bmtool 0.6.7__tar.gz → 0.6.7.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. {bmtool-0.6.7 → bmtool-0.6.7.1}/PKG-INFO +1 -1
  2. {bmtool-0.6.7 → bmtool-0.6.7.1}/bmtool/SLURM.py +10 -14
  3. bmtool-0.6.7.1/bmtool/analysis/lfp.py +408 -0
  4. bmtool-0.6.7.1/bmtool/analysis/spikes.py +181 -0
  5. {bmtool-0.6.7 → bmtool-0.6.7.1}/bmtool/bmplot.py +5 -5
  6. {bmtool-0.6.7 → bmtool-0.6.7.1}/bmtool/singlecell.py +5 -1
  7. bmtool-0.6.7.1/bmtool/util/neuron/__init__.py +0 -0
  8. {bmtool-0.6.7 → bmtool-0.6.7.1}/bmtool.egg-info/PKG-INFO +1 -1
  9. {bmtool-0.6.7 → bmtool-0.6.7.1}/bmtool.egg-info/SOURCES.txt +3 -0
  10. {bmtool-0.6.7 → bmtool-0.6.7.1}/setup.py +1 -1
  11. {bmtool-0.6.7 → bmtool-0.6.7.1}/LICENSE +0 -0
  12. {bmtool-0.6.7 → bmtool-0.6.7.1}/README.md +0 -0
  13. {bmtool-0.6.7 → bmtool-0.6.7.1}/bmtool/__init__.py +0 -0
  14. {bmtool-0.6.7 → bmtool-0.6.7.1}/bmtool/__main__.py +0 -0
  15. {bmtool-0.6.7/bmtool/debug → bmtool-0.6.7.1/bmtool/analysis}/__init__.py +0 -0
  16. {bmtool-0.6.7 → bmtool-0.6.7.1}/bmtool/connectors.py +0 -0
  17. {bmtool-0.6.7/bmtool/util → bmtool-0.6.7.1/bmtool/debug}/__init__.py +0 -0
  18. {bmtool-0.6.7 → bmtool-0.6.7.1}/bmtool/debug/commands.py +0 -0
  19. {bmtool-0.6.7 → bmtool-0.6.7.1}/bmtool/debug/debug.py +0 -0
  20. {bmtool-0.6.7 → bmtool-0.6.7.1}/bmtool/graphs.py +0 -0
  21. {bmtool-0.6.7 → bmtool-0.6.7.1}/bmtool/manage.py +0 -0
  22. {bmtool-0.6.7 → bmtool-0.6.7.1}/bmtool/plot_commands.py +0 -0
  23. {bmtool-0.6.7 → bmtool-0.6.7.1}/bmtool/synapses.py +0 -0
  24. {bmtool-0.6.7/bmtool/util/neuron → bmtool-0.6.7.1/bmtool/util}/__init__.py +0 -0
  25. {bmtool-0.6.7 → bmtool-0.6.7.1}/bmtool/util/commands.py +0 -0
  26. {bmtool-0.6.7 → bmtool-0.6.7.1}/bmtool/util/neuron/celltuner.py +0 -0
  27. {bmtool-0.6.7 → bmtool-0.6.7.1}/bmtool/util/util.py +0 -0
  28. {bmtool-0.6.7 → bmtool-0.6.7.1}/bmtool.egg-info/dependency_links.txt +0 -0
  29. {bmtool-0.6.7 → bmtool-0.6.7.1}/bmtool.egg-info/entry_points.txt +0 -0
  30. {bmtool-0.6.7 → bmtool-0.6.7.1}/bmtool.egg-info/requires.txt +0 -0
  31. {bmtool-0.6.7 → bmtool-0.6.7.1}/bmtool.egg-info/top_level.txt +0 -0
  32. {bmtool-0.6.7 → bmtool-0.6.7.1}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: bmtool
3
- Version: 0.6.7
3
+ Version: 0.6.7.1
4
4
  Summary: BMTool
5
5
  Home-page: https://github.com/cyneuro/bmtool
6
6
  Download-URL:
@@ -353,17 +353,15 @@ class BlockRunner:
353
353
  shutil.copytree(source_dir, destination_dir) # create new components folder
354
354
  json_file_path = os.path.join(destination_dir,self.json_file_path)
355
355
 
356
- # need to keep the orignal around
357
- syn_dict_temp = copy.deepcopy(self.syn_dict)
358
- print(self.syn_dict['json_file_path'])
359
- json_to_be_ratioed = syn_dict_temp['json_file_path']
360
- corrected_ratio_path = os.path.join(destination_dir,json_to_be_ratioed)
361
- syn_dict_temp['json_file_path'] = corrected_ratio_path
362
-
363
356
  if self.syn_dict == None:
364
357
  json_editor = seedSweep(json_file_path , self.param_name)
365
358
  json_editor.edit_json(new_value)
366
359
  else:
360
+ # need to keep the orignal around
361
+ syn_dict_temp = copy.deepcopy(self.syn_dict)
362
+ json_to_be_ratioed = syn_dict_temp['json_file_path']
363
+ corrected_ratio_path = os.path.join(destination_dir,json_to_be_ratioed)
364
+ syn_dict_temp['json_file_path'] = corrected_ratio_path
367
365
  json_editor = multiSeedSweep(json_file_path ,self.param_name,
368
366
  syn_dict=syn_dict_temp,base_ratio=1)
369
367
  json_editor.edit_all_jsons(new_value)
@@ -415,17 +413,15 @@ class BlockRunner:
415
413
  shutil.copytree(source_dir, destination_dir) # create new components folder
416
414
  json_file_path = os.path.join(destination_dir,self.json_file_path)
417
415
 
418
- # need to keep the orignal around
419
- syn_dict_temp = copy.deepcopy(self.syn_dict)
420
- print(self.syn_dict['json_file_path'])
421
- json_to_be_ratioed = syn_dict_temp['json_file_path']
422
- corrected_ratio_path = os.path.join(destination_dir,json_to_be_ratioed)
423
- syn_dict_temp['json_file_path'] = corrected_ratio_path
424
-
425
416
  if self.syn_dict == None:
426
417
  json_editor = seedSweep(json_file_path , self.param_name)
427
418
  json_editor.edit_json(new_value)
428
419
  else:
420
+ # need to keep the orignal around
421
+ syn_dict_temp = copy.deepcopy(self.syn_dict)
422
+ json_to_be_ratioed = syn_dict_temp['json_file_path']
423
+ corrected_ratio_path = os.path.join(destination_dir,json_to_be_ratioed)
424
+ syn_dict_temp['json_file_path'] = corrected_ratio_path
429
425
  json_editor = multiSeedSweep(json_file_path ,self.param_name,
430
426
  syn_dict_temp,base_ratio=1)
431
427
  json_editor.edit_all_jsons(new_value)
@@ -0,0 +1,408 @@
1
+ """
2
+ Module for processing BMTK LFP output.
3
+ """
4
+
5
+ import h5py
6
+ import numpy as np
7
+ import xarray as xr
8
+ from fooof import FOOOF
9
+ from fooof.sim.gen import gen_model
10
+ import matplotlib.pyplot as plt
11
+ from scipy import signal
12
+ import pywt
13
+ from bmtool.bmplot import is_notebook
14
+
15
+
16
+ def load_ecp_to_xarray(ecp_file: str, demean: bool = False) -> xr.DataArray:
17
+ """
18
+ Load ECP data from an HDF5 file (BMTK sim) into an xarray DataArray.
19
+
20
+ Parameters:
21
+ ----------
22
+ ecp_file : str
23
+ Path to the HDF5 file containing ECP data.
24
+ demean : bool, optional
25
+ If True, the mean of the data will be subtracted (default is False).
26
+
27
+ Returns:
28
+ -------
29
+ xr.DataArray
30
+ An xarray DataArray containing the ECP data, with time as one dimension
31
+ and channel_id as another.
32
+ """
33
+ with h5py.File(ecp_file, 'r') as f:
34
+ ecp = xr.DataArray(
35
+ f['ecp']['data'][()].T,
36
+ coords=dict(
37
+ channel_id=f['ecp']['channel_id'][()],
38
+ time=np.arange(*f['ecp']['time']) # ms
39
+ ),
40
+ attrs=dict(
41
+ fs=1000 / f['ecp']['time'][2] # Hz
42
+ )
43
+ )
44
+ if demean:
45
+ ecp -= ecp.mean(dim='time')
46
+ return ecp
47
+
48
+
49
+ def ecp_to_lfp(ecp_data: xr.DataArray, cutoff: float = 250, fs: float = 10000,
50
+ downsample_freq: float = 1000) -> xr.DataArray:
51
+ """
52
+ Apply a low-pass Butterworth filter to an xarray DataArray and optionally downsample.
53
+ This filters out the high end frequencies turning the ECP into a LFP
54
+
55
+ Parameters:
56
+ ----------
57
+ ecp_data : xr.DataArray
58
+ The input data array containing LFP data with time as one dimension.
59
+ cutoff : float
60
+ The cutoff frequency for the low-pass filter in Hz (default is 250Hz).
61
+ fs : float, optional
62
+ The sampling frequency of the data (default is 10000 Hz).
63
+ downsample_freq : float, optional
64
+ The frequency to downsample to (default is 1000 Hz).
65
+
66
+ Returns:
67
+ -------
68
+ xr.DataArray
69
+ The filtered (and possibly downsampled) data as an xarray DataArray.
70
+ """
71
+ # Bandpass filter design
72
+ nyq = 0.5 * fs
73
+ cut = cutoff / nyq
74
+ b, a = signal.butter(8, cut, btype='low', analog=False)
75
+
76
+ # Initialize an array to hold filtered data
77
+ filtered_data = xr.DataArray(np.zeros_like(ecp_data), coords=ecp_data.coords, dims=ecp_data.dims)
78
+
79
+ # Apply the filter to each channel
80
+ for channel in ecp_data.channel_id:
81
+ filtered_data.loc[channel, :] = signal.filtfilt(b, a, ecp_data.sel(channel_id=channel).values)
82
+
83
+ # Downsample the filtered data if a downsample frequency is provided
84
+ if downsample_freq is not None:
85
+ downsample_factor = int(fs / downsample_freq)
86
+ filtered_data = filtered_data.isel(time=slice(None, None, downsample_factor))
87
+ # Update the sampling frequency attribute
88
+ filtered_data.attrs['fs'] = downsample_freq
89
+
90
+ return filtered_data
91
+
92
+
93
+ def slice_time_series(data: xr.DataArray, time_ranges: tuple) -> xr.DataArray:
94
+ """
95
+ Slice the xarray DataArray based on provided time ranges.
96
+ Can be used to get LFP during certain stimulus times
97
+
98
+ Parameters:
99
+ ----------
100
+ data : xr.DataArray
101
+ The input xarray DataArray containing time-series data.
102
+ time_ranges : tuple or list of tuples
103
+ One or more tuples representing the (start, stop) time points for slicing.
104
+ For example: (start, stop) or [(start1, stop1), (start2, stop2)]
105
+
106
+ Returns:
107
+ -------
108
+ xr.DataArray
109
+ A new xarray DataArray containing the concatenated slices.
110
+ """
111
+ # Ensure time_ranges is a list of tuples
112
+ if isinstance(time_ranges, tuple) and len(time_ranges) == 2:
113
+ time_ranges = [time_ranges]
114
+
115
+ # List to hold sliced data
116
+ slices = []
117
+
118
+ # Slice the data for each time range
119
+ for start, stop in time_ranges:
120
+ sliced_data = data.sel(time=slice(start, stop))
121
+ slices.append(sliced_data)
122
+
123
+ # Concatenate all slices along the time dimension if more than one slice
124
+ if len(slices) > 1:
125
+ return xr.concat(slices, dim='time')
126
+ else:
127
+ return slices[0]
128
+
129
+
130
+ def fit_fooof(f: np.ndarray, pxx: np.ndarray, aperiodic_mode: str = 'fixed',
131
+ dB_threshold: float = 3.0, max_n_peaks: int = 10,
132
+ freq_range: tuple = None, peak_width_limits: tuple = None,
133
+ report: bool = False, plot: bool = False,
134
+ plt_log: bool = False, plt_range: tuple = None,
135
+ figsize: tuple = None, title: str = None) -> tuple:
136
+ """
137
+ Fit a FOOOF model to power spectral density data.
138
+
139
+ Parameters:
140
+ ----------
141
+ f : array-like
142
+ Frequencies corresponding to the power spectral density data.
143
+ pxx : array-like
144
+ Power spectral density data to fit.
145
+ aperiodic_mode : str, optional
146
+ The mode for fitting aperiodic components ('fixed' or 'knee', default is 'fixed').
147
+ dB_threshold : float, optional
148
+ Minimum peak height in dB (default is 3).
149
+ max_n_peaks : int, optional
150
+ Maximum number of peaks to fit (default is 10).
151
+ freq_range : tuple, optional
152
+ Frequency range to fit (default is None, which uses the full range).
153
+ peak_width_limits : tuple, optional
154
+ Limits on the width of peaks (default is None).
155
+ report : bool, optional
156
+ If True, will print fitting results (default is False).
157
+ plot : bool, optional
158
+ If True, will plot the fitting results (default is False).
159
+ plt_log : bool, optional
160
+ If True, use a logarithmic scale for the y-axis in plots (default is False).
161
+ plt_range : tuple, optional
162
+ Range for plotting (default is None).
163
+ figsize : tuple, optional
164
+ Size of the figure (default is None).
165
+ title : str, optional
166
+ Title for the plot (default is None).
167
+
168
+ Returns:
169
+ -------
170
+ tuple
171
+ A tuple containing the fitting results and the FOOOF model object.
172
+ """
173
+ if aperiodic_mode != 'knee':
174
+ aperiodic_mode = 'fixed'
175
+
176
+ def set_range(x, upper=f[-1]):
177
+ x = np.array(upper) if x is None else np.array(x)
178
+ return [f[2], x.item()] if x.size == 1 else x.tolist()
179
+
180
+ freq_range = set_range(freq_range)
181
+ peak_width_limits = set_range(peak_width_limits, np.inf)
182
+
183
+ # Initialize a FOOOF object
184
+ fm = FOOOF(peak_width_limits=peak_width_limits, min_peak_height=dB_threshold / 10,
185
+ peak_threshold=0., max_n_peaks=max_n_peaks, aperiodic_mode=aperiodic_mode)
186
+
187
+ # Fit the model
188
+ try:
189
+ fm.fit(f, pxx, freq_range)
190
+ except Exception as e:
191
+ fl = np.linspace(f[0], f[-1], int((f[-1] - f[0]) / np.min(np.diff(f))) + 1)
192
+ fm.fit(fl, np.interp(fl, f, pxx), freq_range)
193
+
194
+ results = fm.get_results()
195
+
196
+ if report:
197
+ fm.print_results()
198
+ if aperiodic_mode == 'knee':
199
+ ap_params = results.aperiodic_params
200
+ if ap_params[1] <= 0:
201
+ print('Negative value of knee parameter occurred. Suggestion: Fit without knee parameter.')
202
+ knee_freq = np.abs(ap_params[1]) ** (1 / ap_params[2])
203
+ print(f'Knee location: {knee_freq:.2f} Hz')
204
+
205
+ if plot:
206
+ plt_range = set_range(plt_range)
207
+ fm.plot(plt_log=plt_log)
208
+ plt.xlim(np.log10(plt_range) if plt_log else plt_range)
209
+ #plt.ylim(-8, -5.5)
210
+ if figsize:
211
+ plt.gcf().set_size_inches(figsize)
212
+ if title:
213
+ plt.title(title)
214
+ if is_notebook():
215
+ pass
216
+ else:
217
+ plt.show()
218
+
219
+ return results, fm
220
+
221
+
222
+ def generate_resd_from_fooof(fooof_model: FOOOF) -> tuple:
223
+ """
224
+ Generate residuals from a fitted FOOOF model.
225
+
226
+ Parameters:
227
+ ----------
228
+ fooof_model : FOOOF
229
+ A fitted FOOOF model object.
230
+
231
+ Returns:
232
+ -------
233
+ tuple
234
+ A tuple containing the residual power spectral density and the aperiodic fit.
235
+ """
236
+ results = fooof_model.get_results()
237
+ full_fit, _, ap_fit = gen_model(fooof_model.freqs[1:], results.aperiodic_params,
238
+ results.gaussian_params, return_components=True)
239
+
240
+ full_fit, ap_fit = 10 ** full_fit, 10 ** ap_fit # Convert back from log
241
+ res_psd = np.insert((10 ** fooof_model.power_spectrum[1:]) - ap_fit, 0, 0.) # Convert back from log
242
+ res_fit = np.insert(full_fit - ap_fit, 0, 0.)
243
+ ap_fit = np.insert(ap_fit, 0, 0.)
244
+
245
+ return res_psd, ap_fit
246
+
247
+
248
+ def calculate_SNR(fooof_model: FOOOF, freq_band: tuple) -> float:
249
+ """
250
+ Calculate the signal-to-noise ratio (SNR) from a fitted FOOOF model.
251
+
252
+ Parameters:
253
+ ----------
254
+ fooof_model : FOOOF
255
+ A fitted FOOOF model object.
256
+ freq_band : tuple
257
+ Frequency band (min, max) for SNR calculation.
258
+
259
+ Returns:
260
+ -------
261
+ float
262
+ The calculated SNR for the specified frequency band.
263
+ """
264
+ periodic, ap = generate_resd_from_fooof(fooof_model)
265
+ freq = fooof_model.freqs # Get frequencies from model
266
+ indices = (freq >= freq_band[0]) & (freq <= freq_band[1]) # Get only the band we care about
267
+ band_periodic = periodic[indices] # Filter based on band
268
+ band_ap = ap[indices] # Filter
269
+ band_freq = freq[indices] # Another filter
270
+ periodic_power = np.trapz(band_periodic, band_freq) # Integrate periodic power
271
+ ap_power = np.trapz(band_ap, band_freq) # Integrate aperiodic power
272
+ normalized_power = periodic_power / ap_power # Compute the SNR
273
+ return normalized_power
274
+
275
+
276
+ def wavelet_filter(x: np.ndarray, freq: float, fs: float, bandwidth: float = 1.0, axis: int = -1) -> np.ndarray:
277
+ """
278
+ Compute the Continuous Wavelet Transform (CWT) for a specified frequency using a complex Morlet wavelet.
279
+ """
280
+ wavelet = 'cmor' + str(2 * bandwidth ** 2) + '-1.0'
281
+ scale = pywt.scale2frequency(wavelet, 1) * fs / freq
282
+ x_a = pywt.cwt(x, [scale], wavelet=wavelet, axis=axis)[0][0]
283
+ return x_a
284
+
285
+
286
+ def butter_bandpass_filter(data: np.ndarray, lowcut: float, highcut: float, fs: float, order: int = 5, axis: int = -1) -> np.ndarray:
287
+ """
288
+ Apply a Butterworth bandpass filter to the input data.
289
+ """
290
+ sos = signal.butter(order, [lowcut, highcut], fs=fs, btype='band', output='sos')
291
+ x_a = signal.sosfiltfilt(sos, data, axis=axis)
292
+ return x_a
293
+
294
+
295
+ def calculate_plv(x1: np.ndarray, x2: np.ndarray, fs: float, freq_of_interest: float = None,
296
+ method: str = 'wavelet', lowcut: float = None, highcut: float = None,
297
+ bandwidth: float = 2.0) -> np.ndarray:
298
+ """
299
+ Calculate Phase Locking Value (PLV) between two signals using wavelet or Hilbert method.
300
+
301
+ Parameters:
302
+ - x1, x2: Input signals (1D arrays, same length)
303
+ - fs: Sampling frequency
304
+ - freq_of_interest: Desired frequency for wavelet PLV calculation
305
+ - method: 'wavelet' or 'hilbert' to choose the PLV calculation method
306
+ - lowcut, highcut: Cutoff frequencies for the Hilbert method
307
+ - bandwidth: Bandwidth parameter for the wavelet
308
+
309
+ Returns:
310
+ - plv: Phase Locking Value (1D array)
311
+ """
312
+ if len(x1) != len(x2):
313
+ raise ValueError("Input signals must have the same length.")
314
+
315
+ if method == 'wavelet':
316
+ if freq_of_interest is None:
317
+ raise ValueError("freq_of_interest must be provided for the wavelet method.")
318
+
319
+ # Apply CWT to both signals
320
+ theta1 = wavelet_filter(x=x1, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
321
+ theta2 = wavelet_filter(x=x2, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
322
+
323
+ elif method == 'hilbert':
324
+ if lowcut is None or highcut is None:
325
+ print("Lowcut and or highcut were not definded, signal will not be filter and just take hilbert transform for plv calc")
326
+
327
+ if lowcut and highcut:
328
+ # Bandpass filter and get the analytic signal using the Hilbert transform
329
+ x1 = butter_bandpass_filter(x1, lowcut, highcut, fs)
330
+ x2 = butter_bandpass_filter(x2, lowcut, highcut, fs)
331
+
332
+ # Get phase using the Hilbert transform
333
+ theta1 = signal.hilbert(x1)
334
+ theta2 = signal.hilbert(x2)
335
+
336
+ else:
337
+ raise ValueError("Invalid method. Choose 'wavelet' or 'hilbert'.")
338
+
339
+ # Calculate phase difference
340
+ phase_diff = np.angle(theta1) - np.angle(theta2)
341
+
342
+ # Calculate PLV from standard equation from Measuring phase synchrony in brain signals(1999)
343
+ plv = np.abs(np.mean(np.exp(1j * phase_diff), axis=-1))
344
+
345
+ return plv
346
+
347
+
348
+ def calculate_plv_over_time(x1: np.ndarray, x2: np.ndarray, fs: float,
349
+ window_size: float, step_size: float,
350
+ method: str = 'wavelet', freq_of_interest: float = None,
351
+ lowcut: float = None, highcut: float = None,
352
+ bandwidth: float = 2.0):
353
+ """
354
+ Calculate the time-resolved Phase Locking Value (PLV) between two signals using a sliding window approach.
355
+
356
+ Parameters:
357
+ ----------
358
+ x1, x2 : array-like
359
+ Input signals (1D arrays, same length).
360
+ fs : float
361
+ Sampling frequency of the input signals.
362
+ window_size : float
363
+ Length of the window in seconds for PLV calculation.
364
+ step_size : float
365
+ Step size in seconds to slide the window across the signals.
366
+ method : str, optional
367
+ Method to calculate PLV ('wavelet' or 'hilbert'). Defaults to 'wavelet'.
368
+ freq_of_interest : float, optional
369
+ Frequency of interest for the wavelet method. Required if method is 'wavelet'.
370
+ lowcut, highcut : float, optional
371
+ Cutoff frequencies for the Hilbert method. Required if method is 'hilbert'.
372
+ bandwidth : float, optional
373
+ Bandwidth parameter for the wavelet. Defaults to 2.0.
374
+
375
+ Returns:
376
+ -------
377
+ plv_over_time : 1D array
378
+ Array of PLV values calculated over each window.
379
+ times : 1D array
380
+ The center times of each window where the PLV was calculated.
381
+ """
382
+ # Convert window and step size from seconds to samples
383
+ window_samples = int(window_size * fs)
384
+ step_samples = int(step_size * fs)
385
+
386
+ # Initialize results
387
+ plv_over_time = []
388
+ times = []
389
+
390
+ # Iterate over the signal with a sliding window
391
+ for start in range(0, len(x1) - window_samples + 1, step_samples):
392
+ end = start + window_samples
393
+ window_x1 = x1[start:end]
394
+ window_x2 = x2[start:end]
395
+
396
+ # Use the updated calculate_plv function within each window
397
+ plv = calculate_plv(x1=window_x1, x2=window_x2, fs=fs,
398
+ method=method, freq_of_interest=freq_of_interest,
399
+ lowcut=lowcut, highcut=highcut, bandwidth=bandwidth)
400
+ plv_over_time.append(plv)
401
+
402
+ # Store the time at the center of the window
403
+ center_time = (start + end) / 2 / fs
404
+ times.append(center_time)
405
+
406
+ return np.array(plv_over_time), np.array(times)
407
+
408
+
@@ -0,0 +1,181 @@
1
+ """
2
+ Module for processing BMTK spikes output.
3
+ """
4
+
5
+ import h5py
6
+ import pandas as pd
7
+ from bmtool.util.util import load_nodes_from_config
8
+ from typing import Dict, Optional,Tuple, Union
9
+ import numpy as np
10
+ import os
11
+
12
+
13
+ def load_spikes_to_df(spike_file: str, network_name: str, sort: bool = True, config: str = None, groupby: str = 'pop_name') -> pd.DataFrame:
14
+ """
15
+ Load spike data from an HDF5 file into a pandas DataFrame.
16
+
17
+ Args:
18
+ spike_file (str): Path to the HDF5 file containing spike data.
19
+ network_name (str): The name of the network within the HDF5 file from which to load spike data.
20
+ sort (bool, optional): Whether to sort the DataFrame by 'timestamps'. Defaults to True.
21
+ config (str, optional): Will label the cell type of each spike.
22
+ groupby (str or list of str, optional): The column(s) to group by. Defaults to 'pop_name'.
23
+
24
+ Returns:
25
+ pd.DataFrame: A pandas DataFrame containing 'node_ids' and 'timestamps' columns from the spike data.
26
+
27
+ Example:
28
+ df = load_spikes_to_df("spikes.h5", "cortex")
29
+ """
30
+ with h5py.File(spike_file) as f:
31
+ spikes_df = pd.DataFrame({
32
+ 'node_ids': f['spikes'][network_name]['node_ids'],
33
+ 'timestamps': f['spikes'][network_name]['timestamps']
34
+ })
35
+
36
+ if sort:
37
+ spikes_df.sort_values(by='timestamps', inplace=True, ignore_index=True)
38
+
39
+ if config:
40
+ nodes = load_nodes_from_config(config)
41
+ nodes = nodes[network_name]
42
+
43
+ # Check if 'groupby' is a string or a list of strings and handle accordingly
44
+ if isinstance(groupby, str):
45
+ spikes_df = spikes_df.merge(nodes[groupby], left_on='node_ids', right_index=True, how='left')
46
+ elif isinstance(groupby, list):
47
+ for group in groupby:
48
+ spikes_df = spikes_df.merge(nodes[group], left_on='node_ids', right_index=True, how='left')
49
+
50
+ return spikes_df
51
+
52
+
53
+
54
+ def _pop_spike_rate(spike_times: Union[np.ndarray, list], time: Optional[Tuple[float, float, float]] = None,
55
+ time_points: Optional[Union[np.ndarray, list]] = None, frequeny: bool = False) -> np.ndarray:
56
+ """
57
+ Calculate the spike count or frequency histogram over specified time intervals.
58
+
59
+ Args:
60
+ spike_times (Union[np.ndarray, list]): Array or list of spike times in milliseconds.
61
+ time (Optional[Tuple[float, float, float]], optional): Tuple specifying (start, stop, step) in milliseconds.
62
+ Used to create evenly spaced time points if `time_points` is not provided. Default is None.
63
+ time_points (Optional[Union[np.ndarray, list]], optional): Array or list of specific time points for binning.
64
+ If provided, `time` is ignored. Default is None.
65
+ frequeny (bool, optional): If True, returns spike frequency in Hz; otherwise, returns spike count. Default is False.
66
+
67
+ Returns:
68
+ np.ndarray: Array of spike counts or frequencies, depending on the `frequeny` flag.
69
+
70
+ Raises:
71
+ ValueError: If both `time` and `time_points` are None.
72
+ """
73
+ if time_points is None:
74
+ if time is None:
75
+ raise ValueError("Either `time` or `time_points` must be provided.")
76
+ time_points = np.arange(*time)
77
+ dt = time[2]
78
+ else:
79
+ time_points = np.asarray(time_points).ravel()
80
+ dt = (time_points[-1] - time_points[0]) / (time_points.size - 1)
81
+
82
+ bins = np.append(time_points, time_points[-1] + dt)
83
+ spike_rate, _ = np.histogram(np.asarray(spike_times), bins)
84
+
85
+ if frequeny:
86
+ spike_rate = 1000 / dt * spike_rate
87
+
88
+ return spike_rate
89
+
90
+
91
+ def get_population_spike_rate(spikes: pd.DataFrame, fs: float = 400.0, t_start: float = 0, t_stop: Optional[float] = None,
92
+ config: Optional[str] = None, network_name: Optional[str] = None,
93
+ save: bool = False, save_path: Optional[str] = None,
94
+ normalize: bool = False) -> Dict[str, np.ndarray]:
95
+ """
96
+ Calculate the population spike rate for each population in the given spike data, with an option to normalize.
97
+
98
+ Args:
99
+ spikes (pd.DataFrame): A DataFrame containing spike data with columns 'pop_name', 'timestamps', and 'node_ids'.
100
+ fs (float, optional): Sampling frequency in Hz, which determines the time bin size for calculating the spike rate. Default is 400.
101
+ t_start (float, optional): Start time (in milliseconds) for spike rate calculation. Default is 0.
102
+ t_stop (Optional[float], optional): Stop time (in milliseconds) for spike rate calculation. If None, defaults to the maximum timestamp in the data.
103
+ config (Optional[str], optional): Path to a configuration file containing node information, used to determine the correct number of nodes per population.
104
+ If None, node count is estimated from unique node spikes. Default is None.
105
+ network_name (Optional[str], optional): Name of the network used in the configuration file, allowing selection of nodes for that network.
106
+ Required if `config` is provided. Default is None.
107
+ save (bool, optional): Whether to save the calculated population spike rate to a file. Default is False.
108
+ save_path (Optional[str], optional): Directory path where the file should be saved if `save` is True. If `save` is True and `save_path` is None, a ValueError is raised.
109
+ normalize (bool, optional): Whether to normalize the spike rates for each population to a range of [0, 1]. Default is False.
110
+
111
+ Returns:
112
+ Dict[str, np.ndarray]: A dictionary where keys are population names, and values are arrays representing the spike rate over time for each population.
113
+ If `normalize` is True, each population's spike rate is scaled to [0, 1].
114
+
115
+ Raises:
116
+ ValueError: If `save` is True but `save_path` is not provided.
117
+
118
+ Notes:
119
+ - If `config` is None, the function assumes all cells in each population have fired at least once; otherwise, the node count may be inaccurate.
120
+ - If normalization is enabled, each population's spike rate is scaled using Min-Max normalization based on its own minimum and maximum values.
121
+
122
+ """
123
+ pop_spikes = {}
124
+ node_number = {}
125
+
126
+ if config is None:
127
+ print("Note: Node number is obtained by counting unique node spikes in the network.\nIf the network did not run for a sufficient duration, and not all cells fired, this count might be incorrect.")
128
+ print("You can provide a config to calculate the correct amount of nodes!")
129
+
130
+ if config:
131
+ if not network_name:
132
+ print("Grabbing first network; specify a network name to ensure correct node population is selected.")
133
+
134
+ for pop_name in spikes['pop_name'].unique():
135
+ ps = spikes[spikes['pop_name'] == pop_name]
136
+
137
+ if config:
138
+ nodes = load_nodes_from_config(config)
139
+ if network_name:
140
+ nodes = nodes[network_name]
141
+ else:
142
+ nodes = list(nodes.values())[0] if nodes else {}
143
+ nodes = nodes[nodes['pop_name'] == pop_name]
144
+ node_number[pop_name] = nodes.index.nunique()
145
+ else:
146
+ node_number[pop_name] = ps['node_ids'].nunique()
147
+
148
+ if t_stop is None:
149
+ t_stop = spikes['timestamps'].max()
150
+
151
+ filtered_spikes = spikes[
152
+ (spikes['pop_name'] == pop_name) &
153
+ (spikes['timestamps'] > t_start) &
154
+ (spikes['timestamps'] < t_stop)
155
+ ]
156
+ pop_spikes[pop_name] = filtered_spikes
157
+
158
+ time = np.array([t_start, t_stop, 1000 / fs])
159
+ pop_rspk = {p: _pop_spike_rate(spk['timestamps'], time) for p, spk in pop_spikes.items()}
160
+ spike_rate = {p: fs / node_number[p] * pop_rspk[p] for p in pop_rspk}
161
+
162
+ # Normalize each spike rate series if normalize=True
163
+ if normalize:
164
+ spike_rate = {p: (sr - sr.min()) / (sr.max() - sr.min()) for p, sr in spike_rate.items()}
165
+
166
+ if save:
167
+ if save_path is None:
168
+ raise ValueError("save_path must be provided if save is True.")
169
+
170
+ os.makedirs(save_path, exist_ok=True)
171
+
172
+ save_file = os.path.join(save_path, 'spike_rate.h5')
173
+ with h5py.File(save_file, 'w') as f:
174
+ f.create_dataset('time', data=time)
175
+ grp = f.create_group('populations')
176
+ for p, rspk in spike_rate.items():
177
+ pop_grp = grp.create_group(p)
178
+ pop_grp.create_dataset('data', data=rspk)
179
+
180
+ return spike_rate
181
+
@@ -762,7 +762,7 @@ def connector_percent_matrix(csv_path: str = None, exclude_strings=None, assemb_
762
762
  plt.tight_layout()
763
763
  plt.show()
764
764
 
765
- def raster(spikes_df: Optional[pd.DataFrame] = None, config: Optional[str] = None, network_name: Optional[str] = None,
765
+ def raster(spikes_df: Optional[pd.DataFrame] = None, config: Optional[str] = None, network_name: Optional[str] = None, groupby:Optional[str] = 'pop_name',
766
766
  ax: Optional[Axes] = None,tstart: Optional[float] = None,tstop: Optional[float] = None,
767
767
  color_map: Optional[Dict[str, str]] = None) -> Axes:
768
768
  """
@@ -793,7 +793,7 @@ def raster(spikes_df: Optional[pd.DataFrame] = None, config: Optional[str] = Non
793
793
  Notes:
794
794
  -----
795
795
  - If `config` is provided, the function merges population names from the node data with `spikes_df`.
796
- - Each unique population (`pop_name`) in `spikes_df` will be represented by a different color if `color_map` is not specified.
796
+ - Each unique population from groupby in `spikes_df` will be represented by a different color if `color_map` is not specified.
797
797
  - If `color_map` is provided, it should contain colors for all unique `pop_name` values in `spikes_df`.
798
798
  """
799
799
  # Initialize axes if none provided
@@ -822,11 +822,11 @@ def raster(spikes_df: Optional[pd.DataFrame] = None, config: Optional[str] = Non
822
822
  # Drop all intersecting columns except the join key column from df2
823
823
  spikes_df = spikes_df.drop(columns=common_columns)
824
824
  # merge nodes and spikes df
825
- spikes_df = spikes_df.merge(nodes['pop_name'], left_on='node_ids', right_index=True, how='left')
825
+ spikes_df = spikes_df.merge(nodes[groupby], left_on='node_ids', right_index=True, how='left')
826
826
 
827
827
 
828
828
  # Get unique population names
829
- unique_pop_names = spikes_df['pop_name'].unique()
829
+ unique_pop_names = spikes_df[groupby].unique()
830
830
 
831
831
  # Generate colors if no color_map is provided
832
832
  if color_map is None:
@@ -839,7 +839,7 @@ def raster(spikes_df: Optional[pd.DataFrame] = None, config: Optional[str] = Non
839
839
  raise ValueError(f"color_map is missing colors for populations: {missing_colors}")
840
840
 
841
841
  # Plot each population with its specified or generated color
842
- for pop_name, group in spikes_df.groupby('pop_name'):
842
+ for pop_name, group in spikes_df.groupby(groupby):
843
843
  ax.scatter(group['timestamps'], group['node_ids'], label=pop_name, color=color_map[pop_name], s=0.5)
844
844
 
845
845
  # Label axes
@@ -90,7 +90,11 @@ class CurrentClamp(object):
90
90
  self.inj_dur = inj_dur
91
91
  self.inj_amp = inj_amp * 1e-3 # pA to nA
92
92
 
93
- self.cell = self.create_cell()
93
+ # sometimes people may put a hoc object in for the template name
94
+ if callable(template_name):
95
+ self.cell = template_name
96
+ else:
97
+ self.cell = self.create_cell()
94
98
  if post_init_function:
95
99
  eval(f"self.cell.{post_init_function}")
96
100
 
File without changes
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: bmtool
3
- Version: 0.6.7
3
+ Version: 0.6.7.1
4
4
  Summary: BMTool
5
5
  Home-page: https://github.com/cyneuro/bmtool
6
6
  Download-URL:
@@ -17,6 +17,9 @@ bmtool.egg-info/dependency_links.txt
17
17
  bmtool.egg-info/entry_points.txt
18
18
  bmtool.egg-info/requires.txt
19
19
  bmtool.egg-info/top_level.txt
20
+ bmtool/analysis/__init__.py
21
+ bmtool/analysis/lfp.py
22
+ bmtool/analysis/spikes.py
20
23
  bmtool/debug/__init__.py
21
24
  bmtool/debug/commands.py
22
25
  bmtool/debug/debug.py
@@ -6,7 +6,7 @@ with open("README.md", "r") as fh:
6
6
 
7
7
  setup(
8
8
  name="bmtool",
9
- version='0.6.7',
9
+ version='0.6.7.1',
10
10
  author="Neural Engineering Laboratory at the University of Missouri",
11
11
  author_email="gregglickert@mail.missouri.edu",
12
12
  description="BMTool",
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes