coralnet-toolbox 0.0.73__py2.py3-none-any.whl → 0.0.75__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. coralnet_toolbox/Annotations/QtAnnotation.py +28 -69
  2. coralnet_toolbox/Annotations/QtMaskAnnotation.py +408 -0
  3. coralnet_toolbox/Annotations/QtMultiPolygonAnnotation.py +72 -56
  4. coralnet_toolbox/Annotations/QtPatchAnnotation.py +165 -216
  5. coralnet_toolbox/Annotations/QtPolygonAnnotation.py +497 -353
  6. coralnet_toolbox/Annotations/QtRectangleAnnotation.py +126 -116
  7. coralnet_toolbox/CoralNet/QtDownload.py +2 -1
  8. coralnet_toolbox/Explorer/QtDataItem.py +52 -22
  9. coralnet_toolbox/Explorer/QtExplorer.py +293 -1614
  10. coralnet_toolbox/Explorer/QtSettingsWidgets.py +203 -85
  11. coralnet_toolbox/Explorer/QtViewers.py +1568 -0
  12. coralnet_toolbox/Explorer/transformer_models.py +59 -0
  13. coralnet_toolbox/Explorer/yolo_models.py +112 -0
  14. coralnet_toolbox/IO/QtExportTagLabAnnotations.py +30 -10
  15. coralnet_toolbox/IO/QtImportTagLabAnnotations.py +21 -15
  16. coralnet_toolbox/IO/QtOpenProject.py +46 -78
  17. coralnet_toolbox/IO/QtSaveProject.py +18 -43
  18. coralnet_toolbox/MachineLearning/ExportDataset/QtBase.py +1 -1
  19. coralnet_toolbox/MachineLearning/ImportDataset/QtBase.py +253 -141
  20. coralnet_toolbox/MachineLearning/VideoInference/QtBase.py +0 -4
  21. coralnet_toolbox/MachineLearning/VideoInference/YOLO3D/run.py +102 -16
  22. coralnet_toolbox/QtAnnotationWindow.py +16 -10
  23. coralnet_toolbox/QtEventFilter.py +11 -0
  24. coralnet_toolbox/QtImageWindow.py +120 -75
  25. coralnet_toolbox/QtLabelWindow.py +13 -1
  26. coralnet_toolbox/QtMainWindow.py +5 -27
  27. coralnet_toolbox/QtProgressBar.py +52 -27
  28. coralnet_toolbox/Rasters/RasterTableModel.py +28 -8
  29. coralnet_toolbox/SAM/QtDeployGenerator.py +1 -4
  30. coralnet_toolbox/SAM/QtDeployPredictor.py +11 -3
  31. coralnet_toolbox/SeeAnything/QtDeployGenerator.py +805 -162
  32. coralnet_toolbox/SeeAnything/QtDeployPredictor.py +130 -151
  33. coralnet_toolbox/Tools/QtCutSubTool.py +18 -2
  34. coralnet_toolbox/Tools/QtPolygonTool.py +42 -3
  35. coralnet_toolbox/Tools/QtRectangleTool.py +30 -0
  36. coralnet_toolbox/Tools/QtResizeSubTool.py +19 -2
  37. coralnet_toolbox/Tools/QtSAMTool.py +72 -50
  38. coralnet_toolbox/Tools/QtSeeAnythingTool.py +8 -5
  39. coralnet_toolbox/Tools/QtSelectTool.py +27 -3
  40. coralnet_toolbox/Tools/QtSubtractSubTool.py +66 -0
  41. coralnet_toolbox/Tools/__init__.py +2 -0
  42. coralnet_toolbox/__init__.py +1 -1
  43. coralnet_toolbox/utilities.py +158 -47
  44. coralnet_toolbox-0.0.75.dist-info/METADATA +378 -0
  45. {coralnet_toolbox-0.0.73.dist-info → coralnet_toolbox-0.0.75.dist-info}/RECORD +49 -44
  46. coralnet_toolbox-0.0.73.dist-info/METADATA +0 -341
  47. {coralnet_toolbox-0.0.73.dist-info → coralnet_toolbox-0.0.75.dist-info}/WHEEL +0 -0
  48. {coralnet_toolbox-0.0.73.dist-info → coralnet_toolbox-0.0.75.dist-info}/entry_points.txt +0 -0
  49. {coralnet_toolbox-0.0.73.dist-info → coralnet_toolbox-0.0.75.dist-info}/licenses/LICENSE.txt +0 -0
  50. {coralnet_toolbox-0.0.73.dist-info → coralnet_toolbox-0.0.75.dist-info}/top_level.txt +0 -0
@@ -2,24 +2,24 @@ import warnings
2
2
 
3
3
  import os
4
4
  import gc
5
- import json
6
- import copy
7
5
 
8
6
  import numpy as np
7
+ from sklearn.decomposition import PCA
9
8
 
10
9
  import torch
11
10
  from torch.cuda import empty_cache
12
11
 
12
+ import pyqtgraph as pg
13
+ from pyqtgraph.Qt import QtGui
14
+
13
15
  from ultralytics import YOLOE
14
16
  from ultralytics.models.yolo.yoloe import YOLOEVPSegPredictor
15
- from ultralytics.models.yolo.yoloe import YOLOEVPDetectPredictor
16
17
 
17
18
  from PyQt5.QtCore import Qt
18
- from PyQt5.QtGui import QColor
19
- from PyQt5.QtWidgets import (QMessageBox, QCheckBox, QVBoxLayout, QApplication,
19
+ from PyQt5.QtWidgets import (QMessageBox, QVBoxLayout, QApplication, QFileDialog,
20
20
  QLabel, QDialog, QDialogButtonBox, QGroupBox, QLineEdit,
21
21
  QFormLayout, QComboBox, QSpinBox, QSlider, QPushButton,
22
- QHBoxLayout, QWidget, QFileDialog)
22
+ QHBoxLayout)
23
23
 
24
24
  from coralnet_toolbox.Annotations.QtPolygonAnnotation import PolygonAnnotation
25
25
  from coralnet_toolbox.Annotations.QtRectangleAnnotation import RectangleAnnotation
@@ -79,10 +79,18 @@ class DeployGeneratorDialog(QDialog):
79
79
  self.class_mapping = {}
80
80
 
81
81
  # Reference image and label
82
- self.source_images = []
83
- self.source_label = None
84
- # Target images
85
- self.target_images = []
82
+ self.reference_label = None
83
+ self.reference_image_paths = []
84
+
85
+ # Visual Prompting Encoding (VPE) - legacy single tensor variable
86
+ self.vpe_path = None
87
+ self.vpe = None
88
+
89
+ # New separate VPE collections
90
+ self.imported_vpes = [] # VPEs loaded from file
91
+ self.reference_vpes = [] # VPEs created from reference images
92
+
93
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
86
94
 
87
95
  # Main vertical layout for the dialog
88
96
  self.layout = QVBoxLayout(self)
@@ -110,7 +118,7 @@ class DeployGeneratorDialog(QDialog):
110
118
  self.setup_status_layout()
111
119
 
112
120
  # Add layouts to the right panel
113
- self.setup_source_layout()
121
+ self.setup_reference_layout()
114
122
 
115
123
  # # Add a full ImageWindow instance for target image selection
116
124
  self.image_selection_window = ImageWindow(self.main_window)
@@ -158,6 +166,9 @@ class DeployGeneratorDialog(QDialog):
158
166
  iw.search_bar_images.setEnabled(False)
159
167
  iw.search_bar_labels.setEnabled(False)
160
168
  iw.top_k_combo.setEnabled(False)
169
+
170
+ # Hide the "Current" label as it is not applicable in this dialog
171
+ iw.current_image_index_label.hide()
161
172
 
162
173
  # Set Top-K to Top1
163
174
  iw.top_k_combo.setCurrentText("Top1")
@@ -190,7 +201,7 @@ class DeployGeneratorDialog(QDialog):
190
201
  self.sync_image_window()
191
202
  # This now populates the dropdown, restores the last selection,
192
203
  # and then manually triggers the image filtering.
193
- self.update_source_labels()
204
+ self.update_reference_labels()
194
205
 
195
206
  def sync_image_window(self):
196
207
  """
@@ -219,14 +230,23 @@ class DeployGeneratorDialog(QDialog):
219
230
  annotation that has BOTH the selected label AND a valid type (Polygon or Rectangle).
220
231
  This uses the fast, pre-computed cache for performance.
221
232
  """
222
- source_label = self.source_label_combo_box.currentData()
223
- source_label_text = self.source_label_combo_box.currentText()
233
+ # Persist the user's current highlights from the table model before filtering.
234
+ # This ensures that if the user highlights items and then changes the filter,
235
+ # their selection is not lost.
236
+ current_highlights = self.image_selection_window.table_model.get_highlighted_paths()
237
+ if current_highlights:
238
+ self.reference_image_paths = current_highlights
239
+
240
+ reference_label = self.reference_label_combo_box.currentData()
241
+ reference_label_text = self.reference_label_combo_box.currentText()
224
242
 
225
243
  # Store the last selected label for a better user experience on re-opening.
226
- if source_label_text:
227
- self.last_selected_label_code = source_label_text
244
+ if reference_label_text:
245
+ self.last_selected_label_code = reference_label_text
246
+ # Also store the reference label object itself
247
+ self.reference_label = reference_label
228
248
 
229
- if not source_label:
249
+ if not reference_label:
230
250
  # If no label is selected (e.g., during initialization), show an empty list.
231
251
  self.image_selection_window.table_model.set_filtered_paths([])
232
252
  return
@@ -235,7 +255,7 @@ class DeployGeneratorDialog(QDialog):
235
255
  final_filtered_paths = []
236
256
 
237
257
  valid_types = {"RectangleAnnotation", "PolygonAnnotation"}
238
- selected_label_code = source_label.short_label_code
258
+ selected_label_code = reference_label.short_label_code
239
259
 
240
260
  # Loop through paths and check the pre-computed map on each raster
241
261
  for path in all_paths:
@@ -257,40 +277,54 @@ class DeployGeneratorDialog(QDialog):
257
277
  # Directly set the filtered list in the table model.
258
278
  self.image_selection_window.table_model.set_filtered_paths(final_filtered_paths)
259
279
 
280
+ # Try to preserve any previous selections
281
+ if hasattr(self, 'reference_image_paths') and self.reference_image_paths:
282
+ # Find which of our previously selected paths are still in the filtered list
283
+ valid_selections = [p for p in self.reference_image_paths if p in final_filtered_paths]
284
+ if valid_selections:
285
+ # Highlight previously selected paths that are still valid
286
+ self.image_selection_window.table_model.set_highlighted_paths(valid_selections)
287
+
288
+ # After filtering, update all labels with the correct counts.
289
+ dialog_iw = self.image_selection_window
290
+ dialog_iw.update_image_count_label(len(final_filtered_paths)) # Set "Total" to filtered count
291
+ dialog_iw.update_current_image_index_label()
292
+ dialog_iw.update_highlighted_count_label()
293
+
260
294
  def accept(self):
261
295
  """
262
296
  Validate selections and store them before closing the dialog.
297
+ A prediction is valid if a model and label are selected, and the user
298
+ has provided either reference images or an imported VPE file.
263
299
  """
264
300
  if not self.loaded_model:
265
301
  QMessageBox.warning(self,
266
302
  "No Model",
267
303
  "A model must be loaded before running predictions.")
268
- super().reject()
269
304
  return
270
305
 
271
- current_label = self.source_label_combo_box.currentData()
272
- if not current_label:
306
+ # Set reference label from combo box
307
+ self.reference_label = self.reference_label_combo_box.currentData()
308
+ if not self.reference_label:
273
309
  QMessageBox.warning(self,
274
- "No Source Label",
275
- "A source label must be selected.")
276
- super().reject()
310
+ "No Reference Label",
311
+ "A reference label must be selected.")
277
312
  return
278
313
 
279
- # Get highlighted paths from our internal image window to use as targets
280
- highlighted_images = self.image_selection_window.table_model.get_highlighted_paths()
314
+ # Stash the current UI selection before validating.
315
+ self.update_stashed_references_from_ui()
316
+
317
+ # Check for a valid VPE source using the now-stashed list.
318
+ has_reference_images = bool(self.reference_image_paths)
319
+ has_imported_vpes = bool(self.imported_vpes)
281
320
 
282
- if not highlighted_images:
321
+ if not has_reference_images and not has_imported_vpes:
283
322
  QMessageBox.warning(self,
284
- "No Target Images",
285
- "You must highlight at least one image in the list to process.")
286
- super().reject()
323
+ "No VPE Source Provided",
324
+ "You must highlight at least one reference image or load a VPE file to proceed.")
287
325
  return
288
326
 
289
- # Store the selections for the caller to use after the dialog closes.
290
- self.source_label = current_label
291
- self.target_images = highlighted_images
292
-
293
- # Do not call self.predict here; just close the dialog and let the caller handle prediction
327
+ # If validation passes, close the dialog.
294
328
  super().accept()
295
329
 
296
330
  def setup_info_layout(self):
@@ -317,19 +351,19 @@ class DeployGeneratorDialog(QDialog):
317
351
  Setup the models layout with a simple model selection combo box (no tabs).
318
352
  """
319
353
  group_box = QGroupBox("Model Selection")
320
- layout = QVBoxLayout()
354
+ layout = QFormLayout()
321
355
 
322
356
  self.model_combo = QComboBox()
323
357
  self.model_combo.setEditable(True)
324
358
 
325
359
  # Define available models (keep the existing dictionary)
326
360
  self.models = [
327
- "yoloe-v8s-seg.pt",
328
- "yoloe-v8m-seg.pt",
329
- "yoloe-v8l-seg.pt",
330
- "yoloe-11s-seg.pt",
331
- "yoloe-11m-seg.pt",
332
- "yoloe-11l-seg.pt",
361
+ 'yoloe-v8s-seg.pt',
362
+ 'yoloe-v8m-seg.pt',
363
+ 'yoloe-v8l-seg.pt',
364
+ 'yoloe-11s-seg.pt',
365
+ 'yoloe-11m-seg.pt',
366
+ 'yoloe-11l-seg.pt',
333
367
  ]
334
368
 
335
369
  # Add all models to combo box
@@ -337,10 +371,19 @@ class DeployGeneratorDialog(QDialog):
337
371
  self.model_combo.addItem(model_name)
338
372
 
339
373
  # Set the default model
340
- self.model_combo.setCurrentText("yoloe-v8s-seg.pt")
374
+ self.model_combo.setCurrentIndex(self.models.index('yoloe-v8s-seg.pt'))
375
+ # Create a layout for the model selection
376
+ layout.addRow(QLabel("Models:"), self.model_combo)
377
+
378
+ # Add custom vpe file selection
379
+ self.vpe_path_edit = QLineEdit()
380
+ browse_button = QPushButton("Browse...")
381
+ browse_button.clicked.connect(self.browse_vpe_file)
341
382
 
342
- layout.addWidget(QLabel("Select Model:"))
343
- layout.addWidget(self.model_combo)
383
+ vpe_path_layout = QHBoxLayout()
384
+ vpe_path_layout.addWidget(self.vpe_path_edit)
385
+ vpe_path_layout.addWidget(browse_button)
386
+ layout.addRow("Custom VPE:", vpe_path_layout)
344
387
 
345
388
  group_box.setLayout(layout)
346
389
  self.left_panel.addWidget(group_box) # Add to left panel
@@ -373,7 +416,7 @@ class DeployGeneratorDialog(QDialog):
373
416
 
374
417
  # Image size control
375
418
  self.imgsz_spinbox = QSpinBox()
376
- self.imgsz_spinbox.setRange(512, 65536)
419
+ self.imgsz_spinbox.setRange(1024, 65536)
377
420
  self.imgsz_spinbox.setSingleStep(1024)
378
421
  self.imgsz_spinbox.setValue(self.imgsz)
379
422
  layout.addRow("Image Size (imgsz):", self.imgsz_spinbox)
@@ -445,17 +488,38 @@ class DeployGeneratorDialog(QDialog):
445
488
  Setup action buttons in a group box.
446
489
  """
447
490
  group_box = QGroupBox("Actions")
448
- layout = QHBoxLayout()
491
+ main_layout = QVBoxLayout()
449
492
 
493
+ # First row: Load and Deactivate buttons side by side
494
+ button_row = QHBoxLayout()
450
495
  load_button = QPushButton("Load Model")
451
496
  load_button.clicked.connect(self.load_model)
452
- layout.addWidget(load_button)
497
+ button_row.addWidget(load_button)
453
498
 
454
499
  deactivate_button = QPushButton("Deactivate Model")
455
500
  deactivate_button.clicked.connect(self.deactivate_model)
456
- layout.addWidget(deactivate_button)
501
+ button_row.addWidget(deactivate_button)
457
502
 
458
- group_box.setLayout(layout)
503
+ main_layout.addLayout(button_row)
504
+
505
+ # Second row: VPE action buttons
506
+ vpe_row = QHBoxLayout()
507
+
508
+ generate_vpe_button = QPushButton("Generate VPEs")
509
+ generate_vpe_button.clicked.connect(self.generate_vpes_from_references)
510
+ vpe_row.addWidget(generate_vpe_button)
511
+
512
+ save_vpe_button = QPushButton("Save VPE")
513
+ save_vpe_button.clicked.connect(self.save_vpe)
514
+ vpe_row.addWidget(save_vpe_button)
515
+
516
+ show_vpe_button = QPushButton("Show VPE")
517
+ show_vpe_button.clicked.connect(self.show_vpe)
518
+ vpe_row.addWidget(show_vpe_button)
519
+
520
+ main_layout.addLayout(vpe_row)
521
+
522
+ group_box.setLayout(main_layout)
459
523
  self.left_panel.addWidget(group_box) # Add to left panel
460
524
 
461
525
  def setup_status_layout(self):
@@ -471,18 +535,23 @@ class DeployGeneratorDialog(QDialog):
471
535
  group_box.setLayout(layout)
472
536
  self.left_panel.addWidget(group_box) # Add to left panel
473
537
 
474
- def setup_source_layout(self):
538
+ def setup_reference_layout(self):
475
539
  """
476
- Set up the layout with source label selection.
477
- The source image is implicitly the currently active image.
540
+ Set up the layout with reference label selection.
541
+ The reference image is implicitly the currently active image.
478
542
  """
479
- group_box = QGroupBox("Reference Label")
543
+ group_box = QGroupBox("Reference")
480
544
  layout = QFormLayout()
481
545
 
482
- # Create the source label combo box
483
- self.source_label_combo_box = QComboBox()
484
- self.source_label_combo_box.currentIndexChanged.connect(self.filter_images_by_label_and_type)
485
- layout.addRow("Reference Label:", self.source_label_combo_box)
546
+ # Create the reference label combo box
547
+ self.reference_label_combo_box = QComboBox()
548
+ self.reference_label_combo_box.currentIndexChanged.connect(self.filter_images_by_label_and_type)
549
+ layout.addRow("Reference Label:", self.reference_label_combo_box)
550
+
551
+ # Create a Reference model combobox (VPE, Images)
552
+ self.reference_method_combo_box = QComboBox()
553
+ self.reference_method_combo_box.addItems(["VPE", "Images"])
554
+ layout.addRow("Reference Method:", self.reference_method_combo_box)
486
555
 
487
556
  group_box.setLayout(layout)
488
557
  self.right_panel.addWidget(group_box) # Add to right panel
@@ -543,6 +612,10 @@ class DeployGeneratorDialog(QDialog):
543
612
  self.area_thresh_max = max_val / 100.0
544
613
  self.main_window.update_area_thresh(self.area_thresh_min, self.area_thresh_max)
545
614
  self.area_threshold_label.setText(f"{self.area_thresh_min:.2f} - {self.area_thresh_max:.2f}")
615
+
616
+ def update_stashed_references_from_ui(self):
617
+ """Updates the internal reference path list from the current UI selection."""
618
+ self.reference_image_paths = self.image_selection_window.table_model.get_highlighted_paths()
546
619
 
547
620
  def get_max_detections(self):
548
621
  """Get the maximum number of detections to return."""
@@ -603,54 +676,44 @@ class DeployGeneratorDialog(QDialog):
603
676
  if self.loaded_model:
604
677
  self.deactivate_model()
605
678
 
606
- def update_source_labels(self):
679
+ def update_reference_labels(self):
607
680
  """
608
- Updates the source label combo box with labels that are associated with
609
- valid reference annotations (Polygons or Rectangles), using the fast cache.
681
+ Updates the reference label combo box with ALL available project labels.
682
+ This dropdown now serves as the "Output Label" for all predictions.
683
+ The "Review" label with id "-1" is excluded.
610
684
  """
611
- self.source_label_combo_box.blockSignals(True)
685
+ self.reference_label_combo_box.blockSignals(True)
612
686
 
613
687
  try:
614
- self.source_label_combo_box.clear()
615
-
616
- dialog_manager = self.image_selection_window.raster_manager
617
- valid_types = {"RectangleAnnotation", "PolygonAnnotation"}
618
- valid_labels = set() # This will store the full Label objects
619
-
620
- # Create a lookup map to get full label objects from their codes
621
- all_project_labels = {lbl.short_label_code: lbl for lbl in self.main_window.label_window.labels}
622
-
623
- # Use the cached data to find all labels that have valid reference types.
624
- for raster in dialog_manager.rasters.values():
625
- # raster.label_to_types_map is like: {'coral': {'Point'}, 'rock': {'Polygon'}}
626
- for label_code, types_for_label in raster.label_to_types_map.items():
627
- # Check if the set of types for this specific label
628
- # has any overlap with our valid reference types.
629
- if not valid_types.isdisjoint(types_for_label):
630
- # This label is a valid reference label.
631
- # Add its full Label object to our set of valid labels.
632
- if label_code in all_project_labels:
633
- valid_labels.add(all_project_labels[label_code])
688
+ self.reference_label_combo_box.clear()
689
+
690
+ # Get all labels from the main label window
691
+ all_project_labels = self.main_window.label_window.labels
692
+
693
+ # Filter out the special "Review" label and create a list of valid labels
694
+ valid_labels = [
695
+ label_obj for label_obj in all_project_labels
696
+ if not (label_obj.short_label_code == "Review" and str(label_obj.id) == "-1")
697
+ ]
634
698
 
635
699
  # Add the valid labels to the combo box, sorted alphabetically.
636
- sorted_valid_labels = sorted(list(valid_labels), key=lambda x: x.short_label_code)
700
+ sorted_valid_labels = sorted(valid_labels, key=lambda x: x.short_label_code)
637
701
  for label_obj in sorted_valid_labels:
638
- self.source_label_combo_box.addItem(label_obj.short_label_code, label_obj)
702
+ self.reference_label_combo_box.addItem(label_obj.short_label_code, label_obj)
639
703
 
640
704
  # Restore the last selected label if it's still present in the list.
641
705
  if self.last_selected_label_code:
642
- index = self.source_label_combo_box.findText(self.last_selected_label_code)
706
+ index = self.reference_label_combo_box.findText(self.last_selected_label_code)
643
707
  if index != -1:
644
- self.source_label_combo_box.setCurrentIndex(index)
708
+ self.reference_label_combo_box.setCurrentIndex(index)
645
709
  finally:
646
- self.source_label_combo_box.blockSignals(False)
710
+ self.reference_label_combo_box.blockSignals(False)
647
711
 
648
- # Manually trigger the filtering now that the combo box is stable.
712
+ # Manually trigger the image filtering now that the combo box is stable.
713
+ # This will still filter the image list to help find references if needed.
649
714
  self.filter_images_by_label_and_type()
650
715
 
651
- return True
652
-
653
- def get_source_annotations(self, reference_label, reference_image_path):
716
+ def get_reference_annotations(self, reference_label, reference_image_path):
654
717
  """
655
718
  Return a list of bboxes and masks for a specific image
656
719
  belonging to the selected label.
@@ -666,22 +729,161 @@ class DeployGeneratorDialog(QDialog):
666
729
  annotations = self.annotation_window.get_image_annotations(reference_image_path)
667
730
 
668
731
  # Filter annotations by the provided label
669
- source_bboxes = []
670
- source_masks = []
732
+ reference_bboxes = []
733
+ reference_masks = []
671
734
  for annotation in annotations:
672
735
  if annotation.label.short_label_code == reference_label.short_label_code:
673
736
  if isinstance(annotation, (PolygonAnnotation, RectangleAnnotation)):
674
737
  bbox = annotation.cropped_bbox
675
- source_bboxes.append(bbox)
738
+ reference_bboxes.append(bbox)
676
739
  if isinstance(annotation, PolygonAnnotation):
677
740
  points = np.array([[p.x(), p.y()] for p in annotation.points])
678
- source_masks.append(points)
741
+ reference_masks.append(points)
679
742
  elif isinstance(annotation, RectangleAnnotation):
680
743
  x1, y1, x2, y2 = bbox
681
744
  rect_points = np.array([[x1, y1], [x2, y1], [x2, y2], [x1, y2]])
682
- source_masks.append(rect_points)
745
+ reference_masks.append(rect_points)
683
746
 
684
- return np.array(source_bboxes), source_masks
747
+ return np.array(reference_bboxes), reference_masks
748
+
749
+ def browse_vpe_file(self):
750
+ """
751
+ Open a file dialog to browse for a VPE file and load it.
752
+ Stores imported VPEs separately from reference-generated VPEs.
753
+ """
754
+ file_path, _ = QFileDialog.getOpenFileName(
755
+ self,
756
+ "Select Visual Prompt Encoding (VPE) File",
757
+ "",
758
+ "VPE Files (*.pt);;All Files (*)"
759
+ )
760
+
761
+ if not file_path:
762
+ return
763
+
764
+ self.vpe_path_edit.setText(file_path)
765
+ self.vpe_path = file_path
766
+
767
+ try:
768
+ # Load the VPE file
769
+ loaded_data = torch.load(file_path)
770
+
771
+ # TODO Move tensors to the appropriate device
772
+ # device = self.main_window.device
773
+
774
+ # Check format type and handle appropriately
775
+ if isinstance(loaded_data, list):
776
+ # New format: list of VPE tensors
777
+ self.imported_vpes = [vpe.to(self.device) for vpe in loaded_data]
778
+ vpe_count = len(self.imported_vpes)
779
+ self.status_bar.setText(f"Loaded {vpe_count} VPE tensors from file")
780
+
781
+ elif isinstance(loaded_data, torch.Tensor):
782
+ # Legacy format: single tensor - convert to list for consistency
783
+ loaded_vpe = loaded_data.to(self.device)
784
+ # Store as a single-item list
785
+ self.imported_vpes = [loaded_vpe]
786
+ self.status_bar.setText("Loaded 1 VPE tensor from file (legacy format)")
787
+
788
+ else:
789
+ # Invalid format
790
+ self.imported_vpes = []
791
+ self.status_bar.setText("Invalid VPE file format")
792
+ QMessageBox.warning(
793
+ self,
794
+ "Invalid VPE",
795
+ "The file does not appear to be a valid VPE format."
796
+ )
797
+ # Clear the VPE path edit field
798
+ self.vpe_path_edit.clear()
799
+
800
+ # For backward compatibility - set self.vpe to the average of imported VPEs
801
+ # This ensures older code paths still work
802
+ if self.imported_vpes:
803
+ combined_vpe = torch.cat(self.imported_vpes).mean(dim=0, keepdim=True)
804
+ self.vpe = torch.nn.functional.normalize(combined_vpe, p=2, dim=-1)
805
+
806
+ except Exception as e:
807
+ self.imported_vpes = []
808
+ self.vpe = None
809
+ self.status_bar.setText(f"Error loading VPE: {str(e)}")
810
+ QMessageBox.critical(
811
+ self,
812
+ "Error Loading VPE",
813
+ f"Failed to load VPE file: {str(e)}"
814
+ )
815
+
816
+ def save_vpe(self):
817
+ """
818
+ Saves the combined collection of VPEs (imported and pre-generated from references) to disk.
819
+ """
820
+ QApplication.setOverrideCursor(Qt.WaitCursor)
821
+
822
+ try:
823
+ # Create a list to hold all VPEs to be saved
824
+ all_vpes = []
825
+
826
+ # Add imported VPEs if available
827
+ if self.imported_vpes:
828
+ all_vpes.extend(self.imported_vpes)
829
+
830
+ # Add pre-generated reference VPEs if available
831
+ if self.reference_vpes:
832
+ all_vpes.extend(self.reference_vpes)
833
+
834
+ # Check if we have any VPEs to save
835
+ if not all_vpes:
836
+ QApplication.restoreOverrideCursor()
837
+ QMessageBox.warning(
838
+ self,
839
+ "No VPEs Available",
840
+ "No VPEs available to save. "
841
+ "Please either load a VPE file or generate VPEs from reference images first."
842
+ )
843
+ return
844
+
845
+ QApplication.restoreOverrideCursor()
846
+
847
+ file_path, _ = QFileDialog.getSaveFileName(
848
+ self,
849
+ "Save VPE Collection",
850
+ "",
851
+ "PyTorch Tensor (*.pt);;All Files (*)"
852
+ )
853
+
854
+ if not file_path:
855
+ return
856
+
857
+ QApplication.setOverrideCursor(Qt.WaitCursor)
858
+
859
+ if not file_path.endswith('.pt'):
860
+ file_path += '.pt'
861
+
862
+ vpe_list_cpu = [vpe.cpu() for vpe in all_vpes]
863
+
864
+ torch.save(vpe_list_cpu, file_path)
865
+
866
+ self.status_bar.setText(f"Saved {len(all_vpes)} VPE tensors to {os.path.basename(file_path)}")
867
+
868
+ QApplication.restoreOverrideCursor()
869
+ QMessageBox.information(
870
+ self,
871
+ "VPE Saved",
872
+ f"Saved {len(all_vpes)} VPE tensors to {file_path}"
873
+ )
874
+
875
+ except Exception as e:
876
+ QApplication.restoreOverrideCursor()
877
+ QMessageBox.critical(
878
+ self,
879
+ "Error Saving VPE",
880
+ f"Failed to save VPE: {str(e)}"
881
+ )
882
+ finally:
883
+ try:
884
+ QApplication.restoreOverrideCursor()
885
+ except:
886
+ pass
685
887
 
686
888
  def load_model(self):
687
889
  """
@@ -692,40 +894,32 @@ class DeployGeneratorDialog(QDialog):
692
894
  progress_bar.show()
693
895
 
694
896
  try:
695
- # Get selected model path and download weights if needed
696
- self.model_path = self.model_combo.currentText()
697
-
698
- # Load model using registry
699
- self.loaded_model = YOLOE(self.model_path).to(self.main_window.device)
700
-
701
- # Create a dummy visual dictionary
702
- visuals = dict(
703
- bboxes=np.array(
704
- [
705
- [120, 425, 160, 445],
706
- ],
707
- ),
708
- cls=np.array(
709
- np.zeros(1),
710
- ),
711
- )
897
+ # Load the model using reload_model method
898
+ self.reload_model()
712
899
 
713
- # Run a dummy prediction to load the model
714
- self.loaded_model.predict(
715
- np.zeros((640, 640, 3), dtype=np.uint8),
716
- visual_prompts=visuals.copy(),
717
- predictor=YOLOEVPDetectPredictor,
718
- imgsz=640,
719
- conf=0.99,
720
- )
900
+ # Calculate total number of VPEs from both sources
901
+ total_vpes = len(self.imported_vpes) + len(self.reference_vpes)
902
+
903
+ if total_vpes > 0:
904
+ if self.imported_vpes and self.reference_vpes:
905
+ message = f"Model loaded with {len(self.imported_vpes)} imported VPEs "
906
+ message += f"and {len(self.reference_vpes)} reference VPEs"
907
+ elif self.imported_vpes:
908
+ message = f"Model loaded with {len(self.imported_vpes)} imported VPEs"
909
+ else:
910
+ message = f"Model loaded with {len(self.reference_vpes)} reference VPEs"
911
+
912
+ self.status_bar.setText(message)
913
+ else:
914
+ message = "Model loaded with default VPE"
915
+ self.status_bar.setText("Model loaded with default VPE")
721
916
 
917
+ # Finish progress bar
722
918
  progress_bar.finish_progress()
723
- self.status_bar.setText("Model loaded")
724
- QMessageBox.information(self.annotation_window,
725
- "Model Loaded",
726
- "Model loaded successfully")
919
+ QMessageBox.information(self.annotation_window, "Model Loaded", message)
727
920
 
728
921
  except Exception as e:
922
+ self.loaded_model = None
729
923
  QMessageBox.critical(self.annotation_window,
730
924
  "Error Loading Model",
731
925
  f"Error loading model: {e}")
@@ -737,7 +931,52 @@ class DeployGeneratorDialog(QDialog):
737
931
  progress_bar.stop_progress()
738
932
  progress_bar.close()
739
933
  progress_bar = None
934
+
935
+ def reload_model(self):
936
+ """
937
+ Subset of the load_model method. This is needed when additional
938
+ reference images and annotations (i.e., VPEs) are added (we have
939
+ to re-load the model each time).
940
+
941
+ This method also ensures that we stash the currently highlighted reference
942
+ image paths before reloading, so they're available for predictions
943
+ even if the user switches the active image in the main window.
944
+ """
945
+ self.loaded_model = None
946
+
947
+ # Get selected model path and download weights if needed
948
+ self.model_path = self.model_combo.currentText()
949
+
950
+ # Load model using registry
951
+ self.loaded_model = YOLOE(self.model_path, verbose=False).to(self.device) # TODO
952
+
953
+ # Create a dummy visual dictionary for standard model loading
954
+ visual_prompts = dict(
955
+ bboxes=np.array(
956
+ [
957
+ [120, 425, 160, 445], # Random box
958
+ ],
959
+ ),
960
+ cls=np.array(
961
+ np.zeros(1),
962
+ ),
963
+ )
964
+
965
+ # Run a dummy prediction to load the model
966
+ self.loaded_model.predict(
967
+ np.zeros((640, 640, 3), dtype=np.uint8),
968
+ visual_prompts=visual_prompts.copy(), # This needs to happen to properly initialize the predictor
969
+ predictor=YOLOEVPSegPredictor, # This also needs to be SegPredictor, no matter what
970
+ imgsz=640,
971
+ conf=0.99,
972
+ )
740
973
 
974
+ # If a VPE file was loaded, use it with the model after the dummy prediction
975
+ if self.vpe is not None and isinstance(self.vpe, torch.Tensor):
976
+ # Directly set the final tensor as the prompt for the predictor
977
+ self.loaded_model.is_fused = lambda: False
978
+ self.loaded_model.set_classes(["object0"], self.vpe)
979
+
741
980
  def predict(self, image_paths=None):
742
981
  """
743
982
  Make predictions on the given image paths using the loaded model.
@@ -745,11 +984,11 @@ class DeployGeneratorDialog(QDialog):
745
984
  Args:
746
985
  image_paths: List of image paths to process. If None, uses the current image.
747
986
  """
748
- if not self.loaded_model or not self.source_label:
987
+ if not self.loaded_model or not self.reference_label:
749
988
  return
750
989
 
751
990
  # Update class mapping with the selected reference label
752
- self.class_mapping = {0: self.source_label}
991
+ self.class_mapping = {0: self.reference_label}
753
992
 
754
993
  # Create a results processor
755
994
  results_processor = ResultsProcessor(
@@ -809,32 +1048,33 @@ class DeployGeneratorDialog(QDialog):
809
1048
 
810
1049
  return work_areas_data
811
1050
 
812
- def _apply_model(self, inputs):
813
- """
814
- Apply the model to the target inputs, using each highlighted source
815
- image as an individual reference for a separate prediction run.
1051
+ def _get_references(self):
816
1052
  """
817
- # Update the model with user parameters
818
- self.loaded_model.conf = self.main_window.get_uncertainty_thresh()
819
- self.loaded_model.iou = self.main_window.get_iou_thresh()
820
- self.loaded_model.max_det = self.get_max_detections()
1053
+ Get the reference annotations using the stashed list of reference images
1054
+ that was saved when the user accepted the dialog.
821
1055
 
822
- # NOTE: self.target_images contains the reference images highlighted in the dialog
823
- reference_image_paths = self.target_images
824
-
825
- if not reference_image_paths:
826
- QMessageBox.warning(self,
827
- "No Reference Images",
828
- "You must highlight at least one reference image.")
829
- return []
830
-
831
- # Get the selected reference label from the stored variable
832
- source_label = self.source_label
1056
+ Returns:
1057
+ dict: Dictionary mapping image paths to annotation data, or empty dict if no valid references.
1058
+ """
1059
+ # Use the "stashed" list of paths. Do NOT query the table_model again,
1060
+ # as the UI's highlight state may have been cleared by other actions.
1061
+ reference_paths = self.reference_image_paths
1062
+
1063
+ if not reference_paths:
1064
+ print("No reference image paths were stashed to use for prediction.")
1065
+ return {}
1066
+
1067
+ # Get the reference label that was also stashed
1068
+ reference_label = self.reference_label
1069
+ if not reference_label:
1070
+ # This check is a safeguard; the accept() method should prevent this.
1071
+ print("No reference label was selected.")
1072
+ return {}
833
1073
 
834
- # Create a dictionary of reference annotations, with image path as the key.
1074
+ # Create a dictionary of reference annotations from the stashed paths
835
1075
  reference_annotations_dict = {}
836
- for path in reference_image_paths:
837
- bboxes, masks = self.get_source_annotations(source_label, path)
1076
+ for path in reference_paths:
1077
+ bboxes, masks = self.get_reference_annotations(reference_label, path)
838
1078
  if bboxes.size > 0:
839
1079
  reference_annotations_dict[path] = {
840
1080
  'bboxes': bboxes,
@@ -842,37 +1082,48 @@ class DeployGeneratorDialog(QDialog):
842
1082
  'cls': np.zeros(len(bboxes))
843
1083
  }
844
1084
 
845
- # Set the task
846
- self.task = self.use_task_dropdown.currentText()
847
- predictor = YOLOEVPSegPredictor if self.task == "segment" else YOLOEVPDetectPredictor
1085
+ return reference_annotations_dict
1086
+
1087
+ def _apply_model_using_images(self, inputs, reference_dict):
1088
+ """
1089
+ Apply the model using the provided images and reference annotations (dict). This method
1090
+ loops through each reference image using its annotations; we then aggregate
1091
+ all the results together. Less efficient, but potentially more accurate.
1092
+
1093
+ Args:
1094
+ inputs (list): List of input images.
1095
+ reference_dict (dict): Dictionary containing reference annotations for each image.
848
1096
 
1097
+ Returns:
1098
+ list: List of prediction results.
1099
+ """
849
1100
  # Create a progress bar for iterating through reference images
850
1101
  QApplication.setOverrideCursor(Qt.WaitCursor)
851
1102
  progress_bar = ProgressBar(self.annotation_window, title="Making Predictions per Reference")
852
1103
  progress_bar.show()
853
- progress_bar.start_progress(len(reference_annotations_dict))
854
-
1104
+ progress_bar.start_progress(len(reference_dict))
1105
+
855
1106
  results_list = []
856
1107
  # The 'inputs' list contains work areas from the single target image.
857
1108
  # We will predict on the first work area/full image.
858
1109
  input_image = inputs[0]
859
1110
 
860
1111
  # Iterate through each reference image and its annotations
861
- for ref_path, ref_annotations in reference_annotations_dict.items():
1112
+ for ref_path, ref_annotations in reference_dict.items():
862
1113
  # The 'refer_image' parameter is the path to the current reference image
863
1114
  # The 'visual_prompts' are the annotations from that same reference image
864
- visuals = {
1115
+ visual_prompts = {
865
1116
  'bboxes': ref_annotations['bboxes'],
866
1117
  'cls': ref_annotations['cls'],
867
1118
  }
868
1119
  if self.task == 'segment':
869
- visuals['masks'] = ref_annotations['masks']
1120
+ visual_prompts['masks'] = ref_annotations['masks']
870
1121
 
871
1122
  # Make predictions on the target using the current reference
872
1123
  results = self.loaded_model.predict(input_image,
873
1124
  refer_image=ref_path,
874
- visual_prompts=visuals,
875
- predictor=predictor,
1125
+ visual_prompts=visual_prompts,
1126
+ predictor=YOLOEVPSegPredictor, # TODO This is necessary here?
876
1127
  imgsz=self.imgsz_spinbox.value(),
877
1128
  conf=self.main_window.get_uncertainty_thresh(),
878
1129
  iou=self.main_window.get_iou_thresh(),
@@ -904,6 +1155,192 @@ class DeployGeneratorDialog(QDialog):
904
1155
  return []
905
1156
 
906
1157
  return [[combined_results]]
1158
+
1159
+ def generate_vpes_from_references(self):
1160
+ """
1161
+ Calculates VPEs from the currently highlighted reference images and
1162
+ stores them in self.reference_vpes, overwriting any previous ones.
1163
+ """
1164
+ if not self.loaded_model:
1165
+ QMessageBox.warning(self, "No Model Loaded", "A model must be loaded before generating VPEs.")
1166
+ return
1167
+
1168
+ # Always sync with the live UI selection before generating.
1169
+ self.update_stashed_references_from_ui()
1170
+ references_dict = self._get_references()
1171
+
1172
+ if not references_dict:
1173
+ QMessageBox.information(
1174
+ self,
1175
+ "No References Selected",
1176
+ "Please highlight one or more reference images in the table to generate VPEs."
1177
+ )
1178
+ return
1179
+
1180
+ QApplication.setOverrideCursor(Qt.WaitCursor)
1181
+ progress_bar = ProgressBar(self, title="Generating VPEs")
1182
+ progress_bar.show()
1183
+
1184
+ try:
1185
+ # Make progress bar busy
1186
+ progress_bar.set_busy_mode("Generating VPEs...")
1187
+ # Reload the model to ensure a clean state for VPE generation
1188
+ self.reload_model()
1189
+
1190
+ # The references_to_vpe method will calculate and update self.reference_vpes
1191
+ new_vpes = self.references_to_vpe(references_dict, update_reference_vpes=True)
1192
+
1193
+ if new_vpes:
1194
+ num_vpes = len(new_vpes)
1195
+ num_images = len(references_dict)
1196
+ message = f"Successfully generated {num_vpes} VPEs from {num_images} reference image(s)."
1197
+ self.status_bar.setText(message)
1198
+ QMessageBox.information(self, "VPEs Generated", message)
1199
+ else:
1200
+ message = "Could not generate VPEs. Ensure annotations are valid."
1201
+ self.status_bar.setText(message)
1202
+ QMessageBox.warning(self, "Generation Failed", message)
1203
+
1204
+ except Exception as e:
1205
+ QMessageBox.critical(self, "Error Generating VPEs", f"An unexpected error occurred: {str(e)}")
1206
+ self.status_bar.setText("Error during VPE generation.")
1207
+ finally:
1208
+ QApplication.restoreOverrideCursor()
1209
+ progress_bar.stop_progress()
1210
+ progress_bar.close()
1211
+
1212
+ def references_to_vpe(self, reference_dict, update_reference_vpes=True):
1213
+ """
1214
+ Converts the contents of a reference dictionary to VPEs (Visual Prompt Embeddings).
1215
+ Reference dictionaries contain information about the visual prompts for each reference image:
1216
+ dict[image_path]: {bboxes, masks, cls}
1217
+
1218
+ Args:
1219
+ reference_dict (dict): The reference dictionary containing visual prompts for each image.
1220
+ update_reference_vpes (bool): Whether to update self.reference_vpes with the results.
1221
+
1222
+ Returns:
1223
+ list: List of individual VPE tensors (normalized), or None if empty reference_dict
1224
+ """
1225
+ # Check if the reference dictionary is empty
1226
+ if not reference_dict:
1227
+ return None
1228
+
1229
+ # Create a list to hold the individual VPE tensors
1230
+ vpe_list = []
1231
+
1232
+ for ref_path, ref_annotations in reference_dict.items():
1233
+ # Set the prompts to the model predictor
1234
+ self.loaded_model.predictor.set_prompts(ref_annotations)
1235
+
1236
+ # Get the VPE from the model
1237
+ vpe = self.loaded_model.predictor.get_vpe(ref_path)
1238
+
1239
+ # Normalize individual VPE
1240
+ vpe_normalized = torch.nn.functional.normalize(vpe, p=2, dim=-1)
1241
+ vpe_list.append(vpe_normalized)
1242
+
1243
+ # Check if we have any valid VPEs
1244
+ if not vpe_list:
1245
+ return None
1246
+
1247
+ # Update the reference_vpes list if requested
1248
+ if update_reference_vpes:
1249
+ self.reference_vpes = vpe_list
1250
+
1251
+ return vpe_list
1252
+
1253
+ def _apply_model_using_vpe(self, inputs):
1254
+ """
1255
+ Apply the model to the inputs using pre-calculated VPEs from imported files
1256
+ and/or generated from reference annotations.
1257
+
1258
+ Args:
1259
+ inputs (list): List of input images.
1260
+
1261
+ Returns:
1262
+ list: List of prediction results.
1263
+ """
1264
+ # First reload the model to clear any cached data
1265
+ self.reload_model()
1266
+
1267
+ # Initialize combined_vpes list
1268
+ combined_vpes = []
1269
+
1270
+ # Add imported VPEs if available
1271
+ if self.imported_vpes:
1272
+ combined_vpes.extend(self.imported_vpes)
1273
+
1274
+ # Add pre-generated reference VPEs if available
1275
+ if self.reference_vpes:
1276
+ combined_vpes.extend(self.reference_vpes)
1277
+
1278
+ # Check if we have any VPEs to use
1279
+ if not combined_vpes:
1280
+ QMessageBox.warning(
1281
+ self,
1282
+ "No VPEs Available",
1283
+ "No VPEs are available for prediction. "
1284
+ "Please either load a VPE file or generate VPEs from reference images."
1285
+ )
1286
+ return []
1287
+
1288
+ # Average all the VPEs together to create a final VPE tensor
1289
+ averaged_vpe = torch.cat(combined_vpes).mean(dim=0, keepdim=True)
1290
+ final_vpe = torch.nn.functional.normalize(averaged_vpe, p=2, dim=-1)
1291
+
1292
+ # For backward compatibility, update self.vpe
1293
+ self.vpe = final_vpe
1294
+
1295
+ # Set the final VPE to the model
1296
+ self.loaded_model.is_fused = lambda: False
1297
+ self.loaded_model.set_classes(["object0"], final_vpe)
1298
+
1299
+ # Make predictions on the target using the averaged VPE
1300
+ results = self.loaded_model.predict(inputs[0],
1301
+ visual_prompts=[],
1302
+ imgsz=self.imgsz_spinbox.value(),
1303
+ conf=self.main_window.get_uncertainty_thresh(),
1304
+ iou=self.main_window.get_iou_thresh(),
1305
+ max_det=self.get_max_detections(),
1306
+ retina_masks=self.task == "segment")
1307
+
1308
+ return [results]
1309
+
1310
+ def _apply_model(self, inputs):
1311
+ """
1312
+ Apply the model to the target inputs. This method handles both image-based
1313
+ references and VPE-based references.
1314
+ """
1315
+ # Update the model with user parameters
1316
+ self.task = self.use_task_dropdown.currentText()
1317
+
1318
+ self.loaded_model.conf = self.main_window.get_uncertainty_thresh()
1319
+ self.loaded_model.iou = self.main_window.get_iou_thresh()
1320
+ self.loaded_model.max_det = self.get_max_detections()
1321
+
1322
+ # Get the reference information for the currently selected rows
1323
+ references_dict = self._get_references()
1324
+
1325
+ # Check if the user is using VPE or Reference Images
1326
+ if self.reference_method_combo_box.currentText() == "VPE":
1327
+ # The VPE method will use pre-loaded/pre-generated VPEs.
1328
+ # The internal checks for whether any VPEs exist are now inside _apply_model_using_vpe.
1329
+ results = self._apply_model_using_vpe(inputs)
1330
+ else:
1331
+ # Use Reference Images method - requires reference images
1332
+ if not references_dict:
1333
+ QMessageBox.warning(
1334
+ self,
1335
+ "No References Selected",
1336
+ "No reference images with valid annotations were selected. "
1337
+ "Please select at least one reference image."
1338
+ )
1339
+ return []
1340
+
1341
+ results = self._apply_model_using_images(inputs, references_dict)
1342
+
1343
+ return results
907
1344
 
908
1345
  def _apply_sam(self, results_list, image_path):
909
1346
  """Apply SAM to the results if needed."""
@@ -1000,17 +1437,223 @@ class DeployGeneratorDialog(QDialog):
1000
1437
  progress_bar.stop_progress()
1001
1438
  progress_bar.close()
1002
1439
 
1440
+ def show_vpe(self):
1441
+ """
1442
+ Show a visualization of the currently stored VPEs using PyQtGraph.
1443
+ """
1444
+ try:
1445
+ vpes_with_source = []
1446
+
1447
+ # 1. Add any VPEs that were loaded from a file
1448
+ if self.imported_vpes:
1449
+ for vpe in self.imported_vpes:
1450
+ vpes_with_source.append((vpe, "Import"))
1451
+
1452
+ # 2. Add any pre-generated VPEs from reference images
1453
+ if self.reference_vpes:
1454
+ for vpe in self.reference_vpes:
1455
+ vpes_with_source.append((vpe, "Reference"))
1456
+
1457
+ # 3. Check if there is anything to visualize
1458
+ if not vpes_with_source:
1459
+ QMessageBox.warning(
1460
+ self,
1461
+ "No VPEs Available",
1462
+ "No VPEs available to visualize. Please load a VPE file or generate VPEs from references first."
1463
+ )
1464
+ return
1465
+
1466
+ # 4. Create the visualization dialog
1467
+ all_vpe_tensors = [vpe for vpe, source in vpes_with_source]
1468
+ averaged_vpe = torch.cat(all_vpe_tensors).mean(dim=0, keepdim=True)
1469
+ final_vpe = torch.nn.functional.normalize(averaged_vpe, p=2, dim=-1)
1470
+
1471
+ QApplication.setOverrideCursor(Qt.WaitCursor)
1472
+
1473
+ dialog = VPEVisualizationDialog(vpes_with_source, final_vpe, self)
1474
+ QApplication.restoreOverrideCursor()
1475
+
1476
+ dialog.exec_()
1477
+
1478
+ except Exception as e:
1479
+ QApplication.restoreOverrideCursor()
1480
+ QMessageBox.critical(self, "Error Visualizing VPE", f"An error occurred: {str(e)}")
1481
+
1003
1482
  def deactivate_model(self):
1004
1483
  """
1005
1484
  Deactivate the currently loaded model and clean up resources.
1006
1485
  """
1007
1486
  self.loaded_model = None
1008
1487
  self.model_path = None
1009
- # Clean up resources
1488
+
1489
+ # Clear all VPE-related data
1490
+ self.vpe_path_edit.clear()
1491
+ self.vpe_path = None
1492
+ self.vpe = None
1493
+ self.imported_vpes = []
1494
+ self.reference_vpes = []
1495
+
1496
+ # Clean up references
1010
1497
  gc.collect()
1011
1498
  torch.cuda.empty_cache()
1499
+
1012
1500
  # Untoggle all tools
1013
1501
  self.main_window.untoggle_all_tools()
1502
+
1014
1503
  # Update status bar
1015
1504
  self.status_bar.setText("No model loaded")
1016
- QMessageBox.information(self, "Model Deactivated", "Model deactivated")
1505
+ QMessageBox.information(self, "Model Deactivated", "Model deactivated")
1506
+
1507
+
1508
+ class VPEVisualizationDialog(QDialog):
1509
+ """
1510
+ Dialog for visualizing VPE embeddings in 2D space using PCA.
1511
+ """
1512
+ def __init__(self, vpe_list_with_source, final_vpe=None, parent=None):
1513
+ """
1514
+ Initialize the dialog with a list of VPE tensors and their sources.
1515
+
1516
+ Args:
1517
+ vpe_list_with_source (list): List of (VPE tensor, source_str) tuples
1518
+ final_vpe (torch.Tensor, optional): The final (averaged) VPE
1519
+ parent (QWidget, optional): Parent widget
1520
+ """
1521
+ super().__init__(parent)
1522
+ self.setWindowTitle("VPE Visualization")
1523
+ self.resize(1000, 1000)
1524
+
1525
+ # Add a maximize button to the dialog's title bar
1526
+ self.setWindowFlags(self.windowFlags() | Qt.WindowMaximizeButtonHint)
1527
+
1528
+ # Store the VPEs and their sources
1529
+ self.vpe_list_with_source = vpe_list_with_source
1530
+ self.final_vpe = final_vpe
1531
+
1532
+ # Create the layout
1533
+ layout = QVBoxLayout(self)
1534
+
1535
+ # Create the plot widget
1536
+ self.plot_widget = pg.PlotWidget()
1537
+ self.plot_widget.setBackground('w') # White background
1538
+ self.plot_widget.setTitle("PCA Visualization of Visual Prompt Embeddings", color="#000000", size="10pt")
1539
+ self.plot_widget.showGrid(x=True, y=True, alpha=0.3)
1540
+
1541
+ # Add the plot widget to the layout
1542
+ layout.addWidget(self.plot_widget)
1543
+
1544
+ # Add spacing between plot_widget and info_label
1545
+ layout.addSpacing(20)
1546
+
1547
+ # Add information label at the bottom
1548
+ self.info_label = QLabel()
1549
+ self.info_label.setAlignment(Qt.AlignCenter)
1550
+ layout.addWidget(self.info_label)
1551
+
1552
+ # Create the button box
1553
+ button_box = QDialogButtonBox(QDialogButtonBox.Close)
1554
+ button_box.rejected.connect(self.reject)
1555
+ layout.addWidget(button_box)
1556
+
1557
+ # Visualize the VPEs
1558
+ self.visualize_vpes()
1559
+
1560
+ def visualize_vpes(self):
1561
+ """
1562
+ Apply PCA to the VPE tensors and visualize them in 2D space.
1563
+ """
1564
+ if not self.vpe_list_with_source:
1565
+ self.info_label.setText("No VPEs available to visualize.")
1566
+ return
1567
+
1568
+ # Convert tensors to numpy arrays for PCA, separating them from the source string
1569
+ vpe_arrays = [vpe.detach().cpu().numpy().squeeze() for vpe, source in self.vpe_list_with_source]
1570
+
1571
+ # If final VPE is provided, add it to the arrays
1572
+ final_vpe_array = None
1573
+ if self.final_vpe is not None:
1574
+ final_vpe_array = self.final_vpe.detach().cpu().numpy().squeeze()
1575
+ all_vpes = np.vstack(vpe_arrays + [final_vpe_array])
1576
+ else:
1577
+ all_vpes = np.vstack(vpe_arrays)
1578
+
1579
+ # Apply PCA to reduce to 2 dimensions
1580
+ pca = PCA(n_components=2)
1581
+ vpes_2d = pca.fit_transform(all_vpes)
1582
+
1583
+ # Clear the plot
1584
+ self.plot_widget.clear()
1585
+
1586
+ # Generate random colors for individual VPEs
1587
+ num_vpes = len(vpe_arrays)
1588
+ colors = self.generate_distinct_colors(num_vpes)
1589
+
1590
+ # Create a legend with 3 columns to keep it compact
1591
+ legend = self.plot_widget.addLegend(colCount=3)
1592
+
1593
+ # Plot individual VPEs
1594
+ for i, (vpe_tuple, vpe_2d) in enumerate(zip(self.vpe_list_with_source, vpes_2d[:num_vpes])):
1595
+ source_char = 'I' if vpe_tuple[1] == 'Import' else 'R'
1596
+ color = pg.mkColor(colors[i])
1597
+ scatter = pg.ScatterPlotItem(
1598
+ x=[vpe_2d[0]],
1599
+ y=[vpe_2d[1]],
1600
+ brush=color,
1601
+ size=15,
1602
+ name=f"VPE {i+1} ({source_char})"
1603
+ )
1604
+ self.plot_widget.addItem(scatter)
1605
+
1606
+ # Plot the final (averaged) VPE if available
1607
+ if final_vpe_array is not None:
1608
+ final_vpe_2d = vpes_2d[-1]
1609
+ scatter = pg.ScatterPlotItem(
1610
+ x=[final_vpe_2d[0]],
1611
+ y=[final_vpe_2d[1]],
1612
+ brush=pg.mkBrush(color='r'),
1613
+ size=20,
1614
+ symbol='star',
1615
+ name="Final VPE"
1616
+ )
1617
+ self.plot_widget.addItem(scatter)
1618
+
1619
+ # Update the information label
1620
+ orig_dim = self.vpe_list_with_source[0][0].shape[-1]
1621
+ explained_variance = sum(pca.explained_variance_ratio_)
1622
+ self.info_label.setText(
1623
+ f"Original dimension: {orig_dim} → Reduced to 2D\n"
1624
+ f"Total explained variance: {explained_variance:.2%}\n"
1625
+ f"PC1: {pca.explained_variance_ratio_[0]:.2%} variance, "
1626
+ f"PC2: {pca.explained_variance_ratio_[1]:.2%} variance"
1627
+ )
1628
+
1629
+ def generate_distinct_colors(self, num_colors):
1630
+ """
1631
+ Generate visually distinct colors by using evenly spaced hues
1632
+ with random saturation and value.
1633
+
1634
+ Args:
1635
+ num_colors (int): Number of colors to generate
1636
+
1637
+ Returns:
1638
+ list: List of color hex strings
1639
+ """
1640
+ import random
1641
+ from colorsys import hsv_to_rgb
1642
+
1643
+ colors = []
1644
+ for i in range(num_colors):
1645
+ # Use golden ratio to space hues evenly
1646
+ hue = (i * 0.618033988749895) % 1.0
1647
+ # Random saturation between 0.6-1.0 (avoid too pale)
1648
+ saturation = random.uniform(0.6, 1.0)
1649
+ # Random value between 0.7-1.0 (avoid too dark)
1650
+ value = random.uniform(0.7, 1.0)
1651
+
1652
+ # Convert HSV to RGB (0-1 range)
1653
+ r, g, b = hsv_to_rgb(hue, saturation, value)
1654
+
1655
+ # Convert RGB to hex string
1656
+ hex_color = f"#{int(r*255):02x}{int(g*255):02x}{int(b*255):02x}"
1657
+ colors.append(hex_color)
1658
+
1659
+ return colors