canns 0.12.6__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- canns/__init__.py +39 -3
- canns/analyzer/__init__.py +7 -6
- canns/analyzer/data/__init__.py +3 -11
- canns/analyzer/data/asa/__init__.py +74 -0
- canns/analyzer/data/asa/cohospace.py +905 -0
- canns/analyzer/data/asa/config.py +246 -0
- canns/analyzer/data/asa/decode.py +448 -0
- canns/analyzer/data/asa/embedding.py +269 -0
- canns/analyzer/data/asa/filters.py +208 -0
- canns/analyzer/data/asa/fr.py +439 -0
- canns/analyzer/data/asa/path.py +389 -0
- canns/analyzer/data/asa/plotting.py +1276 -0
- canns/analyzer/data/asa/tda.py +901 -0
- canns/analyzer/data/legacy/__init__.py +6 -0
- canns/analyzer/data/{cann1d.py → legacy/cann1d.py} +2 -2
- canns/analyzer/data/{cann2d.py → legacy/cann2d.py} +3 -3
- canns/analyzer/metrics/spatial_metrics.py +70 -100
- canns/analyzer/metrics/systematic_ratemap.py +12 -17
- canns/analyzer/metrics/utils.py +28 -0
- canns/analyzer/model_specific/hopfield.py +19 -16
- canns/analyzer/slow_points/checkpoint.py +32 -9
- canns/analyzer/slow_points/finder.py +33 -6
- canns/analyzer/slow_points/fixed_points.py +12 -0
- canns/analyzer/slow_points/visualization.py +22 -10
- canns/analyzer/visualization/core/backend.py +15 -26
- canns/analyzer/visualization/core/config.py +120 -15
- canns/analyzer/visualization/core/jupyter_utils.py +34 -16
- canns/analyzer/visualization/core/rendering.py +42 -40
- canns/analyzer/visualization/core/writers.py +10 -20
- canns/analyzer/visualization/energy_plots.py +78 -28
- canns/analyzer/visualization/spatial_plots.py +81 -36
- canns/analyzer/visualization/spike_plots.py +27 -7
- canns/analyzer/visualization/theta_sweep_plots.py +159 -72
- canns/analyzer/visualization/tuning_plots.py +11 -3
- canns/data/__init__.py +7 -4
- canns/models/__init__.py +10 -0
- canns/models/basic/cann.py +102 -40
- canns/models/basic/grid_cell.py +9 -8
- canns/models/basic/hierarchical_model.py +57 -11
- canns/models/brain_inspired/hopfield.py +26 -14
- canns/models/brain_inspired/linear.py +15 -16
- canns/models/brain_inspired/spiking.py +23 -12
- canns/pipeline/__init__.py +4 -8
- canns/pipeline/asa/__init__.py +21 -0
- canns/pipeline/asa/__main__.py +11 -0
- canns/pipeline/asa/app.py +1000 -0
- canns/pipeline/asa/runner.py +1095 -0
- canns/pipeline/asa/screens.py +215 -0
- canns/pipeline/asa/state.py +248 -0
- canns/pipeline/asa/styles.tcss +221 -0
- canns/pipeline/asa/widgets.py +233 -0
- canns/pipeline/gallery/__init__.py +7 -0
- canns/task/closed_loop_navigation.py +54 -13
- canns/task/open_loop_navigation.py +230 -147
- canns/task/tracking.py +156 -24
- canns/trainer/__init__.py +8 -5
- canns/utils/__init__.py +12 -4
- {canns-0.12.6.dist-info → canns-0.13.0.dist-info}/METADATA +6 -3
- canns-0.13.0.dist-info/RECORD +91 -0
- {canns-0.12.6.dist-info → canns-0.13.0.dist-info}/entry_points.txt +1 -0
- canns/pipeline/theta_sweep.py +0 -573
- canns-0.12.6.dist-info/RECORD +0 -72
- {canns-0.12.6.dist-info → canns-0.13.0.dist-info}/WHEEL +0 -0
- {canns-0.12.6.dist-info → canns-0.13.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,1095 @@
|
|
|
1
|
+
"""Pipeline execution wrapper for ASA TUI.
|
|
2
|
+
|
|
3
|
+
This module provides async pipeline execution that integrates with the existing
|
|
4
|
+
canns.analyzer.data.asa module. It wraps the analysis functions and provides
|
|
5
|
+
progress callbacks for the TUI.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import hashlib
|
|
11
|
+
import json
|
|
12
|
+
import time
|
|
13
|
+
from collections.abc import Callable
|
|
14
|
+
from dataclasses import dataclass
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
from typing import Any
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
|
|
20
|
+
from .state import WorkflowState, resolve_path
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class PipelineResult:
|
|
25
|
+
"""Result from pipeline execution."""
|
|
26
|
+
|
|
27
|
+
success: bool
|
|
28
|
+
artifacts: dict[str, Path]
|
|
29
|
+
summary: str
|
|
30
|
+
error: str | None = None
|
|
31
|
+
elapsed_time: float = 0.0
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class ProcessingError(RuntimeError):
|
|
35
|
+
"""Raised when a pipeline stage fails."""
|
|
36
|
+
|
|
37
|
+
pass
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class PipelineRunner:
|
|
41
|
+
"""Async pipeline execution wrapper."""
|
|
42
|
+
|
|
43
|
+
def __init__(self):
|
|
44
|
+
"""Initialize pipeline runner."""
|
|
45
|
+
self._asa_data: dict[str, Any] | None = None
|
|
46
|
+
self._embed_data: np.ndarray | None = None # Preprocessed data
|
|
47
|
+
self._aligned_pos: dict[str, np.ndarray] | None = None
|
|
48
|
+
self._input_hash: str | None = None
|
|
49
|
+
self._embed_hash: str | None = None
|
|
50
|
+
self._mpl_ready: bool = False
|
|
51
|
+
|
|
52
|
+
def has_preprocessed_data(self) -> bool:
|
|
53
|
+
"""Check if preprocessing has been completed."""
|
|
54
|
+
return self._embed_data is not None
|
|
55
|
+
|
|
56
|
+
def reset_input(self) -> None:
|
|
57
|
+
"""Clear cached input/preprocessing state when input files change."""
|
|
58
|
+
self._asa_data = None
|
|
59
|
+
self._embed_data = None
|
|
60
|
+
self._aligned_pos = None
|
|
61
|
+
self._input_hash = None
|
|
62
|
+
self._embed_hash = None
|
|
63
|
+
|
|
64
|
+
def _json_safe(self, obj: Any) -> Any:
|
|
65
|
+
"""Convert objects to JSON-serializable structures."""
|
|
66
|
+
if isinstance(obj, Path):
|
|
67
|
+
return str(obj)
|
|
68
|
+
if isinstance(obj, tuple):
|
|
69
|
+
return [self._json_safe(v) for v in obj]
|
|
70
|
+
if isinstance(obj, list):
|
|
71
|
+
return [self._json_safe(v) for v in obj]
|
|
72
|
+
if isinstance(obj, dict):
|
|
73
|
+
return {str(k): self._json_safe(v) for k, v in obj.items()}
|
|
74
|
+
if hasattr(obj, "item") and callable(obj.item):
|
|
75
|
+
try:
|
|
76
|
+
return obj.item()
|
|
77
|
+
except Exception:
|
|
78
|
+
return str(obj)
|
|
79
|
+
return obj
|
|
80
|
+
|
|
81
|
+
def _hash_bytes(self, data: bytes) -> str:
|
|
82
|
+
return hashlib.md5(data).hexdigest()
|
|
83
|
+
|
|
84
|
+
def _hash_file(self, path: Path) -> str:
|
|
85
|
+
"""Compute md5 hash for a file."""
|
|
86
|
+
md5 = hashlib.md5()
|
|
87
|
+
with path.open("rb") as f:
|
|
88
|
+
for chunk in iter(lambda: f.read(1024 * 1024), b""):
|
|
89
|
+
md5.update(chunk)
|
|
90
|
+
return md5.hexdigest()
|
|
91
|
+
|
|
92
|
+
def _hash_obj(self, obj: Any) -> str:
|
|
93
|
+
payload = json.dumps(self._json_safe(obj), sort_keys=True, ensure_ascii=True).encode(
|
|
94
|
+
"utf-8"
|
|
95
|
+
)
|
|
96
|
+
return self._hash_bytes(payload)
|
|
97
|
+
|
|
98
|
+
def _ensure_matplotlib_backend(self) -> None:
|
|
99
|
+
"""Force a non-interactive Matplotlib backend for worker threads."""
|
|
100
|
+
if self._mpl_ready:
|
|
101
|
+
return
|
|
102
|
+
try:
|
|
103
|
+
import os
|
|
104
|
+
|
|
105
|
+
os.environ.setdefault("MPLBACKEND", "Agg")
|
|
106
|
+
import matplotlib
|
|
107
|
+
|
|
108
|
+
try:
|
|
109
|
+
matplotlib.use("Agg", force=True)
|
|
110
|
+
except TypeError:
|
|
111
|
+
matplotlib.use("Agg")
|
|
112
|
+
except Exception:
|
|
113
|
+
pass
|
|
114
|
+
self._mpl_ready = True
|
|
115
|
+
|
|
116
|
+
def _cache_dir(self, state: WorkflowState) -> Path:
|
|
117
|
+
return self._results_dir(state) / ".asa_cache"
|
|
118
|
+
|
|
119
|
+
def _results_dir(self, state: WorkflowState) -> Path:
|
|
120
|
+
base = state.workdir / "Results"
|
|
121
|
+
dataset_id = self._dataset_id(state)
|
|
122
|
+
return base / dataset_id
|
|
123
|
+
|
|
124
|
+
def results_dir(self, state: WorkflowState) -> Path:
|
|
125
|
+
"""Public accessor for results directory."""
|
|
126
|
+
return self._results_dir(state)
|
|
127
|
+
|
|
128
|
+
def _dataset_id(self, state: WorkflowState) -> str:
|
|
129
|
+
"""Create a stable dataset id based on input filename and md5 prefix."""
|
|
130
|
+
try:
|
|
131
|
+
input_hash = self._input_hash or self._compute_input_hash(state)
|
|
132
|
+
except Exception:
|
|
133
|
+
input_hash = "unknown"
|
|
134
|
+
prefix = input_hash[:8]
|
|
135
|
+
|
|
136
|
+
if state.input_mode == "asa":
|
|
137
|
+
path = resolve_path(state, state.asa_file)
|
|
138
|
+
stem = path.stem if path is not None else "asa"
|
|
139
|
+
return f"{stem}_{prefix}"
|
|
140
|
+
if state.input_mode == "neuron_traj":
|
|
141
|
+
neuron_path = resolve_path(state, state.neuron_file)
|
|
142
|
+
traj_path = resolve_path(state, state.traj_file)
|
|
143
|
+
neuron_stem = neuron_path.stem if neuron_path is not None else "neuron"
|
|
144
|
+
traj_stem = traj_path.stem if traj_path is not None else "traj"
|
|
145
|
+
return f"{neuron_stem}_{traj_stem}_{prefix}"
|
|
146
|
+
return f"{state.input_mode}_{prefix}"
|
|
147
|
+
|
|
148
|
+
def _stage_cache_path(self, stage_dir: Path) -> Path:
|
|
149
|
+
return stage_dir / "cache.json"
|
|
150
|
+
|
|
151
|
+
def _load_cache_meta(self, path: Path) -> dict[str, Any]:
|
|
152
|
+
if not path.exists():
|
|
153
|
+
return {}
|
|
154
|
+
try:
|
|
155
|
+
return json.loads(path.read_text(encoding="utf-8"))
|
|
156
|
+
except Exception:
|
|
157
|
+
return {}
|
|
158
|
+
|
|
159
|
+
def _write_cache_meta(self, path: Path, payload: dict[str, Any]) -> None:
|
|
160
|
+
path.write_text(
|
|
161
|
+
json.dumps(self._json_safe(payload), ensure_ascii=True, indent=2), encoding="utf-8"
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
def _stage_cache_hit(
|
|
165
|
+
self, stage_dir: Path, expected_hash: str, required_files: list[Path]
|
|
166
|
+
) -> bool:
|
|
167
|
+
if not all(p.exists() for p in required_files):
|
|
168
|
+
return False
|
|
169
|
+
meta = self._load_cache_meta(self._stage_cache_path(stage_dir))
|
|
170
|
+
return meta.get("hash") == expected_hash
|
|
171
|
+
|
|
172
|
+
def _compute_input_hash(self, state: WorkflowState) -> str:
|
|
173
|
+
"""Compute md5 hash for input data files."""
|
|
174
|
+
if state.input_mode == "asa":
|
|
175
|
+
path = resolve_path(state, state.asa_file)
|
|
176
|
+
if path is None:
|
|
177
|
+
raise ProcessingError("ASA file not set.")
|
|
178
|
+
return self._hash_obj({"mode": "asa", "file": self._hash_file(path)})
|
|
179
|
+
if state.input_mode == "neuron_traj":
|
|
180
|
+
neuron_path = resolve_path(state, state.neuron_file)
|
|
181
|
+
traj_path = resolve_path(state, state.traj_file)
|
|
182
|
+
if neuron_path is None or traj_path is None:
|
|
183
|
+
raise ProcessingError("Neuron/trajectory files not set.")
|
|
184
|
+
return self._hash_obj(
|
|
185
|
+
{
|
|
186
|
+
"mode": "neuron_traj",
|
|
187
|
+
"neuron": self._hash_file(neuron_path),
|
|
188
|
+
"traj": self._hash_file(traj_path),
|
|
189
|
+
}
|
|
190
|
+
)
|
|
191
|
+
return self._hash_obj({"mode": state.input_mode})
|
|
192
|
+
|
|
193
|
+
def _load_npz_dict(self, path: Path) -> dict[str, Any]:
|
|
194
|
+
"""Load npz into a dict, handling wrapped dict entries."""
|
|
195
|
+
data = np.load(path, allow_pickle=True)
|
|
196
|
+
for key in ("persistence_result", "decode_result"):
|
|
197
|
+
if key in data.files:
|
|
198
|
+
return data[key].item()
|
|
199
|
+
return {k: data[k] for k in data.files}
|
|
200
|
+
|
|
201
|
+
async def run_preprocessing(
|
|
202
|
+
self,
|
|
203
|
+
state: WorkflowState,
|
|
204
|
+
log_callback: Callable[[str], None],
|
|
205
|
+
progress_callback: Callable[[int], None],
|
|
206
|
+
) -> PipelineResult:
|
|
207
|
+
"""Run preprocessing pipeline to generate embed_data.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
state: Current workflow state
|
|
211
|
+
log_callback: Callback for log messages
|
|
212
|
+
progress_callback: Callback for progress updates (0-100)
|
|
213
|
+
|
|
214
|
+
Returns:
|
|
215
|
+
PipelineResult with preprocessing status
|
|
216
|
+
"""
|
|
217
|
+
t0 = time.time()
|
|
218
|
+
|
|
219
|
+
try:
|
|
220
|
+
# Stage 1: Load data
|
|
221
|
+
log_callback("Loading data...")
|
|
222
|
+
progress_callback(10)
|
|
223
|
+
asa_data = self._load_data(state)
|
|
224
|
+
self._asa_data = asa_data
|
|
225
|
+
self._aligned_pos = None
|
|
226
|
+
self._input_hash = self._compute_input_hash(state)
|
|
227
|
+
|
|
228
|
+
# Stage 2: Preprocess
|
|
229
|
+
log_callback(f"Preprocessing with {state.preprocess_method}...")
|
|
230
|
+
progress_callback(30)
|
|
231
|
+
|
|
232
|
+
if state.preprocess_method == "embed_spike_trains":
|
|
233
|
+
from canns.analyzer.data.asa import SpikeEmbeddingConfig, embed_spike_trains
|
|
234
|
+
|
|
235
|
+
# Get preprocessing parameters from state or use config defaults
|
|
236
|
+
params = state.preprocess_params if state.preprocess_params else {}
|
|
237
|
+
base_config = SpikeEmbeddingConfig()
|
|
238
|
+
effective_params = {
|
|
239
|
+
"res": base_config.res,
|
|
240
|
+
"dt": base_config.dt,
|
|
241
|
+
"sigma": base_config.sigma,
|
|
242
|
+
"smooth": base_config.smooth,
|
|
243
|
+
"speed_filter": base_config.speed_filter,
|
|
244
|
+
"min_speed": base_config.min_speed,
|
|
245
|
+
}
|
|
246
|
+
effective_params.update(params)
|
|
247
|
+
|
|
248
|
+
self._embed_hash = self._hash_obj(
|
|
249
|
+
{
|
|
250
|
+
"input_hash": self._input_hash,
|
|
251
|
+
"method": state.preprocess_method,
|
|
252
|
+
"params": effective_params,
|
|
253
|
+
}
|
|
254
|
+
)
|
|
255
|
+
cache_dir = self._cache_dir(state)
|
|
256
|
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
|
257
|
+
cache_path = cache_dir / f"embed_{self._embed_hash}.npz"
|
|
258
|
+
meta_path = cache_dir / f"embed_{self._embed_hash}.json"
|
|
259
|
+
|
|
260
|
+
if cache_path.exists():
|
|
261
|
+
log_callback("♻️ Using cached embedding.")
|
|
262
|
+
cached = np.load(cache_path, allow_pickle=True)
|
|
263
|
+
self._embed_data = cached["embed_data"]
|
|
264
|
+
if {"x", "y", "t"}.issubset(set(cached.files)):
|
|
265
|
+
self._aligned_pos = {
|
|
266
|
+
"x": cached["x"],
|
|
267
|
+
"y": cached["y"],
|
|
268
|
+
"t": cached["t"],
|
|
269
|
+
}
|
|
270
|
+
progress_callback(100)
|
|
271
|
+
elapsed = time.time() - t0
|
|
272
|
+
return PipelineResult(
|
|
273
|
+
success=True,
|
|
274
|
+
artifacts={"embedding": cache_path},
|
|
275
|
+
summary=f"Preprocessing reused cached embedding in {elapsed:.1f}s",
|
|
276
|
+
elapsed_time=elapsed,
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
config = SpikeEmbeddingConfig(**effective_params)
|
|
280
|
+
|
|
281
|
+
log_callback("Running embed_spike_trains...")
|
|
282
|
+
progress_callback(50)
|
|
283
|
+
embed_result = embed_spike_trains(asa_data, config)
|
|
284
|
+
|
|
285
|
+
if isinstance(embed_result, tuple):
|
|
286
|
+
embed_data = embed_result[0]
|
|
287
|
+
if len(embed_result) >= 4 and embed_result[1] is not None:
|
|
288
|
+
self._aligned_pos = {
|
|
289
|
+
"x": embed_result[1],
|
|
290
|
+
"y": embed_result[2],
|
|
291
|
+
"t": embed_result[3],
|
|
292
|
+
}
|
|
293
|
+
else:
|
|
294
|
+
embed_data = embed_result
|
|
295
|
+
|
|
296
|
+
self._embed_data = embed_data
|
|
297
|
+
log_callback(f"Embed data shape: {embed_data.shape}")
|
|
298
|
+
|
|
299
|
+
try:
|
|
300
|
+
payload = {"embed_data": embed_data}
|
|
301
|
+
if self._aligned_pos is not None:
|
|
302
|
+
payload.update(self._aligned_pos)
|
|
303
|
+
np.savez_compressed(cache_path, **payload)
|
|
304
|
+
self._write_cache_meta(
|
|
305
|
+
meta_path,
|
|
306
|
+
{
|
|
307
|
+
"hash": self._embed_hash,
|
|
308
|
+
"input_hash": self._input_hash,
|
|
309
|
+
"params": effective_params,
|
|
310
|
+
},
|
|
311
|
+
)
|
|
312
|
+
except Exception as e:
|
|
313
|
+
log_callback(f"Warning: failed to cache embedding: {e}")
|
|
314
|
+
else:
|
|
315
|
+
# No preprocessing - use spike data directly
|
|
316
|
+
log_callback("No preprocessing - using raw spike data")
|
|
317
|
+
spike = asa_data.get("spike")
|
|
318
|
+
self._embed_hash = self._hash_obj(
|
|
319
|
+
{
|
|
320
|
+
"input_hash": self._input_hash,
|
|
321
|
+
"method": state.preprocess_method,
|
|
322
|
+
"params": {},
|
|
323
|
+
}
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
# Check if already a dense matrix
|
|
327
|
+
if isinstance(spike, np.ndarray) and spike.ndim == 2:
|
|
328
|
+
self._embed_data = spike
|
|
329
|
+
log_callback(f"Using spike matrix shape: {spike.shape}")
|
|
330
|
+
else:
|
|
331
|
+
log_callback(
|
|
332
|
+
"Warning: spike data is not a dense matrix, some analyses may fail"
|
|
333
|
+
)
|
|
334
|
+
self._embed_data = spike
|
|
335
|
+
|
|
336
|
+
progress_callback(100)
|
|
337
|
+
elapsed = time.time() - t0
|
|
338
|
+
|
|
339
|
+
return PipelineResult(
|
|
340
|
+
success=True,
|
|
341
|
+
artifacts={},
|
|
342
|
+
summary=f"Preprocessing completed in {elapsed:.1f}s",
|
|
343
|
+
elapsed_time=elapsed,
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
except Exception as e:
|
|
347
|
+
elapsed = time.time() - t0
|
|
348
|
+
log_callback(f"Error: {e}")
|
|
349
|
+
return PipelineResult(
|
|
350
|
+
success=False,
|
|
351
|
+
artifacts={},
|
|
352
|
+
summary=f"Failed after {elapsed:.1f}s",
|
|
353
|
+
error=str(e),
|
|
354
|
+
elapsed_time=elapsed,
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
async def run_analysis(
|
|
358
|
+
self,
|
|
359
|
+
state: WorkflowState,
|
|
360
|
+
log_callback: Callable[[str], None],
|
|
361
|
+
progress_callback: Callable[[int], None],
|
|
362
|
+
) -> PipelineResult:
|
|
363
|
+
"""Run analysis pipeline based on workflow state.
|
|
364
|
+
|
|
365
|
+
Args:
|
|
366
|
+
state: Current workflow state
|
|
367
|
+
log_callback: Callback for log messages
|
|
368
|
+
progress_callback: Callback for progress updates (0-100)
|
|
369
|
+
|
|
370
|
+
Returns:
|
|
371
|
+
PipelineResult with success status and artifacts
|
|
372
|
+
"""
|
|
373
|
+
t0 = time.time()
|
|
374
|
+
artifacts = {}
|
|
375
|
+
|
|
376
|
+
try:
|
|
377
|
+
# Stage 1: Load data
|
|
378
|
+
log_callback("Loading data...")
|
|
379
|
+
progress_callback(10)
|
|
380
|
+
asa_data = self._asa_data if self._asa_data is not None else self._load_data(state)
|
|
381
|
+
if self._input_hash is None:
|
|
382
|
+
self._input_hash = self._compute_input_hash(state)
|
|
383
|
+
|
|
384
|
+
self._ensure_matplotlib_backend()
|
|
385
|
+
|
|
386
|
+
# Stage 3: Analysis (mode-dependent)
|
|
387
|
+
log_callback(f"Running {state.analysis_mode} analysis...")
|
|
388
|
+
progress_callback(40)
|
|
389
|
+
|
|
390
|
+
mode = state.analysis_mode.lower()
|
|
391
|
+
if mode == "tda":
|
|
392
|
+
artifacts = self._run_tda(asa_data, state, log_callback)
|
|
393
|
+
elif mode == "cohomap":
|
|
394
|
+
artifacts = self._run_cohomap(asa_data, state, log_callback)
|
|
395
|
+
elif mode == "pathcompare":
|
|
396
|
+
artifacts = self._run_pathcompare(asa_data, state, log_callback)
|
|
397
|
+
elif mode == "cohospace":
|
|
398
|
+
artifacts = self._run_cohospace(asa_data, state, log_callback)
|
|
399
|
+
elif mode == "fr":
|
|
400
|
+
artifacts = self._run_fr(asa_data, state, log_callback)
|
|
401
|
+
elif mode == "frm":
|
|
402
|
+
artifacts = self._run_frm(asa_data, state, log_callback)
|
|
403
|
+
elif mode == "gridscore":
|
|
404
|
+
artifacts = self._run_gridscore(asa_data, state, log_callback)
|
|
405
|
+
else:
|
|
406
|
+
raise ProcessingError(f"Unknown analysis mode: {state.analysis_mode}")
|
|
407
|
+
|
|
408
|
+
progress_callback(100)
|
|
409
|
+
elapsed = time.time() - t0
|
|
410
|
+
|
|
411
|
+
return PipelineResult(
|
|
412
|
+
success=True,
|
|
413
|
+
artifacts=artifacts,
|
|
414
|
+
summary=f"Completed {state.analysis_mode} analysis in {elapsed:.1f}s",
|
|
415
|
+
elapsed_time=elapsed,
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
except Exception as e:
|
|
419
|
+
elapsed = time.time() - t0
|
|
420
|
+
log_callback(f"Error: {e}")
|
|
421
|
+
return PipelineResult(
|
|
422
|
+
success=False,
|
|
423
|
+
artifacts=artifacts,
|
|
424
|
+
summary=f"Failed after {elapsed:.1f}s",
|
|
425
|
+
error=str(e),
|
|
426
|
+
elapsed_time=elapsed,
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
def _load_data(self, state: WorkflowState) -> dict[str, Any]:
|
|
430
|
+
"""Load data based on input mode."""
|
|
431
|
+
if state.input_mode == "asa":
|
|
432
|
+
path = resolve_path(state, state.asa_file)
|
|
433
|
+
data = np.load(path, allow_pickle=True)
|
|
434
|
+
return {k: data[k] for k in data.files}
|
|
435
|
+
elif state.input_mode == "neuron_traj":
|
|
436
|
+
neuron_path = resolve_path(state, state.neuron_file)
|
|
437
|
+
traj_path = resolve_path(state, state.traj_file)
|
|
438
|
+
neuron_data = np.load(neuron_path, allow_pickle=True)
|
|
439
|
+
traj_data = np.load(traj_path, allow_pickle=True)
|
|
440
|
+
return {
|
|
441
|
+
"spike": neuron_data.get("spike", neuron_data),
|
|
442
|
+
"x": traj_data["x"],
|
|
443
|
+
"y": traj_data["y"],
|
|
444
|
+
"t": traj_data["t"],
|
|
445
|
+
}
|
|
446
|
+
else:
|
|
447
|
+
raise ProcessingError(f"Unknown input mode: {state.input_mode}")
|
|
448
|
+
|
|
449
|
+
def _run_preprocess(self, asa_data: dict[str, Any], state: WorkflowState) -> dict[str, Any]:
|
|
450
|
+
"""Run preprocessing on ASA data."""
|
|
451
|
+
if state.preprocess_method == "embed_spike_trains":
|
|
452
|
+
from canns.analyzer.data.asa import SpikeEmbeddingConfig, embed_spike_trains
|
|
453
|
+
|
|
454
|
+
params = state.preprocess_params
|
|
455
|
+
base_config = SpikeEmbeddingConfig()
|
|
456
|
+
effective_params = {
|
|
457
|
+
"res": base_config.res,
|
|
458
|
+
"dt": base_config.dt,
|
|
459
|
+
"sigma": base_config.sigma,
|
|
460
|
+
"smooth": base_config.smooth,
|
|
461
|
+
"speed_filter": base_config.speed_filter,
|
|
462
|
+
"min_speed": base_config.min_speed,
|
|
463
|
+
}
|
|
464
|
+
effective_params.update(params)
|
|
465
|
+
config = SpikeEmbeddingConfig(**effective_params)
|
|
466
|
+
|
|
467
|
+
spike_main = embed_spike_trains(asa_data, config)
|
|
468
|
+
asa_data["spike_main"] = spike_main
|
|
469
|
+
|
|
470
|
+
return asa_data
|
|
471
|
+
|
|
472
|
+
def _run_tda(
|
|
473
|
+
self, asa_data: dict[str, Any], state: WorkflowState, log_callback
|
|
474
|
+
) -> dict[str, Path]:
|
|
475
|
+
"""Run TDA analysis."""
|
|
476
|
+
from canns.analyzer.data.asa import TDAConfig, tda_vis
|
|
477
|
+
from canns.analyzer.data.asa.tda import _plot_barcode, _plot_barcode_with_shuffle
|
|
478
|
+
|
|
479
|
+
# Create output directory
|
|
480
|
+
out_dir = self._results_dir(state) / "TDA"
|
|
481
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
482
|
+
|
|
483
|
+
# Get parameters
|
|
484
|
+
params = state.analysis_params
|
|
485
|
+
config = TDAConfig(
|
|
486
|
+
dim=params.get("dim", 6),
|
|
487
|
+
num_times=params.get("num_times", 5),
|
|
488
|
+
active_times=params.get("active_times", 15000),
|
|
489
|
+
k=params.get("k", 1000),
|
|
490
|
+
n_points=params.get("n_points", 1200),
|
|
491
|
+
metric=params.get("metric", "cosine"),
|
|
492
|
+
nbs=params.get("nbs", 800),
|
|
493
|
+
maxdim=params.get("maxdim", 1),
|
|
494
|
+
coeff=params.get("coeff", 47),
|
|
495
|
+
show=False,
|
|
496
|
+
do_shuffle=params.get("do_shuffle", False),
|
|
497
|
+
num_shuffles=params.get("num_shuffles", 1000),
|
|
498
|
+
progress_bar=False,
|
|
499
|
+
)
|
|
500
|
+
|
|
501
|
+
log_callback("Computing persistent homology...")
|
|
502
|
+
|
|
503
|
+
if self._embed_data is None:
|
|
504
|
+
raise ProcessingError("No preprocessed data available. Run preprocessing first.")
|
|
505
|
+
if not isinstance(self._embed_data, np.ndarray) or self._embed_data.ndim != 2:
|
|
506
|
+
raise ProcessingError(
|
|
507
|
+
"TDA requires a dense spike matrix (T,N). "
|
|
508
|
+
"Please choose 'Embed Spike Trains' in preprocessing or provide a dense spike matrix in the .npz."
|
|
509
|
+
)
|
|
510
|
+
|
|
511
|
+
persistence_path = out_dir / "persistence.npz"
|
|
512
|
+
barcode_path = out_dir / "barcode.png"
|
|
513
|
+
|
|
514
|
+
embed_hash = self._embed_hash or self._hash_obj({"embed": "unknown"})
|
|
515
|
+
tda_hash = self._hash_obj({"embed_hash": embed_hash, "params": params})
|
|
516
|
+
|
|
517
|
+
if self._stage_cache_hit(out_dir, tda_hash, [persistence_path, barcode_path]):
|
|
518
|
+
log_callback("♻️ Using cached TDA results.")
|
|
519
|
+
return {"persistence": persistence_path, "barcode": barcode_path}
|
|
520
|
+
|
|
521
|
+
embed_data = self._embed_data
|
|
522
|
+
if params.get("standardize", False):
|
|
523
|
+
try:
|
|
524
|
+
from sklearn.preprocessing import StandardScaler
|
|
525
|
+
|
|
526
|
+
embed_data = StandardScaler().fit_transform(embed_data)
|
|
527
|
+
except Exception as e:
|
|
528
|
+
raise ProcessingError(f"StandardScaler failed: {e}") from e
|
|
529
|
+
|
|
530
|
+
result = tda_vis(
|
|
531
|
+
embed_data=embed_data,
|
|
532
|
+
config=config,
|
|
533
|
+
)
|
|
534
|
+
|
|
535
|
+
np.savez_compressed(persistence_path, persistence_result=result)
|
|
536
|
+
|
|
537
|
+
try:
|
|
538
|
+
persistence = result.get("persistence")
|
|
539
|
+
shuffle_max = result.get("shuffle_max")
|
|
540
|
+
if config.do_shuffle and shuffle_max is not None:
|
|
541
|
+
fig = _plot_barcode_with_shuffle(persistence, shuffle_max)
|
|
542
|
+
else:
|
|
543
|
+
fig = _plot_barcode(persistence)
|
|
544
|
+
fig.savefig(barcode_path, dpi=200, bbox_inches="tight")
|
|
545
|
+
try:
|
|
546
|
+
import matplotlib.pyplot as plt
|
|
547
|
+
|
|
548
|
+
plt.close(fig)
|
|
549
|
+
except Exception:
|
|
550
|
+
pass
|
|
551
|
+
except Exception as e:
|
|
552
|
+
log_callback(f"Warning: failed to save barcode: {e}")
|
|
553
|
+
|
|
554
|
+
self._write_cache_meta(
|
|
555
|
+
self._stage_cache_path(out_dir),
|
|
556
|
+
{"hash": tda_hash, "embed_hash": embed_hash, "params": params},
|
|
557
|
+
)
|
|
558
|
+
|
|
559
|
+
return {"persistence": persistence_path, "barcode": barcode_path}
|
|
560
|
+
|
|
561
|
+
def _load_or_run_decode(
|
|
562
|
+
self,
|
|
563
|
+
asa_data: dict[str, Any],
|
|
564
|
+
persistence_path: Path,
|
|
565
|
+
state: WorkflowState,
|
|
566
|
+
log_callback,
|
|
567
|
+
) -> dict[str, Any]:
|
|
568
|
+
"""Load cached decoding or run decode_circular_coordinates."""
|
|
569
|
+
from canns.analyzer.data.asa import (
|
|
570
|
+
decode_circular_coordinates,
|
|
571
|
+
decode_circular_coordinates_multi,
|
|
572
|
+
)
|
|
573
|
+
|
|
574
|
+
decode_dir = self._results_dir(state) / "CohoMap"
|
|
575
|
+
decode_dir.mkdir(parents=True, exist_ok=True)
|
|
576
|
+
decode_path = decode_dir / "decoding.npz"
|
|
577
|
+
|
|
578
|
+
params = state.analysis_params
|
|
579
|
+
decode_version = str(params.get("decode_version", "v2"))
|
|
580
|
+
num_circ = int(params.get("num_circ", 2))
|
|
581
|
+
decode_params = {
|
|
582
|
+
"real_ground": params.get("real_ground", True),
|
|
583
|
+
"real_of": params.get("real_of", True),
|
|
584
|
+
"decode_version": decode_version,
|
|
585
|
+
"num_circ": num_circ,
|
|
586
|
+
}
|
|
587
|
+
persistence_hash = self._hash_file(persistence_path)
|
|
588
|
+
decode_hash = self._hash_obj(
|
|
589
|
+
{"persistence_hash": persistence_hash, "params": decode_params}
|
|
590
|
+
)
|
|
591
|
+
|
|
592
|
+
meta_path = self._stage_cache_path(decode_dir)
|
|
593
|
+
meta = self._load_cache_meta(meta_path)
|
|
594
|
+
if decode_path.exists() and meta.get("decode_hash") == decode_hash:
|
|
595
|
+
log_callback("♻️ Using cached decoding.")
|
|
596
|
+
return self._load_npz_dict(decode_path)
|
|
597
|
+
|
|
598
|
+
log_callback("Decoding circular coordinates...")
|
|
599
|
+
persistence_result = self._load_npz_dict(persistence_path)
|
|
600
|
+
if decode_version == "v0":
|
|
601
|
+
decode_result = decode_circular_coordinates(
|
|
602
|
+
persistence_result=persistence_result,
|
|
603
|
+
spike_data=asa_data,
|
|
604
|
+
real_ground=decode_params["real_ground"],
|
|
605
|
+
real_of=decode_params["real_of"],
|
|
606
|
+
save_path=str(decode_path),
|
|
607
|
+
)
|
|
608
|
+
else:
|
|
609
|
+
if self._embed_data is None:
|
|
610
|
+
raise ProcessingError("No preprocessed data available for decode v2.")
|
|
611
|
+
spike_data = dict(asa_data)
|
|
612
|
+
spike_data["spike"] = self._embed_data
|
|
613
|
+
decode_result = decode_circular_coordinates_multi(
|
|
614
|
+
persistence_result=persistence_result,
|
|
615
|
+
spike_data=spike_data,
|
|
616
|
+
save_path=str(decode_path),
|
|
617
|
+
num_circ=num_circ,
|
|
618
|
+
)
|
|
619
|
+
|
|
620
|
+
meta["decode_hash"] = decode_hash
|
|
621
|
+
meta["persistence_hash"] = persistence_hash
|
|
622
|
+
meta["decode_params"] = decode_params
|
|
623
|
+
self._write_cache_meta(meta_path, meta)
|
|
624
|
+
return decode_result
|
|
625
|
+
|
|
626
|
+
def _run_cohomap(
|
|
627
|
+
self, asa_data: dict[str, Any], state: WorkflowState, log_callback
|
|
628
|
+
) -> dict[str, Path]:
|
|
629
|
+
"""Run CohoMap analysis (TDA + decode + plotting)."""
|
|
630
|
+
from canns.analyzer.data.asa import plot_cohomap_multi
|
|
631
|
+
from canns.analyzer.visualization import PlotConfigs
|
|
632
|
+
|
|
633
|
+
tda_dir = self._results_dir(state) / "TDA"
|
|
634
|
+
persistence_path = tda_dir / "persistence.npz"
|
|
635
|
+
if not persistence_path.exists():
|
|
636
|
+
raise ProcessingError("TDA results not found. Run TDA analysis first.")
|
|
637
|
+
|
|
638
|
+
out_dir = self._results_dir(state) / "CohoMap"
|
|
639
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
640
|
+
|
|
641
|
+
decode_result = self._load_or_run_decode(asa_data, persistence_path, state, log_callback)
|
|
642
|
+
|
|
643
|
+
params = state.analysis_params
|
|
644
|
+
subsample = int(params.get("cohomap_subsample", 10))
|
|
645
|
+
|
|
646
|
+
cohomap_path = out_dir / "cohomap.png"
|
|
647
|
+
stage_hash = self._hash_obj(
|
|
648
|
+
{
|
|
649
|
+
"decode_hash": self._load_cache_meta(self._stage_cache_path(out_dir)).get(
|
|
650
|
+
"decode_hash"
|
|
651
|
+
),
|
|
652
|
+
"plot": "cohomap",
|
|
653
|
+
"subsample": subsample,
|
|
654
|
+
}
|
|
655
|
+
)
|
|
656
|
+
if self._stage_cache_hit(out_dir, stage_hash, [cohomap_path]):
|
|
657
|
+
log_callback("♻️ Using cached CohoMap plot.")
|
|
658
|
+
return {"decoding": out_dir / "decoding.npz", "cohomap": cohomap_path}
|
|
659
|
+
|
|
660
|
+
log_callback("Generating cohomology map...")
|
|
661
|
+
pos = self._aligned_pos if self._aligned_pos is not None else asa_data
|
|
662
|
+
config = PlotConfigs.cohomap(show=False, save_path=str(cohomap_path))
|
|
663
|
+
plot_cohomap_multi(
|
|
664
|
+
decoding_result=decode_result,
|
|
665
|
+
position_data={"x": pos["x"], "y": pos["y"]},
|
|
666
|
+
config=config,
|
|
667
|
+
subsample=subsample,
|
|
668
|
+
)
|
|
669
|
+
|
|
670
|
+
self._write_cache_meta(
|
|
671
|
+
self._stage_cache_path(out_dir),
|
|
672
|
+
{
|
|
673
|
+
**self._load_cache_meta(self._stage_cache_path(out_dir)),
|
|
674
|
+
"hash": stage_hash,
|
|
675
|
+
},
|
|
676
|
+
)
|
|
677
|
+
|
|
678
|
+
return {"decoding": out_dir / "decoding.npz", "cohomap": cohomap_path}
|
|
679
|
+
|
|
680
|
+
def _run_pathcompare(
|
|
681
|
+
self, asa_data: dict[str, Any], state: WorkflowState, log_callback
|
|
682
|
+
) -> dict[str, Path]:
|
|
683
|
+
"""Run path comparison visualization."""
|
|
684
|
+
from canns.analyzer.data.asa import (
|
|
685
|
+
align_coords_to_position,
|
|
686
|
+
apply_angle_scale,
|
|
687
|
+
plot_path_compare,
|
|
688
|
+
)
|
|
689
|
+
from canns.analyzer.data.asa.path import (
|
|
690
|
+
find_coords_matrix,
|
|
691
|
+
find_times_box,
|
|
692
|
+
resolve_time_slice,
|
|
693
|
+
)
|
|
694
|
+
from canns.analyzer.visualization import PlotConfigs
|
|
695
|
+
|
|
696
|
+
tda_dir = self._results_dir(state) / "TDA"
|
|
697
|
+
persistence_path = tda_dir / "persistence.npz"
|
|
698
|
+
if not persistence_path.exists():
|
|
699
|
+
raise ProcessingError("TDA results not found. Run TDA analysis first.")
|
|
700
|
+
|
|
701
|
+
decode_result = self._load_or_run_decode(asa_data, persistence_path, state, log_callback)
|
|
702
|
+
|
|
703
|
+
# Create output directory
|
|
704
|
+
out_dir = self._results_dir(state) / "PathCompare"
|
|
705
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
706
|
+
|
|
707
|
+
params = state.analysis_params
|
|
708
|
+
angle_scale = params.get("angle_scale", "rad")
|
|
709
|
+
dim_mode = params.get("dim_mode", "2d")
|
|
710
|
+
dim = int(params.get("dim", 1))
|
|
711
|
+
dim1 = int(params.get("dim1", 1))
|
|
712
|
+
dim2 = int(params.get("dim2", 2))
|
|
713
|
+
use_box = bool(params.get("use_box", False))
|
|
714
|
+
interp_full = bool(params.get("interp_full", True))
|
|
715
|
+
coords_key = params.get("coords_key")
|
|
716
|
+
times_key = params.get("times_key")
|
|
717
|
+
slice_mode = params.get("slice_mode", "time")
|
|
718
|
+
tmin = params.get("tmin")
|
|
719
|
+
tmax = params.get("tmax")
|
|
720
|
+
imin = params.get("imin")
|
|
721
|
+
imax = params.get("imax")
|
|
722
|
+
stride = int(params.get("stride", 1))
|
|
723
|
+
|
|
724
|
+
coords_raw, _ = find_coords_matrix(
|
|
725
|
+
decode_result,
|
|
726
|
+
coords_key=coords_key,
|
|
727
|
+
prefer_box_fallback=use_box,
|
|
728
|
+
)
|
|
729
|
+
|
|
730
|
+
if dim_mode == "1d":
|
|
731
|
+
idx = max(0, dim - 1)
|
|
732
|
+
if idx >= coords_raw.shape[1]:
|
|
733
|
+
raise ProcessingError(f"dim out of range for coords shape {coords_raw.shape}")
|
|
734
|
+
coords2 = coords_raw[:, [idx]]
|
|
735
|
+
else:
|
|
736
|
+
idx1 = max(0, dim1 - 1)
|
|
737
|
+
idx2 = max(0, dim2 - 1)
|
|
738
|
+
if idx1 >= coords_raw.shape[1] or idx2 >= coords_raw.shape[1]:
|
|
739
|
+
raise ProcessingError(f"dim1/dim2 out of range for coords shape {coords_raw.shape}")
|
|
740
|
+
coords2 = coords_raw[:, [idx1, idx2]]
|
|
741
|
+
|
|
742
|
+
pos = self._aligned_pos if self._aligned_pos is not None else asa_data
|
|
743
|
+
t_full = np.asarray(pos["t"]).ravel()
|
|
744
|
+
x_full = np.asarray(pos["x"]).ravel()
|
|
745
|
+
y_full = np.asarray(pos["y"]).ravel()
|
|
746
|
+
|
|
747
|
+
if use_box:
|
|
748
|
+
if times_key:
|
|
749
|
+
times_box = decode_result.get(times_key)
|
|
750
|
+
else:
|
|
751
|
+
times_box, _ = find_times_box(decode_result)
|
|
752
|
+
else:
|
|
753
|
+
times_box = None
|
|
754
|
+
|
|
755
|
+
log_callback("Aligning decoded coordinates to position...")
|
|
756
|
+
t_use, x_use, y_use, coords_use, _ = align_coords_to_position(
|
|
757
|
+
t_full=t_full,
|
|
758
|
+
x_full=x_full,
|
|
759
|
+
y_full=y_full,
|
|
760
|
+
coords2=coords2,
|
|
761
|
+
use_box=use_box,
|
|
762
|
+
times_box=times_box,
|
|
763
|
+
interp_to_full=interp_full,
|
|
764
|
+
)
|
|
765
|
+
scale = str(angle_scale) if str(angle_scale) in {"rad", "deg", "unit", "auto"} else "rad"
|
|
766
|
+
coords_use = apply_angle_scale(coords_use, scale)
|
|
767
|
+
|
|
768
|
+
if slice_mode == "index":
|
|
769
|
+
i0, i1 = resolve_time_slice(t_use, None, None, imin, imax)
|
|
770
|
+
else:
|
|
771
|
+
i0, i1 = resolve_time_slice(t_use, tmin, tmax, None, None)
|
|
772
|
+
|
|
773
|
+
stride = max(1, stride)
|
|
774
|
+
idx = slice(i0, i1, stride)
|
|
775
|
+
t_use = t_use[idx]
|
|
776
|
+
x_use = x_use[idx]
|
|
777
|
+
y_use = y_use[idx]
|
|
778
|
+
coords_use = coords_use[idx]
|
|
779
|
+
|
|
780
|
+
out_path = out_dir / "path_compare.png"
|
|
781
|
+
decode_meta = self._load_cache_meta(
|
|
782
|
+
self._stage_cache_path(self._results_dir(state) / "CohoMap")
|
|
783
|
+
)
|
|
784
|
+
stage_hash = self._hash_obj(
|
|
785
|
+
{
|
|
786
|
+
"persistence": self._hash_file(persistence_path),
|
|
787
|
+
"decode_hash": decode_meta.get("decode_hash"),
|
|
788
|
+
"params": {
|
|
789
|
+
"angle_scale": scale,
|
|
790
|
+
"dim_mode": dim_mode,
|
|
791
|
+
"dim": dim,
|
|
792
|
+
"dim1": dim1,
|
|
793
|
+
"dim2": dim2,
|
|
794
|
+
"use_box": use_box,
|
|
795
|
+
"interp_full": interp_full,
|
|
796
|
+
"coords_key": coords_key,
|
|
797
|
+
"times_key": times_key,
|
|
798
|
+
"slice_mode": slice_mode,
|
|
799
|
+
"tmin": tmin,
|
|
800
|
+
"tmax": tmax,
|
|
801
|
+
"imin": imin,
|
|
802
|
+
"imax": imax,
|
|
803
|
+
"stride": stride,
|
|
804
|
+
},
|
|
805
|
+
}
|
|
806
|
+
)
|
|
807
|
+
if self._stage_cache_hit(out_dir, stage_hash, [out_path]):
|
|
808
|
+
log_callback("♻️ Using cached PathCompare plot.")
|
|
809
|
+
return {"path_compare": out_path}
|
|
810
|
+
|
|
811
|
+
log_callback("Generating path comparison...")
|
|
812
|
+
config = PlotConfigs.path_compare(show=False, save_path=str(out_path))
|
|
813
|
+
plot_path_compare(x_use, y_use, coords_use, config=config)
|
|
814
|
+
|
|
815
|
+
self._write_cache_meta(self._stage_cache_path(out_dir), {"hash": stage_hash})
|
|
816
|
+
return {"path_compare": out_path}
|
|
817
|
+
|
|
818
|
+
def _run_cohospace(
|
|
819
|
+
self, asa_data: dict[str, Any], state: WorkflowState, log_callback
|
|
820
|
+
) -> dict[str, Path]:
|
|
821
|
+
"""Run cohomology space visualization."""
|
|
822
|
+
from canns.analyzer.data.asa import (
|
|
823
|
+
plot_cohospace_neuron,
|
|
824
|
+
plot_cohospace_population,
|
|
825
|
+
plot_cohospace_trajectory,
|
|
826
|
+
)
|
|
827
|
+
from canns.analyzer.data.asa.cohospace import (
|
|
828
|
+
plot_cohospace_neuron_skewed,
|
|
829
|
+
plot_cohospace_population_skewed,
|
|
830
|
+
)
|
|
831
|
+
from canns.analyzer.visualization import PlotConfigs
|
|
832
|
+
|
|
833
|
+
tda_dir = self._results_dir(state) / "TDA"
|
|
834
|
+
persistence_path = tda_dir / "persistence.npz"
|
|
835
|
+
if not persistence_path.exists():
|
|
836
|
+
raise ProcessingError("TDA results not found. Run TDA analysis first.")
|
|
837
|
+
|
|
838
|
+
decode_result = self._load_or_run_decode(asa_data, persistence_path, state, log_callback)
|
|
839
|
+
|
|
840
|
+
out_dir = self._results_dir(state) / "CohoSpace"
|
|
841
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
842
|
+
|
|
843
|
+
params = state.analysis_params
|
|
844
|
+
artifacts: dict[str, Path] = {}
|
|
845
|
+
|
|
846
|
+
coords = np.asarray(decode_result.get("coords"))
|
|
847
|
+
coordsbox = np.asarray(decode_result.get("coordsbox"))
|
|
848
|
+
if coords.ndim != 2 or coords.shape[1] < 1:
|
|
849
|
+
raise ProcessingError("decode_result['coords'] must be 2D.")
|
|
850
|
+
|
|
851
|
+
dim_mode = str(params.get("dim_mode", "2d"))
|
|
852
|
+
dim = int(params.get("dim", 1))
|
|
853
|
+
dim1 = int(params.get("dim1", 1))
|
|
854
|
+
dim2 = int(params.get("dim2", 2))
|
|
855
|
+
mode = str(params.get("mode", "fr"))
|
|
856
|
+
top_percent = float(params.get("top_percent", 5.0))
|
|
857
|
+
view = str(params.get("view", "both"))
|
|
858
|
+
subsample = int(params.get("subsample", 2))
|
|
859
|
+
unfold = str(params.get("unfold", "square"))
|
|
860
|
+
skew_show_grid = bool(params.get("skew_show_grid", True))
|
|
861
|
+
skew_tiles = int(params.get("skew_tiles", 0))
|
|
862
|
+
|
|
863
|
+
def pick_coords(arr: np.ndarray) -> np.ndarray:
|
|
864
|
+
if dim_mode == "1d":
|
|
865
|
+
idx = max(0, dim - 1)
|
|
866
|
+
if idx >= arr.shape[1]:
|
|
867
|
+
raise ProcessingError(f"dim out of range for coords shape {arr.shape}")
|
|
868
|
+
one = arr[:, [idx]]
|
|
869
|
+
return np.hstack([one, np.zeros_like(one)])
|
|
870
|
+
idx1 = max(0, dim1 - 1)
|
|
871
|
+
idx2 = max(0, dim2 - 1)
|
|
872
|
+
if idx1 >= arr.shape[1] or idx2 >= arr.shape[1]:
|
|
873
|
+
raise ProcessingError(f"dim1/dim2 out of range for coords shape {arr.shape}")
|
|
874
|
+
return arr[:, [idx1, idx2]]
|
|
875
|
+
|
|
876
|
+
coords2 = pick_coords(coords)
|
|
877
|
+
coordsbox2 = pick_coords(coordsbox) if coordsbox.ndim == 2 else coords2
|
|
878
|
+
|
|
879
|
+
if mode == "spike":
|
|
880
|
+
activity = np.asarray(asa_data.get("spike"))
|
|
881
|
+
else:
|
|
882
|
+
activity = (
|
|
883
|
+
self._embed_data
|
|
884
|
+
if self._embed_data is not None
|
|
885
|
+
else np.asarray(asa_data.get("spike"))
|
|
886
|
+
)
|
|
887
|
+
|
|
888
|
+
decode_meta = self._load_cache_meta(
|
|
889
|
+
self._stage_cache_path(self._results_dir(state) / "CohoMap")
|
|
890
|
+
)
|
|
891
|
+
stage_hash = self._hash_obj(
|
|
892
|
+
{
|
|
893
|
+
"persistence": self._hash_file(persistence_path),
|
|
894
|
+
"decode_hash": decode_meta.get("decode_hash"),
|
|
895
|
+
"params": params,
|
|
896
|
+
}
|
|
897
|
+
)
|
|
898
|
+
meta_path = self._stage_cache_path(out_dir)
|
|
899
|
+
required = [out_dir / "cohospace_trajectory.png"]
|
|
900
|
+
view = str(params.get("view", "both"))
|
|
901
|
+
neuron_id = params.get("neuron_id")
|
|
902
|
+
if view in {"both", "population"}:
|
|
903
|
+
required.append(out_dir / "cohospace_population.png")
|
|
904
|
+
if neuron_id is not None and view in {"both", "single"}:
|
|
905
|
+
required.append(out_dir / f"cohospace_neuron_{neuron_id}.png")
|
|
906
|
+
|
|
907
|
+
if self._stage_cache_hit(out_dir, stage_hash, required):
|
|
908
|
+
log_callback("♻️ Using cached CohoSpace plots.")
|
|
909
|
+
artifacts = {"trajectory": out_dir / "cohospace_trajectory.png"}
|
|
910
|
+
if neuron_id is not None and view in {"both", "single"}:
|
|
911
|
+
artifacts["neuron"] = out_dir / f"cohospace_neuron_{neuron_id}.png"
|
|
912
|
+
if view in {"both", "population"}:
|
|
913
|
+
artifacts["population"] = out_dir / "cohospace_population.png"
|
|
914
|
+
return artifacts
|
|
915
|
+
|
|
916
|
+
log_callback("Plotting cohomology space trajectory...")
|
|
917
|
+
traj_path = out_dir / "cohospace_trajectory.png"
|
|
918
|
+
traj_cfg = PlotConfigs.cohospace_trajectory(show=False, save_path=str(traj_path))
|
|
919
|
+
plot_cohospace_trajectory(coords=coords2, times=None, subsample=subsample, config=traj_cfg)
|
|
920
|
+
artifacts["trajectory"] = traj_path
|
|
921
|
+
|
|
922
|
+
neuron_id = params.get("neuron_id", None)
|
|
923
|
+
if neuron_id is not None and view in {"both", "single"}:
|
|
924
|
+
log_callback(f"Plotting neuron {neuron_id}...")
|
|
925
|
+
neuron_path = out_dir / f"cohospace_neuron_{neuron_id}.png"
|
|
926
|
+
if unfold == "skew":
|
|
927
|
+
plot_cohospace_neuron_skewed(
|
|
928
|
+
coords=coordsbox2,
|
|
929
|
+
activity=activity,
|
|
930
|
+
neuron_id=int(neuron_id),
|
|
931
|
+
mode=mode,
|
|
932
|
+
top_percent=top_percent,
|
|
933
|
+
save_path=str(neuron_path),
|
|
934
|
+
show=False,
|
|
935
|
+
show_grid=skew_show_grid,
|
|
936
|
+
n_tiles=skew_tiles,
|
|
937
|
+
)
|
|
938
|
+
else:
|
|
939
|
+
neuron_cfg = PlotConfigs.cohospace_neuron(show=False, save_path=str(neuron_path))
|
|
940
|
+
plot_cohospace_neuron(
|
|
941
|
+
coords=coordsbox2,
|
|
942
|
+
activity=activity,
|
|
943
|
+
neuron_id=int(neuron_id),
|
|
944
|
+
mode=mode,
|
|
945
|
+
top_percent=top_percent,
|
|
946
|
+
config=neuron_cfg,
|
|
947
|
+
)
|
|
948
|
+
artifacts["neuron"] = neuron_path
|
|
949
|
+
|
|
950
|
+
if view in {"both", "population"}:
|
|
951
|
+
log_callback("Plotting population activity...")
|
|
952
|
+
pop_path = out_dir / "cohospace_population.png"
|
|
953
|
+
neuron_ids = list(range(activity.shape[1]))
|
|
954
|
+
if unfold == "skew":
|
|
955
|
+
plot_cohospace_population_skewed(
|
|
956
|
+
coords=coords2,
|
|
957
|
+
activity=activity,
|
|
958
|
+
neuron_ids=neuron_ids,
|
|
959
|
+
mode=mode,
|
|
960
|
+
top_percent=top_percent,
|
|
961
|
+
save_path=str(pop_path),
|
|
962
|
+
show=False,
|
|
963
|
+
show_grid=skew_show_grid,
|
|
964
|
+
n_tiles=skew_tiles,
|
|
965
|
+
)
|
|
966
|
+
else:
|
|
967
|
+
pop_cfg = PlotConfigs.cohospace_population(show=False, save_path=str(pop_path))
|
|
968
|
+
plot_cohospace_population(
|
|
969
|
+
coords=coords2,
|
|
970
|
+
activity=activity,
|
|
971
|
+
neuron_ids=neuron_ids,
|
|
972
|
+
mode=mode,
|
|
973
|
+
top_percent=top_percent,
|
|
974
|
+
config=pop_cfg,
|
|
975
|
+
)
|
|
976
|
+
artifacts["population"] = pop_path
|
|
977
|
+
|
|
978
|
+
self._write_cache_meta(meta_path, {"hash": stage_hash})
|
|
979
|
+
return artifacts
|
|
980
|
+
|
|
981
|
+
def _run_fr(
|
|
982
|
+
self, asa_data: dict[str, Any], state: WorkflowState, log_callback
|
|
983
|
+
) -> dict[str, Path]:
|
|
984
|
+
"""Run firing rate heatmap analysis."""
|
|
985
|
+
from canns.analyzer.data.asa import compute_fr_heatmap_matrix, save_fr_heatmap_png
|
|
986
|
+
from canns.analyzer.visualization import PlotConfigs
|
|
987
|
+
|
|
988
|
+
out_dir = self._results_dir(state) / "FR"
|
|
989
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
990
|
+
|
|
991
|
+
params = state.analysis_params
|
|
992
|
+
neuron_range = params.get("neuron_range", None)
|
|
993
|
+
time_range = params.get("time_range", None)
|
|
994
|
+
normalize = params.get("normalize", "zscore_per_neuron")
|
|
995
|
+
mode = params.get("mode", "fr")
|
|
996
|
+
|
|
997
|
+
if mode == "spike":
|
|
998
|
+
spike_data = asa_data.get("spike")
|
|
999
|
+
else:
|
|
1000
|
+
spike_data = self._embed_data
|
|
1001
|
+
|
|
1002
|
+
if spike_data is None:
|
|
1003
|
+
raise ProcessingError("No spike data available for FR.")
|
|
1004
|
+
|
|
1005
|
+
out_path = out_dir / "fr_heatmap.png"
|
|
1006
|
+
stage_hash = self._hash_obj(
|
|
1007
|
+
{
|
|
1008
|
+
"embed_hash": self._embed_hash,
|
|
1009
|
+
"params": params,
|
|
1010
|
+
}
|
|
1011
|
+
)
|
|
1012
|
+
if self._stage_cache_hit(out_dir, stage_hash, [out_path]):
|
|
1013
|
+
log_callback("♻️ Using cached FR heatmap.")
|
|
1014
|
+
return {"fr_heatmap": out_path}
|
|
1015
|
+
|
|
1016
|
+
log_callback("Computing firing rate heatmap...")
|
|
1017
|
+
fr_matrix = compute_fr_heatmap_matrix(
|
|
1018
|
+
spike_data,
|
|
1019
|
+
neuron_range=neuron_range,
|
|
1020
|
+
time_range=time_range,
|
|
1021
|
+
normalize=normalize,
|
|
1022
|
+
)
|
|
1023
|
+
|
|
1024
|
+
config = PlotConfigs.fr_heatmap(show=False, save_path=str(out_path))
|
|
1025
|
+
save_fr_heatmap_png(fr_matrix, config=config, dpi=200)
|
|
1026
|
+
|
|
1027
|
+
self._write_cache_meta(self._stage_cache_path(out_dir), {"hash": stage_hash})
|
|
1028
|
+
return {"fr_heatmap": out_path}
|
|
1029
|
+
|
|
1030
|
+
def _run_frm(
|
|
1031
|
+
self, asa_data: dict[str, Any], state: WorkflowState, log_callback
|
|
1032
|
+
) -> dict[str, Path]:
|
|
1033
|
+
"""Run single neuron firing rate map."""
|
|
1034
|
+
from canns.analyzer.data.asa import compute_frm, plot_frm
|
|
1035
|
+
from canns.analyzer.visualization import PlotConfigs
|
|
1036
|
+
|
|
1037
|
+
out_dir = self._results_dir(state) / "FRM"
|
|
1038
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
1039
|
+
|
|
1040
|
+
params = state.analysis_params
|
|
1041
|
+
neuron_id = int(params.get("neuron_id", 0))
|
|
1042
|
+
bins = int(params.get("bin_size", 50))
|
|
1043
|
+
min_occupancy = int(params.get("min_occupancy", 1))
|
|
1044
|
+
smoothing = bool(params.get("smoothing", False))
|
|
1045
|
+
smooth_sigma = float(params.get("smooth_sigma", 2.0))
|
|
1046
|
+
mode = str(params.get("mode", "fr"))
|
|
1047
|
+
|
|
1048
|
+
spike_data = self._embed_data if mode != "spike" else asa_data.get("spike")
|
|
1049
|
+
if spike_data is None:
|
|
1050
|
+
raise ProcessingError("No spike data available for FRM.")
|
|
1051
|
+
|
|
1052
|
+
pos = self._aligned_pos if self._aligned_pos is not None else asa_data
|
|
1053
|
+
x = np.asarray(pos.get("x"))
|
|
1054
|
+
y = np.asarray(pos.get("y"))
|
|
1055
|
+
|
|
1056
|
+
if x is None or y is None:
|
|
1057
|
+
raise ProcessingError("Position data (x,y) is required for FRM.")
|
|
1058
|
+
|
|
1059
|
+
out_path = out_dir / f"frm_neuron_{neuron_id}.png"
|
|
1060
|
+
stage_hash = self._hash_obj(
|
|
1061
|
+
{
|
|
1062
|
+
"embed_hash": self._embed_hash,
|
|
1063
|
+
"params": params,
|
|
1064
|
+
}
|
|
1065
|
+
)
|
|
1066
|
+
if self._stage_cache_hit(out_dir, stage_hash, [out_path]):
|
|
1067
|
+
log_callback("♻️ Using cached FRM.")
|
|
1068
|
+
return {"frm": out_path}
|
|
1069
|
+
|
|
1070
|
+
log_callback(f"Computing firing rate map for neuron {neuron_id}...")
|
|
1071
|
+
frm_result = compute_frm(
|
|
1072
|
+
spike_data,
|
|
1073
|
+
x,
|
|
1074
|
+
y,
|
|
1075
|
+
neuron_id=neuron_id,
|
|
1076
|
+
bins=max(1, bins),
|
|
1077
|
+
min_occupancy=min_occupancy,
|
|
1078
|
+
smoothing=smoothing,
|
|
1079
|
+
sigma=smooth_sigma,
|
|
1080
|
+
nan_for_empty=True,
|
|
1081
|
+
)
|
|
1082
|
+
|
|
1083
|
+
config = PlotConfigs.frm(show=False, save_path=str(out_path))
|
|
1084
|
+
plot_frm(frm_result.frm, config=config, dpi=200)
|
|
1085
|
+
|
|
1086
|
+
self._write_cache_meta(self._stage_cache_path(out_dir), {"hash": stage_hash})
|
|
1087
|
+
return {"frm": out_path}
|
|
1088
|
+
|
|
1089
|
+
def _run_gridscore(
|
|
1090
|
+
self, asa_data: dict[str, Any], state: WorkflowState, log_callback
|
|
1091
|
+
) -> dict[str, Path]:
|
|
1092
|
+
"""Run grid score analysis."""
|
|
1093
|
+
|
|
1094
|
+
log_callback("GridScore analysis is not implemented in the TUI yet.")
|
|
1095
|
+
raise ProcessingError("GridScore analysis is not implemented yet.")
|