canns 0.14.3__py3-none-any.whl → 0.15.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. canns/analyzer/data/asa/__init__.py +56 -21
  2. canns/analyzer/data/asa/coho.py +21 -0
  3. canns/analyzer/data/asa/cohomap.py +453 -0
  4. canns/analyzer/data/asa/cohomap_vectors.py +365 -0
  5. canns/analyzer/data/asa/cohospace.py +155 -1165
  6. canns/analyzer/data/asa/cohospace_phase_centers.py +119 -0
  7. canns/analyzer/data/asa/cohospace_scatter.py +1115 -0
  8. canns/analyzer/data/asa/embedding.py +5 -7
  9. canns/analyzer/data/asa/fr.py +1 -8
  10. canns/analyzer/data/asa/path.py +70 -0
  11. canns/analyzer/data/asa/plotting.py +5 -30
  12. canns/analyzer/data/asa/utils.py +160 -0
  13. canns/analyzer/data/cell_classification/__init__.py +10 -0
  14. canns/analyzer/data/cell_classification/core/__init__.py +4 -0
  15. canns/analyzer/data/cell_classification/core/btn.py +272 -0
  16. canns/analyzer/data/cell_classification/visualization/__init__.py +3 -0
  17. canns/analyzer/data/cell_classification/visualization/btn_plots.py +241 -0
  18. canns/analyzer/visualization/__init__.py +2 -0
  19. canns/analyzer/visualization/core/config.py +20 -0
  20. canns/analyzer/visualization/theta_sweep_plots.py +142 -0
  21. canns/pipeline/asa/runner.py +19 -19
  22. canns/pipeline/asa_gui/__init__.py +5 -3
  23. canns/pipeline/asa_gui/analysis_modes/pathcompare_mode.py +3 -1
  24. canns/pipeline/asa_gui/core/runner.py +23 -23
  25. canns/pipeline/asa_gui/views/pages/preprocess_page.py +7 -12
  26. {canns-0.14.3.dist-info → canns-0.15.1.dist-info}/METADATA +1 -1
  27. {canns-0.14.3.dist-info → canns-0.15.1.dist-info}/RECORD +30 -23
  28. canns/analyzer/data/asa/filters.py +0 -208
  29. {canns-0.14.3.dist-info → canns-0.15.1.dist-info}/WHEEL +0 -0
  30. {canns-0.14.3.dist-info → canns-0.15.1.dist-info}/entry_points.txt +0 -0
  31. {canns-0.14.3.dist-info → canns-0.15.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,241 @@
1
+ """BTN visualization utilities."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
+ import numpy as np
8
+ from matplotlib import pyplot as plt
9
+ from scipy.ndimage import gaussian_filter1d
10
+
11
+ from canns.analyzer.data.asa.utils import _ensure_plot_config
12
+ from canns.analyzer.visualization.core.config import PlotConfig, finalize_figure
13
+
14
+ _DEFAULT_BTN_COLORS = {
15
+ "B": "#1f77b4",
16
+ "T": "#000000",
17
+ "N": "#2ca02c",
18
+ }
19
+
20
+
21
+ def _canonical_label(label: str) -> str:
22
+ lab = label.strip().lower()
23
+ if lab in ("b", "bursty"):
24
+ return "B"
25
+ if lab in ("t", "theta", "theta-modulated", "theta_modulated", "theta modulated"):
26
+ return "T"
27
+ if lab in ("n", "nonbursty", "non-bursty", "non_bursty"):
28
+ return "N"
29
+ return label
30
+
31
+
32
+ def _cluster_order(labels: np.ndarray, mapping: dict[int, str] | None) -> list[int]:
33
+ cids = [int(c) for c in np.unique(labels)]
34
+ if mapping is None:
35
+ return sorted(cids)
36
+
37
+ def _key(cid: int) -> tuple[int, str]:
38
+ lab = _canonical_label(mapping.get(int(cid), str(cid)))
39
+ order = {"B": 0, "T": 1, "N": 2}.get(lab, 999)
40
+ return (order, str(lab))
41
+
42
+ return sorted(cids, key=_key)
43
+
44
+
45
+ def _label_color(
46
+ label: str,
47
+ colors: dict[str, str] | None,
48
+ fallback_idx: int,
49
+ ) -> str:
50
+ if colors and label in colors:
51
+ return colors[label]
52
+ if label in _DEFAULT_BTN_COLORS:
53
+ return _DEFAULT_BTN_COLORS[label]
54
+ cmap = plt.get_cmap("tab10")
55
+ return cmap(fallback_idx % 10)
56
+
57
+
58
+ def _normalize_rows(acorr: np.ndarray, mode: str | None) -> np.ndarray:
59
+ if mode is None or mode == "none":
60
+ return acorr
61
+ if mode == "probability":
62
+ denom = acorr.sum(axis=1, keepdims=True)
63
+ elif mode == "peak":
64
+ denom = acorr.max(axis=1, keepdims=True)
65
+ elif mode == "first":
66
+ denom = acorr[:, :1]
67
+ else:
68
+ raise ValueError(f"Unknown normalize mode: {mode!r}")
69
+ denom = np.where(denom == 0, 1.0, denom)
70
+ return acorr / denom
71
+
72
+
73
+ def plot_btn_distance_matrix(
74
+ *,
75
+ dist: np.ndarray | None = None,
76
+ labels: np.ndarray | None = None,
77
+ mapping: dict[int, str] | None = None,
78
+ sort_by_label: bool = True,
79
+ title: str = "BTN distance matrix",
80
+ cmap: str = "afmhot",
81
+ figsize: tuple[int, int] = (5, 5),
82
+ save_path: str | None = None,
83
+ show: bool = True,
84
+ ax: plt.Axes | None = None,
85
+ config: PlotConfig | None = None,
86
+ ) -> tuple[plt.Figure, plt.Axes, np.ndarray]:
87
+ """Plot a distance matrix heatmap sorted by BTN cluster labels."""
88
+ if dist is None or labels is None:
89
+ raise ValueError("dist and labels are required.")
90
+
91
+ labels = np.asarray(labels).astype(int)
92
+
93
+ if sort_by_label:
94
+ cids = _cluster_order(labels, mapping)
95
+ order = np.concatenate([np.where(labels == c)[0] for c in cids])
96
+ else:
97
+ order = np.arange(len(labels))
98
+
99
+ dist_sorted = dist[np.ix_(order, order)]
100
+
101
+ config = _ensure_plot_config(
102
+ config,
103
+ PlotConfig.for_static_plot,
104
+ title=title,
105
+ figsize=figsize,
106
+ save_path=save_path,
107
+ show=show,
108
+ kwargs={},
109
+ )
110
+
111
+ created_fig = False
112
+ if ax is None:
113
+ fig, ax = plt.subplots(1, 1, figsize=config.figsize)
114
+ created_fig = True
115
+ else:
116
+ fig = ax.figure
117
+
118
+ im = ax.imshow(dist_sorted, cmap=cmap, origin="lower", interpolation="nearest")
119
+ fig.colorbar(im, ax=ax, label="Cosine distance")
120
+ ax.set_title(config.title)
121
+ ax.set_xlabel("Neuron")
122
+ ax.set_ylabel("Neuron")
123
+
124
+ if sort_by_label:
125
+ sizes = [np.sum(labels == c) for c in cids]
126
+ boundaries = np.cumsum(sizes)[:-1]
127
+ for b in boundaries:
128
+ ax.axhline(b - 0.5, color="w", linewidth=0.6, alpha=0.7)
129
+ ax.axvline(b - 0.5, color="w", linewidth=0.6, alpha=0.7)
130
+
131
+ if created_fig:
132
+ fig.tight_layout()
133
+ finalize_figure(fig, config, rasterize_artists=[im] if config.rasterized else None)
134
+
135
+ return fig, ax, order
136
+
137
+
138
+ def plot_btn_autocorr_summary(
139
+ *,
140
+ acorr: np.ndarray | None = None,
141
+ labels: np.ndarray | None = None,
142
+ bin_times: np.ndarray | None = None,
143
+ res: float | None = None,
144
+ mapping: dict[int, str] | None = None,
145
+ colors: dict[str, str] | None = None,
146
+ normalize: str | None = "probability",
147
+ smooth_sigma: float | None = None,
148
+ long_max_ms: float | None = 200.0,
149
+ short_max_ms: float | None = None,
150
+ title: str = "BTN temporal autocorr",
151
+ figsize: tuple[int, int] = (8, 3),
152
+ save_path: str | None = None,
153
+ show: bool = True,
154
+ config: PlotConfig | None = None,
155
+ ) -> plt.Figure:
156
+ """Plot class-averaged ISI autocorr curves (mean +/- SEM)."""
157
+ if acorr is None or labels is None:
158
+ raise ValueError("acorr and labels are required.")
159
+
160
+ labels = np.asarray(labels).astype(int)
161
+ acorr = np.asarray(acorr)
162
+ acorr_plot = _normalize_rows(acorr.astype(float, copy=False), normalize)
163
+ if smooth_sigma is not None:
164
+ acorr_plot = gaussian_filter1d(acorr_plot, sigma=float(smooth_sigma), axis=1)
165
+
166
+ if bin_times is not None:
167
+ bin_times = np.asarray(bin_times)
168
+ x = 0.5 * (bin_times[:-1] + bin_times[1:])
169
+ elif res is not None:
170
+ x = np.arange(acorr.shape[1]) * float(res)
171
+ else:
172
+ raise ValueError("Provide bin_times or res to define lag axis.")
173
+
174
+ x_ms = x * 1000.0
175
+
176
+ cids = _cluster_order(labels, mapping)
177
+ label_strings = []
178
+ for c in cids:
179
+ if mapping is not None:
180
+ label_strings.append(_canonical_label(mapping.get(int(c), str(c))))
181
+ else:
182
+ label_strings.append(str(c))
183
+
184
+ show_short = short_max_ms is not None
185
+ ncols = 2 if show_short else 1
186
+
187
+ config = _ensure_plot_config(
188
+ config,
189
+ PlotConfig.for_static_plot,
190
+ title=title,
191
+ figsize=figsize if ncols == 1 else (figsize[0] * 1.6, figsize[1]),
192
+ save_path=save_path,
193
+ show=show,
194
+ kwargs={},
195
+ )
196
+
197
+ fig, axes = plt.subplots(1, ncols, figsize=config.figsize)
198
+ if ncols == 1:
199
+ axes = [axes]
200
+
201
+ def _plot_panel(ax: plt.Axes, max_ms: float | None, panel_title: str):
202
+ if max_ms is None:
203
+ mask = np.ones_like(x_ms, dtype=bool)
204
+ else:
205
+ mask = x_ms <= max_ms
206
+
207
+ for idx, (cid, label_str) in enumerate(zip(cids, label_strings, strict=False)):
208
+ rows = acorr_plot[labels == cid]
209
+ if rows.size == 0:
210
+ continue
211
+ mean = rows.mean(axis=0)
212
+ sem = rows.std(axis=0) / np.sqrt(rows.shape[0])
213
+ color = _label_color(label_str, colors, idx)
214
+ ax.plot(x_ms[mask], mean[mask], color=color, lw=2, label=label_str)
215
+ ax.fill_between(
216
+ x_ms[mask],
217
+ (mean - sem)[mask],
218
+ (mean + sem)[mask],
219
+ color=color,
220
+ alpha=0.25,
221
+ linewidth=0,
222
+ )
223
+
224
+ ax.set_xlabel("Lag (ms)")
225
+ ax.set_title(panel_title)
226
+ ax.grid(False)
227
+
228
+ _plot_panel(axes[0], long_max_ms, "Long lag")
229
+ ylabel = "Probability" if normalize == "probability" else "Autocorr (norm)"
230
+ axes[0].set_ylabel(ylabel)
231
+
232
+ if show_short:
233
+ _plot_panel(axes[1], float(short_max_ms), "Short lag")
234
+
235
+ for ax in axes:
236
+ ax.legend(frameon=False)
237
+
238
+ fig.suptitle(config.title)
239
+ fig.tight_layout()
240
+ finalize_figure(fig, config)
241
+ return fig
@@ -34,6 +34,7 @@ from .theta_sweep_plots import (
34
34
  create_theta_sweep_grid_cell_animation,
35
35
  create_theta_sweep_place_cell_animation,
36
36
  plot_grid_cell_manifold,
37
+ plot_internal_position_trajectory,
37
38
  plot_population_activity_with_theta,
38
39
  )
39
40
  from .tuning_plots import tuning_curve
@@ -70,5 +71,6 @@ __all__ = [
70
71
  "create_theta_sweep_grid_cell_animation",
71
72
  "create_theta_sweep_place_cell_animation",
72
73
  "plot_grid_cell_manifold",
74
+ "plot_internal_position_trajectory",
73
75
  "plot_population_activity_with_theta",
74
76
  ]
@@ -469,6 +469,26 @@ class PlotConfigs:
469
469
  defaults.update(kwargs)
470
470
  return PlotConfig.for_static_plot(**defaults)
471
471
 
472
+ @staticmethod
473
+ def internal_position_trajectory_static(**kwargs: Any) -> PlotConfig:
474
+ defaults: dict[str, Any] = {
475
+ "title": "Internal Position vs. Real Trajectory",
476
+ "figsize": (6, 4),
477
+ }
478
+ plot_kwargs: dict[str, Any] = {
479
+ "cmap": "cool",
480
+ "add_colorbar": True,
481
+ "colorbar": {"label": "Max GC activity"},
482
+ "trajectory_color": "black",
483
+ "trajectory_linewidth": 1.0,
484
+ "scatter_size": 4,
485
+ "scatter_alpha": 0.9,
486
+ }
487
+ plot_kwargs.update(kwargs.pop("kwargs", {}))
488
+ defaults["kwargs"] = plot_kwargs
489
+ defaults.update(kwargs)
490
+ return PlotConfig.for_static_plot(**defaults)
491
+
472
492
  @staticmethod
473
493
  def direction_cell_polar(**kwargs: Any) -> PlotConfig:
474
494
  """Configuration for direction cell polar plot visualization.
@@ -760,6 +760,148 @@ def plot_grid_cell_manifold(
760
760
  raise e
761
761
 
762
762
 
763
+ def plot_internal_position_trajectory(
764
+ internal_position: np.ndarray,
765
+ position: np.ndarray,
766
+ max_activity: np.ndarray | None = None,
767
+ env_size: float | tuple[float, float] | tuple[float, float, float, float] | None = None,
768
+ config: PlotConfig | None = None,
769
+ ax: plt.Axes | None = None,
770
+ # Backward compatibility parameters
771
+ title: str = "Internal Position (GC bump) vs. Real Trajectory",
772
+ figsize: tuple[int, int] = (6, 4),
773
+ cmap: str = "cool",
774
+ show: bool = True,
775
+ save_path: str | None = None,
776
+ **kwargs,
777
+ ) -> tuple[plt.Figure, plt.Axes]:
778
+ """Plot internal position (GC bump) against the real trajectory.
779
+
780
+ Args:
781
+ internal_position: Internal decoded positions ``(T, 2)``.
782
+ position: Real positions ``(T, 2)``.
783
+ max_activity: Optional per-time max activity to color the internal position.
784
+ env_size: Environment size. If float, uses ``[0, env_size]`` for both axes.
785
+ If a tuple of 2, treats as ``(width, height)``. If a tuple of 4, treats as
786
+ ``(xmin, xmax, ymin, ymax)``.
787
+ config: PlotConfig object for unified configuration.
788
+ ax: Optional axis to draw on instead of creating a new figure.
789
+ **kwargs: Additional parameters for backward compatibility.
790
+
791
+ Returns:
792
+ tuple: ``(figure, axis)`` objects.
793
+ """
794
+ if internal_position.ndim != 2 or internal_position.shape[1] != 2:
795
+ raise ValueError(
796
+ f"internal_position must be (T, 2), got shape {internal_position.shape}"
797
+ )
798
+ if position.ndim != 2 or position.shape[1] != 2:
799
+ raise ValueError(f"position must be (T, 2), got shape {position.shape}")
800
+ if internal_position.shape[0] != position.shape[0]:
801
+ raise ValueError(
802
+ "internal_position and position must have the same length: "
803
+ f"{internal_position.shape[0]} != {position.shape[0]}"
804
+ )
805
+ if max_activity is not None and max_activity.shape[0] != internal_position.shape[0]:
806
+ raise ValueError(
807
+ "max_activity must have same length as internal_position: "
808
+ f"{max_activity.shape[0]} != {internal_position.shape[0]}"
809
+ )
810
+
811
+ if config is None:
812
+ config = PlotConfig(
813
+ title=title,
814
+ figsize=figsize,
815
+ show=show,
816
+ save_path=save_path,
817
+ kwargs={"cmap": cmap, **kwargs},
818
+ )
819
+
820
+ plot_kwargs = config.to_matplotlib_kwargs()
821
+ trajectory_color = plot_kwargs.pop("trajectory_color", "black")
822
+ trajectory_linewidth = plot_kwargs.pop("trajectory_linewidth", 1.0)
823
+ scatter_size = plot_kwargs.pop("scatter_size", 4)
824
+ scatter_alpha = plot_kwargs.pop("scatter_alpha", 0.9)
825
+ add_colorbar = bool(plot_kwargs.pop("add_colorbar", True))
826
+ colorbar_options = plot_kwargs.pop("colorbar", {}) if add_colorbar else {}
827
+ if "cmap" not in plot_kwargs:
828
+ plot_kwargs["cmap"] = "cool"
829
+
830
+ if max_activity is None:
831
+ add_colorbar = False
832
+
833
+ axis_provided = ax is not None
834
+ if not axis_provided:
835
+ fig, ax = plt.subplots(figsize=config.figsize)
836
+ else:
837
+ fig = ax.figure
838
+
839
+ try:
840
+ scatter = ax.scatter(
841
+ internal_position[:, 0],
842
+ internal_position[:, 1],
843
+ c=max_activity,
844
+ s=scatter_size,
845
+ alpha=scatter_alpha,
846
+ **plot_kwargs,
847
+ )
848
+ ax.plot(position[:, 0], position[:, 1], color=trajectory_color, lw=trajectory_linewidth)
849
+
850
+ ax.set_aspect("equal", adjustable="box")
851
+ if config.title:
852
+ ax.set_title(config.title, fontsize=14, fontweight="bold")
853
+
854
+ if env_size is not None:
855
+ if isinstance(env_size, (tuple, list, np.ndarray)):
856
+ if len(env_size) == 2:
857
+ ax.set_xlim(0, env_size[0])
858
+ ax.set_ylim(0, env_size[1])
859
+ elif len(env_size) == 4:
860
+ ax.set_xlim(env_size[0], env_size[1])
861
+ ax.set_ylim(env_size[2], env_size[3])
862
+ else:
863
+ raise ValueError("env_size tuple must be length 2 or 4.")
864
+ else:
865
+ ax.set_xlim(0, env_size)
866
+ ax.set_ylim(0, env_size)
867
+
868
+ sns.despine(ax=ax)
869
+
870
+ if add_colorbar:
871
+ default_cbar_opts = {"pad": 0.15, "size": "5%", "label": "Max GC activity"}
872
+ if isinstance(colorbar_options, dict):
873
+ extra_cbar_kwargs = colorbar_options.get("kwargs", {})
874
+ for key in ("pad", "size", "label"):
875
+ if key in colorbar_options:
876
+ default_cbar_opts[key] = colorbar_options[key]
877
+ else:
878
+ extra_cbar_kwargs = {}
879
+
880
+ divider = make_axes_locatable(ax)
881
+ cax = divider.append_axes(
882
+ "right", size=default_cbar_opts["size"], pad=default_cbar_opts["pad"]
883
+ )
884
+ cbar = fig.colorbar(scatter, cax=cax, **extra_cbar_kwargs)
885
+ if default_cbar_opts["label"]:
886
+ cbar.set_label(default_cbar_opts["label"], fontsize=12)
887
+
888
+ if not axis_provided:
889
+ fig.tight_layout()
890
+
891
+ show_flag = config.show and not axis_provided
892
+ finalize_figure(
893
+ fig,
894
+ replace(config, show=show_flag),
895
+ rasterize_artists=[scatter] if config.rasterized else None,
896
+ always_close=not show_flag,
897
+ )
898
+ return fig, ax
899
+
900
+ except Exception as e:
901
+ plt.close(fig)
902
+ raise e
903
+
904
+
763
905
  @dataclass(slots=True)
764
906
  class _PlaceCellAnimationData:
765
907
  """Immutable container for place cell animation arrays."""
@@ -628,7 +628,7 @@ class PipelineRunner:
628
628
  self, asa_data: dict[str, Any], state: WorkflowState, log_callback
629
629
  ) -> dict[str, Path]:
630
630
  """Run CohoMap analysis (TDA + decode + plotting)."""
631
- from canns.analyzer.data.asa import plot_cohomap_multi
631
+ from canns.analyzer.data.asa import plot_cohomap_scatter_multi
632
632
  from canns.analyzer.visualization import PlotConfigs
633
633
 
634
634
  tda_dir = self._results_dir(state) / "TDA"
@@ -661,7 +661,7 @@ class PipelineRunner:
661
661
  log_callback("Generating cohomology map...")
662
662
  pos = self._aligned_pos if self._aligned_pos is not None else asa_data
663
663
  config = PlotConfigs.cohomap(show=False, save_path=str(cohomap_path))
664
- plot_cohomap_multi(
664
+ plot_cohomap_scatter_multi(
665
665
  decoding_result=decode_result,
666
666
  position_data={"x": pos["x"], "y": pos["y"]},
667
667
  config=config,
@@ -838,16 +838,16 @@ class PipelineRunner:
838
838
  ) -> dict[str, Path]:
839
839
  """Run cohomology space visualization."""
840
840
  from canns.analyzer.data.asa import (
841
- plot_cohospace_neuron_1d,
842
- plot_cohospace_neuron_2d,
843
- plot_cohospace_population_1d,
844
- plot_cohospace_population_2d,
845
- plot_cohospace_trajectory_1d,
846
- plot_cohospace_trajectory_2d,
841
+ plot_cohospace_scatter_neuron_1d,
842
+ plot_cohospace_scatter_neuron_2d,
843
+ plot_cohospace_scatter_population_1d,
844
+ plot_cohospace_scatter_population_2d,
845
+ plot_cohospace_scatter_trajectory_1d,
846
+ plot_cohospace_scatter_trajectory_2d,
847
847
  )
848
- from canns.analyzer.data.asa.cohospace import (
849
- plot_cohospace_neuron_skewed,
850
- plot_cohospace_population_skewed,
848
+ from canns.analyzer.data.asa.cohospace_scatter import (
849
+ plot_cohospace_scatter_neuron_skewed,
850
+ plot_cohospace_scatter_population_skewed,
851
851
  )
852
852
  from canns.analyzer.visualization import PlotConfigs
853
853
 
@@ -937,7 +937,7 @@ class PipelineRunner:
937
937
  traj_path = out_dir / "cohospace_trajectory.png"
938
938
  if dim_mode == "1d":
939
939
  traj_cfg = PlotConfigs.cohospace_trajectory_1d(show=False, save_path=str(traj_path))
940
- plot_cohospace_trajectory_1d(
940
+ plot_cohospace_scatter_trajectory_1d(
941
941
  coords=coords2,
942
942
  times=None,
943
943
  subsample=subsample,
@@ -945,7 +945,7 @@ class PipelineRunner:
945
945
  )
946
946
  else:
947
947
  traj_cfg = PlotConfigs.cohospace_trajectory_2d(show=False, save_path=str(traj_path))
948
- plot_cohospace_trajectory_2d(
948
+ plot_cohospace_scatter_trajectory_2d(
949
949
  coords=coords2,
950
950
  times=None,
951
951
  subsample=subsample,
@@ -958,7 +958,7 @@ class PipelineRunner:
958
958
  log_callback(f"Plotting neuron {neuron_id}...")
959
959
  neuron_path = out_dir / f"cohospace_neuron_{neuron_id}.png"
960
960
  if unfold == "skew" and dim_mode != "1d":
961
- plot_cohospace_neuron_skewed(
961
+ plot_cohospace_scatter_neuron_skewed(
962
962
  coords=coordsbox2,
963
963
  activity=activity,
964
964
  neuron_id=int(neuron_id),
@@ -974,7 +974,7 @@ class PipelineRunner:
974
974
  neuron_cfg = PlotConfigs.cohospace_neuron_1d(
975
975
  show=False, save_path=str(neuron_path)
976
976
  )
977
- plot_cohospace_neuron_1d(
977
+ plot_cohospace_scatter_neuron_1d(
978
978
  coords=coordsbox2,
979
979
  activity=activity,
980
980
  neuron_id=int(neuron_id),
@@ -986,7 +986,7 @@ class PipelineRunner:
986
986
  neuron_cfg = PlotConfigs.cohospace_neuron_2d(
987
987
  show=False, save_path=str(neuron_path)
988
988
  )
989
- plot_cohospace_neuron_2d(
989
+ plot_cohospace_scatter_neuron_2d(
990
990
  coords=coordsbox2,
991
991
  activity=activity,
992
992
  neuron_id=int(neuron_id),
@@ -1001,7 +1001,7 @@ class PipelineRunner:
1001
1001
  pop_path = out_dir / "cohospace_population.png"
1002
1002
  neuron_ids = list(range(activity.shape[1]))
1003
1003
  if unfold == "skew" and dim_mode != "1d":
1004
- plot_cohospace_population_skewed(
1004
+ plot_cohospace_scatter_population_skewed(
1005
1005
  coords=coords2,
1006
1006
  activity=activity,
1007
1007
  neuron_ids=neuron_ids,
@@ -1017,7 +1017,7 @@ class PipelineRunner:
1017
1017
  pop_cfg = PlotConfigs.cohospace_population_1d(
1018
1018
  show=False, save_path=str(pop_path)
1019
1019
  )
1020
- plot_cohospace_population_1d(
1020
+ plot_cohospace_scatter_population_1d(
1021
1021
  coords=coords2,
1022
1022
  activity=activity,
1023
1023
  neuron_ids=neuron_ids,
@@ -1029,7 +1029,7 @@ class PipelineRunner:
1029
1029
  pop_cfg = PlotConfigs.cohospace_population_2d(
1030
1030
  show=False, save_path=str(pop_path)
1031
1031
  )
1032
- plot_cohospace_population_2d(
1032
+ plot_cohospace_scatter_population_2d(
1033
1033
  coords=coords2,
1034
1034
  activity=activity,
1035
1035
  neuron_ids=neuron_ids,
@@ -2,9 +2,9 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- import sys
6
- import os
7
5
  import importlib.util
6
+ import os
7
+ import sys
8
8
 
9
9
  __all__ = ["main", "ASAGuiApp"]
10
10
 
@@ -14,7 +14,9 @@ if _pyside6_missing:
14
14
  try: # pragma: no cover - only used in CI/test runs
15
15
  import pytest
16
16
 
17
- pytest.skip("PySide6 is not installed; skipping asa_gui module.", allow_module_level=True)
17
+ pytest.skip(
18
+ "PySide6 is not installed; skipping asa_gui module.", allow_module_level=True
19
+ )
18
20
  except Exception:
19
21
  pass
20
22
 
@@ -250,7 +250,9 @@ class PathCompareMode(AbstractAnalysisMode):
250
250
  self.dim.setToolTip("1D decoded dimension index.")
251
251
  self.dim1.setToolTip("2D decoded dimension 1.")
252
252
  self.dim2.setToolTip("2D decoded dimension 2.")
253
- self.use_box.setToolTip("Use coordsbox/times_box alignment (recommended with speed_filter).")
253
+ self.use_box.setToolTip(
254
+ "Use coordsbox/times_box alignment (recommended with speed_filter)."
255
+ )
254
256
  self.interp_full.setToolTip("Interpolate back to full trajectory.")
255
257
  self.coords_key.setToolTip("Optional decode coords key (default coords/coordsbox).")
256
258
  self.times_key.setToolTip("Optional times_box key.")