coralnet-toolbox 0.0.71__py2.py3-none-any.whl → 0.0.73__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- coralnet_toolbox/Annotations/QtRectangleAnnotation.py +31 -2
- coralnet_toolbox/AutoDistill/QtDeployModel.py +23 -12
- coralnet_toolbox/Explorer/QtDataItem.py +53 -21
- coralnet_toolbox/Explorer/QtExplorer.py +581 -276
- coralnet_toolbox/Explorer/QtFeatureStore.py +15 -0
- coralnet_toolbox/Explorer/QtSettingsWidgets.py +49 -7
- coralnet_toolbox/MachineLearning/DeployModel/QtDetect.py +22 -11
- coralnet_toolbox/MachineLearning/DeployModel/QtSegment.py +22 -10
- coralnet_toolbox/MachineLearning/ExportDataset/QtBase.py +61 -24
- coralnet_toolbox/MachineLearning/ExportDataset/QtClassify.py +5 -1
- coralnet_toolbox/MachineLearning/ExportDataset/QtDetect.py +19 -6
- coralnet_toolbox/MachineLearning/ExportDataset/QtSegment.py +21 -8
- coralnet_toolbox/QtAnnotationWindow.py +52 -16
- coralnet_toolbox/QtEventFilter.py +8 -2
- coralnet_toolbox/QtImageWindow.py +17 -18
- coralnet_toolbox/QtLabelWindow.py +1 -1
- coralnet_toolbox/QtMainWindow.py +203 -8
- coralnet_toolbox/Rasters/QtRaster.py +59 -7
- coralnet_toolbox/Rasters/RasterTableModel.py +34 -6
- coralnet_toolbox/SAM/QtBatchInference.py +0 -2
- coralnet_toolbox/SAM/QtDeployGenerator.py +22 -11
- coralnet_toolbox/SeeAnything/QtBatchInference.py +19 -221
- coralnet_toolbox/SeeAnything/QtDeployGenerator.py +1016 -0
- coralnet_toolbox/SeeAnything/QtDeployPredictor.py +69 -53
- coralnet_toolbox/SeeAnything/QtTrainModel.py +115 -45
- coralnet_toolbox/SeeAnything/__init__.py +2 -0
- coralnet_toolbox/Tools/QtResizeSubTool.py +6 -1
- coralnet_toolbox/Tools/QtSAMTool.py +150 -7
- coralnet_toolbox/Tools/QtSeeAnythingTool.py +220 -55
- coralnet_toolbox/Tools/QtSelectSubTool.py +6 -4
- coralnet_toolbox/Tools/QtSelectTool.py +48 -6
- coralnet_toolbox/Tools/QtWorkAreaTool.py +25 -13
- coralnet_toolbox/__init__.py +1 -1
- {coralnet_toolbox-0.0.71.dist-info → coralnet_toolbox-0.0.73.dist-info}/METADATA +1 -1
- {coralnet_toolbox-0.0.71.dist-info → coralnet_toolbox-0.0.73.dist-info}/RECORD +39 -38
- {coralnet_toolbox-0.0.71.dist-info → coralnet_toolbox-0.0.73.dist-info}/WHEEL +0 -0
- {coralnet_toolbox-0.0.71.dist-info → coralnet_toolbox-0.0.73.dist-info}/entry_points.txt +0 -0
- {coralnet_toolbox-0.0.71.dist-info → coralnet_toolbox-0.0.73.dist-info}/licenses/LICENSE.txt +0 -0
- {coralnet_toolbox-0.0.71.dist-info → coralnet_toolbox-0.0.73.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1016 @@
|
|
1
|
+
import warnings
|
2
|
+
|
3
|
+
import os
|
4
|
+
import gc
|
5
|
+
import json
|
6
|
+
import copy
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
|
10
|
+
import torch
|
11
|
+
from torch.cuda import empty_cache
|
12
|
+
|
13
|
+
from ultralytics import YOLOE
|
14
|
+
from ultralytics.models.yolo.yoloe import YOLOEVPSegPredictor
|
15
|
+
from ultralytics.models.yolo.yoloe import YOLOEVPDetectPredictor
|
16
|
+
|
17
|
+
from PyQt5.QtCore import Qt
|
18
|
+
from PyQt5.QtGui import QColor
|
19
|
+
from PyQt5.QtWidgets import (QMessageBox, QCheckBox, QVBoxLayout, QApplication,
|
20
|
+
QLabel, QDialog, QDialogButtonBox, QGroupBox, QLineEdit,
|
21
|
+
QFormLayout, QComboBox, QSpinBox, QSlider, QPushButton,
|
22
|
+
QHBoxLayout, QWidget, QFileDialog)
|
23
|
+
|
24
|
+
from coralnet_toolbox.Annotations.QtPolygonAnnotation import PolygonAnnotation
|
25
|
+
from coralnet_toolbox.Annotations.QtRectangleAnnotation import RectangleAnnotation
|
26
|
+
|
27
|
+
from coralnet_toolbox.Results import ResultsProcessor
|
28
|
+
from coralnet_toolbox.Results import MapResults
|
29
|
+
from coralnet_toolbox.Results import CombineResults
|
30
|
+
|
31
|
+
from coralnet_toolbox.QtProgressBar import ProgressBar
|
32
|
+
from coralnet_toolbox.QtImageWindow import ImageWindow
|
33
|
+
|
34
|
+
from coralnet_toolbox.Icons import get_icon
|
35
|
+
|
36
|
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
37
|
+
warnings.filterwarnings("ignore", category=UserWarning)
|
38
|
+
|
39
|
+
|
40
|
+
# ----------------------------------------------------------------------------------------------------------------------
|
41
|
+
# Classes
|
42
|
+
# ----------------------------------------------------------------------------------------------------------------------
|
43
|
+
|
44
|
+
|
45
|
+
class DeployGeneratorDialog(QDialog):
|
46
|
+
"""
|
47
|
+
Perform See Anything (YOLOE) on multiple images using a reference image and label.
|
48
|
+
|
49
|
+
:param main_window: MainWindow object
|
50
|
+
:param parent: Parent widget
|
51
|
+
"""
|
52
|
+
def __init__(self, main_window, parent=None):
|
53
|
+
super().__init__(parent)
|
54
|
+
self.main_window = main_window
|
55
|
+
self.label_window = main_window.label_window
|
56
|
+
self.image_window = main_window.image_window
|
57
|
+
self.annotation_window = main_window.annotation_window
|
58
|
+
self.sam_dialog = None
|
59
|
+
|
60
|
+
self.setWindowIcon(get_icon("eye.png"))
|
61
|
+
self.setWindowTitle("See Anything (YOLOE) Generator (Ctrl + 5)")
|
62
|
+
self.resize(800, 800) # Increased size to accommodate the horizontal layout
|
63
|
+
|
64
|
+
self.deploy_model_dialog = None
|
65
|
+
self.loaded_model = None
|
66
|
+
self.last_selected_label_code = None
|
67
|
+
|
68
|
+
# Initialize variables
|
69
|
+
self.imgsz = 1024
|
70
|
+
self.iou_thresh = 0.20
|
71
|
+
self.uncertainty_thresh = 0.30
|
72
|
+
self.area_thresh_min = 0.00
|
73
|
+
self.area_thresh_max = 0.40
|
74
|
+
|
75
|
+
self.task = 'detect'
|
76
|
+
self.max_detect = 300
|
77
|
+
self.loaded_model = None
|
78
|
+
self.model_path = None
|
79
|
+
self.class_mapping = {}
|
80
|
+
|
81
|
+
# Reference image and label
|
82
|
+
self.source_images = []
|
83
|
+
self.source_label = None
|
84
|
+
# Target images
|
85
|
+
self.target_images = []
|
86
|
+
|
87
|
+
# Main vertical layout for the dialog
|
88
|
+
self.layout = QVBoxLayout(self)
|
89
|
+
|
90
|
+
# Setup the info layout at the top
|
91
|
+
self.setup_info_layout()
|
92
|
+
|
93
|
+
# Create horizontal layout for the two panels
|
94
|
+
self.horizontal_layout = QHBoxLayout()
|
95
|
+
self.layout.addLayout(self.horizontal_layout)
|
96
|
+
|
97
|
+
# Create left panel
|
98
|
+
self.left_panel = QVBoxLayout()
|
99
|
+
self.horizontal_layout.addLayout(self.left_panel)
|
100
|
+
|
101
|
+
# Create right panel
|
102
|
+
self.right_panel = QVBoxLayout()
|
103
|
+
self.horizontal_layout.addLayout(self.right_panel)
|
104
|
+
|
105
|
+
# Add layouts to the left panel
|
106
|
+
self.setup_models_layout()
|
107
|
+
self.setup_parameters_layout()
|
108
|
+
self.setup_sam_layout()
|
109
|
+
self.setup_model_buttons_layout()
|
110
|
+
self.setup_status_layout()
|
111
|
+
|
112
|
+
# Add layouts to the right panel
|
113
|
+
self.setup_source_layout()
|
114
|
+
|
115
|
+
# # Add a full ImageWindow instance for target image selection
|
116
|
+
self.image_selection_window = ImageWindow(self.main_window)
|
117
|
+
self.right_panel.addWidget(self.image_selection_window)
|
118
|
+
|
119
|
+
# Setup the buttons layout at the bottom
|
120
|
+
self.setup_buttons_layout()
|
121
|
+
|
122
|
+
def configure_image_window_for_dialog(self):
|
123
|
+
"""
|
124
|
+
Disables parts of the internal ImageWindow UI to guide user selection.
|
125
|
+
This forces the image list to only show images with annotations
|
126
|
+
matching the selected reference label.
|
127
|
+
"""
|
128
|
+
iw = self.image_selection_window
|
129
|
+
|
130
|
+
# Block signals to prevent setChecked from triggering the ImageWindow's
|
131
|
+
# own filtering logic. We want to be in complete control.
|
132
|
+
iw.highlighted_checkbox.blockSignals(True)
|
133
|
+
iw.has_predictions_checkbox.blockSignals(True)
|
134
|
+
iw.no_annotations_checkbox.blockSignals(True)
|
135
|
+
iw.has_annotations_checkbox.blockSignals(True)
|
136
|
+
|
137
|
+
# Disable and set filter checkboxes
|
138
|
+
iw.highlighted_checkbox.setEnabled(False)
|
139
|
+
iw.has_predictions_checkbox.setEnabled(False)
|
140
|
+
iw.no_annotations_checkbox.setEnabled(False)
|
141
|
+
iw.has_annotations_checkbox.setEnabled(False)
|
142
|
+
|
143
|
+
iw.highlighted_checkbox.setChecked(False)
|
144
|
+
iw.has_predictions_checkbox.setChecked(False)
|
145
|
+
iw.no_annotations_checkbox.setChecked(False)
|
146
|
+
iw.has_annotations_checkbox.setChecked(True) # This will no longer trigger a filter
|
147
|
+
|
148
|
+
# Unblock signals now that we're done.
|
149
|
+
iw.highlighted_checkbox.blockSignals(False)
|
150
|
+
iw.has_predictions_checkbox.blockSignals(False)
|
151
|
+
iw.no_annotations_checkbox.blockSignals(False)
|
152
|
+
iw.has_annotations_checkbox.blockSignals(False)
|
153
|
+
|
154
|
+
# Disable search UI elements
|
155
|
+
iw.home_button.setEnabled(False)
|
156
|
+
iw.image_search_button.setEnabled(False)
|
157
|
+
iw.label_search_button.setEnabled(False)
|
158
|
+
iw.search_bar_images.setEnabled(False)
|
159
|
+
iw.search_bar_labels.setEnabled(False)
|
160
|
+
iw.top_k_combo.setEnabled(False)
|
161
|
+
|
162
|
+
# Set Top-K to Top1
|
163
|
+
iw.top_k_combo.setCurrentText("Top1")
|
164
|
+
|
165
|
+
# Disconnect the double-click signal to prevent it from loading an image
|
166
|
+
# in the main window, as this dialog is for selection only.
|
167
|
+
try:
|
168
|
+
iw.tableView.doubleClicked.disconnect()
|
169
|
+
except TypeError:
|
170
|
+
pass
|
171
|
+
|
172
|
+
# CRITICAL: Override the load_first_filtered_image method to prevent auto-loading
|
173
|
+
# This is the key fix to prevent unwanted load_image_by_path calls
|
174
|
+
iw.load_first_filtered_image = lambda: None
|
175
|
+
|
176
|
+
def showEvent(self, event):
|
177
|
+
"""
|
178
|
+
Set up the layout when the dialog is shown.
|
179
|
+
|
180
|
+
:param event: Show event
|
181
|
+
"""
|
182
|
+
super().showEvent(event)
|
183
|
+
self.initialize_uncertainty_threshold()
|
184
|
+
self.initialize_iou_threshold()
|
185
|
+
self.initialize_area_threshold()
|
186
|
+
|
187
|
+
# Configure the image window's UI elements for this specific dialog
|
188
|
+
self.configure_image_window_for_dialog()
|
189
|
+
# Sync with main window's images BEFORE updating labels
|
190
|
+
self.sync_image_window()
|
191
|
+
# This now populates the dropdown, restores the last selection,
|
192
|
+
# and then manually triggers the image filtering.
|
193
|
+
self.update_source_labels()
|
194
|
+
|
195
|
+
def sync_image_window(self):
|
196
|
+
"""
|
197
|
+
Syncs by directly adopting the main manager's up-to-date raster objects,
|
198
|
+
avoiding redundant and slow re-calculation of annotation info.
|
199
|
+
"""
|
200
|
+
main_manager = self.main_window.image_window.raster_manager
|
201
|
+
dialog_manager = self.image_selection_window.raster_manager
|
202
|
+
|
203
|
+
# Since the main_manager's rasters are always up-to-date, we can
|
204
|
+
# simply replace the dialog's raster dictionary and path list entirely.
|
205
|
+
# This is a shallow copy of the dictionary, which is extremely fast.
|
206
|
+
# The Raster objects themselves are not copied, just referenced.
|
207
|
+
dialog_manager.rasters = main_manager.rasters.copy()
|
208
|
+
|
209
|
+
# Update the path list to match the new dictionary of rasters.
|
210
|
+
dialog_manager.image_paths = list(dialog_manager.rasters.keys())
|
211
|
+
|
212
|
+
# The slow 'for' loop that called update_annotation_info is now gone.
|
213
|
+
# We are trusting that each raster object from the main_manager
|
214
|
+
# already has its .label_set and .annotation_type_set correctly populated.
|
215
|
+
|
216
|
+
def filter_images_by_label_and_type(self):
|
217
|
+
"""
|
218
|
+
Filters the image list to show only images that contain at least one
|
219
|
+
annotation that has BOTH the selected label AND a valid type (Polygon or Rectangle).
|
220
|
+
This uses the fast, pre-computed cache for performance.
|
221
|
+
"""
|
222
|
+
source_label = self.source_label_combo_box.currentData()
|
223
|
+
source_label_text = self.source_label_combo_box.currentText()
|
224
|
+
|
225
|
+
# Store the last selected label for a better user experience on re-opening.
|
226
|
+
if source_label_text:
|
227
|
+
self.last_selected_label_code = source_label_text
|
228
|
+
|
229
|
+
if not source_label:
|
230
|
+
# If no label is selected (e.g., during initialization), show an empty list.
|
231
|
+
self.image_selection_window.table_model.set_filtered_paths([])
|
232
|
+
return
|
233
|
+
|
234
|
+
all_paths = self.image_selection_window.raster_manager.image_paths
|
235
|
+
final_filtered_paths = []
|
236
|
+
|
237
|
+
valid_types = {"RectangleAnnotation", "PolygonAnnotation"}
|
238
|
+
selected_label_code = source_label.short_label_code
|
239
|
+
|
240
|
+
# Loop through paths and check the pre-computed map on each raster
|
241
|
+
for path in all_paths:
|
242
|
+
raster = self.image_selection_window.raster_manager.get_raster(path)
|
243
|
+
if not raster:
|
244
|
+
continue
|
245
|
+
|
246
|
+
# 1. From the cache, get the set of annotation types specifically for our selected label.
|
247
|
+
# Use .get() to safely return an empty set if the label isn't on this image at all.
|
248
|
+
types_for_this_label = raster.label_to_types_map.get(selected_label_code, set())
|
249
|
+
|
250
|
+
# 2. Check for any overlap between the types found FOR THIS LABEL and the
|
251
|
+
# valid types we need (Polygon/Rectangle). This is the key check.
|
252
|
+
if not valid_types.isdisjoint(types_for_this_label):
|
253
|
+
# This image is a valid reference because the selected label exists
|
254
|
+
# on a Polygon or Rectangle. Add it to the list.
|
255
|
+
final_filtered_paths.append(path)
|
256
|
+
|
257
|
+
# Directly set the filtered list in the table model.
|
258
|
+
self.image_selection_window.table_model.set_filtered_paths(final_filtered_paths)
|
259
|
+
|
260
|
+
def accept(self):
|
261
|
+
"""
|
262
|
+
Validate selections and store them before closing the dialog.
|
263
|
+
"""
|
264
|
+
if not self.loaded_model:
|
265
|
+
QMessageBox.warning(self,
|
266
|
+
"No Model",
|
267
|
+
"A model must be loaded before running predictions.")
|
268
|
+
super().reject()
|
269
|
+
return
|
270
|
+
|
271
|
+
current_label = self.source_label_combo_box.currentData()
|
272
|
+
if not current_label:
|
273
|
+
QMessageBox.warning(self,
|
274
|
+
"No Source Label",
|
275
|
+
"A source label must be selected.")
|
276
|
+
super().reject()
|
277
|
+
return
|
278
|
+
|
279
|
+
# Get highlighted paths from our internal image window to use as targets
|
280
|
+
highlighted_images = self.image_selection_window.table_model.get_highlighted_paths()
|
281
|
+
|
282
|
+
if not highlighted_images:
|
283
|
+
QMessageBox.warning(self,
|
284
|
+
"No Target Images",
|
285
|
+
"You must highlight at least one image in the list to process.")
|
286
|
+
super().reject()
|
287
|
+
return
|
288
|
+
|
289
|
+
# Store the selections for the caller to use after the dialog closes.
|
290
|
+
self.source_label = current_label
|
291
|
+
self.target_images = highlighted_images
|
292
|
+
|
293
|
+
# Do not call self.predict here; just close the dialog and let the caller handle prediction
|
294
|
+
super().accept()
|
295
|
+
|
296
|
+
def setup_info_layout(self):
|
297
|
+
"""
|
298
|
+
Set up the layout and widgets for the info layout that spans the top.
|
299
|
+
"""
|
300
|
+
group_box = QGroupBox("Information")
|
301
|
+
layout = QVBoxLayout()
|
302
|
+
|
303
|
+
# Create a QLabel with explanatory text and hyperlink
|
304
|
+
info_label = QLabel("Choose a Generator to deploy. "
|
305
|
+
"Select a reference label, then highlight reference images that contain examples. "
|
306
|
+
"Each additional reference image may increase accuracy but also processing time.")
|
307
|
+
|
308
|
+
info_label.setOpenExternalLinks(True)
|
309
|
+
info_label.setWordWrap(True)
|
310
|
+
layout.addWidget(info_label)
|
311
|
+
|
312
|
+
group_box.setLayout(layout)
|
313
|
+
self.layout.addWidget(group_box) # Add to main layout so it spans both panels
|
314
|
+
|
315
|
+
def setup_models_layout(self):
|
316
|
+
"""
|
317
|
+
Setup the models layout with a simple model selection combo box (no tabs).
|
318
|
+
"""
|
319
|
+
group_box = QGroupBox("Model Selection")
|
320
|
+
layout = QVBoxLayout()
|
321
|
+
|
322
|
+
self.model_combo = QComboBox()
|
323
|
+
self.model_combo.setEditable(True)
|
324
|
+
|
325
|
+
# Define available models (keep the existing dictionary)
|
326
|
+
self.models = [
|
327
|
+
"yoloe-v8s-seg.pt",
|
328
|
+
"yoloe-v8m-seg.pt",
|
329
|
+
"yoloe-v8l-seg.pt",
|
330
|
+
"yoloe-11s-seg.pt",
|
331
|
+
"yoloe-11m-seg.pt",
|
332
|
+
"yoloe-11l-seg.pt",
|
333
|
+
]
|
334
|
+
|
335
|
+
# Add all models to combo box
|
336
|
+
for model_name in self.models:
|
337
|
+
self.model_combo.addItem(model_name)
|
338
|
+
|
339
|
+
# Set the default model
|
340
|
+
self.model_combo.setCurrentText("yoloe-v8s-seg.pt")
|
341
|
+
|
342
|
+
layout.addWidget(QLabel("Select Model:"))
|
343
|
+
layout.addWidget(self.model_combo)
|
344
|
+
|
345
|
+
group_box.setLayout(layout)
|
346
|
+
self.left_panel.addWidget(group_box) # Add to left panel
|
347
|
+
|
348
|
+
def setup_parameters_layout(self):
|
349
|
+
"""
|
350
|
+
Setup parameter control section in a group box.
|
351
|
+
"""
|
352
|
+
group_box = QGroupBox("Parameters")
|
353
|
+
layout = QFormLayout()
|
354
|
+
|
355
|
+
# Task dropdown
|
356
|
+
self.use_task_dropdown = QComboBox()
|
357
|
+
self.use_task_dropdown.addItems(["detect", "segment"])
|
358
|
+
self.use_task_dropdown.currentIndexChanged.connect(self.update_task)
|
359
|
+
layout.addRow("Task:", self.use_task_dropdown)
|
360
|
+
|
361
|
+
# Max detections spinbox
|
362
|
+
self.max_detections_spinbox = QSpinBox()
|
363
|
+
self.max_detections_spinbox.setRange(1, 10000)
|
364
|
+
self.max_detections_spinbox.setValue(self.max_detect)
|
365
|
+
layout.addRow("Max Detections:", self.max_detections_spinbox)
|
366
|
+
|
367
|
+
# Resize image dropdown
|
368
|
+
self.resize_image_dropdown = QComboBox()
|
369
|
+
self.resize_image_dropdown.addItems(["True", "False"])
|
370
|
+
self.resize_image_dropdown.setCurrentIndex(0)
|
371
|
+
self.resize_image_dropdown.setEnabled(False) # Grey out the dropdown
|
372
|
+
layout.addRow("Resize Image:", self.resize_image_dropdown)
|
373
|
+
|
374
|
+
# Image size control
|
375
|
+
self.imgsz_spinbox = QSpinBox()
|
376
|
+
self.imgsz_spinbox.setRange(512, 65536)
|
377
|
+
self.imgsz_spinbox.setSingleStep(1024)
|
378
|
+
self.imgsz_spinbox.setValue(self.imgsz)
|
379
|
+
layout.addRow("Image Size (imgsz):", self.imgsz_spinbox)
|
380
|
+
|
381
|
+
# Uncertainty threshold controls
|
382
|
+
self.uncertainty_thresh = self.main_window.get_uncertainty_thresh()
|
383
|
+
self.uncertainty_threshold_slider = QSlider(Qt.Horizontal)
|
384
|
+
self.uncertainty_threshold_slider.setRange(0, 100)
|
385
|
+
self.uncertainty_threshold_slider.setValue(int(self.main_window.get_uncertainty_thresh() * 100))
|
386
|
+
self.uncertainty_threshold_slider.setTickPosition(QSlider.TicksBelow)
|
387
|
+
self.uncertainty_threshold_slider.setTickInterval(10)
|
388
|
+
self.uncertainty_threshold_slider.valueChanged.connect(self.update_uncertainty_label)
|
389
|
+
self.uncertainty_threshold_label = QLabel(f"{self.uncertainty_thresh:.2f}")
|
390
|
+
layout.addRow("Uncertainty Threshold", self.uncertainty_threshold_slider)
|
391
|
+
layout.addRow("", self.uncertainty_threshold_label)
|
392
|
+
|
393
|
+
# IoU threshold controls
|
394
|
+
self.iou_thresh = self.main_window.get_iou_thresh()
|
395
|
+
self.iou_threshold_slider = QSlider(Qt.Horizontal)
|
396
|
+
self.iou_threshold_slider.setRange(0, 100)
|
397
|
+
self.iou_threshold_slider.setValue(int(self.iou_thresh * 100))
|
398
|
+
self.iou_threshold_slider.setTickPosition(QSlider.TicksBelow)
|
399
|
+
self.iou_threshold_slider.setTickInterval(10)
|
400
|
+
self.iou_threshold_slider.valueChanged.connect(self.update_iou_label)
|
401
|
+
self.iou_threshold_label = QLabel(f"{self.iou_thresh:.2f}")
|
402
|
+
layout.addRow("IoU Threshold", self.iou_threshold_slider)
|
403
|
+
layout.addRow("", self.iou_threshold_label)
|
404
|
+
|
405
|
+
# Area threshold controls
|
406
|
+
min_val, max_val = self.main_window.get_area_thresh()
|
407
|
+
self.area_thresh_min = int(min_val * 100)
|
408
|
+
self.area_thresh_max = int(max_val * 100)
|
409
|
+
self.area_threshold_min_slider = QSlider(Qt.Horizontal)
|
410
|
+
self.area_threshold_min_slider.setRange(0, 100)
|
411
|
+
self.area_threshold_min_slider.setValue(self.area_thresh_min)
|
412
|
+
self.area_threshold_min_slider.setTickPosition(QSlider.TicksBelow)
|
413
|
+
self.area_threshold_min_slider.setTickInterval(10)
|
414
|
+
self.area_threshold_min_slider.valueChanged.connect(self.update_area_label)
|
415
|
+
self.area_threshold_max_slider = QSlider(Qt.Horizontal)
|
416
|
+
self.area_threshold_max_slider.setRange(0, 100)
|
417
|
+
self.area_threshold_max_slider.setValue(self.area_thresh_max)
|
418
|
+
self.area_threshold_max_slider.setTickPosition(QSlider.TicksBelow)
|
419
|
+
self.area_threshold_max_slider.setTickInterval(10)
|
420
|
+
self.area_threshold_max_slider.valueChanged.connect(self.update_area_label)
|
421
|
+
self.area_threshold_label = QLabel(f"{self.area_thresh_min / 100.0:.2f} - {self.area_thresh_max / 100.0:.2f}")
|
422
|
+
layout.addRow("Area Threshold Min", self.area_threshold_min_slider)
|
423
|
+
layout.addRow("Area Threshold Max", self.area_threshold_max_slider)
|
424
|
+
layout.addRow("", self.area_threshold_label)
|
425
|
+
|
426
|
+
group_box.setLayout(layout)
|
427
|
+
self.left_panel.addWidget(group_box) # Add to left panel
|
428
|
+
|
429
|
+
def setup_sam_layout(self):
|
430
|
+
"""Use SAM model for segmentation."""
|
431
|
+
group_box = QGroupBox("Use SAM Model for Creating Polygons")
|
432
|
+
layout = QFormLayout()
|
433
|
+
|
434
|
+
# SAM dropdown
|
435
|
+
self.use_sam_dropdown = QComboBox()
|
436
|
+
self.use_sam_dropdown.addItems(["False", "True"])
|
437
|
+
self.use_sam_dropdown.currentIndexChanged.connect(self.is_sam_model_deployed)
|
438
|
+
layout.addRow("Use SAM Polygons:", self.use_sam_dropdown)
|
439
|
+
|
440
|
+
group_box.setLayout(layout)
|
441
|
+
self.left_panel.addWidget(group_box) # Add to left panel
|
442
|
+
|
443
|
+
def setup_model_buttons_layout(self):
|
444
|
+
"""
|
445
|
+
Setup action buttons in a group box.
|
446
|
+
"""
|
447
|
+
group_box = QGroupBox("Actions")
|
448
|
+
layout = QHBoxLayout()
|
449
|
+
|
450
|
+
load_button = QPushButton("Load Model")
|
451
|
+
load_button.clicked.connect(self.load_model)
|
452
|
+
layout.addWidget(load_button)
|
453
|
+
|
454
|
+
deactivate_button = QPushButton("Deactivate Model")
|
455
|
+
deactivate_button.clicked.connect(self.deactivate_model)
|
456
|
+
layout.addWidget(deactivate_button)
|
457
|
+
|
458
|
+
group_box.setLayout(layout)
|
459
|
+
self.left_panel.addWidget(group_box) # Add to left panel
|
460
|
+
|
461
|
+
def setup_status_layout(self):
|
462
|
+
"""
|
463
|
+
Setup status display in a group box.
|
464
|
+
"""
|
465
|
+
group_box = QGroupBox("Status")
|
466
|
+
layout = QVBoxLayout()
|
467
|
+
|
468
|
+
self.status_bar = QLabel("No model loaded")
|
469
|
+
layout.addWidget(self.status_bar)
|
470
|
+
|
471
|
+
group_box.setLayout(layout)
|
472
|
+
self.left_panel.addWidget(group_box) # Add to left panel
|
473
|
+
|
474
|
+
def setup_source_layout(self):
|
475
|
+
"""
|
476
|
+
Set up the layout with source label selection.
|
477
|
+
The source image is implicitly the currently active image.
|
478
|
+
"""
|
479
|
+
group_box = QGroupBox("Reference Label")
|
480
|
+
layout = QFormLayout()
|
481
|
+
|
482
|
+
# Create the source label combo box
|
483
|
+
self.source_label_combo_box = QComboBox()
|
484
|
+
self.source_label_combo_box.currentIndexChanged.connect(self.filter_images_by_label_and_type)
|
485
|
+
layout.addRow("Reference Label:", self.source_label_combo_box)
|
486
|
+
|
487
|
+
group_box.setLayout(layout)
|
488
|
+
self.right_panel.addWidget(group_box) # Add to right panel
|
489
|
+
|
490
|
+
def setup_buttons_layout(self):
|
491
|
+
"""
|
492
|
+
Set up the layout with buttons.
|
493
|
+
"""
|
494
|
+
# Create a button box for the buttons
|
495
|
+
button_box = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel)
|
496
|
+
button_box.accepted.connect(self.accept)
|
497
|
+
button_box.rejected.connect(self.reject)
|
498
|
+
|
499
|
+
self.layout.addWidget(button_box)
|
500
|
+
|
501
|
+
def initialize_uncertainty_threshold(self):
|
502
|
+
"""Initialize the uncertainty threshold slider with the current value"""
|
503
|
+
current_value = self.main_window.get_uncertainty_thresh()
|
504
|
+
self.uncertainty_threshold_slider.setValue(int(current_value * 100))
|
505
|
+
self.uncertainty_thresh = current_value
|
506
|
+
|
507
|
+
def initialize_iou_threshold(self):
|
508
|
+
"""Initialize the IOU threshold slider with the current value"""
|
509
|
+
current_value = self.main_window.get_iou_thresh()
|
510
|
+
self.iou_threshold_slider.setValue(int(current_value * 100))
|
511
|
+
self.iou_thresh = current_value
|
512
|
+
|
513
|
+
def initialize_area_threshold(self):
|
514
|
+
"""Initialize the area threshold range slider"""
|
515
|
+
current_min, current_max = self.main_window.get_area_thresh()
|
516
|
+
self.area_threshold_min_slider.setValue(int(current_min * 100))
|
517
|
+
self.area_threshold_max_slider.setValue(int(current_max * 100))
|
518
|
+
self.area_thresh_min = current_min
|
519
|
+
self.area_thresh_max = current_max
|
520
|
+
|
521
|
+
def update_uncertainty_label(self, value):
|
522
|
+
"""Update uncertainty threshold and label"""
|
523
|
+
value = value / 100.0
|
524
|
+
self.uncertainty_thresh = value
|
525
|
+
self.main_window.update_uncertainty_thresh(value)
|
526
|
+
self.uncertainty_threshold_label.setText(f"{value:.2f}")
|
527
|
+
|
528
|
+
def update_iou_label(self, value):
|
529
|
+
"""Update IoU threshold and label"""
|
530
|
+
value = value / 100.0
|
531
|
+
self.iou_thresh = value
|
532
|
+
self.main_window.update_iou_thresh(value)
|
533
|
+
self.iou_threshold_label.setText(f"{value:.2f}")
|
534
|
+
|
535
|
+
def update_area_label(self):
|
536
|
+
"""Handle changes to area threshold range slider"""
|
537
|
+
min_val = self.area_threshold_min_slider.value()
|
538
|
+
max_val = self.area_threshold_max_slider.value()
|
539
|
+
if min_val > max_val:
|
540
|
+
min_val = max_val
|
541
|
+
self.area_threshold_min_slider.setValue(min_val)
|
542
|
+
self.area_thresh_min = min_val / 100.0
|
543
|
+
self.area_thresh_max = max_val / 100.0
|
544
|
+
self.main_window.update_area_thresh(self.area_thresh_min, self.area_thresh_max)
|
545
|
+
self.area_threshold_label.setText(f"{self.area_thresh_min:.2f} - {self.area_thresh_max:.2f}")
|
546
|
+
|
547
|
+
def get_max_detections(self):
|
548
|
+
"""Get the maximum number of detections to return."""
|
549
|
+
self.max_detect = self.max_detections_spinbox.value()
|
550
|
+
return self.max_detect
|
551
|
+
|
552
|
+
def is_sam_model_deployed(self):
|
553
|
+
"""
|
554
|
+
Check if the SAM model is deployed and update the checkbox state accordingly.
|
555
|
+
|
556
|
+
:return: Boolean indicating whether the SAM model is deployed
|
557
|
+
"""
|
558
|
+
if not hasattr(self.main_window, 'sam_deploy_predictor_dialog'):
|
559
|
+
return False
|
560
|
+
|
561
|
+
self.sam_dialog = self.main_window.sam_deploy_predictor_dialog
|
562
|
+
|
563
|
+
if not self.sam_dialog.loaded_model:
|
564
|
+
self.use_sam_dropdown.setCurrentText("False")
|
565
|
+
QMessageBox.critical(self, "Error", "Please deploy the SAM model first.")
|
566
|
+
return False
|
567
|
+
|
568
|
+
return True
|
569
|
+
|
570
|
+
def update_sam_task_state(self):
|
571
|
+
"""
|
572
|
+
Centralized method to check if SAM is loaded and update task accordingly.
|
573
|
+
If the user has selected to use SAM, this function ensures the task is set to 'segment'.
|
574
|
+
Crucially, it does NOT alter the task if SAM is not selected, respecting the
|
575
|
+
user's choice from the 'Task' dropdown.
|
576
|
+
"""
|
577
|
+
# Check if the user wants to use the SAM model
|
578
|
+
if self.use_sam_dropdown.currentText() == "True":
|
579
|
+
# SAM is requested. Check if it's actually available.
|
580
|
+
sam_is_available = (
|
581
|
+
hasattr(self, 'sam_dialog') and
|
582
|
+
self.sam_dialog is not None and
|
583
|
+
self.sam_dialog.loaded_model is not None
|
584
|
+
)
|
585
|
+
|
586
|
+
if sam_is_available:
|
587
|
+
# If SAM is wanted and available, the task must be segmentation.
|
588
|
+
self.task = 'segment'
|
589
|
+
else:
|
590
|
+
# If SAM is wanted but not available, revert the dropdown and do nothing else.
|
591
|
+
# The 'is_sam_model_deployed' function already handles showing an error message.
|
592
|
+
self.use_sam_dropdown.setCurrentText("False")
|
593
|
+
|
594
|
+
# If use_sam_dropdown is "False", do nothing. Let self.task be whatever the user set.
|
595
|
+
|
596
|
+
def update_task(self):
|
597
|
+
"""Update the task based on the dropdown selection and handle UI/model effects."""
|
598
|
+
self.task = self.use_task_dropdown.currentText()
|
599
|
+
|
600
|
+
# Update UI elements based on task
|
601
|
+
if self.task == "segment":
|
602
|
+
# Deactivate model if one is loaded and we're switching to segment task
|
603
|
+
if self.loaded_model:
|
604
|
+
self.deactivate_model()
|
605
|
+
|
606
|
+
def update_source_labels(self):
|
607
|
+
"""
|
608
|
+
Updates the source label combo box with labels that are associated with
|
609
|
+
valid reference annotations (Polygons or Rectangles), using the fast cache.
|
610
|
+
"""
|
611
|
+
self.source_label_combo_box.blockSignals(True)
|
612
|
+
|
613
|
+
try:
|
614
|
+
self.source_label_combo_box.clear()
|
615
|
+
|
616
|
+
dialog_manager = self.image_selection_window.raster_manager
|
617
|
+
valid_types = {"RectangleAnnotation", "PolygonAnnotation"}
|
618
|
+
valid_labels = set() # This will store the full Label objects
|
619
|
+
|
620
|
+
# Create a lookup map to get full label objects from their codes
|
621
|
+
all_project_labels = {lbl.short_label_code: lbl for lbl in self.main_window.label_window.labels}
|
622
|
+
|
623
|
+
# Use the cached data to find all labels that have valid reference types.
|
624
|
+
for raster in dialog_manager.rasters.values():
|
625
|
+
# raster.label_to_types_map is like: {'coral': {'Point'}, 'rock': {'Polygon'}}
|
626
|
+
for label_code, types_for_label in raster.label_to_types_map.items():
|
627
|
+
# Check if the set of types for this specific label
|
628
|
+
# has any overlap with our valid reference types.
|
629
|
+
if not valid_types.isdisjoint(types_for_label):
|
630
|
+
# This label is a valid reference label.
|
631
|
+
# Add its full Label object to our set of valid labels.
|
632
|
+
if label_code in all_project_labels:
|
633
|
+
valid_labels.add(all_project_labels[label_code])
|
634
|
+
|
635
|
+
# Add the valid labels to the combo box, sorted alphabetically.
|
636
|
+
sorted_valid_labels = sorted(list(valid_labels), key=lambda x: x.short_label_code)
|
637
|
+
for label_obj in sorted_valid_labels:
|
638
|
+
self.source_label_combo_box.addItem(label_obj.short_label_code, label_obj)
|
639
|
+
|
640
|
+
# Restore the last selected label if it's still present in the list.
|
641
|
+
if self.last_selected_label_code:
|
642
|
+
index = self.source_label_combo_box.findText(self.last_selected_label_code)
|
643
|
+
if index != -1:
|
644
|
+
self.source_label_combo_box.setCurrentIndex(index)
|
645
|
+
finally:
|
646
|
+
self.source_label_combo_box.blockSignals(False)
|
647
|
+
|
648
|
+
# Manually trigger the filtering now that the combo box is stable.
|
649
|
+
self.filter_images_by_label_and_type()
|
650
|
+
|
651
|
+
return True
|
652
|
+
|
653
|
+
def get_source_annotations(self, reference_label, reference_image_path):
|
654
|
+
"""
|
655
|
+
Return a list of bboxes and masks for a specific image
|
656
|
+
belonging to the selected label.
|
657
|
+
|
658
|
+
:param reference_label: The Label object to filter annotations by.
|
659
|
+
:param reference_image_path: The path of the image to get annotations from.
|
660
|
+
:return: A tuple containing a numpy array of bboxes and a list of masks.
|
661
|
+
"""
|
662
|
+
if not all([reference_label, reference_image_path]):
|
663
|
+
return np.array([]), []
|
664
|
+
|
665
|
+
# Get all annotations for the specified image
|
666
|
+
annotations = self.annotation_window.get_image_annotations(reference_image_path)
|
667
|
+
|
668
|
+
# Filter annotations by the provided label
|
669
|
+
source_bboxes = []
|
670
|
+
source_masks = []
|
671
|
+
for annotation in annotations:
|
672
|
+
if annotation.label.short_label_code == reference_label.short_label_code:
|
673
|
+
if isinstance(annotation, (PolygonAnnotation, RectangleAnnotation)):
|
674
|
+
bbox = annotation.cropped_bbox
|
675
|
+
source_bboxes.append(bbox)
|
676
|
+
if isinstance(annotation, PolygonAnnotation):
|
677
|
+
points = np.array([[p.x(), p.y()] for p in annotation.points])
|
678
|
+
source_masks.append(points)
|
679
|
+
elif isinstance(annotation, RectangleAnnotation):
|
680
|
+
x1, y1, x2, y2 = bbox
|
681
|
+
rect_points = np.array([[x1, y1], [x2, y1], [x2, y2], [x1, y2]])
|
682
|
+
source_masks.append(rect_points)
|
683
|
+
|
684
|
+
return np.array(source_bboxes), source_masks
|
685
|
+
|
686
|
+
def load_model(self):
|
687
|
+
"""
|
688
|
+
Load the selected model.
|
689
|
+
"""
|
690
|
+
QApplication.setOverrideCursor(Qt.WaitCursor)
|
691
|
+
progress_bar = ProgressBar(self.annotation_window, title="Loading Model")
|
692
|
+
progress_bar.show()
|
693
|
+
|
694
|
+
try:
|
695
|
+
# Get selected model path and download weights if needed
|
696
|
+
self.model_path = self.model_combo.currentText()
|
697
|
+
|
698
|
+
# Load model using registry
|
699
|
+
self.loaded_model = YOLOE(self.model_path).to(self.main_window.device)
|
700
|
+
|
701
|
+
# Create a dummy visual dictionary
|
702
|
+
visuals = dict(
|
703
|
+
bboxes=np.array(
|
704
|
+
[
|
705
|
+
[120, 425, 160, 445],
|
706
|
+
],
|
707
|
+
),
|
708
|
+
cls=np.array(
|
709
|
+
np.zeros(1),
|
710
|
+
),
|
711
|
+
)
|
712
|
+
|
713
|
+
# Run a dummy prediction to load the model
|
714
|
+
self.loaded_model.predict(
|
715
|
+
np.zeros((640, 640, 3), dtype=np.uint8),
|
716
|
+
visual_prompts=visuals.copy(),
|
717
|
+
predictor=YOLOEVPDetectPredictor,
|
718
|
+
imgsz=640,
|
719
|
+
conf=0.99,
|
720
|
+
)
|
721
|
+
|
722
|
+
progress_bar.finish_progress()
|
723
|
+
self.status_bar.setText("Model loaded")
|
724
|
+
QMessageBox.information(self.annotation_window,
|
725
|
+
"Model Loaded",
|
726
|
+
"Model loaded successfully")
|
727
|
+
|
728
|
+
except Exception as e:
|
729
|
+
QMessageBox.critical(self.annotation_window,
|
730
|
+
"Error Loading Model",
|
731
|
+
f"Error loading model: {e}")
|
732
|
+
|
733
|
+
finally:
|
734
|
+
# Restore cursor
|
735
|
+
QApplication.restoreOverrideCursor()
|
736
|
+
# Stop the progress bar
|
737
|
+
progress_bar.stop_progress()
|
738
|
+
progress_bar.close()
|
739
|
+
progress_bar = None
|
740
|
+
|
741
|
+
def predict(self, image_paths=None):
|
742
|
+
"""
|
743
|
+
Make predictions on the given image paths using the loaded model.
|
744
|
+
|
745
|
+
Args:
|
746
|
+
image_paths: List of image paths to process. If None, uses the current image.
|
747
|
+
"""
|
748
|
+
if not self.loaded_model or not self.source_label:
|
749
|
+
return
|
750
|
+
|
751
|
+
# Update class mapping with the selected reference label
|
752
|
+
self.class_mapping = {0: self.source_label}
|
753
|
+
|
754
|
+
# Create a results processor
|
755
|
+
results_processor = ResultsProcessor(
|
756
|
+
self.main_window,
|
757
|
+
self.class_mapping,
|
758
|
+
uncertainty_thresh=self.main_window.get_uncertainty_thresh(),
|
759
|
+
iou_thresh=self.main_window.get_iou_thresh(),
|
760
|
+
min_area_thresh=self.main_window.get_area_thresh_min(),
|
761
|
+
max_area_thresh=self.main_window.get_area_thresh_max()
|
762
|
+
)
|
763
|
+
|
764
|
+
if not image_paths:
|
765
|
+
# Predict only the current image
|
766
|
+
image_paths = [self.annotation_window.current_image_path]
|
767
|
+
|
768
|
+
# Make cursor busy
|
769
|
+
QApplication.setOverrideCursor(Qt.WaitCursor)
|
770
|
+
|
771
|
+
# Start the progress bar
|
772
|
+
progress_bar = ProgressBar(self.annotation_window, title="Prediction Workflow")
|
773
|
+
progress_bar.show()
|
774
|
+
progress_bar.start_progress(len(image_paths))
|
775
|
+
|
776
|
+
try:
|
777
|
+
for image_path in image_paths:
|
778
|
+
inputs = self._get_inputs(image_path)
|
779
|
+
if inputs is None:
|
780
|
+
continue
|
781
|
+
|
782
|
+
results = self._apply_model(inputs)
|
783
|
+
results = self._apply_sam(results, image_path)
|
784
|
+
self._process_results(results_processor, results, image_path)
|
785
|
+
|
786
|
+
# Update the progress bar
|
787
|
+
progress_bar.update_progress()
|
788
|
+
|
789
|
+
except Exception as e:
|
790
|
+
print("An error occurred during prediction:", e)
|
791
|
+
finally:
|
792
|
+
QApplication.restoreOverrideCursor()
|
793
|
+
progress_bar.finish_progress()
|
794
|
+
progress_bar.stop_progress()
|
795
|
+
progress_bar.close()
|
796
|
+
|
797
|
+
gc.collect()
|
798
|
+
empty_cache()
|
799
|
+
|
800
|
+
def _get_inputs(self, image_path):
|
801
|
+
"""Get the inputs for the model prediction."""
|
802
|
+
raster = self.image_window.raster_manager.get_raster(image_path)
|
803
|
+
if self.annotation_window.get_selected_tool() != "work_area":
|
804
|
+
# Use the image path
|
805
|
+
work_areas_data = [raster.image_path]
|
806
|
+
else:
|
807
|
+
# Get the work areas
|
808
|
+
work_areas_data = raster.get_work_areas_data()
|
809
|
+
|
810
|
+
return work_areas_data
|
811
|
+
|
812
|
+
def _apply_model(self, inputs):
|
813
|
+
"""
|
814
|
+
Apply the model to the target inputs, using each highlighted source
|
815
|
+
image as an individual reference for a separate prediction run.
|
816
|
+
"""
|
817
|
+
# Update the model with user parameters
|
818
|
+
self.loaded_model.conf = self.main_window.get_uncertainty_thresh()
|
819
|
+
self.loaded_model.iou = self.main_window.get_iou_thresh()
|
820
|
+
self.loaded_model.max_det = self.get_max_detections()
|
821
|
+
|
822
|
+
# NOTE: self.target_images contains the reference images highlighted in the dialog
|
823
|
+
reference_image_paths = self.target_images
|
824
|
+
|
825
|
+
if not reference_image_paths:
|
826
|
+
QMessageBox.warning(self,
|
827
|
+
"No Reference Images",
|
828
|
+
"You must highlight at least one reference image.")
|
829
|
+
return []
|
830
|
+
|
831
|
+
# Get the selected reference label from the stored variable
|
832
|
+
source_label = self.source_label
|
833
|
+
|
834
|
+
# Create a dictionary of reference annotations, with image path as the key.
|
835
|
+
reference_annotations_dict = {}
|
836
|
+
for path in reference_image_paths:
|
837
|
+
bboxes, masks = self.get_source_annotations(source_label, path)
|
838
|
+
if bboxes.size > 0:
|
839
|
+
reference_annotations_dict[path] = {
|
840
|
+
'bboxes': bboxes,
|
841
|
+
'masks': masks,
|
842
|
+
'cls': np.zeros(len(bboxes))
|
843
|
+
}
|
844
|
+
|
845
|
+
# Set the task
|
846
|
+
self.task = self.use_task_dropdown.currentText()
|
847
|
+
predictor = YOLOEVPSegPredictor if self.task == "segment" else YOLOEVPDetectPredictor
|
848
|
+
|
849
|
+
# Create a progress bar for iterating through reference images
|
850
|
+
QApplication.setOverrideCursor(Qt.WaitCursor)
|
851
|
+
progress_bar = ProgressBar(self.annotation_window, title="Making Predictions per Reference")
|
852
|
+
progress_bar.show()
|
853
|
+
progress_bar.start_progress(len(reference_annotations_dict))
|
854
|
+
|
855
|
+
results_list = []
|
856
|
+
# The 'inputs' list contains work areas from the single target image.
|
857
|
+
# We will predict on the first work area/full image.
|
858
|
+
input_image = inputs[0]
|
859
|
+
|
860
|
+
# Iterate through each reference image and its annotations
|
861
|
+
for ref_path, ref_annotations in reference_annotations_dict.items():
|
862
|
+
# The 'refer_image' parameter is the path to the current reference image
|
863
|
+
# The 'visual_prompts' are the annotations from that same reference image
|
864
|
+
visuals = {
|
865
|
+
'bboxes': ref_annotations['bboxes'],
|
866
|
+
'cls': ref_annotations['cls'],
|
867
|
+
}
|
868
|
+
if self.task == 'segment':
|
869
|
+
visuals['masks'] = ref_annotations['masks']
|
870
|
+
|
871
|
+
# Make predictions on the target using the current reference
|
872
|
+
results = self.loaded_model.predict(input_image,
|
873
|
+
refer_image=ref_path,
|
874
|
+
visual_prompts=visuals,
|
875
|
+
predictor=predictor,
|
876
|
+
imgsz=self.imgsz_spinbox.value(),
|
877
|
+
conf=self.main_window.get_uncertainty_thresh(),
|
878
|
+
iou=self.main_window.get_iou_thresh(),
|
879
|
+
max_det=self.get_max_detections(),
|
880
|
+
retina_masks=self.task == "segment")
|
881
|
+
|
882
|
+
if not len(results[0].boxes):
|
883
|
+
# If no boxes were detected, skip to the next reference
|
884
|
+
progress_bar.update_progress()
|
885
|
+
continue
|
886
|
+
|
887
|
+
# Update the name of the results and append to the list
|
888
|
+
results[0].names = {0: self.class_mapping[0].short_label_code}
|
889
|
+
results_list.extend(results[0])
|
890
|
+
|
891
|
+
progress_bar.update_progress()
|
892
|
+
gc.collect()
|
893
|
+
empty_cache()
|
894
|
+
|
895
|
+
# Clean up
|
896
|
+
QApplication.restoreOverrideCursor()
|
897
|
+
progress_bar.finish_progress()
|
898
|
+
progress_bar.stop_progress()
|
899
|
+
progress_bar.close()
|
900
|
+
|
901
|
+
# Combine results if there are any
|
902
|
+
combined_results = CombineResults().combine_results(results_list)
|
903
|
+
if combined_results is None:
|
904
|
+
return []
|
905
|
+
|
906
|
+
return [[combined_results]]
|
907
|
+
|
908
|
+
def _apply_sam(self, results_list, image_path):
|
909
|
+
"""Apply SAM to the results if needed."""
|
910
|
+
# Check if SAM model is deployed and loaded
|
911
|
+
self.update_sam_task_state()
|
912
|
+
if self.task != 'segment':
|
913
|
+
return results_list
|
914
|
+
|
915
|
+
if not self.sam_dialog or self.use_sam_dropdown.currentText() == "False":
|
916
|
+
# If SAM is not deployed or not selected, return the results as is
|
917
|
+
return results_list
|
918
|
+
|
919
|
+
if self.sam_dialog.loaded_model is None:
|
920
|
+
# If SAM is not loaded, ensure we do not use it accidentally
|
921
|
+
self.task = 'detect'
|
922
|
+
self.use_sam_dropdown.setCurrentText("False")
|
923
|
+
return results_list
|
924
|
+
|
925
|
+
# Make cursor busy
|
926
|
+
QApplication.setOverrideCursor(Qt.WaitCursor)
|
927
|
+
progress_bar = ProgressBar(self.annotation_window, title="Predicting with SAM")
|
928
|
+
progress_bar.show()
|
929
|
+
progress_bar.start_progress(len(results_list))
|
930
|
+
|
931
|
+
updated_results = []
|
932
|
+
|
933
|
+
for idx, results in enumerate(results_list):
|
934
|
+
# Each Results is a list (within the results_list, [[], ]
|
935
|
+
if results:
|
936
|
+
# Run it rough the SAM model
|
937
|
+
results = self.sam_dialog.predict_from_results(results, image_path)
|
938
|
+
updated_results.append(results)
|
939
|
+
|
940
|
+
# Update the progress bar
|
941
|
+
progress_bar.update_progress()
|
942
|
+
|
943
|
+
# Make cursor normal
|
944
|
+
QApplication.restoreOverrideCursor()
|
945
|
+
progress_bar.finish_progress()
|
946
|
+
progress_bar.stop_progress()
|
947
|
+
progress_bar.close()
|
948
|
+
|
949
|
+
return updated_results
|
950
|
+
|
951
|
+
def _process_results(self, results_processor, results_list, image_path):
|
952
|
+
"""Process the results using the result processor."""
|
953
|
+
# Get the raster object and number of work items
|
954
|
+
raster = self.image_window.raster_manager.get_raster(image_path)
|
955
|
+
total = raster.count_work_items()
|
956
|
+
|
957
|
+
# Get the work areas (if any)
|
958
|
+
work_areas = raster.get_work_areas()
|
959
|
+
|
960
|
+
# Start the progress bar
|
961
|
+
progress_bar = ProgressBar(self.annotation_window, title="Processing Results")
|
962
|
+
progress_bar.show()
|
963
|
+
progress_bar.start_progress(total)
|
964
|
+
|
965
|
+
updated_results = []
|
966
|
+
|
967
|
+
for idx, results in enumerate(results_list):
|
968
|
+
# Each Results is a list (within the results_list, [[], ]
|
969
|
+
if results:
|
970
|
+
# Update path and names
|
971
|
+
results[0].path = image_path
|
972
|
+
results[0].names = {0: self.class_mapping[0].short_label_code}
|
973
|
+
# This needs to be done again, in case SAM was used
|
974
|
+
|
975
|
+
# Check if the work area is valid, or the image path is being used
|
976
|
+
if work_areas and self.annotation_window.get_selected_tool() == "work_area":
|
977
|
+
# Map results from work area to the full image
|
978
|
+
results = MapResults().map_results_from_work_area(results[0],
|
979
|
+
raster,
|
980
|
+
work_areas[idx],
|
981
|
+
self.task == "segment")
|
982
|
+
else:
|
983
|
+
results = results[0]
|
984
|
+
|
985
|
+
# Append the result object (not a list) to the updated results list
|
986
|
+
updated_results.append(results)
|
987
|
+
|
988
|
+
# Update the index for the next work area
|
989
|
+
idx += 1
|
990
|
+
progress_bar.update_progress()
|
991
|
+
|
992
|
+
# Process the Results
|
993
|
+
if self.task == 'segment' or self.use_sam_dropdown.currentText() == "True":
|
994
|
+
results_processor.process_segmentation_results(updated_results)
|
995
|
+
else:
|
996
|
+
results_processor.process_detection_results(updated_results)
|
997
|
+
|
998
|
+
# Close the progress bar
|
999
|
+
progress_bar.finish_progress()
|
1000
|
+
progress_bar.stop_progress()
|
1001
|
+
progress_bar.close()
|
1002
|
+
|
1003
|
+
def deactivate_model(self):
|
1004
|
+
"""
|
1005
|
+
Deactivate the currently loaded model and clean up resources.
|
1006
|
+
"""
|
1007
|
+
self.loaded_model = None
|
1008
|
+
self.model_path = None
|
1009
|
+
# Clean up resources
|
1010
|
+
gc.collect()
|
1011
|
+
torch.cuda.empty_cache()
|
1012
|
+
# Untoggle all tools
|
1013
|
+
self.main_window.untoggle_all_tools()
|
1014
|
+
# Update status bar
|
1015
|
+
self.status_bar.setText("No model loaded")
|
1016
|
+
QMessageBox.information(self, "Model Deactivated", "Model deactivated")
|