bmtool 0.7.1.4__py3-none-any.whl → 0.7.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -714,7 +714,17 @@ def calculate_spike_rate_power_correlation(
714
714
  return correlation_results, frequencies
715
715
 
716
716
 
717
- def get_spikes_in_cycle(spike_df, lfp_data, spike_fs=1000, lfp_fs=400, band=(30, 80)):
717
+ def get_spikes_in_cycle(
718
+ spike_df,
719
+ lfp_data,
720
+ spike_fs=1000,
721
+ lfp_fs=400,
722
+ filter_method="butter",
723
+ lowcut=None,
724
+ highcut=None,
725
+ bandwidth=2.0,
726
+ freq_of_interest=None,
727
+ ):
718
728
  """
719
729
  Analyze spike timing relative to oscillation phases.
720
730
 
@@ -733,12 +743,15 @@ def get_spikes_in_cycle(spike_df, lfp_data, spike_fs=1000, lfp_fs=400, band=(30,
733
743
  phase_data : dict
734
744
  Dictionary containing phase values for each spike and neuron population
735
745
  """
736
- filtered_lfp = butter_bandpass_filter(lfp_data, band[0], band[1], lfp_fs)
737
-
738
- # Calculate phase using Hilbert transform
739
- analytic_signal = signal.hilbert(filtered_lfp)
740
- phase = np.angle(analytic_signal)
741
- amplitude = np.abs(analytic_signal)
746
+ phase = get_lfp_phase(
747
+ lfp_data=lfp_data,
748
+ fs=lfp_fs,
749
+ filter_method=filter_method,
750
+ lowcut=lowcut,
751
+ highcut=highcut,
752
+ bandwidth=bandwidth,
753
+ freq_of_interest=freq_of_interest,
754
+ )
742
755
 
743
756
  # Get unique neuron populations
744
757
  neuron_pops = spike_df["pop_name"].unique()
@@ -763,4 +776,4 @@ def get_spikes_in_cycle(spike_df, lfp_data, spike_fs=1000, lfp_fs=400, band=(30,
763
776
  valid_samples = spike_indices[valid_indices]
764
777
  phase_data[pop] = phase[valid_samples]
765
778
 
766
- return phase_data, filtered_lfp, phase, amplitude
779
+ return phase_data
bmtool/analysis/spikes.py CHANGED
@@ -3,11 +3,12 @@ Module for processing BMTK spikes output.
3
3
  """
4
4
 
5
5
  import os
6
- from typing import Dict, List, Optional, Tuple, Union
6
+ from typing import List, Optional, Tuple, Union
7
7
 
8
8
  import h5py
9
9
  import numpy as np
10
10
  import pandas as pd
11
+ import xarray as xr
11
12
  from scipy.stats import mannwhitneyu
12
13
 
13
14
  from bmtool.util.util import load_nodes_from_config
@@ -213,9 +214,11 @@ def get_population_spike_rate(
213
214
  save: bool = False,
214
215
  save_path: Optional[str] = None,
215
216
  normalize: bool = False,
216
- ) -> Dict[str, np.ndarray]:
217
+ smooth_window: int = 50, # Window size for smoothing (in time bins)
218
+ smooth_method: str = "gaussian", # Smoothing method: 'gaussian', 'boxcar', or 'exponential'
219
+ ) -> xr.DataArray:
217
220
  """
218
- Calculate the population spike rate for each population in the given spike data, with an option to normalize.
221
+ Calculate the population spike rate for each population in the given spike data.
219
222
 
220
223
  Parameters
221
224
  ----------
@@ -239,23 +242,41 @@ def get_population_spike_rate(
239
242
  Directory path where the file should be saved if `save` is True (default: None)
240
243
  normalize : bool, optional
241
244
  Whether to normalize the spike rates for each population to a range of [0, 1] (default: False)
245
+ smooth_window : int, optional
246
+ Window size for smoothing in number of time bins (default: 50)
247
+ smooth_method : str, optional
248
+ Smoothing method to use: 'gaussian', 'boxcar', or 'exponential' (default: 'gaussian')
242
249
 
243
250
  Returns
244
251
  -------
245
- Dict[str, np.ndarray]
246
- A dictionary where keys are population names, and values are arrays representing the spike rate over time for each population.
247
- If `normalize` is True, each population's spike rate is scaled to [0, 1].
252
+ xr.DataArray
253
+ An xarray DataArray containing the spike rates with dimensions of time, population, and type.
254
+ The 'type' dimension includes 'raw' and 'smoothed' values.
255
+ The DataArray includes sampling frequency (fs) as an attribute.
256
+ If normalize is True, each population's spike rate is scaled to [0, 1].
248
257
 
249
258
  Raises
250
259
  ------
251
260
  ValueError
252
261
  If `save` is True but `save_path` is not provided.
262
+ If an invalid smooth_method is specified.
253
263
 
254
264
  Notes
255
265
  -----
256
- - If `config` is None, the function assumes all cells in each population have fired at least once; otherwise, the node count may be inaccurate.
257
- - If normalization is enabled, each population's spike rate is scaled using Min-Max normalization based on its own minimum and maximum values.
266
+ - If `config` is None, the function assumes all cells in each population have fired at least once;
267
+ otherwise, the node count may be inaccurate.
268
+ - If normalization is enabled, each population's spike rate is scaled using Min-Max normalization.
269
+ - Smoothing is applied using scipy.ndimage's filters based on the specified method.
258
270
  """
271
+ import numpy as np
272
+ from scipy import ndimage
273
+
274
+ # Validate smoothing method
275
+ if smooth_method not in ["gaussian", "boxcar", "exponential"]:
276
+ raise ValueError(
277
+ f"Invalid smooth_method: {smooth_method}. Choose from 'gaussian', 'boxcar', or 'exponential'."
278
+ )
279
+
259
280
  pop_spikes = {}
260
281
  node_number = {}
261
282
 
@@ -271,7 +292,13 @@ def get_population_spike_rate(
271
292
  "Grabbing first network; specify a network name to ensure correct node population is selected."
272
293
  )
273
294
 
274
- for pop_name in spike_data["pop_name"].unique():
295
+ # Get t_stop if not provided
296
+ if t_stop is None:
297
+ t_stop = spike_data["timestamps"].max()
298
+
299
+ # Get population names and prepare data
300
+ populations = spike_data["pop_name"].unique()
301
+ for pop_name in populations:
275
302
  ps = spike_data[spike_data["pop_name"] == pop_name]
276
303
 
277
304
  if config:
@@ -282,12 +309,10 @@ def get_population_spike_rate(
282
309
  nodes = list(nodes.values())[0] if nodes else {}
283
310
  nodes = nodes[nodes["pop_name"] == pop_name]
284
311
  node_number[pop_name] = nodes.index.nunique()
312
+
285
313
  else:
286
314
  node_number[pop_name] = ps["node_ids"].nunique()
287
315
 
288
- if t_stop is None:
289
- t_stop = spike_data["timestamps"].max()
290
-
291
316
  filtered_spikes = spike_data[
292
317
  (spike_data["pop_name"] == pop_name)
293
318
  & (spike_data["timestamps"] > t_start)
@@ -295,29 +320,153 @@ def get_population_spike_rate(
295
320
  ]
296
321
  pop_spikes[pop_name] = filtered_spikes
297
322
 
298
- time = np.array([t_start, t_stop, 1000 / fs])
299
- pop_rspk = {p: _pop_spike_rate(spk["timestamps"], time) for p, spk in pop_spikes.items()}
300
- spike_rate = {p: fs / node_number[p] * pop_rspk[p] for p in pop_rspk}
323
+ # Calculate time points
324
+ time = np.arange(t_start, t_stop, 1000 / fs) # Convert sampling frequency to time steps
325
+
326
+ # Calculate spike rates for each population
327
+ spike_rates = []
328
+ for p in populations:
329
+ raw_rate = _pop_spike_rate(pop_spikes[p]["timestamps"], (t_start, t_stop, 1000 / fs))
330
+ rate = fs / node_number[p] * raw_rate
331
+ spike_rates.append(rate)
332
+
333
+ spike_rates_array = np.array(spike_rates).T # Transpose to have time as first dimension
334
+
335
+ # Calculate smoothed version for each population
336
+ smoothed_rates = []
337
+
338
+ for i in range(spike_rates_array.shape[1]):
339
+ pop_rate = spike_rates_array[:, i]
340
+
341
+ if smooth_method == "gaussian":
342
+ # Gaussian smoothing (sigma is approximately window/6 for a Gaussian filter)
343
+ sigma = smooth_window / 6
344
+ smoothed_pop_rate = ndimage.gaussian_filter1d(pop_rate, sigma=sigma)
345
+ elif smooth_method == "boxcar":
346
+ # Boxcar/uniform smoothing
347
+ kernel = np.ones(smooth_window) / smooth_window
348
+ smoothed_pop_rate = ndimage.convolve1d(pop_rate, kernel, mode="nearest")
349
+ elif smooth_method == "exponential":
350
+ # Exponential smoothing
351
+ alpha = 2 / (smooth_window + 1) # Equivalent to window size in exponential smoothing
352
+ smoothed_pop_rate = np.zeros_like(pop_rate)
353
+ smoothed_pop_rate[0] = pop_rate[0]
354
+ for t in range(1, len(pop_rate)):
355
+ smoothed_pop_rate[t] = alpha * pop_rate[t] + (1 - alpha) * smoothed_pop_rate[t - 1]
356
+
357
+ smoothed_rates.append(smoothed_pop_rate)
358
+
359
+ smoothed_rates_array = np.array(smoothed_rates).T # Transpose to have time as first dimension
360
+
361
+ # Stack raw and smoothed data
362
+ combined_data = np.stack([spike_rates_array, smoothed_rates_array], axis=2)
363
+
364
+ # Create DataArray with the additional 'type' dimension
365
+ spike_rate_array = xr.DataArray(
366
+ combined_data,
367
+ coords={"time": time, "population": populations, "type": ["raw", "smoothed"]},
368
+ dims=["time", "population", "type"],
369
+ attrs={
370
+ "fs": fs,
371
+ "normalized": False,
372
+ "smooth_method": smooth_method,
373
+ "smooth_window": smooth_window,
374
+ },
375
+ )
301
376
 
302
- # Normalize each spike rate series if normalize=True
377
+ # Normalize if requested
303
378
  if normalize:
304
- spike_rate = {p: (sr - sr.min()) / (sr.max() - sr.min()) for p, sr in spike_rate.items()}
305
-
379
+ # Apply normalization for each population and each type (raw/smoothed)
380
+ for pop_idx in range(len(populations)):
381
+ for type_idx, type_name in enumerate(["raw", "smoothed"]):
382
+ pop_data = spike_rate_array.sel(population=populations[pop_idx], type=type_name)
383
+ min_val = pop_data.min(dim="time")
384
+ max_val = pop_data.max(dim="time")
385
+
386
+ # Handle case where min == max (constant signal)
387
+ if max_val != min_val:
388
+ spike_rate_array.loc[:, populations[pop_idx], type_name] = (
389
+ pop_data - min_val
390
+ ) / (max_val - min_val)
391
+
392
+ spike_rate_array.attrs["normalized"] = True
393
+
394
+ # Save if requested
306
395
  if save:
307
396
  if save_path is None:
308
397
  raise ValueError("save_path must be provided if save is True.")
309
398
 
310
399
  os.makedirs(save_path, exist_ok=True)
311
-
312
400
  save_file = os.path.join(save_path, "spike_rate.h5")
313
- with h5py.File(save_file, "w") as f:
314
- f.create_dataset("time", data=time)
315
- grp = f.create_group("populations")
316
- for p, rspk in spike_rate.items():
317
- pop_grp = grp.create_group(p)
318
- pop_grp.create_dataset("data", data=rspk)
401
+ spike_rate_array.to_netcdf(save_file)
319
402
 
320
- return spike_rate
403
+ return spike_rate_array
404
+
405
+
406
+ def average_spike_rate_over_windows(
407
+ spike_rate: xr.DataArray, windows: List[Tuple[float, float]]
408
+ ) -> xr.DataArray:
409
+ """
410
+ Calculate the average spike rate over multiple time windows.
411
+
412
+ Parameters
413
+ ----------
414
+ spike_rate : xr.DataArray
415
+ The spike rate data array with dimensions (time, population, type)
416
+ where 'type' can be 'raw' or 'smoothed'
417
+ windows : List[Tuple[float, float]]
418
+ List of (start, end) times in milliseconds defining the windows to average over
419
+
420
+ Returns
421
+ -------
422
+ xr.DataArray
423
+ Averaged spike rate with time normalized to start at 0,
424
+ preserving all original dimensions (time, population, type)
425
+ """
426
+ # Check if the DataArray has a 'type' dimension (compatible with new format)
427
+ has_type_dim = "type" in spike_rate.dims
428
+
429
+ # Initialize list to store data from each window
430
+ window_data = []
431
+
432
+ # Get data for each window
433
+ for start, end in windows:
434
+ # Select data points within the window
435
+ window = spike_rate.sel(time=slice(start, end))
436
+
437
+ # Normalize time to start at 0 for this window
438
+ window = window.assign_coords(time=window.time - start)
439
+ window_data.append(window)
440
+
441
+ # Align and average windows
442
+ # First window determines the time coordinates
443
+ aligned_data = xr.concat(window_data, dim="window")
444
+ averaged_data = aligned_data.mean(dim="window")
445
+
446
+ # Create new DataArray with the averaged data
447
+ if has_type_dim:
448
+ # Create result with time, population, and type dimensions
449
+ result = xr.DataArray(
450
+ averaged_data.values,
451
+ coords={
452
+ "time": averaged_data.time.values,
453
+ "population": averaged_data.population,
454
+ "type": averaged_data.type,
455
+ },
456
+ dims=["time", "population", "type"],
457
+ )
458
+ else:
459
+ # Handle older format without 'type' dimension (for backward compatibility)
460
+ result = xr.DataArray(
461
+ averaged_data.values,
462
+ coords={"time": averaged_data.time.values, "population": averaged_data.population},
463
+ dims=["time", "population"],
464
+ )
465
+
466
+ # Preserve attributes
467
+ result.attrs = spike_rate.attrs
468
+
469
+ return result
321
470
 
322
471
 
323
472
  def compare_firing_over_times(
@@ -401,7 +550,7 @@ def compare_firing_over_times(
401
550
 
402
551
 
403
552
  def find_bursting_cells(
404
- df: pd.DataFrame, burst_threshold: float = 10, rename_bursting_cells: bool = False
553
+ df: pd.DataFrame, isi_threshold: float = 10, burst_count_threshold: int = 1
405
554
  ) -> pd.DataFrame:
406
555
  """
407
556
  Finds bursting cells in a population based on a time difference threshold.
@@ -410,10 +559,10 @@ def find_bursting_cells(
410
559
  ----------
411
560
  df : pd.DataFrame
412
561
  DataFrame containing spike data with columns for timestamps, node_ids, and pop_name
413
- burst_threshold : float, optional
562
+ isi_threshold : float, optional
414
563
  Time difference threshold in milliseconds to identify bursts
415
- rename_bursting_cells : bool, optional
416
- If True, returns a DataFrame with bursting cells renamed in their pop_name column
564
+ burst_count_threshold : int, optional
565
+ Number of bursts required to identify a bursting cell
417
566
 
418
567
  Returns
419
568
  -------
@@ -425,10 +574,11 @@ def find_bursting_cells(
425
574
  diff_df["time_diff"] = df.groupby("node_ids")["timestamps"].diff()
426
575
 
427
576
  # Create a column indicating whether each time difference is a burst
428
- diff_df["is_burst_instance"] = diff_df["time_diff"] < burst_threshold
577
+ diff_df["is_burst_instance"] = diff_df["time_diff"] < isi_threshold
429
578
 
430
579
  # Group by node_ids and check if any row has a burst instance
431
- burst_summary = diff_df.groupby("node_ids")["is_burst_instance"].any()
580
+ # check if there are enough bursts
581
+ burst_summary = diff_df.groupby("node_ids")["is_burst_instance"].sum() >= burst_count_threshold
432
582
 
433
583
  # Convert to a DataFrame with reset index
434
584
  burst_cells = burst_summary.reset_index(name="is_burst")
@@ -442,14 +592,11 @@ def find_bursting_cells(
442
592
  )
443
593
 
444
594
  # Add "_bursters" suffix only to those cells
445
- if rename_bursting_cells:
446
- burst_cells.loc[burst_mask, "pop_name"] = (
447
- burst_cells.loc[burst_mask, "pop_name"] + "_bursters"
448
- )
595
+ burst_cells.loc[burst_mask, "pop_name"] = burst_cells.loc[burst_mask, "pop_name"] + "_bursters"
449
596
 
450
- for pop in burst_cells["pop_name"].unique():
597
+ for pop in sorted(burst_cells["pop_name"].unique()):
451
598
  print(
452
- f"Number of bursters in {pop}: {burst_cells[burst_cells['pop_name'] == pop]['node_ids'].nunique()}"
599
+ f"Number of cells in {pop}: {burst_cells[burst_cells['pop_name'] == pop]['node_ids'].nunique()}"
453
600
  )
454
601
 
455
602
  return burst_cells
bmtool/bmplot/spikes.py CHANGED
@@ -1,3 +1,5 @@
1
+ """Plotting functions for neural spikes and firing rates."""
2
+
1
3
  from typing import Dict, List, Optional, Union
2
4
 
3
5
  import matplotlib.pyplot as plt
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bmtool
3
- Version: 0.7.1.4
3
+ Version: 0.7.1.6
4
4
  Summary: BMTool
5
5
  Home-page: https://github.com/cyneuro/bmtool
6
6
  Download-URL:
@@ -8,16 +8,16 @@ bmtool/plot_commands.py,sha256=Dxm_RaT4CtHnfsltTtUopJ4KVbfhxtktEB_b7bFEXII,12716
8
8
  bmtool/singlecell.py,sha256=I2yolbAnNC8qpnRkNdnDCLidNW7CktmBuRrcowMZJ3A,45041
9
9
  bmtool/synapses.py,sha256=wlRY7IixefPzafqG6k2sPIK4s6PLG9Kct-oCaVR29wA,64269
10
10
  bmtool/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
- bmtool/analysis/entrainment.py,sha256=uG2TWbeYJEg_VQB6pKEWlrVBzQ6M4h6FSAZR4GMKp-E,28178
11
+ bmtool/analysis/entrainment.py,sha256=PM4Do8Cl248Y2kIXLRFLPmUB_mH38Yhl8CUDDcunGq0,28241
12
12
  bmtool/analysis/lfp.py,sha256=S2JvxkjcK3-EH93wCrhqNSFY6cX7fOq74pz64ibHKrc,26556
13
13
  bmtool/analysis/netcon_reports.py,sha256=VnPZNKPaQA7oh1q9cIatsqQudm4cOtzNtbGPXoiDCD0,2909
14
- bmtool/analysis/spikes.py,sha256=kIcTRRMr6ofxCwnE9lrlnvlez4z07IKX0mUr4jy9jP8,17120
14
+ bmtool/analysis/spikes.py,sha256=iJfoVKl2k1X9s6C3PYz-18zlfahuRM_35wN5H9xDCIg,22715
15
15
  bmtool/bmplot/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
16
  bmtool/bmplot/connections.py,sha256=P1JBG4xCbLVq4sfQuUE6c3dO949qajrjdQcrazdmDS4,53861
17
17
  bmtool/bmplot/entrainment.py,sha256=VSlZvcSeXLr5OxGvmWcGU4s7JS7vOL38lq1XC69O_AE,6926
18
18
  bmtool/bmplot/lfp.py,sha256=SNpbWGOUnYEgnkeBw5S--aPN5mIGD22Gw2Pwus0_lvY,2034
19
19
  bmtool/bmplot/netcon_reports.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
20
- bmtool/bmplot/spikes.py,sha256=Lg8V3ynYCqk-QJvq-BOInjZMHYHrxHgXjtDOX67df-A,11148
20
+ bmtool/bmplot/spikes.py,sha256=RJOOtmgWhTvyVi1CghoKTtxvt7MF9cJCrJVm5hV5wA4,11210
21
21
  bmtool/debug/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
22
22
  bmtool/debug/commands.py,sha256=VV00f6q5gzZI503vUPeG40ABLLen0bw_k4-EX-H5WZE,580
23
23
  bmtool/debug/debug.py,sha256=9yUFvA4_Bl-x9s29quIEG3pY-S8hNJF3RKBfRBHCl28,208
@@ -26,9 +26,9 @@ bmtool/util/commands.py,sha256=Nn-R-4e9g8ZhSPZvTkr38xeKRPfEMANB9Lugppj82UI,68564
26
26
  bmtool/util/util.py,sha256=owce5BEusZO_8T5x05N2_B583G26vWAy7QX29V0Pj0Y,62818
27
27
  bmtool/util/neuron/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
28
28
  bmtool/util/neuron/celltuner.py,sha256=lokRLUM1rsdSYBYrNbLBBo39j14mm8TBNVNRnSlhHCk,94868
29
- bmtool-0.7.1.4.dist-info/licenses/LICENSE,sha256=qrXg2jj6kz5d0EnN11hllcQt2fcWVNumx0xNbV05nyM,1068
30
- bmtool-0.7.1.4.dist-info/METADATA,sha256=6lhzffBCmubbgsbW4WJyoZ4PdOo3Qoe1SjcOMaSgZwY,3577
31
- bmtool-0.7.1.4.dist-info/WHEEL,sha256=DnLRTWE75wApRYVsjgc6wsVswC54sMSJhAEd4xhDpBk,91
32
- bmtool-0.7.1.4.dist-info/entry_points.txt,sha256=0-BHZ6nUnh0twWw9SXNTiRmKjDnb1VO2DfG_-oprhAc,45
33
- bmtool-0.7.1.4.dist-info/top_level.txt,sha256=gpd2Sj-L9tWbuJEd5E8C8S8XkNm5yUE76klUYcM-eWM,7
34
- bmtool-0.7.1.4.dist-info/RECORD,,
29
+ bmtool-0.7.1.6.dist-info/licenses/LICENSE,sha256=qrXg2jj6kz5d0EnN11hllcQt2fcWVNumx0xNbV05nyM,1068
30
+ bmtool-0.7.1.6.dist-info/METADATA,sha256=_jtey-F9b0kjpQ2CELf9SxupFMMv1c-RtgmthDweFJw,3577
31
+ bmtool-0.7.1.6.dist-info/WHEEL,sha256=zaaOINJESkSfm_4HQVc5ssNzHCPXhJm0kEUakpsEHaU,91
32
+ bmtool-0.7.1.6.dist-info/entry_points.txt,sha256=0-BHZ6nUnh0twWw9SXNTiRmKjDnb1VO2DfG_-oprhAc,45
33
+ bmtool-0.7.1.6.dist-info/top_level.txt,sha256=gpd2Sj-L9tWbuJEd5E8C8S8XkNm5yUE76klUYcM-eWM,7
34
+ bmtool-0.7.1.6.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.4.0)
2
+ Generator: setuptools (80.8.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5