canns 0.13.1__py3-none-any.whl → 0.13.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,707 @@
1
+ """Model gallery TUI for quick CANN visualizations."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+
7
+ from textual.app import App, ComposeResult
8
+ from textual.binding import Binding
9
+ from textual.containers import Horizontal, Vertical, VerticalScroll
10
+ from textual.widgets import Button, Footer, Header, Input, Label, ProgressBar, Select, Static
11
+ from textual.worker import Worker
12
+
13
+ from canns.pipeline.asa.screens import ErrorScreen, TerminalSizeWarning, WorkdirScreen
14
+ from canns.pipeline.asa.widgets import ImagePreview, LogViewer, ParamGroup
15
+
16
+ from .runner import GalleryRunner
17
+ from .state import GalleryState, get_analysis_options, get_default_analysis
18
+
19
+
20
+ CANN1D_DEFAULTS = {
21
+ "seed": "42",
22
+ "num": "256",
23
+ "tau": "1.0",
24
+ "k": "8.1",
25
+ "a": "0.5",
26
+ "A": "10.0",
27
+ "J0": "4.0",
28
+ "dt": "0.1",
29
+ "energy_pos": "1.0",
30
+ "energy_duration": "10.0",
31
+ "tuning_start": "0.0",
32
+ "tuning_mid": "3.1416",
33
+ "tuning_end": "6.2832",
34
+ "tuning_duration": "40.0",
35
+ "tuning_bins": "50",
36
+ "tuning_neurons": "64,128,192",
37
+ "template_pos": "1.0",
38
+ "template_duration": "10.0",
39
+ "manifold_segment": "40.0",
40
+ "manifold_warmup": "5.0",
41
+ }
42
+
43
+ CANN2D_DEFAULTS = {
44
+ "seed": "42",
45
+ "length": "64",
46
+ "tau": "1.0",
47
+ "k": "8.1",
48
+ "a": "0.5",
49
+ "A": "10.0",
50
+ "J0": "4.0",
51
+ "dt": "0.1",
52
+ "energy_x": "1.0",
53
+ "energy_y": "1.0",
54
+ "energy_duration": "10.0",
55
+ "field_duration": "8000.0",
56
+ "field_box": "6.2832",
57
+ "field_resolution": "80",
58
+ "field_sigma": "2.0",
59
+ "field_speed": "0.3",
60
+ "field_speed_std": "0.1",
61
+ "traj_segment": "40.0",
62
+ "traj_warmup": "5.0",
63
+ "manifold_duration": "22000.0",
64
+ "manifold_warmup": "2000.0",
65
+ "manifold_speed": "0.05",
66
+ "manifold_speed_std": "0.02",
67
+ "manifold_downsample": "10",
68
+ }
69
+
70
+ GRID_DEFAULTS = {
71
+ "seed": "74",
72
+ "length": "40",
73
+ "tau": "0.01",
74
+ "alpha": "0.1",
75
+ "W_l": "3.0",
76
+ "lambda_net": "17.0",
77
+ "dt": "0.0005",
78
+ "box_size": "2.2",
79
+ "energy_duration": "10.0",
80
+ "energy_speed": "0.2",
81
+ "energy_speed_std": "0.05",
82
+ "energy_heal_steps": "10000",
83
+ "field_resolution": "100",
84
+ "field_sigma": "2.0",
85
+ "field_speed": "0.3",
86
+ "field_batches": "10",
87
+ "path_duration": "10.0",
88
+ "path_dt": "0.01",
89
+ "path_speed": "0.5",
90
+ "path_speed_std": "0.05",
91
+ "path_heal_steps": "5000",
92
+ }
93
+
94
+
95
+ class GalleryApp(App):
96
+ """Main TUI application for the model gallery."""
97
+
98
+ CSS_PATH = "styles.tcss"
99
+ TITLE = "CANNs Model Gallery"
100
+
101
+ MIN_WIDTH = 100
102
+ RECOMMENDED_WIDTH = 120
103
+ MIN_HEIGHT = 28
104
+ RECOMMENDED_HEIGHT = 36
105
+
106
+ BINDINGS = [
107
+ Binding("ctrl+w", "change_workdir", "Workdir"),
108
+ Binding("ctrl+r", "run", "Run"),
109
+ Binding("f5", "refresh", "Refresh"),
110
+ Binding("escape", "quit", "Quit"),
111
+ ]
112
+
113
+ def __init__(self) -> None:
114
+ super().__init__()
115
+ self.state = GalleryState()
116
+ self.runner = GalleryRunner()
117
+ self.current_worker: Worker | None = None
118
+ self._size_warning_shown = False
119
+
120
+ def compose(self) -> ComposeResult:
121
+ yield Header()
122
+
123
+ with Horizontal(id="main-container"):
124
+ with Vertical(id="left-panel"):
125
+ yield Label(f"Workdir: {self.state.workdir}", id="workdir-label")
126
+ yield Button("Change Workdir", id="change-workdir-btn")
127
+
128
+ yield Label("Model")
129
+ yield Select(
130
+ [
131
+ ("CANN 1D", "cann1d"),
132
+ ("CANN 2D", "cann2d"),
133
+ ("Grid Cell", "gridcell"),
134
+ ],
135
+ value=self.state.model,
136
+ id="model-select",
137
+ )
138
+
139
+ yield Label("Analysis")
140
+ yield Select(
141
+ get_analysis_options(self.state.model),
142
+ value=self.state.analysis,
143
+ id="analysis-select",
144
+ )
145
+
146
+ yield Button("Run", variant="primary", id="run-btn")
147
+ yield ProgressBar(id="progress-bar")
148
+ yield Static("Status: Idle", id="run-status")
149
+
150
+ with Vertical(id="middle-panel"):
151
+ with Vertical(id="params-panel"):
152
+ yield Static("Parameters", id="params-header")
153
+ with VerticalScroll(id="params-scroll"):
154
+ with Vertical(id="params-cann1d"):
155
+ with ParamGroup("CANN1D Model"):
156
+ yield Label("seed")
157
+ yield Input(value=CANN1D_DEFAULTS["seed"], id="c1-seed")
158
+ yield Label("num")
159
+ yield Input(value=CANN1D_DEFAULTS["num"], id="c1-num")
160
+ yield Label("tau")
161
+ yield Input(value=CANN1D_DEFAULTS["tau"], id="c1-tau")
162
+ yield Label("k")
163
+ yield Input(value=CANN1D_DEFAULTS["k"], id="c1-k")
164
+ yield Label("a")
165
+ yield Input(value=CANN1D_DEFAULTS["a"], id="c1-a")
166
+ yield Label("A")
167
+ yield Input(value=CANN1D_DEFAULTS["A"], id="c1-A")
168
+ yield Label("J0")
169
+ yield Input(value=CANN1D_DEFAULTS["J0"], id="c1-J0")
170
+ yield Label("dt")
171
+ yield Input(value=CANN1D_DEFAULTS["dt"], id="c1-dt")
172
+
173
+ with ParamGroup("Energy Landscape", id="c1-analysis-energy"):
174
+ yield Label("stimulus_pos")
175
+ yield Input(value=CANN1D_DEFAULTS["energy_pos"], id="c1-energy-pos")
176
+ yield Label("duration")
177
+ yield Input(
178
+ value=CANN1D_DEFAULTS["energy_duration"],
179
+ id="c1-energy-duration",
180
+ )
181
+
182
+ with ParamGroup("Tuning Curve", id="c1-analysis-tuning"):
183
+ yield Label("start_pos")
184
+ yield Input(
185
+ value=CANN1D_DEFAULTS["tuning_start"],
186
+ id="c1-tuning-start",
187
+ )
188
+ yield Label("mid_pos")
189
+ yield Input(
190
+ value=CANN1D_DEFAULTS["tuning_mid"],
191
+ id="c1-tuning-mid",
192
+ )
193
+ yield Label("end_pos")
194
+ yield Input(
195
+ value=CANN1D_DEFAULTS["tuning_end"],
196
+ id="c1-tuning-end",
197
+ )
198
+ yield Label("segment_duration")
199
+ yield Input(
200
+ value=CANN1D_DEFAULTS["tuning_duration"],
201
+ id="c1-tuning-duration",
202
+ )
203
+ yield Label("num_bins")
204
+ yield Input(
205
+ value=CANN1D_DEFAULTS["tuning_bins"],
206
+ id="c1-tuning-bins",
207
+ )
208
+ yield Label("neuron_indices")
209
+ yield Input(
210
+ value=CANN1D_DEFAULTS["tuning_neurons"],
211
+ id="c1-tuning-neurons",
212
+ )
213
+
214
+ with ParamGroup("Template Matching", id="c1-analysis-template"):
215
+ yield Label("stimulus_pos")
216
+ yield Input(
217
+ value=CANN1D_DEFAULTS["template_pos"],
218
+ id="c1-template-pos",
219
+ )
220
+ yield Label("duration")
221
+ yield Input(
222
+ value=CANN1D_DEFAULTS["template_duration"],
223
+ id="c1-template-duration",
224
+ )
225
+
226
+ with ParamGroup("Neural Manifold", id="c1-analysis-manifold"):
227
+ yield Label("segment_duration")
228
+ yield Input(
229
+ value=CANN1D_DEFAULTS["manifold_segment"],
230
+ id="c1-manifold-segment",
231
+ )
232
+ yield Label("warmup")
233
+ yield Input(
234
+ value=CANN1D_DEFAULTS["manifold_warmup"],
235
+ id="c1-manifold-warmup",
236
+ )
237
+
238
+ with ParamGroup("Connectivity", id="c1-analysis-connectivity"):
239
+ yield Static("No extra parameters")
240
+
241
+ with Vertical(id="params-cann2d", classes="hidden"):
242
+ with ParamGroup("CANN2D Model"):
243
+ yield Label("seed")
244
+ yield Input(value=CANN2D_DEFAULTS["seed"], id="c2-seed")
245
+ yield Label("length")
246
+ yield Input(value=CANN2D_DEFAULTS["length"], id="c2-length")
247
+ yield Label("tau")
248
+ yield Input(value=CANN2D_DEFAULTS["tau"], id="c2-tau")
249
+ yield Label("k")
250
+ yield Input(value=CANN2D_DEFAULTS["k"], id="c2-k")
251
+ yield Label("a")
252
+ yield Input(value=CANN2D_DEFAULTS["a"], id="c2-a")
253
+ yield Label("A")
254
+ yield Input(value=CANN2D_DEFAULTS["A"], id="c2-A")
255
+ yield Label("J0")
256
+ yield Input(value=CANN2D_DEFAULTS["J0"], id="c2-J0")
257
+ yield Label("dt")
258
+ yield Input(value=CANN2D_DEFAULTS["dt"], id="c2-dt")
259
+
260
+ with ParamGroup("Energy Landscape", id="c2-analysis-energy"):
261
+ yield Label("stimulus_x")
262
+ yield Input(value=CANN2D_DEFAULTS["energy_x"], id="c2-energy-x")
263
+ yield Label("stimulus_y")
264
+ yield Input(value=CANN2D_DEFAULTS["energy_y"], id="c2-energy-y")
265
+ yield Label("duration")
266
+ yield Input(
267
+ value=CANN2D_DEFAULTS["energy_duration"],
268
+ id="c2-energy-duration",
269
+ )
270
+
271
+ with ParamGroup("Firing Field", id="c2-analysis-firing"):
272
+ yield Label("duration")
273
+ yield Input(
274
+ value=CANN2D_DEFAULTS["field_duration"], id="c2-field-duration"
275
+ )
276
+ yield Label("box_size")
277
+ yield Input(value=CANN2D_DEFAULTS["field_box"], id="c2-field-box")
278
+ yield Label("resolution")
279
+ yield Input(
280
+ value=CANN2D_DEFAULTS["field_resolution"],
281
+ id="c2-field-resolution",
282
+ )
283
+ yield Label("smooth_sigma")
284
+ yield Input(
285
+ value=CANN2D_DEFAULTS["field_sigma"], id="c2-field-sigma"
286
+ )
287
+ yield Label("speed_mean")
288
+ yield Input(
289
+ value=CANN2D_DEFAULTS["field_speed"], id="c2-field-speed"
290
+ )
291
+ yield Label("speed_std")
292
+ yield Input(
293
+ value=CANN2D_DEFAULTS["field_speed_std"],
294
+ id="c2-field-speed-std",
295
+ )
296
+
297
+ with ParamGroup("Trajectory Comparison", id="c2-analysis-trajectory"):
298
+ yield Label("segment_duration")
299
+ yield Input(
300
+ value=CANN2D_DEFAULTS["traj_segment"],
301
+ id="c2-traj-segment",
302
+ )
303
+ yield Label("warmup")
304
+ yield Input(
305
+ value=CANN2D_DEFAULTS["traj_warmup"],
306
+ id="c2-traj-warmup",
307
+ )
308
+
309
+ with ParamGroup("Neural Manifold", id="c2-analysis-manifold"):
310
+ yield Label("duration")
311
+ yield Input(
312
+ value=CANN2D_DEFAULTS["manifold_duration"],
313
+ id="c2-manifold-duration",
314
+ )
315
+ yield Label("warmup")
316
+ yield Input(
317
+ value=CANN2D_DEFAULTS["manifold_warmup"],
318
+ id="c2-manifold-warmup",
319
+ )
320
+ yield Label("speed_mean")
321
+ yield Input(
322
+ value=CANN2D_DEFAULTS["manifold_speed"],
323
+ id="c2-manifold-speed",
324
+ )
325
+ yield Label("speed_std")
326
+ yield Input(
327
+ value=CANN2D_DEFAULTS["manifold_speed_std"],
328
+ id="c2-manifold-speed-std",
329
+ )
330
+ yield Label("downsample")
331
+ yield Input(
332
+ value=CANN2D_DEFAULTS["manifold_downsample"],
333
+ id="c2-manifold-downsample",
334
+ )
335
+
336
+ with ParamGroup("Connectivity", id="c2-analysis-connectivity"):
337
+ yield Static("No extra parameters")
338
+
339
+ with Vertical(id="params-gridcell", classes="hidden"):
340
+ with ParamGroup("Grid Cell Model"):
341
+ yield Label("seed")
342
+ yield Input(value=GRID_DEFAULTS["seed"], id="g-seed")
343
+ yield Label("length")
344
+ yield Input(value=GRID_DEFAULTS["length"], id="g-length")
345
+ yield Label("tau")
346
+ yield Input(value=GRID_DEFAULTS["tau"], id="g-tau")
347
+ yield Label("alpha")
348
+ yield Input(value=GRID_DEFAULTS["alpha"], id="g-alpha")
349
+ yield Label("W_l")
350
+ yield Input(value=GRID_DEFAULTS["W_l"], id="g-Wl")
351
+ yield Label("lambda_net")
352
+ yield Input(value=GRID_DEFAULTS["lambda_net"], id="g-lambda")
353
+ yield Label("dt")
354
+ yield Input(value=GRID_DEFAULTS["dt"], id="g-dt")
355
+
356
+ with ParamGroup("Energy Landscape", id="g-analysis-energy"):
357
+ yield Label("duration")
358
+ yield Input(
359
+ value=GRID_DEFAULTS["energy_duration"], id="g-energy-duration"
360
+ )
361
+ yield Label("speed_mean")
362
+ yield Input(
363
+ value=GRID_DEFAULTS["energy_speed"], id="g-energy-speed"
364
+ )
365
+ yield Label("speed_std")
366
+ yield Input(
367
+ value=GRID_DEFAULTS["energy_speed_std"],
368
+ id="g-energy-speed-std",
369
+ )
370
+ yield Label("heal_steps")
371
+ yield Input(
372
+ value=GRID_DEFAULTS["energy_heal_steps"],
373
+ id="g-energy-heal",
374
+ )
375
+
376
+ with ParamGroup("Firing Field", id="g-analysis-firing"):
377
+ yield Label("box_size")
378
+ yield Input(value=GRID_DEFAULTS["box_size"], id="g-box-size")
379
+ yield Label("resolution")
380
+ yield Input(
381
+ value=GRID_DEFAULTS["field_resolution"],
382
+ id="g-field-resolution",
383
+ )
384
+ yield Label("smooth_sigma")
385
+ yield Input(value=GRID_DEFAULTS["field_sigma"], id="g-field-sigma")
386
+ yield Label("speed")
387
+ yield Input(value=GRID_DEFAULTS["field_speed"], id="g-field-speed")
388
+ yield Label("num_batches")
389
+ yield Input(
390
+ value=GRID_DEFAULTS["field_batches"],
391
+ id="g-field-batches",
392
+ )
393
+
394
+ with ParamGroup("Path Integration", id="g-analysis-path"):
395
+ yield Label("duration")
396
+ yield Input(value=GRID_DEFAULTS["path_duration"], id="g-path-duration")
397
+ yield Label("dt")
398
+ yield Input(value=GRID_DEFAULTS["path_dt"], id="g-path-dt")
399
+ yield Label("speed_mean")
400
+ yield Input(value=GRID_DEFAULTS["path_speed"], id="g-path-speed")
401
+ yield Label("speed_std")
402
+ yield Input(
403
+ value=GRID_DEFAULTS["path_speed_std"], id="g-path-speed-std"
404
+ )
405
+ yield Label("heal_steps")
406
+ yield Input(
407
+ value=GRID_DEFAULTS["path_heal_steps"], id="g-path-heal"
408
+ )
409
+
410
+ with ParamGroup("Connectivity", id="g-analysis-connectivity"):
411
+ yield Static("No extra parameters")
412
+
413
+ with Vertical(id="right-panel"):
414
+ yield ImagePreview(id="result-preview")
415
+ yield LogViewer(id="log-viewer")
416
+
417
+ yield Footer()
418
+
419
+ def on_mount(self) -> None:
420
+ self.check_terminal_size()
421
+ self._update_analysis_options()
422
+ self._update_param_visibility()
423
+
424
+ def on_resize(self, event) -> None:
425
+ self.check_terminal_size()
426
+
427
+ def check_terminal_size(self) -> None:
428
+ width, height = self.size
429
+ if not self._size_warning_shown and (
430
+ width < self.MIN_WIDTH or height < self.MIN_HEIGHT
431
+ ):
432
+ self._size_warning_shown = True
433
+ self.push_screen(TerminalSizeWarning(width, height))
434
+
435
+ def action_change_workdir(self) -> None:
436
+ self.push_screen(WorkdirScreen(), self.on_workdir_selected)
437
+
438
+ def on_workdir_selected(self, path: Path | None) -> None:
439
+ if path:
440
+ self.state.workdir = path
441
+ self.update_workdir_label()
442
+
443
+ def update_workdir_label(self) -> None:
444
+ label = self.query_one("#workdir-label", Label)
445
+ label.update(f"Workdir: {self.state.workdir}")
446
+
447
+ def on_button_pressed(self, event: Button.Pressed) -> None:
448
+ if event.button.id == "change-workdir-btn":
449
+ self.action_change_workdir()
450
+ elif event.button.id == "run-btn":
451
+ self.action_run()
452
+
453
+ def on_select_changed(self, event: Select.Changed) -> None:
454
+ if event.select.id == "model-select":
455
+ self.state.model = str(event.value)
456
+ self.state.analysis = get_default_analysis(self.state.model)
457
+ self._update_analysis_options()
458
+ self._update_param_visibility()
459
+ elif event.select.id == "analysis-select":
460
+ self.state.analysis = str(event.value)
461
+ self._update_param_visibility()
462
+
463
+ def action_run(self) -> None:
464
+ if self.current_worker and not self.current_worker.is_finished:
465
+ self.log_message("A task is already running.")
466
+ return
467
+
468
+ try:
469
+ model_params, analysis_params = self.collect_params()
470
+ except ValueError as exc:
471
+ self.push_screen(ErrorScreen("Parameter Error", str(exc)))
472
+ return
473
+
474
+ output_dir = self.state.workdir / "Results" / "gallery" / self.state.model
475
+ self.set_run_status("Status: Running...", "running")
476
+ self.update_progress(0)
477
+
478
+ self.current_worker = self.run_worker(
479
+ self.runner.run(
480
+ self.state.model,
481
+ self.state.analysis,
482
+ model_params,
483
+ analysis_params,
484
+ output_dir,
485
+ log_callback=self.log_message,
486
+ progress_callback=self.update_progress,
487
+ ),
488
+ name="gallery_worker",
489
+ thread=True,
490
+ )
491
+
492
+ def action_refresh(self) -> None:
493
+ preview_path = self.state.artifacts.get("output")
494
+ if preview_path:
495
+ preview = self.query_one("#result-preview", ImagePreview)
496
+ preview.update_image(preview_path)
497
+ self.log_message(f"Refreshed preview: {preview_path}")
498
+
499
+ def on_worker_state_changed(self, event: Worker.StateChanged) -> None:
500
+ if event.worker.name != "gallery_worker" or not event.worker.is_finished:
501
+ return
502
+ result = event.worker.result
503
+ if result.success:
504
+ self.state.artifacts = result.artifacts
505
+ self.set_run_status("Status: Complete", "success")
506
+ self.log_message(result.summary)
507
+ output_path = result.artifacts.get("output")
508
+ if output_path:
509
+ preview = self.query_one("#result-preview", ImagePreview)
510
+ preview.update_image(output_path)
511
+ else:
512
+ self.set_run_status("Status: Failed", "error")
513
+ self.push_screen(ErrorScreen("Gallery Error", result.error or "Unknown error"))
514
+
515
+ def set_run_status(self, message: str, status_class: str | None = None) -> None:
516
+ status = self.query_one("#run-status", Static)
517
+ status.update(message)
518
+ status.remove_class("running", "success", "error")
519
+ if status_class:
520
+ status.add_class(status_class)
521
+
522
+ def update_progress(self, percent: int) -> None:
523
+ progress = self.query_one("#progress-bar", ProgressBar)
524
+ progress.update(total=100, progress=percent)
525
+
526
+ def log_message(self, message: str) -> None:
527
+ log_viewer = self.query_one("#log-viewer", LogViewer)
528
+ log_viewer.add_log(message)
529
+
530
+ def collect_params(self) -> tuple[dict[str, float | int], dict[str, float | int]]:
531
+ if self.state.model == "cann1d":
532
+ model_params = {
533
+ "seed": self._int("#c1-seed"),
534
+ "num": self._int("#c1-num"),
535
+ "tau": self._float("#c1-tau"),
536
+ "k": self._float("#c1-k"),
537
+ "a": self._float("#c1-a"),
538
+ "A": self._float("#c1-A"),
539
+ "J0": self._float("#c1-J0"),
540
+ "dt": self._float("#c1-dt"),
541
+ }
542
+ analysis_params = {
543
+ "energy_pos": self._float("#c1-energy-pos"),
544
+ "energy_duration": self._float("#c1-energy-duration"),
545
+ "tuning_start": self._float("#c1-tuning-start"),
546
+ "tuning_mid": self._float("#c1-tuning-mid"),
547
+ "tuning_end": self._float("#c1-tuning-end"),
548
+ "tuning_duration": self._float("#c1-tuning-duration"),
549
+ "tuning_bins": self._int("#c1-tuning-bins"),
550
+ "tuning_neurons": self._str("#c1-tuning-neurons"),
551
+ "template_pos": self._float("#c1-template-pos"),
552
+ "template_duration": self._float("#c1-template-duration"),
553
+ "manifold_segment": self._float("#c1-manifold-segment"),
554
+ "manifold_warmup": self._float("#c1-manifold-warmup"),
555
+ }
556
+ return model_params, analysis_params
557
+
558
+ if self.state.model == "cann2d":
559
+ model_params = {
560
+ "seed": self._int("#c2-seed"),
561
+ "length": self._int("#c2-length"),
562
+ "tau": self._float("#c2-tau"),
563
+ "k": self._float("#c2-k"),
564
+ "a": self._float("#c2-a"),
565
+ "A": self._float("#c2-A"),
566
+ "J0": self._float("#c2-J0"),
567
+ "dt": self._float("#c2-dt"),
568
+ }
569
+ analysis_params = {
570
+ "energy_x": self._float("#c2-energy-x"),
571
+ "energy_y": self._float("#c2-energy-y"),
572
+ "energy_duration": self._float("#c2-energy-duration"),
573
+ "field_duration": self._float("#c2-field-duration"),
574
+ "field_box": self._float("#c2-field-box"),
575
+ "field_resolution": self._int("#c2-field-resolution"),
576
+ "field_sigma": self._float("#c2-field-sigma"),
577
+ "field_speed": self._float("#c2-field-speed"),
578
+ "field_speed_std": self._float("#c2-field-speed-std"),
579
+ "traj_segment": self._float("#c2-traj-segment"),
580
+ "traj_warmup": self._float("#c2-traj-warmup"),
581
+ "manifold_duration": self._float("#c2-manifold-duration"),
582
+ "manifold_warmup": self._float("#c2-manifold-warmup"),
583
+ "manifold_speed": self._float("#c2-manifold-speed"),
584
+ "manifold_speed_std": self._float("#c2-manifold-speed-std"),
585
+ "manifold_downsample": self._int("#c2-manifold-downsample"),
586
+ }
587
+ return model_params, analysis_params
588
+
589
+ if self.state.model == "gridcell":
590
+ model_params = {
591
+ "seed": self._int("#g-seed"),
592
+ "length": self._int("#g-length"),
593
+ "tau": self._float("#g-tau"),
594
+ "alpha": self._float("#g-alpha"),
595
+ "W_l": self._float("#g-Wl"),
596
+ "lambda_net": self._float("#g-lambda"),
597
+ "dt": self._float("#g-dt"),
598
+ }
599
+ analysis_params = {
600
+ "box_size": self._float("#g-box-size"),
601
+ "energy_duration": self._float("#g-energy-duration"),
602
+ "energy_speed": self._float("#g-energy-speed"),
603
+ "energy_speed_std": self._float("#g-energy-speed-std"),
604
+ "energy_heal_steps": self._int("#g-energy-heal"),
605
+ "field_resolution": self._int("#g-field-resolution"),
606
+ "field_sigma": self._float("#g-field-sigma"),
607
+ "field_speed": self._float("#g-field-speed"),
608
+ "field_batches": self._int("#g-field-batches"),
609
+ "path_duration": self._float("#g-path-duration"),
610
+ "path_dt": self._float("#g-path-dt"),
611
+ "path_speed": self._float("#g-path-speed"),
612
+ "path_speed_std": self._float("#g-path-speed-std"),
613
+ "path_heal_steps": self._int("#g-path-heal"),
614
+ }
615
+ return model_params, analysis_params
616
+
617
+ raise ValueError(f"Unknown model: {self.state.model}")
618
+
619
+ def _int(self, selector: str) -> int:
620
+ return int(self.query_one(selector, Input).value)
621
+
622
+ def _float(self, selector: str) -> float:
623
+ return float(self.query_one(selector, Input).value)
624
+
625
+ def _str(self, selector: str) -> str:
626
+ return self.query_one(selector, Input).value
627
+
628
+ def _update_analysis_options(self) -> None:
629
+ select = self.query_one("#analysis-select", Select)
630
+ options = get_analysis_options(self.state.model)
631
+ select.set_options(options)
632
+ select.value = self.state.analysis
633
+
634
+ def _update_param_visibility(self) -> None:
635
+ self._set_visible("#params-cann1d", self.state.model == "cann1d")
636
+ self._set_visible("#params-cann2d", self.state.model == "cann2d")
637
+ self._set_visible("#params-gridcell", self.state.model == "gridcell")
638
+
639
+ # CANN1D analysis groups
640
+ self._set_visible(
641
+ "#c1-analysis-connectivity",
642
+ self.state.model == "cann1d" and self.state.analysis == "connectivity",
643
+ )
644
+ self._set_visible(
645
+ "#c1-analysis-energy",
646
+ self.state.model == "cann1d" and self.state.analysis == "energy",
647
+ )
648
+ self._set_visible(
649
+ "#c1-analysis-tuning",
650
+ self.state.model == "cann1d" and self.state.analysis == "tuning",
651
+ )
652
+ self._set_visible(
653
+ "#c1-analysis-template",
654
+ self.state.model == "cann1d" and self.state.analysis == "template",
655
+ )
656
+ self._set_visible(
657
+ "#c1-analysis-manifold",
658
+ self.state.model == "cann1d" and self.state.analysis == "manifold",
659
+ )
660
+
661
+ # CANN2D analysis groups
662
+ self._set_visible(
663
+ "#c2-analysis-connectivity",
664
+ self.state.model == "cann2d" and self.state.analysis == "connectivity",
665
+ )
666
+ self._set_visible(
667
+ "#c2-analysis-energy",
668
+ self.state.model == "cann2d" and self.state.analysis == "energy",
669
+ )
670
+ self._set_visible(
671
+ "#c2-analysis-firing",
672
+ self.state.model == "cann2d" and self.state.analysis == "firing_field",
673
+ )
674
+ self._set_visible(
675
+ "#c2-analysis-trajectory",
676
+ self.state.model == "cann2d" and self.state.analysis == "trajectory",
677
+ )
678
+ self._set_visible(
679
+ "#c2-analysis-manifold",
680
+ self.state.model == "cann2d" and self.state.analysis == "manifold",
681
+ )
682
+
683
+ # Grid cell analysis groups
684
+ self._set_visible(
685
+ "#g-analysis-connectivity",
686
+ self.state.model == "gridcell" and self.state.analysis == "connectivity",
687
+ )
688
+ self._set_visible(
689
+ "#g-analysis-energy",
690
+ self.state.model == "gridcell" and self.state.analysis == "energy",
691
+ )
692
+ self._set_visible(
693
+ "#g-analysis-firing",
694
+ self.state.model == "gridcell"
695
+ and self.state.analysis in {"firing_field", "manifold"},
696
+ )
697
+ self._set_visible(
698
+ "#g-analysis-path",
699
+ self.state.model == "gridcell" and self.state.analysis == "path_integration",
700
+ )
701
+
702
+ def _set_visible(self, selector: str, visible: bool) -> None:
703
+ widget = self.query_one(selector)
704
+ if visible:
705
+ widget.remove_class("hidden")
706
+ else:
707
+ widget.add_class("hidden")