coralnet-toolbox 0.0.72__py2.py3-none-any.whl → 0.0.74__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. coralnet_toolbox/Annotations/QtAnnotation.py +28 -69
  2. coralnet_toolbox/Annotations/QtMaskAnnotation.py +408 -0
  3. coralnet_toolbox/Annotations/QtMultiPolygonAnnotation.py +72 -56
  4. coralnet_toolbox/Annotations/QtPatchAnnotation.py +165 -216
  5. coralnet_toolbox/Annotations/QtPolygonAnnotation.py +497 -353
  6. coralnet_toolbox/Annotations/QtRectangleAnnotation.py +126 -116
  7. coralnet_toolbox/AutoDistill/QtDeployModel.py +23 -12
  8. coralnet_toolbox/CoralNet/QtDownload.py +2 -1
  9. coralnet_toolbox/Explorer/QtDataItem.py +1 -1
  10. coralnet_toolbox/Explorer/QtExplorer.py +159 -17
  11. coralnet_toolbox/Explorer/QtSettingsWidgets.py +160 -86
  12. coralnet_toolbox/IO/QtExportTagLabAnnotations.py +30 -10
  13. coralnet_toolbox/IO/QtImportTagLabAnnotations.py +21 -15
  14. coralnet_toolbox/IO/QtOpenProject.py +46 -78
  15. coralnet_toolbox/IO/QtSaveProject.py +18 -43
  16. coralnet_toolbox/MachineLearning/DeployModel/QtDetect.py +22 -11
  17. coralnet_toolbox/MachineLearning/DeployModel/QtSegment.py +22 -10
  18. coralnet_toolbox/MachineLearning/ExportDataset/QtBase.py +61 -24
  19. coralnet_toolbox/MachineLearning/ExportDataset/QtClassify.py +5 -1
  20. coralnet_toolbox/MachineLearning/ExportDataset/QtDetect.py +19 -6
  21. coralnet_toolbox/MachineLearning/ExportDataset/QtSegment.py +21 -8
  22. coralnet_toolbox/MachineLearning/ImportDataset/QtBase.py +42 -22
  23. coralnet_toolbox/MachineLearning/VideoInference/QtBase.py +0 -4
  24. coralnet_toolbox/QtAnnotationWindow.py +42 -14
  25. coralnet_toolbox/QtEventFilter.py +19 -2
  26. coralnet_toolbox/QtImageWindow.py +134 -86
  27. coralnet_toolbox/QtLabelWindow.py +14 -2
  28. coralnet_toolbox/QtMainWindow.py +122 -9
  29. coralnet_toolbox/QtProgressBar.py +52 -27
  30. coralnet_toolbox/Rasters/QtRaster.py +59 -7
  31. coralnet_toolbox/Rasters/RasterTableModel.py +42 -14
  32. coralnet_toolbox/SAM/QtBatchInference.py +0 -2
  33. coralnet_toolbox/SAM/QtDeployGenerator.py +22 -11
  34. coralnet_toolbox/SAM/QtDeployPredictor.py +10 -0
  35. coralnet_toolbox/SeeAnything/QtBatchInference.py +19 -221
  36. coralnet_toolbox/SeeAnything/QtDeployGenerator.py +1634 -0
  37. coralnet_toolbox/SeeAnything/QtDeployPredictor.py +107 -154
  38. coralnet_toolbox/SeeAnything/QtTrainModel.py +115 -45
  39. coralnet_toolbox/SeeAnything/__init__.py +2 -0
  40. coralnet_toolbox/Tools/QtCutSubTool.py +18 -2
  41. coralnet_toolbox/Tools/QtResizeSubTool.py +19 -2
  42. coralnet_toolbox/Tools/QtSAMTool.py +222 -57
  43. coralnet_toolbox/Tools/QtSeeAnythingTool.py +223 -55
  44. coralnet_toolbox/Tools/QtSelectSubTool.py +6 -4
  45. coralnet_toolbox/Tools/QtSelectTool.py +27 -3
  46. coralnet_toolbox/Tools/QtSubtractSubTool.py +66 -0
  47. coralnet_toolbox/Tools/QtWorkAreaTool.py +25 -13
  48. coralnet_toolbox/Tools/__init__.py +2 -0
  49. coralnet_toolbox/__init__.py +1 -1
  50. coralnet_toolbox/utilities.py +137 -47
  51. coralnet_toolbox-0.0.74.dist-info/METADATA +375 -0
  52. {coralnet_toolbox-0.0.72.dist-info → coralnet_toolbox-0.0.74.dist-info}/RECORD +56 -53
  53. coralnet_toolbox-0.0.72.dist-info/METADATA +0 -341
  54. {coralnet_toolbox-0.0.72.dist-info → coralnet_toolbox-0.0.74.dist-info}/WHEEL +0 -0
  55. {coralnet_toolbox-0.0.72.dist-info → coralnet_toolbox-0.0.74.dist-info}/entry_points.txt +0 -0
  56. {coralnet_toolbox-0.0.72.dist-info → coralnet_toolbox-0.0.74.dist-info}/licenses/LICENSE.txt +0 -0
  57. {coralnet_toolbox-0.0.72.dist-info → coralnet_toolbox-0.0.74.dist-info}/top_level.txt +0 -0
@@ -2,23 +2,23 @@ import warnings
2
2
 
3
3
  import os
4
4
  import gc
5
- import ujson as json
6
5
 
7
6
  import numpy as np
8
7
 
9
- from PyQt5.QtCore import Qt
10
- from PyQt5.QtGui import QColor
11
- from PyQt5.QtWidgets import (QApplication, QComboBox, QDialog, QFormLayout,
12
- QHBoxLayout, QLabel, QMessageBox, QPushButton,
13
- QSlider, QSpinBox, QVBoxLayout, QGroupBox, QTabWidget,
14
- QWidget, QLineEdit, QFileDialog)
8
+ import torch
9
+ from torch.cuda import empty_cache
10
+ from ultralytics.utils import ops
15
11
 
16
12
  from ultralytics import YOLOE
17
13
  from ultralytics.models.yolo.yoloe import YOLOEVPSegPredictor
18
14
  from ultralytics.models.yolo.yoloe import YOLOEVPDetectPredictor
19
15
 
20
- from torch.cuda import empty_cache
21
- from ultralytics.utils import ops
16
+ from PyQt5.QtCore import Qt
17
+ from PyQt5.QtGui import QColor
18
+ from PyQt5.QtWidgets import (QApplication, QComboBox, QDialog, QFormLayout,
19
+ QHBoxLayout, QLabel, QMessageBox, QPushButton,
20
+ QSlider, QSpinBox, QVBoxLayout, QGroupBox,
21
+ QWidget, QLineEdit, QFileDialog)
22
22
 
23
23
  from coralnet_toolbox.Results import ResultsProcessor
24
24
 
@@ -98,7 +98,10 @@ class DeployPredictorDialog(QDialog):
98
98
  layout = QVBoxLayout()
99
99
 
100
100
  # Create a QLabel with explanatory text and hyperlink
101
- info_label = QLabel("Choose a Predictor to deploy and use interactively with the See Anything tool.")
101
+ info_label = QLabel(
102
+ "Choose a Predictor to deploy and use interactively with the See Anything tool. "
103
+ "Optionally include a custom visual prompt encoding (VPE) file."
104
+ )
102
105
 
103
106
  info_label.setOpenExternalLinks(True)
104
107
  info_label.setWordWrap(True)
@@ -109,21 +112,15 @@ class DeployPredictorDialog(QDialog):
109
112
 
110
113
  def setup_models_layout(self):
111
114
  """
112
- Setup the models layout with tabs for standard and custom models.
115
+ Setup the models layout with standard models and file selection.
113
116
  """
114
117
  group_box = QGroupBox("Model Selection")
115
- layout = QVBoxLayout()
116
-
117
- # Create tabbed widget
118
- tab_widget = QTabWidget()
119
-
120
- # Tab 1: Standard models
121
- standard_tab = QWidget()
122
- standard_layout = QVBoxLayout(standard_tab)
123
-
118
+ layout = QFormLayout()
119
+
120
+ # Model dropdown
124
121
  self.model_combo = QComboBox()
125
122
  self.model_combo.setEditable(True)
126
-
123
+
127
124
  # Define available models
128
125
  standard_models = [
129
126
  'yoloe-v8s-seg.pt',
@@ -133,83 +130,18 @@ class DeployPredictorDialog(QDialog):
133
130
  'yoloe-11m-seg.pt',
134
131
  'yoloe-11l-seg.pt',
135
132
  ]
136
-
133
+
137
134
  # Add all models to combo box
138
135
  self.model_combo.addItems(standard_models)
136
+
139
137
  # Set the default model
140
138
  self.model_combo.setCurrentIndex(standard_models.index('yoloe-v8s-seg.pt'))
141
-
142
- standard_layout.addWidget(QLabel("Models"))
143
- standard_layout.addWidget(self.model_combo)
144
-
145
- tab_widget.addTab(standard_tab, "Use Existing Model")
146
-
147
- # Tab 2: Custom model
148
- custom_tab = QWidget()
149
- custom_layout = QFormLayout(custom_tab)
150
-
151
- # Custom model file selection
152
- self.model_path_edit = QLineEdit()
153
- browse_button = QPushButton("Browse...")
154
- browse_button.clicked.connect(self.browse_model_file)
155
-
156
- model_path_layout = QHBoxLayout()
157
- model_path_layout.addWidget(self.model_path_edit)
158
- model_path_layout.addWidget(browse_button)
159
- custom_layout.addRow("Custom Model:", model_path_layout)
160
-
161
- # Class Mapping
162
- self.mapping_edit = QLineEdit()
163
- self.mapping_button = QPushButton("Browse...")
164
- self.mapping_button.clicked.connect(self.browse_class_mapping_file)
165
-
166
- class_mapping_layout = QHBoxLayout()
167
- class_mapping_layout.addWidget(self.mapping_edit)
168
- class_mapping_layout.addWidget(self.mapping_button)
169
- custom_layout.addRow("Class Mapping:", class_mapping_layout)
170
-
171
- tab_widget.addTab(custom_tab, "Custom Model")
172
-
173
- # Add the tab widget to the main layout
174
- layout.addWidget(tab_widget)
175
-
176
- # Store the tab widget for later reference
177
- self.model_tab_widget = tab_widget
178
-
139
+ # Create a layout for the model selection
140
+ layout.addRow("Models:", self.model_combo)
141
+
179
142
  group_box.setLayout(layout)
180
143
  self.layout.addWidget(group_box)
181
144
 
182
- def browse_model_file(self):
183
- """
184
- Open a file dialog to browse for a model file.
185
- """
186
- file_path, _ = QFileDialog.getOpenFileName(self,
187
- "Select Model File",
188
- "",
189
- "Model Files (*.pt *.pth);;All Files (*)")
190
- if file_path:
191
- self.model_path_edit.setText(file_path)
192
-
193
- # Load the class mapping if it exists
194
- dir_path = os.path.dirname(os.path.dirname(file_path))
195
- class_mapping_path = f"{dir_path}/class_mapping.json"
196
- if os.path.exists(class_mapping_path):
197
- self.class_mapping = json.load(open(class_mapping_path, 'r'))
198
- self.mapping_edit.setText(class_mapping_path)
199
-
200
- def browse_class_mapping_file(self):
201
- """
202
- Browse and select a class mapping file.
203
- """
204
- file_path, _ = QFileDialog.getOpenFileName(self,
205
- "Select Class Mapping File",
206
- "",
207
- "JSON Files (*.json)")
208
- if file_path:
209
- # Load the class mapping
210
- self.class_mapping = json.load(open(file_path, 'r'))
211
- self.mapping_edit.setText(file_path)
212
-
213
145
  def setup_parameters_layout(self):
214
146
  """
215
147
  Setup parameter control section in a group box.
@@ -412,46 +344,43 @@ class DeployPredictorDialog(QDialog):
412
344
  QApplication.setOverrideCursor(Qt.WaitCursor)
413
345
  progress_bar = ProgressBar(self.annotation_window, title="Loading Model")
414
346
  progress_bar.show()
415
-
347
+
416
348
  try:
417
349
  # Get selected model path and download weights if needed
418
350
  self.model_path = self.model_combo.currentText()
419
-
351
+
420
352
  # Load model using registry
421
353
  self.loaded_model = YOLOE(self.model_path).to(self.main_window.device)
422
-
423
- # Create a dummy visual dictionary
354
+
355
+ # Create a dummy visual dictionary for standard model loading
424
356
  visuals = dict(
425
357
  bboxes=np.array(
426
358
  [
427
- [120, 425, 160, 445],
359
+ [120, 425, 160, 445], # Random box
428
360
  ],
429
361
  ),
430
362
  cls=np.array(
431
363
  np.zeros(1),
432
364
  ),
433
365
  )
434
-
366
+
435
367
  # Run a dummy prediction to load the model
436
368
  self.loaded_model.predict(
437
369
  np.zeros((640, 640, 3), dtype=np.uint8),
438
- visual_prompts=visuals.copy(),
439
- predictor=YOLOEVPDetectPredictor,
370
+ visual_prompts=visuals.copy(), # This needs to happen to properly initialize the predictor
371
+ predictor=YOLOEVPSegPredictor, # This also needs to be SegPredictor, no matter what
440
372
  imgsz=640,
441
373
  conf=0.99,
442
374
  )
443
375
 
444
- # Load the model class names if available
445
- if self.class_mapping:
446
- self.add_labels_to_label_window()
447
-
448
- progress_bar.finish_progress()
449
- self.status_bar.setText("Model loaded")
376
+ self.status_bar.setText(f"Loaded ({self.model_path}")
450
377
  QMessageBox.information(self.annotation_window, "Model Loaded", "Model loaded successfully")
451
378
 
452
379
  except Exception as e:
380
+ self.loaded_model = None
381
+ self.status_bar.setText(f"Error loading model: {self.model_path}")
453
382
  QMessageBox.critical(self.annotation_window, "Error Loading Model", f"Error loading model: {e}")
454
-
383
+
455
384
  finally:
456
385
  # Restore cursor
457
386
  QApplication.restoreOverrideCursor()
@@ -460,18 +389,6 @@ class DeployPredictorDialog(QDialog):
460
389
  progress_bar.close()
461
390
  progress_bar = None
462
391
 
463
- self.accept()
464
-
465
- def add_labels_to_label_window(self):
466
- """
467
- Add labels to the label window based on the class mapping.
468
- """
469
- if self.class_mapping:
470
- for label in self.class_mapping.values():
471
- self.main_window.label_window.add_label_if_not_exists(label['short_label_code'],
472
- label['long_label_code'],
473
- QColor(*label['color']))
474
-
475
392
  def resize_image(self, image):
476
393
  """
477
394
  Resize the image to the specified size.
@@ -517,9 +434,6 @@ class DeployPredictorDialog(QDialog):
517
434
  # Open the image using rasterio
518
435
  image = rasterio_to_numpy(self.main_window.image_window.rasterio_images[image_path])
519
436
 
520
- # Preprocess the image
521
- # image = preprocess_image(image)
522
-
523
437
  # Save the original image
524
438
  self.original_image = image
525
439
  self.image_path = image_path
@@ -529,19 +443,57 @@ class DeployPredictorDialog(QDialog):
529
443
  self.resized_image = self.resize_image(image)
530
444
  else:
531
445
  self.resized_image = image
446
+
447
+ def scale_prompts(self, bboxes, masks=None):
448
+ """
449
+ Scale the bounding boxes and masks to the resized image.
450
+ """
451
+ # Update the bbox coordinates to be relative to the resized image
452
+ bboxes = np.array(bboxes)
453
+ bboxes[:, 0] = (bboxes[:, 0] / self.original_image.shape[1]) * self.resized_image.shape[1]
454
+ bboxes[:, 1] = (bboxes[:, 1] / self.original_image.shape[0]) * self.resized_image.shape[0]
455
+ bboxes[:, 2] = (bboxes[:, 2] / self.original_image.shape[1]) * self.resized_image.shape[1]
456
+ bboxes[:, 3] = (bboxes[:, 3] / self.original_image.shape[0]) * self.resized_image.shape[0]
457
+
458
+ # Set the predictor
459
+ self.task = self.task_dropdown.currentText()
532
460
 
533
- def predict_from_prompts(self, bboxes):
461
+ # Create a visual dictionary
462
+ visual_prompts = {
463
+ 'bboxes': np.array(bboxes),
464
+ 'cls': np.zeros(len(bboxes))
465
+ }
466
+ if self.task == 'segment':
467
+ if masks:
468
+ scaled_masks = []
469
+ for mask in masks:
470
+ scaled_mask = np.array(mask, dtype=np.float32)
471
+ scaled_mask[:, 0] = (scaled_mask[:, 0] / self.original_image.shape[1]) * self.resized_image.shape[1]
472
+ scaled_mask[:, 1] = (scaled_mask[:, 1] / self.original_image.shape[0]) * self.resized_image.shape[0]
473
+ scaled_masks.append(scaled_mask)
474
+ visual_prompts['masks'] = scaled_masks
475
+ else: # Fallback to creating masks from bboxes if no masks are provided
476
+ fallback_masks = []
477
+ for bbox in bboxes:
478
+ x1, y1, x2, y2 = bbox
479
+ fallback_masks.append(np.array([[x1, y1], [x2, y1], [x2, y2], [x1, y2]]))
480
+ visual_prompts['masks'] = fallback_masks
481
+
482
+ return visual_prompts
483
+
484
+ def predict_from_prompts(self, bboxes, masks=None):
534
485
  """
535
486
  Make predictions using the currently loaded model using prompts.
536
487
 
537
488
  Args:
538
- bbox (np.ndarray): The bounding boxes to use as prompts.
489
+ bboxes (np.ndarray): The bounding boxes to use as prompts.
490
+ masks (list, optional): A list of polygons to use as prompts for segmentation.
539
491
 
540
492
  Returns:
541
493
  results (Results): Ultralytics Results object
542
494
  """
543
495
  if not self.loaded_model:
544
- QMessageBox.critical(self.annotation_window,
496
+ QMessageBox.critical(self.annotation_window,
545
497
  "Model Not Loaded",
546
498
  "Model not loaded, cannot make predictions")
547
499
  return None
@@ -549,28 +501,14 @@ class DeployPredictorDialog(QDialog):
549
501
  if not len(bboxes):
550
502
  return None
551
503
 
552
- # Update the bbox coordinates to be relative to the resized image
553
- bboxes = np.array(bboxes)
554
- bboxes[:, 0] = (bboxes[:, 0] / self.original_image.shape[1]) * self.resized_image.shape[1]
555
- bboxes[:, 1] = (bboxes[:, 1] / self.original_image.shape[0]) * self.resized_image.shape[0]
556
- bboxes[:, 2] = (bboxes[:, 2] / self.original_image.shape[1]) * self.resized_image.shape[1]
557
- bboxes[:, 3] = (bboxes[:, 3] / self.original_image.shape[0]) * self.resized_image.shape[0]
558
-
559
- # Create a visual dictionary
560
- visuals = {
561
- 'bboxes': np.array(bboxes),
562
- 'cls': np.zeros(len(bboxes)) # TODO figure this out
563
- }
564
-
565
- # Set the predictor
566
- self.task = self.task_dropdown.currentText()
567
- predictor = YOLOEVPSegPredictor if self.task == "segment" else YOLOEVPDetectPredictor
504
+ # Get the scaled visual prompts
505
+ visual_prompts = self.scale_prompts(bboxes, masks)
568
506
 
569
507
  try:
570
508
  # Make predictions
571
509
  results = self.loaded_model.predict(self.resized_image,
572
- visual_prompts=visuals.copy(),
573
- predictor=predictor,
510
+ visual_prompts=visual_prompts.copy(),
511
+ predictor=YOLOEVPSegPredictor,
574
512
  imgsz=max(self.resized_image.shape[:2]),
575
513
  conf=self.main_window.get_uncertainty_thresh(),
576
514
  iou=self.main_window.get_iou_thresh(),
@@ -590,7 +528,7 @@ class DeployPredictorDialog(QDialog):
590
528
 
591
529
  return results
592
530
 
593
- def predict_from_annotations(self, refer_image, refer_label, refer_annotations, target_images):
531
+ def predict_from_annotations(self, refer_image, refer_label, refer_bboxes, refer_masks, target_images):
594
532
  """"""
595
533
  # Create a class mapping
596
534
  class_mapping = {0: refer_label}
@@ -605,15 +543,30 @@ class DeployPredictorDialog(QDialog):
605
543
  max_area_thresh=self.main_window.get_area_thresh_max()
606
544
  )
607
545
 
608
- # Create a visual dictionary
609
- visuals = {
610
- 'bboxes': np.array(refer_annotations),
611
- 'cls': np.zeros(len(refer_annotations))
612
- }
613
-
614
- # Set the predictor
615
- self.task = self.task_dropdown.currentText()
616
- predictor = YOLOEVPSegPredictor if self.task == "segment" else YOLOEVPDetectPredictor
546
+ # Get the scaled visual prompts
547
+ visual_prompts = self.scale_prompts(refer_bboxes, refer_masks)
548
+
549
+ # If VPEs are being used
550
+ if self.vpe is not None:
551
+ # Generate a new VPE from the current visual prompts
552
+ new_vpe = self.prompts_to_vpes(visual_prompts, self.resized_image)
553
+
554
+ if new_vpe is not None:
555
+ # If we already have a VPE, average with the existing one
556
+ if self.vpe.shape == new_vpe.shape:
557
+ self.vpe = (self.vpe + new_vpe) / 2
558
+ # Re-normalize
559
+ self.vpe = torch.nn.functional.normalize(self.vpe, p=2, dim=-1)
560
+ else:
561
+ # Replace with the new VPE if shapes don't match
562
+ self.vpe = new_vpe
563
+
564
+ # Set the updated VPE in the model
565
+ self.loaded_model.is_fused = lambda: False
566
+ self.loaded_model.set_classes(["object0"], self.vpe)
567
+
568
+ # Clear visual prompts since we're using VPE
569
+ visual_prompts = {} # this is okay with a fused model
617
570
 
618
571
  # Create a progress bar
619
572
  QApplication.setOverrideCursor(Qt.WaitCursor)
@@ -627,8 +580,8 @@ class DeployPredictorDialog(QDialog):
627
580
  # Make predictions
628
581
  results = self.loaded_model.predict(target_image,
629
582
  refer_image=refer_image,
630
- visual_prompts=visuals.copy(),
631
- predictor=predictor,
583
+ visual_prompts=visual_prompts.copy(),
584
+ predictor=YOLOEVPSegPredictor,
632
585
  imgsz=self.imgsz_spinbox.value(),
633
586
  conf=self.main_window.get_uncertainty_thresh(),
634
587
  iou=self.main_window.get_iou_thresh(),
@@ -13,7 +13,7 @@ from ultralytics.models.yolo.yoloe import YOLOEPESegTrainer
13
13
  from PyQt5.QtCore import Qt, QThread, pyqtSignal
14
14
  from PyQt5.QtWidgets import (QFileDialog, QScrollArea, QMessageBox, QCheckBox, QWidget, QVBoxLayout,
15
15
  QLabel, QLineEdit, QDialog, QHBoxLayout, QPushButton, QComboBox, QSpinBox,
16
- QFormLayout, QTabWidget, QDoubleSpinBox, QGroupBox, )
16
+ QFormLayout, QTabWidget, QDoubleSpinBox, QGroupBox)
17
17
 
18
18
  from torch.cuda import empty_cache
19
19
 
@@ -109,7 +109,7 @@ class TrainModelWorker(QThread):
109
109
 
110
110
  # Train the model
111
111
  self.model.train(**self.params,
112
- trainer=YOLOEPESegTrainer,
112
+ trainer=YOLOEVPTrainer , #YOLOEPESegTrainer,
113
113
  device=self.device)
114
114
 
115
115
  # Post-run cleanup
@@ -222,6 +222,13 @@ class TrainModelDialog(QDialog):
222
222
  # Task specific parameters
223
223
  self.imgsz = 640
224
224
  self.batch = 4
225
+
226
+ # Training parameters with defaults for linear-probing
227
+ self._lr0 = 1e-3
228
+ self._warmup_bias_lr = 0.0
229
+ self._weight_decay = 0.025
230
+ self._momentum = 0.9
231
+ self._close_mosaic = 0 # Default for linear-probing
225
232
 
226
233
  # Create the layout
227
234
  self.layout = QVBoxLayout(self)
@@ -232,6 +239,8 @@ class TrainModelDialog(QDialog):
232
239
  self.setup_dataset_layout()
233
240
  # Create the model layout (new)
234
241
  self.setup_model_layout()
242
+ # Create the output parameters layout
243
+ self.setup_output_parameters_layout()
235
244
  # Create and set up the parameters layout
236
245
  self.setup_parameters_layout()
237
246
  # Create the buttons layout
@@ -327,6 +336,29 @@ class TrainModelDialog(QDialog):
327
336
  layout.addWidget(tab_widget)
328
337
  group_box.setLayout(layout)
329
338
  self.layout.addWidget(group_box)
339
+
340
+ def setup_output_parameters_layout(self):
341
+ """
342
+ Set up the layout and widgets for output parameters (project directory and name).
343
+ """
344
+ group_box = QGroupBox("Output")
345
+ layout = QFormLayout()
346
+
347
+ # Project
348
+ self.project_edit = QLineEdit()
349
+ self.project_button = QPushButton("Browse...")
350
+ self.project_button.clicked.connect(self.browse_project_dir)
351
+ project_layout = QHBoxLayout()
352
+ project_layout.addWidget(self.project_edit)
353
+ project_layout.addWidget(self.project_button)
354
+ layout.addRow("Project:", project_layout)
355
+
356
+ # Name
357
+ self.name_edit = QLineEdit()
358
+ layout.addRow("Name:", self.name_edit)
359
+
360
+ group_box.setLayout(layout)
361
+ self.layout.addWidget(group_box)
330
362
 
331
363
  def setup_parameters_layout(self):
332
364
  """
@@ -352,19 +384,6 @@ class TrainModelDialog(QDialog):
352
384
  group_layout = QVBoxLayout(group_box)
353
385
  group_layout.addWidget(scroll_area)
354
386
 
355
- # Project
356
- self.project_edit = QLineEdit()
357
- self.project_button = QPushButton("Browse...")
358
- self.project_button.clicked.connect(self.browse_project_dir)
359
- project_layout = QHBoxLayout()
360
- project_layout.addWidget(self.project_edit)
361
- project_layout.addWidget(self.project_button)
362
- form_layout.addRow("Project:", project_layout)
363
-
364
- # Name
365
- self.name_edit = QLineEdit()
366
- form_layout.addRow("Name:", self.name_edit)
367
-
368
387
  # Fine-tune or Linear Probing
369
388
  self.training_mode = QComboBox()
370
389
  self.training_mode.addItems(["linear-probe", "fine-tune"])
@@ -381,11 +400,60 @@ class TrainModelDialog(QDialog):
381
400
 
382
401
  # Patience
383
402
  self.patience_spinbox = QSpinBox()
384
- self.patience_spinbox.setMinimum(1)
403
+ self.patience_spinbox.setMinimum(0) # Changed minimum to 0 to allow for 0 patience
385
404
  self.patience_spinbox.setMaximum(1000)
386
- self.patience_spinbox.setValue(30)
405
+ self.patience_spinbox.setValue(0) # Default for linear-probing
387
406
  form_layout.addRow("Patience:", self.patience_spinbox)
388
407
 
408
+ # Close Mosaic
409
+ self.close_mosaic_spinbox = QSpinBox()
410
+ self.close_mosaic_spinbox.setMinimum(0)
411
+ self.close_mosaic_spinbox.setMaximum(1000)
412
+ self.close_mosaic_spinbox.setValue(0) # Default for linear-probing
413
+ form_layout.addRow("Close Mosaic:", self.close_mosaic_spinbox)
414
+
415
+ # Optimizer
416
+ self.optimizer_combo = QComboBox()
417
+ self.optimizer_combo.addItems(["auto", "SGD", "Adam", "AdamW", "NAdam", "RAdam", "RMSProp"])
418
+ self.optimizer_combo.setCurrentText("AdamW")
419
+ form_layout.addRow("Optimizer:", self.optimizer_combo)
420
+
421
+ # Learning Rate (lr0)
422
+ self.lr0_spinbox = QDoubleSpinBox()
423
+ self.lr0_spinbox.setDecimals(6)
424
+ self.lr0_spinbox.setMinimum(0.000001)
425
+ self.lr0_spinbox.setMaximum(1.0)
426
+ self.lr0_spinbox.setSingleStep(0.0001)
427
+ self.lr0_spinbox.setValue(self._lr0)
428
+ form_layout.addRow("Learning Rate:", self.lr0_spinbox)
429
+
430
+ # Warmup Bias Learning Rate
431
+ self.warmup_bias_lr_spinbox = QDoubleSpinBox()
432
+ self.warmup_bias_lr_spinbox.setDecimals(6)
433
+ self.warmup_bias_lr_spinbox.setMinimum(0.0)
434
+ self.warmup_bias_lr_spinbox.setMaximum(1.0)
435
+ self.warmup_bias_lr_spinbox.setSingleStep(0.0001)
436
+ self.warmup_bias_lr_spinbox.setValue(self._warmup_bias_lr)
437
+ form_layout.addRow("Warmup Bias LR:", self.warmup_bias_lr_spinbox)
438
+
439
+ # Weight Decay
440
+ self.weight_decay_spinbox = QDoubleSpinBox()
441
+ self.weight_decay_spinbox.setDecimals(6)
442
+ self.weight_decay_spinbox.setMinimum(0.0)
443
+ self.weight_decay_spinbox.setMaximum(1.0)
444
+ self.weight_decay_spinbox.setSingleStep(0.001)
445
+ self.weight_decay_spinbox.setValue(self._weight_decay)
446
+ form_layout.addRow("Weight Decay:", self.weight_decay_spinbox)
447
+
448
+ # Momentum
449
+ self.momentum_spinbox = QDoubleSpinBox()
450
+ self.momentum_spinbox.setDecimals(2)
451
+ self.momentum_spinbox.setMinimum(0.0)
452
+ self.momentum_spinbox.setMaximum(1.0)
453
+ self.momentum_spinbox.setSingleStep(0.01)
454
+ self.momentum_spinbox.setValue(self._momentum)
455
+ form_layout.addRow("Momentum:", self.momentum_spinbox)
456
+
389
457
  # Imgsz
390
458
  self.imgsz_spinbox = QSpinBox()
391
459
  self.imgsz_spinbox.setMinimum(16)
@@ -393,10 +461,6 @@ class TrainModelDialog(QDialog):
393
461
  self.imgsz_spinbox.setValue(self.imgsz)
394
462
  form_layout.addRow("Image Size:", self.imgsz_spinbox)
395
463
 
396
- # Multi Scale
397
- self.multi_scale_combo = create_bool_combo()
398
- form_layout.addRow("Multi-Scale:", self.multi_scale_combo)
399
-
400
464
  # Batch
401
465
  self.batch_spinbox = QSpinBox()
402
466
  self.batch_spinbox.setMinimum(1)
@@ -421,20 +485,6 @@ class TrainModelDialog(QDialog):
421
485
  self.save_period_spinbox.setMaximum(1000)
422
486
  self.save_period_spinbox.setValue(-1)
423
487
  form_layout.addRow("Save Period:", self.save_period_spinbox)
424
-
425
- # Dropout
426
- self.dropout_spinbox = QDoubleSpinBox()
427
- self.dropout_spinbox.setMinimum(0.0)
428
- self.dropout_spinbox.setMaximum(1.0)
429
- self.dropout_spinbox.setValue(0.0)
430
- form_layout.addRow("Dropout:", self.dropout_spinbox)
431
-
432
- # Optimizer
433
- self.optimizer_combo = QComboBox()
434
- self.optimizer_combo.addItems(["auto", "SGD", "Adam", "AdamW", "NAdam", "RAdam", "RMSProp"])
435
- self.optimizer_combo.setCurrentText("AdamW")
436
- form_layout.addRow("Optimizer:", self.optimizer_combo)
437
-
438
488
  # Val
439
489
  self.val_combo = create_bool_combo()
440
490
  form_layout.addRow("Validation:", self.val_combo)
@@ -489,9 +539,32 @@ class TrainModelDialog(QDialog):
489
539
  # Fine-tune mode
490
540
  self.epochs_spinbox.setValue(80)
491
541
  self.patience_spinbox.setValue(20)
542
+ self.close_mosaic_spinbox.setValue(10)
543
+ self._close_mosaic = 10
544
+
545
+ # These parameters stay the same for both modes
546
+ self.lr0_spinbox.setValue(1e-3)
547
+ self.warmup_bias_lr_spinbox.setValue(0.0)
548
+ self.weight_decay_spinbox.setValue(0.025)
549
+ self.momentum_spinbox.setValue(0.9)
550
+
551
+ # Ensure optimizer is set to AdamW
552
+ self.optimizer_combo.setCurrentText("AdamW")
492
553
  else:
554
+ # Linear-probing mode
493
555
  self.epochs_spinbox.setValue(2)
494
- self.patience_spinbox.setValue(1)
556
+ self.patience_spinbox.setValue(0)
557
+ self.close_mosaic_spinbox.setValue(0)
558
+ self._close_mosaic = 0
559
+
560
+ # These parameters stay the same for both modes
561
+ self.lr0_spinbox.setValue(1e-3)
562
+ self.warmup_bias_lr_spinbox.setValue(0.0)
563
+ self.weight_decay_spinbox.setValue(0.025)
564
+ self.momentum_spinbox.setValue(0.9)
565
+
566
+ # Ensure optimizer is set to AdamW
567
+ self.optimizer_combo.setCurrentText("AdamW")
495
568
 
496
569
  def load_model_combobox(self):
497
570
  """Load the model combobox with the available models."""
@@ -605,13 +678,16 @@ class TrainModelDialog(QDialog):
605
678
  'patience': self.patience_spinbox.value(),
606
679
  'batch': self.batch_spinbox.value(),
607
680
  'imgsz': self.imgsz_spinbox.value(),
608
- 'multi_scale': self.multi_scale_combo.currentText() == "True",
681
+ 'optimizer': self.optimizer_combo.currentText(),
682
+ 'lr0': self.lr0_spinbox.value(),
683
+ 'warmup_bias_lr': self.warmup_bias_lr_spinbox.value(),
684
+ 'weight_decay': self.weight_decay_spinbox.value(),
685
+ 'momentum': self.momentum_spinbox.value(),
686
+ 'close_mosaic': self.close_mosaic_spinbox.value(),
609
687
  'save': self.save_combo.currentText() == "True",
610
688
  'save_period': self.save_period_spinbox.value(),
611
689
  'workers': self.workers_spinbox.value(),
612
- 'optimizer': self.optimizer_combo.currentText(),
613
690
  'verbose': self.verbose_combo.currentText() == "True",
614
- 'dropout': self.dropout_spinbox.value(),
615
691
  'val': self.val_combo.currentText() == "True",
616
692
  'exist_ok': True,
617
693
  'plots': True,
@@ -644,12 +720,6 @@ class TrainModelDialog(QDialog):
644
720
  except ValueError:
645
721
  params[name] = value
646
722
 
647
- params['lr0'] = 1e-3 if 'lr0' not in params else params['lr0']
648
- params['warmup_bias_lr'] = 0.0 if 'warmup_bias_lr' not in params else params['warmup_bias_lr']
649
- params['weight_decay'] = 0.025 if 'weight_decay' not in params else params['weight_decay']
650
- params['momentum'] = 0.9 if 'momentum' not in params else params['momentum']
651
- params['close_mosaic'] = 10 if 'close_mosaic' not in params else params['close_mosaic']
652
-
653
723
  # Return the dictionary of parameters
654
724
  return params
655
725
 
@@ -2,9 +2,11 @@
2
2
  from .QtTrainModel import TrainModelDialog
3
3
  from .QtBatchInference import BatchInferenceDialog
4
4
  from .QtDeployPredictor import DeployPredictorDialog
5
+ from .QtDeployGenerator import DeployGeneratorDialog
5
6
 
6
7
  __all__ = [
7
8
  'TrainModelDialog',
8
9
  'BatchInferenceDialog'
9
10
  'DeployPredictorDialog',
11
+ 'DeployGeneratorDialog',
10
12
  ]