canns 0.14.2__py3-none-any.whl → 0.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. canns/analyzer/data/asa/__init__.py +77 -21
  2. canns/analyzer/data/asa/coho.py +97 -0
  3. canns/analyzer/data/asa/cohomap.py +408 -0
  4. canns/analyzer/data/asa/cohomap_scatter.py +10 -0
  5. canns/analyzer/data/asa/cohomap_vectors.py +311 -0
  6. canns/analyzer/data/asa/cohospace.py +173 -1153
  7. canns/analyzer/data/asa/cohospace_phase_centers.py +137 -0
  8. canns/analyzer/data/asa/cohospace_scatter.py +1220 -0
  9. canns/analyzer/data/asa/embedding.py +3 -4
  10. canns/analyzer/data/asa/plotting.py +4 -4
  11. canns/analyzer/data/cell_classification/__init__.py +10 -0
  12. canns/analyzer/data/cell_classification/core/__init__.py +4 -0
  13. canns/analyzer/data/cell_classification/core/btn.py +272 -0
  14. canns/analyzer/data/cell_classification/visualization/__init__.py +3 -0
  15. canns/analyzer/data/cell_classification/visualization/btn_plots.py +258 -0
  16. canns/analyzer/visualization/__init__.py +2 -0
  17. canns/analyzer/visualization/core/config.py +20 -0
  18. canns/analyzer/visualization/theta_sweep_plots.py +142 -0
  19. canns/pipeline/asa/runner.py +19 -19
  20. canns/pipeline/asa_gui/__init__.py +5 -3
  21. canns/pipeline/asa_gui/analysis_modes/pathcompare_mode.py +32 -4
  22. canns/pipeline/asa_gui/core/runner.py +23 -23
  23. canns/pipeline/asa_gui/views/pages/preprocess_page.py +250 -8
  24. {canns-0.14.2.dist-info → canns-0.15.0.dist-info}/METADATA +2 -1
  25. {canns-0.14.2.dist-info → canns-0.15.0.dist-info}/RECORD +28 -20
  26. {canns-0.14.2.dist-info → canns-0.15.0.dist-info}/WHEEL +0 -0
  27. {canns-0.14.2.dist-info → canns-0.15.0.dist-info}/entry_points.txt +0 -0
  28. {canns-0.14.2.dist-info → canns-0.15.0.dist-info}/licenses/LICENSE +0 -0
@@ -760,6 +760,148 @@ def plot_grid_cell_manifold(
760
760
  raise e
761
761
 
762
762
 
763
+ def plot_internal_position_trajectory(
764
+ internal_position: np.ndarray,
765
+ position: np.ndarray,
766
+ max_activity: np.ndarray | None = None,
767
+ env_size: float | tuple[float, float] | tuple[float, float, float, float] | None = None,
768
+ config: PlotConfig | None = None,
769
+ ax: plt.Axes | None = None,
770
+ # Backward compatibility parameters
771
+ title: str = "Internal Position (GC bump) vs. Real Trajectory",
772
+ figsize: tuple[int, int] = (6, 4),
773
+ cmap: str = "cool",
774
+ show: bool = True,
775
+ save_path: str | None = None,
776
+ **kwargs,
777
+ ) -> tuple[plt.Figure, plt.Axes]:
778
+ """Plot internal position (GC bump) against the real trajectory.
779
+
780
+ Args:
781
+ internal_position: Internal decoded positions ``(T, 2)``.
782
+ position: Real positions ``(T, 2)``.
783
+ max_activity: Optional per-time max activity to color the internal position.
784
+ env_size: Environment size. If float, uses ``[0, env_size]`` for both axes.
785
+ If a tuple of 2, treats as ``(width, height)``. If a tuple of 4, treats as
786
+ ``(xmin, xmax, ymin, ymax)``.
787
+ config: PlotConfig object for unified configuration.
788
+ ax: Optional axis to draw on instead of creating a new figure.
789
+ **kwargs: Additional parameters for backward compatibility.
790
+
791
+ Returns:
792
+ tuple: ``(figure, axis)`` objects.
793
+ """
794
+ if internal_position.ndim != 2 or internal_position.shape[1] != 2:
795
+ raise ValueError(
796
+ f"internal_position must be (T, 2), got shape {internal_position.shape}"
797
+ )
798
+ if position.ndim != 2 or position.shape[1] != 2:
799
+ raise ValueError(f"position must be (T, 2), got shape {position.shape}")
800
+ if internal_position.shape[0] != position.shape[0]:
801
+ raise ValueError(
802
+ "internal_position and position must have the same length: "
803
+ f"{internal_position.shape[0]} != {position.shape[0]}"
804
+ )
805
+ if max_activity is not None and max_activity.shape[0] != internal_position.shape[0]:
806
+ raise ValueError(
807
+ "max_activity must have same length as internal_position: "
808
+ f"{max_activity.shape[0]} != {internal_position.shape[0]}"
809
+ )
810
+
811
+ if config is None:
812
+ config = PlotConfig(
813
+ title=title,
814
+ figsize=figsize,
815
+ show=show,
816
+ save_path=save_path,
817
+ kwargs={"cmap": cmap, **kwargs},
818
+ )
819
+
820
+ plot_kwargs = config.to_matplotlib_kwargs()
821
+ trajectory_color = plot_kwargs.pop("trajectory_color", "black")
822
+ trajectory_linewidth = plot_kwargs.pop("trajectory_linewidth", 1.0)
823
+ scatter_size = plot_kwargs.pop("scatter_size", 4)
824
+ scatter_alpha = plot_kwargs.pop("scatter_alpha", 0.9)
825
+ add_colorbar = bool(plot_kwargs.pop("add_colorbar", True))
826
+ colorbar_options = plot_kwargs.pop("colorbar", {}) if add_colorbar else {}
827
+ if "cmap" not in plot_kwargs:
828
+ plot_kwargs["cmap"] = "cool"
829
+
830
+ if max_activity is None:
831
+ add_colorbar = False
832
+
833
+ axis_provided = ax is not None
834
+ if not axis_provided:
835
+ fig, ax = plt.subplots(figsize=config.figsize)
836
+ else:
837
+ fig = ax.figure
838
+
839
+ try:
840
+ scatter = ax.scatter(
841
+ internal_position[:, 0],
842
+ internal_position[:, 1],
843
+ c=max_activity,
844
+ s=scatter_size,
845
+ alpha=scatter_alpha,
846
+ **plot_kwargs,
847
+ )
848
+ ax.plot(position[:, 0], position[:, 1], color=trajectory_color, lw=trajectory_linewidth)
849
+
850
+ ax.set_aspect("equal", adjustable="box")
851
+ if config.title:
852
+ ax.set_title(config.title, fontsize=14, fontweight="bold")
853
+
854
+ if env_size is not None:
855
+ if isinstance(env_size, (tuple, list, np.ndarray)):
856
+ if len(env_size) == 2:
857
+ ax.set_xlim(0, env_size[0])
858
+ ax.set_ylim(0, env_size[1])
859
+ elif len(env_size) == 4:
860
+ ax.set_xlim(env_size[0], env_size[1])
861
+ ax.set_ylim(env_size[2], env_size[3])
862
+ else:
863
+ raise ValueError("env_size tuple must be length 2 or 4.")
864
+ else:
865
+ ax.set_xlim(0, env_size)
866
+ ax.set_ylim(0, env_size)
867
+
868
+ sns.despine(ax=ax)
869
+
870
+ if add_colorbar:
871
+ default_cbar_opts = {"pad": 0.15, "size": "5%", "label": "Max GC activity"}
872
+ if isinstance(colorbar_options, dict):
873
+ extra_cbar_kwargs = colorbar_options.get("kwargs", {})
874
+ for key in ("pad", "size", "label"):
875
+ if key in colorbar_options:
876
+ default_cbar_opts[key] = colorbar_options[key]
877
+ else:
878
+ extra_cbar_kwargs = {}
879
+
880
+ divider = make_axes_locatable(ax)
881
+ cax = divider.append_axes(
882
+ "right", size=default_cbar_opts["size"], pad=default_cbar_opts["pad"]
883
+ )
884
+ cbar = fig.colorbar(scatter, cax=cax, **extra_cbar_kwargs)
885
+ if default_cbar_opts["label"]:
886
+ cbar.set_label(default_cbar_opts["label"], fontsize=12)
887
+
888
+ if not axis_provided:
889
+ fig.tight_layout()
890
+
891
+ show_flag = config.show and not axis_provided
892
+ finalize_figure(
893
+ fig,
894
+ replace(config, show=show_flag),
895
+ rasterize_artists=[scatter] if config.rasterized else None,
896
+ always_close=not show_flag,
897
+ )
898
+ return fig, ax
899
+
900
+ except Exception as e:
901
+ plt.close(fig)
902
+ raise e
903
+
904
+
763
905
  @dataclass(slots=True)
764
906
  class _PlaceCellAnimationData:
765
907
  """Immutable container for place cell animation arrays."""
@@ -628,7 +628,7 @@ class PipelineRunner:
628
628
  self, asa_data: dict[str, Any], state: WorkflowState, log_callback
629
629
  ) -> dict[str, Path]:
630
630
  """Run CohoMap analysis (TDA + decode + plotting)."""
631
- from canns.analyzer.data.asa import plot_cohomap_multi
631
+ from canns.analyzer.data.asa import plot_cohomap_scatter_multi
632
632
  from canns.analyzer.visualization import PlotConfigs
633
633
 
634
634
  tda_dir = self._results_dir(state) / "TDA"
@@ -661,7 +661,7 @@ class PipelineRunner:
661
661
  log_callback("Generating cohomology map...")
662
662
  pos = self._aligned_pos if self._aligned_pos is not None else asa_data
663
663
  config = PlotConfigs.cohomap(show=False, save_path=str(cohomap_path))
664
- plot_cohomap_multi(
664
+ plot_cohomap_scatter_multi(
665
665
  decoding_result=decode_result,
666
666
  position_data={"x": pos["x"], "y": pos["y"]},
667
667
  config=config,
@@ -838,16 +838,16 @@ class PipelineRunner:
838
838
  ) -> dict[str, Path]:
839
839
  """Run cohomology space visualization."""
840
840
  from canns.analyzer.data.asa import (
841
- plot_cohospace_neuron_1d,
842
- plot_cohospace_neuron_2d,
843
- plot_cohospace_population_1d,
844
- plot_cohospace_population_2d,
845
- plot_cohospace_trajectory_1d,
846
- plot_cohospace_trajectory_2d,
841
+ plot_cohospace_scatter_neuron_1d,
842
+ plot_cohospace_scatter_neuron_2d,
843
+ plot_cohospace_scatter_population_1d,
844
+ plot_cohospace_scatter_population_2d,
845
+ plot_cohospace_scatter_trajectory_1d,
846
+ plot_cohospace_scatter_trajectory_2d,
847
847
  )
848
- from canns.analyzer.data.asa.cohospace import (
849
- plot_cohospace_neuron_skewed,
850
- plot_cohospace_population_skewed,
848
+ from canns.analyzer.data.asa.cohospace_scatter import (
849
+ plot_cohospace_scatter_neuron_skewed,
850
+ plot_cohospace_scatter_population_skewed,
851
851
  )
852
852
  from canns.analyzer.visualization import PlotConfigs
853
853
 
@@ -937,7 +937,7 @@ class PipelineRunner:
937
937
  traj_path = out_dir / "cohospace_trajectory.png"
938
938
  if dim_mode == "1d":
939
939
  traj_cfg = PlotConfigs.cohospace_trajectory_1d(show=False, save_path=str(traj_path))
940
- plot_cohospace_trajectory_1d(
940
+ plot_cohospace_scatter_trajectory_1d(
941
941
  coords=coords2,
942
942
  times=None,
943
943
  subsample=subsample,
@@ -945,7 +945,7 @@ class PipelineRunner:
945
945
  )
946
946
  else:
947
947
  traj_cfg = PlotConfigs.cohospace_trajectory_2d(show=False, save_path=str(traj_path))
948
- plot_cohospace_trajectory_2d(
948
+ plot_cohospace_scatter_trajectory_2d(
949
949
  coords=coords2,
950
950
  times=None,
951
951
  subsample=subsample,
@@ -958,7 +958,7 @@ class PipelineRunner:
958
958
  log_callback(f"Plotting neuron {neuron_id}...")
959
959
  neuron_path = out_dir / f"cohospace_neuron_{neuron_id}.png"
960
960
  if unfold == "skew" and dim_mode != "1d":
961
- plot_cohospace_neuron_skewed(
961
+ plot_cohospace_scatter_neuron_skewed(
962
962
  coords=coordsbox2,
963
963
  activity=activity,
964
964
  neuron_id=int(neuron_id),
@@ -974,7 +974,7 @@ class PipelineRunner:
974
974
  neuron_cfg = PlotConfigs.cohospace_neuron_1d(
975
975
  show=False, save_path=str(neuron_path)
976
976
  )
977
- plot_cohospace_neuron_1d(
977
+ plot_cohospace_scatter_neuron_1d(
978
978
  coords=coordsbox2,
979
979
  activity=activity,
980
980
  neuron_id=int(neuron_id),
@@ -986,7 +986,7 @@ class PipelineRunner:
986
986
  neuron_cfg = PlotConfigs.cohospace_neuron_2d(
987
987
  show=False, save_path=str(neuron_path)
988
988
  )
989
- plot_cohospace_neuron_2d(
989
+ plot_cohospace_scatter_neuron_2d(
990
990
  coords=coordsbox2,
991
991
  activity=activity,
992
992
  neuron_id=int(neuron_id),
@@ -1001,7 +1001,7 @@ class PipelineRunner:
1001
1001
  pop_path = out_dir / "cohospace_population.png"
1002
1002
  neuron_ids = list(range(activity.shape[1]))
1003
1003
  if unfold == "skew" and dim_mode != "1d":
1004
- plot_cohospace_population_skewed(
1004
+ plot_cohospace_scatter_population_skewed(
1005
1005
  coords=coords2,
1006
1006
  activity=activity,
1007
1007
  neuron_ids=neuron_ids,
@@ -1017,7 +1017,7 @@ class PipelineRunner:
1017
1017
  pop_cfg = PlotConfigs.cohospace_population_1d(
1018
1018
  show=False, save_path=str(pop_path)
1019
1019
  )
1020
- plot_cohospace_population_1d(
1020
+ plot_cohospace_scatter_population_1d(
1021
1021
  coords=coords2,
1022
1022
  activity=activity,
1023
1023
  neuron_ids=neuron_ids,
@@ -1029,7 +1029,7 @@ class PipelineRunner:
1029
1029
  pop_cfg = PlotConfigs.cohospace_population_2d(
1030
1030
  show=False, save_path=str(pop_path)
1031
1031
  )
1032
- plot_cohospace_population_2d(
1032
+ plot_cohospace_scatter_population_2d(
1033
1033
  coords=coords2,
1034
1034
  activity=activity,
1035
1035
  neuron_ids=neuron_ids,
@@ -2,9 +2,9 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- import sys
6
- import os
7
5
  import importlib.util
6
+ import os
7
+ import sys
8
8
 
9
9
  __all__ = ["main", "ASAGuiApp"]
10
10
 
@@ -14,7 +14,9 @@ if _pyside6_missing:
14
14
  try: # pragma: no cover - only used in CI/test runs
15
15
  import pytest
16
16
 
17
- pytest.skip("PySide6 is not installed; skipping asa_gui module.", allow_module_level=True)
17
+ pytest.skip(
18
+ "PySide6 is not installed; skipping asa_gui module.", allow_module_level=True
19
+ )
18
20
  except Exception:
19
21
  pass
20
22
 
@@ -10,6 +10,7 @@ from PySide6.QtWidgets import (
10
10
  QHBoxLayout,
11
11
  QLabel,
12
12
  QLineEdit,
13
+ QPushButton,
13
14
  QSpinBox,
14
15
  QWidget,
15
16
  )
@@ -49,7 +50,7 @@ class PathCompareMode(AbstractAnalysisMode):
49
50
  self.dim2.setValue(2)
50
51
 
51
52
  self.use_box = QCheckBox("Use coordsbox / times_box")
52
- self.use_box.setChecked(False)
53
+ self.use_box.setChecked(True)
53
54
 
54
55
  self.interp_full = QCheckBox("Interpolate to full trajectory")
55
56
  self.interp_full.setChecked(True)
@@ -57,8 +58,13 @@ class PathCompareMode(AbstractAnalysisMode):
57
58
 
58
59
  self.coords_key = QLineEdit()
59
60
  self.coords_key.setPlaceholderText("coords / coordsbox (optional)")
61
+ self.btn_coordsbox = QPushButton("coordsbox")
62
+ self.btn_coordsbox.clicked.connect(lambda: self.coords_key.setText("coordsbox"))
63
+
60
64
  self.times_key = QLineEdit()
61
65
  self.times_key.setPlaceholderText("times_box (optional)")
66
+ self.btn_times_box = QPushButton("times_box")
67
+ self.btn_times_box.clicked.connect(lambda: self.times_key.setText("times_box"))
62
68
 
63
69
  self.slice_mode = PopupComboBox()
64
70
  self.slice_mode.addItem("Time (tmin/tmax)", userData="time")
@@ -125,8 +131,20 @@ class PathCompareMode(AbstractAnalysisMode):
125
131
  form.addRow(self._dims2d_label, dims_2d)
126
132
  form.addRow(self.use_box)
127
133
  form.addRow(self.interp_full)
128
- form.addRow("coords key", self.coords_key)
129
- form.addRow("times key", self.times_key)
134
+ coords_row = QWidget()
135
+ coords_layout = QHBoxLayout(coords_row)
136
+ coords_layout.setContentsMargins(0, 0, 0, 0)
137
+ coords_layout.addWidget(self.coords_key, 1)
138
+ coords_layout.addWidget(self.btn_coordsbox)
139
+
140
+ times_row = QWidget()
141
+ times_layout = QHBoxLayout(times_row)
142
+ times_layout.setContentsMargins(0, 0, 0, 0)
143
+ times_layout.addWidget(self.times_key, 1)
144
+ times_layout.addWidget(self.btn_times_box)
145
+
146
+ form.addRow("coords key", coords_row)
147
+ form.addRow("times key", times_row)
130
148
  form.addRow("Slice mode", self.slice_mode)
131
149
  form.addRow("tmin (sec, -1=auto)", self.tmin)
132
150
  form.addRow("tmax (sec, -1=auto)", self.tmax)
@@ -150,6 +168,10 @@ class PathCompareMode(AbstractAnalysisMode):
150
168
  def _refresh_enabled() -> None:
151
169
  use_box = bool(self.use_box.isChecked())
152
170
  self.interp_full.setEnabled(use_box)
171
+ self.coords_key.setEnabled(use_box)
172
+ self.times_key.setEnabled(use_box)
173
+ self.btn_coordsbox.setEnabled(use_box)
174
+ self.btn_times_box.setEnabled(use_box)
153
175
 
154
176
  def _refresh_slice_mode() -> None:
155
177
  is_time = self.slice_mode.currentData() == "time"
@@ -210,6 +232,8 @@ class PathCompareMode(AbstractAnalysisMode):
210
232
  self.interp_full.setToolTip("插值回完整轨迹。")
211
233
  self.coords_key.setToolTip("可选:解码坐标键(默认 coords/coordsbox)。")
212
234
  self.times_key.setToolTip("可选:times_box 键名。")
235
+ self.btn_coordsbox.setToolTip("填入 coordsbox。")
236
+ self.btn_times_box.setToolTip("填入 times_box。")
213
237
  self.slice_mode.setToolTip("按时间或索引裁剪。")
214
238
  self.tmin.setToolTip("起始时间(秒),-1 自动。")
215
239
  self.tmax.setToolTip("结束时间(秒),-1 自动。")
@@ -226,10 +250,14 @@ class PathCompareMode(AbstractAnalysisMode):
226
250
  self.dim.setToolTip("1D decoded dimension index.")
227
251
  self.dim1.setToolTip("2D decoded dimension 1.")
228
252
  self.dim2.setToolTip("2D decoded dimension 2.")
229
- self.use_box.setToolTip("Use coordsbox/times_box alignment (recommended with speed_filter).")
253
+ self.use_box.setToolTip(
254
+ "Use coordsbox/times_box alignment (recommended with speed_filter)."
255
+ )
230
256
  self.interp_full.setToolTip("Interpolate back to full trajectory.")
231
257
  self.coords_key.setToolTip("Optional decode coords key (default coords/coordsbox).")
232
258
  self.times_key.setToolTip("Optional times_box key.")
259
+ self.btn_coordsbox.setToolTip("Fill coordsbox.")
260
+ self.btn_times_box.setToolTip("Fill times_box.")
233
261
  self.slice_mode.setToolTip("Slice by time or index.")
234
262
  self.tmin.setToolTip("Start time (sec), -1 = auto.")
235
263
  self.tmax.setToolTip("End time (sec), -1 = auto.")
@@ -651,7 +651,7 @@ class PipelineRunner:
651
651
  def _run_cohomap(
652
652
  self, asa_data: dict[str, Any], state: WorkflowState, log_callback
653
653
  ) -> dict[str, Path]:
654
- from canns.analyzer.data.asa import plot_cohomap_multi
654
+ from canns.analyzer.data.asa import plot_cohomap_scatter_multi
655
655
  from canns.analyzer.visualization import PlotConfigs
656
656
 
657
657
  tda_dir = self._results_dir(state) / "TDA"
@@ -684,7 +684,7 @@ class PipelineRunner:
684
684
  log_callback("Generating cohomology map...")
685
685
  pos = self._aligned_pos if self._aligned_pos is not None else asa_data
686
686
  config = PlotConfigs.cohomap(show=False, save_path=str(cohomap_path))
687
- plot_cohomap_multi(
687
+ plot_cohomap_scatter_multi(
688
688
  decoding_result=decode_result,
689
689
  position_data={"x": pos["x"], "y": pos["y"]},
690
690
  config=config,
@@ -1175,18 +1175,18 @@ class PipelineRunner:
1175
1175
  self, asa_data: dict[str, Any], state: WorkflowState, log_callback
1176
1176
  ) -> dict[str, Path]:
1177
1177
  from canns.analyzer.data.asa import (
1178
- plot_cohospace_neuron_1d,
1179
- plot_cohospace_neuron_2d,
1180
- plot_cohospace_population_1d,
1181
- plot_cohospace_population_2d,
1182
- plot_cohospace_trajectory_1d,
1183
- plot_cohospace_trajectory_2d,
1178
+ plot_cohospace_scatter_neuron_1d,
1179
+ plot_cohospace_scatter_neuron_2d,
1180
+ plot_cohospace_scatter_population_1d,
1181
+ plot_cohospace_scatter_population_2d,
1182
+ plot_cohospace_scatter_trajectory_1d,
1183
+ plot_cohospace_scatter_trajectory_2d,
1184
1184
  )
1185
- from canns.analyzer.data.asa.cohospace import (
1186
- compute_cohoscore_1d,
1187
- compute_cohoscore_2d,
1188
- plot_cohospace_neuron_skewed,
1189
- plot_cohospace_population_skewed,
1185
+ from canns.analyzer.data.asa.cohospace_scatter import (
1186
+ compute_cohoscore_scatter_1d,
1187
+ compute_cohoscore_scatter_2d,
1188
+ plot_cohospace_scatter_neuron_skewed,
1189
+ plot_cohospace_scatter_population_skewed,
1190
1190
  )
1191
1191
  from canns.analyzer.visualization import PlotConfigs
1192
1192
 
@@ -1254,11 +1254,11 @@ class PipelineRunner:
1254
1254
  if enable_score:
1255
1255
  try:
1256
1256
  if dim_mode == "1d":
1257
- scores = compute_cohoscore_1d(
1257
+ scores = compute_cohoscore_scatter_1d(
1258
1258
  coords2, activity, top_percent=top_percent, times=times
1259
1259
  )
1260
1260
  else:
1261
- scores = compute_cohoscore_2d(
1261
+ scores = compute_cohoscore_scatter_2d(
1262
1262
  coords2, activity, top_percent=top_percent, times=times
1263
1263
  )
1264
1264
  cohoscore_path = out_dir / "cohoscore.npy"
@@ -1314,7 +1314,7 @@ class PipelineRunner:
1314
1314
  traj_path = out_dir / "cohospace_trajectory.png"
1315
1315
  if dim_mode == "1d":
1316
1316
  traj_cfg = PlotConfigs.cohospace_trajectory_1d(show=False, save_path=str(traj_path))
1317
- plot_cohospace_trajectory_1d(
1317
+ plot_cohospace_scatter_trajectory_1d(
1318
1318
  coords=coords2,
1319
1319
  times=None,
1320
1320
  subsample=subsample,
@@ -1322,7 +1322,7 @@ class PipelineRunner:
1322
1322
  )
1323
1323
  else:
1324
1324
  traj_cfg = PlotConfigs.cohospace_trajectory_2d(show=False, save_path=str(traj_path))
1325
- plot_cohospace_trajectory_2d(
1325
+ plot_cohospace_scatter_trajectory_2d(
1326
1326
  coords=coords2,
1327
1327
  times=None,
1328
1328
  subsample=subsample,
@@ -1334,7 +1334,7 @@ class PipelineRunner:
1334
1334
  log_callback(f"Plotting neuron {neuron_id}...")
1335
1335
  neuron_path = out_dir / f"cohospace_neuron_{neuron_id}.png"
1336
1336
  if unfold == "skew" and dim_mode != "1d":
1337
- plot_cohospace_neuron_skewed(
1337
+ plot_cohospace_scatter_neuron_skewed(
1338
1338
  coords=coordsbox2,
1339
1339
  activity=activity,
1340
1340
  neuron_id=int(neuron_id),
@@ -1351,7 +1351,7 @@ class PipelineRunner:
1351
1351
  neuron_cfg = PlotConfigs.cohospace_neuron_1d(
1352
1352
  show=False, save_path=str(neuron_path)
1353
1353
  )
1354
- plot_cohospace_neuron_1d(
1354
+ plot_cohospace_scatter_neuron_1d(
1355
1355
  coords=coordsbox2,
1356
1356
  activity=activity,
1357
1357
  neuron_id=int(neuron_id),
@@ -1364,7 +1364,7 @@ class PipelineRunner:
1364
1364
  neuron_cfg = PlotConfigs.cohospace_neuron_2d(
1365
1365
  show=False, save_path=str(neuron_path)
1366
1366
  )
1367
- plot_cohospace_neuron_2d(
1367
+ plot_cohospace_scatter_neuron_2d(
1368
1368
  coords=coordsbox2,
1369
1369
  activity=activity,
1370
1370
  neuron_id=int(neuron_id),
@@ -1384,7 +1384,7 @@ class PipelineRunner:
1384
1384
  else:
1385
1385
  neuron_ids = list(range(activity.shape[1]))
1386
1386
  if unfold == "skew" and dim_mode != "1d":
1387
- plot_cohospace_population_skewed(
1387
+ plot_cohospace_scatter_population_skewed(
1388
1388
  coords=coords2,
1389
1389
  activity=activity,
1390
1390
  neuron_ids=neuron_ids,
@@ -1401,7 +1401,7 @@ class PipelineRunner:
1401
1401
  pop_cfg = PlotConfigs.cohospace_population_1d(
1402
1402
  show=False, save_path=str(pop_path)
1403
1403
  )
1404
- plot_cohospace_population_1d(
1404
+ plot_cohospace_scatter_population_1d(
1405
1405
  coords=coords2,
1406
1406
  activity=activity,
1407
1407
  neuron_ids=neuron_ids,
@@ -1414,7 +1414,7 @@ class PipelineRunner:
1414
1414
  pop_cfg = PlotConfigs.cohospace_population_2d(
1415
1415
  show=False, save_path=str(pop_path)
1416
1416
  )
1417
- plot_cohospace_population_2d(
1417
+ plot_cohospace_scatter_population_2d(
1418
1418
  coords=coords2,
1419
1419
  activity=activity,
1420
1420
  neuron_ids=neuron_ids,