singlebehaviorlab 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. sam2/__init__.py +11 -0
  2. sam2/automatic_mask_generator.py +454 -0
  3. sam2/benchmark.py +92 -0
  4. sam2/build_sam.py +174 -0
  5. sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
  6. sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
  7. sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
  8. sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
  9. sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
  10. sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
  11. sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
  12. sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
  13. sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
  14. sam2/modeling/__init__.py +5 -0
  15. sam2/modeling/backbones/__init__.py +5 -0
  16. sam2/modeling/backbones/hieradet.py +317 -0
  17. sam2/modeling/backbones/image_encoder.py +134 -0
  18. sam2/modeling/backbones/utils.py +93 -0
  19. sam2/modeling/memory_attention.py +169 -0
  20. sam2/modeling/memory_encoder.py +181 -0
  21. sam2/modeling/position_encoding.py +239 -0
  22. sam2/modeling/sam/__init__.py +5 -0
  23. sam2/modeling/sam/mask_decoder.py +295 -0
  24. sam2/modeling/sam/prompt_encoder.py +202 -0
  25. sam2/modeling/sam/transformer.py +311 -0
  26. sam2/modeling/sam2_base.py +913 -0
  27. sam2/modeling/sam2_utils.py +323 -0
  28. sam2/sam2_hiera_b+.yaml +113 -0
  29. sam2/sam2_hiera_l.yaml +117 -0
  30. sam2/sam2_hiera_s.yaml +116 -0
  31. sam2/sam2_hiera_t.yaml +118 -0
  32. sam2/sam2_image_predictor.py +466 -0
  33. sam2/sam2_video_predictor.py +1388 -0
  34. sam2/sam2_video_predictor_legacy.py +1172 -0
  35. sam2/utils/__init__.py +5 -0
  36. sam2/utils/amg.py +348 -0
  37. sam2/utils/misc.py +349 -0
  38. sam2/utils/transforms.py +118 -0
  39. singlebehaviorlab/__init__.py +4 -0
  40. singlebehaviorlab/__main__.py +130 -0
  41. singlebehaviorlab/_paths.py +100 -0
  42. singlebehaviorlab/backend/__init__.py +2 -0
  43. singlebehaviorlab/backend/augmentations.py +320 -0
  44. singlebehaviorlab/backend/data_store.py +420 -0
  45. singlebehaviorlab/backend/model.py +1290 -0
  46. singlebehaviorlab/backend/train.py +4667 -0
  47. singlebehaviorlab/backend/uncertainty.py +578 -0
  48. singlebehaviorlab/backend/video_processor.py +688 -0
  49. singlebehaviorlab/backend/video_utils.py +139 -0
  50. singlebehaviorlab/data/config/config.yaml +85 -0
  51. singlebehaviorlab/data/training_profiles.json +334 -0
  52. singlebehaviorlab/gui/__init__.py +4 -0
  53. singlebehaviorlab/gui/analysis_widget.py +2291 -0
  54. singlebehaviorlab/gui/attention_export.py +311 -0
  55. singlebehaviorlab/gui/clip_extraction_widget.py +481 -0
  56. singlebehaviorlab/gui/clustering_widget.py +3187 -0
  57. singlebehaviorlab/gui/inference_popups.py +1138 -0
  58. singlebehaviorlab/gui/inference_widget.py +4550 -0
  59. singlebehaviorlab/gui/inference_worker.py +651 -0
  60. singlebehaviorlab/gui/labeling_widget.py +2324 -0
  61. singlebehaviorlab/gui/main_window.py +754 -0
  62. singlebehaviorlab/gui/metadata_management_widget.py +1119 -0
  63. singlebehaviorlab/gui/motion_tracking.py +764 -0
  64. singlebehaviorlab/gui/overlay_export.py +1234 -0
  65. singlebehaviorlab/gui/plot_integration.py +729 -0
  66. singlebehaviorlab/gui/qt_helpers.py +29 -0
  67. singlebehaviorlab/gui/registration_widget.py +1485 -0
  68. singlebehaviorlab/gui/review_widget.py +1330 -0
  69. singlebehaviorlab/gui/segmentation_tracking_widget.py +2752 -0
  70. singlebehaviorlab/gui/tab_tutorial_dialog.py +312 -0
  71. singlebehaviorlab/gui/timeline_themes.py +131 -0
  72. singlebehaviorlab/gui/training_profiles.py +418 -0
  73. singlebehaviorlab/gui/training_widget.py +3719 -0
  74. singlebehaviorlab/gui/video_utils.py +233 -0
  75. singlebehaviorlab/licenses/SAM2-LICENSE +201 -0
  76. singlebehaviorlab/licenses/VideoPrism-LICENSE +202 -0
  77. singlebehaviorlab-2.0.0.dist-info/METADATA +447 -0
  78. singlebehaviorlab-2.0.0.dist-info/RECORD +88 -0
  79. singlebehaviorlab-2.0.0.dist-info/WHEEL +5 -0
  80. singlebehaviorlab-2.0.0.dist-info/entry_points.txt +2 -0
  81. singlebehaviorlab-2.0.0.dist-info/licenses/LICENSE +21 -0
  82. singlebehaviorlab-2.0.0.dist-info/top_level.txt +3 -0
  83. videoprism/__init__.py +0 -0
  84. videoprism/encoders.py +910 -0
  85. videoprism/layers.py +1136 -0
  86. videoprism/models.py +407 -0
  87. videoprism/tokenizers.py +167 -0
  88. videoprism/utils.py +168 -0
@@ -0,0 +1,729 @@
1
+ """
2
+ Plot integration utilities for PyQt6.
3
+ Handles both matplotlib and plotly plots.
4
+ """
5
+
6
+ import logging
7
+ from PyQt6.QtWidgets import (
8
+ QWidget, QVBoxLayout, QScrollArea, QPushButton, QFileDialog, QMessageBox
9
+ )
10
+ from PyQt6.QtCore import Qt, pyqtSignal
11
+ from PyQt6.QtGui import QPainter, QColor, QPen, QFont
12
+ from PyQt6.QtCore import QRect
13
+ import matplotlib
14
+ matplotlib.use('QtAgg') # Use Qt backend
15
+ from matplotlib.backends.backend_qtagg import FigureCanvasQTAgg
16
+ from matplotlib.figure import Figure
17
+ import plotly.graph_objects as go
18
+ import plotly.io as pio
19
+ import io
20
+ from PIL import Image
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class TimelineWidget(QWidget):
26
+ """Timeline widget showing context (grey) and clip (green) sections"""
27
+
28
+ def __init__(self, parent=None, clip_metadata: dict = None):
29
+ super().__init__(parent)
30
+ self.clip_metadata = clip_metadata or {}
31
+ self.duration_ms = 0
32
+ self.current_position_ms = 0
33
+ self.context_start_ms = 0
34
+ self.clip_start_ms = 0
35
+ self.clip_end_ms = 0
36
+ self.context_end_ms = 0
37
+
38
+ if clip_metadata:
39
+ self._calculate_segments()
40
+
41
+ def _calculate_segments(self):
42
+ """Calculate timeline segment positions based on clip metadata"""
43
+ if not self.clip_metadata or self.duration_ms == 0:
44
+ return
45
+
46
+ fps = self.clip_metadata.get('fps', 30)
47
+ start_frame = self.clip_metadata.get('start_frame', 0)
48
+ end_frame = self.clip_metadata.get('end_frame', 0)
49
+ context_frames = self.clip_metadata.get('context_frames', 30)
50
+
51
+ # The extracted video is laid out as: context_before + clip + context_after.
52
+ total_frames_in_video = (end_frame - start_frame + 1) + (2 * context_frames)
53
+
54
+ if total_frames_in_video == 0:
55
+ return
56
+
57
+ context_before_frames = context_frames
58
+ clip_frames = end_frame - start_frame + 1
59
+ context_after_frames = context_frames
60
+
61
+ frame_duration_ms = 1000.0 / fps if fps > 0 else 33.33
62
+
63
+ self.context_start_ms = 0
64
+ self.clip_start_ms = context_before_frames * frame_duration_ms
65
+ self.clip_end_ms = (context_before_frames + clip_frames) * frame_duration_ms
66
+ self.context_end_ms = self.duration_ms
67
+
68
+ def set_duration(self, duration_ms: int):
69
+ """Set total video duration"""
70
+ self.duration_ms = duration_ms
71
+ if self.clip_metadata:
72
+ self._calculate_segments()
73
+ self.update()
74
+
75
+ def set_current_position(self, position_ms: int, duration_ms: int):
76
+ """Update current playback position"""
77
+ if duration_ms > 0:
78
+ self.duration_ms = duration_ms
79
+ if self.clip_metadata:
80
+ self._calculate_segments()
81
+ self.current_position_ms = position_ms
82
+ self.update()
83
+
84
+ def paintEvent(self, event):
85
+ """Draw timeline with context and clip sections"""
86
+ if self.duration_ms == 0:
87
+ return
88
+
89
+ painter = QPainter(self)
90
+ painter.setRenderHint(QPainter.RenderHint.Antialiasing)
91
+
92
+ width = self.width()
93
+ height = self.height()
94
+
95
+ # Draw timeline background
96
+ painter.fillRect(0, 0, width, height, QColor(40, 40, 40))
97
+
98
+ if not self.clip_metadata or self.clip_start_ms == 0:
99
+ # No metadata, just draw a simple timeline
100
+ painter.fillRect(0, 0, width, height, QColor(60, 60, 60))
101
+ # Draw current position indicator
102
+ if self.current_position_ms > 0:
103
+ pos_x = int((self.current_position_ms / self.duration_ms) * width)
104
+ painter.setPen(QPen(QColor(255, 255, 255), 2))
105
+ painter.drawLine(pos_x, 0, pos_x, height)
106
+ return
107
+
108
+ # Draw segments
109
+ # Context before (grey)
110
+ context_before_width = int((self.clip_start_ms / self.duration_ms) * width)
111
+ painter.fillRect(0, 0, context_before_width, height, QColor(100, 100, 100))
112
+
113
+ # Clip section (green)
114
+ clip_start_x = context_before_width
115
+ clip_width = int(((self.clip_end_ms - self.clip_start_ms) / self.duration_ms) * width)
116
+ painter.fillRect(clip_start_x, 0, clip_width, height, QColor(0, 200, 0))
117
+
118
+ # Context after (grey)
119
+ context_after_start_x = clip_start_x + clip_width
120
+ context_after_width = width - context_after_start_x
121
+ painter.fillRect(context_after_start_x, 0, context_after_width, height, QColor(100, 100, 100))
122
+
123
+ # Draw current position indicator (white line)
124
+ if self.current_position_ms > 0:
125
+ pos_x = int((self.current_position_ms / self.duration_ms) * width)
126
+ painter.setPen(QPen(QColor(255, 255, 255), 2))
127
+ painter.drawLine(pos_x, 0, pos_x, height)
128
+
129
+ # Draw labels
130
+ painter.setPen(QPen(QColor(255, 255, 255), 1))
131
+ font = QFont("Arial", 8, QFont.Weight.Bold)
132
+ painter.setFont(font)
133
+
134
+ # Label for clip section (green)
135
+ if clip_width > 100: # Only draw if wide enough
136
+ label_rect = QRect(clip_start_x + 5, 2, clip_width - 10, height - 4)
137
+ painter.drawText(label_rect, Qt.AlignmentFlag.AlignLeft | Qt.AlignmentFlag.AlignVCenter, "Clip to evaluate")
138
+
139
+ # Label for context before (grey)
140
+ if context_before_width > 50:
141
+ context_label_rect = QRect(5, 2, context_before_width - 10, height - 4)
142
+ painter.drawText(context_label_rect, Qt.AlignmentFlag.AlignLeft | Qt.AlignmentFlag.AlignVCenter, "Context")
143
+
144
+ # Label for context after (grey)
145
+ if context_after_width > 50:
146
+ context_after_label_rect = QRect(context_after_start_x + 5, 2, context_after_width - 10, height - 4)
147
+ painter.drawText(context_after_label_rect, Qt.AlignmentFlag.AlignLeft | Qt.AlignmentFlag.AlignVCenter, "Context")
148
+
149
+
150
+ class MatplotlibWidget(QWidget):
151
+ """Widget for displaying matplotlib figures"""
152
+
153
+ def __init__(self, parent=None, width=8, height=6, dpi=100):
154
+ super().__init__(parent)
155
+ self.figure = Figure(figsize=(width, height), dpi=dpi)
156
+ self.canvas = FigureCanvasQTAgg(self.figure)
157
+ self.original_figure = None # Store original figure for saving
158
+
159
+ layout = QVBoxLayout()
160
+ layout.setContentsMargins(0, 0, 0, 0)
161
+ layout.addWidget(self.canvas)
162
+ self.setLayout(layout)
163
+
164
+ def update_plot(self, fig):
165
+ """Update the plot with a new figure"""
166
+ # Store the original figure for saving
167
+ self.original_figure = fig
168
+
169
+ # Clear existing figure
170
+ self.figure.clear()
171
+
172
+ # Matplotlib artists (especially collections from seaborn heatmaps) cannot be
173
+ # moved between figures. The safest approach is to save the figure as an image
174
+ # and display it. This avoids all "artist in more than one figure" errors.
175
+ try:
176
+ import io
177
+ from matplotlib.backends.backend_agg import FigureCanvasAgg
178
+ from matplotlib.image import imread
179
+
180
+ # Match the figure size
181
+ if hasattr(fig, 'get_size_inches'):
182
+ self.figure.set_size_inches(fig.get_size_inches())
183
+
184
+ # Save the input figure to a buffer as PNG
185
+ buf = io.BytesIO()
186
+ canvas = FigureCanvasAgg(fig)
187
+ canvas.print_figure(buf, format='png', dpi=100, bbox_inches='tight', pad_inches=0.1)
188
+ buf.seek(0)
189
+
190
+ # Load the image and display it
191
+ img = imread(buf)
192
+ buf.close()
193
+
194
+ # Display the image in our figure
195
+ ax = self.figure.add_subplot(111)
196
+ ax.imshow(img, aspect='auto')
197
+ ax.axis('off')
198
+
199
+ except Exception as e:
200
+ logger.error("Error updating plot: %s", e, exc_info=True)
201
+ # On error, at least try to show something
202
+ ax = self.figure.add_subplot(111)
203
+ ax.text(0.5, 0.5, f"Error displaying plot:\n{str(e)}",
204
+ ha='center', va='center', transform=ax.transAxes)
205
+
206
+ self.canvas.draw()
207
+
208
+ def clear(self):
209
+ """Clear the plot"""
210
+ self.figure.clear()
211
+ self.canvas.draw()
212
+
213
+ def get_figure(self):
214
+ """Get the matplotlib figure"""
215
+ return self.figure
216
+
217
+
218
+ class ScrollablePlotContainer(QWidget):
219
+ """Container widget with scrollable plot and save button"""
220
+
221
+ def __init__(self, plot_widget, parent=None):
222
+ super().__init__(parent)
223
+ self.plot_widget = plot_widget
224
+ self.current_figure = None # Store current figure for saving
225
+
226
+ # Main layout
227
+ main_layout = QVBoxLayout()
228
+ main_layout.setContentsMargins(0, 0, 0, 0)
229
+ main_layout.setSpacing(0)
230
+
231
+ scroll_area = QScrollArea()
232
+ scroll_area.setWidgetResizable(True)
233
+ scroll_area.setWidget(plot_widget)
234
+ scroll_area.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAsNeeded)
235
+ scroll_area.setVerticalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAsNeeded)
236
+
237
+ container = QWidget()
238
+ container_layout = QVBoxLayout()
239
+ container_layout.setContentsMargins(0, 0, 0, 0)
240
+ container_layout.addWidget(scroll_area)
241
+ container.setLayout(container_layout)
242
+
243
+ self.save_btn = QPushButton("Save Plot")
244
+ self.save_btn.setStyleSheet("""
245
+ QPushButton {
246
+ background-color: #28a745;
247
+ color: white;
248
+ border: none;
249
+ padding: 8px 16px;
250
+ border-radius: 4px;
251
+ font-weight: bold;
252
+ }
253
+ QPushButton:hover {
254
+ background-color: #218838;
255
+ }
256
+ QPushButton:pressed {
257
+ background-color: #1e7e34;
258
+ }
259
+ """)
260
+ self.save_btn.clicked.connect(self._save_plot)
261
+ self.save_btn.setFixedSize(120, 35)
262
+
263
+ main_widget = QWidget()
264
+ main_widget_layout = QVBoxLayout()
265
+ main_widget_layout.setContentsMargins(0, 0, 0, 0)
266
+ main_widget_layout.addWidget(scroll_area)
267
+ main_widget.setLayout(main_widget_layout)
268
+
269
+ # Floats the save button over the top-right corner of the scroll area.
270
+ class OverlayWidget(QWidget):
271
+ def __init__(self, parent, button):
272
+ super().__init__(parent)
273
+ self.button = button
274
+ layout = QVBoxLayout()
275
+ layout.setContentsMargins(10, 10, 10, 10)
276
+ layout.addWidget(button, alignment=Qt.AlignmentFlag.AlignTop | Qt.AlignmentFlag.AlignRight)
277
+ layout.addStretch()
278
+ self.setLayout(layout)
279
+ self.setStyleSheet("background-color: transparent;")
280
+
281
+ def resizeEvent(self, event):
282
+ super().resizeEvent(event)
283
+ self.setGeometry(0, 0, self.parent().width(), self.parent().height())
284
+
285
+ overlay = OverlayWidget(main_widget, self.save_btn)
286
+ overlay.raise_()
287
+
288
+ main_layout.addWidget(main_widget)
289
+ self.setLayout(main_layout)
290
+
291
+ def _save_plot(self):
292
+ """Save the current plot as PNG or PDF"""
293
+ import os
294
+
295
+ if self.current_figure is None:
296
+ QMessageBox.warning(self, "No Plot", "No plot to save. Please generate a plot first.")
297
+ return
298
+
299
+ file_path, selected_filter = QFileDialog.getSaveFileName(
300
+ self, "Save Plot", "plot",
301
+ "PNG Files (*.png);;PDF Files (*.pdf);;SVG Files (*.svg)"
302
+ )
303
+
304
+ if not file_path:
305
+ return
306
+
307
+ try:
308
+ # Determine format from extension or filter
309
+ if selected_filter.startswith("PNG") or file_path.endswith('.png'):
310
+ format = 'png'
311
+ elif selected_filter.startswith("PDF") or file_path.endswith('.pdf'):
312
+ format = 'pdf'
313
+ elif selected_filter.startswith("SVG") or file_path.endswith('.svg'):
314
+ format = 'svg'
315
+ else:
316
+ format = 'png'
317
+
318
+ # Save based on widget type
319
+ if isinstance(self.plot_widget, PlotlyWidget):
320
+ # Save Plotly figure
321
+ import plotly.io as pio
322
+ if format == 'png':
323
+ pio.write_image(self.current_figure, file_path, format='png', width=1200, height=800, scale=2)
324
+ elif format == 'pdf':
325
+ pio.write_image(self.current_figure, file_path, format='pdf', width=1200, height=800)
326
+ elif format == 'svg':
327
+ pio.write_image(self.current_figure, file_path, format='svg', width=1200, height=800)
328
+ elif isinstance(self.plot_widget, MatplotlibWidget):
329
+ # Save Matplotlib figure
330
+ # Use the original figure stored in the widget
331
+ if hasattr(self.plot_widget, 'original_figure') and self.plot_widget.original_figure is not None:
332
+ self.plot_widget.original_figure.savefig(file_path, format=format, dpi=300, bbox_inches='tight')
333
+ elif hasattr(self.current_figure, 'savefig'):
334
+ self.current_figure.savefig(file_path, format=format, dpi=300, bbox_inches='tight')
335
+ else:
336
+ # Fallback: save the widget's figure
337
+ self.plot_widget.figure.savefig(file_path, format=format, dpi=300, bbox_inches='tight')
338
+
339
+ QMessageBox.information(self, "Success", f"Plot saved to:\n{file_path}")
340
+ except Exception as e:
341
+ logger.error("Error saving plot: %s", e, exc_info=True)
342
+ QMessageBox.critical(self, "Error", f"Error saving plot:\n{str(e)}")
343
+
344
+ def update_plot(self, fig):
345
+ """Update the plot and store the figure"""
346
+ self.current_figure = fig
347
+ if hasattr(self.plot_widget, 'update_plot'):
348
+ self.plot_widget.update_plot(fig)
349
+
350
+ def clear(self):
351
+ """Clear the plot"""
352
+ self.current_figure = None
353
+ if hasattr(self.plot_widget, 'clear'):
354
+ self.plot_widget.clear()
355
+
356
+
357
+ class PlotlyWidget(QWidget):
358
+ """Widget for displaying plotly figures using HTML export with full interactivity"""
359
+
360
+ # Signal emitted when a point is clicked (snippet_id)
361
+ point_clicked = pyqtSignal(str)
362
+
363
+ def __init__(self, parent=None):
364
+ super().__init__(parent)
365
+ self._click_callback = None # Callback for point clicks
366
+ try:
367
+ from PyQt6.QtWebEngineWidgets import QWebEngineView
368
+ from PyQt6.QtCore import QUrl
369
+ from PyQt6.QtWebEngineCore import QWebEngineSettings
370
+
371
+ self.web_view = QWebEngineView()
372
+
373
+ # Configure settings for maximum interactivity
374
+ settings = self.web_view.settings()
375
+ # Enable JavaScript (should be enabled by default, but ensure it)
376
+ try:
377
+ settings.setAttribute(QWebEngineSettings.WebAttribute.JavascriptEnabled, True)
378
+ settings.setAttribute(QWebEngineSettings.WebAttribute.LocalContentCanAccessRemoteUrls, True)
379
+ settings.setAttribute(QWebEngineSettings.WebAttribute.LocalContentCanAccessFileUrls, True)
380
+ settings.setAttribute(QWebEngineSettings.WebAttribute.ErrorPageEnabled, True)
381
+ settings.setAttribute(QWebEngineSettings.WebAttribute.PluginsEnabled, True)
382
+ except AttributeError:
383
+ pass
384
+
385
+ # Set up QWebChannel for JavaScript-Python communication
386
+ try:
387
+ from PyQt6.QtWebChannel import QWebChannel
388
+ from PyQt6.QtCore import QObject, pyqtSlot
389
+
390
+ class ClickBridge(QObject):
391
+ def __init__(self, callback):
392
+ super().__init__()
393
+ self.callback = callback
394
+
395
+ @pyqtSlot(str)
396
+ def on_click(self, snippet_id):
397
+ if self.callback:
398
+ self.callback(snippet_id)
399
+
400
+ self.click_bridge = ClickBridge(self._handle_snippet_click)
401
+ self.web_channel = QWebChannel()
402
+ self.web_channel.registerObject('bridge', self.click_bridge)
403
+ self.web_view.page().setWebChannel(self.web_channel)
404
+ except ImportError:
405
+ # QWebChannel not available, fall back to URL scheme
406
+ self.click_bridge = None
407
+ self.web_channel = None
408
+
409
+ layout = QVBoxLayout()
410
+ layout.setContentsMargins(0, 0, 0, 0)
411
+ layout.addWidget(self.web_view)
412
+ self.setLayout(layout)
413
+ self.use_webview = True
414
+
415
+ # Store a temporary file path for HTML (optional, for better compatibility)
416
+ import tempfile
417
+ self.temp_dir = tempfile.gettempdir()
418
+
419
+ except ImportError as e:
420
+ logger.warning("QWebEngineWidgets not available: %s. Plotly plots will be static images.", e)
421
+ # Fallback to static image if WebEngine not available
422
+ from PyQt6.QtWidgets import QLabel
423
+ from PyQt6.QtGui import QPixmap
424
+ self.image_label = QLabel()
425
+ self.image_label.setAlignment(Qt.AlignmentFlag.AlignCenter)
426
+ self.image_label.setText("Plotly interactive plots require PyQt6.QtWebEngineWidgets.\nPlease install: pip install PyQt6-WebEngine")
427
+ layout = QVBoxLayout()
428
+ layout.addWidget(self.image_label)
429
+ self.setLayout(layout)
430
+ self.use_webview = False
431
+
432
+ def update_plot(self, fig):
433
+ """Update the plot with a plotly figure"""
434
+ if self.use_webview:
435
+ try:
436
+ from PyQt6.QtCore import QUrl
437
+ import tempfile
438
+ import os
439
+
440
+ # Ensure figure has responsive layout and full interactivity
441
+ if not hasattr(fig, 'layout') or fig.layout is None:
442
+ fig.update_layout(template='plotly_white')
443
+
444
+ # Make layout responsive and ensure interactivity
445
+ fig.update_layout(
446
+ autosize=True,
447
+ hovermode='closest',
448
+ dragmode='pan' # Allow panning by default
449
+ )
450
+
451
+ # Create a temporary HTML file for better compatibility with QWebEngineView
452
+ # This ensures all JavaScript and resources load properly
453
+ temp_file = tempfile.NamedTemporaryFile(mode='w', suffix='.html', delete=False, dir=self.temp_dir)
454
+ temp_path = temp_file.name
455
+ temp_file.close()
456
+
457
+ # Embedding plotly.js inline is more reliable than the CDN path
458
+ # when QtWebEngine loads from a local file URL.
459
+ html = pio.to_html(
460
+ fig,
461
+ include_plotlyjs='inline',
462
+ div_id='plotly-div',
463
+ config={
464
+ 'displayModeBar': True, # Show toolbar
465
+ 'displaylogo': False, # Hide plotly logo
466
+ 'modeBarButtonsToAdd': ['pan2d', 'select2d', 'lasso2d', 'resetScale2d', 'zoomIn2d', 'zoomOut2d'],
467
+ 'toImageButtonOptions': {
468
+ 'format': 'png',
469
+ 'filename': 'plot',
470
+ 'height': None,
471
+ 'width': None,
472
+ 'scale': 1
473
+ },
474
+ 'responsive': True, # Enable responsive behavior
475
+ 'staticPlot': False, # Ensure interactivity is enabled
476
+ 'doubleClick': 'reset', # Double-click to reset zoom
477
+ 'showTips': True, # Show interaction tips
478
+ 'showLink': False # Hide "Edit chart" link
479
+ }
480
+ )
481
+
482
+ if hasattr(self, '_click_callback') and self._click_callback:
483
+ html = self._inject_click_handler(html)
484
+
485
+ with open(temp_path, 'w', encoding='utf-8') as f:
486
+ f.write(html)
487
+
488
+ # Loading from a file URL is more reliable than setHtml when
489
+ # injected JavaScript (click handlers) must run.
490
+ file_url = QUrl.fromLocalFile(temp_path)
491
+
492
+ if not self.web_view.isVisible():
493
+ self.web_view.show()
494
+
495
+ if self.web_view.width() < 100 or self.web_view.height() < 100:
496
+ self.web_view.setMinimumSize(400, 300)
497
+
498
+ self.web_view.setUrl(file_url)
499
+
500
+ if hasattr(self, '_last_temp_file') and os.path.exists(self._last_temp_file):
501
+ try:
502
+ os.unlink(self._last_temp_file)
503
+ except:
504
+ pass
505
+
506
+ self._last_temp_file = temp_path
507
+
508
+ except Exception as e:
509
+ logger.error("Error updating plotly plot: %s", e, exc_info=True)
510
+ try:
511
+ html = pio.to_html(fig, include_plotlyjs='inline')
512
+ self.web_view.setHtml(html)
513
+ except Exception as e2:
514
+ logger.error("Error with setHtml fallback: %s", e2)
515
+ else:
516
+ # Static image fallback when QtWebEngine is not available.
517
+ try:
518
+ img_bytes = pio.to_image(fig, format='png', width=1200, height=800)
519
+ from PyQt6.QtGui import QPixmap
520
+ pixmap = QPixmap()
521
+ pixmap.loadFromData(img_bytes)
522
+ self.image_label.setPixmap(pixmap)
523
+ except Exception as e:
524
+ self.image_label.setText(f"Error rendering plot: {str(e)}")
525
+
526
+ def set_click_callback(self, callback):
527
+ """Set callback function for point clicks. Callback receives snippet_id (str)."""
528
+ self._click_callback = callback
529
+ if self.use_webview and hasattr(self, 'click_bridge') and self.click_bridge:
530
+ self.click_bridge.callback = callback
531
+
532
+ def _handle_snippet_click(self, snippet_id):
533
+ """Handle snippet:// URL clicks"""
534
+ if self._click_callback:
535
+ self._click_callback(snippet_id)
536
+
537
+ def _inject_click_handler(self, html):
538
+ """Inject JavaScript to handle plotly_click events."""
539
+ use_webchannel = hasattr(self, 'web_channel') and self.web_channel is not None
540
+
541
+ if use_webchannel:
542
+ js_injection = """
543
+ <script src="qrc:///qtwebchannel/qwebchannel.js"></script>
544
+ <script>
545
+ var bridge = null;
546
+ new QWebChannel(qt.webChannelTransport, function(channel) {
547
+ bridge = channel.objects.bridge;
548
+ });
549
+
550
+ // Wait for Plotly to be loaded and plot to be ready
551
+ function setupClickHandler() {
552
+ if (typeof Plotly === 'undefined') {
553
+ setTimeout(setupClickHandler, 100);
554
+ return;
555
+ }
556
+
557
+ var checkPlot = setInterval(function() {
558
+ var plotDivs = document.getElementsByClassName('plotly-graph-div');
559
+ if (plotDivs.length > 0) {
560
+ var plotDiv = plotDivs[0];
561
+ if (plotDiv && (plotDiv.data || plotDiv._fullLayout)) {
562
+ clearInterval(checkPlot);
563
+
564
+ // Attach click handler using Plotly's event system
565
+ plotDiv.on('plotly_click', function(data) {
566
+ if (data && data.points && data.points.length > 0) {
567
+ var point = data.points[0];
568
+ // Get snippet_id from customdata
569
+ var snippet_id = null;
570
+ if (point.customdata !== undefined && point.customdata !== null) {
571
+ if (Array.isArray(point.customdata) && point.customdata.length > 0) {
572
+ snippet_id = point.customdata[0];
573
+ } else if (Array.isArray(point.customdata[0]) && point.customdata[0].length > 0) {
574
+ snippet_id = point.customdata[0][0];
575
+ } else {
576
+ snippet_id = point.customdata;
577
+ }
578
+ }
579
+
580
+ if (bridge && snippet_id) {
581
+ bridge.on_click(String(snippet_id));
582
+ }
583
+ }
584
+ });
585
+ }
586
+ }
587
+ }, 100);
588
+
589
+ setTimeout(function() {
590
+ clearInterval(checkPlot);
591
+ }, 10000);
592
+ }
593
+
594
+ if (document.readyState === 'loading') {
595
+ document.addEventListener('DOMContentLoaded', setupClickHandler);
596
+ } else {
597
+ setupClickHandler();
598
+ }
599
+ </script>
600
+ </body>
601
+ """
602
+ else:
603
+ # Fallback: use URL scheme (less reliable but works without QWebChannel)
604
+ js_injection = """
605
+ <script>
606
+ // Wait for Plotly to be loaded and plot to be ready
607
+ function setupClickHandler() {
608
+ if (typeof Plotly === 'undefined') {
609
+ setTimeout(setupClickHandler, 100);
610
+ return;
611
+ }
612
+
613
+ var checkPlot = setInterval(function() {
614
+ var plotDivs = document.getElementsByClassName('plotly-graph-div');
615
+ if (plotDivs.length > 0) {
616
+ var plotDiv = plotDivs[0];
617
+ if (plotDiv && (plotDiv.data || plotDiv._fullLayout)) {
618
+ clearInterval(checkPlot);
619
+
620
+ // Attach click handler using Plotly's event system
621
+ plotDiv.on('plotly_click', function(data) {
622
+ if (data && data.points && data.points.length > 0) {
623
+ var point = data.points[0];
624
+ var snippet_id = null;
625
+ if (point.customdata !== undefined && point.customdata !== null) {
626
+ if (Array.isArray(point.customdata) && point.customdata.length > 0) {
627
+ snippet_id = point.customdata[0];
628
+ } else if (Array.isArray(point.customdata[0]) && point.customdata[0].length > 0) {
629
+ snippet_id = point.customdata[0][0];
630
+ } else {
631
+ snippet_id = point.customdata;
632
+ }
633
+ }
634
+
635
+ if (snippet_id) {
636
+ // Use window.location to trigger navigation
637
+ window.location.href = 'snippet://' + encodeURIComponent(String(snippet_id));
638
+ }
639
+ }
640
+ });
641
+ }
642
+ }
643
+ }, 100);
644
+
645
+ setTimeout(function() {
646
+ clearInterval(checkPlot);
647
+ }, 10000);
648
+ }
649
+
650
+ if (document.readyState === 'loading') {
651
+ document.addEventListener('DOMContentLoaded', setupClickHandler);
652
+ } else {
653
+ setupClickHandler();
654
+ }
655
+ </script>
656
+ </body>
657
+ """
658
+ return html.replace('</body>', js_injection)
659
+
660
+ def clear(self):
661
+ """Clear the plot"""
662
+ if self.use_webview:
663
+ self.web_view.setHtml("")
664
+ else:
665
+ self.image_label.clear()
666
+
667
+ def __del__(self):
668
+ """Cleanup temporary files when widget is destroyed"""
669
+ if hasattr(self, '_last_temp_file'):
670
+ import os
671
+ try:
672
+ if os.path.exists(self._last_temp_file):
673
+ os.unlink(self._last_temp_file)
674
+ except:
675
+ pass
676
+
677
+
678
+ class ScrollablePlotWidget(QWidget):
679
+ """Scrollable container for plots (useful for large plots)"""
680
+
681
+ def __init__(self, plot_widget: QWidget, parent=None):
682
+ super().__init__(parent)
683
+ scroll = QScrollArea()
684
+ scroll.setWidget(plot_widget)
685
+ scroll.setWidgetResizable(True)
686
+ scroll.setAlignment(Qt.AlignmentFlag.AlignCenter)
687
+
688
+ layout = QVBoxLayout()
689
+ layout.setContentsMargins(0, 0, 0, 0)
690
+ layout.addWidget(scroll)
691
+ self.setLayout(layout)
692
+
693
+ self.plot_widget = plot_widget
694
+
695
+ def update_plot(self, fig):
696
+ """Update the contained plot"""
697
+ if hasattr(self.plot_widget, 'update_plot'):
698
+ self.plot_widget.update_plot(fig)
699
+
700
+ def clear(self):
701
+ """Clear the contained plot"""
702
+ if hasattr(self.plot_widget, 'clear'):
703
+ self.plot_widget.clear()
704
+
705
+
706
+ def create_plot_widget(plot_type='matplotlib', width=8, height=6, scrollable=False):
707
+ """
708
+ Factory function to create appropriate plot widget.
709
+
710
+ Args:
711
+ plot_type: 'matplotlib' or 'plotly'
712
+ width: Figure width (for matplotlib)
713
+ height: Figure height (for matplotlib)
714
+ scrollable: Whether to wrap in scrollable container
715
+
716
+ Returns:
717
+ Plot widget instance
718
+ """
719
+ if plot_type == 'matplotlib':
720
+ widget = MatplotlibWidget(width=width, height=height)
721
+ elif plot_type == 'plotly':
722
+ widget = PlotlyWidget()
723
+ else:
724
+ raise ValueError(f"Unknown plot_type: {plot_type}")
725
+
726
+ if scrollable:
727
+ return ScrollablePlotWidget(widget)
728
+ return widget
729
+