canns 0.13.1__py3-none-any.whl → 0.13.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- canns/data/__init__.py +6 -1
- canns/data/datasets.py +154 -1
- canns/data/loaders.py +37 -0
- canns/pipeline/__init__.py +5 -13
- canns/pipeline/__main__.py +6 -0
- canns/pipeline/gallery/__init__.py +15 -5
- canns/pipeline/gallery/__main__.py +11 -0
- canns/pipeline/gallery/app.py +707 -0
- canns/pipeline/gallery/runner.py +783 -0
- canns/pipeline/gallery/state.py +52 -0
- canns/pipeline/gallery/styles.tcss +123 -0
- canns/pipeline/launcher.py +81 -0
- {canns-0.13.1.dist-info → canns-0.13.2.dist-info}/METADATA +1 -1
- {canns-0.13.1.dist-info → canns-0.13.2.dist-info}/RECORD +17 -11
- canns-0.13.2.dist-info/entry_points.txt +4 -0
- canns/pipeline/_base.py +0 -50
- canns-0.13.1.dist-info/entry_points.txt +0 -3
- {canns-0.13.1.dist-info → canns-0.13.2.dist-info}/WHEEL +0 -0
- {canns-0.13.1.dist-info → canns-0.13.2.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,707 @@
|
|
|
1
|
+
"""Model gallery TUI for quick CANN visualizations."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
from textual.app import App, ComposeResult
|
|
8
|
+
from textual.binding import Binding
|
|
9
|
+
from textual.containers import Horizontal, Vertical, VerticalScroll
|
|
10
|
+
from textual.widgets import Button, Footer, Header, Input, Label, ProgressBar, Select, Static
|
|
11
|
+
from textual.worker import Worker
|
|
12
|
+
|
|
13
|
+
from canns.pipeline.asa.screens import ErrorScreen, TerminalSizeWarning, WorkdirScreen
|
|
14
|
+
from canns.pipeline.asa.widgets import ImagePreview, LogViewer, ParamGroup
|
|
15
|
+
|
|
16
|
+
from .runner import GalleryRunner
|
|
17
|
+
from .state import GalleryState, get_analysis_options, get_default_analysis
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
CANN1D_DEFAULTS = {
|
|
21
|
+
"seed": "42",
|
|
22
|
+
"num": "256",
|
|
23
|
+
"tau": "1.0",
|
|
24
|
+
"k": "8.1",
|
|
25
|
+
"a": "0.5",
|
|
26
|
+
"A": "10.0",
|
|
27
|
+
"J0": "4.0",
|
|
28
|
+
"dt": "0.1",
|
|
29
|
+
"energy_pos": "1.0",
|
|
30
|
+
"energy_duration": "10.0",
|
|
31
|
+
"tuning_start": "0.0",
|
|
32
|
+
"tuning_mid": "3.1416",
|
|
33
|
+
"tuning_end": "6.2832",
|
|
34
|
+
"tuning_duration": "40.0",
|
|
35
|
+
"tuning_bins": "50",
|
|
36
|
+
"tuning_neurons": "64,128,192",
|
|
37
|
+
"template_pos": "1.0",
|
|
38
|
+
"template_duration": "10.0",
|
|
39
|
+
"manifold_segment": "40.0",
|
|
40
|
+
"manifold_warmup": "5.0",
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
CANN2D_DEFAULTS = {
|
|
44
|
+
"seed": "42",
|
|
45
|
+
"length": "64",
|
|
46
|
+
"tau": "1.0",
|
|
47
|
+
"k": "8.1",
|
|
48
|
+
"a": "0.5",
|
|
49
|
+
"A": "10.0",
|
|
50
|
+
"J0": "4.0",
|
|
51
|
+
"dt": "0.1",
|
|
52
|
+
"energy_x": "1.0",
|
|
53
|
+
"energy_y": "1.0",
|
|
54
|
+
"energy_duration": "10.0",
|
|
55
|
+
"field_duration": "8000.0",
|
|
56
|
+
"field_box": "6.2832",
|
|
57
|
+
"field_resolution": "80",
|
|
58
|
+
"field_sigma": "2.0",
|
|
59
|
+
"field_speed": "0.3",
|
|
60
|
+
"field_speed_std": "0.1",
|
|
61
|
+
"traj_segment": "40.0",
|
|
62
|
+
"traj_warmup": "5.0",
|
|
63
|
+
"manifold_duration": "22000.0",
|
|
64
|
+
"manifold_warmup": "2000.0",
|
|
65
|
+
"manifold_speed": "0.05",
|
|
66
|
+
"manifold_speed_std": "0.02",
|
|
67
|
+
"manifold_downsample": "10",
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
GRID_DEFAULTS = {
|
|
71
|
+
"seed": "74",
|
|
72
|
+
"length": "40",
|
|
73
|
+
"tau": "0.01",
|
|
74
|
+
"alpha": "0.1",
|
|
75
|
+
"W_l": "3.0",
|
|
76
|
+
"lambda_net": "17.0",
|
|
77
|
+
"dt": "0.0005",
|
|
78
|
+
"box_size": "2.2",
|
|
79
|
+
"energy_duration": "10.0",
|
|
80
|
+
"energy_speed": "0.2",
|
|
81
|
+
"energy_speed_std": "0.05",
|
|
82
|
+
"energy_heal_steps": "10000",
|
|
83
|
+
"field_resolution": "100",
|
|
84
|
+
"field_sigma": "2.0",
|
|
85
|
+
"field_speed": "0.3",
|
|
86
|
+
"field_batches": "10",
|
|
87
|
+
"path_duration": "10.0",
|
|
88
|
+
"path_dt": "0.01",
|
|
89
|
+
"path_speed": "0.5",
|
|
90
|
+
"path_speed_std": "0.05",
|
|
91
|
+
"path_heal_steps": "5000",
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class GalleryApp(App):
|
|
96
|
+
"""Main TUI application for the model gallery."""
|
|
97
|
+
|
|
98
|
+
CSS_PATH = "styles.tcss"
|
|
99
|
+
TITLE = "CANNs Model Gallery"
|
|
100
|
+
|
|
101
|
+
MIN_WIDTH = 100
|
|
102
|
+
RECOMMENDED_WIDTH = 120
|
|
103
|
+
MIN_HEIGHT = 28
|
|
104
|
+
RECOMMENDED_HEIGHT = 36
|
|
105
|
+
|
|
106
|
+
BINDINGS = [
|
|
107
|
+
Binding("ctrl+w", "change_workdir", "Workdir"),
|
|
108
|
+
Binding("ctrl+r", "run", "Run"),
|
|
109
|
+
Binding("f5", "refresh", "Refresh"),
|
|
110
|
+
Binding("escape", "quit", "Quit"),
|
|
111
|
+
]
|
|
112
|
+
|
|
113
|
+
def __init__(self) -> None:
|
|
114
|
+
super().__init__()
|
|
115
|
+
self.state = GalleryState()
|
|
116
|
+
self.runner = GalleryRunner()
|
|
117
|
+
self.current_worker: Worker | None = None
|
|
118
|
+
self._size_warning_shown = False
|
|
119
|
+
|
|
120
|
+
def compose(self) -> ComposeResult:
|
|
121
|
+
yield Header()
|
|
122
|
+
|
|
123
|
+
with Horizontal(id="main-container"):
|
|
124
|
+
with Vertical(id="left-panel"):
|
|
125
|
+
yield Label(f"Workdir: {self.state.workdir}", id="workdir-label")
|
|
126
|
+
yield Button("Change Workdir", id="change-workdir-btn")
|
|
127
|
+
|
|
128
|
+
yield Label("Model")
|
|
129
|
+
yield Select(
|
|
130
|
+
[
|
|
131
|
+
("CANN 1D", "cann1d"),
|
|
132
|
+
("CANN 2D", "cann2d"),
|
|
133
|
+
("Grid Cell", "gridcell"),
|
|
134
|
+
],
|
|
135
|
+
value=self.state.model,
|
|
136
|
+
id="model-select",
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
yield Label("Analysis")
|
|
140
|
+
yield Select(
|
|
141
|
+
get_analysis_options(self.state.model),
|
|
142
|
+
value=self.state.analysis,
|
|
143
|
+
id="analysis-select",
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
yield Button("Run", variant="primary", id="run-btn")
|
|
147
|
+
yield ProgressBar(id="progress-bar")
|
|
148
|
+
yield Static("Status: Idle", id="run-status")
|
|
149
|
+
|
|
150
|
+
with Vertical(id="middle-panel"):
|
|
151
|
+
with Vertical(id="params-panel"):
|
|
152
|
+
yield Static("Parameters", id="params-header")
|
|
153
|
+
with VerticalScroll(id="params-scroll"):
|
|
154
|
+
with Vertical(id="params-cann1d"):
|
|
155
|
+
with ParamGroup("CANN1D Model"):
|
|
156
|
+
yield Label("seed")
|
|
157
|
+
yield Input(value=CANN1D_DEFAULTS["seed"], id="c1-seed")
|
|
158
|
+
yield Label("num")
|
|
159
|
+
yield Input(value=CANN1D_DEFAULTS["num"], id="c1-num")
|
|
160
|
+
yield Label("tau")
|
|
161
|
+
yield Input(value=CANN1D_DEFAULTS["tau"], id="c1-tau")
|
|
162
|
+
yield Label("k")
|
|
163
|
+
yield Input(value=CANN1D_DEFAULTS["k"], id="c1-k")
|
|
164
|
+
yield Label("a")
|
|
165
|
+
yield Input(value=CANN1D_DEFAULTS["a"], id="c1-a")
|
|
166
|
+
yield Label("A")
|
|
167
|
+
yield Input(value=CANN1D_DEFAULTS["A"], id="c1-A")
|
|
168
|
+
yield Label("J0")
|
|
169
|
+
yield Input(value=CANN1D_DEFAULTS["J0"], id="c1-J0")
|
|
170
|
+
yield Label("dt")
|
|
171
|
+
yield Input(value=CANN1D_DEFAULTS["dt"], id="c1-dt")
|
|
172
|
+
|
|
173
|
+
with ParamGroup("Energy Landscape", id="c1-analysis-energy"):
|
|
174
|
+
yield Label("stimulus_pos")
|
|
175
|
+
yield Input(value=CANN1D_DEFAULTS["energy_pos"], id="c1-energy-pos")
|
|
176
|
+
yield Label("duration")
|
|
177
|
+
yield Input(
|
|
178
|
+
value=CANN1D_DEFAULTS["energy_duration"],
|
|
179
|
+
id="c1-energy-duration",
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
with ParamGroup("Tuning Curve", id="c1-analysis-tuning"):
|
|
183
|
+
yield Label("start_pos")
|
|
184
|
+
yield Input(
|
|
185
|
+
value=CANN1D_DEFAULTS["tuning_start"],
|
|
186
|
+
id="c1-tuning-start",
|
|
187
|
+
)
|
|
188
|
+
yield Label("mid_pos")
|
|
189
|
+
yield Input(
|
|
190
|
+
value=CANN1D_DEFAULTS["tuning_mid"],
|
|
191
|
+
id="c1-tuning-mid",
|
|
192
|
+
)
|
|
193
|
+
yield Label("end_pos")
|
|
194
|
+
yield Input(
|
|
195
|
+
value=CANN1D_DEFAULTS["tuning_end"],
|
|
196
|
+
id="c1-tuning-end",
|
|
197
|
+
)
|
|
198
|
+
yield Label("segment_duration")
|
|
199
|
+
yield Input(
|
|
200
|
+
value=CANN1D_DEFAULTS["tuning_duration"],
|
|
201
|
+
id="c1-tuning-duration",
|
|
202
|
+
)
|
|
203
|
+
yield Label("num_bins")
|
|
204
|
+
yield Input(
|
|
205
|
+
value=CANN1D_DEFAULTS["tuning_bins"],
|
|
206
|
+
id="c1-tuning-bins",
|
|
207
|
+
)
|
|
208
|
+
yield Label("neuron_indices")
|
|
209
|
+
yield Input(
|
|
210
|
+
value=CANN1D_DEFAULTS["tuning_neurons"],
|
|
211
|
+
id="c1-tuning-neurons",
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
with ParamGroup("Template Matching", id="c1-analysis-template"):
|
|
215
|
+
yield Label("stimulus_pos")
|
|
216
|
+
yield Input(
|
|
217
|
+
value=CANN1D_DEFAULTS["template_pos"],
|
|
218
|
+
id="c1-template-pos",
|
|
219
|
+
)
|
|
220
|
+
yield Label("duration")
|
|
221
|
+
yield Input(
|
|
222
|
+
value=CANN1D_DEFAULTS["template_duration"],
|
|
223
|
+
id="c1-template-duration",
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
with ParamGroup("Neural Manifold", id="c1-analysis-manifold"):
|
|
227
|
+
yield Label("segment_duration")
|
|
228
|
+
yield Input(
|
|
229
|
+
value=CANN1D_DEFAULTS["manifold_segment"],
|
|
230
|
+
id="c1-manifold-segment",
|
|
231
|
+
)
|
|
232
|
+
yield Label("warmup")
|
|
233
|
+
yield Input(
|
|
234
|
+
value=CANN1D_DEFAULTS["manifold_warmup"],
|
|
235
|
+
id="c1-manifold-warmup",
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
with ParamGroup("Connectivity", id="c1-analysis-connectivity"):
|
|
239
|
+
yield Static("No extra parameters")
|
|
240
|
+
|
|
241
|
+
with Vertical(id="params-cann2d", classes="hidden"):
|
|
242
|
+
with ParamGroup("CANN2D Model"):
|
|
243
|
+
yield Label("seed")
|
|
244
|
+
yield Input(value=CANN2D_DEFAULTS["seed"], id="c2-seed")
|
|
245
|
+
yield Label("length")
|
|
246
|
+
yield Input(value=CANN2D_DEFAULTS["length"], id="c2-length")
|
|
247
|
+
yield Label("tau")
|
|
248
|
+
yield Input(value=CANN2D_DEFAULTS["tau"], id="c2-tau")
|
|
249
|
+
yield Label("k")
|
|
250
|
+
yield Input(value=CANN2D_DEFAULTS["k"], id="c2-k")
|
|
251
|
+
yield Label("a")
|
|
252
|
+
yield Input(value=CANN2D_DEFAULTS["a"], id="c2-a")
|
|
253
|
+
yield Label("A")
|
|
254
|
+
yield Input(value=CANN2D_DEFAULTS["A"], id="c2-A")
|
|
255
|
+
yield Label("J0")
|
|
256
|
+
yield Input(value=CANN2D_DEFAULTS["J0"], id="c2-J0")
|
|
257
|
+
yield Label("dt")
|
|
258
|
+
yield Input(value=CANN2D_DEFAULTS["dt"], id="c2-dt")
|
|
259
|
+
|
|
260
|
+
with ParamGroup("Energy Landscape", id="c2-analysis-energy"):
|
|
261
|
+
yield Label("stimulus_x")
|
|
262
|
+
yield Input(value=CANN2D_DEFAULTS["energy_x"], id="c2-energy-x")
|
|
263
|
+
yield Label("stimulus_y")
|
|
264
|
+
yield Input(value=CANN2D_DEFAULTS["energy_y"], id="c2-energy-y")
|
|
265
|
+
yield Label("duration")
|
|
266
|
+
yield Input(
|
|
267
|
+
value=CANN2D_DEFAULTS["energy_duration"],
|
|
268
|
+
id="c2-energy-duration",
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
with ParamGroup("Firing Field", id="c2-analysis-firing"):
|
|
272
|
+
yield Label("duration")
|
|
273
|
+
yield Input(
|
|
274
|
+
value=CANN2D_DEFAULTS["field_duration"], id="c2-field-duration"
|
|
275
|
+
)
|
|
276
|
+
yield Label("box_size")
|
|
277
|
+
yield Input(value=CANN2D_DEFAULTS["field_box"], id="c2-field-box")
|
|
278
|
+
yield Label("resolution")
|
|
279
|
+
yield Input(
|
|
280
|
+
value=CANN2D_DEFAULTS["field_resolution"],
|
|
281
|
+
id="c2-field-resolution",
|
|
282
|
+
)
|
|
283
|
+
yield Label("smooth_sigma")
|
|
284
|
+
yield Input(
|
|
285
|
+
value=CANN2D_DEFAULTS["field_sigma"], id="c2-field-sigma"
|
|
286
|
+
)
|
|
287
|
+
yield Label("speed_mean")
|
|
288
|
+
yield Input(
|
|
289
|
+
value=CANN2D_DEFAULTS["field_speed"], id="c2-field-speed"
|
|
290
|
+
)
|
|
291
|
+
yield Label("speed_std")
|
|
292
|
+
yield Input(
|
|
293
|
+
value=CANN2D_DEFAULTS["field_speed_std"],
|
|
294
|
+
id="c2-field-speed-std",
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
with ParamGroup("Trajectory Comparison", id="c2-analysis-trajectory"):
|
|
298
|
+
yield Label("segment_duration")
|
|
299
|
+
yield Input(
|
|
300
|
+
value=CANN2D_DEFAULTS["traj_segment"],
|
|
301
|
+
id="c2-traj-segment",
|
|
302
|
+
)
|
|
303
|
+
yield Label("warmup")
|
|
304
|
+
yield Input(
|
|
305
|
+
value=CANN2D_DEFAULTS["traj_warmup"],
|
|
306
|
+
id="c2-traj-warmup",
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
with ParamGroup("Neural Manifold", id="c2-analysis-manifold"):
|
|
310
|
+
yield Label("duration")
|
|
311
|
+
yield Input(
|
|
312
|
+
value=CANN2D_DEFAULTS["manifold_duration"],
|
|
313
|
+
id="c2-manifold-duration",
|
|
314
|
+
)
|
|
315
|
+
yield Label("warmup")
|
|
316
|
+
yield Input(
|
|
317
|
+
value=CANN2D_DEFAULTS["manifold_warmup"],
|
|
318
|
+
id="c2-manifold-warmup",
|
|
319
|
+
)
|
|
320
|
+
yield Label("speed_mean")
|
|
321
|
+
yield Input(
|
|
322
|
+
value=CANN2D_DEFAULTS["manifold_speed"],
|
|
323
|
+
id="c2-manifold-speed",
|
|
324
|
+
)
|
|
325
|
+
yield Label("speed_std")
|
|
326
|
+
yield Input(
|
|
327
|
+
value=CANN2D_DEFAULTS["manifold_speed_std"],
|
|
328
|
+
id="c2-manifold-speed-std",
|
|
329
|
+
)
|
|
330
|
+
yield Label("downsample")
|
|
331
|
+
yield Input(
|
|
332
|
+
value=CANN2D_DEFAULTS["manifold_downsample"],
|
|
333
|
+
id="c2-manifold-downsample",
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
with ParamGroup("Connectivity", id="c2-analysis-connectivity"):
|
|
337
|
+
yield Static("No extra parameters")
|
|
338
|
+
|
|
339
|
+
with Vertical(id="params-gridcell", classes="hidden"):
|
|
340
|
+
with ParamGroup("Grid Cell Model"):
|
|
341
|
+
yield Label("seed")
|
|
342
|
+
yield Input(value=GRID_DEFAULTS["seed"], id="g-seed")
|
|
343
|
+
yield Label("length")
|
|
344
|
+
yield Input(value=GRID_DEFAULTS["length"], id="g-length")
|
|
345
|
+
yield Label("tau")
|
|
346
|
+
yield Input(value=GRID_DEFAULTS["tau"], id="g-tau")
|
|
347
|
+
yield Label("alpha")
|
|
348
|
+
yield Input(value=GRID_DEFAULTS["alpha"], id="g-alpha")
|
|
349
|
+
yield Label("W_l")
|
|
350
|
+
yield Input(value=GRID_DEFAULTS["W_l"], id="g-Wl")
|
|
351
|
+
yield Label("lambda_net")
|
|
352
|
+
yield Input(value=GRID_DEFAULTS["lambda_net"], id="g-lambda")
|
|
353
|
+
yield Label("dt")
|
|
354
|
+
yield Input(value=GRID_DEFAULTS["dt"], id="g-dt")
|
|
355
|
+
|
|
356
|
+
with ParamGroup("Energy Landscape", id="g-analysis-energy"):
|
|
357
|
+
yield Label("duration")
|
|
358
|
+
yield Input(
|
|
359
|
+
value=GRID_DEFAULTS["energy_duration"], id="g-energy-duration"
|
|
360
|
+
)
|
|
361
|
+
yield Label("speed_mean")
|
|
362
|
+
yield Input(
|
|
363
|
+
value=GRID_DEFAULTS["energy_speed"], id="g-energy-speed"
|
|
364
|
+
)
|
|
365
|
+
yield Label("speed_std")
|
|
366
|
+
yield Input(
|
|
367
|
+
value=GRID_DEFAULTS["energy_speed_std"],
|
|
368
|
+
id="g-energy-speed-std",
|
|
369
|
+
)
|
|
370
|
+
yield Label("heal_steps")
|
|
371
|
+
yield Input(
|
|
372
|
+
value=GRID_DEFAULTS["energy_heal_steps"],
|
|
373
|
+
id="g-energy-heal",
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
with ParamGroup("Firing Field", id="g-analysis-firing"):
|
|
377
|
+
yield Label("box_size")
|
|
378
|
+
yield Input(value=GRID_DEFAULTS["box_size"], id="g-box-size")
|
|
379
|
+
yield Label("resolution")
|
|
380
|
+
yield Input(
|
|
381
|
+
value=GRID_DEFAULTS["field_resolution"],
|
|
382
|
+
id="g-field-resolution",
|
|
383
|
+
)
|
|
384
|
+
yield Label("smooth_sigma")
|
|
385
|
+
yield Input(value=GRID_DEFAULTS["field_sigma"], id="g-field-sigma")
|
|
386
|
+
yield Label("speed")
|
|
387
|
+
yield Input(value=GRID_DEFAULTS["field_speed"], id="g-field-speed")
|
|
388
|
+
yield Label("num_batches")
|
|
389
|
+
yield Input(
|
|
390
|
+
value=GRID_DEFAULTS["field_batches"],
|
|
391
|
+
id="g-field-batches",
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
with ParamGroup("Path Integration", id="g-analysis-path"):
|
|
395
|
+
yield Label("duration")
|
|
396
|
+
yield Input(value=GRID_DEFAULTS["path_duration"], id="g-path-duration")
|
|
397
|
+
yield Label("dt")
|
|
398
|
+
yield Input(value=GRID_DEFAULTS["path_dt"], id="g-path-dt")
|
|
399
|
+
yield Label("speed_mean")
|
|
400
|
+
yield Input(value=GRID_DEFAULTS["path_speed"], id="g-path-speed")
|
|
401
|
+
yield Label("speed_std")
|
|
402
|
+
yield Input(
|
|
403
|
+
value=GRID_DEFAULTS["path_speed_std"], id="g-path-speed-std"
|
|
404
|
+
)
|
|
405
|
+
yield Label("heal_steps")
|
|
406
|
+
yield Input(
|
|
407
|
+
value=GRID_DEFAULTS["path_heal_steps"], id="g-path-heal"
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
with ParamGroup("Connectivity", id="g-analysis-connectivity"):
|
|
411
|
+
yield Static("No extra parameters")
|
|
412
|
+
|
|
413
|
+
with Vertical(id="right-panel"):
|
|
414
|
+
yield ImagePreview(id="result-preview")
|
|
415
|
+
yield LogViewer(id="log-viewer")
|
|
416
|
+
|
|
417
|
+
yield Footer()
|
|
418
|
+
|
|
419
|
+
def on_mount(self) -> None:
|
|
420
|
+
self.check_terminal_size()
|
|
421
|
+
self._update_analysis_options()
|
|
422
|
+
self._update_param_visibility()
|
|
423
|
+
|
|
424
|
+
def on_resize(self, event) -> None:
|
|
425
|
+
self.check_terminal_size()
|
|
426
|
+
|
|
427
|
+
def check_terminal_size(self) -> None:
|
|
428
|
+
width, height = self.size
|
|
429
|
+
if not self._size_warning_shown and (
|
|
430
|
+
width < self.MIN_WIDTH or height < self.MIN_HEIGHT
|
|
431
|
+
):
|
|
432
|
+
self._size_warning_shown = True
|
|
433
|
+
self.push_screen(TerminalSizeWarning(width, height))
|
|
434
|
+
|
|
435
|
+
def action_change_workdir(self) -> None:
|
|
436
|
+
self.push_screen(WorkdirScreen(), self.on_workdir_selected)
|
|
437
|
+
|
|
438
|
+
def on_workdir_selected(self, path: Path | None) -> None:
|
|
439
|
+
if path:
|
|
440
|
+
self.state.workdir = path
|
|
441
|
+
self.update_workdir_label()
|
|
442
|
+
|
|
443
|
+
def update_workdir_label(self) -> None:
|
|
444
|
+
label = self.query_one("#workdir-label", Label)
|
|
445
|
+
label.update(f"Workdir: {self.state.workdir}")
|
|
446
|
+
|
|
447
|
+
def on_button_pressed(self, event: Button.Pressed) -> None:
|
|
448
|
+
if event.button.id == "change-workdir-btn":
|
|
449
|
+
self.action_change_workdir()
|
|
450
|
+
elif event.button.id == "run-btn":
|
|
451
|
+
self.action_run()
|
|
452
|
+
|
|
453
|
+
def on_select_changed(self, event: Select.Changed) -> None:
|
|
454
|
+
if event.select.id == "model-select":
|
|
455
|
+
self.state.model = str(event.value)
|
|
456
|
+
self.state.analysis = get_default_analysis(self.state.model)
|
|
457
|
+
self._update_analysis_options()
|
|
458
|
+
self._update_param_visibility()
|
|
459
|
+
elif event.select.id == "analysis-select":
|
|
460
|
+
self.state.analysis = str(event.value)
|
|
461
|
+
self._update_param_visibility()
|
|
462
|
+
|
|
463
|
+
def action_run(self) -> None:
|
|
464
|
+
if self.current_worker and not self.current_worker.is_finished:
|
|
465
|
+
self.log_message("A task is already running.")
|
|
466
|
+
return
|
|
467
|
+
|
|
468
|
+
try:
|
|
469
|
+
model_params, analysis_params = self.collect_params()
|
|
470
|
+
except ValueError as exc:
|
|
471
|
+
self.push_screen(ErrorScreen("Parameter Error", str(exc)))
|
|
472
|
+
return
|
|
473
|
+
|
|
474
|
+
output_dir = self.state.workdir / "Results" / "gallery" / self.state.model
|
|
475
|
+
self.set_run_status("Status: Running...", "running")
|
|
476
|
+
self.update_progress(0)
|
|
477
|
+
|
|
478
|
+
self.current_worker = self.run_worker(
|
|
479
|
+
self.runner.run(
|
|
480
|
+
self.state.model,
|
|
481
|
+
self.state.analysis,
|
|
482
|
+
model_params,
|
|
483
|
+
analysis_params,
|
|
484
|
+
output_dir,
|
|
485
|
+
log_callback=self.log_message,
|
|
486
|
+
progress_callback=self.update_progress,
|
|
487
|
+
),
|
|
488
|
+
name="gallery_worker",
|
|
489
|
+
thread=True,
|
|
490
|
+
)
|
|
491
|
+
|
|
492
|
+
def action_refresh(self) -> None:
|
|
493
|
+
preview_path = self.state.artifacts.get("output")
|
|
494
|
+
if preview_path:
|
|
495
|
+
preview = self.query_one("#result-preview", ImagePreview)
|
|
496
|
+
preview.update_image(preview_path)
|
|
497
|
+
self.log_message(f"Refreshed preview: {preview_path}")
|
|
498
|
+
|
|
499
|
+
def on_worker_state_changed(self, event: Worker.StateChanged) -> None:
|
|
500
|
+
if event.worker.name != "gallery_worker" or not event.worker.is_finished:
|
|
501
|
+
return
|
|
502
|
+
result = event.worker.result
|
|
503
|
+
if result.success:
|
|
504
|
+
self.state.artifacts = result.artifacts
|
|
505
|
+
self.set_run_status("Status: Complete", "success")
|
|
506
|
+
self.log_message(result.summary)
|
|
507
|
+
output_path = result.artifacts.get("output")
|
|
508
|
+
if output_path:
|
|
509
|
+
preview = self.query_one("#result-preview", ImagePreview)
|
|
510
|
+
preview.update_image(output_path)
|
|
511
|
+
else:
|
|
512
|
+
self.set_run_status("Status: Failed", "error")
|
|
513
|
+
self.push_screen(ErrorScreen("Gallery Error", result.error or "Unknown error"))
|
|
514
|
+
|
|
515
|
+
def set_run_status(self, message: str, status_class: str | None = None) -> None:
|
|
516
|
+
status = self.query_one("#run-status", Static)
|
|
517
|
+
status.update(message)
|
|
518
|
+
status.remove_class("running", "success", "error")
|
|
519
|
+
if status_class:
|
|
520
|
+
status.add_class(status_class)
|
|
521
|
+
|
|
522
|
+
def update_progress(self, percent: int) -> None:
|
|
523
|
+
progress = self.query_one("#progress-bar", ProgressBar)
|
|
524
|
+
progress.update(total=100, progress=percent)
|
|
525
|
+
|
|
526
|
+
def log_message(self, message: str) -> None:
|
|
527
|
+
log_viewer = self.query_one("#log-viewer", LogViewer)
|
|
528
|
+
log_viewer.add_log(message)
|
|
529
|
+
|
|
530
|
+
def collect_params(self) -> tuple[dict[str, float | int], dict[str, float | int]]:
|
|
531
|
+
if self.state.model == "cann1d":
|
|
532
|
+
model_params = {
|
|
533
|
+
"seed": self._int("#c1-seed"),
|
|
534
|
+
"num": self._int("#c1-num"),
|
|
535
|
+
"tau": self._float("#c1-tau"),
|
|
536
|
+
"k": self._float("#c1-k"),
|
|
537
|
+
"a": self._float("#c1-a"),
|
|
538
|
+
"A": self._float("#c1-A"),
|
|
539
|
+
"J0": self._float("#c1-J0"),
|
|
540
|
+
"dt": self._float("#c1-dt"),
|
|
541
|
+
}
|
|
542
|
+
analysis_params = {
|
|
543
|
+
"energy_pos": self._float("#c1-energy-pos"),
|
|
544
|
+
"energy_duration": self._float("#c1-energy-duration"),
|
|
545
|
+
"tuning_start": self._float("#c1-tuning-start"),
|
|
546
|
+
"tuning_mid": self._float("#c1-tuning-mid"),
|
|
547
|
+
"tuning_end": self._float("#c1-tuning-end"),
|
|
548
|
+
"tuning_duration": self._float("#c1-tuning-duration"),
|
|
549
|
+
"tuning_bins": self._int("#c1-tuning-bins"),
|
|
550
|
+
"tuning_neurons": self._str("#c1-tuning-neurons"),
|
|
551
|
+
"template_pos": self._float("#c1-template-pos"),
|
|
552
|
+
"template_duration": self._float("#c1-template-duration"),
|
|
553
|
+
"manifold_segment": self._float("#c1-manifold-segment"),
|
|
554
|
+
"manifold_warmup": self._float("#c1-manifold-warmup"),
|
|
555
|
+
}
|
|
556
|
+
return model_params, analysis_params
|
|
557
|
+
|
|
558
|
+
if self.state.model == "cann2d":
|
|
559
|
+
model_params = {
|
|
560
|
+
"seed": self._int("#c2-seed"),
|
|
561
|
+
"length": self._int("#c2-length"),
|
|
562
|
+
"tau": self._float("#c2-tau"),
|
|
563
|
+
"k": self._float("#c2-k"),
|
|
564
|
+
"a": self._float("#c2-a"),
|
|
565
|
+
"A": self._float("#c2-A"),
|
|
566
|
+
"J0": self._float("#c2-J0"),
|
|
567
|
+
"dt": self._float("#c2-dt"),
|
|
568
|
+
}
|
|
569
|
+
analysis_params = {
|
|
570
|
+
"energy_x": self._float("#c2-energy-x"),
|
|
571
|
+
"energy_y": self._float("#c2-energy-y"),
|
|
572
|
+
"energy_duration": self._float("#c2-energy-duration"),
|
|
573
|
+
"field_duration": self._float("#c2-field-duration"),
|
|
574
|
+
"field_box": self._float("#c2-field-box"),
|
|
575
|
+
"field_resolution": self._int("#c2-field-resolution"),
|
|
576
|
+
"field_sigma": self._float("#c2-field-sigma"),
|
|
577
|
+
"field_speed": self._float("#c2-field-speed"),
|
|
578
|
+
"field_speed_std": self._float("#c2-field-speed-std"),
|
|
579
|
+
"traj_segment": self._float("#c2-traj-segment"),
|
|
580
|
+
"traj_warmup": self._float("#c2-traj-warmup"),
|
|
581
|
+
"manifold_duration": self._float("#c2-manifold-duration"),
|
|
582
|
+
"manifold_warmup": self._float("#c2-manifold-warmup"),
|
|
583
|
+
"manifold_speed": self._float("#c2-manifold-speed"),
|
|
584
|
+
"manifold_speed_std": self._float("#c2-manifold-speed-std"),
|
|
585
|
+
"manifold_downsample": self._int("#c2-manifold-downsample"),
|
|
586
|
+
}
|
|
587
|
+
return model_params, analysis_params
|
|
588
|
+
|
|
589
|
+
if self.state.model == "gridcell":
|
|
590
|
+
model_params = {
|
|
591
|
+
"seed": self._int("#g-seed"),
|
|
592
|
+
"length": self._int("#g-length"),
|
|
593
|
+
"tau": self._float("#g-tau"),
|
|
594
|
+
"alpha": self._float("#g-alpha"),
|
|
595
|
+
"W_l": self._float("#g-Wl"),
|
|
596
|
+
"lambda_net": self._float("#g-lambda"),
|
|
597
|
+
"dt": self._float("#g-dt"),
|
|
598
|
+
}
|
|
599
|
+
analysis_params = {
|
|
600
|
+
"box_size": self._float("#g-box-size"),
|
|
601
|
+
"energy_duration": self._float("#g-energy-duration"),
|
|
602
|
+
"energy_speed": self._float("#g-energy-speed"),
|
|
603
|
+
"energy_speed_std": self._float("#g-energy-speed-std"),
|
|
604
|
+
"energy_heal_steps": self._int("#g-energy-heal"),
|
|
605
|
+
"field_resolution": self._int("#g-field-resolution"),
|
|
606
|
+
"field_sigma": self._float("#g-field-sigma"),
|
|
607
|
+
"field_speed": self._float("#g-field-speed"),
|
|
608
|
+
"field_batches": self._int("#g-field-batches"),
|
|
609
|
+
"path_duration": self._float("#g-path-duration"),
|
|
610
|
+
"path_dt": self._float("#g-path-dt"),
|
|
611
|
+
"path_speed": self._float("#g-path-speed"),
|
|
612
|
+
"path_speed_std": self._float("#g-path-speed-std"),
|
|
613
|
+
"path_heal_steps": self._int("#g-path-heal"),
|
|
614
|
+
}
|
|
615
|
+
return model_params, analysis_params
|
|
616
|
+
|
|
617
|
+
raise ValueError(f"Unknown model: {self.state.model}")
|
|
618
|
+
|
|
619
|
+
def _int(self, selector: str) -> int:
|
|
620
|
+
return int(self.query_one(selector, Input).value)
|
|
621
|
+
|
|
622
|
+
def _float(self, selector: str) -> float:
|
|
623
|
+
return float(self.query_one(selector, Input).value)
|
|
624
|
+
|
|
625
|
+
def _str(self, selector: str) -> str:
|
|
626
|
+
return self.query_one(selector, Input).value
|
|
627
|
+
|
|
628
|
+
def _update_analysis_options(self) -> None:
|
|
629
|
+
select = self.query_one("#analysis-select", Select)
|
|
630
|
+
options = get_analysis_options(self.state.model)
|
|
631
|
+
select.set_options(options)
|
|
632
|
+
select.value = self.state.analysis
|
|
633
|
+
|
|
634
|
+
def _update_param_visibility(self) -> None:
|
|
635
|
+
self._set_visible("#params-cann1d", self.state.model == "cann1d")
|
|
636
|
+
self._set_visible("#params-cann2d", self.state.model == "cann2d")
|
|
637
|
+
self._set_visible("#params-gridcell", self.state.model == "gridcell")
|
|
638
|
+
|
|
639
|
+
# CANN1D analysis groups
|
|
640
|
+
self._set_visible(
|
|
641
|
+
"#c1-analysis-connectivity",
|
|
642
|
+
self.state.model == "cann1d" and self.state.analysis == "connectivity",
|
|
643
|
+
)
|
|
644
|
+
self._set_visible(
|
|
645
|
+
"#c1-analysis-energy",
|
|
646
|
+
self.state.model == "cann1d" and self.state.analysis == "energy",
|
|
647
|
+
)
|
|
648
|
+
self._set_visible(
|
|
649
|
+
"#c1-analysis-tuning",
|
|
650
|
+
self.state.model == "cann1d" and self.state.analysis == "tuning",
|
|
651
|
+
)
|
|
652
|
+
self._set_visible(
|
|
653
|
+
"#c1-analysis-template",
|
|
654
|
+
self.state.model == "cann1d" and self.state.analysis == "template",
|
|
655
|
+
)
|
|
656
|
+
self._set_visible(
|
|
657
|
+
"#c1-analysis-manifold",
|
|
658
|
+
self.state.model == "cann1d" and self.state.analysis == "manifold",
|
|
659
|
+
)
|
|
660
|
+
|
|
661
|
+
# CANN2D analysis groups
|
|
662
|
+
self._set_visible(
|
|
663
|
+
"#c2-analysis-connectivity",
|
|
664
|
+
self.state.model == "cann2d" and self.state.analysis == "connectivity",
|
|
665
|
+
)
|
|
666
|
+
self._set_visible(
|
|
667
|
+
"#c2-analysis-energy",
|
|
668
|
+
self.state.model == "cann2d" and self.state.analysis == "energy",
|
|
669
|
+
)
|
|
670
|
+
self._set_visible(
|
|
671
|
+
"#c2-analysis-firing",
|
|
672
|
+
self.state.model == "cann2d" and self.state.analysis == "firing_field",
|
|
673
|
+
)
|
|
674
|
+
self._set_visible(
|
|
675
|
+
"#c2-analysis-trajectory",
|
|
676
|
+
self.state.model == "cann2d" and self.state.analysis == "trajectory",
|
|
677
|
+
)
|
|
678
|
+
self._set_visible(
|
|
679
|
+
"#c2-analysis-manifold",
|
|
680
|
+
self.state.model == "cann2d" and self.state.analysis == "manifold",
|
|
681
|
+
)
|
|
682
|
+
|
|
683
|
+
# Grid cell analysis groups
|
|
684
|
+
self._set_visible(
|
|
685
|
+
"#g-analysis-connectivity",
|
|
686
|
+
self.state.model == "gridcell" and self.state.analysis == "connectivity",
|
|
687
|
+
)
|
|
688
|
+
self._set_visible(
|
|
689
|
+
"#g-analysis-energy",
|
|
690
|
+
self.state.model == "gridcell" and self.state.analysis == "energy",
|
|
691
|
+
)
|
|
692
|
+
self._set_visible(
|
|
693
|
+
"#g-analysis-firing",
|
|
694
|
+
self.state.model == "gridcell"
|
|
695
|
+
and self.state.analysis in {"firing_field", "manifold"},
|
|
696
|
+
)
|
|
697
|
+
self._set_visible(
|
|
698
|
+
"#g-analysis-path",
|
|
699
|
+
self.state.model == "gridcell" and self.state.analysis == "path_integration",
|
|
700
|
+
)
|
|
701
|
+
|
|
702
|
+
def _set_visible(self, selector: str, visible: bool) -> None:
|
|
703
|
+
widget = self.query_one(selector)
|
|
704
|
+
if visible:
|
|
705
|
+
widget.remove_class("hidden")
|
|
706
|
+
else:
|
|
707
|
+
widget.add_class("hidden")
|