bmtool 0.7.1__py3-none-any.whl → 0.7.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,7 +2,7 @@
2
2
  Module for entrainment analysis
3
3
  """
4
4
 
5
- from typing import Dict, List
5
+ from typing import Dict, List, Optional, Union
6
6
 
7
7
  import numba
8
8
  import numpy as np
@@ -41,7 +41,7 @@ def align_spike_times_with_lfp(lfp: xr.DataArray, timestamps: np.ndarray) -> np.
41
41
  (timestamps >= lfp.time.values[0]) & (timestamps <= lfp.time.values[-1])
42
42
  ].copy()
43
43
  # set the time axis of the spikes to match the lfp
44
- timestamps = timestamps - lfp.time.values[0]
44
+ # timestamps = timestamps - lfp.time.values[0]
45
45
  return timestamps
46
46
 
47
47
 
@@ -127,33 +127,33 @@ def calculate_signal_signal_plv(
127
127
  return plv
128
128
 
129
129
 
130
- def calculate_spike_lfp_plv(
131
- spike_times: np.ndarray = None,
132
- lfp_data=None,
133
- spike_fs: float = None,
134
- lfp_fs: float = None,
135
- filter_method: str = "butter",
136
- freq_of_interest: float = None,
137
- lowcut: float = None,
138
- highcut: float = None,
130
+ def _get_spike_phases(
131
+ spike_times: np.ndarray,
132
+ lfp_data: Union[np.ndarray, xr.DataArray],
133
+ spike_fs: float,
134
+ lfp_fs: float,
135
+ filter_method: str = "wavelet",
136
+ freq_of_interest: Optional[float] = None,
137
+ lowcut: Optional[float] = None,
138
+ highcut: Optional[float] = None,
139
139
  bandwidth: float = 2.0,
140
- filtered_lfp_phase: np.ndarray = None,
141
- ) -> float:
140
+ filtered_lfp_phase: Optional[Union[np.ndarray, xr.DataArray]] = None,
141
+ ) -> np.ndarray:
142
142
  """
143
- Calculate spike-lfp unbiased phase locking value
143
+ Helper function to get spike phases from LFP data.
144
144
 
145
145
  Parameters
146
146
  ----------
147
147
  spike_times : np.ndarray
148
148
  Array of spike times
149
- lfp_data : np.ndarray
149
+ lfp_data : Union[np.ndarray, xr.DataArray]
150
150
  Local field potential time series data. Not required if filtered_lfp_phase is provided.
151
- spike_fs : float, optional
152
- Sampling frequency in Hz of the spike times, only needed if spike times and LFP have different sampling rates
151
+ spike_fs : float
152
+ Sampling frequency in Hz of the spike times
153
153
  lfp_fs : float
154
154
  Sampling frequency in Hz of the LFP data
155
155
  filter_method : str, optional
156
- Method to use for filtering, either 'wavelet' or 'butter' (default: 'butter')
156
+ Method to use for filtering, either 'wavelet' or 'butter' (default: 'wavelet')
157
157
  freq_of_interest : float, optional
158
158
  Desired frequency for wavelet phase extraction, required if filter_method='wavelet'
159
159
  lowcut : float, optional
@@ -167,12 +167,9 @@ def calculate_spike_lfp_plv(
167
167
 
168
168
  Returns
169
169
  -------
170
- float
171
- Phase Locking Value (unbiased)
170
+ np.ndarray
171
+ Array of phases at spike times
172
172
  """
173
-
174
- if spike_fs is None:
175
- spike_fs = lfp_fs
176
173
  # Convert spike times to sample indices
177
174
  spike_times_seconds = spike_times / spike_fs
178
175
 
@@ -180,10 +177,21 @@ def calculate_spike_lfp_plv(
180
177
  spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
181
178
 
182
179
  # Filter indices to ensure they're within bounds of the LFP signal
183
- valid_indices = align_spike_times_with_lfp(lfp=lfp_data, timestamps=spike_indices)
180
+ if isinstance(lfp_data, xr.DataArray):
181
+ if filtered_lfp_phase is not None:
182
+ valid_indices = align_spike_times_with_lfp(
183
+ lfp=filtered_lfp_phase, timestamps=spike_indices
184
+ )
185
+ else:
186
+ valid_indices = align_spike_times_with_lfp(lfp=lfp_data, timestamps=spike_indices)
187
+ elif isinstance(lfp_data, np.ndarray):
188
+ if filtered_lfp_phase is not None:
189
+ valid_indices = [idx for idx in spike_indices if 0 <= idx < len(filtered_lfp_phase)]
190
+ else:
191
+ valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_data)]
184
192
 
185
193
  if len(valid_indices) <= 1:
186
- return 0
194
+ return np.array([])
187
195
 
188
196
  # Get instantaneous phase
189
197
  if filtered_lfp_phase is None:
@@ -200,7 +208,73 @@ def calculate_spike_lfp_plv(
200
208
  instantaneous_phase = filtered_lfp_phase
201
209
 
202
210
  # Get phases at spike times
203
- spike_phases = instantaneous_phase.sel(time=valid_indices).values
211
+ if isinstance(instantaneous_phase, xr.DataArray):
212
+ spike_phases = instantaneous_phase.sel(time=valid_indices, method="nearest").values
213
+ else:
214
+ spike_phases = instantaneous_phase[valid_indices]
215
+
216
+ return spike_phases
217
+
218
+
219
+ def calculate_spike_lfp_plv(
220
+ spike_times: np.ndarray = None,
221
+ lfp_data=None,
222
+ spike_fs: float = None,
223
+ lfp_fs: float = None,
224
+ filter_method: str = "butter",
225
+ freq_of_interest: float = None,
226
+ lowcut: float = None,
227
+ highcut: float = None,
228
+ bandwidth: float = 2.0,
229
+ filtered_lfp_phase: np.ndarray = None,
230
+ ) -> float:
231
+ """
232
+ Calculate spike-lfp unbiased phase locking value
233
+
234
+ Parameters
235
+ ----------
236
+ spike_times : np.ndarray
237
+ Array of spike times
238
+ lfp_data : np.ndarray
239
+ Local field potential time series data. Not required if filtered_lfp_phase is provided.
240
+ spike_fs : float, optional
241
+ Sampling frequency in Hz of the spike times, only needed if spike times and LFP have different sampling rates
242
+ lfp_fs : float
243
+ Sampling frequency in Hz of the LFP data
244
+ filter_method : str, optional
245
+ Method to use for filtering, either 'wavelet' or 'butter' (default: 'butter')
246
+ freq_of_interest : float, optional
247
+ Desired frequency for wavelet phase extraction, required if filter_method='wavelet'
248
+ lowcut : float, optional
249
+ Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
250
+ highcut : float, optional
251
+ Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
252
+ bandwidth : float, optional
253
+ Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
254
+ filtered_lfp_phase : np.ndarray, optional
255
+ Pre-computed instantaneous phase of the filtered LFP. If provided, the function will skip the filtering step.
256
+
257
+ Returns
258
+ -------
259
+ float
260
+ Phase Locking Value (unbiased)
261
+ """
262
+
263
+ spike_phases = _get_spike_phases(
264
+ spike_times=spike_times,
265
+ lfp_data=lfp_data,
266
+ spike_fs=spike_fs,
267
+ lfp_fs=lfp_fs,
268
+ filter_method=filter_method,
269
+ freq_of_interest=freq_of_interest,
270
+ lowcut=lowcut,
271
+ highcut=highcut,
272
+ bandwidth=bandwidth,
273
+ filtered_lfp_phase=filtered_lfp_phase,
274
+ )
275
+
276
+ if len(spike_phases) <= 1:
277
+ return 0
204
278
 
205
279
  # Number of spikes
206
280
  N = len(spike_phases)
@@ -302,46 +376,26 @@ def calculate_ppc(
302
376
  float
303
377
  Pairwise Phase Consistency value
304
378
  """
305
- if spike_fs is None:
306
- spike_fs = lfp_fs
307
- # Convert spike times to sample indices
308
- spike_times_seconds = spike_times / spike_fs
309
379
 
310
- # Then convert from seconds to samples at the new sampling rate
311
- spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
312
-
313
- # Filter indices to ensure they're within bounds of the LFP signal
314
- if filtered_lfp_phase is not None:
315
- valid_indices = align_spike_times_with_lfp(lfp=filtered_lfp_phase, timestamps=spike_indices)
316
- else:
317
- valid_indices = align_spike_times_with_lfp(lfp=lfp_data, timestamps=spike_indices)
318
-
319
- if len(valid_indices) <= 1:
380
+ spike_phases = _get_spike_phases(
381
+ spike_times=spike_times,
382
+ lfp_data=lfp_data,
383
+ spike_fs=spike_fs,
384
+ lfp_fs=lfp_fs,
385
+ filter_method=filter_method,
386
+ freq_of_interest=freq_of_interest,
387
+ lowcut=lowcut,
388
+ highcut=highcut,
389
+ bandwidth=bandwidth,
390
+ filtered_lfp_phase=filtered_lfp_phase,
391
+ )
392
+
393
+ if len(spike_phases) <= 1:
320
394
  return 0
321
395
 
322
- # Get instantaneous phase
323
- if filtered_lfp_phase is None:
324
- instantaneous_phase = get_lfp_phase(
325
- lfp_data=lfp_data,
326
- filter_method=filter_method,
327
- freq_of_interest=freq_of_interest,
328
- lowcut=lowcut,
329
- highcut=highcut,
330
- bandwidth=bandwidth,
331
- fs=lfp_fs,
332
- )
333
- else:
334
- instantaneous_phase = filtered_lfp_phase
335
-
336
- # Get phases at spike times
337
- spike_phases = instantaneous_phase.sel(time=valid_indices).values
338
-
339
396
  n_spikes = len(spike_phases)
340
397
 
341
398
  # Calculate PPC (Pairwise Phase Consistency)
342
- if n_spikes <= 1:
343
- return 0
344
-
345
399
  # Explicit calculation of pairwise phase consistency
346
400
  # Vectorized computation for efficiency
347
401
  if ppc_method == "numpy":
@@ -409,45 +463,25 @@ def calculate_ppc2(
409
463
  Pairwise Phase Consistency 2 (PPC2) value
410
464
  """
411
465
 
412
- if spike_fs is None:
413
- spike_fs = lfp_fs
414
- # Convert spike times to sample indices
415
- spike_times_seconds = spike_times / spike_fs
416
-
417
- # Then convert from seconds to samples at the new sampling rate
418
- spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
419
-
420
- # Filter indices to ensure they're within bounds of the LFP signal
421
- if filtered_lfp_phase is not None:
422
- valid_indices = align_spike_times_with_lfp(lfp=filtered_lfp_phase, timestamps=spike_indices)
423
- else:
424
- valid_indices = align_spike_times_with_lfp(lfp=lfp_data, timestamps=spike_indices)
425
-
426
- if len(valid_indices) <= 1:
466
+ spike_phases = _get_spike_phases(
467
+ spike_times=spike_times,
468
+ lfp_data=lfp_data,
469
+ spike_fs=spike_fs,
470
+ lfp_fs=lfp_fs,
471
+ filter_method=filter_method,
472
+ freq_of_interest=freq_of_interest,
473
+ lowcut=lowcut,
474
+ highcut=highcut,
475
+ bandwidth=bandwidth,
476
+ filtered_lfp_phase=filtered_lfp_phase,
477
+ )
478
+
479
+ if len(spike_phases) <= 1:
427
480
  return 0
428
481
 
429
- # Get instantaneous phase
430
- if filtered_lfp_phase is None:
431
- instantaneous_phase = get_lfp_phase(
432
- lfp_data=lfp_data,
433
- filter_method=filter_method,
434
- freq_of_interest=freq_of_interest,
435
- lowcut=lowcut,
436
- highcut=highcut,
437
- bandwidth=bandwidth,
438
- fs=lfp_fs,
439
- )
440
- else:
441
- instantaneous_phase = filtered_lfp_phase
442
-
443
- # Get phases at spike times
444
- spike_phases = instantaneous_phase.sel(time=valid_indices).values
445
482
  # Calculate PPC2 according to Vinck et al. (2010), Equation 6
446
483
  n = len(spike_phases)
447
484
 
448
- if n <= 1:
449
- return 0
450
-
451
485
  # Convert phases to unit vectors in the complex plane
452
486
  unit_vectors = np.exp(1j * spike_phases)
453
487
 
@@ -539,7 +573,7 @@ def calculate_entrainment_per_cell(
539
573
  for pop in pop_names:
540
574
  skip_count = 0
541
575
  pop_spikes = spike_df[spike_df["pop_name"] == pop]
542
- nodes = pop_spikes["node_ids"].unique()
576
+ nodes = sorted(pop_spikes["node_ids"].unique()) # sort so all nodes are processed in order
543
577
  entrainment_dict[pop] = {}
544
578
  print(f"Processing {pop} population")
545
579
  for node in tqdm(nodes):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bmtool
3
- Version: 0.7.1
3
+ Version: 0.7.1.2
4
4
  Summary: BMTool
5
5
  Home-page: https://github.com/cyneuro/bmtool
6
6
  Download-URL:
@@ -8,7 +8,7 @@ bmtool/plot_commands.py,sha256=Dxm_RaT4CtHnfsltTtUopJ4KVbfhxtktEB_b7bFEXII,12716
8
8
  bmtool/singlecell.py,sha256=I2yolbAnNC8qpnRkNdnDCLidNW7CktmBuRrcowMZJ3A,45041
9
9
  bmtool/synapses.py,sha256=wlRY7IixefPzafqG6k2sPIK4s6PLG9Kct-oCaVR29wA,64269
10
10
  bmtool/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
- bmtool/analysis/entrainment.py,sha256=UVYhakyFGmH56UZ2jlshOH86rIKc9hSI-U2-kK2yp7o,25190
11
+ bmtool/analysis/entrainment.py,sha256=sjfSxPs1Y0dnEtX9a3IIMEeQ09L6WbhO3KMt-O8SN64,26480
12
12
  bmtool/analysis/lfp.py,sha256=S2JvxkjcK3-EH93wCrhqNSFY6cX7fOq74pz64ibHKrc,26556
13
13
  bmtool/analysis/netcon_reports.py,sha256=VnPZNKPaQA7oh1q9cIatsqQudm4cOtzNtbGPXoiDCD0,2909
14
14
  bmtool/analysis/spikes.py,sha256=ScP4EeX2QuEd_FXyj3W0WWgZKvZwwneuWuKFe3xwaCY,15115
@@ -26,9 +26,9 @@ bmtool/util/commands.py,sha256=Nn-R-4e9g8ZhSPZvTkr38xeKRPfEMANB9Lugppj82UI,68564
26
26
  bmtool/util/util.py,sha256=owce5BEusZO_8T5x05N2_B583G26vWAy7QX29V0Pj0Y,62818
27
27
  bmtool/util/neuron/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
28
28
  bmtool/util/neuron/celltuner.py,sha256=lokRLUM1rsdSYBYrNbLBBo39j14mm8TBNVNRnSlhHCk,94868
29
- bmtool-0.7.1.dist-info/licenses/LICENSE,sha256=qrXg2jj6kz5d0EnN11hllcQt2fcWVNumx0xNbV05nyM,1068
30
- bmtool-0.7.1.dist-info/METADATA,sha256=qVCMtNEx1YXI0KlMV3hLlhTSdEbbv5xrj2V1ZFWY0ho,3621
31
- bmtool-0.7.1.dist-info/WHEEL,sha256=DnLRTWE75wApRYVsjgc6wsVswC54sMSJhAEd4xhDpBk,91
32
- bmtool-0.7.1.dist-info/entry_points.txt,sha256=0-BHZ6nUnh0twWw9SXNTiRmKjDnb1VO2DfG_-oprhAc,45
33
- bmtool-0.7.1.dist-info/top_level.txt,sha256=gpd2Sj-L9tWbuJEd5E8C8S8XkNm5yUE76klUYcM-eWM,7
34
- bmtool-0.7.1.dist-info/RECORD,,
29
+ bmtool-0.7.1.2.dist-info/licenses/LICENSE,sha256=qrXg2jj6kz5d0EnN11hllcQt2fcWVNumx0xNbV05nyM,1068
30
+ bmtool-0.7.1.2.dist-info/METADATA,sha256=LO1VUW641H9cxsHp1vi809dqrJoZGc4GfJaKqTbOZGc,3623
31
+ bmtool-0.7.1.2.dist-info/WHEEL,sha256=DnLRTWE75wApRYVsjgc6wsVswC54sMSJhAEd4xhDpBk,91
32
+ bmtool-0.7.1.2.dist-info/entry_points.txt,sha256=0-BHZ6nUnh0twWw9SXNTiRmKjDnb1VO2DfG_-oprhAc,45
33
+ bmtool-0.7.1.2.dist-info/top_level.txt,sha256=gpd2Sj-L9tWbuJEd5E8C8S8XkNm5yUE76klUYcM-eWM,7
34
+ bmtool-0.7.1.2.dist-info/RECORD,,