# Source code for neurodynex3.tools.plot_tools


# This file is part of the exercise code repository accompanying
# the book: Neuronal Dynamics (see http://neuronaldynamics.epfl.ch)
# located at http://github.com/EPFL-LCN/neuronaldynamics-exercises.

# This free software: you can redistribute it and/or modify it under
# Free Software Foundation. You should have received a copy of the
# GNU General Public License along with the repository. If not,

# Should you reuse and publish the code for your own purposes,
# please cite the book or point to the webpage http://neuronaldynamics.epfl.ch.

# Wulfram Gerstner, Werner M. Kistler, Richard Naud, and Liam Paninski.
# Neuronal Dynamics: From Single Neurons to Networks and Models of Cognition.
# Cambridge University Press, 2014.

import matplotlib.pyplot as plt
import brian2 as b2
import numpy

[docs]def plot_voltage_and_current_traces(voltage_monitor, current, title=None, firing_threshold=None, legend_location=0):
"""plots voltage and current .

Args:
voltage_monitor (StateMonitor): recorded voltage
current (TimedArray): injected current
title (string, optional): title of the figure
firing_threshold (Quantity, optional): if set to a value, the firing threshold is plotted.
legend_location (int): legend location. default = 0 (="best")

Returns:
the figure
"""

assert isinstance(voltage_monitor, b2.StateMonitor), "voltage_monitor is not of type StateMonitor"
assert isinstance(current, b2.TimedArray), "current is not of type TimedArray"

time_values_ms = voltage_monitor.t / b2.ms

# current
axis_c = plt.subplot(211)
c = current(voltage_monitor.t, 0)
max_current = max(current(voltage_monitor.t, 0))
min_current = min(current(voltage_monitor.t, 0))
margin = 1.05 * (max_current - min_current)
# plot the input current time-aligned with the voltage monitor
plt.plot(voltage_monitor.t / b2.ms, c, "r", lw=2)
if margin > 0.:
plt.ylim((min_current - margin) / b2.amp, (max_current + margin) / b2.amp)
# plt.xlabel("t [ms]")
plt.ylabel("Input current [A] \n min: {0} \nmax: {1}".format(min_current, max_current))
plt.grid()
axis_v = plt.subplot(212)
plt.plot(time_values_ms, voltage_monitor[0].v / b2.mV, lw=2)
if firing_threshold is not None:
plt.plot(
(voltage_monitor.t / b2.ms)[[0, -1]],
[firing_threshold / b2.mV, firing_threshold / b2.mV],
"r--", lw=2
)
max_val = max(voltage_monitor[0].v)
if firing_threshold is not None:
max_val = max(max_val, firing_threshold)
min_val = min(voltage_monitor[0].v)
margin = 0.05 * (max_val - min_val)
plt.ylim((min_val - margin) / b2.mV, (max_val + margin) / b2.mV)
plt.xlabel("t [ms]")
plt.ylabel("membrane voltage [mV]\n min: {0}\n max: {1}".format(min_val, max_val))
plt.grid()

if firing_threshold is not None:
plt.legend(["vm", "firing threshold"], fontsize=12, loc=legend_location)

if title is not None:
plt.suptitle(title)
return axis_c, axis_v

[docs]def plot_network_activity(rate_monitor, spike_monitor, voltage_monitor=None, spike_train_idx_list=None,
t_min=None, t_max=None, N_highlighted_spiketrains=3, avg_window_width=1.0 * b2.ms,
sup_title=None, figure_size=(10, 4)):
"""
Visualizes the results of a network simulation: spike-train, population activity and voltage-traces.

Args:
rate_monitor (PopulationRateMonitor): rate of the population
spike_monitor (SpikeMonitor): spike trains of individual neurons
voltage_monitor (StateMonitor): optional. voltage traces of some (same as in spike_train_idx_list) neurons
spike_train_idx_list (list): optional. A list of neuron indices whose spike-train is plotted.
If no list is provided, all (up to 500) spike-trains in the spike_monitor are plotted. If None, the
the list in voltage_monitor.record is used.
t_min (Quantity): optional. lower bound of the plotted time interval.
if t_min is None, it is set to the larger of [0ms, (t_max - 100ms)]
t_max (Quantity): optional. upper bound of the plotted time interval.
if t_max is None, it is set to the timestamp of the last spike in
N_highlighted_spiketrains (int): optional. Number of spike trains visually highlighted, defaults to 3
If N_highlighted_spiketrains==0 and voltage_monitor is not None, then all voltage traces of
the voltage_monitor are plotted. Otherwise N_highlighted_spiketrains voltage traces are plotted.
avg_window_width (Quantity): optional. Before plotting the population rate (PopulationRateMonitor), the rate
is smoothed using a window of width = avg_window_width. Defaults is 1.0ms
sup_title (String): figure suptitle. Default is None.
figure_size (tuple): (width,height) tuple passed to pyplot's figsize parameter.

Returns:
Figure: The whole figure
Axes: Top panel, Raster plot
Axes: Middle panel, population activity
Axes: Bottom panel, voltage traces. None if no voltage monitor is provided.
"""

assert isinstance(rate_monitor, b2.PopulationRateMonitor), \
"rate_monitor  is not of type PopulationRateMonitor"
assert isinstance(spike_monitor, b2.SpikeMonitor), \
"spike_monitor is not of type SpikeMonitor"
assert (voltage_monitor is None) or (isinstance(voltage_monitor, b2.StateMonitor)), \
"voltage_monitor is not of type StateMonitor"
assert (spike_train_idx_list is None) or (isinstance(spike_train_idx_list, list)), \
"spike_train_idx_list is not of type list"

all_spike_trains = spike_monitor.spike_trains()
if spike_train_idx_list is None:
if voltage_monitor is not None:
# if no index list is provided use the one from the voltage monitor
spike_train_idx_list = numpy.sort(voltage_monitor.record)
else:
# no index list AND no voltage monitor: plot all spike trains
spike_train_idx_list = numpy.sort(all_spike_trains.keys())
if len(spike_train_idx_list) > 5000:
# avoid slow plotting of a large set
print("Warning: raster plot with more than 5000 neurons truncated!")
spike_train_idx_list = spike_train_idx_list[:5000]

# get a reasonable default interval
if t_max is None:
t_max = max(rate_monitor.t / b2.ms)
else:
t_max = t_max / b2.ms
if t_min is None:
t_min = max(0., t_max - 100.)  # if none, plot at most the last 100ms
else:
t_min = t_min / b2.ms

fig = None
ax_raster = None
ax_rate = None
ax_voltage = None
if voltage_monitor is None:
fig, (ax_raster, ax_rate) = plt.subplots(2, 1, sharex=True, figsize=figure_size)
else:
fig, (ax_raster, ax_rate, ax_voltage) = plt.subplots(3, 1, sharex=True, figsize=figure_size)

# nested helpers to plot the parts, note that they use parameters defined outside.
def get_spike_train_ts_indices(spike_train):
"""
Helper. Extracts the spikes within the time window from the spike train
"""
ts = spike_train/b2.ms
# spike_within_time_window = (ts >= t_min) & (ts <= t_max)
# idx_spikes = numpy.where(spike_within_time_window)
idx_spikes = (ts >= t_min) & (ts <= t_max)
ts_spikes = ts[idx_spikes]
return idx_spikes, ts_spikes

def plot_raster():
"""
Helper. Plots the spike trains of the spikes in spike_train_idx_list
"""
neuron_counter = 0
for neuron_index in spike_train_idx_list:
idx_spikes, ts_spikes = get_spike_train_ts_indices(all_spike_trains[neuron_index])
ax_raster.scatter(ts_spikes, neuron_counter * numpy.ones(ts_spikes.shape),
marker=".", c="k", s=15, lw=0)
neuron_counter += 1
ax_raster.set_ylim([0, neuron_counter])

def highlight_raster(neuron_idxs):
"""
Helper. Highlights three spike trains
"""
for i in range(len(neuron_idxs)):
color = "r" if i == 0 else "k"
raster_plot_index = neuron_idxs[i]
population_index = spike_train_idx_list[raster_plot_index]
idx_spikes, ts_spikes = get_spike_train_ts_indices(all_spike_trains[population_index])
ax_raster.axhline(y=raster_plot_index, linewidth=.5, linestyle="-", color=[.9, .9, .9])
ax_raster.scatter(
ts_spikes, raster_plot_index * numpy.ones(ts_spikes.shape),
marker=".", c=color, s=144, lw=0)
ax_raster.set_ylabel("neuron #")
ax_raster.set_title("Raster Plot", fontsize=10)

def plot_population_activity(window_width=0.5*b2.ms):
"""
Helper. Plots the population rate and a mean
"""
ts = rate_monitor.t / b2.ms
idx_rate = (ts >= t_min) & (ts <= t_max)
# ax_rate.plot(ts[idx_rate],rate_monitor.rate[idx_rate]/b2.Hz, ".k", markersize=2)
smoothed_rates = rate_monitor.smooth_rate(window="flat", width=window_width)/b2.Hz
ax_rate.plot(ts[idx_rate], smoothed_rates[idx_rate])
ax_rate.set_ylabel("A(t) [Hz]")
ax_rate.set_title("Population Activity", fontsize=10)

def plot_voltage_traces(voltage_traces_i):
"""
Helper. Plots three voltage traces
"""
ts = voltage_monitor.t/b2.ms
idx_voltage = (ts >= t_min) & (ts <= t_max)
for i in range(len(voltage_traces_i)):
color = "r" if i == 0 else ".7"
raster_plot_index = voltage_traces_i[i]
population_index = spike_train_idx_list[raster_plot_index]
ax_voltage.plot(
ts[idx_voltage], voltage_monitor[population_index].v[idx_voltage]/b2.mV,
c=color, lw=1.)
ax_voltage.set_ylabel("V(t) [mV]")
ax_voltage.set_title("Voltage Traces", fontsize=10)

plot_raster()
plot_population_activity(avg_window_width)
nr_neurons = len(spike_train_idx_list)
highlighted_neurons_i = []  # default to an empty list.
if N_highlighted_spiketrains > 0:
fract = numpy.linspace(0, 1, N_highlighted_spiketrains + 2)[1:-1]
highlighted_neurons_i = [int(nr_neurons * v) for v in fract]
highlight_raster(highlighted_neurons_i)

if voltage_monitor is not None:
if N_highlighted_spiketrains == 0:
traces_i = range(nr_neurons)
else:
traces_i = highlighted_neurons_i
plot_voltage_traces(traces_i)

plt.xlabel("t [ms]")

if sup_title is not None:
plt.suptitle(sup_title)

return fig, ax_raster, ax_rate, ax_voltage

[docs]def plot_ISI_distribution(spike_stats, hist_nr_bins=50, xlim_max_ISI=None):
"""
Computes the ISI distribution of the given spike_monitor and displays the distribution in a histogram

Args:
spike_stats (neurodynex3.tools.spike_tools.PopulationSpikeStats): statistics of a population activity
hist_nr_bins (int): Number of histrogram bins. Default:50
xlim_max_ISI (Quantity):  Default: None. In not None, the upper xlim of the plot is set to xlim_max_ISI.
The CV does not change if this bound is set.

Returns:
the figure
"""
from neurodynex3.tools import spike_tools
assert isinstance(spike_stats, spike_tools.PopulationSpikeStats), \
"spike_stats is not of type spike_tools.PopulationSpikeStats"
isi_ms = spike_stats.all_ISI/b2.ms

if xlim_max_ISI is not None:
lim = xlim_max_ISI/b2.ms
idx = isi_ms < lim
isi_ms = isi_ms[idx]

f = plt.figure()
plt.hist(isi_ms, bins=hist_nr_bins)
if xlim_max_ISI is not None:
xmax = xlim_max_ISI / b2.ms
plt.xlim([0, xmax])
plt.title("ISI histogram, CV={}".format(round(spike_stats.CV, 3)))
plt.xlabel("ISI [ms]")
return f

[docs]def plot_spike_train_power_spectrum(freq, mean_ps, all_ps, max_freq,
nr_highlighted_neurons=2, mean_firing_freqs_per_neuron=None, plot_f0=False):
"""
Visualizes the power spectrum of the spike trains.

Args:
freq: frequencies (= x axis)
mean_ps: average power taken over all neurons (typically all of a subsample).
all_ps (dict): power spectra for each single neuron
max_freq (Quantity): The x-lim of the plot is [-0.05*max_freq, max_freq]
mean_firing_freqs_per_neuron (float): None or the mean firing rate averaged across the neurons. Default is
None in which case the value is not shown in the legend
plot_f0 (bool): if true, the power at frequency 0 is plotted. Default is False and the value is not plotted.
Returns:
the figure and the index of the random neuron for which the PS is computed: all_ps[random_neuron_index]
"""
nr_neurons = len(all_ps)
f = plt.figure()
color = "r"
msize = 10
legend_text = []
random_neuron_index = []

first_idx_to_plot = 0 if plot_f0 else 1
for i in range(nr_highlighted_neurons):
rand_idx = numpy.random.randint(nr_neurons)
rand_key = all_ps.keys()[rand_idx]
rand_neuron_ps = all_ps[rand_key]
plt.plot(freq[first_idx_to_plot:], rand_neuron_ps[first_idx_to_plot:],
marker=".", linestyle=" ", markersize=msize, c=color)
color = [.75, .75, .75]  # print the first neuron in red and all others in gray
msize = 8
random_neuron_index.append(rand_key)
if mean_firing_freqs_per_neuron is None:
legend_text.append("PS Neuron #{}".format(rand_key))
else:
legend_text.append("PS Neuron #{}, avg rate={}"
.format(rand_key, round(mean_firing_freqs_per_neuron[rand_key], 1)))

plt.plot(freq[first_idx_to_plot:], mean_ps[first_idx_to_plot:], ".b")

if mean_firing_freqs_per_neuron is None:
legend_text.append("averaged PS")
else:
avg_rate = numpy.mean(mean_firing_freqs_per_neuron.values())
legend_text.append("averaged PS, avg rate={}".format(round(avg_rate, 1)))

plt.legend(legend_text)
plt.xlim([-0.05*max_freq/b2.Hz, max_freq/b2.Hz])
plt.axvline(x=0., lw=1, color="k")
plt.grid()
plt.xlabel("Frequency [Hz]")
plt.ylabel("Power")
plt.title("Single neuron power spectrum and average")
return f, random_neuron_index

[docs]def plot_population_activity_power_spectrum(freq, ps, max_freq, average_At=None, plot_f0=False):
"""
Plots the power spectrum of the population activity A(t)

Args:
freq: frequencies (= x axis)
ps: power spectrum of the population activity
max_freq (Quantity): The data is plotted in the interval [-.05*max_freq, max_freq]
plot_f0 (bool): if true, the power at frequency 0 is plotted. Default is False and the value is not plotted.

Returns:
the figure
"""
first_idx_to_plot = 0 if plot_f0 else 1
f = plt.figure()
plt.plot(freq[first_idx_to_plot:], ps[first_idx_to_plot:], ".b")
plt.axvline(x=0., lw=1, color="k")
plt.xlim([-.05*max_freq/b2.Hz, max_freq/b2.Hz])
plt.grid()
plt.xlabel("Frequency [Hz]")
plt.ylabel("Power")
if average_At is None:
plt.title("Power Spectrum of population activity A(t).")
else:
plt.title("Power Spectrum of population activity A(t). Avg. rate = {}".format(round(average_At, 1)))
return f