singlebehaviorlab 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sam2/__init__.py +11 -0
- sam2/automatic_mask_generator.py +454 -0
- sam2/benchmark.py +92 -0
- sam2/build_sam.py +174 -0
- sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
- sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
- sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
- sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
- sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
- sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
- sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
- sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
- sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
- sam2/modeling/__init__.py +5 -0
- sam2/modeling/backbones/__init__.py +5 -0
- sam2/modeling/backbones/hieradet.py +317 -0
- sam2/modeling/backbones/image_encoder.py +134 -0
- sam2/modeling/backbones/utils.py +93 -0
- sam2/modeling/memory_attention.py +169 -0
- sam2/modeling/memory_encoder.py +181 -0
- sam2/modeling/position_encoding.py +239 -0
- sam2/modeling/sam/__init__.py +5 -0
- sam2/modeling/sam/mask_decoder.py +295 -0
- sam2/modeling/sam/prompt_encoder.py +202 -0
- sam2/modeling/sam/transformer.py +311 -0
- sam2/modeling/sam2_base.py +913 -0
- sam2/modeling/sam2_utils.py +323 -0
- sam2/sam2_hiera_b+.yaml +113 -0
- sam2/sam2_hiera_l.yaml +117 -0
- sam2/sam2_hiera_s.yaml +116 -0
- sam2/sam2_hiera_t.yaml +118 -0
- sam2/sam2_image_predictor.py +466 -0
- sam2/sam2_video_predictor.py +1388 -0
- sam2/sam2_video_predictor_legacy.py +1172 -0
- sam2/utils/__init__.py +5 -0
- sam2/utils/amg.py +348 -0
- sam2/utils/misc.py +349 -0
- sam2/utils/transforms.py +118 -0
- singlebehaviorlab/__init__.py +4 -0
- singlebehaviorlab/__main__.py +130 -0
- singlebehaviorlab/_paths.py +100 -0
- singlebehaviorlab/backend/__init__.py +2 -0
- singlebehaviorlab/backend/augmentations.py +320 -0
- singlebehaviorlab/backend/data_store.py +420 -0
- singlebehaviorlab/backend/model.py +1290 -0
- singlebehaviorlab/backend/train.py +4667 -0
- singlebehaviorlab/backend/uncertainty.py +578 -0
- singlebehaviorlab/backend/video_processor.py +688 -0
- singlebehaviorlab/backend/video_utils.py +139 -0
- singlebehaviorlab/data/config/config.yaml +85 -0
- singlebehaviorlab/data/training_profiles.json +334 -0
- singlebehaviorlab/gui/__init__.py +4 -0
- singlebehaviorlab/gui/analysis_widget.py +2291 -0
- singlebehaviorlab/gui/attention_export.py +311 -0
- singlebehaviorlab/gui/clip_extraction_widget.py +481 -0
- singlebehaviorlab/gui/clustering_widget.py +3187 -0
- singlebehaviorlab/gui/inference_popups.py +1138 -0
- singlebehaviorlab/gui/inference_widget.py +4550 -0
- singlebehaviorlab/gui/inference_worker.py +651 -0
- singlebehaviorlab/gui/labeling_widget.py +2324 -0
- singlebehaviorlab/gui/main_window.py +754 -0
- singlebehaviorlab/gui/metadata_management_widget.py +1119 -0
- singlebehaviorlab/gui/motion_tracking.py +764 -0
- singlebehaviorlab/gui/overlay_export.py +1234 -0
- singlebehaviorlab/gui/plot_integration.py +729 -0
- singlebehaviorlab/gui/qt_helpers.py +29 -0
- singlebehaviorlab/gui/registration_widget.py +1485 -0
- singlebehaviorlab/gui/review_widget.py +1330 -0
- singlebehaviorlab/gui/segmentation_tracking_widget.py +2752 -0
- singlebehaviorlab/gui/tab_tutorial_dialog.py +312 -0
- singlebehaviorlab/gui/timeline_themes.py +131 -0
- singlebehaviorlab/gui/training_profiles.py +418 -0
- singlebehaviorlab/gui/training_widget.py +3719 -0
- singlebehaviorlab/gui/video_utils.py +233 -0
- singlebehaviorlab/licenses/SAM2-LICENSE +201 -0
- singlebehaviorlab/licenses/VideoPrism-LICENSE +202 -0
- singlebehaviorlab-2.0.0.dist-info/METADATA +447 -0
- singlebehaviorlab-2.0.0.dist-info/RECORD +88 -0
- singlebehaviorlab-2.0.0.dist-info/WHEEL +5 -0
- singlebehaviorlab-2.0.0.dist-info/entry_points.txt +2 -0
- singlebehaviorlab-2.0.0.dist-info/licenses/LICENSE +21 -0
- singlebehaviorlab-2.0.0.dist-info/top_level.txt +3 -0
- videoprism/__init__.py +0 -0
- videoprism/encoders.py +910 -0
- videoprism/layers.py +1136 -0
- videoprism/models.py +407 -0
- videoprism/tokenizers.py +167 -0
- videoprism/utils.py +168 -0
|
@@ -0,0 +1,3187 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Clustering Widget for SingleBehavior Lab.
|
|
3
|
+
Integrates preprocessing and clustering (UMAP, Leiden, HBSCAN) of behaviorome embeddings.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
import os
|
|
8
|
+
import pandas as pd
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
import numpy as np
|
|
12
|
+
import umap
|
|
13
|
+
import leidenalg as la
|
|
14
|
+
import igraph as ig
|
|
15
|
+
from sklearn.neighbors import kneighbors_graph
|
|
16
|
+
from sklearn.preprocessing import StandardScaler, MinMaxScaler, Normalizer
|
|
17
|
+
from sklearn.decomposition import PCA
|
|
18
|
+
import hdbscan
|
|
19
|
+
import plotly.express as px
|
|
20
|
+
import plotly.graph_objects as go
|
|
21
|
+
from plotly.subplots import make_subplots
|
|
22
|
+
import plotly.io as pio
|
|
23
|
+
from datetime import datetime
|
|
24
|
+
import pickle
|
|
25
|
+
|
|
26
|
+
from PyQt6.QtWidgets import (
|
|
27
|
+
QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QLabel,
|
|
28
|
+
QComboBox, QSlider, QCheckBox, QGroupBox, QScrollArea, QSplitter,
|
|
29
|
+
QMessageBox, QListWidget, QTextEdit, QFileDialog, QProgressBar, QDialog,
|
|
30
|
+
QSizePolicy, QDialogButtonBox
|
|
31
|
+
)
|
|
32
|
+
from PyQt6.QtCore import Qt, QThread, pyqtSignal
|
|
33
|
+
from PyQt6.QtGui import QFont
|
|
34
|
+
|
|
35
|
+
from .plot_integration import PlotlyWidget
|
|
36
|
+
from .qt_helpers import create_status_label, update_status_label, create_section
|
|
37
|
+
|
|
38
|
+
class ClusteringWorker(QThread):
|
|
39
|
+
"""Worker thread for clustering computation"""
|
|
40
|
+
finished = pyqtSignal(str, object) # status message, figure
|
|
41
|
+
|
|
42
|
+
def __init__(self, clustering_widget, params, data):
|
|
43
|
+
super().__init__()
|
|
44
|
+
self.clustering_widget = clustering_widget
|
|
45
|
+
self.params = params
|
|
46
|
+
self.data = data
|
|
47
|
+
|
|
48
|
+
def run(self):
|
|
49
|
+
"""Run clustering in background thread"""
|
|
50
|
+
try:
|
|
51
|
+
status, fig = self.clustering_widget.perform_clustering(self.data, **self.params)
|
|
52
|
+
self.finished.emit(status, fig)
|
|
53
|
+
except Exception as e:
|
|
54
|
+
logger.error("Error running clustering: %s", e, exc_info=True)
|
|
55
|
+
self.finished.emit(f"Error: {str(e)}", None)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class LoadDataWorker(QThread):
|
|
59
|
+
"""Background loader for large embedding matrices."""
|
|
60
|
+
loaded = pyqtSignal(object, object, str) # matrix_df, metadata_df, metadata_path
|
|
61
|
+
error = pyqtSignal(str)
|
|
62
|
+
|
|
63
|
+
def __init__(self, matrix_path: str, metadata_path: str | None):
|
|
64
|
+
super().__init__()
|
|
65
|
+
self.matrix_path = matrix_path
|
|
66
|
+
self.metadata_path = metadata_path
|
|
67
|
+
|
|
68
|
+
def run(self):
|
|
69
|
+
try:
|
|
70
|
+
matrix_df = self._load_matrix(self.matrix_path)
|
|
71
|
+
metadata_df = None
|
|
72
|
+
meta_path = self.metadata_path
|
|
73
|
+
if self.metadata_path and os.path.exists(self.metadata_path):
|
|
74
|
+
metadata_df = self._load_metadata(self.metadata_path)
|
|
75
|
+
else:
|
|
76
|
+
meta_path = None
|
|
77
|
+
self.loaded.emit(matrix_df, metadata_df, meta_path)
|
|
78
|
+
except Exception as e:
|
|
79
|
+
logger.error("Error loading data: %s", e, exc_info=True)
|
|
80
|
+
self.error.emit(str(e))
|
|
81
|
+
|
|
82
|
+
def _load_matrix(self, path: str) -> pd.DataFrame:
|
|
83
|
+
if path.endswith(".npz"):
|
|
84
|
+
with np.load(path, allow_pickle=True) as data:
|
|
85
|
+
matrix = data["matrix"]
|
|
86
|
+
feature_names = data["feature_names"]
|
|
87
|
+
snippet_ids = data["snippet_ids"] if "snippet_ids" in data else data.get("span_ids", None) # Backward compatibility
|
|
88
|
+
if snippet_ids is None:
|
|
89
|
+
# Fallback: generate snippet IDs
|
|
90
|
+
snippet_ids = np.array([f'snippet{i+1}' for i in range(matrix.shape[1])])
|
|
91
|
+
return pd.DataFrame(matrix, index=feature_names, columns=snippet_ids)
|
|
92
|
+
if path.endswith(".parquet"):
|
|
93
|
+
return pd.read_parquet(path)
|
|
94
|
+
# default csv
|
|
95
|
+
return pd.read_csv(path, index_col=0)
|
|
96
|
+
|
|
97
|
+
def _load_metadata(self, path: str) -> pd.DataFrame:
|
|
98
|
+
if path.endswith(".npz"):
|
|
99
|
+
with np.load(path, allow_pickle=True) as data:
|
|
100
|
+
metadata_values = data["metadata"]
|
|
101
|
+
columns = list(data["columns"])
|
|
102
|
+
return pd.DataFrame(metadata_values, columns=columns)
|
|
103
|
+
elif path.endswith(".parquet"):
|
|
104
|
+
return pd.read_parquet(path)
|
|
105
|
+
else:
|
|
106
|
+
return pd.read_csv(path)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class ClusterExportDialog(QDialog):
|
|
110
|
+
"""Dialog shown after cluster export with option to load the new dataset."""
|
|
111
|
+
|
|
112
|
+
def __init__(self, parent, message: str, matrix_path: str, metadata_path: str):
|
|
113
|
+
super().__init__(parent)
|
|
114
|
+
self.matrix_path = matrix_path
|
|
115
|
+
self.metadata_path = metadata_path
|
|
116
|
+
self.load_requested = False
|
|
117
|
+
|
|
118
|
+
self.setWindowTitle("Cluster export complete")
|
|
119
|
+
self.setMinimumWidth(500)
|
|
120
|
+
|
|
121
|
+
layout = QVBoxLayout(self)
|
|
122
|
+
|
|
123
|
+
# Message label
|
|
124
|
+
msg_label = QLabel(message)
|
|
125
|
+
msg_label.setWordWrap(True)
|
|
126
|
+
layout.addWidget(msg_label)
|
|
127
|
+
|
|
128
|
+
# Buttons
|
|
129
|
+
button_layout = QHBoxLayout()
|
|
130
|
+
load_btn = QPushButton("Load dataset")
|
|
131
|
+
load_btn.clicked.connect(self._on_load_clicked)
|
|
132
|
+
ok_btn = QPushButton("OK")
|
|
133
|
+
ok_btn.clicked.connect(self.accept)
|
|
134
|
+
button_layout.addWidget(load_btn)
|
|
135
|
+
button_layout.addStretch()
|
|
136
|
+
button_layout.addWidget(ok_btn)
|
|
137
|
+
layout.addLayout(button_layout)
|
|
138
|
+
|
|
139
|
+
def _on_load_clicked(self):
|
|
140
|
+
"""Mark that user wants to load the dataset."""
|
|
141
|
+
self.load_requested = True
|
|
142
|
+
self.accept()
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
class ClusteringWidget(QWidget):
|
|
146
|
+
"""Widget for clustering behaviorome embeddings."""
|
|
147
|
+
|
|
148
|
+
def __init__(self, config: dict):
|
|
149
|
+
super().__init__()
|
|
150
|
+
self.config = config
|
|
151
|
+
self.matrix_data = None
|
|
152
|
+
self.metadata = None
|
|
153
|
+
self.metadata_file_path = None # Store path to metadata file for updates
|
|
154
|
+
self.processed_data = None
|
|
155
|
+
self.embedding = None
|
|
156
|
+
self.clusters = None
|
|
157
|
+
self.current_fig = None
|
|
158
|
+
self.current_df = None
|
|
159
|
+
self.snippet_to_clip_map = {} # Map snippet_id -> clip_path
|
|
160
|
+
|
|
161
|
+
# Preprocessing state
|
|
162
|
+
self.selected_features = None
|
|
163
|
+
|
|
164
|
+
self._setup_ui()
|
|
165
|
+
|
|
166
|
+
def update_config(self, config: dict):
|
|
167
|
+
"""Update configuration."""
|
|
168
|
+
self.config = config
|
|
169
|
+
|
|
170
|
+
def _setup_ui(self):
|
|
171
|
+
"""Setup UI components."""
|
|
172
|
+
self.main_layout = QVBoxLayout(self)
|
|
173
|
+
|
|
174
|
+
# Splitter: Settings on left, Plot on right
|
|
175
|
+
splitter = QSplitter(Qt.Orientation.Horizontal)
|
|
176
|
+
|
|
177
|
+
# Left Panel: Settings (Scrollable)
|
|
178
|
+
settings_scroll = QScrollArea()
|
|
179
|
+
settings_scroll.setWidgetResizable(True)
|
|
180
|
+
settings_scroll.setMinimumWidth(300)
|
|
181
|
+
settings_scroll.setMaximumWidth(350)
|
|
182
|
+
settings_scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAsNeeded)
|
|
183
|
+
settings_scroll.setVerticalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAsNeeded)
|
|
184
|
+
|
|
185
|
+
settings_widget = QWidget()
|
|
186
|
+
settings_layout = QVBoxLayout(settings_widget)
|
|
187
|
+
settings_layout.setContentsMargins(5, 5, 5, 5)
|
|
188
|
+
settings_layout.setSpacing(5)
|
|
189
|
+
|
|
190
|
+
# 1. Data Loading Section
|
|
191
|
+
data_group = QGroupBox("Data loading")
|
|
192
|
+
data_layout = QVBoxLayout()
|
|
193
|
+
data_layout.setSpacing(5)
|
|
194
|
+
|
|
195
|
+
self.load_status_label = QLabel("No data loaded")
|
|
196
|
+
self.load_status_label.setWordWrap(True)
|
|
197
|
+
self.load_status_label.setTextFormat(Qt.TextFormat.PlainText)
|
|
198
|
+
self.load_status_label.setAlignment(Qt.AlignmentFlag.AlignTop | Qt.AlignmentFlag.AlignLeft)
|
|
199
|
+
# Set size policy to prevent expansion
|
|
200
|
+
self.load_status_label.setSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Minimum)
|
|
201
|
+
data_layout.addWidget(self.load_status_label)
|
|
202
|
+
|
|
203
|
+
self.load_btn = QPushButton("Load from registration")
|
|
204
|
+
self.load_btn.clicked.connect(self.load_data)
|
|
205
|
+
data_layout.addWidget(self.load_btn)
|
|
206
|
+
|
|
207
|
+
self.load_file_btn = QPushButton("Load external matrix...")
|
|
208
|
+
self.load_file_btn.clicked.connect(self.load_external_data)
|
|
209
|
+
data_layout.addWidget(self.load_file_btn)
|
|
210
|
+
|
|
211
|
+
self.load_progress = QProgressBar()
|
|
212
|
+
self.load_progress.setVisible(False)
|
|
213
|
+
self.load_progress.setRange(0, 0) # indeterminate
|
|
214
|
+
data_layout.addWidget(self.load_progress)
|
|
215
|
+
|
|
216
|
+
data_group.setLayout(data_layout)
|
|
217
|
+
settings_layout.addWidget(data_group)
|
|
218
|
+
|
|
219
|
+
# 2. Preprocessing Section (Collapsible-ish via GroupBox)
|
|
220
|
+
preprocess_group = QGroupBox("Preprocessing")
|
|
221
|
+
preprocess_layout = QVBoxLayout()
|
|
222
|
+
|
|
223
|
+
# Normalization
|
|
224
|
+
norm_row = QHBoxLayout()
|
|
225
|
+
norm_row.addWidget(QLabel("Normalization:"))
|
|
226
|
+
self.normalization_method = QComboBox()
|
|
227
|
+
self.normalization_method.addItems(["none", "l2", "standard", "minmax"])
|
|
228
|
+
self.normalization_method.setCurrentText("none") # Default for embeddings
|
|
229
|
+
self.normalization_method.setToolTip(
|
|
230
|
+
"None: Embeddings usually normalized\n"
|
|
231
|
+
"L2: Unit norm (good for embeddings)\n"
|
|
232
|
+
"Standard: Zero mean, unit var\n"
|
|
233
|
+
"MinMax: [0,1] range"
|
|
234
|
+
)
|
|
235
|
+
norm_row.addWidget(self.normalization_method)
|
|
236
|
+
preprocess_layout.addLayout(norm_row)
|
|
237
|
+
|
|
238
|
+
self.preprocess_btn = QPushButton("Apply preprocessing")
|
|
239
|
+
self.preprocess_btn.clicked.connect(self.apply_preprocessing)
|
|
240
|
+
preprocess_layout.addWidget(self.preprocess_btn)
|
|
241
|
+
|
|
242
|
+
self.preprocess_status = QLabel("Ready")
|
|
243
|
+
self.preprocess_status.setWordWrap(True)
|
|
244
|
+
self.preprocess_status.setSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Minimum)
|
|
245
|
+
preprocess_layout.addWidget(self.preprocess_status)
|
|
246
|
+
|
|
247
|
+
preprocess_group.setLayout(preprocess_layout)
|
|
248
|
+
settings_layout.addWidget(preprocess_group)
|
|
249
|
+
|
|
250
|
+
# 3. Clustering Parameters
|
|
251
|
+
cluster_params_group = QGroupBox("Clustering & projection")
|
|
252
|
+
cluster_params_layout = QVBoxLayout()
|
|
253
|
+
|
|
254
|
+
cluster_params_layout.addWidget(QLabel("<b>Dimensionality Reduction (UMAP)</b>"))
|
|
255
|
+
|
|
256
|
+
self.n_neighbors, self.n_neighbors_lbl = self._create_slider(
|
|
257
|
+
"Neighbors:", 2, 200, 30, 1
|
|
258
|
+
)
|
|
259
|
+
cluster_params_layout.addWidget(self._slider_widget(self.n_neighbors, self.n_neighbors_lbl))
|
|
260
|
+
|
|
261
|
+
self.min_dist, self.min_dist_lbl = self._create_slider(
|
|
262
|
+
"Min Dist:", 0.0, 0.99, 0.1, 0.01, is_float=True
|
|
263
|
+
)
|
|
264
|
+
cluster_params_layout.addWidget(self._slider_widget(self.min_dist, self.min_dist_lbl))
|
|
265
|
+
|
|
266
|
+
self.n_components, self.n_components_lbl = self._create_slider(
|
|
267
|
+
"Components:", 2, 3, 2, 1
|
|
268
|
+
)
|
|
269
|
+
cluster_params_layout.addWidget(self._slider_widget(self.n_components, self.n_components_lbl))
|
|
270
|
+
|
|
271
|
+
# Clustering Method
|
|
272
|
+
cluster_params_layout.addWidget(QLabel("<b>Clustering Method</b>"))
|
|
273
|
+
self.clustering_method = QComboBox()
|
|
274
|
+
self.clustering_method.addItems(['leiden', 'hdbscan']) # Only these two + UMAP visualization
|
|
275
|
+
self.clustering_method.currentIndexChanged.connect(self._toggle_clustering_params)
|
|
276
|
+
cluster_params_layout.addWidget(self.clustering_method)
|
|
277
|
+
|
|
278
|
+
# Leiden Params
|
|
279
|
+
self.leiden_container = QWidget()
|
|
280
|
+
leiden_layout = QVBoxLayout(self.leiden_container)
|
|
281
|
+
leiden_layout.setContentsMargins(0,0,0,0)
|
|
282
|
+
|
|
283
|
+
self.leiden_resolution, self.leiden_res_lbl = self._create_slider(
|
|
284
|
+
"Resolution:", 0.1, 5.0, 1.0, 0.1, is_float=True
|
|
285
|
+
)
|
|
286
|
+
leiden_layout.addWidget(self._slider_widget(self.leiden_resolution, self.leiden_res_lbl))
|
|
287
|
+
|
|
288
|
+
self.leiden_k, self.leiden_k_lbl = self._create_slider(
|
|
289
|
+
"K-Neighbors:", 2, 100, 15, 1
|
|
290
|
+
)
|
|
291
|
+
leiden_layout.addWidget(self._slider_widget(self.leiden_k, self.leiden_k_lbl))
|
|
292
|
+
|
|
293
|
+
cluster_params_layout.addWidget(self.leiden_container)
|
|
294
|
+
|
|
295
|
+
# HDBSCAN Params
|
|
296
|
+
self.hdbscan_container = QWidget()
|
|
297
|
+
hdbscan_layout = QVBoxLayout(self.hdbscan_container)
|
|
298
|
+
hdbscan_layout.setContentsMargins(0,0,0,0)
|
|
299
|
+
|
|
300
|
+
self.min_cluster_size, self.min_cluster_size_lbl = self._create_slider(
|
|
301
|
+
"Min Cluster Size:", 2, 100, 5, 1
|
|
302
|
+
)
|
|
303
|
+
hdbscan_layout.addWidget(self._slider_widget(self.min_cluster_size, self.min_cluster_size_lbl))
|
|
304
|
+
|
|
305
|
+
self.min_samples, self.min_samples_lbl = self._create_slider(
|
|
306
|
+
"Min Samples:", 1, 50, 1, 1
|
|
307
|
+
)
|
|
308
|
+
hdbscan_layout.addWidget(self._slider_widget(self.min_samples, self.min_samples_lbl))
|
|
309
|
+
|
|
310
|
+
self.hdbscan_epsilon, self.hdbscan_eps_lbl = self._create_slider(
|
|
311
|
+
"Epsilon:", 0.0, 5.0, 0.0, 0.1, is_float=True
|
|
312
|
+
)
|
|
313
|
+
hdbscan_layout.addWidget(self._slider_widget(self.hdbscan_epsilon, self.hdbscan_eps_lbl))
|
|
314
|
+
|
|
315
|
+
cluster_params_layout.addWidget(self.hdbscan_container)
|
|
316
|
+
self.hdbscan_container.hide() # Hide initially
|
|
317
|
+
|
|
318
|
+
# Run Button
|
|
319
|
+
self.run_btn = QPushButton("Run analysis")
|
|
320
|
+
self.run_btn.setStyleSheet("background-color: #007bff; color: white; font-weight: bold; padding: 5px;")
|
|
321
|
+
self.run_btn.clicked.connect(self.run_clustering)
|
|
322
|
+
cluster_params_layout.addWidget(self.run_btn)
|
|
323
|
+
|
|
324
|
+
cluster_params_group.setLayout(cluster_params_layout)
|
|
325
|
+
settings_layout.addWidget(cluster_params_group)
|
|
326
|
+
|
|
327
|
+
# 4. Metadata Management
|
|
328
|
+
metadata_group = QGroupBox("Metadata")
|
|
329
|
+
metadata_layout = QVBoxLayout()
|
|
330
|
+
self.manage_metadata_btn = QPushButton("Manage metadata")
|
|
331
|
+
self.manage_metadata_btn.clicked.connect(self.open_metadata_manager)
|
|
332
|
+
metadata_layout.addWidget(self.manage_metadata_btn)
|
|
333
|
+
metadata_group.setLayout(metadata_layout)
|
|
334
|
+
settings_layout.addWidget(metadata_group)
|
|
335
|
+
|
|
336
|
+
# 5. Cluster Export
|
|
337
|
+
cluster_export_group = QGroupBox("Cluster export")
|
|
338
|
+
cluster_export_layout = QVBoxLayout()
|
|
339
|
+
|
|
340
|
+
# Help button with explanation
|
|
341
|
+
help_row = QHBoxLayout()
|
|
342
|
+
help_row.addWidget(QLabel("Select clusters:"))
|
|
343
|
+
self.cluster_export_help_btn = QPushButton("?")
|
|
344
|
+
self.cluster_export_help_btn.setMaximumWidth(30)
|
|
345
|
+
self.cluster_export_help_btn.setToolTip("Click for help")
|
|
346
|
+
self.cluster_export_help_btn.clicked.connect(self._show_cluster_export_help)
|
|
347
|
+
help_row.addWidget(self.cluster_export_help_btn)
|
|
348
|
+
help_row.addStretch()
|
|
349
|
+
cluster_export_layout.addLayout(help_row)
|
|
350
|
+
|
|
351
|
+
self.cluster_export_list = QListWidget()
|
|
352
|
+
self.cluster_export_list.setSelectionMode(QListWidget.SelectionMode.ExtendedSelection)
|
|
353
|
+
cluster_export_layout.addWidget(self.cluster_export_list)
|
|
354
|
+
|
|
355
|
+
self.use_raw_data_checkbox = QCheckBox("Use raw data")
|
|
356
|
+
self.use_raw_data_checkbox.setToolTip("If checked, exports raw (unnormalized) data. Otherwise exports preprocessed data.")
|
|
357
|
+
cluster_export_layout.addWidget(self.use_raw_data_checkbox)
|
|
358
|
+
|
|
359
|
+
self.extract_cluster_btn = QPushButton("Extract selected cluster")
|
|
360
|
+
self.extract_cluster_btn.clicked.connect(self._extract_cluster)
|
|
361
|
+
self.extract_cluster_btn.setEnabled(False)
|
|
362
|
+
cluster_export_layout.addWidget(self.extract_cluster_btn)
|
|
363
|
+
|
|
364
|
+
self.exclude_clusters_btn = QPushButton("Exclude selected clusters")
|
|
365
|
+
self.exclude_clusters_btn.clicked.connect(self._exclude_clusters)
|
|
366
|
+
self.exclude_clusters_btn.setEnabled(False)
|
|
367
|
+
cluster_export_layout.addWidget(self.exclude_clusters_btn)
|
|
368
|
+
|
|
369
|
+
cluster_export_group.setLayout(cluster_export_layout)
|
|
370
|
+
settings_layout.addWidget(cluster_export_group)
|
|
371
|
+
|
|
372
|
+
settings_layout.addStretch()
|
|
373
|
+
settings_widget.setLayout(settings_layout)
|
|
374
|
+
settings_scroll.setWidget(settings_widget)
|
|
375
|
+
|
|
376
|
+
splitter.addWidget(settings_scroll)
|
|
377
|
+
|
|
378
|
+
# Middle Panel: Plot
|
|
379
|
+
plot_container = QWidget()
|
|
380
|
+
plot_layout = QVBoxLayout(plot_container)
|
|
381
|
+
plot_layout.setContentsMargins(0, 0, 0, 0)
|
|
382
|
+
|
|
383
|
+
self.plot_widget = PlotlyWidget()
|
|
384
|
+
# Set up click callback for snippet selection
|
|
385
|
+
self.plot_widget.set_click_callback(self._on_umap_point_clicked)
|
|
386
|
+
plot_layout.addWidget(self.plot_widget)
|
|
387
|
+
|
|
388
|
+
# Status label (minimal, at bottom)
|
|
389
|
+
self.status_label = QLabel("Ready")
|
|
390
|
+
self.status_label.setMaximumHeight(20)
|
|
391
|
+
self.status_label.setWordWrap(True)
|
|
392
|
+
self.status_label.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Maximum)
|
|
393
|
+
plot_layout.addWidget(self.status_label)
|
|
394
|
+
|
|
395
|
+
splitter.addWidget(plot_container)
|
|
396
|
+
|
|
397
|
+
# Right Panel: Plot Settings
|
|
398
|
+
plot_settings_scroll = QScrollArea()
|
|
399
|
+
plot_settings_scroll.setWidgetResizable(True)
|
|
400
|
+
plot_settings_scroll.setMinimumWidth(300)
|
|
401
|
+
plot_settings_scroll.setMaximumWidth(350)
|
|
402
|
+
plot_settings_scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAsNeeded)
|
|
403
|
+
plot_settings_scroll.setVerticalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAsNeeded)
|
|
404
|
+
|
|
405
|
+
plot_settings_widget = QWidget()
|
|
406
|
+
plot_settings_layout = QVBoxLayout(plot_settings_widget)
|
|
407
|
+
plot_settings_layout.setContentsMargins(5, 5, 5, 5)
|
|
408
|
+
plot_settings_layout.setSpacing(5)
|
|
409
|
+
|
|
410
|
+
# Plot Settings Group
|
|
411
|
+
plot_settings_group = QGroupBox("Plot settings")
|
|
412
|
+
plot_settings_group_layout = QVBoxLayout()
|
|
413
|
+
plot_settings_group_layout.setSpacing(5)
|
|
414
|
+
|
|
415
|
+
# Color Theme Selector
|
|
416
|
+
plot_settings_group_layout.addWidget(QLabel("<b>Color Theme:</b>"))
|
|
417
|
+
self.color_theme_combo = QComboBox()
|
|
418
|
+
self.color_theme_combo.addItems(["plotly", "plotly_white", "plotly_dark", "ggplot2", "seaborn", "simple_white", "none"])
|
|
419
|
+
self.color_theme_combo.setCurrentText("simple_white")
|
|
420
|
+
self.color_theme_combo.currentIndexChanged.connect(self._update_plots_by_metadata)
|
|
421
|
+
plot_settings_group_layout.addWidget(self.color_theme_combo)
|
|
422
|
+
|
|
423
|
+
# Metadata Column Selector
|
|
424
|
+
plot_settings_group_layout.addWidget(QLabel("<b>Group by metadata column:</b>"))
|
|
425
|
+
self.metadata_column_combo = QComboBox()
|
|
426
|
+
self.metadata_column_combo.addItem("None (Show all clusters)")
|
|
427
|
+
self.metadata_column_combo.currentTextChanged.connect(self._update_plots_by_metadata)
|
|
428
|
+
plot_settings_group_layout.addWidget(self.metadata_column_combo)
|
|
429
|
+
|
|
430
|
+
# Point Size for grouped plots
|
|
431
|
+
point_size_label = QLabel("Point size:")
|
|
432
|
+
self.plot_point_size_slider = QSlider(Qt.Orientation.Horizontal)
|
|
433
|
+
self.plot_point_size_slider.setMinimum(1)
|
|
434
|
+
self.plot_point_size_slider.setMaximum(20)
|
|
435
|
+
self.plot_point_size_slider.setValue(5)
|
|
436
|
+
self.plot_point_size_label = QLabel("5")
|
|
437
|
+
self.plot_point_size_slider.valueChanged.connect(
|
|
438
|
+
lambda v: (self.plot_point_size_label.setText(str(v)), self._update_plots_by_metadata())
|
|
439
|
+
)
|
|
440
|
+
point_size_layout = QHBoxLayout()
|
|
441
|
+
point_size_layout.addWidget(point_size_label)
|
|
442
|
+
point_size_layout.addWidget(self.plot_point_size_slider)
|
|
443
|
+
point_size_layout.addWidget(self.plot_point_size_label)
|
|
444
|
+
plot_settings_group_layout.addLayout(point_size_layout)
|
|
445
|
+
|
|
446
|
+
# Plot Type Selector
|
|
447
|
+
plot_settings_group_layout.addWidget(QLabel("<b>Plot type:</b>"))
|
|
448
|
+
self.plot_type_combo = QComboBox()
|
|
449
|
+
self.plot_type_combo.addItems(["UMAP", "Cluster Proportions", "Single Cluster Analysis", "Spatial cluster distribution"])
|
|
450
|
+
self.plot_type_combo.currentTextChanged.connect(self._update_plot_type)
|
|
451
|
+
plot_settings_group_layout.addWidget(self.plot_type_combo)
|
|
452
|
+
|
|
453
|
+
# Single Cluster Analysis (only show when relevant)
|
|
454
|
+
self.single_cluster_label = QLabel("<b>Select cluster:</b>")
|
|
455
|
+
self.single_cluster_combo = QComboBox()
|
|
456
|
+
self.single_cluster_combo.addItem("None")
|
|
457
|
+
self.single_cluster_combo.currentTextChanged.connect(self._update_plots_by_metadata)
|
|
458
|
+
plot_settings_group_layout.addWidget(self.single_cluster_label)
|
|
459
|
+
plot_settings_group_layout.addWidget(self.single_cluster_combo)
|
|
460
|
+
self.single_cluster_label.setVisible(False)
|
|
461
|
+
self.single_cluster_combo.setVisible(False)
|
|
462
|
+
|
|
463
|
+
# Spatial distribution: Video and Object selectors
|
|
464
|
+
self.spatial_video_label = QLabel("<b>Video:</b>")
|
|
465
|
+
self.spatial_video_combo = QComboBox()
|
|
466
|
+
self.spatial_video_combo.addItem("All")
|
|
467
|
+
self.spatial_video_combo.currentTextChanged.connect(self._on_spatial_video_changed)
|
|
468
|
+
plot_settings_group_layout.addWidget(self.spatial_video_label)
|
|
469
|
+
plot_settings_group_layout.addWidget(self.spatial_video_combo)
|
|
470
|
+
self.spatial_video_label.setVisible(False)
|
|
471
|
+
self.spatial_video_combo.setVisible(False)
|
|
472
|
+
|
|
473
|
+
self.spatial_object_label = QLabel("<b>Object:</b>")
|
|
474
|
+
self.spatial_object_combo = QComboBox()
|
|
475
|
+
self.spatial_object_combo.addItem("All")
|
|
476
|
+
self.spatial_object_combo.currentTextChanged.connect(self._update_plots_by_metadata)
|
|
477
|
+
plot_settings_group_layout.addWidget(self.spatial_object_label)
|
|
478
|
+
plot_settings_group_layout.addWidget(self.spatial_object_combo)
|
|
479
|
+
self.spatial_object_label.setVisible(False)
|
|
480
|
+
self.spatial_object_combo.setVisible(False)
|
|
481
|
+
|
|
482
|
+
plot_settings_group.setLayout(plot_settings_group_layout)
|
|
483
|
+
plot_settings_layout.addWidget(plot_settings_group)
|
|
484
|
+
|
|
485
|
+
# Export Group
|
|
486
|
+
export_group = QGroupBox("Export")
|
|
487
|
+
export_layout = QVBoxLayout()
|
|
488
|
+
|
|
489
|
+
self.export_plot_btn = QPushButton("Export plot (PDF/SVG)")
|
|
490
|
+
self.export_plot_btn.clicked.connect(self._export_plot)
|
|
491
|
+
self.export_plot_btn.setEnabled(False)
|
|
492
|
+
export_layout.addWidget(self.export_plot_btn)
|
|
493
|
+
|
|
494
|
+
self.export_csv_btn = QPushButton("Export results (CSV)")
|
|
495
|
+
self.export_csv_btn.clicked.connect(self.export_results)
|
|
496
|
+
self.export_csv_btn.setEnabled(False)
|
|
497
|
+
export_layout.addWidget(self.export_csv_btn)
|
|
498
|
+
|
|
499
|
+
export_group.setLayout(export_layout)
|
|
500
|
+
plot_settings_layout.addWidget(export_group)
|
|
501
|
+
|
|
502
|
+
# Analysis State Group
|
|
503
|
+
state_group = QGroupBox("Analysis state")
|
|
504
|
+
state_layout = QVBoxLayout()
|
|
505
|
+
|
|
506
|
+
self.save_state_btn = QPushButton("Save full analysis")
|
|
507
|
+
self.save_state_btn.clicked.connect(self._save_analysis_state)
|
|
508
|
+
state_layout.addWidget(self.save_state_btn)
|
|
509
|
+
|
|
510
|
+
self.load_state_btn = QPushButton("Load full analysis")
|
|
511
|
+
self.load_state_btn.clicked.connect(self._load_analysis_state)
|
|
512
|
+
state_layout.addWidget(self.load_state_btn)
|
|
513
|
+
|
|
514
|
+
self.file_info_label = QLabel("No analysis loaded")
|
|
515
|
+
self.file_info_label.setWordWrap(True)
|
|
516
|
+
self.file_info_label.setStyleSheet("color: gray; font-style: italic;")
|
|
517
|
+
state_layout.addWidget(self.file_info_label)
|
|
518
|
+
|
|
519
|
+
state_group.setLayout(state_layout)
|
|
520
|
+
plot_settings_layout.addWidget(state_group)
|
|
521
|
+
|
|
522
|
+
plot_settings_layout.addStretch()
|
|
523
|
+
plot_settings_widget.setLayout(plot_settings_layout)
|
|
524
|
+
plot_settings_scroll.setWidget(plot_settings_widget)
|
|
525
|
+
|
|
526
|
+
splitter.addWidget(plot_settings_scroll)
|
|
527
|
+
splitter.setStretchFactor(1, 3) # Plot takes most space
|
|
528
|
+
|
|
529
|
+
self.main_layout.addWidget(splitter)
|
|
530
|
+
|
|
531
|
+
def _create_slider(self, label_text, min_val, max_val, default, step, is_float=False):
|
|
532
|
+
slider = QSlider(Qt.Orientation.Horizontal)
|
|
533
|
+
if is_float:
|
|
534
|
+
slider.setMinimum(int(min_val * 100))
|
|
535
|
+
slider.setMaximum(int(max_val * 100))
|
|
536
|
+
slider.setValue(int(default * 100))
|
|
537
|
+
slider.setSingleStep(int(step * 100))
|
|
538
|
+
value_label = QLabel(f"{default:.2f}")
|
|
539
|
+
slider.valueChanged.connect(lambda v: value_label.setText(f"{v/100:.2f}"))
|
|
540
|
+
else:
|
|
541
|
+
slider.setMinimum(min_val)
|
|
542
|
+
slider.setMaximum(max_val)
|
|
543
|
+
slider.setValue(default)
|
|
544
|
+
slider.setSingleStep(step)
|
|
545
|
+
value_label = QLabel(str(default))
|
|
546
|
+
slider.valueChanged.connect(lambda v: value_label.setText(str(v)))
|
|
547
|
+
slider.label_text = label_text
|
|
548
|
+
return slider, value_label
|
|
549
|
+
|
|
550
|
+
def _slider_widget(self, slider, label):
|
|
551
|
+
widget = QWidget()
|
|
552
|
+
layout = QHBoxLayout()
|
|
553
|
+
layout.setContentsMargins(0, 0, 0, 0)
|
|
554
|
+
layout.addWidget(QLabel(slider.label_text))
|
|
555
|
+
layout.addWidget(slider)
|
|
556
|
+
layout.addWidget(label)
|
|
557
|
+
widget.setLayout(layout)
|
|
558
|
+
return widget
|
|
559
|
+
|
|
560
|
+
def _get_slider_value(self, slider, is_float=False):
|
|
561
|
+
if is_float:
|
|
562
|
+
return slider.value() / 100.0
|
|
563
|
+
return slider.value()
|
|
564
|
+
|
|
565
|
+
def _toggle_clustering_params(self):
|
|
566
|
+
method = self.clustering_method.currentText()
|
|
567
|
+
if method == 'leiden':
|
|
568
|
+
self.leiden_container.show()
|
|
569
|
+
self.hdbscan_container.hide()
|
|
570
|
+
else:
|
|
571
|
+
self.leiden_container.hide()
|
|
572
|
+
self.hdbscan_container.show()
|
|
573
|
+
|
|
574
|
+
def _find_latest_behaviorome(self, directory: str):
|
|
575
|
+
"""Pick the latest behaviorome matrix with preferred extensions."""
|
|
576
|
+
exts = ["npz", "parquet", "csv"]
|
|
577
|
+
candidates = []
|
|
578
|
+
for fname in os.listdir(directory):
|
|
579
|
+
for ext in exts:
|
|
580
|
+
if fname.startswith("behaviorome_") and fname.endswith(f"_matrix.{ext}"):
|
|
581
|
+
candidates.append(os.path.join(directory, fname))
|
|
582
|
+
if not candidates:
|
|
583
|
+
return None, None
|
|
584
|
+
candidates.sort(reverse=True)
|
|
585
|
+
matrix_path = candidates[0]
|
|
586
|
+
base, ext = os.path.splitext(matrix_path)
|
|
587
|
+
# ext includes .npz etc; build metadata preference (NPZ first, then Parquet, then CSV)
|
|
588
|
+
metadata_path = None
|
|
589
|
+
for meta_ext in ["npz", "parquet", "csv"]:
|
|
590
|
+
candidate_meta = base.replace("_matrix", "_metadata") + f".{meta_ext}"
|
|
591
|
+
if os.path.exists(candidate_meta):
|
|
592
|
+
metadata_path = candidate_meta
|
|
593
|
+
break
|
|
594
|
+
return matrix_path, metadata_path
|
|
595
|
+
|
|
596
|
+
def _load_files_async(self, matrix_path: str, metadata_path: str | None):
|
|
597
|
+
"""Start background load with progress bar."""
|
|
598
|
+
self.load_progress.setVisible(True)
|
|
599
|
+
self.load_status_label.setText("Loading data...")
|
|
600
|
+
self.load_status_label.setStyleSheet("color: black;")
|
|
601
|
+
self.load_progress.setRange(0, 0)
|
|
602
|
+
self.load_worker = LoadDataWorker(matrix_path, metadata_path)
|
|
603
|
+
self.load_worker.loaded.connect(self._on_loaded_data)
|
|
604
|
+
self.load_worker.error.connect(self._on_load_error)
|
|
605
|
+
self.load_worker.start()
|
|
606
|
+
|
|
607
|
+
def _on_loaded_data(self, matrix_df: pd.DataFrame, metadata_df: pd.DataFrame | None, metadata_path: str | None):
|
|
608
|
+
self.load_progress.setVisible(False)
|
|
609
|
+
# Downcast to float32 to reduce memory for large matrices
|
|
610
|
+
try:
|
|
611
|
+
self.matrix_data = matrix_df.astype(np.float32, copy=False)
|
|
612
|
+
except Exception:
|
|
613
|
+
self.matrix_data = matrix_df
|
|
614
|
+
self.metadata = metadata_df
|
|
615
|
+
self.metadata_file_path = metadata_path
|
|
616
|
+
self.processed_data = None
|
|
617
|
+
self.preprocess_status.setText("Raw data loaded")
|
|
618
|
+
shape = (self.matrix_data.shape[0], self.matrix_data.shape[1]) if self.matrix_data is not None else (0, 0)
|
|
619
|
+
|
|
620
|
+
# Format status text to fit in container (truncate long filenames)
|
|
621
|
+
max_filename_len = 35
|
|
622
|
+
status_text = f"Matrix: {shape[0]} x {shape[1]}"
|
|
623
|
+
if metadata_path:
|
|
624
|
+
meta_filename = os.path.basename(metadata_path)
|
|
625
|
+
if len(meta_filename) > max_filename_len:
|
|
626
|
+
meta_filename = meta_filename[:max_filename_len-3] + "..."
|
|
627
|
+
status_text += f"\nMeta: {meta_filename}"
|
|
628
|
+
|
|
629
|
+
self.load_status_label.setText(status_text)
|
|
630
|
+
self.load_status_label.setStyleSheet("color: green;")
|
|
631
|
+
self.apply_preprocessing()
|
|
632
|
+
|
|
633
|
+
def _on_load_error(self, msg: str):
|
|
634
|
+
self.load_progress.setVisible(False)
|
|
635
|
+
QMessageBox.critical(self, "Load Error", f"Failed to load data: {msg}")
|
|
636
|
+
|
|
637
|
+
def load_data(self):
|
|
638
|
+
"""Load data from experiment folder."""
|
|
639
|
+
experiment_path = self.config.get("experiment_path")
|
|
640
|
+
if not experiment_path:
|
|
641
|
+
QMessageBox.warning(self, "No Experiment", "Please create or load an experiment first.")
|
|
642
|
+
return
|
|
643
|
+
|
|
644
|
+
registered_clips_dir = os.path.join(experiment_path, "registered_clips")
|
|
645
|
+
if not os.path.exists(registered_clips_dir):
|
|
646
|
+
QMessageBox.warning(self, "No Data", "No registered clips directory found.")
|
|
647
|
+
return
|
|
648
|
+
|
|
649
|
+
matrix_path, metadata_path = self._find_latest_behaviorome(registered_clips_dir)
|
|
650
|
+
if not matrix_path:
|
|
651
|
+
QMessageBox.warning(self, "No Data", "No behaviorome matrix files found (npz/parquet/csv).")
|
|
652
|
+
return
|
|
653
|
+
self._load_files_async(matrix_path, metadata_path)
|
|
654
|
+
|
|
655
|
+
def load_external_data(self):
|
|
656
|
+
"""Load external data (npz/parquet/csv)."""
|
|
657
|
+
matrix_path, _ = QFileDialog.getOpenFileName(
|
|
658
|
+
self,
|
|
659
|
+
"Open Feature Matrix",
|
|
660
|
+
"",
|
|
661
|
+
"Matrices (*.npz *.parquet *.csv);;All Files (*)"
|
|
662
|
+
)
|
|
663
|
+
if not matrix_path:
|
|
664
|
+
return
|
|
665
|
+
|
|
666
|
+
metadata_path, _ = QFileDialog.getOpenFileName(
|
|
667
|
+
self,
|
|
668
|
+
"Open Metadata",
|
|
669
|
+
os.path.dirname(matrix_path),
|
|
670
|
+
"Metadata (*.parquet *.csv);;All Files (*)"
|
|
671
|
+
)
|
|
672
|
+
self._load_files_async(matrix_path, metadata_path if metadata_path else None)
|
|
673
|
+
|
|
674
|
+
def load_from_registration(self, matrix_path: str, metadata_path: str):
|
|
675
|
+
"""Load data from registration tab (NPZ/Parquet paths)."""
|
|
676
|
+
self._load_data_files(matrix_path, metadata_path)
|
|
677
|
+
|
|
678
|
+
def _load_csvs(self, matrix_path, metadata_path):
|
|
679
|
+
"""Load from CSV files (legacy support)."""
|
|
680
|
+
self._load_data_files(matrix_path, metadata_path)
|
|
681
|
+
|
|
682
|
+
def _load_data_files(self, matrix_path: str, metadata_path: str = None):
|
|
683
|
+
"""Load data from matrix and metadata files (supports NPZ, Parquet, CSV)."""
|
|
684
|
+
try:
|
|
685
|
+
# Load matrix
|
|
686
|
+
if matrix_path.endswith('.npz'):
|
|
687
|
+
npz_data = np.load(matrix_path, allow_pickle=True) # Need allow_pickle for string arrays
|
|
688
|
+
matrix = npz_data['matrix'] # features x snippets
|
|
689
|
+
feature_names = npz_data['feature_names']
|
|
690
|
+
snippet_ids = npz_data['snippet_ids'] if 'snippet_ids' in npz_data else npz_data.get('span_ids', None) # Backward compatibility
|
|
691
|
+
if snippet_ids is None:
|
|
692
|
+
# Fallback: generate snippet IDs
|
|
693
|
+
snippet_ids = np.array([f'snippet{i+1}' for i in range(matrix.shape[1])])
|
|
694
|
+
self.matrix_data = pd.DataFrame(matrix, index=feature_names, columns=snippet_ids)
|
|
695
|
+
elif matrix_path.endswith('.parquet'):
|
|
696
|
+
self.matrix_data = pd.read_parquet(matrix_path, engine='pyarrow')
|
|
697
|
+
else: # CSV
|
|
698
|
+
self.matrix_data = pd.read_csv(matrix_path, index_col=0)
|
|
699
|
+
|
|
700
|
+
# Load metadata
|
|
701
|
+
if metadata_path:
|
|
702
|
+
if metadata_path.endswith('.npz'):
|
|
703
|
+
npz_meta = np.load(metadata_path, allow_pickle=True) # Need allow_pickle for string arrays
|
|
704
|
+
metadata_array = npz_meta['metadata']
|
|
705
|
+
columns = npz_meta['columns']
|
|
706
|
+
self.metadata = pd.DataFrame(metadata_array, columns=columns)
|
|
707
|
+
elif metadata_path.endswith('.parquet'):
|
|
708
|
+
self.metadata = pd.read_parquet(metadata_path, engine='pyarrow')
|
|
709
|
+
else: # CSV
|
|
710
|
+
self.metadata = pd.read_csv(metadata_path)
|
|
711
|
+
self.metadata_file_path = metadata_path
|
|
712
|
+
else:
|
|
713
|
+
self.metadata_file_path = None
|
|
714
|
+
self.metadata = None
|
|
715
|
+
|
|
716
|
+
# Reset processing
|
|
717
|
+
self.processed_data = None
|
|
718
|
+
self.preprocess_status.setText("Raw data loaded")
|
|
719
|
+
|
|
720
|
+
# Format status text to fit in container
|
|
721
|
+
filename = os.path.basename(matrix_path)
|
|
722
|
+
max_filename_len = 35
|
|
723
|
+
if len(filename) > max_filename_len:
|
|
724
|
+
filename = filename[:max_filename_len-3] + "..."
|
|
725
|
+
|
|
726
|
+
status_text = f"Matrix: {self.matrix_data.shape[0]} x {self.matrix_data.shape[1]}\nFile: {filename}"
|
|
727
|
+
if metadata_path:
|
|
728
|
+
meta_filename = os.path.basename(metadata_path)
|
|
729
|
+
if len(meta_filename) > max_filename_len:
|
|
730
|
+
meta_filename = meta_filename[:max_filename_len-3] + "..."
|
|
731
|
+
status_text += f"\nMeta: {meta_filename}"
|
|
732
|
+
|
|
733
|
+
self.load_status_label.setText(status_text)
|
|
734
|
+
self.load_status_label.setStyleSheet("color: green;")
|
|
735
|
+
|
|
736
|
+
# Auto-apply default preprocessing
|
|
737
|
+
self.apply_preprocessing()
|
|
738
|
+
|
|
739
|
+
# Refresh metadata columns in plot settings
|
|
740
|
+
self._refresh_metadata_columns()
|
|
741
|
+
|
|
742
|
+
except Exception as e:
|
|
743
|
+
QMessageBox.critical(self, "Load Error", f"Failed to load data: {e}")
|
|
744
|
+
|
|
745
|
+
def apply_preprocessing(self):
|
|
746
|
+
"""Apply normalization."""
|
|
747
|
+
if self.matrix_data is None:
|
|
748
|
+
return
|
|
749
|
+
|
|
750
|
+
try:
|
|
751
|
+
data = self.matrix_data.copy()
|
|
752
|
+
|
|
753
|
+
# Transpose for sklearn (samples as rows)
|
|
754
|
+
# Matrix format: Rows=Features, Cols=Samples usually?
|
|
755
|
+
# Check registration widget format:
|
|
756
|
+
# feature_matrix.shape = (n_samples, embed_dim) in extraction worker list, then transposed?
|
|
757
|
+
# Registration widget: pd.DataFrame(feature_matrix.T, index=feature_names, columns=snippet_ids)
|
|
758
|
+
# So Rows are Features (dimensions), Columns are Samples (snippets).
|
|
759
|
+
# sklearn expects Samples as Rows. So we transpose.
|
|
760
|
+
|
|
761
|
+
X = data.T
|
|
762
|
+
|
|
763
|
+
# Clean infinite/NaN
|
|
764
|
+
X = X.replace([np.inf, -np.inf], np.nan)
|
|
765
|
+
|
|
766
|
+
# Normalize
|
|
767
|
+
norm_method = self.normalization_method.currentText()
|
|
768
|
+
if norm_method == 'standard':
|
|
769
|
+
scaler = StandardScaler()
|
|
770
|
+
X_norm = scaler.fit_transform(X)
|
|
771
|
+
elif norm_method == 'minmax':
|
|
772
|
+
scaler = MinMaxScaler()
|
|
773
|
+
X_norm = scaler.fit_transform(X)
|
|
774
|
+
elif norm_method == 'l2':
|
|
775
|
+
scaler = Normalizer(norm='l2')
|
|
776
|
+
X_norm = scaler.fit_transform(X)
|
|
777
|
+
else:
|
|
778
|
+
X_norm = X
|
|
779
|
+
|
|
780
|
+
# Store processed data (Samples x Features)
|
|
781
|
+
self.processed_data = pd.DataFrame(X_norm, index=X.index, columns=X.columns)
|
|
782
|
+
|
|
783
|
+
self.preprocess_status.setText(f"Normalized: {norm_method}")
|
|
784
|
+
self.preprocess_status.setStyleSheet("color: green;")
|
|
785
|
+
|
|
786
|
+
except Exception as e:
|
|
787
|
+
QMessageBox.critical(self, "Preprocessing Error", f"Error: {e}")
|
|
788
|
+
|
|
789
|
+
def run_clustering(self):
|
|
790
|
+
"""Start clustering worker."""
|
|
791
|
+
if self.processed_data is None:
|
|
792
|
+
QMessageBox.warning(self, "No Data", "Please load and preprocess data first.")
|
|
793
|
+
return
|
|
794
|
+
|
|
795
|
+
self.run_btn.setEnabled(False)
|
|
796
|
+
self.run_btn.setText("Running...")
|
|
797
|
+
self.status_label.setText("Computing clustering...")
|
|
798
|
+
|
|
799
|
+
params = {
|
|
800
|
+
'n_neighbors': self._get_slider_value(self.n_neighbors),
|
|
801
|
+
'min_dist': self._get_slider_value(self.min_dist, is_float=True),
|
|
802
|
+
'n_components': self._get_slider_value(self.n_components),
|
|
803
|
+
'method': self.clustering_method.currentText(),
|
|
804
|
+
'leiden_resolution': self._get_slider_value(self.leiden_resolution, is_float=True),
|
|
805
|
+
'leiden_k': self._get_slider_value(self.leiden_k),
|
|
806
|
+
'min_cluster_size': self._get_slider_value(self.min_cluster_size),
|
|
807
|
+
'min_samples': self._get_slider_value(self.min_samples),
|
|
808
|
+
'hdbscan_epsilon': self._get_slider_value(self.hdbscan_epsilon, is_float=True)
|
|
809
|
+
}
|
|
810
|
+
|
|
811
|
+
self.worker = ClusteringWorker(self, params, self.processed_data)
|
|
812
|
+
self.worker.finished.connect(self.on_clustering_finished)
|
|
813
|
+
self.worker.start()
|
|
814
|
+
|
|
815
|
+
def perform_clustering(self, data, **params):
|
|
816
|
+
"""Execute clustering logic (runs in thread)."""
|
|
817
|
+
# UMAP Embedding
|
|
818
|
+
reducer = umap.UMAP(
|
|
819
|
+
n_neighbors=params['n_neighbors'],
|
|
820
|
+
min_dist=params['min_dist'],
|
|
821
|
+
n_components=params['n_components'],
|
|
822
|
+
random_state=42
|
|
823
|
+
)
|
|
824
|
+
embedding = reducer.fit_transform(data)
|
|
825
|
+
self.embedding = embedding # Store for export
|
|
826
|
+
|
|
827
|
+
# Clustering
|
|
828
|
+
method = params['method']
|
|
829
|
+
if method == 'leiden':
|
|
830
|
+
# Construct k-NN graph
|
|
831
|
+
knn_graph = kneighbors_graph(
|
|
832
|
+
data,
|
|
833
|
+
n_neighbors=params['leiden_k'],
|
|
834
|
+
mode='connectivity',
|
|
835
|
+
include_self=False
|
|
836
|
+
)
|
|
837
|
+
sources, targets = knn_graph.nonzero()
|
|
838
|
+
edges = list(zip(sources.tolist(), targets.tolist()))
|
|
839
|
+
g = ig.Graph(n=data.shape[0], edges=edges, directed=False)
|
|
840
|
+
|
|
841
|
+
partition = la.find_partition(
|
|
842
|
+
g,
|
|
843
|
+
la.RBConfigurationVertexPartition,
|
|
844
|
+
resolution_parameter=params['leiden_resolution']
|
|
845
|
+
)
|
|
846
|
+
clusters = np.array(partition.membership)
|
|
847
|
+
|
|
848
|
+
elif method == 'hdbscan':
|
|
849
|
+
clusterer = hdbscan.HDBSCAN(
|
|
850
|
+
min_cluster_size=params['min_cluster_size'],
|
|
851
|
+
min_samples=params['min_samples'],
|
|
852
|
+
cluster_selection_epsilon=params['hdbscan_epsilon']
|
|
853
|
+
)
|
|
854
|
+
clusters = clusterer.fit_predict(data)
|
|
855
|
+
|
|
856
|
+
self.clusters = clusters
|
|
857
|
+
|
|
858
|
+
# Plotting
|
|
859
|
+
df_plot = pd.DataFrame({
|
|
860
|
+
'UMAP1': embedding[:, 0],
|
|
861
|
+
'UMAP2': embedding[:, 1],
|
|
862
|
+
'Cluster': [f'Cluster_{c}' if c >= 0 else 'Noise' for c in clusters],
|
|
863
|
+
'Sample': data.index
|
|
864
|
+
})
|
|
865
|
+
|
|
866
|
+
# Add customdata with snippet IDs for click handling
|
|
867
|
+
snippet_ids = data.index.tolist()
|
|
868
|
+
|
|
869
|
+
if params['n_components'] == 3:
|
|
870
|
+
df_plot['UMAP3'] = embedding[:, 2]
|
|
871
|
+
fig = px.scatter_3d(
|
|
872
|
+
df_plot, x='UMAP1', y='UMAP2', z='UMAP3',
|
|
873
|
+
color='Cluster', hover_data=['Sample'],
|
|
874
|
+
title=f"UMAP + {method.title()} Clustering",
|
|
875
|
+
custom_data=[snippet_ids] # Add snippet IDs for click handling
|
|
876
|
+
)
|
|
877
|
+
else:
|
|
878
|
+
fig = px.scatter(
|
|
879
|
+
df_plot, x='UMAP1', y='UMAP2',
|
|
880
|
+
color='Cluster', hover_data=['Sample'],
|
|
881
|
+
title=f"UMAP + {method.title()} Clustering",
|
|
882
|
+
custom_data=[snippet_ids] # Add snippet IDs for click handling
|
|
883
|
+
)
|
|
884
|
+
|
|
885
|
+
theme = self._get_plot_theme()
|
|
886
|
+
fig.update_layout(template=theme if theme else None)
|
|
887
|
+
return "Clustering Complete", fig
|
|
888
|
+
|
|
889
|
+
def on_clustering_finished(self, status, fig):
|
|
890
|
+
self.run_btn.setEnabled(True)
|
|
891
|
+
self.run_btn.setText("Run analysis")
|
|
892
|
+
self.status_label.setText(status)
|
|
893
|
+
|
|
894
|
+
if fig:
|
|
895
|
+
point_size = self.plot_point_size_slider.value() if hasattr(self, 'plot_point_size_slider') else 5
|
|
896
|
+
fig.update_traces(marker=dict(size=point_size))
|
|
897
|
+
self.current_fig = fig
|
|
898
|
+
self.plot_widget.update_plot(fig)
|
|
899
|
+
|
|
900
|
+
# Enable export buttons
|
|
901
|
+
if hasattr(self, 'export_plot_btn'):
|
|
902
|
+
self.export_plot_btn.setEnabled(True)
|
|
903
|
+
if hasattr(self, 'export_csv_btn'):
|
|
904
|
+
self.export_csv_btn.setEnabled(True)
|
|
905
|
+
|
|
906
|
+
# Build snippet-to-clip mapping after clustering
|
|
907
|
+
self._build_snippet_to_clip_map()
|
|
908
|
+
|
|
909
|
+
# Immediately update metadata with cluster assignments
|
|
910
|
+
self._update_metadata_with_clusters()
|
|
911
|
+
|
|
912
|
+
# Refresh metadata columns and cluster list after clustering
|
|
913
|
+
self._refresh_metadata_columns()
|
|
914
|
+
self._refresh_cluster_list()
|
|
915
|
+
self._refresh_cluster_export_list()
|
|
916
|
+
|
|
917
|
+
# Update plots if metadata column is selected
|
|
918
|
+
if hasattr(self, 'metadata_column_combo') and self.metadata_column_combo.currentText() != "None (Show all clusters)":
|
|
919
|
+
self._update_plots_by_metadata()
|
|
920
|
+
else:
|
|
921
|
+
QMessageBox.warning(self, "Error", status)
|
|
922
|
+
|
|
923
|
+
def _update_metadata_with_clusters(self):
|
|
924
|
+
"""Update metadata CSV file with cluster assignments."""
|
|
925
|
+
if self.metadata is None or self.clusters is None or self.processed_data is None:
|
|
926
|
+
return
|
|
927
|
+
|
|
928
|
+
try:
|
|
929
|
+
# Map clusters to snippet_ids
|
|
930
|
+
# processed_data.index contains snippet_ids (from matrix columns)
|
|
931
|
+
cluster_series = pd.Series(self.clusters, index=self.processed_data.index, name='Cluster')
|
|
932
|
+
|
|
933
|
+
# Convert cluster numbers to cluster labels
|
|
934
|
+
cluster_labels = cluster_series.map(lambda c: f'Cluster_{c}' if c >= 0 else 'Noise')
|
|
935
|
+
|
|
936
|
+
# Update metadata
|
|
937
|
+
# Check for both 'snippet' and 'span_id' (backward compatibility)
|
|
938
|
+
snippet_col = 'snippet' if 'snippet' in self.metadata.columns else ('span_id' if 'span_id' in self.metadata.columns else None)
|
|
939
|
+
if snippet_col:
|
|
940
|
+
# Merge cluster assignments by snippet/span_id
|
|
941
|
+
cluster_df = pd.DataFrame({
|
|
942
|
+
snippet_col: self.processed_data.index,
|
|
943
|
+
'Cluster': cluster_labels
|
|
944
|
+
})
|
|
945
|
+
|
|
946
|
+
# Update or add Cluster column
|
|
947
|
+
if 'Cluster' in self.metadata.columns:
|
|
948
|
+
# Remove old Cluster column
|
|
949
|
+
self.metadata = self.metadata.drop(columns=['Cluster'])
|
|
950
|
+
|
|
951
|
+
# Merge new cluster assignments
|
|
952
|
+
self.metadata = self.metadata.merge(cluster_df, on=snippet_col, how='left')
|
|
953
|
+
|
|
954
|
+
# Save updated metadata back to file
|
|
955
|
+
if self.metadata_file_path:
|
|
956
|
+
# Use the original metadata file path, respecting format
|
|
957
|
+
self._save_metadata_to_file(self.metadata, self.metadata_file_path)
|
|
958
|
+
current_text = self.status_label.text()
|
|
959
|
+
self.status_label.setText(f"{current_text}\nMetadata updated with cluster assignments")
|
|
960
|
+
else:
|
|
961
|
+
# Try to find metadata file in experiment folder
|
|
962
|
+
experiment_path = self.config.get("experiment_path")
|
|
963
|
+
if experiment_path:
|
|
964
|
+
registered_clips_dir = os.path.join(experiment_path, "registered_clips")
|
|
965
|
+
if os.path.exists(registered_clips_dir):
|
|
966
|
+
# Find the latest metadata file (prefer parquet, then npz, then csv)
|
|
967
|
+
meta_files = [f for f in os.listdir(registered_clips_dir)
|
|
968
|
+
if f.startswith("behaviorome_") and "_metadata" in f]
|
|
969
|
+
if meta_files:
|
|
970
|
+
# Prefer parquet > npz > csv
|
|
971
|
+
parquet_files = [f for f in meta_files if f.endswith(".parquet")]
|
|
972
|
+
npz_files = [f for f in meta_files if f.endswith(".npz")]
|
|
973
|
+
csv_files = [f for f in meta_files if f.endswith(".csv")]
|
|
974
|
+
|
|
975
|
+
chosen_file = None
|
|
976
|
+
if parquet_files:
|
|
977
|
+
parquet_files.sort(reverse=True)
|
|
978
|
+
chosen_file = parquet_files[0]
|
|
979
|
+
elif csv_files:
|
|
980
|
+
csv_files.sort(reverse=True)
|
|
981
|
+
chosen_file = csv_files[0]
|
|
982
|
+
# Skip npz for saving to avoid format issues
|
|
983
|
+
|
|
984
|
+
if chosen_file:
|
|
985
|
+
metadata_file = os.path.join(registered_clips_dir, chosen_file)
|
|
986
|
+
self._save_metadata_to_file(self.metadata, metadata_file)
|
|
987
|
+
self.metadata_file_path = metadata_file
|
|
988
|
+
current_text = self.status_label.text()
|
|
989
|
+
self.status_label.setText(f"{current_text}\nMetadata updated with cluster assignments")
|
|
990
|
+
|
|
991
|
+
except Exception as e:
|
|
992
|
+
logger.error("Could not update metadata file: %s", e, exc_info=True)
|
|
993
|
+
QMessageBox.warning(self, "Metadata Update Warning",
|
|
994
|
+
f"Could not update metadata file:\n{str(e)}\n\nCluster assignments are still available for export.")
|
|
995
|
+
|
|
996
|
+
def export_results(self):
|
|
997
|
+
"""Export current plot data to CSV."""
|
|
998
|
+
if self.processed_data is None or self.clusters is None:
|
|
999
|
+
QMessageBox.warning(self, "No Data", "No results to export.")
|
|
1000
|
+
return
|
|
1001
|
+
|
|
1002
|
+
# Determine plot type and get appropriate data
|
|
1003
|
+
plot_type = self.plot_type_combo.currentText() if hasattr(self, 'plot_type_combo') else "UMAP"
|
|
1004
|
+
group_name = self.metadata_column_combo.currentText() if hasattr(self, 'metadata_column_combo') else None
|
|
1005
|
+
|
|
1006
|
+
experiment_path = self.config.get("experiment_path")
|
|
1007
|
+
if not experiment_path:
|
|
1008
|
+
return
|
|
1009
|
+
|
|
1010
|
+
output_dir = os.path.join(experiment_path, "analysis_results")
|
|
1011
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
1012
|
+
|
|
1013
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
1014
|
+
|
|
1015
|
+
# Create base dataframe with UMAP coordinates and clusters
|
|
1016
|
+
df_export = pd.DataFrame({
|
|
1017
|
+
'UMAP1': self.embedding[:, 0],
|
|
1018
|
+
'UMAP2': self.embedding[:, 1],
|
|
1019
|
+
'Cluster': [f'Cluster_{c}' if c >= 0 else 'Noise' for c in self.clusters],
|
|
1020
|
+
'Sample': self.processed_data.index
|
|
1021
|
+
})
|
|
1022
|
+
|
|
1023
|
+
if self.embedding.shape[1] > 2:
|
|
1024
|
+
df_export['UMAP3'] = self.embedding[:, 2]
|
|
1025
|
+
|
|
1026
|
+
# Add metadata if available
|
|
1027
|
+
if self.metadata is not None:
|
|
1028
|
+
snippet_col = 'snippet' if 'snippet' in self.metadata.columns else ('span_id' if 'span_id' in self.metadata.columns else None)
|
|
1029
|
+
if snippet_col and snippet_col in self.metadata.columns:
|
|
1030
|
+
df_export = df_export.merge(
|
|
1031
|
+
self.metadata,
|
|
1032
|
+
left_on='Sample',
|
|
1033
|
+
right_on=snippet_col,
|
|
1034
|
+
how='left'
|
|
1035
|
+
)
|
|
1036
|
+
|
|
1037
|
+
# Add plot-specific data based on plot type
|
|
1038
|
+
if plot_type == "Cluster Proportions" and group_name and group_name != "None (Show all clusters)":
|
|
1039
|
+
if group_name in df_export.columns:
|
|
1040
|
+
filename = os.path.join(output_dir, f"cluster_proportions_{group_name}_{timestamp}.csv")
|
|
1041
|
+
else:
|
|
1042
|
+
filename = os.path.join(output_dir, f"cluster_proportions_{timestamp}.csv")
|
|
1043
|
+
elif plot_type == "Single Cluster Analysis" and group_name and group_name != "None (Show all clusters)":
|
|
1044
|
+
selected_cluster = self.single_cluster_combo.currentText() if hasattr(self, 'single_cluster_combo') and self.single_cluster_combo.currentText() != "None" else None
|
|
1045
|
+
if selected_cluster and group_name in df_export.columns:
|
|
1046
|
+
filename = os.path.join(output_dir, f"single_cluster_{selected_cluster}_{group_name}_{timestamp}.csv")
|
|
1047
|
+
else:
|
|
1048
|
+
filename = os.path.join(output_dir, f"single_cluster_analysis_{timestamp}.csv")
|
|
1049
|
+
else:
|
|
1050
|
+
filename = os.path.join(output_dir, f"clustering_results_{timestamp}.csv")
|
|
1051
|
+
|
|
1052
|
+
df_export.to_csv(filename, index=False)
|
|
1053
|
+
QMessageBox.information(self, "Export", f"Results saved to:\n{filename}")
|
|
1054
|
+
|
|
1055
|
+
def open_metadata_manager(self):
|
|
1056
|
+
"""Open metadata management dialog."""
|
|
1057
|
+
if self.metadata is None:
|
|
1058
|
+
QMessageBox.warning(self, "No Data", "Please load metadata first.")
|
|
1059
|
+
return
|
|
1060
|
+
|
|
1061
|
+
# Import here to avoid circular imports
|
|
1062
|
+
from .metadata_management_widget import MetadataManagementDialog
|
|
1063
|
+
|
|
1064
|
+
dialog = MetadataManagementDialog(self.metadata.copy(), self.metadata_file_path, self.config, self)
|
|
1065
|
+
if dialog.exec() == QDialog.DialogCode.Accepted:
|
|
1066
|
+
# Update metadata from dialog
|
|
1067
|
+
self.metadata = dialog.get_metadata()
|
|
1068
|
+
self.metadata_file_path = dialog.get_metadata_path()
|
|
1069
|
+
|
|
1070
|
+
# Reload data to reflect changes
|
|
1071
|
+
if self.matrix_data is not None:
|
|
1072
|
+
# Reapply preprocessing if needed
|
|
1073
|
+
self.apply_preprocessing()
|
|
1074
|
+
|
|
1075
|
+
# Refresh metadata columns
|
|
1076
|
+
self._refresh_metadata_columns()
|
|
1077
|
+
|
|
1078
|
+
# Update plots if metadata column is selected
|
|
1079
|
+
if hasattr(self, 'metadata_column_combo') and self.metadata_column_combo.currentText() != "None (Show all clusters)":
|
|
1080
|
+
self._update_plots_by_metadata()
|
|
1081
|
+
|
|
1082
|
+
QMessageBox.information(self, "Success", "Metadata updated successfully.")
|
|
1083
|
+
|
|
1084
|
+
def _refresh_metadata_columns(self):
|
|
1085
|
+
"""Refresh metadata columns in combo box."""
|
|
1086
|
+
if not hasattr(self, 'metadata_column_combo'):
|
|
1087
|
+
return
|
|
1088
|
+
|
|
1089
|
+
# Block signals to prevent recursion
|
|
1090
|
+
self.metadata_column_combo.blockSignals(True)
|
|
1091
|
+
|
|
1092
|
+
try:
|
|
1093
|
+
current_selection = self.metadata_column_combo.currentText()
|
|
1094
|
+
self.metadata_column_combo.clear()
|
|
1095
|
+
self.metadata_column_combo.addItem("None (Show all clusters)")
|
|
1096
|
+
|
|
1097
|
+
if self.metadata is not None:
|
|
1098
|
+
# Include all columns from metadata (user can choose which to use)
|
|
1099
|
+
available_cols = list(self.metadata.columns)
|
|
1100
|
+
self.metadata_column_combo.addItems(available_cols)
|
|
1101
|
+
|
|
1102
|
+
# Restore previous selection if still available
|
|
1103
|
+
if current_selection in available_cols:
|
|
1104
|
+
idx = self.metadata_column_combo.findText(current_selection)
|
|
1105
|
+
if idx >= 0:
|
|
1106
|
+
self.metadata_column_combo.setCurrentIndex(idx)
|
|
1107
|
+
finally:
|
|
1108
|
+
self.metadata_column_combo.blockSignals(False)
|
|
1109
|
+
|
|
1110
|
+
def _refresh_cluster_list(self):
|
|
1111
|
+
"""Refresh cluster list in single cluster combo."""
|
|
1112
|
+
if not hasattr(self, 'single_cluster_combo') or self.clusters is None:
|
|
1113
|
+
return
|
|
1114
|
+
|
|
1115
|
+
# Block signals to prevent recursion
|
|
1116
|
+
self.single_cluster_combo.blockSignals(True)
|
|
1117
|
+
|
|
1118
|
+
try:
|
|
1119
|
+
current_selection = self.single_cluster_combo.currentText()
|
|
1120
|
+
self.single_cluster_combo.clear()
|
|
1121
|
+
self.single_cluster_combo.addItem("None")
|
|
1122
|
+
|
|
1123
|
+
# Normalize cluster identifiers to string labels: Cluster_X or Noise
|
|
1124
|
+
labels = set()
|
|
1125
|
+
for cid in set(self.clusters):
|
|
1126
|
+
label = None
|
|
1127
|
+
if isinstance(cid, str):
|
|
1128
|
+
lc = cid.lower()
|
|
1129
|
+
if lc.startswith("cluster_"):
|
|
1130
|
+
label = cid
|
|
1131
|
+
elif lc == "noise":
|
|
1132
|
+
label = "Noise"
|
|
1133
|
+
else:
|
|
1134
|
+
# Try to parse numeric from string
|
|
1135
|
+
try:
|
|
1136
|
+
num = int(cid)
|
|
1137
|
+
label = f"Cluster_{num}" if num >= 0 else "Noise"
|
|
1138
|
+
except Exception:
|
|
1139
|
+
label = None
|
|
1140
|
+
else:
|
|
1141
|
+
# numeric cluster id
|
|
1142
|
+
try:
|
|
1143
|
+
num = int(cid)
|
|
1144
|
+
label = f"Cluster_{num}" if num >= 0 else "Noise"
|
|
1145
|
+
except Exception:
|
|
1146
|
+
label = None
|
|
1147
|
+
|
|
1148
|
+
if label:
|
|
1149
|
+
labels.add(label)
|
|
1150
|
+
|
|
1151
|
+
# Sort labels numerically, keep Noise last if present
|
|
1152
|
+
def _label_key(lbl):
|
|
1153
|
+
if lbl == "Noise":
|
|
1154
|
+
return (1e9,)
|
|
1155
|
+
if lbl.lower().startswith("cluster_"):
|
|
1156
|
+
try:
|
|
1157
|
+
return (int(lbl.split("_", 1)[1]),)
|
|
1158
|
+
except Exception:
|
|
1159
|
+
return (1e8,)
|
|
1160
|
+
return (1e8,)
|
|
1161
|
+
|
|
1162
|
+
for lbl in sorted(labels, key=_label_key):
|
|
1163
|
+
self.single_cluster_combo.addItem(lbl)
|
|
1164
|
+
|
|
1165
|
+
# Restore previous selection if still available
|
|
1166
|
+
if current_selection and current_selection != "None":
|
|
1167
|
+
idx = self.single_cluster_combo.findText(current_selection)
|
|
1168
|
+
if idx >= 0:
|
|
1169
|
+
self.single_cluster_combo.setCurrentIndex(idx)
|
|
1170
|
+
finally:
|
|
1171
|
+
# Always unblock signals
|
|
1172
|
+
self.single_cluster_combo.blockSignals(False)
|
|
1173
|
+
|
|
1174
|
+
def _refresh_spatial_selectors(self):
|
|
1175
|
+
"""Refresh video and object selectors for spatial distribution plot."""
|
|
1176
|
+
if self.metadata is None:
|
|
1177
|
+
return
|
|
1178
|
+
|
|
1179
|
+
# Block signals to prevent recursion
|
|
1180
|
+
self.spatial_video_combo.blockSignals(True)
|
|
1181
|
+
self.spatial_object_combo.blockSignals(True)
|
|
1182
|
+
|
|
1183
|
+
try:
|
|
1184
|
+
# Get current selections
|
|
1185
|
+
current_video = self.spatial_video_combo.currentText()
|
|
1186
|
+
current_object = self.spatial_object_combo.currentText()
|
|
1187
|
+
|
|
1188
|
+
# Get unique videos from metadata
|
|
1189
|
+
# Prefer 'group' column (contains original video/animal name)
|
|
1190
|
+
# Otherwise extract from 'video_id' (clip filenames)
|
|
1191
|
+
videos = ["All"]
|
|
1192
|
+
video_names = set()
|
|
1193
|
+
|
|
1194
|
+
if 'group' in self.metadata.columns:
|
|
1195
|
+
# Use group column which has original video names
|
|
1196
|
+
for v in self.metadata['group'].dropna().unique():
|
|
1197
|
+
v_str = str(v).strip()
|
|
1198
|
+
if v_str:
|
|
1199
|
+
video_names.add(v_str)
|
|
1200
|
+
elif 'video_id' in self.metadata.columns:
|
|
1201
|
+
# Extract video name from clip filenames
|
|
1202
|
+
# Clip format: {video_name}_clip_{clip_idx:06d}_obj{obj_id}.mp4
|
|
1203
|
+
import re
|
|
1204
|
+
for v in self.metadata['video_id'].dropna().unique():
|
|
1205
|
+
v_str = str(v)
|
|
1206
|
+
base = os.path.splitext(os.path.basename(v_str))[0]
|
|
1207
|
+
# Remove _clip_XXXXXX and _objX suffixes
|
|
1208
|
+
match = re.match(r'^(.+?)_clip_\d+(?:_obj\d+)?$', base)
|
|
1209
|
+
if match:
|
|
1210
|
+
video_name = match.group(1)
|
|
1211
|
+
if video_name:
|
|
1212
|
+
video_names.add(video_name)
|
|
1213
|
+
else:
|
|
1214
|
+
# Fallback: use base name if pattern doesn't match
|
|
1215
|
+
if base:
|
|
1216
|
+
video_names.add(base)
|
|
1217
|
+
|
|
1218
|
+
for v in sorted(video_names):
|
|
1219
|
+
videos.append(v)
|
|
1220
|
+
|
|
1221
|
+
self.spatial_video_combo.clear()
|
|
1222
|
+
for v in videos:
|
|
1223
|
+
self.spatial_video_combo.addItem(v)
|
|
1224
|
+
|
|
1225
|
+
# Restore video selection
|
|
1226
|
+
if current_video in videos:
|
|
1227
|
+
idx = self.spatial_video_combo.findText(current_video)
|
|
1228
|
+
if idx >= 0:
|
|
1229
|
+
self.spatial_video_combo.setCurrentIndex(idx)
|
|
1230
|
+
|
|
1231
|
+
# Get unique objects from metadata
|
|
1232
|
+
object_col = 'object_id' if 'object_id' in self.metadata.columns else None
|
|
1233
|
+
objects = ["All"]
|
|
1234
|
+
if object_col:
|
|
1235
|
+
unique_objects = self.metadata[object_col].dropna().unique()
|
|
1236
|
+
for o in sorted(set(str(obj) for obj in unique_objects if str(obj).strip())):
|
|
1237
|
+
if o not in objects:
|
|
1238
|
+
objects.append(o)
|
|
1239
|
+
|
|
1240
|
+
self.spatial_object_combo.clear()
|
|
1241
|
+
for o in objects:
|
|
1242
|
+
self.spatial_object_combo.addItem(o)
|
|
1243
|
+
|
|
1244
|
+
# Restore object selection
|
|
1245
|
+
if current_object in objects:
|
|
1246
|
+
idx = self.spatial_object_combo.findText(current_object)
|
|
1247
|
+
if idx >= 0:
|
|
1248
|
+
self.spatial_object_combo.setCurrentIndex(idx)
|
|
1249
|
+
|
|
1250
|
+
finally:
|
|
1251
|
+
self.spatial_video_combo.blockSignals(False)
|
|
1252
|
+
self.spatial_object_combo.blockSignals(False)
|
|
1253
|
+
|
|
1254
|
+
def _on_spatial_video_changed(self):
|
|
1255
|
+
"""Handle video selection change - refresh object list and update plot."""
|
|
1256
|
+
# Refresh object list based on selected video
|
|
1257
|
+
if self.metadata is None:
|
|
1258
|
+
return
|
|
1259
|
+
|
|
1260
|
+
selected_video = self.spatial_video_combo.currentText()
|
|
1261
|
+
object_col = 'object_id' if 'object_id' in self.metadata.columns else None
|
|
1262
|
+
video_col = 'video_id' if 'video_id' in self.metadata.columns else None
|
|
1263
|
+
|
|
1264
|
+
self.spatial_object_combo.blockSignals(True)
|
|
1265
|
+
try:
|
|
1266
|
+
current_object = self.spatial_object_combo.currentText()
|
|
1267
|
+
self.spatial_object_combo.clear()
|
|
1268
|
+
self.spatial_object_combo.addItem("All")
|
|
1269
|
+
|
|
1270
|
+
if object_col:
|
|
1271
|
+
if selected_video == "All":
|
|
1272
|
+
# Show all objects across all videos
|
|
1273
|
+
unique_objects = self.metadata[object_col].dropna().unique()
|
|
1274
|
+
else:
|
|
1275
|
+
# Filter objects by selected video
|
|
1276
|
+
if 'group' in self.metadata.columns:
|
|
1277
|
+
# Use group column for matching
|
|
1278
|
+
mask = self.metadata['group'].apply(lambda x: str(x).strip() == selected_video)
|
|
1279
|
+
elif video_col:
|
|
1280
|
+
# Extract video name from clip filenames and match
|
|
1281
|
+
import re
|
|
1282
|
+
def extract_video_name(clip_name):
|
|
1283
|
+
base = os.path.splitext(os.path.basename(str(clip_name)))[0]
|
|
1284
|
+
match = re.match(r'^(.+?)_clip_\d+(?:_obj\d+)?$', base)
|
|
1285
|
+
return match.group(1) if match else base
|
|
1286
|
+
mask = self.metadata[video_col].apply(lambda x: extract_video_name(x) == selected_video)
|
|
1287
|
+
else:
|
|
1288
|
+
mask = pd.Series([True] * len(self.metadata), index=self.metadata.index)
|
|
1289
|
+
unique_objects = self.metadata.loc[mask, object_col].dropna().unique()
|
|
1290
|
+
|
|
1291
|
+
for o in sorted(set(str(obj) for obj in unique_objects if str(obj).strip())):
|
|
1292
|
+
self.spatial_object_combo.addItem(o)
|
|
1293
|
+
|
|
1294
|
+
# Try to restore previous selection
|
|
1295
|
+
idx = self.spatial_object_combo.findText(current_object)
|
|
1296
|
+
if idx >= 0:
|
|
1297
|
+
self.spatial_object_combo.setCurrentIndex(idx)
|
|
1298
|
+
finally:
|
|
1299
|
+
self.spatial_object_combo.blockSignals(False)
|
|
1300
|
+
|
|
1301
|
+
# Update plot
|
|
1302
|
+
self._update_plots_by_metadata()
|
|
1303
|
+
|
|
1304
|
+
def _refresh_cluster_export_list(self):
|
|
1305
|
+
"""Refresh cluster list in export list widget."""
|
|
1306
|
+
if not hasattr(self, 'cluster_export_list') or self.clusters is None:
|
|
1307
|
+
return
|
|
1308
|
+
|
|
1309
|
+
self.cluster_export_list.clear()
|
|
1310
|
+
|
|
1311
|
+
# Get unique cluster labels
|
|
1312
|
+
labels = set()
|
|
1313
|
+
for cid in set(self.clusters):
|
|
1314
|
+
label = None
|
|
1315
|
+
if isinstance(cid, str):
|
|
1316
|
+
lc = cid.lower()
|
|
1317
|
+
if lc.startswith("cluster_"):
|
|
1318
|
+
label = cid
|
|
1319
|
+
elif lc == "noise":
|
|
1320
|
+
label = "Noise"
|
|
1321
|
+
else:
|
|
1322
|
+
try:
|
|
1323
|
+
num = int(cid)
|
|
1324
|
+
label = f"Cluster_{num}" if num >= 0 else "Noise"
|
|
1325
|
+
except Exception:
|
|
1326
|
+
label = None
|
|
1327
|
+
else:
|
|
1328
|
+
try:
|
|
1329
|
+
num = int(cid)
|
|
1330
|
+
label = f"Cluster_{num}" if num >= 0 else "Noise"
|
|
1331
|
+
except Exception:
|
|
1332
|
+
label = None
|
|
1333
|
+
|
|
1334
|
+
if label:
|
|
1335
|
+
labels.add(label)
|
|
1336
|
+
|
|
1337
|
+
# Sort labels numerically, keep Noise last if present
|
|
1338
|
+
def _label_key(lbl):
|
|
1339
|
+
if lbl == "Noise":
|
|
1340
|
+
return (1e9,)
|
|
1341
|
+
if lbl.lower().startswith("cluster_"):
|
|
1342
|
+
try:
|
|
1343
|
+
num = int(lbl.split('_')[1])
|
|
1344
|
+
return (0, num)
|
|
1345
|
+
except Exception:
|
|
1346
|
+
return (2, lbl)
|
|
1347
|
+
return (2, lbl)
|
|
1348
|
+
|
|
1349
|
+
sorted_labels = sorted(labels, key=_label_key)
|
|
1350
|
+
|
|
1351
|
+
for label in sorted_labels:
|
|
1352
|
+
self.cluster_export_list.addItem(label)
|
|
1353
|
+
|
|
1354
|
+
# Enable buttons if clusters are available
|
|
1355
|
+
if hasattr(self, 'extract_cluster_btn'):
|
|
1356
|
+
self.extract_cluster_btn.setEnabled(len(sorted_labels) > 0)
|
|
1357
|
+
if hasattr(self, 'exclude_clusters_btn'):
|
|
1358
|
+
self.exclude_clusters_btn.setEnabled(len(sorted_labels) > 0)
|
|
1359
|
+
|
|
1360
|
+
def _extract_cluster(self):
|
|
1361
|
+
"""Extract data for a specific cluster and save as NPZ files."""
|
|
1362
|
+
if self.clusters is None:
|
|
1363
|
+
QMessageBox.warning(self, "No Data", "No clustering data available. Please perform clustering first.")
|
|
1364
|
+
return
|
|
1365
|
+
|
|
1366
|
+
selected_items = self.cluster_export_list.selectedItems()
|
|
1367
|
+
if not selected_items:
|
|
1368
|
+
QMessageBox.warning(self, "No Selection", "Please select a cluster to extract.")
|
|
1369
|
+
return
|
|
1370
|
+
|
|
1371
|
+
if len(selected_items) > 1:
|
|
1372
|
+
QMessageBox.warning(self, "Multiple Selection", "Please select only one cluster for extraction. Use 'Exclude Selected Clusters' for multiple clusters.")
|
|
1373
|
+
return
|
|
1374
|
+
|
|
1375
|
+
selected_cluster = selected_items[0].text()
|
|
1376
|
+
use_raw_data = self.use_raw_data_checkbox.isChecked()
|
|
1377
|
+
|
|
1378
|
+
try:
|
|
1379
|
+
# Get cluster number from string (e.g., "Cluster_0" -> 0, "Noise" -> -1)
|
|
1380
|
+
if selected_cluster.lower() == "noise":
|
|
1381
|
+
cluster_num = -1
|
|
1382
|
+
else:
|
|
1383
|
+
cluster_num = int(selected_cluster.split('_')[-1])
|
|
1384
|
+
|
|
1385
|
+
# Get indices of samples in the selected cluster
|
|
1386
|
+
cluster_indices = [i for i, c in enumerate(self.clusters) if c == cluster_num]
|
|
1387
|
+
|
|
1388
|
+
if not cluster_indices:
|
|
1389
|
+
QMessageBox.warning(self, "No Data", f"No samples found in {selected_cluster}")
|
|
1390
|
+
return
|
|
1391
|
+
|
|
1392
|
+
# Choose data source based on user preference
|
|
1393
|
+
if use_raw_data:
|
|
1394
|
+
if self.matrix_data is None:
|
|
1395
|
+
QMessageBox.warning(self, "No Data", "Raw data not available.")
|
|
1396
|
+
return
|
|
1397
|
+
data = self.matrix_data
|
|
1398
|
+
data_type = "raw"
|
|
1399
|
+
else:
|
|
1400
|
+
if self.processed_data is None:
|
|
1401
|
+
QMessageBox.warning(self, "No Data", "Processed data not available. Please apply preprocessing first.")
|
|
1402
|
+
return
|
|
1403
|
+
data = self.processed_data.T # Transpose back to features x samples format
|
|
1404
|
+
data_type = "processed"
|
|
1405
|
+
|
|
1406
|
+
metadata = self.metadata
|
|
1407
|
+
|
|
1408
|
+
if data is None or metadata is None:
|
|
1409
|
+
QMessageBox.warning(self, "No Data", "Original data not available. Please ensure data is loaded.")
|
|
1410
|
+
return
|
|
1411
|
+
|
|
1412
|
+
# Extract subset of data (columns are samples/snippets)
|
|
1413
|
+
subset_data = data.iloc[:, cluster_indices]
|
|
1414
|
+
|
|
1415
|
+
# Get snippet IDs for the selected samples
|
|
1416
|
+
snippet_ids = subset_data.columns.tolist()
|
|
1417
|
+
|
|
1418
|
+
# Extract corresponding metadata
|
|
1419
|
+
snippet_col = 'snippet' if 'snippet' in metadata.columns else ('span_id' if 'span_id' in metadata.columns else None)
|
|
1420
|
+
if snippet_col:
|
|
1421
|
+
subset_metadata = metadata[metadata[snippet_col].isin(snippet_ids)].copy()
|
|
1422
|
+
else:
|
|
1423
|
+
# Fallback: align by index if snippet column not found
|
|
1424
|
+
subset_metadata = metadata.iloc[cluster_indices].copy() if len(metadata) == len(self.clusters) else metadata.copy()
|
|
1425
|
+
|
|
1426
|
+
# Create timestamp for unique filenames
|
|
1427
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
1428
|
+
|
|
1429
|
+
# Determine output directory
|
|
1430
|
+
experiment_path = self.config.get("experiment_path")
|
|
1431
|
+
if experiment_path:
|
|
1432
|
+
output_dir = os.path.join(experiment_path, "analysis_results")
|
|
1433
|
+
else:
|
|
1434
|
+
output_dir = os.getcwd()
|
|
1435
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
1436
|
+
|
|
1437
|
+
# Save matrix as NPZ
|
|
1438
|
+
matrix_filename = f"matrix_{selected_cluster}_{data_type}_{timestamp}.npz"
|
|
1439
|
+
matrix_path = os.path.join(output_dir, matrix_filename)
|
|
1440
|
+
|
|
1441
|
+
feature_names = subset_data.index.tolist()
|
|
1442
|
+
matrix_array = subset_data.values # features x samples
|
|
1443
|
+
|
|
1444
|
+
np.savez_compressed(
|
|
1445
|
+
matrix_path,
|
|
1446
|
+
matrix=matrix_array,
|
|
1447
|
+
feature_names=np.array(feature_names, dtype=object),
|
|
1448
|
+
snippet_ids=np.array(snippet_ids, dtype=object),
|
|
1449
|
+
)
|
|
1450
|
+
|
|
1451
|
+
# Save metadata as NPZ
|
|
1452
|
+
metadata_filename = f"metadata_{selected_cluster}_{timestamp}.npz"
|
|
1453
|
+
metadata_path = os.path.join(output_dir, metadata_filename)
|
|
1454
|
+
|
|
1455
|
+
np.savez_compressed(
|
|
1456
|
+
metadata_path,
|
|
1457
|
+
metadata=subset_metadata.values,
|
|
1458
|
+
columns=np.array(subset_metadata.columns, dtype=object),
|
|
1459
|
+
)
|
|
1460
|
+
|
|
1461
|
+
msg = (f"Successfully extracted {len(cluster_indices)} samples from {selected_cluster}.\n"
|
|
1462
|
+
f"Data type: {data_type}\n"
|
|
1463
|
+
f"Saved as:\n- {matrix_filename} (shape: {subset_data.shape})\n"
|
|
1464
|
+
f"- {metadata_filename} (shape: {subset_metadata.shape})\n\n"
|
|
1465
|
+
f"Click 'Load Dataset' to load this data now, or 'OK' to continue with current data.")
|
|
1466
|
+
|
|
1467
|
+
dialog = ClusterExportDialog(self, msg, matrix_path, metadata_path)
|
|
1468
|
+
dialog.exec()
|
|
1469
|
+
|
|
1470
|
+
# Load dataset if user clicked "Load Dataset"
|
|
1471
|
+
if dialog.load_requested:
|
|
1472
|
+
self._load_files_async(matrix_path, metadata_path)
|
|
1473
|
+
|
|
1474
|
+
except Exception as e:
|
|
1475
|
+
error_msg = f"Error extracting cluster data: {str(e)}"
|
|
1476
|
+
logger.error("Error extracting cluster data: %s", e, exc_info=True)
|
|
1477
|
+
QMessageBox.critical(self, "Error", error_msg)
|
|
1478
|
+
|
|
1479
|
+
def _exclude_clusters(self):
|
|
1480
|
+
"""Export data excluding specific clusters as NPZ files."""
|
|
1481
|
+
if self.clusters is None:
|
|
1482
|
+
QMessageBox.warning(self, "No Data", "No clustering data available. Please perform clustering first.")
|
|
1483
|
+
return
|
|
1484
|
+
|
|
1485
|
+
selected_items = self.cluster_export_list.selectedItems()
|
|
1486
|
+
if not selected_items:
|
|
1487
|
+
QMessageBox.warning(self, "No Selection", "Please select at least one cluster to exclude.")
|
|
1488
|
+
return
|
|
1489
|
+
|
|
1490
|
+
clusters_to_exclude = [item.text() for item in selected_items]
|
|
1491
|
+
use_raw_data = self.use_raw_data_checkbox.isChecked()
|
|
1492
|
+
|
|
1493
|
+
try:
|
|
1494
|
+
# Get cluster numbers from strings
|
|
1495
|
+
exclude_nums = []
|
|
1496
|
+
for cluster_str in clusters_to_exclude:
|
|
1497
|
+
if cluster_str.lower() == "noise":
|
|
1498
|
+
exclude_nums.append(-1)
|
|
1499
|
+
else:
|
|
1500
|
+
exclude_nums.append(int(cluster_str.split('_')[-1]))
|
|
1501
|
+
|
|
1502
|
+
# Get indices of samples NOT in the excluded clusters
|
|
1503
|
+
keep_indices = [i for i, c in enumerate(self.clusters) if c not in exclude_nums]
|
|
1504
|
+
|
|
1505
|
+
if not keep_indices:
|
|
1506
|
+
QMessageBox.warning(self, "Error", "Excluding these clusters would remove all data!")
|
|
1507
|
+
return
|
|
1508
|
+
|
|
1509
|
+
# Choose data source based on user preference
|
|
1510
|
+
if use_raw_data:
|
|
1511
|
+
if self.matrix_data is None:
|
|
1512
|
+
QMessageBox.warning(self, "No Data", "Raw data not available.")
|
|
1513
|
+
return
|
|
1514
|
+
data = self.matrix_data
|
|
1515
|
+
data_type = "raw"
|
|
1516
|
+
else:
|
|
1517
|
+
if self.processed_data is None:
|
|
1518
|
+
QMessageBox.warning(self, "No Data", "Processed data not available. Please apply preprocessing first.")
|
|
1519
|
+
return
|
|
1520
|
+
data = self.processed_data.T # Transpose back to features x samples format
|
|
1521
|
+
data_type = "processed"
|
|
1522
|
+
|
|
1523
|
+
metadata = self.metadata
|
|
1524
|
+
|
|
1525
|
+
if data is None or metadata is None:
|
|
1526
|
+
QMessageBox.warning(self, "No Data", "Original data not available. Please ensure data is loaded.")
|
|
1527
|
+
return
|
|
1528
|
+
|
|
1529
|
+
# Extract subset of data (excluding the selected clusters)
|
|
1530
|
+
subset_data = data.iloc[:, keep_indices]
|
|
1531
|
+
|
|
1532
|
+
# Get snippet IDs for the kept samples
|
|
1533
|
+
snippet_ids = subset_data.columns.tolist()
|
|
1534
|
+
|
|
1535
|
+
# Extract corresponding metadata
|
|
1536
|
+
snippet_col = 'snippet' if 'snippet' in metadata.columns else ('span_id' if 'span_id' in metadata.columns else None)
|
|
1537
|
+
if snippet_col:
|
|
1538
|
+
subset_metadata = metadata[metadata[snippet_col].isin(snippet_ids)].copy()
|
|
1539
|
+
else:
|
|
1540
|
+
# Fallback: align by index if snippet column not found
|
|
1541
|
+
subset_metadata = metadata.iloc[keep_indices].copy() if len(metadata) == len(self.clusters) else metadata.copy()
|
|
1542
|
+
|
|
1543
|
+
# Create timestamp and descriptive filename
|
|
1544
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
1545
|
+
excluded_str = "_".join([str(num) for num in sorted(exclude_nums)])
|
|
1546
|
+
|
|
1547
|
+
# Determine output directory
|
|
1548
|
+
experiment_path = self.config.get("experiment_path")
|
|
1549
|
+
if experiment_path:
|
|
1550
|
+
output_dir = os.path.join(experiment_path, "analysis_results")
|
|
1551
|
+
else:
|
|
1552
|
+
output_dir = os.getcwd()
|
|
1553
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
1554
|
+
|
|
1555
|
+
# Save matrix as NPZ
|
|
1556
|
+
matrix_filename = f"matrix_excluding_clusters_{excluded_str}_{data_type}_{timestamp}.npz"
|
|
1557
|
+
matrix_path = os.path.join(output_dir, matrix_filename)
|
|
1558
|
+
|
|
1559
|
+
feature_names = subset_data.index.tolist()
|
|
1560
|
+
matrix_array = subset_data.values # features x samples
|
|
1561
|
+
|
|
1562
|
+
np.savez_compressed(
|
|
1563
|
+
matrix_path,
|
|
1564
|
+
matrix=matrix_array,
|
|
1565
|
+
feature_names=np.array(feature_names, dtype=object),
|
|
1566
|
+
snippet_ids=np.array(snippet_ids, dtype=object),
|
|
1567
|
+
)
|
|
1568
|
+
|
|
1569
|
+
# Save metadata as NPZ
|
|
1570
|
+
metadata_filename = f"metadata_excluding_clusters_{excluded_str}_{timestamp}.npz"
|
|
1571
|
+
metadata_path = os.path.join(output_dir, metadata_filename)
|
|
1572
|
+
|
|
1573
|
+
np.savez_compressed(
|
|
1574
|
+
metadata_path,
|
|
1575
|
+
metadata=subset_metadata.values,
|
|
1576
|
+
columns=np.array(subset_metadata.columns, dtype=object),
|
|
1577
|
+
)
|
|
1578
|
+
|
|
1579
|
+
# Calculate statistics
|
|
1580
|
+
original_count = len(self.clusters)
|
|
1581
|
+
remaining_count = len(keep_indices)
|
|
1582
|
+
excluded_count = original_count - remaining_count
|
|
1583
|
+
|
|
1584
|
+
# Get remaining clusters
|
|
1585
|
+
remaining_clusters = sorted(set(self.clusters[i] for i in keep_indices))
|
|
1586
|
+
remaining_labels = []
|
|
1587
|
+
for cid in remaining_clusters:
|
|
1588
|
+
if isinstance(cid, str):
|
|
1589
|
+
remaining_labels.append(cid)
|
|
1590
|
+
else:
|
|
1591
|
+
remaining_labels.append(f"Cluster_{cid}" if cid >= 0 else "Noise")
|
|
1592
|
+
|
|
1593
|
+
msg = (f"Successfully excluded {len(clusters_to_exclude)} clusters.\n"
|
|
1594
|
+
f"Data type: {data_type}\n"
|
|
1595
|
+
f"Removed {excluded_count} samples, kept {remaining_count} samples.\n"
|
|
1596
|
+
f"Remaining clusters: {remaining_labels}\n\n"
|
|
1597
|
+
f"Saved as:\n- {matrix_filename} (shape: {subset_data.shape})\n"
|
|
1598
|
+
f"- {metadata_filename} (shape: {subset_metadata.shape})\n\n"
|
|
1599
|
+
f"Click 'Load Dataset' to load this data now, or 'OK' to continue with current data.")
|
|
1600
|
+
|
|
1601
|
+
dialog = ClusterExportDialog(self, msg, matrix_path, metadata_path)
|
|
1602
|
+
dialog.exec()
|
|
1603
|
+
|
|
1604
|
+
# Load dataset if user clicked "Load Dataset"
|
|
1605
|
+
if dialog.load_requested:
|
|
1606
|
+
self._load_files_async(matrix_path, metadata_path)
|
|
1607
|
+
|
|
1608
|
+
except Exception as e:
|
|
1609
|
+
error_msg = f"Error excluding clusters: {str(e)}"
|
|
1610
|
+
logger.error("Error excluding clusters: %s", e, exc_info=True)
|
|
1611
|
+
QMessageBox.critical(self, "Error", error_msg)
|
|
1612
|
+
|
|
1613
|
+
def _show_cluster_export_help(self):
|
|
1614
|
+
"""Show help dialog for cluster export."""
|
|
1615
|
+
QMessageBox.information(
|
|
1616
|
+
self,
|
|
1617
|
+
"Cluster Export Help",
|
|
1618
|
+
"Cluster Export allows you to:\n\n"
|
|
1619
|
+
"-Extract Selected Cluster: Export a single cluster for subclustering analysis. "
|
|
1620
|
+
"This creates a new dataset containing only samples from the selected cluster, "
|
|
1621
|
+
"which you can then load and analyze separately.\n\n"
|
|
1622
|
+
"-Exclude Selected Clusters: Export data excluding specific clusters. "
|
|
1623
|
+
"This is useful for removing noise or artifacts from your analysis. "
|
|
1624
|
+
"You can select multiple clusters to exclude at once.\n\n"
|
|
1625
|
+
"-Use raw data: If checked, exports the original (unnormalized) data. "
|
|
1626
|
+
"Otherwise, exports the preprocessed data.\n\n"
|
|
1627
|
+
"All exports are saved as NPZ files in the analysis_results folder."
|
|
1628
|
+
)
|
|
1629
|
+
|
|
1630
|
+
def _get_plot_theme(self):
|
|
1631
|
+
"""Get the currently selected plot theme."""
|
|
1632
|
+
if hasattr(self, 'color_theme_combo'):
|
|
1633
|
+
theme = self.color_theme_combo.currentText()
|
|
1634
|
+
return theme if theme != "none" else None
|
|
1635
|
+
return "simple_white" # Default
|
|
1636
|
+
|
|
1637
|
+
def _export_plot(self):
|
|
1638
|
+
"""Export current plot as PDF or SVG."""
|
|
1639
|
+
if self.current_fig is None:
|
|
1640
|
+
QMessageBox.warning(self, "No Plot", "No plot available to export. Please run clustering first.")
|
|
1641
|
+
return
|
|
1642
|
+
|
|
1643
|
+
# Determine default directory
|
|
1644
|
+
experiment_path = self.config.get("experiment_path")
|
|
1645
|
+
if experiment_path:
|
|
1646
|
+
default_dir = os.path.join(experiment_path, "analysis_results")
|
|
1647
|
+
os.makedirs(default_dir, exist_ok=True)
|
|
1648
|
+
else:
|
|
1649
|
+
default_dir = os.getcwd()
|
|
1650
|
+
|
|
1651
|
+
# Create timestamp for default filename
|
|
1652
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
1653
|
+
default_filename = os.path.join(default_dir, f"clustering_plot_{timestamp}.pdf")
|
|
1654
|
+
|
|
1655
|
+
# Open file dialog
|
|
1656
|
+
file_path, selected_filter = QFileDialog.getSaveFileName(
|
|
1657
|
+
self,
|
|
1658
|
+
"Export Plot",
|
|
1659
|
+
default_filename,
|
|
1660
|
+
"PDF Files (*.pdf);;SVG Files (*.svg)"
|
|
1661
|
+
)
|
|
1662
|
+
|
|
1663
|
+
if not file_path:
|
|
1664
|
+
return
|
|
1665
|
+
|
|
1666
|
+
try:
|
|
1667
|
+
# Determine format from file extension
|
|
1668
|
+
if file_path.lower().endswith('.svg'):
|
|
1669
|
+
format_type = 'svg'
|
|
1670
|
+
elif file_path.lower().endswith('.pdf'):
|
|
1671
|
+
format_type = 'pdf'
|
|
1672
|
+
else:
|
|
1673
|
+
# Default to PDF if extension not recognized
|
|
1674
|
+
format_type = 'pdf'
|
|
1675
|
+
if not file_path.endswith('.pdf'):
|
|
1676
|
+
file_path += '.pdf'
|
|
1677
|
+
|
|
1678
|
+
# Check if kaleido is available (required for PDF/SVG export)
|
|
1679
|
+
try:
|
|
1680
|
+
import kaleido
|
|
1681
|
+
except ImportError:
|
|
1682
|
+
QMessageBox.critical(
|
|
1683
|
+
self,
|
|
1684
|
+
"Export Error",
|
|
1685
|
+
"The 'kaleido' package is required for PDF/SVG export.\n\n"
|
|
1686
|
+
"Please install it with: pip install kaleido"
|
|
1687
|
+
)
|
|
1688
|
+
return
|
|
1689
|
+
|
|
1690
|
+
# Get figure dimensions (use on-screen canvas size when possible)
|
|
1691
|
+
fig_width = 1200
|
|
1692
|
+
fig_height = 800
|
|
1693
|
+
|
|
1694
|
+
# Prefer the plot widget's rendered size to avoid stretching
|
|
1695
|
+
if hasattr(self, 'plot_widget') and self.plot_widget is not None:
|
|
1696
|
+
try:
|
|
1697
|
+
w = self.plot_widget.width()
|
|
1698
|
+
h = self.plot_widget.height()
|
|
1699
|
+
if w and h:
|
|
1700
|
+
fig_width = max(int(w), 400)
|
|
1701
|
+
fig_height = max(int(h), 300)
|
|
1702
|
+
except Exception as e:
|
|
1703
|
+
logger.debug("Could not read plot widget dimensions: %s", e)
|
|
1704
|
+
|
|
1705
|
+
# If explicit layout sizes are set on the figure, prefer them
|
|
1706
|
+
if hasattr(self.current_fig, 'layout') and self.current_fig.layout:
|
|
1707
|
+
if getattr(self.current_fig.layout, 'width', None):
|
|
1708
|
+
fig_width = self.current_fig.layout.width
|
|
1709
|
+
if getattr(self.current_fig.layout, 'height', None):
|
|
1710
|
+
fig_height = self.current_fig.layout.height
|
|
1711
|
+
|
|
1712
|
+
# Export using plotly.io.write_image
|
|
1713
|
+
pio.write_image(
|
|
1714
|
+
self.current_fig,
|
|
1715
|
+
file_path,
|
|
1716
|
+
format=format_type,
|
|
1717
|
+
width=fig_width,
|
|
1718
|
+
height=fig_height
|
|
1719
|
+
)
|
|
1720
|
+
|
|
1721
|
+
QMessageBox.information(self, "Success", f"Plot exported successfully to:\n{file_path}")
|
|
1722
|
+
|
|
1723
|
+
except Exception as e:
|
|
1724
|
+
error_msg = f"Error exporting plot: {str(e)}"
|
|
1725
|
+
logger.error("Error exporting plot: %s", e, exc_info=True)
|
|
1726
|
+
QMessageBox.critical(self, "Export Error", error_msg)
|
|
1727
|
+
|
|
1728
|
+
def _update_plot_type(self):
|
|
1729
|
+
"""Update plot type based on selection."""
|
|
1730
|
+
self._update_plots_by_metadata()
|
|
1731
|
+
|
|
1732
|
+
def _update_plots_by_metadata(self):
|
|
1733
|
+
"""Update plots based on selected metadata column and plot type."""
|
|
1734
|
+
if self.embedding is None or self.clusters is None or self.processed_data is None:
|
|
1735
|
+
return
|
|
1736
|
+
|
|
1737
|
+
plot_type = self.plot_type_combo.currentText()
|
|
1738
|
+
|
|
1739
|
+
# Show/hide single cluster selector based on plot type
|
|
1740
|
+
if plot_type in ["Single Cluster Analysis", "Spatial cluster distribution"]:
|
|
1741
|
+
self.single_cluster_label.setVisible(True)
|
|
1742
|
+
self.single_cluster_combo.setVisible(True)
|
|
1743
|
+
# Make sure cluster list is populated
|
|
1744
|
+
if self.clusters is not None:
|
|
1745
|
+
self._refresh_cluster_list()
|
|
1746
|
+
else:
|
|
1747
|
+
self.single_cluster_label.setVisible(False)
|
|
1748
|
+
self.single_cluster_combo.setVisible(False)
|
|
1749
|
+
|
|
1750
|
+
# Show/hide video and object selectors for spatial distribution
|
|
1751
|
+
if plot_type == "Spatial cluster distribution":
|
|
1752
|
+
self.spatial_video_label.setVisible(True)
|
|
1753
|
+
self.spatial_video_combo.setVisible(True)
|
|
1754
|
+
self.spatial_object_label.setVisible(True)
|
|
1755
|
+
self.spatial_object_combo.setVisible(True)
|
|
1756
|
+
# Populate video/object lists
|
|
1757
|
+
self._refresh_spatial_selectors()
|
|
1758
|
+
else:
|
|
1759
|
+
self.spatial_video_label.setVisible(False)
|
|
1760
|
+
self.spatial_video_combo.setVisible(False)
|
|
1761
|
+
self.spatial_object_label.setVisible(False)
|
|
1762
|
+
self.spatial_object_combo.setVisible(False)
|
|
1763
|
+
|
|
1764
|
+
# Handle "UMAP" plot type
|
|
1765
|
+
if plot_type == "UMAP":
|
|
1766
|
+
group_name = self.metadata_column_combo.currentText()
|
|
1767
|
+
|
|
1768
|
+
# Handle empty or "None" selection
|
|
1769
|
+
if not group_name or group_name.strip() == "" or group_name == "None (Show all clusters)":
|
|
1770
|
+
# Regenerate default clustering plot to show all clusters
|
|
1771
|
+
self._regenerate_default_plot()
|
|
1772
|
+
if hasattr(self, 'export_plot_btn'):
|
|
1773
|
+
self.export_plot_btn.setEnabled(True)
|
|
1774
|
+
if hasattr(self, 'export_csv_btn'):
|
|
1775
|
+
self.export_csv_btn.setEnabled(True)
|
|
1776
|
+
return
|
|
1777
|
+
|
|
1778
|
+
if self.metadata is None:
|
|
1779
|
+
QMessageBox.warning(self, "No Metadata", "Please load metadata first.")
|
|
1780
|
+
return
|
|
1781
|
+
|
|
1782
|
+
# Validate column exists
|
|
1783
|
+
if group_name not in self.metadata.columns:
|
|
1784
|
+
QMessageBox.warning(self, "Invalid Column", f"Column '{group_name}' not found in metadata.\n\nAvailable columns: {', '.join(self.metadata.columns)}")
|
|
1785
|
+
# Reset to "None" selection
|
|
1786
|
+
self.metadata_column_combo.setCurrentIndex(0)
|
|
1787
|
+
return
|
|
1788
|
+
|
|
1789
|
+
try:
|
|
1790
|
+
point_size = self.plot_point_size_slider.value()
|
|
1791
|
+
umap_fig, props_fig, single_fig = self._create_grouped_plots(group_name, point_size, None)
|
|
1792
|
+
|
|
1793
|
+
if umap_fig:
|
|
1794
|
+
self.current_fig = umap_fig
|
|
1795
|
+
self.plot_widget.update_plot(umap_fig)
|
|
1796
|
+
if hasattr(self, 'export_plot_btn'):
|
|
1797
|
+
self.export_plot_btn.setEnabled(True)
|
|
1798
|
+
if hasattr(self, 'export_csv_btn'):
|
|
1799
|
+
self.export_csv_btn.setEnabled(True)
|
|
1800
|
+
|
|
1801
|
+
except Exception as e:
|
|
1802
|
+
logger.error("Error creating plots: %s", e, exc_info=True)
|
|
1803
|
+
QMessageBox.critical(self, "Plot Error", f"Error creating plots: {e}")
|
|
1804
|
+
|
|
1805
|
+
# Handle "Cluster Proportions" plot type
|
|
1806
|
+
elif plot_type == "Cluster Proportions":
|
|
1807
|
+
group_name = self.metadata_column_combo.currentText()
|
|
1808
|
+
|
|
1809
|
+
if not group_name or group_name.strip() == "" or group_name == "None (Show all clusters)":
|
|
1810
|
+
QMessageBox.warning(self, "No Group Selected", "Please select a metadata column to show proportions.")
|
|
1811
|
+
return
|
|
1812
|
+
|
|
1813
|
+
if self.metadata is None:
|
|
1814
|
+
QMessageBox.warning(self, "No Metadata", "Please load metadata first.")
|
|
1815
|
+
return
|
|
1816
|
+
|
|
1817
|
+
if group_name not in self.metadata.columns:
|
|
1818
|
+
QMessageBox.warning(self, "Invalid Column", f"Column '{group_name}' not found in metadata.")
|
|
1819
|
+
return
|
|
1820
|
+
|
|
1821
|
+
try:
|
|
1822
|
+
point_size = self.plot_point_size_slider.value()
|
|
1823
|
+
umap_fig, props_fig, single_fig = self._create_grouped_plots(group_name, point_size, None)
|
|
1824
|
+
|
|
1825
|
+
if props_fig:
|
|
1826
|
+
self.current_fig = props_fig
|
|
1827
|
+
self.plot_widget.update_plot(props_fig)
|
|
1828
|
+
if hasattr(self, 'export_plot_btn'):
|
|
1829
|
+
self.export_plot_btn.setEnabled(True)
|
|
1830
|
+
if hasattr(self, 'export_csv_btn'):
|
|
1831
|
+
self.export_csv_btn.setEnabled(True)
|
|
1832
|
+
else:
|
|
1833
|
+
QMessageBox.warning(self, "Plot Error", "Could not create proportions plot.")
|
|
1834
|
+
|
|
1835
|
+
except Exception as e:
|
|
1836
|
+
logger.error("Error creating plots: %s", e, exc_info=True)
|
|
1837
|
+
QMessageBox.critical(self, "Plot Error", f"Error creating plots: {e}")
|
|
1838
|
+
|
|
1839
|
+
# Handle "Single Cluster Analysis" plot type
|
|
1840
|
+
elif plot_type == "Single Cluster Analysis":
|
|
1841
|
+
# Check if clusters are available
|
|
1842
|
+
if self.clusters is None:
|
|
1843
|
+
QMessageBox.warning(self, "No Clusters", "Please run clustering first.")
|
|
1844
|
+
return
|
|
1845
|
+
|
|
1846
|
+
# Make sure cluster list is populated
|
|
1847
|
+
self._refresh_cluster_list()
|
|
1848
|
+
|
|
1849
|
+
group_name = self.metadata_column_combo.currentText()
|
|
1850
|
+
selected_cluster = self.single_cluster_combo.currentText() if self.single_cluster_combo.currentText() != "None" else None
|
|
1851
|
+
|
|
1852
|
+
if not selected_cluster:
|
|
1853
|
+
# Don't show warning immediately - let user select first
|
|
1854
|
+
# Just clear the plot or show a message
|
|
1855
|
+
self.status_label.setText("Please select a cluster from the dropdown above.")
|
|
1856
|
+
return
|
|
1857
|
+
|
|
1858
|
+
if not group_name or group_name.strip() == "" or group_name == "None (Show all clusters)":
|
|
1859
|
+
QMessageBox.warning(self, "No Group Selected", "Please select a metadata column to show single cluster proportions.")
|
|
1860
|
+
return
|
|
1861
|
+
|
|
1862
|
+
if self.metadata is None:
|
|
1863
|
+
QMessageBox.warning(self, "No Metadata", "Please load metadata first.")
|
|
1864
|
+
return
|
|
1865
|
+
|
|
1866
|
+
if group_name not in self.metadata.columns:
|
|
1867
|
+
QMessageBox.warning(self, "Invalid Column", f"Column '{group_name}' not found in metadata.")
|
|
1868
|
+
return
|
|
1869
|
+
|
|
1870
|
+
try:
|
|
1871
|
+
point_size = self.plot_point_size_slider.value()
|
|
1872
|
+
umap_fig, props_fig, single_fig = self._create_grouped_plots(group_name, point_size, selected_cluster)
|
|
1873
|
+
|
|
1874
|
+
if single_fig:
|
|
1875
|
+
self.current_fig = single_fig
|
|
1876
|
+
self.plot_widget.update_plot(single_fig)
|
|
1877
|
+
if hasattr(self, 'export_plot_btn'):
|
|
1878
|
+
self.export_plot_btn.setEnabled(True)
|
|
1879
|
+
if hasattr(self, 'export_csv_btn'):
|
|
1880
|
+
self.export_csv_btn.setEnabled(True)
|
|
1881
|
+
else:
|
|
1882
|
+
QMessageBox.warning(self, "Plot Error", "Could not create single cluster plot.")
|
|
1883
|
+
|
|
1884
|
+
except Exception as e:
|
|
1885
|
+
logger.error("Error creating plots: %s", e, exc_info=True)
|
|
1886
|
+
QMessageBox.critical(self, "Plot Error", f"Error creating plots: {e}")
|
|
1887
|
+
|
|
1888
|
+
# Handle "Spatial cluster distribution" plot type
|
|
1889
|
+
elif plot_type == "Spatial cluster distribution":
|
|
1890
|
+
# Check if clusters are available
|
|
1891
|
+
if self.clusters is None:
|
|
1892
|
+
self.status_label.setText("Please run clustering first.")
|
|
1893
|
+
return
|
|
1894
|
+
|
|
1895
|
+
# Make sure cluster list is populated
|
|
1896
|
+
if hasattr(self, 'single_cluster_combo'):
|
|
1897
|
+
self._refresh_cluster_list()
|
|
1898
|
+
selected_cluster = self.single_cluster_combo.currentText() if self.single_cluster_combo.currentText() != "None" else None
|
|
1899
|
+
else:
|
|
1900
|
+
selected_cluster = None
|
|
1901
|
+
|
|
1902
|
+
if not selected_cluster:
|
|
1903
|
+
self.status_label.setText("Please select a cluster from the dropdown above to visualize its spatial distribution.")
|
|
1904
|
+
return
|
|
1905
|
+
|
|
1906
|
+
try:
|
|
1907
|
+
spatial_fig = self._create_spatial_distribution_plot(selected_cluster)
|
|
1908
|
+
|
|
1909
|
+
if spatial_fig:
|
|
1910
|
+
self.current_fig = spatial_fig
|
|
1911
|
+
self.plot_widget.update_plot(spatial_fig)
|
|
1912
|
+
if hasattr(self, 'export_plot_btn'):
|
|
1913
|
+
self.export_plot_btn.setEnabled(True)
|
|
1914
|
+
if hasattr(self, 'export_csv_btn'):
|
|
1915
|
+
self.export_csv_btn.setEnabled(True)
|
|
1916
|
+
self.status_label.setText(f"Spatial distribution plot for {selected_cluster} displayed.")
|
|
1917
|
+
else:
|
|
1918
|
+
self.status_label.setText("Could not create spatial distribution plot. Make sure mask data is available in the experiment folder.")
|
|
1919
|
+
|
|
1920
|
+
except Exception as e:
|
|
1921
|
+
error_msg = f"Error creating spatial distribution plot: {str(e)}"
|
|
1922
|
+
self.status_label.setText(error_msg)
|
|
1923
|
+
logger.error("Error creating spatial distribution plot: %s", e, exc_info=True)
|
|
1924
|
+
QMessageBox.critical(self, "Plot Error", error_msg)
|
|
1925
|
+
|
|
1926
|
+
def _create_grouped_plots(self, group_name, point_size, selected_cluster=None):
|
|
1927
|
+
"""Create UMAP subplots grouped by metadata column, proportion plot, and single cluster plot."""
|
|
1928
|
+
# Create base dataframe with UMAP coordinates and clusters
|
|
1929
|
+
df_plot = pd.DataFrame({
|
|
1930
|
+
'UMAP1': self.embedding[:, 0],
|
|
1931
|
+
'UMAP2': self.embedding[:, 1],
|
|
1932
|
+
'Cluster': [f'Cluster_{c}' if c >= 0 else 'Noise' for c in self.clusters],
|
|
1933
|
+
'Sample': self.processed_data.index
|
|
1934
|
+
})
|
|
1935
|
+
|
|
1936
|
+
# Validate group_name exists in metadata
|
|
1937
|
+
if group_name not in self.metadata.columns:
|
|
1938
|
+
return None, None, None
|
|
1939
|
+
|
|
1940
|
+
# Merge metadata with processed data
|
|
1941
|
+
snippet_col = 'snippet' if 'snippet' in self.metadata.columns else ('span_id' if 'span_id' in self.metadata.columns else None)
|
|
1942
|
+
|
|
1943
|
+
# Prepare columns to merge (only what we need)
|
|
1944
|
+
cols_to_merge = [group_name]
|
|
1945
|
+
if snippet_col and snippet_col in self.metadata.columns:
|
|
1946
|
+
cols_to_merge.append(snippet_col)
|
|
1947
|
+
|
|
1948
|
+
if snippet_col and snippet_col in self.metadata.columns:
|
|
1949
|
+
# Merge using snippet column
|
|
1950
|
+
df_plot = df_plot.merge(
|
|
1951
|
+
self.metadata[cols_to_merge],
|
|
1952
|
+
left_on='Sample',
|
|
1953
|
+
right_on=snippet_col,
|
|
1954
|
+
how='left'
|
|
1955
|
+
)
|
|
1956
|
+
else:
|
|
1957
|
+
# Try to align by index - reset index if needed
|
|
1958
|
+
metadata_aligned = self.metadata.copy()
|
|
1959
|
+
if len(metadata_aligned) == len(df_plot):
|
|
1960
|
+
# Align by position
|
|
1961
|
+
metadata_aligned.index = df_plot.index
|
|
1962
|
+
df_plot = pd.concat([df_plot, metadata_aligned[[group_name]]], axis=1)
|
|
1963
|
+
else:
|
|
1964
|
+
# Try to merge by resetting index
|
|
1965
|
+
metadata_reset = self.metadata.reset_index()
|
|
1966
|
+
if 'index' in metadata_reset.columns:
|
|
1967
|
+
df_plot = df_plot.merge(metadata_reset[cols_to_merge + ['index']], left_on='Sample', right_on='index', how='left')
|
|
1968
|
+
else:
|
|
1969
|
+
# Last resort: align by position if lengths match
|
|
1970
|
+
if len(metadata_aligned) == len(df_plot):
|
|
1971
|
+
df_plot[group_name] = metadata_aligned[group_name].values
|
|
1972
|
+
else:
|
|
1973
|
+
return None, None, None
|
|
1974
|
+
|
|
1975
|
+
# Ensure Cluster column is still present (merge might have dropped it)
|
|
1976
|
+
if 'Cluster' not in df_plot.columns:
|
|
1977
|
+
df_plot['Cluster'] = [f'Cluster_{c}' if c >= 0 else 'Noise' for c in self.clusters]
|
|
1978
|
+
|
|
1979
|
+
if group_name not in df_plot.columns or df_plot[group_name].isnull().all():
|
|
1980
|
+
return None, None, None
|
|
1981
|
+
|
|
1982
|
+
# Filter out NaN values
|
|
1983
|
+
df_plot = df_plot.dropna(subset=[group_name])
|
|
1984
|
+
|
|
1985
|
+
# Generate cluster colors
|
|
1986
|
+
unique_clusters = sorted(df_plot['Cluster'].unique())
|
|
1987
|
+
n_clusters = len(unique_clusters)
|
|
1988
|
+
cluster_colors = self._generate_cluster_colors(n_clusters)
|
|
1989
|
+
cluster_color_map = {f'Cluster_{i}': cluster_colors[i] for i in range(n_clusters) if f'Cluster_{i}' in unique_clusters}
|
|
1990
|
+
if 'Noise' in unique_clusters:
|
|
1991
|
+
cluster_color_map['Noise'] = '#808080'
|
|
1992
|
+
|
|
1993
|
+
# Get unique values in the group
|
|
1994
|
+
unique_values = sorted(df_plot[group_name].dropna().unique())
|
|
1995
|
+
n_groups = len(unique_values)
|
|
1996
|
+
|
|
1997
|
+
if n_groups == 0:
|
|
1998
|
+
return None, None, None
|
|
1999
|
+
|
|
2000
|
+
# Create UMAP subplots
|
|
2001
|
+
n_cols = 2
|
|
2002
|
+
n_rows = (n_groups + 1 + n_cols - 1) // n_cols # Add 1 for combined view
|
|
2003
|
+
|
|
2004
|
+
subplot_titles = [f'{group_name}: {val}' for val in unique_values]
|
|
2005
|
+
subplot_titles.append(f'All {group_name} Groups Combined')
|
|
2006
|
+
|
|
2007
|
+
umap_fig = make_subplots(
|
|
2008
|
+
rows=n_rows,
|
|
2009
|
+
cols=n_cols,
|
|
2010
|
+
subplot_titles=subplot_titles,
|
|
2011
|
+
horizontal_spacing=0.05,
|
|
2012
|
+
vertical_spacing=0.08
|
|
2013
|
+
)
|
|
2014
|
+
|
|
2015
|
+
# Calculate global axis ranges
|
|
2016
|
+
umap1_min, umap1_max = df_plot['UMAP1'].min(), df_plot['UMAP1'].max()
|
|
2017
|
+
umap2_min, umap2_max = df_plot['UMAP2'].min(), df_plot['UMAP2'].max()
|
|
2018
|
+
range1 = umap1_max - umap1_min
|
|
2019
|
+
range2 = umap2_max - umap2_min
|
|
2020
|
+
max_range = max(range1, range2)
|
|
2021
|
+
center1 = (umap1_min + umap1_max) / 2
|
|
2022
|
+
center2 = (umap2_min + umap2_max) / 2
|
|
2023
|
+
padding = max_range * 0.1
|
|
2024
|
+
axis_min1 = center1 - (max_range / 2) - padding
|
|
2025
|
+
axis_max1 = center1 + (max_range / 2) + padding
|
|
2026
|
+
axis_min2 = center2 - (max_range / 2) - padding
|
|
2027
|
+
axis_max2 = center2 + (max_range / 2) + padding
|
|
2028
|
+
|
|
2029
|
+
# Plot individual groups
|
|
2030
|
+
for idx, value in enumerate(unique_values):
|
|
2031
|
+
row = idx // n_cols + 1
|
|
2032
|
+
col = idx % n_cols + 1
|
|
2033
|
+
|
|
2034
|
+
mask = df_plot[group_name] == value
|
|
2035
|
+
df_subset = df_plot[mask]
|
|
2036
|
+
|
|
2037
|
+
for cluster in unique_clusters:
|
|
2038
|
+
cluster_data = df_subset[df_subset['Cluster'] == cluster]
|
|
2039
|
+
if len(cluster_data) == 0:
|
|
2040
|
+
continue
|
|
2041
|
+
|
|
2042
|
+
# Add snippet IDs as customdata for click handling
|
|
2043
|
+
snippet_ids = cluster_data['Sample'].tolist()
|
|
2044
|
+
|
|
2045
|
+
umap_fig.add_trace(
|
|
2046
|
+
go.Scatter(
|
|
2047
|
+
x=cluster_data['UMAP1'],
|
|
2048
|
+
y=cluster_data['UMAP2'],
|
|
2049
|
+
mode='markers',
|
|
2050
|
+
marker=dict(size=point_size, color=cluster_color_map.get(cluster, '#CCCCCC')),
|
|
2051
|
+
name=cluster,
|
|
2052
|
+
showlegend=(idx == 0),
|
|
2053
|
+
legendgroup=cluster,
|
|
2054
|
+
customdata=[[sid] for sid in snippet_ids], # Wrap in list for each point
|
|
2055
|
+
hovertemplate='<b>%{hovertext}</b><br>Cluster: ' + cluster + '<br>X: %{x:.2f}<br>Y: %{y:.2f}<extra></extra>',
|
|
2056
|
+
hovertext=snippet_ids
|
|
2057
|
+
),
|
|
2058
|
+
row=row, col=col
|
|
2059
|
+
)
|
|
2060
|
+
|
|
2061
|
+
umap_fig.update_xaxes(range=[axis_min1, axis_max1], row=row, col=col)
|
|
2062
|
+
umap_fig.update_yaxes(range=[axis_min2, axis_max2], row=row, col=col)
|
|
2063
|
+
|
|
2064
|
+
# Combined view
|
|
2065
|
+
combined_row = (n_groups // n_cols) + 1
|
|
2066
|
+
combined_col = (n_groups % n_cols) + 1
|
|
2067
|
+
|
|
2068
|
+
# Generate group colors for combined view
|
|
2069
|
+
group_colors = self._generate_pastel_palette(n_groups)
|
|
2070
|
+
group_color_map = {val: group_colors[i] for i, val in enumerate(unique_values)}
|
|
2071
|
+
|
|
2072
|
+
for value in unique_values:
|
|
2073
|
+
mask = df_plot[group_name] == value
|
|
2074
|
+
df_subset = df_plot[mask]
|
|
2075
|
+
|
|
2076
|
+
# Add snippet IDs as customdata for click handling
|
|
2077
|
+
snippet_ids_combined = df_subset['Sample'].tolist()
|
|
2078
|
+
|
|
2079
|
+
umap_fig.add_trace(
|
|
2080
|
+
go.Scatter(
|
|
2081
|
+
x=df_subset['UMAP1'],
|
|
2082
|
+
y=df_subset['UMAP2'],
|
|
2083
|
+
mode='markers',
|
|
2084
|
+
marker=dict(size=point_size, color=group_color_map[value]),
|
|
2085
|
+
name=f'{group_name}: {value}',
|
|
2086
|
+
showlegend=True,
|
|
2087
|
+
legendgroup=f'group_{value}',
|
|
2088
|
+
customdata=[[sid] for sid in snippet_ids_combined], # Wrap in list for each point
|
|
2089
|
+
hovertemplate='<b>%{hovertext}</b><br>' + group_name + ': ' + str(value) + '<br>X: %{x:.2f}<br>Y: %{y:.2f}<extra></extra>',
|
|
2090
|
+
hovertext=snippet_ids_combined
|
|
2091
|
+
),
|
|
2092
|
+
row=combined_row, col=combined_col
|
|
2093
|
+
)
|
|
2094
|
+
|
|
2095
|
+
umap_fig.update_xaxes(range=[axis_min1, axis_max1], row=combined_row, col=combined_col)
|
|
2096
|
+
umap_fig.update_yaxes(range=[axis_min2, axis_max2], row=combined_row, col=combined_col)
|
|
2097
|
+
|
|
2098
|
+
theme = self._get_plot_theme()
|
|
2099
|
+
umap_fig.update_layout(
|
|
2100
|
+
title=f'UMAP + Leiden Clustering by {group_name}',
|
|
2101
|
+
template=theme if theme else None,
|
|
2102
|
+
height=300 * n_rows
|
|
2103
|
+
)
|
|
2104
|
+
|
|
2105
|
+
# Create proportion plot
|
|
2106
|
+
props_fig = self._create_proportion_plot(df_plot, group_name, cluster_color_map)
|
|
2107
|
+
|
|
2108
|
+
# Create single cluster plot if selected and exists in data
|
|
2109
|
+
single_fig = None
|
|
2110
|
+
if selected_cluster:
|
|
2111
|
+
if selected_cluster in df_plot['Cluster'].unique():
|
|
2112
|
+
single_fig = self._create_single_cluster_plot(df_plot, group_name, selected_cluster, cluster_color_map)
|
|
2113
|
+
else:
|
|
2114
|
+
single_fig = None
|
|
2115
|
+
|
|
2116
|
+
return umap_fig, props_fig, single_fig
|
|
2117
|
+
|
|
2118
|
+
def _create_proportion_plot(self, df_plot, group_name, cluster_color_map):
|
|
2119
|
+
"""Create cluster proportion bar plot."""
|
|
2120
|
+
# Calculate proportions
|
|
2121
|
+
props = []
|
|
2122
|
+
for group_val in sorted(df_plot[group_name].dropna().unique()):
|
|
2123
|
+
group_data = df_plot[df_plot[group_name] == group_val]
|
|
2124
|
+
total = len(group_data)
|
|
2125
|
+
cluster_props = group_data['Cluster'].value_counts() / total * 100
|
|
2126
|
+
for cluster in cluster_props.index:
|
|
2127
|
+
props.append({
|
|
2128
|
+
group_name: group_val,
|
|
2129
|
+
'Cluster': cluster,
|
|
2130
|
+
'Proportion (%)': cluster_props[cluster]
|
|
2131
|
+
})
|
|
2132
|
+
|
|
2133
|
+
df_props = pd.DataFrame(props)
|
|
2134
|
+
|
|
2135
|
+
# Create bar plot
|
|
2136
|
+
unique_groups = sorted(df_props[group_name].unique())
|
|
2137
|
+
unique_clusters = sorted(df_props['Cluster'].unique())
|
|
2138
|
+
|
|
2139
|
+
fig = go.Figure()
|
|
2140
|
+
|
|
2141
|
+
n_groups = len(unique_groups)
|
|
2142
|
+
n_clusters = len(unique_clusters)
|
|
2143
|
+
bar_width = 0.8 / n_groups
|
|
2144
|
+
|
|
2145
|
+
for i, group in enumerate(unique_groups):
|
|
2146
|
+
group_data = df_props[df_props[group_name] == group]
|
|
2147
|
+
|
|
2148
|
+
x_positions = []
|
|
2149
|
+
proportions = []
|
|
2150
|
+
|
|
2151
|
+
for j, cluster in enumerate(unique_clusters):
|
|
2152
|
+
cluster_row = group_data[group_data['Cluster'] == cluster]
|
|
2153
|
+
prop = cluster_row['Proportion (%)'].iloc[0] if len(cluster_row) > 0 else 0
|
|
2154
|
+
x_pos = j + (i - n_groups/2 + 0.5) * bar_width
|
|
2155
|
+
x_positions.append(x_pos)
|
|
2156
|
+
proportions.append(prop)
|
|
2157
|
+
|
|
2158
|
+
fig.add_trace(go.Bar(
|
|
2159
|
+
x=x_positions,
|
|
2160
|
+
y=proportions,
|
|
2161
|
+
name=f"{group_name}: {group}",
|
|
2162
|
+
width=bar_width
|
|
2163
|
+
))
|
|
2164
|
+
|
|
2165
|
+
theme = self._get_plot_theme()
|
|
2166
|
+
fig.update_layout(
|
|
2167
|
+
title=f'Cluster Proportions by {group_name}',
|
|
2168
|
+
xaxis_title='Cluster',
|
|
2169
|
+
yaxis_title='Proportion (%)',
|
|
2170
|
+
height=400,
|
|
2171
|
+
template=theme if theme else None,
|
|
2172
|
+
barmode='group',
|
|
2173
|
+
xaxis=dict(
|
|
2174
|
+
tickmode='array',
|
|
2175
|
+
tickvals=list(range(n_clusters)),
|
|
2176
|
+
ticktext=[str(c) for c in unique_clusters]
|
|
2177
|
+
)
|
|
2178
|
+
)
|
|
2179
|
+
|
|
2180
|
+
return fig
|
|
2181
|
+
|
|
2182
|
+
def _create_single_cluster_plot(self, df_plot, group_name, selected_cluster, cluster_color_map):
|
|
2183
|
+
"""Create single cluster proportion bar plot."""
|
|
2184
|
+
# Extract cluster number from "Cluster_X" format
|
|
2185
|
+
cluster_num = selected_cluster.replace('Cluster_', '')
|
|
2186
|
+
|
|
2187
|
+
props = []
|
|
2188
|
+
for group_val in sorted(df_plot[group_name].dropna().unique()):
|
|
2189
|
+
group_data = df_plot[df_plot[group_name] == group_val]
|
|
2190
|
+
total = len(group_data)
|
|
2191
|
+
cluster_count = len(group_data[group_data['Cluster'] == selected_cluster])
|
|
2192
|
+
proportion = (cluster_count / total) * 100 if total > 0 else 0
|
|
2193
|
+
props.append({
|
|
2194
|
+
group_name: group_val,
|
|
2195
|
+
'Proportion (%)': proportion
|
|
2196
|
+
})
|
|
2197
|
+
|
|
2198
|
+
df_props = pd.DataFrame(props)
|
|
2199
|
+
|
|
2200
|
+
color = cluster_color_map.get(selected_cluster, '#CCCCCC')
|
|
2201
|
+
|
|
2202
|
+
fig = go.Figure()
|
|
2203
|
+
fig.add_trace(go.Bar(
|
|
2204
|
+
x=df_props[group_name],
|
|
2205
|
+
y=df_props['Proportion (%)'],
|
|
2206
|
+
marker_color=color,
|
|
2207
|
+
name=selected_cluster,
|
|
2208
|
+
text=df_props['Proportion (%)'].round(1).astype(str) + '%',
|
|
2209
|
+
textposition='auto'
|
|
2210
|
+
))
|
|
2211
|
+
|
|
2212
|
+
theme = self._get_plot_theme()
|
|
2213
|
+
fig.update_layout(
|
|
2214
|
+
title=f'Proportion of {selected_cluster} Across {group_name}',
|
|
2215
|
+
xaxis_title=group_name,
|
|
2216
|
+
yaxis_title='Proportion (%)',
|
|
2217
|
+
height=400,
|
|
2218
|
+
template=theme if theme else None
|
|
2219
|
+
)
|
|
2220
|
+
|
|
2221
|
+
return fig
|
|
2222
|
+
|
|
2223
|
+
def _create_spatial_distribution_plot(self, selected_cluster):
|
|
2224
|
+
"""Create spatial distribution plot showing trajectory with cluster clip segments overlaid."""
|
|
2225
|
+
try:
|
|
2226
|
+
if self.metadata is None or self.clusters is None or self.processed_data is None:
|
|
2227
|
+
return None
|
|
2228
|
+
|
|
2229
|
+
# Get selected video and object filters
|
|
2230
|
+
selected_video = self.spatial_video_combo.currentText() if hasattr(self, 'spatial_video_combo') else "All"
|
|
2231
|
+
selected_object = self.spatial_object_combo.currentText() if hasattr(self, 'spatial_object_combo') else "All"
|
|
2232
|
+
|
|
2233
|
+
# Get full trajectory and clip trajectories from mask data
|
|
2234
|
+
trajectory_data = self._get_trajectory_data(selected_video, selected_object)
|
|
2235
|
+
if trajectory_data is None:
|
|
2236
|
+
return None
|
|
2237
|
+
|
|
2238
|
+
all_trajectories = trajectory_data.get('all_trajectories', [])
|
|
2239
|
+
clip_trajectories = trajectory_data.get('clip_trajectories', {})
|
|
2240
|
+
|
|
2241
|
+
if not all_trajectories:
|
|
2242
|
+
return None
|
|
2243
|
+
|
|
2244
|
+
# Get theme
|
|
2245
|
+
theme = self._get_plot_theme()
|
|
2246
|
+
|
|
2247
|
+
# Create figure
|
|
2248
|
+
fig = go.Figure()
|
|
2249
|
+
|
|
2250
|
+
# Plot all trajectories as base layer (grey lines)
|
|
2251
|
+
for traj_idx, traj_info in enumerate(all_trajectories):
|
|
2252
|
+
traj = traj_info['trajectory']
|
|
2253
|
+
traj_name = traj_info.get('name', f'Trajectory {traj_idx + 1}')
|
|
2254
|
+
|
|
2255
|
+
if len(traj) > 0:
|
|
2256
|
+
traj_x = [p[0] for p in traj]
|
|
2257
|
+
traj_y = [p[1] for p in traj]
|
|
2258
|
+
|
|
2259
|
+
fig.add_trace(go.Scatter(
|
|
2260
|
+
x=traj_x,
|
|
2261
|
+
y=traj_y,
|
|
2262
|
+
mode='lines',
|
|
2263
|
+
line=dict(
|
|
2264
|
+
color='lightgrey',
|
|
2265
|
+
width=1
|
|
2266
|
+
),
|
|
2267
|
+
name=traj_name if traj_idx == 0 else None,
|
|
2268
|
+
showlegend=(traj_idx == 0),
|
|
2269
|
+
legendgroup='trajectories',
|
|
2270
|
+
hoverinfo='skip'
|
|
2271
|
+
))
|
|
2272
|
+
|
|
2273
|
+
# Get cluster color
|
|
2274
|
+
unique_clusters = sorted(set(self.clusters))
|
|
2275
|
+
cluster_colors = self._generate_cluster_colors(len(unique_clusters))
|
|
2276
|
+
cluster_to_color = {f'Cluster_{c}' if c >= 0 else 'Noise': cluster_colors[i]
|
|
2277
|
+
for i, c in enumerate(unique_clusters)}
|
|
2278
|
+
cluster_color = cluster_to_color.get(selected_cluster, '#e74c3c')
|
|
2279
|
+
|
|
2280
|
+
# Get snippets belonging to selected cluster (filtered by video/object)
|
|
2281
|
+
snippet_col = 'snippet' if 'snippet' in self.metadata.columns else ('span_id' if 'span_id' in self.metadata.columns else None)
|
|
2282
|
+
video_col = 'video_id' if 'video_id' in self.metadata.columns else None
|
|
2283
|
+
object_col = 'object_id' if 'object_id' in self.metadata.columns else None
|
|
2284
|
+
|
|
2285
|
+
if snippet_col is None:
|
|
2286
|
+
return None
|
|
2287
|
+
|
|
2288
|
+
# Build filter mask for metadata
|
|
2289
|
+
filter_mask = pd.Series([True] * len(self.metadata), index=self.metadata.index)
|
|
2290
|
+
if selected_video != "All":
|
|
2291
|
+
if 'group' in self.metadata.columns:
|
|
2292
|
+
# Use group column for matching
|
|
2293
|
+
filter_mask &= self.metadata['group'].apply(lambda x: str(x).strip() == selected_video)
|
|
2294
|
+
elif video_col:
|
|
2295
|
+
# Extract video name from clip filenames and match
|
|
2296
|
+
import re
|
|
2297
|
+
def extract_video_name(clip_name):
|
|
2298
|
+
base = os.path.splitext(os.path.basename(str(clip_name)))[0]
|
|
2299
|
+
match = re.match(r'^(.+?)_clip_\d+(?:_obj\d+)?$', base)
|
|
2300
|
+
return match.group(1) if match else base
|
|
2301
|
+
filter_mask &= self.metadata[video_col].apply(lambda x: extract_video_name(x) == selected_video)
|
|
2302
|
+
if object_col and selected_object != "All":
|
|
2303
|
+
filter_mask &= self.metadata[object_col].apply(lambda x: str(x) == selected_object)
|
|
2304
|
+
|
|
2305
|
+
filtered_snippets = set(self.metadata.loc[filter_mask, snippet_col].values)
|
|
2306
|
+
|
|
2307
|
+
# Map snippets to clusters
|
|
2308
|
+
cluster_snippets = []
|
|
2309
|
+
for i, snip in enumerate(self.processed_data.index):
|
|
2310
|
+
if snip not in filtered_snippets:
|
|
2311
|
+
continue
|
|
2312
|
+
cluster_label = f'Cluster_{self.clusters[i]}' if self.clusters[i] >= 0 else 'Noise'
|
|
2313
|
+
if cluster_label == selected_cluster:
|
|
2314
|
+
cluster_snippets.append(snip)
|
|
2315
|
+
|
|
2316
|
+
# Plot clip trajectory segments for selected cluster
|
|
2317
|
+
segment_count = 0
|
|
2318
|
+
for snippet_id in cluster_snippets:
|
|
2319
|
+
if snippet_id in clip_trajectories:
|
|
2320
|
+
clip_traj = clip_trajectories[snippet_id]
|
|
2321
|
+
if len(clip_traj) >= 1:
|
|
2322
|
+
clip_x = [p[0] for p in clip_traj]
|
|
2323
|
+
clip_y = [p[1] for p in clip_traj]
|
|
2324
|
+
|
|
2325
|
+
fig.add_trace(go.Scatter(
|
|
2326
|
+
x=clip_x,
|
|
2327
|
+
y=clip_y,
|
|
2328
|
+
mode='lines+markers',
|
|
2329
|
+
line=dict(
|
|
2330
|
+
color=cluster_color,
|
|
2331
|
+
width=3
|
|
2332
|
+
),
|
|
2333
|
+
marker=dict(
|
|
2334
|
+
size=4,
|
|
2335
|
+
color=cluster_color
|
|
2336
|
+
),
|
|
2337
|
+
name=selected_cluster if segment_count == 0 else None,
|
|
2338
|
+
showlegend=(segment_count == 0),
|
|
2339
|
+
legendgroup='cluster',
|
|
2340
|
+
hovertext=f'{snippet_id}',
|
|
2341
|
+
hoverinfo='text'
|
|
2342
|
+
))
|
|
2343
|
+
segment_count += 1
|
|
2344
|
+
|
|
2345
|
+
# Build title with filter info
|
|
2346
|
+
filter_info = []
|
|
2347
|
+
if selected_video != "All":
|
|
2348
|
+
filter_info.append(f"Video: {selected_video}")
|
|
2349
|
+
if selected_object != "All":
|
|
2350
|
+
filter_info.append(f"Object: {selected_object}")
|
|
2351
|
+
filter_str = f" | {', '.join(filter_info)}" if filter_info else ""
|
|
2352
|
+
|
|
2353
|
+
# Update layout
|
|
2354
|
+
fig.update_layout(
|
|
2355
|
+
title=f'Spatial Distribution: {selected_cluster} ({segment_count} clips){filter_str}',
|
|
2356
|
+
xaxis_title='X Position (pixels)',
|
|
2357
|
+
yaxis_title='Y Position (pixels)',
|
|
2358
|
+
xaxis=dict(scaleanchor="y", scaleratio=1),
|
|
2359
|
+
yaxis=dict(autorange='reversed'),
|
|
2360
|
+
height=600,
|
|
2361
|
+
template=theme if theme else None,
|
|
2362
|
+
hovermode='closest'
|
|
2363
|
+
)
|
|
2364
|
+
|
|
2365
|
+
return fig
|
|
2366
|
+
except Exception as e:
|
|
2367
|
+
logger.error("Error creating spatial distribution plot: %s", e, exc_info=True)
|
|
2368
|
+
return None
|
|
2369
|
+
|
|
2370
|
+
def get_representative_snippets(self, n_samples=10):
|
|
2371
|
+
"""Find representative snippets for each cluster (closest to centroid).
|
|
2372
|
+
|
|
2373
|
+
Returns:
|
|
2374
|
+
dict: {cluster_label: [snippet_id1, snippet_id2, ...]}
|
|
2375
|
+
"""
|
|
2376
|
+
if self.processed_data is None or self.clusters is None:
|
|
2377
|
+
return {}
|
|
2378
|
+
|
|
2379
|
+
from sklearn.metrics import pairwise_distances_argmin_min
|
|
2380
|
+
|
|
2381
|
+
representative_snippets = {}
|
|
2382
|
+
unique_clusters = sorted(set(self.clusters))
|
|
2383
|
+
|
|
2384
|
+
for cluster_id in unique_clusters:
|
|
2385
|
+
label = f'Cluster_{cluster_id}' if cluster_id >= 0 else 'Noise'
|
|
2386
|
+
if label == 'Noise':
|
|
2387
|
+
continue
|
|
2388
|
+
|
|
2389
|
+
# Get indices of samples in this cluster
|
|
2390
|
+
indices = [i for i, c in enumerate(self.clusters) if c == cluster_id]
|
|
2391
|
+
if not indices:
|
|
2392
|
+
continue
|
|
2393
|
+
|
|
2394
|
+
# Get features for these samples
|
|
2395
|
+
cluster_features = self.processed_data.iloc[indices].values
|
|
2396
|
+
snippet_ids = self.processed_data.iloc[indices].index.tolist()
|
|
2397
|
+
|
|
2398
|
+
# Compute centroid
|
|
2399
|
+
centroid = np.mean(cluster_features, axis=0).reshape(1, -1)
|
|
2400
|
+
|
|
2401
|
+
# Calculate distances to centroid
|
|
2402
|
+
from sklearn.metrics.pairwise import euclidean_distances
|
|
2403
|
+
distances = euclidean_distances(cluster_features, centroid).flatten()
|
|
2404
|
+
|
|
2405
|
+
# Get closest samples
|
|
2406
|
+
n = min(n_samples, len(indices))
|
|
2407
|
+
closest_indices = np.argsort(distances)[:n]
|
|
2408
|
+
|
|
2409
|
+
representative_snippets[label] = [snippet_ids[i] for i in closest_indices]
|
|
2410
|
+
|
|
2411
|
+
return representative_snippets
|
|
2412
|
+
|
|
2413
|
+
def _get_trajectory_data(self, selected_video="All", selected_object="All"):
|
|
2414
|
+
"""Get full trajectory and clip trajectory segments from mask data.
|
|
2415
|
+
|
|
2416
|
+
Args:
|
|
2417
|
+
selected_video: Filter by video name, or "All" for all videos
|
|
2418
|
+
selected_object: Filter by object ID, or "All" for all objects
|
|
2419
|
+
"""
|
|
2420
|
+
try:
|
|
2421
|
+
if self.metadata is None:
|
|
2422
|
+
return None
|
|
2423
|
+
|
|
2424
|
+
# Get experiment path
|
|
2425
|
+
experiment_path = self.config.get("experiment_path")
|
|
2426
|
+
if not experiment_path or not os.path.exists(experiment_path):
|
|
2427
|
+
return None
|
|
2428
|
+
|
|
2429
|
+
# Try to find mask files in common locations
|
|
2430
|
+
possible_mask_dirs = [
|
|
2431
|
+
os.path.join(experiment_path, "masks"),
|
|
2432
|
+
os.path.join(experiment_path, "segmentation_masks"),
|
|
2433
|
+
experiment_path
|
|
2434
|
+
]
|
|
2435
|
+
|
|
2436
|
+
mask_files_found = []
|
|
2437
|
+
for mask_dir in possible_mask_dirs:
|
|
2438
|
+
if os.path.exists(mask_dir):
|
|
2439
|
+
try:
|
|
2440
|
+
for f in os.listdir(mask_dir):
|
|
2441
|
+
if f.endswith(('.h5', '.hdf5')):
|
|
2442
|
+
mask_files_found.append((mask_dir, f))
|
|
2443
|
+
except (OSError, PermissionError):
|
|
2444
|
+
continue
|
|
2445
|
+
|
|
2446
|
+
if not mask_files_found:
|
|
2447
|
+
return None
|
|
2448
|
+
|
|
2449
|
+
# Get video names from metadata for matching
|
|
2450
|
+
metadata_video_names = set()
|
|
2451
|
+
if 'group' in self.metadata.columns:
|
|
2452
|
+
for v in self.metadata['group'].dropna().unique():
|
|
2453
|
+
metadata_video_names.add(str(v).strip())
|
|
2454
|
+
elif 'video_id' in self.metadata.columns:
|
|
2455
|
+
import re
|
|
2456
|
+
for v in self.metadata['video_id'].dropna().unique():
|
|
2457
|
+
base = os.path.splitext(os.path.basename(str(v)))[0]
|
|
2458
|
+
match = re.match(r'^(.+?)_clip_\d+(?:_obj\d+)?$', base)
|
|
2459
|
+
if match:
|
|
2460
|
+
metadata_video_names.add(match.group(1))
|
|
2461
|
+
else:
|
|
2462
|
+
metadata_video_names.add(base)
|
|
2463
|
+
|
|
2464
|
+
from singlebehaviorlab.backend.video_processor import load_segmentation_data
|
|
2465
|
+
|
|
2466
|
+
all_trajectories = [] # List of {name, trajectory}
|
|
2467
|
+
all_frame_centroids = {} # (video_name, obj_id, frame_idx) -> (cx, cy)
|
|
2468
|
+
|
|
2469
|
+
# Load trajectories from each mask file
|
|
2470
|
+
for mask_dir, mask_file in mask_files_found:
|
|
2471
|
+
mask_path = os.path.join(mask_dir, mask_file)
|
|
2472
|
+
|
|
2473
|
+
# Extract video name from mask filename
|
|
2474
|
+
mask_base = os.path.splitext(mask_file)[0]
|
|
2475
|
+
mask_video_name = mask_base.replace('_mask', '').replace('_objects', '').replace('_segmentation', '')
|
|
2476
|
+
|
|
2477
|
+
# Find matching video name from metadata
|
|
2478
|
+
matched_video_name = None
|
|
2479
|
+
for meta_video_name in metadata_video_names:
|
|
2480
|
+
# Check if mask filename contains metadata video name or vice versa
|
|
2481
|
+
if meta_video_name in mask_video_name or mask_video_name in meta_video_name:
|
|
2482
|
+
matched_video_name = meta_video_name
|
|
2483
|
+
break
|
|
2484
|
+
|
|
2485
|
+
# If no match found, use mask filename as video name
|
|
2486
|
+
if matched_video_name is None:
|
|
2487
|
+
matched_video_name = mask_video_name
|
|
2488
|
+
|
|
2489
|
+
# Filter by video if not "All"
|
|
2490
|
+
if selected_video != "All" and selected_video != matched_video_name:
|
|
2491
|
+
continue
|
|
2492
|
+
|
|
2493
|
+
try:
|
|
2494
|
+
mask_data = load_segmentation_data(mask_path)
|
|
2495
|
+
frame_objects = mask_data.get('frame_objects', [])
|
|
2496
|
+
|
|
2497
|
+
if not frame_objects:
|
|
2498
|
+
continue
|
|
2499
|
+
|
|
2500
|
+
video_height = mask_data.get('height', 1080)
|
|
2501
|
+
video_width = mask_data.get('width', 1920)
|
|
2502
|
+
|
|
2503
|
+
# Get unique object IDs in this mask file
|
|
2504
|
+
obj_ids = set()
|
|
2505
|
+
for frame_objs in frame_objects:
|
|
2506
|
+
for obj in frame_objs:
|
|
2507
|
+
obj_id = obj.get('obj_id', 0)
|
|
2508
|
+
obj_ids.add(str(obj_id))
|
|
2509
|
+
|
|
2510
|
+
# Process each object
|
|
2511
|
+
for obj_id_str in obj_ids:
|
|
2512
|
+
# Filter by object if not "All"
|
|
2513
|
+
if selected_object != "All" and selected_object != obj_id_str:
|
|
2514
|
+
continue
|
|
2515
|
+
|
|
2516
|
+
trajectory = []
|
|
2517
|
+
|
|
2518
|
+
for frame_idx, frame_objs in enumerate(frame_objects):
|
|
2519
|
+
for obj in frame_objs:
|
|
2520
|
+
if str(obj.get('obj_id', 0)) == obj_id_str:
|
|
2521
|
+
bbox = obj.get('bbox', (0, 0, video_width, video_height))
|
|
2522
|
+
x_min, y_min, x_max, y_max = bbox
|
|
2523
|
+
cx = (x_min + x_max) / 2.0
|
|
2524
|
+
cy = (y_min + y_max) / 2.0
|
|
2525
|
+
trajectory.append((cx, cy))
|
|
2526
|
+
# Store for clip trajectory lookup using matched video name
|
|
2527
|
+
all_frame_centroids[(matched_video_name, obj_id_str, frame_idx)] = (cx, cy)
|
|
2528
|
+
break
|
|
2529
|
+
|
|
2530
|
+
if len(trajectory) > 0:
|
|
2531
|
+
traj_name = f"{matched_video_name} obj{obj_id_str}" if obj_id_str != "0" else matched_video_name
|
|
2532
|
+
all_trajectories.append({
|
|
2533
|
+
'name': traj_name,
|
|
2534
|
+
'trajectory': trajectory,
|
|
2535
|
+
'video': matched_video_name,
|
|
2536
|
+
'object_id': obj_id_str
|
|
2537
|
+
})
|
|
2538
|
+
except Exception as e:
|
|
2539
|
+
logger.debug("Could not process trajectory for video: %s", e)
|
|
2540
|
+
continue
|
|
2541
|
+
|
|
2542
|
+
if not all_trajectories:
|
|
2543
|
+
return None
|
|
2544
|
+
|
|
2545
|
+
# Build clip trajectories for each snippet
|
|
2546
|
+
clip_trajectories = {}
|
|
2547
|
+
snippet_col = 'snippet' if 'snippet' in self.metadata.columns else ('span_id' if 'span_id' in self.metadata.columns else None)
|
|
2548
|
+
video_col = 'video_id' if 'video_id' in self.metadata.columns else None
|
|
2549
|
+
object_col = 'object_id' if 'object_id' in self.metadata.columns else None
|
|
2550
|
+
|
|
2551
|
+
if snippet_col is None:
|
|
2552
|
+
return {'all_trajectories': all_trajectories, 'clip_trajectories': {}}
|
|
2553
|
+
|
|
2554
|
+
for idx, row in self.metadata.iterrows():
|
|
2555
|
+
try:
|
|
2556
|
+
snippet_id = row[snippet_col]
|
|
2557
|
+
start_frame = row.get('start_frame')
|
|
2558
|
+
end_frame = row.get('end_frame')
|
|
2559
|
+
|
|
2560
|
+
# Get video name for this snippet (from group or extract from video_id)
|
|
2561
|
+
snippet_video_name = None
|
|
2562
|
+
if 'group' in self.metadata.columns:
|
|
2563
|
+
snippet_video_name = str(row.get('group', '')).strip()
|
|
2564
|
+
elif video_col:
|
|
2565
|
+
import re
|
|
2566
|
+
snippet_video = str(row.get(video_col, ''))
|
|
2567
|
+
base = os.path.splitext(os.path.basename(snippet_video))[0]
|
|
2568
|
+
match = re.match(r'^(.+?)_clip_\d+(?:_obj\d+)?$', base)
|
|
2569
|
+
snippet_video_name = match.group(1) if match else base
|
|
2570
|
+
|
|
2571
|
+
snippet_obj = str(row.get(object_col, '0')) if object_col else '0'
|
|
2572
|
+
|
|
2573
|
+
if pd.notna(start_frame) and pd.notna(end_frame):
|
|
2574
|
+
try:
|
|
2575
|
+
start = int(float(start_frame))
|
|
2576
|
+
end = int(float(end_frame))
|
|
2577
|
+
except (ValueError, TypeError):
|
|
2578
|
+
continue
|
|
2579
|
+
else:
|
|
2580
|
+
clip_idx = row.get('clip_index')
|
|
2581
|
+
if pd.notna(clip_idx):
|
|
2582
|
+
try:
|
|
2583
|
+
start = int(clip_idx) * 16
|
|
2584
|
+
end = start + 15
|
|
2585
|
+
except (ValueError, TypeError):
|
|
2586
|
+
continue
|
|
2587
|
+
else:
|
|
2588
|
+
continue
|
|
2589
|
+
|
|
2590
|
+
# Get trajectory segment for this clip
|
|
2591
|
+
clip_traj = []
|
|
2592
|
+
if snippet_video_name:
|
|
2593
|
+
for f in range(start, end + 1):
|
|
2594
|
+
# Try with object ID first, then without
|
|
2595
|
+
key = (snippet_video_name, snippet_obj, f)
|
|
2596
|
+
if key in all_frame_centroids:
|
|
2597
|
+
clip_traj.append(all_frame_centroids[key])
|
|
2598
|
+
else:
|
|
2599
|
+
key = (snippet_video_name, '0', f)
|
|
2600
|
+
if key in all_frame_centroids:
|
|
2601
|
+
clip_traj.append(all_frame_centroids[key])
|
|
2602
|
+
|
|
2603
|
+
if len(clip_traj) >= 1:
|
|
2604
|
+
clip_trajectories[snippet_id] = clip_traj
|
|
2605
|
+
|
|
2606
|
+
except Exception as e:
|
|
2607
|
+
logger.debug("Could not build clip trajectory for snippet: %s", e)
|
|
2608
|
+
continue
|
|
2609
|
+
|
|
2610
|
+
return {
|
|
2611
|
+
'all_trajectories': all_trajectories,
|
|
2612
|
+
'clip_trajectories': clip_trajectories
|
|
2613
|
+
}
|
|
2614
|
+
|
|
2615
|
+
except Exception as e:
|
|
2616
|
+
logger.error("Error creating spatial distribution plot: %s", e, exc_info=True)
|
|
2617
|
+
return None
|
|
2618
|
+
|
|
2619
|
+
def _save_analysis_state(self):
|
|
2620
|
+
"""Save full analysis state to a pickle file."""
|
|
2621
|
+
if self.matrix_data is None:
|
|
2622
|
+
QMessageBox.warning(self, "No Data", "No data to save. Please load data first.")
|
|
2623
|
+
return
|
|
2624
|
+
|
|
2625
|
+
try:
|
|
2626
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
2627
|
+
default_name = f"clustering_analysis_{timestamp}.pkl"
|
|
2628
|
+
|
|
2629
|
+
# Determine output directory
|
|
2630
|
+
experiment_path = self.config.get("experiment_path")
|
|
2631
|
+
initial_dir = ""
|
|
2632
|
+
if experiment_path:
|
|
2633
|
+
initial_dir = os.path.join(experiment_path, "analysis_results")
|
|
2634
|
+
os.makedirs(initial_dir, exist_ok=True)
|
|
2635
|
+
|
|
2636
|
+
path, _ = QFileDialog.getSaveFileName(
|
|
2637
|
+
self, "Save Analysis State", os.path.join(initial_dir, default_name), "Pickle Files (*.pkl)"
|
|
2638
|
+
)
|
|
2639
|
+
|
|
2640
|
+
if not path:
|
|
2641
|
+
return
|
|
2642
|
+
|
|
2643
|
+
state = {
|
|
2644
|
+
'matrix_data': self.matrix_data,
|
|
2645
|
+
'metadata': self.metadata,
|
|
2646
|
+
'processed_data': self.processed_data,
|
|
2647
|
+
'embedding': self.embedding,
|
|
2648
|
+
'clusters': self.clusters,
|
|
2649
|
+
'selected_features': self.selected_features,
|
|
2650
|
+
'snippet_to_clip_map': self.snippet_to_clip_map,
|
|
2651
|
+
'metadata_file_path': self.metadata_file_path,
|
|
2652
|
+
'timestamp': timestamp,
|
|
2653
|
+
'version': '1.0'
|
|
2654
|
+
}
|
|
2655
|
+
|
|
2656
|
+
with open(path, 'wb') as f:
|
|
2657
|
+
pickle.dump(state, f)
|
|
2658
|
+
|
|
2659
|
+
QMessageBox.information(self, "Success", f"Analysis saved to:\n{path}")
|
|
2660
|
+
|
|
2661
|
+
except Exception as e:
|
|
2662
|
+
logger.error("Failed to save analysis: %s", e, exc_info=True)
|
|
2663
|
+
QMessageBox.critical(self, "Error", f"Failed to save analysis: {str(e)}")
|
|
2664
|
+
|
|
2665
|
+
def _load_analysis_state(self):
|
|
2666
|
+
"""Load full analysis state from a pickle file."""
|
|
2667
|
+
try:
|
|
2668
|
+
experiment_path = self.config.get("experiment_path")
|
|
2669
|
+
initial_dir = ""
|
|
2670
|
+
if experiment_path:
|
|
2671
|
+
initial_dir = os.path.join(experiment_path, "analysis_results")
|
|
2672
|
+
|
|
2673
|
+
path, _ = QFileDialog.getOpenFileName(
|
|
2674
|
+
self, "Load Analysis State", initial_dir, "Pickle Files (*.pkl)"
|
|
2675
|
+
)
|
|
2676
|
+
|
|
2677
|
+
if not path:
|
|
2678
|
+
return
|
|
2679
|
+
|
|
2680
|
+
with open(path, 'rb') as f:
|
|
2681
|
+
state = pickle.load(f)
|
|
2682
|
+
|
|
2683
|
+
# Restore state
|
|
2684
|
+
self.matrix_data = state.get('matrix_data')
|
|
2685
|
+
self.metadata = state.get('metadata')
|
|
2686
|
+
self.processed_data = state.get('processed_data')
|
|
2687
|
+
self.embedding = state.get('embedding')
|
|
2688
|
+
self.clusters = state.get('clusters')
|
|
2689
|
+
self.selected_features = state.get('selected_features')
|
|
2690
|
+
self.snippet_to_clip_map = state.get('snippet_to_clip_map', {})
|
|
2691
|
+
self.metadata_file_path = state.get('metadata_file_path')
|
|
2692
|
+
|
|
2693
|
+
# Refresh UI
|
|
2694
|
+
self.status_label.setText("Analysis state loaded.")
|
|
2695
|
+
self.file_info_label.setText(f"Loaded analysis from: {os.path.basename(path)}")
|
|
2696
|
+
|
|
2697
|
+
# Re-build clip map if missing
|
|
2698
|
+
if not self.snippet_to_clip_map:
|
|
2699
|
+
self._build_snippet_to_clip_map()
|
|
2700
|
+
|
|
2701
|
+
# Update UI elements
|
|
2702
|
+
self._refresh_metadata_columns()
|
|
2703
|
+
self._refresh_cluster_list()
|
|
2704
|
+
self._refresh_cluster_export_list()
|
|
2705
|
+
self._refresh_spatial_selectors()
|
|
2706
|
+
|
|
2707
|
+
# Enable buttons
|
|
2708
|
+
has_data = self.matrix_data is not None
|
|
2709
|
+
self.run_btn.setEnabled(has_data)
|
|
2710
|
+
|
|
2711
|
+
# Update plot if embedding exists
|
|
2712
|
+
if self.embedding is not None and self.clusters is not None:
|
|
2713
|
+
# Regenerate default UMAP plot from loaded data
|
|
2714
|
+
self._regenerate_default_plot()
|
|
2715
|
+
|
|
2716
|
+
# Trigger plot update via plot type
|
|
2717
|
+
self._update_plots_by_metadata()
|
|
2718
|
+
|
|
2719
|
+
# Enable export buttons
|
|
2720
|
+
if hasattr(self, 'export_plot_btn'):
|
|
2721
|
+
self.export_plot_btn.setEnabled(True)
|
|
2722
|
+
if hasattr(self, 'export_csv_btn'):
|
|
2723
|
+
self.export_csv_btn.setEnabled(True)
|
|
2724
|
+
|
|
2725
|
+
QMessageBox.information(self, "Success", "Analysis state loaded successfully.")
|
|
2726
|
+
|
|
2727
|
+
except Exception as e:
|
|
2728
|
+
logger.error("Failed to load analysis: %s", e, exc_info=True)
|
|
2729
|
+
QMessageBox.critical(self, "Error", f"Failed to load analysis: {str(e)}")
|
|
2730
|
+
|
|
2731
|
+
def _regenerate_default_plot(self):
|
|
2732
|
+
"""Regenerate the default UMAP plot from loaded embedding and clusters."""
|
|
2733
|
+
if self.embedding is None or self.clusters is None:
|
|
2734
|
+
return
|
|
2735
|
+
|
|
2736
|
+
try:
|
|
2737
|
+
# Create sample index
|
|
2738
|
+
if self.processed_data is not None:
|
|
2739
|
+
sample_index = self.processed_data.index.tolist()
|
|
2740
|
+
else:
|
|
2741
|
+
sample_index = [f"snippet{i}" for i in range(len(self.clusters))]
|
|
2742
|
+
|
|
2743
|
+
df_plot = pd.DataFrame({
|
|
2744
|
+
'UMAP1': self.embedding[:, 0],
|
|
2745
|
+
'UMAP2': self.embedding[:, 1],
|
|
2746
|
+
'Cluster': [f'Cluster_{c}' if c >= 0 else 'Noise' for c in self.clusters],
|
|
2747
|
+
'Sample': sample_index
|
|
2748
|
+
})
|
|
2749
|
+
|
|
2750
|
+
if self.embedding.shape[1] >= 3:
|
|
2751
|
+
df_plot['UMAP3'] = self.embedding[:, 2]
|
|
2752
|
+
fig = px.scatter_3d(
|
|
2753
|
+
df_plot, x='UMAP1', y='UMAP2', z='UMAP3',
|
|
2754
|
+
color='Cluster', hover_data=['Sample'],
|
|
2755
|
+
title="UMAP Clustering (Loaded)",
|
|
2756
|
+
custom_data=[sample_index]
|
|
2757
|
+
)
|
|
2758
|
+
else:
|
|
2759
|
+
fig = px.scatter(
|
|
2760
|
+
df_plot, x='UMAP1', y='UMAP2',
|
|
2761
|
+
color='Cluster', hover_data=['Sample'],
|
|
2762
|
+
title="UMAP Clustering (Loaded)",
|
|
2763
|
+
custom_data=[sample_index]
|
|
2764
|
+
)
|
|
2765
|
+
|
|
2766
|
+
theme = self._get_plot_theme()
|
|
2767
|
+
fig.update_layout(template=theme if theme else None)
|
|
2768
|
+
point_size = self.plot_point_size_slider.value() if hasattr(self, 'plot_point_size_slider') else 5
|
|
2769
|
+
fig.update_traces(marker=dict(size=point_size))
|
|
2770
|
+
self.current_fig = fig
|
|
2771
|
+
self.plot_widget.update_plot(fig)
|
|
2772
|
+
|
|
2773
|
+
except Exception as e:
|
|
2774
|
+
logger.error("Error regenerating plot: %s", e, exc_info=True)
|
|
2775
|
+
|
|
2776
|
+
def _generate_cluster_colors(self, n_clusters):
|
|
2777
|
+
"""Generate colors for clusters."""
|
|
2778
|
+
base_colors = ['#e74c3c', '#3498db', '#2ecc71', '#f39c12', '#9b59b6', '#1abc9c', '#e67e22', '#34495e']
|
|
2779
|
+
if n_clusters <= len(base_colors):
|
|
2780
|
+
return base_colors[:n_clusters]
|
|
2781
|
+
# Generate more colors if needed
|
|
2782
|
+
import colorsys
|
|
2783
|
+
colors = []
|
|
2784
|
+
for i in range(n_clusters):
|
|
2785
|
+
hue = i / n_clusters
|
|
2786
|
+
rgb = colorsys.hsv_to_rgb(hue, 0.7, 0.9)
|
|
2787
|
+
colors.append(f'#{int(rgb[0]*255):02x}{int(rgb[1]*255):02x}{int(rgb[2]*255):02x}')
|
|
2788
|
+
return colors
|
|
2789
|
+
|
|
2790
|
+
def _generate_pastel_palette(self, n_colors):
|
|
2791
|
+
"""Generate pastel color palette."""
|
|
2792
|
+
base_colors = [
|
|
2793
|
+
'#FF6B6B', '#4ECDC4', '#45B7D1', '#FFA500', '#FF1493',
|
|
2794
|
+
'#32CD32', '#9B59B6', '#FF8C00', '#00CED1', '#DA70D6'
|
|
2795
|
+
]
|
|
2796
|
+
if n_colors <= len(base_colors):
|
|
2797
|
+
return base_colors[:n_colors]
|
|
2798
|
+
# Cycle through colors
|
|
2799
|
+
return [base_colors[i % len(base_colors)] for i in range(n_colors)]
|
|
2800
|
+
|
|
2801
|
+
def _save_metadata_to_file(self, df: pd.DataFrame, path: str):
|
|
2802
|
+
"""Save metadata DataFrame to file, respecting the file format based on extension."""
|
|
2803
|
+
if path.endswith(".npz"):
|
|
2804
|
+
# Save as NPZ format
|
|
2805
|
+
np.savez_compressed(
|
|
2806
|
+
path,
|
|
2807
|
+
metadata=df.values,
|
|
2808
|
+
columns=np.array(df.columns, dtype=object),
|
|
2809
|
+
)
|
|
2810
|
+
elif path.endswith(".parquet"):
|
|
2811
|
+
df.to_parquet(path, index=False)
|
|
2812
|
+
else:
|
|
2813
|
+
# Default to CSV
|
|
2814
|
+
df.to_csv(path, index=False)
|
|
2815
|
+
|
|
2816
|
+
def _build_snippet_to_clip_map(self):
|
|
2817
|
+
"""Build mapping from snippet IDs to clip file paths."""
|
|
2818
|
+
self.snippet_to_clip_map = {}
|
|
2819
|
+
|
|
2820
|
+
if self.metadata is None:
|
|
2821
|
+
return
|
|
2822
|
+
|
|
2823
|
+
# Get experiment path and registered_clips directory
|
|
2824
|
+
experiment_path = self.config.get("experiment_path")
|
|
2825
|
+
if not experiment_path:
|
|
2826
|
+
return
|
|
2827
|
+
|
|
2828
|
+
registered_clips_dir = os.path.join(experiment_path, "registered_clips")
|
|
2829
|
+
if not os.path.exists(registered_clips_dir):
|
|
2830
|
+
return
|
|
2831
|
+
|
|
2832
|
+
# Get snippet column name
|
|
2833
|
+
snippet_col = 'snippet' if 'snippet' in self.metadata.columns else ('span_id' if 'span_id' in self.metadata.columns else None)
|
|
2834
|
+
if snippet_col is None:
|
|
2835
|
+
return
|
|
2836
|
+
|
|
2837
|
+
# Find all clip files recursively (clips are in subdirectories: registered_clips/video_name/clip_XXXXXX.mp4)
|
|
2838
|
+
import glob
|
|
2839
|
+
clip_files = glob.glob(os.path.join(registered_clips_dir, "**", "*.avi"), recursive=True)
|
|
2840
|
+
# Also find legacy .mp4 clips for backwards compatibility
|
|
2841
|
+
clip_files += glob.glob(os.path.join(registered_clips_dir, "**", "*.mp4"), recursive=True)
|
|
2842
|
+
clip_name_to_path = {}
|
|
2843
|
+
for clip_file in clip_files:
|
|
2844
|
+
clip_name = os.path.basename(clip_file)
|
|
2845
|
+
clip_name_to_path[clip_name] = clip_file
|
|
2846
|
+
clip_name_to_path[clip_name.lower()] = clip_file
|
|
2847
|
+
|
|
2848
|
+
# Map snippets to clips using video_id from metadata
|
|
2849
|
+
# The metadata has 'video_id' which contains the clip filename (e.g., 'clip_000000_obj1.mp4')
|
|
2850
|
+
for idx, row in self.metadata.iterrows():
|
|
2851
|
+
snippet_id = str(row.get(snippet_col, ''))
|
|
2852
|
+
if not snippet_id:
|
|
2853
|
+
continue
|
|
2854
|
+
|
|
2855
|
+
# Get video_id (clip filename) from metadata
|
|
2856
|
+
video_id = str(row.get('video_id', ''))
|
|
2857
|
+
|
|
2858
|
+
# Try exact match first
|
|
2859
|
+
if video_id and video_id in clip_name_to_path:
|
|
2860
|
+
self.snippet_to_clip_map[snippet_id] = clip_name_to_path[video_id]
|
|
2861
|
+
continue
|
|
2862
|
+
|
|
2863
|
+
# Try case-insensitive match
|
|
2864
|
+
if video_id and video_id.lower() in clip_name_to_path:
|
|
2865
|
+
self.snippet_to_clip_map[snippet_id] = clip_name_to_path[video_id.lower()]
|
|
2866
|
+
continue
|
|
2867
|
+
|
|
2868
|
+
# Fallback: try to match by clip filename pattern
|
|
2869
|
+
clip_index = row.get('clip_index', None)
|
|
2870
|
+
object_id = str(row.get('object_id', '')) if pd.notna(row.get('object_id')) and row.get('object_id') else None
|
|
2871
|
+
|
|
2872
|
+
if clip_index is not None:
|
|
2873
|
+
clip_idx_str = f"{int(clip_index):06d}"
|
|
2874
|
+
for clip_name, clip_path in clip_name_to_path.items():
|
|
2875
|
+
if clip_name.islower():
|
|
2876
|
+
continue
|
|
2877
|
+
if clip_idx_str in clip_name:
|
|
2878
|
+
if object_id:
|
|
2879
|
+
if f"_obj{object_id}" in clip_name:
|
|
2880
|
+
self.snippet_to_clip_map[snippet_id] = clip_path
|
|
2881
|
+
break
|
|
2882
|
+
else:
|
|
2883
|
+
if "_obj" not in clip_name:
|
|
2884
|
+
self.snippet_to_clip_map[snippet_id] = clip_path
|
|
2885
|
+
break
|
|
2886
|
+
|
|
2887
|
+
# If still not found, try matching by position in metadata
|
|
2888
|
+
if snippet_id not in self.snippet_to_clip_map:
|
|
2889
|
+
try:
|
|
2890
|
+
snippet_num = int(snippet_id.replace('snippet', '')) - 1
|
|
2891
|
+
if 0 <= snippet_num < len(clip_files):
|
|
2892
|
+
sorted_clips = sorted(clip_files)
|
|
2893
|
+
self.snippet_to_clip_map[snippet_id] = sorted_clips[snippet_num]
|
|
2894
|
+
except (ValueError, IndexError):
|
|
2895
|
+
pass
|
|
2896
|
+
|
|
2897
|
+
def _on_umap_point_clicked(self, snippet_id: str):
|
|
2898
|
+
"""Handle click on UMAP point - open video popup for corresponding clip."""
|
|
2899
|
+
if not snippet_id:
|
|
2900
|
+
return
|
|
2901
|
+
|
|
2902
|
+
# Build snippet-to-clip mapping if not already done
|
|
2903
|
+
if not self.snippet_to_clip_map:
|
|
2904
|
+
self._build_snippet_to_clip_map()
|
|
2905
|
+
|
|
2906
|
+
# Find clip file
|
|
2907
|
+
clip_path = self.snippet_to_clip_map.get(snippet_id)
|
|
2908
|
+
if not clip_path or not os.path.exists(clip_path):
|
|
2909
|
+
QMessageBox.warning(self, "Clip Not Found",
|
|
2910
|
+
f"Could not find clip file for snippet: {snippet_id}\n\n"
|
|
2911
|
+
f"Please ensure clips are extracted in the Registration tab.")
|
|
2912
|
+
return
|
|
2913
|
+
|
|
2914
|
+
# Get metadata for this snippet
|
|
2915
|
+
clip_metadata = None
|
|
2916
|
+
if self.metadata is not None:
|
|
2917
|
+
snippet_col = 'snippet' if 'snippet' in self.metadata.columns else ('span_id' if 'span_id' in self.metadata.columns else None)
|
|
2918
|
+
if snippet_col:
|
|
2919
|
+
snippet_row = self.metadata[self.metadata[snippet_col].astype(str) == str(snippet_id)]
|
|
2920
|
+
if len(snippet_row) > 0:
|
|
2921
|
+
row = snippet_row.iloc[0]
|
|
2922
|
+
start_frame = row.get('start_frame')
|
|
2923
|
+
end_frame = row.get('end_frame')
|
|
2924
|
+
video_id = row.get('video_id', '')
|
|
2925
|
+
|
|
2926
|
+
# Get FPS from video file
|
|
2927
|
+
fps = None
|
|
2928
|
+
try:
|
|
2929
|
+
import cv2
|
|
2930
|
+
cap = cv2.VideoCapture(clip_path)
|
|
2931
|
+
if cap.isOpened():
|
|
2932
|
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
|
2933
|
+
cap.release()
|
|
2934
|
+
except Exception as e:
|
|
2935
|
+
logger.debug("Could not read clip FPS: %s", e)
|
|
2936
|
+
|
|
2937
|
+
# Only set clip metadata if frame indices are valid numbers
|
|
2938
|
+
try:
|
|
2939
|
+
# Check if values are not NaN/None and not empty strings
|
|
2940
|
+
start_valid = (pd.notna(start_frame) and str(start_frame).strip() != '')
|
|
2941
|
+
end_valid = (pd.notna(end_frame) and str(end_frame).strip() != '')
|
|
2942
|
+
|
|
2943
|
+
if start_valid and end_valid:
|
|
2944
|
+
clip_metadata = {
|
|
2945
|
+
'start_frame': int(float(start_frame)),
|
|
2946
|
+
'end_frame': int(float(end_frame)),
|
|
2947
|
+
'context_frames': 30, # Default context frames
|
|
2948
|
+
'fps': fps if fps else 30.0
|
|
2949
|
+
}
|
|
2950
|
+
except (ValueError, TypeError):
|
|
2951
|
+
# If conversion fails, skip metadata
|
|
2952
|
+
clip_metadata = None
|
|
2953
|
+
|
|
2954
|
+
# Open video popup
|
|
2955
|
+
self._open_video_popup(clip_path, snippet_id, clip_metadata)
|
|
2956
|
+
|
|
2957
|
+
def _open_video_popup(self, file_path: str, label: str = None, clip_metadata: dict = None):
|
|
2958
|
+
"""Open video in popup dialog with timeline indicator (from clustering_behavior)"""
|
|
2959
|
+
from PyQt6.QtWidgets import QDialog, QVBoxLayout, QHBoxLayout, QCheckBox
|
|
2960
|
+
from PyQt6.QtMultimedia import QMediaPlayer, QAudioOutput
|
|
2961
|
+
from PyQt6.QtMultimediaWidgets import QVideoWidget
|
|
2962
|
+
from PyQt6.QtCore import QUrl
|
|
2963
|
+
from .plot_integration import TimelineWidget
|
|
2964
|
+
|
|
2965
|
+
dialog = QDialog(self)
|
|
2966
|
+
dialog.setWindowTitle(f"Video: {label if label else os.path.basename(file_path)}")
|
|
2967
|
+
dialog.setMinimumSize(800, 650)
|
|
2968
|
+
|
|
2969
|
+
layout = QVBoxLayout(dialog)
|
|
2970
|
+
layout.setSpacing(5)
|
|
2971
|
+
|
|
2972
|
+
# Video container with timeline
|
|
2973
|
+
video_container = QWidget()
|
|
2974
|
+
video_layout = QVBoxLayout(video_container)
|
|
2975
|
+
video_layout.setContentsMargins(0, 0, 0, 0)
|
|
2976
|
+
video_layout.setSpacing(0)
|
|
2977
|
+
|
|
2978
|
+
# Expanded video widget
|
|
2979
|
+
expanded_video = QVideoWidget()
|
|
2980
|
+
expanded_video.setMinimumSize(800, 600)
|
|
2981
|
+
|
|
2982
|
+
# Calculate clip boundaries in milliseconds (for clip-only playback)
|
|
2983
|
+
clip_start_ms = None
|
|
2984
|
+
clip_end_ms = None
|
|
2985
|
+
if clip_metadata:
|
|
2986
|
+
start_frame = clip_metadata.get('start_frame')
|
|
2987
|
+
end_frame = clip_metadata.get('end_frame')
|
|
2988
|
+
context_frames = clip_metadata.get('context_frames', 30)
|
|
2989
|
+
fps = clip_metadata.get('fps')
|
|
2990
|
+
|
|
2991
|
+
if start_frame is not None and end_frame is not None and fps:
|
|
2992
|
+
# Calculate clip boundaries in milliseconds
|
|
2993
|
+
# The extracted video starts at context_start = max(0, start_frame - context_frames)
|
|
2994
|
+
# So in the extracted video, the clip starts at (start_frame - context_start)
|
|
2995
|
+
context_start = max(0, start_frame - context_frames)
|
|
2996
|
+
clip_start_in_extracted = start_frame - context_start
|
|
2997
|
+
clip_end_in_extracted = end_frame - context_start
|
|
2998
|
+
clip_start_ms = int((clip_start_in_extracted / fps) * 1000)
|
|
2999
|
+
clip_end_ms = int((clip_end_in_extracted / fps) * 1000)
|
|
3000
|
+
|
|
3001
|
+
# Timeline widget
|
|
3002
|
+
timeline_widget = TimelineWidget(clip_metadata=clip_metadata)
|
|
3003
|
+
timeline_widget.setFixedHeight(30)
|
|
3004
|
+
timeline_widget.setStyleSheet("""
|
|
3005
|
+
QWidget {
|
|
3006
|
+
background-color: #2b2b2b;
|
|
3007
|
+
border-top: 1px solid #555;
|
|
3008
|
+
}
|
|
3009
|
+
""")
|
|
3010
|
+
|
|
3011
|
+
# Expanded player
|
|
3012
|
+
expanded_player = QMediaPlayer()
|
|
3013
|
+
expanded_audio = QAudioOutput()
|
|
3014
|
+
expanded_player.setAudioOutput(expanded_audio)
|
|
3015
|
+
expanded_player.setVideoOutput(expanded_video)
|
|
3016
|
+
expanded_player.setSource(QUrl.fromLocalFile(os.path.abspath(file_path)))
|
|
3017
|
+
|
|
3018
|
+
# Clip-only mode state
|
|
3019
|
+
clip_only_mode = [False]
|
|
3020
|
+
|
|
3021
|
+
# Enhanced position tracking with clip boundary enforcement
|
|
3022
|
+
def update_timeline_position_with_clip(position):
|
|
3023
|
+
try:
|
|
3024
|
+
if timeline_widget and timeline_widget.isVisible():
|
|
3025
|
+
timeline_widget.set_current_position(position, expanded_player.duration())
|
|
3026
|
+
|
|
3027
|
+
# Enforce clip boundaries when clip-only mode is enabled
|
|
3028
|
+
if clip_only_mode[0] and clip_start_ms is not None and clip_end_ms is not None:
|
|
3029
|
+
if position < clip_start_ms:
|
|
3030
|
+
# Jump to clip start if before clip
|
|
3031
|
+
expanded_player.setPosition(clip_start_ms)
|
|
3032
|
+
elif position > clip_end_ms:
|
|
3033
|
+
# If looping is enabled, jump back to clip start
|
|
3034
|
+
if loop_enabled[0]:
|
|
3035
|
+
expanded_player.setPosition(clip_start_ms)
|
|
3036
|
+
else:
|
|
3037
|
+
# Otherwise pause at clip end
|
|
3038
|
+
expanded_player.pause()
|
|
3039
|
+
play_btn.setText("▶ Play")
|
|
3040
|
+
except RuntimeError:
|
|
3041
|
+
# Widget was deleted, ignore
|
|
3042
|
+
pass
|
|
3043
|
+
|
|
3044
|
+
def update_timeline_duration(duration):
|
|
3045
|
+
try:
|
|
3046
|
+
if timeline_widget and timeline_widget.isVisible():
|
|
3047
|
+
timeline_widget.set_duration(duration)
|
|
3048
|
+
except RuntimeError:
|
|
3049
|
+
# Widget was deleted, ignore
|
|
3050
|
+
pass
|
|
3051
|
+
|
|
3052
|
+
expanded_player.positionChanged.connect(update_timeline_position_with_clip)
|
|
3053
|
+
expanded_player.durationChanged.connect(update_timeline_duration)
|
|
3054
|
+
|
|
3055
|
+
# Loop state
|
|
3056
|
+
loop_enabled = [False]
|
|
3057
|
+
|
|
3058
|
+
def on_playback_state_changed(state):
|
|
3059
|
+
if loop_enabled[0] and state == QMediaPlayer.PlaybackState.StoppedState:
|
|
3060
|
+
if clip_only_mode[0] and clip_start_ms is not None:
|
|
3061
|
+
expanded_player.setPosition(clip_start_ms)
|
|
3062
|
+
else:
|
|
3063
|
+
expanded_player.setPosition(0)
|
|
3064
|
+
expanded_player.play()
|
|
3065
|
+
|
|
3066
|
+
def on_media_status_changed(status):
|
|
3067
|
+
if loop_enabled[0] and status == QMediaPlayer.MediaStatus.EndOfMedia:
|
|
3068
|
+
if clip_only_mode[0] and clip_start_ms is not None:
|
|
3069
|
+
expanded_player.setPosition(clip_start_ms)
|
|
3070
|
+
else:
|
|
3071
|
+
expanded_player.setPosition(0)
|
|
3072
|
+
expanded_player.play()
|
|
3073
|
+
|
|
3074
|
+
expanded_player.playbackStateChanged.connect(on_playback_state_changed)
|
|
3075
|
+
expanded_player.mediaStatusChanged.connect(on_media_status_changed)
|
|
3076
|
+
|
|
3077
|
+
video_layout.addWidget(expanded_video)
|
|
3078
|
+
video_layout.addWidget(timeline_widget)
|
|
3079
|
+
|
|
3080
|
+
# Controls
|
|
3081
|
+
controls = QHBoxLayout()
|
|
3082
|
+
|
|
3083
|
+
def toggle_play():
|
|
3084
|
+
if expanded_player.playbackState() == QMediaPlayer.PlaybackState.PlayingState:
|
|
3085
|
+
expanded_player.pause()
|
|
3086
|
+
play_btn.setText("▶ Play")
|
|
3087
|
+
else:
|
|
3088
|
+
expanded_player.play()
|
|
3089
|
+
play_btn.setText("⏸ Pause")
|
|
3090
|
+
|
|
3091
|
+
play_btn = QPushButton("▶ Play")
|
|
3092
|
+
play_btn.clicked.connect(toggle_play)
|
|
3093
|
+
|
|
3094
|
+
def toggle_loop(checked):
|
|
3095
|
+
loop_enabled[0] = checked
|
|
3096
|
+
if checked:
|
|
3097
|
+
loop_btn.setText("🔁 Loop ON")
|
|
3098
|
+
else:
|
|
3099
|
+
loop_btn.setText("🔁 Loop")
|
|
3100
|
+
|
|
3101
|
+
loop_btn = QPushButton("🔁 Loop")
|
|
3102
|
+
loop_btn.setCheckable(True)
|
|
3103
|
+
loop_btn.toggled.connect(toggle_loop)
|
|
3104
|
+
|
|
3105
|
+
# Clip-only mode checkbox (only show if clip metadata is available)
|
|
3106
|
+
clip_only_chk = None
|
|
3107
|
+
if clip_start_ms is not None and clip_end_ms is not None:
|
|
3108
|
+
clip_only_chk = QCheckBox("Clip only")
|
|
3109
|
+
clip_only_chk.setToolTip("When checked, play and loop only the clip area (without context frames)")
|
|
3110
|
+
|
|
3111
|
+
def toggle_clip_only(checked):
|
|
3112
|
+
clip_only_mode[0] = checked
|
|
3113
|
+
if checked:
|
|
3114
|
+
# Jump to clip start when enabling clip-only mode
|
|
3115
|
+
expanded_player.setPosition(clip_start_ms)
|
|
3116
|
+
# Auto-enable loop for better UX in clip-only mode
|
|
3117
|
+
if not loop_enabled[0]:
|
|
3118
|
+
loop_btn.setChecked(True)
|
|
3119
|
+
# If unchecked, allow playback from current position (full video)
|
|
3120
|
+
|
|
3121
|
+
clip_only_chk.toggled.connect(toggle_clip_only)
|
|
3122
|
+
|
|
3123
|
+
# Playback speed controls
|
|
3124
|
+
speed_label = QLabel("Speed:")
|
|
3125
|
+
speed_1x_btn = QPushButton("1x")
|
|
3126
|
+
speed_1x_btn.setCheckable(True)
|
|
3127
|
+
speed_1x_btn.setChecked(True)
|
|
3128
|
+
speed_05x_btn = QPushButton("0.5x")
|
|
3129
|
+
speed_05x_btn.setCheckable(True)
|
|
3130
|
+
speed_025x_btn = QPushButton("0.25x")
|
|
3131
|
+
speed_025x_btn.setCheckable(True)
|
|
3132
|
+
speed_0166x_btn = QPushButton("0.166x")
|
|
3133
|
+
speed_0166x_btn.setCheckable(True)
|
|
3134
|
+
|
|
3135
|
+
# Speed button group
|
|
3136
|
+
speed_buttons = [speed_1x_btn, speed_05x_btn, speed_025x_btn, speed_0166x_btn]
|
|
3137
|
+
current_speed = [1.0] # Use list to allow modification in nested functions
|
|
3138
|
+
|
|
3139
|
+
def set_speed(speed, button):
|
|
3140
|
+
current_speed[0] = speed
|
|
3141
|
+
expanded_player.setPlaybackRate(speed)
|
|
3142
|
+
# Update button states
|
|
3143
|
+
for btn in speed_buttons:
|
|
3144
|
+
btn.setChecked(btn == button)
|
|
3145
|
+
|
|
3146
|
+
speed_1x_btn.clicked.connect(lambda: set_speed(1.0, speed_1x_btn))
|
|
3147
|
+
speed_05x_btn.clicked.connect(lambda: set_speed(0.5, speed_05x_btn))
|
|
3148
|
+
speed_025x_btn.clicked.connect(lambda: set_speed(0.25, speed_025x_btn))
|
|
3149
|
+
speed_0166x_btn.clicked.connect(lambda: set_speed(1.0/6.0, speed_0166x_btn))
|
|
3150
|
+
|
|
3151
|
+
close_btn = QPushButton("Close")
|
|
3152
|
+
close_btn.clicked.connect(dialog.close)
|
|
3153
|
+
|
|
3154
|
+
controls.addWidget(play_btn)
|
|
3155
|
+
controls.addWidget(loop_btn)
|
|
3156
|
+
if clip_only_chk is not None:
|
|
3157
|
+
controls.addWidget(clip_only_chk)
|
|
3158
|
+
controls.addWidget(speed_label)
|
|
3159
|
+
controls.addWidget(speed_1x_btn)
|
|
3160
|
+
controls.addWidget(speed_05x_btn)
|
|
3161
|
+
controls.addWidget(speed_025x_btn)
|
|
3162
|
+
controls.addWidget(speed_0166x_btn)
|
|
3163
|
+
controls.addStretch()
|
|
3164
|
+
controls.addWidget(close_btn)
|
|
3165
|
+
|
|
3166
|
+
layout.addWidget(video_container)
|
|
3167
|
+
layout.addLayout(controls)
|
|
3168
|
+
|
|
3169
|
+
# Auto-play when opened (if clip-only mode is enabled, start at clip start)
|
|
3170
|
+
if clip_only_mode[0] and clip_start_ms is not None:
|
|
3171
|
+
expanded_player.setPosition(clip_start_ms)
|
|
3172
|
+
expanded_player.play()
|
|
3173
|
+
play_btn.setText("⏸ Pause")
|
|
3174
|
+
|
|
3175
|
+
dialog.exec()
|
|
3176
|
+
|
|
3177
|
+
# Cleanup - disconnect signals before stopping to prevent errors
|
|
3178
|
+
try:
|
|
3179
|
+
expanded_player.positionChanged.disconnect()
|
|
3180
|
+
expanded_player.durationChanged.disconnect()
|
|
3181
|
+
expanded_player.playbackStateChanged.disconnect()
|
|
3182
|
+
expanded_player.mediaStatusChanged.disconnect()
|
|
3183
|
+
except (RuntimeError, TypeError):
|
|
3184
|
+
# Signals already disconnected or widget deleted
|
|
3185
|
+
pass
|
|
3186
|
+
|
|
3187
|
+
expanded_player.stop()
|