canns 0.13.1__py3-none-any.whl → 0.14.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. canns/analyzer/data/__init__.py +5 -1
  2. canns/analyzer/data/asa/__init__.py +27 -12
  3. canns/analyzer/data/asa/cohospace.py +336 -10
  4. canns/analyzer/data/asa/config.py +3 -0
  5. canns/analyzer/data/asa/embedding.py +48 -45
  6. canns/analyzer/data/asa/path.py +104 -2
  7. canns/analyzer/data/asa/plotting.py +88 -19
  8. canns/analyzer/data/asa/tda.py +11 -4
  9. canns/analyzer/data/cell_classification/__init__.py +97 -0
  10. canns/analyzer/data/cell_classification/core/__init__.py +26 -0
  11. canns/analyzer/data/cell_classification/core/grid_cells.py +633 -0
  12. canns/analyzer/data/cell_classification/core/grid_modules_leiden.py +288 -0
  13. canns/analyzer/data/cell_classification/core/head_direction.py +347 -0
  14. canns/analyzer/data/cell_classification/core/spatial_analysis.py +431 -0
  15. canns/analyzer/data/cell_classification/io/__init__.py +5 -0
  16. canns/analyzer/data/cell_classification/io/matlab_loader.py +417 -0
  17. canns/analyzer/data/cell_classification/utils/__init__.py +39 -0
  18. canns/analyzer/data/cell_classification/utils/circular_stats.py +383 -0
  19. canns/analyzer/data/cell_classification/utils/correlation.py +318 -0
  20. canns/analyzer/data/cell_classification/utils/geometry.py +442 -0
  21. canns/analyzer/data/cell_classification/utils/image_processing.py +416 -0
  22. canns/analyzer/data/cell_classification/visualization/__init__.py +19 -0
  23. canns/analyzer/data/cell_classification/visualization/grid_plots.py +292 -0
  24. canns/analyzer/data/cell_classification/visualization/hd_plots.py +200 -0
  25. canns/analyzer/metrics/__init__.py +2 -1
  26. canns/analyzer/visualization/core/config.py +46 -4
  27. canns/data/__init__.py +6 -1
  28. canns/data/datasets.py +154 -1
  29. canns/data/loaders.py +37 -0
  30. canns/pipeline/__init__.py +13 -9
  31. canns/pipeline/__main__.py +6 -0
  32. canns/pipeline/asa/runner.py +105 -41
  33. canns/pipeline/asa_gui/__init__.py +68 -0
  34. canns/pipeline/asa_gui/__main__.py +6 -0
  35. canns/pipeline/asa_gui/analysis_modes/__init__.py +42 -0
  36. canns/pipeline/asa_gui/analysis_modes/base.py +39 -0
  37. canns/pipeline/asa_gui/analysis_modes/batch_mode.py +21 -0
  38. canns/pipeline/asa_gui/analysis_modes/cohomap_mode.py +56 -0
  39. canns/pipeline/asa_gui/analysis_modes/cohospace_mode.py +194 -0
  40. canns/pipeline/asa_gui/analysis_modes/decode_mode.py +52 -0
  41. canns/pipeline/asa_gui/analysis_modes/fr_mode.py +81 -0
  42. canns/pipeline/asa_gui/analysis_modes/frm_mode.py +92 -0
  43. canns/pipeline/asa_gui/analysis_modes/gridscore_mode.py +123 -0
  44. canns/pipeline/asa_gui/analysis_modes/pathcompare_mode.py +199 -0
  45. canns/pipeline/asa_gui/analysis_modes/tda_mode.py +112 -0
  46. canns/pipeline/asa_gui/app.py +29 -0
  47. canns/pipeline/asa_gui/controllers/__init__.py +6 -0
  48. canns/pipeline/asa_gui/controllers/analysis_controller.py +59 -0
  49. canns/pipeline/asa_gui/controllers/preprocess_controller.py +89 -0
  50. canns/pipeline/asa_gui/core/__init__.py +15 -0
  51. canns/pipeline/asa_gui/core/cache.py +14 -0
  52. canns/pipeline/asa_gui/core/runner.py +1936 -0
  53. canns/pipeline/asa_gui/core/state.py +324 -0
  54. canns/pipeline/asa_gui/core/worker.py +260 -0
  55. canns/pipeline/asa_gui/main_window.py +184 -0
  56. canns/pipeline/asa_gui/models/__init__.py +7 -0
  57. canns/pipeline/asa_gui/models/config.py +14 -0
  58. canns/pipeline/asa_gui/models/job.py +31 -0
  59. canns/pipeline/asa_gui/models/presets.py +21 -0
  60. canns/pipeline/asa_gui/resources/__init__.py +16 -0
  61. canns/pipeline/asa_gui/resources/dark.qss +167 -0
  62. canns/pipeline/asa_gui/resources/light.qss +163 -0
  63. canns/pipeline/asa_gui/resources/styles.qss +130 -0
  64. canns/pipeline/asa_gui/utils/__init__.py +1 -0
  65. canns/pipeline/asa_gui/utils/formatters.py +15 -0
  66. canns/pipeline/asa_gui/utils/io_adapters.py +40 -0
  67. canns/pipeline/asa_gui/utils/validators.py +41 -0
  68. canns/pipeline/asa_gui/views/__init__.py +1 -0
  69. canns/pipeline/asa_gui/views/help_content.py +171 -0
  70. canns/pipeline/asa_gui/views/pages/__init__.py +6 -0
  71. canns/pipeline/asa_gui/views/pages/analysis_page.py +565 -0
  72. canns/pipeline/asa_gui/views/pages/preprocess_page.py +492 -0
  73. canns/pipeline/asa_gui/views/panels/__init__.py +1 -0
  74. canns/pipeline/asa_gui/views/widgets/__init__.py +21 -0
  75. canns/pipeline/asa_gui/views/widgets/artifacts_tab.py +44 -0
  76. canns/pipeline/asa_gui/views/widgets/drop_zone.py +80 -0
  77. canns/pipeline/asa_gui/views/widgets/file_list.py +27 -0
  78. canns/pipeline/asa_gui/views/widgets/gridscore_tab.py +308 -0
  79. canns/pipeline/asa_gui/views/widgets/help_dialog.py +27 -0
  80. canns/pipeline/asa_gui/views/widgets/image_tab.py +50 -0
  81. canns/pipeline/asa_gui/views/widgets/image_viewer.py +97 -0
  82. canns/pipeline/asa_gui/views/widgets/log_box.py +16 -0
  83. canns/pipeline/asa_gui/views/widgets/pathcompare_tab.py +200 -0
  84. canns/pipeline/asa_gui/views/widgets/popup_combo.py +25 -0
  85. canns/pipeline/gallery/__init__.py +15 -5
  86. canns/pipeline/gallery/__main__.py +11 -0
  87. canns/pipeline/gallery/app.py +705 -0
  88. canns/pipeline/gallery/runner.py +790 -0
  89. canns/pipeline/gallery/state.py +51 -0
  90. canns/pipeline/gallery/styles.tcss +123 -0
  91. canns/pipeline/launcher.py +81 -0
  92. {canns-0.13.1.dist-info → canns-0.14.0.dist-info}/METADATA +11 -1
  93. canns-0.14.0.dist-info/RECORD +163 -0
  94. canns-0.14.0.dist-info/entry_points.txt +5 -0
  95. canns/pipeline/_base.py +0 -50
  96. canns-0.13.1.dist-info/RECORD +0 -89
  97. canns-0.13.1.dist-info/entry_points.txt +0 -3
  98. {canns-0.13.1.dist-info → canns-0.14.0.dist-info}/WHEEL +0 -0
  99. {canns-0.13.1.dist-info → canns-0.14.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,790 @@
1
+ """Execution helpers for the model gallery TUI."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import multiprocessing as mp
6
+ import sys
7
+ import time
8
+ import traceback
9
+ from collections.abc import Callable
10
+ from dataclasses import dataclass
11
+ from pathlib import Path
12
+ from typing import Any
13
+
14
+ import brainpy.math as bm
15
+ import matplotlib
16
+ import matplotlib.pyplot as plt
17
+ import numpy as np
18
+
19
+ from canns.analyzer.metrics.spatial_metrics import compute_firing_field, gaussian_smooth_heatmaps
20
+ from canns.analyzer.visualization import (
21
+ PlotConfigs,
22
+ energy_landscape_1d_static,
23
+ energy_landscape_2d_static,
24
+ plot_firing_field_heatmap,
25
+ tuning_curve,
26
+ )
27
+ from canns.models.basic import CANN1D, CANN2D, GridCell2DVelocity
28
+ from canns.task.open_loop_navigation import OpenLoopNavigationTask
29
+ from canns.task.tracking import (
30
+ SmoothTracking1D,
31
+ SmoothTracking2D,
32
+ TemplateMatching1D,
33
+ TemplateMatching2D,
34
+ )
35
+
36
+
37
+ @dataclass
38
+ class GalleryResult:
39
+ """Result from running a gallery analysis."""
40
+
41
+ success: bool
42
+ artifacts: dict[str, Path]
43
+ summary: str
44
+ error: str | None = None
45
+ elapsed_time: float = 0.0
46
+
47
+
48
+ class GalleryRunner:
49
+ """Runner for gallery model analyses."""
50
+
51
+ def __init__(self) -> None:
52
+ self._mpl_ready = False
53
+
54
+ def _ensure_matplotlib_backend(self) -> None:
55
+ if self._mpl_ready:
56
+ return
57
+ matplotlib.use("Agg", force=True)
58
+ self._mpl_ready = True
59
+
60
+ def _ensure_multiprocessing(self) -> None:
61
+ """Stabilize multiprocessing behavior on macOS within threads."""
62
+ if sys.platform == "darwin":
63
+ try:
64
+ mp.set_start_method("fork", force=True)
65
+ except RuntimeError:
66
+ pass
67
+
68
+ def _ensure_brainpy_environment(self) -> None:
69
+ """Initialize BrainPy environment for worker threads."""
70
+ try:
71
+ import brainstate.environ as bs_env
72
+ from brainpy.math import defaults as bm_defaults
73
+
74
+ bm_defaults.setting()
75
+ bm.set_environment(
76
+ mode=bm.nonbatching_mode,
77
+ bp_object_as_pytree=False,
78
+ numpy_func_return="bp_array",
79
+ )
80
+ bs_env.set(
81
+ mode=bm.nonbatching_mode,
82
+ dt=bm.get_dt(),
83
+ bp_object_as_pytree=False,
84
+ numpy_func_return="bp_array",
85
+ )
86
+ except Exception:
87
+ pass
88
+
89
+ async def run(
90
+ self,
91
+ model: str,
92
+ analysis: str,
93
+ model_params: dict[str, Any],
94
+ analysis_params: dict[str, Any],
95
+ output_dir: Path,
96
+ log_callback: Callable[[str], None],
97
+ progress_callback: Callable[[int], None],
98
+ ) -> GalleryResult:
99
+ start_time = time.time()
100
+ artifacts: dict[str, Path] = {}
101
+
102
+ try:
103
+ self._ensure_matplotlib_backend()
104
+ self._ensure_multiprocessing()
105
+ self._ensure_brainpy_environment()
106
+ output_dir.mkdir(parents=True, exist_ok=True)
107
+
108
+ log_callback(f"Running {model} / {analysis}...")
109
+ progress_callback(5)
110
+
111
+ if model == "cann1d":
112
+ output_path = self._run_cann1d(
113
+ analysis,
114
+ model_params,
115
+ analysis_params,
116
+ output_dir,
117
+ log_callback,
118
+ progress_callback,
119
+ )
120
+ elif model == "cann2d":
121
+ output_path = self._run_cann2d(
122
+ analysis,
123
+ model_params,
124
+ analysis_params,
125
+ output_dir,
126
+ log_callback,
127
+ progress_callback,
128
+ )
129
+ elif model == "gridcell":
130
+ output_path = self._run_gridcell(
131
+ analysis,
132
+ model_params,
133
+ analysis_params,
134
+ output_dir,
135
+ log_callback,
136
+ progress_callback,
137
+ )
138
+ else:
139
+ raise ValueError(f"Unknown model: {model}")
140
+
141
+ artifacts["output"] = output_path
142
+ elapsed = time.time() - start_time
143
+ progress_callback(100)
144
+ return GalleryResult(
145
+ success=True,
146
+ artifacts=artifacts,
147
+ summary=f"Completed in {elapsed:.1f}s",
148
+ elapsed_time=elapsed,
149
+ )
150
+ except Exception as exc:
151
+ elapsed = time.time() - start_time
152
+ log_callback(f"Error: {exc}")
153
+ log_callback(traceback.format_exc())
154
+ return GalleryResult(
155
+ success=False,
156
+ artifacts=artifacts,
157
+ summary=f"Failed after {elapsed:.1f}s",
158
+ error=str(exc),
159
+ elapsed_time=elapsed,
160
+ )
161
+
162
+ def _run_cann1d(
163
+ self,
164
+ analysis: str,
165
+ model_params: dict[str, Any],
166
+ analysis_params: dict[str, Any],
167
+ output_dir: Path,
168
+ log_callback: Callable[[str], None],
169
+ progress_callback: Callable[[int], None],
170
+ ) -> Path:
171
+ seed = model_params["seed"]
172
+ np.random.seed(seed)
173
+ bm.random.seed(seed)
174
+ bm.set_dt(model_params["dt"])
175
+
176
+ model = CANN1D(
177
+ num=model_params["num"],
178
+ tau=model_params["tau"],
179
+ k=model_params["k"],
180
+ a=model_params["a"],
181
+ A=model_params["A"],
182
+ J0=model_params["J0"],
183
+ )
184
+
185
+ output_path = output_dir / f"cann1d_{analysis}_seed{seed}.png"
186
+
187
+ if analysis == "connectivity":
188
+ log_callback("Rendering connectivity matrix...")
189
+ progress_callback(30)
190
+ self._plot_connectivity(model.conn_mat, output_path, title="CANN1D Connectivity")
191
+ return output_path
192
+
193
+ if analysis == "energy":
194
+ log_callback("Simulating energy landscape...")
195
+ task = TemplateMatching1D(
196
+ model,
197
+ Iext=analysis_params["energy_pos"],
198
+ duration=analysis_params["energy_duration"],
199
+ time_step=model_params["dt"],
200
+ )
201
+ task.get_data(progress_bar=False)
202
+
203
+ def run_step(inputs):
204
+ model(inputs)
205
+ return model.u.value
206
+
207
+ us = bm.for_loop(run_step, operands=(task.data,), progress_bar=False)
208
+ select_index = len(task.data) // 2
209
+ config = PlotConfigs.energy_landscape_1d_static(
210
+ title="Energy Landscape 1D",
211
+ xlabel="State",
212
+ ylabel="Activity",
213
+ show=False,
214
+ save_path=str(output_path),
215
+ save_format="png",
216
+ )
217
+ energy_landscape_1d_static(
218
+ data_sets={"u": (np.asarray(model.x), np.asarray(us)[select_index])},
219
+ config=config,
220
+ )
221
+ return output_path
222
+
223
+ if analysis == "tuning":
224
+ log_callback("Simulating tuning curves...")
225
+ task = SmoothTracking1D(
226
+ model,
227
+ Iext=(
228
+ analysis_params["tuning_start"],
229
+ analysis_params["tuning_mid"],
230
+ analysis_params["tuning_end"],
231
+ ),
232
+ duration=(analysis_params["tuning_duration"],) * 2,
233
+ time_step=model_params["dt"],
234
+ )
235
+ task.get_data(progress_bar=False)
236
+
237
+ def run_step(inputs):
238
+ model(inputs)
239
+ return model.r.value
240
+
241
+ rs = bm.for_loop(run_step, operands=(task.data,), progress_bar=False)
242
+ neuron_indices = self._parse_indices(analysis_params["tuning_neurons"], len(model.x))
243
+ config = PlotConfigs.tuning_curve(
244
+ num_bins=analysis_params["tuning_bins"],
245
+ pref_stim=np.asarray(model.x),
246
+ title="Tuning Curves",
247
+ xlabel="Stimulus",
248
+ ylabel="Average Rate",
249
+ show=False,
250
+ save_path=str(output_path),
251
+ save_format="png",
252
+ )
253
+ tuning_curve(
254
+ stimulus=task.Iext_sequence.squeeze(),
255
+ firing_rates=np.asarray(rs),
256
+ neuron_indices=neuron_indices,
257
+ config=config,
258
+ )
259
+ return output_path
260
+
261
+ if analysis == "template":
262
+ log_callback("Simulating template matching...")
263
+ task = TemplateMatching1D(
264
+ model,
265
+ Iext=analysis_params["template_pos"],
266
+ duration=analysis_params["template_duration"],
267
+ time_step=model_params["dt"],
268
+ )
269
+ task.get_data(progress_bar=False)
270
+
271
+ def run_step(inputs):
272
+ model(inputs)
273
+ return model.u.value, model.inp.value
274
+
275
+ us, inps = bm.for_loop(run_step, operands=(task.data,), progress_bar=False)
276
+ select_index = len(task.data) // 2
277
+ fig, ax = plt.subplots(figsize=(6, 4))
278
+ ax.plot(
279
+ np.asarray(model.x), np.asarray(inps)[select_index], "r--", linewidth=2.0, alpha=0.6
280
+ )
281
+ ax.plot(np.asarray(model.x), np.asarray(us)[select_index], "b-", linewidth=2.5)
282
+ ax.grid(True, alpha=0.3)
283
+ ax.set_title("Template Matching", fontsize=12, fontweight="bold")
284
+ fig.tight_layout()
285
+ fig.savefig(output_path, dpi=150, bbox_inches="tight")
286
+ plt.close(fig)
287
+ return output_path
288
+
289
+ if analysis == "manifold":
290
+ log_callback("Computing neural manifold...")
291
+ segment = analysis_params["manifold_segment"]
292
+ warmup = analysis_params["manifold_warmup"]
293
+ iext = (0.0, 0.0, np.pi, 2 * np.pi, -2 * np.pi, 0.0)
294
+ durations = (warmup, segment, segment, segment, segment)
295
+ task = SmoothTracking1D(
296
+ model, Iext=iext, duration=durations, time_step=model_params["dt"]
297
+ )
298
+ task.get_data(progress_bar=False)
299
+
300
+ def run_step(t, inputs):
301
+ model(inputs)
302
+ return model.r.value
303
+
304
+ rs = bm.for_loop(run_step, (task.run_steps, task.data), progress_bar=False)
305
+ n_warmup = int(warmup / model_params["dt"])
306
+ firing_rates = np.asarray(rs[n_warmup:])
307
+ stimulus_pos = np.asarray(task.Iext_sequence).squeeze()[n_warmup:]
308
+ projected = self._pca_projection(firing_rates, n_components=2)
309
+
310
+ fig, ax = plt.subplots(figsize=(6, 4))
311
+ ax.scatter(
312
+ projected[:, 0], projected[:, 1], c=stimulus_pos, cmap="viridis", s=2, alpha=0.7
313
+ )
314
+ ax.set_title("Neural Manifold (PC1/PC2)", fontsize=12, fontweight="bold")
315
+ ax.set_xticks([])
316
+ ax.set_yticks([])
317
+ fig.tight_layout()
318
+ fig.savefig(output_path, dpi=150, bbox_inches="tight")
319
+ plt.close(fig)
320
+ return output_path
321
+
322
+ raise ValueError(f"Unsupported analysis for CANN1D: {analysis}")
323
+
324
+ def _run_cann2d(
325
+ self,
326
+ analysis: str,
327
+ model_params: dict[str, Any],
328
+ analysis_params: dict[str, Any],
329
+ output_dir: Path,
330
+ log_callback: Callable[[str], None],
331
+ progress_callback: Callable[[int], None],
332
+ ) -> Path:
333
+ seed = model_params["seed"]
334
+ np.random.seed(seed)
335
+ bm.random.seed(seed)
336
+ bm.set_dt(model_params["dt"])
337
+
338
+ model = CANN2D(
339
+ length=model_params["length"],
340
+ tau=model_params["tau"],
341
+ k=model_params["k"],
342
+ a=model_params["a"],
343
+ A=model_params["A"],
344
+ J0=model_params["J0"],
345
+ )
346
+
347
+ output_path = output_dir / f"cann2d_{analysis}_seed{seed}.png"
348
+
349
+ if analysis == "connectivity":
350
+ log_callback("Rendering connectivity matrix...")
351
+ progress_callback(30)
352
+ self._plot_connectivity(model.conn_mat, output_path, title="CANN2D Connectivity")
353
+ return output_path
354
+
355
+ if analysis == "energy":
356
+ log_callback("Simulating energy landscape...")
357
+ task = TemplateMatching2D(
358
+ model,
359
+ Iext=(analysis_params["energy_x"], analysis_params["energy_y"]),
360
+ duration=analysis_params["energy_duration"],
361
+ time_step=model_params["dt"],
362
+ )
363
+ task.get_data(progress_bar=False)
364
+
365
+ def run_step(inputs):
366
+ model(inputs)
367
+ return model.u.value
368
+
369
+ us = bm.for_loop(run_step, operands=(task.data,), progress_bar=False)
370
+ select_index = len(task.data) // 2
371
+ config = PlotConfigs.energy_landscape_2d_static(
372
+ title="Energy Landscape 2D",
373
+ xlabel="State",
374
+ ylabel="Activity",
375
+ show=False,
376
+ save_path=str(output_path),
377
+ save_format="png",
378
+ )
379
+ energy_landscape_2d_static(z_data=np.asarray(us)[select_index], config=config)
380
+ return output_path
381
+
382
+ if analysis == "firing_field":
383
+ log_callback("Computing firing field...")
384
+ box_size = analysis_params["field_box"]
385
+ task = OpenLoopNavigationTask(
386
+ duration=analysis_params["field_duration"],
387
+ width=box_size,
388
+ height=box_size,
389
+ start_pos=(box_size / 2.0, box_size / 2.0),
390
+ speed_mean=analysis_params["field_speed"],
391
+ speed_std=analysis_params["field_speed_std"],
392
+ dt=model_params["dt"],
393
+ rng_seed=seed,
394
+ progress_bar=False,
395
+ )
396
+ task.get_data()
397
+ positions = task.data.position
398
+
399
+ def run_step(inputs):
400
+ stimulus = model.get_stimulus_by_pos(inputs)
401
+ model(stimulus)
402
+ return model.r.value
403
+
404
+ rs = bm.for_loop(run_step, operands=(positions,), progress_bar=False)
405
+ activity = np.asarray(rs).reshape(rs.shape[0], -1)
406
+
407
+ firing_fields = compute_firing_field(
408
+ activity,
409
+ np.asarray(positions),
410
+ width=box_size,
411
+ height=box_size,
412
+ M=analysis_params["field_resolution"],
413
+ K=analysis_params["field_resolution"],
414
+ )
415
+ firing_fields = gaussian_smooth_heatmaps(
416
+ firing_fields, sigma=analysis_params["field_sigma"]
417
+ )
418
+ cell_idx = min(64, firing_fields.shape[0] - 1)
419
+ config = PlotConfigs.firing_field_heatmap(
420
+ title=f"Firing Field Cell {cell_idx}",
421
+ show=False,
422
+ save_path=str(output_path),
423
+ save_format="png",
424
+ )
425
+ plot_firing_field_heatmap(firing_fields[cell_idx], config=config)
426
+ return output_path
427
+
428
+ if analysis == "trajectory":
429
+ log_callback("Computing trajectory comparison...")
430
+ segment = analysis_params["traj_segment"]
431
+ warmup = analysis_params["traj_warmup"]
432
+ iext = (
433
+ (0.0, 0.0),
434
+ (0.0, 0.0),
435
+ (-2.0, 2.0),
436
+ (2.0, 2.0),
437
+ (2.0, -2.0),
438
+ (-2.0, -2.0),
439
+ )
440
+ durations = (warmup, segment, segment, segment, segment)
441
+ task = SmoothTracking2D(
442
+ model, Iext=iext, duration=durations, time_step=model_params["dt"]
443
+ )
444
+ task.get_data(progress_bar=False)
445
+ true_positions = np.asarray(task.Iext_sequence)
446
+
447
+ def run_step(inputs):
448
+ model(inputs)
449
+ return model.r.value
450
+
451
+ rs = bm.for_loop(run_step, operands=(task.data,), progress_bar=False)
452
+ decoded = self._decode_cann2d_center(np.asarray(rs), model.length)
453
+ decoded_pos = (decoded / model.length - 0.5) * 2 * np.pi
454
+ warmup_steps = int(warmup / model_params["dt"])
455
+ true_pos = true_positions[warmup_steps:]
456
+ decoded_pos = decoded_pos[warmup_steps:]
457
+
458
+ fig, ax = plt.subplots(figsize=(6, 4))
459
+ ax.plot(true_pos[:, 0], true_pos[:, 1], "b-", linewidth=1.5, alpha=0.6)
460
+ ax.plot(decoded_pos[:, 0], decoded_pos[:, 1], "r--", linewidth=1.5, alpha=0.8)
461
+ ax.grid(True, alpha=0.3)
462
+ ax.set_aspect("equal", adjustable="box")
463
+ ax.set_title("Trajectory Comparison", fontsize=12, fontweight="bold")
464
+ fig.tight_layout()
465
+ fig.savefig(output_path, dpi=150, bbox_inches="tight")
466
+ plt.close(fig)
467
+ return output_path
468
+
469
+ if analysis == "manifold":
470
+ log_callback("Computing neural manifold...")
471
+ box_size = 2 * np.pi
472
+ task = OpenLoopNavigationTask(
473
+ duration=analysis_params["manifold_duration"],
474
+ width=box_size,
475
+ height=box_size,
476
+ start_pos=(box_size / 2.0, box_size / 2.0),
477
+ speed_mean=analysis_params["manifold_speed"],
478
+ speed_std=analysis_params["manifold_speed_std"],
479
+ dt=model_params["dt"],
480
+ rng_seed=seed,
481
+ progress_bar=False,
482
+ )
483
+ task.get_data()
484
+ positions = task.data.position
485
+
486
+ def run_step(inputs):
487
+ stimulus = model.get_stimulus_by_pos(inputs)
488
+ model(stimulus)
489
+ return model.r.value
490
+
491
+ rs = bm.for_loop(run_step, operands=(positions,), progress_bar=False)
492
+ n_warmup = int(analysis_params["manifold_warmup"] / model_params["dt"])
493
+ firing_rates = np.asarray(rs[n_warmup:]).reshape(-1, model.length * model.length)
494
+ stimulus_pos = np.asarray(positions[n_warmup:])
495
+
496
+ downsample = max(1, analysis_params["manifold_downsample"])
497
+ firing_rates = firing_rates[::downsample]
498
+ stimulus_pos = stimulus_pos[::downsample]
499
+
500
+ projected = self._pca_projection(firing_rates, n_components=2)
501
+ fig, ax = plt.subplots(figsize=(6, 4))
502
+ ax.scatter(
503
+ projected[:, 0],
504
+ projected[:, 1],
505
+ c=stimulus_pos[:, 0],
506
+ cmap="viridis",
507
+ s=2,
508
+ alpha=0.7,
509
+ )
510
+ ax.set_xticks([])
511
+ ax.set_yticks([])
512
+ ax.set_title("Neural Manifold (PC1/PC2)", fontsize=12, fontweight="bold")
513
+ fig.tight_layout()
514
+ fig.savefig(output_path, dpi=150, bbox_inches="tight")
515
+ plt.close(fig)
516
+ return output_path
517
+
518
+ raise ValueError(f"Unsupported analysis for CANN2D: {analysis}")
519
+
520
+ def _run_gridcell(
521
+ self,
522
+ analysis: str,
523
+ model_params: dict[str, Any],
524
+ analysis_params: dict[str, Any],
525
+ output_dir: Path,
526
+ log_callback: Callable[[str], None],
527
+ progress_callback: Callable[[int], None],
528
+ ) -> Path:
529
+ seed = model_params["seed"]
530
+ np.random.seed(seed)
531
+ bm.random.seed(seed)
532
+ bm.set_dt(model_params["dt"])
533
+
534
+ output_path = output_dir / f"gridcell_{analysis}_seed{seed}.png"
535
+
536
+ if analysis == "connectivity":
537
+ np.random.seed(999)
538
+ bm.random.seed(999)
539
+ model = GridCell2DVelocity(
540
+ length=model_params["length"],
541
+ tau=model_params["tau"],
542
+ alpha=model_params["alpha"],
543
+ W_l=model_params["W_l"],
544
+ lambda_net=model_params["lambda_net"],
545
+ )
546
+ log_callback("Rendering connectivity matrix...")
547
+ progress_callback(30)
548
+ self._plot_connectivity(model.conn_mat, output_path, title="Grid Cell Connectivity")
549
+ return output_path
550
+
551
+ model = GridCell2DVelocity(
552
+ length=model_params["length"],
553
+ tau=model_params["tau"],
554
+ alpha=model_params["alpha"],
555
+ W_l=model_params["W_l"],
556
+ lambda_net=model_params["lambda_net"],
557
+ )
558
+
559
+ box_size = analysis_params["box_size"]
560
+ start_pos = (box_size / 2.0, box_size / 2.0)
561
+
562
+ if analysis == "energy":
563
+ log_callback("Computing energy landscape...")
564
+ task = OpenLoopNavigationTask(
565
+ duration=analysis_params["energy_duration"],
566
+ width=box_size,
567
+ height=box_size,
568
+ start_pos=start_pos,
569
+ speed_mean=analysis_params["energy_speed"],
570
+ speed_std=analysis_params["energy_speed_std"],
571
+ dt=bm.get_dt(),
572
+ rng_seed=seed,
573
+ progress_bar=False,
574
+ )
575
+ task.get_data()
576
+ model.heal_network(
577
+ num_healing_steps=analysis_params["energy_heal_steps"],
578
+ dt_healing=1e-4,
579
+ )
580
+
581
+ def run_step(vel):
582
+ model(vel)
583
+ return model.s.value
584
+
585
+ us = bm.for_loop(run_step, operands=(task.data.velocity,), progress_bar=False)
586
+ select_index = int(task.total_steps * 0.75)
587
+ energy_data = np.asarray(us)[select_index].reshape(model.length, model.length)
588
+ fig, ax = plt.subplots(figsize=(5, 5))
589
+ ax.imshow(energy_data, cmap="viridis", origin="lower")
590
+ ax.set_xticks([])
591
+ ax.set_yticks([])
592
+ ax.set_title("Energy Landscape", fontsize=12, fontweight="bold")
593
+ fig.tight_layout()
594
+ fig.savefig(output_path, dpi=150, bbox_inches="tight")
595
+ plt.close(fig)
596
+ return output_path
597
+
598
+ if analysis == "firing_field":
599
+ log_callback("Computing firing field...")
600
+ from canns.analyzer.metrics.systematic_ratemap import compute_systematic_ratemap
601
+
602
+ ratemaps = compute_systematic_ratemap(
603
+ model,
604
+ box_width=box_size,
605
+ box_height=box_size,
606
+ resolution=analysis_params["field_resolution"],
607
+ speed=analysis_params["field_speed"],
608
+ num_batches=analysis_params["field_batches"],
609
+ verbose=False,
610
+ )
611
+ firing_fields = np.transpose(ratemaps, (2, 0, 1))
612
+ firing_fields = gaussian_smooth_heatmaps(
613
+ firing_fields, sigma=analysis_params["field_sigma"]
614
+ )
615
+ cell_idx = model.num // 2
616
+ config = PlotConfigs.firing_field_heatmap(
617
+ title=f"Grid Cell Field {cell_idx}",
618
+ show=False,
619
+ save_path=str(output_path),
620
+ save_format="png",
621
+ )
622
+ plot_firing_field_heatmap(firing_fields[cell_idx], config=config)
623
+ return output_path
624
+
625
+ if analysis == "path_integration":
626
+ log_callback("Computing path integration...")
627
+ model.heal_network(
628
+ num_healing_steps=analysis_params["path_heal_steps"],
629
+ dt_healing=1e-4,
630
+ )
631
+ task = OpenLoopNavigationTask(
632
+ duration=analysis_params["path_duration"],
633
+ width=box_size,
634
+ height=box_size,
635
+ start_pos=start_pos,
636
+ speed_mean=analysis_params["path_speed"],
637
+ speed_std=analysis_params["path_speed_std"],
638
+ dt=analysis_params["path_dt"],
639
+ rng_seed=seed,
640
+ progress_bar=False,
641
+ )
642
+ task.get_data()
643
+ true_positions = np.asarray(task.data.position)
644
+
645
+ def run_step(vel):
646
+ model(vel)
647
+ return model.r.value
648
+
649
+ activities = bm.for_loop(run_step, operands=(task.data.velocity,), progress_bar=False)
650
+ activities = np.asarray(activities)
651
+ blob_centers = GridCell2DVelocity.track_blob_centers(activities, model.length)
652
+ blob_displacement = np.diff(blob_centers, axis=0)
653
+ displacement_norm = np.linalg.norm(blob_displacement, axis=1)
654
+ jump_indices = np.where(displacement_norm > 3.0)[0]
655
+ for idx in jump_indices:
656
+ if 0 < idx < len(blob_displacement) - 1:
657
+ blob_displacement[idx] = (
658
+ blob_displacement[idx - 1] + blob_displacement[idx + 1]
659
+ ) / 2
660
+ estimated_pos_neuron = np.cumsum(blob_displacement, axis=0)
661
+
662
+ true_pos_rel = true_positions - true_positions[0]
663
+ true_pos_aligned = true_pos_rel[: len(estimated_pos_neuron)]
664
+ X = estimated_pos_neuron.reshape(-1)
665
+ y = true_pos_aligned.reshape(-1)
666
+ scale = np.dot(X, y) / (np.dot(X, X) + 1e-8)
667
+ estimated_pos = scale * estimated_pos_neuron + true_positions[0]
668
+
669
+ fig, ax = plt.subplots(figsize=(6, 4))
670
+ ax.plot(
671
+ true_positions[: len(estimated_pos), 0],
672
+ true_positions[: len(estimated_pos), 1],
673
+ "b-",
674
+ alpha=0.5,
675
+ linewidth=1.5,
676
+ label="True",
677
+ )
678
+ ax.plot(
679
+ estimated_pos[:, 0],
680
+ estimated_pos[:, 1],
681
+ "r-",
682
+ alpha=0.7,
683
+ linewidth=1.5,
684
+ label="Estimated",
685
+ )
686
+ ax.set_aspect("equal", adjustable="box")
687
+ ax.grid(True, alpha=0.3)
688
+ ax.set_title("Path Integration", fontsize=12, fontweight="bold")
689
+ fig.tight_layout()
690
+ fig.savefig(output_path, dpi=150, bbox_inches="tight")
691
+ plt.close(fig)
692
+ return output_path
693
+
694
+ if analysis == "manifold":
695
+ log_callback("Computing neural manifold...")
696
+ from canns.analyzer.metrics.systematic_ratemap import compute_systematic_ratemap
697
+
698
+ ratemaps = compute_systematic_ratemap(
699
+ model,
700
+ box_width=box_size,
701
+ box_height=box_size,
702
+ resolution=analysis_params["field_resolution"],
703
+ speed=analysis_params["field_speed"],
704
+ num_batches=analysis_params["field_batches"],
705
+ verbose=False,
706
+ )
707
+ firing_fields = np.transpose(ratemaps, (2, 0, 1))
708
+ firing_fields = gaussian_smooth_heatmaps(
709
+ firing_fields, sigma=analysis_params["field_sigma"]
710
+ )
711
+ data_for_pca = firing_fields.reshape(firing_fields.shape[0], -1).T
712
+ projected = self._pca_projection(data_for_pca, n_components=3)
713
+ fig = plt.figure(figsize=(6, 5))
714
+ ax = fig.add_subplot(111, projection="3d")
715
+ ax.scatter(
716
+ projected[:, 0],
717
+ projected[:, 1],
718
+ projected[:, 2],
719
+ c=projected[:, 2],
720
+ cmap="viridis",
721
+ s=1,
722
+ alpha=0.7,
723
+ )
724
+ ax.axis("off")
725
+ ax.set_title("Grid Cell Manifold", fontsize=12, fontweight="bold")
726
+ fig.tight_layout()
727
+ fig.savefig(output_path, dpi=150, bbox_inches="tight")
728
+ plt.close(fig)
729
+ return output_path
730
+
731
+ raise ValueError(f"Unsupported analysis for Grid Cell: {analysis}")
732
+
733
+ def _plot_connectivity(self, conn_mat: Any, output_path: Path, title: str) -> None:
734
+ data = np.asarray(conn_mat)
735
+ fig, ax = plt.subplots(figsize=(6, 6))
736
+ im = ax.imshow(data, cmap="viridis")
737
+ ax.set_title(title)
738
+ ax.set_xlabel("Neuron Index")
739
+ ax.set_ylabel("Neuron Index")
740
+ fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
741
+ fig.tight_layout()
742
+ fig.savefig(output_path, dpi=150, bbox_inches="tight")
743
+ plt.close(fig)
744
+
745
+ def _pca_projection(self, data: np.ndarray, n_components: int = 3) -> np.ndarray:
746
+ centered = data - np.mean(data, axis=0, keepdims=True)
747
+ _, _, vt = np.linalg.svd(centered, full_matrices=False)
748
+ components = vt[:n_components].T
749
+ return centered @ components
750
+
751
+ def _decode_cann2d_center(self, activities: np.ndarray, length: int) -> np.ndarray:
752
+ from scipy.ndimage import center_of_mass, gaussian_filter, label
753
+
754
+ T = len(activities)
755
+ n = length
756
+ activities_2d = activities.reshape(T, n, n)
757
+ smoothed = np.array([gaussian_filter(activities_2d[t], sigma=1) for t in range(T)])
758
+ thresholds = smoothed.mean(axis=(1, 2)) + 0.5 * smoothed.std(axis=(1, 2))
759
+ binary_images = smoothed > thresholds[:, None, None]
760
+
761
+ centers = []
762
+ for i in range(T):
763
+ labeled, num_features = label(binary_images[i])
764
+ if num_features > 0:
765
+ blob_centers = np.array(
766
+ center_of_mass(binary_images[i], labeled, range(1, num_features + 1))
767
+ )
768
+ if blob_centers.ndim == 1:
769
+ blob_centers = blob_centers.reshape(1, -1)
770
+ blob_centers = blob_centers[:, [1, 0]]
771
+ dist = np.linalg.norm(blob_centers - n / 2, axis=1)
772
+ best_center = blob_centers[np.argmin(dist)]
773
+ else:
774
+ best_center = centers[-1] if centers else np.array([n / 2, n / 2])
775
+ centers.append(best_center)
776
+ return np.array(centers)
777
+
778
+ def _parse_indices(self, raw: str, max_size: int) -> list[int]:
779
+ cleaned = [p.strip() for p in raw.split(",") if p.strip()]
780
+ indices: list[int] = []
781
+ for part in cleaned:
782
+ try:
783
+ idx = int(part)
784
+ except ValueError:
785
+ continue
786
+ if 0 <= idx < max_size:
787
+ indices.append(idx)
788
+ if not indices:
789
+ indices = [max_size // 4, max_size // 2, (3 * max_size) // 4]
790
+ return indices