singlebehaviorlab 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sam2/__init__.py +11 -0
- sam2/automatic_mask_generator.py +454 -0
- sam2/benchmark.py +92 -0
- sam2/build_sam.py +174 -0
- sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
- sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
- sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
- sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
- sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
- sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
- sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
- sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
- sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
- sam2/modeling/__init__.py +5 -0
- sam2/modeling/backbones/__init__.py +5 -0
- sam2/modeling/backbones/hieradet.py +317 -0
- sam2/modeling/backbones/image_encoder.py +134 -0
- sam2/modeling/backbones/utils.py +93 -0
- sam2/modeling/memory_attention.py +169 -0
- sam2/modeling/memory_encoder.py +181 -0
- sam2/modeling/position_encoding.py +239 -0
- sam2/modeling/sam/__init__.py +5 -0
- sam2/modeling/sam/mask_decoder.py +295 -0
- sam2/modeling/sam/prompt_encoder.py +202 -0
- sam2/modeling/sam/transformer.py +311 -0
- sam2/modeling/sam2_base.py +913 -0
- sam2/modeling/sam2_utils.py +323 -0
- sam2/sam2_hiera_b+.yaml +113 -0
- sam2/sam2_hiera_l.yaml +117 -0
- sam2/sam2_hiera_s.yaml +116 -0
- sam2/sam2_hiera_t.yaml +118 -0
- sam2/sam2_image_predictor.py +466 -0
- sam2/sam2_video_predictor.py +1388 -0
- sam2/sam2_video_predictor_legacy.py +1172 -0
- sam2/utils/__init__.py +5 -0
- sam2/utils/amg.py +348 -0
- sam2/utils/misc.py +349 -0
- sam2/utils/transforms.py +118 -0
- singlebehaviorlab/__init__.py +4 -0
- singlebehaviorlab/__main__.py +130 -0
- singlebehaviorlab/_paths.py +100 -0
- singlebehaviorlab/backend/__init__.py +2 -0
- singlebehaviorlab/backend/augmentations.py +320 -0
- singlebehaviorlab/backend/data_store.py +420 -0
- singlebehaviorlab/backend/model.py +1290 -0
- singlebehaviorlab/backend/train.py +4667 -0
- singlebehaviorlab/backend/uncertainty.py +578 -0
- singlebehaviorlab/backend/video_processor.py +688 -0
- singlebehaviorlab/backend/video_utils.py +139 -0
- singlebehaviorlab/data/config/config.yaml +85 -0
- singlebehaviorlab/data/training_profiles.json +334 -0
- singlebehaviorlab/gui/__init__.py +4 -0
- singlebehaviorlab/gui/analysis_widget.py +2291 -0
- singlebehaviorlab/gui/attention_export.py +311 -0
- singlebehaviorlab/gui/clip_extraction_widget.py +481 -0
- singlebehaviorlab/gui/clustering_widget.py +3187 -0
- singlebehaviorlab/gui/inference_popups.py +1138 -0
- singlebehaviorlab/gui/inference_widget.py +4550 -0
- singlebehaviorlab/gui/inference_worker.py +651 -0
- singlebehaviorlab/gui/labeling_widget.py +2324 -0
- singlebehaviorlab/gui/main_window.py +754 -0
- singlebehaviorlab/gui/metadata_management_widget.py +1119 -0
- singlebehaviorlab/gui/motion_tracking.py +764 -0
- singlebehaviorlab/gui/overlay_export.py +1234 -0
- singlebehaviorlab/gui/plot_integration.py +729 -0
- singlebehaviorlab/gui/qt_helpers.py +29 -0
- singlebehaviorlab/gui/registration_widget.py +1485 -0
- singlebehaviorlab/gui/review_widget.py +1330 -0
- singlebehaviorlab/gui/segmentation_tracking_widget.py +2752 -0
- singlebehaviorlab/gui/tab_tutorial_dialog.py +312 -0
- singlebehaviorlab/gui/timeline_themes.py +131 -0
- singlebehaviorlab/gui/training_profiles.py +418 -0
- singlebehaviorlab/gui/training_widget.py +3719 -0
- singlebehaviorlab/gui/video_utils.py +233 -0
- singlebehaviorlab/licenses/SAM2-LICENSE +201 -0
- singlebehaviorlab/licenses/VideoPrism-LICENSE +202 -0
- singlebehaviorlab-2.0.0.dist-info/METADATA +447 -0
- singlebehaviorlab-2.0.0.dist-info/RECORD +88 -0
- singlebehaviorlab-2.0.0.dist-info/WHEEL +5 -0
- singlebehaviorlab-2.0.0.dist-info/entry_points.txt +2 -0
- singlebehaviorlab-2.0.0.dist-info/licenses/LICENSE +21 -0
- singlebehaviorlab-2.0.0.dist-info/top_level.txt +3 -0
- videoprism/__init__.py +0 -0
- videoprism/encoders.py +910 -0
- videoprism/layers.py +1136 -0
- videoprism/models.py +407 -0
- videoprism/tokenizers.py +167 -0
- videoprism/utils.py +168 -0
|
@@ -0,0 +1,729 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Plot integration utilities for PyQt6.
|
|
3
|
+
Handles both matplotlib and plotly plots.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
from PyQt6.QtWidgets import (
|
|
8
|
+
QWidget, QVBoxLayout, QScrollArea, QPushButton, QFileDialog, QMessageBox
|
|
9
|
+
)
|
|
10
|
+
from PyQt6.QtCore import Qt, pyqtSignal
|
|
11
|
+
from PyQt6.QtGui import QPainter, QColor, QPen, QFont
|
|
12
|
+
from PyQt6.QtCore import QRect
|
|
13
|
+
import matplotlib
|
|
14
|
+
matplotlib.use('QtAgg') # Use Qt backend
|
|
15
|
+
from matplotlib.backends.backend_qtagg import FigureCanvasQTAgg
|
|
16
|
+
from matplotlib.figure import Figure
|
|
17
|
+
import plotly.graph_objects as go
|
|
18
|
+
import plotly.io as pio
|
|
19
|
+
import io
|
|
20
|
+
from PIL import Image
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class TimelineWidget(QWidget):
|
|
26
|
+
"""Timeline widget showing context (grey) and clip (green) sections"""
|
|
27
|
+
|
|
28
|
+
def __init__(self, parent=None, clip_metadata: dict = None):
|
|
29
|
+
super().__init__(parent)
|
|
30
|
+
self.clip_metadata = clip_metadata or {}
|
|
31
|
+
self.duration_ms = 0
|
|
32
|
+
self.current_position_ms = 0
|
|
33
|
+
self.context_start_ms = 0
|
|
34
|
+
self.clip_start_ms = 0
|
|
35
|
+
self.clip_end_ms = 0
|
|
36
|
+
self.context_end_ms = 0
|
|
37
|
+
|
|
38
|
+
if clip_metadata:
|
|
39
|
+
self._calculate_segments()
|
|
40
|
+
|
|
41
|
+
def _calculate_segments(self):
|
|
42
|
+
"""Calculate timeline segment positions based on clip metadata"""
|
|
43
|
+
if not self.clip_metadata or self.duration_ms == 0:
|
|
44
|
+
return
|
|
45
|
+
|
|
46
|
+
fps = self.clip_metadata.get('fps', 30)
|
|
47
|
+
start_frame = self.clip_metadata.get('start_frame', 0)
|
|
48
|
+
end_frame = self.clip_metadata.get('end_frame', 0)
|
|
49
|
+
context_frames = self.clip_metadata.get('context_frames', 30)
|
|
50
|
+
|
|
51
|
+
# The extracted video is laid out as: context_before + clip + context_after.
|
|
52
|
+
total_frames_in_video = (end_frame - start_frame + 1) + (2 * context_frames)
|
|
53
|
+
|
|
54
|
+
if total_frames_in_video == 0:
|
|
55
|
+
return
|
|
56
|
+
|
|
57
|
+
context_before_frames = context_frames
|
|
58
|
+
clip_frames = end_frame - start_frame + 1
|
|
59
|
+
context_after_frames = context_frames
|
|
60
|
+
|
|
61
|
+
frame_duration_ms = 1000.0 / fps if fps > 0 else 33.33
|
|
62
|
+
|
|
63
|
+
self.context_start_ms = 0
|
|
64
|
+
self.clip_start_ms = context_before_frames * frame_duration_ms
|
|
65
|
+
self.clip_end_ms = (context_before_frames + clip_frames) * frame_duration_ms
|
|
66
|
+
self.context_end_ms = self.duration_ms
|
|
67
|
+
|
|
68
|
+
def set_duration(self, duration_ms: int):
|
|
69
|
+
"""Set total video duration"""
|
|
70
|
+
self.duration_ms = duration_ms
|
|
71
|
+
if self.clip_metadata:
|
|
72
|
+
self._calculate_segments()
|
|
73
|
+
self.update()
|
|
74
|
+
|
|
75
|
+
def set_current_position(self, position_ms: int, duration_ms: int):
|
|
76
|
+
"""Update current playback position"""
|
|
77
|
+
if duration_ms > 0:
|
|
78
|
+
self.duration_ms = duration_ms
|
|
79
|
+
if self.clip_metadata:
|
|
80
|
+
self._calculate_segments()
|
|
81
|
+
self.current_position_ms = position_ms
|
|
82
|
+
self.update()
|
|
83
|
+
|
|
84
|
+
def paintEvent(self, event):
|
|
85
|
+
"""Draw timeline with context and clip sections"""
|
|
86
|
+
if self.duration_ms == 0:
|
|
87
|
+
return
|
|
88
|
+
|
|
89
|
+
painter = QPainter(self)
|
|
90
|
+
painter.setRenderHint(QPainter.RenderHint.Antialiasing)
|
|
91
|
+
|
|
92
|
+
width = self.width()
|
|
93
|
+
height = self.height()
|
|
94
|
+
|
|
95
|
+
# Draw timeline background
|
|
96
|
+
painter.fillRect(0, 0, width, height, QColor(40, 40, 40))
|
|
97
|
+
|
|
98
|
+
if not self.clip_metadata or self.clip_start_ms == 0:
|
|
99
|
+
# No metadata, just draw a simple timeline
|
|
100
|
+
painter.fillRect(0, 0, width, height, QColor(60, 60, 60))
|
|
101
|
+
# Draw current position indicator
|
|
102
|
+
if self.current_position_ms > 0:
|
|
103
|
+
pos_x = int((self.current_position_ms / self.duration_ms) * width)
|
|
104
|
+
painter.setPen(QPen(QColor(255, 255, 255), 2))
|
|
105
|
+
painter.drawLine(pos_x, 0, pos_x, height)
|
|
106
|
+
return
|
|
107
|
+
|
|
108
|
+
# Draw segments
|
|
109
|
+
# Context before (grey)
|
|
110
|
+
context_before_width = int((self.clip_start_ms / self.duration_ms) * width)
|
|
111
|
+
painter.fillRect(0, 0, context_before_width, height, QColor(100, 100, 100))
|
|
112
|
+
|
|
113
|
+
# Clip section (green)
|
|
114
|
+
clip_start_x = context_before_width
|
|
115
|
+
clip_width = int(((self.clip_end_ms - self.clip_start_ms) / self.duration_ms) * width)
|
|
116
|
+
painter.fillRect(clip_start_x, 0, clip_width, height, QColor(0, 200, 0))
|
|
117
|
+
|
|
118
|
+
# Context after (grey)
|
|
119
|
+
context_after_start_x = clip_start_x + clip_width
|
|
120
|
+
context_after_width = width - context_after_start_x
|
|
121
|
+
painter.fillRect(context_after_start_x, 0, context_after_width, height, QColor(100, 100, 100))
|
|
122
|
+
|
|
123
|
+
# Draw current position indicator (white line)
|
|
124
|
+
if self.current_position_ms > 0:
|
|
125
|
+
pos_x = int((self.current_position_ms / self.duration_ms) * width)
|
|
126
|
+
painter.setPen(QPen(QColor(255, 255, 255), 2))
|
|
127
|
+
painter.drawLine(pos_x, 0, pos_x, height)
|
|
128
|
+
|
|
129
|
+
# Draw labels
|
|
130
|
+
painter.setPen(QPen(QColor(255, 255, 255), 1))
|
|
131
|
+
font = QFont("Arial", 8, QFont.Weight.Bold)
|
|
132
|
+
painter.setFont(font)
|
|
133
|
+
|
|
134
|
+
# Label for clip section (green)
|
|
135
|
+
if clip_width > 100: # Only draw if wide enough
|
|
136
|
+
label_rect = QRect(clip_start_x + 5, 2, clip_width - 10, height - 4)
|
|
137
|
+
painter.drawText(label_rect, Qt.AlignmentFlag.AlignLeft | Qt.AlignmentFlag.AlignVCenter, "Clip to evaluate")
|
|
138
|
+
|
|
139
|
+
# Label for context before (grey)
|
|
140
|
+
if context_before_width > 50:
|
|
141
|
+
context_label_rect = QRect(5, 2, context_before_width - 10, height - 4)
|
|
142
|
+
painter.drawText(context_label_rect, Qt.AlignmentFlag.AlignLeft | Qt.AlignmentFlag.AlignVCenter, "Context")
|
|
143
|
+
|
|
144
|
+
# Label for context after (grey)
|
|
145
|
+
if context_after_width > 50:
|
|
146
|
+
context_after_label_rect = QRect(context_after_start_x + 5, 2, context_after_width - 10, height - 4)
|
|
147
|
+
painter.drawText(context_after_label_rect, Qt.AlignmentFlag.AlignLeft | Qt.AlignmentFlag.AlignVCenter, "Context")
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
class MatplotlibWidget(QWidget):
|
|
151
|
+
"""Widget for displaying matplotlib figures"""
|
|
152
|
+
|
|
153
|
+
def __init__(self, parent=None, width=8, height=6, dpi=100):
|
|
154
|
+
super().__init__(parent)
|
|
155
|
+
self.figure = Figure(figsize=(width, height), dpi=dpi)
|
|
156
|
+
self.canvas = FigureCanvasQTAgg(self.figure)
|
|
157
|
+
self.original_figure = None # Store original figure for saving
|
|
158
|
+
|
|
159
|
+
layout = QVBoxLayout()
|
|
160
|
+
layout.setContentsMargins(0, 0, 0, 0)
|
|
161
|
+
layout.addWidget(self.canvas)
|
|
162
|
+
self.setLayout(layout)
|
|
163
|
+
|
|
164
|
+
def update_plot(self, fig):
|
|
165
|
+
"""Update the plot with a new figure"""
|
|
166
|
+
# Store the original figure for saving
|
|
167
|
+
self.original_figure = fig
|
|
168
|
+
|
|
169
|
+
# Clear existing figure
|
|
170
|
+
self.figure.clear()
|
|
171
|
+
|
|
172
|
+
# Matplotlib artists (especially collections from seaborn heatmaps) cannot be
|
|
173
|
+
# moved between figures. The safest approach is to save the figure as an image
|
|
174
|
+
# and display it. This avoids all "artist in more than one figure" errors.
|
|
175
|
+
try:
|
|
176
|
+
import io
|
|
177
|
+
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
|
178
|
+
from matplotlib.image import imread
|
|
179
|
+
|
|
180
|
+
# Match the figure size
|
|
181
|
+
if hasattr(fig, 'get_size_inches'):
|
|
182
|
+
self.figure.set_size_inches(fig.get_size_inches())
|
|
183
|
+
|
|
184
|
+
# Save the input figure to a buffer as PNG
|
|
185
|
+
buf = io.BytesIO()
|
|
186
|
+
canvas = FigureCanvasAgg(fig)
|
|
187
|
+
canvas.print_figure(buf, format='png', dpi=100, bbox_inches='tight', pad_inches=0.1)
|
|
188
|
+
buf.seek(0)
|
|
189
|
+
|
|
190
|
+
# Load the image and display it
|
|
191
|
+
img = imread(buf)
|
|
192
|
+
buf.close()
|
|
193
|
+
|
|
194
|
+
# Display the image in our figure
|
|
195
|
+
ax = self.figure.add_subplot(111)
|
|
196
|
+
ax.imshow(img, aspect='auto')
|
|
197
|
+
ax.axis('off')
|
|
198
|
+
|
|
199
|
+
except Exception as e:
|
|
200
|
+
logger.error("Error updating plot: %s", e, exc_info=True)
|
|
201
|
+
# On error, at least try to show something
|
|
202
|
+
ax = self.figure.add_subplot(111)
|
|
203
|
+
ax.text(0.5, 0.5, f"Error displaying plot:\n{str(e)}",
|
|
204
|
+
ha='center', va='center', transform=ax.transAxes)
|
|
205
|
+
|
|
206
|
+
self.canvas.draw()
|
|
207
|
+
|
|
208
|
+
def clear(self):
|
|
209
|
+
"""Clear the plot"""
|
|
210
|
+
self.figure.clear()
|
|
211
|
+
self.canvas.draw()
|
|
212
|
+
|
|
213
|
+
def get_figure(self):
|
|
214
|
+
"""Get the matplotlib figure"""
|
|
215
|
+
return self.figure
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
class ScrollablePlotContainer(QWidget):
|
|
219
|
+
"""Container widget with scrollable plot and save button"""
|
|
220
|
+
|
|
221
|
+
def __init__(self, plot_widget, parent=None):
|
|
222
|
+
super().__init__(parent)
|
|
223
|
+
self.plot_widget = plot_widget
|
|
224
|
+
self.current_figure = None # Store current figure for saving
|
|
225
|
+
|
|
226
|
+
# Main layout
|
|
227
|
+
main_layout = QVBoxLayout()
|
|
228
|
+
main_layout.setContentsMargins(0, 0, 0, 0)
|
|
229
|
+
main_layout.setSpacing(0)
|
|
230
|
+
|
|
231
|
+
scroll_area = QScrollArea()
|
|
232
|
+
scroll_area.setWidgetResizable(True)
|
|
233
|
+
scroll_area.setWidget(plot_widget)
|
|
234
|
+
scroll_area.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAsNeeded)
|
|
235
|
+
scroll_area.setVerticalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAsNeeded)
|
|
236
|
+
|
|
237
|
+
container = QWidget()
|
|
238
|
+
container_layout = QVBoxLayout()
|
|
239
|
+
container_layout.setContentsMargins(0, 0, 0, 0)
|
|
240
|
+
container_layout.addWidget(scroll_area)
|
|
241
|
+
container.setLayout(container_layout)
|
|
242
|
+
|
|
243
|
+
self.save_btn = QPushButton("Save Plot")
|
|
244
|
+
self.save_btn.setStyleSheet("""
|
|
245
|
+
QPushButton {
|
|
246
|
+
background-color: #28a745;
|
|
247
|
+
color: white;
|
|
248
|
+
border: none;
|
|
249
|
+
padding: 8px 16px;
|
|
250
|
+
border-radius: 4px;
|
|
251
|
+
font-weight: bold;
|
|
252
|
+
}
|
|
253
|
+
QPushButton:hover {
|
|
254
|
+
background-color: #218838;
|
|
255
|
+
}
|
|
256
|
+
QPushButton:pressed {
|
|
257
|
+
background-color: #1e7e34;
|
|
258
|
+
}
|
|
259
|
+
""")
|
|
260
|
+
self.save_btn.clicked.connect(self._save_plot)
|
|
261
|
+
self.save_btn.setFixedSize(120, 35)
|
|
262
|
+
|
|
263
|
+
main_widget = QWidget()
|
|
264
|
+
main_widget_layout = QVBoxLayout()
|
|
265
|
+
main_widget_layout.setContentsMargins(0, 0, 0, 0)
|
|
266
|
+
main_widget_layout.addWidget(scroll_area)
|
|
267
|
+
main_widget.setLayout(main_widget_layout)
|
|
268
|
+
|
|
269
|
+
# Floats the save button over the top-right corner of the scroll area.
|
|
270
|
+
class OverlayWidget(QWidget):
|
|
271
|
+
def __init__(self, parent, button):
|
|
272
|
+
super().__init__(parent)
|
|
273
|
+
self.button = button
|
|
274
|
+
layout = QVBoxLayout()
|
|
275
|
+
layout.setContentsMargins(10, 10, 10, 10)
|
|
276
|
+
layout.addWidget(button, alignment=Qt.AlignmentFlag.AlignTop | Qt.AlignmentFlag.AlignRight)
|
|
277
|
+
layout.addStretch()
|
|
278
|
+
self.setLayout(layout)
|
|
279
|
+
self.setStyleSheet("background-color: transparent;")
|
|
280
|
+
|
|
281
|
+
def resizeEvent(self, event):
|
|
282
|
+
super().resizeEvent(event)
|
|
283
|
+
self.setGeometry(0, 0, self.parent().width(), self.parent().height())
|
|
284
|
+
|
|
285
|
+
overlay = OverlayWidget(main_widget, self.save_btn)
|
|
286
|
+
overlay.raise_()
|
|
287
|
+
|
|
288
|
+
main_layout.addWidget(main_widget)
|
|
289
|
+
self.setLayout(main_layout)
|
|
290
|
+
|
|
291
|
+
def _save_plot(self):
|
|
292
|
+
"""Save the current plot as PNG or PDF"""
|
|
293
|
+
import os
|
|
294
|
+
|
|
295
|
+
if self.current_figure is None:
|
|
296
|
+
QMessageBox.warning(self, "No Plot", "No plot to save. Please generate a plot first.")
|
|
297
|
+
return
|
|
298
|
+
|
|
299
|
+
file_path, selected_filter = QFileDialog.getSaveFileName(
|
|
300
|
+
self, "Save Plot", "plot",
|
|
301
|
+
"PNG Files (*.png);;PDF Files (*.pdf);;SVG Files (*.svg)"
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
if not file_path:
|
|
305
|
+
return
|
|
306
|
+
|
|
307
|
+
try:
|
|
308
|
+
# Determine format from extension or filter
|
|
309
|
+
if selected_filter.startswith("PNG") or file_path.endswith('.png'):
|
|
310
|
+
format = 'png'
|
|
311
|
+
elif selected_filter.startswith("PDF") or file_path.endswith('.pdf'):
|
|
312
|
+
format = 'pdf'
|
|
313
|
+
elif selected_filter.startswith("SVG") or file_path.endswith('.svg'):
|
|
314
|
+
format = 'svg'
|
|
315
|
+
else:
|
|
316
|
+
format = 'png'
|
|
317
|
+
|
|
318
|
+
# Save based on widget type
|
|
319
|
+
if isinstance(self.plot_widget, PlotlyWidget):
|
|
320
|
+
# Save Plotly figure
|
|
321
|
+
import plotly.io as pio
|
|
322
|
+
if format == 'png':
|
|
323
|
+
pio.write_image(self.current_figure, file_path, format='png', width=1200, height=800, scale=2)
|
|
324
|
+
elif format == 'pdf':
|
|
325
|
+
pio.write_image(self.current_figure, file_path, format='pdf', width=1200, height=800)
|
|
326
|
+
elif format == 'svg':
|
|
327
|
+
pio.write_image(self.current_figure, file_path, format='svg', width=1200, height=800)
|
|
328
|
+
elif isinstance(self.plot_widget, MatplotlibWidget):
|
|
329
|
+
# Save Matplotlib figure
|
|
330
|
+
# Use the original figure stored in the widget
|
|
331
|
+
if hasattr(self.plot_widget, 'original_figure') and self.plot_widget.original_figure is not None:
|
|
332
|
+
self.plot_widget.original_figure.savefig(file_path, format=format, dpi=300, bbox_inches='tight')
|
|
333
|
+
elif hasattr(self.current_figure, 'savefig'):
|
|
334
|
+
self.current_figure.savefig(file_path, format=format, dpi=300, bbox_inches='tight')
|
|
335
|
+
else:
|
|
336
|
+
# Fallback: save the widget's figure
|
|
337
|
+
self.plot_widget.figure.savefig(file_path, format=format, dpi=300, bbox_inches='tight')
|
|
338
|
+
|
|
339
|
+
QMessageBox.information(self, "Success", f"Plot saved to:\n{file_path}")
|
|
340
|
+
except Exception as e:
|
|
341
|
+
logger.error("Error saving plot: %s", e, exc_info=True)
|
|
342
|
+
QMessageBox.critical(self, "Error", f"Error saving plot:\n{str(e)}")
|
|
343
|
+
|
|
344
|
+
def update_plot(self, fig):
|
|
345
|
+
"""Update the plot and store the figure"""
|
|
346
|
+
self.current_figure = fig
|
|
347
|
+
if hasattr(self.plot_widget, 'update_plot'):
|
|
348
|
+
self.plot_widget.update_plot(fig)
|
|
349
|
+
|
|
350
|
+
def clear(self):
|
|
351
|
+
"""Clear the plot"""
|
|
352
|
+
self.current_figure = None
|
|
353
|
+
if hasattr(self.plot_widget, 'clear'):
|
|
354
|
+
self.plot_widget.clear()
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
class PlotlyWidget(QWidget):
|
|
358
|
+
"""Widget for displaying plotly figures using HTML export with full interactivity"""
|
|
359
|
+
|
|
360
|
+
# Signal emitted when a point is clicked (snippet_id)
|
|
361
|
+
point_clicked = pyqtSignal(str)
|
|
362
|
+
|
|
363
|
+
def __init__(self, parent=None):
|
|
364
|
+
super().__init__(parent)
|
|
365
|
+
self._click_callback = None # Callback for point clicks
|
|
366
|
+
try:
|
|
367
|
+
from PyQt6.QtWebEngineWidgets import QWebEngineView
|
|
368
|
+
from PyQt6.QtCore import QUrl
|
|
369
|
+
from PyQt6.QtWebEngineCore import QWebEngineSettings
|
|
370
|
+
|
|
371
|
+
self.web_view = QWebEngineView()
|
|
372
|
+
|
|
373
|
+
# Configure settings for maximum interactivity
|
|
374
|
+
settings = self.web_view.settings()
|
|
375
|
+
# Enable JavaScript (should be enabled by default, but ensure it)
|
|
376
|
+
try:
|
|
377
|
+
settings.setAttribute(QWebEngineSettings.WebAttribute.JavascriptEnabled, True)
|
|
378
|
+
settings.setAttribute(QWebEngineSettings.WebAttribute.LocalContentCanAccessRemoteUrls, True)
|
|
379
|
+
settings.setAttribute(QWebEngineSettings.WebAttribute.LocalContentCanAccessFileUrls, True)
|
|
380
|
+
settings.setAttribute(QWebEngineSettings.WebAttribute.ErrorPageEnabled, True)
|
|
381
|
+
settings.setAttribute(QWebEngineSettings.WebAttribute.PluginsEnabled, True)
|
|
382
|
+
except AttributeError:
|
|
383
|
+
pass
|
|
384
|
+
|
|
385
|
+
# Set up QWebChannel for JavaScript-Python communication
|
|
386
|
+
try:
|
|
387
|
+
from PyQt6.QtWebChannel import QWebChannel
|
|
388
|
+
from PyQt6.QtCore import QObject, pyqtSlot
|
|
389
|
+
|
|
390
|
+
class ClickBridge(QObject):
|
|
391
|
+
def __init__(self, callback):
|
|
392
|
+
super().__init__()
|
|
393
|
+
self.callback = callback
|
|
394
|
+
|
|
395
|
+
@pyqtSlot(str)
|
|
396
|
+
def on_click(self, snippet_id):
|
|
397
|
+
if self.callback:
|
|
398
|
+
self.callback(snippet_id)
|
|
399
|
+
|
|
400
|
+
self.click_bridge = ClickBridge(self._handle_snippet_click)
|
|
401
|
+
self.web_channel = QWebChannel()
|
|
402
|
+
self.web_channel.registerObject('bridge', self.click_bridge)
|
|
403
|
+
self.web_view.page().setWebChannel(self.web_channel)
|
|
404
|
+
except ImportError:
|
|
405
|
+
# QWebChannel not available, fall back to URL scheme
|
|
406
|
+
self.click_bridge = None
|
|
407
|
+
self.web_channel = None
|
|
408
|
+
|
|
409
|
+
layout = QVBoxLayout()
|
|
410
|
+
layout.setContentsMargins(0, 0, 0, 0)
|
|
411
|
+
layout.addWidget(self.web_view)
|
|
412
|
+
self.setLayout(layout)
|
|
413
|
+
self.use_webview = True
|
|
414
|
+
|
|
415
|
+
# Store a temporary file path for HTML (optional, for better compatibility)
|
|
416
|
+
import tempfile
|
|
417
|
+
self.temp_dir = tempfile.gettempdir()
|
|
418
|
+
|
|
419
|
+
except ImportError as e:
|
|
420
|
+
logger.warning("QWebEngineWidgets not available: %s. Plotly plots will be static images.", e)
|
|
421
|
+
# Fallback to static image if WebEngine not available
|
|
422
|
+
from PyQt6.QtWidgets import QLabel
|
|
423
|
+
from PyQt6.QtGui import QPixmap
|
|
424
|
+
self.image_label = QLabel()
|
|
425
|
+
self.image_label.setAlignment(Qt.AlignmentFlag.AlignCenter)
|
|
426
|
+
self.image_label.setText("Plotly interactive plots require PyQt6.QtWebEngineWidgets.\nPlease install: pip install PyQt6-WebEngine")
|
|
427
|
+
layout = QVBoxLayout()
|
|
428
|
+
layout.addWidget(self.image_label)
|
|
429
|
+
self.setLayout(layout)
|
|
430
|
+
self.use_webview = False
|
|
431
|
+
|
|
432
|
+
def update_plot(self, fig):
|
|
433
|
+
"""Update the plot with a plotly figure"""
|
|
434
|
+
if self.use_webview:
|
|
435
|
+
try:
|
|
436
|
+
from PyQt6.QtCore import QUrl
|
|
437
|
+
import tempfile
|
|
438
|
+
import os
|
|
439
|
+
|
|
440
|
+
# Ensure figure has responsive layout and full interactivity
|
|
441
|
+
if not hasattr(fig, 'layout') or fig.layout is None:
|
|
442
|
+
fig.update_layout(template='plotly_white')
|
|
443
|
+
|
|
444
|
+
# Make layout responsive and ensure interactivity
|
|
445
|
+
fig.update_layout(
|
|
446
|
+
autosize=True,
|
|
447
|
+
hovermode='closest',
|
|
448
|
+
dragmode='pan' # Allow panning by default
|
|
449
|
+
)
|
|
450
|
+
|
|
451
|
+
# Create a temporary HTML file for better compatibility with QWebEngineView
|
|
452
|
+
# This ensures all JavaScript and resources load properly
|
|
453
|
+
temp_file = tempfile.NamedTemporaryFile(mode='w', suffix='.html', delete=False, dir=self.temp_dir)
|
|
454
|
+
temp_path = temp_file.name
|
|
455
|
+
temp_file.close()
|
|
456
|
+
|
|
457
|
+
# Embedding plotly.js inline is more reliable than the CDN path
|
|
458
|
+
# when QtWebEngine loads from a local file URL.
|
|
459
|
+
html = pio.to_html(
|
|
460
|
+
fig,
|
|
461
|
+
include_plotlyjs='inline',
|
|
462
|
+
div_id='plotly-div',
|
|
463
|
+
config={
|
|
464
|
+
'displayModeBar': True, # Show toolbar
|
|
465
|
+
'displaylogo': False, # Hide plotly logo
|
|
466
|
+
'modeBarButtonsToAdd': ['pan2d', 'select2d', 'lasso2d', 'resetScale2d', 'zoomIn2d', 'zoomOut2d'],
|
|
467
|
+
'toImageButtonOptions': {
|
|
468
|
+
'format': 'png',
|
|
469
|
+
'filename': 'plot',
|
|
470
|
+
'height': None,
|
|
471
|
+
'width': None,
|
|
472
|
+
'scale': 1
|
|
473
|
+
},
|
|
474
|
+
'responsive': True, # Enable responsive behavior
|
|
475
|
+
'staticPlot': False, # Ensure interactivity is enabled
|
|
476
|
+
'doubleClick': 'reset', # Double-click to reset zoom
|
|
477
|
+
'showTips': True, # Show interaction tips
|
|
478
|
+
'showLink': False # Hide "Edit chart" link
|
|
479
|
+
}
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
if hasattr(self, '_click_callback') and self._click_callback:
|
|
483
|
+
html = self._inject_click_handler(html)
|
|
484
|
+
|
|
485
|
+
with open(temp_path, 'w', encoding='utf-8') as f:
|
|
486
|
+
f.write(html)
|
|
487
|
+
|
|
488
|
+
# Loading from a file URL is more reliable than setHtml when
|
|
489
|
+
# injected JavaScript (click handlers) must run.
|
|
490
|
+
file_url = QUrl.fromLocalFile(temp_path)
|
|
491
|
+
|
|
492
|
+
if not self.web_view.isVisible():
|
|
493
|
+
self.web_view.show()
|
|
494
|
+
|
|
495
|
+
if self.web_view.width() < 100 or self.web_view.height() < 100:
|
|
496
|
+
self.web_view.setMinimumSize(400, 300)
|
|
497
|
+
|
|
498
|
+
self.web_view.setUrl(file_url)
|
|
499
|
+
|
|
500
|
+
if hasattr(self, '_last_temp_file') and os.path.exists(self._last_temp_file):
|
|
501
|
+
try:
|
|
502
|
+
os.unlink(self._last_temp_file)
|
|
503
|
+
except:
|
|
504
|
+
pass
|
|
505
|
+
|
|
506
|
+
self._last_temp_file = temp_path
|
|
507
|
+
|
|
508
|
+
except Exception as e:
|
|
509
|
+
logger.error("Error updating plotly plot: %s", e, exc_info=True)
|
|
510
|
+
try:
|
|
511
|
+
html = pio.to_html(fig, include_plotlyjs='inline')
|
|
512
|
+
self.web_view.setHtml(html)
|
|
513
|
+
except Exception as e2:
|
|
514
|
+
logger.error("Error with setHtml fallback: %s", e2)
|
|
515
|
+
else:
|
|
516
|
+
# Static image fallback when QtWebEngine is not available.
|
|
517
|
+
try:
|
|
518
|
+
img_bytes = pio.to_image(fig, format='png', width=1200, height=800)
|
|
519
|
+
from PyQt6.QtGui import QPixmap
|
|
520
|
+
pixmap = QPixmap()
|
|
521
|
+
pixmap.loadFromData(img_bytes)
|
|
522
|
+
self.image_label.setPixmap(pixmap)
|
|
523
|
+
except Exception as e:
|
|
524
|
+
self.image_label.setText(f"Error rendering plot: {str(e)}")
|
|
525
|
+
|
|
526
|
+
def set_click_callback(self, callback):
|
|
527
|
+
"""Set callback function for point clicks. Callback receives snippet_id (str)."""
|
|
528
|
+
self._click_callback = callback
|
|
529
|
+
if self.use_webview and hasattr(self, 'click_bridge') and self.click_bridge:
|
|
530
|
+
self.click_bridge.callback = callback
|
|
531
|
+
|
|
532
|
+
def _handle_snippet_click(self, snippet_id):
|
|
533
|
+
"""Handle snippet:// URL clicks"""
|
|
534
|
+
if self._click_callback:
|
|
535
|
+
self._click_callback(snippet_id)
|
|
536
|
+
|
|
537
|
+
def _inject_click_handler(self, html):
|
|
538
|
+
"""Inject JavaScript to handle plotly_click events."""
|
|
539
|
+
use_webchannel = hasattr(self, 'web_channel') and self.web_channel is not None
|
|
540
|
+
|
|
541
|
+
if use_webchannel:
|
|
542
|
+
js_injection = """
|
|
543
|
+
<script src="qrc:///qtwebchannel/qwebchannel.js"></script>
|
|
544
|
+
<script>
|
|
545
|
+
var bridge = null;
|
|
546
|
+
new QWebChannel(qt.webChannelTransport, function(channel) {
|
|
547
|
+
bridge = channel.objects.bridge;
|
|
548
|
+
});
|
|
549
|
+
|
|
550
|
+
// Wait for Plotly to be loaded and plot to be ready
|
|
551
|
+
function setupClickHandler() {
|
|
552
|
+
if (typeof Plotly === 'undefined') {
|
|
553
|
+
setTimeout(setupClickHandler, 100);
|
|
554
|
+
return;
|
|
555
|
+
}
|
|
556
|
+
|
|
557
|
+
var checkPlot = setInterval(function() {
|
|
558
|
+
var plotDivs = document.getElementsByClassName('plotly-graph-div');
|
|
559
|
+
if (plotDivs.length > 0) {
|
|
560
|
+
var plotDiv = plotDivs[0];
|
|
561
|
+
if (plotDiv && (plotDiv.data || plotDiv._fullLayout)) {
|
|
562
|
+
clearInterval(checkPlot);
|
|
563
|
+
|
|
564
|
+
// Attach click handler using Plotly's event system
|
|
565
|
+
plotDiv.on('plotly_click', function(data) {
|
|
566
|
+
if (data && data.points && data.points.length > 0) {
|
|
567
|
+
var point = data.points[0];
|
|
568
|
+
// Get snippet_id from customdata
|
|
569
|
+
var snippet_id = null;
|
|
570
|
+
if (point.customdata !== undefined && point.customdata !== null) {
|
|
571
|
+
if (Array.isArray(point.customdata) && point.customdata.length > 0) {
|
|
572
|
+
snippet_id = point.customdata[0];
|
|
573
|
+
} else if (Array.isArray(point.customdata[0]) && point.customdata[0].length > 0) {
|
|
574
|
+
snippet_id = point.customdata[0][0];
|
|
575
|
+
} else {
|
|
576
|
+
snippet_id = point.customdata;
|
|
577
|
+
}
|
|
578
|
+
}
|
|
579
|
+
|
|
580
|
+
if (bridge && snippet_id) {
|
|
581
|
+
bridge.on_click(String(snippet_id));
|
|
582
|
+
}
|
|
583
|
+
}
|
|
584
|
+
});
|
|
585
|
+
}
|
|
586
|
+
}
|
|
587
|
+
}, 100);
|
|
588
|
+
|
|
589
|
+
setTimeout(function() {
|
|
590
|
+
clearInterval(checkPlot);
|
|
591
|
+
}, 10000);
|
|
592
|
+
}
|
|
593
|
+
|
|
594
|
+
if (document.readyState === 'loading') {
|
|
595
|
+
document.addEventListener('DOMContentLoaded', setupClickHandler);
|
|
596
|
+
} else {
|
|
597
|
+
setupClickHandler();
|
|
598
|
+
}
|
|
599
|
+
</script>
|
|
600
|
+
</body>
|
|
601
|
+
"""
|
|
602
|
+
else:
|
|
603
|
+
# Fallback: use URL scheme (less reliable but works without QWebChannel)
|
|
604
|
+
js_injection = """
|
|
605
|
+
<script>
|
|
606
|
+
// Wait for Plotly to be loaded and plot to be ready
|
|
607
|
+
function setupClickHandler() {
|
|
608
|
+
if (typeof Plotly === 'undefined') {
|
|
609
|
+
setTimeout(setupClickHandler, 100);
|
|
610
|
+
return;
|
|
611
|
+
}
|
|
612
|
+
|
|
613
|
+
var checkPlot = setInterval(function() {
|
|
614
|
+
var plotDivs = document.getElementsByClassName('plotly-graph-div');
|
|
615
|
+
if (plotDivs.length > 0) {
|
|
616
|
+
var plotDiv = plotDivs[0];
|
|
617
|
+
if (plotDiv && (plotDiv.data || plotDiv._fullLayout)) {
|
|
618
|
+
clearInterval(checkPlot);
|
|
619
|
+
|
|
620
|
+
// Attach click handler using Plotly's event system
|
|
621
|
+
plotDiv.on('plotly_click', function(data) {
|
|
622
|
+
if (data && data.points && data.points.length > 0) {
|
|
623
|
+
var point = data.points[0];
|
|
624
|
+
var snippet_id = null;
|
|
625
|
+
if (point.customdata !== undefined && point.customdata !== null) {
|
|
626
|
+
if (Array.isArray(point.customdata) && point.customdata.length > 0) {
|
|
627
|
+
snippet_id = point.customdata[0];
|
|
628
|
+
} else if (Array.isArray(point.customdata[0]) && point.customdata[0].length > 0) {
|
|
629
|
+
snippet_id = point.customdata[0][0];
|
|
630
|
+
} else {
|
|
631
|
+
snippet_id = point.customdata;
|
|
632
|
+
}
|
|
633
|
+
}
|
|
634
|
+
|
|
635
|
+
if (snippet_id) {
|
|
636
|
+
// Use window.location to trigger navigation
|
|
637
|
+
window.location.href = 'snippet://' + encodeURIComponent(String(snippet_id));
|
|
638
|
+
}
|
|
639
|
+
}
|
|
640
|
+
});
|
|
641
|
+
}
|
|
642
|
+
}
|
|
643
|
+
}, 100);
|
|
644
|
+
|
|
645
|
+
setTimeout(function() {
|
|
646
|
+
clearInterval(checkPlot);
|
|
647
|
+
}, 10000);
|
|
648
|
+
}
|
|
649
|
+
|
|
650
|
+
if (document.readyState === 'loading') {
|
|
651
|
+
document.addEventListener('DOMContentLoaded', setupClickHandler);
|
|
652
|
+
} else {
|
|
653
|
+
setupClickHandler();
|
|
654
|
+
}
|
|
655
|
+
</script>
|
|
656
|
+
</body>
|
|
657
|
+
"""
|
|
658
|
+
return html.replace('</body>', js_injection)
|
|
659
|
+
|
|
660
|
+
def clear(self):
|
|
661
|
+
"""Clear the plot"""
|
|
662
|
+
if self.use_webview:
|
|
663
|
+
self.web_view.setHtml("")
|
|
664
|
+
else:
|
|
665
|
+
self.image_label.clear()
|
|
666
|
+
|
|
667
|
+
def __del__(self):
|
|
668
|
+
"""Cleanup temporary files when widget is destroyed"""
|
|
669
|
+
if hasattr(self, '_last_temp_file'):
|
|
670
|
+
import os
|
|
671
|
+
try:
|
|
672
|
+
if os.path.exists(self._last_temp_file):
|
|
673
|
+
os.unlink(self._last_temp_file)
|
|
674
|
+
except:
|
|
675
|
+
pass
|
|
676
|
+
|
|
677
|
+
|
|
678
|
+
class ScrollablePlotWidget(QWidget):
|
|
679
|
+
"""Scrollable container for plots (useful for large plots)"""
|
|
680
|
+
|
|
681
|
+
def __init__(self, plot_widget: QWidget, parent=None):
|
|
682
|
+
super().__init__(parent)
|
|
683
|
+
scroll = QScrollArea()
|
|
684
|
+
scroll.setWidget(plot_widget)
|
|
685
|
+
scroll.setWidgetResizable(True)
|
|
686
|
+
scroll.setAlignment(Qt.AlignmentFlag.AlignCenter)
|
|
687
|
+
|
|
688
|
+
layout = QVBoxLayout()
|
|
689
|
+
layout.setContentsMargins(0, 0, 0, 0)
|
|
690
|
+
layout.addWidget(scroll)
|
|
691
|
+
self.setLayout(layout)
|
|
692
|
+
|
|
693
|
+
self.plot_widget = plot_widget
|
|
694
|
+
|
|
695
|
+
def update_plot(self, fig):
|
|
696
|
+
"""Update the contained plot"""
|
|
697
|
+
if hasattr(self.plot_widget, 'update_plot'):
|
|
698
|
+
self.plot_widget.update_plot(fig)
|
|
699
|
+
|
|
700
|
+
def clear(self):
|
|
701
|
+
"""Clear the contained plot"""
|
|
702
|
+
if hasattr(self.plot_widget, 'clear'):
|
|
703
|
+
self.plot_widget.clear()
|
|
704
|
+
|
|
705
|
+
|
|
706
|
+
def create_plot_widget(plot_type='matplotlib', width=8, height=6, scrollable=False):
|
|
707
|
+
"""
|
|
708
|
+
Factory function to create appropriate plot widget.
|
|
709
|
+
|
|
710
|
+
Args:
|
|
711
|
+
plot_type: 'matplotlib' or 'plotly'
|
|
712
|
+
width: Figure width (for matplotlib)
|
|
713
|
+
height: Figure height (for matplotlib)
|
|
714
|
+
scrollable: Whether to wrap in scrollable container
|
|
715
|
+
|
|
716
|
+
Returns:
|
|
717
|
+
Plot widget instance
|
|
718
|
+
"""
|
|
719
|
+
if plot_type == 'matplotlib':
|
|
720
|
+
widget = MatplotlibWidget(width=width, height=height)
|
|
721
|
+
elif plot_type == 'plotly':
|
|
722
|
+
widget = PlotlyWidget()
|
|
723
|
+
else:
|
|
724
|
+
raise ValueError(f"Unknown plot_type: {plot_type}")
|
|
725
|
+
|
|
726
|
+
if scrollable:
|
|
727
|
+
return ScrollablePlotWidget(widget)
|
|
728
|
+
return widget
|
|
729
|
+
|