coralnet-toolbox 0.0.76__py2.py3-none-any.whl → 0.0.77__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -26,7 +26,6 @@ from coralnet_toolbox.Annotations.QtRectangleAnnotation import RectangleAnnotati
26
26
 
27
27
  from coralnet_toolbox.Results import ResultsProcessor
28
28
  from coralnet_toolbox.Results import MapResults
29
- from coralnet_toolbox.Results import CombineResults
30
29
 
31
30
  from coralnet_toolbox.QtProgressBar import ProgressBar
32
31
  from coralnet_toolbox.QtImageWindow import ImageWindow
@@ -539,8 +538,8 @@ class DeployGeneratorDialog(QDialog):
539
538
 
540
539
  def setup_reference_layout(self):
541
540
  """
542
- Set up the layout with reference label selection.
543
- The reference image is implicitly the currently active image.
541
+ Set up the layout for reference selection, including the output label,
542
+ reference method, and the number of prototype clusters (K).
544
543
  """
545
544
  group_box = QGroupBox("Reference")
546
545
  layout = QFormLayout()
@@ -555,8 +554,18 @@ class DeployGeneratorDialog(QDialog):
555
554
  self.reference_method_combo_box.addItems(["VPE", "Images"])
556
555
  layout.addRow("Reference Method:", self.reference_method_combo_box)
557
556
 
557
+ # Add a spinbox for the user to define the number of prototypes (K)
558
+ self.k_prototypes_spinbox = QSpinBox()
559
+ self.k_prototypes_spinbox.setRange(0, 1000)
560
+ self.k_prototypes_spinbox.setValue(0)
561
+ self.k_prototypes_spinbox.setToolTip(
562
+ "Set the number of prototype clusters (K) to generate from references.\n"
563
+ "Set to 0 to treat every unique reference image/VPE as its own prototype (K=N)."
564
+ )
565
+ layout.addRow("Number of Prototypes (K):", self.k_prototypes_spinbox)
566
+
558
567
  group_box.setLayout(layout)
559
- self.right_panel.addWidget(group_box) # Add to right panel
568
+ self.right_panel.addWidget(group_box)
560
569
 
561
570
  def setup_buttons_layout(self):
562
571
  """
@@ -1088,9 +1097,9 @@ class DeployGeneratorDialog(QDialog):
1088
1097
 
1089
1098
  def _apply_model_using_images(self, inputs, reference_dict):
1090
1099
  """
1091
- Apply the model using the provided images and reference annotations (dict). This method
1092
- loops through each reference image using its annotations; we then aggregate
1093
- all the results together. Less efficient, but potentially more accurate.
1100
+ Apply the model using reference images. This method now generates VPEs
1101
+ from the reference annotations, clusters them into K prototypes, and
1102
+ runs a single, efficient multi-class prediction.
1094
1103
 
1095
1104
  Args:
1096
1105
  inputs (list): List of input images.
@@ -1099,67 +1108,70 @@ class DeployGeneratorDialog(QDialog):
1099
1108
  Returns:
1100
1109
  list: List of prediction results.
1101
1110
  """
1102
- # Create a progress bar for iterating through reference images
1103
- QApplication.setOverrideCursor(Qt.WaitCursor)
1104
- progress_bar = ProgressBar(self.annotation_window, title="Making Predictions per Reference")
1105
- progress_bar.show()
1106
- progress_bar.start_progress(len(reference_dict))
1111
+ # Note: This method may require 'from sklearn.cluster import KMeans'
1107
1112
 
1108
- results_list = []
1109
- # The 'inputs' list contains work areas from the single target image.
1110
- # We will predict on the first work area/full image.
1111
- input_image = inputs[0]
1113
+ # 1. Reload model and generate initial VPEs from the reference images
1114
+ self.reload_model()
1115
+ initial_vpes = self.references_to_vpe(reference_dict, update_reference_vpes=False)
1112
1116
 
1113
- # Set the predictor
1114
- predictor = YOLOEVPDetectPredictor if self.task == "detect" else YOLOEVPSegPredictor
1117
+ if not initial_vpes:
1118
+ QMessageBox.warning(self,
1119
+ "VPE Generation Failed",
1120
+ "Could not generate VPEs from the selected reference images.")
1121
+ return []
1115
1122
 
1116
- # Iterate through each reference image and its annotations
1117
- for ref_path, ref_annotations in reference_dict.items():
1118
- # The 'refer_image' parameter is the path to the current reference image
1119
- # The 'visual_prompts' are the annotations from that same reference image
1120
- visual_prompts = {
1121
- 'bboxes': ref_annotations['bboxes'],
1122
- 'cls': ref_annotations['cls'],
1123
- }
1124
- if self.task == 'segment':
1125
- visual_prompts['masks'] = ref_annotations['masks']
1126
-
1127
- # Make predictions on the target using the current reference
1128
- results = self.loaded_model.predict(input_image,
1129
- refer_image=ref_path,
1130
- visual_prompts=visual_prompts,
1131
- predictor=predictor,
1132
- imgsz=self.imgsz_spinbox.value(),
1133
- conf=self.main_window.get_uncertainty_thresh(),
1134
- iou=self.main_window.get_iou_thresh(),
1135
- max_det=self.get_max_detections(),
1136
- retina_masks=self.task == "segment")
1137
-
1138
- if not len(results[0].boxes):
1139
- # If no boxes were detected, skip to the next reference
1140
- progress_bar.update_progress()
1141
- continue
1142
-
1143
- # Update the name of the results and append to the list
1144
- results[0].names = {0: self.class_mapping[0].short_label_code}
1145
- results_list.extend(results[0])
1146
-
1147
- progress_bar.update_progress()
1148
- gc.collect()
1149
- empty_cache()
1123
+ # 2. Generate K prototypes from the N generated VPEs
1124
+ k = self.k_prototypes_spinbox.value()
1125
+ num_available_vpes = len(initial_vpes)
1126
+ prototype_vpes = []
1150
1127
 
1151
- # Clean up
1152
- QApplication.restoreOverrideCursor()
1153
- progress_bar.finish_progress()
1154
- progress_bar.stop_progress()
1155
- progress_bar.close()
1156
-
1157
- # Combine results if there are any
1158
- combined_results = CombineResults().combine_results(results_list)
1159
- if combined_results is None:
1128
+ # If K is 0 (use all) or K >= N, no clustering is needed.
1129
+ if k == 0 or k >= num_available_vpes:
1130
+ prototype_vpes = initial_vpes
1131
+ else:
1132
+ # Perform K-Means clustering to find K prototypes
1133
+ try:
1134
+ # Prepare tensor for scikit-learn: shape (N, E)
1135
+ all_vpes_tensor = torch.cat([vpe.squeeze(1) for vpe in initial_vpes], dim=0)
1136
+ vpes_np = all_vpes_tensor.cpu().numpy()
1137
+
1138
+ from sklearn.cluster import KMeans
1139
+ kmeans = KMeans(n_clusters=k, random_state=0, n_init='auto').fit(vpes_np)
1140
+ centroids_np = kmeans.cluster_centers_
1141
+
1142
+ # Convert centroids back to a list of normalized PyTorch tensors
1143
+ centroids_tensor = torch.from_numpy(centroids_np).to(self.device)
1144
+ for i in range(k):
1145
+ # Reshape centroid to (1, 1, E)
1146
+ centroid = centroids_tensor[i].unsqueeze(0).unsqueeze(0)
1147
+ normalized_centroid = torch.nn.functional.normalize(centroid, p=2, dim=-1)
1148
+ prototype_vpes.append(normalized_centroid)
1149
+ except Exception as e:
1150
+ QMessageBox.critical(self, "Clustering Error", f"Failed to perform K-Means clustering: {e}")
1151
+ return []
1152
+
1153
+ # 3. Configure the model with the K prototypes
1154
+ if not prototype_vpes:
1155
+ QMessageBox.warning(self, "Prototype Error", "Could not generate any prototypes for prediction.")
1160
1156
  return []
1161
-
1162
- return [[combined_results]]
1157
+
1158
+ num_prototypes = len(prototype_vpes)
1159
+ proto_class_names = [f"object{i}" for i in range(num_prototypes)]
1160
+ stacked_vpes = torch.cat(prototype_vpes, dim=1) # Shape: (1, K, E)
1161
+
1162
+ self.loaded_model.is_fused = lambda: False
1163
+ self.loaded_model.set_classes(proto_class_names, stacked_vpes)
1164
+
1165
+ # 4. Make a single prediction on the target using the K prototypes
1166
+ results = self.loaded_model.predict(inputs[0],
1167
+ visual_prompts=[],
1168
+ imgsz=self.imgsz_spinbox.value(),
1169
+ conf=self.main_window.get_uncertainty_thresh(),
1170
+ iou=self.main_window.get_iou_thresh(),
1171
+ max_det=self.get_max_detections(),
1172
+ retina_masks=self.task == "segment")
1173
+
1174
+ return [results]
1163
1175
 
1164
1176
  def generate_vpes_from_references(self):
1165
1177
  """
@@ -1257,8 +1269,8 @@ class DeployGeneratorDialog(QDialog):
1257
1269
 
1258
1270
  def _apply_model_using_vpe(self, inputs):
1259
1271
  """
1260
- Apply the model to the inputs using pre-calculated VPEs from imported files
1261
- and/or generated from reference annotations.
1272
+ Apply the model using VPEs. This method now supports clustering N VPEs
1273
+ into K prototypes before running a single multi-class prediction.
1262
1274
 
1263
1275
  Args:
1264
1276
  inputs (list): List of input images.
@@ -1266,21 +1278,18 @@ class DeployGeneratorDialog(QDialog):
1266
1278
  Returns:
1267
1279
  list: List of prediction results.
1268
1280
  """
1281
+ # Note: This method may require 'from sklearn.cluster import KMeans'
1282
+
1269
1283
  # First reload the model to clear any cached data
1270
1284
  self.reload_model()
1271
1285
 
1272
- # Initialize combined_vpes list
1286
+ # 1. Gather all available VPEs from imported files and generated references
1273
1287
  combined_vpes = []
1274
-
1275
- # Add imported VPEs if available
1276
1288
  if self.imported_vpes:
1277
1289
  combined_vpes.extend(self.imported_vpes)
1278
-
1279
- # Add pre-generated reference VPEs if available
1280
1290
  if self.reference_vpes:
1281
1291
  combined_vpes.extend(self.reference_vpes)
1282
1292
 
1283
- # Check if we have any VPEs to use
1284
1293
  if not combined_vpes:
1285
1294
  QMessageBox.warning(
1286
1295
  self,
@@ -1290,18 +1299,57 @@ class DeployGeneratorDialog(QDialog):
1290
1299
  )
1291
1300
  return []
1292
1301
 
1293
- # Average all the VPEs together to create a final VPE tensor
1294
- averaged_vpe = torch.cat(combined_vpes).mean(dim=0, keepdim=True)
1295
- final_vpe = torch.nn.functional.normalize(averaged_vpe, p=2, dim=-1)
1296
-
1297
- # For backward compatibility, update self.vpe
1298
- self.vpe = final_vpe
1302
+ # 2. Generate K prototypes from the N available VPEs
1303
+ k = self.k_prototypes_spinbox.value()
1304
+ num_available_vpes = len(combined_vpes)
1305
+ prototype_vpes = []
1306
+
1307
+ # If K is 0 (use all) or K >= N, no clustering is needed. Each VPE is a prototype.
1308
+ if k == 0 or k >= num_available_vpes:
1309
+ prototype_vpes = combined_vpes
1310
+ else:
1311
+ # Perform K-Means clustering to find K prototypes
1312
+ try:
1313
+ # Prepare tensor for scikit-learn: shape (N, E)
1314
+ all_vpes_tensor = torch.cat([vpe.squeeze(1) for vpe in combined_vpes], dim=0)
1315
+ vpes_np = all_vpes_tensor.cpu().numpy()
1316
+
1317
+ # Lazily import KMeans to avoid making sklearn a hard dependency if not used
1318
+ from sklearn.cluster import KMeans
1319
+ kmeans = KMeans(n_clusters=k, random_state=0, n_init='auto').fit(vpes_np)
1320
+ centroids_np = kmeans.cluster_centers_
1321
+
1322
+ # Convert centroids back to a list of normalized PyTorch tensors
1323
+ centroids_tensor = torch.from_numpy(centroids_np).to(self.device)
1324
+ for i in range(k):
1325
+ # Reshape centroid to (1, 1, E) for model compatibility
1326
+ centroid = centroids_tensor[i].unsqueeze(0).unsqueeze(0)
1327
+ normalized_centroid = torch.nn.functional.normalize(centroid, p=2, dim=-1)
1328
+ prototype_vpes.append(normalized_centroid)
1329
+ except Exception as e:
1330
+ QMessageBox.critical(self, "Clustering Error", f"Failed to perform K-Means clustering: {e}")
1331
+ return []
1332
+
1333
+ # 3. Configure the model with the K prototypes
1334
+ if not prototype_vpes:
1335
+ QMessageBox.warning(self, "Prototype Error", "Could not generate any prototypes for prediction.")
1336
+ return []
1337
+
1338
+ # For backward compatibility, set self.vpe to the average of the final prototypes
1339
+ averaged_prototype = torch.cat(prototype_vpes).mean(dim=0, keepdim=True)
1340
+ self.vpe = torch.nn.functional.normalize(averaged_prototype, p=2, dim=-1)
1341
+
1342
+ # Set up the model for multi-class detection with K proto-classes
1343
+ num_prototypes = len(prototype_vpes)
1344
+ proto_class_names = [f"object{i}" for i in range(num_prototypes)]
1299
1345
 
1300
- # Set the final VPE to the model
1346
+ # Stack prototypes into a single tensor of shape (1, K, E) for set_classes
1347
+ stacked_vpes = torch.cat(prototype_vpes, dim=1)
1348
+
1301
1349
  self.loaded_model.is_fused = lambda: False
1302
- self.loaded_model.set_classes(["object0"], final_vpe)
1350
+ self.loaded_model.set_classes(proto_class_names, stacked_vpes)
1303
1351
 
1304
- # Make predictions on the target using the averaged VPE
1352
+ # 4. Make predictions on the target using the K prototypes
1305
1353
  results = self.loaded_model.predict(inputs[0],
1306
1354
  visual_prompts=[],
1307
1355
  imgsz=self.imgsz_spinbox.value(),
@@ -1391,7 +1439,7 @@ class DeployGeneratorDialog(QDialog):
1391
1439
  return updated_results
1392
1440
 
1393
1441
  def _process_results(self, results_processor, results_list, image_path):
1394
- """Process the results using the result processor."""
1442
+ """Process the results, merging K proto-class detections into a single target class."""
1395
1443
  # Get the raster object and number of work items
1396
1444
  raster = self.image_window.raster_manager.get_raster(image_path)
1397
1445
  total = raster.count_work_items()
@@ -1405,14 +1453,25 @@ class DeployGeneratorDialog(QDialog):
1405
1453
  progress_bar.start_progress(total)
1406
1454
 
1407
1455
  updated_results = []
1456
+ target_label_name = self.reference_label.short_label_code
1408
1457
 
1409
1458
  for idx, results in enumerate(results_list):
1410
1459
  # Each Results is a list (within the results_list, [[], ]
1411
- if results:
1412
- # Update path and names
1460
+ if results and results[0].boxes is not None and len(results[0].boxes) > 0:
1461
+ # Clone the data tensor and set all classes to 0
1462
+ new_data = results[0].boxes.data.clone()
1463
+ new_data[:, 5] = 0 # The 6th column (index 5) is the class
1464
+
1465
+ # Create a new Boxes object of the same type
1466
+ new_boxes = type(results[0].boxes)(new_data, results[0].boxes.orig_shape)
1467
+ results[0].boxes = new_boxes
1468
+
1469
+ # Update the 'names' dictionary to map our single class ID (0)
1470
+ # to the final target label name chosen by the user.
1471
+ results[0].names = {0: target_label_name}
1472
+
1473
+ # Update path (original logic)
1413
1474
  results[0].path = image_path
1414
- results[0].names = {0: self.class_mapping[0].short_label_code}
1415
- # This needs to be done again, in case SAM was used
1416
1475
 
1417
1476
  # Check if the work area is valid, or the image path is being used
1418
1477
  if work_areas and self.annotation_window.get_selected_tool() == "work_area":
@@ -1424,7 +1483,7 @@ class DeployGeneratorDialog(QDialog):
1424
1483
  else:
1425
1484
  results = results[0]
1426
1485
 
1427
- # Append the result object (not a list) to the updated results list
1486
+ # Append the result object to the updated results list
1428
1487
  updated_results.append(results)
1429
1488
 
1430
1489
  # Update the index for the next work area
@@ -1444,42 +1503,96 @@ class DeployGeneratorDialog(QDialog):
1444
1503
 
1445
1504
  def show_vpe(self):
1446
1505
  """
1447
- Show a visualization of the currently stored VPEs using PyQtGraph.
1506
+ Show a visualization of the stored VPEs and their K-prototypes.
1448
1507
  """
1449
1508
  try:
1509
+ # 1. Gather all raw VPEs from imports and references
1450
1510
  vpes_with_source = []
1451
-
1452
- # 1. Add any VPEs that were loaded from a file
1453
1511
  if self.imported_vpes:
1454
1512
  for vpe in self.imported_vpes:
1455
1513
  vpes_with_source.append((vpe, "Import"))
1456
-
1457
- # 2. Add any pre-generated VPEs from reference images
1458
1514
  if self.reference_vpes:
1459
1515
  for vpe in self.reference_vpes:
1460
1516
  vpes_with_source.append((vpe, "Reference"))
1461
1517
 
1462
- # 3. Check if there is anything to visualize
1463
1518
  if not vpes_with_source:
1464
1519
  QMessageBox.warning(
1465
1520
  self,
1466
1521
  "No VPEs Available",
1467
- "No VPEs available to visualize. Please load a VPE file or generate VPEs from references first."
1522
+ "No VPEs available to visualize. Please load or generate VPEs first."
1468
1523
  )
1469
1524
  return
1470
1525
 
1471
- # 4. Create the visualization dialog
1472
- all_vpe_tensors = [vpe for vpe, source in vpes_with_source]
1473
- averaged_vpe = torch.cat(all_vpe_tensors).mean(dim=0, keepdim=True)
1474
- final_vpe = torch.nn.functional.normalize(averaged_vpe, p=2, dim=-1)
1475
-
1476
- QApplication.setOverrideCursor(Qt.WaitCursor)
1526
+ raw_vpes = [vpe for vpe, source in vpes_with_source]
1527
+ num_raw = len(raw_vpes)
1477
1528
 
1478
- dialog = VPEVisualizationDialog(vpes_with_source, final_vpe, self)
1479
- QApplication.restoreOverrideCursor()
1529
+ # 2. Get K and determine if we need to cluster
1530
+ k = self.k_prototypes_spinbox.value()
1531
+ prototypes = [] # This will be the centroids if clustering is performed
1532
+ final_vpe = None
1533
+ clustering_performed = False
1534
+
1535
+ # Case 1: We want to cluster (1 <= k < num_raw)
1536
+ if 1 <= k < num_raw:
1537
+ try:
1538
+ # Prepare tensor for scikit-learn: shape (N, E)
1539
+ all_vpes_tensor = torch.cat([vpe.squeeze(1) for vpe in raw_vpes], dim=0)
1540
+ vpes_np = all_vpes_tensor.cpu().numpy()
1541
+
1542
+ from sklearn.cluster import KMeans
1543
+ kmeans = KMeans(n_clusters=k, random_state=0, n_init='auto').fit(vpes_np)
1544
+ centroids_np = kmeans.cluster_centers_
1545
+
1546
+ centroids_tensor = torch.from_numpy(centroids_np).to(self.device)
1547
+ for i in range(k):
1548
+ centroid = centroids_tensor[i].unsqueeze(0).unsqueeze(0)
1549
+ normalized_centroid = torch.nn.functional.normalize(centroid, p=2, dim=-1)
1550
+ prototypes.append(normalized_centroid)
1551
+
1552
+ # The final VPE is the average of the centroids
1553
+ stacked_prototypes = torch.cat(prototypes, dim=1) # Shape: (1, k, E)
1554
+ averaged_prototype = stacked_prototypes.mean(dim=1, keepdim=True) # Shape: (1, 1, E)
1555
+ final_vpe = torch.nn.functional.normalize(averaged_prototype, p=2, dim=-1)
1556
+ clustering_performed = True
1557
+
1558
+ except Exception as e:
1559
+ QMessageBox.critical(self,
1560
+ "Clustering Error",
1561
+ f"Could not perform clustering for visualization: {e}")
1562
+ # If clustering fails, fall back to using all raw VPEs
1563
+ prototypes = []
1564
+ clustering_performed = False
1565
+
1566
+ # Case 2: k==0 -> use all raw VPEs as prototypes (no clustering)
1567
+ if k == 0 or not clustering_performed:
1568
+ # We are not clustering, so we use all raw VPEs as prototypes
1569
+ # For visualization purposes, we'll show the raw VPEs and their average
1570
+ stacked_raw = torch.cat(raw_vpes, dim=1) # Shape: (1, num_raw, E)
1571
+ averaged_raw = stacked_raw.mean(dim=1, keepdim=True) # Shape: (1, 1, E)
1572
+ final_vpe = torch.nn.functional.normalize(averaged_raw, p=2, dim=-1)
1573
+ # Don't set prototypes here - we'll show raw VPEs separately in the visualization
1480
1574
 
1575
+ # Case 3: k >= num_raw -> use all raw VPEs as prototypes (no clustering needed)
1576
+ elif k >= num_raw:
1577
+ # We have more requested prototypes than available VPEs, so we use all VPEs
1578
+ stacked_raw = torch.cat(raw_vpes, dim=1) # Shape: (1, num_raw, E)
1579
+ averaged_raw = stacked_raw.mean(dim=1, keepdim=True) # Shape: (1, 1, E)
1580
+ final_vpe = torch.nn.functional.normalize(averaged_raw, p=2, dim=-1)
1581
+ # Don't set prototypes here - we'll show raw VPEs separately in the visualization
1582
+
1583
+ # 3. Create and show the visualization dialog
1584
+ QApplication.setOverrideCursor(Qt.WaitCursor)
1585
+ dialog = VPEVisualizationDialog(
1586
+ vpes_with_source,
1587
+ final_vpe,
1588
+ prototypes=prototypes,
1589
+ clustering_performed=clustering_performed,
1590
+ k_value=k,
1591
+ parent=self
1592
+ )
1593
+ QApplication.restoreOverrideCursor()
1481
1594
  dialog.exec_()
1482
-
1595
+
1483
1596
  except Exception as e:
1484
1597
  QApplication.restoreOverrideCursor()
1485
1598
  QMessageBox.critical(self, "Error Visualizing VPE", f"An error occurred: {str(e)}")
@@ -1512,44 +1625,45 @@ class DeployGeneratorDialog(QDialog):
1512
1625
 
1513
1626
  class VPEVisualizationDialog(QDialog):
1514
1627
  """
1515
- Dialog for visualizing VPE embeddings in 2D space using PCA.
1628
+ Dialog for visualizing VPE embeddings, now including K-prototypes.
1516
1629
  """
1517
- def __init__(self, vpe_list_with_source, final_vpe=None, parent=None):
1630
+ def __init__(self, vpe_list_with_source, final_vpe=None, prototypes=None,
1631
+ clustering_performed=False, k_value=0, parent=None):
1518
1632
  """
1519
- Initialize the dialog with a list of VPE tensors and their sources.
1633
+ Initialize the dialog.
1520
1634
 
1521
1635
  Args:
1522
- vpe_list_with_source (list): List of (VPE tensor, source_str) tuples
1523
- final_vpe (torch.Tensor, optional): The final (averaged) VPE
1524
- parent (QWidget, optional): Parent widget
1636
+ vpe_list_with_source (list): List of (VPE tensor, source_str) tuples for raw VPEs.
1637
+ final_vpe (torch.Tensor, optional): The final (averaged) VPE.
1638
+ prototypes (list, optional): List of K-prototype VPE tensors (cluster centroids).
1639
+ clustering_performed (bool): Whether clustering was performed.
1640
+ k_value (int): The K value used for clustering.
1641
+ parent (QWidget, optional): Parent widget.
1525
1642
  """
1526
1643
  super().__init__(parent)
1527
1644
  self.setWindowTitle("VPE Visualization")
1528
1645
  self.resize(1000, 1000)
1529
-
1530
- # Add a maximize button to the dialog's title bar
1531
1646
  self.setWindowFlags(self.windowFlags() | Qt.WindowMaximizeButtonHint)
1532
1647
 
1533
- # Store the VPEs and their sources
1648
+ # Store the VPEs and clustering info
1534
1649
  self.vpe_list_with_source = vpe_list_with_source
1535
1650
  self.final_vpe = final_vpe
1651
+ self.prototypes = prototypes if prototypes else []
1652
+ self.clustering_performed = clustering_performed
1653
+ self.k_value = k_value
1536
1654
 
1537
1655
  # Create the layout
1538
1656
  layout = QVBoxLayout(self)
1539
1657
 
1540
1658
  # Create the plot widget
1541
1659
  self.plot_widget = pg.PlotWidget()
1542
- self.plot_widget.setBackground('w') # White background
1660
+ self.plot_widget.setBackground('w')
1543
1661
  self.plot_widget.setTitle("PCA Visualization of Visual Prompt Embeddings", color="#000000", size="10pt")
1544
1662
  self.plot_widget.showGrid(x=True, y=True, alpha=0.3)
1545
-
1546
- # Add the plot widget to the layout
1547
1663
  layout.addWidget(self.plot_widget)
1548
-
1549
- # Add spacing between plot_widget and info_label
1550
1664
  layout.addSpacing(20)
1551
1665
 
1552
- # Add information label at the bottom
1666
+ # Add information label
1553
1667
  self.info_label = QLabel()
1554
1668
  self.info_label.setAlignment(Qt.AlignCenter)
1555
1669
  layout.addWidget(self.info_label)
@@ -1564,101 +1678,129 @@ class VPEVisualizationDialog(QDialog):
1564
1678
 
1565
1679
  def visualize_vpes(self):
1566
1680
  """
1567
- Apply PCA to the VPE tensors and visualize them in 2D space.
1681
+ Apply PCA to all VPEs (raw, prototypes, final) and visualize them.
1568
1682
  """
1569
1683
  if not self.vpe_list_with_source:
1570
1684
  self.info_label.setText("No VPEs available to visualize.")
1571
1685
  return
1686
+
1687
+ # 1. Collect all numpy arrays for PCA transformation
1688
+ raw_vpe_arrays = [vpe.detach().cpu().numpy().squeeze() for vpe, source in self.vpe_list_with_source]
1689
+ prototype_arrays = [p.detach().cpu().numpy().squeeze() for p in self.prototypes]
1572
1690
 
1573
- # Convert tensors to numpy arrays for PCA, separating them from the source string
1574
- vpe_arrays = [vpe.detach().cpu().numpy().squeeze() for vpe, source in self.vpe_list_with_source]
1691
+ all_arrays_for_pca = raw_vpe_arrays + prototype_arrays
1575
1692
 
1576
- # If final VPE is provided, add it to the arrays
1577
1693
  final_vpe_array = None
1578
1694
  if self.final_vpe is not None:
1579
1695
  final_vpe_array = self.final_vpe.detach().cpu().numpy().squeeze()
1580
- all_vpes = np.vstack(vpe_arrays + [final_vpe_array])
1581
- else:
1582
- all_vpes = np.vstack(vpe_arrays)
1583
-
1584
- # Apply PCA to reduce to 2 dimensions
1696
+ all_arrays_for_pca.append(final_vpe_array)
1697
+
1698
+ if len(all_arrays_for_pca) < 2:
1699
+ self.info_label.setText("At least 2 VPEs are needed for PCA visualization.")
1700
+ return
1701
+
1702
+ # 2. Apply PCA
1703
+ all_vpes_stacked = np.vstack(all_arrays_for_pca)
1585
1704
  pca = PCA(n_components=2)
1586
- vpes_2d = pca.fit_transform(all_vpes)
1705
+ vpes_2d = pca.fit_transform(all_vpes_stacked)
1587
1706
 
1588
- # Clear the plot
1707
+ # 3. Plot the results
1589
1708
  self.plot_widget.clear()
1590
-
1591
- # Generate random colors for individual VPEs
1592
- num_vpes = len(vpe_arrays)
1593
- colors = self.generate_distinct_colors(num_vpes)
1594
-
1595
- # Create a legend with 3 columns to keep it compact
1596
1709
  legend = self.plot_widget.addLegend(colCount=3)
1597
1710
 
1598
- # Plot individual VPEs
1599
- for i, (vpe_tuple, vpe_2d) in enumerate(zip(self.vpe_list_with_source, vpes_2d[:num_vpes])):
1711
+ # Slicing indices
1712
+ num_raw = len(raw_vpe_arrays)
1713
+ num_prototypes = len(prototype_arrays)
1714
+
1715
+ # Determine if each raw VPE is effectively a prototype (k==0 or k>=N)
1716
+ each_vpe_is_prototype = (self.k_value == 0 or self.k_value >= num_raw)
1717
+
1718
+ # Plot individual raw VPEs
1719
+ colors = self.generate_distinct_colors(num_raw)
1720
+ for i, (vpe_tuple, vpe_2d) in enumerate(zip(self.vpe_list_with_source, vpes_2d[:num_raw])):
1600
1721
  source_char = 'I' if vpe_tuple[1] == 'Import' else 'R'
1601
- color = pg.mkColor(colors[i])
1722
+
1723
+ # Use diamonds if each VPE is a prototype, circles otherwise
1724
+ symbol = 'd' if each_vpe_is_prototype else 'o'
1725
+
1726
+ # If it's a prototype, add a black border
1727
+ pen = pg.mkPen(color='k', width=1.5) if each_vpe_is_prototype else None
1728
+
1729
+ # Create label with prototype indicator if applicable
1730
+ name_suffix = " (Prototype)" if each_vpe_is_prototype else ""
1731
+ name = f"VPE {i+1} ({source_char}){name_suffix}"
1732
+
1602
1733
  scatter = pg.ScatterPlotItem(
1603
1734
  x=[vpe_2d[0]],
1604
1735
  y=[vpe_2d[1]],
1605
- brush=color,
1606
- size=15,
1607
- name=f"VPE {i+1} ({source_char})"
1736
+ brush=pg.mkColor(colors[i]),
1737
+ pen=pen,
1738
+ size=15 if not each_vpe_is_prototype else 18,
1739
+ symbol=symbol,
1740
+ name=name
1608
1741
  )
1609
1742
  self.plot_widget.addItem(scatter)
1610
1743
 
1611
- # Plot the final (averaged) VPE if available
1744
+ # Plot K-Prototypes (blue diamonds) if we have any and explicit clustering was performed
1745
+ if self.prototypes and self.clustering_performed:
1746
+ prototype_vpes_2d = vpes_2d[num_raw: num_raw + num_prototypes]
1747
+ scatter = pg.ScatterPlotItem(
1748
+ x=prototype_vpes_2d[:, 0],
1749
+ y=prototype_vpes_2d[:, 1],
1750
+ brush=pg.mkBrush(color=(0, 0, 255, 150)),
1751
+ pen=pg.mkPen(color='k', width=1.5),
1752
+ size=18,
1753
+ symbol='d',
1754
+ name=f"K-Prototypes (K={self.k_value})"
1755
+ )
1756
+ self.plot_widget.addItem(scatter)
1757
+
1758
+ # Plot the final (averaged) VPE (red star)
1612
1759
  if final_vpe_array is not None:
1613
1760
  final_vpe_2d = vpes_2d[-1]
1614
1761
  scatter = pg.ScatterPlotItem(
1615
1762
  x=[final_vpe_2d[0]],
1616
- y=[final_vpe_2d[1]],
1763
+ y=[final_vpe_2d[1]],
1617
1764
  brush=pg.mkBrush(color='r'),
1618
- size=20,
1619
- symbol='star',
1620
- name="Final VPE"
1765
+ size=20,
1766
+ symbol='star',
1767
+ name="Final VPE (Avg)"
1621
1768
  )
1622
1769
  self.plot_widget.addItem(scatter)
1623
1770
 
1624
- # Update the information label
1771
+ # 4. Update the information label
1625
1772
  orig_dim = self.vpe_list_with_source[0][0].shape[-1]
1626
1773
  explained_variance = sum(pca.explained_variance_ratio_)
1627
- self.info_label.setText(
1628
- f"Original dimension: {orig_dim} → Reduced to 2D\n"
1629
- f"Total explained variance: {explained_variance:.2%}\n"
1630
- f"PC1: {pca.explained_variance_ratio_[0]:.2%} variance, "
1631
- f"PC2: {pca.explained_variance_ratio_[1]:.2%} variance"
1632
- )
1774
+
1775
+ info_text = (f"Original dimension: {orig_dim} → Reduced to 2D\n"
1776
+ f"Total explained variance: {explained_variance:.2%}\n"
1777
+ f"PC1: {pca.explained_variance_ratio_[0]:.2%} variance, "
1778
+ f"PC2: {pca.explained_variance_ratio_[1]:.2%} variance\n"
1779
+ f"Number of raw VPEs: {num_raw}\n")
1780
+
1781
+ if self.clustering_performed:
1782
+ info_text += f"Clustering performed with K={self.k_value}\n"
1783
+ info_text += f"Number of prototypes: {len(self.prototypes)}"
1784
+ else:
1785
+ if self.k_value == 0:
1786
+ info_text += f"No clustering (K=0): all {num_raw} raw VPEs used as prototypes"
1787
+ else:
1788
+ info_text += f"No clustering performed (K={self.k_value} >= {num_raw}): all raw VPEs used as prototypes"
1789
+
1790
+ self.info_label.setText(info_text)
1633
1791
 
1634
1792
  def generate_distinct_colors(self, num_colors):
1635
- """
1636
- Generate visually distinct colors by using evenly spaced hues
1637
- with random saturation and value.
1638
-
1639
- Args:
1640
- num_colors (int): Number of colors to generate
1641
-
1642
- Returns:
1643
- list: List of color hex strings
1644
- """
1793
+ """Generates visually distinct colors."""
1645
1794
  import random
1646
1795
  from colorsys import hsv_to_rgb
1647
1796
 
1648
1797
  colors = []
1649
1798
  for i in range(num_colors):
1650
- # Use golden ratio to space hues evenly
1651
1799
  hue = (i * 0.618033988749895) % 1.0
1652
- # Random saturation between 0.6-1.0 (avoid too pale)
1653
1800
  saturation = random.uniform(0.6, 1.0)
1654
- # Random value between 0.7-1.0 (avoid too dark)
1655
1801
  value = random.uniform(0.7, 1.0)
1656
-
1657
- # Convert HSV to RGB (0-1 range)
1658
1802
  r, g, b = hsv_to_rgb(hue, saturation, value)
1659
-
1660
- # Convert RGB to hex string
1661
1803
  hex_color = f"#{int(r*255):02x}{int(g*255):02x}{int(b*255):02x}"
1662
1804
  colors.append(hex_color)
1663
-
1805
+
1664
1806
  return colors