bmtool 0.6.7__tar.gz → 0.6.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. {bmtool-0.6.7 → bmtool-0.6.8}/PKG-INFO +7 -2
  2. {bmtool-0.6.7 → bmtool-0.6.8}/README.md +6 -1
  3. {bmtool-0.6.7 → bmtool-0.6.8}/bmtool/SLURM.py +10 -14
  4. bmtool-0.6.8/bmtool/analysis/lfp.py +408 -0
  5. bmtool-0.6.8/bmtool/analysis/spikes.py +254 -0
  6. {bmtool-0.6.7 → bmtool-0.6.8}/bmtool/bmplot.py +170 -167
  7. {bmtool-0.6.7 → bmtool-0.6.8}/bmtool/singlecell.py +5 -1
  8. bmtool-0.6.8/bmtool/util/neuron/__init__.py +0 -0
  9. {bmtool-0.6.7 → bmtool-0.6.8}/bmtool.egg-info/PKG-INFO +7 -2
  10. {bmtool-0.6.7 → bmtool-0.6.8}/bmtool.egg-info/SOURCES.txt +3 -0
  11. {bmtool-0.6.7 → bmtool-0.6.8}/setup.py +1 -1
  12. {bmtool-0.6.7 → bmtool-0.6.8}/LICENSE +0 -0
  13. {bmtool-0.6.7 → bmtool-0.6.8}/bmtool/__init__.py +0 -0
  14. {bmtool-0.6.7 → bmtool-0.6.8}/bmtool/__main__.py +0 -0
  15. {bmtool-0.6.7/bmtool/debug → bmtool-0.6.8/bmtool/analysis}/__init__.py +0 -0
  16. {bmtool-0.6.7 → bmtool-0.6.8}/bmtool/connectors.py +0 -0
  17. {bmtool-0.6.7/bmtool/util → bmtool-0.6.8/bmtool/debug}/__init__.py +0 -0
  18. {bmtool-0.6.7 → bmtool-0.6.8}/bmtool/debug/commands.py +0 -0
  19. {bmtool-0.6.7 → bmtool-0.6.8}/bmtool/debug/debug.py +0 -0
  20. {bmtool-0.6.7 → bmtool-0.6.8}/bmtool/graphs.py +0 -0
  21. {bmtool-0.6.7 → bmtool-0.6.8}/bmtool/manage.py +0 -0
  22. {bmtool-0.6.7 → bmtool-0.6.8}/bmtool/plot_commands.py +0 -0
  23. {bmtool-0.6.7 → bmtool-0.6.8}/bmtool/synapses.py +0 -0
  24. {bmtool-0.6.7/bmtool/util/neuron → bmtool-0.6.8/bmtool/util}/__init__.py +0 -0
  25. {bmtool-0.6.7 → bmtool-0.6.8}/bmtool/util/commands.py +0 -0
  26. {bmtool-0.6.7 → bmtool-0.6.8}/bmtool/util/neuron/celltuner.py +0 -0
  27. {bmtool-0.6.7 → bmtool-0.6.8}/bmtool/util/util.py +0 -0
  28. {bmtool-0.6.7 → bmtool-0.6.8}/bmtool.egg-info/dependency_links.txt +0 -0
  29. {bmtool-0.6.7 → bmtool-0.6.8}/bmtool.egg-info/entry_points.txt +0 -0
  30. {bmtool-0.6.7 → bmtool-0.6.8}/bmtool.egg-info/requires.txt +0 -0
  31. {bmtool-0.6.7 → bmtool-0.6.8}/bmtool.egg-info/top_level.txt +0 -0
  32. {bmtool-0.6.7 → bmtool-0.6.8}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: bmtool
3
- Version: 0.6.7
3
+ Version: 0.6.8
4
4
  Summary: BMTool
5
5
  Home-page: https://github.com/cyneuro/bmtool
6
6
  Download-URL:
@@ -53,6 +53,7 @@ A collection of modules to make developing [Neuron](https://www.neuron.yale.edu/
53
53
  - [Synapses](#synapses-module)
54
54
  - [Connectors](#connectors-module)
55
55
  - [Bmplot](#bmplot-module)
56
+ - [Analysis](#analysis-module)
56
57
  - [SLURM](#slurm-module)
57
58
  - [Graphs](#graphs-module)
58
59
 
@@ -471,7 +472,11 @@ bmplot.plot_network_graph(config='config.json',sources='LA',targets='LA',tids='p
471
472
 
472
473
 
473
474
  ![png](readme_figures/output_35_0.png)
474
-
475
+
476
+
477
+ ## Analysis Module
478
+ ### A notebook example of how to use the spikes module can be found [here](examples/analysis/using_spikes.ipynb)
479
+
475
480
  ## SLURM Module
476
481
  ### This is an extremely helpful module that can simplify using SLURM too submit your models. There is also features to enable doing a seedSweep. This will vary the parameters of the simulation and make tuning the model easier. An example can be found [here](examples/SLURM/using_BlockRunner.ipynb)
477
482
 
@@ -10,6 +10,7 @@ A collection of modules to make developing [Neuron](https://www.neuron.yale.edu/
10
10
  - [Synapses](#synapses-module)
11
11
  - [Connectors](#connectors-module)
12
12
  - [Bmplot](#bmplot-module)
13
+ - [Analysis](#analysis-module)
13
14
  - [SLURM](#slurm-module)
14
15
  - [Graphs](#graphs-module)
15
16
 
@@ -428,7 +429,11 @@ bmplot.plot_network_graph(config='config.json',sources='LA',targets='LA',tids='p
428
429
 
429
430
 
430
431
  ![png](readme_figures/output_35_0.png)
431
-
432
+
433
+
434
+ ## Analysis Module
435
+ ### A notebook example of how to use the spikes module can be found [here](examples/analysis/using_spikes.ipynb)
436
+
432
437
  ## SLURM Module
433
438
  ### This is an extremely helpful module that can simplify using SLURM too submit your models. There is also features to enable doing a seedSweep. This will vary the parameters of the simulation and make tuning the model easier. An example can be found [here](examples/SLURM/using_BlockRunner.ipynb)
434
439
 
@@ -353,17 +353,15 @@ class BlockRunner:
353
353
  shutil.copytree(source_dir, destination_dir) # create new components folder
354
354
  json_file_path = os.path.join(destination_dir,self.json_file_path)
355
355
 
356
- # need to keep the orignal around
357
- syn_dict_temp = copy.deepcopy(self.syn_dict)
358
- print(self.syn_dict['json_file_path'])
359
- json_to_be_ratioed = syn_dict_temp['json_file_path']
360
- corrected_ratio_path = os.path.join(destination_dir,json_to_be_ratioed)
361
- syn_dict_temp['json_file_path'] = corrected_ratio_path
362
-
363
356
  if self.syn_dict == None:
364
357
  json_editor = seedSweep(json_file_path , self.param_name)
365
358
  json_editor.edit_json(new_value)
366
359
  else:
360
+ # need to keep the orignal around
361
+ syn_dict_temp = copy.deepcopy(self.syn_dict)
362
+ json_to_be_ratioed = syn_dict_temp['json_file_path']
363
+ corrected_ratio_path = os.path.join(destination_dir,json_to_be_ratioed)
364
+ syn_dict_temp['json_file_path'] = corrected_ratio_path
367
365
  json_editor = multiSeedSweep(json_file_path ,self.param_name,
368
366
  syn_dict=syn_dict_temp,base_ratio=1)
369
367
  json_editor.edit_all_jsons(new_value)
@@ -415,17 +413,15 @@ class BlockRunner:
415
413
  shutil.copytree(source_dir, destination_dir) # create new components folder
416
414
  json_file_path = os.path.join(destination_dir,self.json_file_path)
417
415
 
418
- # need to keep the orignal around
419
- syn_dict_temp = copy.deepcopy(self.syn_dict)
420
- print(self.syn_dict['json_file_path'])
421
- json_to_be_ratioed = syn_dict_temp['json_file_path']
422
- corrected_ratio_path = os.path.join(destination_dir,json_to_be_ratioed)
423
- syn_dict_temp['json_file_path'] = corrected_ratio_path
424
-
425
416
  if self.syn_dict == None:
426
417
  json_editor = seedSweep(json_file_path , self.param_name)
427
418
  json_editor.edit_json(new_value)
428
419
  else:
420
+ # need to keep the orignal around
421
+ syn_dict_temp = copy.deepcopy(self.syn_dict)
422
+ json_to_be_ratioed = syn_dict_temp['json_file_path']
423
+ corrected_ratio_path = os.path.join(destination_dir,json_to_be_ratioed)
424
+ syn_dict_temp['json_file_path'] = corrected_ratio_path
429
425
  json_editor = multiSeedSweep(json_file_path ,self.param_name,
430
426
  syn_dict_temp,base_ratio=1)
431
427
  json_editor.edit_all_jsons(new_value)
@@ -0,0 +1,408 @@
1
+ """
2
+ Module for processing BMTK LFP output.
3
+ """
4
+
5
+ import h5py
6
+ import numpy as np
7
+ import xarray as xr
8
+ from fooof import FOOOF
9
+ from fooof.sim.gen import gen_model
10
+ import matplotlib.pyplot as plt
11
+ from scipy import signal
12
+ import pywt
13
+ from bmtool.bmplot import is_notebook
14
+
15
+
16
+ def load_ecp_to_xarray(ecp_file: str, demean: bool = False) -> xr.DataArray:
17
+ """
18
+ Load ECP data from an HDF5 file (BMTK sim) into an xarray DataArray.
19
+
20
+ Parameters:
21
+ ----------
22
+ ecp_file : str
23
+ Path to the HDF5 file containing ECP data.
24
+ demean : bool, optional
25
+ If True, the mean of the data will be subtracted (default is False).
26
+
27
+ Returns:
28
+ -------
29
+ xr.DataArray
30
+ An xarray DataArray containing the ECP data, with time as one dimension
31
+ and channel_id as another.
32
+ """
33
+ with h5py.File(ecp_file, 'r') as f:
34
+ ecp = xr.DataArray(
35
+ f['ecp']['data'][()].T,
36
+ coords=dict(
37
+ channel_id=f['ecp']['channel_id'][()],
38
+ time=np.arange(*f['ecp']['time']) # ms
39
+ ),
40
+ attrs=dict(
41
+ fs=1000 / f['ecp']['time'][2] # Hz
42
+ )
43
+ )
44
+ if demean:
45
+ ecp -= ecp.mean(dim='time')
46
+ return ecp
47
+
48
+
49
+ def ecp_to_lfp(ecp_data: xr.DataArray, cutoff: float = 250, fs: float = 10000,
50
+ downsample_freq: float = 1000) -> xr.DataArray:
51
+ """
52
+ Apply a low-pass Butterworth filter to an xarray DataArray and optionally downsample.
53
+ This filters out the high end frequencies turning the ECP into a LFP
54
+
55
+ Parameters:
56
+ ----------
57
+ ecp_data : xr.DataArray
58
+ The input data array containing LFP data with time as one dimension.
59
+ cutoff : float
60
+ The cutoff frequency for the low-pass filter in Hz (default is 250Hz).
61
+ fs : float, optional
62
+ The sampling frequency of the data (default is 10000 Hz).
63
+ downsample_freq : float, optional
64
+ The frequency to downsample to (default is 1000 Hz).
65
+
66
+ Returns:
67
+ -------
68
+ xr.DataArray
69
+ The filtered (and possibly downsampled) data as an xarray DataArray.
70
+ """
71
+ # Bandpass filter design
72
+ nyq = 0.5 * fs
73
+ cut = cutoff / nyq
74
+ b, a = signal.butter(8, cut, btype='low', analog=False)
75
+
76
+ # Initialize an array to hold filtered data
77
+ filtered_data = xr.DataArray(np.zeros_like(ecp_data), coords=ecp_data.coords, dims=ecp_data.dims)
78
+
79
+ # Apply the filter to each channel
80
+ for channel in ecp_data.channel_id:
81
+ filtered_data.loc[channel, :] = signal.filtfilt(b, a, ecp_data.sel(channel_id=channel).values)
82
+
83
+ # Downsample the filtered data if a downsample frequency is provided
84
+ if downsample_freq is not None:
85
+ downsample_factor = int(fs / downsample_freq)
86
+ filtered_data = filtered_data.isel(time=slice(None, None, downsample_factor))
87
+ # Update the sampling frequency attribute
88
+ filtered_data.attrs['fs'] = downsample_freq
89
+
90
+ return filtered_data
91
+
92
+
93
+ def slice_time_series(data: xr.DataArray, time_ranges: tuple) -> xr.DataArray:
94
+ """
95
+ Slice the xarray DataArray based on provided time ranges.
96
+ Can be used to get LFP during certain stimulus times
97
+
98
+ Parameters:
99
+ ----------
100
+ data : xr.DataArray
101
+ The input xarray DataArray containing time-series data.
102
+ time_ranges : tuple or list of tuples
103
+ One or more tuples representing the (start, stop) time points for slicing.
104
+ For example: (start, stop) or [(start1, stop1), (start2, stop2)]
105
+
106
+ Returns:
107
+ -------
108
+ xr.DataArray
109
+ A new xarray DataArray containing the concatenated slices.
110
+ """
111
+ # Ensure time_ranges is a list of tuples
112
+ if isinstance(time_ranges, tuple) and len(time_ranges) == 2:
113
+ time_ranges = [time_ranges]
114
+
115
+ # List to hold sliced data
116
+ slices = []
117
+
118
+ # Slice the data for each time range
119
+ for start, stop in time_ranges:
120
+ sliced_data = data.sel(time=slice(start, stop))
121
+ slices.append(sliced_data)
122
+
123
+ # Concatenate all slices along the time dimension if more than one slice
124
+ if len(slices) > 1:
125
+ return xr.concat(slices, dim='time')
126
+ else:
127
+ return slices[0]
128
+
129
+
130
+ def fit_fooof(f: np.ndarray, pxx: np.ndarray, aperiodic_mode: str = 'fixed',
131
+ dB_threshold: float = 3.0, max_n_peaks: int = 10,
132
+ freq_range: tuple = None, peak_width_limits: tuple = None,
133
+ report: bool = False, plot: bool = False,
134
+ plt_log: bool = False, plt_range: tuple = None,
135
+ figsize: tuple = None, title: str = None) -> tuple:
136
+ """
137
+ Fit a FOOOF model to power spectral density data.
138
+
139
+ Parameters:
140
+ ----------
141
+ f : array-like
142
+ Frequencies corresponding to the power spectral density data.
143
+ pxx : array-like
144
+ Power spectral density data to fit.
145
+ aperiodic_mode : str, optional
146
+ The mode for fitting aperiodic components ('fixed' or 'knee', default is 'fixed').
147
+ dB_threshold : float, optional
148
+ Minimum peak height in dB (default is 3).
149
+ max_n_peaks : int, optional
150
+ Maximum number of peaks to fit (default is 10).
151
+ freq_range : tuple, optional
152
+ Frequency range to fit (default is None, which uses the full range).
153
+ peak_width_limits : tuple, optional
154
+ Limits on the width of peaks (default is None).
155
+ report : bool, optional
156
+ If True, will print fitting results (default is False).
157
+ plot : bool, optional
158
+ If True, will plot the fitting results (default is False).
159
+ plt_log : bool, optional
160
+ If True, use a logarithmic scale for the y-axis in plots (default is False).
161
+ plt_range : tuple, optional
162
+ Range for plotting (default is None).
163
+ figsize : tuple, optional
164
+ Size of the figure (default is None).
165
+ title : str, optional
166
+ Title for the plot (default is None).
167
+
168
+ Returns:
169
+ -------
170
+ tuple
171
+ A tuple containing the fitting results and the FOOOF model object.
172
+ """
173
+ if aperiodic_mode != 'knee':
174
+ aperiodic_mode = 'fixed'
175
+
176
+ def set_range(x, upper=f[-1]):
177
+ x = np.array(upper) if x is None else np.array(x)
178
+ return [f[2], x.item()] if x.size == 1 else x.tolist()
179
+
180
+ freq_range = set_range(freq_range)
181
+ peak_width_limits = set_range(peak_width_limits, np.inf)
182
+
183
+ # Initialize a FOOOF object
184
+ fm = FOOOF(peak_width_limits=peak_width_limits, min_peak_height=dB_threshold / 10,
185
+ peak_threshold=0., max_n_peaks=max_n_peaks, aperiodic_mode=aperiodic_mode)
186
+
187
+ # Fit the model
188
+ try:
189
+ fm.fit(f, pxx, freq_range)
190
+ except Exception as e:
191
+ fl = np.linspace(f[0], f[-1], int((f[-1] - f[0]) / np.min(np.diff(f))) + 1)
192
+ fm.fit(fl, np.interp(fl, f, pxx), freq_range)
193
+
194
+ results = fm.get_results()
195
+
196
+ if report:
197
+ fm.print_results()
198
+ if aperiodic_mode == 'knee':
199
+ ap_params = results.aperiodic_params
200
+ if ap_params[1] <= 0:
201
+ print('Negative value of knee parameter occurred. Suggestion: Fit without knee parameter.')
202
+ knee_freq = np.abs(ap_params[1]) ** (1 / ap_params[2])
203
+ print(f'Knee location: {knee_freq:.2f} Hz')
204
+
205
+ if plot:
206
+ plt_range = set_range(plt_range)
207
+ fm.plot(plt_log=plt_log)
208
+ plt.xlim(np.log10(plt_range) if plt_log else plt_range)
209
+ #plt.ylim(-8, -5.5)
210
+ if figsize:
211
+ plt.gcf().set_size_inches(figsize)
212
+ if title:
213
+ plt.title(title)
214
+ if is_notebook():
215
+ pass
216
+ else:
217
+ plt.show()
218
+
219
+ return results, fm
220
+
221
+
222
+ def generate_resd_from_fooof(fooof_model: FOOOF) -> tuple:
223
+ """
224
+ Generate residuals from a fitted FOOOF model.
225
+
226
+ Parameters:
227
+ ----------
228
+ fooof_model : FOOOF
229
+ A fitted FOOOF model object.
230
+
231
+ Returns:
232
+ -------
233
+ tuple
234
+ A tuple containing the residual power spectral density and the aperiodic fit.
235
+ """
236
+ results = fooof_model.get_results()
237
+ full_fit, _, ap_fit = gen_model(fooof_model.freqs[1:], results.aperiodic_params,
238
+ results.gaussian_params, return_components=True)
239
+
240
+ full_fit, ap_fit = 10 ** full_fit, 10 ** ap_fit # Convert back from log
241
+ res_psd = np.insert((10 ** fooof_model.power_spectrum[1:]) - ap_fit, 0, 0.) # Convert back from log
242
+ res_fit = np.insert(full_fit - ap_fit, 0, 0.)
243
+ ap_fit = np.insert(ap_fit, 0, 0.)
244
+
245
+ return res_psd, ap_fit
246
+
247
+
248
+ def calculate_SNR(fooof_model: FOOOF, freq_band: tuple) -> float:
249
+ """
250
+ Calculate the signal-to-noise ratio (SNR) from a fitted FOOOF model.
251
+
252
+ Parameters:
253
+ ----------
254
+ fooof_model : FOOOF
255
+ A fitted FOOOF model object.
256
+ freq_band : tuple
257
+ Frequency band (min, max) for SNR calculation.
258
+
259
+ Returns:
260
+ -------
261
+ float
262
+ The calculated SNR for the specified frequency band.
263
+ """
264
+ periodic, ap = generate_resd_from_fooof(fooof_model)
265
+ freq = fooof_model.freqs # Get frequencies from model
266
+ indices = (freq >= freq_band[0]) & (freq <= freq_band[1]) # Get only the band we care about
267
+ band_periodic = periodic[indices] # Filter based on band
268
+ band_ap = ap[indices] # Filter
269
+ band_freq = freq[indices] # Another filter
270
+ periodic_power = np.trapz(band_periodic, band_freq) # Integrate periodic power
271
+ ap_power = np.trapz(band_ap, band_freq) # Integrate aperiodic power
272
+ normalized_power = periodic_power / ap_power # Compute the SNR
273
+ return normalized_power
274
+
275
+
276
+ def wavelet_filter(x: np.ndarray, freq: float, fs: float, bandwidth: float = 1.0, axis: int = -1) -> np.ndarray:
277
+ """
278
+ Compute the Continuous Wavelet Transform (CWT) for a specified frequency using a complex Morlet wavelet.
279
+ """
280
+ wavelet = 'cmor' + str(2 * bandwidth ** 2) + '-1.0'
281
+ scale = pywt.scale2frequency(wavelet, 1) * fs / freq
282
+ x_a = pywt.cwt(x, [scale], wavelet=wavelet, axis=axis)[0][0]
283
+ return x_a
284
+
285
+
286
+ def butter_bandpass_filter(data: np.ndarray, lowcut: float, highcut: float, fs: float, order: int = 5, axis: int = -1) -> np.ndarray:
287
+ """
288
+ Apply a Butterworth bandpass filter to the input data.
289
+ """
290
+ sos = signal.butter(order, [lowcut, highcut], fs=fs, btype='band', output='sos')
291
+ x_a = signal.sosfiltfilt(sos, data, axis=axis)
292
+ return x_a
293
+
294
+
295
+ def calculate_plv(x1: np.ndarray, x2: np.ndarray, fs: float, freq_of_interest: float = None,
296
+ method: str = 'wavelet', lowcut: float = None, highcut: float = None,
297
+ bandwidth: float = 2.0) -> np.ndarray:
298
+ """
299
+ Calculate Phase Locking Value (PLV) between two signals using wavelet or Hilbert method.
300
+
301
+ Parameters:
302
+ - x1, x2: Input signals (1D arrays, same length)
303
+ - fs: Sampling frequency
304
+ - freq_of_interest: Desired frequency for wavelet PLV calculation
305
+ - method: 'wavelet' or 'hilbert' to choose the PLV calculation method
306
+ - lowcut, highcut: Cutoff frequencies for the Hilbert method
307
+ - bandwidth: Bandwidth parameter for the wavelet
308
+
309
+ Returns:
310
+ - plv: Phase Locking Value (1D array)
311
+ """
312
+ if len(x1) != len(x2):
313
+ raise ValueError("Input signals must have the same length.")
314
+
315
+ if method == 'wavelet':
316
+ if freq_of_interest is None:
317
+ raise ValueError("freq_of_interest must be provided for the wavelet method.")
318
+
319
+ # Apply CWT to both signals
320
+ theta1 = wavelet_filter(x=x1, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
321
+ theta2 = wavelet_filter(x=x2, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
322
+
323
+ elif method == 'hilbert':
324
+ if lowcut is None or highcut is None:
325
+ print("Lowcut and or highcut were not definded, signal will not be filter and just take hilbert transform for plv calc")
326
+
327
+ if lowcut and highcut:
328
+ # Bandpass filter and get the analytic signal using the Hilbert transform
329
+ x1 = butter_bandpass_filter(x1, lowcut, highcut, fs)
330
+ x2 = butter_bandpass_filter(x2, lowcut, highcut, fs)
331
+
332
+ # Get phase using the Hilbert transform
333
+ theta1 = signal.hilbert(x1)
334
+ theta2 = signal.hilbert(x2)
335
+
336
+ else:
337
+ raise ValueError("Invalid method. Choose 'wavelet' or 'hilbert'.")
338
+
339
+ # Calculate phase difference
340
+ phase_diff = np.angle(theta1) - np.angle(theta2)
341
+
342
+ # Calculate PLV from standard equation from Measuring phase synchrony in brain signals(1999)
343
+ plv = np.abs(np.mean(np.exp(1j * phase_diff), axis=-1))
344
+
345
+ return plv
346
+
347
+
348
+ def calculate_plv_over_time(x1: np.ndarray, x2: np.ndarray, fs: float,
349
+ window_size: float, step_size: float,
350
+ method: str = 'wavelet', freq_of_interest: float = None,
351
+ lowcut: float = None, highcut: float = None,
352
+ bandwidth: float = 2.0):
353
+ """
354
+ Calculate the time-resolved Phase Locking Value (PLV) between two signals using a sliding window approach.
355
+
356
+ Parameters:
357
+ ----------
358
+ x1, x2 : array-like
359
+ Input signals (1D arrays, same length).
360
+ fs : float
361
+ Sampling frequency of the input signals.
362
+ window_size : float
363
+ Length of the window in seconds for PLV calculation.
364
+ step_size : float
365
+ Step size in seconds to slide the window across the signals.
366
+ method : str, optional
367
+ Method to calculate PLV ('wavelet' or 'hilbert'). Defaults to 'wavelet'.
368
+ freq_of_interest : float, optional
369
+ Frequency of interest for the wavelet method. Required if method is 'wavelet'.
370
+ lowcut, highcut : float, optional
371
+ Cutoff frequencies for the Hilbert method. Required if method is 'hilbert'.
372
+ bandwidth : float, optional
373
+ Bandwidth parameter for the wavelet. Defaults to 2.0.
374
+
375
+ Returns:
376
+ -------
377
+ plv_over_time : 1D array
378
+ Array of PLV values calculated over each window.
379
+ times : 1D array
380
+ The center times of each window where the PLV was calculated.
381
+ """
382
+ # Convert window and step size from seconds to samples
383
+ window_samples = int(window_size * fs)
384
+ step_samples = int(step_size * fs)
385
+
386
+ # Initialize results
387
+ plv_over_time = []
388
+ times = []
389
+
390
+ # Iterate over the signal with a sliding window
391
+ for start in range(0, len(x1) - window_samples + 1, step_samples):
392
+ end = start + window_samples
393
+ window_x1 = x1[start:end]
394
+ window_x2 = x2[start:end]
395
+
396
+ # Use the updated calculate_plv function within each window
397
+ plv = calculate_plv(x1=window_x1, x2=window_x2, fs=fs,
398
+ method=method, freq_of_interest=freq_of_interest,
399
+ lowcut=lowcut, highcut=highcut, bandwidth=bandwidth)
400
+ plv_over_time.append(plv)
401
+
402
+ # Store the time at the center of the window
403
+ center_time = (start + end) / 2 / fs
404
+ times.append(center_time)
405
+
406
+ return np.array(plv_over_time), np.array(times)
407
+
408
+
@@ -0,0 +1,254 @@
1
+ """
2
+ Module for processing BMTK spikes output.
3
+ """
4
+
5
+ import h5py
6
+ import pandas as pd
7
+ from bmtool.util.util import load_nodes_from_config
8
+ from typing import Dict, Optional,Tuple, Union, List
9
+ import numpy as np
10
+ import os
11
+
12
+
13
+ def load_spikes_to_df(spike_file: str, network_name: str, sort: bool = True, config: str = None, groupby: str = 'pop_name') -> pd.DataFrame:
14
+ """
15
+ Load spike data from an HDF5 file into a pandas DataFrame.
16
+
17
+ Args:
18
+ spike_file (str): Path to the HDF5 file containing spike data.
19
+ network_name (str): The name of the network within the HDF5 file from which to load spike data.
20
+ sort (bool, optional): Whether to sort the DataFrame by 'timestamps'. Defaults to True.
21
+ config (str, optional): Will label the cell type of each spike.
22
+ groupby (str or list of str, optional): The column(s) to group by. Defaults to 'pop_name'.
23
+
24
+ Returns:
25
+ pd.DataFrame: A pandas DataFrame containing 'node_ids' and 'timestamps' columns from the spike data.
26
+
27
+ Example:
28
+ df = load_spikes_to_df("spikes.h5", "cortex")
29
+ """
30
+ with h5py.File(spike_file) as f:
31
+ spikes_df = pd.DataFrame({
32
+ 'node_ids': f['spikes'][network_name]['node_ids'],
33
+ 'timestamps': f['spikes'][network_name]['timestamps']
34
+ })
35
+
36
+ if sort:
37
+ spikes_df.sort_values(by='timestamps', inplace=True, ignore_index=True)
38
+
39
+ if config:
40
+ nodes = load_nodes_from_config(config)
41
+ nodes = nodes[network_name]
42
+
43
+ # Convert single string to a list for uniform handling
44
+ if isinstance(groupby, str):
45
+ groupby = [groupby]
46
+
47
+ # Ensure all requested columns exist
48
+ missing_cols = [col for col in groupby if col not in nodes.columns]
49
+ if missing_cols:
50
+ raise KeyError(f"Columns {missing_cols} not found in nodes DataFrame.")
51
+
52
+ spikes_df = spikes_df.merge(nodes[groupby], left_on='node_ids', right_index=True, how='left')
53
+
54
+ return spikes_df
55
+
56
+
57
+ def compute_firing_rate_stats(df: pd.DataFrame, groupby: Union[str, List[str]] = "pop_name", start_time: float = None, stop_time: float = None) -> Tuple[pd.DataFrame, pd.DataFrame]:
58
+ """
59
+ Computes the firing rates of individual nodes and the mean and standard deviation of firing rates per group.
60
+
61
+ Args:
62
+ df (pd.DataFrame): Dataframe containing spike timestamps and node IDs.
63
+ groupby (str or list of str, optional): Column(s) to group by (e.g., 'pop_name' or ['pop_name', 'layer']).
64
+ start_time (float, optional): Start time for the analysis window. Defaults to the minimum timestamp in the data.
65
+ stop_time (float, optional): Stop time for the analysis window. Defaults to the maximum timestamp in the data.
66
+
67
+ Returns:
68
+ Tuple[pd.DataFrame, pd.DataFrame]:
69
+ - The first DataFrame (`pop_stats`) contains the mean and standard deviation of firing rates per group.
70
+ - The second DataFrame (`individual_stats`) contains the firing rate of each individual node.
71
+ """
72
+
73
+ # Ensure groupby is a list
74
+ if isinstance(groupby, str):
75
+ groupby = [groupby]
76
+
77
+ # Ensure all columns exist in the dataframe
78
+ for col in groupby:
79
+ if col not in df.columns:
80
+ raise ValueError(f"Column '{col}' not found in dataframe.")
81
+
82
+ # Filter dataframe based on start/stop time
83
+ if start_time is not None:
84
+ df = df[df["timestamps"] >= start_time]
85
+ if stop_time is not None:
86
+ df = df[df["timestamps"] <= stop_time]
87
+
88
+ # Compute total duration for firing rate calculation
89
+ if start_time is None:
90
+ min_time = df["timestamps"].min()
91
+ else:
92
+ min_time = start_time
93
+
94
+ if stop_time is None:
95
+ max_time = df["timestamps"].max()
96
+ else:
97
+ max_time = stop_time
98
+
99
+ duration = max_time - min_time # Avoid division by zero
100
+
101
+ if duration <= 0:
102
+ raise ValueError("Invalid time window: Stop time must be greater than start time.")
103
+
104
+ # Compute firing rate for each node
105
+ import pandas as pd
106
+
107
+ # Compute spike counts per node
108
+ spike_counts = df["node_ids"].value_counts().reset_index()
109
+ spike_counts.columns = ["node_ids", "spike_count"] # Rename columns
110
+
111
+ # Merge with original dataframe to get corresponding labels (e.g., 'pop_name')
112
+ spike_counts = spike_counts.merge(df[["node_ids"] + groupby].drop_duplicates(), on="node_ids", how="left")
113
+
114
+ # Compute firing rate
115
+ spike_counts["firing_rate"] = spike_counts["spike_count"] / duration * 1000 # scale to Hz
116
+ indivdual_stats = spike_counts
117
+
118
+ # Compute mean and standard deviation per group
119
+ pop_stats = spike_counts.groupby(groupby)["firing_rate"].agg(["mean", "std"]).reset_index()
120
+
121
+ # Rename columns
122
+ pop_stats.rename(columns={"mean": "firing_rate_mean", "std": "firing_rate_std"}, inplace=True)
123
+
124
+ return pop_stats,indivdual_stats
125
+
126
+
127
+ def _pop_spike_rate(spike_times: Union[np.ndarray, list], time: Optional[Tuple[float, float, float]] = None,
128
+ time_points: Optional[Union[np.ndarray, list]] = None, frequeny: bool = False) -> np.ndarray:
129
+ """
130
+ Calculate the spike count or frequency histogram over specified time intervals.
131
+
132
+ Args:
133
+ spike_times (Union[np.ndarray, list]): Array or list of spike times in milliseconds.
134
+ time (Optional[Tuple[float, float, float]], optional): Tuple specifying (start, stop, step) in milliseconds.
135
+ Used to create evenly spaced time points if `time_points` is not provided. Default is None.
136
+ time_points (Optional[Union[np.ndarray, list]], optional): Array or list of specific time points for binning.
137
+ If provided, `time` is ignored. Default is None.
138
+ frequeny (bool, optional): If True, returns spike frequency in Hz; otherwise, returns spike count. Default is False.
139
+
140
+ Returns:
141
+ np.ndarray: Array of spike counts or frequencies, depending on the `frequeny` flag.
142
+
143
+ Raises:
144
+ ValueError: If both `time` and `time_points` are None.
145
+ """
146
+ if time_points is None:
147
+ if time is None:
148
+ raise ValueError("Either `time` or `time_points` must be provided.")
149
+ time_points = np.arange(*time)
150
+ dt = time[2]
151
+ else:
152
+ time_points = np.asarray(time_points).ravel()
153
+ dt = (time_points[-1] - time_points[0]) / (time_points.size - 1)
154
+
155
+ bins = np.append(time_points, time_points[-1] + dt)
156
+ spike_rate, _ = np.histogram(np.asarray(spike_times), bins)
157
+
158
+ if frequeny:
159
+ spike_rate = 1000 / dt * spike_rate
160
+
161
+ return spike_rate
162
+
163
+
164
+ def get_population_spike_rate(spikes: pd.DataFrame, fs: float = 400.0, t_start: float = 0, t_stop: Optional[float] = None,
165
+ config: Optional[str] = None, network_name: Optional[str] = None,
166
+ save: bool = False, save_path: Optional[str] = None,
167
+ normalize: bool = False) -> Dict[str, np.ndarray]:
168
+ """
169
+ Calculate the population spike rate for each population in the given spike data, with an option to normalize.
170
+
171
+ Args:
172
+ spikes (pd.DataFrame): A DataFrame containing spike data with columns 'pop_name', 'timestamps', and 'node_ids'.
173
+ fs (float, optional): Sampling frequency in Hz, which determines the time bin size for calculating the spike rate. Default is 400.
174
+ t_start (float, optional): Start time (in milliseconds) for spike rate calculation. Default is 0.
175
+ t_stop (Optional[float], optional): Stop time (in milliseconds) for spike rate calculation. If None, defaults to the maximum timestamp in the data.
176
+ config (Optional[str], optional): Path to a configuration file containing node information, used to determine the correct number of nodes per population.
177
+ If None, node count is estimated from unique node spikes. Default is None.
178
+ network_name (Optional[str], optional): Name of the network used in the configuration file, allowing selection of nodes for that network.
179
+ Required if `config` is provided. Default is None.
180
+ save (bool, optional): Whether to save the calculated population spike rate to a file. Default is False.
181
+ save_path (Optional[str], optional): Directory path where the file should be saved if `save` is True. If `save` is True and `save_path` is None, a ValueError is raised.
182
+ normalize (bool, optional): Whether to normalize the spike rates for each population to a range of [0, 1]. Default is False.
183
+
184
+ Returns:
185
+ Dict[str, np.ndarray]: A dictionary where keys are population names, and values are arrays representing the spike rate over time for each population.
186
+ If `normalize` is True, each population's spike rate is scaled to [0, 1].
187
+
188
+ Raises:
189
+ ValueError: If `save` is True but `save_path` is not provided.
190
+
191
+ Notes:
192
+ - If `config` is None, the function assumes all cells in each population have fired at least once; otherwise, the node count may be inaccurate.
193
+ - If normalization is enabled, each population's spike rate is scaled using Min-Max normalization based on its own minimum and maximum values.
194
+
195
+ """
196
+ pop_spikes = {}
197
+ node_number = {}
198
+
199
+ if config is None:
200
+ print("Note: Node number is obtained by counting unique node spikes in the network.\nIf the network did not run for a sufficient duration, and not all cells fired, this count might be incorrect.")
201
+ print("You can provide a config to calculate the correct amount of nodes!")
202
+
203
+ if config:
204
+ if not network_name:
205
+ print("Grabbing first network; specify a network name to ensure correct node population is selected.")
206
+
207
+ for pop_name in spikes['pop_name'].unique():
208
+ ps = spikes[spikes['pop_name'] == pop_name]
209
+
210
+ if config:
211
+ nodes = load_nodes_from_config(config)
212
+ if network_name:
213
+ nodes = nodes[network_name]
214
+ else:
215
+ nodes = list(nodes.values())[0] if nodes else {}
216
+ nodes = nodes[nodes['pop_name'] == pop_name]
217
+ node_number[pop_name] = nodes.index.nunique()
218
+ else:
219
+ node_number[pop_name] = ps['node_ids'].nunique()
220
+
221
+ if t_stop is None:
222
+ t_stop = spikes['timestamps'].max()
223
+
224
+ filtered_spikes = spikes[
225
+ (spikes['pop_name'] == pop_name) &
226
+ (spikes['timestamps'] > t_start) &
227
+ (spikes['timestamps'] < t_stop)
228
+ ]
229
+ pop_spikes[pop_name] = filtered_spikes
230
+
231
+ time = np.array([t_start, t_stop, 1000 / fs])
232
+ pop_rspk = {p: _pop_spike_rate(spk['timestamps'], time) for p, spk in pop_spikes.items()}
233
+ spike_rate = {p: fs / node_number[p] * pop_rspk[p] for p in pop_rspk}
234
+
235
+ # Normalize each spike rate series if normalize=True
236
+ if normalize:
237
+ spike_rate = {p: (sr - sr.min()) / (sr.max() - sr.min()) for p, sr in spike_rate.items()}
238
+
239
+ if save:
240
+ if save_path is None:
241
+ raise ValueError("save_path must be provided if save is True.")
242
+
243
+ os.makedirs(save_path, exist_ok=True)
244
+
245
+ save_file = os.path.join(save_path, 'spike_rate.h5')
246
+ with h5py.File(save_file, 'w') as f:
247
+ f.create_dataset('time', data=time)
248
+ grp = f.create_group('populations')
249
+ for p, rspk in spike_rate.items():
250
+ pop_grp = grp.create_group(p)
251
+ pop_grp.create_dataset('data', data=rspk)
252
+
253
+ return spike_rate
254
+
@@ -13,6 +13,7 @@ import matplotlib.colors as colors
13
13
  import matplotlib.gridspec as gridspec
14
14
  from mpl_toolkits.mplot3d import Axes3D
15
15
  from matplotlib.axes import Axes
16
+ import seaborn as sns
16
17
  from IPython import get_ipython
17
18
  from IPython.display import display, HTML
18
19
  import statistics
@@ -20,7 +21,7 @@ import pandas as pd
20
21
  import os
21
22
  import sys
22
23
  import re
23
- from typing import Optional, Dict
24
+ from typing import Optional, Dict, Union, List
24
25
 
25
26
  from .util.util import CellVarsFile,load_nodes_from_config #, missing_units
26
27
  from bmtk.analyzer.utils import listify
@@ -762,7 +763,7 @@ def connector_percent_matrix(csv_path: str = None, exclude_strings=None, assemb_
762
763
  plt.tight_layout()
763
764
  plt.show()
764
765
 
765
- def raster(spikes_df: Optional[pd.DataFrame] = None, config: Optional[str] = None, network_name: Optional[str] = None,
766
+ def raster(spikes_df: Optional[pd.DataFrame] = None, config: Optional[str] = None, network_name: Optional[str] = None, groupby:Optional[str] = 'pop_name',
766
767
  ax: Optional[Axes] = None,tstart: Optional[float] = None,tstop: Optional[float] = None,
767
768
  color_map: Optional[Dict[str, str]] = None) -> Axes:
768
769
  """
@@ -793,7 +794,7 @@ def raster(spikes_df: Optional[pd.DataFrame] = None, config: Optional[str] = Non
793
794
  Notes:
794
795
  -----
795
796
  - If `config` is provided, the function merges population names from the node data with `spikes_df`.
796
- - Each unique population (`pop_name`) in `spikes_df` will be represented by a different color if `color_map` is not specified.
797
+ - Each unique population from groupby in `spikes_df` will be represented by a different color if `color_map` is not specified.
797
798
  - If `color_map` is provided, it should contain colors for all unique `pop_name` values in `spikes_df`.
798
799
  """
799
800
  # Initialize axes if none provided
@@ -822,11 +823,11 @@ def raster(spikes_df: Optional[pd.DataFrame] = None, config: Optional[str] = Non
822
823
  # Drop all intersecting columns except the join key column from df2
823
824
  spikes_df = spikes_df.drop(columns=common_columns)
824
825
  # merge nodes and spikes df
825
- spikes_df = spikes_df.merge(nodes['pop_name'], left_on='node_ids', right_index=True, how='left')
826
+ spikes_df = spikes_df.merge(nodes[groupby], left_on='node_ids', right_index=True, how='left')
826
827
 
827
828
 
828
829
  # Get unique population names
829
- unique_pop_names = spikes_df['pop_name'].unique()
830
+ unique_pop_names = spikes_df[groupby].unique()
830
831
 
831
832
  # Generate colors if no color_map is provided
832
833
  if color_map is None:
@@ -839,7 +840,7 @@ def raster(spikes_df: Optional[pd.DataFrame] = None, config: Optional[str] = Non
839
840
  raise ValueError(f"color_map is missing colors for populations: {missing_colors}")
840
841
 
841
842
  # Plot each population with its specified or generated color
842
- for pop_name, group in spikes_df.groupby('pop_name'):
843
+ for pop_name, group in spikes_df.groupby(groupby):
843
844
  ax.scatter(group['timestamps'], group['node_ids'], label=pop_name, color=color_map[pop_name], s=0.5)
844
845
 
845
846
  # Label axes
@@ -849,6 +850,169 @@ def raster(spikes_df: Optional[pd.DataFrame] = None, config: Optional[str] = Non
849
850
 
850
851
  return ax
851
852
 
853
+ # uses df from bmtool.analysis.spikes compute_firing_rate_stats
854
+ def plot_firing_rate_pop_stats(firing_stats: pd.DataFrame, groupby: Union[str, List[str]], ax: Optional[Axes] = None,
855
+ color_map: Optional[Dict[str, str]] = None) -> Axes:
856
+ """
857
+ Plots a bar graph of mean firing rates with error bars (standard deviation).
858
+
859
+ Parameters:
860
+ ----------
861
+ firing_stats : pd.DataFrame
862
+ Dataframe containing 'firing_rate_mean' and 'firing_rate_std'.
863
+ groupby : str or list of str
864
+ Column(s) used for grouping.
865
+ ax : matplotlib.axes.Axes, optional
866
+ Axes on which to plot the bar chart; if None, a new figure and axes are created.
867
+ color_map : dict, optional
868
+ Dictionary specifying colors for each group. Keys should be group names, and values should be color values.
869
+
870
+ Returns:
871
+ -------
872
+ matplotlib.axes.Axes
873
+ Axes with the bar plot.
874
+ """
875
+ # Ensure groupby is a list for consistent handling
876
+ if isinstance(groupby, str):
877
+ groupby = [groupby]
878
+
879
+ # Create a categorical column for grouping
880
+ firing_stats["group"] = firing_stats[groupby].astype(str).agg("_".join, axis=1)
881
+
882
+ # Get unique group names
883
+ unique_groups = firing_stats["group"].unique()
884
+
885
+ # Generate colors if no color_map is provided
886
+ if color_map is None:
887
+ cmap = plt.get_cmap('viridis')
888
+ color_map = {group: cmap(i / len(unique_groups)) for i, group in enumerate(unique_groups)}
889
+ else:
890
+ # Ensure color_map contains all groups
891
+ missing_colors = [group for group in unique_groups if group not in color_map]
892
+ if missing_colors:
893
+ raise ValueError(f"color_map is missing colors for groups: {missing_colors}")
894
+
895
+ # Create new figure and axes if ax is not provided
896
+ if ax is None:
897
+ fig, ax = plt.subplots(figsize=(10, 6))
898
+
899
+ # Sort data for consistent plotting
900
+ firing_stats = firing_stats.sort_values(by="group")
901
+
902
+ # Extract values for plotting
903
+ x_labels = firing_stats["group"]
904
+ means = firing_stats["firing_rate_mean"]
905
+ std_devs = firing_stats["firing_rate_std"]
906
+
907
+ # Get colors for each group
908
+ colors = [color_map[group] for group in x_labels]
909
+
910
+ # Create bar plot
911
+ bars = ax.bar(x_labels, means, yerr=std_devs, capsize=5, color=colors, edgecolor="black")
912
+
913
+ # Add error bars manually with caps
914
+ _, caps, _ = ax.errorbar(
915
+ x=np.arange(len(x_labels)),
916
+ y=means,
917
+ yerr=std_devs,
918
+ fmt='none',
919
+ capsize=5,
920
+ capthick=2,
921
+ color="black"
922
+ )
923
+
924
+ # Formatting
925
+ ax.set_xticks(np.arange(len(x_labels)))
926
+ ax.set_xticklabels(x_labels, rotation=45, ha="right")
927
+ ax.set_xlabel("Population Group")
928
+ ax.set_ylabel("Mean Firing Rate (spikes/s)")
929
+ ax.set_title("Firing Rate Statistics by Population")
930
+ ax.grid(axis='y', linestyle='--', alpha=0.7)
931
+
932
+ return ax
933
+
934
+ # uses df from bmtool.analysis.spikes compute_firing_rate_stats
935
+ def plot_firing_rate_distribution(individual_stats: pd.DataFrame, groupby: Union[str, list], ax: Optional[Axes] = None,
936
+ color_map: Optional[Dict[str, str]] = None,
937
+ plot_type: Union[str, list] = "box", swarm_alpha: float = 0.6) -> Axes:
938
+ """
939
+ Plots a distribution of individual firing rates using one or more plot types
940
+ (box plot, violin plot, or swarm plot), overlaying them on top of each other.
941
+
942
+ Parameters:
943
+ ----------
944
+ individual_stats : pd.DataFrame
945
+ Dataframe containing individual firing rates and corresponding group labels.
946
+ groupby : str or list of str
947
+ Column(s) used for grouping.
948
+ ax : matplotlib.axes.Axes, optional
949
+ Axes on which to plot the graph; if None, a new figure and axes are created.
950
+ color_map : dict, optional
951
+ Dictionary specifying colors for each group. Keys should be group names, and values should be color values.
952
+ plot_type : str or list of str, optional
953
+ List of plot types to generate. Options: "box", "violin", "swarm". Default is "box".
954
+ swarm_alpha : float, optional
955
+ Transparency of swarm plot points. Default is 0.6.
956
+
957
+ Returns:
958
+ -------
959
+ matplotlib.axes.Axes
960
+ Axes with the selected plot type(s) overlayed.
961
+ """
962
+ # Ensure groupby is a list for consistent handling
963
+ if isinstance(groupby, str):
964
+ groupby = [groupby]
965
+
966
+ # Create a categorical column for grouping
967
+ individual_stats["group"] = individual_stats[groupby].astype(str).agg("_".join, axis=1)
968
+
969
+ # Validate plot_type (it can be a list or a single type)
970
+ if isinstance(plot_type, str):
971
+ plot_type = [plot_type]
972
+
973
+ for pt in plot_type:
974
+ if pt not in ["box", "violin", "swarm"]:
975
+ raise ValueError("plot_type must be one of: 'box', 'violin', 'swarm'.")
976
+
977
+ # Get unique groups for coloring
978
+ unique_groups = individual_stats["group"].unique()
979
+
980
+ # Generate colors if no color_map is provided
981
+ if color_map is None:
982
+ cmap = plt.get_cmap('viridis')
983
+ color_map = {group: cmap(i / len(unique_groups)) for i, group in enumerate(unique_groups)}
984
+
985
+ # Ensure color_map contains all groups
986
+ missing_colors = [group for group in unique_groups if group not in color_map]
987
+ if missing_colors:
988
+ raise ValueError(f"color_map is missing colors for groups: {missing_colors}")
989
+
990
+ # Create new figure and axes if ax is not provided
991
+ if ax is None:
992
+ fig, ax = plt.subplots(figsize=(10, 6))
993
+
994
+ # Sort data for consistent plotting
995
+ individual_stats = individual_stats.sort_values(by="group")
996
+
997
+ # Loop over each plot type and overlay them
998
+ for pt in plot_type:
999
+ if pt == "box":
1000
+ sns.boxplot(data=individual_stats, x="group", y="firing_rate", ax=ax, palette=color_map, width=0.5)
1001
+ elif pt == "violin":
1002
+ sns.violinplot(data=individual_stats, x="group", y="firing_rate", ax=ax, palette=color_map, inner="quartile", alpha=0.4)
1003
+ elif pt == "swarm":
1004
+ sns.swarmplot(data=individual_stats, x="group", y="firing_rate", ax=ax, palette=color_map, alpha=swarm_alpha)
1005
+
1006
+ # Formatting
1007
+ ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
1008
+ ax.set_xlabel("Population Group")
1009
+ ax.set_ylabel("Firing Rate (spikes/s)")
1010
+ ax.set_title("Firing Rate Distribution by Population")
1011
+ ax.grid(axis='y', linestyle='--', alpha=0.7)
1012
+
1013
+ return ax
1014
+
1015
+
852
1016
  def plot_3d_positions(config=None, populations_list=None, group_by=None, title=None, save_file=None, subset=None):
853
1017
  """
854
1018
  Plots a 3D graph of all cells with x, y, z location.
@@ -1154,164 +1318,3 @@ def plot_report_default(config, report_name, variables, gids):
1154
1318
  plot_report(config_file=config, report_file=report_file, report_name=report_name, variables=variables, gids=gids);
1155
1319
 
1156
1320
  return
1157
-
1158
- # The following code was developed by Matthew Stroud 7/15/21 neural engineering supervisor: Satish Nair
1159
- # This is an extension of bmtool: a development of Tyler Banks.
1160
- # The goal of the sim_setup() function is to output relevant simulation information that can be gathered by providing only the main configuration file.
1161
-
1162
-
1163
- def sim_setup(config_file='simulation_config.json',network=None):
1164
- if "JPY_PARENT_PID" in os.environ:
1165
- print("Inside a notebook:")
1166
- get_ipython().run_line_magic('matplotlib', 'tk')
1167
-
1168
-
1169
- # Output tables that contain the cells involved in the configuration file given. Also returns the first biophysical network found
1170
- bio=plot_basic_cell_info(config_file)
1171
- if network == None:
1172
- network=bio
1173
-
1174
- print("Please wait. This may take a while depending on your network size...")
1175
- # Plot connection probabilities
1176
- plt.close(1)
1177
- probability_connection_matrix(config=config_file,sources=network,targets=network, no_prepend_pop=True,sids= 'pop_name', tids= 'pop_name', bins=10,line_plot=True,verbose=False)
1178
- # Gives current clamp information
1179
- plot_I_clamps(config_file)
1180
- # Plot spike train info
1181
- plot_inspikes(config_file)
1182
- # Using bmtool, print total number of connections between cell groups
1183
- total_connection_matrix(config=config_file,sources='all',targets='all',sids='pop_name',tids='pop_name',title='All Connections found', size_scalar=2, no_prepend_pop=True, synaptic_info='0')
1184
- # Plot 3d positions of the network
1185
- plot_3d_positions(populations='all',config=config_file,group_by='pop_name',title='3D Positions',save_file=None)
1186
-
1187
- def plot_I_clamps(fp):
1188
- print("Plotting current clamp info...")
1189
- clamps = util.load_I_clamp_from_config(fp)
1190
- if not clamps:
1191
- print(" No current clamps were found.")
1192
- return
1193
- time=[]
1194
- num_clamps=0
1195
- fig, ax = plt.subplots()
1196
- ax = plt.gca()
1197
- for clinfo in clamps:
1198
- simtime=len(clinfo[0])*clinfo[1]
1199
- time.append(np.arange(0,simtime,clinfo[1]).tolist())
1200
-
1201
- line,=ax.plot(time[num_clamps],clinfo[0],drawstyle='steps')
1202
- line.set_label('I Clamp to: '+str(clinfo[2]))
1203
- plt.legend()
1204
- num_clamps=num_clamps+1
1205
-
1206
- def plot_basic_cell_info(config_file):
1207
- print("Network and node info:")
1208
- nodes=util.load_nodes_from_config(config_file)
1209
- if not nodes:
1210
- print("No nodes were found.")
1211
- return
1212
- pd.set_option("display.max_rows", None, "display.max_columns", None)
1213
- bio=[]
1214
- i=0
1215
- j=0
1216
- for j in nodes:
1217
- node=nodes[j]
1218
- node_type_id=node['node_type_id']
1219
- num_cells=len(node['node_type_id'])
1220
- if node['model_type'][0]=='virtual':
1221
- CELLS=[]
1222
- count=1
1223
- for i in range(num_cells-1):
1224
- if(node_type_id[i]==node_type_id[i+1]):
1225
- count+=1
1226
- else:
1227
- node_type=node_type_id[i]
1228
- pop_name=node['pop_name'][i]
1229
- model_type=node['model_type'][i]
1230
- CELLS.append([node_type,pop_name,model_type,count])
1231
- count=1
1232
- else:
1233
- node_type=node_type_id[i]
1234
- pop_name=node['pop_name'][i]
1235
- model_type=node['model_type'][i]
1236
- CELLS.append([node_type,pop_name,model_type,count])
1237
- count=1
1238
- df1 = pd.DataFrame(CELLS, columns = ["node_type","pop_name","model_type","count"])
1239
- print(j+':')
1240
- notebook = is_notebook()
1241
- if notebook == True:
1242
- display(HTML(df1.to_html()))
1243
- else:
1244
- print(df1)
1245
- elif node['model_type'][0]=='biophysical':
1246
- CELLS=[]
1247
- count=1
1248
- node_type_id=node['node_type_id']
1249
- num_cells=len(node['node_type_id'])
1250
- for i in range(num_cells-1):
1251
- if(node_type_id[i]==node_type_id[i+1]):
1252
- count+=1
1253
- else:
1254
- node_type=node_type_id[i]
1255
- pop_name=node['pop_name'][i]
1256
- model_type=node['model_type'][i]
1257
- model_template=node['model_template'][i]
1258
- morphology=node['morphology'][i] if node['morphology'][i] else ''
1259
- CELLS.append([node_type,pop_name,model_type,model_template,morphology,count])
1260
- count=1
1261
- else:
1262
- node_type=node_type_id[i]
1263
- pop_name=node['pop_name'][i]
1264
- model_type=node['model_type'][i]
1265
- model_template=node['model_template'][i]
1266
- morphology=node['morphology'][i] if node['morphology'][i] else ''
1267
- CELLS.append([node_type,pop_name,model_type,model_template,morphology,count])
1268
- count=1
1269
- df2 = pd.DataFrame(CELLS, columns = ["node_type","pop_name","model_type","model_template","morphology","count"])
1270
- print(j+':')
1271
- bio.append(j)
1272
- notebook = is_notebook()
1273
- if notebook == True:
1274
- display(HTML(df2.to_html()))
1275
- else:
1276
- print(df2)
1277
- if len(bio)>0:
1278
- return bio[0]
1279
-
1280
- def plot_inspikes(fp):
1281
-
1282
- print("Plotting spike Train info...")
1283
- trains = util.load_inspikes_from_config(fp)
1284
- if not trains:
1285
- print("No spike trains were found.")
1286
- num_trains=len(trains)
1287
-
1288
- time=[]
1289
- node=[]
1290
- fig, ax = plt.subplots(num_trains, figsize=(12,12),squeeze=False)
1291
- fig.subplots_adjust(hspace=0.5, wspace=0.5)
1292
-
1293
- pos=0
1294
- for tr in trains:
1295
- node_group=tr[0][2]
1296
- if node_group=='':
1297
- node_group='Defined by gids (y-axis)'
1298
- time=[]
1299
- node=[]
1300
- for sp in tr:
1301
- node.append(sp[1])
1302
- time.append(sp[0])
1303
-
1304
- #plotting spike train
1305
-
1306
- ax[pos,0].scatter(time,node,s=1)
1307
- ax[pos,0].title.set_text('Input Spike Train to: '+node_group)
1308
- plt.xticks(rotation = 45)
1309
- if num_trains <=4:
1310
- ax[pos,0].xaxis.set_major_locator(plt.MaxNLocator(20))
1311
- if num_trains <=9 and num_trains >4:
1312
- ax[pos,0].xaxis.set_major_locator(plt.MaxNLocator(4))
1313
- elif num_trains <9:
1314
- ax[pos,0].xaxis.set_major_locator(plt.MaxNLocator(2))
1315
- #fig.suptitle('Input Spike Train to: '+node_group, fontsize=14)
1316
- fig.show()
1317
- pos+=1
@@ -90,7 +90,11 @@ class CurrentClamp(object):
90
90
  self.inj_dur = inj_dur
91
91
  self.inj_amp = inj_amp * 1e-3 # pA to nA
92
92
 
93
- self.cell = self.create_cell()
93
+ # sometimes people may put a hoc object in for the template name
94
+ if callable(template_name):
95
+ self.cell = template_name
96
+ else:
97
+ self.cell = self.create_cell()
94
98
  if post_init_function:
95
99
  eval(f"self.cell.{post_init_function}")
96
100
 
File without changes
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: bmtool
3
- Version: 0.6.7
3
+ Version: 0.6.8
4
4
  Summary: BMTool
5
5
  Home-page: https://github.com/cyneuro/bmtool
6
6
  Download-URL:
@@ -53,6 +53,7 @@ A collection of modules to make developing [Neuron](https://www.neuron.yale.edu/
53
53
  - [Synapses](#synapses-module)
54
54
  - [Connectors](#connectors-module)
55
55
  - [Bmplot](#bmplot-module)
56
+ - [Analysis](#analysis-module)
56
57
  - [SLURM](#slurm-module)
57
58
  - [Graphs](#graphs-module)
58
59
 
@@ -471,7 +472,11 @@ bmplot.plot_network_graph(config='config.json',sources='LA',targets='LA',tids='p
471
472
 
472
473
 
473
474
  ![png](readme_figures/output_35_0.png)
474
-
475
+
476
+
477
+ ## Analysis Module
478
+ ### A notebook example of how to use the spikes module can be found [here](examples/analysis/using_spikes.ipynb)
479
+
475
480
  ## SLURM Module
476
481
  ### This is an extremely helpful module that can simplify using SLURM too submit your models. There is also features to enable doing a seedSweep. This will vary the parameters of the simulation and make tuning the model easier. An example can be found [here](examples/SLURM/using_BlockRunner.ipynb)
477
482
 
@@ -17,6 +17,9 @@ bmtool.egg-info/dependency_links.txt
17
17
  bmtool.egg-info/entry_points.txt
18
18
  bmtool.egg-info/requires.txt
19
19
  bmtool.egg-info/top_level.txt
20
+ bmtool/analysis/__init__.py
21
+ bmtool/analysis/lfp.py
22
+ bmtool/analysis/spikes.py
20
23
  bmtool/debug/__init__.py
21
24
  bmtool/debug/commands.py
22
25
  bmtool/debug/debug.py
@@ -6,7 +6,7 @@ with open("README.md", "r") as fh:
6
6
 
7
7
  setup(
8
8
  name="bmtool",
9
- version='0.6.7',
9
+ version='0.6.8',
10
10
  author="Neural Engineering Laboratory at the University of Missouri",
11
11
  author_email="gregglickert@mail.missouri.edu",
12
12
  description="BMTool",
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes