bmtool 0.6.7__py3-none-any.whl → 0.6.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bmtool/SLURM.py +10 -14
- bmtool/analysis/__init__.py +0 -0
- bmtool/analysis/lfp.py +408 -0
- bmtool/analysis/spikes.py +181 -0
- bmtool/bmplot.py +5 -5
- bmtool/singlecell.py +5 -1
- {bmtool-0.6.7.dist-info → bmtool-0.6.7.1.dist-info}/METADATA +1 -1
- {bmtool-0.6.7.dist-info → bmtool-0.6.7.1.dist-info}/RECORD +12 -9
- {bmtool-0.6.7.dist-info → bmtool-0.6.7.1.dist-info}/WHEEL +1 -1
- {bmtool-0.6.7.dist-info → bmtool-0.6.7.1.dist-info}/LICENSE +0 -0
- {bmtool-0.6.7.dist-info → bmtool-0.6.7.1.dist-info}/entry_points.txt +0 -0
- {bmtool-0.6.7.dist-info → bmtool-0.6.7.1.dist-info}/top_level.txt +0 -0
bmtool/SLURM.py
CHANGED
@@ -353,17 +353,15 @@ class BlockRunner:
|
|
353
353
|
shutil.copytree(source_dir, destination_dir) # create new components folder
|
354
354
|
json_file_path = os.path.join(destination_dir,self.json_file_path)
|
355
355
|
|
356
|
-
# need to keep the orignal around
|
357
|
-
syn_dict_temp = copy.deepcopy(self.syn_dict)
|
358
|
-
print(self.syn_dict['json_file_path'])
|
359
|
-
json_to_be_ratioed = syn_dict_temp['json_file_path']
|
360
|
-
corrected_ratio_path = os.path.join(destination_dir,json_to_be_ratioed)
|
361
|
-
syn_dict_temp['json_file_path'] = corrected_ratio_path
|
362
|
-
|
363
356
|
if self.syn_dict == None:
|
364
357
|
json_editor = seedSweep(json_file_path , self.param_name)
|
365
358
|
json_editor.edit_json(new_value)
|
366
359
|
else:
|
360
|
+
# need to keep the orignal around
|
361
|
+
syn_dict_temp = copy.deepcopy(self.syn_dict)
|
362
|
+
json_to_be_ratioed = syn_dict_temp['json_file_path']
|
363
|
+
corrected_ratio_path = os.path.join(destination_dir,json_to_be_ratioed)
|
364
|
+
syn_dict_temp['json_file_path'] = corrected_ratio_path
|
367
365
|
json_editor = multiSeedSweep(json_file_path ,self.param_name,
|
368
366
|
syn_dict=syn_dict_temp,base_ratio=1)
|
369
367
|
json_editor.edit_all_jsons(new_value)
|
@@ -415,17 +413,15 @@ class BlockRunner:
|
|
415
413
|
shutil.copytree(source_dir, destination_dir) # create new components folder
|
416
414
|
json_file_path = os.path.join(destination_dir,self.json_file_path)
|
417
415
|
|
418
|
-
# need to keep the orignal around
|
419
|
-
syn_dict_temp = copy.deepcopy(self.syn_dict)
|
420
|
-
print(self.syn_dict['json_file_path'])
|
421
|
-
json_to_be_ratioed = syn_dict_temp['json_file_path']
|
422
|
-
corrected_ratio_path = os.path.join(destination_dir,json_to_be_ratioed)
|
423
|
-
syn_dict_temp['json_file_path'] = corrected_ratio_path
|
424
|
-
|
425
416
|
if self.syn_dict == None:
|
426
417
|
json_editor = seedSweep(json_file_path , self.param_name)
|
427
418
|
json_editor.edit_json(new_value)
|
428
419
|
else:
|
420
|
+
# need to keep the orignal around
|
421
|
+
syn_dict_temp = copy.deepcopy(self.syn_dict)
|
422
|
+
json_to_be_ratioed = syn_dict_temp['json_file_path']
|
423
|
+
corrected_ratio_path = os.path.join(destination_dir,json_to_be_ratioed)
|
424
|
+
syn_dict_temp['json_file_path'] = corrected_ratio_path
|
429
425
|
json_editor = multiSeedSweep(json_file_path ,self.param_name,
|
430
426
|
syn_dict_temp,base_ratio=1)
|
431
427
|
json_editor.edit_all_jsons(new_value)
|
File without changes
|
bmtool/analysis/lfp.py
ADDED
@@ -0,0 +1,408 @@
|
|
1
|
+
"""
|
2
|
+
Module for processing BMTK LFP output.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import h5py
|
6
|
+
import numpy as np
|
7
|
+
import xarray as xr
|
8
|
+
from fooof import FOOOF
|
9
|
+
from fooof.sim.gen import gen_model
|
10
|
+
import matplotlib.pyplot as plt
|
11
|
+
from scipy import signal
|
12
|
+
import pywt
|
13
|
+
from bmtool.bmplot import is_notebook
|
14
|
+
|
15
|
+
|
16
|
+
def load_ecp_to_xarray(ecp_file: str, demean: bool = False) -> xr.DataArray:
|
17
|
+
"""
|
18
|
+
Load ECP data from an HDF5 file (BMTK sim) into an xarray DataArray.
|
19
|
+
|
20
|
+
Parameters:
|
21
|
+
----------
|
22
|
+
ecp_file : str
|
23
|
+
Path to the HDF5 file containing ECP data.
|
24
|
+
demean : bool, optional
|
25
|
+
If True, the mean of the data will be subtracted (default is False).
|
26
|
+
|
27
|
+
Returns:
|
28
|
+
-------
|
29
|
+
xr.DataArray
|
30
|
+
An xarray DataArray containing the ECP data, with time as one dimension
|
31
|
+
and channel_id as another.
|
32
|
+
"""
|
33
|
+
with h5py.File(ecp_file, 'r') as f:
|
34
|
+
ecp = xr.DataArray(
|
35
|
+
f['ecp']['data'][()].T,
|
36
|
+
coords=dict(
|
37
|
+
channel_id=f['ecp']['channel_id'][()],
|
38
|
+
time=np.arange(*f['ecp']['time']) # ms
|
39
|
+
),
|
40
|
+
attrs=dict(
|
41
|
+
fs=1000 / f['ecp']['time'][2] # Hz
|
42
|
+
)
|
43
|
+
)
|
44
|
+
if demean:
|
45
|
+
ecp -= ecp.mean(dim='time')
|
46
|
+
return ecp
|
47
|
+
|
48
|
+
|
49
|
+
def ecp_to_lfp(ecp_data: xr.DataArray, cutoff: float = 250, fs: float = 10000,
|
50
|
+
downsample_freq: float = 1000) -> xr.DataArray:
|
51
|
+
"""
|
52
|
+
Apply a low-pass Butterworth filter to an xarray DataArray and optionally downsample.
|
53
|
+
This filters out the high end frequencies turning the ECP into a LFP
|
54
|
+
|
55
|
+
Parameters:
|
56
|
+
----------
|
57
|
+
ecp_data : xr.DataArray
|
58
|
+
The input data array containing LFP data with time as one dimension.
|
59
|
+
cutoff : float
|
60
|
+
The cutoff frequency for the low-pass filter in Hz (default is 250Hz).
|
61
|
+
fs : float, optional
|
62
|
+
The sampling frequency of the data (default is 10000 Hz).
|
63
|
+
downsample_freq : float, optional
|
64
|
+
The frequency to downsample to (default is 1000 Hz).
|
65
|
+
|
66
|
+
Returns:
|
67
|
+
-------
|
68
|
+
xr.DataArray
|
69
|
+
The filtered (and possibly downsampled) data as an xarray DataArray.
|
70
|
+
"""
|
71
|
+
# Bandpass filter design
|
72
|
+
nyq = 0.5 * fs
|
73
|
+
cut = cutoff / nyq
|
74
|
+
b, a = signal.butter(8, cut, btype='low', analog=False)
|
75
|
+
|
76
|
+
# Initialize an array to hold filtered data
|
77
|
+
filtered_data = xr.DataArray(np.zeros_like(ecp_data), coords=ecp_data.coords, dims=ecp_data.dims)
|
78
|
+
|
79
|
+
# Apply the filter to each channel
|
80
|
+
for channel in ecp_data.channel_id:
|
81
|
+
filtered_data.loc[channel, :] = signal.filtfilt(b, a, ecp_data.sel(channel_id=channel).values)
|
82
|
+
|
83
|
+
# Downsample the filtered data if a downsample frequency is provided
|
84
|
+
if downsample_freq is not None:
|
85
|
+
downsample_factor = int(fs / downsample_freq)
|
86
|
+
filtered_data = filtered_data.isel(time=slice(None, None, downsample_factor))
|
87
|
+
# Update the sampling frequency attribute
|
88
|
+
filtered_data.attrs['fs'] = downsample_freq
|
89
|
+
|
90
|
+
return filtered_data
|
91
|
+
|
92
|
+
|
93
|
+
def slice_time_series(data: xr.DataArray, time_ranges: tuple) -> xr.DataArray:
|
94
|
+
"""
|
95
|
+
Slice the xarray DataArray based on provided time ranges.
|
96
|
+
Can be used to get LFP during certain stimulus times
|
97
|
+
|
98
|
+
Parameters:
|
99
|
+
----------
|
100
|
+
data : xr.DataArray
|
101
|
+
The input xarray DataArray containing time-series data.
|
102
|
+
time_ranges : tuple or list of tuples
|
103
|
+
One or more tuples representing the (start, stop) time points for slicing.
|
104
|
+
For example: (start, stop) or [(start1, stop1), (start2, stop2)]
|
105
|
+
|
106
|
+
Returns:
|
107
|
+
-------
|
108
|
+
xr.DataArray
|
109
|
+
A new xarray DataArray containing the concatenated slices.
|
110
|
+
"""
|
111
|
+
# Ensure time_ranges is a list of tuples
|
112
|
+
if isinstance(time_ranges, tuple) and len(time_ranges) == 2:
|
113
|
+
time_ranges = [time_ranges]
|
114
|
+
|
115
|
+
# List to hold sliced data
|
116
|
+
slices = []
|
117
|
+
|
118
|
+
# Slice the data for each time range
|
119
|
+
for start, stop in time_ranges:
|
120
|
+
sliced_data = data.sel(time=slice(start, stop))
|
121
|
+
slices.append(sliced_data)
|
122
|
+
|
123
|
+
# Concatenate all slices along the time dimension if more than one slice
|
124
|
+
if len(slices) > 1:
|
125
|
+
return xr.concat(slices, dim='time')
|
126
|
+
else:
|
127
|
+
return slices[0]
|
128
|
+
|
129
|
+
|
130
|
+
def fit_fooof(f: np.ndarray, pxx: np.ndarray, aperiodic_mode: str = 'fixed',
|
131
|
+
dB_threshold: float = 3.0, max_n_peaks: int = 10,
|
132
|
+
freq_range: tuple = None, peak_width_limits: tuple = None,
|
133
|
+
report: bool = False, plot: bool = False,
|
134
|
+
plt_log: bool = False, plt_range: tuple = None,
|
135
|
+
figsize: tuple = None, title: str = None) -> tuple:
|
136
|
+
"""
|
137
|
+
Fit a FOOOF model to power spectral density data.
|
138
|
+
|
139
|
+
Parameters:
|
140
|
+
----------
|
141
|
+
f : array-like
|
142
|
+
Frequencies corresponding to the power spectral density data.
|
143
|
+
pxx : array-like
|
144
|
+
Power spectral density data to fit.
|
145
|
+
aperiodic_mode : str, optional
|
146
|
+
The mode for fitting aperiodic components ('fixed' or 'knee', default is 'fixed').
|
147
|
+
dB_threshold : float, optional
|
148
|
+
Minimum peak height in dB (default is 3).
|
149
|
+
max_n_peaks : int, optional
|
150
|
+
Maximum number of peaks to fit (default is 10).
|
151
|
+
freq_range : tuple, optional
|
152
|
+
Frequency range to fit (default is None, which uses the full range).
|
153
|
+
peak_width_limits : tuple, optional
|
154
|
+
Limits on the width of peaks (default is None).
|
155
|
+
report : bool, optional
|
156
|
+
If True, will print fitting results (default is False).
|
157
|
+
plot : bool, optional
|
158
|
+
If True, will plot the fitting results (default is False).
|
159
|
+
plt_log : bool, optional
|
160
|
+
If True, use a logarithmic scale for the y-axis in plots (default is False).
|
161
|
+
plt_range : tuple, optional
|
162
|
+
Range for plotting (default is None).
|
163
|
+
figsize : tuple, optional
|
164
|
+
Size of the figure (default is None).
|
165
|
+
title : str, optional
|
166
|
+
Title for the plot (default is None).
|
167
|
+
|
168
|
+
Returns:
|
169
|
+
-------
|
170
|
+
tuple
|
171
|
+
A tuple containing the fitting results and the FOOOF model object.
|
172
|
+
"""
|
173
|
+
if aperiodic_mode != 'knee':
|
174
|
+
aperiodic_mode = 'fixed'
|
175
|
+
|
176
|
+
def set_range(x, upper=f[-1]):
|
177
|
+
x = np.array(upper) if x is None else np.array(x)
|
178
|
+
return [f[2], x.item()] if x.size == 1 else x.tolist()
|
179
|
+
|
180
|
+
freq_range = set_range(freq_range)
|
181
|
+
peak_width_limits = set_range(peak_width_limits, np.inf)
|
182
|
+
|
183
|
+
# Initialize a FOOOF object
|
184
|
+
fm = FOOOF(peak_width_limits=peak_width_limits, min_peak_height=dB_threshold / 10,
|
185
|
+
peak_threshold=0., max_n_peaks=max_n_peaks, aperiodic_mode=aperiodic_mode)
|
186
|
+
|
187
|
+
# Fit the model
|
188
|
+
try:
|
189
|
+
fm.fit(f, pxx, freq_range)
|
190
|
+
except Exception as e:
|
191
|
+
fl = np.linspace(f[0], f[-1], int((f[-1] - f[0]) / np.min(np.diff(f))) + 1)
|
192
|
+
fm.fit(fl, np.interp(fl, f, pxx), freq_range)
|
193
|
+
|
194
|
+
results = fm.get_results()
|
195
|
+
|
196
|
+
if report:
|
197
|
+
fm.print_results()
|
198
|
+
if aperiodic_mode == 'knee':
|
199
|
+
ap_params = results.aperiodic_params
|
200
|
+
if ap_params[1] <= 0:
|
201
|
+
print('Negative value of knee parameter occurred. Suggestion: Fit without knee parameter.')
|
202
|
+
knee_freq = np.abs(ap_params[1]) ** (1 / ap_params[2])
|
203
|
+
print(f'Knee location: {knee_freq:.2f} Hz')
|
204
|
+
|
205
|
+
if plot:
|
206
|
+
plt_range = set_range(plt_range)
|
207
|
+
fm.plot(plt_log=plt_log)
|
208
|
+
plt.xlim(np.log10(plt_range) if plt_log else plt_range)
|
209
|
+
#plt.ylim(-8, -5.5)
|
210
|
+
if figsize:
|
211
|
+
plt.gcf().set_size_inches(figsize)
|
212
|
+
if title:
|
213
|
+
plt.title(title)
|
214
|
+
if is_notebook():
|
215
|
+
pass
|
216
|
+
else:
|
217
|
+
plt.show()
|
218
|
+
|
219
|
+
return results, fm
|
220
|
+
|
221
|
+
|
222
|
+
def generate_resd_from_fooof(fooof_model: FOOOF) -> tuple:
|
223
|
+
"""
|
224
|
+
Generate residuals from a fitted FOOOF model.
|
225
|
+
|
226
|
+
Parameters:
|
227
|
+
----------
|
228
|
+
fooof_model : FOOOF
|
229
|
+
A fitted FOOOF model object.
|
230
|
+
|
231
|
+
Returns:
|
232
|
+
-------
|
233
|
+
tuple
|
234
|
+
A tuple containing the residual power spectral density and the aperiodic fit.
|
235
|
+
"""
|
236
|
+
results = fooof_model.get_results()
|
237
|
+
full_fit, _, ap_fit = gen_model(fooof_model.freqs[1:], results.aperiodic_params,
|
238
|
+
results.gaussian_params, return_components=True)
|
239
|
+
|
240
|
+
full_fit, ap_fit = 10 ** full_fit, 10 ** ap_fit # Convert back from log
|
241
|
+
res_psd = np.insert((10 ** fooof_model.power_spectrum[1:]) - ap_fit, 0, 0.) # Convert back from log
|
242
|
+
res_fit = np.insert(full_fit - ap_fit, 0, 0.)
|
243
|
+
ap_fit = np.insert(ap_fit, 0, 0.)
|
244
|
+
|
245
|
+
return res_psd, ap_fit
|
246
|
+
|
247
|
+
|
248
|
+
def calculate_SNR(fooof_model: FOOOF, freq_band: tuple) -> float:
|
249
|
+
"""
|
250
|
+
Calculate the signal-to-noise ratio (SNR) from a fitted FOOOF model.
|
251
|
+
|
252
|
+
Parameters:
|
253
|
+
----------
|
254
|
+
fooof_model : FOOOF
|
255
|
+
A fitted FOOOF model object.
|
256
|
+
freq_band : tuple
|
257
|
+
Frequency band (min, max) for SNR calculation.
|
258
|
+
|
259
|
+
Returns:
|
260
|
+
-------
|
261
|
+
float
|
262
|
+
The calculated SNR for the specified frequency band.
|
263
|
+
"""
|
264
|
+
periodic, ap = generate_resd_from_fooof(fooof_model)
|
265
|
+
freq = fooof_model.freqs # Get frequencies from model
|
266
|
+
indices = (freq >= freq_band[0]) & (freq <= freq_band[1]) # Get only the band we care about
|
267
|
+
band_periodic = periodic[indices] # Filter based on band
|
268
|
+
band_ap = ap[indices] # Filter
|
269
|
+
band_freq = freq[indices] # Another filter
|
270
|
+
periodic_power = np.trapz(band_periodic, band_freq) # Integrate periodic power
|
271
|
+
ap_power = np.trapz(band_ap, band_freq) # Integrate aperiodic power
|
272
|
+
normalized_power = periodic_power / ap_power # Compute the SNR
|
273
|
+
return normalized_power
|
274
|
+
|
275
|
+
|
276
|
+
def wavelet_filter(x: np.ndarray, freq: float, fs: float, bandwidth: float = 1.0, axis: int = -1) -> np.ndarray:
|
277
|
+
"""
|
278
|
+
Compute the Continuous Wavelet Transform (CWT) for a specified frequency using a complex Morlet wavelet.
|
279
|
+
"""
|
280
|
+
wavelet = 'cmor' + str(2 * bandwidth ** 2) + '-1.0'
|
281
|
+
scale = pywt.scale2frequency(wavelet, 1) * fs / freq
|
282
|
+
x_a = pywt.cwt(x, [scale], wavelet=wavelet, axis=axis)[0][0]
|
283
|
+
return x_a
|
284
|
+
|
285
|
+
|
286
|
+
def butter_bandpass_filter(data: np.ndarray, lowcut: float, highcut: float, fs: float, order: int = 5, axis: int = -1) -> np.ndarray:
|
287
|
+
"""
|
288
|
+
Apply a Butterworth bandpass filter to the input data.
|
289
|
+
"""
|
290
|
+
sos = signal.butter(order, [lowcut, highcut], fs=fs, btype='band', output='sos')
|
291
|
+
x_a = signal.sosfiltfilt(sos, data, axis=axis)
|
292
|
+
return x_a
|
293
|
+
|
294
|
+
|
295
|
+
def calculate_plv(x1: np.ndarray, x2: np.ndarray, fs: float, freq_of_interest: float = None,
|
296
|
+
method: str = 'wavelet', lowcut: float = None, highcut: float = None,
|
297
|
+
bandwidth: float = 2.0) -> np.ndarray:
|
298
|
+
"""
|
299
|
+
Calculate Phase Locking Value (PLV) between two signals using wavelet or Hilbert method.
|
300
|
+
|
301
|
+
Parameters:
|
302
|
+
- x1, x2: Input signals (1D arrays, same length)
|
303
|
+
- fs: Sampling frequency
|
304
|
+
- freq_of_interest: Desired frequency for wavelet PLV calculation
|
305
|
+
- method: 'wavelet' or 'hilbert' to choose the PLV calculation method
|
306
|
+
- lowcut, highcut: Cutoff frequencies for the Hilbert method
|
307
|
+
- bandwidth: Bandwidth parameter for the wavelet
|
308
|
+
|
309
|
+
Returns:
|
310
|
+
- plv: Phase Locking Value (1D array)
|
311
|
+
"""
|
312
|
+
if len(x1) != len(x2):
|
313
|
+
raise ValueError("Input signals must have the same length.")
|
314
|
+
|
315
|
+
if method == 'wavelet':
|
316
|
+
if freq_of_interest is None:
|
317
|
+
raise ValueError("freq_of_interest must be provided for the wavelet method.")
|
318
|
+
|
319
|
+
# Apply CWT to both signals
|
320
|
+
theta1 = wavelet_filter(x=x1, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
|
321
|
+
theta2 = wavelet_filter(x=x2, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
|
322
|
+
|
323
|
+
elif method == 'hilbert':
|
324
|
+
if lowcut is None or highcut is None:
|
325
|
+
print("Lowcut and or highcut were not definded, signal will not be filter and just take hilbert transform for plv calc")
|
326
|
+
|
327
|
+
if lowcut and highcut:
|
328
|
+
# Bandpass filter and get the analytic signal using the Hilbert transform
|
329
|
+
x1 = butter_bandpass_filter(x1, lowcut, highcut, fs)
|
330
|
+
x2 = butter_bandpass_filter(x2, lowcut, highcut, fs)
|
331
|
+
|
332
|
+
# Get phase using the Hilbert transform
|
333
|
+
theta1 = signal.hilbert(x1)
|
334
|
+
theta2 = signal.hilbert(x2)
|
335
|
+
|
336
|
+
else:
|
337
|
+
raise ValueError("Invalid method. Choose 'wavelet' or 'hilbert'.")
|
338
|
+
|
339
|
+
# Calculate phase difference
|
340
|
+
phase_diff = np.angle(theta1) - np.angle(theta2)
|
341
|
+
|
342
|
+
# Calculate PLV from standard equation from Measuring phase synchrony in brain signals(1999)
|
343
|
+
plv = np.abs(np.mean(np.exp(1j * phase_diff), axis=-1))
|
344
|
+
|
345
|
+
return plv
|
346
|
+
|
347
|
+
|
348
|
+
def calculate_plv_over_time(x1: np.ndarray, x2: np.ndarray, fs: float,
|
349
|
+
window_size: float, step_size: float,
|
350
|
+
method: str = 'wavelet', freq_of_interest: float = None,
|
351
|
+
lowcut: float = None, highcut: float = None,
|
352
|
+
bandwidth: float = 2.0):
|
353
|
+
"""
|
354
|
+
Calculate the time-resolved Phase Locking Value (PLV) between two signals using a sliding window approach.
|
355
|
+
|
356
|
+
Parameters:
|
357
|
+
----------
|
358
|
+
x1, x2 : array-like
|
359
|
+
Input signals (1D arrays, same length).
|
360
|
+
fs : float
|
361
|
+
Sampling frequency of the input signals.
|
362
|
+
window_size : float
|
363
|
+
Length of the window in seconds for PLV calculation.
|
364
|
+
step_size : float
|
365
|
+
Step size in seconds to slide the window across the signals.
|
366
|
+
method : str, optional
|
367
|
+
Method to calculate PLV ('wavelet' or 'hilbert'). Defaults to 'wavelet'.
|
368
|
+
freq_of_interest : float, optional
|
369
|
+
Frequency of interest for the wavelet method. Required if method is 'wavelet'.
|
370
|
+
lowcut, highcut : float, optional
|
371
|
+
Cutoff frequencies for the Hilbert method. Required if method is 'hilbert'.
|
372
|
+
bandwidth : float, optional
|
373
|
+
Bandwidth parameter for the wavelet. Defaults to 2.0.
|
374
|
+
|
375
|
+
Returns:
|
376
|
+
-------
|
377
|
+
plv_over_time : 1D array
|
378
|
+
Array of PLV values calculated over each window.
|
379
|
+
times : 1D array
|
380
|
+
The center times of each window where the PLV was calculated.
|
381
|
+
"""
|
382
|
+
# Convert window and step size from seconds to samples
|
383
|
+
window_samples = int(window_size * fs)
|
384
|
+
step_samples = int(step_size * fs)
|
385
|
+
|
386
|
+
# Initialize results
|
387
|
+
plv_over_time = []
|
388
|
+
times = []
|
389
|
+
|
390
|
+
# Iterate over the signal with a sliding window
|
391
|
+
for start in range(0, len(x1) - window_samples + 1, step_samples):
|
392
|
+
end = start + window_samples
|
393
|
+
window_x1 = x1[start:end]
|
394
|
+
window_x2 = x2[start:end]
|
395
|
+
|
396
|
+
# Use the updated calculate_plv function within each window
|
397
|
+
plv = calculate_plv(x1=window_x1, x2=window_x2, fs=fs,
|
398
|
+
method=method, freq_of_interest=freq_of_interest,
|
399
|
+
lowcut=lowcut, highcut=highcut, bandwidth=bandwidth)
|
400
|
+
plv_over_time.append(plv)
|
401
|
+
|
402
|
+
# Store the time at the center of the window
|
403
|
+
center_time = (start + end) / 2 / fs
|
404
|
+
times.append(center_time)
|
405
|
+
|
406
|
+
return np.array(plv_over_time), np.array(times)
|
407
|
+
|
408
|
+
|
@@ -0,0 +1,181 @@
|
|
1
|
+
"""
|
2
|
+
Module for processing BMTK spikes output.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import h5py
|
6
|
+
import pandas as pd
|
7
|
+
from bmtool.util.util import load_nodes_from_config
|
8
|
+
from typing import Dict, Optional,Tuple, Union
|
9
|
+
import numpy as np
|
10
|
+
import os
|
11
|
+
|
12
|
+
|
13
|
+
def load_spikes_to_df(spike_file: str, network_name: str, sort: bool = True, config: str = None, groupby: str = 'pop_name') -> pd.DataFrame:
|
14
|
+
"""
|
15
|
+
Load spike data from an HDF5 file into a pandas DataFrame.
|
16
|
+
|
17
|
+
Args:
|
18
|
+
spike_file (str): Path to the HDF5 file containing spike data.
|
19
|
+
network_name (str): The name of the network within the HDF5 file from which to load spike data.
|
20
|
+
sort (bool, optional): Whether to sort the DataFrame by 'timestamps'. Defaults to True.
|
21
|
+
config (str, optional): Will label the cell type of each spike.
|
22
|
+
groupby (str or list of str, optional): The column(s) to group by. Defaults to 'pop_name'.
|
23
|
+
|
24
|
+
Returns:
|
25
|
+
pd.DataFrame: A pandas DataFrame containing 'node_ids' and 'timestamps' columns from the spike data.
|
26
|
+
|
27
|
+
Example:
|
28
|
+
df = load_spikes_to_df("spikes.h5", "cortex")
|
29
|
+
"""
|
30
|
+
with h5py.File(spike_file) as f:
|
31
|
+
spikes_df = pd.DataFrame({
|
32
|
+
'node_ids': f['spikes'][network_name]['node_ids'],
|
33
|
+
'timestamps': f['spikes'][network_name]['timestamps']
|
34
|
+
})
|
35
|
+
|
36
|
+
if sort:
|
37
|
+
spikes_df.sort_values(by='timestamps', inplace=True, ignore_index=True)
|
38
|
+
|
39
|
+
if config:
|
40
|
+
nodes = load_nodes_from_config(config)
|
41
|
+
nodes = nodes[network_name]
|
42
|
+
|
43
|
+
# Check if 'groupby' is a string or a list of strings and handle accordingly
|
44
|
+
if isinstance(groupby, str):
|
45
|
+
spikes_df = spikes_df.merge(nodes[groupby], left_on='node_ids', right_index=True, how='left')
|
46
|
+
elif isinstance(groupby, list):
|
47
|
+
for group in groupby:
|
48
|
+
spikes_df = spikes_df.merge(nodes[group], left_on='node_ids', right_index=True, how='left')
|
49
|
+
|
50
|
+
return spikes_df
|
51
|
+
|
52
|
+
|
53
|
+
|
54
|
+
def _pop_spike_rate(spike_times: Union[np.ndarray, list], time: Optional[Tuple[float, float, float]] = None,
|
55
|
+
time_points: Optional[Union[np.ndarray, list]] = None, frequeny: bool = False) -> np.ndarray:
|
56
|
+
"""
|
57
|
+
Calculate the spike count or frequency histogram over specified time intervals.
|
58
|
+
|
59
|
+
Args:
|
60
|
+
spike_times (Union[np.ndarray, list]): Array or list of spike times in milliseconds.
|
61
|
+
time (Optional[Tuple[float, float, float]], optional): Tuple specifying (start, stop, step) in milliseconds.
|
62
|
+
Used to create evenly spaced time points if `time_points` is not provided. Default is None.
|
63
|
+
time_points (Optional[Union[np.ndarray, list]], optional): Array or list of specific time points for binning.
|
64
|
+
If provided, `time` is ignored. Default is None.
|
65
|
+
frequeny (bool, optional): If True, returns spike frequency in Hz; otherwise, returns spike count. Default is False.
|
66
|
+
|
67
|
+
Returns:
|
68
|
+
np.ndarray: Array of spike counts or frequencies, depending on the `frequeny` flag.
|
69
|
+
|
70
|
+
Raises:
|
71
|
+
ValueError: If both `time` and `time_points` are None.
|
72
|
+
"""
|
73
|
+
if time_points is None:
|
74
|
+
if time is None:
|
75
|
+
raise ValueError("Either `time` or `time_points` must be provided.")
|
76
|
+
time_points = np.arange(*time)
|
77
|
+
dt = time[2]
|
78
|
+
else:
|
79
|
+
time_points = np.asarray(time_points).ravel()
|
80
|
+
dt = (time_points[-1] - time_points[0]) / (time_points.size - 1)
|
81
|
+
|
82
|
+
bins = np.append(time_points, time_points[-1] + dt)
|
83
|
+
spike_rate, _ = np.histogram(np.asarray(spike_times), bins)
|
84
|
+
|
85
|
+
if frequeny:
|
86
|
+
spike_rate = 1000 / dt * spike_rate
|
87
|
+
|
88
|
+
return spike_rate
|
89
|
+
|
90
|
+
|
91
|
+
def get_population_spike_rate(spikes: pd.DataFrame, fs: float = 400.0, t_start: float = 0, t_stop: Optional[float] = None,
|
92
|
+
config: Optional[str] = None, network_name: Optional[str] = None,
|
93
|
+
save: bool = False, save_path: Optional[str] = None,
|
94
|
+
normalize: bool = False) -> Dict[str, np.ndarray]:
|
95
|
+
"""
|
96
|
+
Calculate the population spike rate for each population in the given spike data, with an option to normalize.
|
97
|
+
|
98
|
+
Args:
|
99
|
+
spikes (pd.DataFrame): A DataFrame containing spike data with columns 'pop_name', 'timestamps', and 'node_ids'.
|
100
|
+
fs (float, optional): Sampling frequency in Hz, which determines the time bin size for calculating the spike rate. Default is 400.
|
101
|
+
t_start (float, optional): Start time (in milliseconds) for spike rate calculation. Default is 0.
|
102
|
+
t_stop (Optional[float], optional): Stop time (in milliseconds) for spike rate calculation. If None, defaults to the maximum timestamp in the data.
|
103
|
+
config (Optional[str], optional): Path to a configuration file containing node information, used to determine the correct number of nodes per population.
|
104
|
+
If None, node count is estimated from unique node spikes. Default is None.
|
105
|
+
network_name (Optional[str], optional): Name of the network used in the configuration file, allowing selection of nodes for that network.
|
106
|
+
Required if `config` is provided. Default is None.
|
107
|
+
save (bool, optional): Whether to save the calculated population spike rate to a file. Default is False.
|
108
|
+
save_path (Optional[str], optional): Directory path where the file should be saved if `save` is True. If `save` is True and `save_path` is None, a ValueError is raised.
|
109
|
+
normalize (bool, optional): Whether to normalize the spike rates for each population to a range of [0, 1]. Default is False.
|
110
|
+
|
111
|
+
Returns:
|
112
|
+
Dict[str, np.ndarray]: A dictionary where keys are population names, and values are arrays representing the spike rate over time for each population.
|
113
|
+
If `normalize` is True, each population's spike rate is scaled to [0, 1].
|
114
|
+
|
115
|
+
Raises:
|
116
|
+
ValueError: If `save` is True but `save_path` is not provided.
|
117
|
+
|
118
|
+
Notes:
|
119
|
+
- If `config` is None, the function assumes all cells in each population have fired at least once; otherwise, the node count may be inaccurate.
|
120
|
+
- If normalization is enabled, each population's spike rate is scaled using Min-Max normalization based on its own minimum and maximum values.
|
121
|
+
|
122
|
+
"""
|
123
|
+
pop_spikes = {}
|
124
|
+
node_number = {}
|
125
|
+
|
126
|
+
if config is None:
|
127
|
+
print("Note: Node number is obtained by counting unique node spikes in the network.\nIf the network did not run for a sufficient duration, and not all cells fired, this count might be incorrect.")
|
128
|
+
print("You can provide a config to calculate the correct amount of nodes!")
|
129
|
+
|
130
|
+
if config:
|
131
|
+
if not network_name:
|
132
|
+
print("Grabbing first network; specify a network name to ensure correct node population is selected.")
|
133
|
+
|
134
|
+
for pop_name in spikes['pop_name'].unique():
|
135
|
+
ps = spikes[spikes['pop_name'] == pop_name]
|
136
|
+
|
137
|
+
if config:
|
138
|
+
nodes = load_nodes_from_config(config)
|
139
|
+
if network_name:
|
140
|
+
nodes = nodes[network_name]
|
141
|
+
else:
|
142
|
+
nodes = list(nodes.values())[0] if nodes else {}
|
143
|
+
nodes = nodes[nodes['pop_name'] == pop_name]
|
144
|
+
node_number[pop_name] = nodes.index.nunique()
|
145
|
+
else:
|
146
|
+
node_number[pop_name] = ps['node_ids'].nunique()
|
147
|
+
|
148
|
+
if t_stop is None:
|
149
|
+
t_stop = spikes['timestamps'].max()
|
150
|
+
|
151
|
+
filtered_spikes = spikes[
|
152
|
+
(spikes['pop_name'] == pop_name) &
|
153
|
+
(spikes['timestamps'] > t_start) &
|
154
|
+
(spikes['timestamps'] < t_stop)
|
155
|
+
]
|
156
|
+
pop_spikes[pop_name] = filtered_spikes
|
157
|
+
|
158
|
+
time = np.array([t_start, t_stop, 1000 / fs])
|
159
|
+
pop_rspk = {p: _pop_spike_rate(spk['timestamps'], time) for p, spk in pop_spikes.items()}
|
160
|
+
spike_rate = {p: fs / node_number[p] * pop_rspk[p] for p in pop_rspk}
|
161
|
+
|
162
|
+
# Normalize each spike rate series if normalize=True
|
163
|
+
if normalize:
|
164
|
+
spike_rate = {p: (sr - sr.min()) / (sr.max() - sr.min()) for p, sr in spike_rate.items()}
|
165
|
+
|
166
|
+
if save:
|
167
|
+
if save_path is None:
|
168
|
+
raise ValueError("save_path must be provided if save is True.")
|
169
|
+
|
170
|
+
os.makedirs(save_path, exist_ok=True)
|
171
|
+
|
172
|
+
save_file = os.path.join(save_path, 'spike_rate.h5')
|
173
|
+
with h5py.File(save_file, 'w') as f:
|
174
|
+
f.create_dataset('time', data=time)
|
175
|
+
grp = f.create_group('populations')
|
176
|
+
for p, rspk in spike_rate.items():
|
177
|
+
pop_grp = grp.create_group(p)
|
178
|
+
pop_grp.create_dataset('data', data=rspk)
|
179
|
+
|
180
|
+
return spike_rate
|
181
|
+
|
bmtool/bmplot.py
CHANGED
@@ -762,7 +762,7 @@ def connector_percent_matrix(csv_path: str = None, exclude_strings=None, assemb_
|
|
762
762
|
plt.tight_layout()
|
763
763
|
plt.show()
|
764
764
|
|
765
|
-
def raster(spikes_df: Optional[pd.DataFrame] = None, config: Optional[str] = None, network_name: Optional[str] = None,
|
765
|
+
def raster(spikes_df: Optional[pd.DataFrame] = None, config: Optional[str] = None, network_name: Optional[str] = None, groupby:Optional[str] = 'pop_name',
|
766
766
|
ax: Optional[Axes] = None,tstart: Optional[float] = None,tstop: Optional[float] = None,
|
767
767
|
color_map: Optional[Dict[str, str]] = None) -> Axes:
|
768
768
|
"""
|
@@ -793,7 +793,7 @@ def raster(spikes_df: Optional[pd.DataFrame] = None, config: Optional[str] = Non
|
|
793
793
|
Notes:
|
794
794
|
-----
|
795
795
|
- If `config` is provided, the function merges population names from the node data with `spikes_df`.
|
796
|
-
- Each unique population
|
796
|
+
- Each unique population from groupby in `spikes_df` will be represented by a different color if `color_map` is not specified.
|
797
797
|
- If `color_map` is provided, it should contain colors for all unique `pop_name` values in `spikes_df`.
|
798
798
|
"""
|
799
799
|
# Initialize axes if none provided
|
@@ -822,11 +822,11 @@ def raster(spikes_df: Optional[pd.DataFrame] = None, config: Optional[str] = Non
|
|
822
822
|
# Drop all intersecting columns except the join key column from df2
|
823
823
|
spikes_df = spikes_df.drop(columns=common_columns)
|
824
824
|
# merge nodes and spikes df
|
825
|
-
spikes_df = spikes_df.merge(nodes[
|
825
|
+
spikes_df = spikes_df.merge(nodes[groupby], left_on='node_ids', right_index=True, how='left')
|
826
826
|
|
827
827
|
|
828
828
|
# Get unique population names
|
829
|
-
unique_pop_names = spikes_df[
|
829
|
+
unique_pop_names = spikes_df[groupby].unique()
|
830
830
|
|
831
831
|
# Generate colors if no color_map is provided
|
832
832
|
if color_map is None:
|
@@ -839,7 +839,7 @@ def raster(spikes_df: Optional[pd.DataFrame] = None, config: Optional[str] = Non
|
|
839
839
|
raise ValueError(f"color_map is missing colors for populations: {missing_colors}")
|
840
840
|
|
841
841
|
# Plot each population with its specified or generated color
|
842
|
-
for pop_name, group in spikes_df.groupby(
|
842
|
+
for pop_name, group in spikes_df.groupby(groupby):
|
843
843
|
ax.scatter(group['timestamps'], group['node_ids'], label=pop_name, color=color_map[pop_name], s=0.5)
|
844
844
|
|
845
845
|
# Label axes
|
bmtool/singlecell.py
CHANGED
@@ -90,7 +90,11 @@ class CurrentClamp(object):
|
|
90
90
|
self.inj_dur = inj_dur
|
91
91
|
self.inj_amp = inj_amp * 1e-3 # pA to nA
|
92
92
|
|
93
|
-
|
93
|
+
# sometimes people may put a hoc object in for the template name
|
94
|
+
if callable(template_name):
|
95
|
+
self.cell = template_name
|
96
|
+
else:
|
97
|
+
self.cell = self.create_cell()
|
94
98
|
if post_init_function:
|
95
99
|
eval(f"self.cell.{post_init_function}")
|
96
100
|
|
@@ -1,13 +1,16 @@
|
|
1
|
-
bmtool/SLURM.py,sha256=
|
1
|
+
bmtool/SLURM.py,sha256=AKxu_Ln9wuCBVLdOJP4yAN59jYt222DM-iUlsQojNvY,18145
|
2
2
|
bmtool/__init__.py,sha256=ZStTNkAJHJxG7Pwiy5UgCzC4KlhMS5pUNPtUJZVwL_Y,136
|
3
3
|
bmtool/__main__.py,sha256=TmFkmDxjZ6250nYD4cgGhn-tbJeEm0u-EMz2ajAN9vE,650
|
4
|
-
bmtool/bmplot.py,sha256=
|
4
|
+
bmtool/bmplot.py,sha256=iTK6q8XEqc8QEAKR152ut_1qdtnMoEe1Uq-4dCrkCA0,53992
|
5
5
|
bmtool/connectors.py,sha256=hWkUUcJ4tmas8NDOFPPjQT-TgTlPcpjuZsYyAW2WkPA,72242
|
6
6
|
bmtool/graphs.py,sha256=K8BiughRUeXFVvAgo8UzrwpSClIVg7UfmIcvtEsEsk0,6020
|
7
7
|
bmtool/manage.py,sha256=_lCU0qBQZ4jSxjzAJUd09JEetb--cud7KZgxQFbLGSY,657
|
8
8
|
bmtool/plot_commands.py,sha256=Tqujyf0c0u8olhiHOMwgUSJXIIE1hgjv6otb25G9cA0,12298
|
9
|
-
bmtool/singlecell.py,sha256=
|
9
|
+
bmtool/singlecell.py,sha256=XZAT_2n44EhwqVLnk3qur9aO7oJ-10axJZfwPBslM88,27219
|
10
10
|
bmtool/synapses.py,sha256=gIkfLhKDG2dHHCVJJoKuQrFn_Qut843bfk_-s97wu6c,54553
|
11
|
+
bmtool/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
|
+
bmtool/analysis/lfp.py,sha256=Zp-aJ8x2KmsI3h_mqvq4u9ixFYu4n0CxQpgdbYnrtYE,14909
|
13
|
+
bmtool/analysis/spikes.py,sha256=23k_wFOC9pKQgetxMp1V2z6cZaW2eoIZAZyjFiTfrrM,8560
|
11
14
|
bmtool/debug/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
15
|
bmtool/debug/commands.py,sha256=AwtcR7BUUheM0NxvU1Nu234zCdpobhJv5noX8x5K2vY,583
|
13
16
|
bmtool/debug/debug.py,sha256=xqnkzLiH3s-tS26Y5lZZL62qR2evJdi46Gud-HzxEN4,207
|
@@ -16,9 +19,9 @@ bmtool/util/commands.py,sha256=zJF-fiLk0b8LyzHDfvewUyS7iumOxVnj33IkJDzux4M,64396
|
|
16
19
|
bmtool/util/util.py,sha256=00vOAwTVIifCqouBoFoT0lBashl4fCalrk8fhg_Uq4c,56654
|
17
20
|
bmtool/util/neuron/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
18
21
|
bmtool/util/neuron/celltuner.py,sha256=xSRpRN6DhPFz4q5buq_W8UmsD7BbUrkzYBEbKVloYss,87194
|
19
|
-
bmtool-0.6.7.dist-info/LICENSE,sha256=qrXg2jj6kz5d0EnN11hllcQt2fcWVNumx0xNbV05nyM,1068
|
20
|
-
bmtool-0.6.7.dist-info/METADATA,sha256=
|
21
|
-
bmtool-0.6.7.dist-info/WHEEL,sha256=
|
22
|
-
bmtool-0.6.7.dist-info/entry_points.txt,sha256=0-BHZ6nUnh0twWw9SXNTiRmKjDnb1VO2DfG_-oprhAc,45
|
23
|
-
bmtool-0.6.7.dist-info/top_level.txt,sha256=gpd2Sj-L9tWbuJEd5E8C8S8XkNm5yUE76klUYcM-eWM,7
|
24
|
-
bmtool-0.6.7.dist-info/RECORD,,
|
22
|
+
bmtool-0.6.7.1.dist-info/LICENSE,sha256=qrXg2jj6kz5d0EnN11hllcQt2fcWVNumx0xNbV05nyM,1068
|
23
|
+
bmtool-0.6.7.1.dist-info/METADATA,sha256=qlkGt9KNUlVP7C3J1xgvj23YUHaPD7SlVFzhVhSUceg,20226
|
24
|
+
bmtool-0.6.7.1.dist-info/WHEEL,sha256=nn6H5-ilmfVryoAQl3ZQ2l8SH5imPWFpm1A5FgEuFV4,91
|
25
|
+
bmtool-0.6.7.1.dist-info/entry_points.txt,sha256=0-BHZ6nUnh0twWw9SXNTiRmKjDnb1VO2DfG_-oprhAc,45
|
26
|
+
bmtool-0.6.7.1.dist-info/top_level.txt,sha256=gpd2Sj-L9tWbuJEd5E8C8S8XkNm5yUE76klUYcM-eWM,7
|
27
|
+
bmtool-0.6.7.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|