canns 0.13.1__py3-none-any.whl → 0.14.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. canns/analyzer/data/__init__.py +5 -1
  2. canns/analyzer/data/asa/__init__.py +27 -12
  3. canns/analyzer/data/asa/cohospace.py +336 -10
  4. canns/analyzer/data/asa/config.py +3 -0
  5. canns/analyzer/data/asa/embedding.py +48 -45
  6. canns/analyzer/data/asa/path.py +104 -2
  7. canns/analyzer/data/asa/plotting.py +88 -19
  8. canns/analyzer/data/asa/tda.py +11 -4
  9. canns/analyzer/data/cell_classification/__init__.py +97 -0
  10. canns/analyzer/data/cell_classification/core/__init__.py +26 -0
  11. canns/analyzer/data/cell_classification/core/grid_cells.py +633 -0
  12. canns/analyzer/data/cell_classification/core/grid_modules_leiden.py +288 -0
  13. canns/analyzer/data/cell_classification/core/head_direction.py +347 -0
  14. canns/analyzer/data/cell_classification/core/spatial_analysis.py +431 -0
  15. canns/analyzer/data/cell_classification/io/__init__.py +5 -0
  16. canns/analyzer/data/cell_classification/io/matlab_loader.py +417 -0
  17. canns/analyzer/data/cell_classification/utils/__init__.py +39 -0
  18. canns/analyzer/data/cell_classification/utils/circular_stats.py +383 -0
  19. canns/analyzer/data/cell_classification/utils/correlation.py +318 -0
  20. canns/analyzer/data/cell_classification/utils/geometry.py +442 -0
  21. canns/analyzer/data/cell_classification/utils/image_processing.py +416 -0
  22. canns/analyzer/data/cell_classification/visualization/__init__.py +19 -0
  23. canns/analyzer/data/cell_classification/visualization/grid_plots.py +292 -0
  24. canns/analyzer/data/cell_classification/visualization/hd_plots.py +200 -0
  25. canns/analyzer/metrics/__init__.py +2 -1
  26. canns/analyzer/visualization/core/config.py +46 -4
  27. canns/data/__init__.py +6 -1
  28. canns/data/datasets.py +154 -1
  29. canns/data/loaders.py +37 -0
  30. canns/pipeline/__init__.py +13 -9
  31. canns/pipeline/__main__.py +6 -0
  32. canns/pipeline/asa/runner.py +105 -41
  33. canns/pipeline/asa_gui/__init__.py +68 -0
  34. canns/pipeline/asa_gui/__main__.py +6 -0
  35. canns/pipeline/asa_gui/analysis_modes/__init__.py +42 -0
  36. canns/pipeline/asa_gui/analysis_modes/base.py +39 -0
  37. canns/pipeline/asa_gui/analysis_modes/batch_mode.py +21 -0
  38. canns/pipeline/asa_gui/analysis_modes/cohomap_mode.py +56 -0
  39. canns/pipeline/asa_gui/analysis_modes/cohospace_mode.py +194 -0
  40. canns/pipeline/asa_gui/analysis_modes/decode_mode.py +52 -0
  41. canns/pipeline/asa_gui/analysis_modes/fr_mode.py +81 -0
  42. canns/pipeline/asa_gui/analysis_modes/frm_mode.py +92 -0
  43. canns/pipeline/asa_gui/analysis_modes/gridscore_mode.py +123 -0
  44. canns/pipeline/asa_gui/analysis_modes/pathcompare_mode.py +199 -0
  45. canns/pipeline/asa_gui/analysis_modes/tda_mode.py +112 -0
  46. canns/pipeline/asa_gui/app.py +29 -0
  47. canns/pipeline/asa_gui/controllers/__init__.py +6 -0
  48. canns/pipeline/asa_gui/controllers/analysis_controller.py +59 -0
  49. canns/pipeline/asa_gui/controllers/preprocess_controller.py +89 -0
  50. canns/pipeline/asa_gui/core/__init__.py +15 -0
  51. canns/pipeline/asa_gui/core/cache.py +14 -0
  52. canns/pipeline/asa_gui/core/runner.py +1936 -0
  53. canns/pipeline/asa_gui/core/state.py +324 -0
  54. canns/pipeline/asa_gui/core/worker.py +260 -0
  55. canns/pipeline/asa_gui/main_window.py +184 -0
  56. canns/pipeline/asa_gui/models/__init__.py +7 -0
  57. canns/pipeline/asa_gui/models/config.py +14 -0
  58. canns/pipeline/asa_gui/models/job.py +31 -0
  59. canns/pipeline/asa_gui/models/presets.py +21 -0
  60. canns/pipeline/asa_gui/resources/__init__.py +16 -0
  61. canns/pipeline/asa_gui/resources/dark.qss +167 -0
  62. canns/pipeline/asa_gui/resources/light.qss +163 -0
  63. canns/pipeline/asa_gui/resources/styles.qss +130 -0
  64. canns/pipeline/asa_gui/utils/__init__.py +1 -0
  65. canns/pipeline/asa_gui/utils/formatters.py +15 -0
  66. canns/pipeline/asa_gui/utils/io_adapters.py +40 -0
  67. canns/pipeline/asa_gui/utils/validators.py +41 -0
  68. canns/pipeline/asa_gui/views/__init__.py +1 -0
  69. canns/pipeline/asa_gui/views/help_content.py +171 -0
  70. canns/pipeline/asa_gui/views/pages/__init__.py +6 -0
  71. canns/pipeline/asa_gui/views/pages/analysis_page.py +565 -0
  72. canns/pipeline/asa_gui/views/pages/preprocess_page.py +492 -0
  73. canns/pipeline/asa_gui/views/panels/__init__.py +1 -0
  74. canns/pipeline/asa_gui/views/widgets/__init__.py +21 -0
  75. canns/pipeline/asa_gui/views/widgets/artifacts_tab.py +44 -0
  76. canns/pipeline/asa_gui/views/widgets/drop_zone.py +80 -0
  77. canns/pipeline/asa_gui/views/widgets/file_list.py +27 -0
  78. canns/pipeline/asa_gui/views/widgets/gridscore_tab.py +308 -0
  79. canns/pipeline/asa_gui/views/widgets/help_dialog.py +27 -0
  80. canns/pipeline/asa_gui/views/widgets/image_tab.py +50 -0
  81. canns/pipeline/asa_gui/views/widgets/image_viewer.py +97 -0
  82. canns/pipeline/asa_gui/views/widgets/log_box.py +16 -0
  83. canns/pipeline/asa_gui/views/widgets/pathcompare_tab.py +200 -0
  84. canns/pipeline/asa_gui/views/widgets/popup_combo.py +25 -0
  85. canns/pipeline/gallery/__init__.py +15 -5
  86. canns/pipeline/gallery/__main__.py +11 -0
  87. canns/pipeline/gallery/app.py +705 -0
  88. canns/pipeline/gallery/runner.py +790 -0
  89. canns/pipeline/gallery/state.py +51 -0
  90. canns/pipeline/gallery/styles.tcss +123 -0
  91. canns/pipeline/launcher.py +81 -0
  92. {canns-0.13.1.dist-info → canns-0.14.0.dist-info}/METADATA +11 -1
  93. canns-0.14.0.dist-info/RECORD +163 -0
  94. canns-0.14.0.dist-info/entry_points.txt +5 -0
  95. canns/pipeline/_base.py +0 -50
  96. canns-0.13.1.dist-info/RECORD +0 -89
  97. canns-0.13.1.dist-info/entry_points.txt +0 -3
  98. {canns-0.13.1.dist-info → canns-0.14.0.dist-info}/WHEEL +0 -0
  99. {canns-0.13.1.dist-info → canns-0.14.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,705 @@
1
+ """Model gallery TUI for quick CANN visualizations."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+
7
+ from textual.app import App, ComposeResult
8
+ from textual.binding import Binding
9
+ from textual.containers import Horizontal, Vertical, VerticalScroll
10
+ from textual.widgets import Button, Footer, Header, Input, Label, ProgressBar, Select, Static
11
+ from textual.worker import Worker
12
+
13
+ from canns.pipeline.asa.screens import ErrorScreen, TerminalSizeWarning, WorkdirScreen
14
+ from canns.pipeline.asa.widgets import ImagePreview, LogViewer, ParamGroup
15
+
16
+ from .runner import GalleryRunner
17
+ from .state import GalleryState, get_analysis_options, get_default_analysis
18
+
19
+ CANN1D_DEFAULTS = {
20
+ "seed": "42",
21
+ "num": "256",
22
+ "tau": "1.0",
23
+ "k": "8.1",
24
+ "a": "0.5",
25
+ "A": "10.0",
26
+ "J0": "4.0",
27
+ "dt": "0.1",
28
+ "energy_pos": "1.0",
29
+ "energy_duration": "10.0",
30
+ "tuning_start": "0.0",
31
+ "tuning_mid": "3.1416",
32
+ "tuning_end": "6.2832",
33
+ "tuning_duration": "40.0",
34
+ "tuning_bins": "50",
35
+ "tuning_neurons": "64,128,192",
36
+ "template_pos": "1.0",
37
+ "template_duration": "10.0",
38
+ "manifold_segment": "40.0",
39
+ "manifold_warmup": "5.0",
40
+ }
41
+
42
+ CANN2D_DEFAULTS = {
43
+ "seed": "42",
44
+ "length": "64",
45
+ "tau": "1.0",
46
+ "k": "8.1",
47
+ "a": "0.5",
48
+ "A": "10.0",
49
+ "J0": "4.0",
50
+ "dt": "0.1",
51
+ "energy_x": "1.0",
52
+ "energy_y": "1.0",
53
+ "energy_duration": "10.0",
54
+ "field_duration": "8000.0",
55
+ "field_box": "6.2832",
56
+ "field_resolution": "80",
57
+ "field_sigma": "2.0",
58
+ "field_speed": "0.3",
59
+ "field_speed_std": "0.1",
60
+ "traj_segment": "40.0",
61
+ "traj_warmup": "5.0",
62
+ "manifold_duration": "22000.0",
63
+ "manifold_warmup": "2000.0",
64
+ "manifold_speed": "0.05",
65
+ "manifold_speed_std": "0.02",
66
+ "manifold_downsample": "10",
67
+ }
68
+
69
+ GRID_DEFAULTS = {
70
+ "seed": "74",
71
+ "length": "40",
72
+ "tau": "0.01",
73
+ "alpha": "0.1",
74
+ "W_l": "3.0",
75
+ "lambda_net": "17.0",
76
+ "dt": "0.0005",
77
+ "box_size": "2.2",
78
+ "energy_duration": "10.0",
79
+ "energy_speed": "0.2",
80
+ "energy_speed_std": "0.05",
81
+ "energy_heal_steps": "10000",
82
+ "field_resolution": "100",
83
+ "field_sigma": "2.0",
84
+ "field_speed": "0.3",
85
+ "field_batches": "10",
86
+ "path_duration": "10.0",
87
+ "path_dt": "0.01",
88
+ "path_speed": "0.5",
89
+ "path_speed_std": "0.05",
90
+ "path_heal_steps": "5000",
91
+ }
92
+
93
+
94
+ class GalleryApp(App):
95
+ """Main TUI application for the model gallery."""
96
+
97
+ CSS_PATH = "styles.tcss"
98
+ TITLE = "CANNs Model Gallery"
99
+
100
+ MIN_WIDTH = 100
101
+ RECOMMENDED_WIDTH = 120
102
+ MIN_HEIGHT = 28
103
+ RECOMMENDED_HEIGHT = 36
104
+
105
+ BINDINGS = [
106
+ Binding("ctrl+w", "change_workdir", "Workdir"),
107
+ Binding("ctrl+r", "run", "Run"),
108
+ Binding("f5", "refresh", "Refresh"),
109
+ Binding("escape", "quit", "Quit"),
110
+ ]
111
+
112
+ def __init__(self) -> None:
113
+ super().__init__()
114
+ self.state = GalleryState()
115
+ self.runner = GalleryRunner()
116
+ self.current_worker: Worker | None = None
117
+ self._size_warning_shown = False
118
+
119
+ def compose(self) -> ComposeResult:
120
+ yield Header()
121
+
122
+ with Horizontal(id="main-container"):
123
+ with Vertical(id="left-panel"):
124
+ yield Label(f"Workdir: {self.state.workdir}", id="workdir-label")
125
+ yield Button("Change Workdir", id="change-workdir-btn")
126
+
127
+ yield Label("Model")
128
+ yield Select(
129
+ [
130
+ ("CANN 1D", "cann1d"),
131
+ ("CANN 2D", "cann2d"),
132
+ ("Grid Cell", "gridcell"),
133
+ ],
134
+ value=self.state.model,
135
+ id="model-select",
136
+ )
137
+
138
+ yield Label("Analysis")
139
+ yield Select(
140
+ get_analysis_options(self.state.model),
141
+ value=self.state.analysis,
142
+ id="analysis-select",
143
+ )
144
+
145
+ yield Button("Run", variant="primary", id="run-btn")
146
+ yield ProgressBar(id="progress-bar")
147
+ yield Static("Status: Idle", id="run-status")
148
+
149
+ with Vertical(id="middle-panel"):
150
+ with Vertical(id="params-panel"):
151
+ yield Static("Parameters", id="params-header")
152
+ with VerticalScroll(id="params-scroll"):
153
+ with Vertical(id="params-cann1d"):
154
+ with ParamGroup("CANN1D Model"):
155
+ yield Label("seed")
156
+ yield Input(value=CANN1D_DEFAULTS["seed"], id="c1-seed")
157
+ yield Label("num")
158
+ yield Input(value=CANN1D_DEFAULTS["num"], id="c1-num")
159
+ yield Label("tau")
160
+ yield Input(value=CANN1D_DEFAULTS["tau"], id="c1-tau")
161
+ yield Label("k")
162
+ yield Input(value=CANN1D_DEFAULTS["k"], id="c1-k")
163
+ yield Label("a")
164
+ yield Input(value=CANN1D_DEFAULTS["a"], id="c1-a")
165
+ yield Label("A")
166
+ yield Input(value=CANN1D_DEFAULTS["A"], id="c1-A")
167
+ yield Label("J0")
168
+ yield Input(value=CANN1D_DEFAULTS["J0"], id="c1-J0")
169
+ yield Label("dt")
170
+ yield Input(value=CANN1D_DEFAULTS["dt"], id="c1-dt")
171
+
172
+ with ParamGroup("Energy Landscape", id="c1-analysis-energy"):
173
+ yield Label("stimulus_pos")
174
+ yield Input(value=CANN1D_DEFAULTS["energy_pos"], id="c1-energy-pos")
175
+ yield Label("duration")
176
+ yield Input(
177
+ value=CANN1D_DEFAULTS["energy_duration"],
178
+ id="c1-energy-duration",
179
+ )
180
+
181
+ with ParamGroup("Tuning Curve", id="c1-analysis-tuning"):
182
+ yield Label("start_pos")
183
+ yield Input(
184
+ value=CANN1D_DEFAULTS["tuning_start"],
185
+ id="c1-tuning-start",
186
+ )
187
+ yield Label("mid_pos")
188
+ yield Input(
189
+ value=CANN1D_DEFAULTS["tuning_mid"],
190
+ id="c1-tuning-mid",
191
+ )
192
+ yield Label("end_pos")
193
+ yield Input(
194
+ value=CANN1D_DEFAULTS["tuning_end"],
195
+ id="c1-tuning-end",
196
+ )
197
+ yield Label("segment_duration")
198
+ yield Input(
199
+ value=CANN1D_DEFAULTS["tuning_duration"],
200
+ id="c1-tuning-duration",
201
+ )
202
+ yield Label("num_bins")
203
+ yield Input(
204
+ value=CANN1D_DEFAULTS["tuning_bins"],
205
+ id="c1-tuning-bins",
206
+ )
207
+ yield Label("neuron_indices")
208
+ yield Input(
209
+ value=CANN1D_DEFAULTS["tuning_neurons"],
210
+ id="c1-tuning-neurons",
211
+ )
212
+
213
+ with ParamGroup("Template Matching", id="c1-analysis-template"):
214
+ yield Label("stimulus_pos")
215
+ yield Input(
216
+ value=CANN1D_DEFAULTS["template_pos"],
217
+ id="c1-template-pos",
218
+ )
219
+ yield Label("duration")
220
+ yield Input(
221
+ value=CANN1D_DEFAULTS["template_duration"],
222
+ id="c1-template-duration",
223
+ )
224
+
225
+ with ParamGroup("Neural Manifold", id="c1-analysis-manifold"):
226
+ yield Label("segment_duration")
227
+ yield Input(
228
+ value=CANN1D_DEFAULTS["manifold_segment"],
229
+ id="c1-manifold-segment",
230
+ )
231
+ yield Label("warmup")
232
+ yield Input(
233
+ value=CANN1D_DEFAULTS["manifold_warmup"],
234
+ id="c1-manifold-warmup",
235
+ )
236
+
237
+ with ParamGroup("Connectivity", id="c1-analysis-connectivity"):
238
+ yield Static("No extra parameters")
239
+
240
+ with Vertical(id="params-cann2d", classes="hidden"):
241
+ with ParamGroup("CANN2D Model"):
242
+ yield Label("seed")
243
+ yield Input(value=CANN2D_DEFAULTS["seed"], id="c2-seed")
244
+ yield Label("length")
245
+ yield Input(value=CANN2D_DEFAULTS["length"], id="c2-length")
246
+ yield Label("tau")
247
+ yield Input(value=CANN2D_DEFAULTS["tau"], id="c2-tau")
248
+ yield Label("k")
249
+ yield Input(value=CANN2D_DEFAULTS["k"], id="c2-k")
250
+ yield Label("a")
251
+ yield Input(value=CANN2D_DEFAULTS["a"], id="c2-a")
252
+ yield Label("A")
253
+ yield Input(value=CANN2D_DEFAULTS["A"], id="c2-A")
254
+ yield Label("J0")
255
+ yield Input(value=CANN2D_DEFAULTS["J0"], id="c2-J0")
256
+ yield Label("dt")
257
+ yield Input(value=CANN2D_DEFAULTS["dt"], id="c2-dt")
258
+
259
+ with ParamGroup("Energy Landscape", id="c2-analysis-energy"):
260
+ yield Label("stimulus_x")
261
+ yield Input(value=CANN2D_DEFAULTS["energy_x"], id="c2-energy-x")
262
+ yield Label("stimulus_y")
263
+ yield Input(value=CANN2D_DEFAULTS["energy_y"], id="c2-energy-y")
264
+ yield Label("duration")
265
+ yield Input(
266
+ value=CANN2D_DEFAULTS["energy_duration"],
267
+ id="c2-energy-duration",
268
+ )
269
+
270
+ with ParamGroup("Firing Field", id="c2-analysis-firing"):
271
+ yield Label("duration")
272
+ yield Input(
273
+ value=CANN2D_DEFAULTS["field_duration"], id="c2-field-duration"
274
+ )
275
+ yield Label("box_size")
276
+ yield Input(value=CANN2D_DEFAULTS["field_box"], id="c2-field-box")
277
+ yield Label("resolution")
278
+ yield Input(
279
+ value=CANN2D_DEFAULTS["field_resolution"],
280
+ id="c2-field-resolution",
281
+ )
282
+ yield Label("smooth_sigma")
283
+ yield Input(
284
+ value=CANN2D_DEFAULTS["field_sigma"], id="c2-field-sigma"
285
+ )
286
+ yield Label("speed_mean")
287
+ yield Input(
288
+ value=CANN2D_DEFAULTS["field_speed"], id="c2-field-speed"
289
+ )
290
+ yield Label("speed_std")
291
+ yield Input(
292
+ value=CANN2D_DEFAULTS["field_speed_std"],
293
+ id="c2-field-speed-std",
294
+ )
295
+
296
+ with ParamGroup("Trajectory Comparison", id="c2-analysis-trajectory"):
297
+ yield Label("segment_duration")
298
+ yield Input(
299
+ value=CANN2D_DEFAULTS["traj_segment"],
300
+ id="c2-traj-segment",
301
+ )
302
+ yield Label("warmup")
303
+ yield Input(
304
+ value=CANN2D_DEFAULTS["traj_warmup"],
305
+ id="c2-traj-warmup",
306
+ )
307
+
308
+ with ParamGroup("Neural Manifold", id="c2-analysis-manifold"):
309
+ yield Label("duration")
310
+ yield Input(
311
+ value=CANN2D_DEFAULTS["manifold_duration"],
312
+ id="c2-manifold-duration",
313
+ )
314
+ yield Label("warmup")
315
+ yield Input(
316
+ value=CANN2D_DEFAULTS["manifold_warmup"],
317
+ id="c2-manifold-warmup",
318
+ )
319
+ yield Label("speed_mean")
320
+ yield Input(
321
+ value=CANN2D_DEFAULTS["manifold_speed"],
322
+ id="c2-manifold-speed",
323
+ )
324
+ yield Label("speed_std")
325
+ yield Input(
326
+ value=CANN2D_DEFAULTS["manifold_speed_std"],
327
+ id="c2-manifold-speed-std",
328
+ )
329
+ yield Label("downsample")
330
+ yield Input(
331
+ value=CANN2D_DEFAULTS["manifold_downsample"],
332
+ id="c2-manifold-downsample",
333
+ )
334
+
335
+ with ParamGroup("Connectivity", id="c2-analysis-connectivity"):
336
+ yield Static("No extra parameters")
337
+
338
+ with Vertical(id="params-gridcell", classes="hidden"):
339
+ with ParamGroup("Grid Cell Model"):
340
+ yield Label("seed")
341
+ yield Input(value=GRID_DEFAULTS["seed"], id="g-seed")
342
+ yield Label("length")
343
+ yield Input(value=GRID_DEFAULTS["length"], id="g-length")
344
+ yield Label("tau")
345
+ yield Input(value=GRID_DEFAULTS["tau"], id="g-tau")
346
+ yield Label("alpha")
347
+ yield Input(value=GRID_DEFAULTS["alpha"], id="g-alpha")
348
+ yield Label("W_l")
349
+ yield Input(value=GRID_DEFAULTS["W_l"], id="g-Wl")
350
+ yield Label("lambda_net")
351
+ yield Input(value=GRID_DEFAULTS["lambda_net"], id="g-lambda")
352
+ yield Label("dt")
353
+ yield Input(value=GRID_DEFAULTS["dt"], id="g-dt")
354
+
355
+ with ParamGroup("Energy Landscape", id="g-analysis-energy"):
356
+ yield Label("duration")
357
+ yield Input(
358
+ value=GRID_DEFAULTS["energy_duration"], id="g-energy-duration"
359
+ )
360
+ yield Label("speed_mean")
361
+ yield Input(
362
+ value=GRID_DEFAULTS["energy_speed"], id="g-energy-speed"
363
+ )
364
+ yield Label("speed_std")
365
+ yield Input(
366
+ value=GRID_DEFAULTS["energy_speed_std"],
367
+ id="g-energy-speed-std",
368
+ )
369
+ yield Label("heal_steps")
370
+ yield Input(
371
+ value=GRID_DEFAULTS["energy_heal_steps"],
372
+ id="g-energy-heal",
373
+ )
374
+
375
+ with ParamGroup("Firing Field", id="g-analysis-firing"):
376
+ yield Label("box_size")
377
+ yield Input(value=GRID_DEFAULTS["box_size"], id="g-box-size")
378
+ yield Label("resolution")
379
+ yield Input(
380
+ value=GRID_DEFAULTS["field_resolution"],
381
+ id="g-field-resolution",
382
+ )
383
+ yield Label("smooth_sigma")
384
+ yield Input(value=GRID_DEFAULTS["field_sigma"], id="g-field-sigma")
385
+ yield Label("speed")
386
+ yield Input(value=GRID_DEFAULTS["field_speed"], id="g-field-speed")
387
+ yield Label("num_batches")
388
+ yield Input(
389
+ value=GRID_DEFAULTS["field_batches"],
390
+ id="g-field-batches",
391
+ )
392
+
393
+ with ParamGroup("Path Integration", id="g-analysis-path"):
394
+ yield Label("duration")
395
+ yield Input(
396
+ value=GRID_DEFAULTS["path_duration"], id="g-path-duration"
397
+ )
398
+ yield Label("dt")
399
+ yield Input(value=GRID_DEFAULTS["path_dt"], id="g-path-dt")
400
+ yield Label("speed_mean")
401
+ yield Input(value=GRID_DEFAULTS["path_speed"], id="g-path-speed")
402
+ yield Label("speed_std")
403
+ yield Input(
404
+ value=GRID_DEFAULTS["path_speed_std"], id="g-path-speed-std"
405
+ )
406
+ yield Label("heal_steps")
407
+ yield Input(
408
+ value=GRID_DEFAULTS["path_heal_steps"], id="g-path-heal"
409
+ )
410
+
411
+ with ParamGroup("Connectivity", id="g-analysis-connectivity"):
412
+ yield Static("No extra parameters")
413
+
414
+ with Vertical(id="right-panel"):
415
+ yield ImagePreview(id="result-preview")
416
+ yield LogViewer(id="log-viewer")
417
+
418
+ yield Footer()
419
+
420
+ def on_mount(self) -> None:
421
+ self.check_terminal_size()
422
+ self._update_analysis_options()
423
+ self._update_param_visibility()
424
+
425
+ def on_resize(self, event) -> None:
426
+ self.check_terminal_size()
427
+
428
+ def check_terminal_size(self) -> None:
429
+ width, height = self.size
430
+ if not self._size_warning_shown and (width < self.MIN_WIDTH or height < self.MIN_HEIGHT):
431
+ self._size_warning_shown = True
432
+ self.push_screen(TerminalSizeWarning(width, height))
433
+
434
+ def action_change_workdir(self) -> None:
435
+ self.push_screen(WorkdirScreen(), self.on_workdir_selected)
436
+
437
+ def on_workdir_selected(self, path: Path | None) -> None:
438
+ if path:
439
+ self.state.workdir = path
440
+ self.update_workdir_label()
441
+
442
+ def update_workdir_label(self) -> None:
443
+ label = self.query_one("#workdir-label", Label)
444
+ label.update(f"Workdir: {self.state.workdir}")
445
+
446
+ def on_button_pressed(self, event: Button.Pressed) -> None:
447
+ if event.button.id == "change-workdir-btn":
448
+ self.action_change_workdir()
449
+ elif event.button.id == "run-btn":
450
+ self.action_run()
451
+
452
+ def on_select_changed(self, event: Select.Changed) -> None:
453
+ if event.select.id == "model-select":
454
+ self.state.model = str(event.value)
455
+ self.state.analysis = get_default_analysis(self.state.model)
456
+ self._update_analysis_options()
457
+ self._update_param_visibility()
458
+ elif event.select.id == "analysis-select":
459
+ self.state.analysis = str(event.value)
460
+ self._update_param_visibility()
461
+
462
+ def action_run(self) -> None:
463
+ if self.current_worker and not self.current_worker.is_finished:
464
+ self.log_message("A task is already running.")
465
+ return
466
+
467
+ try:
468
+ model_params, analysis_params = self.collect_params()
469
+ except ValueError as exc:
470
+ self.push_screen(ErrorScreen("Parameter Error", str(exc)))
471
+ return
472
+
473
+ output_dir = self.state.workdir / "Results" / "gallery" / self.state.model
474
+ self.set_run_status("Status: Running...", "running")
475
+ self.update_progress(0)
476
+
477
+ self.current_worker = self.run_worker(
478
+ self.runner.run(
479
+ self.state.model,
480
+ self.state.analysis,
481
+ model_params,
482
+ analysis_params,
483
+ output_dir,
484
+ log_callback=self.log_message,
485
+ progress_callback=self.update_progress,
486
+ ),
487
+ name="gallery_worker",
488
+ thread=True,
489
+ )
490
+
491
+ def action_refresh(self) -> None:
492
+ preview_path = self.state.artifacts.get("output")
493
+ if preview_path:
494
+ preview = self.query_one("#result-preview", ImagePreview)
495
+ preview.update_image(preview_path)
496
+ self.log_message(f"Refreshed preview: {preview_path}")
497
+
498
+ def on_worker_state_changed(self, event: Worker.StateChanged) -> None:
499
+ if event.worker.name != "gallery_worker" or not event.worker.is_finished:
500
+ return
501
+ result = event.worker.result
502
+ if result.success:
503
+ self.state.artifacts = result.artifacts
504
+ self.set_run_status("Status: Complete", "success")
505
+ self.log_message(result.summary)
506
+ output_path = result.artifacts.get("output")
507
+ if output_path:
508
+ preview = self.query_one("#result-preview", ImagePreview)
509
+ preview.update_image(output_path)
510
+ else:
511
+ self.set_run_status("Status: Failed", "error")
512
+ self.push_screen(ErrorScreen("Gallery Error", result.error or "Unknown error"))
513
+
514
+ def set_run_status(self, message: str, status_class: str | None = None) -> None:
515
+ status = self.query_one("#run-status", Static)
516
+ status.update(message)
517
+ status.remove_class("running", "success", "error")
518
+ if status_class:
519
+ status.add_class(status_class)
520
+
521
+ def update_progress(self, percent: int) -> None:
522
+ progress = self.query_one("#progress-bar", ProgressBar)
523
+ progress.update(total=100, progress=percent)
524
+
525
+ def log_message(self, message: str) -> None:
526
+ log_viewer = self.query_one("#log-viewer", LogViewer)
527
+ log_viewer.add_log(message)
528
+
529
+ def collect_params(self) -> tuple[dict[str, float | int], dict[str, float | int]]:
530
+ if self.state.model == "cann1d":
531
+ model_params = {
532
+ "seed": self._int("#c1-seed"),
533
+ "num": self._int("#c1-num"),
534
+ "tau": self._float("#c1-tau"),
535
+ "k": self._float("#c1-k"),
536
+ "a": self._float("#c1-a"),
537
+ "A": self._float("#c1-A"),
538
+ "J0": self._float("#c1-J0"),
539
+ "dt": self._float("#c1-dt"),
540
+ }
541
+ analysis_params = {
542
+ "energy_pos": self._float("#c1-energy-pos"),
543
+ "energy_duration": self._float("#c1-energy-duration"),
544
+ "tuning_start": self._float("#c1-tuning-start"),
545
+ "tuning_mid": self._float("#c1-tuning-mid"),
546
+ "tuning_end": self._float("#c1-tuning-end"),
547
+ "tuning_duration": self._float("#c1-tuning-duration"),
548
+ "tuning_bins": self._int("#c1-tuning-bins"),
549
+ "tuning_neurons": self._str("#c1-tuning-neurons"),
550
+ "template_pos": self._float("#c1-template-pos"),
551
+ "template_duration": self._float("#c1-template-duration"),
552
+ "manifold_segment": self._float("#c1-manifold-segment"),
553
+ "manifold_warmup": self._float("#c1-manifold-warmup"),
554
+ }
555
+ return model_params, analysis_params
556
+
557
+ if self.state.model == "cann2d":
558
+ model_params = {
559
+ "seed": self._int("#c2-seed"),
560
+ "length": self._int("#c2-length"),
561
+ "tau": self._float("#c2-tau"),
562
+ "k": self._float("#c2-k"),
563
+ "a": self._float("#c2-a"),
564
+ "A": self._float("#c2-A"),
565
+ "J0": self._float("#c2-J0"),
566
+ "dt": self._float("#c2-dt"),
567
+ }
568
+ analysis_params = {
569
+ "energy_x": self._float("#c2-energy-x"),
570
+ "energy_y": self._float("#c2-energy-y"),
571
+ "energy_duration": self._float("#c2-energy-duration"),
572
+ "field_duration": self._float("#c2-field-duration"),
573
+ "field_box": self._float("#c2-field-box"),
574
+ "field_resolution": self._int("#c2-field-resolution"),
575
+ "field_sigma": self._float("#c2-field-sigma"),
576
+ "field_speed": self._float("#c2-field-speed"),
577
+ "field_speed_std": self._float("#c2-field-speed-std"),
578
+ "traj_segment": self._float("#c2-traj-segment"),
579
+ "traj_warmup": self._float("#c2-traj-warmup"),
580
+ "manifold_duration": self._float("#c2-manifold-duration"),
581
+ "manifold_warmup": self._float("#c2-manifold-warmup"),
582
+ "manifold_speed": self._float("#c2-manifold-speed"),
583
+ "manifold_speed_std": self._float("#c2-manifold-speed-std"),
584
+ "manifold_downsample": self._int("#c2-manifold-downsample"),
585
+ }
586
+ return model_params, analysis_params
587
+
588
+ if self.state.model == "gridcell":
589
+ model_params = {
590
+ "seed": self._int("#g-seed"),
591
+ "length": self._int("#g-length"),
592
+ "tau": self._float("#g-tau"),
593
+ "alpha": self._float("#g-alpha"),
594
+ "W_l": self._float("#g-Wl"),
595
+ "lambda_net": self._float("#g-lambda"),
596
+ "dt": self._float("#g-dt"),
597
+ }
598
+ analysis_params = {
599
+ "box_size": self._float("#g-box-size"),
600
+ "energy_duration": self._float("#g-energy-duration"),
601
+ "energy_speed": self._float("#g-energy-speed"),
602
+ "energy_speed_std": self._float("#g-energy-speed-std"),
603
+ "energy_heal_steps": self._int("#g-energy-heal"),
604
+ "field_resolution": self._int("#g-field-resolution"),
605
+ "field_sigma": self._float("#g-field-sigma"),
606
+ "field_speed": self._float("#g-field-speed"),
607
+ "field_batches": self._int("#g-field-batches"),
608
+ "path_duration": self._float("#g-path-duration"),
609
+ "path_dt": self._float("#g-path-dt"),
610
+ "path_speed": self._float("#g-path-speed"),
611
+ "path_speed_std": self._float("#g-path-speed-std"),
612
+ "path_heal_steps": self._int("#g-path-heal"),
613
+ }
614
+ return model_params, analysis_params
615
+
616
+ raise ValueError(f"Unknown model: {self.state.model}")
617
+
618
+ def _int(self, selector: str) -> int:
619
+ return int(self.query_one(selector, Input).value)
620
+
621
+ def _float(self, selector: str) -> float:
622
+ return float(self.query_one(selector, Input).value)
623
+
624
+ def _str(self, selector: str) -> str:
625
+ return self.query_one(selector, Input).value
626
+
627
+ def _update_analysis_options(self) -> None:
628
+ select = self.query_one("#analysis-select", Select)
629
+ options = get_analysis_options(self.state.model)
630
+ select.set_options(options)
631
+ select.value = self.state.analysis
632
+
633
+ def _update_param_visibility(self) -> None:
634
+ self._set_visible("#params-cann1d", self.state.model == "cann1d")
635
+ self._set_visible("#params-cann2d", self.state.model == "cann2d")
636
+ self._set_visible("#params-gridcell", self.state.model == "gridcell")
637
+
638
+ # CANN1D analysis groups
639
+ self._set_visible(
640
+ "#c1-analysis-connectivity",
641
+ self.state.model == "cann1d" and self.state.analysis == "connectivity",
642
+ )
643
+ self._set_visible(
644
+ "#c1-analysis-energy",
645
+ self.state.model == "cann1d" and self.state.analysis == "energy",
646
+ )
647
+ self._set_visible(
648
+ "#c1-analysis-tuning",
649
+ self.state.model == "cann1d" and self.state.analysis == "tuning",
650
+ )
651
+ self._set_visible(
652
+ "#c1-analysis-template",
653
+ self.state.model == "cann1d" and self.state.analysis == "template",
654
+ )
655
+ self._set_visible(
656
+ "#c1-analysis-manifold",
657
+ self.state.model == "cann1d" and self.state.analysis == "manifold",
658
+ )
659
+
660
+ # CANN2D analysis groups
661
+ self._set_visible(
662
+ "#c2-analysis-connectivity",
663
+ self.state.model == "cann2d" and self.state.analysis == "connectivity",
664
+ )
665
+ self._set_visible(
666
+ "#c2-analysis-energy",
667
+ self.state.model == "cann2d" and self.state.analysis == "energy",
668
+ )
669
+ self._set_visible(
670
+ "#c2-analysis-firing",
671
+ self.state.model == "cann2d" and self.state.analysis == "firing_field",
672
+ )
673
+ self._set_visible(
674
+ "#c2-analysis-trajectory",
675
+ self.state.model == "cann2d" and self.state.analysis == "trajectory",
676
+ )
677
+ self._set_visible(
678
+ "#c2-analysis-manifold",
679
+ self.state.model == "cann2d" and self.state.analysis == "manifold",
680
+ )
681
+
682
+ # Grid cell analysis groups
683
+ self._set_visible(
684
+ "#g-analysis-connectivity",
685
+ self.state.model == "gridcell" and self.state.analysis == "connectivity",
686
+ )
687
+ self._set_visible(
688
+ "#g-analysis-energy",
689
+ self.state.model == "gridcell" and self.state.analysis == "energy",
690
+ )
691
+ self._set_visible(
692
+ "#g-analysis-firing",
693
+ self.state.model == "gridcell" and self.state.analysis in {"firing_field", "manifold"},
694
+ )
695
+ self._set_visible(
696
+ "#g-analysis-path",
697
+ self.state.model == "gridcell" and self.state.analysis == "path_integration",
698
+ )
699
+
700
+ def _set_visible(self, selector: str, visible: bool) -> None:
701
+ widget = self.query_one(selector)
702
+ if visible:
703
+ widget.remove_class("hidden")
704
+ else:
705
+ widget.add_class("hidden")