coralnet-toolbox 0.0.75__py2.py3-none-any.whl → 0.0.77__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. coralnet_toolbox/Annotations/QtPolygonAnnotation.py +57 -12
  2. coralnet_toolbox/Annotations/QtRectangleAnnotation.py +44 -14
  3. coralnet_toolbox/Common/QtGraphicsUtility.py +18 -8
  4. coralnet_toolbox/Explorer/transformer_models.py +13 -2
  5. coralnet_toolbox/IO/QtExportMaskAnnotations.py +576 -402
  6. coralnet_toolbox/IO/QtImportImages.py +7 -15
  7. coralnet_toolbox/IO/QtOpenProject.py +15 -19
  8. coralnet_toolbox/Icons/system_monitor.png +0 -0
  9. coralnet_toolbox/MachineLearning/ImportDataset/QtBase.py +33 -8
  10. coralnet_toolbox/QtAnnotationWindow.py +4 -0
  11. coralnet_toolbox/QtEventFilter.py +5 -5
  12. coralnet_toolbox/QtImageWindow.py +4 -0
  13. coralnet_toolbox/QtMainWindow.py +104 -64
  14. coralnet_toolbox/QtProgressBar.py +1 -0
  15. coralnet_toolbox/QtSystemMonitor.py +370 -0
  16. coralnet_toolbox/Rasters/RasterManager.py +5 -2
  17. coralnet_toolbox/Results/ConvertResults.py +14 -8
  18. coralnet_toolbox/Results/ResultsProcessor.py +3 -2
  19. coralnet_toolbox/SAM/QtDeployGenerator.py +1 -1
  20. coralnet_toolbox/SAM/QtDeployPredictor.py +10 -0
  21. coralnet_toolbox/SeeAnything/QtDeployGenerator.py +324 -177
  22. coralnet_toolbox/SeeAnything/QtDeployPredictor.py +10 -6
  23. coralnet_toolbox/Tile/QtTileBatchInference.py +4 -4
  24. coralnet_toolbox/Tools/QtPatchTool.py +6 -2
  25. coralnet_toolbox/Tools/QtPolygonTool.py +5 -3
  26. coralnet_toolbox/Tools/QtRectangleTool.py +17 -9
  27. coralnet_toolbox/Tools/QtSAMTool.py +144 -91
  28. coralnet_toolbox/Tools/QtSeeAnythingTool.py +4 -0
  29. coralnet_toolbox/Tools/QtTool.py +79 -3
  30. coralnet_toolbox/Tools/QtWorkAreaTool.py +4 -0
  31. coralnet_toolbox/Transformers/Models/GroundingDINO.py +72 -0
  32. coralnet_toolbox/Transformers/Models/OWLViT.py +72 -0
  33. coralnet_toolbox/Transformers/Models/OmDetTurbo.py +68 -0
  34. coralnet_toolbox/Transformers/Models/QtBase.py +121 -0
  35. coralnet_toolbox/{AutoDistill → Transformers}/Models/__init__.py +1 -1
  36. coralnet_toolbox/{AutoDistill → Transformers}/QtBatchInference.py +15 -15
  37. coralnet_toolbox/{AutoDistill → Transformers}/QtDeployModel.py +18 -16
  38. coralnet_toolbox/{AutoDistill → Transformers}/__init__.py +1 -1
  39. coralnet_toolbox/__init__.py +1 -1
  40. coralnet_toolbox/utilities.py +0 -15
  41. {coralnet_toolbox-0.0.75.dist-info → coralnet_toolbox-0.0.77.dist-info}/METADATA +9 -9
  42. {coralnet_toolbox-0.0.75.dist-info → coralnet_toolbox-0.0.77.dist-info}/RECORD +46 -44
  43. coralnet_toolbox/AutoDistill/Models/GroundingDINO.py +0 -81
  44. coralnet_toolbox/AutoDistill/Models/OWLViT.py +0 -76
  45. coralnet_toolbox/AutoDistill/Models/OmDetTurbo.py +0 -75
  46. coralnet_toolbox/AutoDistill/Models/QtBase.py +0 -112
  47. {coralnet_toolbox-0.0.75.dist-info → coralnet_toolbox-0.0.77.dist-info}/WHEEL +0 -0
  48. {coralnet_toolbox-0.0.75.dist-info → coralnet_toolbox-0.0.77.dist-info}/entry_points.txt +0 -0
  49. {coralnet_toolbox-0.0.75.dist-info → coralnet_toolbox-0.0.77.dist-info}/licenses/LICENSE.txt +0 -0
  50. {coralnet_toolbox-0.0.75.dist-info → coralnet_toolbox-0.0.77.dist-info}/top_level.txt +0 -0
@@ -10,10 +10,10 @@ import torch
10
10
  from torch.cuda import empty_cache
11
11
 
12
12
  import pyqtgraph as pg
13
- from pyqtgraph.Qt import QtGui
14
13
 
15
14
  from ultralytics import YOLOE
16
15
  from ultralytics.models.yolo.yoloe import YOLOEVPSegPredictor
16
+ from ultralytics.models.yolo.yoloe import YOLOEVPDetectPredictor
17
17
 
18
18
  from PyQt5.QtCore import Qt
19
19
  from PyQt5.QtWidgets import (QMessageBox, QVBoxLayout, QApplication, QFileDialog,
@@ -26,7 +26,6 @@ from coralnet_toolbox.Annotations.QtRectangleAnnotation import RectangleAnnotati
26
26
 
27
27
  from coralnet_toolbox.Results import ResultsProcessor
28
28
  from coralnet_toolbox.Results import MapResults
29
- from coralnet_toolbox.Results import CombineResults
30
29
 
31
30
  from coralnet_toolbox.QtProgressBar import ProgressBar
32
31
  from coralnet_toolbox.QtImageWindow import ImageWindow
@@ -90,7 +89,7 @@ class DeployGeneratorDialog(QDialog):
90
89
  self.imported_vpes = [] # VPEs loaded from file
91
90
  self.reference_vpes = [] # VPEs created from reference images
92
91
 
93
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
92
+ self.device = None # Will be set in showEvent
94
93
 
95
94
  # Main vertical layout for the dialog
96
95
  self.layout = QVBoxLayout(self)
@@ -195,6 +194,8 @@ class DeployGeneratorDialog(QDialog):
195
194
  self.initialize_iou_threshold()
196
195
  self.initialize_area_threshold()
197
196
 
197
+ # Update the device
198
+ self.device = self.main_window.device
198
199
  # Configure the image window's UI elements for this specific dialog
199
200
  self.configure_image_window_for_dialog()
200
201
  # Sync with main window's images BEFORE updating labels
@@ -537,8 +538,8 @@ class DeployGeneratorDialog(QDialog):
537
538
 
538
539
  def setup_reference_layout(self):
539
540
  """
540
- Set up the layout with reference label selection.
541
- The reference image is implicitly the currently active image.
541
+ Set up the layout for reference selection, including the output label,
542
+ reference method, and the number of prototype clusters (K).
542
543
  """
543
544
  group_box = QGroupBox("Reference")
544
545
  layout = QFormLayout()
@@ -553,8 +554,18 @@ class DeployGeneratorDialog(QDialog):
553
554
  self.reference_method_combo_box.addItems(["VPE", "Images"])
554
555
  layout.addRow("Reference Method:", self.reference_method_combo_box)
555
556
 
557
+ # Add a spinbox for the user to define the number of prototypes (K)
558
+ self.k_prototypes_spinbox = QSpinBox()
559
+ self.k_prototypes_spinbox.setRange(0, 1000)
560
+ self.k_prototypes_spinbox.setValue(0)
561
+ self.k_prototypes_spinbox.setToolTip(
562
+ "Set the number of prototype clusters (K) to generate from references.\n"
563
+ "Set to 0 to treat every unique reference image/VPE as its own prototype (K=N)."
564
+ )
565
+ layout.addRow("Number of Prototypes (K):", self.k_prototypes_spinbox)
566
+
556
567
  group_box.setLayout(layout)
557
- self.right_panel.addWidget(group_box) # Add to right panel
568
+ self.right_panel.addWidget(group_box)
558
569
 
559
570
  def setup_buttons_layout(self):
560
571
  """
@@ -767,20 +778,20 @@ class DeployGeneratorDialog(QDialog):
767
778
  try:
768
779
  # Load the VPE file
769
780
  loaded_data = torch.load(file_path)
770
-
771
- # TODO Move tensors to the appropriate device
772
- # device = self.main_window.device
781
+
782
+ # Move tensors to the appropriate device
783
+ device = self.main_window.device
773
784
 
774
785
  # Check format type and handle appropriately
775
786
  if isinstance(loaded_data, list):
776
787
  # New format: list of VPE tensors
777
- self.imported_vpes = [vpe.to(self.device) for vpe in loaded_data]
788
+ self.imported_vpes = [vpe.to(device) for vpe in loaded_data]
778
789
  vpe_count = len(self.imported_vpes)
779
790
  self.status_bar.setText(f"Loaded {vpe_count} VPE tensors from file")
780
791
 
781
792
  elif isinstance(loaded_data, torch.Tensor):
782
793
  # Legacy format: single tensor - convert to list for consistency
783
- loaded_vpe = loaded_data.to(self.device)
794
+ loaded_vpe = loaded_data.to(device)
784
795
  # Store as a single-item list
785
796
  self.imported_vpes = [loaded_vpe]
786
797
  self.status_bar.setText("Loaded 1 VPE tensor from file (legacy format)")
@@ -948,7 +959,7 @@ class DeployGeneratorDialog(QDialog):
948
959
  self.model_path = self.model_combo.currentText()
949
960
 
950
961
  # Load model using registry
951
- self.loaded_model = YOLOE(self.model_path, verbose=False).to(self.device) # TODO
962
+ self.loaded_model = YOLOE(self.model_path, verbose=False).to(self.device)
952
963
 
953
964
  # Create a dummy visual dictionary for standard model loading
954
965
  visual_prompts = dict(
@@ -966,7 +977,7 @@ class DeployGeneratorDialog(QDialog):
966
977
  self.loaded_model.predict(
967
978
  np.zeros((640, 640, 3), dtype=np.uint8),
968
979
  visual_prompts=visual_prompts.copy(), # This needs to happen to properly initialize the predictor
969
- predictor=YOLOEVPSegPredictor, # This also needs to be SegPredictor, no matter what
980
+ predictor=YOLOEVPDetectPredictor if self.task == "detect" else YOLOEVPSegPredictor,
970
981
  imgsz=640,
971
982
  conf=0.99,
972
983
  )
@@ -1086,9 +1097,9 @@ class DeployGeneratorDialog(QDialog):
1086
1097
 
1087
1098
  def _apply_model_using_images(self, inputs, reference_dict):
1088
1099
  """
1089
- Apply the model using the provided images and reference annotations (dict). This method
1090
- loops through each reference image using its annotations; we then aggregate
1091
- all the results together. Less efficient, but potentially more accurate.
1100
+ Apply the model using reference images. This method now generates VPEs
1101
+ from the reference annotations, clusters them into K prototypes, and
1102
+ runs a single, efficient multi-class prediction.
1092
1103
 
1093
1104
  Args:
1094
1105
  inputs (list): List of input images.
@@ -1097,64 +1108,70 @@ class DeployGeneratorDialog(QDialog):
1097
1108
  Returns:
1098
1109
  list: List of prediction results.
1099
1110
  """
1100
- # Create a progress bar for iterating through reference images
1101
- QApplication.setOverrideCursor(Qt.WaitCursor)
1102
- progress_bar = ProgressBar(self.annotation_window, title="Making Predictions per Reference")
1103
- progress_bar.show()
1104
- progress_bar.start_progress(len(reference_dict))
1111
+ # Note: This method may require 'from sklearn.cluster import KMeans'
1105
1112
 
1106
- results_list = []
1107
- # The 'inputs' list contains work areas from the single target image.
1108
- # We will predict on the first work area/full image.
1109
- input_image = inputs[0]
1113
+ # 1. Reload model and generate initial VPEs from the reference images
1114
+ self.reload_model()
1115
+ initial_vpes = self.references_to_vpe(reference_dict, update_reference_vpes=False)
1110
1116
 
1111
- # Iterate through each reference image and its annotations
1112
- for ref_path, ref_annotations in reference_dict.items():
1113
- # The 'refer_image' parameter is the path to the current reference image
1114
- # The 'visual_prompts' are the annotations from that same reference image
1115
- visual_prompts = {
1116
- 'bboxes': ref_annotations['bboxes'],
1117
- 'cls': ref_annotations['cls'],
1118
- }
1119
- if self.task == 'segment':
1120
- visual_prompts['masks'] = ref_annotations['masks']
1121
-
1122
- # Make predictions on the target using the current reference
1123
- results = self.loaded_model.predict(input_image,
1124
- refer_image=ref_path,
1125
- visual_prompts=visual_prompts,
1126
- predictor=YOLOEVPSegPredictor, # TODO This is necessary here?
1127
- imgsz=self.imgsz_spinbox.value(),
1128
- conf=self.main_window.get_uncertainty_thresh(),
1129
- iou=self.main_window.get_iou_thresh(),
1130
- max_det=self.get_max_detections(),
1131
- retina_masks=self.task == "segment")
1132
-
1133
- if not len(results[0].boxes):
1134
- # If no boxes were detected, skip to the next reference
1135
- progress_bar.update_progress()
1136
- continue
1137
-
1138
- # Update the name of the results and append to the list
1139
- results[0].names = {0: self.class_mapping[0].short_label_code}
1140
- results_list.extend(results[0])
1141
-
1142
- progress_bar.update_progress()
1143
- gc.collect()
1144
- empty_cache()
1117
+ if not initial_vpes:
1118
+ QMessageBox.warning(self,
1119
+ "VPE Generation Failed",
1120
+ "Could not generate VPEs from the selected reference images.")
1121
+ return []
1145
1122
 
1146
- # Clean up
1147
- QApplication.restoreOverrideCursor()
1148
- progress_bar.finish_progress()
1149
- progress_bar.stop_progress()
1150
- progress_bar.close()
1151
-
1152
- # Combine results if there are any
1153
- combined_results = CombineResults().combine_results(results_list)
1154
- if combined_results is None:
1123
+ # 2. Generate K prototypes from the N generated VPEs
1124
+ k = self.k_prototypes_spinbox.value()
1125
+ num_available_vpes = len(initial_vpes)
1126
+ prototype_vpes = []
1127
+
1128
+ # If K is 0 (use all) or K >= N, no clustering is needed.
1129
+ if k == 0 or k >= num_available_vpes:
1130
+ prototype_vpes = initial_vpes
1131
+ else:
1132
+ # Perform K-Means clustering to find K prototypes
1133
+ try:
1134
+ # Prepare tensor for scikit-learn: shape (N, E)
1135
+ all_vpes_tensor = torch.cat([vpe.squeeze(1) for vpe in initial_vpes], dim=0)
1136
+ vpes_np = all_vpes_tensor.cpu().numpy()
1137
+
1138
+ from sklearn.cluster import KMeans
1139
+ kmeans = KMeans(n_clusters=k, random_state=0, n_init='auto').fit(vpes_np)
1140
+ centroids_np = kmeans.cluster_centers_
1141
+
1142
+ # Convert centroids back to a list of normalized PyTorch tensors
1143
+ centroids_tensor = torch.from_numpy(centroids_np).to(self.device)
1144
+ for i in range(k):
1145
+ # Reshape centroid to (1, 1, E)
1146
+ centroid = centroids_tensor[i].unsqueeze(0).unsqueeze(0)
1147
+ normalized_centroid = torch.nn.functional.normalize(centroid, p=2, dim=-1)
1148
+ prototype_vpes.append(normalized_centroid)
1149
+ except Exception as e:
1150
+ QMessageBox.critical(self, "Clustering Error", f"Failed to perform K-Means clustering: {e}")
1151
+ return []
1152
+
1153
+ # 3. Configure the model with the K prototypes
1154
+ if not prototype_vpes:
1155
+ QMessageBox.warning(self, "Prototype Error", "Could not generate any prototypes for prediction.")
1155
1156
  return []
1156
-
1157
- return [[combined_results]]
1157
+
1158
+ num_prototypes = len(prototype_vpes)
1159
+ proto_class_names = [f"object{i}" for i in range(num_prototypes)]
1160
+ stacked_vpes = torch.cat(prototype_vpes, dim=1) # Shape: (1, K, E)
1161
+
1162
+ self.loaded_model.is_fused = lambda: False
1163
+ self.loaded_model.set_classes(proto_class_names, stacked_vpes)
1164
+
1165
+ # 4. Make a single prediction on the target using the K prototypes
1166
+ results = self.loaded_model.predict(inputs[0],
1167
+ visual_prompts=[],
1168
+ imgsz=self.imgsz_spinbox.value(),
1169
+ conf=self.main_window.get_uncertainty_thresh(),
1170
+ iou=self.main_window.get_iou_thresh(),
1171
+ max_det=self.get_max_detections(),
1172
+ retina_masks=self.task == "segment")
1173
+
1174
+ return [results]
1158
1175
 
1159
1176
  def generate_vpes_from_references(self):
1160
1177
  """
@@ -1252,8 +1269,8 @@ class DeployGeneratorDialog(QDialog):
1252
1269
 
1253
1270
  def _apply_model_using_vpe(self, inputs):
1254
1271
  """
1255
- Apply the model to the inputs using pre-calculated VPEs from imported files
1256
- and/or generated from reference annotations.
1272
+ Apply the model using VPEs. This method now supports clustering N VPEs
1273
+ into K prototypes before running a single multi-class prediction.
1257
1274
 
1258
1275
  Args:
1259
1276
  inputs (list): List of input images.
@@ -1261,21 +1278,18 @@ class DeployGeneratorDialog(QDialog):
1261
1278
  Returns:
1262
1279
  list: List of prediction results.
1263
1280
  """
1281
+ # Note: This method may require 'from sklearn.cluster import KMeans'
1282
+
1264
1283
  # First reload the model to clear any cached data
1265
1284
  self.reload_model()
1266
1285
 
1267
- # Initialize combined_vpes list
1286
+ # 1. Gather all available VPEs from imported files and generated references
1268
1287
  combined_vpes = []
1269
-
1270
- # Add imported VPEs if available
1271
1288
  if self.imported_vpes:
1272
1289
  combined_vpes.extend(self.imported_vpes)
1273
-
1274
- # Add pre-generated reference VPEs if available
1275
1290
  if self.reference_vpes:
1276
1291
  combined_vpes.extend(self.reference_vpes)
1277
1292
 
1278
- # Check if we have any VPEs to use
1279
1293
  if not combined_vpes:
1280
1294
  QMessageBox.warning(
1281
1295
  self,
@@ -1285,18 +1299,57 @@ class DeployGeneratorDialog(QDialog):
1285
1299
  )
1286
1300
  return []
1287
1301
 
1288
- # Average all the VPEs together to create a final VPE tensor
1289
- averaged_vpe = torch.cat(combined_vpes).mean(dim=0, keepdim=True)
1290
- final_vpe = torch.nn.functional.normalize(averaged_vpe, p=2, dim=-1)
1291
-
1292
- # For backward compatibility, update self.vpe
1293
- self.vpe = final_vpe
1302
+ # 2. Generate K prototypes from the N available VPEs
1303
+ k = self.k_prototypes_spinbox.value()
1304
+ num_available_vpes = len(combined_vpes)
1305
+ prototype_vpes = []
1306
+
1307
+ # If K is 0 (use all) or K >= N, no clustering is needed. Each VPE is a prototype.
1308
+ if k == 0 or k >= num_available_vpes:
1309
+ prototype_vpes = combined_vpes
1310
+ else:
1311
+ # Perform K-Means clustering to find K prototypes
1312
+ try:
1313
+ # Prepare tensor for scikit-learn: shape (N, E)
1314
+ all_vpes_tensor = torch.cat([vpe.squeeze(1) for vpe in combined_vpes], dim=0)
1315
+ vpes_np = all_vpes_tensor.cpu().numpy()
1316
+
1317
+ # Lazily import KMeans to avoid making sklearn a hard dependency if not used
1318
+ from sklearn.cluster import KMeans
1319
+ kmeans = KMeans(n_clusters=k, random_state=0, n_init='auto').fit(vpes_np)
1320
+ centroids_np = kmeans.cluster_centers_
1321
+
1322
+ # Convert centroids back to a list of normalized PyTorch tensors
1323
+ centroids_tensor = torch.from_numpy(centroids_np).to(self.device)
1324
+ for i in range(k):
1325
+ # Reshape centroid to (1, 1, E) for model compatibility
1326
+ centroid = centroids_tensor[i].unsqueeze(0).unsqueeze(0)
1327
+ normalized_centroid = torch.nn.functional.normalize(centroid, p=2, dim=-1)
1328
+ prototype_vpes.append(normalized_centroid)
1329
+ except Exception as e:
1330
+ QMessageBox.critical(self, "Clustering Error", f"Failed to perform K-Means clustering: {e}")
1331
+ return []
1332
+
1333
+ # 3. Configure the model with the K prototypes
1334
+ if not prototype_vpes:
1335
+ QMessageBox.warning(self, "Prototype Error", "Could not generate any prototypes for prediction.")
1336
+ return []
1337
+
1338
+ # For backward compatibility, set self.vpe to the average of the final prototypes
1339
+ averaged_prototype = torch.cat(prototype_vpes).mean(dim=0, keepdim=True)
1340
+ self.vpe = torch.nn.functional.normalize(averaged_prototype, p=2, dim=-1)
1341
+
1342
+ # Set up the model for multi-class detection with K proto-classes
1343
+ num_prototypes = len(prototype_vpes)
1344
+ proto_class_names = [f"object{i}" for i in range(num_prototypes)]
1294
1345
 
1295
- # Set the final VPE to the model
1346
+ # Stack prototypes into a single tensor of shape (1, K, E) for set_classes
1347
+ stacked_vpes = torch.cat(prototype_vpes, dim=1)
1348
+
1296
1349
  self.loaded_model.is_fused = lambda: False
1297
- self.loaded_model.set_classes(["object0"], final_vpe)
1350
+ self.loaded_model.set_classes(proto_class_names, stacked_vpes)
1298
1351
 
1299
- # Make predictions on the target using the averaged VPE
1352
+ # 4. Make predictions on the target using the K prototypes
1300
1353
  results = self.loaded_model.predict(inputs[0],
1301
1354
  visual_prompts=[],
1302
1355
  imgsz=self.imgsz_spinbox.value(),
@@ -1386,7 +1439,7 @@ class DeployGeneratorDialog(QDialog):
1386
1439
  return updated_results
1387
1440
 
1388
1441
  def _process_results(self, results_processor, results_list, image_path):
1389
- """Process the results using the result processor."""
1442
+ """Process the results, merging K proto-class detections into a single target class."""
1390
1443
  # Get the raster object and number of work items
1391
1444
  raster = self.image_window.raster_manager.get_raster(image_path)
1392
1445
  total = raster.count_work_items()
@@ -1400,14 +1453,25 @@ class DeployGeneratorDialog(QDialog):
1400
1453
  progress_bar.start_progress(total)
1401
1454
 
1402
1455
  updated_results = []
1456
+ target_label_name = self.reference_label.short_label_code
1403
1457
 
1404
1458
  for idx, results in enumerate(results_list):
1405
1459
  # Each Results is a list (within the results_list, [[], ]
1406
- if results:
1407
- # Update path and names
1460
+ if results and results[0].boxes is not None and len(results[0].boxes) > 0:
1461
+ # Clone the data tensor and set all classes to 0
1462
+ new_data = results[0].boxes.data.clone()
1463
+ new_data[:, 5] = 0 # The 6th column (index 5) is the class
1464
+
1465
+ # Create a new Boxes object of the same type
1466
+ new_boxes = type(results[0].boxes)(new_data, results[0].boxes.orig_shape)
1467
+ results[0].boxes = new_boxes
1468
+
1469
+ # Update the 'names' dictionary to map our single class ID (0)
1470
+ # to the final target label name chosen by the user.
1471
+ results[0].names = {0: target_label_name}
1472
+
1473
+ # Update path (original logic)
1408
1474
  results[0].path = image_path
1409
- results[0].names = {0: self.class_mapping[0].short_label_code}
1410
- # This needs to be done again, in case SAM was used
1411
1475
 
1412
1476
  # Check if the work area is valid, or the image path is being used
1413
1477
  if work_areas and self.annotation_window.get_selected_tool() == "work_area":
@@ -1419,7 +1483,7 @@ class DeployGeneratorDialog(QDialog):
1419
1483
  else:
1420
1484
  results = results[0]
1421
1485
 
1422
- # Append the result object (not a list) to the updated results list
1486
+ # Append the result object to the updated results list
1423
1487
  updated_results.append(results)
1424
1488
 
1425
1489
  # Update the index for the next work area
@@ -1439,42 +1503,96 @@ class DeployGeneratorDialog(QDialog):
1439
1503
 
1440
1504
  def show_vpe(self):
1441
1505
  """
1442
- Show a visualization of the currently stored VPEs using PyQtGraph.
1506
+ Show a visualization of the stored VPEs and their K-prototypes.
1443
1507
  """
1444
1508
  try:
1509
+ # 1. Gather all raw VPEs from imports and references
1445
1510
  vpes_with_source = []
1446
-
1447
- # 1. Add any VPEs that were loaded from a file
1448
1511
  if self.imported_vpes:
1449
1512
  for vpe in self.imported_vpes:
1450
1513
  vpes_with_source.append((vpe, "Import"))
1451
-
1452
- # 2. Add any pre-generated VPEs from reference images
1453
1514
  if self.reference_vpes:
1454
1515
  for vpe in self.reference_vpes:
1455
1516
  vpes_with_source.append((vpe, "Reference"))
1456
1517
 
1457
- # 3. Check if there is anything to visualize
1458
1518
  if not vpes_with_source:
1459
1519
  QMessageBox.warning(
1460
1520
  self,
1461
1521
  "No VPEs Available",
1462
- "No VPEs available to visualize. Please load a VPE file or generate VPEs from references first."
1522
+ "No VPEs available to visualize. Please load or generate VPEs first."
1463
1523
  )
1464
1524
  return
1465
1525
 
1466
- # 4. Create the visualization dialog
1467
- all_vpe_tensors = [vpe for vpe, source in vpes_with_source]
1468
- averaged_vpe = torch.cat(all_vpe_tensors).mean(dim=0, keepdim=True)
1469
- final_vpe = torch.nn.functional.normalize(averaged_vpe, p=2, dim=-1)
1470
-
1471
- QApplication.setOverrideCursor(Qt.WaitCursor)
1526
+ raw_vpes = [vpe for vpe, source in vpes_with_source]
1527
+ num_raw = len(raw_vpes)
1472
1528
 
1473
- dialog = VPEVisualizationDialog(vpes_with_source, final_vpe, self)
1474
- QApplication.restoreOverrideCursor()
1529
+ # 2. Get K and determine if we need to cluster
1530
+ k = self.k_prototypes_spinbox.value()
1531
+ prototypes = [] # This will be the centroids if clustering is performed
1532
+ final_vpe = None
1533
+ clustering_performed = False
1534
+
1535
+ # Case 1: We want to cluster (1 <= k < num_raw)
1536
+ if 1 <= k < num_raw:
1537
+ try:
1538
+ # Prepare tensor for scikit-learn: shape (N, E)
1539
+ all_vpes_tensor = torch.cat([vpe.squeeze(1) for vpe in raw_vpes], dim=0)
1540
+ vpes_np = all_vpes_tensor.cpu().numpy()
1541
+
1542
+ from sklearn.cluster import KMeans
1543
+ kmeans = KMeans(n_clusters=k, random_state=0, n_init='auto').fit(vpes_np)
1544
+ centroids_np = kmeans.cluster_centers_
1545
+
1546
+ centroids_tensor = torch.from_numpy(centroids_np).to(self.device)
1547
+ for i in range(k):
1548
+ centroid = centroids_tensor[i].unsqueeze(0).unsqueeze(0)
1549
+ normalized_centroid = torch.nn.functional.normalize(centroid, p=2, dim=-1)
1550
+ prototypes.append(normalized_centroid)
1551
+
1552
+ # The final VPE is the average of the centroids
1553
+ stacked_prototypes = torch.cat(prototypes, dim=1) # Shape: (1, k, E)
1554
+ averaged_prototype = stacked_prototypes.mean(dim=1, keepdim=True) # Shape: (1, 1, E)
1555
+ final_vpe = torch.nn.functional.normalize(averaged_prototype, p=2, dim=-1)
1556
+ clustering_performed = True
1557
+
1558
+ except Exception as e:
1559
+ QMessageBox.critical(self,
1560
+ "Clustering Error",
1561
+ f"Could not perform clustering for visualization: {e}")
1562
+ # If clustering fails, fall back to using all raw VPEs
1563
+ prototypes = []
1564
+ clustering_performed = False
1565
+
1566
+ # Case 2: k==0 -> use all raw VPEs as prototypes (no clustering)
1567
+ if k == 0 or not clustering_performed:
1568
+ # We are not clustering, so we use all raw VPEs as prototypes
1569
+ # For visualization purposes, we'll show the raw VPEs and their average
1570
+ stacked_raw = torch.cat(raw_vpes, dim=1) # Shape: (1, num_raw, E)
1571
+ averaged_raw = stacked_raw.mean(dim=1, keepdim=True) # Shape: (1, 1, E)
1572
+ final_vpe = torch.nn.functional.normalize(averaged_raw, p=2, dim=-1)
1573
+ # Don't set prototypes here - we'll show raw VPEs separately in the visualization
1475
1574
 
1575
+ # Case 3: k >= num_raw -> use all raw VPEs as prototypes (no clustering needed)
1576
+ elif k >= num_raw:
1577
+ # We have more requested prototypes than available VPEs, so we use all VPEs
1578
+ stacked_raw = torch.cat(raw_vpes, dim=1) # Shape: (1, num_raw, E)
1579
+ averaged_raw = stacked_raw.mean(dim=1, keepdim=True) # Shape: (1, 1, E)
1580
+ final_vpe = torch.nn.functional.normalize(averaged_raw, p=2, dim=-1)
1581
+ # Don't set prototypes here - we'll show raw VPEs separately in the visualization
1582
+
1583
+ # 3. Create and show the visualization dialog
1584
+ QApplication.setOverrideCursor(Qt.WaitCursor)
1585
+ dialog = VPEVisualizationDialog(
1586
+ vpes_with_source,
1587
+ final_vpe,
1588
+ prototypes=prototypes,
1589
+ clustering_performed=clustering_performed,
1590
+ k_value=k,
1591
+ parent=self
1592
+ )
1593
+ QApplication.restoreOverrideCursor()
1476
1594
  dialog.exec_()
1477
-
1595
+
1478
1596
  except Exception as e:
1479
1597
  QApplication.restoreOverrideCursor()
1480
1598
  QMessageBox.critical(self, "Error Visualizing VPE", f"An error occurred: {str(e)}")
@@ -1507,44 +1625,45 @@ class DeployGeneratorDialog(QDialog):
1507
1625
 
1508
1626
  class VPEVisualizationDialog(QDialog):
1509
1627
  """
1510
- Dialog for visualizing VPE embeddings in 2D space using PCA.
1628
+ Dialog for visualizing VPE embeddings, now including K-prototypes.
1511
1629
  """
1512
- def __init__(self, vpe_list_with_source, final_vpe=None, parent=None):
1630
+ def __init__(self, vpe_list_with_source, final_vpe=None, prototypes=None,
1631
+ clustering_performed=False, k_value=0, parent=None):
1513
1632
  """
1514
- Initialize the dialog with a list of VPE tensors and their sources.
1633
+ Initialize the dialog.
1515
1634
 
1516
1635
  Args:
1517
- vpe_list_with_source (list): List of (VPE tensor, source_str) tuples
1518
- final_vpe (torch.Tensor, optional): The final (averaged) VPE
1519
- parent (QWidget, optional): Parent widget
1636
+ vpe_list_with_source (list): List of (VPE tensor, source_str) tuples for raw VPEs.
1637
+ final_vpe (torch.Tensor, optional): The final (averaged) VPE.
1638
+ prototypes (list, optional): List of K-prototype VPE tensors (cluster centroids).
1639
+ clustering_performed (bool): Whether clustering was performed.
1640
+ k_value (int): The K value used for clustering.
1641
+ parent (QWidget, optional): Parent widget.
1520
1642
  """
1521
1643
  super().__init__(parent)
1522
1644
  self.setWindowTitle("VPE Visualization")
1523
1645
  self.resize(1000, 1000)
1524
-
1525
- # Add a maximize button to the dialog's title bar
1526
1646
  self.setWindowFlags(self.windowFlags() | Qt.WindowMaximizeButtonHint)
1527
1647
 
1528
- # Store the VPEs and their sources
1648
+ # Store the VPEs and clustering info
1529
1649
  self.vpe_list_with_source = vpe_list_with_source
1530
1650
  self.final_vpe = final_vpe
1651
+ self.prototypes = prototypes if prototypes else []
1652
+ self.clustering_performed = clustering_performed
1653
+ self.k_value = k_value
1531
1654
 
1532
1655
  # Create the layout
1533
1656
  layout = QVBoxLayout(self)
1534
1657
 
1535
1658
  # Create the plot widget
1536
1659
  self.plot_widget = pg.PlotWidget()
1537
- self.plot_widget.setBackground('w') # White background
1660
+ self.plot_widget.setBackground('w')
1538
1661
  self.plot_widget.setTitle("PCA Visualization of Visual Prompt Embeddings", color="#000000", size="10pt")
1539
1662
  self.plot_widget.showGrid(x=True, y=True, alpha=0.3)
1540
-
1541
- # Add the plot widget to the layout
1542
1663
  layout.addWidget(self.plot_widget)
1543
-
1544
- # Add spacing between plot_widget and info_label
1545
1664
  layout.addSpacing(20)
1546
1665
 
1547
- # Add information label at the bottom
1666
+ # Add information label
1548
1667
  self.info_label = QLabel()
1549
1668
  self.info_label.setAlignment(Qt.AlignCenter)
1550
1669
  layout.addWidget(self.info_label)
@@ -1559,101 +1678,129 @@ class VPEVisualizationDialog(QDialog):
1559
1678
 
1560
1679
  def visualize_vpes(self):
1561
1680
  """
1562
- Apply PCA to the VPE tensors and visualize them in 2D space.
1681
+ Apply PCA to all VPEs (raw, prototypes, final) and visualize them.
1563
1682
  """
1564
1683
  if not self.vpe_list_with_source:
1565
1684
  self.info_label.setText("No VPEs available to visualize.")
1566
1685
  return
1686
+
1687
+ # 1. Collect all numpy arrays for PCA transformation
1688
+ raw_vpe_arrays = [vpe.detach().cpu().numpy().squeeze() for vpe, source in self.vpe_list_with_source]
1689
+ prototype_arrays = [p.detach().cpu().numpy().squeeze() for p in self.prototypes]
1567
1690
 
1568
- # Convert tensors to numpy arrays for PCA, separating them from the source string
1569
- vpe_arrays = [vpe.detach().cpu().numpy().squeeze() for vpe, source in self.vpe_list_with_source]
1691
+ all_arrays_for_pca = raw_vpe_arrays + prototype_arrays
1570
1692
 
1571
- # If final VPE is provided, add it to the arrays
1572
1693
  final_vpe_array = None
1573
1694
  if self.final_vpe is not None:
1574
1695
  final_vpe_array = self.final_vpe.detach().cpu().numpy().squeeze()
1575
- all_vpes = np.vstack(vpe_arrays + [final_vpe_array])
1576
- else:
1577
- all_vpes = np.vstack(vpe_arrays)
1578
-
1579
- # Apply PCA to reduce to 2 dimensions
1696
+ all_arrays_for_pca.append(final_vpe_array)
1697
+
1698
+ if len(all_arrays_for_pca) < 2:
1699
+ self.info_label.setText("At least 2 VPEs are needed for PCA visualization.")
1700
+ return
1701
+
1702
+ # 2. Apply PCA
1703
+ all_vpes_stacked = np.vstack(all_arrays_for_pca)
1580
1704
  pca = PCA(n_components=2)
1581
- vpes_2d = pca.fit_transform(all_vpes)
1705
+ vpes_2d = pca.fit_transform(all_vpes_stacked)
1582
1706
 
1583
- # Clear the plot
1707
+ # 3. Plot the results
1584
1708
  self.plot_widget.clear()
1585
-
1586
- # Generate random colors for individual VPEs
1587
- num_vpes = len(vpe_arrays)
1588
- colors = self.generate_distinct_colors(num_vpes)
1589
-
1590
- # Create a legend with 3 columns to keep it compact
1591
1709
  legend = self.plot_widget.addLegend(colCount=3)
1592
1710
 
1593
- # Plot individual VPEs
1594
- for i, (vpe_tuple, vpe_2d) in enumerate(zip(self.vpe_list_with_source, vpes_2d[:num_vpes])):
1711
+ # Slicing indices
1712
+ num_raw = len(raw_vpe_arrays)
1713
+ num_prototypes = len(prototype_arrays)
1714
+
1715
+ # Determine if each raw VPE is effectively a prototype (k==0 or k>=N)
1716
+ each_vpe_is_prototype = (self.k_value == 0 or self.k_value >= num_raw)
1717
+
1718
+ # Plot individual raw VPEs
1719
+ colors = self.generate_distinct_colors(num_raw)
1720
+ for i, (vpe_tuple, vpe_2d) in enumerate(zip(self.vpe_list_with_source, vpes_2d[:num_raw])):
1595
1721
  source_char = 'I' if vpe_tuple[1] == 'Import' else 'R'
1596
- color = pg.mkColor(colors[i])
1722
+
1723
+ # Use diamonds if each VPE is a prototype, circles otherwise
1724
+ symbol = 'd' if each_vpe_is_prototype else 'o'
1725
+
1726
+ # If it's a prototype, add a black border
1727
+ pen = pg.mkPen(color='k', width=1.5) if each_vpe_is_prototype else None
1728
+
1729
+ # Create label with prototype indicator if applicable
1730
+ name_suffix = " (Prototype)" if each_vpe_is_prototype else ""
1731
+ name = f"VPE {i+1} ({source_char}){name_suffix}"
1732
+
1597
1733
  scatter = pg.ScatterPlotItem(
1598
1734
  x=[vpe_2d[0]],
1599
1735
  y=[vpe_2d[1]],
1600
- brush=color,
1601
- size=15,
1602
- name=f"VPE {i+1} ({source_char})"
1736
+ brush=pg.mkColor(colors[i]),
1737
+ pen=pen,
1738
+ size=15 if not each_vpe_is_prototype else 18,
1739
+ symbol=symbol,
1740
+ name=name
1603
1741
  )
1604
1742
  self.plot_widget.addItem(scatter)
1605
1743
 
1606
- # Plot the final (averaged) VPE if available
1744
+ # Plot K-Prototypes (blue diamonds) if we have any and explicit clustering was performed
1745
+ if self.prototypes and self.clustering_performed:
1746
+ prototype_vpes_2d = vpes_2d[num_raw: num_raw + num_prototypes]
1747
+ scatter = pg.ScatterPlotItem(
1748
+ x=prototype_vpes_2d[:, 0],
1749
+ y=prototype_vpes_2d[:, 1],
1750
+ brush=pg.mkBrush(color=(0, 0, 255, 150)),
1751
+ pen=pg.mkPen(color='k', width=1.5),
1752
+ size=18,
1753
+ symbol='d',
1754
+ name=f"K-Prototypes (K={self.k_value})"
1755
+ )
1756
+ self.plot_widget.addItem(scatter)
1757
+
1758
+ # Plot the final (averaged) VPE (red star)
1607
1759
  if final_vpe_array is not None:
1608
1760
  final_vpe_2d = vpes_2d[-1]
1609
1761
  scatter = pg.ScatterPlotItem(
1610
1762
  x=[final_vpe_2d[0]],
1611
- y=[final_vpe_2d[1]],
1763
+ y=[final_vpe_2d[1]],
1612
1764
  brush=pg.mkBrush(color='r'),
1613
- size=20,
1614
- symbol='star',
1615
- name="Final VPE"
1765
+ size=20,
1766
+ symbol='star',
1767
+ name="Final VPE (Avg)"
1616
1768
  )
1617
1769
  self.plot_widget.addItem(scatter)
1618
1770
 
1619
- # Update the information label
1771
+ # 4. Update the information label
1620
1772
  orig_dim = self.vpe_list_with_source[0][0].shape[-1]
1621
1773
  explained_variance = sum(pca.explained_variance_ratio_)
1622
- self.info_label.setText(
1623
- f"Original dimension: {orig_dim} → Reduced to 2D\n"
1624
- f"Total explained variance: {explained_variance:.2%}\n"
1625
- f"PC1: {pca.explained_variance_ratio_[0]:.2%} variance, "
1626
- f"PC2: {pca.explained_variance_ratio_[1]:.2%} variance"
1627
- )
1774
+
1775
+ info_text = (f"Original dimension: {orig_dim} → Reduced to 2D\n"
1776
+ f"Total explained variance: {explained_variance:.2%}\n"
1777
+ f"PC1: {pca.explained_variance_ratio_[0]:.2%} variance, "
1778
+ f"PC2: {pca.explained_variance_ratio_[1]:.2%} variance\n"
1779
+ f"Number of raw VPEs: {num_raw}\n")
1780
+
1781
+ if self.clustering_performed:
1782
+ info_text += f"Clustering performed with K={self.k_value}\n"
1783
+ info_text += f"Number of prototypes: {len(self.prototypes)}"
1784
+ else:
1785
+ if self.k_value == 0:
1786
+ info_text += f"No clustering (K=0): all {num_raw} raw VPEs used as prototypes"
1787
+ else:
1788
+ info_text += f"No clustering performed (K={self.k_value} >= {num_raw}): all raw VPEs used as prototypes"
1789
+
1790
+ self.info_label.setText(info_text)
1628
1791
 
1629
1792
  def generate_distinct_colors(self, num_colors):
1630
- """
1631
- Generate visually distinct colors by using evenly spaced hues
1632
- with random saturation and value.
1633
-
1634
- Args:
1635
- num_colors (int): Number of colors to generate
1636
-
1637
- Returns:
1638
- list: List of color hex strings
1639
- """
1793
+ """Generates visually distinct colors."""
1640
1794
  import random
1641
1795
  from colorsys import hsv_to_rgb
1642
1796
 
1643
1797
  colors = []
1644
1798
  for i in range(num_colors):
1645
- # Use golden ratio to space hues evenly
1646
1799
  hue = (i * 0.618033988749895) % 1.0
1647
- # Random saturation between 0.6-1.0 (avoid too pale)
1648
1800
  saturation = random.uniform(0.6, 1.0)
1649
- # Random value between 0.7-1.0 (avoid too dark)
1650
1801
  value = random.uniform(0.7, 1.0)
1651
-
1652
- # Convert HSV to RGB (0-1 range)
1653
1802
  r, g, b = hsv_to_rgb(hue, saturation, value)
1654
-
1655
- # Convert RGB to hex string
1656
1803
  hex_color = f"#{int(r*255):02x}{int(g*255):02x}{int(b*255):02x}"
1657
1804
  colors.append(hex_color)
1658
-
1805
+
1659
1806
  return colors