bmtool 0.7.1.5__py3-none-any.whl → 0.7.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bmtool/analysis/spikes.py CHANGED
@@ -3,11 +3,12 @@ Module for processing BMTK spikes output.
3
3
  """
4
4
 
5
5
  import os
6
- from typing import Dict, List, Optional, Tuple, Union
6
+ from typing import List, Optional, Tuple, Union
7
7
 
8
8
  import h5py
9
9
  import numpy as np
10
10
  import pandas as pd
11
+ import xarray as xr
11
12
  from scipy.stats import mannwhitneyu
12
13
 
13
14
  from bmtool.util.util import load_nodes_from_config
@@ -213,9 +214,11 @@ def get_population_spike_rate(
213
214
  save: bool = False,
214
215
  save_path: Optional[str] = None,
215
216
  normalize: bool = False,
216
- ) -> Dict[str, np.ndarray]:
217
+ smooth_window: int = 50, # Window size for smoothing (in time bins)
218
+ smooth_method: str = "gaussian", # Smoothing method: 'gaussian', 'boxcar', or 'exponential'
219
+ ) -> xr.DataArray:
217
220
  """
218
- Calculate the population spike rate for each population in the given spike data, with an option to normalize.
221
+ Calculate the population spike rate for each population in the given spike data.
219
222
 
220
223
  Parameters
221
224
  ----------
@@ -239,23 +242,41 @@ def get_population_spike_rate(
239
242
  Directory path where the file should be saved if `save` is True (default: None)
240
243
  normalize : bool, optional
241
244
  Whether to normalize the spike rates for each population to a range of [0, 1] (default: False)
245
+ smooth_window : int, optional
246
+ Window size for smoothing in number of time bins (default: 50)
247
+ smooth_method : str, optional
248
+ Smoothing method to use: 'gaussian', 'boxcar', or 'exponential' (default: 'gaussian')
242
249
 
243
250
  Returns
244
251
  -------
245
- Dict[str, np.ndarray]
246
- A dictionary where keys are population names, and values are arrays representing the spike rate over time for each population.
247
- If `normalize` is True, each population's spike rate is scaled to [0, 1].
252
+ xr.DataArray
253
+ An xarray DataArray containing the spike rates with dimensions of time, population, and type.
254
+ The 'type' dimension includes 'raw' and 'smoothed' values.
255
+ The DataArray includes sampling frequency (fs) as an attribute.
256
+ If normalize is True, each population's spike rate is scaled to [0, 1].
248
257
 
249
258
  Raises
250
259
  ------
251
260
  ValueError
252
261
  If `save` is True but `save_path` is not provided.
262
+ If an invalid smooth_method is specified.
253
263
 
254
264
  Notes
255
265
  -----
256
- - If `config` is None, the function assumes all cells in each population have fired at least once; otherwise, the node count may be inaccurate.
257
- - If normalization is enabled, each population's spike rate is scaled using Min-Max normalization based on its own minimum and maximum values.
266
+ - If `config` is None, the function assumes all cells in each population have fired at least once;
267
+ otherwise, the node count may be inaccurate.
268
+ - If normalization is enabled, each population's spike rate is scaled using Min-Max normalization.
269
+ - Smoothing is applied using scipy.ndimage's filters based on the specified method.
258
270
  """
271
+ import numpy as np
272
+ from scipy import ndimage
273
+
274
+ # Validate smoothing method
275
+ if smooth_method not in ["gaussian", "boxcar", "exponential"]:
276
+ raise ValueError(
277
+ f"Invalid smooth_method: {smooth_method}. Choose from 'gaussian', 'boxcar', or 'exponential'."
278
+ )
279
+
259
280
  pop_spikes = {}
260
281
  node_number = {}
261
282
 
@@ -271,7 +292,13 @@ def get_population_spike_rate(
271
292
  "Grabbing first network; specify a network name to ensure correct node population is selected."
272
293
  )
273
294
 
274
- for pop_name in spike_data["pop_name"].unique():
295
+ # Get t_stop if not provided
296
+ if t_stop is None:
297
+ t_stop = spike_data["timestamps"].max()
298
+
299
+ # Get population names and prepare data
300
+ populations = spike_data["pop_name"].unique()
301
+ for pop_name in populations:
275
302
  ps = spike_data[spike_data["pop_name"] == pop_name]
276
303
 
277
304
  if config:
@@ -282,12 +309,10 @@ def get_population_spike_rate(
282
309
  nodes = list(nodes.values())[0] if nodes else {}
283
310
  nodes = nodes[nodes["pop_name"] == pop_name]
284
311
  node_number[pop_name] = nodes.index.nunique()
312
+
285
313
  else:
286
314
  node_number[pop_name] = ps["node_ids"].nunique()
287
315
 
288
- if t_stop is None:
289
- t_stop = spike_data["timestamps"].max()
290
-
291
316
  filtered_spikes = spike_data[
292
317
  (spike_data["pop_name"] == pop_name)
293
318
  & (spike_data["timestamps"] > t_start)
@@ -295,29 +320,153 @@ def get_population_spike_rate(
295
320
  ]
296
321
  pop_spikes[pop_name] = filtered_spikes
297
322
 
298
- time = np.array([t_start, t_stop, 1000 / fs])
299
- pop_rspk = {p: _pop_spike_rate(spk["timestamps"], time) for p, spk in pop_spikes.items()}
300
- spike_rate = {p: fs / node_number[p] * pop_rspk[p] for p in pop_rspk}
323
+ # Calculate time points
324
+ time = np.arange(t_start, t_stop, 1000 / fs) # Convert sampling frequency to time steps
325
+
326
+ # Calculate spike rates for each population
327
+ spike_rates = []
328
+ for p in populations:
329
+ raw_rate = _pop_spike_rate(pop_spikes[p]["timestamps"], (t_start, t_stop, 1000 / fs))
330
+ rate = fs / node_number[p] * raw_rate
331
+ spike_rates.append(rate)
332
+
333
+ spike_rates_array = np.array(spike_rates).T # Transpose to have time as first dimension
334
+
335
+ # Calculate smoothed version for each population
336
+ smoothed_rates = []
337
+
338
+ for i in range(spike_rates_array.shape[1]):
339
+ pop_rate = spike_rates_array[:, i]
340
+
341
+ if smooth_method == "gaussian":
342
+ # Gaussian smoothing (sigma is approximately window/6 for a Gaussian filter)
343
+ sigma = smooth_window / 6
344
+ smoothed_pop_rate = ndimage.gaussian_filter1d(pop_rate, sigma=sigma)
345
+ elif smooth_method == "boxcar":
346
+ # Boxcar/uniform smoothing
347
+ kernel = np.ones(smooth_window) / smooth_window
348
+ smoothed_pop_rate = ndimage.convolve1d(pop_rate, kernel, mode="nearest")
349
+ elif smooth_method == "exponential":
350
+ # Exponential smoothing
351
+ alpha = 2 / (smooth_window + 1) # Equivalent to window size in exponential smoothing
352
+ smoothed_pop_rate = np.zeros_like(pop_rate)
353
+ smoothed_pop_rate[0] = pop_rate[0]
354
+ for t in range(1, len(pop_rate)):
355
+ smoothed_pop_rate[t] = alpha * pop_rate[t] + (1 - alpha) * smoothed_pop_rate[t - 1]
356
+
357
+ smoothed_rates.append(smoothed_pop_rate)
358
+
359
+ smoothed_rates_array = np.array(smoothed_rates).T # Transpose to have time as first dimension
360
+
361
+ # Stack raw and smoothed data
362
+ combined_data = np.stack([spike_rates_array, smoothed_rates_array], axis=2)
363
+
364
+ # Create DataArray with the additional 'type' dimension
365
+ spike_rate_array = xr.DataArray(
366
+ combined_data,
367
+ coords={"time": time, "population": populations, "type": ["raw", "smoothed"]},
368
+ dims=["time", "population", "type"],
369
+ attrs={
370
+ "fs": fs,
371
+ "normalized": False,
372
+ "smooth_method": smooth_method,
373
+ "smooth_window": smooth_window,
374
+ },
375
+ )
301
376
 
302
- # Normalize each spike rate series if normalize=True
377
+ # Normalize if requested
303
378
  if normalize:
304
- spike_rate = {p: (sr - sr.min()) / (sr.max() - sr.min()) for p, sr in spike_rate.items()}
305
-
379
+ # Apply normalization for each population and each type (raw/smoothed)
380
+ for pop_idx in range(len(populations)):
381
+ for type_idx, type_name in enumerate(["raw", "smoothed"]):
382
+ pop_data = spike_rate_array.sel(population=populations[pop_idx], type=type_name)
383
+ min_val = pop_data.min(dim="time")
384
+ max_val = pop_data.max(dim="time")
385
+
386
+ # Handle case where min == max (constant signal)
387
+ if max_val != min_val:
388
+ spike_rate_array.loc[:, populations[pop_idx], type_name] = (
389
+ pop_data - min_val
390
+ ) / (max_val - min_val)
391
+
392
+ spike_rate_array.attrs["normalized"] = True
393
+
394
+ # Save if requested
306
395
  if save:
307
396
  if save_path is None:
308
397
  raise ValueError("save_path must be provided if save is True.")
309
398
 
310
399
  os.makedirs(save_path, exist_ok=True)
311
-
312
400
  save_file = os.path.join(save_path, "spike_rate.h5")
313
- with h5py.File(save_file, "w") as f:
314
- f.create_dataset("time", data=time)
315
- grp = f.create_group("populations")
316
- for p, rspk in spike_rate.items():
317
- pop_grp = grp.create_group(p)
318
- pop_grp.create_dataset("data", data=rspk)
401
+ spike_rate_array.to_netcdf(save_file)
319
402
 
320
- return spike_rate
403
+ return spike_rate_array
404
+
405
+
406
+ def average_spike_rate_over_windows(
407
+ spike_rate: xr.DataArray, windows: List[Tuple[float, float]]
408
+ ) -> xr.DataArray:
409
+ """
410
+ Calculate the average spike rate over multiple time windows.
411
+
412
+ Parameters
413
+ ----------
414
+ spike_rate : xr.DataArray
415
+ The spike rate data array with dimensions (time, population, type)
416
+ where 'type' can be 'raw' or 'smoothed'
417
+ windows : List[Tuple[float, float]]
418
+ List of (start, end) times in milliseconds defining the windows to average over
419
+
420
+ Returns
421
+ -------
422
+ xr.DataArray
423
+ Averaged spike rate with time normalized to start at 0,
424
+ preserving all original dimensions (time, population, type)
425
+ """
426
+ # Check if the DataArray has a 'type' dimension (compatible with new format)
427
+ has_type_dim = "type" in spike_rate.dims
428
+
429
+ # Initialize list to store data from each window
430
+ window_data = []
431
+
432
+ # Get data for each window
433
+ for start, end in windows:
434
+ # Select data points within the window
435
+ window = spike_rate.sel(time=slice(start, end))
436
+
437
+ # Normalize time to start at 0 for this window
438
+ window = window.assign_coords(time=window.time - start)
439
+ window_data.append(window)
440
+
441
+ # Align and average windows
442
+ # First window determines the time coordinates
443
+ aligned_data = xr.concat(window_data, dim="window")
444
+ averaged_data = aligned_data.mean(dim="window")
445
+
446
+ # Create new DataArray with the averaged data
447
+ if has_type_dim:
448
+ # Create result with time, population, and type dimensions
449
+ result = xr.DataArray(
450
+ averaged_data.values,
451
+ coords={
452
+ "time": averaged_data.time.values,
453
+ "population": averaged_data.population,
454
+ "type": averaged_data.type,
455
+ },
456
+ dims=["time", "population", "type"],
457
+ )
458
+ else:
459
+ # Handle older format without 'type' dimension (for backward compatibility)
460
+ result = xr.DataArray(
461
+ averaged_data.values,
462
+ coords={"time": averaged_data.time.values, "population": averaged_data.population},
463
+ dims=["time", "population"],
464
+ )
465
+
466
+ # Preserve attributes
467
+ result.attrs = spike_rate.attrs
468
+
469
+ return result
321
470
 
322
471
 
323
472
  def compare_firing_over_times(
bmtool/bmplot/spikes.py CHANGED
@@ -1,3 +1,5 @@
1
+ """Plotting functions for neural spikes and firing rates."""
2
+
1
3
  from typing import Dict, List, Optional, Union
2
4
 
3
5
  import matplotlib.pyplot as plt
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bmtool
3
- Version: 0.7.1.5
3
+ Version: 0.7.1.6
4
4
  Summary: BMTool
5
5
  Home-page: https://github.com/cyneuro/bmtool
6
6
  Download-URL:
@@ -11,13 +11,13 @@ bmtool/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
11
  bmtool/analysis/entrainment.py,sha256=PM4Do8Cl248Y2kIXLRFLPmUB_mH38Yhl8CUDDcunGq0,28241
12
12
  bmtool/analysis/lfp.py,sha256=S2JvxkjcK3-EH93wCrhqNSFY6cX7fOq74pz64ibHKrc,26556
13
13
  bmtool/analysis/netcon_reports.py,sha256=VnPZNKPaQA7oh1q9cIatsqQudm4cOtzNtbGPXoiDCD0,2909
14
- bmtool/analysis/spikes.py,sha256=IHxV7_X8ojh4NDVBjzHCzfHF8muPPef2UtH3yqYre78,17091
14
+ bmtool/analysis/spikes.py,sha256=iJfoVKl2k1X9s6C3PYz-18zlfahuRM_35wN5H9xDCIg,22715
15
15
  bmtool/bmplot/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
16
  bmtool/bmplot/connections.py,sha256=P1JBG4xCbLVq4sfQuUE6c3dO949qajrjdQcrazdmDS4,53861
17
17
  bmtool/bmplot/entrainment.py,sha256=VSlZvcSeXLr5OxGvmWcGU4s7JS7vOL38lq1XC69O_AE,6926
18
18
  bmtool/bmplot/lfp.py,sha256=SNpbWGOUnYEgnkeBw5S--aPN5mIGD22Gw2Pwus0_lvY,2034
19
19
  bmtool/bmplot/netcon_reports.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
20
- bmtool/bmplot/spikes.py,sha256=Lg8V3ynYCqk-QJvq-BOInjZMHYHrxHgXjtDOX67df-A,11148
20
+ bmtool/bmplot/spikes.py,sha256=RJOOtmgWhTvyVi1CghoKTtxvt7MF9cJCrJVm5hV5wA4,11210
21
21
  bmtool/debug/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
22
22
  bmtool/debug/commands.py,sha256=VV00f6q5gzZI503vUPeG40ABLLen0bw_k4-EX-H5WZE,580
23
23
  bmtool/debug/debug.py,sha256=9yUFvA4_Bl-x9s29quIEG3pY-S8hNJF3RKBfRBHCl28,208
@@ -26,9 +26,9 @@ bmtool/util/commands.py,sha256=Nn-R-4e9g8ZhSPZvTkr38xeKRPfEMANB9Lugppj82UI,68564
26
26
  bmtool/util/util.py,sha256=owce5BEusZO_8T5x05N2_B583G26vWAy7QX29V0Pj0Y,62818
27
27
  bmtool/util/neuron/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
28
28
  bmtool/util/neuron/celltuner.py,sha256=lokRLUM1rsdSYBYrNbLBBo39j14mm8TBNVNRnSlhHCk,94868
29
- bmtool-0.7.1.5.dist-info/licenses/LICENSE,sha256=qrXg2jj6kz5d0EnN11hllcQt2fcWVNumx0xNbV05nyM,1068
30
- bmtool-0.7.1.5.dist-info/METADATA,sha256=pvpABD7P2ytzO08EYUC8HrbUTg_fk8pc67mBaEvi7-M,3577
31
- bmtool-0.7.1.5.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
32
- bmtool-0.7.1.5.dist-info/entry_points.txt,sha256=0-BHZ6nUnh0twWw9SXNTiRmKjDnb1VO2DfG_-oprhAc,45
33
- bmtool-0.7.1.5.dist-info/top_level.txt,sha256=gpd2Sj-L9tWbuJEd5E8C8S8XkNm5yUE76klUYcM-eWM,7
34
- bmtool-0.7.1.5.dist-info/RECORD,,
29
+ bmtool-0.7.1.6.dist-info/licenses/LICENSE,sha256=qrXg2jj6kz5d0EnN11hllcQt2fcWVNumx0xNbV05nyM,1068
30
+ bmtool-0.7.1.6.dist-info/METADATA,sha256=_jtey-F9b0kjpQ2CELf9SxupFMMv1c-RtgmthDweFJw,3577
31
+ bmtool-0.7.1.6.dist-info/WHEEL,sha256=zaaOINJESkSfm_4HQVc5ssNzHCPXhJm0kEUakpsEHaU,91
32
+ bmtool-0.7.1.6.dist-info/entry_points.txt,sha256=0-BHZ6nUnh0twWw9SXNTiRmKjDnb1VO2DfG_-oprhAc,45
33
+ bmtool-0.7.1.6.dist-info/top_level.txt,sha256=gpd2Sj-L9tWbuJEd5E8C8S8XkNm5yUE76klUYcM-eWM,7
34
+ bmtool-0.7.1.6.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.7.1)
2
+ Generator: setuptools (80.8.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5