dgenerate-ultralytics-headless 8.3.214__py3-none-any.whl → 8.4.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (249) hide show
  1. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/METADATA +64 -74
  2. dgenerate_ultralytics_headless-8.4.7.dist-info/RECORD +311 -0
  3. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -9
  5. tests/conftest.py +8 -15
  6. tests/test_cli.py +1 -1
  7. tests/test_cuda.py +13 -10
  8. tests/test_engine.py +9 -9
  9. tests/test_exports.py +65 -13
  10. tests/test_integrations.py +13 -13
  11. tests/test_python.py +125 -69
  12. tests/test_solutions.py +161 -152
  13. ultralytics/__init__.py +1 -1
  14. ultralytics/cfg/__init__.py +86 -92
  15. ultralytics/cfg/datasets/Argoverse.yaml +7 -6
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
  17. ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
  18. ultralytics/cfg/datasets/ImageNet.yaml +1 -1
  19. ultralytics/cfg/datasets/TT100K.yaml +346 -0
  20. ultralytics/cfg/datasets/VOC.yaml +15 -16
  21. ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
  22. ultralytics/cfg/datasets/coco-pose.yaml +21 -0
  23. ultralytics/cfg/datasets/coco12-formats.yaml +101 -0
  24. ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
  25. ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
  26. ultralytics/cfg/datasets/dog-pose.yaml +28 -0
  27. ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
  28. ultralytics/cfg/datasets/dota8.yaml +2 -2
  29. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
  30. ultralytics/cfg/datasets/kitti.yaml +27 -0
  31. ultralytics/cfg/datasets/lvis.yaml +5 -5
  32. ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
  33. ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
  34. ultralytics/cfg/datasets/xView.yaml +16 -16
  35. ultralytics/cfg/default.yaml +4 -2
  36. ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
  37. ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
  38. ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
  39. ultralytics/cfg/models/26/yolo26-cls.yaml +33 -0
  40. ultralytics/cfg/models/26/yolo26-obb.yaml +52 -0
  41. ultralytics/cfg/models/26/yolo26-p2.yaml +60 -0
  42. ultralytics/cfg/models/26/yolo26-p6.yaml +62 -0
  43. ultralytics/cfg/models/26/yolo26-pose.yaml +53 -0
  44. ultralytics/cfg/models/26/yolo26-seg.yaml +52 -0
  45. ultralytics/cfg/models/26/yolo26.yaml +52 -0
  46. ultralytics/cfg/models/26/yoloe-26-seg.yaml +53 -0
  47. ultralytics/cfg/models/26/yoloe-26.yaml +53 -0
  48. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
  49. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
  50. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
  51. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
  52. ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
  53. ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
  54. ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
  55. ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
  56. ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
  57. ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
  58. ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
  59. ultralytics/cfg/models/v6/yolov6.yaml +1 -1
  60. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
  61. ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
  62. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
  63. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
  64. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
  65. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
  66. ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
  67. ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
  68. ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
  69. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
  70. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
  71. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
  72. ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
  73. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
  74. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  75. ultralytics/data/__init__.py +4 -4
  76. ultralytics/data/annotator.py +5 -6
  77. ultralytics/data/augment.py +300 -475
  78. ultralytics/data/base.py +18 -26
  79. ultralytics/data/build.py +147 -25
  80. ultralytics/data/converter.py +108 -87
  81. ultralytics/data/dataset.py +47 -75
  82. ultralytics/data/loaders.py +42 -49
  83. ultralytics/data/split.py +5 -6
  84. ultralytics/data/split_dota.py +8 -15
  85. ultralytics/data/utils.py +36 -45
  86. ultralytics/engine/exporter.py +351 -263
  87. ultralytics/engine/model.py +186 -225
  88. ultralytics/engine/predictor.py +45 -54
  89. ultralytics/engine/results.py +198 -325
  90. ultralytics/engine/trainer.py +165 -106
  91. ultralytics/engine/tuner.py +41 -43
  92. ultralytics/engine/validator.py +55 -38
  93. ultralytics/hub/__init__.py +16 -19
  94. ultralytics/hub/auth.py +6 -12
  95. ultralytics/hub/google/__init__.py +7 -10
  96. ultralytics/hub/session.py +15 -25
  97. ultralytics/hub/utils.py +5 -8
  98. ultralytics/models/__init__.py +1 -1
  99. ultralytics/models/fastsam/__init__.py +1 -1
  100. ultralytics/models/fastsam/model.py +8 -10
  101. ultralytics/models/fastsam/predict.py +18 -30
  102. ultralytics/models/fastsam/utils.py +1 -2
  103. ultralytics/models/fastsam/val.py +5 -7
  104. ultralytics/models/nas/__init__.py +1 -1
  105. ultralytics/models/nas/model.py +5 -8
  106. ultralytics/models/nas/predict.py +7 -9
  107. ultralytics/models/nas/val.py +1 -2
  108. ultralytics/models/rtdetr/__init__.py +1 -1
  109. ultralytics/models/rtdetr/model.py +5 -8
  110. ultralytics/models/rtdetr/predict.py +15 -19
  111. ultralytics/models/rtdetr/train.py +10 -13
  112. ultralytics/models/rtdetr/val.py +21 -23
  113. ultralytics/models/sam/__init__.py +15 -2
  114. ultralytics/models/sam/amg.py +14 -20
  115. ultralytics/models/sam/build.py +26 -19
  116. ultralytics/models/sam/build_sam3.py +377 -0
  117. ultralytics/models/sam/model.py +29 -32
  118. ultralytics/models/sam/modules/blocks.py +83 -144
  119. ultralytics/models/sam/modules/decoders.py +19 -37
  120. ultralytics/models/sam/modules/encoders.py +44 -101
  121. ultralytics/models/sam/modules/memory_attention.py +16 -30
  122. ultralytics/models/sam/modules/sam.py +200 -73
  123. ultralytics/models/sam/modules/tiny_encoder.py +64 -83
  124. ultralytics/models/sam/modules/transformer.py +18 -28
  125. ultralytics/models/sam/modules/utils.py +174 -50
  126. ultralytics/models/sam/predict.py +2248 -350
  127. ultralytics/models/sam/sam3/__init__.py +3 -0
  128. ultralytics/models/sam/sam3/decoder.py +546 -0
  129. ultralytics/models/sam/sam3/encoder.py +529 -0
  130. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  131. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  132. ultralytics/models/sam/sam3/model_misc.py +199 -0
  133. ultralytics/models/sam/sam3/necks.py +129 -0
  134. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  135. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  136. ultralytics/models/sam/sam3/vitdet.py +547 -0
  137. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  138. ultralytics/models/utils/loss.py +14 -26
  139. ultralytics/models/utils/ops.py +13 -17
  140. ultralytics/models/yolo/__init__.py +1 -1
  141. ultralytics/models/yolo/classify/predict.py +10 -13
  142. ultralytics/models/yolo/classify/train.py +12 -33
  143. ultralytics/models/yolo/classify/val.py +30 -29
  144. ultralytics/models/yolo/detect/predict.py +9 -12
  145. ultralytics/models/yolo/detect/train.py +17 -23
  146. ultralytics/models/yolo/detect/val.py +77 -59
  147. ultralytics/models/yolo/model.py +43 -60
  148. ultralytics/models/yolo/obb/predict.py +7 -16
  149. ultralytics/models/yolo/obb/train.py +14 -17
  150. ultralytics/models/yolo/obb/val.py +40 -37
  151. ultralytics/models/yolo/pose/__init__.py +1 -1
  152. ultralytics/models/yolo/pose/predict.py +7 -22
  153. ultralytics/models/yolo/pose/train.py +13 -16
  154. ultralytics/models/yolo/pose/val.py +39 -58
  155. ultralytics/models/yolo/segment/predict.py +17 -21
  156. ultralytics/models/yolo/segment/train.py +7 -10
  157. ultralytics/models/yolo/segment/val.py +95 -47
  158. ultralytics/models/yolo/world/train.py +8 -14
  159. ultralytics/models/yolo/world/train_world.py +11 -34
  160. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  161. ultralytics/models/yolo/yoloe/predict.py +16 -23
  162. ultralytics/models/yolo/yoloe/train.py +36 -44
  163. ultralytics/models/yolo/yoloe/train_seg.py +11 -11
  164. ultralytics/models/yolo/yoloe/val.py +15 -20
  165. ultralytics/nn/__init__.py +7 -7
  166. ultralytics/nn/autobackend.py +159 -85
  167. ultralytics/nn/modules/__init__.py +68 -60
  168. ultralytics/nn/modules/activation.py +4 -6
  169. ultralytics/nn/modules/block.py +260 -224
  170. ultralytics/nn/modules/conv.py +52 -97
  171. ultralytics/nn/modules/head.py +831 -299
  172. ultralytics/nn/modules/transformer.py +76 -88
  173. ultralytics/nn/modules/utils.py +16 -21
  174. ultralytics/nn/tasks.py +180 -195
  175. ultralytics/nn/text_model.py +45 -69
  176. ultralytics/optim/__init__.py +5 -0
  177. ultralytics/optim/muon.py +338 -0
  178. ultralytics/solutions/__init__.py +12 -12
  179. ultralytics/solutions/ai_gym.py +13 -19
  180. ultralytics/solutions/analytics.py +15 -16
  181. ultralytics/solutions/config.py +6 -7
  182. ultralytics/solutions/distance_calculation.py +10 -13
  183. ultralytics/solutions/heatmap.py +8 -14
  184. ultralytics/solutions/instance_segmentation.py +6 -9
  185. ultralytics/solutions/object_blurrer.py +7 -10
  186. ultralytics/solutions/object_counter.py +12 -19
  187. ultralytics/solutions/object_cropper.py +8 -14
  188. ultralytics/solutions/parking_management.py +34 -32
  189. ultralytics/solutions/queue_management.py +10 -12
  190. ultralytics/solutions/region_counter.py +9 -12
  191. ultralytics/solutions/security_alarm.py +15 -20
  192. ultralytics/solutions/similarity_search.py +10 -15
  193. ultralytics/solutions/solutions.py +77 -76
  194. ultralytics/solutions/speed_estimation.py +7 -10
  195. ultralytics/solutions/streamlit_inference.py +2 -4
  196. ultralytics/solutions/templates/similarity-search.html +7 -18
  197. ultralytics/solutions/trackzone.py +7 -10
  198. ultralytics/solutions/vision_eye.py +5 -8
  199. ultralytics/trackers/__init__.py +1 -1
  200. ultralytics/trackers/basetrack.py +3 -5
  201. ultralytics/trackers/bot_sort.py +10 -27
  202. ultralytics/trackers/byte_tracker.py +21 -37
  203. ultralytics/trackers/track.py +4 -7
  204. ultralytics/trackers/utils/gmc.py +11 -22
  205. ultralytics/trackers/utils/kalman_filter.py +37 -48
  206. ultralytics/trackers/utils/matching.py +12 -15
  207. ultralytics/utils/__init__.py +124 -124
  208. ultralytics/utils/autobatch.py +2 -4
  209. ultralytics/utils/autodevice.py +17 -18
  210. ultralytics/utils/benchmarks.py +57 -71
  211. ultralytics/utils/callbacks/base.py +8 -10
  212. ultralytics/utils/callbacks/clearml.py +5 -13
  213. ultralytics/utils/callbacks/comet.py +32 -46
  214. ultralytics/utils/callbacks/dvc.py +13 -18
  215. ultralytics/utils/callbacks/mlflow.py +4 -5
  216. ultralytics/utils/callbacks/neptune.py +7 -15
  217. ultralytics/utils/callbacks/platform.py +423 -38
  218. ultralytics/utils/callbacks/raytune.py +3 -4
  219. ultralytics/utils/callbacks/tensorboard.py +25 -31
  220. ultralytics/utils/callbacks/wb.py +16 -14
  221. ultralytics/utils/checks.py +127 -85
  222. ultralytics/utils/cpu.py +3 -8
  223. ultralytics/utils/dist.py +9 -12
  224. ultralytics/utils/downloads.py +25 -33
  225. ultralytics/utils/errors.py +6 -14
  226. ultralytics/utils/events.py +2 -4
  227. ultralytics/utils/export/__init__.py +4 -236
  228. ultralytics/utils/export/engine.py +246 -0
  229. ultralytics/utils/export/imx.py +117 -63
  230. ultralytics/utils/export/tensorflow.py +231 -0
  231. ultralytics/utils/files.py +26 -30
  232. ultralytics/utils/git.py +9 -11
  233. ultralytics/utils/instance.py +30 -51
  234. ultralytics/utils/logger.py +212 -114
  235. ultralytics/utils/loss.py +601 -215
  236. ultralytics/utils/metrics.py +128 -156
  237. ultralytics/utils/nms.py +13 -16
  238. ultralytics/utils/ops.py +117 -166
  239. ultralytics/utils/patches.py +75 -21
  240. ultralytics/utils/plotting.py +75 -80
  241. ultralytics/utils/tal.py +125 -59
  242. ultralytics/utils/torch_utils.py +53 -79
  243. ultralytics/utils/tqdm.py +24 -21
  244. ultralytics/utils/triton.py +13 -19
  245. ultralytics/utils/tuner.py +19 -10
  246. dgenerate_ultralytics_headless-8.3.214.dist-info/RECORD +0 -283
  247. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/entry_points.txt +0 -0
  248. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/licenses/LICENSE +0 -0
  249. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,415 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torchvision
8
+
9
+ from ultralytics.nn.modules.utils import _get_clones
10
+ from ultralytics.utils.ops import xywh2xyxy
11
+
12
+
13
+ def is_right_padded(mask: torch.Tensor):
14
+ """Given a padding mask (following pytorch convention, 1s for padded values), returns whether the padding is on the
15
+ right or not.
16
+ """
17
+ return (mask.long() == torch.sort(mask.long(), dim=-1)[0]).all()
18
+
19
+
20
+ def concat_padded_sequences(seq1, mask1, seq2, mask2, return_index: bool = False):
21
+ """
22
+ Concatenates two right-padded sequences, such that the resulting sequence
23
+ is contiguous and also right-padded.
24
+
25
+ Following pytorch's convention, tensors are sequence first, and the mask are
26
+ batch first, with 1s for padded values.
27
+
28
+ :param seq1: A tensor of shape (seq1_length, batch_size, hidden_size).
29
+ :param mask1: A tensor of shape (batch_size, seq1_length).
30
+ :param seq2: A tensor of shape (seq2_length, batch_size, hidden_size).
31
+ :param mask2: A tensor of shape (batch_size, seq2_length).
32
+ :param return_index: If True, also returns the index of the ids of the element of seq2
33
+ in the concatenated sequence. This can be used to retrieve the elements of seq2
34
+ :return: A tuple (concatenated_sequence, concatenated_mask) if return_index is False,
35
+ otherwise (concatenated_sequence, concatenated_mask, index).
36
+ """
37
+ seq1_length, batch_size, hidden_size = seq1.shape
38
+ seq2_length, batch_size, hidden_size = seq2.shape
39
+
40
+ assert batch_size == seq1.size(1) == seq2.size(1) == mask1.size(0) == mask2.size(0)
41
+ assert hidden_size == seq1.size(2) == seq2.size(2)
42
+ assert seq1_length == mask1.size(1)
43
+ assert seq2_length == mask2.size(1)
44
+
45
+ torch._assert(is_right_padded(mask1), "Mask is not right padded")
46
+ torch._assert(is_right_padded(mask2), "Mask is not right padded")
47
+
48
+ actual_seq1_lengths = (~mask1).sum(dim=-1)
49
+ actual_seq2_lengths = (~mask2).sum(dim=-1)
50
+
51
+ final_lengths = actual_seq1_lengths + actual_seq2_lengths
52
+ max_length = seq1_length + seq2_length
53
+ concatenated_mask = (
54
+ torch.arange(max_length, device=seq2.device)[None].repeat(batch_size, 1) >= final_lengths[:, None]
55
+ )
56
+
57
+ # (max_len, batch_size, hidden_size)
58
+ concatenated_sequence = torch.zeros((max_length, batch_size, hidden_size), device=seq2.device, dtype=seq2.dtype)
59
+ concatenated_sequence[:seq1_length, :, :] = seq1
60
+
61
+ # At this point, the element of seq1 are in the right place
62
+ # We just need to shift the elements of seq2
63
+
64
+ index = torch.arange(seq2_length, device=seq2.device)[:, None].repeat(1, batch_size)
65
+ index = index + actual_seq1_lengths[None]
66
+
67
+ concatenated_sequence = concatenated_sequence.scatter(0, index[:, :, None].expand(-1, -1, hidden_size), seq2)
68
+
69
+ if return_index:
70
+ return concatenated_sequence, concatenated_mask, index
71
+
72
+ return concatenated_sequence, concatenated_mask
73
+
74
+
75
+ class Prompt:
76
+ """Utility class to manipulate geometric prompts.
77
+
78
+ We expect the sequences in pytorch convention, that is sequence first, batch second The dimensions are expected as
79
+ follows: box_embeddings shape: N_boxes x B x C_box box_mask shape: B x N_boxes. Can be None if nothing is masked out
80
+ point_embeddings shape: N_points x B x C_point point_mask shape: B x N_points. Can be None if nothing is masked out
81
+ mask_embeddings shape: N_masks x B x 1 x H_mask x W_mask mask_mask shape: B x N_masks. Can be None if nothing is
82
+ masked out
83
+
84
+ We also store positive/negative labels. These tensors are also stored batch-first If they are None, we'll assume
85
+ positive labels everywhere box_labels: long tensor of shape N_boxes x B point_labels: long tensor of shape N_points
86
+ x B mask_labels: long tensor of shape N_masks x B
87
+ """
88
+
89
+ def __init__(self, box_embeddings=None, box_mask=None, box_labels=None):
90
+ """Initialize the Prompt object."""
91
+ # Check for null prompt
92
+ # Check for null prompt
93
+ if box_embeddings is None:
94
+ self.box_embeddings = None
95
+ self.box_labels = None
96
+ self.box_mask = None
97
+ return
98
+
99
+ # Get sequence length, batch size, and device
100
+ box_seq_len = box_embeddings.shape[0]
101
+ bs = box_embeddings.shape[1]
102
+ device = box_embeddings.device
103
+
104
+ # Initialize labels and attention mask if not provided
105
+ if box_labels is None:
106
+ box_labels = torch.ones(box_seq_len, bs, device=device, dtype=torch.long)
107
+ if box_mask is None:
108
+ box_mask = torch.zeros(bs, box_seq_len, device=device, dtype=torch.bool)
109
+
110
+ # Dimension checks
111
+ assert list(box_embeddings.shape[:2]) == [box_seq_len, bs], (
112
+ f"Wrong dimension for box embeddings. Expected [{box_seq_len}, {bs}, *] got {box_embeddings.shape}"
113
+ )
114
+ assert box_embeddings.shape[-1] == 4, (
115
+ f"Expected box embeddings to have 4 coordinates, got {box_embeddings.shape[-1]}"
116
+ )
117
+ assert list(box_mask.shape) == [bs, box_seq_len], (
118
+ f"Wrong dimension for box mask. Expected [{bs}, {box_seq_len}] got {box_mask.shape}"
119
+ )
120
+ assert list(box_labels.shape) == [box_seq_len, bs], (
121
+ f"Wrong dimension for box labels. Expected [{box_seq_len}, {bs}] got {box_labels.shape}"
122
+ )
123
+
124
+ # Device checks
125
+ assert box_embeddings.device == device, (
126
+ f"Expected box embeddings to be on device {device}, got {box_embeddings.device}"
127
+ )
128
+ assert box_mask.device == device, f"Expected box mask to be on device {device}, got {box_mask.device}"
129
+ assert box_labels.device == device, f"Expected box labels to be on device {device}, got {box_labels.device}"
130
+
131
+ self.box_embeddings = box_embeddings
132
+ self.box_mask = box_mask
133
+ self.box_labels = box_labels
134
+
135
+ def append_boxes(self, boxes, labels=None, mask=None):
136
+ """Append box prompts to existing prompts.
137
+
138
+ Args:
139
+ boxes: Tensor of shape (N_new_boxes, B, 4) with normalized box coordinates
140
+ labels: Optional tensor of shape (N_new_boxes, B) with positive/negative labels
141
+ mask: Optional tensor of shape (B, N_new_boxes) for attention mask
142
+ """
143
+ if self.box_embeddings is None:
144
+ # First boxes - initialize
145
+ self.box_embeddings = boxes
146
+ bs = boxes.shape[1]
147
+ box_seq_len = boxes.shape[0]
148
+
149
+ if labels is None:
150
+ labels = torch.ones(box_seq_len, bs, device=boxes.device, dtype=torch.long)
151
+ if mask is None:
152
+ mask = torch.zeros(bs, box_seq_len, device=boxes.device, dtype=torch.bool)
153
+
154
+ self.box_labels = labels
155
+ self.box_mask = mask
156
+ return
157
+
158
+ # Append to existing boxes
159
+ bs = self.box_embeddings.shape[1]
160
+ assert boxes.shape[1] == bs, f"Batch size mismatch: expected {bs}, got {boxes.shape[1]}"
161
+
162
+ if labels is None:
163
+ labels = torch.ones(boxes.shape[0], bs, device=boxes.device, dtype=torch.long)
164
+ if mask is None:
165
+ mask = torch.zeros(bs, boxes.shape[0], dtype=torch.bool, device=boxes.device)
166
+
167
+ assert list(boxes.shape[:2]) == list(labels.shape[:2]), (
168
+ f"Shape mismatch between boxes {boxes.shape} and labels {labels.shape}"
169
+ )
170
+
171
+ # Concatenate using the helper function
172
+ self.box_labels, _ = concat_padded_sequences(
173
+ self.box_labels.unsqueeze(-1), self.box_mask, labels.unsqueeze(-1), mask
174
+ )
175
+ self.box_labels = self.box_labels.squeeze(-1)
176
+ self.box_embeddings, self.box_mask = concat_padded_sequences(self.box_embeddings, self.box_mask, boxes, mask)
177
+
178
+
179
+ class SequenceGeometryEncoder(nn.Module):
180
+ """Encoder for geometric box prompts. Assumes boxes are passed in the "normalized CxCyWH" format.
181
+
182
+ Boxes can be encoded with any of the three possibilities:
183
+ - direct projection: linear projection from coordinate space to d_model
184
+ - pooling: RoI align features from the backbone
185
+ - pos encoder: position encoding of the box center
186
+
187
+ These three options are mutually compatible and will be summed if multiple are selected.
188
+
189
+ As an alternative, boxes can be encoded as two corner points (top-left and bottom-right).
190
+
191
+ The encoded sequence can be further processed with a transformer.
192
+ """
193
+
194
+ def __init__(
195
+ self,
196
+ encode_boxes_as_points: bool,
197
+ boxes_direct_project: bool,
198
+ boxes_pool: bool,
199
+ boxes_pos_enc: bool,
200
+ d_model: int,
201
+ pos_enc,
202
+ num_layers: int,
203
+ layer: nn.Module,
204
+ roi_size: int = 7,
205
+ add_cls: bool = True,
206
+ add_post_encode_proj: bool = True,
207
+ use_act_ckpt: bool = False,
208
+ ):
209
+ """Initialize the SequenceGeometryEncoder."""
210
+ super().__init__()
211
+
212
+ self.d_model = d_model
213
+ self.pos_enc = pos_enc
214
+ self.encode_boxes_as_points = encode_boxes_as_points
215
+ self.roi_size = roi_size
216
+
217
+ # Label embeddings: 2 labels if encoding as boxes (pos/neg)
218
+ # 6 labels if encoding as points (regular pos/neg, top-left pos/neg, bottom-right pos/neg)
219
+ num_labels = 6 if self.encode_boxes_as_points else 2
220
+ self.label_embed = torch.nn.Embedding(num_labels, self.d_model)
221
+
222
+ # CLS token for pooling
223
+ self.cls_embed = None
224
+ if add_cls:
225
+ self.cls_embed = torch.nn.Embedding(1, self.d_model)
226
+
227
+ # Point encoding (used when encode_boxes_as_points is True)
228
+ if encode_boxes_as_points:
229
+ self.points_direct_project = nn.Linear(2, self.d_model)
230
+ self.points_pool_project = None
231
+ self.points_pos_enc_project = None
232
+ else:
233
+ # Box encoding modules
234
+ assert boxes_direct_project or boxes_pos_enc or boxes_pool, "Error: need at least one way to encode boxes"
235
+ self.points_direct_project = None
236
+ self.points_pool_project = None
237
+ self.points_pos_enc_project = None
238
+
239
+ self.boxes_direct_project = None
240
+ self.boxes_pool_project = None
241
+ self.boxes_pos_enc_project = None
242
+
243
+ if boxes_direct_project:
244
+ self.boxes_direct_project = nn.Linear(4, self.d_model)
245
+ if boxes_pool:
246
+ self.boxes_pool_project = nn.Conv2d(self.d_model, self.d_model, self.roi_size)
247
+ if boxes_pos_enc:
248
+ self.boxes_pos_enc_project = nn.Linear(self.d_model + 2, self.d_model)
249
+
250
+ self.final_proj = None
251
+ if add_post_encode_proj:
252
+ self.final_proj = nn.Linear(self.d_model, self.d_model)
253
+ self.norm = nn.LayerNorm(self.d_model)
254
+
255
+ self.img_pre_norm = nn.Identity()
256
+ if self.points_pool_project is not None or self.boxes_pool_project is not None:
257
+ self.img_pre_norm = nn.LayerNorm(self.d_model)
258
+
259
+ self.encode = None
260
+ if num_layers > 0:
261
+ assert add_cls, "It's currently highly recommended to add a CLS when using a transformer"
262
+ self.encode = _get_clones(layer, num_layers)
263
+ self.encode_norm = nn.LayerNorm(self.d_model)
264
+
265
+ self.use_act_ckpt = use_act_ckpt
266
+
267
+ def _encode_points(self, points, points_mask, points_labels, img_feats):
268
+ """Encode points (used when boxes are converted to corner points)."""
269
+ # Direct projection of coordinates
270
+ points_embed = self.points_direct_project(points.to(img_feats.dtype))
271
+
272
+ # Add label embeddings
273
+ type_embed = self.label_embed(points_labels.long())
274
+ return type_embed + points_embed, points_mask
275
+
276
+ def _encode_boxes(self, boxes, boxes_mask, boxes_labels, img_feats: torch.Tensor):
277
+ """Encode boxes using configured encoding methods."""
278
+ boxes_embed = None
279
+ n_boxes, bs = boxes.shape[:2]
280
+
281
+ if self.boxes_direct_project is not None:
282
+ proj = self.boxes_direct_project(boxes.to(img_feats.dtype))
283
+ boxes_embed = proj
284
+
285
+ if self.boxes_pool_project is not None:
286
+ H, W = img_feats.shape[-2:]
287
+
288
+ # Convert boxes to xyxy format and denormalize
289
+ boxes_xyxy = xywh2xyxy(boxes.to(img_feats.dtype))
290
+ scale = torch.tensor([W, H, W, H], dtype=boxes_xyxy.dtype)
291
+ scale = scale.to(device=boxes_xyxy.device, non_blocking=True)
292
+ scale = scale.view(1, 1, 4)
293
+ boxes_xyxy = boxes_xyxy * scale
294
+
295
+ # RoI align
296
+ sampled = torchvision.ops.roi_align(img_feats, boxes_xyxy.transpose(0, 1).unbind(0), self.roi_size)
297
+ assert list(sampled.shape) == [
298
+ bs * n_boxes,
299
+ self.d_model,
300
+ self.roi_size,
301
+ self.roi_size,
302
+ ]
303
+ proj = self.boxes_pool_project(sampled)
304
+ proj = proj.view(bs, n_boxes, self.d_model).transpose(0, 1)
305
+
306
+ if boxes_embed is None:
307
+ boxes_embed = proj
308
+ else:
309
+ boxes_embed = boxes_embed + proj
310
+
311
+ if self.boxes_pos_enc_project is not None:
312
+ cx, cy, w, h = boxes.unbind(-1)
313
+ enc = self.pos_enc.encode_boxes(cx.flatten(), cy.flatten(), w.flatten(), h.flatten())
314
+ enc = enc.view(boxes.shape[0], boxes.shape[1], enc.shape[-1])
315
+
316
+ proj = self.boxes_pos_enc_project(enc.to(img_feats.dtype))
317
+ if boxes_embed is None:
318
+ boxes_embed = proj
319
+ else:
320
+ boxes_embed = boxes_embed + proj
321
+
322
+ # Add label embeddings
323
+ type_embed = self.label_embed(boxes_labels.long())
324
+ return type_embed + boxes_embed, boxes_mask
325
+
326
+ def forward(self, geo_prompt: Prompt, img_feats, img_sizes, img_pos_embeds=None):
327
+ """Encode geometric box prompts.
328
+
329
+ Args:
330
+ geo_prompt: Prompt object containing box embeddings, masks, and labels
331
+ img_feats: List of image features from backbone
332
+ img_sizes: List of (H, W) tuples for each feature level
333
+ img_pos_embeds: Optional position embeddings for image features
334
+
335
+ Returns:
336
+ Tuple of (encoded_embeddings, attention_mask)
337
+ """
338
+ boxes = geo_prompt.box_embeddings
339
+ boxes_mask = geo_prompt.box_mask
340
+ boxes_labels = geo_prompt.box_labels
341
+
342
+ seq_first_img_feats = img_feats[-1] # [H*W, B, C]
343
+ seq_first_img_pos_embeds = (
344
+ img_pos_embeds[-1] if img_pos_embeds is not None else torch.zeros_like(seq_first_img_feats)
345
+ )
346
+
347
+ # Prepare image features for pooling if needed
348
+ if self.points_pool_project or self.boxes_pool_project:
349
+ assert len(img_feats) == len(img_sizes)
350
+ cur_img_feat = img_feats[-1]
351
+ cur_img_feat = self.img_pre_norm(cur_img_feat)
352
+ H, W = img_sizes[-1]
353
+ assert cur_img_feat.shape[0] == H * W
354
+ N, C = cur_img_feat.shape[-2:]
355
+ # Reshape to NxCxHxW
356
+ cur_img_feat = cur_img_feat.permute(1, 2, 0)
357
+ cur_img_feat = cur_img_feat.view(N, C, H, W)
358
+ img_feats = cur_img_feat
359
+
360
+ if self.encode_boxes_as_points:
361
+ # Convert boxes to corner points
362
+ assert boxes is not None and boxes.shape[-1] == 4
363
+
364
+ boxes_xyxy = xywh2xyxy(boxes)
365
+ top_left, bottom_right = boxes_xyxy.split(split_size=2, dim=-1)
366
+
367
+ # Adjust labels for corner points (offset by 2 and 4)
368
+ labels_tl = boxes_labels + 2
369
+ labels_br = boxes_labels + 4
370
+
371
+ # Concatenate top-left and bottom-right points
372
+ points = torch.cat([top_left, bottom_right], dim=0)
373
+ points_labels = torch.cat([labels_tl, labels_br], dim=0)
374
+ points_mask = torch.cat([boxes_mask, boxes_mask], dim=1)
375
+
376
+ final_embeds, final_mask = self._encode_points(
377
+ points=points,
378
+ points_mask=points_mask,
379
+ points_labels=points_labels,
380
+ img_feats=img_feats,
381
+ )
382
+ else:
383
+ # Encode boxes directly
384
+ final_embeds, final_mask = self._encode_boxes(
385
+ boxes=boxes,
386
+ boxes_mask=boxes_mask,
387
+ boxes_labels=boxes_labels,
388
+ img_feats=img_feats,
389
+ )
390
+
391
+ bs = final_embeds.shape[1]
392
+ assert final_mask.shape[0] == bs
393
+
394
+ # Add CLS token if configured
395
+ if self.cls_embed is not None:
396
+ cls = self.cls_embed.weight.view(1, 1, self.d_model).repeat(1, bs, 1)
397
+ cls_mask = torch.zeros(bs, 1, dtype=final_mask.dtype, device=final_mask.device)
398
+ final_embeds, final_mask = concat_padded_sequences(final_embeds, final_mask, cls, cls_mask)
399
+
400
+ # Final projection
401
+ if self.final_proj is not None:
402
+ final_embeds = self.norm(self.final_proj(final_embeds))
403
+
404
+ # Transformer encoding layers
405
+ if self.encode is not None:
406
+ for lay in self.encode:
407
+ final_embeds = lay(
408
+ tgt=final_embeds,
409
+ memory=seq_first_img_feats,
410
+ tgt_key_padding_mask=final_mask,
411
+ pos=seq_first_img_pos_embeds,
412
+ )
413
+ final_embeds = self.encode_norm(final_embeds)
414
+
415
+ return final_embeds, final_mask
@@ -0,0 +1,286 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
4
+
5
+ from __future__ import annotations
6
+
7
+ import math
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torch.utils.checkpoint as checkpoint
13
+
14
+ from ultralytics.nn.modules.transformer import MLP
15
+
16
+
17
+ class LinearPresenceHead(nn.Sequential):
18
+ """Linear presence head for predicting the presence of classes in an image."""
19
+
20
+ def __init__(self, d_model):
21
+ """Initializes the LinearPresenceHead."""
22
+ # a hack to make `LinearPresenceHead` compatible with old checkpoints
23
+ super().__init__(nn.Identity(), nn.Identity(), nn.Linear(d_model, 1))
24
+
25
+ def forward(self, hs, prompt, prompt_mask):
26
+ """Forward pass of the presence head."""
27
+ return super().forward(hs)
28
+
29
+
30
+ class MaskPredictor(nn.Module):
31
+ """Predicts masks from object queries and pixel embeddings."""
32
+
33
+ def __init__(self, hidden_dim, mask_dim):
34
+ """Initializes the MaskPredictor."""
35
+ super().__init__()
36
+ self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
37
+
38
+ def forward(self, obj_queries, pixel_embed):
39
+ """Predicts masks from object queries and pixel embeddings."""
40
+ if len(obj_queries.shape) == 3:
41
+ if pixel_embed.ndim == 3:
42
+ # batch size was omitted
43
+ mask_preds = torch.einsum("bqc,chw->bqhw", self.mask_embed(obj_queries), pixel_embed)
44
+ else:
45
+ mask_preds = torch.einsum("bqc,bchw->bqhw", self.mask_embed(obj_queries), pixel_embed)
46
+ else:
47
+ # Assumed to have aux masks
48
+ if pixel_embed.ndim == 3:
49
+ # batch size was omitted
50
+ mask_preds = torch.einsum("lbqc,chw->lbqhw", self.mask_embed(obj_queries), pixel_embed)
51
+ else:
52
+ mask_preds = torch.einsum("lbqc,bchw->lbqhw", self.mask_embed(obj_queries), pixel_embed)
53
+
54
+ return mask_preds
55
+
56
+
57
+ class SegmentationHead(nn.Module):
58
+ """Segmentation head that predicts masks from backbone features and object queries."""
59
+
60
+ def __init__(
61
+ self,
62
+ hidden_dim,
63
+ upsampling_stages,
64
+ use_encoder_inputs=False,
65
+ aux_masks=False,
66
+ no_dec=False,
67
+ pixel_decoder=None,
68
+ act_ckpt=False,
69
+ shared_conv=False,
70
+ compile_mode_pixel_decoder=None,
71
+ ):
72
+ """Initializes the SegmentationHead."""
73
+ super().__init__()
74
+ self.use_encoder_inputs = use_encoder_inputs
75
+ self.aux_masks = aux_masks
76
+ if pixel_decoder is not None:
77
+ self.pixel_decoder = pixel_decoder
78
+ else:
79
+ self.pixel_decoder = PixelDecoder(
80
+ hidden_dim,
81
+ upsampling_stages,
82
+ shared_conv=shared_conv,
83
+ compile_mode=compile_mode_pixel_decoder,
84
+ )
85
+ self.no_dec = no_dec
86
+ if no_dec:
87
+ self.mask_predictor = nn.Conv2d(hidden_dim, 1, kernel_size=3, stride=1, padding=1)
88
+ else:
89
+ self.mask_predictor = MaskPredictor(hidden_dim, mask_dim=hidden_dim)
90
+
91
+ self.act_ckpt = act_ckpt
92
+
93
+ # used to update the output dictionary
94
+ self.instance_keys = ["pred_masks"]
95
+
96
+ def _embed_pixels(self, backbone_feats: list[torch.Tensor], encoder_hidden_states) -> torch.Tensor:
97
+ """Embeds pixels using the pixel decoder."""
98
+ if self.use_encoder_inputs:
99
+ backbone_visual_feats = [bb_feat.clone() for bb_feat in backbone_feats]
100
+ # Extract visual embeddings
101
+ encoder_hidden_states = encoder_hidden_states.permute(1, 2, 0)
102
+ spatial_dim = math.prod(backbone_feats[-1].shape[-2:])
103
+ encoder_visual_embed = encoder_hidden_states[..., :spatial_dim].reshape(-1, *backbone_feats[-1].shape[1:])
104
+
105
+ backbone_visual_feats[-1] = encoder_visual_embed
106
+ if self.act_ckpt:
107
+ pixel_embed = checkpoint.checkpoint(self.pixel_decoder, backbone_visual_feats, use_reentrant=False)
108
+ else:
109
+ pixel_embed = self.pixel_decoder(backbone_visual_feats)
110
+ else:
111
+ backbone_feats = [x for x in backbone_feats]
112
+ pixel_embed = self.pixel_decoder(backbone_feats)
113
+ if pixel_embed.shape[0] == 1:
114
+ # For batch_size=1 training, we can avoid the indexing to save memory
115
+ pixel_embed = pixel_embed.squeeze(0)
116
+ else:
117
+ pixel_embed = pixel_embed[[0], ...]
118
+ return pixel_embed
119
+
120
+ def forward(
121
+ self,
122
+ backbone_feats: list[torch.Tensor],
123
+ obj_queries: torch.Tensor,
124
+ encoder_hidden_states: torch.Tensor = None,
125
+ **kwargs,
126
+ ) -> dict[str, torch.Tensor]:
127
+ """Forward pass of the SegmentationHead."""
128
+ if self.use_encoder_inputs:
129
+ assert encoder_hidden_states is not None
130
+
131
+ pixel_embed = self._embed_pixels(backbone_feats=backbone_feats, encoder_hidden_states=encoder_hidden_states)
132
+
133
+ if self.no_dec:
134
+ mask_pred = self.mask_predictor(pixel_embed)
135
+ elif self.aux_masks:
136
+ mask_pred = self.mask_predictor(obj_queries, pixel_embed)
137
+ else:
138
+ mask_pred = self.mask_predictor(obj_queries[-1], pixel_embed)
139
+
140
+ return {"pred_masks": mask_pred}
141
+
142
+
143
+ class PixelDecoder(nn.Module):
144
+ """Pixel decoder module that upsamples backbone features."""
145
+
146
+ def __init__(
147
+ self,
148
+ hidden_dim,
149
+ num_upsampling_stages,
150
+ interpolation_mode="nearest",
151
+ shared_conv=False,
152
+ compile_mode=None,
153
+ ):
154
+ """Initializes the PixelDecoder."""
155
+ super().__init__()
156
+ self.hidden_dim = hidden_dim
157
+ self.num_upsampling_stages = num_upsampling_stages
158
+ self.interpolation_mode = interpolation_mode
159
+ conv_layers = []
160
+ norms = []
161
+ num_convs = 1 if shared_conv else num_upsampling_stages
162
+ for _ in range(num_convs):
163
+ conv_layers.append(nn.Conv2d(self.hidden_dim, self.hidden_dim, 3, 1, 1))
164
+ norms.append(nn.GroupNorm(8, self.hidden_dim))
165
+
166
+ self.conv_layers = nn.ModuleList(conv_layers)
167
+ self.norms = nn.ModuleList(norms)
168
+ self.shared_conv = shared_conv
169
+ self.out_dim = self.conv_layers[-1].out_channels
170
+ if compile_mode is not None:
171
+ self.forward = torch.compile(self.forward, mode=compile_mode, dynamic=True, fullgraph=True)
172
+ # Needed to make checkpointing happy. But we don't know if the module is checkpointed, so we disable it by default.
173
+ torch._dynamo.config.optimize_ddp = False
174
+
175
+ def forward(self, backbone_feats: list[torch.Tensor]):
176
+ """Forward pass of the PixelDecoder."""
177
+ prev_fpn = backbone_feats[-1]
178
+ fpn_feats = backbone_feats[:-1]
179
+ for layer_idx, bb_feat in enumerate(fpn_feats[::-1]):
180
+ curr_fpn = bb_feat
181
+ prev_fpn = curr_fpn + F.interpolate(prev_fpn, size=curr_fpn.shape[-2:], mode=self.interpolation_mode)
182
+ if self.shared_conv:
183
+ # only one conv layer
184
+ layer_idx = 0
185
+ prev_fpn = self.conv_layers[layer_idx](prev_fpn)
186
+ prev_fpn = F.relu(self.norms[layer_idx](prev_fpn))
187
+
188
+ return prev_fpn
189
+
190
+
191
+ class UniversalSegmentationHead(SegmentationHead):
192
+ """This module handles semantic+instance segmentation."""
193
+
194
+ def __init__(
195
+ self,
196
+ hidden_dim,
197
+ upsampling_stages,
198
+ pixel_decoder,
199
+ aux_masks=False,
200
+ no_dec=False,
201
+ act_ckpt=False,
202
+ presence_head: bool = False,
203
+ dot_product_scorer=None,
204
+ cross_attend_prompt=None,
205
+ ):
206
+ """Initializes the UniversalSegmentationHead."""
207
+ super().__init__(
208
+ hidden_dim=hidden_dim,
209
+ upsampling_stages=upsampling_stages,
210
+ use_encoder_inputs=True,
211
+ aux_masks=aux_masks,
212
+ no_dec=no_dec,
213
+ pixel_decoder=pixel_decoder,
214
+ act_ckpt=act_ckpt,
215
+ )
216
+ self.d_model = hidden_dim
217
+
218
+ if dot_product_scorer is not None:
219
+ assert presence_head, "Specifying a dot product scorer without a presence head is likely a mistake"
220
+
221
+ self.presence_head = None
222
+ if presence_head:
223
+ self.presence_head = (
224
+ dot_product_scorer if dot_product_scorer is not None else LinearPresenceHead(self.d_model)
225
+ )
226
+
227
+ self.cross_attend_prompt = cross_attend_prompt
228
+ if self.cross_attend_prompt is not None:
229
+ self.cross_attn_norm = nn.LayerNorm(self.d_model)
230
+
231
+ self.semantic_seg_head = nn.Conv2d(self.pixel_decoder.out_dim, 1, kernel_size=1)
232
+ self.instance_seg_head = nn.Conv2d(self.pixel_decoder.out_dim, self.d_model, kernel_size=1)
233
+
234
+ def forward(
235
+ self,
236
+ backbone_feats: list[torch.Tensor],
237
+ obj_queries: torch.Tensor,
238
+ encoder_hidden_states: torch.Tensor = None,
239
+ prompt: torch.Tensor = None,
240
+ prompt_mask: torch.Tensor = None,
241
+ **kwargs,
242
+ ) -> dict[str, torch.Tensor]:
243
+ """Forward pass of the UniversalSegmentationHead."""
244
+ assert encoder_hidden_states is not None
245
+ bs = encoder_hidden_states.shape[1]
246
+
247
+ if self.cross_attend_prompt is not None:
248
+ tgt2 = self.cross_attn_norm(encoder_hidden_states)
249
+ tgt2 = self.cross_attend_prompt(
250
+ query=tgt2,
251
+ key=prompt.to(tgt2.dtype),
252
+ value=prompt.to(tgt2.dtype),
253
+ key_padding_mask=prompt_mask,
254
+ need_weights=False,
255
+ )[0]
256
+ encoder_hidden_states = tgt2 + encoder_hidden_states
257
+
258
+ presence_logit = None
259
+ if self.presence_head is not None:
260
+ pooled_enc = encoder_hidden_states.mean(0)
261
+ presence_logit = (
262
+ self.presence_head(
263
+ pooled_enc.view(1, bs, 1, self.d_model),
264
+ prompt=prompt,
265
+ prompt_mask=prompt_mask,
266
+ )
267
+ .squeeze(0)
268
+ .squeeze(1)
269
+ )
270
+
271
+ pixel_embed = self._embed_pixels(backbone_feats=backbone_feats, encoder_hidden_states=encoder_hidden_states)
272
+
273
+ instance_embeds = self.instance_seg_head(pixel_embed)
274
+
275
+ if self.no_dec:
276
+ mask_pred = self.mask_predictor(instance_embeds)
277
+ elif self.aux_masks:
278
+ mask_pred = self.mask_predictor(obj_queries, instance_embeds)
279
+ else:
280
+ mask_pred = self.mask_predictor(obj_queries[-1], instance_embeds)
281
+
282
+ return {
283
+ "pred_masks": mask_pred,
284
+ "semantic_seg": self.semantic_seg_head(pixel_embed),
285
+ "presence_logit": presence_logit,
286
+ }