singlebehaviorlab 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. sam2/__init__.py +11 -0
  2. sam2/automatic_mask_generator.py +454 -0
  3. sam2/benchmark.py +92 -0
  4. sam2/build_sam.py +174 -0
  5. sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
  6. sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
  7. sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
  8. sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
  9. sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
  10. sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
  11. sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
  12. sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
  13. sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
  14. sam2/modeling/__init__.py +5 -0
  15. sam2/modeling/backbones/__init__.py +5 -0
  16. sam2/modeling/backbones/hieradet.py +317 -0
  17. sam2/modeling/backbones/image_encoder.py +134 -0
  18. sam2/modeling/backbones/utils.py +93 -0
  19. sam2/modeling/memory_attention.py +169 -0
  20. sam2/modeling/memory_encoder.py +181 -0
  21. sam2/modeling/position_encoding.py +239 -0
  22. sam2/modeling/sam/__init__.py +5 -0
  23. sam2/modeling/sam/mask_decoder.py +295 -0
  24. sam2/modeling/sam/prompt_encoder.py +202 -0
  25. sam2/modeling/sam/transformer.py +311 -0
  26. sam2/modeling/sam2_base.py +913 -0
  27. sam2/modeling/sam2_utils.py +323 -0
  28. sam2/sam2_hiera_b+.yaml +113 -0
  29. sam2/sam2_hiera_l.yaml +117 -0
  30. sam2/sam2_hiera_s.yaml +116 -0
  31. sam2/sam2_hiera_t.yaml +118 -0
  32. sam2/sam2_image_predictor.py +466 -0
  33. sam2/sam2_video_predictor.py +1388 -0
  34. sam2/sam2_video_predictor_legacy.py +1172 -0
  35. sam2/utils/__init__.py +5 -0
  36. sam2/utils/amg.py +348 -0
  37. sam2/utils/misc.py +349 -0
  38. sam2/utils/transforms.py +118 -0
  39. singlebehaviorlab/__init__.py +4 -0
  40. singlebehaviorlab/__main__.py +130 -0
  41. singlebehaviorlab/_paths.py +100 -0
  42. singlebehaviorlab/backend/__init__.py +2 -0
  43. singlebehaviorlab/backend/augmentations.py +320 -0
  44. singlebehaviorlab/backend/data_store.py +420 -0
  45. singlebehaviorlab/backend/model.py +1290 -0
  46. singlebehaviorlab/backend/train.py +4667 -0
  47. singlebehaviorlab/backend/uncertainty.py +578 -0
  48. singlebehaviorlab/backend/video_processor.py +688 -0
  49. singlebehaviorlab/backend/video_utils.py +139 -0
  50. singlebehaviorlab/data/config/config.yaml +85 -0
  51. singlebehaviorlab/data/training_profiles.json +334 -0
  52. singlebehaviorlab/gui/__init__.py +4 -0
  53. singlebehaviorlab/gui/analysis_widget.py +2291 -0
  54. singlebehaviorlab/gui/attention_export.py +311 -0
  55. singlebehaviorlab/gui/clip_extraction_widget.py +481 -0
  56. singlebehaviorlab/gui/clustering_widget.py +3187 -0
  57. singlebehaviorlab/gui/inference_popups.py +1138 -0
  58. singlebehaviorlab/gui/inference_widget.py +4550 -0
  59. singlebehaviorlab/gui/inference_worker.py +651 -0
  60. singlebehaviorlab/gui/labeling_widget.py +2324 -0
  61. singlebehaviorlab/gui/main_window.py +754 -0
  62. singlebehaviorlab/gui/metadata_management_widget.py +1119 -0
  63. singlebehaviorlab/gui/motion_tracking.py +764 -0
  64. singlebehaviorlab/gui/overlay_export.py +1234 -0
  65. singlebehaviorlab/gui/plot_integration.py +729 -0
  66. singlebehaviorlab/gui/qt_helpers.py +29 -0
  67. singlebehaviorlab/gui/registration_widget.py +1485 -0
  68. singlebehaviorlab/gui/review_widget.py +1330 -0
  69. singlebehaviorlab/gui/segmentation_tracking_widget.py +2752 -0
  70. singlebehaviorlab/gui/tab_tutorial_dialog.py +312 -0
  71. singlebehaviorlab/gui/timeline_themes.py +131 -0
  72. singlebehaviorlab/gui/training_profiles.py +418 -0
  73. singlebehaviorlab/gui/training_widget.py +3719 -0
  74. singlebehaviorlab/gui/video_utils.py +233 -0
  75. singlebehaviorlab/licenses/SAM2-LICENSE +201 -0
  76. singlebehaviorlab/licenses/VideoPrism-LICENSE +202 -0
  77. singlebehaviorlab-2.0.0.dist-info/METADATA +447 -0
  78. singlebehaviorlab-2.0.0.dist-info/RECORD +88 -0
  79. singlebehaviorlab-2.0.0.dist-info/WHEEL +5 -0
  80. singlebehaviorlab-2.0.0.dist-info/entry_points.txt +2 -0
  81. singlebehaviorlab-2.0.0.dist-info/licenses/LICENSE +21 -0
  82. singlebehaviorlab-2.0.0.dist-info/top_level.txt +3 -0
  83. videoprism/__init__.py +0 -0
  84. videoprism/encoders.py +910 -0
  85. videoprism/layers.py +1136 -0
  86. videoprism/models.py +407 -0
  87. videoprism/tokenizers.py +167 -0
  88. videoprism/utils.py +168 -0
@@ -0,0 +1,2291 @@
1
+ import json
2
+ import logging
3
+ import os
4
+ import numpy as np
5
+ import cv2
6
+ import yaml
7
+
8
+ logger = logging.getLogger(__name__)
9
+ from PyQt6.QtWidgets import (
10
+ QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QLabel, QComboBox,
11
+ QGroupBox, QTableWidget, QTableWidgetItem, QHeaderView, QDialog,
12
+ QFormLayout, QLineEdit, QMessageBox, QSplitter, QCheckBox, QFileDialog,
13
+ QSizePolicy, QScrollArea, QTabWidget, QListWidget, QListWidgetItem,
14
+ QTextEdit, QAbstractItemView, QSlider, QToolButton, QInputDialog,
15
+ QButtonGroup
16
+ )
17
+ from PyQt6.QtCore import Qt, QUrl, QPointF, QRectF
18
+ from PyQt6.QtGui import QPainter, QPen, QBrush, QColor, QFont, QPolygonF, QPainterPath
19
+
20
+ try:
21
+ from PyQt6.QtWebEngineWidgets import QWebEngineView
22
+ HAS_WEBENGINE = True
23
+ except ImportError:
24
+ HAS_WEBENGINE = False
25
+
26
+ try:
27
+ import plotly.graph_objects as go
28
+ import plotly.express as px
29
+ import pandas as pd
30
+ HAS_PLOTLY = True
31
+ except ImportError:
32
+ HAS_PLOTLY = False
33
+
34
+ try:
35
+ import scipy.stats as stats
36
+ HAS_SCIPY = True
37
+ except ImportError:
38
+ HAS_SCIPY = False
39
+
40
+ class AnalysisWidget(QWidget):
41
+ """Widget for downstream analysis of behavior data."""
42
+
43
+ def __init__(self, config: dict):
44
+ super().__init__()
45
+ self.config = config
46
+ self.results = {}
47
+ self.groups = {} # {group_name: [video_path, ...]}
48
+ self.video_groups = {} # {video_path: group_name}
49
+ self.spatial_regions = [] # [{name, type, vertices}, ...]
50
+ self.merged_data = [] # List of dicts
51
+ self.all_behaviors = []
52
+ self.selected_behaviors = set()
53
+ self.visible_groups = set()
54
+ self._setup_ui()
55
+
56
+ def _setup_ui(self):
57
+ layout = QVBoxLayout()
58
+
59
+ # Global Controls (Load, Manage Groups)
60
+ global_controls = QGroupBox("Data controls")
61
+ global_controls.setSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Maximum)
62
+ global_layout = QHBoxLayout()
63
+
64
+ self.load_btn = QPushButton("Load results")
65
+ self.load_btn.clicked.connect(self._load_results)
66
+ global_layout.addWidget(self.load_btn)
67
+
68
+ self.manage_groups_btn = QPushButton("Manage groups")
69
+ self.manage_groups_btn.clicked.connect(self._manage_groups)
70
+ self.manage_groups_btn.setEnabled(False)
71
+ global_layout.addWidget(self.manage_groups_btn)
72
+
73
+ self.filter_behaviors_btn = QPushButton("Filter behaviors")
74
+ self.filter_behaviors_btn.clicked.connect(self._filter_behaviors)
75
+ self.filter_behaviors_btn.setEnabled(False)
76
+ global_layout.addWidget(self.filter_behaviors_btn)
77
+
78
+ global_layout.addStretch()
79
+ global_controls.setLayout(global_layout)
80
+ layout.addWidget(global_controls)
81
+
82
+ # Tab Widget
83
+ self.tabs = QTabWidget()
84
+
85
+ # Overview Tab
86
+ self.overview_tab = self._create_overview_tab()
87
+ self.tabs.addTab(self.overview_tab, "Overview")
88
+
89
+ # Behavior Transitions Tab
90
+ self.transitions_tab = self._create_transitions_tab()
91
+ self.tabs.addTab(self.transitions_tab, "Behavior Transitions")
92
+
93
+ layout.addWidget(self.tabs)
94
+ self.setLayout(layout)
95
+
96
+ def _create_overview_tab(self):
97
+ """Create the Overview tab with split layout."""
98
+ tab = QWidget()
99
+ main_layout = QHBoxLayout()
100
+
101
+ # Left: Plot area (70%)
102
+ if HAS_WEBENGINE:
103
+ self.webview = QWebEngineView()
104
+ self.webview.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
105
+ main_layout.addWidget(self.webview, 70)
106
+ else:
107
+ lbl = QLabel("PyQt6.QtWebEngineWidgets not installed. Plots will open in default browser.")
108
+ lbl.setAlignment(Qt.AlignmentFlag.AlignCenter)
109
+ main_layout.addWidget(lbl, 70)
110
+
111
+ # Right: Controls & Analysis (30%)
112
+ sidebar_scroll = QScrollArea()
113
+ sidebar_scroll.setWidgetResizable(True)
114
+ sidebar_widget = QWidget()
115
+ sidebar_layout = QVBoxLayout()
116
+ sidebar_layout.setSpacing(15)
117
+
118
+ # 1. Plot Settings
119
+ plot_group = QGroupBox("Plot settings")
120
+ plot_layout = QFormLayout()
121
+
122
+ self.metric_combo = QComboBox()
123
+ self.metric_combo.addItems(["Occurrences (Count)", "Average Bout Duration (s)", "Total Duration (s)", "Percent Time (%)"])
124
+ self.metric_combo.currentIndexChanged.connect(self._update_plots)
125
+ plot_layout.addRow("Metric:", self.metric_combo)
126
+
127
+ self.plot_mode_combo = QComboBox()
128
+ self.plot_mode_combo.addItems(["General Overview", "Group Comparison"])
129
+ self.plot_mode_combo.currentIndexChanged.connect(self._update_plots)
130
+ plot_layout.addRow("Analysis Mode:", self.plot_mode_combo)
131
+
132
+ plot_group.setLayout(plot_layout)
133
+ sidebar_layout.addWidget(plot_group)
134
+
135
+ # 1b. Graph Appearance
136
+ app_group = QGroupBox("Graph appearance")
137
+ app_layout = QFormLayout()
138
+
139
+ self.graph_type_combo = QComboBox()
140
+ self.graph_type_combo.addItems(["Auto", "Bar Chart", "Box Plot", "Violin Plot", "Strip Plot", "Line Plot"])
141
+ self.graph_type_combo.currentIndexChanged.connect(self._update_plots)
142
+ app_layout.addRow("Graph Type:", self.graph_type_combo)
143
+
144
+ self.color_theme_combo = QComboBox()
145
+ self.color_theme_combo.addItems(["plotly", "plotly_white", "plotly_dark", "ggplot2", "seaborn", "simple_white", "none"])
146
+ self.color_theme_combo.setCurrentText("simple_white")
147
+ self.color_theme_combo.currentIndexChanged.connect(self._update_plots)
148
+ app_layout.addRow("Color Theme:", self.color_theme_combo)
149
+
150
+ app_group.setLayout(app_layout)
151
+ sidebar_layout.addWidget(app_group)
152
+
153
+ # 1c. Spatial Distribution (uses localization data)
154
+ spatial_group = QGroupBox("Spatial distribution")
155
+ spatial_layout = QVBoxLayout()
156
+
157
+ self.spatial_behavior_combo = QComboBox()
158
+ self.spatial_behavior_combo.addItem("All behaviors")
159
+ self.spatial_behavior_combo.currentIndexChanged.connect(self._update_spatial_plot)
160
+ spatial_form = QFormLayout()
161
+ spatial_form.addRow("Behavior:", self.spatial_behavior_combo)
162
+
163
+ self.spatial_video_combo = QComboBox()
164
+ self.spatial_video_combo.addItem("All")
165
+ self.spatial_video_combo.currentIndexChanged.connect(self._update_spatial_plot)
166
+ spatial_form.addRow("Video:", self.spatial_video_combo)
167
+
168
+ self.spatial_dot_size_slider = QSlider(Qt.Orientation.Horizontal)
169
+ self.spatial_dot_size_slider.setMinimum(1)
170
+ self.spatial_dot_size_slider.setMaximum(25)
171
+ self.spatial_dot_size_slider.setValue(7)
172
+ self.spatial_dot_size_slider.valueChanged.connect(self._update_spatial_plot)
173
+ spatial_form.addRow("Dot size:", self.spatial_dot_size_slider)
174
+
175
+ spatial_layout.addLayout(spatial_form)
176
+
177
+ self.spatial_show_btn = QPushButton("Show spatial distribution")
178
+ self.spatial_show_btn.clicked.connect(self._update_spatial_plot)
179
+ self.spatial_show_btn.setEnabled(False)
180
+ spatial_layout.addWidget(self.spatial_show_btn)
181
+
182
+ self.spatial_save_btn = QPushButton("Save spatial plot (PDF/SVG)")
183
+ self.spatial_save_btn.clicked.connect(self._save_spatial_plot)
184
+ spatial_layout.addWidget(self.spatial_save_btn)
185
+
186
+ self.spatial_info_label = QLabel("Load results with localization data to use.")
187
+ self.spatial_info_label.setWordWrap(True)
188
+ self.spatial_info_label.setStyleSheet("color: grey; font-size: 11px;")
189
+ spatial_layout.addWidget(self.spatial_info_label)
190
+
191
+ spatial_group.setToolTip(
192
+ "One point per inference clip (step-frame grid). Not based on aggregated bout boundaries."
193
+ )
194
+
195
+ self.manage_regions_btn = QPushButton("Manage spatial regions")
196
+ self.manage_regions_btn.setToolTip("Draw named regions on the spatial map, then filter analysis by region.")
197
+ self.manage_regions_btn.clicked.connect(self._manage_spatial_regions)
198
+ self.manage_regions_btn.setEnabled(False)
199
+ spatial_layout.addWidget(self.manage_regions_btn)
200
+
201
+ self.region_filter_combo = QComboBox()
202
+ self.region_filter_combo.addItem("All Regions")
203
+ self.region_filter_combo.currentIndexChanged.connect(self._on_region_filter_changed)
204
+ spatial_form.addRow("Region:", self.region_filter_combo)
205
+
206
+ spatial_group.setLayout(spatial_layout)
207
+ sidebar_layout.addWidget(spatial_group)
208
+
209
+ # 2. Group Selection
210
+ group_group = QGroupBox("Groups")
211
+ group_layout = QVBoxLayout()
212
+ self.group_list_widget = QListWidget()
213
+ self.group_list_widget.setSelectionMode(QAbstractItemView.SelectionMode.NoSelection) # Checkboxes only
214
+ self.group_list_widget.setFixedHeight(150)
215
+ self.group_list_widget.itemChanged.connect(self._on_group_selection_changed)
216
+ group_layout.addWidget(self.group_list_widget)
217
+ group_group.setLayout(group_layout)
218
+ sidebar_layout.addWidget(group_group)
219
+
220
+ # 3. Statistics
221
+ stats_group = QGroupBox("Statistics (Group comp.)")
222
+ stats_layout = QVBoxLayout()
223
+
224
+ self.stats_test_combo = QComboBox()
225
+ self.stats_test_combo.addItems(["T-Test / Mann-Whitney (2 groups)", "ANOVA / Kruskal-Wallis (>2 groups)"])
226
+ stats_layout.addWidget(QLabel("Test type:"))
227
+ stats_layout.addWidget(self.stats_test_combo)
228
+
229
+ self.run_stats_btn = QPushButton("Run statistics")
230
+ self.run_stats_btn.clicked.connect(self._run_statistics)
231
+ stats_layout.addWidget(self.run_stats_btn)
232
+
233
+ self.stats_output = QTextEdit()
234
+ self.stats_output.setReadOnly(True)
235
+ self.stats_output.setPlaceholderText("Results will appear here...")
236
+ self.stats_output.setMaximumHeight(150)
237
+ stats_layout.addWidget(self.stats_output)
238
+
239
+ stats_group.setLayout(stats_layout)
240
+ sidebar_layout.addWidget(stats_group)
241
+
242
+ # 4. Export
243
+ export_group = QGroupBox("Export")
244
+ export_layout = QVBoxLayout()
245
+
246
+ self.save_graph_btn = QPushButton("Save graph")
247
+ self.save_graph_btn.clicked.connect(self._save_graph)
248
+ export_layout.addWidget(self.save_graph_btn)
249
+
250
+ self.save_csv_btn = QPushButton("Save data (.csv)")
251
+ self.save_csv_btn.clicked.connect(self._save_csv)
252
+ export_layout.addWidget(self.save_csv_btn)
253
+
254
+ export_group.setLayout(export_layout)
255
+ sidebar_layout.addWidget(export_group)
256
+
257
+ sidebar_layout.addStretch()
258
+ sidebar_widget.setLayout(sidebar_layout)
259
+ sidebar_scroll.setWidget(sidebar_widget)
260
+
261
+ main_layout.addWidget(sidebar_scroll, 30)
262
+
263
+ tab.setLayout(main_layout)
264
+ return tab
265
+
266
+ def _update_sidebar_groups(self):
267
+ """Update group list in sidebar based on current data."""
268
+ self.group_list_widget.blockSignals(True)
269
+ self.group_list_widget.clear()
270
+
271
+ # Get all unique groups from data or metadata
272
+ groups = sorted(list(self.groups.keys()))
273
+ if not groups and self.merged_data:
274
+ # Fallback if groups dict not fully populated but data exists
275
+ df = pd.DataFrame(self.merged_data)
276
+ if "Group" in df.columns:
277
+ groups = sorted(df["Group"].dropna().unique().tolist())
278
+
279
+ if not groups:
280
+ groups = ["Unassigned"]
281
+
282
+ for grp in groups:
283
+ item = QListWidgetItem(grp)
284
+ item.setFlags(item.flags() | Qt.ItemFlag.ItemIsUserCheckable)
285
+ item.setCheckState(Qt.CheckState.Checked)
286
+ self.group_list_widget.addItem(item)
287
+
288
+ self.visible_groups = set(groups)
289
+ self.group_list_widget.blockSignals(False)
290
+
291
+ def _on_group_selection_changed(self, item):
292
+ """Handle group checkbox toggles."""
293
+ self.visible_groups = set()
294
+ for i in range(self.group_list_widget.count()):
295
+ it = self.group_list_widget.item(i)
296
+ if it.checkState() == Qt.CheckState.Checked:
297
+ self.visible_groups.add(it.text())
298
+ self._update_plots()
299
+
300
+ def _run_statistics(self):
301
+ if not HAS_SCIPY:
302
+ QMessageBox.warning(self, "Error", "Scipy is not installed. Cannot run statistics.")
303
+ return
304
+
305
+ if not self.merged_data:
306
+ return
307
+
308
+ df = pd.DataFrame(self.merged_data)
309
+
310
+ # Apply filters
311
+ if self.selected_behaviors:
312
+ df = df[df["Behavior"].isin(self.selected_behaviors)]
313
+ if self.visible_groups:
314
+ df = df[df["Group"].isin(self.visible_groups)]
315
+ sel_region = self.region_filter_combo.currentText()
316
+ if sel_region != "All Regions" and "Region" in df.columns:
317
+ df = df[df["Region"] == sel_region]
318
+
319
+ if df.empty:
320
+ self.stats_output.setText("No data available for selected filters.")
321
+ return
322
+
323
+ metric = self.metric_combo.currentText()
324
+ test_type = self.stats_test_combo.currentText()
325
+
326
+ # Prepare data for stats
327
+ if "Occurrences" in metric:
328
+ agg = df.groupby(["Video", "Group", "Behavior"]).size().reset_index(name="Value")
329
+ elif "Average" in metric:
330
+ agg = df.groupby(["Video", "Group", "Behavior"])["Duration"].mean().reset_index(name="Value")
331
+ elif "Total" in metric:
332
+ agg = df.groupby(["Video", "Group", "Behavior"])["Duration"].sum().reset_index(name="Value")
333
+ elif "Percent" in metric:
334
+ total_times = df.groupby("Video")["Duration"].sum().to_dict()
335
+ agg = df.groupby(["Video", "Group", "Behavior"])["Duration"].sum().reset_index(name="Value")
336
+ agg["Value"] = agg.apply(lambda x: (x["Value"] / total_times.get(x["Video"], 1)) * 100, axis=1)
337
+
338
+ output = []
339
+ output.append(f"Metric: {metric}")
340
+ output.append(f"Test: {test_type}\n")
341
+
342
+ unique_behaviors = sorted(agg["Behavior"].unique())
343
+ unique_groups = sorted(agg["Group"].unique())
344
+
345
+ if len(unique_groups) < 2:
346
+ self.stats_output.setText("Need at least 2 groups for comparison.")
347
+ return
348
+
349
+ for beh in unique_behaviors:
350
+ beh_data = agg[agg["Behavior"] == beh]
351
+ groups_data = [beh_data[beh_data["Group"] == g]["Value"].values for g in unique_groups]
352
+
353
+ # Filter valid data
354
+ valid_groups_data = []
355
+ valid_group_names = []
356
+ for g, d in zip(unique_groups, groups_data):
357
+ if len(d) > 0:
358
+ valid_groups_data.append(d)
359
+ valid_group_names.append(g)
360
+
361
+ if len(valid_groups_data) < 2:
362
+ output.append(f"{beh}: Not enough data")
363
+ continue
364
+
365
+ try:
366
+ if "2 groups" in test_type:
367
+ # Pairwise Mann-Whitney
368
+ import itertools
369
+ for g1, g2 in itertools.combinations(valid_group_names, 2):
370
+ d1 = valid_groups_data[valid_group_names.index(g1)]
371
+ d2 = valid_groups_data[valid_group_names.index(g2)]
372
+ stat, p = stats.mannwhitneyu(d1, d2)
373
+ output.append(f"{beh} ({g1} vs {g2}): p={p:.4f}")
374
+
375
+ else:
376
+ # Kruskal-Wallis
377
+ stat, p = stats.kruskal(*valid_groups_data)
378
+ output.append(f"{beh} (Kruskal-Wallis): p={p:.4f}")
379
+
380
+ except Exception as e:
381
+ output.append(f"{beh}: Error ({str(e)})")
382
+
383
+ self.stats_output.setText("\n".join(output))
384
+
385
+ def _save_graph(self):
386
+ if not hasattr(self, 'last_fig') or self.last_fig is None:
387
+ QMessageBox.warning(self, "Error", "No plot to save.")
388
+ return
389
+
390
+ path, filter_ = QFileDialog.getSaveFileName(
391
+ self, "Save Graph",
392
+ self.config.get("experiment_path", ""),
393
+ "PDF Files (*.pdf);;SVG Files (*.svg);;PNG Files (*.png);;HTML Files (*.html)"
394
+ )
395
+
396
+ if not path:
397
+ return
398
+
399
+ try:
400
+ if path.lower().endswith(".html"):
401
+ self.last_fig.write_html(path)
402
+ else:
403
+ self.last_fig.write_image(path)
404
+ QMessageBox.information(self, "Success", f"Graph saved to {path}")
405
+ except Exception as e:
406
+ if "kaleido" in str(e).lower() or "executable" in str(e).lower():
407
+ QMessageBox.warning(self, "Error", "Saving as static image requires 'kaleido'.\nPlease install it (pip install kaleido) or save as HTML.")
408
+ else:
409
+ QMessageBox.critical(self, "Error", f"Failed to save graph: {e}")
410
+
411
+ def _save_csv(self):
412
+ if not self.merged_data:
413
+ return
414
+
415
+ path, _ = QFileDialog.getSaveFileName(
416
+ self, "Save Data",
417
+ self.config.get("experiment_path", ""),
418
+ "CSV Files (*.csv)"
419
+ )
420
+
421
+ if not path:
422
+ return
423
+
424
+ df = pd.DataFrame(self.merged_data)
425
+ if self.selected_behaviors:
426
+ df = df[df["Behavior"].isin(self.selected_behaviors)]
427
+ if self.visible_groups:
428
+ df = df[df["Group"].isin(self.visible_groups)]
429
+ sel_region = self.region_filter_combo.currentText()
430
+ if sel_region != "All Regions" and "Region" in df.columns:
431
+ df = df[df["Region"] == sel_region]
432
+
433
+ df.to_csv(path, index=False)
434
+ QMessageBox.information(self, "Success", f"Data saved to {path}")
435
+
436
+ def _create_transitions_tab(self):
437
+ """Create the Behavior Transitions tab (Markov chain analysis)."""
438
+ tab = QWidget()
439
+ layout = QVBoxLayout()
440
+
441
+ # Controls
442
+ controls_group = QGroupBox("Transition analysis controls")
443
+ controls_group.setSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Maximum)
444
+ controls_layout = QHBoxLayout()
445
+
446
+ controls_layout.addWidget(QLabel("Analysis type:"))
447
+ self.transition_type_combo = QComboBox()
448
+ self.transition_type_combo.addItems(["Individual Video", "Group Comparison"])
449
+ self.transition_type_combo.currentIndexChanged.connect(self._on_transition_type_changed)
450
+ controls_layout.addWidget(self.transition_type_combo)
451
+
452
+ controls_layout.addWidget(QLabel("Select:"))
453
+ self.transition_select_combo = QComboBox()
454
+ controls_layout.addWidget(self.transition_select_combo)
455
+
456
+ # New Controls for Layout and Filtering
457
+ self.layout_combo = QComboBox()
458
+ self.layout_combo.addItems(["Circular Layout", "Network Layout"])
459
+ controls_layout.addWidget(QLabel("Layout:"))
460
+ controls_layout.addWidget(self.layout_combo)
461
+
462
+ self.sig_filter_check = QCheckBox("Significant only")
463
+ self.sig_filter_check.setToolTip("Only show transitions that occur significantly more than expected by chance (Z-score > 1.96)")
464
+ controls_layout.addWidget(self.sig_filter_check)
465
+
466
+ self.compute_transition_btn = QPushButton("Compute transitions")
467
+ self.compute_transition_btn.clicked.connect(self._compute_transitions)
468
+ self.compute_transition_btn.setEnabled(False)
469
+ controls_layout.addWidget(self.compute_transition_btn)
470
+
471
+ controls_layout.addStretch()
472
+ controls_group.setLayout(controls_layout)
473
+ layout.addWidget(controls_group)
474
+
475
+ # Transition Matrix Display
476
+ matrix_group = QGroupBox("Transition matrix")
477
+ matrix_layout = QVBoxLayout()
478
+ self.transition_matrix_table = QTableWidget()
479
+ self.transition_matrix_table.setMaximumHeight(250)
480
+ matrix_layout.addWidget(self.transition_matrix_table)
481
+ matrix_group.setLayout(matrix_layout)
482
+ layout.addWidget(matrix_group)
483
+
484
+ # Transition Graph
485
+ if HAS_WEBENGINE:
486
+ self.transition_webview = QWebEngineView()
487
+ self.transition_webview.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
488
+ layout.addWidget(self.transition_webview, 1)
489
+ else:
490
+ layout.addWidget(QLabel("PyQt6.QtWebEngineWidgets not installed."))
491
+
492
+ tab.setLayout(layout)
493
+ return tab
494
+
495
+ def update_config(self, config: dict):
496
+ self.config = config
497
+ self._load_groups_from_config()
498
+ # Try to auto-load
499
+ self._auto_load_results()
500
+
501
+ def _load_groups_from_config(self):
502
+ groups = self.config.get("analysis_groups", {})
503
+ video_groups = self.config.get("analysis_video_groups", {})
504
+ if isinstance(groups, dict):
505
+ self.groups = {
506
+ str(group_name): [str(path) for path in (paths or [])]
507
+ for group_name, paths in groups.items()
508
+ }
509
+ if isinstance(video_groups, dict):
510
+ self.video_groups = {
511
+ str(video_path): str(group_name)
512
+ for video_path, group_name in video_groups.items()
513
+ }
514
+
515
+ def _save_groups_to_config(self):
516
+ self.config["analysis_groups"] = {
517
+ str(group_name): [str(path) for path in sorted(set(paths or []))]
518
+ for group_name, paths in self.groups.items()
519
+ }
520
+ self.config["analysis_video_groups"] = {
521
+ str(video_path): str(group_name)
522
+ for video_path, group_name in self.video_groups.items()
523
+ }
524
+ config_path = self.config.get("config_path")
525
+ if not config_path:
526
+ return
527
+ try:
528
+ os.makedirs(os.path.dirname(config_path), exist_ok=True)
529
+ with open(config_path, "w", encoding="utf-8") as f:
530
+ yaml.safe_dump(dict(self.config), f, sort_keys=False)
531
+ except Exception as e:
532
+ logger.error("Error saving group config: %s", e)
533
+
534
+ def _auto_load_results(self):
535
+ exp_path = self.config.get("experiment_path")
536
+ if exp_path:
537
+ results_path = os.path.join(exp_path, "results", "inference_results.json")
538
+ if os.path.exists(results_path):
539
+ self._load_results_from_file(results_path)
540
+
541
+ def _load_results(self):
542
+ path, _ = QFileDialog.getOpenFileName(
543
+ self, "Load Inference Results",
544
+ self.config.get("experiment_path", ""),
545
+ "JSON Files (*.json)"
546
+ )
547
+ if path:
548
+ self._load_results_from_file(path)
549
+
550
+ def _load_results_from_file(self, path):
551
+ try:
552
+ with open(path, 'r') as f:
553
+ self.results = json.load(f)
554
+
555
+ self.manage_groups_btn.setEnabled(True)
556
+
557
+ # Load metadata if exists
558
+ meta_path = path.replace("inference_results.json", "analysis_metadata.json")
559
+ if os.path.exists(meta_path):
560
+ with open(meta_path, 'r') as f:
561
+ meta = json.load(f)
562
+ if not self.groups:
563
+ self.groups = meta.get("groups", {})
564
+ if not self.video_groups:
565
+ self.video_groups = meta.get("video_groups", {})
566
+ raw_regions = meta.get("spatial_regions", [])
567
+ self.spatial_regions = []
568
+ for r in raw_regions:
569
+ self.spatial_regions.append({
570
+ "name": r["name"],
571
+ "type": r["type"],
572
+ "vertices": [tuple(v) for v in r["vertices"]],
573
+ })
574
+
575
+ self._update_region_filter_combo()
576
+ self.manage_regions_btn.setEnabled(True)
577
+ self._process_data()
578
+ self._update_plots()
579
+
580
+ except Exception as e:
581
+ QMessageBox.critical(self, "Error", f"Failed to load results: {e}")
582
+
583
+ @staticmethod
584
+ def _point_in_region(cx, cy, region):
585
+ verts = region["vertices"]
586
+ if region["type"] == "rect":
587
+ return verts[0][0] <= cx <= verts[1][0] and verts[0][1] <= cy <= verts[1][1]
588
+ from matplotlib.path import Path as MplPath
589
+ return MplPath(verts).contains_point((cx, cy))
590
+
591
+ def _region_for_point(self, cx, cy):
592
+ """Return region name for a centroid, or 'Outside'."""
593
+ for r in self.spatial_regions:
594
+ if self._point_in_region(cx, cy, r):
595
+ return r["name"]
596
+ return "Outside"
597
+
598
+ def _build_clip_centroids(self, v_data):
599
+ """Return list of (cx, cy) per clip index, or None for clips without bbox."""
600
+ loc_bboxes = v_data.get("localization_bboxes", [])
601
+ centroids = []
602
+ for raw in loc_bboxes:
603
+ cx, cy = self._bbox_centroid(raw)
604
+ centroids.append((cx, cy) if cx is not None else None)
605
+ return centroids
606
+
607
+ def _process_data(self):
608
+ """Merge timelines and prepare data structures."""
609
+ self.merged_data = []
610
+ has_regions = bool(self.spatial_regions)
611
+
612
+ data_container = self.results
613
+ if "results" in data_container:
614
+ results_dict = data_container["results"]
615
+ classes = data_container.get("classes", [])
616
+ params = data_container.get("parameters", {})
617
+ target_fps = params.get("target_fps", 30)
618
+ clip_length = params.get("clip_length", 16)
619
+ step_frames = params.get("step_frames", 16)
620
+ else:
621
+ results_dict = data_container
622
+ classes = []
623
+ params = {}
624
+ target_fps = 30
625
+ clip_length = 16
626
+ step_frames = 16
627
+
628
+ frame_agg_enabled = params.get("frame_aggregation_enabled", False)
629
+ use_ovr_param = params.get("use_ovr", None)
630
+
631
+ for video_path, v_data in results_dict.items():
632
+ agg_segments = v_data.get("aggregated_segments", [])
633
+ agg_multiclass = v_data.get("aggregated_multiclass_segments", [])
634
+ use_ovr = bool(use_ovr_param) if use_ovr_param is not None else bool(agg_multiclass)
635
+ clip_starts = v_data.get("clip_starts", [])
636
+ total_frames = int(v_data.get("total_frames", 0) or 0)
637
+
638
+ # Determine orig_fps: prefer saved per-video metadata.
639
+ orig_fps = v_data.get("orig_fps", 0)
640
+ if orig_fps <= 0:
641
+ if os.path.exists(video_path):
642
+ try:
643
+ cap = cv2.VideoCapture(video_path)
644
+ orig_fps = cap.get(cv2.CAP_PROP_FPS)
645
+ cap.release()
646
+ except Exception:
647
+ orig_fps = 0
648
+ if orig_fps <= 0 and len(clip_starts) >= 2 and step_frames > 0:
649
+ actual_step = clip_starts[1] - clip_starts[0]
650
+ frame_interval = actual_step / step_frames
651
+ orig_fps = frame_interval * target_fps
652
+ if orig_fps <= 0:
653
+ orig_fps = 30.0
654
+
655
+ # Prefer stored inference-time interval when available.
656
+ frame_interval = v_data.get("frame_interval", 0)
657
+ try:
658
+ frame_interval = int(frame_interval)
659
+ except Exception:
660
+ frame_interval = 0
661
+ if frame_interval <= 0:
662
+ if len(clip_starts) >= 2 and step_frames > 0:
663
+ inferred = int(round((clip_starts[1] - clip_starts[0]) / step_frames))
664
+ frame_interval = max(1, inferred)
665
+ else:
666
+ frame_interval = max(1, int(round(orig_fps / max(1e-6, float(target_fps)))))
667
+
668
+ # Build per-clip centroid lookup for spatial region assignment
669
+ clip_centroids = self._build_clip_centroids(v_data) if has_regions else []
670
+
671
+ def _region_for_frame_range(start_f, end_f):
672
+ """Find region for a segment by checking the clip whose start is nearest the midpoint."""
673
+ if not has_regions or not clip_centroids or not clip_starts:
674
+ return "All"
675
+ mid = (start_f + end_f) / 2.0
676
+ best_ci = 0
677
+ best_dist = float("inf")
678
+ for ci, cs in enumerate(clip_starts):
679
+ d = abs(cs - mid)
680
+ if d < best_dist:
681
+ best_dist = d
682
+ best_ci = ci
683
+ if best_ci < len(clip_centroids) and clip_centroids[best_ci] is not None:
684
+ return self._region_for_point(*clip_centroids[best_ci])
685
+ return "Outside"
686
+
687
+ def _region_for_clip_range(bout_start, bout_end):
688
+ """Find region for a clip-based bout by majority vote of clip centroids."""
689
+ if not has_regions or not clip_centroids:
690
+ return "All"
691
+ region_votes = {}
692
+ for ci in range(bout_start, min(bout_end + 1, len(clip_centroids))):
693
+ c = clip_centroids[ci]
694
+ if c is not None:
695
+ rn = self._region_for_point(*c)
696
+ region_votes[rn] = region_votes.get(rn, 0) + 1
697
+ if not region_votes:
698
+ return "Outside"
699
+ return max(region_votes, key=region_votes.get)
700
+
701
+ # Use frame-level aggregated segments when available (precise boundaries).
702
+ # For OvR, prefer per-class multiclass segments so downstream analysis
703
+ # sees overlapping labels as independent behavior bouts.
704
+ #
705
+ # IMPORTANT: saved aggregated_segments have per-class ignore thresholds
706
+ # applied (frames below threshold are labelled class=-1 / "Filtered").
707
+ # When the threshold for a class is set higher than the model's actual
708
+ # confidence range for that class, the entire class disappears from the
709
+ # saved segments even though the model predicted it as dominant.
710
+ # To avoid this, we rebuild segments from the raw frame probabilities
711
+ # (argmax, no threshold) when they are available so that analysis always
712
+ # reflects true model predictions.
713
+ agg_probs = v_data.get("aggregated_frame_probs")
714
+
715
+ def _relabel_filtered_segs(segs, probs_list):
716
+ """Relabel class=-1 (Filtered) segments using the argmax of the
717
+ underlying frame probabilities over that segment's frame range.
718
+
719
+ The saved aggregated_segments already encode the correct temporal
720
+ structure (temporal smoothing, merge-gap, min-segment all applied).
721
+ The only problem is that the per-class ignore threshold may have
722
+ labelled some segments as Filtered (-1) even though the model
723
+ predicted a real class with high confidence. This function restores
724
+ the true model prediction for those segments without touching the
725
+ segment boundaries.
726
+ """
727
+ if not probs_list or not segs:
728
+ return segs
729
+ n_probs = len(probs_list)
730
+ result = []
731
+ for seg in segs:
732
+ if seg.get("class", -1) >= 0:
733
+ result.append(seg)
734
+ continue
735
+ # Filtered segment: determine label from mean probs over its range
736
+ s = max(0, int(seg["start"]))
737
+ e = min(n_probs - 1, int(seg["end"]))
738
+ if s > e:
739
+ result.append(seg)
740
+ continue
741
+ # Sum probabilities over the frame range and pick argmax
742
+ n_cls = len(probs_list[s])
743
+ totals = [0.0] * n_cls
744
+ for fi in range(s, e + 1):
745
+ for ci, p in enumerate(probs_list[fi]):
746
+ totals[ci] += p
747
+ best = int(max(range(n_cls), key=lambda ci: totals[ci]))
748
+ new_seg = dict(seg)
749
+ new_seg["class"] = best
750
+ result.append(new_seg)
751
+ return result
752
+
753
+ if use_ovr and agg_multiclass:
754
+ seg_source = agg_multiclass
755
+ elif frame_agg_enabled and agg_probs:
756
+ # Use saved segments (correct temporal structure: temporal smoothing,
757
+ # merge-gap, min-segment all already applied) but relabel any
758
+ # Filtered (-1) segments with the true argmax from raw probabilities.
759
+ seg_source = _relabel_filtered_segs(agg_segments, agg_probs)
760
+ else:
761
+ seg_source = agg_segments
762
+
763
+ if frame_agg_enabled and seg_source:
764
+ covered_frames = 0
765
+ for seg in seg_source:
766
+ pred_idx = seg["class"]
767
+ start_frame = seg["start"]
768
+ end_frame = seg["end"]
769
+ covered_frames += max(0, (end_frame - start_frame + 1))
770
+
771
+ if pred_idx < 0:
772
+ label_name = "Filtered"
773
+ elif classes and pred_idx < len(classes):
774
+ label_name = classes[pred_idx]
775
+ else:
776
+ label_name = f"Class {pred_idx}"
777
+
778
+ duration_sec = (end_frame - start_frame + 1) / orig_fps
779
+
780
+ self.merged_data.append({
781
+ "Video": os.path.basename(video_path),
782
+ "VideoPath": video_path,
783
+ "Group": self.video_groups.get(video_path, "Unassigned"),
784
+ "Behavior": label_name,
785
+ "Duration": duration_sec,
786
+ "Region": _region_for_frame_range(start_frame, end_frame),
787
+ })
788
+
789
+ # Keep totals comparable across videos by accounting for uncovered tail.
790
+ if total_frames > 0 and covered_frames < total_frames:
791
+ self.merged_data.append({
792
+ "Video": os.path.basename(video_path),
793
+ "VideoPath": video_path,
794
+ "Group": self.video_groups.get(video_path, "Unassigned"),
795
+ "Behavior": "Uncovered",
796
+ "Duration": (total_frames - covered_frames) / orig_fps,
797
+ "Region": "All",
798
+ })
799
+ continue
800
+
801
+ # Fallback: clip-based bout detection
802
+ preds = v_data.get("predictions", [])
803
+ corrections = v_data.get("corrected_labels", {})
804
+ confs = v_data.get("confidences", [])
805
+
806
+ if not preds:
807
+ continue
808
+
809
+ # Reconstruct ignore threshold from saved parameters
810
+ apply_ignore = params.get("use_ignore_threshold", False)
811
+ ignore_thr = float(params.get("ignore_threshold", 0.5))
812
+
813
+ # Apply corrections and threshold filtering
814
+ final_preds = []
815
+ for i, p in enumerate(preds):
816
+ if str(i) in corrections:
817
+ pred = corrections[str(i)]
818
+ elif i in corrections:
819
+ pred = corrections[i]
820
+ else:
821
+ pred = p
822
+ if apply_ignore and i < len(confs) and float(confs[i]) < ignore_thr:
823
+ pred = -1
824
+ final_preds.append(pred)
825
+
826
+ if not final_preds:
827
+ continue
828
+
829
+ # Derive bouts by clip index, then convert to seconds in original timeline.
830
+ bout_start_idx = 0
831
+ current_label = final_preds[0]
832
+ covered_frames = 0
833
+
834
+ for i in range(1, len(final_preds) + 1):
835
+ boundary = (i == len(final_preds)) or (final_preds[i] != current_label)
836
+ if not boundary:
837
+ continue
838
+
839
+ bout_end_idx = i - 1
840
+ if clip_starts and len(clip_starts) == len(final_preds):
841
+ start_frame = int(clip_starts[bout_start_idx])
842
+ if i < len(clip_starts):
843
+ end_frame_exclusive = int(clip_starts[i])
844
+ else:
845
+ end_frame_exclusive = start_frame + (clip_length * frame_interval)
846
+ if total_frames > 0:
847
+ end_frame_exclusive = min(end_frame_exclusive, total_frames)
848
+ duration_frames = max(0, end_frame_exclusive - start_frame)
849
+ covered_frames += duration_frames
850
+ duration_sec = duration_frames / orig_fps
851
+ else:
852
+ # Legacy fallback when clip starts are missing.
853
+ clip_count = bout_end_idx - bout_start_idx + 1
854
+ duration_subsampled = ((clip_count - 1) * step_frames + clip_length)
855
+ duration_sec = duration_subsampled / max(1e-6, float(target_fps))
856
+
857
+ if current_label < 0:
858
+ label_name = "Filtered"
859
+ elif classes and current_label < len(classes):
860
+ label_name = classes[current_label]
861
+ else:
862
+ label_name = f"Class {current_label}"
863
+
864
+ self.merged_data.append({
865
+ "Video": os.path.basename(video_path),
866
+ "VideoPath": video_path,
867
+ "Group": self.video_groups.get(video_path, "Unassigned"),
868
+ "Behavior": label_name,
869
+ "Duration": duration_sec,
870
+ "Region": _region_for_clip_range(bout_start_idx, bout_end_idx),
871
+ })
872
+
873
+ if i < len(final_preds):
874
+ current_label = final_preds[i]
875
+ bout_start_idx = i
876
+
877
+ if total_frames > 0 and covered_frames < total_frames:
878
+ self.merged_data.append({
879
+ "Video": os.path.basename(video_path),
880
+ "VideoPath": video_path,
881
+ "Group": self.video_groups.get(video_path, "Unassigned"),
882
+ "Behavior": "Uncovered",
883
+ "Duration": (total_frames - covered_frames) / orig_fps,
884
+ "Region": "All",
885
+ })
886
+
887
+ # Update available behaviors and selection (include all model classes so behaviors
888
+ # with zero or few segments still appear in filters and plots)
889
+ if self.merged_data:
890
+ df = pd.DataFrame(self.merged_data)
891
+ from_data = set(df["Behavior"].unique().tolist())
892
+ classes = []
893
+ if isinstance(self.results, dict):
894
+ classes = self.results.get("classes", [])
895
+ self.all_behaviors = sorted(from_data | set(classes))
896
+
897
+ # Initialize selection if empty (first load)
898
+ if not self.selected_behaviors:
899
+ self.selected_behaviors = set(self.all_behaviors)
900
+ else:
901
+ # Clean up stale behaviors
902
+ self.selected_behaviors = self.selected_behaviors.intersection(set(self.all_behaviors))
903
+ # If nothing selected (e.g. all prev behaviors gone), select all new ones
904
+ if not self.selected_behaviors:
905
+ self.selected_behaviors = set(self.all_behaviors)
906
+
907
+ self.filter_behaviors_btn.setEnabled(True)
908
+
909
+ # Update transition tab
910
+ self._on_transition_type_changed()
911
+ self._update_sidebar_groups()
912
+
913
+ # Populate spatial distribution combos
914
+ self._populate_spatial_combos()
915
+ else:
916
+ self.filter_behaviors_btn.setEnabled(False)
917
+
918
+ # ------------------------------------------------------------------
919
+ # Spatial distribution (localization-based)
920
+ # ------------------------------------------------------------------
921
+
922
+ def _populate_spatial_combos(self):
923
+ """Populate behavior and video combos for spatial distribution."""
924
+ data_container = self.results
925
+ results_dict = data_container.get("results", data_container)
926
+ classes = data_container.get("classes", [])
927
+
928
+ # Check if any video has localization data
929
+ has_loc = any(
930
+ "localization_bboxes" in v_data
931
+ for v_data in (results_dict.values() if isinstance(results_dict, dict) else [])
932
+ )
933
+
934
+ self.spatial_behavior_combo.blockSignals(True)
935
+ self.spatial_behavior_combo.clear()
936
+ self.spatial_behavior_combo.addItem("All behaviors")
937
+ for b in self.all_behaviors:
938
+ self.spatial_behavior_combo.addItem(b)
939
+ self.spatial_behavior_combo.blockSignals(False)
940
+
941
+ self.spatial_video_combo.blockSignals(True)
942
+ self.spatial_video_combo.clear()
943
+ self.spatial_video_combo.addItem("All")
944
+ if isinstance(results_dict, dict):
945
+ for vp in results_dict:
946
+ self.spatial_video_combo.addItem(os.path.basename(vp))
947
+ self.spatial_video_combo.blockSignals(False)
948
+
949
+ self.spatial_show_btn.setEnabled(has_loc)
950
+ self.manage_regions_btn.setEnabled(has_loc)
951
+ if has_loc:
952
+ self.spatial_info_label.setText("Localization data available.")
953
+ else:
954
+ self.spatial_info_label.setText("No localization data found in results.")
955
+
956
+ def _extract_spatial_data(self):
957
+ """Extract per-clip centroids grouped by behavior from loaded results.
958
+
959
+ Returns dict:
960
+ {
961
+ video_basename: {
962
+ "all_centroids": [(cx, cy), ...],
963
+ "behavior_centroids": {behavior_name: [(cx, cy), ...], ...}
964
+ }
965
+ }
966
+ """
967
+ data_container = self.results
968
+ results_dict = data_container.get("results", data_container)
969
+ classes = data_container.get("classes", [])
970
+ params = data_container.get("parameters", {})
971
+
972
+ spatial_data = {}
973
+
974
+ for video_path, v_data in (results_dict.items() if isinstance(results_dict, dict) else []):
975
+ loc_bboxes = v_data.get("localization_bboxes", [])
976
+ if not loc_bboxes:
977
+ continue
978
+
979
+ preds = v_data.get("predictions", [])
980
+ corrections = v_data.get("corrected_labels", {})
981
+ video_name = os.path.basename(video_path)
982
+
983
+ all_centroids = []
984
+ behavior_centroids = {}
985
+
986
+ for clip_idx, raw in enumerate(loc_bboxes):
987
+ # Compute centroid from bbox
988
+ cx, cy = self._bbox_centroid(raw)
989
+ if cx is None:
990
+ continue
991
+
992
+ all_centroids.append((cx, cy))
993
+
994
+ # Determine behavior label for this clip
995
+ if clip_idx < len(preds):
996
+ pred_idx = preds[clip_idx]
997
+ # Apply correction if exists
998
+ if str(clip_idx) in corrections:
999
+ pred_idx = corrections[str(clip_idx)]
1000
+ elif clip_idx in corrections:
1001
+ pred_idx = corrections[clip_idx]
1002
+
1003
+ if classes and pred_idx < len(classes):
1004
+ label = classes[pred_idx]
1005
+ else:
1006
+ label = f"Class {pred_idx}"
1007
+
1008
+ behavior_centroids.setdefault(label, []).append((cx, cy))
1009
+
1010
+ if all_centroids:
1011
+ spatial_data[video_name] = {
1012
+ "all_centroids": all_centroids,
1013
+ "behavior_centroids": behavior_centroids,
1014
+ }
1015
+
1016
+ return spatial_data
1017
+
1018
+ @staticmethod
1019
+ def _bbox_centroid(raw):
1020
+ """Return (cx, cy) from a localization bbox entry, or (None, None)."""
1021
+ try:
1022
+ if not isinstance(raw, (list, tuple)) or len(raw) == 0:
1023
+ return None, None
1024
+ # Single bbox [x1, y1, x2, y2]
1025
+ if len(raw) == 4 and all(not isinstance(v, (list, tuple)) for v in raw):
1026
+ x1, y1, x2, y2 = [float(v) for v in raw]
1027
+ return (x1 + x2) / 2.0, (y1 + y2) / 2.0
1028
+ # Per-frame bboxes [[x1,y1,x2,y2], ...] — use middle frame
1029
+ if isinstance(raw[0], (list, tuple)):
1030
+ mid = len(raw) // 2
1031
+ box = raw[mid]
1032
+ x1, y1, x2, y2 = [float(v) for v in box]
1033
+ return (x1 + x2) / 2.0, (y1 + y2) / 2.0
1034
+ except Exception as e:
1035
+ logger.debug("Could not parse localization bbox center: %s", e)
1036
+ return None, None
1037
+
1038
+ def _update_spatial_plot(self):
1039
+ """Build and display the spatial distribution plot."""
1040
+ if not HAS_PLOTLY:
1041
+ return
1042
+
1043
+ spatial_data = self._extract_spatial_data()
1044
+ if not spatial_data:
1045
+ self.spatial_info_label.setText("No localization centroids could be extracted.")
1046
+ return
1047
+
1048
+ selected_behavior = self.spatial_behavior_combo.currentText()
1049
+ selected_video = self.spatial_video_combo.currentText()
1050
+ theme = self.color_theme_combo.currentText()
1051
+ if theme == "none":
1052
+ theme = None
1053
+
1054
+ fig = go.Figure()
1055
+
1056
+ # Collect centroids based on video filter
1057
+ all_cx, all_cy = [], []
1058
+ beh_centroids = {} # {behavior: [(cx, cy), ...]}
1059
+
1060
+ for video_name, vdata in spatial_data.items():
1061
+ if selected_video != "All" and video_name != selected_video:
1062
+ continue
1063
+ for cx, cy in vdata["all_centroids"]:
1064
+ all_cx.append(cx)
1065
+ all_cy.append(cy)
1066
+ for beh, pts in vdata["behavior_centroids"].items():
1067
+ beh_centroids.setdefault(beh, []).extend(pts)
1068
+
1069
+ if not all_cx:
1070
+ self.spatial_info_label.setText("No centroids for the selected filter.")
1071
+ return
1072
+
1073
+ dot_size = self.spatial_dot_size_slider.value()
1074
+ base_size = max(1, dot_size - 3)
1075
+
1076
+ # Base layer: all centroids as light grey
1077
+ fig.add_trace(go.Scatter(
1078
+ x=all_cx,
1079
+ y=all_cy,
1080
+ mode='markers',
1081
+ marker=dict(color='lightgrey', size=base_size, opacity=0.4),
1082
+ name='All clips',
1083
+ showlegend=True,
1084
+ legendgroup='all',
1085
+ hoverinfo='skip',
1086
+ ))
1087
+
1088
+ # Use same palette as Bar Plot (General Overview) so behavior colors match
1089
+ behaviors = sorted(beh_centroids.keys())
1090
+ color_palette = px.colors.qualitative.Plotly
1091
+
1092
+ if selected_behavior == "All behaviors":
1093
+ for i, beh in enumerate(behaviors):
1094
+ pts = beh_centroids[beh]
1095
+ bx = [p[0] for p in pts]
1096
+ by = [p[1] for p in pts]
1097
+ color = color_palette[i % len(color_palette)]
1098
+ fig.add_trace(go.Scatter(
1099
+ x=bx, y=by,
1100
+ mode='markers',
1101
+ marker=dict(color=color, size=dot_size, opacity=0.8),
1102
+ name=beh,
1103
+ showlegend=True,
1104
+ legendgroup=beh,
1105
+ hovertext=[beh] * len(bx),
1106
+ hoverinfo='text',
1107
+ ))
1108
+ title = "Spatial Distribution: All Behaviors"
1109
+ else:
1110
+ pts = beh_centroids.get(selected_behavior, [])
1111
+ bx = [p[0] for p in pts]
1112
+ by = [p[1] for p in pts]
1113
+ color = color_palette[behaviors.index(selected_behavior) % len(color_palette)] if selected_behavior in behaviors else color_palette[0]
1114
+ fig.add_trace(go.Scatter(
1115
+ x=bx, y=by,
1116
+ mode='markers',
1117
+ marker=dict(color=color, size=dot_size, opacity=0.8),
1118
+ name=selected_behavior,
1119
+ showlegend=True,
1120
+ legendgroup='behavior',
1121
+ hovertext=[selected_behavior] * len(bx),
1122
+ hoverinfo='text',
1123
+ ))
1124
+ title = f"Spatial Distribution: {selected_behavior} ({len(pts)} clips)"
1125
+
1126
+ # Overlay defined spatial regions as semi-transparent fills
1127
+ region_colors = px.colors.qualitative.Set2
1128
+ for ri, region in enumerate(self.spatial_regions):
1129
+ verts = region["vertices"]
1130
+ rc = region_colors[ri % len(region_colors)]
1131
+ if region["type"] == "rect" and len(verts) == 2:
1132
+ rx = [verts[0][0], verts[1][0], verts[1][0], verts[0][0], verts[0][0]]
1133
+ ry = [verts[0][1], verts[0][1], verts[1][1], verts[1][1], verts[0][1]]
1134
+ else:
1135
+ rx = [v[0] for v in verts] + [verts[0][0]]
1136
+ ry = [v[1] for v in verts] + [verts[0][1]]
1137
+ fig.add_trace(go.Scatter(
1138
+ x=rx, y=ry, fill="toself",
1139
+ fillcolor=rc, opacity=0.15,
1140
+ line=dict(color=rc, width=2),
1141
+ name=region["name"],
1142
+ showlegend=True,
1143
+ legendgroup=f"region_{ri}",
1144
+ hoverinfo="name",
1145
+ ))
1146
+
1147
+ filter_parts = []
1148
+ if selected_video != "All":
1149
+ filter_parts.append(f"Video: {selected_video}")
1150
+ filter_str = f" | {', '.join(filter_parts)}" if filter_parts else ""
1151
+
1152
+ fig.update_layout(
1153
+ title=f'{title}{filter_str}',
1154
+ xaxis_title='X (normalized)',
1155
+ yaxis_title='Y (normalized)',
1156
+ xaxis=dict(scaleanchor="y", scaleratio=1, range=[0, 1]),
1157
+ yaxis=dict(autorange='reversed', range=[0, 1]),
1158
+ height=600,
1159
+ template=theme,
1160
+ hovermode='closest',
1161
+ )
1162
+
1163
+ self.last_spatial_fig = fig
1164
+
1165
+ # Render to same webview
1166
+ import tempfile
1167
+ temp_dir = os.path.join(self.config.get("data_dir", "."), "temp_plots")
1168
+ os.makedirs(temp_dir, exist_ok=True)
1169
+ plot_path = os.path.join(temp_dir, "spatial_distribution.html")
1170
+
1171
+ with open(plot_path, 'w', encoding="utf-8") as f:
1172
+ f.write(fig.to_html(include_plotlyjs=True))
1173
+
1174
+ if HAS_WEBENGINE:
1175
+ self.webview.setUrl(QUrl.fromLocalFile(os.path.abspath(plot_path)))
1176
+
1177
+ self.spatial_info_label.setText(f"Showing {len(all_cx)} total centroids.")
1178
+
1179
+ def _save_spatial_plot(self):
1180
+ """Save the current spatial distribution plot as PDF or SVG."""
1181
+ if not HAS_PLOTLY or not getattr(self, "last_spatial_fig", None):
1182
+ QMessageBox.warning(self, "Save spatial plot", "No spatial distribution plot to save. Show the plot first.")
1183
+ return
1184
+ path, selected_filter = QFileDialog.getSaveFileName(
1185
+ self, "Save spatial plot",
1186
+ self.config.get("experiment_path", ""),
1187
+ "PDF Files (*.pdf);;SVG Files (*.svg);;PNG Files (*.png);;HTML Files (*.html)"
1188
+ )
1189
+ if not path:
1190
+ return
1191
+ try:
1192
+ if path.lower().endswith(".html"):
1193
+ self.last_spatial_fig.write_html(path)
1194
+ else:
1195
+ self.last_spatial_fig.write_image(path)
1196
+ QMessageBox.information(self, "Success", f"Spatial plot saved to {path}")
1197
+ except Exception as e:
1198
+ if "kaleido" in str(e).lower() or "executable" in str(e).lower():
1199
+ QMessageBox.warning(self, "Error", "Saving as PDF/SVG requires 'kaleido'. Install with: pip install kaleido")
1200
+ else:
1201
+ QMessageBox.critical(self, "Error", f"Failed to save: {e}")
1202
+
1203
+ def _filter_behaviors(self):
1204
+ """Open dialog to filter behaviors."""
1205
+ dialog = QDialog(self)
1206
+ dialog.setWindowTitle("Select behaviors")
1207
+ dialog.resize(300, 400)
1208
+ layout = QVBoxLayout()
1209
+
1210
+ # Buttons
1211
+ btn_layout = QHBoxLayout()
1212
+ all_btn = QPushButton("Select all")
1213
+ none_btn = QPushButton("Deselect all")
1214
+ btn_layout.addWidget(all_btn)
1215
+ btn_layout.addWidget(none_btn)
1216
+ layout.addLayout(btn_layout)
1217
+
1218
+ # Scrollable Checkbox List
1219
+ scroll = QScrollArea()
1220
+ scroll.setWidgetResizable(True)
1221
+ widget = QWidget()
1222
+ vbox = QVBoxLayout()
1223
+
1224
+ checkboxes = []
1225
+ for beh in self.all_behaviors:
1226
+ cb = QCheckBox(beh)
1227
+ if beh in self.selected_behaviors:
1228
+ cb.setChecked(True)
1229
+ checkboxes.append((beh, cb))
1230
+ vbox.addWidget(cb)
1231
+
1232
+ vbox.addStretch()
1233
+ widget.setLayout(vbox)
1234
+ scroll.setWidget(widget)
1235
+ layout.addWidget(scroll)
1236
+
1237
+ # Connect All/None
1238
+ def select_all():
1239
+ for _, cb in checkboxes: cb.setChecked(True)
1240
+ def select_none():
1241
+ for _, cb in checkboxes: cb.setChecked(False)
1242
+
1243
+ all_btn.clicked.connect(select_all)
1244
+ none_btn.clicked.connect(select_none)
1245
+
1246
+ # OK Button
1247
+ ok_btn = QPushButton("Update plots")
1248
+ ok_btn.clicked.connect(dialog.accept)
1249
+ layout.addWidget(ok_btn)
1250
+
1251
+ dialog.setLayout(layout)
1252
+
1253
+ if dialog.exec():
1254
+ self.selected_behaviors = {beh for beh, cb in checkboxes if cb.isChecked()}
1255
+ self._update_plots()
1256
+
1257
+ def _manage_groups(self):
1258
+ # Get video paths from results
1259
+ if "results" in self.results:
1260
+ video_paths = list(self.results["results"].keys())
1261
+ else:
1262
+ video_paths = list(self.results.keys())
1263
+
1264
+ dialog = GroupManagementDialog(video_paths, self.groups, self.video_groups, self)
1265
+ if dialog.exec():
1266
+ self.groups = dialog.groups
1267
+ self.video_groups = dialog.video_groups
1268
+ self._save_groups_to_config()
1269
+ self._save_metadata()
1270
+
1271
+ # Re-process to update groups in merged_data
1272
+ self._process_data()
1273
+ self._update_plots()
1274
+
1275
+ def _save_metadata(self):
1276
+ exp_path = self.config.get("experiment_path")
1277
+ if exp_path:
1278
+ meta_path = os.path.join(exp_path, "results", "analysis_metadata.json")
1279
+ try:
1280
+ os.makedirs(os.path.dirname(meta_path), exist_ok=True)
1281
+ serializable_regions = []
1282
+ for r in self.spatial_regions:
1283
+ serializable_regions.append({
1284
+ "name": r["name"],
1285
+ "type": r["type"],
1286
+ "vertices": [list(v) for v in r["vertices"]],
1287
+ })
1288
+ with open(meta_path, 'w') as f:
1289
+ json.dump({
1290
+ "groups": self.groups,
1291
+ "video_groups": self.video_groups,
1292
+ "spatial_regions": serializable_regions,
1293
+ }, f, indent=2)
1294
+ except Exception as e:
1295
+ logger.error("Error saving metadata: %s", e)
1296
+
1297
+ def _on_region_filter_changed(self):
1298
+ if self.merged_data:
1299
+ self._update_plots()
1300
+
1301
+ def _update_region_filter_combo(self):
1302
+ self.region_filter_combo.blockSignals(True)
1303
+ prev = self.region_filter_combo.currentText()
1304
+ self.region_filter_combo.clear()
1305
+ self.region_filter_combo.addItem("All Regions")
1306
+ for r in self.spatial_regions:
1307
+ self.region_filter_combo.addItem(r["name"])
1308
+ self.region_filter_combo.addItem("Outside")
1309
+ idx = self.region_filter_combo.findText(prev)
1310
+ if idx >= 0:
1311
+ self.region_filter_combo.setCurrentIndex(idx)
1312
+ self.region_filter_combo.blockSignals(False)
1313
+
1314
+ def _manage_spatial_regions(self):
1315
+ spatial_data = self._extract_spatial_data()
1316
+ all_centroids = []
1317
+ for vdata in spatial_data.values():
1318
+ all_centroids.extend(vdata["all_centroids"])
1319
+ dialog = SpatialRegionEditor(all_centroids, self.spatial_regions, self)
1320
+ if dialog.exec():
1321
+ self.spatial_regions = dialog.regions
1322
+ self._save_metadata()
1323
+ self._update_region_filter_combo()
1324
+ self._process_data()
1325
+ self._update_plots()
1326
+
1327
+ def _update_plots(self):
1328
+ if not HAS_PLOTLY or not self.merged_data:
1329
+ if not HAS_PLOTLY:
1330
+ QMessageBox.warning(self, "Error", "Plotly is not installed.")
1331
+ return
1332
+
1333
+ df = pd.DataFrame(self.merged_data)
1334
+
1335
+ # Filter behaviors
1336
+ if self.selected_behaviors:
1337
+ df = df[df["Behavior"].isin(self.selected_behaviors)]
1338
+
1339
+ # Filter groups
1340
+ if self.visible_groups:
1341
+ df = df[df["Group"].isin(self.visible_groups)]
1342
+
1343
+ # Filter by spatial region
1344
+ selected_region = self.region_filter_combo.currentText()
1345
+ if selected_region != "All Regions" and "Region" in df.columns:
1346
+ df = df[df["Region"] == selected_region]
1347
+
1348
+ if df.empty:
1349
+ self.last_fig = None
1350
+ return
1351
+
1352
+ metric = self.metric_combo.currentText()
1353
+ analysis_mode = self.plot_mode_combo.currentText()
1354
+ graph_type = self.graph_type_combo.currentText()
1355
+ theme = self.color_theme_combo.currentText()
1356
+
1357
+ fig = None
1358
+
1359
+ # Aggregation
1360
+ if "Occurrences" in metric:
1361
+ agg = df.groupby(["Video", "Group", "Behavior"]).size().reset_index(name="Value")
1362
+ y_label = "Count"
1363
+ elif "Average" in metric:
1364
+ agg = df.groupby(["Video", "Group", "Behavior"])["Duration"].mean().reset_index(name="Value")
1365
+ y_label = "Avg Duration (s)"
1366
+ elif "Total" in metric:
1367
+ agg = df.groupby(["Video", "Group", "Behavior"])["Duration"].sum().reset_index(name="Value")
1368
+ y_label = "Total Duration (s)"
1369
+ elif "Percent" in metric:
1370
+ # Calculate total time per video
1371
+ total_times = df.groupby("Video")["Duration"].sum().to_dict()
1372
+ agg = df.groupby(["Video", "Group", "Behavior"])["Duration"].sum().reset_index(name="Value")
1373
+ agg["Value"] = agg.apply(lambda x: (x["Value"] / total_times.get(x["Video"], 1)) * 100, axis=1)
1374
+ y_label = "Percent Time (%)"
1375
+
1376
+ # Auto-determine graph type if needed
1377
+ if graph_type == "Auto":
1378
+ if analysis_mode == "General Overview":
1379
+ if "Average" in metric: graph_type = "Box Plot"
1380
+ else: graph_type = "Bar Chart"
1381
+ else: # Group Comparison
1382
+ graph_type = "Box Plot"
1383
+
1384
+ # Common Plot Settings
1385
+ template = theme if theme != "none" else None
1386
+ title = f"{analysis_mode}: {metric}"
1387
+ labels = {"Value": y_label, "Duration": "Bout Duration (s)", "Video": "Video", "Group": "Group"}
1388
+ plot_args = {"template": template, "title": title, "labels": labels}
1389
+
1390
+ if analysis_mode == "General Overview":
1391
+ # X=Video, Color=Behavior
1392
+
1393
+ # Decide data source: Bout-level (df) or Video-level (agg)
1394
+ if "Average" in metric and graph_type in ["Box Plot", "Violin Plot", "Strip Plot"]:
1395
+ data = df
1396
+ y_col = "Duration"
1397
+ else:
1398
+ data = agg
1399
+ y_col = "Value"
1400
+
1401
+ if graph_type == "Bar Chart":
1402
+ fig = px.bar(data, x="Video", y=y_col, color="Behavior", **plot_args)
1403
+ elif graph_type == "Box Plot":
1404
+ fig = px.box(data, x="Video", y=y_col, color="Behavior", points="all", **plot_args)
1405
+ elif graph_type == "Violin Plot":
1406
+ fig = px.violin(data, x="Video", y=y_col, color="Behavior", points="all", box=True, **plot_args)
1407
+ elif graph_type == "Strip Plot":
1408
+ fig = px.strip(data, x="Video", y=y_col, color="Behavior", **plot_args)
1409
+ elif graph_type == "Line Plot":
1410
+ fig = px.line(agg, x="Video", y="Value", color="Behavior", markers=True, **plot_args)
1411
+
1412
+ elif analysis_mode == "Group Comparison":
1413
+ # X=Behavior, Color=Group, Data=Video Aggregates (agg)
1414
+ data = agg
1415
+
1416
+ if graph_type == "Bar Chart":
1417
+ # Create bar chart with mean, SEM error bars, and individual points
1418
+ fig = go.Figure()
1419
+
1420
+ # Get unique behaviors and groups
1421
+ behaviors = sorted(data["Behavior"].unique())
1422
+ groups = sorted(data["Group"].unique())
1423
+ num_groups = len(groups)
1424
+
1425
+ # Color palette
1426
+ colors = px.colors.qualitative.Plotly
1427
+
1428
+ # Layout settings
1429
+ group_width = 0.8
1430
+ bar_width = group_width / num_groups
1431
+
1432
+ for i, group in enumerate(groups):
1433
+ group_data = data[data["Group"] == group]
1434
+
1435
+ means = []
1436
+ sems = []
1437
+ x_positions_bar = []
1438
+ x_positions_points = []
1439
+ y_points = []
1440
+
1441
+ # Calculate offset for this group
1442
+ # We map behaviors to integers 0..N-1
1443
+ # Center of bar i is at: index - 0.4 + (i + 0.5) * bar_width
1444
+ offset = -0.4 + bar_width * (i + 0.5)
1445
+
1446
+ for j, behavior in enumerate(behaviors):
1447
+ beh_data = group_data[group_data["Behavior"] == behavior]["Value"]
1448
+
1449
+ # Stats
1450
+ if len(beh_data) > 0:
1451
+ means.append(beh_data.mean())
1452
+ sems.append(beh_data.sem() if len(beh_data) > 1 else 0)
1453
+ else:
1454
+ means.append(0)
1455
+ sems.append(0)
1456
+
1457
+ # Bar Position
1458
+ x_positions_bar.append(j + offset)
1459
+
1460
+ # Points Position
1461
+ if len(beh_data) > 0:
1462
+ x_positions_points.extend([j + offset] * len(beh_data))
1463
+ y_points.extend(beh_data.values)
1464
+
1465
+ # Add bar trace with error bars
1466
+ fig.add_trace(go.Bar(
1467
+ name=group,
1468
+ x=x_positions_bar,
1469
+ y=means,
1470
+ error_y=dict(type='data', array=sems, visible=True),
1471
+ marker_color=colors[i % len(colors)],
1472
+ width=bar_width,
1473
+ showlegend=True
1474
+ ))
1475
+
1476
+ # Add individual points as scatter (black)
1477
+ if x_positions_points:
1478
+ fig.add_trace(go.Scatter(
1479
+ x=x_positions_points,
1480
+ y=y_points,
1481
+ mode='markers',
1482
+ marker=dict(
1483
+ color='black',
1484
+ size=5,
1485
+ opacity=0.7,
1486
+ ),
1487
+ showlegend=False,
1488
+ hovertemplate=f'{group}<br>%{{y:.2f}}<extra></extra>'
1489
+ ))
1490
+
1491
+ fig.update_layout(
1492
+ title=title,
1493
+ xaxis_title="Behavior",
1494
+ yaxis_title=y_label,
1495
+ xaxis=dict(
1496
+ tickmode='array',
1497
+ tickvals=list(range(len(behaviors))),
1498
+ ticktext=behaviors
1499
+ ),
1500
+ template=template,
1501
+ hovermode='closest'
1502
+ )
1503
+
1504
+ elif graph_type == "Box Plot":
1505
+ fig = px.box(data, x="Behavior", y="Value", color="Group", points="all", **plot_args)
1506
+ elif graph_type == "Violin Plot":
1507
+ fig = px.violin(data, x="Behavior", y="Value", color="Group", points="all", box=True, **plot_args)
1508
+ elif graph_type == "Strip Plot":
1509
+ fig = px.strip(data, x="Behavior", y="Value", color="Group", **plot_args)
1510
+ elif graph_type == "Line Plot":
1511
+ fig = px.line(data, x="Behavior", y="Value", color="Group", markers=True, **plot_args)
1512
+
1513
+ self.last_fig = fig
1514
+
1515
+ if fig:
1516
+ # Save to temp html
1517
+ import tempfile
1518
+ import shutil
1519
+
1520
+ # Use a fixed temp file in app dir to avoid permission issues sometimes
1521
+ temp_dir = os.path.join(self.config.get("data_dir", "."), "temp_plots")
1522
+ os.makedirs(temp_dir, exist_ok=True)
1523
+ plot_path = os.path.join(temp_dir, "current_plot.html")
1524
+
1525
+ with open(plot_path, 'w', encoding="utf-8") as f:
1526
+ # include_plotlyjs=True embeds the ~3MB library directly in the HTML
1527
+ # preventing 'Plotly is not defined' errors if CDN fails
1528
+ f.write(fig.to_html(include_plotlyjs=True))
1529
+
1530
+ if HAS_WEBENGINE:
1531
+ self.webview.setUrl(QUrl.fromLocalFile(os.path.abspath(plot_path)))
1532
+
1533
+ def _on_transition_type_changed(self):
1534
+ """Update transition select combo when type changes."""
1535
+ self.transition_select_combo.clear()
1536
+
1537
+ if not self.merged_data:
1538
+ return
1539
+
1540
+ analysis_type = self.transition_type_combo.currentText()
1541
+
1542
+ if analysis_type == "Individual Video":
1543
+ # Populate with video names
1544
+ df = pd.DataFrame(self.merged_data)
1545
+ videos = sorted(df["Video"].unique().tolist())
1546
+ self.transition_select_combo.addItems(videos)
1547
+ else: # Group Comparison
1548
+ # Populate with group names
1549
+ groups = sorted([g for g in self.groups.keys() if g])
1550
+ if not groups:
1551
+ groups = ["Unassigned"]
1552
+ self.transition_select_combo.addItems(groups)
1553
+
1554
+ self.compute_transition_btn.setEnabled(len(self.transition_select_combo) > 0)
1555
+
1556
+ def _compute_transitions(self):
1557
+ """Compute transition matrix and plot transition graph."""
1558
+ if not HAS_PLOTLY:
1559
+ QMessageBox.warning(self, "Error", "Plotly is required for transition analysis.")
1560
+ return
1561
+
1562
+ analysis_type = self.transition_type_combo.currentText()
1563
+ selection = self.transition_select_combo.currentText()
1564
+
1565
+ if not selection:
1566
+ return
1567
+
1568
+ df = pd.DataFrame(self.merged_data)
1569
+
1570
+ # Filter data based on selection
1571
+ if analysis_type == "Individual Video":
1572
+ df_filtered = df[df["Video"] == selection]
1573
+ title_suffix = f"Video: {selection}"
1574
+ else: # Group Comparison
1575
+ df_filtered = df[df["Group"] == selection]
1576
+ title_suffix = f"Group: {selection}"
1577
+
1578
+ if df_filtered.empty:
1579
+ QMessageBox.warning(self, "No Data", f"No data found for {selection}")
1580
+ return
1581
+
1582
+ # Build sequence of behaviors for this selection
1583
+ # Group by video, sort by implicit order (row order in merged_data represents time)
1584
+ sequences = []
1585
+
1586
+ if analysis_type == "Individual Video":
1587
+ # Single video - one sequence
1588
+ sequence = df_filtered["Behavior"].tolist()
1589
+ sequences.append(sequence)
1590
+ else:
1591
+ # Group - multiple videos, analyze separately then aggregate
1592
+ for video in df_filtered["Video"].unique():
1593
+ video_df = df_filtered[df_filtered["Video"] == video]
1594
+ sequence = video_df["Behavior"].tolist()
1595
+ sequences.append(sequence)
1596
+
1597
+ # Compute transition matrix
1598
+ behaviors = sorted(self.all_behaviors) if self.all_behaviors else sorted(df["Behavior"].unique().tolist())
1599
+ transition_counts = {b: {b2: 0 for b2 in behaviors} for b in behaviors}
1600
+
1601
+ for sequence in sequences:
1602
+ for i in range(len(sequence) - 1):
1603
+ from_beh = sequence[i]
1604
+ to_beh = sequence[i + 1]
1605
+ if from_beh in transition_counts and to_beh in transition_counts[from_beh]:
1606
+ transition_counts[from_beh][to_beh] += 1
1607
+
1608
+ # Normalize to probabilities and calculate residuals
1609
+ total_transitions = sum(sum(row.values()) for row in transition_counts.values())
1610
+
1611
+ # Row and Column totals for Expected values
1612
+ row_totals = {b: sum(transition_counts[b].values()) for b in behaviors}
1613
+ col_totals = {b: sum(transition_counts[row][b] for row in behaviors) for b in behaviors}
1614
+
1615
+ transition_matrix = {}
1616
+ residuals_matrix = {}
1617
+
1618
+ for from_beh in behaviors:
1619
+ transition_matrix[from_beh] = {}
1620
+ residuals_matrix[from_beh] = {}
1621
+ row_sum = row_totals[from_beh]
1622
+
1623
+ for to_beh in behaviors:
1624
+ # Probability
1625
+ if row_sum > 0:
1626
+ transition_matrix[from_beh][to_beh] = transition_counts[from_beh][to_beh] / row_sum
1627
+ else:
1628
+ transition_matrix[from_beh][to_beh] = 0.0
1629
+
1630
+ # Residual (Observed - Expected) / sqrt(Expected)
1631
+ # Expected = (RowTotal * ColTotal) / GrandTotal
1632
+ if total_transitions > 0:
1633
+ expected = (row_totals[from_beh] * col_totals[to_beh]) / total_transitions
1634
+ else:
1635
+ expected = 0
1636
+
1637
+ observed = transition_counts[from_beh][to_beh]
1638
+
1639
+ if expected > 0:
1640
+ # Standardized Residual
1641
+ z_score = (observed - expected) / np.sqrt(expected)
1642
+ else:
1643
+ z_score = 0.0
1644
+
1645
+ residuals_matrix[from_beh][to_beh] = z_score
1646
+
1647
+ # Display matrix in table
1648
+ self._display_transition_matrix(behaviors, transition_matrix)
1649
+
1650
+ # Plot transition graph
1651
+ self._plot_transition_graph(behaviors, transition_matrix, residuals_matrix, title_suffix)
1652
+
1653
+ def _display_transition_matrix(self, behaviors, transition_matrix):
1654
+ """Display transition matrix in QTableWidget."""
1655
+ self.transition_matrix_table.clear()
1656
+ self.transition_matrix_table.setRowCount(len(behaviors))
1657
+ self.transition_matrix_table.setColumnCount(len(behaviors) + 1)
1658
+
1659
+ # Headers
1660
+ self.transition_matrix_table.setHorizontalHeaderLabels(["From \\ To"] + behaviors)
1661
+
1662
+ for i, from_beh in enumerate(behaviors):
1663
+ # Row label
1664
+ self.transition_matrix_table.setItem(i, 0, QTableWidgetItem(from_beh))
1665
+
1666
+ # Probabilities
1667
+ for j, to_beh in enumerate(behaviors):
1668
+ prob = transition_matrix[from_beh][to_beh]
1669
+ item = QTableWidgetItem(f"{prob:.3f}")
1670
+ self.transition_matrix_table.setItem(i, j + 1, item)
1671
+
1672
+ self.transition_matrix_table.resizeColumnsToContents()
1673
+
1674
+ def _plot_transition_graph(self, behaviors, transition_matrix, residuals_matrix, title_suffix):
1675
+ """Plot transition graph using Plotly network graph with circular layout."""
1676
+ if not HAS_PLOTLY:
1677
+ return
1678
+
1679
+ use_sig_filter = self.sig_filter_check.isChecked()
1680
+ layout_mode = self.layout_combo.currentText()
1681
+
1682
+ # Build edges
1683
+ threshold = 0.05
1684
+ edges = []
1685
+ edge_weights = []
1686
+ edge_texts = []
1687
+
1688
+ for from_beh in behaviors:
1689
+ for to_beh in behaviors:
1690
+ prob = transition_matrix[from_beh][to_beh]
1691
+ resid = residuals_matrix[from_beh][to_beh]
1692
+
1693
+ # Filtering condition
1694
+ include = False
1695
+ if use_sig_filter:
1696
+ # Significant positive deviation (Z > 1.96 corresponds to p < 0.05)
1697
+ if resid > 1.96:
1698
+ include = True
1699
+ else:
1700
+ # Probability threshold
1701
+ if prob > threshold:
1702
+ include = True
1703
+
1704
+ if include:
1705
+ edges.append((from_beh, to_beh))
1706
+ edge_weights.append(prob) # Always visualize probability width
1707
+
1708
+ # Tooltip text
1709
+ txt = f"{from_beh} → {to_beh}<br>Prob: {prob:.2%}"
1710
+ if resid != 0:
1711
+ txt += f"<br>Z-score: {resid:.2f}"
1712
+ edge_texts.append(txt)
1713
+
1714
+ if not edges:
1715
+ msg = "No significant transitions found (Z > 1.96)" if use_sig_filter else "No significant transitions found (prob > 0.05)"
1716
+ QMessageBox.information(self, "No Transitions", msg)
1717
+ return
1718
+
1719
+ # Calculate Layout
1720
+ import math
1721
+ n = len(behaviors)
1722
+ node_positions = {}
1723
+
1724
+ if layout_mode == "Network Layout":
1725
+ try:
1726
+ import networkx as nx
1727
+ # Create graph for layout calculation
1728
+ G = nx.DiGraph()
1729
+ G.add_nodes_from(behaviors)
1730
+ # Add edges with weights (inverse of probability for 'distance')
1731
+ for (u, v), w in zip(edges, edge_weights):
1732
+ if w > 0:
1733
+ G.add_edge(u, v, weight=w)
1734
+
1735
+ # Spring layout
1736
+ pos = nx.spring_layout(G, k=2.0/math.sqrt(n) if n > 0 else 1, seed=42, iterations=50)
1737
+ # Scale to fit roughly in range
1738
+ for node, p in pos.items():
1739
+ node_positions[node] = (p[0] * 2, p[1] * 2)
1740
+ except ImportError:
1741
+ logger.warning("NetworkX not installed, falling back to Circular Layout")
1742
+ layout_mode = "Circular Layout" # Fallback
1743
+
1744
+ if layout_mode == "Circular Layout":
1745
+ radius = 1.5
1746
+ for i, beh in enumerate(behaviors):
1747
+ angle = 2 * math.pi * i / n - math.pi / 2
1748
+ node_positions[beh] = (radius * math.cos(angle), radius * math.sin(angle))
1749
+
1750
+ # Generate distinct colors for nodes
1751
+ import plotly.colors as pc
1752
+ if n <= 10:
1753
+ node_colors = pc.qualitative.Set3[:n]
1754
+ else:
1755
+ node_colors = pc.sample_colorscale("hsv", [i/n for i in range(n)])
1756
+
1757
+ behavior_to_color = {beh: node_colors[i] for i, beh in enumerate(behaviors)}
1758
+
1759
+ # Build Plotly figure with curved edges
1760
+ edge_traces = []
1761
+ for (from_beh, to_beh), weight, txt in zip(edges, edge_weights, edge_texts):
1762
+ x0, y0 = node_positions[from_beh]
1763
+ x1, y1 = node_positions[to_beh]
1764
+
1765
+ # Curve control points
1766
+ if layout_mode == "Circular Layout":
1767
+ # Control point slightly outside circle
1768
+ mid_x, mid_y = (x0 + x1) / 2, (y0 + y1) / 2
1769
+ dist = math.sqrt(mid_x**2 + mid_y**2)
1770
+ if dist > 0:
1771
+ offset = 0.2
1772
+ ctrl_x = mid_x + offset * mid_x / dist
1773
+ ctrl_y = mid_y + offset * mid_y / dist
1774
+ else:
1775
+ ctrl_x, ctrl_y = mid_x, mid_y
1776
+ else:
1777
+ # Network layout: simple quadratic curve to avoid overlap
1778
+ # Perpendicular offset
1779
+ mid_x, mid_y = (x0 + x1) / 2, (y0 + y1) / 2
1780
+ dx, dy = x1 - x0, y1 - y0
1781
+ perp_x, perp_y = -dy, dx
1782
+ norm = math.sqrt(perp_x**2 + perp_y**2)
1783
+ if norm > 0:
1784
+ offset = 0.2 # Fixed offset amount
1785
+ ctrl_x = mid_x + offset * perp_x / norm
1786
+ ctrl_y = mid_y + offset * perp_y / norm
1787
+ else:
1788
+ ctrl_x, ctrl_y = mid_x + 0.2, mid_y + 0.2 # Loop
1789
+
1790
+ # Approximate curve
1791
+ curve_x = [x0, ctrl_x, x1, None]
1792
+ curve_y = [y0, ctrl_y, y1, None]
1793
+
1794
+ # Style
1795
+ edge_color = behavior_to_color[from_beh]
1796
+ opacity = 0.3 + 0.5 * weight
1797
+
1798
+ edge_trace = go.Scatter(
1799
+ x=curve_x,
1800
+ y=curve_y,
1801
+ mode='lines',
1802
+ line=dict(
1803
+ width=max(1, weight * 10),
1804
+ color=edge_color.replace('rgb', 'rgba').replace(')', f',{opacity})'),
1805
+ shape='spline'
1806
+ ),
1807
+ hoverinfo='text',
1808
+ text=txt,
1809
+ showlegend=False
1810
+ )
1811
+ edge_traces.append(edge_trace)
1812
+
1813
+ # Node trace
1814
+ node_x = [node_positions[beh][0] for beh in behaviors]
1815
+ node_y = [node_positions[beh][1] for beh in behaviors]
1816
+ node_colors_list = [behavior_to_color[beh] for beh in behaviors]
1817
+
1818
+ node_trace = go.Scatter(
1819
+ x=node_x,
1820
+ y=node_y,
1821
+ mode='markers+text',
1822
+ text=behaviors,
1823
+ textposition='middle center',
1824
+ textfont=dict(size=10, color='black', family='Arial Black'),
1825
+ marker=dict(
1826
+ size=50,
1827
+ color=node_colors_list,
1828
+ line=dict(width=3, color='white'),
1829
+ opacity=0.9
1830
+ ),
1831
+ hoverinfo='text',
1832
+ hovertext=behaviors,
1833
+ showlegend=False
1834
+ )
1835
+
1836
+ fig = go.Figure(data=edge_traces + [node_trace])
1837
+ fig.update_layout(
1838
+ title=dict(
1839
+ text=f"Behavior Transition Graph - {title_suffix}" + (" (Significant Only)" if use_sig_filter else "") + "<br><span style='font-size:12px;color:grey'>Edge color matches SOURCE behavior (From -> To)</span>",
1840
+ font=dict(size=18)
1841
+ ),
1842
+ showlegend=False,
1843
+ hovermode='closest',
1844
+ xaxis=dict(showgrid=False, zeroline=False, showticklabels=False, range=[-2.5, 2.5]),
1845
+ yaxis=dict(showgrid=False, zeroline=False, showticklabels=False, range=[-2.5, 2.5]),
1846
+ plot_bgcolor='white',
1847
+ height=700,
1848
+ width=700
1849
+ )
1850
+
1851
+ # Save and display
1852
+ temp_dir = os.path.join(self.config.get("data_dir", "."), "temp_plots")
1853
+ os.makedirs(temp_dir, exist_ok=True)
1854
+ plot_path = os.path.join(temp_dir, "transition_graph.html")
1855
+
1856
+ with open(plot_path, 'w', encoding="utf-8") as f:
1857
+ f.write(fig.to_html(include_plotlyjs=True))
1858
+
1859
+ if HAS_WEBENGINE:
1860
+ self.transition_webview.setUrl(QUrl.fromLocalFile(os.path.abspath(plot_path)))
1861
+
1862
+ class SpatialCanvas(QWidget):
1863
+ """Interactive canvas for drawing spatial regions over centroid scatter."""
1864
+
1865
+ MARGIN = 30
1866
+ REGION_COLORS = [
1867
+ QColor(102, 194, 165, 60), QColor(252, 141, 98, 60),
1868
+ QColor(141, 160, 203, 60), QColor(231, 138, 195, 60),
1869
+ QColor(166, 216, 84, 60), QColor(255, 217, 47, 60),
1870
+ ]
1871
+ REGION_BORDER_COLORS = [
1872
+ QColor(102, 194, 165), QColor(252, 141, 98),
1873
+ QColor(141, 160, 203), QColor(231, 138, 195),
1874
+ QColor(166, 216, 84), QColor(255, 217, 47),
1875
+ ]
1876
+
1877
+ def __init__(self, centroids, regions, parent=None):
1878
+ super().__init__(parent)
1879
+ self.centroids = centroids # [(cx, cy), ...]
1880
+ self.regions = regions # list of region dicts (mutable reference)
1881
+ self.setMinimumSize(500, 500)
1882
+ self.setMouseTracking(True)
1883
+
1884
+ self.draw_mode = None # "polygon" or "rect"
1885
+ self._poly_points = [] # in-progress polygon vertices (normalized)
1886
+ self._rect_start = None # in-progress rect start (normalized)
1887
+ self._rect_end = None
1888
+ self._mouse_norm = None # current mouse in normalized coords
1889
+ self._pending_name = None # name for the region being drawn
1890
+
1891
+ def _norm_to_pixel(self, nx, ny):
1892
+ m = self.MARGIN
1893
+ w = self.width() - 2 * m
1894
+ h = self.height() - 2 * m
1895
+ return int(m + nx * w), int(m + ny * h)
1896
+
1897
+ def _pixel_to_norm(self, px, py):
1898
+ m = self.MARGIN
1899
+ w = self.width() - 2 * m
1900
+ h = self.height() - 2 * m
1901
+ nx = max(0.0, min(1.0, (px - m) / max(1, w)))
1902
+ ny = max(0.0, min(1.0, (py - m) / max(1, h)))
1903
+ return nx, ny
1904
+
1905
+ def start_drawing(self, mode, name):
1906
+ self.draw_mode = mode
1907
+ self._pending_name = name
1908
+ self._poly_points = []
1909
+ self._rect_start = None
1910
+ self._rect_end = None
1911
+ self.setCursor(Qt.CursorShape.CrossCursor)
1912
+ self.update()
1913
+
1914
+ def cancel_drawing(self):
1915
+ self.draw_mode = None
1916
+ self._poly_points = []
1917
+ self._rect_start = None
1918
+ self._rect_end = None
1919
+ self.setCursor(Qt.CursorShape.ArrowCursor)
1920
+ self.update()
1921
+
1922
+ def mousePressEvent(self, event):
1923
+ if not self.draw_mode:
1924
+ return
1925
+ pos = event.position()
1926
+ nx, ny = self._pixel_to_norm(int(pos.x()), int(pos.y()))
1927
+ if self.draw_mode == "polygon":
1928
+ self._poly_points.append((nx, ny))
1929
+ self.update()
1930
+ elif self.draw_mode == "rect":
1931
+ self._rect_start = (nx, ny)
1932
+ self._rect_end = (nx, ny)
1933
+ self.update()
1934
+
1935
+ def mouseMoveEvent(self, event):
1936
+ pos = event.position()
1937
+ self._mouse_norm = self._pixel_to_norm(int(pos.x()), int(pos.y()))
1938
+ if self.draw_mode == "rect" and self._rect_start:
1939
+ self._rect_end = self._mouse_norm
1940
+ self.update()
1941
+
1942
+ def mouseReleaseEvent(self, event):
1943
+ if self.draw_mode == "rect" and self._rect_start and self._rect_end:
1944
+ x0, y0 = self._rect_start
1945
+ x1, y1 = self._rect_end
1946
+ if abs(x1 - x0) > 0.005 and abs(y1 - y0) > 0.005:
1947
+ self.regions.append({
1948
+ "name": self._pending_name or f"Region {len(self.regions)+1}",
1949
+ "type": "rect",
1950
+ "vertices": [(min(x0, x1), min(y0, y1)), (max(x0, x1), max(y0, y1))],
1951
+ })
1952
+ self.draw_mode = None
1953
+ self._rect_start = None
1954
+ self._rect_end = None
1955
+ self.setCursor(Qt.CursorShape.ArrowCursor)
1956
+ # Signal parent to refresh list
1957
+ parent = self.parent()
1958
+ while parent and not isinstance(parent, SpatialRegionEditor):
1959
+ parent = parent.parent()
1960
+ if parent:
1961
+ parent._refresh_region_list()
1962
+ self.update()
1963
+
1964
+ def mouseDoubleClickEvent(self, event):
1965
+ if self.draw_mode == "polygon" and len(self._poly_points) >= 3:
1966
+ self.regions.append({
1967
+ "name": self._pending_name or f"Region {len(self.regions)+1}",
1968
+ "type": "polygon",
1969
+ "vertices": list(self._poly_points),
1970
+ })
1971
+ self._poly_points = []
1972
+ self.draw_mode = None
1973
+ self.setCursor(Qt.CursorShape.ArrowCursor)
1974
+ parent = self.parent()
1975
+ while parent and not isinstance(parent, SpatialRegionEditor):
1976
+ parent = parent.parent()
1977
+ if parent:
1978
+ parent._refresh_region_list()
1979
+ self.update()
1980
+
1981
+ def paintEvent(self, event):
1982
+ p = QPainter(self)
1983
+ p.setRenderHint(QPainter.RenderHint.Antialiasing)
1984
+ m = self.MARGIN
1985
+ draw_w = self.width() - 2 * m
1986
+ draw_h = self.height() - 2 * m
1987
+
1988
+ # Background
1989
+ p.fillRect(self.rect(), QColor(255, 255, 255))
1990
+ p.setPen(QPen(QColor(200, 200, 200), 1))
1991
+ p.drawRect(m, m, draw_w, draw_h)
1992
+
1993
+ # Centroids
1994
+ p.setPen(Qt.PenStyle.NoPen)
1995
+ p.setBrush(QBrush(QColor(180, 180, 180, 120)))
1996
+ for cx, cy in self.centroids:
1997
+ px, py = self._norm_to_pixel(cx, cy)
1998
+ p.drawEllipse(QPointF(px, py), 3, 3)
1999
+
2000
+ # Saved regions
2001
+ for ri, region in enumerate(self.regions):
2002
+ col = self.REGION_COLORS[ri % len(self.REGION_COLORS)]
2003
+ border = self.REGION_BORDER_COLORS[ri % len(self.REGION_BORDER_COLORS)]
2004
+ p.setBrush(QBrush(col))
2005
+ p.setPen(QPen(border, 2))
2006
+ verts = region["vertices"]
2007
+ if region["type"] == "rect" and len(verts) == 2:
2008
+ px0, py0 = self._norm_to_pixel(*verts[0])
2009
+ px1, py1 = self._norm_to_pixel(*verts[1])
2010
+ p.drawRect(QRectF(QPointF(px0, py0), QPointF(px1, py1)))
2011
+ else:
2012
+ poly = QPolygonF([QPointF(*self._norm_to_pixel(*v)) for v in verts])
2013
+ p.drawPolygon(poly)
2014
+ # Label
2015
+ if verts:
2016
+ cx_avg = sum(v[0] for v in verts) / len(verts)
2017
+ cy_avg = sum(v[1] for v in verts) / len(verts)
2018
+ lx, ly = self._norm_to_pixel(cx_avg, cy_avg)
2019
+ p.setPen(QPen(border.darker(130), 1))
2020
+ p.setFont(QFont("Arial", 9, QFont.Weight.Bold))
2021
+ p.drawText(lx - 30, ly - 5, 60, 20, Qt.AlignmentFlag.AlignCenter, region["name"])
2022
+
2023
+ # In-progress polygon
2024
+ if self.draw_mode == "polygon" and self._poly_points:
2025
+ p.setPen(QPen(QColor(255, 80, 80), 2, Qt.PenStyle.DashLine))
2026
+ p.setBrush(QBrush(QColor(255, 80, 80, 40)))
2027
+ pts = [QPointF(*self._norm_to_pixel(*v)) for v in self._poly_points]
2028
+ if self._mouse_norm:
2029
+ pts.append(QPointF(*self._norm_to_pixel(*self._mouse_norm)))
2030
+ if len(pts) >= 3:
2031
+ p.drawPolygon(QPolygonF(pts))
2032
+ elif len(pts) == 2:
2033
+ p.drawLine(pts[0], pts[1])
2034
+ for pt in pts[:-1]:
2035
+ p.setBrush(QBrush(QColor(255, 80, 80)))
2036
+ p.drawEllipse(pt, 4, 4)
2037
+ p.setBrush(QBrush(QColor(255, 80, 80, 40)))
2038
+
2039
+ # In-progress rectangle
2040
+ if self.draw_mode == "rect" and self._rect_start and self._rect_end:
2041
+ p.setPen(QPen(QColor(80, 80, 255), 2, Qt.PenStyle.DashLine))
2042
+ p.setBrush(QBrush(QColor(80, 80, 255, 40)))
2043
+ px0, py0 = self._norm_to_pixel(*self._rect_start)
2044
+ px1, py1 = self._norm_to_pixel(*self._rect_end)
2045
+ p.drawRect(QRectF(QPointF(px0, py0), QPointF(px1, py1)))
2046
+
2047
+ # Axis labels
2048
+ p.setPen(QColor(100, 100, 100))
2049
+ p.setFont(QFont("Arial", 8))
2050
+ p.drawText(m, m - 5, "0,0")
2051
+ p.drawText(m + draw_w - 20, m - 5, "1,0")
2052
+ p.drawText(m, m + draw_h + 14, "0,1")
2053
+
2054
+ p.end()
2055
+
2056
+
2057
+ class SpatialRegionEditor(QDialog):
2058
+ """Dialog for drawing and managing named spatial regions."""
2059
+
2060
+ def __init__(self, centroids, existing_regions, parent=None):
2061
+ super().__init__(parent)
2062
+ self.setWindowTitle("Spatial Region Editor")
2063
+ self.resize(850, 600)
2064
+ # Deep copy to allow cancel
2065
+ self.regions = [
2066
+ {"name": r["name"], "type": r["type"], "vertices": list(r["vertices"])}
2067
+ for r in existing_regions
2068
+ ]
2069
+ self.centroids = centroids
2070
+ self._setup_ui()
2071
+
2072
+ def _setup_ui(self):
2073
+ layout = QVBoxLayout()
2074
+
2075
+ # Toolbar
2076
+ toolbar = QHBoxLayout()
2077
+ self.polygon_btn = QToolButton()
2078
+ self.polygon_btn.setText("Draw Polygon")
2079
+ self.polygon_btn.setToolTip("Click vertices, double-click to finish")
2080
+ self.polygon_btn.clicked.connect(self._start_polygon)
2081
+
2082
+ self.rect_btn = QToolButton()
2083
+ self.rect_btn.setText("Draw Rectangle")
2084
+ self.rect_btn.setToolTip("Click and drag to define rectangle")
2085
+ self.rect_btn.clicked.connect(self._start_rect)
2086
+
2087
+ self.cancel_draw_btn = QPushButton("Cancel drawing")
2088
+ self.cancel_draw_btn.clicked.connect(self._cancel_draw)
2089
+ self.cancel_draw_btn.setEnabled(False)
2090
+
2091
+ toolbar.addWidget(self.polygon_btn)
2092
+ toolbar.addWidget(self.rect_btn)
2093
+ toolbar.addWidget(self.cancel_draw_btn)
2094
+ toolbar.addStretch()
2095
+ layout.addLayout(toolbar)
2096
+
2097
+ # Main splitter: canvas (left) + region list (right)
2098
+ splitter = QSplitter(Qt.Orientation.Horizontal)
2099
+
2100
+ self.canvas = SpatialCanvas(self.centroids, self.regions)
2101
+ splitter.addWidget(self.canvas)
2102
+
2103
+ right = QWidget()
2104
+ right_layout = QVBoxLayout()
2105
+ right_layout.addWidget(QLabel("Defined Regions"))
2106
+ self.region_list = QListWidget()
2107
+ self.region_list.setSelectionMode(QListWidget.SelectionMode.SingleSelection)
2108
+ right_layout.addWidget(self.region_list)
2109
+
2110
+ btn_row = QHBoxLayout()
2111
+ self.rename_btn = QPushButton("Rename")
2112
+ self.rename_btn.clicked.connect(self._rename_region)
2113
+ self.delete_btn = QPushButton("Delete")
2114
+ self.delete_btn.clicked.connect(self._delete_region)
2115
+ btn_row.addWidget(self.rename_btn)
2116
+ btn_row.addWidget(self.delete_btn)
2117
+ right_layout.addLayout(btn_row)
2118
+ right.setLayout(right_layout)
2119
+ splitter.addWidget(right)
2120
+ splitter.setSizes([600, 250])
2121
+ layout.addWidget(splitter, stretch=1)
2122
+
2123
+ # OK/Cancel
2124
+ btns = QHBoxLayout()
2125
+ ok_btn = QPushButton("Done")
2126
+ ok_btn.clicked.connect(self.accept)
2127
+ cancel_btn = QPushButton("Cancel")
2128
+ cancel_btn.clicked.connect(self.reject)
2129
+ btns.addWidget(ok_btn)
2130
+ btns.addWidget(cancel_btn)
2131
+ layout.addLayout(btns)
2132
+
2133
+ self.setLayout(layout)
2134
+ self._refresh_region_list()
2135
+
2136
+ def _start_polygon(self):
2137
+ name, ok = QInputDialog.getText(self, "Region Name", "Name for the new polygon region:")
2138
+ if not ok or not name.strip():
2139
+ return
2140
+ self.cancel_draw_btn.setEnabled(True)
2141
+ self.canvas.start_drawing("polygon", name.strip())
2142
+
2143
+ def _start_rect(self):
2144
+ name, ok = QInputDialog.getText(self, "Region Name", "Name for the new rectangle region:")
2145
+ if not ok or not name.strip():
2146
+ return
2147
+ self.cancel_draw_btn.setEnabled(True)
2148
+ self.canvas.start_drawing("rect", name.strip())
2149
+
2150
+ def _cancel_draw(self):
2151
+ self.canvas.cancel_drawing()
2152
+ self.cancel_draw_btn.setEnabled(False)
2153
+
2154
+ def _refresh_region_list(self):
2155
+ self.cancel_draw_btn.setEnabled(False)
2156
+ self.region_list.clear()
2157
+ for r in self.regions:
2158
+ n_verts = len(r["vertices"])
2159
+ self.region_list.addItem(f"{r['name']} ({r['type']}, {n_verts} pts)")
2160
+ self.canvas.update()
2161
+
2162
+ def _rename_region(self):
2163
+ idx = self.region_list.currentRow()
2164
+ if idx < 0 or idx >= len(self.regions):
2165
+ return
2166
+ old = self.regions[idx]["name"]
2167
+ name, ok = QInputDialog.getText(self, "Rename Region", "New name:", text=old)
2168
+ if ok and name.strip():
2169
+ self.regions[idx]["name"] = name.strip()
2170
+ self._refresh_region_list()
2171
+
2172
+ def _delete_region(self):
2173
+ idx = self.region_list.currentRow()
2174
+ if idx < 0 or idx >= len(self.regions):
2175
+ return
2176
+ self.regions.pop(idx)
2177
+ self._refresh_region_list()
2178
+
2179
+
2180
+ class GroupManagementDialog(QDialog):
2181
+ def __init__(self, video_paths, groups, video_groups, parent=None):
2182
+ super().__init__(parent)
2183
+ self.video_paths = sorted(list(video_paths))
2184
+ self.groups = groups.copy()
2185
+ self.video_groups = video_groups.copy()
2186
+ self.setWindowTitle("Manage groups")
2187
+ self.resize(800, 600)
2188
+ self._setup_ui()
2189
+
2190
+ def _setup_ui(self):
2191
+ layout = QVBoxLayout()
2192
+
2193
+ # Top: Group creation
2194
+ group_input_layout = QHBoxLayout()
2195
+ self.group_name_edit = QLineEdit()
2196
+ self.group_name_edit.setPlaceholderText("New Group Name")
2197
+ self.add_group_btn = QPushButton("Add group")
2198
+ self.add_group_btn.clicked.connect(self._add_group)
2199
+ group_input_layout.addWidget(self.group_name_edit)
2200
+ group_input_layout.addWidget(self.add_group_btn)
2201
+ layout.addLayout(group_input_layout)
2202
+
2203
+ # Main: Splitter with Videos (left) and Groups (right)
2204
+ splitter = QSplitter(Qt.Orientation.Horizontal)
2205
+
2206
+ # Videos List (Table to show current group)
2207
+ video_widget = QWidget()
2208
+ video_layout = QVBoxLayout()
2209
+ video_layout.addWidget(QLabel("Videos"))
2210
+ self.video_table = QTableWidget()
2211
+ self.video_table.setColumnCount(2)
2212
+ self.video_table.setHorizontalHeaderLabels(["Video", "Assigned Group"])
2213
+ self.video_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.ResizeMode.Stretch)
2214
+ self.video_table.setSelectionBehavior(QTableWidget.SelectionBehavior.SelectRows)
2215
+ self.video_table.setSelectionMode(QTableWidget.SelectionMode.ExtendedSelection)
2216
+ video_layout.addWidget(self.video_table)
2217
+ video_widget.setLayout(video_layout)
2218
+ splitter.addWidget(video_widget)
2219
+
2220
+ # Groups List
2221
+ group_widget = QWidget()
2222
+ group_layout = QVBoxLayout()
2223
+ group_layout.addWidget(QLabel("Groups (Select to assign)"))
2224
+ self.group_list = QTableWidget()
2225
+ self.group_list.setColumnCount(1)
2226
+ self.group_list.setHorizontalHeaderLabels(["Group Name"])
2227
+ self.group_list.horizontalHeader().setSectionResizeMode(0, QHeaderView.ResizeMode.Stretch)
2228
+ self.group_list.itemSelectionChanged.connect(self._assign_group)
2229
+ group_layout.addWidget(self.group_list)
2230
+ group_widget.setLayout(group_layout)
2231
+ splitter.addWidget(group_widget)
2232
+
2233
+ layout.addWidget(splitter)
2234
+
2235
+ btns = QHBoxLayout()
2236
+ ok_btn = QPushButton("OK")
2237
+ ok_btn.clicked.connect(self.accept)
2238
+ cancel_btn = QPushButton("Cancel")
2239
+ cancel_btn.clicked.connect(self.reject)
2240
+ btns.addWidget(ok_btn)
2241
+ btns.addWidget(cancel_btn)
2242
+ layout.addLayout(btns)
2243
+
2244
+ self.setLayout(layout)
2245
+ self._refresh_lists()
2246
+
2247
+ def _refresh_lists(self):
2248
+ # Videos
2249
+ self.video_table.setRowCount(len(self.video_paths))
2250
+ for i, path in enumerate(self.video_paths):
2251
+ name = os.path.basename(path)
2252
+ self.video_table.setItem(i, 0, QTableWidgetItem(name))
2253
+ group = self.video_groups.get(path, "Unassigned")
2254
+ self.video_table.setItem(i, 1, QTableWidgetItem(group))
2255
+
2256
+ # Groups
2257
+ self.group_list.setRowCount(len(self.groups))
2258
+ for i, group in enumerate(sorted(self.groups.keys())):
2259
+ self.group_list.setItem(i, 0, QTableWidgetItem(group))
2260
+
2261
+ def _add_group(self):
2262
+ name = self.group_name_edit.text().strip()
2263
+ if name and name not in self.groups:
2264
+ self.groups[name] = []
2265
+ self._refresh_lists()
2266
+ self.group_name_edit.clear()
2267
+
2268
+ def _assign_group(self):
2269
+ selected_group_items = self.group_list.selectedItems()
2270
+ if not selected_group_items:
2271
+ return
2272
+
2273
+ group_name = selected_group_items[0].text()
2274
+
2275
+ selected_video_rows = set(idx.row() for idx in self.video_table.selectedIndexes())
2276
+
2277
+ for row in selected_video_rows:
2278
+ # Using index from video_paths since table order matches
2279
+ full_path = self.video_paths[row]
2280
+
2281
+ # Remove from old group
2282
+ old_group = self.video_groups.get(full_path)
2283
+ if old_group and old_group in self.groups and full_path in self.groups[old_group]:
2284
+ self.groups[old_group].remove(full_path)
2285
+
2286
+ # Assign new
2287
+ self.video_groups[full_path] = group_name
2288
+ if full_path not in self.groups[group_name]:
2289
+ self.groups[group_name].append(full_path)
2290
+
2291
+ self._refresh_lists()