bmtool 0.7.1.2__py3-none-any.whl → 0.7.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -712,3 +712,55 @@ def calculate_spike_rate_power_correlation(
712
712
  correlation_results[pop][freq] = {"correlation": corr, "p_value": p_val}
713
713
 
714
714
  return correlation_results, frequencies
715
+
716
+
717
+ def get_spikes_in_cycle(spike_df, lfp_data, spike_fs=1000, lfp_fs=400, band=(30, 80)):
718
+ """
719
+ Analyze spike timing relative to oscillation phases.
720
+
721
+ Parameters:
722
+ -----------
723
+ spike_df : pd.DataFrame
724
+ lfp_data : np.array
725
+ Raw LFP signal
726
+ fs : float
727
+ Sampling frequency of LFP in Hz
728
+ gamma_band : tuple
729
+ Lower and upper bounds of gamma frequency band in Hz
730
+
731
+ Returns:
732
+ --------
733
+ phase_data : dict
734
+ Dictionary containing phase values for each spike and neuron population
735
+ """
736
+ filtered_lfp = butter_bandpass_filter(lfp_data, band[0], band[1], lfp_fs)
737
+
738
+ # Calculate phase using Hilbert transform
739
+ analytic_signal = signal.hilbert(filtered_lfp)
740
+ phase = np.angle(analytic_signal)
741
+ amplitude = np.abs(analytic_signal)
742
+
743
+ # Get unique neuron populations
744
+ neuron_pops = spike_df["pop_name"].unique()
745
+
746
+ # Get the phase at each spike time for each neuron population
747
+ phase_data = {}
748
+
749
+ for pop in neuron_pops:
750
+ # Get spike times for this population
751
+ pop_spikes = spike_df[spike_df["pop_name"] == pop]["timestamps"].values
752
+
753
+ # Convert spike times to sample indices
754
+ spike_times_seconds = pop_spikes / spike_fs
755
+
756
+ # Then convert from seconds to samples at the new sampling rate
757
+ spike_indices = np.round(spike_times_seconds * lfp_fs).astype(int)
758
+
759
+ # Ensure spike times are within LFP data range
760
+ valid_indices = (spike_indices >= 0) & (spike_indices < len(phase))
761
+
762
+ if np.any(valid_indices):
763
+ valid_samples = spike_indices[valid_indices]
764
+ phase_data[pop] = phase[valid_samples]
765
+
766
+ return phase_data, filtered_lfp, phase, amplitude
bmtool/analysis/spikes.py CHANGED
@@ -398,3 +398,58 @@ def compare_firing_over_times(
398
398
  print(f" p-value: {p_val}")
399
399
  print(f" Significant difference (p<0.05): {'Yes' if p_val < 0.05 else 'No'}")
400
400
  return
401
+
402
+
403
+ def find_bursting_cells(
404
+ df: pd.DataFrame, burst_threshold: float = 10, rename_bursting_cells: bool = False
405
+ ) -> pd.DataFrame:
406
+ """
407
+ Finds bursting cells in a population based on a time difference threshold.
408
+
409
+ Parameters
410
+ ----------
411
+ df : pd.DataFrame
412
+ DataFrame containing spike data with columns for timestamps, node_ids, and pop_name
413
+ burst_threshold : float, optional
414
+ Time difference threshold in milliseconds to identify bursts
415
+ rename_bursting_cells : bool, optional
416
+ If True, returns a DataFrame with bursting cells renamed in their pop_name column
417
+
418
+ Returns
419
+ -------
420
+ pd.DataFrame
421
+ DataFrame with bursting cells renamed in their pop_name column
422
+ """
423
+ # Create a new DataFrame with the time differences
424
+ diff_df = df.copy()
425
+ diff_df["time_diff"] = df.groupby("node_ids")["timestamps"].diff()
426
+
427
+ # Create a column indicating whether each time difference is a burst
428
+ diff_df["is_burst_instance"] = diff_df["time_diff"] < burst_threshold
429
+
430
+ # Group by node_ids and check if any row has a burst instance
431
+ burst_summary = diff_df.groupby("node_ids")["is_burst_instance"].any()
432
+
433
+ # Convert to a DataFrame with reset index
434
+ burst_cells = burst_summary.reset_index(name="is_burst")
435
+
436
+ # merge with original df to get timestamps
437
+ burst_cells = pd.merge(burst_cells, df, on="node_ids")
438
+
439
+ # Create a mask for burst cells that don't already have "_bursters" in their name
440
+ burst_mask = (burst_cells["is_burst"] is True) & (
441
+ ~burst_cells["pop_name"].str.contains("_bursters")
442
+ )
443
+
444
+ # Add "_bursters" suffix only to those cells
445
+ if rename_bursting_cells:
446
+ burst_cells.loc[burst_mask, "pop_name"] = (
447
+ burst_cells.loc[burst_mask, "pop_name"] + "_bursters"
448
+ )
449
+
450
+ for pop in burst_cells["pop_name"].unique():
451
+ print(
452
+ f"Number of bursters in {pop}: {burst_cells[burst_cells['pop_name'] == pop]['node_ids'].nunique()}"
453
+ )
454
+
455
+ return burst_cells
@@ -1,5 +1,8 @@
1
1
  import matplotlib.pyplot as plt
2
+ import numpy as np
2
3
  import seaborn as sns
4
+ from matplotlib.gridspec import GridSpec
5
+ from scipy import stats
3
6
 
4
7
 
5
8
  def plot_spike_power_correlation(correlation_results, frequencies, pop_names):
@@ -48,3 +51,152 @@ def plot_spike_power_correlation(correlation_results, frequencies, pop_names):
48
51
  plt.tight_layout()
49
52
 
50
53
  plt.show()
54
+
55
+
56
+ def plot_cycle_with_spike_histograms(phase_data, bins=36, pop_name=None):
57
+ """
58
+ Plot an idealized gamma cycle with spike histograms for different neuron populations.
59
+
60
+ Parameters:
61
+ -----------
62
+ phase_data : dict
63
+ Dictionary containing phase values for each spike and neuron population
64
+ fs : float
65
+ Sampling frequency of LFP in Hz
66
+ bins : int
67
+ Number of bins for the phase histogram (default 36 gives 10-degree bins)
68
+ pop_name : list
69
+ List of population names to be plotted
70
+ """
71
+ sns.set_style("whitegrid")
72
+ # Create a figure with subplots
73
+ fig = plt.figure(figsize=(12, 8))
74
+ gs = GridSpec(len(pop_name) + 1, 1, height_ratios=[1.5] + [1] * len(pop_name))
75
+
76
+ # Top subplot: Idealized gamma cycle
77
+ ax_gamma = fig.add_subplot(gs[0])
78
+
79
+ # Create an idealized gamma cycle
80
+ x = np.linspace(-np.pi, np.pi, 1000)
81
+ y = np.sin(x)
82
+
83
+ ax_gamma.plot(x, y, "b-", linewidth=2)
84
+ ax_gamma.set_title("Cycle with Neuron Population Spike Distributions", fontsize=14)
85
+ ax_gamma.set_ylabel("Amplitude", fontsize=12)
86
+ ax_gamma.set_xlim(-np.pi, np.pi)
87
+ ax_gamma.set_xticks(np.linspace(-np.pi, np.pi, 9))
88
+ ax_gamma.set_xticklabels(["-180°", "-135°", "-90°", "-45°", "0°", "45°", "90°", "135°", "180°"])
89
+ ax_gamma.grid(True)
90
+ ax_gamma.axhline(y=0, color="k", linestyle="-", alpha=0.3)
91
+ ax_gamma.axvline(x=0, color="k", linestyle="--", alpha=0.3)
92
+
93
+ # Generate a color map for the different populations
94
+ colors = plt.cm.tab10(np.linspace(0, 1, len(pop_name)))
95
+
96
+ # Add histograms for each neuron population
97
+ for i, pop_name in enumerate(pop_name):
98
+ ax_hist = fig.add_subplot(gs[i + 1], sharex=ax_gamma)
99
+
100
+ # Compute histogram
101
+ hist, bin_edges = np.histogram(phase_data[pop_name], bins=bins, range=(-np.pi, np.pi))
102
+ bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
103
+
104
+ # Normalize histogram
105
+ if np.sum(hist) > 0:
106
+ hist = hist / np.sum(hist) * 100 # Convert to percentage
107
+
108
+ # Plot histogram
109
+ ax_hist.bar(bin_centers, hist, width=2 * np.pi / bins, alpha=0.7, color=colors[i])
110
+ ax_hist.set_ylabel(f"{pop_name}\nSpikes (%)", fontsize=10)
111
+
112
+ # Add grid to align with gamma cycle
113
+ ax_hist.grid(True, alpha=0.3)
114
+ ax_hist.set_ylim(0, max(hist) * 1.2) # Add some headroom
115
+
116
+ # Set x-label for the last subplot
117
+ ax_hist.set_xlabel("Phase (degrees)", fontsize=12)
118
+
119
+ plt.tight_layout()
120
+ plt.show()
121
+
122
+
123
+ def plot_ppc_by_population(ppc_dict, pop_names, freqs, figsize=(15, 8), title=None):
124
+ """
125
+ Plot PPC for all node populations on one graph with mean and standard error.
126
+
127
+ Parameters:
128
+ -----------
129
+ ppc_dict : dict
130
+ Dictionary containing PPC data organized by population, node, and frequency
131
+ pop_names : list
132
+ List of population names to plot data for
133
+ freqs : list
134
+ List of frequencies to plot
135
+ figsize : tuple
136
+ Figure size for the plot
137
+ """
138
+ # Set up the visualization style
139
+ sns.set_style("whitegrid")
140
+ plt.figure(figsize=figsize)
141
+
142
+ # Calculate the width of each group of bars
143
+ n_groups = len(freqs)
144
+ n_populations = len(pop_names)
145
+ group_width = 0.8
146
+ bar_width = group_width / n_populations
147
+
148
+ # Color palette for different populations
149
+ pop_colors = sns.color_palette(n_colors=n_populations)
150
+
151
+ # For tracking x-axis positions and labels
152
+ x_centers = np.arange(n_groups)
153
+ tick_labels = [str(freq) for freq in freqs]
154
+
155
+ # Process and plot data for each population
156
+ for i, pop in enumerate(pop_names):
157
+ # Store mean and SE for each frequency in this population
158
+ means = []
159
+ errors = []
160
+ valid_freqs_idx = []
161
+
162
+ # Collect and process data for all frequencies in this population
163
+ for freq_idx, freq in enumerate(freqs):
164
+ freq_values = []
165
+
166
+ # Collect values across all nodes for this frequency
167
+ for node in ppc_dict[pop]:
168
+ try:
169
+ ppc_value = ppc_dict[pop][node][freq]
170
+ freq_values.append(ppc_value)
171
+ except KeyError:
172
+ continue
173
+
174
+ # If we have data for this frequency
175
+ if freq_values:
176
+ mean_val = np.mean(freq_values)
177
+ se_val = stats.sem(freq_values)
178
+ means.append(mean_val)
179
+ errors.append(se_val)
180
+ valid_freqs_idx.append(freq_idx)
181
+
182
+ # Calculate x positions for this population's bars
183
+ # Each population's bars are offset within their frequency group
184
+ x_positions = x_centers[valid_freqs_idx] + (i - n_populations / 2 + 0.5) * bar_width
185
+
186
+ # Plot bars with error bars
187
+ plt.bar(
188
+ x_positions, means, width=bar_width * 0.9, color=pop_colors[i], alpha=0.7, label=pop
189
+ )
190
+ plt.errorbar(x_positions, means, yerr=errors, fmt="none", ecolor="black", capsize=4)
191
+
192
+ # Set up the plot labels and legend
193
+ plt.xlabel("Frequency")
194
+ plt.ylabel("PPC Value")
195
+ if title:
196
+ plt.title(title)
197
+ plt.xticks(x_centers, tick_labels)
198
+ plt.legend(title="Population")
199
+
200
+ # Adjust layout and save
201
+ plt.tight_layout()
202
+ plt.show()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bmtool
3
- Version: 0.7.1.2
3
+ Version: 0.7.1.3
4
4
  Summary: BMTool
5
5
  Home-page: https://github.com/cyneuro/bmtool
6
6
  Download-URL:
@@ -38,7 +38,6 @@ Requires-Dist: PyWavelets
38
38
  Requires-Dist: numba
39
39
  Provides-Extra: dev
40
40
  Requires-Dist: ruff>=0.1.0; extra == "dev"
41
- Requires-Dist: pyright>=1.1.0; extra == "dev"
42
41
  Requires-Dist: pytest>=7.0.0; extra == "dev"
43
42
  Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
44
43
  Dynamic: author
@@ -8,13 +8,13 @@ bmtool/plot_commands.py,sha256=Dxm_RaT4CtHnfsltTtUopJ4KVbfhxtktEB_b7bFEXII,12716
8
8
  bmtool/singlecell.py,sha256=I2yolbAnNC8qpnRkNdnDCLidNW7CktmBuRrcowMZJ3A,45041
9
9
  bmtool/synapses.py,sha256=wlRY7IixefPzafqG6k2sPIK4s6PLG9Kct-oCaVR29wA,64269
10
10
  bmtool/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
- bmtool/analysis/entrainment.py,sha256=sjfSxPs1Y0dnEtX9a3IIMEeQ09L6WbhO3KMt-O8SN64,26480
11
+ bmtool/analysis/entrainment.py,sha256=uG2TWbeYJEg_VQB6pKEWlrVBzQ6M4h6FSAZR4GMKp-E,28178
12
12
  bmtool/analysis/lfp.py,sha256=S2JvxkjcK3-EH93wCrhqNSFY6cX7fOq74pz64ibHKrc,26556
13
13
  bmtool/analysis/netcon_reports.py,sha256=VnPZNKPaQA7oh1q9cIatsqQudm4cOtzNtbGPXoiDCD0,2909
14
- bmtool/analysis/spikes.py,sha256=ScP4EeX2QuEd_FXyj3W0WWgZKvZwwneuWuKFe3xwaCY,15115
14
+ bmtool/analysis/spikes.py,sha256=Mz0e7XOey_6eWQDZAU0ePjDzDMDTFkMbWSA5YWDooYk,17122
15
15
  bmtool/bmplot/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
16
  bmtool/bmplot/connections.py,sha256=P1JBG4xCbLVq4sfQuUE6c3dO949qajrjdQcrazdmDS4,53861
17
- bmtool/bmplot/entrainment.py,sha256=TC2qJV0Z6YiK5X7bEpiEyf7nZiWU__466Uhq0kbrhIY,1611
17
+ bmtool/bmplot/entrainment.py,sha256=VSlZvcSeXLr5OxGvmWcGU4s7JS7vOL38lq1XC69O_AE,6926
18
18
  bmtool/bmplot/lfp.py,sha256=SNpbWGOUnYEgnkeBw5S--aPN5mIGD22Gw2Pwus0_lvY,2034
19
19
  bmtool/bmplot/netcon_reports.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
20
20
  bmtool/bmplot/spikes.py,sha256=Lg8V3ynYCqk-QJvq-BOInjZMHYHrxHgXjtDOX67df-A,11148
@@ -26,9 +26,9 @@ bmtool/util/commands.py,sha256=Nn-R-4e9g8ZhSPZvTkr38xeKRPfEMANB9Lugppj82UI,68564
26
26
  bmtool/util/util.py,sha256=owce5BEusZO_8T5x05N2_B583G26vWAy7QX29V0Pj0Y,62818
27
27
  bmtool/util/neuron/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
28
28
  bmtool/util/neuron/celltuner.py,sha256=lokRLUM1rsdSYBYrNbLBBo39j14mm8TBNVNRnSlhHCk,94868
29
- bmtool-0.7.1.2.dist-info/licenses/LICENSE,sha256=qrXg2jj6kz5d0EnN11hllcQt2fcWVNumx0xNbV05nyM,1068
30
- bmtool-0.7.1.2.dist-info/METADATA,sha256=LO1VUW641H9cxsHp1vi809dqrJoZGc4GfJaKqTbOZGc,3623
31
- bmtool-0.7.1.2.dist-info/WHEEL,sha256=DnLRTWE75wApRYVsjgc6wsVswC54sMSJhAEd4xhDpBk,91
32
- bmtool-0.7.1.2.dist-info/entry_points.txt,sha256=0-BHZ6nUnh0twWw9SXNTiRmKjDnb1VO2DfG_-oprhAc,45
33
- bmtool-0.7.1.2.dist-info/top_level.txt,sha256=gpd2Sj-L9tWbuJEd5E8C8S8XkNm5yUE76klUYcM-eWM,7
34
- bmtool-0.7.1.2.dist-info/RECORD,,
29
+ bmtool-0.7.1.3.dist-info/licenses/LICENSE,sha256=qrXg2jj6kz5d0EnN11hllcQt2fcWVNumx0xNbV05nyM,1068
30
+ bmtool-0.7.1.3.dist-info/METADATA,sha256=GpvzWhlfyNgHHCfhakOBeeablgTAOCHeQFvhnExZ-X8,3577
31
+ bmtool-0.7.1.3.dist-info/WHEEL,sha256=DnLRTWE75wApRYVsjgc6wsVswC54sMSJhAEd4xhDpBk,91
32
+ bmtool-0.7.1.3.dist-info/entry_points.txt,sha256=0-BHZ6nUnh0twWw9SXNTiRmKjDnb1VO2DfG_-oprhAc,45
33
+ bmtool-0.7.1.3.dist-info/top_level.txt,sha256=gpd2Sj-L9tWbuJEd5E8C8S8XkNm5yUE76klUYcM-eWM,7
34
+ bmtool-0.7.1.3.dist-info/RECORD,,