coralnet-toolbox 0.0.66__py2.py3-none-any.whl → 0.0.67__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. coralnet_toolbox/Annotations/QtMultiPolygonAnnotation.py +1 -1
  2. coralnet_toolbox/Annotations/QtPatchAnnotation.py +1 -1
  3. coralnet_toolbox/Annotations/QtPolygonAnnotation.py +1 -1
  4. coralnet_toolbox/Annotations/QtRectangleAnnotation.py +1 -1
  5. coralnet_toolbox/AutoDistill/QtDeployModel.py +4 -0
  6. coralnet_toolbox/Explorer/QtAnnotationDataItem.py +97 -0
  7. coralnet_toolbox/Explorer/QtAnnotationImageWidget.py +183 -0
  8. coralnet_toolbox/Explorer/QtEmbeddingPointItem.py +30 -0
  9. coralnet_toolbox/Explorer/QtExplorer.py +2067 -0
  10. coralnet_toolbox/Explorer/QtSettingsWidgets.py +490 -0
  11. coralnet_toolbox/Explorer/__init__.py +7 -0
  12. coralnet_toolbox/IO/QtImportViscoreAnnotations.py +2 -4
  13. coralnet_toolbox/IO/QtOpenProject.py +2 -1
  14. coralnet_toolbox/Icons/magic.png +0 -0
  15. coralnet_toolbox/MachineLearning/DeployModel/QtDetect.py +4 -0
  16. coralnet_toolbox/MachineLearning/DeployModel/QtSegment.py +4 -0
  17. coralnet_toolbox/MachineLearning/TrainModel/QtClassify.py +1 -1
  18. coralnet_toolbox/QtConfidenceWindow.py +2 -23
  19. coralnet_toolbox/QtEventFilter.py +2 -2
  20. coralnet_toolbox/QtLabelWindow.py +23 -8
  21. coralnet_toolbox/QtMainWindow.py +81 -2
  22. coralnet_toolbox/QtProgressBar.py +12 -0
  23. coralnet_toolbox/SAM/QtDeployGenerator.py +4 -0
  24. coralnet_toolbox/__init__.py +1 -1
  25. coralnet_toolbox/utilities.py +24 -0
  26. {coralnet_toolbox-0.0.66.dist-info → coralnet_toolbox-0.0.67.dist-info}/METADATA +2 -1
  27. {coralnet_toolbox-0.0.66.dist-info → coralnet_toolbox-0.0.67.dist-info}/RECORD +31 -24
  28. {coralnet_toolbox-0.0.66.dist-info → coralnet_toolbox-0.0.67.dist-info}/WHEEL +0 -0
  29. {coralnet_toolbox-0.0.66.dist-info → coralnet_toolbox-0.0.67.dist-info}/entry_points.txt +0 -0
  30. {coralnet_toolbox-0.0.66.dist-info → coralnet_toolbox-0.0.67.dist-info}/licenses/LICENSE.txt +0 -0
  31. {coralnet_toolbox-0.0.66.dist-info → coralnet_toolbox-0.0.67.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,2067 @@
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import warnings
5
+
6
+ from ultralytics import YOLO
7
+
8
+ from coralnet_toolbox.MachineLearning.Community.cfg import get_available_configs
9
+
10
+ from coralnet_toolbox.Icons import get_icon
11
+ from coralnet_toolbox.utilities import pixmap_to_numpy
12
+
13
+ from PyQt5.QtGui import QIcon, QPen, QColor, QPainter, QImage, QBrush, QPainterPath, QPolygonF, QMouseEvent
14
+ from PyQt5.QtCore import Qt, QTimer, QSize, QRect, QRectF, QPointF, pyqtSignal, QSignalBlocker, pyqtSlot
15
+
16
+ from PyQt5.QtWidgets import (QVBoxLayout, QHBoxLayout, QGraphicsView, QScrollArea,
17
+ QGraphicsScene, QPushButton, QComboBox, QLabel, QWidget, QGridLayout,
18
+ QMainWindow, QSplitter, QGroupBox, QFormLayout,
19
+ QSpinBox, QGraphicsEllipseItem, QGraphicsItem, QSlider,
20
+ QListWidget, QDoubleSpinBox, QApplication, QStyle,
21
+ QGraphicsRectItem, QRubberBand, QStyleOptionGraphicsItem,
22
+ QTabWidget, QLineEdit, QFileDialog)
23
+
24
+ from .QtAnnotationDataItem import AnnotationDataItem
25
+ from .QtEmbeddingPointItem import EmbeddingPointItem
26
+ from .QtAnnotationImageWidget import AnnotationImageWidget
27
+ from .QtSettingsWidgets import ModelSettingsWidget
28
+ from .QtSettingsWidgets import EmbeddingSettingsWidget
29
+ from .QtSettingsWidgets import AnnotationSettingsWidget
30
+
31
+ from coralnet_toolbox.QtProgressBar import ProgressBar
32
+
33
+ try:
34
+ from sklearn.preprocessing import StandardScaler
35
+ from sklearn.decomposition import PCA
36
+ from sklearn.manifold import TSNE
37
+ from umap import UMAP
38
+ except ImportError:
39
+ print("Warning: sklearn or umap not installed. Some features may be unavailable.")
40
+ StandardScaler = None
41
+ PCA = None
42
+ TSNE = None
43
+ UMAP = None
44
+
45
+
46
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
47
+
48
+
49
+ # ----------------------------------------------------------------------------------------------------------------------
50
+ # Constants
51
+ # ----------------------------------------------------------------------------------------------------------------------
52
+
53
+
54
+ POINT_SIZE = 15
55
+ POINT_WIDTH = 3
56
+
57
+
58
+ # ----------------------------------------------------------------------------------------------------------------------
59
+ # Viewers
60
+ # ----------------------------------------------------------------------------------------------------------------------
61
+
62
+
63
+ class EmbeddingViewer(QWidget): # Change inheritance to QWidget
64
+ """Custom QGraphicsView for interactive embedding visualization with zooming, panning, and selection."""
65
+
66
+ # Define signal to report selection changes
67
+ selection_changed = pyqtSignal(list) # list of all currently selected annotation IDs
68
+ reset_view_requested = pyqtSignal() # Signal to reset the view to fit all points
69
+
70
+ def __init__(self, parent=None):
71
+ # Create the graphics scene first
72
+ self.graphics_scene = QGraphicsScene()
73
+ self.graphics_scene.setSceneRect(-5000, -5000, 10000, 10000)
74
+
75
+ # Initialize as a QWidget
76
+ super(EmbeddingViewer, self).__init__(parent)
77
+ self.explorer_window = parent
78
+
79
+ # Create the actual graphics view
80
+ self.graphics_view = QGraphicsView(self.graphics_scene)
81
+ self.graphics_view.setRenderHint(QPainter.Antialiasing)
82
+ self.graphics_view.setDragMode(QGraphicsView.ScrollHandDrag)
83
+ self.graphics_view.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
84
+ self.graphics_view.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
85
+ self.graphics_view.setMinimumHeight(200)
86
+
87
+ # Custom rubber_band state variables
88
+ self.rubber_band = None
89
+ self.rubber_band_origin = QPointF()
90
+ self.selection_at_press = None
91
+
92
+ self.points_by_id = {} # Map annotation ID to embedding point
93
+ self.previous_selection_ids = set() # Track previous selection to detect changes
94
+
95
+ self.animation_offset = 0
96
+ self.animation_timer = QTimer()
97
+ self.animation_timer.timeout.connect(self.animate_selection)
98
+ self.animation_timer.setInterval(100)
99
+
100
+ # Connect the scene's selection signal
101
+ self.graphics_scene.selectionChanged.connect(self.on_selection_changed)
102
+
103
+ # Setup the UI with header
104
+ self.setup_ui()
105
+
106
+ # Connect mouse events to the graphics view
107
+ self.graphics_view.mousePressEvent = self.mousePressEvent
108
+ self.graphics_view.mouseDoubleClickEvent = self.mouseDoubleClickEvent
109
+ self.graphics_view.mouseReleaseEvent = self.mouseReleaseEvent
110
+ self.graphics_view.mouseMoveEvent = self.mouseMoveEvent
111
+ self.graphics_view.wheelEvent = self.wheelEvent
112
+
113
+ def setup_ui(self):
114
+ """Set up the UI with header layout and graphics view."""
115
+ layout = QVBoxLayout(self)
116
+ layout.setContentsMargins(0, 0, 0, 0)
117
+
118
+ # Header layout
119
+ header_layout = QHBoxLayout()
120
+
121
+ # Home button
122
+ self.home_button = QPushButton("Home")
123
+ self.home_button.setToolTip("Reset view to fit all points")
124
+ self.home_button.clicked.connect(self.reset_view)
125
+ header_layout.addWidget(self.home_button)
126
+
127
+ # Add stretch to push future controls to the right if needed
128
+ header_layout.addStretch()
129
+
130
+ layout.addLayout(header_layout)
131
+
132
+ # Add the graphics view
133
+ layout.addWidget(self.graphics_view)
134
+ # Add a placeholder label when no embedding is available
135
+ self.placeholder_label = QLabel(
136
+ "No embedding data available.\nPress 'Apply Embedding' to generate visualization."
137
+ )
138
+ self.placeholder_label.setAlignment(Qt.AlignCenter)
139
+ self.placeholder_label.setStyleSheet("color: gray; font-size: 14px;")
140
+ layout.addWidget(self.placeholder_label)
141
+
142
+ # Initially show placeholder
143
+ self.show_placeholder()
144
+
145
+ def reset_view(self):
146
+ """Reset the view to fit all embedding points."""
147
+ self.fit_view_to_points()
148
+
149
+ def show_placeholder(self):
150
+ """Show the placeholder message and hide the graphics view."""
151
+ self.graphics_view.setVisible(False)
152
+ self.placeholder_label.setVisible(True)
153
+ self.home_button.setEnabled(False)
154
+
155
+ def show_embedding(self):
156
+ """Show the graphics view and hide the placeholder message."""
157
+ self.graphics_view.setVisible(True)
158
+ self.placeholder_label.setVisible(False)
159
+ self.home_button.setEnabled(True)
160
+
161
+ # Delegate graphics view methods
162
+ def setRenderHint(self, hint):
163
+ self.graphics_view.setRenderHint(hint)
164
+
165
+ def setDragMode(self, mode):
166
+ self.graphics_view.setDragMode(mode)
167
+
168
+ def setTransformationAnchor(self, anchor):
169
+ self.graphics_view.setTransformationAnchor(anchor)
170
+
171
+ def setResizeAnchor(self, anchor):
172
+ self.graphics_view.setResizeAnchor(anchor)
173
+
174
+ def mapToScene(self, point):
175
+ return self.graphics_view.mapToScene(point)
176
+
177
+ def scale(self, sx, sy):
178
+ self.graphics_view.scale(sx, sy)
179
+
180
+ def translate(self, dx, dy):
181
+ self.graphics_view.translate(dx, dy)
182
+
183
+ def fitInView(self, rect, aspect_ratio):
184
+ self.graphics_view.fitInView(rect, aspect_ratio)
185
+
186
+ def mousePressEvent(self, event):
187
+ """Handle mouse press for selection (point or rubber band) and panning."""
188
+ if event.button() == Qt.LeftButton and event.modifiers() == Qt.ControlModifier:
189
+ # Check if the click is on an existing point
190
+ item_at_pos = self.graphics_view.itemAt(event.pos())
191
+ if isinstance(item_at_pos, EmbeddingPointItem):
192
+ # If so, toggle its selection state and do nothing else
193
+ self.graphics_view.setDragMode(QGraphicsView.NoDrag)
194
+ item_at_pos.setSelected(not item_at_pos.isSelected())
195
+ return # Event handled
196
+
197
+ # If the click was on the background, proceed with rubber band selection
198
+ self.selection_at_press = set(self.graphics_scene.selectedItems())
199
+ self.graphics_view.setDragMode(QGraphicsView.NoDrag)
200
+ self.rubber_band_origin = self.graphics_view.mapToScene(event.pos())
201
+ self.rubber_band = QGraphicsRectItem(QRectF(self.rubber_band_origin, self.rubber_band_origin))
202
+ self.rubber_band.setPen(QPen(QColor(0, 100, 255), 1, Qt.DotLine))
203
+ self.rubber_band.setBrush(QBrush(QColor(0, 100, 255, 50)))
204
+ self.graphics_scene.addItem(self.rubber_band)
205
+
206
+ elif event.button() == Qt.RightButton:
207
+ # Handle panning
208
+ self.graphics_view.setDragMode(QGraphicsView.ScrollHandDrag)
209
+ left_event = QMouseEvent(event.type(), event.localPos(), Qt.LeftButton, Qt.LeftButton, event.modifiers())
210
+ QGraphicsView.mousePressEvent(self.graphics_view, left_event)
211
+ else:
212
+ # Handle standard single-item selection
213
+ self.graphics_view.setDragMode(QGraphicsView.NoDrag)
214
+ QGraphicsView.mousePressEvent(self.graphics_view, event)
215
+
216
+ def mouseDoubleClickEvent(self, event):
217
+ """Handle double-click to clear selection and reset the main view."""
218
+ if event.button() == Qt.LeftButton:
219
+ # Clear selection if any items are selected
220
+ if self.graphics_scene.selectedItems():
221
+ self.graphics_scene.clearSelection() # This triggers on_selection_changed
222
+
223
+ # Signal the main window to revert from isolation mode
224
+ self.reset_view_requested.emit()
225
+ event.accept()
226
+ else:
227
+ # Pass other double-clicks to the base class
228
+ super().mouseDoubleClickEvent(event)
229
+
230
+ def mouseMoveEvent(self, event):
231
+ """Handle mouse move for dynamic selection and panning."""
232
+ if self.rubber_band:
233
+ # Update the rubber band geometry
234
+ current_pos = self.graphics_view.mapToScene(event.pos())
235
+ self.rubber_band.setRect(QRectF(self.rubber_band_origin, current_pos).normalized())
236
+
237
+ path = QPainterPath()
238
+ path.addRect(self.rubber_band.rect())
239
+
240
+ # Block signals to perform a compound selection operation
241
+ self.graphics_scene.blockSignals(True)
242
+
243
+ # 1. Perform the "fancy" dynamic selection, which replaces the current selection
244
+ # with only the items inside the rubber band.
245
+ self.graphics_scene.setSelectionArea(path)
246
+
247
+ # 2. Add back the items that were selected at the start of the drag.
248
+ if self.selection_at_press:
249
+ for item in self.selection_at_press:
250
+ item.setSelected(True)
251
+
252
+ # Unblock signals and manually trigger our handler to process the final result.
253
+ self.graphics_scene.blockSignals(False)
254
+ self.on_selection_changed()
255
+
256
+ elif event.buttons() == Qt.RightButton:
257
+ # Handle right-click panning
258
+ left_event = QMouseEvent(event.type(),
259
+ event.localPos(),
260
+ Qt.LeftButton,
261
+ Qt.LeftButton,
262
+ event.modifiers())
263
+ QGraphicsView.mouseMoveEvent(self.graphics_view, left_event)
264
+ else:
265
+ QGraphicsView.mouseMoveEvent(self.graphics_view, event)
266
+
267
+ def mouseReleaseEvent(self, event):
268
+ """Handle mouse release to finalize the action and clean up."""
269
+ if self.rubber_band:
270
+ # Clean up the visual rectangle
271
+ self.graphics_scene.removeItem(self.rubber_band)
272
+ self.rubber_band = None
273
+
274
+ # Clean up the stored selection state.
275
+ self.selection_at_press = None
276
+
277
+ elif event.button() == Qt.RightButton:
278
+ # Finalize the pan
279
+ left_event = QMouseEvent(event.type(),
280
+ event.localPos(),
281
+ Qt.LeftButton,
282
+ Qt.LeftButton,
283
+ event.modifiers())
284
+ QGraphicsView.mouseReleaseEvent(self.graphics_view, left_event)
285
+ self.graphics_view.setDragMode(QGraphicsView.NoDrag)
286
+ else:
287
+ # Finalize a single click
288
+ QGraphicsView.mouseReleaseEvent(self.graphics_view, event)
289
+ self.graphics_view.setDragMode(QGraphicsView.NoDrag)
290
+
291
+ def wheelEvent(self, event):
292
+ """Handle mouse wheel for zooming."""
293
+ zoom_in_factor = 1.25
294
+ zoom_out_factor = 1 / zoom_in_factor
295
+
296
+ self.graphics_view.setTransformationAnchor(QGraphicsView.NoAnchor)
297
+ self.graphics_view.setResizeAnchor(QGraphicsView.NoAnchor)
298
+
299
+ old_pos = self.graphics_view.mapToScene(event.pos())
300
+ zoom_factor = zoom_in_factor if event.angleDelta().y() > 0 else zoom_out_factor
301
+ self.graphics_view.scale(zoom_factor, zoom_factor)
302
+ new_pos = self.graphics_view.mapToScene(event.pos())
303
+
304
+ delta = new_pos - old_pos
305
+ self.graphics_view.translate(delta.x(), delta.y())
306
+
307
+ def update_embeddings(self, data_items):
308
+ """Update the embedding visualization with new data.
309
+
310
+ Args:
311
+ data_items: List of AnnotationDataItem objects.
312
+ """
313
+ self.clear_points()
314
+
315
+ for item in data_items:
316
+ point = EmbeddingPointItem(0, 0, POINT_SIZE, POINT_SIZE)
317
+ point.setPos(item.embedding_x, item.embedding_y)
318
+
319
+ # No need to set initial brush - paint() will handle it
320
+ point.setPen(QPen(QColor("black"), POINT_WIDTH))
321
+
322
+ point.setFlag(QGraphicsItem.ItemIgnoresTransformations)
323
+ point.setFlag(QGraphicsItem.ItemIsSelectable)
324
+
325
+ # This is the crucial link: store the shared AnnotationDataItem
326
+ point.setData(0, item)
327
+
328
+ self.graphics_scene.addItem(point)
329
+ self.points_by_id[item.annotation.id] = point
330
+
331
+ def clear_points(self):
332
+ """Clear all embedding points from the scene."""
333
+ for point in self.points_by_id.values():
334
+ self.graphics_scene.removeItem(point)
335
+ self.points_by_id.clear()
336
+
337
+ def on_selection_changed(self):
338
+ """Handle point selection changes and emit a signal to the controller."""
339
+ # Check if graphics_scene still exists and is valid
340
+ if not self.graphics_scene or not hasattr(self.graphics_scene, 'selectedItems'):
341
+ return
342
+
343
+ try:
344
+ selected_items = self.graphics_scene.selectedItems()
345
+ except RuntimeError:
346
+ # Scene has been deleted
347
+ return
348
+
349
+ current_selection_ids = {item.data(0).annotation.id for item in selected_items}
350
+
351
+ # If the selection has actually changed, update the model and emit
352
+ if current_selection_ids != self.previous_selection_ids:
353
+ # Update the central model (the AnnotationDataItem) for all points
354
+ for point_id, point in self.points_by_id.items():
355
+ is_selected = point_id in current_selection_ids
356
+ point.data(0).set_selected(is_selected)
357
+
358
+ # Emit the complete list of currently selected IDs
359
+ self.selection_changed.emit(list(current_selection_ids))
360
+ self.previous_selection_ids = current_selection_ids
361
+
362
+ # Handle local animation - check if animation_timer still exists
363
+ if hasattr(self, 'animation_timer') and self.animation_timer:
364
+ self.animation_timer.stop()
365
+
366
+ for point in self.points_by_id.values():
367
+ if not point.isSelected():
368
+ point.setPen(QPen(QColor("black"), POINT_WIDTH))
369
+
370
+ if selected_items and hasattr(self, 'animation_timer') and self.animation_timer:
371
+ self.animation_timer.start()
372
+
373
+ def animate_selection(self):
374
+ """Animate selected points with marching ants effect using darker versions of point colors."""
375
+ # Check if graphics_scene still exists and is valid
376
+ if not self.graphics_scene or not hasattr(self.graphics_scene, 'selectedItems'):
377
+ return
378
+
379
+ try:
380
+ selected_items = self.graphics_scene.selectedItems()
381
+ except RuntimeError:
382
+ # Scene has been deleted
383
+ return
384
+
385
+ self.animation_offset = (self.animation_offset + 1) % 20
386
+
387
+ # This logic remains the same. It applies the custom pen to the selected items.
388
+ # Because the items are EmbeddingPointItem, the default selection box won't be drawn.
389
+ for item in selected_items:
390
+ original_color = item.brush().color()
391
+ darker_color = original_color.darker(150)
392
+
393
+ animated_pen = QPen(darker_color, POINT_WIDTH)
394
+ animated_pen.setStyle(Qt.CustomDashLine)
395
+ animated_pen.setDashPattern([1, 2])
396
+ animated_pen.setDashOffset(self.animation_offset)
397
+
398
+ item.setPen(animated_pen)
399
+
400
+ def render_selection_from_ids(self, selected_ids):
401
+ """Update the visual selection of points based on a set of IDs from the controller."""
402
+ # Block this scene's own selectionChanged signal to prevent an infinite loop
403
+ blocker = QSignalBlocker(self.graphics_scene)
404
+
405
+ for ann_id, point in self.points_by_id.items():
406
+ point.setSelected(ann_id in selected_ids)
407
+
408
+ self.previous_selection_ids = selected_ids
409
+
410
+ # Trigger animation update
411
+ self.on_selection_changed()
412
+
413
+ def fit_view_to_points(self):
414
+ """Fit the view to show all embedding points."""
415
+ if self.points_by_id:
416
+ self.graphics_view.fitInView(self.graphics_scene.itemsBoundingRect(), Qt.KeepAspectRatio)
417
+ else:
418
+ # If no points, reset to default view
419
+ self.graphics_view.fitInView(-2500, -2500, 5000, 5000, Qt.KeepAspectRatio)
420
+
421
+
422
+ class AnnotationViewer(QScrollArea):
423
+ """Scrollable grid widget for displaying annotation image crops with selection,
424
+ filtering, and isolation support.
425
+ """
426
+
427
+ # Define signals to report changes to the ExplorerWindow
428
+ selection_changed = pyqtSignal(list) # list of changed annotation IDs
429
+ preview_changed = pyqtSignal(list) # list of annotation IDs with new previews
430
+ reset_view_requested = pyqtSignal() # Signal to reset the view to fit all points
431
+
432
+ def __init__(self, parent=None):
433
+ super(AnnotationViewer, self).__init__(parent)
434
+ self.annotation_widgets_by_id = {}
435
+ self.selected_widgets = []
436
+ self.last_selected_index = -1
437
+ self.current_widget_size = 96
438
+
439
+ self.selection_at_press = set()
440
+ self.rubber_band = None
441
+ self.rubber_band_origin = None
442
+ self.drag_threshold = 5
443
+ self.mouse_pressed_on_widget = False
444
+
445
+ self.preview_label_assignments = {}
446
+ self.original_label_assignments = {}
447
+
448
+ # New state variables for Isolate/Focus mode
449
+ self.isolated_mode = False
450
+ self.isolated_widgets = set()
451
+
452
+ self.setup_ui()
453
+
454
+ def setup_ui(self):
455
+ """Set up the UI with a toolbar and a scrollable content area."""
456
+ self.setWidgetResizable(True)
457
+ self.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
458
+ self.setVerticalScrollBarPolicy(Qt.ScrollBarAsNeeded)
459
+
460
+ # Main container and layout
461
+ main_container = QWidget()
462
+ main_layout = QVBoxLayout(main_container)
463
+ main_layout.setContentsMargins(0, 0, 0, 0)
464
+ main_layout.setSpacing(4) # Add a little space between toolbar and content
465
+
466
+ # --- New Toolbar ---
467
+ toolbar_widget = QWidget()
468
+ toolbar_layout = QHBoxLayout(toolbar_widget)
469
+ toolbar_layout.setContentsMargins(4, 2, 4, 2)
470
+
471
+ # Isolate/Focus controls
472
+ self.isolate_button = QPushButton("Isolate Selection")
473
+ isolate_icon = get_icon("focus.png")
474
+ if not isolate_icon.isNull():
475
+ self.isolate_button.setIcon(isolate_icon)
476
+ self.isolate_button.setToolTip("Hide all non-selected annotations")
477
+ self.isolate_button.clicked.connect(self.isolate_selection)
478
+ toolbar_layout.addWidget(self.isolate_button)
479
+
480
+ self.show_all_button = QPushButton("Show All")
481
+ show_all_icon = get_icon("show_all.png")
482
+ if not show_all_icon.isNull():
483
+ self.show_all_button.setIcon(show_all_icon)
484
+ self.show_all_button.setToolTip("Show all filtered annotations")
485
+ self.show_all_button.clicked.connect(self.show_all_annotations)
486
+ toolbar_layout.addWidget(self.show_all_button)
487
+
488
+ # Add a separator
489
+ toolbar_layout.addWidget(self._create_separator())
490
+
491
+ # Sort controls
492
+ sort_label = QLabel("Sort By:")
493
+ toolbar_layout.addWidget(sort_label)
494
+ self.sort_combo = QComboBox()
495
+ self.sort_combo.addItems(["None", "Label", "Image"])
496
+ self.sort_combo.currentTextChanged.connect(self.on_sort_changed)
497
+ toolbar_layout.addWidget(self.sort_combo)
498
+
499
+ # Add a spacer to push the size controls to the right
500
+ toolbar_layout.addStretch()
501
+
502
+ # Size controls
503
+ size_label = QLabel("Size:")
504
+ toolbar_layout.addWidget(size_label)
505
+ self.size_slider = QSlider(Qt.Horizontal)
506
+ self.size_slider.setMinimum(32)
507
+ self.size_slider.setMaximum(256)
508
+ self.size_slider.setValue(96)
509
+ self.size_slider.setTickPosition(QSlider.TicksBelow)
510
+ self.size_slider.setTickInterval(32)
511
+ self.size_slider.valueChanged.connect(self.on_size_changed)
512
+ toolbar_layout.addWidget(self.size_slider)
513
+
514
+ self.size_value_label = QLabel("96")
515
+ self.size_value_label.setMinimumWidth(30)
516
+ toolbar_layout.addWidget(self.size_value_label)
517
+
518
+ main_layout.addWidget(toolbar_widget)
519
+
520
+ # --- Content Area ---
521
+ self.content_widget = QWidget()
522
+ content_scroll = QScrollArea()
523
+ content_scroll.setWidgetResizable(True)
524
+ content_scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
525
+ content_scroll.setVerticalScrollBarPolicy(Qt.ScrollBarAsNeeded)
526
+ content_scroll.setWidget(self.content_widget)
527
+
528
+ main_layout.addWidget(content_scroll)
529
+ self.setWidget(main_container)
530
+
531
+ # Set the initial state of the toolbar buttons
532
+ self._update_toolbar_state()
533
+
534
+ @pyqtSlot()
535
+ def isolate_selection(self):
536
+ """Hides all annotation widgets that are not currently selected."""
537
+ if not self.selected_widgets or self.isolated_mode:
538
+ return
539
+
540
+ self.isolated_widgets = set(self.selected_widgets)
541
+ self.content_widget.setUpdatesEnabled(False)
542
+ try:
543
+ for widget in self.annotation_widgets_by_id.values():
544
+ if widget not in self.isolated_widgets:
545
+ widget.hide()
546
+ self.isolated_mode = True
547
+ self.recalculate_widget_positions()
548
+ finally:
549
+ self.content_widget.setUpdatesEnabled(True)
550
+
551
+ self._update_toolbar_state()
552
+
553
+ @pyqtSlot()
554
+ def show_all_annotations(self):
555
+ """Shows all annotation widgets, exiting the isolated mode."""
556
+ if not self.isolated_mode:
557
+ return
558
+
559
+ self.isolated_mode = False
560
+ self.isolated_widgets.clear()
561
+
562
+ self.content_widget.setUpdatesEnabled(False)
563
+ try:
564
+ for widget in self.annotation_widgets_by_id.values():
565
+ widget.show()
566
+ self.recalculate_widget_positions()
567
+ finally:
568
+ self.content_widget.setUpdatesEnabled(True)
569
+
570
+ self._update_toolbar_state()
571
+
572
+ def _update_toolbar_state(self):
573
+ """Updates the visibility and enabled state of the toolbar buttons
574
+ based on the current selection and isolation mode.
575
+ """
576
+ selection_exists = bool(self.selected_widgets)
577
+
578
+ if self.isolated_mode:
579
+ self.isolate_button.hide()
580
+ self.show_all_button.show()
581
+ self.show_all_button.setEnabled(True)
582
+ else:
583
+ self.isolate_button.show()
584
+ self.show_all_button.hide()
585
+ self.isolate_button.setEnabled(selection_exists)
586
+
587
+ def _create_separator(self):
588
+ """Create a vertical separator line for the toolbar."""
589
+ separator = QLabel("|")
590
+ separator.setStyleSheet("color: gray; margin: 0 5px;")
591
+ return separator
592
+
593
+ def on_sort_changed(self, sort_type):
594
+ """Handle sort type change."""
595
+ self.recalculate_widget_positions()
596
+
597
+ def _get_sorted_widgets(self):
598
+ """Get widgets sorted according to the current sort setting."""
599
+ sort_type = self.sort_combo.currentText()
600
+
601
+ if sort_type == "None":
602
+ return list(self.annotation_widgets_by_id.values())
603
+
604
+ widgets = list(self.annotation_widgets_by_id.values())
605
+
606
+ if sort_type == "Label":
607
+ widgets.sort(key=lambda w: w.data_item.effective_label.short_label_code)
608
+ elif sort_type == "Image":
609
+ widgets.sort(key=lambda w: os.path.basename(w.data_item.annotation.image_path))
610
+
611
+ return widgets
612
+
613
+ def _group_widgets_by_sort_key(self, widgets):
614
+ """Group widgets by the current sort key and return groups with headers."""
615
+ sort_type = self.sort_combo.currentText()
616
+
617
+ if sort_type == "None":
618
+ return [("", widgets)]
619
+
620
+ groups = []
621
+ current_group = []
622
+ current_key = None
623
+
624
+ for widget in widgets:
625
+ if sort_type == "Label":
626
+ key = widget.data_item.effective_label.short_label_code
627
+ elif sort_type == "Image":
628
+ key = os.path.basename(widget.data_item.annotation.image_path)
629
+ else:
630
+ key = ""
631
+
632
+ if current_key != key:
633
+ if current_group:
634
+ groups.append((current_key, current_group))
635
+ current_group = [widget]
636
+ current_key = key
637
+ else:
638
+ current_group.append(widget)
639
+
640
+ if current_group:
641
+ groups.append((current_key, current_group))
642
+
643
+ return groups
644
+
645
+ def _clear_separator_labels(self):
646
+ """Remove any existing group header labels."""
647
+ if hasattr(self, '_group_headers'):
648
+ for header in self._group_headers:
649
+ header.setParent(None)
650
+ header.deleteLater()
651
+ self._group_headers = []
652
+
653
+ def _create_group_header(self, text):
654
+ """Create a group header label."""
655
+ if not hasattr(self, '_group_headers'):
656
+ self._group_headers = []
657
+
658
+ header = QLabel(text)
659
+ header.setParent(self.content_widget)
660
+ header.setStyleSheet("""
661
+ QLabel {
662
+ font-weight: bold;
663
+ font-size: 12px;
664
+ color: #555;
665
+ background-color: #f0f0f0;
666
+ border: 1px solid #ccc;
667
+ border-radius: 3px;
668
+ padding: 5px 8px;
669
+ margin: 2px 0px;
670
+ }
671
+ """)
672
+ header.setFixedHeight(30) # Increased from 25 to 30
673
+ header.setMinimumWidth(self.viewport().width() - 20)
674
+ header.show()
675
+
676
+ self._group_headers.append(header)
677
+ return header
678
+
679
+ def on_size_changed(self, value):
680
+ """Handle slider value change to resize annotation widgets."""
681
+ if value % 2 != 0:
682
+ value -= 1
683
+ self.current_widget_size = value
684
+ self.size_value_label.setText(str(value))
685
+
686
+ # Disable updates for performance while resizing many items
687
+ self.content_widget.setUpdatesEnabled(False)
688
+ for widget in self.annotation_widgets_by_id.values():
689
+ widget.update_height(value) # Call the new, more descriptive method
690
+ self.content_widget.setUpdatesEnabled(True)
691
+
692
+ # After resizing, reflow the layout
693
+ self.recalculate_widget_positions()
694
+
695
+ def recalculate_grid_layout(self):
696
+ """Recalculate the grid layout based on current widget width."""
697
+ if not self.annotation_widgets_by_id:
698
+ return
699
+
700
+ available_width = self.viewport().width() - 20
701
+ widget_width = self.current_widget_size + self.grid_layout.spacing()
702
+ cols = max(1, available_width // widget_width)
703
+
704
+ for i, widget in enumerate(self.annotation_widgets_by_id.values()):
705
+ self.grid_layout.addWidget(widget, i // cols, i % cols)
706
+
707
+ def recalculate_widget_positions(self):
708
+ """Manually positions widgets in a flow layout with sorting and group headers."""
709
+ if not self.annotation_widgets_by_id:
710
+ self.content_widget.setMinimumSize(1, 1)
711
+ return
712
+
713
+ # Clear any existing separator labels
714
+ self._clear_separator_labels()
715
+
716
+ # Get sorted widgets
717
+ all_widgets = self._get_sorted_widgets()
718
+
719
+ # Filter to only visible widgets
720
+ visible_widgets = [w for w in all_widgets if not w.isHidden()]
721
+
722
+ if not visible_widgets:
723
+ self.content_widget.setMinimumSize(1, 1)
724
+ return
725
+
726
+ # Group widgets by sort key
727
+ groups = self._group_widgets_by_sort_key(visible_widgets)
728
+
729
+ # Calculate spacing
730
+ spacing = max(5, int(self.current_widget_size * 0.08))
731
+ available_width = self.viewport().width()
732
+
733
+ x, y = spacing, spacing
734
+ max_height_in_row = 0
735
+
736
+ for group_name, group_widgets in groups:
737
+ # Add group header if sorting is enabled and group has a name
738
+ if group_name and self.sort_combo.currentText() != "None":
739
+ # Ensure we're at the start of a new line for headers
740
+ if x > spacing:
741
+ x = spacing
742
+ y += max_height_in_row + spacing
743
+ max_height_in_row = 0
744
+
745
+ # Create and position header label
746
+ header_label = self._create_group_header(group_name)
747
+ header_label.move(x, y)
748
+
749
+ # Move to next line after header
750
+ y += header_label.height() + spacing
751
+ x = spacing
752
+ max_height_in_row = 0
753
+
754
+ # Position widgets in this group
755
+ for widget in group_widgets:
756
+ widget_size = widget.size()
757
+
758
+ # Check if widget fits on current line
759
+ if x > spacing and x + widget_size.width() > available_width:
760
+ x = spacing
761
+ y += max_height_in_row + spacing
762
+ max_height_in_row = 0
763
+
764
+ widget.move(x, y)
765
+ x += widget_size.width() + spacing
766
+ max_height_in_row = max(max_height_in_row, widget_size.height())
767
+
768
+ # Update content widget size
769
+ total_height = y + max_height_in_row + spacing
770
+ self.content_widget.setMinimumSize(available_width, total_height)
771
+
772
+ def update_annotations(self, data_items):
773
+ """Update displayed annotations, creating new widgets. This will also
774
+ reset any active isolation view.
775
+ """
776
+ # Reset isolation state before updating to avoid confusion
777
+ if self.isolated_mode:
778
+ self.show_all_annotations()
779
+
780
+ # Clear any existing widgets and ensure they are deleted
781
+ for widget in self.annotation_widgets_by_id.values():
782
+ widget.setParent(None)
783
+ widget.deleteLater()
784
+
785
+ self.annotation_widgets_by_id.clear()
786
+ self.selected_widgets.clear()
787
+ self.last_selected_index = -1
788
+
789
+ # Create new widgets, parenting them to the content_widget
790
+ for data_item in data_items:
791
+ annotation_widget = AnnotationImageWidget(
792
+ data_item,
793
+ self.current_widget_size,
794
+ annotation_viewer=self,
795
+ parent=self.content_widget
796
+ )
797
+ annotation_widget.show()
798
+ self.annotation_widgets_by_id[data_item.annotation.id] = annotation_widget
799
+
800
+ self.recalculate_widget_positions()
801
+ # Ensure toolbar is in the correct state after a refresh
802
+ self._update_toolbar_state()
803
+
804
+ def resizeEvent(self, event):
805
+ """On window resize, reflow the annotation widgets."""
806
+ super().resizeEvent(event)
807
+ # Use a QTimer to avoid rapid, expensive reflows while dragging the resize handle
808
+ if not hasattr(self, '_resize_timer'):
809
+ self._resize_timer = QTimer(self)
810
+ self._resize_timer.setSingleShot(True)
811
+ self._resize_timer.timeout.connect(self.recalculate_widget_positions)
812
+ # Restart the timer on each resize event
813
+ self._resize_timer.start(100) # 100ms delay
814
+
815
+ def mousePressEvent(self, event):
816
+ """Handle mouse press for starting rubber band selection OR clearing selection."""
817
+
818
+ # Handle plain left-clicks
819
+ if event.button() == Qt.LeftButton:
820
+
821
+ # This is the new logic for clearing selection on a background click.
822
+ if not event.modifiers(): # Check for NO modifiers (e.g., Ctrl, Shift)
823
+
824
+ is_on_widget = False
825
+ child_at_pos = self.childAt(event.pos())
826
+
827
+ # Determine if the click was on an actual annotation widget or empty space
828
+ if child_at_pos:
829
+ widget = child_at_pos
830
+ while widget and widget != self:
831
+ if hasattr(widget, 'annotation_viewer') and widget.annotation_viewer == self:
832
+ is_on_widget = True
833
+ break
834
+ widget = widget.parent()
835
+
836
+ # If click was on empty space AND something is currently selected...
837
+ if not is_on_widget and self.selected_widgets:
838
+ # Get IDs of widgets that are about to be deselected to emit a signal
839
+ changed_ids = [w.data_item.annotation.id for w in self.selected_widgets]
840
+ self.clear_selection()
841
+ self.selection_changed.emit(changed_ids)
842
+ # The event is handled, but we don't call super() to prevent
843
+ # the scroll area from doing anything else, like starting a drag.
844
+ return
845
+
846
+ # Handle Ctrl+Click for rubber band
847
+ elif event.modifiers() == Qt.ControlModifier:
848
+ # Store the set of currently selected items.
849
+ self.selection_at_press = set(self.selected_widgets)
850
+ self.rubber_band_origin = event.pos()
851
+ # We determine mouse_pressed_on_widget here but use it in mouseMove
852
+ self.mouse_pressed_on_widget = False
853
+ child_widget = self.childAt(event.pos())
854
+ if child_widget:
855
+ widget = child_widget
856
+ while widget and widget != self:
857
+ if hasattr(widget, 'annotation_viewer') and widget.annotation_viewer == self:
858
+ self.mouse_pressed_on_widget = True
859
+ break
860
+ widget = widget.parent()
861
+ return
862
+
863
+ # Handle right-clicks
864
+ elif event.button() == Qt.RightButton:
865
+ event.ignore()
866
+ return
867
+
868
+ # For all other cases (e.g., a click on a widget that should be handled
869
+ # by the widget itself), pass the event to the default handler.
870
+ super().mousePressEvent(event)
871
+
872
+ def mouseDoubleClickEvent(self, event):
873
+ """Handle double-click to clear selection and exit isolation mode."""
874
+ if event.button() == Qt.LeftButton:
875
+ changed_ids = []
876
+
877
+ # If items are selected, clear the selection and record their IDs
878
+ if self.selected_widgets:
879
+ changed_ids = [w.data_item.annotation.id for w in self.selected_widgets]
880
+ self.clear_selection()
881
+ self.selection_changed.emit(changed_ids)
882
+
883
+ # If in isolation mode, revert to showing all annotations
884
+ if self.isolated_mode:
885
+ self.show_all_annotations()
886
+
887
+ # Signal the main window to reset its view (e.g., switch tabs)
888
+ self.reset_view_requested.emit()
889
+ event.accept()
890
+ else:
891
+ super().mouseDoubleClickEvent(event)
892
+
893
+ def mouseMoveEvent(self, event):
894
+ """Handle mouse move for DYNAMIC rubber band selection."""
895
+ if self.rubber_band_origin is None or \
896
+ event.buttons() != Qt.LeftButton or \
897
+ event.modifiers() != Qt.ControlModifier:
898
+ super().mouseMoveEvent(event)
899
+ return
900
+
901
+ # If the mouse was pressed on a widget, let that widget handle the event.
902
+ if self.mouse_pressed_on_widget:
903
+ super().mouseMoveEvent(event)
904
+ return
905
+
906
+ # Only start the rubber band after dragging a minimum distance
907
+ distance = (event.pos() - self.rubber_band_origin).manhattanLength()
908
+ if distance < self.drag_threshold:
909
+ return
910
+
911
+ # Create and show the rubber band if it doesn't exist
912
+ if not self.rubber_band:
913
+ self.rubber_band = QRubberBand(QRubberBand.Rectangle, self.viewport())
914
+
915
+ rect = QRect(self.rubber_band_origin, event.pos()).normalized()
916
+ self.rubber_band.setGeometry(rect)
917
+ self.rubber_band.show()
918
+
919
+ # Perform dynamic selection on every move
920
+ selection_rect = self.rubber_band.geometry()
921
+ content_widget = self.content_widget
922
+ changed_ids = []
923
+
924
+ for widget in self.annotation_widgets_by_id.values():
925
+ widget_rect_in_content = widget.geometry()
926
+ # Map widget's geometry from the content area to the visible viewport
927
+ widget_rect_in_viewport = QRect(
928
+ content_widget.mapTo(self.viewport(), widget_rect_in_content.topLeft()),
929
+ widget_rect_in_content.size()
930
+ )
931
+
932
+ is_in_band = selection_rect.intersects(widget_rect_in_viewport)
933
+
934
+ # A widget should be selected if it was selected at the start OR is in the band now.
935
+ should_be_selected = (widget in self.selection_at_press) or is_in_band
936
+
937
+ if should_be_selected and not widget.is_selected():
938
+ if self.select_widget(widget):
939
+ changed_ids.append(widget.data_item.annotation.id)
940
+ elif not should_be_selected and widget.is_selected():
941
+ if self.deselect_widget(widget):
942
+ changed_ids.append(widget.data_item.annotation.id)
943
+
944
+ if changed_ids:
945
+ self.selection_changed.emit(changed_ids)
946
+
947
+ def mouseReleaseEvent(self, event):
948
+ """Handle mouse release to complete rubber band selection."""
949
+ # Check if a rubber band drag was in progress
950
+ if self.rubber_band_origin is not None and event.button() == Qt.LeftButton:
951
+ if self.rubber_band and self.rubber_band.isVisible():
952
+ self.rubber_band.hide()
953
+ self.rubber_band.deleteLater()
954
+ self.rubber_band = None
955
+
956
+ # **NEEDED CHANGE**: Clean up the stored selection state.
957
+ self.selection_at_press = set()
958
+ self.rubber_band_origin = None
959
+ self.mouse_pressed_on_widget = False
960
+ event.accept()
961
+ return
962
+
963
+ super().mouseReleaseEvent(event)
964
+
965
+ def handle_annotation_selection(self, widget, event):
966
+ """Handle selection of annotation widgets with different modes."""
967
+ # Get the list of widgets to work with based on isolation mode
968
+ if self.isolated_mode:
969
+ # Only work with visible widgets when in isolation mode
970
+ widget_list = [w for w in self.annotation_widgets_by_id.values() if not w.isHidden()]
971
+ else:
972
+ # Use all widgets when not in isolation mode
973
+ widget_list = list(self.annotation_widgets_by_id.values())
974
+
975
+ try:
976
+ widget_index = widget_list.index(widget)
977
+ except ValueError:
978
+ return
979
+
980
+ modifiers = event.modifiers()
981
+ changed_ids = []
982
+
983
+ # --- The selection logic now identifies which items to change ---
984
+ # --- but the core state change happens in select/deselect ---
985
+
986
+ if modifiers == Qt.ShiftModifier or modifiers == (Qt.ShiftModifier | Qt.ControlModifier):
987
+ # Range selection
988
+ if self.last_selected_index != -1:
989
+ # Find the last selected widget in the current widget list
990
+ last_selected_widget = None
991
+ for w in self.selected_widgets:
992
+ if w in widget_list:
993
+ try:
994
+ last_index_in_current_list = widget_list.index(w)
995
+ if last_selected_widget is None or \
996
+ last_index_in_current_list > widget_list.index(last_selected_widget):
997
+ last_selected_widget = w
998
+ except ValueError:
999
+ continue
1000
+
1001
+ if last_selected_widget:
1002
+ last_selected_index_in_current_list = widget_list.index(last_selected_widget)
1003
+ start = min(last_selected_index_in_current_list, widget_index)
1004
+ end = max(last_selected_index_in_current_list, widget_index)
1005
+ else:
1006
+ # Fallback if no previously selected widget is found in current list
1007
+ start = widget_index
1008
+ end = widget_index
1009
+
1010
+ for i in range(start, end + 1):
1011
+ # select_widget will return True if a change occurred
1012
+ if self.select_widget(widget_list[i]):
1013
+ changed_ids.append(widget_list[i].data_item.annotation.id)
1014
+ else:
1015
+ if self.select_widget(widget):
1016
+ changed_ids.append(widget.data_item.annotation.id)
1017
+ self.last_selected_index = widget_index
1018
+
1019
+ elif modifiers == Qt.ControlModifier:
1020
+ # Toggle selection
1021
+ if widget.is_selected():
1022
+ if self.deselect_widget(widget):
1023
+ changed_ids.append(widget.data_item.annotation.id)
1024
+ else:
1025
+ if self.select_widget(widget):
1026
+ changed_ids.append(widget.data_item.annotation.id)
1027
+ self.last_selected_index = widget_index
1028
+
1029
+ else:
1030
+ # Normal click: clear all others and select this one
1031
+ newly_selected_id = widget.data_item.annotation.id
1032
+ # Deselect all widgets that are not the clicked one
1033
+ for w in list(self.selected_widgets):
1034
+ if w.data_item.annotation.id != newly_selected_id:
1035
+ if self.deselect_widget(w):
1036
+ changed_ids.append(w.data_item.annotation.id)
1037
+ # Select the clicked widget
1038
+ if self.select_widget(widget):
1039
+ changed_ids.append(newly_selected_id)
1040
+ self.last_selected_index = widget_index
1041
+
1042
+ # Update isolation if in isolated mode
1043
+ if self.isolated_mode:
1044
+ self._update_isolation()
1045
+
1046
+ # If any selections were changed, emit the signal
1047
+ if changed_ids:
1048
+ self.selection_changed.emit(changed_ids)
1049
+
1050
+ def _update_isolation(self):
1051
+ """Update the isolated view to show only currently selected widgets."""
1052
+ if not self.isolated_mode:
1053
+ return
1054
+
1055
+ if self.selected_widgets:
1056
+ # ADD TO isolation instead of replacing it
1057
+ self.isolated_widgets.update(self.selected_widgets) # Use update() to add, not replace
1058
+ self.setUpdatesEnabled(False)
1059
+ try:
1060
+ for widget in self.annotation_widgets_by_id.values():
1061
+ if widget not in self.isolated_widgets:
1062
+ widget.hide()
1063
+ else:
1064
+ widget.show()
1065
+ self.recalculate_widget_positions()
1066
+ finally:
1067
+ self.setUpdatesEnabled(True)
1068
+ else:
1069
+ # If no widgets are selected, keep the current isolation (don't exit)
1070
+ # This prevents accidentally exiting isolation mode when clearing selection
1071
+ pass
1072
+
1073
+ def select_widget(self, widget):
1074
+ """Select a widget, update the data_item, and return True if state changed."""
1075
+ if not widget.is_selected():
1076
+ widget.set_selected(True)
1077
+ widget.data_item.set_selected(True)
1078
+ self.selected_widgets.append(widget)
1079
+ self.update_label_window_selection()
1080
+ self._update_toolbar_state() # Update button states
1081
+ return True
1082
+ return False
1083
+
1084
+ def deselect_widget(self, widget):
1085
+ """Deselect a widget, update the data_item, and return True if state changed."""
1086
+ if widget.is_selected():
1087
+ widget.set_selected(False)
1088
+ widget.data_item.set_selected(False)
1089
+ if widget in self.selected_widgets:
1090
+ self.selected_widgets.remove(widget)
1091
+ self.update_label_window_selection()
1092
+ self._update_toolbar_state() # Update button states
1093
+ return True
1094
+ return False
1095
+
1096
+ def clear_selection(self):
1097
+ """Clear all selected widgets and update toolbar state."""
1098
+ for widget in list(self.selected_widgets):
1099
+ widget.set_selected(False)
1100
+ self.selected_widgets.clear()
1101
+ self.update_label_window_selection()
1102
+ self._update_toolbar_state() # Update button states
1103
+
1104
+ def update_label_window_selection(self):
1105
+ """Update the label window selection based on currently selected annotations."""
1106
+ explorer_window = self.parent()
1107
+ while explorer_window and not hasattr(explorer_window, 'main_window'):
1108
+ explorer_window = explorer_window.parent()
1109
+
1110
+ if not explorer_window or not hasattr(explorer_window, 'main_window'):
1111
+ return
1112
+
1113
+ main_window = explorer_window.main_window
1114
+ label_window = main_window.label_window
1115
+ annotation_window = main_window.annotation_window
1116
+
1117
+ if not self.selected_widgets:
1118
+ label_window.deselect_active_label()
1119
+ label_window.update_annotation_count()
1120
+ return
1121
+
1122
+ selected_data_items = [widget.data_item for widget in self.selected_widgets]
1123
+
1124
+ first_effective_label = selected_data_items[0].effective_label
1125
+ all_same_current_label = True
1126
+ for item in selected_data_items:
1127
+ if item.effective_label.id != first_effective_label.id:
1128
+ all_same_current_label = False
1129
+ break
1130
+
1131
+ if all_same_current_label:
1132
+ label_window.set_active_label(first_effective_label)
1133
+ if not selected_data_items[0].has_preview_changes():
1134
+ annotation_window.labelSelected.emit(first_effective_label.id)
1135
+ else:
1136
+ label_window.deselect_active_label()
1137
+
1138
+ label_window.update_annotation_count()
1139
+
1140
+ def get_selected_annotations(self):
1141
+ """Get the annotations corresponding to selected widgets."""
1142
+ return [widget.annotation for widget in self.selected_widgets]
1143
+
1144
+ def render_selection_from_ids(self, selected_ids):
1145
+ """Update the visual selection of widgets based on a set of IDs from the controller."""
1146
+ # Block signals temporarily to prevent cascade updates
1147
+ self.setUpdatesEnabled(False)
1148
+
1149
+ try:
1150
+ for ann_id, widget in self.annotation_widgets_by_id.items():
1151
+ is_selected = ann_id in selected_ids
1152
+ widget.set_selected(is_selected)
1153
+
1154
+ # Resync internal list of selected widgets
1155
+ self.selected_widgets = [w for w in self.annotation_widgets_by_id.values() if w.is_selected()]
1156
+
1157
+ # If we're in isolated mode, ADD to the isolation instead of replacing it
1158
+ if self.isolated_mode and self.selected_widgets:
1159
+ self.isolated_widgets.update(self.selected_widgets) # Add to existing isolation
1160
+ # Hide all widgets except those in the isolated set
1161
+ for widget in self.annotation_widgets_by_id.values():
1162
+ if widget not in self.isolated_widgets:
1163
+ widget.hide()
1164
+ else:
1165
+ widget.show()
1166
+ self.recalculate_widget_positions()
1167
+
1168
+ finally:
1169
+ self.setUpdatesEnabled(True)
1170
+
1171
+ # Update label window once at the end
1172
+ self.update_label_window_selection()
1173
+ # Update toolbar state to enable/disable Isolate button
1174
+ self._update_toolbar_state()
1175
+
1176
+ def apply_preview_label_to_selected(self, preview_label):
1177
+ """Apply a preview label and emit a signal for the embedding view to update."""
1178
+ if not self.selected_widgets or not preview_label:
1179
+ return
1180
+
1181
+ changed_ids = []
1182
+ for widget in self.selected_widgets:
1183
+ widget.data_item.set_preview_label(preview_label)
1184
+ widget.update() # Force repaint with new color
1185
+ changed_ids.append(widget.data_item.annotation.id)
1186
+
1187
+ # Recalculate positions to update sorting based on new effective labels
1188
+ if self.sort_combo.currentText() == "Label":
1189
+ self.recalculate_widget_positions()
1190
+
1191
+ if changed_ids:
1192
+ self.preview_changed.emit(changed_ids)
1193
+
1194
+ def clear_preview_states(self):
1195
+ """Clear all preview states and revert to original labels."""
1196
+ # We just need to iterate through all widgets and tell their data_items to clear
1197
+ something_cleared = False
1198
+ for widget in self.annotation_widgets_by_id.values():
1199
+ if widget.data_item.has_preview_changes():
1200
+ widget.data_item.clear_preview_label()
1201
+ widget.update() # Repaint to show original color
1202
+ something_cleared = True
1203
+
1204
+ if something_cleared:
1205
+ # Recalculate positions to update sorting based on reverted labels
1206
+ if self.sort_combo.currentText() == "Label":
1207
+ self.recalculate_widget_positions()
1208
+ self.update_label_window_selection()
1209
+
1210
+ def has_preview_changes(self):
1211
+ """Check if there are any pending preview changes."""
1212
+ return any(w.data_item.has_preview_changes() for w in self.annotation_widgets_by_id.values())
1213
+
1214
+ def get_preview_changes_summary(self):
1215
+ """Get a summary of preview changes for user feedback."""
1216
+ change_count = sum(1 for w in self.annotation_widgets_by_id.values() if w.data_item.has_preview_changes())
1217
+ if not change_count:
1218
+ return "No preview changes"
1219
+ return f"{change_count} annotation(s) with preview changes"
1220
+
1221
+ def apply_preview_changes_permanently(self):
1222
+ """Apply all preview changes permanently to the annotation data."""
1223
+ applied_annotations = []
1224
+ for widget in self.annotation_widgets_by_id.values():
1225
+ # Tell the data_item to apply its changes to the underlying annotation
1226
+ if widget.data_item.apply_preview_permanently():
1227
+ applied_annotations.append(widget.annotation)
1228
+
1229
+ return applied_annotations
1230
+
1231
+
1232
+ # ----------------------------------------------------------------------------------------------------------------------
1233
+ # ExplorerWindow
1234
+ # ----------------------------------------------------------------------------------------------------------------------
1235
+
1236
+
1237
+ class ExplorerWindow(QMainWindow):
1238
+ def __init__(self, main_window, parent=None):
1239
+ super(ExplorerWindow, self).__init__(parent)
1240
+ self.main_window = main_window
1241
+ self.image_window = main_window.image_window
1242
+ self.label_window = main_window.label_window
1243
+ self.annotation_window = main_window.annotation_window
1244
+
1245
+ self.device = main_window.device # Use the same device as the main window
1246
+ self.model_path = ""
1247
+ self.loaded_model = None
1248
+
1249
+ # Store current filtered data items for embedding
1250
+ self.current_data_items = []
1251
+
1252
+ # Cache for extracted features and the model that generated them ---
1253
+ self.current_features = None
1254
+ self.current_feature_generating_model = ""
1255
+
1256
+ self.setWindowTitle("Explorer")
1257
+ # Set the window icon
1258
+ explorer_icon_path = get_icon("magic.png")
1259
+ self.setWindowIcon(QIcon(explorer_icon_path))
1260
+
1261
+ # Create a central widget and main layout
1262
+ self.central_widget = QWidget()
1263
+ self.setCentralWidget(self.central_widget)
1264
+ self.main_layout = QVBoxLayout(self.central_widget)
1265
+ # Create a left panel widget and layout for the re-parented LabelWindow
1266
+ self.left_panel = QWidget()
1267
+ self.left_layout = QVBoxLayout(self.left_panel)
1268
+
1269
+ # Create widgets in __init__ so they're always available
1270
+ self.annotation_settings_widget = AnnotationSettingsWidget(self.main_window, self)
1271
+ self.model_settings_widget = ModelSettingsWidget(self.main_window, self)
1272
+ self.embedding_settings_widget = EmbeddingSettingsWidget(self.main_window, self)
1273
+ self.annotation_viewer = AnnotationViewer(self)
1274
+ self.embedding_viewer = EmbeddingViewer(self)
1275
+
1276
+ # Create buttons
1277
+ self.clear_preview_button = QPushButton('Clear Preview', self)
1278
+ self.clear_preview_button.clicked.connect(self.clear_preview_changes)
1279
+ self.clear_preview_button.setToolTip("Clear all preview changes and revert to original labels")
1280
+ self.clear_preview_button.setEnabled(False) # Initially disabled
1281
+
1282
+ self.exit_button = QPushButton('Exit', self)
1283
+ self.exit_button.clicked.connect(self.close)
1284
+ self.exit_button.setToolTip("Close the window")
1285
+
1286
+ self.apply_button = QPushButton('Apply', self)
1287
+ self.apply_button.clicked.connect(self.apply)
1288
+ self.apply_button.setToolTip("Apply changes")
1289
+ self.apply_button.setEnabled(False) # Initially disabled
1290
+
1291
+ def showEvent(self, event):
1292
+ self.setup_ui()
1293
+ super(ExplorerWindow, self).showEvent(event)
1294
+
1295
+ def closeEvent(self, event):
1296
+ """
1297
+ Handles the window close event.
1298
+ This now calls the resource cleanup method.
1299
+ """
1300
+ # Stop any running timers to prevent errors
1301
+ if hasattr(self, 'embedding_viewer') and self.embedding_viewer:
1302
+ if hasattr(self.embedding_viewer, 'animation_timer') and self.embedding_viewer.animation_timer:
1303
+ self.embedding_viewer.animation_timer.stop()
1304
+
1305
+ # Clear any unsaved preview states
1306
+ if hasattr(self, 'annotation_viewer'):
1307
+ self.annotation_viewer.clear_preview_states()
1308
+
1309
+ # --- NEW: Call the dedicated cleanup method ---
1310
+ self._cleanup_resources()
1311
+
1312
+ # Re-enable the main window before closing
1313
+ if self.main_window:
1314
+ self.main_window.setEnabled(True)
1315
+
1316
+ # Move the label_window back to the main_window
1317
+ if hasattr(self.main_window, 'explorer_closed'):
1318
+ self.main_window.explorer_closed()
1319
+
1320
+ # Clear the reference in the main_window to allow garbage collection
1321
+ self.main_window.explorer_window = None
1322
+
1323
+ event.accept()
1324
+
1325
+ def setup_ui(self):
1326
+ # Clear the main layout to remove any existing widgets
1327
+ while self.main_layout.count():
1328
+ child = self.main_layout.takeAt(0)
1329
+ if child.widget():
1330
+ child.widget().setParent(None) # Remove from layout but don't delete
1331
+
1332
+ # Top section: Conditions and Settings side by side
1333
+ top_layout = QHBoxLayout()
1334
+
1335
+ # Add existing widgets to layout
1336
+ top_layout.addWidget(self.annotation_settings_widget, 2) # Give annotation settings more space
1337
+ top_layout.addWidget(self.model_settings_widget, 1) # Model settings in the middle
1338
+ top_layout.addWidget(self.embedding_settings_widget, 1) # Embedding settings on the right
1339
+
1340
+ # Create container widget for top layout
1341
+ top_container = QWidget()
1342
+ top_container.setLayout(top_layout)
1343
+ self.main_layout.addWidget(top_container)
1344
+
1345
+ # Middle section: Annotation Viewer (left) and Embedding Viewer (right)
1346
+ middle_splitter = QSplitter(Qt.Horizontal)
1347
+
1348
+ # Wrap annotation viewer in a group box
1349
+ annotation_group = QGroupBox("Annotation Viewer")
1350
+ annotation_layout = QVBoxLayout(annotation_group)
1351
+ annotation_layout.addWidget(self.annotation_viewer)
1352
+ middle_splitter.addWidget(annotation_group)
1353
+
1354
+ # Wrap embedding viewer in a group box
1355
+ embedding_group = QGroupBox("Embedding Viewer")
1356
+ embedding_layout = QVBoxLayout(embedding_group)
1357
+ embedding_layout.addWidget(self.embedding_viewer)
1358
+ middle_splitter.addWidget(embedding_group)
1359
+
1360
+ # Set splitter proportions (annotation viewer wider)
1361
+ middle_splitter.setSizes([500, 500])
1362
+
1363
+ # Add middle section to main layout with stretch factor
1364
+ self.main_layout.addWidget(middle_splitter, 1)
1365
+
1366
+ # Note: LabelWindow will be re-parented here by MainWindow.open_explorer_window()
1367
+ # The LabelWindow will be added to self.left_layout at index 1 by the MainWindow
1368
+ self.main_layout.addWidget(self.label_window)
1369
+
1370
+ # Bottom control buttons
1371
+ self.buttons_layout = QHBoxLayout()
1372
+ # Add stretch to push buttons to the right
1373
+ self.buttons_layout.addStretch(1)
1374
+
1375
+ # Add existing buttons to layout
1376
+ self.buttons_layout.addWidget(self.clear_preview_button)
1377
+ self.buttons_layout.addWidget(self.exit_button)
1378
+ self.buttons_layout.addWidget(self.apply_button)
1379
+
1380
+ self.main_layout.addLayout(self.buttons_layout)
1381
+
1382
+ # Set default condition to current image and refresh filters
1383
+ self.annotation_settings_widget.set_default_to_current_image()
1384
+ self.refresh_filters()
1385
+
1386
+ # Connect label selection to preview updates (only connect once)
1387
+ try:
1388
+ self.label_window.labelSelected.disconnect(self.on_label_selected_for_preview)
1389
+ except TypeError:
1390
+ pass # Signal wasn't connected yet
1391
+
1392
+ self.label_window.labelSelected.connect(self.on_label_selected_for_preview)
1393
+ self.annotation_viewer.selection_changed.connect(self.on_annotation_view_selection_changed)
1394
+ self.annotation_viewer.preview_changed.connect(self.on_preview_changed)
1395
+ self.annotation_viewer.reset_view_requested.connect(self.on_reset_view_requested)
1396
+ self.embedding_viewer.selection_changed.connect(self.on_embedding_view_selection_changed)
1397
+ self.embedding_viewer.reset_view_requested.connect(self.on_reset_view_requested)
1398
+
1399
+ @pyqtSlot(list)
1400
+ def on_annotation_view_selection_changed(self, changed_ann_ids):
1401
+ """A selection was made in the AnnotationViewer, so update the EmbeddingViewer."""
1402
+ all_selected_ids = {w.data_item.annotation.id for w in self.annotation_viewer.selected_widgets}
1403
+
1404
+ # Only try to sync the selection with the EmbeddingViewer if it has points.
1405
+ # This prevents the feedback loop that was clearing the selection.
1406
+ if self.embedding_viewer.points_by_id:
1407
+ self.embedding_viewer.render_selection_from_ids(all_selected_ids)
1408
+
1409
+ self.update_label_window_selection() # Keep label window in sync
1410
+
1411
+ @pyqtSlot(list)
1412
+ def on_embedding_view_selection_changed(self, all_selected_ann_ids):
1413
+ """A selection was made in the EmbeddingViewer, so update the AnnotationViewer."""
1414
+ # Check if this is a new selection being made when nothing was previously selected
1415
+ was_empty_selection = len(self.annotation_viewer.selected_widgets) == 0
1416
+ is_new_selection = len(all_selected_ann_ids) > 0
1417
+
1418
+ # Update the annotation viewer with the new selection
1419
+ self.annotation_viewer.render_selection_from_ids(set(all_selected_ann_ids))
1420
+
1421
+ # Auto-switch to isolation mode if conditions are met
1422
+ if (was_empty_selection and
1423
+ is_new_selection and
1424
+ not self.annotation_viewer.isolated_mode):
1425
+ print("Auto-switching to isolation mode due to new selection in embedding viewer")
1426
+ self.annotation_viewer.isolate_selection()
1427
+
1428
+ self.update_label_window_selection() # Keep label window in sync
1429
+
1430
+ @pyqtSlot(list)
1431
+ def on_preview_changed(self, changed_ann_ids):
1432
+ """A preview color was changed in the AnnotationViewer, so update the EmbeddingViewer points."""
1433
+ for ann_id in changed_ann_ids:
1434
+ point = self.embedding_viewer.points_by_id.get(ann_id)
1435
+ if point:
1436
+ point.update() # Force the point to repaint itself
1437
+
1438
+ @pyqtSlot()
1439
+ def on_reset_view_requested(self):
1440
+ """Handle reset view requests from double-click in either viewer."""
1441
+ # Clear all selections in both viewers
1442
+ self.annotation_viewer.clear_selection()
1443
+ self.embedding_viewer.render_selection_from_ids(set())
1444
+
1445
+ # Exit isolation mode if currently active
1446
+ if self.annotation_viewer.isolated_mode:
1447
+ self.annotation_viewer.show_all_annotations()
1448
+
1449
+ # Update button states
1450
+ self.update_button_states()
1451
+
1452
+ print("Reset view: cleared selections and exited isolation mode")
1453
+
1454
+ def update_label_window_selection(self):
1455
+ """Update the label window based on the selection in the annotation viewer."""
1456
+ self.annotation_viewer.update_label_window_selection()
1457
+
1458
+ def get_filtered_data_items(self):
1459
+ """Get annotations that match all conditions, returned as AnnotationDataItem objects."""
1460
+ data_items = []
1461
+ if not hasattr(self.main_window.annotation_window, 'annotations_dict'):
1462
+ return data_items
1463
+
1464
+ # Get current filter conditions
1465
+ selected_images = self.annotation_settings_widget.get_selected_images()
1466
+ selected_types = self.annotation_settings_widget.get_selected_annotation_types()
1467
+ selected_labels = self.annotation_settings_widget.get_selected_labels()
1468
+
1469
+ annotations_to_process = []
1470
+ for annotation in self.main_window.annotation_window.annotations_dict.values():
1471
+ annotation_matches = True
1472
+
1473
+ # Check image condition - if empty list, no annotations match
1474
+ if selected_images:
1475
+ annotation_image = os.path.basename(annotation.image_path)
1476
+ if annotation_image not in selected_images:
1477
+ annotation_matches = False
1478
+ else:
1479
+ # No images selected means no annotations should match
1480
+ annotation_matches = False
1481
+
1482
+ # Check annotation type condition - if empty list, no annotations match
1483
+ if annotation_matches:
1484
+ if selected_types:
1485
+ annotation_type = type(annotation).__name__
1486
+ if annotation_type not in selected_types:
1487
+ annotation_matches = False
1488
+ else:
1489
+ # No types selected means no annotations should match
1490
+ annotation_matches = False
1491
+
1492
+ # Check label condition - if empty list, no annotations match
1493
+ if annotation_matches:
1494
+ if selected_labels:
1495
+ annotation_label = annotation.label.short_label_code
1496
+ if annotation_label not in selected_labels:
1497
+ annotation_matches = False
1498
+ else:
1499
+ # No labels selected means no annotations should match
1500
+ annotation_matches = False
1501
+
1502
+ if annotation_matches:
1503
+ annotations_to_process.append(annotation)
1504
+
1505
+ # Ensure all filtered annotations have cropped images
1506
+ self._ensure_cropped_images(annotations_to_process)
1507
+
1508
+ # Wrap in AnnotationDataItem
1509
+ for ann in annotations_to_process:
1510
+ data_items.append(AnnotationDataItem(ann))
1511
+
1512
+ return data_items
1513
+
1514
+ def _ensure_cropped_images(self, annotations):
1515
+ """Ensure all provided annotations have a cropped image available."""
1516
+ annotations_by_image = {}
1517
+ for annotation in annotations:
1518
+ # Only process annotations that don't have a cropped image yet
1519
+ if not annotation.cropped_image:
1520
+ image_path = annotation.image_path
1521
+ if image_path not in annotations_by_image:
1522
+ annotations_by_image[image_path] = []
1523
+ annotations_by_image[image_path].append(annotation)
1524
+
1525
+ # Only proceed if there are annotations that actually need cropping
1526
+ if annotations_by_image:
1527
+ progress_bar = ProgressBar(self, "Cropping Image Annotations")
1528
+ progress_bar.show()
1529
+ progress_bar.start_progress(len(annotations_by_image))
1530
+
1531
+ try:
1532
+ # Crop annotations for each image using the AnnotationWindow method
1533
+ # This ensures consistency with how cropped images are generated elsewhere
1534
+ for image_path, image_annotations in annotations_by_image.items():
1535
+ self.annotation_window.crop_annotations(
1536
+ image_path=image_path,
1537
+ annotations=image_annotations,
1538
+ return_annotations=False, # We don't need the return value
1539
+ verbose=False
1540
+ )
1541
+ # Update progress bar
1542
+ progress_bar.update_progress()
1543
+
1544
+ except Exception as e:
1545
+ print(f"Error cropping annotations: {e}")
1546
+
1547
+ finally:
1548
+ progress_bar.finish_progress()
1549
+ progress_bar.stop_progress()
1550
+ progress_bar.close()
1551
+
1552
+ def _extract_color_features(self, data_items, progress_bar=None, bins=32):
1553
+ """
1554
+ Extracts a comprehensive set of color features from annotation crops using only NumPy.
1555
+
1556
+ For each image, it calculates:
1557
+ 1. Color Moments (per channel):
1558
+ - Mean (1st moment)
1559
+ - Standard Deviation (2nd moment)
1560
+ - Skewness (3rd moment)
1561
+ - Kurtosis (4th moment)
1562
+ 2. Color Histogram (per channel)
1563
+ 3. Grayscale Statistics:
1564
+ - Mean Brightness
1565
+ - Contrast (Std Dev)
1566
+ - Intensity Range
1567
+ 4. Geometric Features:
1568
+ - Area
1569
+ - Perimeter
1570
+ """
1571
+ if progress_bar:
1572
+ progress_bar.set_title("Extracting Color Features...")
1573
+ progress_bar.start_progress(len(data_items))
1574
+
1575
+ features = []
1576
+ valid_data_items = []
1577
+ for item in data_items:
1578
+ pixmap = item.annotation.get_cropped_image()
1579
+ if pixmap and not pixmap.isNull():
1580
+ # arr has shape (height, width, 3)
1581
+ arr = pixmap_to_numpy(pixmap)
1582
+
1583
+ # Reshape for channel-wise statistics: (num_pixels, 3)
1584
+ pixels = arr.reshape(-1, 3)
1585
+
1586
+ # --- 1. Calculate Color Moments using only NumPy ---
1587
+ mean_color = np.mean(pixels, axis=0)
1588
+ std_color = np.std(pixels, axis=0)
1589
+
1590
+ # Center the data (subtract the mean) for skew/kurtosis calculation
1591
+ centered_pixels = pixels - mean_color
1592
+
1593
+ # A small value (epsilon) is added to the denominator to prevent division by zero
1594
+ epsilon = 1e-8
1595
+ skew_color = np.mean(centered_pixels**3, axis=0) / (std_color**3 + epsilon)
1596
+ kurt_color = np.mean(centered_pixels**4, axis=0) / (std_color**4 + epsilon) - 3
1597
+
1598
+ # --- 2. Calculate Color Histograms ---
1599
+ histograms = []
1600
+ for i in range(3): # For each channel (R, G, B)
1601
+ hist, _ = np.histogram(pixels[:, i], bins=bins, range=(0, 255))
1602
+
1603
+ # Normalize histogram
1604
+ hist_sum = np.sum(hist)
1605
+ if hist_sum > 0:
1606
+ histograms.append(hist / hist_sum)
1607
+ else: # Avoid division by zero
1608
+ histograms.append(np.zeros(bins))
1609
+
1610
+ # --- 3. Calculate Grayscale Statistics ---
1611
+ gray_arr = np.dot(arr[..., :3], [0.2989, 0.5870, 0.1140])
1612
+ grayscale_stats = np.array([
1613
+ np.mean(gray_arr), # Overall brightness
1614
+ np.std(gray_arr), # Overall contrast
1615
+ np.ptp(gray_arr) # Peak-to-peak intensity range
1616
+ ])
1617
+
1618
+ # --- 4. Calculate Geometric Features ---
1619
+ area = getattr(item.annotation, 'area', 0.0)
1620
+ perimeter = getattr(item.annotation, 'perimeter', 0.0)
1621
+ geometric_features = np.array([area, perimeter])
1622
+
1623
+ # --- 5. Concatenate all features into a single vector ---
1624
+ current_features = np.concatenate([
1625
+ mean_color,
1626
+ std_color,
1627
+ skew_color,
1628
+ kurt_color,
1629
+ *histograms,
1630
+ grayscale_stats,
1631
+ geometric_features
1632
+ ])
1633
+
1634
+ features.append(current_features)
1635
+ valid_data_items.append(item)
1636
+ else:
1637
+ print(f"Warning: Could not get cropped image for annotation ID {item.annotation.id}. Skipping.")
1638
+
1639
+ if progress_bar:
1640
+ progress_bar.update_progress()
1641
+
1642
+ return np.array(features), valid_data_items
1643
+
1644
+ def _extract_yolo_features(self, data_items, model_info, progress_bar=None):
1645
+ """
1646
+ Extracts features from annotation crops using a specified YOLO model.
1647
+ Uses model.embed() for embedding features or model.predict() for classification probabilities.
1648
+ """
1649
+ # Unpack model information
1650
+ model_name, feature_mode = model_info
1651
+
1652
+ # Load or retrieve the cached model
1653
+ if model_name != self.model_path or self.loaded_model is None:
1654
+ try:
1655
+ print(f"Loading new model: {model_name}")
1656
+ self.loaded_model = YOLO(model_name)
1657
+ self.model_path = model_name
1658
+
1659
+ # Determine image size from model config if possible
1660
+ try:
1661
+ self.imgsz = self.loaded_model.model.args['imgsz']
1662
+ if self.imgsz > 224:
1663
+ self.imgsz = 128
1664
+ except (AttributeError, KeyError):
1665
+ self.imgsz = 128
1666
+
1667
+ # Run a dummy inference to warm up the model
1668
+ print(f"Warming up model on device '{self.device}'...")
1669
+ dummy_image = np.zeros((self.imgsz, self.imgsz, 3), dtype=np.uint8)
1670
+ self.loaded_model.predict(dummy_image, imgsz=self.imgsz, half=True, device=self.device, verbose=False)
1671
+
1672
+ except Exception as e:
1673
+ print(f"ERROR: Could not load YOLO model '{model_name}': {e}")
1674
+ return np.array([]), []
1675
+
1676
+ if progress_bar:
1677
+ progress_bar.set_title(f"Preparing images...")
1678
+ progress_bar.start_progress(len(data_items))
1679
+
1680
+ # 1. Prepare a list of all valid images and their corresponding data items.
1681
+ image_list = []
1682
+ valid_data_items = []
1683
+ for item in data_items:
1684
+ pixmap = item.annotation.get_cropped_image()
1685
+ if pixmap and not pixmap.isNull():
1686
+ image_np = pixmap_to_numpy(pixmap)
1687
+ image_list.append(image_np)
1688
+ valid_data_items.append(item)
1689
+ else:
1690
+ print(f"Warning: Could not get cropped image for annotation ID {item.annotation.id}. Skipping.")
1691
+
1692
+ if progress_bar:
1693
+ progress_bar.update_progress()
1694
+
1695
+ if not valid_data_items:
1696
+ print("Warning: No valid images found to process.")
1697
+ return np.array([]), []
1698
+
1699
+ embeddings_list = []
1700
+
1701
+ try:
1702
+ if progress_bar:
1703
+ progress_bar.set_busy_mode(f"Extracting features with {os.path.basename(model_name)}...")
1704
+
1705
+ kwargs = {
1706
+ 'stream': True,
1707
+ 'imgsz': self.imgsz,
1708
+ 'half': True,
1709
+ 'device': self.device,
1710
+ 'verbose': False
1711
+ }
1712
+
1713
+ # 2. Choose between embed() and predict() based on feature mode
1714
+ using_embed_method = feature_mode == "Embed Features"
1715
+
1716
+ if using_embed_method:
1717
+ print("Using embed() method")
1718
+
1719
+ # Use model.embed() method (uses the second to last layer)
1720
+ results_generator = self.loaded_model.embed(image_list, **kwargs)
1721
+
1722
+ else:
1723
+ print("Using predict() method")
1724
+
1725
+ # Use model.predict() method for classification probabilities
1726
+ results_generator = self.loaded_model.predict(image_list, **kwargs)
1727
+
1728
+ if progress_bar:
1729
+ progress_bar.set_title(f"Extracting features with {os.path.basename(model_name)}...")
1730
+ progress_bar.start_progress(len(valid_data_items))
1731
+
1732
+ # 3. Process the results from the generator - different handling based on method
1733
+ for result in results_generator:
1734
+ if using_embed_method:
1735
+ try:
1736
+ # With embed(), result is directly a tensor
1737
+ embedding = result.cpu().numpy().flatten()
1738
+ embeddings_list.append(embedding)
1739
+ except Exception as e:
1740
+ print(f"Error processing embedding: {e}")
1741
+ raise TypeError(
1742
+ f"Model '{os.path.basename(model_name)}' did not return valid embeddings. "
1743
+ f"Error: {str(e)}"
1744
+ )
1745
+ else:
1746
+ # Classification mode: We expect probability vectors
1747
+ if hasattr(result, 'probs') and result.probs is not None:
1748
+ embedding = result.probs.data.cpu().numpy().squeeze()
1749
+ embeddings_list.append(embedding)
1750
+ else:
1751
+ raise TypeError(
1752
+ f"Model '{os.path.basename(model_name)}' did not return probability vectors. "
1753
+ "Make sure this is a classification model."
1754
+ )
1755
+
1756
+ # Update progress for each item as it's processed from the stream.
1757
+ if progress_bar:
1758
+ progress_bar.update_progress()
1759
+
1760
+ if not embeddings_list:
1761
+ print("Warning: No features were extracted. The model may have failed.")
1762
+ return np.array([]), []
1763
+
1764
+ embeddings = np.array(embeddings_list)
1765
+ except Exception as e:
1766
+ print(f"ERROR: An error occurred during feature extraction: {e}")
1767
+ return np.array([]), []
1768
+ finally:
1769
+ # Clean up CUDA memory after the operation
1770
+ if torch.cuda.is_available():
1771
+ torch.cuda.empty_cache()
1772
+
1773
+ print(f"Successfully extracted {len(embeddings)} features with shape {embeddings.shape}")
1774
+ return embeddings, valid_data_items
1775
+
1776
+ def _extract_features(self, data_items, progress_bar=None):
1777
+ """
1778
+ Dispatcher method to call the appropriate feature extraction function.
1779
+ It now passes the progress_bar object to the sub-methods.
1780
+ """
1781
+ model_name, feature_mode = self.model_settings_widget.get_selected_model()
1782
+
1783
+ # Handle tuple or string return value
1784
+ if isinstance(model_name, tuple) and len(model_name) >= 3:
1785
+ model_name = model_name[0]
1786
+ else:
1787
+ model_name = model_name
1788
+
1789
+ if not model_name:
1790
+ print("No model selected or path provided.")
1791
+ return np.array([]), []
1792
+
1793
+ # --- MODIFIED: Pass the progress_bar object ---
1794
+ if model_name == "Color Features":
1795
+ return self._extract_color_features(data_items, progress_bar=progress_bar)
1796
+ elif ".pt" in model_name:
1797
+ # Pass the full model_info which may include embed layers
1798
+ return self._extract_yolo_features(data_items, (model_name, feature_mode), progress_bar=progress_bar)
1799
+ else:
1800
+ print(f"Unknown or invalid feature model selected: {model_name}")
1801
+ return np.array([]), []
1802
+
1803
+ def _run_dimensionality_reduction(self, features, params):
1804
+ """
1805
+ Runs PCA, UMAP or t-SNE on the feature matrix using provided parameters.
1806
+ """
1807
+ technique = params.get('technique', 'UMAP') # Changed default to UMAP
1808
+ random_state = 42
1809
+
1810
+ print(f"Running {technique} on {len(features)} items with params: {params}")
1811
+ if len(features) <= 2: # UMAP/t-SNE need at least a few points
1812
+ print("Not enough data points for dimensionality reduction.")
1813
+ return None
1814
+
1815
+ try:
1816
+ # Scaling is crucial, your implementation is already correct
1817
+ scaler = StandardScaler()
1818
+ features_scaled = scaler.fit_transform(features)
1819
+
1820
+ if technique == "UMAP":
1821
+ # Get hyperparameters from params, with sensible defaults
1822
+ n_neighbors = params.get('n_neighbors', 15)
1823
+ min_dist = params.get('min_dist', 0.1)
1824
+ metric = params.get('metric', 'cosine') # Allow metric to be specified
1825
+
1826
+ reducer = UMAP(
1827
+ n_components=2,
1828
+ random_state=random_state,
1829
+ n_neighbors=min(n_neighbors, len(features_scaled) - 1),
1830
+ min_dist=min_dist,
1831
+ metric=metric # Use the specified metric
1832
+ )
1833
+
1834
+ elif technique == "TSNE":
1835
+ perplexity = params.get('perplexity', 30)
1836
+ early_exaggeration = params.get('early_exaggeration', 12.0)
1837
+ learning_rate = params.get('learning_rate', 'auto')
1838
+
1839
+ reducer = TSNE(
1840
+ n_components=2,
1841
+ random_state=random_state,
1842
+ perplexity=min(perplexity, len(features_scaled) - 1),
1843
+ early_exaggeration=early_exaggeration,
1844
+ learning_rate=learning_rate,
1845
+ init='pca' # Improves stability and speed
1846
+ )
1847
+
1848
+ elif technique == "PCA":
1849
+ reducer = PCA(n_components=2, random_state=random_state)
1850
+
1851
+ else:
1852
+ print(f"Unknown dimensionality reduction technique: {technique}")
1853
+ return None
1854
+
1855
+ return reducer.fit_transform(features_scaled)
1856
+
1857
+ except Exception as e:
1858
+ print(f"Error during {technique} dimensionality reduction: {e}")
1859
+ return None
1860
+
1861
+ def _update_data_items_with_embedding(self, data_items, embedded_features):
1862
+ """Updates AnnotationDataItem objects with embedding results."""
1863
+ scale_factor = 4000
1864
+ min_vals = np.min(embedded_features, axis=0)
1865
+ max_vals = np.max(embedded_features, axis=0)
1866
+ range_vals = max_vals - min_vals
1867
+
1868
+ for i, item in enumerate(data_items):
1869
+ # Normalize coordinates for consistent display
1870
+ norm_x = (embedded_features[i, 0] - min_vals[0]) / range_vals[0] if range_vals[0] > 0 else 0.5
1871
+ norm_y = (embedded_features[i, 1] - min_vals[1]) / range_vals[1] if range_vals[1] > 0 else 0.5
1872
+ # Scale and center the points in the view
1873
+ item.embedding_x = (norm_x * scale_factor) - (scale_factor / 2)
1874
+ item.embedding_y = (norm_y * scale_factor) - (scale_factor / 2)
1875
+
1876
+ def run_embedding_pipeline(self):
1877
+ """
1878
+ Orchestrates the feature extraction and dimensionality reduction pipeline.
1879
+ This version correctly re-runs reduction on cached features when parameters change.
1880
+ """
1881
+ if not self.current_data_items:
1882
+ print("No items to process for embedding.")
1883
+ return
1884
+
1885
+ # 1. Get current parameters from the UI
1886
+ embedding_params = self.embedding_settings_widget.get_embedding_parameters()
1887
+ model_info = self.model_settings_widget.get_selected_model() # Now returns tuple (model_name, feature_mode)
1888
+
1889
+ # Unpack model info - both model_name and feature_mode are used for caching
1890
+ if isinstance(model_info, tuple):
1891
+ selected_model, selected_feature_mode = model_info
1892
+ # Create a unique cache key that includes both model and feature mode
1893
+ cache_key = f"{selected_model}_{selected_feature_mode}"
1894
+ else:
1895
+ selected_model = model_info
1896
+ selected_feature_mode = "default"
1897
+ cache_key = f"{selected_model}_{selected_feature_mode}"
1898
+
1899
+ technique = embedding_params['technique']
1900
+
1901
+ QApplication.setOverrideCursor(Qt.WaitCursor)
1902
+ progress_bar = ProgressBar(self, "Generating Embedding Visualization")
1903
+ progress_bar.show()
1904
+
1905
+ try:
1906
+ # 2. Decide whether to use cached features or extract new ones
1907
+ # Now checks both model name AND feature mode
1908
+ if self.current_features is None or cache_key != self.current_feature_generating_model:
1909
+ # SLOW PATH: Extract and cache new features
1910
+ features, valid_data_items = self._extract_features(self.current_data_items, progress_bar=progress_bar)
1911
+
1912
+ self.current_features = features
1913
+ self.current_feature_generating_model = cache_key # Store the complete cache key
1914
+ self.current_data_items = valid_data_items
1915
+ self.annotation_viewer.update_annotations(self.current_data_items)
1916
+ else:
1917
+ # FAST PATH: Use existing features
1918
+ print("Using cached features. Skipping feature extraction.")
1919
+ features = self.current_features
1920
+
1921
+ if features is None or len(features) == 0:
1922
+ print("No valid features available. Aborting embedding.")
1923
+ return
1924
+
1925
+ # 3. Run dimensionality reduction with the latest parameters
1926
+ progress_bar.set_busy_mode(f"Running {technique} dimensionality reduction...")
1927
+ embedded_features = self._run_dimensionality_reduction(features, embedding_params)
1928
+ progress_bar.update_progress()
1929
+
1930
+ if embedded_features is None:
1931
+ return
1932
+
1933
+ # 4. Update the visualization with the new 2D layout
1934
+ progress_bar.set_busy_mode("Updating visualization...")
1935
+ self._update_data_items_with_embedding(self.current_data_items, embedded_features)
1936
+
1937
+ self.embedding_viewer.update_embeddings(self.current_data_items)
1938
+ self.embedding_viewer.show_embedding()
1939
+ self.embedding_viewer.fit_view_to_points()
1940
+ progress_bar.update_progress()
1941
+
1942
+ print(f"Successfully generated embedding for {len(self.current_data_items)} annotations using {technique}")
1943
+
1944
+ except Exception as e:
1945
+ print(f"Error during embedding pipeline: {e}")
1946
+ self.embedding_viewer.clear_points()
1947
+ self.embedding_viewer.show_placeholder()
1948
+
1949
+ finally:
1950
+ QApplication.restoreOverrideCursor()
1951
+ progress_bar.finish_progress()
1952
+ progress_bar.stop_progress()
1953
+ progress_bar.close()
1954
+
1955
+ def refresh_filters(self):
1956
+ """Refresh display: filter data and update annotation viewer."""
1957
+ QApplication.setOverrideCursor(Qt.WaitCursor)
1958
+ try:
1959
+ # Get filtered data and store for potential embedding
1960
+ self.current_data_items = self.get_filtered_data_items()
1961
+
1962
+ # --- MODIFIED: Invalidate the feature cache ---
1963
+ # Since the filtered items have changed, the old features are no longer valid.
1964
+ self.current_features = None
1965
+
1966
+ # Update annotation viewer with filtered data
1967
+ self.annotation_viewer.update_annotations(self.current_data_items)
1968
+
1969
+ # Clear embedding viewer and show placeholder, as it is now out of sync
1970
+ self.embedding_viewer.clear_points()
1971
+ self.embedding_viewer.show_placeholder()
1972
+
1973
+ finally:
1974
+ QApplication.restoreOverrideCursor()
1975
+
1976
+ def on_label_selected_for_preview(self, label):
1977
+ """Handle label selection to update preview state."""
1978
+ if hasattr(self, 'annotation_viewer') and self.annotation_viewer.selected_widgets:
1979
+ self.annotation_viewer.apply_preview_label_to_selected(label)
1980
+ self.update_button_states()
1981
+
1982
+ def clear_preview_changes(self):
1983
+ """Clear all preview changes and revert to original labels."""
1984
+ if hasattr(self, 'annotation_viewer'):
1985
+ self.annotation_viewer.clear_preview_states()
1986
+ self.update_button_states()
1987
+ print("Cleared all preview changes")
1988
+
1989
+ def update_button_states(self):
1990
+ """Update the state of Clear Preview and Apply buttons."""
1991
+ has_changes = (hasattr(self, 'annotation_viewer') and self.annotation_viewer.has_preview_changes())
1992
+
1993
+ self.clear_preview_button.setEnabled(has_changes)
1994
+ self.apply_button.setEnabled(has_changes)
1995
+
1996
+ summary = self.annotation_viewer.get_preview_changes_summary()
1997
+ self.clear_preview_button.setToolTip(f"Clear all preview changes - {summary}")
1998
+ self.apply_button.setToolTip(f"Apply changes - {summary}")
1999
+
2000
+ def apply(self):
2001
+ """Apply any modifications to the actual annotations."""
2002
+ # Make cursor busy
2003
+ QApplication.setOverrideCursor(Qt.WaitCursor)
2004
+
2005
+ try:
2006
+ applied_annotations = self.annotation_viewer.apply_preview_changes_permanently()
2007
+
2008
+ if applied_annotations:
2009
+ # Find which data items were affected and tell their visual components to update
2010
+ changed_ids = {ann.id for ann in applied_annotations}
2011
+ for item in self.current_data_items:
2012
+ if item.annotation.id in changed_ids:
2013
+ # Update annotation widget in the grid
2014
+ widget = self.annotation_viewer.annotation_widgets_by_id.get(item.annotation.id)
2015
+ if widget:
2016
+ widget.update() # Repaint to show new permanent color
2017
+
2018
+ # Update point in the embedding viewer
2019
+ point = self.embedding_viewer.points_by_id.get(item.annotation.id)
2020
+ if point:
2021
+ point.update() # Repaint to show new permanent color
2022
+
2023
+ # Update sorting if currently sorting by label
2024
+ if self.annotation_viewer.sort_combo.currentText() == "Label":
2025
+ self.annotation_viewer.recalculate_widget_positions()
2026
+
2027
+ # Update the main application's data
2028
+ affected_images = {ann.image_path for ann in applied_annotations}
2029
+ for image_path in affected_images:
2030
+ self.image_window.update_image_annotations(image_path)
2031
+ self.annotation_window.load_annotations()
2032
+
2033
+ # Clear selection and button states
2034
+ self.annotation_viewer.clear_selection()
2035
+ self.embedding_viewer.render_selection_from_ids(set()) # Clear embedding selection
2036
+ self.update_button_states()
2037
+
2038
+ print(f"Applied changes to {len(applied_annotations)} annotation(s)")
2039
+ else:
2040
+ print("No preview changes to apply")
2041
+
2042
+ except Exception as e:
2043
+ print(f"Error applying modifications: {e}")
2044
+ finally:
2045
+ # Restore cursor
2046
+ QApplication.restoreOverrideCursor()
2047
+
2048
+ def _cleanup_resources(self):
2049
+ """
2050
+ Clean up heavy resources like the loaded model and clear GPU cache.
2051
+ This is called when the window is closed to free up memory.
2052
+ """
2053
+ print("Cleaning up Explorer resources...")
2054
+
2055
+ # Reset model and feature caches
2056
+ self.imgsz = 128
2057
+ self.loaded_model = None
2058
+ self.model_path = ""
2059
+ self.current_features = None
2060
+ self.current_feature_generating_model = ""
2061
+
2062
+ # Clear CUDA cache if available to free up GPU memory
2063
+ if torch.cuda.is_available():
2064
+ print("Clearing CUDA cache.")
2065
+ torch.cuda.empty_cache()
2066
+
2067
+ print("Cleanup complete.")