canns 0.13.1__py3-none-any.whl → 0.14.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- canns/analyzer/data/__init__.py +5 -1
- canns/analyzer/data/asa/__init__.py +27 -12
- canns/analyzer/data/asa/cohospace.py +336 -10
- canns/analyzer/data/asa/config.py +3 -0
- canns/analyzer/data/asa/embedding.py +48 -45
- canns/analyzer/data/asa/path.py +104 -2
- canns/analyzer/data/asa/plotting.py +88 -19
- canns/analyzer/data/asa/tda.py +11 -4
- canns/analyzer/data/cell_classification/__init__.py +97 -0
- canns/analyzer/data/cell_classification/core/__init__.py +26 -0
- canns/analyzer/data/cell_classification/core/grid_cells.py +633 -0
- canns/analyzer/data/cell_classification/core/grid_modules_leiden.py +288 -0
- canns/analyzer/data/cell_classification/core/head_direction.py +347 -0
- canns/analyzer/data/cell_classification/core/spatial_analysis.py +431 -0
- canns/analyzer/data/cell_classification/io/__init__.py +5 -0
- canns/analyzer/data/cell_classification/io/matlab_loader.py +417 -0
- canns/analyzer/data/cell_classification/utils/__init__.py +39 -0
- canns/analyzer/data/cell_classification/utils/circular_stats.py +383 -0
- canns/analyzer/data/cell_classification/utils/correlation.py +318 -0
- canns/analyzer/data/cell_classification/utils/geometry.py +442 -0
- canns/analyzer/data/cell_classification/utils/image_processing.py +416 -0
- canns/analyzer/data/cell_classification/visualization/__init__.py +19 -0
- canns/analyzer/data/cell_classification/visualization/grid_plots.py +292 -0
- canns/analyzer/data/cell_classification/visualization/hd_plots.py +200 -0
- canns/analyzer/metrics/__init__.py +2 -1
- canns/analyzer/visualization/core/config.py +46 -4
- canns/data/__init__.py +6 -1
- canns/data/datasets.py +154 -1
- canns/data/loaders.py +37 -0
- canns/pipeline/__init__.py +13 -9
- canns/pipeline/__main__.py +6 -0
- canns/pipeline/asa/runner.py +105 -41
- canns/pipeline/asa_gui/__init__.py +68 -0
- canns/pipeline/asa_gui/__main__.py +6 -0
- canns/pipeline/asa_gui/analysis_modes/__init__.py +42 -0
- canns/pipeline/asa_gui/analysis_modes/base.py +39 -0
- canns/pipeline/asa_gui/analysis_modes/batch_mode.py +21 -0
- canns/pipeline/asa_gui/analysis_modes/cohomap_mode.py +56 -0
- canns/pipeline/asa_gui/analysis_modes/cohospace_mode.py +194 -0
- canns/pipeline/asa_gui/analysis_modes/decode_mode.py +52 -0
- canns/pipeline/asa_gui/analysis_modes/fr_mode.py +81 -0
- canns/pipeline/asa_gui/analysis_modes/frm_mode.py +92 -0
- canns/pipeline/asa_gui/analysis_modes/gridscore_mode.py +123 -0
- canns/pipeline/asa_gui/analysis_modes/pathcompare_mode.py +199 -0
- canns/pipeline/asa_gui/analysis_modes/tda_mode.py +112 -0
- canns/pipeline/asa_gui/app.py +29 -0
- canns/pipeline/asa_gui/controllers/__init__.py +6 -0
- canns/pipeline/asa_gui/controllers/analysis_controller.py +59 -0
- canns/pipeline/asa_gui/controllers/preprocess_controller.py +89 -0
- canns/pipeline/asa_gui/core/__init__.py +15 -0
- canns/pipeline/asa_gui/core/cache.py +14 -0
- canns/pipeline/asa_gui/core/runner.py +1936 -0
- canns/pipeline/asa_gui/core/state.py +324 -0
- canns/pipeline/asa_gui/core/worker.py +260 -0
- canns/pipeline/asa_gui/main_window.py +184 -0
- canns/pipeline/asa_gui/models/__init__.py +7 -0
- canns/pipeline/asa_gui/models/config.py +14 -0
- canns/pipeline/asa_gui/models/job.py +31 -0
- canns/pipeline/asa_gui/models/presets.py +21 -0
- canns/pipeline/asa_gui/resources/__init__.py +16 -0
- canns/pipeline/asa_gui/resources/dark.qss +167 -0
- canns/pipeline/asa_gui/resources/light.qss +163 -0
- canns/pipeline/asa_gui/resources/styles.qss +130 -0
- canns/pipeline/asa_gui/utils/__init__.py +1 -0
- canns/pipeline/asa_gui/utils/formatters.py +15 -0
- canns/pipeline/asa_gui/utils/io_adapters.py +40 -0
- canns/pipeline/asa_gui/utils/validators.py +41 -0
- canns/pipeline/asa_gui/views/__init__.py +1 -0
- canns/pipeline/asa_gui/views/help_content.py +171 -0
- canns/pipeline/asa_gui/views/pages/__init__.py +6 -0
- canns/pipeline/asa_gui/views/pages/analysis_page.py +565 -0
- canns/pipeline/asa_gui/views/pages/preprocess_page.py +492 -0
- canns/pipeline/asa_gui/views/panels/__init__.py +1 -0
- canns/pipeline/asa_gui/views/widgets/__init__.py +21 -0
- canns/pipeline/asa_gui/views/widgets/artifacts_tab.py +44 -0
- canns/pipeline/asa_gui/views/widgets/drop_zone.py +80 -0
- canns/pipeline/asa_gui/views/widgets/file_list.py +27 -0
- canns/pipeline/asa_gui/views/widgets/gridscore_tab.py +308 -0
- canns/pipeline/asa_gui/views/widgets/help_dialog.py +27 -0
- canns/pipeline/asa_gui/views/widgets/image_tab.py +50 -0
- canns/pipeline/asa_gui/views/widgets/image_viewer.py +97 -0
- canns/pipeline/asa_gui/views/widgets/log_box.py +16 -0
- canns/pipeline/asa_gui/views/widgets/pathcompare_tab.py +200 -0
- canns/pipeline/asa_gui/views/widgets/popup_combo.py +25 -0
- canns/pipeline/gallery/__init__.py +15 -5
- canns/pipeline/gallery/__main__.py +11 -0
- canns/pipeline/gallery/app.py +705 -0
- canns/pipeline/gallery/runner.py +790 -0
- canns/pipeline/gallery/state.py +51 -0
- canns/pipeline/gallery/styles.tcss +123 -0
- canns/pipeline/launcher.py +81 -0
- {canns-0.13.1.dist-info → canns-0.14.0.dist-info}/METADATA +11 -1
- canns-0.14.0.dist-info/RECORD +163 -0
- canns-0.14.0.dist-info/entry_points.txt +5 -0
- canns/pipeline/_base.py +0 -50
- canns-0.13.1.dist-info/RECORD +0 -89
- canns-0.13.1.dist-info/entry_points.txt +0 -3
- {canns-0.13.1.dist-info → canns-0.14.0.dist-info}/WHEEL +0 -0
- {canns-0.13.1.dist-info → canns-0.14.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,1936 @@
|
|
|
1
|
+
"""Pipeline execution wrapper for ASA GUI.
|
|
2
|
+
|
|
3
|
+
Provides synchronous pipeline execution that wraps canns.analyzer.data.asa APIs
|
|
4
|
+
and mirrors the TUI runner behavior for caching and artifacts.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import hashlib
|
|
10
|
+
import json
|
|
11
|
+
import time
|
|
12
|
+
from collections.abc import Callable
|
|
13
|
+
from dataclasses import dataclass
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from typing import Any
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
|
|
19
|
+
from .state import WorkflowState, resolve_path
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class PipelineResult:
|
|
24
|
+
"""Result from pipeline execution."""
|
|
25
|
+
|
|
26
|
+
success: bool
|
|
27
|
+
artifacts: dict[str, Path]
|
|
28
|
+
summary: str
|
|
29
|
+
error: str | None = None
|
|
30
|
+
elapsed_time: float = 0.0
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class ProcessingError(RuntimeError):
|
|
34
|
+
"""Raised when a pipeline stage fails."""
|
|
35
|
+
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class PipelineRunner:
|
|
40
|
+
"""Synchronous pipeline execution wrapper."""
|
|
41
|
+
|
|
42
|
+
def __init__(self) -> None:
|
|
43
|
+
self._asa_data: dict[str, Any] | None = None
|
|
44
|
+
self._embed_data: np.ndarray | None = None
|
|
45
|
+
self._aligned_pos: dict[str, np.ndarray] | None = None
|
|
46
|
+
self._input_hash: str | None = None
|
|
47
|
+
self._embed_hash: str | None = None
|
|
48
|
+
self._mpl_ready: bool = False
|
|
49
|
+
|
|
50
|
+
def has_preprocessed_data(self) -> bool:
|
|
51
|
+
return self._embed_data is not None
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def embed_data(self) -> np.ndarray | None:
|
|
55
|
+
return self._embed_data
|
|
56
|
+
|
|
57
|
+
@property
|
|
58
|
+
def aligned_pos(self) -> dict[str, np.ndarray] | None:
|
|
59
|
+
return self._aligned_pos
|
|
60
|
+
|
|
61
|
+
def reset_input(self) -> None:
|
|
62
|
+
self._asa_data = None
|
|
63
|
+
self._embed_data = None
|
|
64
|
+
self._aligned_pos = None
|
|
65
|
+
self._input_hash = None
|
|
66
|
+
self._embed_hash = None
|
|
67
|
+
|
|
68
|
+
def _json_safe(self, obj: Any) -> Any:
|
|
69
|
+
if isinstance(obj, Path):
|
|
70
|
+
return str(obj)
|
|
71
|
+
if isinstance(obj, tuple):
|
|
72
|
+
return [self._json_safe(v) for v in obj]
|
|
73
|
+
if isinstance(obj, list):
|
|
74
|
+
return [self._json_safe(v) for v in obj]
|
|
75
|
+
if isinstance(obj, dict):
|
|
76
|
+
return {str(k): self._json_safe(v) for k, v in obj.items()}
|
|
77
|
+
if hasattr(obj, "item") and callable(obj.item):
|
|
78
|
+
try:
|
|
79
|
+
return obj.item()
|
|
80
|
+
except Exception:
|
|
81
|
+
return str(obj)
|
|
82
|
+
return obj
|
|
83
|
+
|
|
84
|
+
def _hash_bytes(self, data: bytes) -> str:
|
|
85
|
+
return hashlib.md5(data).hexdigest()
|
|
86
|
+
|
|
87
|
+
def _hash_file(self, path: Path) -> str:
|
|
88
|
+
md5 = hashlib.md5()
|
|
89
|
+
with path.open("rb") as f:
|
|
90
|
+
for chunk in iter(lambda: f.read(1024 * 1024), b""):
|
|
91
|
+
md5.update(chunk)
|
|
92
|
+
return md5.hexdigest()
|
|
93
|
+
|
|
94
|
+
def _hash_obj(self, obj: Any) -> str:
|
|
95
|
+
payload = json.dumps(self._json_safe(obj), sort_keys=True, ensure_ascii=True).encode(
|
|
96
|
+
"utf-8"
|
|
97
|
+
)
|
|
98
|
+
return self._hash_bytes(payload)
|
|
99
|
+
|
|
100
|
+
def _ensure_matplotlib_backend(self) -> None:
|
|
101
|
+
if self._mpl_ready:
|
|
102
|
+
return
|
|
103
|
+
try:
|
|
104
|
+
import os
|
|
105
|
+
|
|
106
|
+
os.environ.setdefault("MPLBACKEND", "Agg")
|
|
107
|
+
import matplotlib
|
|
108
|
+
|
|
109
|
+
try:
|
|
110
|
+
matplotlib.use("Agg", force=True)
|
|
111
|
+
except TypeError:
|
|
112
|
+
matplotlib.use("Agg")
|
|
113
|
+
except Exception:
|
|
114
|
+
pass
|
|
115
|
+
self._mpl_ready = True
|
|
116
|
+
|
|
117
|
+
def _cache_dir(self, state: WorkflowState) -> Path:
|
|
118
|
+
return self._results_dir(state) / ".asa_cache"
|
|
119
|
+
|
|
120
|
+
def _results_dir(self, state: WorkflowState) -> Path:
|
|
121
|
+
base = state.workdir / "Results"
|
|
122
|
+
dataset_id = self._dataset_id(state)
|
|
123
|
+
return base / dataset_id
|
|
124
|
+
|
|
125
|
+
def results_dir(self, state: WorkflowState) -> Path:
|
|
126
|
+
return self._results_dir(state)
|
|
127
|
+
|
|
128
|
+
def _dataset_id(self, state: WorkflowState) -> str:
|
|
129
|
+
try:
|
|
130
|
+
input_hash = self._input_hash or self._compute_input_hash(state)
|
|
131
|
+
except Exception:
|
|
132
|
+
input_hash = "unknown"
|
|
133
|
+
prefix = input_hash[:8]
|
|
134
|
+
|
|
135
|
+
if state.input_mode == "asa":
|
|
136
|
+
path = resolve_path(state, state.asa_file)
|
|
137
|
+
stem = path.stem if path is not None else "asa"
|
|
138
|
+
return f"{stem}_{prefix}"
|
|
139
|
+
if state.input_mode == "neuron_traj":
|
|
140
|
+
neuron_path = resolve_path(state, state.neuron_file)
|
|
141
|
+
traj_path = resolve_path(state, state.traj_file)
|
|
142
|
+
neuron_stem = neuron_path.stem if neuron_path is not None else "neuron"
|
|
143
|
+
traj_stem = traj_path.stem if traj_path is not None else "traj"
|
|
144
|
+
return f"{neuron_stem}_{traj_stem}_{prefix}"
|
|
145
|
+
return f"{state.input_mode}_{prefix}"
|
|
146
|
+
|
|
147
|
+
def _stage_cache_path(self, stage_dir: Path) -> Path:
|
|
148
|
+
return stage_dir / "cache.json"
|
|
149
|
+
|
|
150
|
+
def _load_cache_meta(self, path: Path) -> dict[str, Any]:
|
|
151
|
+
if not path.exists():
|
|
152
|
+
return {}
|
|
153
|
+
try:
|
|
154
|
+
return json.loads(path.read_text(encoding="utf-8"))
|
|
155
|
+
except Exception:
|
|
156
|
+
return {}
|
|
157
|
+
|
|
158
|
+
def _write_cache_meta(self, path: Path, payload: dict[str, Any]) -> None:
|
|
159
|
+
path.write_text(
|
|
160
|
+
json.dumps(self._json_safe(payload), ensure_ascii=True, indent=2), encoding="utf-8"
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
def _stage_cache_hit(
|
|
164
|
+
self, stage_dir: Path, expected_hash: str, required_files: list[Path]
|
|
165
|
+
) -> bool:
|
|
166
|
+
if not all(p.exists() for p in required_files):
|
|
167
|
+
return False
|
|
168
|
+
meta = self._load_cache_meta(self._stage_cache_path(stage_dir))
|
|
169
|
+
return meta.get("hash") == expected_hash
|
|
170
|
+
|
|
171
|
+
def _compute_input_hash(self, state: WorkflowState) -> str:
|
|
172
|
+
if state.input_mode == "asa":
|
|
173
|
+
path = resolve_path(state, state.asa_file)
|
|
174
|
+
if path is None:
|
|
175
|
+
raise ProcessingError("ASA file not set.")
|
|
176
|
+
return self._hash_obj({"mode": "asa", "file": self._hash_file(path)})
|
|
177
|
+
if state.input_mode == "neuron_traj":
|
|
178
|
+
neuron_path = resolve_path(state, state.neuron_file)
|
|
179
|
+
traj_path = resolve_path(state, state.traj_file)
|
|
180
|
+
if neuron_path is None or traj_path is None:
|
|
181
|
+
raise ProcessingError("Neuron/trajectory files not set.")
|
|
182
|
+
return self._hash_obj(
|
|
183
|
+
{
|
|
184
|
+
"mode": "neuron_traj",
|
|
185
|
+
"neuron": self._hash_file(neuron_path),
|
|
186
|
+
"traj": self._hash_file(traj_path),
|
|
187
|
+
}
|
|
188
|
+
)
|
|
189
|
+
return self._hash_obj({"mode": state.input_mode})
|
|
190
|
+
|
|
191
|
+
def _load_npz_dict(self, path: Path) -> dict[str, Any]:
|
|
192
|
+
data = np.load(path, allow_pickle=True)
|
|
193
|
+
for key in ("persistence_result", "decode_result"):
|
|
194
|
+
if key in data.files:
|
|
195
|
+
return data[key].item()
|
|
196
|
+
return {k: data[k] for k in data.files}
|
|
197
|
+
|
|
198
|
+
def _build_spike_matrix_from_events(self, asa: dict[str, Any]) -> np.ndarray:
|
|
199
|
+
if "t" not in asa:
|
|
200
|
+
raise ProcessingError("asa dict missing key 't' for spike mode.")
|
|
201
|
+
t = np.asarray(asa["t"])
|
|
202
|
+
if t.ndim != 1:
|
|
203
|
+
raise ProcessingError(f"asa['t'] must be 1D, got shape={t.shape}")
|
|
204
|
+
total_steps = t.shape[0]
|
|
205
|
+
|
|
206
|
+
raw = asa.get("spike")
|
|
207
|
+
if raw is None:
|
|
208
|
+
raise ProcessingError("asa dict missing key 'spike' for spike mode.")
|
|
209
|
+
arr = np.asarray(raw)
|
|
210
|
+
|
|
211
|
+
if isinstance(raw, np.ndarray) and arr.ndim == 2 and np.issubdtype(arr.dtype, np.number):
|
|
212
|
+
if arr.shape[0] != total_steps:
|
|
213
|
+
raise ProcessingError(
|
|
214
|
+
f"asa['spike'] matrix first dim {arr.shape[0]} != len(t)={total_steps}"
|
|
215
|
+
)
|
|
216
|
+
return arr.astype(float, copy=False)
|
|
217
|
+
|
|
218
|
+
if isinstance(raw, np.ndarray) and arr.dtype == object and arr.size == 1:
|
|
219
|
+
raw = arr.item()
|
|
220
|
+
|
|
221
|
+
if isinstance(raw, dict):
|
|
222
|
+
keys = sorted(raw.keys())
|
|
223
|
+
spike_dict = {k: np.asarray(raw[k], dtype=float).ravel() for k in keys}
|
|
224
|
+
elif isinstance(raw, (list, tuple)):
|
|
225
|
+
keys = list(range(len(raw)))
|
|
226
|
+
spike_dict = {i: np.asarray(raw[i], dtype=float).ravel() for i in keys}
|
|
227
|
+
else:
|
|
228
|
+
raise ProcessingError(
|
|
229
|
+
"asa['spike'] must be a (T,N) numeric array, dict, or list-of-arrays for spike mode."
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
neuron_count = len(spike_dict)
|
|
233
|
+
spike_mat = np.zeros((total_steps, neuron_count), dtype=float)
|
|
234
|
+
if total_steps > 1:
|
|
235
|
+
dt = float(t[1] - t[0])
|
|
236
|
+
else:
|
|
237
|
+
dt = 1.0
|
|
238
|
+
t0 = float(t[0])
|
|
239
|
+
|
|
240
|
+
for col, key in enumerate(keys):
|
|
241
|
+
times = spike_dict[key]
|
|
242
|
+
if times.size == 0:
|
|
243
|
+
continue
|
|
244
|
+
idx = np.rint((times - t0) / dt).astype(int)
|
|
245
|
+
idx = idx[(idx >= 0) & (idx < total_steps)]
|
|
246
|
+
if idx.size == 0:
|
|
247
|
+
continue
|
|
248
|
+
np.add.at(spike_mat[:, col], idx, 1.0)
|
|
249
|
+
|
|
250
|
+
return spike_mat
|
|
251
|
+
|
|
252
|
+
def _check_cancel(self, cancel_check: Callable[[], bool] | None) -> None:
|
|
253
|
+
if cancel_check and cancel_check():
|
|
254
|
+
raise ProcessingError("Cancelled by user")
|
|
255
|
+
|
|
256
|
+
def run_preprocessing(
|
|
257
|
+
self,
|
|
258
|
+
state: WorkflowState,
|
|
259
|
+
log_callback: Callable[[str], None],
|
|
260
|
+
progress_callback: Callable[[int], None],
|
|
261
|
+
cancel_check: Callable[[], bool] | None = None,
|
|
262
|
+
) -> PipelineResult:
|
|
263
|
+
t0 = time.time()
|
|
264
|
+
|
|
265
|
+
try:
|
|
266
|
+
self._check_cancel(cancel_check)
|
|
267
|
+
|
|
268
|
+
log_callback("Loading data...")
|
|
269
|
+
progress_callback(10)
|
|
270
|
+
asa_data = self._load_data(state)
|
|
271
|
+
self._asa_data = asa_data
|
|
272
|
+
self._aligned_pos = None
|
|
273
|
+
self._input_hash = self._compute_input_hash(state)
|
|
274
|
+
|
|
275
|
+
self._check_cancel(cancel_check)
|
|
276
|
+
|
|
277
|
+
log_callback(f"Preprocessing with {state.preprocess_method}...")
|
|
278
|
+
progress_callback(30)
|
|
279
|
+
|
|
280
|
+
if state.preprocess_method == "embed_spike_trains":
|
|
281
|
+
from canns.analyzer.data.asa import SpikeEmbeddingConfig, embed_spike_trains
|
|
282
|
+
|
|
283
|
+
params = state.preprocess_params if state.preprocess_params else {}
|
|
284
|
+
base_config = SpikeEmbeddingConfig()
|
|
285
|
+
effective_params = {
|
|
286
|
+
"res": base_config.res,
|
|
287
|
+
"dt": base_config.dt,
|
|
288
|
+
"sigma": base_config.sigma,
|
|
289
|
+
"smooth": base_config.smooth,
|
|
290
|
+
"speed_filter": base_config.speed_filter,
|
|
291
|
+
"min_speed": base_config.min_speed,
|
|
292
|
+
}
|
|
293
|
+
effective_params.update(params)
|
|
294
|
+
|
|
295
|
+
self._embed_hash = self._hash_obj(
|
|
296
|
+
{
|
|
297
|
+
"input_hash": self._input_hash,
|
|
298
|
+
"method": state.preprocess_method,
|
|
299
|
+
"params": effective_params,
|
|
300
|
+
}
|
|
301
|
+
)
|
|
302
|
+
cache_dir = self._cache_dir(state)
|
|
303
|
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
|
304
|
+
cache_path = cache_dir / f"embed_{self._embed_hash}.npz"
|
|
305
|
+
meta_path = cache_dir / f"embed_{self._embed_hash}.json"
|
|
306
|
+
|
|
307
|
+
if cache_path.exists():
|
|
308
|
+
log_callback("♻️ Using cached embedding.")
|
|
309
|
+
cached = np.load(cache_path, allow_pickle=True)
|
|
310
|
+
self._embed_data = cached["embed_data"]
|
|
311
|
+
if {"x", "y", "t"}.issubset(set(cached.files)):
|
|
312
|
+
self._aligned_pos = {
|
|
313
|
+
"x": cached["x"],
|
|
314
|
+
"y": cached["y"],
|
|
315
|
+
"t": cached["t"],
|
|
316
|
+
}
|
|
317
|
+
progress_callback(100)
|
|
318
|
+
elapsed = time.time() - t0
|
|
319
|
+
return PipelineResult(
|
|
320
|
+
success=True,
|
|
321
|
+
artifacts={"embedding": cache_path},
|
|
322
|
+
summary=f"Preprocessing reused cached embedding in {elapsed:.1f}s",
|
|
323
|
+
elapsed_time=elapsed,
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
config = SpikeEmbeddingConfig(**effective_params)
|
|
327
|
+
|
|
328
|
+
log_callback("Running embed_spike_trains...")
|
|
329
|
+
progress_callback(50)
|
|
330
|
+
embed_result = embed_spike_trains(asa_data, config)
|
|
331
|
+
|
|
332
|
+
if isinstance(embed_result, tuple):
|
|
333
|
+
embed_data = embed_result[0]
|
|
334
|
+
if len(embed_result) >= 4 and embed_result[1] is not None:
|
|
335
|
+
self._aligned_pos = {
|
|
336
|
+
"x": embed_result[1],
|
|
337
|
+
"y": embed_result[2],
|
|
338
|
+
"t": embed_result[3],
|
|
339
|
+
}
|
|
340
|
+
else:
|
|
341
|
+
embed_data = embed_result
|
|
342
|
+
|
|
343
|
+
self._embed_data = embed_data
|
|
344
|
+
log_callback(f"Embed data shape: {embed_data.shape}")
|
|
345
|
+
|
|
346
|
+
try:
|
|
347
|
+
payload = {"embed_data": embed_data}
|
|
348
|
+
if self._aligned_pos is not None:
|
|
349
|
+
payload.update(self._aligned_pos)
|
|
350
|
+
np.savez_compressed(cache_path, **payload)
|
|
351
|
+
self._write_cache_meta(
|
|
352
|
+
meta_path,
|
|
353
|
+
{
|
|
354
|
+
"hash": self._embed_hash,
|
|
355
|
+
"input_hash": self._input_hash,
|
|
356
|
+
"params": effective_params,
|
|
357
|
+
},
|
|
358
|
+
)
|
|
359
|
+
except Exception as e:
|
|
360
|
+
log_callback(f"Warning: failed to cache embedding: {e}")
|
|
361
|
+
else:
|
|
362
|
+
log_callback("No preprocessing - using raw spike data")
|
|
363
|
+
spike = asa_data.get("spike")
|
|
364
|
+
self._embed_hash = self._hash_obj(
|
|
365
|
+
{
|
|
366
|
+
"input_hash": self._input_hash,
|
|
367
|
+
"method": state.preprocess_method,
|
|
368
|
+
"params": {},
|
|
369
|
+
}
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
if isinstance(spike, np.ndarray) and spike.ndim == 2:
|
|
373
|
+
self._embed_data = spike
|
|
374
|
+
log_callback(f"Using spike matrix shape: {spike.shape}")
|
|
375
|
+
else:
|
|
376
|
+
log_callback(
|
|
377
|
+
"Warning: spike data is not a dense matrix, some analyses may fail"
|
|
378
|
+
)
|
|
379
|
+
self._embed_data = spike
|
|
380
|
+
|
|
381
|
+
progress_callback(100)
|
|
382
|
+
elapsed = time.time() - t0
|
|
383
|
+
|
|
384
|
+
return PipelineResult(
|
|
385
|
+
success=True,
|
|
386
|
+
artifacts={},
|
|
387
|
+
summary=f"Preprocessing completed in {elapsed:.1f}s",
|
|
388
|
+
elapsed_time=elapsed,
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
except Exception as e:
|
|
392
|
+
elapsed = time.time() - t0
|
|
393
|
+
log_callback(f"Error: {e}")
|
|
394
|
+
return PipelineResult(
|
|
395
|
+
success=False,
|
|
396
|
+
artifacts={},
|
|
397
|
+
summary=f"Failed after {elapsed:.1f}s",
|
|
398
|
+
error=str(e),
|
|
399
|
+
elapsed_time=elapsed,
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
def run_analysis(
|
|
403
|
+
self,
|
|
404
|
+
state: WorkflowState,
|
|
405
|
+
log_callback: Callable[[str], None],
|
|
406
|
+
progress_callback: Callable[[int], None],
|
|
407
|
+
cancel_check: Callable[[], bool] | None = None,
|
|
408
|
+
) -> PipelineResult:
|
|
409
|
+
t0 = time.time()
|
|
410
|
+
artifacts: dict[str, Path] = {}
|
|
411
|
+
|
|
412
|
+
try:
|
|
413
|
+
self._check_cancel(cancel_check)
|
|
414
|
+
|
|
415
|
+
log_callback("Loading data...")
|
|
416
|
+
progress_callback(10)
|
|
417
|
+
asa_data = self._asa_data if self._asa_data is not None else self._load_data(state)
|
|
418
|
+
if self._input_hash is None:
|
|
419
|
+
self._input_hash = self._compute_input_hash(state)
|
|
420
|
+
|
|
421
|
+
self._ensure_matplotlib_backend()
|
|
422
|
+
|
|
423
|
+
log_callback(f"Running {state.analysis_mode} analysis...")
|
|
424
|
+
progress_callback(40)
|
|
425
|
+
|
|
426
|
+
mode = state.analysis_mode.lower()
|
|
427
|
+
if mode == "tda":
|
|
428
|
+
artifacts = self._run_tda(asa_data, state, log_callback)
|
|
429
|
+
elif mode == "decode":
|
|
430
|
+
artifacts = self._run_decode_only(asa_data, state, log_callback)
|
|
431
|
+
elif mode == "cohomap":
|
|
432
|
+
artifacts = self._run_cohomap(asa_data, state, log_callback)
|
|
433
|
+
elif mode == "pathcompare":
|
|
434
|
+
artifacts = self._run_pathcompare(asa_data, state, log_callback)
|
|
435
|
+
elif mode == "cohospace":
|
|
436
|
+
artifacts = self._run_cohospace(asa_data, state, log_callback)
|
|
437
|
+
elif mode == "fr":
|
|
438
|
+
artifacts = self._run_fr(asa_data, state, log_callback)
|
|
439
|
+
elif mode == "frm":
|
|
440
|
+
artifacts = self._run_frm(asa_data, state, log_callback)
|
|
441
|
+
elif mode == "gridscore":
|
|
442
|
+
artifacts = self._run_gridscore(asa_data, state, log_callback, progress_callback)
|
|
443
|
+
elif mode == "gridscore_inspect":
|
|
444
|
+
artifacts = self._run_gridscore_inspect(
|
|
445
|
+
asa_data, state, log_callback, progress_callback
|
|
446
|
+
)
|
|
447
|
+
else:
|
|
448
|
+
raise ProcessingError(f"Unknown analysis mode: {state.analysis_mode}")
|
|
449
|
+
|
|
450
|
+
progress_callback(100)
|
|
451
|
+
elapsed = time.time() - t0
|
|
452
|
+
|
|
453
|
+
return PipelineResult(
|
|
454
|
+
success=True,
|
|
455
|
+
artifacts=artifacts,
|
|
456
|
+
summary=f"Completed {state.analysis_mode} analysis in {elapsed:.1f}s",
|
|
457
|
+
elapsed_time=elapsed,
|
|
458
|
+
)
|
|
459
|
+
|
|
460
|
+
except Exception as e:
|
|
461
|
+
elapsed = time.time() - t0
|
|
462
|
+
log_callback(f"Error: {e}")
|
|
463
|
+
return PipelineResult(
|
|
464
|
+
success=False,
|
|
465
|
+
artifacts=artifacts,
|
|
466
|
+
summary=f"Failed after {elapsed:.1f}s",
|
|
467
|
+
error=str(e),
|
|
468
|
+
elapsed_time=elapsed,
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
def _load_data(self, state: WorkflowState) -> dict[str, Any]:
|
|
472
|
+
if state.input_mode == "asa":
|
|
473
|
+
path = resolve_path(state, state.asa_file)
|
|
474
|
+
data = np.load(path, allow_pickle=True)
|
|
475
|
+
return {k: data[k] for k in data.files}
|
|
476
|
+
if state.input_mode == "neuron_traj":
|
|
477
|
+
neuron_path = resolve_path(state, state.neuron_file)
|
|
478
|
+
traj_path = resolve_path(state, state.traj_file)
|
|
479
|
+
neuron_data = np.load(neuron_path, allow_pickle=True)
|
|
480
|
+
traj_data = np.load(traj_path, allow_pickle=True)
|
|
481
|
+
return {
|
|
482
|
+
"spike": neuron_data.get("spike", neuron_data),
|
|
483
|
+
"x": traj_data["x"],
|
|
484
|
+
"y": traj_data["y"],
|
|
485
|
+
"t": traj_data["t"],
|
|
486
|
+
}
|
|
487
|
+
raise ProcessingError(f"Unknown input mode: {state.input_mode}")
|
|
488
|
+
|
|
489
|
+
def _run_tda(
|
|
490
|
+
self, asa_data: dict[str, Any], state: WorkflowState, log_callback
|
|
491
|
+
) -> dict[str, Path]:
|
|
492
|
+
from canns.analyzer.data.asa import TDAConfig, tda_vis
|
|
493
|
+
from canns.analyzer.data.asa.tda import _plot_barcode, _plot_barcode_with_shuffle
|
|
494
|
+
|
|
495
|
+
out_dir = self._results_dir(state) / "TDA"
|
|
496
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
497
|
+
|
|
498
|
+
params = state.analysis_params
|
|
499
|
+
config = TDAConfig(
|
|
500
|
+
dim=params.get("dim", 6),
|
|
501
|
+
num_times=params.get("num_times", 5),
|
|
502
|
+
active_times=params.get("active_times", 15000),
|
|
503
|
+
k=params.get("k", 1000),
|
|
504
|
+
n_points=params.get("n_points", 1200),
|
|
505
|
+
metric=params.get("metric", "cosine"),
|
|
506
|
+
nbs=params.get("nbs", 800),
|
|
507
|
+
maxdim=params.get("maxdim", 1),
|
|
508
|
+
coeff=params.get("coeff", 47),
|
|
509
|
+
show=False,
|
|
510
|
+
do_shuffle=params.get("do_shuffle", False),
|
|
511
|
+
num_shuffles=params.get("num_shuffles", 1000),
|
|
512
|
+
progress_bar=False,
|
|
513
|
+
standardize=False,
|
|
514
|
+
)
|
|
515
|
+
|
|
516
|
+
log_callback("Computing persistent homology...")
|
|
517
|
+
|
|
518
|
+
if self._embed_data is None:
|
|
519
|
+
raise ProcessingError("No preprocessed data available. Run preprocessing first.")
|
|
520
|
+
if not isinstance(self._embed_data, np.ndarray) or self._embed_data.ndim != 2:
|
|
521
|
+
raise ProcessingError(
|
|
522
|
+
"TDA requires a dense spike matrix (T,N). "
|
|
523
|
+
"Please choose 'Embed Spike Trains' in preprocessing or provide a dense spike matrix in the .npz."
|
|
524
|
+
)
|
|
525
|
+
|
|
526
|
+
persistence_path = out_dir / "persistence.npz"
|
|
527
|
+
barcode_path = out_dir / "barcode.png"
|
|
528
|
+
|
|
529
|
+
embed_hash = self._embed_hash or self._hash_obj({"embed": "unknown"})
|
|
530
|
+
tda_hash = self._hash_obj({"embed_hash": embed_hash, "params": params})
|
|
531
|
+
|
|
532
|
+
if self._stage_cache_hit(out_dir, tda_hash, [persistence_path, barcode_path]):
|
|
533
|
+
log_callback("♻️ Using cached TDA results.")
|
|
534
|
+
return {"persistence": persistence_path, "barcode": barcode_path}
|
|
535
|
+
|
|
536
|
+
embed_data = self._embed_data
|
|
537
|
+
if params.get("standardize", False):
|
|
538
|
+
try:
|
|
539
|
+
from sklearn.preprocessing import StandardScaler
|
|
540
|
+
|
|
541
|
+
embed_data = StandardScaler().fit_transform(embed_data)
|
|
542
|
+
except Exception as e:
|
|
543
|
+
raise ProcessingError(f"StandardScaler failed: {e}") from e
|
|
544
|
+
|
|
545
|
+
result = tda_vis(
|
|
546
|
+
embed_data=embed_data,
|
|
547
|
+
config=config,
|
|
548
|
+
)
|
|
549
|
+
|
|
550
|
+
np.savez_compressed(persistence_path, persistence_result=result)
|
|
551
|
+
|
|
552
|
+
try:
|
|
553
|
+
persistence = result.get("persistence")
|
|
554
|
+
shuffle_max = result.get("shuffle_max")
|
|
555
|
+
if config.do_shuffle and shuffle_max is not None:
|
|
556
|
+
fig = _plot_barcode_with_shuffle(persistence, shuffle_max)
|
|
557
|
+
else:
|
|
558
|
+
fig = _plot_barcode(persistence)
|
|
559
|
+
fig.savefig(barcode_path, dpi=200, bbox_inches="tight")
|
|
560
|
+
try:
|
|
561
|
+
import matplotlib.pyplot as plt
|
|
562
|
+
|
|
563
|
+
plt.close(fig)
|
|
564
|
+
except Exception:
|
|
565
|
+
pass
|
|
566
|
+
except Exception as e:
|
|
567
|
+
log_callback(f"Warning: failed to save barcode: {e}")
|
|
568
|
+
|
|
569
|
+
self._write_cache_meta(
|
|
570
|
+
self._stage_cache_path(out_dir),
|
|
571
|
+
{"hash": tda_hash, "embed_hash": embed_hash, "params": params},
|
|
572
|
+
)
|
|
573
|
+
|
|
574
|
+
return {"persistence": persistence_path, "barcode": barcode_path}
|
|
575
|
+
|
|
576
|
+
def _load_or_run_decode(
|
|
577
|
+
self,
|
|
578
|
+
asa_data: dict[str, Any],
|
|
579
|
+
persistence_path: Path,
|
|
580
|
+
state: WorkflowState,
|
|
581
|
+
log_callback,
|
|
582
|
+
) -> dict[str, Any]:
|
|
583
|
+
from canns.analyzer.data.asa import (
|
|
584
|
+
decode_circular_coordinates,
|
|
585
|
+
decode_circular_coordinates_multi,
|
|
586
|
+
)
|
|
587
|
+
|
|
588
|
+
decode_dir = self._results_dir(state) / "CohoMap"
|
|
589
|
+
decode_dir.mkdir(parents=True, exist_ok=True)
|
|
590
|
+
decode_path = decode_dir / "decoding.npz"
|
|
591
|
+
|
|
592
|
+
params = state.analysis_params
|
|
593
|
+
decode_version = str(params.get("decode_version", "v2"))
|
|
594
|
+
num_circ = int(params.get("num_circ", 2))
|
|
595
|
+
decode_params = {
|
|
596
|
+
"real_ground": params.get("real_ground", True),
|
|
597
|
+
"real_of": params.get("real_of", True),
|
|
598
|
+
"decode_version": decode_version,
|
|
599
|
+
"num_circ": num_circ,
|
|
600
|
+
}
|
|
601
|
+
persistence_hash = self._hash_file(persistence_path)
|
|
602
|
+
decode_hash = self._hash_obj(
|
|
603
|
+
{"persistence_hash": persistence_hash, "params": decode_params}
|
|
604
|
+
)
|
|
605
|
+
|
|
606
|
+
meta_path = self._stage_cache_path(decode_dir)
|
|
607
|
+
meta = self._load_cache_meta(meta_path)
|
|
608
|
+
if decode_path.exists() and meta.get("decode_hash") == decode_hash:
|
|
609
|
+
log_callback("♻️ Using cached decoding.")
|
|
610
|
+
return self._load_npz_dict(decode_path)
|
|
611
|
+
|
|
612
|
+
log_callback("Decoding circular coordinates...")
|
|
613
|
+
persistence_result = self._load_npz_dict(persistence_path)
|
|
614
|
+
if decode_version == "v0":
|
|
615
|
+
decode_result = decode_circular_coordinates(
|
|
616
|
+
persistence_result=persistence_result,
|
|
617
|
+
spike_data=asa_data,
|
|
618
|
+
real_ground=decode_params["real_ground"],
|
|
619
|
+
real_of=decode_params["real_of"],
|
|
620
|
+
save_path=str(decode_path),
|
|
621
|
+
)
|
|
622
|
+
else:
|
|
623
|
+
if self._embed_data is None:
|
|
624
|
+
raise ProcessingError("No preprocessed data available for decode v2.")
|
|
625
|
+
spike_data = dict(asa_data)
|
|
626
|
+
spike_data["spike"] = self._embed_data
|
|
627
|
+
decode_result = decode_circular_coordinates_multi(
|
|
628
|
+
persistence_result=persistence_result,
|
|
629
|
+
spike_data=spike_data,
|
|
630
|
+
save_path=str(decode_path),
|
|
631
|
+
num_circ=num_circ,
|
|
632
|
+
)
|
|
633
|
+
|
|
634
|
+
meta["decode_hash"] = decode_hash
|
|
635
|
+
meta["persistence_hash"] = persistence_hash
|
|
636
|
+
meta["decode_params"] = decode_params
|
|
637
|
+
self._write_cache_meta(meta_path, meta)
|
|
638
|
+
return decode_result
|
|
639
|
+
|
|
640
|
+
def _run_decode_only(
|
|
641
|
+
self, asa_data: dict[str, Any], state: WorkflowState, log_callback
|
|
642
|
+
) -> dict[str, Path]:
|
|
643
|
+
tda_dir = self._results_dir(state) / "TDA"
|
|
644
|
+
persistence_path = tda_dir / "persistence.npz"
|
|
645
|
+
if not persistence_path.exists():
|
|
646
|
+
raise ProcessingError("TDA results not found. Run TDA analysis first.")
|
|
647
|
+
|
|
648
|
+
self._load_or_run_decode(asa_data, persistence_path, state, log_callback)
|
|
649
|
+
return {"decoding": self._results_dir(state) / "CohoMap" / "decoding.npz"}
|
|
650
|
+
|
|
651
|
+
def _run_cohomap(
|
|
652
|
+
self, asa_data: dict[str, Any], state: WorkflowState, log_callback
|
|
653
|
+
) -> dict[str, Path]:
|
|
654
|
+
from canns.analyzer.data.asa import plot_cohomap_multi
|
|
655
|
+
from canns.analyzer.visualization import PlotConfigs
|
|
656
|
+
|
|
657
|
+
tda_dir = self._results_dir(state) / "TDA"
|
|
658
|
+
persistence_path = tda_dir / "persistence.npz"
|
|
659
|
+
if not persistence_path.exists():
|
|
660
|
+
raise ProcessingError("TDA results not found. Run TDA analysis first.")
|
|
661
|
+
|
|
662
|
+
out_dir = self._results_dir(state) / "CohoMap"
|
|
663
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
664
|
+
|
|
665
|
+
decode_result = self._load_or_run_decode(asa_data, persistence_path, state, log_callback)
|
|
666
|
+
|
|
667
|
+
params = state.analysis_params
|
|
668
|
+
subsample = int(params.get("cohomap_subsample", 10))
|
|
669
|
+
|
|
670
|
+
cohomap_path = out_dir / "cohomap.png"
|
|
671
|
+
stage_hash = self._hash_obj(
|
|
672
|
+
{
|
|
673
|
+
"decode_hash": self._load_cache_meta(self._stage_cache_path(out_dir)).get(
|
|
674
|
+
"decode_hash"
|
|
675
|
+
),
|
|
676
|
+
"plot": "cohomap",
|
|
677
|
+
"subsample": subsample,
|
|
678
|
+
}
|
|
679
|
+
)
|
|
680
|
+
if self._stage_cache_hit(out_dir, stage_hash, [cohomap_path]):
|
|
681
|
+
log_callback("♻️ Using cached CohoMap plot.")
|
|
682
|
+
return {"decoding": out_dir / "decoding.npz", "cohomap": cohomap_path}
|
|
683
|
+
|
|
684
|
+
log_callback("Generating cohomology map...")
|
|
685
|
+
pos = self._aligned_pos if self._aligned_pos is not None else asa_data
|
|
686
|
+
config = PlotConfigs.cohomap(show=False, save_path=str(cohomap_path))
|
|
687
|
+
plot_cohomap_multi(
|
|
688
|
+
decoding_result=decode_result,
|
|
689
|
+
position_data={"x": pos["x"], "y": pos["y"]},
|
|
690
|
+
config=config,
|
|
691
|
+
subsample=subsample,
|
|
692
|
+
)
|
|
693
|
+
|
|
694
|
+
self._write_cache_meta(
|
|
695
|
+
self._stage_cache_path(out_dir),
|
|
696
|
+
{
|
|
697
|
+
**self._load_cache_meta(self._stage_cache_path(out_dir)),
|
|
698
|
+
"hash": stage_hash,
|
|
699
|
+
},
|
|
700
|
+
)
|
|
701
|
+
|
|
702
|
+
return {"decoding": out_dir / "decoding.npz", "cohomap": cohomap_path}
|
|
703
|
+
|
|
704
|
+
def _run_pathcompare(
|
|
705
|
+
self, asa_data: dict[str, Any], state: WorkflowState, log_callback
|
|
706
|
+
) -> dict[str, Path]:
|
|
707
|
+
from canns.analyzer.data.asa import (
|
|
708
|
+
align_coords_to_position_1d,
|
|
709
|
+
align_coords_to_position_2d,
|
|
710
|
+
apply_angle_scale,
|
|
711
|
+
plot_path_compare_1d,
|
|
712
|
+
plot_path_compare_2d,
|
|
713
|
+
)
|
|
714
|
+
from canns.analyzer.data.asa.path import (
|
|
715
|
+
find_coords_matrix,
|
|
716
|
+
find_times_box,
|
|
717
|
+
resolve_time_slice,
|
|
718
|
+
)
|
|
719
|
+
from canns.analyzer.visualization import PlotConfigs
|
|
720
|
+
|
|
721
|
+
tda_dir = self._results_dir(state) / "TDA"
|
|
722
|
+
persistence_path = tda_dir / "persistence.npz"
|
|
723
|
+
if not persistence_path.exists():
|
|
724
|
+
raise ProcessingError("TDA results not found. Run TDA analysis first.")
|
|
725
|
+
|
|
726
|
+
decode_result = self._load_or_run_decode(asa_data, persistence_path, state, log_callback)
|
|
727
|
+
|
|
728
|
+
out_dir = self._results_dir(state) / "PathCompare"
|
|
729
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
730
|
+
|
|
731
|
+
params = state.analysis_params or {}
|
|
732
|
+
pc_params = (
|
|
733
|
+
params.get("pathcompare") if isinstance(params.get("pathcompare"), dict) else params
|
|
734
|
+
)
|
|
735
|
+
|
|
736
|
+
def _param(key: str, default: Any = None) -> Any:
|
|
737
|
+
return pc_params.get(key, default) if isinstance(pc_params, dict) else default
|
|
738
|
+
|
|
739
|
+
angle_scale = _param("angle_scale", _param("theta_scale", "rad"))
|
|
740
|
+
dim_mode = _param("dim_mode", "2d")
|
|
741
|
+
dim = int(_param("dim", 1))
|
|
742
|
+
dim1 = int(_param("dim1", 1))
|
|
743
|
+
dim2 = int(_param("dim2", 2))
|
|
744
|
+
use_box = bool(_param("use_box", False))
|
|
745
|
+
interp_full = bool(_param("interp_full", True))
|
|
746
|
+
coords_key = _param("coords_key")
|
|
747
|
+
times_key = _param("times_key", _param("times_box_key"))
|
|
748
|
+
slice_mode = _param("slice_mode", "time")
|
|
749
|
+
tmin = _param("tmin")
|
|
750
|
+
tmax = _param("tmax")
|
|
751
|
+
imin = _param("imin")
|
|
752
|
+
imax = _param("imax")
|
|
753
|
+
stride = int(_param("stride", 1))
|
|
754
|
+
tail = int(_param("tail", 300))
|
|
755
|
+
fps = int(_param("fps", 20))
|
|
756
|
+
no_wrap = bool(_param("no_wrap", False))
|
|
757
|
+
animation_format = str(_param("animation_format", "none")).lower()
|
|
758
|
+
if animation_format not in {"none", "gif", "mp4"}:
|
|
759
|
+
animation_format = "none"
|
|
760
|
+
|
|
761
|
+
coords_raw, _ = find_coords_matrix(
|
|
762
|
+
decode_result,
|
|
763
|
+
coords_key=coords_key,
|
|
764
|
+
prefer_box_fallback=use_box,
|
|
765
|
+
)
|
|
766
|
+
|
|
767
|
+
if dim_mode == "1d":
|
|
768
|
+
idx = max(0, dim - 1)
|
|
769
|
+
if idx >= coords_raw.shape[1]:
|
|
770
|
+
raise ProcessingError(f"dim out of range for coords shape {coords_raw.shape}")
|
|
771
|
+
coords1 = coords_raw[:, idx]
|
|
772
|
+
else:
|
|
773
|
+
idx1 = max(0, dim1 - 1)
|
|
774
|
+
idx2 = max(0, dim2 - 1)
|
|
775
|
+
if idx1 >= coords_raw.shape[1] or idx2 >= coords_raw.shape[1]:
|
|
776
|
+
raise ProcessingError(f"dim1/dim2 out of range for coords shape {coords_raw.shape}")
|
|
777
|
+
coords2 = coords_raw[:, [idx1, idx2]]
|
|
778
|
+
|
|
779
|
+
pos = self._aligned_pos if self._aligned_pos is not None else asa_data
|
|
780
|
+
t_full = np.asarray(pos["t"]).ravel()
|
|
781
|
+
x_full = np.asarray(pos["x"]).ravel()
|
|
782
|
+
y_full = np.asarray(pos["y"]).ravel()
|
|
783
|
+
|
|
784
|
+
if use_box:
|
|
785
|
+
if times_key:
|
|
786
|
+
times_box = decode_result.get(times_key)
|
|
787
|
+
else:
|
|
788
|
+
times_box, _ = find_times_box(decode_result)
|
|
789
|
+
else:
|
|
790
|
+
times_box = None
|
|
791
|
+
|
|
792
|
+
log_callback("Aligning decoded coordinates to position...")
|
|
793
|
+
if dim_mode == "1d":
|
|
794
|
+
t_use, x_use, y_use, coords_use, _ = align_coords_to_position_1d(
|
|
795
|
+
t_full=t_full,
|
|
796
|
+
x_full=x_full,
|
|
797
|
+
y_full=y_full,
|
|
798
|
+
coords1=coords1,
|
|
799
|
+
use_box=use_box,
|
|
800
|
+
times_box=times_box,
|
|
801
|
+
interp_to_full=interp_full,
|
|
802
|
+
)
|
|
803
|
+
else:
|
|
804
|
+
t_use, x_use, y_use, coords_use, _ = align_coords_to_position_2d(
|
|
805
|
+
t_full=t_full,
|
|
806
|
+
x_full=x_full,
|
|
807
|
+
y_full=y_full,
|
|
808
|
+
coords2=coords2,
|
|
809
|
+
use_box=use_box,
|
|
810
|
+
times_box=times_box,
|
|
811
|
+
interp_to_full=interp_full,
|
|
812
|
+
)
|
|
813
|
+
scale = str(angle_scale) if str(angle_scale) in {"rad", "deg", "unit", "auto"} else "rad"
|
|
814
|
+
coords_use = apply_angle_scale(coords_use, scale)
|
|
815
|
+
if not no_wrap:
|
|
816
|
+
coords_use = np.mod(coords_use, 2 * np.pi)
|
|
817
|
+
|
|
818
|
+
if slice_mode == "index":
|
|
819
|
+
i0, i1 = resolve_time_slice(t_use, None, None, imin, imax)
|
|
820
|
+
else:
|
|
821
|
+
i0, i1 = resolve_time_slice(t_use, tmin, tmax, None, None)
|
|
822
|
+
|
|
823
|
+
stride = max(1, stride)
|
|
824
|
+
idx = slice(i0, i1, stride)
|
|
825
|
+
t_use = t_use[idx]
|
|
826
|
+
x_use = x_use[idx]
|
|
827
|
+
y_use = y_use[idx]
|
|
828
|
+
coords_use = coords_use[idx]
|
|
829
|
+
|
|
830
|
+
out_path = out_dir / "path_compare.png"
|
|
831
|
+
anim_path: Path | None = None
|
|
832
|
+
if animation_format == "gif":
|
|
833
|
+
anim_path = out_dir / "path_compare.gif"
|
|
834
|
+
elif animation_format == "mp4":
|
|
835
|
+
anim_path = out_dir / "path_compare.mp4"
|
|
836
|
+
decode_meta = self._load_cache_meta(
|
|
837
|
+
self._stage_cache_path(self._results_dir(state) / "CohoMap")
|
|
838
|
+
)
|
|
839
|
+
stage_hash = self._hash_obj(
|
|
840
|
+
{
|
|
841
|
+
"persistence": self._hash_file(persistence_path),
|
|
842
|
+
"decode_hash": decode_meta.get("decode_hash"),
|
|
843
|
+
"params": {
|
|
844
|
+
"angle_scale": scale,
|
|
845
|
+
"dim_mode": dim_mode,
|
|
846
|
+
"dim": dim,
|
|
847
|
+
"dim1": dim1,
|
|
848
|
+
"dim2": dim2,
|
|
849
|
+
"use_box": use_box,
|
|
850
|
+
"interp_full": interp_full,
|
|
851
|
+
"coords_key": coords_key,
|
|
852
|
+
"times_key": times_key,
|
|
853
|
+
"slice_mode": slice_mode,
|
|
854
|
+
"tmin": tmin,
|
|
855
|
+
"tmax": tmax,
|
|
856
|
+
"imin": imin,
|
|
857
|
+
"imax": imax,
|
|
858
|
+
"stride": stride,
|
|
859
|
+
"tail": tail,
|
|
860
|
+
"fps": fps,
|
|
861
|
+
"no_wrap": no_wrap,
|
|
862
|
+
"animation_format": animation_format,
|
|
863
|
+
},
|
|
864
|
+
}
|
|
865
|
+
)
|
|
866
|
+
required = [out_path]
|
|
867
|
+
if anim_path is not None:
|
|
868
|
+
required.append(anim_path)
|
|
869
|
+
if self._stage_cache_hit(out_dir, stage_hash, required):
|
|
870
|
+
log_callback("♻️ Using cached PathCompare plot.")
|
|
871
|
+
artifacts = {"path_compare": out_path}
|
|
872
|
+
if anim_path is not None:
|
|
873
|
+
if anim_path.suffix == ".gif":
|
|
874
|
+
artifacts["path_compare_gif"] = anim_path
|
|
875
|
+
else:
|
|
876
|
+
artifacts["path_compare_mp4"] = anim_path
|
|
877
|
+
return artifacts
|
|
878
|
+
|
|
879
|
+
log_callback("Generating path comparison...")
|
|
880
|
+
if dim_mode == "1d":
|
|
881
|
+
config = PlotConfigs.path_compare_1d(show=False, save_path=str(out_path))
|
|
882
|
+
plot_path_compare_1d(x_use, y_use, coords_use, config=config)
|
|
883
|
+
else:
|
|
884
|
+
config = PlotConfigs.path_compare_2d(show=False, save_path=str(out_path))
|
|
885
|
+
plot_path_compare_2d(x_use, y_use, coords_use, config=config)
|
|
886
|
+
|
|
887
|
+
artifacts: dict[str, Path] = {"path_compare": out_path}
|
|
888
|
+
|
|
889
|
+
if anim_path is not None:
|
|
890
|
+
try:
|
|
891
|
+
if dim_mode == "1d":
|
|
892
|
+
self._render_pathcompare_1d_animation(
|
|
893
|
+
x_use,
|
|
894
|
+
y_use,
|
|
895
|
+
coords_use,
|
|
896
|
+
t_use,
|
|
897
|
+
anim_path,
|
|
898
|
+
tail=tail,
|
|
899
|
+
fps=fps,
|
|
900
|
+
log_callback=log_callback,
|
|
901
|
+
)
|
|
902
|
+
else:
|
|
903
|
+
self._render_pathcompare_2d_animation(
|
|
904
|
+
x_use,
|
|
905
|
+
y_use,
|
|
906
|
+
coords_use,
|
|
907
|
+
t_use,
|
|
908
|
+
anim_path,
|
|
909
|
+
tail=tail,
|
|
910
|
+
fps=fps,
|
|
911
|
+
log_callback=log_callback,
|
|
912
|
+
)
|
|
913
|
+
if anim_path.suffix == ".gif":
|
|
914
|
+
artifacts["path_compare_gif"] = anim_path
|
|
915
|
+
else:
|
|
916
|
+
artifacts["path_compare_mp4"] = anim_path
|
|
917
|
+
except Exception as e:
|
|
918
|
+
log_callback(f"Warning: failed to render PathCompare animation: {e}")
|
|
919
|
+
|
|
920
|
+
self._write_cache_meta(self._stage_cache_path(out_dir), {"hash": stage_hash})
|
|
921
|
+
return artifacts
|
|
922
|
+
|
|
923
|
+
def _render_pathcompare_1d_animation(
|
|
924
|
+
self,
|
|
925
|
+
x: np.ndarray,
|
|
926
|
+
y: np.ndarray,
|
|
927
|
+
coords: np.ndarray,
|
|
928
|
+
t: np.ndarray,
|
|
929
|
+
save_path: Path,
|
|
930
|
+
*,
|
|
931
|
+
tail: int,
|
|
932
|
+
fps: int,
|
|
933
|
+
log_callback,
|
|
934
|
+
) -> None:
|
|
935
|
+
import matplotlib.pyplot as plt
|
|
936
|
+
|
|
937
|
+
x = np.asarray(x).ravel()
|
|
938
|
+
y = np.asarray(y).ravel()
|
|
939
|
+
theta = np.asarray(coords).ravel()
|
|
940
|
+
t = np.asarray(t).ravel()
|
|
941
|
+
|
|
942
|
+
n_frames = len(theta)
|
|
943
|
+
if n_frames == 0:
|
|
944
|
+
raise ProcessingError("PathCompare animation has no frames.")
|
|
945
|
+
|
|
946
|
+
# Downsample if too many frames for animation
|
|
947
|
+
if n_frames > 20000:
|
|
948
|
+
factor = int(np.ceil(n_frames / 20000))
|
|
949
|
+
idx = np.arange(0, n_frames, factor)
|
|
950
|
+
x = x[idx]
|
|
951
|
+
y = y[idx]
|
|
952
|
+
theta = theta[idx]
|
|
953
|
+
t = t[idx]
|
|
954
|
+
n_frames = len(theta)
|
|
955
|
+
log_callback(f"PathCompare animation downsampled by x{factor} (frames={n_frames}).")
|
|
956
|
+
|
|
957
|
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5), dpi=120)
|
|
958
|
+
|
|
959
|
+
ax1.set_title("Physical path")
|
|
960
|
+
ax1.set_xlabel("x")
|
|
961
|
+
ax1.set_ylabel("y")
|
|
962
|
+
ax1.set_aspect("equal")
|
|
963
|
+
x_min, x_max = np.min(x), np.max(x)
|
|
964
|
+
y_min, y_max = np.min(y), np.max(y)
|
|
965
|
+
pad_x = (x_max - x_min) * 0.05 if x_max > x_min else 1.0
|
|
966
|
+
pad_y = (y_max - y_min) * 0.05 if y_max > y_min else 1.0
|
|
967
|
+
ax1.set_xlim(x_min - pad_x, x_max + pad_x)
|
|
968
|
+
ax1.set_ylim(y_min - pad_y, y_max + pad_y)
|
|
969
|
+
|
|
970
|
+
ax2.set_title("Decoded coho path (1D)")
|
|
971
|
+
ax2.set_aspect("equal")
|
|
972
|
+
ax2.axis("off")
|
|
973
|
+
ax2.set_xlim(-1.2, 1.2)
|
|
974
|
+
ax2.set_ylim(-1.2, 1.2)
|
|
975
|
+
|
|
976
|
+
(phys_trail,) = ax1.plot([], [], lw=1.0)
|
|
977
|
+
phys_dot = ax1.scatter([], [], s=30)
|
|
978
|
+
(circ_trail,) = ax2.plot([], [], lw=1.0)
|
|
979
|
+
circ_dot = ax2.scatter([], [], s=30)
|
|
980
|
+
title_text = fig.suptitle("", y=1.02)
|
|
981
|
+
|
|
982
|
+
def update(k: int) -> None:
|
|
983
|
+
a0 = max(0, k - tail) if tail > 0 else 0
|
|
984
|
+
xs = x[a0 : k + 1]
|
|
985
|
+
ys = y[a0 : k + 1]
|
|
986
|
+
phys_trail.set_data(xs, ys)
|
|
987
|
+
phys_dot.set_offsets(np.array([[x[k], y[k]]]))
|
|
988
|
+
|
|
989
|
+
x_unit = np.cos(theta[a0 : k + 1])
|
|
990
|
+
y_unit = np.sin(theta[a0 : k + 1])
|
|
991
|
+
circ_trail.set_data(x_unit, y_unit)
|
|
992
|
+
circ_dot.set_offsets(np.array([[np.cos(theta[k]), np.sin(theta[k])]]))
|
|
993
|
+
|
|
994
|
+
title_text.set_text(f"t = {float(t[k]):.3f}s (frame {k + 1}/{n_frames})")
|
|
995
|
+
|
|
996
|
+
fig.tight_layout()
|
|
997
|
+
self._save_animation(fig, update, n_frames, save_path, fps, log_callback)
|
|
998
|
+
plt.close(fig)
|
|
999
|
+
|
|
1000
|
+
def _render_pathcompare_2d_animation(
|
|
1001
|
+
self,
|
|
1002
|
+
x: np.ndarray,
|
|
1003
|
+
y: np.ndarray,
|
|
1004
|
+
coords: np.ndarray,
|
|
1005
|
+
t: np.ndarray,
|
|
1006
|
+
save_path: Path,
|
|
1007
|
+
*,
|
|
1008
|
+
tail: int,
|
|
1009
|
+
fps: int,
|
|
1010
|
+
log_callback,
|
|
1011
|
+
) -> None:
|
|
1012
|
+
import matplotlib.pyplot as plt
|
|
1013
|
+
|
|
1014
|
+
from canns.analyzer.data.asa.path import (
|
|
1015
|
+
draw_base_parallelogram,
|
|
1016
|
+
skew_transform,
|
|
1017
|
+
snake_wrap_trail_in_parallelogram,
|
|
1018
|
+
)
|
|
1019
|
+
|
|
1020
|
+
x = np.asarray(x).ravel()
|
|
1021
|
+
y = np.asarray(y).ravel()
|
|
1022
|
+
coords = np.asarray(coords)
|
|
1023
|
+
t = np.asarray(t).ravel()
|
|
1024
|
+
|
|
1025
|
+
if coords.ndim != 2 or coords.shape[1] < 2:
|
|
1026
|
+
raise ProcessingError("PathCompare 2D animation requires coords with 2 columns.")
|
|
1027
|
+
|
|
1028
|
+
n_frames = len(coords)
|
|
1029
|
+
if n_frames == 0:
|
|
1030
|
+
raise ProcessingError("PathCompare animation has no frames.")
|
|
1031
|
+
|
|
1032
|
+
if n_frames > 20000:
|
|
1033
|
+
factor = int(np.ceil(n_frames / 20000))
|
|
1034
|
+
idx = np.arange(0, n_frames, factor)
|
|
1035
|
+
x = x[idx]
|
|
1036
|
+
y = y[idx]
|
|
1037
|
+
coords = coords[idx]
|
|
1038
|
+
t = t[idx]
|
|
1039
|
+
n_frames = len(coords)
|
|
1040
|
+
log_callback(f"PathCompare animation downsampled by x{factor} (frames={n_frames}).")
|
|
1041
|
+
|
|
1042
|
+
xy_skew = skew_transform(coords[:, :2])
|
|
1043
|
+
|
|
1044
|
+
e1 = np.array([2 * np.pi, 0.0])
|
|
1045
|
+
e2 = np.array([np.pi, np.sqrt(3) * np.pi])
|
|
1046
|
+
corners = np.vstack([[0.0, 0.0], e1, e2, e1 + e2])
|
|
1047
|
+
xm, ym = corners.min(axis=0)
|
|
1048
|
+
xM, yM = corners.max(axis=0)
|
|
1049
|
+
px2 = 0.05 * (xM - xm + 1e-9)
|
|
1050
|
+
py2 = 0.05 * (yM - ym + 1e-9)
|
|
1051
|
+
|
|
1052
|
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5), dpi=120)
|
|
1053
|
+
|
|
1054
|
+
ax1.set_title("Physical path")
|
|
1055
|
+
ax1.set_xlabel("x")
|
|
1056
|
+
ax1.set_ylabel("y")
|
|
1057
|
+
ax1.set_aspect("equal")
|
|
1058
|
+
x_min, x_max = np.min(x), np.max(x)
|
|
1059
|
+
y_min, y_max = np.min(y), np.max(y)
|
|
1060
|
+
pad_x = 0.05 * (x_max - x_min + 1e-9)
|
|
1061
|
+
pad_y = 0.05 * (y_max - y_min + 1e-9)
|
|
1062
|
+
ax1.set_xlim(x_min - pad_x, x_max + pad_x)
|
|
1063
|
+
ax1.set_ylim(y_min - pad_y, y_max + pad_y)
|
|
1064
|
+
|
|
1065
|
+
ax2.set_title("Torus path (skew)")
|
|
1066
|
+
ax2.set_xlabel(r"$\theta_1 + \frac{1}{2}\theta_2$")
|
|
1067
|
+
ax2.set_ylabel(r"$\frac{\sqrt{3}}{2}\theta_2$")
|
|
1068
|
+
ax2.set_aspect("equal")
|
|
1069
|
+
draw_base_parallelogram(ax2)
|
|
1070
|
+
ax2.set_xlim(xm - px2, xM + px2)
|
|
1071
|
+
ax2.set_ylim(ym - py2, yM + py2)
|
|
1072
|
+
|
|
1073
|
+
(phys_trail,) = ax1.plot([], [], lw=1.0)
|
|
1074
|
+
phys_dot = ax1.scatter([], [], s=30)
|
|
1075
|
+
(tor_trail,) = ax2.plot([], [], lw=1.0)
|
|
1076
|
+
tor_dot = ax2.scatter([], [], s=30)
|
|
1077
|
+
title_text = fig.suptitle("", y=1.02)
|
|
1078
|
+
|
|
1079
|
+
def update(k: int) -> None:
|
|
1080
|
+
a0 = max(0, k - tail) if tail > 0 else 0
|
|
1081
|
+
xs = x[a0 : k + 1]
|
|
1082
|
+
ys = y[a0 : k + 1]
|
|
1083
|
+
phys_trail.set_data(xs, ys)
|
|
1084
|
+
phys_dot.set_offsets(np.array([[x[k], y[k]]]))
|
|
1085
|
+
|
|
1086
|
+
seg = snake_wrap_trail_in_parallelogram(xy_skew[a0 : k + 1], e1=e1, e2=e2)
|
|
1087
|
+
tor_trail.set_data(seg[:, 0], seg[:, 1])
|
|
1088
|
+
tor_dot.set_offsets(np.array([[xy_skew[k, 0], xy_skew[k, 1]]]))
|
|
1089
|
+
|
|
1090
|
+
title_text.set_text(f"t = {float(t[k]):.3f}s (frame {k + 1}/{n_frames})")
|
|
1091
|
+
|
|
1092
|
+
fig.tight_layout()
|
|
1093
|
+
self._save_animation(fig, update, n_frames, save_path, fps, log_callback)
|
|
1094
|
+
plt.close(fig)
|
|
1095
|
+
|
|
1096
|
+
def _save_animation(
|
|
1097
|
+
self,
|
|
1098
|
+
fig,
|
|
1099
|
+
update_func,
|
|
1100
|
+
n_frames: int,
|
|
1101
|
+
save_path: Path,
|
|
1102
|
+
fps: int,
|
|
1103
|
+
log_callback,
|
|
1104
|
+
) -> None:
|
|
1105
|
+
if save_path.suffix.lower() == ".gif":
|
|
1106
|
+
from matplotlib.animation import FuncAnimation, PillowWriter
|
|
1107
|
+
|
|
1108
|
+
def _update(k: int):
|
|
1109
|
+
update_func(k)
|
|
1110
|
+
return []
|
|
1111
|
+
|
|
1112
|
+
interval_ms = int(1000 / max(1, fps))
|
|
1113
|
+
ani = FuncAnimation(fig, _update, frames=n_frames, interval=interval_ms, blit=True)
|
|
1114
|
+
|
|
1115
|
+
last_pct = {"v": -1}
|
|
1116
|
+
|
|
1117
|
+
def _progress(i: int, total: int) -> None:
|
|
1118
|
+
if not total:
|
|
1119
|
+
return
|
|
1120
|
+
pct = int((i + 1) * 100 / total)
|
|
1121
|
+
if pct != last_pct["v"]:
|
|
1122
|
+
last_pct["v"] = pct
|
|
1123
|
+
log_callback(f"__PCANIM__ {pct} {i + 1}/{total}")
|
|
1124
|
+
|
|
1125
|
+
ani.save(
|
|
1126
|
+
str(save_path),
|
|
1127
|
+
writer=PillowWriter(fps=fps),
|
|
1128
|
+
progress_callback=_progress,
|
|
1129
|
+
)
|
|
1130
|
+
return
|
|
1131
|
+
|
|
1132
|
+
from canns.analyzer.visualization.core import (
|
|
1133
|
+
get_imageio_writer_kwargs,
|
|
1134
|
+
get_matplotlib_writer,
|
|
1135
|
+
select_animation_backend,
|
|
1136
|
+
)
|
|
1137
|
+
|
|
1138
|
+
backend_selection = select_animation_backend(
|
|
1139
|
+
save_path=str(save_path),
|
|
1140
|
+
requested_backend="auto",
|
|
1141
|
+
check_imageio_plugins=True,
|
|
1142
|
+
)
|
|
1143
|
+
for warning in backend_selection.warnings:
|
|
1144
|
+
log_callback(f"⚠️ {warning}")
|
|
1145
|
+
|
|
1146
|
+
if backend_selection.backend == "imageio":
|
|
1147
|
+
import imageio
|
|
1148
|
+
|
|
1149
|
+
writer_kwargs, mode = get_imageio_writer_kwargs(str(save_path), fps)
|
|
1150
|
+
last_pct = -1
|
|
1151
|
+
with imageio.get_writer(str(save_path), mode=mode, **writer_kwargs) as writer:
|
|
1152
|
+
for k in range(n_frames):
|
|
1153
|
+
update_func(k)
|
|
1154
|
+
fig.canvas.draw()
|
|
1155
|
+
frame = np.asarray(fig.canvas.buffer_rgba())
|
|
1156
|
+
writer.append_data(frame[:, :, :3])
|
|
1157
|
+
pct = int((k + 1) * 100 / n_frames)
|
|
1158
|
+
if pct != last_pct:
|
|
1159
|
+
last_pct = pct
|
|
1160
|
+
log_callback(f"__PCANIM__ {pct} {k + 1}/{n_frames}")
|
|
1161
|
+
return
|
|
1162
|
+
|
|
1163
|
+
from matplotlib.animation import FuncAnimation
|
|
1164
|
+
|
|
1165
|
+
def _update(k: int):
|
|
1166
|
+
update_func(k)
|
|
1167
|
+
return []
|
|
1168
|
+
|
|
1169
|
+
interval_ms = int(1000 / max(1, fps))
|
|
1170
|
+
ani = FuncAnimation(fig, _update, frames=n_frames, interval=interval_ms, blit=False)
|
|
1171
|
+
writer = get_matplotlib_writer(str(save_path), fps=fps)
|
|
1172
|
+
ani.save(str(save_path), writer=writer)
|
|
1173
|
+
|
|
1174
|
+
def _run_cohospace(
|
|
1175
|
+
self, asa_data: dict[str, Any], state: WorkflowState, log_callback
|
|
1176
|
+
) -> dict[str, Path]:
|
|
1177
|
+
from canns.analyzer.data.asa import (
|
|
1178
|
+
plot_cohospace_neuron_1d,
|
|
1179
|
+
plot_cohospace_neuron_2d,
|
|
1180
|
+
plot_cohospace_population_1d,
|
|
1181
|
+
plot_cohospace_population_2d,
|
|
1182
|
+
plot_cohospace_trajectory_1d,
|
|
1183
|
+
plot_cohospace_trajectory_2d,
|
|
1184
|
+
)
|
|
1185
|
+
from canns.analyzer.data.asa.cohospace import (
|
|
1186
|
+
compute_cohoscore_1d,
|
|
1187
|
+
compute_cohoscore_2d,
|
|
1188
|
+
plot_cohospace_neuron_skewed,
|
|
1189
|
+
plot_cohospace_population_skewed,
|
|
1190
|
+
)
|
|
1191
|
+
from canns.analyzer.visualization import PlotConfigs
|
|
1192
|
+
|
|
1193
|
+
tda_dir = self._results_dir(state) / "TDA"
|
|
1194
|
+
persistence_path = tda_dir / "persistence.npz"
|
|
1195
|
+
if not persistence_path.exists():
|
|
1196
|
+
raise ProcessingError("TDA results not found. Run TDA analysis first.")
|
|
1197
|
+
|
|
1198
|
+
decode_result = self._load_or_run_decode(asa_data, persistence_path, state, log_callback)
|
|
1199
|
+
|
|
1200
|
+
out_dir = self._results_dir(state) / "CohoSpace"
|
|
1201
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
1202
|
+
|
|
1203
|
+
params = state.analysis_params
|
|
1204
|
+
artifacts: dict[str, Path] = {}
|
|
1205
|
+
|
|
1206
|
+
coords = np.asarray(decode_result.get("coords"))
|
|
1207
|
+
coordsbox = np.asarray(decode_result.get("coordsbox"))
|
|
1208
|
+
if coords.ndim != 2 or coords.shape[1] < 1:
|
|
1209
|
+
raise ProcessingError("decode_result['coords'] must be 2D.")
|
|
1210
|
+
|
|
1211
|
+
dim_mode = str(params.get("dim_mode", "2d"))
|
|
1212
|
+
dim = int(params.get("dim", 1))
|
|
1213
|
+
dim1 = int(params.get("dim1", 1))
|
|
1214
|
+
dim2 = int(params.get("dim2", 2))
|
|
1215
|
+
mode = str(params.get("mode", "fr"))
|
|
1216
|
+
top_percent = float(params.get("top_percent", 5.0))
|
|
1217
|
+
view = str(params.get("view", "both"))
|
|
1218
|
+
subsample = int(params.get("subsample", 2))
|
|
1219
|
+
unfold = str(params.get("unfold", "square"))
|
|
1220
|
+
skew_show_grid = bool(params.get("skew_show_grid", True))
|
|
1221
|
+
skew_tiles = int(params.get("skew_tiles", 0))
|
|
1222
|
+
enable_score = bool(params.get("enable_score", True))
|
|
1223
|
+
top_k = int(params.get("top_k", 10))
|
|
1224
|
+
use_best = bool(params.get("use_best", True))
|
|
1225
|
+
times = decode_result.get("times")
|
|
1226
|
+
|
|
1227
|
+
def pick_coords(arr: np.ndarray) -> np.ndarray:
|
|
1228
|
+
if dim_mode == "1d":
|
|
1229
|
+
idx = max(0, dim - 1)
|
|
1230
|
+
if idx >= arr.shape[1]:
|
|
1231
|
+
raise ProcessingError(f"dim out of range for coords shape {arr.shape}")
|
|
1232
|
+
return arr[:, idx]
|
|
1233
|
+
idx1 = max(0, dim1 - 1)
|
|
1234
|
+
idx2 = max(0, dim2 - 1)
|
|
1235
|
+
if idx1 >= arr.shape[1] or idx2 >= arr.shape[1]:
|
|
1236
|
+
raise ProcessingError(f"dim1/dim2 out of range for coords shape {arr.shape}")
|
|
1237
|
+
return arr[:, [idx1, idx2]]
|
|
1238
|
+
|
|
1239
|
+
coords2 = pick_coords(coords)
|
|
1240
|
+
coordsbox2 = pick_coords(coordsbox) if coordsbox.ndim == 2 else coords2
|
|
1241
|
+
|
|
1242
|
+
if mode == "spike":
|
|
1243
|
+
activity = self._build_spike_matrix_from_events(asa_data)
|
|
1244
|
+
else:
|
|
1245
|
+
activity = (
|
|
1246
|
+
self._embed_data
|
|
1247
|
+
if self._embed_data is not None
|
|
1248
|
+
else self._build_spike_matrix_from_events(asa_data)
|
|
1249
|
+
)
|
|
1250
|
+
|
|
1251
|
+
scores = None
|
|
1252
|
+
top_ids = None
|
|
1253
|
+
neuron_id = int(params.get("neuron_id", 0))
|
|
1254
|
+
if enable_score:
|
|
1255
|
+
try:
|
|
1256
|
+
if dim_mode == "1d":
|
|
1257
|
+
scores = compute_cohoscore_1d(
|
|
1258
|
+
coords2, activity, top_percent=top_percent, times=times
|
|
1259
|
+
)
|
|
1260
|
+
else:
|
|
1261
|
+
scores = compute_cohoscore_2d(
|
|
1262
|
+
coords2, activity, top_percent=top_percent, times=times
|
|
1263
|
+
)
|
|
1264
|
+
cohoscore_path = out_dir / "cohoscore.npy"
|
|
1265
|
+
np.save(cohoscore_path, scores)
|
|
1266
|
+
except Exception as e:
|
|
1267
|
+
log_callback(f"⚠️ CohoScore computation failed: {e}")
|
|
1268
|
+
scores = None
|
|
1269
|
+
|
|
1270
|
+
if scores is not None:
|
|
1271
|
+
valid = np.where(~np.isnan(scores))[0]
|
|
1272
|
+
if valid.size > 0:
|
|
1273
|
+
sorted_idx = valid[np.argsort(scores[valid])]
|
|
1274
|
+
top_ids = sorted_idx[: min(top_k, len(sorted_idx))]
|
|
1275
|
+
top_ids_path = out_dir / "cohospace_top_ids.npy"
|
|
1276
|
+
np.save(top_ids_path, top_ids)
|
|
1277
|
+
else:
|
|
1278
|
+
log_callback("⚠️ CohoScore: all values are NaN.")
|
|
1279
|
+
|
|
1280
|
+
if view in {"both", "single"} and enable_score and scores is not None and use_best:
|
|
1281
|
+
valid = np.where(~np.isnan(scores))[0]
|
|
1282
|
+
if valid.size > 0:
|
|
1283
|
+
best_id = int(valid[np.argmin(scores[valid])])
|
|
1284
|
+
neuron_id = best_id
|
|
1285
|
+
log_callback(f"🎯 CohoSpace neuron auto-selected by best CohoScore: {neuron_id}")
|
|
1286
|
+
|
|
1287
|
+
decode_meta = self._load_cache_meta(
|
|
1288
|
+
self._stage_cache_path(self._results_dir(state) / "CohoMap")
|
|
1289
|
+
)
|
|
1290
|
+
stage_hash = self._hash_obj(
|
|
1291
|
+
{
|
|
1292
|
+
"persistence": self._hash_file(persistence_path),
|
|
1293
|
+
"decode_hash": decode_meta.get("decode_hash"),
|
|
1294
|
+
"params": params,
|
|
1295
|
+
}
|
|
1296
|
+
)
|
|
1297
|
+
meta_path = self._stage_cache_path(out_dir)
|
|
1298
|
+
required = [out_dir / "cohospace_trajectory.png"]
|
|
1299
|
+
if view in {"both", "population"}:
|
|
1300
|
+
required.append(out_dir / "cohospace_population.png")
|
|
1301
|
+
if view in {"both", "single"}:
|
|
1302
|
+
required.append(out_dir / f"cohospace_neuron_{neuron_id}.png")
|
|
1303
|
+
|
|
1304
|
+
if self._stage_cache_hit(out_dir, stage_hash, required):
|
|
1305
|
+
log_callback("♻️ Using cached CohoSpace plots.")
|
|
1306
|
+
artifacts = {"trajectory": out_dir / "cohospace_trajectory.png"}
|
|
1307
|
+
if view in {"both", "single"}:
|
|
1308
|
+
artifacts["neuron"] = out_dir / f"cohospace_neuron_{neuron_id}.png"
|
|
1309
|
+
if view in {"both", "population"}:
|
|
1310
|
+
artifacts["population"] = out_dir / "cohospace_population.png"
|
|
1311
|
+
return artifacts
|
|
1312
|
+
|
|
1313
|
+
log_callback("Plotting cohomology space trajectory...")
|
|
1314
|
+
traj_path = out_dir / "cohospace_trajectory.png"
|
|
1315
|
+
if dim_mode == "1d":
|
|
1316
|
+
traj_cfg = PlotConfigs.cohospace_trajectory_1d(show=False, save_path=str(traj_path))
|
|
1317
|
+
plot_cohospace_trajectory_1d(
|
|
1318
|
+
coords=coords2,
|
|
1319
|
+
times=None,
|
|
1320
|
+
subsample=subsample,
|
|
1321
|
+
config=traj_cfg,
|
|
1322
|
+
)
|
|
1323
|
+
else:
|
|
1324
|
+
traj_cfg = PlotConfigs.cohospace_trajectory_2d(show=False, save_path=str(traj_path))
|
|
1325
|
+
plot_cohospace_trajectory_2d(
|
|
1326
|
+
coords=coords2,
|
|
1327
|
+
times=None,
|
|
1328
|
+
subsample=subsample,
|
|
1329
|
+
config=traj_cfg,
|
|
1330
|
+
)
|
|
1331
|
+
artifacts["trajectory"] = traj_path
|
|
1332
|
+
|
|
1333
|
+
if neuron_id is not None and view in {"both", "single"}:
|
|
1334
|
+
log_callback(f"Plotting neuron {neuron_id}...")
|
|
1335
|
+
neuron_path = out_dir / f"cohospace_neuron_{neuron_id}.png"
|
|
1336
|
+
if unfold == "skew" and dim_mode != "1d":
|
|
1337
|
+
plot_cohospace_neuron_skewed(
|
|
1338
|
+
coords=coordsbox2,
|
|
1339
|
+
activity=activity,
|
|
1340
|
+
neuron_id=int(neuron_id),
|
|
1341
|
+
mode=mode,
|
|
1342
|
+
top_percent=top_percent,
|
|
1343
|
+
times=times,
|
|
1344
|
+
save_path=str(neuron_path),
|
|
1345
|
+
show=False,
|
|
1346
|
+
show_grid=skew_show_grid,
|
|
1347
|
+
n_tiles=skew_tiles,
|
|
1348
|
+
)
|
|
1349
|
+
else:
|
|
1350
|
+
if dim_mode == "1d":
|
|
1351
|
+
neuron_cfg = PlotConfigs.cohospace_neuron_1d(
|
|
1352
|
+
show=False, save_path=str(neuron_path)
|
|
1353
|
+
)
|
|
1354
|
+
plot_cohospace_neuron_1d(
|
|
1355
|
+
coords=coordsbox2,
|
|
1356
|
+
activity=activity,
|
|
1357
|
+
neuron_id=int(neuron_id),
|
|
1358
|
+
mode=mode,
|
|
1359
|
+
top_percent=top_percent,
|
|
1360
|
+
times=times,
|
|
1361
|
+
config=neuron_cfg,
|
|
1362
|
+
)
|
|
1363
|
+
else:
|
|
1364
|
+
neuron_cfg = PlotConfigs.cohospace_neuron_2d(
|
|
1365
|
+
show=False, save_path=str(neuron_path)
|
|
1366
|
+
)
|
|
1367
|
+
plot_cohospace_neuron_2d(
|
|
1368
|
+
coords=coordsbox2,
|
|
1369
|
+
activity=activity,
|
|
1370
|
+
neuron_id=int(neuron_id),
|
|
1371
|
+
mode=mode,
|
|
1372
|
+
top_percent=top_percent,
|
|
1373
|
+
times=times,
|
|
1374
|
+
config=neuron_cfg,
|
|
1375
|
+
)
|
|
1376
|
+
artifacts["neuron"] = neuron_path
|
|
1377
|
+
|
|
1378
|
+
if view in {"both", "population"}:
|
|
1379
|
+
log_callback("Plotting population activity...")
|
|
1380
|
+
pop_path = out_dir / "cohospace_population.png"
|
|
1381
|
+
if enable_score and top_ids is not None:
|
|
1382
|
+
neuron_ids = [int(i) for i in top_ids.tolist()]
|
|
1383
|
+
log_callback(f"CohoSpace: aggregating top-{len(neuron_ids)} neurons by CohoScore.")
|
|
1384
|
+
else:
|
|
1385
|
+
neuron_ids = list(range(activity.shape[1]))
|
|
1386
|
+
if unfold == "skew" and dim_mode != "1d":
|
|
1387
|
+
plot_cohospace_population_skewed(
|
|
1388
|
+
coords=coords2,
|
|
1389
|
+
activity=activity,
|
|
1390
|
+
neuron_ids=neuron_ids,
|
|
1391
|
+
mode=mode,
|
|
1392
|
+
top_percent=top_percent,
|
|
1393
|
+
times=times,
|
|
1394
|
+
save_path=str(pop_path),
|
|
1395
|
+
show=False,
|
|
1396
|
+
show_grid=skew_show_grid,
|
|
1397
|
+
n_tiles=skew_tiles,
|
|
1398
|
+
)
|
|
1399
|
+
else:
|
|
1400
|
+
if dim_mode == "1d":
|
|
1401
|
+
pop_cfg = PlotConfigs.cohospace_population_1d(
|
|
1402
|
+
show=False, save_path=str(pop_path)
|
|
1403
|
+
)
|
|
1404
|
+
plot_cohospace_population_1d(
|
|
1405
|
+
coords=coords2,
|
|
1406
|
+
activity=activity,
|
|
1407
|
+
neuron_ids=neuron_ids,
|
|
1408
|
+
mode=mode,
|
|
1409
|
+
top_percent=top_percent,
|
|
1410
|
+
times=times,
|
|
1411
|
+
config=pop_cfg,
|
|
1412
|
+
)
|
|
1413
|
+
else:
|
|
1414
|
+
pop_cfg = PlotConfigs.cohospace_population_2d(
|
|
1415
|
+
show=False, save_path=str(pop_path)
|
|
1416
|
+
)
|
|
1417
|
+
plot_cohospace_population_2d(
|
|
1418
|
+
coords=coords2,
|
|
1419
|
+
activity=activity,
|
|
1420
|
+
neuron_ids=neuron_ids,
|
|
1421
|
+
mode=mode,
|
|
1422
|
+
top_percent=top_percent,
|
|
1423
|
+
times=times,
|
|
1424
|
+
config=pop_cfg,
|
|
1425
|
+
)
|
|
1426
|
+
artifacts["population"] = pop_path
|
|
1427
|
+
|
|
1428
|
+
self._write_cache_meta(meta_path, {"hash": stage_hash})
|
|
1429
|
+
return artifacts
|
|
1430
|
+
|
|
1431
|
+
def _run_fr(
|
|
1432
|
+
self, asa_data: dict[str, Any], state: WorkflowState, log_callback
|
|
1433
|
+
) -> dict[str, Path]:
|
|
1434
|
+
from canns.analyzer.data.asa import compute_fr_heatmap_matrix, save_fr_heatmap_png
|
|
1435
|
+
from canns.analyzer.visualization import PlotConfigs
|
|
1436
|
+
|
|
1437
|
+
out_dir = self._results_dir(state) / "FR"
|
|
1438
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
1439
|
+
|
|
1440
|
+
params = state.analysis_params
|
|
1441
|
+
neuron_range = params.get("neuron_range", None)
|
|
1442
|
+
time_range = params.get("time_range", None)
|
|
1443
|
+
normalize = params.get("normalize", "zscore_per_neuron")
|
|
1444
|
+
if normalize in {"none", "", None}:
|
|
1445
|
+
normalize = None
|
|
1446
|
+
mode = params.get("mode", "fr")
|
|
1447
|
+
|
|
1448
|
+
if mode == "spike":
|
|
1449
|
+
spike_data = self._build_spike_matrix_from_events(asa_data)
|
|
1450
|
+
else:
|
|
1451
|
+
spike_data = self._embed_data
|
|
1452
|
+
|
|
1453
|
+
if spike_data is None:
|
|
1454
|
+
raise ProcessingError("No spike data available for FR.")
|
|
1455
|
+
|
|
1456
|
+
out_path = out_dir / "fr_heatmap.png"
|
|
1457
|
+
stage_hash = self._hash_obj(
|
|
1458
|
+
{
|
|
1459
|
+
"embed_hash": self._embed_hash,
|
|
1460
|
+
"params": params,
|
|
1461
|
+
}
|
|
1462
|
+
)
|
|
1463
|
+
if self._stage_cache_hit(out_dir, stage_hash, [out_path]):
|
|
1464
|
+
log_callback("♻️ Using cached FR heatmap.")
|
|
1465
|
+
return {"fr_heatmap": out_path}
|
|
1466
|
+
|
|
1467
|
+
log_callback("Computing firing rate heatmap...")
|
|
1468
|
+
fr_matrix = compute_fr_heatmap_matrix(
|
|
1469
|
+
spike_data,
|
|
1470
|
+
neuron_range=neuron_range,
|
|
1471
|
+
time_range=time_range,
|
|
1472
|
+
normalize=normalize,
|
|
1473
|
+
)
|
|
1474
|
+
|
|
1475
|
+
config = PlotConfigs.fr_heatmap(show=False, save_path=str(out_path))
|
|
1476
|
+
save_fr_heatmap_png(fr_matrix, config=config, dpi=200)
|
|
1477
|
+
|
|
1478
|
+
self._write_cache_meta(self._stage_cache_path(out_dir), {"hash": stage_hash})
|
|
1479
|
+
return {"fr_heatmap": out_path}
|
|
1480
|
+
|
|
1481
|
+
def _run_frm(
|
|
1482
|
+
self, asa_data: dict[str, Any], state: WorkflowState, log_callback
|
|
1483
|
+
) -> dict[str, Path]:
|
|
1484
|
+
from canns.analyzer.data.asa import compute_frm, plot_frm
|
|
1485
|
+
from canns.analyzer.visualization import PlotConfigs
|
|
1486
|
+
|
|
1487
|
+
out_dir = self._results_dir(state) / "FRM"
|
|
1488
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
1489
|
+
|
|
1490
|
+
params = state.analysis_params
|
|
1491
|
+
neuron_id = int(params.get("neuron_id", 0))
|
|
1492
|
+
bins = int(params.get("bin_size", 50))
|
|
1493
|
+
min_occupancy = int(params.get("min_occupancy", 1))
|
|
1494
|
+
smoothing = bool(params.get("smoothing", False))
|
|
1495
|
+
smooth_sigma = float(params.get("smooth_sigma", 2.0))
|
|
1496
|
+
mode = str(params.get("mode", "fr"))
|
|
1497
|
+
|
|
1498
|
+
if mode == "spike":
|
|
1499
|
+
spike_data = self._build_spike_matrix_from_events(asa_data)
|
|
1500
|
+
else:
|
|
1501
|
+
spike_data = self._embed_data
|
|
1502
|
+
if spike_data is None:
|
|
1503
|
+
raise ProcessingError("No spike data available for FRM.")
|
|
1504
|
+
|
|
1505
|
+
pos = self._aligned_pos if self._aligned_pos is not None else asa_data
|
|
1506
|
+
x = np.asarray(pos.get("x"))
|
|
1507
|
+
y = np.asarray(pos.get("y"))
|
|
1508
|
+
|
|
1509
|
+
if x is None or y is None:
|
|
1510
|
+
raise ProcessingError("Position data (x,y) is required for FRM.")
|
|
1511
|
+
|
|
1512
|
+
out_path = out_dir / f"frm_neuron_{neuron_id}.png"
|
|
1513
|
+
stage_hash = self._hash_obj(
|
|
1514
|
+
{
|
|
1515
|
+
"embed_hash": self._embed_hash,
|
|
1516
|
+
"params": params,
|
|
1517
|
+
}
|
|
1518
|
+
)
|
|
1519
|
+
if self._stage_cache_hit(out_dir, stage_hash, [out_path]):
|
|
1520
|
+
log_callback("♻️ Using cached FRM.")
|
|
1521
|
+
return {"frm": out_path}
|
|
1522
|
+
|
|
1523
|
+
log_callback(f"Computing firing rate map for neuron {neuron_id}...")
|
|
1524
|
+
frm_result = compute_frm(
|
|
1525
|
+
spike_data,
|
|
1526
|
+
x,
|
|
1527
|
+
y,
|
|
1528
|
+
neuron_id=neuron_id,
|
|
1529
|
+
bins=max(1, bins),
|
|
1530
|
+
min_occupancy=min_occupancy,
|
|
1531
|
+
smoothing=smoothing,
|
|
1532
|
+
sigma=smooth_sigma,
|
|
1533
|
+
nan_for_empty=True,
|
|
1534
|
+
)
|
|
1535
|
+
|
|
1536
|
+
config = PlotConfigs.frm(show=False, save_path=str(out_path))
|
|
1537
|
+
plot_frm(frm_result.frm, config=config, dpi=200)
|
|
1538
|
+
|
|
1539
|
+
self._write_cache_meta(self._stage_cache_path(out_dir), {"hash": stage_hash})
|
|
1540
|
+
return {"frm": out_path}
|
|
1541
|
+
|
|
1542
|
+
def _run_gridscore(
|
|
1543
|
+
self,
|
|
1544
|
+
asa_data: dict[str, Any],
|
|
1545
|
+
state: WorkflowState,
|
|
1546
|
+
log_callback,
|
|
1547
|
+
progress_callback: Callable[[int], None],
|
|
1548
|
+
) -> dict[str, Path]:
|
|
1549
|
+
"""Run batch gridness score computation."""
|
|
1550
|
+
import csv
|
|
1551
|
+
|
|
1552
|
+
from canns.analyzer.data.cell_classification import (
|
|
1553
|
+
GridnessAnalyzer,
|
|
1554
|
+
compute_2d_autocorrelation,
|
|
1555
|
+
)
|
|
1556
|
+
|
|
1557
|
+
params = state.analysis_params or {}
|
|
1558
|
+
gs_params = params.get("gridscore") if isinstance(params.get("gridscore"), dict) else {}
|
|
1559
|
+
|
|
1560
|
+
def _param(key: str, default: Any) -> Any:
|
|
1561
|
+
if key in params:
|
|
1562
|
+
return params.get(key, default)
|
|
1563
|
+
if gs_params and key in gs_params:
|
|
1564
|
+
return gs_params.get(key, default)
|
|
1565
|
+
return default
|
|
1566
|
+
|
|
1567
|
+
n_start = int(_param("neuron_start", 0))
|
|
1568
|
+
n_end = int(_param("neuron_end", 0))
|
|
1569
|
+
bins = int(_param("bins", 50))
|
|
1570
|
+
min_occ = int(_param("min_occupancy", 1))
|
|
1571
|
+
smoothing = bool(_param("smoothing", False))
|
|
1572
|
+
sigma = float(_param("sigma", 1.0))
|
|
1573
|
+
overlap = float(_param("overlap", 0.8))
|
|
1574
|
+
mode = str(_param("mode", "fr")).strip().lower()
|
|
1575
|
+
score_thr = float(_param("score_thr", 0.3))
|
|
1576
|
+
|
|
1577
|
+
if mode not in {"fr", "spike"}:
|
|
1578
|
+
mode = "fr"
|
|
1579
|
+
bins = max(5, bins)
|
|
1580
|
+
overlap = max(0.1, min(1.0, overlap))
|
|
1581
|
+
|
|
1582
|
+
pos = self._aligned_pos if self._aligned_pos is not None else asa_data
|
|
1583
|
+
if "x" not in pos or "y" not in pos or pos["x"] is None or pos["y"] is None:
|
|
1584
|
+
raise ProcessingError("GridScore requires position data (x, y).")
|
|
1585
|
+
|
|
1586
|
+
if mode == "fr":
|
|
1587
|
+
if isinstance(self._embed_data, np.ndarray) and self._embed_data.ndim == 2:
|
|
1588
|
+
activity_full = self._embed_data
|
|
1589
|
+
log_callback("GridScore[FR]: using preprocessed spike matrix.")
|
|
1590
|
+
else:
|
|
1591
|
+
log_callback("GridScore[FR]: no preprocessed matrix, falling back to spike mode.")
|
|
1592
|
+
activity_full = self._build_spike_matrix_from_events(asa_data)
|
|
1593
|
+
else:
|
|
1594
|
+
activity_full = self._build_spike_matrix_from_events(asa_data)
|
|
1595
|
+
log_callback("GridScore[spike]: using event-based spike matrix.")
|
|
1596
|
+
|
|
1597
|
+
sp = np.asarray(activity_full)
|
|
1598
|
+
if sp.ndim != 2:
|
|
1599
|
+
raise ProcessingError(f"GridScore expects 2D spike matrix, got ndim={sp.ndim}.")
|
|
1600
|
+
|
|
1601
|
+
x = np.asarray(pos["x"]).ravel()
|
|
1602
|
+
y = np.asarray(pos["y"]).ravel()
|
|
1603
|
+
m = min(len(x), len(y), sp.shape[0])
|
|
1604
|
+
x = x[:m]
|
|
1605
|
+
y = y[:m]
|
|
1606
|
+
sp = sp[:m, :]
|
|
1607
|
+
|
|
1608
|
+
finite = np.isfinite(x) & np.isfinite(y)
|
|
1609
|
+
if not np.all(finite):
|
|
1610
|
+
x = x[finite]
|
|
1611
|
+
y = y[finite]
|
|
1612
|
+
sp = sp[finite, :]
|
|
1613
|
+
|
|
1614
|
+
total_neurons = sp.shape[1]
|
|
1615
|
+
if n_end <= 0 or n_end > total_neurons:
|
|
1616
|
+
n_end = total_neurons
|
|
1617
|
+
n_start = max(0, min(n_start, total_neurons - 1))
|
|
1618
|
+
n_end = max(n_start + 1, n_end)
|
|
1619
|
+
|
|
1620
|
+
xmin, xmax = float(np.min(x)), float(np.max(x))
|
|
1621
|
+
ymin, ymax = float(np.min(y)), float(np.max(y))
|
|
1622
|
+
eps = 1e-12
|
|
1623
|
+
if xmax - xmin < eps:
|
|
1624
|
+
xmax = xmin + 1.0
|
|
1625
|
+
if ymax - ymin < eps:
|
|
1626
|
+
ymax = ymin + 1.0
|
|
1627
|
+
|
|
1628
|
+
ix = np.floor((x - xmin) / (xmax - xmin + eps) * bins).astype(int)
|
|
1629
|
+
iy = np.floor((y - ymin) / (ymax - ymin + eps) * bins).astype(int)
|
|
1630
|
+
ix = np.clip(ix, 0, bins - 1)
|
|
1631
|
+
iy = np.clip(iy, 0, bins - 1)
|
|
1632
|
+
flat = (iy * bins + ix).astype(int)
|
|
1633
|
+
|
|
1634
|
+
occ = np.bincount(flat, minlength=bins * bins).astype(float).reshape(bins, bins)
|
|
1635
|
+
occ_mask = occ >= float(min_occ)
|
|
1636
|
+
|
|
1637
|
+
gaussian_filter = None
|
|
1638
|
+
if smoothing and sigma > 0:
|
|
1639
|
+
try:
|
|
1640
|
+
from scipy.ndimage import gaussian_filter as _gaussian_filter
|
|
1641
|
+
except Exception as e: # pragma: no cover - optional dependency
|
|
1642
|
+
raise ProcessingError(f"GridScore requires scipy for smoothing: {e}") from e
|
|
1643
|
+
gaussian_filter = _gaussian_filter
|
|
1644
|
+
|
|
1645
|
+
def _rate_map_for_neuron(col: int) -> np.ndarray:
|
|
1646
|
+
weights = sp[:, col].astype(float, copy=False)
|
|
1647
|
+
spike_map = (
|
|
1648
|
+
np.bincount(flat, weights=weights, minlength=bins * bins)
|
|
1649
|
+
.astype(float)
|
|
1650
|
+
.reshape(bins, bins)
|
|
1651
|
+
)
|
|
1652
|
+
rate_map = np.zeros_like(spike_map)
|
|
1653
|
+
rate_map[occ_mask] = spike_map[occ_mask] / occ[occ_mask]
|
|
1654
|
+
if gaussian_filter is not None:
|
|
1655
|
+
rate_map = gaussian_filter(rate_map, sigma=float(sigma), mode="nearest")
|
|
1656
|
+
return rate_map
|
|
1657
|
+
|
|
1658
|
+
analyzer = GridnessAnalyzer()
|
|
1659
|
+
n_sel = n_end - n_start
|
|
1660
|
+
scores = np.full((n_sel,), np.nan, dtype=float)
|
|
1661
|
+
spacing = np.full((n_sel, 3), np.nan, dtype=float)
|
|
1662
|
+
orientation = np.full((n_sel, 3), np.nan, dtype=float)
|
|
1663
|
+
ellipse = np.full((n_sel, 5), np.nan, dtype=float)
|
|
1664
|
+
ellipse_theta_deg = np.full((n_sel,), np.nan, dtype=float)
|
|
1665
|
+
center_radius = np.full((n_sel,), np.nan, dtype=float)
|
|
1666
|
+
optimal_radius = np.full((n_sel,), np.nan, dtype=float)
|
|
1667
|
+
|
|
1668
|
+
log_callback(
|
|
1669
|
+
f"GridScore: computing neurons [{n_start}, {n_end}) with bins={bins}, overlap={overlap:.2f}."
|
|
1670
|
+
)
|
|
1671
|
+
|
|
1672
|
+
for j, nid in enumerate(range(n_start, n_end)):
|
|
1673
|
+
rate_map = _rate_map_for_neuron(nid)
|
|
1674
|
+
autocorr = compute_2d_autocorrelation(rate_map, overlap=overlap)
|
|
1675
|
+
result = analyzer.compute_gridness_score(autocorr)
|
|
1676
|
+
|
|
1677
|
+
scores[j] = float(result.score)
|
|
1678
|
+
if result.spacing is not None and np.size(result.spacing) >= 3:
|
|
1679
|
+
spacing[j, :] = np.asarray(result.spacing).ravel()[:3]
|
|
1680
|
+
if result.orientation is not None and np.size(result.orientation) >= 3:
|
|
1681
|
+
orientation[j, :] = np.asarray(result.orientation).ravel()[:3]
|
|
1682
|
+
if result.ellipse is not None and np.size(result.ellipse) >= 5:
|
|
1683
|
+
ellipse[j, :] = np.asarray(result.ellipse).ravel()[:5]
|
|
1684
|
+
ellipse_theta_deg[j] = float(result.ellipse_theta_deg)
|
|
1685
|
+
center_radius[j] = float(result.center_radius)
|
|
1686
|
+
optimal_radius[j] = float(result.optimal_radius)
|
|
1687
|
+
|
|
1688
|
+
if (j + 1) % max(1, n_sel // 10) == 0:
|
|
1689
|
+
progress = 60 + int(35 * (j + 1) / max(1, n_sel))
|
|
1690
|
+
progress_callback(min(98, progress))
|
|
1691
|
+
|
|
1692
|
+
out_dir = self._results_dir(state) / "GRIDScore"
|
|
1693
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
1694
|
+
|
|
1695
|
+
gridscore_npz = out_dir / "gridscore.npz"
|
|
1696
|
+
np.savez_compressed(
|
|
1697
|
+
str(gridscore_npz),
|
|
1698
|
+
neuron_start=n_start,
|
|
1699
|
+
neuron_end=n_end,
|
|
1700
|
+
neuron_ids=np.arange(n_start, n_end, dtype=int),
|
|
1701
|
+
bins=bins,
|
|
1702
|
+
min_occupancy=min_occ,
|
|
1703
|
+
smoothing=smoothing,
|
|
1704
|
+
sigma=sigma,
|
|
1705
|
+
overlap=overlap,
|
|
1706
|
+
mode=mode,
|
|
1707
|
+
score=scores,
|
|
1708
|
+
grid_score=scores,
|
|
1709
|
+
spacing=spacing,
|
|
1710
|
+
orientation=orientation,
|
|
1711
|
+
ellipse=ellipse,
|
|
1712
|
+
ellipse_theta_deg=ellipse_theta_deg,
|
|
1713
|
+
center_radius=center_radius,
|
|
1714
|
+
optimal_radius=optimal_radius,
|
|
1715
|
+
)
|
|
1716
|
+
|
|
1717
|
+
gridscore_csv = out_dir / "gridscore.csv"
|
|
1718
|
+
with gridscore_csv.open("w", newline="", encoding="utf-8") as f:
|
|
1719
|
+
writer = csv.writer(f)
|
|
1720
|
+
writer.writerow(
|
|
1721
|
+
[
|
|
1722
|
+
"neuron_id",
|
|
1723
|
+
"grid_score",
|
|
1724
|
+
"spacing1",
|
|
1725
|
+
"spacing2",
|
|
1726
|
+
"spacing3",
|
|
1727
|
+
"orient1_deg",
|
|
1728
|
+
"orient2_deg",
|
|
1729
|
+
"orient3_deg",
|
|
1730
|
+
"ellipse_cx",
|
|
1731
|
+
"ellipse_cy",
|
|
1732
|
+
"ellipse_rx",
|
|
1733
|
+
"ellipse_ry",
|
|
1734
|
+
"ellipse_theta_deg",
|
|
1735
|
+
"center_radius",
|
|
1736
|
+
"optimal_radius",
|
|
1737
|
+
]
|
|
1738
|
+
)
|
|
1739
|
+
for j, nid in enumerate(range(n_start, n_end)):
|
|
1740
|
+
writer.writerow(
|
|
1741
|
+
[
|
|
1742
|
+
nid,
|
|
1743
|
+
scores[j],
|
|
1744
|
+
spacing[j, 0],
|
|
1745
|
+
spacing[j, 1],
|
|
1746
|
+
spacing[j, 2],
|
|
1747
|
+
orientation[j, 0],
|
|
1748
|
+
orientation[j, 1],
|
|
1749
|
+
orientation[j, 2],
|
|
1750
|
+
ellipse[j, 0],
|
|
1751
|
+
ellipse[j, 1],
|
|
1752
|
+
ellipse[j, 2],
|
|
1753
|
+
ellipse[j, 3],
|
|
1754
|
+
ellipse_theta_deg[j],
|
|
1755
|
+
center_radius[j],
|
|
1756
|
+
optimal_radius[j],
|
|
1757
|
+
]
|
|
1758
|
+
)
|
|
1759
|
+
|
|
1760
|
+
gridscore_png = out_dir / "gridscore_summary.png"
|
|
1761
|
+
try:
|
|
1762
|
+
import matplotlib.pyplot as plt
|
|
1763
|
+
|
|
1764
|
+
fig, ax = plt.subplots(1, 1, figsize=(10, 3.5))
|
|
1765
|
+
valid = scores[np.isfinite(scores)]
|
|
1766
|
+
ax.hist(valid, bins=30)
|
|
1767
|
+
ax.axvline(score_thr, linestyle="--")
|
|
1768
|
+
ax.set_title("Grid score distribution")
|
|
1769
|
+
ax.set_xlabel("grid score")
|
|
1770
|
+
ax.set_ylabel("count")
|
|
1771
|
+
fig.tight_layout()
|
|
1772
|
+
fig.savefig(gridscore_png, dpi=200)
|
|
1773
|
+
plt.close(fig)
|
|
1774
|
+
except Exception as e:
|
|
1775
|
+
log_callback(f"Warning: failed to save GridScore summary png: {e}")
|
|
1776
|
+
|
|
1777
|
+
log_callback(f"GridScore done. Saved: {gridscore_npz} , {gridscore_csv}")
|
|
1778
|
+
return {
|
|
1779
|
+
"gridscore_npz": gridscore_npz,
|
|
1780
|
+
"gridscore_csv": gridscore_csv,
|
|
1781
|
+
"gridscore_png": gridscore_png,
|
|
1782
|
+
}
|
|
1783
|
+
|
|
1784
|
+
def _run_gridscore_inspect(
|
|
1785
|
+
self,
|
|
1786
|
+
asa_data: dict[str, Any],
|
|
1787
|
+
state: WorkflowState,
|
|
1788
|
+
log_callback,
|
|
1789
|
+
progress_callback: Callable[[int], None],
|
|
1790
|
+
) -> dict[str, Path]:
|
|
1791
|
+
"""Run single-neuron GridScore inspection."""
|
|
1792
|
+
from canns.analyzer.data.cell_classification import (
|
|
1793
|
+
GridnessAnalyzer,
|
|
1794
|
+
compute_2d_autocorrelation,
|
|
1795
|
+
plot_gridness_analysis,
|
|
1796
|
+
)
|
|
1797
|
+
|
|
1798
|
+
params = state.analysis_params or {}
|
|
1799
|
+
gs_params = params.get("gridscore") if isinstance(params.get("gridscore"), dict) else {}
|
|
1800
|
+
|
|
1801
|
+
def _param(key: str, default: Any) -> Any:
|
|
1802
|
+
if key in params:
|
|
1803
|
+
return params.get(key, default)
|
|
1804
|
+
if gs_params and key in gs_params:
|
|
1805
|
+
return gs_params.get(key, default)
|
|
1806
|
+
return default
|
|
1807
|
+
|
|
1808
|
+
neuron_id = int(_param("neuron_id", _param("neuron", 0)))
|
|
1809
|
+
bins = int(_param("bins", 50))
|
|
1810
|
+
min_occ = int(_param("min_occupancy", 1))
|
|
1811
|
+
smoothing = bool(_param("smoothing", False))
|
|
1812
|
+
sigma = float(_param("sigma", 1.0))
|
|
1813
|
+
overlap = float(_param("overlap", 0.8))
|
|
1814
|
+
mode = str(_param("mode", "fr")).strip().lower()
|
|
1815
|
+
|
|
1816
|
+
if mode not in {"fr", "spike"}:
|
|
1817
|
+
mode = "fr"
|
|
1818
|
+
bins = max(5, bins)
|
|
1819
|
+
overlap = max(0.1, min(1.0, overlap))
|
|
1820
|
+
|
|
1821
|
+
pos = self._aligned_pos if self._aligned_pos is not None else asa_data
|
|
1822
|
+
if "x" not in pos or "y" not in pos or pos["x"] is None or pos["y"] is None:
|
|
1823
|
+
raise ProcessingError("GridScore inspect requires position data (x, y).")
|
|
1824
|
+
|
|
1825
|
+
if mode == "fr":
|
|
1826
|
+
if isinstance(self._embed_data, np.ndarray) and self._embed_data.ndim == 2:
|
|
1827
|
+
activity_full = self._embed_data
|
|
1828
|
+
log_callback("GridScoreInspect[FR]: using preprocessed spike matrix.")
|
|
1829
|
+
else:
|
|
1830
|
+
log_callback(
|
|
1831
|
+
"GridScoreInspect[FR]: no preprocessed matrix, falling back to spike mode."
|
|
1832
|
+
)
|
|
1833
|
+
activity_full = self._build_spike_matrix_from_events(asa_data)
|
|
1834
|
+
else:
|
|
1835
|
+
activity_full = self._build_spike_matrix_from_events(asa_data)
|
|
1836
|
+
log_callback("GridScoreInspect[spike]: using event-based spike matrix.")
|
|
1837
|
+
|
|
1838
|
+
sp = np.asarray(activity_full)
|
|
1839
|
+
if sp.ndim != 2:
|
|
1840
|
+
raise ProcessingError(f"GridScore inspect expects 2D spike matrix, got ndim={sp.ndim}.")
|
|
1841
|
+
|
|
1842
|
+
x = np.asarray(pos["x"]).ravel()
|
|
1843
|
+
y = np.asarray(pos["y"]).ravel()
|
|
1844
|
+
m = min(len(x), len(y), sp.shape[0])
|
|
1845
|
+
x = x[:m]
|
|
1846
|
+
y = y[:m]
|
|
1847
|
+
sp = sp[:m, :]
|
|
1848
|
+
|
|
1849
|
+
finite = np.isfinite(x) & np.isfinite(y)
|
|
1850
|
+
if not np.all(finite):
|
|
1851
|
+
x = x[finite]
|
|
1852
|
+
y = y[finite]
|
|
1853
|
+
sp = sp[finite, :]
|
|
1854
|
+
|
|
1855
|
+
total_neurons = sp.shape[1]
|
|
1856
|
+
neuron_id = max(0, min(int(neuron_id), total_neurons - 1))
|
|
1857
|
+
|
|
1858
|
+
xmin, xmax = float(np.min(x)), float(np.max(x))
|
|
1859
|
+
ymin, ymax = float(np.min(y)), float(np.max(y))
|
|
1860
|
+
eps = 1e-12
|
|
1861
|
+
if xmax - xmin < eps:
|
|
1862
|
+
xmax = xmin + 1.0
|
|
1863
|
+
if ymax - ymin < eps:
|
|
1864
|
+
ymax = ymin + 1.0
|
|
1865
|
+
|
|
1866
|
+
ix = np.floor((x - xmin) / (xmax - xmin + eps) * bins).astype(int)
|
|
1867
|
+
iy = np.floor((y - ymin) / (ymax - ymin + eps) * bins).astype(int)
|
|
1868
|
+
ix = np.clip(ix, 0, bins - 1)
|
|
1869
|
+
iy = np.clip(iy, 0, bins - 1)
|
|
1870
|
+
flat = (iy * bins + ix).astype(int)
|
|
1871
|
+
|
|
1872
|
+
occ = np.bincount(flat, minlength=bins * bins).astype(float).reshape(bins, bins)
|
|
1873
|
+
occ_mask = occ >= float(min_occ)
|
|
1874
|
+
|
|
1875
|
+
weights = sp[:, neuron_id].astype(float, copy=False)
|
|
1876
|
+
spike_map = (
|
|
1877
|
+
np.bincount(flat, weights=weights, minlength=bins * bins)
|
|
1878
|
+
.astype(float)
|
|
1879
|
+
.reshape(bins, bins)
|
|
1880
|
+
)
|
|
1881
|
+
rate_map = np.zeros_like(spike_map)
|
|
1882
|
+
rate_map[occ_mask] = spike_map[occ_mask] / occ[occ_mask]
|
|
1883
|
+
if smoothing and sigma > 0:
|
|
1884
|
+
try:
|
|
1885
|
+
from scipy.ndimage import gaussian_filter as _gaussian_filter
|
|
1886
|
+
except Exception as e: # pragma: no cover - optional dependency
|
|
1887
|
+
raise ProcessingError(f"GridScore requires scipy for smoothing: {e}") from e
|
|
1888
|
+
rate_map = _gaussian_filter(rate_map, sigma=float(sigma), mode="nearest")
|
|
1889
|
+
|
|
1890
|
+
autocorr = compute_2d_autocorrelation(rate_map, overlap=overlap)
|
|
1891
|
+
analyzer = GridnessAnalyzer()
|
|
1892
|
+
result = analyzer.compute_gridness_score(autocorr)
|
|
1893
|
+
|
|
1894
|
+
out_dir = self._results_dir(state) / "GRIDScore"
|
|
1895
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
1896
|
+
|
|
1897
|
+
gridscore_neuron_npz = out_dir / f"gridscore_neuron_{neuron_id}.npz"
|
|
1898
|
+
np.savez_compressed(
|
|
1899
|
+
str(gridscore_neuron_npz),
|
|
1900
|
+
neuron_id=neuron_id,
|
|
1901
|
+
bins=bins,
|
|
1902
|
+
min_occupancy=min_occ,
|
|
1903
|
+
smoothing=smoothing,
|
|
1904
|
+
sigma=sigma,
|
|
1905
|
+
overlap=overlap,
|
|
1906
|
+
mode=mode,
|
|
1907
|
+
grid_score=float(result.score),
|
|
1908
|
+
spacing=np.asarray(result.spacing)
|
|
1909
|
+
if result.spacing is not None
|
|
1910
|
+
else np.full((3,), np.nan),
|
|
1911
|
+
orientation=np.asarray(result.orientation)
|
|
1912
|
+
if result.orientation is not None
|
|
1913
|
+
else np.full((3,), np.nan),
|
|
1914
|
+
ellipse=np.asarray(result.ellipse)
|
|
1915
|
+
if result.ellipse is not None
|
|
1916
|
+
else np.full((5,), np.nan),
|
|
1917
|
+
ellipse_theta_deg=float(getattr(result, "ellipse_theta_deg", np.nan)),
|
|
1918
|
+
center_radius=float(getattr(result, "center_radius", np.nan)),
|
|
1919
|
+
optimal_radius=float(getattr(result, "optimal_radius", np.nan)),
|
|
1920
|
+
)
|
|
1921
|
+
|
|
1922
|
+
gridscore_neuron_png = out_dir / f"gridscore_neuron_{neuron_id}.png"
|
|
1923
|
+
plot_gridness_analysis(
|
|
1924
|
+
rate_map=rate_map,
|
|
1925
|
+
autocorr=autocorr,
|
|
1926
|
+
result=result,
|
|
1927
|
+
save_path=str(gridscore_neuron_png),
|
|
1928
|
+
show=False,
|
|
1929
|
+
title=f"GridScore neuron {neuron_id}",
|
|
1930
|
+
)
|
|
1931
|
+
progress_callback(100)
|
|
1932
|
+
|
|
1933
|
+
return {
|
|
1934
|
+
"gridscore_neuron_npz": gridscore_neuron_npz,
|
|
1935
|
+
"gridscore_neuron_png": gridscore_neuron_png,
|
|
1936
|
+
}
|