bmtool 0.6.7__tar.gz → 0.6.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {bmtool-0.6.7 → bmtool-0.6.8}/PKG-INFO +7 -2
- {bmtool-0.6.7 → bmtool-0.6.8}/README.md +6 -1
- {bmtool-0.6.7 → bmtool-0.6.8}/bmtool/SLURM.py +10 -14
- bmtool-0.6.8/bmtool/analysis/lfp.py +408 -0
- bmtool-0.6.8/bmtool/analysis/spikes.py +254 -0
- {bmtool-0.6.7 → bmtool-0.6.8}/bmtool/bmplot.py +170 -167
- {bmtool-0.6.7 → bmtool-0.6.8}/bmtool/singlecell.py +5 -1
- bmtool-0.6.8/bmtool/util/neuron/__init__.py +0 -0
- {bmtool-0.6.7 → bmtool-0.6.8}/bmtool.egg-info/PKG-INFO +7 -2
- {bmtool-0.6.7 → bmtool-0.6.8}/bmtool.egg-info/SOURCES.txt +3 -0
- {bmtool-0.6.7 → bmtool-0.6.8}/setup.py +1 -1
- {bmtool-0.6.7 → bmtool-0.6.8}/LICENSE +0 -0
- {bmtool-0.6.7 → bmtool-0.6.8}/bmtool/__init__.py +0 -0
- {bmtool-0.6.7 → bmtool-0.6.8}/bmtool/__main__.py +0 -0
- {bmtool-0.6.7/bmtool/debug → bmtool-0.6.8/bmtool/analysis}/__init__.py +0 -0
- {bmtool-0.6.7 → bmtool-0.6.8}/bmtool/connectors.py +0 -0
- {bmtool-0.6.7/bmtool/util → bmtool-0.6.8/bmtool/debug}/__init__.py +0 -0
- {bmtool-0.6.7 → bmtool-0.6.8}/bmtool/debug/commands.py +0 -0
- {bmtool-0.6.7 → bmtool-0.6.8}/bmtool/debug/debug.py +0 -0
- {bmtool-0.6.7 → bmtool-0.6.8}/bmtool/graphs.py +0 -0
- {bmtool-0.6.7 → bmtool-0.6.8}/bmtool/manage.py +0 -0
- {bmtool-0.6.7 → bmtool-0.6.8}/bmtool/plot_commands.py +0 -0
- {bmtool-0.6.7 → bmtool-0.6.8}/bmtool/synapses.py +0 -0
- {bmtool-0.6.7/bmtool/util/neuron → bmtool-0.6.8/bmtool/util}/__init__.py +0 -0
- {bmtool-0.6.7 → bmtool-0.6.8}/bmtool/util/commands.py +0 -0
- {bmtool-0.6.7 → bmtool-0.6.8}/bmtool/util/neuron/celltuner.py +0 -0
- {bmtool-0.6.7 → bmtool-0.6.8}/bmtool/util/util.py +0 -0
- {bmtool-0.6.7 → bmtool-0.6.8}/bmtool.egg-info/dependency_links.txt +0 -0
- {bmtool-0.6.7 → bmtool-0.6.8}/bmtool.egg-info/entry_points.txt +0 -0
- {bmtool-0.6.7 → bmtool-0.6.8}/bmtool.egg-info/requires.txt +0 -0
- {bmtool-0.6.7 → bmtool-0.6.8}/bmtool.egg-info/top_level.txt +0 -0
- {bmtool-0.6.7 → bmtool-0.6.8}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.2
|
2
2
|
Name: bmtool
|
3
|
-
Version: 0.6.
|
3
|
+
Version: 0.6.8
|
4
4
|
Summary: BMTool
|
5
5
|
Home-page: https://github.com/cyneuro/bmtool
|
6
6
|
Download-URL:
|
@@ -53,6 +53,7 @@ A collection of modules to make developing [Neuron](https://www.neuron.yale.edu/
|
|
53
53
|
- [Synapses](#synapses-module)
|
54
54
|
- [Connectors](#connectors-module)
|
55
55
|
- [Bmplot](#bmplot-module)
|
56
|
+
- [Analysis](#analysis-module)
|
56
57
|
- [SLURM](#slurm-module)
|
57
58
|
- [Graphs](#graphs-module)
|
58
59
|
|
@@ -471,7 +472,11 @@ bmplot.plot_network_graph(config='config.json',sources='LA',targets='LA',tids='p
|
|
471
472
|
|
472
473
|
|
473
474
|

|
474
|
-
|
475
|
+
|
476
|
+
|
477
|
+
## Analysis Module
|
478
|
+
### A notebook example of how to use the spikes module can be found [here](examples/analysis/using_spikes.ipynb)
|
479
|
+
|
475
480
|
## SLURM Module
|
476
481
|
### This is an extremely helpful module that can simplify using SLURM too submit your models. There is also features to enable doing a seedSweep. This will vary the parameters of the simulation and make tuning the model easier. An example can be found [here](examples/SLURM/using_BlockRunner.ipynb)
|
477
482
|
|
@@ -10,6 +10,7 @@ A collection of modules to make developing [Neuron](https://www.neuron.yale.edu/
|
|
10
10
|
- [Synapses](#synapses-module)
|
11
11
|
- [Connectors](#connectors-module)
|
12
12
|
- [Bmplot](#bmplot-module)
|
13
|
+
- [Analysis](#analysis-module)
|
13
14
|
- [SLURM](#slurm-module)
|
14
15
|
- [Graphs](#graphs-module)
|
15
16
|
|
@@ -428,7 +429,11 @@ bmplot.plot_network_graph(config='config.json',sources='LA',targets='LA',tids='p
|
|
428
429
|
|
429
430
|
|
430
431
|

|
431
|
-
|
432
|
+
|
433
|
+
|
434
|
+
## Analysis Module
|
435
|
+
### A notebook example of how to use the spikes module can be found [here](examples/analysis/using_spikes.ipynb)
|
436
|
+
|
432
437
|
## SLURM Module
|
433
438
|
### This is an extremely helpful module that can simplify using SLURM too submit your models. There is also features to enable doing a seedSweep. This will vary the parameters of the simulation and make tuning the model easier. An example can be found [here](examples/SLURM/using_BlockRunner.ipynb)
|
434
439
|
|
@@ -353,17 +353,15 @@ class BlockRunner:
|
|
353
353
|
shutil.copytree(source_dir, destination_dir) # create new components folder
|
354
354
|
json_file_path = os.path.join(destination_dir,self.json_file_path)
|
355
355
|
|
356
|
-
# need to keep the orignal around
|
357
|
-
syn_dict_temp = copy.deepcopy(self.syn_dict)
|
358
|
-
print(self.syn_dict['json_file_path'])
|
359
|
-
json_to_be_ratioed = syn_dict_temp['json_file_path']
|
360
|
-
corrected_ratio_path = os.path.join(destination_dir,json_to_be_ratioed)
|
361
|
-
syn_dict_temp['json_file_path'] = corrected_ratio_path
|
362
|
-
|
363
356
|
if self.syn_dict == None:
|
364
357
|
json_editor = seedSweep(json_file_path , self.param_name)
|
365
358
|
json_editor.edit_json(new_value)
|
366
359
|
else:
|
360
|
+
# need to keep the orignal around
|
361
|
+
syn_dict_temp = copy.deepcopy(self.syn_dict)
|
362
|
+
json_to_be_ratioed = syn_dict_temp['json_file_path']
|
363
|
+
corrected_ratio_path = os.path.join(destination_dir,json_to_be_ratioed)
|
364
|
+
syn_dict_temp['json_file_path'] = corrected_ratio_path
|
367
365
|
json_editor = multiSeedSweep(json_file_path ,self.param_name,
|
368
366
|
syn_dict=syn_dict_temp,base_ratio=1)
|
369
367
|
json_editor.edit_all_jsons(new_value)
|
@@ -415,17 +413,15 @@ class BlockRunner:
|
|
415
413
|
shutil.copytree(source_dir, destination_dir) # create new components folder
|
416
414
|
json_file_path = os.path.join(destination_dir,self.json_file_path)
|
417
415
|
|
418
|
-
# need to keep the orignal around
|
419
|
-
syn_dict_temp = copy.deepcopy(self.syn_dict)
|
420
|
-
print(self.syn_dict['json_file_path'])
|
421
|
-
json_to_be_ratioed = syn_dict_temp['json_file_path']
|
422
|
-
corrected_ratio_path = os.path.join(destination_dir,json_to_be_ratioed)
|
423
|
-
syn_dict_temp['json_file_path'] = corrected_ratio_path
|
424
|
-
|
425
416
|
if self.syn_dict == None:
|
426
417
|
json_editor = seedSweep(json_file_path , self.param_name)
|
427
418
|
json_editor.edit_json(new_value)
|
428
419
|
else:
|
420
|
+
# need to keep the orignal around
|
421
|
+
syn_dict_temp = copy.deepcopy(self.syn_dict)
|
422
|
+
json_to_be_ratioed = syn_dict_temp['json_file_path']
|
423
|
+
corrected_ratio_path = os.path.join(destination_dir,json_to_be_ratioed)
|
424
|
+
syn_dict_temp['json_file_path'] = corrected_ratio_path
|
429
425
|
json_editor = multiSeedSweep(json_file_path ,self.param_name,
|
430
426
|
syn_dict_temp,base_ratio=1)
|
431
427
|
json_editor.edit_all_jsons(new_value)
|
@@ -0,0 +1,408 @@
|
|
1
|
+
"""
|
2
|
+
Module for processing BMTK LFP output.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import h5py
|
6
|
+
import numpy as np
|
7
|
+
import xarray as xr
|
8
|
+
from fooof import FOOOF
|
9
|
+
from fooof.sim.gen import gen_model
|
10
|
+
import matplotlib.pyplot as plt
|
11
|
+
from scipy import signal
|
12
|
+
import pywt
|
13
|
+
from bmtool.bmplot import is_notebook
|
14
|
+
|
15
|
+
|
16
|
+
def load_ecp_to_xarray(ecp_file: str, demean: bool = False) -> xr.DataArray:
|
17
|
+
"""
|
18
|
+
Load ECP data from an HDF5 file (BMTK sim) into an xarray DataArray.
|
19
|
+
|
20
|
+
Parameters:
|
21
|
+
----------
|
22
|
+
ecp_file : str
|
23
|
+
Path to the HDF5 file containing ECP data.
|
24
|
+
demean : bool, optional
|
25
|
+
If True, the mean of the data will be subtracted (default is False).
|
26
|
+
|
27
|
+
Returns:
|
28
|
+
-------
|
29
|
+
xr.DataArray
|
30
|
+
An xarray DataArray containing the ECP data, with time as one dimension
|
31
|
+
and channel_id as another.
|
32
|
+
"""
|
33
|
+
with h5py.File(ecp_file, 'r') as f:
|
34
|
+
ecp = xr.DataArray(
|
35
|
+
f['ecp']['data'][()].T,
|
36
|
+
coords=dict(
|
37
|
+
channel_id=f['ecp']['channel_id'][()],
|
38
|
+
time=np.arange(*f['ecp']['time']) # ms
|
39
|
+
),
|
40
|
+
attrs=dict(
|
41
|
+
fs=1000 / f['ecp']['time'][2] # Hz
|
42
|
+
)
|
43
|
+
)
|
44
|
+
if demean:
|
45
|
+
ecp -= ecp.mean(dim='time')
|
46
|
+
return ecp
|
47
|
+
|
48
|
+
|
49
|
+
def ecp_to_lfp(ecp_data: xr.DataArray, cutoff: float = 250, fs: float = 10000,
|
50
|
+
downsample_freq: float = 1000) -> xr.DataArray:
|
51
|
+
"""
|
52
|
+
Apply a low-pass Butterworth filter to an xarray DataArray and optionally downsample.
|
53
|
+
This filters out the high end frequencies turning the ECP into a LFP
|
54
|
+
|
55
|
+
Parameters:
|
56
|
+
----------
|
57
|
+
ecp_data : xr.DataArray
|
58
|
+
The input data array containing LFP data with time as one dimension.
|
59
|
+
cutoff : float
|
60
|
+
The cutoff frequency for the low-pass filter in Hz (default is 250Hz).
|
61
|
+
fs : float, optional
|
62
|
+
The sampling frequency of the data (default is 10000 Hz).
|
63
|
+
downsample_freq : float, optional
|
64
|
+
The frequency to downsample to (default is 1000 Hz).
|
65
|
+
|
66
|
+
Returns:
|
67
|
+
-------
|
68
|
+
xr.DataArray
|
69
|
+
The filtered (and possibly downsampled) data as an xarray DataArray.
|
70
|
+
"""
|
71
|
+
# Bandpass filter design
|
72
|
+
nyq = 0.5 * fs
|
73
|
+
cut = cutoff / nyq
|
74
|
+
b, a = signal.butter(8, cut, btype='low', analog=False)
|
75
|
+
|
76
|
+
# Initialize an array to hold filtered data
|
77
|
+
filtered_data = xr.DataArray(np.zeros_like(ecp_data), coords=ecp_data.coords, dims=ecp_data.dims)
|
78
|
+
|
79
|
+
# Apply the filter to each channel
|
80
|
+
for channel in ecp_data.channel_id:
|
81
|
+
filtered_data.loc[channel, :] = signal.filtfilt(b, a, ecp_data.sel(channel_id=channel).values)
|
82
|
+
|
83
|
+
# Downsample the filtered data if a downsample frequency is provided
|
84
|
+
if downsample_freq is not None:
|
85
|
+
downsample_factor = int(fs / downsample_freq)
|
86
|
+
filtered_data = filtered_data.isel(time=slice(None, None, downsample_factor))
|
87
|
+
# Update the sampling frequency attribute
|
88
|
+
filtered_data.attrs['fs'] = downsample_freq
|
89
|
+
|
90
|
+
return filtered_data
|
91
|
+
|
92
|
+
|
93
|
+
def slice_time_series(data: xr.DataArray, time_ranges: tuple) -> xr.DataArray:
|
94
|
+
"""
|
95
|
+
Slice the xarray DataArray based on provided time ranges.
|
96
|
+
Can be used to get LFP during certain stimulus times
|
97
|
+
|
98
|
+
Parameters:
|
99
|
+
----------
|
100
|
+
data : xr.DataArray
|
101
|
+
The input xarray DataArray containing time-series data.
|
102
|
+
time_ranges : tuple or list of tuples
|
103
|
+
One or more tuples representing the (start, stop) time points for slicing.
|
104
|
+
For example: (start, stop) or [(start1, stop1), (start2, stop2)]
|
105
|
+
|
106
|
+
Returns:
|
107
|
+
-------
|
108
|
+
xr.DataArray
|
109
|
+
A new xarray DataArray containing the concatenated slices.
|
110
|
+
"""
|
111
|
+
# Ensure time_ranges is a list of tuples
|
112
|
+
if isinstance(time_ranges, tuple) and len(time_ranges) == 2:
|
113
|
+
time_ranges = [time_ranges]
|
114
|
+
|
115
|
+
# List to hold sliced data
|
116
|
+
slices = []
|
117
|
+
|
118
|
+
# Slice the data for each time range
|
119
|
+
for start, stop in time_ranges:
|
120
|
+
sliced_data = data.sel(time=slice(start, stop))
|
121
|
+
slices.append(sliced_data)
|
122
|
+
|
123
|
+
# Concatenate all slices along the time dimension if more than one slice
|
124
|
+
if len(slices) > 1:
|
125
|
+
return xr.concat(slices, dim='time')
|
126
|
+
else:
|
127
|
+
return slices[0]
|
128
|
+
|
129
|
+
|
130
|
+
def fit_fooof(f: np.ndarray, pxx: np.ndarray, aperiodic_mode: str = 'fixed',
|
131
|
+
dB_threshold: float = 3.0, max_n_peaks: int = 10,
|
132
|
+
freq_range: tuple = None, peak_width_limits: tuple = None,
|
133
|
+
report: bool = False, plot: bool = False,
|
134
|
+
plt_log: bool = False, plt_range: tuple = None,
|
135
|
+
figsize: tuple = None, title: str = None) -> tuple:
|
136
|
+
"""
|
137
|
+
Fit a FOOOF model to power spectral density data.
|
138
|
+
|
139
|
+
Parameters:
|
140
|
+
----------
|
141
|
+
f : array-like
|
142
|
+
Frequencies corresponding to the power spectral density data.
|
143
|
+
pxx : array-like
|
144
|
+
Power spectral density data to fit.
|
145
|
+
aperiodic_mode : str, optional
|
146
|
+
The mode for fitting aperiodic components ('fixed' or 'knee', default is 'fixed').
|
147
|
+
dB_threshold : float, optional
|
148
|
+
Minimum peak height in dB (default is 3).
|
149
|
+
max_n_peaks : int, optional
|
150
|
+
Maximum number of peaks to fit (default is 10).
|
151
|
+
freq_range : tuple, optional
|
152
|
+
Frequency range to fit (default is None, which uses the full range).
|
153
|
+
peak_width_limits : tuple, optional
|
154
|
+
Limits on the width of peaks (default is None).
|
155
|
+
report : bool, optional
|
156
|
+
If True, will print fitting results (default is False).
|
157
|
+
plot : bool, optional
|
158
|
+
If True, will plot the fitting results (default is False).
|
159
|
+
plt_log : bool, optional
|
160
|
+
If True, use a logarithmic scale for the y-axis in plots (default is False).
|
161
|
+
plt_range : tuple, optional
|
162
|
+
Range for plotting (default is None).
|
163
|
+
figsize : tuple, optional
|
164
|
+
Size of the figure (default is None).
|
165
|
+
title : str, optional
|
166
|
+
Title for the plot (default is None).
|
167
|
+
|
168
|
+
Returns:
|
169
|
+
-------
|
170
|
+
tuple
|
171
|
+
A tuple containing the fitting results and the FOOOF model object.
|
172
|
+
"""
|
173
|
+
if aperiodic_mode != 'knee':
|
174
|
+
aperiodic_mode = 'fixed'
|
175
|
+
|
176
|
+
def set_range(x, upper=f[-1]):
|
177
|
+
x = np.array(upper) if x is None else np.array(x)
|
178
|
+
return [f[2], x.item()] if x.size == 1 else x.tolist()
|
179
|
+
|
180
|
+
freq_range = set_range(freq_range)
|
181
|
+
peak_width_limits = set_range(peak_width_limits, np.inf)
|
182
|
+
|
183
|
+
# Initialize a FOOOF object
|
184
|
+
fm = FOOOF(peak_width_limits=peak_width_limits, min_peak_height=dB_threshold / 10,
|
185
|
+
peak_threshold=0., max_n_peaks=max_n_peaks, aperiodic_mode=aperiodic_mode)
|
186
|
+
|
187
|
+
# Fit the model
|
188
|
+
try:
|
189
|
+
fm.fit(f, pxx, freq_range)
|
190
|
+
except Exception as e:
|
191
|
+
fl = np.linspace(f[0], f[-1], int((f[-1] - f[0]) / np.min(np.diff(f))) + 1)
|
192
|
+
fm.fit(fl, np.interp(fl, f, pxx), freq_range)
|
193
|
+
|
194
|
+
results = fm.get_results()
|
195
|
+
|
196
|
+
if report:
|
197
|
+
fm.print_results()
|
198
|
+
if aperiodic_mode == 'knee':
|
199
|
+
ap_params = results.aperiodic_params
|
200
|
+
if ap_params[1] <= 0:
|
201
|
+
print('Negative value of knee parameter occurred. Suggestion: Fit without knee parameter.')
|
202
|
+
knee_freq = np.abs(ap_params[1]) ** (1 / ap_params[2])
|
203
|
+
print(f'Knee location: {knee_freq:.2f} Hz')
|
204
|
+
|
205
|
+
if plot:
|
206
|
+
plt_range = set_range(plt_range)
|
207
|
+
fm.plot(plt_log=plt_log)
|
208
|
+
plt.xlim(np.log10(plt_range) if plt_log else plt_range)
|
209
|
+
#plt.ylim(-8, -5.5)
|
210
|
+
if figsize:
|
211
|
+
plt.gcf().set_size_inches(figsize)
|
212
|
+
if title:
|
213
|
+
plt.title(title)
|
214
|
+
if is_notebook():
|
215
|
+
pass
|
216
|
+
else:
|
217
|
+
plt.show()
|
218
|
+
|
219
|
+
return results, fm
|
220
|
+
|
221
|
+
|
222
|
+
def generate_resd_from_fooof(fooof_model: FOOOF) -> tuple:
|
223
|
+
"""
|
224
|
+
Generate residuals from a fitted FOOOF model.
|
225
|
+
|
226
|
+
Parameters:
|
227
|
+
----------
|
228
|
+
fooof_model : FOOOF
|
229
|
+
A fitted FOOOF model object.
|
230
|
+
|
231
|
+
Returns:
|
232
|
+
-------
|
233
|
+
tuple
|
234
|
+
A tuple containing the residual power spectral density and the aperiodic fit.
|
235
|
+
"""
|
236
|
+
results = fooof_model.get_results()
|
237
|
+
full_fit, _, ap_fit = gen_model(fooof_model.freqs[1:], results.aperiodic_params,
|
238
|
+
results.gaussian_params, return_components=True)
|
239
|
+
|
240
|
+
full_fit, ap_fit = 10 ** full_fit, 10 ** ap_fit # Convert back from log
|
241
|
+
res_psd = np.insert((10 ** fooof_model.power_spectrum[1:]) - ap_fit, 0, 0.) # Convert back from log
|
242
|
+
res_fit = np.insert(full_fit - ap_fit, 0, 0.)
|
243
|
+
ap_fit = np.insert(ap_fit, 0, 0.)
|
244
|
+
|
245
|
+
return res_psd, ap_fit
|
246
|
+
|
247
|
+
|
248
|
+
def calculate_SNR(fooof_model: FOOOF, freq_band: tuple) -> float:
|
249
|
+
"""
|
250
|
+
Calculate the signal-to-noise ratio (SNR) from a fitted FOOOF model.
|
251
|
+
|
252
|
+
Parameters:
|
253
|
+
----------
|
254
|
+
fooof_model : FOOOF
|
255
|
+
A fitted FOOOF model object.
|
256
|
+
freq_band : tuple
|
257
|
+
Frequency band (min, max) for SNR calculation.
|
258
|
+
|
259
|
+
Returns:
|
260
|
+
-------
|
261
|
+
float
|
262
|
+
The calculated SNR for the specified frequency band.
|
263
|
+
"""
|
264
|
+
periodic, ap = generate_resd_from_fooof(fooof_model)
|
265
|
+
freq = fooof_model.freqs # Get frequencies from model
|
266
|
+
indices = (freq >= freq_band[0]) & (freq <= freq_band[1]) # Get only the band we care about
|
267
|
+
band_periodic = periodic[indices] # Filter based on band
|
268
|
+
band_ap = ap[indices] # Filter
|
269
|
+
band_freq = freq[indices] # Another filter
|
270
|
+
periodic_power = np.trapz(band_periodic, band_freq) # Integrate periodic power
|
271
|
+
ap_power = np.trapz(band_ap, band_freq) # Integrate aperiodic power
|
272
|
+
normalized_power = periodic_power / ap_power # Compute the SNR
|
273
|
+
return normalized_power
|
274
|
+
|
275
|
+
|
276
|
+
def wavelet_filter(x: np.ndarray, freq: float, fs: float, bandwidth: float = 1.0, axis: int = -1) -> np.ndarray:
|
277
|
+
"""
|
278
|
+
Compute the Continuous Wavelet Transform (CWT) for a specified frequency using a complex Morlet wavelet.
|
279
|
+
"""
|
280
|
+
wavelet = 'cmor' + str(2 * bandwidth ** 2) + '-1.0'
|
281
|
+
scale = pywt.scale2frequency(wavelet, 1) * fs / freq
|
282
|
+
x_a = pywt.cwt(x, [scale], wavelet=wavelet, axis=axis)[0][0]
|
283
|
+
return x_a
|
284
|
+
|
285
|
+
|
286
|
+
def butter_bandpass_filter(data: np.ndarray, lowcut: float, highcut: float, fs: float, order: int = 5, axis: int = -1) -> np.ndarray:
|
287
|
+
"""
|
288
|
+
Apply a Butterworth bandpass filter to the input data.
|
289
|
+
"""
|
290
|
+
sos = signal.butter(order, [lowcut, highcut], fs=fs, btype='band', output='sos')
|
291
|
+
x_a = signal.sosfiltfilt(sos, data, axis=axis)
|
292
|
+
return x_a
|
293
|
+
|
294
|
+
|
295
|
+
def calculate_plv(x1: np.ndarray, x2: np.ndarray, fs: float, freq_of_interest: float = None,
|
296
|
+
method: str = 'wavelet', lowcut: float = None, highcut: float = None,
|
297
|
+
bandwidth: float = 2.0) -> np.ndarray:
|
298
|
+
"""
|
299
|
+
Calculate Phase Locking Value (PLV) between two signals using wavelet or Hilbert method.
|
300
|
+
|
301
|
+
Parameters:
|
302
|
+
- x1, x2: Input signals (1D arrays, same length)
|
303
|
+
- fs: Sampling frequency
|
304
|
+
- freq_of_interest: Desired frequency for wavelet PLV calculation
|
305
|
+
- method: 'wavelet' or 'hilbert' to choose the PLV calculation method
|
306
|
+
- lowcut, highcut: Cutoff frequencies for the Hilbert method
|
307
|
+
- bandwidth: Bandwidth parameter for the wavelet
|
308
|
+
|
309
|
+
Returns:
|
310
|
+
- plv: Phase Locking Value (1D array)
|
311
|
+
"""
|
312
|
+
if len(x1) != len(x2):
|
313
|
+
raise ValueError("Input signals must have the same length.")
|
314
|
+
|
315
|
+
if method == 'wavelet':
|
316
|
+
if freq_of_interest is None:
|
317
|
+
raise ValueError("freq_of_interest must be provided for the wavelet method.")
|
318
|
+
|
319
|
+
# Apply CWT to both signals
|
320
|
+
theta1 = wavelet_filter(x=x1, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
|
321
|
+
theta2 = wavelet_filter(x=x2, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
|
322
|
+
|
323
|
+
elif method == 'hilbert':
|
324
|
+
if lowcut is None or highcut is None:
|
325
|
+
print("Lowcut and or highcut were not definded, signal will not be filter and just take hilbert transform for plv calc")
|
326
|
+
|
327
|
+
if lowcut and highcut:
|
328
|
+
# Bandpass filter and get the analytic signal using the Hilbert transform
|
329
|
+
x1 = butter_bandpass_filter(x1, lowcut, highcut, fs)
|
330
|
+
x2 = butter_bandpass_filter(x2, lowcut, highcut, fs)
|
331
|
+
|
332
|
+
# Get phase using the Hilbert transform
|
333
|
+
theta1 = signal.hilbert(x1)
|
334
|
+
theta2 = signal.hilbert(x2)
|
335
|
+
|
336
|
+
else:
|
337
|
+
raise ValueError("Invalid method. Choose 'wavelet' or 'hilbert'.")
|
338
|
+
|
339
|
+
# Calculate phase difference
|
340
|
+
phase_diff = np.angle(theta1) - np.angle(theta2)
|
341
|
+
|
342
|
+
# Calculate PLV from standard equation from Measuring phase synchrony in brain signals(1999)
|
343
|
+
plv = np.abs(np.mean(np.exp(1j * phase_diff), axis=-1))
|
344
|
+
|
345
|
+
return plv
|
346
|
+
|
347
|
+
|
348
|
+
def calculate_plv_over_time(x1: np.ndarray, x2: np.ndarray, fs: float,
|
349
|
+
window_size: float, step_size: float,
|
350
|
+
method: str = 'wavelet', freq_of_interest: float = None,
|
351
|
+
lowcut: float = None, highcut: float = None,
|
352
|
+
bandwidth: float = 2.0):
|
353
|
+
"""
|
354
|
+
Calculate the time-resolved Phase Locking Value (PLV) between two signals using a sliding window approach.
|
355
|
+
|
356
|
+
Parameters:
|
357
|
+
----------
|
358
|
+
x1, x2 : array-like
|
359
|
+
Input signals (1D arrays, same length).
|
360
|
+
fs : float
|
361
|
+
Sampling frequency of the input signals.
|
362
|
+
window_size : float
|
363
|
+
Length of the window in seconds for PLV calculation.
|
364
|
+
step_size : float
|
365
|
+
Step size in seconds to slide the window across the signals.
|
366
|
+
method : str, optional
|
367
|
+
Method to calculate PLV ('wavelet' or 'hilbert'). Defaults to 'wavelet'.
|
368
|
+
freq_of_interest : float, optional
|
369
|
+
Frequency of interest for the wavelet method. Required if method is 'wavelet'.
|
370
|
+
lowcut, highcut : float, optional
|
371
|
+
Cutoff frequencies for the Hilbert method. Required if method is 'hilbert'.
|
372
|
+
bandwidth : float, optional
|
373
|
+
Bandwidth parameter for the wavelet. Defaults to 2.0.
|
374
|
+
|
375
|
+
Returns:
|
376
|
+
-------
|
377
|
+
plv_over_time : 1D array
|
378
|
+
Array of PLV values calculated over each window.
|
379
|
+
times : 1D array
|
380
|
+
The center times of each window where the PLV was calculated.
|
381
|
+
"""
|
382
|
+
# Convert window and step size from seconds to samples
|
383
|
+
window_samples = int(window_size * fs)
|
384
|
+
step_samples = int(step_size * fs)
|
385
|
+
|
386
|
+
# Initialize results
|
387
|
+
plv_over_time = []
|
388
|
+
times = []
|
389
|
+
|
390
|
+
# Iterate over the signal with a sliding window
|
391
|
+
for start in range(0, len(x1) - window_samples + 1, step_samples):
|
392
|
+
end = start + window_samples
|
393
|
+
window_x1 = x1[start:end]
|
394
|
+
window_x2 = x2[start:end]
|
395
|
+
|
396
|
+
# Use the updated calculate_plv function within each window
|
397
|
+
plv = calculate_plv(x1=window_x1, x2=window_x2, fs=fs,
|
398
|
+
method=method, freq_of_interest=freq_of_interest,
|
399
|
+
lowcut=lowcut, highcut=highcut, bandwidth=bandwidth)
|
400
|
+
plv_over_time.append(plv)
|
401
|
+
|
402
|
+
# Store the time at the center of the window
|
403
|
+
center_time = (start + end) / 2 / fs
|
404
|
+
times.append(center_time)
|
405
|
+
|
406
|
+
return np.array(plv_over_time), np.array(times)
|
407
|
+
|
408
|
+
|
@@ -0,0 +1,254 @@
|
|
1
|
+
"""
|
2
|
+
Module for processing BMTK spikes output.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import h5py
|
6
|
+
import pandas as pd
|
7
|
+
from bmtool.util.util import load_nodes_from_config
|
8
|
+
from typing import Dict, Optional,Tuple, Union, List
|
9
|
+
import numpy as np
|
10
|
+
import os
|
11
|
+
|
12
|
+
|
13
|
+
def load_spikes_to_df(spike_file: str, network_name: str, sort: bool = True, config: str = None, groupby: str = 'pop_name') -> pd.DataFrame:
|
14
|
+
"""
|
15
|
+
Load spike data from an HDF5 file into a pandas DataFrame.
|
16
|
+
|
17
|
+
Args:
|
18
|
+
spike_file (str): Path to the HDF5 file containing spike data.
|
19
|
+
network_name (str): The name of the network within the HDF5 file from which to load spike data.
|
20
|
+
sort (bool, optional): Whether to sort the DataFrame by 'timestamps'. Defaults to True.
|
21
|
+
config (str, optional): Will label the cell type of each spike.
|
22
|
+
groupby (str or list of str, optional): The column(s) to group by. Defaults to 'pop_name'.
|
23
|
+
|
24
|
+
Returns:
|
25
|
+
pd.DataFrame: A pandas DataFrame containing 'node_ids' and 'timestamps' columns from the spike data.
|
26
|
+
|
27
|
+
Example:
|
28
|
+
df = load_spikes_to_df("spikes.h5", "cortex")
|
29
|
+
"""
|
30
|
+
with h5py.File(spike_file) as f:
|
31
|
+
spikes_df = pd.DataFrame({
|
32
|
+
'node_ids': f['spikes'][network_name]['node_ids'],
|
33
|
+
'timestamps': f['spikes'][network_name]['timestamps']
|
34
|
+
})
|
35
|
+
|
36
|
+
if sort:
|
37
|
+
spikes_df.sort_values(by='timestamps', inplace=True, ignore_index=True)
|
38
|
+
|
39
|
+
if config:
|
40
|
+
nodes = load_nodes_from_config(config)
|
41
|
+
nodes = nodes[network_name]
|
42
|
+
|
43
|
+
# Convert single string to a list for uniform handling
|
44
|
+
if isinstance(groupby, str):
|
45
|
+
groupby = [groupby]
|
46
|
+
|
47
|
+
# Ensure all requested columns exist
|
48
|
+
missing_cols = [col for col in groupby if col not in nodes.columns]
|
49
|
+
if missing_cols:
|
50
|
+
raise KeyError(f"Columns {missing_cols} not found in nodes DataFrame.")
|
51
|
+
|
52
|
+
spikes_df = spikes_df.merge(nodes[groupby], left_on='node_ids', right_index=True, how='left')
|
53
|
+
|
54
|
+
return spikes_df
|
55
|
+
|
56
|
+
|
57
|
+
def compute_firing_rate_stats(df: pd.DataFrame, groupby: Union[str, List[str]] = "pop_name", start_time: float = None, stop_time: float = None) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
58
|
+
"""
|
59
|
+
Computes the firing rates of individual nodes and the mean and standard deviation of firing rates per group.
|
60
|
+
|
61
|
+
Args:
|
62
|
+
df (pd.DataFrame): Dataframe containing spike timestamps and node IDs.
|
63
|
+
groupby (str or list of str, optional): Column(s) to group by (e.g., 'pop_name' or ['pop_name', 'layer']).
|
64
|
+
start_time (float, optional): Start time for the analysis window. Defaults to the minimum timestamp in the data.
|
65
|
+
stop_time (float, optional): Stop time for the analysis window. Defaults to the maximum timestamp in the data.
|
66
|
+
|
67
|
+
Returns:
|
68
|
+
Tuple[pd.DataFrame, pd.DataFrame]:
|
69
|
+
- The first DataFrame (`pop_stats`) contains the mean and standard deviation of firing rates per group.
|
70
|
+
- The second DataFrame (`individual_stats`) contains the firing rate of each individual node.
|
71
|
+
"""
|
72
|
+
|
73
|
+
# Ensure groupby is a list
|
74
|
+
if isinstance(groupby, str):
|
75
|
+
groupby = [groupby]
|
76
|
+
|
77
|
+
# Ensure all columns exist in the dataframe
|
78
|
+
for col in groupby:
|
79
|
+
if col not in df.columns:
|
80
|
+
raise ValueError(f"Column '{col}' not found in dataframe.")
|
81
|
+
|
82
|
+
# Filter dataframe based on start/stop time
|
83
|
+
if start_time is not None:
|
84
|
+
df = df[df["timestamps"] >= start_time]
|
85
|
+
if stop_time is not None:
|
86
|
+
df = df[df["timestamps"] <= stop_time]
|
87
|
+
|
88
|
+
# Compute total duration for firing rate calculation
|
89
|
+
if start_time is None:
|
90
|
+
min_time = df["timestamps"].min()
|
91
|
+
else:
|
92
|
+
min_time = start_time
|
93
|
+
|
94
|
+
if stop_time is None:
|
95
|
+
max_time = df["timestamps"].max()
|
96
|
+
else:
|
97
|
+
max_time = stop_time
|
98
|
+
|
99
|
+
duration = max_time - min_time # Avoid division by zero
|
100
|
+
|
101
|
+
if duration <= 0:
|
102
|
+
raise ValueError("Invalid time window: Stop time must be greater than start time.")
|
103
|
+
|
104
|
+
# Compute firing rate for each node
|
105
|
+
import pandas as pd
|
106
|
+
|
107
|
+
# Compute spike counts per node
|
108
|
+
spike_counts = df["node_ids"].value_counts().reset_index()
|
109
|
+
spike_counts.columns = ["node_ids", "spike_count"] # Rename columns
|
110
|
+
|
111
|
+
# Merge with original dataframe to get corresponding labels (e.g., 'pop_name')
|
112
|
+
spike_counts = spike_counts.merge(df[["node_ids"] + groupby].drop_duplicates(), on="node_ids", how="left")
|
113
|
+
|
114
|
+
# Compute firing rate
|
115
|
+
spike_counts["firing_rate"] = spike_counts["spike_count"] / duration * 1000 # scale to Hz
|
116
|
+
indivdual_stats = spike_counts
|
117
|
+
|
118
|
+
# Compute mean and standard deviation per group
|
119
|
+
pop_stats = spike_counts.groupby(groupby)["firing_rate"].agg(["mean", "std"]).reset_index()
|
120
|
+
|
121
|
+
# Rename columns
|
122
|
+
pop_stats.rename(columns={"mean": "firing_rate_mean", "std": "firing_rate_std"}, inplace=True)
|
123
|
+
|
124
|
+
return pop_stats,indivdual_stats
|
125
|
+
|
126
|
+
|
127
|
+
def _pop_spike_rate(spike_times: Union[np.ndarray, list], time: Optional[Tuple[float, float, float]] = None,
|
128
|
+
time_points: Optional[Union[np.ndarray, list]] = None, frequeny: bool = False) -> np.ndarray:
|
129
|
+
"""
|
130
|
+
Calculate the spike count or frequency histogram over specified time intervals.
|
131
|
+
|
132
|
+
Args:
|
133
|
+
spike_times (Union[np.ndarray, list]): Array or list of spike times in milliseconds.
|
134
|
+
time (Optional[Tuple[float, float, float]], optional): Tuple specifying (start, stop, step) in milliseconds.
|
135
|
+
Used to create evenly spaced time points if `time_points` is not provided. Default is None.
|
136
|
+
time_points (Optional[Union[np.ndarray, list]], optional): Array or list of specific time points for binning.
|
137
|
+
If provided, `time` is ignored. Default is None.
|
138
|
+
frequeny (bool, optional): If True, returns spike frequency in Hz; otherwise, returns spike count. Default is False.
|
139
|
+
|
140
|
+
Returns:
|
141
|
+
np.ndarray: Array of spike counts or frequencies, depending on the `frequeny` flag.
|
142
|
+
|
143
|
+
Raises:
|
144
|
+
ValueError: If both `time` and `time_points` are None.
|
145
|
+
"""
|
146
|
+
if time_points is None:
|
147
|
+
if time is None:
|
148
|
+
raise ValueError("Either `time` or `time_points` must be provided.")
|
149
|
+
time_points = np.arange(*time)
|
150
|
+
dt = time[2]
|
151
|
+
else:
|
152
|
+
time_points = np.asarray(time_points).ravel()
|
153
|
+
dt = (time_points[-1] - time_points[0]) / (time_points.size - 1)
|
154
|
+
|
155
|
+
bins = np.append(time_points, time_points[-1] + dt)
|
156
|
+
spike_rate, _ = np.histogram(np.asarray(spike_times), bins)
|
157
|
+
|
158
|
+
if frequeny:
|
159
|
+
spike_rate = 1000 / dt * spike_rate
|
160
|
+
|
161
|
+
return spike_rate
|
162
|
+
|
163
|
+
|
164
|
+
def get_population_spike_rate(spikes: pd.DataFrame, fs: float = 400.0, t_start: float = 0, t_stop: Optional[float] = None,
|
165
|
+
config: Optional[str] = None, network_name: Optional[str] = None,
|
166
|
+
save: bool = False, save_path: Optional[str] = None,
|
167
|
+
normalize: bool = False) -> Dict[str, np.ndarray]:
|
168
|
+
"""
|
169
|
+
Calculate the population spike rate for each population in the given spike data, with an option to normalize.
|
170
|
+
|
171
|
+
Args:
|
172
|
+
spikes (pd.DataFrame): A DataFrame containing spike data with columns 'pop_name', 'timestamps', and 'node_ids'.
|
173
|
+
fs (float, optional): Sampling frequency in Hz, which determines the time bin size for calculating the spike rate. Default is 400.
|
174
|
+
t_start (float, optional): Start time (in milliseconds) for spike rate calculation. Default is 0.
|
175
|
+
t_stop (Optional[float], optional): Stop time (in milliseconds) for spike rate calculation. If None, defaults to the maximum timestamp in the data.
|
176
|
+
config (Optional[str], optional): Path to a configuration file containing node information, used to determine the correct number of nodes per population.
|
177
|
+
If None, node count is estimated from unique node spikes. Default is None.
|
178
|
+
network_name (Optional[str], optional): Name of the network used in the configuration file, allowing selection of nodes for that network.
|
179
|
+
Required if `config` is provided. Default is None.
|
180
|
+
save (bool, optional): Whether to save the calculated population spike rate to a file. Default is False.
|
181
|
+
save_path (Optional[str], optional): Directory path where the file should be saved if `save` is True. If `save` is True and `save_path` is None, a ValueError is raised.
|
182
|
+
normalize (bool, optional): Whether to normalize the spike rates for each population to a range of [0, 1]. Default is False.
|
183
|
+
|
184
|
+
Returns:
|
185
|
+
Dict[str, np.ndarray]: A dictionary where keys are population names, and values are arrays representing the spike rate over time for each population.
|
186
|
+
If `normalize` is True, each population's spike rate is scaled to [0, 1].
|
187
|
+
|
188
|
+
Raises:
|
189
|
+
ValueError: If `save` is True but `save_path` is not provided.
|
190
|
+
|
191
|
+
Notes:
|
192
|
+
- If `config` is None, the function assumes all cells in each population have fired at least once; otherwise, the node count may be inaccurate.
|
193
|
+
- If normalization is enabled, each population's spike rate is scaled using Min-Max normalization based on its own minimum and maximum values.
|
194
|
+
|
195
|
+
"""
|
196
|
+
pop_spikes = {}
|
197
|
+
node_number = {}
|
198
|
+
|
199
|
+
if config is None:
|
200
|
+
print("Note: Node number is obtained by counting unique node spikes in the network.\nIf the network did not run for a sufficient duration, and not all cells fired, this count might be incorrect.")
|
201
|
+
print("You can provide a config to calculate the correct amount of nodes!")
|
202
|
+
|
203
|
+
if config:
|
204
|
+
if not network_name:
|
205
|
+
print("Grabbing first network; specify a network name to ensure correct node population is selected.")
|
206
|
+
|
207
|
+
for pop_name in spikes['pop_name'].unique():
|
208
|
+
ps = spikes[spikes['pop_name'] == pop_name]
|
209
|
+
|
210
|
+
if config:
|
211
|
+
nodes = load_nodes_from_config(config)
|
212
|
+
if network_name:
|
213
|
+
nodes = nodes[network_name]
|
214
|
+
else:
|
215
|
+
nodes = list(nodes.values())[0] if nodes else {}
|
216
|
+
nodes = nodes[nodes['pop_name'] == pop_name]
|
217
|
+
node_number[pop_name] = nodes.index.nunique()
|
218
|
+
else:
|
219
|
+
node_number[pop_name] = ps['node_ids'].nunique()
|
220
|
+
|
221
|
+
if t_stop is None:
|
222
|
+
t_stop = spikes['timestamps'].max()
|
223
|
+
|
224
|
+
filtered_spikes = spikes[
|
225
|
+
(spikes['pop_name'] == pop_name) &
|
226
|
+
(spikes['timestamps'] > t_start) &
|
227
|
+
(spikes['timestamps'] < t_stop)
|
228
|
+
]
|
229
|
+
pop_spikes[pop_name] = filtered_spikes
|
230
|
+
|
231
|
+
time = np.array([t_start, t_stop, 1000 / fs])
|
232
|
+
pop_rspk = {p: _pop_spike_rate(spk['timestamps'], time) for p, spk in pop_spikes.items()}
|
233
|
+
spike_rate = {p: fs / node_number[p] * pop_rspk[p] for p in pop_rspk}
|
234
|
+
|
235
|
+
# Normalize each spike rate series if normalize=True
|
236
|
+
if normalize:
|
237
|
+
spike_rate = {p: (sr - sr.min()) / (sr.max() - sr.min()) for p, sr in spike_rate.items()}
|
238
|
+
|
239
|
+
if save:
|
240
|
+
if save_path is None:
|
241
|
+
raise ValueError("save_path must be provided if save is True.")
|
242
|
+
|
243
|
+
os.makedirs(save_path, exist_ok=True)
|
244
|
+
|
245
|
+
save_file = os.path.join(save_path, 'spike_rate.h5')
|
246
|
+
with h5py.File(save_file, 'w') as f:
|
247
|
+
f.create_dataset('time', data=time)
|
248
|
+
grp = f.create_group('populations')
|
249
|
+
for p, rspk in spike_rate.items():
|
250
|
+
pop_grp = grp.create_group(p)
|
251
|
+
pop_grp.create_dataset('data', data=rspk)
|
252
|
+
|
253
|
+
return spike_rate
|
254
|
+
|
@@ -13,6 +13,7 @@ import matplotlib.colors as colors
|
|
13
13
|
import matplotlib.gridspec as gridspec
|
14
14
|
from mpl_toolkits.mplot3d import Axes3D
|
15
15
|
from matplotlib.axes import Axes
|
16
|
+
import seaborn as sns
|
16
17
|
from IPython import get_ipython
|
17
18
|
from IPython.display import display, HTML
|
18
19
|
import statistics
|
@@ -20,7 +21,7 @@ import pandas as pd
|
|
20
21
|
import os
|
21
22
|
import sys
|
22
23
|
import re
|
23
|
-
from typing import Optional, Dict
|
24
|
+
from typing import Optional, Dict, Union, List
|
24
25
|
|
25
26
|
from .util.util import CellVarsFile,load_nodes_from_config #, missing_units
|
26
27
|
from bmtk.analyzer.utils import listify
|
@@ -762,7 +763,7 @@ def connector_percent_matrix(csv_path: str = None, exclude_strings=None, assemb_
|
|
762
763
|
plt.tight_layout()
|
763
764
|
plt.show()
|
764
765
|
|
765
|
-
def raster(spikes_df: Optional[pd.DataFrame] = None, config: Optional[str] = None, network_name: Optional[str] = None,
|
766
|
+
def raster(spikes_df: Optional[pd.DataFrame] = None, config: Optional[str] = None, network_name: Optional[str] = None, groupby:Optional[str] = 'pop_name',
|
766
767
|
ax: Optional[Axes] = None,tstart: Optional[float] = None,tstop: Optional[float] = None,
|
767
768
|
color_map: Optional[Dict[str, str]] = None) -> Axes:
|
768
769
|
"""
|
@@ -793,7 +794,7 @@ def raster(spikes_df: Optional[pd.DataFrame] = None, config: Optional[str] = Non
|
|
793
794
|
Notes:
|
794
795
|
-----
|
795
796
|
- If `config` is provided, the function merges population names from the node data with `spikes_df`.
|
796
|
-
- Each unique population
|
797
|
+
- Each unique population from groupby in `spikes_df` will be represented by a different color if `color_map` is not specified.
|
797
798
|
- If `color_map` is provided, it should contain colors for all unique `pop_name` values in `spikes_df`.
|
798
799
|
"""
|
799
800
|
# Initialize axes if none provided
|
@@ -822,11 +823,11 @@ def raster(spikes_df: Optional[pd.DataFrame] = None, config: Optional[str] = Non
|
|
822
823
|
# Drop all intersecting columns except the join key column from df2
|
823
824
|
spikes_df = spikes_df.drop(columns=common_columns)
|
824
825
|
# merge nodes and spikes df
|
825
|
-
spikes_df = spikes_df.merge(nodes[
|
826
|
+
spikes_df = spikes_df.merge(nodes[groupby], left_on='node_ids', right_index=True, how='left')
|
826
827
|
|
827
828
|
|
828
829
|
# Get unique population names
|
829
|
-
unique_pop_names = spikes_df[
|
830
|
+
unique_pop_names = spikes_df[groupby].unique()
|
830
831
|
|
831
832
|
# Generate colors if no color_map is provided
|
832
833
|
if color_map is None:
|
@@ -839,7 +840,7 @@ def raster(spikes_df: Optional[pd.DataFrame] = None, config: Optional[str] = Non
|
|
839
840
|
raise ValueError(f"color_map is missing colors for populations: {missing_colors}")
|
840
841
|
|
841
842
|
# Plot each population with its specified or generated color
|
842
|
-
for pop_name, group in spikes_df.groupby(
|
843
|
+
for pop_name, group in spikes_df.groupby(groupby):
|
843
844
|
ax.scatter(group['timestamps'], group['node_ids'], label=pop_name, color=color_map[pop_name], s=0.5)
|
844
845
|
|
845
846
|
# Label axes
|
@@ -849,6 +850,169 @@ def raster(spikes_df: Optional[pd.DataFrame] = None, config: Optional[str] = Non
|
|
849
850
|
|
850
851
|
return ax
|
851
852
|
|
853
|
+
# uses df from bmtool.analysis.spikes compute_firing_rate_stats
|
854
|
+
def plot_firing_rate_pop_stats(firing_stats: pd.DataFrame, groupby: Union[str, List[str]], ax: Optional[Axes] = None,
|
855
|
+
color_map: Optional[Dict[str, str]] = None) -> Axes:
|
856
|
+
"""
|
857
|
+
Plots a bar graph of mean firing rates with error bars (standard deviation).
|
858
|
+
|
859
|
+
Parameters:
|
860
|
+
----------
|
861
|
+
firing_stats : pd.DataFrame
|
862
|
+
Dataframe containing 'firing_rate_mean' and 'firing_rate_std'.
|
863
|
+
groupby : str or list of str
|
864
|
+
Column(s) used for grouping.
|
865
|
+
ax : matplotlib.axes.Axes, optional
|
866
|
+
Axes on which to plot the bar chart; if None, a new figure and axes are created.
|
867
|
+
color_map : dict, optional
|
868
|
+
Dictionary specifying colors for each group. Keys should be group names, and values should be color values.
|
869
|
+
|
870
|
+
Returns:
|
871
|
+
-------
|
872
|
+
matplotlib.axes.Axes
|
873
|
+
Axes with the bar plot.
|
874
|
+
"""
|
875
|
+
# Ensure groupby is a list for consistent handling
|
876
|
+
if isinstance(groupby, str):
|
877
|
+
groupby = [groupby]
|
878
|
+
|
879
|
+
# Create a categorical column for grouping
|
880
|
+
firing_stats["group"] = firing_stats[groupby].astype(str).agg("_".join, axis=1)
|
881
|
+
|
882
|
+
# Get unique group names
|
883
|
+
unique_groups = firing_stats["group"].unique()
|
884
|
+
|
885
|
+
# Generate colors if no color_map is provided
|
886
|
+
if color_map is None:
|
887
|
+
cmap = plt.get_cmap('viridis')
|
888
|
+
color_map = {group: cmap(i / len(unique_groups)) for i, group in enumerate(unique_groups)}
|
889
|
+
else:
|
890
|
+
# Ensure color_map contains all groups
|
891
|
+
missing_colors = [group for group in unique_groups if group not in color_map]
|
892
|
+
if missing_colors:
|
893
|
+
raise ValueError(f"color_map is missing colors for groups: {missing_colors}")
|
894
|
+
|
895
|
+
# Create new figure and axes if ax is not provided
|
896
|
+
if ax is None:
|
897
|
+
fig, ax = plt.subplots(figsize=(10, 6))
|
898
|
+
|
899
|
+
# Sort data for consistent plotting
|
900
|
+
firing_stats = firing_stats.sort_values(by="group")
|
901
|
+
|
902
|
+
# Extract values for plotting
|
903
|
+
x_labels = firing_stats["group"]
|
904
|
+
means = firing_stats["firing_rate_mean"]
|
905
|
+
std_devs = firing_stats["firing_rate_std"]
|
906
|
+
|
907
|
+
# Get colors for each group
|
908
|
+
colors = [color_map[group] for group in x_labels]
|
909
|
+
|
910
|
+
# Create bar plot
|
911
|
+
bars = ax.bar(x_labels, means, yerr=std_devs, capsize=5, color=colors, edgecolor="black")
|
912
|
+
|
913
|
+
# Add error bars manually with caps
|
914
|
+
_, caps, _ = ax.errorbar(
|
915
|
+
x=np.arange(len(x_labels)),
|
916
|
+
y=means,
|
917
|
+
yerr=std_devs,
|
918
|
+
fmt='none',
|
919
|
+
capsize=5,
|
920
|
+
capthick=2,
|
921
|
+
color="black"
|
922
|
+
)
|
923
|
+
|
924
|
+
# Formatting
|
925
|
+
ax.set_xticks(np.arange(len(x_labels)))
|
926
|
+
ax.set_xticklabels(x_labels, rotation=45, ha="right")
|
927
|
+
ax.set_xlabel("Population Group")
|
928
|
+
ax.set_ylabel("Mean Firing Rate (spikes/s)")
|
929
|
+
ax.set_title("Firing Rate Statistics by Population")
|
930
|
+
ax.grid(axis='y', linestyle='--', alpha=0.7)
|
931
|
+
|
932
|
+
return ax
|
933
|
+
|
934
|
+
# uses df from bmtool.analysis.spikes compute_firing_rate_stats
|
935
|
+
def plot_firing_rate_distribution(individual_stats: pd.DataFrame, groupby: Union[str, list], ax: Optional[Axes] = None,
|
936
|
+
color_map: Optional[Dict[str, str]] = None,
|
937
|
+
plot_type: Union[str, list] = "box", swarm_alpha: float = 0.6) -> Axes:
|
938
|
+
"""
|
939
|
+
Plots a distribution of individual firing rates using one or more plot types
|
940
|
+
(box plot, violin plot, or swarm plot), overlaying them on top of each other.
|
941
|
+
|
942
|
+
Parameters:
|
943
|
+
----------
|
944
|
+
individual_stats : pd.DataFrame
|
945
|
+
Dataframe containing individual firing rates and corresponding group labels.
|
946
|
+
groupby : str or list of str
|
947
|
+
Column(s) used for grouping.
|
948
|
+
ax : matplotlib.axes.Axes, optional
|
949
|
+
Axes on which to plot the graph; if None, a new figure and axes are created.
|
950
|
+
color_map : dict, optional
|
951
|
+
Dictionary specifying colors for each group. Keys should be group names, and values should be color values.
|
952
|
+
plot_type : str or list of str, optional
|
953
|
+
List of plot types to generate. Options: "box", "violin", "swarm". Default is "box".
|
954
|
+
swarm_alpha : float, optional
|
955
|
+
Transparency of swarm plot points. Default is 0.6.
|
956
|
+
|
957
|
+
Returns:
|
958
|
+
-------
|
959
|
+
matplotlib.axes.Axes
|
960
|
+
Axes with the selected plot type(s) overlayed.
|
961
|
+
"""
|
962
|
+
# Ensure groupby is a list for consistent handling
|
963
|
+
if isinstance(groupby, str):
|
964
|
+
groupby = [groupby]
|
965
|
+
|
966
|
+
# Create a categorical column for grouping
|
967
|
+
individual_stats["group"] = individual_stats[groupby].astype(str).agg("_".join, axis=1)
|
968
|
+
|
969
|
+
# Validate plot_type (it can be a list or a single type)
|
970
|
+
if isinstance(plot_type, str):
|
971
|
+
plot_type = [plot_type]
|
972
|
+
|
973
|
+
for pt in plot_type:
|
974
|
+
if pt not in ["box", "violin", "swarm"]:
|
975
|
+
raise ValueError("plot_type must be one of: 'box', 'violin', 'swarm'.")
|
976
|
+
|
977
|
+
# Get unique groups for coloring
|
978
|
+
unique_groups = individual_stats["group"].unique()
|
979
|
+
|
980
|
+
# Generate colors if no color_map is provided
|
981
|
+
if color_map is None:
|
982
|
+
cmap = plt.get_cmap('viridis')
|
983
|
+
color_map = {group: cmap(i / len(unique_groups)) for i, group in enumerate(unique_groups)}
|
984
|
+
|
985
|
+
# Ensure color_map contains all groups
|
986
|
+
missing_colors = [group for group in unique_groups if group not in color_map]
|
987
|
+
if missing_colors:
|
988
|
+
raise ValueError(f"color_map is missing colors for groups: {missing_colors}")
|
989
|
+
|
990
|
+
# Create new figure and axes if ax is not provided
|
991
|
+
if ax is None:
|
992
|
+
fig, ax = plt.subplots(figsize=(10, 6))
|
993
|
+
|
994
|
+
# Sort data for consistent plotting
|
995
|
+
individual_stats = individual_stats.sort_values(by="group")
|
996
|
+
|
997
|
+
# Loop over each plot type and overlay them
|
998
|
+
for pt in plot_type:
|
999
|
+
if pt == "box":
|
1000
|
+
sns.boxplot(data=individual_stats, x="group", y="firing_rate", ax=ax, palette=color_map, width=0.5)
|
1001
|
+
elif pt == "violin":
|
1002
|
+
sns.violinplot(data=individual_stats, x="group", y="firing_rate", ax=ax, palette=color_map, inner="quartile", alpha=0.4)
|
1003
|
+
elif pt == "swarm":
|
1004
|
+
sns.swarmplot(data=individual_stats, x="group", y="firing_rate", ax=ax, palette=color_map, alpha=swarm_alpha)
|
1005
|
+
|
1006
|
+
# Formatting
|
1007
|
+
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
|
1008
|
+
ax.set_xlabel("Population Group")
|
1009
|
+
ax.set_ylabel("Firing Rate (spikes/s)")
|
1010
|
+
ax.set_title("Firing Rate Distribution by Population")
|
1011
|
+
ax.grid(axis='y', linestyle='--', alpha=0.7)
|
1012
|
+
|
1013
|
+
return ax
|
1014
|
+
|
1015
|
+
|
852
1016
|
def plot_3d_positions(config=None, populations_list=None, group_by=None, title=None, save_file=None, subset=None):
|
853
1017
|
"""
|
854
1018
|
Plots a 3D graph of all cells with x, y, z location.
|
@@ -1154,164 +1318,3 @@ def plot_report_default(config, report_name, variables, gids):
|
|
1154
1318
|
plot_report(config_file=config, report_file=report_file, report_name=report_name, variables=variables, gids=gids);
|
1155
1319
|
|
1156
1320
|
return
|
1157
|
-
|
1158
|
-
# The following code was developed by Matthew Stroud 7/15/21 neural engineering supervisor: Satish Nair
|
1159
|
-
# This is an extension of bmtool: a development of Tyler Banks.
|
1160
|
-
# The goal of the sim_setup() function is to output relevant simulation information that can be gathered by providing only the main configuration file.
|
1161
|
-
|
1162
|
-
|
1163
|
-
def sim_setup(config_file='simulation_config.json',network=None):
|
1164
|
-
if "JPY_PARENT_PID" in os.environ:
|
1165
|
-
print("Inside a notebook:")
|
1166
|
-
get_ipython().run_line_magic('matplotlib', 'tk')
|
1167
|
-
|
1168
|
-
|
1169
|
-
# Output tables that contain the cells involved in the configuration file given. Also returns the first biophysical network found
|
1170
|
-
bio=plot_basic_cell_info(config_file)
|
1171
|
-
if network == None:
|
1172
|
-
network=bio
|
1173
|
-
|
1174
|
-
print("Please wait. This may take a while depending on your network size...")
|
1175
|
-
# Plot connection probabilities
|
1176
|
-
plt.close(1)
|
1177
|
-
probability_connection_matrix(config=config_file,sources=network,targets=network, no_prepend_pop=True,sids= 'pop_name', tids= 'pop_name', bins=10,line_plot=True,verbose=False)
|
1178
|
-
# Gives current clamp information
|
1179
|
-
plot_I_clamps(config_file)
|
1180
|
-
# Plot spike train info
|
1181
|
-
plot_inspikes(config_file)
|
1182
|
-
# Using bmtool, print total number of connections between cell groups
|
1183
|
-
total_connection_matrix(config=config_file,sources='all',targets='all',sids='pop_name',tids='pop_name',title='All Connections found', size_scalar=2, no_prepend_pop=True, synaptic_info='0')
|
1184
|
-
# Plot 3d positions of the network
|
1185
|
-
plot_3d_positions(populations='all',config=config_file,group_by='pop_name',title='3D Positions',save_file=None)
|
1186
|
-
|
1187
|
-
def plot_I_clamps(fp):
|
1188
|
-
print("Plotting current clamp info...")
|
1189
|
-
clamps = util.load_I_clamp_from_config(fp)
|
1190
|
-
if not clamps:
|
1191
|
-
print(" No current clamps were found.")
|
1192
|
-
return
|
1193
|
-
time=[]
|
1194
|
-
num_clamps=0
|
1195
|
-
fig, ax = plt.subplots()
|
1196
|
-
ax = plt.gca()
|
1197
|
-
for clinfo in clamps:
|
1198
|
-
simtime=len(clinfo[0])*clinfo[1]
|
1199
|
-
time.append(np.arange(0,simtime,clinfo[1]).tolist())
|
1200
|
-
|
1201
|
-
line,=ax.plot(time[num_clamps],clinfo[0],drawstyle='steps')
|
1202
|
-
line.set_label('I Clamp to: '+str(clinfo[2]))
|
1203
|
-
plt.legend()
|
1204
|
-
num_clamps=num_clamps+1
|
1205
|
-
|
1206
|
-
def plot_basic_cell_info(config_file):
|
1207
|
-
print("Network and node info:")
|
1208
|
-
nodes=util.load_nodes_from_config(config_file)
|
1209
|
-
if not nodes:
|
1210
|
-
print("No nodes were found.")
|
1211
|
-
return
|
1212
|
-
pd.set_option("display.max_rows", None, "display.max_columns", None)
|
1213
|
-
bio=[]
|
1214
|
-
i=0
|
1215
|
-
j=0
|
1216
|
-
for j in nodes:
|
1217
|
-
node=nodes[j]
|
1218
|
-
node_type_id=node['node_type_id']
|
1219
|
-
num_cells=len(node['node_type_id'])
|
1220
|
-
if node['model_type'][0]=='virtual':
|
1221
|
-
CELLS=[]
|
1222
|
-
count=1
|
1223
|
-
for i in range(num_cells-1):
|
1224
|
-
if(node_type_id[i]==node_type_id[i+1]):
|
1225
|
-
count+=1
|
1226
|
-
else:
|
1227
|
-
node_type=node_type_id[i]
|
1228
|
-
pop_name=node['pop_name'][i]
|
1229
|
-
model_type=node['model_type'][i]
|
1230
|
-
CELLS.append([node_type,pop_name,model_type,count])
|
1231
|
-
count=1
|
1232
|
-
else:
|
1233
|
-
node_type=node_type_id[i]
|
1234
|
-
pop_name=node['pop_name'][i]
|
1235
|
-
model_type=node['model_type'][i]
|
1236
|
-
CELLS.append([node_type,pop_name,model_type,count])
|
1237
|
-
count=1
|
1238
|
-
df1 = pd.DataFrame(CELLS, columns = ["node_type","pop_name","model_type","count"])
|
1239
|
-
print(j+':')
|
1240
|
-
notebook = is_notebook()
|
1241
|
-
if notebook == True:
|
1242
|
-
display(HTML(df1.to_html()))
|
1243
|
-
else:
|
1244
|
-
print(df1)
|
1245
|
-
elif node['model_type'][0]=='biophysical':
|
1246
|
-
CELLS=[]
|
1247
|
-
count=1
|
1248
|
-
node_type_id=node['node_type_id']
|
1249
|
-
num_cells=len(node['node_type_id'])
|
1250
|
-
for i in range(num_cells-1):
|
1251
|
-
if(node_type_id[i]==node_type_id[i+1]):
|
1252
|
-
count+=1
|
1253
|
-
else:
|
1254
|
-
node_type=node_type_id[i]
|
1255
|
-
pop_name=node['pop_name'][i]
|
1256
|
-
model_type=node['model_type'][i]
|
1257
|
-
model_template=node['model_template'][i]
|
1258
|
-
morphology=node['morphology'][i] if node['morphology'][i] else ''
|
1259
|
-
CELLS.append([node_type,pop_name,model_type,model_template,morphology,count])
|
1260
|
-
count=1
|
1261
|
-
else:
|
1262
|
-
node_type=node_type_id[i]
|
1263
|
-
pop_name=node['pop_name'][i]
|
1264
|
-
model_type=node['model_type'][i]
|
1265
|
-
model_template=node['model_template'][i]
|
1266
|
-
morphology=node['morphology'][i] if node['morphology'][i] else ''
|
1267
|
-
CELLS.append([node_type,pop_name,model_type,model_template,morphology,count])
|
1268
|
-
count=1
|
1269
|
-
df2 = pd.DataFrame(CELLS, columns = ["node_type","pop_name","model_type","model_template","morphology","count"])
|
1270
|
-
print(j+':')
|
1271
|
-
bio.append(j)
|
1272
|
-
notebook = is_notebook()
|
1273
|
-
if notebook == True:
|
1274
|
-
display(HTML(df2.to_html()))
|
1275
|
-
else:
|
1276
|
-
print(df2)
|
1277
|
-
if len(bio)>0:
|
1278
|
-
return bio[0]
|
1279
|
-
|
1280
|
-
def plot_inspikes(fp):
|
1281
|
-
|
1282
|
-
print("Plotting spike Train info...")
|
1283
|
-
trains = util.load_inspikes_from_config(fp)
|
1284
|
-
if not trains:
|
1285
|
-
print("No spike trains were found.")
|
1286
|
-
num_trains=len(trains)
|
1287
|
-
|
1288
|
-
time=[]
|
1289
|
-
node=[]
|
1290
|
-
fig, ax = plt.subplots(num_trains, figsize=(12,12),squeeze=False)
|
1291
|
-
fig.subplots_adjust(hspace=0.5, wspace=0.5)
|
1292
|
-
|
1293
|
-
pos=0
|
1294
|
-
for tr in trains:
|
1295
|
-
node_group=tr[0][2]
|
1296
|
-
if node_group=='':
|
1297
|
-
node_group='Defined by gids (y-axis)'
|
1298
|
-
time=[]
|
1299
|
-
node=[]
|
1300
|
-
for sp in tr:
|
1301
|
-
node.append(sp[1])
|
1302
|
-
time.append(sp[0])
|
1303
|
-
|
1304
|
-
#plotting spike train
|
1305
|
-
|
1306
|
-
ax[pos,0].scatter(time,node,s=1)
|
1307
|
-
ax[pos,0].title.set_text('Input Spike Train to: '+node_group)
|
1308
|
-
plt.xticks(rotation = 45)
|
1309
|
-
if num_trains <=4:
|
1310
|
-
ax[pos,0].xaxis.set_major_locator(plt.MaxNLocator(20))
|
1311
|
-
if num_trains <=9 and num_trains >4:
|
1312
|
-
ax[pos,0].xaxis.set_major_locator(plt.MaxNLocator(4))
|
1313
|
-
elif num_trains <9:
|
1314
|
-
ax[pos,0].xaxis.set_major_locator(plt.MaxNLocator(2))
|
1315
|
-
#fig.suptitle('Input Spike Train to: '+node_group, fontsize=14)
|
1316
|
-
fig.show()
|
1317
|
-
pos+=1
|
@@ -90,7 +90,11 @@ class CurrentClamp(object):
|
|
90
90
|
self.inj_dur = inj_dur
|
91
91
|
self.inj_amp = inj_amp * 1e-3 # pA to nA
|
92
92
|
|
93
|
-
|
93
|
+
# sometimes people may put a hoc object in for the template name
|
94
|
+
if callable(template_name):
|
95
|
+
self.cell = template_name
|
96
|
+
else:
|
97
|
+
self.cell = self.create_cell()
|
94
98
|
if post_init_function:
|
95
99
|
eval(f"self.cell.{post_init_function}")
|
96
100
|
|
File without changes
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.2
|
2
2
|
Name: bmtool
|
3
|
-
Version: 0.6.
|
3
|
+
Version: 0.6.8
|
4
4
|
Summary: BMTool
|
5
5
|
Home-page: https://github.com/cyneuro/bmtool
|
6
6
|
Download-URL:
|
@@ -53,6 +53,7 @@ A collection of modules to make developing [Neuron](https://www.neuron.yale.edu/
|
|
53
53
|
- [Synapses](#synapses-module)
|
54
54
|
- [Connectors](#connectors-module)
|
55
55
|
- [Bmplot](#bmplot-module)
|
56
|
+
- [Analysis](#analysis-module)
|
56
57
|
- [SLURM](#slurm-module)
|
57
58
|
- [Graphs](#graphs-module)
|
58
59
|
|
@@ -471,7 +472,11 @@ bmplot.plot_network_graph(config='config.json',sources='LA',targets='LA',tids='p
|
|
471
472
|
|
472
473
|
|
473
474
|

|
474
|
-
|
475
|
+
|
476
|
+
|
477
|
+
## Analysis Module
|
478
|
+
### A notebook example of how to use the spikes module can be found [here](examples/analysis/using_spikes.ipynb)
|
479
|
+
|
475
480
|
## SLURM Module
|
476
481
|
### This is an extremely helpful module that can simplify using SLURM too submit your models. There is also features to enable doing a seedSweep. This will vary the parameters of the simulation and make tuning the model easier. An example can be found [here](examples/SLURM/using_BlockRunner.ipynb)
|
477
482
|
|
@@ -17,6 +17,9 @@ bmtool.egg-info/dependency_links.txt
|
|
17
17
|
bmtool.egg-info/entry_points.txt
|
18
18
|
bmtool.egg-info/requires.txt
|
19
19
|
bmtool.egg-info/top_level.txt
|
20
|
+
bmtool/analysis/__init__.py
|
21
|
+
bmtool/analysis/lfp.py
|
22
|
+
bmtool/analysis/spikes.py
|
20
23
|
bmtool/debug/__init__.py
|
21
24
|
bmtool/debug/commands.py
|
22
25
|
bmtool/debug/debug.py
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|