coralnet-toolbox 0.0.72__py2.py3-none-any.whl → 0.0.74__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. coralnet_toolbox/Annotations/QtAnnotation.py +28 -69
  2. coralnet_toolbox/Annotations/QtMaskAnnotation.py +408 -0
  3. coralnet_toolbox/Annotations/QtMultiPolygonAnnotation.py +72 -56
  4. coralnet_toolbox/Annotations/QtPatchAnnotation.py +165 -216
  5. coralnet_toolbox/Annotations/QtPolygonAnnotation.py +497 -353
  6. coralnet_toolbox/Annotations/QtRectangleAnnotation.py +126 -116
  7. coralnet_toolbox/AutoDistill/QtDeployModel.py +23 -12
  8. coralnet_toolbox/CoralNet/QtDownload.py +2 -1
  9. coralnet_toolbox/Explorer/QtDataItem.py +1 -1
  10. coralnet_toolbox/Explorer/QtExplorer.py +159 -17
  11. coralnet_toolbox/Explorer/QtSettingsWidgets.py +160 -86
  12. coralnet_toolbox/IO/QtExportTagLabAnnotations.py +30 -10
  13. coralnet_toolbox/IO/QtImportTagLabAnnotations.py +21 -15
  14. coralnet_toolbox/IO/QtOpenProject.py +46 -78
  15. coralnet_toolbox/IO/QtSaveProject.py +18 -43
  16. coralnet_toolbox/MachineLearning/DeployModel/QtDetect.py +22 -11
  17. coralnet_toolbox/MachineLearning/DeployModel/QtSegment.py +22 -10
  18. coralnet_toolbox/MachineLearning/ExportDataset/QtBase.py +61 -24
  19. coralnet_toolbox/MachineLearning/ExportDataset/QtClassify.py +5 -1
  20. coralnet_toolbox/MachineLearning/ExportDataset/QtDetect.py +19 -6
  21. coralnet_toolbox/MachineLearning/ExportDataset/QtSegment.py +21 -8
  22. coralnet_toolbox/MachineLearning/ImportDataset/QtBase.py +42 -22
  23. coralnet_toolbox/MachineLearning/VideoInference/QtBase.py +0 -4
  24. coralnet_toolbox/QtAnnotationWindow.py +42 -14
  25. coralnet_toolbox/QtEventFilter.py +19 -2
  26. coralnet_toolbox/QtImageWindow.py +134 -86
  27. coralnet_toolbox/QtLabelWindow.py +14 -2
  28. coralnet_toolbox/QtMainWindow.py +122 -9
  29. coralnet_toolbox/QtProgressBar.py +52 -27
  30. coralnet_toolbox/Rasters/QtRaster.py +59 -7
  31. coralnet_toolbox/Rasters/RasterTableModel.py +42 -14
  32. coralnet_toolbox/SAM/QtBatchInference.py +0 -2
  33. coralnet_toolbox/SAM/QtDeployGenerator.py +22 -11
  34. coralnet_toolbox/SAM/QtDeployPredictor.py +10 -0
  35. coralnet_toolbox/SeeAnything/QtBatchInference.py +19 -221
  36. coralnet_toolbox/SeeAnything/QtDeployGenerator.py +1634 -0
  37. coralnet_toolbox/SeeAnything/QtDeployPredictor.py +107 -154
  38. coralnet_toolbox/SeeAnything/QtTrainModel.py +115 -45
  39. coralnet_toolbox/SeeAnything/__init__.py +2 -0
  40. coralnet_toolbox/Tools/QtCutSubTool.py +18 -2
  41. coralnet_toolbox/Tools/QtResizeSubTool.py +19 -2
  42. coralnet_toolbox/Tools/QtSAMTool.py +222 -57
  43. coralnet_toolbox/Tools/QtSeeAnythingTool.py +223 -55
  44. coralnet_toolbox/Tools/QtSelectSubTool.py +6 -4
  45. coralnet_toolbox/Tools/QtSelectTool.py +27 -3
  46. coralnet_toolbox/Tools/QtSubtractSubTool.py +66 -0
  47. coralnet_toolbox/Tools/QtWorkAreaTool.py +25 -13
  48. coralnet_toolbox/Tools/__init__.py +2 -0
  49. coralnet_toolbox/__init__.py +1 -1
  50. coralnet_toolbox/utilities.py +137 -47
  51. coralnet_toolbox-0.0.74.dist-info/METADATA +375 -0
  52. {coralnet_toolbox-0.0.72.dist-info → coralnet_toolbox-0.0.74.dist-info}/RECORD +56 -53
  53. coralnet_toolbox-0.0.72.dist-info/METADATA +0 -341
  54. {coralnet_toolbox-0.0.72.dist-info → coralnet_toolbox-0.0.74.dist-info}/WHEEL +0 -0
  55. {coralnet_toolbox-0.0.72.dist-info → coralnet_toolbox-0.0.74.dist-info}/entry_points.txt +0 -0
  56. {coralnet_toolbox-0.0.72.dist-info → coralnet_toolbox-0.0.74.dist-info}/licenses/LICENSE.txt +0 -0
  57. {coralnet_toolbox-0.0.72.dist-info → coralnet_toolbox-0.0.74.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1634 @@
1
+ import warnings
2
+
3
+ import os
4
+ import gc
5
+
6
+ import numpy as np
7
+ from sklearn.decomposition import PCA
8
+
9
+ import torch
10
+ from torch.cuda import empty_cache
11
+
12
+ import pyqtgraph as pg
13
+ from pyqtgraph.Qt import QtGui
14
+
15
+ from ultralytics import YOLOE
16
+ from ultralytics.models.yolo.yoloe import YOLOEVPSegPredictor
17
+
18
+ from PyQt5.QtCore import Qt
19
+ from PyQt5.QtWidgets import (QMessageBox, QVBoxLayout, QApplication, QFileDialog,
20
+ QLabel, QDialog, QDialogButtonBox, QGroupBox, QLineEdit,
21
+ QFormLayout, QComboBox, QSpinBox, QSlider, QPushButton,
22
+ QHBoxLayout)
23
+
24
+ from coralnet_toolbox.Annotations.QtPolygonAnnotation import PolygonAnnotation
25
+ from coralnet_toolbox.Annotations.QtRectangleAnnotation import RectangleAnnotation
26
+
27
+ from coralnet_toolbox.Results import ResultsProcessor
28
+ from coralnet_toolbox.Results import MapResults
29
+ from coralnet_toolbox.Results import CombineResults
30
+
31
+ from coralnet_toolbox.QtProgressBar import ProgressBar
32
+ from coralnet_toolbox.QtImageWindow import ImageWindow
33
+
34
+ from coralnet_toolbox.Icons import get_icon
35
+
36
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
37
+ warnings.filterwarnings("ignore", category=UserWarning)
38
+
39
+
40
+ # ----------------------------------------------------------------------------------------------------------------------
41
+ # Classes
42
+ # ----------------------------------------------------------------------------------------------------------------------
43
+
44
+
45
+ class DeployGeneratorDialog(QDialog):
46
+ """
47
+ Perform See Anything (YOLOE) on multiple images using a reference image and label.
48
+
49
+ :param main_window: MainWindow object
50
+ :param parent: Parent widget
51
+ """
52
+ def __init__(self, main_window, parent=None):
53
+ super().__init__(parent)
54
+ self.main_window = main_window
55
+ self.label_window = main_window.label_window
56
+ self.image_window = main_window.image_window
57
+ self.annotation_window = main_window.annotation_window
58
+ self.sam_dialog = None
59
+
60
+ self.setWindowIcon(get_icon("eye.png"))
61
+ self.setWindowTitle("See Anything (YOLOE) Generator (Ctrl + 5)")
62
+ self.resize(800, 800) # Increased size to accommodate the horizontal layout
63
+
64
+ self.deploy_model_dialog = None
65
+ self.loaded_model = None
66
+ self.last_selected_label_code = None
67
+
68
+ # Initialize variables
69
+ self.imgsz = 1024
70
+ self.iou_thresh = 0.20
71
+ self.uncertainty_thresh = 0.30
72
+ self.area_thresh_min = 0.00
73
+ self.area_thresh_max = 0.40
74
+
75
+ self.task = 'detect'
76
+ self.max_detect = 300
77
+ self.loaded_model = None
78
+ self.model_path = None
79
+ self.class_mapping = {}
80
+
81
+ # Reference image and label
82
+ self.reference_label = None
83
+ self.reference_image_paths = []
84
+
85
+ # Visual Prompting Encoding (VPE) - legacy single tensor variable
86
+ self.vpe_path = None
87
+ self.vpe = None
88
+
89
+ # New separate VPE collections
90
+ self.imported_vpes = [] # VPEs loaded from file
91
+ self.reference_vpes = [] # VPEs created from reference images
92
+
93
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
94
+
95
+ # Main vertical layout for the dialog
96
+ self.layout = QVBoxLayout(self)
97
+
98
+ # Setup the info layout at the top
99
+ self.setup_info_layout()
100
+
101
+ # Create horizontal layout for the two panels
102
+ self.horizontal_layout = QHBoxLayout()
103
+ self.layout.addLayout(self.horizontal_layout)
104
+
105
+ # Create left panel
106
+ self.left_panel = QVBoxLayout()
107
+ self.horizontal_layout.addLayout(self.left_panel)
108
+
109
+ # Create right panel
110
+ self.right_panel = QVBoxLayout()
111
+ self.horizontal_layout.addLayout(self.right_panel)
112
+
113
+ # Add layouts to the left panel
114
+ self.setup_models_layout()
115
+ self.setup_parameters_layout()
116
+ self.setup_sam_layout()
117
+ self.setup_model_buttons_layout()
118
+ self.setup_status_layout()
119
+
120
+ # Add layouts to the right panel
121
+ self.setup_reference_layout()
122
+
123
+ # # Add a full ImageWindow instance for target image selection
124
+ self.image_selection_window = ImageWindow(self.main_window)
125
+ self.right_panel.addWidget(self.image_selection_window)
126
+
127
+ # Setup the buttons layout at the bottom
128
+ self.setup_buttons_layout()
129
+
130
+ def configure_image_window_for_dialog(self):
131
+ """
132
+ Disables parts of the internal ImageWindow UI to guide user selection.
133
+ This forces the image list to only show images with annotations
134
+ matching the selected reference label.
135
+ """
136
+ iw = self.image_selection_window
137
+
138
+ # Block signals to prevent setChecked from triggering the ImageWindow's
139
+ # own filtering logic. We want to be in complete control.
140
+ iw.highlighted_checkbox.blockSignals(True)
141
+ iw.has_predictions_checkbox.blockSignals(True)
142
+ iw.no_annotations_checkbox.blockSignals(True)
143
+ iw.has_annotations_checkbox.blockSignals(True)
144
+
145
+ # Disable and set filter checkboxes
146
+ iw.highlighted_checkbox.setEnabled(False)
147
+ iw.has_predictions_checkbox.setEnabled(False)
148
+ iw.no_annotations_checkbox.setEnabled(False)
149
+ iw.has_annotations_checkbox.setEnabled(False)
150
+
151
+ iw.highlighted_checkbox.setChecked(False)
152
+ iw.has_predictions_checkbox.setChecked(False)
153
+ iw.no_annotations_checkbox.setChecked(False)
154
+ iw.has_annotations_checkbox.setChecked(True) # This will no longer trigger a filter
155
+
156
+ # Unblock signals now that we're done.
157
+ iw.highlighted_checkbox.blockSignals(False)
158
+ iw.has_predictions_checkbox.blockSignals(False)
159
+ iw.no_annotations_checkbox.blockSignals(False)
160
+ iw.has_annotations_checkbox.blockSignals(False)
161
+
162
+ # Disable search UI elements
163
+ iw.home_button.setEnabled(False)
164
+ iw.image_search_button.setEnabled(False)
165
+ iw.label_search_button.setEnabled(False)
166
+ iw.search_bar_images.setEnabled(False)
167
+ iw.search_bar_labels.setEnabled(False)
168
+ iw.top_k_combo.setEnabled(False)
169
+
170
+ # Hide the "Current" label as it is not applicable in this dialog
171
+ iw.current_image_index_label.hide()
172
+
173
+ # Set Top-K to Top1
174
+ iw.top_k_combo.setCurrentText("Top1")
175
+
176
+ # Disconnect the double-click signal to prevent it from loading an image
177
+ # in the main window, as this dialog is for selection only.
178
+ try:
179
+ iw.tableView.doubleClicked.disconnect()
180
+ except TypeError:
181
+ pass
182
+
183
+ # CRITICAL: Override the load_first_filtered_image method to prevent auto-loading
184
+ # This is the key fix to prevent unwanted load_image_by_path calls
185
+ iw.load_first_filtered_image = lambda: None
186
+
187
+ def showEvent(self, event):
188
+ """
189
+ Set up the layout when the dialog is shown.
190
+
191
+ :param event: Show event
192
+ """
193
+ super().showEvent(event)
194
+ self.initialize_uncertainty_threshold()
195
+ self.initialize_iou_threshold()
196
+ self.initialize_area_threshold()
197
+
198
+ # Configure the image window's UI elements for this specific dialog
199
+ self.configure_image_window_for_dialog()
200
+ # Sync with main window's images BEFORE updating labels
201
+ self.sync_image_window()
202
+ # This now populates the dropdown, restores the last selection,
203
+ # and then manually triggers the image filtering.
204
+ self.update_reference_labels()
205
+
206
+ def sync_image_window(self):
207
+ """
208
+ Syncs by directly adopting the main manager's up-to-date raster objects,
209
+ avoiding redundant and slow re-calculation of annotation info.
210
+ """
211
+ main_manager = self.main_window.image_window.raster_manager
212
+ dialog_manager = self.image_selection_window.raster_manager
213
+
214
+ # Since the main_manager's rasters are always up-to-date, we can
215
+ # simply replace the dialog's raster dictionary and path list entirely.
216
+ # This is a shallow copy of the dictionary, which is extremely fast.
217
+ # The Raster objects themselves are not copied, just referenced.
218
+ dialog_manager.rasters = main_manager.rasters.copy()
219
+
220
+ # Update the path list to match the new dictionary of rasters.
221
+ dialog_manager.image_paths = list(dialog_manager.rasters.keys())
222
+
223
+ # The slow 'for' loop that called update_annotation_info is now gone.
224
+ # We are trusting that each raster object from the main_manager
225
+ # already has its .label_set and .annotation_type_set correctly populated.
226
+
227
+ def filter_images_by_label_and_type(self):
228
+ """
229
+ Filters the image list to show only images that contain at least one
230
+ annotation that has BOTH the selected label AND a valid type (Polygon or Rectangle).
231
+ This uses the fast, pre-computed cache for performance.
232
+ """
233
+ # Persist the user's current highlights from the table model before filtering.
234
+ # This ensures that if the user highlights items and then changes the filter,
235
+ # their selection is not lost.
236
+ current_highlights = self.image_selection_window.table_model.get_highlighted_paths()
237
+ if current_highlights:
238
+ self.reference_image_paths = current_highlights
239
+
240
+ reference_label = self.reference_label_combo_box.currentData()
241
+ reference_label_text = self.reference_label_combo_box.currentText()
242
+
243
+ # Store the last selected label for a better user experience on re-opening.
244
+ if reference_label_text:
245
+ self.last_selected_label_code = reference_label_text
246
+ # Also store the reference label object itself
247
+ self.reference_label = reference_label
248
+
249
+ if not reference_label:
250
+ # If no label is selected (e.g., during initialization), show an empty list.
251
+ self.image_selection_window.table_model.set_filtered_paths([])
252
+ return
253
+
254
+ all_paths = self.image_selection_window.raster_manager.image_paths
255
+ final_filtered_paths = []
256
+
257
+ valid_types = {"RectangleAnnotation", "PolygonAnnotation"}
258
+ selected_label_code = reference_label.short_label_code
259
+
260
+ # Loop through paths and check the pre-computed map on each raster
261
+ for path in all_paths:
262
+ raster = self.image_selection_window.raster_manager.get_raster(path)
263
+ if not raster:
264
+ continue
265
+
266
+ # 1. From the cache, get the set of annotation types specifically for our selected label.
267
+ # Use .get() to safely return an empty set if the label isn't on this image at all.
268
+ types_for_this_label = raster.label_to_types_map.get(selected_label_code, set())
269
+
270
+ # 2. Check for any overlap between the types found FOR THIS LABEL and the
271
+ # valid types we need (Polygon/Rectangle). This is the key check.
272
+ if not valid_types.isdisjoint(types_for_this_label):
273
+ # This image is a valid reference because the selected label exists
274
+ # on a Polygon or Rectangle. Add it to the list.
275
+ final_filtered_paths.append(path)
276
+
277
+ # Directly set the filtered list in the table model.
278
+ self.image_selection_window.table_model.set_filtered_paths(final_filtered_paths)
279
+
280
+ # Try to preserve any previous selections
281
+ if hasattr(self, 'reference_image_paths') and self.reference_image_paths:
282
+ # Find which of our previously selected paths are still in the filtered list
283
+ valid_selections = [p for p in self.reference_image_paths if p in final_filtered_paths]
284
+ if valid_selections:
285
+ # Highlight previously selected paths that are still valid
286
+ self.image_selection_window.table_model.set_highlighted_paths(valid_selections)
287
+
288
+ # After filtering, update all labels with the correct counts.
289
+ dialog_iw = self.image_selection_window
290
+ dialog_iw.update_image_count_label(len(final_filtered_paths)) # Set "Total" to filtered count
291
+ dialog_iw.update_current_image_index_label()
292
+ dialog_iw.update_highlighted_count_label()
293
+
294
+ def accept(self):
295
+ """
296
+ Validate selections and store them before closing the dialog.
297
+ A prediction is valid if a model and label are selected, and the user
298
+ has provided either reference images or an imported VPE file.
299
+ """
300
+ if not self.loaded_model:
301
+ QMessageBox.warning(self,
302
+ "No Model",
303
+ "A model must be loaded before running predictions.")
304
+ return
305
+
306
+ # Set reference label from combo box
307
+ self.reference_label = self.reference_label_combo_box.currentData()
308
+ if not self.reference_label:
309
+ QMessageBox.warning(self,
310
+ "No Reference Label",
311
+ "A reference label must be selected.")
312
+ return
313
+
314
+ # Stash the current UI selection before validating.
315
+ self.update_stashed_references_from_ui()
316
+
317
+ # Check for a valid VPE source using the now-stashed list.
318
+ has_reference_images = bool(self.reference_image_paths)
319
+ has_imported_vpes = bool(self.imported_vpes)
320
+
321
+ if not has_reference_images and not has_imported_vpes:
322
+ QMessageBox.warning(self,
323
+ "No VPE Source Provided",
324
+ "You must highlight at least one reference image or load a VPE file to proceed.")
325
+ return
326
+
327
+ # If validation passes, close the dialog.
328
+ super().accept()
329
+
330
+ def setup_info_layout(self):
331
+ """
332
+ Set up the layout and widgets for the info layout that spans the top.
333
+ """
334
+ group_box = QGroupBox("Information")
335
+ layout = QVBoxLayout()
336
+
337
+ # Create a QLabel with explanatory text and hyperlink
338
+ info_label = QLabel("Choose a Generator to deploy. "
339
+ "Select a reference label, then highlight reference images that contain examples. "
340
+ "Each additional reference image may increase accuracy but also processing time.")
341
+
342
+ info_label.setOpenExternalLinks(True)
343
+ info_label.setWordWrap(True)
344
+ layout.addWidget(info_label)
345
+
346
+ group_box.setLayout(layout)
347
+ self.layout.addWidget(group_box) # Add to main layout so it spans both panels
348
+
349
+ def setup_models_layout(self):
350
+ """
351
+ Setup the models layout with a simple model selection combo box (no tabs).
352
+ """
353
+ group_box = QGroupBox("Model Selection")
354
+ layout = QFormLayout()
355
+
356
+ self.model_combo = QComboBox()
357
+ self.model_combo.setEditable(True)
358
+
359
+ # Define available models (keep the existing dictionary)
360
+ self.models = [
361
+ 'yoloe-v8s-seg.pt',
362
+ 'yoloe-v8m-seg.pt',
363
+ 'yoloe-v8l-seg.pt',
364
+ 'yoloe-11s-seg.pt',
365
+ 'yoloe-11m-seg.pt',
366
+ 'yoloe-11l-seg.pt',
367
+ ]
368
+
369
+ # Add all models to combo box
370
+ for model_name in self.models:
371
+ self.model_combo.addItem(model_name)
372
+
373
+ # Set the default model
374
+ self.model_combo.setCurrentIndex(self.models.index('yoloe-v8s-seg.pt'))
375
+ # Create a layout for the model selection
376
+ layout.addRow(QLabel("Models:"), self.model_combo)
377
+
378
+ # Add custom vpe file selection
379
+ self.vpe_path_edit = QLineEdit()
380
+ browse_button = QPushButton("Browse...")
381
+ browse_button.clicked.connect(self.browse_vpe_file)
382
+
383
+ vpe_path_layout = QHBoxLayout()
384
+ vpe_path_layout.addWidget(self.vpe_path_edit)
385
+ vpe_path_layout.addWidget(browse_button)
386
+ layout.addRow("Custom VPE:", vpe_path_layout)
387
+
388
+ group_box.setLayout(layout)
389
+ self.left_panel.addWidget(group_box) # Add to left panel
390
+
391
+ def setup_parameters_layout(self):
392
+ """
393
+ Setup parameter control section in a group box.
394
+ """
395
+ group_box = QGroupBox("Parameters")
396
+ layout = QFormLayout()
397
+
398
+ # Task dropdown
399
+ self.use_task_dropdown = QComboBox()
400
+ self.use_task_dropdown.addItems(["detect", "segment"])
401
+ self.use_task_dropdown.currentIndexChanged.connect(self.update_task)
402
+ layout.addRow("Task:", self.use_task_dropdown)
403
+
404
+ # Max detections spinbox
405
+ self.max_detections_spinbox = QSpinBox()
406
+ self.max_detections_spinbox.setRange(1, 10000)
407
+ self.max_detections_spinbox.setValue(self.max_detect)
408
+ layout.addRow("Max Detections:", self.max_detections_spinbox)
409
+
410
+ # Resize image dropdown
411
+ self.resize_image_dropdown = QComboBox()
412
+ self.resize_image_dropdown.addItems(["True", "False"])
413
+ self.resize_image_dropdown.setCurrentIndex(0)
414
+ self.resize_image_dropdown.setEnabled(False) # Grey out the dropdown
415
+ layout.addRow("Resize Image:", self.resize_image_dropdown)
416
+
417
+ # Image size control
418
+ self.imgsz_spinbox = QSpinBox()
419
+ self.imgsz_spinbox.setRange(512, 65536)
420
+ self.imgsz_spinbox.setSingleStep(1024)
421
+ self.imgsz_spinbox.setValue(self.imgsz)
422
+ layout.addRow("Image Size (imgsz):", self.imgsz_spinbox)
423
+
424
+ # Uncertainty threshold controls
425
+ self.uncertainty_thresh = self.main_window.get_uncertainty_thresh()
426
+ self.uncertainty_threshold_slider = QSlider(Qt.Horizontal)
427
+ self.uncertainty_threshold_slider.setRange(0, 100)
428
+ self.uncertainty_threshold_slider.setValue(int(self.main_window.get_uncertainty_thresh() * 100))
429
+ self.uncertainty_threshold_slider.setTickPosition(QSlider.TicksBelow)
430
+ self.uncertainty_threshold_slider.setTickInterval(10)
431
+ self.uncertainty_threshold_slider.valueChanged.connect(self.update_uncertainty_label)
432
+ self.uncertainty_threshold_label = QLabel(f"{self.uncertainty_thresh:.2f}")
433
+ layout.addRow("Uncertainty Threshold", self.uncertainty_threshold_slider)
434
+ layout.addRow("", self.uncertainty_threshold_label)
435
+
436
+ # IoU threshold controls
437
+ self.iou_thresh = self.main_window.get_iou_thresh()
438
+ self.iou_threshold_slider = QSlider(Qt.Horizontal)
439
+ self.iou_threshold_slider.setRange(0, 100)
440
+ self.iou_threshold_slider.setValue(int(self.iou_thresh * 100))
441
+ self.iou_threshold_slider.setTickPosition(QSlider.TicksBelow)
442
+ self.iou_threshold_slider.setTickInterval(10)
443
+ self.iou_threshold_slider.valueChanged.connect(self.update_iou_label)
444
+ self.iou_threshold_label = QLabel(f"{self.iou_thresh:.2f}")
445
+ layout.addRow("IoU Threshold", self.iou_threshold_slider)
446
+ layout.addRow("", self.iou_threshold_label)
447
+
448
+ # Area threshold controls
449
+ min_val, max_val = self.main_window.get_area_thresh()
450
+ self.area_thresh_min = int(min_val * 100)
451
+ self.area_thresh_max = int(max_val * 100)
452
+ self.area_threshold_min_slider = QSlider(Qt.Horizontal)
453
+ self.area_threshold_min_slider.setRange(0, 100)
454
+ self.area_threshold_min_slider.setValue(self.area_thresh_min)
455
+ self.area_threshold_min_slider.setTickPosition(QSlider.TicksBelow)
456
+ self.area_threshold_min_slider.setTickInterval(10)
457
+ self.area_threshold_min_slider.valueChanged.connect(self.update_area_label)
458
+ self.area_threshold_max_slider = QSlider(Qt.Horizontal)
459
+ self.area_threshold_max_slider.setRange(0, 100)
460
+ self.area_threshold_max_slider.setValue(self.area_thresh_max)
461
+ self.area_threshold_max_slider.setTickPosition(QSlider.TicksBelow)
462
+ self.area_threshold_max_slider.setTickInterval(10)
463
+ self.area_threshold_max_slider.valueChanged.connect(self.update_area_label)
464
+ self.area_threshold_label = QLabel(f"{self.area_thresh_min / 100.0:.2f} - {self.area_thresh_max / 100.0:.2f}")
465
+ layout.addRow("Area Threshold Min", self.area_threshold_min_slider)
466
+ layout.addRow("Area Threshold Max", self.area_threshold_max_slider)
467
+ layout.addRow("", self.area_threshold_label)
468
+
469
+ group_box.setLayout(layout)
470
+ self.left_panel.addWidget(group_box) # Add to left panel
471
+
472
+ def setup_sam_layout(self):
473
+ """Use SAM model for segmentation."""
474
+ group_box = QGroupBox("Use SAM Model for Creating Polygons")
475
+ layout = QFormLayout()
476
+
477
+ # SAM dropdown
478
+ self.use_sam_dropdown = QComboBox()
479
+ self.use_sam_dropdown.addItems(["False", "True"])
480
+ self.use_sam_dropdown.currentIndexChanged.connect(self.is_sam_model_deployed)
481
+ layout.addRow("Use SAM Polygons:", self.use_sam_dropdown)
482
+
483
+ group_box.setLayout(layout)
484
+ self.left_panel.addWidget(group_box) # Add to left panel
485
+
486
+ def setup_model_buttons_layout(self):
487
+ """
488
+ Setup action buttons in a group box.
489
+ """
490
+ group_box = QGroupBox("Actions")
491
+ main_layout = QVBoxLayout()
492
+
493
+ # First row: Load and Deactivate buttons side by side
494
+ button_row = QHBoxLayout()
495
+ load_button = QPushButton("Load Model")
496
+ load_button.clicked.connect(self.load_model)
497
+ button_row.addWidget(load_button)
498
+
499
+ deactivate_button = QPushButton("Deactivate Model")
500
+ deactivate_button.clicked.connect(self.deactivate_model)
501
+ button_row.addWidget(deactivate_button)
502
+
503
+ main_layout.addLayout(button_row)
504
+
505
+ # Second row: Save VPE button + Show VPE button side by side
506
+ vpe_row = QHBoxLayout()
507
+ save_vpe_button = QPushButton("Save VPE")
508
+ save_vpe_button.clicked.connect(self.save_vpe)
509
+ vpe_row.addWidget(save_vpe_button)
510
+
511
+ show_vpe_button = QPushButton("Show VPE")
512
+ show_vpe_button.clicked.connect(self.show_vpe)
513
+ vpe_row.addWidget(show_vpe_button)
514
+
515
+ main_layout.addLayout(vpe_row)
516
+
517
+ group_box.setLayout(main_layout)
518
+ self.left_panel.addWidget(group_box) # Add to left panel
519
+
520
+ def setup_status_layout(self):
521
+ """
522
+ Setup status display in a group box.
523
+ """
524
+ group_box = QGroupBox("Status")
525
+ layout = QVBoxLayout()
526
+
527
+ self.status_bar = QLabel("No model loaded")
528
+ layout.addWidget(self.status_bar)
529
+
530
+ group_box.setLayout(layout)
531
+ self.left_panel.addWidget(group_box) # Add to left panel
532
+
533
+ def setup_reference_layout(self):
534
+ """
535
+ Set up the layout with reference label selection.
536
+ The reference image is implicitly the currently active image.
537
+ """
538
+ group_box = QGroupBox("Reference")
539
+ layout = QFormLayout()
540
+
541
+ # Create the reference label combo box
542
+ self.reference_label_combo_box = QComboBox()
543
+ self.reference_label_combo_box.currentIndexChanged.connect(self.filter_images_by_label_and_type)
544
+ layout.addRow("Reference Label:", self.reference_label_combo_box)
545
+
546
+ # Create a Reference model combobox (VPE, Images)
547
+ self.reference_method_combo_box = QComboBox()
548
+ self.reference_method_combo_box.addItems(["VPE", "Images"])
549
+ layout.addRow("Reference Method:", self.reference_method_combo_box)
550
+
551
+ group_box.setLayout(layout)
552
+ self.right_panel.addWidget(group_box) # Add to right panel
553
+
554
+ def setup_buttons_layout(self):
555
+ """
556
+ Set up the layout with buttons.
557
+ """
558
+ # Create a button box for the buttons
559
+ button_box = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel)
560
+ button_box.accepted.connect(self.accept)
561
+ button_box.rejected.connect(self.reject)
562
+
563
+ self.layout.addWidget(button_box)
564
+
565
+ def initialize_uncertainty_threshold(self):
566
+ """Initialize the uncertainty threshold slider with the current value"""
567
+ current_value = self.main_window.get_uncertainty_thresh()
568
+ self.uncertainty_threshold_slider.setValue(int(current_value * 100))
569
+ self.uncertainty_thresh = current_value
570
+
571
+ def initialize_iou_threshold(self):
572
+ """Initialize the IOU threshold slider with the current value"""
573
+ current_value = self.main_window.get_iou_thresh()
574
+ self.iou_threshold_slider.setValue(int(current_value * 100))
575
+ self.iou_thresh = current_value
576
+
577
+ def initialize_area_threshold(self):
578
+ """Initialize the area threshold range slider"""
579
+ current_min, current_max = self.main_window.get_area_thresh()
580
+ self.area_threshold_min_slider.setValue(int(current_min * 100))
581
+ self.area_threshold_max_slider.setValue(int(current_max * 100))
582
+ self.area_thresh_min = current_min
583
+ self.area_thresh_max = current_max
584
+
585
+ def update_uncertainty_label(self, value):
586
+ """Update uncertainty threshold and label"""
587
+ value = value / 100.0
588
+ self.uncertainty_thresh = value
589
+ self.main_window.update_uncertainty_thresh(value)
590
+ self.uncertainty_threshold_label.setText(f"{value:.2f}")
591
+
592
+ def update_iou_label(self, value):
593
+ """Update IoU threshold and label"""
594
+ value = value / 100.0
595
+ self.iou_thresh = value
596
+ self.main_window.update_iou_thresh(value)
597
+ self.iou_threshold_label.setText(f"{value:.2f}")
598
+
599
+ def update_area_label(self):
600
+ """Handle changes to area threshold range slider"""
601
+ min_val = self.area_threshold_min_slider.value()
602
+ max_val = self.area_threshold_max_slider.value()
603
+ if min_val > max_val:
604
+ min_val = max_val
605
+ self.area_threshold_min_slider.setValue(min_val)
606
+ self.area_thresh_min = min_val / 100.0
607
+ self.area_thresh_max = max_val / 100.0
608
+ self.main_window.update_area_thresh(self.area_thresh_min, self.area_thresh_max)
609
+ self.area_threshold_label.setText(f"{self.area_thresh_min:.2f} - {self.area_thresh_max:.2f}")
610
+
611
+ def update_stashed_references_from_ui(self):
612
+ """Updates the internal reference path list from the current UI selection."""
613
+ self.reference_image_paths = self.image_selection_window.table_model.get_highlighted_paths()
614
+
615
+ def get_max_detections(self):
616
+ """Get the maximum number of detections to return."""
617
+ self.max_detect = self.max_detections_spinbox.value()
618
+ return self.max_detect
619
+
620
+ def is_sam_model_deployed(self):
621
+ """
622
+ Check if the SAM model is deployed and update the checkbox state accordingly.
623
+
624
+ :return: Boolean indicating whether the SAM model is deployed
625
+ """
626
+ if not hasattr(self.main_window, 'sam_deploy_predictor_dialog'):
627
+ return False
628
+
629
+ self.sam_dialog = self.main_window.sam_deploy_predictor_dialog
630
+
631
+ if not self.sam_dialog.loaded_model:
632
+ self.use_sam_dropdown.setCurrentText("False")
633
+ QMessageBox.critical(self, "Error", "Please deploy the SAM model first.")
634
+ return False
635
+
636
+ return True
637
+
638
+ def update_sam_task_state(self):
639
+ """
640
+ Centralized method to check if SAM is loaded and update task accordingly.
641
+ If the user has selected to use SAM, this function ensures the task is set to 'segment'.
642
+ Crucially, it does NOT alter the task if SAM is not selected, respecting the
643
+ user's choice from the 'Task' dropdown.
644
+ """
645
+ # Check if the user wants to use the SAM model
646
+ if self.use_sam_dropdown.currentText() == "True":
647
+ # SAM is requested. Check if it's actually available.
648
+ sam_is_available = (
649
+ hasattr(self, 'sam_dialog') and
650
+ self.sam_dialog is not None and
651
+ self.sam_dialog.loaded_model is not None
652
+ )
653
+
654
+ if sam_is_available:
655
+ # If SAM is wanted and available, the task must be segmentation.
656
+ self.task = 'segment'
657
+ else:
658
+ # If SAM is wanted but not available, revert the dropdown and do nothing else.
659
+ # The 'is_sam_model_deployed' function already handles showing an error message.
660
+ self.use_sam_dropdown.setCurrentText("False")
661
+
662
+ # If use_sam_dropdown is "False", do nothing. Let self.task be whatever the user set.
663
+
664
+ def update_task(self):
665
+ """Update the task based on the dropdown selection and handle UI/model effects."""
666
+ self.task = self.use_task_dropdown.currentText()
667
+
668
+ # Update UI elements based on task
669
+ if self.task == "segment":
670
+ # Deactivate model if one is loaded and we're switching to segment task
671
+ if self.loaded_model:
672
+ self.deactivate_model()
673
+
674
+ def update_reference_labels(self):
675
+ """
676
+ Updates the reference label combo box with ALL available project labels.
677
+ This dropdown now serves as the "Output Label" for all predictions.
678
+ The "Review" label with id "-1" is excluded.
679
+ """
680
+ self.reference_label_combo_box.blockSignals(True)
681
+
682
+ try:
683
+ self.reference_label_combo_box.clear()
684
+
685
+ # Get all labels from the main label window
686
+ all_project_labels = self.main_window.label_window.labels
687
+
688
+ # Filter out the special "Review" label and create a list of valid labels
689
+ valid_labels = [
690
+ label_obj for label_obj in all_project_labels
691
+ if not (label_obj.short_label_code == "Review" and str(label_obj.id) == "-1")
692
+ ]
693
+
694
+ # Add the valid labels to the combo box, sorted alphabetically.
695
+ sorted_valid_labels = sorted(valid_labels, key=lambda x: x.short_label_code)
696
+ for label_obj in sorted_valid_labels:
697
+ self.reference_label_combo_box.addItem(label_obj.short_label_code, label_obj)
698
+
699
+ # Restore the last selected label if it's still present in the list.
700
+ if self.last_selected_label_code:
701
+ index = self.reference_label_combo_box.findText(self.last_selected_label_code)
702
+ if index != -1:
703
+ self.reference_label_combo_box.setCurrentIndex(index)
704
+ finally:
705
+ self.reference_label_combo_box.blockSignals(False)
706
+
707
+ # Manually trigger the image filtering now that the combo box is stable.
708
+ # This will still filter the image list to help find references if needed.
709
+ self.filter_images_by_label_and_type()
710
+
711
+ def get_reference_annotations(self, reference_label, reference_image_path):
712
+ """
713
+ Return a list of bboxes and masks for a specific image
714
+ belonging to the selected label.
715
+
716
+ :param reference_label: The Label object to filter annotations by.
717
+ :param reference_image_path: The path of the image to get annotations from.
718
+ :return: A tuple containing a numpy array of bboxes and a list of masks.
719
+ """
720
+ if not all([reference_label, reference_image_path]):
721
+ return np.array([]), []
722
+
723
+ # Get all annotations for the specified image
724
+ annotations = self.annotation_window.get_image_annotations(reference_image_path)
725
+
726
+ # Filter annotations by the provided label
727
+ reference_bboxes = []
728
+ reference_masks = []
729
+ for annotation in annotations:
730
+ if annotation.label.short_label_code == reference_label.short_label_code:
731
+ if isinstance(annotation, (PolygonAnnotation, RectangleAnnotation)):
732
+ bbox = annotation.cropped_bbox
733
+ reference_bboxes.append(bbox)
734
+ if isinstance(annotation, PolygonAnnotation):
735
+ points = np.array([[p.x(), p.y()] for p in annotation.points])
736
+ reference_masks.append(points)
737
+ elif isinstance(annotation, RectangleAnnotation):
738
+ x1, y1, x2, y2 = bbox
739
+ rect_points = np.array([[x1, y1], [x2, y1], [x2, y2], [x1, y2]])
740
+ reference_masks.append(rect_points)
741
+
742
+ return np.array(reference_bboxes), reference_masks
743
+
744
+ def browse_vpe_file(self):
745
+ """
746
+ Open a file dialog to browse for a VPE file and load it.
747
+ Stores imported VPEs separately from reference-generated VPEs.
748
+ """
749
+ file_path, _ = QFileDialog.getOpenFileName(
750
+ self,
751
+ "Select Visual Prompt Encoding (VPE) File",
752
+ "",
753
+ "VPE Files (*.pt);;All Files (*)"
754
+ )
755
+
756
+ if not file_path:
757
+ return
758
+
759
+ self.vpe_path_edit.setText(file_path)
760
+ self.vpe_path = file_path
761
+
762
+ try:
763
+ # Load the VPE file
764
+ loaded_data = torch.load(file_path)
765
+
766
+ # TODO Move tensors to the appropriate device
767
+ # device = self.main_window.device
768
+
769
+ # Check format type and handle appropriately
770
+ if isinstance(loaded_data, list):
771
+ # New format: list of VPE tensors
772
+ self.imported_vpes = [vpe.to(self.device) for vpe in loaded_data]
773
+ vpe_count = len(self.imported_vpes)
774
+ self.status_bar.setText(f"Loaded {vpe_count} VPE tensors from file")
775
+
776
+ elif isinstance(loaded_data, torch.Tensor):
777
+ # Legacy format: single tensor - convert to list for consistency
778
+ loaded_vpe = loaded_data.to(self.device)
779
+ # Store as a single-item list
780
+ self.imported_vpes = [loaded_vpe]
781
+ self.status_bar.setText("Loaded 1 VPE tensor from file (legacy format)")
782
+
783
+ else:
784
+ # Invalid format
785
+ self.imported_vpes = []
786
+ self.status_bar.setText("Invalid VPE file format")
787
+ QMessageBox.warning(
788
+ self,
789
+ "Invalid VPE",
790
+ "The file does not appear to be a valid VPE format."
791
+ )
792
+ # Clear the VPE path edit field
793
+ self.vpe_path_edit.clear()
794
+
795
+ # For backward compatibility - set self.vpe to the average of imported VPEs
796
+ # This ensures older code paths still work
797
+ if self.imported_vpes:
798
+ combined_vpe = torch.cat(self.imported_vpes).mean(dim=0, keepdim=True)
799
+ self.vpe = torch.nn.functional.normalize(combined_vpe, p=2, dim=-1)
800
+
801
+ except Exception as e:
802
+ self.imported_vpes = []
803
+ self.vpe = None
804
+ self.status_bar.setText(f"Error loading VPE: {str(e)}")
805
+ QMessageBox.critical(
806
+ self,
807
+ "Error Loading VPE",
808
+ f"Failed to load VPE file: {str(e)}"
809
+ )
810
+
811
+ def save_vpe(self):
812
+ """
813
+ Save the combined collection of VPEs (imported and reference-generated) to disk.
814
+ """
815
+ # Always sync with the live UI selection before saving.
816
+ self.update_stashed_references_from_ui()
817
+
818
+ # Create a list to hold all VPEs
819
+ all_vpes = []
820
+
821
+ # Add imported VPEs if available
822
+ if self.imported_vpes:
823
+ all_vpes.extend(self.imported_vpes)
824
+
825
+ # Check if we should generate new VPEs from reference images
826
+ references_dict = self._get_references()
827
+ if references_dict:
828
+ # Reload the model to ensure clean state
829
+ self.reload_model()
830
+
831
+ # Convert references to VPEs without updating self.reference_vpes yet
832
+ new_vpes = self.references_to_vpe(references_dict, update_reference_vpes=False)
833
+
834
+ if new_vpes:
835
+ # Add new VPEs to collection
836
+ all_vpes.extend(new_vpes)
837
+ # Update reference_vpes with the new ones
838
+ self.reference_vpes = new_vpes
839
+ else:
840
+ # Include existing reference VPEs if we have them
841
+ if self.reference_vpes:
842
+ all_vpes.extend(self.reference_vpes)
843
+
844
+ # Check if we have any VPEs to save
845
+ if not all_vpes:
846
+ QMessageBox.warning(
847
+ self,
848
+ "No VPEs Available",
849
+ "No VPEs available to save. Please either load a VPE file or select reference images."
850
+ )
851
+ return
852
+
853
+ # Open file dialog to select save location
854
+ file_path, _ = QFileDialog.getSaveFileName(
855
+ self,
856
+ "Save VPE Collection",
857
+ "",
858
+ "PyTorch Tensor (*.pt);;All Files (*)"
859
+ )
860
+
861
+ if not file_path:
862
+ return # User canceled the dialog
863
+
864
+ # Add .pt extension if not present
865
+ if not file_path.endswith('.pt'):
866
+ file_path += '.pt'
867
+
868
+ try:
869
+ # Move tensors to CPU before saving
870
+ vpe_list_cpu = [vpe.cpu() for vpe in all_vpes]
871
+
872
+ # Save the list of tensors
873
+ torch.save(vpe_list_cpu, file_path)
874
+
875
+ self.status_bar.setText(f"Saved {len(all_vpes)} VPE tensors to {os.path.basename(file_path)}")
876
+ QMessageBox.information(
877
+ self,
878
+ "VPE Saved",
879
+ f"Saved {len(all_vpes)} VPE tensors to {file_path}"
880
+ )
881
+ except Exception as e:
882
+ QMessageBox.critical(
883
+ self,
884
+ "Error Saving VPE",
885
+ f"Failed to save VPE: {str(e)}"
886
+ )
887
+
888
+ def load_model(self):
889
+ """
890
+ Load the selected model.
891
+ """
892
+ QApplication.setOverrideCursor(Qt.WaitCursor)
893
+ progress_bar = ProgressBar(self.annotation_window, title="Loading Model")
894
+ progress_bar.show()
895
+
896
+ try:
897
+ # Load the model using reload_model method
898
+ self.reload_model()
899
+
900
+ # Calculate total number of VPEs from both sources
901
+ total_vpes = len(self.imported_vpes) + len(self.reference_vpes)
902
+
903
+ if total_vpes > 0:
904
+ if self.imported_vpes and self.reference_vpes:
905
+ message = f"Model loaded with {len(self.imported_vpes)} imported VPEs "
906
+ message += f"and {len(self.reference_vpes)} reference VPEs"
907
+ elif self.imported_vpes:
908
+ message = f"Model loaded with {len(self.imported_vpes)} imported VPEs"
909
+ else:
910
+ message = f"Model loaded with {len(self.reference_vpes)} reference VPEs"
911
+
912
+ self.status_bar.setText(message)
913
+ else:
914
+ message = "Model loaded with default VPE"
915
+ self.status_bar.setText("Model loaded with default VPE")
916
+
917
+ # Finish progress bar
918
+ progress_bar.finish_progress()
919
+ QMessageBox.information(self.annotation_window, "Model Loaded", message)
920
+
921
+ except Exception as e:
922
+ self.loaded_model = None
923
+ QMessageBox.critical(self.annotation_window,
924
+ "Error Loading Model",
925
+ f"Error loading model: {e}")
926
+
927
+ finally:
928
+ # Restore cursor
929
+ QApplication.restoreOverrideCursor()
930
+ # Stop the progress bar
931
+ progress_bar.stop_progress()
932
+ progress_bar.close()
933
+ progress_bar = None
934
+
935
+ def reload_model(self):
936
+ """
937
+ Subset of the load_model method. This is needed when additional
938
+ reference images and annotations (i.e., VPEs) are added (we have
939
+ to re-load the model each time).
940
+
941
+ This method also ensures that we stash the currently highlighted reference
942
+ image paths before reloading, so they're available for predictions
943
+ even if the user switches the active image in the main window.
944
+ """
945
+ self.loaded_model = None
946
+
947
+ # Get selected model path and download weights if needed
948
+ self.model_path = self.model_combo.currentText()
949
+
950
+ # Load model using registry
951
+ self.loaded_model = YOLOE(self.model_path, verbose=False).to(self.device) # TODO
952
+
953
+ # Create a dummy visual dictionary for standard model loading
954
+ visual_prompts = dict(
955
+ bboxes=np.array(
956
+ [
957
+ [120, 425, 160, 445], # Random box
958
+ ],
959
+ ),
960
+ cls=np.array(
961
+ np.zeros(1),
962
+ ),
963
+ )
964
+
965
+ # Run a dummy prediction to load the model
966
+ self.loaded_model.predict(
967
+ np.zeros((640, 640, 3), dtype=np.uint8),
968
+ visual_prompts=visual_prompts.copy(), # This needs to happen to properly initialize the predictor
969
+ predictor=YOLOEVPSegPredictor, # This also needs to be SegPredictor, no matter what
970
+ imgsz=640,
971
+ conf=0.99,
972
+ )
973
+
974
+ # If a VPE file was loaded, use it with the model after the dummy prediction
975
+ if self.vpe is not None and isinstance(self.vpe, torch.Tensor):
976
+ # Directly set the final tensor as the prompt for the predictor
977
+ self.loaded_model.is_fused = lambda: False
978
+ self.loaded_model.set_classes(["object0"], self.vpe)
979
+
980
+ def predict(self, image_paths=None):
981
+ """
982
+ Make predictions on the given image paths using the loaded model.
983
+
984
+ Args:
985
+ image_paths: List of image paths to process. If None, uses the current image.
986
+ """
987
+ if not self.loaded_model or not self.reference_label:
988
+ return
989
+
990
+ # Update class mapping with the selected reference label
991
+ self.class_mapping = {0: self.reference_label}
992
+
993
+ # Create a results processor
994
+ results_processor = ResultsProcessor(
995
+ self.main_window,
996
+ self.class_mapping,
997
+ uncertainty_thresh=self.main_window.get_uncertainty_thresh(),
998
+ iou_thresh=self.main_window.get_iou_thresh(),
999
+ min_area_thresh=self.main_window.get_area_thresh_min(),
1000
+ max_area_thresh=self.main_window.get_area_thresh_max()
1001
+ )
1002
+
1003
+ if not image_paths:
1004
+ # Predict only the current image
1005
+ image_paths = [self.annotation_window.current_image_path]
1006
+
1007
+ # Make cursor busy
1008
+ QApplication.setOverrideCursor(Qt.WaitCursor)
1009
+
1010
+ # Start the progress bar
1011
+ progress_bar = ProgressBar(self.annotation_window, title="Prediction Workflow")
1012
+ progress_bar.show()
1013
+ progress_bar.start_progress(len(image_paths))
1014
+
1015
+ try:
1016
+ for image_path in image_paths:
1017
+ inputs = self._get_inputs(image_path)
1018
+ if inputs is None:
1019
+ continue
1020
+
1021
+ results = self._apply_model(inputs)
1022
+ results = self._apply_sam(results, image_path)
1023
+ self._process_results(results_processor, results, image_path)
1024
+
1025
+ # Update the progress bar
1026
+ progress_bar.update_progress()
1027
+
1028
+ except Exception as e:
1029
+ print("An error occurred during prediction:", e)
1030
+ finally:
1031
+ QApplication.restoreOverrideCursor()
1032
+ progress_bar.finish_progress()
1033
+ progress_bar.stop_progress()
1034
+ progress_bar.close()
1035
+
1036
+ gc.collect()
1037
+ empty_cache()
1038
+
1039
+ def _get_inputs(self, image_path):
1040
+ """Get the inputs for the model prediction."""
1041
+ raster = self.image_window.raster_manager.get_raster(image_path)
1042
+ if self.annotation_window.get_selected_tool() != "work_area":
1043
+ # Use the image path
1044
+ work_areas_data = [raster.image_path]
1045
+ else:
1046
+ # Get the work areas
1047
+ work_areas_data = raster.get_work_areas_data()
1048
+
1049
+ return work_areas_data
1050
+
1051
+ def _get_references(self):
1052
+ """
1053
+ Get the reference annotations using the stashed list of reference images
1054
+ that was saved when the user accepted the dialog.
1055
+
1056
+ Returns:
1057
+ dict: Dictionary mapping image paths to annotation data, or empty dict if no valid references.
1058
+ """
1059
+ # Use the "stashed" list of paths. Do NOT query the table_model again,
1060
+ # as the UI's highlight state may have been cleared by other actions.
1061
+ reference_paths = self.reference_image_paths
1062
+
1063
+ if not reference_paths:
1064
+ print("No reference image paths were stashed to use for prediction.")
1065
+ return {}
1066
+
1067
+ # Get the reference label that was also stashed
1068
+ reference_label = self.reference_label
1069
+ if not reference_label:
1070
+ # This check is a safeguard; the accept() method should prevent this.
1071
+ print("No reference label was selected.")
1072
+ return {}
1073
+
1074
+ # Create a dictionary of reference annotations from the stashed paths
1075
+ reference_annotations_dict = {}
1076
+ for path in reference_paths:
1077
+ bboxes, masks = self.get_reference_annotations(reference_label, path)
1078
+ if bboxes.size > 0:
1079
+ reference_annotations_dict[path] = {
1080
+ 'bboxes': bboxes,
1081
+ 'masks': masks,
1082
+ 'cls': np.zeros(len(bboxes))
1083
+ }
1084
+
1085
+ return reference_annotations_dict
1086
+
1087
+ def _apply_model_using_images(self, inputs, reference_dict):
1088
+ """
1089
+ Apply the model using the provided images and reference annotations (dict). This method
1090
+ loops through each reference image using its annotations; we then aggregate
1091
+ all the results together. Less efficient, but potentially more accurate.
1092
+
1093
+ Args:
1094
+ inputs (list): List of input images.
1095
+ reference_dict (dict): Dictionary containing reference annotations for each image.
1096
+
1097
+ Returns:
1098
+ list: List of prediction results.
1099
+ """
1100
+ # Create a progress bar for iterating through reference images
1101
+ QApplication.setOverrideCursor(Qt.WaitCursor)
1102
+ progress_bar = ProgressBar(self.annotation_window, title="Making Predictions per Reference")
1103
+ progress_bar.show()
1104
+ progress_bar.start_progress(len(reference_dict))
1105
+
1106
+ results_list = []
1107
+ # The 'inputs' list contains work areas from the single target image.
1108
+ # We will predict on the first work area/full image.
1109
+ input_image = inputs[0]
1110
+
1111
+ # Iterate through each reference image and its annotations
1112
+ for ref_path, ref_annotations in reference_dict.items():
1113
+ # The 'refer_image' parameter is the path to the current reference image
1114
+ # The 'visual_prompts' are the annotations from that same reference image
1115
+ visual_prompts = {
1116
+ 'bboxes': ref_annotations['bboxes'],
1117
+ 'cls': ref_annotations['cls'],
1118
+ }
1119
+ if self.task == 'segment':
1120
+ visual_prompts['masks'] = ref_annotations['masks']
1121
+
1122
+ # Make predictions on the target using the current reference
1123
+ results = self.loaded_model.predict(input_image,
1124
+ refer_image=ref_path,
1125
+ visual_prompts=visual_prompts,
1126
+ predictor=YOLOEVPSegPredictor, # TODO This is necessary here?
1127
+ imgsz=self.imgsz_spinbox.value(),
1128
+ conf=self.main_window.get_uncertainty_thresh(),
1129
+ iou=self.main_window.get_iou_thresh(),
1130
+ max_det=self.get_max_detections(),
1131
+ retina_masks=self.task == "segment")
1132
+
1133
+ if not len(results[0].boxes):
1134
+ # If no boxes were detected, skip to the next reference
1135
+ progress_bar.update_progress()
1136
+ continue
1137
+
1138
+ # Update the name of the results and append to the list
1139
+ results[0].names = {0: self.class_mapping[0].short_label_code}
1140
+ results_list.extend(results[0])
1141
+
1142
+ progress_bar.update_progress()
1143
+ gc.collect()
1144
+ empty_cache()
1145
+
1146
+ # Clean up
1147
+ QApplication.restoreOverrideCursor()
1148
+ progress_bar.finish_progress()
1149
+ progress_bar.stop_progress()
1150
+ progress_bar.close()
1151
+
1152
+ # Combine results if there are any
1153
+ combined_results = CombineResults().combine_results(results_list)
1154
+ if combined_results is None:
1155
+ return []
1156
+
1157
+ return [[combined_results]]
1158
+
1159
+ def references_to_vpe(self, reference_dict, update_reference_vpes=True):
1160
+ """
1161
+ Converts the contents of a reference dictionary to VPEs (Visual Prompt Embeddings).
1162
+ Reference dictionaries contain information about the visual prompts for each reference image:
1163
+ dict[image_path]: {bboxes, masks, cls}
1164
+
1165
+ Args:
1166
+ reference_dict (dict): The reference dictionary containing visual prompts for each image.
1167
+ update_reference_vpes (bool): Whether to update self.reference_vpes with the results.
1168
+
1169
+ Returns:
1170
+ list: List of individual VPE tensors (normalized), or None if empty reference_dict
1171
+ """
1172
+ # Check if the reference dictionary is empty
1173
+ if not reference_dict:
1174
+ return None
1175
+
1176
+ # Create a list to hold the individual VPE tensors
1177
+ vpe_list = []
1178
+
1179
+ for ref_path, ref_annotations in reference_dict.items():
1180
+ # Set the prompts to the model predictor
1181
+ self.loaded_model.predictor.set_prompts(ref_annotations)
1182
+
1183
+ # Get the VPE from the model
1184
+ vpe = self.loaded_model.predictor.get_vpe(ref_path)
1185
+
1186
+ # Normalize individual VPE
1187
+ vpe_normalized = torch.nn.functional.normalize(vpe, p=2, dim=-1)
1188
+ vpe_list.append(vpe_normalized)
1189
+
1190
+ # Check if we have any valid VPEs
1191
+ if not vpe_list:
1192
+ return None
1193
+
1194
+ # Update the reference_vpes list if requested
1195
+ if update_reference_vpes:
1196
+ self.reference_vpes = vpe_list
1197
+
1198
+ return vpe_list
1199
+
1200
+ def _apply_model_using_vpe(self, inputs, references_dict):
1201
+ """
1202
+ Apply the model to the inputs using combined VPEs from both imported files
1203
+ and reference annotations.
1204
+
1205
+ Args:
1206
+ inputs (list): List of input images.
1207
+ references_dict (dict): Dictionary containing reference annotations for each image.
1208
+
1209
+ Returns:
1210
+ list: List of prediction results.
1211
+ """
1212
+ # First reload the model to clear any cached data
1213
+ self.reload_model()
1214
+
1215
+ # Initialize combined_vpes list
1216
+ combined_vpes = []
1217
+
1218
+ # Add imported VPEs if available
1219
+ if self.imported_vpes:
1220
+ combined_vpes.extend(self.imported_vpes)
1221
+
1222
+ # Process reference images to VPEs if any exist
1223
+ if references_dict:
1224
+ # Only update reference_vpes if references_dict is not empty
1225
+ reference_vpes = self.references_to_vpe(references_dict, update_reference_vpes=True)
1226
+ if reference_vpes:
1227
+ combined_vpes.extend(reference_vpes)
1228
+ else:
1229
+ # Use existing reference_vpes if we have them
1230
+ if self.reference_vpes:
1231
+ combined_vpes.extend(self.reference_vpes)
1232
+
1233
+ # Check if we have any VPEs to use
1234
+ if not combined_vpes:
1235
+ QMessageBox.warning(
1236
+ self,
1237
+ "No VPEs Available",
1238
+ "No VPEs available for prediction. Please either load a VPE file or select reference images."
1239
+ )
1240
+ return []
1241
+
1242
+ # Average all the VPEs together to create a final VPE tensor
1243
+ averaged_vpe = torch.cat(combined_vpes).mean(dim=0, keepdim=True)
1244
+ final_vpe = torch.nn.functional.normalize(averaged_vpe, p=2, dim=-1)
1245
+
1246
+ # For backward compatibility, update self.vpe
1247
+ self.vpe = final_vpe
1248
+
1249
+ # Set the final VPE to the model
1250
+ self.loaded_model.is_fused = lambda: False
1251
+ self.loaded_model.set_classes(["object0"], final_vpe)
1252
+
1253
+ # Make predictions on the target using the averaged VPE
1254
+ results = self.loaded_model.predict(inputs[0],
1255
+ visual_prompts=[],
1256
+ imgsz=self.imgsz_spinbox.value(),
1257
+ conf=self.main_window.get_uncertainty_thresh(),
1258
+ iou=self.main_window.get_iou_thresh(),
1259
+ max_det=self.get_max_detections(),
1260
+ retina_masks=self.task == "segment")
1261
+
1262
+ return [results]
1263
+
1264
+ def _apply_model(self, inputs):
1265
+ """
1266
+ Apply the model to the target inputs. This method handles both image-based
1267
+ references and VPE-based references.
1268
+ """
1269
+ # Update the model with user parameters
1270
+ self.task = self.use_task_dropdown.currentText()
1271
+
1272
+ self.loaded_model.conf = self.main_window.get_uncertainty_thresh()
1273
+ self.loaded_model.iou = self.main_window.get_iou_thresh()
1274
+ self.loaded_model.max_det = self.get_max_detections()
1275
+
1276
+ # Get the reference information for the currently selected rows
1277
+ references_dict = self._get_references()
1278
+
1279
+ # Check if the user is using VPE or Reference Images
1280
+ if self.reference_method_combo_box.currentText() == "VPE":
1281
+ # Check if we have any VPEs available (imported or reference-generated)
1282
+ has_vpes = bool(self.imported_vpes or self.reference_vpes)
1283
+
1284
+ # If we have reference images selected but no imported VPEs yet,
1285
+ # warn the user only if we also don't have any reference images
1286
+ if not has_vpes and not references_dict:
1287
+ QMessageBox.warning(
1288
+ self,
1289
+ "No VPEs Available",
1290
+ "No VPEs available for prediction. Please either load a VPE file or select reference images."
1291
+ )
1292
+ return []
1293
+
1294
+ # Use the VPE method, which will combine imported and reference VPEs
1295
+ results = self._apply_model_using_vpe(inputs, references_dict)
1296
+ else:
1297
+ # Use Reference Images method - requires reference images
1298
+ if not references_dict:
1299
+ QMessageBox.warning(
1300
+ self,
1301
+ "No References Selected",
1302
+ "No reference images with valid annotations were selected. "
1303
+ "Please select at least one reference image."
1304
+ )
1305
+ return []
1306
+
1307
+ results = self._apply_model_using_images(inputs, references_dict)
1308
+
1309
+ return results
1310
+
1311
+ def _apply_sam(self, results_list, image_path):
1312
+ """Apply SAM to the results if needed."""
1313
+ # Check if SAM model is deployed and loaded
1314
+ self.update_sam_task_state()
1315
+ if self.task != 'segment':
1316
+ return results_list
1317
+
1318
+ if not self.sam_dialog or self.use_sam_dropdown.currentText() == "False":
1319
+ # If SAM is not deployed or not selected, return the results as is
1320
+ return results_list
1321
+
1322
+ if self.sam_dialog.loaded_model is None:
1323
+ # If SAM is not loaded, ensure we do not use it accidentally
1324
+ self.task = 'detect'
1325
+ self.use_sam_dropdown.setCurrentText("False")
1326
+ return results_list
1327
+
1328
+ # Make cursor busy
1329
+ QApplication.setOverrideCursor(Qt.WaitCursor)
1330
+ progress_bar = ProgressBar(self.annotation_window, title="Predicting with SAM")
1331
+ progress_bar.show()
1332
+ progress_bar.start_progress(len(results_list))
1333
+
1334
+ updated_results = []
1335
+
1336
+ for idx, results in enumerate(results_list):
1337
+ # Each Results is a list (within the results_list, [[], ]
1338
+ if results:
1339
+ # Run it rough the SAM model
1340
+ results = self.sam_dialog.predict_from_results(results, image_path)
1341
+ updated_results.append(results)
1342
+
1343
+ # Update the progress bar
1344
+ progress_bar.update_progress()
1345
+
1346
+ # Make cursor normal
1347
+ QApplication.restoreOverrideCursor()
1348
+ progress_bar.finish_progress()
1349
+ progress_bar.stop_progress()
1350
+ progress_bar.close()
1351
+
1352
+ return updated_results
1353
+
1354
+ def _process_results(self, results_processor, results_list, image_path):
1355
+ """Process the results using the result processor."""
1356
+ # Get the raster object and number of work items
1357
+ raster = self.image_window.raster_manager.get_raster(image_path)
1358
+ total = raster.count_work_items()
1359
+
1360
+ # Get the work areas (if any)
1361
+ work_areas = raster.get_work_areas()
1362
+
1363
+ # Start the progress bar
1364
+ progress_bar = ProgressBar(self.annotation_window, title="Processing Results")
1365
+ progress_bar.show()
1366
+ progress_bar.start_progress(total)
1367
+
1368
+ updated_results = []
1369
+
1370
+ for idx, results in enumerate(results_list):
1371
+ # Each Results is a list (within the results_list, [[], ]
1372
+ if results:
1373
+ # Update path and names
1374
+ results[0].path = image_path
1375
+ results[0].names = {0: self.class_mapping[0].short_label_code}
1376
+ # This needs to be done again, in case SAM was used
1377
+
1378
+ # Check if the work area is valid, or the image path is being used
1379
+ if work_areas and self.annotation_window.get_selected_tool() == "work_area":
1380
+ # Map results from work area to the full image
1381
+ results = MapResults().map_results_from_work_area(results[0],
1382
+ raster,
1383
+ work_areas[idx],
1384
+ self.task == "segment")
1385
+ else:
1386
+ results = results[0]
1387
+
1388
+ # Append the result object (not a list) to the updated results list
1389
+ updated_results.append(results)
1390
+
1391
+ # Update the index for the next work area
1392
+ idx += 1
1393
+ progress_bar.update_progress()
1394
+
1395
+ # Process the Results
1396
+ if self.task == 'segment' or self.use_sam_dropdown.currentText() == "True":
1397
+ results_processor.process_segmentation_results(updated_results)
1398
+ else:
1399
+ results_processor.process_detection_results(updated_results)
1400
+
1401
+ # Close the progress bar
1402
+ progress_bar.finish_progress()
1403
+ progress_bar.stop_progress()
1404
+ progress_bar.close()
1405
+
1406
+ def show_vpe(self):
1407
+ """
1408
+ Show a visualization of the VPEs using PyQtGraph.
1409
+ This method now always recalculates VPEs from the currently highlighted reference images.
1410
+ """
1411
+ # Set cursor to busy while loading VPEs
1412
+ QApplication.setOverrideCursor(Qt.WaitCursor)
1413
+
1414
+ try:
1415
+ # Always sync with the live UI selection before visualizing.
1416
+ self.update_stashed_references_from_ui()
1417
+
1418
+ vpes_with_source = []
1419
+
1420
+ # 1. Add any VPEs that were loaded from a file
1421
+ if self.imported_vpes:
1422
+ for vpe in self.imported_vpes:
1423
+ vpes_with_source.append((vpe, "Import"))
1424
+
1425
+ # 2. Get the currently selected reference images from the stashed list
1426
+ references_dict = self._get_references()
1427
+
1428
+ # 3. If there are reference images, calculate their VPEs and add with source type
1429
+ if references_dict:
1430
+ self.reload_model()
1431
+ new_reference_vpes = self.references_to_vpe(references_dict, update_reference_vpes=True)
1432
+ if new_reference_vpes:
1433
+ for vpe in new_reference_vpes:
1434
+ vpes_with_source.append((vpe, "Reference"))
1435
+
1436
+ # 4. Check if there is anything to visualize
1437
+ if not vpes_with_source:
1438
+ QMessageBox.warning(
1439
+ self,
1440
+ "No VPEs Available",
1441
+ "No VPEs available to visualize. Please either load a VPE file or select reference images."
1442
+ )
1443
+ return
1444
+
1445
+ # 5. Create the visualization dialog, passing the list of tuples
1446
+ all_vpe_tensors = [vpe for vpe, source in vpes_with_source]
1447
+ averaged_vpe = torch.cat(all_vpe_tensors).mean(dim=0, keepdim=True)
1448
+ final_vpe = torch.nn.functional.normalize(averaged_vpe, p=2, dim=-1)
1449
+
1450
+ dialog = VPEVisualizationDialog(vpes_with_source, final_vpe, self)
1451
+ dialog.exec_()
1452
+
1453
+ finally:
1454
+ # Always restore cursor, even if an exception occurs
1455
+ QApplication.restoreOverrideCursor()
1456
+
1457
+ def deactivate_model(self):
1458
+ """
1459
+ Deactivate the currently loaded model and clean up resources.
1460
+ """
1461
+ self.loaded_model = None
1462
+ self.model_path = None
1463
+
1464
+ # Clear all VPE-related data
1465
+ self.vpe_path_edit.clear()
1466
+ self.vpe_path = None
1467
+ self.vpe = None
1468
+ self.imported_vpes = []
1469
+ self.reference_vpes = []
1470
+
1471
+ # Clean up references
1472
+ gc.collect()
1473
+ torch.cuda.empty_cache()
1474
+
1475
+ # Untoggle all tools
1476
+ self.main_window.untoggle_all_tools()
1477
+
1478
+ # Update status bar
1479
+ self.status_bar.setText("No model loaded")
1480
+ QMessageBox.information(self, "Model Deactivated", "Model deactivated")
1481
+
1482
+
1483
+ class VPEVisualizationDialog(QDialog):
1484
+ """
1485
+ Dialog for visualizing VPE embeddings in 2D space using PCA.
1486
+ """
1487
+ def __init__(self, vpe_list_with_source, final_vpe=None, parent=None):
1488
+ """
1489
+ Initialize the dialog with a list of VPE tensors and their sources.
1490
+
1491
+ Args:
1492
+ vpe_list_with_source (list): List of (VPE tensor, source_str) tuples
1493
+ final_vpe (torch.Tensor, optional): The final (averaged) VPE
1494
+ parent (QWidget, optional): Parent widget
1495
+ """
1496
+ super().__init__(parent)
1497
+ self.setWindowTitle("VPE Visualization")
1498
+ self.resize(1000, 1000)
1499
+
1500
+ # Add a maximize button to the dialog's title bar
1501
+ self.setWindowFlags(self.windowFlags() | Qt.WindowMaximizeButtonHint)
1502
+
1503
+ # Store the VPEs and their sources
1504
+ self.vpe_list_with_source = vpe_list_with_source
1505
+ self.final_vpe = final_vpe
1506
+
1507
+ # Create the layout
1508
+ layout = QVBoxLayout(self)
1509
+
1510
+ # Create the plot widget
1511
+ self.plot_widget = pg.PlotWidget()
1512
+ self.plot_widget.setBackground('w') # White background
1513
+ self.plot_widget.setTitle("PCA Visualization of Visual Prompt Embeddings", color="#000000", size="10pt")
1514
+ self.plot_widget.showGrid(x=True, y=True, alpha=0.3)
1515
+
1516
+ # Add the plot widget to the layout
1517
+ layout.addWidget(self.plot_widget)
1518
+
1519
+ # Add spacing between plot_widget and info_label
1520
+ layout.addSpacing(20)
1521
+
1522
+ # Add information label at the bottom
1523
+ self.info_label = QLabel()
1524
+ self.info_label.setAlignment(Qt.AlignCenter)
1525
+ layout.addWidget(self.info_label)
1526
+
1527
+ # Create the button box
1528
+ button_box = QDialogButtonBox(QDialogButtonBox.Close)
1529
+ button_box.rejected.connect(self.reject)
1530
+ layout.addWidget(button_box)
1531
+
1532
+ # Visualize the VPEs
1533
+ self.visualize_vpes()
1534
+
1535
+ def visualize_vpes(self):
1536
+ """
1537
+ Apply PCA to the VPE tensors and visualize them in 2D space.
1538
+ """
1539
+ if not self.vpe_list_with_source:
1540
+ self.info_label.setText("No VPEs available to visualize.")
1541
+ return
1542
+
1543
+ # Convert tensors to numpy arrays for PCA, separating them from the source string
1544
+ vpe_arrays = [vpe.detach().cpu().numpy().squeeze() for vpe, source in self.vpe_list_with_source]
1545
+
1546
+ # If final VPE is provided, add it to the arrays
1547
+ final_vpe_array = None
1548
+ if self.final_vpe is not None:
1549
+ final_vpe_array = self.final_vpe.detach().cpu().numpy().squeeze()
1550
+ all_vpes = np.vstack(vpe_arrays + [final_vpe_array])
1551
+ else:
1552
+ all_vpes = np.vstack(vpe_arrays)
1553
+
1554
+ # Apply PCA to reduce to 2 dimensions
1555
+ pca = PCA(n_components=2)
1556
+ vpes_2d = pca.fit_transform(all_vpes)
1557
+
1558
+ # Clear the plot
1559
+ self.plot_widget.clear()
1560
+
1561
+ # Generate random colors for individual VPEs
1562
+ num_vpes = len(vpe_arrays)
1563
+ colors = self.generate_distinct_colors(num_vpes)
1564
+
1565
+ # Create a legend with 3 columns to keep it compact
1566
+ legend = self.plot_widget.addLegend(colCount=3)
1567
+
1568
+ # Plot individual VPEs
1569
+ for i, (vpe_tuple, vpe_2d) in enumerate(zip(self.vpe_list_with_source, vpes_2d[:num_vpes])):
1570
+ source_char = 'I' if vpe_tuple[1] == 'Import' else 'R'
1571
+ color = pg.mkColor(colors[i])
1572
+ scatter = pg.ScatterPlotItem(
1573
+ x=[vpe_2d[0]],
1574
+ y=[vpe_2d[1]],
1575
+ brush=color,
1576
+ size=15,
1577
+ name=f"VPE {i+1} ({source_char})"
1578
+ )
1579
+ self.plot_widget.addItem(scatter)
1580
+
1581
+ # Plot the final (averaged) VPE if available
1582
+ if final_vpe_array is not None:
1583
+ final_vpe_2d = vpes_2d[-1]
1584
+ scatter = pg.ScatterPlotItem(
1585
+ x=[final_vpe_2d[0]],
1586
+ y=[final_vpe_2d[1]],
1587
+ brush=pg.mkBrush(color='r'),
1588
+ size=20,
1589
+ symbol='star',
1590
+ name="Final VPE"
1591
+ )
1592
+ self.plot_widget.addItem(scatter)
1593
+
1594
+ # Update the information label
1595
+ orig_dim = self.vpe_list_with_source[0][0].shape[-1]
1596
+ explained_variance = sum(pca.explained_variance_ratio_)
1597
+ self.info_label.setText(
1598
+ f"Original dimension: {orig_dim} → Reduced to 2D\n"
1599
+ f"Total explained variance: {explained_variance:.2%}\n"
1600
+ f"PC1: {pca.explained_variance_ratio_[0]:.2%} variance, "
1601
+ f"PC2: {pca.explained_variance_ratio_[1]:.2%} variance"
1602
+ )
1603
+
1604
+ def generate_distinct_colors(self, num_colors):
1605
+ """
1606
+ Generate visually distinct colors by using evenly spaced hues
1607
+ with random saturation and value.
1608
+
1609
+ Args:
1610
+ num_colors (int): Number of colors to generate
1611
+
1612
+ Returns:
1613
+ list: List of color hex strings
1614
+ """
1615
+ import random
1616
+ from colorsys import hsv_to_rgb
1617
+
1618
+ colors = []
1619
+ for i in range(num_colors):
1620
+ # Use golden ratio to space hues evenly
1621
+ hue = (i * 0.618033988749895) % 1.0
1622
+ # Random saturation between 0.6-1.0 (avoid too pale)
1623
+ saturation = random.uniform(0.6, 1.0)
1624
+ # Random value between 0.7-1.0 (avoid too dark)
1625
+ value = random.uniform(0.7, 1.0)
1626
+
1627
+ # Convert HSV to RGB (0-1 range)
1628
+ r, g, b = hsv_to_rgb(hue, saturation, value)
1629
+
1630
+ # Convert RGB to hex string
1631
+ hex_color = f"#{int(r*255):02x}{int(g*255):02x}{int(b*255):02x}"
1632
+ colors.append(hex_color)
1633
+
1634
+ return colors