"""
This the spike_tools submodule provides functions to analyse the
Brian2 SpikeMonitors and Brian2 StateMonitors. The code provided here is
not optimized for performance and there is no guarantee for correctness.
Relevant book chapters:
- http://neuronaldynamics.epfl.ch/online/Ch19.S2.html#SS1.p6
"""
# This file is part of the exercise code repository accompanying
# the book: Neuronal Dynamics (see http://neuronaldynamics.epfl.ch)
# located at http://github.com/EPFL-LCN/neuronaldynamics-exercises.
# This free software: you can redistribute it and/or modify it under
# the terms of the GNU General Public License 2.0 as published by the
# Free Software Foundation. You should have received a copy of the
# GNU General Public License along with the repository. If not,
# see http://www.gnu.org/licenses/.
# Should you reuse and publish the code for your own purposes,
# please cite the book or point to the webpage http://neuronaldynamics.epfl.ch.
# Wulfram Gerstner, Werner M. Kistler, Richard Naud, and Liam Paninski.
# Neuronal Dynamics: From Single Neurons to Networks and Models of Cognition.
# Cambridge University Press, 2014.
import brian2 as b2
import numpy as np
import math
[docs]def get_spike_time(voltage_monitor, spike_threshold):
"""
Detects the spike times in the voltage. Here, the spike time is DEFINED as the value in
voltage_monitor.t for which voltage_monitor.v[idx] is above threshold AND
voltage_monitor.v[idx-1] is below threshold (crossing from below).
Note: currently only the spike times of the first column in voltage_monitor are detected. Matrix-like
monitors are not supported.
Args:
voltage_monitor (StateMonitor): A state monitor with at least the fields "v: and "t"
spike_threshold (Quantity): The spike threshold voltage. e.g. -50*b2.mV
Returns:
A list of spike times (Quantity)
"""
assert isinstance(voltage_monitor, b2.StateMonitor), "voltage_monitor is not of type StateMonitor"
assert b2.units.fundamentalunits.have_same_dimensions(spike_threshold, b2.volt),\
"spike_threshold must be a voltage. e.g. brian2.mV"
v_above_th = np.asarray(voltage_monitor.v[0] > spike_threshold, dtype=int)
diffs = np.diff(v_above_th)
boolean_mask = diffs > 0 # cross from below.
spike_times = (voltage_monitor.t[1:])[boolean_mask]
return spike_times
[docs]def get_spike_stats(voltage_monitor, spike_threshold):
"""
Detects spike times and computes ISI, mean ISI and firing frequency.
Here, the spike time is DEFINED as the value in
voltage_monitor.t for which voltage_monitor.v[idx] is above threshold AND
voltage_monitor.v[idx-1] is below threshold (crossing from below).
Note: meanISI and firing frequency are set to numpy.nan if less than two spikes are detected
Note: currently only the spike times of the first column in voltage_monitor are detected. Matrix-like
monitors are not supported.
Args:
voltage_monitor (StateMonitor): A state monitor with at least the fields "v: and "t"
spike_threshold (Quantity): The spike threshold voltage. e.g. -50*b2.mV
Returns:
tuple: (nr_of_spikes, spike_times, isi, mean_isi, spike_rate)
"""
spike_times = get_spike_time(voltage_monitor, spike_threshold)
isi = np.diff(spike_times)
nr_of_spikes = len(spike_times)
# init with nan, compute values only if 2 or more spikes are detected
mean_isi = np.nan * b2.ms
var_isi = np.nan * (b2.ms ** 2)
spike_rate = np.nan * b2.Hz
if nr_of_spikes >= 2:
mean_isi = np.mean(isi)
var_isi = np.var(isi)
spike_rate = 1. / mean_isi
return nr_of_spikes, spike_times, isi, mean_isi, spike_rate, var_isi
[docs]def pretty_print_spike_train_stats(voltage_monitor, spike_threshold):
"""
Computes and returns the same values as get_spike_stats. Additionally prints these values to the console.
Args:
voltage_monitor:
spike_threshold:
Returns:
tuple: (nr_of_spikes, spike_times, isi, mean_isi, spike_rate)
"""
nr_of_spikes, spike_times, ISI, mean_isi, spike_freq, var_isi = \
get_spike_stats(voltage_monitor, spike_threshold)
print("nr of spikes: {}".format(nr_of_spikes))
print("mean ISI: {}".format(mean_isi))
print("ISI variance: {}".format(var_isi))
print("spike freq: {}".format(spike_freq))
if nr_of_spikes > 20:
print("spike times: too many values")
print("ISI: too many values")
else:
print("spike times: {}".format(spike_times))
print("ISI: {}".format(ISI))
return spike_times, ISI, mean_isi, spike_freq, var_isi
[docs]class PopulationSpikeStats:
"""
Wraps a few spike-train related properties.
"""
def __init__(self, nr_neurons, nr_spikes, all_ISI, filtered_spike_trains):
"""
Args:
nr_neurons:
nr_spikes:
mean_isi:
std_isi:
all_ISI: list of ISI values (can be used to plot a histrogram)
filtered_spike_trains the spike trains used to compute the stats. It's a time-window filtered copy of
the original spike_monitor.all_spike_trains.
Returns:
An instance of PopulationSpikeStats
"""
self._nr_neurons = nr_neurons
self._nr_spikes = nr_spikes
self._all_ISI = all_ISI
self._filtered_spike_trains = filtered_spike_trains
@property
def nr_neurons(self):
"""
Number of neurons in the original population
"""
return self._nr_neurons
@property
def nr_spikes(self):
"""
Nr of spikes
"""
return self._nr_spikes
@property
def filtered_spike_trains(self):
"""
a time-window filtered copy of the original spike_monitor.all_spike_trains
"""
return self._filtered_spike_trains
@property
def mean_isi(self):
"""
Mean Inter Spike Interval
"""
mean_isi = np.mean(self._all_ISI)*b2.second
return mean_isi
@property
def std_isi(self):
"""
Standard deviation of the ISI
"""
std_isi = np.std(self._all_ISI)*b2.second
return std_isi
@property
def all_ISI(self):
"""
all ISIs in no specific order
"""
return self._all_ISI
@property
def CV(self):
"""
Coefficient of Variation
"""
cv = np.nan # init with nan
if self.mean_isi > 0.:
cv = self.std_isi / self.mean_isi
return cv
[docs]def filter_spike_trains(spike_trains, window_t_min=0.*b2.ms, window_t_max=None, idx_subset=None):
"""
creates a new dictionary neuron_idx=>spike_times where all spike_times are in the
half open interval [window_t_min,window_t_max)
Args:
spike_trains (dict): a dictionary of spike trains. Typically obtained by
calling spike_monitor.spike_trains()
window_t_min (Quantity): Lower bound of the time window: t>=window_t_min. Default is 0ms.
window_t_max (Quantity): Upper bound of the time window: t<window_t_max.
Default is None, in which case no upper bound is set.
idx_subset (list, optional): a list of neuron indexes (dict keys) specifying a subset of neurons.
Neurons NOT in the key list are NOT added to the resulting dictionary. Default is None, in which case
all neurons are added to the resulting list.
Returns:
a filtered copy of spike_trains
"""
assert isinstance(spike_trains, dict), \
"spike_trains is not of type dict"
if idx_subset is None:
idx_subset = spike_trains.keys()
spike_trains_subset = dict()
for k in idx_subset:
spike_trains_subset[k] = spike_trains[k].copy()
nr_neurons = len(idx_subset)
filtered_spike_trains = dict()
if (window_t_min == 0.*b2.ms) & (window_t_max is None):
# print("just copy")
filtered_spike_trains = spike_trains_subset
elif (window_t_max is None):
# print("only lower bound")
for i in idx_subset:
idx = (spike_trains_subset[i] >= window_t_min)
filtered_spike_trains[i] = spike_trains_subset[i][idx]
else:
# print("lower and upper bound")
for i in idx_subset:
idx = (spike_trains_subset[i] >= window_t_min) & (spike_trains_subset[i] < window_t_max)
filtered_spike_trains[i] = spike_trains_subset[i][idx]
return filtered_spike_trains
[docs]def get_spike_train_stats(spike_monitor, window_t_min=0.*b2.ms, window_t_max=None):
"""
Analyses the spike monitor and returns a PopulationSpikeStats instance.
Args:
spike_monitor (SpikeMonitor): Brian2 spike monitor
window_t_min (Quantity): Lower bound of the time window: t>=window_t_min. The stats are computed
for spikes within the time window. Default is 0ms
window_t_max (Quantity): Upper bound of the time window: t<window_t_max. The stats are computed
for spikes within the time window. Default is None, in which case no upper bound is set.
Returns:
PopulationSpikeStats
"""
assert isinstance(spike_monitor, b2.SpikeMonitor), \
"spike_monitor is not of type SpikeMonitor"
filtered_spike_trains = filter_spike_trains(spike_monitor.spike_trains(), window_t_min, window_t_max)
nr_neurons = len(filtered_spike_trains)
all_ISI = []
for i in range(nr_neurons):
spike_times = filtered_spike_trains[i]/b2.ms
nr_spikes = len(spike_times)
if nr_spikes >= 2:
isi = np.diff(spike_times)
# maxISI = max(isi)
# if maxISI > 400:
# print(maxISI)
all_ISI = np.hstack([all_ISI, isi])
all_ISI = all_ISI*b2.ms
stats = PopulationSpikeStats(nr_neurons, spike_monitor.num_spikes, all_ISI, filtered_spike_trains)
return stats
def _spike_train_2_binary_vector(spike_train, vector_length, discretization_dt):
"""
Convert the time-stamps of the spike_train into a binary vector of the given length.
Note: if more than one spike fall into the same time bin, only one is counted, surplus spikes are ignored.
Args:
spike_train:
vector_length:
discretization_dt:
Returns:
Discretized spike train: a fixed-length, binary vector.
"""
vec = np.zeros(vector_length, int)
idx = spike_train / discretization_dt
idx = np.floor(idx).astype(int)
vec[idx] = 1
return vec
def _get_spike_train_power_spectrum(spike_train, delta_t, subtract_mean=False):
st = spike_train/b2.ms
if subtract_mean:
data = st-np.mean(st)
else:
data = st
N_signal = data.size
ps = np.abs(np.fft.fft(data))**2
# normalize
ps = ps * delta_t / N_signal # TODO: verify: subtract 1 (N_signal-1)?
freqs = np.fft.fftfreq(N_signal, delta_t)
ps = ps[:(N_signal/2)]
freqs = freqs[:(N_signal/2)]
return ps, freqs
[docs]def get_averaged_single_neuron_power_spectrum(spike_monitor, sampling_frequency,
window_t_min, window_t_max,
nr_neurons_average=100, subtract_mean=False):
"""
averaged power-spectrum of spike trains in the time window [window_t_min, window_t_max).
The power spectrum of every single neuron's spike train is computed. Then the average
across all single-neuron powers is computed. In order to limit the compuation time, the
number of neurons taken to compute the average is limited to nr_neurons_average which defaults to 100
Args:
spike_monitor (SpikeMonitor) : Brian2 SpikeMonitor
sampling_frequency (Quantity): sampling frequency used to discretize the spike trains.
window_t_min (Quantity): Lower bound of the time window: t>=window_t_min. Spikes
before window_t_min are not taken into account (set a lower bound if you want to exclude an initial
transient in the population activity)
window_t_max (Quantity): Upper bound of the time window: t<window_t_max.
nr_neurons_average (int): Number of neurons over which the average is taken.
subtract_mean (bool): If true, the mean value of the signal is subtracted before FFT. Default is False
Returns:
freq, mean_ps, all_ps_dict, mean_firing_rate, mean_firing_freqs_per_neuron
"""
assert isinstance(spike_monitor, b2.SpikeMonitor), \
"spike_monitor is not of type SpikeMonitor"
spiketrains = spike_monitor.spike_trains()
nr_neurons = len(spiketrains)
sample_neurons = []
nr_samples = 0
if nr_neurons <= nr_neurons_average:
sample_neurons = range(nr_neurons)
nr_samples = nr_neurons
else:
idxs = np.arange(nr_neurons)
np.random.shuffle(idxs)
sample_neurons = idxs[:(nr_neurons_average)]
nr_samples = nr_neurons_average
sptrs = filter_spike_trains(spike_monitor.spike_trains(), window_t_min, window_t_max, sample_neurons)
time_window_size = window_t_max - window_t_min
discretization_dt = 1./sampling_frequency
if window_t_max is None:
window_t_max = max(spike_monitor.t)
vector_length = 1+int(math.ceil(time_window_size/discretization_dt)) # +1: space for rounding issues
freq = 0
spike_count = 0
all_ps = np.zeros([nr_samples, vector_length/2], float)
all_ps_dict = dict()
mean_firing_freqs_per_neuron = dict()
for i in range(nr_samples):
idx = sample_neurons[i]
vec = _spike_train_2_binary_vector(
sptrs[idx]-window_t_min, vector_length, discretization_dt=discretization_dt)
ps, freq = _get_spike_train_power_spectrum(vec, discretization_dt, subtract_mean)
all_ps[i, :] = ps
all_ps_dict[idx] = ps
nr_spikes = len(sptrs[idx])
nu_avg = nr_spikes / time_window_size
# print(nu_avg)
mean_firing_freqs_per_neuron[idx] = nu_avg
spike_count += nr_spikes # count in the subsample which is filtered to [window_t_min, window_t_max]
mean_ps = np.mean(all_ps, 0)
mean_firing_rate = spike_count / nr_samples / time_window_size
print("mean_firing_rate:{}".format(mean_firing_rate))
return freq, mean_ps, all_ps_dict, mean_firing_rate, mean_firing_freqs_per_neuron
[docs]def get_population_activity_power_spectrum(
rate_monitor, delta_f, k_repetitions, T_init=100*b2.ms, subtract_mean_activity=False):
"""
Computes the power spectrum of the population activity A(t) (=rate_monitor.rate)
Args:
rate_monitor (RateMonitor): Brian2 rate monitor. rate_monitor.rate is the signal being
analysed here. The temporal resolution is read from rate_monitor.clock.dt
delta_f (Quantity): The desired frequency resolution.
k_repetitions (int): The data rate_monitor.rate is split into k_repetitions which are FFT'd
independently and then averaged in frequency domain.
T_init (Quantity): Rates in the time interval [0, T_init] are removed before doing the
Fourier transform. Use this parameter to ignore the initial transient signals of the simulation.
subtract_mean_activity (bool): If true, the mean value of the signal is subtracted. Default is False
Returns:
freqs, ps, average_population_rate
"""
data = rate_monitor.rate/b2.Hz
delta_t = rate_monitor.clock.dt
f_max = 1./(2. * delta_t)
N_signal = int(2 * f_max / delta_f)
T_signal = N_signal * delta_t
N_init = int(T_init/delta_t)
N_required = k_repetitions * N_signal + N_init
N_data = len(data)
# print("N_data={}, N_required={}".format(N_data,N_required))
if (N_data < N_required):
err_msg = "Inconsistent parameters. k_repetitions require {} samples." \
" rate_monitor.rate contains {} samples.".format(N_required, N_data)
raise ValueError(err_msg)
if N_data > N_required:
# print("drop samples")
data = data[:N_required]
# print("length after dropping end:{}".format(len(data)))
data = data[N_init:]
# print("length after dropping init:{}".format(len(data)))
average_population_rate = np.mean(data)
if subtract_mean_activity:
data = data - average_population_rate
average_population_rate *= b2.Hz
data = data.reshape(k_repetitions, N_signal) # reshape into one row per repetition (k)
k_ps = np.abs(np.fft.fft(data))**2
ps = np.mean(k_ps, 0)
# normalize
ps = ps * delta_t / N_signal # TODO: verify: subtract 1 (N_signal-1)?
freqs = np.fft.fftfreq(N_signal, delta_t)
ps = ps[:(N_signal/2)]
freqs = freqs[:(N_signal/2)]
return freqs, ps, average_population_rate