inference-models 0.18.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. inference_models/__init__.py +36 -0
  2. inference_models/configuration.py +72 -0
  3. inference_models/constants.py +2 -0
  4. inference_models/entities.py +5 -0
  5. inference_models/errors.py +137 -0
  6. inference_models/logger.py +52 -0
  7. inference_models/model_pipelines/__init__.py +0 -0
  8. inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
  9. inference_models/model_pipelines/auto_loaders/core.py +120 -0
  10. inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
  11. inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
  12. inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
  13. inference_models/models/__init__.py +0 -0
  14. inference_models/models/auto_loaders/__init__.py +0 -0
  15. inference_models/models/auto_loaders/access_manager.py +168 -0
  16. inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
  17. inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
  18. inference_models/models/auto_loaders/constants.py +7 -0
  19. inference_models/models/auto_loaders/core.py +1341 -0
  20. inference_models/models/auto_loaders/dependency_models.py +52 -0
  21. inference_models/models/auto_loaders/entities.py +57 -0
  22. inference_models/models/auto_loaders/models_registry.py +497 -0
  23. inference_models/models/auto_loaders/presentation_utils.py +333 -0
  24. inference_models/models/auto_loaders/ranking.py +413 -0
  25. inference_models/models/auto_loaders/utils.py +31 -0
  26. inference_models/models/base/__init__.py +0 -0
  27. inference_models/models/base/classification.py +123 -0
  28. inference_models/models/base/depth_estimation.py +62 -0
  29. inference_models/models/base/documents_parsing.py +111 -0
  30. inference_models/models/base/embeddings.py +66 -0
  31. inference_models/models/base/instance_segmentation.py +87 -0
  32. inference_models/models/base/keypoints_detection.py +93 -0
  33. inference_models/models/base/object_detection.py +143 -0
  34. inference_models/models/base/semantic_segmentation.py +74 -0
  35. inference_models/models/base/types.py +5 -0
  36. inference_models/models/clip/__init__.py +0 -0
  37. inference_models/models/clip/clip_onnx.py +148 -0
  38. inference_models/models/clip/clip_pytorch.py +104 -0
  39. inference_models/models/clip/preprocessing.py +162 -0
  40. inference_models/models/common/__init__.py +0 -0
  41. inference_models/models/common/cuda.py +30 -0
  42. inference_models/models/common/model_packages.py +25 -0
  43. inference_models/models/common/onnx.py +379 -0
  44. inference_models/models/common/roboflow/__init__.py +0 -0
  45. inference_models/models/common/roboflow/model_packages.py +361 -0
  46. inference_models/models/common/roboflow/post_processing.py +436 -0
  47. inference_models/models/common/roboflow/pre_processing.py +1332 -0
  48. inference_models/models/common/torch.py +20 -0
  49. inference_models/models/common/trt.py +266 -0
  50. inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
  51. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
  52. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
  53. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
  54. inference_models/models/depth_anything_v2/__init__.py +0 -0
  55. inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
  56. inference_models/models/dinov3/__init__.py +0 -0
  57. inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
  58. inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
  59. inference_models/models/doctr/__init__.py +0 -0
  60. inference_models/models/doctr/doctr_torch.py +304 -0
  61. inference_models/models/easy_ocr/__init__.py +0 -0
  62. inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
  63. inference_models/models/florence2/__init__.py +0 -0
  64. inference_models/models/florence2/florence2_hf.py +897 -0
  65. inference_models/models/grounding_dino/__init__.py +0 -0
  66. inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
  67. inference_models/models/l2cs/__init__.py +0 -0
  68. inference_models/models/l2cs/l2cs_onnx.py +216 -0
  69. inference_models/models/mediapipe_face_detection/__init__.py +0 -0
  70. inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
  71. inference_models/models/moondream2/__init__.py +0 -0
  72. inference_models/models/moondream2/moondream2_hf.py +281 -0
  73. inference_models/models/owlv2/__init__.py +0 -0
  74. inference_models/models/owlv2/cache.py +182 -0
  75. inference_models/models/owlv2/entities.py +112 -0
  76. inference_models/models/owlv2/owlv2_hf.py +695 -0
  77. inference_models/models/owlv2/reference_dataset.py +291 -0
  78. inference_models/models/paligemma/__init__.py +0 -0
  79. inference_models/models/paligemma/paligemma_hf.py +209 -0
  80. inference_models/models/perception_encoder/__init__.py +0 -0
  81. inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
  82. inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
  83. inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
  84. inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
  85. inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
  86. inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
  87. inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
  88. inference_models/models/qwen25vl/__init__.py +1 -0
  89. inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
  90. inference_models/models/resnet/__init__.py +0 -0
  91. inference_models/models/resnet/resnet_classification_onnx.py +330 -0
  92. inference_models/models/resnet/resnet_classification_torch.py +305 -0
  93. inference_models/models/resnet/resnet_classification_trt.py +369 -0
  94. inference_models/models/rfdetr/__init__.py +0 -0
  95. inference_models/models/rfdetr/backbone_builder.py +101 -0
  96. inference_models/models/rfdetr/class_remapping.py +41 -0
  97. inference_models/models/rfdetr/common.py +115 -0
  98. inference_models/models/rfdetr/default_labels.py +108 -0
  99. inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
  100. inference_models/models/rfdetr/misc.py +26 -0
  101. inference_models/models/rfdetr/ms_deform_attn.py +180 -0
  102. inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
  103. inference_models/models/rfdetr/position_encoding.py +166 -0
  104. inference_models/models/rfdetr/post_processor.py +83 -0
  105. inference_models/models/rfdetr/projector.py +373 -0
  106. inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
  107. inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
  108. inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
  109. inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
  110. inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
  111. inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
  112. inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
  113. inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
  114. inference_models/models/rfdetr/segmentation_head.py +273 -0
  115. inference_models/models/rfdetr/transformer.py +767 -0
  116. inference_models/models/roboflow_instant/__init__.py +0 -0
  117. inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
  118. inference_models/models/sam/__init__.py +0 -0
  119. inference_models/models/sam/cache.py +147 -0
  120. inference_models/models/sam/entities.py +25 -0
  121. inference_models/models/sam/sam_torch.py +675 -0
  122. inference_models/models/sam2/__init__.py +0 -0
  123. inference_models/models/sam2/cache.py +162 -0
  124. inference_models/models/sam2/entities.py +43 -0
  125. inference_models/models/sam2/sam2_torch.py +905 -0
  126. inference_models/models/sam2_rt/__init__.py +0 -0
  127. inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
  128. inference_models/models/smolvlm/__init__.py +0 -0
  129. inference_models/models/smolvlm/smolvlm_hf.py +245 -0
  130. inference_models/models/trocr/__init__.py +0 -0
  131. inference_models/models/trocr/trocr_hf.py +53 -0
  132. inference_models/models/vit/__init__.py +0 -0
  133. inference_models/models/vit/vit_classification_huggingface.py +319 -0
  134. inference_models/models/vit/vit_classification_onnx.py +326 -0
  135. inference_models/models/vit/vit_classification_trt.py +365 -0
  136. inference_models/models/yolact/__init__.py +1 -0
  137. inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
  138. inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
  139. inference_models/models/yolo_world/__init__.py +1 -0
  140. inference_models/models/yolonas/__init__.py +0 -0
  141. inference_models/models/yolonas/nms.py +44 -0
  142. inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
  143. inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
  144. inference_models/models/yolov10/__init__.py +0 -0
  145. inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
  146. inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
  147. inference_models/models/yolov11/__init__.py +0 -0
  148. inference_models/models/yolov11/yolov11_onnx.py +28 -0
  149. inference_models/models/yolov11/yolov11_torch_script.py +25 -0
  150. inference_models/models/yolov11/yolov11_trt.py +21 -0
  151. inference_models/models/yolov12/__init__.py +0 -0
  152. inference_models/models/yolov12/yolov12_onnx.py +7 -0
  153. inference_models/models/yolov12/yolov12_torch_script.py +7 -0
  154. inference_models/models/yolov12/yolov12_trt.py +7 -0
  155. inference_models/models/yolov5/__init__.py +0 -0
  156. inference_models/models/yolov5/nms.py +99 -0
  157. inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
  158. inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
  159. inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
  160. inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
  161. inference_models/models/yolov7/__init__.py +0 -0
  162. inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
  163. inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
  164. inference_models/models/yolov8/__init__.py +0 -0
  165. inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
  166. inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
  167. inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
  168. inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
  169. inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
  170. inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
  171. inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
  172. inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
  173. inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
  174. inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
  175. inference_models/models/yolov9/__init__.py +0 -0
  176. inference_models/models/yolov9/yolov9_onnx.py +7 -0
  177. inference_models/models/yolov9/yolov9_torch_script.py +7 -0
  178. inference_models/models/yolov9/yolov9_trt.py +7 -0
  179. inference_models/runtime_introspection/__init__.py +0 -0
  180. inference_models/runtime_introspection/core.py +410 -0
  181. inference_models/utils/__init__.py +0 -0
  182. inference_models/utils/download.py +608 -0
  183. inference_models/utils/environment.py +28 -0
  184. inference_models/utils/file_system.py +51 -0
  185. inference_models/utils/hashing.py +7 -0
  186. inference_models/utils/imports.py +48 -0
  187. inference_models/utils/onnx_introspection.py +17 -0
  188. inference_models/weights_providers/__init__.py +0 -0
  189. inference_models/weights_providers/core.py +20 -0
  190. inference_models/weights_providers/entities.py +159 -0
  191. inference_models/weights_providers/roboflow.py +601 -0
  192. inference_models-0.18.3.dist-info/METADATA +466 -0
  193. inference_models-0.18.3.dist-info/RECORD +195 -0
  194. inference_models-0.18.3.dist-info/WHEEL +5 -0
  195. inference_models-0.18.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,807 @@
1
+ import argparse
2
+ import copy
3
+ import math
4
+ from typing import Callable, List, Literal, Optional, Union
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import torchvision
9
+ from pydantic import BaseModel, ConfigDict
10
+ from torch import Tensor, nn
11
+
12
+ from inference_models.models.rfdetr.backbone_builder import build_backbone
13
+ from inference_models.models.rfdetr.misc import NestedTensor
14
+ from inference_models.models.rfdetr.segmentation_head import SegmentationHead
15
+ from inference_models.models.rfdetr.transformer import build_transformer
16
+
17
+
18
+ class ModelConfig(BaseModel):
19
+ encoder: Literal["dinov2_windowed_small", "dinov2_windowed_base"]
20
+ out_feature_indexes: List[int]
21
+ dec_layers: int
22
+ two_stage: bool = True
23
+ projector_scale: List[Literal["P3", "P4", "P5"]]
24
+ hidden_dim: int
25
+ patch_size: int
26
+ num_windows: int
27
+ sa_nheads: int
28
+ ca_nheads: int
29
+ dec_n_points: int
30
+ bbox_reparam: bool = True
31
+ lite_refpoint_refine: bool = True
32
+ layer_norm: bool = True
33
+ amp: bool = True
34
+ num_classes: int = 90
35
+ pretrain_weights: Optional[str] = None
36
+ device: torch.device
37
+ resolution: int
38
+ group_detr: int = 13
39
+ gradient_checkpointing: bool = False
40
+ positional_encoding_size: int
41
+ ia_bce_loss: bool = True
42
+ cls_loss_coef: float = 1.0
43
+ segmentation_head: bool = False
44
+ mask_downsample_ratio: int = 4
45
+
46
+ model_config = ConfigDict(arbitrary_types_allowed=True)
47
+
48
+
49
+ class RFDETRBaseConfig(ModelConfig):
50
+ encoder: Literal["dinov2_windowed_small", "dinov2_windowed_base"] = (
51
+ "dinov2_windowed_small"
52
+ )
53
+ hidden_dim: int = 256
54
+ patch_size: int = 14
55
+ num_windows: int = 4
56
+ dec_layers: int = 3
57
+ sa_nheads: int = 8
58
+ ca_nheads: int = 16
59
+ dec_n_points: int = 2
60
+ num_queries: int = 300
61
+ num_select: int = 300
62
+ projector_scale: List[Literal["P3", "P4", "P5"]] = ["P4"]
63
+ out_feature_indexes: List[int] = [2, 5, 8, 11]
64
+ pretrain_weights: Optional[str] = "rf-detr-base.pth"
65
+ resolution: int = 560
66
+ positional_encoding_size: int = 37
67
+
68
+
69
+ class RFDETRLargeConfig(RFDETRBaseConfig):
70
+ encoder: Literal["dinov2_windowed_small", "dinov2_windowed_base"] = (
71
+ "dinov2_windowed_base"
72
+ )
73
+ hidden_dim: int = 384
74
+ sa_nheads: int = 12
75
+ ca_nheads: int = 24
76
+ dec_n_points: int = 4
77
+ projector_scale: List[Literal["P3", "P4", "P5"]] = ["P3", "P5"]
78
+ pretrain_weights: Optional[str] = "rf-detr-large.pth"
79
+
80
+
81
+ class RFDETRNanoConfig(RFDETRBaseConfig):
82
+ out_feature_indexes: List[int] = [3, 6, 9, 12]
83
+ num_windows: int = 2
84
+ dec_layers: int = 2
85
+ patch_size: int = 16
86
+ resolution: int = 384
87
+ positional_encoding_size: int = 24
88
+ pretrain_weights: Optional[str] = "rf-detr-nano.pth"
89
+
90
+
91
+ class RFDETRSmallConfig(RFDETRBaseConfig):
92
+ out_feature_indexes: List[int] = [3, 6, 9, 12]
93
+ num_windows: int = 2
94
+ dec_layers: int = 3
95
+ patch_size: int = 16
96
+ resolution: int = 512
97
+ positional_encoding_size: int = 32
98
+ pretrain_weights: Optional[str] = "rf-detr-small.pth"
99
+
100
+
101
+ class RFDETRMediumConfig(RFDETRBaseConfig):
102
+ out_feature_indexes: List[int] = [3, 6, 9, 12]
103
+ num_windows: int = 2
104
+ dec_layers: int = 4
105
+ patch_size: int = 16
106
+ resolution: int = 576
107
+ positional_encoding_size: int = 36
108
+ pretrain_weights: Optional[str] = "rf-detr-medium.pth"
109
+
110
+
111
+ class RFDETRSegPreviewConfig(RFDETRBaseConfig):
112
+ segmentation_head: bool = True
113
+ out_feature_indexes: List[int] = [3, 6, 9, 12]
114
+ num_windows: int = 2
115
+ dec_layers: int = 4
116
+ patch_size: int = 12
117
+ resolution: int = 432
118
+ positional_encoding_size: int = 36
119
+ num_queries: int = 200
120
+ num_select: int = 200
121
+ pretrain_weights: Optional[str] = "rf-detr-seg-preview.pt"
122
+ num_classes: int = 90
123
+
124
+
125
+ class LWDETR(nn.Module):
126
+ """This is the Group DETR v3 module that performs object detection"""
127
+
128
+ def __init__(
129
+ self,
130
+ backbone,
131
+ transformer,
132
+ segmentation_head,
133
+ num_classes,
134
+ num_queries,
135
+ aux_loss=False,
136
+ group_detr=1,
137
+ two_stage=False,
138
+ lite_refpoint_refine=False,
139
+ bbox_reparam=False,
140
+ ):
141
+ """Initializes the model.
142
+ Parameters:
143
+ backbone: torch module of the backbone to be used. See backbone.py
144
+ transformer: torch module of the transformer architecture. See transformer.py
145
+ num_classes: number of object classes
146
+ num_queries: number of object queries, ie detection slot. This is the maximal number of objects
147
+ Conditional DETR can detect in a single image. For COCO, we recommend 100 queries.
148
+ aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
149
+ group_detr: Number of groups to speed detr training. Default is 1.
150
+ lite_refpoint_refine: TODO
151
+ """
152
+ super().__init__()
153
+ self.num_queries = num_queries
154
+ self.transformer = transformer
155
+ hidden_dim = transformer.d_model
156
+ self.class_embed = nn.Linear(hidden_dim, num_classes)
157
+ self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
158
+ self.segmentation_head = segmentation_head
159
+ query_dim = 4
160
+ self.refpoint_embed = nn.Embedding(num_queries * group_detr, query_dim)
161
+ self.query_feat = nn.Embedding(num_queries * group_detr, hidden_dim)
162
+ nn.init.constant_(self.refpoint_embed.weight.data, 0)
163
+
164
+ self.backbone = backbone
165
+ self.aux_loss = aux_loss
166
+ self.group_detr = group_detr
167
+
168
+ # iter update
169
+ self.lite_refpoint_refine = lite_refpoint_refine
170
+ if not self.lite_refpoint_refine:
171
+ self.transformer.decoder.bbox_embed = self.bbox_embed
172
+ else:
173
+ self.transformer.decoder.bbox_embed = None
174
+
175
+ self.bbox_reparam = bbox_reparam
176
+
177
+ # init prior_prob setting for focal loss
178
+ prior_prob = 0.01
179
+ bias_value = -math.log((1 - prior_prob) / prior_prob)
180
+ self.class_embed.bias.data = torch.ones(num_classes) * bias_value
181
+
182
+ # init bbox_mebed
183
+ nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0)
184
+ nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)
185
+
186
+ # two_stage
187
+ self.two_stage = two_stage
188
+ if self.two_stage:
189
+ self.transformer.enc_out_bbox_embed = nn.ModuleList(
190
+ [copy.deepcopy(self.bbox_embed) for _ in range(group_detr)]
191
+ )
192
+ self.transformer.enc_out_class_embed = nn.ModuleList(
193
+ [copy.deepcopy(self.class_embed) for _ in range(group_detr)]
194
+ )
195
+
196
+ self._export = False
197
+
198
+ def reinitialize_detection_head(self, num_classes):
199
+ base = self.class_embed.weight.shape[0]
200
+ num_repeats = int(math.ceil(num_classes / base))
201
+ self.class_embed.weight.data = self.class_embed.weight.data.repeat(
202
+ num_repeats, 1
203
+ )
204
+ self.class_embed.weight.data = self.class_embed.weight.data[:num_classes]
205
+ self.class_embed.bias.data = self.class_embed.bias.data.repeat(num_repeats)
206
+ self.class_embed.bias.data = self.class_embed.bias.data[:num_classes]
207
+
208
+ if self.two_stage:
209
+ for enc_out_class_embed in self.transformer.enc_out_class_embed:
210
+ enc_out_class_embed.weight.data = (
211
+ enc_out_class_embed.weight.data.repeat(num_repeats, 1)
212
+ )
213
+ enc_out_class_embed.weight.data = enc_out_class_embed.weight.data[
214
+ :num_classes
215
+ ]
216
+ enc_out_class_embed.bias.data = enc_out_class_embed.bias.data.repeat(
217
+ num_repeats
218
+ )
219
+ enc_out_class_embed.bias.data = enc_out_class_embed.bias.data[
220
+ :num_classes
221
+ ]
222
+
223
+ def export(self):
224
+ self._export = True
225
+ self._forward_origin = self.forward
226
+ self.forward = self.forward_export
227
+ for name, m in self.named_modules():
228
+ if (
229
+ hasattr(m, "export")
230
+ and isinstance(m.export, Callable)
231
+ and hasattr(m, "_export")
232
+ and not m._export
233
+ ):
234
+ m.export()
235
+
236
+ def forward(self, samples: NestedTensor, targets=None):
237
+ """The forward expects a NestedTensor, which consists of:
238
+ - samples.tensor: batched images, of shape [batch_size x 3 x H x W]
239
+ - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
240
+
241
+ It returns a dict with the following elements:
242
+ - "pred_logits": the classification logits (including no-object) for all queries.
243
+ Shape= [batch_size x num_queries x num_classes]
244
+ - "pred_boxes": The normalized boxes coordinates for all queries, represented as
245
+ (center_x, center_y, width, height). These values are normalized in [0, 1],
246
+ relative to the size of each individual image (disregarding possible padding).
247
+ See PostProcess for information on how to retrieve the unnormalized bounding box.
248
+ - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
249
+ dictionnaries containing the two above keys for each decoder layer.
250
+ """
251
+ if isinstance(samples, (list, torch.Tensor)):
252
+ samples = nested_tensor_from_tensor_list(samples)
253
+ features, poss = self.backbone(samples)
254
+
255
+ srcs = []
256
+ masks = []
257
+ for l, feat in enumerate(features):
258
+ src, mask = feat.decompose()
259
+ srcs.append(src)
260
+ masks.append(mask)
261
+ assert mask is not None
262
+
263
+ if self.training:
264
+ refpoint_embed_weight = self.refpoint_embed.weight
265
+ query_feat_weight = self.query_feat.weight
266
+ else:
267
+ # only use one group in inference
268
+ refpoint_embed_weight = self.refpoint_embed.weight[: self.num_queries]
269
+ query_feat_weight = self.query_feat.weight[: self.num_queries]
270
+
271
+ hs, ref_unsigmoid, hs_enc, ref_enc = self.transformer(
272
+ srcs, masks, poss, refpoint_embed_weight, query_feat_weight
273
+ )
274
+
275
+ if hs is not None:
276
+ if self.bbox_reparam:
277
+ outputs_coord_delta = self.bbox_embed(hs)
278
+ outputs_coord_cxcy = (
279
+ outputs_coord_delta[..., :2] * ref_unsigmoid[..., 2:]
280
+ + ref_unsigmoid[..., :2]
281
+ )
282
+ outputs_coord_wh = (
283
+ outputs_coord_delta[..., 2:].exp() * ref_unsigmoid[..., 2:]
284
+ )
285
+ outputs_coord = torch.concat(
286
+ [outputs_coord_cxcy, outputs_coord_wh], dim=-1
287
+ )
288
+ else:
289
+ outputs_coord = (self.bbox_embed(hs) + ref_unsigmoid).sigmoid()
290
+
291
+ outputs_class = self.class_embed(hs)
292
+
293
+ if self.segmentation_head is not None:
294
+ outputs_masks = self.segmentation_head(
295
+ features[0].tensors, hs, samples.tensors.shape[-2:]
296
+ )
297
+
298
+ out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord[-1]}
299
+ if self.segmentation_head is not None:
300
+ out["pred_masks"] = outputs_masks[-1]
301
+ if self.aux_loss:
302
+ out["aux_outputs"] = self._set_aux_loss(
303
+ outputs_class,
304
+ outputs_coord,
305
+ outputs_masks if self.segmentation_head is not None else None,
306
+ )
307
+
308
+ if self.two_stage:
309
+ group_detr = self.group_detr if self.training else 1
310
+ hs_enc_list = hs_enc.chunk(group_detr, dim=1)
311
+ cls_enc = []
312
+ for g_idx in range(group_detr):
313
+ cls_enc_gidx = self.transformer.enc_out_class_embed[g_idx](
314
+ hs_enc_list[g_idx]
315
+ )
316
+ cls_enc.append(cls_enc_gidx)
317
+
318
+ cls_enc = torch.cat(cls_enc, dim=1)
319
+
320
+ if self.segmentation_head is not None:
321
+ masks_enc = self.segmentation_head(
322
+ features[0].tensors,
323
+ [
324
+ hs_enc,
325
+ ],
326
+ samples.tensors.shape[-2:],
327
+ skip_blocks=True,
328
+ )
329
+ masks_enc = torch.cat(masks_enc, dim=1)
330
+
331
+ if hs is not None:
332
+ out["enc_outputs"] = {"pred_logits": cls_enc, "pred_boxes": ref_enc}
333
+ if self.segmentation_head is not None:
334
+ out["enc_outputs"]["pred_masks"] = masks_enc
335
+ else:
336
+ out = {"pred_logits": cls_enc, "pred_boxes": ref_enc}
337
+ if self.segmentation_head is not None:
338
+ out["pred_masks"] = masks_enc
339
+
340
+ return out
341
+
342
+ def forward_export(self, tensors):
343
+ srcs, _, poss = self.backbone(tensors)
344
+ # only use one group in inference
345
+ refpoint_embed_weight = self.refpoint_embed.weight[: self.num_queries]
346
+ query_feat_weight = self.query_feat.weight[: self.num_queries]
347
+
348
+ hs, ref_unsigmoid, hs_enc, ref_enc = self.transformer(
349
+ srcs, None, poss, refpoint_embed_weight, query_feat_weight
350
+ )
351
+
352
+ outputs_masks = None
353
+
354
+ if hs is not None:
355
+ if self.bbox_reparam:
356
+ outputs_coord_delta = self.bbox_embed(hs)
357
+ outputs_coord_cxcy = (
358
+ outputs_coord_delta[..., :2] * ref_unsigmoid[..., 2:]
359
+ + ref_unsigmoid[..., :2]
360
+ )
361
+ outputs_coord_wh = (
362
+ outputs_coord_delta[..., 2:].exp() * ref_unsigmoid[..., 2:]
363
+ )
364
+ outputs_coord = torch.concat(
365
+ [outputs_coord_cxcy, outputs_coord_wh], dim=-1
366
+ )
367
+ else:
368
+ outputs_coord = (self.bbox_embed(hs) + ref_unsigmoid).sigmoid()
369
+ outputs_class = self.class_embed(hs)
370
+ if self.segmentation_head is not None:
371
+ outputs_masks = self.segmentation_head(
372
+ srcs[0],
373
+ [
374
+ hs,
375
+ ],
376
+ tensors.shape[-2:],
377
+ )[0]
378
+ else:
379
+ assert self.two_stage, "if not using decoder, two_stage must be True"
380
+ outputs_class = self.transformer.enc_out_class_embed[0](hs_enc)
381
+ outputs_coord = ref_enc
382
+ if self.segmentation_head is not None:
383
+ outputs_masks = self.segmentation_head(
384
+ srcs[0],
385
+ [
386
+ hs_enc,
387
+ ],
388
+ tensors.shape[-2:],
389
+ skip_blocks=True,
390
+ )[0]
391
+
392
+ if outputs_masks is not None:
393
+ return outputs_coord, outputs_class, outputs_masks
394
+ else:
395
+ return outputs_coord, outputs_class
396
+
397
+ @torch.jit.unused
398
+ def _set_aux_loss(self, outputs_class, outputs_coord, outputs_masks):
399
+ # this is a workaround to make torchscript happy, as torchscript
400
+ # doesn't support dictionary with non-homogeneous values, such
401
+ # as a dict having both a Tensor and a list.
402
+ if outputs_masks is not None:
403
+ return [
404
+ {"pred_logits": a, "pred_boxes": b, "pred_masks": c}
405
+ for a, b, c in zip(
406
+ outputs_class[:-1], outputs_coord[:-1], outputs_masks[:-1]
407
+ )
408
+ ]
409
+ else:
410
+ return [
411
+ {"pred_logits": a, "pred_boxes": b}
412
+ for a, b in zip(outputs_class[:-1], outputs_coord[:-1])
413
+ ]
414
+
415
+ def update_drop_path(self, drop_path_rate, vit_encoder_num_layers):
416
+ """ """
417
+ dp_rates = [
418
+ x.item() for x in torch.linspace(0, drop_path_rate, vit_encoder_num_layers)
419
+ ]
420
+ for i in range(vit_encoder_num_layers):
421
+ if hasattr(self.backbone[0].encoder, "blocks"): # Not aimv2
422
+ if hasattr(self.backbone[0].encoder.blocks[i].drop_path, "drop_prob"):
423
+ self.backbone[0].encoder.blocks[i].drop_path.drop_prob = dp_rates[i]
424
+ else: # aimv2
425
+ if hasattr(
426
+ self.backbone[0].encoder.trunk.blocks[i].drop_path, "drop_prob"
427
+ ):
428
+ self.backbone[0].encoder.trunk.blocks[i].drop_path.drop_prob = (
429
+ dp_rates[i]
430
+ )
431
+
432
+ def update_dropout(self, drop_rate):
433
+ for module in self.transformer.modules():
434
+ if isinstance(module, nn.Dropout):
435
+ module.p = drop_rate
436
+
437
+
438
+ class MLP(nn.Module):
439
+ """Very simple multi-layer perceptron (also called FFN)"""
440
+
441
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
442
+ super().__init__()
443
+ self.num_layers = num_layers
444
+ h = [hidden_dim] * (num_layers - 1)
445
+ self.layers = nn.ModuleList(
446
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
447
+ )
448
+
449
+ def forward(self, x):
450
+ for i, layer in enumerate(self.layers):
451
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
452
+ return x
453
+
454
+
455
+ def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
456
+ # TODO make this more general
457
+ if tensor_list[0].ndim == 3:
458
+ if torchvision._is_tracing():
459
+ # nested_tensor_from_tensor_list() does not export well to ONNX
460
+ # call _onnx_nested_tensor_from_tensor_list() instead
461
+ return _onnx_nested_tensor_from_tensor_list(tensor_list)
462
+
463
+ # TODO make it support different-sized images
464
+ max_size = _max_by_axis([list(img.shape) for img in tensor_list])
465
+ # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
466
+ batch_shape = [len(tensor_list)] + max_size
467
+ b, c, h, w = batch_shape
468
+ dtype = tensor_list[0].dtype
469
+ device = tensor_list[0].device
470
+ tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
471
+ mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
472
+ for img, pad_img, m in zip(tensor_list, tensor, mask):
473
+ pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
474
+ m[: img.shape[1], : img.shape[2]] = False
475
+ else:
476
+ raise ValueError("not supported")
477
+ return NestedTensor(tensor, mask)
478
+
479
+
480
+ # _onnx_nested_tensor_from_tensor_list() is an implementation of
481
+ # nested_tensor_from_tensor_list() that is supported by ONNX tracing.
482
+ @torch.jit.unused
483
+ def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
484
+ max_size = []
485
+ for i in range(tensor_list[0].dim()):
486
+ max_size_i = torch.max(
487
+ torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)
488
+ ).to(torch.int64)
489
+ max_size.append(max_size_i)
490
+ max_size = tuple(max_size)
491
+
492
+ # work around for
493
+ # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
494
+ # m[: img.shape[1], :img.shape[2]] = False
495
+ # which is not yet supported in onnx
496
+ padded_imgs = []
497
+ padded_masks = []
498
+ for img in tensor_list:
499
+ padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
500
+ padded_img = torch.nn.functional.pad(
501
+ img, (0, padding[2], 0, padding[1], 0, padding[0])
502
+ )
503
+ padded_imgs.append(padded_img)
504
+
505
+ m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
506
+ padded_mask = torch.nn.functional.pad(
507
+ m, (0, padding[2], 0, padding[1]), "constant", 1
508
+ )
509
+ padded_masks.append(padded_mask.to(torch.bool))
510
+
511
+ tensor = torch.stack(padded_imgs)
512
+ mask = torch.stack(padded_masks)
513
+
514
+ return NestedTensor(tensor, mask=mask)
515
+
516
+
517
+ def _max_by_axis(the_list):
518
+ # type: (List[List[int]]) -> List[int]
519
+ maxes = the_list[0]
520
+ for sublist in the_list[1:]:
521
+ for index, item in enumerate(sublist):
522
+ maxes[index] = max(maxes[index], item)
523
+ return maxes
524
+
525
+
526
+ def build_model(config: ModelConfig) -> LWDETR:
527
+ # the `num_classes` naming here is somewhat misleading.
528
+ # it indeed corresponds to `max_obj_id + 1`, where max_obj_id
529
+ # is the maximum id for a class in your dataset. For example,
530
+ # COCO has a max_obj_id of 90, so we pass `num_classes` to be 91.
531
+ # As another example, for a dataset that has a single class with id 1,
532
+ # you should pass `num_classes` to be 2 (max_obj_id + 1).
533
+ # For more details on this, check the following discussion
534
+ # https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223
535
+ args = populate_args(**config.dict())
536
+ num_classes = args.num_classes + 1
537
+ backbone = build_backbone(
538
+ encoder=args.encoder,
539
+ vit_encoder_num_layers=args.vit_encoder_num_layers,
540
+ pretrained_encoder=args.pretrained_encoder,
541
+ window_block_indexes=args.window_block_indexes,
542
+ drop_path=args.drop_path,
543
+ out_channels=args.hidden_dim,
544
+ out_feature_indexes=args.out_feature_indexes,
545
+ projector_scale=args.projector_scale,
546
+ use_cls_token=args.use_cls_token,
547
+ hidden_dim=args.hidden_dim,
548
+ position_embedding=args.position_embedding,
549
+ freeze_encoder=args.freeze_encoder,
550
+ layer_norm=args.layer_norm,
551
+ target_shape=(
552
+ args.shape
553
+ if hasattr(args, "shape")
554
+ else (
555
+ (args.resolution, args.resolution)
556
+ if hasattr(args, "resolution")
557
+ else (640, 640)
558
+ )
559
+ ),
560
+ rms_norm=args.rms_norm,
561
+ backbone_lora=args.backbone_lora,
562
+ force_no_pretrain=args.force_no_pretrain,
563
+ gradient_checkpointing=args.gradient_checkpointing,
564
+ load_dinov2_weights=args.pretrain_weights is None,
565
+ patch_size=config.patch_size,
566
+ num_windows=config.num_windows,
567
+ positional_encoding_size=config.positional_encoding_size,
568
+ )
569
+ if args.encoder_only:
570
+ return backbone[0].encoder, None, None
571
+ if args.backbone_only:
572
+ return backbone, None, None
573
+ args.num_feature_levels = len(args.projector_scale)
574
+ transformer = build_transformer(args)
575
+ segmentation_head = (
576
+ SegmentationHead(
577
+ args.hidden_dim,
578
+ args.dec_layers,
579
+ downsample_ratio=args.mask_downsample_ratio,
580
+ )
581
+ if args.segmentation_head
582
+ else None
583
+ )
584
+ return LWDETR(
585
+ backbone,
586
+ transformer,
587
+ segmentation_head,
588
+ num_classes=num_classes,
589
+ num_queries=args.num_queries,
590
+ aux_loss=args.aux_loss,
591
+ group_detr=args.group_detr,
592
+ two_stage=args.two_stage,
593
+ lite_refpoint_refine=args.lite_refpoint_refine,
594
+ bbox_reparam=args.bbox_reparam,
595
+ )
596
+
597
+
598
+ def populate_args(
599
+ # Basic training parameters
600
+ num_classes=2,
601
+ grad_accum_steps=1,
602
+ amp=False,
603
+ lr=1e-4,
604
+ lr_encoder=1.5e-4,
605
+ batch_size=2,
606
+ weight_decay=1e-4,
607
+ epochs=12,
608
+ lr_drop=11,
609
+ clip_max_norm=0.1,
610
+ lr_vit_layer_decay=0.8,
611
+ lr_component_decay=1.0,
612
+ do_benchmark=False,
613
+ # Drop parameters
614
+ dropout=0,
615
+ drop_path=0,
616
+ drop_mode="standard",
617
+ drop_schedule="constant",
618
+ cutoff_epoch=0,
619
+ # Model parameters
620
+ pretrained_encoder=None,
621
+ pretrain_weights=None,
622
+ pretrain_exclude_keys=None,
623
+ pretrain_keys_modify_to_load=None,
624
+ pretrained_distiller=None,
625
+ # Backbone parameters
626
+ encoder="vit_tiny",
627
+ vit_encoder_num_layers=12,
628
+ window_block_indexes=None,
629
+ position_embedding="sine",
630
+ out_feature_indexes=[-1],
631
+ freeze_encoder=False,
632
+ layer_norm=False,
633
+ rms_norm=False,
634
+ backbone_lora=False,
635
+ force_no_pretrain=False,
636
+ # Transformer parameters
637
+ dec_layers=3,
638
+ dim_feedforward=2048,
639
+ hidden_dim=256,
640
+ sa_nheads=8,
641
+ ca_nheads=8,
642
+ num_queries=300,
643
+ group_detr=13,
644
+ two_stage=False,
645
+ projector_scale="P4",
646
+ lite_refpoint_refine=False,
647
+ num_select=100,
648
+ dec_n_points=4,
649
+ decoder_norm="LN",
650
+ bbox_reparam=False,
651
+ freeze_batch_norm=False,
652
+ # Matcher parameters
653
+ set_cost_class=2,
654
+ set_cost_bbox=5,
655
+ set_cost_giou=2,
656
+ # Loss coefficients
657
+ cls_loss_coef=2,
658
+ bbox_loss_coef=5,
659
+ giou_loss_coef=2,
660
+ focal_alpha=0.25,
661
+ aux_loss=True,
662
+ sum_group_losses=False,
663
+ use_varifocal_loss=False,
664
+ use_position_supervised_loss=False,
665
+ ia_bce_loss=False,
666
+ # Dataset parameters
667
+ dataset_file="coco",
668
+ coco_path=None,
669
+ dataset_dir=None,
670
+ square_resize_div_64=False,
671
+ # Output parameters
672
+ output_dir="output",
673
+ dont_save_weights=False,
674
+ checkpoint_interval=10,
675
+ seed=42,
676
+ resume="",
677
+ start_epoch=0,
678
+ eval=False,
679
+ use_ema=False,
680
+ ema_decay=0.9997,
681
+ ema_tau=0,
682
+ num_workers=2,
683
+ # Distributed training parameters
684
+ device="cuda",
685
+ world_size=1,
686
+ dist_url="env://",
687
+ sync_bn=True,
688
+ # FP16
689
+ fp16_eval=False,
690
+ # Custom args
691
+ encoder_only=False,
692
+ backbone_only=False,
693
+ resolution=640,
694
+ use_cls_token=False,
695
+ multi_scale=False,
696
+ expanded_scales=False,
697
+ warmup_epochs=1,
698
+ lr_scheduler="step",
699
+ lr_min_factor=0.0,
700
+ # Early stopping parameters
701
+ early_stopping=True,
702
+ early_stopping_patience=10,
703
+ early_stopping_min_delta=0.001,
704
+ early_stopping_use_ema=False,
705
+ gradient_checkpointing=False,
706
+ # Additional
707
+ subcommand=None,
708
+ **extra_kwargs, # To handle any unexpected arguments
709
+ ):
710
+ args = argparse.Namespace(
711
+ num_classes=num_classes,
712
+ grad_accum_steps=grad_accum_steps,
713
+ amp=amp,
714
+ lr=lr,
715
+ lr_encoder=lr_encoder,
716
+ batch_size=batch_size,
717
+ weight_decay=weight_decay,
718
+ epochs=epochs,
719
+ lr_drop=lr_drop,
720
+ clip_max_norm=clip_max_norm,
721
+ lr_vit_layer_decay=lr_vit_layer_decay,
722
+ lr_component_decay=lr_component_decay,
723
+ do_benchmark=do_benchmark,
724
+ dropout=dropout,
725
+ drop_path=drop_path,
726
+ drop_mode=drop_mode,
727
+ drop_schedule=drop_schedule,
728
+ cutoff_epoch=cutoff_epoch,
729
+ pretrained_encoder=pretrained_encoder,
730
+ pretrain_weights=pretrain_weights,
731
+ pretrain_exclude_keys=pretrain_exclude_keys,
732
+ pretrain_keys_modify_to_load=pretrain_keys_modify_to_load,
733
+ pretrained_distiller=pretrained_distiller,
734
+ encoder=encoder,
735
+ vit_encoder_num_layers=vit_encoder_num_layers,
736
+ window_block_indexes=window_block_indexes,
737
+ position_embedding=position_embedding,
738
+ out_feature_indexes=out_feature_indexes,
739
+ freeze_encoder=freeze_encoder,
740
+ layer_norm=layer_norm,
741
+ rms_norm=rms_norm,
742
+ backbone_lora=backbone_lora,
743
+ force_no_pretrain=force_no_pretrain,
744
+ dec_layers=dec_layers,
745
+ dim_feedforward=dim_feedforward,
746
+ hidden_dim=hidden_dim,
747
+ sa_nheads=sa_nheads,
748
+ ca_nheads=ca_nheads,
749
+ num_queries=num_queries,
750
+ group_detr=group_detr,
751
+ two_stage=two_stage,
752
+ projector_scale=projector_scale,
753
+ lite_refpoint_refine=lite_refpoint_refine,
754
+ num_select=num_select,
755
+ dec_n_points=dec_n_points,
756
+ decoder_norm=decoder_norm,
757
+ bbox_reparam=bbox_reparam,
758
+ freeze_batch_norm=freeze_batch_norm,
759
+ set_cost_class=set_cost_class,
760
+ set_cost_bbox=set_cost_bbox,
761
+ set_cost_giou=set_cost_giou,
762
+ cls_loss_coef=cls_loss_coef,
763
+ bbox_loss_coef=bbox_loss_coef,
764
+ giou_loss_coef=giou_loss_coef,
765
+ focal_alpha=focal_alpha,
766
+ aux_loss=aux_loss,
767
+ sum_group_losses=sum_group_losses,
768
+ use_varifocal_loss=use_varifocal_loss,
769
+ use_position_supervised_loss=use_position_supervised_loss,
770
+ ia_bce_loss=ia_bce_loss,
771
+ dataset_file=dataset_file,
772
+ coco_path=coco_path,
773
+ dataset_dir=dataset_dir,
774
+ square_resize_div_64=square_resize_div_64,
775
+ output_dir=output_dir,
776
+ dont_save_weights=dont_save_weights,
777
+ checkpoint_interval=checkpoint_interval,
778
+ seed=seed,
779
+ resume=resume,
780
+ start_epoch=start_epoch,
781
+ eval=eval,
782
+ use_ema=use_ema,
783
+ ema_decay=ema_decay,
784
+ ema_tau=ema_tau,
785
+ num_workers=num_workers,
786
+ device=device,
787
+ world_size=world_size,
788
+ dist_url=dist_url,
789
+ sync_bn=sync_bn,
790
+ fp16_eval=fp16_eval,
791
+ encoder_only=encoder_only,
792
+ backbone_only=backbone_only,
793
+ resolution=resolution,
794
+ use_cls_token=use_cls_token,
795
+ multi_scale=multi_scale,
796
+ expanded_scales=expanded_scales,
797
+ warmup_epochs=warmup_epochs,
798
+ lr_scheduler=lr_scheduler,
799
+ lr_min_factor=lr_min_factor,
800
+ early_stopping=early_stopping,
801
+ early_stopping_patience=early_stopping_patience,
802
+ early_stopping_min_delta=early_stopping_min_delta,
803
+ early_stopping_use_ema=early_stopping_use_ema,
804
+ gradient_checkpointing=gradient_checkpointing,
805
+ **extra_kwargs,
806
+ )
807
+ return args