coralnet-toolbox 0.0.66__py2.py3-none-any.whl → 0.0.68__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. coralnet_toolbox/Annotations/QtMultiPolygonAnnotation.py +1 -1
  2. coralnet_toolbox/Annotations/QtPatchAnnotation.py +1 -1
  3. coralnet_toolbox/Annotations/QtPolygonAnnotation.py +1 -1
  4. coralnet_toolbox/Annotations/QtRectangleAnnotation.py +1 -1
  5. coralnet_toolbox/AutoDistill/QtDeployModel.py +4 -0
  6. coralnet_toolbox/Explorer/QtDataItem.py +300 -0
  7. coralnet_toolbox/Explorer/QtExplorer.py +1825 -0
  8. coralnet_toolbox/Explorer/QtSettingsWidgets.py +494 -0
  9. coralnet_toolbox/Explorer/__init__.py +7 -0
  10. coralnet_toolbox/IO/QtImportViscoreAnnotations.py +2 -4
  11. coralnet_toolbox/IO/QtOpenProject.py +2 -1
  12. coralnet_toolbox/Icons/magic.png +0 -0
  13. coralnet_toolbox/MachineLearning/DeployModel/QtDetect.py +4 -0
  14. coralnet_toolbox/MachineLearning/DeployModel/QtSegment.py +4 -0
  15. coralnet_toolbox/MachineLearning/TrainModel/QtClassify.py +1 -1
  16. coralnet_toolbox/QtConfidenceWindow.py +2 -23
  17. coralnet_toolbox/QtEventFilter.py +18 -7
  18. coralnet_toolbox/QtLabelWindow.py +35 -8
  19. coralnet_toolbox/QtMainWindow.py +81 -2
  20. coralnet_toolbox/QtProgressBar.py +12 -0
  21. coralnet_toolbox/SAM/QtDeployGenerator.py +4 -0
  22. coralnet_toolbox/__init__.py +1 -1
  23. coralnet_toolbox/utilities.py +24 -0
  24. {coralnet_toolbox-0.0.66.dist-info → coralnet_toolbox-0.0.68.dist-info}/METADATA +12 -6
  25. {coralnet_toolbox-0.0.66.dist-info → coralnet_toolbox-0.0.68.dist-info}/RECORD +29 -24
  26. {coralnet_toolbox-0.0.66.dist-info → coralnet_toolbox-0.0.68.dist-info}/WHEEL +0 -0
  27. {coralnet_toolbox-0.0.66.dist-info → coralnet_toolbox-0.0.68.dist-info}/entry_points.txt +0 -0
  28. {coralnet_toolbox-0.0.66.dist-info → coralnet_toolbox-0.0.68.dist-info}/licenses/LICENSE.txt +0 -0
  29. {coralnet_toolbox-0.0.66.dist-info → coralnet_toolbox-0.0.68.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1825 @@
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import warnings
5
+
6
+ from ultralytics import YOLO
7
+
8
+ from coralnet_toolbox.Icons import get_icon
9
+ from coralnet_toolbox.utilities import pixmap_to_numpy
10
+
11
+ from PyQt5.QtGui import QIcon, QPen, QColor, QPainter, QBrush, QPainterPath, QMouseEvent
12
+ from PyQt5.QtCore import Qt, QTimer, QRect, QRectF, QPointF, pyqtSignal, QSignalBlocker, pyqtSlot
13
+
14
+ from PyQt5.QtWidgets import (QVBoxLayout, QHBoxLayout, QGraphicsView, QScrollArea,
15
+ QGraphicsScene, QPushButton, QComboBox, QLabel, QWidget, QGridLayout,
16
+ QMainWindow, QSplitter, QGroupBox, QFormLayout,
17
+ QSpinBox, QGraphicsEllipseItem, QGraphicsItem, QSlider,
18
+ QListWidget, QDoubleSpinBox, QApplication, QStyle,
19
+ QGraphicsRectItem, QRubberBand, QStyleOptionGraphicsItem,
20
+ QTabWidget, QLineEdit, QFileDialog)
21
+
22
+ from coralnet_toolbox.Explorer.QtDataItem import AnnotationDataItem
23
+ from coralnet_toolbox.Explorer.QtDataItem import EmbeddingPointItem
24
+ from coralnet_toolbox.Explorer.QtDataItem import AnnotationImageWidget
25
+ from coralnet_toolbox.Explorer.QtSettingsWidgets import ModelSettingsWidget
26
+ from coralnet_toolbox.Explorer.QtSettingsWidgets import EmbeddingSettingsWidget
27
+ from coralnet_toolbox.Explorer.QtSettingsWidgets import AnnotationSettingsWidget
28
+
29
+ from coralnet_toolbox.QtProgressBar import ProgressBar
30
+
31
+ try:
32
+ from sklearn.preprocessing import StandardScaler
33
+ from sklearn.decomposition import PCA
34
+ from sklearn.manifold import TSNE
35
+ from umap import UMAP
36
+ except ImportError:
37
+ print("Warning: sklearn or umap not installed. Some features may be unavailable.")
38
+ StandardScaler = None
39
+ PCA = None
40
+ TSNE = None
41
+ UMAP = None
42
+
43
+
44
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
45
+
46
+
47
+ # ----------------------------------------------------------------------------------------------------------------------
48
+ # Constants
49
+ # ----------------------------------------------------------------------------------------------------------------------
50
+
51
+ POINT_WIDTH = 3
52
+
53
+ # ----------------------------------------------------------------------------------------------------------------------
54
+ # Viewers
55
+ # ----------------------------------------------------------------------------------------------------------------------
56
+
57
+
58
+ class EmbeddingViewer(QWidget):
59
+ """Custom QGraphicsView for interactive embedding visualization with zooming, panning, and selection."""
60
+ selection_changed = pyqtSignal(list)
61
+ reset_view_requested = pyqtSignal()
62
+
63
+ def __init__(self, parent=None):
64
+ """Initialize the EmbeddingViewer widget."""
65
+ self.graphics_scene = QGraphicsScene()
66
+ self.graphics_scene.setSceneRect(-5000, -5000, 10000, 10000)
67
+
68
+ super(EmbeddingViewer, self).__init__(parent)
69
+ self.explorer_window = parent
70
+
71
+ self.graphics_view = QGraphicsView(self.graphics_scene)
72
+ self.graphics_view.setRenderHint(QPainter.Antialiasing)
73
+ self.graphics_view.setDragMode(QGraphicsView.ScrollHandDrag)
74
+ self.graphics_view.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
75
+ self.graphics_view.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
76
+ self.graphics_view.setMinimumHeight(200)
77
+
78
+ self.rubber_band = None
79
+ self.rubber_band_origin = QPointF()
80
+ self.selection_at_press = None
81
+ self.points_by_id = {}
82
+ self.previous_selection_ids = set()
83
+
84
+ self.animation_offset = 0
85
+ self.animation_timer = QTimer()
86
+ self.animation_timer.timeout.connect(self.animate_selection)
87
+ self.animation_timer.setInterval(100)
88
+
89
+ self.graphics_scene.selectionChanged.connect(self.on_selection_changed)
90
+ self.setup_ui()
91
+ self.graphics_view.mousePressEvent = self.mousePressEvent
92
+ self.graphics_view.mouseDoubleClickEvent = self.mouseDoubleClickEvent
93
+ self.graphics_view.mouseReleaseEvent = self.mouseReleaseEvent
94
+ self.graphics_view.mouseMoveEvent = self.mouseMoveEvent
95
+ self.graphics_view.wheelEvent = self.wheelEvent
96
+
97
+ def setup_ui(self):
98
+ """Set up the UI with header layout and graphics view."""
99
+ layout = QVBoxLayout(self)
100
+ layout.setContentsMargins(0, 0, 0, 0)
101
+
102
+ header_layout = QHBoxLayout()
103
+ self.home_button = QPushButton("Home")
104
+ self.home_button.setToolTip("Reset view to fit all points")
105
+ self.home_button.clicked.connect(self.reset_view)
106
+ header_layout.addWidget(self.home_button)
107
+ header_layout.addStretch()
108
+ layout.addLayout(header_layout)
109
+ layout.addWidget(self.graphics_view)
110
+ self.placeholder_label = QLabel(
111
+ "No embedding data available.\nPress 'Apply Embedding' to generate visualization."
112
+ )
113
+ self.placeholder_label.setAlignment(Qt.AlignCenter)
114
+ self.placeholder_label.setStyleSheet("color: gray; font-size: 14px;")
115
+
116
+ layout.addWidget(self.placeholder_label)
117
+ self.show_placeholder()
118
+
119
+ def reset_view(self):
120
+ """Reset the view to fit all embedding points."""
121
+ self.fit_view_to_points()
122
+
123
+ def show_placeholder(self):
124
+ """Show the placeholder message and hide the graphics view."""
125
+ self.graphics_view.setVisible(False)
126
+ self.placeholder_label.setVisible(True)
127
+ self.home_button.setEnabled(False)
128
+
129
+ def show_embedding(self):
130
+ """Show the graphics view and hide the placeholder message."""
131
+ self.graphics_view.setVisible(True)
132
+ self.placeholder_label.setVisible(False)
133
+ self.home_button.setEnabled(True)
134
+
135
+ # Delegate graphics view methods
136
+ def setRenderHint(self, hint):
137
+ """Set render hint for the graphics view."""
138
+ self.graphics_view.setRenderHint(hint)
139
+
140
+ def setDragMode(self, mode):
141
+ """Set drag mode for the graphics view."""
142
+ self.graphics_view.setDragMode(mode)
143
+
144
+ def setTransformationAnchor(self, anchor):
145
+ """Set transformation anchor for the graphics view."""
146
+ self.graphics_view.setTransformationAnchor(anchor)
147
+
148
+ def setResizeAnchor(self, anchor):
149
+ """Set resize anchor for the graphics view."""
150
+ self.graphics_view.setResizeAnchor(anchor)
151
+
152
+ def mapToScene(self, point):
153
+ """Map a point to the scene coordinates."""
154
+ return self.graphics_view.mapToScene(point)
155
+
156
+ def scale(self, sx, sy):
157
+ """Scale the graphics view."""
158
+ self.graphics_view.scale(sx, sy)
159
+
160
+ def translate(self, dx, dy):
161
+ """Translate the graphics view."""
162
+ self.graphics_view.translate(dx, dy)
163
+
164
+ def fitInView(self, rect, aspect_ratio):
165
+ """Fit the view to a rectangle with aspect ratio."""
166
+ self.graphics_view.fitInView(rect, aspect_ratio)
167
+
168
+ def keyPressEvent(self, event):
169
+ """Handles key presses for deleting selected points."""
170
+ # Check if the pressed key is Delete/Backspace AND the Control key is held down
171
+ if event.key() in (Qt.Key_Delete, Qt.Key_Backspace) and event.modifiers() == Qt.ControlModifier:
172
+ # Get the currently selected items from the graphics scene
173
+ selected_items = self.graphics_scene.selectedItems()
174
+
175
+ if not selected_items:
176
+ super().keyPressEvent(event)
177
+ return
178
+
179
+ print(f"Marking {len(selected_items)} points for deletion.")
180
+
181
+ # Mark each item for deletion and remove it from the scene
182
+ for item in selected_items:
183
+ if isinstance(item, EmbeddingPointItem):
184
+ # Mark the central data item for deletion
185
+ item.data_item.mark_for_deletion()
186
+
187
+ # Remove the point from our internal lookup
188
+ ann_id = item.data_item.annotation.id
189
+ if ann_id in self.points_by_id:
190
+ del self.points_by_id[ann_id]
191
+
192
+ # Remove the point from the visual scene
193
+ self.graphics_scene.removeItem(item)
194
+
195
+ # Trigger a selection change to clear the selection state
196
+ # and notify the ExplorerWindow.
197
+ self.on_selection_changed()
198
+
199
+ # Accept the event to prevent it from being processed further
200
+ event.accept()
201
+ else:
202
+ # Pass any other key presses to the default handler
203
+ super().keyPressEvent(event)
204
+
205
+ def mousePressEvent(self, event):
206
+ """Handle mouse press for selection (point or rubber band) and panning."""
207
+ if event.button() == Qt.LeftButton and event.modifiers() == Qt.ControlModifier:
208
+ item_at_pos = self.graphics_view.itemAt(event.pos())
209
+ if isinstance(item_at_pos, EmbeddingPointItem):
210
+ self.graphics_view.setDragMode(QGraphicsView.NoDrag)
211
+ # The viewer (controller) directly changes the state on the data item.
212
+ is_currently_selected = item_at_pos.data_item.is_selected
213
+ item_at_pos.data_item.set_selected(not is_currently_selected)
214
+ item_at_pos.setSelected(not is_currently_selected) # Keep scene selection in sync
215
+ self.on_selection_changed() # Manually trigger update
216
+ return
217
+
218
+ self.selection_at_press = set(self.graphics_scene.selectedItems())
219
+ self.graphics_view.setDragMode(QGraphicsView.NoDrag)
220
+ self.rubber_band_origin = self.graphics_view.mapToScene(event.pos())
221
+ self.rubber_band = QGraphicsRectItem(QRectF(self.rubber_band_origin, self.rubber_band_origin))
222
+ self.rubber_band.setPen(QPen(QColor(0, 100, 255), 1, Qt.DotLine))
223
+ self.rubber_band.setBrush(QBrush(QColor(0, 100, 255, 50)))
224
+ self.graphics_scene.addItem(self.rubber_band)
225
+
226
+ elif event.button() == Qt.RightButton:
227
+ self.graphics_view.setDragMode(QGraphicsView.ScrollHandDrag)
228
+ left_event = QMouseEvent(event.type(), event.localPos(), Qt.LeftButton, Qt.LeftButton, event.modifiers())
229
+ QGraphicsView.mousePressEvent(self.graphics_view, left_event)
230
+ else:
231
+ self.graphics_view.setDragMode(QGraphicsView.NoDrag)
232
+ QGraphicsView.mousePressEvent(self.graphics_view, event)
233
+
234
+ def mouseDoubleClickEvent(self, event):
235
+ """Handle double-click to clear selection and reset the main view."""
236
+ if event.button() == Qt.LeftButton:
237
+ if self.graphics_scene.selectedItems():
238
+ self.graphics_scene.clearSelection()
239
+ self.reset_view_requested.emit()
240
+ event.accept()
241
+ else:
242
+ super().mouseDoubleClickEvent(event)
243
+
244
+ def mouseMoveEvent(self, event):
245
+ """Handle mouse move for dynamic selection and panning."""
246
+ if self.rubber_band:
247
+ # Update the rubber band rectangle as the mouse moves
248
+ current_pos = self.graphics_view.mapToScene(event.pos())
249
+ self.rubber_band.setRect(QRectF(self.rubber_band_origin, current_pos).normalized())
250
+ # Create a selection path from the rubber band rectangle
251
+ path = QPainterPath()
252
+ path.addRect(self.rubber_band.rect())
253
+ # Block signals to avoid recursive selectionChanged events
254
+ self.graphics_scene.blockSignals(True)
255
+ self.graphics_scene.setSelectionArea(path)
256
+ # Restore selection for items that were already selected at press
257
+ if self.selection_at_press:
258
+ for item in self.selection_at_press:
259
+ item.setSelected(True)
260
+ self.graphics_scene.blockSignals(False)
261
+ # Manually trigger selection changed logic
262
+ self.on_selection_changed()
263
+ elif event.buttons() == Qt.RightButton:
264
+ # Forward right-drag as left-drag for panning
265
+ left_event = QMouseEvent(event.type(), event.localPos(), Qt.LeftButton, Qt.LeftButton, event.modifiers())
266
+ QGraphicsView.mouseMoveEvent(self.graphics_view, left_event)
267
+ else:
268
+ # Default mouse move handling
269
+ QGraphicsView.mouseMoveEvent(self.graphics_view, event)
270
+
271
+ def mouseReleaseEvent(self, event):
272
+ """Handle mouse release to finalize the action and clean up."""
273
+ if self.rubber_band:
274
+ self.graphics_scene.removeItem(self.rubber_band)
275
+ self.rubber_band = None
276
+ self.selection_at_press = None
277
+ elif event.button() == Qt.RightButton:
278
+ left_event = QMouseEvent(event.type(), event.localPos(), Qt.LeftButton, Qt.LeftButton, event.modifiers())
279
+ QGraphicsView.mouseReleaseEvent(self.graphics_view, left_event)
280
+ self.graphics_view.setDragMode(QGraphicsView.NoDrag)
281
+ else:
282
+ QGraphicsView.mouseReleaseEvent(self.graphics_view, event)
283
+ self.graphics_view.setDragMode(QGraphicsView.NoDrag)
284
+
285
+ def wheelEvent(self, event):
286
+ """Handle mouse wheel for zooming."""
287
+ zoom_in_factor = 1.25
288
+ zoom_out_factor = 1 / zoom_in_factor
289
+
290
+ # Set anchor points so zoom occurs at mouse position
291
+ self.graphics_view.setTransformationAnchor(QGraphicsView.NoAnchor)
292
+ self.graphics_view.setResizeAnchor(QGraphicsView.NoAnchor)
293
+
294
+ # Get the scene position before zooming
295
+ old_pos = self.graphics_view.mapToScene(event.pos())
296
+
297
+ # Determine zoom direction
298
+ zoom_factor = zoom_in_factor if event.angleDelta().y() > 0 else zoom_out_factor
299
+
300
+ # Apply zoom
301
+ self.graphics_view.scale(zoom_factor, zoom_factor)
302
+
303
+ # Get the scene position after zooming
304
+ new_pos = self.graphics_view.mapToScene(event.pos())
305
+
306
+ # Translate view to keep mouse position stable
307
+ delta = new_pos - old_pos
308
+ self.graphics_view.translate(delta.x(), delta.y())
309
+
310
+ def update_embeddings(self, data_items):
311
+ """Update the embedding visualization. Creates an EmbeddingPointItem for
312
+ each AnnotationDataItem and links them."""
313
+ self.clear_points()
314
+ for item in data_items:
315
+ # Create the point item directly from the data_item.
316
+ # The item's constructor now handles setting position, flags, etc.
317
+ point = EmbeddingPointItem(item)
318
+ self.graphics_scene.addItem(point)
319
+ self.points_by_id[item.annotation.id] = point
320
+
321
+ def refresh_points(self):
322
+ """Refreshes the points in the view to match the current state of the master data list."""
323
+ if not self.explorer_window or not self.explorer_window.current_data_items:
324
+ return
325
+
326
+ # Get the set of IDs for points currently in the scene
327
+ current_point_ids = set(self.points_by_id.keys())
328
+
329
+ # Get the master list of data items from the parent window
330
+ all_data_items = self.explorer_window.current_data_items
331
+
332
+ something_changed = False
333
+ for item in all_data_items:
334
+ # If a data item is NOT marked for deletion but is also NOT in the scene, add it back.
335
+ if not item.is_marked_for_deletion() and item.annotation.id not in current_point_ids:
336
+ point = EmbeddingPointItem(item)
337
+ self.graphics_scene.addItem(point)
338
+ self.points_by_id[item.annotation.id] = point
339
+ something_changed = True
340
+
341
+ if something_changed:
342
+ print("Refreshed embedding points to show reverted items.")
343
+
344
+ def clear_points(self):
345
+ """Clear all embedding points from the scene."""
346
+ for point in self.points_by_id.values():
347
+ self.graphics_scene.removeItem(point)
348
+ self.points_by_id.clear()
349
+
350
+ def on_selection_changed(self):
351
+ """
352
+ Handles selection changes in the scene. Updates the central data model
353
+ and emits a signal to notify other parts of the application.
354
+ """
355
+ if not self.graphics_scene:
356
+ return
357
+ try:
358
+ selected_items = self.graphics_scene.selectedItems()
359
+ except RuntimeError:
360
+ return
361
+
362
+ current_selection_ids = {item.data_item.annotation.id for item in selected_items}
363
+
364
+ if current_selection_ids != self.previous_selection_ids:
365
+ # Update the central model (AnnotationDataItem) for all points
366
+ for point_id, point in self.points_by_id.items():
367
+ is_selected = point_id in current_selection_ids
368
+ # The data_item is the single source of truth
369
+ point.data_item.set_selected(is_selected)
370
+
371
+ self.selection_changed.emit(list(current_selection_ids))
372
+ self.previous_selection_ids = current_selection_ids
373
+
374
+ # Handle animation
375
+ if hasattr(self, 'animation_timer') and self.animation_timer:
376
+ self.animation_timer.stop()
377
+
378
+ for point in self.points_by_id.values():
379
+ if not point.isSelected():
380
+ point.setPen(QPen(QColor("black"), POINT_WIDTH))
381
+ if selected_items and hasattr(self, 'animation_timer') and self.animation_timer:
382
+ self.animation_timer.start()
383
+
384
+ def animate_selection(self):
385
+ """Animate selected points with a marching ants effect."""
386
+ if not self.graphics_scene:
387
+ return
388
+ try:
389
+ selected_items = self.graphics_scene.selectedItems()
390
+ except RuntimeError:
391
+ return
392
+
393
+ self.animation_offset = (self.animation_offset + 1) % 20
394
+ for item in selected_items:
395
+ # Get the color directly from the source of truth
396
+ original_color = item.data_item.effective_color
397
+ darker_color = original_color.darker(150)
398
+ animated_pen = QPen(darker_color, POINT_WIDTH)
399
+ animated_pen.setStyle(Qt.CustomDashLine)
400
+ animated_pen.setDashPattern([1, 2])
401
+ animated_pen.setDashOffset(self.animation_offset)
402
+ item.setPen(animated_pen)
403
+
404
+ def render_selection_from_ids(self, selected_ids):
405
+ """
406
+ Updates the visual selection of points based on a set of annotation IDs
407
+ provided by an external controller.
408
+ """
409
+ blocker = QSignalBlocker(self.graphics_scene)
410
+
411
+ for ann_id, point in self.points_by_id.items():
412
+ is_selected = ann_id in selected_ids
413
+ # 1. Update the state on the central data item
414
+ point.data_item.set_selected(is_selected)
415
+ # 2. Update the selection state of the graphics item itself
416
+ point.setSelected(is_selected)
417
+
418
+ blocker.unblock()
419
+
420
+ # Manually trigger on_selection_changed to update animation and emit signals
421
+ self.on_selection_changed()
422
+
423
+ def fit_view_to_points(self):
424
+ """Fit the view to show all embedding points."""
425
+ if self.points_by_id:
426
+ self.graphics_view.fitInView(self.graphics_scene.itemsBoundingRect(), Qt.KeepAspectRatio)
427
+ else:
428
+ self.graphics_view.fitInView(-2500, -2500, 5000, 5000, Qt.KeepAspectRatio)
429
+
430
+
431
+ class AnnotationViewer(QScrollArea):
432
+ """Scrollable grid widget for displaying annotation image crops with selection,
433
+ filtering, and isolation support. Acts as a controller for the widgets."""
434
+ selection_changed = pyqtSignal(list)
435
+ preview_changed = pyqtSignal(list)
436
+ reset_view_requested = pyqtSignal()
437
+
438
+ def __init__(self, parent=None):
439
+ """Initialize the AnnotationViewer widget."""
440
+ super(AnnotationViewer, self).__init__(parent)
441
+ self.explorer_window = parent
442
+
443
+ self.annotation_widgets_by_id = {}
444
+ self.selected_widgets = []
445
+ self.last_selected_index = -1
446
+ self.current_widget_size = 96
447
+ self.selection_at_press = set()
448
+ self.rubber_band = None
449
+ self.rubber_band_origin = None
450
+ self.drag_threshold = 5
451
+ self.mouse_pressed_on_widget = False
452
+ self.preview_label_assignments = {}
453
+ self.original_label_assignments = {}
454
+ self.isolated_mode = False
455
+ self.isolated_widgets = set()
456
+ self.setup_ui()
457
+
458
+ def setup_ui(self):
459
+ """Set up the UI with a toolbar and a scrollable content area."""
460
+ self.setWidgetResizable(True)
461
+ self.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
462
+ self.setVerticalScrollBarPolicy(Qt.ScrollBarAsNeeded)
463
+
464
+ main_container = QWidget()
465
+ main_layout = QVBoxLayout(main_container)
466
+ main_layout.setContentsMargins(0, 0, 0, 0)
467
+ main_layout.setSpacing(4)
468
+
469
+ toolbar_widget = QWidget()
470
+ toolbar_layout = QHBoxLayout(toolbar_widget)
471
+ toolbar_layout.setContentsMargins(4, 2, 4, 2)
472
+
473
+ self.isolate_button = QPushButton("Isolate Selection")
474
+ self.isolate_button.setToolTip("Hide all non-selected annotations")
475
+ self.isolate_button.clicked.connect(self.isolate_selection)
476
+ toolbar_layout.addWidget(self.isolate_button)
477
+
478
+ self.show_all_button = QPushButton("Show All")
479
+ self.show_all_button.setToolTip("Show all filtered annotations")
480
+ self.show_all_button.clicked.connect(self.show_all_annotations)
481
+ toolbar_layout.addWidget(self.show_all_button)
482
+
483
+ toolbar_layout.addWidget(self._create_separator())
484
+
485
+ sort_label = QLabel("Sort By:")
486
+ toolbar_layout.addWidget(sort_label)
487
+ self.sort_combo = QComboBox()
488
+ self.sort_combo.addItems(["None", "Label", "Image"])
489
+ self.sort_combo.currentTextChanged.connect(self.on_sort_changed)
490
+ toolbar_layout.addWidget(self.sort_combo)
491
+ toolbar_layout.addStretch()
492
+
493
+ size_label = QLabel("Size:")
494
+ toolbar_layout.addWidget(size_label)
495
+ self.size_slider = QSlider(Qt.Horizontal)
496
+ self.size_slider.setMinimum(32)
497
+ self.size_slider.setMaximum(256)
498
+ self.size_slider.setValue(96)
499
+ self.size_slider.setTickPosition(QSlider.TicksBelow)
500
+ self.size_slider.setTickInterval(32)
501
+ self.size_slider.valueChanged.connect(self.on_size_changed)
502
+ toolbar_layout.addWidget(self.size_slider)
503
+
504
+ self.size_value_label = QLabel("96")
505
+ self.size_value_label.setMinimumWidth(30)
506
+ toolbar_layout.addWidget(self.size_value_label)
507
+ main_layout.addWidget(toolbar_widget)
508
+
509
+ self.content_widget = QWidget()
510
+ content_scroll = QScrollArea()
511
+ content_scroll.setWidgetResizable(True)
512
+ content_scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
513
+ content_scroll.setVerticalScrollBarPolicy(Qt.ScrollBarAsNeeded)
514
+ content_scroll.setWidget(self.content_widget)
515
+
516
+ main_layout.addWidget(content_scroll)
517
+ self.setWidget(main_container)
518
+ self._update_toolbar_state()
519
+
520
+ @pyqtSlot()
521
+ def isolate_selection(self):
522
+ """Hides all annotation widgets that are not currently selected."""
523
+ if not self.selected_widgets or self.isolated_mode:
524
+ return
525
+
526
+ self.isolated_widgets = set(self.selected_widgets)
527
+ self.content_widget.setUpdatesEnabled(False)
528
+ try:
529
+ for widget in self.annotation_widgets_by_id.values():
530
+ if widget not in self.isolated_widgets:
531
+ widget.hide()
532
+ self.isolated_mode = True
533
+ self.recalculate_widget_positions()
534
+ finally:
535
+ self.content_widget.setUpdatesEnabled(True)
536
+
537
+ self._update_toolbar_state()
538
+ self.explorer_window.main_window.label_window.update_annotation_count()
539
+
540
+ @pyqtSlot()
541
+ def show_all_annotations(self):
542
+ """Shows all annotation widgets, exiting the isolated mode."""
543
+ if not self.isolated_mode:
544
+ return
545
+
546
+ self.isolated_mode = False
547
+ self.isolated_widgets.clear()
548
+
549
+ self.content_widget.setUpdatesEnabled(False)
550
+ try:
551
+ for widget in self.annotation_widgets_by_id.values():
552
+ widget.show()
553
+ self.recalculate_widget_positions()
554
+ finally:
555
+ self.content_widget.setUpdatesEnabled(True)
556
+
557
+ self._update_toolbar_state()
558
+ self.explorer_window.main_window.label_window.update_annotation_count()
559
+
560
+ def _update_toolbar_state(self):
561
+ """Updates the toolbar buttons based on selection and isolation mode."""
562
+ selection_exists = bool(self.selected_widgets)
563
+ if self.isolated_mode:
564
+ self.isolate_button.hide()
565
+ self.show_all_button.show()
566
+ self.show_all_button.setEnabled(True)
567
+ else:
568
+ self.isolate_button.show()
569
+ self.show_all_button.hide()
570
+ self.isolate_button.setEnabled(selection_exists)
571
+
572
+ def _create_separator(self):
573
+ """Creates a vertical separator for the toolbar."""
574
+ separator = QLabel("|")
575
+ separator.setStyleSheet("color: gray; margin: 0 5px;")
576
+ return separator
577
+
578
+ def on_sort_changed(self, sort_type):
579
+ """Handle sort type change."""
580
+ self.recalculate_widget_positions()
581
+
582
+ def _get_sorted_widgets(self):
583
+ """Get widgets sorted according to the current sort setting."""
584
+ sort_type = self.sort_combo.currentText()
585
+ widgets = list(self.annotation_widgets_by_id.values())
586
+ if sort_type == "Label":
587
+ widgets.sort(key=lambda w: w.data_item.effective_label.short_label_code)
588
+ elif sort_type == "Image":
589
+ widgets.sort(key=lambda w: os.path.basename(w.data_item.annotation.image_path))
590
+ return widgets
591
+
592
+ def _group_widgets_by_sort_key(self, widgets):
593
+ """Group widgets by the current sort key."""
594
+ sort_type = self.sort_combo.currentText()
595
+ if sort_type == "None":
596
+ return [("", widgets)]
597
+ groups = []
598
+ current_group = []
599
+ current_key = None
600
+ for widget in widgets:
601
+ if sort_type == "Label":
602
+ key = widget.data_item.effective_label.short_label_code
603
+ elif sort_type == "Image":
604
+ key = os.path.basename(widget.data_item.annotation.image_path)
605
+ else:
606
+ key = ""
607
+ if current_key != key:
608
+ if current_group:
609
+ groups.append((current_key, current_group))
610
+ current_group = [widget]
611
+ current_key = key
612
+ else:
613
+ current_group.append(widget)
614
+ if current_group:
615
+ groups.append((current_key, current_group))
616
+ return groups
617
+
618
+ def _clear_separator_labels(self):
619
+ """Remove any existing group header labels."""
620
+ if hasattr(self, '_group_headers'):
621
+ for header in self._group_headers:
622
+ header.setParent(None)
623
+ header.deleteLater()
624
+ self._group_headers = []
625
+
626
+ def _create_group_header(self, text):
627
+ """Create a group header label."""
628
+ if not hasattr(self, '_group_headers'):
629
+ self._group_headers = []
630
+ header = QLabel(text, self.content_widget)
631
+ header.setStyleSheet(
632
+ "QLabel {"
633
+ " font-weight: bold;"
634
+ " font-size: 12px;"
635
+ " color: #555;"
636
+ " background-color: #f0f0f0;"
637
+ " border: 1px solid #ccc;"
638
+ " border-radius: 3px;"
639
+ " padding: 5px 8px;"
640
+ " margin: 2px 0px;"
641
+ " }"
642
+ )
643
+ header.setFixedHeight(30)
644
+ header.setMinimumWidth(self.viewport().width() - 20)
645
+ header.show()
646
+ self._group_headers.append(header)
647
+ return header
648
+
649
+ def on_size_changed(self, value):
650
+ """Handle slider value change to resize annotation widgets."""
651
+ if value % 2 != 0:
652
+ value -= 1
653
+
654
+ self.current_widget_size = value
655
+ self.size_value_label.setText(str(value))
656
+ self.content_widget.setUpdatesEnabled(False)
657
+
658
+ for widget in self.annotation_widgets_by_id.values():
659
+ widget.update_height(value)
660
+
661
+ self.content_widget.setUpdatesEnabled(True)
662
+ self.recalculate_widget_positions()
663
+
664
+ def recalculate_widget_positions(self):
665
+ """Manually positions widgets in a flow layout with sorting and group headers."""
666
+ if not self.annotation_widgets_by_id:
667
+ self.content_widget.setMinimumSize(1, 1)
668
+ return
669
+
670
+ self._clear_separator_labels()
671
+ visible_widgets = [w for w in self._get_sorted_widgets() if not w.isHidden()]
672
+ if not visible_widgets:
673
+ self.content_widget.setMinimumSize(1, 1)
674
+ return
675
+
676
+ # Create groups based on the current sort key
677
+ groups = self._group_widgets_by_sort_key(visible_widgets)
678
+ spacing = max(5, int(self.current_widget_size * 0.08))
679
+ available_width = self.viewport().width()
680
+ x, y = spacing, spacing
681
+ max_height_in_row = 0
682
+
683
+ # Calculate the maximum height of the widgets in each row
684
+ for group_name, group_widgets in groups:
685
+ if group_name and self.sort_combo.currentText() != "None":
686
+ if x > spacing:
687
+ x = spacing
688
+ y += max_height_in_row + spacing
689
+ max_height_in_row = 0
690
+ header_label = self._create_group_header(group_name)
691
+ header_label.move(x, y)
692
+ y += header_label.height() + spacing
693
+ x = spacing
694
+ max_height_in_row = 0
695
+
696
+ for widget in group_widgets:
697
+ widget_size = widget.size()
698
+ if x > spacing and x + widget_size.width() > available_width:
699
+ x = spacing
700
+ y += max_height_in_row + spacing
701
+ max_height_in_row = 0
702
+ widget.move(x, y)
703
+ x += widget_size.width() + spacing
704
+ max_height_in_row = max(max_height_in_row, widget_size.height())
705
+
706
+ total_height = y + max_height_in_row + spacing
707
+ self.content_widget.setMinimumSize(available_width, total_height)
708
+
709
+ def update_annotations(self, data_items):
710
+ """Update displayed annotations, creating new widgets for them."""
711
+ if self.isolated_mode:
712
+ self.show_all_annotations()
713
+
714
+ for widget in self.annotation_widgets_by_id.values():
715
+ widget.setParent(None)
716
+ widget.deleteLater()
717
+
718
+ self.annotation_widgets_by_id.clear()
719
+ self.selected_widgets.clear()
720
+ self.last_selected_index = -1
721
+
722
+ for data_item in data_items:
723
+ annotation_widget = AnnotationImageWidget(
724
+ data_item, self.current_widget_size, self, self.content_widget)
725
+
726
+ annotation_widget.show()
727
+ self.annotation_widgets_by_id[data_item.annotation.id] = annotation_widget
728
+
729
+ self.recalculate_widget_positions()
730
+ self._update_toolbar_state()
731
+
732
+ def resizeEvent(self, event):
733
+ """On window resize, reflow the annotation widgets."""
734
+ super().resizeEvent(event)
735
+ if not hasattr(self, '_resize_timer'):
736
+ self._resize_timer = QTimer(self)
737
+ self._resize_timer.setSingleShot(True)
738
+ self._resize_timer.timeout.connect(self.recalculate_widget_positions)
739
+ self._resize_timer.start(100)
740
+
741
+ def keyPressEvent(self, event):
742
+ """Handles key presses for deleting selected annotations."""
743
+ # Check if the pressed key is Delete/Backspace AND the Control key is held down
744
+ if event.key() in (Qt.Key_Delete, Qt.Key_Backspace) and event.modifiers() == Qt.ControlModifier:
745
+ # Proceed only if there are selected widgets
746
+ if not self.selected_widgets:
747
+ super().keyPressEvent(event)
748
+ return
749
+
750
+ print(f"Marking {len(self.selected_widgets)} annotations for deletion.")
751
+
752
+ # Keep track of which annotations were affected
753
+ changed_ids = []
754
+
755
+ # Mark each selected item for deletion and hide it
756
+ for widget in self.selected_widgets:
757
+ widget.data_item.mark_for_deletion()
758
+ widget.hide()
759
+ changed_ids.append(widget.data_item.annotation.id)
760
+
761
+ # Clear the list of selected widgets
762
+ self.selected_widgets.clear()
763
+
764
+ # Recalculate the layout to fill in the empty space
765
+ self.recalculate_widget_positions()
766
+
767
+ # Emit a signal to notify the ExplorerWindow that the selection is now empty
768
+ # This will also clear the selection in the EmbeddingViewer
769
+ if changed_ids:
770
+ self.selection_changed.emit(changed_ids)
771
+
772
+ # Accept the event to prevent it from being processed further
773
+ event.accept()
774
+ else:
775
+ # Pass any other key presses to the default handler
776
+ super().keyPressEvent(event)
777
+
778
+ def mousePressEvent(self, event):
779
+ """Handle mouse press for starting rubber band selection OR clearing selection."""
780
+ if event.button() == Qt.LeftButton:
781
+ if not event.modifiers():
782
+ # If left click with no modifiers, check if click is outside widgets
783
+ is_on_widget = False
784
+ child_at_pos = self.childAt(event.pos())
785
+
786
+ if child_at_pos:
787
+ widget = child_at_pos
788
+ # Traverse up the parent chain to see if click is on an annotation widget
789
+ while widget and widget != self:
790
+ if hasattr(widget, 'annotation_viewer') and widget.annotation_viewer == self:
791
+ is_on_widget = True
792
+ break
793
+ widget = widget.parent()
794
+
795
+ # If click is outside widgets and there is a selection, clear it
796
+ if not is_on_widget and self.selected_widgets:
797
+ changed_ids = [w.data_item.annotation.id for w in self.selected_widgets]
798
+ self.clear_selection()
799
+ self.selection_changed.emit(changed_ids)
800
+ return
801
+
802
+ elif event.modifiers() == Qt.ControlModifier:
803
+ # Start rubber band selection with Ctrl+Left click
804
+ self.selection_at_press = set(self.selected_widgets)
805
+ self.rubber_band_origin = event.pos()
806
+ self.mouse_pressed_on_widget = False
807
+ child_widget = self.childAt(event.pos())
808
+ if child_widget:
809
+ widget = child_widget
810
+ # Check if click is on a widget to avoid starting rubber band
811
+ while widget and widget != self:
812
+ if hasattr(widget, 'annotation_viewer') and widget.annotation_viewer == self:
813
+ self.mouse_pressed_on_widget = True
814
+ break
815
+ widget = widget.parent()
816
+ return
817
+
818
+ elif event.button() == Qt.RightButton:
819
+ # Ignore right clicks
820
+ event.ignore()
821
+ return
822
+
823
+ # Default handler for other cases
824
+ super().mousePressEvent(event)
825
+
826
+ def mouseDoubleClickEvent(self, event):
827
+ """Handle double-click to clear selection and exit isolation mode."""
828
+ if event.button() == Qt.LeftButton:
829
+ changed_ids = []
830
+ if self.selected_widgets:
831
+ changed_ids = [w.data_item.annotation.id for w in self.selected_widgets]
832
+ self.clear_selection()
833
+ self.selection_changed.emit(changed_ids)
834
+ if self.isolated_mode:
835
+ self.show_all_annotations()
836
+ self.reset_view_requested.emit()
837
+ event.accept()
838
+ else:
839
+ super().mouseDoubleClickEvent(event)
840
+
841
+ def mouseMoveEvent(self, event):
842
+ """Handle mouse move for DYNAMIC rubber band selection."""
843
+ # Only proceed if Ctrl+Left mouse drag is active and not on a widget
844
+ if (
845
+ self.rubber_band_origin is None or
846
+ event.buttons() != Qt.LeftButton or
847
+ event.modifiers() != Qt.ControlModifier
848
+ ):
849
+ super().mouseMoveEvent(event)
850
+ return
851
+
852
+ if self.mouse_pressed_on_widget:
853
+ # If drag started on a widget, do not start rubber band
854
+ super().mouseMoveEvent(event)
855
+ return
856
+
857
+ # Only start selection if drag distance exceeds threshold
858
+ distance = (event.pos() - self.rubber_band_origin).manhattanLength()
859
+ if distance < self.drag_threshold:
860
+ return
861
+
862
+ # Create and show the rubber band if not already present
863
+ if not self.rubber_band:
864
+ self.rubber_band = QRubberBand(QRubberBand.Rectangle, self.viewport())
865
+
866
+ rect = QRect(self.rubber_band_origin, event.pos()).normalized()
867
+ self.rubber_band.setGeometry(rect)
868
+ self.rubber_band.show()
869
+ selection_rect = self.rubber_band.geometry()
870
+ content_widget = self.content_widget
871
+ changed_ids = []
872
+
873
+ # Iterate over all annotation widgets to update selection state
874
+ for widget in self.annotation_widgets_by_id.values():
875
+ widget_rect_in_content = widget.geometry()
876
+ # Map widget's rect to viewport coordinates
877
+ widget_rect_in_viewport = QRect(
878
+ content_widget.mapTo(self.viewport(), widget_rect_in_content.topLeft()),
879
+ widget_rect_in_content.size()
880
+ )
881
+ is_in_band = selection_rect.intersects(widget_rect_in_viewport)
882
+ should_be_selected = (widget in self.selection_at_press) or is_in_band
883
+
884
+ # Select or deselect widgets as needed
885
+ if should_be_selected and not widget.is_selected():
886
+ if self.select_widget(widget):
887
+ changed_ids.append(widget.data_item.annotation.id)
888
+
889
+ elif not should_be_selected and widget.is_selected():
890
+ if self.deselect_widget(widget):
891
+ changed_ids.append(widget.data_item.annotation.id)
892
+
893
+ # Emit signal if any selection state changed
894
+ if changed_ids:
895
+ self.selection_changed.emit(changed_ids)
896
+
897
+ def mouseReleaseEvent(self, event):
898
+ """Handle mouse release to complete rubber band selection."""
899
+ if self.rubber_band_origin is not None and event.button() == Qt.LeftButton:
900
+ if self.rubber_band and self.rubber_band.isVisible():
901
+ self.rubber_band.hide()
902
+ self.rubber_band.deleteLater()
903
+ self.rubber_band = None
904
+
905
+ self.selection_at_press = set()
906
+ self.rubber_band_origin = None
907
+ self.mouse_pressed_on_widget = False
908
+ event.accept()
909
+ return
910
+
911
+ super().mouseReleaseEvent(event)
912
+
913
+ def handle_annotation_selection(self, widget, event):
914
+ """Handle selection of annotation widgets with different modes (single, ctrl, shift)."""
915
+ widget_list = [w for w in self._get_sorted_widgets() if not w.isHidden()]
916
+
917
+ try:
918
+ widget_index = widget_list.index(widget)
919
+ except ValueError:
920
+ return
921
+
922
+ modifiers = event.modifiers()
923
+ changed_ids = []
924
+
925
+ # Shift or Shift+Ctrl: range selection
926
+ if modifiers == Qt.ShiftModifier or modifiers == (Qt.ShiftModifier | Qt.ControlModifier):
927
+ if self.last_selected_index != -1:
928
+ # Find the last selected widget in the current list
929
+ last_selected_widget = None
930
+ for w in self.selected_widgets:
931
+ if w in widget_list:
932
+ try:
933
+ last_index_in_current_list = widget_list.index(w)
934
+ if (
935
+ last_selected_widget is None
936
+ or last_index_in_current_list > widget_list.index(last_selected_widget)
937
+ ):
938
+ last_selected_widget = w
939
+ except ValueError:
940
+ continue
941
+
942
+ if last_selected_widget:
943
+ last_selected_index_in_current_list = widget_list.index(last_selected_widget)
944
+ start = min(last_selected_index_in_current_list, widget_index)
945
+ end = max(last_selected_index_in_current_list, widget_index)
946
+ else:
947
+ start, end = widget_index, widget_index
948
+
949
+ # Select all widgets in the range
950
+ for i in range(start, end + 1):
951
+ if self.select_widget(widget_list[i]):
952
+ changed_ids.append(widget_list[i].data_item.annotation.id)
953
+ else:
954
+ # No previous selection, just select the clicked widget
955
+ if self.select_widget(widget):
956
+ changed_ids.append(widget.data_item.annotation.id)
957
+ self.last_selected_index = widget_index
958
+
959
+ # Ctrl: toggle selection of the clicked widget
960
+ elif modifiers == Qt.ControlModifier:
961
+ if widget.is_selected():
962
+ if self.deselect_widget(widget):
963
+ changed_ids.append(widget.data_item.annotation.id)
964
+ else:
965
+ if self.select_widget(widget):
966
+ changed_ids.append(widget.data_item.annotation.id)
967
+ self.last_selected_index = widget_index
968
+
969
+ # No modifier: single selection
970
+ else:
971
+ newly_selected_id = widget.data_item.annotation.id
972
+
973
+ # Deselect all others
974
+ for w in list(self.selected_widgets):
975
+ if w.data_item.annotation.id != newly_selected_id:
976
+ if self.deselect_widget(w):
977
+ changed_ids.append(w.data_item.annotation.id)
978
+
979
+ # Select the clicked widget
980
+ if self.select_widget(widget):
981
+ changed_ids.append(newly_selected_id)
982
+ self.last_selected_index = widget_index
983
+
984
+ # If in isolated mode, update which widgets are visible
985
+ if self.isolated_mode:
986
+ self._update_isolation()
987
+
988
+ # Emit signal if any selection state changed
989
+ if changed_ids:
990
+ self.selection_changed.emit(changed_ids)
991
+
992
+ def _update_isolation(self):
993
+ """Update the isolated view to show only currently selected widgets."""
994
+ if not self.isolated_mode:
995
+ return
996
+ # If in isolated mode, only show selected widgets
997
+ if self.selected_widgets:
998
+ self.isolated_widgets.update(self.selected_widgets)
999
+ self.setUpdatesEnabled(False)
1000
+ try:
1001
+ for widget in self.annotation_widgets_by_id.values():
1002
+ if widget not in self.isolated_widgets:
1003
+ widget.hide()
1004
+ else:
1005
+ widget.show()
1006
+ self.recalculate_widget_positions()
1007
+
1008
+ finally:
1009
+ self.setUpdatesEnabled(True)
1010
+
1011
+ def select_widget(self, widget):
1012
+ """Selects a widget, updates its data_item, and returns True if state changed."""
1013
+ if not widget.is_selected(): # is_selected() checks the data_item
1014
+ # 1. Controller modifies the state on the data item
1015
+ widget.data_item.set_selected(True)
1016
+ # 2. Controller tells the view to update its appearance
1017
+ widget.update_selection_visuals()
1018
+ self.selected_widgets.append(widget)
1019
+ self._update_toolbar_state()
1020
+ return True
1021
+ return False
1022
+
1023
+ def deselect_widget(self, widget):
1024
+ """Deselects a widget, updates its data_item, and returns True if state changed."""
1025
+ if widget.is_selected():
1026
+ # 1. Controller modifies the state on the data item
1027
+ widget.data_item.set_selected(False)
1028
+ # 2. Controller tells the view to update its appearance
1029
+ widget.update_selection_visuals()
1030
+ if widget in self.selected_widgets:
1031
+ self.selected_widgets.remove(widget)
1032
+ self._update_toolbar_state()
1033
+ return True
1034
+ return False
1035
+
1036
+ def clear_selection(self):
1037
+ """Clear all selected widgets and update toolbar state."""
1038
+ for widget in list(self.selected_widgets):
1039
+ # This will internally call deselect_widget, which is fine
1040
+ self.deselect_widget(widget)
1041
+
1042
+ self.selected_widgets.clear()
1043
+ self._update_toolbar_state()
1044
+
1045
+ def get_selected_annotations(self):
1046
+ """Get the annotations corresponding to selected widgets."""
1047
+ return [widget.annotation for widget in self.selected_widgets]
1048
+
1049
+ def render_selection_from_ids(self, selected_ids):
1050
+ """Update the visual selection of widgets based on a set of IDs from the controller."""
1051
+ self.setUpdatesEnabled(False)
1052
+ try:
1053
+ for ann_id, widget in self.annotation_widgets_by_id.items():
1054
+ is_selected = ann_id in selected_ids
1055
+ # 1. Update the state on the central data item
1056
+ widget.data_item.set_selected(is_selected)
1057
+ # 2. Tell the widget to update its visuals based on the new state
1058
+ widget.update_selection_visuals()
1059
+
1060
+ # Resync internal list of selected widgets from the source of truth
1061
+ self.selected_widgets = [w for w in self.annotation_widgets_by_id.values() if w.is_selected()]
1062
+
1063
+ if self.isolated_mode and self.selected_widgets:
1064
+ self.isolated_widgets.update(self.selected_widgets)
1065
+ for widget in self.annotation_widgets_by_id.values():
1066
+ widget.setHidden(widget not in self.isolated_widgets)
1067
+ self.recalculate_widget_positions()
1068
+ finally:
1069
+ self.setUpdatesEnabled(True)
1070
+ self._update_toolbar_state()
1071
+
1072
+ def apply_preview_label_to_selected(self, preview_label):
1073
+ """Apply a preview label and emit a signal for the embedding view to update."""
1074
+ if not self.selected_widgets or not preview_label:
1075
+ return
1076
+ changed_ids = []
1077
+ for widget in self.selected_widgets:
1078
+ widget.data_item.set_preview_label(preview_label)
1079
+ widget.update() # Force repaint with new color
1080
+ changed_ids.append(widget.data_item.annotation.id)
1081
+
1082
+ if self.sort_combo.currentText() == "Label":
1083
+ self.recalculate_widget_positions()
1084
+ if changed_ids:
1085
+ self.preview_changed.emit(changed_ids)
1086
+
1087
+ def clear_preview_states(self):
1088
+ """
1089
+ Clears all preview states, including label changes and items marked
1090
+ for deletion, reverting them to their original state.
1091
+ """
1092
+ something_changed = False
1093
+ for widget in self.annotation_widgets_by_id.values():
1094
+ # Check for and clear preview labels
1095
+ if widget.data_item.has_preview_changes():
1096
+ widget.data_item.clear_preview_label()
1097
+ widget.update() # Repaint to show original color
1098
+ something_changed = True
1099
+
1100
+ # Check for and un-mark items for deletion
1101
+ if widget.data_item.is_marked_for_deletion():
1102
+ widget.data_item.unmark_for_deletion()
1103
+ widget.show() # Make the widget visible again
1104
+ something_changed = True
1105
+
1106
+ if something_changed:
1107
+ # Recalculate positions to update sorting and re-flow the layout
1108
+ if self.sort_combo.currentText() in ("Label", "Image"):
1109
+ self.recalculate_widget_positions()
1110
+
1111
+ def has_preview_changes(self):
1112
+ """Return True if there are preview changes."""
1113
+ return any(w.data_item.has_preview_changes() for w in self.annotation_widgets_by_id.values())
1114
+
1115
+ def get_preview_changes_summary(self):
1116
+ """Get a summary of preview changes."""
1117
+ change_count = sum(1 for w in self.annotation_widgets_by_id.values() if w.data_item.has_preview_changes())
1118
+ return f"{change_count} annotation(s) with preview changes" if change_count else "No preview changes"
1119
+
1120
+ def apply_preview_changes_permanently(self):
1121
+ """Apply preview changes permanently."""
1122
+ applied_annotations = []
1123
+ for widget in self.annotation_widgets_by_id.values():
1124
+ if widget.data_item.apply_preview_permanently():
1125
+ applied_annotations.append(widget.annotation)
1126
+ return applied_annotations
1127
+
1128
+
1129
+ # ----------------------------------------------------------------------------------------------------------------------
1130
+ # ExplorerWindow
1131
+ # ----------------------------------------------------------------------------------------------------------------------
1132
+
1133
+
1134
+ class ExplorerWindow(QMainWindow):
1135
+ def __init__(self, main_window, parent=None):
1136
+ """Initialize the ExplorerWindow."""
1137
+ super(ExplorerWindow, self).__init__(parent)
1138
+ self.main_window = main_window
1139
+ self.image_window = main_window.image_window
1140
+ self.label_window = main_window.label_window
1141
+ self.annotation_window = main_window.annotation_window
1142
+
1143
+ self.device = main_window.device
1144
+ self.model_path = ""
1145
+ self.loaded_model = None
1146
+ self.current_data_items = []
1147
+ self.current_features = None
1148
+ self.current_feature_generating_model = ""
1149
+ self._ui_initialized = False
1150
+
1151
+ self.setWindowTitle("Explorer")
1152
+ explorer_icon_path = get_icon("magic.png")
1153
+ self.setWindowIcon(QIcon(explorer_icon_path))
1154
+
1155
+ self.central_widget = QWidget()
1156
+ self.setCentralWidget(self.central_widget)
1157
+ self.main_layout = QVBoxLayout(self.central_widget)
1158
+ self.left_panel = QWidget()
1159
+ self.left_layout = QVBoxLayout(self.left_panel)
1160
+
1161
+ self.annotation_settings_widget = None
1162
+ self.model_settings_widget = None
1163
+ self.embedding_settings_widget = None
1164
+ self.annotation_viewer = None
1165
+ self.embedding_viewer = None
1166
+
1167
+ self.clear_preview_button = QPushButton('Clear Preview', self)
1168
+ self.clear_preview_button.clicked.connect(self.clear_preview_changes)
1169
+ self.clear_preview_button.setToolTip("Clear all preview changes and revert to original labels")
1170
+ self.clear_preview_button.setEnabled(False)
1171
+
1172
+ self.exit_button = QPushButton('Exit', self)
1173
+ self.exit_button.clicked.connect(self.close)
1174
+ self.exit_button.setToolTip("Close the window")
1175
+
1176
+ self.apply_button = QPushButton('Apply', self)
1177
+ self.apply_button.clicked.connect(self.apply)
1178
+ self.apply_button.setToolTip("Apply changes")
1179
+ self.apply_button.setEnabled(False)
1180
+
1181
+ def showEvent(self, event):
1182
+ """Handle show event."""
1183
+ if not self._ui_initialized:
1184
+ self.setup_ui()
1185
+ self._ui_initialized = True
1186
+ super(ExplorerWindow, self).showEvent(event)
1187
+
1188
+ def closeEvent(self, event):
1189
+ """Handle close event."""
1190
+ # Stop any running timers to prevent errors
1191
+ if hasattr(self, 'embedding_viewer') and self.embedding_viewer:
1192
+ if hasattr(self.embedding_viewer, 'animation_timer') and self.embedding_viewer.animation_timer:
1193
+ self.embedding_viewer.animation_timer.stop()
1194
+
1195
+ # Call the main cancellation method to revert any pending changes
1196
+ self.clear_preview_changes()
1197
+
1198
+ # Call the dedicated cleanup method
1199
+ self._cleanup_resources()
1200
+
1201
+ # Re-enable the main window before closing
1202
+ if self.main_window:
1203
+ self.main_window.setEnabled(True)
1204
+
1205
+ # Move the label_window back to the main_window
1206
+ if hasattr(self.main_window, 'explorer_closed'):
1207
+ self.main_window.explorer_closed()
1208
+
1209
+ # Clear the reference in the main_window to allow garbage collection
1210
+ self.main_window.explorer_window = None
1211
+
1212
+ # Set the ui_initialized flag to False so it can be re-initialized next time
1213
+ self._ui_initialized = False
1214
+
1215
+ event.accept()
1216
+
1217
+ def setup_ui(self):
1218
+ """Set up the UI for the ExplorerWindow."""
1219
+ while self.main_layout.count():
1220
+ child = self.main_layout.takeAt(0)
1221
+ if child.widget():
1222
+ child.widget().setParent(None)
1223
+
1224
+ # Lazily initialize the settings and viewer widgets if they haven't been created yet.
1225
+ # This ensures that the widgets are only created once per ExplorerWindow instance.
1226
+
1227
+ # Annotation settings panel (filters by image, type, label)
1228
+ if self.annotation_settings_widget is None:
1229
+ self.annotation_settings_widget = AnnotationSettingsWidget(self.main_window, self)
1230
+
1231
+ # Model selection panel (choose feature extraction model)
1232
+ if self.model_settings_widget is None:
1233
+ self.model_settings_widget = ModelSettingsWidget(self.main_window, self)
1234
+
1235
+ # Embedding settings panel (choose dimensionality reduction method)
1236
+ if self.embedding_settings_widget is None:
1237
+ self.embedding_settings_widget = EmbeddingSettingsWidget(self.main_window, self)
1238
+
1239
+ # Annotation viewer (shows annotation image crops in a grid)
1240
+ if self.annotation_viewer is None:
1241
+ self.annotation_viewer = AnnotationViewer(self)
1242
+
1243
+ # Embedding viewer (shows 2D embedding scatter plot)
1244
+ if self.embedding_viewer is None:
1245
+ self.embedding_viewer = EmbeddingViewer(self)
1246
+
1247
+ top_layout = QHBoxLayout()
1248
+ top_layout.addWidget(self.annotation_settings_widget, 2)
1249
+ top_layout.addWidget(self.model_settings_widget, 1)
1250
+ top_layout.addWidget(self.embedding_settings_widget, 1)
1251
+ top_container = QWidget()
1252
+ top_container.setLayout(top_layout)
1253
+ self.main_layout.addWidget(top_container)
1254
+
1255
+ middle_splitter = QSplitter(Qt.Horizontal)
1256
+ annotation_group = QGroupBox("Annotation Viewer")
1257
+ annotation_layout = QVBoxLayout(annotation_group)
1258
+ annotation_layout.addWidget(self.annotation_viewer)
1259
+ middle_splitter.addWidget(annotation_group)
1260
+
1261
+ embedding_group = QGroupBox("Embedding Viewer")
1262
+ embedding_layout = QVBoxLayout(embedding_group)
1263
+ embedding_layout.addWidget(self.embedding_viewer)
1264
+ middle_splitter.addWidget(embedding_group)
1265
+ middle_splitter.setSizes([500, 500])
1266
+ self.main_layout.addWidget(middle_splitter, 1)
1267
+ self.main_layout.addWidget(self.label_window)
1268
+
1269
+ self.buttons_layout = QHBoxLayout()
1270
+ self.buttons_layout.addStretch(1)
1271
+ self.buttons_layout.addWidget(self.clear_preview_button)
1272
+ self.buttons_layout.addWidget(self.exit_button)
1273
+ self.buttons_layout.addWidget(self.apply_button)
1274
+ self.main_layout.addLayout(self.buttons_layout)
1275
+
1276
+ self.annotation_settings_widget.set_default_to_current_image()
1277
+ self.refresh_filters()
1278
+
1279
+ try:
1280
+ self.label_window.labelSelected.disconnect(self.on_label_selected_for_preview)
1281
+ except TypeError:
1282
+ pass
1283
+
1284
+ # Connect signals to slots
1285
+ self.label_window.labelSelected.connect(self.on_label_selected_for_preview)
1286
+ self.annotation_viewer.selection_changed.connect(self.on_annotation_view_selection_changed)
1287
+ self.annotation_viewer.preview_changed.connect(self.on_preview_changed)
1288
+ self.annotation_viewer.reset_view_requested.connect(self.on_reset_view_requested)
1289
+ self.embedding_viewer.selection_changed.connect(self.on_embedding_view_selection_changed)
1290
+ self.embedding_viewer.reset_view_requested.connect(self.on_reset_view_requested)
1291
+
1292
+ @pyqtSlot(list)
1293
+ def on_annotation_view_selection_changed(self, changed_ann_ids):
1294
+ """Syncs selection from AnnotationViewer to EmbeddingViewer."""
1295
+ all_selected_ids = {w.data_item.annotation.id for w in self.annotation_viewer.selected_widgets}
1296
+ if self.embedding_viewer.points_by_id:
1297
+ self.embedding_viewer.render_selection_from_ids(all_selected_ids)
1298
+
1299
+ # Call the new centralized method
1300
+ self.update_label_window_selection()
1301
+
1302
+ @pyqtSlot(list)
1303
+ def on_embedding_view_selection_changed(self, all_selected_ann_ids):
1304
+ """Syncs selection from EmbeddingViewer to AnnotationViewer."""
1305
+ # Check the state BEFORE the selection is changed
1306
+ was_empty_selection = len(self.annotation_viewer.selected_widgets) == 0
1307
+
1308
+ # Now, update the selection in the annotation viewer
1309
+ self.annotation_viewer.render_selection_from_ids(set(all_selected_ann_ids))
1310
+
1311
+ # The rest of the logic now works correctly
1312
+ is_new_selection = len(all_selected_ann_ids) > 0
1313
+ if (
1314
+ was_empty_selection and
1315
+ is_new_selection and
1316
+ not self.annotation_viewer.isolated_mode
1317
+ ):
1318
+ self.annotation_viewer.isolate_selection()
1319
+
1320
+ self.update_label_window_selection()
1321
+
1322
+ @pyqtSlot(list)
1323
+ def on_preview_changed(self, changed_ann_ids):
1324
+ """Updates embedding point colors when a preview label is applied."""
1325
+ for ann_id in changed_ann_ids:
1326
+ point = self.embedding_viewer.points_by_id.get(ann_id)
1327
+ if point:
1328
+ point.update()
1329
+
1330
+ @pyqtSlot()
1331
+ def on_reset_view_requested(self):
1332
+ """Handle reset view requests from double-click in either viewer."""
1333
+ # Clear all selections in both viewers
1334
+ self.annotation_viewer.clear_selection()
1335
+ self.embedding_viewer.render_selection_from_ids(set())
1336
+
1337
+ # Exit isolation mode if currently active
1338
+ if self.annotation_viewer.isolated_mode:
1339
+ self.annotation_viewer.show_all_annotations()
1340
+
1341
+ self.update_label_window_selection()
1342
+ self.update_button_states()
1343
+
1344
+ print("Reset view: cleared selections and exited isolation mode")
1345
+
1346
+ def update_label_window_selection(self):
1347
+ """
1348
+ Updates the label window based on the selection state of the currently
1349
+ loaded data items. This is the single, centralized point of logic.
1350
+ """
1351
+ # Get selected items directly from the master data list
1352
+ selected_data_items = [
1353
+ item for item in self.current_data_items if item.is_selected
1354
+ ]
1355
+
1356
+ if not selected_data_items:
1357
+ self.label_window.deselect_active_label()
1358
+ self.label_window.update_annotation_count()
1359
+ return
1360
+
1361
+ first_effective_label = selected_data_items[0].effective_label
1362
+ all_same_current_label = all(
1363
+ item.effective_label.id == first_effective_label.id
1364
+ for item in selected_data_items
1365
+ )
1366
+
1367
+ if all_same_current_label:
1368
+ self.label_window.set_active_label(first_effective_label)
1369
+ # This emit is what updates other UI elements, like the annotation list
1370
+ self.annotation_window.labelSelected.emit(first_effective_label.id)
1371
+ else:
1372
+ self.label_window.deselect_active_label()
1373
+
1374
+ self.label_window.update_annotation_count()
1375
+
1376
+ def get_filtered_data_items(self):
1377
+ """Gets annotations matching all conditions as AnnotationDataItem objects."""
1378
+ data_items = []
1379
+ if not hasattr(self.main_window.annotation_window, 'annotations_dict'):
1380
+ return data_items
1381
+
1382
+ selected_images = self.annotation_settings_widget.get_selected_images()
1383
+ selected_types = self.annotation_settings_widget.get_selected_annotation_types()
1384
+ selected_labels = self.annotation_settings_widget.get_selected_labels()
1385
+
1386
+ if not all([selected_images, selected_types, selected_labels]):
1387
+ return []
1388
+
1389
+ annotations_to_process = [
1390
+ ann for ann in self.main_window.annotation_window.annotations_dict.values()
1391
+ if (os.path.basename(ann.image_path) in selected_images and
1392
+ type(ann).__name__ in selected_types and
1393
+ ann.label.short_label_code in selected_labels)
1394
+ ]
1395
+
1396
+ self._ensure_cropped_images(annotations_to_process)
1397
+ return [AnnotationDataItem(ann) for ann in annotations_to_process]
1398
+
1399
+ def _ensure_cropped_images(self, annotations):
1400
+ """Ensures all provided annotations have a cropped image available."""
1401
+ annotations_by_image = {}
1402
+
1403
+ for annotation in annotations:
1404
+ if not annotation.cropped_image:
1405
+ image_path = annotation.image_path
1406
+ if image_path not in annotations_by_image:
1407
+ annotations_by_image[image_path] = []
1408
+ annotations_by_image[image_path].append(annotation)
1409
+
1410
+ if not annotations_by_image:
1411
+ return
1412
+
1413
+ progress_bar = ProgressBar(self, "Cropping Image Annotations")
1414
+ progress_bar.show()
1415
+ progress_bar.start_progress(len(annotations_by_image))
1416
+
1417
+ try:
1418
+ for image_path, image_annotations in annotations_by_image.items():
1419
+ self.annotation_window.crop_annotations(image_path=image_path,
1420
+ annotations=image_annotations,
1421
+ return_annotations=False,
1422
+ verbose=False)
1423
+ progress_bar.update_progress()
1424
+ finally:
1425
+ progress_bar.finish_progress()
1426
+ progress_bar.stop_progress()
1427
+ progress_bar.close()
1428
+
1429
+ def _extract_color_features(self, data_items, progress_bar=None, bins=32):
1430
+ """
1431
+ Extracts color-based features from annotation crops.
1432
+
1433
+ Features extracted per annotation:
1434
+ - Mean, standard deviation, skewness, and kurtosis for each RGB channel
1435
+ - Normalized histogram for each RGB channel
1436
+ - Grayscale statistics: mean, std, range
1437
+ - Geometric features: area, perimeter (if available)
1438
+ Returns:
1439
+ features: np.ndarray of shape (N, feature_dim)
1440
+ valid_data_items: list of AnnotationDataItem with valid crops
1441
+ """
1442
+ if progress_bar:
1443
+ progress_bar.set_title("Extracting features...")
1444
+ progress_bar.start_progress(len(data_items))
1445
+
1446
+ features = []
1447
+ valid_data_items = []
1448
+
1449
+ for item in data_items:
1450
+ pixmap = item.annotation.get_cropped_image()
1451
+ if pixmap and not pixmap.isNull():
1452
+ # Convert QPixmap to numpy array (H, W, 3)
1453
+ arr = pixmap_to_numpy(pixmap)
1454
+ pixels = arr.reshape(-1, 3)
1455
+
1456
+ # Basic color statistics
1457
+ mean_color = np.mean(pixels, axis=0)
1458
+ std_color = np.std(pixels, axis=0)
1459
+
1460
+ # Skewness and kurtosis for each channel
1461
+ epsilon = 1e-8 # Prevent division by zero
1462
+ centered_pixels = pixels - mean_color
1463
+ skew_color = np.mean(centered_pixels ** 3, axis=0) / (std_color ** 3 + epsilon)
1464
+ kurt_color = np.mean(centered_pixels ** 4, axis=0) / (std_color ** 4 + epsilon) - 3
1465
+
1466
+ # Normalized histograms for each channel
1467
+ histograms = [
1468
+ np.histogram(pixels[:, i], bins=bins, range=(0, 255))[0]
1469
+ for i in range(3)
1470
+ ]
1471
+ histograms = [
1472
+ h / h.sum() if h.sum() > 0 else np.zeros(bins)
1473
+ for h in histograms
1474
+ ]
1475
+
1476
+ # Grayscale statistics
1477
+ gray_arr = np.dot(arr[..., :3], [0.2989, 0.5870, 0.1140])
1478
+ grayscale_stats = np.array([
1479
+ np.mean(gray_arr),
1480
+ np.std(gray_arr),
1481
+ np.ptp(gray_arr)
1482
+ ])
1483
+
1484
+ # Geometric features (area, perimeter)
1485
+ area = getattr(item.annotation, 'area', 0.0)
1486
+ perimeter = getattr(item.annotation, 'perimeter', 0.0)
1487
+ geometric_features = np.array([area, perimeter])
1488
+
1489
+ # Concatenate all features into a single vector
1490
+ current_features = np.concatenate([
1491
+ mean_color,
1492
+ std_color,
1493
+ skew_color,
1494
+ kurt_color,
1495
+ *histograms,
1496
+ grayscale_stats,
1497
+ geometric_features
1498
+ ])
1499
+
1500
+ features.append(current_features)
1501
+ valid_data_items.append(item)
1502
+
1503
+ if progress_bar:
1504
+ progress_bar.update_progress()
1505
+
1506
+ return np.array(features), valid_data_items
1507
+
1508
+ def _extract_yolo_features(self, data_items, model_info, progress_bar=None):
1509
+ """Extracts features from annotation crops using a YOLO model."""
1510
+ # Extract model name and feature mode from the provided model_info tuple
1511
+ model_name, feature_mode = model_info
1512
+
1513
+ if model_name != self.model_path or self.loaded_model is None:
1514
+ try:
1515
+ self.loaded_model = YOLO(model_name)
1516
+ self.model_path = model_name
1517
+ self.imgsz = getattr(self.loaded_model.model.args, 'imgsz', 128)
1518
+ dummy_image = np.zeros((self.imgsz, self.imgsz, 3), dtype=np.uint8)
1519
+ self.loaded_model.predict(dummy_image, imgsz=self.imgsz, half=True, device=self.device, verbose=False)
1520
+
1521
+ except Exception as e:
1522
+ print(f"ERROR: Could not load YOLO model '{model_name}': {e}")
1523
+ return np.array([]), []
1524
+
1525
+ if progress_bar:
1526
+ progress_bar.set_title("Preparing images...")
1527
+ progress_bar.start_progress(len(data_items))
1528
+
1529
+ image_list, valid_data_items = [], []
1530
+ for item in data_items:
1531
+ # Get the cropped image from the annotation
1532
+ pixmap = item.annotation.get_cropped_image()
1533
+
1534
+ if pixmap and not pixmap.isNull():
1535
+ image_list.append(pixmap_to_numpy(pixmap))
1536
+ valid_data_items.append(item)
1537
+
1538
+ if progress_bar:
1539
+ progress_bar.update_progress()
1540
+
1541
+ if not valid_data_items:
1542
+ return np.array([]), []
1543
+
1544
+ # Specify the kwargs for YOLO model prediction
1545
+ kwargs = {'stream': True,
1546
+ 'imgsz': self.imgsz,
1547
+ 'half': True,
1548
+ 'device': self.device,
1549
+ 'verbose': False}
1550
+
1551
+ # Run the model to extract features
1552
+ if feature_mode == "Embed Features":
1553
+ results_generator = self.loaded_model.embed(image_list, **kwargs)
1554
+ else:
1555
+ results_generator = self.loaded_model.predict(image_list, **kwargs)
1556
+
1557
+ if progress_bar:
1558
+ progress_bar.set_title("Extracting features...")
1559
+ progress_bar.start_progress(len(valid_data_items))
1560
+
1561
+ # Prepare a list to hold the extracted features
1562
+ embeddings_list = []
1563
+
1564
+ try:
1565
+ # Iterate through the results and extract features based on the mode
1566
+ for i, result in enumerate(results_generator):
1567
+ if feature_mode == "Embed Features":
1568
+ embeddings_list.append(result.cpu().numpy().flatten())
1569
+
1570
+ elif hasattr(result, 'probs') and result.probs is not None:
1571
+ embeddings_list.append(result.probs.data.cpu().numpy().squeeze())
1572
+
1573
+ else:
1574
+ raise TypeError("Model did not return expected output")
1575
+
1576
+ if progress_bar:
1577
+ progress_bar.update_progress()
1578
+ finally:
1579
+ if torch.cuda.is_available():
1580
+ torch.cuda.empty_cache()
1581
+
1582
+ return np.array(embeddings_list), valid_data_items
1583
+
1584
+ def _extract_features(self, data_items, progress_bar=None):
1585
+ """Dispatcher to call the appropriate feature extraction function."""
1586
+ # Get the selected model and feature mode from the model settings widget
1587
+ model_name, feature_mode = self.model_settings_widget.get_selected_model()
1588
+
1589
+ if isinstance(model_name, tuple):
1590
+ model_name = model_name[0]
1591
+
1592
+ if not model_name:
1593
+ return np.array([]), []
1594
+
1595
+ if model_name == "Color Features":
1596
+ return self._extract_color_features(data_items, progress_bar=progress_bar)
1597
+
1598
+ elif ".pt" in model_name:
1599
+ return self._extract_yolo_features(data_items, (model_name, feature_mode), progress_bar=progress_bar)
1600
+
1601
+ return np.array([]), []
1602
+
1603
+ def _run_dimensionality_reduction(self, features, params):
1604
+ """
1605
+ Runs dimensionality reduction (PCA, UMAP, or t-SNE) on the feature matrix.
1606
+
1607
+ Args:
1608
+ features (np.ndarray): Feature matrix of shape (N, D).
1609
+ params (dict): Embedding parameters, including technique and its hyperparameters.
1610
+
1611
+ Returns:
1612
+ np.ndarray or None: 2D embedded features of shape (N, 2), or None on failure.
1613
+ """
1614
+ technique = params.get('technique', 'UMAP')
1615
+
1616
+ if len(features) <= 2:
1617
+ # Not enough samples for dimensionality reduction
1618
+ return None
1619
+
1620
+ try:
1621
+ # Standardize features before reduction
1622
+ features_scaled = StandardScaler().fit_transform(features)
1623
+
1624
+ if technique == "UMAP":
1625
+ # UMAP: n_neighbors must be < n_samples
1626
+ n_neighbors = min(params.get('n_neighbors', 15), len(features_scaled) - 1)
1627
+
1628
+ reducer = UMAP(
1629
+ n_components=2,
1630
+ random_state=42,
1631
+ n_neighbors=n_neighbors,
1632
+ min_dist=params.get('min_dist', 0.1),
1633
+ metric=params.get('metric', 'cosine')
1634
+ )
1635
+ elif technique == "TSNE":
1636
+ # t-SNE: perplexity must be < n_samples
1637
+ perplexity = min(params.get('perplexity', 30), len(features_scaled) - 1)
1638
+
1639
+ reducer = TSNE(
1640
+ n_components=2,
1641
+ random_state=42,
1642
+ perplexity=perplexity,
1643
+ early_exaggeration=params.get('early_exaggeration', 12.0),
1644
+ learning_rate=params.get('learning_rate', 'auto'),
1645
+ init='pca'
1646
+ )
1647
+ elif technique == "PCA":
1648
+ reducer = PCA(n_components=2, random_state=42)
1649
+ else:
1650
+ # Unknown technique
1651
+ return None
1652
+
1653
+ # Fit and transform the features
1654
+ return reducer.fit_transform(features_scaled)
1655
+
1656
+ except Exception as e:
1657
+ print(f"Error during {technique} dimensionality reduction: {e}")
1658
+ return None
1659
+
1660
+ def _update_data_items_with_embedding(self, data_items, embedded_features):
1661
+ """Updates AnnotationDataItem objects with embedding results."""
1662
+ scale_factor = 4000
1663
+ min_vals, max_vals = np.min(embedded_features, axis=0), np.max(embedded_features, axis=0)
1664
+ range_vals = max_vals - min_vals
1665
+ for i, item in enumerate(data_items):
1666
+ norm_x = (embedded_features[i, 0] - min_vals[0]) / range_vals[0] if range_vals[0] > 0 else 0.5
1667
+ norm_y = (embedded_features[i, 1] - min_vals[1]) / range_vals[1] if range_vals[1] > 0 else 0.5
1668
+ item.embedding_x = (norm_x * scale_factor) - (scale_factor / 2)
1669
+ item.embedding_y = (norm_y * scale_factor) - (scale_factor / 2)
1670
+
1671
+ def run_embedding_pipeline(self):
1672
+ """Orchestrates the feature extraction and dimensionality reduction pipeline."""
1673
+ if not self.current_data_items:
1674
+ return
1675
+
1676
+ self.annotation_viewer.clear_selection()
1677
+ if self.annotation_viewer.isolated_mode:
1678
+ self.annotation_viewer.show_all_annotations()
1679
+
1680
+ self.embedding_viewer.render_selection_from_ids(set())
1681
+ self.update_button_states()
1682
+
1683
+ # Get embedding parameters and selected model; create a cache key to avoid re-computing features
1684
+ embedding_params = self.embedding_settings_widget.get_embedding_parameters()
1685
+ model_info = self.model_settings_widget.get_selected_model()
1686
+ selected_model, selected_feature_mode = model_info if isinstance(model_info, tuple) else (model_info, "default")
1687
+ cache_key = f"{selected_model}_{selected_feature_mode}"
1688
+
1689
+ QApplication.setOverrideCursor(Qt.WaitCursor)
1690
+ progress_bar = ProgressBar(self, "Generating Embedding Visualization")
1691
+ progress_bar.show()
1692
+ try:
1693
+ if self.current_features is None or cache_key != self.current_feature_generating_model:
1694
+ features, valid_data_items = self._extract_features(self.current_data_items, progress_bar=progress_bar)
1695
+ self.current_features = features
1696
+ self.current_feature_generating_model = cache_key
1697
+ self.current_data_items = valid_data_items
1698
+ self.annotation_viewer.update_annotations(self.current_data_items)
1699
+ else:
1700
+ features = self.current_features
1701
+
1702
+ if features is None or len(features) == 0:
1703
+ return
1704
+
1705
+ progress_bar.set_busy_mode("Running dimensionality reduction...")
1706
+ embedded_features = self._run_dimensionality_reduction(features, embedding_params)
1707
+
1708
+ if embedded_features is None:
1709
+ return
1710
+
1711
+ progress_bar.set_busy_mode("Updating visualization...")
1712
+ self._update_data_items_with_embedding(self.current_data_items, embedded_features)
1713
+ self.embedding_viewer.update_embeddings(self.current_data_items)
1714
+ self.embedding_viewer.show_embedding()
1715
+ self.embedding_viewer.fit_view_to_points()
1716
+ finally:
1717
+ QApplication.restoreOverrideCursor()
1718
+ progress_bar.finish_progress()
1719
+ progress_bar.stop_progress()
1720
+ progress_bar.close()
1721
+
1722
+ def refresh_filters(self):
1723
+ """Refresh display: filter data and update annotation viewer."""
1724
+ QApplication.setOverrideCursor(Qt.WaitCursor)
1725
+ try:
1726
+ self.current_data_items = self.get_filtered_data_items()
1727
+ self.current_features = None
1728
+ self.annotation_viewer.update_annotations(self.current_data_items)
1729
+ self.embedding_viewer.clear_points()
1730
+ self.embedding_viewer.show_placeholder()
1731
+ finally:
1732
+ QApplication.restoreOverrideCursor()
1733
+
1734
+ def on_label_selected_for_preview(self, label):
1735
+ """Handle label selection to update preview state."""
1736
+ if hasattr(self, 'annotation_viewer') and self.annotation_viewer.selected_widgets:
1737
+ self.annotation_viewer.apply_preview_label_to_selected(label)
1738
+ self.update_button_states()
1739
+
1740
+ def clear_preview_changes(self):
1741
+ """
1742
+ Clears all preview changes in both viewers, including label changes
1743
+ and items marked for deletion.
1744
+ """
1745
+ if hasattr(self, 'annotation_viewer'):
1746
+ self.annotation_viewer.clear_preview_states()
1747
+
1748
+ if hasattr(self, 'embedding_viewer'):
1749
+ self.embedding_viewer.refresh_points()
1750
+
1751
+ # After reverting all changes, update the button states
1752
+ self.update_button_states()
1753
+ print("Cleared all pending changes.")
1754
+
1755
+ def update_button_states(self):
1756
+ """Update the state of Clear Preview and Apply buttons."""
1757
+ has_changes = self.annotation_viewer.has_preview_changes()
1758
+ self.clear_preview_button.setEnabled(has_changes)
1759
+ self.apply_button.setEnabled(has_changes)
1760
+ summary = self.annotation_viewer.get_preview_changes_summary()
1761
+ self.clear_preview_button.setToolTip(f"Clear all preview changes - {summary}")
1762
+ self.apply_button.setToolTip(f"Apply changes - {summary}")
1763
+
1764
+ def apply(self):
1765
+ """
1766
+ Apply all pending changes, including label modifications and deletions,
1767
+ to the main application's data.
1768
+ """
1769
+ QApplication.setOverrideCursor(Qt.WaitCursor)
1770
+ try:
1771
+ # Separate items into those to be deleted and those to be kept
1772
+ items_to_delete = [item for item in self.current_data_items if item.is_marked_for_deletion()]
1773
+ items_to_keep = [item for item in self.current_data_items if not item.is_marked_for_deletion()]
1774
+
1775
+ # --- 1. Process Deletions ---
1776
+ deleted_annotations = []
1777
+ if items_to_delete:
1778
+ deleted_annotations = [item.annotation for item in items_to_delete]
1779
+ print(f"Permanently deleting {len(deleted_annotations)} annotation(s).")
1780
+ self.annotation_window.delete_annotations(deleted_annotations)
1781
+
1782
+ # --- 2. Process Label Changes on remaining items ---
1783
+ applied_label_changes = []
1784
+ for item in items_to_keep:
1785
+ if item.apply_preview_permanently():
1786
+ applied_label_changes.append(item.annotation)
1787
+
1788
+ # --- 3. Update UI if any changes were made ---
1789
+ if not deleted_annotations and not applied_label_changes:
1790
+ print("No pending changes to apply.")
1791
+ return
1792
+
1793
+ # Update the Explorer's internal list of data items
1794
+ self.current_data_items = items_to_keep
1795
+
1796
+ # Update the main application's data and UI
1797
+ all_affected_annotations = deleted_annotations + applied_label_changes
1798
+ affected_images = {ann.image_path for ann in all_affected_annotations}
1799
+ for image_path in affected_images:
1800
+ self.image_window.update_image_annotations(image_path)
1801
+ self.annotation_window.load_annotations()
1802
+
1803
+ # Refresh the annotation viewer since its underlying data has changed
1804
+ self.annotation_viewer.update_annotations(self.current_data_items)
1805
+
1806
+ # Reset selections and button states
1807
+ self.embedding_viewer.render_selection_from_ids(set())
1808
+ self.update_label_window_selection()
1809
+ self.update_button_states()
1810
+
1811
+ print(f"Applied changes successfully.")
1812
+
1813
+ except Exception as e:
1814
+ print(f"Error applying modifications: {e}")
1815
+ finally:
1816
+ QApplication.restoreOverrideCursor()
1817
+
1818
+ def _cleanup_resources(self):
1819
+ """Clean up resources."""
1820
+ self.loaded_model = None
1821
+ self.model_path = ""
1822
+ self.current_features = None
1823
+ self.current_feature_generating_model = ""
1824
+ if torch.cuda.is_available():
1825
+ torch.cuda.empty_cache()