coralnet-toolbox 0.0.75__py2.py3-none-any.whl → 0.0.77__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- coralnet_toolbox/Annotations/QtPolygonAnnotation.py +57 -12
- coralnet_toolbox/Annotations/QtRectangleAnnotation.py +44 -14
- coralnet_toolbox/Common/QtGraphicsUtility.py +18 -8
- coralnet_toolbox/Explorer/transformer_models.py +13 -2
- coralnet_toolbox/IO/QtExportMaskAnnotations.py +576 -402
- coralnet_toolbox/IO/QtImportImages.py +7 -15
- coralnet_toolbox/IO/QtOpenProject.py +15 -19
- coralnet_toolbox/Icons/system_monitor.png +0 -0
- coralnet_toolbox/MachineLearning/ImportDataset/QtBase.py +33 -8
- coralnet_toolbox/QtAnnotationWindow.py +4 -0
- coralnet_toolbox/QtEventFilter.py +5 -5
- coralnet_toolbox/QtImageWindow.py +4 -0
- coralnet_toolbox/QtMainWindow.py +104 -64
- coralnet_toolbox/QtProgressBar.py +1 -0
- coralnet_toolbox/QtSystemMonitor.py +370 -0
- coralnet_toolbox/Rasters/RasterManager.py +5 -2
- coralnet_toolbox/Results/ConvertResults.py +14 -8
- coralnet_toolbox/Results/ResultsProcessor.py +3 -2
- coralnet_toolbox/SAM/QtDeployGenerator.py +1 -1
- coralnet_toolbox/SAM/QtDeployPredictor.py +10 -0
- coralnet_toolbox/SeeAnything/QtDeployGenerator.py +324 -177
- coralnet_toolbox/SeeAnything/QtDeployPredictor.py +10 -6
- coralnet_toolbox/Tile/QtTileBatchInference.py +4 -4
- coralnet_toolbox/Tools/QtPatchTool.py +6 -2
- coralnet_toolbox/Tools/QtPolygonTool.py +5 -3
- coralnet_toolbox/Tools/QtRectangleTool.py +17 -9
- coralnet_toolbox/Tools/QtSAMTool.py +144 -91
- coralnet_toolbox/Tools/QtSeeAnythingTool.py +4 -0
- coralnet_toolbox/Tools/QtTool.py +79 -3
- coralnet_toolbox/Tools/QtWorkAreaTool.py +4 -0
- coralnet_toolbox/Transformers/Models/GroundingDINO.py +72 -0
- coralnet_toolbox/Transformers/Models/OWLViT.py +72 -0
- coralnet_toolbox/Transformers/Models/OmDetTurbo.py +68 -0
- coralnet_toolbox/Transformers/Models/QtBase.py +121 -0
- coralnet_toolbox/{AutoDistill → Transformers}/Models/__init__.py +1 -1
- coralnet_toolbox/{AutoDistill → Transformers}/QtBatchInference.py +15 -15
- coralnet_toolbox/{AutoDistill → Transformers}/QtDeployModel.py +18 -16
- coralnet_toolbox/{AutoDistill → Transformers}/__init__.py +1 -1
- coralnet_toolbox/__init__.py +1 -1
- coralnet_toolbox/utilities.py +0 -15
- {coralnet_toolbox-0.0.75.dist-info → coralnet_toolbox-0.0.77.dist-info}/METADATA +9 -9
- {coralnet_toolbox-0.0.75.dist-info → coralnet_toolbox-0.0.77.dist-info}/RECORD +46 -44
- coralnet_toolbox/AutoDistill/Models/GroundingDINO.py +0 -81
- coralnet_toolbox/AutoDistill/Models/OWLViT.py +0 -76
- coralnet_toolbox/AutoDistill/Models/OmDetTurbo.py +0 -75
- coralnet_toolbox/AutoDistill/Models/QtBase.py +0 -112
- {coralnet_toolbox-0.0.75.dist-info → coralnet_toolbox-0.0.77.dist-info}/WHEEL +0 -0
- {coralnet_toolbox-0.0.75.dist-info → coralnet_toolbox-0.0.77.dist-info}/entry_points.txt +0 -0
- {coralnet_toolbox-0.0.75.dist-info → coralnet_toolbox-0.0.77.dist-info}/licenses/LICENSE.txt +0 -0
- {coralnet_toolbox-0.0.75.dist-info → coralnet_toolbox-0.0.77.dist-info}/top_level.txt +0 -0
@@ -10,10 +10,10 @@ import torch
|
|
10
10
|
from torch.cuda import empty_cache
|
11
11
|
|
12
12
|
import pyqtgraph as pg
|
13
|
-
from pyqtgraph.Qt import QtGui
|
14
13
|
|
15
14
|
from ultralytics import YOLOE
|
16
15
|
from ultralytics.models.yolo.yoloe import YOLOEVPSegPredictor
|
16
|
+
from ultralytics.models.yolo.yoloe import YOLOEVPDetectPredictor
|
17
17
|
|
18
18
|
from PyQt5.QtCore import Qt
|
19
19
|
from PyQt5.QtWidgets import (QMessageBox, QVBoxLayout, QApplication, QFileDialog,
|
@@ -26,7 +26,6 @@ from coralnet_toolbox.Annotations.QtRectangleAnnotation import RectangleAnnotati
|
|
26
26
|
|
27
27
|
from coralnet_toolbox.Results import ResultsProcessor
|
28
28
|
from coralnet_toolbox.Results import MapResults
|
29
|
-
from coralnet_toolbox.Results import CombineResults
|
30
29
|
|
31
30
|
from coralnet_toolbox.QtProgressBar import ProgressBar
|
32
31
|
from coralnet_toolbox.QtImageWindow import ImageWindow
|
@@ -90,7 +89,7 @@ class DeployGeneratorDialog(QDialog):
|
|
90
89
|
self.imported_vpes = [] # VPEs loaded from file
|
91
90
|
self.reference_vpes = [] # VPEs created from reference images
|
92
91
|
|
93
|
-
self.device =
|
92
|
+
self.device = None # Will be set in showEvent
|
94
93
|
|
95
94
|
# Main vertical layout for the dialog
|
96
95
|
self.layout = QVBoxLayout(self)
|
@@ -195,6 +194,8 @@ class DeployGeneratorDialog(QDialog):
|
|
195
194
|
self.initialize_iou_threshold()
|
196
195
|
self.initialize_area_threshold()
|
197
196
|
|
197
|
+
# Update the device
|
198
|
+
self.device = self.main_window.device
|
198
199
|
# Configure the image window's UI elements for this specific dialog
|
199
200
|
self.configure_image_window_for_dialog()
|
200
201
|
# Sync with main window's images BEFORE updating labels
|
@@ -537,8 +538,8 @@ class DeployGeneratorDialog(QDialog):
|
|
537
538
|
|
538
539
|
def setup_reference_layout(self):
|
539
540
|
"""
|
540
|
-
Set up the layout
|
541
|
-
|
541
|
+
Set up the layout for reference selection, including the output label,
|
542
|
+
reference method, and the number of prototype clusters (K).
|
542
543
|
"""
|
543
544
|
group_box = QGroupBox("Reference")
|
544
545
|
layout = QFormLayout()
|
@@ -553,8 +554,18 @@ class DeployGeneratorDialog(QDialog):
|
|
553
554
|
self.reference_method_combo_box.addItems(["VPE", "Images"])
|
554
555
|
layout.addRow("Reference Method:", self.reference_method_combo_box)
|
555
556
|
|
557
|
+
# Add a spinbox for the user to define the number of prototypes (K)
|
558
|
+
self.k_prototypes_spinbox = QSpinBox()
|
559
|
+
self.k_prototypes_spinbox.setRange(0, 1000)
|
560
|
+
self.k_prototypes_spinbox.setValue(0)
|
561
|
+
self.k_prototypes_spinbox.setToolTip(
|
562
|
+
"Set the number of prototype clusters (K) to generate from references.\n"
|
563
|
+
"Set to 0 to treat every unique reference image/VPE as its own prototype (K=N)."
|
564
|
+
)
|
565
|
+
layout.addRow("Number of Prototypes (K):", self.k_prototypes_spinbox)
|
566
|
+
|
556
567
|
group_box.setLayout(layout)
|
557
|
-
self.right_panel.addWidget(group_box)
|
568
|
+
self.right_panel.addWidget(group_box)
|
558
569
|
|
559
570
|
def setup_buttons_layout(self):
|
560
571
|
"""
|
@@ -767,20 +778,20 @@ class DeployGeneratorDialog(QDialog):
|
|
767
778
|
try:
|
768
779
|
# Load the VPE file
|
769
780
|
loaded_data = torch.load(file_path)
|
770
|
-
|
771
|
-
#
|
772
|
-
|
781
|
+
|
782
|
+
# Move tensors to the appropriate device
|
783
|
+
device = self.main_window.device
|
773
784
|
|
774
785
|
# Check format type and handle appropriately
|
775
786
|
if isinstance(loaded_data, list):
|
776
787
|
# New format: list of VPE tensors
|
777
|
-
self.imported_vpes = [vpe.to(
|
788
|
+
self.imported_vpes = [vpe.to(device) for vpe in loaded_data]
|
778
789
|
vpe_count = len(self.imported_vpes)
|
779
790
|
self.status_bar.setText(f"Loaded {vpe_count} VPE tensors from file")
|
780
791
|
|
781
792
|
elif isinstance(loaded_data, torch.Tensor):
|
782
793
|
# Legacy format: single tensor - convert to list for consistency
|
783
|
-
loaded_vpe = loaded_data.to(
|
794
|
+
loaded_vpe = loaded_data.to(device)
|
784
795
|
# Store as a single-item list
|
785
796
|
self.imported_vpes = [loaded_vpe]
|
786
797
|
self.status_bar.setText("Loaded 1 VPE tensor from file (legacy format)")
|
@@ -948,7 +959,7 @@ class DeployGeneratorDialog(QDialog):
|
|
948
959
|
self.model_path = self.model_combo.currentText()
|
949
960
|
|
950
961
|
# Load model using registry
|
951
|
-
self.loaded_model = YOLOE(self.model_path, verbose=False).to(self.device)
|
962
|
+
self.loaded_model = YOLOE(self.model_path, verbose=False).to(self.device)
|
952
963
|
|
953
964
|
# Create a dummy visual dictionary for standard model loading
|
954
965
|
visual_prompts = dict(
|
@@ -966,7 +977,7 @@ class DeployGeneratorDialog(QDialog):
|
|
966
977
|
self.loaded_model.predict(
|
967
978
|
np.zeros((640, 640, 3), dtype=np.uint8),
|
968
979
|
visual_prompts=visual_prompts.copy(), # This needs to happen to properly initialize the predictor
|
969
|
-
predictor=
|
980
|
+
predictor=YOLOEVPDetectPredictor if self.task == "detect" else YOLOEVPSegPredictor,
|
970
981
|
imgsz=640,
|
971
982
|
conf=0.99,
|
972
983
|
)
|
@@ -1086,9 +1097,9 @@ class DeployGeneratorDialog(QDialog):
|
|
1086
1097
|
|
1087
1098
|
def _apply_model_using_images(self, inputs, reference_dict):
|
1088
1099
|
"""
|
1089
|
-
Apply the model using
|
1090
|
-
|
1091
|
-
|
1100
|
+
Apply the model using reference images. This method now generates VPEs
|
1101
|
+
from the reference annotations, clusters them into K prototypes, and
|
1102
|
+
runs a single, efficient multi-class prediction.
|
1092
1103
|
|
1093
1104
|
Args:
|
1094
1105
|
inputs (list): List of input images.
|
@@ -1097,64 +1108,70 @@ class DeployGeneratorDialog(QDialog):
|
|
1097
1108
|
Returns:
|
1098
1109
|
list: List of prediction results.
|
1099
1110
|
"""
|
1100
|
-
#
|
1101
|
-
QApplication.setOverrideCursor(Qt.WaitCursor)
|
1102
|
-
progress_bar = ProgressBar(self.annotation_window, title="Making Predictions per Reference")
|
1103
|
-
progress_bar.show()
|
1104
|
-
progress_bar.start_progress(len(reference_dict))
|
1111
|
+
# Note: This method may require 'from sklearn.cluster import KMeans'
|
1105
1112
|
|
1106
|
-
|
1107
|
-
|
1108
|
-
|
1109
|
-
input_image = inputs[0]
|
1113
|
+
# 1. Reload model and generate initial VPEs from the reference images
|
1114
|
+
self.reload_model()
|
1115
|
+
initial_vpes = self.references_to_vpe(reference_dict, update_reference_vpes=False)
|
1110
1116
|
|
1111
|
-
|
1112
|
-
|
1113
|
-
|
1114
|
-
|
1115
|
-
|
1116
|
-
'bboxes': ref_annotations['bboxes'],
|
1117
|
-
'cls': ref_annotations['cls'],
|
1118
|
-
}
|
1119
|
-
if self.task == 'segment':
|
1120
|
-
visual_prompts['masks'] = ref_annotations['masks']
|
1121
|
-
|
1122
|
-
# Make predictions on the target using the current reference
|
1123
|
-
results = self.loaded_model.predict(input_image,
|
1124
|
-
refer_image=ref_path,
|
1125
|
-
visual_prompts=visual_prompts,
|
1126
|
-
predictor=YOLOEVPSegPredictor, # TODO This is necessary here?
|
1127
|
-
imgsz=self.imgsz_spinbox.value(),
|
1128
|
-
conf=self.main_window.get_uncertainty_thresh(),
|
1129
|
-
iou=self.main_window.get_iou_thresh(),
|
1130
|
-
max_det=self.get_max_detections(),
|
1131
|
-
retina_masks=self.task == "segment")
|
1132
|
-
|
1133
|
-
if not len(results[0].boxes):
|
1134
|
-
# If no boxes were detected, skip to the next reference
|
1135
|
-
progress_bar.update_progress()
|
1136
|
-
continue
|
1137
|
-
|
1138
|
-
# Update the name of the results and append to the list
|
1139
|
-
results[0].names = {0: self.class_mapping[0].short_label_code}
|
1140
|
-
results_list.extend(results[0])
|
1141
|
-
|
1142
|
-
progress_bar.update_progress()
|
1143
|
-
gc.collect()
|
1144
|
-
empty_cache()
|
1117
|
+
if not initial_vpes:
|
1118
|
+
QMessageBox.warning(self,
|
1119
|
+
"VPE Generation Failed",
|
1120
|
+
"Could not generate VPEs from the selected reference images.")
|
1121
|
+
return []
|
1145
1122
|
|
1146
|
-
#
|
1147
|
-
|
1148
|
-
|
1149
|
-
|
1150
|
-
|
1151
|
-
|
1152
|
-
|
1153
|
-
|
1154
|
-
|
1123
|
+
# 2. Generate K prototypes from the N generated VPEs
|
1124
|
+
k = self.k_prototypes_spinbox.value()
|
1125
|
+
num_available_vpes = len(initial_vpes)
|
1126
|
+
prototype_vpes = []
|
1127
|
+
|
1128
|
+
# If K is 0 (use all) or K >= N, no clustering is needed.
|
1129
|
+
if k == 0 or k >= num_available_vpes:
|
1130
|
+
prototype_vpes = initial_vpes
|
1131
|
+
else:
|
1132
|
+
# Perform K-Means clustering to find K prototypes
|
1133
|
+
try:
|
1134
|
+
# Prepare tensor for scikit-learn: shape (N, E)
|
1135
|
+
all_vpes_tensor = torch.cat([vpe.squeeze(1) for vpe in initial_vpes], dim=0)
|
1136
|
+
vpes_np = all_vpes_tensor.cpu().numpy()
|
1137
|
+
|
1138
|
+
from sklearn.cluster import KMeans
|
1139
|
+
kmeans = KMeans(n_clusters=k, random_state=0, n_init='auto').fit(vpes_np)
|
1140
|
+
centroids_np = kmeans.cluster_centers_
|
1141
|
+
|
1142
|
+
# Convert centroids back to a list of normalized PyTorch tensors
|
1143
|
+
centroids_tensor = torch.from_numpy(centroids_np).to(self.device)
|
1144
|
+
for i in range(k):
|
1145
|
+
# Reshape centroid to (1, 1, E)
|
1146
|
+
centroid = centroids_tensor[i].unsqueeze(0).unsqueeze(0)
|
1147
|
+
normalized_centroid = torch.nn.functional.normalize(centroid, p=2, dim=-1)
|
1148
|
+
prototype_vpes.append(normalized_centroid)
|
1149
|
+
except Exception as e:
|
1150
|
+
QMessageBox.critical(self, "Clustering Error", f"Failed to perform K-Means clustering: {e}")
|
1151
|
+
return []
|
1152
|
+
|
1153
|
+
# 3. Configure the model with the K prototypes
|
1154
|
+
if not prototype_vpes:
|
1155
|
+
QMessageBox.warning(self, "Prototype Error", "Could not generate any prototypes for prediction.")
|
1155
1156
|
return []
|
1156
|
-
|
1157
|
-
|
1157
|
+
|
1158
|
+
num_prototypes = len(prototype_vpes)
|
1159
|
+
proto_class_names = [f"object{i}" for i in range(num_prototypes)]
|
1160
|
+
stacked_vpes = torch.cat(prototype_vpes, dim=1) # Shape: (1, K, E)
|
1161
|
+
|
1162
|
+
self.loaded_model.is_fused = lambda: False
|
1163
|
+
self.loaded_model.set_classes(proto_class_names, stacked_vpes)
|
1164
|
+
|
1165
|
+
# 4. Make a single prediction on the target using the K prototypes
|
1166
|
+
results = self.loaded_model.predict(inputs[0],
|
1167
|
+
visual_prompts=[],
|
1168
|
+
imgsz=self.imgsz_spinbox.value(),
|
1169
|
+
conf=self.main_window.get_uncertainty_thresh(),
|
1170
|
+
iou=self.main_window.get_iou_thresh(),
|
1171
|
+
max_det=self.get_max_detections(),
|
1172
|
+
retina_masks=self.task == "segment")
|
1173
|
+
|
1174
|
+
return [results]
|
1158
1175
|
|
1159
1176
|
def generate_vpes_from_references(self):
|
1160
1177
|
"""
|
@@ -1252,8 +1269,8 @@ class DeployGeneratorDialog(QDialog):
|
|
1252
1269
|
|
1253
1270
|
def _apply_model_using_vpe(self, inputs):
|
1254
1271
|
"""
|
1255
|
-
Apply the model
|
1256
|
-
|
1272
|
+
Apply the model using VPEs. This method now supports clustering N VPEs
|
1273
|
+
into K prototypes before running a single multi-class prediction.
|
1257
1274
|
|
1258
1275
|
Args:
|
1259
1276
|
inputs (list): List of input images.
|
@@ -1261,21 +1278,18 @@ class DeployGeneratorDialog(QDialog):
|
|
1261
1278
|
Returns:
|
1262
1279
|
list: List of prediction results.
|
1263
1280
|
"""
|
1281
|
+
# Note: This method may require 'from sklearn.cluster import KMeans'
|
1282
|
+
|
1264
1283
|
# First reload the model to clear any cached data
|
1265
1284
|
self.reload_model()
|
1266
1285
|
|
1267
|
-
#
|
1286
|
+
# 1. Gather all available VPEs from imported files and generated references
|
1268
1287
|
combined_vpes = []
|
1269
|
-
|
1270
|
-
# Add imported VPEs if available
|
1271
1288
|
if self.imported_vpes:
|
1272
1289
|
combined_vpes.extend(self.imported_vpes)
|
1273
|
-
|
1274
|
-
# Add pre-generated reference VPEs if available
|
1275
1290
|
if self.reference_vpes:
|
1276
1291
|
combined_vpes.extend(self.reference_vpes)
|
1277
1292
|
|
1278
|
-
# Check if we have any VPEs to use
|
1279
1293
|
if not combined_vpes:
|
1280
1294
|
QMessageBox.warning(
|
1281
1295
|
self,
|
@@ -1285,18 +1299,57 @@ class DeployGeneratorDialog(QDialog):
|
|
1285
1299
|
)
|
1286
1300
|
return []
|
1287
1301
|
|
1288
|
-
#
|
1289
|
-
|
1290
|
-
|
1291
|
-
|
1292
|
-
|
1293
|
-
|
1302
|
+
# 2. Generate K prototypes from the N available VPEs
|
1303
|
+
k = self.k_prototypes_spinbox.value()
|
1304
|
+
num_available_vpes = len(combined_vpes)
|
1305
|
+
prototype_vpes = []
|
1306
|
+
|
1307
|
+
# If K is 0 (use all) or K >= N, no clustering is needed. Each VPE is a prototype.
|
1308
|
+
if k == 0 or k >= num_available_vpes:
|
1309
|
+
prototype_vpes = combined_vpes
|
1310
|
+
else:
|
1311
|
+
# Perform K-Means clustering to find K prototypes
|
1312
|
+
try:
|
1313
|
+
# Prepare tensor for scikit-learn: shape (N, E)
|
1314
|
+
all_vpes_tensor = torch.cat([vpe.squeeze(1) for vpe in combined_vpes], dim=0)
|
1315
|
+
vpes_np = all_vpes_tensor.cpu().numpy()
|
1316
|
+
|
1317
|
+
# Lazily import KMeans to avoid making sklearn a hard dependency if not used
|
1318
|
+
from sklearn.cluster import KMeans
|
1319
|
+
kmeans = KMeans(n_clusters=k, random_state=0, n_init='auto').fit(vpes_np)
|
1320
|
+
centroids_np = kmeans.cluster_centers_
|
1321
|
+
|
1322
|
+
# Convert centroids back to a list of normalized PyTorch tensors
|
1323
|
+
centroids_tensor = torch.from_numpy(centroids_np).to(self.device)
|
1324
|
+
for i in range(k):
|
1325
|
+
# Reshape centroid to (1, 1, E) for model compatibility
|
1326
|
+
centroid = centroids_tensor[i].unsqueeze(0).unsqueeze(0)
|
1327
|
+
normalized_centroid = torch.nn.functional.normalize(centroid, p=2, dim=-1)
|
1328
|
+
prototype_vpes.append(normalized_centroid)
|
1329
|
+
except Exception as e:
|
1330
|
+
QMessageBox.critical(self, "Clustering Error", f"Failed to perform K-Means clustering: {e}")
|
1331
|
+
return []
|
1332
|
+
|
1333
|
+
# 3. Configure the model with the K prototypes
|
1334
|
+
if not prototype_vpes:
|
1335
|
+
QMessageBox.warning(self, "Prototype Error", "Could not generate any prototypes for prediction.")
|
1336
|
+
return []
|
1337
|
+
|
1338
|
+
# For backward compatibility, set self.vpe to the average of the final prototypes
|
1339
|
+
averaged_prototype = torch.cat(prototype_vpes).mean(dim=0, keepdim=True)
|
1340
|
+
self.vpe = torch.nn.functional.normalize(averaged_prototype, p=2, dim=-1)
|
1341
|
+
|
1342
|
+
# Set up the model for multi-class detection with K proto-classes
|
1343
|
+
num_prototypes = len(prototype_vpes)
|
1344
|
+
proto_class_names = [f"object{i}" for i in range(num_prototypes)]
|
1294
1345
|
|
1295
|
-
#
|
1346
|
+
# Stack prototypes into a single tensor of shape (1, K, E) for set_classes
|
1347
|
+
stacked_vpes = torch.cat(prototype_vpes, dim=1)
|
1348
|
+
|
1296
1349
|
self.loaded_model.is_fused = lambda: False
|
1297
|
-
self.loaded_model.set_classes(
|
1350
|
+
self.loaded_model.set_classes(proto_class_names, stacked_vpes)
|
1298
1351
|
|
1299
|
-
# Make predictions on the target using the
|
1352
|
+
# 4. Make predictions on the target using the K prototypes
|
1300
1353
|
results = self.loaded_model.predict(inputs[0],
|
1301
1354
|
visual_prompts=[],
|
1302
1355
|
imgsz=self.imgsz_spinbox.value(),
|
@@ -1386,7 +1439,7 @@ class DeployGeneratorDialog(QDialog):
|
|
1386
1439
|
return updated_results
|
1387
1440
|
|
1388
1441
|
def _process_results(self, results_processor, results_list, image_path):
|
1389
|
-
"""Process the results
|
1442
|
+
"""Process the results, merging K proto-class detections into a single target class."""
|
1390
1443
|
# Get the raster object and number of work items
|
1391
1444
|
raster = self.image_window.raster_manager.get_raster(image_path)
|
1392
1445
|
total = raster.count_work_items()
|
@@ -1400,14 +1453,25 @@ class DeployGeneratorDialog(QDialog):
|
|
1400
1453
|
progress_bar.start_progress(total)
|
1401
1454
|
|
1402
1455
|
updated_results = []
|
1456
|
+
target_label_name = self.reference_label.short_label_code
|
1403
1457
|
|
1404
1458
|
for idx, results in enumerate(results_list):
|
1405
1459
|
# Each Results is a list (within the results_list, [[], ]
|
1406
|
-
if results:
|
1407
|
-
#
|
1460
|
+
if results and results[0].boxes is not None and len(results[0].boxes) > 0:
|
1461
|
+
# Clone the data tensor and set all classes to 0
|
1462
|
+
new_data = results[0].boxes.data.clone()
|
1463
|
+
new_data[:, 5] = 0 # The 6th column (index 5) is the class
|
1464
|
+
|
1465
|
+
# Create a new Boxes object of the same type
|
1466
|
+
new_boxes = type(results[0].boxes)(new_data, results[0].boxes.orig_shape)
|
1467
|
+
results[0].boxes = new_boxes
|
1468
|
+
|
1469
|
+
# Update the 'names' dictionary to map our single class ID (0)
|
1470
|
+
# to the final target label name chosen by the user.
|
1471
|
+
results[0].names = {0: target_label_name}
|
1472
|
+
|
1473
|
+
# Update path (original logic)
|
1408
1474
|
results[0].path = image_path
|
1409
|
-
results[0].names = {0: self.class_mapping[0].short_label_code}
|
1410
|
-
# This needs to be done again, in case SAM was used
|
1411
1475
|
|
1412
1476
|
# Check if the work area is valid, or the image path is being used
|
1413
1477
|
if work_areas and self.annotation_window.get_selected_tool() == "work_area":
|
@@ -1419,7 +1483,7 @@ class DeployGeneratorDialog(QDialog):
|
|
1419
1483
|
else:
|
1420
1484
|
results = results[0]
|
1421
1485
|
|
1422
|
-
# Append the result object
|
1486
|
+
# Append the result object to the updated results list
|
1423
1487
|
updated_results.append(results)
|
1424
1488
|
|
1425
1489
|
# Update the index for the next work area
|
@@ -1439,42 +1503,96 @@ class DeployGeneratorDialog(QDialog):
|
|
1439
1503
|
|
1440
1504
|
def show_vpe(self):
|
1441
1505
|
"""
|
1442
|
-
Show a visualization of the
|
1506
|
+
Show a visualization of the stored VPEs and their K-prototypes.
|
1443
1507
|
"""
|
1444
1508
|
try:
|
1509
|
+
# 1. Gather all raw VPEs from imports and references
|
1445
1510
|
vpes_with_source = []
|
1446
|
-
|
1447
|
-
# 1. Add any VPEs that were loaded from a file
|
1448
1511
|
if self.imported_vpes:
|
1449
1512
|
for vpe in self.imported_vpes:
|
1450
1513
|
vpes_with_source.append((vpe, "Import"))
|
1451
|
-
|
1452
|
-
# 2. Add any pre-generated VPEs from reference images
|
1453
1514
|
if self.reference_vpes:
|
1454
1515
|
for vpe in self.reference_vpes:
|
1455
1516
|
vpes_with_source.append((vpe, "Reference"))
|
1456
1517
|
|
1457
|
-
# 3. Check if there is anything to visualize
|
1458
1518
|
if not vpes_with_source:
|
1459
1519
|
QMessageBox.warning(
|
1460
1520
|
self,
|
1461
1521
|
"No VPEs Available",
|
1462
|
-
"No VPEs available to visualize. Please load
|
1522
|
+
"No VPEs available to visualize. Please load or generate VPEs first."
|
1463
1523
|
)
|
1464
1524
|
return
|
1465
1525
|
|
1466
|
-
|
1467
|
-
|
1468
|
-
averaged_vpe = torch.cat(all_vpe_tensors).mean(dim=0, keepdim=True)
|
1469
|
-
final_vpe = torch.nn.functional.normalize(averaged_vpe, p=2, dim=-1)
|
1470
|
-
|
1471
|
-
QApplication.setOverrideCursor(Qt.WaitCursor)
|
1526
|
+
raw_vpes = [vpe for vpe, source in vpes_with_source]
|
1527
|
+
num_raw = len(raw_vpes)
|
1472
1528
|
|
1473
|
-
|
1474
|
-
|
1529
|
+
# 2. Get K and determine if we need to cluster
|
1530
|
+
k = self.k_prototypes_spinbox.value()
|
1531
|
+
prototypes = [] # This will be the centroids if clustering is performed
|
1532
|
+
final_vpe = None
|
1533
|
+
clustering_performed = False
|
1534
|
+
|
1535
|
+
# Case 1: We want to cluster (1 <= k < num_raw)
|
1536
|
+
if 1 <= k < num_raw:
|
1537
|
+
try:
|
1538
|
+
# Prepare tensor for scikit-learn: shape (N, E)
|
1539
|
+
all_vpes_tensor = torch.cat([vpe.squeeze(1) for vpe in raw_vpes], dim=0)
|
1540
|
+
vpes_np = all_vpes_tensor.cpu().numpy()
|
1541
|
+
|
1542
|
+
from sklearn.cluster import KMeans
|
1543
|
+
kmeans = KMeans(n_clusters=k, random_state=0, n_init='auto').fit(vpes_np)
|
1544
|
+
centroids_np = kmeans.cluster_centers_
|
1545
|
+
|
1546
|
+
centroids_tensor = torch.from_numpy(centroids_np).to(self.device)
|
1547
|
+
for i in range(k):
|
1548
|
+
centroid = centroids_tensor[i].unsqueeze(0).unsqueeze(0)
|
1549
|
+
normalized_centroid = torch.nn.functional.normalize(centroid, p=2, dim=-1)
|
1550
|
+
prototypes.append(normalized_centroid)
|
1551
|
+
|
1552
|
+
# The final VPE is the average of the centroids
|
1553
|
+
stacked_prototypes = torch.cat(prototypes, dim=1) # Shape: (1, k, E)
|
1554
|
+
averaged_prototype = stacked_prototypes.mean(dim=1, keepdim=True) # Shape: (1, 1, E)
|
1555
|
+
final_vpe = torch.nn.functional.normalize(averaged_prototype, p=2, dim=-1)
|
1556
|
+
clustering_performed = True
|
1557
|
+
|
1558
|
+
except Exception as e:
|
1559
|
+
QMessageBox.critical(self,
|
1560
|
+
"Clustering Error",
|
1561
|
+
f"Could not perform clustering for visualization: {e}")
|
1562
|
+
# If clustering fails, fall back to using all raw VPEs
|
1563
|
+
prototypes = []
|
1564
|
+
clustering_performed = False
|
1565
|
+
|
1566
|
+
# Case 2: k==0 -> use all raw VPEs as prototypes (no clustering)
|
1567
|
+
if k == 0 or not clustering_performed:
|
1568
|
+
# We are not clustering, so we use all raw VPEs as prototypes
|
1569
|
+
# For visualization purposes, we'll show the raw VPEs and their average
|
1570
|
+
stacked_raw = torch.cat(raw_vpes, dim=1) # Shape: (1, num_raw, E)
|
1571
|
+
averaged_raw = stacked_raw.mean(dim=1, keepdim=True) # Shape: (1, 1, E)
|
1572
|
+
final_vpe = torch.nn.functional.normalize(averaged_raw, p=2, dim=-1)
|
1573
|
+
# Don't set prototypes here - we'll show raw VPEs separately in the visualization
|
1475
1574
|
|
1575
|
+
# Case 3: k >= num_raw -> use all raw VPEs as prototypes (no clustering needed)
|
1576
|
+
elif k >= num_raw:
|
1577
|
+
# We have more requested prototypes than available VPEs, so we use all VPEs
|
1578
|
+
stacked_raw = torch.cat(raw_vpes, dim=1) # Shape: (1, num_raw, E)
|
1579
|
+
averaged_raw = stacked_raw.mean(dim=1, keepdim=True) # Shape: (1, 1, E)
|
1580
|
+
final_vpe = torch.nn.functional.normalize(averaged_raw, p=2, dim=-1)
|
1581
|
+
# Don't set prototypes here - we'll show raw VPEs separately in the visualization
|
1582
|
+
|
1583
|
+
# 3. Create and show the visualization dialog
|
1584
|
+
QApplication.setOverrideCursor(Qt.WaitCursor)
|
1585
|
+
dialog = VPEVisualizationDialog(
|
1586
|
+
vpes_with_source,
|
1587
|
+
final_vpe,
|
1588
|
+
prototypes=prototypes,
|
1589
|
+
clustering_performed=clustering_performed,
|
1590
|
+
k_value=k,
|
1591
|
+
parent=self
|
1592
|
+
)
|
1593
|
+
QApplication.restoreOverrideCursor()
|
1476
1594
|
dialog.exec_()
|
1477
|
-
|
1595
|
+
|
1478
1596
|
except Exception as e:
|
1479
1597
|
QApplication.restoreOverrideCursor()
|
1480
1598
|
QMessageBox.critical(self, "Error Visualizing VPE", f"An error occurred: {str(e)}")
|
@@ -1507,44 +1625,45 @@ class DeployGeneratorDialog(QDialog):
|
|
1507
1625
|
|
1508
1626
|
class VPEVisualizationDialog(QDialog):
|
1509
1627
|
"""
|
1510
|
-
Dialog for visualizing VPE embeddings
|
1628
|
+
Dialog for visualizing VPE embeddings, now including K-prototypes.
|
1511
1629
|
"""
|
1512
|
-
def __init__(self, vpe_list_with_source, final_vpe=None,
|
1630
|
+
def __init__(self, vpe_list_with_source, final_vpe=None, prototypes=None,
|
1631
|
+
clustering_performed=False, k_value=0, parent=None):
|
1513
1632
|
"""
|
1514
|
-
Initialize the dialog
|
1633
|
+
Initialize the dialog.
|
1515
1634
|
|
1516
1635
|
Args:
|
1517
|
-
vpe_list_with_source (list): List of (VPE tensor, source_str) tuples
|
1518
|
-
final_vpe (torch.Tensor, optional): The final (averaged) VPE
|
1519
|
-
|
1636
|
+
vpe_list_with_source (list): List of (VPE tensor, source_str) tuples for raw VPEs.
|
1637
|
+
final_vpe (torch.Tensor, optional): The final (averaged) VPE.
|
1638
|
+
prototypes (list, optional): List of K-prototype VPE tensors (cluster centroids).
|
1639
|
+
clustering_performed (bool): Whether clustering was performed.
|
1640
|
+
k_value (int): The K value used for clustering.
|
1641
|
+
parent (QWidget, optional): Parent widget.
|
1520
1642
|
"""
|
1521
1643
|
super().__init__(parent)
|
1522
1644
|
self.setWindowTitle("VPE Visualization")
|
1523
1645
|
self.resize(1000, 1000)
|
1524
|
-
|
1525
|
-
# Add a maximize button to the dialog's title bar
|
1526
1646
|
self.setWindowFlags(self.windowFlags() | Qt.WindowMaximizeButtonHint)
|
1527
1647
|
|
1528
|
-
# Store the VPEs and
|
1648
|
+
# Store the VPEs and clustering info
|
1529
1649
|
self.vpe_list_with_source = vpe_list_with_source
|
1530
1650
|
self.final_vpe = final_vpe
|
1651
|
+
self.prototypes = prototypes if prototypes else []
|
1652
|
+
self.clustering_performed = clustering_performed
|
1653
|
+
self.k_value = k_value
|
1531
1654
|
|
1532
1655
|
# Create the layout
|
1533
1656
|
layout = QVBoxLayout(self)
|
1534
1657
|
|
1535
1658
|
# Create the plot widget
|
1536
1659
|
self.plot_widget = pg.PlotWidget()
|
1537
|
-
self.plot_widget.setBackground('w')
|
1660
|
+
self.plot_widget.setBackground('w')
|
1538
1661
|
self.plot_widget.setTitle("PCA Visualization of Visual Prompt Embeddings", color="#000000", size="10pt")
|
1539
1662
|
self.plot_widget.showGrid(x=True, y=True, alpha=0.3)
|
1540
|
-
|
1541
|
-
# Add the plot widget to the layout
|
1542
1663
|
layout.addWidget(self.plot_widget)
|
1543
|
-
|
1544
|
-
# Add spacing between plot_widget and info_label
|
1545
1664
|
layout.addSpacing(20)
|
1546
1665
|
|
1547
|
-
# Add information label
|
1666
|
+
# Add information label
|
1548
1667
|
self.info_label = QLabel()
|
1549
1668
|
self.info_label.setAlignment(Qt.AlignCenter)
|
1550
1669
|
layout.addWidget(self.info_label)
|
@@ -1559,101 +1678,129 @@ class VPEVisualizationDialog(QDialog):
|
|
1559
1678
|
|
1560
1679
|
def visualize_vpes(self):
|
1561
1680
|
"""
|
1562
|
-
Apply PCA to
|
1681
|
+
Apply PCA to all VPEs (raw, prototypes, final) and visualize them.
|
1563
1682
|
"""
|
1564
1683
|
if not self.vpe_list_with_source:
|
1565
1684
|
self.info_label.setText("No VPEs available to visualize.")
|
1566
1685
|
return
|
1686
|
+
|
1687
|
+
# 1. Collect all numpy arrays for PCA transformation
|
1688
|
+
raw_vpe_arrays = [vpe.detach().cpu().numpy().squeeze() for vpe, source in self.vpe_list_with_source]
|
1689
|
+
prototype_arrays = [p.detach().cpu().numpy().squeeze() for p in self.prototypes]
|
1567
1690
|
|
1568
|
-
|
1569
|
-
vpe_arrays = [vpe.detach().cpu().numpy().squeeze() for vpe, source in self.vpe_list_with_source]
|
1691
|
+
all_arrays_for_pca = raw_vpe_arrays + prototype_arrays
|
1570
1692
|
|
1571
|
-
# If final VPE is provided, add it to the arrays
|
1572
1693
|
final_vpe_array = None
|
1573
1694
|
if self.final_vpe is not None:
|
1574
1695
|
final_vpe_array = self.final_vpe.detach().cpu().numpy().squeeze()
|
1575
|
-
|
1576
|
-
|
1577
|
-
|
1578
|
-
|
1579
|
-
|
1696
|
+
all_arrays_for_pca.append(final_vpe_array)
|
1697
|
+
|
1698
|
+
if len(all_arrays_for_pca) < 2:
|
1699
|
+
self.info_label.setText("At least 2 VPEs are needed for PCA visualization.")
|
1700
|
+
return
|
1701
|
+
|
1702
|
+
# 2. Apply PCA
|
1703
|
+
all_vpes_stacked = np.vstack(all_arrays_for_pca)
|
1580
1704
|
pca = PCA(n_components=2)
|
1581
|
-
vpes_2d = pca.fit_transform(
|
1705
|
+
vpes_2d = pca.fit_transform(all_vpes_stacked)
|
1582
1706
|
|
1583
|
-
#
|
1707
|
+
# 3. Plot the results
|
1584
1708
|
self.plot_widget.clear()
|
1585
|
-
|
1586
|
-
# Generate random colors for individual VPEs
|
1587
|
-
num_vpes = len(vpe_arrays)
|
1588
|
-
colors = self.generate_distinct_colors(num_vpes)
|
1589
|
-
|
1590
|
-
# Create a legend with 3 columns to keep it compact
|
1591
1709
|
legend = self.plot_widget.addLegend(colCount=3)
|
1592
1710
|
|
1593
|
-
#
|
1594
|
-
|
1711
|
+
# Slicing indices
|
1712
|
+
num_raw = len(raw_vpe_arrays)
|
1713
|
+
num_prototypes = len(prototype_arrays)
|
1714
|
+
|
1715
|
+
# Determine if each raw VPE is effectively a prototype (k==0 or k>=N)
|
1716
|
+
each_vpe_is_prototype = (self.k_value == 0 or self.k_value >= num_raw)
|
1717
|
+
|
1718
|
+
# Plot individual raw VPEs
|
1719
|
+
colors = self.generate_distinct_colors(num_raw)
|
1720
|
+
for i, (vpe_tuple, vpe_2d) in enumerate(zip(self.vpe_list_with_source, vpes_2d[:num_raw])):
|
1595
1721
|
source_char = 'I' if vpe_tuple[1] == 'Import' else 'R'
|
1596
|
-
|
1722
|
+
|
1723
|
+
# Use diamonds if each VPE is a prototype, circles otherwise
|
1724
|
+
symbol = 'd' if each_vpe_is_prototype else 'o'
|
1725
|
+
|
1726
|
+
# If it's a prototype, add a black border
|
1727
|
+
pen = pg.mkPen(color='k', width=1.5) if each_vpe_is_prototype else None
|
1728
|
+
|
1729
|
+
# Create label with prototype indicator if applicable
|
1730
|
+
name_suffix = " (Prototype)" if each_vpe_is_prototype else ""
|
1731
|
+
name = f"VPE {i+1} ({source_char}){name_suffix}"
|
1732
|
+
|
1597
1733
|
scatter = pg.ScatterPlotItem(
|
1598
1734
|
x=[vpe_2d[0]],
|
1599
1735
|
y=[vpe_2d[1]],
|
1600
|
-
brush=
|
1601
|
-
|
1602
|
-
|
1736
|
+
brush=pg.mkColor(colors[i]),
|
1737
|
+
pen=pen,
|
1738
|
+
size=15 if not each_vpe_is_prototype else 18,
|
1739
|
+
symbol=symbol,
|
1740
|
+
name=name
|
1603
1741
|
)
|
1604
1742
|
self.plot_widget.addItem(scatter)
|
1605
1743
|
|
1606
|
-
# Plot
|
1744
|
+
# Plot K-Prototypes (blue diamonds) if we have any and explicit clustering was performed
|
1745
|
+
if self.prototypes and self.clustering_performed:
|
1746
|
+
prototype_vpes_2d = vpes_2d[num_raw: num_raw + num_prototypes]
|
1747
|
+
scatter = pg.ScatterPlotItem(
|
1748
|
+
x=prototype_vpes_2d[:, 0],
|
1749
|
+
y=prototype_vpes_2d[:, 1],
|
1750
|
+
brush=pg.mkBrush(color=(0, 0, 255, 150)),
|
1751
|
+
pen=pg.mkPen(color='k', width=1.5),
|
1752
|
+
size=18,
|
1753
|
+
symbol='d',
|
1754
|
+
name=f"K-Prototypes (K={self.k_value})"
|
1755
|
+
)
|
1756
|
+
self.plot_widget.addItem(scatter)
|
1757
|
+
|
1758
|
+
# Plot the final (averaged) VPE (red star)
|
1607
1759
|
if final_vpe_array is not None:
|
1608
1760
|
final_vpe_2d = vpes_2d[-1]
|
1609
1761
|
scatter = pg.ScatterPlotItem(
|
1610
1762
|
x=[final_vpe_2d[0]],
|
1611
|
-
y=[final_vpe_2d[1]],
|
1763
|
+
y=[final_vpe_2d[1]],
|
1612
1764
|
brush=pg.mkBrush(color='r'),
|
1613
|
-
size=20,
|
1614
|
-
symbol='star',
|
1615
|
-
name="Final VPE"
|
1765
|
+
size=20,
|
1766
|
+
symbol='star',
|
1767
|
+
name="Final VPE (Avg)"
|
1616
1768
|
)
|
1617
1769
|
self.plot_widget.addItem(scatter)
|
1618
1770
|
|
1619
|
-
# Update the information label
|
1771
|
+
# 4. Update the information label
|
1620
1772
|
orig_dim = self.vpe_list_with_source[0][0].shape[-1]
|
1621
1773
|
explained_variance = sum(pca.explained_variance_ratio_)
|
1622
|
-
|
1623
|
-
|
1624
|
-
|
1625
|
-
|
1626
|
-
|
1627
|
-
|
1774
|
+
|
1775
|
+
info_text = (f"Original dimension: {orig_dim} → Reduced to 2D\n"
|
1776
|
+
f"Total explained variance: {explained_variance:.2%}\n"
|
1777
|
+
f"PC1: {pca.explained_variance_ratio_[0]:.2%} variance, "
|
1778
|
+
f"PC2: {pca.explained_variance_ratio_[1]:.2%} variance\n"
|
1779
|
+
f"Number of raw VPEs: {num_raw}\n")
|
1780
|
+
|
1781
|
+
if self.clustering_performed:
|
1782
|
+
info_text += f"Clustering performed with K={self.k_value}\n"
|
1783
|
+
info_text += f"Number of prototypes: {len(self.prototypes)}"
|
1784
|
+
else:
|
1785
|
+
if self.k_value == 0:
|
1786
|
+
info_text += f"No clustering (K=0): all {num_raw} raw VPEs used as prototypes"
|
1787
|
+
else:
|
1788
|
+
info_text += f"No clustering performed (K={self.k_value} >= {num_raw}): all raw VPEs used as prototypes"
|
1789
|
+
|
1790
|
+
self.info_label.setText(info_text)
|
1628
1791
|
|
1629
1792
|
def generate_distinct_colors(self, num_colors):
|
1630
|
-
"""
|
1631
|
-
Generate visually distinct colors by using evenly spaced hues
|
1632
|
-
with random saturation and value.
|
1633
|
-
|
1634
|
-
Args:
|
1635
|
-
num_colors (int): Number of colors to generate
|
1636
|
-
|
1637
|
-
Returns:
|
1638
|
-
list: List of color hex strings
|
1639
|
-
"""
|
1793
|
+
"""Generates visually distinct colors."""
|
1640
1794
|
import random
|
1641
1795
|
from colorsys import hsv_to_rgb
|
1642
1796
|
|
1643
1797
|
colors = []
|
1644
1798
|
for i in range(num_colors):
|
1645
|
-
# Use golden ratio to space hues evenly
|
1646
1799
|
hue = (i * 0.618033988749895) % 1.0
|
1647
|
-
# Random saturation between 0.6-1.0 (avoid too pale)
|
1648
1800
|
saturation = random.uniform(0.6, 1.0)
|
1649
|
-
# Random value between 0.7-1.0 (avoid too dark)
|
1650
1801
|
value = random.uniform(0.7, 1.0)
|
1651
|
-
|
1652
|
-
# Convert HSV to RGB (0-1 range)
|
1653
1802
|
r, g, b = hsv_to_rgb(hue, saturation, value)
|
1654
|
-
|
1655
|
-
# Convert RGB to hex string
|
1656
1803
|
hex_color = f"#{int(r*255):02x}{int(g*255):02x}{int(b*255):02x}"
|
1657
1804
|
colors.append(hex_color)
|
1658
|
-
|
1805
|
+
|
1659
1806
|
return colors
|