canns 0.13.1__py3-none-any.whl → 0.14.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- canns/analyzer/data/__init__.py +5 -1
- canns/analyzer/data/asa/__init__.py +27 -12
- canns/analyzer/data/asa/cohospace.py +336 -10
- canns/analyzer/data/asa/config.py +3 -0
- canns/analyzer/data/asa/embedding.py +48 -45
- canns/analyzer/data/asa/path.py +104 -2
- canns/analyzer/data/asa/plotting.py +88 -19
- canns/analyzer/data/asa/tda.py +11 -4
- canns/analyzer/data/cell_classification/__init__.py +97 -0
- canns/analyzer/data/cell_classification/core/__init__.py +26 -0
- canns/analyzer/data/cell_classification/core/grid_cells.py +633 -0
- canns/analyzer/data/cell_classification/core/grid_modules_leiden.py +288 -0
- canns/analyzer/data/cell_classification/core/head_direction.py +347 -0
- canns/analyzer/data/cell_classification/core/spatial_analysis.py +431 -0
- canns/analyzer/data/cell_classification/io/__init__.py +5 -0
- canns/analyzer/data/cell_classification/io/matlab_loader.py +417 -0
- canns/analyzer/data/cell_classification/utils/__init__.py +39 -0
- canns/analyzer/data/cell_classification/utils/circular_stats.py +383 -0
- canns/analyzer/data/cell_classification/utils/correlation.py +318 -0
- canns/analyzer/data/cell_classification/utils/geometry.py +442 -0
- canns/analyzer/data/cell_classification/utils/image_processing.py +416 -0
- canns/analyzer/data/cell_classification/visualization/__init__.py +19 -0
- canns/analyzer/data/cell_classification/visualization/grid_plots.py +292 -0
- canns/analyzer/data/cell_classification/visualization/hd_plots.py +200 -0
- canns/analyzer/metrics/__init__.py +2 -1
- canns/analyzer/visualization/core/config.py +46 -4
- canns/data/__init__.py +6 -1
- canns/data/datasets.py +154 -1
- canns/data/loaders.py +37 -0
- canns/pipeline/__init__.py +13 -9
- canns/pipeline/__main__.py +6 -0
- canns/pipeline/asa/runner.py +105 -41
- canns/pipeline/asa_gui/__init__.py +68 -0
- canns/pipeline/asa_gui/__main__.py +6 -0
- canns/pipeline/asa_gui/analysis_modes/__init__.py +42 -0
- canns/pipeline/asa_gui/analysis_modes/base.py +39 -0
- canns/pipeline/asa_gui/analysis_modes/batch_mode.py +21 -0
- canns/pipeline/asa_gui/analysis_modes/cohomap_mode.py +56 -0
- canns/pipeline/asa_gui/analysis_modes/cohospace_mode.py +194 -0
- canns/pipeline/asa_gui/analysis_modes/decode_mode.py +52 -0
- canns/pipeline/asa_gui/analysis_modes/fr_mode.py +81 -0
- canns/pipeline/asa_gui/analysis_modes/frm_mode.py +92 -0
- canns/pipeline/asa_gui/analysis_modes/gridscore_mode.py +123 -0
- canns/pipeline/asa_gui/analysis_modes/pathcompare_mode.py +199 -0
- canns/pipeline/asa_gui/analysis_modes/tda_mode.py +112 -0
- canns/pipeline/asa_gui/app.py +29 -0
- canns/pipeline/asa_gui/controllers/__init__.py +6 -0
- canns/pipeline/asa_gui/controllers/analysis_controller.py +59 -0
- canns/pipeline/asa_gui/controllers/preprocess_controller.py +89 -0
- canns/pipeline/asa_gui/core/__init__.py +15 -0
- canns/pipeline/asa_gui/core/cache.py +14 -0
- canns/pipeline/asa_gui/core/runner.py +1936 -0
- canns/pipeline/asa_gui/core/state.py +324 -0
- canns/pipeline/asa_gui/core/worker.py +260 -0
- canns/pipeline/asa_gui/main_window.py +184 -0
- canns/pipeline/asa_gui/models/__init__.py +7 -0
- canns/pipeline/asa_gui/models/config.py +14 -0
- canns/pipeline/asa_gui/models/job.py +31 -0
- canns/pipeline/asa_gui/models/presets.py +21 -0
- canns/pipeline/asa_gui/resources/__init__.py +16 -0
- canns/pipeline/asa_gui/resources/dark.qss +167 -0
- canns/pipeline/asa_gui/resources/light.qss +163 -0
- canns/pipeline/asa_gui/resources/styles.qss +130 -0
- canns/pipeline/asa_gui/utils/__init__.py +1 -0
- canns/pipeline/asa_gui/utils/formatters.py +15 -0
- canns/pipeline/asa_gui/utils/io_adapters.py +40 -0
- canns/pipeline/asa_gui/utils/validators.py +41 -0
- canns/pipeline/asa_gui/views/__init__.py +1 -0
- canns/pipeline/asa_gui/views/help_content.py +171 -0
- canns/pipeline/asa_gui/views/pages/__init__.py +6 -0
- canns/pipeline/asa_gui/views/pages/analysis_page.py +565 -0
- canns/pipeline/asa_gui/views/pages/preprocess_page.py +492 -0
- canns/pipeline/asa_gui/views/panels/__init__.py +1 -0
- canns/pipeline/asa_gui/views/widgets/__init__.py +21 -0
- canns/pipeline/asa_gui/views/widgets/artifacts_tab.py +44 -0
- canns/pipeline/asa_gui/views/widgets/drop_zone.py +80 -0
- canns/pipeline/asa_gui/views/widgets/file_list.py +27 -0
- canns/pipeline/asa_gui/views/widgets/gridscore_tab.py +308 -0
- canns/pipeline/asa_gui/views/widgets/help_dialog.py +27 -0
- canns/pipeline/asa_gui/views/widgets/image_tab.py +50 -0
- canns/pipeline/asa_gui/views/widgets/image_viewer.py +97 -0
- canns/pipeline/asa_gui/views/widgets/log_box.py +16 -0
- canns/pipeline/asa_gui/views/widgets/pathcompare_tab.py +200 -0
- canns/pipeline/asa_gui/views/widgets/popup_combo.py +25 -0
- canns/pipeline/gallery/__init__.py +15 -5
- canns/pipeline/gallery/__main__.py +11 -0
- canns/pipeline/gallery/app.py +705 -0
- canns/pipeline/gallery/runner.py +790 -0
- canns/pipeline/gallery/state.py +51 -0
- canns/pipeline/gallery/styles.tcss +123 -0
- canns/pipeline/launcher.py +81 -0
- {canns-0.13.1.dist-info → canns-0.14.0.dist-info}/METADATA +11 -1
- canns-0.14.0.dist-info/RECORD +163 -0
- canns-0.14.0.dist-info/entry_points.txt +5 -0
- canns/pipeline/_base.py +0 -50
- canns-0.13.1.dist-info/RECORD +0 -89
- canns-0.13.1.dist-info/entry_points.txt +0 -3
- {canns-0.13.1.dist-info → canns-0.14.0.dist-info}/WHEEL +0 -0
- {canns-0.13.1.dist-info → canns-0.14.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,565 @@
|
|
|
1
|
+
"""Analysis page for ASA GUI."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
from PySide6.QtCore import QSettings, Qt, Signal
|
|
9
|
+
from PySide6.QtGui import QColor
|
|
10
|
+
from PySide6.QtWidgets import (
|
|
11
|
+
QCheckBox,
|
|
12
|
+
QFrame,
|
|
13
|
+
QGraphicsDropShadowEffect,
|
|
14
|
+
QGroupBox,
|
|
15
|
+
QHBoxLayout,
|
|
16
|
+
QLabel,
|
|
17
|
+
QProgressBar,
|
|
18
|
+
QPushButton,
|
|
19
|
+
QScrollArea,
|
|
20
|
+
QSplitter,
|
|
21
|
+
QTabWidget,
|
|
22
|
+
QVBoxLayout,
|
|
23
|
+
QWidget,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
from ...analysis_modes import AbstractAnalysisMode, get_analysis_modes
|
|
27
|
+
from ...controllers import AnalysisController
|
|
28
|
+
from ...core import WorkerManager
|
|
29
|
+
from ..help_content import analysis_help_markdown
|
|
30
|
+
from ..widgets.artifacts_tab import ArtifactsTab
|
|
31
|
+
from ..widgets.gridscore_tab import GridScoreTab
|
|
32
|
+
from ..widgets.help_dialog import show_help_dialog
|
|
33
|
+
from ..widgets.image_tab import ImageTab
|
|
34
|
+
from ..widgets.log_box import LogBox
|
|
35
|
+
from ..widgets.pathcompare_tab import PathCompareTab
|
|
36
|
+
from ..widgets.popup_combo import PopupComboBox
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class AnalysisPage(QWidget):
|
|
40
|
+
"""Page for running analyses and viewing results."""
|
|
41
|
+
|
|
42
|
+
analysis_completed = Signal()
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
controller: AnalysisController,
|
|
47
|
+
worker_manager: WorkerManager,
|
|
48
|
+
parent=None,
|
|
49
|
+
) -> None:
|
|
50
|
+
super().__init__(parent)
|
|
51
|
+
self._controller = controller
|
|
52
|
+
self._workers = worker_manager
|
|
53
|
+
self._last_state = None
|
|
54
|
+
self._lang = "en"
|
|
55
|
+
self._build_ui()
|
|
56
|
+
|
|
57
|
+
def _build_ui(self) -> None:
|
|
58
|
+
root = QVBoxLayout(self)
|
|
59
|
+
|
|
60
|
+
info_row = QHBoxLayout()
|
|
61
|
+
self.info_label = QLabel("Mode=— | preset=— | preprocess=— | spike_main_shape=—")
|
|
62
|
+
self.info_label.setObjectName("muted")
|
|
63
|
+
info_row.addWidget(self.info_label, 1)
|
|
64
|
+
root.addLayout(info_row)
|
|
65
|
+
|
|
66
|
+
splitter = QSplitter(Qt.Horizontal)
|
|
67
|
+
root.addWidget(splitter, 1)
|
|
68
|
+
|
|
69
|
+
left_wrap = QWidget()
|
|
70
|
+
right_wrap = QWidget()
|
|
71
|
+
left = QVBoxLayout(left_wrap)
|
|
72
|
+
right = QVBoxLayout(right_wrap)
|
|
73
|
+
|
|
74
|
+
self.param_container = QGroupBox("Analysis Parameters")
|
|
75
|
+
self.param_container.setObjectName("card")
|
|
76
|
+
self.param_layout = QVBoxLayout(self.param_container)
|
|
77
|
+
self.param_layout.setContentsMargins(0, 0, 0, 0)
|
|
78
|
+
self.param_layout.setSpacing(12)
|
|
79
|
+
|
|
80
|
+
mode_row = QHBoxLayout()
|
|
81
|
+
self.analysis_mode = PopupComboBox()
|
|
82
|
+
self.analysis_mode.setToolTip("Select an analysis mode to run.")
|
|
83
|
+
self._modes: dict[str, AbstractAnalysisMode] = {}
|
|
84
|
+
hidden_modes = {"decode", "gridscore_inspect", "batch"}
|
|
85
|
+
for mode in get_analysis_modes():
|
|
86
|
+
self._modes[mode.name] = mode
|
|
87
|
+
if mode.name not in hidden_modes:
|
|
88
|
+
self.analysis_mode.addItem(mode.display_name, userData=mode.name)
|
|
89
|
+
|
|
90
|
+
self.label_analysis_module = QLabel("Analysis module:")
|
|
91
|
+
mode_row.addWidget(self.label_analysis_module)
|
|
92
|
+
mode_row.addWidget(self.analysis_mode, 1)
|
|
93
|
+
self.help_btn = QPushButton("Help")
|
|
94
|
+
self.help_btn.setToolTip("Show parameter guide for the selected mode.")
|
|
95
|
+
self.help_btn.clicked.connect(self._show_help)
|
|
96
|
+
mode_row.addWidget(self.help_btn)
|
|
97
|
+
self.param_layout.addLayout(mode_row)
|
|
98
|
+
|
|
99
|
+
self.grp_standardize = QGroupBox("Preprocess (Standardization)")
|
|
100
|
+
std_layout = QHBoxLayout(self.grp_standardize)
|
|
101
|
+
self.chk_standardize = QCheckBox("StandardScaler")
|
|
102
|
+
std_layout.addWidget(self.chk_standardize)
|
|
103
|
+
std_layout.addStretch(1)
|
|
104
|
+
self.param_layout.addWidget(self.grp_standardize)
|
|
105
|
+
|
|
106
|
+
self.param_widgets: dict[str, QWidget] = {}
|
|
107
|
+
for mode in self._modes.values():
|
|
108
|
+
widget = mode.create_params_widget()
|
|
109
|
+
widget.setObjectName("card")
|
|
110
|
+
btn_show = getattr(mode, "btn_show", None)
|
|
111
|
+
if btn_show is not None:
|
|
112
|
+
btn_show.clicked.connect(
|
|
113
|
+
lambda _=False, mode_name=mode.name: self._run_analysis(mode_override=mode_name)
|
|
114
|
+
)
|
|
115
|
+
self.param_widgets[mode.name] = widget
|
|
116
|
+
self.param_layout.addWidget(widget)
|
|
117
|
+
self.param_layout.addStretch(1)
|
|
118
|
+
self.param_scroll = QScrollArea()
|
|
119
|
+
self.param_scroll.setWidgetResizable(True)
|
|
120
|
+
self.param_scroll.setFrameShape(QFrame.NoFrame)
|
|
121
|
+
self.param_scroll.setWidget(self.param_container)
|
|
122
|
+
|
|
123
|
+
controls = QHBoxLayout()
|
|
124
|
+
self.run_btn = QPushButton("Run Analysis")
|
|
125
|
+
self.run_btn.setObjectName("btn_run")
|
|
126
|
+
self.stop_btn = QPushButton("Stop")
|
|
127
|
+
self.stop_btn.setObjectName("btn_stop")
|
|
128
|
+
self.stop_btn.setEnabled(False)
|
|
129
|
+
self.progress = QProgressBar()
|
|
130
|
+
self.progress.setRange(0, 100)
|
|
131
|
+
self.progress.setValue(0)
|
|
132
|
+
controls.addWidget(self.run_btn)
|
|
133
|
+
controls.addWidget(self.stop_btn)
|
|
134
|
+
controls.addWidget(self.progress, 1)
|
|
135
|
+
|
|
136
|
+
self.log_box = LogBox()
|
|
137
|
+
log_wrap = QWidget()
|
|
138
|
+
log_layout = QVBoxLayout(log_wrap)
|
|
139
|
+
self.logs_label = QLabel("Logs")
|
|
140
|
+
log_layout.addWidget(self.logs_label)
|
|
141
|
+
log_layout.addWidget(self.log_box, 1)
|
|
142
|
+
|
|
143
|
+
left.addWidget(self.param_scroll, 2)
|
|
144
|
+
left.addLayout(controls)
|
|
145
|
+
left.addWidget(log_wrap, 1)
|
|
146
|
+
|
|
147
|
+
# Results
|
|
148
|
+
self.tabs = QTabWidget()
|
|
149
|
+
self.tab_barcode = ImageTab("TDA Barcode")
|
|
150
|
+
self.tab_cohomap = ImageTab("CohoMap")
|
|
151
|
+
self.tab_pathcompare = PathCompareTab("Path Compare")
|
|
152
|
+
self.tab_cohospace = ImageTab("CohoSpace")
|
|
153
|
+
self.tab_fr = ImageTab("FR Heatmap")
|
|
154
|
+
self.tab_frm = ImageTab("FRM")
|
|
155
|
+
self.tab_gridscore = GridScoreTab("Grid Score")
|
|
156
|
+
|
|
157
|
+
self.tabs.addTab(self.tab_barcode, "Barcode")
|
|
158
|
+
self.tabs.addTab(self.tab_cohomap, "CohoMap")
|
|
159
|
+
self.tabs.addTab(self.tab_pathcompare, "Path Compare")
|
|
160
|
+
self.tabs.addTab(self.tab_cohospace, "CohoSpace")
|
|
161
|
+
self.tabs.addTab(self.tab_fr, "FR")
|
|
162
|
+
self.tabs.addTab(self.tab_frm, "FRM")
|
|
163
|
+
self.tabs.addTab(self.tab_gridscore, "GridScore")
|
|
164
|
+
|
|
165
|
+
self.tab_files = ArtifactsTab()
|
|
166
|
+
self.tabs.addTab(self.tab_files, "Files")
|
|
167
|
+
|
|
168
|
+
right.addWidget(self.tabs, 1)
|
|
169
|
+
|
|
170
|
+
splitter.addWidget(left_wrap)
|
|
171
|
+
splitter.addWidget(right_wrap)
|
|
172
|
+
splitter.setStretchFactor(0, 1)
|
|
173
|
+
splitter.setStretchFactor(1, 2)
|
|
174
|
+
|
|
175
|
+
self.analysis_mode.currentIndexChanged.connect(self._on_mode_changed)
|
|
176
|
+
self.run_btn.clicked.connect(self._run_analysis)
|
|
177
|
+
self.stop_btn.clicked.connect(self._stop_analysis)
|
|
178
|
+
self.tab_gridscore.inspectRequested.connect(self._run_gridscore_inspect)
|
|
179
|
+
|
|
180
|
+
self._on_mode_changed()
|
|
181
|
+
self._apply_card_effects([self.param_container] + list(self.param_widgets.values()))
|
|
182
|
+
self._sync_standardize()
|
|
183
|
+
self.apply_language(str(QSettings("canns", "asa_gui").value("lang", "en")))
|
|
184
|
+
|
|
185
|
+
def _apply_card_effects(self, widgets: list[QWidget]) -> None:
|
|
186
|
+
for widget in widgets:
|
|
187
|
+
effect = QGraphicsDropShadowEffect(self)
|
|
188
|
+
effect.setBlurRadius(18)
|
|
189
|
+
effect.setOffset(0, 3)
|
|
190
|
+
effect.setColor(QColor(0, 0, 0, 40))
|
|
191
|
+
widget.setGraphicsEffect(effect)
|
|
192
|
+
|
|
193
|
+
def apply_language(self, lang: str) -> None:
|
|
194
|
+
self._lang = str(lang or "en")
|
|
195
|
+
is_zh = self._lang.lower().startswith("zh")
|
|
196
|
+
self.param_container.setTitle("分析参数" if is_zh else "Analysis Parameters")
|
|
197
|
+
self.label_analysis_module.setText("分析模块:" if is_zh else "Analysis module:")
|
|
198
|
+
self.help_btn.setText("帮助" if is_zh else "Help")
|
|
199
|
+
self.help_btn.setToolTip(
|
|
200
|
+
"查看参数说明" if is_zh else "Show parameter guide for the selected mode."
|
|
201
|
+
)
|
|
202
|
+
self.grp_standardize.setTitle(
|
|
203
|
+
"预处理(标准化)" if is_zh else "Preprocess (Standardization)"
|
|
204
|
+
)
|
|
205
|
+
self.chk_standardize.setText("StandardScaler")
|
|
206
|
+
self.run_btn.setText("运行分析" if is_zh else "Run Analysis")
|
|
207
|
+
self.stop_btn.setText("停止" if is_zh else "Stop")
|
|
208
|
+
self.logs_label.setText("日志" if is_zh else "Logs")
|
|
209
|
+
|
|
210
|
+
if self._last_state is not None:
|
|
211
|
+
self._update_info(self._last_state)
|
|
212
|
+
else:
|
|
213
|
+
self.info_label.setText(
|
|
214
|
+
"模式=— | 预设=— | 预处理=— | spike_main_shape=—"
|
|
215
|
+
if is_zh
|
|
216
|
+
else "Mode=— | preset=— | preprocess=— | spike_main_shape=—"
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
def load_state(self, state) -> None:
|
|
220
|
+
self._last_state = state
|
|
221
|
+
self._update_info(state)
|
|
222
|
+
preset = getattr(state, "preset", None)
|
|
223
|
+
if preset:
|
|
224
|
+
for mode in self._modes.values():
|
|
225
|
+
mode.apply_preset(preset)
|
|
226
|
+
|
|
227
|
+
total_steps = None
|
|
228
|
+
neuron_count = None
|
|
229
|
+
|
|
230
|
+
embed_data = getattr(state, "embed_data", None)
|
|
231
|
+
if isinstance(embed_data, np.ndarray) and embed_data.ndim == 2:
|
|
232
|
+
total_steps, neuron_count = embed_data.shape
|
|
233
|
+
else:
|
|
234
|
+
aligned_pos = getattr(state, "aligned_pos", None)
|
|
235
|
+
if isinstance(aligned_pos, dict) and "t" in aligned_pos:
|
|
236
|
+
try:
|
|
237
|
+
total_steps = len(aligned_pos["t"])
|
|
238
|
+
except Exception:
|
|
239
|
+
total_steps = None
|
|
240
|
+
inferred = self._infer_counts_from_state(state)
|
|
241
|
+
if inferred is not None:
|
|
242
|
+
inferred_steps, inferred_neurons = inferred
|
|
243
|
+
total_steps = total_steps or inferred_steps
|
|
244
|
+
neuron_count = neuron_count or inferred_neurons
|
|
245
|
+
|
|
246
|
+
for mode in self._modes.values():
|
|
247
|
+
mode.apply_ranges(neuron_count, total_steps)
|
|
248
|
+
|
|
249
|
+
def _update_info(self, state) -> None:
|
|
250
|
+
mode = getattr(state, "input_mode", "—")
|
|
251
|
+
preset = getattr(state, "preset", "—")
|
|
252
|
+
preprocess = getattr(state, "preprocess_method", "—")
|
|
253
|
+
preclass = getattr(state, "preclass", None)
|
|
254
|
+
shape = "None"
|
|
255
|
+
embed = getattr(state, "embed_data", None)
|
|
256
|
+
if isinstance(embed, np.ndarray) and embed.ndim == 2:
|
|
257
|
+
shape = f"{embed.shape}"
|
|
258
|
+
is_zh = str(self._lang).lower().startswith("zh")
|
|
259
|
+
if is_zh:
|
|
260
|
+
parts = [f"模式={mode}", f"预设={preset}", f"预处理={preprocess}"]
|
|
261
|
+
else:
|
|
262
|
+
parts = [f"Mode={mode}", f"preset={preset}", f"preprocess={preprocess}"]
|
|
263
|
+
if preclass is not None:
|
|
264
|
+
parts.append(("预分类" if is_zh else "preclass") + f"={preclass}")
|
|
265
|
+
parts.append(f"spike_main_shape={shape}")
|
|
266
|
+
self.info_label.setText(" | ".join(parts))
|
|
267
|
+
|
|
268
|
+
def _infer_counts_from_state(self, state) -> tuple[int | None, int | None] | None:
|
|
269
|
+
try:
|
|
270
|
+
from ...core.state import resolve_path
|
|
271
|
+
except Exception:
|
|
272
|
+
return None
|
|
273
|
+
|
|
274
|
+
def _infer_from_spike(spike_obj) -> tuple[int | None, int | None]:
|
|
275
|
+
if spike_obj is None:
|
|
276
|
+
return None, None
|
|
277
|
+
if isinstance(spike_obj, np.ndarray):
|
|
278
|
+
if spike_obj.ndim == 2:
|
|
279
|
+
return int(spike_obj.shape[0]), int(spike_obj.shape[1])
|
|
280
|
+
if spike_obj.dtype == object:
|
|
281
|
+
if spike_obj.size == 1:
|
|
282
|
+
spike_obj = spike_obj.item()
|
|
283
|
+
elif spike_obj.ndim == 1:
|
|
284
|
+
return None, int(spike_obj.shape[0])
|
|
285
|
+
if isinstance(spike_obj, dict):
|
|
286
|
+
return None, int(len(spike_obj))
|
|
287
|
+
if isinstance(spike_obj, (list, tuple)):
|
|
288
|
+
return None, int(len(spike_obj))
|
|
289
|
+
return None, None
|
|
290
|
+
|
|
291
|
+
total_steps = None
|
|
292
|
+
neuron_count = None
|
|
293
|
+
|
|
294
|
+
if getattr(state, "input_mode", None) == "asa":
|
|
295
|
+
path = resolve_path(state, state.asa_file)
|
|
296
|
+
if path is None:
|
|
297
|
+
return None
|
|
298
|
+
data = np.load(path, allow_pickle=True)
|
|
299
|
+
if "t" in data:
|
|
300
|
+
try:
|
|
301
|
+
total_steps = len(data["t"])
|
|
302
|
+
except Exception:
|
|
303
|
+
total_steps = None
|
|
304
|
+
if "spike" in data:
|
|
305
|
+
t_guess, n_guess = _infer_from_spike(data["spike"])
|
|
306
|
+
total_steps = total_steps or t_guess
|
|
307
|
+
neuron_count = neuron_count or n_guess
|
|
308
|
+
elif getattr(state, "input_mode", None) == "neuron_traj":
|
|
309
|
+
neuron_path = resolve_path(state, state.neuron_file)
|
|
310
|
+
traj_path = resolve_path(state, state.traj_file)
|
|
311
|
+
if neuron_path is not None and neuron_path.exists():
|
|
312
|
+
neuron_data = np.load(neuron_path, allow_pickle=True)
|
|
313
|
+
if isinstance(neuron_data, np.lib.npyio.NpzFile):
|
|
314
|
+
if "spike" in neuron_data:
|
|
315
|
+
spike_obj = neuron_data["spike"]
|
|
316
|
+
elif neuron_data.files:
|
|
317
|
+
spike_obj = neuron_data[neuron_data.files[0]]
|
|
318
|
+
else:
|
|
319
|
+
spike_obj = None
|
|
320
|
+
t_guess, n_guess = _infer_from_spike(spike_obj)
|
|
321
|
+
else:
|
|
322
|
+
t_guess, n_guess = _infer_from_spike(neuron_data)
|
|
323
|
+
total_steps = total_steps or t_guess
|
|
324
|
+
neuron_count = neuron_count or n_guess
|
|
325
|
+
if traj_path is not None and traj_path.exists():
|
|
326
|
+
traj_data = np.load(traj_path, allow_pickle=True)
|
|
327
|
+
if isinstance(traj_data, np.lib.npyio.NpzFile):
|
|
328
|
+
for key in ("t", "x", "y"):
|
|
329
|
+
if key in traj_data:
|
|
330
|
+
total_steps = total_steps or len(traj_data[key])
|
|
331
|
+
break
|
|
332
|
+
else:
|
|
333
|
+
if hasattr(traj_data, "shape") and len(traj_data.shape) > 0:
|
|
334
|
+
total_steps = total_steps or int(traj_data.shape[0])
|
|
335
|
+
|
|
336
|
+
if total_steps is None and neuron_count is None:
|
|
337
|
+
return None
|
|
338
|
+
return total_steps, neuron_count
|
|
339
|
+
|
|
340
|
+
def _on_mode_changed(self) -> None:
|
|
341
|
+
mode = self.analysis_mode.currentData() or "tda"
|
|
342
|
+
self._sync_standardize()
|
|
343
|
+
visible = {
|
|
344
|
+
"tda": {"tda"},
|
|
345
|
+
"cohomap": {"cohomap"},
|
|
346
|
+
"pathcompare": {"pathcompare"},
|
|
347
|
+
"cohospace": {"cohospace"},
|
|
348
|
+
"fr": {"fr"},
|
|
349
|
+
"frm": {"frm"},
|
|
350
|
+
"gridscore": {"gridscore"},
|
|
351
|
+
"gridscore_inspect": {"gridscore"},
|
|
352
|
+
"decode": {"decode"},
|
|
353
|
+
}.get(mode, {mode})
|
|
354
|
+
|
|
355
|
+
for name, widget in self.param_widgets.items():
|
|
356
|
+
widget.setVisible(name in visible)
|
|
357
|
+
|
|
358
|
+
self.grp_standardize.setVisible(mode in {"tda", "cohomap"})
|
|
359
|
+
|
|
360
|
+
def _sync_standardize(self) -> None:
|
|
361
|
+
tda_mode = self._modes.get("tda")
|
|
362
|
+
if tda_mode is None or not hasattr(tda_mode, "standardize"):
|
|
363
|
+
return
|
|
364
|
+
try:
|
|
365
|
+
checkbox = tda_mode.standardize
|
|
366
|
+
if checkbox.isChecked() != self.chk_standardize.isChecked():
|
|
367
|
+
self.chk_standardize.setChecked(bool(checkbox.isChecked()))
|
|
368
|
+
except Exception:
|
|
369
|
+
return
|
|
370
|
+
|
|
371
|
+
def _on_toggle(val: bool) -> None:
|
|
372
|
+
try:
|
|
373
|
+
checkbox.setChecked(bool(val))
|
|
374
|
+
except Exception:
|
|
375
|
+
pass
|
|
376
|
+
|
|
377
|
+
try:
|
|
378
|
+
self.chk_standardize.toggled.disconnect()
|
|
379
|
+
except Exception:
|
|
380
|
+
pass
|
|
381
|
+
self.chk_standardize.toggled.connect(_on_toggle)
|
|
382
|
+
|
|
383
|
+
def _run_analysis(self, mode_override: str | None = None) -> None:
|
|
384
|
+
if self._workers.is_running():
|
|
385
|
+
self.log_box.log("A task is already running.")
|
|
386
|
+
return
|
|
387
|
+
|
|
388
|
+
mode = mode_override or (self.analysis_mode.currentData() or "tda")
|
|
389
|
+
params = self._collect_params(mode)
|
|
390
|
+
|
|
391
|
+
state = self._controller.get_state()
|
|
392
|
+
from ...core.state import validate_files, validate_preprocessing
|
|
393
|
+
|
|
394
|
+
ok, msg = validate_files(state)
|
|
395
|
+
if not ok:
|
|
396
|
+
self.log_box.log(f"Input error: {msg}")
|
|
397
|
+
return
|
|
398
|
+
|
|
399
|
+
if mode == "tda":
|
|
400
|
+
ok, msg = validate_preprocessing(state)
|
|
401
|
+
if not ok:
|
|
402
|
+
self.log_box.log(f"Preprocess required for TDA: {msg}")
|
|
403
|
+
return
|
|
404
|
+
|
|
405
|
+
if mode in {"fr", "frm", "gridscore", "gridscore_inspect", "cohospace"}:
|
|
406
|
+
mode_flag = None
|
|
407
|
+
if mode in {"fr", "frm"}:
|
|
408
|
+
mode_flag = params.get("mode")
|
|
409
|
+
elif mode in {"gridscore", "gridscore_inspect"}:
|
|
410
|
+
mode_flag = params.get("gridscore", {}).get("mode")
|
|
411
|
+
elif mode == "cohospace":
|
|
412
|
+
mode_flag = params.get("mode")
|
|
413
|
+
if mode_flag == "fr" and state.embed_data is None:
|
|
414
|
+
self.log_box.log(
|
|
415
|
+
"Preprocess required for FR-mode. Use spike-mode or run preprocess."
|
|
416
|
+
)
|
|
417
|
+
return
|
|
418
|
+
|
|
419
|
+
self._controller.update_analysis(analysis_mode=mode, analysis_params=params)
|
|
420
|
+
|
|
421
|
+
self.progress.setValue(0)
|
|
422
|
+
self.run_btn.setEnabled(False)
|
|
423
|
+
self.stop_btn.setEnabled(True)
|
|
424
|
+
self.log_box.log(f"Starting analysis: {mode}")
|
|
425
|
+
if mode == "gridscore_inspect":
|
|
426
|
+
self.tab_gridscore.set_status("Computing gridscore inspect…")
|
|
427
|
+
|
|
428
|
+
def _on_log(msg: str) -> None:
|
|
429
|
+
if msg.startswith("__PCANIM__"):
|
|
430
|
+
parts = msg.split()
|
|
431
|
+
if len(parts) >= 2:
|
|
432
|
+
try:
|
|
433
|
+
pct = int(parts[1])
|
|
434
|
+
self.tab_pathcompare.set_animation_progress(pct)
|
|
435
|
+
return
|
|
436
|
+
except Exception:
|
|
437
|
+
pass
|
|
438
|
+
self.log_box.log(msg)
|
|
439
|
+
|
|
440
|
+
def _on_progress(pct: int) -> None:
|
|
441
|
+
self.progress.setValue(pct)
|
|
442
|
+
|
|
443
|
+
def _on_finished(result) -> None:
|
|
444
|
+
if hasattr(result, "success") and not result.success:
|
|
445
|
+
self._controller.mark_idle()
|
|
446
|
+
self.log_box.log(result.error or "Analysis failed")
|
|
447
|
+
self.run_btn.setEnabled(True)
|
|
448
|
+
self.stop_btn.setEnabled(False)
|
|
449
|
+
return
|
|
450
|
+
artifacts = result.artifacts if hasattr(result, "artifacts") else {}
|
|
451
|
+
self._controller.finalize_analysis(artifacts)
|
|
452
|
+
self._populate_artifacts(artifacts)
|
|
453
|
+
self._select_result_tab(mode, artifacts)
|
|
454
|
+
self.log_box.log(result.summary)
|
|
455
|
+
self.run_btn.setEnabled(True)
|
|
456
|
+
self.stop_btn.setEnabled(False)
|
|
457
|
+
self.analysis_completed.emit()
|
|
458
|
+
|
|
459
|
+
def _on_error(msg: str) -> None:
|
|
460
|
+
self._controller.mark_idle()
|
|
461
|
+
self.log_box.log(f"Error: {msg}")
|
|
462
|
+
self.run_btn.setEnabled(True)
|
|
463
|
+
self.stop_btn.setEnabled(False)
|
|
464
|
+
|
|
465
|
+
def _on_cleanup() -> None:
|
|
466
|
+
self._controller.mark_idle()
|
|
467
|
+
|
|
468
|
+
self._controller.run_analysis(
|
|
469
|
+
worker_manager=self._workers,
|
|
470
|
+
on_log=_on_log,
|
|
471
|
+
on_progress=_on_progress,
|
|
472
|
+
on_finished=_on_finished,
|
|
473
|
+
on_error=_on_error,
|
|
474
|
+
on_cleanup=_on_cleanup,
|
|
475
|
+
)
|
|
476
|
+
|
|
477
|
+
def _show_help(self) -> None:
|
|
478
|
+
mode = self.analysis_mode.currentData()
|
|
479
|
+
lang = str(QSettings("canns", "asa_gui").value("lang", "en"))
|
|
480
|
+
markdown = analysis_help_markdown(str(mode) if mode is not None else "", lang=lang)
|
|
481
|
+
title = "ASA Help" if not str(lang).lower().startswith("zh") else "ASA 参数说明"
|
|
482
|
+
show_help_dialog(self, title, markdown)
|
|
483
|
+
|
|
484
|
+
def _stop_analysis(self) -> None:
|
|
485
|
+
if self._workers.is_running():
|
|
486
|
+
self._workers.request_cancel()
|
|
487
|
+
self.log_box.log("Cancel requested.")
|
|
488
|
+
|
|
489
|
+
def _populate_artifacts(self, artifacts: dict) -> None:
|
|
490
|
+
if "barcode" in artifacts:
|
|
491
|
+
self.tab_barcode.set_image(artifacts.get("barcode"))
|
|
492
|
+
if "cohomap" in artifacts:
|
|
493
|
+
self.tab_cohomap.set_image(artifacts.get("cohomap"))
|
|
494
|
+
if "path_compare" in artifacts or "path_compare_gif" in artifacts:
|
|
495
|
+
self.tab_pathcompare.set_artifacts(
|
|
496
|
+
artifacts.get("path_compare"),
|
|
497
|
+
artifacts.get("path_compare_gif"),
|
|
498
|
+
)
|
|
499
|
+
if "path_compare_mp4" in artifacts:
|
|
500
|
+
self.tab_pathcompare.set_animation(Path(artifacts["path_compare_mp4"]))
|
|
501
|
+
if "neuron" in artifacts:
|
|
502
|
+
self.tab_cohospace.set_image(artifacts.get("neuron"))
|
|
503
|
+
elif "population" in artifacts:
|
|
504
|
+
self.tab_cohospace.set_image(artifacts.get("population"))
|
|
505
|
+
elif "trajectory" in artifacts:
|
|
506
|
+
self.tab_cohospace.set_image(artifacts.get("trajectory"))
|
|
507
|
+
if "fr_heatmap" in artifacts:
|
|
508
|
+
self.tab_fr.set_image(artifacts.get("fr_heatmap"))
|
|
509
|
+
if "frm" in artifacts:
|
|
510
|
+
self.tab_frm.set_image(artifacts.get("frm"))
|
|
511
|
+
|
|
512
|
+
if "gridscore_png" in artifacts:
|
|
513
|
+
self.tab_gridscore.set_distribution_image(Path(artifacts["gridscore_png"]))
|
|
514
|
+
if "gridscore_npz" in artifacts:
|
|
515
|
+
try:
|
|
516
|
+
self.tab_gridscore.load_gridscore_npz(Path(artifacts["gridscore_npz"]))
|
|
517
|
+
except Exception as e:
|
|
518
|
+
self.log_box.log(f"GridScore: failed to load gridscore.npz: {e}")
|
|
519
|
+
if "gridscore_neuron_png" in artifacts:
|
|
520
|
+
self.tab_gridscore.set_autocorr_image(Path(artifacts["gridscore_neuron_png"]))
|
|
521
|
+
self.tab_gridscore.set_status("")
|
|
522
|
+
|
|
523
|
+
self.tab_files.set_artifacts(artifacts)
|
|
524
|
+
|
|
525
|
+
def _select_result_tab(self, mode: str, artifacts: dict) -> None:
|
|
526
|
+
mapping = {
|
|
527
|
+
"tda": self.tab_barcode,
|
|
528
|
+
"decode": self.tab_files,
|
|
529
|
+
"cohomap": self.tab_cohomap,
|
|
530
|
+
"pathcompare": self.tab_pathcompare,
|
|
531
|
+
"cohospace": self.tab_cohospace,
|
|
532
|
+
"fr": self.tab_fr,
|
|
533
|
+
"frm": self.tab_frm,
|
|
534
|
+
"gridscore": self.tab_gridscore,
|
|
535
|
+
"gridscore_inspect": self.tab_gridscore,
|
|
536
|
+
}
|
|
537
|
+
target = mapping.get(mode)
|
|
538
|
+
if target is None:
|
|
539
|
+
return
|
|
540
|
+
idx = self.tabs.indexOf(target)
|
|
541
|
+
if idx < 0:
|
|
542
|
+
return
|
|
543
|
+
if mode == "decode" and not artifacts:
|
|
544
|
+
return
|
|
545
|
+
self.tabs.setCurrentIndex(idx)
|
|
546
|
+
|
|
547
|
+
def _run_gridscore_inspect(self, neuron_id: int, meta: dict) -> None:
|
|
548
|
+
if self._workers.is_running():
|
|
549
|
+
self.log_box.log("A task is already running.")
|
|
550
|
+
return
|
|
551
|
+
mode_obj = self._modes.get("gridscore_inspect")
|
|
552
|
+
if mode_obj is None:
|
|
553
|
+
self.log_box.log("GridScore Inspect is not available.")
|
|
554
|
+
return
|
|
555
|
+
if hasattr(mode_obj, "apply_meta"):
|
|
556
|
+
meta = dict(meta or {})
|
|
557
|
+
meta["neuron_id"] = int(neuron_id)
|
|
558
|
+
mode_obj.apply_meta(meta)
|
|
559
|
+
self._run_analysis(mode_override="gridscore_inspect")
|
|
560
|
+
|
|
561
|
+
def _collect_params(self, mode: str) -> dict:
|
|
562
|
+
mode_obj = self._modes.get(mode)
|
|
563
|
+
if mode_obj is None:
|
|
564
|
+
return {}
|
|
565
|
+
return mode_obj.collect_params()
|