coralnet-toolbox 0.0.73__py2.py3-none-any.whl → 0.0.75__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- coralnet_toolbox/Annotations/QtAnnotation.py +28 -69
- coralnet_toolbox/Annotations/QtMaskAnnotation.py +408 -0
- coralnet_toolbox/Annotations/QtMultiPolygonAnnotation.py +72 -56
- coralnet_toolbox/Annotations/QtPatchAnnotation.py +165 -216
- coralnet_toolbox/Annotations/QtPolygonAnnotation.py +497 -353
- coralnet_toolbox/Annotations/QtRectangleAnnotation.py +126 -116
- coralnet_toolbox/CoralNet/QtDownload.py +2 -1
- coralnet_toolbox/Explorer/QtDataItem.py +52 -22
- coralnet_toolbox/Explorer/QtExplorer.py +293 -1614
- coralnet_toolbox/Explorer/QtSettingsWidgets.py +203 -85
- coralnet_toolbox/Explorer/QtViewers.py +1568 -0
- coralnet_toolbox/Explorer/transformer_models.py +59 -0
- coralnet_toolbox/Explorer/yolo_models.py +112 -0
- coralnet_toolbox/IO/QtExportTagLabAnnotations.py +30 -10
- coralnet_toolbox/IO/QtImportTagLabAnnotations.py +21 -15
- coralnet_toolbox/IO/QtOpenProject.py +46 -78
- coralnet_toolbox/IO/QtSaveProject.py +18 -43
- coralnet_toolbox/MachineLearning/ExportDataset/QtBase.py +1 -1
- coralnet_toolbox/MachineLearning/ImportDataset/QtBase.py +253 -141
- coralnet_toolbox/MachineLearning/VideoInference/QtBase.py +0 -4
- coralnet_toolbox/MachineLearning/VideoInference/YOLO3D/run.py +102 -16
- coralnet_toolbox/QtAnnotationWindow.py +16 -10
- coralnet_toolbox/QtEventFilter.py +11 -0
- coralnet_toolbox/QtImageWindow.py +120 -75
- coralnet_toolbox/QtLabelWindow.py +13 -1
- coralnet_toolbox/QtMainWindow.py +5 -27
- coralnet_toolbox/QtProgressBar.py +52 -27
- coralnet_toolbox/Rasters/RasterTableModel.py +28 -8
- coralnet_toolbox/SAM/QtDeployGenerator.py +1 -4
- coralnet_toolbox/SAM/QtDeployPredictor.py +11 -3
- coralnet_toolbox/SeeAnything/QtDeployGenerator.py +805 -162
- coralnet_toolbox/SeeAnything/QtDeployPredictor.py +130 -151
- coralnet_toolbox/Tools/QtCutSubTool.py +18 -2
- coralnet_toolbox/Tools/QtPolygonTool.py +42 -3
- coralnet_toolbox/Tools/QtRectangleTool.py +30 -0
- coralnet_toolbox/Tools/QtResizeSubTool.py +19 -2
- coralnet_toolbox/Tools/QtSAMTool.py +72 -50
- coralnet_toolbox/Tools/QtSeeAnythingTool.py +8 -5
- coralnet_toolbox/Tools/QtSelectTool.py +27 -3
- coralnet_toolbox/Tools/QtSubtractSubTool.py +66 -0
- coralnet_toolbox/Tools/__init__.py +2 -0
- coralnet_toolbox/__init__.py +1 -1
- coralnet_toolbox/utilities.py +158 -47
- coralnet_toolbox-0.0.75.dist-info/METADATA +378 -0
- {coralnet_toolbox-0.0.73.dist-info → coralnet_toolbox-0.0.75.dist-info}/RECORD +49 -44
- coralnet_toolbox-0.0.73.dist-info/METADATA +0 -341
- {coralnet_toolbox-0.0.73.dist-info → coralnet_toolbox-0.0.75.dist-info}/WHEEL +0 -0
- {coralnet_toolbox-0.0.73.dist-info → coralnet_toolbox-0.0.75.dist-info}/entry_points.txt +0 -0
- {coralnet_toolbox-0.0.73.dist-info → coralnet_toolbox-0.0.75.dist-info}/licenses/LICENSE.txt +0 -0
- {coralnet_toolbox-0.0.73.dist-info → coralnet_toolbox-0.0.75.dist-info}/top_level.txt +0 -0
@@ -2,10 +2,10 @@ import warnings
|
|
2
2
|
|
3
3
|
import os
|
4
4
|
import gc
|
5
|
-
import ujson as json
|
6
5
|
|
7
6
|
import numpy as np
|
8
7
|
|
8
|
+
import torch
|
9
9
|
from torch.cuda import empty_cache
|
10
10
|
from ultralytics.utils import ops
|
11
11
|
|
@@ -17,7 +17,7 @@ from PyQt5.QtCore import Qt
|
|
17
17
|
from PyQt5.QtGui import QColor
|
18
18
|
from PyQt5.QtWidgets import (QApplication, QComboBox, QDialog, QFormLayout,
|
19
19
|
QHBoxLayout, QLabel, QMessageBox, QPushButton,
|
20
|
-
QSlider, QSpinBox, QVBoxLayout, QGroupBox,
|
20
|
+
QSlider, QSpinBox, QVBoxLayout, QGroupBox,
|
21
21
|
QWidget, QLineEdit, QFileDialog)
|
22
22
|
|
23
23
|
from coralnet_toolbox.Results import ResultsProcessor
|
@@ -98,7 +98,10 @@ class DeployPredictorDialog(QDialog):
|
|
98
98
|
layout = QVBoxLayout()
|
99
99
|
|
100
100
|
# Create a QLabel with explanatory text and hyperlink
|
101
|
-
info_label = QLabel(
|
101
|
+
info_label = QLabel(
|
102
|
+
"Choose a Predictor to deploy and use interactively with the See Anything tool. "
|
103
|
+
"Optionally include a custom visual prompt encoding (VPE) file."
|
104
|
+
)
|
102
105
|
|
103
106
|
info_label.setOpenExternalLinks(True)
|
104
107
|
info_label.setWordWrap(True)
|
@@ -109,21 +112,15 @@ class DeployPredictorDialog(QDialog):
|
|
109
112
|
|
110
113
|
def setup_models_layout(self):
|
111
114
|
"""
|
112
|
-
Setup the models layout with
|
115
|
+
Setup the models layout with standard models and file selection.
|
113
116
|
"""
|
114
117
|
group_box = QGroupBox("Model Selection")
|
115
|
-
layout =
|
116
|
-
|
117
|
-
#
|
118
|
-
tab_widget = QTabWidget()
|
119
|
-
|
120
|
-
# Tab 1: Standard models
|
121
|
-
standard_tab = QWidget()
|
122
|
-
standard_layout = QVBoxLayout(standard_tab)
|
123
|
-
|
118
|
+
layout = QFormLayout()
|
119
|
+
|
120
|
+
# Model dropdown
|
124
121
|
self.model_combo = QComboBox()
|
125
122
|
self.model_combo.setEditable(True)
|
126
|
-
|
123
|
+
|
127
124
|
# Define available models
|
128
125
|
standard_models = [
|
129
126
|
'yoloe-v8s-seg.pt',
|
@@ -133,49 +130,15 @@ class DeployPredictorDialog(QDialog):
|
|
133
130
|
'yoloe-11m-seg.pt',
|
134
131
|
'yoloe-11l-seg.pt',
|
135
132
|
]
|
136
|
-
|
133
|
+
|
137
134
|
# Add all models to combo box
|
138
135
|
self.model_combo.addItems(standard_models)
|
136
|
+
|
139
137
|
# Set the default model
|
140
138
|
self.model_combo.setCurrentIndex(standard_models.index('yoloe-v8s-seg.pt'))
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
tab_widget.addTab(standard_tab, "Use Existing Model")
|
146
|
-
|
147
|
-
# Tab 2: Custom model
|
148
|
-
custom_tab = QWidget()
|
149
|
-
custom_layout = QFormLayout(custom_tab)
|
150
|
-
|
151
|
-
# Custom model file selection
|
152
|
-
self.model_path_edit = QLineEdit()
|
153
|
-
browse_button = QPushButton("Browse...")
|
154
|
-
browse_button.clicked.connect(self.browse_model_file)
|
155
|
-
|
156
|
-
model_path_layout = QHBoxLayout()
|
157
|
-
model_path_layout.addWidget(self.model_path_edit)
|
158
|
-
model_path_layout.addWidget(browse_button)
|
159
|
-
custom_layout.addRow("Custom Model:", model_path_layout)
|
160
|
-
|
161
|
-
# Class Mapping
|
162
|
-
self.mapping_edit = QLineEdit()
|
163
|
-
self.mapping_button = QPushButton("Browse...")
|
164
|
-
self.mapping_button.clicked.connect(self.browse_class_mapping_file)
|
165
|
-
|
166
|
-
class_mapping_layout = QHBoxLayout()
|
167
|
-
class_mapping_layout.addWidget(self.mapping_edit)
|
168
|
-
class_mapping_layout.addWidget(self.mapping_button)
|
169
|
-
custom_layout.addRow("Class Mapping:", class_mapping_layout)
|
170
|
-
|
171
|
-
tab_widget.addTab(custom_tab, "Custom Model")
|
172
|
-
|
173
|
-
# Add the tab widget to the main layout
|
174
|
-
layout.addWidget(tab_widget)
|
175
|
-
|
176
|
-
# Store the tab widget for later reference
|
177
|
-
self.model_tab_widget = tab_widget
|
178
|
-
|
139
|
+
# Create a layout for the model selection
|
140
|
+
layout.addRow("Models:", self.model_combo)
|
141
|
+
|
179
142
|
group_box.setLayout(layout)
|
180
143
|
self.layout.addWidget(group_box)
|
181
144
|
|
@@ -208,7 +171,7 @@ class DeployPredictorDialog(QDialog):
|
|
208
171
|
# Image size control
|
209
172
|
self.imgsz_spinbox = QSpinBox()
|
210
173
|
self.imgsz_spinbox.setRange(512, 65536)
|
211
|
-
self.imgsz_spinbox.setSingleStep(
|
174
|
+
self.imgsz_spinbox.setSingleStep(1024)
|
212
175
|
self.imgsz_spinbox.setValue(self.imgsz)
|
213
176
|
layout.addRow("Image Size (imgsz)", self.imgsz_spinbox)
|
214
177
|
|
@@ -304,37 +267,6 @@ class DeployPredictorDialog(QDialog):
|
|
304
267
|
|
305
268
|
group_box.setLayout(layout)
|
306
269
|
self.layout.addWidget(group_box)
|
307
|
-
|
308
|
-
def browse_model_file(self):
|
309
|
-
"""
|
310
|
-
Open a file dialog to browse for a model file.
|
311
|
-
"""
|
312
|
-
file_path, _ = QFileDialog.getOpenFileName(self,
|
313
|
-
"Select Model File",
|
314
|
-
"",
|
315
|
-
"Model Files (*.pt *.pth);;All Files (*)")
|
316
|
-
if file_path:
|
317
|
-
self.model_path_edit.setText(file_path)
|
318
|
-
|
319
|
-
# Load the class mapping if it exists
|
320
|
-
dir_path = os.path.dirname(os.path.dirname(file_path))
|
321
|
-
class_mapping_path = f"{dir_path}/class_mapping.json"
|
322
|
-
if os.path.exists(class_mapping_path):
|
323
|
-
self.class_mapping = json.load(open(class_mapping_path, 'r'))
|
324
|
-
self.mapping_edit.setText(class_mapping_path)
|
325
|
-
|
326
|
-
def browse_class_mapping_file(self):
|
327
|
-
"""
|
328
|
-
Browse and select a class mapping file.
|
329
|
-
"""
|
330
|
-
file_path, _ = QFileDialog.getOpenFileName(self,
|
331
|
-
"Select Class Mapping File",
|
332
|
-
"",
|
333
|
-
"JSON Files (*.json)")
|
334
|
-
if file_path:
|
335
|
-
# Load the class mapping
|
336
|
-
self.class_mapping = json.load(open(file_path, 'r'))
|
337
|
-
self.mapping_edit.setText(file_path)
|
338
270
|
|
339
271
|
def initialize_uncertainty_threshold(self):
|
340
272
|
"""Initialize the uncertainty threshold slider with the current value"""
|
@@ -390,6 +322,7 @@ class DeployPredictorDialog(QDialog):
|
|
390
322
|
def is_sam_model_deployed(self):
|
391
323
|
"""
|
392
324
|
Check if the SAM model is deployed and update the checkbox state accordingly.
|
325
|
+
If SAM is enabled for polygons, sync and disable the imgsz spinbox.
|
393
326
|
|
394
327
|
:return: Boolean indicating whether the SAM model is deployed
|
395
328
|
"""
|
@@ -402,9 +335,48 @@ class DeployPredictorDialog(QDialog):
|
|
402
335
|
self.use_sam_dropdown.setCurrentText("False")
|
403
336
|
QMessageBox.critical(self, "Error", "Please deploy the SAM model first.")
|
404
337
|
return False
|
338
|
+
|
339
|
+
# Check if SAM polygons are enabled
|
340
|
+
if self.use_sam_dropdown.currentText() == "True":
|
341
|
+
# Sync the imgsz spinbox with SAM's value
|
342
|
+
self.imgsz_spinbox.setValue(self.sam_dialog.imgsz_spinbox.value())
|
343
|
+
# Disable the spinbox
|
344
|
+
self.imgsz_spinbox.setEnabled(False)
|
345
|
+
|
346
|
+
# Connect SAM's imgsz_spinbox valueChanged signal to update our value
|
347
|
+
# First disconnect any existing connection to avoid duplicates
|
348
|
+
try:
|
349
|
+
self.sam_dialog.imgsz_spinbox.valueChanged.disconnect(self.update_from_sam_imgsz)
|
350
|
+
except TypeError:
|
351
|
+
# No connection exists yet
|
352
|
+
pass
|
353
|
+
|
354
|
+
# Connect the signal
|
355
|
+
self.sam_dialog.imgsz_spinbox.valueChanged.connect(self.update_from_sam_imgsz)
|
356
|
+
else:
|
357
|
+
# Re-enable the spinbox when SAM polygons are disabled
|
358
|
+
self.imgsz_spinbox.setEnabled(True)
|
359
|
+
|
360
|
+
# Disconnect the signal when SAM is disabled
|
361
|
+
try:
|
362
|
+
self.sam_dialog.imgsz_spinbox.valueChanged.disconnect(self.update_from_sam_imgsz)
|
363
|
+
except TypeError:
|
364
|
+
# No connection exists
|
365
|
+
pass
|
405
366
|
|
406
367
|
return True
|
407
368
|
|
369
|
+
def update_from_sam_imgsz(self, value):
|
370
|
+
"""
|
371
|
+
Update the SeeAnything image size when SAM's image size changes.
|
372
|
+
Only takes effect when SAM polygons are enabled.
|
373
|
+
|
374
|
+
Args:
|
375
|
+
value (int): The new image size value from SAM dialog
|
376
|
+
"""
|
377
|
+
if self.use_sam_dropdown.currentText() == "True":
|
378
|
+
self.imgsz_spinbox.setValue(value)
|
379
|
+
|
408
380
|
def load_model(self):
|
409
381
|
"""
|
410
382
|
Load the selected model.
|
@@ -412,46 +384,45 @@ class DeployPredictorDialog(QDialog):
|
|
412
384
|
QApplication.setOverrideCursor(Qt.WaitCursor)
|
413
385
|
progress_bar = ProgressBar(self.annotation_window, title="Loading Model")
|
414
386
|
progress_bar.show()
|
415
|
-
|
387
|
+
|
416
388
|
try:
|
417
389
|
# Get selected model path and download weights if needed
|
418
390
|
self.model_path = self.model_combo.currentText()
|
419
|
-
|
391
|
+
|
420
392
|
# Load model using registry
|
421
393
|
self.loaded_model = YOLOE(self.model_path).to(self.main_window.device)
|
422
|
-
|
423
|
-
# Create a dummy visual dictionary
|
394
|
+
|
395
|
+
# Create a dummy visual dictionary for standard model loading
|
424
396
|
visuals = dict(
|
425
397
|
bboxes=np.array(
|
426
398
|
[
|
427
|
-
[120, 425, 160, 445],
|
399
|
+
[120, 425, 160, 445], # Random box
|
428
400
|
],
|
429
401
|
),
|
430
402
|
cls=np.array(
|
431
403
|
np.zeros(1),
|
432
404
|
),
|
433
405
|
)
|
434
|
-
|
406
|
+
|
435
407
|
# Run a dummy prediction to load the model
|
436
408
|
self.loaded_model.predict(
|
437
409
|
np.zeros((640, 640, 3), dtype=np.uint8),
|
438
|
-
visual_prompts=visuals.copy(),
|
439
|
-
predictor=
|
410
|
+
visual_prompts=visuals.copy(), # This needs to happen to properly initialize the predictor
|
411
|
+
predictor=YOLOEVPSegPredictor, # This also needs to be SegPredictor, no matter what
|
440
412
|
imgsz=640,
|
441
413
|
conf=0.99,
|
442
414
|
)
|
443
|
-
|
444
|
-
# Load the model class names if available
|
445
|
-
if self.class_mapping:
|
446
|
-
self.add_labels_to_label_window()
|
447
|
-
|
415
|
+
# Finish the progress bar
|
448
416
|
progress_bar.finish_progress()
|
449
|
-
|
417
|
+
# Update the status bar
|
418
|
+
self.status_bar.setText(f"Loaded ({self.model_path}")
|
450
419
|
QMessageBox.information(self.annotation_window, "Model Loaded", "Model loaded successfully")
|
451
420
|
|
452
421
|
except Exception as e:
|
422
|
+
self.loaded_model = None
|
423
|
+
self.status_bar.setText(f"Error loading model: {self.model_path}")
|
453
424
|
QMessageBox.critical(self.annotation_window, "Error Loading Model", f"Error loading model: {e}")
|
454
|
-
|
425
|
+
|
455
426
|
finally:
|
456
427
|
# Restore cursor
|
457
428
|
QApplication.restoreOverrideCursor()
|
@@ -459,19 +430,7 @@ class DeployPredictorDialog(QDialog):
|
|
459
430
|
progress_bar.stop_progress()
|
460
431
|
progress_bar.close()
|
461
432
|
progress_bar = None
|
462
|
-
|
463
|
-
self.accept()
|
464
|
-
|
465
|
-
def add_labels_to_label_window(self):
|
466
|
-
"""
|
467
|
-
Add labels to the label window based on the class mapping.
|
468
|
-
"""
|
469
|
-
if self.class_mapping:
|
470
|
-
for label in self.class_mapping.values():
|
471
|
-
self.main_window.label_window.add_label_if_not_exists(label['short_label_code'],
|
472
|
-
label['long_label_code'],
|
473
|
-
QColor(*label['color']))
|
474
|
-
|
433
|
+
|
475
434
|
def resize_image(self, image):
|
476
435
|
"""
|
477
436
|
Resize the image to the specified size.
|
@@ -526,26 +485,11 @@ class DeployPredictorDialog(QDialog):
|
|
526
485
|
self.resized_image = self.resize_image(image)
|
527
486
|
else:
|
528
487
|
self.resized_image = image
|
529
|
-
|
530
|
-
def
|
488
|
+
|
489
|
+
def scale_prompts(self, bboxes, masks=None):
|
531
490
|
"""
|
532
|
-
|
533
|
-
|
534
|
-
Args:
|
535
|
-
bboxes (np.ndarray): The bounding boxes to use as prompts.
|
536
|
-
masks (list, optional): A list of polygons to use as prompts for segmentation.
|
537
|
-
|
538
|
-
Returns:
|
539
|
-
results (Results): Ultralytics Results object
|
491
|
+
Scale the bounding boxes and masks to the resized image.
|
540
492
|
"""
|
541
|
-
if not self.loaded_model:
|
542
|
-
QMessageBox.critical(self.annotation_window, "Model Not Loaded",
|
543
|
-
"Model not loaded, cannot make predictions")
|
544
|
-
return None
|
545
|
-
|
546
|
-
if not len(bboxes):
|
547
|
-
return None
|
548
|
-
|
549
493
|
# Update the bbox coordinates to be relative to the resized image
|
550
494
|
bboxes = np.array(bboxes)
|
551
495
|
bboxes[:, 0] = (bboxes[:, 0] / self.original_image.shape[1]) * self.resized_image.shape[1]
|
@@ -557,7 +501,7 @@ class DeployPredictorDialog(QDialog):
|
|
557
501
|
self.task = self.task_dropdown.currentText()
|
558
502
|
|
559
503
|
# Create a visual dictionary
|
560
|
-
|
504
|
+
visual_prompts = {
|
561
505
|
'bboxes': np.array(bboxes),
|
562
506
|
'cls': np.zeros(len(bboxes))
|
563
507
|
}
|
@@ -569,21 +513,44 @@ class DeployPredictorDialog(QDialog):
|
|
569
513
|
scaled_mask[:, 0] = (scaled_mask[:, 0] / self.original_image.shape[1]) * self.resized_image.shape[1]
|
570
514
|
scaled_mask[:, 1] = (scaled_mask[:, 1] / self.original_image.shape[0]) * self.resized_image.shape[0]
|
571
515
|
scaled_masks.append(scaled_mask)
|
572
|
-
|
516
|
+
visual_prompts['masks'] = scaled_masks
|
573
517
|
else: # Fallback to creating masks from bboxes if no masks are provided
|
574
518
|
fallback_masks = []
|
575
519
|
for bbox in bboxes:
|
576
520
|
x1, y1, x2, y2 = bbox
|
577
521
|
fallback_masks.append(np.array([[x1, y1], [x2, y1], [x2, y2], [x1, y2]]))
|
578
|
-
|
522
|
+
visual_prompts['masks'] = fallback_masks
|
523
|
+
|
524
|
+
return visual_prompts
|
579
525
|
|
580
|
-
|
526
|
+
def predict_from_prompts(self, bboxes, masks=None):
|
527
|
+
"""
|
528
|
+
Make predictions using the currently loaded model using prompts.
|
529
|
+
|
530
|
+
Args:
|
531
|
+
bboxes (np.ndarray): The bounding boxes to use as prompts.
|
532
|
+
masks (list, optional): A list of polygons to use as prompts for segmentation.
|
533
|
+
|
534
|
+
Returns:
|
535
|
+
results (Results): Ultralytics Results object
|
536
|
+
"""
|
537
|
+
if not self.loaded_model:
|
538
|
+
QMessageBox.critical(self.annotation_window,
|
539
|
+
"Model Not Loaded",
|
540
|
+
"Model not loaded, cannot make predictions")
|
541
|
+
return None
|
542
|
+
|
543
|
+
if not len(bboxes):
|
544
|
+
return None
|
545
|
+
|
546
|
+
# Get the scaled visual prompts
|
547
|
+
visual_prompts = self.scale_prompts(bboxes, masks)
|
581
548
|
|
582
549
|
try:
|
583
550
|
# Make predictions
|
584
551
|
results = self.loaded_model.predict(self.resized_image,
|
585
|
-
visual_prompts=
|
586
|
-
predictor=
|
552
|
+
visual_prompts=visual_prompts.copy(),
|
553
|
+
predictor=YOLOEVPSegPredictor,
|
587
554
|
imgsz=max(self.resized_image.shape[:2]),
|
588
555
|
conf=self.main_window.get_uncertainty_thresh(),
|
589
556
|
iou=self.main_window.get_iou_thresh(),
|
@@ -618,18 +585,30 @@ class DeployPredictorDialog(QDialog):
|
|
618
585
|
max_area_thresh=self.main_window.get_area_thresh_max()
|
619
586
|
)
|
620
587
|
|
621
|
-
#
|
622
|
-
|
623
|
-
|
624
|
-
#
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
588
|
+
# Get the scaled visual prompts
|
589
|
+
visual_prompts = self.scale_prompts(refer_bboxes, refer_masks)
|
590
|
+
|
591
|
+
# If VPEs are being used
|
592
|
+
if self.vpe is not None:
|
593
|
+
# Generate a new VPE from the current visual prompts
|
594
|
+
new_vpe = self.prompts_to_vpes(visual_prompts, self.resized_image)
|
595
|
+
|
596
|
+
if new_vpe is not None:
|
597
|
+
# If we already have a VPE, average with the existing one
|
598
|
+
if self.vpe.shape == new_vpe.shape:
|
599
|
+
self.vpe = (self.vpe + new_vpe) / 2
|
600
|
+
# Re-normalize
|
601
|
+
self.vpe = torch.nn.functional.normalize(self.vpe, p=2, dim=-1)
|
602
|
+
else:
|
603
|
+
# Replace with the new VPE if shapes don't match
|
604
|
+
self.vpe = new_vpe
|
605
|
+
|
606
|
+
# Set the updated VPE in the model
|
607
|
+
self.loaded_model.is_fused = lambda: False
|
608
|
+
self.loaded_model.set_classes(["object0"], self.vpe)
|
609
|
+
|
610
|
+
# Clear visual prompts since we're using VPE
|
611
|
+
visual_prompts = {} # this is okay with a fused model
|
633
612
|
|
634
613
|
# Create a progress bar
|
635
614
|
QApplication.setOverrideCursor(Qt.WaitCursor)
|
@@ -643,8 +622,8 @@ class DeployPredictorDialog(QDialog):
|
|
643
622
|
# Make predictions
|
644
623
|
results = self.loaded_model.predict(target_image,
|
645
624
|
refer_image=refer_image,
|
646
|
-
visual_prompts=
|
647
|
-
predictor=
|
625
|
+
visual_prompts=visual_prompts.copy(),
|
626
|
+
predictor=YOLOEVPSegPredictor,
|
648
627
|
imgsz=self.imgsz_spinbox.value(),
|
649
628
|
conf=self.main_window.get_uncertainty_thresh(),
|
650
629
|
iou=self.main_window.get_iou_thresh(),
|
@@ -87,9 +87,16 @@ class CutSubTool(SubTool):
|
|
87
87
|
self._update_cut_line_path(position)
|
88
88
|
|
89
89
|
def keyPressEvent(self, event):
|
90
|
-
"""Handle key press events for
|
91
|
-
|
90
|
+
"""Handle key press events for cutting operations."""
|
91
|
+
# Check for Ctrl+X to toggle cutting mode off
|
92
|
+
if event.modifiers() & Qt.ControlModifier and event.key() == Qt.Key_X:
|
92
93
|
self.parent_tool.deactivate_subtool()
|
94
|
+
return
|
95
|
+
|
96
|
+
# Handle Backspace to clear the current cutting line but stay in cutting mode
|
97
|
+
if event.key() == Qt.Key_Backspace:
|
98
|
+
self._clear_cutting_line()
|
99
|
+
return
|
93
100
|
|
94
101
|
def _start_drawing_cut_line(self, position):
|
95
102
|
"""Start drawing the cut line from the given position."""
|
@@ -115,6 +122,15 @@ class CutSubTool(SubTool):
|
|
115
122
|
path.lineTo(point)
|
116
123
|
self.cutting_path_item.setPath(path)
|
117
124
|
|
125
|
+
def _clear_cutting_line(self):
|
126
|
+
"""Clear the current cutting line but remain in cutting mode."""
|
127
|
+
self.cutting_points = []
|
128
|
+
self.drawing_in_progress = False
|
129
|
+
if self.cutting_path_item:
|
130
|
+
self.annotation_window.scene.removeItem(self.cutting_path_item)
|
131
|
+
self.cutting_path_item = None
|
132
|
+
self.annotation_window.scene.update()
|
133
|
+
|
118
134
|
def _break_apart_multipolygon(self):
|
119
135
|
"""Handle the special case of 'cutting' a MultiPolygonAnnotation."""
|
120
136
|
new_annotations = self.target_annotation.cut()
|
@@ -133,12 +133,51 @@ class PolygonTool(Tool):
|
|
133
133
|
return None
|
134
134
|
|
135
135
|
# Create the annotation with current points
|
136
|
-
# The polygon simplification is now handled inside the PolygonAnnotation class
|
137
136
|
if finished and len(self.points) > 2:
|
138
137
|
# Close the polygon
|
139
138
|
self.points.append(self.points[0])
|
140
|
-
|
141
|
-
|
139
|
+
|
140
|
+
# --- Validation for polygon size and shape ---
|
141
|
+
# Step 1: Remove duplicate or near-duplicate points
|
142
|
+
filtered_points = []
|
143
|
+
MIN_DISTANCE = 2.0 # Minimum distance between points in pixels
|
144
|
+
|
145
|
+
for i, point in enumerate(self.points):
|
146
|
+
# Skip if this point is too close to the previous one
|
147
|
+
if i > 0:
|
148
|
+
prev_point = filtered_points[-1]
|
149
|
+
distance = ((point.x() - prev_point.x())**2 + (point.y() - prev_point.y())**2)**0.5
|
150
|
+
if distance < MIN_DISTANCE:
|
151
|
+
continue
|
152
|
+
filtered_points.append(point)
|
153
|
+
|
154
|
+
# Step 2: Ensure we have enough points for a valid polygon
|
155
|
+
if len(filtered_points) < 4: # Need at least 3 + 1 closing point
|
156
|
+
# Create a small triangle/square if we don't have enough points
|
157
|
+
if len(filtered_points) > 0:
|
158
|
+
center_x = sum(p.x() for p in filtered_points) / len(filtered_points)
|
159
|
+
center_y = sum(p.y() for p in filtered_points) / len(filtered_points)
|
160
|
+
|
161
|
+
# Create a small polygon centered on the average of existing points
|
162
|
+
MIN_SIZE = 5.0
|
163
|
+
filtered_points = [
|
164
|
+
QPointF(center_x - MIN_SIZE, center_y - MIN_SIZE),
|
165
|
+
QPointF(center_x + MIN_SIZE, center_y - MIN_SIZE),
|
166
|
+
QPointF(center_x + MIN_SIZE, center_y + MIN_SIZE),
|
167
|
+
QPointF(center_x - MIN_SIZE, center_y + MIN_SIZE),
|
168
|
+
QPointF(center_x - MIN_SIZE, center_y - MIN_SIZE) # Close the polygon
|
169
|
+
]
|
170
|
+
|
171
|
+
QMessageBox.information(
|
172
|
+
self.annotation_window,
|
173
|
+
"Polygon Adjusted",
|
174
|
+
"The polygon had too few unique points and has been adjusted to a minimum size."
|
175
|
+
)
|
176
|
+
|
177
|
+
# Use the filtered points list instead of the original
|
178
|
+
self.points = filtered_points
|
179
|
+
|
180
|
+
# Create the annotation with validated points
|
142
181
|
annotation = PolygonAnnotation(self.points,
|
143
182
|
self.annotation_window.selected_label.short_label_code,
|
144
183
|
self.annotation_window.selected_label.long_label_code,
|
@@ -113,6 +113,36 @@ class RectangleTool(Tool):
|
|
113
113
|
# Ensure top_left and bottom_right are correctly calculated
|
114
114
|
top_left = QPointF(min(self.start_point.x(), end_point.x()), min(self.start_point.y(), end_point.y()))
|
115
115
|
bottom_right = QPointF(max(self.start_point.x(), end_point.x()), max(self.start_point.y(), end_point.y()))
|
116
|
+
|
117
|
+
# Calculate width and height of the rectangle
|
118
|
+
width = bottom_right.x() - top_left.x()
|
119
|
+
height = bottom_right.y() - top_left.y()
|
120
|
+
|
121
|
+
# Define minimum dimensions for a valid rectangle (e.g., 3x3 pixels)
|
122
|
+
MIN_DIMENSION = 3.0
|
123
|
+
|
124
|
+
# If rectangle is too small and we're finalizing it, enforce minimum size
|
125
|
+
if finished and (width < MIN_DIMENSION or height < MIN_DIMENSION):
|
126
|
+
if width < MIN_DIMENSION:
|
127
|
+
# Expand width while maintaining center
|
128
|
+
center_x = (top_left.x() + bottom_right.x()) / 2
|
129
|
+
top_left.setX(center_x - MIN_DIMENSION / 2)
|
130
|
+
bottom_right.setX(center_x + MIN_DIMENSION / 2)
|
131
|
+
|
132
|
+
if height < MIN_DIMENSION:
|
133
|
+
# Expand height while maintaining center
|
134
|
+
center_y = (top_left.y() + bottom_right.y()) / 2
|
135
|
+
top_left.setY(center_y - MIN_DIMENSION / 2)
|
136
|
+
bottom_right.setY(center_y + MIN_DIMENSION / 2)
|
137
|
+
|
138
|
+
# Show a message if we had to adjust a very small rectangle
|
139
|
+
if width < 1 or height < 1:
|
140
|
+
QMessageBox.information(
|
141
|
+
self.annotation_window,
|
142
|
+
"Rectangle Adjusted",
|
143
|
+
f"The rectangle was too small and has been adjusted to a minimum size of "
|
144
|
+
f"{MIN_DIMENSION}x{MIN_DIMENSION} pixels."
|
145
|
+
)
|
116
146
|
|
117
147
|
# Create the rectangle annotation
|
118
148
|
annotation = RectangleAnnotation(top_left,
|
@@ -124,5 +124,22 @@ class ResizeSubTool(SubTool):
|
|
124
124
|
}
|
125
125
|
|
126
126
|
def _get_polygon_handles(self, annotation):
|
127
|
-
"""
|
128
|
-
|
127
|
+
"""
|
128
|
+
Return resize handles for a polygon, including its outer boundary and all holes.
|
129
|
+
Uses the new handle format: 'point_{poly_index}_{vertex_index}'.
|
130
|
+
"""
|
131
|
+
handles = {}
|
132
|
+
|
133
|
+
# 1. Create handles for the outer boundary using the 'outer' keyword.
|
134
|
+
for i, p in enumerate(annotation.points):
|
135
|
+
handle_name = f"point_outer_{i}"
|
136
|
+
handles[handle_name] = QPointF(p.x(), p.y())
|
137
|
+
|
138
|
+
# 2. Create handles for each of the inner holes using their index.
|
139
|
+
if hasattr(annotation, 'holes'):
|
140
|
+
for hole_index, hole in enumerate(annotation.holes):
|
141
|
+
for vertex_index, p in enumerate(hole):
|
142
|
+
handle_name = f"point_{hole_index}_{vertex_index}"
|
143
|
+
handles[handle_name] = QPointF(p.x(), p.y())
|
144
|
+
|
145
|
+
return handles
|