bmtool 0.7.0.6.4__py3-none-any.whl → 0.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,97 +1,94 @@
1
1
  import h5py
2
2
  import numpy as np
3
3
  import xarray as xr
4
- import pandas as pd
4
+
5
5
  from ..util.util import load_nodes_from_config
6
6
 
7
+
7
8
  def load_synapse_report(h5_file_path, config_path, network):
8
9
  """
9
10
  Load and process a synapse report from a bmtk simulation into an xarray.
10
-
11
+
11
12
  Parameters:
12
13
  -----------
13
14
  h5_file_path : str
14
15
  Path to the h5 file containing the synapse report
15
16
  config_path : str
16
17
  Path to the simulation configuration file
17
-
18
+
18
19
  Returns:
19
20
  --------
20
21
  xarray.Dataset
21
22
  An xarray containing the synapse report data with proper population labeling
22
23
  """
23
24
  # Load the h5 file
24
- with h5py.File(h5_file_path, 'r') as file:
25
+ with h5py.File(h5_file_path, "r") as file:
25
26
  # Get the report data
26
- report = file['report'][network]
27
- mapping = report['mapping']
28
-
27
+ report = file["report"][network]
28
+ mapping = report["mapping"]
29
+
29
30
  # Get the data - shape is (n_timesteps, n_synapses)
30
- data = report['data'][:]
31
-
31
+ data = report["data"][:]
32
+
32
33
  # Get time information
33
- time_info = mapping['time'][:] # [start_time, end_time, dt]
34
+ time_info = mapping["time"][:] # [start_time, end_time, dt]
34
35
  start_time = time_info[0]
35
36
  end_time = time_info[1]
36
37
  dt = time_info[2]
37
-
38
+
38
39
  # Create time array
39
40
  n_steps = data.shape[0]
40
- time = np.linspace(start_time, start_time + (n_steps-1)*dt, n_steps)
41
-
41
+ time = np.linspace(start_time, start_time + (n_steps - 1) * dt, n_steps)
42
+
42
43
  # Get mapping information
43
- src_ids = mapping['src_ids'][:]
44
- trg_ids = mapping['trg_ids'][:]
45
- sec_id = mapping['element_ids'][:]
46
- sec_x = mapping['element_pos'][:]
47
-
44
+ src_ids = mapping["src_ids"][:]
45
+ trg_ids = mapping["trg_ids"][:]
46
+ sec_id = mapping["element_ids"][:]
47
+ sec_x = mapping["element_pos"][:]
48
+
48
49
  # Load node information
49
50
  nodes = load_nodes_from_config(config_path)
50
51
  nodes = nodes[network]
51
-
52
+
52
53
  # Create a mapping from node IDs to population names
53
- node_to_pop = dict(zip(nodes.index, nodes['pop_name']))
54
-
54
+ node_to_pop = dict(zip(nodes.index, nodes["pop_name"]))
55
+
55
56
  # Get the number of synapses
56
57
  n_synapses = data.shape[1]
57
-
58
+
58
59
  # Create arrays to hold the source and target populations for each synapse
59
60
  source_pops = []
60
61
  target_pops = []
61
62
  connection_labels = []
62
-
63
+
63
64
  # Process each synapse
64
65
  for i in range(n_synapses):
65
66
  src_id = src_ids[i]
66
67
  trg_id = trg_ids[i]
67
-
68
+
68
69
  # Get population names (with fallback for unknown IDs)
69
- src_pop = node_to_pop.get(src_id, f'unknown_{src_id}')
70
- trg_pop = node_to_pop.get(trg_id, f'unknown_{trg_id}')
71
-
70
+ src_pop = node_to_pop.get(src_id, f"unknown_{src_id}")
71
+ trg_pop = node_to_pop.get(trg_id, f"unknown_{trg_id}")
72
+
72
73
  source_pops.append(src_pop)
73
74
  target_pops.append(trg_pop)
74
75
  connection_labels.append(f"{src_pop}->{trg_pop}")
75
-
76
+
76
77
  # Create xarray dataset
77
78
  ds = xr.Dataset(
78
- data_vars={
79
- 'synapse_value': (['time', 'synapse'], data)
80
- },
79
+ data_vars={"synapse_value": (["time", "synapse"], data)},
81
80
  coords={
82
- 'time': time,
83
- 'synapse': np.arange(n_synapses),
84
- 'source_pop': ('synapse', source_pops),
85
- 'target_pop': ('synapse', target_pops),
86
- 'source_id': ('synapse', src_ids),
87
- 'target_id': ('synapse', trg_ids),
88
- 'sec_id': ('synapse', sec_id),
89
- 'sec_x': ('synapse', sec_x),
90
- 'connection_label': ('synapse', connection_labels)
81
+ "time": time,
82
+ "synapse": np.arange(n_synapses),
83
+ "source_pop": ("synapse", source_pops),
84
+ "target_pop": ("synapse", target_pops),
85
+ "source_id": ("synapse", src_ids),
86
+ "target_id": ("synapse", trg_ids),
87
+ "sec_id": ("synapse", sec_id),
88
+ "sec_x": ("synapse", sec_x),
89
+ "connection_label": ("synapse", connection_labels),
91
90
  },
92
- attrs={
93
- 'description': 'Synapse report data from bmtk simulation'
94
- }
91
+ attrs={"description": "Synapse report data from bmtk simulation"},
95
92
  )
96
-
93
+
97
94
  return ds
bmtool/analysis/spikes.py CHANGED
@@ -2,16 +2,24 @@
2
2
  Module for processing BMTK spikes output.
3
3
  """
4
4
 
5
+ import os
6
+ from typing import Dict, List, Optional, Tuple, Union
7
+
5
8
  import h5py
6
- import pandas as pd
7
- from bmtool.util.util import load_nodes_from_config
8
- from typing import Dict, Optional,Tuple, Union, List
9
9
  import numpy as np
10
+ import pandas as pd
10
11
  from scipy.stats import mannwhitneyu
11
- import os
12
+
13
+ from bmtool.util.util import load_nodes_from_config
12
14
 
13
15
 
14
- def load_spikes_to_df(spike_file: str, network_name: str, sort: bool = True, config: str = None, groupby: Union[str, List[str]] = 'pop_name') -> pd.DataFrame:
16
+ def load_spikes_to_df(
17
+ spike_file: str,
18
+ network_name: str,
19
+ sort: bool = True,
20
+ config: str = None,
21
+ groupby: Union[str, List[str]] = "pop_name",
22
+ ) -> pd.DataFrame:
15
23
  """
16
24
  Load spike data from an HDF5 file into a pandas DataFrame.
17
25
 
@@ -33,21 +41,23 @@ def load_spikes_to_df(spike_file: str, network_name: str, sort: bool = True, con
33
41
  pd.DataFrame
34
42
  A pandas DataFrame containing 'node_ids' and 'timestamps' columns from the spike data,
35
43
  with additional columns if a config file is provided
36
-
44
+
37
45
  Examples
38
46
  --------
39
47
  >>> df = load_spikes_to_df("spikes.h5", "cortex")
40
48
  >>> df = load_spikes_to_df("spikes.h5", "cortex", config="config.json", groupby=["pop_name", "model_type"])
41
49
  """
42
50
  with h5py.File(spike_file) as f:
43
- spikes_df = pd.DataFrame({
44
- 'node_ids': f['spikes'][network_name]['node_ids'],
45
- 'timestamps': f['spikes'][network_name]['timestamps']
46
- })
51
+ spikes_df = pd.DataFrame(
52
+ {
53
+ "node_ids": f["spikes"][network_name]["node_ids"],
54
+ "timestamps": f["spikes"][network_name]["timestamps"],
55
+ }
56
+ )
47
57
 
48
58
  if sort:
49
- spikes_df.sort_values(by='timestamps', inplace=True, ignore_index=True)
50
-
59
+ spikes_df.sort_values(by="timestamps", inplace=True, ignore_index=True)
60
+
51
61
  if config:
52
62
  nodes = load_nodes_from_config(config)
53
63
  nodes = nodes[network_name]
@@ -61,12 +71,19 @@ def load_spikes_to_df(spike_file: str, network_name: str, sort: bool = True, con
61
71
  if missing_cols:
62
72
  raise KeyError(f"Columns {missing_cols} not found in nodes DataFrame.")
63
73
 
64
- spikes_df = spikes_df.merge(nodes[groupby], left_on='node_ids', right_index=True, how='left')
74
+ spikes_df = spikes_df.merge(
75
+ nodes[groupby], left_on="node_ids", right_index=True, how="left"
76
+ )
65
77
 
66
78
  return spikes_df
67
79
 
68
80
 
69
- def compute_firing_rate_stats(df: pd.DataFrame, groupby: Union[str, List[str]] = "pop_name", start_time: float = None, stop_time: float = None) -> Tuple[pd.DataFrame, pd.DataFrame]:
81
+ def compute_firing_rate_stats(
82
+ df: pd.DataFrame,
83
+ groupby: Union[str, List[str]] = "pop_name",
84
+ start_time: float = None,
85
+ stop_time: float = None,
86
+ ) -> Tuple[pd.DataFrame, pd.DataFrame]:
70
87
  """
71
88
  Computes the firing rates of individual nodes and the mean and standard deviation of firing rates per group.
72
89
 
@@ -77,7 +94,7 @@ def compute_firing_rate_stats(df: pd.DataFrame, groupby: Union[str, List[str]] =
77
94
  stop_time (float, optional): Stop time for the analysis window. Defaults to the maximum timestamp in the data.
78
95
 
79
96
  Returns:
80
- Tuple[pd.DataFrame, pd.DataFrame]:
97
+ Tuple[pd.DataFrame, pd.DataFrame]:
81
98
  - The first DataFrame (`pop_stats`) contains the mean and standard deviation of firing rates per group.
82
99
  - The second DataFrame (`individual_stats`) contains the firing rate of each individual node.
83
100
  """
@@ -85,7 +102,7 @@ def compute_firing_rate_stats(df: pd.DataFrame, groupby: Union[str, List[str]] =
85
102
  # Ensure groupby is a list
86
103
  if isinstance(groupby, str):
87
104
  groupby = [groupby]
88
-
105
+
89
106
  # Ensure all columns exist in the dataframe
90
107
  for col in groupby:
91
108
  if col not in df.columns:
@@ -102,42 +119,47 @@ def compute_firing_rate_stats(df: pd.DataFrame, groupby: Union[str, List[str]] =
102
119
  min_time = df["timestamps"].min()
103
120
  else:
104
121
  min_time = start_time
105
-
106
- if stop_time is None:
122
+
123
+ if stop_time is None:
107
124
  max_time = df["timestamps"].max()
108
125
  else:
109
126
  max_time = stop_time
110
-
127
+
111
128
  duration = max_time - min_time # Avoid division by zero
112
129
 
113
130
  if duration <= 0:
114
131
  raise ValueError("Invalid time window: Stop time must be greater than start time.")
115
132
 
116
133
  # Compute firing rate for each node
117
- import pandas as pd
118
134
 
119
135
  # Compute spike counts per node
120
136
  spike_counts = df["node_ids"].value_counts().reset_index()
121
137
  spike_counts.columns = ["node_ids", "spike_count"] # Rename columns
122
138
 
123
139
  # Merge with original dataframe to get corresponding labels (e.g., 'pop_name')
124
- spike_counts = spike_counts.merge(df[["node_ids"] + groupby].drop_duplicates(), on="node_ids", how="left")
140
+ spike_counts = spike_counts.merge(
141
+ df[["node_ids"] + groupby].drop_duplicates(), on="node_ids", how="left"
142
+ )
125
143
 
126
144
  # Compute firing rate
127
- spike_counts["firing_rate"] = spike_counts["spike_count"] / duration * 1000 # scale to Hz
145
+ spike_counts["firing_rate"] = spike_counts["spike_count"] / duration * 1000 # scale to Hz
128
146
  indivdual_stats = spike_counts
129
-
147
+
130
148
  # Compute mean and standard deviation per group
131
149
  pop_stats = spike_counts.groupby(groupby)["firing_rate"].agg(["mean", "std"]).reset_index()
132
150
 
133
151
  # Rename columns
134
152
  pop_stats.rename(columns={"mean": "firing_rate_mean", "std": "firing_rate_std"}, inplace=True)
135
153
 
136
- return pop_stats,indivdual_stats
154
+ return pop_stats, indivdual_stats
137
155
 
138
156
 
139
- def _pop_spike_rate(spike_times: Union[np.ndarray, list], time: Optional[Tuple[float, float, float]] = None,
140
- time_points: Optional[Union[np.ndarray, list]] = None, frequency: bool = False) -> np.ndarray:
157
+ def _pop_spike_rate(
158
+ spike_times: Union[np.ndarray, list],
159
+ time: Optional[Tuple[float, float, float]] = None,
160
+ time_points: Optional[Union[np.ndarray, list]] = None,
161
+ frequency: bool = False,
162
+ ) -> np.ndarray:
141
163
  """
142
164
  Calculate the spike count or frequency histogram over specified time intervals.
143
165
 
@@ -146,7 +168,7 @@ def _pop_spike_rate(spike_times: Union[np.ndarray, list], time: Optional[Tuple[f
146
168
  spike_times : Union[np.ndarray, list]
147
169
  Array or list of spike times in milliseconds
148
170
  time : Optional[Tuple[float, float, float]], optional
149
- Tuple specifying (start, stop, step) in milliseconds. Used to create evenly spaced time points
171
+ Tuple specifying (start, stop, step) in milliseconds. Used to create evenly spaced time points
150
172
  if `time_points` is not provided. Default is None.
151
173
  time_points : Optional[Union[np.ndarray, list]], optional
152
174
  Array or list of specific time points for binning. If provided, `time` is ignored. Default is None.
@@ -171,20 +193,27 @@ def _pop_spike_rate(spike_times: Union[np.ndarray, list], time: Optional[Tuple[f
171
193
  else:
172
194
  time_points = np.asarray(time_points).ravel()
173
195
  dt = (time_points[-1] - time_points[0]) / (time_points.size - 1)
174
-
196
+
175
197
  bins = np.append(time_points, time_points[-1] + dt)
176
198
  spike_rate, _ = np.histogram(np.asarray(spike_times), bins)
177
-
199
+
178
200
  if frequency:
179
201
  spike_rate = 1000 / dt * spike_rate
180
-
202
+
181
203
  return spike_rate
182
204
 
183
205
 
184
- def get_population_spike_rate(spike_data: pd.DataFrame, fs: float = 400.0, t_start: float = 0, t_stop: Optional[float] = None,
185
- config: Optional[str] = None, network_name: Optional[str] = None,
186
- save: bool = False, save_path: Optional[str] = None,
187
- normalize: bool = False) -> Dict[str, np.ndarray]:
206
+ def get_population_spike_rate(
207
+ spike_data: pd.DataFrame,
208
+ fs: float = 400.0,
209
+ t_start: float = 0,
210
+ t_stop: Optional[float] = None,
211
+ config: Optional[str] = None,
212
+ network_name: Optional[str] = None,
213
+ save: bool = False,
214
+ save_path: Optional[str] = None,
215
+ normalize: bool = False,
216
+ ) -> Dict[str, np.ndarray]:
188
217
  """
189
218
  Calculate the population spike rate for each population in the given spike data, with an option to normalize.
190
219
 
@@ -231,39 +260,43 @@ def get_population_spike_rate(spike_data: pd.DataFrame, fs: float = 400.0, t_sta
231
260
  node_number = {}
232
261
 
233
262
  if config is None:
234
- print("Note: Node number is obtained by counting unique node spikes in the network.\nIf the network did not run for a sufficient duration, and not all cells fired, this count might be incorrect.")
263
+ print(
264
+ "Note: Node number is obtained by counting unique node spikes in the network.\nIf the network did not run for a sufficient duration, and not all cells fired, this count might be incorrect."
265
+ )
235
266
  print("You can provide a config to calculate the correct amount of nodes!")
236
-
267
+
237
268
  if config:
238
269
  if not network_name:
239
- print("Grabbing first network; specify a network name to ensure correct node population is selected.")
270
+ print(
271
+ "Grabbing first network; specify a network name to ensure correct node population is selected."
272
+ )
273
+
274
+ for pop_name in spike_data["pop_name"].unique():
275
+ ps = spike_data[spike_data["pop_name"] == pop_name]
240
276
 
241
- for pop_name in spike_data['pop_name'].unique():
242
- ps = spike_data[spike_data['pop_name'] == pop_name]
243
-
244
277
  if config:
245
278
  nodes = load_nodes_from_config(config)
246
279
  if network_name:
247
280
  nodes = nodes[network_name]
248
281
  else:
249
282
  nodes = list(nodes.values())[0] if nodes else {}
250
- nodes = nodes[nodes['pop_name'] == pop_name]
283
+ nodes = nodes[nodes["pop_name"] == pop_name]
251
284
  node_number[pop_name] = nodes.index.nunique()
252
285
  else:
253
- node_number[pop_name] = ps['node_ids'].nunique()
286
+ node_number[pop_name] = ps["node_ids"].nunique()
254
287
 
255
288
  if t_stop is None:
256
- t_stop = spike_data['timestamps'].max()
289
+ t_stop = spike_data["timestamps"].max()
257
290
 
258
291
  filtered_spikes = spike_data[
259
- (spike_data['pop_name'] == pop_name) &
260
- (spike_data['timestamps'] > t_start) &
261
- (spike_data['timestamps'] < t_stop)
292
+ (spike_data["pop_name"] == pop_name)
293
+ & (spike_data["timestamps"] > t_start)
294
+ & (spike_data["timestamps"] < t_stop)
262
295
  ]
263
296
  pop_spikes[pop_name] = filtered_spikes
264
297
 
265
298
  time = np.array([t_start, t_stop, 1000 / fs])
266
- pop_rspk = {p: _pop_spike_rate(spk['timestamps'], time) for p, spk in pop_spikes.items()}
299
+ pop_rspk = {p: _pop_spike_rate(spk["timestamps"], time) for p, spk in pop_spikes.items()}
267
300
  spike_rate = {p: fs / node_number[p] * pop_rspk[p] for p in pop_rspk}
268
301
 
269
302
  # Normalize each spike rate series if normalize=True
@@ -273,25 +306,27 @@ def get_population_spike_rate(spike_data: pd.DataFrame, fs: float = 400.0, t_sta
273
306
  if save:
274
307
  if save_path is None:
275
308
  raise ValueError("save_path must be provided if save is True.")
276
-
309
+
277
310
  os.makedirs(save_path, exist_ok=True)
278
-
279
- save_file = os.path.join(save_path, 'spike_rate.h5')
280
- with h5py.File(save_file, 'w') as f:
281
- f.create_dataset('time', data=time)
282
- grp = f.create_group('populations')
311
+
312
+ save_file = os.path.join(save_path, "spike_rate.h5")
313
+ with h5py.File(save_file, "w") as f:
314
+ f.create_dataset("time", data=time)
315
+ grp = f.create_group("populations")
283
316
  for p, rspk in spike_rate.items():
284
317
  pop_grp = grp.create_group(p)
285
- pop_grp.create_dataset('data', data=rspk)
318
+ pop_grp.create_dataset("data", data=rspk)
286
319
 
287
320
  return spike_rate
288
321
 
289
322
 
290
- def compare_firing_over_times(spike_df: pd.DataFrame, group_by: str, time_window_1: List[float], time_window_2: List[float]) -> None:
323
+ def compare_firing_over_times(
324
+ spike_df: pd.DataFrame, group_by: str, time_window_1: List[float], time_window_2: List[float]
325
+ ) -> None:
291
326
  """
292
327
  Compares the firing rates of a population during two different time windows and performs
293
328
  a statistical test to determine if there is a significant difference.
294
-
329
+
295
330
  Parameters
296
331
  ----------
297
332
  spike_df : pd.DataFrame
@@ -302,12 +337,12 @@ def compare_firing_over_times(spike_df: pd.DataFrame, group_by: str, time_window
302
337
  First time window as [start, stop] in milliseconds
303
338
  time_window_2 : List[float]
304
339
  Second time window as [start, stop] in milliseconds
305
-
340
+
306
341
  Returns
307
342
  -------
308
343
  None
309
344
  Results are printed to the console
310
-
345
+
311
346
  Notes
312
347
  -----
313
348
  Uses Mann-Whitney U test (non-parametric) to compare firing rates between the two windows
@@ -316,38 +351,44 @@ def compare_firing_over_times(spike_df: pd.DataFrame, group_by: str, time_window
316
351
  for pop_name in spike_df[group_by].unique():
317
352
  print(f"Population: {pop_name}")
318
353
  pop_spikes = spike_df[spike_df[group_by] == pop_name]
319
-
354
+
320
355
  # Filter by time windows
321
- pop_spikes_1 = pop_spikes[(pop_spikes['timestamps'] >= time_window_1[0]) & (pop_spikes['timestamps'] <= time_window_1[1])]
322
- pop_spikes_2 = pop_spikes[(pop_spikes['timestamps'] >= time_window_2[0]) & (pop_spikes['timestamps'] <= time_window_2[1])]
323
-
356
+ pop_spikes_1 = pop_spikes[
357
+ (pop_spikes["timestamps"] >= time_window_1[0])
358
+ & (pop_spikes["timestamps"] <= time_window_1[1])
359
+ ]
360
+ pop_spikes_2 = pop_spikes[
361
+ (pop_spikes["timestamps"] >= time_window_2[0])
362
+ & (pop_spikes["timestamps"] <= time_window_2[1])
363
+ ]
364
+
324
365
  # Get unique neuron IDs
325
- unique_neurons = pop_spikes['node_ids'].unique()
326
-
366
+ unique_neurons = pop_spikes["node_ids"].unique()
367
+
327
368
  # Calculate firing rates per neuron for each time window in Hz
328
369
  neuron_rates_1 = []
329
370
  neuron_rates_2 = []
330
-
371
+
331
372
  for neuron in unique_neurons:
332
373
  # Count spikes for this neuron in each window
333
- n_spikes_1 = len(pop_spikes_1[pop_spikes_1['node_ids'] == neuron])
334
- n_spikes_2 = len(pop_spikes_2[pop_spikes_2['node_ids'] == neuron])
335
-
374
+ n_spikes_1 = len(pop_spikes_1[pop_spikes_1["node_ids"] == neuron])
375
+ n_spikes_2 = len(pop_spikes_2[pop_spikes_2["node_ids"] == neuron])
376
+
336
377
  # Calculate firing rate in Hz (convert ms to seconds by dividing by 1000)
337
378
  rate_1 = n_spikes_1 / ((time_window_1[1] - time_window_1[0]) / 1000)
338
379
  rate_2 = n_spikes_2 / ((time_window_2[1] - time_window_2[0]) / 1000)
339
-
380
+
340
381
  neuron_rates_1.append(rate_1)
341
382
  neuron_rates_2.append(rate_2)
342
-
383
+
343
384
  # Calculate average firing rates
344
385
  avg_firing_rate_1 = np.mean(neuron_rates_1) if neuron_rates_1 else 0
345
386
  avg_firing_rate_2 = np.mean(neuron_rates_2) if neuron_rates_2 else 0
346
-
387
+
347
388
  # Perform Mann-Whitney U test
348
389
  # Handle the case when one or both arrays are empty
349
390
  if len(neuron_rates_1) > 0 and len(neuron_rates_2) > 0:
350
- u_stat, p_val = mannwhitneyu(neuron_rates_1, neuron_rates_2, alternative='two-sided')
391
+ u_stat, p_val = mannwhitneyu(neuron_rates_1, neuron_rates_2, alternative="two-sided")
351
392
  else:
352
393
  u_stat, p_val = np.nan, np.nan
353
394
 
@@ -356,4 +397,4 @@ def compare_firing_over_times(spike_df: pd.DataFrame, group_by: str, time_window
356
397
  print(f" U-statistic: {u_stat:.2f}")
357
398
  print(f" p-value: {p_val}")
358
399
  print(f" Significant difference (p<0.05): {'Yes' if p_val < 0.05 else 'No'}")
359
- return
400
+ return