coralnet-toolbox 0.0.72__py2.py3-none-any.whl → 0.0.74__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- coralnet_toolbox/Annotations/QtAnnotation.py +28 -69
- coralnet_toolbox/Annotations/QtMaskAnnotation.py +408 -0
- coralnet_toolbox/Annotations/QtMultiPolygonAnnotation.py +72 -56
- coralnet_toolbox/Annotations/QtPatchAnnotation.py +165 -216
- coralnet_toolbox/Annotations/QtPolygonAnnotation.py +497 -353
- coralnet_toolbox/Annotations/QtRectangleAnnotation.py +126 -116
- coralnet_toolbox/AutoDistill/QtDeployModel.py +23 -12
- coralnet_toolbox/CoralNet/QtDownload.py +2 -1
- coralnet_toolbox/Explorer/QtDataItem.py +1 -1
- coralnet_toolbox/Explorer/QtExplorer.py +159 -17
- coralnet_toolbox/Explorer/QtSettingsWidgets.py +160 -86
- coralnet_toolbox/IO/QtExportTagLabAnnotations.py +30 -10
- coralnet_toolbox/IO/QtImportTagLabAnnotations.py +21 -15
- coralnet_toolbox/IO/QtOpenProject.py +46 -78
- coralnet_toolbox/IO/QtSaveProject.py +18 -43
- coralnet_toolbox/MachineLearning/DeployModel/QtDetect.py +22 -11
- coralnet_toolbox/MachineLearning/DeployModel/QtSegment.py +22 -10
- coralnet_toolbox/MachineLearning/ExportDataset/QtBase.py +61 -24
- coralnet_toolbox/MachineLearning/ExportDataset/QtClassify.py +5 -1
- coralnet_toolbox/MachineLearning/ExportDataset/QtDetect.py +19 -6
- coralnet_toolbox/MachineLearning/ExportDataset/QtSegment.py +21 -8
- coralnet_toolbox/MachineLearning/ImportDataset/QtBase.py +42 -22
- coralnet_toolbox/MachineLearning/VideoInference/QtBase.py +0 -4
- coralnet_toolbox/QtAnnotationWindow.py +42 -14
- coralnet_toolbox/QtEventFilter.py +19 -2
- coralnet_toolbox/QtImageWindow.py +134 -86
- coralnet_toolbox/QtLabelWindow.py +14 -2
- coralnet_toolbox/QtMainWindow.py +122 -9
- coralnet_toolbox/QtProgressBar.py +52 -27
- coralnet_toolbox/Rasters/QtRaster.py +59 -7
- coralnet_toolbox/Rasters/RasterTableModel.py +42 -14
- coralnet_toolbox/SAM/QtBatchInference.py +0 -2
- coralnet_toolbox/SAM/QtDeployGenerator.py +22 -11
- coralnet_toolbox/SAM/QtDeployPredictor.py +10 -0
- coralnet_toolbox/SeeAnything/QtBatchInference.py +19 -221
- coralnet_toolbox/SeeAnything/QtDeployGenerator.py +1634 -0
- coralnet_toolbox/SeeAnything/QtDeployPredictor.py +107 -154
- coralnet_toolbox/SeeAnything/QtTrainModel.py +115 -45
- coralnet_toolbox/SeeAnything/__init__.py +2 -0
- coralnet_toolbox/Tools/QtCutSubTool.py +18 -2
- coralnet_toolbox/Tools/QtResizeSubTool.py +19 -2
- coralnet_toolbox/Tools/QtSAMTool.py +222 -57
- coralnet_toolbox/Tools/QtSeeAnythingTool.py +223 -55
- coralnet_toolbox/Tools/QtSelectSubTool.py +6 -4
- coralnet_toolbox/Tools/QtSelectTool.py +27 -3
- coralnet_toolbox/Tools/QtSubtractSubTool.py +66 -0
- coralnet_toolbox/Tools/QtWorkAreaTool.py +25 -13
- coralnet_toolbox/Tools/__init__.py +2 -0
- coralnet_toolbox/__init__.py +1 -1
- coralnet_toolbox/utilities.py +137 -47
- coralnet_toolbox-0.0.74.dist-info/METADATA +375 -0
- {coralnet_toolbox-0.0.72.dist-info → coralnet_toolbox-0.0.74.dist-info}/RECORD +56 -53
- coralnet_toolbox-0.0.72.dist-info/METADATA +0 -341
- {coralnet_toolbox-0.0.72.dist-info → coralnet_toolbox-0.0.74.dist-info}/WHEEL +0 -0
- {coralnet_toolbox-0.0.72.dist-info → coralnet_toolbox-0.0.74.dist-info}/entry_points.txt +0 -0
- {coralnet_toolbox-0.0.72.dist-info → coralnet_toolbox-0.0.74.dist-info}/licenses/LICENSE.txt +0 -0
- {coralnet_toolbox-0.0.72.dist-info → coralnet_toolbox-0.0.74.dist-info}/top_level.txt +0 -0
@@ -2,23 +2,23 @@ import warnings
|
|
2
2
|
|
3
3
|
import os
|
4
4
|
import gc
|
5
|
-
import ujson as json
|
6
5
|
|
7
6
|
import numpy as np
|
8
7
|
|
9
|
-
|
10
|
-
from
|
11
|
-
from
|
12
|
-
QHBoxLayout, QLabel, QMessageBox, QPushButton,
|
13
|
-
QSlider, QSpinBox, QVBoxLayout, QGroupBox, QTabWidget,
|
14
|
-
QWidget, QLineEdit, QFileDialog)
|
8
|
+
import torch
|
9
|
+
from torch.cuda import empty_cache
|
10
|
+
from ultralytics.utils import ops
|
15
11
|
|
16
12
|
from ultralytics import YOLOE
|
17
13
|
from ultralytics.models.yolo.yoloe import YOLOEVPSegPredictor
|
18
14
|
from ultralytics.models.yolo.yoloe import YOLOEVPDetectPredictor
|
19
15
|
|
20
|
-
from
|
21
|
-
from
|
16
|
+
from PyQt5.QtCore import Qt
|
17
|
+
from PyQt5.QtGui import QColor
|
18
|
+
from PyQt5.QtWidgets import (QApplication, QComboBox, QDialog, QFormLayout,
|
19
|
+
QHBoxLayout, QLabel, QMessageBox, QPushButton,
|
20
|
+
QSlider, QSpinBox, QVBoxLayout, QGroupBox,
|
21
|
+
QWidget, QLineEdit, QFileDialog)
|
22
22
|
|
23
23
|
from coralnet_toolbox.Results import ResultsProcessor
|
24
24
|
|
@@ -98,7 +98,10 @@ class DeployPredictorDialog(QDialog):
|
|
98
98
|
layout = QVBoxLayout()
|
99
99
|
|
100
100
|
# Create a QLabel with explanatory text and hyperlink
|
101
|
-
info_label = QLabel(
|
101
|
+
info_label = QLabel(
|
102
|
+
"Choose a Predictor to deploy and use interactively with the See Anything tool. "
|
103
|
+
"Optionally include a custom visual prompt encoding (VPE) file."
|
104
|
+
)
|
102
105
|
|
103
106
|
info_label.setOpenExternalLinks(True)
|
104
107
|
info_label.setWordWrap(True)
|
@@ -109,21 +112,15 @@ class DeployPredictorDialog(QDialog):
|
|
109
112
|
|
110
113
|
def setup_models_layout(self):
|
111
114
|
"""
|
112
|
-
Setup the models layout with
|
115
|
+
Setup the models layout with standard models and file selection.
|
113
116
|
"""
|
114
117
|
group_box = QGroupBox("Model Selection")
|
115
|
-
layout =
|
116
|
-
|
117
|
-
#
|
118
|
-
tab_widget = QTabWidget()
|
119
|
-
|
120
|
-
# Tab 1: Standard models
|
121
|
-
standard_tab = QWidget()
|
122
|
-
standard_layout = QVBoxLayout(standard_tab)
|
123
|
-
|
118
|
+
layout = QFormLayout()
|
119
|
+
|
120
|
+
# Model dropdown
|
124
121
|
self.model_combo = QComboBox()
|
125
122
|
self.model_combo.setEditable(True)
|
126
|
-
|
123
|
+
|
127
124
|
# Define available models
|
128
125
|
standard_models = [
|
129
126
|
'yoloe-v8s-seg.pt',
|
@@ -133,83 +130,18 @@ class DeployPredictorDialog(QDialog):
|
|
133
130
|
'yoloe-11m-seg.pt',
|
134
131
|
'yoloe-11l-seg.pt',
|
135
132
|
]
|
136
|
-
|
133
|
+
|
137
134
|
# Add all models to combo box
|
138
135
|
self.model_combo.addItems(standard_models)
|
136
|
+
|
139
137
|
# Set the default model
|
140
138
|
self.model_combo.setCurrentIndex(standard_models.index('yoloe-v8s-seg.pt'))
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
tab_widget.addTab(standard_tab, "Use Existing Model")
|
146
|
-
|
147
|
-
# Tab 2: Custom model
|
148
|
-
custom_tab = QWidget()
|
149
|
-
custom_layout = QFormLayout(custom_tab)
|
150
|
-
|
151
|
-
# Custom model file selection
|
152
|
-
self.model_path_edit = QLineEdit()
|
153
|
-
browse_button = QPushButton("Browse...")
|
154
|
-
browse_button.clicked.connect(self.browse_model_file)
|
155
|
-
|
156
|
-
model_path_layout = QHBoxLayout()
|
157
|
-
model_path_layout.addWidget(self.model_path_edit)
|
158
|
-
model_path_layout.addWidget(browse_button)
|
159
|
-
custom_layout.addRow("Custom Model:", model_path_layout)
|
160
|
-
|
161
|
-
# Class Mapping
|
162
|
-
self.mapping_edit = QLineEdit()
|
163
|
-
self.mapping_button = QPushButton("Browse...")
|
164
|
-
self.mapping_button.clicked.connect(self.browse_class_mapping_file)
|
165
|
-
|
166
|
-
class_mapping_layout = QHBoxLayout()
|
167
|
-
class_mapping_layout.addWidget(self.mapping_edit)
|
168
|
-
class_mapping_layout.addWidget(self.mapping_button)
|
169
|
-
custom_layout.addRow("Class Mapping:", class_mapping_layout)
|
170
|
-
|
171
|
-
tab_widget.addTab(custom_tab, "Custom Model")
|
172
|
-
|
173
|
-
# Add the tab widget to the main layout
|
174
|
-
layout.addWidget(tab_widget)
|
175
|
-
|
176
|
-
# Store the tab widget for later reference
|
177
|
-
self.model_tab_widget = tab_widget
|
178
|
-
|
139
|
+
# Create a layout for the model selection
|
140
|
+
layout.addRow("Models:", self.model_combo)
|
141
|
+
|
179
142
|
group_box.setLayout(layout)
|
180
143
|
self.layout.addWidget(group_box)
|
181
144
|
|
182
|
-
def browse_model_file(self):
|
183
|
-
"""
|
184
|
-
Open a file dialog to browse for a model file.
|
185
|
-
"""
|
186
|
-
file_path, _ = QFileDialog.getOpenFileName(self,
|
187
|
-
"Select Model File",
|
188
|
-
"",
|
189
|
-
"Model Files (*.pt *.pth);;All Files (*)")
|
190
|
-
if file_path:
|
191
|
-
self.model_path_edit.setText(file_path)
|
192
|
-
|
193
|
-
# Load the class mapping if it exists
|
194
|
-
dir_path = os.path.dirname(os.path.dirname(file_path))
|
195
|
-
class_mapping_path = f"{dir_path}/class_mapping.json"
|
196
|
-
if os.path.exists(class_mapping_path):
|
197
|
-
self.class_mapping = json.load(open(class_mapping_path, 'r'))
|
198
|
-
self.mapping_edit.setText(class_mapping_path)
|
199
|
-
|
200
|
-
def browse_class_mapping_file(self):
|
201
|
-
"""
|
202
|
-
Browse and select a class mapping file.
|
203
|
-
"""
|
204
|
-
file_path, _ = QFileDialog.getOpenFileName(self,
|
205
|
-
"Select Class Mapping File",
|
206
|
-
"",
|
207
|
-
"JSON Files (*.json)")
|
208
|
-
if file_path:
|
209
|
-
# Load the class mapping
|
210
|
-
self.class_mapping = json.load(open(file_path, 'r'))
|
211
|
-
self.mapping_edit.setText(file_path)
|
212
|
-
|
213
145
|
def setup_parameters_layout(self):
|
214
146
|
"""
|
215
147
|
Setup parameter control section in a group box.
|
@@ -412,46 +344,43 @@ class DeployPredictorDialog(QDialog):
|
|
412
344
|
QApplication.setOverrideCursor(Qt.WaitCursor)
|
413
345
|
progress_bar = ProgressBar(self.annotation_window, title="Loading Model")
|
414
346
|
progress_bar.show()
|
415
|
-
|
347
|
+
|
416
348
|
try:
|
417
349
|
# Get selected model path and download weights if needed
|
418
350
|
self.model_path = self.model_combo.currentText()
|
419
|
-
|
351
|
+
|
420
352
|
# Load model using registry
|
421
353
|
self.loaded_model = YOLOE(self.model_path).to(self.main_window.device)
|
422
|
-
|
423
|
-
# Create a dummy visual dictionary
|
354
|
+
|
355
|
+
# Create a dummy visual dictionary for standard model loading
|
424
356
|
visuals = dict(
|
425
357
|
bboxes=np.array(
|
426
358
|
[
|
427
|
-
[120, 425, 160, 445],
|
359
|
+
[120, 425, 160, 445], # Random box
|
428
360
|
],
|
429
361
|
),
|
430
362
|
cls=np.array(
|
431
363
|
np.zeros(1),
|
432
364
|
),
|
433
365
|
)
|
434
|
-
|
366
|
+
|
435
367
|
# Run a dummy prediction to load the model
|
436
368
|
self.loaded_model.predict(
|
437
369
|
np.zeros((640, 640, 3), dtype=np.uint8),
|
438
|
-
visual_prompts=visuals.copy(),
|
439
|
-
predictor=
|
370
|
+
visual_prompts=visuals.copy(), # This needs to happen to properly initialize the predictor
|
371
|
+
predictor=YOLOEVPSegPredictor, # This also needs to be SegPredictor, no matter what
|
440
372
|
imgsz=640,
|
441
373
|
conf=0.99,
|
442
374
|
)
|
443
375
|
|
444
|
-
|
445
|
-
if self.class_mapping:
|
446
|
-
self.add_labels_to_label_window()
|
447
|
-
|
448
|
-
progress_bar.finish_progress()
|
449
|
-
self.status_bar.setText("Model loaded")
|
376
|
+
self.status_bar.setText(f"Loaded ({self.model_path}")
|
450
377
|
QMessageBox.information(self.annotation_window, "Model Loaded", "Model loaded successfully")
|
451
378
|
|
452
379
|
except Exception as e:
|
380
|
+
self.loaded_model = None
|
381
|
+
self.status_bar.setText(f"Error loading model: {self.model_path}")
|
453
382
|
QMessageBox.critical(self.annotation_window, "Error Loading Model", f"Error loading model: {e}")
|
454
|
-
|
383
|
+
|
455
384
|
finally:
|
456
385
|
# Restore cursor
|
457
386
|
QApplication.restoreOverrideCursor()
|
@@ -460,18 +389,6 @@ class DeployPredictorDialog(QDialog):
|
|
460
389
|
progress_bar.close()
|
461
390
|
progress_bar = None
|
462
391
|
|
463
|
-
self.accept()
|
464
|
-
|
465
|
-
def add_labels_to_label_window(self):
|
466
|
-
"""
|
467
|
-
Add labels to the label window based on the class mapping.
|
468
|
-
"""
|
469
|
-
if self.class_mapping:
|
470
|
-
for label in self.class_mapping.values():
|
471
|
-
self.main_window.label_window.add_label_if_not_exists(label['short_label_code'],
|
472
|
-
label['long_label_code'],
|
473
|
-
QColor(*label['color']))
|
474
|
-
|
475
392
|
def resize_image(self, image):
|
476
393
|
"""
|
477
394
|
Resize the image to the specified size.
|
@@ -517,9 +434,6 @@ class DeployPredictorDialog(QDialog):
|
|
517
434
|
# Open the image using rasterio
|
518
435
|
image = rasterio_to_numpy(self.main_window.image_window.rasterio_images[image_path])
|
519
436
|
|
520
|
-
# Preprocess the image
|
521
|
-
# image = preprocess_image(image)
|
522
|
-
|
523
437
|
# Save the original image
|
524
438
|
self.original_image = image
|
525
439
|
self.image_path = image_path
|
@@ -529,19 +443,57 @@ class DeployPredictorDialog(QDialog):
|
|
529
443
|
self.resized_image = self.resize_image(image)
|
530
444
|
else:
|
531
445
|
self.resized_image = image
|
446
|
+
|
447
|
+
def scale_prompts(self, bboxes, masks=None):
|
448
|
+
"""
|
449
|
+
Scale the bounding boxes and masks to the resized image.
|
450
|
+
"""
|
451
|
+
# Update the bbox coordinates to be relative to the resized image
|
452
|
+
bboxes = np.array(bboxes)
|
453
|
+
bboxes[:, 0] = (bboxes[:, 0] / self.original_image.shape[1]) * self.resized_image.shape[1]
|
454
|
+
bboxes[:, 1] = (bboxes[:, 1] / self.original_image.shape[0]) * self.resized_image.shape[0]
|
455
|
+
bboxes[:, 2] = (bboxes[:, 2] / self.original_image.shape[1]) * self.resized_image.shape[1]
|
456
|
+
bboxes[:, 3] = (bboxes[:, 3] / self.original_image.shape[0]) * self.resized_image.shape[0]
|
457
|
+
|
458
|
+
# Set the predictor
|
459
|
+
self.task = self.task_dropdown.currentText()
|
532
460
|
|
533
|
-
|
461
|
+
# Create a visual dictionary
|
462
|
+
visual_prompts = {
|
463
|
+
'bboxes': np.array(bboxes),
|
464
|
+
'cls': np.zeros(len(bboxes))
|
465
|
+
}
|
466
|
+
if self.task == 'segment':
|
467
|
+
if masks:
|
468
|
+
scaled_masks = []
|
469
|
+
for mask in masks:
|
470
|
+
scaled_mask = np.array(mask, dtype=np.float32)
|
471
|
+
scaled_mask[:, 0] = (scaled_mask[:, 0] / self.original_image.shape[1]) * self.resized_image.shape[1]
|
472
|
+
scaled_mask[:, 1] = (scaled_mask[:, 1] / self.original_image.shape[0]) * self.resized_image.shape[0]
|
473
|
+
scaled_masks.append(scaled_mask)
|
474
|
+
visual_prompts['masks'] = scaled_masks
|
475
|
+
else: # Fallback to creating masks from bboxes if no masks are provided
|
476
|
+
fallback_masks = []
|
477
|
+
for bbox in bboxes:
|
478
|
+
x1, y1, x2, y2 = bbox
|
479
|
+
fallback_masks.append(np.array([[x1, y1], [x2, y1], [x2, y2], [x1, y2]]))
|
480
|
+
visual_prompts['masks'] = fallback_masks
|
481
|
+
|
482
|
+
return visual_prompts
|
483
|
+
|
484
|
+
def predict_from_prompts(self, bboxes, masks=None):
|
534
485
|
"""
|
535
486
|
Make predictions using the currently loaded model using prompts.
|
536
487
|
|
537
488
|
Args:
|
538
|
-
|
489
|
+
bboxes (np.ndarray): The bounding boxes to use as prompts.
|
490
|
+
masks (list, optional): A list of polygons to use as prompts for segmentation.
|
539
491
|
|
540
492
|
Returns:
|
541
493
|
results (Results): Ultralytics Results object
|
542
494
|
"""
|
543
495
|
if not self.loaded_model:
|
544
|
-
QMessageBox.critical(self.annotation_window,
|
496
|
+
QMessageBox.critical(self.annotation_window,
|
545
497
|
"Model Not Loaded",
|
546
498
|
"Model not loaded, cannot make predictions")
|
547
499
|
return None
|
@@ -549,28 +501,14 @@ class DeployPredictorDialog(QDialog):
|
|
549
501
|
if not len(bboxes):
|
550
502
|
return None
|
551
503
|
|
552
|
-
#
|
553
|
-
|
554
|
-
bboxes[:, 0] = (bboxes[:, 0] / self.original_image.shape[1]) * self.resized_image.shape[1]
|
555
|
-
bboxes[:, 1] = (bboxes[:, 1] / self.original_image.shape[0]) * self.resized_image.shape[0]
|
556
|
-
bboxes[:, 2] = (bboxes[:, 2] / self.original_image.shape[1]) * self.resized_image.shape[1]
|
557
|
-
bboxes[:, 3] = (bboxes[:, 3] / self.original_image.shape[0]) * self.resized_image.shape[0]
|
558
|
-
|
559
|
-
# Create a visual dictionary
|
560
|
-
visuals = {
|
561
|
-
'bboxes': np.array(bboxes),
|
562
|
-
'cls': np.zeros(len(bboxes)) # TODO figure this out
|
563
|
-
}
|
564
|
-
|
565
|
-
# Set the predictor
|
566
|
-
self.task = self.task_dropdown.currentText()
|
567
|
-
predictor = YOLOEVPSegPredictor if self.task == "segment" else YOLOEVPDetectPredictor
|
504
|
+
# Get the scaled visual prompts
|
505
|
+
visual_prompts = self.scale_prompts(bboxes, masks)
|
568
506
|
|
569
507
|
try:
|
570
508
|
# Make predictions
|
571
509
|
results = self.loaded_model.predict(self.resized_image,
|
572
|
-
visual_prompts=
|
573
|
-
predictor=
|
510
|
+
visual_prompts=visual_prompts.copy(),
|
511
|
+
predictor=YOLOEVPSegPredictor,
|
574
512
|
imgsz=max(self.resized_image.shape[:2]),
|
575
513
|
conf=self.main_window.get_uncertainty_thresh(),
|
576
514
|
iou=self.main_window.get_iou_thresh(),
|
@@ -590,7 +528,7 @@ class DeployPredictorDialog(QDialog):
|
|
590
528
|
|
591
529
|
return results
|
592
530
|
|
593
|
-
def predict_from_annotations(self, refer_image, refer_label,
|
531
|
+
def predict_from_annotations(self, refer_image, refer_label, refer_bboxes, refer_masks, target_images):
|
594
532
|
""""""
|
595
533
|
# Create a class mapping
|
596
534
|
class_mapping = {0: refer_label}
|
@@ -605,15 +543,30 @@ class DeployPredictorDialog(QDialog):
|
|
605
543
|
max_area_thresh=self.main_window.get_area_thresh_max()
|
606
544
|
)
|
607
545
|
|
608
|
-
#
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
|
546
|
+
# Get the scaled visual prompts
|
547
|
+
visual_prompts = self.scale_prompts(refer_bboxes, refer_masks)
|
548
|
+
|
549
|
+
# If VPEs are being used
|
550
|
+
if self.vpe is not None:
|
551
|
+
# Generate a new VPE from the current visual prompts
|
552
|
+
new_vpe = self.prompts_to_vpes(visual_prompts, self.resized_image)
|
553
|
+
|
554
|
+
if new_vpe is not None:
|
555
|
+
# If we already have a VPE, average with the existing one
|
556
|
+
if self.vpe.shape == new_vpe.shape:
|
557
|
+
self.vpe = (self.vpe + new_vpe) / 2
|
558
|
+
# Re-normalize
|
559
|
+
self.vpe = torch.nn.functional.normalize(self.vpe, p=2, dim=-1)
|
560
|
+
else:
|
561
|
+
# Replace with the new VPE if shapes don't match
|
562
|
+
self.vpe = new_vpe
|
563
|
+
|
564
|
+
# Set the updated VPE in the model
|
565
|
+
self.loaded_model.is_fused = lambda: False
|
566
|
+
self.loaded_model.set_classes(["object0"], self.vpe)
|
567
|
+
|
568
|
+
# Clear visual prompts since we're using VPE
|
569
|
+
visual_prompts = {} # this is okay with a fused model
|
617
570
|
|
618
571
|
# Create a progress bar
|
619
572
|
QApplication.setOverrideCursor(Qt.WaitCursor)
|
@@ -627,8 +580,8 @@ class DeployPredictorDialog(QDialog):
|
|
627
580
|
# Make predictions
|
628
581
|
results = self.loaded_model.predict(target_image,
|
629
582
|
refer_image=refer_image,
|
630
|
-
visual_prompts=
|
631
|
-
predictor=
|
583
|
+
visual_prompts=visual_prompts.copy(),
|
584
|
+
predictor=YOLOEVPSegPredictor,
|
632
585
|
imgsz=self.imgsz_spinbox.value(),
|
633
586
|
conf=self.main_window.get_uncertainty_thresh(),
|
634
587
|
iou=self.main_window.get_iou_thresh(),
|
@@ -13,7 +13,7 @@ from ultralytics.models.yolo.yoloe import YOLOEPESegTrainer
|
|
13
13
|
from PyQt5.QtCore import Qt, QThread, pyqtSignal
|
14
14
|
from PyQt5.QtWidgets import (QFileDialog, QScrollArea, QMessageBox, QCheckBox, QWidget, QVBoxLayout,
|
15
15
|
QLabel, QLineEdit, QDialog, QHBoxLayout, QPushButton, QComboBox, QSpinBox,
|
16
|
-
QFormLayout, QTabWidget, QDoubleSpinBox, QGroupBox
|
16
|
+
QFormLayout, QTabWidget, QDoubleSpinBox, QGroupBox)
|
17
17
|
|
18
18
|
from torch.cuda import empty_cache
|
19
19
|
|
@@ -109,7 +109,7 @@ class TrainModelWorker(QThread):
|
|
109
109
|
|
110
110
|
# Train the model
|
111
111
|
self.model.train(**self.params,
|
112
|
-
trainer=YOLOEPESegTrainer,
|
112
|
+
trainer=YOLOEVPTrainer , #YOLOEPESegTrainer,
|
113
113
|
device=self.device)
|
114
114
|
|
115
115
|
# Post-run cleanup
|
@@ -222,6 +222,13 @@ class TrainModelDialog(QDialog):
|
|
222
222
|
# Task specific parameters
|
223
223
|
self.imgsz = 640
|
224
224
|
self.batch = 4
|
225
|
+
|
226
|
+
# Training parameters with defaults for linear-probing
|
227
|
+
self._lr0 = 1e-3
|
228
|
+
self._warmup_bias_lr = 0.0
|
229
|
+
self._weight_decay = 0.025
|
230
|
+
self._momentum = 0.9
|
231
|
+
self._close_mosaic = 0 # Default for linear-probing
|
225
232
|
|
226
233
|
# Create the layout
|
227
234
|
self.layout = QVBoxLayout(self)
|
@@ -232,6 +239,8 @@ class TrainModelDialog(QDialog):
|
|
232
239
|
self.setup_dataset_layout()
|
233
240
|
# Create the model layout (new)
|
234
241
|
self.setup_model_layout()
|
242
|
+
# Create the output parameters layout
|
243
|
+
self.setup_output_parameters_layout()
|
235
244
|
# Create and set up the parameters layout
|
236
245
|
self.setup_parameters_layout()
|
237
246
|
# Create the buttons layout
|
@@ -327,6 +336,29 @@ class TrainModelDialog(QDialog):
|
|
327
336
|
layout.addWidget(tab_widget)
|
328
337
|
group_box.setLayout(layout)
|
329
338
|
self.layout.addWidget(group_box)
|
339
|
+
|
340
|
+
def setup_output_parameters_layout(self):
|
341
|
+
"""
|
342
|
+
Set up the layout and widgets for output parameters (project directory and name).
|
343
|
+
"""
|
344
|
+
group_box = QGroupBox("Output")
|
345
|
+
layout = QFormLayout()
|
346
|
+
|
347
|
+
# Project
|
348
|
+
self.project_edit = QLineEdit()
|
349
|
+
self.project_button = QPushButton("Browse...")
|
350
|
+
self.project_button.clicked.connect(self.browse_project_dir)
|
351
|
+
project_layout = QHBoxLayout()
|
352
|
+
project_layout.addWidget(self.project_edit)
|
353
|
+
project_layout.addWidget(self.project_button)
|
354
|
+
layout.addRow("Project:", project_layout)
|
355
|
+
|
356
|
+
# Name
|
357
|
+
self.name_edit = QLineEdit()
|
358
|
+
layout.addRow("Name:", self.name_edit)
|
359
|
+
|
360
|
+
group_box.setLayout(layout)
|
361
|
+
self.layout.addWidget(group_box)
|
330
362
|
|
331
363
|
def setup_parameters_layout(self):
|
332
364
|
"""
|
@@ -352,19 +384,6 @@ class TrainModelDialog(QDialog):
|
|
352
384
|
group_layout = QVBoxLayout(group_box)
|
353
385
|
group_layout.addWidget(scroll_area)
|
354
386
|
|
355
|
-
# Project
|
356
|
-
self.project_edit = QLineEdit()
|
357
|
-
self.project_button = QPushButton("Browse...")
|
358
|
-
self.project_button.clicked.connect(self.browse_project_dir)
|
359
|
-
project_layout = QHBoxLayout()
|
360
|
-
project_layout.addWidget(self.project_edit)
|
361
|
-
project_layout.addWidget(self.project_button)
|
362
|
-
form_layout.addRow("Project:", project_layout)
|
363
|
-
|
364
|
-
# Name
|
365
|
-
self.name_edit = QLineEdit()
|
366
|
-
form_layout.addRow("Name:", self.name_edit)
|
367
|
-
|
368
387
|
# Fine-tune or Linear Probing
|
369
388
|
self.training_mode = QComboBox()
|
370
389
|
self.training_mode.addItems(["linear-probe", "fine-tune"])
|
@@ -381,11 +400,60 @@ class TrainModelDialog(QDialog):
|
|
381
400
|
|
382
401
|
# Patience
|
383
402
|
self.patience_spinbox = QSpinBox()
|
384
|
-
self.patience_spinbox.setMinimum(
|
403
|
+
self.patience_spinbox.setMinimum(0) # Changed minimum to 0 to allow for 0 patience
|
385
404
|
self.patience_spinbox.setMaximum(1000)
|
386
|
-
self.patience_spinbox.setValue(
|
405
|
+
self.patience_spinbox.setValue(0) # Default for linear-probing
|
387
406
|
form_layout.addRow("Patience:", self.patience_spinbox)
|
388
407
|
|
408
|
+
# Close Mosaic
|
409
|
+
self.close_mosaic_spinbox = QSpinBox()
|
410
|
+
self.close_mosaic_spinbox.setMinimum(0)
|
411
|
+
self.close_mosaic_spinbox.setMaximum(1000)
|
412
|
+
self.close_mosaic_spinbox.setValue(0) # Default for linear-probing
|
413
|
+
form_layout.addRow("Close Mosaic:", self.close_mosaic_spinbox)
|
414
|
+
|
415
|
+
# Optimizer
|
416
|
+
self.optimizer_combo = QComboBox()
|
417
|
+
self.optimizer_combo.addItems(["auto", "SGD", "Adam", "AdamW", "NAdam", "RAdam", "RMSProp"])
|
418
|
+
self.optimizer_combo.setCurrentText("AdamW")
|
419
|
+
form_layout.addRow("Optimizer:", self.optimizer_combo)
|
420
|
+
|
421
|
+
# Learning Rate (lr0)
|
422
|
+
self.lr0_spinbox = QDoubleSpinBox()
|
423
|
+
self.lr0_spinbox.setDecimals(6)
|
424
|
+
self.lr0_spinbox.setMinimum(0.000001)
|
425
|
+
self.lr0_spinbox.setMaximum(1.0)
|
426
|
+
self.lr0_spinbox.setSingleStep(0.0001)
|
427
|
+
self.lr0_spinbox.setValue(self._lr0)
|
428
|
+
form_layout.addRow("Learning Rate:", self.lr0_spinbox)
|
429
|
+
|
430
|
+
# Warmup Bias Learning Rate
|
431
|
+
self.warmup_bias_lr_spinbox = QDoubleSpinBox()
|
432
|
+
self.warmup_bias_lr_spinbox.setDecimals(6)
|
433
|
+
self.warmup_bias_lr_spinbox.setMinimum(0.0)
|
434
|
+
self.warmup_bias_lr_spinbox.setMaximum(1.0)
|
435
|
+
self.warmup_bias_lr_spinbox.setSingleStep(0.0001)
|
436
|
+
self.warmup_bias_lr_spinbox.setValue(self._warmup_bias_lr)
|
437
|
+
form_layout.addRow("Warmup Bias LR:", self.warmup_bias_lr_spinbox)
|
438
|
+
|
439
|
+
# Weight Decay
|
440
|
+
self.weight_decay_spinbox = QDoubleSpinBox()
|
441
|
+
self.weight_decay_spinbox.setDecimals(6)
|
442
|
+
self.weight_decay_spinbox.setMinimum(0.0)
|
443
|
+
self.weight_decay_spinbox.setMaximum(1.0)
|
444
|
+
self.weight_decay_spinbox.setSingleStep(0.001)
|
445
|
+
self.weight_decay_spinbox.setValue(self._weight_decay)
|
446
|
+
form_layout.addRow("Weight Decay:", self.weight_decay_spinbox)
|
447
|
+
|
448
|
+
# Momentum
|
449
|
+
self.momentum_spinbox = QDoubleSpinBox()
|
450
|
+
self.momentum_spinbox.setDecimals(2)
|
451
|
+
self.momentum_spinbox.setMinimum(0.0)
|
452
|
+
self.momentum_spinbox.setMaximum(1.0)
|
453
|
+
self.momentum_spinbox.setSingleStep(0.01)
|
454
|
+
self.momentum_spinbox.setValue(self._momentum)
|
455
|
+
form_layout.addRow("Momentum:", self.momentum_spinbox)
|
456
|
+
|
389
457
|
# Imgsz
|
390
458
|
self.imgsz_spinbox = QSpinBox()
|
391
459
|
self.imgsz_spinbox.setMinimum(16)
|
@@ -393,10 +461,6 @@ class TrainModelDialog(QDialog):
|
|
393
461
|
self.imgsz_spinbox.setValue(self.imgsz)
|
394
462
|
form_layout.addRow("Image Size:", self.imgsz_spinbox)
|
395
463
|
|
396
|
-
# Multi Scale
|
397
|
-
self.multi_scale_combo = create_bool_combo()
|
398
|
-
form_layout.addRow("Multi-Scale:", self.multi_scale_combo)
|
399
|
-
|
400
464
|
# Batch
|
401
465
|
self.batch_spinbox = QSpinBox()
|
402
466
|
self.batch_spinbox.setMinimum(1)
|
@@ -421,20 +485,6 @@ class TrainModelDialog(QDialog):
|
|
421
485
|
self.save_period_spinbox.setMaximum(1000)
|
422
486
|
self.save_period_spinbox.setValue(-1)
|
423
487
|
form_layout.addRow("Save Period:", self.save_period_spinbox)
|
424
|
-
|
425
|
-
# Dropout
|
426
|
-
self.dropout_spinbox = QDoubleSpinBox()
|
427
|
-
self.dropout_spinbox.setMinimum(0.0)
|
428
|
-
self.dropout_spinbox.setMaximum(1.0)
|
429
|
-
self.dropout_spinbox.setValue(0.0)
|
430
|
-
form_layout.addRow("Dropout:", self.dropout_spinbox)
|
431
|
-
|
432
|
-
# Optimizer
|
433
|
-
self.optimizer_combo = QComboBox()
|
434
|
-
self.optimizer_combo.addItems(["auto", "SGD", "Adam", "AdamW", "NAdam", "RAdam", "RMSProp"])
|
435
|
-
self.optimizer_combo.setCurrentText("AdamW")
|
436
|
-
form_layout.addRow("Optimizer:", self.optimizer_combo)
|
437
|
-
|
438
488
|
# Val
|
439
489
|
self.val_combo = create_bool_combo()
|
440
490
|
form_layout.addRow("Validation:", self.val_combo)
|
@@ -489,9 +539,32 @@ class TrainModelDialog(QDialog):
|
|
489
539
|
# Fine-tune mode
|
490
540
|
self.epochs_spinbox.setValue(80)
|
491
541
|
self.patience_spinbox.setValue(20)
|
542
|
+
self.close_mosaic_spinbox.setValue(10)
|
543
|
+
self._close_mosaic = 10
|
544
|
+
|
545
|
+
# These parameters stay the same for both modes
|
546
|
+
self.lr0_spinbox.setValue(1e-3)
|
547
|
+
self.warmup_bias_lr_spinbox.setValue(0.0)
|
548
|
+
self.weight_decay_spinbox.setValue(0.025)
|
549
|
+
self.momentum_spinbox.setValue(0.9)
|
550
|
+
|
551
|
+
# Ensure optimizer is set to AdamW
|
552
|
+
self.optimizer_combo.setCurrentText("AdamW")
|
492
553
|
else:
|
554
|
+
# Linear-probing mode
|
493
555
|
self.epochs_spinbox.setValue(2)
|
494
|
-
self.patience_spinbox.setValue(
|
556
|
+
self.patience_spinbox.setValue(0)
|
557
|
+
self.close_mosaic_spinbox.setValue(0)
|
558
|
+
self._close_mosaic = 0
|
559
|
+
|
560
|
+
# These parameters stay the same for both modes
|
561
|
+
self.lr0_spinbox.setValue(1e-3)
|
562
|
+
self.warmup_bias_lr_spinbox.setValue(0.0)
|
563
|
+
self.weight_decay_spinbox.setValue(0.025)
|
564
|
+
self.momentum_spinbox.setValue(0.9)
|
565
|
+
|
566
|
+
# Ensure optimizer is set to AdamW
|
567
|
+
self.optimizer_combo.setCurrentText("AdamW")
|
495
568
|
|
496
569
|
def load_model_combobox(self):
|
497
570
|
"""Load the model combobox with the available models."""
|
@@ -605,13 +678,16 @@ class TrainModelDialog(QDialog):
|
|
605
678
|
'patience': self.patience_spinbox.value(),
|
606
679
|
'batch': self.batch_spinbox.value(),
|
607
680
|
'imgsz': self.imgsz_spinbox.value(),
|
608
|
-
'
|
681
|
+
'optimizer': self.optimizer_combo.currentText(),
|
682
|
+
'lr0': self.lr0_spinbox.value(),
|
683
|
+
'warmup_bias_lr': self.warmup_bias_lr_spinbox.value(),
|
684
|
+
'weight_decay': self.weight_decay_spinbox.value(),
|
685
|
+
'momentum': self.momentum_spinbox.value(),
|
686
|
+
'close_mosaic': self.close_mosaic_spinbox.value(),
|
609
687
|
'save': self.save_combo.currentText() == "True",
|
610
688
|
'save_period': self.save_period_spinbox.value(),
|
611
689
|
'workers': self.workers_spinbox.value(),
|
612
|
-
'optimizer': self.optimizer_combo.currentText(),
|
613
690
|
'verbose': self.verbose_combo.currentText() == "True",
|
614
|
-
'dropout': self.dropout_spinbox.value(),
|
615
691
|
'val': self.val_combo.currentText() == "True",
|
616
692
|
'exist_ok': True,
|
617
693
|
'plots': True,
|
@@ -644,12 +720,6 @@ class TrainModelDialog(QDialog):
|
|
644
720
|
except ValueError:
|
645
721
|
params[name] = value
|
646
722
|
|
647
|
-
params['lr0'] = 1e-3 if 'lr0' not in params else params['lr0']
|
648
|
-
params['warmup_bias_lr'] = 0.0 if 'warmup_bias_lr' not in params else params['warmup_bias_lr']
|
649
|
-
params['weight_decay'] = 0.025 if 'weight_decay' not in params else params['weight_decay']
|
650
|
-
params['momentum'] = 0.9 if 'momentum' not in params else params['momentum']
|
651
|
-
params['close_mosaic'] = 10 if 'close_mosaic' not in params else params['close_mosaic']
|
652
|
-
|
653
723
|
# Return the dictionary of parameters
|
654
724
|
return params
|
655
725
|
|
@@ -2,9 +2,11 @@
|
|
2
2
|
from .QtTrainModel import TrainModelDialog
|
3
3
|
from .QtBatchInference import BatchInferenceDialog
|
4
4
|
from .QtDeployPredictor import DeployPredictorDialog
|
5
|
+
from .QtDeployGenerator import DeployGeneratorDialog
|
5
6
|
|
6
7
|
__all__ = [
|
7
8
|
'TrainModelDialog',
|
8
9
|
'BatchInferenceDialog'
|
9
10
|
'DeployPredictorDialog',
|
11
|
+
'DeployGeneratorDialog',
|
10
12
|
]
|