inference-models 0.18.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. inference_models/__init__.py +36 -0
  2. inference_models/configuration.py +72 -0
  3. inference_models/constants.py +2 -0
  4. inference_models/entities.py +5 -0
  5. inference_models/errors.py +137 -0
  6. inference_models/logger.py +52 -0
  7. inference_models/model_pipelines/__init__.py +0 -0
  8. inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
  9. inference_models/model_pipelines/auto_loaders/core.py +120 -0
  10. inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
  11. inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
  12. inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
  13. inference_models/models/__init__.py +0 -0
  14. inference_models/models/auto_loaders/__init__.py +0 -0
  15. inference_models/models/auto_loaders/access_manager.py +168 -0
  16. inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
  17. inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
  18. inference_models/models/auto_loaders/constants.py +7 -0
  19. inference_models/models/auto_loaders/core.py +1341 -0
  20. inference_models/models/auto_loaders/dependency_models.py +52 -0
  21. inference_models/models/auto_loaders/entities.py +57 -0
  22. inference_models/models/auto_loaders/models_registry.py +497 -0
  23. inference_models/models/auto_loaders/presentation_utils.py +333 -0
  24. inference_models/models/auto_loaders/ranking.py +413 -0
  25. inference_models/models/auto_loaders/utils.py +31 -0
  26. inference_models/models/base/__init__.py +0 -0
  27. inference_models/models/base/classification.py +123 -0
  28. inference_models/models/base/depth_estimation.py +62 -0
  29. inference_models/models/base/documents_parsing.py +111 -0
  30. inference_models/models/base/embeddings.py +66 -0
  31. inference_models/models/base/instance_segmentation.py +87 -0
  32. inference_models/models/base/keypoints_detection.py +93 -0
  33. inference_models/models/base/object_detection.py +143 -0
  34. inference_models/models/base/semantic_segmentation.py +74 -0
  35. inference_models/models/base/types.py +5 -0
  36. inference_models/models/clip/__init__.py +0 -0
  37. inference_models/models/clip/clip_onnx.py +148 -0
  38. inference_models/models/clip/clip_pytorch.py +104 -0
  39. inference_models/models/clip/preprocessing.py +162 -0
  40. inference_models/models/common/__init__.py +0 -0
  41. inference_models/models/common/cuda.py +30 -0
  42. inference_models/models/common/model_packages.py +25 -0
  43. inference_models/models/common/onnx.py +379 -0
  44. inference_models/models/common/roboflow/__init__.py +0 -0
  45. inference_models/models/common/roboflow/model_packages.py +361 -0
  46. inference_models/models/common/roboflow/post_processing.py +436 -0
  47. inference_models/models/common/roboflow/pre_processing.py +1332 -0
  48. inference_models/models/common/torch.py +20 -0
  49. inference_models/models/common/trt.py +266 -0
  50. inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
  51. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
  52. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
  53. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
  54. inference_models/models/depth_anything_v2/__init__.py +0 -0
  55. inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
  56. inference_models/models/dinov3/__init__.py +0 -0
  57. inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
  58. inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
  59. inference_models/models/doctr/__init__.py +0 -0
  60. inference_models/models/doctr/doctr_torch.py +304 -0
  61. inference_models/models/easy_ocr/__init__.py +0 -0
  62. inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
  63. inference_models/models/florence2/__init__.py +0 -0
  64. inference_models/models/florence2/florence2_hf.py +897 -0
  65. inference_models/models/grounding_dino/__init__.py +0 -0
  66. inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
  67. inference_models/models/l2cs/__init__.py +0 -0
  68. inference_models/models/l2cs/l2cs_onnx.py +216 -0
  69. inference_models/models/mediapipe_face_detection/__init__.py +0 -0
  70. inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
  71. inference_models/models/moondream2/__init__.py +0 -0
  72. inference_models/models/moondream2/moondream2_hf.py +281 -0
  73. inference_models/models/owlv2/__init__.py +0 -0
  74. inference_models/models/owlv2/cache.py +182 -0
  75. inference_models/models/owlv2/entities.py +112 -0
  76. inference_models/models/owlv2/owlv2_hf.py +695 -0
  77. inference_models/models/owlv2/reference_dataset.py +291 -0
  78. inference_models/models/paligemma/__init__.py +0 -0
  79. inference_models/models/paligemma/paligemma_hf.py +209 -0
  80. inference_models/models/perception_encoder/__init__.py +0 -0
  81. inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
  82. inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
  83. inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
  84. inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
  85. inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
  86. inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
  87. inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
  88. inference_models/models/qwen25vl/__init__.py +1 -0
  89. inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
  90. inference_models/models/resnet/__init__.py +0 -0
  91. inference_models/models/resnet/resnet_classification_onnx.py +330 -0
  92. inference_models/models/resnet/resnet_classification_torch.py +305 -0
  93. inference_models/models/resnet/resnet_classification_trt.py +369 -0
  94. inference_models/models/rfdetr/__init__.py +0 -0
  95. inference_models/models/rfdetr/backbone_builder.py +101 -0
  96. inference_models/models/rfdetr/class_remapping.py +41 -0
  97. inference_models/models/rfdetr/common.py +115 -0
  98. inference_models/models/rfdetr/default_labels.py +108 -0
  99. inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
  100. inference_models/models/rfdetr/misc.py +26 -0
  101. inference_models/models/rfdetr/ms_deform_attn.py +180 -0
  102. inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
  103. inference_models/models/rfdetr/position_encoding.py +166 -0
  104. inference_models/models/rfdetr/post_processor.py +83 -0
  105. inference_models/models/rfdetr/projector.py +373 -0
  106. inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
  107. inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
  108. inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
  109. inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
  110. inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
  111. inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
  112. inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
  113. inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
  114. inference_models/models/rfdetr/segmentation_head.py +273 -0
  115. inference_models/models/rfdetr/transformer.py +767 -0
  116. inference_models/models/roboflow_instant/__init__.py +0 -0
  117. inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
  118. inference_models/models/sam/__init__.py +0 -0
  119. inference_models/models/sam/cache.py +147 -0
  120. inference_models/models/sam/entities.py +25 -0
  121. inference_models/models/sam/sam_torch.py +675 -0
  122. inference_models/models/sam2/__init__.py +0 -0
  123. inference_models/models/sam2/cache.py +162 -0
  124. inference_models/models/sam2/entities.py +43 -0
  125. inference_models/models/sam2/sam2_torch.py +905 -0
  126. inference_models/models/sam2_rt/__init__.py +0 -0
  127. inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
  128. inference_models/models/smolvlm/__init__.py +0 -0
  129. inference_models/models/smolvlm/smolvlm_hf.py +245 -0
  130. inference_models/models/trocr/__init__.py +0 -0
  131. inference_models/models/trocr/trocr_hf.py +53 -0
  132. inference_models/models/vit/__init__.py +0 -0
  133. inference_models/models/vit/vit_classification_huggingface.py +319 -0
  134. inference_models/models/vit/vit_classification_onnx.py +326 -0
  135. inference_models/models/vit/vit_classification_trt.py +365 -0
  136. inference_models/models/yolact/__init__.py +1 -0
  137. inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
  138. inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
  139. inference_models/models/yolo_world/__init__.py +1 -0
  140. inference_models/models/yolonas/__init__.py +0 -0
  141. inference_models/models/yolonas/nms.py +44 -0
  142. inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
  143. inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
  144. inference_models/models/yolov10/__init__.py +0 -0
  145. inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
  146. inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
  147. inference_models/models/yolov11/__init__.py +0 -0
  148. inference_models/models/yolov11/yolov11_onnx.py +28 -0
  149. inference_models/models/yolov11/yolov11_torch_script.py +25 -0
  150. inference_models/models/yolov11/yolov11_trt.py +21 -0
  151. inference_models/models/yolov12/__init__.py +0 -0
  152. inference_models/models/yolov12/yolov12_onnx.py +7 -0
  153. inference_models/models/yolov12/yolov12_torch_script.py +7 -0
  154. inference_models/models/yolov12/yolov12_trt.py +7 -0
  155. inference_models/models/yolov5/__init__.py +0 -0
  156. inference_models/models/yolov5/nms.py +99 -0
  157. inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
  158. inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
  159. inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
  160. inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
  161. inference_models/models/yolov7/__init__.py +0 -0
  162. inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
  163. inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
  164. inference_models/models/yolov8/__init__.py +0 -0
  165. inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
  166. inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
  167. inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
  168. inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
  169. inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
  170. inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
  171. inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
  172. inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
  173. inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
  174. inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
  175. inference_models/models/yolov9/__init__.py +0 -0
  176. inference_models/models/yolov9/yolov9_onnx.py +7 -0
  177. inference_models/models/yolov9/yolov9_torch_script.py +7 -0
  178. inference_models/models/yolov9/yolov9_trt.py +7 -0
  179. inference_models/runtime_introspection/__init__.py +0 -0
  180. inference_models/runtime_introspection/core.py +410 -0
  181. inference_models/utils/__init__.py +0 -0
  182. inference_models/utils/download.py +608 -0
  183. inference_models/utils/environment.py +28 -0
  184. inference_models/utils/file_system.py +51 -0
  185. inference_models/utils/hashing.py +7 -0
  186. inference_models/utils/imports.py +48 -0
  187. inference_models/utils/onnx_introspection.py +17 -0
  188. inference_models/weights_providers/__init__.py +0 -0
  189. inference_models/weights_providers/core.py +20 -0
  190. inference_models/weights_providers/entities.py +159 -0
  191. inference_models/weights_providers/roboflow.py +601 -0
  192. inference_models-0.18.3.dist-info/METADATA +466 -0
  193. inference_models-0.18.3.dist-info/RECORD +195 -0
  194. inference_models-0.18.3.dist-info/WHEEL +5 -0
  195. inference_models-0.18.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,394 @@
1
+ import json
2
+ import math
3
+ import os
4
+ import types
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from peft import PeftModel
9
+ from torch import nn
10
+ from transformers import AutoBackbone
11
+
12
+ from inference_models.logger import LOGGER
13
+ from inference_models.models.rfdetr.dinov2_with_windowed_attn import (
14
+ WindowedDinov2WithRegistersBackbone,
15
+ WindowedDinov2WithRegistersConfig,
16
+ )
17
+ from inference_models.models.rfdetr.misc import NestedTensor
18
+ from inference_models.models.rfdetr.projector import MultiScaleProjector
19
+
20
+ size_to_width = {
21
+ "tiny": 192,
22
+ "small": 384,
23
+ "base": 768,
24
+ "large": 1024,
25
+ }
26
+
27
+ size_to_config = {
28
+ "small": "dinov2_small.json",
29
+ "base": "dinov2_base.json",
30
+ "large": "dinov2_large.json",
31
+ }
32
+
33
+ size_to_config_with_registers = {
34
+ "small": "dinov2_with_registers_small.json",
35
+ "base": "dinov2_with_registers_base.json",
36
+ "large": "dinov2_with_registers_large.json",
37
+ }
38
+
39
+
40
+ def get_config(size, use_registers):
41
+ config_dict = size_to_config_with_registers if use_registers else size_to_config
42
+ current_dir = os.path.dirname(os.path.abspath(__file__))
43
+ configs_dir = os.path.join(current_dir, "dinov2_configs")
44
+ config_path = os.path.join(configs_dir, config_dict[size])
45
+ with open(config_path, "r") as f:
46
+ dino_config = json.load(f)
47
+ return dino_config
48
+
49
+
50
+ class DinoV2(nn.Module):
51
+ def __init__(
52
+ self,
53
+ shape=(640, 640),
54
+ out_feature_indexes=[2, 4, 5, 9],
55
+ size="base",
56
+ use_registers=True,
57
+ use_windowed_attn=True,
58
+ gradient_checkpointing=False,
59
+ load_dinov2_weights=True,
60
+ patch_size=14,
61
+ num_windows=4,
62
+ positional_encoding_size=37,
63
+ ):
64
+ super().__init__()
65
+
66
+ name = (
67
+ f"facebook/dinov2-with-registers-{size}"
68
+ if use_registers
69
+ else f"facebook/dinov2-{size}"
70
+ )
71
+
72
+ self.shape = shape
73
+ self.patch_size = patch_size
74
+ self.num_windows = num_windows
75
+
76
+ # Create the encoder
77
+
78
+ if not use_windowed_attn:
79
+ assert (
80
+ not gradient_checkpointing
81
+ ), "Gradient checkpointing is not supported for non-windowed attention"
82
+ assert (
83
+ load_dinov2_weights
84
+ ), "Using non-windowed attention requires loading dinov2 weights from hub"
85
+ self.encoder = AutoBackbone.from_pretrained(
86
+ name,
87
+ out_features=[f"stage{i}" for i in out_feature_indexes],
88
+ return_dict=False,
89
+ )
90
+ else:
91
+ window_block_indexes = set(range(out_feature_indexes[-1] + 1))
92
+ window_block_indexes.difference_update(out_feature_indexes)
93
+ window_block_indexes = list(window_block_indexes)
94
+
95
+ dino_config = get_config(size, use_registers)
96
+
97
+ dino_config["return_dict"] = False
98
+ dino_config["out_features"] = [f"stage{i}" for i in out_feature_indexes]
99
+ implied_resolution = positional_encoding_size * patch_size
100
+
101
+ if implied_resolution != dino_config["image_size"]:
102
+ LOGGER.warning(
103
+ f"Using a different number of positional encodings than DINOv2, which means we're not loading DINOv2 backbone weights. This is not a problem if finetuning a pretrained RF-DETR model."
104
+ )
105
+ dino_config["image_size"] = implied_resolution
106
+ load_dinov2_weights = False
107
+
108
+ if patch_size != 14:
109
+ LOGGER.warning(
110
+ f"Using patch size {patch_size} instead of 14, which means we're not loading DINOv2 backbone weights. This is not a problem if finetuning a pretrained RF-DETR model."
111
+ )
112
+ dino_config["patch_size"] = patch_size
113
+ load_dinov2_weights = False
114
+
115
+ if use_registers:
116
+ windowed_dino_config = WindowedDinov2WithRegistersConfig(
117
+ **dino_config,
118
+ num_windows=num_windows,
119
+ window_block_indexes=window_block_indexes,
120
+ gradient_checkpointing=gradient_checkpointing,
121
+ )
122
+ else:
123
+ windowed_dino_config = WindowedDinov2WithRegistersConfig(
124
+ **dino_config,
125
+ num_windows=num_windows,
126
+ window_block_indexes=window_block_indexes,
127
+ num_register_tokens=0,
128
+ gradient_checkpointing=gradient_checkpointing,
129
+ )
130
+ self.encoder = (
131
+ WindowedDinov2WithRegistersBackbone.from_pretrained(
132
+ name,
133
+ config=windowed_dino_config,
134
+ )
135
+ if load_dinov2_weights
136
+ else WindowedDinov2WithRegistersBackbone(windowed_dino_config)
137
+ )
138
+
139
+ self._out_feature_channels = [size_to_width[size]] * len(out_feature_indexes)
140
+ self._export = False
141
+
142
+ def export(self):
143
+ if self._export:
144
+ return
145
+ self._export = True
146
+ shape = self.shape
147
+
148
+ def make_new_interpolated_pos_encoding(
149
+ position_embeddings, patch_size, height, width
150
+ ):
151
+
152
+ num_positions = position_embeddings.shape[1] - 1
153
+ dim = position_embeddings.shape[-1]
154
+ height = height // patch_size
155
+ width = width // patch_size
156
+
157
+ class_pos_embed = position_embeddings[:, 0]
158
+ patch_pos_embed = position_embeddings[:, 1:]
159
+
160
+ # Reshape and permute
161
+ patch_pos_embed = patch_pos_embed.reshape(
162
+ 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
163
+ )
164
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
165
+
166
+ # Use bilinear interpolation without antialias
167
+ patch_pos_embed = F.interpolate(
168
+ patch_pos_embed,
169
+ size=(height, width),
170
+ mode="bicubic",
171
+ align_corners=False,
172
+ antialias=True,
173
+ )
174
+
175
+ # Reshape back
176
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).reshape(1, -1, dim)
177
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
178
+
179
+ # If the shape of self.encoder.embeddings.position_embeddings
180
+ # matches the shape of your new tensor, use copy_:
181
+ with torch.no_grad():
182
+ new_positions = make_new_interpolated_pos_encoding(
183
+ self.encoder.embeddings.position_embeddings,
184
+ self.encoder.config.patch_size,
185
+ shape[0],
186
+ shape[1],
187
+ )
188
+ # Create a new Parameter with the new size
189
+ old_interpolate_pos_encoding = self.encoder.embeddings.interpolate_pos_encoding
190
+
191
+ def new_interpolate_pos_encoding(self_mod, embeddings, height, width):
192
+ num_patches = embeddings.shape[1] - 1
193
+ num_positions = self_mod.position_embeddings.shape[1] - 1
194
+ if num_patches == num_positions and height == width:
195
+ return self_mod.position_embeddings
196
+ return old_interpolate_pos_encoding(embeddings, height, width)
197
+
198
+ self.encoder.embeddings.position_embeddings = nn.Parameter(new_positions)
199
+ self.encoder.embeddings.interpolate_pos_encoding = types.MethodType(
200
+ new_interpolate_pos_encoding, self.encoder.embeddings
201
+ )
202
+
203
+ def forward(self, x):
204
+ block_size = self.patch_size * self.num_windows
205
+ assert (
206
+ x.shape[2] % block_size == 0 and x.shape[3] % block_size == 0
207
+ ), f"Backbone requires input shape to be divisible by {block_size}, but got {x.shape}"
208
+ x = self.encoder(x)
209
+ return list(x[0])
210
+
211
+
212
+ class BackboneBase(nn.Module):
213
+ def __init__(self):
214
+ super().__init__()
215
+
216
+ def get_named_param_lr_pairs(self, args, prefix: str):
217
+ raise NotImplementedError
218
+
219
+
220
+ class Backbone(BackboneBase):
221
+ """backbone."""
222
+
223
+ def __init__(
224
+ self,
225
+ name: str,
226
+ pretrained_encoder: str = None,
227
+ window_block_indexes: list = None,
228
+ drop_path=0.0,
229
+ out_channels=256,
230
+ out_feature_indexes: list = None,
231
+ projector_scale: list = None,
232
+ use_cls_token: bool = False,
233
+ freeze_encoder: bool = False,
234
+ layer_norm: bool = False,
235
+ target_shape: tuple[int, int] = (640, 640),
236
+ rms_norm: bool = False,
237
+ backbone_lora: bool = False,
238
+ gradient_checkpointing: bool = False,
239
+ load_dinov2_weights: bool = True,
240
+ patch_size: int = 14,
241
+ num_windows: int = 4,
242
+ positional_encoding_size: bool = False,
243
+ ):
244
+ super().__init__()
245
+ # an example name here would be "dinov2_base" or "dinov2_registers_windowed_base"
246
+ # if "registers" is in the name, then use_registers is set to True, otherwise it is set to False
247
+ # similarly, if "windowed" is in the name, then use_windowed_attn is set to True, otherwise it is set to False
248
+ # the last part of the name should be the size
249
+ # and the start should be dinov2
250
+ name_parts = name.split("_")
251
+ assert name_parts[0] == "dinov2"
252
+ size = name_parts[-1]
253
+ use_registers = False
254
+ if "registers" in name_parts:
255
+ use_registers = True
256
+ name_parts.remove("registers")
257
+ use_windowed_attn = False
258
+ if "windowed" in name_parts:
259
+ use_windowed_attn = True
260
+ name_parts.remove("windowed")
261
+ assert (
262
+ len(name_parts) == 2
263
+ ), "name should be dinov2, then either registers, windowed, both, or none, then the size"
264
+ self.encoder = DinoV2(
265
+ size=name_parts[-1],
266
+ out_feature_indexes=out_feature_indexes,
267
+ shape=target_shape,
268
+ use_registers=use_registers,
269
+ use_windowed_attn=use_windowed_attn,
270
+ gradient_checkpointing=gradient_checkpointing,
271
+ load_dinov2_weights=load_dinov2_weights,
272
+ patch_size=patch_size,
273
+ num_windows=num_windows,
274
+ positional_encoding_size=positional_encoding_size,
275
+ )
276
+ # build encoder + projector as backbone module
277
+ if freeze_encoder:
278
+ for param in self.encoder.parameters():
279
+ param.requires_grad = False
280
+
281
+ self.projector_scale = projector_scale
282
+ assert len(self.projector_scale) > 0
283
+ # x[0]
284
+ assert (
285
+ sorted(self.projector_scale) == self.projector_scale
286
+ ), "only support projector scale P3/P4/P5/P6 in ascending order."
287
+ level2scalefactor = dict(P3=2.0, P4=1.0, P5=0.5, P6=0.25)
288
+ scale_factors = [level2scalefactor[lvl] for lvl in self.projector_scale]
289
+
290
+ self.projector = MultiScaleProjector(
291
+ in_channels=self.encoder._out_feature_channels,
292
+ out_channels=out_channels,
293
+ scale_factors=scale_factors,
294
+ layer_norm=layer_norm,
295
+ rms_norm=rms_norm,
296
+ )
297
+
298
+ self._export = False
299
+
300
+ def export(self):
301
+ self._export = True
302
+ self._forward_origin = self.forward
303
+ self.forward = self.forward_export
304
+
305
+ if isinstance(self.encoder, PeftModel):
306
+ LOGGER.info("Merging and unloading LoRA weights")
307
+ self.encoder.merge_and_unload()
308
+
309
+ def forward(self, tensor_list: NestedTensor):
310
+ """ """
311
+ # (H, W, B, C)
312
+ feats = self.encoder(tensor_list.tensors)
313
+ feats = self.projector(feats)
314
+ # x: [(B, C, H, W)]
315
+ out = []
316
+ for feat in feats:
317
+ m = tensor_list.mask
318
+ assert m is not None
319
+ mask = F.interpolate(m[None].float(), size=feat.shape[-2:]).to(torch.bool)[
320
+ 0
321
+ ]
322
+ out.append(NestedTensor(feat, mask))
323
+ return out
324
+
325
+ def forward_export(self, tensors: torch.Tensor):
326
+ feats = self.encoder(tensors)
327
+ feats = self.projector(feats)
328
+ out_feats = []
329
+ out_masks = []
330
+ for feat in feats:
331
+ # x: [(B, C, H, W)]
332
+ b, _, h, w = feat.shape
333
+ out_masks.append(
334
+ torch.zeros((b, h, w), dtype=torch.bool, device=feat.device)
335
+ )
336
+ out_feats.append(feat)
337
+ return out_feats, out_masks
338
+
339
+ def get_named_param_lr_pairs(self, args, prefix: str = "backbone.0"):
340
+ num_layers = args.out_feature_indexes[-1] + 1
341
+ backbone_key = "backbone.0.encoder"
342
+ named_param_lr_pairs = {}
343
+ for n, p in self.named_parameters():
344
+ n = prefix + "." + n
345
+ if backbone_key in n and p.requires_grad:
346
+ lr = (
347
+ args.lr_encoder
348
+ * get_dinov2_lr_decay_rate(
349
+ n,
350
+ lr_decay_rate=args.lr_vit_layer_decay,
351
+ num_layers=num_layers,
352
+ )
353
+ * args.lr_component_decay**2
354
+ )
355
+ wd = args.weight_decay * get_dinov2_weight_decay_rate(n)
356
+ named_param_lr_pairs[n] = {
357
+ "params": p,
358
+ "lr": lr,
359
+ "weight_decay": wd,
360
+ }
361
+ return named_param_lr_pairs
362
+
363
+
364
+ def get_dinov2_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12):
365
+ """
366
+ Calculate lr decay rate for different ViT blocks.
367
+
368
+ Args:
369
+ name (string): parameter name.
370
+ lr_decay_rate (float): base lr decay rate.
371
+ num_layers (int): number of ViT blocks.
372
+ Returns:
373
+ lr decay rate for the given parameter.
374
+ """
375
+ layer_id = num_layers + 1
376
+ if name.startswith("backbone"):
377
+ if "embeddings" in name:
378
+ layer_id = 0
379
+ elif ".layer." in name and ".residual." not in name:
380
+ layer_id = int(name[name.find(".layer.") :].split(".")[2]) + 1
381
+ return lr_decay_rate ** (num_layers + 1 - layer_id)
382
+
383
+
384
+ def get_dinov2_weight_decay_rate(name, weight_decay_rate=1.0):
385
+ if (
386
+ ("gamma" in name)
387
+ or ("pos_embed" in name)
388
+ or ("rel_pos" in name)
389
+ or ("bias" in name)
390
+ or ("norm" in name)
391
+ or ("embeddings" in name)
392
+ ):
393
+ weight_decay_rate = 0.0
394
+ return weight_decay_rate