singlebehaviorlab 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sam2/__init__.py +11 -0
- sam2/automatic_mask_generator.py +454 -0
- sam2/benchmark.py +92 -0
- sam2/build_sam.py +174 -0
- sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
- sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
- sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
- sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
- sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
- sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
- sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
- sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
- sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
- sam2/modeling/__init__.py +5 -0
- sam2/modeling/backbones/__init__.py +5 -0
- sam2/modeling/backbones/hieradet.py +317 -0
- sam2/modeling/backbones/image_encoder.py +134 -0
- sam2/modeling/backbones/utils.py +93 -0
- sam2/modeling/memory_attention.py +169 -0
- sam2/modeling/memory_encoder.py +181 -0
- sam2/modeling/position_encoding.py +239 -0
- sam2/modeling/sam/__init__.py +5 -0
- sam2/modeling/sam/mask_decoder.py +295 -0
- sam2/modeling/sam/prompt_encoder.py +202 -0
- sam2/modeling/sam/transformer.py +311 -0
- sam2/modeling/sam2_base.py +913 -0
- sam2/modeling/sam2_utils.py +323 -0
- sam2/sam2_hiera_b+.yaml +113 -0
- sam2/sam2_hiera_l.yaml +117 -0
- sam2/sam2_hiera_s.yaml +116 -0
- sam2/sam2_hiera_t.yaml +118 -0
- sam2/sam2_image_predictor.py +466 -0
- sam2/sam2_video_predictor.py +1388 -0
- sam2/sam2_video_predictor_legacy.py +1172 -0
- sam2/utils/__init__.py +5 -0
- sam2/utils/amg.py +348 -0
- sam2/utils/misc.py +349 -0
- sam2/utils/transforms.py +118 -0
- singlebehaviorlab/__init__.py +4 -0
- singlebehaviorlab/__main__.py +130 -0
- singlebehaviorlab/_paths.py +100 -0
- singlebehaviorlab/backend/__init__.py +2 -0
- singlebehaviorlab/backend/augmentations.py +320 -0
- singlebehaviorlab/backend/data_store.py +420 -0
- singlebehaviorlab/backend/model.py +1290 -0
- singlebehaviorlab/backend/train.py +4667 -0
- singlebehaviorlab/backend/uncertainty.py +578 -0
- singlebehaviorlab/backend/video_processor.py +688 -0
- singlebehaviorlab/backend/video_utils.py +139 -0
- singlebehaviorlab/data/config/config.yaml +85 -0
- singlebehaviorlab/data/training_profiles.json +334 -0
- singlebehaviorlab/gui/__init__.py +4 -0
- singlebehaviorlab/gui/analysis_widget.py +2291 -0
- singlebehaviorlab/gui/attention_export.py +311 -0
- singlebehaviorlab/gui/clip_extraction_widget.py +481 -0
- singlebehaviorlab/gui/clustering_widget.py +3187 -0
- singlebehaviorlab/gui/inference_popups.py +1138 -0
- singlebehaviorlab/gui/inference_widget.py +4550 -0
- singlebehaviorlab/gui/inference_worker.py +651 -0
- singlebehaviorlab/gui/labeling_widget.py +2324 -0
- singlebehaviorlab/gui/main_window.py +754 -0
- singlebehaviorlab/gui/metadata_management_widget.py +1119 -0
- singlebehaviorlab/gui/motion_tracking.py +764 -0
- singlebehaviorlab/gui/overlay_export.py +1234 -0
- singlebehaviorlab/gui/plot_integration.py +729 -0
- singlebehaviorlab/gui/qt_helpers.py +29 -0
- singlebehaviorlab/gui/registration_widget.py +1485 -0
- singlebehaviorlab/gui/review_widget.py +1330 -0
- singlebehaviorlab/gui/segmentation_tracking_widget.py +2752 -0
- singlebehaviorlab/gui/tab_tutorial_dialog.py +312 -0
- singlebehaviorlab/gui/timeline_themes.py +131 -0
- singlebehaviorlab/gui/training_profiles.py +418 -0
- singlebehaviorlab/gui/training_widget.py +3719 -0
- singlebehaviorlab/gui/video_utils.py +233 -0
- singlebehaviorlab/licenses/SAM2-LICENSE +201 -0
- singlebehaviorlab/licenses/VideoPrism-LICENSE +202 -0
- singlebehaviorlab-2.0.0.dist-info/METADATA +447 -0
- singlebehaviorlab-2.0.0.dist-info/RECORD +88 -0
- singlebehaviorlab-2.0.0.dist-info/WHEEL +5 -0
- singlebehaviorlab-2.0.0.dist-info/entry_points.txt +2 -0
- singlebehaviorlab-2.0.0.dist-info/licenses/LICENSE +21 -0
- singlebehaviorlab-2.0.0.dist-info/top_level.txt +3 -0
- videoprism/__init__.py +0 -0
- videoprism/encoders.py +910 -0
- videoprism/layers.py +1136 -0
- videoprism/models.py +407 -0
- videoprism/tokenizers.py +167 -0
- videoprism/utils.py +168 -0
|
@@ -0,0 +1,2291 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
import numpy as np
|
|
5
|
+
import cv2
|
|
6
|
+
import yaml
|
|
7
|
+
|
|
8
|
+
logger = logging.getLogger(__name__)
|
|
9
|
+
from PyQt6.QtWidgets import (
|
|
10
|
+
QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QLabel, QComboBox,
|
|
11
|
+
QGroupBox, QTableWidget, QTableWidgetItem, QHeaderView, QDialog,
|
|
12
|
+
QFormLayout, QLineEdit, QMessageBox, QSplitter, QCheckBox, QFileDialog,
|
|
13
|
+
QSizePolicy, QScrollArea, QTabWidget, QListWidget, QListWidgetItem,
|
|
14
|
+
QTextEdit, QAbstractItemView, QSlider, QToolButton, QInputDialog,
|
|
15
|
+
QButtonGroup
|
|
16
|
+
)
|
|
17
|
+
from PyQt6.QtCore import Qt, QUrl, QPointF, QRectF
|
|
18
|
+
from PyQt6.QtGui import QPainter, QPen, QBrush, QColor, QFont, QPolygonF, QPainterPath
|
|
19
|
+
|
|
20
|
+
try:
|
|
21
|
+
from PyQt6.QtWebEngineWidgets import QWebEngineView
|
|
22
|
+
HAS_WEBENGINE = True
|
|
23
|
+
except ImportError:
|
|
24
|
+
HAS_WEBENGINE = False
|
|
25
|
+
|
|
26
|
+
try:
|
|
27
|
+
import plotly.graph_objects as go
|
|
28
|
+
import plotly.express as px
|
|
29
|
+
import pandas as pd
|
|
30
|
+
HAS_PLOTLY = True
|
|
31
|
+
except ImportError:
|
|
32
|
+
HAS_PLOTLY = False
|
|
33
|
+
|
|
34
|
+
try:
|
|
35
|
+
import scipy.stats as stats
|
|
36
|
+
HAS_SCIPY = True
|
|
37
|
+
except ImportError:
|
|
38
|
+
HAS_SCIPY = False
|
|
39
|
+
|
|
40
|
+
class AnalysisWidget(QWidget):
|
|
41
|
+
"""Widget for downstream analysis of behavior data."""
|
|
42
|
+
|
|
43
|
+
def __init__(self, config: dict):
|
|
44
|
+
super().__init__()
|
|
45
|
+
self.config = config
|
|
46
|
+
self.results = {}
|
|
47
|
+
self.groups = {} # {group_name: [video_path, ...]}
|
|
48
|
+
self.video_groups = {} # {video_path: group_name}
|
|
49
|
+
self.spatial_regions = [] # [{name, type, vertices}, ...]
|
|
50
|
+
self.merged_data = [] # List of dicts
|
|
51
|
+
self.all_behaviors = []
|
|
52
|
+
self.selected_behaviors = set()
|
|
53
|
+
self.visible_groups = set()
|
|
54
|
+
self._setup_ui()
|
|
55
|
+
|
|
56
|
+
def _setup_ui(self):
|
|
57
|
+
layout = QVBoxLayout()
|
|
58
|
+
|
|
59
|
+
# Global Controls (Load, Manage Groups)
|
|
60
|
+
global_controls = QGroupBox("Data controls")
|
|
61
|
+
global_controls.setSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Maximum)
|
|
62
|
+
global_layout = QHBoxLayout()
|
|
63
|
+
|
|
64
|
+
self.load_btn = QPushButton("Load results")
|
|
65
|
+
self.load_btn.clicked.connect(self._load_results)
|
|
66
|
+
global_layout.addWidget(self.load_btn)
|
|
67
|
+
|
|
68
|
+
self.manage_groups_btn = QPushButton("Manage groups")
|
|
69
|
+
self.manage_groups_btn.clicked.connect(self._manage_groups)
|
|
70
|
+
self.manage_groups_btn.setEnabled(False)
|
|
71
|
+
global_layout.addWidget(self.manage_groups_btn)
|
|
72
|
+
|
|
73
|
+
self.filter_behaviors_btn = QPushButton("Filter behaviors")
|
|
74
|
+
self.filter_behaviors_btn.clicked.connect(self._filter_behaviors)
|
|
75
|
+
self.filter_behaviors_btn.setEnabled(False)
|
|
76
|
+
global_layout.addWidget(self.filter_behaviors_btn)
|
|
77
|
+
|
|
78
|
+
global_layout.addStretch()
|
|
79
|
+
global_controls.setLayout(global_layout)
|
|
80
|
+
layout.addWidget(global_controls)
|
|
81
|
+
|
|
82
|
+
# Tab Widget
|
|
83
|
+
self.tabs = QTabWidget()
|
|
84
|
+
|
|
85
|
+
# Overview Tab
|
|
86
|
+
self.overview_tab = self._create_overview_tab()
|
|
87
|
+
self.tabs.addTab(self.overview_tab, "Overview")
|
|
88
|
+
|
|
89
|
+
# Behavior Transitions Tab
|
|
90
|
+
self.transitions_tab = self._create_transitions_tab()
|
|
91
|
+
self.tabs.addTab(self.transitions_tab, "Behavior Transitions")
|
|
92
|
+
|
|
93
|
+
layout.addWidget(self.tabs)
|
|
94
|
+
self.setLayout(layout)
|
|
95
|
+
|
|
96
|
+
def _create_overview_tab(self):
|
|
97
|
+
"""Create the Overview tab with split layout."""
|
|
98
|
+
tab = QWidget()
|
|
99
|
+
main_layout = QHBoxLayout()
|
|
100
|
+
|
|
101
|
+
# Left: Plot area (70%)
|
|
102
|
+
if HAS_WEBENGINE:
|
|
103
|
+
self.webview = QWebEngineView()
|
|
104
|
+
self.webview.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
|
|
105
|
+
main_layout.addWidget(self.webview, 70)
|
|
106
|
+
else:
|
|
107
|
+
lbl = QLabel("PyQt6.QtWebEngineWidgets not installed. Plots will open in default browser.")
|
|
108
|
+
lbl.setAlignment(Qt.AlignmentFlag.AlignCenter)
|
|
109
|
+
main_layout.addWidget(lbl, 70)
|
|
110
|
+
|
|
111
|
+
# Right: Controls & Analysis (30%)
|
|
112
|
+
sidebar_scroll = QScrollArea()
|
|
113
|
+
sidebar_scroll.setWidgetResizable(True)
|
|
114
|
+
sidebar_widget = QWidget()
|
|
115
|
+
sidebar_layout = QVBoxLayout()
|
|
116
|
+
sidebar_layout.setSpacing(15)
|
|
117
|
+
|
|
118
|
+
# 1. Plot Settings
|
|
119
|
+
plot_group = QGroupBox("Plot settings")
|
|
120
|
+
plot_layout = QFormLayout()
|
|
121
|
+
|
|
122
|
+
self.metric_combo = QComboBox()
|
|
123
|
+
self.metric_combo.addItems(["Occurrences (Count)", "Average Bout Duration (s)", "Total Duration (s)", "Percent Time (%)"])
|
|
124
|
+
self.metric_combo.currentIndexChanged.connect(self._update_plots)
|
|
125
|
+
plot_layout.addRow("Metric:", self.metric_combo)
|
|
126
|
+
|
|
127
|
+
self.plot_mode_combo = QComboBox()
|
|
128
|
+
self.plot_mode_combo.addItems(["General Overview", "Group Comparison"])
|
|
129
|
+
self.plot_mode_combo.currentIndexChanged.connect(self._update_plots)
|
|
130
|
+
plot_layout.addRow("Analysis Mode:", self.plot_mode_combo)
|
|
131
|
+
|
|
132
|
+
plot_group.setLayout(plot_layout)
|
|
133
|
+
sidebar_layout.addWidget(plot_group)
|
|
134
|
+
|
|
135
|
+
# 1b. Graph Appearance
|
|
136
|
+
app_group = QGroupBox("Graph appearance")
|
|
137
|
+
app_layout = QFormLayout()
|
|
138
|
+
|
|
139
|
+
self.graph_type_combo = QComboBox()
|
|
140
|
+
self.graph_type_combo.addItems(["Auto", "Bar Chart", "Box Plot", "Violin Plot", "Strip Plot", "Line Plot"])
|
|
141
|
+
self.graph_type_combo.currentIndexChanged.connect(self._update_plots)
|
|
142
|
+
app_layout.addRow("Graph Type:", self.graph_type_combo)
|
|
143
|
+
|
|
144
|
+
self.color_theme_combo = QComboBox()
|
|
145
|
+
self.color_theme_combo.addItems(["plotly", "plotly_white", "plotly_dark", "ggplot2", "seaborn", "simple_white", "none"])
|
|
146
|
+
self.color_theme_combo.setCurrentText("simple_white")
|
|
147
|
+
self.color_theme_combo.currentIndexChanged.connect(self._update_plots)
|
|
148
|
+
app_layout.addRow("Color Theme:", self.color_theme_combo)
|
|
149
|
+
|
|
150
|
+
app_group.setLayout(app_layout)
|
|
151
|
+
sidebar_layout.addWidget(app_group)
|
|
152
|
+
|
|
153
|
+
# 1c. Spatial Distribution (uses localization data)
|
|
154
|
+
spatial_group = QGroupBox("Spatial distribution")
|
|
155
|
+
spatial_layout = QVBoxLayout()
|
|
156
|
+
|
|
157
|
+
self.spatial_behavior_combo = QComboBox()
|
|
158
|
+
self.spatial_behavior_combo.addItem("All behaviors")
|
|
159
|
+
self.spatial_behavior_combo.currentIndexChanged.connect(self._update_spatial_plot)
|
|
160
|
+
spatial_form = QFormLayout()
|
|
161
|
+
spatial_form.addRow("Behavior:", self.spatial_behavior_combo)
|
|
162
|
+
|
|
163
|
+
self.spatial_video_combo = QComboBox()
|
|
164
|
+
self.spatial_video_combo.addItem("All")
|
|
165
|
+
self.spatial_video_combo.currentIndexChanged.connect(self._update_spatial_plot)
|
|
166
|
+
spatial_form.addRow("Video:", self.spatial_video_combo)
|
|
167
|
+
|
|
168
|
+
self.spatial_dot_size_slider = QSlider(Qt.Orientation.Horizontal)
|
|
169
|
+
self.spatial_dot_size_slider.setMinimum(1)
|
|
170
|
+
self.spatial_dot_size_slider.setMaximum(25)
|
|
171
|
+
self.spatial_dot_size_slider.setValue(7)
|
|
172
|
+
self.spatial_dot_size_slider.valueChanged.connect(self._update_spatial_plot)
|
|
173
|
+
spatial_form.addRow("Dot size:", self.spatial_dot_size_slider)
|
|
174
|
+
|
|
175
|
+
spatial_layout.addLayout(spatial_form)
|
|
176
|
+
|
|
177
|
+
self.spatial_show_btn = QPushButton("Show spatial distribution")
|
|
178
|
+
self.spatial_show_btn.clicked.connect(self._update_spatial_plot)
|
|
179
|
+
self.spatial_show_btn.setEnabled(False)
|
|
180
|
+
spatial_layout.addWidget(self.spatial_show_btn)
|
|
181
|
+
|
|
182
|
+
self.spatial_save_btn = QPushButton("Save spatial plot (PDF/SVG)")
|
|
183
|
+
self.spatial_save_btn.clicked.connect(self._save_spatial_plot)
|
|
184
|
+
spatial_layout.addWidget(self.spatial_save_btn)
|
|
185
|
+
|
|
186
|
+
self.spatial_info_label = QLabel("Load results with localization data to use.")
|
|
187
|
+
self.spatial_info_label.setWordWrap(True)
|
|
188
|
+
self.spatial_info_label.setStyleSheet("color: grey; font-size: 11px;")
|
|
189
|
+
spatial_layout.addWidget(self.spatial_info_label)
|
|
190
|
+
|
|
191
|
+
spatial_group.setToolTip(
|
|
192
|
+
"One point per inference clip (step-frame grid). Not based on aggregated bout boundaries."
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
self.manage_regions_btn = QPushButton("Manage spatial regions")
|
|
196
|
+
self.manage_regions_btn.setToolTip("Draw named regions on the spatial map, then filter analysis by region.")
|
|
197
|
+
self.manage_regions_btn.clicked.connect(self._manage_spatial_regions)
|
|
198
|
+
self.manage_regions_btn.setEnabled(False)
|
|
199
|
+
spatial_layout.addWidget(self.manage_regions_btn)
|
|
200
|
+
|
|
201
|
+
self.region_filter_combo = QComboBox()
|
|
202
|
+
self.region_filter_combo.addItem("All Regions")
|
|
203
|
+
self.region_filter_combo.currentIndexChanged.connect(self._on_region_filter_changed)
|
|
204
|
+
spatial_form.addRow("Region:", self.region_filter_combo)
|
|
205
|
+
|
|
206
|
+
spatial_group.setLayout(spatial_layout)
|
|
207
|
+
sidebar_layout.addWidget(spatial_group)
|
|
208
|
+
|
|
209
|
+
# 2. Group Selection
|
|
210
|
+
group_group = QGroupBox("Groups")
|
|
211
|
+
group_layout = QVBoxLayout()
|
|
212
|
+
self.group_list_widget = QListWidget()
|
|
213
|
+
self.group_list_widget.setSelectionMode(QAbstractItemView.SelectionMode.NoSelection) # Checkboxes only
|
|
214
|
+
self.group_list_widget.setFixedHeight(150)
|
|
215
|
+
self.group_list_widget.itemChanged.connect(self._on_group_selection_changed)
|
|
216
|
+
group_layout.addWidget(self.group_list_widget)
|
|
217
|
+
group_group.setLayout(group_layout)
|
|
218
|
+
sidebar_layout.addWidget(group_group)
|
|
219
|
+
|
|
220
|
+
# 3. Statistics
|
|
221
|
+
stats_group = QGroupBox("Statistics (Group comp.)")
|
|
222
|
+
stats_layout = QVBoxLayout()
|
|
223
|
+
|
|
224
|
+
self.stats_test_combo = QComboBox()
|
|
225
|
+
self.stats_test_combo.addItems(["T-Test / Mann-Whitney (2 groups)", "ANOVA / Kruskal-Wallis (>2 groups)"])
|
|
226
|
+
stats_layout.addWidget(QLabel("Test type:"))
|
|
227
|
+
stats_layout.addWidget(self.stats_test_combo)
|
|
228
|
+
|
|
229
|
+
self.run_stats_btn = QPushButton("Run statistics")
|
|
230
|
+
self.run_stats_btn.clicked.connect(self._run_statistics)
|
|
231
|
+
stats_layout.addWidget(self.run_stats_btn)
|
|
232
|
+
|
|
233
|
+
self.stats_output = QTextEdit()
|
|
234
|
+
self.stats_output.setReadOnly(True)
|
|
235
|
+
self.stats_output.setPlaceholderText("Results will appear here...")
|
|
236
|
+
self.stats_output.setMaximumHeight(150)
|
|
237
|
+
stats_layout.addWidget(self.stats_output)
|
|
238
|
+
|
|
239
|
+
stats_group.setLayout(stats_layout)
|
|
240
|
+
sidebar_layout.addWidget(stats_group)
|
|
241
|
+
|
|
242
|
+
# 4. Export
|
|
243
|
+
export_group = QGroupBox("Export")
|
|
244
|
+
export_layout = QVBoxLayout()
|
|
245
|
+
|
|
246
|
+
self.save_graph_btn = QPushButton("Save graph")
|
|
247
|
+
self.save_graph_btn.clicked.connect(self._save_graph)
|
|
248
|
+
export_layout.addWidget(self.save_graph_btn)
|
|
249
|
+
|
|
250
|
+
self.save_csv_btn = QPushButton("Save data (.csv)")
|
|
251
|
+
self.save_csv_btn.clicked.connect(self._save_csv)
|
|
252
|
+
export_layout.addWidget(self.save_csv_btn)
|
|
253
|
+
|
|
254
|
+
export_group.setLayout(export_layout)
|
|
255
|
+
sidebar_layout.addWidget(export_group)
|
|
256
|
+
|
|
257
|
+
sidebar_layout.addStretch()
|
|
258
|
+
sidebar_widget.setLayout(sidebar_layout)
|
|
259
|
+
sidebar_scroll.setWidget(sidebar_widget)
|
|
260
|
+
|
|
261
|
+
main_layout.addWidget(sidebar_scroll, 30)
|
|
262
|
+
|
|
263
|
+
tab.setLayout(main_layout)
|
|
264
|
+
return tab
|
|
265
|
+
|
|
266
|
+
def _update_sidebar_groups(self):
|
|
267
|
+
"""Update group list in sidebar based on current data."""
|
|
268
|
+
self.group_list_widget.blockSignals(True)
|
|
269
|
+
self.group_list_widget.clear()
|
|
270
|
+
|
|
271
|
+
# Get all unique groups from data or metadata
|
|
272
|
+
groups = sorted(list(self.groups.keys()))
|
|
273
|
+
if not groups and self.merged_data:
|
|
274
|
+
# Fallback if groups dict not fully populated but data exists
|
|
275
|
+
df = pd.DataFrame(self.merged_data)
|
|
276
|
+
if "Group" in df.columns:
|
|
277
|
+
groups = sorted(df["Group"].dropna().unique().tolist())
|
|
278
|
+
|
|
279
|
+
if not groups:
|
|
280
|
+
groups = ["Unassigned"]
|
|
281
|
+
|
|
282
|
+
for grp in groups:
|
|
283
|
+
item = QListWidgetItem(grp)
|
|
284
|
+
item.setFlags(item.flags() | Qt.ItemFlag.ItemIsUserCheckable)
|
|
285
|
+
item.setCheckState(Qt.CheckState.Checked)
|
|
286
|
+
self.group_list_widget.addItem(item)
|
|
287
|
+
|
|
288
|
+
self.visible_groups = set(groups)
|
|
289
|
+
self.group_list_widget.blockSignals(False)
|
|
290
|
+
|
|
291
|
+
def _on_group_selection_changed(self, item):
|
|
292
|
+
"""Handle group checkbox toggles."""
|
|
293
|
+
self.visible_groups = set()
|
|
294
|
+
for i in range(self.group_list_widget.count()):
|
|
295
|
+
it = self.group_list_widget.item(i)
|
|
296
|
+
if it.checkState() == Qt.CheckState.Checked:
|
|
297
|
+
self.visible_groups.add(it.text())
|
|
298
|
+
self._update_plots()
|
|
299
|
+
|
|
300
|
+
def _run_statistics(self):
|
|
301
|
+
if not HAS_SCIPY:
|
|
302
|
+
QMessageBox.warning(self, "Error", "Scipy is not installed. Cannot run statistics.")
|
|
303
|
+
return
|
|
304
|
+
|
|
305
|
+
if not self.merged_data:
|
|
306
|
+
return
|
|
307
|
+
|
|
308
|
+
df = pd.DataFrame(self.merged_data)
|
|
309
|
+
|
|
310
|
+
# Apply filters
|
|
311
|
+
if self.selected_behaviors:
|
|
312
|
+
df = df[df["Behavior"].isin(self.selected_behaviors)]
|
|
313
|
+
if self.visible_groups:
|
|
314
|
+
df = df[df["Group"].isin(self.visible_groups)]
|
|
315
|
+
sel_region = self.region_filter_combo.currentText()
|
|
316
|
+
if sel_region != "All Regions" and "Region" in df.columns:
|
|
317
|
+
df = df[df["Region"] == sel_region]
|
|
318
|
+
|
|
319
|
+
if df.empty:
|
|
320
|
+
self.stats_output.setText("No data available for selected filters.")
|
|
321
|
+
return
|
|
322
|
+
|
|
323
|
+
metric = self.metric_combo.currentText()
|
|
324
|
+
test_type = self.stats_test_combo.currentText()
|
|
325
|
+
|
|
326
|
+
# Prepare data for stats
|
|
327
|
+
if "Occurrences" in metric:
|
|
328
|
+
agg = df.groupby(["Video", "Group", "Behavior"]).size().reset_index(name="Value")
|
|
329
|
+
elif "Average" in metric:
|
|
330
|
+
agg = df.groupby(["Video", "Group", "Behavior"])["Duration"].mean().reset_index(name="Value")
|
|
331
|
+
elif "Total" in metric:
|
|
332
|
+
agg = df.groupby(["Video", "Group", "Behavior"])["Duration"].sum().reset_index(name="Value")
|
|
333
|
+
elif "Percent" in metric:
|
|
334
|
+
total_times = df.groupby("Video")["Duration"].sum().to_dict()
|
|
335
|
+
agg = df.groupby(["Video", "Group", "Behavior"])["Duration"].sum().reset_index(name="Value")
|
|
336
|
+
agg["Value"] = agg.apply(lambda x: (x["Value"] / total_times.get(x["Video"], 1)) * 100, axis=1)
|
|
337
|
+
|
|
338
|
+
output = []
|
|
339
|
+
output.append(f"Metric: {metric}")
|
|
340
|
+
output.append(f"Test: {test_type}\n")
|
|
341
|
+
|
|
342
|
+
unique_behaviors = sorted(agg["Behavior"].unique())
|
|
343
|
+
unique_groups = sorted(agg["Group"].unique())
|
|
344
|
+
|
|
345
|
+
if len(unique_groups) < 2:
|
|
346
|
+
self.stats_output.setText("Need at least 2 groups for comparison.")
|
|
347
|
+
return
|
|
348
|
+
|
|
349
|
+
for beh in unique_behaviors:
|
|
350
|
+
beh_data = agg[agg["Behavior"] == beh]
|
|
351
|
+
groups_data = [beh_data[beh_data["Group"] == g]["Value"].values for g in unique_groups]
|
|
352
|
+
|
|
353
|
+
# Filter valid data
|
|
354
|
+
valid_groups_data = []
|
|
355
|
+
valid_group_names = []
|
|
356
|
+
for g, d in zip(unique_groups, groups_data):
|
|
357
|
+
if len(d) > 0:
|
|
358
|
+
valid_groups_data.append(d)
|
|
359
|
+
valid_group_names.append(g)
|
|
360
|
+
|
|
361
|
+
if len(valid_groups_data) < 2:
|
|
362
|
+
output.append(f"{beh}: Not enough data")
|
|
363
|
+
continue
|
|
364
|
+
|
|
365
|
+
try:
|
|
366
|
+
if "2 groups" in test_type:
|
|
367
|
+
# Pairwise Mann-Whitney
|
|
368
|
+
import itertools
|
|
369
|
+
for g1, g2 in itertools.combinations(valid_group_names, 2):
|
|
370
|
+
d1 = valid_groups_data[valid_group_names.index(g1)]
|
|
371
|
+
d2 = valid_groups_data[valid_group_names.index(g2)]
|
|
372
|
+
stat, p = stats.mannwhitneyu(d1, d2)
|
|
373
|
+
output.append(f"{beh} ({g1} vs {g2}): p={p:.4f}")
|
|
374
|
+
|
|
375
|
+
else:
|
|
376
|
+
# Kruskal-Wallis
|
|
377
|
+
stat, p = stats.kruskal(*valid_groups_data)
|
|
378
|
+
output.append(f"{beh} (Kruskal-Wallis): p={p:.4f}")
|
|
379
|
+
|
|
380
|
+
except Exception as e:
|
|
381
|
+
output.append(f"{beh}: Error ({str(e)})")
|
|
382
|
+
|
|
383
|
+
self.stats_output.setText("\n".join(output))
|
|
384
|
+
|
|
385
|
+
def _save_graph(self):
|
|
386
|
+
if not hasattr(self, 'last_fig') or self.last_fig is None:
|
|
387
|
+
QMessageBox.warning(self, "Error", "No plot to save.")
|
|
388
|
+
return
|
|
389
|
+
|
|
390
|
+
path, filter_ = QFileDialog.getSaveFileName(
|
|
391
|
+
self, "Save Graph",
|
|
392
|
+
self.config.get("experiment_path", ""),
|
|
393
|
+
"PDF Files (*.pdf);;SVG Files (*.svg);;PNG Files (*.png);;HTML Files (*.html)"
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
if not path:
|
|
397
|
+
return
|
|
398
|
+
|
|
399
|
+
try:
|
|
400
|
+
if path.lower().endswith(".html"):
|
|
401
|
+
self.last_fig.write_html(path)
|
|
402
|
+
else:
|
|
403
|
+
self.last_fig.write_image(path)
|
|
404
|
+
QMessageBox.information(self, "Success", f"Graph saved to {path}")
|
|
405
|
+
except Exception as e:
|
|
406
|
+
if "kaleido" in str(e).lower() or "executable" in str(e).lower():
|
|
407
|
+
QMessageBox.warning(self, "Error", "Saving as static image requires 'kaleido'.\nPlease install it (pip install kaleido) or save as HTML.")
|
|
408
|
+
else:
|
|
409
|
+
QMessageBox.critical(self, "Error", f"Failed to save graph: {e}")
|
|
410
|
+
|
|
411
|
+
def _save_csv(self):
|
|
412
|
+
if not self.merged_data:
|
|
413
|
+
return
|
|
414
|
+
|
|
415
|
+
path, _ = QFileDialog.getSaveFileName(
|
|
416
|
+
self, "Save Data",
|
|
417
|
+
self.config.get("experiment_path", ""),
|
|
418
|
+
"CSV Files (*.csv)"
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
if not path:
|
|
422
|
+
return
|
|
423
|
+
|
|
424
|
+
df = pd.DataFrame(self.merged_data)
|
|
425
|
+
if self.selected_behaviors:
|
|
426
|
+
df = df[df["Behavior"].isin(self.selected_behaviors)]
|
|
427
|
+
if self.visible_groups:
|
|
428
|
+
df = df[df["Group"].isin(self.visible_groups)]
|
|
429
|
+
sel_region = self.region_filter_combo.currentText()
|
|
430
|
+
if sel_region != "All Regions" and "Region" in df.columns:
|
|
431
|
+
df = df[df["Region"] == sel_region]
|
|
432
|
+
|
|
433
|
+
df.to_csv(path, index=False)
|
|
434
|
+
QMessageBox.information(self, "Success", f"Data saved to {path}")
|
|
435
|
+
|
|
436
|
+
def _create_transitions_tab(self):
|
|
437
|
+
"""Create the Behavior Transitions tab (Markov chain analysis)."""
|
|
438
|
+
tab = QWidget()
|
|
439
|
+
layout = QVBoxLayout()
|
|
440
|
+
|
|
441
|
+
# Controls
|
|
442
|
+
controls_group = QGroupBox("Transition analysis controls")
|
|
443
|
+
controls_group.setSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Maximum)
|
|
444
|
+
controls_layout = QHBoxLayout()
|
|
445
|
+
|
|
446
|
+
controls_layout.addWidget(QLabel("Analysis type:"))
|
|
447
|
+
self.transition_type_combo = QComboBox()
|
|
448
|
+
self.transition_type_combo.addItems(["Individual Video", "Group Comparison"])
|
|
449
|
+
self.transition_type_combo.currentIndexChanged.connect(self._on_transition_type_changed)
|
|
450
|
+
controls_layout.addWidget(self.transition_type_combo)
|
|
451
|
+
|
|
452
|
+
controls_layout.addWidget(QLabel("Select:"))
|
|
453
|
+
self.transition_select_combo = QComboBox()
|
|
454
|
+
controls_layout.addWidget(self.transition_select_combo)
|
|
455
|
+
|
|
456
|
+
# New Controls for Layout and Filtering
|
|
457
|
+
self.layout_combo = QComboBox()
|
|
458
|
+
self.layout_combo.addItems(["Circular Layout", "Network Layout"])
|
|
459
|
+
controls_layout.addWidget(QLabel("Layout:"))
|
|
460
|
+
controls_layout.addWidget(self.layout_combo)
|
|
461
|
+
|
|
462
|
+
self.sig_filter_check = QCheckBox("Significant only")
|
|
463
|
+
self.sig_filter_check.setToolTip("Only show transitions that occur significantly more than expected by chance (Z-score > 1.96)")
|
|
464
|
+
controls_layout.addWidget(self.sig_filter_check)
|
|
465
|
+
|
|
466
|
+
self.compute_transition_btn = QPushButton("Compute transitions")
|
|
467
|
+
self.compute_transition_btn.clicked.connect(self._compute_transitions)
|
|
468
|
+
self.compute_transition_btn.setEnabled(False)
|
|
469
|
+
controls_layout.addWidget(self.compute_transition_btn)
|
|
470
|
+
|
|
471
|
+
controls_layout.addStretch()
|
|
472
|
+
controls_group.setLayout(controls_layout)
|
|
473
|
+
layout.addWidget(controls_group)
|
|
474
|
+
|
|
475
|
+
# Transition Matrix Display
|
|
476
|
+
matrix_group = QGroupBox("Transition matrix")
|
|
477
|
+
matrix_layout = QVBoxLayout()
|
|
478
|
+
self.transition_matrix_table = QTableWidget()
|
|
479
|
+
self.transition_matrix_table.setMaximumHeight(250)
|
|
480
|
+
matrix_layout.addWidget(self.transition_matrix_table)
|
|
481
|
+
matrix_group.setLayout(matrix_layout)
|
|
482
|
+
layout.addWidget(matrix_group)
|
|
483
|
+
|
|
484
|
+
# Transition Graph
|
|
485
|
+
if HAS_WEBENGINE:
|
|
486
|
+
self.transition_webview = QWebEngineView()
|
|
487
|
+
self.transition_webview.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
|
|
488
|
+
layout.addWidget(self.transition_webview, 1)
|
|
489
|
+
else:
|
|
490
|
+
layout.addWidget(QLabel("PyQt6.QtWebEngineWidgets not installed."))
|
|
491
|
+
|
|
492
|
+
tab.setLayout(layout)
|
|
493
|
+
return tab
|
|
494
|
+
|
|
495
|
+
def update_config(self, config: dict):
|
|
496
|
+
self.config = config
|
|
497
|
+
self._load_groups_from_config()
|
|
498
|
+
# Try to auto-load
|
|
499
|
+
self._auto_load_results()
|
|
500
|
+
|
|
501
|
+
def _load_groups_from_config(self):
|
|
502
|
+
groups = self.config.get("analysis_groups", {})
|
|
503
|
+
video_groups = self.config.get("analysis_video_groups", {})
|
|
504
|
+
if isinstance(groups, dict):
|
|
505
|
+
self.groups = {
|
|
506
|
+
str(group_name): [str(path) for path in (paths or [])]
|
|
507
|
+
for group_name, paths in groups.items()
|
|
508
|
+
}
|
|
509
|
+
if isinstance(video_groups, dict):
|
|
510
|
+
self.video_groups = {
|
|
511
|
+
str(video_path): str(group_name)
|
|
512
|
+
for video_path, group_name in video_groups.items()
|
|
513
|
+
}
|
|
514
|
+
|
|
515
|
+
def _save_groups_to_config(self):
|
|
516
|
+
self.config["analysis_groups"] = {
|
|
517
|
+
str(group_name): [str(path) for path in sorted(set(paths or []))]
|
|
518
|
+
for group_name, paths in self.groups.items()
|
|
519
|
+
}
|
|
520
|
+
self.config["analysis_video_groups"] = {
|
|
521
|
+
str(video_path): str(group_name)
|
|
522
|
+
for video_path, group_name in self.video_groups.items()
|
|
523
|
+
}
|
|
524
|
+
config_path = self.config.get("config_path")
|
|
525
|
+
if not config_path:
|
|
526
|
+
return
|
|
527
|
+
try:
|
|
528
|
+
os.makedirs(os.path.dirname(config_path), exist_ok=True)
|
|
529
|
+
with open(config_path, "w", encoding="utf-8") as f:
|
|
530
|
+
yaml.safe_dump(dict(self.config), f, sort_keys=False)
|
|
531
|
+
except Exception as e:
|
|
532
|
+
logger.error("Error saving group config: %s", e)
|
|
533
|
+
|
|
534
|
+
def _auto_load_results(self):
|
|
535
|
+
exp_path = self.config.get("experiment_path")
|
|
536
|
+
if exp_path:
|
|
537
|
+
results_path = os.path.join(exp_path, "results", "inference_results.json")
|
|
538
|
+
if os.path.exists(results_path):
|
|
539
|
+
self._load_results_from_file(results_path)
|
|
540
|
+
|
|
541
|
+
def _load_results(self):
|
|
542
|
+
path, _ = QFileDialog.getOpenFileName(
|
|
543
|
+
self, "Load Inference Results",
|
|
544
|
+
self.config.get("experiment_path", ""),
|
|
545
|
+
"JSON Files (*.json)"
|
|
546
|
+
)
|
|
547
|
+
if path:
|
|
548
|
+
self._load_results_from_file(path)
|
|
549
|
+
|
|
550
|
+
def _load_results_from_file(self, path):
|
|
551
|
+
try:
|
|
552
|
+
with open(path, 'r') as f:
|
|
553
|
+
self.results = json.load(f)
|
|
554
|
+
|
|
555
|
+
self.manage_groups_btn.setEnabled(True)
|
|
556
|
+
|
|
557
|
+
# Load metadata if exists
|
|
558
|
+
meta_path = path.replace("inference_results.json", "analysis_metadata.json")
|
|
559
|
+
if os.path.exists(meta_path):
|
|
560
|
+
with open(meta_path, 'r') as f:
|
|
561
|
+
meta = json.load(f)
|
|
562
|
+
if not self.groups:
|
|
563
|
+
self.groups = meta.get("groups", {})
|
|
564
|
+
if not self.video_groups:
|
|
565
|
+
self.video_groups = meta.get("video_groups", {})
|
|
566
|
+
raw_regions = meta.get("spatial_regions", [])
|
|
567
|
+
self.spatial_regions = []
|
|
568
|
+
for r in raw_regions:
|
|
569
|
+
self.spatial_regions.append({
|
|
570
|
+
"name": r["name"],
|
|
571
|
+
"type": r["type"],
|
|
572
|
+
"vertices": [tuple(v) for v in r["vertices"]],
|
|
573
|
+
})
|
|
574
|
+
|
|
575
|
+
self._update_region_filter_combo()
|
|
576
|
+
self.manage_regions_btn.setEnabled(True)
|
|
577
|
+
self._process_data()
|
|
578
|
+
self._update_plots()
|
|
579
|
+
|
|
580
|
+
except Exception as e:
|
|
581
|
+
QMessageBox.critical(self, "Error", f"Failed to load results: {e}")
|
|
582
|
+
|
|
583
|
+
@staticmethod
|
|
584
|
+
def _point_in_region(cx, cy, region):
|
|
585
|
+
verts = region["vertices"]
|
|
586
|
+
if region["type"] == "rect":
|
|
587
|
+
return verts[0][0] <= cx <= verts[1][0] and verts[0][1] <= cy <= verts[1][1]
|
|
588
|
+
from matplotlib.path import Path as MplPath
|
|
589
|
+
return MplPath(verts).contains_point((cx, cy))
|
|
590
|
+
|
|
591
|
+
def _region_for_point(self, cx, cy):
|
|
592
|
+
"""Return region name for a centroid, or 'Outside'."""
|
|
593
|
+
for r in self.spatial_regions:
|
|
594
|
+
if self._point_in_region(cx, cy, r):
|
|
595
|
+
return r["name"]
|
|
596
|
+
return "Outside"
|
|
597
|
+
|
|
598
|
+
def _build_clip_centroids(self, v_data):
|
|
599
|
+
"""Return list of (cx, cy) per clip index, or None for clips without bbox."""
|
|
600
|
+
loc_bboxes = v_data.get("localization_bboxes", [])
|
|
601
|
+
centroids = []
|
|
602
|
+
for raw in loc_bboxes:
|
|
603
|
+
cx, cy = self._bbox_centroid(raw)
|
|
604
|
+
centroids.append((cx, cy) if cx is not None else None)
|
|
605
|
+
return centroids
|
|
606
|
+
|
|
607
|
+
def _process_data(self):
|
|
608
|
+
"""Merge timelines and prepare data structures."""
|
|
609
|
+
self.merged_data = []
|
|
610
|
+
has_regions = bool(self.spatial_regions)
|
|
611
|
+
|
|
612
|
+
data_container = self.results
|
|
613
|
+
if "results" in data_container:
|
|
614
|
+
results_dict = data_container["results"]
|
|
615
|
+
classes = data_container.get("classes", [])
|
|
616
|
+
params = data_container.get("parameters", {})
|
|
617
|
+
target_fps = params.get("target_fps", 30)
|
|
618
|
+
clip_length = params.get("clip_length", 16)
|
|
619
|
+
step_frames = params.get("step_frames", 16)
|
|
620
|
+
else:
|
|
621
|
+
results_dict = data_container
|
|
622
|
+
classes = []
|
|
623
|
+
params = {}
|
|
624
|
+
target_fps = 30
|
|
625
|
+
clip_length = 16
|
|
626
|
+
step_frames = 16
|
|
627
|
+
|
|
628
|
+
frame_agg_enabled = params.get("frame_aggregation_enabled", False)
|
|
629
|
+
use_ovr_param = params.get("use_ovr", None)
|
|
630
|
+
|
|
631
|
+
for video_path, v_data in results_dict.items():
|
|
632
|
+
agg_segments = v_data.get("aggregated_segments", [])
|
|
633
|
+
agg_multiclass = v_data.get("aggregated_multiclass_segments", [])
|
|
634
|
+
use_ovr = bool(use_ovr_param) if use_ovr_param is not None else bool(agg_multiclass)
|
|
635
|
+
clip_starts = v_data.get("clip_starts", [])
|
|
636
|
+
total_frames = int(v_data.get("total_frames", 0) or 0)
|
|
637
|
+
|
|
638
|
+
# Determine orig_fps: prefer saved per-video metadata.
|
|
639
|
+
orig_fps = v_data.get("orig_fps", 0)
|
|
640
|
+
if orig_fps <= 0:
|
|
641
|
+
if os.path.exists(video_path):
|
|
642
|
+
try:
|
|
643
|
+
cap = cv2.VideoCapture(video_path)
|
|
644
|
+
orig_fps = cap.get(cv2.CAP_PROP_FPS)
|
|
645
|
+
cap.release()
|
|
646
|
+
except Exception:
|
|
647
|
+
orig_fps = 0
|
|
648
|
+
if orig_fps <= 0 and len(clip_starts) >= 2 and step_frames > 0:
|
|
649
|
+
actual_step = clip_starts[1] - clip_starts[0]
|
|
650
|
+
frame_interval = actual_step / step_frames
|
|
651
|
+
orig_fps = frame_interval * target_fps
|
|
652
|
+
if orig_fps <= 0:
|
|
653
|
+
orig_fps = 30.0
|
|
654
|
+
|
|
655
|
+
# Prefer stored inference-time interval when available.
|
|
656
|
+
frame_interval = v_data.get("frame_interval", 0)
|
|
657
|
+
try:
|
|
658
|
+
frame_interval = int(frame_interval)
|
|
659
|
+
except Exception:
|
|
660
|
+
frame_interval = 0
|
|
661
|
+
if frame_interval <= 0:
|
|
662
|
+
if len(clip_starts) >= 2 and step_frames > 0:
|
|
663
|
+
inferred = int(round((clip_starts[1] - clip_starts[0]) / step_frames))
|
|
664
|
+
frame_interval = max(1, inferred)
|
|
665
|
+
else:
|
|
666
|
+
frame_interval = max(1, int(round(orig_fps / max(1e-6, float(target_fps)))))
|
|
667
|
+
|
|
668
|
+
# Build per-clip centroid lookup for spatial region assignment
|
|
669
|
+
clip_centroids = self._build_clip_centroids(v_data) if has_regions else []
|
|
670
|
+
|
|
671
|
+
def _region_for_frame_range(start_f, end_f):
|
|
672
|
+
"""Find region for a segment by checking the clip whose start is nearest the midpoint."""
|
|
673
|
+
if not has_regions or not clip_centroids or not clip_starts:
|
|
674
|
+
return "All"
|
|
675
|
+
mid = (start_f + end_f) / 2.0
|
|
676
|
+
best_ci = 0
|
|
677
|
+
best_dist = float("inf")
|
|
678
|
+
for ci, cs in enumerate(clip_starts):
|
|
679
|
+
d = abs(cs - mid)
|
|
680
|
+
if d < best_dist:
|
|
681
|
+
best_dist = d
|
|
682
|
+
best_ci = ci
|
|
683
|
+
if best_ci < len(clip_centroids) and clip_centroids[best_ci] is not None:
|
|
684
|
+
return self._region_for_point(*clip_centroids[best_ci])
|
|
685
|
+
return "Outside"
|
|
686
|
+
|
|
687
|
+
def _region_for_clip_range(bout_start, bout_end):
|
|
688
|
+
"""Find region for a clip-based bout by majority vote of clip centroids."""
|
|
689
|
+
if not has_regions or not clip_centroids:
|
|
690
|
+
return "All"
|
|
691
|
+
region_votes = {}
|
|
692
|
+
for ci in range(bout_start, min(bout_end + 1, len(clip_centroids))):
|
|
693
|
+
c = clip_centroids[ci]
|
|
694
|
+
if c is not None:
|
|
695
|
+
rn = self._region_for_point(*c)
|
|
696
|
+
region_votes[rn] = region_votes.get(rn, 0) + 1
|
|
697
|
+
if not region_votes:
|
|
698
|
+
return "Outside"
|
|
699
|
+
return max(region_votes, key=region_votes.get)
|
|
700
|
+
|
|
701
|
+
# Use frame-level aggregated segments when available (precise boundaries).
|
|
702
|
+
# For OvR, prefer per-class multiclass segments so downstream analysis
|
|
703
|
+
# sees overlapping labels as independent behavior bouts.
|
|
704
|
+
#
|
|
705
|
+
# IMPORTANT: saved aggregated_segments have per-class ignore thresholds
|
|
706
|
+
# applied (frames below threshold are labelled class=-1 / "Filtered").
|
|
707
|
+
# When the threshold for a class is set higher than the model's actual
|
|
708
|
+
# confidence range for that class, the entire class disappears from the
|
|
709
|
+
# saved segments even though the model predicted it as dominant.
|
|
710
|
+
# To avoid this, we rebuild segments from the raw frame probabilities
|
|
711
|
+
# (argmax, no threshold) when they are available so that analysis always
|
|
712
|
+
# reflects true model predictions.
|
|
713
|
+
agg_probs = v_data.get("aggregated_frame_probs")
|
|
714
|
+
|
|
715
|
+
def _relabel_filtered_segs(segs, probs_list):
|
|
716
|
+
"""Relabel class=-1 (Filtered) segments using the argmax of the
|
|
717
|
+
underlying frame probabilities over that segment's frame range.
|
|
718
|
+
|
|
719
|
+
The saved aggregated_segments already encode the correct temporal
|
|
720
|
+
structure (temporal smoothing, merge-gap, min-segment all applied).
|
|
721
|
+
The only problem is that the per-class ignore threshold may have
|
|
722
|
+
labelled some segments as Filtered (-1) even though the model
|
|
723
|
+
predicted a real class with high confidence. This function restores
|
|
724
|
+
the true model prediction for those segments without touching the
|
|
725
|
+
segment boundaries.
|
|
726
|
+
"""
|
|
727
|
+
if not probs_list or not segs:
|
|
728
|
+
return segs
|
|
729
|
+
n_probs = len(probs_list)
|
|
730
|
+
result = []
|
|
731
|
+
for seg in segs:
|
|
732
|
+
if seg.get("class", -1) >= 0:
|
|
733
|
+
result.append(seg)
|
|
734
|
+
continue
|
|
735
|
+
# Filtered segment: determine label from mean probs over its range
|
|
736
|
+
s = max(0, int(seg["start"]))
|
|
737
|
+
e = min(n_probs - 1, int(seg["end"]))
|
|
738
|
+
if s > e:
|
|
739
|
+
result.append(seg)
|
|
740
|
+
continue
|
|
741
|
+
# Sum probabilities over the frame range and pick argmax
|
|
742
|
+
n_cls = len(probs_list[s])
|
|
743
|
+
totals = [0.0] * n_cls
|
|
744
|
+
for fi in range(s, e + 1):
|
|
745
|
+
for ci, p in enumerate(probs_list[fi]):
|
|
746
|
+
totals[ci] += p
|
|
747
|
+
best = int(max(range(n_cls), key=lambda ci: totals[ci]))
|
|
748
|
+
new_seg = dict(seg)
|
|
749
|
+
new_seg["class"] = best
|
|
750
|
+
result.append(new_seg)
|
|
751
|
+
return result
|
|
752
|
+
|
|
753
|
+
if use_ovr and agg_multiclass:
|
|
754
|
+
seg_source = agg_multiclass
|
|
755
|
+
elif frame_agg_enabled and agg_probs:
|
|
756
|
+
# Use saved segments (correct temporal structure: temporal smoothing,
|
|
757
|
+
# merge-gap, min-segment all already applied) but relabel any
|
|
758
|
+
# Filtered (-1) segments with the true argmax from raw probabilities.
|
|
759
|
+
seg_source = _relabel_filtered_segs(agg_segments, agg_probs)
|
|
760
|
+
else:
|
|
761
|
+
seg_source = agg_segments
|
|
762
|
+
|
|
763
|
+
if frame_agg_enabled and seg_source:
|
|
764
|
+
covered_frames = 0
|
|
765
|
+
for seg in seg_source:
|
|
766
|
+
pred_idx = seg["class"]
|
|
767
|
+
start_frame = seg["start"]
|
|
768
|
+
end_frame = seg["end"]
|
|
769
|
+
covered_frames += max(0, (end_frame - start_frame + 1))
|
|
770
|
+
|
|
771
|
+
if pred_idx < 0:
|
|
772
|
+
label_name = "Filtered"
|
|
773
|
+
elif classes and pred_idx < len(classes):
|
|
774
|
+
label_name = classes[pred_idx]
|
|
775
|
+
else:
|
|
776
|
+
label_name = f"Class {pred_idx}"
|
|
777
|
+
|
|
778
|
+
duration_sec = (end_frame - start_frame + 1) / orig_fps
|
|
779
|
+
|
|
780
|
+
self.merged_data.append({
|
|
781
|
+
"Video": os.path.basename(video_path),
|
|
782
|
+
"VideoPath": video_path,
|
|
783
|
+
"Group": self.video_groups.get(video_path, "Unassigned"),
|
|
784
|
+
"Behavior": label_name,
|
|
785
|
+
"Duration": duration_sec,
|
|
786
|
+
"Region": _region_for_frame_range(start_frame, end_frame),
|
|
787
|
+
})
|
|
788
|
+
|
|
789
|
+
# Keep totals comparable across videos by accounting for uncovered tail.
|
|
790
|
+
if total_frames > 0 and covered_frames < total_frames:
|
|
791
|
+
self.merged_data.append({
|
|
792
|
+
"Video": os.path.basename(video_path),
|
|
793
|
+
"VideoPath": video_path,
|
|
794
|
+
"Group": self.video_groups.get(video_path, "Unassigned"),
|
|
795
|
+
"Behavior": "Uncovered",
|
|
796
|
+
"Duration": (total_frames - covered_frames) / orig_fps,
|
|
797
|
+
"Region": "All",
|
|
798
|
+
})
|
|
799
|
+
continue
|
|
800
|
+
|
|
801
|
+
# Fallback: clip-based bout detection
|
|
802
|
+
preds = v_data.get("predictions", [])
|
|
803
|
+
corrections = v_data.get("corrected_labels", {})
|
|
804
|
+
confs = v_data.get("confidences", [])
|
|
805
|
+
|
|
806
|
+
if not preds:
|
|
807
|
+
continue
|
|
808
|
+
|
|
809
|
+
# Reconstruct ignore threshold from saved parameters
|
|
810
|
+
apply_ignore = params.get("use_ignore_threshold", False)
|
|
811
|
+
ignore_thr = float(params.get("ignore_threshold", 0.5))
|
|
812
|
+
|
|
813
|
+
# Apply corrections and threshold filtering
|
|
814
|
+
final_preds = []
|
|
815
|
+
for i, p in enumerate(preds):
|
|
816
|
+
if str(i) in corrections:
|
|
817
|
+
pred = corrections[str(i)]
|
|
818
|
+
elif i in corrections:
|
|
819
|
+
pred = corrections[i]
|
|
820
|
+
else:
|
|
821
|
+
pred = p
|
|
822
|
+
if apply_ignore and i < len(confs) and float(confs[i]) < ignore_thr:
|
|
823
|
+
pred = -1
|
|
824
|
+
final_preds.append(pred)
|
|
825
|
+
|
|
826
|
+
if not final_preds:
|
|
827
|
+
continue
|
|
828
|
+
|
|
829
|
+
# Derive bouts by clip index, then convert to seconds in original timeline.
|
|
830
|
+
bout_start_idx = 0
|
|
831
|
+
current_label = final_preds[0]
|
|
832
|
+
covered_frames = 0
|
|
833
|
+
|
|
834
|
+
for i in range(1, len(final_preds) + 1):
|
|
835
|
+
boundary = (i == len(final_preds)) or (final_preds[i] != current_label)
|
|
836
|
+
if not boundary:
|
|
837
|
+
continue
|
|
838
|
+
|
|
839
|
+
bout_end_idx = i - 1
|
|
840
|
+
if clip_starts and len(clip_starts) == len(final_preds):
|
|
841
|
+
start_frame = int(clip_starts[bout_start_idx])
|
|
842
|
+
if i < len(clip_starts):
|
|
843
|
+
end_frame_exclusive = int(clip_starts[i])
|
|
844
|
+
else:
|
|
845
|
+
end_frame_exclusive = start_frame + (clip_length * frame_interval)
|
|
846
|
+
if total_frames > 0:
|
|
847
|
+
end_frame_exclusive = min(end_frame_exclusive, total_frames)
|
|
848
|
+
duration_frames = max(0, end_frame_exclusive - start_frame)
|
|
849
|
+
covered_frames += duration_frames
|
|
850
|
+
duration_sec = duration_frames / orig_fps
|
|
851
|
+
else:
|
|
852
|
+
# Legacy fallback when clip starts are missing.
|
|
853
|
+
clip_count = bout_end_idx - bout_start_idx + 1
|
|
854
|
+
duration_subsampled = ((clip_count - 1) * step_frames + clip_length)
|
|
855
|
+
duration_sec = duration_subsampled / max(1e-6, float(target_fps))
|
|
856
|
+
|
|
857
|
+
if current_label < 0:
|
|
858
|
+
label_name = "Filtered"
|
|
859
|
+
elif classes and current_label < len(classes):
|
|
860
|
+
label_name = classes[current_label]
|
|
861
|
+
else:
|
|
862
|
+
label_name = f"Class {current_label}"
|
|
863
|
+
|
|
864
|
+
self.merged_data.append({
|
|
865
|
+
"Video": os.path.basename(video_path),
|
|
866
|
+
"VideoPath": video_path,
|
|
867
|
+
"Group": self.video_groups.get(video_path, "Unassigned"),
|
|
868
|
+
"Behavior": label_name,
|
|
869
|
+
"Duration": duration_sec,
|
|
870
|
+
"Region": _region_for_clip_range(bout_start_idx, bout_end_idx),
|
|
871
|
+
})
|
|
872
|
+
|
|
873
|
+
if i < len(final_preds):
|
|
874
|
+
current_label = final_preds[i]
|
|
875
|
+
bout_start_idx = i
|
|
876
|
+
|
|
877
|
+
if total_frames > 0 and covered_frames < total_frames:
|
|
878
|
+
self.merged_data.append({
|
|
879
|
+
"Video": os.path.basename(video_path),
|
|
880
|
+
"VideoPath": video_path,
|
|
881
|
+
"Group": self.video_groups.get(video_path, "Unassigned"),
|
|
882
|
+
"Behavior": "Uncovered",
|
|
883
|
+
"Duration": (total_frames - covered_frames) / orig_fps,
|
|
884
|
+
"Region": "All",
|
|
885
|
+
})
|
|
886
|
+
|
|
887
|
+
# Update available behaviors and selection (include all model classes so behaviors
|
|
888
|
+
# with zero or few segments still appear in filters and plots)
|
|
889
|
+
if self.merged_data:
|
|
890
|
+
df = pd.DataFrame(self.merged_data)
|
|
891
|
+
from_data = set(df["Behavior"].unique().tolist())
|
|
892
|
+
classes = []
|
|
893
|
+
if isinstance(self.results, dict):
|
|
894
|
+
classes = self.results.get("classes", [])
|
|
895
|
+
self.all_behaviors = sorted(from_data | set(classes))
|
|
896
|
+
|
|
897
|
+
# Initialize selection if empty (first load)
|
|
898
|
+
if not self.selected_behaviors:
|
|
899
|
+
self.selected_behaviors = set(self.all_behaviors)
|
|
900
|
+
else:
|
|
901
|
+
# Clean up stale behaviors
|
|
902
|
+
self.selected_behaviors = self.selected_behaviors.intersection(set(self.all_behaviors))
|
|
903
|
+
# If nothing selected (e.g. all prev behaviors gone), select all new ones
|
|
904
|
+
if not self.selected_behaviors:
|
|
905
|
+
self.selected_behaviors = set(self.all_behaviors)
|
|
906
|
+
|
|
907
|
+
self.filter_behaviors_btn.setEnabled(True)
|
|
908
|
+
|
|
909
|
+
# Update transition tab
|
|
910
|
+
self._on_transition_type_changed()
|
|
911
|
+
self._update_sidebar_groups()
|
|
912
|
+
|
|
913
|
+
# Populate spatial distribution combos
|
|
914
|
+
self._populate_spatial_combos()
|
|
915
|
+
else:
|
|
916
|
+
self.filter_behaviors_btn.setEnabled(False)
|
|
917
|
+
|
|
918
|
+
# ------------------------------------------------------------------
|
|
919
|
+
# Spatial distribution (localization-based)
|
|
920
|
+
# ------------------------------------------------------------------
|
|
921
|
+
|
|
922
|
+
def _populate_spatial_combos(self):
|
|
923
|
+
"""Populate behavior and video combos for spatial distribution."""
|
|
924
|
+
data_container = self.results
|
|
925
|
+
results_dict = data_container.get("results", data_container)
|
|
926
|
+
classes = data_container.get("classes", [])
|
|
927
|
+
|
|
928
|
+
# Check if any video has localization data
|
|
929
|
+
has_loc = any(
|
|
930
|
+
"localization_bboxes" in v_data
|
|
931
|
+
for v_data in (results_dict.values() if isinstance(results_dict, dict) else [])
|
|
932
|
+
)
|
|
933
|
+
|
|
934
|
+
self.spatial_behavior_combo.blockSignals(True)
|
|
935
|
+
self.spatial_behavior_combo.clear()
|
|
936
|
+
self.spatial_behavior_combo.addItem("All behaviors")
|
|
937
|
+
for b in self.all_behaviors:
|
|
938
|
+
self.spatial_behavior_combo.addItem(b)
|
|
939
|
+
self.spatial_behavior_combo.blockSignals(False)
|
|
940
|
+
|
|
941
|
+
self.spatial_video_combo.blockSignals(True)
|
|
942
|
+
self.spatial_video_combo.clear()
|
|
943
|
+
self.spatial_video_combo.addItem("All")
|
|
944
|
+
if isinstance(results_dict, dict):
|
|
945
|
+
for vp in results_dict:
|
|
946
|
+
self.spatial_video_combo.addItem(os.path.basename(vp))
|
|
947
|
+
self.spatial_video_combo.blockSignals(False)
|
|
948
|
+
|
|
949
|
+
self.spatial_show_btn.setEnabled(has_loc)
|
|
950
|
+
self.manage_regions_btn.setEnabled(has_loc)
|
|
951
|
+
if has_loc:
|
|
952
|
+
self.spatial_info_label.setText("Localization data available.")
|
|
953
|
+
else:
|
|
954
|
+
self.spatial_info_label.setText("No localization data found in results.")
|
|
955
|
+
|
|
956
|
+
def _extract_spatial_data(self):
|
|
957
|
+
"""Extract per-clip centroids grouped by behavior from loaded results.
|
|
958
|
+
|
|
959
|
+
Returns dict:
|
|
960
|
+
{
|
|
961
|
+
video_basename: {
|
|
962
|
+
"all_centroids": [(cx, cy), ...],
|
|
963
|
+
"behavior_centroids": {behavior_name: [(cx, cy), ...], ...}
|
|
964
|
+
}
|
|
965
|
+
}
|
|
966
|
+
"""
|
|
967
|
+
data_container = self.results
|
|
968
|
+
results_dict = data_container.get("results", data_container)
|
|
969
|
+
classes = data_container.get("classes", [])
|
|
970
|
+
params = data_container.get("parameters", {})
|
|
971
|
+
|
|
972
|
+
spatial_data = {}
|
|
973
|
+
|
|
974
|
+
for video_path, v_data in (results_dict.items() if isinstance(results_dict, dict) else []):
|
|
975
|
+
loc_bboxes = v_data.get("localization_bboxes", [])
|
|
976
|
+
if not loc_bboxes:
|
|
977
|
+
continue
|
|
978
|
+
|
|
979
|
+
preds = v_data.get("predictions", [])
|
|
980
|
+
corrections = v_data.get("corrected_labels", {})
|
|
981
|
+
video_name = os.path.basename(video_path)
|
|
982
|
+
|
|
983
|
+
all_centroids = []
|
|
984
|
+
behavior_centroids = {}
|
|
985
|
+
|
|
986
|
+
for clip_idx, raw in enumerate(loc_bboxes):
|
|
987
|
+
# Compute centroid from bbox
|
|
988
|
+
cx, cy = self._bbox_centroid(raw)
|
|
989
|
+
if cx is None:
|
|
990
|
+
continue
|
|
991
|
+
|
|
992
|
+
all_centroids.append((cx, cy))
|
|
993
|
+
|
|
994
|
+
# Determine behavior label for this clip
|
|
995
|
+
if clip_idx < len(preds):
|
|
996
|
+
pred_idx = preds[clip_idx]
|
|
997
|
+
# Apply correction if exists
|
|
998
|
+
if str(clip_idx) in corrections:
|
|
999
|
+
pred_idx = corrections[str(clip_idx)]
|
|
1000
|
+
elif clip_idx in corrections:
|
|
1001
|
+
pred_idx = corrections[clip_idx]
|
|
1002
|
+
|
|
1003
|
+
if classes and pred_idx < len(classes):
|
|
1004
|
+
label = classes[pred_idx]
|
|
1005
|
+
else:
|
|
1006
|
+
label = f"Class {pred_idx}"
|
|
1007
|
+
|
|
1008
|
+
behavior_centroids.setdefault(label, []).append((cx, cy))
|
|
1009
|
+
|
|
1010
|
+
if all_centroids:
|
|
1011
|
+
spatial_data[video_name] = {
|
|
1012
|
+
"all_centroids": all_centroids,
|
|
1013
|
+
"behavior_centroids": behavior_centroids,
|
|
1014
|
+
}
|
|
1015
|
+
|
|
1016
|
+
return spatial_data
|
|
1017
|
+
|
|
1018
|
+
@staticmethod
|
|
1019
|
+
def _bbox_centroid(raw):
|
|
1020
|
+
"""Return (cx, cy) from a localization bbox entry, or (None, None)."""
|
|
1021
|
+
try:
|
|
1022
|
+
if not isinstance(raw, (list, tuple)) or len(raw) == 0:
|
|
1023
|
+
return None, None
|
|
1024
|
+
# Single bbox [x1, y1, x2, y2]
|
|
1025
|
+
if len(raw) == 4 and all(not isinstance(v, (list, tuple)) for v in raw):
|
|
1026
|
+
x1, y1, x2, y2 = [float(v) for v in raw]
|
|
1027
|
+
return (x1 + x2) / 2.0, (y1 + y2) / 2.0
|
|
1028
|
+
# Per-frame bboxes [[x1,y1,x2,y2], ...] — use middle frame
|
|
1029
|
+
if isinstance(raw[0], (list, tuple)):
|
|
1030
|
+
mid = len(raw) // 2
|
|
1031
|
+
box = raw[mid]
|
|
1032
|
+
x1, y1, x2, y2 = [float(v) for v in box]
|
|
1033
|
+
return (x1 + x2) / 2.0, (y1 + y2) / 2.0
|
|
1034
|
+
except Exception as e:
|
|
1035
|
+
logger.debug("Could not parse localization bbox center: %s", e)
|
|
1036
|
+
return None, None
|
|
1037
|
+
|
|
1038
|
+
def _update_spatial_plot(self):
|
|
1039
|
+
"""Build and display the spatial distribution plot."""
|
|
1040
|
+
if not HAS_PLOTLY:
|
|
1041
|
+
return
|
|
1042
|
+
|
|
1043
|
+
spatial_data = self._extract_spatial_data()
|
|
1044
|
+
if not spatial_data:
|
|
1045
|
+
self.spatial_info_label.setText("No localization centroids could be extracted.")
|
|
1046
|
+
return
|
|
1047
|
+
|
|
1048
|
+
selected_behavior = self.spatial_behavior_combo.currentText()
|
|
1049
|
+
selected_video = self.spatial_video_combo.currentText()
|
|
1050
|
+
theme = self.color_theme_combo.currentText()
|
|
1051
|
+
if theme == "none":
|
|
1052
|
+
theme = None
|
|
1053
|
+
|
|
1054
|
+
fig = go.Figure()
|
|
1055
|
+
|
|
1056
|
+
# Collect centroids based on video filter
|
|
1057
|
+
all_cx, all_cy = [], []
|
|
1058
|
+
beh_centroids = {} # {behavior: [(cx, cy), ...]}
|
|
1059
|
+
|
|
1060
|
+
for video_name, vdata in spatial_data.items():
|
|
1061
|
+
if selected_video != "All" and video_name != selected_video:
|
|
1062
|
+
continue
|
|
1063
|
+
for cx, cy in vdata["all_centroids"]:
|
|
1064
|
+
all_cx.append(cx)
|
|
1065
|
+
all_cy.append(cy)
|
|
1066
|
+
for beh, pts in vdata["behavior_centroids"].items():
|
|
1067
|
+
beh_centroids.setdefault(beh, []).extend(pts)
|
|
1068
|
+
|
|
1069
|
+
if not all_cx:
|
|
1070
|
+
self.spatial_info_label.setText("No centroids for the selected filter.")
|
|
1071
|
+
return
|
|
1072
|
+
|
|
1073
|
+
dot_size = self.spatial_dot_size_slider.value()
|
|
1074
|
+
base_size = max(1, dot_size - 3)
|
|
1075
|
+
|
|
1076
|
+
# Base layer: all centroids as light grey
|
|
1077
|
+
fig.add_trace(go.Scatter(
|
|
1078
|
+
x=all_cx,
|
|
1079
|
+
y=all_cy,
|
|
1080
|
+
mode='markers',
|
|
1081
|
+
marker=dict(color='lightgrey', size=base_size, opacity=0.4),
|
|
1082
|
+
name='All clips',
|
|
1083
|
+
showlegend=True,
|
|
1084
|
+
legendgroup='all',
|
|
1085
|
+
hoverinfo='skip',
|
|
1086
|
+
))
|
|
1087
|
+
|
|
1088
|
+
# Use same palette as Bar Plot (General Overview) so behavior colors match
|
|
1089
|
+
behaviors = sorted(beh_centroids.keys())
|
|
1090
|
+
color_palette = px.colors.qualitative.Plotly
|
|
1091
|
+
|
|
1092
|
+
if selected_behavior == "All behaviors":
|
|
1093
|
+
for i, beh in enumerate(behaviors):
|
|
1094
|
+
pts = beh_centroids[beh]
|
|
1095
|
+
bx = [p[0] for p in pts]
|
|
1096
|
+
by = [p[1] for p in pts]
|
|
1097
|
+
color = color_palette[i % len(color_palette)]
|
|
1098
|
+
fig.add_trace(go.Scatter(
|
|
1099
|
+
x=bx, y=by,
|
|
1100
|
+
mode='markers',
|
|
1101
|
+
marker=dict(color=color, size=dot_size, opacity=0.8),
|
|
1102
|
+
name=beh,
|
|
1103
|
+
showlegend=True,
|
|
1104
|
+
legendgroup=beh,
|
|
1105
|
+
hovertext=[beh] * len(bx),
|
|
1106
|
+
hoverinfo='text',
|
|
1107
|
+
))
|
|
1108
|
+
title = "Spatial Distribution: All Behaviors"
|
|
1109
|
+
else:
|
|
1110
|
+
pts = beh_centroids.get(selected_behavior, [])
|
|
1111
|
+
bx = [p[0] for p in pts]
|
|
1112
|
+
by = [p[1] for p in pts]
|
|
1113
|
+
color = color_palette[behaviors.index(selected_behavior) % len(color_palette)] if selected_behavior in behaviors else color_palette[0]
|
|
1114
|
+
fig.add_trace(go.Scatter(
|
|
1115
|
+
x=bx, y=by,
|
|
1116
|
+
mode='markers',
|
|
1117
|
+
marker=dict(color=color, size=dot_size, opacity=0.8),
|
|
1118
|
+
name=selected_behavior,
|
|
1119
|
+
showlegend=True,
|
|
1120
|
+
legendgroup='behavior',
|
|
1121
|
+
hovertext=[selected_behavior] * len(bx),
|
|
1122
|
+
hoverinfo='text',
|
|
1123
|
+
))
|
|
1124
|
+
title = f"Spatial Distribution: {selected_behavior} ({len(pts)} clips)"
|
|
1125
|
+
|
|
1126
|
+
# Overlay defined spatial regions as semi-transparent fills
|
|
1127
|
+
region_colors = px.colors.qualitative.Set2
|
|
1128
|
+
for ri, region in enumerate(self.spatial_regions):
|
|
1129
|
+
verts = region["vertices"]
|
|
1130
|
+
rc = region_colors[ri % len(region_colors)]
|
|
1131
|
+
if region["type"] == "rect" and len(verts) == 2:
|
|
1132
|
+
rx = [verts[0][0], verts[1][0], verts[1][0], verts[0][0], verts[0][0]]
|
|
1133
|
+
ry = [verts[0][1], verts[0][1], verts[1][1], verts[1][1], verts[0][1]]
|
|
1134
|
+
else:
|
|
1135
|
+
rx = [v[0] for v in verts] + [verts[0][0]]
|
|
1136
|
+
ry = [v[1] for v in verts] + [verts[0][1]]
|
|
1137
|
+
fig.add_trace(go.Scatter(
|
|
1138
|
+
x=rx, y=ry, fill="toself",
|
|
1139
|
+
fillcolor=rc, opacity=0.15,
|
|
1140
|
+
line=dict(color=rc, width=2),
|
|
1141
|
+
name=region["name"],
|
|
1142
|
+
showlegend=True,
|
|
1143
|
+
legendgroup=f"region_{ri}",
|
|
1144
|
+
hoverinfo="name",
|
|
1145
|
+
))
|
|
1146
|
+
|
|
1147
|
+
filter_parts = []
|
|
1148
|
+
if selected_video != "All":
|
|
1149
|
+
filter_parts.append(f"Video: {selected_video}")
|
|
1150
|
+
filter_str = f" | {', '.join(filter_parts)}" if filter_parts else ""
|
|
1151
|
+
|
|
1152
|
+
fig.update_layout(
|
|
1153
|
+
title=f'{title}{filter_str}',
|
|
1154
|
+
xaxis_title='X (normalized)',
|
|
1155
|
+
yaxis_title='Y (normalized)',
|
|
1156
|
+
xaxis=dict(scaleanchor="y", scaleratio=1, range=[0, 1]),
|
|
1157
|
+
yaxis=dict(autorange='reversed', range=[0, 1]),
|
|
1158
|
+
height=600,
|
|
1159
|
+
template=theme,
|
|
1160
|
+
hovermode='closest',
|
|
1161
|
+
)
|
|
1162
|
+
|
|
1163
|
+
self.last_spatial_fig = fig
|
|
1164
|
+
|
|
1165
|
+
# Render to same webview
|
|
1166
|
+
import tempfile
|
|
1167
|
+
temp_dir = os.path.join(self.config.get("data_dir", "."), "temp_plots")
|
|
1168
|
+
os.makedirs(temp_dir, exist_ok=True)
|
|
1169
|
+
plot_path = os.path.join(temp_dir, "spatial_distribution.html")
|
|
1170
|
+
|
|
1171
|
+
with open(plot_path, 'w', encoding="utf-8") as f:
|
|
1172
|
+
f.write(fig.to_html(include_plotlyjs=True))
|
|
1173
|
+
|
|
1174
|
+
if HAS_WEBENGINE:
|
|
1175
|
+
self.webview.setUrl(QUrl.fromLocalFile(os.path.abspath(plot_path)))
|
|
1176
|
+
|
|
1177
|
+
self.spatial_info_label.setText(f"Showing {len(all_cx)} total centroids.")
|
|
1178
|
+
|
|
1179
|
+
def _save_spatial_plot(self):
|
|
1180
|
+
"""Save the current spatial distribution plot as PDF or SVG."""
|
|
1181
|
+
if not HAS_PLOTLY or not getattr(self, "last_spatial_fig", None):
|
|
1182
|
+
QMessageBox.warning(self, "Save spatial plot", "No spatial distribution plot to save. Show the plot first.")
|
|
1183
|
+
return
|
|
1184
|
+
path, selected_filter = QFileDialog.getSaveFileName(
|
|
1185
|
+
self, "Save spatial plot",
|
|
1186
|
+
self.config.get("experiment_path", ""),
|
|
1187
|
+
"PDF Files (*.pdf);;SVG Files (*.svg);;PNG Files (*.png);;HTML Files (*.html)"
|
|
1188
|
+
)
|
|
1189
|
+
if not path:
|
|
1190
|
+
return
|
|
1191
|
+
try:
|
|
1192
|
+
if path.lower().endswith(".html"):
|
|
1193
|
+
self.last_spatial_fig.write_html(path)
|
|
1194
|
+
else:
|
|
1195
|
+
self.last_spatial_fig.write_image(path)
|
|
1196
|
+
QMessageBox.information(self, "Success", f"Spatial plot saved to {path}")
|
|
1197
|
+
except Exception as e:
|
|
1198
|
+
if "kaleido" in str(e).lower() or "executable" in str(e).lower():
|
|
1199
|
+
QMessageBox.warning(self, "Error", "Saving as PDF/SVG requires 'kaleido'. Install with: pip install kaleido")
|
|
1200
|
+
else:
|
|
1201
|
+
QMessageBox.critical(self, "Error", f"Failed to save: {e}")
|
|
1202
|
+
|
|
1203
|
+
def _filter_behaviors(self):
|
|
1204
|
+
"""Open dialog to filter behaviors."""
|
|
1205
|
+
dialog = QDialog(self)
|
|
1206
|
+
dialog.setWindowTitle("Select behaviors")
|
|
1207
|
+
dialog.resize(300, 400)
|
|
1208
|
+
layout = QVBoxLayout()
|
|
1209
|
+
|
|
1210
|
+
# Buttons
|
|
1211
|
+
btn_layout = QHBoxLayout()
|
|
1212
|
+
all_btn = QPushButton("Select all")
|
|
1213
|
+
none_btn = QPushButton("Deselect all")
|
|
1214
|
+
btn_layout.addWidget(all_btn)
|
|
1215
|
+
btn_layout.addWidget(none_btn)
|
|
1216
|
+
layout.addLayout(btn_layout)
|
|
1217
|
+
|
|
1218
|
+
# Scrollable Checkbox List
|
|
1219
|
+
scroll = QScrollArea()
|
|
1220
|
+
scroll.setWidgetResizable(True)
|
|
1221
|
+
widget = QWidget()
|
|
1222
|
+
vbox = QVBoxLayout()
|
|
1223
|
+
|
|
1224
|
+
checkboxes = []
|
|
1225
|
+
for beh in self.all_behaviors:
|
|
1226
|
+
cb = QCheckBox(beh)
|
|
1227
|
+
if beh in self.selected_behaviors:
|
|
1228
|
+
cb.setChecked(True)
|
|
1229
|
+
checkboxes.append((beh, cb))
|
|
1230
|
+
vbox.addWidget(cb)
|
|
1231
|
+
|
|
1232
|
+
vbox.addStretch()
|
|
1233
|
+
widget.setLayout(vbox)
|
|
1234
|
+
scroll.setWidget(widget)
|
|
1235
|
+
layout.addWidget(scroll)
|
|
1236
|
+
|
|
1237
|
+
# Connect All/None
|
|
1238
|
+
def select_all():
|
|
1239
|
+
for _, cb in checkboxes: cb.setChecked(True)
|
|
1240
|
+
def select_none():
|
|
1241
|
+
for _, cb in checkboxes: cb.setChecked(False)
|
|
1242
|
+
|
|
1243
|
+
all_btn.clicked.connect(select_all)
|
|
1244
|
+
none_btn.clicked.connect(select_none)
|
|
1245
|
+
|
|
1246
|
+
# OK Button
|
|
1247
|
+
ok_btn = QPushButton("Update plots")
|
|
1248
|
+
ok_btn.clicked.connect(dialog.accept)
|
|
1249
|
+
layout.addWidget(ok_btn)
|
|
1250
|
+
|
|
1251
|
+
dialog.setLayout(layout)
|
|
1252
|
+
|
|
1253
|
+
if dialog.exec():
|
|
1254
|
+
self.selected_behaviors = {beh for beh, cb in checkboxes if cb.isChecked()}
|
|
1255
|
+
self._update_plots()
|
|
1256
|
+
|
|
1257
|
+
def _manage_groups(self):
|
|
1258
|
+
# Get video paths from results
|
|
1259
|
+
if "results" in self.results:
|
|
1260
|
+
video_paths = list(self.results["results"].keys())
|
|
1261
|
+
else:
|
|
1262
|
+
video_paths = list(self.results.keys())
|
|
1263
|
+
|
|
1264
|
+
dialog = GroupManagementDialog(video_paths, self.groups, self.video_groups, self)
|
|
1265
|
+
if dialog.exec():
|
|
1266
|
+
self.groups = dialog.groups
|
|
1267
|
+
self.video_groups = dialog.video_groups
|
|
1268
|
+
self._save_groups_to_config()
|
|
1269
|
+
self._save_metadata()
|
|
1270
|
+
|
|
1271
|
+
# Re-process to update groups in merged_data
|
|
1272
|
+
self._process_data()
|
|
1273
|
+
self._update_plots()
|
|
1274
|
+
|
|
1275
|
+
def _save_metadata(self):
|
|
1276
|
+
exp_path = self.config.get("experiment_path")
|
|
1277
|
+
if exp_path:
|
|
1278
|
+
meta_path = os.path.join(exp_path, "results", "analysis_metadata.json")
|
|
1279
|
+
try:
|
|
1280
|
+
os.makedirs(os.path.dirname(meta_path), exist_ok=True)
|
|
1281
|
+
serializable_regions = []
|
|
1282
|
+
for r in self.spatial_regions:
|
|
1283
|
+
serializable_regions.append({
|
|
1284
|
+
"name": r["name"],
|
|
1285
|
+
"type": r["type"],
|
|
1286
|
+
"vertices": [list(v) for v in r["vertices"]],
|
|
1287
|
+
})
|
|
1288
|
+
with open(meta_path, 'w') as f:
|
|
1289
|
+
json.dump({
|
|
1290
|
+
"groups": self.groups,
|
|
1291
|
+
"video_groups": self.video_groups,
|
|
1292
|
+
"spatial_regions": serializable_regions,
|
|
1293
|
+
}, f, indent=2)
|
|
1294
|
+
except Exception as e:
|
|
1295
|
+
logger.error("Error saving metadata: %s", e)
|
|
1296
|
+
|
|
1297
|
+
def _on_region_filter_changed(self):
|
|
1298
|
+
if self.merged_data:
|
|
1299
|
+
self._update_plots()
|
|
1300
|
+
|
|
1301
|
+
def _update_region_filter_combo(self):
|
|
1302
|
+
self.region_filter_combo.blockSignals(True)
|
|
1303
|
+
prev = self.region_filter_combo.currentText()
|
|
1304
|
+
self.region_filter_combo.clear()
|
|
1305
|
+
self.region_filter_combo.addItem("All Regions")
|
|
1306
|
+
for r in self.spatial_regions:
|
|
1307
|
+
self.region_filter_combo.addItem(r["name"])
|
|
1308
|
+
self.region_filter_combo.addItem("Outside")
|
|
1309
|
+
idx = self.region_filter_combo.findText(prev)
|
|
1310
|
+
if idx >= 0:
|
|
1311
|
+
self.region_filter_combo.setCurrentIndex(idx)
|
|
1312
|
+
self.region_filter_combo.blockSignals(False)
|
|
1313
|
+
|
|
1314
|
+
def _manage_spatial_regions(self):
|
|
1315
|
+
spatial_data = self._extract_spatial_data()
|
|
1316
|
+
all_centroids = []
|
|
1317
|
+
for vdata in spatial_data.values():
|
|
1318
|
+
all_centroids.extend(vdata["all_centroids"])
|
|
1319
|
+
dialog = SpatialRegionEditor(all_centroids, self.spatial_regions, self)
|
|
1320
|
+
if dialog.exec():
|
|
1321
|
+
self.spatial_regions = dialog.regions
|
|
1322
|
+
self._save_metadata()
|
|
1323
|
+
self._update_region_filter_combo()
|
|
1324
|
+
self._process_data()
|
|
1325
|
+
self._update_plots()
|
|
1326
|
+
|
|
1327
|
+
def _update_plots(self):
|
|
1328
|
+
if not HAS_PLOTLY or not self.merged_data:
|
|
1329
|
+
if not HAS_PLOTLY:
|
|
1330
|
+
QMessageBox.warning(self, "Error", "Plotly is not installed.")
|
|
1331
|
+
return
|
|
1332
|
+
|
|
1333
|
+
df = pd.DataFrame(self.merged_data)
|
|
1334
|
+
|
|
1335
|
+
# Filter behaviors
|
|
1336
|
+
if self.selected_behaviors:
|
|
1337
|
+
df = df[df["Behavior"].isin(self.selected_behaviors)]
|
|
1338
|
+
|
|
1339
|
+
# Filter groups
|
|
1340
|
+
if self.visible_groups:
|
|
1341
|
+
df = df[df["Group"].isin(self.visible_groups)]
|
|
1342
|
+
|
|
1343
|
+
# Filter by spatial region
|
|
1344
|
+
selected_region = self.region_filter_combo.currentText()
|
|
1345
|
+
if selected_region != "All Regions" and "Region" in df.columns:
|
|
1346
|
+
df = df[df["Region"] == selected_region]
|
|
1347
|
+
|
|
1348
|
+
if df.empty:
|
|
1349
|
+
self.last_fig = None
|
|
1350
|
+
return
|
|
1351
|
+
|
|
1352
|
+
metric = self.metric_combo.currentText()
|
|
1353
|
+
analysis_mode = self.plot_mode_combo.currentText()
|
|
1354
|
+
graph_type = self.graph_type_combo.currentText()
|
|
1355
|
+
theme = self.color_theme_combo.currentText()
|
|
1356
|
+
|
|
1357
|
+
fig = None
|
|
1358
|
+
|
|
1359
|
+
# Aggregation
|
|
1360
|
+
if "Occurrences" in metric:
|
|
1361
|
+
agg = df.groupby(["Video", "Group", "Behavior"]).size().reset_index(name="Value")
|
|
1362
|
+
y_label = "Count"
|
|
1363
|
+
elif "Average" in metric:
|
|
1364
|
+
agg = df.groupby(["Video", "Group", "Behavior"])["Duration"].mean().reset_index(name="Value")
|
|
1365
|
+
y_label = "Avg Duration (s)"
|
|
1366
|
+
elif "Total" in metric:
|
|
1367
|
+
agg = df.groupby(["Video", "Group", "Behavior"])["Duration"].sum().reset_index(name="Value")
|
|
1368
|
+
y_label = "Total Duration (s)"
|
|
1369
|
+
elif "Percent" in metric:
|
|
1370
|
+
# Calculate total time per video
|
|
1371
|
+
total_times = df.groupby("Video")["Duration"].sum().to_dict()
|
|
1372
|
+
agg = df.groupby(["Video", "Group", "Behavior"])["Duration"].sum().reset_index(name="Value")
|
|
1373
|
+
agg["Value"] = agg.apply(lambda x: (x["Value"] / total_times.get(x["Video"], 1)) * 100, axis=1)
|
|
1374
|
+
y_label = "Percent Time (%)"
|
|
1375
|
+
|
|
1376
|
+
# Auto-determine graph type if needed
|
|
1377
|
+
if graph_type == "Auto":
|
|
1378
|
+
if analysis_mode == "General Overview":
|
|
1379
|
+
if "Average" in metric: graph_type = "Box Plot"
|
|
1380
|
+
else: graph_type = "Bar Chart"
|
|
1381
|
+
else: # Group Comparison
|
|
1382
|
+
graph_type = "Box Plot"
|
|
1383
|
+
|
|
1384
|
+
# Common Plot Settings
|
|
1385
|
+
template = theme if theme != "none" else None
|
|
1386
|
+
title = f"{analysis_mode}: {metric}"
|
|
1387
|
+
labels = {"Value": y_label, "Duration": "Bout Duration (s)", "Video": "Video", "Group": "Group"}
|
|
1388
|
+
plot_args = {"template": template, "title": title, "labels": labels}
|
|
1389
|
+
|
|
1390
|
+
if analysis_mode == "General Overview":
|
|
1391
|
+
# X=Video, Color=Behavior
|
|
1392
|
+
|
|
1393
|
+
# Decide data source: Bout-level (df) or Video-level (agg)
|
|
1394
|
+
if "Average" in metric and graph_type in ["Box Plot", "Violin Plot", "Strip Plot"]:
|
|
1395
|
+
data = df
|
|
1396
|
+
y_col = "Duration"
|
|
1397
|
+
else:
|
|
1398
|
+
data = agg
|
|
1399
|
+
y_col = "Value"
|
|
1400
|
+
|
|
1401
|
+
if graph_type == "Bar Chart":
|
|
1402
|
+
fig = px.bar(data, x="Video", y=y_col, color="Behavior", **plot_args)
|
|
1403
|
+
elif graph_type == "Box Plot":
|
|
1404
|
+
fig = px.box(data, x="Video", y=y_col, color="Behavior", points="all", **plot_args)
|
|
1405
|
+
elif graph_type == "Violin Plot":
|
|
1406
|
+
fig = px.violin(data, x="Video", y=y_col, color="Behavior", points="all", box=True, **plot_args)
|
|
1407
|
+
elif graph_type == "Strip Plot":
|
|
1408
|
+
fig = px.strip(data, x="Video", y=y_col, color="Behavior", **plot_args)
|
|
1409
|
+
elif graph_type == "Line Plot":
|
|
1410
|
+
fig = px.line(agg, x="Video", y="Value", color="Behavior", markers=True, **plot_args)
|
|
1411
|
+
|
|
1412
|
+
elif analysis_mode == "Group Comparison":
|
|
1413
|
+
# X=Behavior, Color=Group, Data=Video Aggregates (agg)
|
|
1414
|
+
data = agg
|
|
1415
|
+
|
|
1416
|
+
if graph_type == "Bar Chart":
|
|
1417
|
+
# Create bar chart with mean, SEM error bars, and individual points
|
|
1418
|
+
fig = go.Figure()
|
|
1419
|
+
|
|
1420
|
+
# Get unique behaviors and groups
|
|
1421
|
+
behaviors = sorted(data["Behavior"].unique())
|
|
1422
|
+
groups = sorted(data["Group"].unique())
|
|
1423
|
+
num_groups = len(groups)
|
|
1424
|
+
|
|
1425
|
+
# Color palette
|
|
1426
|
+
colors = px.colors.qualitative.Plotly
|
|
1427
|
+
|
|
1428
|
+
# Layout settings
|
|
1429
|
+
group_width = 0.8
|
|
1430
|
+
bar_width = group_width / num_groups
|
|
1431
|
+
|
|
1432
|
+
for i, group in enumerate(groups):
|
|
1433
|
+
group_data = data[data["Group"] == group]
|
|
1434
|
+
|
|
1435
|
+
means = []
|
|
1436
|
+
sems = []
|
|
1437
|
+
x_positions_bar = []
|
|
1438
|
+
x_positions_points = []
|
|
1439
|
+
y_points = []
|
|
1440
|
+
|
|
1441
|
+
# Calculate offset for this group
|
|
1442
|
+
# We map behaviors to integers 0..N-1
|
|
1443
|
+
# Center of bar i is at: index - 0.4 + (i + 0.5) * bar_width
|
|
1444
|
+
offset = -0.4 + bar_width * (i + 0.5)
|
|
1445
|
+
|
|
1446
|
+
for j, behavior in enumerate(behaviors):
|
|
1447
|
+
beh_data = group_data[group_data["Behavior"] == behavior]["Value"]
|
|
1448
|
+
|
|
1449
|
+
# Stats
|
|
1450
|
+
if len(beh_data) > 0:
|
|
1451
|
+
means.append(beh_data.mean())
|
|
1452
|
+
sems.append(beh_data.sem() if len(beh_data) > 1 else 0)
|
|
1453
|
+
else:
|
|
1454
|
+
means.append(0)
|
|
1455
|
+
sems.append(0)
|
|
1456
|
+
|
|
1457
|
+
# Bar Position
|
|
1458
|
+
x_positions_bar.append(j + offset)
|
|
1459
|
+
|
|
1460
|
+
# Points Position
|
|
1461
|
+
if len(beh_data) > 0:
|
|
1462
|
+
x_positions_points.extend([j + offset] * len(beh_data))
|
|
1463
|
+
y_points.extend(beh_data.values)
|
|
1464
|
+
|
|
1465
|
+
# Add bar trace with error bars
|
|
1466
|
+
fig.add_trace(go.Bar(
|
|
1467
|
+
name=group,
|
|
1468
|
+
x=x_positions_bar,
|
|
1469
|
+
y=means,
|
|
1470
|
+
error_y=dict(type='data', array=sems, visible=True),
|
|
1471
|
+
marker_color=colors[i % len(colors)],
|
|
1472
|
+
width=bar_width,
|
|
1473
|
+
showlegend=True
|
|
1474
|
+
))
|
|
1475
|
+
|
|
1476
|
+
# Add individual points as scatter (black)
|
|
1477
|
+
if x_positions_points:
|
|
1478
|
+
fig.add_trace(go.Scatter(
|
|
1479
|
+
x=x_positions_points,
|
|
1480
|
+
y=y_points,
|
|
1481
|
+
mode='markers',
|
|
1482
|
+
marker=dict(
|
|
1483
|
+
color='black',
|
|
1484
|
+
size=5,
|
|
1485
|
+
opacity=0.7,
|
|
1486
|
+
),
|
|
1487
|
+
showlegend=False,
|
|
1488
|
+
hovertemplate=f'{group}<br>%{{y:.2f}}<extra></extra>'
|
|
1489
|
+
))
|
|
1490
|
+
|
|
1491
|
+
fig.update_layout(
|
|
1492
|
+
title=title,
|
|
1493
|
+
xaxis_title="Behavior",
|
|
1494
|
+
yaxis_title=y_label,
|
|
1495
|
+
xaxis=dict(
|
|
1496
|
+
tickmode='array',
|
|
1497
|
+
tickvals=list(range(len(behaviors))),
|
|
1498
|
+
ticktext=behaviors
|
|
1499
|
+
),
|
|
1500
|
+
template=template,
|
|
1501
|
+
hovermode='closest'
|
|
1502
|
+
)
|
|
1503
|
+
|
|
1504
|
+
elif graph_type == "Box Plot":
|
|
1505
|
+
fig = px.box(data, x="Behavior", y="Value", color="Group", points="all", **plot_args)
|
|
1506
|
+
elif graph_type == "Violin Plot":
|
|
1507
|
+
fig = px.violin(data, x="Behavior", y="Value", color="Group", points="all", box=True, **plot_args)
|
|
1508
|
+
elif graph_type == "Strip Plot":
|
|
1509
|
+
fig = px.strip(data, x="Behavior", y="Value", color="Group", **plot_args)
|
|
1510
|
+
elif graph_type == "Line Plot":
|
|
1511
|
+
fig = px.line(data, x="Behavior", y="Value", color="Group", markers=True, **plot_args)
|
|
1512
|
+
|
|
1513
|
+
self.last_fig = fig
|
|
1514
|
+
|
|
1515
|
+
if fig:
|
|
1516
|
+
# Save to temp html
|
|
1517
|
+
import tempfile
|
|
1518
|
+
import shutil
|
|
1519
|
+
|
|
1520
|
+
# Use a fixed temp file in app dir to avoid permission issues sometimes
|
|
1521
|
+
temp_dir = os.path.join(self.config.get("data_dir", "."), "temp_plots")
|
|
1522
|
+
os.makedirs(temp_dir, exist_ok=True)
|
|
1523
|
+
plot_path = os.path.join(temp_dir, "current_plot.html")
|
|
1524
|
+
|
|
1525
|
+
with open(plot_path, 'w', encoding="utf-8") as f:
|
|
1526
|
+
# include_plotlyjs=True embeds the ~3MB library directly in the HTML
|
|
1527
|
+
# preventing 'Plotly is not defined' errors if CDN fails
|
|
1528
|
+
f.write(fig.to_html(include_plotlyjs=True))
|
|
1529
|
+
|
|
1530
|
+
if HAS_WEBENGINE:
|
|
1531
|
+
self.webview.setUrl(QUrl.fromLocalFile(os.path.abspath(plot_path)))
|
|
1532
|
+
|
|
1533
|
+
def _on_transition_type_changed(self):
|
|
1534
|
+
"""Update transition select combo when type changes."""
|
|
1535
|
+
self.transition_select_combo.clear()
|
|
1536
|
+
|
|
1537
|
+
if not self.merged_data:
|
|
1538
|
+
return
|
|
1539
|
+
|
|
1540
|
+
analysis_type = self.transition_type_combo.currentText()
|
|
1541
|
+
|
|
1542
|
+
if analysis_type == "Individual Video":
|
|
1543
|
+
# Populate with video names
|
|
1544
|
+
df = pd.DataFrame(self.merged_data)
|
|
1545
|
+
videos = sorted(df["Video"].unique().tolist())
|
|
1546
|
+
self.transition_select_combo.addItems(videos)
|
|
1547
|
+
else: # Group Comparison
|
|
1548
|
+
# Populate with group names
|
|
1549
|
+
groups = sorted([g for g in self.groups.keys() if g])
|
|
1550
|
+
if not groups:
|
|
1551
|
+
groups = ["Unassigned"]
|
|
1552
|
+
self.transition_select_combo.addItems(groups)
|
|
1553
|
+
|
|
1554
|
+
self.compute_transition_btn.setEnabled(len(self.transition_select_combo) > 0)
|
|
1555
|
+
|
|
1556
|
+
def _compute_transitions(self):
|
|
1557
|
+
"""Compute transition matrix and plot transition graph."""
|
|
1558
|
+
if not HAS_PLOTLY:
|
|
1559
|
+
QMessageBox.warning(self, "Error", "Plotly is required for transition analysis.")
|
|
1560
|
+
return
|
|
1561
|
+
|
|
1562
|
+
analysis_type = self.transition_type_combo.currentText()
|
|
1563
|
+
selection = self.transition_select_combo.currentText()
|
|
1564
|
+
|
|
1565
|
+
if not selection:
|
|
1566
|
+
return
|
|
1567
|
+
|
|
1568
|
+
df = pd.DataFrame(self.merged_data)
|
|
1569
|
+
|
|
1570
|
+
# Filter data based on selection
|
|
1571
|
+
if analysis_type == "Individual Video":
|
|
1572
|
+
df_filtered = df[df["Video"] == selection]
|
|
1573
|
+
title_suffix = f"Video: {selection}"
|
|
1574
|
+
else: # Group Comparison
|
|
1575
|
+
df_filtered = df[df["Group"] == selection]
|
|
1576
|
+
title_suffix = f"Group: {selection}"
|
|
1577
|
+
|
|
1578
|
+
if df_filtered.empty:
|
|
1579
|
+
QMessageBox.warning(self, "No Data", f"No data found for {selection}")
|
|
1580
|
+
return
|
|
1581
|
+
|
|
1582
|
+
# Build sequence of behaviors for this selection
|
|
1583
|
+
# Group by video, sort by implicit order (row order in merged_data represents time)
|
|
1584
|
+
sequences = []
|
|
1585
|
+
|
|
1586
|
+
if analysis_type == "Individual Video":
|
|
1587
|
+
# Single video - one sequence
|
|
1588
|
+
sequence = df_filtered["Behavior"].tolist()
|
|
1589
|
+
sequences.append(sequence)
|
|
1590
|
+
else:
|
|
1591
|
+
# Group - multiple videos, analyze separately then aggregate
|
|
1592
|
+
for video in df_filtered["Video"].unique():
|
|
1593
|
+
video_df = df_filtered[df_filtered["Video"] == video]
|
|
1594
|
+
sequence = video_df["Behavior"].tolist()
|
|
1595
|
+
sequences.append(sequence)
|
|
1596
|
+
|
|
1597
|
+
# Compute transition matrix
|
|
1598
|
+
behaviors = sorted(self.all_behaviors) if self.all_behaviors else sorted(df["Behavior"].unique().tolist())
|
|
1599
|
+
transition_counts = {b: {b2: 0 for b2 in behaviors} for b in behaviors}
|
|
1600
|
+
|
|
1601
|
+
for sequence in sequences:
|
|
1602
|
+
for i in range(len(sequence) - 1):
|
|
1603
|
+
from_beh = sequence[i]
|
|
1604
|
+
to_beh = sequence[i + 1]
|
|
1605
|
+
if from_beh in transition_counts and to_beh in transition_counts[from_beh]:
|
|
1606
|
+
transition_counts[from_beh][to_beh] += 1
|
|
1607
|
+
|
|
1608
|
+
# Normalize to probabilities and calculate residuals
|
|
1609
|
+
total_transitions = sum(sum(row.values()) for row in transition_counts.values())
|
|
1610
|
+
|
|
1611
|
+
# Row and Column totals for Expected values
|
|
1612
|
+
row_totals = {b: sum(transition_counts[b].values()) for b in behaviors}
|
|
1613
|
+
col_totals = {b: sum(transition_counts[row][b] for row in behaviors) for b in behaviors}
|
|
1614
|
+
|
|
1615
|
+
transition_matrix = {}
|
|
1616
|
+
residuals_matrix = {}
|
|
1617
|
+
|
|
1618
|
+
for from_beh in behaviors:
|
|
1619
|
+
transition_matrix[from_beh] = {}
|
|
1620
|
+
residuals_matrix[from_beh] = {}
|
|
1621
|
+
row_sum = row_totals[from_beh]
|
|
1622
|
+
|
|
1623
|
+
for to_beh in behaviors:
|
|
1624
|
+
# Probability
|
|
1625
|
+
if row_sum > 0:
|
|
1626
|
+
transition_matrix[from_beh][to_beh] = transition_counts[from_beh][to_beh] / row_sum
|
|
1627
|
+
else:
|
|
1628
|
+
transition_matrix[from_beh][to_beh] = 0.0
|
|
1629
|
+
|
|
1630
|
+
# Residual (Observed - Expected) / sqrt(Expected)
|
|
1631
|
+
# Expected = (RowTotal * ColTotal) / GrandTotal
|
|
1632
|
+
if total_transitions > 0:
|
|
1633
|
+
expected = (row_totals[from_beh] * col_totals[to_beh]) / total_transitions
|
|
1634
|
+
else:
|
|
1635
|
+
expected = 0
|
|
1636
|
+
|
|
1637
|
+
observed = transition_counts[from_beh][to_beh]
|
|
1638
|
+
|
|
1639
|
+
if expected > 0:
|
|
1640
|
+
# Standardized Residual
|
|
1641
|
+
z_score = (observed - expected) / np.sqrt(expected)
|
|
1642
|
+
else:
|
|
1643
|
+
z_score = 0.0
|
|
1644
|
+
|
|
1645
|
+
residuals_matrix[from_beh][to_beh] = z_score
|
|
1646
|
+
|
|
1647
|
+
# Display matrix in table
|
|
1648
|
+
self._display_transition_matrix(behaviors, transition_matrix)
|
|
1649
|
+
|
|
1650
|
+
# Plot transition graph
|
|
1651
|
+
self._plot_transition_graph(behaviors, transition_matrix, residuals_matrix, title_suffix)
|
|
1652
|
+
|
|
1653
|
+
def _display_transition_matrix(self, behaviors, transition_matrix):
|
|
1654
|
+
"""Display transition matrix in QTableWidget."""
|
|
1655
|
+
self.transition_matrix_table.clear()
|
|
1656
|
+
self.transition_matrix_table.setRowCount(len(behaviors))
|
|
1657
|
+
self.transition_matrix_table.setColumnCount(len(behaviors) + 1)
|
|
1658
|
+
|
|
1659
|
+
# Headers
|
|
1660
|
+
self.transition_matrix_table.setHorizontalHeaderLabels(["From \\ To"] + behaviors)
|
|
1661
|
+
|
|
1662
|
+
for i, from_beh in enumerate(behaviors):
|
|
1663
|
+
# Row label
|
|
1664
|
+
self.transition_matrix_table.setItem(i, 0, QTableWidgetItem(from_beh))
|
|
1665
|
+
|
|
1666
|
+
# Probabilities
|
|
1667
|
+
for j, to_beh in enumerate(behaviors):
|
|
1668
|
+
prob = transition_matrix[from_beh][to_beh]
|
|
1669
|
+
item = QTableWidgetItem(f"{prob:.3f}")
|
|
1670
|
+
self.transition_matrix_table.setItem(i, j + 1, item)
|
|
1671
|
+
|
|
1672
|
+
self.transition_matrix_table.resizeColumnsToContents()
|
|
1673
|
+
|
|
1674
|
+
def _plot_transition_graph(self, behaviors, transition_matrix, residuals_matrix, title_suffix):
|
|
1675
|
+
"""Plot transition graph using Plotly network graph with circular layout."""
|
|
1676
|
+
if not HAS_PLOTLY:
|
|
1677
|
+
return
|
|
1678
|
+
|
|
1679
|
+
use_sig_filter = self.sig_filter_check.isChecked()
|
|
1680
|
+
layout_mode = self.layout_combo.currentText()
|
|
1681
|
+
|
|
1682
|
+
# Build edges
|
|
1683
|
+
threshold = 0.05
|
|
1684
|
+
edges = []
|
|
1685
|
+
edge_weights = []
|
|
1686
|
+
edge_texts = []
|
|
1687
|
+
|
|
1688
|
+
for from_beh in behaviors:
|
|
1689
|
+
for to_beh in behaviors:
|
|
1690
|
+
prob = transition_matrix[from_beh][to_beh]
|
|
1691
|
+
resid = residuals_matrix[from_beh][to_beh]
|
|
1692
|
+
|
|
1693
|
+
# Filtering condition
|
|
1694
|
+
include = False
|
|
1695
|
+
if use_sig_filter:
|
|
1696
|
+
# Significant positive deviation (Z > 1.96 corresponds to p < 0.05)
|
|
1697
|
+
if resid > 1.96:
|
|
1698
|
+
include = True
|
|
1699
|
+
else:
|
|
1700
|
+
# Probability threshold
|
|
1701
|
+
if prob > threshold:
|
|
1702
|
+
include = True
|
|
1703
|
+
|
|
1704
|
+
if include:
|
|
1705
|
+
edges.append((from_beh, to_beh))
|
|
1706
|
+
edge_weights.append(prob) # Always visualize probability width
|
|
1707
|
+
|
|
1708
|
+
# Tooltip text
|
|
1709
|
+
txt = f"{from_beh} → {to_beh}<br>Prob: {prob:.2%}"
|
|
1710
|
+
if resid != 0:
|
|
1711
|
+
txt += f"<br>Z-score: {resid:.2f}"
|
|
1712
|
+
edge_texts.append(txt)
|
|
1713
|
+
|
|
1714
|
+
if not edges:
|
|
1715
|
+
msg = "No significant transitions found (Z > 1.96)" if use_sig_filter else "No significant transitions found (prob > 0.05)"
|
|
1716
|
+
QMessageBox.information(self, "No Transitions", msg)
|
|
1717
|
+
return
|
|
1718
|
+
|
|
1719
|
+
# Calculate Layout
|
|
1720
|
+
import math
|
|
1721
|
+
n = len(behaviors)
|
|
1722
|
+
node_positions = {}
|
|
1723
|
+
|
|
1724
|
+
if layout_mode == "Network Layout":
|
|
1725
|
+
try:
|
|
1726
|
+
import networkx as nx
|
|
1727
|
+
# Create graph for layout calculation
|
|
1728
|
+
G = nx.DiGraph()
|
|
1729
|
+
G.add_nodes_from(behaviors)
|
|
1730
|
+
# Add edges with weights (inverse of probability for 'distance')
|
|
1731
|
+
for (u, v), w in zip(edges, edge_weights):
|
|
1732
|
+
if w > 0:
|
|
1733
|
+
G.add_edge(u, v, weight=w)
|
|
1734
|
+
|
|
1735
|
+
# Spring layout
|
|
1736
|
+
pos = nx.spring_layout(G, k=2.0/math.sqrt(n) if n > 0 else 1, seed=42, iterations=50)
|
|
1737
|
+
# Scale to fit roughly in range
|
|
1738
|
+
for node, p in pos.items():
|
|
1739
|
+
node_positions[node] = (p[0] * 2, p[1] * 2)
|
|
1740
|
+
except ImportError:
|
|
1741
|
+
logger.warning("NetworkX not installed, falling back to Circular Layout")
|
|
1742
|
+
layout_mode = "Circular Layout" # Fallback
|
|
1743
|
+
|
|
1744
|
+
if layout_mode == "Circular Layout":
|
|
1745
|
+
radius = 1.5
|
|
1746
|
+
for i, beh in enumerate(behaviors):
|
|
1747
|
+
angle = 2 * math.pi * i / n - math.pi / 2
|
|
1748
|
+
node_positions[beh] = (radius * math.cos(angle), radius * math.sin(angle))
|
|
1749
|
+
|
|
1750
|
+
# Generate distinct colors for nodes
|
|
1751
|
+
import plotly.colors as pc
|
|
1752
|
+
if n <= 10:
|
|
1753
|
+
node_colors = pc.qualitative.Set3[:n]
|
|
1754
|
+
else:
|
|
1755
|
+
node_colors = pc.sample_colorscale("hsv", [i/n for i in range(n)])
|
|
1756
|
+
|
|
1757
|
+
behavior_to_color = {beh: node_colors[i] for i, beh in enumerate(behaviors)}
|
|
1758
|
+
|
|
1759
|
+
# Build Plotly figure with curved edges
|
|
1760
|
+
edge_traces = []
|
|
1761
|
+
for (from_beh, to_beh), weight, txt in zip(edges, edge_weights, edge_texts):
|
|
1762
|
+
x0, y0 = node_positions[from_beh]
|
|
1763
|
+
x1, y1 = node_positions[to_beh]
|
|
1764
|
+
|
|
1765
|
+
# Curve control points
|
|
1766
|
+
if layout_mode == "Circular Layout":
|
|
1767
|
+
# Control point slightly outside circle
|
|
1768
|
+
mid_x, mid_y = (x0 + x1) / 2, (y0 + y1) / 2
|
|
1769
|
+
dist = math.sqrt(mid_x**2 + mid_y**2)
|
|
1770
|
+
if dist > 0:
|
|
1771
|
+
offset = 0.2
|
|
1772
|
+
ctrl_x = mid_x + offset * mid_x / dist
|
|
1773
|
+
ctrl_y = mid_y + offset * mid_y / dist
|
|
1774
|
+
else:
|
|
1775
|
+
ctrl_x, ctrl_y = mid_x, mid_y
|
|
1776
|
+
else:
|
|
1777
|
+
# Network layout: simple quadratic curve to avoid overlap
|
|
1778
|
+
# Perpendicular offset
|
|
1779
|
+
mid_x, mid_y = (x0 + x1) / 2, (y0 + y1) / 2
|
|
1780
|
+
dx, dy = x1 - x0, y1 - y0
|
|
1781
|
+
perp_x, perp_y = -dy, dx
|
|
1782
|
+
norm = math.sqrt(perp_x**2 + perp_y**2)
|
|
1783
|
+
if norm > 0:
|
|
1784
|
+
offset = 0.2 # Fixed offset amount
|
|
1785
|
+
ctrl_x = mid_x + offset * perp_x / norm
|
|
1786
|
+
ctrl_y = mid_y + offset * perp_y / norm
|
|
1787
|
+
else:
|
|
1788
|
+
ctrl_x, ctrl_y = mid_x + 0.2, mid_y + 0.2 # Loop
|
|
1789
|
+
|
|
1790
|
+
# Approximate curve
|
|
1791
|
+
curve_x = [x0, ctrl_x, x1, None]
|
|
1792
|
+
curve_y = [y0, ctrl_y, y1, None]
|
|
1793
|
+
|
|
1794
|
+
# Style
|
|
1795
|
+
edge_color = behavior_to_color[from_beh]
|
|
1796
|
+
opacity = 0.3 + 0.5 * weight
|
|
1797
|
+
|
|
1798
|
+
edge_trace = go.Scatter(
|
|
1799
|
+
x=curve_x,
|
|
1800
|
+
y=curve_y,
|
|
1801
|
+
mode='lines',
|
|
1802
|
+
line=dict(
|
|
1803
|
+
width=max(1, weight * 10),
|
|
1804
|
+
color=edge_color.replace('rgb', 'rgba').replace(')', f',{opacity})'),
|
|
1805
|
+
shape='spline'
|
|
1806
|
+
),
|
|
1807
|
+
hoverinfo='text',
|
|
1808
|
+
text=txt,
|
|
1809
|
+
showlegend=False
|
|
1810
|
+
)
|
|
1811
|
+
edge_traces.append(edge_trace)
|
|
1812
|
+
|
|
1813
|
+
# Node trace
|
|
1814
|
+
node_x = [node_positions[beh][0] for beh in behaviors]
|
|
1815
|
+
node_y = [node_positions[beh][1] for beh in behaviors]
|
|
1816
|
+
node_colors_list = [behavior_to_color[beh] for beh in behaviors]
|
|
1817
|
+
|
|
1818
|
+
node_trace = go.Scatter(
|
|
1819
|
+
x=node_x,
|
|
1820
|
+
y=node_y,
|
|
1821
|
+
mode='markers+text',
|
|
1822
|
+
text=behaviors,
|
|
1823
|
+
textposition='middle center',
|
|
1824
|
+
textfont=dict(size=10, color='black', family='Arial Black'),
|
|
1825
|
+
marker=dict(
|
|
1826
|
+
size=50,
|
|
1827
|
+
color=node_colors_list,
|
|
1828
|
+
line=dict(width=3, color='white'),
|
|
1829
|
+
opacity=0.9
|
|
1830
|
+
),
|
|
1831
|
+
hoverinfo='text',
|
|
1832
|
+
hovertext=behaviors,
|
|
1833
|
+
showlegend=False
|
|
1834
|
+
)
|
|
1835
|
+
|
|
1836
|
+
fig = go.Figure(data=edge_traces + [node_trace])
|
|
1837
|
+
fig.update_layout(
|
|
1838
|
+
title=dict(
|
|
1839
|
+
text=f"Behavior Transition Graph - {title_suffix}" + (" (Significant Only)" if use_sig_filter else "") + "<br><span style='font-size:12px;color:grey'>Edge color matches SOURCE behavior (From -> To)</span>",
|
|
1840
|
+
font=dict(size=18)
|
|
1841
|
+
),
|
|
1842
|
+
showlegend=False,
|
|
1843
|
+
hovermode='closest',
|
|
1844
|
+
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False, range=[-2.5, 2.5]),
|
|
1845
|
+
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False, range=[-2.5, 2.5]),
|
|
1846
|
+
plot_bgcolor='white',
|
|
1847
|
+
height=700,
|
|
1848
|
+
width=700
|
|
1849
|
+
)
|
|
1850
|
+
|
|
1851
|
+
# Save and display
|
|
1852
|
+
temp_dir = os.path.join(self.config.get("data_dir", "."), "temp_plots")
|
|
1853
|
+
os.makedirs(temp_dir, exist_ok=True)
|
|
1854
|
+
plot_path = os.path.join(temp_dir, "transition_graph.html")
|
|
1855
|
+
|
|
1856
|
+
with open(plot_path, 'w', encoding="utf-8") as f:
|
|
1857
|
+
f.write(fig.to_html(include_plotlyjs=True))
|
|
1858
|
+
|
|
1859
|
+
if HAS_WEBENGINE:
|
|
1860
|
+
self.transition_webview.setUrl(QUrl.fromLocalFile(os.path.abspath(plot_path)))
|
|
1861
|
+
|
|
1862
|
+
class SpatialCanvas(QWidget):
|
|
1863
|
+
"""Interactive canvas for drawing spatial regions over centroid scatter."""
|
|
1864
|
+
|
|
1865
|
+
MARGIN = 30
|
|
1866
|
+
REGION_COLORS = [
|
|
1867
|
+
QColor(102, 194, 165, 60), QColor(252, 141, 98, 60),
|
|
1868
|
+
QColor(141, 160, 203, 60), QColor(231, 138, 195, 60),
|
|
1869
|
+
QColor(166, 216, 84, 60), QColor(255, 217, 47, 60),
|
|
1870
|
+
]
|
|
1871
|
+
REGION_BORDER_COLORS = [
|
|
1872
|
+
QColor(102, 194, 165), QColor(252, 141, 98),
|
|
1873
|
+
QColor(141, 160, 203), QColor(231, 138, 195),
|
|
1874
|
+
QColor(166, 216, 84), QColor(255, 217, 47),
|
|
1875
|
+
]
|
|
1876
|
+
|
|
1877
|
+
def __init__(self, centroids, regions, parent=None):
|
|
1878
|
+
super().__init__(parent)
|
|
1879
|
+
self.centroids = centroids # [(cx, cy), ...]
|
|
1880
|
+
self.regions = regions # list of region dicts (mutable reference)
|
|
1881
|
+
self.setMinimumSize(500, 500)
|
|
1882
|
+
self.setMouseTracking(True)
|
|
1883
|
+
|
|
1884
|
+
self.draw_mode = None # "polygon" or "rect"
|
|
1885
|
+
self._poly_points = [] # in-progress polygon vertices (normalized)
|
|
1886
|
+
self._rect_start = None # in-progress rect start (normalized)
|
|
1887
|
+
self._rect_end = None
|
|
1888
|
+
self._mouse_norm = None # current mouse in normalized coords
|
|
1889
|
+
self._pending_name = None # name for the region being drawn
|
|
1890
|
+
|
|
1891
|
+
def _norm_to_pixel(self, nx, ny):
|
|
1892
|
+
m = self.MARGIN
|
|
1893
|
+
w = self.width() - 2 * m
|
|
1894
|
+
h = self.height() - 2 * m
|
|
1895
|
+
return int(m + nx * w), int(m + ny * h)
|
|
1896
|
+
|
|
1897
|
+
def _pixel_to_norm(self, px, py):
|
|
1898
|
+
m = self.MARGIN
|
|
1899
|
+
w = self.width() - 2 * m
|
|
1900
|
+
h = self.height() - 2 * m
|
|
1901
|
+
nx = max(0.0, min(1.0, (px - m) / max(1, w)))
|
|
1902
|
+
ny = max(0.0, min(1.0, (py - m) / max(1, h)))
|
|
1903
|
+
return nx, ny
|
|
1904
|
+
|
|
1905
|
+
def start_drawing(self, mode, name):
|
|
1906
|
+
self.draw_mode = mode
|
|
1907
|
+
self._pending_name = name
|
|
1908
|
+
self._poly_points = []
|
|
1909
|
+
self._rect_start = None
|
|
1910
|
+
self._rect_end = None
|
|
1911
|
+
self.setCursor(Qt.CursorShape.CrossCursor)
|
|
1912
|
+
self.update()
|
|
1913
|
+
|
|
1914
|
+
def cancel_drawing(self):
|
|
1915
|
+
self.draw_mode = None
|
|
1916
|
+
self._poly_points = []
|
|
1917
|
+
self._rect_start = None
|
|
1918
|
+
self._rect_end = None
|
|
1919
|
+
self.setCursor(Qt.CursorShape.ArrowCursor)
|
|
1920
|
+
self.update()
|
|
1921
|
+
|
|
1922
|
+
def mousePressEvent(self, event):
|
|
1923
|
+
if not self.draw_mode:
|
|
1924
|
+
return
|
|
1925
|
+
pos = event.position()
|
|
1926
|
+
nx, ny = self._pixel_to_norm(int(pos.x()), int(pos.y()))
|
|
1927
|
+
if self.draw_mode == "polygon":
|
|
1928
|
+
self._poly_points.append((nx, ny))
|
|
1929
|
+
self.update()
|
|
1930
|
+
elif self.draw_mode == "rect":
|
|
1931
|
+
self._rect_start = (nx, ny)
|
|
1932
|
+
self._rect_end = (nx, ny)
|
|
1933
|
+
self.update()
|
|
1934
|
+
|
|
1935
|
+
def mouseMoveEvent(self, event):
|
|
1936
|
+
pos = event.position()
|
|
1937
|
+
self._mouse_norm = self._pixel_to_norm(int(pos.x()), int(pos.y()))
|
|
1938
|
+
if self.draw_mode == "rect" and self._rect_start:
|
|
1939
|
+
self._rect_end = self._mouse_norm
|
|
1940
|
+
self.update()
|
|
1941
|
+
|
|
1942
|
+
def mouseReleaseEvent(self, event):
|
|
1943
|
+
if self.draw_mode == "rect" and self._rect_start and self._rect_end:
|
|
1944
|
+
x0, y0 = self._rect_start
|
|
1945
|
+
x1, y1 = self._rect_end
|
|
1946
|
+
if abs(x1 - x0) > 0.005 and abs(y1 - y0) > 0.005:
|
|
1947
|
+
self.regions.append({
|
|
1948
|
+
"name": self._pending_name or f"Region {len(self.regions)+1}",
|
|
1949
|
+
"type": "rect",
|
|
1950
|
+
"vertices": [(min(x0, x1), min(y0, y1)), (max(x0, x1), max(y0, y1))],
|
|
1951
|
+
})
|
|
1952
|
+
self.draw_mode = None
|
|
1953
|
+
self._rect_start = None
|
|
1954
|
+
self._rect_end = None
|
|
1955
|
+
self.setCursor(Qt.CursorShape.ArrowCursor)
|
|
1956
|
+
# Signal parent to refresh list
|
|
1957
|
+
parent = self.parent()
|
|
1958
|
+
while parent and not isinstance(parent, SpatialRegionEditor):
|
|
1959
|
+
parent = parent.parent()
|
|
1960
|
+
if parent:
|
|
1961
|
+
parent._refresh_region_list()
|
|
1962
|
+
self.update()
|
|
1963
|
+
|
|
1964
|
+
def mouseDoubleClickEvent(self, event):
|
|
1965
|
+
if self.draw_mode == "polygon" and len(self._poly_points) >= 3:
|
|
1966
|
+
self.regions.append({
|
|
1967
|
+
"name": self._pending_name or f"Region {len(self.regions)+1}",
|
|
1968
|
+
"type": "polygon",
|
|
1969
|
+
"vertices": list(self._poly_points),
|
|
1970
|
+
})
|
|
1971
|
+
self._poly_points = []
|
|
1972
|
+
self.draw_mode = None
|
|
1973
|
+
self.setCursor(Qt.CursorShape.ArrowCursor)
|
|
1974
|
+
parent = self.parent()
|
|
1975
|
+
while parent and not isinstance(parent, SpatialRegionEditor):
|
|
1976
|
+
parent = parent.parent()
|
|
1977
|
+
if parent:
|
|
1978
|
+
parent._refresh_region_list()
|
|
1979
|
+
self.update()
|
|
1980
|
+
|
|
1981
|
+
def paintEvent(self, event):
|
|
1982
|
+
p = QPainter(self)
|
|
1983
|
+
p.setRenderHint(QPainter.RenderHint.Antialiasing)
|
|
1984
|
+
m = self.MARGIN
|
|
1985
|
+
draw_w = self.width() - 2 * m
|
|
1986
|
+
draw_h = self.height() - 2 * m
|
|
1987
|
+
|
|
1988
|
+
# Background
|
|
1989
|
+
p.fillRect(self.rect(), QColor(255, 255, 255))
|
|
1990
|
+
p.setPen(QPen(QColor(200, 200, 200), 1))
|
|
1991
|
+
p.drawRect(m, m, draw_w, draw_h)
|
|
1992
|
+
|
|
1993
|
+
# Centroids
|
|
1994
|
+
p.setPen(Qt.PenStyle.NoPen)
|
|
1995
|
+
p.setBrush(QBrush(QColor(180, 180, 180, 120)))
|
|
1996
|
+
for cx, cy in self.centroids:
|
|
1997
|
+
px, py = self._norm_to_pixel(cx, cy)
|
|
1998
|
+
p.drawEllipse(QPointF(px, py), 3, 3)
|
|
1999
|
+
|
|
2000
|
+
# Saved regions
|
|
2001
|
+
for ri, region in enumerate(self.regions):
|
|
2002
|
+
col = self.REGION_COLORS[ri % len(self.REGION_COLORS)]
|
|
2003
|
+
border = self.REGION_BORDER_COLORS[ri % len(self.REGION_BORDER_COLORS)]
|
|
2004
|
+
p.setBrush(QBrush(col))
|
|
2005
|
+
p.setPen(QPen(border, 2))
|
|
2006
|
+
verts = region["vertices"]
|
|
2007
|
+
if region["type"] == "rect" and len(verts) == 2:
|
|
2008
|
+
px0, py0 = self._norm_to_pixel(*verts[0])
|
|
2009
|
+
px1, py1 = self._norm_to_pixel(*verts[1])
|
|
2010
|
+
p.drawRect(QRectF(QPointF(px0, py0), QPointF(px1, py1)))
|
|
2011
|
+
else:
|
|
2012
|
+
poly = QPolygonF([QPointF(*self._norm_to_pixel(*v)) for v in verts])
|
|
2013
|
+
p.drawPolygon(poly)
|
|
2014
|
+
# Label
|
|
2015
|
+
if verts:
|
|
2016
|
+
cx_avg = sum(v[0] for v in verts) / len(verts)
|
|
2017
|
+
cy_avg = sum(v[1] for v in verts) / len(verts)
|
|
2018
|
+
lx, ly = self._norm_to_pixel(cx_avg, cy_avg)
|
|
2019
|
+
p.setPen(QPen(border.darker(130), 1))
|
|
2020
|
+
p.setFont(QFont("Arial", 9, QFont.Weight.Bold))
|
|
2021
|
+
p.drawText(lx - 30, ly - 5, 60, 20, Qt.AlignmentFlag.AlignCenter, region["name"])
|
|
2022
|
+
|
|
2023
|
+
# In-progress polygon
|
|
2024
|
+
if self.draw_mode == "polygon" and self._poly_points:
|
|
2025
|
+
p.setPen(QPen(QColor(255, 80, 80), 2, Qt.PenStyle.DashLine))
|
|
2026
|
+
p.setBrush(QBrush(QColor(255, 80, 80, 40)))
|
|
2027
|
+
pts = [QPointF(*self._norm_to_pixel(*v)) for v in self._poly_points]
|
|
2028
|
+
if self._mouse_norm:
|
|
2029
|
+
pts.append(QPointF(*self._norm_to_pixel(*self._mouse_norm)))
|
|
2030
|
+
if len(pts) >= 3:
|
|
2031
|
+
p.drawPolygon(QPolygonF(pts))
|
|
2032
|
+
elif len(pts) == 2:
|
|
2033
|
+
p.drawLine(pts[0], pts[1])
|
|
2034
|
+
for pt in pts[:-1]:
|
|
2035
|
+
p.setBrush(QBrush(QColor(255, 80, 80)))
|
|
2036
|
+
p.drawEllipse(pt, 4, 4)
|
|
2037
|
+
p.setBrush(QBrush(QColor(255, 80, 80, 40)))
|
|
2038
|
+
|
|
2039
|
+
# In-progress rectangle
|
|
2040
|
+
if self.draw_mode == "rect" and self._rect_start and self._rect_end:
|
|
2041
|
+
p.setPen(QPen(QColor(80, 80, 255), 2, Qt.PenStyle.DashLine))
|
|
2042
|
+
p.setBrush(QBrush(QColor(80, 80, 255, 40)))
|
|
2043
|
+
px0, py0 = self._norm_to_pixel(*self._rect_start)
|
|
2044
|
+
px1, py1 = self._norm_to_pixel(*self._rect_end)
|
|
2045
|
+
p.drawRect(QRectF(QPointF(px0, py0), QPointF(px1, py1)))
|
|
2046
|
+
|
|
2047
|
+
# Axis labels
|
|
2048
|
+
p.setPen(QColor(100, 100, 100))
|
|
2049
|
+
p.setFont(QFont("Arial", 8))
|
|
2050
|
+
p.drawText(m, m - 5, "0,0")
|
|
2051
|
+
p.drawText(m + draw_w - 20, m - 5, "1,0")
|
|
2052
|
+
p.drawText(m, m + draw_h + 14, "0,1")
|
|
2053
|
+
|
|
2054
|
+
p.end()
|
|
2055
|
+
|
|
2056
|
+
|
|
2057
|
+
class SpatialRegionEditor(QDialog):
|
|
2058
|
+
"""Dialog for drawing and managing named spatial regions."""
|
|
2059
|
+
|
|
2060
|
+
def __init__(self, centroids, existing_regions, parent=None):
|
|
2061
|
+
super().__init__(parent)
|
|
2062
|
+
self.setWindowTitle("Spatial Region Editor")
|
|
2063
|
+
self.resize(850, 600)
|
|
2064
|
+
# Deep copy to allow cancel
|
|
2065
|
+
self.regions = [
|
|
2066
|
+
{"name": r["name"], "type": r["type"], "vertices": list(r["vertices"])}
|
|
2067
|
+
for r in existing_regions
|
|
2068
|
+
]
|
|
2069
|
+
self.centroids = centroids
|
|
2070
|
+
self._setup_ui()
|
|
2071
|
+
|
|
2072
|
+
def _setup_ui(self):
|
|
2073
|
+
layout = QVBoxLayout()
|
|
2074
|
+
|
|
2075
|
+
# Toolbar
|
|
2076
|
+
toolbar = QHBoxLayout()
|
|
2077
|
+
self.polygon_btn = QToolButton()
|
|
2078
|
+
self.polygon_btn.setText("Draw Polygon")
|
|
2079
|
+
self.polygon_btn.setToolTip("Click vertices, double-click to finish")
|
|
2080
|
+
self.polygon_btn.clicked.connect(self._start_polygon)
|
|
2081
|
+
|
|
2082
|
+
self.rect_btn = QToolButton()
|
|
2083
|
+
self.rect_btn.setText("Draw Rectangle")
|
|
2084
|
+
self.rect_btn.setToolTip("Click and drag to define rectangle")
|
|
2085
|
+
self.rect_btn.clicked.connect(self._start_rect)
|
|
2086
|
+
|
|
2087
|
+
self.cancel_draw_btn = QPushButton("Cancel drawing")
|
|
2088
|
+
self.cancel_draw_btn.clicked.connect(self._cancel_draw)
|
|
2089
|
+
self.cancel_draw_btn.setEnabled(False)
|
|
2090
|
+
|
|
2091
|
+
toolbar.addWidget(self.polygon_btn)
|
|
2092
|
+
toolbar.addWidget(self.rect_btn)
|
|
2093
|
+
toolbar.addWidget(self.cancel_draw_btn)
|
|
2094
|
+
toolbar.addStretch()
|
|
2095
|
+
layout.addLayout(toolbar)
|
|
2096
|
+
|
|
2097
|
+
# Main splitter: canvas (left) + region list (right)
|
|
2098
|
+
splitter = QSplitter(Qt.Orientation.Horizontal)
|
|
2099
|
+
|
|
2100
|
+
self.canvas = SpatialCanvas(self.centroids, self.regions)
|
|
2101
|
+
splitter.addWidget(self.canvas)
|
|
2102
|
+
|
|
2103
|
+
right = QWidget()
|
|
2104
|
+
right_layout = QVBoxLayout()
|
|
2105
|
+
right_layout.addWidget(QLabel("Defined Regions"))
|
|
2106
|
+
self.region_list = QListWidget()
|
|
2107
|
+
self.region_list.setSelectionMode(QListWidget.SelectionMode.SingleSelection)
|
|
2108
|
+
right_layout.addWidget(self.region_list)
|
|
2109
|
+
|
|
2110
|
+
btn_row = QHBoxLayout()
|
|
2111
|
+
self.rename_btn = QPushButton("Rename")
|
|
2112
|
+
self.rename_btn.clicked.connect(self._rename_region)
|
|
2113
|
+
self.delete_btn = QPushButton("Delete")
|
|
2114
|
+
self.delete_btn.clicked.connect(self._delete_region)
|
|
2115
|
+
btn_row.addWidget(self.rename_btn)
|
|
2116
|
+
btn_row.addWidget(self.delete_btn)
|
|
2117
|
+
right_layout.addLayout(btn_row)
|
|
2118
|
+
right.setLayout(right_layout)
|
|
2119
|
+
splitter.addWidget(right)
|
|
2120
|
+
splitter.setSizes([600, 250])
|
|
2121
|
+
layout.addWidget(splitter, stretch=1)
|
|
2122
|
+
|
|
2123
|
+
# OK/Cancel
|
|
2124
|
+
btns = QHBoxLayout()
|
|
2125
|
+
ok_btn = QPushButton("Done")
|
|
2126
|
+
ok_btn.clicked.connect(self.accept)
|
|
2127
|
+
cancel_btn = QPushButton("Cancel")
|
|
2128
|
+
cancel_btn.clicked.connect(self.reject)
|
|
2129
|
+
btns.addWidget(ok_btn)
|
|
2130
|
+
btns.addWidget(cancel_btn)
|
|
2131
|
+
layout.addLayout(btns)
|
|
2132
|
+
|
|
2133
|
+
self.setLayout(layout)
|
|
2134
|
+
self._refresh_region_list()
|
|
2135
|
+
|
|
2136
|
+
def _start_polygon(self):
|
|
2137
|
+
name, ok = QInputDialog.getText(self, "Region Name", "Name for the new polygon region:")
|
|
2138
|
+
if not ok or not name.strip():
|
|
2139
|
+
return
|
|
2140
|
+
self.cancel_draw_btn.setEnabled(True)
|
|
2141
|
+
self.canvas.start_drawing("polygon", name.strip())
|
|
2142
|
+
|
|
2143
|
+
def _start_rect(self):
|
|
2144
|
+
name, ok = QInputDialog.getText(self, "Region Name", "Name for the new rectangle region:")
|
|
2145
|
+
if not ok or not name.strip():
|
|
2146
|
+
return
|
|
2147
|
+
self.cancel_draw_btn.setEnabled(True)
|
|
2148
|
+
self.canvas.start_drawing("rect", name.strip())
|
|
2149
|
+
|
|
2150
|
+
def _cancel_draw(self):
|
|
2151
|
+
self.canvas.cancel_drawing()
|
|
2152
|
+
self.cancel_draw_btn.setEnabled(False)
|
|
2153
|
+
|
|
2154
|
+
def _refresh_region_list(self):
|
|
2155
|
+
self.cancel_draw_btn.setEnabled(False)
|
|
2156
|
+
self.region_list.clear()
|
|
2157
|
+
for r in self.regions:
|
|
2158
|
+
n_verts = len(r["vertices"])
|
|
2159
|
+
self.region_list.addItem(f"{r['name']} ({r['type']}, {n_verts} pts)")
|
|
2160
|
+
self.canvas.update()
|
|
2161
|
+
|
|
2162
|
+
def _rename_region(self):
|
|
2163
|
+
idx = self.region_list.currentRow()
|
|
2164
|
+
if idx < 0 or idx >= len(self.regions):
|
|
2165
|
+
return
|
|
2166
|
+
old = self.regions[idx]["name"]
|
|
2167
|
+
name, ok = QInputDialog.getText(self, "Rename Region", "New name:", text=old)
|
|
2168
|
+
if ok and name.strip():
|
|
2169
|
+
self.regions[idx]["name"] = name.strip()
|
|
2170
|
+
self._refresh_region_list()
|
|
2171
|
+
|
|
2172
|
+
def _delete_region(self):
|
|
2173
|
+
idx = self.region_list.currentRow()
|
|
2174
|
+
if idx < 0 or idx >= len(self.regions):
|
|
2175
|
+
return
|
|
2176
|
+
self.regions.pop(idx)
|
|
2177
|
+
self._refresh_region_list()
|
|
2178
|
+
|
|
2179
|
+
|
|
2180
|
+
class GroupManagementDialog(QDialog):
|
|
2181
|
+
def __init__(self, video_paths, groups, video_groups, parent=None):
|
|
2182
|
+
super().__init__(parent)
|
|
2183
|
+
self.video_paths = sorted(list(video_paths))
|
|
2184
|
+
self.groups = groups.copy()
|
|
2185
|
+
self.video_groups = video_groups.copy()
|
|
2186
|
+
self.setWindowTitle("Manage groups")
|
|
2187
|
+
self.resize(800, 600)
|
|
2188
|
+
self._setup_ui()
|
|
2189
|
+
|
|
2190
|
+
def _setup_ui(self):
|
|
2191
|
+
layout = QVBoxLayout()
|
|
2192
|
+
|
|
2193
|
+
# Top: Group creation
|
|
2194
|
+
group_input_layout = QHBoxLayout()
|
|
2195
|
+
self.group_name_edit = QLineEdit()
|
|
2196
|
+
self.group_name_edit.setPlaceholderText("New Group Name")
|
|
2197
|
+
self.add_group_btn = QPushButton("Add group")
|
|
2198
|
+
self.add_group_btn.clicked.connect(self._add_group)
|
|
2199
|
+
group_input_layout.addWidget(self.group_name_edit)
|
|
2200
|
+
group_input_layout.addWidget(self.add_group_btn)
|
|
2201
|
+
layout.addLayout(group_input_layout)
|
|
2202
|
+
|
|
2203
|
+
# Main: Splitter with Videos (left) and Groups (right)
|
|
2204
|
+
splitter = QSplitter(Qt.Orientation.Horizontal)
|
|
2205
|
+
|
|
2206
|
+
# Videos List (Table to show current group)
|
|
2207
|
+
video_widget = QWidget()
|
|
2208
|
+
video_layout = QVBoxLayout()
|
|
2209
|
+
video_layout.addWidget(QLabel("Videos"))
|
|
2210
|
+
self.video_table = QTableWidget()
|
|
2211
|
+
self.video_table.setColumnCount(2)
|
|
2212
|
+
self.video_table.setHorizontalHeaderLabels(["Video", "Assigned Group"])
|
|
2213
|
+
self.video_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.ResizeMode.Stretch)
|
|
2214
|
+
self.video_table.setSelectionBehavior(QTableWidget.SelectionBehavior.SelectRows)
|
|
2215
|
+
self.video_table.setSelectionMode(QTableWidget.SelectionMode.ExtendedSelection)
|
|
2216
|
+
video_layout.addWidget(self.video_table)
|
|
2217
|
+
video_widget.setLayout(video_layout)
|
|
2218
|
+
splitter.addWidget(video_widget)
|
|
2219
|
+
|
|
2220
|
+
# Groups List
|
|
2221
|
+
group_widget = QWidget()
|
|
2222
|
+
group_layout = QVBoxLayout()
|
|
2223
|
+
group_layout.addWidget(QLabel("Groups (Select to assign)"))
|
|
2224
|
+
self.group_list = QTableWidget()
|
|
2225
|
+
self.group_list.setColumnCount(1)
|
|
2226
|
+
self.group_list.setHorizontalHeaderLabels(["Group Name"])
|
|
2227
|
+
self.group_list.horizontalHeader().setSectionResizeMode(0, QHeaderView.ResizeMode.Stretch)
|
|
2228
|
+
self.group_list.itemSelectionChanged.connect(self._assign_group)
|
|
2229
|
+
group_layout.addWidget(self.group_list)
|
|
2230
|
+
group_widget.setLayout(group_layout)
|
|
2231
|
+
splitter.addWidget(group_widget)
|
|
2232
|
+
|
|
2233
|
+
layout.addWidget(splitter)
|
|
2234
|
+
|
|
2235
|
+
btns = QHBoxLayout()
|
|
2236
|
+
ok_btn = QPushButton("OK")
|
|
2237
|
+
ok_btn.clicked.connect(self.accept)
|
|
2238
|
+
cancel_btn = QPushButton("Cancel")
|
|
2239
|
+
cancel_btn.clicked.connect(self.reject)
|
|
2240
|
+
btns.addWidget(ok_btn)
|
|
2241
|
+
btns.addWidget(cancel_btn)
|
|
2242
|
+
layout.addLayout(btns)
|
|
2243
|
+
|
|
2244
|
+
self.setLayout(layout)
|
|
2245
|
+
self._refresh_lists()
|
|
2246
|
+
|
|
2247
|
+
def _refresh_lists(self):
|
|
2248
|
+
# Videos
|
|
2249
|
+
self.video_table.setRowCount(len(self.video_paths))
|
|
2250
|
+
for i, path in enumerate(self.video_paths):
|
|
2251
|
+
name = os.path.basename(path)
|
|
2252
|
+
self.video_table.setItem(i, 0, QTableWidgetItem(name))
|
|
2253
|
+
group = self.video_groups.get(path, "Unassigned")
|
|
2254
|
+
self.video_table.setItem(i, 1, QTableWidgetItem(group))
|
|
2255
|
+
|
|
2256
|
+
# Groups
|
|
2257
|
+
self.group_list.setRowCount(len(self.groups))
|
|
2258
|
+
for i, group in enumerate(sorted(self.groups.keys())):
|
|
2259
|
+
self.group_list.setItem(i, 0, QTableWidgetItem(group))
|
|
2260
|
+
|
|
2261
|
+
def _add_group(self):
|
|
2262
|
+
name = self.group_name_edit.text().strip()
|
|
2263
|
+
if name and name not in self.groups:
|
|
2264
|
+
self.groups[name] = []
|
|
2265
|
+
self._refresh_lists()
|
|
2266
|
+
self.group_name_edit.clear()
|
|
2267
|
+
|
|
2268
|
+
def _assign_group(self):
|
|
2269
|
+
selected_group_items = self.group_list.selectedItems()
|
|
2270
|
+
if not selected_group_items:
|
|
2271
|
+
return
|
|
2272
|
+
|
|
2273
|
+
group_name = selected_group_items[0].text()
|
|
2274
|
+
|
|
2275
|
+
selected_video_rows = set(idx.row() for idx in self.video_table.selectedIndexes())
|
|
2276
|
+
|
|
2277
|
+
for row in selected_video_rows:
|
|
2278
|
+
# Using index from video_paths since table order matches
|
|
2279
|
+
full_path = self.video_paths[row]
|
|
2280
|
+
|
|
2281
|
+
# Remove from old group
|
|
2282
|
+
old_group = self.video_groups.get(full_path)
|
|
2283
|
+
if old_group and old_group in self.groups and full_path in self.groups[old_group]:
|
|
2284
|
+
self.groups[old_group].remove(full_path)
|
|
2285
|
+
|
|
2286
|
+
# Assign new
|
|
2287
|
+
self.video_groups[full_path] = group_name
|
|
2288
|
+
if full_path not in self.groups[group_name]:
|
|
2289
|
+
self.groups[group_name].append(full_path)
|
|
2290
|
+
|
|
2291
|
+
self._refresh_lists()
|