canns 0.13.0__py3-none-any.whl → 0.13.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- canns/analyzer/data/asa/__init__.py +10 -0
- canns/analyzer/data/asa/decode.py +18 -21
- canns/analyzer/data/{legacy/cann1d.py → asa/fly_roi.py} +96 -43
- canns/analyzer/data/asa/fr.py +4 -12
- canns/analyzer/data/asa/plotting.py +12 -1
- canns/data/__init__.py +6 -1
- canns/data/datasets.py +154 -1
- canns/data/loaders.py +37 -0
- canns/pipeline/__init__.py +5 -13
- canns/pipeline/__main__.py +6 -0
- canns/pipeline/asa/widgets.py +1 -1
- canns/pipeline/gallery/__init__.py +15 -5
- canns/pipeline/gallery/__main__.py +11 -0
- canns/pipeline/gallery/app.py +707 -0
- canns/pipeline/gallery/runner.py +783 -0
- canns/pipeline/gallery/state.py +52 -0
- canns/pipeline/gallery/styles.tcss +123 -0
- canns/pipeline/launcher.py +81 -0
- {canns-0.13.0.dist-info → canns-0.13.2.dist-info}/METADATA +1 -1
- {canns-0.13.0.dist-info → canns-0.13.2.dist-info}/RECORD +23 -19
- canns-0.13.2.dist-info/entry_points.txt +4 -0
- canns/analyzer/data/legacy/__init__.py +0 -6
- canns/analyzer/data/legacy/cann2d.py +0 -2565
- canns/pipeline/_base.py +0 -50
- canns-0.13.0.dist-info/entry_points.txt +0 -3
- {canns-0.13.0.dist-info → canns-0.13.2.dist-info}/WHEEL +0 -0
- {canns-0.13.0.dist-info → canns-0.13.2.dist-info}/licenses/LICENSE +0 -0
|
@@ -18,6 +18,12 @@ from .config import (
|
|
|
18
18
|
)
|
|
19
19
|
from .decode import decode_circular_coordinates, decode_circular_coordinates_multi
|
|
20
20
|
from .embedding import embed_spike_trains
|
|
21
|
+
from .fly_roi import (
|
|
22
|
+
BumpFitsConfig,
|
|
23
|
+
CANN1DPlotConfig,
|
|
24
|
+
create_1d_bump_animation,
|
|
25
|
+
roi_bump_fits,
|
|
26
|
+
)
|
|
21
27
|
from .fr import (
|
|
22
28
|
FRMResult,
|
|
23
29
|
compute_fr_heatmap_matrix,
|
|
@@ -60,6 +66,10 @@ __all__ = [
|
|
|
60
66
|
"plot_cohomap_multi",
|
|
61
67
|
"plot_3d_bump_on_torus",
|
|
62
68
|
"plot_2d_bump_on_manifold",
|
|
69
|
+
"BumpFitsConfig",
|
|
70
|
+
"CANN1DPlotConfig",
|
|
71
|
+
"create_1d_bump_animation",
|
|
72
|
+
"roi_bump_fits",
|
|
63
73
|
"compute_fr_heatmap_matrix",
|
|
64
74
|
"save_fr_heatmap_png",
|
|
65
75
|
"FRMResult",
|
|
@@ -33,7 +33,7 @@ def decode_circular_coordinates(
|
|
|
33
33
|
real_of : bool
|
|
34
34
|
Whether the experiment is open-field (controls box coordinate handling).
|
|
35
35
|
save_path : str, optional
|
|
36
|
-
Path to save decoding results.
|
|
36
|
+
Path to save decoding results. If ``None``, results are not saved.
|
|
37
37
|
|
|
38
38
|
Returns
|
|
39
39
|
-------
|
|
@@ -174,13 +174,12 @@ def decode_circular_coordinates(
|
|
|
174
174
|
"centsinall": centsinall,
|
|
175
175
|
}
|
|
176
176
|
|
|
177
|
-
# Save results
|
|
178
|
-
if save_path is None:
|
|
179
|
-
os.
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
np.savez_compressed(save_path, **results)
|
|
177
|
+
# Save results (only when requested)
|
|
178
|
+
if save_path is not None:
|
|
179
|
+
save_dir = os.path.dirname(save_path)
|
|
180
|
+
if save_dir:
|
|
181
|
+
os.makedirs(save_dir, exist_ok=True)
|
|
182
|
+
np.savez_compressed(save_path, **results)
|
|
184
183
|
|
|
185
184
|
return results
|
|
186
185
|
|
|
@@ -264,13 +263,12 @@ def decode_circular_coordinates1(
|
|
|
264
263
|
"centsinall": centsinall,
|
|
265
264
|
}
|
|
266
265
|
|
|
267
|
-
# Save results
|
|
268
|
-
if save_path is None:
|
|
269
|
-
os.
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
np.savez_compressed(save_path, **results)
|
|
266
|
+
# Save results (only when requested)
|
|
267
|
+
if save_path is not None:
|
|
268
|
+
save_dir = os.path.dirname(save_path)
|
|
269
|
+
if save_dir:
|
|
270
|
+
os.makedirs(save_dir, exist_ok=True)
|
|
271
|
+
np.savez_compressed(save_path, **results)
|
|
274
272
|
|
|
275
273
|
return results
|
|
276
274
|
|
|
@@ -291,7 +289,7 @@ def decode_circular_coordinates_multi(
|
|
|
291
289
|
spike_data : dict
|
|
292
290
|
Spike data dictionary containing ``'spike'``, ``'t'`` and optionally ``'x'``/``'y'``.
|
|
293
291
|
save_path : str, optional
|
|
294
|
-
Path to save decoding results.
|
|
292
|
+
Path to save decoding results. If ``None``, results are not saved.
|
|
295
293
|
num_circ : int
|
|
296
294
|
Number of H1 cocycles/circular coordinates to decode.
|
|
297
295
|
|
|
@@ -383,11 +381,10 @@ def decode_circular_coordinates_multi(
|
|
|
383
381
|
"centsinall": centsinall,
|
|
384
382
|
}
|
|
385
383
|
|
|
386
|
-
if save_path is None:
|
|
387
|
-
os.
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
|
384
|
+
if save_path is not None:
|
|
385
|
+
save_dir = os.path.dirname(save_path)
|
|
386
|
+
if save_dir:
|
|
387
|
+
os.makedirs(save_dir, exist_ok=True)
|
|
391
388
|
np.savez_compressed(save_path, **results)
|
|
392
389
|
return results
|
|
393
390
|
|
|
@@ -4,7 +4,7 @@ from dataclasses import dataclass
|
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
from matplotlib import pyplot as plt
|
|
7
|
-
from matplotlib.animation import FuncAnimation
|
|
7
|
+
from matplotlib.animation import FuncAnimation
|
|
8
8
|
from scipy.optimize import linear_sum_assignment
|
|
9
9
|
from scipy.special import i0
|
|
10
10
|
from tqdm import tqdm
|
|
@@ -66,22 +66,6 @@ class BumpFitsConfig:
|
|
|
66
66
|
random_seed: int | None = None
|
|
67
67
|
|
|
68
68
|
|
|
69
|
-
@dataclass
|
|
70
|
-
class AnimationConfig:
|
|
71
|
-
"""Configuration for 1D CANN bump animation."""
|
|
72
|
-
|
|
73
|
-
show: bool = False
|
|
74
|
-
max_height_value: float = 0.5
|
|
75
|
-
max_width_range: int = 40
|
|
76
|
-
npoints: int = 300
|
|
77
|
-
nframes: int | None = None
|
|
78
|
-
fps: int = 5
|
|
79
|
-
bump_selection: str = "strongest"
|
|
80
|
-
show_progress_bar: bool = True
|
|
81
|
-
repeat: bool = False
|
|
82
|
-
title: str = "1D CANN Bump Animation"
|
|
83
|
-
|
|
84
|
-
|
|
85
69
|
@dataclass
|
|
86
70
|
class CANN1DPlotConfig(PlotConfig):
|
|
87
71
|
"""Specialized PlotConfig for CANN1D visualizations."""
|
|
@@ -141,9 +125,9 @@ class AnimationError(CANN1DError):
|
|
|
141
125
|
pass
|
|
142
126
|
|
|
143
127
|
|
|
144
|
-
def
|
|
128
|
+
def roi_bump_fits(data, config: BumpFitsConfig | None = None, save_path=None, **kwargs):
|
|
145
129
|
"""
|
|
146
|
-
Fit CANN1D bumps to data using MCMC optimization.
|
|
130
|
+
Fit CANN1D bumps to ROI data using MCMC optimization.
|
|
147
131
|
|
|
148
132
|
Parameters:
|
|
149
133
|
data : numpy.ndarray
|
|
@@ -318,10 +302,10 @@ def create_1d_bump_animation(
|
|
|
318
302
|
Parameters:
|
|
319
303
|
fits_data : numpy.ndarray
|
|
320
304
|
Shape (n_fits, 4) array with columns [time, position, amplitude, kappa]
|
|
321
|
-
config :
|
|
305
|
+
config : CANN1DPlotConfig, optional
|
|
322
306
|
Configuration object with all animation parameters
|
|
323
307
|
save_path : str, optional
|
|
324
|
-
Output path for the generated
|
|
308
|
+
Output path for the generated animation (e.g. .gif or .mp4)
|
|
325
309
|
**kwargs : backward compatibility parameters
|
|
326
310
|
|
|
327
311
|
Returns:
|
|
@@ -336,6 +320,8 @@ def create_1d_bump_animation(
|
|
|
336
320
|
for key, value in kwargs.items():
|
|
337
321
|
if hasattr(config, key):
|
|
338
322
|
setattr(config, key, value)
|
|
323
|
+
if save_path is not None:
|
|
324
|
+
config.save_path = save_path
|
|
339
325
|
|
|
340
326
|
try:
|
|
341
327
|
# ==== Smoothing functions ====
|
|
@@ -520,7 +506,10 @@ def create_1d_bump_animation(
|
|
|
520
506
|
fig, update, frames=nframes, init_func=init, blit=use_blitting, repeat=config.repeat
|
|
521
507
|
)
|
|
522
508
|
|
|
523
|
-
|
|
509
|
+
ani = None
|
|
510
|
+
progress_bar_enabled = getattr(config, "show_progress_bar", True)
|
|
511
|
+
|
|
512
|
+
# Save animation with unified backend selection
|
|
524
513
|
if config.save_path:
|
|
525
514
|
# Warn if both saving and showing (causes double rendering)
|
|
526
515
|
if config.show and nframes > 50:
|
|
@@ -528,31 +517,95 @@ def create_1d_bump_animation(
|
|
|
528
517
|
|
|
529
518
|
warn_double_rendering(nframes, config.save_path, stacklevel=2)
|
|
530
519
|
|
|
531
|
-
|
|
532
|
-
|
|
520
|
+
from ...visualization.core import (
|
|
521
|
+
emit_backend_warnings,
|
|
522
|
+
get_imageio_writer_kwargs,
|
|
523
|
+
get_matplotlib_writer,
|
|
524
|
+
select_animation_backend,
|
|
525
|
+
)
|
|
533
526
|
|
|
534
|
-
|
|
535
|
-
|
|
527
|
+
backend_selection = select_animation_backend(
|
|
528
|
+
save_path=config.save_path,
|
|
529
|
+
requested_backend=getattr(config, "render_backend", None),
|
|
530
|
+
check_imageio_plugins=True,
|
|
531
|
+
)
|
|
532
|
+
emit_backend_warnings(backend_selection.warnings, stacklevel=2)
|
|
533
|
+
backend = backend_selection.backend
|
|
536
534
|
|
|
535
|
+
if backend == "imageio":
|
|
537
536
|
try:
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
537
|
+
import imageio
|
|
538
|
+
|
|
539
|
+
writer_kwargs, mode = get_imageio_writer_kwargs(config.save_path, config.fps)
|
|
540
|
+
with imageio.get_writer(config.save_path, mode=mode, **writer_kwargs) as writer:
|
|
541
|
+
frames_iter = range(nframes)
|
|
542
|
+
if progress_bar_enabled:
|
|
543
|
+
frames_iter = tqdm(
|
|
544
|
+
frames_iter,
|
|
545
|
+
desc=f"Rendering {config.save_path}",
|
|
546
|
+
)
|
|
547
|
+
|
|
548
|
+
init()
|
|
549
|
+
for frame_idx in frames_iter:
|
|
550
|
+
update(frame_idx)
|
|
551
|
+
fig.canvas.draw()
|
|
552
|
+
frame = np.asarray(fig.canvas.buffer_rgba())
|
|
553
|
+
if frame.shape[-1] == 4:
|
|
554
|
+
frame = frame[:, :, :3]
|
|
555
|
+
writer.append_data(frame)
|
|
556
|
+
|
|
552
557
|
print(f"Animation saved to: {config.save_path}")
|
|
553
558
|
except Exception as e:
|
|
554
|
-
|
|
555
|
-
|
|
559
|
+
import warnings
|
|
560
|
+
|
|
561
|
+
warnings.warn(
|
|
562
|
+
f"imageio rendering failed: {e}. Falling back to matplotlib.",
|
|
563
|
+
RuntimeWarning,
|
|
564
|
+
stacklevel=2,
|
|
565
|
+
)
|
|
566
|
+
backend = "matplotlib"
|
|
567
|
+
|
|
568
|
+
if backend == "matplotlib":
|
|
569
|
+
ani = FuncAnimation(
|
|
570
|
+
fig,
|
|
571
|
+
update,
|
|
572
|
+
frames=nframes,
|
|
573
|
+
init_func=init,
|
|
574
|
+
blit=use_blitting,
|
|
575
|
+
repeat=config.repeat,
|
|
576
|
+
)
|
|
577
|
+
|
|
578
|
+
writer = get_matplotlib_writer(config.save_path, fps=config.fps)
|
|
579
|
+
|
|
580
|
+
if progress_bar_enabled:
|
|
581
|
+
pbar = tqdm(total=nframes, desc=f"Saving to {config.save_path}")
|
|
582
|
+
|
|
583
|
+
def progress_callback(current_frame, total_frames):
|
|
584
|
+
pbar.update(1)
|
|
585
|
+
|
|
586
|
+
try:
|
|
587
|
+
ani.save(
|
|
588
|
+
config.save_path,
|
|
589
|
+
writer=writer,
|
|
590
|
+
progress_callback=progress_callback,
|
|
591
|
+
)
|
|
592
|
+
print(f"Animation saved to: {config.save_path}")
|
|
593
|
+
finally:
|
|
594
|
+
pbar.close()
|
|
595
|
+
else:
|
|
596
|
+
ani.save(config.save_path, writer=writer)
|
|
597
|
+
print(f"Animation saved to: {config.save_path}")
|
|
598
|
+
|
|
599
|
+
# Create animation object for showing (if not already created)
|
|
600
|
+
if config.show and ani is None:
|
|
601
|
+
ani = FuncAnimation(
|
|
602
|
+
fig,
|
|
603
|
+
update,
|
|
604
|
+
frames=nframes,
|
|
605
|
+
init_func=init,
|
|
606
|
+
blit=use_blitting,
|
|
607
|
+
repeat=config.repeat,
|
|
608
|
+
)
|
|
556
609
|
|
|
557
610
|
if config.show:
|
|
558
611
|
# Automatically detect Jupyter and display as HTML/JS
|
|
@@ -1043,7 +1096,7 @@ def _mcmc(
|
|
|
1043
1096
|
|
|
1044
1097
|
if __name__ == "__main__":
|
|
1045
1098
|
data = load_roi_data()
|
|
1046
|
-
bumps, fits, nbump, centrbump =
|
|
1099
|
+
bumps, fits, nbump, centrbump = roi_bump_fits(
|
|
1047
1100
|
data, n_steps=5000, n_roi=16, save_path=os.path.join(os.getcwd(), "test.npz")
|
|
1048
1101
|
)
|
|
1049
1102
|
|
canns/analyzer/data/asa/fr.py
CHANGED
|
@@ -135,7 +135,8 @@ def save_fr_heatmap_png(
|
|
|
135
135
|
Plot configuration. Use ``config.save_path`` to specify output file.
|
|
136
136
|
**kwargs : Any
|
|
137
137
|
Additional ``imshow`` keyword arguments. ``save_path`` may be provided here
|
|
138
|
-
as a fallback if not set in ``config``.
|
|
138
|
+
as a fallback if not set in ``config``. If ``save_path`` is omitted, the
|
|
139
|
+
figure is only displayed when ``show=True``.
|
|
139
140
|
|
|
140
141
|
Notes
|
|
141
142
|
-----
|
|
@@ -172,11 +173,6 @@ def save_fr_heatmap_png(
|
|
|
172
173
|
if not config.ylabel:
|
|
173
174
|
config.ylabel = ylabel
|
|
174
175
|
|
|
175
|
-
if config.save_path is None:
|
|
176
|
-
raise ValueError(
|
|
177
|
-
"save_path must be provided via config.save_path or as a keyword argument."
|
|
178
|
-
)
|
|
179
|
-
|
|
180
176
|
config.save_dpi = dpi
|
|
181
177
|
|
|
182
178
|
M = np.asarray(M)
|
|
@@ -391,7 +387,8 @@ def plot_frm(
|
|
|
391
387
|
Plot configuration. Use ``config.save_path`` to specify output file.
|
|
392
388
|
**kwargs : Any
|
|
393
389
|
Additional ``imshow`` keyword arguments. ``save_path`` may be provided here
|
|
394
|
-
as a fallback if not set in ``config``.
|
|
390
|
+
as a fallback if not set in ``config``. If ``save_path`` is omitted, the
|
|
391
|
+
figure is only displayed when ``show=True``.
|
|
395
392
|
|
|
396
393
|
Examples
|
|
397
394
|
--------
|
|
@@ -423,11 +420,6 @@ def plot_frm(
|
|
|
423
420
|
if not config.ylabel:
|
|
424
421
|
config.ylabel = "Y bin"
|
|
425
422
|
|
|
426
|
-
if config.save_path is None:
|
|
427
|
-
raise ValueError(
|
|
428
|
-
"save_path must be provided via config.save_path or as a keyword argument."
|
|
429
|
-
)
|
|
430
|
-
|
|
431
423
|
config.save_dpi = dpi
|
|
432
424
|
|
|
433
425
|
frm = np.asarray(frm)
|
|
@@ -295,8 +295,19 @@ def plot_path_compare(
|
|
|
295
295
|
ax0 = axes[0]
|
|
296
296
|
ax0.set_title("Physical path (x,y)")
|
|
297
297
|
ax0.set_aspect("equal", "box")
|
|
298
|
-
ax0.axis("off")
|
|
299
298
|
ax0.plot(x, y, lw=0.9, alpha=0.8)
|
|
299
|
+
# Keep a visible frame while hiding ticks for a clean path outline.
|
|
300
|
+
ax0.set_xticks([])
|
|
301
|
+
ax0.set_yticks([])
|
|
302
|
+
for spine in ax0.spines.values():
|
|
303
|
+
spine.set_visible(True)
|
|
304
|
+
# Add a small padding so the frame doesn't touch the path.
|
|
305
|
+
x_min, x_max = np.min(x), np.max(x)
|
|
306
|
+
y_min, y_max = np.min(y), np.max(y)
|
|
307
|
+
pad_x = (x_max - x_min) * 0.03 if x_max > x_min else 1.0
|
|
308
|
+
pad_y = (y_max - y_min) * 0.03 if y_max > y_min else 1.0
|
|
309
|
+
ax0.set_xlim(x_min - pad_x, x_max + pad_x)
|
|
310
|
+
ax0.set_ylim(y_min - pad_y, y_max + pad_y)
|
|
300
311
|
|
|
301
312
|
ax1 = axes[1]
|
|
302
313
|
ax1.set_title("Decoded coho path")
|
canns/data/__init__.py
CHANGED
|
@@ -16,11 +16,13 @@ from .datasets import (
|
|
|
16
16
|
get_data_dir,
|
|
17
17
|
get_dataset_path,
|
|
18
18
|
get_huggingface_upload_guide,
|
|
19
|
+
get_left_right_data_session,
|
|
20
|
+
get_left_right_npz,
|
|
19
21
|
list_datasets,
|
|
20
22
|
load,
|
|
21
23
|
quick_setup,
|
|
22
24
|
)
|
|
23
|
-
from .loaders import load_grid_data, load_roi_data
|
|
25
|
+
from .loaders import load_grid_data, load_left_right_npz, load_roi_data
|
|
24
26
|
|
|
25
27
|
__all__ = [
|
|
26
28
|
# Dataset registry and management
|
|
@@ -31,6 +33,8 @@ __all__ = [
|
|
|
31
33
|
"list_datasets",
|
|
32
34
|
"download_dataset",
|
|
33
35
|
"get_dataset_path",
|
|
36
|
+
"get_left_right_data_session",
|
|
37
|
+
"get_left_right_npz",
|
|
34
38
|
"quick_setup",
|
|
35
39
|
"get_huggingface_upload_guide",
|
|
36
40
|
# Generic loading
|
|
@@ -38,4 +42,5 @@ __all__ = [
|
|
|
38
42
|
# Specialized loaders
|
|
39
43
|
"load_roi_data",
|
|
40
44
|
"load_grid_data",
|
|
45
|
+
"load_left_right_npz",
|
|
41
46
|
]
|
canns/data/datasets.py
CHANGED
|
@@ -38,6 +38,7 @@ DEFAULT_DATA_DIR = Path.home() / ".canns" / "data"
|
|
|
38
38
|
# URLs for datasets on Hugging Face
|
|
39
39
|
HUGGINGFACE_REPO = "canns-team/data-analysis-datasets"
|
|
40
40
|
BASE_URL = f"https://huggingface.co/datasets/{HUGGINGFACE_REPO}/resolve/main/"
|
|
41
|
+
LEFT_RIGHT_DATASET_DIR = "Left_Right_data_of"
|
|
41
42
|
|
|
42
43
|
# Dataset registry with metadata
|
|
43
44
|
DATASETS = {
|
|
@@ -68,6 +69,16 @@ DATASETS = {
|
|
|
68
69
|
"sha256": None,
|
|
69
70
|
"url": f"{BASE_URL}grid_2.npz",
|
|
70
71
|
},
|
|
72
|
+
"left_right_data_of": {
|
|
73
|
+
"filename": LEFT_RIGHT_DATASET_DIR,
|
|
74
|
+
"description": "ASA type data from Left-Right sweep paper",
|
|
75
|
+
"size_mb": 604.0,
|
|
76
|
+
"format": "directory",
|
|
77
|
+
"usage": "ASA analysis, left-right sweep sessions",
|
|
78
|
+
"sha256": None,
|
|
79
|
+
"url": f"{BASE_URL}{LEFT_RIGHT_DATASET_DIR}/",
|
|
80
|
+
"is_collection": True,
|
|
81
|
+
},
|
|
71
82
|
}
|
|
72
83
|
|
|
73
84
|
|
|
@@ -130,7 +141,10 @@ def list_datasets() -> None:
|
|
|
130
141
|
print("=" * 60)
|
|
131
142
|
|
|
132
143
|
for key, info in DATASETS.items():
|
|
133
|
-
|
|
144
|
+
if info.get("is_collection"):
|
|
145
|
+
status = "Collection (use session getter)"
|
|
146
|
+
else:
|
|
147
|
+
status = "Available" if info["url"] else "Setup required"
|
|
134
148
|
print(f"\nDataset: {key}")
|
|
135
149
|
print(f" File: {info['filename']}")
|
|
136
150
|
print(f" Size: {info['size_mb']} MB")
|
|
@@ -162,6 +176,11 @@ def download_dataset(dataset_key: str, force: bool = False) -> Path | None:
|
|
|
162
176
|
|
|
163
177
|
info = DATASETS[dataset_key]
|
|
164
178
|
|
|
179
|
+
if info.get("is_collection"):
|
|
180
|
+
print(f"{dataset_key} is a dataset collection.")
|
|
181
|
+
print("Use get_left_right_data_session(session_id) to download a session.")
|
|
182
|
+
return None
|
|
183
|
+
|
|
165
184
|
if not info["url"]:
|
|
166
185
|
print(f"{dataset_key} not yet available for download")
|
|
167
186
|
print("Please use setup_local_datasets() to copy from local repository")
|
|
@@ -213,6 +232,10 @@ def get_dataset_path(dataset_key: str, auto_setup: bool = True) -> Path | None:
|
|
|
213
232
|
if dataset_key not in DATASETS:
|
|
214
233
|
print(f"Unknown dataset: {dataset_key}")
|
|
215
234
|
return None
|
|
235
|
+
if DATASETS[dataset_key].get("is_collection"):
|
|
236
|
+
print(f"{dataset_key} is a dataset collection.")
|
|
237
|
+
print("Use get_left_right_data_session(session_id) to access session files.")
|
|
238
|
+
return None
|
|
216
239
|
|
|
217
240
|
data_dir = get_data_dir()
|
|
218
241
|
filepath = data_dir / DATASETS[dataset_key]["filename"]
|
|
@@ -236,6 +259,136 @@ def get_dataset_path(dataset_key: str, auto_setup: bool = True) -> Path | None:
|
|
|
236
259
|
return None
|
|
237
260
|
|
|
238
261
|
|
|
262
|
+
def get_left_right_data_session(
|
|
263
|
+
session_id: str, auto_download: bool = True, force: bool = False
|
|
264
|
+
) -> dict[str, Path | list[Path] | None] | None:
|
|
265
|
+
"""
|
|
266
|
+
Download and return files for a Left_Right_data_of session.
|
|
267
|
+
|
|
268
|
+
Parameters
|
|
269
|
+
----------
|
|
270
|
+
session_id : str
|
|
271
|
+
Session folder name, e.g. "24365_2".
|
|
272
|
+
auto_download : bool
|
|
273
|
+
Whether to download missing files automatically.
|
|
274
|
+
force : bool
|
|
275
|
+
Whether to force re-download of existing files.
|
|
276
|
+
|
|
277
|
+
Returns
|
|
278
|
+
-------
|
|
279
|
+
dict or None
|
|
280
|
+
Mapping with keys: "manifest", "full_file", "module_files".
|
|
281
|
+
"""
|
|
282
|
+
if not session_id:
|
|
283
|
+
raise ValueError("session_id must be non-empty")
|
|
284
|
+
|
|
285
|
+
session_dir = get_data_dir() / LEFT_RIGHT_DATASET_DIR / session_id
|
|
286
|
+
session_dir.mkdir(parents=True, exist_ok=True)
|
|
287
|
+
|
|
288
|
+
manifest_filename = f"{session_id}_ASA_manifest.json"
|
|
289
|
+
manifest_url = f"{BASE_URL}{LEFT_RIGHT_DATASET_DIR}/{session_id}/{manifest_filename}"
|
|
290
|
+
manifest_path = session_dir / manifest_filename
|
|
291
|
+
|
|
292
|
+
if auto_download and (force or not manifest_path.exists()):
|
|
293
|
+
if not download_file_with_progress(manifest_url, manifest_path):
|
|
294
|
+
print(f"Failed to download manifest for session {session_id}")
|
|
295
|
+
return None
|
|
296
|
+
|
|
297
|
+
if not manifest_path.exists():
|
|
298
|
+
print(f"Manifest not found for session {session_id}")
|
|
299
|
+
return None
|
|
300
|
+
|
|
301
|
+
import json
|
|
302
|
+
|
|
303
|
+
with open(manifest_path) as f:
|
|
304
|
+
manifest = json.load(f)
|
|
305
|
+
|
|
306
|
+
full_file = manifest.get("full_file")
|
|
307
|
+
module_files = manifest.get("module_files", [])
|
|
308
|
+
requested_files: list[str] = []
|
|
309
|
+
|
|
310
|
+
if isinstance(full_file, str):
|
|
311
|
+
requested_files.append(Path(full_file).name)
|
|
312
|
+
|
|
313
|
+
if isinstance(module_files, list):
|
|
314
|
+
for module_file in module_files:
|
|
315
|
+
if isinstance(module_file, str):
|
|
316
|
+
requested_files.append(Path(module_file).name)
|
|
317
|
+
|
|
318
|
+
# De-duplicate while preserving order
|
|
319
|
+
seen: set[str] = set()
|
|
320
|
+
unique_files: list[str] = []
|
|
321
|
+
for filename in requested_files:
|
|
322
|
+
if filename and filename not in seen:
|
|
323
|
+
seen.add(filename)
|
|
324
|
+
unique_files.append(filename)
|
|
325
|
+
|
|
326
|
+
for filename in unique_files:
|
|
327
|
+
file_path = session_dir / filename
|
|
328
|
+
if auto_download and (force or not file_path.exists()):
|
|
329
|
+
file_url = f"{BASE_URL}{LEFT_RIGHT_DATASET_DIR}/{session_id}/{filename}"
|
|
330
|
+
if not download_file_with_progress(file_url, file_path):
|
|
331
|
+
print(f"Failed to download {filename} for session {session_id}")
|
|
332
|
+
return None
|
|
333
|
+
|
|
334
|
+
return {
|
|
335
|
+
"manifest": manifest_path,
|
|
336
|
+
"full_file": session_dir / Path(full_file).name if isinstance(full_file, str) else None,
|
|
337
|
+
"module_files": [
|
|
338
|
+
session_dir / Path(module_file).name
|
|
339
|
+
for module_file in module_files
|
|
340
|
+
if isinstance(module_file, str)
|
|
341
|
+
],
|
|
342
|
+
}
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
def get_left_right_npz(
|
|
346
|
+
session_id: str, filename: str, auto_download: bool = True, force: bool = False
|
|
347
|
+
) -> Path | None:
|
|
348
|
+
"""
|
|
349
|
+
Download and return a specific Left_Right_data_of NPZ file.
|
|
350
|
+
|
|
351
|
+
Parameters
|
|
352
|
+
----------
|
|
353
|
+
session_id : str
|
|
354
|
+
Session folder name, e.g. "26034_3".
|
|
355
|
+
filename : str
|
|
356
|
+
File name inside the session folder, e.g.
|
|
357
|
+
"26034_3_ASA_mec_gridModule02_n104_cm.npz".
|
|
358
|
+
auto_download : bool
|
|
359
|
+
Whether to download the file if missing.
|
|
360
|
+
force : bool
|
|
361
|
+
Whether to force re-download of existing files.
|
|
362
|
+
|
|
363
|
+
Returns
|
|
364
|
+
-------
|
|
365
|
+
Path or None
|
|
366
|
+
Path to the requested file if available, None otherwise.
|
|
367
|
+
"""
|
|
368
|
+
if not session_id:
|
|
369
|
+
raise ValueError("session_id must be non-empty")
|
|
370
|
+
if not filename:
|
|
371
|
+
raise ValueError("filename must be non-empty")
|
|
372
|
+
|
|
373
|
+
safe_name = Path(filename).name
|
|
374
|
+
session_dir = get_data_dir() / LEFT_RIGHT_DATASET_DIR / session_id
|
|
375
|
+
session_dir.mkdir(parents=True, exist_ok=True)
|
|
376
|
+
|
|
377
|
+
file_path = session_dir / safe_name
|
|
378
|
+
if file_path.exists() and not force:
|
|
379
|
+
return file_path
|
|
380
|
+
|
|
381
|
+
if not auto_download:
|
|
382
|
+
return None
|
|
383
|
+
|
|
384
|
+
file_url = f"{BASE_URL}{LEFT_RIGHT_DATASET_DIR}/{session_id}/{safe_name}"
|
|
385
|
+
if not download_file_with_progress(file_url, file_path):
|
|
386
|
+
print(f"Failed to download {safe_name} for session {session_id}")
|
|
387
|
+
return None
|
|
388
|
+
|
|
389
|
+
return file_path
|
|
390
|
+
|
|
391
|
+
|
|
239
392
|
def detect_file_type(filepath: Path) -> str:
|
|
240
393
|
"""Detect file type based on extension."""
|
|
241
394
|
suffix = filepath.suffix.lower()
|
canns/data/loaders.py
CHANGED
|
@@ -211,6 +211,43 @@ def load_grid_data(
|
|
|
211
211
|
return None
|
|
212
212
|
|
|
213
213
|
|
|
214
|
+
def load_left_right_npz(
|
|
215
|
+
session_id: str, filename: str, auto_download: bool = True, force: bool = False
|
|
216
|
+
) -> dict[str, Any] | None:
|
|
217
|
+
"""
|
|
218
|
+
Load a Left_Right_data_of NPZ file.
|
|
219
|
+
|
|
220
|
+
Parameters
|
|
221
|
+
----------
|
|
222
|
+
session_id : str
|
|
223
|
+
Session folder name, e.g. "26034_3".
|
|
224
|
+
filename : str
|
|
225
|
+
File name inside the session folder.
|
|
226
|
+
auto_download : bool
|
|
227
|
+
Whether to download the file if missing.
|
|
228
|
+
force : bool
|
|
229
|
+
Whether to force re-download of existing files.
|
|
230
|
+
|
|
231
|
+
Returns
|
|
232
|
+
-------
|
|
233
|
+
dict or None
|
|
234
|
+
Dictionary of npz arrays if successful, None otherwise.
|
|
235
|
+
"""
|
|
236
|
+
try:
|
|
237
|
+
path = _datasets.get_left_right_npz(
|
|
238
|
+
session_id=session_id,
|
|
239
|
+
filename=filename,
|
|
240
|
+
auto_download=auto_download,
|
|
241
|
+
force=force,
|
|
242
|
+
)
|
|
243
|
+
if path is None:
|
|
244
|
+
return None
|
|
245
|
+
return dict(np.load(path, allow_pickle=True))
|
|
246
|
+
except Exception as e:
|
|
247
|
+
print(f"Failed to load Left-Right npz {session_id}/{filename}: {e}")
|
|
248
|
+
return None
|
|
249
|
+
|
|
250
|
+
|
|
214
251
|
def validate_roi_data(data: np.ndarray) -> bool:
|
|
215
252
|
"""
|
|
216
253
|
Validate ROI data format for 1D CANN analysis.
|
canns/pipeline/__init__.py
CHANGED
|
@@ -1,17 +1,9 @@
|
|
|
1
|
-
"""
|
|
2
|
-
CANNs Pipeline Module
|
|
1
|
+
"""CANNs pipeline entrypoints."""
|
|
3
2
|
|
|
4
|
-
High-level pipelines for common analysis workflows, designed to make CANN models
|
|
5
|
-
accessible to experimental researchers without requiring detailed knowledge of
|
|
6
|
-
the underlying implementations.
|
|
7
|
-
"""
|
|
8
|
-
|
|
9
|
-
from ._base import Pipeline
|
|
10
3
|
from .asa import ASAApp
|
|
11
4
|
from .asa import main as asa_main
|
|
5
|
+
from .gallery import GalleryApp
|
|
6
|
+
from .gallery import main as gallery_main
|
|
7
|
+
from .launcher import main as launcher_main
|
|
12
8
|
|
|
13
|
-
__all__ = [
|
|
14
|
-
"Pipeline",
|
|
15
|
-
"ASAApp",
|
|
16
|
-
"asa_main",
|
|
17
|
-
]
|
|
9
|
+
__all__ = ["ASAApp", "asa_main", "GalleryApp", "gallery_main", "launcher_main"]
|
canns/pipeline/asa/widgets.py
CHANGED