coralnet-toolbox 0.0.67__py2.py3-none-any.whl → 0.0.68__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,13 +5,11 @@ import warnings
5
5
 
6
6
  from ultralytics import YOLO
7
7
 
8
- from coralnet_toolbox.MachineLearning.Community.cfg import get_available_configs
9
-
10
8
  from coralnet_toolbox.Icons import get_icon
11
9
  from coralnet_toolbox.utilities import pixmap_to_numpy
12
10
 
13
- from PyQt5.QtGui import QIcon, QPen, QColor, QPainter, QImage, QBrush, QPainterPath, QPolygonF, QMouseEvent
14
- from PyQt5.QtCore import Qt, QTimer, QSize, QRect, QRectF, QPointF, pyqtSignal, QSignalBlocker, pyqtSlot
11
+ from PyQt5.QtGui import QIcon, QPen, QColor, QPainter, QBrush, QPainterPath, QMouseEvent
12
+ from PyQt5.QtCore import Qt, QTimer, QRect, QRectF, QPointF, pyqtSignal, QSignalBlocker, pyqtSlot
15
13
 
16
14
  from PyQt5.QtWidgets import (QVBoxLayout, QHBoxLayout, QGraphicsView, QScrollArea,
17
15
  QGraphicsScene, QPushButton, QComboBox, QLabel, QWidget, QGridLayout,
@@ -21,12 +19,12 @@ from PyQt5.QtWidgets import (QVBoxLayout, QHBoxLayout, QGraphicsView, QScrollAre
21
19
  QGraphicsRectItem, QRubberBand, QStyleOptionGraphicsItem,
22
20
  QTabWidget, QLineEdit, QFileDialog)
23
21
 
24
- from .QtAnnotationDataItem import AnnotationDataItem
25
- from .QtEmbeddingPointItem import EmbeddingPointItem
26
- from .QtAnnotationImageWidget import AnnotationImageWidget
27
- from .QtSettingsWidgets import ModelSettingsWidget
28
- from .QtSettingsWidgets import EmbeddingSettingsWidget
29
- from .QtSettingsWidgets import AnnotationSettingsWidget
22
+ from coralnet_toolbox.Explorer.QtDataItem import AnnotationDataItem
23
+ from coralnet_toolbox.Explorer.QtDataItem import EmbeddingPointItem
24
+ from coralnet_toolbox.Explorer.QtDataItem import AnnotationImageWidget
25
+ from coralnet_toolbox.Explorer.QtSettingsWidgets import ModelSettingsWidget
26
+ from coralnet_toolbox.Explorer.QtSettingsWidgets import EmbeddingSettingsWidget
27
+ from coralnet_toolbox.Explorer.QtSettingsWidgets import AnnotationSettingsWidget
30
28
 
31
29
  from coralnet_toolbox.QtProgressBar import ProgressBar
32
30
 
@@ -50,60 +48,46 @@ warnings.filterwarnings("ignore", category=DeprecationWarning)
50
48
  # Constants
51
49
  # ----------------------------------------------------------------------------------------------------------------------
52
50
 
53
-
54
- POINT_SIZE = 15
55
51
  POINT_WIDTH = 3
56
52
 
57
-
58
53
  # ----------------------------------------------------------------------------------------------------------------------
59
54
  # Viewers
60
55
  # ----------------------------------------------------------------------------------------------------------------------
61
56
 
62
57
 
63
- class EmbeddingViewer(QWidget): # Change inheritance to QWidget
58
+ class EmbeddingViewer(QWidget):
64
59
  """Custom QGraphicsView for interactive embedding visualization with zooming, panning, and selection."""
65
-
66
- # Define signal to report selection changes
67
- selection_changed = pyqtSignal(list) # list of all currently selected annotation IDs
68
- reset_view_requested = pyqtSignal() # Signal to reset the view to fit all points
69
-
60
+ selection_changed = pyqtSignal(list)
61
+ reset_view_requested = pyqtSignal()
62
+
70
63
  def __init__(self, parent=None):
71
- # Create the graphics scene first
64
+ """Initialize the EmbeddingViewer widget."""
72
65
  self.graphics_scene = QGraphicsScene()
73
66
  self.graphics_scene.setSceneRect(-5000, -5000, 10000, 10000)
74
-
75
- # Initialize as a QWidget
67
+
76
68
  super(EmbeddingViewer, self).__init__(parent)
77
69
  self.explorer_window = parent
78
-
79
- # Create the actual graphics view
70
+
80
71
  self.graphics_view = QGraphicsView(self.graphics_scene)
81
72
  self.graphics_view.setRenderHint(QPainter.Antialiasing)
82
73
  self.graphics_view.setDragMode(QGraphicsView.ScrollHandDrag)
83
74
  self.graphics_view.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
84
75
  self.graphics_view.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
85
76
  self.graphics_view.setMinimumHeight(200)
86
-
87
- # Custom rubber_band state variables
77
+
88
78
  self.rubber_band = None
89
79
  self.rubber_band_origin = QPointF()
90
80
  self.selection_at_press = None
91
-
92
- self.points_by_id = {} # Map annotation ID to embedding point
93
- self.previous_selection_ids = set() # Track previous selection to detect changes
94
-
81
+ self.points_by_id = {}
82
+ self.previous_selection_ids = set()
83
+
95
84
  self.animation_offset = 0
96
85
  self.animation_timer = QTimer()
97
86
  self.animation_timer.timeout.connect(self.animate_selection)
98
87
  self.animation_timer.setInterval(100)
99
-
100
- # Connect the scene's selection signal
88
+
101
89
  self.graphics_scene.selectionChanged.connect(self.on_selection_changed)
102
-
103
- # Setup the UI with header
104
90
  self.setup_ui()
105
-
106
- # Connect mouse events to the graphics view
107
91
  self.graphics_view.mousePressEvent = self.mousePressEvent
108
92
  self.graphics_view.mouseDoubleClickEvent = self.mouseDoubleClickEvent
109
93
  self.graphics_view.mouseReleaseEvent = self.mouseReleaseEvent
@@ -115,31 +99,21 @@ class EmbeddingViewer(QWidget): # Change inheritance to QWidget
115
99
  layout = QVBoxLayout(self)
116
100
  layout.setContentsMargins(0, 0, 0, 0)
117
101
 
118
- # Header layout
119
102
  header_layout = QHBoxLayout()
120
-
121
- # Home button
122
103
  self.home_button = QPushButton("Home")
123
104
  self.home_button.setToolTip("Reset view to fit all points")
124
105
  self.home_button.clicked.connect(self.reset_view)
125
106
  header_layout.addWidget(self.home_button)
126
-
127
- # Add stretch to push future controls to the right if needed
128
107
  header_layout.addStretch()
129
-
130
108
  layout.addLayout(header_layout)
131
-
132
- # Add the graphics view
133
109
  layout.addWidget(self.graphics_view)
134
- # Add a placeholder label when no embedding is available
135
110
  self.placeholder_label = QLabel(
136
111
  "No embedding data available.\nPress 'Apply Embedding' to generate visualization."
137
112
  )
138
113
  self.placeholder_label.setAlignment(Qt.AlignCenter)
139
114
  self.placeholder_label.setStyleSheet("color: gray; font-size: 14px;")
140
- layout.addWidget(self.placeholder_label)
141
115
 
142
- # Initially show placeholder
116
+ layout.addWidget(self.placeholder_label)
143
117
  self.show_placeholder()
144
118
 
145
119
  def reset_view(self):
@@ -159,42 +133,88 @@ class EmbeddingViewer(QWidget): # Change inheritance to QWidget
159
133
  self.home_button.setEnabled(True)
160
134
 
161
135
  # Delegate graphics view methods
162
- def setRenderHint(self, hint):
136
+ def setRenderHint(self, hint):
137
+ """Set render hint for the graphics view."""
163
138
  self.graphics_view.setRenderHint(hint)
164
-
165
- def setDragMode(self, mode):
139
+
140
+ def setDragMode(self, mode):
141
+ """Set drag mode for the graphics view."""
166
142
  self.graphics_view.setDragMode(mode)
167
-
168
- def setTransformationAnchor(self, anchor):
143
+
144
+ def setTransformationAnchor(self, anchor):
145
+ """Set transformation anchor for the graphics view."""
169
146
  self.graphics_view.setTransformationAnchor(anchor)
170
-
171
- def setResizeAnchor(self, anchor):
147
+
148
+ def setResizeAnchor(self, anchor):
149
+ """Set resize anchor for the graphics view."""
172
150
  self.graphics_view.setResizeAnchor(anchor)
173
-
174
- def mapToScene(self, point):
151
+
152
+ def mapToScene(self, point):
153
+ """Map a point to the scene coordinates."""
175
154
  return self.graphics_view.mapToScene(point)
176
155
 
177
- def scale(self, sx, sy):
156
+ def scale(self, sx, sy):
157
+ """Scale the graphics view."""
178
158
  self.graphics_view.scale(sx, sy)
179
-
180
- def translate(self, dx, dy):
159
+
160
+ def translate(self, dx, dy):
161
+ """Translate the graphics view."""
181
162
  self.graphics_view.translate(dx, dy)
182
-
183
- def fitInView(self, rect, aspect_ratio):
163
+
164
+ def fitInView(self, rect, aspect_ratio):
165
+ """Fit the view to a rectangle with aspect ratio."""
184
166
  self.graphics_view.fitInView(rect, aspect_ratio)
167
+
168
+ def keyPressEvent(self, event):
169
+ """Handles key presses for deleting selected points."""
170
+ # Check if the pressed key is Delete/Backspace AND the Control key is held down
171
+ if event.key() in (Qt.Key_Delete, Qt.Key_Backspace) and event.modifiers() == Qt.ControlModifier:
172
+ # Get the currently selected items from the graphics scene
173
+ selected_items = self.graphics_scene.selectedItems()
174
+
175
+ if not selected_items:
176
+ super().keyPressEvent(event)
177
+ return
178
+
179
+ print(f"Marking {len(selected_items)} points for deletion.")
180
+
181
+ # Mark each item for deletion and remove it from the scene
182
+ for item in selected_items:
183
+ if isinstance(item, EmbeddingPointItem):
184
+ # Mark the central data item for deletion
185
+ item.data_item.mark_for_deletion()
186
+
187
+ # Remove the point from our internal lookup
188
+ ann_id = item.data_item.annotation.id
189
+ if ann_id in self.points_by_id:
190
+ del self.points_by_id[ann_id]
191
+
192
+ # Remove the point from the visual scene
193
+ self.graphics_scene.removeItem(item)
194
+
195
+ # Trigger a selection change to clear the selection state
196
+ # and notify the ExplorerWindow.
197
+ self.on_selection_changed()
198
+
199
+ # Accept the event to prevent it from being processed further
200
+ event.accept()
201
+ else:
202
+ # Pass any other key presses to the default handler
203
+ super().keyPressEvent(event)
185
204
 
186
205
  def mousePressEvent(self, event):
187
206
  """Handle mouse press for selection (point or rubber band) and panning."""
188
207
  if event.button() == Qt.LeftButton and event.modifiers() == Qt.ControlModifier:
189
- # Check if the click is on an existing point
190
208
  item_at_pos = self.graphics_view.itemAt(event.pos())
191
209
  if isinstance(item_at_pos, EmbeddingPointItem):
192
- # If so, toggle its selection state and do nothing else
193
210
  self.graphics_view.setDragMode(QGraphicsView.NoDrag)
194
- item_at_pos.setSelected(not item_at_pos.isSelected())
195
- return # Event handled
211
+ # The viewer (controller) directly changes the state on the data item.
212
+ is_currently_selected = item_at_pos.data_item.is_selected
213
+ item_at_pos.data_item.set_selected(not is_currently_selected)
214
+ item_at_pos.setSelected(not is_currently_selected) # Keep scene selection in sync
215
+ self.on_selection_changed() # Manually trigger update
216
+ return
196
217
 
197
- # If the click was on the background, proceed with rubber band selection
198
218
  self.selection_at_press = set(self.graphics_scene.selectedItems())
199
219
  self.graphics_view.setDragMode(QGraphicsView.NoDrag)
200
220
  self.rubber_band_origin = self.graphics_view.mapToScene(event.pos())
@@ -204,87 +224,61 @@ class EmbeddingViewer(QWidget): # Change inheritance to QWidget
204
224
  self.graphics_scene.addItem(self.rubber_band)
205
225
 
206
226
  elif event.button() == Qt.RightButton:
207
- # Handle panning
208
227
  self.graphics_view.setDragMode(QGraphicsView.ScrollHandDrag)
209
228
  left_event = QMouseEvent(event.type(), event.localPos(), Qt.LeftButton, Qt.LeftButton, event.modifiers())
210
229
  QGraphicsView.mousePressEvent(self.graphics_view, left_event)
211
230
  else:
212
- # Handle standard single-item selection
213
231
  self.graphics_view.setDragMode(QGraphicsView.NoDrag)
214
232
  QGraphicsView.mousePressEvent(self.graphics_view, event)
215
233
 
216
234
  def mouseDoubleClickEvent(self, event):
217
235
  """Handle double-click to clear selection and reset the main view."""
218
236
  if event.button() == Qt.LeftButton:
219
- # Clear selection if any items are selected
220
237
  if self.graphics_scene.selectedItems():
221
- self.graphics_scene.clearSelection() # This triggers on_selection_changed
222
-
223
- # Signal the main window to revert from isolation mode
238
+ self.graphics_scene.clearSelection()
224
239
  self.reset_view_requested.emit()
225
240
  event.accept()
226
241
  else:
227
- # Pass other double-clicks to the base class
228
242
  super().mouseDoubleClickEvent(event)
229
243
 
230
244
  def mouseMoveEvent(self, event):
231
245
  """Handle mouse move for dynamic selection and panning."""
232
246
  if self.rubber_band:
233
- # Update the rubber band geometry
247
+ # Update the rubber band rectangle as the mouse moves
234
248
  current_pos = self.graphics_view.mapToScene(event.pos())
235
249
  self.rubber_band.setRect(QRectF(self.rubber_band_origin, current_pos).normalized())
236
-
250
+ # Create a selection path from the rubber band rectangle
237
251
  path = QPainterPath()
238
252
  path.addRect(self.rubber_band.rect())
239
-
240
- # Block signals to perform a compound selection operation
253
+ # Block signals to avoid recursive selectionChanged events
241
254
  self.graphics_scene.blockSignals(True)
242
-
243
- # 1. Perform the "fancy" dynamic selection, which replaces the current selection
244
- # with only the items inside the rubber band.
245
255
  self.graphics_scene.setSelectionArea(path)
246
-
247
- # 2. Add back the items that were selected at the start of the drag.
256
+ # Restore selection for items that were already selected at press
248
257
  if self.selection_at_press:
249
258
  for item in self.selection_at_press:
250
259
  item.setSelected(True)
251
-
252
- # Unblock signals and manually trigger our handler to process the final result.
253
260
  self.graphics_scene.blockSignals(False)
261
+ # Manually trigger selection changed logic
254
262
  self.on_selection_changed()
255
-
256
263
  elif event.buttons() == Qt.RightButton:
257
- # Handle right-click panning
258
- left_event = QMouseEvent(event.type(),
259
- event.localPos(),
260
- Qt.LeftButton,
261
- Qt.LeftButton,
262
- event.modifiers())
264
+ # Forward right-drag as left-drag for panning
265
+ left_event = QMouseEvent(event.type(), event.localPos(), Qt.LeftButton, Qt.LeftButton, event.modifiers())
263
266
  QGraphicsView.mouseMoveEvent(self.graphics_view, left_event)
264
267
  else:
268
+ # Default mouse move handling
265
269
  QGraphicsView.mouseMoveEvent(self.graphics_view, event)
266
270
 
267
271
  def mouseReleaseEvent(self, event):
268
272
  """Handle mouse release to finalize the action and clean up."""
269
273
  if self.rubber_band:
270
- # Clean up the visual rectangle
271
274
  self.graphics_scene.removeItem(self.rubber_band)
272
275
  self.rubber_band = None
273
-
274
- # Clean up the stored selection state.
275
276
  self.selection_at_press = None
276
-
277
277
  elif event.button() == Qt.RightButton:
278
- # Finalize the pan
279
- left_event = QMouseEvent(event.type(),
280
- event.localPos(),
281
- Qt.LeftButton,
282
- Qt.LeftButton,
283
- event.modifiers())
278
+ left_event = QMouseEvent(event.type(), event.localPos(), Qt.LeftButton, Qt.LeftButton, event.modifiers())
284
279
  QGraphicsView.mouseReleaseEvent(self.graphics_view, left_event)
285
280
  self.graphics_view.setDragMode(QGraphicsView.NoDrag)
286
281
  else:
287
- # Finalize a single click
288
282
  QGraphicsView.mouseReleaseEvent(self.graphics_view, event)
289
283
  self.graphics_view.setDragMode(QGraphicsView.NoDrag)
290
284
 
@@ -293,41 +287,60 @@ class EmbeddingViewer(QWidget): # Change inheritance to QWidget
293
287
  zoom_in_factor = 1.25
294
288
  zoom_out_factor = 1 / zoom_in_factor
295
289
 
290
+ # Set anchor points so zoom occurs at mouse position
296
291
  self.graphics_view.setTransformationAnchor(QGraphicsView.NoAnchor)
297
292
  self.graphics_view.setResizeAnchor(QGraphicsView.NoAnchor)
298
293
 
294
+ # Get the scene position before zooming
299
295
  old_pos = self.graphics_view.mapToScene(event.pos())
296
+
297
+ # Determine zoom direction
300
298
  zoom_factor = zoom_in_factor if event.angleDelta().y() > 0 else zoom_out_factor
299
+
300
+ # Apply zoom
301
301
  self.graphics_view.scale(zoom_factor, zoom_factor)
302
+
303
+ # Get the scene position after zooming
302
304
  new_pos = self.graphics_view.mapToScene(event.pos())
303
-
305
+
306
+ # Translate view to keep mouse position stable
304
307
  delta = new_pos - old_pos
305
308
  self.graphics_view.translate(delta.x(), delta.y())
306
309
 
307
310
  def update_embeddings(self, data_items):
308
- """Update the embedding visualization with new data.
309
-
310
- Args:
311
- data_items: List of AnnotationDataItem objects.
312
- """
311
+ """Update the embedding visualization. Creates an EmbeddingPointItem for
312
+ each AnnotationDataItem and links them."""
313
313
  self.clear_points()
314
-
315
314
  for item in data_items:
316
- point = EmbeddingPointItem(0, 0, POINT_SIZE, POINT_SIZE)
317
- point.setPos(item.embedding_x, item.embedding_y)
318
-
319
- # No need to set initial brush - paint() will handle it
320
- point.setPen(QPen(QColor("black"), POINT_WIDTH))
321
-
322
- point.setFlag(QGraphicsItem.ItemIgnoresTransformations)
323
- point.setFlag(QGraphicsItem.ItemIsSelectable)
324
-
325
- # This is the crucial link: store the shared AnnotationDataItem
326
- point.setData(0, item)
327
-
315
+ # Create the point item directly from the data_item.
316
+ # The item's constructor now handles setting position, flags, etc.
317
+ point = EmbeddingPointItem(item)
328
318
  self.graphics_scene.addItem(point)
329
319
  self.points_by_id[item.annotation.id] = point
330
320
 
321
+ def refresh_points(self):
322
+ """Refreshes the points in the view to match the current state of the master data list."""
323
+ if not self.explorer_window or not self.explorer_window.current_data_items:
324
+ return
325
+
326
+ # Get the set of IDs for points currently in the scene
327
+ current_point_ids = set(self.points_by_id.keys())
328
+
329
+ # Get the master list of data items from the parent window
330
+ all_data_items = self.explorer_window.current_data_items
331
+
332
+ something_changed = False
333
+ for item in all_data_items:
334
+ # If a data item is NOT marked for deletion but is also NOT in the scene, add it back.
335
+ if not item.is_marked_for_deletion() and item.annotation.id not in current_point_ids:
336
+ point = EmbeddingPointItem(item)
337
+ self.graphics_scene.addItem(point)
338
+ self.points_by_id[item.annotation.id] = point
339
+ something_changed = True
340
+
341
+ if something_changed:
342
+ print("Refreshed embedding points to show reverted items.")
343
+
331
344
  def clear_points(self):
332
345
  """Clear all embedding points from the scene."""
333
346
  for point in self.points_by_id.values():
@@ -335,79 +348,76 @@ class EmbeddingViewer(QWidget): # Change inheritance to QWidget
335
348
  self.points_by_id.clear()
336
349
 
337
350
  def on_selection_changed(self):
338
- """Handle point selection changes and emit a signal to the controller."""
339
- # Check if graphics_scene still exists and is valid
340
- if not self.graphics_scene or not hasattr(self.graphics_scene, 'selectedItems'):
351
+ """
352
+ Handles selection changes in the scene. Updates the central data model
353
+ and emits a signal to notify other parts of the application.
354
+ """
355
+ if not self.graphics_scene:
341
356
  return
342
-
343
357
  try:
344
358
  selected_items = self.graphics_scene.selectedItems()
345
359
  except RuntimeError:
346
- # Scene has been deleted
347
360
  return
348
361
 
349
- current_selection_ids = {item.data(0).annotation.id for item in selected_items}
362
+ current_selection_ids = {item.data_item.annotation.id for item in selected_items}
350
363
 
351
- # If the selection has actually changed, update the model and emit
352
364
  if current_selection_ids != self.previous_selection_ids:
353
- # Update the central model (the AnnotationDataItem) for all points
365
+ # Update the central model (AnnotationDataItem) for all points
354
366
  for point_id, point in self.points_by_id.items():
355
367
  is_selected = point_id in current_selection_ids
356
- point.data(0).set_selected(is_selected)
368
+ # The data_item is the single source of truth
369
+ point.data_item.set_selected(is_selected)
357
370
 
358
- # Emit the complete list of currently selected IDs
359
371
  self.selection_changed.emit(list(current_selection_ids))
360
372
  self.previous_selection_ids = current_selection_ids
361
373
 
362
- # Handle local animation - check if animation_timer still exists
363
- if hasattr(self, 'animation_timer') and self.animation_timer:
374
+ # Handle animation
375
+ if hasattr(self, 'animation_timer') and self.animation_timer:
364
376
  self.animation_timer.stop()
365
377
 
366
378
  for point in self.points_by_id.values():
367
379
  if not point.isSelected():
368
380
  point.setPen(QPen(QColor("black"), POINT_WIDTH))
369
-
370
381
  if selected_items and hasattr(self, 'animation_timer') and self.animation_timer:
371
382
  self.animation_timer.start()
372
383
 
373
384
  def animate_selection(self):
374
- """Animate selected points with marching ants effect using darker versions of point colors."""
375
- # Check if graphics_scene still exists and is valid
376
- if not self.graphics_scene or not hasattr(self.graphics_scene, 'selectedItems'):
385
+ """Animate selected points with a marching ants effect."""
386
+ if not self.graphics_scene:
377
387
  return
378
-
379
388
  try:
380
389
  selected_items = self.graphics_scene.selectedItems()
381
390
  except RuntimeError:
382
- # Scene has been deleted
383
391
  return
384
392
 
385
393
  self.animation_offset = (self.animation_offset + 1) % 20
386
-
387
- # This logic remains the same. It applies the custom pen to the selected items.
388
- # Because the items are EmbeddingPointItem, the default selection box won't be drawn.
389
394
  for item in selected_items:
390
- original_color = item.brush().color()
395
+ # Get the color directly from the source of truth
396
+ original_color = item.data_item.effective_color
391
397
  darker_color = original_color.darker(150)
392
-
393
398
  animated_pen = QPen(darker_color, POINT_WIDTH)
394
399
  animated_pen.setStyle(Qt.CustomDashLine)
395
400
  animated_pen.setDashPattern([1, 2])
396
401
  animated_pen.setDashOffset(self.animation_offset)
397
-
398
402
  item.setPen(animated_pen)
399
403
 
400
404
  def render_selection_from_ids(self, selected_ids):
401
- """Update the visual selection of points based on a set of IDs from the controller."""
402
- # Block this scene's own selectionChanged signal to prevent an infinite loop
405
+ """
406
+ Updates the visual selection of points based on a set of annotation IDs
407
+ provided by an external controller.
408
+ """
403
409
  blocker = QSignalBlocker(self.graphics_scene)
404
410
 
405
411
  for ann_id, point in self.points_by_id.items():
406
- point.setSelected(ann_id in selected_ids)
412
+ is_selected = ann_id in selected_ids
413
+ # 1. Update the state on the central data item
414
+ point.data_item.set_selected(is_selected)
415
+ # 2. Update the selection state of the graphics item itself
416
+ point.setSelected(is_selected)
407
417
 
408
- self.previous_selection_ids = selected_ids
418
+ blocker.unblock()
409
419
 
410
- # Trigger animation update
420
+ # Manually trigger on_selection_changed to update animation and emit signals
411
421
  self.on_selection_changed()
412
422
 
413
423
  def fit_view_to_points(self):
@@ -415,40 +425,34 @@ class EmbeddingViewer(QWidget): # Change inheritance to QWidget
415
425
  if self.points_by_id:
416
426
  self.graphics_view.fitInView(self.graphics_scene.itemsBoundingRect(), Qt.KeepAspectRatio)
417
427
  else:
418
- # If no points, reset to default view
419
428
  self.graphics_view.fitInView(-2500, -2500, 5000, 5000, Qt.KeepAspectRatio)
420
429
 
421
430
 
422
431
  class AnnotationViewer(QScrollArea):
423
- """Scrollable grid widget for displaying annotation image crops with selection,
424
- filtering, and isolation support.
425
- """
426
-
427
- # Define signals to report changes to the ExplorerWindow
428
- selection_changed = pyqtSignal(list) # list of changed annotation IDs
429
- preview_changed = pyqtSignal(list) # list of annotation IDs with new previews
430
- reset_view_requested = pyqtSignal() # Signal to reset the view to fit all points
432
+ """Scrollable grid widget for displaying annotation image crops with selection,
433
+ filtering, and isolation support. Acts as a controller for the widgets."""
434
+ selection_changed = pyqtSignal(list)
435
+ preview_changed = pyqtSignal(list)
436
+ reset_view_requested = pyqtSignal()
431
437
 
432
438
  def __init__(self, parent=None):
439
+ """Initialize the AnnotationViewer widget."""
433
440
  super(AnnotationViewer, self).__init__(parent)
441
+ self.explorer_window = parent
442
+
434
443
  self.annotation_widgets_by_id = {}
435
444
  self.selected_widgets = []
436
445
  self.last_selected_index = -1
437
446
  self.current_widget_size = 96
438
-
439
447
  self.selection_at_press = set()
440
448
  self.rubber_band = None
441
449
  self.rubber_band_origin = None
442
450
  self.drag_threshold = 5
443
451
  self.mouse_pressed_on_widget = False
444
-
445
452
  self.preview_label_assignments = {}
446
453
  self.original_label_assignments = {}
447
-
448
- # New state variables for Isolate/Focus mode
449
454
  self.isolated_mode = False
450
455
  self.isolated_widgets = set()
451
-
452
456
  self.setup_ui()
453
457
 
454
458
  def setup_ui(self):
@@ -457,49 +461,35 @@ class AnnotationViewer(QScrollArea):
457
461
  self.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
458
462
  self.setVerticalScrollBarPolicy(Qt.ScrollBarAsNeeded)
459
463
 
460
- # Main container and layout
461
464
  main_container = QWidget()
462
465
  main_layout = QVBoxLayout(main_container)
463
466
  main_layout.setContentsMargins(0, 0, 0, 0)
464
- main_layout.setSpacing(4) # Add a little space between toolbar and content
467
+ main_layout.setSpacing(4)
465
468
 
466
- # --- New Toolbar ---
467
469
  toolbar_widget = QWidget()
468
470
  toolbar_layout = QHBoxLayout(toolbar_widget)
469
471
  toolbar_layout.setContentsMargins(4, 2, 4, 2)
470
472
 
471
- # Isolate/Focus controls
472
473
  self.isolate_button = QPushButton("Isolate Selection")
473
- isolate_icon = get_icon("focus.png")
474
- if not isolate_icon.isNull():
475
- self.isolate_button.setIcon(isolate_icon)
476
474
  self.isolate_button.setToolTip("Hide all non-selected annotations")
477
475
  self.isolate_button.clicked.connect(self.isolate_selection)
478
476
  toolbar_layout.addWidget(self.isolate_button)
479
477
 
480
478
  self.show_all_button = QPushButton("Show All")
481
- show_all_icon = get_icon("show_all.png")
482
- if not show_all_icon.isNull():
483
- self.show_all_button.setIcon(show_all_icon)
484
479
  self.show_all_button.setToolTip("Show all filtered annotations")
485
480
  self.show_all_button.clicked.connect(self.show_all_annotations)
486
481
  toolbar_layout.addWidget(self.show_all_button)
487
482
 
488
- # Add a separator
489
483
  toolbar_layout.addWidget(self._create_separator())
490
484
 
491
- # Sort controls
492
485
  sort_label = QLabel("Sort By:")
493
486
  toolbar_layout.addWidget(sort_label)
494
487
  self.sort_combo = QComboBox()
495
488
  self.sort_combo.addItems(["None", "Label", "Image"])
496
489
  self.sort_combo.currentTextChanged.connect(self.on_sort_changed)
497
490
  toolbar_layout.addWidget(self.sort_combo)
498
-
499
- # Add a spacer to push the size controls to the right
500
491
  toolbar_layout.addStretch()
501
492
 
502
- # Size controls
503
493
  size_label = QLabel("Size:")
504
494
  toolbar_layout.addWidget(size_label)
505
495
  self.size_slider = QSlider(Qt.Horizontal)
@@ -514,23 +504,19 @@ class AnnotationViewer(QScrollArea):
514
504
  self.size_value_label = QLabel("96")
515
505
  self.size_value_label.setMinimumWidth(30)
516
506
  toolbar_layout.addWidget(self.size_value_label)
517
-
518
507
  main_layout.addWidget(toolbar_widget)
519
-
520
- # --- Content Area ---
508
+
521
509
  self.content_widget = QWidget()
522
510
  content_scroll = QScrollArea()
523
511
  content_scroll.setWidgetResizable(True)
524
512
  content_scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
525
513
  content_scroll.setVerticalScrollBarPolicy(Qt.ScrollBarAsNeeded)
526
514
  content_scroll.setWidget(self.content_widget)
527
-
515
+
528
516
  main_layout.addWidget(content_scroll)
529
517
  self.setWidget(main_container)
530
-
531
- # Set the initial state of the toolbar buttons
532
518
  self._update_toolbar_state()
533
-
519
+
534
520
  @pyqtSlot()
535
521
  def isolate_selection(self):
536
522
  """Hides all annotation widgets that are not currently selected."""
@@ -549,6 +535,7 @@ class AnnotationViewer(QScrollArea):
549
535
  self.content_widget.setUpdatesEnabled(True)
550
536
 
551
537
  self._update_toolbar_state()
538
+ self.explorer_window.main_window.label_window.update_annotation_count()
552
539
 
553
540
  @pyqtSlot()
554
541
  def show_all_annotations(self):
@@ -568,13 +555,11 @@ class AnnotationViewer(QScrollArea):
568
555
  self.content_widget.setUpdatesEnabled(True)
569
556
 
570
557
  self._update_toolbar_state()
571
-
558
+ self.explorer_window.main_window.label_window.update_annotation_count()
559
+
572
560
  def _update_toolbar_state(self):
573
- """Updates the visibility and enabled state of the toolbar buttons
574
- based on the current selection and isolation mode.
575
- """
561
+ """Updates the toolbar buttons based on selection and isolation mode."""
576
562
  selection_exists = bool(self.selected_widgets)
577
-
578
563
  if self.isolated_mode:
579
564
  self.isolate_button.hide()
580
565
  self.show_all_button.show()
@@ -583,9 +568,9 @@ class AnnotationViewer(QScrollArea):
583
568
  self.isolate_button.show()
584
569
  self.show_all_button.hide()
585
570
  self.isolate_button.setEnabled(selection_exists)
586
-
571
+
587
572
  def _create_separator(self):
588
- """Create a vertical separator line for the toolbar."""
573
+ """Creates a vertical separator for the toolbar."""
589
574
  separator = QLabel("|")
590
575
  separator.setStyleSheet("color: gray; margin: 0 5px;")
591
576
  return separator
@@ -597,30 +582,21 @@ class AnnotationViewer(QScrollArea):
597
582
  def _get_sorted_widgets(self):
598
583
  """Get widgets sorted according to the current sort setting."""
599
584
  sort_type = self.sort_combo.currentText()
600
-
601
- if sort_type == "None":
602
- return list(self.annotation_widgets_by_id.values())
603
-
604
585
  widgets = list(self.annotation_widgets_by_id.values())
605
-
606
586
  if sort_type == "Label":
607
587
  widgets.sort(key=lambda w: w.data_item.effective_label.short_label_code)
608
588
  elif sort_type == "Image":
609
589
  widgets.sort(key=lambda w: os.path.basename(w.data_item.annotation.image_path))
610
-
611
590
  return widgets
612
591
 
613
592
  def _group_widgets_by_sort_key(self, widgets):
614
- """Group widgets by the current sort key and return groups with headers."""
593
+ """Group widgets by the current sort key."""
615
594
  sort_type = self.sort_combo.currentText()
616
-
617
595
  if sort_type == "None":
618
596
  return [("", widgets)]
619
-
620
597
  groups = []
621
598
  current_group = []
622
599
  current_key = None
623
-
624
600
  for widget in widgets:
625
601
  if sort_type == "Label":
626
602
  key = widget.data_item.effective_label.short_label_code
@@ -628,7 +604,6 @@ class AnnotationViewer(QScrollArea):
628
604
  key = os.path.basename(widget.data_item.annotation.image_path)
629
605
  else:
630
606
  key = ""
631
-
632
607
  if current_key != key:
633
608
  if current_group:
634
609
  groups.append((current_key, current_group))
@@ -636,12 +611,10 @@ class AnnotationViewer(QScrollArea):
636
611
  current_key = key
637
612
  else:
638
613
  current_group.append(widget)
639
-
640
614
  if current_group:
641
615
  groups.append((current_key, current_group))
642
-
643
616
  return groups
644
-
617
+
645
618
  def _clear_separator_labels(self):
646
619
  """Remove any existing group header labels."""
647
620
  if hasattr(self, '_group_headers'):
@@ -654,130 +627,90 @@ class AnnotationViewer(QScrollArea):
654
627
  """Create a group header label."""
655
628
  if not hasattr(self, '_group_headers'):
656
629
  self._group_headers = []
657
-
658
- header = QLabel(text)
659
- header.setParent(self.content_widget)
660
- header.setStyleSheet("""
661
- QLabel {
662
- font-weight: bold;
663
- font-size: 12px;
664
- color: #555;
665
- background-color: #f0f0f0;
666
- border: 1px solid #ccc;
667
- border-radius: 3px;
668
- padding: 5px 8px;
669
- margin: 2px 0px;
670
- }
671
- """)
672
- header.setFixedHeight(30) # Increased from 25 to 30
630
+ header = QLabel(text, self.content_widget)
631
+ header.setStyleSheet(
632
+ "QLabel {"
633
+ " font-weight: bold;"
634
+ " font-size: 12px;"
635
+ " color: #555;"
636
+ " background-color: #f0f0f0;"
637
+ " border: 1px solid #ccc;"
638
+ " border-radius: 3px;"
639
+ " padding: 5px 8px;"
640
+ " margin: 2px 0px;"
641
+ " }"
642
+ )
643
+ header.setFixedHeight(30)
673
644
  header.setMinimumWidth(self.viewport().width() - 20)
674
645
  header.show()
675
-
676
646
  self._group_headers.append(header)
677
647
  return header
678
648
 
679
649
  def on_size_changed(self, value):
680
650
  """Handle slider value change to resize annotation widgets."""
681
- if value % 2 != 0:
651
+ if value % 2 != 0:
682
652
  value -= 1
653
+
683
654
  self.current_widget_size = value
684
655
  self.size_value_label.setText(str(value))
685
-
686
- # Disable updates for performance while resizing many items
687
656
  self.content_widget.setUpdatesEnabled(False)
657
+
688
658
  for widget in self.annotation_widgets_by_id.values():
689
- widget.update_height(value) # Call the new, more descriptive method
659
+ widget.update_height(value)
660
+
690
661
  self.content_widget.setUpdatesEnabled(True)
691
-
692
- # After resizing, reflow the layout
693
662
  self.recalculate_widget_positions()
694
663
 
695
- def recalculate_grid_layout(self):
696
- """Recalculate the grid layout based on current widget width."""
697
- if not self.annotation_widgets_by_id:
698
- return
699
-
700
- available_width = self.viewport().width() - 20
701
- widget_width = self.current_widget_size + self.grid_layout.spacing()
702
- cols = max(1, available_width // widget_width)
703
-
704
- for i, widget in enumerate(self.annotation_widgets_by_id.values()):
705
- self.grid_layout.addWidget(widget, i // cols, i % cols)
706
-
707
664
  def recalculate_widget_positions(self):
708
665
  """Manually positions widgets in a flow layout with sorting and group headers."""
709
666
  if not self.annotation_widgets_by_id:
710
667
  self.content_widget.setMinimumSize(1, 1)
711
668
  return
712
669
 
713
- # Clear any existing separator labels
714
670
  self._clear_separator_labels()
715
-
716
- # Get sorted widgets
717
- all_widgets = self._get_sorted_widgets()
718
-
719
- # Filter to only visible widgets
720
- visible_widgets = [w for w in all_widgets if not w.isHidden()]
721
-
671
+ visible_widgets = [w for w in self._get_sorted_widgets() if not w.isHidden()]
722
672
  if not visible_widgets:
723
673
  self.content_widget.setMinimumSize(1, 1)
724
674
  return
725
675
 
726
- # Group widgets by sort key
676
+ # Create groups based on the current sort key
727
677
  groups = self._group_widgets_by_sort_key(visible_widgets)
728
-
729
- # Calculate spacing
730
678
  spacing = max(5, int(self.current_widget_size * 0.08))
731
679
  available_width = self.viewport().width()
732
-
733
680
  x, y = spacing, spacing
734
681
  max_height_in_row = 0
735
682
 
683
+ # Calculate the maximum height of the widgets in each row
736
684
  for group_name, group_widgets in groups:
737
- # Add group header if sorting is enabled and group has a name
738
685
  if group_name and self.sort_combo.currentText() != "None":
739
- # Ensure we're at the start of a new line for headers
740
686
  if x > spacing:
741
687
  x = spacing
742
688
  y += max_height_in_row + spacing
743
689
  max_height_in_row = 0
744
-
745
- # Create and position header label
746
690
  header_label = self._create_group_header(group_name)
747
691
  header_label.move(x, y)
748
-
749
- # Move to next line after header
750
692
  y += header_label.height() + spacing
751
693
  x = spacing
752
694
  max_height_in_row = 0
753
-
754
- # Position widgets in this group
695
+
755
696
  for widget in group_widgets:
756
697
  widget_size = widget.size()
757
-
758
- # Check if widget fits on current line
759
698
  if x > spacing and x + widget_size.width() > available_width:
760
699
  x = spacing
761
700
  y += max_height_in_row + spacing
762
701
  max_height_in_row = 0
763
-
764
702
  widget.move(x, y)
765
703
  x += widget_size.width() + spacing
766
704
  max_height_in_row = max(max_height_in_row, widget_size.height())
767
705
 
768
- # Update content widget size
769
706
  total_height = y + max_height_in_row + spacing
770
707
  self.content_widget.setMinimumSize(available_width, total_height)
771
-
708
+
772
709
  def update_annotations(self, data_items):
773
- """Update displayed annotations, creating new widgets. This will also
774
- reset any active isolation view.
775
- """
776
- # Reset isolation state before updating to avoid confusion
710
+ """Update displayed annotations, creating new widgets for them."""
777
711
  if self.isolated_mode:
778
712
  self.show_all_annotations()
779
713
 
780
- # Clear any existing widgets and ensure they are deleted
781
714
  for widget in self.annotation_widgets_by_id.values():
782
715
  widget.setParent(None)
783
716
  widget.deleteLater()
@@ -785,193 +718,202 @@ class AnnotationViewer(QScrollArea):
785
718
  self.annotation_widgets_by_id.clear()
786
719
  self.selected_widgets.clear()
787
720
  self.last_selected_index = -1
788
-
789
- # Create new widgets, parenting them to the content_widget
721
+
790
722
  for data_item in data_items:
791
723
  annotation_widget = AnnotationImageWidget(
792
- data_item,
793
- self.current_widget_size,
794
- annotation_viewer=self,
795
- parent=self.content_widget
796
- )
797
- annotation_widget.show()
724
+ data_item, self.current_widget_size, self, self.content_widget)
725
+
726
+ annotation_widget.show()
798
727
  self.annotation_widgets_by_id[data_item.annotation.id] = annotation_widget
799
-
728
+
800
729
  self.recalculate_widget_positions()
801
- # Ensure toolbar is in the correct state after a refresh
802
730
  self._update_toolbar_state()
803
731
 
804
732
  def resizeEvent(self, event):
805
733
  """On window resize, reflow the annotation widgets."""
806
734
  super().resizeEvent(event)
807
- # Use a QTimer to avoid rapid, expensive reflows while dragging the resize handle
808
735
  if not hasattr(self, '_resize_timer'):
809
736
  self._resize_timer = QTimer(self)
810
737
  self._resize_timer.setSingleShot(True)
811
738
  self._resize_timer.timeout.connect(self.recalculate_widget_positions)
812
- # Restart the timer on each resize event
813
- self._resize_timer.start(100) # 100ms delay
739
+ self._resize_timer.start(100)
740
+
741
+ def keyPressEvent(self, event):
742
+ """Handles key presses for deleting selected annotations."""
743
+ # Check if the pressed key is Delete/Backspace AND the Control key is held down
744
+ if event.key() in (Qt.Key_Delete, Qt.Key_Backspace) and event.modifiers() == Qt.ControlModifier:
745
+ # Proceed only if there are selected widgets
746
+ if not self.selected_widgets:
747
+ super().keyPressEvent(event)
748
+ return
749
+
750
+ print(f"Marking {len(self.selected_widgets)} annotations for deletion.")
751
+
752
+ # Keep track of which annotations were affected
753
+ changed_ids = []
754
+
755
+ # Mark each selected item for deletion and hide it
756
+ for widget in self.selected_widgets:
757
+ widget.data_item.mark_for_deletion()
758
+ widget.hide()
759
+ changed_ids.append(widget.data_item.annotation.id)
760
+
761
+ # Clear the list of selected widgets
762
+ self.selected_widgets.clear()
763
+
764
+ # Recalculate the layout to fill in the empty space
765
+ self.recalculate_widget_positions()
766
+
767
+ # Emit a signal to notify the ExplorerWindow that the selection is now empty
768
+ # This will also clear the selection in the EmbeddingViewer
769
+ if changed_ids:
770
+ self.selection_changed.emit(changed_ids)
771
+
772
+ # Accept the event to prevent it from being processed further
773
+ event.accept()
774
+ else:
775
+ # Pass any other key presses to the default handler
776
+ super().keyPressEvent(event)
814
777
 
815
778
  def mousePressEvent(self, event):
816
779
  """Handle mouse press for starting rubber band selection OR clearing selection."""
817
-
818
- # Handle plain left-clicks
819
780
  if event.button() == Qt.LeftButton:
820
-
821
- # This is the new logic for clearing selection on a background click.
822
- if not event.modifiers(): # Check for NO modifiers (e.g., Ctrl, Shift)
823
-
781
+ if not event.modifiers():
782
+ # If left click with no modifiers, check if click is outside widgets
824
783
  is_on_widget = False
825
784
  child_at_pos = self.childAt(event.pos())
826
-
827
- # Determine if the click was on an actual annotation widget or empty space
785
+
828
786
  if child_at_pos:
829
787
  widget = child_at_pos
788
+ # Traverse up the parent chain to see if click is on an annotation widget
830
789
  while widget and widget != self:
831
790
  if hasattr(widget, 'annotation_viewer') and widget.annotation_viewer == self:
832
791
  is_on_widget = True
833
792
  break
834
793
  widget = widget.parent()
835
-
836
- # If click was on empty space AND something is currently selected...
794
+
795
+ # If click is outside widgets and there is a selection, clear it
837
796
  if not is_on_widget and self.selected_widgets:
838
- # Get IDs of widgets that are about to be deselected to emit a signal
839
797
  changed_ids = [w.data_item.annotation.id for w in self.selected_widgets]
840
798
  self.clear_selection()
841
799
  self.selection_changed.emit(changed_ids)
842
- # The event is handled, but we don't call super() to prevent
843
- # the scroll area from doing anything else, like starting a drag.
844
800
  return
845
-
846
- # Handle Ctrl+Click for rubber band
801
+
847
802
  elif event.modifiers() == Qt.ControlModifier:
848
- # Store the set of currently selected items.
803
+ # Start rubber band selection with Ctrl+Left click
849
804
  self.selection_at_press = set(self.selected_widgets)
850
805
  self.rubber_band_origin = event.pos()
851
- # We determine mouse_pressed_on_widget here but use it in mouseMove
852
806
  self.mouse_pressed_on_widget = False
853
807
  child_widget = self.childAt(event.pos())
854
808
  if child_widget:
855
809
  widget = child_widget
810
+ # Check if click is on a widget to avoid starting rubber band
856
811
  while widget and widget != self:
857
812
  if hasattr(widget, 'annotation_viewer') and widget.annotation_viewer == self:
858
813
  self.mouse_pressed_on_widget = True
859
814
  break
860
815
  widget = widget.parent()
861
816
  return
862
-
863
- # Handle right-clicks
817
+
864
818
  elif event.button() == Qt.RightButton:
819
+ # Ignore right clicks
865
820
  event.ignore()
866
821
  return
867
-
868
- # For all other cases (e.g., a click on a widget that should be handled
869
- # by the widget itself), pass the event to the default handler.
870
- super().mousePressEvent(event)
871
822
 
823
+ # Default handler for other cases
824
+ super().mousePressEvent(event)
825
+
872
826
  def mouseDoubleClickEvent(self, event):
873
827
  """Handle double-click to clear selection and exit isolation mode."""
874
828
  if event.button() == Qt.LeftButton:
875
829
  changed_ids = []
876
-
877
- # If items are selected, clear the selection and record their IDs
878
830
  if self.selected_widgets:
879
831
  changed_ids = [w.data_item.annotation.id for w in self.selected_widgets]
880
832
  self.clear_selection()
881
833
  self.selection_changed.emit(changed_ids)
882
-
883
- # If in isolation mode, revert to showing all annotations
884
834
  if self.isolated_mode:
885
835
  self.show_all_annotations()
886
-
887
- # Signal the main window to reset its view (e.g., switch tabs)
888
836
  self.reset_view_requested.emit()
889
837
  event.accept()
890
838
  else:
891
839
  super().mouseDoubleClickEvent(event)
892
-
840
+
893
841
  def mouseMoveEvent(self, event):
894
842
  """Handle mouse move for DYNAMIC rubber band selection."""
895
- if self.rubber_band_origin is None or \
896
- event.buttons() != Qt.LeftButton or \
897
- event.modifiers() != Qt.ControlModifier:
843
+ # Only proceed if Ctrl+Left mouse drag is active and not on a widget
844
+ if (
845
+ self.rubber_band_origin is None or
846
+ event.buttons() != Qt.LeftButton or
847
+ event.modifiers() != Qt.ControlModifier
848
+ ):
898
849
  super().mouseMoveEvent(event)
899
850
  return
900
851
 
901
- # If the mouse was pressed on a widget, let that widget handle the event.
902
852
  if self.mouse_pressed_on_widget:
853
+ # If drag started on a widget, do not start rubber band
903
854
  super().mouseMoveEvent(event)
904
855
  return
905
856
 
906
- # Only start the rubber band after dragging a minimum distance
857
+ # Only start selection if drag distance exceeds threshold
907
858
  distance = (event.pos() - self.rubber_band_origin).manhattanLength()
908
859
  if distance < self.drag_threshold:
909
860
  return
910
861
 
911
- # Create and show the rubber band if it doesn't exist
862
+ # Create and show the rubber band if not already present
912
863
  if not self.rubber_band:
913
864
  self.rubber_band = QRubberBand(QRubberBand.Rectangle, self.viewport())
914
-
865
+
915
866
  rect = QRect(self.rubber_band_origin, event.pos()).normalized()
916
867
  self.rubber_band.setGeometry(rect)
917
868
  self.rubber_band.show()
918
-
919
- # Perform dynamic selection on every move
920
869
  selection_rect = self.rubber_band.geometry()
921
870
  content_widget = self.content_widget
922
871
  changed_ids = []
923
872
 
873
+ # Iterate over all annotation widgets to update selection state
924
874
  for widget in self.annotation_widgets_by_id.values():
925
875
  widget_rect_in_content = widget.geometry()
926
- # Map widget's geometry from the content area to the visible viewport
876
+ # Map widget's rect to viewport coordinates
927
877
  widget_rect_in_viewport = QRect(
928
878
  content_widget.mapTo(self.viewport(), widget_rect_in_content.topLeft()),
929
879
  widget_rect_in_content.size()
930
880
  )
931
-
932
881
  is_in_band = selection_rect.intersects(widget_rect_in_viewport)
933
-
934
- # A widget should be selected if it was selected at the start OR is in the band now.
935
882
  should_be_selected = (widget in self.selection_at_press) or is_in_band
936
883
 
884
+ # Select or deselect widgets as needed
937
885
  if should_be_selected and not widget.is_selected():
938
886
  if self.select_widget(widget):
939
887
  changed_ids.append(widget.data_item.annotation.id)
888
+
940
889
  elif not should_be_selected and widget.is_selected():
941
890
  if self.deselect_widget(widget):
942
891
  changed_ids.append(widget.data_item.annotation.id)
943
-
892
+
893
+ # Emit signal if any selection state changed
944
894
  if changed_ids:
945
895
  self.selection_changed.emit(changed_ids)
946
-
896
+
947
897
  def mouseReleaseEvent(self, event):
948
898
  """Handle mouse release to complete rubber band selection."""
949
- # Check if a rubber band drag was in progress
950
899
  if self.rubber_band_origin is not None and event.button() == Qt.LeftButton:
951
900
  if self.rubber_band and self.rubber_band.isVisible():
952
901
  self.rubber_band.hide()
953
902
  self.rubber_band.deleteLater()
954
903
  self.rubber_band = None
955
-
956
- # **NEEDED CHANGE**: Clean up the stored selection state.
904
+
957
905
  self.selection_at_press = set()
958
906
  self.rubber_band_origin = None
959
907
  self.mouse_pressed_on_widget = False
960
908
  event.accept()
961
909
  return
962
-
910
+
963
911
  super().mouseReleaseEvent(event)
964
912
 
965
913
  def handle_annotation_selection(self, widget, event):
966
- """Handle selection of annotation widgets with different modes."""
967
- # Get the list of widgets to work with based on isolation mode
968
- if self.isolated_mode:
969
- # Only work with visible widgets when in isolation mode
970
- widget_list = [w for w in self.annotation_widgets_by_id.values() if not w.isHidden()]
971
- else:
972
- # Use all widgets when not in isolation mode
973
- widget_list = list(self.annotation_widgets_by_id.values())
974
-
914
+ """Handle selection of annotation widgets with different modes (single, ctrl, shift)."""
915
+ widget_list = [w for w in self._get_sorted_widgets() if not w.isHidden()]
916
+
975
917
  try:
976
918
  widget_index = widget_list.index(widget)
977
919
  except ValueError:
@@ -980,44 +922,42 @@ class AnnotationViewer(QScrollArea):
980
922
  modifiers = event.modifiers()
981
923
  changed_ids = []
982
924
 
983
- # --- The selection logic now identifies which items to change ---
984
- # --- but the core state change happens in select/deselect ---
985
-
925
+ # Shift or Shift+Ctrl: range selection
986
926
  if modifiers == Qt.ShiftModifier or modifiers == (Qt.ShiftModifier | Qt.ControlModifier):
987
- # Range selection
988
927
  if self.last_selected_index != -1:
989
- # Find the last selected widget in the current widget list
928
+ # Find the last selected widget in the current list
990
929
  last_selected_widget = None
991
930
  for w in self.selected_widgets:
992
931
  if w in widget_list:
993
932
  try:
994
933
  last_index_in_current_list = widget_list.index(w)
995
- if last_selected_widget is None or \
996
- last_index_in_current_list > widget_list.index(last_selected_widget):
934
+ if (
935
+ last_selected_widget is None
936
+ or last_index_in_current_list > widget_list.index(last_selected_widget)
937
+ ):
997
938
  last_selected_widget = w
998
939
  except ValueError:
999
940
  continue
1000
-
941
+
1001
942
  if last_selected_widget:
1002
943
  last_selected_index_in_current_list = widget_list.index(last_selected_widget)
1003
944
  start = min(last_selected_index_in_current_list, widget_index)
1004
945
  end = max(last_selected_index_in_current_list, widget_index)
1005
946
  else:
1006
- # Fallback if no previously selected widget is found in current list
1007
- start = widget_index
1008
- end = widget_index
1009
-
947
+ start, end = widget_index, widget_index
948
+
949
+ # Select all widgets in the range
1010
950
  for i in range(start, end + 1):
1011
- # select_widget will return True if a change occurred
1012
951
  if self.select_widget(widget_list[i]):
1013
952
  changed_ids.append(widget_list[i].data_item.annotation.id)
1014
953
  else:
954
+ # No previous selection, just select the clicked widget
1015
955
  if self.select_widget(widget):
1016
956
  changed_ids.append(widget.data_item.annotation.id)
1017
957
  self.last_selected_index = widget_index
1018
-
958
+
959
+ # Ctrl: toggle selection of the clicked widget
1019
960
  elif modifiers == Qt.ControlModifier:
1020
- # Toggle selection
1021
961
  if widget.is_selected():
1022
962
  if self.deselect_widget(widget):
1023
963
  changed_ids.append(widget.data_item.annotation.id)
@@ -1025,207 +965,164 @@ class AnnotationViewer(QScrollArea):
1025
965
  if self.select_widget(widget):
1026
966
  changed_ids.append(widget.data_item.annotation.id)
1027
967
  self.last_selected_index = widget_index
1028
-
968
+
969
+ # No modifier: single selection
1029
970
  else:
1030
- # Normal click: clear all others and select this one
1031
971
  newly_selected_id = widget.data_item.annotation.id
1032
- # Deselect all widgets that are not the clicked one
972
+
973
+ # Deselect all others
1033
974
  for w in list(self.selected_widgets):
1034
975
  if w.data_item.annotation.id != newly_selected_id:
1035
976
  if self.deselect_widget(w):
1036
977
  changed_ids.append(w.data_item.annotation.id)
978
+
1037
979
  # Select the clicked widget
1038
980
  if self.select_widget(widget):
1039
981
  changed_ids.append(newly_selected_id)
1040
982
  self.last_selected_index = widget_index
1041
-
1042
- # Update isolation if in isolated mode
983
+
984
+ # If in isolated mode, update which widgets are visible
1043
985
  if self.isolated_mode:
1044
986
  self._update_isolation()
1045
-
1046
- # If any selections were changed, emit the signal
987
+
988
+ # Emit signal if any selection state changed
1047
989
  if changed_ids:
1048
990
  self.selection_changed.emit(changed_ids)
1049
-
991
+
1050
992
  def _update_isolation(self):
1051
993
  """Update the isolated view to show only currently selected widgets."""
1052
- if not self.isolated_mode:
994
+ if not self.isolated_mode:
1053
995
  return
1054
-
996
+ # If in isolated mode, only show selected widgets
1055
997
  if self.selected_widgets:
1056
- # ADD TO isolation instead of replacing it
1057
- self.isolated_widgets.update(self.selected_widgets) # Use update() to add, not replace
998
+ self.isolated_widgets.update(self.selected_widgets)
1058
999
  self.setUpdatesEnabled(False)
1059
1000
  try:
1060
1001
  for widget in self.annotation_widgets_by_id.values():
1061
- if widget not in self.isolated_widgets:
1002
+ if widget not in self.isolated_widgets:
1062
1003
  widget.hide()
1063
- else:
1004
+ else:
1064
1005
  widget.show()
1065
1006
  self.recalculate_widget_positions()
1007
+
1066
1008
  finally:
1067
1009
  self.setUpdatesEnabled(True)
1068
- else:
1069
- # If no widgets are selected, keep the current isolation (don't exit)
1070
- # This prevents accidentally exiting isolation mode when clearing selection
1071
- pass
1072
1010
 
1073
1011
  def select_widget(self, widget):
1074
- """Select a widget, update the data_item, and return True if state changed."""
1075
- if not widget.is_selected():
1076
- widget.set_selected(True)
1012
+ """Selects a widget, updates its data_item, and returns True if state changed."""
1013
+ if not widget.is_selected(): # is_selected() checks the data_item
1014
+ # 1. Controller modifies the state on the data item
1077
1015
  widget.data_item.set_selected(True)
1016
+ # 2. Controller tells the view to update its appearance
1017
+ widget.update_selection_visuals()
1078
1018
  self.selected_widgets.append(widget)
1079
- self.update_label_window_selection()
1080
- self._update_toolbar_state() # Update button states
1019
+ self._update_toolbar_state()
1081
1020
  return True
1082
1021
  return False
1083
1022
 
1084
1023
  def deselect_widget(self, widget):
1085
- """Deselect a widget, update the data_item, and return True if state changed."""
1024
+ """Deselects a widget, updates its data_item, and returns True if state changed."""
1086
1025
  if widget.is_selected():
1087
- widget.set_selected(False)
1026
+ # 1. Controller modifies the state on the data item
1088
1027
  widget.data_item.set_selected(False)
1028
+ # 2. Controller tells the view to update its appearance
1029
+ widget.update_selection_visuals()
1089
1030
  if widget in self.selected_widgets:
1090
1031
  self.selected_widgets.remove(widget)
1091
- self.update_label_window_selection()
1092
- self._update_toolbar_state() # Update button states
1032
+ self._update_toolbar_state()
1093
1033
  return True
1094
1034
  return False
1095
1035
 
1096
1036
  def clear_selection(self):
1097
1037
  """Clear all selected widgets and update toolbar state."""
1098
1038
  for widget in list(self.selected_widgets):
1099
- widget.set_selected(False)
1100
- self.selected_widgets.clear()
1101
- self.update_label_window_selection()
1102
- self._update_toolbar_state() # Update button states
1103
-
1104
- def update_label_window_selection(self):
1105
- """Update the label window selection based on currently selected annotations."""
1106
- explorer_window = self.parent()
1107
- while explorer_window and not hasattr(explorer_window, 'main_window'):
1108
- explorer_window = explorer_window.parent()
1109
-
1110
- if not explorer_window or not hasattr(explorer_window, 'main_window'):
1111
- return
1039
+ # This will internally call deselect_widget, which is fine
1040
+ self.deselect_widget(widget)
1112
1041
 
1113
- main_window = explorer_window.main_window
1114
- label_window = main_window.label_window
1115
- annotation_window = main_window.annotation_window
1116
-
1117
- if not self.selected_widgets:
1118
- label_window.deselect_active_label()
1119
- label_window.update_annotation_count()
1120
- return
1121
-
1122
- selected_data_items = [widget.data_item for widget in self.selected_widgets]
1123
-
1124
- first_effective_label = selected_data_items[0].effective_label
1125
- all_same_current_label = True
1126
- for item in selected_data_items:
1127
- if item.effective_label.id != first_effective_label.id:
1128
- all_same_current_label = False
1129
- break
1130
-
1131
- if all_same_current_label:
1132
- label_window.set_active_label(first_effective_label)
1133
- if not selected_data_items[0].has_preview_changes():
1134
- annotation_window.labelSelected.emit(first_effective_label.id)
1135
- else:
1136
- label_window.deselect_active_label()
1137
-
1138
- label_window.update_annotation_count()
1042
+ self.selected_widgets.clear()
1043
+ self._update_toolbar_state()
1139
1044
 
1140
1045
  def get_selected_annotations(self):
1141
1046
  """Get the annotations corresponding to selected widgets."""
1142
1047
  return [widget.annotation for widget in self.selected_widgets]
1143
-
1048
+
1144
1049
  def render_selection_from_ids(self, selected_ids):
1145
1050
  """Update the visual selection of widgets based on a set of IDs from the controller."""
1146
- # Block signals temporarily to prevent cascade updates
1147
1051
  self.setUpdatesEnabled(False)
1148
-
1149
1052
  try:
1150
1053
  for ann_id, widget in self.annotation_widgets_by_id.items():
1151
1054
  is_selected = ann_id in selected_ids
1152
- widget.set_selected(is_selected)
1055
+ # 1. Update the state on the central data item
1056
+ widget.data_item.set_selected(is_selected)
1057
+ # 2. Tell the widget to update its visuals based on the new state
1058
+ widget.update_selection_visuals()
1153
1059
 
1154
- # Resync internal list of selected widgets
1060
+ # Resync internal list of selected widgets from the source of truth
1155
1061
  self.selected_widgets = [w for w in self.annotation_widgets_by_id.values() if w.is_selected()]
1156
1062
 
1157
- # If we're in isolated mode, ADD to the isolation instead of replacing it
1158
1063
  if self.isolated_mode and self.selected_widgets:
1159
- self.isolated_widgets.update(self.selected_widgets) # Add to existing isolation
1160
- # Hide all widgets except those in the isolated set
1064
+ self.isolated_widgets.update(self.selected_widgets)
1161
1065
  for widget in self.annotation_widgets_by_id.values():
1162
- if widget not in self.isolated_widgets:
1163
- widget.hide()
1164
- else:
1165
- widget.show()
1066
+ widget.setHidden(widget not in self.isolated_widgets)
1166
1067
  self.recalculate_widget_positions()
1167
-
1168
1068
  finally:
1169
1069
  self.setUpdatesEnabled(True)
1170
-
1171
- # Update label window once at the end
1172
- self.update_label_window_selection()
1173
- # Update toolbar state to enable/disable Isolate button
1174
1070
  self._update_toolbar_state()
1175
-
1071
+
1176
1072
  def apply_preview_label_to_selected(self, preview_label):
1177
1073
  """Apply a preview label and emit a signal for the embedding view to update."""
1178
- if not self.selected_widgets or not preview_label:
1074
+ if not self.selected_widgets or not preview_label:
1179
1075
  return
1180
-
1181
1076
  changed_ids = []
1182
1077
  for widget in self.selected_widgets:
1183
1078
  widget.data_item.set_preview_label(preview_label)
1184
- widget.update() # Force repaint with new color
1079
+ widget.update() # Force repaint with new color
1185
1080
  changed_ids.append(widget.data_item.annotation.id)
1186
1081
 
1187
- # Recalculate positions to update sorting based on new effective labels
1188
- if self.sort_combo.currentText() == "Label":
1082
+ if self.sort_combo.currentText() == "Label":
1189
1083
  self.recalculate_widget_positions()
1190
-
1191
- if changed_ids:
1084
+ if changed_ids:
1192
1085
  self.preview_changed.emit(changed_ids)
1193
1086
 
1194
1087
  def clear_preview_states(self):
1195
- """Clear all preview states and revert to original labels."""
1196
- # We just need to iterate through all widgets and tell their data_items to clear
1197
- something_cleared = False
1088
+ """
1089
+ Clears all preview states, including label changes and items marked
1090
+ for deletion, reverting them to their original state.
1091
+ """
1092
+ something_changed = False
1198
1093
  for widget in self.annotation_widgets_by_id.values():
1094
+ # Check for and clear preview labels
1199
1095
  if widget.data_item.has_preview_changes():
1200
1096
  widget.data_item.clear_preview_label()
1201
- widget.update() # Repaint to show original color
1202
- something_cleared = True
1203
-
1204
- if something_cleared:
1205
- # Recalculate positions to update sorting based on reverted labels
1206
- if self.sort_combo.currentText() == "Label":
1097
+ widget.update() # Repaint to show original color
1098
+ something_changed = True
1099
+
1100
+ # Check for and un-mark items for deletion
1101
+ if widget.data_item.is_marked_for_deletion():
1102
+ widget.data_item.unmark_for_deletion()
1103
+ widget.show() # Make the widget visible again
1104
+ something_changed = True
1105
+
1106
+ if something_changed:
1107
+ # Recalculate positions to update sorting and re-flow the layout
1108
+ if self.sort_combo.currentText() in ("Label", "Image"):
1207
1109
  self.recalculate_widget_positions()
1208
- self.update_label_window_selection()
1209
1110
 
1210
1111
  def has_preview_changes(self):
1211
- """Check if there are any pending preview changes."""
1112
+ """Return True if there are preview changes."""
1212
1113
  return any(w.data_item.has_preview_changes() for w in self.annotation_widgets_by_id.values())
1213
1114
 
1214
1115
  def get_preview_changes_summary(self):
1215
- """Get a summary of preview changes for user feedback."""
1116
+ """Get a summary of preview changes."""
1216
1117
  change_count = sum(1 for w in self.annotation_widgets_by_id.values() if w.data_item.has_preview_changes())
1217
- if not change_count:
1218
- return "No preview changes"
1219
- return f"{change_count} annotation(s) with preview changes"
1118
+ return f"{change_count} annotation(s) with preview changes" if change_count else "No preview changes"
1220
1119
 
1221
1120
  def apply_preview_changes_permanently(self):
1222
- """Apply all preview changes permanently to the annotation data."""
1121
+ """Apply preview changes permanently."""
1223
1122
  applied_annotations = []
1224
1123
  for widget in self.annotation_widgets_by_id.values():
1225
- # Tell the data_item to apply its changes to the underlying annotation
1226
1124
  if widget.data_item.apply_preview_permanently():
1227
1125
  applied_annotations.append(widget.annotation)
1228
-
1229
1126
  return applied_annotations
1230
1127
 
1231
1128
 
@@ -1236,48 +1133,41 @@ class AnnotationViewer(QScrollArea):
1236
1133
 
1237
1134
  class ExplorerWindow(QMainWindow):
1238
1135
  def __init__(self, main_window, parent=None):
1136
+ """Initialize the ExplorerWindow."""
1239
1137
  super(ExplorerWindow, self).__init__(parent)
1240
1138
  self.main_window = main_window
1241
1139
  self.image_window = main_window.image_window
1242
1140
  self.label_window = main_window.label_window
1243
1141
  self.annotation_window = main_window.annotation_window
1244
1142
 
1245
- self.device = main_window.device # Use the same device as the main window
1143
+ self.device = main_window.device
1246
1144
  self.model_path = ""
1247
1145
  self.loaded_model = None
1248
-
1249
- # Store current filtered data items for embedding
1250
1146
  self.current_data_items = []
1251
-
1252
- # Cache for extracted features and the model that generated them ---
1253
1147
  self.current_features = None
1254
1148
  self.current_feature_generating_model = ""
1149
+ self._ui_initialized = False
1255
1150
 
1256
1151
  self.setWindowTitle("Explorer")
1257
- # Set the window icon
1258
1152
  explorer_icon_path = get_icon("magic.png")
1259
1153
  self.setWindowIcon(QIcon(explorer_icon_path))
1260
1154
 
1261
- # Create a central widget and main layout
1262
1155
  self.central_widget = QWidget()
1263
1156
  self.setCentralWidget(self.central_widget)
1264
1157
  self.main_layout = QVBoxLayout(self.central_widget)
1265
- # Create a left panel widget and layout for the re-parented LabelWindow
1266
1158
  self.left_panel = QWidget()
1267
1159
  self.left_layout = QVBoxLayout(self.left_panel)
1268
1160
 
1269
- # Create widgets in __init__ so they're always available
1270
- self.annotation_settings_widget = AnnotationSettingsWidget(self.main_window, self)
1271
- self.model_settings_widget = ModelSettingsWidget(self.main_window, self)
1272
- self.embedding_settings_widget = EmbeddingSettingsWidget(self.main_window, self)
1273
- self.annotation_viewer = AnnotationViewer(self)
1274
- self.embedding_viewer = EmbeddingViewer(self)
1161
+ self.annotation_settings_widget = None
1162
+ self.model_settings_widget = None
1163
+ self.embedding_settings_widget = None
1164
+ self.annotation_viewer = None
1165
+ self.embedding_viewer = None
1275
1166
 
1276
- # Create buttons
1277
1167
  self.clear_preview_button = QPushButton('Clear Preview', self)
1278
1168
  self.clear_preview_button.clicked.connect(self.clear_preview_changes)
1279
1169
  self.clear_preview_button.setToolTip("Clear all preview changes and revert to original labels")
1280
- self.clear_preview_button.setEnabled(False) # Initially disabled
1170
+ self.clear_preview_button.setEnabled(False)
1281
1171
 
1282
1172
  self.exit_button = QPushButton('Exit', self)
1283
1173
  self.exit_button.clicked.connect(self.close)
@@ -1286,27 +1176,26 @@ class ExplorerWindow(QMainWindow):
1286
1176
  self.apply_button = QPushButton('Apply', self)
1287
1177
  self.apply_button.clicked.connect(self.apply)
1288
1178
  self.apply_button.setToolTip("Apply changes")
1289
- self.apply_button.setEnabled(False) # Initially disabled
1179
+ self.apply_button.setEnabled(False)
1290
1180
 
1291
1181
  def showEvent(self, event):
1292
- self.setup_ui()
1182
+ """Handle show event."""
1183
+ if not self._ui_initialized:
1184
+ self.setup_ui()
1185
+ self._ui_initialized = True
1293
1186
  super(ExplorerWindow, self).showEvent(event)
1294
1187
 
1295
1188
  def closeEvent(self, event):
1296
- """
1297
- Handles the window close event.
1298
- This now calls the resource cleanup method.
1299
- """
1189
+ """Handle close event."""
1300
1190
  # Stop any running timers to prevent errors
1301
1191
  if hasattr(self, 'embedding_viewer') and self.embedding_viewer:
1302
1192
  if hasattr(self.embedding_viewer, 'animation_timer') and self.embedding_viewer.animation_timer:
1303
1193
  self.embedding_viewer.animation_timer.stop()
1304
1194
 
1305
- # Clear any unsaved preview states
1306
- if hasattr(self, 'annotation_viewer'):
1307
- self.annotation_viewer.clear_preview_states()
1195
+ # Call the main cancellation method to revert any pending changes
1196
+ self.clear_preview_changes()
1308
1197
 
1309
- # --- NEW: Call the dedicated cleanup method ---
1198
+ # Call the dedicated cleanup method
1310
1199
  self._cleanup_resources()
1311
1200
 
1312
1201
  # Re-enable the main window before closing
@@ -1319,76 +1208,80 @@ class ExplorerWindow(QMainWindow):
1319
1208
 
1320
1209
  # Clear the reference in the main_window to allow garbage collection
1321
1210
  self.main_window.explorer_window = None
1322
-
1211
+
1212
+ # Set the ui_initialized flag to False so it can be re-initialized next time
1213
+ self._ui_initialized = False
1214
+
1323
1215
  event.accept()
1324
1216
 
1325
1217
  def setup_ui(self):
1326
- # Clear the main layout to remove any existing widgets
1218
+ """Set up the UI for the ExplorerWindow."""
1327
1219
  while self.main_layout.count():
1328
1220
  child = self.main_layout.takeAt(0)
1329
1221
  if child.widget():
1330
- child.widget().setParent(None) # Remove from layout but don't delete
1222
+ child.widget().setParent(None)
1331
1223
 
1332
- # Top section: Conditions and Settings side by side
1333
- top_layout = QHBoxLayout()
1224
+ # Lazily initialize the settings and viewer widgets if they haven't been created yet.
1225
+ # This ensures that the widgets are only created once per ExplorerWindow instance.
1334
1226
 
1335
- # Add existing widgets to layout
1336
- top_layout.addWidget(self.annotation_settings_widget, 2) # Give annotation settings more space
1337
- top_layout.addWidget(self.model_settings_widget, 1) # Model settings in the middle
1338
- top_layout.addWidget(self.embedding_settings_widget, 1) # Embedding settings on the right
1227
+ # Annotation settings panel (filters by image, type, label)
1228
+ if self.annotation_settings_widget is None:
1229
+ self.annotation_settings_widget = AnnotationSettingsWidget(self.main_window, self)
1230
+
1231
+ # Model selection panel (choose feature extraction model)
1232
+ if self.model_settings_widget is None:
1233
+ self.model_settings_widget = ModelSettingsWidget(self.main_window, self)
1234
+
1235
+ # Embedding settings panel (choose dimensionality reduction method)
1236
+ if self.embedding_settings_widget is None:
1237
+ self.embedding_settings_widget = EmbeddingSettingsWidget(self.main_window, self)
1238
+
1239
+ # Annotation viewer (shows annotation image crops in a grid)
1240
+ if self.annotation_viewer is None:
1241
+ self.annotation_viewer = AnnotationViewer(self)
1242
+
1243
+ # Embedding viewer (shows 2D embedding scatter plot)
1244
+ if self.embedding_viewer is None:
1245
+ self.embedding_viewer = EmbeddingViewer(self)
1339
1246
 
1340
- # Create container widget for top layout
1247
+ top_layout = QHBoxLayout()
1248
+ top_layout.addWidget(self.annotation_settings_widget, 2)
1249
+ top_layout.addWidget(self.model_settings_widget, 1)
1250
+ top_layout.addWidget(self.embedding_settings_widget, 1)
1341
1251
  top_container = QWidget()
1342
1252
  top_container.setLayout(top_layout)
1343
1253
  self.main_layout.addWidget(top_container)
1344
1254
 
1345
- # Middle section: Annotation Viewer (left) and Embedding Viewer (right)
1346
1255
  middle_splitter = QSplitter(Qt.Horizontal)
1347
-
1348
- # Wrap annotation viewer in a group box
1349
1256
  annotation_group = QGroupBox("Annotation Viewer")
1350
1257
  annotation_layout = QVBoxLayout(annotation_group)
1351
1258
  annotation_layout.addWidget(self.annotation_viewer)
1352
1259
  middle_splitter.addWidget(annotation_group)
1353
1260
 
1354
- # Wrap embedding viewer in a group box
1355
1261
  embedding_group = QGroupBox("Embedding Viewer")
1356
1262
  embedding_layout = QVBoxLayout(embedding_group)
1357
1263
  embedding_layout.addWidget(self.embedding_viewer)
1358
1264
  middle_splitter.addWidget(embedding_group)
1359
-
1360
- # Set splitter proportions (annotation viewer wider)
1361
1265
  middle_splitter.setSizes([500, 500])
1362
-
1363
- # Add middle section to main layout with stretch factor
1364
1266
  self.main_layout.addWidget(middle_splitter, 1)
1365
-
1366
- # Note: LabelWindow will be re-parented here by MainWindow.open_explorer_window()
1367
- # The LabelWindow will be added to self.left_layout at index 1 by the MainWindow
1368
1267
  self.main_layout.addWidget(self.label_window)
1369
1268
 
1370
- # Bottom control buttons
1371
1269
  self.buttons_layout = QHBoxLayout()
1372
- # Add stretch to push buttons to the right
1373
1270
  self.buttons_layout.addStretch(1)
1374
-
1375
- # Add existing buttons to layout
1376
1271
  self.buttons_layout.addWidget(self.clear_preview_button)
1377
1272
  self.buttons_layout.addWidget(self.exit_button)
1378
1273
  self.buttons_layout.addWidget(self.apply_button)
1379
-
1380
1274
  self.main_layout.addLayout(self.buttons_layout)
1381
1275
 
1382
- # Set default condition to current image and refresh filters
1383
1276
  self.annotation_settings_widget.set_default_to_current_image()
1384
1277
  self.refresh_filters()
1385
1278
 
1386
- # Connect label selection to preview updates (only connect once)
1387
1279
  try:
1388
1280
  self.label_window.labelSelected.disconnect(self.on_label_selected_for_preview)
1389
1281
  except TypeError:
1390
- pass # Signal wasn't connected yet
1391
-
1282
+ pass
1283
+
1284
+ # Connect signals to slots
1392
1285
  self.label_window.labelSelected.connect(self.on_label_selected_for_preview)
1393
1286
  self.annotation_viewer.selection_changed.connect(self.on_annotation_view_selection_changed)
1394
1287
  self.annotation_viewer.preview_changed.connect(self.on_preview_changed)
@@ -1398,229 +1291,202 @@ class ExplorerWindow(QMainWindow):
1398
1291
 
1399
1292
  @pyqtSlot(list)
1400
1293
  def on_annotation_view_selection_changed(self, changed_ann_ids):
1401
- """A selection was made in the AnnotationViewer, so update the EmbeddingViewer."""
1294
+ """Syncs selection from AnnotationViewer to EmbeddingViewer."""
1402
1295
  all_selected_ids = {w.data_item.annotation.id for w in self.annotation_viewer.selected_widgets}
1403
-
1404
- # Only try to sync the selection with the EmbeddingViewer if it has points.
1405
- # This prevents the feedback loop that was clearing the selection.
1406
1296
  if self.embedding_viewer.points_by_id:
1407
1297
  self.embedding_viewer.render_selection_from_ids(all_selected_ids)
1408
-
1409
- self.update_label_window_selection() # Keep label window in sync
1298
+
1299
+ # Call the new centralized method
1300
+ self.update_label_window_selection()
1410
1301
 
1411
1302
  @pyqtSlot(list)
1412
1303
  def on_embedding_view_selection_changed(self, all_selected_ann_ids):
1413
- """A selection was made in the EmbeddingViewer, so update the AnnotationViewer."""
1414
- # Check if this is a new selection being made when nothing was previously selected
1304
+ """Syncs selection from EmbeddingViewer to AnnotationViewer."""
1305
+ # Check the state BEFORE the selection is changed
1415
1306
  was_empty_selection = len(self.annotation_viewer.selected_widgets) == 0
1416
- is_new_selection = len(all_selected_ann_ids) > 0
1417
-
1418
- # Update the annotation viewer with the new selection
1307
+
1308
+ # Now, update the selection in the annotation viewer
1419
1309
  self.annotation_viewer.render_selection_from_ids(set(all_selected_ann_ids))
1420
-
1421
- # Auto-switch to isolation mode if conditions are met
1422
- if (was_empty_selection and
1423
- is_new_selection and
1424
- not self.annotation_viewer.isolated_mode):
1425
- print("Auto-switching to isolation mode due to new selection in embedding viewer")
1310
+
1311
+ # The rest of the logic now works correctly
1312
+ is_new_selection = len(all_selected_ann_ids) > 0
1313
+ if (
1314
+ was_empty_selection and
1315
+ is_new_selection and
1316
+ not self.annotation_viewer.isolated_mode
1317
+ ):
1426
1318
  self.annotation_viewer.isolate_selection()
1427
-
1428
- self.update_label_window_selection() # Keep label window in sync
1319
+
1320
+ self.update_label_window_selection()
1429
1321
 
1430
1322
  @pyqtSlot(list)
1431
1323
  def on_preview_changed(self, changed_ann_ids):
1432
- """A preview color was changed in the AnnotationViewer, so update the EmbeddingViewer points."""
1324
+ """Updates embedding point colors when a preview label is applied."""
1433
1325
  for ann_id in changed_ann_ids:
1434
1326
  point = self.embedding_viewer.points_by_id.get(ann_id)
1435
1327
  if point:
1436
- point.update() # Force the point to repaint itself
1437
-
1328
+ point.update()
1329
+
1438
1330
  @pyqtSlot()
1439
1331
  def on_reset_view_requested(self):
1440
1332
  """Handle reset view requests from double-click in either viewer."""
1441
1333
  # Clear all selections in both viewers
1442
1334
  self.annotation_viewer.clear_selection()
1443
1335
  self.embedding_viewer.render_selection_from_ids(set())
1444
-
1336
+
1445
1337
  # Exit isolation mode if currently active
1446
1338
  if self.annotation_viewer.isolated_mode:
1447
1339
  self.annotation_viewer.show_all_annotations()
1448
-
1449
- # Update button states
1340
+
1341
+ self.update_label_window_selection()
1450
1342
  self.update_button_states()
1451
-
1343
+
1452
1344
  print("Reset view: cleared selections and exited isolation mode")
1453
1345
 
1454
1346
  def update_label_window_selection(self):
1455
- """Update the label window based on the selection in the annotation viewer."""
1456
- self.annotation_viewer.update_label_window_selection()
1347
+ """
1348
+ Updates the label window based on the selection state of the currently
1349
+ loaded data items. This is the single, centralized point of logic.
1350
+ """
1351
+ # Get selected items directly from the master data list
1352
+ selected_data_items = [
1353
+ item for item in self.current_data_items if item.is_selected
1354
+ ]
1355
+
1356
+ if not selected_data_items:
1357
+ self.label_window.deselect_active_label()
1358
+ self.label_window.update_annotation_count()
1359
+ return
1360
+
1361
+ first_effective_label = selected_data_items[0].effective_label
1362
+ all_same_current_label = all(
1363
+ item.effective_label.id == first_effective_label.id
1364
+ for item in selected_data_items
1365
+ )
1366
+
1367
+ if all_same_current_label:
1368
+ self.label_window.set_active_label(first_effective_label)
1369
+ # This emit is what updates other UI elements, like the annotation list
1370
+ self.annotation_window.labelSelected.emit(first_effective_label.id)
1371
+ else:
1372
+ self.label_window.deselect_active_label()
1373
+
1374
+ self.label_window.update_annotation_count()
1457
1375
 
1458
1376
  def get_filtered_data_items(self):
1459
- """Get annotations that match all conditions, returned as AnnotationDataItem objects."""
1377
+ """Gets annotations matching all conditions as AnnotationDataItem objects."""
1460
1378
  data_items = []
1461
1379
  if not hasattr(self.main_window.annotation_window, 'annotations_dict'):
1462
1380
  return data_items
1463
1381
 
1464
- # Get current filter conditions
1465
1382
  selected_images = self.annotation_settings_widget.get_selected_images()
1466
1383
  selected_types = self.annotation_settings_widget.get_selected_annotation_types()
1467
1384
  selected_labels = self.annotation_settings_widget.get_selected_labels()
1468
1385
 
1469
- annotations_to_process = []
1470
- for annotation in self.main_window.annotation_window.annotations_dict.values():
1471
- annotation_matches = True
1472
-
1473
- # Check image condition - if empty list, no annotations match
1474
- if selected_images:
1475
- annotation_image = os.path.basename(annotation.image_path)
1476
- if annotation_image not in selected_images:
1477
- annotation_matches = False
1478
- else:
1479
- # No images selected means no annotations should match
1480
- annotation_matches = False
1481
-
1482
- # Check annotation type condition - if empty list, no annotations match
1483
- if annotation_matches:
1484
- if selected_types:
1485
- annotation_type = type(annotation).__name__
1486
- if annotation_type not in selected_types:
1487
- annotation_matches = False
1488
- else:
1489
- # No types selected means no annotations should match
1490
- annotation_matches = False
1491
-
1492
- # Check label condition - if empty list, no annotations match
1493
- if annotation_matches:
1494
- if selected_labels:
1495
- annotation_label = annotation.label.short_label_code
1496
- if annotation_label not in selected_labels:
1497
- annotation_matches = False
1498
- else:
1499
- # No labels selected means no annotations should match
1500
- annotation_matches = False
1386
+ if not all([selected_images, selected_types, selected_labels]):
1387
+ return []
1501
1388
 
1502
- if annotation_matches:
1503
- annotations_to_process.append(annotation)
1389
+ annotations_to_process = [
1390
+ ann for ann in self.main_window.annotation_window.annotations_dict.values()
1391
+ if (os.path.basename(ann.image_path) in selected_images and
1392
+ type(ann).__name__ in selected_types and
1393
+ ann.label.short_label_code in selected_labels)
1394
+ ]
1504
1395
 
1505
- # Ensure all filtered annotations have cropped images
1506
1396
  self._ensure_cropped_images(annotations_to_process)
1507
-
1508
- # Wrap in AnnotationDataItem
1509
- for ann in annotations_to_process:
1510
- data_items.append(AnnotationDataItem(ann))
1397
+ return [AnnotationDataItem(ann) for ann in annotations_to_process]
1511
1398
 
1512
- return data_items
1513
-
1514
1399
  def _ensure_cropped_images(self, annotations):
1515
- """Ensure all provided annotations have a cropped image available."""
1400
+ """Ensures all provided annotations have a cropped image available."""
1516
1401
  annotations_by_image = {}
1402
+
1517
1403
  for annotation in annotations:
1518
- # Only process annotations that don't have a cropped image yet
1519
1404
  if not annotation.cropped_image:
1520
1405
  image_path = annotation.image_path
1521
1406
  if image_path not in annotations_by_image:
1522
1407
  annotations_by_image[image_path] = []
1523
1408
  annotations_by_image[image_path].append(annotation)
1524
-
1525
- # Only proceed if there are annotations that actually need cropping
1526
- if annotations_by_image:
1527
- progress_bar = ProgressBar(self, "Cropping Image Annotations")
1528
- progress_bar.show()
1529
- progress_bar.start_progress(len(annotations_by_image))
1530
-
1531
- try:
1532
- # Crop annotations for each image using the AnnotationWindow method
1533
- # This ensures consistency with how cropped images are generated elsewhere
1534
- for image_path, image_annotations in annotations_by_image.items():
1535
- self.annotation_window.crop_annotations(
1536
- image_path=image_path,
1537
- annotations=image_annotations,
1538
- return_annotations=False, # We don't need the return value
1539
- verbose=False
1540
- )
1541
- # Update progress bar
1542
- progress_bar.update_progress()
1543
-
1544
- except Exception as e:
1545
- print(f"Error cropping annotations: {e}")
1546
1409
 
1547
- finally:
1548
- progress_bar.finish_progress()
1549
- progress_bar.stop_progress()
1550
- progress_bar.close()
1410
+ if not annotations_by_image:
1411
+ return
1412
+
1413
+ progress_bar = ProgressBar(self, "Cropping Image Annotations")
1414
+ progress_bar.show()
1415
+ progress_bar.start_progress(len(annotations_by_image))
1416
+
1417
+ try:
1418
+ for image_path, image_annotations in annotations_by_image.items():
1419
+ self.annotation_window.crop_annotations(image_path=image_path,
1420
+ annotations=image_annotations,
1421
+ return_annotations=False,
1422
+ verbose=False)
1423
+ progress_bar.update_progress()
1424
+ finally:
1425
+ progress_bar.finish_progress()
1426
+ progress_bar.stop_progress()
1427
+ progress_bar.close()
1551
1428
 
1552
1429
  def _extract_color_features(self, data_items, progress_bar=None, bins=32):
1553
1430
  """
1554
- Extracts a comprehensive set of color features from annotation crops using only NumPy.
1555
-
1556
- For each image, it calculates:
1557
- 1. Color Moments (per channel):
1558
- - Mean (1st moment)
1559
- - Standard Deviation (2nd moment)
1560
- - Skewness (3rd moment)
1561
- - Kurtosis (4th moment)
1562
- 2. Color Histogram (per channel)
1563
- 3. Grayscale Statistics:
1564
- - Mean Brightness
1565
- - Contrast (Std Dev)
1566
- - Intensity Range
1567
- 4. Geometric Features:
1568
- - Area
1569
- - Perimeter
1431
+ Extracts color-based features from annotation crops.
1432
+
1433
+ Features extracted per annotation:
1434
+ - Mean, standard deviation, skewness, and kurtosis for each RGB channel
1435
+ - Normalized histogram for each RGB channel
1436
+ - Grayscale statistics: mean, std, range
1437
+ - Geometric features: area, perimeter (if available)
1438
+ Returns:
1439
+ features: np.ndarray of shape (N, feature_dim)
1440
+ valid_data_items: list of AnnotationDataItem with valid crops
1570
1441
  """
1571
1442
  if progress_bar:
1572
- progress_bar.set_title("Extracting Color Features...")
1443
+ progress_bar.set_title("Extracting features...")
1573
1444
  progress_bar.start_progress(len(data_items))
1574
1445
 
1575
1446
  features = []
1576
1447
  valid_data_items = []
1448
+
1577
1449
  for item in data_items:
1578
1450
  pixmap = item.annotation.get_cropped_image()
1579
1451
  if pixmap and not pixmap.isNull():
1580
- # arr has shape (height, width, 3)
1452
+ # Convert QPixmap to numpy array (H, W, 3)
1581
1453
  arr = pixmap_to_numpy(pixmap)
1582
-
1583
- # Reshape for channel-wise statistics: (num_pixels, 3)
1584
1454
  pixels = arr.reshape(-1, 3)
1585
1455
 
1586
- # --- 1. Calculate Color Moments using only NumPy ---
1456
+ # Basic color statistics
1587
1457
  mean_color = np.mean(pixels, axis=0)
1588
1458
  std_color = np.std(pixels, axis=0)
1589
-
1590
- # Center the data (subtract the mean) for skew/kurtosis calculation
1459
+
1460
+ # Skewness and kurtosis for each channel
1461
+ epsilon = 1e-8 # Prevent division by zero
1591
1462
  centered_pixels = pixels - mean_color
1592
-
1593
- # A small value (epsilon) is added to the denominator to prevent division by zero
1594
- epsilon = 1e-8
1595
- skew_color = np.mean(centered_pixels**3, axis=0) / (std_color**3 + epsilon)
1596
- kurt_color = np.mean(centered_pixels**4, axis=0) / (std_color**4 + epsilon) - 3
1597
-
1598
- # --- 2. Calculate Color Histograms ---
1599
- histograms = []
1600
- for i in range(3): # For each channel (R, G, B)
1601
- hist, _ = np.histogram(pixels[:, i], bins=bins, range=(0, 255))
1602
-
1603
- # Normalize histogram
1604
- hist_sum = np.sum(hist)
1605
- if hist_sum > 0:
1606
- histograms.append(hist / hist_sum)
1607
- else: # Avoid division by zero
1608
- histograms.append(np.zeros(bins))
1609
-
1610
- # --- 3. Calculate Grayscale Statistics ---
1463
+ skew_color = np.mean(centered_pixels ** 3, axis=0) / (std_color ** 3 + epsilon)
1464
+ kurt_color = np.mean(centered_pixels ** 4, axis=0) / (std_color ** 4 + epsilon) - 3
1465
+
1466
+ # Normalized histograms for each channel
1467
+ histograms = [
1468
+ np.histogram(pixels[:, i], bins=bins, range=(0, 255))[0]
1469
+ for i in range(3)
1470
+ ]
1471
+ histograms = [
1472
+ h / h.sum() if h.sum() > 0 else np.zeros(bins)
1473
+ for h in histograms
1474
+ ]
1475
+
1476
+ # Grayscale statistics
1611
1477
  gray_arr = np.dot(arr[..., :3], [0.2989, 0.5870, 0.1140])
1612
1478
  grayscale_stats = np.array([
1613
- np.mean(gray_arr), # Overall brightness
1614
- np.std(gray_arr), # Overall contrast
1615
- np.ptp(gray_arr) # Peak-to-peak intensity range
1479
+ np.mean(gray_arr),
1480
+ np.std(gray_arr),
1481
+ np.ptp(gray_arr)
1616
1482
  ])
1617
-
1618
- # --- 4. Calculate Geometric Features ---
1483
+
1484
+ # Geometric features (area, perimeter)
1619
1485
  area = getattr(item.annotation, 'area', 0.0)
1620
1486
  perimeter = getattr(item.annotation, 'perimeter', 0.0)
1621
1487
  geometric_features = np.array([area, perimeter])
1622
-
1623
- # --- 5. Concatenate all features into a single vector ---
1488
+
1489
+ # Concatenate all features into a single vector
1624
1490
  current_features = np.concatenate([
1625
1491
  mean_color,
1626
1492
  std_color,
@@ -1630,230 +1496,163 @@ class ExplorerWindow(QMainWindow):
1630
1496
  grayscale_stats,
1631
1497
  geometric_features
1632
1498
  ])
1633
-
1499
+
1634
1500
  features.append(current_features)
1635
1501
  valid_data_items.append(item)
1636
- else:
1637
- print(f"Warning: Could not get cropped image for annotation ID {item.annotation.id}. Skipping.")
1638
-
1502
+
1639
1503
  if progress_bar:
1640
1504
  progress_bar.update_progress()
1641
1505
 
1642
1506
  return np.array(features), valid_data_items
1643
1507
 
1644
1508
  def _extract_yolo_features(self, data_items, model_info, progress_bar=None):
1645
- """
1646
- Extracts features from annotation crops using a specified YOLO model.
1647
- Uses model.embed() for embedding features or model.predict() for classification probabilities.
1648
- """
1649
- # Unpack model information
1509
+ """Extracts features from annotation crops using a YOLO model."""
1510
+ # Extract model name and feature mode from the provided model_info tuple
1650
1511
  model_name, feature_mode = model_info
1651
1512
 
1652
- # Load or retrieve the cached model
1653
1513
  if model_name != self.model_path or self.loaded_model is None:
1654
1514
  try:
1655
- print(f"Loading new model: {model_name}")
1656
1515
  self.loaded_model = YOLO(model_name)
1657
1516
  self.model_path = model_name
1658
-
1659
- # Determine image size from model config if possible
1660
- try:
1661
- self.imgsz = self.loaded_model.model.args['imgsz']
1662
- if self.imgsz > 224:
1663
- self.imgsz = 128
1664
- except (AttributeError, KeyError):
1665
- self.imgsz = 128
1666
-
1667
- # Run a dummy inference to warm up the model
1668
- print(f"Warming up model on device '{self.device}'...")
1517
+ self.imgsz = getattr(self.loaded_model.model.args, 'imgsz', 128)
1669
1518
  dummy_image = np.zeros((self.imgsz, self.imgsz, 3), dtype=np.uint8)
1670
1519
  self.loaded_model.predict(dummy_image, imgsz=self.imgsz, half=True, device=self.device, verbose=False)
1671
-
1520
+
1672
1521
  except Exception as e:
1673
1522
  print(f"ERROR: Could not load YOLO model '{model_name}': {e}")
1674
1523
  return np.array([]), []
1675
-
1676
- if progress_bar:
1677
- progress_bar.set_title(f"Preparing images...")
1678
- progress_bar.start_progress(len(data_items))
1679
1524
 
1680
- # 1. Prepare a list of all valid images and their corresponding data items.
1681
- image_list = []
1682
- valid_data_items = []
1525
+ if progress_bar:
1526
+ progress_bar.set_title("Preparing images...")
1527
+ progress_bar.start_progress(len(data_items))
1528
+
1529
+ image_list, valid_data_items = [], []
1683
1530
  for item in data_items:
1531
+ # Get the cropped image from the annotation
1684
1532
  pixmap = item.annotation.get_cropped_image()
1533
+
1685
1534
  if pixmap and not pixmap.isNull():
1686
- image_np = pixmap_to_numpy(pixmap)
1687
- image_list.append(image_np)
1535
+ image_list.append(pixmap_to_numpy(pixmap))
1688
1536
  valid_data_items.append(item)
1689
- else:
1690
- print(f"Warning: Could not get cropped image for annotation ID {item.annotation.id}. Skipping.")
1691
-
1692
- if progress_bar:
1537
+
1538
+ if progress_bar:
1693
1539
  progress_bar.update_progress()
1694
-
1695
- if not valid_data_items:
1696
- print("Warning: No valid images found to process.")
1540
+
1541
+ if not valid_data_items:
1697
1542
  return np.array([]), []
1698
-
1699
- embeddings_list = []
1700
1543
 
1701
- try:
1702
- if progress_bar:
1703
- progress_bar.set_busy_mode(f"Extracting features with {os.path.basename(model_name)}...")
1704
-
1705
- kwargs = {
1706
- 'stream': True,
1707
- 'imgsz': self.imgsz,
1708
- 'half': True,
1709
- 'device': self.device,
1710
- 'verbose': False
1711
- }
1712
-
1713
- # 2. Choose between embed() and predict() based on feature mode
1714
- using_embed_method = feature_mode == "Embed Features"
1715
-
1716
- if using_embed_method:
1717
- print("Using embed() method")
1718
-
1719
- # Use model.embed() method (uses the second to last layer)
1720
- results_generator = self.loaded_model.embed(image_list, **kwargs)
1721
-
1722
- else:
1723
- print("Using predict() method")
1724
-
1725
- # Use model.predict() method for classification probabilities
1726
- results_generator = self.loaded_model.predict(image_list, **kwargs)
1727
-
1728
- if progress_bar:
1729
- progress_bar.set_title(f"Extracting features with {os.path.basename(model_name)}...")
1730
- progress_bar.start_progress(len(valid_data_items))
1544
+ # Specify the kwargs for YOLO model prediction
1545
+ kwargs = {'stream': True,
1546
+ 'imgsz': self.imgsz,
1547
+ 'half': True,
1548
+ 'device': self.device,
1549
+ 'verbose': False}
1550
+
1551
+ # Run the model to extract features
1552
+ if feature_mode == "Embed Features":
1553
+ results_generator = self.loaded_model.embed(image_list, **kwargs)
1554
+ else:
1555
+ results_generator = self.loaded_model.predict(image_list, **kwargs)
1731
1556
 
1732
- # 3. Process the results from the generator - different handling based on method
1733
- for result in results_generator:
1734
- if using_embed_method:
1735
- try:
1736
- # With embed(), result is directly a tensor
1737
- embedding = result.cpu().numpy().flatten()
1738
- embeddings_list.append(embedding)
1739
- except Exception as e:
1740
- print(f"Error processing embedding: {e}")
1741
- raise TypeError(
1742
- f"Model '{os.path.basename(model_name)}' did not return valid embeddings. "
1743
- f"Error: {str(e)}"
1744
- )
1557
+ if progress_bar:
1558
+ progress_bar.set_title("Extracting features...")
1559
+ progress_bar.start_progress(len(valid_data_items))
1560
+
1561
+ # Prepare a list to hold the extracted features
1562
+ embeddings_list = []
1563
+
1564
+ try:
1565
+ # Iterate through the results and extract features based on the mode
1566
+ for i, result in enumerate(results_generator):
1567
+ if feature_mode == "Embed Features":
1568
+ embeddings_list.append(result.cpu().numpy().flatten())
1569
+
1570
+ elif hasattr(result, 'probs') and result.probs is not None:
1571
+ embeddings_list.append(result.probs.data.cpu().numpy().squeeze())
1572
+
1745
1573
  else:
1746
- # Classification mode: We expect probability vectors
1747
- if hasattr(result, 'probs') and result.probs is not None:
1748
- embedding = result.probs.data.cpu().numpy().squeeze()
1749
- embeddings_list.append(embedding)
1750
- else:
1751
- raise TypeError(
1752
- f"Model '{os.path.basename(model_name)}' did not return probability vectors. "
1753
- "Make sure this is a classification model."
1754
- )
1755
-
1756
- # Update progress for each item as it's processed from the stream.
1757
- if progress_bar:
1574
+ raise TypeError("Model did not return expected output")
1575
+
1576
+ if progress_bar:
1758
1577
  progress_bar.update_progress()
1759
-
1760
- if not embeddings_list:
1761
- print("Warning: No features were extracted. The model may have failed.")
1762
- return np.array([]), []
1763
-
1764
- embeddings = np.array(embeddings_list)
1765
- except Exception as e:
1766
- print(f"ERROR: An error occurred during feature extraction: {e}")
1767
- return np.array([]), []
1768
1578
  finally:
1769
- # Clean up CUDA memory after the operation
1770
1579
  if torch.cuda.is_available():
1771
1580
  torch.cuda.empty_cache()
1772
-
1773
- print(f"Successfully extracted {len(embeddings)} features with shape {embeddings.shape}")
1774
- return embeddings, valid_data_items
1581
+
1582
+ return np.array(embeddings_list), valid_data_items
1775
1583
 
1776
1584
  def _extract_features(self, data_items, progress_bar=None):
1777
- """
1778
- Dispatcher method to call the appropriate feature extraction function.
1779
- It now passes the progress_bar object to the sub-methods.
1780
- """
1585
+ """Dispatcher to call the appropriate feature extraction function."""
1586
+ # Get the selected model and feature mode from the model settings widget
1781
1587
  model_name, feature_mode = self.model_settings_widget.get_selected_model()
1782
1588
 
1783
- # Handle tuple or string return value
1784
- if isinstance(model_name, tuple) and len(model_name) >= 3:
1589
+ if isinstance(model_name, tuple):
1785
1590
  model_name = model_name[0]
1786
- else:
1787
- model_name = model_name
1788
-
1789
- if not model_name:
1790
- print("No model selected or path provided.")
1591
+
1592
+ if not model_name:
1791
1593
  return np.array([]), []
1792
-
1793
- # --- MODIFIED: Pass the progress_bar object ---
1594
+
1794
1595
  if model_name == "Color Features":
1795
1596
  return self._extract_color_features(data_items, progress_bar=progress_bar)
1597
+
1796
1598
  elif ".pt" in model_name:
1797
- # Pass the full model_info which may include embed layers
1798
1599
  return self._extract_yolo_features(data_items, (model_name, feature_mode), progress_bar=progress_bar)
1799
- else:
1800
- print(f"Unknown or invalid feature model selected: {model_name}")
1801
- return np.array([]), []
1600
+
1601
+ return np.array([]), []
1802
1602
 
1803
1603
  def _run_dimensionality_reduction(self, features, params):
1804
1604
  """
1805
- Runs PCA, UMAP or t-SNE on the feature matrix using provided parameters.
1605
+ Runs dimensionality reduction (PCA, UMAP, or t-SNE) on the feature matrix.
1606
+
1607
+ Args:
1608
+ features (np.ndarray): Feature matrix of shape (N, D).
1609
+ params (dict): Embedding parameters, including technique and its hyperparameters.
1610
+
1611
+ Returns:
1612
+ np.ndarray or None: 2D embedded features of shape (N, 2), or None on failure.
1806
1613
  """
1807
- technique = params.get('technique', 'UMAP') # Changed default to UMAP
1808
- random_state = 42
1614
+ technique = params.get('technique', 'UMAP')
1809
1615
 
1810
- print(f"Running {technique} on {len(features)} items with params: {params}")
1811
- if len(features) <= 2: # UMAP/t-SNE need at least a few points
1812
- print("Not enough data points for dimensionality reduction.")
1616
+ if len(features) <= 2:
1617
+ # Not enough samples for dimensionality reduction
1813
1618
  return None
1814
1619
 
1815
1620
  try:
1816
- # Scaling is crucial, your implementation is already correct
1817
- scaler = StandardScaler()
1818
- features_scaled = scaler.fit_transform(features)
1819
-
1820
- if technique == "UMAP":
1821
- # Get hyperparameters from params, with sensible defaults
1822
- n_neighbors = params.get('n_neighbors', 15)
1823
- min_dist = params.get('min_dist', 0.1)
1824
- metric = params.get('metric', 'cosine') # Allow metric to be specified
1621
+ # Standardize features before reduction
1622
+ features_scaled = StandardScaler().fit_transform(features)
1825
1623
 
1624
+ if technique == "UMAP":
1625
+ # UMAP: n_neighbors must be < n_samples
1626
+ n_neighbors = min(params.get('n_neighbors', 15), len(features_scaled) - 1)
1627
+
1826
1628
  reducer = UMAP(
1827
- n_components=2,
1828
- random_state=random_state,
1829
- n_neighbors=min(n_neighbors, len(features_scaled) - 1),
1830
- min_dist=min_dist,
1831
- metric=metric # Use the specified metric
1629
+ n_components=2,
1630
+ random_state=42,
1631
+ n_neighbors=n_neighbors,
1632
+ min_dist=params.get('min_dist', 0.1),
1633
+ metric=params.get('metric', 'cosine')
1832
1634
  )
1833
-
1834
1635
  elif technique == "TSNE":
1835
- perplexity = params.get('perplexity', 30)
1836
- early_exaggeration = params.get('early_exaggeration', 12.0)
1837
- learning_rate = params.get('learning_rate', 'auto')
1838
-
1636
+ # t-SNE: perplexity must be < n_samples
1637
+ perplexity = min(params.get('perplexity', 30), len(features_scaled) - 1)
1638
+
1839
1639
  reducer = TSNE(
1840
- n_components=2,
1841
- random_state=random_state,
1842
- perplexity=min(perplexity, len(features_scaled) - 1),
1843
- early_exaggeration=early_exaggeration,
1844
- learning_rate=learning_rate,
1845
- init='pca' # Improves stability and speed
1640
+ n_components=2,
1641
+ random_state=42,
1642
+ perplexity=perplexity,
1643
+ early_exaggeration=params.get('early_exaggeration', 12.0),
1644
+ learning_rate=params.get('learning_rate', 'auto'),
1645
+ init='pca'
1846
1646
  )
1847
-
1848
1647
  elif technique == "PCA":
1849
- reducer = PCA(n_components=2, random_state=random_state)
1850
-
1648
+ reducer = PCA(n_components=2, random_state=42)
1851
1649
  else:
1852
- print(f"Unknown dimensionality reduction technique: {technique}")
1650
+ # Unknown technique
1853
1651
  return None
1854
-
1652
+
1653
+ # Fit and transform the features
1855
1654
  return reducer.fit_transform(features_scaled)
1856
-
1655
+
1857
1656
  except Exception as e:
1858
1657
  print(f"Error during {technique} dimensionality reduction: {e}")
1859
1658
  return None
@@ -1861,91 +1660,59 @@ class ExplorerWindow(QMainWindow):
1861
1660
  def _update_data_items_with_embedding(self, data_items, embedded_features):
1862
1661
  """Updates AnnotationDataItem objects with embedding results."""
1863
1662
  scale_factor = 4000
1864
- min_vals = np.min(embedded_features, axis=0)
1865
- max_vals = np.max(embedded_features, axis=0)
1663
+ min_vals, max_vals = np.min(embedded_features, axis=0), np.max(embedded_features, axis=0)
1866
1664
  range_vals = max_vals - min_vals
1867
-
1868
1665
  for i, item in enumerate(data_items):
1869
- # Normalize coordinates for consistent display
1870
1666
  norm_x = (embedded_features[i, 0] - min_vals[0]) / range_vals[0] if range_vals[0] > 0 else 0.5
1871
1667
  norm_y = (embedded_features[i, 1] - min_vals[1]) / range_vals[1] if range_vals[1] > 0 else 0.5
1872
- # Scale and center the points in the view
1873
1668
  item.embedding_x = (norm_x * scale_factor) - (scale_factor / 2)
1874
1669
  item.embedding_y = (norm_y * scale_factor) - (scale_factor / 2)
1875
1670
 
1876
1671
  def run_embedding_pipeline(self):
1877
- """
1878
- Orchestrates the feature extraction and dimensionality reduction pipeline.
1879
- This version correctly re-runs reduction on cached features when parameters change.
1880
- """
1881
- if not self.current_data_items:
1882
- print("No items to process for embedding.")
1672
+ """Orchestrates the feature extraction and dimensionality reduction pipeline."""
1673
+ if not self.current_data_items:
1883
1674
  return
1884
-
1885
- # 1. Get current parameters from the UI
1886
- embedding_params = self.embedding_settings_widget.get_embedding_parameters()
1887
- model_info = self.model_settings_widget.get_selected_model() # Now returns tuple (model_name, feature_mode)
1888
1675
 
1889
- # Unpack model info - both model_name and feature_mode are used for caching
1890
- if isinstance(model_info, tuple):
1891
- selected_model, selected_feature_mode = model_info
1892
- # Create a unique cache key that includes both model and feature mode
1893
- cache_key = f"{selected_model}_{selected_feature_mode}"
1894
- else:
1895
- selected_model = model_info
1896
- selected_feature_mode = "default"
1897
- cache_key = f"{selected_model}_{selected_feature_mode}"
1676
+ self.annotation_viewer.clear_selection()
1677
+ if self.annotation_viewer.isolated_mode:
1678
+ self.annotation_viewer.show_all_annotations()
1898
1679
 
1899
- technique = embedding_params['technique']
1680
+ self.embedding_viewer.render_selection_from_ids(set())
1681
+ self.update_button_states()
1682
+
1683
+ # Get embedding parameters and selected model; create a cache key to avoid re-computing features
1684
+ embedding_params = self.embedding_settings_widget.get_embedding_parameters()
1685
+ model_info = self.model_settings_widget.get_selected_model()
1686
+ selected_model, selected_feature_mode = model_info if isinstance(model_info, tuple) else (model_info, "default")
1687
+ cache_key = f"{selected_model}_{selected_feature_mode}"
1900
1688
 
1901
1689
  QApplication.setOverrideCursor(Qt.WaitCursor)
1902
1690
  progress_bar = ProgressBar(self, "Generating Embedding Visualization")
1903
1691
  progress_bar.show()
1904
-
1905
1692
  try:
1906
- # 2. Decide whether to use cached features or extract new ones
1907
- # Now checks both model name AND feature mode
1908
1693
  if self.current_features is None or cache_key != self.current_feature_generating_model:
1909
- # SLOW PATH: Extract and cache new features
1910
1694
  features, valid_data_items = self._extract_features(self.current_data_items, progress_bar=progress_bar)
1911
-
1912
1695
  self.current_features = features
1913
- self.current_feature_generating_model = cache_key # Store the complete cache key
1696
+ self.current_feature_generating_model = cache_key
1914
1697
  self.current_data_items = valid_data_items
1915
1698
  self.annotation_viewer.update_annotations(self.current_data_items)
1916
1699
  else:
1917
- # FAST PATH: Use existing features
1918
- print("Using cached features. Skipping feature extraction.")
1919
1700
  features = self.current_features
1920
-
1921
- if features is None or len(features) == 0:
1922
- print("No valid features available. Aborting embedding.")
1701
+
1702
+ if features is None or len(features) == 0:
1923
1703
  return
1924
1704
 
1925
- # 3. Run dimensionality reduction with the latest parameters
1926
- progress_bar.set_busy_mode(f"Running {technique} dimensionality reduction...")
1705
+ progress_bar.set_busy_mode("Running dimensionality reduction...")
1927
1706
  embedded_features = self._run_dimensionality_reduction(features, embedding_params)
1928
- progress_bar.update_progress()
1929
1707
 
1930
- if embedded_features is None:
1708
+ if embedded_features is None:
1931
1709
  return
1932
1710
 
1933
- # 4. Update the visualization with the new 2D layout
1934
1711
  progress_bar.set_busy_mode("Updating visualization...")
1935
1712
  self._update_data_items_with_embedding(self.current_data_items, embedded_features)
1936
-
1937
1713
  self.embedding_viewer.update_embeddings(self.current_data_items)
1938
1714
  self.embedding_viewer.show_embedding()
1939
1715
  self.embedding_viewer.fit_view_to_points()
1940
- progress_bar.update_progress()
1941
-
1942
- print(f"Successfully generated embedding for {len(self.current_data_items)} annotations using {technique}")
1943
-
1944
- except Exception as e:
1945
- print(f"Error during embedding pipeline: {e}")
1946
- self.embedding_viewer.clear_points()
1947
- self.embedding_viewer.show_placeholder()
1948
-
1949
1716
  finally:
1950
1717
  QApplication.restoreOverrideCursor()
1951
1718
  progress_bar.finish_progress()
@@ -1956,20 +1723,11 @@ class ExplorerWindow(QMainWindow):
1956
1723
  """Refresh display: filter data and update annotation viewer."""
1957
1724
  QApplication.setOverrideCursor(Qt.WaitCursor)
1958
1725
  try:
1959
- # Get filtered data and store for potential embedding
1960
1726
  self.current_data_items = self.get_filtered_data_items()
1961
-
1962
- # --- MODIFIED: Invalidate the feature cache ---
1963
- # Since the filtered items have changed, the old features are no longer valid.
1964
- self.current_features = None
1965
-
1966
- # Update annotation viewer with filtered data
1727
+ self.current_features = None
1967
1728
  self.annotation_viewer.update_annotations(self.current_data_items)
1968
-
1969
- # Clear embedding viewer and show placeholder, as it is now out of sync
1970
1729
  self.embedding_viewer.clear_points()
1971
1730
  self.embedding_viewer.show_placeholder()
1972
-
1973
1731
  finally:
1974
1732
  QApplication.restoreOverrideCursor()
1975
1733
 
@@ -1980,88 +1738,88 @@ class ExplorerWindow(QMainWindow):
1980
1738
  self.update_button_states()
1981
1739
 
1982
1740
  def clear_preview_changes(self):
1983
- """Clear all preview changes and revert to original labels."""
1741
+ """
1742
+ Clears all preview changes in both viewers, including label changes
1743
+ and items marked for deletion.
1744
+ """
1984
1745
  if hasattr(self, 'annotation_viewer'):
1985
1746
  self.annotation_viewer.clear_preview_states()
1986
- self.update_button_states()
1987
- print("Cleared all preview changes")
1747
+
1748
+ if hasattr(self, 'embedding_viewer'):
1749
+ self.embedding_viewer.refresh_points()
1750
+
1751
+ # After reverting all changes, update the button states
1752
+ self.update_button_states()
1753
+ print("Cleared all pending changes.")
1988
1754
 
1989
1755
  def update_button_states(self):
1990
1756
  """Update the state of Clear Preview and Apply buttons."""
1991
- has_changes = (hasattr(self, 'annotation_viewer') and self.annotation_viewer.has_preview_changes())
1992
-
1757
+ has_changes = self.annotation_viewer.has_preview_changes()
1993
1758
  self.clear_preview_button.setEnabled(has_changes)
1994
1759
  self.apply_button.setEnabled(has_changes)
1995
-
1996
1760
  summary = self.annotation_viewer.get_preview_changes_summary()
1997
1761
  self.clear_preview_button.setToolTip(f"Clear all preview changes - {summary}")
1998
1762
  self.apply_button.setToolTip(f"Apply changes - {summary}")
1999
1763
 
2000
1764
  def apply(self):
2001
- """Apply any modifications to the actual annotations."""
2002
- # Make cursor busy
1765
+ """
1766
+ Apply all pending changes, including label modifications and deletions,
1767
+ to the main application's data.
1768
+ """
2003
1769
  QApplication.setOverrideCursor(Qt.WaitCursor)
2004
-
2005
1770
  try:
2006
- applied_annotations = self.annotation_viewer.apply_preview_changes_permanently()
2007
-
2008
- if applied_annotations:
2009
- # Find which data items were affected and tell their visual components to update
2010
- changed_ids = {ann.id for ann in applied_annotations}
2011
- for item in self.current_data_items:
2012
- if item.annotation.id in changed_ids:
2013
- # Update annotation widget in the grid
2014
- widget = self.annotation_viewer.annotation_widgets_by_id.get(item.annotation.id)
2015
- if widget:
2016
- widget.update() # Repaint to show new permanent color
2017
-
2018
- # Update point in the embedding viewer
2019
- point = self.embedding_viewer.points_by_id.get(item.annotation.id)
2020
- if point:
2021
- point.update() # Repaint to show new permanent color
2022
-
2023
- # Update sorting if currently sorting by label
2024
- if self.annotation_viewer.sort_combo.currentText() == "Label":
2025
- self.annotation_viewer.recalculate_widget_positions()
2026
-
2027
- # Update the main application's data
2028
- affected_images = {ann.image_path for ann in applied_annotations}
2029
- for image_path in affected_images:
2030
- self.image_window.update_image_annotations(image_path)
2031
- self.annotation_window.load_annotations()
2032
-
2033
- # Clear selection and button states
2034
- self.annotation_viewer.clear_selection()
2035
- self.embedding_viewer.render_selection_from_ids(set()) # Clear embedding selection
2036
- self.update_button_states()
2037
-
2038
- print(f"Applied changes to {len(applied_annotations)} annotation(s)")
2039
- else:
2040
- print("No preview changes to apply")
1771
+ # Separate items into those to be deleted and those to be kept
1772
+ items_to_delete = [item for item in self.current_data_items if item.is_marked_for_deletion()]
1773
+ items_to_keep = [item for item in self.current_data_items if not item.is_marked_for_deletion()]
1774
+
1775
+ # --- 1. Process Deletions ---
1776
+ deleted_annotations = []
1777
+ if items_to_delete:
1778
+ deleted_annotations = [item.annotation for item in items_to_delete]
1779
+ print(f"Permanently deleting {len(deleted_annotations)} annotation(s).")
1780
+ self.annotation_window.delete_annotations(deleted_annotations)
1781
+
1782
+ # --- 2. Process Label Changes on remaining items ---
1783
+ applied_label_changes = []
1784
+ for item in items_to_keep:
1785
+ if item.apply_preview_permanently():
1786
+ applied_label_changes.append(item.annotation)
1787
+
1788
+ # --- 3. Update UI if any changes were made ---
1789
+ if not deleted_annotations and not applied_label_changes:
1790
+ print("No pending changes to apply.")
1791
+ return
1792
+
1793
+ # Update the Explorer's internal list of data items
1794
+ self.current_data_items = items_to_keep
1795
+
1796
+ # Update the main application's data and UI
1797
+ all_affected_annotations = deleted_annotations + applied_label_changes
1798
+ affected_images = {ann.image_path for ann in all_affected_annotations}
1799
+ for image_path in affected_images:
1800
+ self.image_window.update_image_annotations(image_path)
1801
+ self.annotation_window.load_annotations()
1802
+
1803
+ # Refresh the annotation viewer since its underlying data has changed
1804
+ self.annotation_viewer.update_annotations(self.current_data_items)
1805
+
1806
+ # Reset selections and button states
1807
+ self.embedding_viewer.render_selection_from_ids(set())
1808
+ self.update_label_window_selection()
1809
+ self.update_button_states()
1810
+
1811
+ print(f"Applied changes successfully.")
2041
1812
 
2042
1813
  except Exception as e:
2043
1814
  print(f"Error applying modifications: {e}")
2044
1815
  finally:
2045
- # Restore cursor
2046
1816
  QApplication.restoreOverrideCursor()
2047
-
1817
+
2048
1818
  def _cleanup_resources(self):
2049
- """
2050
- Clean up heavy resources like the loaded model and clear GPU cache.
2051
- This is called when the window is closed to free up memory.
2052
- """
2053
- print("Cleaning up Explorer resources...")
2054
-
2055
- # Reset model and feature caches
2056
- self.imgsz = 128
1819
+ """Clean up resources."""
2057
1820
  self.loaded_model = None
2058
1821
  self.model_path = ""
2059
1822
  self.current_features = None
2060
1823
  self.current_feature_generating_model = ""
2061
-
2062
- # Clear CUDA cache if available to free up GPU memory
2063
1824
  if torch.cuda.is_available():
2064
- print("Clearing CUDA cache.")
2065
- torch.cuda.empty_cache()
2066
-
2067
- print("Cleanup complete.")
1825
+ torch.cuda.empty_cache()