coralnet-toolbox 0.0.73__py2.py3-none-any.whl → 0.0.75__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- coralnet_toolbox/Annotations/QtAnnotation.py +28 -69
- coralnet_toolbox/Annotations/QtMaskAnnotation.py +408 -0
- coralnet_toolbox/Annotations/QtMultiPolygonAnnotation.py +72 -56
- coralnet_toolbox/Annotations/QtPatchAnnotation.py +165 -216
- coralnet_toolbox/Annotations/QtPolygonAnnotation.py +497 -353
- coralnet_toolbox/Annotations/QtRectangleAnnotation.py +126 -116
- coralnet_toolbox/CoralNet/QtDownload.py +2 -1
- coralnet_toolbox/Explorer/QtDataItem.py +52 -22
- coralnet_toolbox/Explorer/QtExplorer.py +293 -1614
- coralnet_toolbox/Explorer/QtSettingsWidgets.py +203 -85
- coralnet_toolbox/Explorer/QtViewers.py +1568 -0
- coralnet_toolbox/Explorer/transformer_models.py +59 -0
- coralnet_toolbox/Explorer/yolo_models.py +112 -0
- coralnet_toolbox/IO/QtExportTagLabAnnotations.py +30 -10
- coralnet_toolbox/IO/QtImportTagLabAnnotations.py +21 -15
- coralnet_toolbox/IO/QtOpenProject.py +46 -78
- coralnet_toolbox/IO/QtSaveProject.py +18 -43
- coralnet_toolbox/MachineLearning/ExportDataset/QtBase.py +1 -1
- coralnet_toolbox/MachineLearning/ImportDataset/QtBase.py +253 -141
- coralnet_toolbox/MachineLearning/VideoInference/QtBase.py +0 -4
- coralnet_toolbox/MachineLearning/VideoInference/YOLO3D/run.py +102 -16
- coralnet_toolbox/QtAnnotationWindow.py +16 -10
- coralnet_toolbox/QtEventFilter.py +11 -0
- coralnet_toolbox/QtImageWindow.py +120 -75
- coralnet_toolbox/QtLabelWindow.py +13 -1
- coralnet_toolbox/QtMainWindow.py +5 -27
- coralnet_toolbox/QtProgressBar.py +52 -27
- coralnet_toolbox/Rasters/RasterTableModel.py +28 -8
- coralnet_toolbox/SAM/QtDeployGenerator.py +1 -4
- coralnet_toolbox/SAM/QtDeployPredictor.py +11 -3
- coralnet_toolbox/SeeAnything/QtDeployGenerator.py +805 -162
- coralnet_toolbox/SeeAnything/QtDeployPredictor.py +130 -151
- coralnet_toolbox/Tools/QtCutSubTool.py +18 -2
- coralnet_toolbox/Tools/QtPolygonTool.py +42 -3
- coralnet_toolbox/Tools/QtRectangleTool.py +30 -0
- coralnet_toolbox/Tools/QtResizeSubTool.py +19 -2
- coralnet_toolbox/Tools/QtSAMTool.py +72 -50
- coralnet_toolbox/Tools/QtSeeAnythingTool.py +8 -5
- coralnet_toolbox/Tools/QtSelectTool.py +27 -3
- coralnet_toolbox/Tools/QtSubtractSubTool.py +66 -0
- coralnet_toolbox/Tools/__init__.py +2 -0
- coralnet_toolbox/__init__.py +1 -1
- coralnet_toolbox/utilities.py +158 -47
- coralnet_toolbox-0.0.75.dist-info/METADATA +378 -0
- {coralnet_toolbox-0.0.73.dist-info → coralnet_toolbox-0.0.75.dist-info}/RECORD +49 -44
- coralnet_toolbox-0.0.73.dist-info/METADATA +0 -341
- {coralnet_toolbox-0.0.73.dist-info → coralnet_toolbox-0.0.75.dist-info}/WHEEL +0 -0
- {coralnet_toolbox-0.0.73.dist-info → coralnet_toolbox-0.0.75.dist-info}/entry_points.txt +0 -0
- {coralnet_toolbox-0.0.73.dist-info → coralnet_toolbox-0.0.75.dist-info}/licenses/LICENSE.txt +0 -0
- {coralnet_toolbox-0.0.73.dist-info → coralnet_toolbox-0.0.75.dist-info}/top_level.txt +0 -0
@@ -2,24 +2,24 @@ import warnings
|
|
2
2
|
|
3
3
|
import os
|
4
4
|
import gc
|
5
|
-
import json
|
6
|
-
import copy
|
7
5
|
|
8
6
|
import numpy as np
|
7
|
+
from sklearn.decomposition import PCA
|
9
8
|
|
10
9
|
import torch
|
11
10
|
from torch.cuda import empty_cache
|
12
11
|
|
12
|
+
import pyqtgraph as pg
|
13
|
+
from pyqtgraph.Qt import QtGui
|
14
|
+
|
13
15
|
from ultralytics import YOLOE
|
14
16
|
from ultralytics.models.yolo.yoloe import YOLOEVPSegPredictor
|
15
|
-
from ultralytics.models.yolo.yoloe import YOLOEVPDetectPredictor
|
16
17
|
|
17
18
|
from PyQt5.QtCore import Qt
|
18
|
-
from PyQt5.
|
19
|
-
from PyQt5.QtWidgets import (QMessageBox, QCheckBox, QVBoxLayout, QApplication,
|
19
|
+
from PyQt5.QtWidgets import (QMessageBox, QVBoxLayout, QApplication, QFileDialog,
|
20
20
|
QLabel, QDialog, QDialogButtonBox, QGroupBox, QLineEdit,
|
21
21
|
QFormLayout, QComboBox, QSpinBox, QSlider, QPushButton,
|
22
|
-
QHBoxLayout
|
22
|
+
QHBoxLayout)
|
23
23
|
|
24
24
|
from coralnet_toolbox.Annotations.QtPolygonAnnotation import PolygonAnnotation
|
25
25
|
from coralnet_toolbox.Annotations.QtRectangleAnnotation import RectangleAnnotation
|
@@ -79,10 +79,18 @@ class DeployGeneratorDialog(QDialog):
|
|
79
79
|
self.class_mapping = {}
|
80
80
|
|
81
81
|
# Reference image and label
|
82
|
-
self.
|
83
|
-
self.
|
84
|
-
|
85
|
-
|
82
|
+
self.reference_label = None
|
83
|
+
self.reference_image_paths = []
|
84
|
+
|
85
|
+
# Visual Prompting Encoding (VPE) - legacy single tensor variable
|
86
|
+
self.vpe_path = None
|
87
|
+
self.vpe = None
|
88
|
+
|
89
|
+
# New separate VPE collections
|
90
|
+
self.imported_vpes = [] # VPEs loaded from file
|
91
|
+
self.reference_vpes = [] # VPEs created from reference images
|
92
|
+
|
93
|
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
86
94
|
|
87
95
|
# Main vertical layout for the dialog
|
88
96
|
self.layout = QVBoxLayout(self)
|
@@ -110,7 +118,7 @@ class DeployGeneratorDialog(QDialog):
|
|
110
118
|
self.setup_status_layout()
|
111
119
|
|
112
120
|
# Add layouts to the right panel
|
113
|
-
self.
|
121
|
+
self.setup_reference_layout()
|
114
122
|
|
115
123
|
# # Add a full ImageWindow instance for target image selection
|
116
124
|
self.image_selection_window = ImageWindow(self.main_window)
|
@@ -158,6 +166,9 @@ class DeployGeneratorDialog(QDialog):
|
|
158
166
|
iw.search_bar_images.setEnabled(False)
|
159
167
|
iw.search_bar_labels.setEnabled(False)
|
160
168
|
iw.top_k_combo.setEnabled(False)
|
169
|
+
|
170
|
+
# Hide the "Current" label as it is not applicable in this dialog
|
171
|
+
iw.current_image_index_label.hide()
|
161
172
|
|
162
173
|
# Set Top-K to Top1
|
163
174
|
iw.top_k_combo.setCurrentText("Top1")
|
@@ -190,7 +201,7 @@ class DeployGeneratorDialog(QDialog):
|
|
190
201
|
self.sync_image_window()
|
191
202
|
# This now populates the dropdown, restores the last selection,
|
192
203
|
# and then manually triggers the image filtering.
|
193
|
-
self.
|
204
|
+
self.update_reference_labels()
|
194
205
|
|
195
206
|
def sync_image_window(self):
|
196
207
|
"""
|
@@ -219,14 +230,23 @@ class DeployGeneratorDialog(QDialog):
|
|
219
230
|
annotation that has BOTH the selected label AND a valid type (Polygon or Rectangle).
|
220
231
|
This uses the fast, pre-computed cache for performance.
|
221
232
|
"""
|
222
|
-
|
223
|
-
|
233
|
+
# Persist the user's current highlights from the table model before filtering.
|
234
|
+
# This ensures that if the user highlights items and then changes the filter,
|
235
|
+
# their selection is not lost.
|
236
|
+
current_highlights = self.image_selection_window.table_model.get_highlighted_paths()
|
237
|
+
if current_highlights:
|
238
|
+
self.reference_image_paths = current_highlights
|
239
|
+
|
240
|
+
reference_label = self.reference_label_combo_box.currentData()
|
241
|
+
reference_label_text = self.reference_label_combo_box.currentText()
|
224
242
|
|
225
243
|
# Store the last selected label for a better user experience on re-opening.
|
226
|
-
if
|
227
|
-
self.last_selected_label_code =
|
244
|
+
if reference_label_text:
|
245
|
+
self.last_selected_label_code = reference_label_text
|
246
|
+
# Also store the reference label object itself
|
247
|
+
self.reference_label = reference_label
|
228
248
|
|
229
|
-
if not
|
249
|
+
if not reference_label:
|
230
250
|
# If no label is selected (e.g., during initialization), show an empty list.
|
231
251
|
self.image_selection_window.table_model.set_filtered_paths([])
|
232
252
|
return
|
@@ -235,7 +255,7 @@ class DeployGeneratorDialog(QDialog):
|
|
235
255
|
final_filtered_paths = []
|
236
256
|
|
237
257
|
valid_types = {"RectangleAnnotation", "PolygonAnnotation"}
|
238
|
-
selected_label_code =
|
258
|
+
selected_label_code = reference_label.short_label_code
|
239
259
|
|
240
260
|
# Loop through paths and check the pre-computed map on each raster
|
241
261
|
for path in all_paths:
|
@@ -257,40 +277,54 @@ class DeployGeneratorDialog(QDialog):
|
|
257
277
|
# Directly set the filtered list in the table model.
|
258
278
|
self.image_selection_window.table_model.set_filtered_paths(final_filtered_paths)
|
259
279
|
|
280
|
+
# Try to preserve any previous selections
|
281
|
+
if hasattr(self, 'reference_image_paths') and self.reference_image_paths:
|
282
|
+
# Find which of our previously selected paths are still in the filtered list
|
283
|
+
valid_selections = [p for p in self.reference_image_paths if p in final_filtered_paths]
|
284
|
+
if valid_selections:
|
285
|
+
# Highlight previously selected paths that are still valid
|
286
|
+
self.image_selection_window.table_model.set_highlighted_paths(valid_selections)
|
287
|
+
|
288
|
+
# After filtering, update all labels with the correct counts.
|
289
|
+
dialog_iw = self.image_selection_window
|
290
|
+
dialog_iw.update_image_count_label(len(final_filtered_paths)) # Set "Total" to filtered count
|
291
|
+
dialog_iw.update_current_image_index_label()
|
292
|
+
dialog_iw.update_highlighted_count_label()
|
293
|
+
|
260
294
|
def accept(self):
|
261
295
|
"""
|
262
296
|
Validate selections and store them before closing the dialog.
|
297
|
+
A prediction is valid if a model and label are selected, and the user
|
298
|
+
has provided either reference images or an imported VPE file.
|
263
299
|
"""
|
264
300
|
if not self.loaded_model:
|
265
301
|
QMessageBox.warning(self,
|
266
302
|
"No Model",
|
267
303
|
"A model must be loaded before running predictions.")
|
268
|
-
super().reject()
|
269
304
|
return
|
270
305
|
|
271
|
-
|
272
|
-
|
306
|
+
# Set reference label from combo box
|
307
|
+
self.reference_label = self.reference_label_combo_box.currentData()
|
308
|
+
if not self.reference_label:
|
273
309
|
QMessageBox.warning(self,
|
274
|
-
"No
|
275
|
-
"A
|
276
|
-
super().reject()
|
310
|
+
"No Reference Label",
|
311
|
+
"A reference label must be selected.")
|
277
312
|
return
|
278
313
|
|
279
|
-
#
|
280
|
-
|
314
|
+
# Stash the current UI selection before validating.
|
315
|
+
self.update_stashed_references_from_ui()
|
316
|
+
|
317
|
+
# Check for a valid VPE source using the now-stashed list.
|
318
|
+
has_reference_images = bool(self.reference_image_paths)
|
319
|
+
has_imported_vpes = bool(self.imported_vpes)
|
281
320
|
|
282
|
-
if not
|
321
|
+
if not has_reference_images and not has_imported_vpes:
|
283
322
|
QMessageBox.warning(self,
|
284
|
-
"No
|
285
|
-
"You must highlight at least one image
|
286
|
-
super().reject()
|
323
|
+
"No VPE Source Provided",
|
324
|
+
"You must highlight at least one reference image or load a VPE file to proceed.")
|
287
325
|
return
|
288
326
|
|
289
|
-
#
|
290
|
-
self.source_label = current_label
|
291
|
-
self.target_images = highlighted_images
|
292
|
-
|
293
|
-
# Do not call self.predict here; just close the dialog and let the caller handle prediction
|
327
|
+
# If validation passes, close the dialog.
|
294
328
|
super().accept()
|
295
329
|
|
296
330
|
def setup_info_layout(self):
|
@@ -317,19 +351,19 @@ class DeployGeneratorDialog(QDialog):
|
|
317
351
|
Setup the models layout with a simple model selection combo box (no tabs).
|
318
352
|
"""
|
319
353
|
group_box = QGroupBox("Model Selection")
|
320
|
-
layout =
|
354
|
+
layout = QFormLayout()
|
321
355
|
|
322
356
|
self.model_combo = QComboBox()
|
323
357
|
self.model_combo.setEditable(True)
|
324
358
|
|
325
359
|
# Define available models (keep the existing dictionary)
|
326
360
|
self.models = [
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
361
|
+
'yoloe-v8s-seg.pt',
|
362
|
+
'yoloe-v8m-seg.pt',
|
363
|
+
'yoloe-v8l-seg.pt',
|
364
|
+
'yoloe-11s-seg.pt',
|
365
|
+
'yoloe-11m-seg.pt',
|
366
|
+
'yoloe-11l-seg.pt',
|
333
367
|
]
|
334
368
|
|
335
369
|
# Add all models to combo box
|
@@ -337,10 +371,19 @@ class DeployGeneratorDialog(QDialog):
|
|
337
371
|
self.model_combo.addItem(model_name)
|
338
372
|
|
339
373
|
# Set the default model
|
340
|
-
self.model_combo.
|
374
|
+
self.model_combo.setCurrentIndex(self.models.index('yoloe-v8s-seg.pt'))
|
375
|
+
# Create a layout for the model selection
|
376
|
+
layout.addRow(QLabel("Models:"), self.model_combo)
|
377
|
+
|
378
|
+
# Add custom vpe file selection
|
379
|
+
self.vpe_path_edit = QLineEdit()
|
380
|
+
browse_button = QPushButton("Browse...")
|
381
|
+
browse_button.clicked.connect(self.browse_vpe_file)
|
341
382
|
|
342
|
-
|
343
|
-
|
383
|
+
vpe_path_layout = QHBoxLayout()
|
384
|
+
vpe_path_layout.addWidget(self.vpe_path_edit)
|
385
|
+
vpe_path_layout.addWidget(browse_button)
|
386
|
+
layout.addRow("Custom VPE:", vpe_path_layout)
|
344
387
|
|
345
388
|
group_box.setLayout(layout)
|
346
389
|
self.left_panel.addWidget(group_box) # Add to left panel
|
@@ -373,7 +416,7 @@ class DeployGeneratorDialog(QDialog):
|
|
373
416
|
|
374
417
|
# Image size control
|
375
418
|
self.imgsz_spinbox = QSpinBox()
|
376
|
-
self.imgsz_spinbox.setRange(
|
419
|
+
self.imgsz_spinbox.setRange(1024, 65536)
|
377
420
|
self.imgsz_spinbox.setSingleStep(1024)
|
378
421
|
self.imgsz_spinbox.setValue(self.imgsz)
|
379
422
|
layout.addRow("Image Size (imgsz):", self.imgsz_spinbox)
|
@@ -445,17 +488,38 @@ class DeployGeneratorDialog(QDialog):
|
|
445
488
|
Setup action buttons in a group box.
|
446
489
|
"""
|
447
490
|
group_box = QGroupBox("Actions")
|
448
|
-
|
491
|
+
main_layout = QVBoxLayout()
|
449
492
|
|
493
|
+
# First row: Load and Deactivate buttons side by side
|
494
|
+
button_row = QHBoxLayout()
|
450
495
|
load_button = QPushButton("Load Model")
|
451
496
|
load_button.clicked.connect(self.load_model)
|
452
|
-
|
497
|
+
button_row.addWidget(load_button)
|
453
498
|
|
454
499
|
deactivate_button = QPushButton("Deactivate Model")
|
455
500
|
deactivate_button.clicked.connect(self.deactivate_model)
|
456
|
-
|
501
|
+
button_row.addWidget(deactivate_button)
|
457
502
|
|
458
|
-
|
503
|
+
main_layout.addLayout(button_row)
|
504
|
+
|
505
|
+
# Second row: VPE action buttons
|
506
|
+
vpe_row = QHBoxLayout()
|
507
|
+
|
508
|
+
generate_vpe_button = QPushButton("Generate VPEs")
|
509
|
+
generate_vpe_button.clicked.connect(self.generate_vpes_from_references)
|
510
|
+
vpe_row.addWidget(generate_vpe_button)
|
511
|
+
|
512
|
+
save_vpe_button = QPushButton("Save VPE")
|
513
|
+
save_vpe_button.clicked.connect(self.save_vpe)
|
514
|
+
vpe_row.addWidget(save_vpe_button)
|
515
|
+
|
516
|
+
show_vpe_button = QPushButton("Show VPE")
|
517
|
+
show_vpe_button.clicked.connect(self.show_vpe)
|
518
|
+
vpe_row.addWidget(show_vpe_button)
|
519
|
+
|
520
|
+
main_layout.addLayout(vpe_row)
|
521
|
+
|
522
|
+
group_box.setLayout(main_layout)
|
459
523
|
self.left_panel.addWidget(group_box) # Add to left panel
|
460
524
|
|
461
525
|
def setup_status_layout(self):
|
@@ -471,18 +535,23 @@ class DeployGeneratorDialog(QDialog):
|
|
471
535
|
group_box.setLayout(layout)
|
472
536
|
self.left_panel.addWidget(group_box) # Add to left panel
|
473
537
|
|
474
|
-
def
|
538
|
+
def setup_reference_layout(self):
|
475
539
|
"""
|
476
|
-
Set up the layout with
|
477
|
-
The
|
540
|
+
Set up the layout with reference label selection.
|
541
|
+
The reference image is implicitly the currently active image.
|
478
542
|
"""
|
479
|
-
group_box = QGroupBox("Reference
|
543
|
+
group_box = QGroupBox("Reference")
|
480
544
|
layout = QFormLayout()
|
481
545
|
|
482
|
-
# Create the
|
483
|
-
self.
|
484
|
-
self.
|
485
|
-
layout.addRow("Reference Label:", self.
|
546
|
+
# Create the reference label combo box
|
547
|
+
self.reference_label_combo_box = QComboBox()
|
548
|
+
self.reference_label_combo_box.currentIndexChanged.connect(self.filter_images_by_label_and_type)
|
549
|
+
layout.addRow("Reference Label:", self.reference_label_combo_box)
|
550
|
+
|
551
|
+
# Create a Reference model combobox (VPE, Images)
|
552
|
+
self.reference_method_combo_box = QComboBox()
|
553
|
+
self.reference_method_combo_box.addItems(["VPE", "Images"])
|
554
|
+
layout.addRow("Reference Method:", self.reference_method_combo_box)
|
486
555
|
|
487
556
|
group_box.setLayout(layout)
|
488
557
|
self.right_panel.addWidget(group_box) # Add to right panel
|
@@ -543,6 +612,10 @@ class DeployGeneratorDialog(QDialog):
|
|
543
612
|
self.area_thresh_max = max_val / 100.0
|
544
613
|
self.main_window.update_area_thresh(self.area_thresh_min, self.area_thresh_max)
|
545
614
|
self.area_threshold_label.setText(f"{self.area_thresh_min:.2f} - {self.area_thresh_max:.2f}")
|
615
|
+
|
616
|
+
def update_stashed_references_from_ui(self):
|
617
|
+
"""Updates the internal reference path list from the current UI selection."""
|
618
|
+
self.reference_image_paths = self.image_selection_window.table_model.get_highlighted_paths()
|
546
619
|
|
547
620
|
def get_max_detections(self):
|
548
621
|
"""Get the maximum number of detections to return."""
|
@@ -603,54 +676,44 @@ class DeployGeneratorDialog(QDialog):
|
|
603
676
|
if self.loaded_model:
|
604
677
|
self.deactivate_model()
|
605
678
|
|
606
|
-
def
|
679
|
+
def update_reference_labels(self):
|
607
680
|
"""
|
608
|
-
Updates the
|
609
|
-
|
681
|
+
Updates the reference label combo box with ALL available project labels.
|
682
|
+
This dropdown now serves as the "Output Label" for all predictions.
|
683
|
+
The "Review" label with id "-1" is excluded.
|
610
684
|
"""
|
611
|
-
self.
|
685
|
+
self.reference_label_combo_box.blockSignals(True)
|
612
686
|
|
613
687
|
try:
|
614
|
-
self.
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
for raster in dialog_manager.rasters.values():
|
625
|
-
# raster.label_to_types_map is like: {'coral': {'Point'}, 'rock': {'Polygon'}}
|
626
|
-
for label_code, types_for_label in raster.label_to_types_map.items():
|
627
|
-
# Check if the set of types for this specific label
|
628
|
-
# has any overlap with our valid reference types.
|
629
|
-
if not valid_types.isdisjoint(types_for_label):
|
630
|
-
# This label is a valid reference label.
|
631
|
-
# Add its full Label object to our set of valid labels.
|
632
|
-
if label_code in all_project_labels:
|
633
|
-
valid_labels.add(all_project_labels[label_code])
|
688
|
+
self.reference_label_combo_box.clear()
|
689
|
+
|
690
|
+
# Get all labels from the main label window
|
691
|
+
all_project_labels = self.main_window.label_window.labels
|
692
|
+
|
693
|
+
# Filter out the special "Review" label and create a list of valid labels
|
694
|
+
valid_labels = [
|
695
|
+
label_obj for label_obj in all_project_labels
|
696
|
+
if not (label_obj.short_label_code == "Review" and str(label_obj.id) == "-1")
|
697
|
+
]
|
634
698
|
|
635
699
|
# Add the valid labels to the combo box, sorted alphabetically.
|
636
|
-
sorted_valid_labels = sorted(
|
700
|
+
sorted_valid_labels = sorted(valid_labels, key=lambda x: x.short_label_code)
|
637
701
|
for label_obj in sorted_valid_labels:
|
638
|
-
self.
|
702
|
+
self.reference_label_combo_box.addItem(label_obj.short_label_code, label_obj)
|
639
703
|
|
640
704
|
# Restore the last selected label if it's still present in the list.
|
641
705
|
if self.last_selected_label_code:
|
642
|
-
index = self.
|
706
|
+
index = self.reference_label_combo_box.findText(self.last_selected_label_code)
|
643
707
|
if index != -1:
|
644
|
-
self.
|
708
|
+
self.reference_label_combo_box.setCurrentIndex(index)
|
645
709
|
finally:
|
646
|
-
self.
|
710
|
+
self.reference_label_combo_box.blockSignals(False)
|
647
711
|
|
648
|
-
# Manually trigger the filtering now that the combo box is stable.
|
712
|
+
# Manually trigger the image filtering now that the combo box is stable.
|
713
|
+
# This will still filter the image list to help find references if needed.
|
649
714
|
self.filter_images_by_label_and_type()
|
650
715
|
|
651
|
-
|
652
|
-
|
653
|
-
def get_source_annotations(self, reference_label, reference_image_path):
|
716
|
+
def get_reference_annotations(self, reference_label, reference_image_path):
|
654
717
|
"""
|
655
718
|
Return a list of bboxes and masks for a specific image
|
656
719
|
belonging to the selected label.
|
@@ -666,22 +729,161 @@ class DeployGeneratorDialog(QDialog):
|
|
666
729
|
annotations = self.annotation_window.get_image_annotations(reference_image_path)
|
667
730
|
|
668
731
|
# Filter annotations by the provided label
|
669
|
-
|
670
|
-
|
732
|
+
reference_bboxes = []
|
733
|
+
reference_masks = []
|
671
734
|
for annotation in annotations:
|
672
735
|
if annotation.label.short_label_code == reference_label.short_label_code:
|
673
736
|
if isinstance(annotation, (PolygonAnnotation, RectangleAnnotation)):
|
674
737
|
bbox = annotation.cropped_bbox
|
675
|
-
|
738
|
+
reference_bboxes.append(bbox)
|
676
739
|
if isinstance(annotation, PolygonAnnotation):
|
677
740
|
points = np.array([[p.x(), p.y()] for p in annotation.points])
|
678
|
-
|
741
|
+
reference_masks.append(points)
|
679
742
|
elif isinstance(annotation, RectangleAnnotation):
|
680
743
|
x1, y1, x2, y2 = bbox
|
681
744
|
rect_points = np.array([[x1, y1], [x2, y1], [x2, y2], [x1, y2]])
|
682
|
-
|
745
|
+
reference_masks.append(rect_points)
|
683
746
|
|
684
|
-
return np.array(
|
747
|
+
return np.array(reference_bboxes), reference_masks
|
748
|
+
|
749
|
+
def browse_vpe_file(self):
|
750
|
+
"""
|
751
|
+
Open a file dialog to browse for a VPE file and load it.
|
752
|
+
Stores imported VPEs separately from reference-generated VPEs.
|
753
|
+
"""
|
754
|
+
file_path, _ = QFileDialog.getOpenFileName(
|
755
|
+
self,
|
756
|
+
"Select Visual Prompt Encoding (VPE) File",
|
757
|
+
"",
|
758
|
+
"VPE Files (*.pt);;All Files (*)"
|
759
|
+
)
|
760
|
+
|
761
|
+
if not file_path:
|
762
|
+
return
|
763
|
+
|
764
|
+
self.vpe_path_edit.setText(file_path)
|
765
|
+
self.vpe_path = file_path
|
766
|
+
|
767
|
+
try:
|
768
|
+
# Load the VPE file
|
769
|
+
loaded_data = torch.load(file_path)
|
770
|
+
|
771
|
+
# TODO Move tensors to the appropriate device
|
772
|
+
# device = self.main_window.device
|
773
|
+
|
774
|
+
# Check format type and handle appropriately
|
775
|
+
if isinstance(loaded_data, list):
|
776
|
+
# New format: list of VPE tensors
|
777
|
+
self.imported_vpes = [vpe.to(self.device) for vpe in loaded_data]
|
778
|
+
vpe_count = len(self.imported_vpes)
|
779
|
+
self.status_bar.setText(f"Loaded {vpe_count} VPE tensors from file")
|
780
|
+
|
781
|
+
elif isinstance(loaded_data, torch.Tensor):
|
782
|
+
# Legacy format: single tensor - convert to list for consistency
|
783
|
+
loaded_vpe = loaded_data.to(self.device)
|
784
|
+
# Store as a single-item list
|
785
|
+
self.imported_vpes = [loaded_vpe]
|
786
|
+
self.status_bar.setText("Loaded 1 VPE tensor from file (legacy format)")
|
787
|
+
|
788
|
+
else:
|
789
|
+
# Invalid format
|
790
|
+
self.imported_vpes = []
|
791
|
+
self.status_bar.setText("Invalid VPE file format")
|
792
|
+
QMessageBox.warning(
|
793
|
+
self,
|
794
|
+
"Invalid VPE",
|
795
|
+
"The file does not appear to be a valid VPE format."
|
796
|
+
)
|
797
|
+
# Clear the VPE path edit field
|
798
|
+
self.vpe_path_edit.clear()
|
799
|
+
|
800
|
+
# For backward compatibility - set self.vpe to the average of imported VPEs
|
801
|
+
# This ensures older code paths still work
|
802
|
+
if self.imported_vpes:
|
803
|
+
combined_vpe = torch.cat(self.imported_vpes).mean(dim=0, keepdim=True)
|
804
|
+
self.vpe = torch.nn.functional.normalize(combined_vpe, p=2, dim=-1)
|
805
|
+
|
806
|
+
except Exception as e:
|
807
|
+
self.imported_vpes = []
|
808
|
+
self.vpe = None
|
809
|
+
self.status_bar.setText(f"Error loading VPE: {str(e)}")
|
810
|
+
QMessageBox.critical(
|
811
|
+
self,
|
812
|
+
"Error Loading VPE",
|
813
|
+
f"Failed to load VPE file: {str(e)}"
|
814
|
+
)
|
815
|
+
|
816
|
+
def save_vpe(self):
|
817
|
+
"""
|
818
|
+
Saves the combined collection of VPEs (imported and pre-generated from references) to disk.
|
819
|
+
"""
|
820
|
+
QApplication.setOverrideCursor(Qt.WaitCursor)
|
821
|
+
|
822
|
+
try:
|
823
|
+
# Create a list to hold all VPEs to be saved
|
824
|
+
all_vpes = []
|
825
|
+
|
826
|
+
# Add imported VPEs if available
|
827
|
+
if self.imported_vpes:
|
828
|
+
all_vpes.extend(self.imported_vpes)
|
829
|
+
|
830
|
+
# Add pre-generated reference VPEs if available
|
831
|
+
if self.reference_vpes:
|
832
|
+
all_vpes.extend(self.reference_vpes)
|
833
|
+
|
834
|
+
# Check if we have any VPEs to save
|
835
|
+
if not all_vpes:
|
836
|
+
QApplication.restoreOverrideCursor()
|
837
|
+
QMessageBox.warning(
|
838
|
+
self,
|
839
|
+
"No VPEs Available",
|
840
|
+
"No VPEs available to save. "
|
841
|
+
"Please either load a VPE file or generate VPEs from reference images first."
|
842
|
+
)
|
843
|
+
return
|
844
|
+
|
845
|
+
QApplication.restoreOverrideCursor()
|
846
|
+
|
847
|
+
file_path, _ = QFileDialog.getSaveFileName(
|
848
|
+
self,
|
849
|
+
"Save VPE Collection",
|
850
|
+
"",
|
851
|
+
"PyTorch Tensor (*.pt);;All Files (*)"
|
852
|
+
)
|
853
|
+
|
854
|
+
if not file_path:
|
855
|
+
return
|
856
|
+
|
857
|
+
QApplication.setOverrideCursor(Qt.WaitCursor)
|
858
|
+
|
859
|
+
if not file_path.endswith('.pt'):
|
860
|
+
file_path += '.pt'
|
861
|
+
|
862
|
+
vpe_list_cpu = [vpe.cpu() for vpe in all_vpes]
|
863
|
+
|
864
|
+
torch.save(vpe_list_cpu, file_path)
|
865
|
+
|
866
|
+
self.status_bar.setText(f"Saved {len(all_vpes)} VPE tensors to {os.path.basename(file_path)}")
|
867
|
+
|
868
|
+
QApplication.restoreOverrideCursor()
|
869
|
+
QMessageBox.information(
|
870
|
+
self,
|
871
|
+
"VPE Saved",
|
872
|
+
f"Saved {len(all_vpes)} VPE tensors to {file_path}"
|
873
|
+
)
|
874
|
+
|
875
|
+
except Exception as e:
|
876
|
+
QApplication.restoreOverrideCursor()
|
877
|
+
QMessageBox.critical(
|
878
|
+
self,
|
879
|
+
"Error Saving VPE",
|
880
|
+
f"Failed to save VPE: {str(e)}"
|
881
|
+
)
|
882
|
+
finally:
|
883
|
+
try:
|
884
|
+
QApplication.restoreOverrideCursor()
|
885
|
+
except:
|
886
|
+
pass
|
685
887
|
|
686
888
|
def load_model(self):
|
687
889
|
"""
|
@@ -692,40 +894,32 @@ class DeployGeneratorDialog(QDialog):
|
|
692
894
|
progress_bar.show()
|
693
895
|
|
694
896
|
try:
|
695
|
-
#
|
696
|
-
self.
|
697
|
-
|
698
|
-
# Load model using registry
|
699
|
-
self.loaded_model = YOLOE(self.model_path).to(self.main_window.device)
|
700
|
-
|
701
|
-
# Create a dummy visual dictionary
|
702
|
-
visuals = dict(
|
703
|
-
bboxes=np.array(
|
704
|
-
[
|
705
|
-
[120, 425, 160, 445],
|
706
|
-
],
|
707
|
-
),
|
708
|
-
cls=np.array(
|
709
|
-
np.zeros(1),
|
710
|
-
),
|
711
|
-
)
|
897
|
+
# Load the model using reload_model method
|
898
|
+
self.reload_model()
|
712
899
|
|
713
|
-
#
|
714
|
-
self.
|
715
|
-
|
716
|
-
|
717
|
-
|
718
|
-
|
719
|
-
|
720
|
-
|
900
|
+
# Calculate total number of VPEs from both sources
|
901
|
+
total_vpes = len(self.imported_vpes) + len(self.reference_vpes)
|
902
|
+
|
903
|
+
if total_vpes > 0:
|
904
|
+
if self.imported_vpes and self.reference_vpes:
|
905
|
+
message = f"Model loaded with {len(self.imported_vpes)} imported VPEs "
|
906
|
+
message += f"and {len(self.reference_vpes)} reference VPEs"
|
907
|
+
elif self.imported_vpes:
|
908
|
+
message = f"Model loaded with {len(self.imported_vpes)} imported VPEs"
|
909
|
+
else:
|
910
|
+
message = f"Model loaded with {len(self.reference_vpes)} reference VPEs"
|
911
|
+
|
912
|
+
self.status_bar.setText(message)
|
913
|
+
else:
|
914
|
+
message = "Model loaded with default VPE"
|
915
|
+
self.status_bar.setText("Model loaded with default VPE")
|
721
916
|
|
917
|
+
# Finish progress bar
|
722
918
|
progress_bar.finish_progress()
|
723
|
-
self.
|
724
|
-
QMessageBox.information(self.annotation_window,
|
725
|
-
"Model Loaded",
|
726
|
-
"Model loaded successfully")
|
919
|
+
QMessageBox.information(self.annotation_window, "Model Loaded", message)
|
727
920
|
|
728
921
|
except Exception as e:
|
922
|
+
self.loaded_model = None
|
729
923
|
QMessageBox.critical(self.annotation_window,
|
730
924
|
"Error Loading Model",
|
731
925
|
f"Error loading model: {e}")
|
@@ -737,7 +931,52 @@ class DeployGeneratorDialog(QDialog):
|
|
737
931
|
progress_bar.stop_progress()
|
738
932
|
progress_bar.close()
|
739
933
|
progress_bar = None
|
934
|
+
|
935
|
+
def reload_model(self):
|
936
|
+
"""
|
937
|
+
Subset of the load_model method. This is needed when additional
|
938
|
+
reference images and annotations (i.e., VPEs) are added (we have
|
939
|
+
to re-load the model each time).
|
940
|
+
|
941
|
+
This method also ensures that we stash the currently highlighted reference
|
942
|
+
image paths before reloading, so they're available for predictions
|
943
|
+
even if the user switches the active image in the main window.
|
944
|
+
"""
|
945
|
+
self.loaded_model = None
|
946
|
+
|
947
|
+
# Get selected model path and download weights if needed
|
948
|
+
self.model_path = self.model_combo.currentText()
|
949
|
+
|
950
|
+
# Load model using registry
|
951
|
+
self.loaded_model = YOLOE(self.model_path, verbose=False).to(self.device) # TODO
|
952
|
+
|
953
|
+
# Create a dummy visual dictionary for standard model loading
|
954
|
+
visual_prompts = dict(
|
955
|
+
bboxes=np.array(
|
956
|
+
[
|
957
|
+
[120, 425, 160, 445], # Random box
|
958
|
+
],
|
959
|
+
),
|
960
|
+
cls=np.array(
|
961
|
+
np.zeros(1),
|
962
|
+
),
|
963
|
+
)
|
964
|
+
|
965
|
+
# Run a dummy prediction to load the model
|
966
|
+
self.loaded_model.predict(
|
967
|
+
np.zeros((640, 640, 3), dtype=np.uint8),
|
968
|
+
visual_prompts=visual_prompts.copy(), # This needs to happen to properly initialize the predictor
|
969
|
+
predictor=YOLOEVPSegPredictor, # This also needs to be SegPredictor, no matter what
|
970
|
+
imgsz=640,
|
971
|
+
conf=0.99,
|
972
|
+
)
|
740
973
|
|
974
|
+
# If a VPE file was loaded, use it with the model after the dummy prediction
|
975
|
+
if self.vpe is not None and isinstance(self.vpe, torch.Tensor):
|
976
|
+
# Directly set the final tensor as the prompt for the predictor
|
977
|
+
self.loaded_model.is_fused = lambda: False
|
978
|
+
self.loaded_model.set_classes(["object0"], self.vpe)
|
979
|
+
|
741
980
|
def predict(self, image_paths=None):
|
742
981
|
"""
|
743
982
|
Make predictions on the given image paths using the loaded model.
|
@@ -745,11 +984,11 @@ class DeployGeneratorDialog(QDialog):
|
|
745
984
|
Args:
|
746
985
|
image_paths: List of image paths to process. If None, uses the current image.
|
747
986
|
"""
|
748
|
-
if not self.loaded_model or not self.
|
987
|
+
if not self.loaded_model or not self.reference_label:
|
749
988
|
return
|
750
989
|
|
751
990
|
# Update class mapping with the selected reference label
|
752
|
-
self.class_mapping = {0: self.
|
991
|
+
self.class_mapping = {0: self.reference_label}
|
753
992
|
|
754
993
|
# Create a results processor
|
755
994
|
results_processor = ResultsProcessor(
|
@@ -809,32 +1048,33 @@ class DeployGeneratorDialog(QDialog):
|
|
809
1048
|
|
810
1049
|
return work_areas_data
|
811
1050
|
|
812
|
-
def
|
813
|
-
"""
|
814
|
-
Apply the model to the target inputs, using each highlighted source
|
815
|
-
image as an individual reference for a separate prediction run.
|
1051
|
+
def _get_references(self):
|
816
1052
|
"""
|
817
|
-
|
818
|
-
|
819
|
-
self.loaded_model.iou = self.main_window.get_iou_thresh()
|
820
|
-
self.loaded_model.max_det = self.get_max_detections()
|
1053
|
+
Get the reference annotations using the stashed list of reference images
|
1054
|
+
that was saved when the user accepted the dialog.
|
821
1055
|
|
822
|
-
|
823
|
-
|
824
|
-
|
825
|
-
|
826
|
-
|
827
|
-
|
828
|
-
|
829
|
-
|
830
|
-
|
831
|
-
|
832
|
-
|
1056
|
+
Returns:
|
1057
|
+
dict: Dictionary mapping image paths to annotation data, or empty dict if no valid references.
|
1058
|
+
"""
|
1059
|
+
# Use the "stashed" list of paths. Do NOT query the table_model again,
|
1060
|
+
# as the UI's highlight state may have been cleared by other actions.
|
1061
|
+
reference_paths = self.reference_image_paths
|
1062
|
+
|
1063
|
+
if not reference_paths:
|
1064
|
+
print("No reference image paths were stashed to use for prediction.")
|
1065
|
+
return {}
|
1066
|
+
|
1067
|
+
# Get the reference label that was also stashed
|
1068
|
+
reference_label = self.reference_label
|
1069
|
+
if not reference_label:
|
1070
|
+
# This check is a safeguard; the accept() method should prevent this.
|
1071
|
+
print("No reference label was selected.")
|
1072
|
+
return {}
|
833
1073
|
|
834
|
-
# Create a dictionary of reference annotations
|
1074
|
+
# Create a dictionary of reference annotations from the stashed paths
|
835
1075
|
reference_annotations_dict = {}
|
836
|
-
for path in
|
837
|
-
bboxes, masks = self.
|
1076
|
+
for path in reference_paths:
|
1077
|
+
bboxes, masks = self.get_reference_annotations(reference_label, path)
|
838
1078
|
if bboxes.size > 0:
|
839
1079
|
reference_annotations_dict[path] = {
|
840
1080
|
'bboxes': bboxes,
|
@@ -842,37 +1082,48 @@ class DeployGeneratorDialog(QDialog):
|
|
842
1082
|
'cls': np.zeros(len(bboxes))
|
843
1083
|
}
|
844
1084
|
|
845
|
-
|
846
|
-
|
847
|
-
|
1085
|
+
return reference_annotations_dict
|
1086
|
+
|
1087
|
+
def _apply_model_using_images(self, inputs, reference_dict):
|
1088
|
+
"""
|
1089
|
+
Apply the model using the provided images and reference annotations (dict). This method
|
1090
|
+
loops through each reference image using its annotations; we then aggregate
|
1091
|
+
all the results together. Less efficient, but potentially more accurate.
|
1092
|
+
|
1093
|
+
Args:
|
1094
|
+
inputs (list): List of input images.
|
1095
|
+
reference_dict (dict): Dictionary containing reference annotations for each image.
|
848
1096
|
|
1097
|
+
Returns:
|
1098
|
+
list: List of prediction results.
|
1099
|
+
"""
|
849
1100
|
# Create a progress bar for iterating through reference images
|
850
1101
|
QApplication.setOverrideCursor(Qt.WaitCursor)
|
851
1102
|
progress_bar = ProgressBar(self.annotation_window, title="Making Predictions per Reference")
|
852
1103
|
progress_bar.show()
|
853
|
-
progress_bar.start_progress(len(
|
854
|
-
|
1104
|
+
progress_bar.start_progress(len(reference_dict))
|
1105
|
+
|
855
1106
|
results_list = []
|
856
1107
|
# The 'inputs' list contains work areas from the single target image.
|
857
1108
|
# We will predict on the first work area/full image.
|
858
1109
|
input_image = inputs[0]
|
859
1110
|
|
860
1111
|
# Iterate through each reference image and its annotations
|
861
|
-
for ref_path, ref_annotations in
|
1112
|
+
for ref_path, ref_annotations in reference_dict.items():
|
862
1113
|
# The 'refer_image' parameter is the path to the current reference image
|
863
1114
|
# The 'visual_prompts' are the annotations from that same reference image
|
864
|
-
|
1115
|
+
visual_prompts = {
|
865
1116
|
'bboxes': ref_annotations['bboxes'],
|
866
1117
|
'cls': ref_annotations['cls'],
|
867
1118
|
}
|
868
1119
|
if self.task == 'segment':
|
869
|
-
|
1120
|
+
visual_prompts['masks'] = ref_annotations['masks']
|
870
1121
|
|
871
1122
|
# Make predictions on the target using the current reference
|
872
1123
|
results = self.loaded_model.predict(input_image,
|
873
1124
|
refer_image=ref_path,
|
874
|
-
visual_prompts=
|
875
|
-
predictor=
|
1125
|
+
visual_prompts=visual_prompts,
|
1126
|
+
predictor=YOLOEVPSegPredictor, # TODO This is necessary here?
|
876
1127
|
imgsz=self.imgsz_spinbox.value(),
|
877
1128
|
conf=self.main_window.get_uncertainty_thresh(),
|
878
1129
|
iou=self.main_window.get_iou_thresh(),
|
@@ -904,6 +1155,192 @@ class DeployGeneratorDialog(QDialog):
|
|
904
1155
|
return []
|
905
1156
|
|
906
1157
|
return [[combined_results]]
|
1158
|
+
|
1159
|
+
def generate_vpes_from_references(self):
|
1160
|
+
"""
|
1161
|
+
Calculates VPEs from the currently highlighted reference images and
|
1162
|
+
stores them in self.reference_vpes, overwriting any previous ones.
|
1163
|
+
"""
|
1164
|
+
if not self.loaded_model:
|
1165
|
+
QMessageBox.warning(self, "No Model Loaded", "A model must be loaded before generating VPEs.")
|
1166
|
+
return
|
1167
|
+
|
1168
|
+
# Always sync with the live UI selection before generating.
|
1169
|
+
self.update_stashed_references_from_ui()
|
1170
|
+
references_dict = self._get_references()
|
1171
|
+
|
1172
|
+
if not references_dict:
|
1173
|
+
QMessageBox.information(
|
1174
|
+
self,
|
1175
|
+
"No References Selected",
|
1176
|
+
"Please highlight one or more reference images in the table to generate VPEs."
|
1177
|
+
)
|
1178
|
+
return
|
1179
|
+
|
1180
|
+
QApplication.setOverrideCursor(Qt.WaitCursor)
|
1181
|
+
progress_bar = ProgressBar(self, title="Generating VPEs")
|
1182
|
+
progress_bar.show()
|
1183
|
+
|
1184
|
+
try:
|
1185
|
+
# Make progress bar busy
|
1186
|
+
progress_bar.set_busy_mode("Generating VPEs...")
|
1187
|
+
# Reload the model to ensure a clean state for VPE generation
|
1188
|
+
self.reload_model()
|
1189
|
+
|
1190
|
+
# The references_to_vpe method will calculate and update self.reference_vpes
|
1191
|
+
new_vpes = self.references_to_vpe(references_dict, update_reference_vpes=True)
|
1192
|
+
|
1193
|
+
if new_vpes:
|
1194
|
+
num_vpes = len(new_vpes)
|
1195
|
+
num_images = len(references_dict)
|
1196
|
+
message = f"Successfully generated {num_vpes} VPEs from {num_images} reference image(s)."
|
1197
|
+
self.status_bar.setText(message)
|
1198
|
+
QMessageBox.information(self, "VPEs Generated", message)
|
1199
|
+
else:
|
1200
|
+
message = "Could not generate VPEs. Ensure annotations are valid."
|
1201
|
+
self.status_bar.setText(message)
|
1202
|
+
QMessageBox.warning(self, "Generation Failed", message)
|
1203
|
+
|
1204
|
+
except Exception as e:
|
1205
|
+
QMessageBox.critical(self, "Error Generating VPEs", f"An unexpected error occurred: {str(e)}")
|
1206
|
+
self.status_bar.setText("Error during VPE generation.")
|
1207
|
+
finally:
|
1208
|
+
QApplication.restoreOverrideCursor()
|
1209
|
+
progress_bar.stop_progress()
|
1210
|
+
progress_bar.close()
|
1211
|
+
|
1212
|
+
def references_to_vpe(self, reference_dict, update_reference_vpes=True):
|
1213
|
+
"""
|
1214
|
+
Converts the contents of a reference dictionary to VPEs (Visual Prompt Embeddings).
|
1215
|
+
Reference dictionaries contain information about the visual prompts for each reference image:
|
1216
|
+
dict[image_path]: {bboxes, masks, cls}
|
1217
|
+
|
1218
|
+
Args:
|
1219
|
+
reference_dict (dict): The reference dictionary containing visual prompts for each image.
|
1220
|
+
update_reference_vpes (bool): Whether to update self.reference_vpes with the results.
|
1221
|
+
|
1222
|
+
Returns:
|
1223
|
+
list: List of individual VPE tensors (normalized), or None if empty reference_dict
|
1224
|
+
"""
|
1225
|
+
# Check if the reference dictionary is empty
|
1226
|
+
if not reference_dict:
|
1227
|
+
return None
|
1228
|
+
|
1229
|
+
# Create a list to hold the individual VPE tensors
|
1230
|
+
vpe_list = []
|
1231
|
+
|
1232
|
+
for ref_path, ref_annotations in reference_dict.items():
|
1233
|
+
# Set the prompts to the model predictor
|
1234
|
+
self.loaded_model.predictor.set_prompts(ref_annotations)
|
1235
|
+
|
1236
|
+
# Get the VPE from the model
|
1237
|
+
vpe = self.loaded_model.predictor.get_vpe(ref_path)
|
1238
|
+
|
1239
|
+
# Normalize individual VPE
|
1240
|
+
vpe_normalized = torch.nn.functional.normalize(vpe, p=2, dim=-1)
|
1241
|
+
vpe_list.append(vpe_normalized)
|
1242
|
+
|
1243
|
+
# Check if we have any valid VPEs
|
1244
|
+
if not vpe_list:
|
1245
|
+
return None
|
1246
|
+
|
1247
|
+
# Update the reference_vpes list if requested
|
1248
|
+
if update_reference_vpes:
|
1249
|
+
self.reference_vpes = vpe_list
|
1250
|
+
|
1251
|
+
return vpe_list
|
1252
|
+
|
1253
|
+
def _apply_model_using_vpe(self, inputs):
|
1254
|
+
"""
|
1255
|
+
Apply the model to the inputs using pre-calculated VPEs from imported files
|
1256
|
+
and/or generated from reference annotations.
|
1257
|
+
|
1258
|
+
Args:
|
1259
|
+
inputs (list): List of input images.
|
1260
|
+
|
1261
|
+
Returns:
|
1262
|
+
list: List of prediction results.
|
1263
|
+
"""
|
1264
|
+
# First reload the model to clear any cached data
|
1265
|
+
self.reload_model()
|
1266
|
+
|
1267
|
+
# Initialize combined_vpes list
|
1268
|
+
combined_vpes = []
|
1269
|
+
|
1270
|
+
# Add imported VPEs if available
|
1271
|
+
if self.imported_vpes:
|
1272
|
+
combined_vpes.extend(self.imported_vpes)
|
1273
|
+
|
1274
|
+
# Add pre-generated reference VPEs if available
|
1275
|
+
if self.reference_vpes:
|
1276
|
+
combined_vpes.extend(self.reference_vpes)
|
1277
|
+
|
1278
|
+
# Check if we have any VPEs to use
|
1279
|
+
if not combined_vpes:
|
1280
|
+
QMessageBox.warning(
|
1281
|
+
self,
|
1282
|
+
"No VPEs Available",
|
1283
|
+
"No VPEs are available for prediction. "
|
1284
|
+
"Please either load a VPE file or generate VPEs from reference images."
|
1285
|
+
)
|
1286
|
+
return []
|
1287
|
+
|
1288
|
+
# Average all the VPEs together to create a final VPE tensor
|
1289
|
+
averaged_vpe = torch.cat(combined_vpes).mean(dim=0, keepdim=True)
|
1290
|
+
final_vpe = torch.nn.functional.normalize(averaged_vpe, p=2, dim=-1)
|
1291
|
+
|
1292
|
+
# For backward compatibility, update self.vpe
|
1293
|
+
self.vpe = final_vpe
|
1294
|
+
|
1295
|
+
# Set the final VPE to the model
|
1296
|
+
self.loaded_model.is_fused = lambda: False
|
1297
|
+
self.loaded_model.set_classes(["object0"], final_vpe)
|
1298
|
+
|
1299
|
+
# Make predictions on the target using the averaged VPE
|
1300
|
+
results = self.loaded_model.predict(inputs[0],
|
1301
|
+
visual_prompts=[],
|
1302
|
+
imgsz=self.imgsz_spinbox.value(),
|
1303
|
+
conf=self.main_window.get_uncertainty_thresh(),
|
1304
|
+
iou=self.main_window.get_iou_thresh(),
|
1305
|
+
max_det=self.get_max_detections(),
|
1306
|
+
retina_masks=self.task == "segment")
|
1307
|
+
|
1308
|
+
return [results]
|
1309
|
+
|
1310
|
+
def _apply_model(self, inputs):
|
1311
|
+
"""
|
1312
|
+
Apply the model to the target inputs. This method handles both image-based
|
1313
|
+
references and VPE-based references.
|
1314
|
+
"""
|
1315
|
+
# Update the model with user parameters
|
1316
|
+
self.task = self.use_task_dropdown.currentText()
|
1317
|
+
|
1318
|
+
self.loaded_model.conf = self.main_window.get_uncertainty_thresh()
|
1319
|
+
self.loaded_model.iou = self.main_window.get_iou_thresh()
|
1320
|
+
self.loaded_model.max_det = self.get_max_detections()
|
1321
|
+
|
1322
|
+
# Get the reference information for the currently selected rows
|
1323
|
+
references_dict = self._get_references()
|
1324
|
+
|
1325
|
+
# Check if the user is using VPE or Reference Images
|
1326
|
+
if self.reference_method_combo_box.currentText() == "VPE":
|
1327
|
+
# The VPE method will use pre-loaded/pre-generated VPEs.
|
1328
|
+
# The internal checks for whether any VPEs exist are now inside _apply_model_using_vpe.
|
1329
|
+
results = self._apply_model_using_vpe(inputs)
|
1330
|
+
else:
|
1331
|
+
# Use Reference Images method - requires reference images
|
1332
|
+
if not references_dict:
|
1333
|
+
QMessageBox.warning(
|
1334
|
+
self,
|
1335
|
+
"No References Selected",
|
1336
|
+
"No reference images with valid annotations were selected. "
|
1337
|
+
"Please select at least one reference image."
|
1338
|
+
)
|
1339
|
+
return []
|
1340
|
+
|
1341
|
+
results = self._apply_model_using_images(inputs, references_dict)
|
1342
|
+
|
1343
|
+
return results
|
907
1344
|
|
908
1345
|
def _apply_sam(self, results_list, image_path):
|
909
1346
|
"""Apply SAM to the results if needed."""
|
@@ -1000,17 +1437,223 @@ class DeployGeneratorDialog(QDialog):
|
|
1000
1437
|
progress_bar.stop_progress()
|
1001
1438
|
progress_bar.close()
|
1002
1439
|
|
1440
|
+
def show_vpe(self):
|
1441
|
+
"""
|
1442
|
+
Show a visualization of the currently stored VPEs using PyQtGraph.
|
1443
|
+
"""
|
1444
|
+
try:
|
1445
|
+
vpes_with_source = []
|
1446
|
+
|
1447
|
+
# 1. Add any VPEs that were loaded from a file
|
1448
|
+
if self.imported_vpes:
|
1449
|
+
for vpe in self.imported_vpes:
|
1450
|
+
vpes_with_source.append((vpe, "Import"))
|
1451
|
+
|
1452
|
+
# 2. Add any pre-generated VPEs from reference images
|
1453
|
+
if self.reference_vpes:
|
1454
|
+
for vpe in self.reference_vpes:
|
1455
|
+
vpes_with_source.append((vpe, "Reference"))
|
1456
|
+
|
1457
|
+
# 3. Check if there is anything to visualize
|
1458
|
+
if not vpes_with_source:
|
1459
|
+
QMessageBox.warning(
|
1460
|
+
self,
|
1461
|
+
"No VPEs Available",
|
1462
|
+
"No VPEs available to visualize. Please load a VPE file or generate VPEs from references first."
|
1463
|
+
)
|
1464
|
+
return
|
1465
|
+
|
1466
|
+
# 4. Create the visualization dialog
|
1467
|
+
all_vpe_tensors = [vpe for vpe, source in vpes_with_source]
|
1468
|
+
averaged_vpe = torch.cat(all_vpe_tensors).mean(dim=0, keepdim=True)
|
1469
|
+
final_vpe = torch.nn.functional.normalize(averaged_vpe, p=2, dim=-1)
|
1470
|
+
|
1471
|
+
QApplication.setOverrideCursor(Qt.WaitCursor)
|
1472
|
+
|
1473
|
+
dialog = VPEVisualizationDialog(vpes_with_source, final_vpe, self)
|
1474
|
+
QApplication.restoreOverrideCursor()
|
1475
|
+
|
1476
|
+
dialog.exec_()
|
1477
|
+
|
1478
|
+
except Exception as e:
|
1479
|
+
QApplication.restoreOverrideCursor()
|
1480
|
+
QMessageBox.critical(self, "Error Visualizing VPE", f"An error occurred: {str(e)}")
|
1481
|
+
|
1003
1482
|
def deactivate_model(self):
|
1004
1483
|
"""
|
1005
1484
|
Deactivate the currently loaded model and clean up resources.
|
1006
1485
|
"""
|
1007
1486
|
self.loaded_model = None
|
1008
1487
|
self.model_path = None
|
1009
|
-
|
1488
|
+
|
1489
|
+
# Clear all VPE-related data
|
1490
|
+
self.vpe_path_edit.clear()
|
1491
|
+
self.vpe_path = None
|
1492
|
+
self.vpe = None
|
1493
|
+
self.imported_vpes = []
|
1494
|
+
self.reference_vpes = []
|
1495
|
+
|
1496
|
+
# Clean up references
|
1010
1497
|
gc.collect()
|
1011
1498
|
torch.cuda.empty_cache()
|
1499
|
+
|
1012
1500
|
# Untoggle all tools
|
1013
1501
|
self.main_window.untoggle_all_tools()
|
1502
|
+
|
1014
1503
|
# Update status bar
|
1015
1504
|
self.status_bar.setText("No model loaded")
|
1016
|
-
QMessageBox.information(self, "Model Deactivated", "Model deactivated")
|
1505
|
+
QMessageBox.information(self, "Model Deactivated", "Model deactivated")
|
1506
|
+
|
1507
|
+
|
1508
|
+
class VPEVisualizationDialog(QDialog):
|
1509
|
+
"""
|
1510
|
+
Dialog for visualizing VPE embeddings in 2D space using PCA.
|
1511
|
+
"""
|
1512
|
+
def __init__(self, vpe_list_with_source, final_vpe=None, parent=None):
|
1513
|
+
"""
|
1514
|
+
Initialize the dialog with a list of VPE tensors and their sources.
|
1515
|
+
|
1516
|
+
Args:
|
1517
|
+
vpe_list_with_source (list): List of (VPE tensor, source_str) tuples
|
1518
|
+
final_vpe (torch.Tensor, optional): The final (averaged) VPE
|
1519
|
+
parent (QWidget, optional): Parent widget
|
1520
|
+
"""
|
1521
|
+
super().__init__(parent)
|
1522
|
+
self.setWindowTitle("VPE Visualization")
|
1523
|
+
self.resize(1000, 1000)
|
1524
|
+
|
1525
|
+
# Add a maximize button to the dialog's title bar
|
1526
|
+
self.setWindowFlags(self.windowFlags() | Qt.WindowMaximizeButtonHint)
|
1527
|
+
|
1528
|
+
# Store the VPEs and their sources
|
1529
|
+
self.vpe_list_with_source = vpe_list_with_source
|
1530
|
+
self.final_vpe = final_vpe
|
1531
|
+
|
1532
|
+
# Create the layout
|
1533
|
+
layout = QVBoxLayout(self)
|
1534
|
+
|
1535
|
+
# Create the plot widget
|
1536
|
+
self.plot_widget = pg.PlotWidget()
|
1537
|
+
self.plot_widget.setBackground('w') # White background
|
1538
|
+
self.plot_widget.setTitle("PCA Visualization of Visual Prompt Embeddings", color="#000000", size="10pt")
|
1539
|
+
self.plot_widget.showGrid(x=True, y=True, alpha=0.3)
|
1540
|
+
|
1541
|
+
# Add the plot widget to the layout
|
1542
|
+
layout.addWidget(self.plot_widget)
|
1543
|
+
|
1544
|
+
# Add spacing between plot_widget and info_label
|
1545
|
+
layout.addSpacing(20)
|
1546
|
+
|
1547
|
+
# Add information label at the bottom
|
1548
|
+
self.info_label = QLabel()
|
1549
|
+
self.info_label.setAlignment(Qt.AlignCenter)
|
1550
|
+
layout.addWidget(self.info_label)
|
1551
|
+
|
1552
|
+
# Create the button box
|
1553
|
+
button_box = QDialogButtonBox(QDialogButtonBox.Close)
|
1554
|
+
button_box.rejected.connect(self.reject)
|
1555
|
+
layout.addWidget(button_box)
|
1556
|
+
|
1557
|
+
# Visualize the VPEs
|
1558
|
+
self.visualize_vpes()
|
1559
|
+
|
1560
|
+
def visualize_vpes(self):
|
1561
|
+
"""
|
1562
|
+
Apply PCA to the VPE tensors and visualize them in 2D space.
|
1563
|
+
"""
|
1564
|
+
if not self.vpe_list_with_source:
|
1565
|
+
self.info_label.setText("No VPEs available to visualize.")
|
1566
|
+
return
|
1567
|
+
|
1568
|
+
# Convert tensors to numpy arrays for PCA, separating them from the source string
|
1569
|
+
vpe_arrays = [vpe.detach().cpu().numpy().squeeze() for vpe, source in self.vpe_list_with_source]
|
1570
|
+
|
1571
|
+
# If final VPE is provided, add it to the arrays
|
1572
|
+
final_vpe_array = None
|
1573
|
+
if self.final_vpe is not None:
|
1574
|
+
final_vpe_array = self.final_vpe.detach().cpu().numpy().squeeze()
|
1575
|
+
all_vpes = np.vstack(vpe_arrays + [final_vpe_array])
|
1576
|
+
else:
|
1577
|
+
all_vpes = np.vstack(vpe_arrays)
|
1578
|
+
|
1579
|
+
# Apply PCA to reduce to 2 dimensions
|
1580
|
+
pca = PCA(n_components=2)
|
1581
|
+
vpes_2d = pca.fit_transform(all_vpes)
|
1582
|
+
|
1583
|
+
# Clear the plot
|
1584
|
+
self.plot_widget.clear()
|
1585
|
+
|
1586
|
+
# Generate random colors for individual VPEs
|
1587
|
+
num_vpes = len(vpe_arrays)
|
1588
|
+
colors = self.generate_distinct_colors(num_vpes)
|
1589
|
+
|
1590
|
+
# Create a legend with 3 columns to keep it compact
|
1591
|
+
legend = self.plot_widget.addLegend(colCount=3)
|
1592
|
+
|
1593
|
+
# Plot individual VPEs
|
1594
|
+
for i, (vpe_tuple, vpe_2d) in enumerate(zip(self.vpe_list_with_source, vpes_2d[:num_vpes])):
|
1595
|
+
source_char = 'I' if vpe_tuple[1] == 'Import' else 'R'
|
1596
|
+
color = pg.mkColor(colors[i])
|
1597
|
+
scatter = pg.ScatterPlotItem(
|
1598
|
+
x=[vpe_2d[0]],
|
1599
|
+
y=[vpe_2d[1]],
|
1600
|
+
brush=color,
|
1601
|
+
size=15,
|
1602
|
+
name=f"VPE {i+1} ({source_char})"
|
1603
|
+
)
|
1604
|
+
self.plot_widget.addItem(scatter)
|
1605
|
+
|
1606
|
+
# Plot the final (averaged) VPE if available
|
1607
|
+
if final_vpe_array is not None:
|
1608
|
+
final_vpe_2d = vpes_2d[-1]
|
1609
|
+
scatter = pg.ScatterPlotItem(
|
1610
|
+
x=[final_vpe_2d[0]],
|
1611
|
+
y=[final_vpe_2d[1]],
|
1612
|
+
brush=pg.mkBrush(color='r'),
|
1613
|
+
size=20,
|
1614
|
+
symbol='star',
|
1615
|
+
name="Final VPE"
|
1616
|
+
)
|
1617
|
+
self.plot_widget.addItem(scatter)
|
1618
|
+
|
1619
|
+
# Update the information label
|
1620
|
+
orig_dim = self.vpe_list_with_source[0][0].shape[-1]
|
1621
|
+
explained_variance = sum(pca.explained_variance_ratio_)
|
1622
|
+
self.info_label.setText(
|
1623
|
+
f"Original dimension: {orig_dim} → Reduced to 2D\n"
|
1624
|
+
f"Total explained variance: {explained_variance:.2%}\n"
|
1625
|
+
f"PC1: {pca.explained_variance_ratio_[0]:.2%} variance, "
|
1626
|
+
f"PC2: {pca.explained_variance_ratio_[1]:.2%} variance"
|
1627
|
+
)
|
1628
|
+
|
1629
|
+
def generate_distinct_colors(self, num_colors):
|
1630
|
+
"""
|
1631
|
+
Generate visually distinct colors by using evenly spaced hues
|
1632
|
+
with random saturation and value.
|
1633
|
+
|
1634
|
+
Args:
|
1635
|
+
num_colors (int): Number of colors to generate
|
1636
|
+
|
1637
|
+
Returns:
|
1638
|
+
list: List of color hex strings
|
1639
|
+
"""
|
1640
|
+
import random
|
1641
|
+
from colorsys import hsv_to_rgb
|
1642
|
+
|
1643
|
+
colors = []
|
1644
|
+
for i in range(num_colors):
|
1645
|
+
# Use golden ratio to space hues evenly
|
1646
|
+
hue = (i * 0.618033988749895) % 1.0
|
1647
|
+
# Random saturation between 0.6-1.0 (avoid too pale)
|
1648
|
+
saturation = random.uniform(0.6, 1.0)
|
1649
|
+
# Random value between 0.7-1.0 (avoid too dark)
|
1650
|
+
value = random.uniform(0.7, 1.0)
|
1651
|
+
|
1652
|
+
# Convert HSV to RGB (0-1 range)
|
1653
|
+
r, g, b = hsv_to_rgb(hue, saturation, value)
|
1654
|
+
|
1655
|
+
# Convert RGB to hex string
|
1656
|
+
hex_color = f"#{int(r*255):02x}{int(g*255):02x}{int(b*255):02x}"
|
1657
|
+
colors.append(hex_color)
|
1658
|
+
|
1659
|
+
return colors
|