singlebehaviorlab 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. sam2/__init__.py +11 -0
  2. sam2/automatic_mask_generator.py +454 -0
  3. sam2/benchmark.py +92 -0
  4. sam2/build_sam.py +174 -0
  5. sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
  6. sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
  7. sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
  8. sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
  9. sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
  10. sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
  11. sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
  12. sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
  13. sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
  14. sam2/modeling/__init__.py +5 -0
  15. sam2/modeling/backbones/__init__.py +5 -0
  16. sam2/modeling/backbones/hieradet.py +317 -0
  17. sam2/modeling/backbones/image_encoder.py +134 -0
  18. sam2/modeling/backbones/utils.py +93 -0
  19. sam2/modeling/memory_attention.py +169 -0
  20. sam2/modeling/memory_encoder.py +181 -0
  21. sam2/modeling/position_encoding.py +239 -0
  22. sam2/modeling/sam/__init__.py +5 -0
  23. sam2/modeling/sam/mask_decoder.py +295 -0
  24. sam2/modeling/sam/prompt_encoder.py +202 -0
  25. sam2/modeling/sam/transformer.py +311 -0
  26. sam2/modeling/sam2_base.py +913 -0
  27. sam2/modeling/sam2_utils.py +323 -0
  28. sam2/sam2_hiera_b+.yaml +113 -0
  29. sam2/sam2_hiera_l.yaml +117 -0
  30. sam2/sam2_hiera_s.yaml +116 -0
  31. sam2/sam2_hiera_t.yaml +118 -0
  32. sam2/sam2_image_predictor.py +466 -0
  33. sam2/sam2_video_predictor.py +1388 -0
  34. sam2/sam2_video_predictor_legacy.py +1172 -0
  35. sam2/utils/__init__.py +5 -0
  36. sam2/utils/amg.py +348 -0
  37. sam2/utils/misc.py +349 -0
  38. sam2/utils/transforms.py +118 -0
  39. singlebehaviorlab/__init__.py +4 -0
  40. singlebehaviorlab/__main__.py +130 -0
  41. singlebehaviorlab/_paths.py +100 -0
  42. singlebehaviorlab/backend/__init__.py +2 -0
  43. singlebehaviorlab/backend/augmentations.py +320 -0
  44. singlebehaviorlab/backend/data_store.py +420 -0
  45. singlebehaviorlab/backend/model.py +1290 -0
  46. singlebehaviorlab/backend/train.py +4667 -0
  47. singlebehaviorlab/backend/uncertainty.py +578 -0
  48. singlebehaviorlab/backend/video_processor.py +688 -0
  49. singlebehaviorlab/backend/video_utils.py +139 -0
  50. singlebehaviorlab/data/config/config.yaml +85 -0
  51. singlebehaviorlab/data/training_profiles.json +334 -0
  52. singlebehaviorlab/gui/__init__.py +4 -0
  53. singlebehaviorlab/gui/analysis_widget.py +2291 -0
  54. singlebehaviorlab/gui/attention_export.py +311 -0
  55. singlebehaviorlab/gui/clip_extraction_widget.py +481 -0
  56. singlebehaviorlab/gui/clustering_widget.py +3187 -0
  57. singlebehaviorlab/gui/inference_popups.py +1138 -0
  58. singlebehaviorlab/gui/inference_widget.py +4550 -0
  59. singlebehaviorlab/gui/inference_worker.py +651 -0
  60. singlebehaviorlab/gui/labeling_widget.py +2324 -0
  61. singlebehaviorlab/gui/main_window.py +754 -0
  62. singlebehaviorlab/gui/metadata_management_widget.py +1119 -0
  63. singlebehaviorlab/gui/motion_tracking.py +764 -0
  64. singlebehaviorlab/gui/overlay_export.py +1234 -0
  65. singlebehaviorlab/gui/plot_integration.py +729 -0
  66. singlebehaviorlab/gui/qt_helpers.py +29 -0
  67. singlebehaviorlab/gui/registration_widget.py +1485 -0
  68. singlebehaviorlab/gui/review_widget.py +1330 -0
  69. singlebehaviorlab/gui/segmentation_tracking_widget.py +2752 -0
  70. singlebehaviorlab/gui/tab_tutorial_dialog.py +312 -0
  71. singlebehaviorlab/gui/timeline_themes.py +131 -0
  72. singlebehaviorlab/gui/training_profiles.py +418 -0
  73. singlebehaviorlab/gui/training_widget.py +3719 -0
  74. singlebehaviorlab/gui/video_utils.py +233 -0
  75. singlebehaviorlab/licenses/SAM2-LICENSE +201 -0
  76. singlebehaviorlab/licenses/VideoPrism-LICENSE +202 -0
  77. singlebehaviorlab-2.0.0.dist-info/METADATA +447 -0
  78. singlebehaviorlab-2.0.0.dist-info/RECORD +88 -0
  79. singlebehaviorlab-2.0.0.dist-info/WHEEL +5 -0
  80. singlebehaviorlab-2.0.0.dist-info/entry_points.txt +2 -0
  81. singlebehaviorlab-2.0.0.dist-info/licenses/LICENSE +21 -0
  82. singlebehaviorlab-2.0.0.dist-info/top_level.txt +3 -0
  83. videoprism/__init__.py +0 -0
  84. videoprism/encoders.py +910 -0
  85. videoprism/layers.py +1136 -0
  86. videoprism/models.py +407 -0
  87. videoprism/tokenizers.py +167 -0
  88. videoprism/utils.py +168 -0
@@ -0,0 +1,3187 @@
1
+ """
2
+ Clustering Widget for SingleBehavior Lab.
3
+ Integrates preprocessing and clustering (UMAP, Leiden, HBSCAN) of behaviorome embeddings.
4
+ """
5
+
6
+ import logging
7
+ import os
8
+ import pandas as pd
9
+
10
+ logger = logging.getLogger(__name__)
11
+ import numpy as np
12
+ import umap
13
+ import leidenalg as la
14
+ import igraph as ig
15
+ from sklearn.neighbors import kneighbors_graph
16
+ from sklearn.preprocessing import StandardScaler, MinMaxScaler, Normalizer
17
+ from sklearn.decomposition import PCA
18
+ import hdbscan
19
+ import plotly.express as px
20
+ import plotly.graph_objects as go
21
+ from plotly.subplots import make_subplots
22
+ import plotly.io as pio
23
+ from datetime import datetime
24
+ import pickle
25
+
26
+ from PyQt6.QtWidgets import (
27
+ QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QLabel,
28
+ QComboBox, QSlider, QCheckBox, QGroupBox, QScrollArea, QSplitter,
29
+ QMessageBox, QListWidget, QTextEdit, QFileDialog, QProgressBar, QDialog,
30
+ QSizePolicy, QDialogButtonBox
31
+ )
32
+ from PyQt6.QtCore import Qt, QThread, pyqtSignal
33
+ from PyQt6.QtGui import QFont
34
+
35
+ from .plot_integration import PlotlyWidget
36
+ from .qt_helpers import create_status_label, update_status_label, create_section
37
+
38
+ class ClusteringWorker(QThread):
39
+ """Worker thread for clustering computation"""
40
+ finished = pyqtSignal(str, object) # status message, figure
41
+
42
+ def __init__(self, clustering_widget, params, data):
43
+ super().__init__()
44
+ self.clustering_widget = clustering_widget
45
+ self.params = params
46
+ self.data = data
47
+
48
+ def run(self):
49
+ """Run clustering in background thread"""
50
+ try:
51
+ status, fig = self.clustering_widget.perform_clustering(self.data, **self.params)
52
+ self.finished.emit(status, fig)
53
+ except Exception as e:
54
+ logger.error("Error running clustering: %s", e, exc_info=True)
55
+ self.finished.emit(f"Error: {str(e)}", None)
56
+
57
+
58
+ class LoadDataWorker(QThread):
59
+ """Background loader for large embedding matrices."""
60
+ loaded = pyqtSignal(object, object, str) # matrix_df, metadata_df, metadata_path
61
+ error = pyqtSignal(str)
62
+
63
+ def __init__(self, matrix_path: str, metadata_path: str | None):
64
+ super().__init__()
65
+ self.matrix_path = matrix_path
66
+ self.metadata_path = metadata_path
67
+
68
+ def run(self):
69
+ try:
70
+ matrix_df = self._load_matrix(self.matrix_path)
71
+ metadata_df = None
72
+ meta_path = self.metadata_path
73
+ if self.metadata_path and os.path.exists(self.metadata_path):
74
+ metadata_df = self._load_metadata(self.metadata_path)
75
+ else:
76
+ meta_path = None
77
+ self.loaded.emit(matrix_df, metadata_df, meta_path)
78
+ except Exception as e:
79
+ logger.error("Error loading data: %s", e, exc_info=True)
80
+ self.error.emit(str(e))
81
+
82
+ def _load_matrix(self, path: str) -> pd.DataFrame:
83
+ if path.endswith(".npz"):
84
+ with np.load(path, allow_pickle=True) as data:
85
+ matrix = data["matrix"]
86
+ feature_names = data["feature_names"]
87
+ snippet_ids = data["snippet_ids"] if "snippet_ids" in data else data.get("span_ids", None) # Backward compatibility
88
+ if snippet_ids is None:
89
+ # Fallback: generate snippet IDs
90
+ snippet_ids = np.array([f'snippet{i+1}' for i in range(matrix.shape[1])])
91
+ return pd.DataFrame(matrix, index=feature_names, columns=snippet_ids)
92
+ if path.endswith(".parquet"):
93
+ return pd.read_parquet(path)
94
+ # default csv
95
+ return pd.read_csv(path, index_col=0)
96
+
97
+ def _load_metadata(self, path: str) -> pd.DataFrame:
98
+ if path.endswith(".npz"):
99
+ with np.load(path, allow_pickle=True) as data:
100
+ metadata_values = data["metadata"]
101
+ columns = list(data["columns"])
102
+ return pd.DataFrame(metadata_values, columns=columns)
103
+ elif path.endswith(".parquet"):
104
+ return pd.read_parquet(path)
105
+ else:
106
+ return pd.read_csv(path)
107
+
108
+
109
+ class ClusterExportDialog(QDialog):
110
+ """Dialog shown after cluster export with option to load the new dataset."""
111
+
112
+ def __init__(self, parent, message: str, matrix_path: str, metadata_path: str):
113
+ super().__init__(parent)
114
+ self.matrix_path = matrix_path
115
+ self.metadata_path = metadata_path
116
+ self.load_requested = False
117
+
118
+ self.setWindowTitle("Cluster export complete")
119
+ self.setMinimumWidth(500)
120
+
121
+ layout = QVBoxLayout(self)
122
+
123
+ # Message label
124
+ msg_label = QLabel(message)
125
+ msg_label.setWordWrap(True)
126
+ layout.addWidget(msg_label)
127
+
128
+ # Buttons
129
+ button_layout = QHBoxLayout()
130
+ load_btn = QPushButton("Load dataset")
131
+ load_btn.clicked.connect(self._on_load_clicked)
132
+ ok_btn = QPushButton("OK")
133
+ ok_btn.clicked.connect(self.accept)
134
+ button_layout.addWidget(load_btn)
135
+ button_layout.addStretch()
136
+ button_layout.addWidget(ok_btn)
137
+ layout.addLayout(button_layout)
138
+
139
+ def _on_load_clicked(self):
140
+ """Mark that user wants to load the dataset."""
141
+ self.load_requested = True
142
+ self.accept()
143
+
144
+
145
+ class ClusteringWidget(QWidget):
146
+ """Widget for clustering behaviorome embeddings."""
147
+
148
+ def __init__(self, config: dict):
149
+ super().__init__()
150
+ self.config = config
151
+ self.matrix_data = None
152
+ self.metadata = None
153
+ self.metadata_file_path = None # Store path to metadata file for updates
154
+ self.processed_data = None
155
+ self.embedding = None
156
+ self.clusters = None
157
+ self.current_fig = None
158
+ self.current_df = None
159
+ self.snippet_to_clip_map = {} # Map snippet_id -> clip_path
160
+
161
+ # Preprocessing state
162
+ self.selected_features = None
163
+
164
+ self._setup_ui()
165
+
166
+ def update_config(self, config: dict):
167
+ """Update configuration."""
168
+ self.config = config
169
+
170
+ def _setup_ui(self):
171
+ """Setup UI components."""
172
+ self.main_layout = QVBoxLayout(self)
173
+
174
+ # Splitter: Settings on left, Plot on right
175
+ splitter = QSplitter(Qt.Orientation.Horizontal)
176
+
177
+ # Left Panel: Settings (Scrollable)
178
+ settings_scroll = QScrollArea()
179
+ settings_scroll.setWidgetResizable(True)
180
+ settings_scroll.setMinimumWidth(300)
181
+ settings_scroll.setMaximumWidth(350)
182
+ settings_scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAsNeeded)
183
+ settings_scroll.setVerticalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAsNeeded)
184
+
185
+ settings_widget = QWidget()
186
+ settings_layout = QVBoxLayout(settings_widget)
187
+ settings_layout.setContentsMargins(5, 5, 5, 5)
188
+ settings_layout.setSpacing(5)
189
+
190
+ # 1. Data Loading Section
191
+ data_group = QGroupBox("Data loading")
192
+ data_layout = QVBoxLayout()
193
+ data_layout.setSpacing(5)
194
+
195
+ self.load_status_label = QLabel("No data loaded")
196
+ self.load_status_label.setWordWrap(True)
197
+ self.load_status_label.setTextFormat(Qt.TextFormat.PlainText)
198
+ self.load_status_label.setAlignment(Qt.AlignmentFlag.AlignTop | Qt.AlignmentFlag.AlignLeft)
199
+ # Set size policy to prevent expansion
200
+ self.load_status_label.setSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Minimum)
201
+ data_layout.addWidget(self.load_status_label)
202
+
203
+ self.load_btn = QPushButton("Load from registration")
204
+ self.load_btn.clicked.connect(self.load_data)
205
+ data_layout.addWidget(self.load_btn)
206
+
207
+ self.load_file_btn = QPushButton("Load external matrix...")
208
+ self.load_file_btn.clicked.connect(self.load_external_data)
209
+ data_layout.addWidget(self.load_file_btn)
210
+
211
+ self.load_progress = QProgressBar()
212
+ self.load_progress.setVisible(False)
213
+ self.load_progress.setRange(0, 0) # indeterminate
214
+ data_layout.addWidget(self.load_progress)
215
+
216
+ data_group.setLayout(data_layout)
217
+ settings_layout.addWidget(data_group)
218
+
219
+ # 2. Preprocessing Section (Collapsible-ish via GroupBox)
220
+ preprocess_group = QGroupBox("Preprocessing")
221
+ preprocess_layout = QVBoxLayout()
222
+
223
+ # Normalization
224
+ norm_row = QHBoxLayout()
225
+ norm_row.addWidget(QLabel("Normalization:"))
226
+ self.normalization_method = QComboBox()
227
+ self.normalization_method.addItems(["none", "l2", "standard", "minmax"])
228
+ self.normalization_method.setCurrentText("none") # Default for embeddings
229
+ self.normalization_method.setToolTip(
230
+ "None: Embeddings usually normalized\n"
231
+ "L2: Unit norm (good for embeddings)\n"
232
+ "Standard: Zero mean, unit var\n"
233
+ "MinMax: [0,1] range"
234
+ )
235
+ norm_row.addWidget(self.normalization_method)
236
+ preprocess_layout.addLayout(norm_row)
237
+
238
+ self.preprocess_btn = QPushButton("Apply preprocessing")
239
+ self.preprocess_btn.clicked.connect(self.apply_preprocessing)
240
+ preprocess_layout.addWidget(self.preprocess_btn)
241
+
242
+ self.preprocess_status = QLabel("Ready")
243
+ self.preprocess_status.setWordWrap(True)
244
+ self.preprocess_status.setSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Minimum)
245
+ preprocess_layout.addWidget(self.preprocess_status)
246
+
247
+ preprocess_group.setLayout(preprocess_layout)
248
+ settings_layout.addWidget(preprocess_group)
249
+
250
+ # 3. Clustering Parameters
251
+ cluster_params_group = QGroupBox("Clustering & projection")
252
+ cluster_params_layout = QVBoxLayout()
253
+
254
+ cluster_params_layout.addWidget(QLabel("<b>Dimensionality Reduction (UMAP)</b>"))
255
+
256
+ self.n_neighbors, self.n_neighbors_lbl = self._create_slider(
257
+ "Neighbors:", 2, 200, 30, 1
258
+ )
259
+ cluster_params_layout.addWidget(self._slider_widget(self.n_neighbors, self.n_neighbors_lbl))
260
+
261
+ self.min_dist, self.min_dist_lbl = self._create_slider(
262
+ "Min Dist:", 0.0, 0.99, 0.1, 0.01, is_float=True
263
+ )
264
+ cluster_params_layout.addWidget(self._slider_widget(self.min_dist, self.min_dist_lbl))
265
+
266
+ self.n_components, self.n_components_lbl = self._create_slider(
267
+ "Components:", 2, 3, 2, 1
268
+ )
269
+ cluster_params_layout.addWidget(self._slider_widget(self.n_components, self.n_components_lbl))
270
+
271
+ # Clustering Method
272
+ cluster_params_layout.addWidget(QLabel("<b>Clustering Method</b>"))
273
+ self.clustering_method = QComboBox()
274
+ self.clustering_method.addItems(['leiden', 'hdbscan']) # Only these two + UMAP visualization
275
+ self.clustering_method.currentIndexChanged.connect(self._toggle_clustering_params)
276
+ cluster_params_layout.addWidget(self.clustering_method)
277
+
278
+ # Leiden Params
279
+ self.leiden_container = QWidget()
280
+ leiden_layout = QVBoxLayout(self.leiden_container)
281
+ leiden_layout.setContentsMargins(0,0,0,0)
282
+
283
+ self.leiden_resolution, self.leiden_res_lbl = self._create_slider(
284
+ "Resolution:", 0.1, 5.0, 1.0, 0.1, is_float=True
285
+ )
286
+ leiden_layout.addWidget(self._slider_widget(self.leiden_resolution, self.leiden_res_lbl))
287
+
288
+ self.leiden_k, self.leiden_k_lbl = self._create_slider(
289
+ "K-Neighbors:", 2, 100, 15, 1
290
+ )
291
+ leiden_layout.addWidget(self._slider_widget(self.leiden_k, self.leiden_k_lbl))
292
+
293
+ cluster_params_layout.addWidget(self.leiden_container)
294
+
295
+ # HDBSCAN Params
296
+ self.hdbscan_container = QWidget()
297
+ hdbscan_layout = QVBoxLayout(self.hdbscan_container)
298
+ hdbscan_layout.setContentsMargins(0,0,0,0)
299
+
300
+ self.min_cluster_size, self.min_cluster_size_lbl = self._create_slider(
301
+ "Min Cluster Size:", 2, 100, 5, 1
302
+ )
303
+ hdbscan_layout.addWidget(self._slider_widget(self.min_cluster_size, self.min_cluster_size_lbl))
304
+
305
+ self.min_samples, self.min_samples_lbl = self._create_slider(
306
+ "Min Samples:", 1, 50, 1, 1
307
+ )
308
+ hdbscan_layout.addWidget(self._slider_widget(self.min_samples, self.min_samples_lbl))
309
+
310
+ self.hdbscan_epsilon, self.hdbscan_eps_lbl = self._create_slider(
311
+ "Epsilon:", 0.0, 5.0, 0.0, 0.1, is_float=True
312
+ )
313
+ hdbscan_layout.addWidget(self._slider_widget(self.hdbscan_epsilon, self.hdbscan_eps_lbl))
314
+
315
+ cluster_params_layout.addWidget(self.hdbscan_container)
316
+ self.hdbscan_container.hide() # Hide initially
317
+
318
+ # Run Button
319
+ self.run_btn = QPushButton("Run analysis")
320
+ self.run_btn.setStyleSheet("background-color: #007bff; color: white; font-weight: bold; padding: 5px;")
321
+ self.run_btn.clicked.connect(self.run_clustering)
322
+ cluster_params_layout.addWidget(self.run_btn)
323
+
324
+ cluster_params_group.setLayout(cluster_params_layout)
325
+ settings_layout.addWidget(cluster_params_group)
326
+
327
+ # 4. Metadata Management
328
+ metadata_group = QGroupBox("Metadata")
329
+ metadata_layout = QVBoxLayout()
330
+ self.manage_metadata_btn = QPushButton("Manage metadata")
331
+ self.manage_metadata_btn.clicked.connect(self.open_metadata_manager)
332
+ metadata_layout.addWidget(self.manage_metadata_btn)
333
+ metadata_group.setLayout(metadata_layout)
334
+ settings_layout.addWidget(metadata_group)
335
+
336
+ # 5. Cluster Export
337
+ cluster_export_group = QGroupBox("Cluster export")
338
+ cluster_export_layout = QVBoxLayout()
339
+
340
+ # Help button with explanation
341
+ help_row = QHBoxLayout()
342
+ help_row.addWidget(QLabel("Select clusters:"))
343
+ self.cluster_export_help_btn = QPushButton("?")
344
+ self.cluster_export_help_btn.setMaximumWidth(30)
345
+ self.cluster_export_help_btn.setToolTip("Click for help")
346
+ self.cluster_export_help_btn.clicked.connect(self._show_cluster_export_help)
347
+ help_row.addWidget(self.cluster_export_help_btn)
348
+ help_row.addStretch()
349
+ cluster_export_layout.addLayout(help_row)
350
+
351
+ self.cluster_export_list = QListWidget()
352
+ self.cluster_export_list.setSelectionMode(QListWidget.SelectionMode.ExtendedSelection)
353
+ cluster_export_layout.addWidget(self.cluster_export_list)
354
+
355
+ self.use_raw_data_checkbox = QCheckBox("Use raw data")
356
+ self.use_raw_data_checkbox.setToolTip("If checked, exports raw (unnormalized) data. Otherwise exports preprocessed data.")
357
+ cluster_export_layout.addWidget(self.use_raw_data_checkbox)
358
+
359
+ self.extract_cluster_btn = QPushButton("Extract selected cluster")
360
+ self.extract_cluster_btn.clicked.connect(self._extract_cluster)
361
+ self.extract_cluster_btn.setEnabled(False)
362
+ cluster_export_layout.addWidget(self.extract_cluster_btn)
363
+
364
+ self.exclude_clusters_btn = QPushButton("Exclude selected clusters")
365
+ self.exclude_clusters_btn.clicked.connect(self._exclude_clusters)
366
+ self.exclude_clusters_btn.setEnabled(False)
367
+ cluster_export_layout.addWidget(self.exclude_clusters_btn)
368
+
369
+ cluster_export_group.setLayout(cluster_export_layout)
370
+ settings_layout.addWidget(cluster_export_group)
371
+
372
+ settings_layout.addStretch()
373
+ settings_widget.setLayout(settings_layout)
374
+ settings_scroll.setWidget(settings_widget)
375
+
376
+ splitter.addWidget(settings_scroll)
377
+
378
+ # Middle Panel: Plot
379
+ plot_container = QWidget()
380
+ plot_layout = QVBoxLayout(plot_container)
381
+ plot_layout.setContentsMargins(0, 0, 0, 0)
382
+
383
+ self.plot_widget = PlotlyWidget()
384
+ # Set up click callback for snippet selection
385
+ self.plot_widget.set_click_callback(self._on_umap_point_clicked)
386
+ plot_layout.addWidget(self.plot_widget)
387
+
388
+ # Status label (minimal, at bottom)
389
+ self.status_label = QLabel("Ready")
390
+ self.status_label.setMaximumHeight(20)
391
+ self.status_label.setWordWrap(True)
392
+ self.status_label.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Maximum)
393
+ plot_layout.addWidget(self.status_label)
394
+
395
+ splitter.addWidget(plot_container)
396
+
397
+ # Right Panel: Plot Settings
398
+ plot_settings_scroll = QScrollArea()
399
+ plot_settings_scroll.setWidgetResizable(True)
400
+ plot_settings_scroll.setMinimumWidth(300)
401
+ plot_settings_scroll.setMaximumWidth(350)
402
+ plot_settings_scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAsNeeded)
403
+ plot_settings_scroll.setVerticalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAsNeeded)
404
+
405
+ plot_settings_widget = QWidget()
406
+ plot_settings_layout = QVBoxLayout(plot_settings_widget)
407
+ plot_settings_layout.setContentsMargins(5, 5, 5, 5)
408
+ plot_settings_layout.setSpacing(5)
409
+
410
+ # Plot Settings Group
411
+ plot_settings_group = QGroupBox("Plot settings")
412
+ plot_settings_group_layout = QVBoxLayout()
413
+ plot_settings_group_layout.setSpacing(5)
414
+
415
+ # Color Theme Selector
416
+ plot_settings_group_layout.addWidget(QLabel("<b>Color Theme:</b>"))
417
+ self.color_theme_combo = QComboBox()
418
+ self.color_theme_combo.addItems(["plotly", "plotly_white", "plotly_dark", "ggplot2", "seaborn", "simple_white", "none"])
419
+ self.color_theme_combo.setCurrentText("simple_white")
420
+ self.color_theme_combo.currentIndexChanged.connect(self._update_plots_by_metadata)
421
+ plot_settings_group_layout.addWidget(self.color_theme_combo)
422
+
423
+ # Metadata Column Selector
424
+ plot_settings_group_layout.addWidget(QLabel("<b>Group by metadata column:</b>"))
425
+ self.metadata_column_combo = QComboBox()
426
+ self.metadata_column_combo.addItem("None (Show all clusters)")
427
+ self.metadata_column_combo.currentTextChanged.connect(self._update_plots_by_metadata)
428
+ plot_settings_group_layout.addWidget(self.metadata_column_combo)
429
+
430
+ # Point Size for grouped plots
431
+ point_size_label = QLabel("Point size:")
432
+ self.plot_point_size_slider = QSlider(Qt.Orientation.Horizontal)
433
+ self.plot_point_size_slider.setMinimum(1)
434
+ self.plot_point_size_slider.setMaximum(20)
435
+ self.plot_point_size_slider.setValue(5)
436
+ self.plot_point_size_label = QLabel("5")
437
+ self.plot_point_size_slider.valueChanged.connect(
438
+ lambda v: (self.plot_point_size_label.setText(str(v)), self._update_plots_by_metadata())
439
+ )
440
+ point_size_layout = QHBoxLayout()
441
+ point_size_layout.addWidget(point_size_label)
442
+ point_size_layout.addWidget(self.plot_point_size_slider)
443
+ point_size_layout.addWidget(self.plot_point_size_label)
444
+ plot_settings_group_layout.addLayout(point_size_layout)
445
+
446
+ # Plot Type Selector
447
+ plot_settings_group_layout.addWidget(QLabel("<b>Plot type:</b>"))
448
+ self.plot_type_combo = QComboBox()
449
+ self.plot_type_combo.addItems(["UMAP", "Cluster Proportions", "Single Cluster Analysis", "Spatial cluster distribution"])
450
+ self.plot_type_combo.currentTextChanged.connect(self._update_plot_type)
451
+ plot_settings_group_layout.addWidget(self.plot_type_combo)
452
+
453
+ # Single Cluster Analysis (only show when relevant)
454
+ self.single_cluster_label = QLabel("<b>Select cluster:</b>")
455
+ self.single_cluster_combo = QComboBox()
456
+ self.single_cluster_combo.addItem("None")
457
+ self.single_cluster_combo.currentTextChanged.connect(self._update_plots_by_metadata)
458
+ plot_settings_group_layout.addWidget(self.single_cluster_label)
459
+ plot_settings_group_layout.addWidget(self.single_cluster_combo)
460
+ self.single_cluster_label.setVisible(False)
461
+ self.single_cluster_combo.setVisible(False)
462
+
463
+ # Spatial distribution: Video and Object selectors
464
+ self.spatial_video_label = QLabel("<b>Video:</b>")
465
+ self.spatial_video_combo = QComboBox()
466
+ self.spatial_video_combo.addItem("All")
467
+ self.spatial_video_combo.currentTextChanged.connect(self._on_spatial_video_changed)
468
+ plot_settings_group_layout.addWidget(self.spatial_video_label)
469
+ plot_settings_group_layout.addWidget(self.spatial_video_combo)
470
+ self.spatial_video_label.setVisible(False)
471
+ self.spatial_video_combo.setVisible(False)
472
+
473
+ self.spatial_object_label = QLabel("<b>Object:</b>")
474
+ self.spatial_object_combo = QComboBox()
475
+ self.spatial_object_combo.addItem("All")
476
+ self.spatial_object_combo.currentTextChanged.connect(self._update_plots_by_metadata)
477
+ plot_settings_group_layout.addWidget(self.spatial_object_label)
478
+ plot_settings_group_layout.addWidget(self.spatial_object_combo)
479
+ self.spatial_object_label.setVisible(False)
480
+ self.spatial_object_combo.setVisible(False)
481
+
482
+ plot_settings_group.setLayout(plot_settings_group_layout)
483
+ plot_settings_layout.addWidget(plot_settings_group)
484
+
485
+ # Export Group
486
+ export_group = QGroupBox("Export")
487
+ export_layout = QVBoxLayout()
488
+
489
+ self.export_plot_btn = QPushButton("Export plot (PDF/SVG)")
490
+ self.export_plot_btn.clicked.connect(self._export_plot)
491
+ self.export_plot_btn.setEnabled(False)
492
+ export_layout.addWidget(self.export_plot_btn)
493
+
494
+ self.export_csv_btn = QPushButton("Export results (CSV)")
495
+ self.export_csv_btn.clicked.connect(self.export_results)
496
+ self.export_csv_btn.setEnabled(False)
497
+ export_layout.addWidget(self.export_csv_btn)
498
+
499
+ export_group.setLayout(export_layout)
500
+ plot_settings_layout.addWidget(export_group)
501
+
502
+ # Analysis State Group
503
+ state_group = QGroupBox("Analysis state")
504
+ state_layout = QVBoxLayout()
505
+
506
+ self.save_state_btn = QPushButton("Save full analysis")
507
+ self.save_state_btn.clicked.connect(self._save_analysis_state)
508
+ state_layout.addWidget(self.save_state_btn)
509
+
510
+ self.load_state_btn = QPushButton("Load full analysis")
511
+ self.load_state_btn.clicked.connect(self._load_analysis_state)
512
+ state_layout.addWidget(self.load_state_btn)
513
+
514
+ self.file_info_label = QLabel("No analysis loaded")
515
+ self.file_info_label.setWordWrap(True)
516
+ self.file_info_label.setStyleSheet("color: gray; font-style: italic;")
517
+ state_layout.addWidget(self.file_info_label)
518
+
519
+ state_group.setLayout(state_layout)
520
+ plot_settings_layout.addWidget(state_group)
521
+
522
+ plot_settings_layout.addStretch()
523
+ plot_settings_widget.setLayout(plot_settings_layout)
524
+ plot_settings_scroll.setWidget(plot_settings_widget)
525
+
526
+ splitter.addWidget(plot_settings_scroll)
527
+ splitter.setStretchFactor(1, 3) # Plot takes most space
528
+
529
+ self.main_layout.addWidget(splitter)
530
+
531
+ def _create_slider(self, label_text, min_val, max_val, default, step, is_float=False):
532
+ slider = QSlider(Qt.Orientation.Horizontal)
533
+ if is_float:
534
+ slider.setMinimum(int(min_val * 100))
535
+ slider.setMaximum(int(max_val * 100))
536
+ slider.setValue(int(default * 100))
537
+ slider.setSingleStep(int(step * 100))
538
+ value_label = QLabel(f"{default:.2f}")
539
+ slider.valueChanged.connect(lambda v: value_label.setText(f"{v/100:.2f}"))
540
+ else:
541
+ slider.setMinimum(min_val)
542
+ slider.setMaximum(max_val)
543
+ slider.setValue(default)
544
+ slider.setSingleStep(step)
545
+ value_label = QLabel(str(default))
546
+ slider.valueChanged.connect(lambda v: value_label.setText(str(v)))
547
+ slider.label_text = label_text
548
+ return slider, value_label
549
+
550
+ def _slider_widget(self, slider, label):
551
+ widget = QWidget()
552
+ layout = QHBoxLayout()
553
+ layout.setContentsMargins(0, 0, 0, 0)
554
+ layout.addWidget(QLabel(slider.label_text))
555
+ layout.addWidget(slider)
556
+ layout.addWidget(label)
557
+ widget.setLayout(layout)
558
+ return widget
559
+
560
+ def _get_slider_value(self, slider, is_float=False):
561
+ if is_float:
562
+ return slider.value() / 100.0
563
+ return slider.value()
564
+
565
+ def _toggle_clustering_params(self):
566
+ method = self.clustering_method.currentText()
567
+ if method == 'leiden':
568
+ self.leiden_container.show()
569
+ self.hdbscan_container.hide()
570
+ else:
571
+ self.leiden_container.hide()
572
+ self.hdbscan_container.show()
573
+
574
+ def _find_latest_behaviorome(self, directory: str):
575
+ """Pick the latest behaviorome matrix with preferred extensions."""
576
+ exts = ["npz", "parquet", "csv"]
577
+ candidates = []
578
+ for fname in os.listdir(directory):
579
+ for ext in exts:
580
+ if fname.startswith("behaviorome_") and fname.endswith(f"_matrix.{ext}"):
581
+ candidates.append(os.path.join(directory, fname))
582
+ if not candidates:
583
+ return None, None
584
+ candidates.sort(reverse=True)
585
+ matrix_path = candidates[0]
586
+ base, ext = os.path.splitext(matrix_path)
587
+ # ext includes .npz etc; build metadata preference (NPZ first, then Parquet, then CSV)
588
+ metadata_path = None
589
+ for meta_ext in ["npz", "parquet", "csv"]:
590
+ candidate_meta = base.replace("_matrix", "_metadata") + f".{meta_ext}"
591
+ if os.path.exists(candidate_meta):
592
+ metadata_path = candidate_meta
593
+ break
594
+ return matrix_path, metadata_path
595
+
596
+ def _load_files_async(self, matrix_path: str, metadata_path: str | None):
597
+ """Start background load with progress bar."""
598
+ self.load_progress.setVisible(True)
599
+ self.load_status_label.setText("Loading data...")
600
+ self.load_status_label.setStyleSheet("color: black;")
601
+ self.load_progress.setRange(0, 0)
602
+ self.load_worker = LoadDataWorker(matrix_path, metadata_path)
603
+ self.load_worker.loaded.connect(self._on_loaded_data)
604
+ self.load_worker.error.connect(self._on_load_error)
605
+ self.load_worker.start()
606
+
607
+ def _on_loaded_data(self, matrix_df: pd.DataFrame, metadata_df: pd.DataFrame | None, metadata_path: str | None):
608
+ self.load_progress.setVisible(False)
609
+ # Downcast to float32 to reduce memory for large matrices
610
+ try:
611
+ self.matrix_data = matrix_df.astype(np.float32, copy=False)
612
+ except Exception:
613
+ self.matrix_data = matrix_df
614
+ self.metadata = metadata_df
615
+ self.metadata_file_path = metadata_path
616
+ self.processed_data = None
617
+ self.preprocess_status.setText("Raw data loaded")
618
+ shape = (self.matrix_data.shape[0], self.matrix_data.shape[1]) if self.matrix_data is not None else (0, 0)
619
+
620
+ # Format status text to fit in container (truncate long filenames)
621
+ max_filename_len = 35
622
+ status_text = f"Matrix: {shape[0]} x {shape[1]}"
623
+ if metadata_path:
624
+ meta_filename = os.path.basename(metadata_path)
625
+ if len(meta_filename) > max_filename_len:
626
+ meta_filename = meta_filename[:max_filename_len-3] + "..."
627
+ status_text += f"\nMeta: {meta_filename}"
628
+
629
+ self.load_status_label.setText(status_text)
630
+ self.load_status_label.setStyleSheet("color: green;")
631
+ self.apply_preprocessing()
632
+
633
+ def _on_load_error(self, msg: str):
634
+ self.load_progress.setVisible(False)
635
+ QMessageBox.critical(self, "Load Error", f"Failed to load data: {msg}")
636
+
637
+ def load_data(self):
638
+ """Load data from experiment folder."""
639
+ experiment_path = self.config.get("experiment_path")
640
+ if not experiment_path:
641
+ QMessageBox.warning(self, "No Experiment", "Please create or load an experiment first.")
642
+ return
643
+
644
+ registered_clips_dir = os.path.join(experiment_path, "registered_clips")
645
+ if not os.path.exists(registered_clips_dir):
646
+ QMessageBox.warning(self, "No Data", "No registered clips directory found.")
647
+ return
648
+
649
+ matrix_path, metadata_path = self._find_latest_behaviorome(registered_clips_dir)
650
+ if not matrix_path:
651
+ QMessageBox.warning(self, "No Data", "No behaviorome matrix files found (npz/parquet/csv).")
652
+ return
653
+ self._load_files_async(matrix_path, metadata_path)
654
+
655
+ def load_external_data(self):
656
+ """Load external data (npz/parquet/csv)."""
657
+ matrix_path, _ = QFileDialog.getOpenFileName(
658
+ self,
659
+ "Open Feature Matrix",
660
+ "",
661
+ "Matrices (*.npz *.parquet *.csv);;All Files (*)"
662
+ )
663
+ if not matrix_path:
664
+ return
665
+
666
+ metadata_path, _ = QFileDialog.getOpenFileName(
667
+ self,
668
+ "Open Metadata",
669
+ os.path.dirname(matrix_path),
670
+ "Metadata (*.parquet *.csv);;All Files (*)"
671
+ )
672
+ self._load_files_async(matrix_path, metadata_path if metadata_path else None)
673
+
674
+ def load_from_registration(self, matrix_path: str, metadata_path: str):
675
+ """Load data from registration tab (NPZ/Parquet paths)."""
676
+ self._load_data_files(matrix_path, metadata_path)
677
+
678
+ def _load_csvs(self, matrix_path, metadata_path):
679
+ """Load from CSV files (legacy support)."""
680
+ self._load_data_files(matrix_path, metadata_path)
681
+
682
+ def _load_data_files(self, matrix_path: str, metadata_path: str = None):
683
+ """Load data from matrix and metadata files (supports NPZ, Parquet, CSV)."""
684
+ try:
685
+ # Load matrix
686
+ if matrix_path.endswith('.npz'):
687
+ npz_data = np.load(matrix_path, allow_pickle=True) # Need allow_pickle for string arrays
688
+ matrix = npz_data['matrix'] # features x snippets
689
+ feature_names = npz_data['feature_names']
690
+ snippet_ids = npz_data['snippet_ids'] if 'snippet_ids' in npz_data else npz_data.get('span_ids', None) # Backward compatibility
691
+ if snippet_ids is None:
692
+ # Fallback: generate snippet IDs
693
+ snippet_ids = np.array([f'snippet{i+1}' for i in range(matrix.shape[1])])
694
+ self.matrix_data = pd.DataFrame(matrix, index=feature_names, columns=snippet_ids)
695
+ elif matrix_path.endswith('.parquet'):
696
+ self.matrix_data = pd.read_parquet(matrix_path, engine='pyarrow')
697
+ else: # CSV
698
+ self.matrix_data = pd.read_csv(matrix_path, index_col=0)
699
+
700
+ # Load metadata
701
+ if metadata_path:
702
+ if metadata_path.endswith('.npz'):
703
+ npz_meta = np.load(metadata_path, allow_pickle=True) # Need allow_pickle for string arrays
704
+ metadata_array = npz_meta['metadata']
705
+ columns = npz_meta['columns']
706
+ self.metadata = pd.DataFrame(metadata_array, columns=columns)
707
+ elif metadata_path.endswith('.parquet'):
708
+ self.metadata = pd.read_parquet(metadata_path, engine='pyarrow')
709
+ else: # CSV
710
+ self.metadata = pd.read_csv(metadata_path)
711
+ self.metadata_file_path = metadata_path
712
+ else:
713
+ self.metadata_file_path = None
714
+ self.metadata = None
715
+
716
+ # Reset processing
717
+ self.processed_data = None
718
+ self.preprocess_status.setText("Raw data loaded")
719
+
720
+ # Format status text to fit in container
721
+ filename = os.path.basename(matrix_path)
722
+ max_filename_len = 35
723
+ if len(filename) > max_filename_len:
724
+ filename = filename[:max_filename_len-3] + "..."
725
+
726
+ status_text = f"Matrix: {self.matrix_data.shape[0]} x {self.matrix_data.shape[1]}\nFile: {filename}"
727
+ if metadata_path:
728
+ meta_filename = os.path.basename(metadata_path)
729
+ if len(meta_filename) > max_filename_len:
730
+ meta_filename = meta_filename[:max_filename_len-3] + "..."
731
+ status_text += f"\nMeta: {meta_filename}"
732
+
733
+ self.load_status_label.setText(status_text)
734
+ self.load_status_label.setStyleSheet("color: green;")
735
+
736
+ # Auto-apply default preprocessing
737
+ self.apply_preprocessing()
738
+
739
+ # Refresh metadata columns in plot settings
740
+ self._refresh_metadata_columns()
741
+
742
+ except Exception as e:
743
+ QMessageBox.critical(self, "Load Error", f"Failed to load data: {e}")
744
+
745
+ def apply_preprocessing(self):
746
+ """Apply normalization."""
747
+ if self.matrix_data is None:
748
+ return
749
+
750
+ try:
751
+ data = self.matrix_data.copy()
752
+
753
+ # Transpose for sklearn (samples as rows)
754
+ # Matrix format: Rows=Features, Cols=Samples usually?
755
+ # Check registration widget format:
756
+ # feature_matrix.shape = (n_samples, embed_dim) in extraction worker list, then transposed?
757
+ # Registration widget: pd.DataFrame(feature_matrix.T, index=feature_names, columns=snippet_ids)
758
+ # So Rows are Features (dimensions), Columns are Samples (snippets).
759
+ # sklearn expects Samples as Rows. So we transpose.
760
+
761
+ X = data.T
762
+
763
+ # Clean infinite/NaN
764
+ X = X.replace([np.inf, -np.inf], np.nan)
765
+
766
+ # Normalize
767
+ norm_method = self.normalization_method.currentText()
768
+ if norm_method == 'standard':
769
+ scaler = StandardScaler()
770
+ X_norm = scaler.fit_transform(X)
771
+ elif norm_method == 'minmax':
772
+ scaler = MinMaxScaler()
773
+ X_norm = scaler.fit_transform(X)
774
+ elif norm_method == 'l2':
775
+ scaler = Normalizer(norm='l2')
776
+ X_norm = scaler.fit_transform(X)
777
+ else:
778
+ X_norm = X
779
+
780
+ # Store processed data (Samples x Features)
781
+ self.processed_data = pd.DataFrame(X_norm, index=X.index, columns=X.columns)
782
+
783
+ self.preprocess_status.setText(f"Normalized: {norm_method}")
784
+ self.preprocess_status.setStyleSheet("color: green;")
785
+
786
+ except Exception as e:
787
+ QMessageBox.critical(self, "Preprocessing Error", f"Error: {e}")
788
+
789
+ def run_clustering(self):
790
+ """Start clustering worker."""
791
+ if self.processed_data is None:
792
+ QMessageBox.warning(self, "No Data", "Please load and preprocess data first.")
793
+ return
794
+
795
+ self.run_btn.setEnabled(False)
796
+ self.run_btn.setText("Running...")
797
+ self.status_label.setText("Computing clustering...")
798
+
799
+ params = {
800
+ 'n_neighbors': self._get_slider_value(self.n_neighbors),
801
+ 'min_dist': self._get_slider_value(self.min_dist, is_float=True),
802
+ 'n_components': self._get_slider_value(self.n_components),
803
+ 'method': self.clustering_method.currentText(),
804
+ 'leiden_resolution': self._get_slider_value(self.leiden_resolution, is_float=True),
805
+ 'leiden_k': self._get_slider_value(self.leiden_k),
806
+ 'min_cluster_size': self._get_slider_value(self.min_cluster_size),
807
+ 'min_samples': self._get_slider_value(self.min_samples),
808
+ 'hdbscan_epsilon': self._get_slider_value(self.hdbscan_epsilon, is_float=True)
809
+ }
810
+
811
+ self.worker = ClusteringWorker(self, params, self.processed_data)
812
+ self.worker.finished.connect(self.on_clustering_finished)
813
+ self.worker.start()
814
+
815
+ def perform_clustering(self, data, **params):
816
+ """Execute clustering logic (runs in thread)."""
817
+ # UMAP Embedding
818
+ reducer = umap.UMAP(
819
+ n_neighbors=params['n_neighbors'],
820
+ min_dist=params['min_dist'],
821
+ n_components=params['n_components'],
822
+ random_state=42
823
+ )
824
+ embedding = reducer.fit_transform(data)
825
+ self.embedding = embedding # Store for export
826
+
827
+ # Clustering
828
+ method = params['method']
829
+ if method == 'leiden':
830
+ # Construct k-NN graph
831
+ knn_graph = kneighbors_graph(
832
+ data,
833
+ n_neighbors=params['leiden_k'],
834
+ mode='connectivity',
835
+ include_self=False
836
+ )
837
+ sources, targets = knn_graph.nonzero()
838
+ edges = list(zip(sources.tolist(), targets.tolist()))
839
+ g = ig.Graph(n=data.shape[0], edges=edges, directed=False)
840
+
841
+ partition = la.find_partition(
842
+ g,
843
+ la.RBConfigurationVertexPartition,
844
+ resolution_parameter=params['leiden_resolution']
845
+ )
846
+ clusters = np.array(partition.membership)
847
+
848
+ elif method == 'hdbscan':
849
+ clusterer = hdbscan.HDBSCAN(
850
+ min_cluster_size=params['min_cluster_size'],
851
+ min_samples=params['min_samples'],
852
+ cluster_selection_epsilon=params['hdbscan_epsilon']
853
+ )
854
+ clusters = clusterer.fit_predict(data)
855
+
856
+ self.clusters = clusters
857
+
858
+ # Plotting
859
+ df_plot = pd.DataFrame({
860
+ 'UMAP1': embedding[:, 0],
861
+ 'UMAP2': embedding[:, 1],
862
+ 'Cluster': [f'Cluster_{c}' if c >= 0 else 'Noise' for c in clusters],
863
+ 'Sample': data.index
864
+ })
865
+
866
+ # Add customdata with snippet IDs for click handling
867
+ snippet_ids = data.index.tolist()
868
+
869
+ if params['n_components'] == 3:
870
+ df_plot['UMAP3'] = embedding[:, 2]
871
+ fig = px.scatter_3d(
872
+ df_plot, x='UMAP1', y='UMAP2', z='UMAP3',
873
+ color='Cluster', hover_data=['Sample'],
874
+ title=f"UMAP + {method.title()} Clustering",
875
+ custom_data=[snippet_ids] # Add snippet IDs for click handling
876
+ )
877
+ else:
878
+ fig = px.scatter(
879
+ df_plot, x='UMAP1', y='UMAP2',
880
+ color='Cluster', hover_data=['Sample'],
881
+ title=f"UMAP + {method.title()} Clustering",
882
+ custom_data=[snippet_ids] # Add snippet IDs for click handling
883
+ )
884
+
885
+ theme = self._get_plot_theme()
886
+ fig.update_layout(template=theme if theme else None)
887
+ return "Clustering Complete", fig
888
+
889
+ def on_clustering_finished(self, status, fig):
890
+ self.run_btn.setEnabled(True)
891
+ self.run_btn.setText("Run analysis")
892
+ self.status_label.setText(status)
893
+
894
+ if fig:
895
+ point_size = self.plot_point_size_slider.value() if hasattr(self, 'plot_point_size_slider') else 5
896
+ fig.update_traces(marker=dict(size=point_size))
897
+ self.current_fig = fig
898
+ self.plot_widget.update_plot(fig)
899
+
900
+ # Enable export buttons
901
+ if hasattr(self, 'export_plot_btn'):
902
+ self.export_plot_btn.setEnabled(True)
903
+ if hasattr(self, 'export_csv_btn'):
904
+ self.export_csv_btn.setEnabled(True)
905
+
906
+ # Build snippet-to-clip mapping after clustering
907
+ self._build_snippet_to_clip_map()
908
+
909
+ # Immediately update metadata with cluster assignments
910
+ self._update_metadata_with_clusters()
911
+
912
+ # Refresh metadata columns and cluster list after clustering
913
+ self._refresh_metadata_columns()
914
+ self._refresh_cluster_list()
915
+ self._refresh_cluster_export_list()
916
+
917
+ # Update plots if metadata column is selected
918
+ if hasattr(self, 'metadata_column_combo') and self.metadata_column_combo.currentText() != "None (Show all clusters)":
919
+ self._update_plots_by_metadata()
920
+ else:
921
+ QMessageBox.warning(self, "Error", status)
922
+
923
+ def _update_metadata_with_clusters(self):
924
+ """Update metadata CSV file with cluster assignments."""
925
+ if self.metadata is None or self.clusters is None or self.processed_data is None:
926
+ return
927
+
928
+ try:
929
+ # Map clusters to snippet_ids
930
+ # processed_data.index contains snippet_ids (from matrix columns)
931
+ cluster_series = pd.Series(self.clusters, index=self.processed_data.index, name='Cluster')
932
+
933
+ # Convert cluster numbers to cluster labels
934
+ cluster_labels = cluster_series.map(lambda c: f'Cluster_{c}' if c >= 0 else 'Noise')
935
+
936
+ # Update metadata
937
+ # Check for both 'snippet' and 'span_id' (backward compatibility)
938
+ snippet_col = 'snippet' if 'snippet' in self.metadata.columns else ('span_id' if 'span_id' in self.metadata.columns else None)
939
+ if snippet_col:
940
+ # Merge cluster assignments by snippet/span_id
941
+ cluster_df = pd.DataFrame({
942
+ snippet_col: self.processed_data.index,
943
+ 'Cluster': cluster_labels
944
+ })
945
+
946
+ # Update or add Cluster column
947
+ if 'Cluster' in self.metadata.columns:
948
+ # Remove old Cluster column
949
+ self.metadata = self.metadata.drop(columns=['Cluster'])
950
+
951
+ # Merge new cluster assignments
952
+ self.metadata = self.metadata.merge(cluster_df, on=snippet_col, how='left')
953
+
954
+ # Save updated metadata back to file
955
+ if self.metadata_file_path:
956
+ # Use the original metadata file path, respecting format
957
+ self._save_metadata_to_file(self.metadata, self.metadata_file_path)
958
+ current_text = self.status_label.text()
959
+ self.status_label.setText(f"{current_text}\nMetadata updated with cluster assignments")
960
+ else:
961
+ # Try to find metadata file in experiment folder
962
+ experiment_path = self.config.get("experiment_path")
963
+ if experiment_path:
964
+ registered_clips_dir = os.path.join(experiment_path, "registered_clips")
965
+ if os.path.exists(registered_clips_dir):
966
+ # Find the latest metadata file (prefer parquet, then npz, then csv)
967
+ meta_files = [f for f in os.listdir(registered_clips_dir)
968
+ if f.startswith("behaviorome_") and "_metadata" in f]
969
+ if meta_files:
970
+ # Prefer parquet > npz > csv
971
+ parquet_files = [f for f in meta_files if f.endswith(".parquet")]
972
+ npz_files = [f for f in meta_files if f.endswith(".npz")]
973
+ csv_files = [f for f in meta_files if f.endswith(".csv")]
974
+
975
+ chosen_file = None
976
+ if parquet_files:
977
+ parquet_files.sort(reverse=True)
978
+ chosen_file = parquet_files[0]
979
+ elif csv_files:
980
+ csv_files.sort(reverse=True)
981
+ chosen_file = csv_files[0]
982
+ # Skip npz for saving to avoid format issues
983
+
984
+ if chosen_file:
985
+ metadata_file = os.path.join(registered_clips_dir, chosen_file)
986
+ self._save_metadata_to_file(self.metadata, metadata_file)
987
+ self.metadata_file_path = metadata_file
988
+ current_text = self.status_label.text()
989
+ self.status_label.setText(f"{current_text}\nMetadata updated with cluster assignments")
990
+
991
+ except Exception as e:
992
+ logger.error("Could not update metadata file: %s", e, exc_info=True)
993
+ QMessageBox.warning(self, "Metadata Update Warning",
994
+ f"Could not update metadata file:\n{str(e)}\n\nCluster assignments are still available for export.")
995
+
996
+ def export_results(self):
997
+ """Export current plot data to CSV."""
998
+ if self.processed_data is None or self.clusters is None:
999
+ QMessageBox.warning(self, "No Data", "No results to export.")
1000
+ return
1001
+
1002
+ # Determine plot type and get appropriate data
1003
+ plot_type = self.plot_type_combo.currentText() if hasattr(self, 'plot_type_combo') else "UMAP"
1004
+ group_name = self.metadata_column_combo.currentText() if hasattr(self, 'metadata_column_combo') else None
1005
+
1006
+ experiment_path = self.config.get("experiment_path")
1007
+ if not experiment_path:
1008
+ return
1009
+
1010
+ output_dir = os.path.join(experiment_path, "analysis_results")
1011
+ os.makedirs(output_dir, exist_ok=True)
1012
+
1013
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
1014
+
1015
+ # Create base dataframe with UMAP coordinates and clusters
1016
+ df_export = pd.DataFrame({
1017
+ 'UMAP1': self.embedding[:, 0],
1018
+ 'UMAP2': self.embedding[:, 1],
1019
+ 'Cluster': [f'Cluster_{c}' if c >= 0 else 'Noise' for c in self.clusters],
1020
+ 'Sample': self.processed_data.index
1021
+ })
1022
+
1023
+ if self.embedding.shape[1] > 2:
1024
+ df_export['UMAP3'] = self.embedding[:, 2]
1025
+
1026
+ # Add metadata if available
1027
+ if self.metadata is not None:
1028
+ snippet_col = 'snippet' if 'snippet' in self.metadata.columns else ('span_id' if 'span_id' in self.metadata.columns else None)
1029
+ if snippet_col and snippet_col in self.metadata.columns:
1030
+ df_export = df_export.merge(
1031
+ self.metadata,
1032
+ left_on='Sample',
1033
+ right_on=snippet_col,
1034
+ how='left'
1035
+ )
1036
+
1037
+ # Add plot-specific data based on plot type
1038
+ if plot_type == "Cluster Proportions" and group_name and group_name != "None (Show all clusters)":
1039
+ if group_name in df_export.columns:
1040
+ filename = os.path.join(output_dir, f"cluster_proportions_{group_name}_{timestamp}.csv")
1041
+ else:
1042
+ filename = os.path.join(output_dir, f"cluster_proportions_{timestamp}.csv")
1043
+ elif plot_type == "Single Cluster Analysis" and group_name and group_name != "None (Show all clusters)":
1044
+ selected_cluster = self.single_cluster_combo.currentText() if hasattr(self, 'single_cluster_combo') and self.single_cluster_combo.currentText() != "None" else None
1045
+ if selected_cluster and group_name in df_export.columns:
1046
+ filename = os.path.join(output_dir, f"single_cluster_{selected_cluster}_{group_name}_{timestamp}.csv")
1047
+ else:
1048
+ filename = os.path.join(output_dir, f"single_cluster_analysis_{timestamp}.csv")
1049
+ else:
1050
+ filename = os.path.join(output_dir, f"clustering_results_{timestamp}.csv")
1051
+
1052
+ df_export.to_csv(filename, index=False)
1053
+ QMessageBox.information(self, "Export", f"Results saved to:\n{filename}")
1054
+
1055
+ def open_metadata_manager(self):
1056
+ """Open metadata management dialog."""
1057
+ if self.metadata is None:
1058
+ QMessageBox.warning(self, "No Data", "Please load metadata first.")
1059
+ return
1060
+
1061
+ # Import here to avoid circular imports
1062
+ from .metadata_management_widget import MetadataManagementDialog
1063
+
1064
+ dialog = MetadataManagementDialog(self.metadata.copy(), self.metadata_file_path, self.config, self)
1065
+ if dialog.exec() == QDialog.DialogCode.Accepted:
1066
+ # Update metadata from dialog
1067
+ self.metadata = dialog.get_metadata()
1068
+ self.metadata_file_path = dialog.get_metadata_path()
1069
+
1070
+ # Reload data to reflect changes
1071
+ if self.matrix_data is not None:
1072
+ # Reapply preprocessing if needed
1073
+ self.apply_preprocessing()
1074
+
1075
+ # Refresh metadata columns
1076
+ self._refresh_metadata_columns()
1077
+
1078
+ # Update plots if metadata column is selected
1079
+ if hasattr(self, 'metadata_column_combo') and self.metadata_column_combo.currentText() != "None (Show all clusters)":
1080
+ self._update_plots_by_metadata()
1081
+
1082
+ QMessageBox.information(self, "Success", "Metadata updated successfully.")
1083
+
1084
+ def _refresh_metadata_columns(self):
1085
+ """Refresh metadata columns in combo box."""
1086
+ if not hasattr(self, 'metadata_column_combo'):
1087
+ return
1088
+
1089
+ # Block signals to prevent recursion
1090
+ self.metadata_column_combo.blockSignals(True)
1091
+
1092
+ try:
1093
+ current_selection = self.metadata_column_combo.currentText()
1094
+ self.metadata_column_combo.clear()
1095
+ self.metadata_column_combo.addItem("None (Show all clusters)")
1096
+
1097
+ if self.metadata is not None:
1098
+ # Include all columns from metadata (user can choose which to use)
1099
+ available_cols = list(self.metadata.columns)
1100
+ self.metadata_column_combo.addItems(available_cols)
1101
+
1102
+ # Restore previous selection if still available
1103
+ if current_selection in available_cols:
1104
+ idx = self.metadata_column_combo.findText(current_selection)
1105
+ if idx >= 0:
1106
+ self.metadata_column_combo.setCurrentIndex(idx)
1107
+ finally:
1108
+ self.metadata_column_combo.blockSignals(False)
1109
+
1110
+ def _refresh_cluster_list(self):
1111
+ """Refresh cluster list in single cluster combo."""
1112
+ if not hasattr(self, 'single_cluster_combo') or self.clusters is None:
1113
+ return
1114
+
1115
+ # Block signals to prevent recursion
1116
+ self.single_cluster_combo.blockSignals(True)
1117
+
1118
+ try:
1119
+ current_selection = self.single_cluster_combo.currentText()
1120
+ self.single_cluster_combo.clear()
1121
+ self.single_cluster_combo.addItem("None")
1122
+
1123
+ # Normalize cluster identifiers to string labels: Cluster_X or Noise
1124
+ labels = set()
1125
+ for cid in set(self.clusters):
1126
+ label = None
1127
+ if isinstance(cid, str):
1128
+ lc = cid.lower()
1129
+ if lc.startswith("cluster_"):
1130
+ label = cid
1131
+ elif lc == "noise":
1132
+ label = "Noise"
1133
+ else:
1134
+ # Try to parse numeric from string
1135
+ try:
1136
+ num = int(cid)
1137
+ label = f"Cluster_{num}" if num >= 0 else "Noise"
1138
+ except Exception:
1139
+ label = None
1140
+ else:
1141
+ # numeric cluster id
1142
+ try:
1143
+ num = int(cid)
1144
+ label = f"Cluster_{num}" if num >= 0 else "Noise"
1145
+ except Exception:
1146
+ label = None
1147
+
1148
+ if label:
1149
+ labels.add(label)
1150
+
1151
+ # Sort labels numerically, keep Noise last if present
1152
+ def _label_key(lbl):
1153
+ if lbl == "Noise":
1154
+ return (1e9,)
1155
+ if lbl.lower().startswith("cluster_"):
1156
+ try:
1157
+ return (int(lbl.split("_", 1)[1]),)
1158
+ except Exception:
1159
+ return (1e8,)
1160
+ return (1e8,)
1161
+
1162
+ for lbl in sorted(labels, key=_label_key):
1163
+ self.single_cluster_combo.addItem(lbl)
1164
+
1165
+ # Restore previous selection if still available
1166
+ if current_selection and current_selection != "None":
1167
+ idx = self.single_cluster_combo.findText(current_selection)
1168
+ if idx >= 0:
1169
+ self.single_cluster_combo.setCurrentIndex(idx)
1170
+ finally:
1171
+ # Always unblock signals
1172
+ self.single_cluster_combo.blockSignals(False)
1173
+
1174
+ def _refresh_spatial_selectors(self):
1175
+ """Refresh video and object selectors for spatial distribution plot."""
1176
+ if self.metadata is None:
1177
+ return
1178
+
1179
+ # Block signals to prevent recursion
1180
+ self.spatial_video_combo.blockSignals(True)
1181
+ self.spatial_object_combo.blockSignals(True)
1182
+
1183
+ try:
1184
+ # Get current selections
1185
+ current_video = self.spatial_video_combo.currentText()
1186
+ current_object = self.spatial_object_combo.currentText()
1187
+
1188
+ # Get unique videos from metadata
1189
+ # Prefer 'group' column (contains original video/animal name)
1190
+ # Otherwise extract from 'video_id' (clip filenames)
1191
+ videos = ["All"]
1192
+ video_names = set()
1193
+
1194
+ if 'group' in self.metadata.columns:
1195
+ # Use group column which has original video names
1196
+ for v in self.metadata['group'].dropna().unique():
1197
+ v_str = str(v).strip()
1198
+ if v_str:
1199
+ video_names.add(v_str)
1200
+ elif 'video_id' in self.metadata.columns:
1201
+ # Extract video name from clip filenames
1202
+ # Clip format: {video_name}_clip_{clip_idx:06d}_obj{obj_id}.mp4
1203
+ import re
1204
+ for v in self.metadata['video_id'].dropna().unique():
1205
+ v_str = str(v)
1206
+ base = os.path.splitext(os.path.basename(v_str))[0]
1207
+ # Remove _clip_XXXXXX and _objX suffixes
1208
+ match = re.match(r'^(.+?)_clip_\d+(?:_obj\d+)?$', base)
1209
+ if match:
1210
+ video_name = match.group(1)
1211
+ if video_name:
1212
+ video_names.add(video_name)
1213
+ else:
1214
+ # Fallback: use base name if pattern doesn't match
1215
+ if base:
1216
+ video_names.add(base)
1217
+
1218
+ for v in sorted(video_names):
1219
+ videos.append(v)
1220
+
1221
+ self.spatial_video_combo.clear()
1222
+ for v in videos:
1223
+ self.spatial_video_combo.addItem(v)
1224
+
1225
+ # Restore video selection
1226
+ if current_video in videos:
1227
+ idx = self.spatial_video_combo.findText(current_video)
1228
+ if idx >= 0:
1229
+ self.spatial_video_combo.setCurrentIndex(idx)
1230
+
1231
+ # Get unique objects from metadata
1232
+ object_col = 'object_id' if 'object_id' in self.metadata.columns else None
1233
+ objects = ["All"]
1234
+ if object_col:
1235
+ unique_objects = self.metadata[object_col].dropna().unique()
1236
+ for o in sorted(set(str(obj) for obj in unique_objects if str(obj).strip())):
1237
+ if o not in objects:
1238
+ objects.append(o)
1239
+
1240
+ self.spatial_object_combo.clear()
1241
+ for o in objects:
1242
+ self.spatial_object_combo.addItem(o)
1243
+
1244
+ # Restore object selection
1245
+ if current_object in objects:
1246
+ idx = self.spatial_object_combo.findText(current_object)
1247
+ if idx >= 0:
1248
+ self.spatial_object_combo.setCurrentIndex(idx)
1249
+
1250
+ finally:
1251
+ self.spatial_video_combo.blockSignals(False)
1252
+ self.spatial_object_combo.blockSignals(False)
1253
+
1254
+ def _on_spatial_video_changed(self):
1255
+ """Handle video selection change - refresh object list and update plot."""
1256
+ # Refresh object list based on selected video
1257
+ if self.metadata is None:
1258
+ return
1259
+
1260
+ selected_video = self.spatial_video_combo.currentText()
1261
+ object_col = 'object_id' if 'object_id' in self.metadata.columns else None
1262
+ video_col = 'video_id' if 'video_id' in self.metadata.columns else None
1263
+
1264
+ self.spatial_object_combo.blockSignals(True)
1265
+ try:
1266
+ current_object = self.spatial_object_combo.currentText()
1267
+ self.spatial_object_combo.clear()
1268
+ self.spatial_object_combo.addItem("All")
1269
+
1270
+ if object_col:
1271
+ if selected_video == "All":
1272
+ # Show all objects across all videos
1273
+ unique_objects = self.metadata[object_col].dropna().unique()
1274
+ else:
1275
+ # Filter objects by selected video
1276
+ if 'group' in self.metadata.columns:
1277
+ # Use group column for matching
1278
+ mask = self.metadata['group'].apply(lambda x: str(x).strip() == selected_video)
1279
+ elif video_col:
1280
+ # Extract video name from clip filenames and match
1281
+ import re
1282
+ def extract_video_name(clip_name):
1283
+ base = os.path.splitext(os.path.basename(str(clip_name)))[0]
1284
+ match = re.match(r'^(.+?)_clip_\d+(?:_obj\d+)?$', base)
1285
+ return match.group(1) if match else base
1286
+ mask = self.metadata[video_col].apply(lambda x: extract_video_name(x) == selected_video)
1287
+ else:
1288
+ mask = pd.Series([True] * len(self.metadata), index=self.metadata.index)
1289
+ unique_objects = self.metadata.loc[mask, object_col].dropna().unique()
1290
+
1291
+ for o in sorted(set(str(obj) for obj in unique_objects if str(obj).strip())):
1292
+ self.spatial_object_combo.addItem(o)
1293
+
1294
+ # Try to restore previous selection
1295
+ idx = self.spatial_object_combo.findText(current_object)
1296
+ if idx >= 0:
1297
+ self.spatial_object_combo.setCurrentIndex(idx)
1298
+ finally:
1299
+ self.spatial_object_combo.blockSignals(False)
1300
+
1301
+ # Update plot
1302
+ self._update_plots_by_metadata()
1303
+
1304
+ def _refresh_cluster_export_list(self):
1305
+ """Refresh cluster list in export list widget."""
1306
+ if not hasattr(self, 'cluster_export_list') or self.clusters is None:
1307
+ return
1308
+
1309
+ self.cluster_export_list.clear()
1310
+
1311
+ # Get unique cluster labels
1312
+ labels = set()
1313
+ for cid in set(self.clusters):
1314
+ label = None
1315
+ if isinstance(cid, str):
1316
+ lc = cid.lower()
1317
+ if lc.startswith("cluster_"):
1318
+ label = cid
1319
+ elif lc == "noise":
1320
+ label = "Noise"
1321
+ else:
1322
+ try:
1323
+ num = int(cid)
1324
+ label = f"Cluster_{num}" if num >= 0 else "Noise"
1325
+ except Exception:
1326
+ label = None
1327
+ else:
1328
+ try:
1329
+ num = int(cid)
1330
+ label = f"Cluster_{num}" if num >= 0 else "Noise"
1331
+ except Exception:
1332
+ label = None
1333
+
1334
+ if label:
1335
+ labels.add(label)
1336
+
1337
+ # Sort labels numerically, keep Noise last if present
1338
+ def _label_key(lbl):
1339
+ if lbl == "Noise":
1340
+ return (1e9,)
1341
+ if lbl.lower().startswith("cluster_"):
1342
+ try:
1343
+ num = int(lbl.split('_')[1])
1344
+ return (0, num)
1345
+ except Exception:
1346
+ return (2, lbl)
1347
+ return (2, lbl)
1348
+
1349
+ sorted_labels = sorted(labels, key=_label_key)
1350
+
1351
+ for label in sorted_labels:
1352
+ self.cluster_export_list.addItem(label)
1353
+
1354
+ # Enable buttons if clusters are available
1355
+ if hasattr(self, 'extract_cluster_btn'):
1356
+ self.extract_cluster_btn.setEnabled(len(sorted_labels) > 0)
1357
+ if hasattr(self, 'exclude_clusters_btn'):
1358
+ self.exclude_clusters_btn.setEnabled(len(sorted_labels) > 0)
1359
+
1360
+ def _extract_cluster(self):
1361
+ """Extract data for a specific cluster and save as NPZ files."""
1362
+ if self.clusters is None:
1363
+ QMessageBox.warning(self, "No Data", "No clustering data available. Please perform clustering first.")
1364
+ return
1365
+
1366
+ selected_items = self.cluster_export_list.selectedItems()
1367
+ if not selected_items:
1368
+ QMessageBox.warning(self, "No Selection", "Please select a cluster to extract.")
1369
+ return
1370
+
1371
+ if len(selected_items) > 1:
1372
+ QMessageBox.warning(self, "Multiple Selection", "Please select only one cluster for extraction. Use 'Exclude Selected Clusters' for multiple clusters.")
1373
+ return
1374
+
1375
+ selected_cluster = selected_items[0].text()
1376
+ use_raw_data = self.use_raw_data_checkbox.isChecked()
1377
+
1378
+ try:
1379
+ # Get cluster number from string (e.g., "Cluster_0" -> 0, "Noise" -> -1)
1380
+ if selected_cluster.lower() == "noise":
1381
+ cluster_num = -1
1382
+ else:
1383
+ cluster_num = int(selected_cluster.split('_')[-1])
1384
+
1385
+ # Get indices of samples in the selected cluster
1386
+ cluster_indices = [i for i, c in enumerate(self.clusters) if c == cluster_num]
1387
+
1388
+ if not cluster_indices:
1389
+ QMessageBox.warning(self, "No Data", f"No samples found in {selected_cluster}")
1390
+ return
1391
+
1392
+ # Choose data source based on user preference
1393
+ if use_raw_data:
1394
+ if self.matrix_data is None:
1395
+ QMessageBox.warning(self, "No Data", "Raw data not available.")
1396
+ return
1397
+ data = self.matrix_data
1398
+ data_type = "raw"
1399
+ else:
1400
+ if self.processed_data is None:
1401
+ QMessageBox.warning(self, "No Data", "Processed data not available. Please apply preprocessing first.")
1402
+ return
1403
+ data = self.processed_data.T # Transpose back to features x samples format
1404
+ data_type = "processed"
1405
+
1406
+ metadata = self.metadata
1407
+
1408
+ if data is None or metadata is None:
1409
+ QMessageBox.warning(self, "No Data", "Original data not available. Please ensure data is loaded.")
1410
+ return
1411
+
1412
+ # Extract subset of data (columns are samples/snippets)
1413
+ subset_data = data.iloc[:, cluster_indices]
1414
+
1415
+ # Get snippet IDs for the selected samples
1416
+ snippet_ids = subset_data.columns.tolist()
1417
+
1418
+ # Extract corresponding metadata
1419
+ snippet_col = 'snippet' if 'snippet' in metadata.columns else ('span_id' if 'span_id' in metadata.columns else None)
1420
+ if snippet_col:
1421
+ subset_metadata = metadata[metadata[snippet_col].isin(snippet_ids)].copy()
1422
+ else:
1423
+ # Fallback: align by index if snippet column not found
1424
+ subset_metadata = metadata.iloc[cluster_indices].copy() if len(metadata) == len(self.clusters) else metadata.copy()
1425
+
1426
+ # Create timestamp for unique filenames
1427
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
1428
+
1429
+ # Determine output directory
1430
+ experiment_path = self.config.get("experiment_path")
1431
+ if experiment_path:
1432
+ output_dir = os.path.join(experiment_path, "analysis_results")
1433
+ else:
1434
+ output_dir = os.getcwd()
1435
+ os.makedirs(output_dir, exist_ok=True)
1436
+
1437
+ # Save matrix as NPZ
1438
+ matrix_filename = f"matrix_{selected_cluster}_{data_type}_{timestamp}.npz"
1439
+ matrix_path = os.path.join(output_dir, matrix_filename)
1440
+
1441
+ feature_names = subset_data.index.tolist()
1442
+ matrix_array = subset_data.values # features x samples
1443
+
1444
+ np.savez_compressed(
1445
+ matrix_path,
1446
+ matrix=matrix_array,
1447
+ feature_names=np.array(feature_names, dtype=object),
1448
+ snippet_ids=np.array(snippet_ids, dtype=object),
1449
+ )
1450
+
1451
+ # Save metadata as NPZ
1452
+ metadata_filename = f"metadata_{selected_cluster}_{timestamp}.npz"
1453
+ metadata_path = os.path.join(output_dir, metadata_filename)
1454
+
1455
+ np.savez_compressed(
1456
+ metadata_path,
1457
+ metadata=subset_metadata.values,
1458
+ columns=np.array(subset_metadata.columns, dtype=object),
1459
+ )
1460
+
1461
+ msg = (f"Successfully extracted {len(cluster_indices)} samples from {selected_cluster}.\n"
1462
+ f"Data type: {data_type}\n"
1463
+ f"Saved as:\n- {matrix_filename} (shape: {subset_data.shape})\n"
1464
+ f"- {metadata_filename} (shape: {subset_metadata.shape})\n\n"
1465
+ f"Click 'Load Dataset' to load this data now, or 'OK' to continue with current data.")
1466
+
1467
+ dialog = ClusterExportDialog(self, msg, matrix_path, metadata_path)
1468
+ dialog.exec()
1469
+
1470
+ # Load dataset if user clicked "Load Dataset"
1471
+ if dialog.load_requested:
1472
+ self._load_files_async(matrix_path, metadata_path)
1473
+
1474
+ except Exception as e:
1475
+ error_msg = f"Error extracting cluster data: {str(e)}"
1476
+ logger.error("Error extracting cluster data: %s", e, exc_info=True)
1477
+ QMessageBox.critical(self, "Error", error_msg)
1478
+
1479
+ def _exclude_clusters(self):
1480
+ """Export data excluding specific clusters as NPZ files."""
1481
+ if self.clusters is None:
1482
+ QMessageBox.warning(self, "No Data", "No clustering data available. Please perform clustering first.")
1483
+ return
1484
+
1485
+ selected_items = self.cluster_export_list.selectedItems()
1486
+ if not selected_items:
1487
+ QMessageBox.warning(self, "No Selection", "Please select at least one cluster to exclude.")
1488
+ return
1489
+
1490
+ clusters_to_exclude = [item.text() for item in selected_items]
1491
+ use_raw_data = self.use_raw_data_checkbox.isChecked()
1492
+
1493
+ try:
1494
+ # Get cluster numbers from strings
1495
+ exclude_nums = []
1496
+ for cluster_str in clusters_to_exclude:
1497
+ if cluster_str.lower() == "noise":
1498
+ exclude_nums.append(-1)
1499
+ else:
1500
+ exclude_nums.append(int(cluster_str.split('_')[-1]))
1501
+
1502
+ # Get indices of samples NOT in the excluded clusters
1503
+ keep_indices = [i for i, c in enumerate(self.clusters) if c not in exclude_nums]
1504
+
1505
+ if not keep_indices:
1506
+ QMessageBox.warning(self, "Error", "Excluding these clusters would remove all data!")
1507
+ return
1508
+
1509
+ # Choose data source based on user preference
1510
+ if use_raw_data:
1511
+ if self.matrix_data is None:
1512
+ QMessageBox.warning(self, "No Data", "Raw data not available.")
1513
+ return
1514
+ data = self.matrix_data
1515
+ data_type = "raw"
1516
+ else:
1517
+ if self.processed_data is None:
1518
+ QMessageBox.warning(self, "No Data", "Processed data not available. Please apply preprocessing first.")
1519
+ return
1520
+ data = self.processed_data.T # Transpose back to features x samples format
1521
+ data_type = "processed"
1522
+
1523
+ metadata = self.metadata
1524
+
1525
+ if data is None or metadata is None:
1526
+ QMessageBox.warning(self, "No Data", "Original data not available. Please ensure data is loaded.")
1527
+ return
1528
+
1529
+ # Extract subset of data (excluding the selected clusters)
1530
+ subset_data = data.iloc[:, keep_indices]
1531
+
1532
+ # Get snippet IDs for the kept samples
1533
+ snippet_ids = subset_data.columns.tolist()
1534
+
1535
+ # Extract corresponding metadata
1536
+ snippet_col = 'snippet' if 'snippet' in metadata.columns else ('span_id' if 'span_id' in metadata.columns else None)
1537
+ if snippet_col:
1538
+ subset_metadata = metadata[metadata[snippet_col].isin(snippet_ids)].copy()
1539
+ else:
1540
+ # Fallback: align by index if snippet column not found
1541
+ subset_metadata = metadata.iloc[keep_indices].copy() if len(metadata) == len(self.clusters) else metadata.copy()
1542
+
1543
+ # Create timestamp and descriptive filename
1544
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
1545
+ excluded_str = "_".join([str(num) for num in sorted(exclude_nums)])
1546
+
1547
+ # Determine output directory
1548
+ experiment_path = self.config.get("experiment_path")
1549
+ if experiment_path:
1550
+ output_dir = os.path.join(experiment_path, "analysis_results")
1551
+ else:
1552
+ output_dir = os.getcwd()
1553
+ os.makedirs(output_dir, exist_ok=True)
1554
+
1555
+ # Save matrix as NPZ
1556
+ matrix_filename = f"matrix_excluding_clusters_{excluded_str}_{data_type}_{timestamp}.npz"
1557
+ matrix_path = os.path.join(output_dir, matrix_filename)
1558
+
1559
+ feature_names = subset_data.index.tolist()
1560
+ matrix_array = subset_data.values # features x samples
1561
+
1562
+ np.savez_compressed(
1563
+ matrix_path,
1564
+ matrix=matrix_array,
1565
+ feature_names=np.array(feature_names, dtype=object),
1566
+ snippet_ids=np.array(snippet_ids, dtype=object),
1567
+ )
1568
+
1569
+ # Save metadata as NPZ
1570
+ metadata_filename = f"metadata_excluding_clusters_{excluded_str}_{timestamp}.npz"
1571
+ metadata_path = os.path.join(output_dir, metadata_filename)
1572
+
1573
+ np.savez_compressed(
1574
+ metadata_path,
1575
+ metadata=subset_metadata.values,
1576
+ columns=np.array(subset_metadata.columns, dtype=object),
1577
+ )
1578
+
1579
+ # Calculate statistics
1580
+ original_count = len(self.clusters)
1581
+ remaining_count = len(keep_indices)
1582
+ excluded_count = original_count - remaining_count
1583
+
1584
+ # Get remaining clusters
1585
+ remaining_clusters = sorted(set(self.clusters[i] for i in keep_indices))
1586
+ remaining_labels = []
1587
+ for cid in remaining_clusters:
1588
+ if isinstance(cid, str):
1589
+ remaining_labels.append(cid)
1590
+ else:
1591
+ remaining_labels.append(f"Cluster_{cid}" if cid >= 0 else "Noise")
1592
+
1593
+ msg = (f"Successfully excluded {len(clusters_to_exclude)} clusters.\n"
1594
+ f"Data type: {data_type}\n"
1595
+ f"Removed {excluded_count} samples, kept {remaining_count} samples.\n"
1596
+ f"Remaining clusters: {remaining_labels}\n\n"
1597
+ f"Saved as:\n- {matrix_filename} (shape: {subset_data.shape})\n"
1598
+ f"- {metadata_filename} (shape: {subset_metadata.shape})\n\n"
1599
+ f"Click 'Load Dataset' to load this data now, or 'OK' to continue with current data.")
1600
+
1601
+ dialog = ClusterExportDialog(self, msg, matrix_path, metadata_path)
1602
+ dialog.exec()
1603
+
1604
+ # Load dataset if user clicked "Load Dataset"
1605
+ if dialog.load_requested:
1606
+ self._load_files_async(matrix_path, metadata_path)
1607
+
1608
+ except Exception as e:
1609
+ error_msg = f"Error excluding clusters: {str(e)}"
1610
+ logger.error("Error excluding clusters: %s", e, exc_info=True)
1611
+ QMessageBox.critical(self, "Error", error_msg)
1612
+
1613
+ def _show_cluster_export_help(self):
1614
+ """Show help dialog for cluster export."""
1615
+ QMessageBox.information(
1616
+ self,
1617
+ "Cluster Export Help",
1618
+ "Cluster Export allows you to:\n\n"
1619
+ "-Extract Selected Cluster: Export a single cluster for subclustering analysis. "
1620
+ "This creates a new dataset containing only samples from the selected cluster, "
1621
+ "which you can then load and analyze separately.\n\n"
1622
+ "-Exclude Selected Clusters: Export data excluding specific clusters. "
1623
+ "This is useful for removing noise or artifacts from your analysis. "
1624
+ "You can select multiple clusters to exclude at once.\n\n"
1625
+ "-Use raw data: If checked, exports the original (unnormalized) data. "
1626
+ "Otherwise, exports the preprocessed data.\n\n"
1627
+ "All exports are saved as NPZ files in the analysis_results folder."
1628
+ )
1629
+
1630
+ def _get_plot_theme(self):
1631
+ """Get the currently selected plot theme."""
1632
+ if hasattr(self, 'color_theme_combo'):
1633
+ theme = self.color_theme_combo.currentText()
1634
+ return theme if theme != "none" else None
1635
+ return "simple_white" # Default
1636
+
1637
+ def _export_plot(self):
1638
+ """Export current plot as PDF or SVG."""
1639
+ if self.current_fig is None:
1640
+ QMessageBox.warning(self, "No Plot", "No plot available to export. Please run clustering first.")
1641
+ return
1642
+
1643
+ # Determine default directory
1644
+ experiment_path = self.config.get("experiment_path")
1645
+ if experiment_path:
1646
+ default_dir = os.path.join(experiment_path, "analysis_results")
1647
+ os.makedirs(default_dir, exist_ok=True)
1648
+ else:
1649
+ default_dir = os.getcwd()
1650
+
1651
+ # Create timestamp for default filename
1652
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
1653
+ default_filename = os.path.join(default_dir, f"clustering_plot_{timestamp}.pdf")
1654
+
1655
+ # Open file dialog
1656
+ file_path, selected_filter = QFileDialog.getSaveFileName(
1657
+ self,
1658
+ "Export Plot",
1659
+ default_filename,
1660
+ "PDF Files (*.pdf);;SVG Files (*.svg)"
1661
+ )
1662
+
1663
+ if not file_path:
1664
+ return
1665
+
1666
+ try:
1667
+ # Determine format from file extension
1668
+ if file_path.lower().endswith('.svg'):
1669
+ format_type = 'svg'
1670
+ elif file_path.lower().endswith('.pdf'):
1671
+ format_type = 'pdf'
1672
+ else:
1673
+ # Default to PDF if extension not recognized
1674
+ format_type = 'pdf'
1675
+ if not file_path.endswith('.pdf'):
1676
+ file_path += '.pdf'
1677
+
1678
+ # Check if kaleido is available (required for PDF/SVG export)
1679
+ try:
1680
+ import kaleido
1681
+ except ImportError:
1682
+ QMessageBox.critical(
1683
+ self,
1684
+ "Export Error",
1685
+ "The 'kaleido' package is required for PDF/SVG export.\n\n"
1686
+ "Please install it with: pip install kaleido"
1687
+ )
1688
+ return
1689
+
1690
+ # Get figure dimensions (use on-screen canvas size when possible)
1691
+ fig_width = 1200
1692
+ fig_height = 800
1693
+
1694
+ # Prefer the plot widget's rendered size to avoid stretching
1695
+ if hasattr(self, 'plot_widget') and self.plot_widget is not None:
1696
+ try:
1697
+ w = self.plot_widget.width()
1698
+ h = self.plot_widget.height()
1699
+ if w and h:
1700
+ fig_width = max(int(w), 400)
1701
+ fig_height = max(int(h), 300)
1702
+ except Exception as e:
1703
+ logger.debug("Could not read plot widget dimensions: %s", e)
1704
+
1705
+ # If explicit layout sizes are set on the figure, prefer them
1706
+ if hasattr(self.current_fig, 'layout') and self.current_fig.layout:
1707
+ if getattr(self.current_fig.layout, 'width', None):
1708
+ fig_width = self.current_fig.layout.width
1709
+ if getattr(self.current_fig.layout, 'height', None):
1710
+ fig_height = self.current_fig.layout.height
1711
+
1712
+ # Export using plotly.io.write_image
1713
+ pio.write_image(
1714
+ self.current_fig,
1715
+ file_path,
1716
+ format=format_type,
1717
+ width=fig_width,
1718
+ height=fig_height
1719
+ )
1720
+
1721
+ QMessageBox.information(self, "Success", f"Plot exported successfully to:\n{file_path}")
1722
+
1723
+ except Exception as e:
1724
+ error_msg = f"Error exporting plot: {str(e)}"
1725
+ logger.error("Error exporting plot: %s", e, exc_info=True)
1726
+ QMessageBox.critical(self, "Export Error", error_msg)
1727
+
1728
+ def _update_plot_type(self):
1729
+ """Update plot type based on selection."""
1730
+ self._update_plots_by_metadata()
1731
+
1732
+ def _update_plots_by_metadata(self):
1733
+ """Update plots based on selected metadata column and plot type."""
1734
+ if self.embedding is None or self.clusters is None or self.processed_data is None:
1735
+ return
1736
+
1737
+ plot_type = self.plot_type_combo.currentText()
1738
+
1739
+ # Show/hide single cluster selector based on plot type
1740
+ if plot_type in ["Single Cluster Analysis", "Spatial cluster distribution"]:
1741
+ self.single_cluster_label.setVisible(True)
1742
+ self.single_cluster_combo.setVisible(True)
1743
+ # Make sure cluster list is populated
1744
+ if self.clusters is not None:
1745
+ self._refresh_cluster_list()
1746
+ else:
1747
+ self.single_cluster_label.setVisible(False)
1748
+ self.single_cluster_combo.setVisible(False)
1749
+
1750
+ # Show/hide video and object selectors for spatial distribution
1751
+ if plot_type == "Spatial cluster distribution":
1752
+ self.spatial_video_label.setVisible(True)
1753
+ self.spatial_video_combo.setVisible(True)
1754
+ self.spatial_object_label.setVisible(True)
1755
+ self.spatial_object_combo.setVisible(True)
1756
+ # Populate video/object lists
1757
+ self._refresh_spatial_selectors()
1758
+ else:
1759
+ self.spatial_video_label.setVisible(False)
1760
+ self.spatial_video_combo.setVisible(False)
1761
+ self.spatial_object_label.setVisible(False)
1762
+ self.spatial_object_combo.setVisible(False)
1763
+
1764
+ # Handle "UMAP" plot type
1765
+ if plot_type == "UMAP":
1766
+ group_name = self.metadata_column_combo.currentText()
1767
+
1768
+ # Handle empty or "None" selection
1769
+ if not group_name or group_name.strip() == "" or group_name == "None (Show all clusters)":
1770
+ # Regenerate default clustering plot to show all clusters
1771
+ self._regenerate_default_plot()
1772
+ if hasattr(self, 'export_plot_btn'):
1773
+ self.export_plot_btn.setEnabled(True)
1774
+ if hasattr(self, 'export_csv_btn'):
1775
+ self.export_csv_btn.setEnabled(True)
1776
+ return
1777
+
1778
+ if self.metadata is None:
1779
+ QMessageBox.warning(self, "No Metadata", "Please load metadata first.")
1780
+ return
1781
+
1782
+ # Validate column exists
1783
+ if group_name not in self.metadata.columns:
1784
+ QMessageBox.warning(self, "Invalid Column", f"Column '{group_name}' not found in metadata.\n\nAvailable columns: {', '.join(self.metadata.columns)}")
1785
+ # Reset to "None" selection
1786
+ self.metadata_column_combo.setCurrentIndex(0)
1787
+ return
1788
+
1789
+ try:
1790
+ point_size = self.plot_point_size_slider.value()
1791
+ umap_fig, props_fig, single_fig = self._create_grouped_plots(group_name, point_size, None)
1792
+
1793
+ if umap_fig:
1794
+ self.current_fig = umap_fig
1795
+ self.plot_widget.update_plot(umap_fig)
1796
+ if hasattr(self, 'export_plot_btn'):
1797
+ self.export_plot_btn.setEnabled(True)
1798
+ if hasattr(self, 'export_csv_btn'):
1799
+ self.export_csv_btn.setEnabled(True)
1800
+
1801
+ except Exception as e:
1802
+ logger.error("Error creating plots: %s", e, exc_info=True)
1803
+ QMessageBox.critical(self, "Plot Error", f"Error creating plots: {e}")
1804
+
1805
+ # Handle "Cluster Proportions" plot type
1806
+ elif plot_type == "Cluster Proportions":
1807
+ group_name = self.metadata_column_combo.currentText()
1808
+
1809
+ if not group_name or group_name.strip() == "" or group_name == "None (Show all clusters)":
1810
+ QMessageBox.warning(self, "No Group Selected", "Please select a metadata column to show proportions.")
1811
+ return
1812
+
1813
+ if self.metadata is None:
1814
+ QMessageBox.warning(self, "No Metadata", "Please load metadata first.")
1815
+ return
1816
+
1817
+ if group_name not in self.metadata.columns:
1818
+ QMessageBox.warning(self, "Invalid Column", f"Column '{group_name}' not found in metadata.")
1819
+ return
1820
+
1821
+ try:
1822
+ point_size = self.plot_point_size_slider.value()
1823
+ umap_fig, props_fig, single_fig = self._create_grouped_plots(group_name, point_size, None)
1824
+
1825
+ if props_fig:
1826
+ self.current_fig = props_fig
1827
+ self.plot_widget.update_plot(props_fig)
1828
+ if hasattr(self, 'export_plot_btn'):
1829
+ self.export_plot_btn.setEnabled(True)
1830
+ if hasattr(self, 'export_csv_btn'):
1831
+ self.export_csv_btn.setEnabled(True)
1832
+ else:
1833
+ QMessageBox.warning(self, "Plot Error", "Could not create proportions plot.")
1834
+
1835
+ except Exception as e:
1836
+ logger.error("Error creating plots: %s", e, exc_info=True)
1837
+ QMessageBox.critical(self, "Plot Error", f"Error creating plots: {e}")
1838
+
1839
+ # Handle "Single Cluster Analysis" plot type
1840
+ elif plot_type == "Single Cluster Analysis":
1841
+ # Check if clusters are available
1842
+ if self.clusters is None:
1843
+ QMessageBox.warning(self, "No Clusters", "Please run clustering first.")
1844
+ return
1845
+
1846
+ # Make sure cluster list is populated
1847
+ self._refresh_cluster_list()
1848
+
1849
+ group_name = self.metadata_column_combo.currentText()
1850
+ selected_cluster = self.single_cluster_combo.currentText() if self.single_cluster_combo.currentText() != "None" else None
1851
+
1852
+ if not selected_cluster:
1853
+ # Don't show warning immediately - let user select first
1854
+ # Just clear the plot or show a message
1855
+ self.status_label.setText("Please select a cluster from the dropdown above.")
1856
+ return
1857
+
1858
+ if not group_name or group_name.strip() == "" or group_name == "None (Show all clusters)":
1859
+ QMessageBox.warning(self, "No Group Selected", "Please select a metadata column to show single cluster proportions.")
1860
+ return
1861
+
1862
+ if self.metadata is None:
1863
+ QMessageBox.warning(self, "No Metadata", "Please load metadata first.")
1864
+ return
1865
+
1866
+ if group_name not in self.metadata.columns:
1867
+ QMessageBox.warning(self, "Invalid Column", f"Column '{group_name}' not found in metadata.")
1868
+ return
1869
+
1870
+ try:
1871
+ point_size = self.plot_point_size_slider.value()
1872
+ umap_fig, props_fig, single_fig = self._create_grouped_plots(group_name, point_size, selected_cluster)
1873
+
1874
+ if single_fig:
1875
+ self.current_fig = single_fig
1876
+ self.plot_widget.update_plot(single_fig)
1877
+ if hasattr(self, 'export_plot_btn'):
1878
+ self.export_plot_btn.setEnabled(True)
1879
+ if hasattr(self, 'export_csv_btn'):
1880
+ self.export_csv_btn.setEnabled(True)
1881
+ else:
1882
+ QMessageBox.warning(self, "Plot Error", "Could not create single cluster plot.")
1883
+
1884
+ except Exception as e:
1885
+ logger.error("Error creating plots: %s", e, exc_info=True)
1886
+ QMessageBox.critical(self, "Plot Error", f"Error creating plots: {e}")
1887
+
1888
+ # Handle "Spatial cluster distribution" plot type
1889
+ elif plot_type == "Spatial cluster distribution":
1890
+ # Check if clusters are available
1891
+ if self.clusters is None:
1892
+ self.status_label.setText("Please run clustering first.")
1893
+ return
1894
+
1895
+ # Make sure cluster list is populated
1896
+ if hasattr(self, 'single_cluster_combo'):
1897
+ self._refresh_cluster_list()
1898
+ selected_cluster = self.single_cluster_combo.currentText() if self.single_cluster_combo.currentText() != "None" else None
1899
+ else:
1900
+ selected_cluster = None
1901
+
1902
+ if not selected_cluster:
1903
+ self.status_label.setText("Please select a cluster from the dropdown above to visualize its spatial distribution.")
1904
+ return
1905
+
1906
+ try:
1907
+ spatial_fig = self._create_spatial_distribution_plot(selected_cluster)
1908
+
1909
+ if spatial_fig:
1910
+ self.current_fig = spatial_fig
1911
+ self.plot_widget.update_plot(spatial_fig)
1912
+ if hasattr(self, 'export_plot_btn'):
1913
+ self.export_plot_btn.setEnabled(True)
1914
+ if hasattr(self, 'export_csv_btn'):
1915
+ self.export_csv_btn.setEnabled(True)
1916
+ self.status_label.setText(f"Spatial distribution plot for {selected_cluster} displayed.")
1917
+ else:
1918
+ self.status_label.setText("Could not create spatial distribution plot. Make sure mask data is available in the experiment folder.")
1919
+
1920
+ except Exception as e:
1921
+ error_msg = f"Error creating spatial distribution plot: {str(e)}"
1922
+ self.status_label.setText(error_msg)
1923
+ logger.error("Error creating spatial distribution plot: %s", e, exc_info=True)
1924
+ QMessageBox.critical(self, "Plot Error", error_msg)
1925
+
1926
+ def _create_grouped_plots(self, group_name, point_size, selected_cluster=None):
1927
+ """Create UMAP subplots grouped by metadata column, proportion plot, and single cluster plot."""
1928
+ # Create base dataframe with UMAP coordinates and clusters
1929
+ df_plot = pd.DataFrame({
1930
+ 'UMAP1': self.embedding[:, 0],
1931
+ 'UMAP2': self.embedding[:, 1],
1932
+ 'Cluster': [f'Cluster_{c}' if c >= 0 else 'Noise' for c in self.clusters],
1933
+ 'Sample': self.processed_data.index
1934
+ })
1935
+
1936
+ # Validate group_name exists in metadata
1937
+ if group_name not in self.metadata.columns:
1938
+ return None, None, None
1939
+
1940
+ # Merge metadata with processed data
1941
+ snippet_col = 'snippet' if 'snippet' in self.metadata.columns else ('span_id' if 'span_id' in self.metadata.columns else None)
1942
+
1943
+ # Prepare columns to merge (only what we need)
1944
+ cols_to_merge = [group_name]
1945
+ if snippet_col and snippet_col in self.metadata.columns:
1946
+ cols_to_merge.append(snippet_col)
1947
+
1948
+ if snippet_col and snippet_col in self.metadata.columns:
1949
+ # Merge using snippet column
1950
+ df_plot = df_plot.merge(
1951
+ self.metadata[cols_to_merge],
1952
+ left_on='Sample',
1953
+ right_on=snippet_col,
1954
+ how='left'
1955
+ )
1956
+ else:
1957
+ # Try to align by index - reset index if needed
1958
+ metadata_aligned = self.metadata.copy()
1959
+ if len(metadata_aligned) == len(df_plot):
1960
+ # Align by position
1961
+ metadata_aligned.index = df_plot.index
1962
+ df_plot = pd.concat([df_plot, metadata_aligned[[group_name]]], axis=1)
1963
+ else:
1964
+ # Try to merge by resetting index
1965
+ metadata_reset = self.metadata.reset_index()
1966
+ if 'index' in metadata_reset.columns:
1967
+ df_plot = df_plot.merge(metadata_reset[cols_to_merge + ['index']], left_on='Sample', right_on='index', how='left')
1968
+ else:
1969
+ # Last resort: align by position if lengths match
1970
+ if len(metadata_aligned) == len(df_plot):
1971
+ df_plot[group_name] = metadata_aligned[group_name].values
1972
+ else:
1973
+ return None, None, None
1974
+
1975
+ # Ensure Cluster column is still present (merge might have dropped it)
1976
+ if 'Cluster' not in df_plot.columns:
1977
+ df_plot['Cluster'] = [f'Cluster_{c}' if c >= 0 else 'Noise' for c in self.clusters]
1978
+
1979
+ if group_name not in df_plot.columns or df_plot[group_name].isnull().all():
1980
+ return None, None, None
1981
+
1982
+ # Filter out NaN values
1983
+ df_plot = df_plot.dropna(subset=[group_name])
1984
+
1985
+ # Generate cluster colors
1986
+ unique_clusters = sorted(df_plot['Cluster'].unique())
1987
+ n_clusters = len(unique_clusters)
1988
+ cluster_colors = self._generate_cluster_colors(n_clusters)
1989
+ cluster_color_map = {f'Cluster_{i}': cluster_colors[i] for i in range(n_clusters) if f'Cluster_{i}' in unique_clusters}
1990
+ if 'Noise' in unique_clusters:
1991
+ cluster_color_map['Noise'] = '#808080'
1992
+
1993
+ # Get unique values in the group
1994
+ unique_values = sorted(df_plot[group_name].dropna().unique())
1995
+ n_groups = len(unique_values)
1996
+
1997
+ if n_groups == 0:
1998
+ return None, None, None
1999
+
2000
+ # Create UMAP subplots
2001
+ n_cols = 2
2002
+ n_rows = (n_groups + 1 + n_cols - 1) // n_cols # Add 1 for combined view
2003
+
2004
+ subplot_titles = [f'{group_name}: {val}' for val in unique_values]
2005
+ subplot_titles.append(f'All {group_name} Groups Combined')
2006
+
2007
+ umap_fig = make_subplots(
2008
+ rows=n_rows,
2009
+ cols=n_cols,
2010
+ subplot_titles=subplot_titles,
2011
+ horizontal_spacing=0.05,
2012
+ vertical_spacing=0.08
2013
+ )
2014
+
2015
+ # Calculate global axis ranges
2016
+ umap1_min, umap1_max = df_plot['UMAP1'].min(), df_plot['UMAP1'].max()
2017
+ umap2_min, umap2_max = df_plot['UMAP2'].min(), df_plot['UMAP2'].max()
2018
+ range1 = umap1_max - umap1_min
2019
+ range2 = umap2_max - umap2_min
2020
+ max_range = max(range1, range2)
2021
+ center1 = (umap1_min + umap1_max) / 2
2022
+ center2 = (umap2_min + umap2_max) / 2
2023
+ padding = max_range * 0.1
2024
+ axis_min1 = center1 - (max_range / 2) - padding
2025
+ axis_max1 = center1 + (max_range / 2) + padding
2026
+ axis_min2 = center2 - (max_range / 2) - padding
2027
+ axis_max2 = center2 + (max_range / 2) + padding
2028
+
2029
+ # Plot individual groups
2030
+ for idx, value in enumerate(unique_values):
2031
+ row = idx // n_cols + 1
2032
+ col = idx % n_cols + 1
2033
+
2034
+ mask = df_plot[group_name] == value
2035
+ df_subset = df_plot[mask]
2036
+
2037
+ for cluster in unique_clusters:
2038
+ cluster_data = df_subset[df_subset['Cluster'] == cluster]
2039
+ if len(cluster_data) == 0:
2040
+ continue
2041
+
2042
+ # Add snippet IDs as customdata for click handling
2043
+ snippet_ids = cluster_data['Sample'].tolist()
2044
+
2045
+ umap_fig.add_trace(
2046
+ go.Scatter(
2047
+ x=cluster_data['UMAP1'],
2048
+ y=cluster_data['UMAP2'],
2049
+ mode='markers',
2050
+ marker=dict(size=point_size, color=cluster_color_map.get(cluster, '#CCCCCC')),
2051
+ name=cluster,
2052
+ showlegend=(idx == 0),
2053
+ legendgroup=cluster,
2054
+ customdata=[[sid] for sid in snippet_ids], # Wrap in list for each point
2055
+ hovertemplate='<b>%{hovertext}</b><br>Cluster: ' + cluster + '<br>X: %{x:.2f}<br>Y: %{y:.2f}<extra></extra>',
2056
+ hovertext=snippet_ids
2057
+ ),
2058
+ row=row, col=col
2059
+ )
2060
+
2061
+ umap_fig.update_xaxes(range=[axis_min1, axis_max1], row=row, col=col)
2062
+ umap_fig.update_yaxes(range=[axis_min2, axis_max2], row=row, col=col)
2063
+
2064
+ # Combined view
2065
+ combined_row = (n_groups // n_cols) + 1
2066
+ combined_col = (n_groups % n_cols) + 1
2067
+
2068
+ # Generate group colors for combined view
2069
+ group_colors = self._generate_pastel_palette(n_groups)
2070
+ group_color_map = {val: group_colors[i] for i, val in enumerate(unique_values)}
2071
+
2072
+ for value in unique_values:
2073
+ mask = df_plot[group_name] == value
2074
+ df_subset = df_plot[mask]
2075
+
2076
+ # Add snippet IDs as customdata for click handling
2077
+ snippet_ids_combined = df_subset['Sample'].tolist()
2078
+
2079
+ umap_fig.add_trace(
2080
+ go.Scatter(
2081
+ x=df_subset['UMAP1'],
2082
+ y=df_subset['UMAP2'],
2083
+ mode='markers',
2084
+ marker=dict(size=point_size, color=group_color_map[value]),
2085
+ name=f'{group_name}: {value}',
2086
+ showlegend=True,
2087
+ legendgroup=f'group_{value}',
2088
+ customdata=[[sid] for sid in snippet_ids_combined], # Wrap in list for each point
2089
+ hovertemplate='<b>%{hovertext}</b><br>' + group_name + ': ' + str(value) + '<br>X: %{x:.2f}<br>Y: %{y:.2f}<extra></extra>',
2090
+ hovertext=snippet_ids_combined
2091
+ ),
2092
+ row=combined_row, col=combined_col
2093
+ )
2094
+
2095
+ umap_fig.update_xaxes(range=[axis_min1, axis_max1], row=combined_row, col=combined_col)
2096
+ umap_fig.update_yaxes(range=[axis_min2, axis_max2], row=combined_row, col=combined_col)
2097
+
2098
+ theme = self._get_plot_theme()
2099
+ umap_fig.update_layout(
2100
+ title=f'UMAP + Leiden Clustering by {group_name}',
2101
+ template=theme if theme else None,
2102
+ height=300 * n_rows
2103
+ )
2104
+
2105
+ # Create proportion plot
2106
+ props_fig = self._create_proportion_plot(df_plot, group_name, cluster_color_map)
2107
+
2108
+ # Create single cluster plot if selected and exists in data
2109
+ single_fig = None
2110
+ if selected_cluster:
2111
+ if selected_cluster in df_plot['Cluster'].unique():
2112
+ single_fig = self._create_single_cluster_plot(df_plot, group_name, selected_cluster, cluster_color_map)
2113
+ else:
2114
+ single_fig = None
2115
+
2116
+ return umap_fig, props_fig, single_fig
2117
+
2118
+ def _create_proportion_plot(self, df_plot, group_name, cluster_color_map):
2119
+ """Create cluster proportion bar plot."""
2120
+ # Calculate proportions
2121
+ props = []
2122
+ for group_val in sorted(df_plot[group_name].dropna().unique()):
2123
+ group_data = df_plot[df_plot[group_name] == group_val]
2124
+ total = len(group_data)
2125
+ cluster_props = group_data['Cluster'].value_counts() / total * 100
2126
+ for cluster in cluster_props.index:
2127
+ props.append({
2128
+ group_name: group_val,
2129
+ 'Cluster': cluster,
2130
+ 'Proportion (%)': cluster_props[cluster]
2131
+ })
2132
+
2133
+ df_props = pd.DataFrame(props)
2134
+
2135
+ # Create bar plot
2136
+ unique_groups = sorted(df_props[group_name].unique())
2137
+ unique_clusters = sorted(df_props['Cluster'].unique())
2138
+
2139
+ fig = go.Figure()
2140
+
2141
+ n_groups = len(unique_groups)
2142
+ n_clusters = len(unique_clusters)
2143
+ bar_width = 0.8 / n_groups
2144
+
2145
+ for i, group in enumerate(unique_groups):
2146
+ group_data = df_props[df_props[group_name] == group]
2147
+
2148
+ x_positions = []
2149
+ proportions = []
2150
+
2151
+ for j, cluster in enumerate(unique_clusters):
2152
+ cluster_row = group_data[group_data['Cluster'] == cluster]
2153
+ prop = cluster_row['Proportion (%)'].iloc[0] if len(cluster_row) > 0 else 0
2154
+ x_pos = j + (i - n_groups/2 + 0.5) * bar_width
2155
+ x_positions.append(x_pos)
2156
+ proportions.append(prop)
2157
+
2158
+ fig.add_trace(go.Bar(
2159
+ x=x_positions,
2160
+ y=proportions,
2161
+ name=f"{group_name}: {group}",
2162
+ width=bar_width
2163
+ ))
2164
+
2165
+ theme = self._get_plot_theme()
2166
+ fig.update_layout(
2167
+ title=f'Cluster Proportions by {group_name}',
2168
+ xaxis_title='Cluster',
2169
+ yaxis_title='Proportion (%)',
2170
+ height=400,
2171
+ template=theme if theme else None,
2172
+ barmode='group',
2173
+ xaxis=dict(
2174
+ tickmode='array',
2175
+ tickvals=list(range(n_clusters)),
2176
+ ticktext=[str(c) for c in unique_clusters]
2177
+ )
2178
+ )
2179
+
2180
+ return fig
2181
+
2182
+ def _create_single_cluster_plot(self, df_plot, group_name, selected_cluster, cluster_color_map):
2183
+ """Create single cluster proportion bar plot."""
2184
+ # Extract cluster number from "Cluster_X" format
2185
+ cluster_num = selected_cluster.replace('Cluster_', '')
2186
+
2187
+ props = []
2188
+ for group_val in sorted(df_plot[group_name].dropna().unique()):
2189
+ group_data = df_plot[df_plot[group_name] == group_val]
2190
+ total = len(group_data)
2191
+ cluster_count = len(group_data[group_data['Cluster'] == selected_cluster])
2192
+ proportion = (cluster_count / total) * 100 if total > 0 else 0
2193
+ props.append({
2194
+ group_name: group_val,
2195
+ 'Proportion (%)': proportion
2196
+ })
2197
+
2198
+ df_props = pd.DataFrame(props)
2199
+
2200
+ color = cluster_color_map.get(selected_cluster, '#CCCCCC')
2201
+
2202
+ fig = go.Figure()
2203
+ fig.add_trace(go.Bar(
2204
+ x=df_props[group_name],
2205
+ y=df_props['Proportion (%)'],
2206
+ marker_color=color,
2207
+ name=selected_cluster,
2208
+ text=df_props['Proportion (%)'].round(1).astype(str) + '%',
2209
+ textposition='auto'
2210
+ ))
2211
+
2212
+ theme = self._get_plot_theme()
2213
+ fig.update_layout(
2214
+ title=f'Proportion of {selected_cluster} Across {group_name}',
2215
+ xaxis_title=group_name,
2216
+ yaxis_title='Proportion (%)',
2217
+ height=400,
2218
+ template=theme if theme else None
2219
+ )
2220
+
2221
+ return fig
2222
+
2223
+ def _create_spatial_distribution_plot(self, selected_cluster):
2224
+ """Create spatial distribution plot showing trajectory with cluster clip segments overlaid."""
2225
+ try:
2226
+ if self.metadata is None or self.clusters is None or self.processed_data is None:
2227
+ return None
2228
+
2229
+ # Get selected video and object filters
2230
+ selected_video = self.spatial_video_combo.currentText() if hasattr(self, 'spatial_video_combo') else "All"
2231
+ selected_object = self.spatial_object_combo.currentText() if hasattr(self, 'spatial_object_combo') else "All"
2232
+
2233
+ # Get full trajectory and clip trajectories from mask data
2234
+ trajectory_data = self._get_trajectory_data(selected_video, selected_object)
2235
+ if trajectory_data is None:
2236
+ return None
2237
+
2238
+ all_trajectories = trajectory_data.get('all_trajectories', [])
2239
+ clip_trajectories = trajectory_data.get('clip_trajectories', {})
2240
+
2241
+ if not all_trajectories:
2242
+ return None
2243
+
2244
+ # Get theme
2245
+ theme = self._get_plot_theme()
2246
+
2247
+ # Create figure
2248
+ fig = go.Figure()
2249
+
2250
+ # Plot all trajectories as base layer (grey lines)
2251
+ for traj_idx, traj_info in enumerate(all_trajectories):
2252
+ traj = traj_info['trajectory']
2253
+ traj_name = traj_info.get('name', f'Trajectory {traj_idx + 1}')
2254
+
2255
+ if len(traj) > 0:
2256
+ traj_x = [p[0] for p in traj]
2257
+ traj_y = [p[1] for p in traj]
2258
+
2259
+ fig.add_trace(go.Scatter(
2260
+ x=traj_x,
2261
+ y=traj_y,
2262
+ mode='lines',
2263
+ line=dict(
2264
+ color='lightgrey',
2265
+ width=1
2266
+ ),
2267
+ name=traj_name if traj_idx == 0 else None,
2268
+ showlegend=(traj_idx == 0),
2269
+ legendgroup='trajectories',
2270
+ hoverinfo='skip'
2271
+ ))
2272
+
2273
+ # Get cluster color
2274
+ unique_clusters = sorted(set(self.clusters))
2275
+ cluster_colors = self._generate_cluster_colors(len(unique_clusters))
2276
+ cluster_to_color = {f'Cluster_{c}' if c >= 0 else 'Noise': cluster_colors[i]
2277
+ for i, c in enumerate(unique_clusters)}
2278
+ cluster_color = cluster_to_color.get(selected_cluster, '#e74c3c')
2279
+
2280
+ # Get snippets belonging to selected cluster (filtered by video/object)
2281
+ snippet_col = 'snippet' if 'snippet' in self.metadata.columns else ('span_id' if 'span_id' in self.metadata.columns else None)
2282
+ video_col = 'video_id' if 'video_id' in self.metadata.columns else None
2283
+ object_col = 'object_id' if 'object_id' in self.metadata.columns else None
2284
+
2285
+ if snippet_col is None:
2286
+ return None
2287
+
2288
+ # Build filter mask for metadata
2289
+ filter_mask = pd.Series([True] * len(self.metadata), index=self.metadata.index)
2290
+ if selected_video != "All":
2291
+ if 'group' in self.metadata.columns:
2292
+ # Use group column for matching
2293
+ filter_mask &= self.metadata['group'].apply(lambda x: str(x).strip() == selected_video)
2294
+ elif video_col:
2295
+ # Extract video name from clip filenames and match
2296
+ import re
2297
+ def extract_video_name(clip_name):
2298
+ base = os.path.splitext(os.path.basename(str(clip_name)))[0]
2299
+ match = re.match(r'^(.+?)_clip_\d+(?:_obj\d+)?$', base)
2300
+ return match.group(1) if match else base
2301
+ filter_mask &= self.metadata[video_col].apply(lambda x: extract_video_name(x) == selected_video)
2302
+ if object_col and selected_object != "All":
2303
+ filter_mask &= self.metadata[object_col].apply(lambda x: str(x) == selected_object)
2304
+
2305
+ filtered_snippets = set(self.metadata.loc[filter_mask, snippet_col].values)
2306
+
2307
+ # Map snippets to clusters
2308
+ cluster_snippets = []
2309
+ for i, snip in enumerate(self.processed_data.index):
2310
+ if snip not in filtered_snippets:
2311
+ continue
2312
+ cluster_label = f'Cluster_{self.clusters[i]}' if self.clusters[i] >= 0 else 'Noise'
2313
+ if cluster_label == selected_cluster:
2314
+ cluster_snippets.append(snip)
2315
+
2316
+ # Plot clip trajectory segments for selected cluster
2317
+ segment_count = 0
2318
+ for snippet_id in cluster_snippets:
2319
+ if snippet_id in clip_trajectories:
2320
+ clip_traj = clip_trajectories[snippet_id]
2321
+ if len(clip_traj) >= 1:
2322
+ clip_x = [p[0] for p in clip_traj]
2323
+ clip_y = [p[1] for p in clip_traj]
2324
+
2325
+ fig.add_trace(go.Scatter(
2326
+ x=clip_x,
2327
+ y=clip_y,
2328
+ mode='lines+markers',
2329
+ line=dict(
2330
+ color=cluster_color,
2331
+ width=3
2332
+ ),
2333
+ marker=dict(
2334
+ size=4,
2335
+ color=cluster_color
2336
+ ),
2337
+ name=selected_cluster if segment_count == 0 else None,
2338
+ showlegend=(segment_count == 0),
2339
+ legendgroup='cluster',
2340
+ hovertext=f'{snippet_id}',
2341
+ hoverinfo='text'
2342
+ ))
2343
+ segment_count += 1
2344
+
2345
+ # Build title with filter info
2346
+ filter_info = []
2347
+ if selected_video != "All":
2348
+ filter_info.append(f"Video: {selected_video}")
2349
+ if selected_object != "All":
2350
+ filter_info.append(f"Object: {selected_object}")
2351
+ filter_str = f" | {', '.join(filter_info)}" if filter_info else ""
2352
+
2353
+ # Update layout
2354
+ fig.update_layout(
2355
+ title=f'Spatial Distribution: {selected_cluster} ({segment_count} clips){filter_str}',
2356
+ xaxis_title='X Position (pixels)',
2357
+ yaxis_title='Y Position (pixels)',
2358
+ xaxis=dict(scaleanchor="y", scaleratio=1),
2359
+ yaxis=dict(autorange='reversed'),
2360
+ height=600,
2361
+ template=theme if theme else None,
2362
+ hovermode='closest'
2363
+ )
2364
+
2365
+ return fig
2366
+ except Exception as e:
2367
+ logger.error("Error creating spatial distribution plot: %s", e, exc_info=True)
2368
+ return None
2369
+
2370
+ def get_representative_snippets(self, n_samples=10):
2371
+ """Find representative snippets for each cluster (closest to centroid).
2372
+
2373
+ Returns:
2374
+ dict: {cluster_label: [snippet_id1, snippet_id2, ...]}
2375
+ """
2376
+ if self.processed_data is None or self.clusters is None:
2377
+ return {}
2378
+
2379
+ from sklearn.metrics import pairwise_distances_argmin_min
2380
+
2381
+ representative_snippets = {}
2382
+ unique_clusters = sorted(set(self.clusters))
2383
+
2384
+ for cluster_id in unique_clusters:
2385
+ label = f'Cluster_{cluster_id}' if cluster_id >= 0 else 'Noise'
2386
+ if label == 'Noise':
2387
+ continue
2388
+
2389
+ # Get indices of samples in this cluster
2390
+ indices = [i for i, c in enumerate(self.clusters) if c == cluster_id]
2391
+ if not indices:
2392
+ continue
2393
+
2394
+ # Get features for these samples
2395
+ cluster_features = self.processed_data.iloc[indices].values
2396
+ snippet_ids = self.processed_data.iloc[indices].index.tolist()
2397
+
2398
+ # Compute centroid
2399
+ centroid = np.mean(cluster_features, axis=0).reshape(1, -1)
2400
+
2401
+ # Calculate distances to centroid
2402
+ from sklearn.metrics.pairwise import euclidean_distances
2403
+ distances = euclidean_distances(cluster_features, centroid).flatten()
2404
+
2405
+ # Get closest samples
2406
+ n = min(n_samples, len(indices))
2407
+ closest_indices = np.argsort(distances)[:n]
2408
+
2409
+ representative_snippets[label] = [snippet_ids[i] for i in closest_indices]
2410
+
2411
+ return representative_snippets
2412
+
2413
+ def _get_trajectory_data(self, selected_video="All", selected_object="All"):
2414
+ """Get full trajectory and clip trajectory segments from mask data.
2415
+
2416
+ Args:
2417
+ selected_video: Filter by video name, or "All" for all videos
2418
+ selected_object: Filter by object ID, or "All" for all objects
2419
+ """
2420
+ try:
2421
+ if self.metadata is None:
2422
+ return None
2423
+
2424
+ # Get experiment path
2425
+ experiment_path = self.config.get("experiment_path")
2426
+ if not experiment_path or not os.path.exists(experiment_path):
2427
+ return None
2428
+
2429
+ # Try to find mask files in common locations
2430
+ possible_mask_dirs = [
2431
+ os.path.join(experiment_path, "masks"),
2432
+ os.path.join(experiment_path, "segmentation_masks"),
2433
+ experiment_path
2434
+ ]
2435
+
2436
+ mask_files_found = []
2437
+ for mask_dir in possible_mask_dirs:
2438
+ if os.path.exists(mask_dir):
2439
+ try:
2440
+ for f in os.listdir(mask_dir):
2441
+ if f.endswith(('.h5', '.hdf5')):
2442
+ mask_files_found.append((mask_dir, f))
2443
+ except (OSError, PermissionError):
2444
+ continue
2445
+
2446
+ if not mask_files_found:
2447
+ return None
2448
+
2449
+ # Get video names from metadata for matching
2450
+ metadata_video_names = set()
2451
+ if 'group' in self.metadata.columns:
2452
+ for v in self.metadata['group'].dropna().unique():
2453
+ metadata_video_names.add(str(v).strip())
2454
+ elif 'video_id' in self.metadata.columns:
2455
+ import re
2456
+ for v in self.metadata['video_id'].dropna().unique():
2457
+ base = os.path.splitext(os.path.basename(str(v)))[0]
2458
+ match = re.match(r'^(.+?)_clip_\d+(?:_obj\d+)?$', base)
2459
+ if match:
2460
+ metadata_video_names.add(match.group(1))
2461
+ else:
2462
+ metadata_video_names.add(base)
2463
+
2464
+ from singlebehaviorlab.backend.video_processor import load_segmentation_data
2465
+
2466
+ all_trajectories = [] # List of {name, trajectory}
2467
+ all_frame_centroids = {} # (video_name, obj_id, frame_idx) -> (cx, cy)
2468
+
2469
+ # Load trajectories from each mask file
2470
+ for mask_dir, mask_file in mask_files_found:
2471
+ mask_path = os.path.join(mask_dir, mask_file)
2472
+
2473
+ # Extract video name from mask filename
2474
+ mask_base = os.path.splitext(mask_file)[0]
2475
+ mask_video_name = mask_base.replace('_mask', '').replace('_objects', '').replace('_segmentation', '')
2476
+
2477
+ # Find matching video name from metadata
2478
+ matched_video_name = None
2479
+ for meta_video_name in metadata_video_names:
2480
+ # Check if mask filename contains metadata video name or vice versa
2481
+ if meta_video_name in mask_video_name or mask_video_name in meta_video_name:
2482
+ matched_video_name = meta_video_name
2483
+ break
2484
+
2485
+ # If no match found, use mask filename as video name
2486
+ if matched_video_name is None:
2487
+ matched_video_name = mask_video_name
2488
+
2489
+ # Filter by video if not "All"
2490
+ if selected_video != "All" and selected_video != matched_video_name:
2491
+ continue
2492
+
2493
+ try:
2494
+ mask_data = load_segmentation_data(mask_path)
2495
+ frame_objects = mask_data.get('frame_objects', [])
2496
+
2497
+ if not frame_objects:
2498
+ continue
2499
+
2500
+ video_height = mask_data.get('height', 1080)
2501
+ video_width = mask_data.get('width', 1920)
2502
+
2503
+ # Get unique object IDs in this mask file
2504
+ obj_ids = set()
2505
+ for frame_objs in frame_objects:
2506
+ for obj in frame_objs:
2507
+ obj_id = obj.get('obj_id', 0)
2508
+ obj_ids.add(str(obj_id))
2509
+
2510
+ # Process each object
2511
+ for obj_id_str in obj_ids:
2512
+ # Filter by object if not "All"
2513
+ if selected_object != "All" and selected_object != obj_id_str:
2514
+ continue
2515
+
2516
+ trajectory = []
2517
+
2518
+ for frame_idx, frame_objs in enumerate(frame_objects):
2519
+ for obj in frame_objs:
2520
+ if str(obj.get('obj_id', 0)) == obj_id_str:
2521
+ bbox = obj.get('bbox', (0, 0, video_width, video_height))
2522
+ x_min, y_min, x_max, y_max = bbox
2523
+ cx = (x_min + x_max) / 2.0
2524
+ cy = (y_min + y_max) / 2.0
2525
+ trajectory.append((cx, cy))
2526
+ # Store for clip trajectory lookup using matched video name
2527
+ all_frame_centroids[(matched_video_name, obj_id_str, frame_idx)] = (cx, cy)
2528
+ break
2529
+
2530
+ if len(trajectory) > 0:
2531
+ traj_name = f"{matched_video_name} obj{obj_id_str}" if obj_id_str != "0" else matched_video_name
2532
+ all_trajectories.append({
2533
+ 'name': traj_name,
2534
+ 'trajectory': trajectory,
2535
+ 'video': matched_video_name,
2536
+ 'object_id': obj_id_str
2537
+ })
2538
+ except Exception as e:
2539
+ logger.debug("Could not process trajectory for video: %s", e)
2540
+ continue
2541
+
2542
+ if not all_trajectories:
2543
+ return None
2544
+
2545
+ # Build clip trajectories for each snippet
2546
+ clip_trajectories = {}
2547
+ snippet_col = 'snippet' if 'snippet' in self.metadata.columns else ('span_id' if 'span_id' in self.metadata.columns else None)
2548
+ video_col = 'video_id' if 'video_id' in self.metadata.columns else None
2549
+ object_col = 'object_id' if 'object_id' in self.metadata.columns else None
2550
+
2551
+ if snippet_col is None:
2552
+ return {'all_trajectories': all_trajectories, 'clip_trajectories': {}}
2553
+
2554
+ for idx, row in self.metadata.iterrows():
2555
+ try:
2556
+ snippet_id = row[snippet_col]
2557
+ start_frame = row.get('start_frame')
2558
+ end_frame = row.get('end_frame')
2559
+
2560
+ # Get video name for this snippet (from group or extract from video_id)
2561
+ snippet_video_name = None
2562
+ if 'group' in self.metadata.columns:
2563
+ snippet_video_name = str(row.get('group', '')).strip()
2564
+ elif video_col:
2565
+ import re
2566
+ snippet_video = str(row.get(video_col, ''))
2567
+ base = os.path.splitext(os.path.basename(snippet_video))[0]
2568
+ match = re.match(r'^(.+?)_clip_\d+(?:_obj\d+)?$', base)
2569
+ snippet_video_name = match.group(1) if match else base
2570
+
2571
+ snippet_obj = str(row.get(object_col, '0')) if object_col else '0'
2572
+
2573
+ if pd.notna(start_frame) and pd.notna(end_frame):
2574
+ try:
2575
+ start = int(float(start_frame))
2576
+ end = int(float(end_frame))
2577
+ except (ValueError, TypeError):
2578
+ continue
2579
+ else:
2580
+ clip_idx = row.get('clip_index')
2581
+ if pd.notna(clip_idx):
2582
+ try:
2583
+ start = int(clip_idx) * 16
2584
+ end = start + 15
2585
+ except (ValueError, TypeError):
2586
+ continue
2587
+ else:
2588
+ continue
2589
+
2590
+ # Get trajectory segment for this clip
2591
+ clip_traj = []
2592
+ if snippet_video_name:
2593
+ for f in range(start, end + 1):
2594
+ # Try with object ID first, then without
2595
+ key = (snippet_video_name, snippet_obj, f)
2596
+ if key in all_frame_centroids:
2597
+ clip_traj.append(all_frame_centroids[key])
2598
+ else:
2599
+ key = (snippet_video_name, '0', f)
2600
+ if key in all_frame_centroids:
2601
+ clip_traj.append(all_frame_centroids[key])
2602
+
2603
+ if len(clip_traj) >= 1:
2604
+ clip_trajectories[snippet_id] = clip_traj
2605
+
2606
+ except Exception as e:
2607
+ logger.debug("Could not build clip trajectory for snippet: %s", e)
2608
+ continue
2609
+
2610
+ return {
2611
+ 'all_trajectories': all_trajectories,
2612
+ 'clip_trajectories': clip_trajectories
2613
+ }
2614
+
2615
+ except Exception as e:
2616
+ logger.error("Error creating spatial distribution plot: %s", e, exc_info=True)
2617
+ return None
2618
+
2619
+ def _save_analysis_state(self):
2620
+ """Save full analysis state to a pickle file."""
2621
+ if self.matrix_data is None:
2622
+ QMessageBox.warning(self, "No Data", "No data to save. Please load data first.")
2623
+ return
2624
+
2625
+ try:
2626
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
2627
+ default_name = f"clustering_analysis_{timestamp}.pkl"
2628
+
2629
+ # Determine output directory
2630
+ experiment_path = self.config.get("experiment_path")
2631
+ initial_dir = ""
2632
+ if experiment_path:
2633
+ initial_dir = os.path.join(experiment_path, "analysis_results")
2634
+ os.makedirs(initial_dir, exist_ok=True)
2635
+
2636
+ path, _ = QFileDialog.getSaveFileName(
2637
+ self, "Save Analysis State", os.path.join(initial_dir, default_name), "Pickle Files (*.pkl)"
2638
+ )
2639
+
2640
+ if not path:
2641
+ return
2642
+
2643
+ state = {
2644
+ 'matrix_data': self.matrix_data,
2645
+ 'metadata': self.metadata,
2646
+ 'processed_data': self.processed_data,
2647
+ 'embedding': self.embedding,
2648
+ 'clusters': self.clusters,
2649
+ 'selected_features': self.selected_features,
2650
+ 'snippet_to_clip_map': self.snippet_to_clip_map,
2651
+ 'metadata_file_path': self.metadata_file_path,
2652
+ 'timestamp': timestamp,
2653
+ 'version': '1.0'
2654
+ }
2655
+
2656
+ with open(path, 'wb') as f:
2657
+ pickle.dump(state, f)
2658
+
2659
+ QMessageBox.information(self, "Success", f"Analysis saved to:\n{path}")
2660
+
2661
+ except Exception as e:
2662
+ logger.error("Failed to save analysis: %s", e, exc_info=True)
2663
+ QMessageBox.critical(self, "Error", f"Failed to save analysis: {str(e)}")
2664
+
2665
+ def _load_analysis_state(self):
2666
+ """Load full analysis state from a pickle file."""
2667
+ try:
2668
+ experiment_path = self.config.get("experiment_path")
2669
+ initial_dir = ""
2670
+ if experiment_path:
2671
+ initial_dir = os.path.join(experiment_path, "analysis_results")
2672
+
2673
+ path, _ = QFileDialog.getOpenFileName(
2674
+ self, "Load Analysis State", initial_dir, "Pickle Files (*.pkl)"
2675
+ )
2676
+
2677
+ if not path:
2678
+ return
2679
+
2680
+ with open(path, 'rb') as f:
2681
+ state = pickle.load(f)
2682
+
2683
+ # Restore state
2684
+ self.matrix_data = state.get('matrix_data')
2685
+ self.metadata = state.get('metadata')
2686
+ self.processed_data = state.get('processed_data')
2687
+ self.embedding = state.get('embedding')
2688
+ self.clusters = state.get('clusters')
2689
+ self.selected_features = state.get('selected_features')
2690
+ self.snippet_to_clip_map = state.get('snippet_to_clip_map', {})
2691
+ self.metadata_file_path = state.get('metadata_file_path')
2692
+
2693
+ # Refresh UI
2694
+ self.status_label.setText("Analysis state loaded.")
2695
+ self.file_info_label.setText(f"Loaded analysis from: {os.path.basename(path)}")
2696
+
2697
+ # Re-build clip map if missing
2698
+ if not self.snippet_to_clip_map:
2699
+ self._build_snippet_to_clip_map()
2700
+
2701
+ # Update UI elements
2702
+ self._refresh_metadata_columns()
2703
+ self._refresh_cluster_list()
2704
+ self._refresh_cluster_export_list()
2705
+ self._refresh_spatial_selectors()
2706
+
2707
+ # Enable buttons
2708
+ has_data = self.matrix_data is not None
2709
+ self.run_btn.setEnabled(has_data)
2710
+
2711
+ # Update plot if embedding exists
2712
+ if self.embedding is not None and self.clusters is not None:
2713
+ # Regenerate default UMAP plot from loaded data
2714
+ self._regenerate_default_plot()
2715
+
2716
+ # Trigger plot update via plot type
2717
+ self._update_plots_by_metadata()
2718
+
2719
+ # Enable export buttons
2720
+ if hasattr(self, 'export_plot_btn'):
2721
+ self.export_plot_btn.setEnabled(True)
2722
+ if hasattr(self, 'export_csv_btn'):
2723
+ self.export_csv_btn.setEnabled(True)
2724
+
2725
+ QMessageBox.information(self, "Success", "Analysis state loaded successfully.")
2726
+
2727
+ except Exception as e:
2728
+ logger.error("Failed to load analysis: %s", e, exc_info=True)
2729
+ QMessageBox.critical(self, "Error", f"Failed to load analysis: {str(e)}")
2730
+
2731
+ def _regenerate_default_plot(self):
2732
+ """Regenerate the default UMAP plot from loaded embedding and clusters."""
2733
+ if self.embedding is None or self.clusters is None:
2734
+ return
2735
+
2736
+ try:
2737
+ # Create sample index
2738
+ if self.processed_data is not None:
2739
+ sample_index = self.processed_data.index.tolist()
2740
+ else:
2741
+ sample_index = [f"snippet{i}" for i in range(len(self.clusters))]
2742
+
2743
+ df_plot = pd.DataFrame({
2744
+ 'UMAP1': self.embedding[:, 0],
2745
+ 'UMAP2': self.embedding[:, 1],
2746
+ 'Cluster': [f'Cluster_{c}' if c >= 0 else 'Noise' for c in self.clusters],
2747
+ 'Sample': sample_index
2748
+ })
2749
+
2750
+ if self.embedding.shape[1] >= 3:
2751
+ df_plot['UMAP3'] = self.embedding[:, 2]
2752
+ fig = px.scatter_3d(
2753
+ df_plot, x='UMAP1', y='UMAP2', z='UMAP3',
2754
+ color='Cluster', hover_data=['Sample'],
2755
+ title="UMAP Clustering (Loaded)",
2756
+ custom_data=[sample_index]
2757
+ )
2758
+ else:
2759
+ fig = px.scatter(
2760
+ df_plot, x='UMAP1', y='UMAP2',
2761
+ color='Cluster', hover_data=['Sample'],
2762
+ title="UMAP Clustering (Loaded)",
2763
+ custom_data=[sample_index]
2764
+ )
2765
+
2766
+ theme = self._get_plot_theme()
2767
+ fig.update_layout(template=theme if theme else None)
2768
+ point_size = self.plot_point_size_slider.value() if hasattr(self, 'plot_point_size_slider') else 5
2769
+ fig.update_traces(marker=dict(size=point_size))
2770
+ self.current_fig = fig
2771
+ self.plot_widget.update_plot(fig)
2772
+
2773
+ except Exception as e:
2774
+ logger.error("Error regenerating plot: %s", e, exc_info=True)
2775
+
2776
+ def _generate_cluster_colors(self, n_clusters):
2777
+ """Generate colors for clusters."""
2778
+ base_colors = ['#e74c3c', '#3498db', '#2ecc71', '#f39c12', '#9b59b6', '#1abc9c', '#e67e22', '#34495e']
2779
+ if n_clusters <= len(base_colors):
2780
+ return base_colors[:n_clusters]
2781
+ # Generate more colors if needed
2782
+ import colorsys
2783
+ colors = []
2784
+ for i in range(n_clusters):
2785
+ hue = i / n_clusters
2786
+ rgb = colorsys.hsv_to_rgb(hue, 0.7, 0.9)
2787
+ colors.append(f'#{int(rgb[0]*255):02x}{int(rgb[1]*255):02x}{int(rgb[2]*255):02x}')
2788
+ return colors
2789
+
2790
+ def _generate_pastel_palette(self, n_colors):
2791
+ """Generate pastel color palette."""
2792
+ base_colors = [
2793
+ '#FF6B6B', '#4ECDC4', '#45B7D1', '#FFA500', '#FF1493',
2794
+ '#32CD32', '#9B59B6', '#FF8C00', '#00CED1', '#DA70D6'
2795
+ ]
2796
+ if n_colors <= len(base_colors):
2797
+ return base_colors[:n_colors]
2798
+ # Cycle through colors
2799
+ return [base_colors[i % len(base_colors)] for i in range(n_colors)]
2800
+
2801
+ def _save_metadata_to_file(self, df: pd.DataFrame, path: str):
2802
+ """Save metadata DataFrame to file, respecting the file format based on extension."""
2803
+ if path.endswith(".npz"):
2804
+ # Save as NPZ format
2805
+ np.savez_compressed(
2806
+ path,
2807
+ metadata=df.values,
2808
+ columns=np.array(df.columns, dtype=object),
2809
+ )
2810
+ elif path.endswith(".parquet"):
2811
+ df.to_parquet(path, index=False)
2812
+ else:
2813
+ # Default to CSV
2814
+ df.to_csv(path, index=False)
2815
+
2816
+ def _build_snippet_to_clip_map(self):
2817
+ """Build mapping from snippet IDs to clip file paths."""
2818
+ self.snippet_to_clip_map = {}
2819
+
2820
+ if self.metadata is None:
2821
+ return
2822
+
2823
+ # Get experiment path and registered_clips directory
2824
+ experiment_path = self.config.get("experiment_path")
2825
+ if not experiment_path:
2826
+ return
2827
+
2828
+ registered_clips_dir = os.path.join(experiment_path, "registered_clips")
2829
+ if not os.path.exists(registered_clips_dir):
2830
+ return
2831
+
2832
+ # Get snippet column name
2833
+ snippet_col = 'snippet' if 'snippet' in self.metadata.columns else ('span_id' if 'span_id' in self.metadata.columns else None)
2834
+ if snippet_col is None:
2835
+ return
2836
+
2837
+ # Find all clip files recursively (clips are in subdirectories: registered_clips/video_name/clip_XXXXXX.mp4)
2838
+ import glob
2839
+ clip_files = glob.glob(os.path.join(registered_clips_dir, "**", "*.avi"), recursive=True)
2840
+ # Also find legacy .mp4 clips for backwards compatibility
2841
+ clip_files += glob.glob(os.path.join(registered_clips_dir, "**", "*.mp4"), recursive=True)
2842
+ clip_name_to_path = {}
2843
+ for clip_file in clip_files:
2844
+ clip_name = os.path.basename(clip_file)
2845
+ clip_name_to_path[clip_name] = clip_file
2846
+ clip_name_to_path[clip_name.lower()] = clip_file
2847
+
2848
+ # Map snippets to clips using video_id from metadata
2849
+ # The metadata has 'video_id' which contains the clip filename (e.g., 'clip_000000_obj1.mp4')
2850
+ for idx, row in self.metadata.iterrows():
2851
+ snippet_id = str(row.get(snippet_col, ''))
2852
+ if not snippet_id:
2853
+ continue
2854
+
2855
+ # Get video_id (clip filename) from metadata
2856
+ video_id = str(row.get('video_id', ''))
2857
+
2858
+ # Try exact match first
2859
+ if video_id and video_id in clip_name_to_path:
2860
+ self.snippet_to_clip_map[snippet_id] = clip_name_to_path[video_id]
2861
+ continue
2862
+
2863
+ # Try case-insensitive match
2864
+ if video_id and video_id.lower() in clip_name_to_path:
2865
+ self.snippet_to_clip_map[snippet_id] = clip_name_to_path[video_id.lower()]
2866
+ continue
2867
+
2868
+ # Fallback: try to match by clip filename pattern
2869
+ clip_index = row.get('clip_index', None)
2870
+ object_id = str(row.get('object_id', '')) if pd.notna(row.get('object_id')) and row.get('object_id') else None
2871
+
2872
+ if clip_index is not None:
2873
+ clip_idx_str = f"{int(clip_index):06d}"
2874
+ for clip_name, clip_path in clip_name_to_path.items():
2875
+ if clip_name.islower():
2876
+ continue
2877
+ if clip_idx_str in clip_name:
2878
+ if object_id:
2879
+ if f"_obj{object_id}" in clip_name:
2880
+ self.snippet_to_clip_map[snippet_id] = clip_path
2881
+ break
2882
+ else:
2883
+ if "_obj" not in clip_name:
2884
+ self.snippet_to_clip_map[snippet_id] = clip_path
2885
+ break
2886
+
2887
+ # If still not found, try matching by position in metadata
2888
+ if snippet_id not in self.snippet_to_clip_map:
2889
+ try:
2890
+ snippet_num = int(snippet_id.replace('snippet', '')) - 1
2891
+ if 0 <= snippet_num < len(clip_files):
2892
+ sorted_clips = sorted(clip_files)
2893
+ self.snippet_to_clip_map[snippet_id] = sorted_clips[snippet_num]
2894
+ except (ValueError, IndexError):
2895
+ pass
2896
+
2897
+ def _on_umap_point_clicked(self, snippet_id: str):
2898
+ """Handle click on UMAP point - open video popup for corresponding clip."""
2899
+ if not snippet_id:
2900
+ return
2901
+
2902
+ # Build snippet-to-clip mapping if not already done
2903
+ if not self.snippet_to_clip_map:
2904
+ self._build_snippet_to_clip_map()
2905
+
2906
+ # Find clip file
2907
+ clip_path = self.snippet_to_clip_map.get(snippet_id)
2908
+ if not clip_path or not os.path.exists(clip_path):
2909
+ QMessageBox.warning(self, "Clip Not Found",
2910
+ f"Could not find clip file for snippet: {snippet_id}\n\n"
2911
+ f"Please ensure clips are extracted in the Registration tab.")
2912
+ return
2913
+
2914
+ # Get metadata for this snippet
2915
+ clip_metadata = None
2916
+ if self.metadata is not None:
2917
+ snippet_col = 'snippet' if 'snippet' in self.metadata.columns else ('span_id' if 'span_id' in self.metadata.columns else None)
2918
+ if snippet_col:
2919
+ snippet_row = self.metadata[self.metadata[snippet_col].astype(str) == str(snippet_id)]
2920
+ if len(snippet_row) > 0:
2921
+ row = snippet_row.iloc[0]
2922
+ start_frame = row.get('start_frame')
2923
+ end_frame = row.get('end_frame')
2924
+ video_id = row.get('video_id', '')
2925
+
2926
+ # Get FPS from video file
2927
+ fps = None
2928
+ try:
2929
+ import cv2
2930
+ cap = cv2.VideoCapture(clip_path)
2931
+ if cap.isOpened():
2932
+ fps = cap.get(cv2.CAP_PROP_FPS)
2933
+ cap.release()
2934
+ except Exception as e:
2935
+ logger.debug("Could not read clip FPS: %s", e)
2936
+
2937
+ # Only set clip metadata if frame indices are valid numbers
2938
+ try:
2939
+ # Check if values are not NaN/None and not empty strings
2940
+ start_valid = (pd.notna(start_frame) and str(start_frame).strip() != '')
2941
+ end_valid = (pd.notna(end_frame) and str(end_frame).strip() != '')
2942
+
2943
+ if start_valid and end_valid:
2944
+ clip_metadata = {
2945
+ 'start_frame': int(float(start_frame)),
2946
+ 'end_frame': int(float(end_frame)),
2947
+ 'context_frames': 30, # Default context frames
2948
+ 'fps': fps if fps else 30.0
2949
+ }
2950
+ except (ValueError, TypeError):
2951
+ # If conversion fails, skip metadata
2952
+ clip_metadata = None
2953
+
2954
+ # Open video popup
2955
+ self._open_video_popup(clip_path, snippet_id, clip_metadata)
2956
+
2957
+ def _open_video_popup(self, file_path: str, label: str = None, clip_metadata: dict = None):
2958
+ """Open video in popup dialog with timeline indicator (from clustering_behavior)"""
2959
+ from PyQt6.QtWidgets import QDialog, QVBoxLayout, QHBoxLayout, QCheckBox
2960
+ from PyQt6.QtMultimedia import QMediaPlayer, QAudioOutput
2961
+ from PyQt6.QtMultimediaWidgets import QVideoWidget
2962
+ from PyQt6.QtCore import QUrl
2963
+ from .plot_integration import TimelineWidget
2964
+
2965
+ dialog = QDialog(self)
2966
+ dialog.setWindowTitle(f"Video: {label if label else os.path.basename(file_path)}")
2967
+ dialog.setMinimumSize(800, 650)
2968
+
2969
+ layout = QVBoxLayout(dialog)
2970
+ layout.setSpacing(5)
2971
+
2972
+ # Video container with timeline
2973
+ video_container = QWidget()
2974
+ video_layout = QVBoxLayout(video_container)
2975
+ video_layout.setContentsMargins(0, 0, 0, 0)
2976
+ video_layout.setSpacing(0)
2977
+
2978
+ # Expanded video widget
2979
+ expanded_video = QVideoWidget()
2980
+ expanded_video.setMinimumSize(800, 600)
2981
+
2982
+ # Calculate clip boundaries in milliseconds (for clip-only playback)
2983
+ clip_start_ms = None
2984
+ clip_end_ms = None
2985
+ if clip_metadata:
2986
+ start_frame = clip_metadata.get('start_frame')
2987
+ end_frame = clip_metadata.get('end_frame')
2988
+ context_frames = clip_metadata.get('context_frames', 30)
2989
+ fps = clip_metadata.get('fps')
2990
+
2991
+ if start_frame is not None and end_frame is not None and fps:
2992
+ # Calculate clip boundaries in milliseconds
2993
+ # The extracted video starts at context_start = max(0, start_frame - context_frames)
2994
+ # So in the extracted video, the clip starts at (start_frame - context_start)
2995
+ context_start = max(0, start_frame - context_frames)
2996
+ clip_start_in_extracted = start_frame - context_start
2997
+ clip_end_in_extracted = end_frame - context_start
2998
+ clip_start_ms = int((clip_start_in_extracted / fps) * 1000)
2999
+ clip_end_ms = int((clip_end_in_extracted / fps) * 1000)
3000
+
3001
+ # Timeline widget
3002
+ timeline_widget = TimelineWidget(clip_metadata=clip_metadata)
3003
+ timeline_widget.setFixedHeight(30)
3004
+ timeline_widget.setStyleSheet("""
3005
+ QWidget {
3006
+ background-color: #2b2b2b;
3007
+ border-top: 1px solid #555;
3008
+ }
3009
+ """)
3010
+
3011
+ # Expanded player
3012
+ expanded_player = QMediaPlayer()
3013
+ expanded_audio = QAudioOutput()
3014
+ expanded_player.setAudioOutput(expanded_audio)
3015
+ expanded_player.setVideoOutput(expanded_video)
3016
+ expanded_player.setSource(QUrl.fromLocalFile(os.path.abspath(file_path)))
3017
+
3018
+ # Clip-only mode state
3019
+ clip_only_mode = [False]
3020
+
3021
+ # Enhanced position tracking with clip boundary enforcement
3022
+ def update_timeline_position_with_clip(position):
3023
+ try:
3024
+ if timeline_widget and timeline_widget.isVisible():
3025
+ timeline_widget.set_current_position(position, expanded_player.duration())
3026
+
3027
+ # Enforce clip boundaries when clip-only mode is enabled
3028
+ if clip_only_mode[0] and clip_start_ms is not None and clip_end_ms is not None:
3029
+ if position < clip_start_ms:
3030
+ # Jump to clip start if before clip
3031
+ expanded_player.setPosition(clip_start_ms)
3032
+ elif position > clip_end_ms:
3033
+ # If looping is enabled, jump back to clip start
3034
+ if loop_enabled[0]:
3035
+ expanded_player.setPosition(clip_start_ms)
3036
+ else:
3037
+ # Otherwise pause at clip end
3038
+ expanded_player.pause()
3039
+ play_btn.setText("▶ Play")
3040
+ except RuntimeError:
3041
+ # Widget was deleted, ignore
3042
+ pass
3043
+
3044
+ def update_timeline_duration(duration):
3045
+ try:
3046
+ if timeline_widget and timeline_widget.isVisible():
3047
+ timeline_widget.set_duration(duration)
3048
+ except RuntimeError:
3049
+ # Widget was deleted, ignore
3050
+ pass
3051
+
3052
+ expanded_player.positionChanged.connect(update_timeline_position_with_clip)
3053
+ expanded_player.durationChanged.connect(update_timeline_duration)
3054
+
3055
+ # Loop state
3056
+ loop_enabled = [False]
3057
+
3058
+ def on_playback_state_changed(state):
3059
+ if loop_enabled[0] and state == QMediaPlayer.PlaybackState.StoppedState:
3060
+ if clip_only_mode[0] and clip_start_ms is not None:
3061
+ expanded_player.setPosition(clip_start_ms)
3062
+ else:
3063
+ expanded_player.setPosition(0)
3064
+ expanded_player.play()
3065
+
3066
+ def on_media_status_changed(status):
3067
+ if loop_enabled[0] and status == QMediaPlayer.MediaStatus.EndOfMedia:
3068
+ if clip_only_mode[0] and clip_start_ms is not None:
3069
+ expanded_player.setPosition(clip_start_ms)
3070
+ else:
3071
+ expanded_player.setPosition(0)
3072
+ expanded_player.play()
3073
+
3074
+ expanded_player.playbackStateChanged.connect(on_playback_state_changed)
3075
+ expanded_player.mediaStatusChanged.connect(on_media_status_changed)
3076
+
3077
+ video_layout.addWidget(expanded_video)
3078
+ video_layout.addWidget(timeline_widget)
3079
+
3080
+ # Controls
3081
+ controls = QHBoxLayout()
3082
+
3083
+ def toggle_play():
3084
+ if expanded_player.playbackState() == QMediaPlayer.PlaybackState.PlayingState:
3085
+ expanded_player.pause()
3086
+ play_btn.setText("▶ Play")
3087
+ else:
3088
+ expanded_player.play()
3089
+ play_btn.setText("⏸ Pause")
3090
+
3091
+ play_btn = QPushButton("▶ Play")
3092
+ play_btn.clicked.connect(toggle_play)
3093
+
3094
+ def toggle_loop(checked):
3095
+ loop_enabled[0] = checked
3096
+ if checked:
3097
+ loop_btn.setText("🔁 Loop ON")
3098
+ else:
3099
+ loop_btn.setText("🔁 Loop")
3100
+
3101
+ loop_btn = QPushButton("🔁 Loop")
3102
+ loop_btn.setCheckable(True)
3103
+ loop_btn.toggled.connect(toggle_loop)
3104
+
3105
+ # Clip-only mode checkbox (only show if clip metadata is available)
3106
+ clip_only_chk = None
3107
+ if clip_start_ms is not None and clip_end_ms is not None:
3108
+ clip_only_chk = QCheckBox("Clip only")
3109
+ clip_only_chk.setToolTip("When checked, play and loop only the clip area (without context frames)")
3110
+
3111
+ def toggle_clip_only(checked):
3112
+ clip_only_mode[0] = checked
3113
+ if checked:
3114
+ # Jump to clip start when enabling clip-only mode
3115
+ expanded_player.setPosition(clip_start_ms)
3116
+ # Auto-enable loop for better UX in clip-only mode
3117
+ if not loop_enabled[0]:
3118
+ loop_btn.setChecked(True)
3119
+ # If unchecked, allow playback from current position (full video)
3120
+
3121
+ clip_only_chk.toggled.connect(toggle_clip_only)
3122
+
3123
+ # Playback speed controls
3124
+ speed_label = QLabel("Speed:")
3125
+ speed_1x_btn = QPushButton("1x")
3126
+ speed_1x_btn.setCheckable(True)
3127
+ speed_1x_btn.setChecked(True)
3128
+ speed_05x_btn = QPushButton("0.5x")
3129
+ speed_05x_btn.setCheckable(True)
3130
+ speed_025x_btn = QPushButton("0.25x")
3131
+ speed_025x_btn.setCheckable(True)
3132
+ speed_0166x_btn = QPushButton("0.166x")
3133
+ speed_0166x_btn.setCheckable(True)
3134
+
3135
+ # Speed button group
3136
+ speed_buttons = [speed_1x_btn, speed_05x_btn, speed_025x_btn, speed_0166x_btn]
3137
+ current_speed = [1.0] # Use list to allow modification in nested functions
3138
+
3139
+ def set_speed(speed, button):
3140
+ current_speed[0] = speed
3141
+ expanded_player.setPlaybackRate(speed)
3142
+ # Update button states
3143
+ for btn in speed_buttons:
3144
+ btn.setChecked(btn == button)
3145
+
3146
+ speed_1x_btn.clicked.connect(lambda: set_speed(1.0, speed_1x_btn))
3147
+ speed_05x_btn.clicked.connect(lambda: set_speed(0.5, speed_05x_btn))
3148
+ speed_025x_btn.clicked.connect(lambda: set_speed(0.25, speed_025x_btn))
3149
+ speed_0166x_btn.clicked.connect(lambda: set_speed(1.0/6.0, speed_0166x_btn))
3150
+
3151
+ close_btn = QPushButton("Close")
3152
+ close_btn.clicked.connect(dialog.close)
3153
+
3154
+ controls.addWidget(play_btn)
3155
+ controls.addWidget(loop_btn)
3156
+ if clip_only_chk is not None:
3157
+ controls.addWidget(clip_only_chk)
3158
+ controls.addWidget(speed_label)
3159
+ controls.addWidget(speed_1x_btn)
3160
+ controls.addWidget(speed_05x_btn)
3161
+ controls.addWidget(speed_025x_btn)
3162
+ controls.addWidget(speed_0166x_btn)
3163
+ controls.addStretch()
3164
+ controls.addWidget(close_btn)
3165
+
3166
+ layout.addWidget(video_container)
3167
+ layout.addLayout(controls)
3168
+
3169
+ # Auto-play when opened (if clip-only mode is enabled, start at clip start)
3170
+ if clip_only_mode[0] and clip_start_ms is not None:
3171
+ expanded_player.setPosition(clip_start_ms)
3172
+ expanded_player.play()
3173
+ play_btn.setText("⏸ Pause")
3174
+
3175
+ dialog.exec()
3176
+
3177
+ # Cleanup - disconnect signals before stopping to prevent errors
3178
+ try:
3179
+ expanded_player.positionChanged.disconnect()
3180
+ expanded_player.durationChanged.disconnect()
3181
+ expanded_player.playbackStateChanged.disconnect()
3182
+ expanded_player.mediaStatusChanged.disconnect()
3183
+ except (RuntimeError, TypeError):
3184
+ # Signals already disconnected or widget deleted
3185
+ pass
3186
+
3187
+ expanded_player.stop()