canns 0.13.1__py3-none-any.whl → 0.14.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. canns/analyzer/data/__init__.py +5 -1
  2. canns/analyzer/data/asa/__init__.py +27 -12
  3. canns/analyzer/data/asa/cohospace.py +336 -10
  4. canns/analyzer/data/asa/config.py +3 -0
  5. canns/analyzer/data/asa/embedding.py +48 -45
  6. canns/analyzer/data/asa/path.py +104 -2
  7. canns/analyzer/data/asa/plotting.py +88 -19
  8. canns/analyzer/data/asa/tda.py +11 -4
  9. canns/analyzer/data/cell_classification/__init__.py +97 -0
  10. canns/analyzer/data/cell_classification/core/__init__.py +26 -0
  11. canns/analyzer/data/cell_classification/core/grid_cells.py +633 -0
  12. canns/analyzer/data/cell_classification/core/grid_modules_leiden.py +288 -0
  13. canns/analyzer/data/cell_classification/core/head_direction.py +347 -0
  14. canns/analyzer/data/cell_classification/core/spatial_analysis.py +431 -0
  15. canns/analyzer/data/cell_classification/io/__init__.py +5 -0
  16. canns/analyzer/data/cell_classification/io/matlab_loader.py +417 -0
  17. canns/analyzer/data/cell_classification/utils/__init__.py +39 -0
  18. canns/analyzer/data/cell_classification/utils/circular_stats.py +383 -0
  19. canns/analyzer/data/cell_classification/utils/correlation.py +318 -0
  20. canns/analyzer/data/cell_classification/utils/geometry.py +442 -0
  21. canns/analyzer/data/cell_classification/utils/image_processing.py +416 -0
  22. canns/analyzer/data/cell_classification/visualization/__init__.py +19 -0
  23. canns/analyzer/data/cell_classification/visualization/grid_plots.py +292 -0
  24. canns/analyzer/data/cell_classification/visualization/hd_plots.py +200 -0
  25. canns/analyzer/metrics/__init__.py +2 -1
  26. canns/analyzer/visualization/core/config.py +46 -4
  27. canns/data/__init__.py +6 -1
  28. canns/data/datasets.py +154 -1
  29. canns/data/loaders.py +37 -0
  30. canns/pipeline/__init__.py +13 -9
  31. canns/pipeline/__main__.py +6 -0
  32. canns/pipeline/asa/runner.py +105 -41
  33. canns/pipeline/asa_gui/__init__.py +68 -0
  34. canns/pipeline/asa_gui/__main__.py +6 -0
  35. canns/pipeline/asa_gui/analysis_modes/__init__.py +42 -0
  36. canns/pipeline/asa_gui/analysis_modes/base.py +39 -0
  37. canns/pipeline/asa_gui/analysis_modes/batch_mode.py +21 -0
  38. canns/pipeline/asa_gui/analysis_modes/cohomap_mode.py +56 -0
  39. canns/pipeline/asa_gui/analysis_modes/cohospace_mode.py +194 -0
  40. canns/pipeline/asa_gui/analysis_modes/decode_mode.py +52 -0
  41. canns/pipeline/asa_gui/analysis_modes/fr_mode.py +81 -0
  42. canns/pipeline/asa_gui/analysis_modes/frm_mode.py +92 -0
  43. canns/pipeline/asa_gui/analysis_modes/gridscore_mode.py +123 -0
  44. canns/pipeline/asa_gui/analysis_modes/pathcompare_mode.py +199 -0
  45. canns/pipeline/asa_gui/analysis_modes/tda_mode.py +112 -0
  46. canns/pipeline/asa_gui/app.py +29 -0
  47. canns/pipeline/asa_gui/controllers/__init__.py +6 -0
  48. canns/pipeline/asa_gui/controllers/analysis_controller.py +59 -0
  49. canns/pipeline/asa_gui/controllers/preprocess_controller.py +89 -0
  50. canns/pipeline/asa_gui/core/__init__.py +15 -0
  51. canns/pipeline/asa_gui/core/cache.py +14 -0
  52. canns/pipeline/asa_gui/core/runner.py +1936 -0
  53. canns/pipeline/asa_gui/core/state.py +324 -0
  54. canns/pipeline/asa_gui/core/worker.py +260 -0
  55. canns/pipeline/asa_gui/main_window.py +184 -0
  56. canns/pipeline/asa_gui/models/__init__.py +7 -0
  57. canns/pipeline/asa_gui/models/config.py +14 -0
  58. canns/pipeline/asa_gui/models/job.py +31 -0
  59. canns/pipeline/asa_gui/models/presets.py +21 -0
  60. canns/pipeline/asa_gui/resources/__init__.py +16 -0
  61. canns/pipeline/asa_gui/resources/dark.qss +167 -0
  62. canns/pipeline/asa_gui/resources/light.qss +163 -0
  63. canns/pipeline/asa_gui/resources/styles.qss +130 -0
  64. canns/pipeline/asa_gui/utils/__init__.py +1 -0
  65. canns/pipeline/asa_gui/utils/formatters.py +15 -0
  66. canns/pipeline/asa_gui/utils/io_adapters.py +40 -0
  67. canns/pipeline/asa_gui/utils/validators.py +41 -0
  68. canns/pipeline/asa_gui/views/__init__.py +1 -0
  69. canns/pipeline/asa_gui/views/help_content.py +171 -0
  70. canns/pipeline/asa_gui/views/pages/__init__.py +6 -0
  71. canns/pipeline/asa_gui/views/pages/analysis_page.py +565 -0
  72. canns/pipeline/asa_gui/views/pages/preprocess_page.py +492 -0
  73. canns/pipeline/asa_gui/views/panels/__init__.py +1 -0
  74. canns/pipeline/asa_gui/views/widgets/__init__.py +21 -0
  75. canns/pipeline/asa_gui/views/widgets/artifacts_tab.py +44 -0
  76. canns/pipeline/asa_gui/views/widgets/drop_zone.py +80 -0
  77. canns/pipeline/asa_gui/views/widgets/file_list.py +27 -0
  78. canns/pipeline/asa_gui/views/widgets/gridscore_tab.py +308 -0
  79. canns/pipeline/asa_gui/views/widgets/help_dialog.py +27 -0
  80. canns/pipeline/asa_gui/views/widgets/image_tab.py +50 -0
  81. canns/pipeline/asa_gui/views/widgets/image_viewer.py +97 -0
  82. canns/pipeline/asa_gui/views/widgets/log_box.py +16 -0
  83. canns/pipeline/asa_gui/views/widgets/pathcompare_tab.py +200 -0
  84. canns/pipeline/asa_gui/views/widgets/popup_combo.py +25 -0
  85. canns/pipeline/gallery/__init__.py +15 -5
  86. canns/pipeline/gallery/__main__.py +11 -0
  87. canns/pipeline/gallery/app.py +705 -0
  88. canns/pipeline/gallery/runner.py +790 -0
  89. canns/pipeline/gallery/state.py +51 -0
  90. canns/pipeline/gallery/styles.tcss +123 -0
  91. canns/pipeline/launcher.py +81 -0
  92. {canns-0.13.1.dist-info → canns-0.14.0.dist-info}/METADATA +11 -1
  93. canns-0.14.0.dist-info/RECORD +163 -0
  94. canns-0.14.0.dist-info/entry_points.txt +5 -0
  95. canns/pipeline/_base.py +0 -50
  96. canns-0.13.1.dist-info/RECORD +0 -89
  97. canns-0.13.1.dist-info/entry_points.txt +0 -3
  98. {canns-0.13.1.dist-info → canns-0.14.0.dist-info}/WHEEL +0 -0
  99. {canns-0.13.1.dist-info → canns-0.14.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,1936 @@
1
+ """Pipeline execution wrapper for ASA GUI.
2
+
3
+ Provides synchronous pipeline execution that wraps canns.analyzer.data.asa APIs
4
+ and mirrors the TUI runner behavior for caching and artifacts.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import hashlib
10
+ import json
11
+ import time
12
+ from collections.abc import Callable
13
+ from dataclasses import dataclass
14
+ from pathlib import Path
15
+ from typing import Any
16
+
17
+ import numpy as np
18
+
19
+ from .state import WorkflowState, resolve_path
20
+
21
+
22
+ @dataclass
23
+ class PipelineResult:
24
+ """Result from pipeline execution."""
25
+
26
+ success: bool
27
+ artifacts: dict[str, Path]
28
+ summary: str
29
+ error: str | None = None
30
+ elapsed_time: float = 0.0
31
+
32
+
33
+ class ProcessingError(RuntimeError):
34
+ """Raised when a pipeline stage fails."""
35
+
36
+ pass
37
+
38
+
39
+ class PipelineRunner:
40
+ """Synchronous pipeline execution wrapper."""
41
+
42
+ def __init__(self) -> None:
43
+ self._asa_data: dict[str, Any] | None = None
44
+ self._embed_data: np.ndarray | None = None
45
+ self._aligned_pos: dict[str, np.ndarray] | None = None
46
+ self._input_hash: str | None = None
47
+ self._embed_hash: str | None = None
48
+ self._mpl_ready: bool = False
49
+
50
+ def has_preprocessed_data(self) -> bool:
51
+ return self._embed_data is not None
52
+
53
+ @property
54
+ def embed_data(self) -> np.ndarray | None:
55
+ return self._embed_data
56
+
57
+ @property
58
+ def aligned_pos(self) -> dict[str, np.ndarray] | None:
59
+ return self._aligned_pos
60
+
61
+ def reset_input(self) -> None:
62
+ self._asa_data = None
63
+ self._embed_data = None
64
+ self._aligned_pos = None
65
+ self._input_hash = None
66
+ self._embed_hash = None
67
+
68
+ def _json_safe(self, obj: Any) -> Any:
69
+ if isinstance(obj, Path):
70
+ return str(obj)
71
+ if isinstance(obj, tuple):
72
+ return [self._json_safe(v) for v in obj]
73
+ if isinstance(obj, list):
74
+ return [self._json_safe(v) for v in obj]
75
+ if isinstance(obj, dict):
76
+ return {str(k): self._json_safe(v) for k, v in obj.items()}
77
+ if hasattr(obj, "item") and callable(obj.item):
78
+ try:
79
+ return obj.item()
80
+ except Exception:
81
+ return str(obj)
82
+ return obj
83
+
84
+ def _hash_bytes(self, data: bytes) -> str:
85
+ return hashlib.md5(data).hexdigest()
86
+
87
+ def _hash_file(self, path: Path) -> str:
88
+ md5 = hashlib.md5()
89
+ with path.open("rb") as f:
90
+ for chunk in iter(lambda: f.read(1024 * 1024), b""):
91
+ md5.update(chunk)
92
+ return md5.hexdigest()
93
+
94
+ def _hash_obj(self, obj: Any) -> str:
95
+ payload = json.dumps(self._json_safe(obj), sort_keys=True, ensure_ascii=True).encode(
96
+ "utf-8"
97
+ )
98
+ return self._hash_bytes(payload)
99
+
100
+ def _ensure_matplotlib_backend(self) -> None:
101
+ if self._mpl_ready:
102
+ return
103
+ try:
104
+ import os
105
+
106
+ os.environ.setdefault("MPLBACKEND", "Agg")
107
+ import matplotlib
108
+
109
+ try:
110
+ matplotlib.use("Agg", force=True)
111
+ except TypeError:
112
+ matplotlib.use("Agg")
113
+ except Exception:
114
+ pass
115
+ self._mpl_ready = True
116
+
117
+ def _cache_dir(self, state: WorkflowState) -> Path:
118
+ return self._results_dir(state) / ".asa_cache"
119
+
120
+ def _results_dir(self, state: WorkflowState) -> Path:
121
+ base = state.workdir / "Results"
122
+ dataset_id = self._dataset_id(state)
123
+ return base / dataset_id
124
+
125
+ def results_dir(self, state: WorkflowState) -> Path:
126
+ return self._results_dir(state)
127
+
128
+ def _dataset_id(self, state: WorkflowState) -> str:
129
+ try:
130
+ input_hash = self._input_hash or self._compute_input_hash(state)
131
+ except Exception:
132
+ input_hash = "unknown"
133
+ prefix = input_hash[:8]
134
+
135
+ if state.input_mode == "asa":
136
+ path = resolve_path(state, state.asa_file)
137
+ stem = path.stem if path is not None else "asa"
138
+ return f"{stem}_{prefix}"
139
+ if state.input_mode == "neuron_traj":
140
+ neuron_path = resolve_path(state, state.neuron_file)
141
+ traj_path = resolve_path(state, state.traj_file)
142
+ neuron_stem = neuron_path.stem if neuron_path is not None else "neuron"
143
+ traj_stem = traj_path.stem if traj_path is not None else "traj"
144
+ return f"{neuron_stem}_{traj_stem}_{prefix}"
145
+ return f"{state.input_mode}_{prefix}"
146
+
147
+ def _stage_cache_path(self, stage_dir: Path) -> Path:
148
+ return stage_dir / "cache.json"
149
+
150
+ def _load_cache_meta(self, path: Path) -> dict[str, Any]:
151
+ if not path.exists():
152
+ return {}
153
+ try:
154
+ return json.loads(path.read_text(encoding="utf-8"))
155
+ except Exception:
156
+ return {}
157
+
158
+ def _write_cache_meta(self, path: Path, payload: dict[str, Any]) -> None:
159
+ path.write_text(
160
+ json.dumps(self._json_safe(payload), ensure_ascii=True, indent=2), encoding="utf-8"
161
+ )
162
+
163
+ def _stage_cache_hit(
164
+ self, stage_dir: Path, expected_hash: str, required_files: list[Path]
165
+ ) -> bool:
166
+ if not all(p.exists() for p in required_files):
167
+ return False
168
+ meta = self._load_cache_meta(self._stage_cache_path(stage_dir))
169
+ return meta.get("hash") == expected_hash
170
+
171
+ def _compute_input_hash(self, state: WorkflowState) -> str:
172
+ if state.input_mode == "asa":
173
+ path = resolve_path(state, state.asa_file)
174
+ if path is None:
175
+ raise ProcessingError("ASA file not set.")
176
+ return self._hash_obj({"mode": "asa", "file": self._hash_file(path)})
177
+ if state.input_mode == "neuron_traj":
178
+ neuron_path = resolve_path(state, state.neuron_file)
179
+ traj_path = resolve_path(state, state.traj_file)
180
+ if neuron_path is None or traj_path is None:
181
+ raise ProcessingError("Neuron/trajectory files not set.")
182
+ return self._hash_obj(
183
+ {
184
+ "mode": "neuron_traj",
185
+ "neuron": self._hash_file(neuron_path),
186
+ "traj": self._hash_file(traj_path),
187
+ }
188
+ )
189
+ return self._hash_obj({"mode": state.input_mode})
190
+
191
+ def _load_npz_dict(self, path: Path) -> dict[str, Any]:
192
+ data = np.load(path, allow_pickle=True)
193
+ for key in ("persistence_result", "decode_result"):
194
+ if key in data.files:
195
+ return data[key].item()
196
+ return {k: data[k] for k in data.files}
197
+
198
+ def _build_spike_matrix_from_events(self, asa: dict[str, Any]) -> np.ndarray:
199
+ if "t" not in asa:
200
+ raise ProcessingError("asa dict missing key 't' for spike mode.")
201
+ t = np.asarray(asa["t"])
202
+ if t.ndim != 1:
203
+ raise ProcessingError(f"asa['t'] must be 1D, got shape={t.shape}")
204
+ total_steps = t.shape[0]
205
+
206
+ raw = asa.get("spike")
207
+ if raw is None:
208
+ raise ProcessingError("asa dict missing key 'spike' for spike mode.")
209
+ arr = np.asarray(raw)
210
+
211
+ if isinstance(raw, np.ndarray) and arr.ndim == 2 and np.issubdtype(arr.dtype, np.number):
212
+ if arr.shape[0] != total_steps:
213
+ raise ProcessingError(
214
+ f"asa['spike'] matrix first dim {arr.shape[0]} != len(t)={total_steps}"
215
+ )
216
+ return arr.astype(float, copy=False)
217
+
218
+ if isinstance(raw, np.ndarray) and arr.dtype == object and arr.size == 1:
219
+ raw = arr.item()
220
+
221
+ if isinstance(raw, dict):
222
+ keys = sorted(raw.keys())
223
+ spike_dict = {k: np.asarray(raw[k], dtype=float).ravel() for k in keys}
224
+ elif isinstance(raw, (list, tuple)):
225
+ keys = list(range(len(raw)))
226
+ spike_dict = {i: np.asarray(raw[i], dtype=float).ravel() for i in keys}
227
+ else:
228
+ raise ProcessingError(
229
+ "asa['spike'] must be a (T,N) numeric array, dict, or list-of-arrays for spike mode."
230
+ )
231
+
232
+ neuron_count = len(spike_dict)
233
+ spike_mat = np.zeros((total_steps, neuron_count), dtype=float)
234
+ if total_steps > 1:
235
+ dt = float(t[1] - t[0])
236
+ else:
237
+ dt = 1.0
238
+ t0 = float(t[0])
239
+
240
+ for col, key in enumerate(keys):
241
+ times = spike_dict[key]
242
+ if times.size == 0:
243
+ continue
244
+ idx = np.rint((times - t0) / dt).astype(int)
245
+ idx = idx[(idx >= 0) & (idx < total_steps)]
246
+ if idx.size == 0:
247
+ continue
248
+ np.add.at(spike_mat[:, col], idx, 1.0)
249
+
250
+ return spike_mat
251
+
252
+ def _check_cancel(self, cancel_check: Callable[[], bool] | None) -> None:
253
+ if cancel_check and cancel_check():
254
+ raise ProcessingError("Cancelled by user")
255
+
256
+ def run_preprocessing(
257
+ self,
258
+ state: WorkflowState,
259
+ log_callback: Callable[[str], None],
260
+ progress_callback: Callable[[int], None],
261
+ cancel_check: Callable[[], bool] | None = None,
262
+ ) -> PipelineResult:
263
+ t0 = time.time()
264
+
265
+ try:
266
+ self._check_cancel(cancel_check)
267
+
268
+ log_callback("Loading data...")
269
+ progress_callback(10)
270
+ asa_data = self._load_data(state)
271
+ self._asa_data = asa_data
272
+ self._aligned_pos = None
273
+ self._input_hash = self._compute_input_hash(state)
274
+
275
+ self._check_cancel(cancel_check)
276
+
277
+ log_callback(f"Preprocessing with {state.preprocess_method}...")
278
+ progress_callback(30)
279
+
280
+ if state.preprocess_method == "embed_spike_trains":
281
+ from canns.analyzer.data.asa import SpikeEmbeddingConfig, embed_spike_trains
282
+
283
+ params = state.preprocess_params if state.preprocess_params else {}
284
+ base_config = SpikeEmbeddingConfig()
285
+ effective_params = {
286
+ "res": base_config.res,
287
+ "dt": base_config.dt,
288
+ "sigma": base_config.sigma,
289
+ "smooth": base_config.smooth,
290
+ "speed_filter": base_config.speed_filter,
291
+ "min_speed": base_config.min_speed,
292
+ }
293
+ effective_params.update(params)
294
+
295
+ self._embed_hash = self._hash_obj(
296
+ {
297
+ "input_hash": self._input_hash,
298
+ "method": state.preprocess_method,
299
+ "params": effective_params,
300
+ }
301
+ )
302
+ cache_dir = self._cache_dir(state)
303
+ cache_dir.mkdir(parents=True, exist_ok=True)
304
+ cache_path = cache_dir / f"embed_{self._embed_hash}.npz"
305
+ meta_path = cache_dir / f"embed_{self._embed_hash}.json"
306
+
307
+ if cache_path.exists():
308
+ log_callback("♻️ Using cached embedding.")
309
+ cached = np.load(cache_path, allow_pickle=True)
310
+ self._embed_data = cached["embed_data"]
311
+ if {"x", "y", "t"}.issubset(set(cached.files)):
312
+ self._aligned_pos = {
313
+ "x": cached["x"],
314
+ "y": cached["y"],
315
+ "t": cached["t"],
316
+ }
317
+ progress_callback(100)
318
+ elapsed = time.time() - t0
319
+ return PipelineResult(
320
+ success=True,
321
+ artifacts={"embedding": cache_path},
322
+ summary=f"Preprocessing reused cached embedding in {elapsed:.1f}s",
323
+ elapsed_time=elapsed,
324
+ )
325
+
326
+ config = SpikeEmbeddingConfig(**effective_params)
327
+
328
+ log_callback("Running embed_spike_trains...")
329
+ progress_callback(50)
330
+ embed_result = embed_spike_trains(asa_data, config)
331
+
332
+ if isinstance(embed_result, tuple):
333
+ embed_data = embed_result[0]
334
+ if len(embed_result) >= 4 and embed_result[1] is not None:
335
+ self._aligned_pos = {
336
+ "x": embed_result[1],
337
+ "y": embed_result[2],
338
+ "t": embed_result[3],
339
+ }
340
+ else:
341
+ embed_data = embed_result
342
+
343
+ self._embed_data = embed_data
344
+ log_callback(f"Embed data shape: {embed_data.shape}")
345
+
346
+ try:
347
+ payload = {"embed_data": embed_data}
348
+ if self._aligned_pos is not None:
349
+ payload.update(self._aligned_pos)
350
+ np.savez_compressed(cache_path, **payload)
351
+ self._write_cache_meta(
352
+ meta_path,
353
+ {
354
+ "hash": self._embed_hash,
355
+ "input_hash": self._input_hash,
356
+ "params": effective_params,
357
+ },
358
+ )
359
+ except Exception as e:
360
+ log_callback(f"Warning: failed to cache embedding: {e}")
361
+ else:
362
+ log_callback("No preprocessing - using raw spike data")
363
+ spike = asa_data.get("spike")
364
+ self._embed_hash = self._hash_obj(
365
+ {
366
+ "input_hash": self._input_hash,
367
+ "method": state.preprocess_method,
368
+ "params": {},
369
+ }
370
+ )
371
+
372
+ if isinstance(spike, np.ndarray) and spike.ndim == 2:
373
+ self._embed_data = spike
374
+ log_callback(f"Using spike matrix shape: {spike.shape}")
375
+ else:
376
+ log_callback(
377
+ "Warning: spike data is not a dense matrix, some analyses may fail"
378
+ )
379
+ self._embed_data = spike
380
+
381
+ progress_callback(100)
382
+ elapsed = time.time() - t0
383
+
384
+ return PipelineResult(
385
+ success=True,
386
+ artifacts={},
387
+ summary=f"Preprocessing completed in {elapsed:.1f}s",
388
+ elapsed_time=elapsed,
389
+ )
390
+
391
+ except Exception as e:
392
+ elapsed = time.time() - t0
393
+ log_callback(f"Error: {e}")
394
+ return PipelineResult(
395
+ success=False,
396
+ artifacts={},
397
+ summary=f"Failed after {elapsed:.1f}s",
398
+ error=str(e),
399
+ elapsed_time=elapsed,
400
+ )
401
+
402
+ def run_analysis(
403
+ self,
404
+ state: WorkflowState,
405
+ log_callback: Callable[[str], None],
406
+ progress_callback: Callable[[int], None],
407
+ cancel_check: Callable[[], bool] | None = None,
408
+ ) -> PipelineResult:
409
+ t0 = time.time()
410
+ artifacts: dict[str, Path] = {}
411
+
412
+ try:
413
+ self._check_cancel(cancel_check)
414
+
415
+ log_callback("Loading data...")
416
+ progress_callback(10)
417
+ asa_data = self._asa_data if self._asa_data is not None else self._load_data(state)
418
+ if self._input_hash is None:
419
+ self._input_hash = self._compute_input_hash(state)
420
+
421
+ self._ensure_matplotlib_backend()
422
+
423
+ log_callback(f"Running {state.analysis_mode} analysis...")
424
+ progress_callback(40)
425
+
426
+ mode = state.analysis_mode.lower()
427
+ if mode == "tda":
428
+ artifacts = self._run_tda(asa_data, state, log_callback)
429
+ elif mode == "decode":
430
+ artifacts = self._run_decode_only(asa_data, state, log_callback)
431
+ elif mode == "cohomap":
432
+ artifacts = self._run_cohomap(asa_data, state, log_callback)
433
+ elif mode == "pathcompare":
434
+ artifacts = self._run_pathcompare(asa_data, state, log_callback)
435
+ elif mode == "cohospace":
436
+ artifacts = self._run_cohospace(asa_data, state, log_callback)
437
+ elif mode == "fr":
438
+ artifacts = self._run_fr(asa_data, state, log_callback)
439
+ elif mode == "frm":
440
+ artifacts = self._run_frm(asa_data, state, log_callback)
441
+ elif mode == "gridscore":
442
+ artifacts = self._run_gridscore(asa_data, state, log_callback, progress_callback)
443
+ elif mode == "gridscore_inspect":
444
+ artifacts = self._run_gridscore_inspect(
445
+ asa_data, state, log_callback, progress_callback
446
+ )
447
+ else:
448
+ raise ProcessingError(f"Unknown analysis mode: {state.analysis_mode}")
449
+
450
+ progress_callback(100)
451
+ elapsed = time.time() - t0
452
+
453
+ return PipelineResult(
454
+ success=True,
455
+ artifacts=artifacts,
456
+ summary=f"Completed {state.analysis_mode} analysis in {elapsed:.1f}s",
457
+ elapsed_time=elapsed,
458
+ )
459
+
460
+ except Exception as e:
461
+ elapsed = time.time() - t0
462
+ log_callback(f"Error: {e}")
463
+ return PipelineResult(
464
+ success=False,
465
+ artifacts=artifacts,
466
+ summary=f"Failed after {elapsed:.1f}s",
467
+ error=str(e),
468
+ elapsed_time=elapsed,
469
+ )
470
+
471
+ def _load_data(self, state: WorkflowState) -> dict[str, Any]:
472
+ if state.input_mode == "asa":
473
+ path = resolve_path(state, state.asa_file)
474
+ data = np.load(path, allow_pickle=True)
475
+ return {k: data[k] for k in data.files}
476
+ if state.input_mode == "neuron_traj":
477
+ neuron_path = resolve_path(state, state.neuron_file)
478
+ traj_path = resolve_path(state, state.traj_file)
479
+ neuron_data = np.load(neuron_path, allow_pickle=True)
480
+ traj_data = np.load(traj_path, allow_pickle=True)
481
+ return {
482
+ "spike": neuron_data.get("spike", neuron_data),
483
+ "x": traj_data["x"],
484
+ "y": traj_data["y"],
485
+ "t": traj_data["t"],
486
+ }
487
+ raise ProcessingError(f"Unknown input mode: {state.input_mode}")
488
+
489
+ def _run_tda(
490
+ self, asa_data: dict[str, Any], state: WorkflowState, log_callback
491
+ ) -> dict[str, Path]:
492
+ from canns.analyzer.data.asa import TDAConfig, tda_vis
493
+ from canns.analyzer.data.asa.tda import _plot_barcode, _plot_barcode_with_shuffle
494
+
495
+ out_dir = self._results_dir(state) / "TDA"
496
+ out_dir.mkdir(parents=True, exist_ok=True)
497
+
498
+ params = state.analysis_params
499
+ config = TDAConfig(
500
+ dim=params.get("dim", 6),
501
+ num_times=params.get("num_times", 5),
502
+ active_times=params.get("active_times", 15000),
503
+ k=params.get("k", 1000),
504
+ n_points=params.get("n_points", 1200),
505
+ metric=params.get("metric", "cosine"),
506
+ nbs=params.get("nbs", 800),
507
+ maxdim=params.get("maxdim", 1),
508
+ coeff=params.get("coeff", 47),
509
+ show=False,
510
+ do_shuffle=params.get("do_shuffle", False),
511
+ num_shuffles=params.get("num_shuffles", 1000),
512
+ progress_bar=False,
513
+ standardize=False,
514
+ )
515
+
516
+ log_callback("Computing persistent homology...")
517
+
518
+ if self._embed_data is None:
519
+ raise ProcessingError("No preprocessed data available. Run preprocessing first.")
520
+ if not isinstance(self._embed_data, np.ndarray) or self._embed_data.ndim != 2:
521
+ raise ProcessingError(
522
+ "TDA requires a dense spike matrix (T,N). "
523
+ "Please choose 'Embed Spike Trains' in preprocessing or provide a dense spike matrix in the .npz."
524
+ )
525
+
526
+ persistence_path = out_dir / "persistence.npz"
527
+ barcode_path = out_dir / "barcode.png"
528
+
529
+ embed_hash = self._embed_hash or self._hash_obj({"embed": "unknown"})
530
+ tda_hash = self._hash_obj({"embed_hash": embed_hash, "params": params})
531
+
532
+ if self._stage_cache_hit(out_dir, tda_hash, [persistence_path, barcode_path]):
533
+ log_callback("♻️ Using cached TDA results.")
534
+ return {"persistence": persistence_path, "barcode": barcode_path}
535
+
536
+ embed_data = self._embed_data
537
+ if params.get("standardize", False):
538
+ try:
539
+ from sklearn.preprocessing import StandardScaler
540
+
541
+ embed_data = StandardScaler().fit_transform(embed_data)
542
+ except Exception as e:
543
+ raise ProcessingError(f"StandardScaler failed: {e}") from e
544
+
545
+ result = tda_vis(
546
+ embed_data=embed_data,
547
+ config=config,
548
+ )
549
+
550
+ np.savez_compressed(persistence_path, persistence_result=result)
551
+
552
+ try:
553
+ persistence = result.get("persistence")
554
+ shuffle_max = result.get("shuffle_max")
555
+ if config.do_shuffle and shuffle_max is not None:
556
+ fig = _plot_barcode_with_shuffle(persistence, shuffle_max)
557
+ else:
558
+ fig = _plot_barcode(persistence)
559
+ fig.savefig(barcode_path, dpi=200, bbox_inches="tight")
560
+ try:
561
+ import matplotlib.pyplot as plt
562
+
563
+ plt.close(fig)
564
+ except Exception:
565
+ pass
566
+ except Exception as e:
567
+ log_callback(f"Warning: failed to save barcode: {e}")
568
+
569
+ self._write_cache_meta(
570
+ self._stage_cache_path(out_dir),
571
+ {"hash": tda_hash, "embed_hash": embed_hash, "params": params},
572
+ )
573
+
574
+ return {"persistence": persistence_path, "barcode": barcode_path}
575
+
576
+ def _load_or_run_decode(
577
+ self,
578
+ asa_data: dict[str, Any],
579
+ persistence_path: Path,
580
+ state: WorkflowState,
581
+ log_callback,
582
+ ) -> dict[str, Any]:
583
+ from canns.analyzer.data.asa import (
584
+ decode_circular_coordinates,
585
+ decode_circular_coordinates_multi,
586
+ )
587
+
588
+ decode_dir = self._results_dir(state) / "CohoMap"
589
+ decode_dir.mkdir(parents=True, exist_ok=True)
590
+ decode_path = decode_dir / "decoding.npz"
591
+
592
+ params = state.analysis_params
593
+ decode_version = str(params.get("decode_version", "v2"))
594
+ num_circ = int(params.get("num_circ", 2))
595
+ decode_params = {
596
+ "real_ground": params.get("real_ground", True),
597
+ "real_of": params.get("real_of", True),
598
+ "decode_version": decode_version,
599
+ "num_circ": num_circ,
600
+ }
601
+ persistence_hash = self._hash_file(persistence_path)
602
+ decode_hash = self._hash_obj(
603
+ {"persistence_hash": persistence_hash, "params": decode_params}
604
+ )
605
+
606
+ meta_path = self._stage_cache_path(decode_dir)
607
+ meta = self._load_cache_meta(meta_path)
608
+ if decode_path.exists() and meta.get("decode_hash") == decode_hash:
609
+ log_callback("♻️ Using cached decoding.")
610
+ return self._load_npz_dict(decode_path)
611
+
612
+ log_callback("Decoding circular coordinates...")
613
+ persistence_result = self._load_npz_dict(persistence_path)
614
+ if decode_version == "v0":
615
+ decode_result = decode_circular_coordinates(
616
+ persistence_result=persistence_result,
617
+ spike_data=asa_data,
618
+ real_ground=decode_params["real_ground"],
619
+ real_of=decode_params["real_of"],
620
+ save_path=str(decode_path),
621
+ )
622
+ else:
623
+ if self._embed_data is None:
624
+ raise ProcessingError("No preprocessed data available for decode v2.")
625
+ spike_data = dict(asa_data)
626
+ spike_data["spike"] = self._embed_data
627
+ decode_result = decode_circular_coordinates_multi(
628
+ persistence_result=persistence_result,
629
+ spike_data=spike_data,
630
+ save_path=str(decode_path),
631
+ num_circ=num_circ,
632
+ )
633
+
634
+ meta["decode_hash"] = decode_hash
635
+ meta["persistence_hash"] = persistence_hash
636
+ meta["decode_params"] = decode_params
637
+ self._write_cache_meta(meta_path, meta)
638
+ return decode_result
639
+
640
+ def _run_decode_only(
641
+ self, asa_data: dict[str, Any], state: WorkflowState, log_callback
642
+ ) -> dict[str, Path]:
643
+ tda_dir = self._results_dir(state) / "TDA"
644
+ persistence_path = tda_dir / "persistence.npz"
645
+ if not persistence_path.exists():
646
+ raise ProcessingError("TDA results not found. Run TDA analysis first.")
647
+
648
+ self._load_or_run_decode(asa_data, persistence_path, state, log_callback)
649
+ return {"decoding": self._results_dir(state) / "CohoMap" / "decoding.npz"}
650
+
651
+ def _run_cohomap(
652
+ self, asa_data: dict[str, Any], state: WorkflowState, log_callback
653
+ ) -> dict[str, Path]:
654
+ from canns.analyzer.data.asa import plot_cohomap_multi
655
+ from canns.analyzer.visualization import PlotConfigs
656
+
657
+ tda_dir = self._results_dir(state) / "TDA"
658
+ persistence_path = tda_dir / "persistence.npz"
659
+ if not persistence_path.exists():
660
+ raise ProcessingError("TDA results not found. Run TDA analysis first.")
661
+
662
+ out_dir = self._results_dir(state) / "CohoMap"
663
+ out_dir.mkdir(parents=True, exist_ok=True)
664
+
665
+ decode_result = self._load_or_run_decode(asa_data, persistence_path, state, log_callback)
666
+
667
+ params = state.analysis_params
668
+ subsample = int(params.get("cohomap_subsample", 10))
669
+
670
+ cohomap_path = out_dir / "cohomap.png"
671
+ stage_hash = self._hash_obj(
672
+ {
673
+ "decode_hash": self._load_cache_meta(self._stage_cache_path(out_dir)).get(
674
+ "decode_hash"
675
+ ),
676
+ "plot": "cohomap",
677
+ "subsample": subsample,
678
+ }
679
+ )
680
+ if self._stage_cache_hit(out_dir, stage_hash, [cohomap_path]):
681
+ log_callback("♻️ Using cached CohoMap plot.")
682
+ return {"decoding": out_dir / "decoding.npz", "cohomap": cohomap_path}
683
+
684
+ log_callback("Generating cohomology map...")
685
+ pos = self._aligned_pos if self._aligned_pos is not None else asa_data
686
+ config = PlotConfigs.cohomap(show=False, save_path=str(cohomap_path))
687
+ plot_cohomap_multi(
688
+ decoding_result=decode_result,
689
+ position_data={"x": pos["x"], "y": pos["y"]},
690
+ config=config,
691
+ subsample=subsample,
692
+ )
693
+
694
+ self._write_cache_meta(
695
+ self._stage_cache_path(out_dir),
696
+ {
697
+ **self._load_cache_meta(self._stage_cache_path(out_dir)),
698
+ "hash": stage_hash,
699
+ },
700
+ )
701
+
702
+ return {"decoding": out_dir / "decoding.npz", "cohomap": cohomap_path}
703
+
704
+ def _run_pathcompare(
705
+ self, asa_data: dict[str, Any], state: WorkflowState, log_callback
706
+ ) -> dict[str, Path]:
707
+ from canns.analyzer.data.asa import (
708
+ align_coords_to_position_1d,
709
+ align_coords_to_position_2d,
710
+ apply_angle_scale,
711
+ plot_path_compare_1d,
712
+ plot_path_compare_2d,
713
+ )
714
+ from canns.analyzer.data.asa.path import (
715
+ find_coords_matrix,
716
+ find_times_box,
717
+ resolve_time_slice,
718
+ )
719
+ from canns.analyzer.visualization import PlotConfigs
720
+
721
+ tda_dir = self._results_dir(state) / "TDA"
722
+ persistence_path = tda_dir / "persistence.npz"
723
+ if not persistence_path.exists():
724
+ raise ProcessingError("TDA results not found. Run TDA analysis first.")
725
+
726
+ decode_result = self._load_or_run_decode(asa_data, persistence_path, state, log_callback)
727
+
728
+ out_dir = self._results_dir(state) / "PathCompare"
729
+ out_dir.mkdir(parents=True, exist_ok=True)
730
+
731
+ params = state.analysis_params or {}
732
+ pc_params = (
733
+ params.get("pathcompare") if isinstance(params.get("pathcompare"), dict) else params
734
+ )
735
+
736
+ def _param(key: str, default: Any = None) -> Any:
737
+ return pc_params.get(key, default) if isinstance(pc_params, dict) else default
738
+
739
+ angle_scale = _param("angle_scale", _param("theta_scale", "rad"))
740
+ dim_mode = _param("dim_mode", "2d")
741
+ dim = int(_param("dim", 1))
742
+ dim1 = int(_param("dim1", 1))
743
+ dim2 = int(_param("dim2", 2))
744
+ use_box = bool(_param("use_box", False))
745
+ interp_full = bool(_param("interp_full", True))
746
+ coords_key = _param("coords_key")
747
+ times_key = _param("times_key", _param("times_box_key"))
748
+ slice_mode = _param("slice_mode", "time")
749
+ tmin = _param("tmin")
750
+ tmax = _param("tmax")
751
+ imin = _param("imin")
752
+ imax = _param("imax")
753
+ stride = int(_param("stride", 1))
754
+ tail = int(_param("tail", 300))
755
+ fps = int(_param("fps", 20))
756
+ no_wrap = bool(_param("no_wrap", False))
757
+ animation_format = str(_param("animation_format", "none")).lower()
758
+ if animation_format not in {"none", "gif", "mp4"}:
759
+ animation_format = "none"
760
+
761
+ coords_raw, _ = find_coords_matrix(
762
+ decode_result,
763
+ coords_key=coords_key,
764
+ prefer_box_fallback=use_box,
765
+ )
766
+
767
+ if dim_mode == "1d":
768
+ idx = max(0, dim - 1)
769
+ if idx >= coords_raw.shape[1]:
770
+ raise ProcessingError(f"dim out of range for coords shape {coords_raw.shape}")
771
+ coords1 = coords_raw[:, idx]
772
+ else:
773
+ idx1 = max(0, dim1 - 1)
774
+ idx2 = max(0, dim2 - 1)
775
+ if idx1 >= coords_raw.shape[1] or idx2 >= coords_raw.shape[1]:
776
+ raise ProcessingError(f"dim1/dim2 out of range for coords shape {coords_raw.shape}")
777
+ coords2 = coords_raw[:, [idx1, idx2]]
778
+
779
+ pos = self._aligned_pos if self._aligned_pos is not None else asa_data
780
+ t_full = np.asarray(pos["t"]).ravel()
781
+ x_full = np.asarray(pos["x"]).ravel()
782
+ y_full = np.asarray(pos["y"]).ravel()
783
+
784
+ if use_box:
785
+ if times_key:
786
+ times_box = decode_result.get(times_key)
787
+ else:
788
+ times_box, _ = find_times_box(decode_result)
789
+ else:
790
+ times_box = None
791
+
792
+ log_callback("Aligning decoded coordinates to position...")
793
+ if dim_mode == "1d":
794
+ t_use, x_use, y_use, coords_use, _ = align_coords_to_position_1d(
795
+ t_full=t_full,
796
+ x_full=x_full,
797
+ y_full=y_full,
798
+ coords1=coords1,
799
+ use_box=use_box,
800
+ times_box=times_box,
801
+ interp_to_full=interp_full,
802
+ )
803
+ else:
804
+ t_use, x_use, y_use, coords_use, _ = align_coords_to_position_2d(
805
+ t_full=t_full,
806
+ x_full=x_full,
807
+ y_full=y_full,
808
+ coords2=coords2,
809
+ use_box=use_box,
810
+ times_box=times_box,
811
+ interp_to_full=interp_full,
812
+ )
813
+ scale = str(angle_scale) if str(angle_scale) in {"rad", "deg", "unit", "auto"} else "rad"
814
+ coords_use = apply_angle_scale(coords_use, scale)
815
+ if not no_wrap:
816
+ coords_use = np.mod(coords_use, 2 * np.pi)
817
+
818
+ if slice_mode == "index":
819
+ i0, i1 = resolve_time_slice(t_use, None, None, imin, imax)
820
+ else:
821
+ i0, i1 = resolve_time_slice(t_use, tmin, tmax, None, None)
822
+
823
+ stride = max(1, stride)
824
+ idx = slice(i0, i1, stride)
825
+ t_use = t_use[idx]
826
+ x_use = x_use[idx]
827
+ y_use = y_use[idx]
828
+ coords_use = coords_use[idx]
829
+
830
+ out_path = out_dir / "path_compare.png"
831
+ anim_path: Path | None = None
832
+ if animation_format == "gif":
833
+ anim_path = out_dir / "path_compare.gif"
834
+ elif animation_format == "mp4":
835
+ anim_path = out_dir / "path_compare.mp4"
836
+ decode_meta = self._load_cache_meta(
837
+ self._stage_cache_path(self._results_dir(state) / "CohoMap")
838
+ )
839
+ stage_hash = self._hash_obj(
840
+ {
841
+ "persistence": self._hash_file(persistence_path),
842
+ "decode_hash": decode_meta.get("decode_hash"),
843
+ "params": {
844
+ "angle_scale": scale,
845
+ "dim_mode": dim_mode,
846
+ "dim": dim,
847
+ "dim1": dim1,
848
+ "dim2": dim2,
849
+ "use_box": use_box,
850
+ "interp_full": interp_full,
851
+ "coords_key": coords_key,
852
+ "times_key": times_key,
853
+ "slice_mode": slice_mode,
854
+ "tmin": tmin,
855
+ "tmax": tmax,
856
+ "imin": imin,
857
+ "imax": imax,
858
+ "stride": stride,
859
+ "tail": tail,
860
+ "fps": fps,
861
+ "no_wrap": no_wrap,
862
+ "animation_format": animation_format,
863
+ },
864
+ }
865
+ )
866
+ required = [out_path]
867
+ if anim_path is not None:
868
+ required.append(anim_path)
869
+ if self._stage_cache_hit(out_dir, stage_hash, required):
870
+ log_callback("♻️ Using cached PathCompare plot.")
871
+ artifacts = {"path_compare": out_path}
872
+ if anim_path is not None:
873
+ if anim_path.suffix == ".gif":
874
+ artifacts["path_compare_gif"] = anim_path
875
+ else:
876
+ artifacts["path_compare_mp4"] = anim_path
877
+ return artifacts
878
+
879
+ log_callback("Generating path comparison...")
880
+ if dim_mode == "1d":
881
+ config = PlotConfigs.path_compare_1d(show=False, save_path=str(out_path))
882
+ plot_path_compare_1d(x_use, y_use, coords_use, config=config)
883
+ else:
884
+ config = PlotConfigs.path_compare_2d(show=False, save_path=str(out_path))
885
+ plot_path_compare_2d(x_use, y_use, coords_use, config=config)
886
+
887
+ artifacts: dict[str, Path] = {"path_compare": out_path}
888
+
889
+ if anim_path is not None:
890
+ try:
891
+ if dim_mode == "1d":
892
+ self._render_pathcompare_1d_animation(
893
+ x_use,
894
+ y_use,
895
+ coords_use,
896
+ t_use,
897
+ anim_path,
898
+ tail=tail,
899
+ fps=fps,
900
+ log_callback=log_callback,
901
+ )
902
+ else:
903
+ self._render_pathcompare_2d_animation(
904
+ x_use,
905
+ y_use,
906
+ coords_use,
907
+ t_use,
908
+ anim_path,
909
+ tail=tail,
910
+ fps=fps,
911
+ log_callback=log_callback,
912
+ )
913
+ if anim_path.suffix == ".gif":
914
+ artifacts["path_compare_gif"] = anim_path
915
+ else:
916
+ artifacts["path_compare_mp4"] = anim_path
917
+ except Exception as e:
918
+ log_callback(f"Warning: failed to render PathCompare animation: {e}")
919
+
920
+ self._write_cache_meta(self._stage_cache_path(out_dir), {"hash": stage_hash})
921
+ return artifacts
922
+
923
+ def _render_pathcompare_1d_animation(
924
+ self,
925
+ x: np.ndarray,
926
+ y: np.ndarray,
927
+ coords: np.ndarray,
928
+ t: np.ndarray,
929
+ save_path: Path,
930
+ *,
931
+ tail: int,
932
+ fps: int,
933
+ log_callback,
934
+ ) -> None:
935
+ import matplotlib.pyplot as plt
936
+
937
+ x = np.asarray(x).ravel()
938
+ y = np.asarray(y).ravel()
939
+ theta = np.asarray(coords).ravel()
940
+ t = np.asarray(t).ravel()
941
+
942
+ n_frames = len(theta)
943
+ if n_frames == 0:
944
+ raise ProcessingError("PathCompare animation has no frames.")
945
+
946
+ # Downsample if too many frames for animation
947
+ if n_frames > 20000:
948
+ factor = int(np.ceil(n_frames / 20000))
949
+ idx = np.arange(0, n_frames, factor)
950
+ x = x[idx]
951
+ y = y[idx]
952
+ theta = theta[idx]
953
+ t = t[idx]
954
+ n_frames = len(theta)
955
+ log_callback(f"PathCompare animation downsampled by x{factor} (frames={n_frames}).")
956
+
957
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5), dpi=120)
958
+
959
+ ax1.set_title("Physical path")
960
+ ax1.set_xlabel("x")
961
+ ax1.set_ylabel("y")
962
+ ax1.set_aspect("equal")
963
+ x_min, x_max = np.min(x), np.max(x)
964
+ y_min, y_max = np.min(y), np.max(y)
965
+ pad_x = (x_max - x_min) * 0.05 if x_max > x_min else 1.0
966
+ pad_y = (y_max - y_min) * 0.05 if y_max > y_min else 1.0
967
+ ax1.set_xlim(x_min - pad_x, x_max + pad_x)
968
+ ax1.set_ylim(y_min - pad_y, y_max + pad_y)
969
+
970
+ ax2.set_title("Decoded coho path (1D)")
971
+ ax2.set_aspect("equal")
972
+ ax2.axis("off")
973
+ ax2.set_xlim(-1.2, 1.2)
974
+ ax2.set_ylim(-1.2, 1.2)
975
+
976
+ (phys_trail,) = ax1.plot([], [], lw=1.0)
977
+ phys_dot = ax1.scatter([], [], s=30)
978
+ (circ_trail,) = ax2.plot([], [], lw=1.0)
979
+ circ_dot = ax2.scatter([], [], s=30)
980
+ title_text = fig.suptitle("", y=1.02)
981
+
982
+ def update(k: int) -> None:
983
+ a0 = max(0, k - tail) if tail > 0 else 0
984
+ xs = x[a0 : k + 1]
985
+ ys = y[a0 : k + 1]
986
+ phys_trail.set_data(xs, ys)
987
+ phys_dot.set_offsets(np.array([[x[k], y[k]]]))
988
+
989
+ x_unit = np.cos(theta[a0 : k + 1])
990
+ y_unit = np.sin(theta[a0 : k + 1])
991
+ circ_trail.set_data(x_unit, y_unit)
992
+ circ_dot.set_offsets(np.array([[np.cos(theta[k]), np.sin(theta[k])]]))
993
+
994
+ title_text.set_text(f"t = {float(t[k]):.3f}s (frame {k + 1}/{n_frames})")
995
+
996
+ fig.tight_layout()
997
+ self._save_animation(fig, update, n_frames, save_path, fps, log_callback)
998
+ plt.close(fig)
999
+
1000
+ def _render_pathcompare_2d_animation(
1001
+ self,
1002
+ x: np.ndarray,
1003
+ y: np.ndarray,
1004
+ coords: np.ndarray,
1005
+ t: np.ndarray,
1006
+ save_path: Path,
1007
+ *,
1008
+ tail: int,
1009
+ fps: int,
1010
+ log_callback,
1011
+ ) -> None:
1012
+ import matplotlib.pyplot as plt
1013
+
1014
+ from canns.analyzer.data.asa.path import (
1015
+ draw_base_parallelogram,
1016
+ skew_transform,
1017
+ snake_wrap_trail_in_parallelogram,
1018
+ )
1019
+
1020
+ x = np.asarray(x).ravel()
1021
+ y = np.asarray(y).ravel()
1022
+ coords = np.asarray(coords)
1023
+ t = np.asarray(t).ravel()
1024
+
1025
+ if coords.ndim != 2 or coords.shape[1] < 2:
1026
+ raise ProcessingError("PathCompare 2D animation requires coords with 2 columns.")
1027
+
1028
+ n_frames = len(coords)
1029
+ if n_frames == 0:
1030
+ raise ProcessingError("PathCompare animation has no frames.")
1031
+
1032
+ if n_frames > 20000:
1033
+ factor = int(np.ceil(n_frames / 20000))
1034
+ idx = np.arange(0, n_frames, factor)
1035
+ x = x[idx]
1036
+ y = y[idx]
1037
+ coords = coords[idx]
1038
+ t = t[idx]
1039
+ n_frames = len(coords)
1040
+ log_callback(f"PathCompare animation downsampled by x{factor} (frames={n_frames}).")
1041
+
1042
+ xy_skew = skew_transform(coords[:, :2])
1043
+
1044
+ e1 = np.array([2 * np.pi, 0.0])
1045
+ e2 = np.array([np.pi, np.sqrt(3) * np.pi])
1046
+ corners = np.vstack([[0.0, 0.0], e1, e2, e1 + e2])
1047
+ xm, ym = corners.min(axis=0)
1048
+ xM, yM = corners.max(axis=0)
1049
+ px2 = 0.05 * (xM - xm + 1e-9)
1050
+ py2 = 0.05 * (yM - ym + 1e-9)
1051
+
1052
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5), dpi=120)
1053
+
1054
+ ax1.set_title("Physical path")
1055
+ ax1.set_xlabel("x")
1056
+ ax1.set_ylabel("y")
1057
+ ax1.set_aspect("equal")
1058
+ x_min, x_max = np.min(x), np.max(x)
1059
+ y_min, y_max = np.min(y), np.max(y)
1060
+ pad_x = 0.05 * (x_max - x_min + 1e-9)
1061
+ pad_y = 0.05 * (y_max - y_min + 1e-9)
1062
+ ax1.set_xlim(x_min - pad_x, x_max + pad_x)
1063
+ ax1.set_ylim(y_min - pad_y, y_max + pad_y)
1064
+
1065
+ ax2.set_title("Torus path (skew)")
1066
+ ax2.set_xlabel(r"$\theta_1 + \frac{1}{2}\theta_2$")
1067
+ ax2.set_ylabel(r"$\frac{\sqrt{3}}{2}\theta_2$")
1068
+ ax2.set_aspect("equal")
1069
+ draw_base_parallelogram(ax2)
1070
+ ax2.set_xlim(xm - px2, xM + px2)
1071
+ ax2.set_ylim(ym - py2, yM + py2)
1072
+
1073
+ (phys_trail,) = ax1.plot([], [], lw=1.0)
1074
+ phys_dot = ax1.scatter([], [], s=30)
1075
+ (tor_trail,) = ax2.plot([], [], lw=1.0)
1076
+ tor_dot = ax2.scatter([], [], s=30)
1077
+ title_text = fig.suptitle("", y=1.02)
1078
+
1079
+ def update(k: int) -> None:
1080
+ a0 = max(0, k - tail) if tail > 0 else 0
1081
+ xs = x[a0 : k + 1]
1082
+ ys = y[a0 : k + 1]
1083
+ phys_trail.set_data(xs, ys)
1084
+ phys_dot.set_offsets(np.array([[x[k], y[k]]]))
1085
+
1086
+ seg = snake_wrap_trail_in_parallelogram(xy_skew[a0 : k + 1], e1=e1, e2=e2)
1087
+ tor_trail.set_data(seg[:, 0], seg[:, 1])
1088
+ tor_dot.set_offsets(np.array([[xy_skew[k, 0], xy_skew[k, 1]]]))
1089
+
1090
+ title_text.set_text(f"t = {float(t[k]):.3f}s (frame {k + 1}/{n_frames})")
1091
+
1092
+ fig.tight_layout()
1093
+ self._save_animation(fig, update, n_frames, save_path, fps, log_callback)
1094
+ plt.close(fig)
1095
+
1096
+ def _save_animation(
1097
+ self,
1098
+ fig,
1099
+ update_func,
1100
+ n_frames: int,
1101
+ save_path: Path,
1102
+ fps: int,
1103
+ log_callback,
1104
+ ) -> None:
1105
+ if save_path.suffix.lower() == ".gif":
1106
+ from matplotlib.animation import FuncAnimation, PillowWriter
1107
+
1108
+ def _update(k: int):
1109
+ update_func(k)
1110
+ return []
1111
+
1112
+ interval_ms = int(1000 / max(1, fps))
1113
+ ani = FuncAnimation(fig, _update, frames=n_frames, interval=interval_ms, blit=True)
1114
+
1115
+ last_pct = {"v": -1}
1116
+
1117
+ def _progress(i: int, total: int) -> None:
1118
+ if not total:
1119
+ return
1120
+ pct = int((i + 1) * 100 / total)
1121
+ if pct != last_pct["v"]:
1122
+ last_pct["v"] = pct
1123
+ log_callback(f"__PCANIM__ {pct} {i + 1}/{total}")
1124
+
1125
+ ani.save(
1126
+ str(save_path),
1127
+ writer=PillowWriter(fps=fps),
1128
+ progress_callback=_progress,
1129
+ )
1130
+ return
1131
+
1132
+ from canns.analyzer.visualization.core import (
1133
+ get_imageio_writer_kwargs,
1134
+ get_matplotlib_writer,
1135
+ select_animation_backend,
1136
+ )
1137
+
1138
+ backend_selection = select_animation_backend(
1139
+ save_path=str(save_path),
1140
+ requested_backend="auto",
1141
+ check_imageio_plugins=True,
1142
+ )
1143
+ for warning in backend_selection.warnings:
1144
+ log_callback(f"⚠️ {warning}")
1145
+
1146
+ if backend_selection.backend == "imageio":
1147
+ import imageio
1148
+
1149
+ writer_kwargs, mode = get_imageio_writer_kwargs(str(save_path), fps)
1150
+ last_pct = -1
1151
+ with imageio.get_writer(str(save_path), mode=mode, **writer_kwargs) as writer:
1152
+ for k in range(n_frames):
1153
+ update_func(k)
1154
+ fig.canvas.draw()
1155
+ frame = np.asarray(fig.canvas.buffer_rgba())
1156
+ writer.append_data(frame[:, :, :3])
1157
+ pct = int((k + 1) * 100 / n_frames)
1158
+ if pct != last_pct:
1159
+ last_pct = pct
1160
+ log_callback(f"__PCANIM__ {pct} {k + 1}/{n_frames}")
1161
+ return
1162
+
1163
+ from matplotlib.animation import FuncAnimation
1164
+
1165
+ def _update(k: int):
1166
+ update_func(k)
1167
+ return []
1168
+
1169
+ interval_ms = int(1000 / max(1, fps))
1170
+ ani = FuncAnimation(fig, _update, frames=n_frames, interval=interval_ms, blit=False)
1171
+ writer = get_matplotlib_writer(str(save_path), fps=fps)
1172
+ ani.save(str(save_path), writer=writer)
1173
+
1174
+ def _run_cohospace(
1175
+ self, asa_data: dict[str, Any], state: WorkflowState, log_callback
1176
+ ) -> dict[str, Path]:
1177
+ from canns.analyzer.data.asa import (
1178
+ plot_cohospace_neuron_1d,
1179
+ plot_cohospace_neuron_2d,
1180
+ plot_cohospace_population_1d,
1181
+ plot_cohospace_population_2d,
1182
+ plot_cohospace_trajectory_1d,
1183
+ plot_cohospace_trajectory_2d,
1184
+ )
1185
+ from canns.analyzer.data.asa.cohospace import (
1186
+ compute_cohoscore_1d,
1187
+ compute_cohoscore_2d,
1188
+ plot_cohospace_neuron_skewed,
1189
+ plot_cohospace_population_skewed,
1190
+ )
1191
+ from canns.analyzer.visualization import PlotConfigs
1192
+
1193
+ tda_dir = self._results_dir(state) / "TDA"
1194
+ persistence_path = tda_dir / "persistence.npz"
1195
+ if not persistence_path.exists():
1196
+ raise ProcessingError("TDA results not found. Run TDA analysis first.")
1197
+
1198
+ decode_result = self._load_or_run_decode(asa_data, persistence_path, state, log_callback)
1199
+
1200
+ out_dir = self._results_dir(state) / "CohoSpace"
1201
+ out_dir.mkdir(parents=True, exist_ok=True)
1202
+
1203
+ params = state.analysis_params
1204
+ artifacts: dict[str, Path] = {}
1205
+
1206
+ coords = np.asarray(decode_result.get("coords"))
1207
+ coordsbox = np.asarray(decode_result.get("coordsbox"))
1208
+ if coords.ndim != 2 or coords.shape[1] < 1:
1209
+ raise ProcessingError("decode_result['coords'] must be 2D.")
1210
+
1211
+ dim_mode = str(params.get("dim_mode", "2d"))
1212
+ dim = int(params.get("dim", 1))
1213
+ dim1 = int(params.get("dim1", 1))
1214
+ dim2 = int(params.get("dim2", 2))
1215
+ mode = str(params.get("mode", "fr"))
1216
+ top_percent = float(params.get("top_percent", 5.0))
1217
+ view = str(params.get("view", "both"))
1218
+ subsample = int(params.get("subsample", 2))
1219
+ unfold = str(params.get("unfold", "square"))
1220
+ skew_show_grid = bool(params.get("skew_show_grid", True))
1221
+ skew_tiles = int(params.get("skew_tiles", 0))
1222
+ enable_score = bool(params.get("enable_score", True))
1223
+ top_k = int(params.get("top_k", 10))
1224
+ use_best = bool(params.get("use_best", True))
1225
+ times = decode_result.get("times")
1226
+
1227
+ def pick_coords(arr: np.ndarray) -> np.ndarray:
1228
+ if dim_mode == "1d":
1229
+ idx = max(0, dim - 1)
1230
+ if idx >= arr.shape[1]:
1231
+ raise ProcessingError(f"dim out of range for coords shape {arr.shape}")
1232
+ return arr[:, idx]
1233
+ idx1 = max(0, dim1 - 1)
1234
+ idx2 = max(0, dim2 - 1)
1235
+ if idx1 >= arr.shape[1] or idx2 >= arr.shape[1]:
1236
+ raise ProcessingError(f"dim1/dim2 out of range for coords shape {arr.shape}")
1237
+ return arr[:, [idx1, idx2]]
1238
+
1239
+ coords2 = pick_coords(coords)
1240
+ coordsbox2 = pick_coords(coordsbox) if coordsbox.ndim == 2 else coords2
1241
+
1242
+ if mode == "spike":
1243
+ activity = self._build_spike_matrix_from_events(asa_data)
1244
+ else:
1245
+ activity = (
1246
+ self._embed_data
1247
+ if self._embed_data is not None
1248
+ else self._build_spike_matrix_from_events(asa_data)
1249
+ )
1250
+
1251
+ scores = None
1252
+ top_ids = None
1253
+ neuron_id = int(params.get("neuron_id", 0))
1254
+ if enable_score:
1255
+ try:
1256
+ if dim_mode == "1d":
1257
+ scores = compute_cohoscore_1d(
1258
+ coords2, activity, top_percent=top_percent, times=times
1259
+ )
1260
+ else:
1261
+ scores = compute_cohoscore_2d(
1262
+ coords2, activity, top_percent=top_percent, times=times
1263
+ )
1264
+ cohoscore_path = out_dir / "cohoscore.npy"
1265
+ np.save(cohoscore_path, scores)
1266
+ except Exception as e:
1267
+ log_callback(f"⚠️ CohoScore computation failed: {e}")
1268
+ scores = None
1269
+
1270
+ if scores is not None:
1271
+ valid = np.where(~np.isnan(scores))[0]
1272
+ if valid.size > 0:
1273
+ sorted_idx = valid[np.argsort(scores[valid])]
1274
+ top_ids = sorted_idx[: min(top_k, len(sorted_idx))]
1275
+ top_ids_path = out_dir / "cohospace_top_ids.npy"
1276
+ np.save(top_ids_path, top_ids)
1277
+ else:
1278
+ log_callback("⚠️ CohoScore: all values are NaN.")
1279
+
1280
+ if view in {"both", "single"} and enable_score and scores is not None and use_best:
1281
+ valid = np.where(~np.isnan(scores))[0]
1282
+ if valid.size > 0:
1283
+ best_id = int(valid[np.argmin(scores[valid])])
1284
+ neuron_id = best_id
1285
+ log_callback(f"🎯 CohoSpace neuron auto-selected by best CohoScore: {neuron_id}")
1286
+
1287
+ decode_meta = self._load_cache_meta(
1288
+ self._stage_cache_path(self._results_dir(state) / "CohoMap")
1289
+ )
1290
+ stage_hash = self._hash_obj(
1291
+ {
1292
+ "persistence": self._hash_file(persistence_path),
1293
+ "decode_hash": decode_meta.get("decode_hash"),
1294
+ "params": params,
1295
+ }
1296
+ )
1297
+ meta_path = self._stage_cache_path(out_dir)
1298
+ required = [out_dir / "cohospace_trajectory.png"]
1299
+ if view in {"both", "population"}:
1300
+ required.append(out_dir / "cohospace_population.png")
1301
+ if view in {"both", "single"}:
1302
+ required.append(out_dir / f"cohospace_neuron_{neuron_id}.png")
1303
+
1304
+ if self._stage_cache_hit(out_dir, stage_hash, required):
1305
+ log_callback("♻️ Using cached CohoSpace plots.")
1306
+ artifacts = {"trajectory": out_dir / "cohospace_trajectory.png"}
1307
+ if view in {"both", "single"}:
1308
+ artifacts["neuron"] = out_dir / f"cohospace_neuron_{neuron_id}.png"
1309
+ if view in {"both", "population"}:
1310
+ artifacts["population"] = out_dir / "cohospace_population.png"
1311
+ return artifacts
1312
+
1313
+ log_callback("Plotting cohomology space trajectory...")
1314
+ traj_path = out_dir / "cohospace_trajectory.png"
1315
+ if dim_mode == "1d":
1316
+ traj_cfg = PlotConfigs.cohospace_trajectory_1d(show=False, save_path=str(traj_path))
1317
+ plot_cohospace_trajectory_1d(
1318
+ coords=coords2,
1319
+ times=None,
1320
+ subsample=subsample,
1321
+ config=traj_cfg,
1322
+ )
1323
+ else:
1324
+ traj_cfg = PlotConfigs.cohospace_trajectory_2d(show=False, save_path=str(traj_path))
1325
+ plot_cohospace_trajectory_2d(
1326
+ coords=coords2,
1327
+ times=None,
1328
+ subsample=subsample,
1329
+ config=traj_cfg,
1330
+ )
1331
+ artifacts["trajectory"] = traj_path
1332
+
1333
+ if neuron_id is not None and view in {"both", "single"}:
1334
+ log_callback(f"Plotting neuron {neuron_id}...")
1335
+ neuron_path = out_dir / f"cohospace_neuron_{neuron_id}.png"
1336
+ if unfold == "skew" and dim_mode != "1d":
1337
+ plot_cohospace_neuron_skewed(
1338
+ coords=coordsbox2,
1339
+ activity=activity,
1340
+ neuron_id=int(neuron_id),
1341
+ mode=mode,
1342
+ top_percent=top_percent,
1343
+ times=times,
1344
+ save_path=str(neuron_path),
1345
+ show=False,
1346
+ show_grid=skew_show_grid,
1347
+ n_tiles=skew_tiles,
1348
+ )
1349
+ else:
1350
+ if dim_mode == "1d":
1351
+ neuron_cfg = PlotConfigs.cohospace_neuron_1d(
1352
+ show=False, save_path=str(neuron_path)
1353
+ )
1354
+ plot_cohospace_neuron_1d(
1355
+ coords=coordsbox2,
1356
+ activity=activity,
1357
+ neuron_id=int(neuron_id),
1358
+ mode=mode,
1359
+ top_percent=top_percent,
1360
+ times=times,
1361
+ config=neuron_cfg,
1362
+ )
1363
+ else:
1364
+ neuron_cfg = PlotConfigs.cohospace_neuron_2d(
1365
+ show=False, save_path=str(neuron_path)
1366
+ )
1367
+ plot_cohospace_neuron_2d(
1368
+ coords=coordsbox2,
1369
+ activity=activity,
1370
+ neuron_id=int(neuron_id),
1371
+ mode=mode,
1372
+ top_percent=top_percent,
1373
+ times=times,
1374
+ config=neuron_cfg,
1375
+ )
1376
+ artifacts["neuron"] = neuron_path
1377
+
1378
+ if view in {"both", "population"}:
1379
+ log_callback("Plotting population activity...")
1380
+ pop_path = out_dir / "cohospace_population.png"
1381
+ if enable_score and top_ids is not None:
1382
+ neuron_ids = [int(i) for i in top_ids.tolist()]
1383
+ log_callback(f"CohoSpace: aggregating top-{len(neuron_ids)} neurons by CohoScore.")
1384
+ else:
1385
+ neuron_ids = list(range(activity.shape[1]))
1386
+ if unfold == "skew" and dim_mode != "1d":
1387
+ plot_cohospace_population_skewed(
1388
+ coords=coords2,
1389
+ activity=activity,
1390
+ neuron_ids=neuron_ids,
1391
+ mode=mode,
1392
+ top_percent=top_percent,
1393
+ times=times,
1394
+ save_path=str(pop_path),
1395
+ show=False,
1396
+ show_grid=skew_show_grid,
1397
+ n_tiles=skew_tiles,
1398
+ )
1399
+ else:
1400
+ if dim_mode == "1d":
1401
+ pop_cfg = PlotConfigs.cohospace_population_1d(
1402
+ show=False, save_path=str(pop_path)
1403
+ )
1404
+ plot_cohospace_population_1d(
1405
+ coords=coords2,
1406
+ activity=activity,
1407
+ neuron_ids=neuron_ids,
1408
+ mode=mode,
1409
+ top_percent=top_percent,
1410
+ times=times,
1411
+ config=pop_cfg,
1412
+ )
1413
+ else:
1414
+ pop_cfg = PlotConfigs.cohospace_population_2d(
1415
+ show=False, save_path=str(pop_path)
1416
+ )
1417
+ plot_cohospace_population_2d(
1418
+ coords=coords2,
1419
+ activity=activity,
1420
+ neuron_ids=neuron_ids,
1421
+ mode=mode,
1422
+ top_percent=top_percent,
1423
+ times=times,
1424
+ config=pop_cfg,
1425
+ )
1426
+ artifacts["population"] = pop_path
1427
+
1428
+ self._write_cache_meta(meta_path, {"hash": stage_hash})
1429
+ return artifacts
1430
+
1431
+ def _run_fr(
1432
+ self, asa_data: dict[str, Any], state: WorkflowState, log_callback
1433
+ ) -> dict[str, Path]:
1434
+ from canns.analyzer.data.asa import compute_fr_heatmap_matrix, save_fr_heatmap_png
1435
+ from canns.analyzer.visualization import PlotConfigs
1436
+
1437
+ out_dir = self._results_dir(state) / "FR"
1438
+ out_dir.mkdir(parents=True, exist_ok=True)
1439
+
1440
+ params = state.analysis_params
1441
+ neuron_range = params.get("neuron_range", None)
1442
+ time_range = params.get("time_range", None)
1443
+ normalize = params.get("normalize", "zscore_per_neuron")
1444
+ if normalize in {"none", "", None}:
1445
+ normalize = None
1446
+ mode = params.get("mode", "fr")
1447
+
1448
+ if mode == "spike":
1449
+ spike_data = self._build_spike_matrix_from_events(asa_data)
1450
+ else:
1451
+ spike_data = self._embed_data
1452
+
1453
+ if spike_data is None:
1454
+ raise ProcessingError("No spike data available for FR.")
1455
+
1456
+ out_path = out_dir / "fr_heatmap.png"
1457
+ stage_hash = self._hash_obj(
1458
+ {
1459
+ "embed_hash": self._embed_hash,
1460
+ "params": params,
1461
+ }
1462
+ )
1463
+ if self._stage_cache_hit(out_dir, stage_hash, [out_path]):
1464
+ log_callback("♻️ Using cached FR heatmap.")
1465
+ return {"fr_heatmap": out_path}
1466
+
1467
+ log_callback("Computing firing rate heatmap...")
1468
+ fr_matrix = compute_fr_heatmap_matrix(
1469
+ spike_data,
1470
+ neuron_range=neuron_range,
1471
+ time_range=time_range,
1472
+ normalize=normalize,
1473
+ )
1474
+
1475
+ config = PlotConfigs.fr_heatmap(show=False, save_path=str(out_path))
1476
+ save_fr_heatmap_png(fr_matrix, config=config, dpi=200)
1477
+
1478
+ self._write_cache_meta(self._stage_cache_path(out_dir), {"hash": stage_hash})
1479
+ return {"fr_heatmap": out_path}
1480
+
1481
+ def _run_frm(
1482
+ self, asa_data: dict[str, Any], state: WorkflowState, log_callback
1483
+ ) -> dict[str, Path]:
1484
+ from canns.analyzer.data.asa import compute_frm, plot_frm
1485
+ from canns.analyzer.visualization import PlotConfigs
1486
+
1487
+ out_dir = self._results_dir(state) / "FRM"
1488
+ out_dir.mkdir(parents=True, exist_ok=True)
1489
+
1490
+ params = state.analysis_params
1491
+ neuron_id = int(params.get("neuron_id", 0))
1492
+ bins = int(params.get("bin_size", 50))
1493
+ min_occupancy = int(params.get("min_occupancy", 1))
1494
+ smoothing = bool(params.get("smoothing", False))
1495
+ smooth_sigma = float(params.get("smooth_sigma", 2.0))
1496
+ mode = str(params.get("mode", "fr"))
1497
+
1498
+ if mode == "spike":
1499
+ spike_data = self._build_spike_matrix_from_events(asa_data)
1500
+ else:
1501
+ spike_data = self._embed_data
1502
+ if spike_data is None:
1503
+ raise ProcessingError("No spike data available for FRM.")
1504
+
1505
+ pos = self._aligned_pos if self._aligned_pos is not None else asa_data
1506
+ x = np.asarray(pos.get("x"))
1507
+ y = np.asarray(pos.get("y"))
1508
+
1509
+ if x is None or y is None:
1510
+ raise ProcessingError("Position data (x,y) is required for FRM.")
1511
+
1512
+ out_path = out_dir / f"frm_neuron_{neuron_id}.png"
1513
+ stage_hash = self._hash_obj(
1514
+ {
1515
+ "embed_hash": self._embed_hash,
1516
+ "params": params,
1517
+ }
1518
+ )
1519
+ if self._stage_cache_hit(out_dir, stage_hash, [out_path]):
1520
+ log_callback("♻️ Using cached FRM.")
1521
+ return {"frm": out_path}
1522
+
1523
+ log_callback(f"Computing firing rate map for neuron {neuron_id}...")
1524
+ frm_result = compute_frm(
1525
+ spike_data,
1526
+ x,
1527
+ y,
1528
+ neuron_id=neuron_id,
1529
+ bins=max(1, bins),
1530
+ min_occupancy=min_occupancy,
1531
+ smoothing=smoothing,
1532
+ sigma=smooth_sigma,
1533
+ nan_for_empty=True,
1534
+ )
1535
+
1536
+ config = PlotConfigs.frm(show=False, save_path=str(out_path))
1537
+ plot_frm(frm_result.frm, config=config, dpi=200)
1538
+
1539
+ self._write_cache_meta(self._stage_cache_path(out_dir), {"hash": stage_hash})
1540
+ return {"frm": out_path}
1541
+
1542
+ def _run_gridscore(
1543
+ self,
1544
+ asa_data: dict[str, Any],
1545
+ state: WorkflowState,
1546
+ log_callback,
1547
+ progress_callback: Callable[[int], None],
1548
+ ) -> dict[str, Path]:
1549
+ """Run batch gridness score computation."""
1550
+ import csv
1551
+
1552
+ from canns.analyzer.data.cell_classification import (
1553
+ GridnessAnalyzer,
1554
+ compute_2d_autocorrelation,
1555
+ )
1556
+
1557
+ params = state.analysis_params or {}
1558
+ gs_params = params.get("gridscore") if isinstance(params.get("gridscore"), dict) else {}
1559
+
1560
+ def _param(key: str, default: Any) -> Any:
1561
+ if key in params:
1562
+ return params.get(key, default)
1563
+ if gs_params and key in gs_params:
1564
+ return gs_params.get(key, default)
1565
+ return default
1566
+
1567
+ n_start = int(_param("neuron_start", 0))
1568
+ n_end = int(_param("neuron_end", 0))
1569
+ bins = int(_param("bins", 50))
1570
+ min_occ = int(_param("min_occupancy", 1))
1571
+ smoothing = bool(_param("smoothing", False))
1572
+ sigma = float(_param("sigma", 1.0))
1573
+ overlap = float(_param("overlap", 0.8))
1574
+ mode = str(_param("mode", "fr")).strip().lower()
1575
+ score_thr = float(_param("score_thr", 0.3))
1576
+
1577
+ if mode not in {"fr", "spike"}:
1578
+ mode = "fr"
1579
+ bins = max(5, bins)
1580
+ overlap = max(0.1, min(1.0, overlap))
1581
+
1582
+ pos = self._aligned_pos if self._aligned_pos is not None else asa_data
1583
+ if "x" not in pos or "y" not in pos or pos["x"] is None or pos["y"] is None:
1584
+ raise ProcessingError("GridScore requires position data (x, y).")
1585
+
1586
+ if mode == "fr":
1587
+ if isinstance(self._embed_data, np.ndarray) and self._embed_data.ndim == 2:
1588
+ activity_full = self._embed_data
1589
+ log_callback("GridScore[FR]: using preprocessed spike matrix.")
1590
+ else:
1591
+ log_callback("GridScore[FR]: no preprocessed matrix, falling back to spike mode.")
1592
+ activity_full = self._build_spike_matrix_from_events(asa_data)
1593
+ else:
1594
+ activity_full = self._build_spike_matrix_from_events(asa_data)
1595
+ log_callback("GridScore[spike]: using event-based spike matrix.")
1596
+
1597
+ sp = np.asarray(activity_full)
1598
+ if sp.ndim != 2:
1599
+ raise ProcessingError(f"GridScore expects 2D spike matrix, got ndim={sp.ndim}.")
1600
+
1601
+ x = np.asarray(pos["x"]).ravel()
1602
+ y = np.asarray(pos["y"]).ravel()
1603
+ m = min(len(x), len(y), sp.shape[0])
1604
+ x = x[:m]
1605
+ y = y[:m]
1606
+ sp = sp[:m, :]
1607
+
1608
+ finite = np.isfinite(x) & np.isfinite(y)
1609
+ if not np.all(finite):
1610
+ x = x[finite]
1611
+ y = y[finite]
1612
+ sp = sp[finite, :]
1613
+
1614
+ total_neurons = sp.shape[1]
1615
+ if n_end <= 0 or n_end > total_neurons:
1616
+ n_end = total_neurons
1617
+ n_start = max(0, min(n_start, total_neurons - 1))
1618
+ n_end = max(n_start + 1, n_end)
1619
+
1620
+ xmin, xmax = float(np.min(x)), float(np.max(x))
1621
+ ymin, ymax = float(np.min(y)), float(np.max(y))
1622
+ eps = 1e-12
1623
+ if xmax - xmin < eps:
1624
+ xmax = xmin + 1.0
1625
+ if ymax - ymin < eps:
1626
+ ymax = ymin + 1.0
1627
+
1628
+ ix = np.floor((x - xmin) / (xmax - xmin + eps) * bins).astype(int)
1629
+ iy = np.floor((y - ymin) / (ymax - ymin + eps) * bins).astype(int)
1630
+ ix = np.clip(ix, 0, bins - 1)
1631
+ iy = np.clip(iy, 0, bins - 1)
1632
+ flat = (iy * bins + ix).astype(int)
1633
+
1634
+ occ = np.bincount(flat, minlength=bins * bins).astype(float).reshape(bins, bins)
1635
+ occ_mask = occ >= float(min_occ)
1636
+
1637
+ gaussian_filter = None
1638
+ if smoothing and sigma > 0:
1639
+ try:
1640
+ from scipy.ndimage import gaussian_filter as _gaussian_filter
1641
+ except Exception as e: # pragma: no cover - optional dependency
1642
+ raise ProcessingError(f"GridScore requires scipy for smoothing: {e}") from e
1643
+ gaussian_filter = _gaussian_filter
1644
+
1645
+ def _rate_map_for_neuron(col: int) -> np.ndarray:
1646
+ weights = sp[:, col].astype(float, copy=False)
1647
+ spike_map = (
1648
+ np.bincount(flat, weights=weights, minlength=bins * bins)
1649
+ .astype(float)
1650
+ .reshape(bins, bins)
1651
+ )
1652
+ rate_map = np.zeros_like(spike_map)
1653
+ rate_map[occ_mask] = spike_map[occ_mask] / occ[occ_mask]
1654
+ if gaussian_filter is not None:
1655
+ rate_map = gaussian_filter(rate_map, sigma=float(sigma), mode="nearest")
1656
+ return rate_map
1657
+
1658
+ analyzer = GridnessAnalyzer()
1659
+ n_sel = n_end - n_start
1660
+ scores = np.full((n_sel,), np.nan, dtype=float)
1661
+ spacing = np.full((n_sel, 3), np.nan, dtype=float)
1662
+ orientation = np.full((n_sel, 3), np.nan, dtype=float)
1663
+ ellipse = np.full((n_sel, 5), np.nan, dtype=float)
1664
+ ellipse_theta_deg = np.full((n_sel,), np.nan, dtype=float)
1665
+ center_radius = np.full((n_sel,), np.nan, dtype=float)
1666
+ optimal_radius = np.full((n_sel,), np.nan, dtype=float)
1667
+
1668
+ log_callback(
1669
+ f"GridScore: computing neurons [{n_start}, {n_end}) with bins={bins}, overlap={overlap:.2f}."
1670
+ )
1671
+
1672
+ for j, nid in enumerate(range(n_start, n_end)):
1673
+ rate_map = _rate_map_for_neuron(nid)
1674
+ autocorr = compute_2d_autocorrelation(rate_map, overlap=overlap)
1675
+ result = analyzer.compute_gridness_score(autocorr)
1676
+
1677
+ scores[j] = float(result.score)
1678
+ if result.spacing is not None and np.size(result.spacing) >= 3:
1679
+ spacing[j, :] = np.asarray(result.spacing).ravel()[:3]
1680
+ if result.orientation is not None and np.size(result.orientation) >= 3:
1681
+ orientation[j, :] = np.asarray(result.orientation).ravel()[:3]
1682
+ if result.ellipse is not None and np.size(result.ellipse) >= 5:
1683
+ ellipse[j, :] = np.asarray(result.ellipse).ravel()[:5]
1684
+ ellipse_theta_deg[j] = float(result.ellipse_theta_deg)
1685
+ center_radius[j] = float(result.center_radius)
1686
+ optimal_radius[j] = float(result.optimal_radius)
1687
+
1688
+ if (j + 1) % max(1, n_sel // 10) == 0:
1689
+ progress = 60 + int(35 * (j + 1) / max(1, n_sel))
1690
+ progress_callback(min(98, progress))
1691
+
1692
+ out_dir = self._results_dir(state) / "GRIDScore"
1693
+ out_dir.mkdir(parents=True, exist_ok=True)
1694
+
1695
+ gridscore_npz = out_dir / "gridscore.npz"
1696
+ np.savez_compressed(
1697
+ str(gridscore_npz),
1698
+ neuron_start=n_start,
1699
+ neuron_end=n_end,
1700
+ neuron_ids=np.arange(n_start, n_end, dtype=int),
1701
+ bins=bins,
1702
+ min_occupancy=min_occ,
1703
+ smoothing=smoothing,
1704
+ sigma=sigma,
1705
+ overlap=overlap,
1706
+ mode=mode,
1707
+ score=scores,
1708
+ grid_score=scores,
1709
+ spacing=spacing,
1710
+ orientation=orientation,
1711
+ ellipse=ellipse,
1712
+ ellipse_theta_deg=ellipse_theta_deg,
1713
+ center_radius=center_radius,
1714
+ optimal_radius=optimal_radius,
1715
+ )
1716
+
1717
+ gridscore_csv = out_dir / "gridscore.csv"
1718
+ with gridscore_csv.open("w", newline="", encoding="utf-8") as f:
1719
+ writer = csv.writer(f)
1720
+ writer.writerow(
1721
+ [
1722
+ "neuron_id",
1723
+ "grid_score",
1724
+ "spacing1",
1725
+ "spacing2",
1726
+ "spacing3",
1727
+ "orient1_deg",
1728
+ "orient2_deg",
1729
+ "orient3_deg",
1730
+ "ellipse_cx",
1731
+ "ellipse_cy",
1732
+ "ellipse_rx",
1733
+ "ellipse_ry",
1734
+ "ellipse_theta_deg",
1735
+ "center_radius",
1736
+ "optimal_radius",
1737
+ ]
1738
+ )
1739
+ for j, nid in enumerate(range(n_start, n_end)):
1740
+ writer.writerow(
1741
+ [
1742
+ nid,
1743
+ scores[j],
1744
+ spacing[j, 0],
1745
+ spacing[j, 1],
1746
+ spacing[j, 2],
1747
+ orientation[j, 0],
1748
+ orientation[j, 1],
1749
+ orientation[j, 2],
1750
+ ellipse[j, 0],
1751
+ ellipse[j, 1],
1752
+ ellipse[j, 2],
1753
+ ellipse[j, 3],
1754
+ ellipse_theta_deg[j],
1755
+ center_radius[j],
1756
+ optimal_radius[j],
1757
+ ]
1758
+ )
1759
+
1760
+ gridscore_png = out_dir / "gridscore_summary.png"
1761
+ try:
1762
+ import matplotlib.pyplot as plt
1763
+
1764
+ fig, ax = plt.subplots(1, 1, figsize=(10, 3.5))
1765
+ valid = scores[np.isfinite(scores)]
1766
+ ax.hist(valid, bins=30)
1767
+ ax.axvline(score_thr, linestyle="--")
1768
+ ax.set_title("Grid score distribution")
1769
+ ax.set_xlabel("grid score")
1770
+ ax.set_ylabel("count")
1771
+ fig.tight_layout()
1772
+ fig.savefig(gridscore_png, dpi=200)
1773
+ plt.close(fig)
1774
+ except Exception as e:
1775
+ log_callback(f"Warning: failed to save GridScore summary png: {e}")
1776
+
1777
+ log_callback(f"GridScore done. Saved: {gridscore_npz} , {gridscore_csv}")
1778
+ return {
1779
+ "gridscore_npz": gridscore_npz,
1780
+ "gridscore_csv": gridscore_csv,
1781
+ "gridscore_png": gridscore_png,
1782
+ }
1783
+
1784
+ def _run_gridscore_inspect(
1785
+ self,
1786
+ asa_data: dict[str, Any],
1787
+ state: WorkflowState,
1788
+ log_callback,
1789
+ progress_callback: Callable[[int], None],
1790
+ ) -> dict[str, Path]:
1791
+ """Run single-neuron GridScore inspection."""
1792
+ from canns.analyzer.data.cell_classification import (
1793
+ GridnessAnalyzer,
1794
+ compute_2d_autocorrelation,
1795
+ plot_gridness_analysis,
1796
+ )
1797
+
1798
+ params = state.analysis_params or {}
1799
+ gs_params = params.get("gridscore") if isinstance(params.get("gridscore"), dict) else {}
1800
+
1801
+ def _param(key: str, default: Any) -> Any:
1802
+ if key in params:
1803
+ return params.get(key, default)
1804
+ if gs_params and key in gs_params:
1805
+ return gs_params.get(key, default)
1806
+ return default
1807
+
1808
+ neuron_id = int(_param("neuron_id", _param("neuron", 0)))
1809
+ bins = int(_param("bins", 50))
1810
+ min_occ = int(_param("min_occupancy", 1))
1811
+ smoothing = bool(_param("smoothing", False))
1812
+ sigma = float(_param("sigma", 1.0))
1813
+ overlap = float(_param("overlap", 0.8))
1814
+ mode = str(_param("mode", "fr")).strip().lower()
1815
+
1816
+ if mode not in {"fr", "spike"}:
1817
+ mode = "fr"
1818
+ bins = max(5, bins)
1819
+ overlap = max(0.1, min(1.0, overlap))
1820
+
1821
+ pos = self._aligned_pos if self._aligned_pos is not None else asa_data
1822
+ if "x" not in pos or "y" not in pos or pos["x"] is None or pos["y"] is None:
1823
+ raise ProcessingError("GridScore inspect requires position data (x, y).")
1824
+
1825
+ if mode == "fr":
1826
+ if isinstance(self._embed_data, np.ndarray) and self._embed_data.ndim == 2:
1827
+ activity_full = self._embed_data
1828
+ log_callback("GridScoreInspect[FR]: using preprocessed spike matrix.")
1829
+ else:
1830
+ log_callback(
1831
+ "GridScoreInspect[FR]: no preprocessed matrix, falling back to spike mode."
1832
+ )
1833
+ activity_full = self._build_spike_matrix_from_events(asa_data)
1834
+ else:
1835
+ activity_full = self._build_spike_matrix_from_events(asa_data)
1836
+ log_callback("GridScoreInspect[spike]: using event-based spike matrix.")
1837
+
1838
+ sp = np.asarray(activity_full)
1839
+ if sp.ndim != 2:
1840
+ raise ProcessingError(f"GridScore inspect expects 2D spike matrix, got ndim={sp.ndim}.")
1841
+
1842
+ x = np.asarray(pos["x"]).ravel()
1843
+ y = np.asarray(pos["y"]).ravel()
1844
+ m = min(len(x), len(y), sp.shape[0])
1845
+ x = x[:m]
1846
+ y = y[:m]
1847
+ sp = sp[:m, :]
1848
+
1849
+ finite = np.isfinite(x) & np.isfinite(y)
1850
+ if not np.all(finite):
1851
+ x = x[finite]
1852
+ y = y[finite]
1853
+ sp = sp[finite, :]
1854
+
1855
+ total_neurons = sp.shape[1]
1856
+ neuron_id = max(0, min(int(neuron_id), total_neurons - 1))
1857
+
1858
+ xmin, xmax = float(np.min(x)), float(np.max(x))
1859
+ ymin, ymax = float(np.min(y)), float(np.max(y))
1860
+ eps = 1e-12
1861
+ if xmax - xmin < eps:
1862
+ xmax = xmin + 1.0
1863
+ if ymax - ymin < eps:
1864
+ ymax = ymin + 1.0
1865
+
1866
+ ix = np.floor((x - xmin) / (xmax - xmin + eps) * bins).astype(int)
1867
+ iy = np.floor((y - ymin) / (ymax - ymin + eps) * bins).astype(int)
1868
+ ix = np.clip(ix, 0, bins - 1)
1869
+ iy = np.clip(iy, 0, bins - 1)
1870
+ flat = (iy * bins + ix).astype(int)
1871
+
1872
+ occ = np.bincount(flat, minlength=bins * bins).astype(float).reshape(bins, bins)
1873
+ occ_mask = occ >= float(min_occ)
1874
+
1875
+ weights = sp[:, neuron_id].astype(float, copy=False)
1876
+ spike_map = (
1877
+ np.bincount(flat, weights=weights, minlength=bins * bins)
1878
+ .astype(float)
1879
+ .reshape(bins, bins)
1880
+ )
1881
+ rate_map = np.zeros_like(spike_map)
1882
+ rate_map[occ_mask] = spike_map[occ_mask] / occ[occ_mask]
1883
+ if smoothing and sigma > 0:
1884
+ try:
1885
+ from scipy.ndimage import gaussian_filter as _gaussian_filter
1886
+ except Exception as e: # pragma: no cover - optional dependency
1887
+ raise ProcessingError(f"GridScore requires scipy for smoothing: {e}") from e
1888
+ rate_map = _gaussian_filter(rate_map, sigma=float(sigma), mode="nearest")
1889
+
1890
+ autocorr = compute_2d_autocorrelation(rate_map, overlap=overlap)
1891
+ analyzer = GridnessAnalyzer()
1892
+ result = analyzer.compute_gridness_score(autocorr)
1893
+
1894
+ out_dir = self._results_dir(state) / "GRIDScore"
1895
+ out_dir.mkdir(parents=True, exist_ok=True)
1896
+
1897
+ gridscore_neuron_npz = out_dir / f"gridscore_neuron_{neuron_id}.npz"
1898
+ np.savez_compressed(
1899
+ str(gridscore_neuron_npz),
1900
+ neuron_id=neuron_id,
1901
+ bins=bins,
1902
+ min_occupancy=min_occ,
1903
+ smoothing=smoothing,
1904
+ sigma=sigma,
1905
+ overlap=overlap,
1906
+ mode=mode,
1907
+ grid_score=float(result.score),
1908
+ spacing=np.asarray(result.spacing)
1909
+ if result.spacing is not None
1910
+ else np.full((3,), np.nan),
1911
+ orientation=np.asarray(result.orientation)
1912
+ if result.orientation is not None
1913
+ else np.full((3,), np.nan),
1914
+ ellipse=np.asarray(result.ellipse)
1915
+ if result.ellipse is not None
1916
+ else np.full((5,), np.nan),
1917
+ ellipse_theta_deg=float(getattr(result, "ellipse_theta_deg", np.nan)),
1918
+ center_radius=float(getattr(result, "center_radius", np.nan)),
1919
+ optimal_radius=float(getattr(result, "optimal_radius", np.nan)),
1920
+ )
1921
+
1922
+ gridscore_neuron_png = out_dir / f"gridscore_neuron_{neuron_id}.png"
1923
+ plot_gridness_analysis(
1924
+ rate_map=rate_map,
1925
+ autocorr=autocorr,
1926
+ result=result,
1927
+ save_path=str(gridscore_neuron_png),
1928
+ show=False,
1929
+ title=f"GridScore neuron {neuron_id}",
1930
+ )
1931
+ progress_callback(100)
1932
+
1933
+ return {
1934
+ "gridscore_neuron_npz": gridscore_neuron_npz,
1935
+ "gridscore_neuron_png": gridscore_neuron_png,
1936
+ }