coralnet-toolbox 0.0.66__py2.py3-none-any.whl → 0.0.67__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- coralnet_toolbox/Annotations/QtMultiPolygonAnnotation.py +1 -1
- coralnet_toolbox/Annotations/QtPatchAnnotation.py +1 -1
- coralnet_toolbox/Annotations/QtPolygonAnnotation.py +1 -1
- coralnet_toolbox/Annotations/QtRectangleAnnotation.py +1 -1
- coralnet_toolbox/AutoDistill/QtDeployModel.py +4 -0
- coralnet_toolbox/Explorer/QtAnnotationDataItem.py +97 -0
- coralnet_toolbox/Explorer/QtAnnotationImageWidget.py +183 -0
- coralnet_toolbox/Explorer/QtEmbeddingPointItem.py +30 -0
- coralnet_toolbox/Explorer/QtExplorer.py +2067 -0
- coralnet_toolbox/Explorer/QtSettingsWidgets.py +490 -0
- coralnet_toolbox/Explorer/__init__.py +7 -0
- coralnet_toolbox/IO/QtImportViscoreAnnotations.py +2 -4
- coralnet_toolbox/IO/QtOpenProject.py +2 -1
- coralnet_toolbox/Icons/magic.png +0 -0
- coralnet_toolbox/MachineLearning/DeployModel/QtDetect.py +4 -0
- coralnet_toolbox/MachineLearning/DeployModel/QtSegment.py +4 -0
- coralnet_toolbox/MachineLearning/TrainModel/QtClassify.py +1 -1
- coralnet_toolbox/QtConfidenceWindow.py +2 -23
- coralnet_toolbox/QtEventFilter.py +2 -2
- coralnet_toolbox/QtLabelWindow.py +23 -8
- coralnet_toolbox/QtMainWindow.py +81 -2
- coralnet_toolbox/QtProgressBar.py +12 -0
- coralnet_toolbox/SAM/QtDeployGenerator.py +4 -0
- coralnet_toolbox/__init__.py +1 -1
- coralnet_toolbox/utilities.py +24 -0
- {coralnet_toolbox-0.0.66.dist-info → coralnet_toolbox-0.0.67.dist-info}/METADATA +2 -1
- {coralnet_toolbox-0.0.66.dist-info → coralnet_toolbox-0.0.67.dist-info}/RECORD +31 -24
- {coralnet_toolbox-0.0.66.dist-info → coralnet_toolbox-0.0.67.dist-info}/WHEEL +0 -0
- {coralnet_toolbox-0.0.66.dist-info → coralnet_toolbox-0.0.67.dist-info}/entry_points.txt +0 -0
- {coralnet_toolbox-0.0.66.dist-info → coralnet_toolbox-0.0.67.dist-info}/licenses/LICENSE.txt +0 -0
- {coralnet_toolbox-0.0.66.dist-info → coralnet_toolbox-0.0.67.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,2067 @@
|
|
1
|
+
import os
|
2
|
+
import numpy as np
|
3
|
+
import torch
|
4
|
+
import warnings
|
5
|
+
|
6
|
+
from ultralytics import YOLO
|
7
|
+
|
8
|
+
from coralnet_toolbox.MachineLearning.Community.cfg import get_available_configs
|
9
|
+
|
10
|
+
from coralnet_toolbox.Icons import get_icon
|
11
|
+
from coralnet_toolbox.utilities import pixmap_to_numpy
|
12
|
+
|
13
|
+
from PyQt5.QtGui import QIcon, QPen, QColor, QPainter, QImage, QBrush, QPainterPath, QPolygonF, QMouseEvent
|
14
|
+
from PyQt5.QtCore import Qt, QTimer, QSize, QRect, QRectF, QPointF, pyqtSignal, QSignalBlocker, pyqtSlot
|
15
|
+
|
16
|
+
from PyQt5.QtWidgets import (QVBoxLayout, QHBoxLayout, QGraphicsView, QScrollArea,
|
17
|
+
QGraphicsScene, QPushButton, QComboBox, QLabel, QWidget, QGridLayout,
|
18
|
+
QMainWindow, QSplitter, QGroupBox, QFormLayout,
|
19
|
+
QSpinBox, QGraphicsEllipseItem, QGraphicsItem, QSlider,
|
20
|
+
QListWidget, QDoubleSpinBox, QApplication, QStyle,
|
21
|
+
QGraphicsRectItem, QRubberBand, QStyleOptionGraphicsItem,
|
22
|
+
QTabWidget, QLineEdit, QFileDialog)
|
23
|
+
|
24
|
+
from .QtAnnotationDataItem import AnnotationDataItem
|
25
|
+
from .QtEmbeddingPointItem import EmbeddingPointItem
|
26
|
+
from .QtAnnotationImageWidget import AnnotationImageWidget
|
27
|
+
from .QtSettingsWidgets import ModelSettingsWidget
|
28
|
+
from .QtSettingsWidgets import EmbeddingSettingsWidget
|
29
|
+
from .QtSettingsWidgets import AnnotationSettingsWidget
|
30
|
+
|
31
|
+
from coralnet_toolbox.QtProgressBar import ProgressBar
|
32
|
+
|
33
|
+
try:
|
34
|
+
from sklearn.preprocessing import StandardScaler
|
35
|
+
from sklearn.decomposition import PCA
|
36
|
+
from sklearn.manifold import TSNE
|
37
|
+
from umap import UMAP
|
38
|
+
except ImportError:
|
39
|
+
print("Warning: sklearn or umap not installed. Some features may be unavailable.")
|
40
|
+
StandardScaler = None
|
41
|
+
PCA = None
|
42
|
+
TSNE = None
|
43
|
+
UMAP = None
|
44
|
+
|
45
|
+
|
46
|
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
47
|
+
|
48
|
+
|
49
|
+
# ----------------------------------------------------------------------------------------------------------------------
|
50
|
+
# Constants
|
51
|
+
# ----------------------------------------------------------------------------------------------------------------------
|
52
|
+
|
53
|
+
|
54
|
+
POINT_SIZE = 15
|
55
|
+
POINT_WIDTH = 3
|
56
|
+
|
57
|
+
|
58
|
+
# ----------------------------------------------------------------------------------------------------------------------
|
59
|
+
# Viewers
|
60
|
+
# ----------------------------------------------------------------------------------------------------------------------
|
61
|
+
|
62
|
+
|
63
|
+
class EmbeddingViewer(QWidget): # Change inheritance to QWidget
|
64
|
+
"""Custom QGraphicsView for interactive embedding visualization with zooming, panning, and selection."""
|
65
|
+
|
66
|
+
# Define signal to report selection changes
|
67
|
+
selection_changed = pyqtSignal(list) # list of all currently selected annotation IDs
|
68
|
+
reset_view_requested = pyqtSignal() # Signal to reset the view to fit all points
|
69
|
+
|
70
|
+
def __init__(self, parent=None):
|
71
|
+
# Create the graphics scene first
|
72
|
+
self.graphics_scene = QGraphicsScene()
|
73
|
+
self.graphics_scene.setSceneRect(-5000, -5000, 10000, 10000)
|
74
|
+
|
75
|
+
# Initialize as a QWidget
|
76
|
+
super(EmbeddingViewer, self).__init__(parent)
|
77
|
+
self.explorer_window = parent
|
78
|
+
|
79
|
+
# Create the actual graphics view
|
80
|
+
self.graphics_view = QGraphicsView(self.graphics_scene)
|
81
|
+
self.graphics_view.setRenderHint(QPainter.Antialiasing)
|
82
|
+
self.graphics_view.setDragMode(QGraphicsView.ScrollHandDrag)
|
83
|
+
self.graphics_view.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
|
84
|
+
self.graphics_view.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
|
85
|
+
self.graphics_view.setMinimumHeight(200)
|
86
|
+
|
87
|
+
# Custom rubber_band state variables
|
88
|
+
self.rubber_band = None
|
89
|
+
self.rubber_band_origin = QPointF()
|
90
|
+
self.selection_at_press = None
|
91
|
+
|
92
|
+
self.points_by_id = {} # Map annotation ID to embedding point
|
93
|
+
self.previous_selection_ids = set() # Track previous selection to detect changes
|
94
|
+
|
95
|
+
self.animation_offset = 0
|
96
|
+
self.animation_timer = QTimer()
|
97
|
+
self.animation_timer.timeout.connect(self.animate_selection)
|
98
|
+
self.animation_timer.setInterval(100)
|
99
|
+
|
100
|
+
# Connect the scene's selection signal
|
101
|
+
self.graphics_scene.selectionChanged.connect(self.on_selection_changed)
|
102
|
+
|
103
|
+
# Setup the UI with header
|
104
|
+
self.setup_ui()
|
105
|
+
|
106
|
+
# Connect mouse events to the graphics view
|
107
|
+
self.graphics_view.mousePressEvent = self.mousePressEvent
|
108
|
+
self.graphics_view.mouseDoubleClickEvent = self.mouseDoubleClickEvent
|
109
|
+
self.graphics_view.mouseReleaseEvent = self.mouseReleaseEvent
|
110
|
+
self.graphics_view.mouseMoveEvent = self.mouseMoveEvent
|
111
|
+
self.graphics_view.wheelEvent = self.wheelEvent
|
112
|
+
|
113
|
+
def setup_ui(self):
|
114
|
+
"""Set up the UI with header layout and graphics view."""
|
115
|
+
layout = QVBoxLayout(self)
|
116
|
+
layout.setContentsMargins(0, 0, 0, 0)
|
117
|
+
|
118
|
+
# Header layout
|
119
|
+
header_layout = QHBoxLayout()
|
120
|
+
|
121
|
+
# Home button
|
122
|
+
self.home_button = QPushButton("Home")
|
123
|
+
self.home_button.setToolTip("Reset view to fit all points")
|
124
|
+
self.home_button.clicked.connect(self.reset_view)
|
125
|
+
header_layout.addWidget(self.home_button)
|
126
|
+
|
127
|
+
# Add stretch to push future controls to the right if needed
|
128
|
+
header_layout.addStretch()
|
129
|
+
|
130
|
+
layout.addLayout(header_layout)
|
131
|
+
|
132
|
+
# Add the graphics view
|
133
|
+
layout.addWidget(self.graphics_view)
|
134
|
+
# Add a placeholder label when no embedding is available
|
135
|
+
self.placeholder_label = QLabel(
|
136
|
+
"No embedding data available.\nPress 'Apply Embedding' to generate visualization."
|
137
|
+
)
|
138
|
+
self.placeholder_label.setAlignment(Qt.AlignCenter)
|
139
|
+
self.placeholder_label.setStyleSheet("color: gray; font-size: 14px;")
|
140
|
+
layout.addWidget(self.placeholder_label)
|
141
|
+
|
142
|
+
# Initially show placeholder
|
143
|
+
self.show_placeholder()
|
144
|
+
|
145
|
+
def reset_view(self):
|
146
|
+
"""Reset the view to fit all embedding points."""
|
147
|
+
self.fit_view_to_points()
|
148
|
+
|
149
|
+
def show_placeholder(self):
|
150
|
+
"""Show the placeholder message and hide the graphics view."""
|
151
|
+
self.graphics_view.setVisible(False)
|
152
|
+
self.placeholder_label.setVisible(True)
|
153
|
+
self.home_button.setEnabled(False)
|
154
|
+
|
155
|
+
def show_embedding(self):
|
156
|
+
"""Show the graphics view and hide the placeholder message."""
|
157
|
+
self.graphics_view.setVisible(True)
|
158
|
+
self.placeholder_label.setVisible(False)
|
159
|
+
self.home_button.setEnabled(True)
|
160
|
+
|
161
|
+
# Delegate graphics view methods
|
162
|
+
def setRenderHint(self, hint):
|
163
|
+
self.graphics_view.setRenderHint(hint)
|
164
|
+
|
165
|
+
def setDragMode(self, mode):
|
166
|
+
self.graphics_view.setDragMode(mode)
|
167
|
+
|
168
|
+
def setTransformationAnchor(self, anchor):
|
169
|
+
self.graphics_view.setTransformationAnchor(anchor)
|
170
|
+
|
171
|
+
def setResizeAnchor(self, anchor):
|
172
|
+
self.graphics_view.setResizeAnchor(anchor)
|
173
|
+
|
174
|
+
def mapToScene(self, point):
|
175
|
+
return self.graphics_view.mapToScene(point)
|
176
|
+
|
177
|
+
def scale(self, sx, sy):
|
178
|
+
self.graphics_view.scale(sx, sy)
|
179
|
+
|
180
|
+
def translate(self, dx, dy):
|
181
|
+
self.graphics_view.translate(dx, dy)
|
182
|
+
|
183
|
+
def fitInView(self, rect, aspect_ratio):
|
184
|
+
self.graphics_view.fitInView(rect, aspect_ratio)
|
185
|
+
|
186
|
+
def mousePressEvent(self, event):
|
187
|
+
"""Handle mouse press for selection (point or rubber band) and panning."""
|
188
|
+
if event.button() == Qt.LeftButton and event.modifiers() == Qt.ControlModifier:
|
189
|
+
# Check if the click is on an existing point
|
190
|
+
item_at_pos = self.graphics_view.itemAt(event.pos())
|
191
|
+
if isinstance(item_at_pos, EmbeddingPointItem):
|
192
|
+
# If so, toggle its selection state and do nothing else
|
193
|
+
self.graphics_view.setDragMode(QGraphicsView.NoDrag)
|
194
|
+
item_at_pos.setSelected(not item_at_pos.isSelected())
|
195
|
+
return # Event handled
|
196
|
+
|
197
|
+
# If the click was on the background, proceed with rubber band selection
|
198
|
+
self.selection_at_press = set(self.graphics_scene.selectedItems())
|
199
|
+
self.graphics_view.setDragMode(QGraphicsView.NoDrag)
|
200
|
+
self.rubber_band_origin = self.graphics_view.mapToScene(event.pos())
|
201
|
+
self.rubber_band = QGraphicsRectItem(QRectF(self.rubber_band_origin, self.rubber_band_origin))
|
202
|
+
self.rubber_band.setPen(QPen(QColor(0, 100, 255), 1, Qt.DotLine))
|
203
|
+
self.rubber_band.setBrush(QBrush(QColor(0, 100, 255, 50)))
|
204
|
+
self.graphics_scene.addItem(self.rubber_band)
|
205
|
+
|
206
|
+
elif event.button() == Qt.RightButton:
|
207
|
+
# Handle panning
|
208
|
+
self.graphics_view.setDragMode(QGraphicsView.ScrollHandDrag)
|
209
|
+
left_event = QMouseEvent(event.type(), event.localPos(), Qt.LeftButton, Qt.LeftButton, event.modifiers())
|
210
|
+
QGraphicsView.mousePressEvent(self.graphics_view, left_event)
|
211
|
+
else:
|
212
|
+
# Handle standard single-item selection
|
213
|
+
self.graphics_view.setDragMode(QGraphicsView.NoDrag)
|
214
|
+
QGraphicsView.mousePressEvent(self.graphics_view, event)
|
215
|
+
|
216
|
+
def mouseDoubleClickEvent(self, event):
|
217
|
+
"""Handle double-click to clear selection and reset the main view."""
|
218
|
+
if event.button() == Qt.LeftButton:
|
219
|
+
# Clear selection if any items are selected
|
220
|
+
if self.graphics_scene.selectedItems():
|
221
|
+
self.graphics_scene.clearSelection() # This triggers on_selection_changed
|
222
|
+
|
223
|
+
# Signal the main window to revert from isolation mode
|
224
|
+
self.reset_view_requested.emit()
|
225
|
+
event.accept()
|
226
|
+
else:
|
227
|
+
# Pass other double-clicks to the base class
|
228
|
+
super().mouseDoubleClickEvent(event)
|
229
|
+
|
230
|
+
def mouseMoveEvent(self, event):
|
231
|
+
"""Handle mouse move for dynamic selection and panning."""
|
232
|
+
if self.rubber_band:
|
233
|
+
# Update the rubber band geometry
|
234
|
+
current_pos = self.graphics_view.mapToScene(event.pos())
|
235
|
+
self.rubber_band.setRect(QRectF(self.rubber_band_origin, current_pos).normalized())
|
236
|
+
|
237
|
+
path = QPainterPath()
|
238
|
+
path.addRect(self.rubber_band.rect())
|
239
|
+
|
240
|
+
# Block signals to perform a compound selection operation
|
241
|
+
self.graphics_scene.blockSignals(True)
|
242
|
+
|
243
|
+
# 1. Perform the "fancy" dynamic selection, which replaces the current selection
|
244
|
+
# with only the items inside the rubber band.
|
245
|
+
self.graphics_scene.setSelectionArea(path)
|
246
|
+
|
247
|
+
# 2. Add back the items that were selected at the start of the drag.
|
248
|
+
if self.selection_at_press:
|
249
|
+
for item in self.selection_at_press:
|
250
|
+
item.setSelected(True)
|
251
|
+
|
252
|
+
# Unblock signals and manually trigger our handler to process the final result.
|
253
|
+
self.graphics_scene.blockSignals(False)
|
254
|
+
self.on_selection_changed()
|
255
|
+
|
256
|
+
elif event.buttons() == Qt.RightButton:
|
257
|
+
# Handle right-click panning
|
258
|
+
left_event = QMouseEvent(event.type(),
|
259
|
+
event.localPos(),
|
260
|
+
Qt.LeftButton,
|
261
|
+
Qt.LeftButton,
|
262
|
+
event.modifiers())
|
263
|
+
QGraphicsView.mouseMoveEvent(self.graphics_view, left_event)
|
264
|
+
else:
|
265
|
+
QGraphicsView.mouseMoveEvent(self.graphics_view, event)
|
266
|
+
|
267
|
+
def mouseReleaseEvent(self, event):
|
268
|
+
"""Handle mouse release to finalize the action and clean up."""
|
269
|
+
if self.rubber_band:
|
270
|
+
# Clean up the visual rectangle
|
271
|
+
self.graphics_scene.removeItem(self.rubber_band)
|
272
|
+
self.rubber_band = None
|
273
|
+
|
274
|
+
# Clean up the stored selection state.
|
275
|
+
self.selection_at_press = None
|
276
|
+
|
277
|
+
elif event.button() == Qt.RightButton:
|
278
|
+
# Finalize the pan
|
279
|
+
left_event = QMouseEvent(event.type(),
|
280
|
+
event.localPos(),
|
281
|
+
Qt.LeftButton,
|
282
|
+
Qt.LeftButton,
|
283
|
+
event.modifiers())
|
284
|
+
QGraphicsView.mouseReleaseEvent(self.graphics_view, left_event)
|
285
|
+
self.graphics_view.setDragMode(QGraphicsView.NoDrag)
|
286
|
+
else:
|
287
|
+
# Finalize a single click
|
288
|
+
QGraphicsView.mouseReleaseEvent(self.graphics_view, event)
|
289
|
+
self.graphics_view.setDragMode(QGraphicsView.NoDrag)
|
290
|
+
|
291
|
+
def wheelEvent(self, event):
|
292
|
+
"""Handle mouse wheel for zooming."""
|
293
|
+
zoom_in_factor = 1.25
|
294
|
+
zoom_out_factor = 1 / zoom_in_factor
|
295
|
+
|
296
|
+
self.graphics_view.setTransformationAnchor(QGraphicsView.NoAnchor)
|
297
|
+
self.graphics_view.setResizeAnchor(QGraphicsView.NoAnchor)
|
298
|
+
|
299
|
+
old_pos = self.graphics_view.mapToScene(event.pos())
|
300
|
+
zoom_factor = zoom_in_factor if event.angleDelta().y() > 0 else zoom_out_factor
|
301
|
+
self.graphics_view.scale(zoom_factor, zoom_factor)
|
302
|
+
new_pos = self.graphics_view.mapToScene(event.pos())
|
303
|
+
|
304
|
+
delta = new_pos - old_pos
|
305
|
+
self.graphics_view.translate(delta.x(), delta.y())
|
306
|
+
|
307
|
+
def update_embeddings(self, data_items):
|
308
|
+
"""Update the embedding visualization with new data.
|
309
|
+
|
310
|
+
Args:
|
311
|
+
data_items: List of AnnotationDataItem objects.
|
312
|
+
"""
|
313
|
+
self.clear_points()
|
314
|
+
|
315
|
+
for item in data_items:
|
316
|
+
point = EmbeddingPointItem(0, 0, POINT_SIZE, POINT_SIZE)
|
317
|
+
point.setPos(item.embedding_x, item.embedding_y)
|
318
|
+
|
319
|
+
# No need to set initial brush - paint() will handle it
|
320
|
+
point.setPen(QPen(QColor("black"), POINT_WIDTH))
|
321
|
+
|
322
|
+
point.setFlag(QGraphicsItem.ItemIgnoresTransformations)
|
323
|
+
point.setFlag(QGraphicsItem.ItemIsSelectable)
|
324
|
+
|
325
|
+
# This is the crucial link: store the shared AnnotationDataItem
|
326
|
+
point.setData(0, item)
|
327
|
+
|
328
|
+
self.graphics_scene.addItem(point)
|
329
|
+
self.points_by_id[item.annotation.id] = point
|
330
|
+
|
331
|
+
def clear_points(self):
|
332
|
+
"""Clear all embedding points from the scene."""
|
333
|
+
for point in self.points_by_id.values():
|
334
|
+
self.graphics_scene.removeItem(point)
|
335
|
+
self.points_by_id.clear()
|
336
|
+
|
337
|
+
def on_selection_changed(self):
|
338
|
+
"""Handle point selection changes and emit a signal to the controller."""
|
339
|
+
# Check if graphics_scene still exists and is valid
|
340
|
+
if not self.graphics_scene or not hasattr(self.graphics_scene, 'selectedItems'):
|
341
|
+
return
|
342
|
+
|
343
|
+
try:
|
344
|
+
selected_items = self.graphics_scene.selectedItems()
|
345
|
+
except RuntimeError:
|
346
|
+
# Scene has been deleted
|
347
|
+
return
|
348
|
+
|
349
|
+
current_selection_ids = {item.data(0).annotation.id for item in selected_items}
|
350
|
+
|
351
|
+
# If the selection has actually changed, update the model and emit
|
352
|
+
if current_selection_ids != self.previous_selection_ids:
|
353
|
+
# Update the central model (the AnnotationDataItem) for all points
|
354
|
+
for point_id, point in self.points_by_id.items():
|
355
|
+
is_selected = point_id in current_selection_ids
|
356
|
+
point.data(0).set_selected(is_selected)
|
357
|
+
|
358
|
+
# Emit the complete list of currently selected IDs
|
359
|
+
self.selection_changed.emit(list(current_selection_ids))
|
360
|
+
self.previous_selection_ids = current_selection_ids
|
361
|
+
|
362
|
+
# Handle local animation - check if animation_timer still exists
|
363
|
+
if hasattr(self, 'animation_timer') and self.animation_timer:
|
364
|
+
self.animation_timer.stop()
|
365
|
+
|
366
|
+
for point in self.points_by_id.values():
|
367
|
+
if not point.isSelected():
|
368
|
+
point.setPen(QPen(QColor("black"), POINT_WIDTH))
|
369
|
+
|
370
|
+
if selected_items and hasattr(self, 'animation_timer') and self.animation_timer:
|
371
|
+
self.animation_timer.start()
|
372
|
+
|
373
|
+
def animate_selection(self):
|
374
|
+
"""Animate selected points with marching ants effect using darker versions of point colors."""
|
375
|
+
# Check if graphics_scene still exists and is valid
|
376
|
+
if not self.graphics_scene or not hasattr(self.graphics_scene, 'selectedItems'):
|
377
|
+
return
|
378
|
+
|
379
|
+
try:
|
380
|
+
selected_items = self.graphics_scene.selectedItems()
|
381
|
+
except RuntimeError:
|
382
|
+
# Scene has been deleted
|
383
|
+
return
|
384
|
+
|
385
|
+
self.animation_offset = (self.animation_offset + 1) % 20
|
386
|
+
|
387
|
+
# This logic remains the same. It applies the custom pen to the selected items.
|
388
|
+
# Because the items are EmbeddingPointItem, the default selection box won't be drawn.
|
389
|
+
for item in selected_items:
|
390
|
+
original_color = item.brush().color()
|
391
|
+
darker_color = original_color.darker(150)
|
392
|
+
|
393
|
+
animated_pen = QPen(darker_color, POINT_WIDTH)
|
394
|
+
animated_pen.setStyle(Qt.CustomDashLine)
|
395
|
+
animated_pen.setDashPattern([1, 2])
|
396
|
+
animated_pen.setDashOffset(self.animation_offset)
|
397
|
+
|
398
|
+
item.setPen(animated_pen)
|
399
|
+
|
400
|
+
def render_selection_from_ids(self, selected_ids):
|
401
|
+
"""Update the visual selection of points based on a set of IDs from the controller."""
|
402
|
+
# Block this scene's own selectionChanged signal to prevent an infinite loop
|
403
|
+
blocker = QSignalBlocker(self.graphics_scene)
|
404
|
+
|
405
|
+
for ann_id, point in self.points_by_id.items():
|
406
|
+
point.setSelected(ann_id in selected_ids)
|
407
|
+
|
408
|
+
self.previous_selection_ids = selected_ids
|
409
|
+
|
410
|
+
# Trigger animation update
|
411
|
+
self.on_selection_changed()
|
412
|
+
|
413
|
+
def fit_view_to_points(self):
|
414
|
+
"""Fit the view to show all embedding points."""
|
415
|
+
if self.points_by_id:
|
416
|
+
self.graphics_view.fitInView(self.graphics_scene.itemsBoundingRect(), Qt.KeepAspectRatio)
|
417
|
+
else:
|
418
|
+
# If no points, reset to default view
|
419
|
+
self.graphics_view.fitInView(-2500, -2500, 5000, 5000, Qt.KeepAspectRatio)
|
420
|
+
|
421
|
+
|
422
|
+
class AnnotationViewer(QScrollArea):
|
423
|
+
"""Scrollable grid widget for displaying annotation image crops with selection,
|
424
|
+
filtering, and isolation support.
|
425
|
+
"""
|
426
|
+
|
427
|
+
# Define signals to report changes to the ExplorerWindow
|
428
|
+
selection_changed = pyqtSignal(list) # list of changed annotation IDs
|
429
|
+
preview_changed = pyqtSignal(list) # list of annotation IDs with new previews
|
430
|
+
reset_view_requested = pyqtSignal() # Signal to reset the view to fit all points
|
431
|
+
|
432
|
+
def __init__(self, parent=None):
|
433
|
+
super(AnnotationViewer, self).__init__(parent)
|
434
|
+
self.annotation_widgets_by_id = {}
|
435
|
+
self.selected_widgets = []
|
436
|
+
self.last_selected_index = -1
|
437
|
+
self.current_widget_size = 96
|
438
|
+
|
439
|
+
self.selection_at_press = set()
|
440
|
+
self.rubber_band = None
|
441
|
+
self.rubber_band_origin = None
|
442
|
+
self.drag_threshold = 5
|
443
|
+
self.mouse_pressed_on_widget = False
|
444
|
+
|
445
|
+
self.preview_label_assignments = {}
|
446
|
+
self.original_label_assignments = {}
|
447
|
+
|
448
|
+
# New state variables for Isolate/Focus mode
|
449
|
+
self.isolated_mode = False
|
450
|
+
self.isolated_widgets = set()
|
451
|
+
|
452
|
+
self.setup_ui()
|
453
|
+
|
454
|
+
def setup_ui(self):
|
455
|
+
"""Set up the UI with a toolbar and a scrollable content area."""
|
456
|
+
self.setWidgetResizable(True)
|
457
|
+
self.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
|
458
|
+
self.setVerticalScrollBarPolicy(Qt.ScrollBarAsNeeded)
|
459
|
+
|
460
|
+
# Main container and layout
|
461
|
+
main_container = QWidget()
|
462
|
+
main_layout = QVBoxLayout(main_container)
|
463
|
+
main_layout.setContentsMargins(0, 0, 0, 0)
|
464
|
+
main_layout.setSpacing(4) # Add a little space between toolbar and content
|
465
|
+
|
466
|
+
# --- New Toolbar ---
|
467
|
+
toolbar_widget = QWidget()
|
468
|
+
toolbar_layout = QHBoxLayout(toolbar_widget)
|
469
|
+
toolbar_layout.setContentsMargins(4, 2, 4, 2)
|
470
|
+
|
471
|
+
# Isolate/Focus controls
|
472
|
+
self.isolate_button = QPushButton("Isolate Selection")
|
473
|
+
isolate_icon = get_icon("focus.png")
|
474
|
+
if not isolate_icon.isNull():
|
475
|
+
self.isolate_button.setIcon(isolate_icon)
|
476
|
+
self.isolate_button.setToolTip("Hide all non-selected annotations")
|
477
|
+
self.isolate_button.clicked.connect(self.isolate_selection)
|
478
|
+
toolbar_layout.addWidget(self.isolate_button)
|
479
|
+
|
480
|
+
self.show_all_button = QPushButton("Show All")
|
481
|
+
show_all_icon = get_icon("show_all.png")
|
482
|
+
if not show_all_icon.isNull():
|
483
|
+
self.show_all_button.setIcon(show_all_icon)
|
484
|
+
self.show_all_button.setToolTip("Show all filtered annotations")
|
485
|
+
self.show_all_button.clicked.connect(self.show_all_annotations)
|
486
|
+
toolbar_layout.addWidget(self.show_all_button)
|
487
|
+
|
488
|
+
# Add a separator
|
489
|
+
toolbar_layout.addWidget(self._create_separator())
|
490
|
+
|
491
|
+
# Sort controls
|
492
|
+
sort_label = QLabel("Sort By:")
|
493
|
+
toolbar_layout.addWidget(sort_label)
|
494
|
+
self.sort_combo = QComboBox()
|
495
|
+
self.sort_combo.addItems(["None", "Label", "Image"])
|
496
|
+
self.sort_combo.currentTextChanged.connect(self.on_sort_changed)
|
497
|
+
toolbar_layout.addWidget(self.sort_combo)
|
498
|
+
|
499
|
+
# Add a spacer to push the size controls to the right
|
500
|
+
toolbar_layout.addStretch()
|
501
|
+
|
502
|
+
# Size controls
|
503
|
+
size_label = QLabel("Size:")
|
504
|
+
toolbar_layout.addWidget(size_label)
|
505
|
+
self.size_slider = QSlider(Qt.Horizontal)
|
506
|
+
self.size_slider.setMinimum(32)
|
507
|
+
self.size_slider.setMaximum(256)
|
508
|
+
self.size_slider.setValue(96)
|
509
|
+
self.size_slider.setTickPosition(QSlider.TicksBelow)
|
510
|
+
self.size_slider.setTickInterval(32)
|
511
|
+
self.size_slider.valueChanged.connect(self.on_size_changed)
|
512
|
+
toolbar_layout.addWidget(self.size_slider)
|
513
|
+
|
514
|
+
self.size_value_label = QLabel("96")
|
515
|
+
self.size_value_label.setMinimumWidth(30)
|
516
|
+
toolbar_layout.addWidget(self.size_value_label)
|
517
|
+
|
518
|
+
main_layout.addWidget(toolbar_widget)
|
519
|
+
|
520
|
+
# --- Content Area ---
|
521
|
+
self.content_widget = QWidget()
|
522
|
+
content_scroll = QScrollArea()
|
523
|
+
content_scroll.setWidgetResizable(True)
|
524
|
+
content_scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
|
525
|
+
content_scroll.setVerticalScrollBarPolicy(Qt.ScrollBarAsNeeded)
|
526
|
+
content_scroll.setWidget(self.content_widget)
|
527
|
+
|
528
|
+
main_layout.addWidget(content_scroll)
|
529
|
+
self.setWidget(main_container)
|
530
|
+
|
531
|
+
# Set the initial state of the toolbar buttons
|
532
|
+
self._update_toolbar_state()
|
533
|
+
|
534
|
+
@pyqtSlot()
|
535
|
+
def isolate_selection(self):
|
536
|
+
"""Hides all annotation widgets that are not currently selected."""
|
537
|
+
if not self.selected_widgets or self.isolated_mode:
|
538
|
+
return
|
539
|
+
|
540
|
+
self.isolated_widgets = set(self.selected_widgets)
|
541
|
+
self.content_widget.setUpdatesEnabled(False)
|
542
|
+
try:
|
543
|
+
for widget in self.annotation_widgets_by_id.values():
|
544
|
+
if widget not in self.isolated_widgets:
|
545
|
+
widget.hide()
|
546
|
+
self.isolated_mode = True
|
547
|
+
self.recalculate_widget_positions()
|
548
|
+
finally:
|
549
|
+
self.content_widget.setUpdatesEnabled(True)
|
550
|
+
|
551
|
+
self._update_toolbar_state()
|
552
|
+
|
553
|
+
@pyqtSlot()
|
554
|
+
def show_all_annotations(self):
|
555
|
+
"""Shows all annotation widgets, exiting the isolated mode."""
|
556
|
+
if not self.isolated_mode:
|
557
|
+
return
|
558
|
+
|
559
|
+
self.isolated_mode = False
|
560
|
+
self.isolated_widgets.clear()
|
561
|
+
|
562
|
+
self.content_widget.setUpdatesEnabled(False)
|
563
|
+
try:
|
564
|
+
for widget in self.annotation_widgets_by_id.values():
|
565
|
+
widget.show()
|
566
|
+
self.recalculate_widget_positions()
|
567
|
+
finally:
|
568
|
+
self.content_widget.setUpdatesEnabled(True)
|
569
|
+
|
570
|
+
self._update_toolbar_state()
|
571
|
+
|
572
|
+
def _update_toolbar_state(self):
|
573
|
+
"""Updates the visibility and enabled state of the toolbar buttons
|
574
|
+
based on the current selection and isolation mode.
|
575
|
+
"""
|
576
|
+
selection_exists = bool(self.selected_widgets)
|
577
|
+
|
578
|
+
if self.isolated_mode:
|
579
|
+
self.isolate_button.hide()
|
580
|
+
self.show_all_button.show()
|
581
|
+
self.show_all_button.setEnabled(True)
|
582
|
+
else:
|
583
|
+
self.isolate_button.show()
|
584
|
+
self.show_all_button.hide()
|
585
|
+
self.isolate_button.setEnabled(selection_exists)
|
586
|
+
|
587
|
+
def _create_separator(self):
|
588
|
+
"""Create a vertical separator line for the toolbar."""
|
589
|
+
separator = QLabel("|")
|
590
|
+
separator.setStyleSheet("color: gray; margin: 0 5px;")
|
591
|
+
return separator
|
592
|
+
|
593
|
+
def on_sort_changed(self, sort_type):
|
594
|
+
"""Handle sort type change."""
|
595
|
+
self.recalculate_widget_positions()
|
596
|
+
|
597
|
+
def _get_sorted_widgets(self):
|
598
|
+
"""Get widgets sorted according to the current sort setting."""
|
599
|
+
sort_type = self.sort_combo.currentText()
|
600
|
+
|
601
|
+
if sort_type == "None":
|
602
|
+
return list(self.annotation_widgets_by_id.values())
|
603
|
+
|
604
|
+
widgets = list(self.annotation_widgets_by_id.values())
|
605
|
+
|
606
|
+
if sort_type == "Label":
|
607
|
+
widgets.sort(key=lambda w: w.data_item.effective_label.short_label_code)
|
608
|
+
elif sort_type == "Image":
|
609
|
+
widgets.sort(key=lambda w: os.path.basename(w.data_item.annotation.image_path))
|
610
|
+
|
611
|
+
return widgets
|
612
|
+
|
613
|
+
def _group_widgets_by_sort_key(self, widgets):
|
614
|
+
"""Group widgets by the current sort key and return groups with headers."""
|
615
|
+
sort_type = self.sort_combo.currentText()
|
616
|
+
|
617
|
+
if sort_type == "None":
|
618
|
+
return [("", widgets)]
|
619
|
+
|
620
|
+
groups = []
|
621
|
+
current_group = []
|
622
|
+
current_key = None
|
623
|
+
|
624
|
+
for widget in widgets:
|
625
|
+
if sort_type == "Label":
|
626
|
+
key = widget.data_item.effective_label.short_label_code
|
627
|
+
elif sort_type == "Image":
|
628
|
+
key = os.path.basename(widget.data_item.annotation.image_path)
|
629
|
+
else:
|
630
|
+
key = ""
|
631
|
+
|
632
|
+
if current_key != key:
|
633
|
+
if current_group:
|
634
|
+
groups.append((current_key, current_group))
|
635
|
+
current_group = [widget]
|
636
|
+
current_key = key
|
637
|
+
else:
|
638
|
+
current_group.append(widget)
|
639
|
+
|
640
|
+
if current_group:
|
641
|
+
groups.append((current_key, current_group))
|
642
|
+
|
643
|
+
return groups
|
644
|
+
|
645
|
+
def _clear_separator_labels(self):
|
646
|
+
"""Remove any existing group header labels."""
|
647
|
+
if hasattr(self, '_group_headers'):
|
648
|
+
for header in self._group_headers:
|
649
|
+
header.setParent(None)
|
650
|
+
header.deleteLater()
|
651
|
+
self._group_headers = []
|
652
|
+
|
653
|
+
def _create_group_header(self, text):
|
654
|
+
"""Create a group header label."""
|
655
|
+
if not hasattr(self, '_group_headers'):
|
656
|
+
self._group_headers = []
|
657
|
+
|
658
|
+
header = QLabel(text)
|
659
|
+
header.setParent(self.content_widget)
|
660
|
+
header.setStyleSheet("""
|
661
|
+
QLabel {
|
662
|
+
font-weight: bold;
|
663
|
+
font-size: 12px;
|
664
|
+
color: #555;
|
665
|
+
background-color: #f0f0f0;
|
666
|
+
border: 1px solid #ccc;
|
667
|
+
border-radius: 3px;
|
668
|
+
padding: 5px 8px;
|
669
|
+
margin: 2px 0px;
|
670
|
+
}
|
671
|
+
""")
|
672
|
+
header.setFixedHeight(30) # Increased from 25 to 30
|
673
|
+
header.setMinimumWidth(self.viewport().width() - 20)
|
674
|
+
header.show()
|
675
|
+
|
676
|
+
self._group_headers.append(header)
|
677
|
+
return header
|
678
|
+
|
679
|
+
def on_size_changed(self, value):
|
680
|
+
"""Handle slider value change to resize annotation widgets."""
|
681
|
+
if value % 2 != 0:
|
682
|
+
value -= 1
|
683
|
+
self.current_widget_size = value
|
684
|
+
self.size_value_label.setText(str(value))
|
685
|
+
|
686
|
+
# Disable updates for performance while resizing many items
|
687
|
+
self.content_widget.setUpdatesEnabled(False)
|
688
|
+
for widget in self.annotation_widgets_by_id.values():
|
689
|
+
widget.update_height(value) # Call the new, more descriptive method
|
690
|
+
self.content_widget.setUpdatesEnabled(True)
|
691
|
+
|
692
|
+
# After resizing, reflow the layout
|
693
|
+
self.recalculate_widget_positions()
|
694
|
+
|
695
|
+
def recalculate_grid_layout(self):
|
696
|
+
"""Recalculate the grid layout based on current widget width."""
|
697
|
+
if not self.annotation_widgets_by_id:
|
698
|
+
return
|
699
|
+
|
700
|
+
available_width = self.viewport().width() - 20
|
701
|
+
widget_width = self.current_widget_size + self.grid_layout.spacing()
|
702
|
+
cols = max(1, available_width // widget_width)
|
703
|
+
|
704
|
+
for i, widget in enumerate(self.annotation_widgets_by_id.values()):
|
705
|
+
self.grid_layout.addWidget(widget, i // cols, i % cols)
|
706
|
+
|
707
|
+
def recalculate_widget_positions(self):
|
708
|
+
"""Manually positions widgets in a flow layout with sorting and group headers."""
|
709
|
+
if not self.annotation_widgets_by_id:
|
710
|
+
self.content_widget.setMinimumSize(1, 1)
|
711
|
+
return
|
712
|
+
|
713
|
+
# Clear any existing separator labels
|
714
|
+
self._clear_separator_labels()
|
715
|
+
|
716
|
+
# Get sorted widgets
|
717
|
+
all_widgets = self._get_sorted_widgets()
|
718
|
+
|
719
|
+
# Filter to only visible widgets
|
720
|
+
visible_widgets = [w for w in all_widgets if not w.isHidden()]
|
721
|
+
|
722
|
+
if not visible_widgets:
|
723
|
+
self.content_widget.setMinimumSize(1, 1)
|
724
|
+
return
|
725
|
+
|
726
|
+
# Group widgets by sort key
|
727
|
+
groups = self._group_widgets_by_sort_key(visible_widgets)
|
728
|
+
|
729
|
+
# Calculate spacing
|
730
|
+
spacing = max(5, int(self.current_widget_size * 0.08))
|
731
|
+
available_width = self.viewport().width()
|
732
|
+
|
733
|
+
x, y = spacing, spacing
|
734
|
+
max_height_in_row = 0
|
735
|
+
|
736
|
+
for group_name, group_widgets in groups:
|
737
|
+
# Add group header if sorting is enabled and group has a name
|
738
|
+
if group_name and self.sort_combo.currentText() != "None":
|
739
|
+
# Ensure we're at the start of a new line for headers
|
740
|
+
if x > spacing:
|
741
|
+
x = spacing
|
742
|
+
y += max_height_in_row + spacing
|
743
|
+
max_height_in_row = 0
|
744
|
+
|
745
|
+
# Create and position header label
|
746
|
+
header_label = self._create_group_header(group_name)
|
747
|
+
header_label.move(x, y)
|
748
|
+
|
749
|
+
# Move to next line after header
|
750
|
+
y += header_label.height() + spacing
|
751
|
+
x = spacing
|
752
|
+
max_height_in_row = 0
|
753
|
+
|
754
|
+
# Position widgets in this group
|
755
|
+
for widget in group_widgets:
|
756
|
+
widget_size = widget.size()
|
757
|
+
|
758
|
+
# Check if widget fits on current line
|
759
|
+
if x > spacing and x + widget_size.width() > available_width:
|
760
|
+
x = spacing
|
761
|
+
y += max_height_in_row + spacing
|
762
|
+
max_height_in_row = 0
|
763
|
+
|
764
|
+
widget.move(x, y)
|
765
|
+
x += widget_size.width() + spacing
|
766
|
+
max_height_in_row = max(max_height_in_row, widget_size.height())
|
767
|
+
|
768
|
+
# Update content widget size
|
769
|
+
total_height = y + max_height_in_row + spacing
|
770
|
+
self.content_widget.setMinimumSize(available_width, total_height)
|
771
|
+
|
772
|
+
def update_annotations(self, data_items):
|
773
|
+
"""Update displayed annotations, creating new widgets. This will also
|
774
|
+
reset any active isolation view.
|
775
|
+
"""
|
776
|
+
# Reset isolation state before updating to avoid confusion
|
777
|
+
if self.isolated_mode:
|
778
|
+
self.show_all_annotations()
|
779
|
+
|
780
|
+
# Clear any existing widgets and ensure they are deleted
|
781
|
+
for widget in self.annotation_widgets_by_id.values():
|
782
|
+
widget.setParent(None)
|
783
|
+
widget.deleteLater()
|
784
|
+
|
785
|
+
self.annotation_widgets_by_id.clear()
|
786
|
+
self.selected_widgets.clear()
|
787
|
+
self.last_selected_index = -1
|
788
|
+
|
789
|
+
# Create new widgets, parenting them to the content_widget
|
790
|
+
for data_item in data_items:
|
791
|
+
annotation_widget = AnnotationImageWidget(
|
792
|
+
data_item,
|
793
|
+
self.current_widget_size,
|
794
|
+
annotation_viewer=self,
|
795
|
+
parent=self.content_widget
|
796
|
+
)
|
797
|
+
annotation_widget.show()
|
798
|
+
self.annotation_widgets_by_id[data_item.annotation.id] = annotation_widget
|
799
|
+
|
800
|
+
self.recalculate_widget_positions()
|
801
|
+
# Ensure toolbar is in the correct state after a refresh
|
802
|
+
self._update_toolbar_state()
|
803
|
+
|
804
|
+
def resizeEvent(self, event):
|
805
|
+
"""On window resize, reflow the annotation widgets."""
|
806
|
+
super().resizeEvent(event)
|
807
|
+
# Use a QTimer to avoid rapid, expensive reflows while dragging the resize handle
|
808
|
+
if not hasattr(self, '_resize_timer'):
|
809
|
+
self._resize_timer = QTimer(self)
|
810
|
+
self._resize_timer.setSingleShot(True)
|
811
|
+
self._resize_timer.timeout.connect(self.recalculate_widget_positions)
|
812
|
+
# Restart the timer on each resize event
|
813
|
+
self._resize_timer.start(100) # 100ms delay
|
814
|
+
|
815
|
+
def mousePressEvent(self, event):
|
816
|
+
"""Handle mouse press for starting rubber band selection OR clearing selection."""
|
817
|
+
|
818
|
+
# Handle plain left-clicks
|
819
|
+
if event.button() == Qt.LeftButton:
|
820
|
+
|
821
|
+
# This is the new logic for clearing selection on a background click.
|
822
|
+
if not event.modifiers(): # Check for NO modifiers (e.g., Ctrl, Shift)
|
823
|
+
|
824
|
+
is_on_widget = False
|
825
|
+
child_at_pos = self.childAt(event.pos())
|
826
|
+
|
827
|
+
# Determine if the click was on an actual annotation widget or empty space
|
828
|
+
if child_at_pos:
|
829
|
+
widget = child_at_pos
|
830
|
+
while widget and widget != self:
|
831
|
+
if hasattr(widget, 'annotation_viewer') and widget.annotation_viewer == self:
|
832
|
+
is_on_widget = True
|
833
|
+
break
|
834
|
+
widget = widget.parent()
|
835
|
+
|
836
|
+
# If click was on empty space AND something is currently selected...
|
837
|
+
if not is_on_widget and self.selected_widgets:
|
838
|
+
# Get IDs of widgets that are about to be deselected to emit a signal
|
839
|
+
changed_ids = [w.data_item.annotation.id for w in self.selected_widgets]
|
840
|
+
self.clear_selection()
|
841
|
+
self.selection_changed.emit(changed_ids)
|
842
|
+
# The event is handled, but we don't call super() to prevent
|
843
|
+
# the scroll area from doing anything else, like starting a drag.
|
844
|
+
return
|
845
|
+
|
846
|
+
# Handle Ctrl+Click for rubber band
|
847
|
+
elif event.modifiers() == Qt.ControlModifier:
|
848
|
+
# Store the set of currently selected items.
|
849
|
+
self.selection_at_press = set(self.selected_widgets)
|
850
|
+
self.rubber_band_origin = event.pos()
|
851
|
+
# We determine mouse_pressed_on_widget here but use it in mouseMove
|
852
|
+
self.mouse_pressed_on_widget = False
|
853
|
+
child_widget = self.childAt(event.pos())
|
854
|
+
if child_widget:
|
855
|
+
widget = child_widget
|
856
|
+
while widget and widget != self:
|
857
|
+
if hasattr(widget, 'annotation_viewer') and widget.annotation_viewer == self:
|
858
|
+
self.mouse_pressed_on_widget = True
|
859
|
+
break
|
860
|
+
widget = widget.parent()
|
861
|
+
return
|
862
|
+
|
863
|
+
# Handle right-clicks
|
864
|
+
elif event.button() == Qt.RightButton:
|
865
|
+
event.ignore()
|
866
|
+
return
|
867
|
+
|
868
|
+
# For all other cases (e.g., a click on a widget that should be handled
|
869
|
+
# by the widget itself), pass the event to the default handler.
|
870
|
+
super().mousePressEvent(event)
|
871
|
+
|
872
|
+
def mouseDoubleClickEvent(self, event):
|
873
|
+
"""Handle double-click to clear selection and exit isolation mode."""
|
874
|
+
if event.button() == Qt.LeftButton:
|
875
|
+
changed_ids = []
|
876
|
+
|
877
|
+
# If items are selected, clear the selection and record their IDs
|
878
|
+
if self.selected_widgets:
|
879
|
+
changed_ids = [w.data_item.annotation.id for w in self.selected_widgets]
|
880
|
+
self.clear_selection()
|
881
|
+
self.selection_changed.emit(changed_ids)
|
882
|
+
|
883
|
+
# If in isolation mode, revert to showing all annotations
|
884
|
+
if self.isolated_mode:
|
885
|
+
self.show_all_annotations()
|
886
|
+
|
887
|
+
# Signal the main window to reset its view (e.g., switch tabs)
|
888
|
+
self.reset_view_requested.emit()
|
889
|
+
event.accept()
|
890
|
+
else:
|
891
|
+
super().mouseDoubleClickEvent(event)
|
892
|
+
|
893
|
+
def mouseMoveEvent(self, event):
|
894
|
+
"""Handle mouse move for DYNAMIC rubber band selection."""
|
895
|
+
if self.rubber_band_origin is None or \
|
896
|
+
event.buttons() != Qt.LeftButton or \
|
897
|
+
event.modifiers() != Qt.ControlModifier:
|
898
|
+
super().mouseMoveEvent(event)
|
899
|
+
return
|
900
|
+
|
901
|
+
# If the mouse was pressed on a widget, let that widget handle the event.
|
902
|
+
if self.mouse_pressed_on_widget:
|
903
|
+
super().mouseMoveEvent(event)
|
904
|
+
return
|
905
|
+
|
906
|
+
# Only start the rubber band after dragging a minimum distance
|
907
|
+
distance = (event.pos() - self.rubber_band_origin).manhattanLength()
|
908
|
+
if distance < self.drag_threshold:
|
909
|
+
return
|
910
|
+
|
911
|
+
# Create and show the rubber band if it doesn't exist
|
912
|
+
if not self.rubber_band:
|
913
|
+
self.rubber_band = QRubberBand(QRubberBand.Rectangle, self.viewport())
|
914
|
+
|
915
|
+
rect = QRect(self.rubber_band_origin, event.pos()).normalized()
|
916
|
+
self.rubber_band.setGeometry(rect)
|
917
|
+
self.rubber_band.show()
|
918
|
+
|
919
|
+
# Perform dynamic selection on every move
|
920
|
+
selection_rect = self.rubber_band.geometry()
|
921
|
+
content_widget = self.content_widget
|
922
|
+
changed_ids = []
|
923
|
+
|
924
|
+
for widget in self.annotation_widgets_by_id.values():
|
925
|
+
widget_rect_in_content = widget.geometry()
|
926
|
+
# Map widget's geometry from the content area to the visible viewport
|
927
|
+
widget_rect_in_viewport = QRect(
|
928
|
+
content_widget.mapTo(self.viewport(), widget_rect_in_content.topLeft()),
|
929
|
+
widget_rect_in_content.size()
|
930
|
+
)
|
931
|
+
|
932
|
+
is_in_band = selection_rect.intersects(widget_rect_in_viewport)
|
933
|
+
|
934
|
+
# A widget should be selected if it was selected at the start OR is in the band now.
|
935
|
+
should_be_selected = (widget in self.selection_at_press) or is_in_band
|
936
|
+
|
937
|
+
if should_be_selected and not widget.is_selected():
|
938
|
+
if self.select_widget(widget):
|
939
|
+
changed_ids.append(widget.data_item.annotation.id)
|
940
|
+
elif not should_be_selected and widget.is_selected():
|
941
|
+
if self.deselect_widget(widget):
|
942
|
+
changed_ids.append(widget.data_item.annotation.id)
|
943
|
+
|
944
|
+
if changed_ids:
|
945
|
+
self.selection_changed.emit(changed_ids)
|
946
|
+
|
947
|
+
def mouseReleaseEvent(self, event):
|
948
|
+
"""Handle mouse release to complete rubber band selection."""
|
949
|
+
# Check if a rubber band drag was in progress
|
950
|
+
if self.rubber_band_origin is not None and event.button() == Qt.LeftButton:
|
951
|
+
if self.rubber_band and self.rubber_band.isVisible():
|
952
|
+
self.rubber_band.hide()
|
953
|
+
self.rubber_band.deleteLater()
|
954
|
+
self.rubber_band = None
|
955
|
+
|
956
|
+
# **NEEDED CHANGE**: Clean up the stored selection state.
|
957
|
+
self.selection_at_press = set()
|
958
|
+
self.rubber_band_origin = None
|
959
|
+
self.mouse_pressed_on_widget = False
|
960
|
+
event.accept()
|
961
|
+
return
|
962
|
+
|
963
|
+
super().mouseReleaseEvent(event)
|
964
|
+
|
965
|
+
def handle_annotation_selection(self, widget, event):
|
966
|
+
"""Handle selection of annotation widgets with different modes."""
|
967
|
+
# Get the list of widgets to work with based on isolation mode
|
968
|
+
if self.isolated_mode:
|
969
|
+
# Only work with visible widgets when in isolation mode
|
970
|
+
widget_list = [w for w in self.annotation_widgets_by_id.values() if not w.isHidden()]
|
971
|
+
else:
|
972
|
+
# Use all widgets when not in isolation mode
|
973
|
+
widget_list = list(self.annotation_widgets_by_id.values())
|
974
|
+
|
975
|
+
try:
|
976
|
+
widget_index = widget_list.index(widget)
|
977
|
+
except ValueError:
|
978
|
+
return
|
979
|
+
|
980
|
+
modifiers = event.modifiers()
|
981
|
+
changed_ids = []
|
982
|
+
|
983
|
+
# --- The selection logic now identifies which items to change ---
|
984
|
+
# --- but the core state change happens in select/deselect ---
|
985
|
+
|
986
|
+
if modifiers == Qt.ShiftModifier or modifiers == (Qt.ShiftModifier | Qt.ControlModifier):
|
987
|
+
# Range selection
|
988
|
+
if self.last_selected_index != -1:
|
989
|
+
# Find the last selected widget in the current widget list
|
990
|
+
last_selected_widget = None
|
991
|
+
for w in self.selected_widgets:
|
992
|
+
if w in widget_list:
|
993
|
+
try:
|
994
|
+
last_index_in_current_list = widget_list.index(w)
|
995
|
+
if last_selected_widget is None or \
|
996
|
+
last_index_in_current_list > widget_list.index(last_selected_widget):
|
997
|
+
last_selected_widget = w
|
998
|
+
except ValueError:
|
999
|
+
continue
|
1000
|
+
|
1001
|
+
if last_selected_widget:
|
1002
|
+
last_selected_index_in_current_list = widget_list.index(last_selected_widget)
|
1003
|
+
start = min(last_selected_index_in_current_list, widget_index)
|
1004
|
+
end = max(last_selected_index_in_current_list, widget_index)
|
1005
|
+
else:
|
1006
|
+
# Fallback if no previously selected widget is found in current list
|
1007
|
+
start = widget_index
|
1008
|
+
end = widget_index
|
1009
|
+
|
1010
|
+
for i in range(start, end + 1):
|
1011
|
+
# select_widget will return True if a change occurred
|
1012
|
+
if self.select_widget(widget_list[i]):
|
1013
|
+
changed_ids.append(widget_list[i].data_item.annotation.id)
|
1014
|
+
else:
|
1015
|
+
if self.select_widget(widget):
|
1016
|
+
changed_ids.append(widget.data_item.annotation.id)
|
1017
|
+
self.last_selected_index = widget_index
|
1018
|
+
|
1019
|
+
elif modifiers == Qt.ControlModifier:
|
1020
|
+
# Toggle selection
|
1021
|
+
if widget.is_selected():
|
1022
|
+
if self.deselect_widget(widget):
|
1023
|
+
changed_ids.append(widget.data_item.annotation.id)
|
1024
|
+
else:
|
1025
|
+
if self.select_widget(widget):
|
1026
|
+
changed_ids.append(widget.data_item.annotation.id)
|
1027
|
+
self.last_selected_index = widget_index
|
1028
|
+
|
1029
|
+
else:
|
1030
|
+
# Normal click: clear all others and select this one
|
1031
|
+
newly_selected_id = widget.data_item.annotation.id
|
1032
|
+
# Deselect all widgets that are not the clicked one
|
1033
|
+
for w in list(self.selected_widgets):
|
1034
|
+
if w.data_item.annotation.id != newly_selected_id:
|
1035
|
+
if self.deselect_widget(w):
|
1036
|
+
changed_ids.append(w.data_item.annotation.id)
|
1037
|
+
# Select the clicked widget
|
1038
|
+
if self.select_widget(widget):
|
1039
|
+
changed_ids.append(newly_selected_id)
|
1040
|
+
self.last_selected_index = widget_index
|
1041
|
+
|
1042
|
+
# Update isolation if in isolated mode
|
1043
|
+
if self.isolated_mode:
|
1044
|
+
self._update_isolation()
|
1045
|
+
|
1046
|
+
# If any selections were changed, emit the signal
|
1047
|
+
if changed_ids:
|
1048
|
+
self.selection_changed.emit(changed_ids)
|
1049
|
+
|
1050
|
+
def _update_isolation(self):
|
1051
|
+
"""Update the isolated view to show only currently selected widgets."""
|
1052
|
+
if not self.isolated_mode:
|
1053
|
+
return
|
1054
|
+
|
1055
|
+
if self.selected_widgets:
|
1056
|
+
# ADD TO isolation instead of replacing it
|
1057
|
+
self.isolated_widgets.update(self.selected_widgets) # Use update() to add, not replace
|
1058
|
+
self.setUpdatesEnabled(False)
|
1059
|
+
try:
|
1060
|
+
for widget in self.annotation_widgets_by_id.values():
|
1061
|
+
if widget not in self.isolated_widgets:
|
1062
|
+
widget.hide()
|
1063
|
+
else:
|
1064
|
+
widget.show()
|
1065
|
+
self.recalculate_widget_positions()
|
1066
|
+
finally:
|
1067
|
+
self.setUpdatesEnabled(True)
|
1068
|
+
else:
|
1069
|
+
# If no widgets are selected, keep the current isolation (don't exit)
|
1070
|
+
# This prevents accidentally exiting isolation mode when clearing selection
|
1071
|
+
pass
|
1072
|
+
|
1073
|
+
def select_widget(self, widget):
|
1074
|
+
"""Select a widget, update the data_item, and return True if state changed."""
|
1075
|
+
if not widget.is_selected():
|
1076
|
+
widget.set_selected(True)
|
1077
|
+
widget.data_item.set_selected(True)
|
1078
|
+
self.selected_widgets.append(widget)
|
1079
|
+
self.update_label_window_selection()
|
1080
|
+
self._update_toolbar_state() # Update button states
|
1081
|
+
return True
|
1082
|
+
return False
|
1083
|
+
|
1084
|
+
def deselect_widget(self, widget):
|
1085
|
+
"""Deselect a widget, update the data_item, and return True if state changed."""
|
1086
|
+
if widget.is_selected():
|
1087
|
+
widget.set_selected(False)
|
1088
|
+
widget.data_item.set_selected(False)
|
1089
|
+
if widget in self.selected_widgets:
|
1090
|
+
self.selected_widgets.remove(widget)
|
1091
|
+
self.update_label_window_selection()
|
1092
|
+
self._update_toolbar_state() # Update button states
|
1093
|
+
return True
|
1094
|
+
return False
|
1095
|
+
|
1096
|
+
def clear_selection(self):
|
1097
|
+
"""Clear all selected widgets and update toolbar state."""
|
1098
|
+
for widget in list(self.selected_widgets):
|
1099
|
+
widget.set_selected(False)
|
1100
|
+
self.selected_widgets.clear()
|
1101
|
+
self.update_label_window_selection()
|
1102
|
+
self._update_toolbar_state() # Update button states
|
1103
|
+
|
1104
|
+
def update_label_window_selection(self):
|
1105
|
+
"""Update the label window selection based on currently selected annotations."""
|
1106
|
+
explorer_window = self.parent()
|
1107
|
+
while explorer_window and not hasattr(explorer_window, 'main_window'):
|
1108
|
+
explorer_window = explorer_window.parent()
|
1109
|
+
|
1110
|
+
if not explorer_window or not hasattr(explorer_window, 'main_window'):
|
1111
|
+
return
|
1112
|
+
|
1113
|
+
main_window = explorer_window.main_window
|
1114
|
+
label_window = main_window.label_window
|
1115
|
+
annotation_window = main_window.annotation_window
|
1116
|
+
|
1117
|
+
if not self.selected_widgets:
|
1118
|
+
label_window.deselect_active_label()
|
1119
|
+
label_window.update_annotation_count()
|
1120
|
+
return
|
1121
|
+
|
1122
|
+
selected_data_items = [widget.data_item for widget in self.selected_widgets]
|
1123
|
+
|
1124
|
+
first_effective_label = selected_data_items[0].effective_label
|
1125
|
+
all_same_current_label = True
|
1126
|
+
for item in selected_data_items:
|
1127
|
+
if item.effective_label.id != first_effective_label.id:
|
1128
|
+
all_same_current_label = False
|
1129
|
+
break
|
1130
|
+
|
1131
|
+
if all_same_current_label:
|
1132
|
+
label_window.set_active_label(first_effective_label)
|
1133
|
+
if not selected_data_items[0].has_preview_changes():
|
1134
|
+
annotation_window.labelSelected.emit(first_effective_label.id)
|
1135
|
+
else:
|
1136
|
+
label_window.deselect_active_label()
|
1137
|
+
|
1138
|
+
label_window.update_annotation_count()
|
1139
|
+
|
1140
|
+
def get_selected_annotations(self):
|
1141
|
+
"""Get the annotations corresponding to selected widgets."""
|
1142
|
+
return [widget.annotation for widget in self.selected_widgets]
|
1143
|
+
|
1144
|
+
def render_selection_from_ids(self, selected_ids):
|
1145
|
+
"""Update the visual selection of widgets based on a set of IDs from the controller."""
|
1146
|
+
# Block signals temporarily to prevent cascade updates
|
1147
|
+
self.setUpdatesEnabled(False)
|
1148
|
+
|
1149
|
+
try:
|
1150
|
+
for ann_id, widget in self.annotation_widgets_by_id.items():
|
1151
|
+
is_selected = ann_id in selected_ids
|
1152
|
+
widget.set_selected(is_selected)
|
1153
|
+
|
1154
|
+
# Resync internal list of selected widgets
|
1155
|
+
self.selected_widgets = [w for w in self.annotation_widgets_by_id.values() if w.is_selected()]
|
1156
|
+
|
1157
|
+
# If we're in isolated mode, ADD to the isolation instead of replacing it
|
1158
|
+
if self.isolated_mode and self.selected_widgets:
|
1159
|
+
self.isolated_widgets.update(self.selected_widgets) # Add to existing isolation
|
1160
|
+
# Hide all widgets except those in the isolated set
|
1161
|
+
for widget in self.annotation_widgets_by_id.values():
|
1162
|
+
if widget not in self.isolated_widgets:
|
1163
|
+
widget.hide()
|
1164
|
+
else:
|
1165
|
+
widget.show()
|
1166
|
+
self.recalculate_widget_positions()
|
1167
|
+
|
1168
|
+
finally:
|
1169
|
+
self.setUpdatesEnabled(True)
|
1170
|
+
|
1171
|
+
# Update label window once at the end
|
1172
|
+
self.update_label_window_selection()
|
1173
|
+
# Update toolbar state to enable/disable Isolate button
|
1174
|
+
self._update_toolbar_state()
|
1175
|
+
|
1176
|
+
def apply_preview_label_to_selected(self, preview_label):
|
1177
|
+
"""Apply a preview label and emit a signal for the embedding view to update."""
|
1178
|
+
if not self.selected_widgets or not preview_label:
|
1179
|
+
return
|
1180
|
+
|
1181
|
+
changed_ids = []
|
1182
|
+
for widget in self.selected_widgets:
|
1183
|
+
widget.data_item.set_preview_label(preview_label)
|
1184
|
+
widget.update() # Force repaint with new color
|
1185
|
+
changed_ids.append(widget.data_item.annotation.id)
|
1186
|
+
|
1187
|
+
# Recalculate positions to update sorting based on new effective labels
|
1188
|
+
if self.sort_combo.currentText() == "Label":
|
1189
|
+
self.recalculate_widget_positions()
|
1190
|
+
|
1191
|
+
if changed_ids:
|
1192
|
+
self.preview_changed.emit(changed_ids)
|
1193
|
+
|
1194
|
+
def clear_preview_states(self):
|
1195
|
+
"""Clear all preview states and revert to original labels."""
|
1196
|
+
# We just need to iterate through all widgets and tell their data_items to clear
|
1197
|
+
something_cleared = False
|
1198
|
+
for widget in self.annotation_widgets_by_id.values():
|
1199
|
+
if widget.data_item.has_preview_changes():
|
1200
|
+
widget.data_item.clear_preview_label()
|
1201
|
+
widget.update() # Repaint to show original color
|
1202
|
+
something_cleared = True
|
1203
|
+
|
1204
|
+
if something_cleared:
|
1205
|
+
# Recalculate positions to update sorting based on reverted labels
|
1206
|
+
if self.sort_combo.currentText() == "Label":
|
1207
|
+
self.recalculate_widget_positions()
|
1208
|
+
self.update_label_window_selection()
|
1209
|
+
|
1210
|
+
def has_preview_changes(self):
|
1211
|
+
"""Check if there are any pending preview changes."""
|
1212
|
+
return any(w.data_item.has_preview_changes() for w in self.annotation_widgets_by_id.values())
|
1213
|
+
|
1214
|
+
def get_preview_changes_summary(self):
|
1215
|
+
"""Get a summary of preview changes for user feedback."""
|
1216
|
+
change_count = sum(1 for w in self.annotation_widgets_by_id.values() if w.data_item.has_preview_changes())
|
1217
|
+
if not change_count:
|
1218
|
+
return "No preview changes"
|
1219
|
+
return f"{change_count} annotation(s) with preview changes"
|
1220
|
+
|
1221
|
+
def apply_preview_changes_permanently(self):
|
1222
|
+
"""Apply all preview changes permanently to the annotation data."""
|
1223
|
+
applied_annotations = []
|
1224
|
+
for widget in self.annotation_widgets_by_id.values():
|
1225
|
+
# Tell the data_item to apply its changes to the underlying annotation
|
1226
|
+
if widget.data_item.apply_preview_permanently():
|
1227
|
+
applied_annotations.append(widget.annotation)
|
1228
|
+
|
1229
|
+
return applied_annotations
|
1230
|
+
|
1231
|
+
|
1232
|
+
# ----------------------------------------------------------------------------------------------------------------------
|
1233
|
+
# ExplorerWindow
|
1234
|
+
# ----------------------------------------------------------------------------------------------------------------------
|
1235
|
+
|
1236
|
+
|
1237
|
+
class ExplorerWindow(QMainWindow):
|
1238
|
+
def __init__(self, main_window, parent=None):
|
1239
|
+
super(ExplorerWindow, self).__init__(parent)
|
1240
|
+
self.main_window = main_window
|
1241
|
+
self.image_window = main_window.image_window
|
1242
|
+
self.label_window = main_window.label_window
|
1243
|
+
self.annotation_window = main_window.annotation_window
|
1244
|
+
|
1245
|
+
self.device = main_window.device # Use the same device as the main window
|
1246
|
+
self.model_path = ""
|
1247
|
+
self.loaded_model = None
|
1248
|
+
|
1249
|
+
# Store current filtered data items for embedding
|
1250
|
+
self.current_data_items = []
|
1251
|
+
|
1252
|
+
# Cache for extracted features and the model that generated them ---
|
1253
|
+
self.current_features = None
|
1254
|
+
self.current_feature_generating_model = ""
|
1255
|
+
|
1256
|
+
self.setWindowTitle("Explorer")
|
1257
|
+
# Set the window icon
|
1258
|
+
explorer_icon_path = get_icon("magic.png")
|
1259
|
+
self.setWindowIcon(QIcon(explorer_icon_path))
|
1260
|
+
|
1261
|
+
# Create a central widget and main layout
|
1262
|
+
self.central_widget = QWidget()
|
1263
|
+
self.setCentralWidget(self.central_widget)
|
1264
|
+
self.main_layout = QVBoxLayout(self.central_widget)
|
1265
|
+
# Create a left panel widget and layout for the re-parented LabelWindow
|
1266
|
+
self.left_panel = QWidget()
|
1267
|
+
self.left_layout = QVBoxLayout(self.left_panel)
|
1268
|
+
|
1269
|
+
# Create widgets in __init__ so they're always available
|
1270
|
+
self.annotation_settings_widget = AnnotationSettingsWidget(self.main_window, self)
|
1271
|
+
self.model_settings_widget = ModelSettingsWidget(self.main_window, self)
|
1272
|
+
self.embedding_settings_widget = EmbeddingSettingsWidget(self.main_window, self)
|
1273
|
+
self.annotation_viewer = AnnotationViewer(self)
|
1274
|
+
self.embedding_viewer = EmbeddingViewer(self)
|
1275
|
+
|
1276
|
+
# Create buttons
|
1277
|
+
self.clear_preview_button = QPushButton('Clear Preview', self)
|
1278
|
+
self.clear_preview_button.clicked.connect(self.clear_preview_changes)
|
1279
|
+
self.clear_preview_button.setToolTip("Clear all preview changes and revert to original labels")
|
1280
|
+
self.clear_preview_button.setEnabled(False) # Initially disabled
|
1281
|
+
|
1282
|
+
self.exit_button = QPushButton('Exit', self)
|
1283
|
+
self.exit_button.clicked.connect(self.close)
|
1284
|
+
self.exit_button.setToolTip("Close the window")
|
1285
|
+
|
1286
|
+
self.apply_button = QPushButton('Apply', self)
|
1287
|
+
self.apply_button.clicked.connect(self.apply)
|
1288
|
+
self.apply_button.setToolTip("Apply changes")
|
1289
|
+
self.apply_button.setEnabled(False) # Initially disabled
|
1290
|
+
|
1291
|
+
def showEvent(self, event):
|
1292
|
+
self.setup_ui()
|
1293
|
+
super(ExplorerWindow, self).showEvent(event)
|
1294
|
+
|
1295
|
+
def closeEvent(self, event):
|
1296
|
+
"""
|
1297
|
+
Handles the window close event.
|
1298
|
+
This now calls the resource cleanup method.
|
1299
|
+
"""
|
1300
|
+
# Stop any running timers to prevent errors
|
1301
|
+
if hasattr(self, 'embedding_viewer') and self.embedding_viewer:
|
1302
|
+
if hasattr(self.embedding_viewer, 'animation_timer') and self.embedding_viewer.animation_timer:
|
1303
|
+
self.embedding_viewer.animation_timer.stop()
|
1304
|
+
|
1305
|
+
# Clear any unsaved preview states
|
1306
|
+
if hasattr(self, 'annotation_viewer'):
|
1307
|
+
self.annotation_viewer.clear_preview_states()
|
1308
|
+
|
1309
|
+
# --- NEW: Call the dedicated cleanup method ---
|
1310
|
+
self._cleanup_resources()
|
1311
|
+
|
1312
|
+
# Re-enable the main window before closing
|
1313
|
+
if self.main_window:
|
1314
|
+
self.main_window.setEnabled(True)
|
1315
|
+
|
1316
|
+
# Move the label_window back to the main_window
|
1317
|
+
if hasattr(self.main_window, 'explorer_closed'):
|
1318
|
+
self.main_window.explorer_closed()
|
1319
|
+
|
1320
|
+
# Clear the reference in the main_window to allow garbage collection
|
1321
|
+
self.main_window.explorer_window = None
|
1322
|
+
|
1323
|
+
event.accept()
|
1324
|
+
|
1325
|
+
def setup_ui(self):
|
1326
|
+
# Clear the main layout to remove any existing widgets
|
1327
|
+
while self.main_layout.count():
|
1328
|
+
child = self.main_layout.takeAt(0)
|
1329
|
+
if child.widget():
|
1330
|
+
child.widget().setParent(None) # Remove from layout but don't delete
|
1331
|
+
|
1332
|
+
# Top section: Conditions and Settings side by side
|
1333
|
+
top_layout = QHBoxLayout()
|
1334
|
+
|
1335
|
+
# Add existing widgets to layout
|
1336
|
+
top_layout.addWidget(self.annotation_settings_widget, 2) # Give annotation settings more space
|
1337
|
+
top_layout.addWidget(self.model_settings_widget, 1) # Model settings in the middle
|
1338
|
+
top_layout.addWidget(self.embedding_settings_widget, 1) # Embedding settings on the right
|
1339
|
+
|
1340
|
+
# Create container widget for top layout
|
1341
|
+
top_container = QWidget()
|
1342
|
+
top_container.setLayout(top_layout)
|
1343
|
+
self.main_layout.addWidget(top_container)
|
1344
|
+
|
1345
|
+
# Middle section: Annotation Viewer (left) and Embedding Viewer (right)
|
1346
|
+
middle_splitter = QSplitter(Qt.Horizontal)
|
1347
|
+
|
1348
|
+
# Wrap annotation viewer in a group box
|
1349
|
+
annotation_group = QGroupBox("Annotation Viewer")
|
1350
|
+
annotation_layout = QVBoxLayout(annotation_group)
|
1351
|
+
annotation_layout.addWidget(self.annotation_viewer)
|
1352
|
+
middle_splitter.addWidget(annotation_group)
|
1353
|
+
|
1354
|
+
# Wrap embedding viewer in a group box
|
1355
|
+
embedding_group = QGroupBox("Embedding Viewer")
|
1356
|
+
embedding_layout = QVBoxLayout(embedding_group)
|
1357
|
+
embedding_layout.addWidget(self.embedding_viewer)
|
1358
|
+
middle_splitter.addWidget(embedding_group)
|
1359
|
+
|
1360
|
+
# Set splitter proportions (annotation viewer wider)
|
1361
|
+
middle_splitter.setSizes([500, 500])
|
1362
|
+
|
1363
|
+
# Add middle section to main layout with stretch factor
|
1364
|
+
self.main_layout.addWidget(middle_splitter, 1)
|
1365
|
+
|
1366
|
+
# Note: LabelWindow will be re-parented here by MainWindow.open_explorer_window()
|
1367
|
+
# The LabelWindow will be added to self.left_layout at index 1 by the MainWindow
|
1368
|
+
self.main_layout.addWidget(self.label_window)
|
1369
|
+
|
1370
|
+
# Bottom control buttons
|
1371
|
+
self.buttons_layout = QHBoxLayout()
|
1372
|
+
# Add stretch to push buttons to the right
|
1373
|
+
self.buttons_layout.addStretch(1)
|
1374
|
+
|
1375
|
+
# Add existing buttons to layout
|
1376
|
+
self.buttons_layout.addWidget(self.clear_preview_button)
|
1377
|
+
self.buttons_layout.addWidget(self.exit_button)
|
1378
|
+
self.buttons_layout.addWidget(self.apply_button)
|
1379
|
+
|
1380
|
+
self.main_layout.addLayout(self.buttons_layout)
|
1381
|
+
|
1382
|
+
# Set default condition to current image and refresh filters
|
1383
|
+
self.annotation_settings_widget.set_default_to_current_image()
|
1384
|
+
self.refresh_filters()
|
1385
|
+
|
1386
|
+
# Connect label selection to preview updates (only connect once)
|
1387
|
+
try:
|
1388
|
+
self.label_window.labelSelected.disconnect(self.on_label_selected_for_preview)
|
1389
|
+
except TypeError:
|
1390
|
+
pass # Signal wasn't connected yet
|
1391
|
+
|
1392
|
+
self.label_window.labelSelected.connect(self.on_label_selected_for_preview)
|
1393
|
+
self.annotation_viewer.selection_changed.connect(self.on_annotation_view_selection_changed)
|
1394
|
+
self.annotation_viewer.preview_changed.connect(self.on_preview_changed)
|
1395
|
+
self.annotation_viewer.reset_view_requested.connect(self.on_reset_view_requested)
|
1396
|
+
self.embedding_viewer.selection_changed.connect(self.on_embedding_view_selection_changed)
|
1397
|
+
self.embedding_viewer.reset_view_requested.connect(self.on_reset_view_requested)
|
1398
|
+
|
1399
|
+
@pyqtSlot(list)
|
1400
|
+
def on_annotation_view_selection_changed(self, changed_ann_ids):
|
1401
|
+
"""A selection was made in the AnnotationViewer, so update the EmbeddingViewer."""
|
1402
|
+
all_selected_ids = {w.data_item.annotation.id for w in self.annotation_viewer.selected_widgets}
|
1403
|
+
|
1404
|
+
# Only try to sync the selection with the EmbeddingViewer if it has points.
|
1405
|
+
# This prevents the feedback loop that was clearing the selection.
|
1406
|
+
if self.embedding_viewer.points_by_id:
|
1407
|
+
self.embedding_viewer.render_selection_from_ids(all_selected_ids)
|
1408
|
+
|
1409
|
+
self.update_label_window_selection() # Keep label window in sync
|
1410
|
+
|
1411
|
+
@pyqtSlot(list)
|
1412
|
+
def on_embedding_view_selection_changed(self, all_selected_ann_ids):
|
1413
|
+
"""A selection was made in the EmbeddingViewer, so update the AnnotationViewer."""
|
1414
|
+
# Check if this is a new selection being made when nothing was previously selected
|
1415
|
+
was_empty_selection = len(self.annotation_viewer.selected_widgets) == 0
|
1416
|
+
is_new_selection = len(all_selected_ann_ids) > 0
|
1417
|
+
|
1418
|
+
# Update the annotation viewer with the new selection
|
1419
|
+
self.annotation_viewer.render_selection_from_ids(set(all_selected_ann_ids))
|
1420
|
+
|
1421
|
+
# Auto-switch to isolation mode if conditions are met
|
1422
|
+
if (was_empty_selection and
|
1423
|
+
is_new_selection and
|
1424
|
+
not self.annotation_viewer.isolated_mode):
|
1425
|
+
print("Auto-switching to isolation mode due to new selection in embedding viewer")
|
1426
|
+
self.annotation_viewer.isolate_selection()
|
1427
|
+
|
1428
|
+
self.update_label_window_selection() # Keep label window in sync
|
1429
|
+
|
1430
|
+
@pyqtSlot(list)
|
1431
|
+
def on_preview_changed(self, changed_ann_ids):
|
1432
|
+
"""A preview color was changed in the AnnotationViewer, so update the EmbeddingViewer points."""
|
1433
|
+
for ann_id in changed_ann_ids:
|
1434
|
+
point = self.embedding_viewer.points_by_id.get(ann_id)
|
1435
|
+
if point:
|
1436
|
+
point.update() # Force the point to repaint itself
|
1437
|
+
|
1438
|
+
@pyqtSlot()
|
1439
|
+
def on_reset_view_requested(self):
|
1440
|
+
"""Handle reset view requests from double-click in either viewer."""
|
1441
|
+
# Clear all selections in both viewers
|
1442
|
+
self.annotation_viewer.clear_selection()
|
1443
|
+
self.embedding_viewer.render_selection_from_ids(set())
|
1444
|
+
|
1445
|
+
# Exit isolation mode if currently active
|
1446
|
+
if self.annotation_viewer.isolated_mode:
|
1447
|
+
self.annotation_viewer.show_all_annotations()
|
1448
|
+
|
1449
|
+
# Update button states
|
1450
|
+
self.update_button_states()
|
1451
|
+
|
1452
|
+
print("Reset view: cleared selections and exited isolation mode")
|
1453
|
+
|
1454
|
+
def update_label_window_selection(self):
|
1455
|
+
"""Update the label window based on the selection in the annotation viewer."""
|
1456
|
+
self.annotation_viewer.update_label_window_selection()
|
1457
|
+
|
1458
|
+
def get_filtered_data_items(self):
|
1459
|
+
"""Get annotations that match all conditions, returned as AnnotationDataItem objects."""
|
1460
|
+
data_items = []
|
1461
|
+
if not hasattr(self.main_window.annotation_window, 'annotations_dict'):
|
1462
|
+
return data_items
|
1463
|
+
|
1464
|
+
# Get current filter conditions
|
1465
|
+
selected_images = self.annotation_settings_widget.get_selected_images()
|
1466
|
+
selected_types = self.annotation_settings_widget.get_selected_annotation_types()
|
1467
|
+
selected_labels = self.annotation_settings_widget.get_selected_labels()
|
1468
|
+
|
1469
|
+
annotations_to_process = []
|
1470
|
+
for annotation in self.main_window.annotation_window.annotations_dict.values():
|
1471
|
+
annotation_matches = True
|
1472
|
+
|
1473
|
+
# Check image condition - if empty list, no annotations match
|
1474
|
+
if selected_images:
|
1475
|
+
annotation_image = os.path.basename(annotation.image_path)
|
1476
|
+
if annotation_image not in selected_images:
|
1477
|
+
annotation_matches = False
|
1478
|
+
else:
|
1479
|
+
# No images selected means no annotations should match
|
1480
|
+
annotation_matches = False
|
1481
|
+
|
1482
|
+
# Check annotation type condition - if empty list, no annotations match
|
1483
|
+
if annotation_matches:
|
1484
|
+
if selected_types:
|
1485
|
+
annotation_type = type(annotation).__name__
|
1486
|
+
if annotation_type not in selected_types:
|
1487
|
+
annotation_matches = False
|
1488
|
+
else:
|
1489
|
+
# No types selected means no annotations should match
|
1490
|
+
annotation_matches = False
|
1491
|
+
|
1492
|
+
# Check label condition - if empty list, no annotations match
|
1493
|
+
if annotation_matches:
|
1494
|
+
if selected_labels:
|
1495
|
+
annotation_label = annotation.label.short_label_code
|
1496
|
+
if annotation_label not in selected_labels:
|
1497
|
+
annotation_matches = False
|
1498
|
+
else:
|
1499
|
+
# No labels selected means no annotations should match
|
1500
|
+
annotation_matches = False
|
1501
|
+
|
1502
|
+
if annotation_matches:
|
1503
|
+
annotations_to_process.append(annotation)
|
1504
|
+
|
1505
|
+
# Ensure all filtered annotations have cropped images
|
1506
|
+
self._ensure_cropped_images(annotations_to_process)
|
1507
|
+
|
1508
|
+
# Wrap in AnnotationDataItem
|
1509
|
+
for ann in annotations_to_process:
|
1510
|
+
data_items.append(AnnotationDataItem(ann))
|
1511
|
+
|
1512
|
+
return data_items
|
1513
|
+
|
1514
|
+
def _ensure_cropped_images(self, annotations):
|
1515
|
+
"""Ensure all provided annotations have a cropped image available."""
|
1516
|
+
annotations_by_image = {}
|
1517
|
+
for annotation in annotations:
|
1518
|
+
# Only process annotations that don't have a cropped image yet
|
1519
|
+
if not annotation.cropped_image:
|
1520
|
+
image_path = annotation.image_path
|
1521
|
+
if image_path not in annotations_by_image:
|
1522
|
+
annotations_by_image[image_path] = []
|
1523
|
+
annotations_by_image[image_path].append(annotation)
|
1524
|
+
|
1525
|
+
# Only proceed if there are annotations that actually need cropping
|
1526
|
+
if annotations_by_image:
|
1527
|
+
progress_bar = ProgressBar(self, "Cropping Image Annotations")
|
1528
|
+
progress_bar.show()
|
1529
|
+
progress_bar.start_progress(len(annotations_by_image))
|
1530
|
+
|
1531
|
+
try:
|
1532
|
+
# Crop annotations for each image using the AnnotationWindow method
|
1533
|
+
# This ensures consistency with how cropped images are generated elsewhere
|
1534
|
+
for image_path, image_annotations in annotations_by_image.items():
|
1535
|
+
self.annotation_window.crop_annotations(
|
1536
|
+
image_path=image_path,
|
1537
|
+
annotations=image_annotations,
|
1538
|
+
return_annotations=False, # We don't need the return value
|
1539
|
+
verbose=False
|
1540
|
+
)
|
1541
|
+
# Update progress bar
|
1542
|
+
progress_bar.update_progress()
|
1543
|
+
|
1544
|
+
except Exception as e:
|
1545
|
+
print(f"Error cropping annotations: {e}")
|
1546
|
+
|
1547
|
+
finally:
|
1548
|
+
progress_bar.finish_progress()
|
1549
|
+
progress_bar.stop_progress()
|
1550
|
+
progress_bar.close()
|
1551
|
+
|
1552
|
+
def _extract_color_features(self, data_items, progress_bar=None, bins=32):
|
1553
|
+
"""
|
1554
|
+
Extracts a comprehensive set of color features from annotation crops using only NumPy.
|
1555
|
+
|
1556
|
+
For each image, it calculates:
|
1557
|
+
1. Color Moments (per channel):
|
1558
|
+
- Mean (1st moment)
|
1559
|
+
- Standard Deviation (2nd moment)
|
1560
|
+
- Skewness (3rd moment)
|
1561
|
+
- Kurtosis (4th moment)
|
1562
|
+
2. Color Histogram (per channel)
|
1563
|
+
3. Grayscale Statistics:
|
1564
|
+
- Mean Brightness
|
1565
|
+
- Contrast (Std Dev)
|
1566
|
+
- Intensity Range
|
1567
|
+
4. Geometric Features:
|
1568
|
+
- Area
|
1569
|
+
- Perimeter
|
1570
|
+
"""
|
1571
|
+
if progress_bar:
|
1572
|
+
progress_bar.set_title("Extracting Color Features...")
|
1573
|
+
progress_bar.start_progress(len(data_items))
|
1574
|
+
|
1575
|
+
features = []
|
1576
|
+
valid_data_items = []
|
1577
|
+
for item in data_items:
|
1578
|
+
pixmap = item.annotation.get_cropped_image()
|
1579
|
+
if pixmap and not pixmap.isNull():
|
1580
|
+
# arr has shape (height, width, 3)
|
1581
|
+
arr = pixmap_to_numpy(pixmap)
|
1582
|
+
|
1583
|
+
# Reshape for channel-wise statistics: (num_pixels, 3)
|
1584
|
+
pixels = arr.reshape(-1, 3)
|
1585
|
+
|
1586
|
+
# --- 1. Calculate Color Moments using only NumPy ---
|
1587
|
+
mean_color = np.mean(pixels, axis=0)
|
1588
|
+
std_color = np.std(pixels, axis=0)
|
1589
|
+
|
1590
|
+
# Center the data (subtract the mean) for skew/kurtosis calculation
|
1591
|
+
centered_pixels = pixels - mean_color
|
1592
|
+
|
1593
|
+
# A small value (epsilon) is added to the denominator to prevent division by zero
|
1594
|
+
epsilon = 1e-8
|
1595
|
+
skew_color = np.mean(centered_pixels**3, axis=0) / (std_color**3 + epsilon)
|
1596
|
+
kurt_color = np.mean(centered_pixels**4, axis=0) / (std_color**4 + epsilon) - 3
|
1597
|
+
|
1598
|
+
# --- 2. Calculate Color Histograms ---
|
1599
|
+
histograms = []
|
1600
|
+
for i in range(3): # For each channel (R, G, B)
|
1601
|
+
hist, _ = np.histogram(pixels[:, i], bins=bins, range=(0, 255))
|
1602
|
+
|
1603
|
+
# Normalize histogram
|
1604
|
+
hist_sum = np.sum(hist)
|
1605
|
+
if hist_sum > 0:
|
1606
|
+
histograms.append(hist / hist_sum)
|
1607
|
+
else: # Avoid division by zero
|
1608
|
+
histograms.append(np.zeros(bins))
|
1609
|
+
|
1610
|
+
# --- 3. Calculate Grayscale Statistics ---
|
1611
|
+
gray_arr = np.dot(arr[..., :3], [0.2989, 0.5870, 0.1140])
|
1612
|
+
grayscale_stats = np.array([
|
1613
|
+
np.mean(gray_arr), # Overall brightness
|
1614
|
+
np.std(gray_arr), # Overall contrast
|
1615
|
+
np.ptp(gray_arr) # Peak-to-peak intensity range
|
1616
|
+
])
|
1617
|
+
|
1618
|
+
# --- 4. Calculate Geometric Features ---
|
1619
|
+
area = getattr(item.annotation, 'area', 0.0)
|
1620
|
+
perimeter = getattr(item.annotation, 'perimeter', 0.0)
|
1621
|
+
geometric_features = np.array([area, perimeter])
|
1622
|
+
|
1623
|
+
# --- 5. Concatenate all features into a single vector ---
|
1624
|
+
current_features = np.concatenate([
|
1625
|
+
mean_color,
|
1626
|
+
std_color,
|
1627
|
+
skew_color,
|
1628
|
+
kurt_color,
|
1629
|
+
*histograms,
|
1630
|
+
grayscale_stats,
|
1631
|
+
geometric_features
|
1632
|
+
])
|
1633
|
+
|
1634
|
+
features.append(current_features)
|
1635
|
+
valid_data_items.append(item)
|
1636
|
+
else:
|
1637
|
+
print(f"Warning: Could not get cropped image for annotation ID {item.annotation.id}. Skipping.")
|
1638
|
+
|
1639
|
+
if progress_bar:
|
1640
|
+
progress_bar.update_progress()
|
1641
|
+
|
1642
|
+
return np.array(features), valid_data_items
|
1643
|
+
|
1644
|
+
def _extract_yolo_features(self, data_items, model_info, progress_bar=None):
|
1645
|
+
"""
|
1646
|
+
Extracts features from annotation crops using a specified YOLO model.
|
1647
|
+
Uses model.embed() for embedding features or model.predict() for classification probabilities.
|
1648
|
+
"""
|
1649
|
+
# Unpack model information
|
1650
|
+
model_name, feature_mode = model_info
|
1651
|
+
|
1652
|
+
# Load or retrieve the cached model
|
1653
|
+
if model_name != self.model_path or self.loaded_model is None:
|
1654
|
+
try:
|
1655
|
+
print(f"Loading new model: {model_name}")
|
1656
|
+
self.loaded_model = YOLO(model_name)
|
1657
|
+
self.model_path = model_name
|
1658
|
+
|
1659
|
+
# Determine image size from model config if possible
|
1660
|
+
try:
|
1661
|
+
self.imgsz = self.loaded_model.model.args['imgsz']
|
1662
|
+
if self.imgsz > 224:
|
1663
|
+
self.imgsz = 128
|
1664
|
+
except (AttributeError, KeyError):
|
1665
|
+
self.imgsz = 128
|
1666
|
+
|
1667
|
+
# Run a dummy inference to warm up the model
|
1668
|
+
print(f"Warming up model on device '{self.device}'...")
|
1669
|
+
dummy_image = np.zeros((self.imgsz, self.imgsz, 3), dtype=np.uint8)
|
1670
|
+
self.loaded_model.predict(dummy_image, imgsz=self.imgsz, half=True, device=self.device, verbose=False)
|
1671
|
+
|
1672
|
+
except Exception as e:
|
1673
|
+
print(f"ERROR: Could not load YOLO model '{model_name}': {e}")
|
1674
|
+
return np.array([]), []
|
1675
|
+
|
1676
|
+
if progress_bar:
|
1677
|
+
progress_bar.set_title(f"Preparing images...")
|
1678
|
+
progress_bar.start_progress(len(data_items))
|
1679
|
+
|
1680
|
+
# 1. Prepare a list of all valid images and their corresponding data items.
|
1681
|
+
image_list = []
|
1682
|
+
valid_data_items = []
|
1683
|
+
for item in data_items:
|
1684
|
+
pixmap = item.annotation.get_cropped_image()
|
1685
|
+
if pixmap and not pixmap.isNull():
|
1686
|
+
image_np = pixmap_to_numpy(pixmap)
|
1687
|
+
image_list.append(image_np)
|
1688
|
+
valid_data_items.append(item)
|
1689
|
+
else:
|
1690
|
+
print(f"Warning: Could not get cropped image for annotation ID {item.annotation.id}. Skipping.")
|
1691
|
+
|
1692
|
+
if progress_bar:
|
1693
|
+
progress_bar.update_progress()
|
1694
|
+
|
1695
|
+
if not valid_data_items:
|
1696
|
+
print("Warning: No valid images found to process.")
|
1697
|
+
return np.array([]), []
|
1698
|
+
|
1699
|
+
embeddings_list = []
|
1700
|
+
|
1701
|
+
try:
|
1702
|
+
if progress_bar:
|
1703
|
+
progress_bar.set_busy_mode(f"Extracting features with {os.path.basename(model_name)}...")
|
1704
|
+
|
1705
|
+
kwargs = {
|
1706
|
+
'stream': True,
|
1707
|
+
'imgsz': self.imgsz,
|
1708
|
+
'half': True,
|
1709
|
+
'device': self.device,
|
1710
|
+
'verbose': False
|
1711
|
+
}
|
1712
|
+
|
1713
|
+
# 2. Choose between embed() and predict() based on feature mode
|
1714
|
+
using_embed_method = feature_mode == "Embed Features"
|
1715
|
+
|
1716
|
+
if using_embed_method:
|
1717
|
+
print("Using embed() method")
|
1718
|
+
|
1719
|
+
# Use model.embed() method (uses the second to last layer)
|
1720
|
+
results_generator = self.loaded_model.embed(image_list, **kwargs)
|
1721
|
+
|
1722
|
+
else:
|
1723
|
+
print("Using predict() method")
|
1724
|
+
|
1725
|
+
# Use model.predict() method for classification probabilities
|
1726
|
+
results_generator = self.loaded_model.predict(image_list, **kwargs)
|
1727
|
+
|
1728
|
+
if progress_bar:
|
1729
|
+
progress_bar.set_title(f"Extracting features with {os.path.basename(model_name)}...")
|
1730
|
+
progress_bar.start_progress(len(valid_data_items))
|
1731
|
+
|
1732
|
+
# 3. Process the results from the generator - different handling based on method
|
1733
|
+
for result in results_generator:
|
1734
|
+
if using_embed_method:
|
1735
|
+
try:
|
1736
|
+
# With embed(), result is directly a tensor
|
1737
|
+
embedding = result.cpu().numpy().flatten()
|
1738
|
+
embeddings_list.append(embedding)
|
1739
|
+
except Exception as e:
|
1740
|
+
print(f"Error processing embedding: {e}")
|
1741
|
+
raise TypeError(
|
1742
|
+
f"Model '{os.path.basename(model_name)}' did not return valid embeddings. "
|
1743
|
+
f"Error: {str(e)}"
|
1744
|
+
)
|
1745
|
+
else:
|
1746
|
+
# Classification mode: We expect probability vectors
|
1747
|
+
if hasattr(result, 'probs') and result.probs is not None:
|
1748
|
+
embedding = result.probs.data.cpu().numpy().squeeze()
|
1749
|
+
embeddings_list.append(embedding)
|
1750
|
+
else:
|
1751
|
+
raise TypeError(
|
1752
|
+
f"Model '{os.path.basename(model_name)}' did not return probability vectors. "
|
1753
|
+
"Make sure this is a classification model."
|
1754
|
+
)
|
1755
|
+
|
1756
|
+
# Update progress for each item as it's processed from the stream.
|
1757
|
+
if progress_bar:
|
1758
|
+
progress_bar.update_progress()
|
1759
|
+
|
1760
|
+
if not embeddings_list:
|
1761
|
+
print("Warning: No features were extracted. The model may have failed.")
|
1762
|
+
return np.array([]), []
|
1763
|
+
|
1764
|
+
embeddings = np.array(embeddings_list)
|
1765
|
+
except Exception as e:
|
1766
|
+
print(f"ERROR: An error occurred during feature extraction: {e}")
|
1767
|
+
return np.array([]), []
|
1768
|
+
finally:
|
1769
|
+
# Clean up CUDA memory after the operation
|
1770
|
+
if torch.cuda.is_available():
|
1771
|
+
torch.cuda.empty_cache()
|
1772
|
+
|
1773
|
+
print(f"Successfully extracted {len(embeddings)} features with shape {embeddings.shape}")
|
1774
|
+
return embeddings, valid_data_items
|
1775
|
+
|
1776
|
+
def _extract_features(self, data_items, progress_bar=None):
|
1777
|
+
"""
|
1778
|
+
Dispatcher method to call the appropriate feature extraction function.
|
1779
|
+
It now passes the progress_bar object to the sub-methods.
|
1780
|
+
"""
|
1781
|
+
model_name, feature_mode = self.model_settings_widget.get_selected_model()
|
1782
|
+
|
1783
|
+
# Handle tuple or string return value
|
1784
|
+
if isinstance(model_name, tuple) and len(model_name) >= 3:
|
1785
|
+
model_name = model_name[0]
|
1786
|
+
else:
|
1787
|
+
model_name = model_name
|
1788
|
+
|
1789
|
+
if not model_name:
|
1790
|
+
print("No model selected or path provided.")
|
1791
|
+
return np.array([]), []
|
1792
|
+
|
1793
|
+
# --- MODIFIED: Pass the progress_bar object ---
|
1794
|
+
if model_name == "Color Features":
|
1795
|
+
return self._extract_color_features(data_items, progress_bar=progress_bar)
|
1796
|
+
elif ".pt" in model_name:
|
1797
|
+
# Pass the full model_info which may include embed layers
|
1798
|
+
return self._extract_yolo_features(data_items, (model_name, feature_mode), progress_bar=progress_bar)
|
1799
|
+
else:
|
1800
|
+
print(f"Unknown or invalid feature model selected: {model_name}")
|
1801
|
+
return np.array([]), []
|
1802
|
+
|
1803
|
+
def _run_dimensionality_reduction(self, features, params):
|
1804
|
+
"""
|
1805
|
+
Runs PCA, UMAP or t-SNE on the feature matrix using provided parameters.
|
1806
|
+
"""
|
1807
|
+
technique = params.get('technique', 'UMAP') # Changed default to UMAP
|
1808
|
+
random_state = 42
|
1809
|
+
|
1810
|
+
print(f"Running {technique} on {len(features)} items with params: {params}")
|
1811
|
+
if len(features) <= 2: # UMAP/t-SNE need at least a few points
|
1812
|
+
print("Not enough data points for dimensionality reduction.")
|
1813
|
+
return None
|
1814
|
+
|
1815
|
+
try:
|
1816
|
+
# Scaling is crucial, your implementation is already correct
|
1817
|
+
scaler = StandardScaler()
|
1818
|
+
features_scaled = scaler.fit_transform(features)
|
1819
|
+
|
1820
|
+
if technique == "UMAP":
|
1821
|
+
# Get hyperparameters from params, with sensible defaults
|
1822
|
+
n_neighbors = params.get('n_neighbors', 15)
|
1823
|
+
min_dist = params.get('min_dist', 0.1)
|
1824
|
+
metric = params.get('metric', 'cosine') # Allow metric to be specified
|
1825
|
+
|
1826
|
+
reducer = UMAP(
|
1827
|
+
n_components=2,
|
1828
|
+
random_state=random_state,
|
1829
|
+
n_neighbors=min(n_neighbors, len(features_scaled) - 1),
|
1830
|
+
min_dist=min_dist,
|
1831
|
+
metric=metric # Use the specified metric
|
1832
|
+
)
|
1833
|
+
|
1834
|
+
elif technique == "TSNE":
|
1835
|
+
perplexity = params.get('perplexity', 30)
|
1836
|
+
early_exaggeration = params.get('early_exaggeration', 12.0)
|
1837
|
+
learning_rate = params.get('learning_rate', 'auto')
|
1838
|
+
|
1839
|
+
reducer = TSNE(
|
1840
|
+
n_components=2,
|
1841
|
+
random_state=random_state,
|
1842
|
+
perplexity=min(perplexity, len(features_scaled) - 1),
|
1843
|
+
early_exaggeration=early_exaggeration,
|
1844
|
+
learning_rate=learning_rate,
|
1845
|
+
init='pca' # Improves stability and speed
|
1846
|
+
)
|
1847
|
+
|
1848
|
+
elif technique == "PCA":
|
1849
|
+
reducer = PCA(n_components=2, random_state=random_state)
|
1850
|
+
|
1851
|
+
else:
|
1852
|
+
print(f"Unknown dimensionality reduction technique: {technique}")
|
1853
|
+
return None
|
1854
|
+
|
1855
|
+
return reducer.fit_transform(features_scaled)
|
1856
|
+
|
1857
|
+
except Exception as e:
|
1858
|
+
print(f"Error during {technique} dimensionality reduction: {e}")
|
1859
|
+
return None
|
1860
|
+
|
1861
|
+
def _update_data_items_with_embedding(self, data_items, embedded_features):
|
1862
|
+
"""Updates AnnotationDataItem objects with embedding results."""
|
1863
|
+
scale_factor = 4000
|
1864
|
+
min_vals = np.min(embedded_features, axis=0)
|
1865
|
+
max_vals = np.max(embedded_features, axis=0)
|
1866
|
+
range_vals = max_vals - min_vals
|
1867
|
+
|
1868
|
+
for i, item in enumerate(data_items):
|
1869
|
+
# Normalize coordinates for consistent display
|
1870
|
+
norm_x = (embedded_features[i, 0] - min_vals[0]) / range_vals[0] if range_vals[0] > 0 else 0.5
|
1871
|
+
norm_y = (embedded_features[i, 1] - min_vals[1]) / range_vals[1] if range_vals[1] > 0 else 0.5
|
1872
|
+
# Scale and center the points in the view
|
1873
|
+
item.embedding_x = (norm_x * scale_factor) - (scale_factor / 2)
|
1874
|
+
item.embedding_y = (norm_y * scale_factor) - (scale_factor / 2)
|
1875
|
+
|
1876
|
+
def run_embedding_pipeline(self):
|
1877
|
+
"""
|
1878
|
+
Orchestrates the feature extraction and dimensionality reduction pipeline.
|
1879
|
+
This version correctly re-runs reduction on cached features when parameters change.
|
1880
|
+
"""
|
1881
|
+
if not self.current_data_items:
|
1882
|
+
print("No items to process for embedding.")
|
1883
|
+
return
|
1884
|
+
|
1885
|
+
# 1. Get current parameters from the UI
|
1886
|
+
embedding_params = self.embedding_settings_widget.get_embedding_parameters()
|
1887
|
+
model_info = self.model_settings_widget.get_selected_model() # Now returns tuple (model_name, feature_mode)
|
1888
|
+
|
1889
|
+
# Unpack model info - both model_name and feature_mode are used for caching
|
1890
|
+
if isinstance(model_info, tuple):
|
1891
|
+
selected_model, selected_feature_mode = model_info
|
1892
|
+
# Create a unique cache key that includes both model and feature mode
|
1893
|
+
cache_key = f"{selected_model}_{selected_feature_mode}"
|
1894
|
+
else:
|
1895
|
+
selected_model = model_info
|
1896
|
+
selected_feature_mode = "default"
|
1897
|
+
cache_key = f"{selected_model}_{selected_feature_mode}"
|
1898
|
+
|
1899
|
+
technique = embedding_params['technique']
|
1900
|
+
|
1901
|
+
QApplication.setOverrideCursor(Qt.WaitCursor)
|
1902
|
+
progress_bar = ProgressBar(self, "Generating Embedding Visualization")
|
1903
|
+
progress_bar.show()
|
1904
|
+
|
1905
|
+
try:
|
1906
|
+
# 2. Decide whether to use cached features or extract new ones
|
1907
|
+
# Now checks both model name AND feature mode
|
1908
|
+
if self.current_features is None or cache_key != self.current_feature_generating_model:
|
1909
|
+
# SLOW PATH: Extract and cache new features
|
1910
|
+
features, valid_data_items = self._extract_features(self.current_data_items, progress_bar=progress_bar)
|
1911
|
+
|
1912
|
+
self.current_features = features
|
1913
|
+
self.current_feature_generating_model = cache_key # Store the complete cache key
|
1914
|
+
self.current_data_items = valid_data_items
|
1915
|
+
self.annotation_viewer.update_annotations(self.current_data_items)
|
1916
|
+
else:
|
1917
|
+
# FAST PATH: Use existing features
|
1918
|
+
print("Using cached features. Skipping feature extraction.")
|
1919
|
+
features = self.current_features
|
1920
|
+
|
1921
|
+
if features is None or len(features) == 0:
|
1922
|
+
print("No valid features available. Aborting embedding.")
|
1923
|
+
return
|
1924
|
+
|
1925
|
+
# 3. Run dimensionality reduction with the latest parameters
|
1926
|
+
progress_bar.set_busy_mode(f"Running {technique} dimensionality reduction...")
|
1927
|
+
embedded_features = self._run_dimensionality_reduction(features, embedding_params)
|
1928
|
+
progress_bar.update_progress()
|
1929
|
+
|
1930
|
+
if embedded_features is None:
|
1931
|
+
return
|
1932
|
+
|
1933
|
+
# 4. Update the visualization with the new 2D layout
|
1934
|
+
progress_bar.set_busy_mode("Updating visualization...")
|
1935
|
+
self._update_data_items_with_embedding(self.current_data_items, embedded_features)
|
1936
|
+
|
1937
|
+
self.embedding_viewer.update_embeddings(self.current_data_items)
|
1938
|
+
self.embedding_viewer.show_embedding()
|
1939
|
+
self.embedding_viewer.fit_view_to_points()
|
1940
|
+
progress_bar.update_progress()
|
1941
|
+
|
1942
|
+
print(f"Successfully generated embedding for {len(self.current_data_items)} annotations using {technique}")
|
1943
|
+
|
1944
|
+
except Exception as e:
|
1945
|
+
print(f"Error during embedding pipeline: {e}")
|
1946
|
+
self.embedding_viewer.clear_points()
|
1947
|
+
self.embedding_viewer.show_placeholder()
|
1948
|
+
|
1949
|
+
finally:
|
1950
|
+
QApplication.restoreOverrideCursor()
|
1951
|
+
progress_bar.finish_progress()
|
1952
|
+
progress_bar.stop_progress()
|
1953
|
+
progress_bar.close()
|
1954
|
+
|
1955
|
+
def refresh_filters(self):
|
1956
|
+
"""Refresh display: filter data and update annotation viewer."""
|
1957
|
+
QApplication.setOverrideCursor(Qt.WaitCursor)
|
1958
|
+
try:
|
1959
|
+
# Get filtered data and store for potential embedding
|
1960
|
+
self.current_data_items = self.get_filtered_data_items()
|
1961
|
+
|
1962
|
+
# --- MODIFIED: Invalidate the feature cache ---
|
1963
|
+
# Since the filtered items have changed, the old features are no longer valid.
|
1964
|
+
self.current_features = None
|
1965
|
+
|
1966
|
+
# Update annotation viewer with filtered data
|
1967
|
+
self.annotation_viewer.update_annotations(self.current_data_items)
|
1968
|
+
|
1969
|
+
# Clear embedding viewer and show placeholder, as it is now out of sync
|
1970
|
+
self.embedding_viewer.clear_points()
|
1971
|
+
self.embedding_viewer.show_placeholder()
|
1972
|
+
|
1973
|
+
finally:
|
1974
|
+
QApplication.restoreOverrideCursor()
|
1975
|
+
|
1976
|
+
def on_label_selected_for_preview(self, label):
|
1977
|
+
"""Handle label selection to update preview state."""
|
1978
|
+
if hasattr(self, 'annotation_viewer') and self.annotation_viewer.selected_widgets:
|
1979
|
+
self.annotation_viewer.apply_preview_label_to_selected(label)
|
1980
|
+
self.update_button_states()
|
1981
|
+
|
1982
|
+
def clear_preview_changes(self):
|
1983
|
+
"""Clear all preview changes and revert to original labels."""
|
1984
|
+
if hasattr(self, 'annotation_viewer'):
|
1985
|
+
self.annotation_viewer.clear_preview_states()
|
1986
|
+
self.update_button_states()
|
1987
|
+
print("Cleared all preview changes")
|
1988
|
+
|
1989
|
+
def update_button_states(self):
|
1990
|
+
"""Update the state of Clear Preview and Apply buttons."""
|
1991
|
+
has_changes = (hasattr(self, 'annotation_viewer') and self.annotation_viewer.has_preview_changes())
|
1992
|
+
|
1993
|
+
self.clear_preview_button.setEnabled(has_changes)
|
1994
|
+
self.apply_button.setEnabled(has_changes)
|
1995
|
+
|
1996
|
+
summary = self.annotation_viewer.get_preview_changes_summary()
|
1997
|
+
self.clear_preview_button.setToolTip(f"Clear all preview changes - {summary}")
|
1998
|
+
self.apply_button.setToolTip(f"Apply changes - {summary}")
|
1999
|
+
|
2000
|
+
def apply(self):
|
2001
|
+
"""Apply any modifications to the actual annotations."""
|
2002
|
+
# Make cursor busy
|
2003
|
+
QApplication.setOverrideCursor(Qt.WaitCursor)
|
2004
|
+
|
2005
|
+
try:
|
2006
|
+
applied_annotations = self.annotation_viewer.apply_preview_changes_permanently()
|
2007
|
+
|
2008
|
+
if applied_annotations:
|
2009
|
+
# Find which data items were affected and tell their visual components to update
|
2010
|
+
changed_ids = {ann.id for ann in applied_annotations}
|
2011
|
+
for item in self.current_data_items:
|
2012
|
+
if item.annotation.id in changed_ids:
|
2013
|
+
# Update annotation widget in the grid
|
2014
|
+
widget = self.annotation_viewer.annotation_widgets_by_id.get(item.annotation.id)
|
2015
|
+
if widget:
|
2016
|
+
widget.update() # Repaint to show new permanent color
|
2017
|
+
|
2018
|
+
# Update point in the embedding viewer
|
2019
|
+
point = self.embedding_viewer.points_by_id.get(item.annotation.id)
|
2020
|
+
if point:
|
2021
|
+
point.update() # Repaint to show new permanent color
|
2022
|
+
|
2023
|
+
# Update sorting if currently sorting by label
|
2024
|
+
if self.annotation_viewer.sort_combo.currentText() == "Label":
|
2025
|
+
self.annotation_viewer.recalculate_widget_positions()
|
2026
|
+
|
2027
|
+
# Update the main application's data
|
2028
|
+
affected_images = {ann.image_path for ann in applied_annotations}
|
2029
|
+
for image_path in affected_images:
|
2030
|
+
self.image_window.update_image_annotations(image_path)
|
2031
|
+
self.annotation_window.load_annotations()
|
2032
|
+
|
2033
|
+
# Clear selection and button states
|
2034
|
+
self.annotation_viewer.clear_selection()
|
2035
|
+
self.embedding_viewer.render_selection_from_ids(set()) # Clear embedding selection
|
2036
|
+
self.update_button_states()
|
2037
|
+
|
2038
|
+
print(f"Applied changes to {len(applied_annotations)} annotation(s)")
|
2039
|
+
else:
|
2040
|
+
print("No preview changes to apply")
|
2041
|
+
|
2042
|
+
except Exception as e:
|
2043
|
+
print(f"Error applying modifications: {e}")
|
2044
|
+
finally:
|
2045
|
+
# Restore cursor
|
2046
|
+
QApplication.restoreOverrideCursor()
|
2047
|
+
|
2048
|
+
def _cleanup_resources(self):
|
2049
|
+
"""
|
2050
|
+
Clean up heavy resources like the loaded model and clear GPU cache.
|
2051
|
+
This is called when the window is closed to free up memory.
|
2052
|
+
"""
|
2053
|
+
print("Cleaning up Explorer resources...")
|
2054
|
+
|
2055
|
+
# Reset model and feature caches
|
2056
|
+
self.imgsz = 128
|
2057
|
+
self.loaded_model = None
|
2058
|
+
self.model_path = ""
|
2059
|
+
self.current_features = None
|
2060
|
+
self.current_feature_generating_model = ""
|
2061
|
+
|
2062
|
+
# Clear CUDA cache if available to free up GPU memory
|
2063
|
+
if torch.cuda.is_available():
|
2064
|
+
print("Clearing CUDA cache.")
|
2065
|
+
torch.cuda.empty_cache()
|
2066
|
+
|
2067
|
+
print("Cleanup complete.")
|