coralnet-toolbox 0.0.73__py2.py3-none-any.whl → 0.0.74__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. coralnet_toolbox/Annotations/QtAnnotation.py +28 -69
  2. coralnet_toolbox/Annotations/QtMaskAnnotation.py +408 -0
  3. coralnet_toolbox/Annotations/QtMultiPolygonAnnotation.py +72 -56
  4. coralnet_toolbox/Annotations/QtPatchAnnotation.py +165 -216
  5. coralnet_toolbox/Annotations/QtPolygonAnnotation.py +497 -353
  6. coralnet_toolbox/Annotations/QtRectangleAnnotation.py +126 -116
  7. coralnet_toolbox/CoralNet/QtDownload.py +2 -1
  8. coralnet_toolbox/Explorer/QtExplorer.py +16 -14
  9. coralnet_toolbox/Explorer/QtSettingsWidgets.py +114 -82
  10. coralnet_toolbox/IO/QtExportTagLabAnnotations.py +30 -10
  11. coralnet_toolbox/IO/QtImportTagLabAnnotations.py +21 -15
  12. coralnet_toolbox/IO/QtOpenProject.py +46 -78
  13. coralnet_toolbox/IO/QtSaveProject.py +18 -43
  14. coralnet_toolbox/MachineLearning/ExportDataset/QtBase.py +1 -1
  15. coralnet_toolbox/MachineLearning/ImportDataset/QtBase.py +42 -22
  16. coralnet_toolbox/MachineLearning/VideoInference/QtBase.py +0 -4
  17. coralnet_toolbox/QtEventFilter.py +11 -0
  18. coralnet_toolbox/QtImageWindow.py +117 -68
  19. coralnet_toolbox/QtLabelWindow.py +13 -1
  20. coralnet_toolbox/QtMainWindow.py +5 -27
  21. coralnet_toolbox/QtProgressBar.py +52 -27
  22. coralnet_toolbox/Rasters/RasterTableModel.py +8 -8
  23. coralnet_toolbox/SAM/QtDeployPredictor.py +10 -0
  24. coralnet_toolbox/SeeAnything/QtDeployGenerator.py +779 -161
  25. coralnet_toolbox/SeeAnything/QtDeployPredictor.py +86 -149
  26. coralnet_toolbox/Tools/QtCutSubTool.py +18 -2
  27. coralnet_toolbox/Tools/QtResizeSubTool.py +19 -2
  28. coralnet_toolbox/Tools/QtSAMTool.py +72 -50
  29. coralnet_toolbox/Tools/QtSeeAnythingTool.py +8 -5
  30. coralnet_toolbox/Tools/QtSelectTool.py +27 -3
  31. coralnet_toolbox/Tools/QtSubtractSubTool.py +66 -0
  32. coralnet_toolbox/Tools/__init__.py +2 -0
  33. coralnet_toolbox/__init__.py +1 -1
  34. coralnet_toolbox/utilities.py +137 -47
  35. coralnet_toolbox-0.0.74.dist-info/METADATA +375 -0
  36. {coralnet_toolbox-0.0.73.dist-info → coralnet_toolbox-0.0.74.dist-info}/RECORD +40 -38
  37. coralnet_toolbox-0.0.73.dist-info/METADATA +0 -341
  38. {coralnet_toolbox-0.0.73.dist-info → coralnet_toolbox-0.0.74.dist-info}/WHEEL +0 -0
  39. {coralnet_toolbox-0.0.73.dist-info → coralnet_toolbox-0.0.74.dist-info}/entry_points.txt +0 -0
  40. {coralnet_toolbox-0.0.73.dist-info → coralnet_toolbox-0.0.74.dist-info}/licenses/LICENSE.txt +0 -0
  41. {coralnet_toolbox-0.0.73.dist-info → coralnet_toolbox-0.0.74.dist-info}/top_level.txt +0 -0
@@ -2,10 +2,10 @@ import warnings
2
2
 
3
3
  import os
4
4
  import gc
5
- import ujson as json
6
5
 
7
6
  import numpy as np
8
7
 
8
+ import torch
9
9
  from torch.cuda import empty_cache
10
10
  from ultralytics.utils import ops
11
11
 
@@ -17,7 +17,7 @@ from PyQt5.QtCore import Qt
17
17
  from PyQt5.QtGui import QColor
18
18
  from PyQt5.QtWidgets import (QApplication, QComboBox, QDialog, QFormLayout,
19
19
  QHBoxLayout, QLabel, QMessageBox, QPushButton,
20
- QSlider, QSpinBox, QVBoxLayout, QGroupBox, QTabWidget,
20
+ QSlider, QSpinBox, QVBoxLayout, QGroupBox,
21
21
  QWidget, QLineEdit, QFileDialog)
22
22
 
23
23
  from coralnet_toolbox.Results import ResultsProcessor
@@ -98,7 +98,10 @@ class DeployPredictorDialog(QDialog):
98
98
  layout = QVBoxLayout()
99
99
 
100
100
  # Create a QLabel with explanatory text and hyperlink
101
- info_label = QLabel("Choose a Predictor to deploy and use interactively with the See Anything tool.")
101
+ info_label = QLabel(
102
+ "Choose a Predictor to deploy and use interactively with the See Anything tool. "
103
+ "Optionally include a custom visual prompt encoding (VPE) file."
104
+ )
102
105
 
103
106
  info_label.setOpenExternalLinks(True)
104
107
  info_label.setWordWrap(True)
@@ -109,21 +112,15 @@ class DeployPredictorDialog(QDialog):
109
112
 
110
113
  def setup_models_layout(self):
111
114
  """
112
- Setup the models layout with tabs for standard and custom models.
115
+ Setup the models layout with standard models and file selection.
113
116
  """
114
117
  group_box = QGroupBox("Model Selection")
115
- layout = QVBoxLayout()
116
-
117
- # Create tabbed widget
118
- tab_widget = QTabWidget()
119
-
120
- # Tab 1: Standard models
121
- standard_tab = QWidget()
122
- standard_layout = QVBoxLayout(standard_tab)
123
-
118
+ layout = QFormLayout()
119
+
120
+ # Model dropdown
124
121
  self.model_combo = QComboBox()
125
122
  self.model_combo.setEditable(True)
126
-
123
+
127
124
  # Define available models
128
125
  standard_models = [
129
126
  'yoloe-v8s-seg.pt',
@@ -133,49 +130,15 @@ class DeployPredictorDialog(QDialog):
133
130
  'yoloe-11m-seg.pt',
134
131
  'yoloe-11l-seg.pt',
135
132
  ]
136
-
133
+
137
134
  # Add all models to combo box
138
135
  self.model_combo.addItems(standard_models)
136
+
139
137
  # Set the default model
140
138
  self.model_combo.setCurrentIndex(standard_models.index('yoloe-v8s-seg.pt'))
141
-
142
- standard_layout.addWidget(QLabel("Models"))
143
- standard_layout.addWidget(self.model_combo)
144
-
145
- tab_widget.addTab(standard_tab, "Use Existing Model")
146
-
147
- # Tab 2: Custom model
148
- custom_tab = QWidget()
149
- custom_layout = QFormLayout(custom_tab)
150
-
151
- # Custom model file selection
152
- self.model_path_edit = QLineEdit()
153
- browse_button = QPushButton("Browse...")
154
- browse_button.clicked.connect(self.browse_model_file)
155
-
156
- model_path_layout = QHBoxLayout()
157
- model_path_layout.addWidget(self.model_path_edit)
158
- model_path_layout.addWidget(browse_button)
159
- custom_layout.addRow("Custom Model:", model_path_layout)
160
-
161
- # Class Mapping
162
- self.mapping_edit = QLineEdit()
163
- self.mapping_button = QPushButton("Browse...")
164
- self.mapping_button.clicked.connect(self.browse_class_mapping_file)
165
-
166
- class_mapping_layout = QHBoxLayout()
167
- class_mapping_layout.addWidget(self.mapping_edit)
168
- class_mapping_layout.addWidget(self.mapping_button)
169
- custom_layout.addRow("Class Mapping:", class_mapping_layout)
170
-
171
- tab_widget.addTab(custom_tab, "Custom Model")
172
-
173
- # Add the tab widget to the main layout
174
- layout.addWidget(tab_widget)
175
-
176
- # Store the tab widget for later reference
177
- self.model_tab_widget = tab_widget
178
-
139
+ # Create a layout for the model selection
140
+ layout.addRow("Models:", self.model_combo)
141
+
179
142
  group_box.setLayout(layout)
180
143
  self.layout.addWidget(group_box)
181
144
 
@@ -304,37 +267,6 @@ class DeployPredictorDialog(QDialog):
304
267
 
305
268
  group_box.setLayout(layout)
306
269
  self.layout.addWidget(group_box)
307
-
308
- def browse_model_file(self):
309
- """
310
- Open a file dialog to browse for a model file.
311
- """
312
- file_path, _ = QFileDialog.getOpenFileName(self,
313
- "Select Model File",
314
- "",
315
- "Model Files (*.pt *.pth);;All Files (*)")
316
- if file_path:
317
- self.model_path_edit.setText(file_path)
318
-
319
- # Load the class mapping if it exists
320
- dir_path = os.path.dirname(os.path.dirname(file_path))
321
- class_mapping_path = f"{dir_path}/class_mapping.json"
322
- if os.path.exists(class_mapping_path):
323
- self.class_mapping = json.load(open(class_mapping_path, 'r'))
324
- self.mapping_edit.setText(class_mapping_path)
325
-
326
- def browse_class_mapping_file(self):
327
- """
328
- Browse and select a class mapping file.
329
- """
330
- file_path, _ = QFileDialog.getOpenFileName(self,
331
- "Select Class Mapping File",
332
- "",
333
- "JSON Files (*.json)")
334
- if file_path:
335
- # Load the class mapping
336
- self.class_mapping = json.load(open(file_path, 'r'))
337
- self.mapping_edit.setText(file_path)
338
270
 
339
271
  def initialize_uncertainty_threshold(self):
340
272
  """Initialize the uncertainty threshold slider with the current value"""
@@ -412,46 +344,43 @@ class DeployPredictorDialog(QDialog):
412
344
  QApplication.setOverrideCursor(Qt.WaitCursor)
413
345
  progress_bar = ProgressBar(self.annotation_window, title="Loading Model")
414
346
  progress_bar.show()
415
-
347
+
416
348
  try:
417
349
  # Get selected model path and download weights if needed
418
350
  self.model_path = self.model_combo.currentText()
419
-
351
+
420
352
  # Load model using registry
421
353
  self.loaded_model = YOLOE(self.model_path).to(self.main_window.device)
422
-
423
- # Create a dummy visual dictionary
354
+
355
+ # Create a dummy visual dictionary for standard model loading
424
356
  visuals = dict(
425
357
  bboxes=np.array(
426
358
  [
427
- [120, 425, 160, 445],
359
+ [120, 425, 160, 445], # Random box
428
360
  ],
429
361
  ),
430
362
  cls=np.array(
431
363
  np.zeros(1),
432
364
  ),
433
365
  )
434
-
366
+
435
367
  # Run a dummy prediction to load the model
436
368
  self.loaded_model.predict(
437
369
  np.zeros((640, 640, 3), dtype=np.uint8),
438
- visual_prompts=visuals.copy(),
439
- predictor=YOLOEVPDetectPredictor,
370
+ visual_prompts=visuals.copy(), # This needs to happen to properly initialize the predictor
371
+ predictor=YOLOEVPSegPredictor, # This also needs to be SegPredictor, no matter what
440
372
  imgsz=640,
441
373
  conf=0.99,
442
374
  )
443
375
 
444
- # Load the model class names if available
445
- if self.class_mapping:
446
- self.add_labels_to_label_window()
447
-
448
- progress_bar.finish_progress()
449
- self.status_bar.setText("Model loaded")
376
+ self.status_bar.setText(f"Loaded ({self.model_path}")
450
377
  QMessageBox.information(self.annotation_window, "Model Loaded", "Model loaded successfully")
451
378
 
452
379
  except Exception as e:
380
+ self.loaded_model = None
381
+ self.status_bar.setText(f"Error loading model: {self.model_path}")
453
382
  QMessageBox.critical(self.annotation_window, "Error Loading Model", f"Error loading model: {e}")
454
-
383
+
455
384
  finally:
456
385
  # Restore cursor
457
386
  QApplication.restoreOverrideCursor()
@@ -460,18 +389,6 @@ class DeployPredictorDialog(QDialog):
460
389
  progress_bar.close()
461
390
  progress_bar = None
462
391
 
463
- self.accept()
464
-
465
- def add_labels_to_label_window(self):
466
- """
467
- Add labels to the label window based on the class mapping.
468
- """
469
- if self.class_mapping:
470
- for label in self.class_mapping.values():
471
- self.main_window.label_window.add_label_if_not_exists(label['short_label_code'],
472
- label['long_label_code'],
473
- QColor(*label['color']))
474
-
475
392
  def resize_image(self, image):
476
393
  """
477
394
  Resize the image to the specified size.
@@ -526,26 +443,11 @@ class DeployPredictorDialog(QDialog):
526
443
  self.resized_image = self.resize_image(image)
527
444
  else:
528
445
  self.resized_image = image
529
-
530
- def predict_from_prompts(self, bboxes, masks=None):
446
+
447
+ def scale_prompts(self, bboxes, masks=None):
531
448
  """
532
- Make predictions using the currently loaded model using prompts.
533
-
534
- Args:
535
- bboxes (np.ndarray): The bounding boxes to use as prompts.
536
- masks (list, optional): A list of polygons to use as prompts for segmentation.
537
-
538
- Returns:
539
- results (Results): Ultralytics Results object
449
+ Scale the bounding boxes and masks to the resized image.
540
450
  """
541
- if not self.loaded_model:
542
- QMessageBox.critical(self.annotation_window, "Model Not Loaded",
543
- "Model not loaded, cannot make predictions")
544
- return None
545
-
546
- if not len(bboxes):
547
- return None
548
-
549
451
  # Update the bbox coordinates to be relative to the resized image
550
452
  bboxes = np.array(bboxes)
551
453
  bboxes[:, 0] = (bboxes[:, 0] / self.original_image.shape[1]) * self.resized_image.shape[1]
@@ -557,7 +459,7 @@ class DeployPredictorDialog(QDialog):
557
459
  self.task = self.task_dropdown.currentText()
558
460
 
559
461
  # Create a visual dictionary
560
- visuals = {
462
+ visual_prompts = {
561
463
  'bboxes': np.array(bboxes),
562
464
  'cls': np.zeros(len(bboxes))
563
465
  }
@@ -569,21 +471,44 @@ class DeployPredictorDialog(QDialog):
569
471
  scaled_mask[:, 0] = (scaled_mask[:, 0] / self.original_image.shape[1]) * self.resized_image.shape[1]
570
472
  scaled_mask[:, 1] = (scaled_mask[:, 1] / self.original_image.shape[0]) * self.resized_image.shape[0]
571
473
  scaled_masks.append(scaled_mask)
572
- visuals['masks'] = scaled_masks
474
+ visual_prompts['masks'] = scaled_masks
573
475
  else: # Fallback to creating masks from bboxes if no masks are provided
574
476
  fallback_masks = []
575
477
  for bbox in bboxes:
576
478
  x1, y1, x2, y2 = bbox
577
479
  fallback_masks.append(np.array([[x1, y1], [x2, y1], [x2, y2], [x1, y2]]))
578
- visuals['masks'] = fallback_masks
480
+ visual_prompts['masks'] = fallback_masks
481
+
482
+ return visual_prompts
483
+
484
+ def predict_from_prompts(self, bboxes, masks=None):
485
+ """
486
+ Make predictions using the currently loaded model using prompts.
487
+
488
+ Args:
489
+ bboxes (np.ndarray): The bounding boxes to use as prompts.
490
+ masks (list, optional): A list of polygons to use as prompts for segmentation.
491
+
492
+ Returns:
493
+ results (Results): Ultralytics Results object
494
+ """
495
+ if not self.loaded_model:
496
+ QMessageBox.critical(self.annotation_window,
497
+ "Model Not Loaded",
498
+ "Model not loaded, cannot make predictions")
499
+ return None
500
+
501
+ if not len(bboxes):
502
+ return None
579
503
 
580
- predictor = YOLOEVPSegPredictor if self.task == "segment" else YOLOEVPDetectPredictor
504
+ # Get the scaled visual prompts
505
+ visual_prompts = self.scale_prompts(bboxes, masks)
581
506
 
582
507
  try:
583
508
  # Make predictions
584
509
  results = self.loaded_model.predict(self.resized_image,
585
- visual_prompts=visuals.copy(),
586
- predictor=predictor,
510
+ visual_prompts=visual_prompts.copy(),
511
+ predictor=YOLOEVPSegPredictor,
587
512
  imgsz=max(self.resized_image.shape[:2]),
588
513
  conf=self.main_window.get_uncertainty_thresh(),
589
514
  iou=self.main_window.get_iou_thresh(),
@@ -618,18 +543,30 @@ class DeployPredictorDialog(QDialog):
618
543
  max_area_thresh=self.main_window.get_area_thresh_max()
619
544
  )
620
545
 
621
- # Set the predictor
622
- self.task = self.task_dropdown.currentText()
623
-
624
- # Create a visual dictionary
625
- visuals = {
626
- 'bboxes': np.array(refer_bboxes),
627
- 'cls': np.zeros(len(refer_bboxes))
628
- }
629
- if self.task == 'segment':
630
- visuals['masks'] = refer_masks
631
-
632
- predictor = YOLOEVPSegPredictor if self.task == "segment" else YOLOEVPDetectPredictor
546
+ # Get the scaled visual prompts
547
+ visual_prompts = self.scale_prompts(refer_bboxes, refer_masks)
548
+
549
+ # If VPEs are being used
550
+ if self.vpe is not None:
551
+ # Generate a new VPE from the current visual prompts
552
+ new_vpe = self.prompts_to_vpes(visual_prompts, self.resized_image)
553
+
554
+ if new_vpe is not None:
555
+ # If we already have a VPE, average with the existing one
556
+ if self.vpe.shape == new_vpe.shape:
557
+ self.vpe = (self.vpe + new_vpe) / 2
558
+ # Re-normalize
559
+ self.vpe = torch.nn.functional.normalize(self.vpe, p=2, dim=-1)
560
+ else:
561
+ # Replace with the new VPE if shapes don't match
562
+ self.vpe = new_vpe
563
+
564
+ # Set the updated VPE in the model
565
+ self.loaded_model.is_fused = lambda: False
566
+ self.loaded_model.set_classes(["object0"], self.vpe)
567
+
568
+ # Clear visual prompts since we're using VPE
569
+ visual_prompts = {} # this is okay with a fused model
633
570
 
634
571
  # Create a progress bar
635
572
  QApplication.setOverrideCursor(Qt.WaitCursor)
@@ -643,8 +580,8 @@ class DeployPredictorDialog(QDialog):
643
580
  # Make predictions
644
581
  results = self.loaded_model.predict(target_image,
645
582
  refer_image=refer_image,
646
- visual_prompts=visuals.copy(),
647
- predictor=predictor,
583
+ visual_prompts=visual_prompts.copy(),
584
+ predictor=YOLOEVPSegPredictor,
648
585
  imgsz=self.imgsz_spinbox.value(),
649
586
  conf=self.main_window.get_uncertainty_thresh(),
650
587
  iou=self.main_window.get_iou_thresh(),
@@ -87,9 +87,16 @@ class CutSubTool(SubTool):
87
87
  self._update_cut_line_path(position)
88
88
 
89
89
  def keyPressEvent(self, event):
90
- """Handle key press events for canceling the cut."""
91
- if event.key() in (Qt.Key_Backspace, Qt.Key_Escape):
90
+ """Handle key press events for cutting operations."""
91
+ # Check for Ctrl+X to toggle cutting mode off
92
+ if event.modifiers() & Qt.ControlModifier and event.key() == Qt.Key_X:
92
93
  self.parent_tool.deactivate_subtool()
94
+ return
95
+
96
+ # Handle Backspace to clear the current cutting line but stay in cutting mode
97
+ if event.key() == Qt.Key_Backspace:
98
+ self._clear_cutting_line()
99
+ return
93
100
 
94
101
  def _start_drawing_cut_line(self, position):
95
102
  """Start drawing the cut line from the given position."""
@@ -115,6 +122,15 @@ class CutSubTool(SubTool):
115
122
  path.lineTo(point)
116
123
  self.cutting_path_item.setPath(path)
117
124
 
125
+ def _clear_cutting_line(self):
126
+ """Clear the current cutting line but remain in cutting mode."""
127
+ self.cutting_points = []
128
+ self.drawing_in_progress = False
129
+ if self.cutting_path_item:
130
+ self.annotation_window.scene.removeItem(self.cutting_path_item)
131
+ self.cutting_path_item = None
132
+ self.annotation_window.scene.update()
133
+
118
134
  def _break_apart_multipolygon(self):
119
135
  """Handle the special case of 'cutting' a MultiPolygonAnnotation."""
120
136
  new_annotations = self.target_annotation.cut()
@@ -124,5 +124,22 @@ class ResizeSubTool(SubTool):
124
124
  }
125
125
 
126
126
  def _get_polygon_handles(self, annotation):
127
- """Return resize handles for a polygon annotation."""
128
- return {f"point_{i}": QPointF(p.x(), p.y()) for i, p in enumerate(annotation.points)}
127
+ """
128
+ Return resize handles for a polygon, including its outer boundary and all holes.
129
+ Uses the new handle format: 'point_{poly_index}_{vertex_index}'.
130
+ """
131
+ handles = {}
132
+
133
+ # 1. Create handles for the outer boundary using the 'outer' keyword.
134
+ for i, p in enumerate(annotation.points):
135
+ handle_name = f"point_outer_{i}"
136
+ handles[handle_name] = QPointF(p.x(), p.y())
137
+
138
+ # 2. Create handles for each of the inner holes using their index.
139
+ if hasattr(annotation, 'holes'):
140
+ for hole_index, hole in enumerate(annotation.holes):
141
+ for vertex_index, p in enumerate(hole):
142
+ handle_name = f"point_{hole_index}_{vertex_index}"
143
+ handles[handle_name] = QPointF(p.x(), p.y())
144
+
145
+ return handles
@@ -1,7 +1,7 @@
1
1
  import warnings
2
2
  import numpy as np
3
3
 
4
- from PyQt5.QtCore import Qt, QPointF, QRectF, QTimer
4
+ from PyQt5.QtCore import Qt, QPointF, QRectF
5
5
  from PyQt5.QtGui import QMouseEvent, QKeyEvent, QPen, QColor, QBrush, QPainterPath
6
6
  from PyQt5.QtWidgets import QMessageBox, QGraphicsEllipseItem, QGraphicsRectItem, QGraphicsPathItem, QApplication
7
7
 
@@ -12,6 +12,7 @@ from coralnet_toolbox.QtWorkArea import WorkArea
12
12
 
13
13
  from coralnet_toolbox.utilities import pixmap_to_numpy
14
14
  from coralnet_toolbox.utilities import simplify_polygon
15
+ from coralnet_toolbox.utilities import polygonize_mask_with_holes
15
16
 
16
17
  warnings.filterwarnings("ignore", category=DeprecationWarning)
17
18
 
@@ -369,40 +370,47 @@ class SAMTool(Tool):
369
370
  QApplication.restoreOverrideCursor()
370
371
  return
371
372
 
372
- # Get the points of the top1 mask
373
+ # Get the top confidence prediction's mask tensor
373
374
  top1_index = np.argmax(results.boxes.conf)
374
- predictions = results[top1_index].masks.xy[0]
375
+ mask_tensor = results[top1_index].masks.data
375
376
 
376
- # Safety check: make sure we have predicted points
377
- if len(predictions) == 0:
378
- QApplication.restoreOverrideCursor()
379
- return
377
+ # Check if holes are allowed from the SAM dialog
378
+ allow_holes = self.sam_dialog.get_allow_holes()
380
379
 
381
- # Clean the polygon using Ramer-Douglas-Peucker algorithm
382
- predictions = simplify_polygon(predictions, 0.1)
380
+ # Polygonize the mask to get the exterior and holes
381
+ exterior_coords, holes_coords_list = polygonize_mask_with_holes(mask_tensor)
383
382
 
384
383
  # Safety check: need at least 3 points for a valid polygon
385
- if len(predictions) < 3:
384
+ if len(exterior_coords) < 3:
386
385
  QApplication.restoreOverrideCursor()
387
386
  return
388
387
 
389
- # Move the points back to the original image space
388
+ # --- Process and Clean the Polygon Points ---
390
389
  working_area_top_left = self.working_area.rect.topLeft()
391
- points = [(point[0] + working_area_top_left.x(),
392
- point[1] + working_area_top_left.y()) for point in predictions]
390
+ offset_x, offset_y = working_area_top_left.x(), working_area_top_left.y()
393
391
 
394
- # Convert to QPointF for graphics
395
- self.points = [QPointF(*point) for point in points]
392
+ # Simplify, offset, and convert the exterior points
393
+ simplified_exterior = simplify_polygon(exterior_coords, 0.1)
394
+ self.points = [QPointF(p[0] + offset_x, p[1] + offset_y) for p in simplified_exterior]
396
395
 
397
- # Create the temporary annotation
396
+ # Simplify, offset, and convert each hole only if allowed
397
+ final_holes = []
398
+ if allow_holes:
399
+ for hole_coords in holes_coords_list:
400
+ if len(hole_coords) >= 3: # Ensure holes are also valid polygons
401
+ simplified_hole = simplify_polygon(hole_coords, 0.1)
402
+ final_holes.append([QPointF(p[0] + offset_x, p[1] + offset_y) for p in simplified_hole])
403
+
404
+ # Create the temporary annotation, now with holes (or not)
398
405
  self.temp_annotation = PolygonAnnotation(
399
- self.points,
400
- self.annotation_window.selected_label.short_label_code,
401
- self.annotation_window.selected_label.long_label_code,
402
- self.annotation_window.selected_label.color,
403
- self.annotation_window.current_image_path,
404
- self.annotation_window.selected_label.id,
405
- self.main_window.label_window.active_label.transparency
406
+ points=self.points,
407
+ holes=final_holes,
408
+ short_label_code=self.annotation_window.selected_label.short_label_code,
409
+ long_label_code=self.annotation_window.selected_label.long_label_code,
410
+ color=self.annotation_window.selected_label.color,
411
+ image_path=self.annotation_window.current_image_path,
412
+ label_id=self.annotation_window.selected_label.id,
413
+ transparency=self.main_window.label_window.active_label.transparency
406
414
  )
407
415
 
408
416
  # Create the graphics item for the temporary annotation
@@ -611,12 +619,13 @@ class SAMTool(Tool):
611
619
  # Use existing temporary annotation
612
620
  final_annotation = PolygonAnnotation(
613
621
  self.points,
614
- self.annotation_window.selected_label.short_label_code,
615
- self.annotation_window.selected_label.long_label_code,
616
- self.annotation_window.selected_label.color,
617
- self.annotation_window.current_image_path,
618
- self.annotation_window.selected_label.id,
619
- self.main_window.label_window.active_label.transparency
622
+ self.temp_annotation.label.short_label_code,
623
+ self.temp_annotation.label.long_label_code,
624
+ self.temp_annotation.label.color,
625
+ self.temp_annotation.image_path,
626
+ self.temp_annotation.label.id,
627
+ self.temp_annotation.label.transparency,
628
+ holes=self.temp_annotation.holes
620
629
  )
621
630
 
622
631
  # Copy confidence data
@@ -637,7 +646,7 @@ class SAMTool(Tool):
637
646
  final_annotation = self.create_annotation(True)
638
647
  if final_annotation:
639
648
  self.annotation_window.add_annotation_from_tool(final_annotation)
640
- self.clear_prompt_graphics()
649
+ self.clear_prompt_graphics()
641
650
  # If no active prompts, cancel the working area
642
651
  else:
643
652
  self.cancel_working_area()
@@ -727,24 +736,36 @@ class SAMTool(Tool):
727
736
  QApplication.restoreOverrideCursor()
728
737
  return None
729
738
 
730
- # Get the top confidence prediction
739
+ # Get the top confidence prediction's mask tensor
731
740
  top1_index = np.argmax(results.boxes.conf)
732
- predictions = results[top1_index].masks.xy[0]
741
+ mask_tensor = results[top1_index].masks.data
742
+
743
+ # Check if holes are allowed from the SAM dialog
744
+ allow_holes = self.sam_dialog.get_allow_holes()
733
745
 
734
- # Safety check for predictions
735
- if len(predictions) == 0:
746
+ # Polygonize the mask using the new method to get the exterior and holes
747
+ exterior_coords, holes_coords_list = polygonize_mask_with_holes(mask_tensor)
748
+
749
+ # Safety check for an empty result
750
+ if not exterior_coords:
736
751
  QApplication.restoreOverrideCursor()
737
752
  return None
738
753
 
739
- # Clean polygon points
740
- predictions = simplify_polygon(predictions, 0.1)
741
-
742
- # Move points back to original image space
754
+ # --- Process and Clean the Polygon Points ---
743
755
  working_area_top_left = self.working_area.rect.topLeft()
744
- points = [(point[0] + working_area_top_left.x(),
745
- point[1] + working_area_top_left.y()) for point in predictions]
746
- # Convert to QPointF for graphics
747
- self.points = [QPointF(*point) for point in points]
756
+ offset_x, offset_y = working_area_top_left.x(), working_area_top_left.y()
757
+
758
+ # Simplify, offset, and convert the exterior points
759
+ simplified_exterior = simplify_polygon(exterior_coords, 0.1)
760
+ self.points = [QPointF(p[0] + offset_x, p[1] + offset_y) for p in simplified_exterior]
761
+
762
+ # Simplify, offset, and convert each hole only if allowed
763
+ final_holes = []
764
+ if allow_holes:
765
+ for hole_coords in holes_coords_list:
766
+ if len(hole_coords) >= 3:
767
+ simplified_hole = simplify_polygon(hole_coords, 0.1)
768
+ final_holes.append([QPointF(p[0] + offset_x, p[1] + offset_y) for p in simplified_hole])
748
769
 
749
770
  # Require at least 3 points for valid polygon
750
771
  if len(self.points) < 3:
@@ -754,15 +775,16 @@ class SAMTool(Tool):
754
775
  # Get confidence score
755
776
  confidence = results.boxes.conf[top1_index].item()
756
777
 
757
- # Create final annotation
778
+ # Create final annotation, now passing the holes argument
758
779
  annotation = PolygonAnnotation(
759
- self.points,
760
- self.annotation_window.selected_label.short_label_code,
761
- self.annotation_window.selected_label.long_label_code,
762
- self.annotation_window.selected_label.color,
763
- self.annotation_window.current_image_path,
764
- self.annotation_window.selected_label.id,
765
- self.main_window.label_window.active_label.transparency
780
+ points=self.points,
781
+ holes=final_holes,
782
+ short_label_code=self.annotation_window.selected_label.short_label_code,
783
+ long_label_code=self.annotation_window.selected_label.long_label_code,
784
+ color=self.annotation_window.selected_label.color,
785
+ image_path=self.annotation_window.current_image_path,
786
+ label_id=self.annotation_window.selected_label.id,
787
+ transparency=self.main_window.label_window.active_label.transparency
766
788
  )
767
789
 
768
790
  # Update confidence
@@ -173,6 +173,7 @@ class SeeAnythingTool(Tool):
173
173
 
174
174
  # Set the image in the SeeAnything dialog
175
175
  self.see_anything_dialog.set_image(self.work_area_image, self.image_path)
176
+ # self.see_anything_dialog.reload_model()
176
177
 
177
178
  self.annotation_window.setCursor(Qt.CrossCursor)
178
179
  self.annotation_window.scene.update()
@@ -552,9 +553,9 @@ class SeeAnythingTool(Tool):
552
553
  # Move the points back to the original image space
553
554
  working_area_top_left = self.working_area.rect.topLeft()
554
555
 
555
- task = self.see_anything_dialog.task_dropdown.currentText()
556
556
  masks = None
557
- if task == 'segment':
557
+ # Create masks from the rectangles (these are not polygons)
558
+ if self.see_anything_dialog.task_dropdown.currentText() == 'segment':
558
559
  masks = []
559
560
  for r in self.rectangles:
560
561
  x1, y1, x2, y2 = r
@@ -587,8 +588,8 @@ class SeeAnythingTool(Tool):
587
588
  # Clear previous annotations if any
588
589
  self.clear_annotations()
589
590
 
590
- # Process results based on the task type
591
- if self.see_anything_dialog.task == "segment":
591
+ # Process results based on the task type (creates polygons or rectangle annotations)
592
+ if self.see_anything_dialog.task_dropdown.currentText() == "segment":
592
593
  if self.results.masks:
593
594
  for i, polygon in enumerate(self.results.masks.xyn):
594
595
  confidence = self.results.boxes.conf[i].item()
@@ -624,7 +625,9 @@ class SeeAnythingTool(Tool):
624
625
  box_abs_work_area = box_norm.detach().cpu().numpy() * np.array(
625
626
  [self.work_area_image.shape[1], self.work_area_image.shape[0],
626
627
  self.work_area_image.shape[1], self.work_area_image.shape[0]])
627
- box_area = (box_abs_work_area[2] - box_abs_work_area[0]) * (box_abs_work_area[3] - box_abs_work_area[1])
628
+ # Calculate the area of the bounding box
629
+ box_area = (box_abs_work_area[2] - box_abs_work_area[0]) * \
630
+ (box_abs_work_area[3] - box_abs_work_area[1])
628
631
 
629
632
  # Area filtering
630
633
  min_area = self.main_window.get_area_thresh_min() * image_area