coralnet-toolbox 0.0.73__py2.py3-none-any.whl → 0.0.74__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. coralnet_toolbox/Annotations/QtAnnotation.py +28 -69
  2. coralnet_toolbox/Annotations/QtMaskAnnotation.py +408 -0
  3. coralnet_toolbox/Annotations/QtMultiPolygonAnnotation.py +72 -56
  4. coralnet_toolbox/Annotations/QtPatchAnnotation.py +165 -216
  5. coralnet_toolbox/Annotations/QtPolygonAnnotation.py +497 -353
  6. coralnet_toolbox/Annotations/QtRectangleAnnotation.py +126 -116
  7. coralnet_toolbox/CoralNet/QtDownload.py +2 -1
  8. coralnet_toolbox/Explorer/QtExplorer.py +16 -14
  9. coralnet_toolbox/Explorer/QtSettingsWidgets.py +114 -82
  10. coralnet_toolbox/IO/QtExportTagLabAnnotations.py +30 -10
  11. coralnet_toolbox/IO/QtImportTagLabAnnotations.py +21 -15
  12. coralnet_toolbox/IO/QtOpenProject.py +46 -78
  13. coralnet_toolbox/IO/QtSaveProject.py +18 -43
  14. coralnet_toolbox/MachineLearning/ExportDataset/QtBase.py +1 -1
  15. coralnet_toolbox/MachineLearning/ImportDataset/QtBase.py +42 -22
  16. coralnet_toolbox/MachineLearning/VideoInference/QtBase.py +0 -4
  17. coralnet_toolbox/QtEventFilter.py +11 -0
  18. coralnet_toolbox/QtImageWindow.py +117 -68
  19. coralnet_toolbox/QtLabelWindow.py +13 -1
  20. coralnet_toolbox/QtMainWindow.py +5 -27
  21. coralnet_toolbox/QtProgressBar.py +52 -27
  22. coralnet_toolbox/Rasters/RasterTableModel.py +8 -8
  23. coralnet_toolbox/SAM/QtDeployPredictor.py +10 -0
  24. coralnet_toolbox/SeeAnything/QtDeployGenerator.py +779 -161
  25. coralnet_toolbox/SeeAnything/QtDeployPredictor.py +86 -149
  26. coralnet_toolbox/Tools/QtCutSubTool.py +18 -2
  27. coralnet_toolbox/Tools/QtResizeSubTool.py +19 -2
  28. coralnet_toolbox/Tools/QtSAMTool.py +72 -50
  29. coralnet_toolbox/Tools/QtSeeAnythingTool.py +8 -5
  30. coralnet_toolbox/Tools/QtSelectTool.py +27 -3
  31. coralnet_toolbox/Tools/QtSubtractSubTool.py +66 -0
  32. coralnet_toolbox/Tools/__init__.py +2 -0
  33. coralnet_toolbox/__init__.py +1 -1
  34. coralnet_toolbox/utilities.py +137 -47
  35. coralnet_toolbox-0.0.74.dist-info/METADATA +375 -0
  36. {coralnet_toolbox-0.0.73.dist-info → coralnet_toolbox-0.0.74.dist-info}/RECORD +40 -38
  37. coralnet_toolbox-0.0.73.dist-info/METADATA +0 -341
  38. {coralnet_toolbox-0.0.73.dist-info → coralnet_toolbox-0.0.74.dist-info}/WHEEL +0 -0
  39. {coralnet_toolbox-0.0.73.dist-info → coralnet_toolbox-0.0.74.dist-info}/entry_points.txt +0 -0
  40. {coralnet_toolbox-0.0.73.dist-info → coralnet_toolbox-0.0.74.dist-info}/licenses/LICENSE.txt +0 -0
  41. {coralnet_toolbox-0.0.73.dist-info → coralnet_toolbox-0.0.74.dist-info}/top_level.txt +0 -0
@@ -2,24 +2,24 @@ import warnings
2
2
 
3
3
  import os
4
4
  import gc
5
- import json
6
- import copy
7
5
 
8
6
  import numpy as np
7
+ from sklearn.decomposition import PCA
9
8
 
10
9
  import torch
11
10
  from torch.cuda import empty_cache
12
11
 
12
+ import pyqtgraph as pg
13
+ from pyqtgraph.Qt import QtGui
14
+
13
15
  from ultralytics import YOLOE
14
16
  from ultralytics.models.yolo.yoloe import YOLOEVPSegPredictor
15
- from ultralytics.models.yolo.yoloe import YOLOEVPDetectPredictor
16
17
 
17
18
  from PyQt5.QtCore import Qt
18
- from PyQt5.QtGui import QColor
19
- from PyQt5.QtWidgets import (QMessageBox, QCheckBox, QVBoxLayout, QApplication,
19
+ from PyQt5.QtWidgets import (QMessageBox, QVBoxLayout, QApplication, QFileDialog,
20
20
  QLabel, QDialog, QDialogButtonBox, QGroupBox, QLineEdit,
21
21
  QFormLayout, QComboBox, QSpinBox, QSlider, QPushButton,
22
- QHBoxLayout, QWidget, QFileDialog)
22
+ QHBoxLayout)
23
23
 
24
24
  from coralnet_toolbox.Annotations.QtPolygonAnnotation import PolygonAnnotation
25
25
  from coralnet_toolbox.Annotations.QtRectangleAnnotation import RectangleAnnotation
@@ -79,10 +79,18 @@ class DeployGeneratorDialog(QDialog):
79
79
  self.class_mapping = {}
80
80
 
81
81
  # Reference image and label
82
- self.source_images = []
83
- self.source_label = None
84
- # Target images
85
- self.target_images = []
82
+ self.reference_label = None
83
+ self.reference_image_paths = []
84
+
85
+ # Visual Prompting Encoding (VPE) - legacy single tensor variable
86
+ self.vpe_path = None
87
+ self.vpe = None
88
+
89
+ # New separate VPE collections
90
+ self.imported_vpes = [] # VPEs loaded from file
91
+ self.reference_vpes = [] # VPEs created from reference images
92
+
93
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
86
94
 
87
95
  # Main vertical layout for the dialog
88
96
  self.layout = QVBoxLayout(self)
@@ -110,7 +118,7 @@ class DeployGeneratorDialog(QDialog):
110
118
  self.setup_status_layout()
111
119
 
112
120
  # Add layouts to the right panel
113
- self.setup_source_layout()
121
+ self.setup_reference_layout()
114
122
 
115
123
  # # Add a full ImageWindow instance for target image selection
116
124
  self.image_selection_window = ImageWindow(self.main_window)
@@ -158,6 +166,9 @@ class DeployGeneratorDialog(QDialog):
158
166
  iw.search_bar_images.setEnabled(False)
159
167
  iw.search_bar_labels.setEnabled(False)
160
168
  iw.top_k_combo.setEnabled(False)
169
+
170
+ # Hide the "Current" label as it is not applicable in this dialog
171
+ iw.current_image_index_label.hide()
161
172
 
162
173
  # Set Top-K to Top1
163
174
  iw.top_k_combo.setCurrentText("Top1")
@@ -190,7 +201,7 @@ class DeployGeneratorDialog(QDialog):
190
201
  self.sync_image_window()
191
202
  # This now populates the dropdown, restores the last selection,
192
203
  # and then manually triggers the image filtering.
193
- self.update_source_labels()
204
+ self.update_reference_labels()
194
205
 
195
206
  def sync_image_window(self):
196
207
  """
@@ -219,14 +230,23 @@ class DeployGeneratorDialog(QDialog):
219
230
  annotation that has BOTH the selected label AND a valid type (Polygon or Rectangle).
220
231
  This uses the fast, pre-computed cache for performance.
221
232
  """
222
- source_label = self.source_label_combo_box.currentData()
223
- source_label_text = self.source_label_combo_box.currentText()
233
+ # Persist the user's current highlights from the table model before filtering.
234
+ # This ensures that if the user highlights items and then changes the filter,
235
+ # their selection is not lost.
236
+ current_highlights = self.image_selection_window.table_model.get_highlighted_paths()
237
+ if current_highlights:
238
+ self.reference_image_paths = current_highlights
239
+
240
+ reference_label = self.reference_label_combo_box.currentData()
241
+ reference_label_text = self.reference_label_combo_box.currentText()
224
242
 
225
243
  # Store the last selected label for a better user experience on re-opening.
226
- if source_label_text:
227
- self.last_selected_label_code = source_label_text
244
+ if reference_label_text:
245
+ self.last_selected_label_code = reference_label_text
246
+ # Also store the reference label object itself
247
+ self.reference_label = reference_label
228
248
 
229
- if not source_label:
249
+ if not reference_label:
230
250
  # If no label is selected (e.g., during initialization), show an empty list.
231
251
  self.image_selection_window.table_model.set_filtered_paths([])
232
252
  return
@@ -235,7 +255,7 @@ class DeployGeneratorDialog(QDialog):
235
255
  final_filtered_paths = []
236
256
 
237
257
  valid_types = {"RectangleAnnotation", "PolygonAnnotation"}
238
- selected_label_code = source_label.short_label_code
258
+ selected_label_code = reference_label.short_label_code
239
259
 
240
260
  # Loop through paths and check the pre-computed map on each raster
241
261
  for path in all_paths:
@@ -257,40 +277,54 @@ class DeployGeneratorDialog(QDialog):
257
277
  # Directly set the filtered list in the table model.
258
278
  self.image_selection_window.table_model.set_filtered_paths(final_filtered_paths)
259
279
 
280
+ # Try to preserve any previous selections
281
+ if hasattr(self, 'reference_image_paths') and self.reference_image_paths:
282
+ # Find which of our previously selected paths are still in the filtered list
283
+ valid_selections = [p for p in self.reference_image_paths if p in final_filtered_paths]
284
+ if valid_selections:
285
+ # Highlight previously selected paths that are still valid
286
+ self.image_selection_window.table_model.set_highlighted_paths(valid_selections)
287
+
288
+ # After filtering, update all labels with the correct counts.
289
+ dialog_iw = self.image_selection_window
290
+ dialog_iw.update_image_count_label(len(final_filtered_paths)) # Set "Total" to filtered count
291
+ dialog_iw.update_current_image_index_label()
292
+ dialog_iw.update_highlighted_count_label()
293
+
260
294
  def accept(self):
261
295
  """
262
296
  Validate selections and store them before closing the dialog.
297
+ A prediction is valid if a model and label are selected, and the user
298
+ has provided either reference images or an imported VPE file.
263
299
  """
264
300
  if not self.loaded_model:
265
301
  QMessageBox.warning(self,
266
302
  "No Model",
267
303
  "A model must be loaded before running predictions.")
268
- super().reject()
269
304
  return
270
305
 
271
- current_label = self.source_label_combo_box.currentData()
272
- if not current_label:
306
+ # Set reference label from combo box
307
+ self.reference_label = self.reference_label_combo_box.currentData()
308
+ if not self.reference_label:
273
309
  QMessageBox.warning(self,
274
- "No Source Label",
275
- "A source label must be selected.")
276
- super().reject()
310
+ "No Reference Label",
311
+ "A reference label must be selected.")
277
312
  return
278
313
 
279
- # Get highlighted paths from our internal image window to use as targets
280
- highlighted_images = self.image_selection_window.table_model.get_highlighted_paths()
314
+ # Stash the current UI selection before validating.
315
+ self.update_stashed_references_from_ui()
281
316
 
282
- if not highlighted_images:
317
+ # Check for a valid VPE source using the now-stashed list.
318
+ has_reference_images = bool(self.reference_image_paths)
319
+ has_imported_vpes = bool(self.imported_vpes)
320
+
321
+ if not has_reference_images and not has_imported_vpes:
283
322
  QMessageBox.warning(self,
284
- "No Target Images",
285
- "You must highlight at least one image in the list to process.")
286
- super().reject()
323
+ "No VPE Source Provided",
324
+ "You must highlight at least one reference image or load a VPE file to proceed.")
287
325
  return
288
326
 
289
- # Store the selections for the caller to use after the dialog closes.
290
- self.source_label = current_label
291
- self.target_images = highlighted_images
292
-
293
- # Do not call self.predict here; just close the dialog and let the caller handle prediction
327
+ # If validation passes, close the dialog.
294
328
  super().accept()
295
329
 
296
330
  def setup_info_layout(self):
@@ -317,19 +351,19 @@ class DeployGeneratorDialog(QDialog):
317
351
  Setup the models layout with a simple model selection combo box (no tabs).
318
352
  """
319
353
  group_box = QGroupBox("Model Selection")
320
- layout = QVBoxLayout()
354
+ layout = QFormLayout()
321
355
 
322
356
  self.model_combo = QComboBox()
323
357
  self.model_combo.setEditable(True)
324
358
 
325
359
  # Define available models (keep the existing dictionary)
326
360
  self.models = [
327
- "yoloe-v8s-seg.pt",
328
- "yoloe-v8m-seg.pt",
329
- "yoloe-v8l-seg.pt",
330
- "yoloe-11s-seg.pt",
331
- "yoloe-11m-seg.pt",
332
- "yoloe-11l-seg.pt",
361
+ 'yoloe-v8s-seg.pt',
362
+ 'yoloe-v8m-seg.pt',
363
+ 'yoloe-v8l-seg.pt',
364
+ 'yoloe-11s-seg.pt',
365
+ 'yoloe-11m-seg.pt',
366
+ 'yoloe-11l-seg.pt',
333
367
  ]
334
368
 
335
369
  # Add all models to combo box
@@ -337,10 +371,19 @@ class DeployGeneratorDialog(QDialog):
337
371
  self.model_combo.addItem(model_name)
338
372
 
339
373
  # Set the default model
340
- self.model_combo.setCurrentText("yoloe-v8s-seg.pt")
374
+ self.model_combo.setCurrentIndex(self.models.index('yoloe-v8s-seg.pt'))
375
+ # Create a layout for the model selection
376
+ layout.addRow(QLabel("Models:"), self.model_combo)
377
+
378
+ # Add custom vpe file selection
379
+ self.vpe_path_edit = QLineEdit()
380
+ browse_button = QPushButton("Browse...")
381
+ browse_button.clicked.connect(self.browse_vpe_file)
341
382
 
342
- layout.addWidget(QLabel("Select Model:"))
343
- layout.addWidget(self.model_combo)
383
+ vpe_path_layout = QHBoxLayout()
384
+ vpe_path_layout.addWidget(self.vpe_path_edit)
385
+ vpe_path_layout.addWidget(browse_button)
386
+ layout.addRow("Custom VPE:", vpe_path_layout)
344
387
 
345
388
  group_box.setLayout(layout)
346
389
  self.left_panel.addWidget(group_box) # Add to left panel
@@ -445,17 +488,33 @@ class DeployGeneratorDialog(QDialog):
445
488
  Setup action buttons in a group box.
446
489
  """
447
490
  group_box = QGroupBox("Actions")
448
- layout = QHBoxLayout()
491
+ main_layout = QVBoxLayout()
449
492
 
493
+ # First row: Load and Deactivate buttons side by side
494
+ button_row = QHBoxLayout()
450
495
  load_button = QPushButton("Load Model")
451
496
  load_button.clicked.connect(self.load_model)
452
- layout.addWidget(load_button)
497
+ button_row.addWidget(load_button)
453
498
 
454
499
  deactivate_button = QPushButton("Deactivate Model")
455
500
  deactivate_button.clicked.connect(self.deactivate_model)
456
- layout.addWidget(deactivate_button)
501
+ button_row.addWidget(deactivate_button)
457
502
 
458
- group_box.setLayout(layout)
503
+ main_layout.addLayout(button_row)
504
+
505
+ # Second row: Save VPE button + Show VPE button side by side
506
+ vpe_row = QHBoxLayout()
507
+ save_vpe_button = QPushButton("Save VPE")
508
+ save_vpe_button.clicked.connect(self.save_vpe)
509
+ vpe_row.addWidget(save_vpe_button)
510
+
511
+ show_vpe_button = QPushButton("Show VPE")
512
+ show_vpe_button.clicked.connect(self.show_vpe)
513
+ vpe_row.addWidget(show_vpe_button)
514
+
515
+ main_layout.addLayout(vpe_row)
516
+
517
+ group_box.setLayout(main_layout)
459
518
  self.left_panel.addWidget(group_box) # Add to left panel
460
519
 
461
520
  def setup_status_layout(self):
@@ -471,18 +530,23 @@ class DeployGeneratorDialog(QDialog):
471
530
  group_box.setLayout(layout)
472
531
  self.left_panel.addWidget(group_box) # Add to left panel
473
532
 
474
- def setup_source_layout(self):
533
+ def setup_reference_layout(self):
475
534
  """
476
- Set up the layout with source label selection.
477
- The source image is implicitly the currently active image.
535
+ Set up the layout with reference label selection.
536
+ The reference image is implicitly the currently active image.
478
537
  """
479
- group_box = QGroupBox("Reference Label")
538
+ group_box = QGroupBox("Reference")
480
539
  layout = QFormLayout()
481
540
 
482
- # Create the source label combo box
483
- self.source_label_combo_box = QComboBox()
484
- self.source_label_combo_box.currentIndexChanged.connect(self.filter_images_by_label_and_type)
485
- layout.addRow("Reference Label:", self.source_label_combo_box)
541
+ # Create the reference label combo box
542
+ self.reference_label_combo_box = QComboBox()
543
+ self.reference_label_combo_box.currentIndexChanged.connect(self.filter_images_by_label_and_type)
544
+ layout.addRow("Reference Label:", self.reference_label_combo_box)
545
+
546
+ # Create a Reference model combobox (VPE, Images)
547
+ self.reference_method_combo_box = QComboBox()
548
+ self.reference_method_combo_box.addItems(["VPE", "Images"])
549
+ layout.addRow("Reference Method:", self.reference_method_combo_box)
486
550
 
487
551
  group_box.setLayout(layout)
488
552
  self.right_panel.addWidget(group_box) # Add to right panel
@@ -543,6 +607,10 @@ class DeployGeneratorDialog(QDialog):
543
607
  self.area_thresh_max = max_val / 100.0
544
608
  self.main_window.update_area_thresh(self.area_thresh_min, self.area_thresh_max)
545
609
  self.area_threshold_label.setText(f"{self.area_thresh_min:.2f} - {self.area_thresh_max:.2f}")
610
+
611
+ def update_stashed_references_from_ui(self):
612
+ """Updates the internal reference path list from the current UI selection."""
613
+ self.reference_image_paths = self.image_selection_window.table_model.get_highlighted_paths()
546
614
 
547
615
  def get_max_detections(self):
548
616
  """Get the maximum number of detections to return."""
@@ -603,54 +671,44 @@ class DeployGeneratorDialog(QDialog):
603
671
  if self.loaded_model:
604
672
  self.deactivate_model()
605
673
 
606
- def update_source_labels(self):
674
+ def update_reference_labels(self):
607
675
  """
608
- Updates the source label combo box with labels that are associated with
609
- valid reference annotations (Polygons or Rectangles), using the fast cache.
676
+ Updates the reference label combo box with ALL available project labels.
677
+ This dropdown now serves as the "Output Label" for all predictions.
678
+ The "Review" label with id "-1" is excluded.
610
679
  """
611
- self.source_label_combo_box.blockSignals(True)
680
+ self.reference_label_combo_box.blockSignals(True)
612
681
 
613
682
  try:
614
- self.source_label_combo_box.clear()
615
-
616
- dialog_manager = self.image_selection_window.raster_manager
617
- valid_types = {"RectangleAnnotation", "PolygonAnnotation"}
618
- valid_labels = set() # This will store the full Label objects
619
-
620
- # Create a lookup map to get full label objects from their codes
621
- all_project_labels = {lbl.short_label_code: lbl for lbl in self.main_window.label_window.labels}
622
-
623
- # Use the cached data to find all labels that have valid reference types.
624
- for raster in dialog_manager.rasters.values():
625
- # raster.label_to_types_map is like: {'coral': {'Point'}, 'rock': {'Polygon'}}
626
- for label_code, types_for_label in raster.label_to_types_map.items():
627
- # Check if the set of types for this specific label
628
- # has any overlap with our valid reference types.
629
- if not valid_types.isdisjoint(types_for_label):
630
- # This label is a valid reference label.
631
- # Add its full Label object to our set of valid labels.
632
- if label_code in all_project_labels:
633
- valid_labels.add(all_project_labels[label_code])
683
+ self.reference_label_combo_box.clear()
684
+
685
+ # Get all labels from the main label window
686
+ all_project_labels = self.main_window.label_window.labels
687
+
688
+ # Filter out the special "Review" label and create a list of valid labels
689
+ valid_labels = [
690
+ label_obj for label_obj in all_project_labels
691
+ if not (label_obj.short_label_code == "Review" and str(label_obj.id) == "-1")
692
+ ]
634
693
 
635
694
  # Add the valid labels to the combo box, sorted alphabetically.
636
- sorted_valid_labels = sorted(list(valid_labels), key=lambda x: x.short_label_code)
695
+ sorted_valid_labels = sorted(valid_labels, key=lambda x: x.short_label_code)
637
696
  for label_obj in sorted_valid_labels:
638
- self.source_label_combo_box.addItem(label_obj.short_label_code, label_obj)
697
+ self.reference_label_combo_box.addItem(label_obj.short_label_code, label_obj)
639
698
 
640
699
  # Restore the last selected label if it's still present in the list.
641
700
  if self.last_selected_label_code:
642
- index = self.source_label_combo_box.findText(self.last_selected_label_code)
701
+ index = self.reference_label_combo_box.findText(self.last_selected_label_code)
643
702
  if index != -1:
644
- self.source_label_combo_box.setCurrentIndex(index)
703
+ self.reference_label_combo_box.setCurrentIndex(index)
645
704
  finally:
646
- self.source_label_combo_box.blockSignals(False)
705
+ self.reference_label_combo_box.blockSignals(False)
647
706
 
648
- # Manually trigger the filtering now that the combo box is stable.
707
+ # Manually trigger the image filtering now that the combo box is stable.
708
+ # This will still filter the image list to help find references if needed.
649
709
  self.filter_images_by_label_and_type()
650
710
 
651
- return True
652
-
653
- def get_source_annotations(self, reference_label, reference_image_path):
711
+ def get_reference_annotations(self, reference_label, reference_image_path):
654
712
  """
655
713
  Return a list of bboxes and masks for a specific image
656
714
  belonging to the selected label.
@@ -666,22 +724,166 @@ class DeployGeneratorDialog(QDialog):
666
724
  annotations = self.annotation_window.get_image_annotations(reference_image_path)
667
725
 
668
726
  # Filter annotations by the provided label
669
- source_bboxes = []
670
- source_masks = []
727
+ reference_bboxes = []
728
+ reference_masks = []
671
729
  for annotation in annotations:
672
730
  if annotation.label.short_label_code == reference_label.short_label_code:
673
731
  if isinstance(annotation, (PolygonAnnotation, RectangleAnnotation)):
674
732
  bbox = annotation.cropped_bbox
675
- source_bboxes.append(bbox)
733
+ reference_bboxes.append(bbox)
676
734
  if isinstance(annotation, PolygonAnnotation):
677
735
  points = np.array([[p.x(), p.y()] for p in annotation.points])
678
- source_masks.append(points)
736
+ reference_masks.append(points)
679
737
  elif isinstance(annotation, RectangleAnnotation):
680
738
  x1, y1, x2, y2 = bbox
681
739
  rect_points = np.array([[x1, y1], [x2, y1], [x2, y2], [x1, y2]])
682
- source_masks.append(rect_points)
740
+ reference_masks.append(rect_points)
741
+
742
+ return np.array(reference_bboxes), reference_masks
743
+
744
+ def browse_vpe_file(self):
745
+ """
746
+ Open a file dialog to browse for a VPE file and load it.
747
+ Stores imported VPEs separately from reference-generated VPEs.
748
+ """
749
+ file_path, _ = QFileDialog.getOpenFileName(
750
+ self,
751
+ "Select Visual Prompt Encoding (VPE) File",
752
+ "",
753
+ "VPE Files (*.pt);;All Files (*)"
754
+ )
755
+
756
+ if not file_path:
757
+ return
758
+
759
+ self.vpe_path_edit.setText(file_path)
760
+ self.vpe_path = file_path
761
+
762
+ try:
763
+ # Load the VPE file
764
+ loaded_data = torch.load(file_path)
765
+
766
+ # TODO Move tensors to the appropriate device
767
+ # device = self.main_window.device
768
+
769
+ # Check format type and handle appropriately
770
+ if isinstance(loaded_data, list):
771
+ # New format: list of VPE tensors
772
+ self.imported_vpes = [vpe.to(self.device) for vpe in loaded_data]
773
+ vpe_count = len(self.imported_vpes)
774
+ self.status_bar.setText(f"Loaded {vpe_count} VPE tensors from file")
775
+
776
+ elif isinstance(loaded_data, torch.Tensor):
777
+ # Legacy format: single tensor - convert to list for consistency
778
+ loaded_vpe = loaded_data.to(self.device)
779
+ # Store as a single-item list
780
+ self.imported_vpes = [loaded_vpe]
781
+ self.status_bar.setText("Loaded 1 VPE tensor from file (legacy format)")
782
+
783
+ else:
784
+ # Invalid format
785
+ self.imported_vpes = []
786
+ self.status_bar.setText("Invalid VPE file format")
787
+ QMessageBox.warning(
788
+ self,
789
+ "Invalid VPE",
790
+ "The file does not appear to be a valid VPE format."
791
+ )
792
+ # Clear the VPE path edit field
793
+ self.vpe_path_edit.clear()
794
+
795
+ # For backward compatibility - set self.vpe to the average of imported VPEs
796
+ # This ensures older code paths still work
797
+ if self.imported_vpes:
798
+ combined_vpe = torch.cat(self.imported_vpes).mean(dim=0, keepdim=True)
799
+ self.vpe = torch.nn.functional.normalize(combined_vpe, p=2, dim=-1)
800
+
801
+ except Exception as e:
802
+ self.imported_vpes = []
803
+ self.vpe = None
804
+ self.status_bar.setText(f"Error loading VPE: {str(e)}")
805
+ QMessageBox.critical(
806
+ self,
807
+ "Error Loading VPE",
808
+ f"Failed to load VPE file: {str(e)}"
809
+ )
810
+
811
+ def save_vpe(self):
812
+ """
813
+ Save the combined collection of VPEs (imported and reference-generated) to disk.
814
+ """
815
+ # Always sync with the live UI selection before saving.
816
+ self.update_stashed_references_from_ui()
683
817
 
684
- return np.array(source_bboxes), source_masks
818
+ # Create a list to hold all VPEs
819
+ all_vpes = []
820
+
821
+ # Add imported VPEs if available
822
+ if self.imported_vpes:
823
+ all_vpes.extend(self.imported_vpes)
824
+
825
+ # Check if we should generate new VPEs from reference images
826
+ references_dict = self._get_references()
827
+ if references_dict:
828
+ # Reload the model to ensure clean state
829
+ self.reload_model()
830
+
831
+ # Convert references to VPEs without updating self.reference_vpes yet
832
+ new_vpes = self.references_to_vpe(references_dict, update_reference_vpes=False)
833
+
834
+ if new_vpes:
835
+ # Add new VPEs to collection
836
+ all_vpes.extend(new_vpes)
837
+ # Update reference_vpes with the new ones
838
+ self.reference_vpes = new_vpes
839
+ else:
840
+ # Include existing reference VPEs if we have them
841
+ if self.reference_vpes:
842
+ all_vpes.extend(self.reference_vpes)
843
+
844
+ # Check if we have any VPEs to save
845
+ if not all_vpes:
846
+ QMessageBox.warning(
847
+ self,
848
+ "No VPEs Available",
849
+ "No VPEs available to save. Please either load a VPE file or select reference images."
850
+ )
851
+ return
852
+
853
+ # Open file dialog to select save location
854
+ file_path, _ = QFileDialog.getSaveFileName(
855
+ self,
856
+ "Save VPE Collection",
857
+ "",
858
+ "PyTorch Tensor (*.pt);;All Files (*)"
859
+ )
860
+
861
+ if not file_path:
862
+ return # User canceled the dialog
863
+
864
+ # Add .pt extension if not present
865
+ if not file_path.endswith('.pt'):
866
+ file_path += '.pt'
867
+
868
+ try:
869
+ # Move tensors to CPU before saving
870
+ vpe_list_cpu = [vpe.cpu() for vpe in all_vpes]
871
+
872
+ # Save the list of tensors
873
+ torch.save(vpe_list_cpu, file_path)
874
+
875
+ self.status_bar.setText(f"Saved {len(all_vpes)} VPE tensors to {os.path.basename(file_path)}")
876
+ QMessageBox.information(
877
+ self,
878
+ "VPE Saved",
879
+ f"Saved {len(all_vpes)} VPE tensors to {file_path}"
880
+ )
881
+ except Exception as e:
882
+ QMessageBox.critical(
883
+ self,
884
+ "Error Saving VPE",
885
+ f"Failed to save VPE: {str(e)}"
886
+ )
685
887
 
686
888
  def load_model(self):
687
889
  """
@@ -692,40 +894,32 @@ class DeployGeneratorDialog(QDialog):
692
894
  progress_bar.show()
693
895
 
694
896
  try:
695
- # Get selected model path and download weights if needed
696
- self.model_path = self.model_combo.currentText()
697
-
698
- # Load model using registry
699
- self.loaded_model = YOLOE(self.model_path).to(self.main_window.device)
700
-
701
- # Create a dummy visual dictionary
702
- visuals = dict(
703
- bboxes=np.array(
704
- [
705
- [120, 425, 160, 445],
706
- ],
707
- ),
708
- cls=np.array(
709
- np.zeros(1),
710
- ),
711
- )
897
+ # Load the model using reload_model method
898
+ self.reload_model()
712
899
 
713
- # Run a dummy prediction to load the model
714
- self.loaded_model.predict(
715
- np.zeros((640, 640, 3), dtype=np.uint8),
716
- visual_prompts=visuals.copy(),
717
- predictor=YOLOEVPDetectPredictor,
718
- imgsz=640,
719
- conf=0.99,
720
- )
900
+ # Calculate total number of VPEs from both sources
901
+ total_vpes = len(self.imported_vpes) + len(self.reference_vpes)
902
+
903
+ if total_vpes > 0:
904
+ if self.imported_vpes and self.reference_vpes:
905
+ message = f"Model loaded with {len(self.imported_vpes)} imported VPEs "
906
+ message += f"and {len(self.reference_vpes)} reference VPEs"
907
+ elif self.imported_vpes:
908
+ message = f"Model loaded with {len(self.imported_vpes)} imported VPEs"
909
+ else:
910
+ message = f"Model loaded with {len(self.reference_vpes)} reference VPEs"
911
+
912
+ self.status_bar.setText(message)
913
+ else:
914
+ message = "Model loaded with default VPE"
915
+ self.status_bar.setText("Model loaded with default VPE")
721
916
 
917
+ # Finish progress bar
722
918
  progress_bar.finish_progress()
723
- self.status_bar.setText("Model loaded")
724
- QMessageBox.information(self.annotation_window,
725
- "Model Loaded",
726
- "Model loaded successfully")
919
+ QMessageBox.information(self.annotation_window, "Model Loaded", message)
727
920
 
728
921
  except Exception as e:
922
+ self.loaded_model = None
729
923
  QMessageBox.critical(self.annotation_window,
730
924
  "Error Loading Model",
731
925
  f"Error loading model: {e}")
@@ -737,7 +931,52 @@ class DeployGeneratorDialog(QDialog):
737
931
  progress_bar.stop_progress()
738
932
  progress_bar.close()
739
933
  progress_bar = None
934
+
935
+ def reload_model(self):
936
+ """
937
+ Subset of the load_model method. This is needed when additional
938
+ reference images and annotations (i.e., VPEs) are added (we have
939
+ to re-load the model each time).
940
+
941
+ This method also ensures that we stash the currently highlighted reference
942
+ image paths before reloading, so they're available for predictions
943
+ even if the user switches the active image in the main window.
944
+ """
945
+ self.loaded_model = None
946
+
947
+ # Get selected model path and download weights if needed
948
+ self.model_path = self.model_combo.currentText()
949
+
950
+ # Load model using registry
951
+ self.loaded_model = YOLOE(self.model_path, verbose=False).to(self.device) # TODO
952
+
953
+ # Create a dummy visual dictionary for standard model loading
954
+ visual_prompts = dict(
955
+ bboxes=np.array(
956
+ [
957
+ [120, 425, 160, 445], # Random box
958
+ ],
959
+ ),
960
+ cls=np.array(
961
+ np.zeros(1),
962
+ ),
963
+ )
964
+
965
+ # Run a dummy prediction to load the model
966
+ self.loaded_model.predict(
967
+ np.zeros((640, 640, 3), dtype=np.uint8),
968
+ visual_prompts=visual_prompts.copy(), # This needs to happen to properly initialize the predictor
969
+ predictor=YOLOEVPSegPredictor, # This also needs to be SegPredictor, no matter what
970
+ imgsz=640,
971
+ conf=0.99,
972
+ )
740
973
 
974
+ # If a VPE file was loaded, use it with the model after the dummy prediction
975
+ if self.vpe is not None and isinstance(self.vpe, torch.Tensor):
976
+ # Directly set the final tensor as the prompt for the predictor
977
+ self.loaded_model.is_fused = lambda: False
978
+ self.loaded_model.set_classes(["object0"], self.vpe)
979
+
741
980
  def predict(self, image_paths=None):
742
981
  """
743
982
  Make predictions on the given image paths using the loaded model.
@@ -745,11 +984,11 @@ class DeployGeneratorDialog(QDialog):
745
984
  Args:
746
985
  image_paths: List of image paths to process. If None, uses the current image.
747
986
  """
748
- if not self.loaded_model or not self.source_label:
987
+ if not self.loaded_model or not self.reference_label:
749
988
  return
750
989
 
751
990
  # Update class mapping with the selected reference label
752
- self.class_mapping = {0: self.source_label}
991
+ self.class_mapping = {0: self.reference_label}
753
992
 
754
993
  # Create a results processor
755
994
  results_processor = ResultsProcessor(
@@ -809,32 +1048,33 @@ class DeployGeneratorDialog(QDialog):
809
1048
 
810
1049
  return work_areas_data
811
1050
 
812
- def _apply_model(self, inputs):
813
- """
814
- Apply the model to the target inputs, using each highlighted source
815
- image as an individual reference for a separate prediction run.
1051
+ def _get_references(self):
816
1052
  """
817
- # Update the model with user parameters
818
- self.loaded_model.conf = self.main_window.get_uncertainty_thresh()
819
- self.loaded_model.iou = self.main_window.get_iou_thresh()
820
- self.loaded_model.max_det = self.get_max_detections()
1053
+ Get the reference annotations using the stashed list of reference images
1054
+ that was saved when the user accepted the dialog.
821
1055
 
822
- # NOTE: self.target_images contains the reference images highlighted in the dialog
823
- reference_image_paths = self.target_images
824
-
825
- if not reference_image_paths:
826
- QMessageBox.warning(self,
827
- "No Reference Images",
828
- "You must highlight at least one reference image.")
829
- return []
830
-
831
- # Get the selected reference label from the stored variable
832
- source_label = self.source_label
1056
+ Returns:
1057
+ dict: Dictionary mapping image paths to annotation data, or empty dict if no valid references.
1058
+ """
1059
+ # Use the "stashed" list of paths. Do NOT query the table_model again,
1060
+ # as the UI's highlight state may have been cleared by other actions.
1061
+ reference_paths = self.reference_image_paths
1062
+
1063
+ if not reference_paths:
1064
+ print("No reference image paths were stashed to use for prediction.")
1065
+ return {}
1066
+
1067
+ # Get the reference label that was also stashed
1068
+ reference_label = self.reference_label
1069
+ if not reference_label:
1070
+ # This check is a safeguard; the accept() method should prevent this.
1071
+ print("No reference label was selected.")
1072
+ return {}
833
1073
 
834
- # Create a dictionary of reference annotations, with image path as the key.
1074
+ # Create a dictionary of reference annotations from the stashed paths
835
1075
  reference_annotations_dict = {}
836
- for path in reference_image_paths:
837
- bboxes, masks = self.get_source_annotations(source_label, path)
1076
+ for path in reference_paths:
1077
+ bboxes, masks = self.get_reference_annotations(reference_label, path)
838
1078
  if bboxes.size > 0:
839
1079
  reference_annotations_dict[path] = {
840
1080
  'bboxes': bboxes,
@@ -842,37 +1082,48 @@ class DeployGeneratorDialog(QDialog):
842
1082
  'cls': np.zeros(len(bboxes))
843
1083
  }
844
1084
 
845
- # Set the task
846
- self.task = self.use_task_dropdown.currentText()
847
- predictor = YOLOEVPSegPredictor if self.task == "segment" else YOLOEVPDetectPredictor
1085
+ return reference_annotations_dict
1086
+
1087
+ def _apply_model_using_images(self, inputs, reference_dict):
1088
+ """
1089
+ Apply the model using the provided images and reference annotations (dict). This method
1090
+ loops through each reference image using its annotations; we then aggregate
1091
+ all the results together. Less efficient, but potentially more accurate.
848
1092
 
1093
+ Args:
1094
+ inputs (list): List of input images.
1095
+ reference_dict (dict): Dictionary containing reference annotations for each image.
1096
+
1097
+ Returns:
1098
+ list: List of prediction results.
1099
+ """
849
1100
  # Create a progress bar for iterating through reference images
850
1101
  QApplication.setOverrideCursor(Qt.WaitCursor)
851
1102
  progress_bar = ProgressBar(self.annotation_window, title="Making Predictions per Reference")
852
1103
  progress_bar.show()
853
- progress_bar.start_progress(len(reference_annotations_dict))
854
-
1104
+ progress_bar.start_progress(len(reference_dict))
1105
+
855
1106
  results_list = []
856
1107
  # The 'inputs' list contains work areas from the single target image.
857
1108
  # We will predict on the first work area/full image.
858
1109
  input_image = inputs[0]
859
1110
 
860
1111
  # Iterate through each reference image and its annotations
861
- for ref_path, ref_annotations in reference_annotations_dict.items():
1112
+ for ref_path, ref_annotations in reference_dict.items():
862
1113
  # The 'refer_image' parameter is the path to the current reference image
863
1114
  # The 'visual_prompts' are the annotations from that same reference image
864
- visuals = {
1115
+ visual_prompts = {
865
1116
  'bboxes': ref_annotations['bboxes'],
866
1117
  'cls': ref_annotations['cls'],
867
1118
  }
868
1119
  if self.task == 'segment':
869
- visuals['masks'] = ref_annotations['masks']
1120
+ visual_prompts['masks'] = ref_annotations['masks']
870
1121
 
871
1122
  # Make predictions on the target using the current reference
872
1123
  results = self.loaded_model.predict(input_image,
873
1124
  refer_image=ref_path,
874
- visual_prompts=visuals,
875
- predictor=predictor,
1125
+ visual_prompts=visual_prompts,
1126
+ predictor=YOLOEVPSegPredictor, # TODO This is necessary here?
876
1127
  imgsz=self.imgsz_spinbox.value(),
877
1128
  conf=self.main_window.get_uncertainty_thresh(),
878
1129
  iou=self.main_window.get_iou_thresh(),
@@ -904,6 +1155,158 @@ class DeployGeneratorDialog(QDialog):
904
1155
  return []
905
1156
 
906
1157
  return [[combined_results]]
1158
+
1159
+ def references_to_vpe(self, reference_dict, update_reference_vpes=True):
1160
+ """
1161
+ Converts the contents of a reference dictionary to VPEs (Visual Prompt Embeddings).
1162
+ Reference dictionaries contain information about the visual prompts for each reference image:
1163
+ dict[image_path]: {bboxes, masks, cls}
1164
+
1165
+ Args:
1166
+ reference_dict (dict): The reference dictionary containing visual prompts for each image.
1167
+ update_reference_vpes (bool): Whether to update self.reference_vpes with the results.
1168
+
1169
+ Returns:
1170
+ list: List of individual VPE tensors (normalized), or None if empty reference_dict
1171
+ """
1172
+ # Check if the reference dictionary is empty
1173
+ if not reference_dict:
1174
+ return None
1175
+
1176
+ # Create a list to hold the individual VPE tensors
1177
+ vpe_list = []
1178
+
1179
+ for ref_path, ref_annotations in reference_dict.items():
1180
+ # Set the prompts to the model predictor
1181
+ self.loaded_model.predictor.set_prompts(ref_annotations)
1182
+
1183
+ # Get the VPE from the model
1184
+ vpe = self.loaded_model.predictor.get_vpe(ref_path)
1185
+
1186
+ # Normalize individual VPE
1187
+ vpe_normalized = torch.nn.functional.normalize(vpe, p=2, dim=-1)
1188
+ vpe_list.append(vpe_normalized)
1189
+
1190
+ # Check if we have any valid VPEs
1191
+ if not vpe_list:
1192
+ return None
1193
+
1194
+ # Update the reference_vpes list if requested
1195
+ if update_reference_vpes:
1196
+ self.reference_vpes = vpe_list
1197
+
1198
+ return vpe_list
1199
+
1200
+ def _apply_model_using_vpe(self, inputs, references_dict):
1201
+ """
1202
+ Apply the model to the inputs using combined VPEs from both imported files
1203
+ and reference annotations.
1204
+
1205
+ Args:
1206
+ inputs (list): List of input images.
1207
+ references_dict (dict): Dictionary containing reference annotations for each image.
1208
+
1209
+ Returns:
1210
+ list: List of prediction results.
1211
+ """
1212
+ # First reload the model to clear any cached data
1213
+ self.reload_model()
1214
+
1215
+ # Initialize combined_vpes list
1216
+ combined_vpes = []
1217
+
1218
+ # Add imported VPEs if available
1219
+ if self.imported_vpes:
1220
+ combined_vpes.extend(self.imported_vpes)
1221
+
1222
+ # Process reference images to VPEs if any exist
1223
+ if references_dict:
1224
+ # Only update reference_vpes if references_dict is not empty
1225
+ reference_vpes = self.references_to_vpe(references_dict, update_reference_vpes=True)
1226
+ if reference_vpes:
1227
+ combined_vpes.extend(reference_vpes)
1228
+ else:
1229
+ # Use existing reference_vpes if we have them
1230
+ if self.reference_vpes:
1231
+ combined_vpes.extend(self.reference_vpes)
1232
+
1233
+ # Check if we have any VPEs to use
1234
+ if not combined_vpes:
1235
+ QMessageBox.warning(
1236
+ self,
1237
+ "No VPEs Available",
1238
+ "No VPEs available for prediction. Please either load a VPE file or select reference images."
1239
+ )
1240
+ return []
1241
+
1242
+ # Average all the VPEs together to create a final VPE tensor
1243
+ averaged_vpe = torch.cat(combined_vpes).mean(dim=0, keepdim=True)
1244
+ final_vpe = torch.nn.functional.normalize(averaged_vpe, p=2, dim=-1)
1245
+
1246
+ # For backward compatibility, update self.vpe
1247
+ self.vpe = final_vpe
1248
+
1249
+ # Set the final VPE to the model
1250
+ self.loaded_model.is_fused = lambda: False
1251
+ self.loaded_model.set_classes(["object0"], final_vpe)
1252
+
1253
+ # Make predictions on the target using the averaged VPE
1254
+ results = self.loaded_model.predict(inputs[0],
1255
+ visual_prompts=[],
1256
+ imgsz=self.imgsz_spinbox.value(),
1257
+ conf=self.main_window.get_uncertainty_thresh(),
1258
+ iou=self.main_window.get_iou_thresh(),
1259
+ max_det=self.get_max_detections(),
1260
+ retina_masks=self.task == "segment")
1261
+
1262
+ return [results]
1263
+
1264
+ def _apply_model(self, inputs):
1265
+ """
1266
+ Apply the model to the target inputs. This method handles both image-based
1267
+ references and VPE-based references.
1268
+ """
1269
+ # Update the model with user parameters
1270
+ self.task = self.use_task_dropdown.currentText()
1271
+
1272
+ self.loaded_model.conf = self.main_window.get_uncertainty_thresh()
1273
+ self.loaded_model.iou = self.main_window.get_iou_thresh()
1274
+ self.loaded_model.max_det = self.get_max_detections()
1275
+
1276
+ # Get the reference information for the currently selected rows
1277
+ references_dict = self._get_references()
1278
+
1279
+ # Check if the user is using VPE or Reference Images
1280
+ if self.reference_method_combo_box.currentText() == "VPE":
1281
+ # Check if we have any VPEs available (imported or reference-generated)
1282
+ has_vpes = bool(self.imported_vpes or self.reference_vpes)
1283
+
1284
+ # If we have reference images selected but no imported VPEs yet,
1285
+ # warn the user only if we also don't have any reference images
1286
+ if not has_vpes and not references_dict:
1287
+ QMessageBox.warning(
1288
+ self,
1289
+ "No VPEs Available",
1290
+ "No VPEs available for prediction. Please either load a VPE file or select reference images."
1291
+ )
1292
+ return []
1293
+
1294
+ # Use the VPE method, which will combine imported and reference VPEs
1295
+ results = self._apply_model_using_vpe(inputs, references_dict)
1296
+ else:
1297
+ # Use Reference Images method - requires reference images
1298
+ if not references_dict:
1299
+ QMessageBox.warning(
1300
+ self,
1301
+ "No References Selected",
1302
+ "No reference images with valid annotations were selected. "
1303
+ "Please select at least one reference image."
1304
+ )
1305
+ return []
1306
+
1307
+ results = self._apply_model_using_images(inputs, references_dict)
1308
+
1309
+ return results
907
1310
 
908
1311
  def _apply_sam(self, results_list, image_path):
909
1312
  """Apply SAM to the results if needed."""
@@ -1000,17 +1403,232 @@ class DeployGeneratorDialog(QDialog):
1000
1403
  progress_bar.stop_progress()
1001
1404
  progress_bar.close()
1002
1405
 
1406
+ def show_vpe(self):
1407
+ """
1408
+ Show a visualization of the VPEs using PyQtGraph.
1409
+ This method now always recalculates VPEs from the currently highlighted reference images.
1410
+ """
1411
+ # Set cursor to busy while loading VPEs
1412
+ QApplication.setOverrideCursor(Qt.WaitCursor)
1413
+
1414
+ try:
1415
+ # Always sync with the live UI selection before visualizing.
1416
+ self.update_stashed_references_from_ui()
1417
+
1418
+ vpes_with_source = []
1419
+
1420
+ # 1. Add any VPEs that were loaded from a file
1421
+ if self.imported_vpes:
1422
+ for vpe in self.imported_vpes:
1423
+ vpes_with_source.append((vpe, "Import"))
1424
+
1425
+ # 2. Get the currently selected reference images from the stashed list
1426
+ references_dict = self._get_references()
1427
+
1428
+ # 3. If there are reference images, calculate their VPEs and add with source type
1429
+ if references_dict:
1430
+ self.reload_model()
1431
+ new_reference_vpes = self.references_to_vpe(references_dict, update_reference_vpes=True)
1432
+ if new_reference_vpes:
1433
+ for vpe in new_reference_vpes:
1434
+ vpes_with_source.append((vpe, "Reference"))
1435
+
1436
+ # 4. Check if there is anything to visualize
1437
+ if not vpes_with_source:
1438
+ QMessageBox.warning(
1439
+ self,
1440
+ "No VPEs Available",
1441
+ "No VPEs available to visualize. Please either load a VPE file or select reference images."
1442
+ )
1443
+ return
1444
+
1445
+ # 5. Create the visualization dialog, passing the list of tuples
1446
+ all_vpe_tensors = [vpe for vpe, source in vpes_with_source]
1447
+ averaged_vpe = torch.cat(all_vpe_tensors).mean(dim=0, keepdim=True)
1448
+ final_vpe = torch.nn.functional.normalize(averaged_vpe, p=2, dim=-1)
1449
+
1450
+ dialog = VPEVisualizationDialog(vpes_with_source, final_vpe, self)
1451
+ dialog.exec_()
1452
+
1453
+ finally:
1454
+ # Always restore cursor, even if an exception occurs
1455
+ QApplication.restoreOverrideCursor()
1456
+
1003
1457
  def deactivate_model(self):
1004
1458
  """
1005
1459
  Deactivate the currently loaded model and clean up resources.
1006
1460
  """
1007
1461
  self.loaded_model = None
1008
1462
  self.model_path = None
1009
- # Clean up resources
1463
+
1464
+ # Clear all VPE-related data
1465
+ self.vpe_path_edit.clear()
1466
+ self.vpe_path = None
1467
+ self.vpe = None
1468
+ self.imported_vpes = []
1469
+ self.reference_vpes = []
1470
+
1471
+ # Clean up references
1010
1472
  gc.collect()
1011
1473
  torch.cuda.empty_cache()
1474
+
1012
1475
  # Untoggle all tools
1013
1476
  self.main_window.untoggle_all_tools()
1477
+
1014
1478
  # Update status bar
1015
1479
  self.status_bar.setText("No model loaded")
1016
- QMessageBox.information(self, "Model Deactivated", "Model deactivated")
1480
+ QMessageBox.information(self, "Model Deactivated", "Model deactivated")
1481
+
1482
+
1483
+ class VPEVisualizationDialog(QDialog):
1484
+ """
1485
+ Dialog for visualizing VPE embeddings in 2D space using PCA.
1486
+ """
1487
+ def __init__(self, vpe_list_with_source, final_vpe=None, parent=None):
1488
+ """
1489
+ Initialize the dialog with a list of VPE tensors and their sources.
1490
+
1491
+ Args:
1492
+ vpe_list_with_source (list): List of (VPE tensor, source_str) tuples
1493
+ final_vpe (torch.Tensor, optional): The final (averaged) VPE
1494
+ parent (QWidget, optional): Parent widget
1495
+ """
1496
+ super().__init__(parent)
1497
+ self.setWindowTitle("VPE Visualization")
1498
+ self.resize(1000, 1000)
1499
+
1500
+ # Add a maximize button to the dialog's title bar
1501
+ self.setWindowFlags(self.windowFlags() | Qt.WindowMaximizeButtonHint)
1502
+
1503
+ # Store the VPEs and their sources
1504
+ self.vpe_list_with_source = vpe_list_with_source
1505
+ self.final_vpe = final_vpe
1506
+
1507
+ # Create the layout
1508
+ layout = QVBoxLayout(self)
1509
+
1510
+ # Create the plot widget
1511
+ self.plot_widget = pg.PlotWidget()
1512
+ self.plot_widget.setBackground('w') # White background
1513
+ self.plot_widget.setTitle("PCA Visualization of Visual Prompt Embeddings", color="#000000", size="10pt")
1514
+ self.plot_widget.showGrid(x=True, y=True, alpha=0.3)
1515
+
1516
+ # Add the plot widget to the layout
1517
+ layout.addWidget(self.plot_widget)
1518
+
1519
+ # Add spacing between plot_widget and info_label
1520
+ layout.addSpacing(20)
1521
+
1522
+ # Add information label at the bottom
1523
+ self.info_label = QLabel()
1524
+ self.info_label.setAlignment(Qt.AlignCenter)
1525
+ layout.addWidget(self.info_label)
1526
+
1527
+ # Create the button box
1528
+ button_box = QDialogButtonBox(QDialogButtonBox.Close)
1529
+ button_box.rejected.connect(self.reject)
1530
+ layout.addWidget(button_box)
1531
+
1532
+ # Visualize the VPEs
1533
+ self.visualize_vpes()
1534
+
1535
+ def visualize_vpes(self):
1536
+ """
1537
+ Apply PCA to the VPE tensors and visualize them in 2D space.
1538
+ """
1539
+ if not self.vpe_list_with_source:
1540
+ self.info_label.setText("No VPEs available to visualize.")
1541
+ return
1542
+
1543
+ # Convert tensors to numpy arrays for PCA, separating them from the source string
1544
+ vpe_arrays = [vpe.detach().cpu().numpy().squeeze() for vpe, source in self.vpe_list_with_source]
1545
+
1546
+ # If final VPE is provided, add it to the arrays
1547
+ final_vpe_array = None
1548
+ if self.final_vpe is not None:
1549
+ final_vpe_array = self.final_vpe.detach().cpu().numpy().squeeze()
1550
+ all_vpes = np.vstack(vpe_arrays + [final_vpe_array])
1551
+ else:
1552
+ all_vpes = np.vstack(vpe_arrays)
1553
+
1554
+ # Apply PCA to reduce to 2 dimensions
1555
+ pca = PCA(n_components=2)
1556
+ vpes_2d = pca.fit_transform(all_vpes)
1557
+
1558
+ # Clear the plot
1559
+ self.plot_widget.clear()
1560
+
1561
+ # Generate random colors for individual VPEs
1562
+ num_vpes = len(vpe_arrays)
1563
+ colors = self.generate_distinct_colors(num_vpes)
1564
+
1565
+ # Create a legend with 3 columns to keep it compact
1566
+ legend = self.plot_widget.addLegend(colCount=3)
1567
+
1568
+ # Plot individual VPEs
1569
+ for i, (vpe_tuple, vpe_2d) in enumerate(zip(self.vpe_list_with_source, vpes_2d[:num_vpes])):
1570
+ source_char = 'I' if vpe_tuple[1] == 'Import' else 'R'
1571
+ color = pg.mkColor(colors[i])
1572
+ scatter = pg.ScatterPlotItem(
1573
+ x=[vpe_2d[0]],
1574
+ y=[vpe_2d[1]],
1575
+ brush=color,
1576
+ size=15,
1577
+ name=f"VPE {i+1} ({source_char})"
1578
+ )
1579
+ self.plot_widget.addItem(scatter)
1580
+
1581
+ # Plot the final (averaged) VPE if available
1582
+ if final_vpe_array is not None:
1583
+ final_vpe_2d = vpes_2d[-1]
1584
+ scatter = pg.ScatterPlotItem(
1585
+ x=[final_vpe_2d[0]],
1586
+ y=[final_vpe_2d[1]],
1587
+ brush=pg.mkBrush(color='r'),
1588
+ size=20,
1589
+ symbol='star',
1590
+ name="Final VPE"
1591
+ )
1592
+ self.plot_widget.addItem(scatter)
1593
+
1594
+ # Update the information label
1595
+ orig_dim = self.vpe_list_with_source[0][0].shape[-1]
1596
+ explained_variance = sum(pca.explained_variance_ratio_)
1597
+ self.info_label.setText(
1598
+ f"Original dimension: {orig_dim} → Reduced to 2D\n"
1599
+ f"Total explained variance: {explained_variance:.2%}\n"
1600
+ f"PC1: {pca.explained_variance_ratio_[0]:.2%} variance, "
1601
+ f"PC2: {pca.explained_variance_ratio_[1]:.2%} variance"
1602
+ )
1603
+
1604
+ def generate_distinct_colors(self, num_colors):
1605
+ """
1606
+ Generate visually distinct colors by using evenly spaced hues
1607
+ with random saturation and value.
1608
+
1609
+ Args:
1610
+ num_colors (int): Number of colors to generate
1611
+
1612
+ Returns:
1613
+ list: List of color hex strings
1614
+ """
1615
+ import random
1616
+ from colorsys import hsv_to_rgb
1617
+
1618
+ colors = []
1619
+ for i in range(num_colors):
1620
+ # Use golden ratio to space hues evenly
1621
+ hue = (i * 0.618033988749895) % 1.0
1622
+ # Random saturation between 0.6-1.0 (avoid too pale)
1623
+ saturation = random.uniform(0.6, 1.0)
1624
+ # Random value between 0.7-1.0 (avoid too dark)
1625
+ value = random.uniform(0.7, 1.0)
1626
+
1627
+ # Convert HSV to RGB (0-1 range)
1628
+ r, g, b = hsv_to_rgb(hue, saturation, value)
1629
+
1630
+ # Convert RGB to hex string
1631
+ hex_color = f"#{int(r*255):02x}{int(g*255):02x}{int(b*255):02x}"
1632
+ colors.append(hex_color)
1633
+
1634
+ return colors