canns 0.12.7__py3-none-any.whl → 0.13.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. canns/analyzer/data/__init__.py +3 -11
  2. canns/analyzer/data/asa/__init__.py +84 -0
  3. canns/analyzer/data/asa/cohospace.py +905 -0
  4. canns/analyzer/data/asa/config.py +246 -0
  5. canns/analyzer/data/asa/decode.py +445 -0
  6. canns/analyzer/data/asa/embedding.py +269 -0
  7. canns/analyzer/data/asa/filters.py +208 -0
  8. canns/analyzer/data/{cann1d.py → asa/fly_roi.py} +98 -45
  9. canns/analyzer/data/asa/fr.py +431 -0
  10. canns/analyzer/data/asa/path.py +389 -0
  11. canns/analyzer/data/asa/plotting.py +1287 -0
  12. canns/analyzer/data/asa/tda.py +901 -0
  13. canns/analyzer/visualization/core/backend.py +1 -1
  14. canns/analyzer/visualization/core/config.py +77 -0
  15. canns/analyzer/visualization/core/rendering.py +10 -6
  16. canns/analyzer/visualization/energy_plots.py +22 -8
  17. canns/analyzer/visualization/spatial_plots.py +31 -11
  18. canns/analyzer/visualization/theta_sweep_plots.py +15 -6
  19. canns/pipeline/__init__.py +4 -8
  20. canns/pipeline/asa/__init__.py +21 -0
  21. canns/pipeline/asa/__main__.py +11 -0
  22. canns/pipeline/asa/app.py +1000 -0
  23. canns/pipeline/asa/runner.py +1095 -0
  24. canns/pipeline/asa/screens.py +215 -0
  25. canns/pipeline/asa/state.py +248 -0
  26. canns/pipeline/asa/styles.tcss +221 -0
  27. canns/pipeline/asa/widgets.py +233 -0
  28. canns/pipeline/gallery/__init__.py +7 -0
  29. canns/task/open_loop_navigation.py +3 -1
  30. {canns-0.12.7.dist-info → canns-0.13.1.dist-info}/METADATA +6 -3
  31. {canns-0.12.7.dist-info → canns-0.13.1.dist-info}/RECORD +34 -17
  32. {canns-0.12.7.dist-info → canns-0.13.1.dist-info}/entry_points.txt +1 -0
  33. canns/analyzer/data/cann2d.py +0 -2565
  34. canns/pipeline/theta_sweep.py +0 -573
  35. {canns-0.12.7.dist-info → canns-0.13.1.dist-info}/WHEEL +0 -0
  36. {canns-0.12.7.dist-info → canns-0.13.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,1095 @@
1
+ """Pipeline execution wrapper for ASA TUI.
2
+
3
+ This module provides async pipeline execution that integrates with the existing
4
+ canns.analyzer.data.asa module. It wraps the analysis functions and provides
5
+ progress callbacks for the TUI.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import hashlib
11
+ import json
12
+ import time
13
+ from collections.abc import Callable
14
+ from dataclasses import dataclass
15
+ from pathlib import Path
16
+ from typing import Any
17
+
18
+ import numpy as np
19
+
20
+ from .state import WorkflowState, resolve_path
21
+
22
+
23
+ @dataclass
24
+ class PipelineResult:
25
+ """Result from pipeline execution."""
26
+
27
+ success: bool
28
+ artifacts: dict[str, Path]
29
+ summary: str
30
+ error: str | None = None
31
+ elapsed_time: float = 0.0
32
+
33
+
34
+ class ProcessingError(RuntimeError):
35
+ """Raised when a pipeline stage fails."""
36
+
37
+ pass
38
+
39
+
40
+ class PipelineRunner:
41
+ """Async pipeline execution wrapper."""
42
+
43
+ def __init__(self):
44
+ """Initialize pipeline runner."""
45
+ self._asa_data: dict[str, Any] | None = None
46
+ self._embed_data: np.ndarray | None = None # Preprocessed data
47
+ self._aligned_pos: dict[str, np.ndarray] | None = None
48
+ self._input_hash: str | None = None
49
+ self._embed_hash: str | None = None
50
+ self._mpl_ready: bool = False
51
+
52
+ def has_preprocessed_data(self) -> bool:
53
+ """Check if preprocessing has been completed."""
54
+ return self._embed_data is not None
55
+
56
+ def reset_input(self) -> None:
57
+ """Clear cached input/preprocessing state when input files change."""
58
+ self._asa_data = None
59
+ self._embed_data = None
60
+ self._aligned_pos = None
61
+ self._input_hash = None
62
+ self._embed_hash = None
63
+
64
+ def _json_safe(self, obj: Any) -> Any:
65
+ """Convert objects to JSON-serializable structures."""
66
+ if isinstance(obj, Path):
67
+ return str(obj)
68
+ if isinstance(obj, tuple):
69
+ return [self._json_safe(v) for v in obj]
70
+ if isinstance(obj, list):
71
+ return [self._json_safe(v) for v in obj]
72
+ if isinstance(obj, dict):
73
+ return {str(k): self._json_safe(v) for k, v in obj.items()}
74
+ if hasattr(obj, "item") and callable(obj.item):
75
+ try:
76
+ return obj.item()
77
+ except Exception:
78
+ return str(obj)
79
+ return obj
80
+
81
+ def _hash_bytes(self, data: bytes) -> str:
82
+ return hashlib.md5(data).hexdigest()
83
+
84
+ def _hash_file(self, path: Path) -> str:
85
+ """Compute md5 hash for a file."""
86
+ md5 = hashlib.md5()
87
+ with path.open("rb") as f:
88
+ for chunk in iter(lambda: f.read(1024 * 1024), b""):
89
+ md5.update(chunk)
90
+ return md5.hexdigest()
91
+
92
+ def _hash_obj(self, obj: Any) -> str:
93
+ payload = json.dumps(self._json_safe(obj), sort_keys=True, ensure_ascii=True).encode(
94
+ "utf-8"
95
+ )
96
+ return self._hash_bytes(payload)
97
+
98
+ def _ensure_matplotlib_backend(self) -> None:
99
+ """Force a non-interactive Matplotlib backend for worker threads."""
100
+ if self._mpl_ready:
101
+ return
102
+ try:
103
+ import os
104
+
105
+ os.environ.setdefault("MPLBACKEND", "Agg")
106
+ import matplotlib
107
+
108
+ try:
109
+ matplotlib.use("Agg", force=True)
110
+ except TypeError:
111
+ matplotlib.use("Agg")
112
+ except Exception:
113
+ pass
114
+ self._mpl_ready = True
115
+
116
+ def _cache_dir(self, state: WorkflowState) -> Path:
117
+ return self._results_dir(state) / ".asa_cache"
118
+
119
+ def _results_dir(self, state: WorkflowState) -> Path:
120
+ base = state.workdir / "Results"
121
+ dataset_id = self._dataset_id(state)
122
+ return base / dataset_id
123
+
124
+ def results_dir(self, state: WorkflowState) -> Path:
125
+ """Public accessor for results directory."""
126
+ return self._results_dir(state)
127
+
128
+ def _dataset_id(self, state: WorkflowState) -> str:
129
+ """Create a stable dataset id based on input filename and md5 prefix."""
130
+ try:
131
+ input_hash = self._input_hash or self._compute_input_hash(state)
132
+ except Exception:
133
+ input_hash = "unknown"
134
+ prefix = input_hash[:8]
135
+
136
+ if state.input_mode == "asa":
137
+ path = resolve_path(state, state.asa_file)
138
+ stem = path.stem if path is not None else "asa"
139
+ return f"{stem}_{prefix}"
140
+ if state.input_mode == "neuron_traj":
141
+ neuron_path = resolve_path(state, state.neuron_file)
142
+ traj_path = resolve_path(state, state.traj_file)
143
+ neuron_stem = neuron_path.stem if neuron_path is not None else "neuron"
144
+ traj_stem = traj_path.stem if traj_path is not None else "traj"
145
+ return f"{neuron_stem}_{traj_stem}_{prefix}"
146
+ return f"{state.input_mode}_{prefix}"
147
+
148
+ def _stage_cache_path(self, stage_dir: Path) -> Path:
149
+ return stage_dir / "cache.json"
150
+
151
+ def _load_cache_meta(self, path: Path) -> dict[str, Any]:
152
+ if not path.exists():
153
+ return {}
154
+ try:
155
+ return json.loads(path.read_text(encoding="utf-8"))
156
+ except Exception:
157
+ return {}
158
+
159
+ def _write_cache_meta(self, path: Path, payload: dict[str, Any]) -> None:
160
+ path.write_text(
161
+ json.dumps(self._json_safe(payload), ensure_ascii=True, indent=2), encoding="utf-8"
162
+ )
163
+
164
+ def _stage_cache_hit(
165
+ self, stage_dir: Path, expected_hash: str, required_files: list[Path]
166
+ ) -> bool:
167
+ if not all(p.exists() for p in required_files):
168
+ return False
169
+ meta = self._load_cache_meta(self._stage_cache_path(stage_dir))
170
+ return meta.get("hash") == expected_hash
171
+
172
+ def _compute_input_hash(self, state: WorkflowState) -> str:
173
+ """Compute md5 hash for input data files."""
174
+ if state.input_mode == "asa":
175
+ path = resolve_path(state, state.asa_file)
176
+ if path is None:
177
+ raise ProcessingError("ASA file not set.")
178
+ return self._hash_obj({"mode": "asa", "file": self._hash_file(path)})
179
+ if state.input_mode == "neuron_traj":
180
+ neuron_path = resolve_path(state, state.neuron_file)
181
+ traj_path = resolve_path(state, state.traj_file)
182
+ if neuron_path is None or traj_path is None:
183
+ raise ProcessingError("Neuron/trajectory files not set.")
184
+ return self._hash_obj(
185
+ {
186
+ "mode": "neuron_traj",
187
+ "neuron": self._hash_file(neuron_path),
188
+ "traj": self._hash_file(traj_path),
189
+ }
190
+ )
191
+ return self._hash_obj({"mode": state.input_mode})
192
+
193
+ def _load_npz_dict(self, path: Path) -> dict[str, Any]:
194
+ """Load npz into a dict, handling wrapped dict entries."""
195
+ data = np.load(path, allow_pickle=True)
196
+ for key in ("persistence_result", "decode_result"):
197
+ if key in data.files:
198
+ return data[key].item()
199
+ return {k: data[k] for k in data.files}
200
+
201
+ async def run_preprocessing(
202
+ self,
203
+ state: WorkflowState,
204
+ log_callback: Callable[[str], None],
205
+ progress_callback: Callable[[int], None],
206
+ ) -> PipelineResult:
207
+ """Run preprocessing pipeline to generate embed_data.
208
+
209
+ Args:
210
+ state: Current workflow state
211
+ log_callback: Callback for log messages
212
+ progress_callback: Callback for progress updates (0-100)
213
+
214
+ Returns:
215
+ PipelineResult with preprocessing status
216
+ """
217
+ t0 = time.time()
218
+
219
+ try:
220
+ # Stage 1: Load data
221
+ log_callback("Loading data...")
222
+ progress_callback(10)
223
+ asa_data = self._load_data(state)
224
+ self._asa_data = asa_data
225
+ self._aligned_pos = None
226
+ self._input_hash = self._compute_input_hash(state)
227
+
228
+ # Stage 2: Preprocess
229
+ log_callback(f"Preprocessing with {state.preprocess_method}...")
230
+ progress_callback(30)
231
+
232
+ if state.preprocess_method == "embed_spike_trains":
233
+ from canns.analyzer.data.asa import SpikeEmbeddingConfig, embed_spike_trains
234
+
235
+ # Get preprocessing parameters from state or use config defaults
236
+ params = state.preprocess_params if state.preprocess_params else {}
237
+ base_config = SpikeEmbeddingConfig()
238
+ effective_params = {
239
+ "res": base_config.res,
240
+ "dt": base_config.dt,
241
+ "sigma": base_config.sigma,
242
+ "smooth": base_config.smooth,
243
+ "speed_filter": base_config.speed_filter,
244
+ "min_speed": base_config.min_speed,
245
+ }
246
+ effective_params.update(params)
247
+
248
+ self._embed_hash = self._hash_obj(
249
+ {
250
+ "input_hash": self._input_hash,
251
+ "method": state.preprocess_method,
252
+ "params": effective_params,
253
+ }
254
+ )
255
+ cache_dir = self._cache_dir(state)
256
+ cache_dir.mkdir(parents=True, exist_ok=True)
257
+ cache_path = cache_dir / f"embed_{self._embed_hash}.npz"
258
+ meta_path = cache_dir / f"embed_{self._embed_hash}.json"
259
+
260
+ if cache_path.exists():
261
+ log_callback("♻️ Using cached embedding.")
262
+ cached = np.load(cache_path, allow_pickle=True)
263
+ self._embed_data = cached["embed_data"]
264
+ if {"x", "y", "t"}.issubset(set(cached.files)):
265
+ self._aligned_pos = {
266
+ "x": cached["x"],
267
+ "y": cached["y"],
268
+ "t": cached["t"],
269
+ }
270
+ progress_callback(100)
271
+ elapsed = time.time() - t0
272
+ return PipelineResult(
273
+ success=True,
274
+ artifacts={"embedding": cache_path},
275
+ summary=f"Preprocessing reused cached embedding in {elapsed:.1f}s",
276
+ elapsed_time=elapsed,
277
+ )
278
+
279
+ config = SpikeEmbeddingConfig(**effective_params)
280
+
281
+ log_callback("Running embed_spike_trains...")
282
+ progress_callback(50)
283
+ embed_result = embed_spike_trains(asa_data, config)
284
+
285
+ if isinstance(embed_result, tuple):
286
+ embed_data = embed_result[0]
287
+ if len(embed_result) >= 4 and embed_result[1] is not None:
288
+ self._aligned_pos = {
289
+ "x": embed_result[1],
290
+ "y": embed_result[2],
291
+ "t": embed_result[3],
292
+ }
293
+ else:
294
+ embed_data = embed_result
295
+
296
+ self._embed_data = embed_data
297
+ log_callback(f"Embed data shape: {embed_data.shape}")
298
+
299
+ try:
300
+ payload = {"embed_data": embed_data}
301
+ if self._aligned_pos is not None:
302
+ payload.update(self._aligned_pos)
303
+ np.savez_compressed(cache_path, **payload)
304
+ self._write_cache_meta(
305
+ meta_path,
306
+ {
307
+ "hash": self._embed_hash,
308
+ "input_hash": self._input_hash,
309
+ "params": effective_params,
310
+ },
311
+ )
312
+ except Exception as e:
313
+ log_callback(f"Warning: failed to cache embedding: {e}")
314
+ else:
315
+ # No preprocessing - use spike data directly
316
+ log_callback("No preprocessing - using raw spike data")
317
+ spike = asa_data.get("spike")
318
+ self._embed_hash = self._hash_obj(
319
+ {
320
+ "input_hash": self._input_hash,
321
+ "method": state.preprocess_method,
322
+ "params": {},
323
+ }
324
+ )
325
+
326
+ # Check if already a dense matrix
327
+ if isinstance(spike, np.ndarray) and spike.ndim == 2:
328
+ self._embed_data = spike
329
+ log_callback(f"Using spike matrix shape: {spike.shape}")
330
+ else:
331
+ log_callback(
332
+ "Warning: spike data is not a dense matrix, some analyses may fail"
333
+ )
334
+ self._embed_data = spike
335
+
336
+ progress_callback(100)
337
+ elapsed = time.time() - t0
338
+
339
+ return PipelineResult(
340
+ success=True,
341
+ artifacts={},
342
+ summary=f"Preprocessing completed in {elapsed:.1f}s",
343
+ elapsed_time=elapsed,
344
+ )
345
+
346
+ except Exception as e:
347
+ elapsed = time.time() - t0
348
+ log_callback(f"Error: {e}")
349
+ return PipelineResult(
350
+ success=False,
351
+ artifacts={},
352
+ summary=f"Failed after {elapsed:.1f}s",
353
+ error=str(e),
354
+ elapsed_time=elapsed,
355
+ )
356
+
357
+ async def run_analysis(
358
+ self,
359
+ state: WorkflowState,
360
+ log_callback: Callable[[str], None],
361
+ progress_callback: Callable[[int], None],
362
+ ) -> PipelineResult:
363
+ """Run analysis pipeline based on workflow state.
364
+
365
+ Args:
366
+ state: Current workflow state
367
+ log_callback: Callback for log messages
368
+ progress_callback: Callback for progress updates (0-100)
369
+
370
+ Returns:
371
+ PipelineResult with success status and artifacts
372
+ """
373
+ t0 = time.time()
374
+ artifacts = {}
375
+
376
+ try:
377
+ # Stage 1: Load data
378
+ log_callback("Loading data...")
379
+ progress_callback(10)
380
+ asa_data = self._asa_data if self._asa_data is not None else self._load_data(state)
381
+ if self._input_hash is None:
382
+ self._input_hash = self._compute_input_hash(state)
383
+
384
+ self._ensure_matplotlib_backend()
385
+
386
+ # Stage 3: Analysis (mode-dependent)
387
+ log_callback(f"Running {state.analysis_mode} analysis...")
388
+ progress_callback(40)
389
+
390
+ mode = state.analysis_mode.lower()
391
+ if mode == "tda":
392
+ artifacts = self._run_tda(asa_data, state, log_callback)
393
+ elif mode == "cohomap":
394
+ artifacts = self._run_cohomap(asa_data, state, log_callback)
395
+ elif mode == "pathcompare":
396
+ artifacts = self._run_pathcompare(asa_data, state, log_callback)
397
+ elif mode == "cohospace":
398
+ artifacts = self._run_cohospace(asa_data, state, log_callback)
399
+ elif mode == "fr":
400
+ artifacts = self._run_fr(asa_data, state, log_callback)
401
+ elif mode == "frm":
402
+ artifacts = self._run_frm(asa_data, state, log_callback)
403
+ elif mode == "gridscore":
404
+ artifacts = self._run_gridscore(asa_data, state, log_callback)
405
+ else:
406
+ raise ProcessingError(f"Unknown analysis mode: {state.analysis_mode}")
407
+
408
+ progress_callback(100)
409
+ elapsed = time.time() - t0
410
+
411
+ return PipelineResult(
412
+ success=True,
413
+ artifacts=artifacts,
414
+ summary=f"Completed {state.analysis_mode} analysis in {elapsed:.1f}s",
415
+ elapsed_time=elapsed,
416
+ )
417
+
418
+ except Exception as e:
419
+ elapsed = time.time() - t0
420
+ log_callback(f"Error: {e}")
421
+ return PipelineResult(
422
+ success=False,
423
+ artifacts=artifacts,
424
+ summary=f"Failed after {elapsed:.1f}s",
425
+ error=str(e),
426
+ elapsed_time=elapsed,
427
+ )
428
+
429
+ def _load_data(self, state: WorkflowState) -> dict[str, Any]:
430
+ """Load data based on input mode."""
431
+ if state.input_mode == "asa":
432
+ path = resolve_path(state, state.asa_file)
433
+ data = np.load(path, allow_pickle=True)
434
+ return {k: data[k] for k in data.files}
435
+ elif state.input_mode == "neuron_traj":
436
+ neuron_path = resolve_path(state, state.neuron_file)
437
+ traj_path = resolve_path(state, state.traj_file)
438
+ neuron_data = np.load(neuron_path, allow_pickle=True)
439
+ traj_data = np.load(traj_path, allow_pickle=True)
440
+ return {
441
+ "spike": neuron_data.get("spike", neuron_data),
442
+ "x": traj_data["x"],
443
+ "y": traj_data["y"],
444
+ "t": traj_data["t"],
445
+ }
446
+ else:
447
+ raise ProcessingError(f"Unknown input mode: {state.input_mode}")
448
+
449
+ def _run_preprocess(self, asa_data: dict[str, Any], state: WorkflowState) -> dict[str, Any]:
450
+ """Run preprocessing on ASA data."""
451
+ if state.preprocess_method == "embed_spike_trains":
452
+ from canns.analyzer.data.asa import SpikeEmbeddingConfig, embed_spike_trains
453
+
454
+ params = state.preprocess_params
455
+ base_config = SpikeEmbeddingConfig()
456
+ effective_params = {
457
+ "res": base_config.res,
458
+ "dt": base_config.dt,
459
+ "sigma": base_config.sigma,
460
+ "smooth": base_config.smooth,
461
+ "speed_filter": base_config.speed_filter,
462
+ "min_speed": base_config.min_speed,
463
+ }
464
+ effective_params.update(params)
465
+ config = SpikeEmbeddingConfig(**effective_params)
466
+
467
+ spike_main = embed_spike_trains(asa_data, config)
468
+ asa_data["spike_main"] = spike_main
469
+
470
+ return asa_data
471
+
472
+ def _run_tda(
473
+ self, asa_data: dict[str, Any], state: WorkflowState, log_callback
474
+ ) -> dict[str, Path]:
475
+ """Run TDA analysis."""
476
+ from canns.analyzer.data.asa import TDAConfig, tda_vis
477
+ from canns.analyzer.data.asa.tda import _plot_barcode, _plot_barcode_with_shuffle
478
+
479
+ # Create output directory
480
+ out_dir = self._results_dir(state) / "TDA"
481
+ out_dir.mkdir(parents=True, exist_ok=True)
482
+
483
+ # Get parameters
484
+ params = state.analysis_params
485
+ config = TDAConfig(
486
+ dim=params.get("dim", 6),
487
+ num_times=params.get("num_times", 5),
488
+ active_times=params.get("active_times", 15000),
489
+ k=params.get("k", 1000),
490
+ n_points=params.get("n_points", 1200),
491
+ metric=params.get("metric", "cosine"),
492
+ nbs=params.get("nbs", 800),
493
+ maxdim=params.get("maxdim", 1),
494
+ coeff=params.get("coeff", 47),
495
+ show=False,
496
+ do_shuffle=params.get("do_shuffle", False),
497
+ num_shuffles=params.get("num_shuffles", 1000),
498
+ progress_bar=False,
499
+ )
500
+
501
+ log_callback("Computing persistent homology...")
502
+
503
+ if self._embed_data is None:
504
+ raise ProcessingError("No preprocessed data available. Run preprocessing first.")
505
+ if not isinstance(self._embed_data, np.ndarray) or self._embed_data.ndim != 2:
506
+ raise ProcessingError(
507
+ "TDA requires a dense spike matrix (T,N). "
508
+ "Please choose 'Embed Spike Trains' in preprocessing or provide a dense spike matrix in the .npz."
509
+ )
510
+
511
+ persistence_path = out_dir / "persistence.npz"
512
+ barcode_path = out_dir / "barcode.png"
513
+
514
+ embed_hash = self._embed_hash or self._hash_obj({"embed": "unknown"})
515
+ tda_hash = self._hash_obj({"embed_hash": embed_hash, "params": params})
516
+
517
+ if self._stage_cache_hit(out_dir, tda_hash, [persistence_path, barcode_path]):
518
+ log_callback("♻️ Using cached TDA results.")
519
+ return {"persistence": persistence_path, "barcode": barcode_path}
520
+
521
+ embed_data = self._embed_data
522
+ if params.get("standardize", False):
523
+ try:
524
+ from sklearn.preprocessing import StandardScaler
525
+
526
+ embed_data = StandardScaler().fit_transform(embed_data)
527
+ except Exception as e:
528
+ raise ProcessingError(f"StandardScaler failed: {e}") from e
529
+
530
+ result = tda_vis(
531
+ embed_data=embed_data,
532
+ config=config,
533
+ )
534
+
535
+ np.savez_compressed(persistence_path, persistence_result=result)
536
+
537
+ try:
538
+ persistence = result.get("persistence")
539
+ shuffle_max = result.get("shuffle_max")
540
+ if config.do_shuffle and shuffle_max is not None:
541
+ fig = _plot_barcode_with_shuffle(persistence, shuffle_max)
542
+ else:
543
+ fig = _plot_barcode(persistence)
544
+ fig.savefig(barcode_path, dpi=200, bbox_inches="tight")
545
+ try:
546
+ import matplotlib.pyplot as plt
547
+
548
+ plt.close(fig)
549
+ except Exception:
550
+ pass
551
+ except Exception as e:
552
+ log_callback(f"Warning: failed to save barcode: {e}")
553
+
554
+ self._write_cache_meta(
555
+ self._stage_cache_path(out_dir),
556
+ {"hash": tda_hash, "embed_hash": embed_hash, "params": params},
557
+ )
558
+
559
+ return {"persistence": persistence_path, "barcode": barcode_path}
560
+
561
+ def _load_or_run_decode(
562
+ self,
563
+ asa_data: dict[str, Any],
564
+ persistence_path: Path,
565
+ state: WorkflowState,
566
+ log_callback,
567
+ ) -> dict[str, Any]:
568
+ """Load cached decoding or run decode_circular_coordinates."""
569
+ from canns.analyzer.data.asa import (
570
+ decode_circular_coordinates,
571
+ decode_circular_coordinates_multi,
572
+ )
573
+
574
+ decode_dir = self._results_dir(state) / "CohoMap"
575
+ decode_dir.mkdir(parents=True, exist_ok=True)
576
+ decode_path = decode_dir / "decoding.npz"
577
+
578
+ params = state.analysis_params
579
+ decode_version = str(params.get("decode_version", "v2"))
580
+ num_circ = int(params.get("num_circ", 2))
581
+ decode_params = {
582
+ "real_ground": params.get("real_ground", True),
583
+ "real_of": params.get("real_of", True),
584
+ "decode_version": decode_version,
585
+ "num_circ": num_circ,
586
+ }
587
+ persistence_hash = self._hash_file(persistence_path)
588
+ decode_hash = self._hash_obj(
589
+ {"persistence_hash": persistence_hash, "params": decode_params}
590
+ )
591
+
592
+ meta_path = self._stage_cache_path(decode_dir)
593
+ meta = self._load_cache_meta(meta_path)
594
+ if decode_path.exists() and meta.get("decode_hash") == decode_hash:
595
+ log_callback("♻️ Using cached decoding.")
596
+ return self._load_npz_dict(decode_path)
597
+
598
+ log_callback("Decoding circular coordinates...")
599
+ persistence_result = self._load_npz_dict(persistence_path)
600
+ if decode_version == "v0":
601
+ decode_result = decode_circular_coordinates(
602
+ persistence_result=persistence_result,
603
+ spike_data=asa_data,
604
+ real_ground=decode_params["real_ground"],
605
+ real_of=decode_params["real_of"],
606
+ save_path=str(decode_path),
607
+ )
608
+ else:
609
+ if self._embed_data is None:
610
+ raise ProcessingError("No preprocessed data available for decode v2.")
611
+ spike_data = dict(asa_data)
612
+ spike_data["spike"] = self._embed_data
613
+ decode_result = decode_circular_coordinates_multi(
614
+ persistence_result=persistence_result,
615
+ spike_data=spike_data,
616
+ save_path=str(decode_path),
617
+ num_circ=num_circ,
618
+ )
619
+
620
+ meta["decode_hash"] = decode_hash
621
+ meta["persistence_hash"] = persistence_hash
622
+ meta["decode_params"] = decode_params
623
+ self._write_cache_meta(meta_path, meta)
624
+ return decode_result
625
+
626
+ def _run_cohomap(
627
+ self, asa_data: dict[str, Any], state: WorkflowState, log_callback
628
+ ) -> dict[str, Path]:
629
+ """Run CohoMap analysis (TDA + decode + plotting)."""
630
+ from canns.analyzer.data.asa import plot_cohomap_multi
631
+ from canns.analyzer.visualization import PlotConfigs
632
+
633
+ tda_dir = self._results_dir(state) / "TDA"
634
+ persistence_path = tda_dir / "persistence.npz"
635
+ if not persistence_path.exists():
636
+ raise ProcessingError("TDA results not found. Run TDA analysis first.")
637
+
638
+ out_dir = self._results_dir(state) / "CohoMap"
639
+ out_dir.mkdir(parents=True, exist_ok=True)
640
+
641
+ decode_result = self._load_or_run_decode(asa_data, persistence_path, state, log_callback)
642
+
643
+ params = state.analysis_params
644
+ subsample = int(params.get("cohomap_subsample", 10))
645
+
646
+ cohomap_path = out_dir / "cohomap.png"
647
+ stage_hash = self._hash_obj(
648
+ {
649
+ "decode_hash": self._load_cache_meta(self._stage_cache_path(out_dir)).get(
650
+ "decode_hash"
651
+ ),
652
+ "plot": "cohomap",
653
+ "subsample": subsample,
654
+ }
655
+ )
656
+ if self._stage_cache_hit(out_dir, stage_hash, [cohomap_path]):
657
+ log_callback("♻️ Using cached CohoMap plot.")
658
+ return {"decoding": out_dir / "decoding.npz", "cohomap": cohomap_path}
659
+
660
+ log_callback("Generating cohomology map...")
661
+ pos = self._aligned_pos if self._aligned_pos is not None else asa_data
662
+ config = PlotConfigs.cohomap(show=False, save_path=str(cohomap_path))
663
+ plot_cohomap_multi(
664
+ decoding_result=decode_result,
665
+ position_data={"x": pos["x"], "y": pos["y"]},
666
+ config=config,
667
+ subsample=subsample,
668
+ )
669
+
670
+ self._write_cache_meta(
671
+ self._stage_cache_path(out_dir),
672
+ {
673
+ **self._load_cache_meta(self._stage_cache_path(out_dir)),
674
+ "hash": stage_hash,
675
+ },
676
+ )
677
+
678
+ return {"decoding": out_dir / "decoding.npz", "cohomap": cohomap_path}
679
+
680
+ def _run_pathcompare(
681
+ self, asa_data: dict[str, Any], state: WorkflowState, log_callback
682
+ ) -> dict[str, Path]:
683
+ """Run path comparison visualization."""
684
+ from canns.analyzer.data.asa import (
685
+ align_coords_to_position,
686
+ apply_angle_scale,
687
+ plot_path_compare,
688
+ )
689
+ from canns.analyzer.data.asa.path import (
690
+ find_coords_matrix,
691
+ find_times_box,
692
+ resolve_time_slice,
693
+ )
694
+ from canns.analyzer.visualization import PlotConfigs
695
+
696
+ tda_dir = self._results_dir(state) / "TDA"
697
+ persistence_path = tda_dir / "persistence.npz"
698
+ if not persistence_path.exists():
699
+ raise ProcessingError("TDA results not found. Run TDA analysis first.")
700
+
701
+ decode_result = self._load_or_run_decode(asa_data, persistence_path, state, log_callback)
702
+
703
+ # Create output directory
704
+ out_dir = self._results_dir(state) / "PathCompare"
705
+ out_dir.mkdir(parents=True, exist_ok=True)
706
+
707
+ params = state.analysis_params
708
+ angle_scale = params.get("angle_scale", "rad")
709
+ dim_mode = params.get("dim_mode", "2d")
710
+ dim = int(params.get("dim", 1))
711
+ dim1 = int(params.get("dim1", 1))
712
+ dim2 = int(params.get("dim2", 2))
713
+ use_box = bool(params.get("use_box", False))
714
+ interp_full = bool(params.get("interp_full", True))
715
+ coords_key = params.get("coords_key")
716
+ times_key = params.get("times_key")
717
+ slice_mode = params.get("slice_mode", "time")
718
+ tmin = params.get("tmin")
719
+ tmax = params.get("tmax")
720
+ imin = params.get("imin")
721
+ imax = params.get("imax")
722
+ stride = int(params.get("stride", 1))
723
+
724
+ coords_raw, _ = find_coords_matrix(
725
+ decode_result,
726
+ coords_key=coords_key,
727
+ prefer_box_fallback=use_box,
728
+ )
729
+
730
+ if dim_mode == "1d":
731
+ idx = max(0, dim - 1)
732
+ if idx >= coords_raw.shape[1]:
733
+ raise ProcessingError(f"dim out of range for coords shape {coords_raw.shape}")
734
+ coords2 = coords_raw[:, [idx]]
735
+ else:
736
+ idx1 = max(0, dim1 - 1)
737
+ idx2 = max(0, dim2 - 1)
738
+ if idx1 >= coords_raw.shape[1] or idx2 >= coords_raw.shape[1]:
739
+ raise ProcessingError(f"dim1/dim2 out of range for coords shape {coords_raw.shape}")
740
+ coords2 = coords_raw[:, [idx1, idx2]]
741
+
742
+ pos = self._aligned_pos if self._aligned_pos is not None else asa_data
743
+ t_full = np.asarray(pos["t"]).ravel()
744
+ x_full = np.asarray(pos["x"]).ravel()
745
+ y_full = np.asarray(pos["y"]).ravel()
746
+
747
+ if use_box:
748
+ if times_key:
749
+ times_box = decode_result.get(times_key)
750
+ else:
751
+ times_box, _ = find_times_box(decode_result)
752
+ else:
753
+ times_box = None
754
+
755
+ log_callback("Aligning decoded coordinates to position...")
756
+ t_use, x_use, y_use, coords_use, _ = align_coords_to_position(
757
+ t_full=t_full,
758
+ x_full=x_full,
759
+ y_full=y_full,
760
+ coords2=coords2,
761
+ use_box=use_box,
762
+ times_box=times_box,
763
+ interp_to_full=interp_full,
764
+ )
765
+ scale = str(angle_scale) if str(angle_scale) in {"rad", "deg", "unit", "auto"} else "rad"
766
+ coords_use = apply_angle_scale(coords_use, scale)
767
+
768
+ if slice_mode == "index":
769
+ i0, i1 = resolve_time_slice(t_use, None, None, imin, imax)
770
+ else:
771
+ i0, i1 = resolve_time_slice(t_use, tmin, tmax, None, None)
772
+
773
+ stride = max(1, stride)
774
+ idx = slice(i0, i1, stride)
775
+ t_use = t_use[idx]
776
+ x_use = x_use[idx]
777
+ y_use = y_use[idx]
778
+ coords_use = coords_use[idx]
779
+
780
+ out_path = out_dir / "path_compare.png"
781
+ decode_meta = self._load_cache_meta(
782
+ self._stage_cache_path(self._results_dir(state) / "CohoMap")
783
+ )
784
+ stage_hash = self._hash_obj(
785
+ {
786
+ "persistence": self._hash_file(persistence_path),
787
+ "decode_hash": decode_meta.get("decode_hash"),
788
+ "params": {
789
+ "angle_scale": scale,
790
+ "dim_mode": dim_mode,
791
+ "dim": dim,
792
+ "dim1": dim1,
793
+ "dim2": dim2,
794
+ "use_box": use_box,
795
+ "interp_full": interp_full,
796
+ "coords_key": coords_key,
797
+ "times_key": times_key,
798
+ "slice_mode": slice_mode,
799
+ "tmin": tmin,
800
+ "tmax": tmax,
801
+ "imin": imin,
802
+ "imax": imax,
803
+ "stride": stride,
804
+ },
805
+ }
806
+ )
807
+ if self._stage_cache_hit(out_dir, stage_hash, [out_path]):
808
+ log_callback("♻️ Using cached PathCompare plot.")
809
+ return {"path_compare": out_path}
810
+
811
+ log_callback("Generating path comparison...")
812
+ config = PlotConfigs.path_compare(show=False, save_path=str(out_path))
813
+ plot_path_compare(x_use, y_use, coords_use, config=config)
814
+
815
+ self._write_cache_meta(self._stage_cache_path(out_dir), {"hash": stage_hash})
816
+ return {"path_compare": out_path}
817
+
818
+ def _run_cohospace(
819
+ self, asa_data: dict[str, Any], state: WorkflowState, log_callback
820
+ ) -> dict[str, Path]:
821
+ """Run cohomology space visualization."""
822
+ from canns.analyzer.data.asa import (
823
+ plot_cohospace_neuron,
824
+ plot_cohospace_population,
825
+ plot_cohospace_trajectory,
826
+ )
827
+ from canns.analyzer.data.asa.cohospace import (
828
+ plot_cohospace_neuron_skewed,
829
+ plot_cohospace_population_skewed,
830
+ )
831
+ from canns.analyzer.visualization import PlotConfigs
832
+
833
+ tda_dir = self._results_dir(state) / "TDA"
834
+ persistence_path = tda_dir / "persistence.npz"
835
+ if not persistence_path.exists():
836
+ raise ProcessingError("TDA results not found. Run TDA analysis first.")
837
+
838
+ decode_result = self._load_or_run_decode(asa_data, persistence_path, state, log_callback)
839
+
840
+ out_dir = self._results_dir(state) / "CohoSpace"
841
+ out_dir.mkdir(parents=True, exist_ok=True)
842
+
843
+ params = state.analysis_params
844
+ artifacts: dict[str, Path] = {}
845
+
846
+ coords = np.asarray(decode_result.get("coords"))
847
+ coordsbox = np.asarray(decode_result.get("coordsbox"))
848
+ if coords.ndim != 2 or coords.shape[1] < 1:
849
+ raise ProcessingError("decode_result['coords'] must be 2D.")
850
+
851
+ dim_mode = str(params.get("dim_mode", "2d"))
852
+ dim = int(params.get("dim", 1))
853
+ dim1 = int(params.get("dim1", 1))
854
+ dim2 = int(params.get("dim2", 2))
855
+ mode = str(params.get("mode", "fr"))
856
+ top_percent = float(params.get("top_percent", 5.0))
857
+ view = str(params.get("view", "both"))
858
+ subsample = int(params.get("subsample", 2))
859
+ unfold = str(params.get("unfold", "square"))
860
+ skew_show_grid = bool(params.get("skew_show_grid", True))
861
+ skew_tiles = int(params.get("skew_tiles", 0))
862
+
863
+ def pick_coords(arr: np.ndarray) -> np.ndarray:
864
+ if dim_mode == "1d":
865
+ idx = max(0, dim - 1)
866
+ if idx >= arr.shape[1]:
867
+ raise ProcessingError(f"dim out of range for coords shape {arr.shape}")
868
+ one = arr[:, [idx]]
869
+ return np.hstack([one, np.zeros_like(one)])
870
+ idx1 = max(0, dim1 - 1)
871
+ idx2 = max(0, dim2 - 1)
872
+ if idx1 >= arr.shape[1] or idx2 >= arr.shape[1]:
873
+ raise ProcessingError(f"dim1/dim2 out of range for coords shape {arr.shape}")
874
+ return arr[:, [idx1, idx2]]
875
+
876
+ coords2 = pick_coords(coords)
877
+ coordsbox2 = pick_coords(coordsbox) if coordsbox.ndim == 2 else coords2
878
+
879
+ if mode == "spike":
880
+ activity = np.asarray(asa_data.get("spike"))
881
+ else:
882
+ activity = (
883
+ self._embed_data
884
+ if self._embed_data is not None
885
+ else np.asarray(asa_data.get("spike"))
886
+ )
887
+
888
+ decode_meta = self._load_cache_meta(
889
+ self._stage_cache_path(self._results_dir(state) / "CohoMap")
890
+ )
891
+ stage_hash = self._hash_obj(
892
+ {
893
+ "persistence": self._hash_file(persistence_path),
894
+ "decode_hash": decode_meta.get("decode_hash"),
895
+ "params": params,
896
+ }
897
+ )
898
+ meta_path = self._stage_cache_path(out_dir)
899
+ required = [out_dir / "cohospace_trajectory.png"]
900
+ view = str(params.get("view", "both"))
901
+ neuron_id = params.get("neuron_id")
902
+ if view in {"both", "population"}:
903
+ required.append(out_dir / "cohospace_population.png")
904
+ if neuron_id is not None and view in {"both", "single"}:
905
+ required.append(out_dir / f"cohospace_neuron_{neuron_id}.png")
906
+
907
+ if self._stage_cache_hit(out_dir, stage_hash, required):
908
+ log_callback("♻️ Using cached CohoSpace plots.")
909
+ artifacts = {"trajectory": out_dir / "cohospace_trajectory.png"}
910
+ if neuron_id is not None and view in {"both", "single"}:
911
+ artifacts["neuron"] = out_dir / f"cohospace_neuron_{neuron_id}.png"
912
+ if view in {"both", "population"}:
913
+ artifacts["population"] = out_dir / "cohospace_population.png"
914
+ return artifacts
915
+
916
+ log_callback("Plotting cohomology space trajectory...")
917
+ traj_path = out_dir / "cohospace_trajectory.png"
918
+ traj_cfg = PlotConfigs.cohospace_trajectory(show=False, save_path=str(traj_path))
919
+ plot_cohospace_trajectory(coords=coords2, times=None, subsample=subsample, config=traj_cfg)
920
+ artifacts["trajectory"] = traj_path
921
+
922
+ neuron_id = params.get("neuron_id", None)
923
+ if neuron_id is not None and view in {"both", "single"}:
924
+ log_callback(f"Plotting neuron {neuron_id}...")
925
+ neuron_path = out_dir / f"cohospace_neuron_{neuron_id}.png"
926
+ if unfold == "skew":
927
+ plot_cohospace_neuron_skewed(
928
+ coords=coordsbox2,
929
+ activity=activity,
930
+ neuron_id=int(neuron_id),
931
+ mode=mode,
932
+ top_percent=top_percent,
933
+ save_path=str(neuron_path),
934
+ show=False,
935
+ show_grid=skew_show_grid,
936
+ n_tiles=skew_tiles,
937
+ )
938
+ else:
939
+ neuron_cfg = PlotConfigs.cohospace_neuron(show=False, save_path=str(neuron_path))
940
+ plot_cohospace_neuron(
941
+ coords=coordsbox2,
942
+ activity=activity,
943
+ neuron_id=int(neuron_id),
944
+ mode=mode,
945
+ top_percent=top_percent,
946
+ config=neuron_cfg,
947
+ )
948
+ artifacts["neuron"] = neuron_path
949
+
950
+ if view in {"both", "population"}:
951
+ log_callback("Plotting population activity...")
952
+ pop_path = out_dir / "cohospace_population.png"
953
+ neuron_ids = list(range(activity.shape[1]))
954
+ if unfold == "skew":
955
+ plot_cohospace_population_skewed(
956
+ coords=coords2,
957
+ activity=activity,
958
+ neuron_ids=neuron_ids,
959
+ mode=mode,
960
+ top_percent=top_percent,
961
+ save_path=str(pop_path),
962
+ show=False,
963
+ show_grid=skew_show_grid,
964
+ n_tiles=skew_tiles,
965
+ )
966
+ else:
967
+ pop_cfg = PlotConfigs.cohospace_population(show=False, save_path=str(pop_path))
968
+ plot_cohospace_population(
969
+ coords=coords2,
970
+ activity=activity,
971
+ neuron_ids=neuron_ids,
972
+ mode=mode,
973
+ top_percent=top_percent,
974
+ config=pop_cfg,
975
+ )
976
+ artifacts["population"] = pop_path
977
+
978
+ self._write_cache_meta(meta_path, {"hash": stage_hash})
979
+ return artifacts
980
+
981
+ def _run_fr(
982
+ self, asa_data: dict[str, Any], state: WorkflowState, log_callback
983
+ ) -> dict[str, Path]:
984
+ """Run firing rate heatmap analysis."""
985
+ from canns.analyzer.data.asa import compute_fr_heatmap_matrix, save_fr_heatmap_png
986
+ from canns.analyzer.visualization import PlotConfigs
987
+
988
+ out_dir = self._results_dir(state) / "FR"
989
+ out_dir.mkdir(parents=True, exist_ok=True)
990
+
991
+ params = state.analysis_params
992
+ neuron_range = params.get("neuron_range", None)
993
+ time_range = params.get("time_range", None)
994
+ normalize = params.get("normalize", "zscore_per_neuron")
995
+ mode = params.get("mode", "fr")
996
+
997
+ if mode == "spike":
998
+ spike_data = asa_data.get("spike")
999
+ else:
1000
+ spike_data = self._embed_data
1001
+
1002
+ if spike_data is None:
1003
+ raise ProcessingError("No spike data available for FR.")
1004
+
1005
+ out_path = out_dir / "fr_heatmap.png"
1006
+ stage_hash = self._hash_obj(
1007
+ {
1008
+ "embed_hash": self._embed_hash,
1009
+ "params": params,
1010
+ }
1011
+ )
1012
+ if self._stage_cache_hit(out_dir, stage_hash, [out_path]):
1013
+ log_callback("♻️ Using cached FR heatmap.")
1014
+ return {"fr_heatmap": out_path}
1015
+
1016
+ log_callback("Computing firing rate heatmap...")
1017
+ fr_matrix = compute_fr_heatmap_matrix(
1018
+ spike_data,
1019
+ neuron_range=neuron_range,
1020
+ time_range=time_range,
1021
+ normalize=normalize,
1022
+ )
1023
+
1024
+ config = PlotConfigs.fr_heatmap(show=False, save_path=str(out_path))
1025
+ save_fr_heatmap_png(fr_matrix, config=config, dpi=200)
1026
+
1027
+ self._write_cache_meta(self._stage_cache_path(out_dir), {"hash": stage_hash})
1028
+ return {"fr_heatmap": out_path}
1029
+
1030
+ def _run_frm(
1031
+ self, asa_data: dict[str, Any], state: WorkflowState, log_callback
1032
+ ) -> dict[str, Path]:
1033
+ """Run single neuron firing rate map."""
1034
+ from canns.analyzer.data.asa import compute_frm, plot_frm
1035
+ from canns.analyzer.visualization import PlotConfigs
1036
+
1037
+ out_dir = self._results_dir(state) / "FRM"
1038
+ out_dir.mkdir(parents=True, exist_ok=True)
1039
+
1040
+ params = state.analysis_params
1041
+ neuron_id = int(params.get("neuron_id", 0))
1042
+ bins = int(params.get("bin_size", 50))
1043
+ min_occupancy = int(params.get("min_occupancy", 1))
1044
+ smoothing = bool(params.get("smoothing", False))
1045
+ smooth_sigma = float(params.get("smooth_sigma", 2.0))
1046
+ mode = str(params.get("mode", "fr"))
1047
+
1048
+ spike_data = self._embed_data if mode != "spike" else asa_data.get("spike")
1049
+ if spike_data is None:
1050
+ raise ProcessingError("No spike data available for FRM.")
1051
+
1052
+ pos = self._aligned_pos if self._aligned_pos is not None else asa_data
1053
+ x = np.asarray(pos.get("x"))
1054
+ y = np.asarray(pos.get("y"))
1055
+
1056
+ if x is None or y is None:
1057
+ raise ProcessingError("Position data (x,y) is required for FRM.")
1058
+
1059
+ out_path = out_dir / f"frm_neuron_{neuron_id}.png"
1060
+ stage_hash = self._hash_obj(
1061
+ {
1062
+ "embed_hash": self._embed_hash,
1063
+ "params": params,
1064
+ }
1065
+ )
1066
+ if self._stage_cache_hit(out_dir, stage_hash, [out_path]):
1067
+ log_callback("♻️ Using cached FRM.")
1068
+ return {"frm": out_path}
1069
+
1070
+ log_callback(f"Computing firing rate map for neuron {neuron_id}...")
1071
+ frm_result = compute_frm(
1072
+ spike_data,
1073
+ x,
1074
+ y,
1075
+ neuron_id=neuron_id,
1076
+ bins=max(1, bins),
1077
+ min_occupancy=min_occupancy,
1078
+ smoothing=smoothing,
1079
+ sigma=smooth_sigma,
1080
+ nan_for_empty=True,
1081
+ )
1082
+
1083
+ config = PlotConfigs.frm(show=False, save_path=str(out_path))
1084
+ plot_frm(frm_result.frm, config=config, dpi=200)
1085
+
1086
+ self._write_cache_meta(self._stage_cache_path(out_dir), {"hash": stage_hash})
1087
+ return {"frm": out_path}
1088
+
1089
+ def _run_gridscore(
1090
+ self, asa_data: dict[str, Any], state: WorkflowState, log_callback
1091
+ ) -> dict[str, Path]:
1092
+ """Run grid score analysis."""
1093
+
1094
+ log_callback("GridScore analysis is not implemented in the TUI yet.")
1095
+ raise ProcessingError("GridScore analysis is not implemented yet.")