dgenerate-ultralytics-headless 8.3.214__py3-none-any.whl → 8.4.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/METADATA +64 -74
- dgenerate_ultralytics_headless-8.4.7.dist-info/RECORD +311 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/WHEEL +1 -1
- tests/__init__.py +7 -9
- tests/conftest.py +8 -15
- tests/test_cli.py +1 -1
- tests/test_cuda.py +13 -10
- tests/test_engine.py +9 -9
- tests/test_exports.py +65 -13
- tests/test_integrations.py +13 -13
- tests/test_python.py +125 -69
- tests/test_solutions.py +161 -152
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +86 -92
- ultralytics/cfg/datasets/Argoverse.yaml +7 -6
- ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
- ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
- ultralytics/cfg/datasets/ImageNet.yaml +1 -1
- ultralytics/cfg/datasets/TT100K.yaml +346 -0
- ultralytics/cfg/datasets/VOC.yaml +15 -16
- ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
- ultralytics/cfg/datasets/coco-pose.yaml +21 -0
- ultralytics/cfg/datasets/coco12-formats.yaml +101 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
- ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
- ultralytics/cfg/datasets/dog-pose.yaml +28 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +5 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
- ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
- ultralytics/cfg/datasets/xView.yaml +16 -16
- ultralytics/cfg/default.yaml +4 -2
- ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
- ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
- ultralytics/cfg/models/26/yolo26-cls.yaml +33 -0
- ultralytics/cfg/models/26/yolo26-obb.yaml +52 -0
- ultralytics/cfg/models/26/yolo26-p2.yaml +60 -0
- ultralytics/cfg/models/26/yolo26-p6.yaml +62 -0
- ultralytics/cfg/models/26/yolo26-pose.yaml +53 -0
- ultralytics/cfg/models/26/yolo26-seg.yaml +52 -0
- ultralytics/cfg/models/26/yolo26.yaml +52 -0
- ultralytics/cfg/models/26/yoloe-26-seg.yaml +53 -0
- ultralytics/cfg/models/26/yoloe-26.yaml +53 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
- ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
- ultralytics/cfg/models/v6/yolov6.yaml +1 -1
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
- ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +5 -6
- ultralytics/data/augment.py +300 -475
- ultralytics/data/base.py +18 -26
- ultralytics/data/build.py +147 -25
- ultralytics/data/converter.py +108 -87
- ultralytics/data/dataset.py +47 -75
- ultralytics/data/loaders.py +42 -49
- ultralytics/data/split.py +5 -6
- ultralytics/data/split_dota.py +8 -15
- ultralytics/data/utils.py +36 -45
- ultralytics/engine/exporter.py +351 -263
- ultralytics/engine/model.py +186 -225
- ultralytics/engine/predictor.py +45 -54
- ultralytics/engine/results.py +198 -325
- ultralytics/engine/trainer.py +165 -106
- ultralytics/engine/tuner.py +41 -43
- ultralytics/engine/validator.py +55 -38
- ultralytics/hub/__init__.py +16 -19
- ultralytics/hub/auth.py +6 -12
- ultralytics/hub/google/__init__.py +7 -10
- ultralytics/hub/session.py +15 -25
- ultralytics/hub/utils.py +5 -8
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +8 -10
- ultralytics/models/fastsam/predict.py +18 -30
- ultralytics/models/fastsam/utils.py +1 -2
- ultralytics/models/fastsam/val.py +5 -7
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +5 -8
- ultralytics/models/nas/predict.py +7 -9
- ultralytics/models/nas/val.py +1 -2
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +5 -8
- ultralytics/models/rtdetr/predict.py +15 -19
- ultralytics/models/rtdetr/train.py +10 -13
- ultralytics/models/rtdetr/val.py +21 -23
- ultralytics/models/sam/__init__.py +15 -2
- ultralytics/models/sam/amg.py +14 -20
- ultralytics/models/sam/build.py +26 -19
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +29 -32
- ultralytics/models/sam/modules/blocks.py +83 -144
- ultralytics/models/sam/modules/decoders.py +19 -37
- ultralytics/models/sam/modules/encoders.py +44 -101
- ultralytics/models/sam/modules/memory_attention.py +16 -30
- ultralytics/models/sam/modules/sam.py +200 -73
- ultralytics/models/sam/modules/tiny_encoder.py +64 -83
- ultralytics/models/sam/modules/transformer.py +18 -28
- ultralytics/models/sam/modules/utils.py +174 -50
- ultralytics/models/sam/predict.py +2248 -350
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/loss.py +14 -26
- ultralytics/models/utils/ops.py +13 -17
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +10 -13
- ultralytics/models/yolo/classify/train.py +12 -33
- ultralytics/models/yolo/classify/val.py +30 -29
- ultralytics/models/yolo/detect/predict.py +9 -12
- ultralytics/models/yolo/detect/train.py +17 -23
- ultralytics/models/yolo/detect/val.py +77 -59
- ultralytics/models/yolo/model.py +43 -60
- ultralytics/models/yolo/obb/predict.py +7 -16
- ultralytics/models/yolo/obb/train.py +14 -17
- ultralytics/models/yolo/obb/val.py +40 -37
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +7 -22
- ultralytics/models/yolo/pose/train.py +13 -16
- ultralytics/models/yolo/pose/val.py +39 -58
- ultralytics/models/yolo/segment/predict.py +17 -21
- ultralytics/models/yolo/segment/train.py +7 -10
- ultralytics/models/yolo/segment/val.py +95 -47
- ultralytics/models/yolo/world/train.py +8 -14
- ultralytics/models/yolo/world/train_world.py +11 -34
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +16 -23
- ultralytics/models/yolo/yoloe/train.py +36 -44
- ultralytics/models/yolo/yoloe/train_seg.py +11 -11
- ultralytics/models/yolo/yoloe/val.py +15 -20
- ultralytics/nn/__init__.py +7 -7
- ultralytics/nn/autobackend.py +159 -85
- ultralytics/nn/modules/__init__.py +68 -60
- ultralytics/nn/modules/activation.py +4 -6
- ultralytics/nn/modules/block.py +260 -224
- ultralytics/nn/modules/conv.py +52 -97
- ultralytics/nn/modules/head.py +831 -299
- ultralytics/nn/modules/transformer.py +76 -88
- ultralytics/nn/modules/utils.py +16 -21
- ultralytics/nn/tasks.py +180 -195
- ultralytics/nn/text_model.py +45 -69
- ultralytics/optim/__init__.py +5 -0
- ultralytics/optim/muon.py +338 -0
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +13 -19
- ultralytics/solutions/analytics.py +15 -16
- ultralytics/solutions/config.py +6 -7
- ultralytics/solutions/distance_calculation.py +10 -13
- ultralytics/solutions/heatmap.py +8 -14
- ultralytics/solutions/instance_segmentation.py +6 -9
- ultralytics/solutions/object_blurrer.py +7 -10
- ultralytics/solutions/object_counter.py +12 -19
- ultralytics/solutions/object_cropper.py +8 -14
- ultralytics/solutions/parking_management.py +34 -32
- ultralytics/solutions/queue_management.py +10 -12
- ultralytics/solutions/region_counter.py +9 -12
- ultralytics/solutions/security_alarm.py +15 -20
- ultralytics/solutions/similarity_search.py +10 -15
- ultralytics/solutions/solutions.py +77 -76
- ultralytics/solutions/speed_estimation.py +7 -10
- ultralytics/solutions/streamlit_inference.py +2 -4
- ultralytics/solutions/templates/similarity-search.html +7 -18
- ultralytics/solutions/trackzone.py +7 -10
- ultralytics/solutions/vision_eye.py +5 -8
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +3 -5
- ultralytics/trackers/bot_sort.py +10 -27
- ultralytics/trackers/byte_tracker.py +21 -37
- ultralytics/trackers/track.py +4 -7
- ultralytics/trackers/utils/gmc.py +11 -22
- ultralytics/trackers/utils/kalman_filter.py +37 -48
- ultralytics/trackers/utils/matching.py +12 -15
- ultralytics/utils/__init__.py +124 -124
- ultralytics/utils/autobatch.py +2 -4
- ultralytics/utils/autodevice.py +17 -18
- ultralytics/utils/benchmarks.py +57 -71
- ultralytics/utils/callbacks/base.py +8 -10
- ultralytics/utils/callbacks/clearml.py +5 -13
- ultralytics/utils/callbacks/comet.py +32 -46
- ultralytics/utils/callbacks/dvc.py +13 -18
- ultralytics/utils/callbacks/mlflow.py +4 -5
- ultralytics/utils/callbacks/neptune.py +7 -15
- ultralytics/utils/callbacks/platform.py +423 -38
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +25 -31
- ultralytics/utils/callbacks/wb.py +16 -14
- ultralytics/utils/checks.py +127 -85
- ultralytics/utils/cpu.py +3 -8
- ultralytics/utils/dist.py +9 -12
- ultralytics/utils/downloads.py +25 -33
- ultralytics/utils/errors.py +6 -14
- ultralytics/utils/events.py +2 -4
- ultralytics/utils/export/__init__.py +4 -236
- ultralytics/utils/export/engine.py +246 -0
- ultralytics/utils/export/imx.py +117 -63
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +26 -30
- ultralytics/utils/git.py +9 -11
- ultralytics/utils/instance.py +30 -51
- ultralytics/utils/logger.py +212 -114
- ultralytics/utils/loss.py +601 -215
- ultralytics/utils/metrics.py +128 -156
- ultralytics/utils/nms.py +13 -16
- ultralytics/utils/ops.py +117 -166
- ultralytics/utils/patches.py +75 -21
- ultralytics/utils/plotting.py +75 -80
- ultralytics/utils/tal.py +125 -59
- ultralytics/utils/torch_utils.py +53 -79
- ultralytics/utils/tqdm.py +24 -21
- ultralytics/utils/triton.py +13 -19
- ultralytics/utils/tuner.py +19 -10
- dgenerate_ultralytics_headless-8.3.214.dist-info/RECORD +0 -283
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,415 @@
|
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
|
+
|
|
3
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
import torchvision
|
|
8
|
+
|
|
9
|
+
from ultralytics.nn.modules.utils import _get_clones
|
|
10
|
+
from ultralytics.utils.ops import xywh2xyxy
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def is_right_padded(mask: torch.Tensor):
|
|
14
|
+
"""Given a padding mask (following pytorch convention, 1s for padded values), returns whether the padding is on the
|
|
15
|
+
right or not.
|
|
16
|
+
"""
|
|
17
|
+
return (mask.long() == torch.sort(mask.long(), dim=-1)[0]).all()
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def concat_padded_sequences(seq1, mask1, seq2, mask2, return_index: bool = False):
|
|
21
|
+
"""
|
|
22
|
+
Concatenates two right-padded sequences, such that the resulting sequence
|
|
23
|
+
is contiguous and also right-padded.
|
|
24
|
+
|
|
25
|
+
Following pytorch's convention, tensors are sequence first, and the mask are
|
|
26
|
+
batch first, with 1s for padded values.
|
|
27
|
+
|
|
28
|
+
:param seq1: A tensor of shape (seq1_length, batch_size, hidden_size).
|
|
29
|
+
:param mask1: A tensor of shape (batch_size, seq1_length).
|
|
30
|
+
:param seq2: A tensor of shape (seq2_length, batch_size, hidden_size).
|
|
31
|
+
:param mask2: A tensor of shape (batch_size, seq2_length).
|
|
32
|
+
:param return_index: If True, also returns the index of the ids of the element of seq2
|
|
33
|
+
in the concatenated sequence. This can be used to retrieve the elements of seq2
|
|
34
|
+
:return: A tuple (concatenated_sequence, concatenated_mask) if return_index is False,
|
|
35
|
+
otherwise (concatenated_sequence, concatenated_mask, index).
|
|
36
|
+
"""
|
|
37
|
+
seq1_length, batch_size, hidden_size = seq1.shape
|
|
38
|
+
seq2_length, batch_size, hidden_size = seq2.shape
|
|
39
|
+
|
|
40
|
+
assert batch_size == seq1.size(1) == seq2.size(1) == mask1.size(0) == mask2.size(0)
|
|
41
|
+
assert hidden_size == seq1.size(2) == seq2.size(2)
|
|
42
|
+
assert seq1_length == mask1.size(1)
|
|
43
|
+
assert seq2_length == mask2.size(1)
|
|
44
|
+
|
|
45
|
+
torch._assert(is_right_padded(mask1), "Mask is not right padded")
|
|
46
|
+
torch._assert(is_right_padded(mask2), "Mask is not right padded")
|
|
47
|
+
|
|
48
|
+
actual_seq1_lengths = (~mask1).sum(dim=-1)
|
|
49
|
+
actual_seq2_lengths = (~mask2).sum(dim=-1)
|
|
50
|
+
|
|
51
|
+
final_lengths = actual_seq1_lengths + actual_seq2_lengths
|
|
52
|
+
max_length = seq1_length + seq2_length
|
|
53
|
+
concatenated_mask = (
|
|
54
|
+
torch.arange(max_length, device=seq2.device)[None].repeat(batch_size, 1) >= final_lengths[:, None]
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
# (max_len, batch_size, hidden_size)
|
|
58
|
+
concatenated_sequence = torch.zeros((max_length, batch_size, hidden_size), device=seq2.device, dtype=seq2.dtype)
|
|
59
|
+
concatenated_sequence[:seq1_length, :, :] = seq1
|
|
60
|
+
|
|
61
|
+
# At this point, the element of seq1 are in the right place
|
|
62
|
+
# We just need to shift the elements of seq2
|
|
63
|
+
|
|
64
|
+
index = torch.arange(seq2_length, device=seq2.device)[:, None].repeat(1, batch_size)
|
|
65
|
+
index = index + actual_seq1_lengths[None]
|
|
66
|
+
|
|
67
|
+
concatenated_sequence = concatenated_sequence.scatter(0, index[:, :, None].expand(-1, -1, hidden_size), seq2)
|
|
68
|
+
|
|
69
|
+
if return_index:
|
|
70
|
+
return concatenated_sequence, concatenated_mask, index
|
|
71
|
+
|
|
72
|
+
return concatenated_sequence, concatenated_mask
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class Prompt:
|
|
76
|
+
"""Utility class to manipulate geometric prompts.
|
|
77
|
+
|
|
78
|
+
We expect the sequences in pytorch convention, that is sequence first, batch second The dimensions are expected as
|
|
79
|
+
follows: box_embeddings shape: N_boxes x B x C_box box_mask shape: B x N_boxes. Can be None if nothing is masked out
|
|
80
|
+
point_embeddings shape: N_points x B x C_point point_mask shape: B x N_points. Can be None if nothing is masked out
|
|
81
|
+
mask_embeddings shape: N_masks x B x 1 x H_mask x W_mask mask_mask shape: B x N_masks. Can be None if nothing is
|
|
82
|
+
masked out
|
|
83
|
+
|
|
84
|
+
We also store positive/negative labels. These tensors are also stored batch-first If they are None, we'll assume
|
|
85
|
+
positive labels everywhere box_labels: long tensor of shape N_boxes x B point_labels: long tensor of shape N_points
|
|
86
|
+
x B mask_labels: long tensor of shape N_masks x B
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
def __init__(self, box_embeddings=None, box_mask=None, box_labels=None):
|
|
90
|
+
"""Initialize the Prompt object."""
|
|
91
|
+
# Check for null prompt
|
|
92
|
+
# Check for null prompt
|
|
93
|
+
if box_embeddings is None:
|
|
94
|
+
self.box_embeddings = None
|
|
95
|
+
self.box_labels = None
|
|
96
|
+
self.box_mask = None
|
|
97
|
+
return
|
|
98
|
+
|
|
99
|
+
# Get sequence length, batch size, and device
|
|
100
|
+
box_seq_len = box_embeddings.shape[0]
|
|
101
|
+
bs = box_embeddings.shape[1]
|
|
102
|
+
device = box_embeddings.device
|
|
103
|
+
|
|
104
|
+
# Initialize labels and attention mask if not provided
|
|
105
|
+
if box_labels is None:
|
|
106
|
+
box_labels = torch.ones(box_seq_len, bs, device=device, dtype=torch.long)
|
|
107
|
+
if box_mask is None:
|
|
108
|
+
box_mask = torch.zeros(bs, box_seq_len, device=device, dtype=torch.bool)
|
|
109
|
+
|
|
110
|
+
# Dimension checks
|
|
111
|
+
assert list(box_embeddings.shape[:2]) == [box_seq_len, bs], (
|
|
112
|
+
f"Wrong dimension for box embeddings. Expected [{box_seq_len}, {bs}, *] got {box_embeddings.shape}"
|
|
113
|
+
)
|
|
114
|
+
assert box_embeddings.shape[-1] == 4, (
|
|
115
|
+
f"Expected box embeddings to have 4 coordinates, got {box_embeddings.shape[-1]}"
|
|
116
|
+
)
|
|
117
|
+
assert list(box_mask.shape) == [bs, box_seq_len], (
|
|
118
|
+
f"Wrong dimension for box mask. Expected [{bs}, {box_seq_len}] got {box_mask.shape}"
|
|
119
|
+
)
|
|
120
|
+
assert list(box_labels.shape) == [box_seq_len, bs], (
|
|
121
|
+
f"Wrong dimension for box labels. Expected [{box_seq_len}, {bs}] got {box_labels.shape}"
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
# Device checks
|
|
125
|
+
assert box_embeddings.device == device, (
|
|
126
|
+
f"Expected box embeddings to be on device {device}, got {box_embeddings.device}"
|
|
127
|
+
)
|
|
128
|
+
assert box_mask.device == device, f"Expected box mask to be on device {device}, got {box_mask.device}"
|
|
129
|
+
assert box_labels.device == device, f"Expected box labels to be on device {device}, got {box_labels.device}"
|
|
130
|
+
|
|
131
|
+
self.box_embeddings = box_embeddings
|
|
132
|
+
self.box_mask = box_mask
|
|
133
|
+
self.box_labels = box_labels
|
|
134
|
+
|
|
135
|
+
def append_boxes(self, boxes, labels=None, mask=None):
|
|
136
|
+
"""Append box prompts to existing prompts.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
boxes: Tensor of shape (N_new_boxes, B, 4) with normalized box coordinates
|
|
140
|
+
labels: Optional tensor of shape (N_new_boxes, B) with positive/negative labels
|
|
141
|
+
mask: Optional tensor of shape (B, N_new_boxes) for attention mask
|
|
142
|
+
"""
|
|
143
|
+
if self.box_embeddings is None:
|
|
144
|
+
# First boxes - initialize
|
|
145
|
+
self.box_embeddings = boxes
|
|
146
|
+
bs = boxes.shape[1]
|
|
147
|
+
box_seq_len = boxes.shape[0]
|
|
148
|
+
|
|
149
|
+
if labels is None:
|
|
150
|
+
labels = torch.ones(box_seq_len, bs, device=boxes.device, dtype=torch.long)
|
|
151
|
+
if mask is None:
|
|
152
|
+
mask = torch.zeros(bs, box_seq_len, device=boxes.device, dtype=torch.bool)
|
|
153
|
+
|
|
154
|
+
self.box_labels = labels
|
|
155
|
+
self.box_mask = mask
|
|
156
|
+
return
|
|
157
|
+
|
|
158
|
+
# Append to existing boxes
|
|
159
|
+
bs = self.box_embeddings.shape[1]
|
|
160
|
+
assert boxes.shape[1] == bs, f"Batch size mismatch: expected {bs}, got {boxes.shape[1]}"
|
|
161
|
+
|
|
162
|
+
if labels is None:
|
|
163
|
+
labels = torch.ones(boxes.shape[0], bs, device=boxes.device, dtype=torch.long)
|
|
164
|
+
if mask is None:
|
|
165
|
+
mask = torch.zeros(bs, boxes.shape[0], dtype=torch.bool, device=boxes.device)
|
|
166
|
+
|
|
167
|
+
assert list(boxes.shape[:2]) == list(labels.shape[:2]), (
|
|
168
|
+
f"Shape mismatch between boxes {boxes.shape} and labels {labels.shape}"
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
# Concatenate using the helper function
|
|
172
|
+
self.box_labels, _ = concat_padded_sequences(
|
|
173
|
+
self.box_labels.unsqueeze(-1), self.box_mask, labels.unsqueeze(-1), mask
|
|
174
|
+
)
|
|
175
|
+
self.box_labels = self.box_labels.squeeze(-1)
|
|
176
|
+
self.box_embeddings, self.box_mask = concat_padded_sequences(self.box_embeddings, self.box_mask, boxes, mask)
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
class SequenceGeometryEncoder(nn.Module):
|
|
180
|
+
"""Encoder for geometric box prompts. Assumes boxes are passed in the "normalized CxCyWH" format.
|
|
181
|
+
|
|
182
|
+
Boxes can be encoded with any of the three possibilities:
|
|
183
|
+
- direct projection: linear projection from coordinate space to d_model
|
|
184
|
+
- pooling: RoI align features from the backbone
|
|
185
|
+
- pos encoder: position encoding of the box center
|
|
186
|
+
|
|
187
|
+
These three options are mutually compatible and will be summed if multiple are selected.
|
|
188
|
+
|
|
189
|
+
As an alternative, boxes can be encoded as two corner points (top-left and bottom-right).
|
|
190
|
+
|
|
191
|
+
The encoded sequence can be further processed with a transformer.
|
|
192
|
+
"""
|
|
193
|
+
|
|
194
|
+
def __init__(
|
|
195
|
+
self,
|
|
196
|
+
encode_boxes_as_points: bool,
|
|
197
|
+
boxes_direct_project: bool,
|
|
198
|
+
boxes_pool: bool,
|
|
199
|
+
boxes_pos_enc: bool,
|
|
200
|
+
d_model: int,
|
|
201
|
+
pos_enc,
|
|
202
|
+
num_layers: int,
|
|
203
|
+
layer: nn.Module,
|
|
204
|
+
roi_size: int = 7,
|
|
205
|
+
add_cls: bool = True,
|
|
206
|
+
add_post_encode_proj: bool = True,
|
|
207
|
+
use_act_ckpt: bool = False,
|
|
208
|
+
):
|
|
209
|
+
"""Initialize the SequenceGeometryEncoder."""
|
|
210
|
+
super().__init__()
|
|
211
|
+
|
|
212
|
+
self.d_model = d_model
|
|
213
|
+
self.pos_enc = pos_enc
|
|
214
|
+
self.encode_boxes_as_points = encode_boxes_as_points
|
|
215
|
+
self.roi_size = roi_size
|
|
216
|
+
|
|
217
|
+
# Label embeddings: 2 labels if encoding as boxes (pos/neg)
|
|
218
|
+
# 6 labels if encoding as points (regular pos/neg, top-left pos/neg, bottom-right pos/neg)
|
|
219
|
+
num_labels = 6 if self.encode_boxes_as_points else 2
|
|
220
|
+
self.label_embed = torch.nn.Embedding(num_labels, self.d_model)
|
|
221
|
+
|
|
222
|
+
# CLS token for pooling
|
|
223
|
+
self.cls_embed = None
|
|
224
|
+
if add_cls:
|
|
225
|
+
self.cls_embed = torch.nn.Embedding(1, self.d_model)
|
|
226
|
+
|
|
227
|
+
# Point encoding (used when encode_boxes_as_points is True)
|
|
228
|
+
if encode_boxes_as_points:
|
|
229
|
+
self.points_direct_project = nn.Linear(2, self.d_model)
|
|
230
|
+
self.points_pool_project = None
|
|
231
|
+
self.points_pos_enc_project = None
|
|
232
|
+
else:
|
|
233
|
+
# Box encoding modules
|
|
234
|
+
assert boxes_direct_project or boxes_pos_enc or boxes_pool, "Error: need at least one way to encode boxes"
|
|
235
|
+
self.points_direct_project = None
|
|
236
|
+
self.points_pool_project = None
|
|
237
|
+
self.points_pos_enc_project = None
|
|
238
|
+
|
|
239
|
+
self.boxes_direct_project = None
|
|
240
|
+
self.boxes_pool_project = None
|
|
241
|
+
self.boxes_pos_enc_project = None
|
|
242
|
+
|
|
243
|
+
if boxes_direct_project:
|
|
244
|
+
self.boxes_direct_project = nn.Linear(4, self.d_model)
|
|
245
|
+
if boxes_pool:
|
|
246
|
+
self.boxes_pool_project = nn.Conv2d(self.d_model, self.d_model, self.roi_size)
|
|
247
|
+
if boxes_pos_enc:
|
|
248
|
+
self.boxes_pos_enc_project = nn.Linear(self.d_model + 2, self.d_model)
|
|
249
|
+
|
|
250
|
+
self.final_proj = None
|
|
251
|
+
if add_post_encode_proj:
|
|
252
|
+
self.final_proj = nn.Linear(self.d_model, self.d_model)
|
|
253
|
+
self.norm = nn.LayerNorm(self.d_model)
|
|
254
|
+
|
|
255
|
+
self.img_pre_norm = nn.Identity()
|
|
256
|
+
if self.points_pool_project is not None or self.boxes_pool_project is not None:
|
|
257
|
+
self.img_pre_norm = nn.LayerNorm(self.d_model)
|
|
258
|
+
|
|
259
|
+
self.encode = None
|
|
260
|
+
if num_layers > 0:
|
|
261
|
+
assert add_cls, "It's currently highly recommended to add a CLS when using a transformer"
|
|
262
|
+
self.encode = _get_clones(layer, num_layers)
|
|
263
|
+
self.encode_norm = nn.LayerNorm(self.d_model)
|
|
264
|
+
|
|
265
|
+
self.use_act_ckpt = use_act_ckpt
|
|
266
|
+
|
|
267
|
+
def _encode_points(self, points, points_mask, points_labels, img_feats):
|
|
268
|
+
"""Encode points (used when boxes are converted to corner points)."""
|
|
269
|
+
# Direct projection of coordinates
|
|
270
|
+
points_embed = self.points_direct_project(points.to(img_feats.dtype))
|
|
271
|
+
|
|
272
|
+
# Add label embeddings
|
|
273
|
+
type_embed = self.label_embed(points_labels.long())
|
|
274
|
+
return type_embed + points_embed, points_mask
|
|
275
|
+
|
|
276
|
+
def _encode_boxes(self, boxes, boxes_mask, boxes_labels, img_feats: torch.Tensor):
|
|
277
|
+
"""Encode boxes using configured encoding methods."""
|
|
278
|
+
boxes_embed = None
|
|
279
|
+
n_boxes, bs = boxes.shape[:2]
|
|
280
|
+
|
|
281
|
+
if self.boxes_direct_project is not None:
|
|
282
|
+
proj = self.boxes_direct_project(boxes.to(img_feats.dtype))
|
|
283
|
+
boxes_embed = proj
|
|
284
|
+
|
|
285
|
+
if self.boxes_pool_project is not None:
|
|
286
|
+
H, W = img_feats.shape[-2:]
|
|
287
|
+
|
|
288
|
+
# Convert boxes to xyxy format and denormalize
|
|
289
|
+
boxes_xyxy = xywh2xyxy(boxes.to(img_feats.dtype))
|
|
290
|
+
scale = torch.tensor([W, H, W, H], dtype=boxes_xyxy.dtype)
|
|
291
|
+
scale = scale.to(device=boxes_xyxy.device, non_blocking=True)
|
|
292
|
+
scale = scale.view(1, 1, 4)
|
|
293
|
+
boxes_xyxy = boxes_xyxy * scale
|
|
294
|
+
|
|
295
|
+
# RoI align
|
|
296
|
+
sampled = torchvision.ops.roi_align(img_feats, boxes_xyxy.transpose(0, 1).unbind(0), self.roi_size)
|
|
297
|
+
assert list(sampled.shape) == [
|
|
298
|
+
bs * n_boxes,
|
|
299
|
+
self.d_model,
|
|
300
|
+
self.roi_size,
|
|
301
|
+
self.roi_size,
|
|
302
|
+
]
|
|
303
|
+
proj = self.boxes_pool_project(sampled)
|
|
304
|
+
proj = proj.view(bs, n_boxes, self.d_model).transpose(0, 1)
|
|
305
|
+
|
|
306
|
+
if boxes_embed is None:
|
|
307
|
+
boxes_embed = proj
|
|
308
|
+
else:
|
|
309
|
+
boxes_embed = boxes_embed + proj
|
|
310
|
+
|
|
311
|
+
if self.boxes_pos_enc_project is not None:
|
|
312
|
+
cx, cy, w, h = boxes.unbind(-1)
|
|
313
|
+
enc = self.pos_enc.encode_boxes(cx.flatten(), cy.flatten(), w.flatten(), h.flatten())
|
|
314
|
+
enc = enc.view(boxes.shape[0], boxes.shape[1], enc.shape[-1])
|
|
315
|
+
|
|
316
|
+
proj = self.boxes_pos_enc_project(enc.to(img_feats.dtype))
|
|
317
|
+
if boxes_embed is None:
|
|
318
|
+
boxes_embed = proj
|
|
319
|
+
else:
|
|
320
|
+
boxes_embed = boxes_embed + proj
|
|
321
|
+
|
|
322
|
+
# Add label embeddings
|
|
323
|
+
type_embed = self.label_embed(boxes_labels.long())
|
|
324
|
+
return type_embed + boxes_embed, boxes_mask
|
|
325
|
+
|
|
326
|
+
def forward(self, geo_prompt: Prompt, img_feats, img_sizes, img_pos_embeds=None):
|
|
327
|
+
"""Encode geometric box prompts.
|
|
328
|
+
|
|
329
|
+
Args:
|
|
330
|
+
geo_prompt: Prompt object containing box embeddings, masks, and labels
|
|
331
|
+
img_feats: List of image features from backbone
|
|
332
|
+
img_sizes: List of (H, W) tuples for each feature level
|
|
333
|
+
img_pos_embeds: Optional position embeddings for image features
|
|
334
|
+
|
|
335
|
+
Returns:
|
|
336
|
+
Tuple of (encoded_embeddings, attention_mask)
|
|
337
|
+
"""
|
|
338
|
+
boxes = geo_prompt.box_embeddings
|
|
339
|
+
boxes_mask = geo_prompt.box_mask
|
|
340
|
+
boxes_labels = geo_prompt.box_labels
|
|
341
|
+
|
|
342
|
+
seq_first_img_feats = img_feats[-1] # [H*W, B, C]
|
|
343
|
+
seq_first_img_pos_embeds = (
|
|
344
|
+
img_pos_embeds[-1] if img_pos_embeds is not None else torch.zeros_like(seq_first_img_feats)
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
# Prepare image features for pooling if needed
|
|
348
|
+
if self.points_pool_project or self.boxes_pool_project:
|
|
349
|
+
assert len(img_feats) == len(img_sizes)
|
|
350
|
+
cur_img_feat = img_feats[-1]
|
|
351
|
+
cur_img_feat = self.img_pre_norm(cur_img_feat)
|
|
352
|
+
H, W = img_sizes[-1]
|
|
353
|
+
assert cur_img_feat.shape[0] == H * W
|
|
354
|
+
N, C = cur_img_feat.shape[-2:]
|
|
355
|
+
# Reshape to NxCxHxW
|
|
356
|
+
cur_img_feat = cur_img_feat.permute(1, 2, 0)
|
|
357
|
+
cur_img_feat = cur_img_feat.view(N, C, H, W)
|
|
358
|
+
img_feats = cur_img_feat
|
|
359
|
+
|
|
360
|
+
if self.encode_boxes_as_points:
|
|
361
|
+
# Convert boxes to corner points
|
|
362
|
+
assert boxes is not None and boxes.shape[-1] == 4
|
|
363
|
+
|
|
364
|
+
boxes_xyxy = xywh2xyxy(boxes)
|
|
365
|
+
top_left, bottom_right = boxes_xyxy.split(split_size=2, dim=-1)
|
|
366
|
+
|
|
367
|
+
# Adjust labels for corner points (offset by 2 and 4)
|
|
368
|
+
labels_tl = boxes_labels + 2
|
|
369
|
+
labels_br = boxes_labels + 4
|
|
370
|
+
|
|
371
|
+
# Concatenate top-left and bottom-right points
|
|
372
|
+
points = torch.cat([top_left, bottom_right], dim=0)
|
|
373
|
+
points_labels = torch.cat([labels_tl, labels_br], dim=0)
|
|
374
|
+
points_mask = torch.cat([boxes_mask, boxes_mask], dim=1)
|
|
375
|
+
|
|
376
|
+
final_embeds, final_mask = self._encode_points(
|
|
377
|
+
points=points,
|
|
378
|
+
points_mask=points_mask,
|
|
379
|
+
points_labels=points_labels,
|
|
380
|
+
img_feats=img_feats,
|
|
381
|
+
)
|
|
382
|
+
else:
|
|
383
|
+
# Encode boxes directly
|
|
384
|
+
final_embeds, final_mask = self._encode_boxes(
|
|
385
|
+
boxes=boxes,
|
|
386
|
+
boxes_mask=boxes_mask,
|
|
387
|
+
boxes_labels=boxes_labels,
|
|
388
|
+
img_feats=img_feats,
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
bs = final_embeds.shape[1]
|
|
392
|
+
assert final_mask.shape[0] == bs
|
|
393
|
+
|
|
394
|
+
# Add CLS token if configured
|
|
395
|
+
if self.cls_embed is not None:
|
|
396
|
+
cls = self.cls_embed.weight.view(1, 1, self.d_model).repeat(1, bs, 1)
|
|
397
|
+
cls_mask = torch.zeros(bs, 1, dtype=final_mask.dtype, device=final_mask.device)
|
|
398
|
+
final_embeds, final_mask = concat_padded_sequences(final_embeds, final_mask, cls, cls_mask)
|
|
399
|
+
|
|
400
|
+
# Final projection
|
|
401
|
+
if self.final_proj is not None:
|
|
402
|
+
final_embeds = self.norm(self.final_proj(final_embeds))
|
|
403
|
+
|
|
404
|
+
# Transformer encoding layers
|
|
405
|
+
if self.encode is not None:
|
|
406
|
+
for lay in self.encode:
|
|
407
|
+
final_embeds = lay(
|
|
408
|
+
tgt=final_embeds,
|
|
409
|
+
memory=seq_first_img_feats,
|
|
410
|
+
tgt_key_padding_mask=final_mask,
|
|
411
|
+
pos=seq_first_img_pos_embeds,
|
|
412
|
+
)
|
|
413
|
+
final_embeds = self.encode_norm(final_embeds)
|
|
414
|
+
|
|
415
|
+
return final_embeds, final_mask
|
|
@@ -0,0 +1,286 @@
|
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
|
+
|
|
3
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import math
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn as nn
|
|
11
|
+
import torch.nn.functional as F
|
|
12
|
+
import torch.utils.checkpoint as checkpoint
|
|
13
|
+
|
|
14
|
+
from ultralytics.nn.modules.transformer import MLP
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class LinearPresenceHead(nn.Sequential):
|
|
18
|
+
"""Linear presence head for predicting the presence of classes in an image."""
|
|
19
|
+
|
|
20
|
+
def __init__(self, d_model):
|
|
21
|
+
"""Initializes the LinearPresenceHead."""
|
|
22
|
+
# a hack to make `LinearPresenceHead` compatible with old checkpoints
|
|
23
|
+
super().__init__(nn.Identity(), nn.Identity(), nn.Linear(d_model, 1))
|
|
24
|
+
|
|
25
|
+
def forward(self, hs, prompt, prompt_mask):
|
|
26
|
+
"""Forward pass of the presence head."""
|
|
27
|
+
return super().forward(hs)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class MaskPredictor(nn.Module):
|
|
31
|
+
"""Predicts masks from object queries and pixel embeddings."""
|
|
32
|
+
|
|
33
|
+
def __init__(self, hidden_dim, mask_dim):
|
|
34
|
+
"""Initializes the MaskPredictor."""
|
|
35
|
+
super().__init__()
|
|
36
|
+
self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
|
|
37
|
+
|
|
38
|
+
def forward(self, obj_queries, pixel_embed):
|
|
39
|
+
"""Predicts masks from object queries and pixel embeddings."""
|
|
40
|
+
if len(obj_queries.shape) == 3:
|
|
41
|
+
if pixel_embed.ndim == 3:
|
|
42
|
+
# batch size was omitted
|
|
43
|
+
mask_preds = torch.einsum("bqc,chw->bqhw", self.mask_embed(obj_queries), pixel_embed)
|
|
44
|
+
else:
|
|
45
|
+
mask_preds = torch.einsum("bqc,bchw->bqhw", self.mask_embed(obj_queries), pixel_embed)
|
|
46
|
+
else:
|
|
47
|
+
# Assumed to have aux masks
|
|
48
|
+
if pixel_embed.ndim == 3:
|
|
49
|
+
# batch size was omitted
|
|
50
|
+
mask_preds = torch.einsum("lbqc,chw->lbqhw", self.mask_embed(obj_queries), pixel_embed)
|
|
51
|
+
else:
|
|
52
|
+
mask_preds = torch.einsum("lbqc,bchw->lbqhw", self.mask_embed(obj_queries), pixel_embed)
|
|
53
|
+
|
|
54
|
+
return mask_preds
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class SegmentationHead(nn.Module):
|
|
58
|
+
"""Segmentation head that predicts masks from backbone features and object queries."""
|
|
59
|
+
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
hidden_dim,
|
|
63
|
+
upsampling_stages,
|
|
64
|
+
use_encoder_inputs=False,
|
|
65
|
+
aux_masks=False,
|
|
66
|
+
no_dec=False,
|
|
67
|
+
pixel_decoder=None,
|
|
68
|
+
act_ckpt=False,
|
|
69
|
+
shared_conv=False,
|
|
70
|
+
compile_mode_pixel_decoder=None,
|
|
71
|
+
):
|
|
72
|
+
"""Initializes the SegmentationHead."""
|
|
73
|
+
super().__init__()
|
|
74
|
+
self.use_encoder_inputs = use_encoder_inputs
|
|
75
|
+
self.aux_masks = aux_masks
|
|
76
|
+
if pixel_decoder is not None:
|
|
77
|
+
self.pixel_decoder = pixel_decoder
|
|
78
|
+
else:
|
|
79
|
+
self.pixel_decoder = PixelDecoder(
|
|
80
|
+
hidden_dim,
|
|
81
|
+
upsampling_stages,
|
|
82
|
+
shared_conv=shared_conv,
|
|
83
|
+
compile_mode=compile_mode_pixel_decoder,
|
|
84
|
+
)
|
|
85
|
+
self.no_dec = no_dec
|
|
86
|
+
if no_dec:
|
|
87
|
+
self.mask_predictor = nn.Conv2d(hidden_dim, 1, kernel_size=3, stride=1, padding=1)
|
|
88
|
+
else:
|
|
89
|
+
self.mask_predictor = MaskPredictor(hidden_dim, mask_dim=hidden_dim)
|
|
90
|
+
|
|
91
|
+
self.act_ckpt = act_ckpt
|
|
92
|
+
|
|
93
|
+
# used to update the output dictionary
|
|
94
|
+
self.instance_keys = ["pred_masks"]
|
|
95
|
+
|
|
96
|
+
def _embed_pixels(self, backbone_feats: list[torch.Tensor], encoder_hidden_states) -> torch.Tensor:
|
|
97
|
+
"""Embeds pixels using the pixel decoder."""
|
|
98
|
+
if self.use_encoder_inputs:
|
|
99
|
+
backbone_visual_feats = [bb_feat.clone() for bb_feat in backbone_feats]
|
|
100
|
+
# Extract visual embeddings
|
|
101
|
+
encoder_hidden_states = encoder_hidden_states.permute(1, 2, 0)
|
|
102
|
+
spatial_dim = math.prod(backbone_feats[-1].shape[-2:])
|
|
103
|
+
encoder_visual_embed = encoder_hidden_states[..., :spatial_dim].reshape(-1, *backbone_feats[-1].shape[1:])
|
|
104
|
+
|
|
105
|
+
backbone_visual_feats[-1] = encoder_visual_embed
|
|
106
|
+
if self.act_ckpt:
|
|
107
|
+
pixel_embed = checkpoint.checkpoint(self.pixel_decoder, backbone_visual_feats, use_reentrant=False)
|
|
108
|
+
else:
|
|
109
|
+
pixel_embed = self.pixel_decoder(backbone_visual_feats)
|
|
110
|
+
else:
|
|
111
|
+
backbone_feats = [x for x in backbone_feats]
|
|
112
|
+
pixel_embed = self.pixel_decoder(backbone_feats)
|
|
113
|
+
if pixel_embed.shape[0] == 1:
|
|
114
|
+
# For batch_size=1 training, we can avoid the indexing to save memory
|
|
115
|
+
pixel_embed = pixel_embed.squeeze(0)
|
|
116
|
+
else:
|
|
117
|
+
pixel_embed = pixel_embed[[0], ...]
|
|
118
|
+
return pixel_embed
|
|
119
|
+
|
|
120
|
+
def forward(
|
|
121
|
+
self,
|
|
122
|
+
backbone_feats: list[torch.Tensor],
|
|
123
|
+
obj_queries: torch.Tensor,
|
|
124
|
+
encoder_hidden_states: torch.Tensor = None,
|
|
125
|
+
**kwargs,
|
|
126
|
+
) -> dict[str, torch.Tensor]:
|
|
127
|
+
"""Forward pass of the SegmentationHead."""
|
|
128
|
+
if self.use_encoder_inputs:
|
|
129
|
+
assert encoder_hidden_states is not None
|
|
130
|
+
|
|
131
|
+
pixel_embed = self._embed_pixels(backbone_feats=backbone_feats, encoder_hidden_states=encoder_hidden_states)
|
|
132
|
+
|
|
133
|
+
if self.no_dec:
|
|
134
|
+
mask_pred = self.mask_predictor(pixel_embed)
|
|
135
|
+
elif self.aux_masks:
|
|
136
|
+
mask_pred = self.mask_predictor(obj_queries, pixel_embed)
|
|
137
|
+
else:
|
|
138
|
+
mask_pred = self.mask_predictor(obj_queries[-1], pixel_embed)
|
|
139
|
+
|
|
140
|
+
return {"pred_masks": mask_pred}
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
class PixelDecoder(nn.Module):
|
|
144
|
+
"""Pixel decoder module that upsamples backbone features."""
|
|
145
|
+
|
|
146
|
+
def __init__(
|
|
147
|
+
self,
|
|
148
|
+
hidden_dim,
|
|
149
|
+
num_upsampling_stages,
|
|
150
|
+
interpolation_mode="nearest",
|
|
151
|
+
shared_conv=False,
|
|
152
|
+
compile_mode=None,
|
|
153
|
+
):
|
|
154
|
+
"""Initializes the PixelDecoder."""
|
|
155
|
+
super().__init__()
|
|
156
|
+
self.hidden_dim = hidden_dim
|
|
157
|
+
self.num_upsampling_stages = num_upsampling_stages
|
|
158
|
+
self.interpolation_mode = interpolation_mode
|
|
159
|
+
conv_layers = []
|
|
160
|
+
norms = []
|
|
161
|
+
num_convs = 1 if shared_conv else num_upsampling_stages
|
|
162
|
+
for _ in range(num_convs):
|
|
163
|
+
conv_layers.append(nn.Conv2d(self.hidden_dim, self.hidden_dim, 3, 1, 1))
|
|
164
|
+
norms.append(nn.GroupNorm(8, self.hidden_dim))
|
|
165
|
+
|
|
166
|
+
self.conv_layers = nn.ModuleList(conv_layers)
|
|
167
|
+
self.norms = nn.ModuleList(norms)
|
|
168
|
+
self.shared_conv = shared_conv
|
|
169
|
+
self.out_dim = self.conv_layers[-1].out_channels
|
|
170
|
+
if compile_mode is not None:
|
|
171
|
+
self.forward = torch.compile(self.forward, mode=compile_mode, dynamic=True, fullgraph=True)
|
|
172
|
+
# Needed to make checkpointing happy. But we don't know if the module is checkpointed, so we disable it by default.
|
|
173
|
+
torch._dynamo.config.optimize_ddp = False
|
|
174
|
+
|
|
175
|
+
def forward(self, backbone_feats: list[torch.Tensor]):
|
|
176
|
+
"""Forward pass of the PixelDecoder."""
|
|
177
|
+
prev_fpn = backbone_feats[-1]
|
|
178
|
+
fpn_feats = backbone_feats[:-1]
|
|
179
|
+
for layer_idx, bb_feat in enumerate(fpn_feats[::-1]):
|
|
180
|
+
curr_fpn = bb_feat
|
|
181
|
+
prev_fpn = curr_fpn + F.interpolate(prev_fpn, size=curr_fpn.shape[-2:], mode=self.interpolation_mode)
|
|
182
|
+
if self.shared_conv:
|
|
183
|
+
# only one conv layer
|
|
184
|
+
layer_idx = 0
|
|
185
|
+
prev_fpn = self.conv_layers[layer_idx](prev_fpn)
|
|
186
|
+
prev_fpn = F.relu(self.norms[layer_idx](prev_fpn))
|
|
187
|
+
|
|
188
|
+
return prev_fpn
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
class UniversalSegmentationHead(SegmentationHead):
|
|
192
|
+
"""This module handles semantic+instance segmentation."""
|
|
193
|
+
|
|
194
|
+
def __init__(
|
|
195
|
+
self,
|
|
196
|
+
hidden_dim,
|
|
197
|
+
upsampling_stages,
|
|
198
|
+
pixel_decoder,
|
|
199
|
+
aux_masks=False,
|
|
200
|
+
no_dec=False,
|
|
201
|
+
act_ckpt=False,
|
|
202
|
+
presence_head: bool = False,
|
|
203
|
+
dot_product_scorer=None,
|
|
204
|
+
cross_attend_prompt=None,
|
|
205
|
+
):
|
|
206
|
+
"""Initializes the UniversalSegmentationHead."""
|
|
207
|
+
super().__init__(
|
|
208
|
+
hidden_dim=hidden_dim,
|
|
209
|
+
upsampling_stages=upsampling_stages,
|
|
210
|
+
use_encoder_inputs=True,
|
|
211
|
+
aux_masks=aux_masks,
|
|
212
|
+
no_dec=no_dec,
|
|
213
|
+
pixel_decoder=pixel_decoder,
|
|
214
|
+
act_ckpt=act_ckpt,
|
|
215
|
+
)
|
|
216
|
+
self.d_model = hidden_dim
|
|
217
|
+
|
|
218
|
+
if dot_product_scorer is not None:
|
|
219
|
+
assert presence_head, "Specifying a dot product scorer without a presence head is likely a mistake"
|
|
220
|
+
|
|
221
|
+
self.presence_head = None
|
|
222
|
+
if presence_head:
|
|
223
|
+
self.presence_head = (
|
|
224
|
+
dot_product_scorer if dot_product_scorer is not None else LinearPresenceHead(self.d_model)
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
self.cross_attend_prompt = cross_attend_prompt
|
|
228
|
+
if self.cross_attend_prompt is not None:
|
|
229
|
+
self.cross_attn_norm = nn.LayerNorm(self.d_model)
|
|
230
|
+
|
|
231
|
+
self.semantic_seg_head = nn.Conv2d(self.pixel_decoder.out_dim, 1, kernel_size=1)
|
|
232
|
+
self.instance_seg_head = nn.Conv2d(self.pixel_decoder.out_dim, self.d_model, kernel_size=1)
|
|
233
|
+
|
|
234
|
+
def forward(
|
|
235
|
+
self,
|
|
236
|
+
backbone_feats: list[torch.Tensor],
|
|
237
|
+
obj_queries: torch.Tensor,
|
|
238
|
+
encoder_hidden_states: torch.Tensor = None,
|
|
239
|
+
prompt: torch.Tensor = None,
|
|
240
|
+
prompt_mask: torch.Tensor = None,
|
|
241
|
+
**kwargs,
|
|
242
|
+
) -> dict[str, torch.Tensor]:
|
|
243
|
+
"""Forward pass of the UniversalSegmentationHead."""
|
|
244
|
+
assert encoder_hidden_states is not None
|
|
245
|
+
bs = encoder_hidden_states.shape[1]
|
|
246
|
+
|
|
247
|
+
if self.cross_attend_prompt is not None:
|
|
248
|
+
tgt2 = self.cross_attn_norm(encoder_hidden_states)
|
|
249
|
+
tgt2 = self.cross_attend_prompt(
|
|
250
|
+
query=tgt2,
|
|
251
|
+
key=prompt.to(tgt2.dtype),
|
|
252
|
+
value=prompt.to(tgt2.dtype),
|
|
253
|
+
key_padding_mask=prompt_mask,
|
|
254
|
+
need_weights=False,
|
|
255
|
+
)[0]
|
|
256
|
+
encoder_hidden_states = tgt2 + encoder_hidden_states
|
|
257
|
+
|
|
258
|
+
presence_logit = None
|
|
259
|
+
if self.presence_head is not None:
|
|
260
|
+
pooled_enc = encoder_hidden_states.mean(0)
|
|
261
|
+
presence_logit = (
|
|
262
|
+
self.presence_head(
|
|
263
|
+
pooled_enc.view(1, bs, 1, self.d_model),
|
|
264
|
+
prompt=prompt,
|
|
265
|
+
prompt_mask=prompt_mask,
|
|
266
|
+
)
|
|
267
|
+
.squeeze(0)
|
|
268
|
+
.squeeze(1)
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
pixel_embed = self._embed_pixels(backbone_feats=backbone_feats, encoder_hidden_states=encoder_hidden_states)
|
|
272
|
+
|
|
273
|
+
instance_embeds = self.instance_seg_head(pixel_embed)
|
|
274
|
+
|
|
275
|
+
if self.no_dec:
|
|
276
|
+
mask_pred = self.mask_predictor(instance_embeds)
|
|
277
|
+
elif self.aux_masks:
|
|
278
|
+
mask_pred = self.mask_predictor(obj_queries, instance_embeds)
|
|
279
|
+
else:
|
|
280
|
+
mask_pred = self.mask_predictor(obj_queries[-1], instance_embeds)
|
|
281
|
+
|
|
282
|
+
return {
|
|
283
|
+
"pred_masks": mask_pred,
|
|
284
|
+
"semantic_seg": self.semantic_seg_head(pixel_embed),
|
|
285
|
+
"presence_logit": presence_logit,
|
|
286
|
+
}
|