coralnet-toolbox 0.0.67__py2.py3-none-any.whl → 0.0.69__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,32 +1,33 @@
1
+ import warnings
2
+
1
3
  import os
4
+
2
5
  import numpy as np
3
6
  import torch
4
- import warnings
5
7
 
6
8
  from ultralytics import YOLO
7
9
 
8
- from coralnet_toolbox.MachineLearning.Community.cfg import get_available_configs
9
-
10
10
  from coralnet_toolbox.Icons import get_icon
11
11
  from coralnet_toolbox.utilities import pixmap_to_numpy
12
12
 
13
- from PyQt5.QtGui import QIcon, QPen, QColor, QPainter, QImage, QBrush, QPainterPath, QPolygonF, QMouseEvent
14
- from PyQt5.QtCore import Qt, QTimer, QSize, QRect, QRectF, QPointF, pyqtSignal, QSignalBlocker, pyqtSlot
15
-
13
+ from PyQt5.QtGui import QIcon, QPen, QColor, QPainter, QBrush, QPainterPath, QMouseEvent
14
+ from PyQt5.QtCore import Qt, QTimer, QRect, QRectF, QPointF, pyqtSignal, QSignalBlocker, pyqtSlot
16
15
  from PyQt5.QtWidgets import (QVBoxLayout, QHBoxLayout, QGraphicsView, QScrollArea,
17
- QGraphicsScene, QPushButton, QComboBox, QLabel, QWidget, QGridLayout,
18
- QMainWindow, QSplitter, QGroupBox, QFormLayout,
19
- QSpinBox, QGraphicsEllipseItem, QGraphicsItem, QSlider,
20
- QListWidget, QDoubleSpinBox, QApplication, QStyle,
21
- QGraphicsRectItem, QRubberBand, QStyleOptionGraphicsItem,
22
- QTabWidget, QLineEdit, QFileDialog)
23
-
24
- from .QtAnnotationDataItem import AnnotationDataItem
25
- from .QtEmbeddingPointItem import EmbeddingPointItem
26
- from .QtAnnotationImageWidget import AnnotationImageWidget
27
- from .QtSettingsWidgets import ModelSettingsWidget
28
- from .QtSettingsWidgets import EmbeddingSettingsWidget
29
- from .QtSettingsWidgets import AnnotationSettingsWidget
16
+ QGraphicsScene, QPushButton, QComboBox, QLabel, QWidget,
17
+ QMainWindow, QSplitter, QGroupBox, QSlider, QMessageBox,
18
+ QApplication, QGraphicsRectItem, QRubberBand, QMenu,
19
+ QWidgetAction, QToolButton, QAction)
20
+
21
+ from coralnet_toolbox.Explorer.QtFeatureStore import FeatureStore
22
+ from coralnet_toolbox.Explorer.QtDataItem import AnnotationDataItem
23
+ from coralnet_toolbox.Explorer.QtDataItem import EmbeddingPointItem
24
+ from coralnet_toolbox.Explorer.QtDataItem import AnnotationImageWidget
25
+ from coralnet_toolbox.Explorer.QtSettingsWidgets import ModelSettingsWidget
26
+ from coralnet_toolbox.Explorer.QtSettingsWidgets import SimilaritySettingsWidget
27
+ from coralnet_toolbox.Explorer.QtSettingsWidgets import UncertaintySettingsWidget
28
+ from coralnet_toolbox.Explorer.QtSettingsWidgets import MislabelSettingsWidget
29
+ from coralnet_toolbox.Explorer.QtSettingsWidgets import EmbeddingSettingsWidget
30
+ from coralnet_toolbox.Explorer.QtSettingsWidgets import AnnotationSettingsWidget
30
31
 
31
32
  from coralnet_toolbox.QtProgressBar import ProgressBar
32
33
 
@@ -34,13 +35,13 @@ try:
34
35
  from sklearn.preprocessing import StandardScaler
35
36
  from sklearn.decomposition import PCA
36
37
  from sklearn.manifold import TSNE
37
- from umap import UMAP
38
+ from umap import UMAP
38
39
  except ImportError:
39
40
  print("Warning: sklearn or umap not installed. Some features may be unavailable.")
40
41
  StandardScaler = None
41
42
  PCA = None
42
43
  TSNE = None
43
- UMAP = None
44
+ UMAP = None
44
45
 
45
46
 
46
47
  warnings.filterwarnings("ignore", category=DeprecationWarning)
@@ -50,60 +51,56 @@ warnings.filterwarnings("ignore", category=DeprecationWarning)
50
51
  # Constants
51
52
  # ----------------------------------------------------------------------------------------------------------------------
52
53
 
53
-
54
- POINT_SIZE = 15
55
54
  POINT_WIDTH = 3
56
55
 
57
-
58
56
  # ----------------------------------------------------------------------------------------------------------------------
59
57
  # Viewers
60
58
  # ----------------------------------------------------------------------------------------------------------------------
61
59
 
62
60
 
63
- class EmbeddingViewer(QWidget): # Change inheritance to QWidget
64
- """Custom QGraphicsView for interactive embedding visualization with zooming, panning, and selection."""
65
-
66
- # Define signal to report selection changes
67
- selection_changed = pyqtSignal(list) # list of all currently selected annotation IDs
68
- reset_view_requested = pyqtSignal() # Signal to reset the view to fit all points
69
-
61
+ class EmbeddingViewer(QWidget):
62
+ """Custom QGraphicsView for interactive embedding visualization with an isolate mode."""
63
+ selection_changed = pyqtSignal(list)
64
+ reset_view_requested = pyqtSignal()
65
+ find_mislabels_requested = pyqtSignal()
66
+ mislabel_parameters_changed = pyqtSignal(dict)
67
+ find_uncertain_requested = pyqtSignal()
68
+ uncertainty_parameters_changed = pyqtSignal(dict)
69
+
70
70
  def __init__(self, parent=None):
71
- # Create the graphics scene first
72
- self.graphics_scene = QGraphicsScene()
73
- self.graphics_scene.setSceneRect(-5000, -5000, 10000, 10000)
74
-
75
- # Initialize as a QWidget
71
+ """Initialize the EmbeddingViewer widget."""
76
72
  super(EmbeddingViewer, self).__init__(parent)
77
73
  self.explorer_window = parent
78
-
79
- # Create the actual graphics view
74
+
75
+ self.graphics_scene = QGraphicsScene()
76
+ self.graphics_scene.setSceneRect(-5000, -5000, 10000, 10000)
77
+
80
78
  self.graphics_view = QGraphicsView(self.graphics_scene)
81
79
  self.graphics_view.setRenderHint(QPainter.Antialiasing)
82
80
  self.graphics_view.setDragMode(QGraphicsView.ScrollHandDrag)
83
81
  self.graphics_view.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
84
82
  self.graphics_view.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
85
83
  self.graphics_view.setMinimumHeight(200)
86
-
87
- # Custom rubber_band state variables
84
+
88
85
  self.rubber_band = None
89
86
  self.rubber_band_origin = QPointF()
90
87
  self.selection_at_press = None
88
+ self.points_by_id = {}
89
+ self.previous_selection_ids = set()
90
+
91
+ # State for isolate mode
92
+ self.isolated_mode = False
93
+ self.isolated_points = set()
91
94
 
92
- self.points_by_id = {} # Map annotation ID to embedding point
93
- self.previous_selection_ids = set() # Track previous selection to detect changes
94
-
95
+ self.is_uncertainty_analysis_available = False
96
+
95
97
  self.animation_offset = 0
96
98
  self.animation_timer = QTimer()
97
99
  self.animation_timer.timeout.connect(self.animate_selection)
98
100
  self.animation_timer.setInterval(100)
99
-
100
- # Connect the scene's selection signal
101
+
101
102
  self.graphics_scene.selectionChanged.connect(self.on_selection_changed)
102
-
103
- # Setup the UI with header
104
103
  self.setup_ui()
105
-
106
- # Connect mouse events to the graphics view
107
104
  self.graphics_view.mousePressEvent = self.mousePressEvent
108
105
  self.graphics_view.mouseDoubleClickEvent = self.mouseDoubleClickEvent
109
106
  self.graphics_view.mouseReleaseEvent = self.mouseReleaseEvent
@@ -111,36 +108,164 @@ class EmbeddingViewer(QWidget): # Change inheritance to QWidget
111
108
  self.graphics_view.wheelEvent = self.wheelEvent
112
109
 
113
110
  def setup_ui(self):
114
- """Set up the UI with header layout and graphics view."""
111
+ """Set up the UI with toolbar layout and graphics view."""
115
112
  layout = QVBoxLayout(self)
116
113
  layout.setContentsMargins(0, 0, 0, 0)
114
+
115
+ toolbar_layout = QHBoxLayout()
116
+
117
+ # Isolate/Show All buttons
118
+ self.isolate_button = QPushButton("Isolate Selection")
119
+ self.isolate_button.setToolTip("Hide all non-selected points")
120
+ self.isolate_button.clicked.connect(self.isolate_selection)
121
+ toolbar_layout.addWidget(self.isolate_button)
122
+
123
+ self.show_all_button = QPushButton("Show All")
124
+ self.show_all_button.setToolTip("Show all embedding points")
125
+ self.show_all_button.clicked.connect(self.show_all_points)
126
+ toolbar_layout.addWidget(self.show_all_button)
117
127
 
118
- # Header layout
119
- header_layout = QHBoxLayout()
128
+ toolbar_layout.addWidget(self._create_separator())
129
+
130
+ # Create a QToolButton to have both a primary action and a dropdown menu
131
+ self.find_mislabels_button = QToolButton()
132
+ self.find_mislabels_button.setText("Find Potential Mislabels")
133
+ self.find_mislabels_button.setPopupMode(QToolButton.MenuButtonPopup) # Key change for split-button style
134
+ self.find_mislabels_button.setToolButtonStyle(Qt.ToolButtonTextOnly)
135
+ self.find_mislabels_button.setStyleSheet(
136
+ "QToolButton::menu-indicator {"
137
+ " subcontrol-position: right center;"
138
+ " subcontrol-origin: padding;"
139
+ " left: -4px;"
140
+ " }"
141
+ )
142
+
143
+ # The primary action (clicking the button) triggers the analysis
144
+ run_analysis_action = QAction("Find Potential Mislabels", self)
145
+ run_analysis_action.triggered.connect(self.find_mislabels_requested.emit)
146
+ self.find_mislabels_button.setDefaultAction(run_analysis_action)
147
+
148
+ # The dropdown menu contains the settings
149
+ mislabel_settings_widget = MislabelSettingsWidget()
150
+ settings_menu = QMenu(self)
151
+ widget_action = QWidgetAction(settings_menu)
152
+ widget_action.setDefaultWidget(mislabel_settings_widget)
153
+ settings_menu.addAction(widget_action)
154
+ self.find_mislabels_button.setMenu(settings_menu)
155
+
156
+ # Connect the widget's signal to the viewer's signal
157
+ mislabel_settings_widget.parameters_changed.connect(self.mislabel_parameters_changed.emit)
158
+ toolbar_layout.addWidget(self.find_mislabels_button)
159
+
160
+ # Create a QToolButton for uncertainty analysis
161
+ self.find_uncertain_button = QToolButton()
162
+ self.find_uncertain_button.setText("Review Uncertain")
163
+ self.find_uncertain_button.setToolTip(
164
+ "Find annotations where the model is least confident.\n"
165
+ "Requires a .pt classification model and 'Predictions' mode."
166
+ )
167
+ self.find_uncertain_button.setPopupMode(QToolButton.MenuButtonPopup)
168
+ self.find_uncertain_button.setToolButtonStyle(Qt.ToolButtonTextOnly)
169
+ self.find_uncertain_button.setStyleSheet(
170
+ "QToolButton::menu-indicator { "
171
+ "subcontrol-position: right center; "
172
+ "subcontrol-origin: padding; "
173
+ "left: -4px; }"
174
+ )
120
175
 
121
- # Home button
122
- self.home_button = QPushButton("Home")
123
- self.home_button.setToolTip("Reset view to fit all points")
124
- self.home_button.clicked.connect(self.reset_view)
125
- header_layout.addWidget(self.home_button)
176
+ run_uncertainty_action = QAction("Review Uncertain", self)
177
+ run_uncertainty_action.triggered.connect(self.find_uncertain_requested.emit)
178
+ self.find_uncertain_button.setDefaultAction(run_uncertainty_action)
179
+
180
+ uncertainty_settings_widget = UncertaintySettingsWidget()
181
+ uncertainty_menu = QMenu(self)
182
+ uncertainty_widget_action = QWidgetAction(uncertainty_menu)
183
+ uncertainty_widget_action.setDefaultWidget(uncertainty_settings_widget)
184
+ uncertainty_menu.addAction(uncertainty_widget_action)
185
+ self.find_uncertain_button.setMenu(uncertainty_menu)
126
186
 
127
- # Add stretch to push future controls to the right if needed
128
- header_layout.addStretch()
187
+ uncertainty_settings_widget.parameters_changed.connect(self.uncertainty_parameters_changed.emit)
188
+ toolbar_layout.addWidget(self.find_uncertain_button)
189
+
190
+ toolbar_layout.addStretch()
129
191
 
130
- layout.addLayout(header_layout)
192
+ # Home button to reset view
193
+ self.home_button = QPushButton()
194
+ self.home_button.setIcon(get_icon("home.png"))
195
+ self.home_button.setToolTip("Reset view to fit all points")
196
+ self.home_button.clicked.connect(self.reset_view)
197
+ toolbar_layout.addWidget(self.home_button)
131
198
 
132
- # Add the graphics view
199
+ layout.addLayout(toolbar_layout)
133
200
  layout.addWidget(self.graphics_view)
134
- # Add a placeholder label when no embedding is available
201
+
135
202
  self.placeholder_label = QLabel(
136
203
  "No embedding data available.\nPress 'Apply Embedding' to generate visualization."
137
204
  )
138
205
  self.placeholder_label.setAlignment(Qt.AlignCenter)
139
206
  self.placeholder_label.setStyleSheet("color: gray; font-size: 14px;")
140
207
  layout.addWidget(self.placeholder_label)
141
-
142
- # Initially show placeholder
208
+
143
209
  self.show_placeholder()
210
+ self._update_toolbar_state()
211
+
212
+ def _create_separator(self):
213
+ """Creates a vertical separator for the toolbar."""
214
+ separator = QLabel("|")
215
+ separator.setStyleSheet("color: gray; margin: 0 5px;")
216
+ return separator
217
+
218
+ @pyqtSlot()
219
+ def isolate_selection(self):
220
+ """Hides all points that are not currently selected."""
221
+ selected_items = self.graphics_scene.selectedItems()
222
+ if not selected_items or self.isolated_mode:
223
+ return
224
+
225
+ self.isolated_points = set(selected_items)
226
+ self.graphics_view.setUpdatesEnabled(False)
227
+ try:
228
+ for point in self.points_by_id.values():
229
+ if point not in self.isolated_points:
230
+ point.hide()
231
+ self.isolated_mode = True
232
+ finally:
233
+ self.graphics_view.setUpdatesEnabled(True)
234
+
235
+ self._update_toolbar_state()
236
+
237
+ @pyqtSlot()
238
+ def show_all_points(self):
239
+ """Shows all embedding points, exiting isolated mode."""
240
+ if not self.isolated_mode:
241
+ return
242
+
243
+ self.isolated_mode = False
244
+ self.isolated_points.clear()
245
+ self.graphics_view.setUpdatesEnabled(False)
246
+ try:
247
+ for point in self.points_by_id.values():
248
+ point.show()
249
+ finally:
250
+ self.graphics_view.setUpdatesEnabled(True)
251
+
252
+ self._update_toolbar_state()
253
+
254
+ def _update_toolbar_state(self):
255
+ """Updates toolbar buttons based on selection and isolation mode."""
256
+ selection_exists = bool(self.graphics_scene.selectedItems())
257
+ points_exist = bool(self.points_by_id)
258
+
259
+ self.find_mislabels_button.setEnabled(points_exist)
260
+ self.find_uncertain_button.setEnabled(points_exist and self.is_uncertainty_analysis_available)
261
+
262
+ if self.isolated_mode:
263
+ self.isolate_button.hide()
264
+ self.show_all_button.show()
265
+ else:
266
+ self.isolate_button.show()
267
+ self.show_all_button.hide()
268
+ self.isolate_button.setEnabled(selection_exists)
144
269
 
145
270
  def reset_view(self):
146
271
  """Reset the view to fit all embedding points."""
@@ -151,50 +276,87 @@ class EmbeddingViewer(QWidget): # Change inheritance to QWidget
151
276
  self.graphics_view.setVisible(False)
152
277
  self.placeholder_label.setVisible(True)
153
278
  self.home_button.setEnabled(False)
279
+ self.find_mislabels_button.setEnabled(False)
280
+ self.find_uncertain_button.setEnabled(False)
281
+
282
+ self.isolate_button.show()
283
+ self.isolate_button.setEnabled(False)
284
+ self.show_all_button.hide()
154
285
 
155
286
  def show_embedding(self):
156
287
  """Show the graphics view and hide the placeholder message."""
157
288
  self.graphics_view.setVisible(True)
158
289
  self.placeholder_label.setVisible(False)
159
290
  self.home_button.setEnabled(True)
291
+ self._update_toolbar_state()
160
292
 
161
293
  # Delegate graphics view methods
162
294
  def setRenderHint(self, hint):
295
+ """Set render hint for the graphics view."""
163
296
  self.graphics_view.setRenderHint(hint)
164
-
297
+
165
298
  def setDragMode(self, mode):
299
+ """Set drag mode for the graphics view."""
166
300
  self.graphics_view.setDragMode(mode)
167
-
301
+
168
302
  def setTransformationAnchor(self, anchor):
303
+ """Set transformation anchor for the graphics view."""
169
304
  self.graphics_view.setTransformationAnchor(anchor)
170
-
305
+
171
306
  def setResizeAnchor(self, anchor):
307
+ """Set resize anchor for the graphics view."""
172
308
  self.graphics_view.setResizeAnchor(anchor)
173
-
309
+
174
310
  def mapToScene(self, point):
311
+ """Map a point to the scene coordinates."""
175
312
  return self.graphics_view.mapToScene(point)
176
-
313
+
177
314
  def scale(self, sx, sy):
315
+ """Scale the graphics view."""
178
316
  self.graphics_view.scale(sx, sy)
179
-
317
+
180
318
  def translate(self, dx, dy):
319
+ """Translate the graphics view."""
181
320
  self.graphics_view.translate(dx, dy)
182
-
321
+
183
322
  def fitInView(self, rect, aspect_ratio):
323
+ """Fit the view to a rectangle with aspect ratio."""
184
324
  self.graphics_view.fitInView(rect, aspect_ratio)
185
325
 
326
+ def keyPressEvent(self, event):
327
+ """Handles key presses for deleting selected points."""
328
+ if event.key() in (Qt.Key_Delete, Qt.Key_Backspace) and event.modifiers() == Qt.ControlModifier:
329
+ selected_items = self.graphics_scene.selectedItems()
330
+ if not selected_items:
331
+ super().keyPressEvent(event)
332
+ return
333
+
334
+ # Extract the central data items from the selected graphics points
335
+ data_items_to_delete = [
336
+ item.data_item for item in selected_items if isinstance(item, EmbeddingPointItem)
337
+ ]
338
+
339
+ # Delegate the actual deletion to the main ExplorerWindow
340
+ if data_items_to_delete:
341
+ self.explorer_window.delete_data_items(data_items_to_delete)
342
+
343
+ event.accept()
344
+ else:
345
+ super().keyPressEvent(event)
346
+
186
347
  def mousePressEvent(self, event):
187
348
  """Handle mouse press for selection (point or rubber band) and panning."""
188
349
  if event.button() == Qt.LeftButton and event.modifiers() == Qt.ControlModifier:
189
- # Check if the click is on an existing point
190
350
  item_at_pos = self.graphics_view.itemAt(event.pos())
191
351
  if isinstance(item_at_pos, EmbeddingPointItem):
192
- # If so, toggle its selection state and do nothing else
193
352
  self.graphics_view.setDragMode(QGraphicsView.NoDrag)
194
- item_at_pos.setSelected(not item_at_pos.isSelected())
195
- return # Event handled
353
+ # The viewer (controller) directly changes the state on the data item.
354
+ is_currently_selected = item_at_pos.data_item.is_selected
355
+ item_at_pos.data_item.set_selected(not is_currently_selected)
356
+ item_at_pos.setSelected(not is_currently_selected) # Keep scene selection in sync
357
+ self.on_selection_changed() # Manually trigger update
358
+ return
196
359
 
197
- # If the click was on the background, proceed with rubber band selection
198
360
  self.selection_at_press = set(self.graphics_scene.selectedItems())
199
361
  self.graphics_view.setDragMode(QGraphicsView.NoDrag)
200
362
  self.rubber_band_origin = self.graphics_view.mapToScene(event.pos())
@@ -204,210 +366,186 @@ class EmbeddingViewer(QWidget): # Change inheritance to QWidget
204
366
  self.graphics_scene.addItem(self.rubber_band)
205
367
 
206
368
  elif event.button() == Qt.RightButton:
207
- # Handle panning
208
369
  self.graphics_view.setDragMode(QGraphicsView.ScrollHandDrag)
209
370
  left_event = QMouseEvent(event.type(), event.localPos(), Qt.LeftButton, Qt.LeftButton, event.modifiers())
210
371
  QGraphicsView.mousePressEvent(self.graphics_view, left_event)
211
372
  else:
212
- # Handle standard single-item selection
213
373
  self.graphics_view.setDragMode(QGraphicsView.NoDrag)
214
374
  QGraphicsView.mousePressEvent(self.graphics_view, event)
215
-
375
+
216
376
  def mouseDoubleClickEvent(self, event):
217
377
  """Handle double-click to clear selection and reset the main view."""
218
378
  if event.button() == Qt.LeftButton:
219
- # Clear selection if any items are selected
220
379
  if self.graphics_scene.selectedItems():
221
- self.graphics_scene.clearSelection() # This triggers on_selection_changed
222
-
223
- # Signal the main window to revert from isolation mode
380
+ self.graphics_scene.clearSelection()
224
381
  self.reset_view_requested.emit()
225
382
  event.accept()
226
383
  else:
227
- # Pass other double-clicks to the base class
228
384
  super().mouseDoubleClickEvent(event)
229
385
 
230
386
  def mouseMoveEvent(self, event):
231
387
  """Handle mouse move for dynamic selection and panning."""
232
388
  if self.rubber_band:
233
- # Update the rubber band geometry
389
+ # Update the rubber band rectangle as the mouse moves
234
390
  current_pos = self.graphics_view.mapToScene(event.pos())
235
391
  self.rubber_band.setRect(QRectF(self.rubber_band_origin, current_pos).normalized())
236
-
392
+ # Create a selection path from the rubber band rectangle
237
393
  path = QPainterPath()
238
394
  path.addRect(self.rubber_band.rect())
239
-
240
- # Block signals to perform a compound selection operation
395
+ # Block signals to avoid recursive selectionChanged events
241
396
  self.graphics_scene.blockSignals(True)
242
-
243
- # 1. Perform the "fancy" dynamic selection, which replaces the current selection
244
- # with only the items inside the rubber band.
245
397
  self.graphics_scene.setSelectionArea(path)
246
-
247
- # 2. Add back the items that were selected at the start of the drag.
398
+ # Restore selection for items that were already selected at press
248
399
  if self.selection_at_press:
249
400
  for item in self.selection_at_press:
250
401
  item.setSelected(True)
251
-
252
- # Unblock signals and manually trigger our handler to process the final result.
253
402
  self.graphics_scene.blockSignals(False)
403
+ # Manually trigger selection changed logic
254
404
  self.on_selection_changed()
255
-
256
405
  elif event.buttons() == Qt.RightButton:
257
- # Handle right-click panning
258
- left_event = QMouseEvent(event.type(),
259
- event.localPos(),
260
- Qt.LeftButton,
261
- Qt.LeftButton,
262
- event.modifiers())
406
+ # Forward right-drag as left-drag for panning
407
+ left_event = QMouseEvent(event.type(), event.localPos(), Qt.LeftButton, Qt.LeftButton, event.modifiers())
263
408
  QGraphicsView.mouseMoveEvent(self.graphics_view, left_event)
264
409
  else:
410
+ # Default mouse move handling
265
411
  QGraphicsView.mouseMoveEvent(self.graphics_view, event)
266
412
 
267
413
  def mouseReleaseEvent(self, event):
268
414
  """Handle mouse release to finalize the action and clean up."""
269
415
  if self.rubber_band:
270
- # Clean up the visual rectangle
271
416
  self.graphics_scene.removeItem(self.rubber_band)
272
417
  self.rubber_band = None
273
-
274
- # Clean up the stored selection state.
275
418
  self.selection_at_press = None
276
-
277
419
  elif event.button() == Qt.RightButton:
278
- # Finalize the pan
279
- left_event = QMouseEvent(event.type(),
280
- event.localPos(),
281
- Qt.LeftButton,
282
- Qt.LeftButton,
283
- event.modifiers())
420
+ left_event = QMouseEvent(event.type(), event.localPos(), Qt.LeftButton, Qt.LeftButton, event.modifiers())
284
421
  QGraphicsView.mouseReleaseEvent(self.graphics_view, left_event)
285
422
  self.graphics_view.setDragMode(QGraphicsView.NoDrag)
286
423
  else:
287
- # Finalize a single click
288
424
  QGraphicsView.mouseReleaseEvent(self.graphics_view, event)
289
425
  self.graphics_view.setDragMode(QGraphicsView.NoDrag)
290
-
426
+
291
427
  def wheelEvent(self, event):
292
428
  """Handle mouse wheel for zooming."""
293
429
  zoom_in_factor = 1.25
294
430
  zoom_out_factor = 1 / zoom_in_factor
295
431
 
432
+ # Set anchor points so zoom occurs at mouse position
296
433
  self.graphics_view.setTransformationAnchor(QGraphicsView.NoAnchor)
297
434
  self.graphics_view.setResizeAnchor(QGraphicsView.NoAnchor)
298
435
 
436
+ # Get the scene position before zooming
299
437
  old_pos = self.graphics_view.mapToScene(event.pos())
438
+
439
+ # Determine zoom direction
300
440
  zoom_factor = zoom_in_factor if event.angleDelta().y() > 0 else zoom_out_factor
441
+
442
+ # Apply zoom
301
443
  self.graphics_view.scale(zoom_factor, zoom_factor)
444
+
445
+ # Get the scene position after zooming
302
446
  new_pos = self.graphics_view.mapToScene(event.pos())
303
-
447
+
448
+ # Translate view to keep mouse position stable
304
449
  delta = new_pos - old_pos
305
450
  self.graphics_view.translate(delta.x(), delta.y())
306
451
 
307
452
  def update_embeddings(self, data_items):
308
- """Update the embedding visualization with new data.
453
+ """Update the embedding visualization. Creates an EmbeddingPointItem for
454
+ each AnnotationDataItem and links them."""
455
+ # Reset isolation state when loading new points
456
+ if self.isolated_mode:
457
+ self.show_all_points()
309
458
 
310
- Args:
311
- data_items: List of AnnotationDataItem objects.
312
- """
313
459
  self.clear_points()
314
-
315
460
  for item in data_items:
316
- point = EmbeddingPointItem(0, 0, POINT_SIZE, POINT_SIZE)
317
- point.setPos(item.embedding_x, item.embedding_y)
318
-
319
- # No need to set initial brush - paint() will handle it
320
- point.setPen(QPen(QColor("black"), POINT_WIDTH))
321
-
322
- point.setFlag(QGraphicsItem.ItemIgnoresTransformations)
323
- point.setFlag(QGraphicsItem.ItemIsSelectable)
324
-
325
- # This is the crucial link: store the shared AnnotationDataItem
326
- point.setData(0, item)
327
-
461
+ point = EmbeddingPointItem(item)
328
462
  self.graphics_scene.addItem(point)
329
463
  self.points_by_id[item.annotation.id] = point
330
-
464
+
465
+ # Ensure buttons are in the correct initial state
466
+ self._update_toolbar_state()
467
+
331
468
  def clear_points(self):
332
469
  """Clear all embedding points from the scene."""
470
+ if self.isolated_mode:
471
+ self.show_all_points()
472
+
333
473
  for point in self.points_by_id.values():
334
474
  self.graphics_scene.removeItem(point)
335
475
  self.points_by_id.clear()
476
+ self._update_toolbar_state()
336
477
 
337
478
  def on_selection_changed(self):
338
- """Handle point selection changes and emit a signal to the controller."""
339
- # Check if graphics_scene still exists and is valid
340
- if not self.graphics_scene or not hasattr(self.graphics_scene, 'selectedItems'):
479
+ """
480
+ Handles selection changes in the scene. Updates the central data model
481
+ and emits a signal to notify other parts of the application.
482
+ """
483
+ if not self.graphics_scene:
341
484
  return
342
-
343
485
  try:
344
486
  selected_items = self.graphics_scene.selectedItems()
345
487
  except RuntimeError:
346
- # Scene has been deleted
347
488
  return
348
-
349
- current_selection_ids = {item.data(0).annotation.id for item in selected_items}
350
489
 
351
- # If the selection has actually changed, update the model and emit
490
+ current_selection_ids = {item.data_item.annotation.id for item in selected_items}
491
+
352
492
  if current_selection_ids != self.previous_selection_ids:
353
- # Update the central model (the AnnotationDataItem) for all points
354
493
  for point_id, point in self.points_by_id.items():
355
494
  is_selected = point_id in current_selection_ids
356
- point.data(0).set_selected(is_selected)
495
+ point.data_item.set_selected(is_selected)
357
496
 
358
- # Emit the complete list of currently selected IDs
359
497
  self.selection_changed.emit(list(current_selection_ids))
360
498
  self.previous_selection_ids = current_selection_ids
361
499
 
362
- # Handle local animation - check if animation_timer still exists
363
500
  if hasattr(self, 'animation_timer') and self.animation_timer:
364
501
  self.animation_timer.stop()
365
-
502
+
366
503
  for point in self.points_by_id.values():
367
504
  if not point.isSelected():
368
505
  point.setPen(QPen(QColor("black"), POINT_WIDTH))
369
-
370
506
  if selected_items and hasattr(self, 'animation_timer') and self.animation_timer:
371
507
  self.animation_timer.start()
372
508
 
509
+ # Update button states based on new selection
510
+ self._update_toolbar_state()
511
+
373
512
  def animate_selection(self):
374
- """Animate selected points with marching ants effect using darker versions of point colors."""
375
- # Check if graphics_scene still exists and is valid
376
- if not self.graphics_scene or not hasattr(self.graphics_scene, 'selectedItems'):
513
+ """Animate selected points with a marching ants effect."""
514
+ if not self.graphics_scene:
377
515
  return
378
-
379
516
  try:
380
517
  selected_items = self.graphics_scene.selectedItems()
381
518
  except RuntimeError:
382
- # Scene has been deleted
383
519
  return
384
-
520
+
385
521
  self.animation_offset = (self.animation_offset + 1) % 20
386
-
387
- # This logic remains the same. It applies the custom pen to the selected items.
388
- # Because the items are EmbeddingPointItem, the default selection box won't be drawn.
389
522
  for item in selected_items:
390
- original_color = item.brush().color()
523
+ # Get the color directly from the source of truth
524
+ original_color = item.data_item.effective_color
391
525
  darker_color = original_color.darker(150)
392
-
393
526
  animated_pen = QPen(darker_color, POINT_WIDTH)
394
527
  animated_pen.setStyle(Qt.CustomDashLine)
395
528
  animated_pen.setDashPattern([1, 2])
396
529
  animated_pen.setDashOffset(self.animation_offset)
397
-
398
530
  item.setPen(animated_pen)
399
-
531
+
400
532
  def render_selection_from_ids(self, selected_ids):
401
- """Update the visual selection of points based on a set of IDs from the controller."""
402
- # Block this scene's own selectionChanged signal to prevent an infinite loop
533
+ """
534
+ Updates the visual selection of points based on a set of annotation IDs
535
+ provided by an external controller.
536
+ """
403
537
  blocker = QSignalBlocker(self.graphics_scene)
404
-
538
+
405
539
  for ann_id, point in self.points_by_id.items():
406
- point.setSelected(ann_id in selected_ids)
407
-
408
- self.previous_selection_ids = selected_ids
409
-
410
- # Trigger animation update
540
+ is_selected = ann_id in selected_ids
541
+ # 1. Update the state on the central data item
542
+ point.data_item.set_selected(is_selected)
543
+ # 2. Update the selection state of the graphics item itself
544
+ point.setSelected(is_selected)
545
+
546
+ blocker.unblock()
547
+
548
+ # Manually trigger on_selection_changed to update animation and emit signals
411
549
  self.on_selection_changed()
412
550
 
413
551
  def fit_view_to_points(self):
@@ -415,40 +553,40 @@ class EmbeddingViewer(QWidget): # Change inheritance to QWidget
415
553
  if self.points_by_id:
416
554
  self.graphics_view.fitInView(self.graphics_scene.itemsBoundingRect(), Qt.KeepAspectRatio)
417
555
  else:
418
- # If no points, reset to default view
419
556
  self.graphics_view.fitInView(-2500, -2500, 5000, 5000, Qt.KeepAspectRatio)
420
-
557
+
421
558
 
422
559
  class AnnotationViewer(QScrollArea):
423
560
  """Scrollable grid widget for displaying annotation image crops with selection,
424
- filtering, and isolation support.
425
- """
426
-
427
- # Define signals to report changes to the ExplorerWindow
428
- selection_changed = pyqtSignal(list) # list of changed annotation IDs
429
- preview_changed = pyqtSignal(list) # list of annotation IDs with new previews
430
- reset_view_requested = pyqtSignal() # Signal to reset the view to fit all points
561
+ filtering, and isolation support. Acts as a controller for the widgets."""
562
+ selection_changed = pyqtSignal(list)
563
+ preview_changed = pyqtSignal(list)
564
+ reset_view_requested = pyqtSignal()
565
+ find_similar_requested = pyqtSignal()
431
566
 
432
567
  def __init__(self, parent=None):
568
+ """Initialize the AnnotationViewer widget."""
433
569
  super(AnnotationViewer, self).__init__(parent)
570
+ self.explorer_window = parent
571
+
434
572
  self.annotation_widgets_by_id = {}
435
573
  self.selected_widgets = []
436
574
  self.last_selected_index = -1
437
575
  self.current_widget_size = 96
438
-
439
576
  self.selection_at_press = set()
440
577
  self.rubber_band = None
441
578
  self.rubber_band_origin = None
442
579
  self.drag_threshold = 5
443
580
  self.mouse_pressed_on_widget = False
444
-
445
581
  self.preview_label_assignments = {}
446
582
  self.original_label_assignments = {}
447
-
448
- # New state variables for Isolate/Focus mode
449
583
  self.isolated_mode = False
450
584
  self.isolated_widgets = set()
451
585
 
586
+ # State for new sorting options
587
+ self.active_ordered_ids = []
588
+ self.is_confidence_sort_available = False
589
+
452
590
  self.setup_ui()
453
591
 
454
592
  def setup_ui(self):
@@ -457,49 +595,61 @@ class AnnotationViewer(QScrollArea):
457
595
  self.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
458
596
  self.setVerticalScrollBarPolicy(Qt.ScrollBarAsNeeded)
459
597
 
460
- # Main container and layout
461
598
  main_container = QWidget()
462
599
  main_layout = QVBoxLayout(main_container)
463
600
  main_layout.setContentsMargins(0, 0, 0, 0)
464
- main_layout.setSpacing(4) # Add a little space between toolbar and content
601
+ main_layout.setSpacing(4)
465
602
 
466
- # --- New Toolbar ---
467
603
  toolbar_widget = QWidget()
468
604
  toolbar_layout = QHBoxLayout(toolbar_widget)
469
605
  toolbar_layout.setContentsMargins(4, 2, 4, 2)
470
606
 
471
- # Isolate/Focus controls
472
607
  self.isolate_button = QPushButton("Isolate Selection")
473
- isolate_icon = get_icon("focus.png")
474
- if not isolate_icon.isNull():
475
- self.isolate_button.setIcon(isolate_icon)
476
608
  self.isolate_button.setToolTip("Hide all non-selected annotations")
477
609
  self.isolate_button.clicked.connect(self.isolate_selection)
478
610
  toolbar_layout.addWidget(self.isolate_button)
479
611
 
480
612
  self.show_all_button = QPushButton("Show All")
481
- show_all_icon = get_icon("show_all.png")
482
- if not show_all_icon.isNull():
483
- self.show_all_button.setIcon(show_all_icon)
484
613
  self.show_all_button.setToolTip("Show all filtered annotations")
485
614
  self.show_all_button.clicked.connect(self.show_all_annotations)
486
615
  toolbar_layout.addWidget(self.show_all_button)
487
616
 
488
- # Add a separator
489
617
  toolbar_layout.addWidget(self._create_separator())
490
618
 
491
- # Sort controls
492
619
  sort_label = QLabel("Sort By:")
493
620
  toolbar_layout.addWidget(sort_label)
494
621
  self.sort_combo = QComboBox()
495
- self.sort_combo.addItems(["None", "Label", "Image"])
622
+ # Remove "Similarity" as it's now an implicit action
623
+ self.sort_combo.addItems(["None", "Label", "Image", "Confidence"])
624
+ self.sort_combo.insertSeparator(3) # Add separator before "Confidence"
496
625
  self.sort_combo.currentTextChanged.connect(self.on_sort_changed)
497
626
  toolbar_layout.addWidget(self.sort_combo)
627
+
628
+ toolbar_layout.addWidget(self._create_separator())
629
+
630
+ self.find_similar_button = QToolButton()
631
+ self.find_similar_button.setText("Find Similar")
632
+ self.find_similar_button.setToolTip("Find annotations visually similar to the selection.")
633
+ self.find_similar_button.setPopupMode(QToolButton.MenuButtonPopup)
634
+ self.find_similar_button.setToolButtonStyle(Qt.ToolButtonTextOnly)
635
+ self.find_similar_button.setStyleSheet(
636
+ "QToolButton::menu-indicator { subcontrol-position: right center; subcontrol-origin: padding; left: -4px; }"
637
+ )
498
638
 
499
- # Add a spacer to push the size controls to the right
639
+ run_similar_action = QAction("Find Similar", self)
640
+ run_similar_action.triggered.connect(self.find_similar_requested.emit)
641
+ self.find_similar_button.setDefaultAction(run_similar_action)
642
+
643
+ self.similarity_settings_widget = SimilaritySettingsWidget()
644
+ settings_menu = QMenu(self)
645
+ widget_action = QWidgetAction(settings_menu)
646
+ widget_action.setDefaultWidget(self.similarity_settings_widget)
647
+ settings_menu.addAction(widget_action)
648
+ self.find_similar_button.setMenu(settings_menu)
649
+ toolbar_layout.addWidget(self.find_similar_button)
650
+
500
651
  toolbar_layout.addStretch()
501
652
 
502
- # Size controls
503
653
  size_label = QLabel("Size:")
504
654
  toolbar_layout.addWidget(size_label)
505
655
  self.size_slider = QSlider(Qt.Horizontal)
@@ -514,23 +664,63 @@ class AnnotationViewer(QScrollArea):
514
664
  self.size_value_label = QLabel("96")
515
665
  self.size_value_label.setMinimumWidth(30)
516
666
  toolbar_layout.addWidget(self.size_value_label)
517
-
518
667
  main_layout.addWidget(toolbar_widget)
519
-
520
- # --- Content Area ---
668
+
521
669
  self.content_widget = QWidget()
522
670
  content_scroll = QScrollArea()
523
671
  content_scroll.setWidgetResizable(True)
524
672
  content_scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
525
673
  content_scroll.setVerticalScrollBarPolicy(Qt.ScrollBarAsNeeded)
526
674
  content_scroll.setWidget(self.content_widget)
527
-
675
+
528
676
  main_layout.addWidget(content_scroll)
529
677
  self.setWidget(main_container)
530
678
 
531
- # Set the initial state of the toolbar buttons
679
+ # Set the initial state of the sort options
680
+ self._update_sort_options_state()
532
681
  self._update_toolbar_state()
533
-
682
+
683
+ def _create_separator(self):
684
+ """Creates a vertical separator for the toolbar."""
685
+ separator = QLabel("|")
686
+ separator.setStyleSheet("color: gray; margin: 0 5px;")
687
+ return separator
688
+
689
+ def _update_sort_options_state(self):
690
+ """Enable/disable sort options based on available data."""
691
+ model = self.sort_combo.model()
692
+
693
+ # Enable/disable "Confidence" option
694
+ confidence_item_index = self.sort_combo.findText("Confidence")
695
+ if confidence_item_index != -1:
696
+ model.item(confidence_item_index).setEnabled(self.is_confidence_sort_available)
697
+
698
+ def handle_annotation_context_menu(self, widget, event):
699
+ """Handle context menu requests (e.g., right-click) on an annotation widget."""
700
+ if event.modifiers() == Qt.ControlModifier:
701
+ explorer = self.explorer_window
702
+ image_path = widget.annotation.image_path
703
+ annotation_to_select = widget.annotation
704
+
705
+ if hasattr(explorer, 'annotation_window'):
706
+ # Check if the image needs to be changed
707
+ if explorer.annotation_window.current_image_path != image_path:
708
+ if hasattr(explorer.annotation_window, 'set_image'):
709
+ explorer.annotation_window.set_image(image_path)
710
+
711
+ # Now, select the annotation in the annotation_window
712
+ if hasattr(explorer.annotation_window, 'select_annotation'):
713
+ # This method by default unselects other annotations
714
+ explorer.annotation_window.select_annotation(annotation_to_select)
715
+
716
+ # Also clear any existing selection in the explorer window itself
717
+ explorer.annotation_viewer.clear_selection()
718
+ explorer.embedding_viewer.render_selection_from_ids(set())
719
+ explorer.update_label_window_selection()
720
+ explorer.update_button_states()
721
+
722
+ event.accept()
723
+
534
724
  @pyqtSlot()
535
725
  def isolate_selection(self):
536
726
  """Hides all annotation widgets that are not currently selected."""
@@ -549,32 +739,64 @@ class AnnotationViewer(QScrollArea):
549
739
  self.content_widget.setUpdatesEnabled(True)
550
740
 
551
741
  self._update_toolbar_state()
742
+ self.explorer_window.main_window.label_window.update_annotation_count()
743
+
744
+ def display_and_isolate_ordered_results(self, ordered_ids):
745
+ """
746
+ Isolates the view to a specific set of ordered widgets, ensuring the
747
+ grid is always updated. This is the new primary method for showing
748
+ similarity results.
749
+ """
750
+ self.active_ordered_ids = ordered_ids
751
+
752
+ # Render the selection based on the new order
753
+ self.render_selection_from_ids(set(ordered_ids))
754
+
755
+ # Now, perform the isolation logic directly to bypass the guard clause
756
+ self.isolated_widgets = set(self.selected_widgets)
757
+ self.content_widget.setUpdatesEnabled(False)
758
+ try:
759
+ for widget in self.annotation_widgets_by_id.values():
760
+ # Show widget if it's in our target set, hide otherwise
761
+ if widget in self.isolated_widgets:
762
+ widget.show()
763
+ else:
764
+ widget.hide()
765
+
766
+ self.isolated_mode = True
767
+ self.recalculate_widget_positions() # Crucial grid update
768
+ finally:
769
+ self.content_widget.setUpdatesEnabled(True)
770
+
771
+ self._update_toolbar_state()
772
+ self.explorer_window.main_window.label_window.update_annotation_count()
552
773
 
553
774
  @pyqtSlot()
554
775
  def show_all_annotations(self):
555
776
  """Shows all annotation widgets, exiting the isolated mode."""
556
777
  if not self.isolated_mode:
557
778
  return
558
-
779
+
559
780
  self.isolated_mode = False
560
781
  self.isolated_widgets.clear()
561
-
782
+ self.active_ordered_ids = [] # Clear similarity sort context
783
+
562
784
  self.content_widget.setUpdatesEnabled(False)
563
785
  try:
786
+ # Show all widgets that are managed by the viewer
564
787
  for widget in self.annotation_widgets_by_id.values():
565
788
  widget.show()
789
+
566
790
  self.recalculate_widget_positions()
567
791
  finally:
568
792
  self.content_widget.setUpdatesEnabled(True)
569
-
793
+
570
794
  self._update_toolbar_state()
571
-
795
+ self.explorer_window.main_window.label_window.update_annotation_count()
796
+
572
797
  def _update_toolbar_state(self):
573
- """Updates the visibility and enabled state of the toolbar buttons
574
- based on the current selection and isolation mode.
575
- """
798
+ """Updates the toolbar buttons based on selection and isolation mode."""
576
799
  selection_exists = bool(self.selected_widgets)
577
-
578
800
  if self.isolated_mode:
579
801
  self.isolate_button.hide()
580
802
  self.show_all_button.show()
@@ -583,65 +805,70 @@ class AnnotationViewer(QScrollArea):
583
805
  self.isolate_button.show()
584
806
  self.show_all_button.hide()
585
807
  self.isolate_button.setEnabled(selection_exists)
586
-
587
- def _create_separator(self):
588
- """Create a vertical separator line for the toolbar."""
589
- separator = QLabel("|")
590
- separator.setStyleSheet("color: gray; margin: 0 5px;")
591
- return separator
592
808
 
593
809
  def on_sort_changed(self, sort_type):
594
810
  """Handle sort type change."""
811
+ self.active_ordered_ids = [] # Clear any special ordering
595
812
  self.recalculate_widget_positions()
596
813
 
814
+ def set_confidence_sort_availability(self, is_available):
815
+ """Sets the availability of the confidence sort option."""
816
+ self.is_confidence_sort_available = is_available
817
+ self._update_sort_options_state()
818
+
597
819
  def _get_sorted_widgets(self):
598
820
  """Get widgets sorted according to the current sort setting."""
821
+ # If a specific order is active (e.g., from similarity search), use it.
822
+ if self.active_ordered_ids:
823
+ widget_map = {w.data_item.annotation.id: w for w in self.annotation_widgets_by_id.values()}
824
+ ordered_widgets = [widget_map[ann_id] for ann_id in self.active_ordered_ids if ann_id in widget_map]
825
+ return ordered_widgets
826
+
827
+ # Otherwise, use the dropdown sort logic
599
828
  sort_type = self.sort_combo.currentText()
600
-
601
- if sort_type == "None":
602
- return list(self.annotation_widgets_by_id.values())
603
-
604
829
  widgets = list(self.annotation_widgets_by_id.values())
605
-
830
+
606
831
  if sort_type == "Label":
607
832
  widgets.sort(key=lambda w: w.data_item.effective_label.short_label_code)
608
833
  elif sort_type == "Image":
609
834
  widgets.sort(key=lambda w: os.path.basename(w.data_item.annotation.image_path))
835
+ elif sort_type == "Confidence":
836
+ # Sort by confidence, descending. Handles cases with no confidence gracefully.
837
+ widgets.sort(key=lambda w: w.data_item.get_effective_confidence(), reverse=True)
610
838
 
611
839
  return widgets
612
840
 
613
841
  def _group_widgets_by_sort_key(self, widgets):
614
- """Group widgets by the current sort key and return groups with headers."""
842
+ """Group widgets by the current sort key."""
615
843
  sort_type = self.sort_combo.currentText()
616
-
617
- if sort_type == "None":
844
+ if not self.active_ordered_ids and sort_type == "None":
618
845
  return [("", widgets)]
619
846
 
847
+ if self.active_ordered_ids: # Don't show group headers for similarity results
848
+ return [("", widgets)]
849
+
620
850
  groups = []
621
851
  current_group = []
622
852
  current_key = None
623
-
624
853
  for widget in widgets:
625
854
  if sort_type == "Label":
626
855
  key = widget.data_item.effective_label.short_label_code
627
856
  elif sort_type == "Image":
628
857
  key = os.path.basename(widget.data_item.annotation.image_path)
629
858
  else:
630
- key = ""
859
+ key = "" # No headers for Confidence or None
631
860
 
632
- if current_key != key:
861
+ if key and current_key != key:
633
862
  if current_group:
634
863
  groups.append((current_key, current_group))
635
864
  current_group = [widget]
636
865
  current_key = key
637
866
  else:
638
867
  current_group.append(widget)
639
-
640
868
  if current_group:
641
869
  groups.append((current_key, current_group))
642
-
643
870
  return groups
644
-
871
+
645
872
  def _clear_separator_labels(self):
646
873
  """Remove any existing group header labels."""
647
874
  if hasattr(self, '_group_headers'):
@@ -654,25 +881,22 @@ class AnnotationViewer(QScrollArea):
654
881
  """Create a group header label."""
655
882
  if not hasattr(self, '_group_headers'):
656
883
  self._group_headers = []
657
-
658
- header = QLabel(text)
659
- header.setParent(self.content_widget)
660
- header.setStyleSheet("""
661
- QLabel {
662
- font-weight: bold;
663
- font-size: 12px;
664
- color: #555;
665
- background-color: #f0f0f0;
666
- border: 1px solid #ccc;
667
- border-radius: 3px;
668
- padding: 5px 8px;
669
- margin: 2px 0px;
670
- }
671
- """)
672
- header.setFixedHeight(30) # Increased from 25 to 30
884
+ header = QLabel(text, self.content_widget)
885
+ header.setStyleSheet(
886
+ "QLabel {"
887
+ " font-weight: bold;"
888
+ " font-size: 12px;"
889
+ " color: #555;"
890
+ " background-color: #f0f0f0;"
891
+ " border: 1px solid #ccc;"
892
+ " border-radius: 3px;"
893
+ " padding: 5px 8px;"
894
+ " margin: 2px 0px;"
895
+ " }"
896
+ )
897
+ header.setFixedHeight(30)
673
898
  header.setMinimumWidth(self.viewport().width() - 20)
674
899
  header.show()
675
-
676
900
  self._group_headers.append(header)
677
901
  return header
678
902
 
@@ -680,298 +904,251 @@ class AnnotationViewer(QScrollArea):
680
904
  """Handle slider value change to resize annotation widgets."""
681
905
  if value % 2 != 0:
682
906
  value -= 1
907
+
683
908
  self.current_widget_size = value
684
909
  self.size_value_label.setText(str(value))
685
-
686
- # Disable updates for performance while resizing many items
687
910
  self.content_widget.setUpdatesEnabled(False)
911
+
688
912
  for widget in self.annotation_widgets_by_id.values():
689
- widget.update_height(value) # Call the new, more descriptive method
690
- self.content_widget.setUpdatesEnabled(True)
913
+ widget.update_height(value)
691
914
 
692
- # After resizing, reflow the layout
915
+ self.content_widget.setUpdatesEnabled(True)
693
916
  self.recalculate_widget_positions()
694
917
 
695
- def recalculate_grid_layout(self):
696
- """Recalculate the grid layout based on current widget width."""
697
- if not self.annotation_widgets_by_id:
698
- return
699
-
700
- available_width = self.viewport().width() - 20
701
- widget_width = self.current_widget_size + self.grid_layout.spacing()
702
- cols = max(1, available_width // widget_width)
703
-
704
- for i, widget in enumerate(self.annotation_widgets_by_id.values()):
705
- self.grid_layout.addWidget(widget, i // cols, i % cols)
706
-
707
918
  def recalculate_widget_positions(self):
708
919
  """Manually positions widgets in a flow layout with sorting and group headers."""
709
920
  if not self.annotation_widgets_by_id:
710
921
  self.content_widget.setMinimumSize(1, 1)
711
922
  return
712
923
 
713
- # Clear any existing separator labels
714
924
  self._clear_separator_labels()
715
-
716
- # Get sorted widgets
717
- all_widgets = self._get_sorted_widgets()
718
-
719
- # Filter to only visible widgets
720
- visible_widgets = [w for w in all_widgets if not w.isHidden()]
721
-
925
+ visible_widgets = [w for w in self._get_sorted_widgets() if not w.isHidden()]
722
926
  if not visible_widgets:
723
927
  self.content_widget.setMinimumSize(1, 1)
724
928
  return
725
929
 
726
- # Group widgets by sort key
930
+ # Create groups based on the current sort key
727
931
  groups = self._group_widgets_by_sort_key(visible_widgets)
728
-
729
- # Calculate spacing
730
932
  spacing = max(5, int(self.current_widget_size * 0.08))
731
933
  available_width = self.viewport().width()
732
-
733
934
  x, y = spacing, spacing
734
935
  max_height_in_row = 0
735
936
 
937
+ # Calculate the maximum height of the widgets in each row
736
938
  for group_name, group_widgets in groups:
737
- # Add group header if sorting is enabled and group has a name
738
939
  if group_name and self.sort_combo.currentText() != "None":
739
- # Ensure we're at the start of a new line for headers
740
940
  if x > spacing:
741
941
  x = spacing
742
942
  y += max_height_in_row + spacing
743
943
  max_height_in_row = 0
744
-
745
- # Create and position header label
746
944
  header_label = self._create_group_header(group_name)
747
945
  header_label.move(x, y)
748
-
749
- # Move to next line after header
750
946
  y += header_label.height() + spacing
751
947
  x = spacing
752
948
  max_height_in_row = 0
753
949
 
754
- # Position widgets in this group
755
950
  for widget in group_widgets:
756
951
  widget_size = widget.size()
757
-
758
- # Check if widget fits on current line
759
952
  if x > spacing and x + widget_size.width() > available_width:
760
953
  x = spacing
761
954
  y += max_height_in_row + spacing
762
955
  max_height_in_row = 0
763
-
764
956
  widget.move(x, y)
765
957
  x += widget_size.width() + spacing
766
958
  max_height_in_row = max(max_height_in_row, widget_size.height())
767
959
 
768
- # Update content widget size
769
960
  total_height = y + max_height_in_row + spacing
770
961
  self.content_widget.setMinimumSize(available_width, total_height)
771
-
962
+
772
963
  def update_annotations(self, data_items):
773
- """Update displayed annotations, creating new widgets. This will also
774
- reset any active isolation view.
775
- """
776
- # Reset isolation state before updating to avoid confusion
964
+ """Update displayed annotations, creating new widgets for them."""
777
965
  if self.isolated_mode:
778
966
  self.show_all_annotations()
779
-
780
- # Clear any existing widgets and ensure they are deleted
967
+
781
968
  for widget in self.annotation_widgets_by_id.values():
782
969
  widget.setParent(None)
783
970
  widget.deleteLater()
784
-
971
+
785
972
  self.annotation_widgets_by_id.clear()
786
973
  self.selected_widgets.clear()
787
974
  self.last_selected_index = -1
788
975
 
789
- # Create new widgets, parenting them to the content_widget
790
976
  for data_item in data_items:
791
977
  annotation_widget = AnnotationImageWidget(
792
- data_item,
793
- self.current_widget_size,
794
- annotation_viewer=self,
795
- parent=self.content_widget
796
- )
797
- annotation_widget.show()
978
+ data_item, self.current_widget_size, self, self.content_widget)
979
+
980
+ annotation_widget.show()
798
981
  self.annotation_widgets_by_id[data_item.annotation.id] = annotation_widget
799
-
982
+
800
983
  self.recalculate_widget_positions()
801
- # Ensure toolbar is in the correct state after a refresh
802
984
  self._update_toolbar_state()
803
985
 
804
986
  def resizeEvent(self, event):
805
987
  """On window resize, reflow the annotation widgets."""
806
988
  super().resizeEvent(event)
807
- # Use a QTimer to avoid rapid, expensive reflows while dragging the resize handle
808
989
  if not hasattr(self, '_resize_timer'):
809
990
  self._resize_timer = QTimer(self)
810
991
  self._resize_timer.setSingleShot(True)
811
992
  self._resize_timer.timeout.connect(self.recalculate_widget_positions)
812
- # Restart the timer on each resize event
813
- self._resize_timer.start(100) # 100ms delay
993
+ self._resize_timer.start(100)
994
+
995
+ def keyPressEvent(self, event):
996
+ """Handles key presses for deleting selected annotations."""
997
+ if event.key() in (Qt.Key_Delete, Qt.Key_Backspace) and event.modifiers() == Qt.ControlModifier:
998
+ if not self.selected_widgets:
999
+ super().keyPressEvent(event)
1000
+ return
1001
+
1002
+ # Extract the central data items from the selected widgets
1003
+ data_items_to_delete = [widget.data_item for widget in self.selected_widgets]
1004
+
1005
+ # Delegate the actual deletion to the main ExplorerWindow
1006
+ if data_items_to_delete:
1007
+ self.explorer_window.delete_data_items(data_items_to_delete)
1008
+
1009
+ event.accept()
1010
+ else:
1011
+ super().keyPressEvent(event)
814
1012
 
815
1013
  def mousePressEvent(self, event):
816
1014
  """Handle mouse press for starting rubber band selection OR clearing selection."""
817
-
818
- # Handle plain left-clicks
819
1015
  if event.button() == Qt.LeftButton:
820
-
821
- # This is the new logic for clearing selection on a background click.
822
- if not event.modifiers(): # Check for NO modifiers (e.g., Ctrl, Shift)
823
-
1016
+ if not event.modifiers():
1017
+ # If left click with no modifiers, check if click is outside widgets
824
1018
  is_on_widget = False
825
1019
  child_at_pos = self.childAt(event.pos())
826
1020
 
827
- # Determine if the click was on an actual annotation widget or empty space
828
1021
  if child_at_pos:
829
1022
  widget = child_at_pos
1023
+ # Traverse up the parent chain to see if click is on an annotation widget
830
1024
  while widget and widget != self:
831
1025
  if hasattr(widget, 'annotation_viewer') and widget.annotation_viewer == self:
832
1026
  is_on_widget = True
833
1027
  break
834
1028
  widget = widget.parent()
835
-
836
- # If click was on empty space AND something is currently selected...
1029
+
1030
+ # If click is outside widgets and there is a selection, clear it
837
1031
  if not is_on_widget and self.selected_widgets:
838
- # Get IDs of widgets that are about to be deselected to emit a signal
839
1032
  changed_ids = [w.data_item.annotation.id for w in self.selected_widgets]
840
1033
  self.clear_selection()
841
1034
  self.selection_changed.emit(changed_ids)
842
- # The event is handled, but we don't call super() to prevent
843
- # the scroll area from doing anything else, like starting a drag.
844
1035
  return
845
1036
 
846
- # Handle Ctrl+Click for rubber band
847
1037
  elif event.modifiers() == Qt.ControlModifier:
848
- # Store the set of currently selected items.
1038
+ # Start rubber band selection with Ctrl+Left click
849
1039
  self.selection_at_press = set(self.selected_widgets)
850
1040
  self.rubber_band_origin = event.pos()
851
- # We determine mouse_pressed_on_widget here but use it in mouseMove
852
1041
  self.mouse_pressed_on_widget = False
853
1042
  child_widget = self.childAt(event.pos())
854
1043
  if child_widget:
855
1044
  widget = child_widget
1045
+ # Check if click is on a widget to avoid starting rubber band
856
1046
  while widget and widget != self:
857
1047
  if hasattr(widget, 'annotation_viewer') and widget.annotation_viewer == self:
858
1048
  self.mouse_pressed_on_widget = True
859
1049
  break
860
1050
  widget = widget.parent()
861
1051
  return
862
-
863
- # Handle right-clicks
1052
+
864
1053
  elif event.button() == Qt.RightButton:
1054
+ # Ignore right clicks
865
1055
  event.ignore()
866
1056
  return
867
-
868
- # For all other cases (e.g., a click on a widget that should be handled
869
- # by the widget itself), pass the event to the default handler.
1057
+
1058
+ # Default handler for other cases
870
1059
  super().mousePressEvent(event)
871
-
1060
+
872
1061
  def mouseDoubleClickEvent(self, event):
873
1062
  """Handle double-click to clear selection and exit isolation mode."""
874
1063
  if event.button() == Qt.LeftButton:
875
1064
  changed_ids = []
876
-
877
- # If items are selected, clear the selection and record their IDs
878
1065
  if self.selected_widgets:
879
1066
  changed_ids = [w.data_item.annotation.id for w in self.selected_widgets]
880
1067
  self.clear_selection()
881
1068
  self.selection_changed.emit(changed_ids)
882
-
883
- # If in isolation mode, revert to showing all annotations
884
1069
  if self.isolated_mode:
885
1070
  self.show_all_annotations()
886
-
887
- # Signal the main window to reset its view (e.g., switch tabs)
888
1071
  self.reset_view_requested.emit()
889
1072
  event.accept()
890
1073
  else:
891
1074
  super().mouseDoubleClickEvent(event)
892
-
1075
+
893
1076
  def mouseMoveEvent(self, event):
894
1077
  """Handle mouse move for DYNAMIC rubber band selection."""
895
- if self.rubber_band_origin is None or \
896
- event.buttons() != Qt.LeftButton or \
897
- event.modifiers() != Qt.ControlModifier:
1078
+ # Only proceed if Ctrl+Left mouse drag is active and not on a widget
1079
+ if (
1080
+ self.rubber_band_origin is None or
1081
+ event.buttons() != Qt.LeftButton or
1082
+ event.modifiers() != Qt.ControlModifier
1083
+ ):
898
1084
  super().mouseMoveEvent(event)
899
1085
  return
900
1086
 
901
- # If the mouse was pressed on a widget, let that widget handle the event.
902
1087
  if self.mouse_pressed_on_widget:
1088
+ # If drag started on a widget, do not start rubber band
903
1089
  super().mouseMoveEvent(event)
904
1090
  return
905
1091
 
906
- # Only start the rubber band after dragging a minimum distance
1092
+ # Only start selection if drag distance exceeds threshold
907
1093
  distance = (event.pos() - self.rubber_band_origin).manhattanLength()
908
1094
  if distance < self.drag_threshold:
909
1095
  return
910
1096
 
911
- # Create and show the rubber band if it doesn't exist
1097
+ # Create and show the rubber band if not already present
912
1098
  if not self.rubber_band:
913
1099
  self.rubber_band = QRubberBand(QRubberBand.Rectangle, self.viewport())
914
-
1100
+
915
1101
  rect = QRect(self.rubber_band_origin, event.pos()).normalized()
916
1102
  self.rubber_band.setGeometry(rect)
917
1103
  self.rubber_band.show()
918
-
919
- # Perform dynamic selection on every move
920
1104
  selection_rect = self.rubber_band.geometry()
921
1105
  content_widget = self.content_widget
922
1106
  changed_ids = []
923
1107
 
1108
+ # Iterate over all annotation widgets to update selection state
924
1109
  for widget in self.annotation_widgets_by_id.values():
925
1110
  widget_rect_in_content = widget.geometry()
926
- # Map widget's geometry from the content area to the visible viewport
1111
+ # Map widget's rect to viewport coordinates
927
1112
  widget_rect_in_viewport = QRect(
928
1113
  content_widget.mapTo(self.viewport(), widget_rect_in_content.topLeft()),
929
1114
  widget_rect_in_content.size()
930
1115
  )
931
-
932
1116
  is_in_band = selection_rect.intersects(widget_rect_in_viewport)
933
-
934
- # A widget should be selected if it was selected at the start OR is in the band now.
935
1117
  should_be_selected = (widget in self.selection_at_press) or is_in_band
936
1118
 
1119
+ # Select or deselect widgets as needed
937
1120
  if should_be_selected and not widget.is_selected():
938
1121
  if self.select_widget(widget):
939
1122
  changed_ids.append(widget.data_item.annotation.id)
1123
+
940
1124
  elif not should_be_selected and widget.is_selected():
941
1125
  if self.deselect_widget(widget):
942
1126
  changed_ids.append(widget.data_item.annotation.id)
943
-
1127
+
1128
+ # Emit signal if any selection state changed
944
1129
  if changed_ids:
945
1130
  self.selection_changed.emit(changed_ids)
946
-
1131
+
947
1132
  def mouseReleaseEvent(self, event):
948
1133
  """Handle mouse release to complete rubber band selection."""
949
- # Check if a rubber band drag was in progress
950
1134
  if self.rubber_band_origin is not None and event.button() == Qt.LeftButton:
951
1135
  if self.rubber_band and self.rubber_band.isVisible():
952
1136
  self.rubber_band.hide()
953
1137
  self.rubber_band.deleteLater()
954
1138
  self.rubber_band = None
955
1139
 
956
- # **NEEDED CHANGE**: Clean up the stored selection state.
957
1140
  self.selection_at_press = set()
958
1141
  self.rubber_band_origin = None
959
1142
  self.mouse_pressed_on_widget = False
960
1143
  event.accept()
961
1144
  return
962
-
1145
+
963
1146
  super().mouseReleaseEvent(event)
964
1147
 
965
1148
  def handle_annotation_selection(self, widget, event):
966
- """Handle selection of annotation widgets with different modes."""
967
- # Get the list of widgets to work with based on isolation mode
968
- if self.isolated_mode:
969
- # Only work with visible widgets when in isolation mode
970
- widget_list = [w for w in self.annotation_widgets_by_id.values() if not w.isHidden()]
971
- else:
972
- # Use all widgets when not in isolation mode
973
- widget_list = list(self.annotation_widgets_by_id.values())
974
-
1149
+ """Handle selection of annotation widgets with different modes (single, ctrl, shift)."""
1150
+ widget_list = [w for w in self._get_sorted_widgets() if not w.isHidden()]
1151
+
975
1152
  try:
976
1153
  widget_index = widget_list.index(widget)
977
1154
  except ValueError:
@@ -980,44 +1157,42 @@ class AnnotationViewer(QScrollArea):
980
1157
  modifiers = event.modifiers()
981
1158
  changed_ids = []
982
1159
 
983
- # --- The selection logic now identifies which items to change ---
984
- # --- but the core state change happens in select/deselect ---
985
-
1160
+ # Shift or Shift+Ctrl: range selection
986
1161
  if modifiers == Qt.ShiftModifier or modifiers == (Qt.ShiftModifier | Qt.ControlModifier):
987
- # Range selection
988
1162
  if self.last_selected_index != -1:
989
- # Find the last selected widget in the current widget list
1163
+ # Find the last selected widget in the current list
990
1164
  last_selected_widget = None
991
1165
  for w in self.selected_widgets:
992
1166
  if w in widget_list:
993
1167
  try:
994
1168
  last_index_in_current_list = widget_list.index(w)
995
- if last_selected_widget is None or \
996
- last_index_in_current_list > widget_list.index(last_selected_widget):
1169
+ if (
1170
+ last_selected_widget is None
1171
+ or last_index_in_current_list > widget_list.index(last_selected_widget)
1172
+ ):
997
1173
  last_selected_widget = w
998
1174
  except ValueError:
999
1175
  continue
1000
-
1176
+
1001
1177
  if last_selected_widget:
1002
1178
  last_selected_index_in_current_list = widget_list.index(last_selected_widget)
1003
1179
  start = min(last_selected_index_in_current_list, widget_index)
1004
1180
  end = max(last_selected_index_in_current_list, widget_index)
1005
1181
  else:
1006
- # Fallback if no previously selected widget is found in current list
1007
- start = widget_index
1008
- end = widget_index
1009
-
1182
+ start, end = widget_index, widget_index
1183
+
1184
+ # Select all widgets in the range
1010
1185
  for i in range(start, end + 1):
1011
- # select_widget will return True if a change occurred
1012
1186
  if self.select_widget(widget_list[i]):
1013
1187
  changed_ids.append(widget_list[i].data_item.annotation.id)
1014
1188
  else:
1189
+ # No previous selection, just select the clicked widget
1015
1190
  if self.select_widget(widget):
1016
1191
  changed_ids.append(widget.data_item.annotation.id)
1017
1192
  self.last_selected_index = widget_index
1018
-
1193
+
1194
+ # Ctrl: toggle selection of the clicked widget
1019
1195
  elif modifiers == Qt.ControlModifier:
1020
- # Toggle selection
1021
1196
  if widget.is_selected():
1022
1197
  if self.deselect_widget(widget):
1023
1198
  changed_ids.append(widget.data_item.annotation.id)
@@ -1025,36 +1200,37 @@ class AnnotationViewer(QScrollArea):
1025
1200
  if self.select_widget(widget):
1026
1201
  changed_ids.append(widget.data_item.annotation.id)
1027
1202
  self.last_selected_index = widget_index
1028
-
1203
+
1204
+ # No modifier: single selection
1029
1205
  else:
1030
- # Normal click: clear all others and select this one
1031
1206
  newly_selected_id = widget.data_item.annotation.id
1032
- # Deselect all widgets that are not the clicked one
1207
+
1208
+ # Deselect all others
1033
1209
  for w in list(self.selected_widgets):
1034
1210
  if w.data_item.annotation.id != newly_selected_id:
1035
1211
  if self.deselect_widget(w):
1036
1212
  changed_ids.append(w.data_item.annotation.id)
1213
+
1037
1214
  # Select the clicked widget
1038
1215
  if self.select_widget(widget):
1039
1216
  changed_ids.append(newly_selected_id)
1040
1217
  self.last_selected_index = widget_index
1041
-
1042
- # Update isolation if in isolated mode
1218
+
1219
+ # If in isolated mode, update which widgets are visible
1043
1220
  if self.isolated_mode:
1044
1221
  self._update_isolation()
1045
-
1046
- # If any selections were changed, emit the signal
1222
+
1223
+ # Emit signal if any selection state changed
1047
1224
  if changed_ids:
1048
1225
  self.selection_changed.emit(changed_ids)
1049
-
1226
+
1050
1227
  def _update_isolation(self):
1051
1228
  """Update the isolated view to show only currently selected widgets."""
1052
1229
  if not self.isolated_mode:
1053
1230
  return
1054
-
1231
+ # If in isolated mode, only show selected widgets
1055
1232
  if self.selected_widgets:
1056
- # ADD TO isolation instead of replacing it
1057
- self.isolated_widgets.update(self.selected_widgets) # Use update() to add, not replace
1233
+ self.isolated_widgets.update(self.selected_widgets)
1058
1234
  self.setUpdatesEnabled(False)
1059
1235
  try:
1060
1236
  for widget in self.annotation_widgets_by_id.values():
@@ -1063,169 +1239,119 @@ class AnnotationViewer(QScrollArea):
1063
1239
  else:
1064
1240
  widget.show()
1065
1241
  self.recalculate_widget_positions()
1242
+
1066
1243
  finally:
1067
1244
  self.setUpdatesEnabled(True)
1068
- else:
1069
- # If no widgets are selected, keep the current isolation (don't exit)
1070
- # This prevents accidentally exiting isolation mode when clearing selection
1071
- pass
1072
1245
 
1073
1246
  def select_widget(self, widget):
1074
- """Select a widget, update the data_item, and return True if state changed."""
1075
- if not widget.is_selected():
1076
- widget.set_selected(True)
1247
+ """Selects a widget, updates its data_item, and returns True if state changed."""
1248
+ if not widget.is_selected(): # is_selected() checks the data_item
1249
+ # 1. Controller modifies the state on the data item
1077
1250
  widget.data_item.set_selected(True)
1251
+ # 2. Controller tells the view to update its appearance
1252
+ widget.update_selection_visuals()
1078
1253
  self.selected_widgets.append(widget)
1079
- self.update_label_window_selection()
1080
- self._update_toolbar_state() # Update button states
1254
+ self._update_toolbar_state()
1081
1255
  return True
1082
1256
  return False
1083
1257
 
1084
1258
  def deselect_widget(self, widget):
1085
- """Deselect a widget, update the data_item, and return True if state changed."""
1259
+ """Deselects a widget, updates its data_item, and returns True if state changed."""
1086
1260
  if widget.is_selected():
1087
- widget.set_selected(False)
1261
+ # 1. Controller modifies the state on the data item
1088
1262
  widget.data_item.set_selected(False)
1263
+ # 2. Controller tells the view to update its appearance
1264
+ widget.update_selection_visuals()
1089
1265
  if widget in self.selected_widgets:
1090
1266
  self.selected_widgets.remove(widget)
1091
- self.update_label_window_selection()
1092
- self._update_toolbar_state() # Update button states
1267
+ self._update_toolbar_state()
1093
1268
  return True
1094
1269
  return False
1095
1270
 
1096
1271
  def clear_selection(self):
1097
1272
  """Clear all selected widgets and update toolbar state."""
1098
1273
  for widget in list(self.selected_widgets):
1099
- widget.set_selected(False)
1100
- self.selected_widgets.clear()
1101
- self.update_label_window_selection()
1102
- self._update_toolbar_state() # Update button states
1274
+ # This will internally call deselect_widget, which is fine
1275
+ self.deselect_widget(widget)
1103
1276
 
1104
- def update_label_window_selection(self):
1105
- """Update the label window selection based on currently selected annotations."""
1106
- explorer_window = self.parent()
1107
- while explorer_window and not hasattr(explorer_window, 'main_window'):
1108
- explorer_window = explorer_window.parent()
1109
-
1110
- if not explorer_window or not hasattr(explorer_window, 'main_window'):
1111
- return
1112
-
1113
- main_window = explorer_window.main_window
1114
- label_window = main_window.label_window
1115
- annotation_window = main_window.annotation_window
1116
-
1117
- if not self.selected_widgets:
1118
- label_window.deselect_active_label()
1119
- label_window.update_annotation_count()
1120
- return
1121
-
1122
- selected_data_items = [widget.data_item for widget in self.selected_widgets]
1123
-
1124
- first_effective_label = selected_data_items[0].effective_label
1125
- all_same_current_label = True
1126
- for item in selected_data_items:
1127
- if item.effective_label.id != first_effective_label.id:
1128
- all_same_current_label = False
1129
- break
1130
-
1131
- if all_same_current_label:
1132
- label_window.set_active_label(first_effective_label)
1133
- if not selected_data_items[0].has_preview_changes():
1134
- annotation_window.labelSelected.emit(first_effective_label.id)
1135
- else:
1136
- label_window.deselect_active_label()
1137
-
1138
- label_window.update_annotation_count()
1277
+ self.selected_widgets.clear()
1278
+ self._update_toolbar_state()
1139
1279
 
1140
1280
  def get_selected_annotations(self):
1141
1281
  """Get the annotations corresponding to selected widgets."""
1142
1282
  return [widget.annotation for widget in self.selected_widgets]
1143
-
1283
+
1144
1284
  def render_selection_from_ids(self, selected_ids):
1145
1285
  """Update the visual selection of widgets based on a set of IDs from the controller."""
1146
- # Block signals temporarily to prevent cascade updates
1147
1286
  self.setUpdatesEnabled(False)
1148
-
1149
1287
  try:
1150
1288
  for ann_id, widget in self.annotation_widgets_by_id.items():
1151
1289
  is_selected = ann_id in selected_ids
1152
- widget.set_selected(is_selected)
1153
-
1154
- # Resync internal list of selected widgets
1290
+ # 1. Update the state on the central data item
1291
+ widget.data_item.set_selected(is_selected)
1292
+ # 2. Tell the widget to update its visuals based on the new state
1293
+ widget.update_selection_visuals()
1294
+
1295
+ # Resync internal list of selected widgets from the source of truth
1155
1296
  self.selected_widgets = [w for w in self.annotation_widgets_by_id.values() if w.is_selected()]
1156
-
1157
- # If we're in isolated mode, ADD to the isolation instead of replacing it
1297
+
1158
1298
  if self.isolated_mode and self.selected_widgets:
1159
- self.isolated_widgets.update(self.selected_widgets) # Add to existing isolation
1160
- # Hide all widgets except those in the isolated set
1299
+ self.isolated_widgets.update(self.selected_widgets)
1161
1300
  for widget in self.annotation_widgets_by_id.values():
1162
- if widget not in self.isolated_widgets:
1163
- widget.hide()
1164
- else:
1165
- widget.show()
1301
+ widget.setHidden(widget not in self.isolated_widgets)
1166
1302
  self.recalculate_widget_positions()
1167
-
1168
1303
  finally:
1169
1304
  self.setUpdatesEnabled(True)
1170
-
1171
- # Update label window once at the end
1172
- self.update_label_window_selection()
1173
- # Update toolbar state to enable/disable Isolate button
1174
1305
  self._update_toolbar_state()
1175
-
1306
+
1176
1307
  def apply_preview_label_to_selected(self, preview_label):
1177
1308
  """Apply a preview label and emit a signal for the embedding view to update."""
1178
1309
  if not self.selected_widgets or not preview_label:
1179
1310
  return
1180
-
1181
1311
  changed_ids = []
1182
1312
  for widget in self.selected_widgets:
1183
1313
  widget.data_item.set_preview_label(preview_label)
1184
- widget.update() # Force repaint with new color
1314
+ widget.update() # Force repaint with new color
1185
1315
  changed_ids.append(widget.data_item.annotation.id)
1186
-
1187
- # Recalculate positions to update sorting based on new effective labels
1316
+
1188
1317
  if self.sort_combo.currentText() == "Label":
1189
1318
  self.recalculate_widget_positions()
1190
-
1191
1319
  if changed_ids:
1192
1320
  self.preview_changed.emit(changed_ids)
1193
1321
 
1194
1322
  def clear_preview_states(self):
1195
- """Clear all preview states and revert to original labels."""
1196
- # We just need to iterate through all widgets and tell their data_items to clear
1197
- something_cleared = False
1323
+ """
1324
+ Clears all preview states, including label changes,
1325
+ reverting them to their original state.
1326
+ """
1327
+ something_changed = False
1198
1328
  for widget in self.annotation_widgets_by_id.values():
1329
+ # Check for and clear preview labels
1199
1330
  if widget.data_item.has_preview_changes():
1200
1331
  widget.data_item.clear_preview_label()
1201
- widget.update() # Repaint to show original color
1202
- something_cleared = True
1203
-
1204
- if something_cleared:
1205
- # Recalculate positions to update sorting based on reverted labels
1206
- if self.sort_combo.currentText() == "Label":
1332
+ widget.update() # Repaint to show original color
1333
+ something_changed = True
1334
+
1335
+ if something_changed:
1336
+ # Recalculate positions to update sorting and re-flow the layout
1337
+ if self.sort_combo.currentText() in ("Label", "Image"):
1207
1338
  self.recalculate_widget_positions()
1208
- self.update_label_window_selection()
1209
1339
 
1210
1340
  def has_preview_changes(self):
1211
- """Check if there are any pending preview changes."""
1341
+ """Return True if there are preview changes."""
1212
1342
  return any(w.data_item.has_preview_changes() for w in self.annotation_widgets_by_id.values())
1213
1343
 
1214
1344
  def get_preview_changes_summary(self):
1215
- """Get a summary of preview changes for user feedback."""
1345
+ """Get a summary of preview changes."""
1216
1346
  change_count = sum(1 for w in self.annotation_widgets_by_id.values() if w.data_item.has_preview_changes())
1217
- if not change_count:
1218
- return "No preview changes"
1219
- return f"{change_count} annotation(s) with preview changes"
1347
+ return f"{change_count} annotation(s) with preview changes" if change_count else "No preview changes"
1220
1348
 
1221
1349
  def apply_preview_changes_permanently(self):
1222
- """Apply all preview changes permanently to the annotation data."""
1350
+ """Apply preview changes permanently."""
1223
1351
  applied_annotations = []
1224
1352
  for widget in self.annotation_widgets_by_id.values():
1225
- # Tell the data_item to apply its changes to the underlying annotation
1226
1353
  if widget.data_item.apply_preview_permanently():
1227
1354
  applied_annotations.append(widget.annotation)
1228
-
1229
1355
  return applied_annotations
1230
1356
 
1231
1357
 
@@ -1236,48 +1362,51 @@ class AnnotationViewer(QScrollArea):
1236
1362
 
1237
1363
  class ExplorerWindow(QMainWindow):
1238
1364
  def __init__(self, main_window, parent=None):
1365
+ """Initialize the ExplorerWindow."""
1239
1366
  super(ExplorerWindow, self).__init__(parent)
1240
1367
  self.main_window = main_window
1241
1368
  self.image_window = main_window.image_window
1242
1369
  self.label_window = main_window.label_window
1243
1370
  self.annotation_window = main_window.annotation_window
1244
1371
 
1245
- self.device = main_window.device # Use the same device as the main window
1246
- self.model_path = ""
1372
+ self.device = main_window.device
1247
1373
  self.loaded_model = None
1248
1374
 
1249
- # Store current filtered data items for embedding
1250
- self.current_data_items = []
1375
+ self.feature_store = FeatureStore()
1251
1376
 
1252
- # Cache for extracted features and the model that generated them ---
1377
+ # Add a property to store the parameters with defaults
1378
+ self.mislabel_params = {'k': 20, 'threshold': 0.6}
1379
+ self.uncertainty_params = {'confidence': 0.6, 'margin': 0.1}
1380
+ self.similarity_params = {'k': 30}
1381
+
1382
+ self.data_item_cache = {} # Cache for AnnotationDataItem objects
1383
+
1384
+ self.current_data_items = []
1253
1385
  self.current_features = None
1254
1386
  self.current_feature_generating_model = ""
1387
+ self.current_embedding_model_info = None
1388
+ self._ui_initialized = False
1255
1389
 
1256
1390
  self.setWindowTitle("Explorer")
1257
- # Set the window icon
1258
1391
  explorer_icon_path = get_icon("magic.png")
1259
1392
  self.setWindowIcon(QIcon(explorer_icon_path))
1260
1393
 
1261
- # Create a central widget and main layout
1262
1394
  self.central_widget = QWidget()
1263
1395
  self.setCentralWidget(self.central_widget)
1264
1396
  self.main_layout = QVBoxLayout(self.central_widget)
1265
- # Create a left panel widget and layout for the re-parented LabelWindow
1266
1397
  self.left_panel = QWidget()
1267
1398
  self.left_layout = QVBoxLayout(self.left_panel)
1268
1399
 
1269
- # Create widgets in __init__ so they're always available
1270
- self.annotation_settings_widget = AnnotationSettingsWidget(self.main_window, self)
1271
- self.model_settings_widget = ModelSettingsWidget(self.main_window, self)
1272
- self.embedding_settings_widget = EmbeddingSettingsWidget(self.main_window, self)
1273
- self.annotation_viewer = AnnotationViewer(self)
1274
- self.embedding_viewer = EmbeddingViewer(self)
1400
+ self.annotation_settings_widget = None
1401
+ self.model_settings_widget = None
1402
+ self.embedding_settings_widget = None
1403
+ self.annotation_viewer = None
1404
+ self.embedding_viewer = None
1275
1405
 
1276
- # Create buttons
1277
1406
  self.clear_preview_button = QPushButton('Clear Preview', self)
1278
1407
  self.clear_preview_button.clicked.connect(self.clear_preview_changes)
1279
1408
  self.clear_preview_button.setToolTip("Clear all preview changes and revert to original labels")
1280
- self.clear_preview_button.setEnabled(False) # Initially disabled
1409
+ self.clear_preview_button.setEnabled(False)
1281
1410
 
1282
1411
  self.exit_button = QPushButton('Exit', self)
1283
1412
  self.exit_button.clicked.connect(self.close)
@@ -1286,27 +1415,30 @@ class ExplorerWindow(QMainWindow):
1286
1415
  self.apply_button = QPushButton('Apply', self)
1287
1416
  self.apply_button.clicked.connect(self.apply)
1288
1417
  self.apply_button.setToolTip("Apply changes")
1289
- self.apply_button.setEnabled(False) # Initially disabled
1418
+ self.apply_button.setEnabled(False)
1290
1419
 
1291
1420
  def showEvent(self, event):
1292
- self.setup_ui()
1421
+ """Handle show event."""
1422
+ if not self._ui_initialized:
1423
+ self.setup_ui()
1424
+ self._ui_initialized = True
1293
1425
  super(ExplorerWindow, self).showEvent(event)
1294
1426
 
1295
1427
  def closeEvent(self, event):
1296
- """
1297
- Handles the window close event.
1298
- This now calls the resource cleanup method.
1299
- """
1428
+ """Handle close event."""
1300
1429
  # Stop any running timers to prevent errors
1301
1430
  if hasattr(self, 'embedding_viewer') and self.embedding_viewer:
1302
1431
  if hasattr(self.embedding_viewer, 'animation_timer') and self.embedding_viewer.animation_timer:
1303
1432
  self.embedding_viewer.animation_timer.stop()
1304
1433
 
1305
- # Clear any unsaved preview states
1306
- if hasattr(self, 'annotation_viewer'):
1307
- self.annotation_viewer.clear_preview_states()
1434
+ # Call the main cancellation method to revert any pending changes
1435
+ self.clear_preview_changes()
1308
1436
 
1309
- # --- NEW: Call the dedicated cleanup method ---
1437
+ # Clean up the feature store by deleting its files
1438
+ if hasattr(self, 'feature_store') and self.feature_store:
1439
+ self.feature_store.delete_storage()
1440
+
1441
+ # Call the dedicated cleanup method
1310
1442
  self._cleanup_resources()
1311
1443
 
1312
1444
  # Re-enable the main window before closing
@@ -1319,541 +1451,908 @@ class ExplorerWindow(QMainWindow):
1319
1451
 
1320
1452
  # Clear the reference in the main_window to allow garbage collection
1321
1453
  self.main_window.explorer_window = None
1322
-
1454
+
1455
+ # Set the ui_initialized flag to False so it can be re-initialized next time
1456
+ self._ui_initialized = False
1457
+
1323
1458
  event.accept()
1324
1459
 
1325
1460
  def setup_ui(self):
1326
- # Clear the main layout to remove any existing widgets
1461
+ """Set up the UI for the ExplorerWindow."""
1327
1462
  while self.main_layout.count():
1328
1463
  child = self.main_layout.takeAt(0)
1329
1464
  if child.widget():
1330
- child.widget().setParent(None) # Remove from layout but don't delete
1465
+ child.widget().setParent(None)
1331
1466
 
1332
- # Top section: Conditions and Settings side by side
1333
- top_layout = QHBoxLayout()
1467
+ # Lazily initialize the settings and viewer widgets if they haven't been created yet.
1468
+ # This ensures that the widgets are only created once per ExplorerWindow instance.
1469
+
1470
+ # Annotation settings panel (filters by image, type, label)
1471
+ if self.annotation_settings_widget is None:
1472
+ self.annotation_settings_widget = AnnotationSettingsWidget(self.main_window, self)
1334
1473
 
1335
- # Add existing widgets to layout
1336
- top_layout.addWidget(self.annotation_settings_widget, 2) # Give annotation settings more space
1337
- top_layout.addWidget(self.model_settings_widget, 1) # Model settings in the middle
1338
- top_layout.addWidget(self.embedding_settings_widget, 1) # Embedding settings on the right
1474
+ # Model selection panel (choose feature extraction model)
1475
+ if self.model_settings_widget is None:
1476
+ self.model_settings_widget = ModelSettingsWidget(self.main_window, self)
1339
1477
 
1340
- # Create container widget for top layout
1478
+ # Embedding settings panel (choose dimensionality reduction method)
1479
+ if self.embedding_settings_widget is None:
1480
+ self.embedding_settings_widget = EmbeddingSettingsWidget(self.main_window, self)
1481
+
1482
+ # Annotation viewer (shows annotation image crops in a grid)
1483
+ if self.annotation_viewer is None:
1484
+ self.annotation_viewer = AnnotationViewer(self)
1485
+
1486
+ # Embedding viewer (shows 2D embedding scatter plot)
1487
+ if self.embedding_viewer is None:
1488
+ self.embedding_viewer = EmbeddingViewer(self)
1489
+
1490
+ top_layout = QHBoxLayout()
1491
+ top_layout.addWidget(self.annotation_settings_widget, 2)
1492
+ top_layout.addWidget(self.model_settings_widget, 1)
1493
+ top_layout.addWidget(self.embedding_settings_widget, 1)
1341
1494
  top_container = QWidget()
1342
1495
  top_container.setLayout(top_layout)
1343
1496
  self.main_layout.addWidget(top_container)
1344
1497
 
1345
- # Middle section: Annotation Viewer (left) and Embedding Viewer (right)
1346
1498
  middle_splitter = QSplitter(Qt.Horizontal)
1347
-
1348
- # Wrap annotation viewer in a group box
1349
1499
  annotation_group = QGroupBox("Annotation Viewer")
1350
1500
  annotation_layout = QVBoxLayout(annotation_group)
1351
1501
  annotation_layout.addWidget(self.annotation_viewer)
1352
1502
  middle_splitter.addWidget(annotation_group)
1353
1503
 
1354
- # Wrap embedding viewer in a group box
1355
1504
  embedding_group = QGroupBox("Embedding Viewer")
1356
1505
  embedding_layout = QVBoxLayout(embedding_group)
1357
1506
  embedding_layout.addWidget(self.embedding_viewer)
1358
1507
  middle_splitter.addWidget(embedding_group)
1359
-
1360
- # Set splitter proportions (annotation viewer wider)
1361
1508
  middle_splitter.setSizes([500, 500])
1362
-
1363
- # Add middle section to main layout with stretch factor
1364
1509
  self.main_layout.addWidget(middle_splitter, 1)
1365
-
1366
- # Note: LabelWindow will be re-parented here by MainWindow.open_explorer_window()
1367
- # The LabelWindow will be added to self.left_layout at index 1 by the MainWindow
1368
1510
  self.main_layout.addWidget(self.label_window)
1369
1511
 
1370
- # Bottom control buttons
1371
1512
  self.buttons_layout = QHBoxLayout()
1372
- # Add stretch to push buttons to the right
1373
1513
  self.buttons_layout.addStretch(1)
1374
-
1375
- # Add existing buttons to layout
1376
1514
  self.buttons_layout.addWidget(self.clear_preview_button)
1377
1515
  self.buttons_layout.addWidget(self.exit_button)
1378
1516
  self.buttons_layout.addWidget(self.apply_button)
1379
-
1380
1517
  self.main_layout.addLayout(self.buttons_layout)
1381
-
1382
- # Set default condition to current image and refresh filters
1518
+
1519
+ self._initialize_data_item_cache()
1520
+ self.annotation_settings_widget.set_default_to_current_image()
1521
+ self.refresh_filters()
1522
+
1383
1523
  self.annotation_settings_widget.set_default_to_current_image()
1384
1524
  self.refresh_filters()
1385
1525
 
1386
- # Connect label selection to preview updates (only connect once)
1387
1526
  try:
1388
1527
  self.label_window.labelSelected.disconnect(self.on_label_selected_for_preview)
1389
1528
  except TypeError:
1390
- pass # Signal wasn't connected yet
1529
+ pass
1391
1530
 
1531
+ # Connect signals to slots
1392
1532
  self.label_window.labelSelected.connect(self.on_label_selected_for_preview)
1393
1533
  self.annotation_viewer.selection_changed.connect(self.on_annotation_view_selection_changed)
1394
1534
  self.annotation_viewer.preview_changed.connect(self.on_preview_changed)
1395
1535
  self.annotation_viewer.reset_view_requested.connect(self.on_reset_view_requested)
1396
1536
  self.embedding_viewer.selection_changed.connect(self.on_embedding_view_selection_changed)
1397
1537
  self.embedding_viewer.reset_view_requested.connect(self.on_reset_view_requested)
1398
-
1538
+ self.embedding_viewer.find_mislabels_requested.connect(self.find_potential_mislabels)
1539
+ self.embedding_viewer.mislabel_parameters_changed.connect(self.on_mislabel_params_changed)
1540
+ self.model_settings_widget.selection_changed.connect(self.on_model_selection_changed)
1541
+ self.embedding_viewer.find_uncertain_requested.connect(self.find_uncertain_annotations)
1542
+ self.embedding_viewer.uncertainty_parameters_changed.connect(self.on_uncertainty_params_changed)
1543
+ self.annotation_viewer.find_similar_requested.connect(self.find_similar_annotations)
1544
+ self.annotation_viewer.similarity_settings_widget.parameters_changed.connect(self.on_similarity_params_changed)
1545
+
1399
1546
  @pyqtSlot(list)
1400
1547
  def on_annotation_view_selection_changed(self, changed_ann_ids):
1401
- """A selection was made in the AnnotationViewer, so update the EmbeddingViewer."""
1548
+ """Syncs selection from AnnotationViewer to EmbeddingViewer."""
1549
+ # Per request, unselect any annotation in the main AnnotationWindow
1550
+ if hasattr(self, 'annotation_window'):
1551
+ self.annotation_window.unselect_annotations()
1552
+
1402
1553
  all_selected_ids = {w.data_item.annotation.id for w in self.annotation_viewer.selected_widgets}
1403
-
1404
- # Only try to sync the selection with the EmbeddingViewer if it has points.
1405
- # This prevents the feedback loop that was clearing the selection.
1406
1554
  if self.embedding_viewer.points_by_id:
1407
1555
  self.embedding_viewer.render_selection_from_ids(all_selected_ids)
1408
-
1409
- self.update_label_window_selection() # Keep label window in sync
1556
+
1557
+ # Call the new centralized method
1558
+ self.update_label_window_selection()
1410
1559
 
1411
1560
  @pyqtSlot(list)
1412
1561
  def on_embedding_view_selection_changed(self, all_selected_ann_ids):
1413
- """A selection was made in the EmbeddingViewer, so update the AnnotationViewer."""
1414
- # Check if this is a new selection being made when nothing was previously selected
1562
+ """Syncs selection from EmbeddingViewer to AnnotationViewer."""
1563
+ # Per request, unselect any annotation in the main AnnotationWindow
1564
+ if hasattr(self, 'annotation_window'):
1565
+ self.annotation_window.unselect_annotations()
1566
+
1567
+ # Check the state BEFORE the selection is changed
1415
1568
  was_empty_selection = len(self.annotation_viewer.selected_widgets) == 0
1416
- is_new_selection = len(all_selected_ann_ids) > 0
1417
-
1418
- # Update the annotation viewer with the new selection
1569
+
1570
+ # Now, update the selection in the annotation viewer
1419
1571
  self.annotation_viewer.render_selection_from_ids(set(all_selected_ann_ids))
1420
-
1421
- # Auto-switch to isolation mode if conditions are met
1422
- if (was_empty_selection and
1423
- is_new_selection and
1424
- not self.annotation_viewer.isolated_mode):
1425
- print("Auto-switching to isolation mode due to new selection in embedding viewer")
1572
+
1573
+ # The rest of the logic now works correctly
1574
+ is_new_selection = len(all_selected_ann_ids) > 0
1575
+ if (
1576
+ was_empty_selection and
1577
+ is_new_selection and
1578
+ not self.annotation_viewer.isolated_mode
1579
+ ):
1426
1580
  self.annotation_viewer.isolate_selection()
1427
-
1428
- self.update_label_window_selection() # Keep label window in sync
1581
+
1582
+ self.update_label_window_selection()
1429
1583
 
1430
1584
  @pyqtSlot(list)
1431
1585
  def on_preview_changed(self, changed_ann_ids):
1432
- """A preview color was changed in the AnnotationViewer, so update the EmbeddingViewer points."""
1586
+ """Updates embedding point colors and tooltips when a preview label is applied."""
1433
1587
  for ann_id in changed_ann_ids:
1588
+ # Update embedding point color
1434
1589
  point = self.embedding_viewer.points_by_id.get(ann_id)
1435
1590
  if point:
1436
- point.update() # Force the point to repaint itself
1437
-
1591
+ point.update()
1592
+ point.update_tooltip() # Refresh tooltip to show new effective label
1593
+
1594
+ # Update annotation widget tooltip
1595
+ widget = self.annotation_viewer.annotation_widgets_by_id.get(ann_id)
1596
+ if widget:
1597
+ widget.update_tooltip()
1598
+
1438
1599
  @pyqtSlot()
1439
1600
  def on_reset_view_requested(self):
1440
1601
  """Handle reset view requests from double-click in either viewer."""
1441
1602
  # Clear all selections in both viewers
1442
1603
  self.annotation_viewer.clear_selection()
1443
1604
  self.embedding_viewer.render_selection_from_ids(set())
1444
-
1445
- # Exit isolation mode if currently active
1605
+
1606
+ # Exit isolation mode if currently active in AnnotationViewer
1446
1607
  if self.annotation_viewer.isolated_mode:
1447
1608
  self.annotation_viewer.show_all_annotations()
1448
-
1449
- # Update button states
1609
+
1610
+ if self.embedding_viewer.isolated_mode:
1611
+ self.embedding_viewer.show_all_points()
1612
+
1613
+ # Clear similarity sort context
1614
+ self.annotation_viewer.active_ordered_ids = []
1615
+
1616
+ self.update_label_window_selection()
1450
1617
  self.update_button_states()
1451
-
1618
+
1452
1619
  print("Reset view: cleared selections and exited isolation mode")
1620
+
1621
+ @pyqtSlot(dict)
1622
+ def on_mislabel_params_changed(self, params):
1623
+ """Updates the stored parameters for mislabel detection."""
1624
+ self.mislabel_params = params
1625
+ print(f"Mislabel detection parameters updated: {self.mislabel_params}")
1626
+
1627
+ @pyqtSlot(dict)
1628
+ def on_uncertainty_params_changed(self, params):
1629
+ """Updates the stored parameters for uncertainty analysis."""
1630
+ self.uncertainty_params = params
1631
+ print(f"Uncertainty parameters updated: {self.uncertainty_params}")
1632
+
1633
+ @pyqtSlot(dict)
1634
+ def on_similarity_params_changed(self, params):
1635
+ """Updates the stored parameters for similarity search."""
1636
+ self.similarity_params = params
1637
+ print(f"Similarity search parameters updated: {self.similarity_params}")
1638
+
1639
+ @pyqtSlot()
1640
+ def on_model_selection_changed(self):
1641
+ """
1642
+ Handles changes in the model settings to enable/disable model-dependent features.
1643
+ """
1644
+ if not self._ui_initialized:
1645
+ return
1646
+
1647
+ model_name, feature_mode = self.model_settings_widget.get_selected_model()
1648
+ is_predict_mode = ".pt" in model_name and feature_mode == "Predictions"
1649
+
1650
+ self.embedding_viewer.is_uncertainty_analysis_available = is_predict_mode
1651
+ self.embedding_viewer._update_toolbar_state()
1652
+
1653
+ def _initialize_data_item_cache(self):
1654
+ """
1655
+ Creates a persistent AnnotationDataItem for every annotation,
1656
+ caching them for the duration of the session.
1657
+ """
1658
+ self.data_item_cache.clear()
1659
+ if not hasattr(self.main_window.annotation_window, 'annotations_dict'):
1660
+ return
1661
+
1662
+ all_annotations = self.main_window.annotation_window.annotations_dict.values()
1663
+ for ann in all_annotations:
1664
+ if ann.id not in self.data_item_cache:
1665
+ self.data_item_cache[ann.id] = AnnotationDataItem(ann)
1453
1666
 
1454
1667
  def update_label_window_selection(self):
1455
- """Update the label window based on the selection in the annotation viewer."""
1456
- self.annotation_viewer.update_label_window_selection()
1668
+ """
1669
+ Updates the label window based on the selection state of the currently
1670
+ loaded data items. This is the single, centralized point of logic.
1671
+ """
1672
+ # Get selected items directly from the master data list
1673
+ selected_data_items = [
1674
+ item for item in self.current_data_items if item.is_selected
1675
+ ]
1676
+
1677
+ if not selected_data_items:
1678
+ self.label_window.deselect_active_label()
1679
+ self.label_window.update_annotation_count()
1680
+ return
1681
+
1682
+ first_effective_label = selected_data_items[0].effective_label
1683
+ all_same_current_label = all(
1684
+ item.effective_label.id == first_effective_label.id
1685
+ for item in selected_data_items
1686
+ )
1687
+
1688
+ if all_same_current_label:
1689
+ self.label_window.set_active_label(first_effective_label)
1690
+ # This emit is what updates other UI elements, like the annotation list
1691
+ self.annotation_window.labelSelected.emit(first_effective_label.id)
1692
+ else:
1693
+ self.label_window.deselect_active_label()
1694
+
1695
+ self.label_window.update_annotation_count()
1457
1696
 
1458
1697
  def get_filtered_data_items(self):
1459
- """Get annotations that match all conditions, returned as AnnotationDataItem objects."""
1460
- data_items = []
1698
+ """
1699
+ Gets annotations matching all conditions by retrieving their
1700
+ persistent AnnotationDataItem objects from the cache.
1701
+ """
1461
1702
  if not hasattr(self.main_window.annotation_window, 'annotations_dict'):
1462
- return data_items
1703
+ return []
1463
1704
 
1464
- # Get current filter conditions
1465
1705
  selected_images = self.annotation_settings_widget.get_selected_images()
1466
1706
  selected_types = self.annotation_settings_widget.get_selected_annotation_types()
1467
1707
  selected_labels = self.annotation_settings_widget.get_selected_labels()
1468
1708
 
1469
- annotations_to_process = []
1470
- for annotation in self.main_window.annotation_window.annotations_dict.values():
1471
- annotation_matches = True
1472
-
1473
- # Check image condition - if empty list, no annotations match
1474
- if selected_images:
1475
- annotation_image = os.path.basename(annotation.image_path)
1476
- if annotation_image not in selected_images:
1477
- annotation_matches = False
1478
- else:
1479
- # No images selected means no annotations should match
1480
- annotation_matches = False
1481
-
1482
- # Check annotation type condition - if empty list, no annotations match
1483
- if annotation_matches:
1484
- if selected_types:
1485
- annotation_type = type(annotation).__name__
1486
- if annotation_type not in selected_types:
1487
- annotation_matches = False
1488
- else:
1489
- # No types selected means no annotations should match
1490
- annotation_matches = False
1491
-
1492
- # Check label condition - if empty list, no annotations match
1493
- if annotation_matches:
1494
- if selected_labels:
1495
- annotation_label = annotation.label.short_label_code
1496
- if annotation_label not in selected_labels:
1497
- annotation_matches = False
1498
- else:
1499
- # No labels selected means no annotations should match
1500
- annotation_matches = False
1709
+ if not all([selected_images, selected_types, selected_labels]):
1710
+ return []
1501
1711
 
1502
- if annotation_matches:
1503
- annotations_to_process.append(annotation)
1712
+ annotations_to_process = [
1713
+ ann for ann in self.main_window.annotation_window.annotations_dict.values()
1714
+ if (os.path.basename(ann.image_path) in selected_images and
1715
+ type(ann).__name__ in selected_types and
1716
+ ann.label.short_label_code in selected_labels)
1717
+ ]
1504
1718
 
1505
- # Ensure all filtered annotations have cropped images
1506
1719
  self._ensure_cropped_images(annotations_to_process)
1507
1720
 
1508
- # Wrap in AnnotationDataItem
1509
- for ann in annotations_to_process:
1510
- data_items.append(AnnotationDataItem(ann))
1511
-
1512
- return data_items
1721
+ return [self.data_item_cache[ann.id] for ann in annotations_to_process if ann.id in self.data_item_cache]
1513
1722
 
1514
- def _ensure_cropped_images(self, annotations):
1515
- """Ensure all provided annotations have a cropped image available."""
1516
- annotations_by_image = {}
1517
- for annotation in annotations:
1518
- # Only process annotations that don't have a cropped image yet
1519
- if not annotation.cropped_image:
1520
- image_path = annotation.image_path
1521
- if image_path not in annotations_by_image:
1522
- annotations_by_image[image_path] = []
1523
- annotations_by_image[image_path].append(annotation)
1524
-
1525
- # Only proceed if there are annotations that actually need cropping
1526
- if annotations_by_image:
1527
- progress_bar = ProgressBar(self, "Cropping Image Annotations")
1528
- progress_bar.show()
1529
- progress_bar.start_progress(len(annotations_by_image))
1530
-
1531
- try:
1532
- # Crop annotations for each image using the AnnotationWindow method
1533
- # This ensures consistency with how cropped images are generated elsewhere
1534
- for image_path, image_annotations in annotations_by_image.items():
1535
- self.annotation_window.crop_annotations(
1536
- image_path=image_path,
1537
- annotations=image_annotations,
1538
- return_annotations=False, # We don't need the return value
1539
- verbose=False
1540
- )
1541
- # Update progress bar
1542
- progress_bar.update_progress()
1543
-
1544
- except Exception as e:
1545
- print(f"Error cropping annotations: {e}")
1546
-
1547
- finally:
1548
- progress_bar.finish_progress()
1549
- progress_bar.stop_progress()
1550
- progress_bar.close()
1551
-
1552
- def _extract_color_features(self, data_items, progress_bar=None, bins=32):
1723
+ def find_potential_mislabels(self):
1553
1724
  """
1554
- Extracts a comprehensive set of color features from annotation crops using only NumPy.
1555
-
1556
- For each image, it calculates:
1557
- 1. Color Moments (per channel):
1558
- - Mean (1st moment)
1559
- - Standard Deviation (2nd moment)
1560
- - Skewness (3rd moment)
1561
- - Kurtosis (4th moment)
1562
- 2. Color Histogram (per channel)
1563
- 3. Grayscale Statistics:
1564
- - Mean Brightness
1565
- - Contrast (Std Dev)
1566
- - Intensity Range
1567
- 4. Geometric Features:
1568
- - Area
1569
- - Perimeter
1725
+ Identifies annotations whose label does not match the majority of its
1726
+ k-nearest neighbors in the high-dimensional feature space.
1570
1727
  """
1571
- if progress_bar:
1572
- progress_bar.set_title("Extracting Color Features...")
1573
- progress_bar.start_progress(len(data_items))
1728
+ # Get parameters from the stored property instead of hardcoding
1729
+ K = self.mislabel_params.get('k', 5)
1730
+ agreement_threshold = self.mislabel_params.get('threshold', 0.6)
1574
1731
 
1575
- features = []
1576
- valid_data_items = []
1577
- for item in data_items:
1578
- pixmap = item.annotation.get_cropped_image()
1579
- if pixmap and not pixmap.isNull():
1580
- # arr has shape (height, width, 3)
1581
- arr = pixmap_to_numpy(pixmap)
1582
-
1583
- # Reshape for channel-wise statistics: (num_pixels, 3)
1584
- pixels = arr.reshape(-1, 3)
1732
+ if not self.embedding_viewer.points_by_id or len(self.embedding_viewer.points_by_id) < K:
1733
+ QMessageBox.information(self, "Not Enough Data",
1734
+ f"This feature requires at least {K} points in the embedding viewer.")
1735
+ return
1585
1736
 
1586
- # --- 1. Calculate Color Moments using only NumPy ---
1587
- mean_color = np.mean(pixels, axis=0)
1588
- std_color = np.std(pixels, axis=0)
1737
+ items_in_view = list(self.embedding_viewer.points_by_id.values())
1738
+ data_items_in_view = [p.data_item for p in items_in_view]
1739
+
1740
+ # Get the model key used for the current embedding
1741
+ model_info = self.model_settings_widget.get_selected_model()
1742
+ model_name, feature_mode = model_info if isinstance(model_info, tuple) else (model_info, "default")
1743
+ sanitized_model_name = os.path.basename(model_name).replace(' ', '_')
1744
+ # FIX: Also replace the forward slash to handle "N/A"
1745
+ sanitized_feature_mode = feature_mode.replace(' ', '_').replace('/', '_')
1746
+ model_key = f"{sanitized_model_name}_{sanitized_feature_mode}"
1747
+
1748
+ QApplication.setOverrideCursor(Qt.WaitCursor)
1749
+ try:
1750
+ # Get the FAISS index and the mapping from index to annotation ID
1751
+ index = self.feature_store._get_or_load_index(model_key)
1752
+ faiss_idx_to_ann_id = self.feature_store.get_faiss_index_to_annotation_id_map(model_key)
1753
+ if index is None or not faiss_idx_to_ann_id:
1754
+ QMessageBox.warning(self, "Error", "Could not find a valid feature index for the current model.")
1755
+ return
1756
+
1757
+ # Get the high-dimensional features for the points in the current view
1758
+ features_dict, _ = self.feature_store.get_features(data_items_in_view, model_key)
1759
+ if not features_dict:
1760
+ QMessageBox.warning(self, "Error", "Could not retrieve features for the items in view.")
1761
+ return
1762
+
1763
+ query_ann_ids = list(features_dict.keys())
1764
+ query_vectors = np.array([features_dict[ann_id] for ann_id in query_ann_ids]).astype('float32')
1765
+
1766
+ # Perform k-NN search. We search for K+1 because the point itself will be the first result.
1767
+ _, I = index.search(query_vectors, K + 1)
1768
+
1769
+ mislabeled_ann_ids = []
1770
+ for i, ann_id in enumerate(query_ann_ids):
1771
+ current_label = self.data_item_cache[ann_id].effective_label.id
1589
1772
 
1590
- # Center the data (subtract the mean) for skew/kurtosis calculation
1591
- centered_pixels = pixels - mean_color
1773
+ # Get neighbor labels, ignoring the first result (the point itself)
1774
+ neighbor_faiss_indices = I[i][1:]
1592
1775
 
1593
- # A small value (epsilon) is added to the denominator to prevent division by zero
1594
- epsilon = 1e-8
1595
- skew_color = np.mean(centered_pixels**3, axis=0) / (std_color**3 + epsilon)
1596
- kurt_color = np.mean(centered_pixels**4, axis=0) / (std_color**4 + epsilon) - 3
1776
+ neighbor_labels = []
1777
+ for n_idx in neighbor_faiss_indices:
1778
+ # THIS IS THE CORRECTED LOGIC
1779
+ if n_idx in faiss_idx_to_ann_id:
1780
+ neighbor_ann_id = faiss_idx_to_ann_id[n_idx]
1781
+ # ADD THIS CHECK to ensure the neighbor hasn't been deleted
1782
+ if neighbor_ann_id in self.data_item_cache:
1783
+ neighbor_labels.append(self.data_item_cache[neighbor_ann_id].effective_label.id)
1784
+
1785
+ if not neighbor_labels:
1786
+ continue
1787
+
1788
+ # Use the agreement threshold instead of strict majority
1789
+ num_matching_neighbors = neighbor_labels.count(current_label)
1790
+ agreement_ratio = num_matching_neighbors / len(neighbor_labels)
1791
+
1792
+ if agreement_ratio < agreement_threshold:
1793
+ mislabeled_ann_ids.append(ann_id)
1794
+
1795
+ self.embedding_viewer.render_selection_from_ids(set(mislabeled_ann_ids))
1796
+
1797
+ finally:
1798
+ QApplication.restoreOverrideCursor()
1799
+
1800
+ def find_uncertain_annotations(self):
1801
+ """
1802
+ Identifies annotations where the model's prediction is uncertain.
1803
+ It reuses cached predictions if available, otherwise runs a temporary prediction.
1804
+ """
1805
+ if not self.embedding_viewer.points_by_id:
1806
+ QMessageBox.information(self, "No Data", "Please generate an embedding first.")
1807
+ return
1808
+
1809
+ if self.current_embedding_model_info is None:
1810
+ QMessageBox.information(self,
1811
+ "No Embedding",
1812
+ "Could not determine the model used for the embedding. Please run it again.")
1813
+ return
1814
+
1815
+ items_in_view = list(self.embedding_viewer.points_by_id.values())
1816
+ data_items_in_view = [p.data_item for p in items_in_view]
1817
+
1818
+ model_name_from_embedding, feature_mode_from_embedding = self.current_embedding_model_info
1819
+
1820
+ QApplication.setOverrideCursor(Qt.WaitCursor)
1821
+ try:
1822
+ probabilities_dict = {}
1823
+
1824
+ # Decide whether to reuse cached features or run a new prediction
1825
+ if feature_mode_from_embedding == "Predictions":
1826
+ print("Reusing cached prediction vectors from the FeatureStore.")
1827
+ sanitized_model_name = os.path.basename(model_name_from_embedding).replace(' ', '_').replace('/', '_')
1828
+ sanitized_feature_mode = feature_mode_from_embedding.replace(' ', '_').replace('/', '_')
1829
+ model_key = f"{sanitized_model_name}_{sanitized_feature_mode}"
1597
1830
 
1598
- # --- 2. Calculate Color Histograms ---
1599
- histograms = []
1600
- for i in range(3): # For each channel (R, G, B)
1601
- hist, _ = np.histogram(pixels[:, i], bins=bins, range=(0, 255))
1602
-
1603
- # Normalize histogram
1604
- hist_sum = np.sum(hist)
1605
- if hist_sum > 0:
1606
- histograms.append(hist / hist_sum)
1607
- else: # Avoid division by zero
1608
- histograms.append(np.zeros(bins))
1831
+ probabilities_dict, _ = self.feature_store.get_features(data_items_in_view, model_key)
1832
+ if not probabilities_dict:
1833
+ QMessageBox.warning(self,
1834
+ "Cache Error",
1835
+ "Could not retrieve cached predictions.")
1836
+ return
1837
+ else:
1838
+ print("Embedding not based on 'Predictions' mode. Running a temporary prediction.")
1839
+ model_info_for_predict = self.model_settings_widget.get_selected_model()
1840
+ probabilities_dict = self._get_yolo_predictions_for_uncertainty(data_items_in_view,
1841
+ model_info_for_predict)
1842
+
1843
+ if not probabilities_dict:
1844
+ # The helper function will show its own, more specific errors.
1845
+ return
1846
+
1847
+ uncertain_ids = []
1848
+ params = self.uncertainty_params
1849
+ for ann_id, probs in probabilities_dict.items():
1850
+ if len(probs) < 2:
1851
+ continue # Cannot calculate margin
1852
+
1853
+ sorted_probs = np.sort(probs)[::-1]
1854
+ top1_conf = sorted_probs[0]
1855
+ top2_conf = sorted_probs[1]
1856
+ margin = top1_conf - top2_conf
1857
+
1858
+ if top1_conf < params['confidence'] or margin < params['margin']:
1859
+ uncertain_ids.append(ann_id)
1860
+
1861
+ self.embedding_viewer.render_selection_from_ids(set(uncertain_ids))
1862
+ print(f"Found {len(uncertain_ids)} uncertain annotations.")
1863
+
1864
+ finally:
1865
+ QApplication.restoreOverrideCursor()
1866
+
1867
+ @pyqtSlot()
1868
+ def find_similar_annotations(self):
1869
+ """
1870
+ Finds k-nearest neighbors to the selected annotation(s) and updates
1871
+ the UI to show the results in an isolated, ordered view. This method
1872
+ now ensures the grid is always updated and resets the sort-by dropdown.
1873
+ """
1874
+ k = self.similarity_params.get('k', 10)
1875
+
1876
+ if not self.annotation_viewer.selected_widgets:
1877
+ QMessageBox.information(self, "No Selection", "Please select one or more annotations first.")
1878
+ return
1879
+
1880
+ if not self.current_embedding_model_info:
1881
+ QMessageBox.warning(self, "No Embedding", "Please run an embedding before searching for similar items.")
1882
+ return
1883
+
1884
+ selected_data_items = [widget.data_item for widget in self.annotation_viewer.selected_widgets]
1885
+ model_name, feature_mode = self.current_embedding_model_info
1886
+ sanitized_model_name = os.path.basename(model_name).replace(' ', '_')
1887
+ sanitized_feature_mode = feature_mode.replace(' ', '_').replace('/', '_')
1888
+ model_key = f"{sanitized_model_name}_{sanitized_feature_mode}"
1889
+
1890
+ QApplication.setOverrideCursor(Qt.WaitCursor)
1891
+ try:
1892
+ features_dict, _ = self.feature_store.get_features(selected_data_items, model_key)
1893
+ if not features_dict:
1894
+ QMessageBox.warning(self,
1895
+ "Features Not Found",
1896
+ "Could not retrieve feature vectors for the selected items.")
1897
+ return
1898
+
1899
+ source_vectors = np.array(list(features_dict.values()))
1900
+ query_vector = np.mean(source_vectors, axis=0, keepdims=True).astype('float32')
1901
+
1902
+ index = self.feature_store._get_or_load_index(model_key)
1903
+ faiss_idx_to_ann_id = self.feature_store.get_faiss_index_to_annotation_id_map(model_key)
1904
+ if index is None or not faiss_idx_to_ann_id:
1905
+ QMessageBox.warning(self,
1906
+ "Index Error",
1907
+ "Could not find a valid feature index for the current model.")
1908
+ return
1909
+
1910
+ # Find k results, plus more to account for the query items possibly being in the results
1911
+ num_to_find = k + len(selected_data_items)
1912
+ if num_to_find > index.ntotal:
1913
+ num_to_find = index.ntotal
1914
+
1915
+ _, I = index.search(query_vector, num_to_find)
1916
+
1917
+ source_ids = {item.annotation.id for item in selected_data_items}
1918
+ similar_ann_ids = []
1919
+ for faiss_idx in I[0]:
1920
+ ann_id = faiss_idx_to_ann_id.get(faiss_idx)
1921
+ if ann_id and ann_id in self.data_item_cache and ann_id not in source_ids:
1922
+ similar_ann_ids.append(ann_id)
1923
+ if len(similar_ann_ids) == k:
1924
+ break
1925
+
1926
+ # Create the final ordered list: original selection first, then similar items.
1927
+ ordered_ids_to_display = list(source_ids) + similar_ann_ids
1928
+
1929
+ # --- FIX IMPLEMENTATION ---
1930
+ # 1. Force sort combo to "None" to avoid user confusion.
1931
+ self.annotation_viewer.sort_combo.setCurrentText("None")
1932
+
1933
+ # 2. Update the embedding viewer selection.
1934
+ self.embedding_viewer.render_selection_from_ids(set(ordered_ids_to_display))
1935
+
1936
+ # 3. Call the new robust method in AnnotationViewer to handle isolation and grid updates.
1937
+ self.annotation_viewer.display_and_isolate_ordered_results(ordered_ids_to_display)
1938
+
1939
+ self.update_button_states()
1940
+
1941
+ finally:
1942
+ QApplication.restoreOverrideCursor()
1943
+
1944
+ def _get_yolo_predictions_for_uncertainty(self, data_items, model_info):
1945
+ """
1946
+ Runs a YOLO classification model to get probabilities for uncertainty analysis.
1947
+ This is a streamlined method that does NOT use the feature store.
1948
+ """
1949
+ model_name, feature_mode = model_info
1950
+
1951
+ # Load the model
1952
+ model, imgsz = self._load_yolo_model(model_name, feature_mode)
1953
+ if model is None:
1954
+ QMessageBox.warning(self,
1955
+ "Model Load Error",
1956
+ f"Could not load YOLO model '{model_name}'.")
1957
+ return None
1958
+
1959
+ # Prepare images from data items
1960
+ image_list, valid_data_items = self._prepare_images_from_data_items(data_items)
1961
+ if not image_list:
1962
+ return None
1963
+
1964
+ try:
1965
+ # We need probabilities for uncertainty analysis, so we always use predict
1966
+ results = model.predict(image_list,
1967
+ stream=False, # Use batch processing for uncertainty
1968
+ imgsz=imgsz,
1969
+ half=True,
1970
+ device=self.device,
1971
+ verbose=False)
1609
1972
 
1610
- # --- 3. Calculate Grayscale Statistics ---
1611
- gray_arr = np.dot(arr[..., :3], [0.2989, 0.5870, 0.1140])
1612
- grayscale_stats = np.array([
1613
- np.mean(gray_arr), # Overall brightness
1614
- np.std(gray_arr), # Overall contrast
1615
- np.ptp(gray_arr) # Peak-to-peak intensity range
1616
- ])
1973
+ _, probabilities_dict = self._process_model_results(results, valid_data_items, "Predictions")
1974
+ return probabilities_dict
1975
+
1976
+ except TypeError:
1977
+ QMessageBox.warning(self,
1978
+ "Invalid Model",
1979
+ "The selected model is not compatible with uncertainty analysis.")
1980
+ return None
1981
+
1982
+ finally:
1983
+ if torch.cuda.is_available():
1984
+ torch.cuda.empty_cache()
1985
+
1986
+ def _ensure_cropped_images(self, annotations):
1987
+ """Ensures all provided annotations have a cropped image available."""
1988
+ annotations_by_image = {}
1989
+
1990
+ for annotation in annotations:
1991
+ if not annotation.cropped_image:
1992
+ image_path = annotation.image_path
1993
+ if image_path not in annotations_by_image:
1994
+ annotations_by_image[image_path] = []
1995
+ annotations_by_image[image_path].append(annotation)
1996
+
1997
+ if not annotations_by_image:
1998
+ return
1999
+
2000
+ progress_bar = ProgressBar(self, "Cropping Image Annotations")
2001
+ progress_bar.show()
2002
+ progress_bar.start_progress(len(annotations_by_image))
2003
+
2004
+ try:
2005
+ for image_path, image_annotations in annotations_by_image.items():
2006
+ self.annotation_window.crop_annotations(image_path=image_path,
2007
+ annotations=image_annotations,
2008
+ return_annotations=False,
2009
+ verbose=False)
2010
+ progress_bar.update_progress()
2011
+ finally:
2012
+ progress_bar.finish_progress()
2013
+ progress_bar.stop_progress()
2014
+ progress_bar.close()
2015
+
2016
+ def _load_yolo_model(self, model_name, feature_mode):
2017
+ """
2018
+ Helper function to load a YOLO model and cache it.
2019
+
2020
+ Args:
2021
+ model_name (str): Path to the YOLO model file
2022
+ feature_mode (str): Mode for feature extraction ("Embed Features" or "Predictions")
2023
+
2024
+ Returns:
2025
+ tuple: (model, image_size) or (None, None) if loading fails
2026
+ """
2027
+ current_run_key = (model_name, feature_mode)
2028
+
2029
+ # Force a reload if the model path OR the feature mode has changed
2030
+ if current_run_key != self.current_feature_generating_model or self.loaded_model is None:
2031
+ print(f"Model or mode changed. Reloading {model_name} for '{feature_mode}'.")
2032
+ try:
2033
+ model = YOLO(model_name)
2034
+ # Update the cache key to the new successful combination
2035
+ self.current_feature_generating_model = current_run_key
2036
+ self.loaded_model = model
2037
+ imgsz = getattr(model.model.args, 'imgsz', 128)
1617
2038
 
1618
- # --- 4. Calculate Geometric Features ---
1619
- area = getattr(item.annotation, 'area', 0.0)
1620
- perimeter = getattr(item.annotation, 'perimeter', 0.0)
1621
- geometric_features = np.array([area, perimeter])
2039
+ # Warm up the model
2040
+ dummy_image = np.zeros((imgsz, imgsz, 3), dtype=np.uint8)
2041
+ model.predict(dummy_image, imgsz=imgsz, half=True, device=self.device, verbose=False)
1622
2042
 
1623
- # --- 5. Concatenate all features into a single vector ---
1624
- current_features = np.concatenate([
1625
- mean_color,
1626
- std_color,
1627
- skew_color,
1628
- kurt_color,
1629
- *histograms,
1630
- grayscale_stats,
1631
- geometric_features
1632
- ])
2043
+ return model, imgsz
1633
2044
 
1634
- features.append(current_features)
2045
+ except Exception as e:
2046
+ print(f"ERROR: Could not load YOLO model '{model_name}': {e}")
2047
+ # On failure, reset the model cache
2048
+ self.loaded_model = None
2049
+ self.current_feature_generating_model = None
2050
+ return None, None
2051
+
2052
+ # Model already loaded and cached
2053
+ return self.loaded_model, getattr(self.loaded_model.model.args, 'imgsz', 128)
2054
+
2055
+ def _prepare_images_from_data_items(self, data_items, progress_bar=None):
2056
+ """
2057
+ Prepare images from data items for model prediction.
2058
+
2059
+ Args:
2060
+ data_items (list): List of AnnotationDataItem objects
2061
+ progress_bar (ProgressBar, optional): Progress bar for UI updates
2062
+
2063
+ Returns:
2064
+ tuple: (image_list, valid_data_items)
2065
+ """
2066
+ if progress_bar:
2067
+ progress_bar.set_title("Preparing images...")
2068
+ progress_bar.start_progress(len(data_items))
2069
+
2070
+ image_list, valid_data_items = [], []
2071
+ for item in data_items:
2072
+ pixmap = item.annotation.get_cropped_image()
2073
+ if pixmap and not pixmap.isNull():
2074
+ image_list.append(pixmap_to_numpy(pixmap))
1635
2075
  valid_data_items.append(item)
1636
- else:
1637
- print(f"Warning: Could not get cropped image for annotation ID {item.annotation.id}. Skipping.")
1638
2076
 
1639
2077
  if progress_bar:
1640
2078
  progress_bar.update_progress()
2079
+
2080
+ return image_list, valid_data_items
1641
2081
 
1642
- return np.array(features), valid_data_items
1643
-
1644
- def _extract_yolo_features(self, data_items, model_info, progress_bar=None):
2082
+ def _process_model_results(self, results, valid_data_items, feature_mode, progress_bar=None):
1645
2083
  """
1646
- Extracts features from annotation crops using a specified YOLO model.
1647
- Uses model.embed() for embedding features or model.predict() for classification probabilities.
2084
+ Process model results and update data item tooltips.
2085
+
2086
+ Args:
2087
+ results: Model prediction results
2088
+ valid_data_items (list): List of valid data items
2089
+ feature_mode (str): Mode for feature extraction
2090
+ progress_bar (ProgressBar, optional): Progress bar for UI updates
2091
+
2092
+ Returns:
2093
+ tuple: (features_list, probabilities_dict)
1648
2094
  """
1649
- # Unpack model information
1650
- model_name, feature_mode = model_info
2095
+ features_list = []
2096
+ probabilities_dict = {}
1651
2097
 
1652
- # Load or retrieve the cached model
1653
- if model_name != self.model_path or self.loaded_model is None:
1654
- try:
1655
- print(f"Loading new model: {model_name}")
1656
- self.loaded_model = YOLO(model_name)
1657
- self.model_path = model_name
2098
+ # Get class names from the model for better tooltips
2099
+ model = self.loaded_model.model if hasattr(self.loaded_model, 'model') else None
2100
+ class_names = model.names if model and hasattr(model, 'names') else {}
2101
+
2102
+ for i, result in enumerate(results):
2103
+ if i >= len(valid_data_items):
2104
+ break
1658
2105
 
1659
- # Determine image size from model config if possible
1660
- try:
1661
- self.imgsz = self.loaded_model.model.args['imgsz']
1662
- if self.imgsz > 224:
1663
- self.imgsz = 128
1664
- except (AttributeError, KeyError):
1665
- self.imgsz = 128
2106
+ item = valid_data_items[i]
2107
+ ann_id = item.annotation.id
2108
+
2109
+ if feature_mode == "Embed Features":
2110
+ embedding = result.cpu().numpy().flatten()
2111
+ features_list.append(embedding)
2112
+
2113
+ elif hasattr(result, 'probs') and result.probs is not None:
2114
+ probs = result.probs.data.cpu().numpy().squeeze()
2115
+ features_list.append(probs)
2116
+ probabilities_dict[ann_id] = probs
2117
+
2118
+ # Store the probabilities directly on the data item for confidence sorting
2119
+ item.prediction_probabilities = probs
1666
2120
 
1667
- # Run a dummy inference to warm up the model
1668
- print(f"Warming up model on device '{self.device}'...")
1669
- dummy_image = np.zeros((self.imgsz, self.imgsz, 3), dtype=np.uint8)
1670
- self.loaded_model.predict(dummy_image, imgsz=self.imgsz, half=True, device=self.device, verbose=False)
2121
+ # Format and store prediction details for tooltips
2122
+ if len(probs) > 0:
2123
+ # Get top 5 predictions
2124
+ top_indices = probs.argsort()[::-1][:5]
2125
+ top_probs = probs[top_indices]
1671
2126
 
1672
- except Exception as e:
1673
- print(f"ERROR: Could not load YOLO model '{model_name}': {e}")
1674
- return np.array([]), []
2127
+ formatted_preds = ["<b>Top Predictions:</b>"]
2128
+ for idx, prob in zip(top_indices, top_probs):
2129
+ class_name = class_names.get(int(idx), f"Class {idx}")
2130
+ formatted_preds.append(f"{class_name}: {prob*100:.1f}%")
2131
+
2132
+ item.prediction_details = "<br>".join(formatted_preds)
2133
+ else:
2134
+ raise TypeError(
2135
+ "The 'Predictions' feature mode requires a classification model "
2136
+ "(e.g., 'yolov8n-cls.pt') that returns class probabilities. "
2137
+ "The selected model did not provide this output. "
2138
+ "Please use 'Embed Features' mode for this model."
2139
+ )
1675
2140
 
2141
+ if progress_bar:
2142
+ progress_bar.update_progress()
2143
+
2144
+ # After processing is complete, update tooltips
2145
+ for item in valid_data_items:
2146
+ if hasattr(item, 'update_tooltip'):
2147
+ item.update_tooltip()
2148
+
2149
+ return features_list, probabilities_dict
2150
+
2151
+ def _extract_color_features(self, data_items, progress_bar=None, bins=32):
2152
+ """
2153
+ Extracts color-based features from annotation crops.
2154
+
2155
+ Features extracted per annotation:
2156
+ - Mean, standard deviation, skewness, and kurtosis for each RGB channel
2157
+ - Normalized histogram for each RGB channel
2158
+ - Grayscale statistics: mean, std, range
2159
+ - Geometric features: area, perimeter (if available)
2160
+ Returns:
2161
+ features: np.ndarray of shape (N, feature_dim)
2162
+ valid_data_items: list of AnnotationDataItem with valid crops
2163
+ """
1676
2164
  if progress_bar:
1677
- progress_bar.set_title(f"Preparing images...")
2165
+ progress_bar.set_title("Extracting features...")
1678
2166
  progress_bar.start_progress(len(data_items))
1679
2167
 
1680
- # 1. Prepare a list of all valid images and their corresponding data items.
1681
- image_list = []
2168
+ features = []
1682
2169
  valid_data_items = []
2170
+
1683
2171
  for item in data_items:
1684
2172
  pixmap = item.annotation.get_cropped_image()
1685
2173
  if pixmap and not pixmap.isNull():
1686
- image_np = pixmap_to_numpy(pixmap)
1687
- image_list.append(image_np)
2174
+ # Convert QPixmap to numpy array (H, W, 3)
2175
+ arr = pixmap_to_numpy(pixmap)
2176
+ pixels = arr.reshape(-1, 3)
2177
+
2178
+ # Basic color statistics
2179
+ mean_color = np.mean(pixels, axis=0)
2180
+ std_color = np.std(pixels, axis=0)
2181
+
2182
+ # Skewness and kurtosis for each channel
2183
+ epsilon = 1e-8 # Prevent division by zero
2184
+ centered_pixels = pixels - mean_color
2185
+ skew_color = np.mean(centered_pixels ** 3, axis=0) / (std_color ** 3 + epsilon)
2186
+ kurt_color = np.mean(centered_pixels ** 4, axis=0) / (std_color ** 4 + epsilon) - 3
2187
+
2188
+ # Normalized histograms for each channel
2189
+ histograms = [
2190
+ np.histogram(pixels[:, i], bins=bins, range=(0, 255))[0]
2191
+ for i in range(3)
2192
+ ]
2193
+ histograms = [
2194
+ h / h.sum() if h.sum() > 0 else np.zeros(bins)
2195
+ for h in histograms
2196
+ ]
2197
+
2198
+ # Grayscale statistics
2199
+ gray_arr = np.dot(arr[..., :3], [0.2989, 0.5870, 0.1140])
2200
+ grayscale_stats = np.array([
2201
+ np.mean(gray_arr),
2202
+ np.std(gray_arr),
2203
+ np.ptp(gray_arr)
2204
+ ])
2205
+
2206
+ # Geometric features (area, perimeter)
2207
+ area = getattr(item.annotation, 'area', 0.0)
2208
+ perimeter = getattr(item.annotation, 'perimeter', 0.0)
2209
+ geometric_features = np.array([area, perimeter])
2210
+
2211
+ # Concatenate all features into a single vector
2212
+ current_features = np.concatenate([
2213
+ mean_color,
2214
+ std_color,
2215
+ skew_color,
2216
+ kurt_color,
2217
+ *histograms,
2218
+ grayscale_stats,
2219
+ geometric_features
2220
+ ])
2221
+
2222
+ features.append(current_features)
1688
2223
  valid_data_items.append(item)
1689
- else:
1690
- print(f"Warning: Could not get cropped image for annotation ID {item.annotation.id}. Skipping.")
1691
2224
 
1692
2225
  if progress_bar:
1693
2226
  progress_bar.update_progress()
1694
2227
 
2228
+ return np.array(features), valid_data_items
2229
+
2230
+ def _extract_yolo_features(self, data_items, model_info, progress_bar=None):
2231
+ """Extracts features from annotation crops using a YOLO model."""
2232
+ model_name, feature_mode = model_info
2233
+
2234
+ # Load the model
2235
+ model, imgsz = self._load_yolo_model(model_name, feature_mode)
2236
+ if model is None:
2237
+ return np.array([]), []
2238
+
2239
+ # Prepare images from data items
2240
+ image_list, valid_data_items = self._prepare_images_from_data_items(data_items, progress_bar)
1695
2241
  if not valid_data_items:
1696
- print("Warning: No valid images found to process.")
1697
2242
  return np.array([]), []
1698
-
1699
- embeddings_list = []
2243
+
2244
+ # Set up prediction parameters
2245
+ kwargs = {
2246
+ 'stream': True,
2247
+ 'imgsz': imgsz,
2248
+ 'half': True,
2249
+ 'device': self.device,
2250
+ 'verbose': False
2251
+ }
2252
+
2253
+ # Get results based on feature mode
2254
+ if feature_mode == "Embed Features":
2255
+ results_generator = model.embed(image_list, **kwargs)
2256
+ else:
2257
+ results_generator = model.predict(image_list, **kwargs)
2258
+
2259
+ if progress_bar:
2260
+ progress_bar.set_title("Extracting features...")
2261
+ progress_bar.start_progress(len(valid_data_items))
1700
2262
 
1701
2263
  try:
1702
- if progress_bar:
1703
- progress_bar.set_busy_mode(f"Extracting features with {os.path.basename(model_name)}...")
1704
-
1705
- kwargs = {
1706
- 'stream': True,
1707
- 'imgsz': self.imgsz,
1708
- 'half': True,
1709
- 'device': self.device,
1710
- 'verbose': False
1711
- }
1712
-
1713
- # 2. Choose between embed() and predict() based on feature mode
1714
- using_embed_method = feature_mode == "Embed Features"
2264
+ features_list, _ = self._process_model_results(results_generator,
2265
+ valid_data_items,
2266
+ feature_mode,
2267
+ progress_bar=progress_bar)
1715
2268
 
1716
- if using_embed_method:
1717
- print("Using embed() method")
1718
-
1719
- # Use model.embed() method (uses the second to last layer)
1720
- results_generator = self.loaded_model.embed(image_list, **kwargs)
1721
-
1722
- else:
1723
- print("Using predict() method")
1724
-
1725
- # Use model.predict() method for classification probabilities
1726
- results_generator = self.loaded_model.predict(image_list, **kwargs)
1727
-
1728
- if progress_bar:
1729
- progress_bar.set_title(f"Extracting features with {os.path.basename(model_name)}...")
1730
- progress_bar.start_progress(len(valid_data_items))
2269
+ return np.array(features_list), valid_data_items
1731
2270
 
1732
- # 3. Process the results from the generator - different handling based on method
1733
- for result in results_generator:
1734
- if using_embed_method:
1735
- try:
1736
- # With embed(), result is directly a tensor
1737
- embedding = result.cpu().numpy().flatten()
1738
- embeddings_list.append(embedding)
1739
- except Exception as e:
1740
- print(f"Error processing embedding: {e}")
1741
- raise TypeError(
1742
- f"Model '{os.path.basename(model_name)}' did not return valid embeddings. "
1743
- f"Error: {str(e)}"
1744
- )
1745
- else:
1746
- # Classification mode: We expect probability vectors
1747
- if hasattr(result, 'probs') and result.probs is not None:
1748
- embedding = result.probs.data.cpu().numpy().squeeze()
1749
- embeddings_list.append(embedding)
1750
- else:
1751
- raise TypeError(
1752
- f"Model '{os.path.basename(model_name)}' did not return probability vectors. "
1753
- "Make sure this is a classification model."
1754
- )
1755
-
1756
- # Update progress for each item as it's processed from the stream.
1757
- if progress_bar:
1758
- progress_bar.update_progress()
1759
-
1760
- if not embeddings_list:
1761
- print("Warning: No features were extracted. The model may have failed.")
1762
- return np.array([]), []
1763
-
1764
- embeddings = np.array(embeddings_list)
1765
- except Exception as e:
1766
- print(f"ERROR: An error occurred during feature extraction: {e}")
1767
- return np.array([]), []
1768
2271
  finally:
1769
- # Clean up CUDA memory after the operation
1770
2272
  if torch.cuda.is_available():
1771
2273
  torch.cuda.empty_cache()
1772
-
1773
- print(f"Successfully extracted {len(embeddings)} features with shape {embeddings.shape}")
1774
- return embeddings, valid_data_items
1775
2274
 
1776
2275
  def _extract_features(self, data_items, progress_bar=None):
1777
- """
1778
- Dispatcher method to call the appropriate feature extraction function.
1779
- It now passes the progress_bar object to the sub-methods.
1780
- """
2276
+ """Dispatcher to call the appropriate feature extraction function."""
2277
+ # Get the selected model and feature mode from the model settings widget
1781
2278
  model_name, feature_mode = self.model_settings_widget.get_selected_model()
1782
-
1783
- # Handle tuple or string return value
1784
- if isinstance(model_name, tuple) and len(model_name) >= 3:
2279
+
2280
+ if isinstance(model_name, tuple):
1785
2281
  model_name = model_name[0]
1786
- else:
1787
- model_name = model_name
1788
2282
 
1789
2283
  if not model_name:
1790
- print("No model selected or path provided.")
1791
2284
  return np.array([]), []
1792
2285
 
1793
- # --- MODIFIED: Pass the progress_bar object ---
1794
2286
  if model_name == "Color Features":
1795
2287
  return self._extract_color_features(data_items, progress_bar=progress_bar)
2288
+
1796
2289
  elif ".pt" in model_name:
1797
- # Pass the full model_info which may include embed layers
1798
2290
  return self._extract_yolo_features(data_items, (model_name, feature_mode), progress_bar=progress_bar)
1799
- else:
1800
- print(f"Unknown or invalid feature model selected: {model_name}")
1801
- return np.array([]), []
2291
+
2292
+ return np.array([]), []
1802
2293
 
1803
2294
  def _run_dimensionality_reduction(self, features, params):
1804
2295
  """
1805
- Runs PCA, UMAP or t-SNE on the feature matrix using provided parameters.
2296
+ Runs dimensionality reduction with automatic PCA preprocessing for UMAP and t-SNE.
2297
+
2298
+ Args:
2299
+ features (np.ndarray): Feature matrix of shape (N, D).
2300
+ params (dict): Embedding parameters, including technique and its hyperparameters.
2301
+
2302
+ Returns:
2303
+ np.ndarray or None: 2D embedded features of shape (N, 2), or None on failure.
1806
2304
  """
1807
- technique = params.get('technique', 'UMAP') # Changed default to UMAP
1808
- random_state = 42
1809
-
1810
- print(f"Running {technique} on {len(features)} items with params: {params}")
1811
- if len(features) <= 2: # UMAP/t-SNE need at least a few points
1812
- print("Not enough data points for dimensionality reduction.")
2305
+ technique = params.get('technique', 'UMAP')
2306
+ # Default number of components to use for PCA preprocessing
2307
+ pca_components = params.get('pca_components', 50)
2308
+
2309
+ if len(features) <= 2:
2310
+ # Not enough samples for dimensionality reduction
1813
2311
  return None
1814
2312
 
1815
2313
  try:
1816
- # Scaling is crucial, your implementation is already correct
1817
- scaler = StandardScaler()
1818
- features_scaled = scaler.fit_transform(features)
1819
-
2314
+ # Standardize features before reduction
2315
+ features_scaled = StandardScaler().fit_transform(features)
2316
+
2317
+ # Apply PCA preprocessing automatically for UMAP or TSNE
2318
+ # (only if the feature dimension is larger than the target PCA components)
2319
+ if technique in ["UMAP", "TSNE"] and features_scaled.shape[1] > pca_components:
2320
+ # Ensure pca_components doesn't exceed number of samples or features
2321
+ pca_components = min(pca_components, features_scaled.shape[0] - 1, features_scaled.shape[1])
2322
+ print(f"Applying PCA preprocessing to {pca_components} components before {technique}")
2323
+ pca = PCA(n_components=pca_components, random_state=42)
2324
+ features_scaled = pca.fit_transform(features_scaled)
2325
+ variance_explained = sum(pca.explained_variance_ratio_) * 100
2326
+ print(f"Variance explained by PCA: {variance_explained:.1f}%")
2327
+
2328
+ # Proceed with the selected dimensionality reduction technique
1820
2329
  if technique == "UMAP":
1821
- # Get hyperparameters from params, with sensible defaults
1822
- n_neighbors = params.get('n_neighbors', 15)
1823
- min_dist = params.get('min_dist', 0.1)
1824
- metric = params.get('metric', 'cosine') # Allow metric to be specified
1825
-
2330
+ n_neighbors = min(params.get('n_neighbors', 15), len(features_scaled) - 1)
1826
2331
  reducer = UMAP(
1827
- n_components=2,
1828
- random_state=random_state,
1829
- n_neighbors=min(n_neighbors, len(features_scaled) - 1),
1830
- min_dist=min_dist,
1831
- metric=metric # Use the specified metric
2332
+ n_components=2,
2333
+ random_state=42,
2334
+ n_neighbors=n_neighbors,
2335
+ min_dist=params.get('min_dist', 0.1),
2336
+ metric=params.get('metric', 'cosine')
1832
2337
  )
1833
-
1834
2338
  elif technique == "TSNE":
1835
- perplexity = params.get('perplexity', 30)
1836
- early_exaggeration = params.get('early_exaggeration', 12.0)
1837
- learning_rate = params.get('learning_rate', 'auto')
1838
-
2339
+ perplexity = min(params.get('perplexity', 30), len(features_scaled) - 1)
1839
2340
  reducer = TSNE(
1840
- n_components=2,
1841
- random_state=random_state,
1842
- perplexity=min(perplexity, len(features_scaled) - 1),
1843
- early_exaggeration=early_exaggeration,
1844
- learning_rate=learning_rate,
1845
- init='pca' # Improves stability and speed
2341
+ n_components=2,
2342
+ random_state=42,
2343
+ perplexity=perplexity,
2344
+ early_exaggeration=params.get('early_exaggeration', 12.0),
2345
+ learning_rate=params.get('learning_rate', 'auto'),
2346
+ init='pca'
1846
2347
  )
1847
-
1848
2348
  elif technique == "PCA":
1849
- reducer = PCA(n_components=2, random_state=random_state)
1850
-
2349
+ reducer = PCA(n_components=2, random_state=42)
1851
2350
  else:
1852
- print(f"Unknown dimensionality reduction technique: {technique}")
1853
2351
  return None
1854
-
2352
+
2353
+ # Fit and transform the features
1855
2354
  return reducer.fit_transform(features_scaled)
1856
-
2355
+
1857
2356
  except Exception as e:
1858
2357
  print(f"Error during {technique} dimensionality reduction: {e}")
1859
2358
  return None
@@ -1861,91 +2360,115 @@ class ExplorerWindow(QMainWindow):
1861
2360
  def _update_data_items_with_embedding(self, data_items, embedded_features):
1862
2361
  """Updates AnnotationDataItem objects with embedding results."""
1863
2362
  scale_factor = 4000
1864
- min_vals = np.min(embedded_features, axis=0)
1865
- max_vals = np.max(embedded_features, axis=0)
2363
+ min_vals, max_vals = np.min(embedded_features, axis=0), np.max(embedded_features, axis=0)
1866
2364
  range_vals = max_vals - min_vals
1867
-
1868
2365
  for i, item in enumerate(data_items):
1869
- # Normalize coordinates for consistent display
1870
2366
  norm_x = (embedded_features[i, 0] - min_vals[0]) / range_vals[0] if range_vals[0] > 0 else 0.5
1871
2367
  norm_y = (embedded_features[i, 1] - min_vals[1]) / range_vals[1] if range_vals[1] > 0 else 0.5
1872
- # Scale and center the points in the view
1873
2368
  item.embedding_x = (norm_x * scale_factor) - (scale_factor / 2)
1874
2369
  item.embedding_y = (norm_y * scale_factor) - (scale_factor / 2)
1875
2370
 
1876
2371
  def run_embedding_pipeline(self):
1877
2372
  """
1878
- Orchestrates the feature extraction and dimensionality reduction pipeline.
1879
- This version correctly re-runs reduction on cached features when parameters change.
2373
+ Orchestrates feature extraction and dimensionality reduction.
2374
+ If the EmbeddingViewer is in isolate mode, it will use only the visible
2375
+ (isolated) points as input for the pipeline.
1880
2376
  """
1881
- if not self.current_data_items:
2377
+ items_to_embed = []
2378
+ if self.embedding_viewer.isolated_mode:
2379
+ items_to_embed = [point.data_item for point in self.embedding_viewer.isolated_points]
2380
+ else:
2381
+ items_to_embed = self.current_data_items
2382
+
2383
+ if not items_to_embed:
1882
2384
  print("No items to process for embedding.")
1883
2385
  return
1884
2386
 
1885
- # 1. Get current parameters from the UI
2387
+ self.annotation_viewer.clear_selection()
2388
+ if self.annotation_viewer.isolated_mode:
2389
+ self.annotation_viewer.show_all_annotations()
2390
+
2391
+ self.embedding_viewer.render_selection_from_ids(set())
2392
+ self.update_button_states()
2393
+
2394
+ self.current_embedding_model_info = self.model_settings_widget.get_selected_model()
2395
+
1886
2396
  embedding_params = self.embedding_settings_widget.get_embedding_parameters()
1887
- model_info = self.model_settings_widget.get_selected_model() # Now returns tuple (model_name, feature_mode)
1888
-
1889
- # Unpack model info - both model_name and feature_mode are used for caching
1890
- if isinstance(model_info, tuple):
1891
- selected_model, selected_feature_mode = model_info
1892
- # Create a unique cache key that includes both model and feature mode
1893
- cache_key = f"{selected_model}_{selected_feature_mode}"
2397
+ selected_model, selected_feature_mode = self.current_embedding_model_info
2398
+
2399
+ # If the model name is a path, use only its base name.
2400
+ if os.path.sep in selected_model or '/' in selected_model:
2401
+ sanitized_model_name = os.path.basename(selected_model)
1894
2402
  else:
1895
- selected_model = model_info
1896
- selected_feature_mode = "default"
1897
- cache_key = f"{selected_model}_{selected_feature_mode}"
1898
-
1899
- technique = embedding_params['technique']
2403
+ sanitized_model_name = selected_model
2404
+
2405
+ # Replace characters that might be problematic in filenames
2406
+ sanitized_model_name = sanitized_model_name.replace(' ', '_')
2407
+ # Also replace the forward slash to handle "N/A"
2408
+ sanitized_feature_mode = selected_feature_mode.replace(' ', '_').replace('/', '_')
2409
+
2410
+ model_key = f"{sanitized_model_name}_{sanitized_feature_mode}"
1900
2411
 
1901
2412
  QApplication.setOverrideCursor(Qt.WaitCursor)
1902
- progress_bar = ProgressBar(self, "Generating Embedding Visualization")
2413
+ progress_bar = ProgressBar(self, "Processing Annotations")
1903
2414
  progress_bar.show()
1904
-
1905
- try:
1906
- # 2. Decide whether to use cached features or extract new ones
1907
- # Now checks both model name AND feature mode
1908
- if self.current_features is None or cache_key != self.current_feature_generating_model:
1909
- # SLOW PATH: Extract and cache new features
1910
- features, valid_data_items = self._extract_features(self.current_data_items, progress_bar=progress_bar)
1911
-
1912
- self.current_features = features
1913
- self.current_feature_generating_model = cache_key # Store the complete cache key
1914
- self.current_data_items = valid_data_items
1915
- self.annotation_viewer.update_annotations(self.current_data_items)
1916
- else:
1917
- # FAST PATH: Use existing features
1918
- print("Using cached features. Skipping feature extraction.")
1919
- features = self.current_features
1920
2415
 
1921
- if features is None or len(features) == 0:
1922
- print("No valid features available. Aborting embedding.")
2416
+ try:
2417
+ progress_bar.set_busy_mode("Checking feature cache...")
2418
+ cached_features, items_to_process = self.feature_store.get_features(items_to_embed, model_key)
2419
+ print(f"Found {len(cached_features)} features in cache. Need to compute {len(items_to_process)}.")
2420
+
2421
+ if items_to_process:
2422
+ newly_extracted_features, valid_items_processed = self._extract_features(items_to_process,
2423
+ progress_bar=progress_bar)
2424
+ if len(newly_extracted_features) > 0:
2425
+ progress_bar.set_busy_mode("Saving new features to cache...")
2426
+ self.feature_store.add_features(valid_items_processed, newly_extracted_features, model_key)
2427
+ new_features_dict = {item.annotation.id: vec for item, vec in zip(valid_items_processed,
2428
+ newly_extracted_features)}
2429
+ cached_features.update(new_features_dict)
2430
+
2431
+ if not cached_features:
2432
+ print("No features found or computed. Aborting.")
1923
2433
  return
1924
2434
 
1925
- # 3. Run dimensionality reduction with the latest parameters
1926
- progress_bar.set_busy_mode(f"Running {technique} dimensionality reduction...")
2435
+ final_feature_list = []
2436
+ final_data_items = []
2437
+ for item in items_to_embed:
2438
+ if item.annotation.id in cached_features:
2439
+ final_feature_list.append(cached_features[item.annotation.id])
2440
+ final_data_items.append(item)
2441
+
2442
+ features = np.array(final_feature_list)
2443
+ self.current_data_items = final_data_items
2444
+ self.annotation_viewer.update_annotations(self.current_data_items)
2445
+
2446
+ progress_bar.set_busy_mode("Running dimensionality reduction...")
1927
2447
  embedded_features = self._run_dimensionality_reduction(features, embedding_params)
1928
- progress_bar.update_progress()
1929
-
2448
+
1930
2449
  if embedded_features is None:
1931
2450
  return
1932
2451
 
1933
- # 4. Update the visualization with the new 2D layout
1934
2452
  progress_bar.set_busy_mode("Updating visualization...")
1935
2453
  self._update_data_items_with_embedding(self.current_data_items, embedded_features)
1936
-
1937
2454
  self.embedding_viewer.update_embeddings(self.current_data_items)
1938
2455
  self.embedding_viewer.show_embedding()
1939
2456
  self.embedding_viewer.fit_view_to_points()
1940
- progress_bar.update_progress()
1941
-
1942
- print(f"Successfully generated embedding for {len(self.current_data_items)} annotations using {technique}")
1943
-
1944
- except Exception as e:
1945
- print(f"Error during embedding pipeline: {e}")
1946
- self.embedding_viewer.clear_points()
1947
- self.embedding_viewer.show_placeholder()
1948
-
2457
+
2458
+ # Check if confidence scores are available to enable sorting
2459
+ _, feature_mode = self.current_embedding_model_info
2460
+ is_predict_mode = feature_mode == "Predictions"
2461
+ self.annotation_viewer.set_confidence_sort_availability(is_predict_mode)
2462
+
2463
+ # If using Predictions mode, update data items with probabilities for confidence sorting
2464
+ if is_predict_mode:
2465
+ for item in self.current_data_items:
2466
+ if item.annotation.id in cached_features:
2467
+ item.prediction_probabilities = cached_features[item.annotation.id]
2468
+
2469
+ # When a new embedding is run, any previous similarity sort becomes irrelevant
2470
+ self.annotation_viewer.active_ordered_ids = []
2471
+
1949
2472
  finally:
1950
2473
  QApplication.restoreOverrideCursor()
1951
2474
  progress_bar.finish_progress()
@@ -1956,20 +2479,15 @@ class ExplorerWindow(QMainWindow):
1956
2479
  """Refresh display: filter data and update annotation viewer."""
1957
2480
  QApplication.setOverrideCursor(Qt.WaitCursor)
1958
2481
  try:
1959
- # Get filtered data and store for potential embedding
1960
2482
  self.current_data_items = self.get_filtered_data_items()
1961
-
1962
- # --- MODIFIED: Invalidate the feature cache ---
1963
- # Since the filtered items have changed, the old features are no longer valid.
1964
2483
  self.current_features = None
1965
-
1966
- # Update annotation viewer with filtered data
1967
2484
  self.annotation_viewer.update_annotations(self.current_data_items)
1968
-
1969
- # Clear embedding viewer and show placeholder, as it is now out of sync
1970
2485
  self.embedding_viewer.clear_points()
1971
2486
  self.embedding_viewer.show_placeholder()
1972
-
2487
+
2488
+ # Reset sort options when filters change
2489
+ self.annotation_viewer.active_ordered_ids = []
2490
+ self.annotation_viewer.set_confidence_sort_availability(False)
1973
2491
  finally:
1974
2492
  QApplication.restoreOverrideCursor()
1975
2493
 
@@ -1979,89 +2497,144 @@ class ExplorerWindow(QMainWindow):
1979
2497
  self.annotation_viewer.apply_preview_label_to_selected(label)
1980
2498
  self.update_button_states()
1981
2499
 
2500
+ def delete_data_items(self, data_items_to_delete):
2501
+ """
2502
+ Permanently deletes a list of data items and their associated annotations
2503
+ and visual components from the explorer and the main application.
2504
+ """
2505
+ if not data_items_to_delete:
2506
+ return
2507
+
2508
+ print(f"Permanently deleting {len(data_items_to_delete)} item(s).")
2509
+ QApplication.setOverrideCursor(Qt.WaitCursor)
2510
+ try:
2511
+ deleted_ann_ids = {item.annotation.id for item in data_items_to_delete}
2512
+ annotations_to_delete_from_main_app = [item.annotation for item in data_items_to_delete]
2513
+
2514
+ # 1. Delete from the main application's data store
2515
+ self.annotation_window.delete_annotations(annotations_to_delete_from_main_app)
2516
+
2517
+ # 2. Remove from Explorer's internal data structures
2518
+ self.current_data_items = [
2519
+ item for item in self.current_data_items if item.annotation.id not in deleted_ann_ids
2520
+ ]
2521
+ for ann_id in deleted_ann_ids:
2522
+ if ann_id in self.data_item_cache:
2523
+ del self.data_item_cache[ann_id]
2524
+
2525
+ # 3. Remove from AnnotationViewer
2526
+ blocker = QSignalBlocker(self.annotation_viewer) # Block signals during mass removal
2527
+ for ann_id in deleted_ann_ids:
2528
+ if ann_id in self.annotation_viewer.annotation_widgets_by_id:
2529
+ widget = self.annotation_viewer.annotation_widgets_by_id.pop(ann_id)
2530
+ if widget in self.annotation_viewer.selected_widgets:
2531
+ self.annotation_viewer.selected_widgets.remove(widget)
2532
+ widget.setParent(None)
2533
+ widget.deleteLater()
2534
+ blocker.unblock()
2535
+ self.annotation_viewer.recalculate_widget_positions()
2536
+
2537
+ # 4. Remove from EmbeddingViewer
2538
+ blocker = QSignalBlocker(self.embedding_viewer.graphics_scene)
2539
+ for ann_id in deleted_ann_ids:
2540
+ if ann_id in self.embedding_viewer.points_by_id:
2541
+ point = self.embedding_viewer.points_by_id.pop(ann_id)
2542
+ self.embedding_viewer.graphics_scene.removeItem(point)
2543
+ blocker.unblock()
2544
+ self.embedding_viewer.on_selection_changed() # Trigger update of selection state
2545
+
2546
+ # 5. Update UI
2547
+ self.update_label_window_selection()
2548
+ self.update_button_states()
2549
+
2550
+ # 6. Refresh main window annotations list
2551
+ affected_images = {ann.image_path for ann in annotations_to_delete_from_main_app}
2552
+ for image_path in affected_images:
2553
+ self.image_window.update_image_annotations(image_path)
2554
+ self.annotation_window.load_annotations()
2555
+
2556
+ except Exception as e:
2557
+ print(f"Error during item deletion: {e}")
2558
+ finally:
2559
+ QApplication.restoreOverrideCursor()
2560
+
1982
2561
  def clear_preview_changes(self):
1983
- """Clear all preview changes and revert to original labels."""
2562
+ """
2563
+ Clears all preview changes in the annotation viewer and updates tooltips.
2564
+ """
1984
2565
  if hasattr(self, 'annotation_viewer'):
1985
2566
  self.annotation_viewer.clear_preview_states()
1986
- self.update_button_states()
1987
- print("Cleared all preview changes")
2567
+
2568
+ # After reverting, tooltips need to be updated to reflect original labels
2569
+ for widget in self.annotation_viewer.annotation_widgets_by_id.values():
2570
+ widget.update_tooltip()
2571
+ for point in self.embedding_viewer.points_by_id.values():
2572
+ point.update_tooltip()
2573
+
2574
+ # After reverting all changes, update the button states
2575
+ self.update_button_states()
2576
+ print("Cleared all pending changes.")
1988
2577
 
1989
2578
  def update_button_states(self):
1990
- """Update the state of Clear Preview and Apply buttons."""
1991
- has_changes = (hasattr(self, 'annotation_viewer') and self.annotation_viewer.has_preview_changes())
1992
-
2579
+ """Update the state of Clear Preview, Apply, and Find Similar buttons."""
2580
+ has_changes = self.annotation_viewer.has_preview_changes()
1993
2581
  self.clear_preview_button.setEnabled(has_changes)
1994
2582
  self.apply_button.setEnabled(has_changes)
1995
2583
 
2584
+ # Update tooltips with a summary of changes
1996
2585
  summary = self.annotation_viewer.get_preview_changes_summary()
1997
2586
  self.clear_preview_button.setToolTip(f"Clear all preview changes - {summary}")
1998
2587
  self.apply_button.setToolTip(f"Apply changes - {summary}")
1999
2588
 
2589
+ # Logic for the "Find Similar" button
2590
+ selection_exists = bool(self.annotation_viewer.selected_widgets)
2591
+ embedding_exists = bool(self.embedding_viewer.points_by_id) and self.current_embedding_model_info is not None
2592
+ self.annotation_viewer.find_similar_button.setEnabled(selection_exists and embedding_exists)
2593
+
2000
2594
  def apply(self):
2001
- """Apply any modifications to the actual annotations."""
2002
- # Make cursor busy
2595
+ """
2596
+ Apply all pending label modifications to the main application's data.
2597
+ """
2003
2598
  QApplication.setOverrideCursor(Qt.WaitCursor)
2004
-
2005
2599
  try:
2006
- applied_annotations = self.annotation_viewer.apply_preview_changes_permanently()
2007
-
2008
- if applied_annotations:
2009
- # Find which data items were affected and tell their visual components to update
2010
- changed_ids = {ann.id for ann in applied_annotations}
2011
- for item in self.current_data_items:
2012
- if item.annotation.id in changed_ids:
2013
- # Update annotation widget in the grid
2014
- widget = self.annotation_viewer.annotation_widgets_by_id.get(item.annotation.id)
2015
- if widget:
2016
- widget.update() # Repaint to show new permanent color
2017
-
2018
- # Update point in the embedding viewer
2019
- point = self.embedding_viewer.points_by_id.get(item.annotation.id)
2020
- if point:
2021
- point.update() # Repaint to show new permanent color
2022
-
2023
- # Update sorting if currently sorting by label
2024
- if self.annotation_viewer.sort_combo.currentText() == "Label":
2025
- self.annotation_viewer.recalculate_widget_positions()
2026
-
2027
- # Update the main application's data
2028
- affected_images = {ann.image_path for ann in applied_annotations}
2029
- for image_path in affected_images:
2030
- self.image_window.update_image_annotations(image_path)
2031
- self.annotation_window.load_annotations()
2032
-
2033
- # Clear selection and button states
2034
- self.annotation_viewer.clear_selection()
2035
- self.embedding_viewer.render_selection_from_ids(set()) # Clear embedding selection
2036
- self.update_button_states()
2037
-
2038
- print(f"Applied changes to {len(applied_annotations)} annotation(s)")
2039
- else:
2040
- print("No preview changes to apply")
2600
+ # --- 1. Process Label Changes ---
2601
+ applied_label_changes = []
2602
+ # Iterate over all current data items
2603
+ for item in self.current_data_items:
2604
+ if item.apply_preview_permanently():
2605
+ applied_label_changes.append(item.annotation)
2606
+
2607
+ # --- 2. Update UI if any changes were made ---
2608
+ if not applied_label_changes:
2609
+ print("No pending changes to apply.")
2610
+ return
2611
+
2612
+ # Update the main application's data and UI
2613
+ affected_images = {ann.image_path for ann in applied_label_changes}
2614
+ for image_path in affected_images:
2615
+ self.image_window.update_image_annotations(image_path)
2616
+ self.annotation_window.load_annotations()
2617
+
2618
+ # Refresh the annotation viewer since its underlying data has changed
2619
+ self.annotation_viewer.update_annotations(self.current_data_items)
2620
+
2621
+ # Reset selections and button states
2622
+ self.embedding_viewer.render_selection_from_ids(set())
2623
+ self.update_label_window_selection()
2624
+ self.update_button_states()
2625
+
2626
+ print("Applied changes successfully.")
2041
2627
 
2042
2628
  except Exception as e:
2043
2629
  print(f"Error applying modifications: {e}")
2044
2630
  finally:
2045
- # Restore cursor
2046
2631
  QApplication.restoreOverrideCursor()
2047
2632
 
2048
2633
  def _cleanup_resources(self):
2049
- """
2050
- Clean up heavy resources like the loaded model and clear GPU cache.
2051
- This is called when the window is closed to free up memory.
2052
- """
2053
- print("Cleaning up Explorer resources...")
2054
-
2055
- # Reset model and feature caches
2056
- self.imgsz = 128
2634
+ """Clean up resources."""
2057
2635
  self.loaded_model = None
2058
2636
  self.model_path = ""
2059
2637
  self.current_features = None
2060
2638
  self.current_feature_generating_model = ""
2061
-
2062
- # Clear CUDA cache if available to free up GPU memory
2063
2639
  if torch.cuda.is_available():
2064
- print("Clearing CUDA cache.")
2065
- torch.cuda.empty_cache()
2066
-
2067
- print("Cleanup complete.")
2640
+ torch.cuda.empty_cache()