coralnet-toolbox 0.0.71__py2.py3-none-any.whl → 0.0.73__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- coralnet_toolbox/Annotations/QtRectangleAnnotation.py +31 -2
- coralnet_toolbox/AutoDistill/QtDeployModel.py +23 -12
- coralnet_toolbox/Explorer/QtDataItem.py +53 -21
- coralnet_toolbox/Explorer/QtExplorer.py +581 -276
- coralnet_toolbox/Explorer/QtFeatureStore.py +15 -0
- coralnet_toolbox/Explorer/QtSettingsWidgets.py +49 -7
- coralnet_toolbox/MachineLearning/DeployModel/QtDetect.py +22 -11
- coralnet_toolbox/MachineLearning/DeployModel/QtSegment.py +22 -10
- coralnet_toolbox/MachineLearning/ExportDataset/QtBase.py +61 -24
- coralnet_toolbox/MachineLearning/ExportDataset/QtClassify.py +5 -1
- coralnet_toolbox/MachineLearning/ExportDataset/QtDetect.py +19 -6
- coralnet_toolbox/MachineLearning/ExportDataset/QtSegment.py +21 -8
- coralnet_toolbox/QtAnnotationWindow.py +52 -16
- coralnet_toolbox/QtEventFilter.py +8 -2
- coralnet_toolbox/QtImageWindow.py +17 -18
- coralnet_toolbox/QtLabelWindow.py +1 -1
- coralnet_toolbox/QtMainWindow.py +203 -8
- coralnet_toolbox/Rasters/QtRaster.py +59 -7
- coralnet_toolbox/Rasters/RasterTableModel.py +34 -6
- coralnet_toolbox/SAM/QtBatchInference.py +0 -2
- coralnet_toolbox/SAM/QtDeployGenerator.py +22 -11
- coralnet_toolbox/SeeAnything/QtBatchInference.py +19 -221
- coralnet_toolbox/SeeAnything/QtDeployGenerator.py +1016 -0
- coralnet_toolbox/SeeAnything/QtDeployPredictor.py +69 -53
- coralnet_toolbox/SeeAnything/QtTrainModel.py +115 -45
- coralnet_toolbox/SeeAnything/__init__.py +2 -0
- coralnet_toolbox/Tools/QtResizeSubTool.py +6 -1
- coralnet_toolbox/Tools/QtSAMTool.py +150 -7
- coralnet_toolbox/Tools/QtSeeAnythingTool.py +220 -55
- coralnet_toolbox/Tools/QtSelectSubTool.py +6 -4
- coralnet_toolbox/Tools/QtSelectTool.py +48 -6
- coralnet_toolbox/Tools/QtWorkAreaTool.py +25 -13
- coralnet_toolbox/__init__.py +1 -1
- {coralnet_toolbox-0.0.71.dist-info → coralnet_toolbox-0.0.73.dist-info}/METADATA +1 -1
- {coralnet_toolbox-0.0.71.dist-info → coralnet_toolbox-0.0.73.dist-info}/RECORD +39 -38
- {coralnet_toolbox-0.0.71.dist-info → coralnet_toolbox-0.0.73.dist-info}/WHEEL +0 -0
- {coralnet_toolbox-0.0.71.dist-info → coralnet_toolbox-0.0.73.dist-info}/entry_points.txt +0 -0
- {coralnet_toolbox-0.0.71.dist-info → coralnet_toolbox-0.0.73.dist-info}/licenses/LICENSE.txt +0 -0
- {coralnet_toolbox-0.0.71.dist-info → coralnet_toolbox-0.0.73.dist-info}/top_level.txt +0 -0
@@ -6,6 +6,13 @@ import ujson as json
|
|
6
6
|
|
7
7
|
import numpy as np
|
8
8
|
|
9
|
+
from torch.cuda import empty_cache
|
10
|
+
from ultralytics.utils import ops
|
11
|
+
|
12
|
+
from ultralytics import YOLOE
|
13
|
+
from ultralytics.models.yolo.yoloe import YOLOEVPSegPredictor
|
14
|
+
from ultralytics.models.yolo.yoloe import YOLOEVPDetectPredictor
|
15
|
+
|
9
16
|
from PyQt5.QtCore import Qt
|
10
17
|
from PyQt5.QtGui import QColor
|
11
18
|
from PyQt5.QtWidgets import (QApplication, QComboBox, QDialog, QFormLayout,
|
@@ -13,13 +20,6 @@ from PyQt5.QtWidgets import (QApplication, QComboBox, QDialog, QFormLayout,
|
|
13
20
|
QSlider, QSpinBox, QVBoxLayout, QGroupBox, QTabWidget,
|
14
21
|
QWidget, QLineEdit, QFileDialog)
|
15
22
|
|
16
|
-
from ultralytics import YOLOE
|
17
|
-
from ultralytics.models.yolo.yoloe import YOLOEVPSegPredictor
|
18
|
-
from ultralytics.models.yolo.yoloe import YOLOEVPDetectPredictor
|
19
|
-
|
20
|
-
from torch.cuda import empty_cache
|
21
|
-
from ultralytics.utils import ops
|
22
|
-
|
23
23
|
from coralnet_toolbox.Results import ResultsProcessor
|
24
24
|
|
25
25
|
from coralnet_toolbox.QtProgressBar import ProgressBar
|
@@ -179,37 +179,6 @@ class DeployPredictorDialog(QDialog):
|
|
179
179
|
group_box.setLayout(layout)
|
180
180
|
self.layout.addWidget(group_box)
|
181
181
|
|
182
|
-
def browse_model_file(self):
|
183
|
-
"""
|
184
|
-
Open a file dialog to browse for a model file.
|
185
|
-
"""
|
186
|
-
file_path, _ = QFileDialog.getOpenFileName(self,
|
187
|
-
"Select Model File",
|
188
|
-
"",
|
189
|
-
"Model Files (*.pt *.pth);;All Files (*)")
|
190
|
-
if file_path:
|
191
|
-
self.model_path_edit.setText(file_path)
|
192
|
-
|
193
|
-
# Load the class mapping if it exists
|
194
|
-
dir_path = os.path.dirname(os.path.dirname(file_path))
|
195
|
-
class_mapping_path = f"{dir_path}/class_mapping.json"
|
196
|
-
if os.path.exists(class_mapping_path):
|
197
|
-
self.class_mapping = json.load(open(class_mapping_path, 'r'))
|
198
|
-
self.mapping_edit.setText(class_mapping_path)
|
199
|
-
|
200
|
-
def browse_class_mapping_file(self):
|
201
|
-
"""
|
202
|
-
Browse and select a class mapping file.
|
203
|
-
"""
|
204
|
-
file_path, _ = QFileDialog.getOpenFileName(self,
|
205
|
-
"Select Class Mapping File",
|
206
|
-
"",
|
207
|
-
"JSON Files (*.json)")
|
208
|
-
if file_path:
|
209
|
-
# Load the class mapping
|
210
|
-
self.class_mapping = json.load(open(file_path, 'r'))
|
211
|
-
self.mapping_edit.setText(file_path)
|
212
|
-
|
213
182
|
def setup_parameters_layout(self):
|
214
183
|
"""
|
215
184
|
Setup parameter control section in a group box.
|
@@ -335,6 +304,37 @@ class DeployPredictorDialog(QDialog):
|
|
335
304
|
|
336
305
|
group_box.setLayout(layout)
|
337
306
|
self.layout.addWidget(group_box)
|
307
|
+
|
308
|
+
def browse_model_file(self):
|
309
|
+
"""
|
310
|
+
Open a file dialog to browse for a model file.
|
311
|
+
"""
|
312
|
+
file_path, _ = QFileDialog.getOpenFileName(self,
|
313
|
+
"Select Model File",
|
314
|
+
"",
|
315
|
+
"Model Files (*.pt *.pth);;All Files (*)")
|
316
|
+
if file_path:
|
317
|
+
self.model_path_edit.setText(file_path)
|
318
|
+
|
319
|
+
# Load the class mapping if it exists
|
320
|
+
dir_path = os.path.dirname(os.path.dirname(file_path))
|
321
|
+
class_mapping_path = f"{dir_path}/class_mapping.json"
|
322
|
+
if os.path.exists(class_mapping_path):
|
323
|
+
self.class_mapping = json.load(open(class_mapping_path, 'r'))
|
324
|
+
self.mapping_edit.setText(class_mapping_path)
|
325
|
+
|
326
|
+
def browse_class_mapping_file(self):
|
327
|
+
"""
|
328
|
+
Browse and select a class mapping file.
|
329
|
+
"""
|
330
|
+
file_path, _ = QFileDialog.getOpenFileName(self,
|
331
|
+
"Select Class Mapping File",
|
332
|
+
"",
|
333
|
+
"JSON Files (*.json)")
|
334
|
+
if file_path:
|
335
|
+
# Load the class mapping
|
336
|
+
self.class_mapping = json.load(open(file_path, 'r'))
|
337
|
+
self.mapping_edit.setText(file_path)
|
338
338
|
|
339
339
|
def initialize_uncertainty_threshold(self):
|
340
340
|
"""Initialize the uncertainty threshold slider with the current value"""
|
@@ -517,9 +517,6 @@ class DeployPredictorDialog(QDialog):
|
|
517
517
|
# Open the image using rasterio
|
518
518
|
image = rasterio_to_numpy(self.main_window.image_window.rasterio_images[image_path])
|
519
519
|
|
520
|
-
# Preprocess the image
|
521
|
-
# image = preprocess_image(image)
|
522
|
-
|
523
520
|
# Save the original image
|
524
521
|
self.original_image = image
|
525
522
|
self.image_path = image_path
|
@@ -530,19 +527,19 @@ class DeployPredictorDialog(QDialog):
|
|
530
527
|
else:
|
531
528
|
self.resized_image = image
|
532
529
|
|
533
|
-
def predict_from_prompts(self, bboxes):
|
530
|
+
def predict_from_prompts(self, bboxes, masks=None):
|
534
531
|
"""
|
535
532
|
Make predictions using the currently loaded model using prompts.
|
536
533
|
|
537
534
|
Args:
|
538
|
-
|
535
|
+
bboxes (np.ndarray): The bounding boxes to use as prompts.
|
536
|
+
masks (list, optional): A list of polygons to use as prompts for segmentation.
|
539
537
|
|
540
538
|
Returns:
|
541
539
|
results (Results): Ultralytics Results object
|
542
540
|
"""
|
543
541
|
if not self.loaded_model:
|
544
|
-
QMessageBox.critical(self.annotation_window,
|
545
|
-
"Model Not Loaded",
|
542
|
+
QMessageBox.critical(self.annotation_window, "Model Not Loaded",
|
546
543
|
"Model not loaded, cannot make predictions")
|
547
544
|
return None
|
548
545
|
|
@@ -556,14 +553,30 @@ class DeployPredictorDialog(QDialog):
|
|
556
553
|
bboxes[:, 2] = (bboxes[:, 2] / self.original_image.shape[1]) * self.resized_image.shape[1]
|
557
554
|
bboxes[:, 3] = (bboxes[:, 3] / self.original_image.shape[0]) * self.resized_image.shape[0]
|
558
555
|
|
556
|
+
# Set the predictor
|
557
|
+
self.task = self.task_dropdown.currentText()
|
558
|
+
|
559
559
|
# Create a visual dictionary
|
560
560
|
visuals = {
|
561
561
|
'bboxes': np.array(bboxes),
|
562
|
-
'cls': np.zeros(len(bboxes))
|
562
|
+
'cls': np.zeros(len(bboxes))
|
563
563
|
}
|
564
|
+
if self.task == 'segment':
|
565
|
+
if masks:
|
566
|
+
scaled_masks = []
|
567
|
+
for mask in masks:
|
568
|
+
scaled_mask = np.array(mask, dtype=np.float32)
|
569
|
+
scaled_mask[:, 0] = (scaled_mask[:, 0] / self.original_image.shape[1]) * self.resized_image.shape[1]
|
570
|
+
scaled_mask[:, 1] = (scaled_mask[:, 1] / self.original_image.shape[0]) * self.resized_image.shape[0]
|
571
|
+
scaled_masks.append(scaled_mask)
|
572
|
+
visuals['masks'] = scaled_masks
|
573
|
+
else: # Fallback to creating masks from bboxes if no masks are provided
|
574
|
+
fallback_masks = []
|
575
|
+
for bbox in bboxes:
|
576
|
+
x1, y1, x2, y2 = bbox
|
577
|
+
fallback_masks.append(np.array([[x1, y1], [x2, y1], [x2, y2], [x1, y2]]))
|
578
|
+
visuals['masks'] = fallback_masks
|
564
579
|
|
565
|
-
# Set the predictor
|
566
|
-
self.task = self.task_dropdown.currentText()
|
567
580
|
predictor = YOLOEVPSegPredictor if self.task == "segment" else YOLOEVPDetectPredictor
|
568
581
|
|
569
582
|
try:
|
@@ -590,7 +603,7 @@ class DeployPredictorDialog(QDialog):
|
|
590
603
|
|
591
604
|
return results
|
592
605
|
|
593
|
-
def predict_from_annotations(self, refer_image, refer_label,
|
606
|
+
def predict_from_annotations(self, refer_image, refer_label, refer_bboxes, refer_masks, target_images):
|
594
607
|
""""""
|
595
608
|
# Create a class mapping
|
596
609
|
class_mapping = {0: refer_label}
|
@@ -605,14 +618,17 @@ class DeployPredictorDialog(QDialog):
|
|
605
618
|
max_area_thresh=self.main_window.get_area_thresh_max()
|
606
619
|
)
|
607
620
|
|
621
|
+
# Set the predictor
|
622
|
+
self.task = self.task_dropdown.currentText()
|
623
|
+
|
608
624
|
# Create a visual dictionary
|
609
625
|
visuals = {
|
610
|
-
'bboxes': np.array(
|
611
|
-
'cls': np.zeros(len(
|
626
|
+
'bboxes': np.array(refer_bboxes),
|
627
|
+
'cls': np.zeros(len(refer_bboxes))
|
612
628
|
}
|
629
|
+
if self.task == 'segment':
|
630
|
+
visuals['masks'] = refer_masks
|
613
631
|
|
614
|
-
# Set the predictor
|
615
|
-
self.task = self.task_dropdown.currentText()
|
616
632
|
predictor = YOLOEVPSegPredictor if self.task == "segment" else YOLOEVPDetectPredictor
|
617
633
|
|
618
634
|
# Create a progress bar
|
@@ -13,7 +13,7 @@ from ultralytics.models.yolo.yoloe import YOLOEPESegTrainer
|
|
13
13
|
from PyQt5.QtCore import Qt, QThread, pyqtSignal
|
14
14
|
from PyQt5.QtWidgets import (QFileDialog, QScrollArea, QMessageBox, QCheckBox, QWidget, QVBoxLayout,
|
15
15
|
QLabel, QLineEdit, QDialog, QHBoxLayout, QPushButton, QComboBox, QSpinBox,
|
16
|
-
QFormLayout, QTabWidget, QDoubleSpinBox, QGroupBox
|
16
|
+
QFormLayout, QTabWidget, QDoubleSpinBox, QGroupBox)
|
17
17
|
|
18
18
|
from torch.cuda import empty_cache
|
19
19
|
|
@@ -109,7 +109,7 @@ class TrainModelWorker(QThread):
|
|
109
109
|
|
110
110
|
# Train the model
|
111
111
|
self.model.train(**self.params,
|
112
|
-
trainer=YOLOEPESegTrainer,
|
112
|
+
trainer=YOLOEVPTrainer , #YOLOEPESegTrainer,
|
113
113
|
device=self.device)
|
114
114
|
|
115
115
|
# Post-run cleanup
|
@@ -222,6 +222,13 @@ class TrainModelDialog(QDialog):
|
|
222
222
|
# Task specific parameters
|
223
223
|
self.imgsz = 640
|
224
224
|
self.batch = 4
|
225
|
+
|
226
|
+
# Training parameters with defaults for linear-probing
|
227
|
+
self._lr0 = 1e-3
|
228
|
+
self._warmup_bias_lr = 0.0
|
229
|
+
self._weight_decay = 0.025
|
230
|
+
self._momentum = 0.9
|
231
|
+
self._close_mosaic = 0 # Default for linear-probing
|
225
232
|
|
226
233
|
# Create the layout
|
227
234
|
self.layout = QVBoxLayout(self)
|
@@ -232,6 +239,8 @@ class TrainModelDialog(QDialog):
|
|
232
239
|
self.setup_dataset_layout()
|
233
240
|
# Create the model layout (new)
|
234
241
|
self.setup_model_layout()
|
242
|
+
# Create the output parameters layout
|
243
|
+
self.setup_output_parameters_layout()
|
235
244
|
# Create and set up the parameters layout
|
236
245
|
self.setup_parameters_layout()
|
237
246
|
# Create the buttons layout
|
@@ -327,6 +336,29 @@ class TrainModelDialog(QDialog):
|
|
327
336
|
layout.addWidget(tab_widget)
|
328
337
|
group_box.setLayout(layout)
|
329
338
|
self.layout.addWidget(group_box)
|
339
|
+
|
340
|
+
def setup_output_parameters_layout(self):
|
341
|
+
"""
|
342
|
+
Set up the layout and widgets for output parameters (project directory and name).
|
343
|
+
"""
|
344
|
+
group_box = QGroupBox("Output")
|
345
|
+
layout = QFormLayout()
|
346
|
+
|
347
|
+
# Project
|
348
|
+
self.project_edit = QLineEdit()
|
349
|
+
self.project_button = QPushButton("Browse...")
|
350
|
+
self.project_button.clicked.connect(self.browse_project_dir)
|
351
|
+
project_layout = QHBoxLayout()
|
352
|
+
project_layout.addWidget(self.project_edit)
|
353
|
+
project_layout.addWidget(self.project_button)
|
354
|
+
layout.addRow("Project:", project_layout)
|
355
|
+
|
356
|
+
# Name
|
357
|
+
self.name_edit = QLineEdit()
|
358
|
+
layout.addRow("Name:", self.name_edit)
|
359
|
+
|
360
|
+
group_box.setLayout(layout)
|
361
|
+
self.layout.addWidget(group_box)
|
330
362
|
|
331
363
|
def setup_parameters_layout(self):
|
332
364
|
"""
|
@@ -352,19 +384,6 @@ class TrainModelDialog(QDialog):
|
|
352
384
|
group_layout = QVBoxLayout(group_box)
|
353
385
|
group_layout.addWidget(scroll_area)
|
354
386
|
|
355
|
-
# Project
|
356
|
-
self.project_edit = QLineEdit()
|
357
|
-
self.project_button = QPushButton("Browse...")
|
358
|
-
self.project_button.clicked.connect(self.browse_project_dir)
|
359
|
-
project_layout = QHBoxLayout()
|
360
|
-
project_layout.addWidget(self.project_edit)
|
361
|
-
project_layout.addWidget(self.project_button)
|
362
|
-
form_layout.addRow("Project:", project_layout)
|
363
|
-
|
364
|
-
# Name
|
365
|
-
self.name_edit = QLineEdit()
|
366
|
-
form_layout.addRow("Name:", self.name_edit)
|
367
|
-
|
368
387
|
# Fine-tune or Linear Probing
|
369
388
|
self.training_mode = QComboBox()
|
370
389
|
self.training_mode.addItems(["linear-probe", "fine-tune"])
|
@@ -381,11 +400,60 @@ class TrainModelDialog(QDialog):
|
|
381
400
|
|
382
401
|
# Patience
|
383
402
|
self.patience_spinbox = QSpinBox()
|
384
|
-
self.patience_spinbox.setMinimum(
|
403
|
+
self.patience_spinbox.setMinimum(0) # Changed minimum to 0 to allow for 0 patience
|
385
404
|
self.patience_spinbox.setMaximum(1000)
|
386
|
-
self.patience_spinbox.setValue(
|
405
|
+
self.patience_spinbox.setValue(0) # Default for linear-probing
|
387
406
|
form_layout.addRow("Patience:", self.patience_spinbox)
|
388
407
|
|
408
|
+
# Close Mosaic
|
409
|
+
self.close_mosaic_spinbox = QSpinBox()
|
410
|
+
self.close_mosaic_spinbox.setMinimum(0)
|
411
|
+
self.close_mosaic_spinbox.setMaximum(1000)
|
412
|
+
self.close_mosaic_spinbox.setValue(0) # Default for linear-probing
|
413
|
+
form_layout.addRow("Close Mosaic:", self.close_mosaic_spinbox)
|
414
|
+
|
415
|
+
# Optimizer
|
416
|
+
self.optimizer_combo = QComboBox()
|
417
|
+
self.optimizer_combo.addItems(["auto", "SGD", "Adam", "AdamW", "NAdam", "RAdam", "RMSProp"])
|
418
|
+
self.optimizer_combo.setCurrentText("AdamW")
|
419
|
+
form_layout.addRow("Optimizer:", self.optimizer_combo)
|
420
|
+
|
421
|
+
# Learning Rate (lr0)
|
422
|
+
self.lr0_spinbox = QDoubleSpinBox()
|
423
|
+
self.lr0_spinbox.setDecimals(6)
|
424
|
+
self.lr0_spinbox.setMinimum(0.000001)
|
425
|
+
self.lr0_spinbox.setMaximum(1.0)
|
426
|
+
self.lr0_spinbox.setSingleStep(0.0001)
|
427
|
+
self.lr0_spinbox.setValue(self._lr0)
|
428
|
+
form_layout.addRow("Learning Rate:", self.lr0_spinbox)
|
429
|
+
|
430
|
+
# Warmup Bias Learning Rate
|
431
|
+
self.warmup_bias_lr_spinbox = QDoubleSpinBox()
|
432
|
+
self.warmup_bias_lr_spinbox.setDecimals(6)
|
433
|
+
self.warmup_bias_lr_spinbox.setMinimum(0.0)
|
434
|
+
self.warmup_bias_lr_spinbox.setMaximum(1.0)
|
435
|
+
self.warmup_bias_lr_spinbox.setSingleStep(0.0001)
|
436
|
+
self.warmup_bias_lr_spinbox.setValue(self._warmup_bias_lr)
|
437
|
+
form_layout.addRow("Warmup Bias LR:", self.warmup_bias_lr_spinbox)
|
438
|
+
|
439
|
+
# Weight Decay
|
440
|
+
self.weight_decay_spinbox = QDoubleSpinBox()
|
441
|
+
self.weight_decay_spinbox.setDecimals(6)
|
442
|
+
self.weight_decay_spinbox.setMinimum(0.0)
|
443
|
+
self.weight_decay_spinbox.setMaximum(1.0)
|
444
|
+
self.weight_decay_spinbox.setSingleStep(0.001)
|
445
|
+
self.weight_decay_spinbox.setValue(self._weight_decay)
|
446
|
+
form_layout.addRow("Weight Decay:", self.weight_decay_spinbox)
|
447
|
+
|
448
|
+
# Momentum
|
449
|
+
self.momentum_spinbox = QDoubleSpinBox()
|
450
|
+
self.momentum_spinbox.setDecimals(2)
|
451
|
+
self.momentum_spinbox.setMinimum(0.0)
|
452
|
+
self.momentum_spinbox.setMaximum(1.0)
|
453
|
+
self.momentum_spinbox.setSingleStep(0.01)
|
454
|
+
self.momentum_spinbox.setValue(self._momentum)
|
455
|
+
form_layout.addRow("Momentum:", self.momentum_spinbox)
|
456
|
+
|
389
457
|
# Imgsz
|
390
458
|
self.imgsz_spinbox = QSpinBox()
|
391
459
|
self.imgsz_spinbox.setMinimum(16)
|
@@ -393,10 +461,6 @@ class TrainModelDialog(QDialog):
|
|
393
461
|
self.imgsz_spinbox.setValue(self.imgsz)
|
394
462
|
form_layout.addRow("Image Size:", self.imgsz_spinbox)
|
395
463
|
|
396
|
-
# Multi Scale
|
397
|
-
self.multi_scale_combo = create_bool_combo()
|
398
|
-
form_layout.addRow("Multi-Scale:", self.multi_scale_combo)
|
399
|
-
|
400
464
|
# Batch
|
401
465
|
self.batch_spinbox = QSpinBox()
|
402
466
|
self.batch_spinbox.setMinimum(1)
|
@@ -421,20 +485,6 @@ class TrainModelDialog(QDialog):
|
|
421
485
|
self.save_period_spinbox.setMaximum(1000)
|
422
486
|
self.save_period_spinbox.setValue(-1)
|
423
487
|
form_layout.addRow("Save Period:", self.save_period_spinbox)
|
424
|
-
|
425
|
-
# Dropout
|
426
|
-
self.dropout_spinbox = QDoubleSpinBox()
|
427
|
-
self.dropout_spinbox.setMinimum(0.0)
|
428
|
-
self.dropout_spinbox.setMaximum(1.0)
|
429
|
-
self.dropout_spinbox.setValue(0.0)
|
430
|
-
form_layout.addRow("Dropout:", self.dropout_spinbox)
|
431
|
-
|
432
|
-
# Optimizer
|
433
|
-
self.optimizer_combo = QComboBox()
|
434
|
-
self.optimizer_combo.addItems(["auto", "SGD", "Adam", "AdamW", "NAdam", "RAdam", "RMSProp"])
|
435
|
-
self.optimizer_combo.setCurrentText("AdamW")
|
436
|
-
form_layout.addRow("Optimizer:", self.optimizer_combo)
|
437
|
-
|
438
488
|
# Val
|
439
489
|
self.val_combo = create_bool_combo()
|
440
490
|
form_layout.addRow("Validation:", self.val_combo)
|
@@ -489,9 +539,32 @@ class TrainModelDialog(QDialog):
|
|
489
539
|
# Fine-tune mode
|
490
540
|
self.epochs_spinbox.setValue(80)
|
491
541
|
self.patience_spinbox.setValue(20)
|
542
|
+
self.close_mosaic_spinbox.setValue(10)
|
543
|
+
self._close_mosaic = 10
|
544
|
+
|
545
|
+
# These parameters stay the same for both modes
|
546
|
+
self.lr0_spinbox.setValue(1e-3)
|
547
|
+
self.warmup_bias_lr_spinbox.setValue(0.0)
|
548
|
+
self.weight_decay_spinbox.setValue(0.025)
|
549
|
+
self.momentum_spinbox.setValue(0.9)
|
550
|
+
|
551
|
+
# Ensure optimizer is set to AdamW
|
552
|
+
self.optimizer_combo.setCurrentText("AdamW")
|
492
553
|
else:
|
554
|
+
# Linear-probing mode
|
493
555
|
self.epochs_spinbox.setValue(2)
|
494
|
-
self.patience_spinbox.setValue(
|
556
|
+
self.patience_spinbox.setValue(0)
|
557
|
+
self.close_mosaic_spinbox.setValue(0)
|
558
|
+
self._close_mosaic = 0
|
559
|
+
|
560
|
+
# These parameters stay the same for both modes
|
561
|
+
self.lr0_spinbox.setValue(1e-3)
|
562
|
+
self.warmup_bias_lr_spinbox.setValue(0.0)
|
563
|
+
self.weight_decay_spinbox.setValue(0.025)
|
564
|
+
self.momentum_spinbox.setValue(0.9)
|
565
|
+
|
566
|
+
# Ensure optimizer is set to AdamW
|
567
|
+
self.optimizer_combo.setCurrentText("AdamW")
|
495
568
|
|
496
569
|
def load_model_combobox(self):
|
497
570
|
"""Load the model combobox with the available models."""
|
@@ -605,13 +678,16 @@ class TrainModelDialog(QDialog):
|
|
605
678
|
'patience': self.patience_spinbox.value(),
|
606
679
|
'batch': self.batch_spinbox.value(),
|
607
680
|
'imgsz': self.imgsz_spinbox.value(),
|
608
|
-
'
|
681
|
+
'optimizer': self.optimizer_combo.currentText(),
|
682
|
+
'lr0': self.lr0_spinbox.value(),
|
683
|
+
'warmup_bias_lr': self.warmup_bias_lr_spinbox.value(),
|
684
|
+
'weight_decay': self.weight_decay_spinbox.value(),
|
685
|
+
'momentum': self.momentum_spinbox.value(),
|
686
|
+
'close_mosaic': self.close_mosaic_spinbox.value(),
|
609
687
|
'save': self.save_combo.currentText() == "True",
|
610
688
|
'save_period': self.save_period_spinbox.value(),
|
611
689
|
'workers': self.workers_spinbox.value(),
|
612
|
-
'optimizer': self.optimizer_combo.currentText(),
|
613
690
|
'verbose': self.verbose_combo.currentText() == "True",
|
614
|
-
'dropout': self.dropout_spinbox.value(),
|
615
691
|
'val': self.val_combo.currentText() == "True",
|
616
692
|
'exist_ok': True,
|
617
693
|
'plots': True,
|
@@ -644,12 +720,6 @@ class TrainModelDialog(QDialog):
|
|
644
720
|
except ValueError:
|
645
721
|
params[name] = value
|
646
722
|
|
647
|
-
params['lr0'] = 1e-3 if 'lr0' not in params else params['lr0']
|
648
|
-
params['warmup_bias_lr'] = 0.0 if 'warmup_bias_lr' not in params else params['warmup_bias_lr']
|
649
|
-
params['weight_decay'] = 0.025 if 'weight_decay' not in params else params['weight_decay']
|
650
|
-
params['momentum'] = 0.9 if 'momentum' not in params else params['momentum']
|
651
|
-
params['close_mosaic'] = 10 if 'close_mosaic' not in params else params['close_mosaic']
|
652
|
-
|
653
723
|
# Return the dictionary of parameters
|
654
724
|
return params
|
655
725
|
|
@@ -2,9 +2,11 @@
|
|
2
2
|
from .QtTrainModel import TrainModelDialog
|
3
3
|
from .QtBatchInference import BatchInferenceDialog
|
4
4
|
from .QtDeployPredictor import DeployPredictorDialog
|
5
|
+
from .QtDeployGenerator import DeployGeneratorDialog
|
5
6
|
|
6
7
|
__all__ = [
|
7
8
|
'TrainModelDialog',
|
8
9
|
'BatchInferenceDialog'
|
9
10
|
'DeployPredictorDialog',
|
11
|
+
'DeployGeneratorDialog',
|
10
12
|
]
|
@@ -59,9 +59,14 @@ class ResizeSubTool(SubTool):
|
|
59
59
|
def mouseReleaseEvent(self, event):
|
60
60
|
"""Finalize the resize, update related windows, and deactivate."""
|
61
61
|
if self.target_annotation:
|
62
|
+
# Normalize the coordinates after resize is complete
|
63
|
+
if hasattr(self.target_annotation, 'normalize_coordinates'):
|
64
|
+
self.target_annotation.normalize_coordinates()
|
65
|
+
|
62
66
|
self.target_annotation.create_cropped_image(self.annotation_window.rasterio_image)
|
63
67
|
self.parent_tool.main_window.confidence_window.display_cropped_image(self.target_annotation)
|
64
|
-
|
68
|
+
self.annotation_window.annotationModified.emit(self.target_annotation.id) # Emit modified signal
|
69
|
+
|
65
70
|
self.parent_tool.deactivate_subtool()
|
66
71
|
|
67
72
|
# --- Handle Management Logic (moved from original class) ---
|
@@ -64,6 +64,11 @@ class SAMTool(Tool):
|
|
64
64
|
|
65
65
|
# Flag to track if we have active prompts
|
66
66
|
self.has_active_prompts = False
|
67
|
+
|
68
|
+
# Add state variables for custom working area creation
|
69
|
+
self.creating_working_area = False
|
70
|
+
self.working_area_start = None
|
71
|
+
self.working_area_temp_graphics = None
|
67
72
|
|
68
73
|
def activate(self):
|
69
74
|
"""
|
@@ -82,6 +87,7 @@ class SAMTool(Tool):
|
|
82
87
|
self.sam_dialog = None
|
83
88
|
self.cancel_working_area()
|
84
89
|
self.has_active_prompts = False
|
90
|
+
self.cancel_working_area_creation()
|
85
91
|
|
86
92
|
def set_working_area(self):
|
87
93
|
"""
|
@@ -141,7 +147,116 @@ class SAMTool(Tool):
|
|
141
147
|
|
142
148
|
self.annotation_window.setCursor(Qt.CrossCursor)
|
143
149
|
self.annotation_window.scene.update()
|
144
|
-
|
150
|
+
|
151
|
+
def set_custom_working_area(self, start_point, end_point):
|
152
|
+
"""
|
153
|
+
Create a working area from custom points selected by the user.
|
154
|
+
|
155
|
+
Args:
|
156
|
+
start_point (QPointF): First corner of the working area
|
157
|
+
end_point (QPointF): Opposite corner of the working area
|
158
|
+
"""
|
159
|
+
self.annotation_window.setCursor(Qt.WaitCursor)
|
160
|
+
|
161
|
+
# Cancel any existing working area
|
162
|
+
self.cancel_working_area()
|
163
|
+
|
164
|
+
# Calculate the rectangle bounds
|
165
|
+
left = max(0, int(min(start_point.x(), end_point.x())))
|
166
|
+
top = max(0, int(min(start_point.y(), end_point.y())))
|
167
|
+
right = min(int(self.annotation_window.pixmap_image.size().width()),
|
168
|
+
int(max(start_point.x(), end_point.x())))
|
169
|
+
bottom = min(int(self.annotation_window.pixmap_image.size().height()),
|
170
|
+
int(max(start_point.y(), end_point.y())))
|
171
|
+
|
172
|
+
# Ensure minimum size (at least 10x10 pixels)
|
173
|
+
if right - left < 10:
|
174
|
+
right = min(left + 10, int(self.annotation_window.pixmap_image.size().width()))
|
175
|
+
if bottom - top < 10:
|
176
|
+
bottom = min(top + 10, int(self.annotation_window.pixmap_image.size().height()))
|
177
|
+
|
178
|
+
# Original image information
|
179
|
+
self.image_path = self.annotation_window.current_image_path
|
180
|
+
self.original_image = pixmap_to_numpy(self.annotation_window.pixmap_image)
|
181
|
+
self.original_width = self.annotation_window.pixmap_image.size().width()
|
182
|
+
self.original_height = self.annotation_window.pixmap_image.size().height()
|
183
|
+
|
184
|
+
# Create the WorkArea instance
|
185
|
+
self.working_area = WorkArea(left, top, right - left, bottom - top, self.image_path)
|
186
|
+
|
187
|
+
# Get the thickness for the working area graphics
|
188
|
+
pen_width = self.graphics_utility.get_workarea_thickness(self.annotation_window)
|
189
|
+
|
190
|
+
# Create and add the working area graphics
|
191
|
+
self.working_area.create_graphics(self.annotation_window.scene, pen_width)
|
192
|
+
self.working_area.set_remove_button_visibility(False)
|
193
|
+
self.working_area.removed.connect(self.on_working_area_removed)
|
194
|
+
|
195
|
+
# Create shadow overlay
|
196
|
+
shadow_brush = QBrush(QColor(0, 0, 0, 150))
|
197
|
+
shadow_path = QPainterPath()
|
198
|
+
shadow_path.addRect(self.annotation_window.scene.sceneRect())
|
199
|
+
shadow_path.addRect(self.working_area.rect)
|
200
|
+
shadow_path = shadow_path.simplified()
|
201
|
+
|
202
|
+
self.shadow_area = QGraphicsPathItem(shadow_path)
|
203
|
+
self.shadow_area.setBrush(shadow_brush)
|
204
|
+
self.shadow_area.setPen(QPen(Qt.NoPen))
|
205
|
+
self.annotation_window.scene.addItem(self.shadow_area)
|
206
|
+
|
207
|
+
# Update the working area image in the SAM model
|
208
|
+
self.image = self.original_image[top:bottom, left:right]
|
209
|
+
self.sam_dialog.set_image(self.image, self.image_path)
|
210
|
+
|
211
|
+
self.annotation_window.setCursor(Qt.CrossCursor)
|
212
|
+
self.annotation_window.scene.update()
|
213
|
+
|
214
|
+
def display_working_area_preview(self, current_pos):
|
215
|
+
"""
|
216
|
+
Display a preview rectangle for the working area being created.
|
217
|
+
|
218
|
+
Args:
|
219
|
+
current_pos (QPointF): Current mouse position
|
220
|
+
"""
|
221
|
+
if not self.working_area_start:
|
222
|
+
return
|
223
|
+
|
224
|
+
# Remove previous preview if it exists
|
225
|
+
if self.working_area_temp_graphics:
|
226
|
+
self.annotation_window.scene.removeItem(self.working_area_temp_graphics)
|
227
|
+
self.working_area_temp_graphics = None
|
228
|
+
|
229
|
+
# Create preview rectangle
|
230
|
+
rect = QRectF(
|
231
|
+
min(self.working_area_start.x(), current_pos.x()),
|
232
|
+
min(self.working_area_start.y(), current_pos.y()),
|
233
|
+
abs(current_pos.x() - self.working_area_start.x()),
|
234
|
+
abs(current_pos.y() - self.working_area_start.y())
|
235
|
+
)
|
236
|
+
|
237
|
+
# Create a dashed blue pen for the working area preview
|
238
|
+
pen = QPen(QColor(0, 120, 215))
|
239
|
+
pen.setStyle(Qt.DashLine)
|
240
|
+
pen.setWidth(2)
|
241
|
+
|
242
|
+
self.working_area_temp_graphics = QGraphicsRectItem(rect)
|
243
|
+
self.working_area_temp_graphics.setPen(pen)
|
244
|
+
self.working_area_temp_graphics.setBrush(QBrush(QColor(0, 120, 215, 30))) # Light blue transparent fill
|
245
|
+
self.annotation_window.scene.addItem(self.working_area_temp_graphics)
|
246
|
+
|
247
|
+
def cancel_working_area_creation(self):
|
248
|
+
"""
|
249
|
+
Cancel the process of creating a working area.
|
250
|
+
"""
|
251
|
+
self.creating_working_area = False
|
252
|
+
self.working_area_start = None
|
253
|
+
|
254
|
+
if self.working_area_temp_graphics:
|
255
|
+
self.annotation_window.scene.removeItem(self.working_area_temp_graphics)
|
256
|
+
self.working_area_temp_graphics = None
|
257
|
+
|
258
|
+
self.annotation_window.scene.update()
|
259
|
+
|
145
260
|
def on_working_area_removed(self, work_area):
|
146
261
|
"""
|
147
262
|
Handle when the work area is removed via its internal mechanism.
|
@@ -357,11 +472,24 @@ class SAMTool(Tool):
|
|
357
472
|
"A label must be selected before adding an annotation.")
|
358
473
|
return
|
359
474
|
|
360
|
-
if not self.working_area:
|
361
|
-
return
|
362
|
-
|
363
475
|
# Get position in scene coordinates
|
364
476
|
scene_pos = self.annotation_window.mapToScene(event.pos())
|
477
|
+
|
478
|
+
# Handle working area creation mode
|
479
|
+
if not self.working_area and event.button() == Qt.LeftButton:
|
480
|
+
if not self.creating_working_area:
|
481
|
+
# Start working area creation
|
482
|
+
self.creating_working_area = True
|
483
|
+
self.working_area_start = scene_pos
|
484
|
+
return
|
485
|
+
elif self.creating_working_area and self.working_area_start:
|
486
|
+
# Finish working area creation
|
487
|
+
self.set_custom_working_area(self.working_area_start, scene_pos)
|
488
|
+
self.cancel_working_area_creation()
|
489
|
+
return
|
490
|
+
|
491
|
+
if not self.working_area:
|
492
|
+
return
|
365
493
|
|
366
494
|
# Check if position is within working area
|
367
495
|
if not self.working_area.contains_point(scene_pos):
|
@@ -430,12 +558,17 @@ class SAMTool(Tool):
|
|
430
558
|
"""
|
431
559
|
Handle mouse move events.
|
432
560
|
"""
|
433
|
-
if not self.working_area:
|
434
|
-
return
|
435
|
-
|
436
561
|
scene_pos = self.annotation_window.mapToScene(event.pos())
|
437
562
|
self.hover_pos = scene_pos
|
438
563
|
|
564
|
+
# Update working area preview during creation
|
565
|
+
if self.creating_working_area and self.working_area_start:
|
566
|
+
self.display_working_area_preview(scene_pos)
|
567
|
+
return
|
568
|
+
|
569
|
+
if not self.working_area:
|
570
|
+
return
|
571
|
+
|
439
572
|
# Update rectangle during drawing
|
440
573
|
if self.drawing_rectangle and self.start_point:
|
441
574
|
self.end_point = scene_pos
|
@@ -461,6 +594,11 @@ class SAMTool(Tool):
|
|
461
594
|
Handle key press events.
|
462
595
|
"""
|
463
596
|
if event.key() == Qt.Key_Space:
|
597
|
+
# If creating working area, confirm it
|
598
|
+
if self.creating_working_area and self.working_area_start and self.hover_pos:
|
599
|
+
self.set_custom_working_area(self.working_area_start, self.hover_pos)
|
600
|
+
self.cancel_working_area_creation()
|
601
|
+
return
|
464
602
|
|
465
603
|
# If no working area, set it up
|
466
604
|
if not self.working_area:
|
@@ -507,6 +645,11 @@ class SAMTool(Tool):
|
|
507
645
|
self.annotation_window.scene.update()
|
508
646
|
|
509
647
|
elif event.key() == Qt.Key_Backspace:
|
648
|
+
# If creating working area, cancel it
|
649
|
+
if self.creating_working_area:
|
650
|
+
self.cancel_working_area_creation()
|
651
|
+
return
|
652
|
+
|
510
653
|
# If drawing rectangle, cancel it
|
511
654
|
if self.drawing_rectangle:
|
512
655
|
self.cancel_rectangle_drawing()
|