bmtool 0.7.0.6.2__py3-none-any.whl → 0.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bmtool/SLURM.py +162 -109
- bmtool/__init__.py +1 -1
- bmtool/__main__.py +8 -7
- bmtool/analysis/entrainment.py +250 -143
- bmtool/analysis/lfp.py +279 -134
- bmtool/analysis/netcon_reports.py +41 -44
- bmtool/analysis/spikes.py +114 -73
- bmtool/bmplot/connections.py +658 -325
- bmtool/bmplot/entrainment.py +17 -18
- bmtool/bmplot/lfp.py +24 -17
- bmtool/bmplot/netcon_reports.py +0 -4
- bmtool/bmplot/spikes.py +97 -48
- bmtool/connectors.py +394 -251
- bmtool/debug/commands.py +13 -7
- bmtool/debug/debug.py +2 -2
- bmtool/graphs.py +26 -19
- bmtool/manage.py +6 -11
- bmtool/plot_commands.py +350 -151
- bmtool/singlecell.py +357 -195
- bmtool/synapses.py +564 -470
- bmtool/util/commands.py +1079 -627
- bmtool/util/neuron/celltuner.py +989 -609
- bmtool/util/util.py +992 -588
- {bmtool-0.7.0.6.2.dist-info → bmtool-0.7.1.dist-info}/METADATA +41 -3
- bmtool-0.7.1.dist-info/RECORD +34 -0
- {bmtool-0.7.0.6.2.dist-info → bmtool-0.7.1.dist-info}/WHEEL +1 -1
- bmtool-0.7.0.6.2.dist-info/RECORD +0 -34
- {bmtool-0.7.0.6.2.dist-info → bmtool-0.7.1.dist-info}/entry_points.txt +0 -0
- {bmtool-0.7.0.6.2.dist-info → bmtool-0.7.1.dist-info}/licenses/LICENSE +0 -0
- {bmtool-0.7.0.6.2.dist-info → bmtool-0.7.1.dist-info}/top_level.txt +0 -0
bmtool/bmplot/entrainment.py
CHANGED
@@ -1,11 +1,11 @@
|
|
1
|
-
|
2
1
|
import matplotlib.pyplot as plt
|
3
2
|
import seaborn as sns
|
4
3
|
|
4
|
+
|
5
5
|
def plot_spike_power_correlation(correlation_results, frequencies, pop_names):
|
6
6
|
"""
|
7
7
|
Plot the correlation between population spike rates and LFP power.
|
8
|
-
|
8
|
+
|
9
9
|
Parameters:
|
10
10
|
-----------
|
11
11
|
correlation_results : dict
|
@@ -17,35 +17,34 @@ def plot_spike_power_correlation(correlation_results, frequencies, pop_names):
|
|
17
17
|
"""
|
18
18
|
sns.set_style("whitegrid")
|
19
19
|
plt.figure(figsize=(10, 6))
|
20
|
-
|
20
|
+
|
21
21
|
for pop in pop_names:
|
22
22
|
# Extract correlation values for each frequency
|
23
23
|
corr_values = []
|
24
24
|
valid_freqs = []
|
25
|
-
|
25
|
+
|
26
26
|
for freq in frequencies:
|
27
27
|
if freq in correlation_results[pop]:
|
28
|
-
corr_values.append(correlation_results[pop][freq][
|
28
|
+
corr_values.append(correlation_results[pop][freq]["correlation"])
|
29
29
|
valid_freqs.append(freq)
|
30
|
-
|
30
|
+
|
31
31
|
# Plot correlation line
|
32
|
-
plt.plot(valid_freqs, corr_values, marker=
|
33
|
-
|
34
|
-
|
35
|
-
plt.
|
36
|
-
plt.
|
37
|
-
plt.title('Spike rate LFP power correlation during stimulus', fontsize=14)
|
32
|
+
plt.plot(valid_freqs, corr_values, marker="o", label=pop, linewidth=2, markersize=6)
|
33
|
+
|
34
|
+
plt.xlabel("Frequency (Hz)", fontsize=12)
|
35
|
+
plt.ylabel("Spike Rate-Power Correlation", fontsize=12)
|
36
|
+
plt.title("Spike rate LFP power correlation during stimulus", fontsize=14)
|
38
37
|
plt.grid(True, alpha=0.3)
|
39
38
|
plt.legend(fontsize=12)
|
40
39
|
plt.xticks(frequencies[::2]) # Display every other frequency on x-axis
|
41
|
-
|
40
|
+
|
42
41
|
# Add horizontal line at zero for reference
|
43
|
-
plt.axhline(y=0, color=
|
44
|
-
|
42
|
+
plt.axhline(y=0, color="gray", linestyle="-", alpha=0.5)
|
43
|
+
|
45
44
|
# Set y-axis limits to make zero visible
|
46
45
|
y_min, y_max = plt.ylim()
|
47
46
|
plt.ylim(min(y_min, -0.1), max(y_max, 0.1))
|
48
|
-
|
47
|
+
|
49
48
|
plt.tight_layout()
|
50
|
-
|
51
|
-
plt.show()
|
49
|
+
|
50
|
+
plt.show()
|
bmtool/bmplot/lfp.py
CHANGED
@@ -1,35 +1,43 @@
|
|
1
|
+
import matplotlib.pyplot as plt
|
1
2
|
import numpy as np
|
3
|
+
|
2
4
|
from bmtool.analysis.lfp import gen_aperiodic
|
3
|
-
import matplotlib.pyplot as plt
|
4
5
|
|
5
6
|
|
6
|
-
def plot_spectrogram(
|
7
|
-
|
7
|
+
def plot_spectrogram(
|
8
|
+
sxx_xarray,
|
9
|
+
remove_aperiodic=None,
|
10
|
+
log_power=False,
|
11
|
+
plt_range=None,
|
12
|
+
clr_freq_range=None,
|
13
|
+
pad=0.03,
|
14
|
+
ax=None,
|
15
|
+
):
|
8
16
|
"""Plot spectrogram. Determine color limits using value in frequency band clr_freq_range"""
|
9
17
|
sxx = sxx_xarray.PSD.values.copy()
|
10
18
|
t = sxx_xarray.time.values.copy()
|
11
19
|
f = sxx_xarray.frequency.values.copy()
|
12
20
|
|
13
|
-
cbar_label =
|
21
|
+
cbar_label = "PSD" if remove_aperiodic is None else "PSD Residual"
|
14
22
|
if log_power:
|
15
|
-
with np.errstate(divide=
|
23
|
+
with np.errstate(divide="ignore"):
|
16
24
|
sxx = np.log10(sxx)
|
17
|
-
cbar_label +=
|
25
|
+
cbar_label += " dB" if log_power == "dB" else " log(power)"
|
18
26
|
|
19
27
|
if remove_aperiodic is not None:
|
20
28
|
f1_idx = 0 if f[0] else 1
|
21
29
|
ap_fit = gen_aperiodic(f[f1_idx:], remove_aperiodic.aperiodic_params)
|
22
|
-
sxx[f1_idx:, :] -= (ap_fit if log_power else 10
|
23
|
-
sxx[:f1_idx, :] = 0.
|
30
|
+
sxx[f1_idx:, :] -= (ap_fit if log_power else 10**ap_fit)[:, None]
|
31
|
+
sxx[:f1_idx, :] = 0.0
|
24
32
|
|
25
|
-
if log_power ==
|
33
|
+
if log_power == "dB":
|
26
34
|
sxx *= 10
|
27
35
|
|
28
36
|
if ax is None:
|
29
37
|
_, ax = plt.subplots(1, 1)
|
30
38
|
plt_range = np.array(f[-1]) if plt_range is None else np.array(plt_range)
|
31
39
|
if plt_range.size == 1:
|
32
|
-
plt_range = [f[0 if f[0] else 1] if log_power else 0
|
40
|
+
plt_range = [f[0 if f[0] else 1] if log_power else 0.0, plt_range.item()]
|
33
41
|
f_idx = (f >= plt_range[0]) & (f <= plt_range[1])
|
34
42
|
if clr_freq_range is None:
|
35
43
|
vmin, vmax = None, None
|
@@ -38,16 +46,15 @@ def plot_spectrogram(sxx_xarray, remove_aperiodic=None, log_power=False,
|
|
38
46
|
vmin, vmax = sxx[c_idx, :].min(), sxx[c_idx, :].max()
|
39
47
|
|
40
48
|
f = f[f_idx]
|
41
|
-
pcm = ax.pcolormesh(t, f, sxx[f_idx, :], shading=
|
42
|
-
if
|
49
|
+
pcm = ax.pcolormesh(t, f, sxx[f_idx, :], shading="gouraud", vmin=vmin, vmax=vmax)
|
50
|
+
if "cone_of_influence_frequency" in sxx_xarray:
|
43
51
|
coif = sxx_xarray.cone_of_influence_frequency
|
44
52
|
ax.plot(t, coif)
|
45
|
-
ax.fill_between(t, coif, step=
|
53
|
+
ax.fill_between(t, coif, step="mid", alpha=0.2)
|
46
54
|
ax.set_xlim(t[0], t[-1])
|
47
|
-
#ax.set_xlim(t[0],0.2)
|
55
|
+
# ax.set_xlim(t[0],0.2)
|
48
56
|
ax.set_ylim(f[0], f[-1])
|
49
57
|
plt.colorbar(mappable=pcm, ax=ax, label=cbar_label, pad=pad)
|
50
|
-
ax.set_xlabel(
|
51
|
-
ax.set_ylabel(
|
58
|
+
ax.set_xlabel("Time (sec)")
|
59
|
+
ax.set_ylabel("Frequency (Hz)")
|
52
60
|
return sxx
|
53
|
-
|
bmtool/bmplot/netcon_reports.py
CHANGED
bmtool/bmplot/spikes.py
CHANGED
@@ -1,19 +1,27 @@
|
|
1
|
-
from
|
2
|
-
|
3
|
-
import pandas as pd
|
4
|
-
from matplotlib.axes import Axes
|
1
|
+
from typing import Dict, List, Optional, Union
|
2
|
+
|
5
3
|
import matplotlib.pyplot as plt
|
6
|
-
import seaborn as sns
|
7
4
|
import numpy as np
|
5
|
+
import pandas as pd
|
6
|
+
import seaborn as sns
|
7
|
+
from matplotlib.axes import Axes
|
8
8
|
|
9
|
+
from ..util.util import load_nodes_from_config
|
9
10
|
|
10
11
|
|
11
|
-
def raster(
|
12
|
-
|
13
|
-
|
12
|
+
def raster(
|
13
|
+
spikes_df: Optional[pd.DataFrame] = None,
|
14
|
+
config: Optional[str] = None,
|
15
|
+
network_name: Optional[str] = None,
|
16
|
+
groupby: Optional[str] = "pop_name",
|
17
|
+
ax: Optional[Axes] = None,
|
18
|
+
tstart: Optional[float] = None,
|
19
|
+
tstop: Optional[float] = None,
|
20
|
+
color_map: Optional[Dict[str, str]] = None,
|
21
|
+
) -> Axes:
|
14
22
|
"""
|
15
23
|
Plots a raster plot of neural spikes, with different colors for each population.
|
16
|
-
|
24
|
+
|
17
25
|
Parameters:
|
18
26
|
----------
|
19
27
|
spikes_df : pd.DataFrame, optional
|
@@ -30,12 +38,12 @@ def raster(spikes_df: Optional[pd.DataFrame] = None, config: Optional[str] = Non
|
|
30
38
|
Stop time for filtering spikes; only spikes with timestamps less than `tstop` will be plotted.
|
31
39
|
color_map : dict, optional
|
32
40
|
Dictionary specifying colors for each population. Keys should be population names, and values should be color values.
|
33
|
-
|
41
|
+
|
34
42
|
Returns:
|
35
43
|
-------
|
36
44
|
matplotlib.axes.Axes
|
37
45
|
Axes with the raster plot.
|
38
|
-
|
46
|
+
|
39
47
|
Notes:
|
40
48
|
-----
|
41
49
|
- If `config` is provided, the function merges population names from the node data with `spikes_df`.
|
@@ -48,9 +56,9 @@ def raster(spikes_df: Optional[pd.DataFrame] = None, config: Optional[str] = Non
|
|
48
56
|
|
49
57
|
# Filter spikes by time range if specified
|
50
58
|
if tstart is not None:
|
51
|
-
spikes_df = spikes_df[spikes_df[
|
59
|
+
spikes_df = spikes_df[spikes_df["timestamps"] > tstart]
|
52
60
|
if tstop is not None:
|
53
|
-
spikes_df = spikes_df[spikes_df[
|
61
|
+
spikes_df = spikes_df[spikes_df["timestamps"] < tstop]
|
54
62
|
|
55
63
|
# Load and merge node population data if config is provided
|
56
64
|
if config:
|
@@ -59,45 +67,59 @@ def raster(spikes_df: Optional[pd.DataFrame] = None, config: Optional[str] = Non
|
|
59
67
|
nodes = nodes.get(network_name, {})
|
60
68
|
else:
|
61
69
|
nodes = list(nodes.values())[0] if nodes else {}
|
62
|
-
print(
|
63
|
-
|
70
|
+
print(
|
71
|
+
"Grabbing first network; specify a network name to ensure correct node population is selected."
|
72
|
+
)
|
73
|
+
|
64
74
|
# Find common columns, but exclude the join key from the list
|
65
75
|
common_columns = spikes_df.columns.intersection(nodes.columns).tolist()
|
66
|
-
common_columns = [
|
76
|
+
common_columns = [
|
77
|
+
col for col in common_columns if col != "node_ids"
|
78
|
+
] # Remove our join key from the common list
|
67
79
|
|
68
80
|
# Drop all intersecting columns except the join key column from df2
|
69
81
|
spikes_df = spikes_df.drop(columns=common_columns)
|
70
82
|
# merge nodes and spikes df
|
71
|
-
spikes_df = spikes_df.merge(
|
72
|
-
|
83
|
+
spikes_df = spikes_df.merge(
|
84
|
+
nodes[groupby], left_on="node_ids", right_index=True, how="left"
|
85
|
+
)
|
73
86
|
|
74
87
|
# Get unique population names
|
75
88
|
unique_pop_names = spikes_df[groupby].unique()
|
76
|
-
|
89
|
+
|
77
90
|
# Generate colors if no color_map is provided
|
78
91
|
if color_map is None:
|
79
|
-
cmap = plt.get_cmap(
|
80
|
-
color_map = {
|
92
|
+
cmap = plt.get_cmap("tab10") # Default colormap
|
93
|
+
color_map = {
|
94
|
+
pop_name: cmap(i / len(unique_pop_names)) for i, pop_name in enumerate(unique_pop_names)
|
95
|
+
}
|
81
96
|
else:
|
82
97
|
# Ensure color_map contains all population names
|
83
98
|
missing_colors = [pop for pop in unique_pop_names if pop not in color_map]
|
84
99
|
if missing_colors:
|
85
100
|
raise ValueError(f"color_map is missing colors for populations: {missing_colors}")
|
86
|
-
|
101
|
+
|
87
102
|
# Plot each population with its specified or generated color
|
88
103
|
for pop_name, group in spikes_df.groupby(groupby):
|
89
|
-
ax.scatter(
|
104
|
+
ax.scatter(
|
105
|
+
group["timestamps"], group["node_ids"], label=pop_name, color=color_map[pop_name], s=0.5
|
106
|
+
)
|
90
107
|
|
91
108
|
# Label axes
|
92
109
|
ax.set_xlabel("Time")
|
93
110
|
ax.set_ylabel("Node ID")
|
94
|
-
ax.legend(title="Population", loc=
|
95
|
-
|
111
|
+
ax.legend(title="Population", loc="upper right", framealpha=0.9, markerfirst=False)
|
112
|
+
|
96
113
|
return ax
|
97
|
-
|
114
|
+
|
115
|
+
|
98
116
|
# uses df from bmtool.analysis.spikes compute_firing_rate_stats
|
99
|
-
def plot_firing_rate_pop_stats(
|
100
|
-
|
117
|
+
def plot_firing_rate_pop_stats(
|
118
|
+
firing_stats: pd.DataFrame,
|
119
|
+
groupby: Union[str, List[str]],
|
120
|
+
ax: Optional[Axes] = None,
|
121
|
+
color_map: Optional[Dict[str, str]] = None,
|
122
|
+
) -> Axes:
|
101
123
|
"""
|
102
124
|
Plots a bar graph of mean firing rates with error bars (standard deviation).
|
103
125
|
|
@@ -129,7 +151,7 @@ def plot_firing_rate_pop_stats(firing_stats: pd.DataFrame, groupby: Union[str, L
|
|
129
151
|
|
130
152
|
# Generate colors if no color_map is provided
|
131
153
|
if color_map is None:
|
132
|
-
cmap = plt.get_cmap(
|
154
|
+
cmap = plt.get_cmap("viridis")
|
133
155
|
color_map = {group: cmap(i / len(unique_groups)) for i, group in enumerate(unique_groups)}
|
134
156
|
else:
|
135
157
|
# Ensure color_map contains all groups
|
@@ -157,13 +179,13 @@ def plot_firing_rate_pop_stats(firing_stats: pd.DataFrame, groupby: Union[str, L
|
|
157
179
|
|
158
180
|
# Add error bars manually with caps
|
159
181
|
_, caps, _ = ax.errorbar(
|
160
|
-
x=np.arange(len(x_labels)),
|
161
|
-
y=means,
|
162
|
-
yerr=std_devs,
|
163
|
-
fmt=
|
164
|
-
capsize=5,
|
165
|
-
capthick=2,
|
166
|
-
color="black"
|
182
|
+
x=np.arange(len(x_labels)),
|
183
|
+
y=means,
|
184
|
+
yerr=std_devs,
|
185
|
+
fmt="none",
|
186
|
+
capsize=5,
|
187
|
+
capthick=2,
|
188
|
+
color="black",
|
167
189
|
)
|
168
190
|
|
169
191
|
# Formatting
|
@@ -172,14 +194,20 @@ def plot_firing_rate_pop_stats(firing_stats: pd.DataFrame, groupby: Union[str, L
|
|
172
194
|
ax.set_xlabel("Population Group")
|
173
195
|
ax.set_ylabel("Mean Firing Rate (spikes/s)")
|
174
196
|
ax.set_title("Firing Rate Statistics by Population")
|
175
|
-
ax.grid(axis=
|
197
|
+
ax.grid(axis="y", linestyle="--", alpha=0.7)
|
176
198
|
|
177
199
|
return ax
|
178
200
|
|
201
|
+
|
179
202
|
# uses df from bmtool.analysis.spikes compute_firing_rate_stats
|
180
|
-
def plot_firing_rate_distribution(
|
181
|
-
|
182
|
-
|
203
|
+
def plot_firing_rate_distribution(
|
204
|
+
individual_stats: pd.DataFrame,
|
205
|
+
groupby: Union[str, list],
|
206
|
+
ax: Optional[Axes] = None,
|
207
|
+
color_map: Optional[Dict[str, str]] = None,
|
208
|
+
plot_type: Union[str, list] = "box",
|
209
|
+
swarm_alpha: float = 0.6,
|
210
|
+
) -> Axes:
|
183
211
|
"""
|
184
212
|
Plots a distribution of individual firing rates using one or more plot types
|
185
213
|
(box plot, violin plot, or swarm plot), overlaying them on top of each other.
|
@@ -214,7 +242,7 @@ def plot_firing_rate_distribution(individual_stats: pd.DataFrame, groupby: Union
|
|
214
242
|
# Validate plot_type (it can be a list or a single type)
|
215
243
|
if isinstance(plot_type, str):
|
216
244
|
plot_type = [plot_type]
|
217
|
-
|
245
|
+
|
218
246
|
for pt in plot_type:
|
219
247
|
if pt not in ["box", "violin", "swarm"]:
|
220
248
|
raise ValueError("plot_type must be one of: 'box', 'violin', 'swarm'.")
|
@@ -224,9 +252,9 @@ def plot_firing_rate_distribution(individual_stats: pd.DataFrame, groupby: Union
|
|
224
252
|
|
225
253
|
# Generate colors if no color_map is provided
|
226
254
|
if color_map is None:
|
227
|
-
cmap = plt.get_cmap(
|
255
|
+
cmap = plt.get_cmap("viridis")
|
228
256
|
color_map = {group: cmap(i / len(unique_groups)) for i, group in enumerate(unique_groups)}
|
229
|
-
|
257
|
+
|
230
258
|
# Ensure color_map contains all groups
|
231
259
|
missing_colors = [group for group in unique_groups if group not in color_map]
|
232
260
|
if missing_colors:
|
@@ -242,18 +270,39 @@ def plot_firing_rate_distribution(individual_stats: pd.DataFrame, groupby: Union
|
|
242
270
|
# Loop over each plot type and overlay them
|
243
271
|
for pt in plot_type:
|
244
272
|
if pt == "box":
|
245
|
-
sns.boxplot(
|
273
|
+
sns.boxplot(
|
274
|
+
data=individual_stats,
|
275
|
+
x="group",
|
276
|
+
y="firing_rate",
|
277
|
+
ax=ax,
|
278
|
+
palette=color_map,
|
279
|
+
width=0.5,
|
280
|
+
)
|
246
281
|
elif pt == "violin":
|
247
|
-
sns.violinplot(
|
282
|
+
sns.violinplot(
|
283
|
+
data=individual_stats,
|
284
|
+
x="group",
|
285
|
+
y="firing_rate",
|
286
|
+
ax=ax,
|
287
|
+
palette=color_map,
|
288
|
+
inner="quartile",
|
289
|
+
alpha=0.4,
|
290
|
+
)
|
248
291
|
elif pt == "swarm":
|
249
|
-
sns.swarmplot(
|
292
|
+
sns.swarmplot(
|
293
|
+
data=individual_stats,
|
294
|
+
x="group",
|
295
|
+
y="firing_rate",
|
296
|
+
ax=ax,
|
297
|
+
palette=color_map,
|
298
|
+
alpha=swarm_alpha,
|
299
|
+
)
|
250
300
|
|
251
301
|
# Formatting
|
252
302
|
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
|
253
303
|
ax.set_xlabel("Population Group")
|
254
304
|
ax.set_ylabel("Firing Rate (spikes/s)")
|
255
305
|
ax.set_title("Firing Rate Distribution for individual cells")
|
256
|
-
ax.grid(axis=
|
306
|
+
ax.grid(axis="y", linestyle="--", alpha=0.7)
|
257
307
|
|
258
308
|
return ax
|
259
|
-
|