bmtool 0.7.7__py3-none-any.whl → 0.7.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bmtool might be problematic. Click here for more details.

@@ -1,6 +1,7 @@
1
- from typing import List, Tuple, Union
1
+ from typing import Dict, List, Optional, Tuple, Union
2
2
 
3
3
  import matplotlib.pyplot as plt
4
+ from matplotlib.figure import Figure
4
5
  import numpy as np
5
6
  import pandas as pd
6
7
  import seaborn as sns
@@ -18,16 +19,16 @@ def plot_spike_power_correlation(
18
19
  lfp_data: xr.DataArray,
19
20
  firing_quantile: float,
20
21
  fs: float,
21
- pop_names: list,
22
+ pop_names: List[str],
22
23
  filter_method: str = "wavelet",
23
24
  bandwidth: float = 2.0,
24
- lowcut: float = None,
25
- highcut: float = None,
26
- freq_range: tuple = (10, 100),
25
+ lowcut: Optional[float] = None,
26
+ highcut: Optional[float] = None,
27
+ freq_range: Tuple[float, float] = (10, 100),
27
28
  freq_step: float = 5,
28
29
  type_name: str = "raw",
29
- time_windows: list = None,
30
- error_type: str = "ci", # New parameter: "ci" for confidence interval, "sem" for standard error, "std" for standard deviation
30
+ time_windows: Optional[List[Tuple[float, float]]] = None,
31
+ error_type: str = "ci",
31
32
  ):
32
33
  """
33
34
  Calculate and plot correlation between population spike rates and LFP power across frequencies.
@@ -35,13 +36,15 @@ def plot_spike_power_correlation(
35
36
 
36
37
  Parameters
37
38
  ----------
38
- spike_rate : xr.DataArray
39
- Population spike rates with dimensions (time, population[, type])
39
+ spike_df : pd.DataFrame
40
+ DataFrame containing spike data with columns 'timestamps', 'node_ids', and 'pop_name'.
40
41
  lfp_data : xr.DataArray
41
42
  LFP data
43
+ firing_quantile : float
44
+ Upper quantile threshold for selecting high-firing cells (e.g., 0.8 for top 20%)
42
45
  fs : float
43
46
  Sampling frequency
44
- pop_names : list
47
+ pop_names : List[str]
45
48
  List of population names to analyze
46
49
  filter_method : str, optional
47
50
  Filtering method to use, either 'wavelet' or 'butter' (default: 'wavelet')
@@ -51,16 +54,21 @@ def plot_spike_power_correlation(
51
54
  Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
52
55
  highcut : float, optional
53
56
  Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
54
- freq_range : tuple, optional
57
+ freq_range : Tuple[float, float], optional
55
58
  Min and max frequency to analyze (default: (10, 100))
56
59
  freq_step : float, optional
57
60
  Step size for frequency analysis (default: 5)
58
61
  type_name : str, optional
59
62
  Which type of spike rate to use if 'type' dimension exists (default: 'raw')
60
- time_windows : list, optional
63
+ time_windows : List[Tuple[float, float]], optional
61
64
  List of (start, end) time tuples for trial-based analysis. If None, analyze entire signal
62
65
  error_type : str, optional
63
66
  Type of error bars to plot: "ci" for 95% confidence interval, "sem" for standard error, "std" for standard deviation
67
+
68
+ Returns
69
+ -------
70
+ matplotlib.figure.Figure
71
+ The figure containing the correlation plot
64
72
  """
65
73
 
66
74
  if not (0 <= firing_quantile < 1):
@@ -206,7 +214,7 @@ def plot_spike_power_correlation(
206
214
 
207
215
  # Plotting
208
216
  sns.set_style("whitegrid")
209
- plt.figure(figsize=(12, 8))
217
+ fig = plt.figure(figsize=(12, 8))
210
218
 
211
219
  for i, pop in enumerate(pop_names):
212
220
  # Extract data for plotting
@@ -294,28 +302,31 @@ def plot_spike_power_correlation(
294
302
  plt.ylim(min(y_min, -0.1), max(y_max, 0.1))
295
303
 
296
304
  plt.tight_layout()
297
- plt.show()
305
+ return fig
298
306
 
299
307
 
300
- def plot_cycle_with_spike_histograms(phase_data, bins=36, pop_name=None):
308
+ def plot_cycle_with_spike_histograms(phase_data, pop_names: List[str], bins: int = 36):
301
309
  """
302
310
  Plot an idealized cycle with spike histograms for different neuron populations.
303
311
 
304
- Parameters:
312
+ Parameters
305
313
  -----------
306
314
  phase_data : dict
307
315
  Dictionary containing phase values for each spike and neuron population
308
- fs : float
309
- Sampling frequency of LFP in Hz
310
- bins : int
311
- Number of bins for the phase histogram (default 36 gives 10-degree bins)
312
- pop_name : list
316
+ pop_names : List[str]
313
317
  List of population names to be plotted
318
+ bins : int, optional
319
+ Number of bins for the phase histogram (default 36 gives 10-degree bins)
320
+
321
+ Returns
322
+ -------
323
+ matplotlib.figure.Figure
324
+ The figure containing the cycle and histograms
314
325
  """
315
326
  sns.set_style("whitegrid")
316
327
  # Create a figure with subplots
317
328
  fig = plt.figure(figsize=(12, 8))
318
- gs = GridSpec(len(pop_name) + 1, 1, height_ratios=[1.5] + [1] * len(pop_name))
329
+ gs = GridSpec(len(pop_names) + 1, 1, height_ratios=[1.5] + [1] * len(pop_names))
319
330
 
320
331
  # Top subplot: Idealized gamma cycle
321
332
  ax_gamma = fig.add_subplot(gs[0])
@@ -335,10 +346,10 @@ def plot_cycle_with_spike_histograms(phase_data, bins=36, pop_name=None):
335
346
  ax_gamma.axvline(x=0, color="k", linestyle="--", alpha=0.3)
336
347
 
337
348
  # Generate a color map for the different populations
338
- colors = plt.cm.tab10(np.linspace(0, 1, len(pop_name)))
349
+ colors = plt.cm.tab10(np.linspace(0, 1, len(pop_names)))
339
350
 
340
351
  # Add histograms for each neuron population
341
- for i, pop_name in enumerate(pop_name):
352
+ for i, pop_name in enumerate(pop_names):
342
353
  ax_hist = fig.add_subplot(gs[i + 1], sharex=ax_gamma)
343
354
 
344
355
  # Compute histogram
@@ -361,27 +372,34 @@ def plot_cycle_with_spike_histograms(phase_data, bins=36, pop_name=None):
361
372
  ax_hist.set_xlabel("Phase (degrees)", fontsize=12)
362
373
 
363
374
  plt.tight_layout()
364
- plt.show()
375
+ return fig
365
376
 
366
377
 
367
- def plot_entrainment_by_population(ppc_dict, pop_names, freqs, figsize=(15, 8), title=None):
378
+ def plot_entrainment_by_population(ppc_dict: Dict[str, Dict[str, Dict[float, float]]], pop_names: List[str], freqs: List[float], figsize: Tuple[float, float] = (15, 8), title: Optional[str] = None):
368
379
  """
369
380
  Plot PPC for all node populations on one graph with mean and standard error.
370
381
 
371
382
  Parameters:
372
383
  -----------
373
- ppc_dict : dict
384
+ ppc_dict : Dict[str, Dict[str, Dict[float, float]]]
374
385
  Dictionary containing PPC data organized by population, node, and frequency
375
- pop_names : list
386
+ pop_names : List[str]
376
387
  List of population names to plot data for
377
- freqs : list
388
+ freqs : List[float]
378
389
  List of frequencies to plot
379
- figsize : tuple
390
+ figsize : Tuple[float, float], optional
380
391
  Figure size for the plot
392
+ title : str, optional
393
+ Title for the plot
394
+
395
+ Returns
396
+ -------
397
+ matplotlib.figure.Figure
398
+ The figure containing the bar plot
381
399
  """
382
400
  # Set up the visualization style
383
401
  sns.set_style("whitegrid")
384
- plt.figure(figsize=figsize)
402
+ fig = plt.figure(figsize=figsize)
385
403
 
386
404
  # Calculate the width of each group of bars
387
405
  n_groups = len(freqs)
@@ -443,23 +461,25 @@ def plot_entrainment_by_population(ppc_dict, pop_names, freqs, figsize=(15, 8),
443
461
 
444
462
  # Adjust layout and save
445
463
  plt.tight_layout()
446
- plt.show()
464
+ return fig
447
465
 
448
466
 
449
- def plot_entrainment_swarm_plot(ppc_dict, pop_names, freq, save_path=None, title=None):
467
+ def plot_entrainment_swarm_plot(ppc_dict: Dict[str, Dict[str, Dict[float, float]]], pop_names: List[str], freq: Union[float, int], save_path: Optional[str] = None, title: Optional[str] = None):
450
468
  """
451
469
  Plot a swarm plot of the entrainment for different populations at a single frequency.
452
470
 
453
471
  Parameters:
454
472
  -----------
455
- ppc_dict : dict
473
+ ppc_dict : Dict[str, Dict[str, Dict[float, float]]]
456
474
  Dictionary containing PPC values organized by population, node, and frequency
457
- pop_names : list
475
+ pop_names : List[str]
458
476
  List of population names to include in the plot
459
- freq : float or int
477
+ freq : Union[float, int]
460
478
  The specific frequency to plot
461
479
  save_path : str, optional
462
480
  Path to save the figure. If None, figure is just displayed.
481
+ title : str, optional
482
+ Title for the plot
463
483
 
464
484
  Returns:
465
485
  --------
@@ -500,7 +520,7 @@ def plot_entrainment_swarm_plot(ppc_dict, pop_names, freq, save_path=None, title
500
520
  print(f"{pop}: {mean_val:.4f} ± {sem_val:.4f} (n={n})")
501
521
 
502
522
  # Create figure
503
- plt.figure(figsize=(max(8, len(pop_names) * 1.5), 8))
523
+ fig = plt.figure(figsize=(max(8, len(pop_names) * 1.5), 8))
504
524
 
505
525
  # Create swarm plot
506
526
  ax = sns.swarmplot(
@@ -608,20 +628,20 @@ def plot_entrainment_swarm_plot(ppc_dict, pop_names, freq, save_path=None, title
608
628
  if save_path:
609
629
  plt.savefig(f"{save_path}/ppc_change_swarm_plot_{freq}Hz.png", dpi=300, bbox_inches="tight")
610
630
 
611
- plt.show()
631
+ return fig
612
632
 
613
633
 
614
634
  def plot_trial_avg_entrainment(
615
635
  spike_df: pd.DataFrame,
616
- lfp: np.ndarray,
636
+ lfp: xr.DataArray,
617
637
  time_windows: List[Tuple[float, float]],
618
638
  entrainment_method: str,
619
639
  pop_names: List[str],
620
640
  freqs: Union[List[float], np.ndarray],
621
641
  firing_quantile: float,
622
642
  spike_fs: float = 1000,
623
- error_type: str = "ci", # New parameter: "ci" for confidence interval, "sem" for standard error, "std" for standard deviation
624
- ) -> None:
643
+ error_type: str = "ci",
644
+ ) -> Figure:
625
645
  """
626
646
  Plot trial-averaged entrainment for specified population names. Only supports wavelet filter current, could easily add other support
627
647
 
@@ -629,9 +649,7 @@ def plot_trial_avg_entrainment(
629
649
  -----------
630
650
  spike_df : pd.DataFrame
631
651
  Spike data containing timestamps, node_ids, and pop_name columns
632
- spike_fs : float
633
- fs for spike data. Default is 1000
634
- lfp : xarray
652
+ lfp : xr.DataArray
635
653
  Xarray for a channel of the lfp data
636
654
  time_windows : List[Tuple[float, float]]
637
655
  List of windows to analysis with start and stp time [(start_time, end_time), ...] for each trial
@@ -643,7 +661,9 @@ def plot_trial_avg_entrainment(
643
661
  Array of frequencies to analyze (Hz)
644
662
  firing_quantile : float
645
663
  Upper quantile threshold for selecting high-firing cells (e.g., 0.8 for top 20%)
646
- error_type : str
664
+ spike_fs : float, optional
665
+ fs for spike data. Default is 1000
666
+ error_type : str, optional
647
667
  Type of error bars to plot: "ci" for 95% confidence interval, "sem" for standard error, "std" for standard deviation
648
668
 
649
669
  Raises:
@@ -655,8 +675,8 @@ def plot_trial_avg_entrainment(
655
675
 
656
676
  Returns:
657
677
  --------
658
- None
659
- Displays plot and prints summary statistics
678
+ matplotlib.figure.Figure
679
+ The figure containing the plot
660
680
  """
661
681
  sns.set_style("whitegrid")
662
682
  # Validate inputs
@@ -813,7 +833,7 @@ def plot_trial_avg_entrainment(
813
833
  error_plv[pop_name] = np.nanstd(all_plv_data[pop_name], axis=0, ddof=1)
814
834
 
815
835
  # Create the combined plot
816
- plt.figure(figsize=(12, 8))
836
+ fig = plt.figure(figsize=(12, 8))
817
837
 
818
838
  # Define markers and colors for different populations
819
839
  markers = ["o-", "s-", "^-", "D-", "v-", "<-", ">-", "p-"]
@@ -862,4 +882,104 @@ def plot_trial_avg_entrainment(
862
882
  plt.legend(fontsize=10)
863
883
  plt.grid(True, alpha=0.3)
864
884
  plt.tight_layout()
865
- plt.show()
885
+ return fig
886
+
887
+
888
+ def plot_fr_hist_phase_amplitude(
889
+ fr_hist: np.ndarray,
890
+ pop_names: List[str],
891
+ freq_labels: List[str],
892
+ nbins_pha: int = 16,
893
+ nbins_amp: int = 16,
894
+ common_clim: bool = True,
895
+ figsize: Tuple[float, float] = (3, 2),
896
+ cmap: str = 'viridis',
897
+ title: Optional[str] = None
898
+ ) -> Tuple[plt.Figure, np.ndarray]:
899
+ """
900
+ Plot firing rate histograms binned by LFP phase and amplitude.
901
+ Check out the bmtool/bmtool/analysis/entrainment.py function
902
+ compute_fr_hist_phase_amplitude
903
+
904
+ Parameters
905
+ ----------
906
+ fr_hist : np.ndarray
907
+ Firing rate histogram of shape (n_pop, n_freq, nbins_pha, nbins_amp)
908
+ pop_names : List[str]
909
+ List of population names
910
+ freq_labels : List[str]
911
+ List of frequency labels for subplot titles (e.g., ['Beta', 'Gamma'])
912
+ nbins_pha : int, default=16
913
+ Number of phase bins
914
+ nbins_amp : int, default=16
915
+ Number of amplitude bins
916
+ common_clim : bool, default=True
917
+ Whether to use common color limits across all subplots
918
+ figsize : Tuple[float, float], default=(3, 2)
919
+ Size of each subplot
920
+ cmap : str, default='RdBu_r'
921
+ Colormap to use
922
+ title : Optional[str], default=None
923
+ Overall title for the figure
924
+
925
+ Returns
926
+ -------
927
+ Tuple[plt.Figure, np.ndarray]
928
+ Figure and axes objects
929
+
930
+ Examples
931
+ --------
932
+ >>> fig, axs = plot_fr_hist_phase_amplitude(
933
+ ... fr_hist, ['PV', 'SST'], ['Beta', 'Gamma'],
934
+ ... common_clim=True, cmap='RdBu_r', title='LFP Phase-Amplitude Coupling'
935
+ ... )
936
+ """
937
+ pha_bins = np.linspace(-np.pi, np.pi, nbins_pha + 1)
938
+ quantiles = np.linspace(0, 1, nbins_amp + 1)
939
+
940
+ n_pop = len(pop_names)
941
+ n_freq = len(freq_labels)
942
+
943
+ fig, axs = plt.subplots(n_pop, n_freq,
944
+ figsize=(figsize[0] * n_freq, figsize[1] * n_pop),
945
+ squeeze=False)
946
+
947
+
948
+ # Add overall title if provided
949
+ if title:
950
+ fig.suptitle(title, fontsize=14, y=0.98)
951
+
952
+ for i, p in enumerate(pop_names):
953
+ if common_clim:
954
+ vmin, vmax = fr_hist.min(), fr_hist.max()
955
+ else:
956
+ vmin, vmax = None, None
957
+
958
+ for j, freq_label in enumerate(freq_labels):
959
+ ax = axs[i, j]
960
+ pcm = ax.pcolormesh(pha_bins, quantiles, fr_hist[i, j].T,
961
+ vmin=vmin, vmax=vmax, cmap=cmap)
962
+ ax.set_title(p)
963
+
964
+ if i < n_pop - 1:
965
+ ax.get_xaxis().set_visible(False)
966
+ else:
967
+ ax.set_xlabel(freq_label.title() + ' Phase')
968
+ ax.set_xticks((-np.pi, 0, np.pi))
969
+ ax.set_xticklabels([r'$-\pi$', '0', r'$\pi$'])
970
+
971
+ if j > 0:
972
+ ax.get_yaxis().set_visible(False)
973
+ else:
974
+ ax.set_ylabel('Amplitude (quantile)')
975
+
976
+ if not common_clim:
977
+ plt.colorbar(mappable=pcm, ax=ax,
978
+ label='Firing rate (% Change)' if j == n_freq - 1 else None,
979
+ pad=0.02)
980
+
981
+ if common_clim:
982
+ plt.colorbar(mappable=pcm, ax=axs[i],
983
+ label='Firing rate (% Change)', pad=0.02)
984
+
985
+ return fig, axs
bmtool/bmplot/lfp.py CHANGED
@@ -1,18 +1,54 @@
1
1
  import matplotlib.pyplot as plt
2
2
  import numpy as np
3
3
  from fooof.sim.gen import gen_aperiodic
4
+ from typing import Optional, List, Dict, Tuple, Any
5
+ import pandas as pd
6
+ from ..analysis.spikes import get_population_spike_rate
7
+ from ..analysis.lfp import get_lfp_power, load_ecp_to_xarray, ecp_to_lfp
8
+ from matplotlib.figure import Figure
4
9
 
5
10
 
6
11
  def plot_spectrogram(
7
- sxx_xarray,
8
- remove_aperiodic=None,
9
- log_power=False,
10
- plt_range=None,
11
- clr_freq_range=None,
12
- pad=0.03,
13
- ax=None,
14
- ):
15
- """Plot spectrogram. Determine color limits using value in frequency band clr_freq_range"""
12
+ sxx_xarray: Any,
13
+ remove_aperiodic: Optional[Any] = None,
14
+ log_power: bool = False,
15
+ plt_range: Optional[Tuple[float, float]] = None,
16
+ clr_freq_range: Optional[Tuple[float, float]] = None,
17
+ pad: float = 0.03,
18
+ ax: Optional[plt.Axes] = None,
19
+ ) -> Figure:
20
+ """
21
+ Plot a power spectrogram with optional aperiodic removal and frequency-based coloring.
22
+
23
+ Parameters
24
+ ----------
25
+ sxx_xarray : array-like
26
+ Spectrogram data as an xarray DataArray with PSD values.
27
+ remove_aperiodic : optional
28
+ FOOOF model object for aperiodic subtraction. If None, raw spectrum is displayed.
29
+ log_power : bool or str, optional
30
+ If True or 'dB', convert power to log scale. Default is False.
31
+ plt_range : tuple of float, optional
32
+ Frequency range to display as (f_min, f_max). If None, displays full range.
33
+ clr_freq_range : tuple of float, optional
34
+ Frequency range to use for determining color limits. If None, uses full range.
35
+ pad : float, optional
36
+ Padding for colorbar. Default is 0.03.
37
+ ax : matplotlib.axes.Axes, optional
38
+ Axes to plot on. If None, creates a new figure and axes.
39
+
40
+ Returns
41
+ -------
42
+ matplotlib.figure.Figure
43
+ The figure object containing the spectrogram.
44
+
45
+ Examples
46
+ --------
47
+ >>> fig = plot_spectrogram(
48
+ ... sxx_xarray, log_power='dB',
49
+ ... plt_range=(10, 100), clr_freq_range=(20, 50)
50
+ ... )
51
+ """
16
52
  sxx = sxx_xarray.PSD.values.copy()
17
53
  t = sxx_xarray.time.values.copy()
18
54
  f = sxx_xarray.frequency.values.copy()
@@ -45,7 +81,7 @@ def plot_spectrogram(
45
81
  vmin, vmax = sxx[c_idx, :].min(), sxx[c_idx, :].max()
46
82
 
47
83
  f = f[f_idx]
48
- pcm = ax.pcolormesh(t, f, sxx[f_idx, :], shading="gouraud", vmin=vmin, vmax=vmax)
84
+ pcm = ax.pcolormesh(t, f, sxx[f_idx, :], shading="gouraud", vmin=vmin, vmax=vmax, rasterized=True)
49
85
  if "cone_of_influence_frequency" in sxx_xarray:
50
86
  coif = sxx_xarray.cone_of_influence_frequency
51
87
  ax.plot(t, coif)
@@ -56,4 +92,103 @@ def plot_spectrogram(
56
92
  plt.colorbar(mappable=pcm, ax=ax, label=cbar_label, pad=pad)
57
93
  ax.set_xlabel("Time (sec)")
58
94
  ax.set_ylabel("Frequency (Hz)")
59
- return sxx
95
+ return ax.figure
96
+
97
+
98
+ def plot_population_spike_rates_with_lfp(
99
+ spikes_df: pd.DataFrame,
100
+ freq_of_interest: List[float],
101
+ freq_labels: List[str],
102
+ freq_colors: List[str],
103
+ time_range: Tuple[float, float],
104
+ pop_names: List[str],
105
+ pop_color: Dict[str, str],
106
+ trial_path: str,
107
+ filter_column: Optional[str] = None,
108
+ filter_value: Optional[Any] = None,
109
+ ) -> Optional[Figure]:
110
+ """
111
+ Plot population spike rates with LFP power overlays.
112
+
113
+ Parameters
114
+ ----------
115
+ spikes_df : pd.DataFrame
116
+ DataFrame with spike data.
117
+ freq_of_interest : list of float
118
+ List of frequencies for LFP power analysis (required).
119
+ freq_labels : list of str
120
+ Labels for the frequencies (required).
121
+ freq_colors : list of str
122
+ Colors for the frequency plots (required).
123
+ time_range : tuple of float
124
+ Tuple (start, end) for x-axis time limits (required).
125
+ pop_names : list of str
126
+ List of population names (required).
127
+ pop_color : dict
128
+ Dictionary mapping population names to colors (required).
129
+ trial_path : str
130
+ Path to trial data (required).
131
+ filter_column : str, optional
132
+ Column name to filter spikes_df on (optional).
133
+ filter_value : any, optional
134
+ Value to filter for in filter_column (optional).
135
+
136
+ Returns
137
+ -------
138
+ matplotlib.figure.Figure or None
139
+ Figure object containing the plot, or None if no data to plot.
140
+
141
+ Examples
142
+ --------
143
+ >>> fig = plot_population_spike_rates_with_lfp(
144
+ ... spikes_df, [40, 80], ['Beta', 'Gamma'],
145
+ ... ['blue', 'red'], (0, 10), ['PV', 'SST'],
146
+ ... {'PV': 'blue', 'SST': 'red'}, 'trial_data.h5'
147
+ ... )
148
+ """
149
+ # Compute spike rates based on filtering
150
+ if filter_column and filter_column in spikes_df.columns:
151
+ filtered_df = spikes_df[spikes_df[filter_column] == filter_value]
152
+ if not filtered_df.empty:
153
+ spike_rate = get_population_spike_rate(filtered_df, fs=400, network_name='cortex')
154
+ plot_title = f'{filter_column} {filter_value}'
155
+ save_suffix = f'_{filter_column}_{filter_value}'
156
+ else:
157
+ print(f"No data found for {filter_column} == {filter_value}.")
158
+ return
159
+ else:
160
+ spike_rate = get_population_spike_rate(spikes_df, fs=400, network_name='cortex')
161
+ plot_title = 'Overall Spike Rates'
162
+ save_suffix = '_overall'
163
+
164
+ # Load LFP data and compute power for each frequency of interest
165
+ ecp = load_ecp_to_xarray(ecp_file=trial_path + "/ecp.h5")
166
+ lfp = ecp_to_lfp(ecp)
167
+ powers = [
168
+ get_lfp_power(lfp, freq_of_interest=freq, fs=lfp.fs, filter_method="wavelet", bandwidth=1.0)
169
+ for freq in freq_of_interest
170
+ ]
171
+
172
+ # Plotting
173
+ fig, axes = plt.subplots(len(spike_rate.population), 1, figsize=(12, 10))
174
+ for i, pop in enumerate(pop_names):
175
+ if pop in spike_rate.population.values:
176
+ ax = axes.flat[i]
177
+ spike_rate.sel(type='raw', population=pop).plot(ax=ax, color=pop_color[pop])
178
+ ax.set_title(f'{pop}')
179
+ ax.set_ylabel('Spike Rate (Hz)', color=pop_color[pop])
180
+ ax.tick_params(axis='y', labelcolor=pop_color[pop])
181
+
182
+ # Twin axis for LFP power
183
+ ax2 = ax.twinx()
184
+ for power, label, color in zip(powers, freq_labels, freq_colors):
185
+ ax2.plot(power['time'], power.values.squeeze(), color=color, label=label)
186
+ ax2.set_ylabel('LFP Power', color='black')
187
+ ax2.tick_params(axis='y', labelcolor='black')
188
+ ax2.legend(loc='upper right')
189
+
190
+ ax.set_xlim(time_range)
191
+
192
+ fig.suptitle(plot_title, fontsize=16, y=0.98)
193
+ plt.tight_layout(rect=[0, 0, 1, 0.95])
194
+ return fig
@@ -0,0 +1 @@
1
+ # NOT IMPLETMENTED YET