coralnet-toolbox 0.0.74__py2.py3-none-any.whl → 0.0.76__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. coralnet_toolbox/Annotations/QtPolygonAnnotation.py +57 -12
  2. coralnet_toolbox/Annotations/QtRectangleAnnotation.py +44 -14
  3. coralnet_toolbox/Explorer/QtDataItem.py +52 -22
  4. coralnet_toolbox/Explorer/QtExplorer.py +277 -1600
  5. coralnet_toolbox/Explorer/QtSettingsWidgets.py +101 -15
  6. coralnet_toolbox/Explorer/QtViewers.py +1568 -0
  7. coralnet_toolbox/Explorer/transformer_models.py +70 -0
  8. coralnet_toolbox/Explorer/yolo_models.py +112 -0
  9. coralnet_toolbox/IO/QtExportMaskAnnotations.py +538 -403
  10. coralnet_toolbox/Icons/system_monitor.png +0 -0
  11. coralnet_toolbox/MachineLearning/ImportDataset/QtBase.py +239 -147
  12. coralnet_toolbox/MachineLearning/VideoInference/YOLO3D/run.py +102 -16
  13. coralnet_toolbox/QtAnnotationWindow.py +16 -10
  14. coralnet_toolbox/QtEventFilter.py +4 -4
  15. coralnet_toolbox/QtImageWindow.py +3 -7
  16. coralnet_toolbox/QtMainWindow.py +104 -64
  17. coralnet_toolbox/QtProgressBar.py +1 -0
  18. coralnet_toolbox/QtSystemMonitor.py +370 -0
  19. coralnet_toolbox/Rasters/RasterTableModel.py +20 -0
  20. coralnet_toolbox/Results/ConvertResults.py +14 -8
  21. coralnet_toolbox/Results/ResultsProcessor.py +3 -2
  22. coralnet_toolbox/SAM/QtDeployGenerator.py +2 -5
  23. coralnet_toolbox/SAM/QtDeployPredictor.py +11 -3
  24. coralnet_toolbox/SeeAnything/QtDeployGenerator.py +146 -116
  25. coralnet_toolbox/SeeAnything/QtDeployPredictor.py +55 -9
  26. coralnet_toolbox/Tile/QtTileBatchInference.py +4 -4
  27. coralnet_toolbox/Tools/QtPolygonTool.py +42 -3
  28. coralnet_toolbox/Tools/QtRectangleTool.py +30 -0
  29. coralnet_toolbox/Tools/QtSAMTool.py +140 -91
  30. coralnet_toolbox/Transformers/Models/GroundingDINO.py +72 -0
  31. coralnet_toolbox/Transformers/Models/OWLViT.py +72 -0
  32. coralnet_toolbox/Transformers/Models/OmDetTurbo.py +68 -0
  33. coralnet_toolbox/Transformers/Models/QtBase.py +120 -0
  34. coralnet_toolbox/{AutoDistill → Transformers}/Models/__init__.py +1 -1
  35. coralnet_toolbox/{AutoDistill → Transformers}/QtBatchInference.py +15 -15
  36. coralnet_toolbox/{AutoDistill → Transformers}/QtDeployModel.py +18 -16
  37. coralnet_toolbox/{AutoDistill → Transformers}/__init__.py +1 -1
  38. coralnet_toolbox/__init__.py +1 -1
  39. coralnet_toolbox/utilities.py +21 -15
  40. {coralnet_toolbox-0.0.74.dist-info → coralnet_toolbox-0.0.76.dist-info}/METADATA +13 -10
  41. {coralnet_toolbox-0.0.74.dist-info → coralnet_toolbox-0.0.76.dist-info}/RECORD +45 -40
  42. coralnet_toolbox/AutoDistill/Models/GroundingDINO.py +0 -81
  43. coralnet_toolbox/AutoDistill/Models/OWLViT.py +0 -76
  44. coralnet_toolbox/AutoDistill/Models/OmDetTurbo.py +0 -75
  45. coralnet_toolbox/AutoDistill/Models/QtBase.py +0 -112
  46. {coralnet_toolbox-0.0.74.dist-info → coralnet_toolbox-0.0.76.dist-info}/WHEEL +0 -0
  47. {coralnet_toolbox-0.0.74.dist-info → coralnet_toolbox-0.0.76.dist-info}/entry_points.txt +0 -0
  48. {coralnet_toolbox-0.0.74.dist-info → coralnet_toolbox-0.0.76.dist-info}/licenses/LICENSE.txt +0 -0
  49. {coralnet_toolbox-0.0.74.dist-info → coralnet_toolbox-0.0.76.dist-info}/top_level.txt +0 -0
@@ -113,6 +113,36 @@ class RectangleTool(Tool):
113
113
  # Ensure top_left and bottom_right are correctly calculated
114
114
  top_left = QPointF(min(self.start_point.x(), end_point.x()), min(self.start_point.y(), end_point.y()))
115
115
  bottom_right = QPointF(max(self.start_point.x(), end_point.x()), max(self.start_point.y(), end_point.y()))
116
+
117
+ # Calculate width and height of the rectangle
118
+ width = bottom_right.x() - top_left.x()
119
+ height = bottom_right.y() - top_left.y()
120
+
121
+ # Define minimum dimensions for a valid rectangle (e.g., 3x3 pixels)
122
+ MIN_DIMENSION = 3.0
123
+
124
+ # If rectangle is too small and we're finalizing it, enforce minimum size
125
+ if finished and (width < MIN_DIMENSION or height < MIN_DIMENSION):
126
+ if width < MIN_DIMENSION:
127
+ # Expand width while maintaining center
128
+ center_x = (top_left.x() + bottom_right.x()) / 2
129
+ top_left.setX(center_x - MIN_DIMENSION / 2)
130
+ bottom_right.setX(center_x + MIN_DIMENSION / 2)
131
+
132
+ if height < MIN_DIMENSION:
133
+ # Expand height while maintaining center
134
+ center_y = (top_left.y() + bottom_right.y()) / 2
135
+ top_left.setY(center_y - MIN_DIMENSION / 2)
136
+ bottom_right.setY(center_y + MIN_DIMENSION / 2)
137
+
138
+ # Show a message if we had to adjust a very small rectangle
139
+ if width < 1 or height < 1:
140
+ QMessageBox.information(
141
+ self.annotation_window,
142
+ "Rectangle Adjusted",
143
+ f"The rectangle was too small and has been adjusted to a minimum size of "
144
+ f"{MIN_DIMENSION}x{MIN_DIMENSION} pixels."
145
+ )
116
146
 
117
147
  # Create the rectangle annotation
118
148
  annotation = RectangleAnnotation(top_left,
@@ -6,6 +6,8 @@ from PyQt5.QtGui import QMouseEvent, QKeyEvent, QPen, QColor, QBrush, QPainterPa
6
6
  from PyQt5.QtWidgets import QMessageBox, QGraphicsEllipseItem, QGraphicsRectItem, QGraphicsPathItem, QApplication
7
7
 
8
8
  from coralnet_toolbox.Tools.QtTool import Tool
9
+
10
+ from coralnet_toolbox.Annotations.QtRectangleAnnotation import RectangleAnnotation
9
11
  from coralnet_toolbox.Annotations.QtPolygonAnnotation import PolygonAnnotation
10
12
 
11
13
  from coralnet_toolbox.QtWorkArea import WorkArea
@@ -374,45 +376,21 @@ class SAMTool(Tool):
374
376
  top1_index = np.argmax(results.boxes.conf)
375
377
  mask_tensor = results[top1_index].masks.data
376
378
 
377
- # Check if holes are allowed from the SAM dialog
379
+ # Check which output type is selected and get allow_holes settings
380
+ output_type = self.sam_dialog.get_output_type()
378
381
  allow_holes = self.sam_dialog.get_allow_holes()
379
-
380
- # Polygonize the mask to get the exterior and holes
381
- exterior_coords, holes_coords_list = polygonize_mask_with_holes(mask_tensor)
382
-
383
- # Safety check: need at least 3 points for a valid polygon
384
- if len(exterior_coords) < 3:
382
+
383
+ # Create annotation using the helper method
384
+ self.temp_annotation = self.create_annotation_from_mask(
385
+ mask_tensor,
386
+ output_type,
387
+ allow_holes
388
+ )
389
+
390
+ if not self.temp_annotation:
385
391
  QApplication.restoreOverrideCursor()
386
392
  return
387
-
388
- # --- Process and Clean the Polygon Points ---
389
- working_area_top_left = self.working_area.rect.topLeft()
390
- offset_x, offset_y = working_area_top_left.x(), working_area_top_left.y()
391
-
392
- # Simplify, offset, and convert the exterior points
393
- simplified_exterior = simplify_polygon(exterior_coords, 0.1)
394
- self.points = [QPointF(p[0] + offset_x, p[1] + offset_y) for p in simplified_exterior]
395
-
396
- # Simplify, offset, and convert each hole only if allowed
397
- final_holes = []
398
- if allow_holes:
399
- for hole_coords in holes_coords_list:
400
- if len(hole_coords) >= 3: # Ensure holes are also valid polygons
401
- simplified_hole = simplify_polygon(hole_coords, 0.1)
402
- final_holes.append([QPointF(p[0] + offset_x, p[1] + offset_y) for p in simplified_hole])
403
-
404
- # Create the temporary annotation, now with holes (or not)
405
- self.temp_annotation = PolygonAnnotation(
406
- points=self.points,
407
- holes=final_holes,
408
- short_label_code=self.annotation_window.selected_label.short_label_code,
409
- long_label_code=self.annotation_window.selected_label.long_label_code,
410
- color=self.annotation_window.selected_label.color,
411
- image_path=self.annotation_window.current_image_path,
412
- label_id=self.annotation_window.selected_label.id,
413
- transparency=self.main_window.label_window.active_label.transparency
414
- )
415
-
393
+
416
394
  # Create the graphics item for the temporary annotation
417
395
  self.temp_annotation.create_graphics_item(self.annotation_window.scene)
418
396
 
@@ -616,17 +594,31 @@ class SAMTool(Tool):
616
594
  elif self.has_active_prompts:
617
595
  # Create the final annotation
618
596
  if self.temp_annotation:
619
- # Use existing temporary annotation
620
- final_annotation = PolygonAnnotation(
621
- self.points,
622
- self.temp_annotation.label.short_label_code,
623
- self.temp_annotation.label.long_label_code,
624
- self.temp_annotation.label.color,
625
- self.temp_annotation.image_path,
626
- self.temp_annotation.label.id,
627
- self.temp_annotation.label.transparency,
628
- holes=self.temp_annotation.holes
629
- )
597
+ # Check if temp_annotation is a PolygonAnnotation or RectangleAnnotation
598
+ if isinstance(self.temp_annotation, PolygonAnnotation):
599
+ # For polygon annotations, use the points and holes
600
+ final_annotation = PolygonAnnotation(
601
+ self.points,
602
+ self.temp_annotation.label.short_label_code,
603
+ self.temp_annotation.label.long_label_code,
604
+ self.temp_annotation.label.color,
605
+ self.temp_annotation.image_path,
606
+ self.temp_annotation.label.id,
607
+ self.temp_annotation.label.transparency,
608
+ holes=self.temp_annotation.holes
609
+ )
610
+ elif isinstance(self.temp_annotation, RectangleAnnotation):
611
+ # For rectangle annotations, use the top_left and bottom_right
612
+ final_annotation = RectangleAnnotation(
613
+ top_left=self.temp_annotation.top_left,
614
+ bottom_right=self.temp_annotation.bottom_right,
615
+ short_label_code=self.temp_annotation.label.short_label_code,
616
+ long_label_code=self.temp_annotation.label.long_label_code,
617
+ color=self.temp_annotation.label.color,
618
+ image_path=self.temp_annotation.image_path,
619
+ label_id=self.temp_annotation.label.id,
620
+ transparency=self.temp_annotation.label.transparency
621
+ )
630
622
 
631
623
  # Copy confidence data
632
624
  final_annotation.update_machine_confidence(
@@ -740,54 +732,23 @@ class SAMTool(Tool):
740
732
  top1_index = np.argmax(results.boxes.conf)
741
733
  mask_tensor = results[top1_index].masks.data
742
734
 
743
- # Check if holes are allowed from the SAM dialog
735
+ # Check which output type is selected and get allow_holes settings
736
+ output_type = self.sam_dialog.get_output_type()
744
737
  allow_holes = self.sam_dialog.get_allow_holes()
745
-
746
- # Polygonize the mask using the new method to get the exterior and holes
747
- exterior_coords, holes_coords_list = polygonize_mask_with_holes(mask_tensor)
748
-
749
- # Safety check for an empty result
750
- if not exterior_coords:
751
- QApplication.restoreOverrideCursor()
752
- return None
753
-
754
- # --- Process and Clean the Polygon Points ---
755
- working_area_top_left = self.working_area.rect.topLeft()
756
- offset_x, offset_y = working_area_top_left.x(), working_area_top_left.y()
757
-
758
- # Simplify, offset, and convert the exterior points
759
- simplified_exterior = simplify_polygon(exterior_coords, 0.1)
760
- self.points = [QPointF(p[0] + offset_x, p[1] + offset_y) for p in simplified_exterior]
761
-
762
- # Simplify, offset, and convert each hole only if allowed
763
- final_holes = []
764
- if allow_holes:
765
- for hole_coords in holes_coords_list:
766
- if len(hole_coords) >= 3:
767
- simplified_hole = simplify_polygon(hole_coords, 0.1)
768
- final_holes.append([QPointF(p[0] + offset_x, p[1] + offset_y) for p in simplified_hole])
769
-
770
- # Require at least 3 points for valid polygon
771
- if len(self.points) < 3:
738
+
739
+ # Create annotation using the helper method
740
+ annotation = self.create_annotation_from_mask(
741
+ mask_tensor,
742
+ output_type,
743
+ allow_holes
744
+ )
745
+
746
+ if not annotation:
772
747
  QApplication.restoreOverrideCursor()
773
748
  return None
774
749
 
775
- # Get confidence score
776
- confidence = results.boxes.conf[top1_index].item()
777
-
778
- # Create final annotation, now passing the holes argument
779
- annotation = PolygonAnnotation(
780
- points=self.points,
781
- holes=final_holes,
782
- short_label_code=self.annotation_window.selected_label.short_label_code,
783
- long_label_code=self.annotation_window.selected_label.long_label_code,
784
- color=self.annotation_window.selected_label.color,
785
- image_path=self.annotation_window.current_image_path,
786
- label_id=self.annotation_window.selected_label.id,
787
- transparency=self.main_window.label_window.active_label.transparency
788
- )
789
-
790
- # Update confidence
750
+ # Update confidence - make sure to extract confidence from results
751
+ confidence = float(results.boxes.conf[top1_index])
791
752
  annotation.update_machine_confidence({self.annotation_window.selected_label: confidence})
792
753
 
793
754
  # Create cropped image
@@ -799,6 +760,94 @@ class SAMTool(Tool):
799
760
 
800
761
  return annotation
801
762
 
763
+ def create_annotation_from_mask(self, mask_tensor, output_type, allow_holes=True):
764
+ """
765
+ Create annotation (Rectangle or Polygon) from a mask tensor.
766
+
767
+ Args:
768
+ mask_tensor: The tensor containing the mask data
769
+ output_type (str): "Rectangle" or "Polygon"
770
+ allow_holes (bool): Whether to include holes in polygon annotations
771
+
772
+ Returns:
773
+ Annotation object or None if creation fails
774
+ """
775
+ if not self.working_area:
776
+ return None
777
+
778
+ if output_type == "Rectangle":
779
+ # For rectangle output, just get the bounding box of the mask
780
+ # Find the bounding rectangle of the mask
781
+ y_indices, x_indices = np.where(mask_tensor.cpu().numpy()[0] > 0)
782
+ if len(y_indices) == 0 or len(x_indices) == 0:
783
+ return None
784
+
785
+ # Get the min/max coordinates
786
+ min_x, max_x = np.min(x_indices), np.max(x_indices)
787
+ min_y, max_y = np.min(y_indices), np.max(y_indices)
788
+
789
+ # Apply the offset from working area
790
+ working_area_top_left = self.working_area.rect.topLeft()
791
+ offset_x, offset_y = working_area_top_left.x(), working_area_top_left.y()
792
+
793
+ top_left = QPointF(min_x + offset_x, min_y + offset_y)
794
+ bottom_right = QPointF(max_x + offset_x, max_y + offset_y)
795
+
796
+ # Create a rectangle annotation
797
+ annotation = RectangleAnnotation(
798
+ top_left=top_left,
799
+ bottom_right=bottom_right,
800
+ short_label_code=self.annotation_window.selected_label.short_label_code,
801
+ long_label_code=self.annotation_window.selected_label.long_label_code,
802
+ color=self.annotation_window.selected_label.color,
803
+ image_path=self.annotation_window.current_image_path,
804
+ label_id=self.annotation_window.selected_label.id,
805
+ transparency=self.main_window.label_window.active_label.transparency
806
+ )
807
+ else:
808
+ # Original polygon code
809
+ # Polygonize the mask using the new method to get the exterior and holes
810
+ exterior_coords, holes_coords_list = polygonize_mask_with_holes(mask_tensor)
811
+
812
+ # Safety check for an empty result
813
+ if not exterior_coords:
814
+ return None
815
+
816
+ # --- Process and Clean the Polygon Points ---
817
+ working_area_top_left = self.working_area.rect.topLeft()
818
+ offset_x, offset_y = working_area_top_left.x(), working_area_top_left.y()
819
+
820
+ # Simplify, offset, and convert the exterior points
821
+ simplified_exterior = simplify_polygon(exterior_coords, 0.1)
822
+ self.points = [QPointF(p[0] + offset_x, p[1] + offset_y) for p in simplified_exterior]
823
+
824
+ # Simplify, offset, and convert each hole only if allowed
825
+ final_holes = []
826
+ if allow_holes:
827
+ for hole_coords in holes_coords_list:
828
+ simplified_hole = simplify_polygon(hole_coords, 0.1)
829
+ if len(simplified_hole) >= 3:
830
+ hole_points = [QPointF(p[0] + offset_x, p[1] + offset_y) for p in simplified_hole]
831
+ final_holes.append(hole_points)
832
+
833
+ # Require at least 3 points for valid polygon
834
+ if len(self.points) < 3:
835
+ return None
836
+
837
+ # Create final annotation, now passing the holes argument
838
+ annotation = PolygonAnnotation(
839
+ points=self.points,
840
+ holes=final_holes,
841
+ short_label_code=self.annotation_window.selected_label.short_label_code,
842
+ long_label_code=self.annotation_window.selected_label.long_label_code,
843
+ color=self.annotation_window.selected_label.color,
844
+ image_path=self.annotation_window.current_image_path,
845
+ label_id=self.annotation_window.selected_label.id,
846
+ transparency=self.main_window.label_window.active_label.transparency
847
+ )
848
+
849
+ return annotation
850
+
802
851
  def cancel_working_area(self):
803
852
  """
804
853
  Cancel the working area and clean up all associated resources.
@@ -0,0 +1,72 @@
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+
5
+ from ultralytics.engine.results import Results
6
+
7
+ from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
8
+
9
+ from autodistill.detection import CaptionOntology
10
+
11
+ from coralnet_toolbox.Transformers.Models.QtBase import QtBaseModel
12
+
13
+
14
+ # ----------------------------------------------------------------------------------------------------------------------
15
+ # Classes
16
+ # ----------------------------------------------------------------------------------------------------------------------
17
+
18
+
19
+ @dataclass
20
+ class GroundingDINOModel(QtBaseModel):
21
+ def __init__(self, ontology: CaptionOntology, model="SwinB", device: str = "cpu"):
22
+ super().__init__(ontology, device)
23
+
24
+ if model == "SwinB":
25
+ model_name = "IDEA-Research/grounding-dino-base"
26
+ else:
27
+ model_name = "IDEA-Research/grounding-dino-tiny"
28
+
29
+ self.processor = AutoProcessor.from_pretrained(model_name, use_fast=True)
30
+ self.model = AutoModelForZeroShotObjectDetection.from_pretrained(model_name).to(self.device)
31
+
32
+ def _process_predictions(self, image, texts, confidence):
33
+ """Process model predictions for a single image."""
34
+ inputs = self.processor(text=texts, images=image, return_tensors="pt").to(self.device)
35
+ outputs = self.model(**inputs)
36
+
37
+ results_processed = self.processor.post_process_grounded_object_detection(
38
+ outputs,
39
+ inputs.input_ids,
40
+ threshold=confidence,
41
+ target_sizes=[image.shape[:2]],
42
+ )[0]
43
+
44
+ boxes = results_processed["boxes"]
45
+ scores = results_processed["scores"]
46
+
47
+ # If no objects are detected, return an empty list to match the original behavior.
48
+ if scores.nelement() == 0:
49
+ return []
50
+
51
+ # Per original logic, assign all detections to class_id 0.
52
+ # TODO: We are only supporting a single class right now
53
+ class_ids = torch.zeros(scores.shape[0], 1, device=self.device)
54
+
55
+ # Combine boxes, scores, and class_ids into the (N, 6) tensor format
56
+ # required by the Results object: [x1, y1, x2, y2, confidence, class_id]
57
+ combined_data = torch.cat([
58
+ boxes,
59
+ scores.unsqueeze(1),
60
+ class_ids
61
+ ], dim=1)
62
+
63
+ # Create the dictionary mapping class indices to class names.
64
+ names = {idx: text for idx, text in enumerate(self.ontology.classes())}
65
+
66
+ # Create the Results object with a DETACHED tensor
67
+ result = Results(orig_img=image,
68
+ path=None,
69
+ names=names,
70
+ boxes=combined_data.detach().cpu())
71
+
72
+ return result
@@ -0,0 +1,72 @@
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+
5
+ from ultralytics.engine.results import Results
6
+
7
+ from transformers import OwlViTForObjectDetection, OwlViTProcessor
8
+
9
+ from autodistill.detection import CaptionOntology
10
+
11
+ from coralnet_toolbox.Transformers.Models.QtBase import QtBaseModel
12
+
13
+
14
+ # ----------------------------------------------------------------------------------------------------------------------
15
+ # Classes
16
+ # ----------------------------------------------------------------------------------------------------------------------
17
+
18
+
19
+ @dataclass
20
+ class OWLViTModel(QtBaseModel):
21
+ def __init__(self, ontology: CaptionOntology, device: str = "cpu"):
22
+ super().__init__(ontology, device)
23
+
24
+ model_name = "google/owlvit-base-patch32"
25
+ self.processor = OwlViTProcessor.from_pretrained(model_name, use_fast=True)
26
+ self.model = OwlViTForObjectDetection.from_pretrained(model_name).to(self.device)
27
+
28
+ def _process_predictions(self, image, texts, confidence):
29
+ """
30
+ Process model predictions for a single image, converting directly
31
+ to an Ultralytics Results object without an intermediate Supervision object.
32
+ """
33
+ inputs = self.processor(text=texts, images=image, return_tensors="pt").to(self.device)
34
+ outputs = self.model(**inputs)
35
+
36
+ # Post-process the outputs to get detections.
37
+ # The confidence threshold is applied during this step.
38
+ results_processed = self.processor.post_process_object_detection(
39
+ outputs,
40
+ threshold=confidence,
41
+ target_sizes=[image.shape[:2]]
42
+ )[0]
43
+
44
+ boxes = results_processed["boxes"]
45
+ scores = results_processed["scores"]
46
+
47
+ # If no objects are detected, return an empty list to match the original behavior.
48
+ if scores.nelement() == 0:
49
+ return []
50
+
51
+ # Per original logic, assign all detections to class_id 0.
52
+ # TODO: We are only supporting a single class right now
53
+ class_ids = torch.zeros(scores.shape[0], 1, device=self.device)
54
+
55
+ # Combine boxes, scores, and class_ids into the (N, 6) tensor format
56
+ # required by the Results object: [x1, y1, x2, y2, confidence, class_id]
57
+ combined_data = torch.cat([
58
+ boxes,
59
+ scores.unsqueeze(1),
60
+ class_ids
61
+ ], dim=1)
62
+
63
+ # Create the dictionary mapping class indices to class names.
64
+ names = {idx: text for idx, text in enumerate(self.ontology.classes())}
65
+
66
+ # Create the Results object with a DETACHED tensor
67
+ result = Results(orig_img=image,
68
+ path=None,
69
+ names=names,
70
+ boxes=combined_data.detach().cpu())
71
+
72
+ return result
@@ -0,0 +1,68 @@
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+
5
+ from ultralytics.engine.results import Results
6
+
7
+ from transformers import AutoProcessor, OmDetTurboForObjectDetection
8
+
9
+ from autodistill.detection import CaptionOntology
10
+
11
+ from coralnet_toolbox.Transformers.Models.QtBase import QtBaseModel
12
+
13
+
14
+ # ----------------------------------------------------------------------------------------------------------------------
15
+ # Classes
16
+ # ----------------------------------------------------------------------------------------------------------------------
17
+
18
+
19
+ @dataclass
20
+ class OmDetTurboModel(QtBaseModel):
21
+ def __init__(self, ontology: CaptionOntology, device: str = "cpu"):
22
+ super().__init__(ontology, device)
23
+
24
+ model_name = "omlab/omdet-turbo-swin-tiny-hf"
25
+ self.processor = AutoProcessor.from_pretrained(model_name, use_fast=True)
26
+ self.model = OmDetTurboForObjectDetection.from_pretrained(model_name).to(self.device)
27
+
28
+ def _process_predictions(self, image, texts, confidence):
29
+ """Process model predictions for a single image."""
30
+ inputs = self.processor(text=texts, images=image, return_tensors="pt").to(self.device)
31
+ outputs = self.model(**inputs)
32
+
33
+ results_processed = self.processor.post_process_grounded_object_detection(
34
+ outputs,
35
+ threshold=confidence,
36
+ target_sizes=[image.shape[:2]],
37
+ text_labels=texts,
38
+ )[0]
39
+
40
+ boxes = results_processed["boxes"]
41
+ scores = results_processed["scores"]
42
+
43
+ # If no objects are detected, return an empty list to match the original behavior.
44
+ if scores.nelement() == 0:
45
+ return []
46
+
47
+ # Per original logic, assign all detections to class_id 0.
48
+ # TODO: We are only supporting a single class right now
49
+ class_ids = torch.zeros(scores.shape[0], 1, device=self.device)
50
+
51
+ # Combine boxes, scores, and class_ids into the (N, 6) tensor format
52
+ # required by the Results object: [x1, y1, x2, y2, confidence, class_id]
53
+ combined_data = torch.cat([
54
+ boxes,
55
+ scores.unsqueeze(1),
56
+ class_ids
57
+ ], dim=1)
58
+
59
+ # Create the dictionary mapping class indices to class names.
60
+ names = {idx: text for idx, text in enumerate(self.ontology.classes())}
61
+
62
+ # Create the Results object with a DETACHED tensor
63
+ result = Results(orig_img=image,
64
+ path=None,
65
+ names=names,
66
+ boxes=combined_data.detach().cpu())
67
+
68
+ return result
@@ -0,0 +1,120 @@
1
+ from dataclasses import dataclass
2
+ from abc import ABC, abstractmethod
3
+
4
+ import cv2
5
+ import numpy as np
6
+
7
+ from ultralytics.engine.results import Results
8
+
9
+ from autodistill.detection import CaptionOntology, DetectionBaseModel
10
+ from autodistill.helpers import load_image
11
+
12
+ from coralnet_toolbox.Results import CombineResults
13
+
14
+
15
+ # ----------------------------------------------------------------------------------------------------------------------
16
+ # Classes
17
+ # ----------------------------------------------------------------------------------------------------------------------
18
+
19
+
20
+ @dataclass
21
+ class QtBaseModel(DetectionBaseModel, ABC):
22
+ """
23
+ Base class for Transformer foundation models that provides common functionality for
24
+ handling inputs, processing image data, and formatting detection results.
25
+ """
26
+ ontology: CaptionOntology
27
+
28
+ def __init__(self, ontology: CaptionOntology, device: str = "cpu"):
29
+ """
30
+ Initialize the base model with ontology and device.
31
+
32
+ Args:
33
+ ontology: The CaptionOntology containing class labels
34
+ device: The compute device (cpu, cuda, etc.)
35
+ """
36
+ self.ontology = ontology
37
+ self.device = device
38
+ self.processor = None
39
+ self.model = None
40
+
41
+ def _normalize_input(self, input) -> list[np.ndarray]:
42
+ """
43
+ Normalizes various input types into a list of images in CV2 (BGR) format.
44
+
45
+ Args:
46
+ input: Can be an image path, a list of paths, a numpy array, or a list of numpy arrays.
47
+
48
+ Returns:
49
+ A list of images, each as a numpy array in CV2 (BGR) format.
50
+ """
51
+ images = []
52
+ if isinstance(input, str):
53
+ # Single image path
54
+ images = [load_image(input, return_format="cv2")]
55
+ elif isinstance(input, np.ndarray):
56
+ # Single image numpy array (RGB) or a batch of images (NHWC, RGB)
57
+ if input.ndim == 3:
58
+ images = [cv2.cvtColor(input, cv2.COLOR_RGB2BGR)]
59
+ elif input.ndim == 4:
60
+ images = [cv2.cvtColor(img, cv2.COLOR_RGB2BGR) for img in input]
61
+ else:
62
+ raise ValueError(f"Unsupported numpy array dimensions: {input.ndim}")
63
+ elif isinstance(input, list):
64
+ if all(isinstance(i, str) for i in input):
65
+ # List of image paths
66
+ images = [load_image(path, return_format="cv2") for path in input]
67
+ elif all(isinstance(i, np.ndarray) for i in input):
68
+ # List of image arrays (RGB)
69
+ images = [cv2.cvtColor(img, cv2.COLOR_RGB2BGR) for img in input]
70
+ else:
71
+ raise ValueError("A list input must contain either all image paths or all numpy arrays.")
72
+ else:
73
+ raise TypeError(f"Unsupported input type: {type(input)}")
74
+
75
+ return images
76
+
77
+ @abstractmethod
78
+ def _process_predictions(self, image: np.ndarray, texts: list[str], confidence: float) -> Results:
79
+ """
80
+ Process model predictions for a single image.
81
+
82
+ Args:
83
+ image: The input image in CV2 (BGR) format.
84
+ texts: The text prompts from the ontology.
85
+ confidence: Confidence threshold.
86
+
87
+ Returns:
88
+ A single Ultralytics Results object, which may be empty if no detections are found.
89
+ """
90
+ pass
91
+
92
+ def predict(self, inputs, confidence=0.01) -> list[Results]:
93
+ """
94
+ Run inference on input images.
95
+
96
+ Args:
97
+ inputs: Can be an image path, a list of image paths, a numpy array, or a list of numpy arrays.
98
+ confidence: Detection confidence threshold.
99
+
100
+ Returns:
101
+ A flat list of Ultralytics Results objects, one for each input image.
102
+ """
103
+ # Step 1: Normalize the input into a consistent list of images
104
+ normalized_inputs = self._normalize_input(inputs)
105
+
106
+ # Step 2: Prepare for inference
107
+ results = []
108
+ texts = self.ontology.prompts()
109
+
110
+ # Step 3: Loop through images and process predictions
111
+ for normalized_input in normalized_inputs:
112
+ result = self._process_predictions(normalized_input, texts, confidence)
113
+ if result:
114
+ results.append(result)
115
+
116
+ if len(results):
117
+ # Combine the results into one, then wrap in a list
118
+ results = CombineResults().combine_results(results)
119
+
120
+ return [results] if results else []
@@ -1,4 +1,4 @@
1
- # coralnet_toolbox/AutoDistill/Models/__init__.py
1
+ # coralnet_toolbox/Transformers/Models/__init__.py
2
2
 
3
3
  from .GroundingDINO import GroundingDINOModel
4
4
  from .OWLViT import OWLViTModel