bmtool 0.7.0.6.4__py3-none-any.whl → 0.7.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,11 +1,11 @@
1
-
2
1
  import matplotlib.pyplot as plt
3
2
  import seaborn as sns
4
3
 
4
+
5
5
  def plot_spike_power_correlation(correlation_results, frequencies, pop_names):
6
6
  """
7
7
  Plot the correlation between population spike rates and LFP power.
8
-
8
+
9
9
  Parameters:
10
10
  -----------
11
11
  correlation_results : dict
@@ -17,35 +17,34 @@ def plot_spike_power_correlation(correlation_results, frequencies, pop_names):
17
17
  """
18
18
  sns.set_style("whitegrid")
19
19
  plt.figure(figsize=(10, 6))
20
-
20
+
21
21
  for pop in pop_names:
22
22
  # Extract correlation values for each frequency
23
23
  corr_values = []
24
24
  valid_freqs = []
25
-
25
+
26
26
  for freq in frequencies:
27
27
  if freq in correlation_results[pop]:
28
- corr_values.append(correlation_results[pop][freq]['correlation'])
28
+ corr_values.append(correlation_results[pop][freq]["correlation"])
29
29
  valid_freqs.append(freq)
30
-
30
+
31
31
  # Plot correlation line
32
- plt.plot(valid_freqs, corr_values, marker='o', label=pop,
33
- linewidth=2, markersize=6)
34
-
35
- plt.xlabel('Frequency (Hz)', fontsize=12)
36
- plt.ylabel('Spike Rate-Power Correlation', fontsize=12)
37
- plt.title('Spike rate LFP power correlation during stimulus', fontsize=14)
32
+ plt.plot(valid_freqs, corr_values, marker="o", label=pop, linewidth=2, markersize=6)
33
+
34
+ plt.xlabel("Frequency (Hz)", fontsize=12)
35
+ plt.ylabel("Spike Rate-Power Correlation", fontsize=12)
36
+ plt.title("Spike rate LFP power correlation during stimulus", fontsize=14)
38
37
  plt.grid(True, alpha=0.3)
39
38
  plt.legend(fontsize=12)
40
39
  plt.xticks(frequencies[::2]) # Display every other frequency on x-axis
41
-
40
+
42
41
  # Add horizontal line at zero for reference
43
- plt.axhline(y=0, color='gray', linestyle='-', alpha=0.5)
44
-
42
+ plt.axhline(y=0, color="gray", linestyle="-", alpha=0.5)
43
+
45
44
  # Set y-axis limits to make zero visible
46
45
  y_min, y_max = plt.ylim()
47
46
  plt.ylim(min(y_min, -0.1), max(y_max, 0.1))
48
-
47
+
49
48
  plt.tight_layout()
50
-
51
- plt.show()
49
+
50
+ plt.show()
bmtool/bmplot/lfp.py CHANGED
@@ -1,35 +1,43 @@
1
+ import matplotlib.pyplot as plt
1
2
  import numpy as np
3
+
2
4
  from bmtool.analysis.lfp import gen_aperiodic
3
- import matplotlib.pyplot as plt
4
5
 
5
6
 
6
- def plot_spectrogram(sxx_xarray, remove_aperiodic=None, log_power=False,
7
- plt_range=None, clr_freq_range=None, pad=0.03, ax=None):
7
+ def plot_spectrogram(
8
+ sxx_xarray,
9
+ remove_aperiodic=None,
10
+ log_power=False,
11
+ plt_range=None,
12
+ clr_freq_range=None,
13
+ pad=0.03,
14
+ ax=None,
15
+ ):
8
16
  """Plot spectrogram. Determine color limits using value in frequency band clr_freq_range"""
9
17
  sxx = sxx_xarray.PSD.values.copy()
10
18
  t = sxx_xarray.time.values.copy()
11
19
  f = sxx_xarray.frequency.values.copy()
12
20
 
13
- cbar_label = 'PSD' if remove_aperiodic is None else 'PSD Residual'
21
+ cbar_label = "PSD" if remove_aperiodic is None else "PSD Residual"
14
22
  if log_power:
15
- with np.errstate(divide='ignore'):
23
+ with np.errstate(divide="ignore"):
16
24
  sxx = np.log10(sxx)
17
- cbar_label += ' dB' if log_power == 'dB' else ' log(power)'
25
+ cbar_label += " dB" if log_power == "dB" else " log(power)"
18
26
 
19
27
  if remove_aperiodic is not None:
20
28
  f1_idx = 0 if f[0] else 1
21
29
  ap_fit = gen_aperiodic(f[f1_idx:], remove_aperiodic.aperiodic_params)
22
- sxx[f1_idx:, :] -= (ap_fit if log_power else 10 ** ap_fit)[:, None]
23
- sxx[:f1_idx, :] = 0.
30
+ sxx[f1_idx:, :] -= (ap_fit if log_power else 10**ap_fit)[:, None]
31
+ sxx[:f1_idx, :] = 0.0
24
32
 
25
- if log_power == 'dB':
33
+ if log_power == "dB":
26
34
  sxx *= 10
27
35
 
28
36
  if ax is None:
29
37
  _, ax = plt.subplots(1, 1)
30
38
  plt_range = np.array(f[-1]) if plt_range is None else np.array(plt_range)
31
39
  if plt_range.size == 1:
32
- plt_range = [f[0 if f[0] else 1] if log_power else 0., plt_range.item()]
40
+ plt_range = [f[0 if f[0] else 1] if log_power else 0.0, plt_range.item()]
33
41
  f_idx = (f >= plt_range[0]) & (f <= plt_range[1])
34
42
  if clr_freq_range is None:
35
43
  vmin, vmax = None, None
@@ -38,16 +46,15 @@ def plot_spectrogram(sxx_xarray, remove_aperiodic=None, log_power=False,
38
46
  vmin, vmax = sxx[c_idx, :].min(), sxx[c_idx, :].max()
39
47
 
40
48
  f = f[f_idx]
41
- pcm = ax.pcolormesh(t, f, sxx[f_idx, :], shading='gouraud', vmin=vmin, vmax=vmax)
42
- if 'cone_of_influence_frequency' in sxx_xarray:
49
+ pcm = ax.pcolormesh(t, f, sxx[f_idx, :], shading="gouraud", vmin=vmin, vmax=vmax)
50
+ if "cone_of_influence_frequency" in sxx_xarray:
43
51
  coif = sxx_xarray.cone_of_influence_frequency
44
52
  ax.plot(t, coif)
45
- ax.fill_between(t, coif, step='mid', alpha=0.2)
53
+ ax.fill_between(t, coif, step="mid", alpha=0.2)
46
54
  ax.set_xlim(t[0], t[-1])
47
- #ax.set_xlim(t[0],0.2)
55
+ # ax.set_xlim(t[0],0.2)
48
56
  ax.set_ylim(f[0], f[-1])
49
57
  plt.colorbar(mappable=pcm, ax=ax, label=cbar_label, pad=pad)
50
- ax.set_xlabel('Time (sec)')
51
- ax.set_ylabel('Frequency (Hz)')
58
+ ax.set_xlabel("Time (sec)")
59
+ ax.set_ylabel("Frequency (Hz)")
52
60
  return sxx
53
-
@@ -1,4 +0,0 @@
1
-
2
-
3
-
4
-
bmtool/bmplot/spikes.py CHANGED
@@ -1,19 +1,27 @@
1
- from ..util.util import load_nodes_from_config
2
- from typing import Optional, Dict, List, Union
3
- import pandas as pd
4
- from matplotlib.axes import Axes
1
+ from typing import Dict, List, Optional, Union
2
+
5
3
  import matplotlib.pyplot as plt
6
- import seaborn as sns
7
4
  import numpy as np
5
+ import pandas as pd
6
+ import seaborn as sns
7
+ from matplotlib.axes import Axes
8
8
 
9
+ from ..util.util import load_nodes_from_config
9
10
 
10
11
 
11
- def raster(spikes_df: Optional[pd.DataFrame] = None, config: Optional[str] = None, network_name: Optional[str] = None, groupby:Optional[str] = 'pop_name',
12
- ax: Optional[Axes] = None,tstart: Optional[float] = None,tstop: Optional[float] = None,
13
- color_map: Optional[Dict[str, str]] = None) -> Axes:
12
+ def raster(
13
+ spikes_df: Optional[pd.DataFrame] = None,
14
+ config: Optional[str] = None,
15
+ network_name: Optional[str] = None,
16
+ groupby: Optional[str] = "pop_name",
17
+ ax: Optional[Axes] = None,
18
+ tstart: Optional[float] = None,
19
+ tstop: Optional[float] = None,
20
+ color_map: Optional[Dict[str, str]] = None,
21
+ ) -> Axes:
14
22
  """
15
23
  Plots a raster plot of neural spikes, with different colors for each population.
16
-
24
+
17
25
  Parameters:
18
26
  ----------
19
27
  spikes_df : pd.DataFrame, optional
@@ -30,12 +38,12 @@ def raster(spikes_df: Optional[pd.DataFrame] = None, config: Optional[str] = Non
30
38
  Stop time for filtering spikes; only spikes with timestamps less than `tstop` will be plotted.
31
39
  color_map : dict, optional
32
40
  Dictionary specifying colors for each population. Keys should be population names, and values should be color values.
33
-
41
+
34
42
  Returns:
35
43
  -------
36
44
  matplotlib.axes.Axes
37
45
  Axes with the raster plot.
38
-
46
+
39
47
  Notes:
40
48
  -----
41
49
  - If `config` is provided, the function merges population names from the node data with `spikes_df`.
@@ -48,9 +56,9 @@ def raster(spikes_df: Optional[pd.DataFrame] = None, config: Optional[str] = Non
48
56
 
49
57
  # Filter spikes by time range if specified
50
58
  if tstart is not None:
51
- spikes_df = spikes_df[spikes_df['timestamps'] > tstart]
59
+ spikes_df = spikes_df[spikes_df["timestamps"] > tstart]
52
60
  if tstop is not None:
53
- spikes_df = spikes_df[spikes_df['timestamps'] < tstop]
61
+ spikes_df = spikes_df[spikes_df["timestamps"] < tstop]
54
62
 
55
63
  # Load and merge node population data if config is provided
56
64
  if config:
@@ -59,45 +67,59 @@ def raster(spikes_df: Optional[pd.DataFrame] = None, config: Optional[str] = Non
59
67
  nodes = nodes.get(network_name, {})
60
68
  else:
61
69
  nodes = list(nodes.values())[0] if nodes else {}
62
- print("Grabbing first network; specify a network name to ensure correct node population is selected.")
63
-
70
+ print(
71
+ "Grabbing first network; specify a network name to ensure correct node population is selected."
72
+ )
73
+
64
74
  # Find common columns, but exclude the join key from the list
65
75
  common_columns = spikes_df.columns.intersection(nodes.columns).tolist()
66
- common_columns = [col for col in common_columns if col != 'node_ids'] # Remove our join key from the common list
76
+ common_columns = [
77
+ col for col in common_columns if col != "node_ids"
78
+ ] # Remove our join key from the common list
67
79
 
68
80
  # Drop all intersecting columns except the join key column from df2
69
81
  spikes_df = spikes_df.drop(columns=common_columns)
70
82
  # merge nodes and spikes df
71
- spikes_df = spikes_df.merge(nodes[groupby], left_on='node_ids', right_index=True, how='left')
72
-
83
+ spikes_df = spikes_df.merge(
84
+ nodes[groupby], left_on="node_ids", right_index=True, how="left"
85
+ )
73
86
 
74
87
  # Get unique population names
75
88
  unique_pop_names = spikes_df[groupby].unique()
76
-
89
+
77
90
  # Generate colors if no color_map is provided
78
91
  if color_map is None:
79
- cmap = plt.get_cmap('tab10') # Default colormap
80
- color_map = {pop_name: cmap(i / len(unique_pop_names)) for i, pop_name in enumerate(unique_pop_names)}
92
+ cmap = plt.get_cmap("tab10") # Default colormap
93
+ color_map = {
94
+ pop_name: cmap(i / len(unique_pop_names)) for i, pop_name in enumerate(unique_pop_names)
95
+ }
81
96
  else:
82
97
  # Ensure color_map contains all population names
83
98
  missing_colors = [pop for pop in unique_pop_names if pop not in color_map]
84
99
  if missing_colors:
85
100
  raise ValueError(f"color_map is missing colors for populations: {missing_colors}")
86
-
101
+
87
102
  # Plot each population with its specified or generated color
88
103
  for pop_name, group in spikes_df.groupby(groupby):
89
- ax.scatter(group['timestamps'], group['node_ids'], label=pop_name, color=color_map[pop_name], s=0.5)
104
+ ax.scatter(
105
+ group["timestamps"], group["node_ids"], label=pop_name, color=color_map[pop_name], s=0.5
106
+ )
90
107
 
91
108
  # Label axes
92
109
  ax.set_xlabel("Time")
93
110
  ax.set_ylabel("Node ID")
94
- ax.legend(title="Population", loc='upper right', framealpha=0.9, markerfirst=False)
95
-
111
+ ax.legend(title="Population", loc="upper right", framealpha=0.9, markerfirst=False)
112
+
96
113
  return ax
97
-
114
+
115
+
98
116
  # uses df from bmtool.analysis.spikes compute_firing_rate_stats
99
- def plot_firing_rate_pop_stats(firing_stats: pd.DataFrame, groupby: Union[str, List[str]], ax: Optional[Axes] = None,
100
- color_map: Optional[Dict[str, str]] = None) -> Axes:
117
+ def plot_firing_rate_pop_stats(
118
+ firing_stats: pd.DataFrame,
119
+ groupby: Union[str, List[str]],
120
+ ax: Optional[Axes] = None,
121
+ color_map: Optional[Dict[str, str]] = None,
122
+ ) -> Axes:
101
123
  """
102
124
  Plots a bar graph of mean firing rates with error bars (standard deviation).
103
125
 
@@ -129,7 +151,7 @@ def plot_firing_rate_pop_stats(firing_stats: pd.DataFrame, groupby: Union[str, L
129
151
 
130
152
  # Generate colors if no color_map is provided
131
153
  if color_map is None:
132
- cmap = plt.get_cmap('viridis')
154
+ cmap = plt.get_cmap("viridis")
133
155
  color_map = {group: cmap(i / len(unique_groups)) for i, group in enumerate(unique_groups)}
134
156
  else:
135
157
  # Ensure color_map contains all groups
@@ -157,13 +179,13 @@ def plot_firing_rate_pop_stats(firing_stats: pd.DataFrame, groupby: Union[str, L
157
179
 
158
180
  # Add error bars manually with caps
159
181
  _, caps, _ = ax.errorbar(
160
- x=np.arange(len(x_labels)),
161
- y=means,
162
- yerr=std_devs,
163
- fmt='none',
164
- capsize=5,
165
- capthick=2,
166
- color="black"
182
+ x=np.arange(len(x_labels)),
183
+ y=means,
184
+ yerr=std_devs,
185
+ fmt="none",
186
+ capsize=5,
187
+ capthick=2,
188
+ color="black",
167
189
  )
168
190
 
169
191
  # Formatting
@@ -172,14 +194,20 @@ def plot_firing_rate_pop_stats(firing_stats: pd.DataFrame, groupby: Union[str, L
172
194
  ax.set_xlabel("Population Group")
173
195
  ax.set_ylabel("Mean Firing Rate (spikes/s)")
174
196
  ax.set_title("Firing Rate Statistics by Population")
175
- ax.grid(axis='y', linestyle='--', alpha=0.7)
197
+ ax.grid(axis="y", linestyle="--", alpha=0.7)
176
198
 
177
199
  return ax
178
200
 
201
+
179
202
  # uses df from bmtool.analysis.spikes compute_firing_rate_stats
180
- def plot_firing_rate_distribution(individual_stats: pd.DataFrame, groupby: Union[str, list], ax: Optional[Axes] = None,
181
- color_map: Optional[Dict[str, str]] = None,
182
- plot_type: Union[str, list] = "box", swarm_alpha: float = 0.6) -> Axes:
203
+ def plot_firing_rate_distribution(
204
+ individual_stats: pd.DataFrame,
205
+ groupby: Union[str, list],
206
+ ax: Optional[Axes] = None,
207
+ color_map: Optional[Dict[str, str]] = None,
208
+ plot_type: Union[str, list] = "box",
209
+ swarm_alpha: float = 0.6,
210
+ ) -> Axes:
183
211
  """
184
212
  Plots a distribution of individual firing rates using one or more plot types
185
213
  (box plot, violin plot, or swarm plot), overlaying them on top of each other.
@@ -214,7 +242,7 @@ def plot_firing_rate_distribution(individual_stats: pd.DataFrame, groupby: Union
214
242
  # Validate plot_type (it can be a list or a single type)
215
243
  if isinstance(plot_type, str):
216
244
  plot_type = [plot_type]
217
-
245
+
218
246
  for pt in plot_type:
219
247
  if pt not in ["box", "violin", "swarm"]:
220
248
  raise ValueError("plot_type must be one of: 'box', 'violin', 'swarm'.")
@@ -224,9 +252,9 @@ def plot_firing_rate_distribution(individual_stats: pd.DataFrame, groupby: Union
224
252
 
225
253
  # Generate colors if no color_map is provided
226
254
  if color_map is None:
227
- cmap = plt.get_cmap('viridis')
255
+ cmap = plt.get_cmap("viridis")
228
256
  color_map = {group: cmap(i / len(unique_groups)) for i, group in enumerate(unique_groups)}
229
-
257
+
230
258
  # Ensure color_map contains all groups
231
259
  missing_colors = [group for group in unique_groups if group not in color_map]
232
260
  if missing_colors:
@@ -242,18 +270,39 @@ def plot_firing_rate_distribution(individual_stats: pd.DataFrame, groupby: Union
242
270
  # Loop over each plot type and overlay them
243
271
  for pt in plot_type:
244
272
  if pt == "box":
245
- sns.boxplot(data=individual_stats, x="group", y="firing_rate", ax=ax, palette=color_map, width=0.5)
273
+ sns.boxplot(
274
+ data=individual_stats,
275
+ x="group",
276
+ y="firing_rate",
277
+ ax=ax,
278
+ palette=color_map,
279
+ width=0.5,
280
+ )
246
281
  elif pt == "violin":
247
- sns.violinplot(data=individual_stats, x="group", y="firing_rate", ax=ax, palette=color_map, inner="quartile", alpha=0.4)
282
+ sns.violinplot(
283
+ data=individual_stats,
284
+ x="group",
285
+ y="firing_rate",
286
+ ax=ax,
287
+ palette=color_map,
288
+ inner="quartile",
289
+ alpha=0.4,
290
+ )
248
291
  elif pt == "swarm":
249
- sns.swarmplot(data=individual_stats, x="group", y="firing_rate", ax=ax, palette=color_map, alpha=swarm_alpha)
292
+ sns.swarmplot(
293
+ data=individual_stats,
294
+ x="group",
295
+ y="firing_rate",
296
+ ax=ax,
297
+ palette=color_map,
298
+ alpha=swarm_alpha,
299
+ )
250
300
 
251
301
  # Formatting
252
302
  ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
253
303
  ax.set_xlabel("Population Group")
254
304
  ax.set_ylabel("Firing Rate (spikes/s)")
255
305
  ax.set_title("Firing Rate Distribution for individual cells")
256
- ax.grid(axis='y', linestyle='--', alpha=0.7)
306
+ ax.grid(axis="y", linestyle="--", alpha=0.7)
257
307
 
258
308
  return ax
259
-