ultralytics-opencv-headless 8.3.246__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +23 -0
- tests/conftest.py +59 -0
- tests/test_cli.py +131 -0
- tests/test_cuda.py +216 -0
- tests/test_engine.py +157 -0
- tests/test_exports.py +309 -0
- tests/test_integrations.py +151 -0
- tests/test_python.py +777 -0
- tests/test_solutions.py +371 -0
- ultralytics/__init__.py +48 -0
- ultralytics/assets/bus.jpg +0 -0
- ultralytics/assets/zidane.jpg +0 -0
- ultralytics/cfg/__init__.py +1026 -0
- ultralytics/cfg/datasets/Argoverse.yaml +78 -0
- ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
- ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +32 -0
- ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
- ultralytics/cfg/datasets/Objects365.yaml +447 -0
- ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
- ultralytics/cfg/datasets/VOC.yaml +102 -0
- ultralytics/cfg/datasets/VisDrone.yaml +87 -0
- ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
- ultralytics/cfg/datasets/brain-tumor.yaml +22 -0
- ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
- ultralytics/cfg/datasets/coco-pose.yaml +64 -0
- ultralytics/cfg/datasets/coco.yaml +118 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco128.yaml +101 -0
- ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
- ultralytics/cfg/datasets/coco8-pose.yaml +47 -0
- ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco8.yaml +101 -0
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/crack-seg.yaml +22 -0
- ultralytics/cfg/datasets/dog-pose.yaml +52 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
- ultralytics/cfg/datasets/dota8.yaml +35 -0
- ultralytics/cfg/datasets/hand-keypoints.yaml +50 -0
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +1240 -0
- ultralytics/cfg/datasets/medical-pills.yaml +21 -0
- ultralytics/cfg/datasets/open-images-v7.yaml +663 -0
- ultralytics/cfg/datasets/package-seg.yaml +22 -0
- ultralytics/cfg/datasets/signature.yaml +21 -0
- ultralytics/cfg/datasets/tiger-pose.yaml +41 -0
- ultralytics/cfg/datasets/xView.yaml +155 -0
- ultralytics/cfg/default.yaml +130 -0
- ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
- ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
- ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
- ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
- ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
- ultralytics/cfg/models/11/yolo11.yaml +50 -0
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
- ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
- ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
- ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
- ultralytics/cfg/models/12/yolo12.yaml +48 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
- ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
- ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
- ultralytics/cfg/models/v3/yolov3.yaml +49 -0
- ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
- ultralytics/cfg/models/v5/yolov5.yaml +51 -0
- ultralytics/cfg/models/v6/yolov6.yaml +56 -0
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +48 -0
- ultralytics/cfg/models/v8/yoloe-v8.yaml +48 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
- ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8.yaml +49 -0
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
- ultralytics/cfg/trackers/botsort.yaml +21 -0
- ultralytics/cfg/trackers/bytetrack.yaml +12 -0
- ultralytics/data/__init__.py +26 -0
- ultralytics/data/annotator.py +66 -0
- ultralytics/data/augment.py +2801 -0
- ultralytics/data/base.py +435 -0
- ultralytics/data/build.py +437 -0
- ultralytics/data/converter.py +855 -0
- ultralytics/data/dataset.py +834 -0
- ultralytics/data/loaders.py +704 -0
- ultralytics/data/scripts/download_weights.sh +18 -0
- ultralytics/data/scripts/get_coco.sh +61 -0
- ultralytics/data/scripts/get_coco128.sh +18 -0
- ultralytics/data/scripts/get_imagenet.sh +52 -0
- ultralytics/data/split.py +138 -0
- ultralytics/data/split_dota.py +344 -0
- ultralytics/data/utils.py +798 -0
- ultralytics/engine/__init__.py +1 -0
- ultralytics/engine/exporter.py +1578 -0
- ultralytics/engine/model.py +1124 -0
- ultralytics/engine/predictor.py +508 -0
- ultralytics/engine/results.py +1522 -0
- ultralytics/engine/trainer.py +974 -0
- ultralytics/engine/tuner.py +448 -0
- ultralytics/engine/validator.py +384 -0
- ultralytics/hub/__init__.py +166 -0
- ultralytics/hub/auth.py +151 -0
- ultralytics/hub/google/__init__.py +174 -0
- ultralytics/hub/session.py +422 -0
- ultralytics/hub/utils.py +162 -0
- ultralytics/models/__init__.py +9 -0
- ultralytics/models/fastsam/__init__.py +7 -0
- ultralytics/models/fastsam/model.py +79 -0
- ultralytics/models/fastsam/predict.py +169 -0
- ultralytics/models/fastsam/utils.py +23 -0
- ultralytics/models/fastsam/val.py +38 -0
- ultralytics/models/nas/__init__.py +7 -0
- ultralytics/models/nas/model.py +98 -0
- ultralytics/models/nas/predict.py +56 -0
- ultralytics/models/nas/val.py +38 -0
- ultralytics/models/rtdetr/__init__.py +7 -0
- ultralytics/models/rtdetr/model.py +63 -0
- ultralytics/models/rtdetr/predict.py +88 -0
- ultralytics/models/rtdetr/train.py +89 -0
- ultralytics/models/rtdetr/val.py +216 -0
- ultralytics/models/sam/__init__.py +25 -0
- ultralytics/models/sam/amg.py +275 -0
- ultralytics/models/sam/build.py +365 -0
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +169 -0
- ultralytics/models/sam/modules/__init__.py +1 -0
- ultralytics/models/sam/modules/blocks.py +1067 -0
- ultralytics/models/sam/modules/decoders.py +495 -0
- ultralytics/models/sam/modules/encoders.py +794 -0
- ultralytics/models/sam/modules/memory_attention.py +298 -0
- ultralytics/models/sam/modules/sam.py +1160 -0
- ultralytics/models/sam/modules/tiny_encoder.py +979 -0
- ultralytics/models/sam/modules/transformer.py +344 -0
- ultralytics/models/sam/modules/utils.py +512 -0
- ultralytics/models/sam/predict.py +3940 -0
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/__init__.py +1 -0
- ultralytics/models/utils/loss.py +466 -0
- ultralytics/models/utils/ops.py +315 -0
- ultralytics/models/yolo/__init__.py +7 -0
- ultralytics/models/yolo/classify/__init__.py +7 -0
- ultralytics/models/yolo/classify/predict.py +90 -0
- ultralytics/models/yolo/classify/train.py +202 -0
- ultralytics/models/yolo/classify/val.py +216 -0
- ultralytics/models/yolo/detect/__init__.py +7 -0
- ultralytics/models/yolo/detect/predict.py +122 -0
- ultralytics/models/yolo/detect/train.py +227 -0
- ultralytics/models/yolo/detect/val.py +507 -0
- ultralytics/models/yolo/model.py +430 -0
- ultralytics/models/yolo/obb/__init__.py +7 -0
- ultralytics/models/yolo/obb/predict.py +56 -0
- ultralytics/models/yolo/obb/train.py +79 -0
- ultralytics/models/yolo/obb/val.py +302 -0
- ultralytics/models/yolo/pose/__init__.py +7 -0
- ultralytics/models/yolo/pose/predict.py +65 -0
- ultralytics/models/yolo/pose/train.py +110 -0
- ultralytics/models/yolo/pose/val.py +248 -0
- ultralytics/models/yolo/segment/__init__.py +7 -0
- ultralytics/models/yolo/segment/predict.py +109 -0
- ultralytics/models/yolo/segment/train.py +69 -0
- ultralytics/models/yolo/segment/val.py +307 -0
- ultralytics/models/yolo/world/__init__.py +5 -0
- ultralytics/models/yolo/world/train.py +173 -0
- ultralytics/models/yolo/world/train_world.py +178 -0
- ultralytics/models/yolo/yoloe/__init__.py +22 -0
- ultralytics/models/yolo/yoloe/predict.py +162 -0
- ultralytics/models/yolo/yoloe/train.py +287 -0
- ultralytics/models/yolo/yoloe/train_seg.py +122 -0
- ultralytics/models/yolo/yoloe/val.py +206 -0
- ultralytics/nn/__init__.py +27 -0
- ultralytics/nn/autobackend.py +958 -0
- ultralytics/nn/modules/__init__.py +182 -0
- ultralytics/nn/modules/activation.py +54 -0
- ultralytics/nn/modules/block.py +1947 -0
- ultralytics/nn/modules/conv.py +669 -0
- ultralytics/nn/modules/head.py +1183 -0
- ultralytics/nn/modules/transformer.py +793 -0
- ultralytics/nn/modules/utils.py +159 -0
- ultralytics/nn/tasks.py +1768 -0
- ultralytics/nn/text_model.py +356 -0
- ultralytics/py.typed +1 -0
- ultralytics/solutions/__init__.py +41 -0
- ultralytics/solutions/ai_gym.py +108 -0
- ultralytics/solutions/analytics.py +264 -0
- ultralytics/solutions/config.py +107 -0
- ultralytics/solutions/distance_calculation.py +123 -0
- ultralytics/solutions/heatmap.py +125 -0
- ultralytics/solutions/instance_segmentation.py +86 -0
- ultralytics/solutions/object_blurrer.py +89 -0
- ultralytics/solutions/object_counter.py +190 -0
- ultralytics/solutions/object_cropper.py +87 -0
- ultralytics/solutions/parking_management.py +280 -0
- ultralytics/solutions/queue_management.py +93 -0
- ultralytics/solutions/region_counter.py +133 -0
- ultralytics/solutions/security_alarm.py +151 -0
- ultralytics/solutions/similarity_search.py +219 -0
- ultralytics/solutions/solutions.py +828 -0
- ultralytics/solutions/speed_estimation.py +114 -0
- ultralytics/solutions/streamlit_inference.py +260 -0
- ultralytics/solutions/templates/similarity-search.html +156 -0
- ultralytics/solutions/trackzone.py +88 -0
- ultralytics/solutions/vision_eye.py +67 -0
- ultralytics/trackers/__init__.py +7 -0
- ultralytics/trackers/basetrack.py +115 -0
- ultralytics/trackers/bot_sort.py +257 -0
- ultralytics/trackers/byte_tracker.py +469 -0
- ultralytics/trackers/track.py +116 -0
- ultralytics/trackers/utils/__init__.py +1 -0
- ultralytics/trackers/utils/gmc.py +339 -0
- ultralytics/trackers/utils/kalman_filter.py +482 -0
- ultralytics/trackers/utils/matching.py +154 -0
- ultralytics/utils/__init__.py +1450 -0
- ultralytics/utils/autobatch.py +118 -0
- ultralytics/utils/autodevice.py +205 -0
- ultralytics/utils/benchmarks.py +728 -0
- ultralytics/utils/callbacks/__init__.py +5 -0
- ultralytics/utils/callbacks/base.py +233 -0
- ultralytics/utils/callbacks/clearml.py +146 -0
- ultralytics/utils/callbacks/comet.py +625 -0
- ultralytics/utils/callbacks/dvc.py +197 -0
- ultralytics/utils/callbacks/hub.py +110 -0
- ultralytics/utils/callbacks/mlflow.py +134 -0
- ultralytics/utils/callbacks/neptune.py +126 -0
- ultralytics/utils/callbacks/platform.py +313 -0
- ultralytics/utils/callbacks/raytune.py +42 -0
- ultralytics/utils/callbacks/tensorboard.py +123 -0
- ultralytics/utils/callbacks/wb.py +188 -0
- ultralytics/utils/checks.py +1006 -0
- ultralytics/utils/cpu.py +85 -0
- ultralytics/utils/dist.py +123 -0
- ultralytics/utils/downloads.py +529 -0
- ultralytics/utils/errors.py +35 -0
- ultralytics/utils/events.py +113 -0
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/export/engine.py +237 -0
- ultralytics/utils/export/imx.py +315 -0
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +219 -0
- ultralytics/utils/git.py +137 -0
- ultralytics/utils/instance.py +484 -0
- ultralytics/utils/logger.py +501 -0
- ultralytics/utils/loss.py +849 -0
- ultralytics/utils/metrics.py +1563 -0
- ultralytics/utils/nms.py +337 -0
- ultralytics/utils/ops.py +664 -0
- ultralytics/utils/patches.py +201 -0
- ultralytics/utils/plotting.py +1045 -0
- ultralytics/utils/tal.py +403 -0
- ultralytics/utils/torch_utils.py +984 -0
- ultralytics/utils/tqdm.py +440 -0
- ultralytics/utils/triton.py +112 -0
- ultralytics/utils/tuner.py +160 -0
- ultralytics_opencv_headless-8.3.246.dist-info/METADATA +374 -0
- ultralytics_opencv_headless-8.3.246.dist-info/RECORD +298 -0
- ultralytics_opencv_headless-8.3.246.dist-info/WHEEL +5 -0
- ultralytics_opencv_headless-8.3.246.dist-info/entry_points.txt +3 -0
- ultralytics_opencv_headless-8.3.246.dist-info/licenses/LICENSE +661 -0
- ultralytics_opencv_headless-8.3.246.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,199 @@
|
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
|
+
|
|
3
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
|
4
|
+
|
|
5
|
+
"""Various utility models."""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import math
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
import torch
|
|
13
|
+
from torch import Tensor, nn
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class DotProductScoring(torch.nn.Module):
|
|
17
|
+
"""A module that computes dot-product scores between a set of query features and a."""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
d_model,
|
|
22
|
+
d_proj,
|
|
23
|
+
prompt_mlp=None,
|
|
24
|
+
clamp_logits=True,
|
|
25
|
+
clamp_max_val=12.0,
|
|
26
|
+
):
|
|
27
|
+
"""Initialize the DotProductScoring module."""
|
|
28
|
+
super().__init__()
|
|
29
|
+
self.d_proj = d_proj
|
|
30
|
+
assert isinstance(prompt_mlp, torch.nn.Module) or prompt_mlp is None
|
|
31
|
+
self.prompt_mlp = prompt_mlp # an optional MLP projection for prompt
|
|
32
|
+
self.prompt_proj = torch.nn.Linear(d_model, d_proj)
|
|
33
|
+
self.hs_proj = torch.nn.Linear(d_model, d_proj)
|
|
34
|
+
self.scale = float(1.0 / np.sqrt(d_proj))
|
|
35
|
+
self.clamp_logits = clamp_logits
|
|
36
|
+
if self.clamp_logits:
|
|
37
|
+
self.clamp_max_val = clamp_max_val
|
|
38
|
+
|
|
39
|
+
@staticmethod
|
|
40
|
+
def mean_pool_text(prompt, prompt_mask):
|
|
41
|
+
"""Mean-pool the prompt embeddings over the valid tokens only."""
|
|
42
|
+
# is_valid has shape (seq, bs, 1), where 1 is valid and 0 is padding
|
|
43
|
+
is_valid = (~prompt_mask).to(prompt.dtype).permute(1, 0)[..., None]
|
|
44
|
+
# num_valid has shape (bs, 1)
|
|
45
|
+
num_valid = torch.clamp(torch.sum(is_valid, dim=0), min=1.0)
|
|
46
|
+
# mean pool over all the valid tokens -- pooled_prompt has shape (bs, proj_dim)
|
|
47
|
+
pooled_prompt = (prompt * is_valid).sum(dim=0) / num_valid
|
|
48
|
+
return pooled_prompt
|
|
49
|
+
|
|
50
|
+
def forward(self, hs, prompt, prompt_mask):
|
|
51
|
+
"""Compute dot-product scores between hs and prompt."""
|
|
52
|
+
# hs has shape (num_layer, bs, num_query, d_model)
|
|
53
|
+
# prompt has shape (seq, bs, d_model)
|
|
54
|
+
# prompt_mask has shape (bs, seq), where 1 is valid and 0 is padding
|
|
55
|
+
assert hs.dim() == 4 and prompt.dim() == 3 and prompt_mask.dim() == 2
|
|
56
|
+
|
|
57
|
+
# apply MLP on prompt if specified
|
|
58
|
+
if self.prompt_mlp is not None:
|
|
59
|
+
prompt = self.prompt_mlp(prompt.to(hs.dtype))
|
|
60
|
+
|
|
61
|
+
# first, get the mean-pooled version of the prompt
|
|
62
|
+
pooled_prompt = self.mean_pool_text(prompt, prompt_mask)
|
|
63
|
+
|
|
64
|
+
# then, project pooled_prompt and hs to d_proj dimensions
|
|
65
|
+
proj_pooled_prompt = self.prompt_proj(pooled_prompt) # (bs, d_proj)
|
|
66
|
+
proj_hs = self.hs_proj(hs) # (num_layer, bs, num_query, d_proj)
|
|
67
|
+
|
|
68
|
+
# finally, get dot-product scores of shape (num_layer, bs, num_query, 1)
|
|
69
|
+
scores = torch.matmul(proj_hs, proj_pooled_prompt.unsqueeze(-1))
|
|
70
|
+
scores *= self.scale
|
|
71
|
+
|
|
72
|
+
# clamp scores to a max value to avoid numerical issues in loss or matcher
|
|
73
|
+
if self.clamp_logits:
|
|
74
|
+
scores.clamp_(min=-self.clamp_max_val, max=self.clamp_max_val)
|
|
75
|
+
|
|
76
|
+
return scores
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class LayerScale(nn.Module):
|
|
80
|
+
"""LayerScale module as introduced in "Meta Pseudo Labels" and used in."""
|
|
81
|
+
|
|
82
|
+
def __init__(
|
|
83
|
+
self,
|
|
84
|
+
dim: int,
|
|
85
|
+
init_values: float | Tensor = 1e-5,
|
|
86
|
+
inplace: bool = False,
|
|
87
|
+
) -> None:
|
|
88
|
+
"""Initialize the LayerScale module."""
|
|
89
|
+
super().__init__()
|
|
90
|
+
self.inplace = inplace
|
|
91
|
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
|
92
|
+
|
|
93
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
94
|
+
"""Apply LayerScale to the input tensor."""
|
|
95
|
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class TransformerWrapper(nn.Module):
|
|
99
|
+
"""A wrapper for the transformer consisting of an encoder and a decoder."""
|
|
100
|
+
|
|
101
|
+
def __init__(
|
|
102
|
+
self,
|
|
103
|
+
encoder,
|
|
104
|
+
decoder,
|
|
105
|
+
d_model: int,
|
|
106
|
+
two_stage_type="none", # ["none"] only for now
|
|
107
|
+
pos_enc_at_input_dec=True,
|
|
108
|
+
):
|
|
109
|
+
"""Initialize the TransformerWrapper."""
|
|
110
|
+
super().__init__()
|
|
111
|
+
self.encoder = encoder
|
|
112
|
+
self.decoder = decoder
|
|
113
|
+
self.num_queries = decoder.num_queries if decoder is not None else None
|
|
114
|
+
self.pos_enc_at_input_dec = pos_enc_at_input_dec
|
|
115
|
+
|
|
116
|
+
# for two stage
|
|
117
|
+
assert two_stage_type in ["none"], f"unknown param {two_stage_type} of two_stage_type"
|
|
118
|
+
self.two_stage_type = two_stage_type
|
|
119
|
+
|
|
120
|
+
self._reset_parameters()
|
|
121
|
+
self.d_model = d_model
|
|
122
|
+
|
|
123
|
+
def _reset_parameters(self):
|
|
124
|
+
"""Initialize the parameters of the model."""
|
|
125
|
+
for n, p in self.named_parameters():
|
|
126
|
+
if p.dim() > 1:
|
|
127
|
+
if "box_embed" not in n and "query_embed" not in n and "reference_points" not in n:
|
|
128
|
+
nn.init.xavier_uniform_(p)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def get_valid_ratio(mask):
|
|
132
|
+
"""Compute the valid ratio of height and width from the mask."""
|
|
133
|
+
_, H, W = mask.shape
|
|
134
|
+
valid_H = torch.sum(~mask[:, :, 0], 1)
|
|
135
|
+
valid_W = torch.sum(~mask[:, 0, :], 1)
|
|
136
|
+
valid_ratio_h = valid_H.float() / H
|
|
137
|
+
valid_ratio_w = valid_W.float() / W
|
|
138
|
+
valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
|
|
139
|
+
return valid_ratio
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def gen_sineembed_for_position(pos_tensor: torch.Tensor, num_feats: int = 256):
|
|
143
|
+
"""Generate sinusoidal position embeddings for 2D or 4D coordinate tensors.
|
|
144
|
+
|
|
145
|
+
This function creates sinusoidal embeddings using sine and cosine functions at different frequencies, similar to the
|
|
146
|
+
positional encoding used in Transformer models. It supports both 2D position tensors (x, y) and 4D tensors (x, y, w,
|
|
147
|
+
h) for bounding box coordinates.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
pos_tensor (torch.Tensor): Input position tensor of shape (n_query, bs, 2) for 2D coordinates or (n_query, bs,
|
|
151
|
+
4) for 4D coordinates (bounding boxes).
|
|
152
|
+
num_feats (int): Number of feature dimensions for the output embedding. Must be even. Defaults to 256.
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
(torch.Tensor): Sinusoidal position embeddings of shape (n_query, bs, num_feats) for 2D input or (n_query, bs,
|
|
156
|
+
num_feats * 2) for 4D input.
|
|
157
|
+
|
|
158
|
+
Raises:
|
|
159
|
+
AssertionError: If num_feats is not even.
|
|
160
|
+
ValueError: If pos_tensor.size(-1) is not 2 or 4.
|
|
161
|
+
|
|
162
|
+
Examples:
|
|
163
|
+
>>> pos_2d = torch.rand(100, 8, 2) # 100 queries, batch size 8, 2D coordinates
|
|
164
|
+
>>> embeddings_2d = gen_sineembed_for_position(pos_2d, num_feats=256)
|
|
165
|
+
>>> embeddings_2d.shape
|
|
166
|
+
torch.Size([100, 8, 256])
|
|
167
|
+
>>> pos_4d = torch.rand(50, 4, 4) # 50 queries, batch size 4, 4D coordinates
|
|
168
|
+
>>> embeddings_4d = gen_sineembed_for_position(pos_4d, num_feats=128)
|
|
169
|
+
>>> embeddings_4d.shape
|
|
170
|
+
torch.Size([50, 4, 256])
|
|
171
|
+
"""
|
|
172
|
+
assert num_feats % 2 == 0
|
|
173
|
+
num_feats = num_feats // 2
|
|
174
|
+
# n_query, bs, _ = pos_tensor.size()
|
|
175
|
+
# sineembed_tensor = torch.zeros(n_query, bs, 256)
|
|
176
|
+
scale = 2 * math.pi
|
|
177
|
+
dim_t = torch.arange(num_feats, dtype=pos_tensor.dtype, device=pos_tensor.device)
|
|
178
|
+
dim_t = 10000 ** (2 * (torch.div(dim_t, 2, rounding_mode="floor")) / num_feats)
|
|
179
|
+
x_embed = pos_tensor[:, :, 0] * scale
|
|
180
|
+
y_embed = pos_tensor[:, :, 1] * scale
|
|
181
|
+
pos_x = x_embed[:, :, None] / dim_t
|
|
182
|
+
pos_y = y_embed[:, :, None] / dim_t
|
|
183
|
+
pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
|
|
184
|
+
pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
|
|
185
|
+
if pos_tensor.size(-1) == 2:
|
|
186
|
+
pos = torch.cat((pos_y, pos_x), dim=2)
|
|
187
|
+
elif pos_tensor.size(-1) == 4:
|
|
188
|
+
w_embed = pos_tensor[:, :, 2] * scale
|
|
189
|
+
pos_w = w_embed[:, :, None] / dim_t
|
|
190
|
+
pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2)
|
|
191
|
+
|
|
192
|
+
h_embed = pos_tensor[:, :, 3] * scale
|
|
193
|
+
pos_h = h_embed[:, :, None] / dim_t
|
|
194
|
+
pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2)
|
|
195
|
+
|
|
196
|
+
pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
|
|
197
|
+
else:
|
|
198
|
+
raise ValueError(f"Unknown pos_tensor shape(-1):{pos_tensor.size(-1)}")
|
|
199
|
+
return pos
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
|
+
|
|
3
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
|
4
|
+
|
|
5
|
+
"""Necks are the interface between a vision backbone and the rest of the detection model."""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from copy import deepcopy
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
import torch.nn as nn
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class Sam3DualViTDetNeck(nn.Module):
|
|
16
|
+
"""A neck that implements a simple FPN as in ViTDet, with support for dual necks (for SAM3 and SAM2)."""
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
trunk: nn.Module,
|
|
21
|
+
position_encoding: nn.Module,
|
|
22
|
+
d_model: int,
|
|
23
|
+
scale_factors=(4.0, 2.0, 1.0, 0.5),
|
|
24
|
+
add_sam2_neck: bool = False,
|
|
25
|
+
):
|
|
26
|
+
"""
|
|
27
|
+
SimpleFPN neck a la ViTDet
|
|
28
|
+
(From detectron2, very lightly adapted)
|
|
29
|
+
It supports a "dual neck" setting, where we have two identical necks (for SAM3 and SAM2), with different weights.
|
|
30
|
+
|
|
31
|
+
:param trunk: the backbone
|
|
32
|
+
:param position_encoding: the positional encoding to use
|
|
33
|
+
:param d_model: the dimension of the model
|
|
34
|
+
"""
|
|
35
|
+
super().__init__()
|
|
36
|
+
self.trunk = trunk
|
|
37
|
+
self.position_encoding = position_encoding
|
|
38
|
+
self.convs = nn.ModuleList()
|
|
39
|
+
|
|
40
|
+
self.scale_factors = scale_factors
|
|
41
|
+
use_bias = True
|
|
42
|
+
dim: int = self.trunk.channel_list[-1]
|
|
43
|
+
|
|
44
|
+
for _, scale in enumerate(scale_factors):
|
|
45
|
+
current = nn.Sequential()
|
|
46
|
+
|
|
47
|
+
if scale == 4.0:
|
|
48
|
+
current.add_module(
|
|
49
|
+
"dconv_2x2_0",
|
|
50
|
+
nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2),
|
|
51
|
+
)
|
|
52
|
+
current.add_module(
|
|
53
|
+
"gelu",
|
|
54
|
+
nn.GELU(),
|
|
55
|
+
)
|
|
56
|
+
current.add_module(
|
|
57
|
+
"dconv_2x2_1",
|
|
58
|
+
nn.ConvTranspose2d(dim // 2, dim // 4, kernel_size=2, stride=2),
|
|
59
|
+
)
|
|
60
|
+
out_dim = dim // 4
|
|
61
|
+
elif scale == 2.0:
|
|
62
|
+
current.add_module(
|
|
63
|
+
"dconv_2x2",
|
|
64
|
+
nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2),
|
|
65
|
+
)
|
|
66
|
+
out_dim = dim // 2
|
|
67
|
+
elif scale == 1.0:
|
|
68
|
+
out_dim = dim
|
|
69
|
+
elif scale == 0.5:
|
|
70
|
+
current.add_module(
|
|
71
|
+
"maxpool_2x2",
|
|
72
|
+
nn.MaxPool2d(kernel_size=2, stride=2),
|
|
73
|
+
)
|
|
74
|
+
out_dim = dim
|
|
75
|
+
else:
|
|
76
|
+
raise NotImplementedError(f"scale_factor={scale} is not supported yet.")
|
|
77
|
+
|
|
78
|
+
current.add_module(
|
|
79
|
+
"conv_1x1",
|
|
80
|
+
nn.Conv2d(
|
|
81
|
+
in_channels=out_dim,
|
|
82
|
+
out_channels=d_model,
|
|
83
|
+
kernel_size=1,
|
|
84
|
+
bias=use_bias,
|
|
85
|
+
),
|
|
86
|
+
)
|
|
87
|
+
current.add_module(
|
|
88
|
+
"conv_3x3",
|
|
89
|
+
nn.Conv2d(
|
|
90
|
+
in_channels=d_model,
|
|
91
|
+
out_channels=d_model,
|
|
92
|
+
kernel_size=3,
|
|
93
|
+
padding=1,
|
|
94
|
+
bias=use_bias,
|
|
95
|
+
),
|
|
96
|
+
)
|
|
97
|
+
self.convs.append(current)
|
|
98
|
+
|
|
99
|
+
self.sam2_convs = None
|
|
100
|
+
if add_sam2_neck:
|
|
101
|
+
# Assumes sam2 neck is just a clone of the original neck
|
|
102
|
+
self.sam2_convs = deepcopy(self.convs)
|
|
103
|
+
|
|
104
|
+
def forward(
|
|
105
|
+
self, tensor_list: list[torch.Tensor]
|
|
106
|
+
) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor] | None, list[torch.Tensor] | None]:
|
|
107
|
+
"""Get feature maps and positional encodings from the neck."""
|
|
108
|
+
xs = self.trunk(tensor_list)
|
|
109
|
+
x = xs[-1] # simpleFPN
|
|
110
|
+
sam3_out, sam3_pos = self.sam_forward_feature_levels(x, self.convs)
|
|
111
|
+
if self.sam2_convs is None:
|
|
112
|
+
return sam3_out, sam3_pos, None, None
|
|
113
|
+
sam2_out, sam2_pos = self.sam_forward_feature_levels(x, self.sam2_convs)
|
|
114
|
+
return sam3_out, sam3_pos, sam2_out, sam2_pos
|
|
115
|
+
|
|
116
|
+
def sam_forward_feature_levels(
|
|
117
|
+
self, x: torch.Tensor, convs: nn.ModuleList
|
|
118
|
+
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
|
|
119
|
+
"""Run neck convolutions and compute positional encodings for each feature level."""
|
|
120
|
+
outs, poss = [], []
|
|
121
|
+
for conv in convs:
|
|
122
|
+
feat = conv(x)
|
|
123
|
+
outs.append(feat)
|
|
124
|
+
poss.append(self.position_encoding(feat).to(feat.dtype))
|
|
125
|
+
return outs, poss
|
|
126
|
+
|
|
127
|
+
def set_imgsz(self, imgsz: list[int] = [1008, 1008]):
|
|
128
|
+
"""Set the image size for the trunk backbone."""
|
|
129
|
+
self.trunk.set_imgsz(imgsz)
|
|
@@ -0,0 +1,339 @@
|
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
|
+
|
|
3
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from copy import deepcopy
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
from ultralytics.nn.modules.utils import inverse_sigmoid
|
|
12
|
+
from ultralytics.utils.ops import xywh2xyxy
|
|
13
|
+
|
|
14
|
+
from ..modules.sam import SAM2Model
|
|
15
|
+
from .geometry_encoders import Prompt
|
|
16
|
+
from .vl_combiner import SAM3VLBackbone
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _update_out(out, out_name, out_value, auxiliary=True, update_aux=True):
|
|
20
|
+
"""Helper function to update output dictionary with main and auxiliary outputs."""
|
|
21
|
+
out[out_name] = out_value[-1] if auxiliary else out_value
|
|
22
|
+
if auxiliary and update_aux:
|
|
23
|
+
if "aux_outputs" not in out:
|
|
24
|
+
out["aux_outputs"] = [{} for _ in range(len(out_value) - 1)]
|
|
25
|
+
assert len(out["aux_outputs"]) == len(out_value) - 1
|
|
26
|
+
for aux_output, aux_value in zip(out["aux_outputs"], out_value[:-1]):
|
|
27
|
+
aux_output[out_name] = aux_value
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class SAM3SemanticModel(torch.nn.Module):
|
|
31
|
+
"""SAM3 model for semantic segmentation with vision-language backbone."""
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
backbone: SAM3VLBackbone,
|
|
36
|
+
transformer,
|
|
37
|
+
input_geometry_encoder,
|
|
38
|
+
segmentation_head=None,
|
|
39
|
+
num_feature_levels=1,
|
|
40
|
+
o2m_mask_predict=True,
|
|
41
|
+
dot_prod_scoring=None,
|
|
42
|
+
use_instance_query: bool = True,
|
|
43
|
+
multimask_output: bool = True,
|
|
44
|
+
use_act_checkpoint_seg_head: bool = True,
|
|
45
|
+
matcher=None,
|
|
46
|
+
use_dot_prod_scoring=True,
|
|
47
|
+
supervise_joint_box_scores: bool = False, # only relevant if using presence token/score
|
|
48
|
+
detach_presence_in_joint_score: bool = False, # only relevant if using presence token/score
|
|
49
|
+
separate_scorer_for_instance: bool = False,
|
|
50
|
+
num_interactive_steps_val: int = 0,
|
|
51
|
+
):
|
|
52
|
+
"""Initialize the SAM3SemanticModel."""
|
|
53
|
+
super().__init__()
|
|
54
|
+
self.backbone = backbone
|
|
55
|
+
self.geometry_encoder = input_geometry_encoder
|
|
56
|
+
self.transformer = transformer
|
|
57
|
+
self.hidden_dim = transformer.d_model
|
|
58
|
+
self.num_feature_levels = num_feature_levels
|
|
59
|
+
self.segmentation_head = segmentation_head
|
|
60
|
+
|
|
61
|
+
self.o2m_mask_predict = o2m_mask_predict
|
|
62
|
+
|
|
63
|
+
self.dot_prod_scoring = dot_prod_scoring
|
|
64
|
+
self.use_act_checkpoint_seg_head = use_act_checkpoint_seg_head
|
|
65
|
+
self.matcher = matcher
|
|
66
|
+
|
|
67
|
+
self.num_interactive_steps_val = num_interactive_steps_val
|
|
68
|
+
self.use_dot_prod_scoring = use_dot_prod_scoring
|
|
69
|
+
|
|
70
|
+
if self.use_dot_prod_scoring:
|
|
71
|
+
assert dot_prod_scoring is not None
|
|
72
|
+
self.dot_prod_scoring = dot_prod_scoring
|
|
73
|
+
self.instance_dot_prod_scoring = None
|
|
74
|
+
if separate_scorer_for_instance:
|
|
75
|
+
self.instance_dot_prod_scoring = deepcopy(dot_prod_scoring)
|
|
76
|
+
else:
|
|
77
|
+
self.class_embed = torch.nn.Linear(self.hidden_dim, 1)
|
|
78
|
+
self.instance_class_embed = None
|
|
79
|
+
if separate_scorer_for_instance:
|
|
80
|
+
self.instance_class_embed = deepcopy(self.class_embed)
|
|
81
|
+
|
|
82
|
+
self.supervise_joint_box_scores = supervise_joint_box_scores
|
|
83
|
+
self.detach_presence_in_joint_score = detach_presence_in_joint_score
|
|
84
|
+
|
|
85
|
+
# verify the number of queries for O2O and O2M
|
|
86
|
+
num_o2o_static = self.transformer.decoder.num_queries
|
|
87
|
+
num_o2m_static = self.transformer.decoder.num_o2m_queries
|
|
88
|
+
assert num_o2m_static == (num_o2o_static if self.transformer.decoder.dac else 0)
|
|
89
|
+
self.dac = self.transformer.decoder.dac
|
|
90
|
+
|
|
91
|
+
self.use_instance_query = use_instance_query
|
|
92
|
+
self.multimask_output = multimask_output
|
|
93
|
+
|
|
94
|
+
self.text_embeddings = {}
|
|
95
|
+
self.names = []
|
|
96
|
+
|
|
97
|
+
def _encode_prompt(
|
|
98
|
+
self,
|
|
99
|
+
img_feats,
|
|
100
|
+
img_pos_embeds,
|
|
101
|
+
vis_feat_sizes,
|
|
102
|
+
geometric_prompt,
|
|
103
|
+
visual_prompt_embed=None,
|
|
104
|
+
visual_prompt_mask=None,
|
|
105
|
+
prev_mask_pred=None,
|
|
106
|
+
):
|
|
107
|
+
"""Encode the geometric and visual prompts."""
|
|
108
|
+
if prev_mask_pred is not None:
|
|
109
|
+
img_feats = [img_feats[-1] + prev_mask_pred]
|
|
110
|
+
# Encode geometry
|
|
111
|
+
geo_feats, geo_masks = self.geometry_encoder(
|
|
112
|
+
geo_prompt=geometric_prompt,
|
|
113
|
+
img_feats=img_feats,
|
|
114
|
+
img_sizes=vis_feat_sizes,
|
|
115
|
+
img_pos_embeds=img_pos_embeds,
|
|
116
|
+
)
|
|
117
|
+
if visual_prompt_embed is None:
|
|
118
|
+
visual_prompt_embed = torch.zeros((0, *geo_feats.shape[1:]), device=geo_feats.device)
|
|
119
|
+
visual_prompt_mask = torch.zeros(
|
|
120
|
+
(*geo_masks.shape[:-1], 0),
|
|
121
|
+
device=geo_masks.device,
|
|
122
|
+
dtype=geo_masks.dtype,
|
|
123
|
+
)
|
|
124
|
+
prompt = torch.cat([geo_feats, visual_prompt_embed], dim=0)
|
|
125
|
+
prompt_mask = torch.cat([geo_masks, visual_prompt_mask], dim=1)
|
|
126
|
+
return prompt, prompt_mask
|
|
127
|
+
|
|
128
|
+
def _run_encoder(
|
|
129
|
+
self,
|
|
130
|
+
img_feats,
|
|
131
|
+
img_pos_embeds,
|
|
132
|
+
vis_feat_sizes,
|
|
133
|
+
prompt,
|
|
134
|
+
prompt_mask,
|
|
135
|
+
encoder_extra_kwargs: dict | None = None,
|
|
136
|
+
):
|
|
137
|
+
"""Run the transformer encoder."""
|
|
138
|
+
# Run the encoder
|
|
139
|
+
# make a copy of the image feature lists since the encoder may modify these lists in-place
|
|
140
|
+
memory = self.transformer.encoder(
|
|
141
|
+
src=img_feats.copy(),
|
|
142
|
+
src_key_padding_mask=None,
|
|
143
|
+
src_pos=img_pos_embeds.copy(),
|
|
144
|
+
prompt=prompt,
|
|
145
|
+
prompt_key_padding_mask=prompt_mask,
|
|
146
|
+
feat_sizes=vis_feat_sizes,
|
|
147
|
+
encoder_extra_kwargs=encoder_extra_kwargs,
|
|
148
|
+
)
|
|
149
|
+
encoder_out = {
|
|
150
|
+
# encoded image features
|
|
151
|
+
"encoder_hidden_states": memory["memory"],
|
|
152
|
+
"pos_embed": memory["pos_embed"],
|
|
153
|
+
"padding_mask": memory["padding_mask"],
|
|
154
|
+
"spatial_shapes": memory["spatial_shapes"],
|
|
155
|
+
"valid_ratios": memory["valid_ratios"],
|
|
156
|
+
"vis_feat_sizes": vis_feat_sizes,
|
|
157
|
+
# encoded text features (or other prompts)
|
|
158
|
+
"prompt_before_enc": prompt,
|
|
159
|
+
"prompt_after_enc": memory.get("memory_text", prompt),
|
|
160
|
+
"prompt_mask": prompt_mask,
|
|
161
|
+
}
|
|
162
|
+
return encoder_out
|
|
163
|
+
|
|
164
|
+
def _run_decoder(
|
|
165
|
+
self,
|
|
166
|
+
pos_embed,
|
|
167
|
+
memory,
|
|
168
|
+
src_mask,
|
|
169
|
+
out,
|
|
170
|
+
prompt,
|
|
171
|
+
prompt_mask,
|
|
172
|
+
encoder_out,
|
|
173
|
+
):
|
|
174
|
+
"""Run the transformer decoder."""
|
|
175
|
+
bs = memory.shape[1]
|
|
176
|
+
query_embed = self.transformer.decoder.query_embed.weight
|
|
177
|
+
tgt = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
|
178
|
+
|
|
179
|
+
hs, reference_boxes, dec_presence_out, _ = self.transformer.decoder(
|
|
180
|
+
tgt=tgt,
|
|
181
|
+
memory=memory,
|
|
182
|
+
memory_key_padding_mask=src_mask,
|
|
183
|
+
pos=pos_embed,
|
|
184
|
+
reference_boxes=None,
|
|
185
|
+
spatial_shapes=encoder_out["spatial_shapes"],
|
|
186
|
+
valid_ratios=encoder_out["valid_ratios"],
|
|
187
|
+
tgt_mask=None,
|
|
188
|
+
memory_text=prompt,
|
|
189
|
+
text_attention_mask=prompt_mask,
|
|
190
|
+
apply_dac=False,
|
|
191
|
+
)
|
|
192
|
+
hs = hs.transpose(1, 2) # seq-first to batch-first
|
|
193
|
+
reference_boxes = reference_boxes.transpose(1, 2) # seq-first to batch-first
|
|
194
|
+
if dec_presence_out is not None:
|
|
195
|
+
# seq-first to batch-first
|
|
196
|
+
dec_presence_out = dec_presence_out.transpose(1, 2)
|
|
197
|
+
self._update_scores_and_boxes(
|
|
198
|
+
out,
|
|
199
|
+
hs,
|
|
200
|
+
reference_boxes,
|
|
201
|
+
prompt,
|
|
202
|
+
prompt_mask,
|
|
203
|
+
dec_presence_out=dec_presence_out,
|
|
204
|
+
)
|
|
205
|
+
return out, hs
|
|
206
|
+
|
|
207
|
+
def _update_scores_and_boxes(
|
|
208
|
+
self,
|
|
209
|
+
out,
|
|
210
|
+
hs,
|
|
211
|
+
reference_boxes,
|
|
212
|
+
prompt,
|
|
213
|
+
prompt_mask,
|
|
214
|
+
dec_presence_out=None,
|
|
215
|
+
is_instance_prompt=False,
|
|
216
|
+
):
|
|
217
|
+
"""Update output dict with class scores and box predictions."""
|
|
218
|
+
num_o2o = hs.size(2)
|
|
219
|
+
# score prediction
|
|
220
|
+
if self.use_dot_prod_scoring:
|
|
221
|
+
dot_prod_scoring_head = self.dot_prod_scoring
|
|
222
|
+
if is_instance_prompt and self.instance_dot_prod_scoring is not None:
|
|
223
|
+
dot_prod_scoring_head = self.instance_dot_prod_scoring
|
|
224
|
+
outputs_class = dot_prod_scoring_head(hs, prompt, prompt_mask)
|
|
225
|
+
else:
|
|
226
|
+
class_embed_head = self.class_embed
|
|
227
|
+
if is_instance_prompt and self.instance_class_embed is not None:
|
|
228
|
+
class_embed_head = self.instance_class_embed
|
|
229
|
+
outputs_class = class_embed_head(hs)
|
|
230
|
+
|
|
231
|
+
# box prediction
|
|
232
|
+
box_head = self.transformer.decoder.bbox_embed
|
|
233
|
+
if is_instance_prompt and self.transformer.decoder.instance_bbox_embed is not None:
|
|
234
|
+
box_head = self.transformer.decoder.instance_bbox_embed
|
|
235
|
+
anchor_box_offsets = box_head(hs)
|
|
236
|
+
reference_boxes_inv_sig = inverse_sigmoid(reference_boxes)
|
|
237
|
+
outputs_coord = (reference_boxes_inv_sig + anchor_box_offsets).sigmoid()
|
|
238
|
+
outputs_boxes_xyxy = xywh2xyxy(outputs_coord)
|
|
239
|
+
|
|
240
|
+
if dec_presence_out is not None:
|
|
241
|
+
_update_out(out, "presence_logit_dec", dec_presence_out, update_aux=False)
|
|
242
|
+
|
|
243
|
+
if self.supervise_joint_box_scores:
|
|
244
|
+
assert dec_presence_out is not None
|
|
245
|
+
prob_dec_presence_out = dec_presence_out.clone().sigmoid()
|
|
246
|
+
if self.detach_presence_in_joint_score:
|
|
247
|
+
prob_dec_presence_out = prob_dec_presence_out.detach()
|
|
248
|
+
|
|
249
|
+
outputs_class = inverse_sigmoid(outputs_class.sigmoid() * prob_dec_presence_out.unsqueeze(2)).clamp(
|
|
250
|
+
min=-10.0, max=10.0
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
_update_out(out, "pred_logits", outputs_class[:, :, :num_o2o], update_aux=False)
|
|
254
|
+
_update_out(out, "pred_boxes", outputs_coord[:, :, :num_o2o], update_aux=False)
|
|
255
|
+
_update_out(out, "pred_boxes_xyxy", outputs_boxes_xyxy[:, :, :num_o2o], update_aux=False)
|
|
256
|
+
|
|
257
|
+
def _run_segmentation_heads(
|
|
258
|
+
self,
|
|
259
|
+
out,
|
|
260
|
+
backbone_out,
|
|
261
|
+
encoder_hidden_states,
|
|
262
|
+
prompt,
|
|
263
|
+
prompt_mask,
|
|
264
|
+
hs,
|
|
265
|
+
):
|
|
266
|
+
"""Run segmentation heads and get masks."""
|
|
267
|
+
if self.segmentation_head is not None:
|
|
268
|
+
num_o2o = hs.size(2)
|
|
269
|
+
obj_queries = hs if self.o2m_mask_predict else hs[:, :, :num_o2o]
|
|
270
|
+
seg_head_outputs = self.segmentation_head(
|
|
271
|
+
backbone_feats=backbone_out["backbone_fpn"],
|
|
272
|
+
obj_queries=obj_queries,
|
|
273
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
274
|
+
prompt=prompt,
|
|
275
|
+
prompt_mask=prompt_mask,
|
|
276
|
+
)
|
|
277
|
+
for k, v in seg_head_outputs.items():
|
|
278
|
+
if k in self.segmentation_head.instance_keys:
|
|
279
|
+
_update_out(out, k, v[:, :num_o2o], auxiliary=False)
|
|
280
|
+
else:
|
|
281
|
+
out[k] = v
|
|
282
|
+
else:
|
|
283
|
+
backbone_out.pop("backbone_fpn", None)
|
|
284
|
+
|
|
285
|
+
def forward_grounding(
|
|
286
|
+
self, backbone_out: dict[str, torch.Tensor], text_ids: torch.Tensor, geometric_prompt: Prompt = None
|
|
287
|
+
):
|
|
288
|
+
"""Forward pass for grounding (detection + segmentation) given input images and text."""
|
|
289
|
+
backbone_out, img_feats, img_pos_embeds, vis_feat_sizes = SAM2Model._prepare_backbone_features(
|
|
290
|
+
self, backbone_out, batch=len(text_ids)
|
|
291
|
+
)
|
|
292
|
+
backbone_out.update({k: v for k, v in self.text_embeddings.items()})
|
|
293
|
+
with torch.profiler.record_function("SAM3Image._encode_prompt"):
|
|
294
|
+
prompt, prompt_mask = self._encode_prompt(img_feats, img_pos_embeds, vis_feat_sizes, geometric_prompt)
|
|
295
|
+
# index text features (note that regardless of early or late fusion, the batch size of
|
|
296
|
+
# `txt_feats` is always the number of *prompts* in the encoder)
|
|
297
|
+
txt_feats = backbone_out["language_features"][:, text_ids]
|
|
298
|
+
txt_masks = backbone_out["language_mask"][text_ids]
|
|
299
|
+
# encode text
|
|
300
|
+
prompt = torch.cat([txt_feats, prompt], dim=0)
|
|
301
|
+
prompt_mask = torch.cat([txt_masks, prompt_mask], dim=1)
|
|
302
|
+
|
|
303
|
+
# Run the encoder
|
|
304
|
+
with torch.profiler.record_function("SAM3Image._run_encoder"):
|
|
305
|
+
encoder_out = self._run_encoder(img_feats, img_pos_embeds, vis_feat_sizes, prompt, prompt_mask)
|
|
306
|
+
out = {"backbone_out": backbone_out}
|
|
307
|
+
|
|
308
|
+
# Run the decoder
|
|
309
|
+
with torch.profiler.record_function("SAM3Image._run_decoder"):
|
|
310
|
+
out, hs = self._run_decoder(
|
|
311
|
+
memory=encoder_out["encoder_hidden_states"],
|
|
312
|
+
pos_embed=encoder_out["pos_embed"],
|
|
313
|
+
src_mask=encoder_out["padding_mask"],
|
|
314
|
+
out=out,
|
|
315
|
+
prompt=prompt,
|
|
316
|
+
prompt_mask=prompt_mask,
|
|
317
|
+
encoder_out=encoder_out,
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
# Run segmentation heads
|
|
321
|
+
with torch.profiler.record_function("SAM3Image._run_segmentation_heads"):
|
|
322
|
+
self._run_segmentation_heads(
|
|
323
|
+
out=out,
|
|
324
|
+
backbone_out=backbone_out,
|
|
325
|
+
encoder_hidden_states=encoder_out["encoder_hidden_states"],
|
|
326
|
+
prompt=prompt,
|
|
327
|
+
prompt_mask=prompt_mask,
|
|
328
|
+
hs=hs,
|
|
329
|
+
)
|
|
330
|
+
return out
|
|
331
|
+
|
|
332
|
+
def set_classes(self, text: list[str]):
|
|
333
|
+
"""Set the text embeddings for the given class names."""
|
|
334
|
+
self.text_embeddings = self.backbone.forward_text(text)
|
|
335
|
+
self.names = text
|
|
336
|
+
|
|
337
|
+
def set_imgsz(self, imgsz: tuple[int, int]):
|
|
338
|
+
"""Set the image size for the model."""
|
|
339
|
+
self.backbone.set_imgsz(imgsz)
|