coralnet-toolbox 0.0.74__py2.py3-none-any.whl → 0.0.76__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- coralnet_toolbox/Annotations/QtPolygonAnnotation.py +57 -12
- coralnet_toolbox/Annotations/QtRectangleAnnotation.py +44 -14
- coralnet_toolbox/Explorer/QtDataItem.py +52 -22
- coralnet_toolbox/Explorer/QtExplorer.py +277 -1600
- coralnet_toolbox/Explorer/QtSettingsWidgets.py +101 -15
- coralnet_toolbox/Explorer/QtViewers.py +1568 -0
- coralnet_toolbox/Explorer/transformer_models.py +70 -0
- coralnet_toolbox/Explorer/yolo_models.py +112 -0
- coralnet_toolbox/IO/QtExportMaskAnnotations.py +538 -403
- coralnet_toolbox/Icons/system_monitor.png +0 -0
- coralnet_toolbox/MachineLearning/ImportDataset/QtBase.py +239 -147
- coralnet_toolbox/MachineLearning/VideoInference/YOLO3D/run.py +102 -16
- coralnet_toolbox/QtAnnotationWindow.py +16 -10
- coralnet_toolbox/QtEventFilter.py +4 -4
- coralnet_toolbox/QtImageWindow.py +3 -7
- coralnet_toolbox/QtMainWindow.py +104 -64
- coralnet_toolbox/QtProgressBar.py +1 -0
- coralnet_toolbox/QtSystemMonitor.py +370 -0
- coralnet_toolbox/Rasters/RasterTableModel.py +20 -0
- coralnet_toolbox/Results/ConvertResults.py +14 -8
- coralnet_toolbox/Results/ResultsProcessor.py +3 -2
- coralnet_toolbox/SAM/QtDeployGenerator.py +2 -5
- coralnet_toolbox/SAM/QtDeployPredictor.py +11 -3
- coralnet_toolbox/SeeAnything/QtDeployGenerator.py +146 -116
- coralnet_toolbox/SeeAnything/QtDeployPredictor.py +55 -9
- coralnet_toolbox/Tile/QtTileBatchInference.py +4 -4
- coralnet_toolbox/Tools/QtPolygonTool.py +42 -3
- coralnet_toolbox/Tools/QtRectangleTool.py +30 -0
- coralnet_toolbox/Tools/QtSAMTool.py +140 -91
- coralnet_toolbox/Transformers/Models/GroundingDINO.py +72 -0
- coralnet_toolbox/Transformers/Models/OWLViT.py +72 -0
- coralnet_toolbox/Transformers/Models/OmDetTurbo.py +68 -0
- coralnet_toolbox/Transformers/Models/QtBase.py +120 -0
- coralnet_toolbox/{AutoDistill → Transformers}/Models/__init__.py +1 -1
- coralnet_toolbox/{AutoDistill → Transformers}/QtBatchInference.py +15 -15
- coralnet_toolbox/{AutoDistill → Transformers}/QtDeployModel.py +18 -16
- coralnet_toolbox/{AutoDistill → Transformers}/__init__.py +1 -1
- coralnet_toolbox/__init__.py +1 -1
- coralnet_toolbox/utilities.py +21 -15
- {coralnet_toolbox-0.0.74.dist-info → coralnet_toolbox-0.0.76.dist-info}/METADATA +13 -10
- {coralnet_toolbox-0.0.74.dist-info → coralnet_toolbox-0.0.76.dist-info}/RECORD +45 -40
- coralnet_toolbox/AutoDistill/Models/GroundingDINO.py +0 -81
- coralnet_toolbox/AutoDistill/Models/OWLViT.py +0 -76
- coralnet_toolbox/AutoDistill/Models/OmDetTurbo.py +0 -75
- coralnet_toolbox/AutoDistill/Models/QtBase.py +0 -112
- {coralnet_toolbox-0.0.74.dist-info → coralnet_toolbox-0.0.76.dist-info}/WHEEL +0 -0
- {coralnet_toolbox-0.0.74.dist-info → coralnet_toolbox-0.0.76.dist-info}/entry_points.txt +0 -0
- {coralnet_toolbox-0.0.74.dist-info → coralnet_toolbox-0.0.76.dist-info}/licenses/LICENSE.txt +0 -0
- {coralnet_toolbox-0.0.74.dist-info → coralnet_toolbox-0.0.76.dist-info}/top_level.txt +0 -0
@@ -113,6 +113,36 @@ class RectangleTool(Tool):
|
|
113
113
|
# Ensure top_left and bottom_right are correctly calculated
|
114
114
|
top_left = QPointF(min(self.start_point.x(), end_point.x()), min(self.start_point.y(), end_point.y()))
|
115
115
|
bottom_right = QPointF(max(self.start_point.x(), end_point.x()), max(self.start_point.y(), end_point.y()))
|
116
|
+
|
117
|
+
# Calculate width and height of the rectangle
|
118
|
+
width = bottom_right.x() - top_left.x()
|
119
|
+
height = bottom_right.y() - top_left.y()
|
120
|
+
|
121
|
+
# Define minimum dimensions for a valid rectangle (e.g., 3x3 pixels)
|
122
|
+
MIN_DIMENSION = 3.0
|
123
|
+
|
124
|
+
# If rectangle is too small and we're finalizing it, enforce minimum size
|
125
|
+
if finished and (width < MIN_DIMENSION or height < MIN_DIMENSION):
|
126
|
+
if width < MIN_DIMENSION:
|
127
|
+
# Expand width while maintaining center
|
128
|
+
center_x = (top_left.x() + bottom_right.x()) / 2
|
129
|
+
top_left.setX(center_x - MIN_DIMENSION / 2)
|
130
|
+
bottom_right.setX(center_x + MIN_DIMENSION / 2)
|
131
|
+
|
132
|
+
if height < MIN_DIMENSION:
|
133
|
+
# Expand height while maintaining center
|
134
|
+
center_y = (top_left.y() + bottom_right.y()) / 2
|
135
|
+
top_left.setY(center_y - MIN_DIMENSION / 2)
|
136
|
+
bottom_right.setY(center_y + MIN_DIMENSION / 2)
|
137
|
+
|
138
|
+
# Show a message if we had to adjust a very small rectangle
|
139
|
+
if width < 1 or height < 1:
|
140
|
+
QMessageBox.information(
|
141
|
+
self.annotation_window,
|
142
|
+
"Rectangle Adjusted",
|
143
|
+
f"The rectangle was too small and has been adjusted to a minimum size of "
|
144
|
+
f"{MIN_DIMENSION}x{MIN_DIMENSION} pixels."
|
145
|
+
)
|
116
146
|
|
117
147
|
# Create the rectangle annotation
|
118
148
|
annotation = RectangleAnnotation(top_left,
|
@@ -6,6 +6,8 @@ from PyQt5.QtGui import QMouseEvent, QKeyEvent, QPen, QColor, QBrush, QPainterPa
|
|
6
6
|
from PyQt5.QtWidgets import QMessageBox, QGraphicsEllipseItem, QGraphicsRectItem, QGraphicsPathItem, QApplication
|
7
7
|
|
8
8
|
from coralnet_toolbox.Tools.QtTool import Tool
|
9
|
+
|
10
|
+
from coralnet_toolbox.Annotations.QtRectangleAnnotation import RectangleAnnotation
|
9
11
|
from coralnet_toolbox.Annotations.QtPolygonAnnotation import PolygonAnnotation
|
10
12
|
|
11
13
|
from coralnet_toolbox.QtWorkArea import WorkArea
|
@@ -374,45 +376,21 @@ class SAMTool(Tool):
|
|
374
376
|
top1_index = np.argmax(results.boxes.conf)
|
375
377
|
mask_tensor = results[top1_index].masks.data
|
376
378
|
|
377
|
-
# Check
|
379
|
+
# Check which output type is selected and get allow_holes settings
|
380
|
+
output_type = self.sam_dialog.get_output_type()
|
378
381
|
allow_holes = self.sam_dialog.get_allow_holes()
|
379
|
-
|
380
|
-
#
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
382
|
+
|
383
|
+
# Create annotation using the helper method
|
384
|
+
self.temp_annotation = self.create_annotation_from_mask(
|
385
|
+
mask_tensor,
|
386
|
+
output_type,
|
387
|
+
allow_holes
|
388
|
+
)
|
389
|
+
|
390
|
+
if not self.temp_annotation:
|
385
391
|
QApplication.restoreOverrideCursor()
|
386
392
|
return
|
387
|
-
|
388
|
-
# --- Process and Clean the Polygon Points ---
|
389
|
-
working_area_top_left = self.working_area.rect.topLeft()
|
390
|
-
offset_x, offset_y = working_area_top_left.x(), working_area_top_left.y()
|
391
|
-
|
392
|
-
# Simplify, offset, and convert the exterior points
|
393
|
-
simplified_exterior = simplify_polygon(exterior_coords, 0.1)
|
394
|
-
self.points = [QPointF(p[0] + offset_x, p[1] + offset_y) for p in simplified_exterior]
|
395
|
-
|
396
|
-
# Simplify, offset, and convert each hole only if allowed
|
397
|
-
final_holes = []
|
398
|
-
if allow_holes:
|
399
|
-
for hole_coords in holes_coords_list:
|
400
|
-
if len(hole_coords) >= 3: # Ensure holes are also valid polygons
|
401
|
-
simplified_hole = simplify_polygon(hole_coords, 0.1)
|
402
|
-
final_holes.append([QPointF(p[0] + offset_x, p[1] + offset_y) for p in simplified_hole])
|
403
|
-
|
404
|
-
# Create the temporary annotation, now with holes (or not)
|
405
|
-
self.temp_annotation = PolygonAnnotation(
|
406
|
-
points=self.points,
|
407
|
-
holes=final_holes,
|
408
|
-
short_label_code=self.annotation_window.selected_label.short_label_code,
|
409
|
-
long_label_code=self.annotation_window.selected_label.long_label_code,
|
410
|
-
color=self.annotation_window.selected_label.color,
|
411
|
-
image_path=self.annotation_window.current_image_path,
|
412
|
-
label_id=self.annotation_window.selected_label.id,
|
413
|
-
transparency=self.main_window.label_window.active_label.transparency
|
414
|
-
)
|
415
|
-
|
393
|
+
|
416
394
|
# Create the graphics item for the temporary annotation
|
417
395
|
self.temp_annotation.create_graphics_item(self.annotation_window.scene)
|
418
396
|
|
@@ -616,17 +594,31 @@ class SAMTool(Tool):
|
|
616
594
|
elif self.has_active_prompts:
|
617
595
|
# Create the final annotation
|
618
596
|
if self.temp_annotation:
|
619
|
-
#
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
597
|
+
# Check if temp_annotation is a PolygonAnnotation or RectangleAnnotation
|
598
|
+
if isinstance(self.temp_annotation, PolygonAnnotation):
|
599
|
+
# For polygon annotations, use the points and holes
|
600
|
+
final_annotation = PolygonAnnotation(
|
601
|
+
self.points,
|
602
|
+
self.temp_annotation.label.short_label_code,
|
603
|
+
self.temp_annotation.label.long_label_code,
|
604
|
+
self.temp_annotation.label.color,
|
605
|
+
self.temp_annotation.image_path,
|
606
|
+
self.temp_annotation.label.id,
|
607
|
+
self.temp_annotation.label.transparency,
|
608
|
+
holes=self.temp_annotation.holes
|
609
|
+
)
|
610
|
+
elif isinstance(self.temp_annotation, RectangleAnnotation):
|
611
|
+
# For rectangle annotations, use the top_left and bottom_right
|
612
|
+
final_annotation = RectangleAnnotation(
|
613
|
+
top_left=self.temp_annotation.top_left,
|
614
|
+
bottom_right=self.temp_annotation.bottom_right,
|
615
|
+
short_label_code=self.temp_annotation.label.short_label_code,
|
616
|
+
long_label_code=self.temp_annotation.label.long_label_code,
|
617
|
+
color=self.temp_annotation.label.color,
|
618
|
+
image_path=self.temp_annotation.image_path,
|
619
|
+
label_id=self.temp_annotation.label.id,
|
620
|
+
transparency=self.temp_annotation.label.transparency
|
621
|
+
)
|
630
622
|
|
631
623
|
# Copy confidence data
|
632
624
|
final_annotation.update_machine_confidence(
|
@@ -740,54 +732,23 @@ class SAMTool(Tool):
|
|
740
732
|
top1_index = np.argmax(results.boxes.conf)
|
741
733
|
mask_tensor = results[top1_index].masks.data
|
742
734
|
|
743
|
-
# Check
|
735
|
+
# Check which output type is selected and get allow_holes settings
|
736
|
+
output_type = self.sam_dialog.get_output_type()
|
744
737
|
allow_holes = self.sam_dialog.get_allow_holes()
|
745
|
-
|
746
|
-
#
|
747
|
-
|
748
|
-
|
749
|
-
|
750
|
-
|
751
|
-
|
752
|
-
|
753
|
-
|
754
|
-
# --- Process and Clean the Polygon Points ---
|
755
|
-
working_area_top_left = self.working_area.rect.topLeft()
|
756
|
-
offset_x, offset_y = working_area_top_left.x(), working_area_top_left.y()
|
757
|
-
|
758
|
-
# Simplify, offset, and convert the exterior points
|
759
|
-
simplified_exterior = simplify_polygon(exterior_coords, 0.1)
|
760
|
-
self.points = [QPointF(p[0] + offset_x, p[1] + offset_y) for p in simplified_exterior]
|
761
|
-
|
762
|
-
# Simplify, offset, and convert each hole only if allowed
|
763
|
-
final_holes = []
|
764
|
-
if allow_holes:
|
765
|
-
for hole_coords in holes_coords_list:
|
766
|
-
if len(hole_coords) >= 3:
|
767
|
-
simplified_hole = simplify_polygon(hole_coords, 0.1)
|
768
|
-
final_holes.append([QPointF(p[0] + offset_x, p[1] + offset_y) for p in simplified_hole])
|
769
|
-
|
770
|
-
# Require at least 3 points for valid polygon
|
771
|
-
if len(self.points) < 3:
|
738
|
+
|
739
|
+
# Create annotation using the helper method
|
740
|
+
annotation = self.create_annotation_from_mask(
|
741
|
+
mask_tensor,
|
742
|
+
output_type,
|
743
|
+
allow_holes
|
744
|
+
)
|
745
|
+
|
746
|
+
if not annotation:
|
772
747
|
QApplication.restoreOverrideCursor()
|
773
748
|
return None
|
774
749
|
|
775
|
-
#
|
776
|
-
confidence = results.boxes.conf[top1_index]
|
777
|
-
|
778
|
-
# Create final annotation, now passing the holes argument
|
779
|
-
annotation = PolygonAnnotation(
|
780
|
-
points=self.points,
|
781
|
-
holes=final_holes,
|
782
|
-
short_label_code=self.annotation_window.selected_label.short_label_code,
|
783
|
-
long_label_code=self.annotation_window.selected_label.long_label_code,
|
784
|
-
color=self.annotation_window.selected_label.color,
|
785
|
-
image_path=self.annotation_window.current_image_path,
|
786
|
-
label_id=self.annotation_window.selected_label.id,
|
787
|
-
transparency=self.main_window.label_window.active_label.transparency
|
788
|
-
)
|
789
|
-
|
790
|
-
# Update confidence
|
750
|
+
# Update confidence - make sure to extract confidence from results
|
751
|
+
confidence = float(results.boxes.conf[top1_index])
|
791
752
|
annotation.update_machine_confidence({self.annotation_window.selected_label: confidence})
|
792
753
|
|
793
754
|
# Create cropped image
|
@@ -799,6 +760,94 @@ class SAMTool(Tool):
|
|
799
760
|
|
800
761
|
return annotation
|
801
762
|
|
763
|
+
def create_annotation_from_mask(self, mask_tensor, output_type, allow_holes=True):
|
764
|
+
"""
|
765
|
+
Create annotation (Rectangle or Polygon) from a mask tensor.
|
766
|
+
|
767
|
+
Args:
|
768
|
+
mask_tensor: The tensor containing the mask data
|
769
|
+
output_type (str): "Rectangle" or "Polygon"
|
770
|
+
allow_holes (bool): Whether to include holes in polygon annotations
|
771
|
+
|
772
|
+
Returns:
|
773
|
+
Annotation object or None if creation fails
|
774
|
+
"""
|
775
|
+
if not self.working_area:
|
776
|
+
return None
|
777
|
+
|
778
|
+
if output_type == "Rectangle":
|
779
|
+
# For rectangle output, just get the bounding box of the mask
|
780
|
+
# Find the bounding rectangle of the mask
|
781
|
+
y_indices, x_indices = np.where(mask_tensor.cpu().numpy()[0] > 0)
|
782
|
+
if len(y_indices) == 0 or len(x_indices) == 0:
|
783
|
+
return None
|
784
|
+
|
785
|
+
# Get the min/max coordinates
|
786
|
+
min_x, max_x = np.min(x_indices), np.max(x_indices)
|
787
|
+
min_y, max_y = np.min(y_indices), np.max(y_indices)
|
788
|
+
|
789
|
+
# Apply the offset from working area
|
790
|
+
working_area_top_left = self.working_area.rect.topLeft()
|
791
|
+
offset_x, offset_y = working_area_top_left.x(), working_area_top_left.y()
|
792
|
+
|
793
|
+
top_left = QPointF(min_x + offset_x, min_y + offset_y)
|
794
|
+
bottom_right = QPointF(max_x + offset_x, max_y + offset_y)
|
795
|
+
|
796
|
+
# Create a rectangle annotation
|
797
|
+
annotation = RectangleAnnotation(
|
798
|
+
top_left=top_left,
|
799
|
+
bottom_right=bottom_right,
|
800
|
+
short_label_code=self.annotation_window.selected_label.short_label_code,
|
801
|
+
long_label_code=self.annotation_window.selected_label.long_label_code,
|
802
|
+
color=self.annotation_window.selected_label.color,
|
803
|
+
image_path=self.annotation_window.current_image_path,
|
804
|
+
label_id=self.annotation_window.selected_label.id,
|
805
|
+
transparency=self.main_window.label_window.active_label.transparency
|
806
|
+
)
|
807
|
+
else:
|
808
|
+
# Original polygon code
|
809
|
+
# Polygonize the mask using the new method to get the exterior and holes
|
810
|
+
exterior_coords, holes_coords_list = polygonize_mask_with_holes(mask_tensor)
|
811
|
+
|
812
|
+
# Safety check for an empty result
|
813
|
+
if not exterior_coords:
|
814
|
+
return None
|
815
|
+
|
816
|
+
# --- Process and Clean the Polygon Points ---
|
817
|
+
working_area_top_left = self.working_area.rect.topLeft()
|
818
|
+
offset_x, offset_y = working_area_top_left.x(), working_area_top_left.y()
|
819
|
+
|
820
|
+
# Simplify, offset, and convert the exterior points
|
821
|
+
simplified_exterior = simplify_polygon(exterior_coords, 0.1)
|
822
|
+
self.points = [QPointF(p[0] + offset_x, p[1] + offset_y) for p in simplified_exterior]
|
823
|
+
|
824
|
+
# Simplify, offset, and convert each hole only if allowed
|
825
|
+
final_holes = []
|
826
|
+
if allow_holes:
|
827
|
+
for hole_coords in holes_coords_list:
|
828
|
+
simplified_hole = simplify_polygon(hole_coords, 0.1)
|
829
|
+
if len(simplified_hole) >= 3:
|
830
|
+
hole_points = [QPointF(p[0] + offset_x, p[1] + offset_y) for p in simplified_hole]
|
831
|
+
final_holes.append(hole_points)
|
832
|
+
|
833
|
+
# Require at least 3 points for valid polygon
|
834
|
+
if len(self.points) < 3:
|
835
|
+
return None
|
836
|
+
|
837
|
+
# Create final annotation, now passing the holes argument
|
838
|
+
annotation = PolygonAnnotation(
|
839
|
+
points=self.points,
|
840
|
+
holes=final_holes,
|
841
|
+
short_label_code=self.annotation_window.selected_label.short_label_code,
|
842
|
+
long_label_code=self.annotation_window.selected_label.long_label_code,
|
843
|
+
color=self.annotation_window.selected_label.color,
|
844
|
+
image_path=self.annotation_window.current_image_path,
|
845
|
+
label_id=self.annotation_window.selected_label.id,
|
846
|
+
transparency=self.main_window.label_window.active_label.transparency
|
847
|
+
)
|
848
|
+
|
849
|
+
return annotation
|
850
|
+
|
802
851
|
def cancel_working_area(self):
|
803
852
|
"""
|
804
853
|
Cancel the working area and clean up all associated resources.
|
@@ -0,0 +1,72 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
|
3
|
+
import torch
|
4
|
+
|
5
|
+
from ultralytics.engine.results import Results
|
6
|
+
|
7
|
+
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
|
8
|
+
|
9
|
+
from autodistill.detection import CaptionOntology
|
10
|
+
|
11
|
+
from coralnet_toolbox.Transformers.Models.QtBase import QtBaseModel
|
12
|
+
|
13
|
+
|
14
|
+
# ----------------------------------------------------------------------------------------------------------------------
|
15
|
+
# Classes
|
16
|
+
# ----------------------------------------------------------------------------------------------------------------------
|
17
|
+
|
18
|
+
|
19
|
+
@dataclass
|
20
|
+
class GroundingDINOModel(QtBaseModel):
|
21
|
+
def __init__(self, ontology: CaptionOntology, model="SwinB", device: str = "cpu"):
|
22
|
+
super().__init__(ontology, device)
|
23
|
+
|
24
|
+
if model == "SwinB":
|
25
|
+
model_name = "IDEA-Research/grounding-dino-base"
|
26
|
+
else:
|
27
|
+
model_name = "IDEA-Research/grounding-dino-tiny"
|
28
|
+
|
29
|
+
self.processor = AutoProcessor.from_pretrained(model_name, use_fast=True)
|
30
|
+
self.model = AutoModelForZeroShotObjectDetection.from_pretrained(model_name).to(self.device)
|
31
|
+
|
32
|
+
def _process_predictions(self, image, texts, confidence):
|
33
|
+
"""Process model predictions for a single image."""
|
34
|
+
inputs = self.processor(text=texts, images=image, return_tensors="pt").to(self.device)
|
35
|
+
outputs = self.model(**inputs)
|
36
|
+
|
37
|
+
results_processed = self.processor.post_process_grounded_object_detection(
|
38
|
+
outputs,
|
39
|
+
inputs.input_ids,
|
40
|
+
threshold=confidence,
|
41
|
+
target_sizes=[image.shape[:2]],
|
42
|
+
)[0]
|
43
|
+
|
44
|
+
boxes = results_processed["boxes"]
|
45
|
+
scores = results_processed["scores"]
|
46
|
+
|
47
|
+
# If no objects are detected, return an empty list to match the original behavior.
|
48
|
+
if scores.nelement() == 0:
|
49
|
+
return []
|
50
|
+
|
51
|
+
# Per original logic, assign all detections to class_id 0.
|
52
|
+
# TODO: We are only supporting a single class right now
|
53
|
+
class_ids = torch.zeros(scores.shape[0], 1, device=self.device)
|
54
|
+
|
55
|
+
# Combine boxes, scores, and class_ids into the (N, 6) tensor format
|
56
|
+
# required by the Results object: [x1, y1, x2, y2, confidence, class_id]
|
57
|
+
combined_data = torch.cat([
|
58
|
+
boxes,
|
59
|
+
scores.unsqueeze(1),
|
60
|
+
class_ids
|
61
|
+
], dim=1)
|
62
|
+
|
63
|
+
# Create the dictionary mapping class indices to class names.
|
64
|
+
names = {idx: text for idx, text in enumerate(self.ontology.classes())}
|
65
|
+
|
66
|
+
# Create the Results object with a DETACHED tensor
|
67
|
+
result = Results(orig_img=image,
|
68
|
+
path=None,
|
69
|
+
names=names,
|
70
|
+
boxes=combined_data.detach().cpu())
|
71
|
+
|
72
|
+
return result
|
@@ -0,0 +1,72 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
|
3
|
+
import torch
|
4
|
+
|
5
|
+
from ultralytics.engine.results import Results
|
6
|
+
|
7
|
+
from transformers import OwlViTForObjectDetection, OwlViTProcessor
|
8
|
+
|
9
|
+
from autodistill.detection import CaptionOntology
|
10
|
+
|
11
|
+
from coralnet_toolbox.Transformers.Models.QtBase import QtBaseModel
|
12
|
+
|
13
|
+
|
14
|
+
# ----------------------------------------------------------------------------------------------------------------------
|
15
|
+
# Classes
|
16
|
+
# ----------------------------------------------------------------------------------------------------------------------
|
17
|
+
|
18
|
+
|
19
|
+
@dataclass
|
20
|
+
class OWLViTModel(QtBaseModel):
|
21
|
+
def __init__(self, ontology: CaptionOntology, device: str = "cpu"):
|
22
|
+
super().__init__(ontology, device)
|
23
|
+
|
24
|
+
model_name = "google/owlvit-base-patch32"
|
25
|
+
self.processor = OwlViTProcessor.from_pretrained(model_name, use_fast=True)
|
26
|
+
self.model = OwlViTForObjectDetection.from_pretrained(model_name).to(self.device)
|
27
|
+
|
28
|
+
def _process_predictions(self, image, texts, confidence):
|
29
|
+
"""
|
30
|
+
Process model predictions for a single image, converting directly
|
31
|
+
to an Ultralytics Results object without an intermediate Supervision object.
|
32
|
+
"""
|
33
|
+
inputs = self.processor(text=texts, images=image, return_tensors="pt").to(self.device)
|
34
|
+
outputs = self.model(**inputs)
|
35
|
+
|
36
|
+
# Post-process the outputs to get detections.
|
37
|
+
# The confidence threshold is applied during this step.
|
38
|
+
results_processed = self.processor.post_process_object_detection(
|
39
|
+
outputs,
|
40
|
+
threshold=confidence,
|
41
|
+
target_sizes=[image.shape[:2]]
|
42
|
+
)[0]
|
43
|
+
|
44
|
+
boxes = results_processed["boxes"]
|
45
|
+
scores = results_processed["scores"]
|
46
|
+
|
47
|
+
# If no objects are detected, return an empty list to match the original behavior.
|
48
|
+
if scores.nelement() == 0:
|
49
|
+
return []
|
50
|
+
|
51
|
+
# Per original logic, assign all detections to class_id 0.
|
52
|
+
# TODO: We are only supporting a single class right now
|
53
|
+
class_ids = torch.zeros(scores.shape[0], 1, device=self.device)
|
54
|
+
|
55
|
+
# Combine boxes, scores, and class_ids into the (N, 6) tensor format
|
56
|
+
# required by the Results object: [x1, y1, x2, y2, confidence, class_id]
|
57
|
+
combined_data = torch.cat([
|
58
|
+
boxes,
|
59
|
+
scores.unsqueeze(1),
|
60
|
+
class_ids
|
61
|
+
], dim=1)
|
62
|
+
|
63
|
+
# Create the dictionary mapping class indices to class names.
|
64
|
+
names = {idx: text for idx, text in enumerate(self.ontology.classes())}
|
65
|
+
|
66
|
+
# Create the Results object with a DETACHED tensor
|
67
|
+
result = Results(orig_img=image,
|
68
|
+
path=None,
|
69
|
+
names=names,
|
70
|
+
boxes=combined_data.detach().cpu())
|
71
|
+
|
72
|
+
return result
|
@@ -0,0 +1,68 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
|
3
|
+
import torch
|
4
|
+
|
5
|
+
from ultralytics.engine.results import Results
|
6
|
+
|
7
|
+
from transformers import AutoProcessor, OmDetTurboForObjectDetection
|
8
|
+
|
9
|
+
from autodistill.detection import CaptionOntology
|
10
|
+
|
11
|
+
from coralnet_toolbox.Transformers.Models.QtBase import QtBaseModel
|
12
|
+
|
13
|
+
|
14
|
+
# ----------------------------------------------------------------------------------------------------------------------
|
15
|
+
# Classes
|
16
|
+
# ----------------------------------------------------------------------------------------------------------------------
|
17
|
+
|
18
|
+
|
19
|
+
@dataclass
|
20
|
+
class OmDetTurboModel(QtBaseModel):
|
21
|
+
def __init__(self, ontology: CaptionOntology, device: str = "cpu"):
|
22
|
+
super().__init__(ontology, device)
|
23
|
+
|
24
|
+
model_name = "omlab/omdet-turbo-swin-tiny-hf"
|
25
|
+
self.processor = AutoProcessor.from_pretrained(model_name, use_fast=True)
|
26
|
+
self.model = OmDetTurboForObjectDetection.from_pretrained(model_name).to(self.device)
|
27
|
+
|
28
|
+
def _process_predictions(self, image, texts, confidence):
|
29
|
+
"""Process model predictions for a single image."""
|
30
|
+
inputs = self.processor(text=texts, images=image, return_tensors="pt").to(self.device)
|
31
|
+
outputs = self.model(**inputs)
|
32
|
+
|
33
|
+
results_processed = self.processor.post_process_grounded_object_detection(
|
34
|
+
outputs,
|
35
|
+
threshold=confidence,
|
36
|
+
target_sizes=[image.shape[:2]],
|
37
|
+
text_labels=texts,
|
38
|
+
)[0]
|
39
|
+
|
40
|
+
boxes = results_processed["boxes"]
|
41
|
+
scores = results_processed["scores"]
|
42
|
+
|
43
|
+
# If no objects are detected, return an empty list to match the original behavior.
|
44
|
+
if scores.nelement() == 0:
|
45
|
+
return []
|
46
|
+
|
47
|
+
# Per original logic, assign all detections to class_id 0.
|
48
|
+
# TODO: We are only supporting a single class right now
|
49
|
+
class_ids = torch.zeros(scores.shape[0], 1, device=self.device)
|
50
|
+
|
51
|
+
# Combine boxes, scores, and class_ids into the (N, 6) tensor format
|
52
|
+
# required by the Results object: [x1, y1, x2, y2, confidence, class_id]
|
53
|
+
combined_data = torch.cat([
|
54
|
+
boxes,
|
55
|
+
scores.unsqueeze(1),
|
56
|
+
class_ids
|
57
|
+
], dim=1)
|
58
|
+
|
59
|
+
# Create the dictionary mapping class indices to class names.
|
60
|
+
names = {idx: text for idx, text in enumerate(self.ontology.classes())}
|
61
|
+
|
62
|
+
# Create the Results object with a DETACHED tensor
|
63
|
+
result = Results(orig_img=image,
|
64
|
+
path=None,
|
65
|
+
names=names,
|
66
|
+
boxes=combined_data.detach().cpu())
|
67
|
+
|
68
|
+
return result
|
@@ -0,0 +1,120 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
from abc import ABC, abstractmethod
|
3
|
+
|
4
|
+
import cv2
|
5
|
+
import numpy as np
|
6
|
+
|
7
|
+
from ultralytics.engine.results import Results
|
8
|
+
|
9
|
+
from autodistill.detection import CaptionOntology, DetectionBaseModel
|
10
|
+
from autodistill.helpers import load_image
|
11
|
+
|
12
|
+
from coralnet_toolbox.Results import CombineResults
|
13
|
+
|
14
|
+
|
15
|
+
# ----------------------------------------------------------------------------------------------------------------------
|
16
|
+
# Classes
|
17
|
+
# ----------------------------------------------------------------------------------------------------------------------
|
18
|
+
|
19
|
+
|
20
|
+
@dataclass
|
21
|
+
class QtBaseModel(DetectionBaseModel, ABC):
|
22
|
+
"""
|
23
|
+
Base class for Transformer foundation models that provides common functionality for
|
24
|
+
handling inputs, processing image data, and formatting detection results.
|
25
|
+
"""
|
26
|
+
ontology: CaptionOntology
|
27
|
+
|
28
|
+
def __init__(self, ontology: CaptionOntology, device: str = "cpu"):
|
29
|
+
"""
|
30
|
+
Initialize the base model with ontology and device.
|
31
|
+
|
32
|
+
Args:
|
33
|
+
ontology: The CaptionOntology containing class labels
|
34
|
+
device: The compute device (cpu, cuda, etc.)
|
35
|
+
"""
|
36
|
+
self.ontology = ontology
|
37
|
+
self.device = device
|
38
|
+
self.processor = None
|
39
|
+
self.model = None
|
40
|
+
|
41
|
+
def _normalize_input(self, input) -> list[np.ndarray]:
|
42
|
+
"""
|
43
|
+
Normalizes various input types into a list of images in CV2 (BGR) format.
|
44
|
+
|
45
|
+
Args:
|
46
|
+
input: Can be an image path, a list of paths, a numpy array, or a list of numpy arrays.
|
47
|
+
|
48
|
+
Returns:
|
49
|
+
A list of images, each as a numpy array in CV2 (BGR) format.
|
50
|
+
"""
|
51
|
+
images = []
|
52
|
+
if isinstance(input, str):
|
53
|
+
# Single image path
|
54
|
+
images = [load_image(input, return_format="cv2")]
|
55
|
+
elif isinstance(input, np.ndarray):
|
56
|
+
# Single image numpy array (RGB) or a batch of images (NHWC, RGB)
|
57
|
+
if input.ndim == 3:
|
58
|
+
images = [cv2.cvtColor(input, cv2.COLOR_RGB2BGR)]
|
59
|
+
elif input.ndim == 4:
|
60
|
+
images = [cv2.cvtColor(img, cv2.COLOR_RGB2BGR) for img in input]
|
61
|
+
else:
|
62
|
+
raise ValueError(f"Unsupported numpy array dimensions: {input.ndim}")
|
63
|
+
elif isinstance(input, list):
|
64
|
+
if all(isinstance(i, str) for i in input):
|
65
|
+
# List of image paths
|
66
|
+
images = [load_image(path, return_format="cv2") for path in input]
|
67
|
+
elif all(isinstance(i, np.ndarray) for i in input):
|
68
|
+
# List of image arrays (RGB)
|
69
|
+
images = [cv2.cvtColor(img, cv2.COLOR_RGB2BGR) for img in input]
|
70
|
+
else:
|
71
|
+
raise ValueError("A list input must contain either all image paths or all numpy arrays.")
|
72
|
+
else:
|
73
|
+
raise TypeError(f"Unsupported input type: {type(input)}")
|
74
|
+
|
75
|
+
return images
|
76
|
+
|
77
|
+
@abstractmethod
|
78
|
+
def _process_predictions(self, image: np.ndarray, texts: list[str], confidence: float) -> Results:
|
79
|
+
"""
|
80
|
+
Process model predictions for a single image.
|
81
|
+
|
82
|
+
Args:
|
83
|
+
image: The input image in CV2 (BGR) format.
|
84
|
+
texts: The text prompts from the ontology.
|
85
|
+
confidence: Confidence threshold.
|
86
|
+
|
87
|
+
Returns:
|
88
|
+
A single Ultralytics Results object, which may be empty if no detections are found.
|
89
|
+
"""
|
90
|
+
pass
|
91
|
+
|
92
|
+
def predict(self, inputs, confidence=0.01) -> list[Results]:
|
93
|
+
"""
|
94
|
+
Run inference on input images.
|
95
|
+
|
96
|
+
Args:
|
97
|
+
inputs: Can be an image path, a list of image paths, a numpy array, or a list of numpy arrays.
|
98
|
+
confidence: Detection confidence threshold.
|
99
|
+
|
100
|
+
Returns:
|
101
|
+
A flat list of Ultralytics Results objects, one for each input image.
|
102
|
+
"""
|
103
|
+
# Step 1: Normalize the input into a consistent list of images
|
104
|
+
normalized_inputs = self._normalize_input(inputs)
|
105
|
+
|
106
|
+
# Step 2: Prepare for inference
|
107
|
+
results = []
|
108
|
+
texts = self.ontology.prompts()
|
109
|
+
|
110
|
+
# Step 3: Loop through images and process predictions
|
111
|
+
for normalized_input in normalized_inputs:
|
112
|
+
result = self._process_predictions(normalized_input, texts, confidence)
|
113
|
+
if result:
|
114
|
+
results.append(result)
|
115
|
+
|
116
|
+
if len(results):
|
117
|
+
# Combine the results into one, then wrap in a list
|
118
|
+
results = CombineResults().combine_results(results)
|
119
|
+
|
120
|
+
return [results] if results else []
|