coralnet-toolbox 0.0.71__py2.py3-none-any.whl → 0.0.73__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. coralnet_toolbox/Annotations/QtRectangleAnnotation.py +31 -2
  2. coralnet_toolbox/AutoDistill/QtDeployModel.py +23 -12
  3. coralnet_toolbox/Explorer/QtDataItem.py +53 -21
  4. coralnet_toolbox/Explorer/QtExplorer.py +581 -276
  5. coralnet_toolbox/Explorer/QtFeatureStore.py +15 -0
  6. coralnet_toolbox/Explorer/QtSettingsWidgets.py +49 -7
  7. coralnet_toolbox/MachineLearning/DeployModel/QtDetect.py +22 -11
  8. coralnet_toolbox/MachineLearning/DeployModel/QtSegment.py +22 -10
  9. coralnet_toolbox/MachineLearning/ExportDataset/QtBase.py +61 -24
  10. coralnet_toolbox/MachineLearning/ExportDataset/QtClassify.py +5 -1
  11. coralnet_toolbox/MachineLearning/ExportDataset/QtDetect.py +19 -6
  12. coralnet_toolbox/MachineLearning/ExportDataset/QtSegment.py +21 -8
  13. coralnet_toolbox/QtAnnotationWindow.py +52 -16
  14. coralnet_toolbox/QtEventFilter.py +8 -2
  15. coralnet_toolbox/QtImageWindow.py +17 -18
  16. coralnet_toolbox/QtLabelWindow.py +1 -1
  17. coralnet_toolbox/QtMainWindow.py +203 -8
  18. coralnet_toolbox/Rasters/QtRaster.py +59 -7
  19. coralnet_toolbox/Rasters/RasterTableModel.py +34 -6
  20. coralnet_toolbox/SAM/QtBatchInference.py +0 -2
  21. coralnet_toolbox/SAM/QtDeployGenerator.py +22 -11
  22. coralnet_toolbox/SeeAnything/QtBatchInference.py +19 -221
  23. coralnet_toolbox/SeeAnything/QtDeployGenerator.py +1016 -0
  24. coralnet_toolbox/SeeAnything/QtDeployPredictor.py +69 -53
  25. coralnet_toolbox/SeeAnything/QtTrainModel.py +115 -45
  26. coralnet_toolbox/SeeAnything/__init__.py +2 -0
  27. coralnet_toolbox/Tools/QtResizeSubTool.py +6 -1
  28. coralnet_toolbox/Tools/QtSAMTool.py +150 -7
  29. coralnet_toolbox/Tools/QtSeeAnythingTool.py +220 -55
  30. coralnet_toolbox/Tools/QtSelectSubTool.py +6 -4
  31. coralnet_toolbox/Tools/QtSelectTool.py +48 -6
  32. coralnet_toolbox/Tools/QtWorkAreaTool.py +25 -13
  33. coralnet_toolbox/__init__.py +1 -1
  34. {coralnet_toolbox-0.0.71.dist-info → coralnet_toolbox-0.0.73.dist-info}/METADATA +1 -1
  35. {coralnet_toolbox-0.0.71.dist-info → coralnet_toolbox-0.0.73.dist-info}/RECORD +39 -38
  36. {coralnet_toolbox-0.0.71.dist-info → coralnet_toolbox-0.0.73.dist-info}/WHEEL +0 -0
  37. {coralnet_toolbox-0.0.71.dist-info → coralnet_toolbox-0.0.73.dist-info}/entry_points.txt +0 -0
  38. {coralnet_toolbox-0.0.71.dist-info → coralnet_toolbox-0.0.73.dist-info}/licenses/LICENSE.txt +0 -0
  39. {coralnet_toolbox-0.0.71.dist-info → coralnet_toolbox-0.0.73.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1016 @@
1
+ import warnings
2
+
3
+ import os
4
+ import gc
5
+ import json
6
+ import copy
7
+
8
+ import numpy as np
9
+
10
+ import torch
11
+ from torch.cuda import empty_cache
12
+
13
+ from ultralytics import YOLOE
14
+ from ultralytics.models.yolo.yoloe import YOLOEVPSegPredictor
15
+ from ultralytics.models.yolo.yoloe import YOLOEVPDetectPredictor
16
+
17
+ from PyQt5.QtCore import Qt
18
+ from PyQt5.QtGui import QColor
19
+ from PyQt5.QtWidgets import (QMessageBox, QCheckBox, QVBoxLayout, QApplication,
20
+ QLabel, QDialog, QDialogButtonBox, QGroupBox, QLineEdit,
21
+ QFormLayout, QComboBox, QSpinBox, QSlider, QPushButton,
22
+ QHBoxLayout, QWidget, QFileDialog)
23
+
24
+ from coralnet_toolbox.Annotations.QtPolygonAnnotation import PolygonAnnotation
25
+ from coralnet_toolbox.Annotations.QtRectangleAnnotation import RectangleAnnotation
26
+
27
+ from coralnet_toolbox.Results import ResultsProcessor
28
+ from coralnet_toolbox.Results import MapResults
29
+ from coralnet_toolbox.Results import CombineResults
30
+
31
+ from coralnet_toolbox.QtProgressBar import ProgressBar
32
+ from coralnet_toolbox.QtImageWindow import ImageWindow
33
+
34
+ from coralnet_toolbox.Icons import get_icon
35
+
36
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
37
+ warnings.filterwarnings("ignore", category=UserWarning)
38
+
39
+
40
+ # ----------------------------------------------------------------------------------------------------------------------
41
+ # Classes
42
+ # ----------------------------------------------------------------------------------------------------------------------
43
+
44
+
45
+ class DeployGeneratorDialog(QDialog):
46
+ """
47
+ Perform See Anything (YOLOE) on multiple images using a reference image and label.
48
+
49
+ :param main_window: MainWindow object
50
+ :param parent: Parent widget
51
+ """
52
+ def __init__(self, main_window, parent=None):
53
+ super().__init__(parent)
54
+ self.main_window = main_window
55
+ self.label_window = main_window.label_window
56
+ self.image_window = main_window.image_window
57
+ self.annotation_window = main_window.annotation_window
58
+ self.sam_dialog = None
59
+
60
+ self.setWindowIcon(get_icon("eye.png"))
61
+ self.setWindowTitle("See Anything (YOLOE) Generator (Ctrl + 5)")
62
+ self.resize(800, 800) # Increased size to accommodate the horizontal layout
63
+
64
+ self.deploy_model_dialog = None
65
+ self.loaded_model = None
66
+ self.last_selected_label_code = None
67
+
68
+ # Initialize variables
69
+ self.imgsz = 1024
70
+ self.iou_thresh = 0.20
71
+ self.uncertainty_thresh = 0.30
72
+ self.area_thresh_min = 0.00
73
+ self.area_thresh_max = 0.40
74
+
75
+ self.task = 'detect'
76
+ self.max_detect = 300
77
+ self.loaded_model = None
78
+ self.model_path = None
79
+ self.class_mapping = {}
80
+
81
+ # Reference image and label
82
+ self.source_images = []
83
+ self.source_label = None
84
+ # Target images
85
+ self.target_images = []
86
+
87
+ # Main vertical layout for the dialog
88
+ self.layout = QVBoxLayout(self)
89
+
90
+ # Setup the info layout at the top
91
+ self.setup_info_layout()
92
+
93
+ # Create horizontal layout for the two panels
94
+ self.horizontal_layout = QHBoxLayout()
95
+ self.layout.addLayout(self.horizontal_layout)
96
+
97
+ # Create left panel
98
+ self.left_panel = QVBoxLayout()
99
+ self.horizontal_layout.addLayout(self.left_panel)
100
+
101
+ # Create right panel
102
+ self.right_panel = QVBoxLayout()
103
+ self.horizontal_layout.addLayout(self.right_panel)
104
+
105
+ # Add layouts to the left panel
106
+ self.setup_models_layout()
107
+ self.setup_parameters_layout()
108
+ self.setup_sam_layout()
109
+ self.setup_model_buttons_layout()
110
+ self.setup_status_layout()
111
+
112
+ # Add layouts to the right panel
113
+ self.setup_source_layout()
114
+
115
+ # # Add a full ImageWindow instance for target image selection
116
+ self.image_selection_window = ImageWindow(self.main_window)
117
+ self.right_panel.addWidget(self.image_selection_window)
118
+
119
+ # Setup the buttons layout at the bottom
120
+ self.setup_buttons_layout()
121
+
122
+ def configure_image_window_for_dialog(self):
123
+ """
124
+ Disables parts of the internal ImageWindow UI to guide user selection.
125
+ This forces the image list to only show images with annotations
126
+ matching the selected reference label.
127
+ """
128
+ iw = self.image_selection_window
129
+
130
+ # Block signals to prevent setChecked from triggering the ImageWindow's
131
+ # own filtering logic. We want to be in complete control.
132
+ iw.highlighted_checkbox.blockSignals(True)
133
+ iw.has_predictions_checkbox.blockSignals(True)
134
+ iw.no_annotations_checkbox.blockSignals(True)
135
+ iw.has_annotations_checkbox.blockSignals(True)
136
+
137
+ # Disable and set filter checkboxes
138
+ iw.highlighted_checkbox.setEnabled(False)
139
+ iw.has_predictions_checkbox.setEnabled(False)
140
+ iw.no_annotations_checkbox.setEnabled(False)
141
+ iw.has_annotations_checkbox.setEnabled(False)
142
+
143
+ iw.highlighted_checkbox.setChecked(False)
144
+ iw.has_predictions_checkbox.setChecked(False)
145
+ iw.no_annotations_checkbox.setChecked(False)
146
+ iw.has_annotations_checkbox.setChecked(True) # This will no longer trigger a filter
147
+
148
+ # Unblock signals now that we're done.
149
+ iw.highlighted_checkbox.blockSignals(False)
150
+ iw.has_predictions_checkbox.blockSignals(False)
151
+ iw.no_annotations_checkbox.blockSignals(False)
152
+ iw.has_annotations_checkbox.blockSignals(False)
153
+
154
+ # Disable search UI elements
155
+ iw.home_button.setEnabled(False)
156
+ iw.image_search_button.setEnabled(False)
157
+ iw.label_search_button.setEnabled(False)
158
+ iw.search_bar_images.setEnabled(False)
159
+ iw.search_bar_labels.setEnabled(False)
160
+ iw.top_k_combo.setEnabled(False)
161
+
162
+ # Set Top-K to Top1
163
+ iw.top_k_combo.setCurrentText("Top1")
164
+
165
+ # Disconnect the double-click signal to prevent it from loading an image
166
+ # in the main window, as this dialog is for selection only.
167
+ try:
168
+ iw.tableView.doubleClicked.disconnect()
169
+ except TypeError:
170
+ pass
171
+
172
+ # CRITICAL: Override the load_first_filtered_image method to prevent auto-loading
173
+ # This is the key fix to prevent unwanted load_image_by_path calls
174
+ iw.load_first_filtered_image = lambda: None
175
+
176
+ def showEvent(self, event):
177
+ """
178
+ Set up the layout when the dialog is shown.
179
+
180
+ :param event: Show event
181
+ """
182
+ super().showEvent(event)
183
+ self.initialize_uncertainty_threshold()
184
+ self.initialize_iou_threshold()
185
+ self.initialize_area_threshold()
186
+
187
+ # Configure the image window's UI elements for this specific dialog
188
+ self.configure_image_window_for_dialog()
189
+ # Sync with main window's images BEFORE updating labels
190
+ self.sync_image_window()
191
+ # This now populates the dropdown, restores the last selection,
192
+ # and then manually triggers the image filtering.
193
+ self.update_source_labels()
194
+
195
+ def sync_image_window(self):
196
+ """
197
+ Syncs by directly adopting the main manager's up-to-date raster objects,
198
+ avoiding redundant and slow re-calculation of annotation info.
199
+ """
200
+ main_manager = self.main_window.image_window.raster_manager
201
+ dialog_manager = self.image_selection_window.raster_manager
202
+
203
+ # Since the main_manager's rasters are always up-to-date, we can
204
+ # simply replace the dialog's raster dictionary and path list entirely.
205
+ # This is a shallow copy of the dictionary, which is extremely fast.
206
+ # The Raster objects themselves are not copied, just referenced.
207
+ dialog_manager.rasters = main_manager.rasters.copy()
208
+
209
+ # Update the path list to match the new dictionary of rasters.
210
+ dialog_manager.image_paths = list(dialog_manager.rasters.keys())
211
+
212
+ # The slow 'for' loop that called update_annotation_info is now gone.
213
+ # We are trusting that each raster object from the main_manager
214
+ # already has its .label_set and .annotation_type_set correctly populated.
215
+
216
+ def filter_images_by_label_and_type(self):
217
+ """
218
+ Filters the image list to show only images that contain at least one
219
+ annotation that has BOTH the selected label AND a valid type (Polygon or Rectangle).
220
+ This uses the fast, pre-computed cache for performance.
221
+ """
222
+ source_label = self.source_label_combo_box.currentData()
223
+ source_label_text = self.source_label_combo_box.currentText()
224
+
225
+ # Store the last selected label for a better user experience on re-opening.
226
+ if source_label_text:
227
+ self.last_selected_label_code = source_label_text
228
+
229
+ if not source_label:
230
+ # If no label is selected (e.g., during initialization), show an empty list.
231
+ self.image_selection_window.table_model.set_filtered_paths([])
232
+ return
233
+
234
+ all_paths = self.image_selection_window.raster_manager.image_paths
235
+ final_filtered_paths = []
236
+
237
+ valid_types = {"RectangleAnnotation", "PolygonAnnotation"}
238
+ selected_label_code = source_label.short_label_code
239
+
240
+ # Loop through paths and check the pre-computed map on each raster
241
+ for path in all_paths:
242
+ raster = self.image_selection_window.raster_manager.get_raster(path)
243
+ if not raster:
244
+ continue
245
+
246
+ # 1. From the cache, get the set of annotation types specifically for our selected label.
247
+ # Use .get() to safely return an empty set if the label isn't on this image at all.
248
+ types_for_this_label = raster.label_to_types_map.get(selected_label_code, set())
249
+
250
+ # 2. Check for any overlap between the types found FOR THIS LABEL and the
251
+ # valid types we need (Polygon/Rectangle). This is the key check.
252
+ if not valid_types.isdisjoint(types_for_this_label):
253
+ # This image is a valid reference because the selected label exists
254
+ # on a Polygon or Rectangle. Add it to the list.
255
+ final_filtered_paths.append(path)
256
+
257
+ # Directly set the filtered list in the table model.
258
+ self.image_selection_window.table_model.set_filtered_paths(final_filtered_paths)
259
+
260
+ def accept(self):
261
+ """
262
+ Validate selections and store them before closing the dialog.
263
+ """
264
+ if not self.loaded_model:
265
+ QMessageBox.warning(self,
266
+ "No Model",
267
+ "A model must be loaded before running predictions.")
268
+ super().reject()
269
+ return
270
+
271
+ current_label = self.source_label_combo_box.currentData()
272
+ if not current_label:
273
+ QMessageBox.warning(self,
274
+ "No Source Label",
275
+ "A source label must be selected.")
276
+ super().reject()
277
+ return
278
+
279
+ # Get highlighted paths from our internal image window to use as targets
280
+ highlighted_images = self.image_selection_window.table_model.get_highlighted_paths()
281
+
282
+ if not highlighted_images:
283
+ QMessageBox.warning(self,
284
+ "No Target Images",
285
+ "You must highlight at least one image in the list to process.")
286
+ super().reject()
287
+ return
288
+
289
+ # Store the selections for the caller to use after the dialog closes.
290
+ self.source_label = current_label
291
+ self.target_images = highlighted_images
292
+
293
+ # Do not call self.predict here; just close the dialog and let the caller handle prediction
294
+ super().accept()
295
+
296
+ def setup_info_layout(self):
297
+ """
298
+ Set up the layout and widgets for the info layout that spans the top.
299
+ """
300
+ group_box = QGroupBox("Information")
301
+ layout = QVBoxLayout()
302
+
303
+ # Create a QLabel with explanatory text and hyperlink
304
+ info_label = QLabel("Choose a Generator to deploy. "
305
+ "Select a reference label, then highlight reference images that contain examples. "
306
+ "Each additional reference image may increase accuracy but also processing time.")
307
+
308
+ info_label.setOpenExternalLinks(True)
309
+ info_label.setWordWrap(True)
310
+ layout.addWidget(info_label)
311
+
312
+ group_box.setLayout(layout)
313
+ self.layout.addWidget(group_box) # Add to main layout so it spans both panels
314
+
315
+ def setup_models_layout(self):
316
+ """
317
+ Setup the models layout with a simple model selection combo box (no tabs).
318
+ """
319
+ group_box = QGroupBox("Model Selection")
320
+ layout = QVBoxLayout()
321
+
322
+ self.model_combo = QComboBox()
323
+ self.model_combo.setEditable(True)
324
+
325
+ # Define available models (keep the existing dictionary)
326
+ self.models = [
327
+ "yoloe-v8s-seg.pt",
328
+ "yoloe-v8m-seg.pt",
329
+ "yoloe-v8l-seg.pt",
330
+ "yoloe-11s-seg.pt",
331
+ "yoloe-11m-seg.pt",
332
+ "yoloe-11l-seg.pt",
333
+ ]
334
+
335
+ # Add all models to combo box
336
+ for model_name in self.models:
337
+ self.model_combo.addItem(model_name)
338
+
339
+ # Set the default model
340
+ self.model_combo.setCurrentText("yoloe-v8s-seg.pt")
341
+
342
+ layout.addWidget(QLabel("Select Model:"))
343
+ layout.addWidget(self.model_combo)
344
+
345
+ group_box.setLayout(layout)
346
+ self.left_panel.addWidget(group_box) # Add to left panel
347
+
348
+ def setup_parameters_layout(self):
349
+ """
350
+ Setup parameter control section in a group box.
351
+ """
352
+ group_box = QGroupBox("Parameters")
353
+ layout = QFormLayout()
354
+
355
+ # Task dropdown
356
+ self.use_task_dropdown = QComboBox()
357
+ self.use_task_dropdown.addItems(["detect", "segment"])
358
+ self.use_task_dropdown.currentIndexChanged.connect(self.update_task)
359
+ layout.addRow("Task:", self.use_task_dropdown)
360
+
361
+ # Max detections spinbox
362
+ self.max_detections_spinbox = QSpinBox()
363
+ self.max_detections_spinbox.setRange(1, 10000)
364
+ self.max_detections_spinbox.setValue(self.max_detect)
365
+ layout.addRow("Max Detections:", self.max_detections_spinbox)
366
+
367
+ # Resize image dropdown
368
+ self.resize_image_dropdown = QComboBox()
369
+ self.resize_image_dropdown.addItems(["True", "False"])
370
+ self.resize_image_dropdown.setCurrentIndex(0)
371
+ self.resize_image_dropdown.setEnabled(False) # Grey out the dropdown
372
+ layout.addRow("Resize Image:", self.resize_image_dropdown)
373
+
374
+ # Image size control
375
+ self.imgsz_spinbox = QSpinBox()
376
+ self.imgsz_spinbox.setRange(512, 65536)
377
+ self.imgsz_spinbox.setSingleStep(1024)
378
+ self.imgsz_spinbox.setValue(self.imgsz)
379
+ layout.addRow("Image Size (imgsz):", self.imgsz_spinbox)
380
+
381
+ # Uncertainty threshold controls
382
+ self.uncertainty_thresh = self.main_window.get_uncertainty_thresh()
383
+ self.uncertainty_threshold_slider = QSlider(Qt.Horizontal)
384
+ self.uncertainty_threshold_slider.setRange(0, 100)
385
+ self.uncertainty_threshold_slider.setValue(int(self.main_window.get_uncertainty_thresh() * 100))
386
+ self.uncertainty_threshold_slider.setTickPosition(QSlider.TicksBelow)
387
+ self.uncertainty_threshold_slider.setTickInterval(10)
388
+ self.uncertainty_threshold_slider.valueChanged.connect(self.update_uncertainty_label)
389
+ self.uncertainty_threshold_label = QLabel(f"{self.uncertainty_thresh:.2f}")
390
+ layout.addRow("Uncertainty Threshold", self.uncertainty_threshold_slider)
391
+ layout.addRow("", self.uncertainty_threshold_label)
392
+
393
+ # IoU threshold controls
394
+ self.iou_thresh = self.main_window.get_iou_thresh()
395
+ self.iou_threshold_slider = QSlider(Qt.Horizontal)
396
+ self.iou_threshold_slider.setRange(0, 100)
397
+ self.iou_threshold_slider.setValue(int(self.iou_thresh * 100))
398
+ self.iou_threshold_slider.setTickPosition(QSlider.TicksBelow)
399
+ self.iou_threshold_slider.setTickInterval(10)
400
+ self.iou_threshold_slider.valueChanged.connect(self.update_iou_label)
401
+ self.iou_threshold_label = QLabel(f"{self.iou_thresh:.2f}")
402
+ layout.addRow("IoU Threshold", self.iou_threshold_slider)
403
+ layout.addRow("", self.iou_threshold_label)
404
+
405
+ # Area threshold controls
406
+ min_val, max_val = self.main_window.get_area_thresh()
407
+ self.area_thresh_min = int(min_val * 100)
408
+ self.area_thresh_max = int(max_val * 100)
409
+ self.area_threshold_min_slider = QSlider(Qt.Horizontal)
410
+ self.area_threshold_min_slider.setRange(0, 100)
411
+ self.area_threshold_min_slider.setValue(self.area_thresh_min)
412
+ self.area_threshold_min_slider.setTickPosition(QSlider.TicksBelow)
413
+ self.area_threshold_min_slider.setTickInterval(10)
414
+ self.area_threshold_min_slider.valueChanged.connect(self.update_area_label)
415
+ self.area_threshold_max_slider = QSlider(Qt.Horizontal)
416
+ self.area_threshold_max_slider.setRange(0, 100)
417
+ self.area_threshold_max_slider.setValue(self.area_thresh_max)
418
+ self.area_threshold_max_slider.setTickPosition(QSlider.TicksBelow)
419
+ self.area_threshold_max_slider.setTickInterval(10)
420
+ self.area_threshold_max_slider.valueChanged.connect(self.update_area_label)
421
+ self.area_threshold_label = QLabel(f"{self.area_thresh_min / 100.0:.2f} - {self.area_thresh_max / 100.0:.2f}")
422
+ layout.addRow("Area Threshold Min", self.area_threshold_min_slider)
423
+ layout.addRow("Area Threshold Max", self.area_threshold_max_slider)
424
+ layout.addRow("", self.area_threshold_label)
425
+
426
+ group_box.setLayout(layout)
427
+ self.left_panel.addWidget(group_box) # Add to left panel
428
+
429
+ def setup_sam_layout(self):
430
+ """Use SAM model for segmentation."""
431
+ group_box = QGroupBox("Use SAM Model for Creating Polygons")
432
+ layout = QFormLayout()
433
+
434
+ # SAM dropdown
435
+ self.use_sam_dropdown = QComboBox()
436
+ self.use_sam_dropdown.addItems(["False", "True"])
437
+ self.use_sam_dropdown.currentIndexChanged.connect(self.is_sam_model_deployed)
438
+ layout.addRow("Use SAM Polygons:", self.use_sam_dropdown)
439
+
440
+ group_box.setLayout(layout)
441
+ self.left_panel.addWidget(group_box) # Add to left panel
442
+
443
+ def setup_model_buttons_layout(self):
444
+ """
445
+ Setup action buttons in a group box.
446
+ """
447
+ group_box = QGroupBox("Actions")
448
+ layout = QHBoxLayout()
449
+
450
+ load_button = QPushButton("Load Model")
451
+ load_button.clicked.connect(self.load_model)
452
+ layout.addWidget(load_button)
453
+
454
+ deactivate_button = QPushButton("Deactivate Model")
455
+ deactivate_button.clicked.connect(self.deactivate_model)
456
+ layout.addWidget(deactivate_button)
457
+
458
+ group_box.setLayout(layout)
459
+ self.left_panel.addWidget(group_box) # Add to left panel
460
+
461
+ def setup_status_layout(self):
462
+ """
463
+ Setup status display in a group box.
464
+ """
465
+ group_box = QGroupBox("Status")
466
+ layout = QVBoxLayout()
467
+
468
+ self.status_bar = QLabel("No model loaded")
469
+ layout.addWidget(self.status_bar)
470
+
471
+ group_box.setLayout(layout)
472
+ self.left_panel.addWidget(group_box) # Add to left panel
473
+
474
+ def setup_source_layout(self):
475
+ """
476
+ Set up the layout with source label selection.
477
+ The source image is implicitly the currently active image.
478
+ """
479
+ group_box = QGroupBox("Reference Label")
480
+ layout = QFormLayout()
481
+
482
+ # Create the source label combo box
483
+ self.source_label_combo_box = QComboBox()
484
+ self.source_label_combo_box.currentIndexChanged.connect(self.filter_images_by_label_and_type)
485
+ layout.addRow("Reference Label:", self.source_label_combo_box)
486
+
487
+ group_box.setLayout(layout)
488
+ self.right_panel.addWidget(group_box) # Add to right panel
489
+
490
+ def setup_buttons_layout(self):
491
+ """
492
+ Set up the layout with buttons.
493
+ """
494
+ # Create a button box for the buttons
495
+ button_box = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel)
496
+ button_box.accepted.connect(self.accept)
497
+ button_box.rejected.connect(self.reject)
498
+
499
+ self.layout.addWidget(button_box)
500
+
501
+ def initialize_uncertainty_threshold(self):
502
+ """Initialize the uncertainty threshold slider with the current value"""
503
+ current_value = self.main_window.get_uncertainty_thresh()
504
+ self.uncertainty_threshold_slider.setValue(int(current_value * 100))
505
+ self.uncertainty_thresh = current_value
506
+
507
+ def initialize_iou_threshold(self):
508
+ """Initialize the IOU threshold slider with the current value"""
509
+ current_value = self.main_window.get_iou_thresh()
510
+ self.iou_threshold_slider.setValue(int(current_value * 100))
511
+ self.iou_thresh = current_value
512
+
513
+ def initialize_area_threshold(self):
514
+ """Initialize the area threshold range slider"""
515
+ current_min, current_max = self.main_window.get_area_thresh()
516
+ self.area_threshold_min_slider.setValue(int(current_min * 100))
517
+ self.area_threshold_max_slider.setValue(int(current_max * 100))
518
+ self.area_thresh_min = current_min
519
+ self.area_thresh_max = current_max
520
+
521
+ def update_uncertainty_label(self, value):
522
+ """Update uncertainty threshold and label"""
523
+ value = value / 100.0
524
+ self.uncertainty_thresh = value
525
+ self.main_window.update_uncertainty_thresh(value)
526
+ self.uncertainty_threshold_label.setText(f"{value:.2f}")
527
+
528
+ def update_iou_label(self, value):
529
+ """Update IoU threshold and label"""
530
+ value = value / 100.0
531
+ self.iou_thresh = value
532
+ self.main_window.update_iou_thresh(value)
533
+ self.iou_threshold_label.setText(f"{value:.2f}")
534
+
535
+ def update_area_label(self):
536
+ """Handle changes to area threshold range slider"""
537
+ min_val = self.area_threshold_min_slider.value()
538
+ max_val = self.area_threshold_max_slider.value()
539
+ if min_val > max_val:
540
+ min_val = max_val
541
+ self.area_threshold_min_slider.setValue(min_val)
542
+ self.area_thresh_min = min_val / 100.0
543
+ self.area_thresh_max = max_val / 100.0
544
+ self.main_window.update_area_thresh(self.area_thresh_min, self.area_thresh_max)
545
+ self.area_threshold_label.setText(f"{self.area_thresh_min:.2f} - {self.area_thresh_max:.2f}")
546
+
547
+ def get_max_detections(self):
548
+ """Get the maximum number of detections to return."""
549
+ self.max_detect = self.max_detections_spinbox.value()
550
+ return self.max_detect
551
+
552
+ def is_sam_model_deployed(self):
553
+ """
554
+ Check if the SAM model is deployed and update the checkbox state accordingly.
555
+
556
+ :return: Boolean indicating whether the SAM model is deployed
557
+ """
558
+ if not hasattr(self.main_window, 'sam_deploy_predictor_dialog'):
559
+ return False
560
+
561
+ self.sam_dialog = self.main_window.sam_deploy_predictor_dialog
562
+
563
+ if not self.sam_dialog.loaded_model:
564
+ self.use_sam_dropdown.setCurrentText("False")
565
+ QMessageBox.critical(self, "Error", "Please deploy the SAM model first.")
566
+ return False
567
+
568
+ return True
569
+
570
+ def update_sam_task_state(self):
571
+ """
572
+ Centralized method to check if SAM is loaded and update task accordingly.
573
+ If the user has selected to use SAM, this function ensures the task is set to 'segment'.
574
+ Crucially, it does NOT alter the task if SAM is not selected, respecting the
575
+ user's choice from the 'Task' dropdown.
576
+ """
577
+ # Check if the user wants to use the SAM model
578
+ if self.use_sam_dropdown.currentText() == "True":
579
+ # SAM is requested. Check if it's actually available.
580
+ sam_is_available = (
581
+ hasattr(self, 'sam_dialog') and
582
+ self.sam_dialog is not None and
583
+ self.sam_dialog.loaded_model is not None
584
+ )
585
+
586
+ if sam_is_available:
587
+ # If SAM is wanted and available, the task must be segmentation.
588
+ self.task = 'segment'
589
+ else:
590
+ # If SAM is wanted but not available, revert the dropdown and do nothing else.
591
+ # The 'is_sam_model_deployed' function already handles showing an error message.
592
+ self.use_sam_dropdown.setCurrentText("False")
593
+
594
+ # If use_sam_dropdown is "False", do nothing. Let self.task be whatever the user set.
595
+
596
+ def update_task(self):
597
+ """Update the task based on the dropdown selection and handle UI/model effects."""
598
+ self.task = self.use_task_dropdown.currentText()
599
+
600
+ # Update UI elements based on task
601
+ if self.task == "segment":
602
+ # Deactivate model if one is loaded and we're switching to segment task
603
+ if self.loaded_model:
604
+ self.deactivate_model()
605
+
606
+ def update_source_labels(self):
607
+ """
608
+ Updates the source label combo box with labels that are associated with
609
+ valid reference annotations (Polygons or Rectangles), using the fast cache.
610
+ """
611
+ self.source_label_combo_box.blockSignals(True)
612
+
613
+ try:
614
+ self.source_label_combo_box.clear()
615
+
616
+ dialog_manager = self.image_selection_window.raster_manager
617
+ valid_types = {"RectangleAnnotation", "PolygonAnnotation"}
618
+ valid_labels = set() # This will store the full Label objects
619
+
620
+ # Create a lookup map to get full label objects from their codes
621
+ all_project_labels = {lbl.short_label_code: lbl for lbl in self.main_window.label_window.labels}
622
+
623
+ # Use the cached data to find all labels that have valid reference types.
624
+ for raster in dialog_manager.rasters.values():
625
+ # raster.label_to_types_map is like: {'coral': {'Point'}, 'rock': {'Polygon'}}
626
+ for label_code, types_for_label in raster.label_to_types_map.items():
627
+ # Check if the set of types for this specific label
628
+ # has any overlap with our valid reference types.
629
+ if not valid_types.isdisjoint(types_for_label):
630
+ # This label is a valid reference label.
631
+ # Add its full Label object to our set of valid labels.
632
+ if label_code in all_project_labels:
633
+ valid_labels.add(all_project_labels[label_code])
634
+
635
+ # Add the valid labels to the combo box, sorted alphabetically.
636
+ sorted_valid_labels = sorted(list(valid_labels), key=lambda x: x.short_label_code)
637
+ for label_obj in sorted_valid_labels:
638
+ self.source_label_combo_box.addItem(label_obj.short_label_code, label_obj)
639
+
640
+ # Restore the last selected label if it's still present in the list.
641
+ if self.last_selected_label_code:
642
+ index = self.source_label_combo_box.findText(self.last_selected_label_code)
643
+ if index != -1:
644
+ self.source_label_combo_box.setCurrentIndex(index)
645
+ finally:
646
+ self.source_label_combo_box.blockSignals(False)
647
+
648
+ # Manually trigger the filtering now that the combo box is stable.
649
+ self.filter_images_by_label_and_type()
650
+
651
+ return True
652
+
653
+ def get_source_annotations(self, reference_label, reference_image_path):
654
+ """
655
+ Return a list of bboxes and masks for a specific image
656
+ belonging to the selected label.
657
+
658
+ :param reference_label: The Label object to filter annotations by.
659
+ :param reference_image_path: The path of the image to get annotations from.
660
+ :return: A tuple containing a numpy array of bboxes and a list of masks.
661
+ """
662
+ if not all([reference_label, reference_image_path]):
663
+ return np.array([]), []
664
+
665
+ # Get all annotations for the specified image
666
+ annotations = self.annotation_window.get_image_annotations(reference_image_path)
667
+
668
+ # Filter annotations by the provided label
669
+ source_bboxes = []
670
+ source_masks = []
671
+ for annotation in annotations:
672
+ if annotation.label.short_label_code == reference_label.short_label_code:
673
+ if isinstance(annotation, (PolygonAnnotation, RectangleAnnotation)):
674
+ bbox = annotation.cropped_bbox
675
+ source_bboxes.append(bbox)
676
+ if isinstance(annotation, PolygonAnnotation):
677
+ points = np.array([[p.x(), p.y()] for p in annotation.points])
678
+ source_masks.append(points)
679
+ elif isinstance(annotation, RectangleAnnotation):
680
+ x1, y1, x2, y2 = bbox
681
+ rect_points = np.array([[x1, y1], [x2, y1], [x2, y2], [x1, y2]])
682
+ source_masks.append(rect_points)
683
+
684
+ return np.array(source_bboxes), source_masks
685
+
686
+ def load_model(self):
687
+ """
688
+ Load the selected model.
689
+ """
690
+ QApplication.setOverrideCursor(Qt.WaitCursor)
691
+ progress_bar = ProgressBar(self.annotation_window, title="Loading Model")
692
+ progress_bar.show()
693
+
694
+ try:
695
+ # Get selected model path and download weights if needed
696
+ self.model_path = self.model_combo.currentText()
697
+
698
+ # Load model using registry
699
+ self.loaded_model = YOLOE(self.model_path).to(self.main_window.device)
700
+
701
+ # Create a dummy visual dictionary
702
+ visuals = dict(
703
+ bboxes=np.array(
704
+ [
705
+ [120, 425, 160, 445],
706
+ ],
707
+ ),
708
+ cls=np.array(
709
+ np.zeros(1),
710
+ ),
711
+ )
712
+
713
+ # Run a dummy prediction to load the model
714
+ self.loaded_model.predict(
715
+ np.zeros((640, 640, 3), dtype=np.uint8),
716
+ visual_prompts=visuals.copy(),
717
+ predictor=YOLOEVPDetectPredictor,
718
+ imgsz=640,
719
+ conf=0.99,
720
+ )
721
+
722
+ progress_bar.finish_progress()
723
+ self.status_bar.setText("Model loaded")
724
+ QMessageBox.information(self.annotation_window,
725
+ "Model Loaded",
726
+ "Model loaded successfully")
727
+
728
+ except Exception as e:
729
+ QMessageBox.critical(self.annotation_window,
730
+ "Error Loading Model",
731
+ f"Error loading model: {e}")
732
+
733
+ finally:
734
+ # Restore cursor
735
+ QApplication.restoreOverrideCursor()
736
+ # Stop the progress bar
737
+ progress_bar.stop_progress()
738
+ progress_bar.close()
739
+ progress_bar = None
740
+
741
+ def predict(self, image_paths=None):
742
+ """
743
+ Make predictions on the given image paths using the loaded model.
744
+
745
+ Args:
746
+ image_paths: List of image paths to process. If None, uses the current image.
747
+ """
748
+ if not self.loaded_model or not self.source_label:
749
+ return
750
+
751
+ # Update class mapping with the selected reference label
752
+ self.class_mapping = {0: self.source_label}
753
+
754
+ # Create a results processor
755
+ results_processor = ResultsProcessor(
756
+ self.main_window,
757
+ self.class_mapping,
758
+ uncertainty_thresh=self.main_window.get_uncertainty_thresh(),
759
+ iou_thresh=self.main_window.get_iou_thresh(),
760
+ min_area_thresh=self.main_window.get_area_thresh_min(),
761
+ max_area_thresh=self.main_window.get_area_thresh_max()
762
+ )
763
+
764
+ if not image_paths:
765
+ # Predict only the current image
766
+ image_paths = [self.annotation_window.current_image_path]
767
+
768
+ # Make cursor busy
769
+ QApplication.setOverrideCursor(Qt.WaitCursor)
770
+
771
+ # Start the progress bar
772
+ progress_bar = ProgressBar(self.annotation_window, title="Prediction Workflow")
773
+ progress_bar.show()
774
+ progress_bar.start_progress(len(image_paths))
775
+
776
+ try:
777
+ for image_path in image_paths:
778
+ inputs = self._get_inputs(image_path)
779
+ if inputs is None:
780
+ continue
781
+
782
+ results = self._apply_model(inputs)
783
+ results = self._apply_sam(results, image_path)
784
+ self._process_results(results_processor, results, image_path)
785
+
786
+ # Update the progress bar
787
+ progress_bar.update_progress()
788
+
789
+ except Exception as e:
790
+ print("An error occurred during prediction:", e)
791
+ finally:
792
+ QApplication.restoreOverrideCursor()
793
+ progress_bar.finish_progress()
794
+ progress_bar.stop_progress()
795
+ progress_bar.close()
796
+
797
+ gc.collect()
798
+ empty_cache()
799
+
800
+ def _get_inputs(self, image_path):
801
+ """Get the inputs for the model prediction."""
802
+ raster = self.image_window.raster_manager.get_raster(image_path)
803
+ if self.annotation_window.get_selected_tool() != "work_area":
804
+ # Use the image path
805
+ work_areas_data = [raster.image_path]
806
+ else:
807
+ # Get the work areas
808
+ work_areas_data = raster.get_work_areas_data()
809
+
810
+ return work_areas_data
811
+
812
+ def _apply_model(self, inputs):
813
+ """
814
+ Apply the model to the target inputs, using each highlighted source
815
+ image as an individual reference for a separate prediction run.
816
+ """
817
+ # Update the model with user parameters
818
+ self.loaded_model.conf = self.main_window.get_uncertainty_thresh()
819
+ self.loaded_model.iou = self.main_window.get_iou_thresh()
820
+ self.loaded_model.max_det = self.get_max_detections()
821
+
822
+ # NOTE: self.target_images contains the reference images highlighted in the dialog
823
+ reference_image_paths = self.target_images
824
+
825
+ if not reference_image_paths:
826
+ QMessageBox.warning(self,
827
+ "No Reference Images",
828
+ "You must highlight at least one reference image.")
829
+ return []
830
+
831
+ # Get the selected reference label from the stored variable
832
+ source_label = self.source_label
833
+
834
+ # Create a dictionary of reference annotations, with image path as the key.
835
+ reference_annotations_dict = {}
836
+ for path in reference_image_paths:
837
+ bboxes, masks = self.get_source_annotations(source_label, path)
838
+ if bboxes.size > 0:
839
+ reference_annotations_dict[path] = {
840
+ 'bboxes': bboxes,
841
+ 'masks': masks,
842
+ 'cls': np.zeros(len(bboxes))
843
+ }
844
+
845
+ # Set the task
846
+ self.task = self.use_task_dropdown.currentText()
847
+ predictor = YOLOEVPSegPredictor if self.task == "segment" else YOLOEVPDetectPredictor
848
+
849
+ # Create a progress bar for iterating through reference images
850
+ QApplication.setOverrideCursor(Qt.WaitCursor)
851
+ progress_bar = ProgressBar(self.annotation_window, title="Making Predictions per Reference")
852
+ progress_bar.show()
853
+ progress_bar.start_progress(len(reference_annotations_dict))
854
+
855
+ results_list = []
856
+ # The 'inputs' list contains work areas from the single target image.
857
+ # We will predict on the first work area/full image.
858
+ input_image = inputs[0]
859
+
860
+ # Iterate through each reference image and its annotations
861
+ for ref_path, ref_annotations in reference_annotations_dict.items():
862
+ # The 'refer_image' parameter is the path to the current reference image
863
+ # The 'visual_prompts' are the annotations from that same reference image
864
+ visuals = {
865
+ 'bboxes': ref_annotations['bboxes'],
866
+ 'cls': ref_annotations['cls'],
867
+ }
868
+ if self.task == 'segment':
869
+ visuals['masks'] = ref_annotations['masks']
870
+
871
+ # Make predictions on the target using the current reference
872
+ results = self.loaded_model.predict(input_image,
873
+ refer_image=ref_path,
874
+ visual_prompts=visuals,
875
+ predictor=predictor,
876
+ imgsz=self.imgsz_spinbox.value(),
877
+ conf=self.main_window.get_uncertainty_thresh(),
878
+ iou=self.main_window.get_iou_thresh(),
879
+ max_det=self.get_max_detections(),
880
+ retina_masks=self.task == "segment")
881
+
882
+ if not len(results[0].boxes):
883
+ # If no boxes were detected, skip to the next reference
884
+ progress_bar.update_progress()
885
+ continue
886
+
887
+ # Update the name of the results and append to the list
888
+ results[0].names = {0: self.class_mapping[0].short_label_code}
889
+ results_list.extend(results[0])
890
+
891
+ progress_bar.update_progress()
892
+ gc.collect()
893
+ empty_cache()
894
+
895
+ # Clean up
896
+ QApplication.restoreOverrideCursor()
897
+ progress_bar.finish_progress()
898
+ progress_bar.stop_progress()
899
+ progress_bar.close()
900
+
901
+ # Combine results if there are any
902
+ combined_results = CombineResults().combine_results(results_list)
903
+ if combined_results is None:
904
+ return []
905
+
906
+ return [[combined_results]]
907
+
908
+ def _apply_sam(self, results_list, image_path):
909
+ """Apply SAM to the results if needed."""
910
+ # Check if SAM model is deployed and loaded
911
+ self.update_sam_task_state()
912
+ if self.task != 'segment':
913
+ return results_list
914
+
915
+ if not self.sam_dialog or self.use_sam_dropdown.currentText() == "False":
916
+ # If SAM is not deployed or not selected, return the results as is
917
+ return results_list
918
+
919
+ if self.sam_dialog.loaded_model is None:
920
+ # If SAM is not loaded, ensure we do not use it accidentally
921
+ self.task = 'detect'
922
+ self.use_sam_dropdown.setCurrentText("False")
923
+ return results_list
924
+
925
+ # Make cursor busy
926
+ QApplication.setOverrideCursor(Qt.WaitCursor)
927
+ progress_bar = ProgressBar(self.annotation_window, title="Predicting with SAM")
928
+ progress_bar.show()
929
+ progress_bar.start_progress(len(results_list))
930
+
931
+ updated_results = []
932
+
933
+ for idx, results in enumerate(results_list):
934
+ # Each Results is a list (within the results_list, [[], ]
935
+ if results:
936
+ # Run it rough the SAM model
937
+ results = self.sam_dialog.predict_from_results(results, image_path)
938
+ updated_results.append(results)
939
+
940
+ # Update the progress bar
941
+ progress_bar.update_progress()
942
+
943
+ # Make cursor normal
944
+ QApplication.restoreOverrideCursor()
945
+ progress_bar.finish_progress()
946
+ progress_bar.stop_progress()
947
+ progress_bar.close()
948
+
949
+ return updated_results
950
+
951
+ def _process_results(self, results_processor, results_list, image_path):
952
+ """Process the results using the result processor."""
953
+ # Get the raster object and number of work items
954
+ raster = self.image_window.raster_manager.get_raster(image_path)
955
+ total = raster.count_work_items()
956
+
957
+ # Get the work areas (if any)
958
+ work_areas = raster.get_work_areas()
959
+
960
+ # Start the progress bar
961
+ progress_bar = ProgressBar(self.annotation_window, title="Processing Results")
962
+ progress_bar.show()
963
+ progress_bar.start_progress(total)
964
+
965
+ updated_results = []
966
+
967
+ for idx, results in enumerate(results_list):
968
+ # Each Results is a list (within the results_list, [[], ]
969
+ if results:
970
+ # Update path and names
971
+ results[0].path = image_path
972
+ results[0].names = {0: self.class_mapping[0].short_label_code}
973
+ # This needs to be done again, in case SAM was used
974
+
975
+ # Check if the work area is valid, or the image path is being used
976
+ if work_areas and self.annotation_window.get_selected_tool() == "work_area":
977
+ # Map results from work area to the full image
978
+ results = MapResults().map_results_from_work_area(results[0],
979
+ raster,
980
+ work_areas[idx],
981
+ self.task == "segment")
982
+ else:
983
+ results = results[0]
984
+
985
+ # Append the result object (not a list) to the updated results list
986
+ updated_results.append(results)
987
+
988
+ # Update the index for the next work area
989
+ idx += 1
990
+ progress_bar.update_progress()
991
+
992
+ # Process the Results
993
+ if self.task == 'segment' or self.use_sam_dropdown.currentText() == "True":
994
+ results_processor.process_segmentation_results(updated_results)
995
+ else:
996
+ results_processor.process_detection_results(updated_results)
997
+
998
+ # Close the progress bar
999
+ progress_bar.finish_progress()
1000
+ progress_bar.stop_progress()
1001
+ progress_bar.close()
1002
+
1003
+ def deactivate_model(self):
1004
+ """
1005
+ Deactivate the currently loaded model and clean up resources.
1006
+ """
1007
+ self.loaded_model = None
1008
+ self.model_path = None
1009
+ # Clean up resources
1010
+ gc.collect()
1011
+ torch.cuda.empty_cache()
1012
+ # Untoggle all tools
1013
+ self.main_window.untoggle_all_tools()
1014
+ # Update status bar
1015
+ self.status_bar.setText("No model loaded")
1016
+ QMessageBox.information(self, "Model Deactivated", "Model deactivated")