bmtool 0.7.1.1__py3-none-any.whl → 0.7.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,7 +2,7 @@
2
2
  Module for entrainment analysis
3
3
  """
4
4
 
5
- from typing import Dict, List
5
+ from typing import Dict, List, Optional, Union
6
6
 
7
7
  import numba
8
8
  import numpy as np
@@ -41,7 +41,7 @@ def align_spike_times_with_lfp(lfp: xr.DataArray, timestamps: np.ndarray) -> np.
41
41
  (timestamps >= lfp.time.values[0]) & (timestamps <= lfp.time.values[-1])
42
42
  ].copy()
43
43
  # set the time axis of the spikes to match the lfp
44
- timestamps = timestamps - lfp.time.values[0]
44
+ # timestamps = timestamps - lfp.time.values[0]
45
45
  return timestamps
46
46
 
47
47
 
@@ -127,33 +127,33 @@ def calculate_signal_signal_plv(
127
127
  return plv
128
128
 
129
129
 
130
- def calculate_spike_lfp_plv(
131
- spike_times: np.ndarray = None,
132
- lfp_data=None,
133
- spike_fs: float = None,
134
- lfp_fs: float = None,
135
- filter_method: str = "butter",
136
- freq_of_interest: float = None,
137
- lowcut: float = None,
138
- highcut: float = None,
130
+ def _get_spike_phases(
131
+ spike_times: np.ndarray,
132
+ lfp_data: Union[np.ndarray, xr.DataArray],
133
+ spike_fs: float,
134
+ lfp_fs: float,
135
+ filter_method: str = "wavelet",
136
+ freq_of_interest: Optional[float] = None,
137
+ lowcut: Optional[float] = None,
138
+ highcut: Optional[float] = None,
139
139
  bandwidth: float = 2.0,
140
- filtered_lfp_phase: np.ndarray = None,
141
- ) -> float:
140
+ filtered_lfp_phase: Optional[Union[np.ndarray, xr.DataArray]] = None,
141
+ ) -> np.ndarray:
142
142
  """
143
- Calculate spike-lfp unbiased phase locking value
143
+ Helper function to get spike phases from LFP data.
144
144
 
145
145
  Parameters
146
146
  ----------
147
147
  spike_times : np.ndarray
148
148
  Array of spike times
149
- lfp_data : np.ndarray
149
+ lfp_data : Union[np.ndarray, xr.DataArray]
150
150
  Local field potential time series data. Not required if filtered_lfp_phase is provided.
151
- spike_fs : float, optional
152
- Sampling frequency in Hz of the spike times, only needed if spike times and LFP have different sampling rates
151
+ spike_fs : float
152
+ Sampling frequency in Hz of the spike times
153
153
  lfp_fs : float
154
154
  Sampling frequency in Hz of the LFP data
155
155
  filter_method : str, optional
156
- Method to use for filtering, either 'wavelet' or 'butter' (default: 'butter')
156
+ Method to use for filtering, either 'wavelet' or 'butter' (default: 'wavelet')
157
157
  freq_of_interest : float, optional
158
158
  Desired frequency for wavelet phase extraction, required if filter_method='wavelet'
159
159
  lowcut : float, optional
@@ -167,12 +167,9 @@ def calculate_spike_lfp_plv(
167
167
 
168
168
  Returns
169
169
  -------
170
- float
171
- Phase Locking Value (unbiased)
170
+ np.ndarray
171
+ Array of phases at spike times
172
172
  """
173
-
174
- if spike_fs is None:
175
- spike_fs = lfp_fs
176
173
  # Convert spike times to sample indices
177
174
  spike_times_seconds = spike_times / spike_fs
178
175
 
@@ -194,7 +191,7 @@ def calculate_spike_lfp_plv(
194
191
  valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_data)]
195
192
 
196
193
  if len(valid_indices) <= 1:
197
- return 0
194
+ return np.array([])
198
195
 
199
196
  # Get instantaneous phase
200
197
  if filtered_lfp_phase is None:
@@ -212,10 +209,73 @@ def calculate_spike_lfp_plv(
212
209
 
213
210
  # Get phases at spike times
214
211
  if isinstance(instantaneous_phase, xr.DataArray):
215
- spike_phases = instantaneous_phase.sel(time=valid_indices).values
212
+ spike_phases = instantaneous_phase.sel(time=valid_indices, method="nearest").values
216
213
  else:
217
214
  spike_phases = instantaneous_phase[valid_indices]
218
215
 
216
+ return spike_phases
217
+
218
+
219
+ def calculate_spike_lfp_plv(
220
+ spike_times: np.ndarray = None,
221
+ lfp_data=None,
222
+ spike_fs: float = None,
223
+ lfp_fs: float = None,
224
+ filter_method: str = "butter",
225
+ freq_of_interest: float = None,
226
+ lowcut: float = None,
227
+ highcut: float = None,
228
+ bandwidth: float = 2.0,
229
+ filtered_lfp_phase: np.ndarray = None,
230
+ ) -> float:
231
+ """
232
+ Calculate spike-lfp unbiased phase locking value
233
+
234
+ Parameters
235
+ ----------
236
+ spike_times : np.ndarray
237
+ Array of spike times
238
+ lfp_data : np.ndarray
239
+ Local field potential time series data. Not required if filtered_lfp_phase is provided.
240
+ spike_fs : float, optional
241
+ Sampling frequency in Hz of the spike times, only needed if spike times and LFP have different sampling rates
242
+ lfp_fs : float
243
+ Sampling frequency in Hz of the LFP data
244
+ filter_method : str, optional
245
+ Method to use for filtering, either 'wavelet' or 'butter' (default: 'butter')
246
+ freq_of_interest : float, optional
247
+ Desired frequency for wavelet phase extraction, required if filter_method='wavelet'
248
+ lowcut : float, optional
249
+ Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
250
+ highcut : float, optional
251
+ Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
252
+ bandwidth : float, optional
253
+ Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
254
+ filtered_lfp_phase : np.ndarray, optional
255
+ Pre-computed instantaneous phase of the filtered LFP. If provided, the function will skip the filtering step.
256
+
257
+ Returns
258
+ -------
259
+ float
260
+ Phase Locking Value (unbiased)
261
+ """
262
+
263
+ spike_phases = _get_spike_phases(
264
+ spike_times=spike_times,
265
+ lfp_data=lfp_data,
266
+ spike_fs=spike_fs,
267
+ lfp_fs=lfp_fs,
268
+ filter_method=filter_method,
269
+ freq_of_interest=freq_of_interest,
270
+ lowcut=lowcut,
271
+ highcut=highcut,
272
+ bandwidth=bandwidth,
273
+ filtered_lfp_phase=filtered_lfp_phase,
274
+ )
275
+
276
+ if len(spike_phases) <= 1:
277
+ return 0
278
+
219
279
  # Number of spikes
220
280
  N = len(spike_phases)
221
281
 
@@ -316,57 +376,26 @@ def calculate_ppc(
316
376
  float
317
377
  Pairwise Phase Consistency value
318
378
  """
319
- if spike_fs is None:
320
- spike_fs = lfp_fs
321
- # Convert spike times to sample indices
322
- spike_times_seconds = spike_times / spike_fs
323
-
324
- # Then convert from seconds to samples at the new sampling rate
325
- spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
326
379
 
327
- # Filter indices to ensure they're within bounds of the LFP signal
328
- if isinstance(lfp_data, xr.DataArray):
329
- if filtered_lfp_phase is not None:
330
- valid_indices = align_spike_times_with_lfp(
331
- lfp=filtered_lfp_phase, timestamps=spike_indices
332
- )
333
- else:
334
- valid_indices = align_spike_times_with_lfp(lfp=lfp_data, timestamps=spike_indices)
335
- elif isinstance(lfp_data, np.ndarray):
336
- if filtered_lfp_phase is not None:
337
- valid_indices = [idx for idx in spike_indices if 0 <= idx < len(filtered_lfp_phase)]
338
- else:
339
- valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_data)]
340
-
341
- if len(valid_indices) <= 1:
380
+ spike_phases = _get_spike_phases(
381
+ spike_times=spike_times,
382
+ lfp_data=lfp_data,
383
+ spike_fs=spike_fs,
384
+ lfp_fs=lfp_fs,
385
+ filter_method=filter_method,
386
+ freq_of_interest=freq_of_interest,
387
+ lowcut=lowcut,
388
+ highcut=highcut,
389
+ bandwidth=bandwidth,
390
+ filtered_lfp_phase=filtered_lfp_phase,
391
+ )
392
+
393
+ if len(spike_phases) <= 1:
342
394
  return 0
343
395
 
344
- # Get instantaneous phase
345
- if filtered_lfp_phase is None:
346
- instantaneous_phase = get_lfp_phase(
347
- lfp_data=lfp_data,
348
- filter_method=filter_method,
349
- freq_of_interest=freq_of_interest,
350
- lowcut=lowcut,
351
- highcut=highcut,
352
- bandwidth=bandwidth,
353
- fs=lfp_fs,
354
- )
355
- else:
356
- instantaneous_phase = filtered_lfp_phase
357
-
358
- # Get phases at spike times
359
- if isinstance(instantaneous_phase, xr.DataArray):
360
- spike_phases = instantaneous_phase.sel(time=valid_indices).values
361
- else:
362
- spike_phases = instantaneous_phase[valid_indices]
363
-
364
396
  n_spikes = len(spike_phases)
365
397
 
366
398
  # Calculate PPC (Pairwise Phase Consistency)
367
- if n_spikes <= 1:
368
- return 0
369
-
370
399
  # Explicit calculation of pairwise phase consistency
371
400
  # Vectorized computation for efficiency
372
401
  if ppc_method == "numpy":
@@ -434,56 +463,25 @@ def calculate_ppc2(
434
463
  Pairwise Phase Consistency 2 (PPC2) value
435
464
  """
436
465
 
437
- if spike_fs is None:
438
- spike_fs = lfp_fs
439
- # Convert spike times to sample indices
440
- spike_times_seconds = spike_times / spike_fs
441
-
442
- # Then convert from seconds to samples at the new sampling rate
443
- spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
444
-
445
- # Filter indices to ensure they're within bounds of the LFP signal
446
- if isinstance(lfp_data, xr.DataArray):
447
- if filtered_lfp_phase is not None:
448
- valid_indices = align_spike_times_with_lfp(
449
- lfp=filtered_lfp_phase, timestamps=spike_indices
450
- )
451
- else:
452
- valid_indices = align_spike_times_with_lfp(lfp=lfp_data, timestamps=spike_indices)
453
- elif isinstance(lfp_data, np.ndarray):
454
- if filtered_lfp_phase is not None:
455
- valid_indices = [idx for idx in spike_indices if 0 <= idx < len(filtered_lfp_phase)]
456
- else:
457
- valid_indices = [idx for idx in spike_indices if 0 <= idx < len(lfp_data)]
458
-
459
- if len(valid_indices) <= 1:
466
+ spike_phases = _get_spike_phases(
467
+ spike_times=spike_times,
468
+ lfp_data=lfp_data,
469
+ spike_fs=spike_fs,
470
+ lfp_fs=lfp_fs,
471
+ filter_method=filter_method,
472
+ freq_of_interest=freq_of_interest,
473
+ lowcut=lowcut,
474
+ highcut=highcut,
475
+ bandwidth=bandwidth,
476
+ filtered_lfp_phase=filtered_lfp_phase,
477
+ )
478
+
479
+ if len(spike_phases) <= 1:
460
480
  return 0
461
481
 
462
- # Get instantaneous phase
463
- if filtered_lfp_phase is None:
464
- instantaneous_phase = get_lfp_phase(
465
- lfp_data=lfp_data,
466
- filter_method=filter_method,
467
- freq_of_interest=freq_of_interest,
468
- lowcut=lowcut,
469
- highcut=highcut,
470
- bandwidth=bandwidth,
471
- fs=lfp_fs,
472
- )
473
- else:
474
- instantaneous_phase = filtered_lfp_phase
475
-
476
- # Get phases at spike times
477
- if isinstance(instantaneous_phase, xr.DataArray):
478
- spike_phases = instantaneous_phase.sel(time=valid_indices).values
479
- else:
480
- spike_phases = instantaneous_phase[valid_indices]
481
482
  # Calculate PPC2 according to Vinck et al. (2010), Equation 6
482
483
  n = len(spike_phases)
483
484
 
484
- if n <= 1:
485
- return 0
486
-
487
485
  # Convert phases to unit vectors in the complex plane
488
486
  unit_vectors = np.exp(1j * spike_phases)
489
487
 
@@ -575,7 +573,7 @@ def calculate_entrainment_per_cell(
575
573
  for pop in pop_names:
576
574
  skip_count = 0
577
575
  pop_spikes = spike_df[spike_df["pop_name"] == pop]
578
- nodes = pop_spikes["node_ids"].unique()
576
+ nodes = sorted(pop_spikes["node_ids"].unique()) # sort so all nodes are processed in order
579
577
  entrainment_dict[pop] = {}
580
578
  print(f"Processing {pop} population")
581
579
  for node in tqdm(nodes):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bmtool
3
- Version: 0.7.1.1
3
+ Version: 0.7.1.2
4
4
  Summary: BMTool
5
5
  Home-page: https://github.com/cyneuro/bmtool
6
6
  Download-URL:
@@ -8,7 +8,7 @@ bmtool/plot_commands.py,sha256=Dxm_RaT4CtHnfsltTtUopJ4KVbfhxtktEB_b7bFEXII,12716
8
8
  bmtool/singlecell.py,sha256=I2yolbAnNC8qpnRkNdnDCLidNW7CktmBuRrcowMZJ3A,45041
9
9
  bmtool/synapses.py,sha256=wlRY7IixefPzafqG6k2sPIK4s6PLG9Kct-oCaVR29wA,64269
10
10
  bmtool/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
- bmtool/analysis/entrainment.py,sha256=Nzrt1pKHEDzrx9TYnsE0sbevxl4iht1fymA29khq8Pg,26841
11
+ bmtool/analysis/entrainment.py,sha256=sjfSxPs1Y0dnEtX9a3IIMEeQ09L6WbhO3KMt-O8SN64,26480
12
12
  bmtool/analysis/lfp.py,sha256=S2JvxkjcK3-EH93wCrhqNSFY6cX7fOq74pz64ibHKrc,26556
13
13
  bmtool/analysis/netcon_reports.py,sha256=VnPZNKPaQA7oh1q9cIatsqQudm4cOtzNtbGPXoiDCD0,2909
14
14
  bmtool/analysis/spikes.py,sha256=ScP4EeX2QuEd_FXyj3W0WWgZKvZwwneuWuKFe3xwaCY,15115
@@ -26,9 +26,9 @@ bmtool/util/commands.py,sha256=Nn-R-4e9g8ZhSPZvTkr38xeKRPfEMANB9Lugppj82UI,68564
26
26
  bmtool/util/util.py,sha256=owce5BEusZO_8T5x05N2_B583G26vWAy7QX29V0Pj0Y,62818
27
27
  bmtool/util/neuron/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
28
28
  bmtool/util/neuron/celltuner.py,sha256=lokRLUM1rsdSYBYrNbLBBo39j14mm8TBNVNRnSlhHCk,94868
29
- bmtool-0.7.1.1.dist-info/licenses/LICENSE,sha256=qrXg2jj6kz5d0EnN11hllcQt2fcWVNumx0xNbV05nyM,1068
30
- bmtool-0.7.1.1.dist-info/METADATA,sha256=ONMkDOU2sGv2xzC9jbP5z9HdGFI2HGO_2fwS6cIBj3w,3623
31
- bmtool-0.7.1.1.dist-info/WHEEL,sha256=DnLRTWE75wApRYVsjgc6wsVswC54sMSJhAEd4xhDpBk,91
32
- bmtool-0.7.1.1.dist-info/entry_points.txt,sha256=0-BHZ6nUnh0twWw9SXNTiRmKjDnb1VO2DfG_-oprhAc,45
33
- bmtool-0.7.1.1.dist-info/top_level.txt,sha256=gpd2Sj-L9tWbuJEd5E8C8S8XkNm5yUE76klUYcM-eWM,7
34
- bmtool-0.7.1.1.dist-info/RECORD,,
29
+ bmtool-0.7.1.2.dist-info/licenses/LICENSE,sha256=qrXg2jj6kz5d0EnN11hllcQt2fcWVNumx0xNbV05nyM,1068
30
+ bmtool-0.7.1.2.dist-info/METADATA,sha256=LO1VUW641H9cxsHp1vi809dqrJoZGc4GfJaKqTbOZGc,3623
31
+ bmtool-0.7.1.2.dist-info/WHEEL,sha256=DnLRTWE75wApRYVsjgc6wsVswC54sMSJhAEd4xhDpBk,91
32
+ bmtool-0.7.1.2.dist-info/entry_points.txt,sha256=0-BHZ6nUnh0twWw9SXNTiRmKjDnb1VO2DfG_-oprhAc,45
33
+ bmtool-0.7.1.2.dist-info/top_level.txt,sha256=gpd2Sj-L9tWbuJEd5E8C8S8XkNm5yUE76klUYcM-eWM,7
34
+ bmtool-0.7.1.2.dist-info/RECORD,,