coralnet-toolbox 0.0.76__py2.py3-none-any.whl → 0.0.77__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- coralnet_toolbox/Common/QtGraphicsUtility.py +18 -8
- coralnet_toolbox/IO/QtExportMaskAnnotations.py +45 -6
- coralnet_toolbox/IO/QtImportImages.py +7 -15
- coralnet_toolbox/IO/QtOpenProject.py +15 -19
- coralnet_toolbox/MachineLearning/ImportDataset/QtBase.py +33 -8
- coralnet_toolbox/QtAnnotationWindow.py +4 -0
- coralnet_toolbox/QtEventFilter.py +1 -1
- coralnet_toolbox/QtImageWindow.py +4 -0
- coralnet_toolbox/Rasters/RasterManager.py +5 -2
- coralnet_toolbox/SeeAnything/QtDeployGenerator.py +312 -170
- coralnet_toolbox/Tools/QtPatchTool.py +6 -2
- coralnet_toolbox/Tools/QtPolygonTool.py +5 -3
- coralnet_toolbox/Tools/QtRectangleTool.py +17 -9
- coralnet_toolbox/Tools/QtSAMTool.py +4 -0
- coralnet_toolbox/Tools/QtSeeAnythingTool.py +4 -0
- coralnet_toolbox/Tools/QtTool.py +79 -3
- coralnet_toolbox/Tools/QtWorkAreaTool.py +4 -0
- coralnet_toolbox/Transformers/Models/QtBase.py +2 -1
- coralnet_toolbox/__init__.py +1 -1
- {coralnet_toolbox-0.0.76.dist-info → coralnet_toolbox-0.0.77.dist-info}/METADATA +1 -1
- {coralnet_toolbox-0.0.76.dist-info → coralnet_toolbox-0.0.77.dist-info}/RECORD +25 -25
- {coralnet_toolbox-0.0.76.dist-info → coralnet_toolbox-0.0.77.dist-info}/WHEEL +0 -0
- {coralnet_toolbox-0.0.76.dist-info → coralnet_toolbox-0.0.77.dist-info}/entry_points.txt +0 -0
- {coralnet_toolbox-0.0.76.dist-info → coralnet_toolbox-0.0.77.dist-info}/licenses/LICENSE.txt +0 -0
- {coralnet_toolbox-0.0.76.dist-info → coralnet_toolbox-0.0.77.dist-info}/top_level.txt +0 -0
@@ -26,7 +26,6 @@ from coralnet_toolbox.Annotations.QtRectangleAnnotation import RectangleAnnotati
|
|
26
26
|
|
27
27
|
from coralnet_toolbox.Results import ResultsProcessor
|
28
28
|
from coralnet_toolbox.Results import MapResults
|
29
|
-
from coralnet_toolbox.Results import CombineResults
|
30
29
|
|
31
30
|
from coralnet_toolbox.QtProgressBar import ProgressBar
|
32
31
|
from coralnet_toolbox.QtImageWindow import ImageWindow
|
@@ -539,8 +538,8 @@ class DeployGeneratorDialog(QDialog):
|
|
539
538
|
|
540
539
|
def setup_reference_layout(self):
|
541
540
|
"""
|
542
|
-
Set up the layout
|
543
|
-
|
541
|
+
Set up the layout for reference selection, including the output label,
|
542
|
+
reference method, and the number of prototype clusters (K).
|
544
543
|
"""
|
545
544
|
group_box = QGroupBox("Reference")
|
546
545
|
layout = QFormLayout()
|
@@ -555,8 +554,18 @@ class DeployGeneratorDialog(QDialog):
|
|
555
554
|
self.reference_method_combo_box.addItems(["VPE", "Images"])
|
556
555
|
layout.addRow("Reference Method:", self.reference_method_combo_box)
|
557
556
|
|
557
|
+
# Add a spinbox for the user to define the number of prototypes (K)
|
558
|
+
self.k_prototypes_spinbox = QSpinBox()
|
559
|
+
self.k_prototypes_spinbox.setRange(0, 1000)
|
560
|
+
self.k_prototypes_spinbox.setValue(0)
|
561
|
+
self.k_prototypes_spinbox.setToolTip(
|
562
|
+
"Set the number of prototype clusters (K) to generate from references.\n"
|
563
|
+
"Set to 0 to treat every unique reference image/VPE as its own prototype (K=N)."
|
564
|
+
)
|
565
|
+
layout.addRow("Number of Prototypes (K):", self.k_prototypes_spinbox)
|
566
|
+
|
558
567
|
group_box.setLayout(layout)
|
559
|
-
self.right_panel.addWidget(group_box)
|
568
|
+
self.right_panel.addWidget(group_box)
|
560
569
|
|
561
570
|
def setup_buttons_layout(self):
|
562
571
|
"""
|
@@ -1088,9 +1097,9 @@ class DeployGeneratorDialog(QDialog):
|
|
1088
1097
|
|
1089
1098
|
def _apply_model_using_images(self, inputs, reference_dict):
|
1090
1099
|
"""
|
1091
|
-
Apply the model using
|
1092
|
-
|
1093
|
-
|
1100
|
+
Apply the model using reference images. This method now generates VPEs
|
1101
|
+
from the reference annotations, clusters them into K prototypes, and
|
1102
|
+
runs a single, efficient multi-class prediction.
|
1094
1103
|
|
1095
1104
|
Args:
|
1096
1105
|
inputs (list): List of input images.
|
@@ -1099,67 +1108,70 @@ class DeployGeneratorDialog(QDialog):
|
|
1099
1108
|
Returns:
|
1100
1109
|
list: List of prediction results.
|
1101
1110
|
"""
|
1102
|
-
#
|
1103
|
-
QApplication.setOverrideCursor(Qt.WaitCursor)
|
1104
|
-
progress_bar = ProgressBar(self.annotation_window, title="Making Predictions per Reference")
|
1105
|
-
progress_bar.show()
|
1106
|
-
progress_bar.start_progress(len(reference_dict))
|
1111
|
+
# Note: This method may require 'from sklearn.cluster import KMeans'
|
1107
1112
|
|
1108
|
-
|
1109
|
-
|
1110
|
-
|
1111
|
-
input_image = inputs[0]
|
1113
|
+
# 1. Reload model and generate initial VPEs from the reference images
|
1114
|
+
self.reload_model()
|
1115
|
+
initial_vpes = self.references_to_vpe(reference_dict, update_reference_vpes=False)
|
1112
1116
|
|
1113
|
-
|
1114
|
-
|
1117
|
+
if not initial_vpes:
|
1118
|
+
QMessageBox.warning(self,
|
1119
|
+
"VPE Generation Failed",
|
1120
|
+
"Could not generate VPEs from the selected reference images.")
|
1121
|
+
return []
|
1115
1122
|
|
1116
|
-
#
|
1117
|
-
|
1118
|
-
|
1119
|
-
|
1120
|
-
visual_prompts = {
|
1121
|
-
'bboxes': ref_annotations['bboxes'],
|
1122
|
-
'cls': ref_annotations['cls'],
|
1123
|
-
}
|
1124
|
-
if self.task == 'segment':
|
1125
|
-
visual_prompts['masks'] = ref_annotations['masks']
|
1126
|
-
|
1127
|
-
# Make predictions on the target using the current reference
|
1128
|
-
results = self.loaded_model.predict(input_image,
|
1129
|
-
refer_image=ref_path,
|
1130
|
-
visual_prompts=visual_prompts,
|
1131
|
-
predictor=predictor,
|
1132
|
-
imgsz=self.imgsz_spinbox.value(),
|
1133
|
-
conf=self.main_window.get_uncertainty_thresh(),
|
1134
|
-
iou=self.main_window.get_iou_thresh(),
|
1135
|
-
max_det=self.get_max_detections(),
|
1136
|
-
retina_masks=self.task == "segment")
|
1137
|
-
|
1138
|
-
if not len(results[0].boxes):
|
1139
|
-
# If no boxes were detected, skip to the next reference
|
1140
|
-
progress_bar.update_progress()
|
1141
|
-
continue
|
1142
|
-
|
1143
|
-
# Update the name of the results and append to the list
|
1144
|
-
results[0].names = {0: self.class_mapping[0].short_label_code}
|
1145
|
-
results_list.extend(results[0])
|
1146
|
-
|
1147
|
-
progress_bar.update_progress()
|
1148
|
-
gc.collect()
|
1149
|
-
empty_cache()
|
1123
|
+
# 2. Generate K prototypes from the N generated VPEs
|
1124
|
+
k = self.k_prototypes_spinbox.value()
|
1125
|
+
num_available_vpes = len(initial_vpes)
|
1126
|
+
prototype_vpes = []
|
1150
1127
|
|
1151
|
-
#
|
1152
|
-
|
1153
|
-
|
1154
|
-
|
1155
|
-
|
1156
|
-
|
1157
|
-
|
1158
|
-
|
1159
|
-
|
1128
|
+
# If K is 0 (use all) or K >= N, no clustering is needed.
|
1129
|
+
if k == 0 or k >= num_available_vpes:
|
1130
|
+
prototype_vpes = initial_vpes
|
1131
|
+
else:
|
1132
|
+
# Perform K-Means clustering to find K prototypes
|
1133
|
+
try:
|
1134
|
+
# Prepare tensor for scikit-learn: shape (N, E)
|
1135
|
+
all_vpes_tensor = torch.cat([vpe.squeeze(1) for vpe in initial_vpes], dim=0)
|
1136
|
+
vpes_np = all_vpes_tensor.cpu().numpy()
|
1137
|
+
|
1138
|
+
from sklearn.cluster import KMeans
|
1139
|
+
kmeans = KMeans(n_clusters=k, random_state=0, n_init='auto').fit(vpes_np)
|
1140
|
+
centroids_np = kmeans.cluster_centers_
|
1141
|
+
|
1142
|
+
# Convert centroids back to a list of normalized PyTorch tensors
|
1143
|
+
centroids_tensor = torch.from_numpy(centroids_np).to(self.device)
|
1144
|
+
for i in range(k):
|
1145
|
+
# Reshape centroid to (1, 1, E)
|
1146
|
+
centroid = centroids_tensor[i].unsqueeze(0).unsqueeze(0)
|
1147
|
+
normalized_centroid = torch.nn.functional.normalize(centroid, p=2, dim=-1)
|
1148
|
+
prototype_vpes.append(normalized_centroid)
|
1149
|
+
except Exception as e:
|
1150
|
+
QMessageBox.critical(self, "Clustering Error", f"Failed to perform K-Means clustering: {e}")
|
1151
|
+
return []
|
1152
|
+
|
1153
|
+
# 3. Configure the model with the K prototypes
|
1154
|
+
if not prototype_vpes:
|
1155
|
+
QMessageBox.warning(self, "Prototype Error", "Could not generate any prototypes for prediction.")
|
1160
1156
|
return []
|
1161
|
-
|
1162
|
-
|
1157
|
+
|
1158
|
+
num_prototypes = len(prototype_vpes)
|
1159
|
+
proto_class_names = [f"object{i}" for i in range(num_prototypes)]
|
1160
|
+
stacked_vpes = torch.cat(prototype_vpes, dim=1) # Shape: (1, K, E)
|
1161
|
+
|
1162
|
+
self.loaded_model.is_fused = lambda: False
|
1163
|
+
self.loaded_model.set_classes(proto_class_names, stacked_vpes)
|
1164
|
+
|
1165
|
+
# 4. Make a single prediction on the target using the K prototypes
|
1166
|
+
results = self.loaded_model.predict(inputs[0],
|
1167
|
+
visual_prompts=[],
|
1168
|
+
imgsz=self.imgsz_spinbox.value(),
|
1169
|
+
conf=self.main_window.get_uncertainty_thresh(),
|
1170
|
+
iou=self.main_window.get_iou_thresh(),
|
1171
|
+
max_det=self.get_max_detections(),
|
1172
|
+
retina_masks=self.task == "segment")
|
1173
|
+
|
1174
|
+
return [results]
|
1163
1175
|
|
1164
1176
|
def generate_vpes_from_references(self):
|
1165
1177
|
"""
|
@@ -1257,8 +1269,8 @@ class DeployGeneratorDialog(QDialog):
|
|
1257
1269
|
|
1258
1270
|
def _apply_model_using_vpe(self, inputs):
|
1259
1271
|
"""
|
1260
|
-
Apply the model
|
1261
|
-
|
1272
|
+
Apply the model using VPEs. This method now supports clustering N VPEs
|
1273
|
+
into K prototypes before running a single multi-class prediction.
|
1262
1274
|
|
1263
1275
|
Args:
|
1264
1276
|
inputs (list): List of input images.
|
@@ -1266,21 +1278,18 @@ class DeployGeneratorDialog(QDialog):
|
|
1266
1278
|
Returns:
|
1267
1279
|
list: List of prediction results.
|
1268
1280
|
"""
|
1281
|
+
# Note: This method may require 'from sklearn.cluster import KMeans'
|
1282
|
+
|
1269
1283
|
# First reload the model to clear any cached data
|
1270
1284
|
self.reload_model()
|
1271
1285
|
|
1272
|
-
#
|
1286
|
+
# 1. Gather all available VPEs from imported files and generated references
|
1273
1287
|
combined_vpes = []
|
1274
|
-
|
1275
|
-
# Add imported VPEs if available
|
1276
1288
|
if self.imported_vpes:
|
1277
1289
|
combined_vpes.extend(self.imported_vpes)
|
1278
|
-
|
1279
|
-
# Add pre-generated reference VPEs if available
|
1280
1290
|
if self.reference_vpes:
|
1281
1291
|
combined_vpes.extend(self.reference_vpes)
|
1282
1292
|
|
1283
|
-
# Check if we have any VPEs to use
|
1284
1293
|
if not combined_vpes:
|
1285
1294
|
QMessageBox.warning(
|
1286
1295
|
self,
|
@@ -1290,18 +1299,57 @@ class DeployGeneratorDialog(QDialog):
|
|
1290
1299
|
)
|
1291
1300
|
return []
|
1292
1301
|
|
1293
|
-
#
|
1294
|
-
|
1295
|
-
|
1296
|
-
|
1297
|
-
|
1298
|
-
|
1302
|
+
# 2. Generate K prototypes from the N available VPEs
|
1303
|
+
k = self.k_prototypes_spinbox.value()
|
1304
|
+
num_available_vpes = len(combined_vpes)
|
1305
|
+
prototype_vpes = []
|
1306
|
+
|
1307
|
+
# If K is 0 (use all) or K >= N, no clustering is needed. Each VPE is a prototype.
|
1308
|
+
if k == 0 or k >= num_available_vpes:
|
1309
|
+
prototype_vpes = combined_vpes
|
1310
|
+
else:
|
1311
|
+
# Perform K-Means clustering to find K prototypes
|
1312
|
+
try:
|
1313
|
+
# Prepare tensor for scikit-learn: shape (N, E)
|
1314
|
+
all_vpes_tensor = torch.cat([vpe.squeeze(1) for vpe in combined_vpes], dim=0)
|
1315
|
+
vpes_np = all_vpes_tensor.cpu().numpy()
|
1316
|
+
|
1317
|
+
# Lazily import KMeans to avoid making sklearn a hard dependency if not used
|
1318
|
+
from sklearn.cluster import KMeans
|
1319
|
+
kmeans = KMeans(n_clusters=k, random_state=0, n_init='auto').fit(vpes_np)
|
1320
|
+
centroids_np = kmeans.cluster_centers_
|
1321
|
+
|
1322
|
+
# Convert centroids back to a list of normalized PyTorch tensors
|
1323
|
+
centroids_tensor = torch.from_numpy(centroids_np).to(self.device)
|
1324
|
+
for i in range(k):
|
1325
|
+
# Reshape centroid to (1, 1, E) for model compatibility
|
1326
|
+
centroid = centroids_tensor[i].unsqueeze(0).unsqueeze(0)
|
1327
|
+
normalized_centroid = torch.nn.functional.normalize(centroid, p=2, dim=-1)
|
1328
|
+
prototype_vpes.append(normalized_centroid)
|
1329
|
+
except Exception as e:
|
1330
|
+
QMessageBox.critical(self, "Clustering Error", f"Failed to perform K-Means clustering: {e}")
|
1331
|
+
return []
|
1332
|
+
|
1333
|
+
# 3. Configure the model with the K prototypes
|
1334
|
+
if not prototype_vpes:
|
1335
|
+
QMessageBox.warning(self, "Prototype Error", "Could not generate any prototypes for prediction.")
|
1336
|
+
return []
|
1337
|
+
|
1338
|
+
# For backward compatibility, set self.vpe to the average of the final prototypes
|
1339
|
+
averaged_prototype = torch.cat(prototype_vpes).mean(dim=0, keepdim=True)
|
1340
|
+
self.vpe = torch.nn.functional.normalize(averaged_prototype, p=2, dim=-1)
|
1341
|
+
|
1342
|
+
# Set up the model for multi-class detection with K proto-classes
|
1343
|
+
num_prototypes = len(prototype_vpes)
|
1344
|
+
proto_class_names = [f"object{i}" for i in range(num_prototypes)]
|
1299
1345
|
|
1300
|
-
#
|
1346
|
+
# Stack prototypes into a single tensor of shape (1, K, E) for set_classes
|
1347
|
+
stacked_vpes = torch.cat(prototype_vpes, dim=1)
|
1348
|
+
|
1301
1349
|
self.loaded_model.is_fused = lambda: False
|
1302
|
-
self.loaded_model.set_classes(
|
1350
|
+
self.loaded_model.set_classes(proto_class_names, stacked_vpes)
|
1303
1351
|
|
1304
|
-
# Make predictions on the target using the
|
1352
|
+
# 4. Make predictions on the target using the K prototypes
|
1305
1353
|
results = self.loaded_model.predict(inputs[0],
|
1306
1354
|
visual_prompts=[],
|
1307
1355
|
imgsz=self.imgsz_spinbox.value(),
|
@@ -1391,7 +1439,7 @@ class DeployGeneratorDialog(QDialog):
|
|
1391
1439
|
return updated_results
|
1392
1440
|
|
1393
1441
|
def _process_results(self, results_processor, results_list, image_path):
|
1394
|
-
"""Process the results
|
1442
|
+
"""Process the results, merging K proto-class detections into a single target class."""
|
1395
1443
|
# Get the raster object and number of work items
|
1396
1444
|
raster = self.image_window.raster_manager.get_raster(image_path)
|
1397
1445
|
total = raster.count_work_items()
|
@@ -1405,14 +1453,25 @@ class DeployGeneratorDialog(QDialog):
|
|
1405
1453
|
progress_bar.start_progress(total)
|
1406
1454
|
|
1407
1455
|
updated_results = []
|
1456
|
+
target_label_name = self.reference_label.short_label_code
|
1408
1457
|
|
1409
1458
|
for idx, results in enumerate(results_list):
|
1410
1459
|
# Each Results is a list (within the results_list, [[], ]
|
1411
|
-
if results:
|
1412
|
-
#
|
1460
|
+
if results and results[0].boxes is not None and len(results[0].boxes) > 0:
|
1461
|
+
# Clone the data tensor and set all classes to 0
|
1462
|
+
new_data = results[0].boxes.data.clone()
|
1463
|
+
new_data[:, 5] = 0 # The 6th column (index 5) is the class
|
1464
|
+
|
1465
|
+
# Create a new Boxes object of the same type
|
1466
|
+
new_boxes = type(results[0].boxes)(new_data, results[0].boxes.orig_shape)
|
1467
|
+
results[0].boxes = new_boxes
|
1468
|
+
|
1469
|
+
# Update the 'names' dictionary to map our single class ID (0)
|
1470
|
+
# to the final target label name chosen by the user.
|
1471
|
+
results[0].names = {0: target_label_name}
|
1472
|
+
|
1473
|
+
# Update path (original logic)
|
1413
1474
|
results[0].path = image_path
|
1414
|
-
results[0].names = {0: self.class_mapping[0].short_label_code}
|
1415
|
-
# This needs to be done again, in case SAM was used
|
1416
1475
|
|
1417
1476
|
# Check if the work area is valid, or the image path is being used
|
1418
1477
|
if work_areas and self.annotation_window.get_selected_tool() == "work_area":
|
@@ -1424,7 +1483,7 @@ class DeployGeneratorDialog(QDialog):
|
|
1424
1483
|
else:
|
1425
1484
|
results = results[0]
|
1426
1485
|
|
1427
|
-
# Append the result object
|
1486
|
+
# Append the result object to the updated results list
|
1428
1487
|
updated_results.append(results)
|
1429
1488
|
|
1430
1489
|
# Update the index for the next work area
|
@@ -1444,42 +1503,96 @@ class DeployGeneratorDialog(QDialog):
|
|
1444
1503
|
|
1445
1504
|
def show_vpe(self):
|
1446
1505
|
"""
|
1447
|
-
Show a visualization of the
|
1506
|
+
Show a visualization of the stored VPEs and their K-prototypes.
|
1448
1507
|
"""
|
1449
1508
|
try:
|
1509
|
+
# 1. Gather all raw VPEs from imports and references
|
1450
1510
|
vpes_with_source = []
|
1451
|
-
|
1452
|
-
# 1. Add any VPEs that were loaded from a file
|
1453
1511
|
if self.imported_vpes:
|
1454
1512
|
for vpe in self.imported_vpes:
|
1455
1513
|
vpes_with_source.append((vpe, "Import"))
|
1456
|
-
|
1457
|
-
# 2. Add any pre-generated VPEs from reference images
|
1458
1514
|
if self.reference_vpes:
|
1459
1515
|
for vpe in self.reference_vpes:
|
1460
1516
|
vpes_with_source.append((vpe, "Reference"))
|
1461
1517
|
|
1462
|
-
# 3. Check if there is anything to visualize
|
1463
1518
|
if not vpes_with_source:
|
1464
1519
|
QMessageBox.warning(
|
1465
1520
|
self,
|
1466
1521
|
"No VPEs Available",
|
1467
|
-
"No VPEs available to visualize. Please load
|
1522
|
+
"No VPEs available to visualize. Please load or generate VPEs first."
|
1468
1523
|
)
|
1469
1524
|
return
|
1470
1525
|
|
1471
|
-
|
1472
|
-
|
1473
|
-
averaged_vpe = torch.cat(all_vpe_tensors).mean(dim=0, keepdim=True)
|
1474
|
-
final_vpe = torch.nn.functional.normalize(averaged_vpe, p=2, dim=-1)
|
1475
|
-
|
1476
|
-
QApplication.setOverrideCursor(Qt.WaitCursor)
|
1526
|
+
raw_vpes = [vpe for vpe, source in vpes_with_source]
|
1527
|
+
num_raw = len(raw_vpes)
|
1477
1528
|
|
1478
|
-
|
1479
|
-
|
1529
|
+
# 2. Get K and determine if we need to cluster
|
1530
|
+
k = self.k_prototypes_spinbox.value()
|
1531
|
+
prototypes = [] # This will be the centroids if clustering is performed
|
1532
|
+
final_vpe = None
|
1533
|
+
clustering_performed = False
|
1534
|
+
|
1535
|
+
# Case 1: We want to cluster (1 <= k < num_raw)
|
1536
|
+
if 1 <= k < num_raw:
|
1537
|
+
try:
|
1538
|
+
# Prepare tensor for scikit-learn: shape (N, E)
|
1539
|
+
all_vpes_tensor = torch.cat([vpe.squeeze(1) for vpe in raw_vpes], dim=0)
|
1540
|
+
vpes_np = all_vpes_tensor.cpu().numpy()
|
1541
|
+
|
1542
|
+
from sklearn.cluster import KMeans
|
1543
|
+
kmeans = KMeans(n_clusters=k, random_state=0, n_init='auto').fit(vpes_np)
|
1544
|
+
centroids_np = kmeans.cluster_centers_
|
1545
|
+
|
1546
|
+
centroids_tensor = torch.from_numpy(centroids_np).to(self.device)
|
1547
|
+
for i in range(k):
|
1548
|
+
centroid = centroids_tensor[i].unsqueeze(0).unsqueeze(0)
|
1549
|
+
normalized_centroid = torch.nn.functional.normalize(centroid, p=2, dim=-1)
|
1550
|
+
prototypes.append(normalized_centroid)
|
1551
|
+
|
1552
|
+
# The final VPE is the average of the centroids
|
1553
|
+
stacked_prototypes = torch.cat(prototypes, dim=1) # Shape: (1, k, E)
|
1554
|
+
averaged_prototype = stacked_prototypes.mean(dim=1, keepdim=True) # Shape: (1, 1, E)
|
1555
|
+
final_vpe = torch.nn.functional.normalize(averaged_prototype, p=2, dim=-1)
|
1556
|
+
clustering_performed = True
|
1557
|
+
|
1558
|
+
except Exception as e:
|
1559
|
+
QMessageBox.critical(self,
|
1560
|
+
"Clustering Error",
|
1561
|
+
f"Could not perform clustering for visualization: {e}")
|
1562
|
+
# If clustering fails, fall back to using all raw VPEs
|
1563
|
+
prototypes = []
|
1564
|
+
clustering_performed = False
|
1565
|
+
|
1566
|
+
# Case 2: k==0 -> use all raw VPEs as prototypes (no clustering)
|
1567
|
+
if k == 0 or not clustering_performed:
|
1568
|
+
# We are not clustering, so we use all raw VPEs as prototypes
|
1569
|
+
# For visualization purposes, we'll show the raw VPEs and their average
|
1570
|
+
stacked_raw = torch.cat(raw_vpes, dim=1) # Shape: (1, num_raw, E)
|
1571
|
+
averaged_raw = stacked_raw.mean(dim=1, keepdim=True) # Shape: (1, 1, E)
|
1572
|
+
final_vpe = torch.nn.functional.normalize(averaged_raw, p=2, dim=-1)
|
1573
|
+
# Don't set prototypes here - we'll show raw VPEs separately in the visualization
|
1480
1574
|
|
1575
|
+
# Case 3: k >= num_raw -> use all raw VPEs as prototypes (no clustering needed)
|
1576
|
+
elif k >= num_raw:
|
1577
|
+
# We have more requested prototypes than available VPEs, so we use all VPEs
|
1578
|
+
stacked_raw = torch.cat(raw_vpes, dim=1) # Shape: (1, num_raw, E)
|
1579
|
+
averaged_raw = stacked_raw.mean(dim=1, keepdim=True) # Shape: (1, 1, E)
|
1580
|
+
final_vpe = torch.nn.functional.normalize(averaged_raw, p=2, dim=-1)
|
1581
|
+
# Don't set prototypes here - we'll show raw VPEs separately in the visualization
|
1582
|
+
|
1583
|
+
# 3. Create and show the visualization dialog
|
1584
|
+
QApplication.setOverrideCursor(Qt.WaitCursor)
|
1585
|
+
dialog = VPEVisualizationDialog(
|
1586
|
+
vpes_with_source,
|
1587
|
+
final_vpe,
|
1588
|
+
prototypes=prototypes,
|
1589
|
+
clustering_performed=clustering_performed,
|
1590
|
+
k_value=k,
|
1591
|
+
parent=self
|
1592
|
+
)
|
1593
|
+
QApplication.restoreOverrideCursor()
|
1481
1594
|
dialog.exec_()
|
1482
|
-
|
1595
|
+
|
1483
1596
|
except Exception as e:
|
1484
1597
|
QApplication.restoreOverrideCursor()
|
1485
1598
|
QMessageBox.critical(self, "Error Visualizing VPE", f"An error occurred: {str(e)}")
|
@@ -1512,44 +1625,45 @@ class DeployGeneratorDialog(QDialog):
|
|
1512
1625
|
|
1513
1626
|
class VPEVisualizationDialog(QDialog):
|
1514
1627
|
"""
|
1515
|
-
Dialog for visualizing VPE embeddings
|
1628
|
+
Dialog for visualizing VPE embeddings, now including K-prototypes.
|
1516
1629
|
"""
|
1517
|
-
def __init__(self, vpe_list_with_source, final_vpe=None,
|
1630
|
+
def __init__(self, vpe_list_with_source, final_vpe=None, prototypes=None,
|
1631
|
+
clustering_performed=False, k_value=0, parent=None):
|
1518
1632
|
"""
|
1519
|
-
Initialize the dialog
|
1633
|
+
Initialize the dialog.
|
1520
1634
|
|
1521
1635
|
Args:
|
1522
|
-
vpe_list_with_source (list): List of (VPE tensor, source_str) tuples
|
1523
|
-
final_vpe (torch.Tensor, optional): The final (averaged) VPE
|
1524
|
-
|
1636
|
+
vpe_list_with_source (list): List of (VPE tensor, source_str) tuples for raw VPEs.
|
1637
|
+
final_vpe (torch.Tensor, optional): The final (averaged) VPE.
|
1638
|
+
prototypes (list, optional): List of K-prototype VPE tensors (cluster centroids).
|
1639
|
+
clustering_performed (bool): Whether clustering was performed.
|
1640
|
+
k_value (int): The K value used for clustering.
|
1641
|
+
parent (QWidget, optional): Parent widget.
|
1525
1642
|
"""
|
1526
1643
|
super().__init__(parent)
|
1527
1644
|
self.setWindowTitle("VPE Visualization")
|
1528
1645
|
self.resize(1000, 1000)
|
1529
|
-
|
1530
|
-
# Add a maximize button to the dialog's title bar
|
1531
1646
|
self.setWindowFlags(self.windowFlags() | Qt.WindowMaximizeButtonHint)
|
1532
1647
|
|
1533
|
-
# Store the VPEs and
|
1648
|
+
# Store the VPEs and clustering info
|
1534
1649
|
self.vpe_list_with_source = vpe_list_with_source
|
1535
1650
|
self.final_vpe = final_vpe
|
1651
|
+
self.prototypes = prototypes if prototypes else []
|
1652
|
+
self.clustering_performed = clustering_performed
|
1653
|
+
self.k_value = k_value
|
1536
1654
|
|
1537
1655
|
# Create the layout
|
1538
1656
|
layout = QVBoxLayout(self)
|
1539
1657
|
|
1540
1658
|
# Create the plot widget
|
1541
1659
|
self.plot_widget = pg.PlotWidget()
|
1542
|
-
self.plot_widget.setBackground('w')
|
1660
|
+
self.plot_widget.setBackground('w')
|
1543
1661
|
self.plot_widget.setTitle("PCA Visualization of Visual Prompt Embeddings", color="#000000", size="10pt")
|
1544
1662
|
self.plot_widget.showGrid(x=True, y=True, alpha=0.3)
|
1545
|
-
|
1546
|
-
# Add the plot widget to the layout
|
1547
1663
|
layout.addWidget(self.plot_widget)
|
1548
|
-
|
1549
|
-
# Add spacing between plot_widget and info_label
|
1550
1664
|
layout.addSpacing(20)
|
1551
1665
|
|
1552
|
-
# Add information label
|
1666
|
+
# Add information label
|
1553
1667
|
self.info_label = QLabel()
|
1554
1668
|
self.info_label.setAlignment(Qt.AlignCenter)
|
1555
1669
|
layout.addWidget(self.info_label)
|
@@ -1564,101 +1678,129 @@ class VPEVisualizationDialog(QDialog):
|
|
1564
1678
|
|
1565
1679
|
def visualize_vpes(self):
|
1566
1680
|
"""
|
1567
|
-
Apply PCA to
|
1681
|
+
Apply PCA to all VPEs (raw, prototypes, final) and visualize them.
|
1568
1682
|
"""
|
1569
1683
|
if not self.vpe_list_with_source:
|
1570
1684
|
self.info_label.setText("No VPEs available to visualize.")
|
1571
1685
|
return
|
1686
|
+
|
1687
|
+
# 1. Collect all numpy arrays for PCA transformation
|
1688
|
+
raw_vpe_arrays = [vpe.detach().cpu().numpy().squeeze() for vpe, source in self.vpe_list_with_source]
|
1689
|
+
prototype_arrays = [p.detach().cpu().numpy().squeeze() for p in self.prototypes]
|
1572
1690
|
|
1573
|
-
|
1574
|
-
vpe_arrays = [vpe.detach().cpu().numpy().squeeze() for vpe, source in self.vpe_list_with_source]
|
1691
|
+
all_arrays_for_pca = raw_vpe_arrays + prototype_arrays
|
1575
1692
|
|
1576
|
-
# If final VPE is provided, add it to the arrays
|
1577
1693
|
final_vpe_array = None
|
1578
1694
|
if self.final_vpe is not None:
|
1579
1695
|
final_vpe_array = self.final_vpe.detach().cpu().numpy().squeeze()
|
1580
|
-
|
1581
|
-
|
1582
|
-
|
1583
|
-
|
1584
|
-
|
1696
|
+
all_arrays_for_pca.append(final_vpe_array)
|
1697
|
+
|
1698
|
+
if len(all_arrays_for_pca) < 2:
|
1699
|
+
self.info_label.setText("At least 2 VPEs are needed for PCA visualization.")
|
1700
|
+
return
|
1701
|
+
|
1702
|
+
# 2. Apply PCA
|
1703
|
+
all_vpes_stacked = np.vstack(all_arrays_for_pca)
|
1585
1704
|
pca = PCA(n_components=2)
|
1586
|
-
vpes_2d = pca.fit_transform(
|
1705
|
+
vpes_2d = pca.fit_transform(all_vpes_stacked)
|
1587
1706
|
|
1588
|
-
#
|
1707
|
+
# 3. Plot the results
|
1589
1708
|
self.plot_widget.clear()
|
1590
|
-
|
1591
|
-
# Generate random colors for individual VPEs
|
1592
|
-
num_vpes = len(vpe_arrays)
|
1593
|
-
colors = self.generate_distinct_colors(num_vpes)
|
1594
|
-
|
1595
|
-
# Create a legend with 3 columns to keep it compact
|
1596
1709
|
legend = self.plot_widget.addLegend(colCount=3)
|
1597
1710
|
|
1598
|
-
#
|
1599
|
-
|
1711
|
+
# Slicing indices
|
1712
|
+
num_raw = len(raw_vpe_arrays)
|
1713
|
+
num_prototypes = len(prototype_arrays)
|
1714
|
+
|
1715
|
+
# Determine if each raw VPE is effectively a prototype (k==0 or k>=N)
|
1716
|
+
each_vpe_is_prototype = (self.k_value == 0 or self.k_value >= num_raw)
|
1717
|
+
|
1718
|
+
# Plot individual raw VPEs
|
1719
|
+
colors = self.generate_distinct_colors(num_raw)
|
1720
|
+
for i, (vpe_tuple, vpe_2d) in enumerate(zip(self.vpe_list_with_source, vpes_2d[:num_raw])):
|
1600
1721
|
source_char = 'I' if vpe_tuple[1] == 'Import' else 'R'
|
1601
|
-
|
1722
|
+
|
1723
|
+
# Use diamonds if each VPE is a prototype, circles otherwise
|
1724
|
+
symbol = 'd' if each_vpe_is_prototype else 'o'
|
1725
|
+
|
1726
|
+
# If it's a prototype, add a black border
|
1727
|
+
pen = pg.mkPen(color='k', width=1.5) if each_vpe_is_prototype else None
|
1728
|
+
|
1729
|
+
# Create label with prototype indicator if applicable
|
1730
|
+
name_suffix = " (Prototype)" if each_vpe_is_prototype else ""
|
1731
|
+
name = f"VPE {i+1} ({source_char}){name_suffix}"
|
1732
|
+
|
1602
1733
|
scatter = pg.ScatterPlotItem(
|
1603
1734
|
x=[vpe_2d[0]],
|
1604
1735
|
y=[vpe_2d[1]],
|
1605
|
-
brush=
|
1606
|
-
|
1607
|
-
|
1736
|
+
brush=pg.mkColor(colors[i]),
|
1737
|
+
pen=pen,
|
1738
|
+
size=15 if not each_vpe_is_prototype else 18,
|
1739
|
+
symbol=symbol,
|
1740
|
+
name=name
|
1608
1741
|
)
|
1609
1742
|
self.plot_widget.addItem(scatter)
|
1610
1743
|
|
1611
|
-
# Plot
|
1744
|
+
# Plot K-Prototypes (blue diamonds) if we have any and explicit clustering was performed
|
1745
|
+
if self.prototypes and self.clustering_performed:
|
1746
|
+
prototype_vpes_2d = vpes_2d[num_raw: num_raw + num_prototypes]
|
1747
|
+
scatter = pg.ScatterPlotItem(
|
1748
|
+
x=prototype_vpes_2d[:, 0],
|
1749
|
+
y=prototype_vpes_2d[:, 1],
|
1750
|
+
brush=pg.mkBrush(color=(0, 0, 255, 150)),
|
1751
|
+
pen=pg.mkPen(color='k', width=1.5),
|
1752
|
+
size=18,
|
1753
|
+
symbol='d',
|
1754
|
+
name=f"K-Prototypes (K={self.k_value})"
|
1755
|
+
)
|
1756
|
+
self.plot_widget.addItem(scatter)
|
1757
|
+
|
1758
|
+
# Plot the final (averaged) VPE (red star)
|
1612
1759
|
if final_vpe_array is not None:
|
1613
1760
|
final_vpe_2d = vpes_2d[-1]
|
1614
1761
|
scatter = pg.ScatterPlotItem(
|
1615
1762
|
x=[final_vpe_2d[0]],
|
1616
|
-
y=[final_vpe_2d[1]],
|
1763
|
+
y=[final_vpe_2d[1]],
|
1617
1764
|
brush=pg.mkBrush(color='r'),
|
1618
|
-
size=20,
|
1619
|
-
symbol='star',
|
1620
|
-
name="Final VPE"
|
1765
|
+
size=20,
|
1766
|
+
symbol='star',
|
1767
|
+
name="Final VPE (Avg)"
|
1621
1768
|
)
|
1622
1769
|
self.plot_widget.addItem(scatter)
|
1623
1770
|
|
1624
|
-
# Update the information label
|
1771
|
+
# 4. Update the information label
|
1625
1772
|
orig_dim = self.vpe_list_with_source[0][0].shape[-1]
|
1626
1773
|
explained_variance = sum(pca.explained_variance_ratio_)
|
1627
|
-
|
1628
|
-
|
1629
|
-
|
1630
|
-
|
1631
|
-
|
1632
|
-
|
1774
|
+
|
1775
|
+
info_text = (f"Original dimension: {orig_dim} → Reduced to 2D\n"
|
1776
|
+
f"Total explained variance: {explained_variance:.2%}\n"
|
1777
|
+
f"PC1: {pca.explained_variance_ratio_[0]:.2%} variance, "
|
1778
|
+
f"PC2: {pca.explained_variance_ratio_[1]:.2%} variance\n"
|
1779
|
+
f"Number of raw VPEs: {num_raw}\n")
|
1780
|
+
|
1781
|
+
if self.clustering_performed:
|
1782
|
+
info_text += f"Clustering performed with K={self.k_value}\n"
|
1783
|
+
info_text += f"Number of prototypes: {len(self.prototypes)}"
|
1784
|
+
else:
|
1785
|
+
if self.k_value == 0:
|
1786
|
+
info_text += f"No clustering (K=0): all {num_raw} raw VPEs used as prototypes"
|
1787
|
+
else:
|
1788
|
+
info_text += f"No clustering performed (K={self.k_value} >= {num_raw}): all raw VPEs used as prototypes"
|
1789
|
+
|
1790
|
+
self.info_label.setText(info_text)
|
1633
1791
|
|
1634
1792
|
def generate_distinct_colors(self, num_colors):
|
1635
|
-
"""
|
1636
|
-
Generate visually distinct colors by using evenly spaced hues
|
1637
|
-
with random saturation and value.
|
1638
|
-
|
1639
|
-
Args:
|
1640
|
-
num_colors (int): Number of colors to generate
|
1641
|
-
|
1642
|
-
Returns:
|
1643
|
-
list: List of color hex strings
|
1644
|
-
"""
|
1793
|
+
"""Generates visually distinct colors."""
|
1645
1794
|
import random
|
1646
1795
|
from colorsys import hsv_to_rgb
|
1647
1796
|
|
1648
1797
|
colors = []
|
1649
1798
|
for i in range(num_colors):
|
1650
|
-
# Use golden ratio to space hues evenly
|
1651
1799
|
hue = (i * 0.618033988749895) % 1.0
|
1652
|
-
# Random saturation between 0.6-1.0 (avoid too pale)
|
1653
1800
|
saturation = random.uniform(0.6, 1.0)
|
1654
|
-
# Random value between 0.7-1.0 (avoid too dark)
|
1655
1801
|
value = random.uniform(0.7, 1.0)
|
1656
|
-
|
1657
|
-
# Convert HSV to RGB (0-1 range)
|
1658
1802
|
r, g, b = hsv_to_rgb(hue, saturation, value)
|
1659
|
-
|
1660
|
-
# Convert RGB to hex string
|
1661
1803
|
hex_color = f"#{int(r*255):02x}{int(g*255):02x}{int(b*255):02x}"
|
1662
1804
|
colors.append(hex_color)
|
1663
|
-
|
1805
|
+
|
1664
1806
|
return colors
|