coralnet-toolbox 0.0.74__py2.py3-none-any.whl → 0.0.75__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. coralnet_toolbox/Explorer/QtDataItem.py +52 -22
  2. coralnet_toolbox/Explorer/QtExplorer.py +277 -1600
  3. coralnet_toolbox/Explorer/QtSettingsWidgets.py +101 -15
  4. coralnet_toolbox/Explorer/QtViewers.py +1568 -0
  5. coralnet_toolbox/Explorer/transformer_models.py +59 -0
  6. coralnet_toolbox/Explorer/yolo_models.py +112 -0
  7. coralnet_toolbox/MachineLearning/ImportDataset/QtBase.py +239 -147
  8. coralnet_toolbox/MachineLearning/VideoInference/YOLO3D/run.py +102 -16
  9. coralnet_toolbox/QtAnnotationWindow.py +16 -10
  10. coralnet_toolbox/QtImageWindow.py +3 -7
  11. coralnet_toolbox/Rasters/RasterTableModel.py +20 -0
  12. coralnet_toolbox/SAM/QtDeployGenerator.py +1 -4
  13. coralnet_toolbox/SAM/QtDeployPredictor.py +1 -3
  14. coralnet_toolbox/SeeAnything/QtDeployGenerator.py +131 -106
  15. coralnet_toolbox/SeeAnything/QtDeployPredictor.py +45 -3
  16. coralnet_toolbox/Tools/QtPolygonTool.py +42 -3
  17. coralnet_toolbox/Tools/QtRectangleTool.py +30 -0
  18. coralnet_toolbox/__init__.py +1 -1
  19. coralnet_toolbox/utilities.py +21 -0
  20. {coralnet_toolbox-0.0.74.dist-info → coralnet_toolbox-0.0.75.dist-info}/METADATA +6 -3
  21. {coralnet_toolbox-0.0.74.dist-info → coralnet_toolbox-0.0.75.dist-info}/RECORD +25 -22
  22. {coralnet_toolbox-0.0.74.dist-info → coralnet_toolbox-0.0.75.dist-info}/WHEEL +0 -0
  23. {coralnet_toolbox-0.0.74.dist-info → coralnet_toolbox-0.0.75.dist-info}/entry_points.txt +0 -0
  24. {coralnet_toolbox-0.0.74.dist-info → coralnet_toolbox-0.0.75.dist-info}/licenses/LICENSE.txt +0 -0
  25. {coralnet_toolbox-0.0.74.dist-info → coralnet_toolbox-0.0.75.dist-info}/top_level.txt +0 -0
@@ -7,30 +7,27 @@ import torch
7
7
 
8
8
  from ultralytics import YOLO
9
9
 
10
- from coralnet_toolbox.Icons import get_icon
11
- from coralnet_toolbox.utilities import pixmap_to_numpy
12
-
13
- from PyQt5.QtGui import QIcon, QPen, QColor, QPainter, QBrush, QPainterPath, QMouseEvent
14
- from PyQt5.QtCore import Qt, QTimer, QRect, QRectF, QPointF, pyqtSignal, QSignalBlocker, pyqtSlot, QEvent
15
- from PyQt5.QtWidgets import (QVBoxLayout, QHBoxLayout, QGraphicsView, QScrollArea,
16
- QGraphicsScene, QPushButton, QComboBox, QLabel, QWidget,
17
- QMainWindow, QSplitter, QGroupBox, QSlider, QMessageBox,
18
- QApplication, QGraphicsRectItem, QRubberBand, QMenu,
19
- QWidgetAction, QToolButton, QAction, QDoubleSpinBox)
20
-
10
+ from PyQt5.QtGui import QIcon
11
+ from PyQt5.QtCore import Qt, QSignalBlocker, pyqtSlot
12
+ from PyQt5.QtWidgets import (QVBoxLayout, QHBoxLayout, QPushButton, QWidget,
13
+ QMainWindow, QSplitter, QGroupBox, QMessageBox,
14
+ QApplication)
15
+
16
+ from coralnet_toolbox.Explorer.QtViewers import AnnotationViewer
17
+ from coralnet_toolbox.Explorer.QtViewers import EmbeddingViewer
21
18
  from coralnet_toolbox.Explorer.QtFeatureStore import FeatureStore
22
19
  from coralnet_toolbox.Explorer.QtDataItem import AnnotationDataItem
23
- from coralnet_toolbox.Explorer.QtDataItem import EmbeddingPointItem
24
- from coralnet_toolbox.Explorer.QtDataItem import AnnotationImageWidget
25
20
  from coralnet_toolbox.Explorer.QtSettingsWidgets import ModelSettingsWidget
26
- from coralnet_toolbox.Explorer.QtSettingsWidgets import SimilaritySettingsWidget
27
- from coralnet_toolbox.Explorer.QtSettingsWidgets import UncertaintySettingsWidget
28
- from coralnet_toolbox.Explorer.QtSettingsWidgets import MislabelSettingsWidget
29
21
  from coralnet_toolbox.Explorer.QtSettingsWidgets import EmbeddingSettingsWidget
30
22
  from coralnet_toolbox.Explorer.QtSettingsWidgets import AnnotationSettingsWidget
31
- from coralnet_toolbox.Explorer.QtSettingsWidgets import DuplicateSettingsWidget
32
23
 
33
- from coralnet_toolbox.Annotations.QtRectangleAnnotation import RectangleAnnotation
24
+ from coralnet_toolbox.Explorer.yolo_models import is_yolo_model
25
+ from coralnet_toolbox.Explorer.transformer_models import is_transformer_model
26
+
27
+ from coralnet_toolbox.utilities import pixmap_to_numpy
28
+ from coralnet_toolbox.utilities import pixmap_to_pil
29
+
30
+ from coralnet_toolbox.Icons import get_icon
34
31
 
35
32
  from coralnet_toolbox.QtProgressBar import ProgressBar
36
33
 
@@ -56,1544 +53,6 @@ warnings.filterwarnings("ignore", category=DeprecationWarning)
56
53
 
57
54
  POINT_WIDTH = 3
58
55
 
59
- # ----------------------------------------------------------------------------------------------------------------------
60
- # Viewers
61
- # ----------------------------------------------------------------------------------------------------------------------
62
-
63
-
64
- class EmbeddingViewer(QWidget):
65
- """Custom QGraphicsView for interactive embedding visualization with an isolate mode."""
66
- selection_changed = pyqtSignal(list)
67
- reset_view_requested = pyqtSignal()
68
- find_mislabels_requested = pyqtSignal()
69
- mislabel_parameters_changed = pyqtSignal(dict)
70
- find_uncertain_requested = pyqtSignal()
71
- uncertainty_parameters_changed = pyqtSignal(dict)
72
- find_duplicates_requested = pyqtSignal()
73
- duplicate_parameters_changed = pyqtSignal(dict)
74
-
75
- def __init__(self, parent=None):
76
- """Initialize the EmbeddingViewer widget."""
77
- super(EmbeddingViewer, self).__init__(parent)
78
- self.explorer_window = parent
79
-
80
- self.graphics_scene = QGraphicsScene()
81
- self.graphics_scene.setSceneRect(-5000, -5000, 10000, 10000)
82
-
83
- self.graphics_view = QGraphicsView(self.graphics_scene)
84
- self.graphics_view.setRenderHint(QPainter.Antialiasing)
85
- self.graphics_view.setDragMode(QGraphicsView.ScrollHandDrag)
86
- self.graphics_view.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
87
- self.graphics_view.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
88
- self.graphics_view.setMinimumHeight(200)
89
-
90
- self.rubber_band = None
91
- self.rubber_band_origin = QPointF()
92
- self.selection_at_press = None
93
- self.points_by_id = {}
94
- self.previous_selection_ids = set()
95
-
96
- # State for isolate mode
97
- self.isolated_mode = False
98
- self.isolated_points = set()
99
-
100
- self.is_uncertainty_analysis_available = False
101
-
102
- self.animation_offset = 0
103
- self.animation_timer = QTimer()
104
- self.animation_timer.timeout.connect(self.animate_selection)
105
- self.animation_timer.setInterval(100)
106
-
107
- # New timer for virtualization
108
- self.view_update_timer = QTimer(self)
109
- self.view_update_timer.setSingleShot(True)
110
- self.view_update_timer.timeout.connect(self._update_visible_points)
111
-
112
- self.graphics_scene.selectionChanged.connect(self.on_selection_changed)
113
- self.setup_ui()
114
- self.graphics_view.mousePressEvent = self.mousePressEvent
115
- self.graphics_view.mouseDoubleClickEvent = self.mouseDoubleClickEvent
116
- self.graphics_view.mouseReleaseEvent = self.mouseReleaseEvent
117
- self.graphics_view.mouseMoveEvent = self.mouseMoveEvent
118
- self.graphics_view.wheelEvent = self.wheelEvent
119
-
120
- def setup_ui(self):
121
- """Set up the UI with toolbar layout and graphics view."""
122
- layout = QVBoxLayout(self)
123
- layout.setContentsMargins(0, 0, 0, 0)
124
-
125
- toolbar_layout = QHBoxLayout()
126
-
127
- # Isolate/Show All buttons
128
- self.isolate_button = QPushButton("Isolate Selection")
129
- self.isolate_button.setToolTip("Hide all non-selected points")
130
- self.isolate_button.clicked.connect(self.isolate_selection)
131
- toolbar_layout.addWidget(self.isolate_button)
132
-
133
- self.show_all_button = QPushButton("Show All")
134
- self.show_all_button.setToolTip("Show all embedding points")
135
- self.show_all_button.clicked.connect(self.show_all_points)
136
- toolbar_layout.addWidget(self.show_all_button)
137
-
138
- toolbar_layout.addWidget(self._create_separator())
139
-
140
- # Create a QToolButton to have both a primary action and a dropdown menu
141
- self.find_mislabels_button = QToolButton()
142
- self.find_mislabels_button.setText("Find Potential Mislabels")
143
- self.find_mislabels_button.setPopupMode(QToolButton.MenuButtonPopup) # Key change for split-button style
144
- self.find_mislabels_button.setToolButtonStyle(Qt.ToolButtonTextOnly)
145
- self.find_mislabels_button.setStyleSheet(
146
- "QToolButton::menu-indicator {"
147
- " subcontrol-position: right center;"
148
- " subcontrol-origin: padding;"
149
- " left: -4px;"
150
- " }"
151
- )
152
-
153
- # The primary action (clicking the button) triggers the analysis
154
- run_analysis_action = QAction("Find Potential Mislabels", self)
155
- run_analysis_action.triggered.connect(self.find_mislabels_requested.emit)
156
- self.find_mislabels_button.setDefaultAction(run_analysis_action)
157
-
158
- # The dropdown menu contains the settings
159
- mislabel_settings_widget = MislabelSettingsWidget()
160
- settings_menu = QMenu(self)
161
- widget_action = QWidgetAction(settings_menu)
162
- widget_action.setDefaultWidget(mislabel_settings_widget)
163
- settings_menu.addAction(widget_action)
164
- self.find_mislabels_button.setMenu(settings_menu)
165
-
166
- # Connect the widget's signal to the viewer's signal
167
- mislabel_settings_widget.parameters_changed.connect(self.mislabel_parameters_changed.emit)
168
- toolbar_layout.addWidget(self.find_mislabels_button)
169
-
170
- # Create a QToolButton for uncertainty analysis
171
- self.find_uncertain_button = QToolButton()
172
- self.find_uncertain_button.setText("Review Uncertain")
173
- self.find_uncertain_button.setToolTip(
174
- "Find annotations where the model is least confident.\n"
175
- "Requires a .pt classification model and 'Predictions' mode."
176
- )
177
- self.find_uncertain_button.setPopupMode(QToolButton.MenuButtonPopup)
178
- self.find_uncertain_button.setToolButtonStyle(Qt.ToolButtonTextOnly)
179
- self.find_uncertain_button.setStyleSheet(
180
- "QToolButton::menu-indicator { "
181
- "subcontrol-position: right center; "
182
- "subcontrol-origin: padding; "
183
- "left: -4px; }"
184
- )
185
-
186
- run_uncertainty_action = QAction("Review Uncertain", self)
187
- run_uncertainty_action.triggered.connect(self.find_uncertain_requested.emit)
188
- self.find_uncertain_button.setDefaultAction(run_uncertainty_action)
189
-
190
- uncertainty_settings_widget = UncertaintySettingsWidget()
191
- uncertainty_menu = QMenu(self)
192
- uncertainty_widget_action = QWidgetAction(uncertainty_menu)
193
- uncertainty_widget_action.setDefaultWidget(uncertainty_settings_widget)
194
- uncertainty_menu.addAction(uncertainty_widget_action)
195
- self.find_uncertain_button.setMenu(uncertainty_menu)
196
-
197
- uncertainty_settings_widget.parameters_changed.connect(self.uncertainty_parameters_changed.emit)
198
- toolbar_layout.addWidget(self.find_uncertain_button)
199
-
200
- # Create a QToolButton for duplicate detection
201
- self.find_duplicates_button = QToolButton()
202
- self.find_duplicates_button.setText("Find Duplicates")
203
- self.find_duplicates_button.setToolTip(
204
- "Find annotations that are likely duplicates based on feature similarity."
205
- )
206
- self.find_duplicates_button.setPopupMode(QToolButton.MenuButtonPopup)
207
- self.find_duplicates_button.setToolButtonStyle(Qt.ToolButtonTextOnly)
208
- self.find_duplicates_button.setStyleSheet(
209
- "QToolButton::menu-indicator { "
210
- "subcontrol-position: right center; "
211
- "subcontrol-origin: padding; "
212
- "left: -4px; }"
213
- )
214
-
215
- run_duplicates_action = QAction("Find Duplicates", self)
216
- run_duplicates_action.triggered.connect(self.find_duplicates_requested.emit)
217
- self.find_duplicates_button.setDefaultAction(run_duplicates_action)
218
-
219
- duplicate_settings_widget = DuplicateSettingsWidget()
220
- duplicate_menu = QMenu(self)
221
- duplicate_widget_action = QWidgetAction(duplicate_menu)
222
- duplicate_widget_action.setDefaultWidget(duplicate_settings_widget)
223
- duplicate_menu.addAction(duplicate_widget_action)
224
- self.find_duplicates_button.setMenu(duplicate_menu)
225
-
226
- duplicate_settings_widget.parameters_changed.connect(self.duplicate_parameters_changed.emit)
227
- toolbar_layout.addWidget(self.find_duplicates_button)
228
-
229
- # Add a stretch and separator
230
- toolbar_layout.addStretch()
231
- toolbar_layout.addWidget(self._create_separator())
232
-
233
- # Center on selection button
234
- self.center_on_selection_button = QPushButton()
235
- self.center_on_selection_button.setIcon(get_icon("target.png"))
236
- self.center_on_selection_button.setToolTip("Center view on selected point(s)")
237
- self.center_on_selection_button.clicked.connect(self.center_on_selection)
238
- toolbar_layout.addWidget(self.center_on_selection_button)
239
-
240
- # Home button to reset view
241
- self.home_button = QPushButton()
242
- self.home_button.setIcon(get_icon("home.png"))
243
- self.home_button.setToolTip("Reset view to fit all points")
244
- self.home_button.clicked.connect(self.reset_view)
245
- toolbar_layout.addWidget(self.home_button)
246
-
247
- layout.addLayout(toolbar_layout)
248
- layout.addWidget(self.graphics_view)
249
-
250
- self.placeholder_label = QLabel(
251
- "No embedding data available.\nPress 'Apply Embedding' to generate visualization."
252
- )
253
- self.placeholder_label.setAlignment(Qt.AlignCenter)
254
- self.placeholder_label.setStyleSheet("color: gray; font-size: 14px;")
255
- layout.addWidget(self.placeholder_label)
256
-
257
- self.show_placeholder()
258
- self._update_toolbar_state()
259
-
260
- def _create_separator(self):
261
- """Creates a vertical separator for the toolbar."""
262
- separator = QLabel("|")
263
- separator.setStyleSheet("color: gray; margin: 0 5px;")
264
- return separator
265
-
266
- def _schedule_view_update(self):
267
- """Schedules a delayed update of visible points to avoid performance issues."""
268
- self.view_update_timer.start(50) # 50ms delay
269
-
270
- def _update_visible_points(self):
271
- """Sets visibility for points based on whether they are in the viewport."""
272
- if self.isolated_mode or not self.points_by_id:
273
- return
274
-
275
- # Get the visible rectangle in scene coordinates
276
- visible_rect = self.graphics_view.mapToScene(self.graphics_view.viewport().rect()).boundingRect()
277
-
278
- # Add a buffer to make scrolling smoother by loading points before they enter the view
279
- buffer_x = visible_rect.width() * 0.2
280
- buffer_y = visible_rect.height() * 0.2
281
- buffered_visible_rect = visible_rect.adjusted(-buffer_x, -buffer_y, buffer_x, buffer_y)
282
-
283
- for point in self.points_by_id.values():
284
- point.setVisible(buffered_visible_rect.contains(point.pos()) or point.isSelected())
285
-
286
- @pyqtSlot()
287
- def isolate_selection(self):
288
- """Hides all points that are not currently selected."""
289
- selected_items = self.graphics_scene.selectedItems()
290
- if not selected_items or self.isolated_mode:
291
- return
292
-
293
- self.isolated_points = set(selected_items)
294
- self.graphics_view.setUpdatesEnabled(False)
295
- try:
296
- for point in self.points_by_id.values():
297
- point.setVisible(point in self.isolated_points)
298
- self.isolated_mode = True
299
- finally:
300
- self.graphics_view.setUpdatesEnabled(True)
301
-
302
- self._update_toolbar_state()
303
-
304
- @pyqtSlot()
305
- def show_all_points(self):
306
- """Shows all embedding points, exiting isolated mode."""
307
- if not self.isolated_mode:
308
- return
309
-
310
- self.isolated_mode = False
311
- self.isolated_points.clear()
312
- self.graphics_view.setUpdatesEnabled(False)
313
- try:
314
- # Instead of showing all, let the virtualization logic take over
315
- self._update_visible_points()
316
- finally:
317
- self.graphics_view.setUpdatesEnabled(True)
318
-
319
- self._update_toolbar_state()
320
-
321
- def _update_toolbar_state(self):
322
- """Updates toolbar buttons based on selection and isolation mode."""
323
- selection_exists = bool(self.graphics_scene.selectedItems())
324
- points_exist = bool(self.points_by_id)
325
-
326
- self.find_mislabels_button.setEnabled(points_exist)
327
- self.find_uncertain_button.setEnabled(points_exist and self.is_uncertainty_analysis_available)
328
- self.find_duplicates_button.setEnabled(points_exist)
329
- self.center_on_selection_button.setEnabled(points_exist and selection_exists)
330
-
331
- if self.isolated_mode:
332
- self.isolate_button.hide()
333
- self.show_all_button.show()
334
- else:
335
- self.isolate_button.show()
336
- self.show_all_button.hide()
337
- self.isolate_button.setEnabled(selection_exists)
338
-
339
- def reset_view(self):
340
- """Reset the view to fit all embedding points."""
341
- self.fit_view_to_points()
342
-
343
- def center_on_selection(self):
344
- """Centers the view on selected point(s) or maintains the current view if no points are selected."""
345
- selected_items = self.graphics_scene.selectedItems()
346
- if not selected_items:
347
- # No selection, show a message
348
- QMessageBox.information(self, "No Selection", "Please select one or more points first.")
349
- return
350
-
351
- # Create a bounding rect that encompasses all selected points
352
- selection_rect = None
353
-
354
- for item in selected_items:
355
- if isinstance(item, EmbeddingPointItem):
356
- # Get the item's bounding rect in scene coordinates
357
- item_rect = item.sceneBoundingRect()
358
-
359
- # Add padding around the point for better visibility
360
- padding = 50 # pixels
361
- item_rect = item_rect.adjusted(-padding, -padding, padding, padding)
362
-
363
- if selection_rect is None:
364
- selection_rect = item_rect
365
- else:
366
- selection_rect = selection_rect.united(item_rect)
367
-
368
- if selection_rect:
369
- # Add extra margin for better visibility
370
- margin = 20
371
- selection_rect = selection_rect.adjusted(-margin, -margin, margin, margin)
372
-
373
- # Fit the view to the selection rect
374
- self.graphics_view.fitInView(selection_rect, Qt.KeepAspectRatio)
375
-
376
- def show_placeholder(self):
377
- """Show the placeholder message and hide the graphics view."""
378
- self.graphics_view.setVisible(False)
379
- self.placeholder_label.setVisible(True)
380
- self.home_button.setEnabled(False)
381
- self.center_on_selection_button.setEnabled(False) # Disable center button
382
- self.find_mislabels_button.setEnabled(False)
383
- self.find_uncertain_button.setEnabled(False)
384
- self.find_duplicates_button.setEnabled(False)
385
-
386
- self.isolate_button.show()
387
- self.isolate_button.setEnabled(False)
388
- self.show_all_button.hide()
389
-
390
- def show_embedding(self):
391
- """Show the graphics view and hide the placeholder message."""
392
- self.graphics_view.setVisible(True)
393
- self.placeholder_label.setVisible(False)
394
- self.home_button.setEnabled(True)
395
- self._update_toolbar_state()
396
-
397
- # Delegate graphics view methods
398
- def setRenderHint(self, hint):
399
- """Set render hint for the graphics view."""
400
- self.graphics_view.setRenderHint(hint)
401
-
402
- def setDragMode(self, mode):
403
- """Set drag mode for the graphics view."""
404
- self.graphics_view.setDragMode(mode)
405
-
406
- def setTransformationAnchor(self, anchor):
407
- """Set transformation anchor for the graphics view."""
408
- self.graphics_view.setTransformationAnchor(anchor)
409
-
410
- def setResizeAnchor(self, anchor):
411
- """Set resize anchor for the graphics view."""
412
- self.graphics_view.setResizeAnchor(anchor)
413
-
414
- def mapToScene(self, point):
415
- """Map a point to the scene coordinates."""
416
- return self.graphics_view.mapToScene(point)
417
-
418
- def scale(self, sx, sy):
419
- """Scale the graphics view."""
420
- self.graphics_view.scale(sx, sy)
421
-
422
- def translate(self, dx, dy):
423
- """Translate the graphics view."""
424
- self.graphics_view.translate(dx, dy)
425
-
426
- def fitInView(self, rect, aspect_ratio):
427
- """Fit the view to a rectangle with aspect ratio."""
428
- self.graphics_view.fitInView(rect, aspect_ratio)
429
-
430
- def keyPressEvent(self, event):
431
- """Handles key presses for deleting selected points."""
432
- if event.key() in (Qt.Key_Delete, Qt.Key_Backspace) and event.modifiers() == Qt.ControlModifier:
433
- selected_items = self.graphics_scene.selectedItems()
434
- if not selected_items:
435
- super().keyPressEvent(event)
436
- return
437
-
438
- # Extract the central data items from the selected graphics points
439
- data_items_to_delete = [
440
- item.data_item for item in selected_items if isinstance(item, EmbeddingPointItem)
441
- ]
442
-
443
- # Delegate the actual deletion to the main ExplorerWindow
444
- if data_items_to_delete:
445
- self.explorer_window.delete_data_items(data_items_to_delete)
446
-
447
- event.accept()
448
- else:
449
- super().keyPressEvent(event)
450
-
451
- def mousePressEvent(self, event):
452
- """Handle mouse press for selection (point or rubber band) and panning."""
453
- # Ctrl+Right-Click for context menu selection
454
- if event.button() == Qt.RightButton and event.modifiers() == Qt.ControlModifier:
455
- item_at_pos = self.graphics_view.itemAt(event.pos())
456
- if isinstance(item_at_pos, EmbeddingPointItem):
457
- # 1. Clear all selections in both viewers
458
- self.graphics_scene.clearSelection()
459
- item_at_pos.setSelected(True)
460
- self.on_selection_changed() # Updates internal state and emits signals
461
-
462
- # 2. Sync annotation viewer selection
463
- ann_id = item_at_pos.data_item.annotation.id
464
- self.explorer_window.annotation_viewer.render_selection_from_ids({ann_id})
465
-
466
- # 3. Update annotation window (set image, select, center)
467
- explorer = self.explorer_window
468
- annotation = item_at_pos.data_item.annotation
469
- image_path = annotation.image_path
470
-
471
- if hasattr(explorer, 'annotation_window'):
472
- if explorer.annotation_window.current_image_path != image_path:
473
- if hasattr(explorer.annotation_window, 'set_image'):
474
- explorer.annotation_window.set_image(image_path)
475
- if hasattr(explorer.annotation_window, 'select_annotation'):
476
- explorer.annotation_window.select_annotation(annotation)
477
- if hasattr(explorer.annotation_window, 'center_on_annotation'):
478
- explorer.annotation_window.center_on_annotation(annotation)
479
-
480
- explorer.update_label_window_selection()
481
- explorer.update_button_states()
482
- event.accept()
483
- return
484
-
485
- # Handle left-click for selection or rubber band
486
- if event.button() == Qt.LeftButton and event.modifiers() == Qt.ControlModifier:
487
- item_at_pos = self.graphics_view.itemAt(event.pos())
488
- if isinstance(item_at_pos, EmbeddingPointItem):
489
- self.graphics_view.setDragMode(QGraphicsView.NoDrag)
490
- # The viewer (controller) directly changes the state on the data item.
491
- is_currently_selected = item_at_pos.data_item.is_selected
492
- item_at_pos.data_item.set_selected(not is_currently_selected)
493
- item_at_pos.setSelected(not is_currently_selected) # Keep scene selection in sync
494
- self.on_selection_changed() # Manually trigger update
495
- return
496
-
497
- self.selection_at_press = set(self.graphics_scene.selectedItems())
498
- self.graphics_view.setDragMode(QGraphicsView.NoDrag)
499
- self.rubber_band_origin = self.graphics_view.mapToScene(event.pos())
500
- self.rubber_band = QGraphicsRectItem(QRectF(self.rubber_band_origin, self.rubber_band_origin))
501
- self.rubber_band.setPen(QPen(QColor(0, 100, 255), 1, Qt.DotLine))
502
- self.rubber_band.setBrush(QBrush(QColor(0, 100, 255, 50)))
503
- self.graphics_scene.addItem(self.rubber_band)
504
-
505
- elif event.button() == Qt.RightButton:
506
- self.graphics_view.setDragMode(QGraphicsView.ScrollHandDrag)
507
- left_event = QMouseEvent(event.type(), event.localPos(), Qt.LeftButton, Qt.LeftButton, event.modifiers())
508
- QGraphicsView.mousePressEvent(self.graphics_view, left_event)
509
- else:
510
- self.graphics_view.setDragMode(QGraphicsView.NoDrag)
511
- QGraphicsView.mousePressEvent(self.graphics_view, event)
512
-
513
- def mouseDoubleClickEvent(self, event):
514
- """Handle double-click to clear selection and reset the main view."""
515
- if event.button() == Qt.LeftButton:
516
- if self.graphics_scene.selectedItems():
517
- self.graphics_scene.clearSelection()
518
- self.reset_view_requested.emit()
519
- event.accept()
520
- else:
521
- super().mouseDoubleClickEvent(event)
522
-
523
- def mouseMoveEvent(self, event):
524
- """Handle mouse move for dynamic selection and panning."""
525
- if self.rubber_band:
526
- # Update the rubber band rectangle as the mouse moves
527
- current_pos = self.graphics_view.mapToScene(event.pos())
528
- self.rubber_band.setRect(QRectF(self.rubber_band_origin, current_pos).normalized())
529
- # Create a selection path from the rubber band rectangle
530
- path = QPainterPath()
531
- path.addRect(self.rubber_band.rect())
532
- # Block signals to avoid recursive selectionChanged events
533
- self.graphics_scene.blockSignals(True)
534
- self.graphics_scene.setSelectionArea(path)
535
- # Restore selection for items that were already selected at press
536
- if self.selection_at_press:
537
- for item in self.selection_at_press:
538
- item.setSelected(True)
539
- self.graphics_scene.blockSignals(False)
540
- # Manually trigger selection changed logic
541
- self.on_selection_changed()
542
- elif event.buttons() == Qt.RightButton:
543
- # Forward right-drag as left-drag for panning
544
- left_event = QMouseEvent(event.type(), event.localPos(), Qt.LeftButton, Qt.LeftButton, event.modifiers())
545
- QGraphicsView.mouseMoveEvent(self.graphics_view, left_event)
546
- self._schedule_view_update()
547
- else:
548
- # Default mouse move handling
549
- QGraphicsView.mouseMoveEvent(self.graphics_view, event)
550
-
551
- def mouseReleaseEvent(self, event):
552
- """Handle mouse release to finalize the action and clean up."""
553
- if self.rubber_band:
554
- self.graphics_scene.removeItem(self.rubber_band)
555
- self.rubber_band = None
556
- self.selection_at_press = None
557
- elif event.button() == Qt.RightButton:
558
- left_event = QMouseEvent(event.type(), event.localPos(), Qt.LeftButton, Qt.LeftButton, event.modifiers())
559
- QGraphicsView.mouseReleaseEvent(self.graphics_view, left_event)
560
- self._schedule_view_update()
561
- self.graphics_view.setDragMode(QGraphicsView.NoDrag)
562
- else:
563
- QGraphicsView.mouseReleaseEvent(self.graphics_view, event)
564
- self.graphics_view.setDragMode(QGraphicsView.NoDrag)
565
-
566
- def wheelEvent(self, event):
567
- """Handle mouse wheel for zooming."""
568
- zoom_in_factor = 1.25
569
- zoom_out_factor = 1 / zoom_in_factor
570
-
571
- # Set anchor points so zoom occurs at mouse position
572
- self.graphics_view.setTransformationAnchor(QGraphicsView.NoAnchor)
573
- self.graphics_view.setResizeAnchor(QGraphicsView.NoAnchor)
574
-
575
- # Get the scene position before zooming
576
- old_pos = self.graphics_view.mapToScene(event.pos())
577
-
578
- # Determine zoom direction
579
- zoom_factor = zoom_in_factor if event.angleDelta().y() > 0 else zoom_out_factor
580
-
581
- # Apply zoom
582
- self.graphics_view.scale(zoom_factor, zoom_factor)
583
-
584
- # Get the scene position after zooming
585
- new_pos = self.graphics_view.mapToScene(event.pos())
586
-
587
- # Translate view to keep mouse position stable
588
- delta = new_pos - old_pos
589
- self.graphics_view.translate(delta.x(), delta.y())
590
- self._schedule_view_update()
591
-
592
- def update_embeddings(self, data_items):
593
- """Update the embedding visualization. Creates an EmbeddingPointItem for
594
- each AnnotationDataItem and links them."""
595
- # Reset isolation state when loading new points
596
- if self.isolated_mode:
597
- self.show_all_points()
598
-
599
- self.clear_points()
600
- for item in data_items:
601
- point = EmbeddingPointItem(item)
602
- self.graphics_scene.addItem(point)
603
- self.points_by_id[item.annotation.id] = point
604
-
605
- # Ensure buttons are in the correct initial state
606
- self._update_toolbar_state()
607
- # Set initial visibility
608
- self._update_visible_points()
609
-
610
- def clear_points(self):
611
- """Clear all embedding points from the scene."""
612
- if self.isolated_mode:
613
- self.show_all_points()
614
-
615
- for point in self.points_by_id.values():
616
- self.graphics_scene.removeItem(point)
617
- self.points_by_id.clear()
618
- self._update_toolbar_state()
619
-
620
- def on_selection_changed(self):
621
- """
622
- Handles selection changes in the scene. Updates the central data model
623
- and emits a signal to notify other parts of the application.
624
- """
625
- if not self.graphics_scene:
626
- return
627
- try:
628
- selected_items = self.graphics_scene.selectedItems()
629
- except RuntimeError:
630
- return
631
-
632
- current_selection_ids = {item.data_item.annotation.id for item in selected_items}
633
-
634
- if current_selection_ids != self.previous_selection_ids:
635
- for point_id, point in self.points_by_id.items():
636
- is_selected = point_id in current_selection_ids
637
- point.data_item.set_selected(is_selected)
638
-
639
- self.selection_changed.emit(list(current_selection_ids))
640
- self.previous_selection_ids = current_selection_ids
641
-
642
- if hasattr(self, 'animation_timer') and self.animation_timer:
643
- self.animation_timer.stop()
644
-
645
- for point in self.points_by_id.values():
646
- if not point.isSelected():
647
- point.setPen(QPen(QColor("black"), POINT_WIDTH))
648
- if selected_items and hasattr(self, 'animation_timer') and self.animation_timer:
649
- self.animation_timer.start()
650
-
651
- # Update button states based on new selection
652
- self._update_toolbar_state()
653
-
654
- # A selection change can affect visibility (e.g., deselecting an off-screen point)
655
- self._schedule_view_update()
656
-
657
- def animate_selection(self):
658
- """Animate selected points with a marching ants effect."""
659
- if not self.graphics_scene:
660
- return
661
- try:
662
- selected_items = self.graphics_scene.selectedItems()
663
- except RuntimeError:
664
- return
665
-
666
- self.animation_offset = (self.animation_offset + 1) % 20
667
- for item in selected_items:
668
- # Get the color directly from the source of truth
669
- original_color = item.data_item.effective_color
670
- darker_color = original_color.darker(150)
671
- animated_pen = QPen(darker_color, POINT_WIDTH)
672
- animated_pen.setStyle(Qt.CustomDashLine)
673
- animated_pen.setDashPattern([1, 2])
674
- animated_pen.setDashOffset(self.animation_offset)
675
- item.setPen(animated_pen)
676
-
677
- def render_selection_from_ids(self, selected_ids):
678
- """
679
- Updates the visual selection of points based on a set of annotation IDs
680
- provided by an external controller.
681
- """
682
- blocker = QSignalBlocker(self.graphics_scene)
683
-
684
- for ann_id, point in self.points_by_id.items():
685
- is_selected = ann_id in selected_ids
686
- # 1. Update the state on the central data item
687
- point.data_item.set_selected(is_selected)
688
- # 2. Update the selection state of the graphics item itself
689
- point.setSelected(is_selected)
690
-
691
- blocker.unblock()
692
-
693
- # Manually trigger on_selection_changed to update animation and emit signals
694
- self.on_selection_changed()
695
-
696
- # After selection, update visibility to ensure newly selected points are shown
697
- self._update_visible_points()
698
-
699
- def fit_view_to_points(self):
700
- """Fit the view to show all embedding points."""
701
- if self.points_by_id:
702
- self.graphics_view.fitInView(self.graphics_scene.itemsBoundingRect(), Qt.KeepAspectRatio)
703
- else:
704
- self.graphics_view.fitInView(-2500, -2500, 5000, 5000, Qt.KeepAspectRatio)
705
-
706
-
707
- class AnnotationViewer(QWidget):
708
- """
709
- Widget containing a toolbar and a scrollable grid for displaying annotation image crops.
710
- Implements virtualization to only render visible widgets.
711
- """
712
- selection_changed = pyqtSignal(list)
713
- preview_changed = pyqtSignal(list)
714
- reset_view_requested = pyqtSignal()
715
- find_similar_requested = pyqtSignal()
716
-
717
- def __init__(self, parent=None):
718
- """Initialize the AnnotationViewer widget."""
719
- super(AnnotationViewer, self).__init__(parent)
720
- self.explorer_window = parent
721
-
722
- self.annotation_widgets_by_id = {}
723
- self.selected_widgets = []
724
- self.last_selected_item_id = None # Use a persistent ID for the selection anchor
725
- self.current_widget_size = 96
726
- self.selection_at_press = set()
727
- self.rubber_band = None
728
- self.rubber_band_origin = None
729
- self.drag_threshold = 5
730
- self.mouse_pressed_on_widget = False
731
- self.preview_label_assignments = {}
732
- self.original_label_assignments = {}
733
- self.isolated_mode = False
734
- self.isolated_widgets = set()
735
-
736
- # State for sorting options
737
- self.active_ordered_ids = []
738
- self.is_confidence_sort_available = False
739
-
740
- # New attributes for virtualization
741
- self.all_data_items = []
742
- self.widget_positions = {} # ann_id -> QRect
743
- self.update_timer = QTimer(self)
744
- self.update_timer.setSingleShot(True)
745
- self.update_timer.timeout.connect(self._update_visible_widgets)
746
-
747
- self.setup_ui()
748
-
749
- # Connect scrollbar value changed to schedule an update for virtualization
750
- self.scroll_area.verticalScrollBar().valueChanged.connect(self._schedule_update)
751
- # Install an event filter on the viewport to handle mouse events for rubber band selection
752
- self.scroll_area.viewport().installEventFilter(self)
753
-
754
- def setup_ui(self):
755
- """Set up the UI with a toolbar and a scrollable content area."""
756
- # This widget is the main container with its own layout
757
- main_layout = QVBoxLayout(self)
758
- main_layout.setContentsMargins(0, 0, 0, 0)
759
- main_layout.setSpacing(4)
760
-
761
- # Create and add the toolbar to the main layout
762
- toolbar_widget = QWidget()
763
- toolbar_layout = QHBoxLayout(toolbar_widget)
764
- toolbar_layout.setContentsMargins(4, 2, 4, 2)
765
-
766
- self.isolate_button = QPushButton("Isolate Selection")
767
- self.isolate_button.setToolTip("Hide all non-selected annotations")
768
- self.isolate_button.clicked.connect(self.isolate_selection)
769
- toolbar_layout.addWidget(self.isolate_button)
770
-
771
- self.show_all_button = QPushButton("Show All")
772
- self.show_all_button.setToolTip("Show all filtered annotations")
773
- self.show_all_button.clicked.connect(self.show_all_annotations)
774
- toolbar_layout.addWidget(self.show_all_button)
775
-
776
- toolbar_layout.addWidget(self._create_separator())
777
-
778
- sort_label = QLabel("Sort By:")
779
- toolbar_layout.addWidget(sort_label)
780
- self.sort_combo = QComboBox()
781
- # Remove "Similarity" as it's now an implicit action
782
- self.sort_combo.addItems(["None", "Label", "Image", "Confidence"])
783
- self.sort_combo.insertSeparator(3) # Add separator before "Confidence"
784
- self.sort_combo.currentTextChanged.connect(self.on_sort_changed)
785
- toolbar_layout.addWidget(self.sort_combo)
786
-
787
- toolbar_layout.addWidget(self._create_separator())
788
-
789
- self.find_similar_button = QToolButton()
790
- self.find_similar_button.setText("Find Similar")
791
- self.find_similar_button.setToolTip("Find annotations visually similar to the selection.")
792
- self.find_similar_button.setPopupMode(QToolButton.MenuButtonPopup)
793
- self.find_similar_button.setToolButtonStyle(Qt.ToolButtonTextOnly)
794
- self.find_similar_button.setStyleSheet(
795
- "QToolButton::menu-indicator { subcontrol-position: right center; subcontrol-origin: padding; left: -4px; }"
796
- )
797
-
798
- run_similar_action = QAction("Find Similar", self)
799
- run_similar_action.triggered.connect(self.find_similar_requested.emit)
800
- self.find_similar_button.setDefaultAction(run_similar_action)
801
-
802
- self.similarity_settings_widget = SimilaritySettingsWidget()
803
- settings_menu = QMenu(self)
804
- widget_action = QWidgetAction(settings_menu)
805
- widget_action.setDefaultWidget(self.similarity_settings_widget)
806
- settings_menu.addAction(widget_action)
807
- self.find_similar_button.setMenu(settings_menu)
808
- toolbar_layout.addWidget(self.find_similar_button)
809
-
810
- toolbar_layout.addStretch()
811
-
812
- size_label = QLabel("Size:")
813
- toolbar_layout.addWidget(size_label)
814
- self.size_slider = QSlider(Qt.Horizontal)
815
- self.size_slider.setMinimum(32)
816
- self.size_slider.setMaximum(256)
817
- self.size_slider.setValue(96)
818
- self.size_slider.setTickPosition(QSlider.TicksBelow)
819
- self.size_slider.setTickInterval(32)
820
- self.size_slider.valueChanged.connect(self.on_size_changed)
821
- toolbar_layout.addWidget(self.size_slider)
822
-
823
- self.size_value_label = QLabel("96")
824
- self.size_value_label.setMinimumWidth(30)
825
- toolbar_layout.addWidget(self.size_value_label)
826
- main_layout.addWidget(toolbar_widget)
827
-
828
- # Create the scroll area which will contain the content
829
- self.scroll_area = QScrollArea()
830
- self.scroll_area.setWidgetResizable(True)
831
- self.scroll_area.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
832
- self.scroll_area.setVerticalScrollBarPolicy(Qt.ScrollBarAsNeeded)
833
-
834
- self.content_widget = QWidget()
835
- self.scroll_area.setWidget(self.content_widget)
836
- main_layout.addWidget(self.scroll_area)
837
-
838
- # Set the initial state of the sort options
839
- self._update_sort_options_state()
840
- self._update_toolbar_state()
841
-
842
- def _create_separator(self):
843
- """Creates a vertical separator for the toolbar."""
844
- separator = QLabel("|")
845
- separator.setStyleSheet("color: gray; margin: 0 5px;")
846
- return separator
847
-
848
- def _update_sort_options_state(self):
849
- """Enable/disable sort options based on available data."""
850
- model = self.sort_combo.model()
851
-
852
- # Enable/disable "Confidence" option
853
- confidence_item_index = self.sort_combo.findText("Confidence")
854
- if confidence_item_index != -1:
855
- model.item(confidence_item_index).setEnabled(self.is_confidence_sort_available)
856
-
857
- def handle_annotation_context_menu(self, widget, event):
858
- """Handle context menu requests (e.g., right-click) on an annotation widget."""
859
- if event.modifiers() == Qt.ControlModifier:
860
- explorer = self.explorer_window
861
- image_path = widget.annotation.image_path
862
- annotation_to_select = widget.annotation
863
-
864
- # ctrl+right click to only select this annotation (single selection):
865
- self.clear_selection()
866
- self.select_widget(widget)
867
- changed_ids = [widget.data_item.annotation.id]
868
-
869
- if changed_ids:
870
- self.selection_changed.emit(changed_ids)
871
-
872
- if hasattr(explorer, 'annotation_window'):
873
- # Check if the image needs to be changed
874
- if explorer.annotation_window.current_image_path != image_path:
875
- if hasattr(explorer.annotation_window, 'set_image'):
876
- explorer.annotation_window.set_image(image_path)
877
-
878
- # Now, select the annotation in the annotation_window (activates animation)
879
- if hasattr(explorer.annotation_window, 'select_annotation'):
880
- explorer.annotation_window.select_annotation(annotation_to_select, quiet_mode=True)
881
-
882
- # Center the annotation window view on the selected annotation
883
- if hasattr(explorer.annotation_window, 'center_on_annotation'):
884
- explorer.annotation_window.center_on_annotation(annotation_to_select)
885
-
886
- # Show resize handles for Rectangle annotations
887
- if isinstance(annotation_to_select, RectangleAnnotation):
888
- explorer.annotation_window.set_selected_tool('select') # Accidentally unselects in AnnotationWindow
889
- explorer.annotation_window.select_annotation(annotation_to_select, quiet_mode=True)
890
- select_tool = explorer.annotation_window.tools.get('select')
891
-
892
- if select_tool:
893
- # Engage the selection lock.
894
- select_tool.selection_locked = True
895
- # Show the resize handles for the now-selected annotation.
896
- select_tool._show_resize_handles()
897
-
898
- # Also clear any existing selection in the explorer window itself
899
- explorer.embedding_viewer.render_selection_from_ids({widget.data_item.annotation.id})
900
- explorer.update_label_window_selection()
901
- explorer.update_button_states()
902
-
903
- event.accept()
904
-
905
- @pyqtSlot()
906
- def isolate_selection(self):
907
- """Hides all annotation widgets that are not currently selected."""
908
- if not self.selected_widgets:
909
- return
910
-
911
- self.isolated_widgets = set(self.selected_widgets)
912
- self.content_widget.setUpdatesEnabled(False)
913
- try:
914
- for widget in self.annotation_widgets_by_id.values():
915
- if widget not in self.isolated_widgets:
916
- widget.hide()
917
- self.isolated_mode = True
918
- self.recalculate_layout()
919
- finally:
920
- self.content_widget.setUpdatesEnabled(True)
921
-
922
- self._update_toolbar_state()
923
- self.explorer_window.main_window.label_window.update_annotation_count()
924
-
925
- def isolate_and_select_from_ids(self, ids_to_isolate):
926
- """
927
- Enters isolated mode showing only widgets for the given IDs, and also
928
- selects them. This is the primary entry point from external viewers.
929
- The isolated set is 'sticky' and will not change on subsequent internal
930
- selection changes.
931
- """
932
- # Get the widget objects from the IDs
933
- widgets_to_isolate = {
934
- self.annotation_widgets_by_id[ann_id]
935
- for ann_id in ids_to_isolate
936
- if ann_id in self.annotation_widgets_by_id
937
- }
938
-
939
- if not widgets_to_isolate:
940
- return
941
-
942
- self.isolated_widgets = widgets_to_isolate
943
- self.isolated_mode = True
944
-
945
- self.render_selection_from_ids(ids_to_isolate)
946
- self.recalculate_layout()
947
-
948
- def display_and_isolate_ordered_results(self, ordered_ids):
949
- """
950
- Isolates the view to a specific set of ordered widgets, ensuring the
951
- grid is always updated. This is the new primary method for showing
952
- similarity results.
953
- """
954
- self.active_ordered_ids = ordered_ids
955
-
956
- # Render the selection based on the new order
957
- self.render_selection_from_ids(set(ordered_ids))
958
-
959
- # Now, perform the isolation logic directly to bypass the guard clause
960
- self.isolated_widgets = set(self.selected_widgets)
961
- self.content_widget.setUpdatesEnabled(False)
962
- try:
963
- for widget in self.annotation_widgets_by_id.values():
964
- # Show widget if it's in our target set, hide otherwise
965
- if widget in self.isolated_widgets:
966
- widget.show()
967
- else:
968
- widget.hide()
969
-
970
- self.isolated_mode = True
971
- self.recalculate_layout() # Crucial grid update
972
- finally:
973
- self.content_widget.setUpdatesEnabled(True)
974
-
975
- self._update_toolbar_state()
976
- self.explorer_window.main_window.label_window.update_annotation_count()
977
-
978
- @pyqtSlot()
979
- def show_all_annotations(self):
980
- """Shows all annotation widgets, exiting the isolated mode."""
981
- if not self.isolated_mode:
982
- return
983
-
984
- self.isolated_mode = False
985
- self.isolated_widgets.clear()
986
- self.active_ordered_ids = [] # Clear similarity sort context
987
-
988
- self.content_widget.setUpdatesEnabled(False)
989
- try:
990
- # Show all widgets that are managed by the viewer
991
- for widget in self.annotation_widgets_by_id.values():
992
- widget.show()
993
-
994
- self.recalculate_layout()
995
- finally:
996
- self.content_widget.setUpdatesEnabled(True)
997
-
998
- self._update_toolbar_state()
999
- self.explorer_window.main_window.label_window.update_annotation_count()
1000
-
1001
- def _update_toolbar_state(self):
1002
- """Updates the toolbar buttons based on selection and isolation mode."""
1003
- selection_exists = bool(self.selected_widgets)
1004
- if self.isolated_mode:
1005
- self.isolate_button.hide()
1006
- self.show_all_button.show()
1007
- self.show_all_button.setEnabled(True)
1008
- else:
1009
- self.isolate_button.show()
1010
- self.show_all_button.hide()
1011
- self.isolate_button.setEnabled(selection_exists)
1012
-
1013
- def on_sort_changed(self, sort_type):
1014
- """Handle sort type change."""
1015
- self.active_ordered_ids = [] # Clear any special ordering
1016
- self.recalculate_layout()
1017
-
1018
- def set_confidence_sort_availability(self, is_available):
1019
- """Sets the availability of the confidence sort option."""
1020
- self.is_confidence_sort_available = is_available
1021
- self._update_sort_options_state()
1022
-
1023
- def _get_sorted_data_items(self):
1024
- """Get data items sorted according to the current sort setting."""
1025
- # If a specific order is active (e.g., from similarity search), use it.
1026
- if self.active_ordered_ids:
1027
- item_map = {i.annotation.id: i for i in self.all_data_items}
1028
- ordered_items = [item_map[ann_id] for ann_id in self.active_ordered_ids if ann_id in item_map]
1029
- return ordered_items
1030
-
1031
- # Otherwise, use the dropdown sort logic
1032
- sort_type = self.sort_combo.currentText()
1033
- items = list(self.all_data_items)
1034
-
1035
- if sort_type == "Label":
1036
- items.sort(key=lambda i: i.effective_label.short_label_code)
1037
- elif sort_type == "Image":
1038
- items.sort(key=lambda i: os.path.basename(i.annotation.image_path))
1039
- elif sort_type == "Confidence":
1040
- # Sort by confidence, descending. Handles cases with no confidence gracefully.
1041
- items.sort(key=lambda i: i.get_effective_confidence(), reverse=True)
1042
-
1043
- return items
1044
-
1045
- def _get_sorted_widgets(self):
1046
- """
1047
- Get widgets sorted according to the current sort setting.
1048
- This is kept for compatibility with selection logic.
1049
- """
1050
- sorted_data_items = self._get_sorted_data_items()
1051
- return [self.annotation_widgets_by_id[item.annotation.id]
1052
- for item in sorted_data_items if item.annotation.id in self.annotation_widgets_by_id]
1053
-
1054
- def _group_data_items_by_sort_key(self, data_items):
1055
- """Group data items by the current sort key."""
1056
- sort_type = self.sort_combo.currentText()
1057
- if not self.active_ordered_ids and sort_type == "None":
1058
- return [("", data_items)]
1059
-
1060
- if self.active_ordered_ids: # Don't show group headers for similarity results
1061
- return [("", data_items)]
1062
-
1063
- groups = []
1064
- current_group = []
1065
- current_key = None
1066
- for item in data_items:
1067
- if sort_type == "Label":
1068
- key = item.effective_label.short_label_code
1069
- elif sort_type == "Image":
1070
- key = os.path.basename(item.annotation.image_path)
1071
- else:
1072
- key = "" # No headers for Confidence or None
1073
-
1074
- if key and current_key != key:
1075
- if current_group:
1076
- groups.append((current_key, current_group))
1077
- current_group = [item]
1078
- current_key = key
1079
- else:
1080
- current_group.append(item)
1081
- if current_group:
1082
- groups.append((current_key, current_group))
1083
- return groups
1084
-
1085
- def _clear_separator_labels(self):
1086
- """Remove any existing group header labels."""
1087
- if hasattr(self, '_group_headers'):
1088
- for header in self._group_headers:
1089
- header.setParent(None)
1090
- header.deleteLater()
1091
- self._group_headers = []
1092
-
1093
- def _create_group_header(self, text):
1094
- """Create a group header label."""
1095
- if not hasattr(self, '_group_headers'):
1096
- self._group_headers = []
1097
- header = QLabel(text, self.content_widget)
1098
- header.setStyleSheet(
1099
- "QLabel {"
1100
- " font-weight: bold;"
1101
- " font-size: 12px;"
1102
- " color: #555;"
1103
- " background-color: #f0f0f0;"
1104
- " border: 1px solid #ccc;"
1105
- " border-radius: 3px;"
1106
- " padding: 5px 8px;"
1107
- " margin: 2px 0px;"
1108
- " }"
1109
- )
1110
- header.setFixedHeight(30)
1111
- header.setMinimumWidth(self.scroll_area.viewport().width() - 20)
1112
- header.show()
1113
- self._group_headers.append(header)
1114
- return header
1115
-
1116
- def on_size_changed(self, value):
1117
- """Handle slider value change to resize annotation widgets."""
1118
- if value % 2 != 0:
1119
- value -= 1
1120
-
1121
- self.current_widget_size = value
1122
- self.size_value_label.setText(str(value))
1123
- self.recalculate_layout()
1124
-
1125
- def _schedule_update(self):
1126
- """Schedules a delayed update of visible widgets to avoid performance issues during rapid scrolling."""
1127
- self.update_timer.start(50) # 50ms delay
1128
-
1129
- def _update_visible_widgets(self):
1130
- """Shows and loads widgets that are in the viewport, and hides/unloads others."""
1131
- if not self.widget_positions:
1132
- return
1133
-
1134
- self.content_widget.setUpdatesEnabled(False)
1135
-
1136
- # Determine the visible rectangle in the content widget's coordinates
1137
- scroll_y = self.scroll_area.verticalScrollBar().value()
1138
- visible_content_rect = QRect(0,
1139
- scroll_y,
1140
- self.scroll_area.viewport().width(),
1141
- self.scroll_area.viewport().height())
1142
-
1143
- # Add a buffer to load images slightly before they become visible
1144
- buffer = self.scroll_area.viewport().height() // 2
1145
- visible_content_rect.adjust(0, -buffer, 0, buffer)
1146
-
1147
- visible_ids = set()
1148
- for ann_id, rect in self.widget_positions.items():
1149
- if rect.intersects(visible_content_rect):
1150
- visible_ids.add(ann_id)
1151
-
1152
- # Update widgets based on visibility
1153
- for ann_id, widget in self.annotation_widgets_by_id.items():
1154
- if ann_id in visible_ids:
1155
- # This widget should be visible
1156
- widget.setGeometry(self.widget_positions[ann_id])
1157
- widget.load_image() # Lazy-loads the image
1158
- widget.show()
1159
- else:
1160
- # This widget is not visible
1161
- if widget.isVisible():
1162
- widget.hide()
1163
- widget.unload_image() # Free up memory
1164
-
1165
- self.content_widget.setUpdatesEnabled(True)
1166
-
1167
- def recalculate_layout(self):
1168
- """Calculates the positions for all widgets and the total size of the content area."""
1169
- if not self.all_data_items:
1170
- self.content_widget.setMinimumSize(1, 1)
1171
- return
1172
-
1173
- self._clear_separator_labels()
1174
- sorted_data_items = self._get_sorted_data_items()
1175
-
1176
- # If in isolated mode, only consider the isolated widgets for layout
1177
- if self.isolated_mode:
1178
- isolated_ids = {w.data_item.annotation.id for w in self.isolated_widgets}
1179
- sorted_data_items = [item for item in sorted_data_items if item.annotation.id in isolated_ids]
1180
-
1181
- if not sorted_data_items:
1182
- self.content_widget.setMinimumSize(1, 1)
1183
- return
1184
-
1185
- # Create groups based on the current sort key
1186
- groups = self._group_data_items_by_sort_key(sorted_data_items)
1187
- spacing = max(5, int(self.current_widget_size * 0.08))
1188
- available_width = self.scroll_area.viewport().width()
1189
- x, y = spacing, spacing
1190
- max_height_in_row = 0
1191
-
1192
- self.widget_positions.clear()
1193
-
1194
- # Calculate positions
1195
- for group_name, group_data_items in groups:
1196
- if group_name and self.sort_combo.currentText() != "None":
1197
- if x > spacing:
1198
- x = spacing
1199
- y += max_height_in_row + spacing
1200
- max_height_in_row = 0
1201
- header_label = self._create_group_header(group_name)
1202
- header_label.move(x, y)
1203
- y += header_label.height() + spacing
1204
- x = spacing
1205
- max_height_in_row = 0
1206
-
1207
- for data_item in group_data_items:
1208
- ann_id = data_item.annotation.id
1209
- # Get or create widget to determine its size
1210
- if ann_id in self.annotation_widgets_by_id:
1211
- widget = self.annotation_widgets_by_id[ann_id]
1212
- widget.update_height(self.current_widget_size) # Ensure size is up-to-date
1213
- else:
1214
- widget = AnnotationImageWidget(data_item, self.current_widget_size, self, self.content_widget)
1215
- self.annotation_widgets_by_id[ann_id] = widget
1216
- widget.hide() # Hide by default
1217
-
1218
- widget_size = widget.size()
1219
- if x > spacing and x + widget_size.width() > available_width:
1220
- x = spacing
1221
- y += max_height_in_row + spacing
1222
- max_height_in_row = 0
1223
-
1224
- self.widget_positions[ann_id] = QRect(x, y, widget_size.width(), widget_size.height())
1225
-
1226
- x += widget_size.width() + spacing
1227
- max_height_in_row = max(max_height_in_row, widget_size.height())
1228
-
1229
- total_height = y + max_height_in_row + spacing
1230
- self.content_widget.setMinimumSize(available_width, total_height)
1231
-
1232
- # After calculating layout, update what's visible
1233
- self._update_visible_widgets()
1234
-
1235
- def update_annotations(self, data_items):
1236
- """Update displayed annotations, creating new widgets for them."""
1237
- if self.isolated_mode:
1238
- self.show_all_annotations()
1239
-
1240
- # Clear out widgets for data items that are no longer in the new set
1241
- all_ann_ids = {item.annotation.id for item in data_items}
1242
- for ann_id, widget in list(self.annotation_widgets_by_id.items()):
1243
- if ann_id not in all_ann_ids:
1244
- if widget in self.selected_widgets:
1245
- self.selected_widgets.remove(widget)
1246
- widget.setParent(None)
1247
- widget.deleteLater()
1248
- del self.annotation_widgets_by_id[ann_id]
1249
-
1250
- self.all_data_items = data_items
1251
- self.selected_widgets.clear()
1252
- self.last_selected_item_id = None
1253
-
1254
- self.recalculate_layout()
1255
- self._update_toolbar_state()
1256
- # Update the label window with the new annotation count
1257
- self.explorer_window.main_window.label_window.update_annotation_count()
1258
-
1259
- def resizeEvent(self, event):
1260
- """On window resize, reflow the annotation widgets."""
1261
- super(AnnotationViewer, self).resizeEvent(event)
1262
- if not hasattr(self, '_resize_timer'):
1263
- self._resize_timer = QTimer(self)
1264
- self._resize_timer.setSingleShot(True)
1265
- self._resize_timer.timeout.connect(self.recalculate_layout)
1266
- self._resize_timer.start(100)
1267
-
1268
- def keyPressEvent(self, event):
1269
- """Handles key presses for deleting selected annotations."""
1270
- if event.key() in (Qt.Key_Delete, Qt.Key_Backspace) and event.modifiers() == Qt.ControlModifier:
1271
- if not self.selected_widgets:
1272
- super().keyPressEvent(event)
1273
- return
1274
-
1275
- # Extract the central data items from the selected widgets
1276
- data_items_to_delete = [widget.data_item for widget in self.selected_widgets]
1277
-
1278
- # Delegate the actual deletion to the main ExplorerWindow
1279
- if data_items_to_delete:
1280
- self.explorer_window.delete_data_items(data_items_to_delete)
1281
-
1282
- event.accept()
1283
- else:
1284
- super().keyPressEvent(event)
1285
-
1286
- def eventFilter(self, source, event):
1287
- """Filters events from the scroll area's viewport to handle mouse interactions."""
1288
- if source is self.scroll_area.viewport():
1289
- if event.type() == QEvent.MouseButtonPress:
1290
- return self.viewport_mouse_press(event)
1291
- elif event.type() == QEvent.MouseMove:
1292
- return self.viewport_mouse_move(event)
1293
- elif event.type() == QEvent.MouseButtonRelease:
1294
- return self.viewport_mouse_release(event)
1295
- elif event.type() == QEvent.MouseButtonDblClick:
1296
- return self.viewport_mouse_double_click(event)
1297
-
1298
- return super(AnnotationViewer, self).eventFilter(source, event)
1299
-
1300
- def viewport_mouse_press(self, event):
1301
- """Handle mouse press inside the viewport for selection."""
1302
- if event.button() == Qt.LeftButton and event.modifiers() == Qt.ControlModifier:
1303
- # Start rubber band selection
1304
- self.selection_at_press = set(self.selected_widgets)
1305
- self.rubber_band_origin = event.pos()
1306
-
1307
- # Check if the press was on a widget to avoid starting rubber band on a widget click
1308
- content_pos = self.content_widget.mapFrom(self.scroll_area.viewport(), event.pos())
1309
- child_at_pos = self.content_widget.childAt(content_pos)
1310
- self.mouse_pressed_on_widget = isinstance(child_at_pos, AnnotationImageWidget)
1311
-
1312
- return True # Event handled
1313
-
1314
- elif event.button() == Qt.LeftButton and not event.modifiers():
1315
- # Clear selection if clicking on the background
1316
- content_pos = self.content_widget.mapFrom(self.scroll_area.viewport(), event.pos())
1317
- if self.content_widget.childAt(content_pos) is None:
1318
- if self.selected_widgets:
1319
- changed_ids = [w.data_item.annotation.id for w in self.selected_widgets]
1320
- self.clear_selection()
1321
- self.selection_changed.emit(changed_ids)
1322
- if hasattr(self.explorer_window.annotation_window, 'unselect_annotations'):
1323
- self.explorer_window.annotation_window.unselect_annotations()
1324
- return True
1325
-
1326
- return False # Let the event propagate for default behaviors like scrolling
1327
-
1328
- def viewport_mouse_double_click(self, event):
1329
- """Handle double-click in the viewport to clear selection and reset view."""
1330
- if event.button() == Qt.LeftButton:
1331
- if self.selected_widgets:
1332
- changed_ids = [w.data_item.annotation.id for w in self.selected_widgets]
1333
- self.clear_selection()
1334
- self.selection_changed.emit(changed_ids)
1335
- if self.isolated_mode:
1336
- self.show_all_annotations()
1337
- self.reset_view_requested.emit()
1338
- return True
1339
- return False
1340
-
1341
- def viewport_mouse_move(self, event):
1342
- """Handle mouse move in the viewport for dynamic rubber band selection."""
1343
- if (
1344
- self.rubber_band_origin is None or
1345
- event.buttons() != Qt.LeftButton or
1346
- event.modifiers() != Qt.ControlModifier or
1347
- self.mouse_pressed_on_widget
1348
- ):
1349
- return False
1350
-
1351
- # Only start selection if drag distance exceeds threshold
1352
- distance = (event.pos() - self.rubber_band_origin).manhattanLength()
1353
- if distance < self.drag_threshold:
1354
- return True
1355
-
1356
- # Create and show the rubber band if not already present
1357
- if not self.rubber_band:
1358
- self.rubber_band = QRubberBand(QRubberBand.Rectangle, self.scroll_area.viewport())
1359
-
1360
- rect = QRect(self.rubber_band_origin, event.pos()).normalized()
1361
- self.rubber_band.setGeometry(rect)
1362
- self.rubber_band.show()
1363
-
1364
- selection_rect = self.rubber_band.geometry()
1365
- content_widget = self.content_widget
1366
- changed_ids = []
1367
-
1368
- # Iterate over all annotation widgets to update selection state
1369
- for widget in self.annotation_widgets_by_id.values():
1370
- # Map widget's geometry from content_widget coordinates to viewport coordinates
1371
- mapped_top_left = content_widget.mapTo(self.scroll_area.viewport(), widget.geometry().topLeft())
1372
- widget_rect_in_viewport = QRect(mapped_top_left, widget.geometry().size())
1373
-
1374
- is_in_band = selection_rect.intersects(widget_rect_in_viewport)
1375
- should_be_selected = (widget in self.selection_at_press) or is_in_band
1376
-
1377
- # Select or deselect widgets as needed
1378
- if should_be_selected and not widget.is_selected():
1379
- if self.select_widget(widget):
1380
- changed_ids.append(widget.data_item.annotation.id)
1381
-
1382
- elif not should_be_selected and widget.is_selected():
1383
- if self.deselect_widget(widget):
1384
- changed_ids.append(widget.data_item.annotation.id)
1385
-
1386
- # Emit signal if any selection state changed
1387
- if changed_ids:
1388
- self.selection_changed.emit(changed_ids)
1389
-
1390
- return True
1391
-
1392
- def viewport_mouse_release(self, event):
1393
- """Handle mouse release in the viewport to finalize rubber band selection."""
1394
- if self.rubber_band_origin is not None and event.button() == Qt.LeftButton:
1395
- if self.rubber_band and self.rubber_band.isVisible():
1396
- self.rubber_band.hide()
1397
- self.rubber_band.deleteLater()
1398
- self.rubber_band = None
1399
- self.rubber_band_origin = None
1400
- return True
1401
- return False
1402
-
1403
- def handle_annotation_selection(self, widget, event):
1404
- """Handle selection of annotation widgets with different modes (single, ctrl, shift)."""
1405
- # The list for range selection should be based on the sorted data items
1406
- sorted_data_items = self._get_sorted_data_items()
1407
-
1408
- # In isolated mode, the list should only contain isolated items
1409
- if self.isolated_mode:
1410
- isolated_ids = {w.data_item.annotation.id for w in self.isolated_widgets}
1411
- sorted_data_items = [item for item in sorted_data_items if item.annotation.id in isolated_ids]
1412
-
1413
- try:
1414
- # Find the index of the clicked widget's data item
1415
- widget_data_item = widget.data_item
1416
- current_index = sorted_data_items.index(widget_data_item)
1417
- except ValueError:
1418
- return
1419
-
1420
- modifiers = event.modifiers()
1421
- changed_ids = []
1422
-
1423
- # Shift or Shift+Ctrl: range selection.
1424
- if modifiers in (Qt.ShiftModifier, Qt.ShiftModifier | Qt.ControlModifier):
1425
- last_index = -1
1426
- if self.last_selected_item_id:
1427
- try:
1428
- # Find the data item corresponding to the last selected ID
1429
- last_item = self.explorer_window.data_item_cache[self.last_selected_item_id]
1430
- # Find its index in the *current* sorted list
1431
- last_index = sorted_data_items.index(last_item)
1432
- except (KeyError, ValueError):
1433
- # The last selected item is not in the current view or cache, so no anchor
1434
- last_index = -1
1435
-
1436
- if last_index != -1:
1437
- start = min(last_index, current_index)
1438
- end = max(last_index, current_index)
1439
-
1440
- # Select all widgets in the range
1441
- for i in range(start, end + 1):
1442
- item_to_select = sorted_data_items[i]
1443
- widget_to_select = self.annotation_widgets_by_id.get(item_to_select.annotation.id)
1444
- if widget_to_select and self.select_widget(widget_to_select):
1445
- changed_ids.append(item_to_select.annotation.id)
1446
- else:
1447
- # No previous selection, just select the clicked widget
1448
- if self.select_widget(widget):
1449
- changed_ids.append(widget.data_item.annotation.id)
1450
-
1451
- self.last_selected_item_id = widget.data_item.annotation.id
1452
-
1453
- # Ctrl: toggle selection of the clicked widget
1454
- elif modifiers == Qt.ControlModifier:
1455
- # Toggle selection and update the anchor
1456
- if self.toggle_widget_selection(widget):
1457
- changed_ids.append(widget.data_item.annotation.id)
1458
- self.last_selected_item_id = widget.data_item.annotation.id
1459
-
1460
- # No modifier: single selection
1461
- else:
1462
- newly_selected_id = widget.data_item.annotation.id
1463
-
1464
- # Deselect all others
1465
- for w in list(self.selected_widgets):
1466
- if w.data_item.annotation.id != newly_selected_id:
1467
- if self.deselect_widget(w):
1468
- changed_ids.append(w.data_item.annotation.id)
1469
-
1470
- # Select the clicked widget
1471
- if self.select_widget(widget):
1472
- changed_ids.append(newly_selected_id)
1473
- self.last_selected_item_id = widget.data_item.annotation.id
1474
-
1475
- # If in isolated mode, update which widgets are visible
1476
- if self.isolated_mode:
1477
- pass # Do not change the isolated set on internal selection changes
1478
-
1479
- # Emit signal if any selection state changed
1480
- if changed_ids:
1481
- self.selection_changed.emit(changed_ids)
1482
-
1483
- def toggle_widget_selection(self, widget):
1484
- """Toggles the selection state of a widget and returns True if changed."""
1485
- if widget.is_selected():
1486
- return self.deselect_widget(widget)
1487
- else:
1488
- return self.select_widget(widget)
1489
-
1490
- def select_widget(self, widget):
1491
- """Selects a widget, updates its data_item, and returns True if state changed."""
1492
- if not widget.is_selected(): # is_selected() checks the data_item
1493
- # 1. Controller modifies the state on the data item
1494
- widget.data_item.set_selected(True)
1495
- # 2. Controller tells the view to update its appearance
1496
- widget.update_selection_visuals()
1497
- self.selected_widgets.append(widget)
1498
- self._update_toolbar_state()
1499
- return True
1500
- return False
1501
-
1502
- def deselect_widget(self, widget):
1503
- """Deselects a widget, updates its data_item, and returns True if state changed."""
1504
- if widget.is_selected():
1505
- # 1. Controller modifies the state on the data item
1506
- widget.data_item.set_selected(False)
1507
- # 2. Controller tells the view to update its appearance
1508
- widget.update_selection_visuals()
1509
- if widget in self.selected_widgets:
1510
- self.selected_widgets.remove(widget)
1511
- self._update_toolbar_state()
1512
- return True
1513
- return False
1514
-
1515
- def clear_selection(self):
1516
- """Clear all selected widgets and update toolbar state."""
1517
- for widget in list(self.selected_widgets):
1518
- # This will internally call deselect_widget, which is fine
1519
- self.deselect_widget(widget)
1520
-
1521
- self.selected_widgets.clear()
1522
- self._update_toolbar_state()
1523
-
1524
- def get_selected_annotations(self):
1525
- """Get the annotations corresponding to selected widgets."""
1526
- return [widget.annotation for widget in self.selected_widgets]
1527
-
1528
- def render_selection_from_ids(self, selected_ids):
1529
- """Update the visual selection of widgets based on a set of IDs from the controller."""
1530
- self.setUpdatesEnabled(False)
1531
- try:
1532
- for ann_id, widget in self.annotation_widgets_by_id.items():
1533
- is_selected = ann_id in selected_ids
1534
- # 1. Update the state on the central data item
1535
- widget.data_item.set_selected(is_selected)
1536
- # 2. Tell the widget to update its visuals based on the new state
1537
- widget.update_selection_visuals()
1538
-
1539
- # Resync internal list of selected widgets from the source of truth
1540
- self.selected_widgets = [w for w in self.annotation_widgets_by_id.values() if w.is_selected()]
1541
-
1542
- finally:
1543
- self.setUpdatesEnabled(True)
1544
- self._update_toolbar_state()
1545
-
1546
- def apply_preview_label_to_selected(self, preview_label):
1547
- """Apply a preview label and emit a signal for the embedding view to update."""
1548
- if not self.selected_widgets or not preview_label:
1549
- return
1550
- changed_ids = []
1551
- for widget in self.selected_widgets:
1552
- widget.data_item.set_preview_label(preview_label)
1553
- widget.update() # Force repaint with new color
1554
- changed_ids.append(widget.data_item.annotation.id)
1555
-
1556
- if self.sort_combo.currentText() == "Label":
1557
- self.recalculate_layout()
1558
- if changed_ids:
1559
- self.preview_changed.emit(changed_ids)
1560
-
1561
- def clear_preview_states(self):
1562
- """
1563
- Clears all preview states, including label changes,
1564
- reverting them to their original state.
1565
- """
1566
- something_changed = False
1567
- for widget in self.annotation_widgets_by_id.values():
1568
- # Check for and clear preview labels
1569
- if widget.data_item.has_preview_changes():
1570
- widget.data_item.clear_preview_label()
1571
- widget.update() # Repaint to show original color
1572
- something_changed = True
1573
-
1574
- if something_changed:
1575
- # Recalculate positions to update sorting and re-flow the layout
1576
- if self.sort_combo.currentText() == "Label":
1577
- self.recalculate_layout()
1578
-
1579
- def has_preview_changes(self):
1580
- """Return True if there are preview changes."""
1581
- return any(w.data_item.has_preview_changes() for w in self.annotation_widgets_by_id.values())
1582
-
1583
- def get_preview_changes_summary(self):
1584
- """Get a summary of preview changes."""
1585
- change_count = sum(1 for w in self.annotation_widgets_by_id.values() if w.data_item.has_preview_changes())
1586
- return f"{change_count} annotation(s) with preview changes" if change_count else "No preview changes"
1587
-
1588
- def apply_preview_changes_permanently(self):
1589
- """Apply preview changes permanently."""
1590
- applied_annotations = []
1591
- for widget in self.annotation_widgets_by_id.values():
1592
- if widget.data_item.apply_preview_permanently():
1593
- applied_annotations.append(widget.annotation)
1594
- return applied_annotations
1595
-
1596
-
1597
56
  # ----------------------------------------------------------------------------------------------------------------------
1598
57
  # ExplorerWindow
1599
58
  # ----------------------------------------------------------------------------------------------------------------------
@@ -1610,7 +69,7 @@ class ExplorerWindow(QMainWindow):
1610
69
 
1611
70
  self.device = main_window.device
1612
71
  self.loaded_model = None
1613
- self.loaded_model_imgsz = 128
72
+ self.imgsz = 128
1614
73
 
1615
74
  self.feature_store = FeatureStore()
1616
75
 
@@ -2357,15 +816,20 @@ class ExplorerWindow(QMainWindow):
2357
816
  model_name, feature_mode = model_info
2358
817
 
2359
818
  # Load the model
2360
- model, imgsz = self._load_yolo_model(model_name, feature_mode)
819
+ model = self._load_yolo_model(model_name, feature_mode)
2361
820
  if model is None:
2362
821
  QMessageBox.warning(self,
2363
822
  "Model Load Error",
2364
823
  f"Could not load YOLO model '{model_name}'.")
2365
824
  return None
2366
825
 
2367
- # Prepare images from data items
2368
- image_list, valid_data_items = self._prepare_images_from_data_items(data_items)
826
+ # Prepare images from data items with proper resizing
827
+ image_list, valid_data_items = self._prepare_images_from_data_items(
828
+ data_items,
829
+ format='numpy',
830
+ target_size=(self.imgsz, self.imgsz)
831
+ )
832
+
2369
833
  if not image_list:
2370
834
  return None
2371
835
 
@@ -2373,7 +837,7 @@ class ExplorerWindow(QMainWindow):
2373
837
  # We need probabilities for uncertainty analysis, so we always use predict
2374
838
  results = model.predict(image_list,
2375
839
  stream=False, # Use batch processing for uncertainty
2376
- imgsz=imgsz,
840
+ imgsz=self.imgsz,
2377
841
  half=True,
2378
842
  device=self.device,
2379
843
  verbose=False)
@@ -2430,7 +894,7 @@ class ExplorerWindow(QMainWindow):
2430
894
  feature_mode (str): Mode for feature extraction ("Embed Features" or "Predictions")
2431
895
 
2432
896
  Returns:
2433
- tuple: (model, image_size) or (None, None) if loading fails
897
+ ultralytics.yolo.engine.model.Model: The loaded YOLO model object, or None if loading fails.
2434
898
  """
2435
899
  current_run_key = (model_name, feature_mode)
2436
900
 
@@ -2454,7 +918,7 @@ class ExplorerWindow(QMainWindow):
2454
918
  # On failure, reset the model cache
2455
919
  self.loaded_model = None
2456
920
  self.current_feature_generating_model = None
2457
- return None, None
921
+ return None
2458
922
 
2459
923
  # Update the cache key to the new successful combination
2460
924
  self.current_feature_generating_model = current_run_key
@@ -2462,31 +926,109 @@ class ExplorerWindow(QMainWindow):
2462
926
 
2463
927
  # Get the imgsz, but if it's larger than 128, default to 128
2464
928
  imgsz = min(getattr(model.model.args, 'imgsz', 128), 128)
2465
- self.loaded_model_imgsz = imgsz
929
+ self.imgsz = imgsz
2466
930
 
2467
931
  # Warm up the model
2468
932
  dummy_image = np.zeros((imgsz, imgsz, 3), dtype=np.uint8)
2469
933
  model.predict(dummy_image, imgsz=imgsz, half=True, device=self.device, verbose=False)
2470
934
 
2471
- return model, self.loaded_model_imgsz
935
+ return model
936
+
937
+ except Exception as e:
938
+ QMessageBox.critical(self,
939
+ "Model Load Error",
940
+ f"Could not load the YOLO model '{model_name}'.\n\nError: {e}")
941
+
942
+ # On failure, reset the model cache
943
+ self.loaded_model = None
944
+ self.current_feature_generating_model = None
945
+ return None
946
+
947
+ # Model already loaded and cached, return it and its image size
948
+ return self.loaded_model
949
+
950
+ def _load_transformer_model(self, model_name):
951
+ """
952
+ Helper function to load a transformer model and cache it.
953
+
954
+ Args:
955
+ model_name (str): Name of the transformer model to use (e.g., "google/vit-base-patch16-224")
956
+
957
+ Returns:
958
+ transformers.pipelines.base.Pipeline: The feature extractor pipeline object, or None if loading fails.
959
+ """
960
+ current_run_key = (model_name, "transformer")
961
+
962
+ # Force a reload if the model path has changed
963
+ if current_run_key != self.current_feature_generating_model or self.loaded_model is None:
964
+ print(f"Model changed. Loading transformer model {model_name}...")
965
+
966
+ try:
967
+ # Lazy import to avoid unnecessary dependencies
968
+ from transformers import pipeline
969
+ from huggingface_hub import snapshot_download
970
+
971
+ # Pre-download the model to show progress if it's not cached
972
+ model_path = snapshot_download(repo_id=model_name,
973
+ allow_patterns=["*.json", "*.bin", "*.safetensors", "*.txt"])
974
+
975
+ # Convert device string to appropriate format for transformers pipeline
976
+ if self.device.startswith('cuda'):
977
+ # Extract device number from 'cuda:0' format for CUDA GPUs
978
+ device_num = int(self.device.split(':')[-1]) if ':' in self.device else 0
979
+ elif self.device == 'mps':
980
+ # MPS (Metal Performance Shaders) - Apple's GPU acceleration for macOS
981
+ device_num = 'mps'
982
+ else:
983
+ # Default to CPU for any other device string
984
+ device_num = -1
985
+
986
+ # Initialize the feature extractor pipeline with local model path
987
+ feature_extractor = pipeline(
988
+ model=model_path,
989
+ task="image-feature-extraction",
990
+ device=device_num,
991
+ )
992
+ try:
993
+ image_processor = feature_extractor.image_processor
994
+ if hasattr(image_processor, 'size'):
995
+ # For older transformers versions
996
+ self.imgsz = image_processor.size['height']
997
+ else:
998
+ # For newer transformers versions
999
+ self.imgsz = image_processor.crop_size['height']
1000
+
1001
+ except Exception:
1002
+ self.imgsz = 128
1003
+
1004
+ # Update the cache key to the new successful combination
1005
+ self.current_feature_generating_model = current_run_key
1006
+ self.loaded_model = feature_extractor
1007
+
1008
+ return feature_extractor
2472
1009
 
2473
1010
  except Exception as e:
2474
- print(f"ERROR: Could not load YOLO model '{model_name}': {e}")
1011
+ QMessageBox.critical(self,
1012
+ "Model Load Error",
1013
+ f"Could not load the transformer model '{model_name}'.\n\nError: {e}")
1014
+
2475
1015
  # On failure, reset the model cache
2476
1016
  self.loaded_model = None
2477
1017
  self.current_feature_generating_model = None
2478
- return None, None
1018
+ return None
2479
1019
 
2480
1020
  # Model already loaded and cached, return it and its image size
2481
- return self.loaded_model, self.loaded_model_imgsz
1021
+ return self.loaded_model
2482
1022
 
2483
- def _prepare_images_from_data_items(self, data_items, progress_bar=None):
1023
+ def _prepare_images_from_data_items(self, data_items, progress_bar=None, format='numpy', target_size=None):
2484
1024
  """
2485
1025
  Prepare images from data items for model prediction.
2486
1026
 
2487
1027
  Args:
2488
1028
  data_items (list): List of AnnotationDataItem objects
2489
1029
  progress_bar (ProgressBar, optional): Progress bar for UI updates
1030
+ format (str, optional): Output format, either 'numpy' or 'pil'. Default is 'numpy'.
1031
+ target_size (tuple, optional): Target size for resizing (width, height). If None, no resizing is performed.
2490
1032
 
2491
1033
  Returns:
2492
1034
  tuple: (image_list, valid_data_items)
@@ -2499,7 +1041,20 @@ class ExplorerWindow(QMainWindow):
2499
1041
  for item in data_items:
2500
1042
  pixmap = item.annotation.get_cropped_image()
2501
1043
  if pixmap and not pixmap.isNull():
2502
- image_list.append(pixmap_to_numpy(pixmap))
1044
+ # Always convert to PIL first for easier resizing
1045
+ pil_img = pixmap_to_pil(pixmap)
1046
+
1047
+ # Resize if target size is specified
1048
+ if target_size and isinstance(target_size, (tuple, list)) and len(target_size) == 2:
1049
+ pil_img = pil_img.resize(target_size, resample=2) # 2 = PIL.Image.BILINEAR
1050
+
1051
+ # Convert to the requested format
1052
+ if format.lower() == 'pil':
1053
+ image_list.append(pil_img)
1054
+ else: # Convert to numpy
1055
+ img_array = np.array(pil_img)
1056
+ image_list.append(img_array)
1057
+
2503
1058
  valid_data_items.append(item)
2504
1059
 
2505
1060
  if progress_bar:
@@ -2539,25 +1094,35 @@ class ExplorerWindow(QMainWindow):
2539
1094
  features_list.append(embedding)
2540
1095
 
2541
1096
  elif hasattr(result, 'probs') and result.probs is not None:
2542
- probs = result.probs.data.cpu().numpy().squeeze()
2543
- features_list.append(probs)
2544
- probabilities_dict[ann_id] = probs
2545
-
2546
- # Store the probabilities directly on the data item for confidence sorting
2547
- item.prediction_probabilities = probs
2548
-
2549
- # Format and store prediction details for tooltips
2550
- if len(probs) > 0:
2551
- # Get top 5 predictions
2552
- top_indices = probs.argsort()[::-1][:5]
2553
- top_probs = probs[top_indices]
2554
-
2555
- formatted_preds = ["<b>Top Predictions:</b>"]
2556
- for idx, prob in zip(top_indices, top_probs):
2557
- class_name = class_names.get(int(idx), f"Class {idx}")
2558
- formatted_preds.append(f"{class_name}: {prob*100:.1f}%")
2559
-
2560
- item.prediction_details = "<br>".join(formatted_preds)
1097
+ try:
1098
+ probs = result.probs.data.cpu().numpy().squeeze()
1099
+ features_list.append(probs)
1100
+ probabilities_dict[ann_id] = probs
1101
+
1102
+ # Store the probabilities directly on the data item for confidence sorting
1103
+ item.prediction_probabilities = probs
1104
+
1105
+ # Format and store prediction details for tooltips
1106
+ # This check will fail with a TypeError if probs is a scalar (unsized)
1107
+ if len(probs) > 0:
1108
+ # Get top 5 predictions
1109
+ top_indices = probs.argsort()[::-1][:5]
1110
+ top_probs = probs[top_indices]
1111
+
1112
+ formatted_preds = ["<b>Top Predictions:</b>"]
1113
+ for idx, prob in zip(top_indices, top_probs):
1114
+ class_name = class_names.get(int(idx), f"Class {idx}")
1115
+ formatted_preds.append(f"{class_name}: {prob*100:.1f}%")
1116
+
1117
+ item.prediction_details = "<br>".join(formatted_preds)
1118
+
1119
+ except TypeError:
1120
+ # This error is raised if len(probs) fails on a scalar value.
1121
+ raise TypeError(
1122
+ "The selected model is not compatible with 'Predictions' mode. "
1123
+ "Its output does not appear to be a list of class probabilities. "
1124
+ "Try using 'Embed Features' mode instead."
1125
+ )
2561
1126
  else:
2562
1127
  raise TypeError(
2563
1128
  "The 'Predictions' feature mode requires a classification model "
@@ -2660,19 +1225,25 @@ class ExplorerWindow(QMainWindow):
2660
1225
  model_name, feature_mode = model_info
2661
1226
 
2662
1227
  # Load the model
2663
- model, imgsz = self._load_yolo_model(model_name, feature_mode)
1228
+ model = self._load_yolo_model(model_name, feature_mode)
2664
1229
  if model is None:
2665
1230
  return np.array([]), []
2666
1231
 
2667
- # Prepare images from data items
2668
- image_list, valid_data_items = self._prepare_images_from_data_items(data_items, progress_bar)
1232
+ # Prepare images from data items with proper resizing
1233
+ image_list, valid_data_items = self._prepare_images_from_data_items(
1234
+ data_items,
1235
+ progress_bar,
1236
+ format='numpy',
1237
+ target_size=(self.imgsz, self.imgsz)
1238
+ )
1239
+
2669
1240
  if not valid_data_items:
2670
1241
  return np.array([]), []
2671
1242
 
2672
1243
  # Set up prediction parameters
2673
1244
  kwargs = {
2674
1245
  'stream': True,
2675
- 'imgsz': imgsz,
1246
+ 'imgsz': self.imgsz,
2676
1247
  'half': True,
2677
1248
  'device': self.device,
2678
1249
  'verbose': False
@@ -2689,13 +1260,106 @@ class ExplorerWindow(QMainWindow):
2689
1260
  progress_bar.start_progress(len(valid_data_items))
2690
1261
 
2691
1262
  try:
2692
- features_list, _ = self._process_model_results(results_generator,
2693
- valid_data_items,
1263
+ features_list, _ = self._process_model_results(results_generator,
1264
+ valid_data_items,
2694
1265
  feature_mode,
2695
1266
  progress_bar=progress_bar)
2696
-
1267
+
2697
1268
  return np.array(features_list), valid_data_items
1269
+
1270
+ except TypeError as e:
1271
+ QMessageBox.warning(self, "Model Incompatibility Error", str(e))
1272
+ return np.array([]), [] # Return empty results to safely stop the pipeline
1273
+
1274
+ finally:
1275
+ if torch.cuda.is_available():
1276
+ torch.cuda.empty_cache()
1277
+
1278
+ def _extract_transformer_features(self, data_items, model_name, progress_bar=None):
1279
+ """
1280
+ Extract features using transformer models from HuggingFace.
1281
+
1282
+ Args:
1283
+ data_items: List of AnnotationDataItem objects
1284
+ model_name: Name of the transformer model to use
1285
+ progress_bar: Optional progress bar for tracking
1286
+
1287
+ Returns:
1288
+ tuple: (features array, valid data items list)
1289
+ """
1290
+ try:
1291
+ if progress_bar:
1292
+ progress_bar.set_busy_mode(f"Loading model {model_name}...")
1293
+
1294
+ # Load the model with caching support
1295
+ feature_extractor = self._load_transformer_model(model_name)
1296
+
1297
+ if feature_extractor is None:
1298
+ print(f"Failed to load transformer model: {model_name}")
1299
+ return np.array([]), []
1300
+
1301
+ # Prepare images from data items - get PIL images directly with proper sizing
1302
+ image_list, valid_data_items = self._prepare_images_from_data_items(
1303
+ data_items,
1304
+ progress_bar,
1305
+ format='pil',
1306
+ target_size=(self.imgsz, self.imgsz)
1307
+ )
1308
+
1309
+ if not image_list:
1310
+ return np.array([]), []
1311
+
1312
+ if progress_bar:
1313
+ progress_bar.set_title("Extracting features...")
1314
+ progress_bar.start_progress(len(valid_data_items))
2698
1315
 
1316
+ features_list = []
1317
+ valid_items = []
1318
+
1319
+ # Process images in batches or individually
1320
+ for i, image in enumerate(image_list):
1321
+ try:
1322
+ # Extract features
1323
+ features = feature_extractor(image)
1324
+
1325
+ # Handle different output formats from transformers
1326
+ if isinstance(features, list):
1327
+ feature_tensor = features[0] if len(features) > 0 else features
1328
+ else:
1329
+ feature_tensor = features
1330
+
1331
+ # Convert to numpy array, handling GPU tensors properly
1332
+ if hasattr(feature_tensor, 'cpu'):
1333
+ # Move tensor to CPU before converting to numpy
1334
+ feature_vector = feature_tensor.cpu().numpy().flatten()
1335
+ else:
1336
+ # Already numpy array or other CPU-compatible format
1337
+ feature_vector = np.array(feature_tensor).flatten()
1338
+
1339
+ features_list.append(feature_vector)
1340
+ valid_items.append(valid_data_items[i])
1341
+
1342
+ except Exception as e:
1343
+ print(f"Error extracting features for item {i}: {e}")
1344
+
1345
+ finally:
1346
+ if progress_bar:
1347
+ progress_bar.update_progress()
1348
+
1349
+ # Make sure we have consistent feature dimensions
1350
+ if features_list:
1351
+ features_array = np.array(features_list)
1352
+ return features_array, valid_items
1353
+ else:
1354
+ return np.array([]), []
1355
+
1356
+ except Exception as e:
1357
+ QMessageBox.warning(self,
1358
+ "Feature Extraction Error",
1359
+ f"An error occurred during transformer feature extraction.\n\nError: {e}")
1360
+
1361
+ return np.array([]), []
1362
+
2699
1363
  finally:
2700
1364
  if torch.cuda.is_available():
2701
1365
  torch.cuda.empty_cache()
@@ -2711,11 +1375,17 @@ class ExplorerWindow(QMainWindow):
2711
1375
  if not model_name:
2712
1376
  return np.array([]), []
2713
1377
 
1378
+ # Check if it's Color Features first
2714
1379
  if model_name == "Color Features":
2715
1380
  return self._extract_color_features(data_items, progress_bar=progress_bar)
2716
1381
 
2717
- elif ".pt" in model_name:
1382
+ # Then check if it's a YOLO model (file path with .pt)
1383
+ elif is_yolo_model(model_name):
2718
1384
  return self._extract_yolo_features(data_items, (model_name, feature_mode), progress_bar=progress_bar)
1385
+
1386
+ # Finally check if it's a transformer model using the shared utility function
1387
+ elif is_transformer_model(model_name):
1388
+ return self._extract_transformer_features(data_items, model_name, progress_bar=progress_bar)
2719
1389
 
2720
1390
  return np.array([]), []
2721
1391
 
@@ -2782,7 +1452,10 @@ class ExplorerWindow(QMainWindow):
2782
1452
  return reducer.fit_transform(features_scaled)
2783
1453
 
2784
1454
  except Exception as e:
2785
- print(f"Error during {technique} dimensionality reduction: {e}")
1455
+ QMessageBox.warning(self,
1456
+ "Embedding Error",
1457
+ f"An error occurred during dimensionality reduction with {technique}.\n\nError: {e}")
1458
+
2786
1459
  return None
2787
1460
 
2788
1461
  def _update_data_items_with_embedding(self, data_items, embedded_features):
@@ -2991,7 +1664,9 @@ class ExplorerWindow(QMainWindow):
2991
1664
  self.annotation_window.load_annotations()
2992
1665
 
2993
1666
  except Exception as e:
2994
- print(f"Error during item deletion: {e}")
1667
+ QMessageBox.warning(self,
1668
+ "Deletion Error",
1669
+ f"An error occurred while deleting annotations.\n\nError: {e}")
2995
1670
  finally:
2996
1671
  QApplication.restoreOverrideCursor()
2997
1672
 
@@ -3066,7 +1741,9 @@ class ExplorerWindow(QMainWindow):
3066
1741
  print("Applied changes successfully.")
3067
1742
 
3068
1743
  except Exception as e:
3069
- print(f"Error applying modifications: {e}")
1744
+ QMessageBox.warning(self,
1745
+ "Apply Error",
1746
+ f"An error occurred while applying changes.\n\nError: {e}")
3070
1747
  finally:
3071
1748
  QApplication.restoreOverrideCursor()
3072
1749