coralnet-toolbox 0.0.72__py2.py3-none-any.whl → 0.0.74__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- coralnet_toolbox/Annotations/QtAnnotation.py +28 -69
- coralnet_toolbox/Annotations/QtMaskAnnotation.py +408 -0
- coralnet_toolbox/Annotations/QtMultiPolygonAnnotation.py +72 -56
- coralnet_toolbox/Annotations/QtPatchAnnotation.py +165 -216
- coralnet_toolbox/Annotations/QtPolygonAnnotation.py +497 -353
- coralnet_toolbox/Annotations/QtRectangleAnnotation.py +126 -116
- coralnet_toolbox/AutoDistill/QtDeployModel.py +23 -12
- coralnet_toolbox/CoralNet/QtDownload.py +2 -1
- coralnet_toolbox/Explorer/QtDataItem.py +1 -1
- coralnet_toolbox/Explorer/QtExplorer.py +159 -17
- coralnet_toolbox/Explorer/QtSettingsWidgets.py +160 -86
- coralnet_toolbox/IO/QtExportTagLabAnnotations.py +30 -10
- coralnet_toolbox/IO/QtImportTagLabAnnotations.py +21 -15
- coralnet_toolbox/IO/QtOpenProject.py +46 -78
- coralnet_toolbox/IO/QtSaveProject.py +18 -43
- coralnet_toolbox/MachineLearning/DeployModel/QtDetect.py +22 -11
- coralnet_toolbox/MachineLearning/DeployModel/QtSegment.py +22 -10
- coralnet_toolbox/MachineLearning/ExportDataset/QtBase.py +61 -24
- coralnet_toolbox/MachineLearning/ExportDataset/QtClassify.py +5 -1
- coralnet_toolbox/MachineLearning/ExportDataset/QtDetect.py +19 -6
- coralnet_toolbox/MachineLearning/ExportDataset/QtSegment.py +21 -8
- coralnet_toolbox/MachineLearning/ImportDataset/QtBase.py +42 -22
- coralnet_toolbox/MachineLearning/VideoInference/QtBase.py +0 -4
- coralnet_toolbox/QtAnnotationWindow.py +42 -14
- coralnet_toolbox/QtEventFilter.py +19 -2
- coralnet_toolbox/QtImageWindow.py +134 -86
- coralnet_toolbox/QtLabelWindow.py +14 -2
- coralnet_toolbox/QtMainWindow.py +122 -9
- coralnet_toolbox/QtProgressBar.py +52 -27
- coralnet_toolbox/Rasters/QtRaster.py +59 -7
- coralnet_toolbox/Rasters/RasterTableModel.py +42 -14
- coralnet_toolbox/SAM/QtBatchInference.py +0 -2
- coralnet_toolbox/SAM/QtDeployGenerator.py +22 -11
- coralnet_toolbox/SAM/QtDeployPredictor.py +10 -0
- coralnet_toolbox/SeeAnything/QtBatchInference.py +19 -221
- coralnet_toolbox/SeeAnything/QtDeployGenerator.py +1634 -0
- coralnet_toolbox/SeeAnything/QtDeployPredictor.py +107 -154
- coralnet_toolbox/SeeAnything/QtTrainModel.py +115 -45
- coralnet_toolbox/SeeAnything/__init__.py +2 -0
- coralnet_toolbox/Tools/QtCutSubTool.py +18 -2
- coralnet_toolbox/Tools/QtResizeSubTool.py +19 -2
- coralnet_toolbox/Tools/QtSAMTool.py +222 -57
- coralnet_toolbox/Tools/QtSeeAnythingTool.py +223 -55
- coralnet_toolbox/Tools/QtSelectSubTool.py +6 -4
- coralnet_toolbox/Tools/QtSelectTool.py +27 -3
- coralnet_toolbox/Tools/QtSubtractSubTool.py +66 -0
- coralnet_toolbox/Tools/QtWorkAreaTool.py +25 -13
- coralnet_toolbox/Tools/__init__.py +2 -0
- coralnet_toolbox/__init__.py +1 -1
- coralnet_toolbox/utilities.py +137 -47
- coralnet_toolbox-0.0.74.dist-info/METADATA +375 -0
- {coralnet_toolbox-0.0.72.dist-info → coralnet_toolbox-0.0.74.dist-info}/RECORD +56 -53
- coralnet_toolbox-0.0.72.dist-info/METADATA +0 -341
- {coralnet_toolbox-0.0.72.dist-info → coralnet_toolbox-0.0.74.dist-info}/WHEEL +0 -0
- {coralnet_toolbox-0.0.72.dist-info → coralnet_toolbox-0.0.74.dist-info}/entry_points.txt +0 -0
- {coralnet_toolbox-0.0.72.dist-info → coralnet_toolbox-0.0.74.dist-info}/licenses/LICENSE.txt +0 -0
- {coralnet_toolbox-0.0.72.dist-info → coralnet_toolbox-0.0.74.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1634 @@
|
|
1
|
+
import warnings
|
2
|
+
|
3
|
+
import os
|
4
|
+
import gc
|
5
|
+
|
6
|
+
import numpy as np
|
7
|
+
from sklearn.decomposition import PCA
|
8
|
+
|
9
|
+
import torch
|
10
|
+
from torch.cuda import empty_cache
|
11
|
+
|
12
|
+
import pyqtgraph as pg
|
13
|
+
from pyqtgraph.Qt import QtGui
|
14
|
+
|
15
|
+
from ultralytics import YOLOE
|
16
|
+
from ultralytics.models.yolo.yoloe import YOLOEVPSegPredictor
|
17
|
+
|
18
|
+
from PyQt5.QtCore import Qt
|
19
|
+
from PyQt5.QtWidgets import (QMessageBox, QVBoxLayout, QApplication, QFileDialog,
|
20
|
+
QLabel, QDialog, QDialogButtonBox, QGroupBox, QLineEdit,
|
21
|
+
QFormLayout, QComboBox, QSpinBox, QSlider, QPushButton,
|
22
|
+
QHBoxLayout)
|
23
|
+
|
24
|
+
from coralnet_toolbox.Annotations.QtPolygonAnnotation import PolygonAnnotation
|
25
|
+
from coralnet_toolbox.Annotations.QtRectangleAnnotation import RectangleAnnotation
|
26
|
+
|
27
|
+
from coralnet_toolbox.Results import ResultsProcessor
|
28
|
+
from coralnet_toolbox.Results import MapResults
|
29
|
+
from coralnet_toolbox.Results import CombineResults
|
30
|
+
|
31
|
+
from coralnet_toolbox.QtProgressBar import ProgressBar
|
32
|
+
from coralnet_toolbox.QtImageWindow import ImageWindow
|
33
|
+
|
34
|
+
from coralnet_toolbox.Icons import get_icon
|
35
|
+
|
36
|
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
37
|
+
warnings.filterwarnings("ignore", category=UserWarning)
|
38
|
+
|
39
|
+
|
40
|
+
# ----------------------------------------------------------------------------------------------------------------------
|
41
|
+
# Classes
|
42
|
+
# ----------------------------------------------------------------------------------------------------------------------
|
43
|
+
|
44
|
+
|
45
|
+
class DeployGeneratorDialog(QDialog):
|
46
|
+
"""
|
47
|
+
Perform See Anything (YOLOE) on multiple images using a reference image and label.
|
48
|
+
|
49
|
+
:param main_window: MainWindow object
|
50
|
+
:param parent: Parent widget
|
51
|
+
"""
|
52
|
+
def __init__(self, main_window, parent=None):
|
53
|
+
super().__init__(parent)
|
54
|
+
self.main_window = main_window
|
55
|
+
self.label_window = main_window.label_window
|
56
|
+
self.image_window = main_window.image_window
|
57
|
+
self.annotation_window = main_window.annotation_window
|
58
|
+
self.sam_dialog = None
|
59
|
+
|
60
|
+
self.setWindowIcon(get_icon("eye.png"))
|
61
|
+
self.setWindowTitle("See Anything (YOLOE) Generator (Ctrl + 5)")
|
62
|
+
self.resize(800, 800) # Increased size to accommodate the horizontal layout
|
63
|
+
|
64
|
+
self.deploy_model_dialog = None
|
65
|
+
self.loaded_model = None
|
66
|
+
self.last_selected_label_code = None
|
67
|
+
|
68
|
+
# Initialize variables
|
69
|
+
self.imgsz = 1024
|
70
|
+
self.iou_thresh = 0.20
|
71
|
+
self.uncertainty_thresh = 0.30
|
72
|
+
self.area_thresh_min = 0.00
|
73
|
+
self.area_thresh_max = 0.40
|
74
|
+
|
75
|
+
self.task = 'detect'
|
76
|
+
self.max_detect = 300
|
77
|
+
self.loaded_model = None
|
78
|
+
self.model_path = None
|
79
|
+
self.class_mapping = {}
|
80
|
+
|
81
|
+
# Reference image and label
|
82
|
+
self.reference_label = None
|
83
|
+
self.reference_image_paths = []
|
84
|
+
|
85
|
+
# Visual Prompting Encoding (VPE) - legacy single tensor variable
|
86
|
+
self.vpe_path = None
|
87
|
+
self.vpe = None
|
88
|
+
|
89
|
+
# New separate VPE collections
|
90
|
+
self.imported_vpes = [] # VPEs loaded from file
|
91
|
+
self.reference_vpes = [] # VPEs created from reference images
|
92
|
+
|
93
|
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
94
|
+
|
95
|
+
# Main vertical layout for the dialog
|
96
|
+
self.layout = QVBoxLayout(self)
|
97
|
+
|
98
|
+
# Setup the info layout at the top
|
99
|
+
self.setup_info_layout()
|
100
|
+
|
101
|
+
# Create horizontal layout for the two panels
|
102
|
+
self.horizontal_layout = QHBoxLayout()
|
103
|
+
self.layout.addLayout(self.horizontal_layout)
|
104
|
+
|
105
|
+
# Create left panel
|
106
|
+
self.left_panel = QVBoxLayout()
|
107
|
+
self.horizontal_layout.addLayout(self.left_panel)
|
108
|
+
|
109
|
+
# Create right panel
|
110
|
+
self.right_panel = QVBoxLayout()
|
111
|
+
self.horizontal_layout.addLayout(self.right_panel)
|
112
|
+
|
113
|
+
# Add layouts to the left panel
|
114
|
+
self.setup_models_layout()
|
115
|
+
self.setup_parameters_layout()
|
116
|
+
self.setup_sam_layout()
|
117
|
+
self.setup_model_buttons_layout()
|
118
|
+
self.setup_status_layout()
|
119
|
+
|
120
|
+
# Add layouts to the right panel
|
121
|
+
self.setup_reference_layout()
|
122
|
+
|
123
|
+
# # Add a full ImageWindow instance for target image selection
|
124
|
+
self.image_selection_window = ImageWindow(self.main_window)
|
125
|
+
self.right_panel.addWidget(self.image_selection_window)
|
126
|
+
|
127
|
+
# Setup the buttons layout at the bottom
|
128
|
+
self.setup_buttons_layout()
|
129
|
+
|
130
|
+
def configure_image_window_for_dialog(self):
|
131
|
+
"""
|
132
|
+
Disables parts of the internal ImageWindow UI to guide user selection.
|
133
|
+
This forces the image list to only show images with annotations
|
134
|
+
matching the selected reference label.
|
135
|
+
"""
|
136
|
+
iw = self.image_selection_window
|
137
|
+
|
138
|
+
# Block signals to prevent setChecked from triggering the ImageWindow's
|
139
|
+
# own filtering logic. We want to be in complete control.
|
140
|
+
iw.highlighted_checkbox.blockSignals(True)
|
141
|
+
iw.has_predictions_checkbox.blockSignals(True)
|
142
|
+
iw.no_annotations_checkbox.blockSignals(True)
|
143
|
+
iw.has_annotations_checkbox.blockSignals(True)
|
144
|
+
|
145
|
+
# Disable and set filter checkboxes
|
146
|
+
iw.highlighted_checkbox.setEnabled(False)
|
147
|
+
iw.has_predictions_checkbox.setEnabled(False)
|
148
|
+
iw.no_annotations_checkbox.setEnabled(False)
|
149
|
+
iw.has_annotations_checkbox.setEnabled(False)
|
150
|
+
|
151
|
+
iw.highlighted_checkbox.setChecked(False)
|
152
|
+
iw.has_predictions_checkbox.setChecked(False)
|
153
|
+
iw.no_annotations_checkbox.setChecked(False)
|
154
|
+
iw.has_annotations_checkbox.setChecked(True) # This will no longer trigger a filter
|
155
|
+
|
156
|
+
# Unblock signals now that we're done.
|
157
|
+
iw.highlighted_checkbox.blockSignals(False)
|
158
|
+
iw.has_predictions_checkbox.blockSignals(False)
|
159
|
+
iw.no_annotations_checkbox.blockSignals(False)
|
160
|
+
iw.has_annotations_checkbox.blockSignals(False)
|
161
|
+
|
162
|
+
# Disable search UI elements
|
163
|
+
iw.home_button.setEnabled(False)
|
164
|
+
iw.image_search_button.setEnabled(False)
|
165
|
+
iw.label_search_button.setEnabled(False)
|
166
|
+
iw.search_bar_images.setEnabled(False)
|
167
|
+
iw.search_bar_labels.setEnabled(False)
|
168
|
+
iw.top_k_combo.setEnabled(False)
|
169
|
+
|
170
|
+
# Hide the "Current" label as it is not applicable in this dialog
|
171
|
+
iw.current_image_index_label.hide()
|
172
|
+
|
173
|
+
# Set Top-K to Top1
|
174
|
+
iw.top_k_combo.setCurrentText("Top1")
|
175
|
+
|
176
|
+
# Disconnect the double-click signal to prevent it from loading an image
|
177
|
+
# in the main window, as this dialog is for selection only.
|
178
|
+
try:
|
179
|
+
iw.tableView.doubleClicked.disconnect()
|
180
|
+
except TypeError:
|
181
|
+
pass
|
182
|
+
|
183
|
+
# CRITICAL: Override the load_first_filtered_image method to prevent auto-loading
|
184
|
+
# This is the key fix to prevent unwanted load_image_by_path calls
|
185
|
+
iw.load_first_filtered_image = lambda: None
|
186
|
+
|
187
|
+
def showEvent(self, event):
|
188
|
+
"""
|
189
|
+
Set up the layout when the dialog is shown.
|
190
|
+
|
191
|
+
:param event: Show event
|
192
|
+
"""
|
193
|
+
super().showEvent(event)
|
194
|
+
self.initialize_uncertainty_threshold()
|
195
|
+
self.initialize_iou_threshold()
|
196
|
+
self.initialize_area_threshold()
|
197
|
+
|
198
|
+
# Configure the image window's UI elements for this specific dialog
|
199
|
+
self.configure_image_window_for_dialog()
|
200
|
+
# Sync with main window's images BEFORE updating labels
|
201
|
+
self.sync_image_window()
|
202
|
+
# This now populates the dropdown, restores the last selection,
|
203
|
+
# and then manually triggers the image filtering.
|
204
|
+
self.update_reference_labels()
|
205
|
+
|
206
|
+
def sync_image_window(self):
|
207
|
+
"""
|
208
|
+
Syncs by directly adopting the main manager's up-to-date raster objects,
|
209
|
+
avoiding redundant and slow re-calculation of annotation info.
|
210
|
+
"""
|
211
|
+
main_manager = self.main_window.image_window.raster_manager
|
212
|
+
dialog_manager = self.image_selection_window.raster_manager
|
213
|
+
|
214
|
+
# Since the main_manager's rasters are always up-to-date, we can
|
215
|
+
# simply replace the dialog's raster dictionary and path list entirely.
|
216
|
+
# This is a shallow copy of the dictionary, which is extremely fast.
|
217
|
+
# The Raster objects themselves are not copied, just referenced.
|
218
|
+
dialog_manager.rasters = main_manager.rasters.copy()
|
219
|
+
|
220
|
+
# Update the path list to match the new dictionary of rasters.
|
221
|
+
dialog_manager.image_paths = list(dialog_manager.rasters.keys())
|
222
|
+
|
223
|
+
# The slow 'for' loop that called update_annotation_info is now gone.
|
224
|
+
# We are trusting that each raster object from the main_manager
|
225
|
+
# already has its .label_set and .annotation_type_set correctly populated.
|
226
|
+
|
227
|
+
def filter_images_by_label_and_type(self):
|
228
|
+
"""
|
229
|
+
Filters the image list to show only images that contain at least one
|
230
|
+
annotation that has BOTH the selected label AND a valid type (Polygon or Rectangle).
|
231
|
+
This uses the fast, pre-computed cache for performance.
|
232
|
+
"""
|
233
|
+
# Persist the user's current highlights from the table model before filtering.
|
234
|
+
# This ensures that if the user highlights items and then changes the filter,
|
235
|
+
# their selection is not lost.
|
236
|
+
current_highlights = self.image_selection_window.table_model.get_highlighted_paths()
|
237
|
+
if current_highlights:
|
238
|
+
self.reference_image_paths = current_highlights
|
239
|
+
|
240
|
+
reference_label = self.reference_label_combo_box.currentData()
|
241
|
+
reference_label_text = self.reference_label_combo_box.currentText()
|
242
|
+
|
243
|
+
# Store the last selected label for a better user experience on re-opening.
|
244
|
+
if reference_label_text:
|
245
|
+
self.last_selected_label_code = reference_label_text
|
246
|
+
# Also store the reference label object itself
|
247
|
+
self.reference_label = reference_label
|
248
|
+
|
249
|
+
if not reference_label:
|
250
|
+
# If no label is selected (e.g., during initialization), show an empty list.
|
251
|
+
self.image_selection_window.table_model.set_filtered_paths([])
|
252
|
+
return
|
253
|
+
|
254
|
+
all_paths = self.image_selection_window.raster_manager.image_paths
|
255
|
+
final_filtered_paths = []
|
256
|
+
|
257
|
+
valid_types = {"RectangleAnnotation", "PolygonAnnotation"}
|
258
|
+
selected_label_code = reference_label.short_label_code
|
259
|
+
|
260
|
+
# Loop through paths and check the pre-computed map on each raster
|
261
|
+
for path in all_paths:
|
262
|
+
raster = self.image_selection_window.raster_manager.get_raster(path)
|
263
|
+
if not raster:
|
264
|
+
continue
|
265
|
+
|
266
|
+
# 1. From the cache, get the set of annotation types specifically for our selected label.
|
267
|
+
# Use .get() to safely return an empty set if the label isn't on this image at all.
|
268
|
+
types_for_this_label = raster.label_to_types_map.get(selected_label_code, set())
|
269
|
+
|
270
|
+
# 2. Check for any overlap between the types found FOR THIS LABEL and the
|
271
|
+
# valid types we need (Polygon/Rectangle). This is the key check.
|
272
|
+
if not valid_types.isdisjoint(types_for_this_label):
|
273
|
+
# This image is a valid reference because the selected label exists
|
274
|
+
# on a Polygon or Rectangle. Add it to the list.
|
275
|
+
final_filtered_paths.append(path)
|
276
|
+
|
277
|
+
# Directly set the filtered list in the table model.
|
278
|
+
self.image_selection_window.table_model.set_filtered_paths(final_filtered_paths)
|
279
|
+
|
280
|
+
# Try to preserve any previous selections
|
281
|
+
if hasattr(self, 'reference_image_paths') and self.reference_image_paths:
|
282
|
+
# Find which of our previously selected paths are still in the filtered list
|
283
|
+
valid_selections = [p for p in self.reference_image_paths if p in final_filtered_paths]
|
284
|
+
if valid_selections:
|
285
|
+
# Highlight previously selected paths that are still valid
|
286
|
+
self.image_selection_window.table_model.set_highlighted_paths(valid_selections)
|
287
|
+
|
288
|
+
# After filtering, update all labels with the correct counts.
|
289
|
+
dialog_iw = self.image_selection_window
|
290
|
+
dialog_iw.update_image_count_label(len(final_filtered_paths)) # Set "Total" to filtered count
|
291
|
+
dialog_iw.update_current_image_index_label()
|
292
|
+
dialog_iw.update_highlighted_count_label()
|
293
|
+
|
294
|
+
def accept(self):
|
295
|
+
"""
|
296
|
+
Validate selections and store them before closing the dialog.
|
297
|
+
A prediction is valid if a model and label are selected, and the user
|
298
|
+
has provided either reference images or an imported VPE file.
|
299
|
+
"""
|
300
|
+
if not self.loaded_model:
|
301
|
+
QMessageBox.warning(self,
|
302
|
+
"No Model",
|
303
|
+
"A model must be loaded before running predictions.")
|
304
|
+
return
|
305
|
+
|
306
|
+
# Set reference label from combo box
|
307
|
+
self.reference_label = self.reference_label_combo_box.currentData()
|
308
|
+
if not self.reference_label:
|
309
|
+
QMessageBox.warning(self,
|
310
|
+
"No Reference Label",
|
311
|
+
"A reference label must be selected.")
|
312
|
+
return
|
313
|
+
|
314
|
+
# Stash the current UI selection before validating.
|
315
|
+
self.update_stashed_references_from_ui()
|
316
|
+
|
317
|
+
# Check for a valid VPE source using the now-stashed list.
|
318
|
+
has_reference_images = bool(self.reference_image_paths)
|
319
|
+
has_imported_vpes = bool(self.imported_vpes)
|
320
|
+
|
321
|
+
if not has_reference_images and not has_imported_vpes:
|
322
|
+
QMessageBox.warning(self,
|
323
|
+
"No VPE Source Provided",
|
324
|
+
"You must highlight at least one reference image or load a VPE file to proceed.")
|
325
|
+
return
|
326
|
+
|
327
|
+
# If validation passes, close the dialog.
|
328
|
+
super().accept()
|
329
|
+
|
330
|
+
def setup_info_layout(self):
|
331
|
+
"""
|
332
|
+
Set up the layout and widgets for the info layout that spans the top.
|
333
|
+
"""
|
334
|
+
group_box = QGroupBox("Information")
|
335
|
+
layout = QVBoxLayout()
|
336
|
+
|
337
|
+
# Create a QLabel with explanatory text and hyperlink
|
338
|
+
info_label = QLabel("Choose a Generator to deploy. "
|
339
|
+
"Select a reference label, then highlight reference images that contain examples. "
|
340
|
+
"Each additional reference image may increase accuracy but also processing time.")
|
341
|
+
|
342
|
+
info_label.setOpenExternalLinks(True)
|
343
|
+
info_label.setWordWrap(True)
|
344
|
+
layout.addWidget(info_label)
|
345
|
+
|
346
|
+
group_box.setLayout(layout)
|
347
|
+
self.layout.addWidget(group_box) # Add to main layout so it spans both panels
|
348
|
+
|
349
|
+
def setup_models_layout(self):
|
350
|
+
"""
|
351
|
+
Setup the models layout with a simple model selection combo box (no tabs).
|
352
|
+
"""
|
353
|
+
group_box = QGroupBox("Model Selection")
|
354
|
+
layout = QFormLayout()
|
355
|
+
|
356
|
+
self.model_combo = QComboBox()
|
357
|
+
self.model_combo.setEditable(True)
|
358
|
+
|
359
|
+
# Define available models (keep the existing dictionary)
|
360
|
+
self.models = [
|
361
|
+
'yoloe-v8s-seg.pt',
|
362
|
+
'yoloe-v8m-seg.pt',
|
363
|
+
'yoloe-v8l-seg.pt',
|
364
|
+
'yoloe-11s-seg.pt',
|
365
|
+
'yoloe-11m-seg.pt',
|
366
|
+
'yoloe-11l-seg.pt',
|
367
|
+
]
|
368
|
+
|
369
|
+
# Add all models to combo box
|
370
|
+
for model_name in self.models:
|
371
|
+
self.model_combo.addItem(model_name)
|
372
|
+
|
373
|
+
# Set the default model
|
374
|
+
self.model_combo.setCurrentIndex(self.models.index('yoloe-v8s-seg.pt'))
|
375
|
+
# Create a layout for the model selection
|
376
|
+
layout.addRow(QLabel("Models:"), self.model_combo)
|
377
|
+
|
378
|
+
# Add custom vpe file selection
|
379
|
+
self.vpe_path_edit = QLineEdit()
|
380
|
+
browse_button = QPushButton("Browse...")
|
381
|
+
browse_button.clicked.connect(self.browse_vpe_file)
|
382
|
+
|
383
|
+
vpe_path_layout = QHBoxLayout()
|
384
|
+
vpe_path_layout.addWidget(self.vpe_path_edit)
|
385
|
+
vpe_path_layout.addWidget(browse_button)
|
386
|
+
layout.addRow("Custom VPE:", vpe_path_layout)
|
387
|
+
|
388
|
+
group_box.setLayout(layout)
|
389
|
+
self.left_panel.addWidget(group_box) # Add to left panel
|
390
|
+
|
391
|
+
def setup_parameters_layout(self):
|
392
|
+
"""
|
393
|
+
Setup parameter control section in a group box.
|
394
|
+
"""
|
395
|
+
group_box = QGroupBox("Parameters")
|
396
|
+
layout = QFormLayout()
|
397
|
+
|
398
|
+
# Task dropdown
|
399
|
+
self.use_task_dropdown = QComboBox()
|
400
|
+
self.use_task_dropdown.addItems(["detect", "segment"])
|
401
|
+
self.use_task_dropdown.currentIndexChanged.connect(self.update_task)
|
402
|
+
layout.addRow("Task:", self.use_task_dropdown)
|
403
|
+
|
404
|
+
# Max detections spinbox
|
405
|
+
self.max_detections_spinbox = QSpinBox()
|
406
|
+
self.max_detections_spinbox.setRange(1, 10000)
|
407
|
+
self.max_detections_spinbox.setValue(self.max_detect)
|
408
|
+
layout.addRow("Max Detections:", self.max_detections_spinbox)
|
409
|
+
|
410
|
+
# Resize image dropdown
|
411
|
+
self.resize_image_dropdown = QComboBox()
|
412
|
+
self.resize_image_dropdown.addItems(["True", "False"])
|
413
|
+
self.resize_image_dropdown.setCurrentIndex(0)
|
414
|
+
self.resize_image_dropdown.setEnabled(False) # Grey out the dropdown
|
415
|
+
layout.addRow("Resize Image:", self.resize_image_dropdown)
|
416
|
+
|
417
|
+
# Image size control
|
418
|
+
self.imgsz_spinbox = QSpinBox()
|
419
|
+
self.imgsz_spinbox.setRange(512, 65536)
|
420
|
+
self.imgsz_spinbox.setSingleStep(1024)
|
421
|
+
self.imgsz_spinbox.setValue(self.imgsz)
|
422
|
+
layout.addRow("Image Size (imgsz):", self.imgsz_spinbox)
|
423
|
+
|
424
|
+
# Uncertainty threshold controls
|
425
|
+
self.uncertainty_thresh = self.main_window.get_uncertainty_thresh()
|
426
|
+
self.uncertainty_threshold_slider = QSlider(Qt.Horizontal)
|
427
|
+
self.uncertainty_threshold_slider.setRange(0, 100)
|
428
|
+
self.uncertainty_threshold_slider.setValue(int(self.main_window.get_uncertainty_thresh() * 100))
|
429
|
+
self.uncertainty_threshold_slider.setTickPosition(QSlider.TicksBelow)
|
430
|
+
self.uncertainty_threshold_slider.setTickInterval(10)
|
431
|
+
self.uncertainty_threshold_slider.valueChanged.connect(self.update_uncertainty_label)
|
432
|
+
self.uncertainty_threshold_label = QLabel(f"{self.uncertainty_thresh:.2f}")
|
433
|
+
layout.addRow("Uncertainty Threshold", self.uncertainty_threshold_slider)
|
434
|
+
layout.addRow("", self.uncertainty_threshold_label)
|
435
|
+
|
436
|
+
# IoU threshold controls
|
437
|
+
self.iou_thresh = self.main_window.get_iou_thresh()
|
438
|
+
self.iou_threshold_slider = QSlider(Qt.Horizontal)
|
439
|
+
self.iou_threshold_slider.setRange(0, 100)
|
440
|
+
self.iou_threshold_slider.setValue(int(self.iou_thresh * 100))
|
441
|
+
self.iou_threshold_slider.setTickPosition(QSlider.TicksBelow)
|
442
|
+
self.iou_threshold_slider.setTickInterval(10)
|
443
|
+
self.iou_threshold_slider.valueChanged.connect(self.update_iou_label)
|
444
|
+
self.iou_threshold_label = QLabel(f"{self.iou_thresh:.2f}")
|
445
|
+
layout.addRow("IoU Threshold", self.iou_threshold_slider)
|
446
|
+
layout.addRow("", self.iou_threshold_label)
|
447
|
+
|
448
|
+
# Area threshold controls
|
449
|
+
min_val, max_val = self.main_window.get_area_thresh()
|
450
|
+
self.area_thresh_min = int(min_val * 100)
|
451
|
+
self.area_thresh_max = int(max_val * 100)
|
452
|
+
self.area_threshold_min_slider = QSlider(Qt.Horizontal)
|
453
|
+
self.area_threshold_min_slider.setRange(0, 100)
|
454
|
+
self.area_threshold_min_slider.setValue(self.area_thresh_min)
|
455
|
+
self.area_threshold_min_slider.setTickPosition(QSlider.TicksBelow)
|
456
|
+
self.area_threshold_min_slider.setTickInterval(10)
|
457
|
+
self.area_threshold_min_slider.valueChanged.connect(self.update_area_label)
|
458
|
+
self.area_threshold_max_slider = QSlider(Qt.Horizontal)
|
459
|
+
self.area_threshold_max_slider.setRange(0, 100)
|
460
|
+
self.area_threshold_max_slider.setValue(self.area_thresh_max)
|
461
|
+
self.area_threshold_max_slider.setTickPosition(QSlider.TicksBelow)
|
462
|
+
self.area_threshold_max_slider.setTickInterval(10)
|
463
|
+
self.area_threshold_max_slider.valueChanged.connect(self.update_area_label)
|
464
|
+
self.area_threshold_label = QLabel(f"{self.area_thresh_min / 100.0:.2f} - {self.area_thresh_max / 100.0:.2f}")
|
465
|
+
layout.addRow("Area Threshold Min", self.area_threshold_min_slider)
|
466
|
+
layout.addRow("Area Threshold Max", self.area_threshold_max_slider)
|
467
|
+
layout.addRow("", self.area_threshold_label)
|
468
|
+
|
469
|
+
group_box.setLayout(layout)
|
470
|
+
self.left_panel.addWidget(group_box) # Add to left panel
|
471
|
+
|
472
|
+
def setup_sam_layout(self):
|
473
|
+
"""Use SAM model for segmentation."""
|
474
|
+
group_box = QGroupBox("Use SAM Model for Creating Polygons")
|
475
|
+
layout = QFormLayout()
|
476
|
+
|
477
|
+
# SAM dropdown
|
478
|
+
self.use_sam_dropdown = QComboBox()
|
479
|
+
self.use_sam_dropdown.addItems(["False", "True"])
|
480
|
+
self.use_sam_dropdown.currentIndexChanged.connect(self.is_sam_model_deployed)
|
481
|
+
layout.addRow("Use SAM Polygons:", self.use_sam_dropdown)
|
482
|
+
|
483
|
+
group_box.setLayout(layout)
|
484
|
+
self.left_panel.addWidget(group_box) # Add to left panel
|
485
|
+
|
486
|
+
def setup_model_buttons_layout(self):
|
487
|
+
"""
|
488
|
+
Setup action buttons in a group box.
|
489
|
+
"""
|
490
|
+
group_box = QGroupBox("Actions")
|
491
|
+
main_layout = QVBoxLayout()
|
492
|
+
|
493
|
+
# First row: Load and Deactivate buttons side by side
|
494
|
+
button_row = QHBoxLayout()
|
495
|
+
load_button = QPushButton("Load Model")
|
496
|
+
load_button.clicked.connect(self.load_model)
|
497
|
+
button_row.addWidget(load_button)
|
498
|
+
|
499
|
+
deactivate_button = QPushButton("Deactivate Model")
|
500
|
+
deactivate_button.clicked.connect(self.deactivate_model)
|
501
|
+
button_row.addWidget(deactivate_button)
|
502
|
+
|
503
|
+
main_layout.addLayout(button_row)
|
504
|
+
|
505
|
+
# Second row: Save VPE button + Show VPE button side by side
|
506
|
+
vpe_row = QHBoxLayout()
|
507
|
+
save_vpe_button = QPushButton("Save VPE")
|
508
|
+
save_vpe_button.clicked.connect(self.save_vpe)
|
509
|
+
vpe_row.addWidget(save_vpe_button)
|
510
|
+
|
511
|
+
show_vpe_button = QPushButton("Show VPE")
|
512
|
+
show_vpe_button.clicked.connect(self.show_vpe)
|
513
|
+
vpe_row.addWidget(show_vpe_button)
|
514
|
+
|
515
|
+
main_layout.addLayout(vpe_row)
|
516
|
+
|
517
|
+
group_box.setLayout(main_layout)
|
518
|
+
self.left_panel.addWidget(group_box) # Add to left panel
|
519
|
+
|
520
|
+
def setup_status_layout(self):
|
521
|
+
"""
|
522
|
+
Setup status display in a group box.
|
523
|
+
"""
|
524
|
+
group_box = QGroupBox("Status")
|
525
|
+
layout = QVBoxLayout()
|
526
|
+
|
527
|
+
self.status_bar = QLabel("No model loaded")
|
528
|
+
layout.addWidget(self.status_bar)
|
529
|
+
|
530
|
+
group_box.setLayout(layout)
|
531
|
+
self.left_panel.addWidget(group_box) # Add to left panel
|
532
|
+
|
533
|
+
def setup_reference_layout(self):
|
534
|
+
"""
|
535
|
+
Set up the layout with reference label selection.
|
536
|
+
The reference image is implicitly the currently active image.
|
537
|
+
"""
|
538
|
+
group_box = QGroupBox("Reference")
|
539
|
+
layout = QFormLayout()
|
540
|
+
|
541
|
+
# Create the reference label combo box
|
542
|
+
self.reference_label_combo_box = QComboBox()
|
543
|
+
self.reference_label_combo_box.currentIndexChanged.connect(self.filter_images_by_label_and_type)
|
544
|
+
layout.addRow("Reference Label:", self.reference_label_combo_box)
|
545
|
+
|
546
|
+
# Create a Reference model combobox (VPE, Images)
|
547
|
+
self.reference_method_combo_box = QComboBox()
|
548
|
+
self.reference_method_combo_box.addItems(["VPE", "Images"])
|
549
|
+
layout.addRow("Reference Method:", self.reference_method_combo_box)
|
550
|
+
|
551
|
+
group_box.setLayout(layout)
|
552
|
+
self.right_panel.addWidget(group_box) # Add to right panel
|
553
|
+
|
554
|
+
def setup_buttons_layout(self):
|
555
|
+
"""
|
556
|
+
Set up the layout with buttons.
|
557
|
+
"""
|
558
|
+
# Create a button box for the buttons
|
559
|
+
button_box = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel)
|
560
|
+
button_box.accepted.connect(self.accept)
|
561
|
+
button_box.rejected.connect(self.reject)
|
562
|
+
|
563
|
+
self.layout.addWidget(button_box)
|
564
|
+
|
565
|
+
def initialize_uncertainty_threshold(self):
|
566
|
+
"""Initialize the uncertainty threshold slider with the current value"""
|
567
|
+
current_value = self.main_window.get_uncertainty_thresh()
|
568
|
+
self.uncertainty_threshold_slider.setValue(int(current_value * 100))
|
569
|
+
self.uncertainty_thresh = current_value
|
570
|
+
|
571
|
+
def initialize_iou_threshold(self):
|
572
|
+
"""Initialize the IOU threshold slider with the current value"""
|
573
|
+
current_value = self.main_window.get_iou_thresh()
|
574
|
+
self.iou_threshold_slider.setValue(int(current_value * 100))
|
575
|
+
self.iou_thresh = current_value
|
576
|
+
|
577
|
+
def initialize_area_threshold(self):
|
578
|
+
"""Initialize the area threshold range slider"""
|
579
|
+
current_min, current_max = self.main_window.get_area_thresh()
|
580
|
+
self.area_threshold_min_slider.setValue(int(current_min * 100))
|
581
|
+
self.area_threshold_max_slider.setValue(int(current_max * 100))
|
582
|
+
self.area_thresh_min = current_min
|
583
|
+
self.area_thresh_max = current_max
|
584
|
+
|
585
|
+
def update_uncertainty_label(self, value):
|
586
|
+
"""Update uncertainty threshold and label"""
|
587
|
+
value = value / 100.0
|
588
|
+
self.uncertainty_thresh = value
|
589
|
+
self.main_window.update_uncertainty_thresh(value)
|
590
|
+
self.uncertainty_threshold_label.setText(f"{value:.2f}")
|
591
|
+
|
592
|
+
def update_iou_label(self, value):
|
593
|
+
"""Update IoU threshold and label"""
|
594
|
+
value = value / 100.0
|
595
|
+
self.iou_thresh = value
|
596
|
+
self.main_window.update_iou_thresh(value)
|
597
|
+
self.iou_threshold_label.setText(f"{value:.2f}")
|
598
|
+
|
599
|
+
def update_area_label(self):
|
600
|
+
"""Handle changes to area threshold range slider"""
|
601
|
+
min_val = self.area_threshold_min_slider.value()
|
602
|
+
max_val = self.area_threshold_max_slider.value()
|
603
|
+
if min_val > max_val:
|
604
|
+
min_val = max_val
|
605
|
+
self.area_threshold_min_slider.setValue(min_val)
|
606
|
+
self.area_thresh_min = min_val / 100.0
|
607
|
+
self.area_thresh_max = max_val / 100.0
|
608
|
+
self.main_window.update_area_thresh(self.area_thresh_min, self.area_thresh_max)
|
609
|
+
self.area_threshold_label.setText(f"{self.area_thresh_min:.2f} - {self.area_thresh_max:.2f}")
|
610
|
+
|
611
|
+
def update_stashed_references_from_ui(self):
|
612
|
+
"""Updates the internal reference path list from the current UI selection."""
|
613
|
+
self.reference_image_paths = self.image_selection_window.table_model.get_highlighted_paths()
|
614
|
+
|
615
|
+
def get_max_detections(self):
|
616
|
+
"""Get the maximum number of detections to return."""
|
617
|
+
self.max_detect = self.max_detections_spinbox.value()
|
618
|
+
return self.max_detect
|
619
|
+
|
620
|
+
def is_sam_model_deployed(self):
|
621
|
+
"""
|
622
|
+
Check if the SAM model is deployed and update the checkbox state accordingly.
|
623
|
+
|
624
|
+
:return: Boolean indicating whether the SAM model is deployed
|
625
|
+
"""
|
626
|
+
if not hasattr(self.main_window, 'sam_deploy_predictor_dialog'):
|
627
|
+
return False
|
628
|
+
|
629
|
+
self.sam_dialog = self.main_window.sam_deploy_predictor_dialog
|
630
|
+
|
631
|
+
if not self.sam_dialog.loaded_model:
|
632
|
+
self.use_sam_dropdown.setCurrentText("False")
|
633
|
+
QMessageBox.critical(self, "Error", "Please deploy the SAM model first.")
|
634
|
+
return False
|
635
|
+
|
636
|
+
return True
|
637
|
+
|
638
|
+
def update_sam_task_state(self):
|
639
|
+
"""
|
640
|
+
Centralized method to check if SAM is loaded and update task accordingly.
|
641
|
+
If the user has selected to use SAM, this function ensures the task is set to 'segment'.
|
642
|
+
Crucially, it does NOT alter the task if SAM is not selected, respecting the
|
643
|
+
user's choice from the 'Task' dropdown.
|
644
|
+
"""
|
645
|
+
# Check if the user wants to use the SAM model
|
646
|
+
if self.use_sam_dropdown.currentText() == "True":
|
647
|
+
# SAM is requested. Check if it's actually available.
|
648
|
+
sam_is_available = (
|
649
|
+
hasattr(self, 'sam_dialog') and
|
650
|
+
self.sam_dialog is not None and
|
651
|
+
self.sam_dialog.loaded_model is not None
|
652
|
+
)
|
653
|
+
|
654
|
+
if sam_is_available:
|
655
|
+
# If SAM is wanted and available, the task must be segmentation.
|
656
|
+
self.task = 'segment'
|
657
|
+
else:
|
658
|
+
# If SAM is wanted but not available, revert the dropdown and do nothing else.
|
659
|
+
# The 'is_sam_model_deployed' function already handles showing an error message.
|
660
|
+
self.use_sam_dropdown.setCurrentText("False")
|
661
|
+
|
662
|
+
# If use_sam_dropdown is "False", do nothing. Let self.task be whatever the user set.
|
663
|
+
|
664
|
+
def update_task(self):
|
665
|
+
"""Update the task based on the dropdown selection and handle UI/model effects."""
|
666
|
+
self.task = self.use_task_dropdown.currentText()
|
667
|
+
|
668
|
+
# Update UI elements based on task
|
669
|
+
if self.task == "segment":
|
670
|
+
# Deactivate model if one is loaded and we're switching to segment task
|
671
|
+
if self.loaded_model:
|
672
|
+
self.deactivate_model()
|
673
|
+
|
674
|
+
def update_reference_labels(self):
|
675
|
+
"""
|
676
|
+
Updates the reference label combo box with ALL available project labels.
|
677
|
+
This dropdown now serves as the "Output Label" for all predictions.
|
678
|
+
The "Review" label with id "-1" is excluded.
|
679
|
+
"""
|
680
|
+
self.reference_label_combo_box.blockSignals(True)
|
681
|
+
|
682
|
+
try:
|
683
|
+
self.reference_label_combo_box.clear()
|
684
|
+
|
685
|
+
# Get all labels from the main label window
|
686
|
+
all_project_labels = self.main_window.label_window.labels
|
687
|
+
|
688
|
+
# Filter out the special "Review" label and create a list of valid labels
|
689
|
+
valid_labels = [
|
690
|
+
label_obj for label_obj in all_project_labels
|
691
|
+
if not (label_obj.short_label_code == "Review" and str(label_obj.id) == "-1")
|
692
|
+
]
|
693
|
+
|
694
|
+
# Add the valid labels to the combo box, sorted alphabetically.
|
695
|
+
sorted_valid_labels = sorted(valid_labels, key=lambda x: x.short_label_code)
|
696
|
+
for label_obj in sorted_valid_labels:
|
697
|
+
self.reference_label_combo_box.addItem(label_obj.short_label_code, label_obj)
|
698
|
+
|
699
|
+
# Restore the last selected label if it's still present in the list.
|
700
|
+
if self.last_selected_label_code:
|
701
|
+
index = self.reference_label_combo_box.findText(self.last_selected_label_code)
|
702
|
+
if index != -1:
|
703
|
+
self.reference_label_combo_box.setCurrentIndex(index)
|
704
|
+
finally:
|
705
|
+
self.reference_label_combo_box.blockSignals(False)
|
706
|
+
|
707
|
+
# Manually trigger the image filtering now that the combo box is stable.
|
708
|
+
# This will still filter the image list to help find references if needed.
|
709
|
+
self.filter_images_by_label_and_type()
|
710
|
+
|
711
|
+
def get_reference_annotations(self, reference_label, reference_image_path):
|
712
|
+
"""
|
713
|
+
Return a list of bboxes and masks for a specific image
|
714
|
+
belonging to the selected label.
|
715
|
+
|
716
|
+
:param reference_label: The Label object to filter annotations by.
|
717
|
+
:param reference_image_path: The path of the image to get annotations from.
|
718
|
+
:return: A tuple containing a numpy array of bboxes and a list of masks.
|
719
|
+
"""
|
720
|
+
if not all([reference_label, reference_image_path]):
|
721
|
+
return np.array([]), []
|
722
|
+
|
723
|
+
# Get all annotations for the specified image
|
724
|
+
annotations = self.annotation_window.get_image_annotations(reference_image_path)
|
725
|
+
|
726
|
+
# Filter annotations by the provided label
|
727
|
+
reference_bboxes = []
|
728
|
+
reference_masks = []
|
729
|
+
for annotation in annotations:
|
730
|
+
if annotation.label.short_label_code == reference_label.short_label_code:
|
731
|
+
if isinstance(annotation, (PolygonAnnotation, RectangleAnnotation)):
|
732
|
+
bbox = annotation.cropped_bbox
|
733
|
+
reference_bboxes.append(bbox)
|
734
|
+
if isinstance(annotation, PolygonAnnotation):
|
735
|
+
points = np.array([[p.x(), p.y()] for p in annotation.points])
|
736
|
+
reference_masks.append(points)
|
737
|
+
elif isinstance(annotation, RectangleAnnotation):
|
738
|
+
x1, y1, x2, y2 = bbox
|
739
|
+
rect_points = np.array([[x1, y1], [x2, y1], [x2, y2], [x1, y2]])
|
740
|
+
reference_masks.append(rect_points)
|
741
|
+
|
742
|
+
return np.array(reference_bboxes), reference_masks
|
743
|
+
|
744
|
+
def browse_vpe_file(self):
|
745
|
+
"""
|
746
|
+
Open a file dialog to browse for a VPE file and load it.
|
747
|
+
Stores imported VPEs separately from reference-generated VPEs.
|
748
|
+
"""
|
749
|
+
file_path, _ = QFileDialog.getOpenFileName(
|
750
|
+
self,
|
751
|
+
"Select Visual Prompt Encoding (VPE) File",
|
752
|
+
"",
|
753
|
+
"VPE Files (*.pt);;All Files (*)"
|
754
|
+
)
|
755
|
+
|
756
|
+
if not file_path:
|
757
|
+
return
|
758
|
+
|
759
|
+
self.vpe_path_edit.setText(file_path)
|
760
|
+
self.vpe_path = file_path
|
761
|
+
|
762
|
+
try:
|
763
|
+
# Load the VPE file
|
764
|
+
loaded_data = torch.load(file_path)
|
765
|
+
|
766
|
+
# TODO Move tensors to the appropriate device
|
767
|
+
# device = self.main_window.device
|
768
|
+
|
769
|
+
# Check format type and handle appropriately
|
770
|
+
if isinstance(loaded_data, list):
|
771
|
+
# New format: list of VPE tensors
|
772
|
+
self.imported_vpes = [vpe.to(self.device) for vpe in loaded_data]
|
773
|
+
vpe_count = len(self.imported_vpes)
|
774
|
+
self.status_bar.setText(f"Loaded {vpe_count} VPE tensors from file")
|
775
|
+
|
776
|
+
elif isinstance(loaded_data, torch.Tensor):
|
777
|
+
# Legacy format: single tensor - convert to list for consistency
|
778
|
+
loaded_vpe = loaded_data.to(self.device)
|
779
|
+
# Store as a single-item list
|
780
|
+
self.imported_vpes = [loaded_vpe]
|
781
|
+
self.status_bar.setText("Loaded 1 VPE tensor from file (legacy format)")
|
782
|
+
|
783
|
+
else:
|
784
|
+
# Invalid format
|
785
|
+
self.imported_vpes = []
|
786
|
+
self.status_bar.setText("Invalid VPE file format")
|
787
|
+
QMessageBox.warning(
|
788
|
+
self,
|
789
|
+
"Invalid VPE",
|
790
|
+
"The file does not appear to be a valid VPE format."
|
791
|
+
)
|
792
|
+
# Clear the VPE path edit field
|
793
|
+
self.vpe_path_edit.clear()
|
794
|
+
|
795
|
+
# For backward compatibility - set self.vpe to the average of imported VPEs
|
796
|
+
# This ensures older code paths still work
|
797
|
+
if self.imported_vpes:
|
798
|
+
combined_vpe = torch.cat(self.imported_vpes).mean(dim=0, keepdim=True)
|
799
|
+
self.vpe = torch.nn.functional.normalize(combined_vpe, p=2, dim=-1)
|
800
|
+
|
801
|
+
except Exception as e:
|
802
|
+
self.imported_vpes = []
|
803
|
+
self.vpe = None
|
804
|
+
self.status_bar.setText(f"Error loading VPE: {str(e)}")
|
805
|
+
QMessageBox.critical(
|
806
|
+
self,
|
807
|
+
"Error Loading VPE",
|
808
|
+
f"Failed to load VPE file: {str(e)}"
|
809
|
+
)
|
810
|
+
|
811
|
+
def save_vpe(self):
|
812
|
+
"""
|
813
|
+
Save the combined collection of VPEs (imported and reference-generated) to disk.
|
814
|
+
"""
|
815
|
+
# Always sync with the live UI selection before saving.
|
816
|
+
self.update_stashed_references_from_ui()
|
817
|
+
|
818
|
+
# Create a list to hold all VPEs
|
819
|
+
all_vpes = []
|
820
|
+
|
821
|
+
# Add imported VPEs if available
|
822
|
+
if self.imported_vpes:
|
823
|
+
all_vpes.extend(self.imported_vpes)
|
824
|
+
|
825
|
+
# Check if we should generate new VPEs from reference images
|
826
|
+
references_dict = self._get_references()
|
827
|
+
if references_dict:
|
828
|
+
# Reload the model to ensure clean state
|
829
|
+
self.reload_model()
|
830
|
+
|
831
|
+
# Convert references to VPEs without updating self.reference_vpes yet
|
832
|
+
new_vpes = self.references_to_vpe(references_dict, update_reference_vpes=False)
|
833
|
+
|
834
|
+
if new_vpes:
|
835
|
+
# Add new VPEs to collection
|
836
|
+
all_vpes.extend(new_vpes)
|
837
|
+
# Update reference_vpes with the new ones
|
838
|
+
self.reference_vpes = new_vpes
|
839
|
+
else:
|
840
|
+
# Include existing reference VPEs if we have them
|
841
|
+
if self.reference_vpes:
|
842
|
+
all_vpes.extend(self.reference_vpes)
|
843
|
+
|
844
|
+
# Check if we have any VPEs to save
|
845
|
+
if not all_vpes:
|
846
|
+
QMessageBox.warning(
|
847
|
+
self,
|
848
|
+
"No VPEs Available",
|
849
|
+
"No VPEs available to save. Please either load a VPE file or select reference images."
|
850
|
+
)
|
851
|
+
return
|
852
|
+
|
853
|
+
# Open file dialog to select save location
|
854
|
+
file_path, _ = QFileDialog.getSaveFileName(
|
855
|
+
self,
|
856
|
+
"Save VPE Collection",
|
857
|
+
"",
|
858
|
+
"PyTorch Tensor (*.pt);;All Files (*)"
|
859
|
+
)
|
860
|
+
|
861
|
+
if not file_path:
|
862
|
+
return # User canceled the dialog
|
863
|
+
|
864
|
+
# Add .pt extension if not present
|
865
|
+
if not file_path.endswith('.pt'):
|
866
|
+
file_path += '.pt'
|
867
|
+
|
868
|
+
try:
|
869
|
+
# Move tensors to CPU before saving
|
870
|
+
vpe_list_cpu = [vpe.cpu() for vpe in all_vpes]
|
871
|
+
|
872
|
+
# Save the list of tensors
|
873
|
+
torch.save(vpe_list_cpu, file_path)
|
874
|
+
|
875
|
+
self.status_bar.setText(f"Saved {len(all_vpes)} VPE tensors to {os.path.basename(file_path)}")
|
876
|
+
QMessageBox.information(
|
877
|
+
self,
|
878
|
+
"VPE Saved",
|
879
|
+
f"Saved {len(all_vpes)} VPE tensors to {file_path}"
|
880
|
+
)
|
881
|
+
except Exception as e:
|
882
|
+
QMessageBox.critical(
|
883
|
+
self,
|
884
|
+
"Error Saving VPE",
|
885
|
+
f"Failed to save VPE: {str(e)}"
|
886
|
+
)
|
887
|
+
|
888
|
+
def load_model(self):
|
889
|
+
"""
|
890
|
+
Load the selected model.
|
891
|
+
"""
|
892
|
+
QApplication.setOverrideCursor(Qt.WaitCursor)
|
893
|
+
progress_bar = ProgressBar(self.annotation_window, title="Loading Model")
|
894
|
+
progress_bar.show()
|
895
|
+
|
896
|
+
try:
|
897
|
+
# Load the model using reload_model method
|
898
|
+
self.reload_model()
|
899
|
+
|
900
|
+
# Calculate total number of VPEs from both sources
|
901
|
+
total_vpes = len(self.imported_vpes) + len(self.reference_vpes)
|
902
|
+
|
903
|
+
if total_vpes > 0:
|
904
|
+
if self.imported_vpes and self.reference_vpes:
|
905
|
+
message = f"Model loaded with {len(self.imported_vpes)} imported VPEs "
|
906
|
+
message += f"and {len(self.reference_vpes)} reference VPEs"
|
907
|
+
elif self.imported_vpes:
|
908
|
+
message = f"Model loaded with {len(self.imported_vpes)} imported VPEs"
|
909
|
+
else:
|
910
|
+
message = f"Model loaded with {len(self.reference_vpes)} reference VPEs"
|
911
|
+
|
912
|
+
self.status_bar.setText(message)
|
913
|
+
else:
|
914
|
+
message = "Model loaded with default VPE"
|
915
|
+
self.status_bar.setText("Model loaded with default VPE")
|
916
|
+
|
917
|
+
# Finish progress bar
|
918
|
+
progress_bar.finish_progress()
|
919
|
+
QMessageBox.information(self.annotation_window, "Model Loaded", message)
|
920
|
+
|
921
|
+
except Exception as e:
|
922
|
+
self.loaded_model = None
|
923
|
+
QMessageBox.critical(self.annotation_window,
|
924
|
+
"Error Loading Model",
|
925
|
+
f"Error loading model: {e}")
|
926
|
+
|
927
|
+
finally:
|
928
|
+
# Restore cursor
|
929
|
+
QApplication.restoreOverrideCursor()
|
930
|
+
# Stop the progress bar
|
931
|
+
progress_bar.stop_progress()
|
932
|
+
progress_bar.close()
|
933
|
+
progress_bar = None
|
934
|
+
|
935
|
+
def reload_model(self):
|
936
|
+
"""
|
937
|
+
Subset of the load_model method. This is needed when additional
|
938
|
+
reference images and annotations (i.e., VPEs) are added (we have
|
939
|
+
to re-load the model each time).
|
940
|
+
|
941
|
+
This method also ensures that we stash the currently highlighted reference
|
942
|
+
image paths before reloading, so they're available for predictions
|
943
|
+
even if the user switches the active image in the main window.
|
944
|
+
"""
|
945
|
+
self.loaded_model = None
|
946
|
+
|
947
|
+
# Get selected model path and download weights if needed
|
948
|
+
self.model_path = self.model_combo.currentText()
|
949
|
+
|
950
|
+
# Load model using registry
|
951
|
+
self.loaded_model = YOLOE(self.model_path, verbose=False).to(self.device) # TODO
|
952
|
+
|
953
|
+
# Create a dummy visual dictionary for standard model loading
|
954
|
+
visual_prompts = dict(
|
955
|
+
bboxes=np.array(
|
956
|
+
[
|
957
|
+
[120, 425, 160, 445], # Random box
|
958
|
+
],
|
959
|
+
),
|
960
|
+
cls=np.array(
|
961
|
+
np.zeros(1),
|
962
|
+
),
|
963
|
+
)
|
964
|
+
|
965
|
+
# Run a dummy prediction to load the model
|
966
|
+
self.loaded_model.predict(
|
967
|
+
np.zeros((640, 640, 3), dtype=np.uint8),
|
968
|
+
visual_prompts=visual_prompts.copy(), # This needs to happen to properly initialize the predictor
|
969
|
+
predictor=YOLOEVPSegPredictor, # This also needs to be SegPredictor, no matter what
|
970
|
+
imgsz=640,
|
971
|
+
conf=0.99,
|
972
|
+
)
|
973
|
+
|
974
|
+
# If a VPE file was loaded, use it with the model after the dummy prediction
|
975
|
+
if self.vpe is not None and isinstance(self.vpe, torch.Tensor):
|
976
|
+
# Directly set the final tensor as the prompt for the predictor
|
977
|
+
self.loaded_model.is_fused = lambda: False
|
978
|
+
self.loaded_model.set_classes(["object0"], self.vpe)
|
979
|
+
|
980
|
+
def predict(self, image_paths=None):
|
981
|
+
"""
|
982
|
+
Make predictions on the given image paths using the loaded model.
|
983
|
+
|
984
|
+
Args:
|
985
|
+
image_paths: List of image paths to process. If None, uses the current image.
|
986
|
+
"""
|
987
|
+
if not self.loaded_model or not self.reference_label:
|
988
|
+
return
|
989
|
+
|
990
|
+
# Update class mapping with the selected reference label
|
991
|
+
self.class_mapping = {0: self.reference_label}
|
992
|
+
|
993
|
+
# Create a results processor
|
994
|
+
results_processor = ResultsProcessor(
|
995
|
+
self.main_window,
|
996
|
+
self.class_mapping,
|
997
|
+
uncertainty_thresh=self.main_window.get_uncertainty_thresh(),
|
998
|
+
iou_thresh=self.main_window.get_iou_thresh(),
|
999
|
+
min_area_thresh=self.main_window.get_area_thresh_min(),
|
1000
|
+
max_area_thresh=self.main_window.get_area_thresh_max()
|
1001
|
+
)
|
1002
|
+
|
1003
|
+
if not image_paths:
|
1004
|
+
# Predict only the current image
|
1005
|
+
image_paths = [self.annotation_window.current_image_path]
|
1006
|
+
|
1007
|
+
# Make cursor busy
|
1008
|
+
QApplication.setOverrideCursor(Qt.WaitCursor)
|
1009
|
+
|
1010
|
+
# Start the progress bar
|
1011
|
+
progress_bar = ProgressBar(self.annotation_window, title="Prediction Workflow")
|
1012
|
+
progress_bar.show()
|
1013
|
+
progress_bar.start_progress(len(image_paths))
|
1014
|
+
|
1015
|
+
try:
|
1016
|
+
for image_path in image_paths:
|
1017
|
+
inputs = self._get_inputs(image_path)
|
1018
|
+
if inputs is None:
|
1019
|
+
continue
|
1020
|
+
|
1021
|
+
results = self._apply_model(inputs)
|
1022
|
+
results = self._apply_sam(results, image_path)
|
1023
|
+
self._process_results(results_processor, results, image_path)
|
1024
|
+
|
1025
|
+
# Update the progress bar
|
1026
|
+
progress_bar.update_progress()
|
1027
|
+
|
1028
|
+
except Exception as e:
|
1029
|
+
print("An error occurred during prediction:", e)
|
1030
|
+
finally:
|
1031
|
+
QApplication.restoreOverrideCursor()
|
1032
|
+
progress_bar.finish_progress()
|
1033
|
+
progress_bar.stop_progress()
|
1034
|
+
progress_bar.close()
|
1035
|
+
|
1036
|
+
gc.collect()
|
1037
|
+
empty_cache()
|
1038
|
+
|
1039
|
+
def _get_inputs(self, image_path):
|
1040
|
+
"""Get the inputs for the model prediction."""
|
1041
|
+
raster = self.image_window.raster_manager.get_raster(image_path)
|
1042
|
+
if self.annotation_window.get_selected_tool() != "work_area":
|
1043
|
+
# Use the image path
|
1044
|
+
work_areas_data = [raster.image_path]
|
1045
|
+
else:
|
1046
|
+
# Get the work areas
|
1047
|
+
work_areas_data = raster.get_work_areas_data()
|
1048
|
+
|
1049
|
+
return work_areas_data
|
1050
|
+
|
1051
|
+
def _get_references(self):
|
1052
|
+
"""
|
1053
|
+
Get the reference annotations using the stashed list of reference images
|
1054
|
+
that was saved when the user accepted the dialog.
|
1055
|
+
|
1056
|
+
Returns:
|
1057
|
+
dict: Dictionary mapping image paths to annotation data, or empty dict if no valid references.
|
1058
|
+
"""
|
1059
|
+
# Use the "stashed" list of paths. Do NOT query the table_model again,
|
1060
|
+
# as the UI's highlight state may have been cleared by other actions.
|
1061
|
+
reference_paths = self.reference_image_paths
|
1062
|
+
|
1063
|
+
if not reference_paths:
|
1064
|
+
print("No reference image paths were stashed to use for prediction.")
|
1065
|
+
return {}
|
1066
|
+
|
1067
|
+
# Get the reference label that was also stashed
|
1068
|
+
reference_label = self.reference_label
|
1069
|
+
if not reference_label:
|
1070
|
+
# This check is a safeguard; the accept() method should prevent this.
|
1071
|
+
print("No reference label was selected.")
|
1072
|
+
return {}
|
1073
|
+
|
1074
|
+
# Create a dictionary of reference annotations from the stashed paths
|
1075
|
+
reference_annotations_dict = {}
|
1076
|
+
for path in reference_paths:
|
1077
|
+
bboxes, masks = self.get_reference_annotations(reference_label, path)
|
1078
|
+
if bboxes.size > 0:
|
1079
|
+
reference_annotations_dict[path] = {
|
1080
|
+
'bboxes': bboxes,
|
1081
|
+
'masks': masks,
|
1082
|
+
'cls': np.zeros(len(bboxes))
|
1083
|
+
}
|
1084
|
+
|
1085
|
+
return reference_annotations_dict
|
1086
|
+
|
1087
|
+
def _apply_model_using_images(self, inputs, reference_dict):
|
1088
|
+
"""
|
1089
|
+
Apply the model using the provided images and reference annotations (dict). This method
|
1090
|
+
loops through each reference image using its annotations; we then aggregate
|
1091
|
+
all the results together. Less efficient, but potentially more accurate.
|
1092
|
+
|
1093
|
+
Args:
|
1094
|
+
inputs (list): List of input images.
|
1095
|
+
reference_dict (dict): Dictionary containing reference annotations for each image.
|
1096
|
+
|
1097
|
+
Returns:
|
1098
|
+
list: List of prediction results.
|
1099
|
+
"""
|
1100
|
+
# Create a progress bar for iterating through reference images
|
1101
|
+
QApplication.setOverrideCursor(Qt.WaitCursor)
|
1102
|
+
progress_bar = ProgressBar(self.annotation_window, title="Making Predictions per Reference")
|
1103
|
+
progress_bar.show()
|
1104
|
+
progress_bar.start_progress(len(reference_dict))
|
1105
|
+
|
1106
|
+
results_list = []
|
1107
|
+
# The 'inputs' list contains work areas from the single target image.
|
1108
|
+
# We will predict on the first work area/full image.
|
1109
|
+
input_image = inputs[0]
|
1110
|
+
|
1111
|
+
# Iterate through each reference image and its annotations
|
1112
|
+
for ref_path, ref_annotations in reference_dict.items():
|
1113
|
+
# The 'refer_image' parameter is the path to the current reference image
|
1114
|
+
# The 'visual_prompts' are the annotations from that same reference image
|
1115
|
+
visual_prompts = {
|
1116
|
+
'bboxes': ref_annotations['bboxes'],
|
1117
|
+
'cls': ref_annotations['cls'],
|
1118
|
+
}
|
1119
|
+
if self.task == 'segment':
|
1120
|
+
visual_prompts['masks'] = ref_annotations['masks']
|
1121
|
+
|
1122
|
+
# Make predictions on the target using the current reference
|
1123
|
+
results = self.loaded_model.predict(input_image,
|
1124
|
+
refer_image=ref_path,
|
1125
|
+
visual_prompts=visual_prompts,
|
1126
|
+
predictor=YOLOEVPSegPredictor, # TODO This is necessary here?
|
1127
|
+
imgsz=self.imgsz_spinbox.value(),
|
1128
|
+
conf=self.main_window.get_uncertainty_thresh(),
|
1129
|
+
iou=self.main_window.get_iou_thresh(),
|
1130
|
+
max_det=self.get_max_detections(),
|
1131
|
+
retina_masks=self.task == "segment")
|
1132
|
+
|
1133
|
+
if not len(results[0].boxes):
|
1134
|
+
# If no boxes were detected, skip to the next reference
|
1135
|
+
progress_bar.update_progress()
|
1136
|
+
continue
|
1137
|
+
|
1138
|
+
# Update the name of the results and append to the list
|
1139
|
+
results[0].names = {0: self.class_mapping[0].short_label_code}
|
1140
|
+
results_list.extend(results[0])
|
1141
|
+
|
1142
|
+
progress_bar.update_progress()
|
1143
|
+
gc.collect()
|
1144
|
+
empty_cache()
|
1145
|
+
|
1146
|
+
# Clean up
|
1147
|
+
QApplication.restoreOverrideCursor()
|
1148
|
+
progress_bar.finish_progress()
|
1149
|
+
progress_bar.stop_progress()
|
1150
|
+
progress_bar.close()
|
1151
|
+
|
1152
|
+
# Combine results if there are any
|
1153
|
+
combined_results = CombineResults().combine_results(results_list)
|
1154
|
+
if combined_results is None:
|
1155
|
+
return []
|
1156
|
+
|
1157
|
+
return [[combined_results]]
|
1158
|
+
|
1159
|
+
def references_to_vpe(self, reference_dict, update_reference_vpes=True):
|
1160
|
+
"""
|
1161
|
+
Converts the contents of a reference dictionary to VPEs (Visual Prompt Embeddings).
|
1162
|
+
Reference dictionaries contain information about the visual prompts for each reference image:
|
1163
|
+
dict[image_path]: {bboxes, masks, cls}
|
1164
|
+
|
1165
|
+
Args:
|
1166
|
+
reference_dict (dict): The reference dictionary containing visual prompts for each image.
|
1167
|
+
update_reference_vpes (bool): Whether to update self.reference_vpes with the results.
|
1168
|
+
|
1169
|
+
Returns:
|
1170
|
+
list: List of individual VPE tensors (normalized), or None if empty reference_dict
|
1171
|
+
"""
|
1172
|
+
# Check if the reference dictionary is empty
|
1173
|
+
if not reference_dict:
|
1174
|
+
return None
|
1175
|
+
|
1176
|
+
# Create a list to hold the individual VPE tensors
|
1177
|
+
vpe_list = []
|
1178
|
+
|
1179
|
+
for ref_path, ref_annotations in reference_dict.items():
|
1180
|
+
# Set the prompts to the model predictor
|
1181
|
+
self.loaded_model.predictor.set_prompts(ref_annotations)
|
1182
|
+
|
1183
|
+
# Get the VPE from the model
|
1184
|
+
vpe = self.loaded_model.predictor.get_vpe(ref_path)
|
1185
|
+
|
1186
|
+
# Normalize individual VPE
|
1187
|
+
vpe_normalized = torch.nn.functional.normalize(vpe, p=2, dim=-1)
|
1188
|
+
vpe_list.append(vpe_normalized)
|
1189
|
+
|
1190
|
+
# Check if we have any valid VPEs
|
1191
|
+
if not vpe_list:
|
1192
|
+
return None
|
1193
|
+
|
1194
|
+
# Update the reference_vpes list if requested
|
1195
|
+
if update_reference_vpes:
|
1196
|
+
self.reference_vpes = vpe_list
|
1197
|
+
|
1198
|
+
return vpe_list
|
1199
|
+
|
1200
|
+
def _apply_model_using_vpe(self, inputs, references_dict):
|
1201
|
+
"""
|
1202
|
+
Apply the model to the inputs using combined VPEs from both imported files
|
1203
|
+
and reference annotations.
|
1204
|
+
|
1205
|
+
Args:
|
1206
|
+
inputs (list): List of input images.
|
1207
|
+
references_dict (dict): Dictionary containing reference annotations for each image.
|
1208
|
+
|
1209
|
+
Returns:
|
1210
|
+
list: List of prediction results.
|
1211
|
+
"""
|
1212
|
+
# First reload the model to clear any cached data
|
1213
|
+
self.reload_model()
|
1214
|
+
|
1215
|
+
# Initialize combined_vpes list
|
1216
|
+
combined_vpes = []
|
1217
|
+
|
1218
|
+
# Add imported VPEs if available
|
1219
|
+
if self.imported_vpes:
|
1220
|
+
combined_vpes.extend(self.imported_vpes)
|
1221
|
+
|
1222
|
+
# Process reference images to VPEs if any exist
|
1223
|
+
if references_dict:
|
1224
|
+
# Only update reference_vpes if references_dict is not empty
|
1225
|
+
reference_vpes = self.references_to_vpe(references_dict, update_reference_vpes=True)
|
1226
|
+
if reference_vpes:
|
1227
|
+
combined_vpes.extend(reference_vpes)
|
1228
|
+
else:
|
1229
|
+
# Use existing reference_vpes if we have them
|
1230
|
+
if self.reference_vpes:
|
1231
|
+
combined_vpes.extend(self.reference_vpes)
|
1232
|
+
|
1233
|
+
# Check if we have any VPEs to use
|
1234
|
+
if not combined_vpes:
|
1235
|
+
QMessageBox.warning(
|
1236
|
+
self,
|
1237
|
+
"No VPEs Available",
|
1238
|
+
"No VPEs available for prediction. Please either load a VPE file or select reference images."
|
1239
|
+
)
|
1240
|
+
return []
|
1241
|
+
|
1242
|
+
# Average all the VPEs together to create a final VPE tensor
|
1243
|
+
averaged_vpe = torch.cat(combined_vpes).mean(dim=0, keepdim=True)
|
1244
|
+
final_vpe = torch.nn.functional.normalize(averaged_vpe, p=2, dim=-1)
|
1245
|
+
|
1246
|
+
# For backward compatibility, update self.vpe
|
1247
|
+
self.vpe = final_vpe
|
1248
|
+
|
1249
|
+
# Set the final VPE to the model
|
1250
|
+
self.loaded_model.is_fused = lambda: False
|
1251
|
+
self.loaded_model.set_classes(["object0"], final_vpe)
|
1252
|
+
|
1253
|
+
# Make predictions on the target using the averaged VPE
|
1254
|
+
results = self.loaded_model.predict(inputs[0],
|
1255
|
+
visual_prompts=[],
|
1256
|
+
imgsz=self.imgsz_spinbox.value(),
|
1257
|
+
conf=self.main_window.get_uncertainty_thresh(),
|
1258
|
+
iou=self.main_window.get_iou_thresh(),
|
1259
|
+
max_det=self.get_max_detections(),
|
1260
|
+
retina_masks=self.task == "segment")
|
1261
|
+
|
1262
|
+
return [results]
|
1263
|
+
|
1264
|
+
def _apply_model(self, inputs):
|
1265
|
+
"""
|
1266
|
+
Apply the model to the target inputs. This method handles both image-based
|
1267
|
+
references and VPE-based references.
|
1268
|
+
"""
|
1269
|
+
# Update the model with user parameters
|
1270
|
+
self.task = self.use_task_dropdown.currentText()
|
1271
|
+
|
1272
|
+
self.loaded_model.conf = self.main_window.get_uncertainty_thresh()
|
1273
|
+
self.loaded_model.iou = self.main_window.get_iou_thresh()
|
1274
|
+
self.loaded_model.max_det = self.get_max_detections()
|
1275
|
+
|
1276
|
+
# Get the reference information for the currently selected rows
|
1277
|
+
references_dict = self._get_references()
|
1278
|
+
|
1279
|
+
# Check if the user is using VPE or Reference Images
|
1280
|
+
if self.reference_method_combo_box.currentText() == "VPE":
|
1281
|
+
# Check if we have any VPEs available (imported or reference-generated)
|
1282
|
+
has_vpes = bool(self.imported_vpes or self.reference_vpes)
|
1283
|
+
|
1284
|
+
# If we have reference images selected but no imported VPEs yet,
|
1285
|
+
# warn the user only if we also don't have any reference images
|
1286
|
+
if not has_vpes and not references_dict:
|
1287
|
+
QMessageBox.warning(
|
1288
|
+
self,
|
1289
|
+
"No VPEs Available",
|
1290
|
+
"No VPEs available for prediction. Please either load a VPE file or select reference images."
|
1291
|
+
)
|
1292
|
+
return []
|
1293
|
+
|
1294
|
+
# Use the VPE method, which will combine imported and reference VPEs
|
1295
|
+
results = self._apply_model_using_vpe(inputs, references_dict)
|
1296
|
+
else:
|
1297
|
+
# Use Reference Images method - requires reference images
|
1298
|
+
if not references_dict:
|
1299
|
+
QMessageBox.warning(
|
1300
|
+
self,
|
1301
|
+
"No References Selected",
|
1302
|
+
"No reference images with valid annotations were selected. "
|
1303
|
+
"Please select at least one reference image."
|
1304
|
+
)
|
1305
|
+
return []
|
1306
|
+
|
1307
|
+
results = self._apply_model_using_images(inputs, references_dict)
|
1308
|
+
|
1309
|
+
return results
|
1310
|
+
|
1311
|
+
def _apply_sam(self, results_list, image_path):
|
1312
|
+
"""Apply SAM to the results if needed."""
|
1313
|
+
# Check if SAM model is deployed and loaded
|
1314
|
+
self.update_sam_task_state()
|
1315
|
+
if self.task != 'segment':
|
1316
|
+
return results_list
|
1317
|
+
|
1318
|
+
if not self.sam_dialog or self.use_sam_dropdown.currentText() == "False":
|
1319
|
+
# If SAM is not deployed or not selected, return the results as is
|
1320
|
+
return results_list
|
1321
|
+
|
1322
|
+
if self.sam_dialog.loaded_model is None:
|
1323
|
+
# If SAM is not loaded, ensure we do not use it accidentally
|
1324
|
+
self.task = 'detect'
|
1325
|
+
self.use_sam_dropdown.setCurrentText("False")
|
1326
|
+
return results_list
|
1327
|
+
|
1328
|
+
# Make cursor busy
|
1329
|
+
QApplication.setOverrideCursor(Qt.WaitCursor)
|
1330
|
+
progress_bar = ProgressBar(self.annotation_window, title="Predicting with SAM")
|
1331
|
+
progress_bar.show()
|
1332
|
+
progress_bar.start_progress(len(results_list))
|
1333
|
+
|
1334
|
+
updated_results = []
|
1335
|
+
|
1336
|
+
for idx, results in enumerate(results_list):
|
1337
|
+
# Each Results is a list (within the results_list, [[], ]
|
1338
|
+
if results:
|
1339
|
+
# Run it rough the SAM model
|
1340
|
+
results = self.sam_dialog.predict_from_results(results, image_path)
|
1341
|
+
updated_results.append(results)
|
1342
|
+
|
1343
|
+
# Update the progress bar
|
1344
|
+
progress_bar.update_progress()
|
1345
|
+
|
1346
|
+
# Make cursor normal
|
1347
|
+
QApplication.restoreOverrideCursor()
|
1348
|
+
progress_bar.finish_progress()
|
1349
|
+
progress_bar.stop_progress()
|
1350
|
+
progress_bar.close()
|
1351
|
+
|
1352
|
+
return updated_results
|
1353
|
+
|
1354
|
+
def _process_results(self, results_processor, results_list, image_path):
|
1355
|
+
"""Process the results using the result processor."""
|
1356
|
+
# Get the raster object and number of work items
|
1357
|
+
raster = self.image_window.raster_manager.get_raster(image_path)
|
1358
|
+
total = raster.count_work_items()
|
1359
|
+
|
1360
|
+
# Get the work areas (if any)
|
1361
|
+
work_areas = raster.get_work_areas()
|
1362
|
+
|
1363
|
+
# Start the progress bar
|
1364
|
+
progress_bar = ProgressBar(self.annotation_window, title="Processing Results")
|
1365
|
+
progress_bar.show()
|
1366
|
+
progress_bar.start_progress(total)
|
1367
|
+
|
1368
|
+
updated_results = []
|
1369
|
+
|
1370
|
+
for idx, results in enumerate(results_list):
|
1371
|
+
# Each Results is a list (within the results_list, [[], ]
|
1372
|
+
if results:
|
1373
|
+
# Update path and names
|
1374
|
+
results[0].path = image_path
|
1375
|
+
results[0].names = {0: self.class_mapping[0].short_label_code}
|
1376
|
+
# This needs to be done again, in case SAM was used
|
1377
|
+
|
1378
|
+
# Check if the work area is valid, or the image path is being used
|
1379
|
+
if work_areas and self.annotation_window.get_selected_tool() == "work_area":
|
1380
|
+
# Map results from work area to the full image
|
1381
|
+
results = MapResults().map_results_from_work_area(results[0],
|
1382
|
+
raster,
|
1383
|
+
work_areas[idx],
|
1384
|
+
self.task == "segment")
|
1385
|
+
else:
|
1386
|
+
results = results[0]
|
1387
|
+
|
1388
|
+
# Append the result object (not a list) to the updated results list
|
1389
|
+
updated_results.append(results)
|
1390
|
+
|
1391
|
+
# Update the index for the next work area
|
1392
|
+
idx += 1
|
1393
|
+
progress_bar.update_progress()
|
1394
|
+
|
1395
|
+
# Process the Results
|
1396
|
+
if self.task == 'segment' or self.use_sam_dropdown.currentText() == "True":
|
1397
|
+
results_processor.process_segmentation_results(updated_results)
|
1398
|
+
else:
|
1399
|
+
results_processor.process_detection_results(updated_results)
|
1400
|
+
|
1401
|
+
# Close the progress bar
|
1402
|
+
progress_bar.finish_progress()
|
1403
|
+
progress_bar.stop_progress()
|
1404
|
+
progress_bar.close()
|
1405
|
+
|
1406
|
+
def show_vpe(self):
|
1407
|
+
"""
|
1408
|
+
Show a visualization of the VPEs using PyQtGraph.
|
1409
|
+
This method now always recalculates VPEs from the currently highlighted reference images.
|
1410
|
+
"""
|
1411
|
+
# Set cursor to busy while loading VPEs
|
1412
|
+
QApplication.setOverrideCursor(Qt.WaitCursor)
|
1413
|
+
|
1414
|
+
try:
|
1415
|
+
# Always sync with the live UI selection before visualizing.
|
1416
|
+
self.update_stashed_references_from_ui()
|
1417
|
+
|
1418
|
+
vpes_with_source = []
|
1419
|
+
|
1420
|
+
# 1. Add any VPEs that were loaded from a file
|
1421
|
+
if self.imported_vpes:
|
1422
|
+
for vpe in self.imported_vpes:
|
1423
|
+
vpes_with_source.append((vpe, "Import"))
|
1424
|
+
|
1425
|
+
# 2. Get the currently selected reference images from the stashed list
|
1426
|
+
references_dict = self._get_references()
|
1427
|
+
|
1428
|
+
# 3. If there are reference images, calculate their VPEs and add with source type
|
1429
|
+
if references_dict:
|
1430
|
+
self.reload_model()
|
1431
|
+
new_reference_vpes = self.references_to_vpe(references_dict, update_reference_vpes=True)
|
1432
|
+
if new_reference_vpes:
|
1433
|
+
for vpe in new_reference_vpes:
|
1434
|
+
vpes_with_source.append((vpe, "Reference"))
|
1435
|
+
|
1436
|
+
# 4. Check if there is anything to visualize
|
1437
|
+
if not vpes_with_source:
|
1438
|
+
QMessageBox.warning(
|
1439
|
+
self,
|
1440
|
+
"No VPEs Available",
|
1441
|
+
"No VPEs available to visualize. Please either load a VPE file or select reference images."
|
1442
|
+
)
|
1443
|
+
return
|
1444
|
+
|
1445
|
+
# 5. Create the visualization dialog, passing the list of tuples
|
1446
|
+
all_vpe_tensors = [vpe for vpe, source in vpes_with_source]
|
1447
|
+
averaged_vpe = torch.cat(all_vpe_tensors).mean(dim=0, keepdim=True)
|
1448
|
+
final_vpe = torch.nn.functional.normalize(averaged_vpe, p=2, dim=-1)
|
1449
|
+
|
1450
|
+
dialog = VPEVisualizationDialog(vpes_with_source, final_vpe, self)
|
1451
|
+
dialog.exec_()
|
1452
|
+
|
1453
|
+
finally:
|
1454
|
+
# Always restore cursor, even if an exception occurs
|
1455
|
+
QApplication.restoreOverrideCursor()
|
1456
|
+
|
1457
|
+
def deactivate_model(self):
|
1458
|
+
"""
|
1459
|
+
Deactivate the currently loaded model and clean up resources.
|
1460
|
+
"""
|
1461
|
+
self.loaded_model = None
|
1462
|
+
self.model_path = None
|
1463
|
+
|
1464
|
+
# Clear all VPE-related data
|
1465
|
+
self.vpe_path_edit.clear()
|
1466
|
+
self.vpe_path = None
|
1467
|
+
self.vpe = None
|
1468
|
+
self.imported_vpes = []
|
1469
|
+
self.reference_vpes = []
|
1470
|
+
|
1471
|
+
# Clean up references
|
1472
|
+
gc.collect()
|
1473
|
+
torch.cuda.empty_cache()
|
1474
|
+
|
1475
|
+
# Untoggle all tools
|
1476
|
+
self.main_window.untoggle_all_tools()
|
1477
|
+
|
1478
|
+
# Update status bar
|
1479
|
+
self.status_bar.setText("No model loaded")
|
1480
|
+
QMessageBox.information(self, "Model Deactivated", "Model deactivated")
|
1481
|
+
|
1482
|
+
|
1483
|
+
class VPEVisualizationDialog(QDialog):
|
1484
|
+
"""
|
1485
|
+
Dialog for visualizing VPE embeddings in 2D space using PCA.
|
1486
|
+
"""
|
1487
|
+
def __init__(self, vpe_list_with_source, final_vpe=None, parent=None):
|
1488
|
+
"""
|
1489
|
+
Initialize the dialog with a list of VPE tensors and their sources.
|
1490
|
+
|
1491
|
+
Args:
|
1492
|
+
vpe_list_with_source (list): List of (VPE tensor, source_str) tuples
|
1493
|
+
final_vpe (torch.Tensor, optional): The final (averaged) VPE
|
1494
|
+
parent (QWidget, optional): Parent widget
|
1495
|
+
"""
|
1496
|
+
super().__init__(parent)
|
1497
|
+
self.setWindowTitle("VPE Visualization")
|
1498
|
+
self.resize(1000, 1000)
|
1499
|
+
|
1500
|
+
# Add a maximize button to the dialog's title bar
|
1501
|
+
self.setWindowFlags(self.windowFlags() | Qt.WindowMaximizeButtonHint)
|
1502
|
+
|
1503
|
+
# Store the VPEs and their sources
|
1504
|
+
self.vpe_list_with_source = vpe_list_with_source
|
1505
|
+
self.final_vpe = final_vpe
|
1506
|
+
|
1507
|
+
# Create the layout
|
1508
|
+
layout = QVBoxLayout(self)
|
1509
|
+
|
1510
|
+
# Create the plot widget
|
1511
|
+
self.plot_widget = pg.PlotWidget()
|
1512
|
+
self.plot_widget.setBackground('w') # White background
|
1513
|
+
self.plot_widget.setTitle("PCA Visualization of Visual Prompt Embeddings", color="#000000", size="10pt")
|
1514
|
+
self.plot_widget.showGrid(x=True, y=True, alpha=0.3)
|
1515
|
+
|
1516
|
+
# Add the plot widget to the layout
|
1517
|
+
layout.addWidget(self.plot_widget)
|
1518
|
+
|
1519
|
+
# Add spacing between plot_widget and info_label
|
1520
|
+
layout.addSpacing(20)
|
1521
|
+
|
1522
|
+
# Add information label at the bottom
|
1523
|
+
self.info_label = QLabel()
|
1524
|
+
self.info_label.setAlignment(Qt.AlignCenter)
|
1525
|
+
layout.addWidget(self.info_label)
|
1526
|
+
|
1527
|
+
# Create the button box
|
1528
|
+
button_box = QDialogButtonBox(QDialogButtonBox.Close)
|
1529
|
+
button_box.rejected.connect(self.reject)
|
1530
|
+
layout.addWidget(button_box)
|
1531
|
+
|
1532
|
+
# Visualize the VPEs
|
1533
|
+
self.visualize_vpes()
|
1534
|
+
|
1535
|
+
def visualize_vpes(self):
|
1536
|
+
"""
|
1537
|
+
Apply PCA to the VPE tensors and visualize them in 2D space.
|
1538
|
+
"""
|
1539
|
+
if not self.vpe_list_with_source:
|
1540
|
+
self.info_label.setText("No VPEs available to visualize.")
|
1541
|
+
return
|
1542
|
+
|
1543
|
+
# Convert tensors to numpy arrays for PCA, separating them from the source string
|
1544
|
+
vpe_arrays = [vpe.detach().cpu().numpy().squeeze() for vpe, source in self.vpe_list_with_source]
|
1545
|
+
|
1546
|
+
# If final VPE is provided, add it to the arrays
|
1547
|
+
final_vpe_array = None
|
1548
|
+
if self.final_vpe is not None:
|
1549
|
+
final_vpe_array = self.final_vpe.detach().cpu().numpy().squeeze()
|
1550
|
+
all_vpes = np.vstack(vpe_arrays + [final_vpe_array])
|
1551
|
+
else:
|
1552
|
+
all_vpes = np.vstack(vpe_arrays)
|
1553
|
+
|
1554
|
+
# Apply PCA to reduce to 2 dimensions
|
1555
|
+
pca = PCA(n_components=2)
|
1556
|
+
vpes_2d = pca.fit_transform(all_vpes)
|
1557
|
+
|
1558
|
+
# Clear the plot
|
1559
|
+
self.plot_widget.clear()
|
1560
|
+
|
1561
|
+
# Generate random colors for individual VPEs
|
1562
|
+
num_vpes = len(vpe_arrays)
|
1563
|
+
colors = self.generate_distinct_colors(num_vpes)
|
1564
|
+
|
1565
|
+
# Create a legend with 3 columns to keep it compact
|
1566
|
+
legend = self.plot_widget.addLegend(colCount=3)
|
1567
|
+
|
1568
|
+
# Plot individual VPEs
|
1569
|
+
for i, (vpe_tuple, vpe_2d) in enumerate(zip(self.vpe_list_with_source, vpes_2d[:num_vpes])):
|
1570
|
+
source_char = 'I' if vpe_tuple[1] == 'Import' else 'R'
|
1571
|
+
color = pg.mkColor(colors[i])
|
1572
|
+
scatter = pg.ScatterPlotItem(
|
1573
|
+
x=[vpe_2d[0]],
|
1574
|
+
y=[vpe_2d[1]],
|
1575
|
+
brush=color,
|
1576
|
+
size=15,
|
1577
|
+
name=f"VPE {i+1} ({source_char})"
|
1578
|
+
)
|
1579
|
+
self.plot_widget.addItem(scatter)
|
1580
|
+
|
1581
|
+
# Plot the final (averaged) VPE if available
|
1582
|
+
if final_vpe_array is not None:
|
1583
|
+
final_vpe_2d = vpes_2d[-1]
|
1584
|
+
scatter = pg.ScatterPlotItem(
|
1585
|
+
x=[final_vpe_2d[0]],
|
1586
|
+
y=[final_vpe_2d[1]],
|
1587
|
+
brush=pg.mkBrush(color='r'),
|
1588
|
+
size=20,
|
1589
|
+
symbol='star',
|
1590
|
+
name="Final VPE"
|
1591
|
+
)
|
1592
|
+
self.plot_widget.addItem(scatter)
|
1593
|
+
|
1594
|
+
# Update the information label
|
1595
|
+
orig_dim = self.vpe_list_with_source[0][0].shape[-1]
|
1596
|
+
explained_variance = sum(pca.explained_variance_ratio_)
|
1597
|
+
self.info_label.setText(
|
1598
|
+
f"Original dimension: {orig_dim} → Reduced to 2D\n"
|
1599
|
+
f"Total explained variance: {explained_variance:.2%}\n"
|
1600
|
+
f"PC1: {pca.explained_variance_ratio_[0]:.2%} variance, "
|
1601
|
+
f"PC2: {pca.explained_variance_ratio_[1]:.2%} variance"
|
1602
|
+
)
|
1603
|
+
|
1604
|
+
def generate_distinct_colors(self, num_colors):
|
1605
|
+
"""
|
1606
|
+
Generate visually distinct colors by using evenly spaced hues
|
1607
|
+
with random saturation and value.
|
1608
|
+
|
1609
|
+
Args:
|
1610
|
+
num_colors (int): Number of colors to generate
|
1611
|
+
|
1612
|
+
Returns:
|
1613
|
+
list: List of color hex strings
|
1614
|
+
"""
|
1615
|
+
import random
|
1616
|
+
from colorsys import hsv_to_rgb
|
1617
|
+
|
1618
|
+
colors = []
|
1619
|
+
for i in range(num_colors):
|
1620
|
+
# Use golden ratio to space hues evenly
|
1621
|
+
hue = (i * 0.618033988749895) % 1.0
|
1622
|
+
# Random saturation between 0.6-1.0 (avoid too pale)
|
1623
|
+
saturation = random.uniform(0.6, 1.0)
|
1624
|
+
# Random value between 0.7-1.0 (avoid too dark)
|
1625
|
+
value = random.uniform(0.7, 1.0)
|
1626
|
+
|
1627
|
+
# Convert HSV to RGB (0-1 range)
|
1628
|
+
r, g, b = hsv_to_rgb(hue, saturation, value)
|
1629
|
+
|
1630
|
+
# Convert RGB to hex string
|
1631
|
+
hex_color = f"#{int(r*255):02x}{int(g*255):02x}{int(b*255):02x}"
|
1632
|
+
colors.append(hex_color)
|
1633
|
+
|
1634
|
+
return colors
|