bmtool 0.7.0.6.4__py3-none-any.whl → 0.7.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bmtool/SLURM.py +162 -109
- bmtool/__init__.py +1 -1
- bmtool/__main__.py +8 -7
- bmtool/analysis/entrainment.py +290 -147
- bmtool/analysis/lfp.py +279 -134
- bmtool/analysis/netcon_reports.py +41 -44
- bmtool/analysis/spikes.py +114 -73
- bmtool/bmplot/connections.py +658 -325
- bmtool/bmplot/entrainment.py +17 -18
- bmtool/bmplot/lfp.py +24 -17
- bmtool/bmplot/netcon_reports.py +0 -4
- bmtool/bmplot/spikes.py +97 -48
- bmtool/connectors.py +394 -251
- bmtool/debug/commands.py +13 -7
- bmtool/debug/debug.py +2 -2
- bmtool/graphs.py +26 -19
- bmtool/manage.py +6 -11
- bmtool/plot_commands.py +350 -151
- bmtool/singlecell.py +357 -195
- bmtool/synapses.py +564 -470
- bmtool/util/commands.py +1079 -627
- bmtool/util/neuron/celltuner.py +989 -609
- bmtool/util/util.py +992 -588
- {bmtool-0.7.0.6.4.dist-info → bmtool-0.7.1.1.dist-info}/METADATA +40 -2
- bmtool-0.7.1.1.dist-info/RECORD +34 -0
- {bmtool-0.7.0.6.4.dist-info → bmtool-0.7.1.1.dist-info}/WHEEL +1 -1
- bmtool-0.7.0.6.4.dist-info/RECORD +0 -34
- {bmtool-0.7.0.6.4.dist-info → bmtool-0.7.1.1.dist-info}/entry_points.txt +0 -0
- {bmtool-0.7.0.6.4.dist-info → bmtool-0.7.1.1.dist-info}/licenses/LICENSE +0 -0
- {bmtool-0.7.0.6.4.dist-info → bmtool-0.7.1.1.dist-info}/top_level.txt +0 -0
@@ -1,97 +1,94 @@
|
|
1
1
|
import h5py
|
2
2
|
import numpy as np
|
3
3
|
import xarray as xr
|
4
|
-
|
4
|
+
|
5
5
|
from ..util.util import load_nodes_from_config
|
6
6
|
|
7
|
+
|
7
8
|
def load_synapse_report(h5_file_path, config_path, network):
|
8
9
|
"""
|
9
10
|
Load and process a synapse report from a bmtk simulation into an xarray.
|
10
|
-
|
11
|
+
|
11
12
|
Parameters:
|
12
13
|
-----------
|
13
14
|
h5_file_path : str
|
14
15
|
Path to the h5 file containing the synapse report
|
15
16
|
config_path : str
|
16
17
|
Path to the simulation configuration file
|
17
|
-
|
18
|
+
|
18
19
|
Returns:
|
19
20
|
--------
|
20
21
|
xarray.Dataset
|
21
22
|
An xarray containing the synapse report data with proper population labeling
|
22
23
|
"""
|
23
24
|
# Load the h5 file
|
24
|
-
with h5py.File(h5_file_path,
|
25
|
+
with h5py.File(h5_file_path, "r") as file:
|
25
26
|
# Get the report data
|
26
|
-
report = file[
|
27
|
-
mapping = report[
|
28
|
-
|
27
|
+
report = file["report"][network]
|
28
|
+
mapping = report["mapping"]
|
29
|
+
|
29
30
|
# Get the data - shape is (n_timesteps, n_synapses)
|
30
|
-
data = report[
|
31
|
-
|
31
|
+
data = report["data"][:]
|
32
|
+
|
32
33
|
# Get time information
|
33
|
-
time_info = mapping[
|
34
|
+
time_info = mapping["time"][:] # [start_time, end_time, dt]
|
34
35
|
start_time = time_info[0]
|
35
36
|
end_time = time_info[1]
|
36
37
|
dt = time_info[2]
|
37
|
-
|
38
|
+
|
38
39
|
# Create time array
|
39
40
|
n_steps = data.shape[0]
|
40
|
-
time = np.linspace(start_time, start_time + (n_steps-1)*dt, n_steps)
|
41
|
-
|
41
|
+
time = np.linspace(start_time, start_time + (n_steps - 1) * dt, n_steps)
|
42
|
+
|
42
43
|
# Get mapping information
|
43
|
-
src_ids = mapping[
|
44
|
-
trg_ids = mapping[
|
45
|
-
sec_id = mapping[
|
46
|
-
sec_x = mapping[
|
47
|
-
|
44
|
+
src_ids = mapping["src_ids"][:]
|
45
|
+
trg_ids = mapping["trg_ids"][:]
|
46
|
+
sec_id = mapping["element_ids"][:]
|
47
|
+
sec_x = mapping["element_pos"][:]
|
48
|
+
|
48
49
|
# Load node information
|
49
50
|
nodes = load_nodes_from_config(config_path)
|
50
51
|
nodes = nodes[network]
|
51
|
-
|
52
|
+
|
52
53
|
# Create a mapping from node IDs to population names
|
53
|
-
node_to_pop = dict(zip(nodes.index, nodes[
|
54
|
-
|
54
|
+
node_to_pop = dict(zip(nodes.index, nodes["pop_name"]))
|
55
|
+
|
55
56
|
# Get the number of synapses
|
56
57
|
n_synapses = data.shape[1]
|
57
|
-
|
58
|
+
|
58
59
|
# Create arrays to hold the source and target populations for each synapse
|
59
60
|
source_pops = []
|
60
61
|
target_pops = []
|
61
62
|
connection_labels = []
|
62
|
-
|
63
|
+
|
63
64
|
# Process each synapse
|
64
65
|
for i in range(n_synapses):
|
65
66
|
src_id = src_ids[i]
|
66
67
|
trg_id = trg_ids[i]
|
67
|
-
|
68
|
+
|
68
69
|
# Get population names (with fallback for unknown IDs)
|
69
|
-
src_pop = node_to_pop.get(src_id, f
|
70
|
-
trg_pop = node_to_pop.get(trg_id, f
|
71
|
-
|
70
|
+
src_pop = node_to_pop.get(src_id, f"unknown_{src_id}")
|
71
|
+
trg_pop = node_to_pop.get(trg_id, f"unknown_{trg_id}")
|
72
|
+
|
72
73
|
source_pops.append(src_pop)
|
73
74
|
target_pops.append(trg_pop)
|
74
75
|
connection_labels.append(f"{src_pop}->{trg_pop}")
|
75
|
-
|
76
|
+
|
76
77
|
# Create xarray dataset
|
77
78
|
ds = xr.Dataset(
|
78
|
-
data_vars={
|
79
|
-
'synapse_value': (['time', 'synapse'], data)
|
80
|
-
},
|
79
|
+
data_vars={"synapse_value": (["time", "synapse"], data)},
|
81
80
|
coords={
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
81
|
+
"time": time,
|
82
|
+
"synapse": np.arange(n_synapses),
|
83
|
+
"source_pop": ("synapse", source_pops),
|
84
|
+
"target_pop": ("synapse", target_pops),
|
85
|
+
"source_id": ("synapse", src_ids),
|
86
|
+
"target_id": ("synapse", trg_ids),
|
87
|
+
"sec_id": ("synapse", sec_id),
|
88
|
+
"sec_x": ("synapse", sec_x),
|
89
|
+
"connection_label": ("synapse", connection_labels),
|
91
90
|
},
|
92
|
-
attrs={
|
93
|
-
'description': 'Synapse report data from bmtk simulation'
|
94
|
-
}
|
91
|
+
attrs={"description": "Synapse report data from bmtk simulation"},
|
95
92
|
)
|
96
|
-
|
93
|
+
|
97
94
|
return ds
|
bmtool/analysis/spikes.py
CHANGED
@@ -2,16 +2,24 @@
|
|
2
2
|
Module for processing BMTK spikes output.
|
3
3
|
"""
|
4
4
|
|
5
|
+
import os
|
6
|
+
from typing import Dict, List, Optional, Tuple, Union
|
7
|
+
|
5
8
|
import h5py
|
6
|
-
import pandas as pd
|
7
|
-
from bmtool.util.util import load_nodes_from_config
|
8
|
-
from typing import Dict, Optional,Tuple, Union, List
|
9
9
|
import numpy as np
|
10
|
+
import pandas as pd
|
10
11
|
from scipy.stats import mannwhitneyu
|
11
|
-
|
12
|
+
|
13
|
+
from bmtool.util.util import load_nodes_from_config
|
12
14
|
|
13
15
|
|
14
|
-
def load_spikes_to_df(
|
16
|
+
def load_spikes_to_df(
|
17
|
+
spike_file: str,
|
18
|
+
network_name: str,
|
19
|
+
sort: bool = True,
|
20
|
+
config: str = None,
|
21
|
+
groupby: Union[str, List[str]] = "pop_name",
|
22
|
+
) -> pd.DataFrame:
|
15
23
|
"""
|
16
24
|
Load spike data from an HDF5 file into a pandas DataFrame.
|
17
25
|
|
@@ -33,21 +41,23 @@ def load_spikes_to_df(spike_file: str, network_name: str, sort: bool = True, con
|
|
33
41
|
pd.DataFrame
|
34
42
|
A pandas DataFrame containing 'node_ids' and 'timestamps' columns from the spike data,
|
35
43
|
with additional columns if a config file is provided
|
36
|
-
|
44
|
+
|
37
45
|
Examples
|
38
46
|
--------
|
39
47
|
>>> df = load_spikes_to_df("spikes.h5", "cortex")
|
40
48
|
>>> df = load_spikes_to_df("spikes.h5", "cortex", config="config.json", groupby=["pop_name", "model_type"])
|
41
49
|
"""
|
42
50
|
with h5py.File(spike_file) as f:
|
43
|
-
spikes_df = pd.DataFrame(
|
44
|
-
|
45
|
-
|
46
|
-
|
51
|
+
spikes_df = pd.DataFrame(
|
52
|
+
{
|
53
|
+
"node_ids": f["spikes"][network_name]["node_ids"],
|
54
|
+
"timestamps": f["spikes"][network_name]["timestamps"],
|
55
|
+
}
|
56
|
+
)
|
47
57
|
|
48
58
|
if sort:
|
49
|
-
spikes_df.sort_values(by=
|
50
|
-
|
59
|
+
spikes_df.sort_values(by="timestamps", inplace=True, ignore_index=True)
|
60
|
+
|
51
61
|
if config:
|
52
62
|
nodes = load_nodes_from_config(config)
|
53
63
|
nodes = nodes[network_name]
|
@@ -61,12 +71,19 @@ def load_spikes_to_df(spike_file: str, network_name: str, sort: bool = True, con
|
|
61
71
|
if missing_cols:
|
62
72
|
raise KeyError(f"Columns {missing_cols} not found in nodes DataFrame.")
|
63
73
|
|
64
|
-
spikes_df = spikes_df.merge(
|
74
|
+
spikes_df = spikes_df.merge(
|
75
|
+
nodes[groupby], left_on="node_ids", right_index=True, how="left"
|
76
|
+
)
|
65
77
|
|
66
78
|
return spikes_df
|
67
79
|
|
68
80
|
|
69
|
-
def compute_firing_rate_stats(
|
81
|
+
def compute_firing_rate_stats(
|
82
|
+
df: pd.DataFrame,
|
83
|
+
groupby: Union[str, List[str]] = "pop_name",
|
84
|
+
start_time: float = None,
|
85
|
+
stop_time: float = None,
|
86
|
+
) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
70
87
|
"""
|
71
88
|
Computes the firing rates of individual nodes and the mean and standard deviation of firing rates per group.
|
72
89
|
|
@@ -77,7 +94,7 @@ def compute_firing_rate_stats(df: pd.DataFrame, groupby: Union[str, List[str]] =
|
|
77
94
|
stop_time (float, optional): Stop time for the analysis window. Defaults to the maximum timestamp in the data.
|
78
95
|
|
79
96
|
Returns:
|
80
|
-
Tuple[pd.DataFrame, pd.DataFrame]:
|
97
|
+
Tuple[pd.DataFrame, pd.DataFrame]:
|
81
98
|
- The first DataFrame (`pop_stats`) contains the mean and standard deviation of firing rates per group.
|
82
99
|
- The second DataFrame (`individual_stats`) contains the firing rate of each individual node.
|
83
100
|
"""
|
@@ -85,7 +102,7 @@ def compute_firing_rate_stats(df: pd.DataFrame, groupby: Union[str, List[str]] =
|
|
85
102
|
# Ensure groupby is a list
|
86
103
|
if isinstance(groupby, str):
|
87
104
|
groupby = [groupby]
|
88
|
-
|
105
|
+
|
89
106
|
# Ensure all columns exist in the dataframe
|
90
107
|
for col in groupby:
|
91
108
|
if col not in df.columns:
|
@@ -102,42 +119,47 @@ def compute_firing_rate_stats(df: pd.DataFrame, groupby: Union[str, List[str]] =
|
|
102
119
|
min_time = df["timestamps"].min()
|
103
120
|
else:
|
104
121
|
min_time = start_time
|
105
|
-
|
106
|
-
if stop_time is None:
|
122
|
+
|
123
|
+
if stop_time is None:
|
107
124
|
max_time = df["timestamps"].max()
|
108
125
|
else:
|
109
126
|
max_time = stop_time
|
110
|
-
|
127
|
+
|
111
128
|
duration = max_time - min_time # Avoid division by zero
|
112
129
|
|
113
130
|
if duration <= 0:
|
114
131
|
raise ValueError("Invalid time window: Stop time must be greater than start time.")
|
115
132
|
|
116
133
|
# Compute firing rate for each node
|
117
|
-
import pandas as pd
|
118
134
|
|
119
135
|
# Compute spike counts per node
|
120
136
|
spike_counts = df["node_ids"].value_counts().reset_index()
|
121
137
|
spike_counts.columns = ["node_ids", "spike_count"] # Rename columns
|
122
138
|
|
123
139
|
# Merge with original dataframe to get corresponding labels (e.g., 'pop_name')
|
124
|
-
spike_counts = spike_counts.merge(
|
140
|
+
spike_counts = spike_counts.merge(
|
141
|
+
df[["node_ids"] + groupby].drop_duplicates(), on="node_ids", how="left"
|
142
|
+
)
|
125
143
|
|
126
144
|
# Compute firing rate
|
127
|
-
spike_counts["firing_rate"] = spike_counts["spike_count"] / duration * 1000
|
145
|
+
spike_counts["firing_rate"] = spike_counts["spike_count"] / duration * 1000 # scale to Hz
|
128
146
|
indivdual_stats = spike_counts
|
129
|
-
|
147
|
+
|
130
148
|
# Compute mean and standard deviation per group
|
131
149
|
pop_stats = spike_counts.groupby(groupby)["firing_rate"].agg(["mean", "std"]).reset_index()
|
132
150
|
|
133
151
|
# Rename columns
|
134
152
|
pop_stats.rename(columns={"mean": "firing_rate_mean", "std": "firing_rate_std"}, inplace=True)
|
135
153
|
|
136
|
-
return pop_stats,indivdual_stats
|
154
|
+
return pop_stats, indivdual_stats
|
137
155
|
|
138
156
|
|
139
|
-
def _pop_spike_rate(
|
140
|
-
|
157
|
+
def _pop_spike_rate(
|
158
|
+
spike_times: Union[np.ndarray, list],
|
159
|
+
time: Optional[Tuple[float, float, float]] = None,
|
160
|
+
time_points: Optional[Union[np.ndarray, list]] = None,
|
161
|
+
frequency: bool = False,
|
162
|
+
) -> np.ndarray:
|
141
163
|
"""
|
142
164
|
Calculate the spike count or frequency histogram over specified time intervals.
|
143
165
|
|
@@ -146,7 +168,7 @@ def _pop_spike_rate(spike_times: Union[np.ndarray, list], time: Optional[Tuple[f
|
|
146
168
|
spike_times : Union[np.ndarray, list]
|
147
169
|
Array or list of spike times in milliseconds
|
148
170
|
time : Optional[Tuple[float, float, float]], optional
|
149
|
-
Tuple specifying (start, stop, step) in milliseconds. Used to create evenly spaced time points
|
171
|
+
Tuple specifying (start, stop, step) in milliseconds. Used to create evenly spaced time points
|
150
172
|
if `time_points` is not provided. Default is None.
|
151
173
|
time_points : Optional[Union[np.ndarray, list]], optional
|
152
174
|
Array or list of specific time points for binning. If provided, `time` is ignored. Default is None.
|
@@ -171,20 +193,27 @@ def _pop_spike_rate(spike_times: Union[np.ndarray, list], time: Optional[Tuple[f
|
|
171
193
|
else:
|
172
194
|
time_points = np.asarray(time_points).ravel()
|
173
195
|
dt = (time_points[-1] - time_points[0]) / (time_points.size - 1)
|
174
|
-
|
196
|
+
|
175
197
|
bins = np.append(time_points, time_points[-1] + dt)
|
176
198
|
spike_rate, _ = np.histogram(np.asarray(spike_times), bins)
|
177
|
-
|
199
|
+
|
178
200
|
if frequency:
|
179
201
|
spike_rate = 1000 / dt * spike_rate
|
180
|
-
|
202
|
+
|
181
203
|
return spike_rate
|
182
204
|
|
183
205
|
|
184
|
-
def get_population_spike_rate(
|
185
|
-
|
186
|
-
|
187
|
-
|
206
|
+
def get_population_spike_rate(
|
207
|
+
spike_data: pd.DataFrame,
|
208
|
+
fs: float = 400.0,
|
209
|
+
t_start: float = 0,
|
210
|
+
t_stop: Optional[float] = None,
|
211
|
+
config: Optional[str] = None,
|
212
|
+
network_name: Optional[str] = None,
|
213
|
+
save: bool = False,
|
214
|
+
save_path: Optional[str] = None,
|
215
|
+
normalize: bool = False,
|
216
|
+
) -> Dict[str, np.ndarray]:
|
188
217
|
"""
|
189
218
|
Calculate the population spike rate for each population in the given spike data, with an option to normalize.
|
190
219
|
|
@@ -231,39 +260,43 @@ def get_population_spike_rate(spike_data: pd.DataFrame, fs: float = 400.0, t_sta
|
|
231
260
|
node_number = {}
|
232
261
|
|
233
262
|
if config is None:
|
234
|
-
print(
|
263
|
+
print(
|
264
|
+
"Note: Node number is obtained by counting unique node spikes in the network.\nIf the network did not run for a sufficient duration, and not all cells fired, this count might be incorrect."
|
265
|
+
)
|
235
266
|
print("You can provide a config to calculate the correct amount of nodes!")
|
236
|
-
|
267
|
+
|
237
268
|
if config:
|
238
269
|
if not network_name:
|
239
|
-
print(
|
270
|
+
print(
|
271
|
+
"Grabbing first network; specify a network name to ensure correct node population is selected."
|
272
|
+
)
|
273
|
+
|
274
|
+
for pop_name in spike_data["pop_name"].unique():
|
275
|
+
ps = spike_data[spike_data["pop_name"] == pop_name]
|
240
276
|
|
241
|
-
for pop_name in spike_data['pop_name'].unique():
|
242
|
-
ps = spike_data[spike_data['pop_name'] == pop_name]
|
243
|
-
|
244
277
|
if config:
|
245
278
|
nodes = load_nodes_from_config(config)
|
246
279
|
if network_name:
|
247
280
|
nodes = nodes[network_name]
|
248
281
|
else:
|
249
282
|
nodes = list(nodes.values())[0] if nodes else {}
|
250
|
-
nodes = nodes[nodes[
|
283
|
+
nodes = nodes[nodes["pop_name"] == pop_name]
|
251
284
|
node_number[pop_name] = nodes.index.nunique()
|
252
285
|
else:
|
253
|
-
node_number[pop_name] = ps[
|
286
|
+
node_number[pop_name] = ps["node_ids"].nunique()
|
254
287
|
|
255
288
|
if t_stop is None:
|
256
|
-
t_stop = spike_data[
|
289
|
+
t_stop = spike_data["timestamps"].max()
|
257
290
|
|
258
291
|
filtered_spikes = spike_data[
|
259
|
-
(spike_data[
|
260
|
-
(spike_data[
|
261
|
-
(spike_data[
|
292
|
+
(spike_data["pop_name"] == pop_name)
|
293
|
+
& (spike_data["timestamps"] > t_start)
|
294
|
+
& (spike_data["timestamps"] < t_stop)
|
262
295
|
]
|
263
296
|
pop_spikes[pop_name] = filtered_spikes
|
264
297
|
|
265
298
|
time = np.array([t_start, t_stop, 1000 / fs])
|
266
|
-
pop_rspk = {p: _pop_spike_rate(spk[
|
299
|
+
pop_rspk = {p: _pop_spike_rate(spk["timestamps"], time) for p, spk in pop_spikes.items()}
|
267
300
|
spike_rate = {p: fs / node_number[p] * pop_rspk[p] for p in pop_rspk}
|
268
301
|
|
269
302
|
# Normalize each spike rate series if normalize=True
|
@@ -273,25 +306,27 @@ def get_population_spike_rate(spike_data: pd.DataFrame, fs: float = 400.0, t_sta
|
|
273
306
|
if save:
|
274
307
|
if save_path is None:
|
275
308
|
raise ValueError("save_path must be provided if save is True.")
|
276
|
-
|
309
|
+
|
277
310
|
os.makedirs(save_path, exist_ok=True)
|
278
|
-
|
279
|
-
save_file = os.path.join(save_path,
|
280
|
-
with h5py.File(save_file,
|
281
|
-
f.create_dataset(
|
282
|
-
grp = f.create_group(
|
311
|
+
|
312
|
+
save_file = os.path.join(save_path, "spike_rate.h5")
|
313
|
+
with h5py.File(save_file, "w") as f:
|
314
|
+
f.create_dataset("time", data=time)
|
315
|
+
grp = f.create_group("populations")
|
283
316
|
for p, rspk in spike_rate.items():
|
284
317
|
pop_grp = grp.create_group(p)
|
285
|
-
pop_grp.create_dataset(
|
318
|
+
pop_grp.create_dataset("data", data=rspk)
|
286
319
|
|
287
320
|
return spike_rate
|
288
321
|
|
289
322
|
|
290
|
-
def compare_firing_over_times(
|
323
|
+
def compare_firing_over_times(
|
324
|
+
spike_df: pd.DataFrame, group_by: str, time_window_1: List[float], time_window_2: List[float]
|
325
|
+
) -> None:
|
291
326
|
"""
|
292
327
|
Compares the firing rates of a population during two different time windows and performs
|
293
328
|
a statistical test to determine if there is a significant difference.
|
294
|
-
|
329
|
+
|
295
330
|
Parameters
|
296
331
|
----------
|
297
332
|
spike_df : pd.DataFrame
|
@@ -302,12 +337,12 @@ def compare_firing_over_times(spike_df: pd.DataFrame, group_by: str, time_window
|
|
302
337
|
First time window as [start, stop] in milliseconds
|
303
338
|
time_window_2 : List[float]
|
304
339
|
Second time window as [start, stop] in milliseconds
|
305
|
-
|
340
|
+
|
306
341
|
Returns
|
307
342
|
-------
|
308
343
|
None
|
309
344
|
Results are printed to the console
|
310
|
-
|
345
|
+
|
311
346
|
Notes
|
312
347
|
-----
|
313
348
|
Uses Mann-Whitney U test (non-parametric) to compare firing rates between the two windows
|
@@ -316,38 +351,44 @@ def compare_firing_over_times(spike_df: pd.DataFrame, group_by: str, time_window
|
|
316
351
|
for pop_name in spike_df[group_by].unique():
|
317
352
|
print(f"Population: {pop_name}")
|
318
353
|
pop_spikes = spike_df[spike_df[group_by] == pop_name]
|
319
|
-
|
354
|
+
|
320
355
|
# Filter by time windows
|
321
|
-
pop_spikes_1 = pop_spikes[
|
322
|
-
|
323
|
-
|
356
|
+
pop_spikes_1 = pop_spikes[
|
357
|
+
(pop_spikes["timestamps"] >= time_window_1[0])
|
358
|
+
& (pop_spikes["timestamps"] <= time_window_1[1])
|
359
|
+
]
|
360
|
+
pop_spikes_2 = pop_spikes[
|
361
|
+
(pop_spikes["timestamps"] >= time_window_2[0])
|
362
|
+
& (pop_spikes["timestamps"] <= time_window_2[1])
|
363
|
+
]
|
364
|
+
|
324
365
|
# Get unique neuron IDs
|
325
|
-
unique_neurons = pop_spikes[
|
326
|
-
|
366
|
+
unique_neurons = pop_spikes["node_ids"].unique()
|
367
|
+
|
327
368
|
# Calculate firing rates per neuron for each time window in Hz
|
328
369
|
neuron_rates_1 = []
|
329
370
|
neuron_rates_2 = []
|
330
|
-
|
371
|
+
|
331
372
|
for neuron in unique_neurons:
|
332
373
|
# Count spikes for this neuron in each window
|
333
|
-
n_spikes_1 = len(pop_spikes_1[pop_spikes_1[
|
334
|
-
n_spikes_2 = len(pop_spikes_2[pop_spikes_2[
|
335
|
-
|
374
|
+
n_spikes_1 = len(pop_spikes_1[pop_spikes_1["node_ids"] == neuron])
|
375
|
+
n_spikes_2 = len(pop_spikes_2[pop_spikes_2["node_ids"] == neuron])
|
376
|
+
|
336
377
|
# Calculate firing rate in Hz (convert ms to seconds by dividing by 1000)
|
337
378
|
rate_1 = n_spikes_1 / ((time_window_1[1] - time_window_1[0]) / 1000)
|
338
379
|
rate_2 = n_spikes_2 / ((time_window_2[1] - time_window_2[0]) / 1000)
|
339
|
-
|
380
|
+
|
340
381
|
neuron_rates_1.append(rate_1)
|
341
382
|
neuron_rates_2.append(rate_2)
|
342
|
-
|
383
|
+
|
343
384
|
# Calculate average firing rates
|
344
385
|
avg_firing_rate_1 = np.mean(neuron_rates_1) if neuron_rates_1 else 0
|
345
386
|
avg_firing_rate_2 = np.mean(neuron_rates_2) if neuron_rates_2 else 0
|
346
|
-
|
387
|
+
|
347
388
|
# Perform Mann-Whitney U test
|
348
389
|
# Handle the case when one or both arrays are empty
|
349
390
|
if len(neuron_rates_1) > 0 and len(neuron_rates_2) > 0:
|
350
|
-
u_stat, p_val = mannwhitneyu(neuron_rates_1, neuron_rates_2, alternative=
|
391
|
+
u_stat, p_val = mannwhitneyu(neuron_rates_1, neuron_rates_2, alternative="two-sided")
|
351
392
|
else:
|
352
393
|
u_stat, p_val = np.nan, np.nan
|
353
394
|
|
@@ -356,4 +397,4 @@ def compare_firing_over_times(spike_df: pd.DataFrame, group_by: str, time_window
|
|
356
397
|
print(f" U-statistic: {u_stat:.2f}")
|
357
398
|
print(f" p-value: {p_val}")
|
358
399
|
print(f" Significant difference (p<0.05): {'Yes' if p_val < 0.05 else 'No'}")
|
359
|
-
return
|
400
|
+
return
|