bmtool 0.7.7__py3-none-any.whl → 0.7.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bmtool might be problematic. Click here for more details.

bmtool/SLURM.py CHANGED
@@ -409,6 +409,16 @@ class BlockRunner:
409
409
  self.param_name = param_name
410
410
  self.json_file_path = json_file_path
411
411
  self.syn_dict = syn_dict
412
+ # Store original component paths to restore later
413
+ self.original_component_paths = [block.component_path for block in self.blocks]
414
+
415
+ def restore_component_paths(self):
416
+ """
417
+ Restores all blocks' component_path to their original values.
418
+ """
419
+ for i, block in enumerate(self.blocks):
420
+ block.component_path = self.original_component_paths[i]
421
+ print("Component paths restored to original values.", flush=True)
412
422
 
413
423
  def submit_blocks_sequentially(self):
414
424
  """
@@ -473,6 +483,8 @@ class BlockRunner:
473
483
 
474
484
  print(f"Block {block.block_name} completed.", flush=True)
475
485
  print("All blocks are done!", flush=True)
486
+ # Restore component paths to their original values
487
+ self.restore_component_paths()
476
488
  if self.webhook:
477
489
  message = "SIMULATION UPDATE: Simulation are Done!"
478
490
  send_teams_message(self.webhook, message)
@@ -531,6 +543,9 @@ class BlockRunner:
531
543
  print(f"Waiting for the last block {i} to complete...")
532
544
  time.sleep(self.check_interval)
533
545
 
546
+ print("All blocks are done!", flush=True)
547
+ # Restore component paths to their original values
548
+ self.restore_component_paths()
534
549
  if self.webhook:
535
550
  message = "SIMULATION UPDATE: Simulations are Done!"
536
551
  send_teams_message(self.webhook, message)
@@ -697,3 +697,116 @@ def get_spikes_in_cycle(
697
697
  phase_data[pop] = phase[valid_samples]
698
698
 
699
699
  return phase_data
700
+
701
+
702
+ def compute_fr_hist_phase_amplitude(
703
+ spike_df: pd.DataFrame,
704
+ lfp_data: xr.DataArray,
705
+ pop_names: List[str],
706
+ freqs: List[float],
707
+ spike_fs: float = 1000,
708
+ lfp_fs: float = 1000,
709
+ nbins_pha: int = 16,
710
+ nbins_amp: int = 16,
711
+ pop_num: Optional[Dict[str, int]] = None,
712
+ duration: Optional[float] = None
713
+ ) -> np.ndarray:
714
+ """
715
+ Compute firing rate histograms binned by LFP phase and amplitude quantiles.
716
+
717
+ This function computes 2D histograms of spike firing rates as a function of
718
+ instantaneous LFP phase and amplitude at multiple frequencies. The output shows
719
+ percentage change in firing rate relative to the mean across all phase-amplitude bins.
720
+
721
+ Parameters
722
+ ----------
723
+ spike_df : pd.DataFrame
724
+ DataFrame with spike data containing 'timestamps', 'pop_name', 'node_ids' columns
725
+ lfp_data : xr.DataArray
726
+ LFP data with time coordinate
727
+ pop_names : List[str]
728
+ List of population names to analyze
729
+ freqs : List[float]
730
+ List of frequencies to analyze (Hz)
731
+ spike_fs : float, default=1000
732
+ Spike sampling frequency (Hz) - should match lfp_fs for proper alignment
733
+ lfp_fs : float, default=1000
734
+ LFP sampling frequency (Hz)
735
+ nbins_pha : int, default=16
736
+ Number of phase bins
737
+ nbins_amp : int, default=16
738
+ Number of amplitude quantile bins
739
+ pop_num : Dict[str, int], optional
740
+ Number of cells per population. If None, computed from spike_df.
741
+ duration : float, optional
742
+ Duration of the data in seconds. If None, computed from lfp_data.
743
+
744
+ Returns
745
+ -------
746
+ np.ndarray
747
+ fr_hist of shape (n_pop, n_freq, nbins_pha, nbins_amp) with % change in firing rate
748
+
749
+ Examples
750
+ --------
751
+ >>> # Basic usage
752
+ >>> fr_hist = compute_fr_hist_phase_amplitude(
753
+ ... spike_df, lfp_data, ['PV', 'SST'], [25, 40],
754
+ ... spike_fs=1000, lfp_fs=1000
755
+ ... )
756
+
757
+ >>> # With custom binning
758
+ >>> fr_hist = compute_fr_hist_phase_amplitude(
759
+ ... spike_df, lfp_data, pop_names, [30],
760
+ ... nbins_pha=16, nbins_amp=16
761
+ ... )
762
+ """
763
+ # Ensure spike and LFP sampling rates are consistent
764
+ if spike_fs != lfp_fs:
765
+ print(f"Warning: spike_fs ({spike_fs}) != lfp_fs ({lfp_fs}). Using lfp_fs for spike indexing.")
766
+
767
+ if pop_num is None:
768
+ pop_num = {p: len(spike_df[spike_df['pop_name'] == p]['node_ids'].unique()) for p in pop_names}
769
+ if duration is None:
770
+ duration = len(lfp_data) / lfp_fs # in seconds
771
+
772
+ pha_bins = np.linspace(-np.pi, np.pi, nbins_pha + 1)
773
+ quantiles = np.linspace(0, 1, nbins_amp + 1)
774
+
775
+ n_pop = len(pop_names)
776
+ n_freq = len(freqs)
777
+ fr_hist = np.zeros((n_pop, n_freq, nbins_pha, nbins_amp))
778
+
779
+ for i, pop in enumerate(pop_names):
780
+ pop_spikes = spike_df[spike_df['pop_name'] == pop]
781
+ for j, freq in enumerate(freqs):
782
+ # Get filtered LFP and compute phase and amplitude
783
+ filtered_lfp = wavelet_filter(lfp_data.values, freq, lfp_fs)
784
+ phase = np.angle(filtered_lfp)
785
+ amplitude = np.abs(filtered_lfp)
786
+
787
+ # Compute amplitude quantiles
788
+ amp_bins = np.quantile(amplitude, quantiles)
789
+
790
+ # Get spike phases and amplitudes
791
+ spike_times = pop_spikes['timestamps'].values
792
+ # Convert spike times (ms) to LFP sample indices using lfp_fs
793
+ spike_indices = np.round(spike_times / 1000 * lfp_fs).astype(int)
794
+ spike_indices = np.clip(spike_indices, 0, len(phase) - 1)
795
+
796
+ spike_phases = phase[spike_indices]
797
+ spike_amps = amplitude[spike_indices]
798
+
799
+ # Bin spikes
800
+ fr, _, _ = np.histogram2d(spike_phases, spike_amps, bins=(pha_bins, amp_bins))
801
+
802
+ # Normalize to firing rate
803
+ fr /= pop_num[pop] * duration
804
+
805
+ # Compute % change from mean
806
+ fr_mean = fr.mean()
807
+ if fr_mean > 0:
808
+ fr_hist[i, j] = 100 * (fr - fr_mean) / fr_mean
809
+ else:
810
+ fr_hist[i, j] = 0 # Handle case where mean is zero
811
+
812
+ return fr_hist
bmtool/analysis/lfp.py CHANGED
@@ -227,7 +227,7 @@ def fit_fooof(
227
227
 
228
228
  if plot:
229
229
  plt_range = set_range(plt_range)
230
- fm.plot(plt_log=plt_log)
230
+ fm.plot(ax=plt.gca(), plt_log=plt_log)
231
231
  plt.xlim(np.log10(plt_range) if plt_log else plt_range)
232
232
  # plt.ylim(-8, -5.5)
233
233
  if figsize: