bmtool 0.6.6.4__py3-none-any.whl → 0.6.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bmtool/SLURM.py CHANGED
@@ -4,6 +4,8 @@ import subprocess
4
4
  import json
5
5
  import requests
6
6
  import shutil
7
+ import time
8
+ import copy
7
9
 
8
10
 
9
11
  def check_job_status(job_id):
@@ -117,7 +119,7 @@ class multiSeedSweep(seedSweep):
117
119
  MultSeedSweeps are centered around some base JSON cell file. When that base JSON is updated, the other JSONs
118
120
  change according to their ratio with the base JSON.
119
121
  """
120
- def __init__(self, base_json_file_path, param_name, syn_dict_list=[], base_ratio=1):
122
+ def __init__(self, base_json_file_path, param_name, syn_dict, base_ratio=1):
121
123
  """
122
124
  Initializes the multipleSeedSweep instance.
123
125
 
@@ -128,7 +130,7 @@ class multiSeedSweep(seedSweep):
128
130
  base_ratio (float): The ratio between the other JSONs; usually the current value for the parameter.
129
131
  """
130
132
  super().__init__(base_json_file_path, param_name)
131
- self.syn_dict_list = syn_dict_list
133
+ self.syn_dict_for_multi = syn_dict
132
134
  self.base_ratio = base_ratio
133
135
 
134
136
  def edit_all_jsons(self, new_value):
@@ -140,19 +142,19 @@ class multiSeedSweep(seedSweep):
140
142
  """
141
143
  self.edit_json(new_value)
142
144
  base_ratio = self.base_ratio
143
- for syn_dict in self.syn_dict_list:
144
- json_file_path = syn_dict['json_file_path']
145
- new_ratio = syn_dict['ratio'] / base_ratio
146
-
147
- with open(json_file_path, 'r') as f:
148
- data = json.load(f)
149
- altered_value = new_ratio * new_value
150
- data[self.param_name] = altered_value
151
-
152
- with open(json_file_path, 'w') as f:
153
- json.dump(data, f, indent=4)
145
+
146
+ json_file_path = self.syn_dict_for_multi['json_file_path']
147
+ new_ratio = self.syn_dict_for_multi['ratio'] / base_ratio
154
148
 
155
- print(f"JSON file '{json_file_path}' modified successfully with {self.param_name}={altered_value}.", flush=True)
149
+ with open(json_file_path, 'r') as f:
150
+ data = json.load(f)
151
+ altered_value = new_ratio * new_value
152
+ data[self.param_name] = altered_value
153
+
154
+ with open(json_file_path, 'w') as f:
155
+ json.dump(data, f, indent=4)
156
+
157
+ print(f"JSON file '{json_file_path}' modified successfully with {self.param_name}={altered_value}.", flush=True)
156
158
 
157
159
 
158
160
  class SimulationBlock:
@@ -273,6 +275,7 @@ export OUTPUT_DIR={case_output_dir}
273
275
  """
274
276
  for job_id in self.job_ids:
275
277
  status = check_job_status(job_id)
278
+ #print(f"status of job is {status}")
276
279
  if status != 'COMPLETED': # can add PENDING here for debugging NOT FOR ACTUALLY USING IT
277
280
  return False
278
281
  return True
@@ -314,7 +317,7 @@ class BlockRunner:
314
317
  """
315
318
 
316
319
  def __init__(self, blocks, json_editor=None,json_file_path=None, param_name=None,
317
- param_values=None, check_interval=60,syn_dict_list = None,
320
+ param_values=None, check_interval=60,syn_dict = None,
318
321
  webhook=None):
319
322
  self.blocks = blocks
320
323
  self.json_editor = json_editor
@@ -323,29 +326,44 @@ class BlockRunner:
323
326
  self.webhook = webhook
324
327
  self.param_name = param_name
325
328
  self.json_file_path = json_file_path
326
- self.syn_dict_list = syn_dict_list
329
+ self.syn_dict = syn_dict
327
330
 
328
331
  def submit_blocks_sequentially(self):
329
332
  """
330
333
  Submits all blocks sequentially, ensuring each block starts only after the previous block has completed or is running.
331
334
  Updates the JSON file with new parameters before each block run.
332
- json file path should be the path WITH the components folder
333
335
  """
334
336
  for i, block in enumerate(self.blocks):
335
337
  # Update JSON file with new parameter value
336
- if self.json_editor == None and self.param_values == None:
338
+ if self.json_file_path == None and self.param_values == None:
339
+ source_dir = block.component_path
340
+ destination_dir = f"{source_dir}{i+1}"
341
+ block.component_path = destination_dir
342
+ shutil.copytree(source_dir, destination_dir) # create new components folder
337
343
  print(f"skipping json editing for block {block.block_name}",flush=True)
338
344
  else:
339
345
  if len(self.blocks) != len(self.param_values):
340
346
  raise Exception("Number of blocks needs to each number of params given")
341
347
  new_value = self.param_values[i]
348
+ # hope this path is correct
349
+ source_dir = block.component_path
350
+ destination_dir = f"{source_dir}{i+1}"
351
+ block.component_path = destination_dir
352
+
353
+ shutil.copytree(source_dir, destination_dir) # create new components folder
354
+ json_file_path = os.path.join(destination_dir,self.json_file_path)
342
355
 
343
- if self.syn_dict_list == None:
344
- json_editor = seedSweep(self.json_file_path, self.param_name)
356
+ if self.syn_dict == None:
357
+ json_editor = seedSweep(json_file_path , self.param_name)
345
358
  json_editor.edit_json(new_value)
346
359
  else:
347
- json_editor = multiSeedSweep(self.json_file_path,self.param_name,
348
- self.syn_dict_list,base_ratio=1)
360
+ # need to keep the orignal around
361
+ syn_dict_temp = copy.deepcopy(self.syn_dict)
362
+ json_to_be_ratioed = syn_dict_temp['json_file_path']
363
+ corrected_ratio_path = os.path.join(destination_dir,json_to_be_ratioed)
364
+ syn_dict_temp['json_file_path'] = corrected_ratio_path
365
+ json_editor = multiSeedSweep(json_file_path ,self.param_name,
366
+ syn_dict=syn_dict_temp,base_ratio=1)
349
367
  json_editor.edit_all_jsons(new_value)
350
368
 
351
369
  # Submit the block
@@ -357,7 +375,7 @@ class BlockRunner:
357
375
 
358
376
  # Wait for the block to complete
359
377
  if i == len(self.blocks) - 1:
360
- while not block.check_block_completed():
378
+ while not block.check_block_status():
361
379
  print(f"Waiting for the last block {i} to complete...")
362
380
  time.sleep(self.check_interval)
363
381
  else: # Not the last block so if job is running lets start a new one (checks status list)
@@ -376,13 +394,14 @@ class BlockRunner:
376
394
  submits all the blocks at once onto the queue. To do this the components dir will be cloned and each block will have its own.
377
395
  Also the json_file_path should be the path after the components dir
378
396
  """
379
- if self.webhook:
380
- message = "SIMULATION UPDATE: Simulations have been submited in parallel!"
381
- send_teams_message(self.webhook,message)
382
- if self.param_values == None:
383
- print(f"skipping json editing for block {block.block_name}",flush=True)
384
- else:
385
- for i, block in enumerate(self.blocks):
397
+ for i, block in enumerate(self.blocks):
398
+ if self.param_values == None:
399
+ source_dir = block.component_path
400
+ destination_dir = f"{source_dir}{i+1}"
401
+ block.component_path = destination_dir
402
+ shutil.copytree(source_dir, destination_dir) # create new components folder
403
+ print(f"skipping json editing for block {block.block_name}",flush=True)
404
+ else:
386
405
  if block.component_path == None:
387
406
  raise Exception("Unable to use parallel submitter without defining the component path")
388
407
  new_value = self.param_values[i]
@@ -393,22 +412,27 @@ class BlockRunner:
393
412
 
394
413
  shutil.copytree(source_dir, destination_dir) # create new components folder
395
414
  json_file_path = os.path.join(destination_dir,self.json_file_path)
396
- if self.syn_dict_list == None:
397
- json_editor = seedSweep(json_file_path, self.param_name)
415
+
416
+ if self.syn_dict == None:
417
+ json_editor = seedSweep(json_file_path , self.param_name)
398
418
  json_editor.edit_json(new_value)
399
419
  else:
400
- json_editor = multiSeedSweep(json_file_path,self.param_name,
401
- self.syn_dict_list,base_ratio=1)
402
- json_editor.edit_all_jsons(new_value)
403
-
404
- # submit block with new component path
420
+ # need to keep the orignal around
421
+ syn_dict_temp = copy.deepcopy(self.syn_dict)
422
+ json_to_be_ratioed = syn_dict_temp['json_file_path']
423
+ corrected_ratio_path = os.path.join(destination_dir,json_to_be_ratioed)
424
+ syn_dict_temp['json_file_path'] = corrected_ratio_path
425
+ json_editor = multiSeedSweep(json_file_path ,self.param_name,
426
+ syn_dict_temp,base_ratio=1)
427
+ json_editor.edit_all_jsons(new_value)
428
+ # submit block with new component path
405
429
  print(f"Submitting block: {block.block_name}", flush=True)
406
430
  block.submit_block()
407
431
  if i == len(self.blocks) - 1:
408
- if self.webook:
409
- while not block.check_block_completed():
410
- print(f"Waiting for the last block {i} to complete...")
411
- time.sleep(self.check_interval)
432
+ print("\nEverything has been submitted. You can close out of this or keep this script running to get a message when everything is finished\n")
433
+ while not block.check_block_status():
434
+ print(f"Waiting for the last block {i} to complete...")
435
+ time.sleep(self.check_interval)
412
436
 
413
437
  if self.webhook:
414
438
  message = "SIMULATION UPDATE: Simulations are Done!"
File without changes
bmtool/analysis/lfp.py ADDED
@@ -0,0 +1,408 @@
1
+ """
2
+ Module for processing BMTK LFP output.
3
+ """
4
+
5
+ import h5py
6
+ import numpy as np
7
+ import xarray as xr
8
+ from fooof import FOOOF
9
+ from fooof.sim.gen import gen_model
10
+ import matplotlib.pyplot as plt
11
+ from scipy import signal
12
+ import pywt
13
+ from bmtool.bmplot import is_notebook
14
+
15
+
16
+ def load_ecp_to_xarray(ecp_file: str, demean: bool = False) -> xr.DataArray:
17
+ """
18
+ Load ECP data from an HDF5 file (BMTK sim) into an xarray DataArray.
19
+
20
+ Parameters:
21
+ ----------
22
+ ecp_file : str
23
+ Path to the HDF5 file containing ECP data.
24
+ demean : bool, optional
25
+ If True, the mean of the data will be subtracted (default is False).
26
+
27
+ Returns:
28
+ -------
29
+ xr.DataArray
30
+ An xarray DataArray containing the ECP data, with time as one dimension
31
+ and channel_id as another.
32
+ """
33
+ with h5py.File(ecp_file, 'r') as f:
34
+ ecp = xr.DataArray(
35
+ f['ecp']['data'][()].T,
36
+ coords=dict(
37
+ channel_id=f['ecp']['channel_id'][()],
38
+ time=np.arange(*f['ecp']['time']) # ms
39
+ ),
40
+ attrs=dict(
41
+ fs=1000 / f['ecp']['time'][2] # Hz
42
+ )
43
+ )
44
+ if demean:
45
+ ecp -= ecp.mean(dim='time')
46
+ return ecp
47
+
48
+
49
+ def ecp_to_lfp(ecp_data: xr.DataArray, cutoff: float = 250, fs: float = 10000,
50
+ downsample_freq: float = 1000) -> xr.DataArray:
51
+ """
52
+ Apply a low-pass Butterworth filter to an xarray DataArray and optionally downsample.
53
+ This filters out the high end frequencies turning the ECP into a LFP
54
+
55
+ Parameters:
56
+ ----------
57
+ ecp_data : xr.DataArray
58
+ The input data array containing LFP data with time as one dimension.
59
+ cutoff : float
60
+ The cutoff frequency for the low-pass filter in Hz (default is 250Hz).
61
+ fs : float, optional
62
+ The sampling frequency of the data (default is 10000 Hz).
63
+ downsample_freq : float, optional
64
+ The frequency to downsample to (default is 1000 Hz).
65
+
66
+ Returns:
67
+ -------
68
+ xr.DataArray
69
+ The filtered (and possibly downsampled) data as an xarray DataArray.
70
+ """
71
+ # Bandpass filter design
72
+ nyq = 0.5 * fs
73
+ cut = cutoff / nyq
74
+ b, a = signal.butter(8, cut, btype='low', analog=False)
75
+
76
+ # Initialize an array to hold filtered data
77
+ filtered_data = xr.DataArray(np.zeros_like(ecp_data), coords=ecp_data.coords, dims=ecp_data.dims)
78
+
79
+ # Apply the filter to each channel
80
+ for channel in ecp_data.channel_id:
81
+ filtered_data.loc[channel, :] = signal.filtfilt(b, a, ecp_data.sel(channel_id=channel).values)
82
+
83
+ # Downsample the filtered data if a downsample frequency is provided
84
+ if downsample_freq is not None:
85
+ downsample_factor = int(fs / downsample_freq)
86
+ filtered_data = filtered_data.isel(time=slice(None, None, downsample_factor))
87
+ # Update the sampling frequency attribute
88
+ filtered_data.attrs['fs'] = downsample_freq
89
+
90
+ return filtered_data
91
+
92
+
93
+ def slice_time_series(data: xr.DataArray, time_ranges: tuple) -> xr.DataArray:
94
+ """
95
+ Slice the xarray DataArray based on provided time ranges.
96
+ Can be used to get LFP during certain stimulus times
97
+
98
+ Parameters:
99
+ ----------
100
+ data : xr.DataArray
101
+ The input xarray DataArray containing time-series data.
102
+ time_ranges : tuple or list of tuples
103
+ One or more tuples representing the (start, stop) time points for slicing.
104
+ For example: (start, stop) or [(start1, stop1), (start2, stop2)]
105
+
106
+ Returns:
107
+ -------
108
+ xr.DataArray
109
+ A new xarray DataArray containing the concatenated slices.
110
+ """
111
+ # Ensure time_ranges is a list of tuples
112
+ if isinstance(time_ranges, tuple) and len(time_ranges) == 2:
113
+ time_ranges = [time_ranges]
114
+
115
+ # List to hold sliced data
116
+ slices = []
117
+
118
+ # Slice the data for each time range
119
+ for start, stop in time_ranges:
120
+ sliced_data = data.sel(time=slice(start, stop))
121
+ slices.append(sliced_data)
122
+
123
+ # Concatenate all slices along the time dimension if more than one slice
124
+ if len(slices) > 1:
125
+ return xr.concat(slices, dim='time')
126
+ else:
127
+ return slices[0]
128
+
129
+
130
+ def fit_fooof(f: np.ndarray, pxx: np.ndarray, aperiodic_mode: str = 'fixed',
131
+ dB_threshold: float = 3.0, max_n_peaks: int = 10,
132
+ freq_range: tuple = None, peak_width_limits: tuple = None,
133
+ report: bool = False, plot: bool = False,
134
+ plt_log: bool = False, plt_range: tuple = None,
135
+ figsize: tuple = None, title: str = None) -> tuple:
136
+ """
137
+ Fit a FOOOF model to power spectral density data.
138
+
139
+ Parameters:
140
+ ----------
141
+ f : array-like
142
+ Frequencies corresponding to the power spectral density data.
143
+ pxx : array-like
144
+ Power spectral density data to fit.
145
+ aperiodic_mode : str, optional
146
+ The mode for fitting aperiodic components ('fixed' or 'knee', default is 'fixed').
147
+ dB_threshold : float, optional
148
+ Minimum peak height in dB (default is 3).
149
+ max_n_peaks : int, optional
150
+ Maximum number of peaks to fit (default is 10).
151
+ freq_range : tuple, optional
152
+ Frequency range to fit (default is None, which uses the full range).
153
+ peak_width_limits : tuple, optional
154
+ Limits on the width of peaks (default is None).
155
+ report : bool, optional
156
+ If True, will print fitting results (default is False).
157
+ plot : bool, optional
158
+ If True, will plot the fitting results (default is False).
159
+ plt_log : bool, optional
160
+ If True, use a logarithmic scale for the y-axis in plots (default is False).
161
+ plt_range : tuple, optional
162
+ Range for plotting (default is None).
163
+ figsize : tuple, optional
164
+ Size of the figure (default is None).
165
+ title : str, optional
166
+ Title for the plot (default is None).
167
+
168
+ Returns:
169
+ -------
170
+ tuple
171
+ A tuple containing the fitting results and the FOOOF model object.
172
+ """
173
+ if aperiodic_mode != 'knee':
174
+ aperiodic_mode = 'fixed'
175
+
176
+ def set_range(x, upper=f[-1]):
177
+ x = np.array(upper) if x is None else np.array(x)
178
+ return [f[2], x.item()] if x.size == 1 else x.tolist()
179
+
180
+ freq_range = set_range(freq_range)
181
+ peak_width_limits = set_range(peak_width_limits, np.inf)
182
+
183
+ # Initialize a FOOOF object
184
+ fm = FOOOF(peak_width_limits=peak_width_limits, min_peak_height=dB_threshold / 10,
185
+ peak_threshold=0., max_n_peaks=max_n_peaks, aperiodic_mode=aperiodic_mode)
186
+
187
+ # Fit the model
188
+ try:
189
+ fm.fit(f, pxx, freq_range)
190
+ except Exception as e:
191
+ fl = np.linspace(f[0], f[-1], int((f[-1] - f[0]) / np.min(np.diff(f))) + 1)
192
+ fm.fit(fl, np.interp(fl, f, pxx), freq_range)
193
+
194
+ results = fm.get_results()
195
+
196
+ if report:
197
+ fm.print_results()
198
+ if aperiodic_mode == 'knee':
199
+ ap_params = results.aperiodic_params
200
+ if ap_params[1] <= 0:
201
+ print('Negative value of knee parameter occurred. Suggestion: Fit without knee parameter.')
202
+ knee_freq = np.abs(ap_params[1]) ** (1 / ap_params[2])
203
+ print(f'Knee location: {knee_freq:.2f} Hz')
204
+
205
+ if plot:
206
+ plt_range = set_range(plt_range)
207
+ fm.plot(plt_log=plt_log)
208
+ plt.xlim(np.log10(plt_range) if plt_log else plt_range)
209
+ #plt.ylim(-8, -5.5)
210
+ if figsize:
211
+ plt.gcf().set_size_inches(figsize)
212
+ if title:
213
+ plt.title(title)
214
+ if is_notebook():
215
+ pass
216
+ else:
217
+ plt.show()
218
+
219
+ return results, fm
220
+
221
+
222
+ def generate_resd_from_fooof(fooof_model: FOOOF) -> tuple:
223
+ """
224
+ Generate residuals from a fitted FOOOF model.
225
+
226
+ Parameters:
227
+ ----------
228
+ fooof_model : FOOOF
229
+ A fitted FOOOF model object.
230
+
231
+ Returns:
232
+ -------
233
+ tuple
234
+ A tuple containing the residual power spectral density and the aperiodic fit.
235
+ """
236
+ results = fooof_model.get_results()
237
+ full_fit, _, ap_fit = gen_model(fooof_model.freqs[1:], results.aperiodic_params,
238
+ results.gaussian_params, return_components=True)
239
+
240
+ full_fit, ap_fit = 10 ** full_fit, 10 ** ap_fit # Convert back from log
241
+ res_psd = np.insert((10 ** fooof_model.power_spectrum[1:]) - ap_fit, 0, 0.) # Convert back from log
242
+ res_fit = np.insert(full_fit - ap_fit, 0, 0.)
243
+ ap_fit = np.insert(ap_fit, 0, 0.)
244
+
245
+ return res_psd, ap_fit
246
+
247
+
248
+ def calculate_SNR(fooof_model: FOOOF, freq_band: tuple) -> float:
249
+ """
250
+ Calculate the signal-to-noise ratio (SNR) from a fitted FOOOF model.
251
+
252
+ Parameters:
253
+ ----------
254
+ fooof_model : FOOOF
255
+ A fitted FOOOF model object.
256
+ freq_band : tuple
257
+ Frequency band (min, max) for SNR calculation.
258
+
259
+ Returns:
260
+ -------
261
+ float
262
+ The calculated SNR for the specified frequency band.
263
+ """
264
+ periodic, ap = generate_resd_from_fooof(fooof_model)
265
+ freq = fooof_model.freqs # Get frequencies from model
266
+ indices = (freq >= freq_band[0]) & (freq <= freq_band[1]) # Get only the band we care about
267
+ band_periodic = periodic[indices] # Filter based on band
268
+ band_ap = ap[indices] # Filter
269
+ band_freq = freq[indices] # Another filter
270
+ periodic_power = np.trapz(band_periodic, band_freq) # Integrate periodic power
271
+ ap_power = np.trapz(band_ap, band_freq) # Integrate aperiodic power
272
+ normalized_power = periodic_power / ap_power # Compute the SNR
273
+ return normalized_power
274
+
275
+
276
+ def wavelet_filter(x: np.ndarray, freq: float, fs: float, bandwidth: float = 1.0, axis: int = -1) -> np.ndarray:
277
+ """
278
+ Compute the Continuous Wavelet Transform (CWT) for a specified frequency using a complex Morlet wavelet.
279
+ """
280
+ wavelet = 'cmor' + str(2 * bandwidth ** 2) + '-1.0'
281
+ scale = pywt.scale2frequency(wavelet, 1) * fs / freq
282
+ x_a = pywt.cwt(x, [scale], wavelet=wavelet, axis=axis)[0][0]
283
+ return x_a
284
+
285
+
286
+ def butter_bandpass_filter(data: np.ndarray, lowcut: float, highcut: float, fs: float, order: int = 5, axis: int = -1) -> np.ndarray:
287
+ """
288
+ Apply a Butterworth bandpass filter to the input data.
289
+ """
290
+ sos = signal.butter(order, [lowcut, highcut], fs=fs, btype='band', output='sos')
291
+ x_a = signal.sosfiltfilt(sos, data, axis=axis)
292
+ return x_a
293
+
294
+
295
+ def calculate_plv(x1: np.ndarray, x2: np.ndarray, fs: float, freq_of_interest: float = None,
296
+ method: str = 'wavelet', lowcut: float = None, highcut: float = None,
297
+ bandwidth: float = 2.0) -> np.ndarray:
298
+ """
299
+ Calculate Phase Locking Value (PLV) between two signals using wavelet or Hilbert method.
300
+
301
+ Parameters:
302
+ - x1, x2: Input signals (1D arrays, same length)
303
+ - fs: Sampling frequency
304
+ - freq_of_interest: Desired frequency for wavelet PLV calculation
305
+ - method: 'wavelet' or 'hilbert' to choose the PLV calculation method
306
+ - lowcut, highcut: Cutoff frequencies for the Hilbert method
307
+ - bandwidth: Bandwidth parameter for the wavelet
308
+
309
+ Returns:
310
+ - plv: Phase Locking Value (1D array)
311
+ """
312
+ if len(x1) != len(x2):
313
+ raise ValueError("Input signals must have the same length.")
314
+
315
+ if method == 'wavelet':
316
+ if freq_of_interest is None:
317
+ raise ValueError("freq_of_interest must be provided for the wavelet method.")
318
+
319
+ # Apply CWT to both signals
320
+ theta1 = wavelet_filter(x=x1, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
321
+ theta2 = wavelet_filter(x=x2, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
322
+
323
+ elif method == 'hilbert':
324
+ if lowcut is None or highcut is None:
325
+ print("Lowcut and or highcut were not definded, signal will not be filter and just take hilbert transform for plv calc")
326
+
327
+ if lowcut and highcut:
328
+ # Bandpass filter and get the analytic signal using the Hilbert transform
329
+ x1 = butter_bandpass_filter(x1, lowcut, highcut, fs)
330
+ x2 = butter_bandpass_filter(x2, lowcut, highcut, fs)
331
+
332
+ # Get phase using the Hilbert transform
333
+ theta1 = signal.hilbert(x1)
334
+ theta2 = signal.hilbert(x2)
335
+
336
+ else:
337
+ raise ValueError("Invalid method. Choose 'wavelet' or 'hilbert'.")
338
+
339
+ # Calculate phase difference
340
+ phase_diff = np.angle(theta1) - np.angle(theta2)
341
+
342
+ # Calculate PLV from standard equation from Measuring phase synchrony in brain signals(1999)
343
+ plv = np.abs(np.mean(np.exp(1j * phase_diff), axis=-1))
344
+
345
+ return plv
346
+
347
+
348
+ def calculate_plv_over_time(x1: np.ndarray, x2: np.ndarray, fs: float,
349
+ window_size: float, step_size: float,
350
+ method: str = 'wavelet', freq_of_interest: float = None,
351
+ lowcut: float = None, highcut: float = None,
352
+ bandwidth: float = 2.0):
353
+ """
354
+ Calculate the time-resolved Phase Locking Value (PLV) between two signals using a sliding window approach.
355
+
356
+ Parameters:
357
+ ----------
358
+ x1, x2 : array-like
359
+ Input signals (1D arrays, same length).
360
+ fs : float
361
+ Sampling frequency of the input signals.
362
+ window_size : float
363
+ Length of the window in seconds for PLV calculation.
364
+ step_size : float
365
+ Step size in seconds to slide the window across the signals.
366
+ method : str, optional
367
+ Method to calculate PLV ('wavelet' or 'hilbert'). Defaults to 'wavelet'.
368
+ freq_of_interest : float, optional
369
+ Frequency of interest for the wavelet method. Required if method is 'wavelet'.
370
+ lowcut, highcut : float, optional
371
+ Cutoff frequencies for the Hilbert method. Required if method is 'hilbert'.
372
+ bandwidth : float, optional
373
+ Bandwidth parameter for the wavelet. Defaults to 2.0.
374
+
375
+ Returns:
376
+ -------
377
+ plv_over_time : 1D array
378
+ Array of PLV values calculated over each window.
379
+ times : 1D array
380
+ The center times of each window where the PLV was calculated.
381
+ """
382
+ # Convert window and step size from seconds to samples
383
+ window_samples = int(window_size * fs)
384
+ step_samples = int(step_size * fs)
385
+
386
+ # Initialize results
387
+ plv_over_time = []
388
+ times = []
389
+
390
+ # Iterate over the signal with a sliding window
391
+ for start in range(0, len(x1) - window_samples + 1, step_samples):
392
+ end = start + window_samples
393
+ window_x1 = x1[start:end]
394
+ window_x2 = x2[start:end]
395
+
396
+ # Use the updated calculate_plv function within each window
397
+ plv = calculate_plv(x1=window_x1, x2=window_x2, fs=fs,
398
+ method=method, freq_of_interest=freq_of_interest,
399
+ lowcut=lowcut, highcut=highcut, bandwidth=bandwidth)
400
+ plv_over_time.append(plv)
401
+
402
+ # Store the time at the center of the window
403
+ center_time = (start + end) / 2 / fs
404
+ times.append(center_time)
405
+
406
+ return np.array(plv_over_time), np.array(times)
407
+
408
+
@@ -0,0 +1,181 @@
1
+ """
2
+ Module for processing BMTK spikes output.
3
+ """
4
+
5
+ import h5py
6
+ import pandas as pd
7
+ from bmtool.util.util import load_nodes_from_config
8
+ from typing import Dict, Optional,Tuple, Union
9
+ import numpy as np
10
+ import os
11
+
12
+
13
+ def load_spikes_to_df(spike_file: str, network_name: str, sort: bool = True, config: str = None, groupby: str = 'pop_name') -> pd.DataFrame:
14
+ """
15
+ Load spike data from an HDF5 file into a pandas DataFrame.
16
+
17
+ Args:
18
+ spike_file (str): Path to the HDF5 file containing spike data.
19
+ network_name (str): The name of the network within the HDF5 file from which to load spike data.
20
+ sort (bool, optional): Whether to sort the DataFrame by 'timestamps'. Defaults to True.
21
+ config (str, optional): Will label the cell type of each spike.
22
+ groupby (str or list of str, optional): The column(s) to group by. Defaults to 'pop_name'.
23
+
24
+ Returns:
25
+ pd.DataFrame: A pandas DataFrame containing 'node_ids' and 'timestamps' columns from the spike data.
26
+
27
+ Example:
28
+ df = load_spikes_to_df("spikes.h5", "cortex")
29
+ """
30
+ with h5py.File(spike_file) as f:
31
+ spikes_df = pd.DataFrame({
32
+ 'node_ids': f['spikes'][network_name]['node_ids'],
33
+ 'timestamps': f['spikes'][network_name]['timestamps']
34
+ })
35
+
36
+ if sort:
37
+ spikes_df.sort_values(by='timestamps', inplace=True, ignore_index=True)
38
+
39
+ if config:
40
+ nodes = load_nodes_from_config(config)
41
+ nodes = nodes[network_name]
42
+
43
+ # Check if 'groupby' is a string or a list of strings and handle accordingly
44
+ if isinstance(groupby, str):
45
+ spikes_df = spikes_df.merge(nodes[groupby], left_on='node_ids', right_index=True, how='left')
46
+ elif isinstance(groupby, list):
47
+ for group in groupby:
48
+ spikes_df = spikes_df.merge(nodes[group], left_on='node_ids', right_index=True, how='left')
49
+
50
+ return spikes_df
51
+
52
+
53
+
54
+ def _pop_spike_rate(spike_times: Union[np.ndarray, list], time: Optional[Tuple[float, float, float]] = None,
55
+ time_points: Optional[Union[np.ndarray, list]] = None, frequeny: bool = False) -> np.ndarray:
56
+ """
57
+ Calculate the spike count or frequency histogram over specified time intervals.
58
+
59
+ Args:
60
+ spike_times (Union[np.ndarray, list]): Array or list of spike times in milliseconds.
61
+ time (Optional[Tuple[float, float, float]], optional): Tuple specifying (start, stop, step) in milliseconds.
62
+ Used to create evenly spaced time points if `time_points` is not provided. Default is None.
63
+ time_points (Optional[Union[np.ndarray, list]], optional): Array or list of specific time points for binning.
64
+ If provided, `time` is ignored. Default is None.
65
+ frequeny (bool, optional): If True, returns spike frequency in Hz; otherwise, returns spike count. Default is False.
66
+
67
+ Returns:
68
+ np.ndarray: Array of spike counts or frequencies, depending on the `frequeny` flag.
69
+
70
+ Raises:
71
+ ValueError: If both `time` and `time_points` are None.
72
+ """
73
+ if time_points is None:
74
+ if time is None:
75
+ raise ValueError("Either `time` or `time_points` must be provided.")
76
+ time_points = np.arange(*time)
77
+ dt = time[2]
78
+ else:
79
+ time_points = np.asarray(time_points).ravel()
80
+ dt = (time_points[-1] - time_points[0]) / (time_points.size - 1)
81
+
82
+ bins = np.append(time_points, time_points[-1] + dt)
83
+ spike_rate, _ = np.histogram(np.asarray(spike_times), bins)
84
+
85
+ if frequeny:
86
+ spike_rate = 1000 / dt * spike_rate
87
+
88
+ return spike_rate
89
+
90
+
91
+ def get_population_spike_rate(spikes: pd.DataFrame, fs: float = 400.0, t_start: float = 0, t_stop: Optional[float] = None,
92
+ config: Optional[str] = None, network_name: Optional[str] = None,
93
+ save: bool = False, save_path: Optional[str] = None,
94
+ normalize: bool = False) -> Dict[str, np.ndarray]:
95
+ """
96
+ Calculate the population spike rate for each population in the given spike data, with an option to normalize.
97
+
98
+ Args:
99
+ spikes (pd.DataFrame): A DataFrame containing spike data with columns 'pop_name', 'timestamps', and 'node_ids'.
100
+ fs (float, optional): Sampling frequency in Hz, which determines the time bin size for calculating the spike rate. Default is 400.
101
+ t_start (float, optional): Start time (in milliseconds) for spike rate calculation. Default is 0.
102
+ t_stop (Optional[float], optional): Stop time (in milliseconds) for spike rate calculation. If None, defaults to the maximum timestamp in the data.
103
+ config (Optional[str], optional): Path to a configuration file containing node information, used to determine the correct number of nodes per population.
104
+ If None, node count is estimated from unique node spikes. Default is None.
105
+ network_name (Optional[str], optional): Name of the network used in the configuration file, allowing selection of nodes for that network.
106
+ Required if `config` is provided. Default is None.
107
+ save (bool, optional): Whether to save the calculated population spike rate to a file. Default is False.
108
+ save_path (Optional[str], optional): Directory path where the file should be saved if `save` is True. If `save` is True and `save_path` is None, a ValueError is raised.
109
+ normalize (bool, optional): Whether to normalize the spike rates for each population to a range of [0, 1]. Default is False.
110
+
111
+ Returns:
112
+ Dict[str, np.ndarray]: A dictionary where keys are population names, and values are arrays representing the spike rate over time for each population.
113
+ If `normalize` is True, each population's spike rate is scaled to [0, 1].
114
+
115
+ Raises:
116
+ ValueError: If `save` is True but `save_path` is not provided.
117
+
118
+ Notes:
119
+ - If `config` is None, the function assumes all cells in each population have fired at least once; otherwise, the node count may be inaccurate.
120
+ - If normalization is enabled, each population's spike rate is scaled using Min-Max normalization based on its own minimum and maximum values.
121
+
122
+ """
123
+ pop_spikes = {}
124
+ node_number = {}
125
+
126
+ if config is None:
127
+ print("Note: Node number is obtained by counting unique node spikes in the network.\nIf the network did not run for a sufficient duration, and not all cells fired, this count might be incorrect.")
128
+ print("You can provide a config to calculate the correct amount of nodes!")
129
+
130
+ if config:
131
+ if not network_name:
132
+ print("Grabbing first network; specify a network name to ensure correct node population is selected.")
133
+
134
+ for pop_name in spikes['pop_name'].unique():
135
+ ps = spikes[spikes['pop_name'] == pop_name]
136
+
137
+ if config:
138
+ nodes = load_nodes_from_config(config)
139
+ if network_name:
140
+ nodes = nodes[network_name]
141
+ else:
142
+ nodes = list(nodes.values())[0] if nodes else {}
143
+ nodes = nodes[nodes['pop_name'] == pop_name]
144
+ node_number[pop_name] = nodes.index.nunique()
145
+ else:
146
+ node_number[pop_name] = ps['node_ids'].nunique()
147
+
148
+ if t_stop is None:
149
+ t_stop = spikes['timestamps'].max()
150
+
151
+ filtered_spikes = spikes[
152
+ (spikes['pop_name'] == pop_name) &
153
+ (spikes['timestamps'] > t_start) &
154
+ (spikes['timestamps'] < t_stop)
155
+ ]
156
+ pop_spikes[pop_name] = filtered_spikes
157
+
158
+ time = np.array([t_start, t_stop, 1000 / fs])
159
+ pop_rspk = {p: _pop_spike_rate(spk['timestamps'], time) for p, spk in pop_spikes.items()}
160
+ spike_rate = {p: fs / node_number[p] * pop_rspk[p] for p in pop_rspk}
161
+
162
+ # Normalize each spike rate series if normalize=True
163
+ if normalize:
164
+ spike_rate = {p: (sr - sr.min()) / (sr.max() - sr.min()) for p, sr in spike_rate.items()}
165
+
166
+ if save:
167
+ if save_path is None:
168
+ raise ValueError("save_path must be provided if save is True.")
169
+
170
+ os.makedirs(save_path, exist_ok=True)
171
+
172
+ save_file = os.path.join(save_path, 'spike_rate.h5')
173
+ with h5py.File(save_file, 'w') as f:
174
+ f.create_dataset('time', data=time)
175
+ grp = f.create_group('populations')
176
+ for p, rspk in spike_rate.items():
177
+ pop_grp = grp.create_group(p)
178
+ pop_grp.create_dataset('data', data=rspk)
179
+
180
+ return spike_rate
181
+
bmtool/bmplot.py CHANGED
@@ -762,7 +762,7 @@ def connector_percent_matrix(csv_path: str = None, exclude_strings=None, assemb_
762
762
  plt.tight_layout()
763
763
  plt.show()
764
764
 
765
- def raster(spikes_df: Optional[pd.DataFrame] = None, config: Optional[str] = None, network_name: Optional[str] = None,
765
+ def raster(spikes_df: Optional[pd.DataFrame] = None, config: Optional[str] = None, network_name: Optional[str] = None, groupby:Optional[str] = 'pop_name',
766
766
  ax: Optional[Axes] = None,tstart: Optional[float] = None,tstop: Optional[float] = None,
767
767
  color_map: Optional[Dict[str, str]] = None) -> Axes:
768
768
  """
@@ -793,7 +793,7 @@ def raster(spikes_df: Optional[pd.DataFrame] = None, config: Optional[str] = Non
793
793
  Notes:
794
794
  -----
795
795
  - If `config` is provided, the function merges population names from the node data with `spikes_df`.
796
- - Each unique population (`pop_name`) in `spikes_df` will be represented by a different color if `color_map` is not specified.
796
+ - Each unique population from groupby in `spikes_df` will be represented by a different color if `color_map` is not specified.
797
797
  - If `color_map` is provided, it should contain colors for all unique `pop_name` values in `spikes_df`.
798
798
  """
799
799
  # Initialize axes if none provided
@@ -822,11 +822,11 @@ def raster(spikes_df: Optional[pd.DataFrame] = None, config: Optional[str] = Non
822
822
  # Drop all intersecting columns except the join key column from df2
823
823
  spikes_df = spikes_df.drop(columns=common_columns)
824
824
  # merge nodes and spikes df
825
- spikes_df = spikes_df.merge(nodes['pop_name'], left_on='node_ids', right_index=True, how='left')
825
+ spikes_df = spikes_df.merge(nodes[groupby], left_on='node_ids', right_index=True, how='left')
826
826
 
827
827
 
828
828
  # Get unique population names
829
- unique_pop_names = spikes_df['pop_name'].unique()
829
+ unique_pop_names = spikes_df[groupby].unique()
830
830
 
831
831
  # Generate colors if no color_map is provided
832
832
  if color_map is None:
@@ -839,7 +839,7 @@ def raster(spikes_df: Optional[pd.DataFrame] = None, config: Optional[str] = Non
839
839
  raise ValueError(f"color_map is missing colors for populations: {missing_colors}")
840
840
 
841
841
  # Plot each population with its specified or generated color
842
- for pop_name, group in spikes_df.groupby('pop_name'):
842
+ for pop_name, group in spikes_df.groupby(groupby):
843
843
  ax.scatter(group['timestamps'], group['node_ids'], label=pop_name, color=color_map[pop_name], s=0.5)
844
844
 
845
845
  # Label axes
bmtool/singlecell.py CHANGED
@@ -90,7 +90,11 @@ class CurrentClamp(object):
90
90
  self.inj_dur = inj_dur
91
91
  self.inj_amp = inj_amp * 1e-3 # pA to nA
92
92
 
93
- self.cell = self.create_cell()
93
+ # sometimes people may put a hoc object in for the template name
94
+ if callable(template_name):
95
+ self.cell = template_name
96
+ else:
97
+ self.cell = self.create_cell()
94
98
  if post_init_function:
95
99
  eval(f"self.cell.{post_init_function}")
96
100
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: bmtool
3
- Version: 0.6.6.4
3
+ Version: 0.6.7.1
4
4
  Summary: BMTool
5
5
  Home-page: https://github.com/cyneuro/bmtool
6
6
  Download-URL:
@@ -1,13 +1,16 @@
1
- bmtool/SLURM.py,sha256=KNt6X0vMuUvVr94OKleq3MAhOuuiCkHWEzzCwZbH-38,16679
1
+ bmtool/SLURM.py,sha256=AKxu_Ln9wuCBVLdOJP4yAN59jYt222DM-iUlsQojNvY,18145
2
2
  bmtool/__init__.py,sha256=ZStTNkAJHJxG7Pwiy5UgCzC4KlhMS5pUNPtUJZVwL_Y,136
3
3
  bmtool/__main__.py,sha256=TmFkmDxjZ6250nYD4cgGhn-tbJeEm0u-EMz2ajAN9vE,650
4
- bmtool/bmplot.py,sha256=Im-Jrv8TK3CmTtksFzHrVogAve0l9ZwRrCW4q2MFRiA,53966
4
+ bmtool/bmplot.py,sha256=iTK6q8XEqc8QEAKR152ut_1qdtnMoEe1Uq-4dCrkCA0,53992
5
5
  bmtool/connectors.py,sha256=hWkUUcJ4tmas8NDOFPPjQT-TgTlPcpjuZsYyAW2WkPA,72242
6
6
  bmtool/graphs.py,sha256=K8BiughRUeXFVvAgo8UzrwpSClIVg7UfmIcvtEsEsk0,6020
7
7
  bmtool/manage.py,sha256=_lCU0qBQZ4jSxjzAJUd09JEetb--cud7KZgxQFbLGSY,657
8
8
  bmtool/plot_commands.py,sha256=Tqujyf0c0u8olhiHOMwgUSJXIIE1hgjv6otb25G9cA0,12298
9
- bmtool/singlecell.py,sha256=MQiLucsI6OBIjtcJra3Z9PTFQOE-Zn5ST-R9SmFvrbQ,27049
9
+ bmtool/singlecell.py,sha256=XZAT_2n44EhwqVLnk3qur9aO7oJ-10axJZfwPBslM88,27219
10
10
  bmtool/synapses.py,sha256=gIkfLhKDG2dHHCVJJoKuQrFn_Qut843bfk_-s97wu6c,54553
11
+ bmtool/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
+ bmtool/analysis/lfp.py,sha256=Zp-aJ8x2KmsI3h_mqvq4u9ixFYu4n0CxQpgdbYnrtYE,14909
13
+ bmtool/analysis/spikes.py,sha256=23k_wFOC9pKQgetxMp1V2z6cZaW2eoIZAZyjFiTfrrM,8560
11
14
  bmtool/debug/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
15
  bmtool/debug/commands.py,sha256=AwtcR7BUUheM0NxvU1Nu234zCdpobhJv5noX8x5K2vY,583
13
16
  bmtool/debug/debug.py,sha256=xqnkzLiH3s-tS26Y5lZZL62qR2evJdi46Gud-HzxEN4,207
@@ -16,9 +19,9 @@ bmtool/util/commands.py,sha256=zJF-fiLk0b8LyzHDfvewUyS7iumOxVnj33IkJDzux4M,64396
16
19
  bmtool/util/util.py,sha256=00vOAwTVIifCqouBoFoT0lBashl4fCalrk8fhg_Uq4c,56654
17
20
  bmtool/util/neuron/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
18
21
  bmtool/util/neuron/celltuner.py,sha256=xSRpRN6DhPFz4q5buq_W8UmsD7BbUrkzYBEbKVloYss,87194
19
- bmtool-0.6.6.4.dist-info/LICENSE,sha256=qrXg2jj6kz5d0EnN11hllcQt2fcWVNumx0xNbV05nyM,1068
20
- bmtool-0.6.6.4.dist-info/METADATA,sha256=loQf1_yqp2RYGGrghu71zSt4TnBtlmBpOSGXlm_21JY,20226
21
- bmtool-0.6.6.4.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
22
- bmtool-0.6.6.4.dist-info/entry_points.txt,sha256=0-BHZ6nUnh0twWw9SXNTiRmKjDnb1VO2DfG_-oprhAc,45
23
- bmtool-0.6.6.4.dist-info/top_level.txt,sha256=gpd2Sj-L9tWbuJEd5E8C8S8XkNm5yUE76klUYcM-eWM,7
24
- bmtool-0.6.6.4.dist-info/RECORD,,
22
+ bmtool-0.6.7.1.dist-info/LICENSE,sha256=qrXg2jj6kz5d0EnN11hllcQt2fcWVNumx0xNbV05nyM,1068
23
+ bmtool-0.6.7.1.dist-info/METADATA,sha256=qlkGt9KNUlVP7C3J1xgvj23YUHaPD7SlVFzhVhSUceg,20226
24
+ bmtool-0.6.7.1.dist-info/WHEEL,sha256=nn6H5-ilmfVryoAQl3ZQ2l8SH5imPWFpm1A5FgEuFV4,91
25
+ bmtool-0.6.7.1.dist-info/entry_points.txt,sha256=0-BHZ6nUnh0twWw9SXNTiRmKjDnb1VO2DfG_-oprhAc,45
26
+ bmtool-0.6.7.1.dist-info/top_level.txt,sha256=gpd2Sj-L9tWbuJEd5E8C8S8XkNm5yUE76klUYcM-eWM,7
27
+ bmtool-0.6.7.1.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.8.0)
2
+ Generator: setuptools (75.8.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5