inference-models 0.18.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. inference_models/__init__.py +36 -0
  2. inference_models/configuration.py +72 -0
  3. inference_models/constants.py +2 -0
  4. inference_models/entities.py +5 -0
  5. inference_models/errors.py +137 -0
  6. inference_models/logger.py +52 -0
  7. inference_models/model_pipelines/__init__.py +0 -0
  8. inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
  9. inference_models/model_pipelines/auto_loaders/core.py +120 -0
  10. inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
  11. inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
  12. inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
  13. inference_models/models/__init__.py +0 -0
  14. inference_models/models/auto_loaders/__init__.py +0 -0
  15. inference_models/models/auto_loaders/access_manager.py +168 -0
  16. inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
  17. inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
  18. inference_models/models/auto_loaders/constants.py +7 -0
  19. inference_models/models/auto_loaders/core.py +1341 -0
  20. inference_models/models/auto_loaders/dependency_models.py +52 -0
  21. inference_models/models/auto_loaders/entities.py +57 -0
  22. inference_models/models/auto_loaders/models_registry.py +497 -0
  23. inference_models/models/auto_loaders/presentation_utils.py +333 -0
  24. inference_models/models/auto_loaders/ranking.py +413 -0
  25. inference_models/models/auto_loaders/utils.py +31 -0
  26. inference_models/models/base/__init__.py +0 -0
  27. inference_models/models/base/classification.py +123 -0
  28. inference_models/models/base/depth_estimation.py +62 -0
  29. inference_models/models/base/documents_parsing.py +111 -0
  30. inference_models/models/base/embeddings.py +66 -0
  31. inference_models/models/base/instance_segmentation.py +87 -0
  32. inference_models/models/base/keypoints_detection.py +93 -0
  33. inference_models/models/base/object_detection.py +143 -0
  34. inference_models/models/base/semantic_segmentation.py +74 -0
  35. inference_models/models/base/types.py +5 -0
  36. inference_models/models/clip/__init__.py +0 -0
  37. inference_models/models/clip/clip_onnx.py +148 -0
  38. inference_models/models/clip/clip_pytorch.py +104 -0
  39. inference_models/models/clip/preprocessing.py +162 -0
  40. inference_models/models/common/__init__.py +0 -0
  41. inference_models/models/common/cuda.py +30 -0
  42. inference_models/models/common/model_packages.py +25 -0
  43. inference_models/models/common/onnx.py +379 -0
  44. inference_models/models/common/roboflow/__init__.py +0 -0
  45. inference_models/models/common/roboflow/model_packages.py +361 -0
  46. inference_models/models/common/roboflow/post_processing.py +436 -0
  47. inference_models/models/common/roboflow/pre_processing.py +1332 -0
  48. inference_models/models/common/torch.py +20 -0
  49. inference_models/models/common/trt.py +266 -0
  50. inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
  51. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
  52. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
  53. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
  54. inference_models/models/depth_anything_v2/__init__.py +0 -0
  55. inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
  56. inference_models/models/dinov3/__init__.py +0 -0
  57. inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
  58. inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
  59. inference_models/models/doctr/__init__.py +0 -0
  60. inference_models/models/doctr/doctr_torch.py +304 -0
  61. inference_models/models/easy_ocr/__init__.py +0 -0
  62. inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
  63. inference_models/models/florence2/__init__.py +0 -0
  64. inference_models/models/florence2/florence2_hf.py +897 -0
  65. inference_models/models/grounding_dino/__init__.py +0 -0
  66. inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
  67. inference_models/models/l2cs/__init__.py +0 -0
  68. inference_models/models/l2cs/l2cs_onnx.py +216 -0
  69. inference_models/models/mediapipe_face_detection/__init__.py +0 -0
  70. inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
  71. inference_models/models/moondream2/__init__.py +0 -0
  72. inference_models/models/moondream2/moondream2_hf.py +281 -0
  73. inference_models/models/owlv2/__init__.py +0 -0
  74. inference_models/models/owlv2/cache.py +182 -0
  75. inference_models/models/owlv2/entities.py +112 -0
  76. inference_models/models/owlv2/owlv2_hf.py +695 -0
  77. inference_models/models/owlv2/reference_dataset.py +291 -0
  78. inference_models/models/paligemma/__init__.py +0 -0
  79. inference_models/models/paligemma/paligemma_hf.py +209 -0
  80. inference_models/models/perception_encoder/__init__.py +0 -0
  81. inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
  82. inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
  83. inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
  84. inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
  85. inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
  86. inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
  87. inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
  88. inference_models/models/qwen25vl/__init__.py +1 -0
  89. inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
  90. inference_models/models/resnet/__init__.py +0 -0
  91. inference_models/models/resnet/resnet_classification_onnx.py +330 -0
  92. inference_models/models/resnet/resnet_classification_torch.py +305 -0
  93. inference_models/models/resnet/resnet_classification_trt.py +369 -0
  94. inference_models/models/rfdetr/__init__.py +0 -0
  95. inference_models/models/rfdetr/backbone_builder.py +101 -0
  96. inference_models/models/rfdetr/class_remapping.py +41 -0
  97. inference_models/models/rfdetr/common.py +115 -0
  98. inference_models/models/rfdetr/default_labels.py +108 -0
  99. inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
  100. inference_models/models/rfdetr/misc.py +26 -0
  101. inference_models/models/rfdetr/ms_deform_attn.py +180 -0
  102. inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
  103. inference_models/models/rfdetr/position_encoding.py +166 -0
  104. inference_models/models/rfdetr/post_processor.py +83 -0
  105. inference_models/models/rfdetr/projector.py +373 -0
  106. inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
  107. inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
  108. inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
  109. inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
  110. inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
  111. inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
  112. inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
  113. inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
  114. inference_models/models/rfdetr/segmentation_head.py +273 -0
  115. inference_models/models/rfdetr/transformer.py +767 -0
  116. inference_models/models/roboflow_instant/__init__.py +0 -0
  117. inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
  118. inference_models/models/sam/__init__.py +0 -0
  119. inference_models/models/sam/cache.py +147 -0
  120. inference_models/models/sam/entities.py +25 -0
  121. inference_models/models/sam/sam_torch.py +675 -0
  122. inference_models/models/sam2/__init__.py +0 -0
  123. inference_models/models/sam2/cache.py +162 -0
  124. inference_models/models/sam2/entities.py +43 -0
  125. inference_models/models/sam2/sam2_torch.py +905 -0
  126. inference_models/models/sam2_rt/__init__.py +0 -0
  127. inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
  128. inference_models/models/smolvlm/__init__.py +0 -0
  129. inference_models/models/smolvlm/smolvlm_hf.py +245 -0
  130. inference_models/models/trocr/__init__.py +0 -0
  131. inference_models/models/trocr/trocr_hf.py +53 -0
  132. inference_models/models/vit/__init__.py +0 -0
  133. inference_models/models/vit/vit_classification_huggingface.py +319 -0
  134. inference_models/models/vit/vit_classification_onnx.py +326 -0
  135. inference_models/models/vit/vit_classification_trt.py +365 -0
  136. inference_models/models/yolact/__init__.py +1 -0
  137. inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
  138. inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
  139. inference_models/models/yolo_world/__init__.py +1 -0
  140. inference_models/models/yolonas/__init__.py +0 -0
  141. inference_models/models/yolonas/nms.py +44 -0
  142. inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
  143. inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
  144. inference_models/models/yolov10/__init__.py +0 -0
  145. inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
  146. inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
  147. inference_models/models/yolov11/__init__.py +0 -0
  148. inference_models/models/yolov11/yolov11_onnx.py +28 -0
  149. inference_models/models/yolov11/yolov11_torch_script.py +25 -0
  150. inference_models/models/yolov11/yolov11_trt.py +21 -0
  151. inference_models/models/yolov12/__init__.py +0 -0
  152. inference_models/models/yolov12/yolov12_onnx.py +7 -0
  153. inference_models/models/yolov12/yolov12_torch_script.py +7 -0
  154. inference_models/models/yolov12/yolov12_trt.py +7 -0
  155. inference_models/models/yolov5/__init__.py +0 -0
  156. inference_models/models/yolov5/nms.py +99 -0
  157. inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
  158. inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
  159. inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
  160. inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
  161. inference_models/models/yolov7/__init__.py +0 -0
  162. inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
  163. inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
  164. inference_models/models/yolov8/__init__.py +0 -0
  165. inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
  166. inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
  167. inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
  168. inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
  169. inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
  170. inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
  171. inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
  172. inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
  173. inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
  174. inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
  175. inference_models/models/yolov9/__init__.py +0 -0
  176. inference_models/models/yolov9/yolov9_onnx.py +7 -0
  177. inference_models/models/yolov9/yolov9_torch_script.py +7 -0
  178. inference_models/models/yolov9/yolov9_trt.py +7 -0
  179. inference_models/runtime_introspection/__init__.py +0 -0
  180. inference_models/runtime_introspection/core.py +410 -0
  181. inference_models/utils/__init__.py +0 -0
  182. inference_models/utils/download.py +608 -0
  183. inference_models/utils/environment.py +28 -0
  184. inference_models/utils/file_system.py +51 -0
  185. inference_models/utils/hashing.py +7 -0
  186. inference_models/utils/imports.py +48 -0
  187. inference_models/utils/onnx_introspection.py +17 -0
  188. inference_models/weights_providers/__init__.py +0 -0
  189. inference_models/weights_providers/core.py +20 -0
  190. inference_models/weights_providers/entities.py +159 -0
  191. inference_models/weights_providers/roboflow.py +601 -0
  192. inference_models-0.18.3.dist-info/METADATA +466 -0
  193. inference_models-0.18.3.dist-info/RECORD +195 -0
  194. inference_models-0.18.3.dist-info/WHEEL +5 -0
  195. inference_models-0.18.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1330 @@
1
+ # ------------------------------------------------------------------------
2
+ # RF-DETR
3
+ # Copyright (c) 2025 Roboflow. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from HuggingFace Dinov2 (https://github.com/huggingface/transformers)
7
+ # Copyright 2024 Meta Inc. and the HuggingFace Inc. team. All rights reserved.
8
+ # ------------------------------------------------------------------------
9
+
10
+ import collections.abc
11
+ import math
12
+ from typing import Dict, List, Optional, Set, Tuple, Union
13
+
14
+ import torch
15
+ from torch import nn
16
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
17
+ from transformers.activations import ACT2FN
18
+ from transformers.configuration_utils import PretrainedConfig
19
+ from transformers.modeling_outputs import (
20
+ BackboneOutput,
21
+ BaseModelOutput,
22
+ BaseModelOutputWithPooling,
23
+ ImageClassifierOutput,
24
+ )
25
+ from transformers.modeling_utils import PreTrainedModel
26
+ from transformers.pytorch_utils import (
27
+ find_pruneable_heads_and_indices,
28
+ prune_linear_layer,
29
+ )
30
+ from transformers.utils import (
31
+ add_code_sample_docstrings,
32
+ add_start_docstrings,
33
+ add_start_docstrings_to_model_forward,
34
+ logging,
35
+ replace_return_docstrings,
36
+ torch_int,
37
+ )
38
+ from transformers.utils.backbone_utils import (
39
+ BackboneConfigMixin,
40
+ BackboneMixin,
41
+ get_aligned_output_features_output_indices,
42
+ )
43
+
44
+ logger = logging.get_logger(__name__)
45
+
46
+ # Base docstring
47
+ _CHECKPOINT_FOR_DOC = "facebook/dinov2_with_registers-base"
48
+
49
+ # General docstring
50
+ _CONFIG_FOR_DOC = "WindowedDinov2WithRegistersConfig"
51
+
52
+
53
+ class WindowedDinov2WithRegistersConfig(BackboneConfigMixin, PretrainedConfig):
54
+ r"""
55
+ This is the configuration class to store the configuration of a [`Dinov2WithRegistersModel`]. It is used to instantiate an
56
+ Dinov2WithRegisters model according to the specified arguments, defining the model architecture. Instantiating a configuration
57
+ with the defaults will yield a similar configuration to that of the DINOv2 with Registers
58
+ [facebook/dinov2-with-registers-base](https://huggingface.co/facebook/dinov2-with-registers-base) architecture.
59
+
60
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
61
+ documentation from [`PretrainedConfig`] for more information.
62
+
63
+ Args:
64
+ hidden_size (`int`, *optional*, defaults to 768):
65
+ Dimensionality of the encoder layers and the pooler layer.
66
+ num_hidden_layers (`int`, *optional*, defaults to 12):
67
+ Number of hidden layers in the Transformer encoder.
68
+ num_attention_heads (`int`, *optional*, defaults to 12):
69
+ Number of attention heads for each attention layer in the Transformer encoder.
70
+ mlp_ratio (`int`, *optional*, defaults to 4):
71
+ Ratio of the hidden size of the MLPs relative to the `hidden_size`.
72
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
73
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
74
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
75
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
76
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
77
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
78
+ The dropout ratio for the attention probabilities.
79
+ initializer_range (`float`, *optional*, defaults to 0.02):
80
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
81
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
82
+ The epsilon used by the layer normalization layers.
83
+ image_size (`int`, *optional*, defaults to 224):
84
+ The size (resolution) of each image.
85
+ patch_size (`int`, *optional*, defaults to 16):
86
+ The size (resolution) of each patch.
87
+ num_channels (`int`, *optional*, defaults to 3):
88
+ The number of input channels.
89
+ qkv_bias (`bool`, *optional*, defaults to `True`):
90
+ Whether to add a bias to the queries, keys and values.
91
+ layerscale_value (`float`, *optional*, defaults to 1.0):
92
+ Initial value to use for layer scale.
93
+ drop_path_rate (`float`, *optional*, defaults to 0.0):
94
+ Stochastic depth rate per sample (when applied in the main path of residual layers).
95
+ use_swiglu_ffn (`bool`, *optional*, defaults to `False`):
96
+ Whether to use the SwiGLU feedforward neural network.
97
+ num_register_tokens (`int`, *optional*, defaults to 4):
98
+ Number of register tokens to use.
99
+ out_features (`List[str]`, *optional*):
100
+ If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
101
+ (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
102
+ corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
103
+ same order as defined in the `stage_names` attribute.
104
+ out_indices (`List[int]`, *optional*):
105
+ If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
106
+ many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
107
+ If unset and `out_features` is unset, will default to the last stage. Must be in the
108
+ same order as defined in the `stage_names` attribute.
109
+ apply_layernorm (`bool`, *optional*, defaults to `True`):
110
+ Whether to apply layer normalization to the feature maps in case the model is used as backbone.
111
+ reshape_hidden_states (`bool`, *optional*, defaults to `True`):
112
+ Whether to reshape the feature maps to 4D tensors of shape `(batch_size, hidden_size, height, width)` in
113
+ case the model is used as backbone. If `False`, the feature maps will be 3D tensors of shape `(batch_size,
114
+ seq_len, hidden_size)`.
115
+
116
+ Example:
117
+
118
+ ```python
119
+ >>> from transformers import Dinov2WithRegistersConfig, Dinov2WithRegistersModel
120
+
121
+ >>> # Initializing a Dinov2WithRegisters base style configuration
122
+ >>> configuration = Dinov2WithRegistersConfig()
123
+
124
+ >>> # Initializing a model (with random weights) from the base style configuration
125
+ >>> model = Dinov2WithRegistersModel(configuration)
126
+
127
+ >>> # Accessing the model configuration
128
+ >>> configuration = model.config
129
+ ```"""
130
+
131
+ model_type = "dinov2_with_registers"
132
+
133
+ def __init__(
134
+ self,
135
+ hidden_size=768,
136
+ num_hidden_layers=12,
137
+ num_attention_heads=12,
138
+ mlp_ratio=4,
139
+ hidden_act="gelu",
140
+ hidden_dropout_prob=0.0,
141
+ attention_probs_dropout_prob=0.0,
142
+ initializer_range=0.02,
143
+ layer_norm_eps=1e-6,
144
+ image_size=224,
145
+ patch_size=16,
146
+ num_channels=3,
147
+ qkv_bias=True,
148
+ layerscale_value=1.0,
149
+ drop_path_rate=0.0,
150
+ use_swiglu_ffn=False,
151
+ num_register_tokens=4,
152
+ out_features=None,
153
+ out_indices=None,
154
+ apply_layernorm=True,
155
+ reshape_hidden_states=True,
156
+ num_windows=1,
157
+ window_block_indexes=None,
158
+ gradient_checkpointing=False,
159
+ **kwargs,
160
+ ):
161
+ super().__init__(**kwargs)
162
+
163
+ self.hidden_size = hidden_size
164
+ self.num_hidden_layers = num_hidden_layers
165
+ self.num_attention_heads = num_attention_heads
166
+ self.mlp_ratio = mlp_ratio
167
+ self.hidden_act = hidden_act
168
+ self.hidden_dropout_prob = hidden_dropout_prob
169
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
170
+ self.initializer_range = initializer_range
171
+ self.layer_norm_eps = layer_norm_eps
172
+ self.image_size = image_size
173
+ self.patch_size = patch_size
174
+ self.num_channels = num_channels
175
+ self.qkv_bias = qkv_bias
176
+ self.layerscale_value = layerscale_value
177
+ self.drop_path_rate = drop_path_rate
178
+ self.use_swiglu_ffn = use_swiglu_ffn
179
+ self.num_register_tokens = num_register_tokens
180
+ self.stage_names = ["stem"] + [
181
+ f"stage{idx}" for idx in range(1, num_hidden_layers + 1)
182
+ ]
183
+ self._out_features, self._out_indices = (
184
+ get_aligned_output_features_output_indices(
185
+ out_features=out_features,
186
+ out_indices=out_indices,
187
+ stage_names=self.stage_names,
188
+ )
189
+ )
190
+ self.apply_layernorm = apply_layernorm
191
+ self.reshape_hidden_states = reshape_hidden_states
192
+ self.num_windows = num_windows
193
+ self.window_block_indexes = (
194
+ list(range(num_hidden_layers))
195
+ if window_block_indexes is None
196
+ else window_block_indexes
197
+ )
198
+ self.gradient_checkpointing = gradient_checkpointing
199
+
200
+
201
+ class Dinov2WithRegistersPatchEmbeddings(nn.Module):
202
+ """
203
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
204
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
205
+ Transformer.
206
+ """
207
+
208
+ def __init__(self, config):
209
+ super().__init__()
210
+ image_size, patch_size = config.image_size, config.patch_size
211
+ num_channels, hidden_size = config.num_channels, config.hidden_size
212
+
213
+ image_size = (
214
+ image_size
215
+ if isinstance(image_size, collections.abc.Iterable)
216
+ else (image_size, image_size)
217
+ )
218
+ patch_size = (
219
+ patch_size
220
+ if isinstance(patch_size, collections.abc.Iterable)
221
+ else (patch_size, patch_size)
222
+ )
223
+ num_patches = (image_size[1] // patch_size[1]) * (
224
+ image_size[0] // patch_size[0]
225
+ )
226
+ self.image_size = image_size
227
+ self.patch_size = patch_size
228
+ self.num_channels = num_channels
229
+ self.num_patches = num_patches
230
+
231
+ self.projection = nn.Conv2d(
232
+ num_channels, hidden_size, kernel_size=patch_size, stride=patch_size
233
+ )
234
+
235
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
236
+ num_channels = pixel_values.shape[1]
237
+ if num_channels != self.num_channels:
238
+ raise ValueError(
239
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
240
+ f" Expected {self.num_channels} but got {num_channels}."
241
+ )
242
+ embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
243
+ return embeddings
244
+
245
+
246
+ class WindowedDinov2WithRegistersEmbeddings(nn.Module):
247
+ """
248
+ Construct the CLS token, mask token, register tokens, position and patch embeddings.
249
+ """
250
+
251
+ def __init__(self, config: WindowedDinov2WithRegistersConfig) -> None:
252
+ super().__init__()
253
+
254
+ self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
255
+ self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size))
256
+ self.register_tokens = (
257
+ nn.Parameter(torch.zeros(1, config.num_register_tokens, config.hidden_size))
258
+ if config.num_register_tokens > 0
259
+ else None
260
+ )
261
+ self.patch_embeddings = Dinov2WithRegistersPatchEmbeddings(config)
262
+ num_patches = self.patch_embeddings.num_patches
263
+ self.position_embeddings = nn.Parameter(
264
+ torch.randn(1, num_patches + 1, config.hidden_size)
265
+ )
266
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
267
+ self.patch_size = config.patch_size
268
+ self.config = config
269
+
270
+ def interpolate_pos_encoding(
271
+ self, embeddings: torch.Tensor, height: int, width: int
272
+ ) -> torch.Tensor:
273
+ """
274
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
275
+ resolution images. This implementation supports torch.jit tracing while maintaining backwards compatibility
276
+ with the original implementation.
277
+
278
+ Adapted from:
279
+ - https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
280
+ - https://github.com/facebookresearch/dinov2/blob/main/dinov2/models/vision_transformer.py
281
+ """
282
+ num_patches = embeddings.shape[1] - 1
283
+ num_positions = self.position_embeddings.shape[1] - 1
284
+
285
+ # Skip interpolation for matching dimensions (unless tracing)
286
+ if (
287
+ not torch.jit.is_tracing()
288
+ and num_patches == num_positions
289
+ and height == width
290
+ ):
291
+ return self.position_embeddings
292
+
293
+ # Handle class token and patch embeddings separately
294
+ class_pos_embed = self.position_embeddings[:, 0]
295
+ patch_pos_embed = self.position_embeddings[:, 1:]
296
+ dim = embeddings.shape[-1]
297
+
298
+ # Calculate new dimensions
299
+ height = height // self.config.patch_size
300
+ width = width // self.config.patch_size
301
+
302
+ # Reshape for interpolation
303
+ sqrt_num_positions = torch_int(num_positions**0.5)
304
+ patch_pos_embed = patch_pos_embed.reshape(
305
+ 1, sqrt_num_positions, sqrt_num_positions, dim
306
+ )
307
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
308
+
309
+ # Store original dtype for restoration after interpolation
310
+ target_dtype = patch_pos_embed.dtype
311
+
312
+ # Interpolate at float32 precision
313
+ patch_pos_embed = nn.functional.interpolate(
314
+ patch_pos_embed.to(dtype=torch.float32),
315
+ size=(
316
+ torch_int(height),
317
+ torch_int(width),
318
+ ), # Explicit size instead of scale_factor
319
+ mode="bicubic",
320
+ align_corners=False,
321
+ antialias=True,
322
+ ).to(dtype=target_dtype)
323
+
324
+ # Validate output dimensions if not tracing
325
+ if not torch.jit.is_tracing():
326
+ if (
327
+ int(height) != patch_pos_embed.shape[-2]
328
+ or int(width) != patch_pos_embed.shape[-1]
329
+ ):
330
+ raise ValueError(
331
+ "Width or height does not match with the interpolated position embeddings"
332
+ )
333
+
334
+ # Reshape back to original format
335
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
336
+
337
+ # Combine class and patch embeddings
338
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
339
+
340
+ def forward(
341
+ self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None
342
+ ) -> torch.Tensor:
343
+ batch_size, _, height, width = pixel_values.shape
344
+ target_dtype = self.patch_embeddings.projection.weight.dtype
345
+ embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype))
346
+
347
+ if bool_masked_pos is not None:
348
+ embeddings = torch.where(
349
+ bool_masked_pos.unsqueeze(-1),
350
+ self.mask_token.to(embeddings.dtype).unsqueeze(0),
351
+ embeddings,
352
+ )
353
+
354
+ # add the [CLS] token to the embedded patch tokens
355
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
356
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
357
+
358
+ # add positional encoding to each token
359
+ embeddings = embeddings + self.interpolate_pos_encoding(
360
+ embeddings, height, width
361
+ )
362
+
363
+ if self.config.num_windows > 1:
364
+ # reshape for windows
365
+ num_h_patches = height // self.config.patch_size
366
+ num_w_patches = width // self.config.patch_size
367
+ cls_token_with_pos_embed = embeddings[:, :1]
368
+ pixel_tokens_with_pos_embed = embeddings[:, 1:]
369
+ pixel_tokens_with_pos_embed = pixel_tokens_with_pos_embed.view(
370
+ batch_size, num_h_patches, num_w_patches, -1
371
+ )
372
+ num_w_patches_per_window = num_w_patches // self.config.num_windows
373
+ num_h_patches_per_window = num_h_patches // self.config.num_windows
374
+ num_windows = self.config.num_windows
375
+ windowed_pixel_tokens = pixel_tokens_with_pos_embed.reshape(
376
+ batch_size * num_windows,
377
+ num_h_patches_per_window,
378
+ num_windows,
379
+ num_h_patches_per_window,
380
+ -1,
381
+ )
382
+ windowed_pixel_tokens = windowed_pixel_tokens.permute(0, 2, 1, 3, 4)
383
+ windowed_pixel_tokens = windowed_pixel_tokens.reshape(
384
+ batch_size * num_windows**2,
385
+ num_h_patches_per_window * num_w_patches_per_window,
386
+ -1,
387
+ )
388
+ windowed_cls_token_with_pos_embed = cls_token_with_pos_embed.repeat(
389
+ num_windows**2, 1, 1
390
+ )
391
+ embeddings = torch.cat(
392
+ (windowed_cls_token_with_pos_embed, windowed_pixel_tokens), dim=1
393
+ )
394
+
395
+ # add register tokens
396
+ embeddings = (
397
+ torch.cat(
398
+ (
399
+ embeddings[:, :1],
400
+ self.register_tokens.expand(embeddings.shape[0], -1, -1),
401
+ embeddings[:, 1:],
402
+ ),
403
+ dim=1,
404
+ )
405
+ if self.config.num_register_tokens > 0
406
+ else embeddings
407
+ )
408
+
409
+ embeddings = self.dropout(embeddings)
410
+
411
+ return embeddings
412
+
413
+
414
+ class Dinov2WithRegistersSelfAttention(nn.Module):
415
+ def __init__(self, config: WindowedDinov2WithRegistersConfig) -> None:
416
+ super().__init__()
417
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
418
+ config, "embedding_size"
419
+ ):
420
+ raise ValueError(
421
+ f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
422
+ f"heads {config.num_attention_heads}."
423
+ )
424
+
425
+ self.num_attention_heads = config.num_attention_heads
426
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
427
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
428
+
429
+ self.query = nn.Linear(
430
+ config.hidden_size, self.all_head_size, bias=config.qkv_bias
431
+ )
432
+ self.key = nn.Linear(
433
+ config.hidden_size, self.all_head_size, bias=config.qkv_bias
434
+ )
435
+ self.value = nn.Linear(
436
+ config.hidden_size, self.all_head_size, bias=config.qkv_bias
437
+ )
438
+
439
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
440
+
441
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
442
+ new_x_shape = x.size()[:-1] + (
443
+ self.num_attention_heads,
444
+ self.attention_head_size,
445
+ )
446
+ x = x.view(new_x_shape)
447
+ return x.permute(0, 2, 1, 3)
448
+
449
+ def forward(
450
+ self,
451
+ hidden_states,
452
+ head_mask: Optional[torch.Tensor] = None,
453
+ output_attentions: bool = False,
454
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
455
+ mixed_query_layer = self.query(hidden_states)
456
+
457
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
458
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
459
+ query_layer = self.transpose_for_scores(mixed_query_layer)
460
+
461
+ # Take the dot product between "query" and "key" to get the raw attention scores.
462
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
463
+
464
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
465
+
466
+ # Normalize the attention scores to probabilities.
467
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
468
+
469
+ # This is actually dropping out entire tokens to attend to, which might
470
+ # seem a bit unusual, but is taken from the original Transformer paper.
471
+ attention_probs = self.dropout(attention_probs)
472
+
473
+ # Mask heads if we want to
474
+ if head_mask is not None:
475
+ attention_probs = attention_probs * head_mask
476
+
477
+ context_layer = torch.matmul(attention_probs, value_layer)
478
+
479
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
480
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
481
+ context_layer = context_layer.view(new_context_layer_shape)
482
+
483
+ outputs = (
484
+ (context_layer, attention_probs) if output_attentions else (context_layer,)
485
+ )
486
+
487
+ return outputs
488
+
489
+
490
+ class Dinov2WithRegistersSdpaSelfAttention(Dinov2WithRegistersSelfAttention):
491
+ def __init__(self, config: WindowedDinov2WithRegistersConfig) -> None:
492
+ super().__init__(config)
493
+ self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
494
+
495
+ def forward(
496
+ self,
497
+ hidden_states,
498
+ head_mask: Optional[torch.Tensor] = None,
499
+ output_attentions: bool = False,
500
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
501
+ if output_attentions:
502
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
503
+ logger.warning_once(
504
+ "Dinov2WithRegistersModel is using Dinov2WithRegistersSdpaSelfAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
505
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
506
+ )
507
+ return super().forward(
508
+ hidden_states=hidden_states,
509
+ head_mask=head_mask,
510
+ output_attentions=output_attentions,
511
+ )
512
+
513
+ mixed_query_layer = self.query(hidden_states)
514
+
515
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
516
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
517
+ query_layer = self.transpose_for_scores(mixed_query_layer)
518
+
519
+ context_layer = torch.nn.functional.scaled_dot_product_attention(
520
+ query_layer,
521
+ key_layer,
522
+ value_layer,
523
+ head_mask,
524
+ self.attention_probs_dropout_prob if self.training else 0.0,
525
+ is_causal=False,
526
+ scale=None,
527
+ )
528
+
529
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
530
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
531
+ context_layer = context_layer.view(new_context_layer_shape)
532
+
533
+ return context_layer, None
534
+
535
+
536
+ class Dinov2WithRegistersSelfOutput(nn.Module):
537
+ """
538
+ The residual connection is defined in Dinov2WithRegistersLayer instead of here (as is the case with other models), due to the
539
+ layernorm applied before each block.
540
+ """
541
+
542
+ def __init__(self, config: WindowedDinov2WithRegistersConfig) -> None:
543
+ super().__init__()
544
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
545
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
546
+
547
+ def forward(
548
+ self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
549
+ ) -> torch.Tensor:
550
+ hidden_states = self.dense(hidden_states)
551
+ hidden_states = self.dropout(hidden_states)
552
+
553
+ return hidden_states
554
+
555
+
556
+ class Dinov2WithRegistersAttention(nn.Module):
557
+ def __init__(self, config: WindowedDinov2WithRegistersConfig) -> None:
558
+ super().__init__()
559
+ self.attention = Dinov2WithRegistersSelfAttention(config)
560
+ self.output = Dinov2WithRegistersSelfOutput(config)
561
+ self.pruned_heads = set()
562
+
563
+ def prune_heads(self, heads: Set[int]) -> None:
564
+ if len(heads) == 0:
565
+ return
566
+ heads, index = find_pruneable_heads_and_indices(
567
+ heads,
568
+ self.attention.num_attention_heads,
569
+ self.attention.attention_head_size,
570
+ self.pruned_heads,
571
+ )
572
+
573
+ # Prune linear layers
574
+ self.attention.query = prune_linear_layer(self.attention.query, index)
575
+ self.attention.key = prune_linear_layer(self.attention.key, index)
576
+ self.attention.value = prune_linear_layer(self.attention.value, index)
577
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
578
+
579
+ # Update hyper params and store pruned heads
580
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(
581
+ heads
582
+ )
583
+ self.attention.all_head_size = (
584
+ self.attention.attention_head_size * self.attention.num_attention_heads
585
+ )
586
+ self.pruned_heads = self.pruned_heads.union(heads)
587
+
588
+ def forward(
589
+ self,
590
+ hidden_states: torch.Tensor,
591
+ head_mask: Optional[torch.Tensor] = None,
592
+ output_attentions: bool = False,
593
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
594
+ self_outputs = self.attention(hidden_states, head_mask, output_attentions)
595
+
596
+ attention_output = self.output(self_outputs[0], hidden_states)
597
+
598
+ outputs = (attention_output,) + self_outputs[
599
+ 1:
600
+ ] # add attentions if we output them
601
+ return outputs
602
+
603
+
604
+ class Dinov2WithRegistersSdpaAttention(Dinov2WithRegistersAttention):
605
+ def __init__(self, config: WindowedDinov2WithRegistersConfig) -> None:
606
+ super().__init__(config)
607
+ self.attention = Dinov2WithRegistersSdpaSelfAttention(config)
608
+
609
+
610
+ class Dinov2WithRegistersLayerScale(nn.Module):
611
+ def __init__(self, config) -> None:
612
+ super().__init__()
613
+ self.lambda1 = nn.Parameter(
614
+ config.layerscale_value * torch.ones(config.hidden_size)
615
+ )
616
+
617
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
618
+ return hidden_state * self.lambda1
619
+
620
+
621
+ def drop_path(
622
+ input: torch.Tensor, drop_prob: float = 0.0, training: bool = False
623
+ ) -> torch.Tensor:
624
+ """
625
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
626
+
627
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
628
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
629
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
630
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
631
+ argument.
632
+ """
633
+ if drop_prob == 0.0 or not training:
634
+ return input
635
+ keep_prob = 1 - drop_prob
636
+ shape = (input.shape[0],) + (1,) * (
637
+ input.ndim - 1
638
+ ) # work with diff dim tensors, not just 2D ConvNets
639
+ random_tensor = keep_prob + torch.rand(
640
+ shape, dtype=input.dtype, device=input.device
641
+ )
642
+ random_tensor.floor_() # binarize
643
+ output = input.div(keep_prob) * random_tensor
644
+ return output
645
+
646
+
647
+ class Dinov2WithRegistersDropPath(nn.Module):
648
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
649
+
650
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
651
+ super().__init__()
652
+ self.drop_prob = drop_prob
653
+
654
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
655
+ return drop_path(hidden_states, self.drop_prob, self.training)
656
+
657
+ def extra_repr(self) -> str:
658
+ return "p={}".format(self.drop_prob)
659
+
660
+
661
+ class Dinov2WithRegistersMLP(nn.Module):
662
+ def __init__(self, config) -> None:
663
+ super().__init__()
664
+ in_features = out_features = config.hidden_size
665
+ hidden_features = int(config.hidden_size * config.mlp_ratio)
666
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=True)
667
+ if isinstance(config.hidden_act, str):
668
+ self.activation = ACT2FN[config.hidden_act]
669
+ else:
670
+ self.activation = config.hidden_act
671
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=True)
672
+
673
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
674
+ hidden_state = self.fc1(hidden_state)
675
+ hidden_state = self.activation(hidden_state)
676
+ hidden_state = self.fc2(hidden_state)
677
+ return hidden_state
678
+
679
+
680
+ class Dinov2WithRegistersSwiGLUFFN(nn.Module):
681
+ def __init__(self, config) -> None:
682
+ super().__init__()
683
+ in_features = out_features = config.hidden_size
684
+ hidden_features = int(config.hidden_size * config.mlp_ratio)
685
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
686
+
687
+ self.weights_in = nn.Linear(in_features, 2 * hidden_features, bias=True)
688
+ self.weights_out = nn.Linear(hidden_features, out_features, bias=True)
689
+
690
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
691
+ hidden_state = self.weights_in(hidden_state)
692
+ x1, x2 = hidden_state.chunk(2, dim=-1)
693
+ hidden = nn.functional.silu(x1) * x2
694
+ return self.weights_out(hidden)
695
+
696
+
697
+ DINOV2_WITH_REGISTERS_ATTENTION_CLASSES = {
698
+ "eager": Dinov2WithRegistersAttention,
699
+ "sdpa": Dinov2WithRegistersSdpaAttention,
700
+ }
701
+
702
+
703
+ class WindowedDinov2WithRegistersLayer(nn.Module):
704
+ """This corresponds to the Block class in the original implementation."""
705
+
706
+ def __init__(self, config: WindowedDinov2WithRegistersConfig) -> None:
707
+ super().__init__()
708
+
709
+ self.num_windows = config.num_windows
710
+
711
+ self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
712
+ self.attention = DINOV2_WITH_REGISTERS_ATTENTION_CLASSES[
713
+ config._attn_implementation
714
+ ](config)
715
+ self.layer_scale1 = Dinov2WithRegistersLayerScale(config)
716
+ self.drop_path = (
717
+ Dinov2WithRegistersDropPath(config.drop_path_rate)
718
+ if config.drop_path_rate > 0.0
719
+ else nn.Identity()
720
+ )
721
+
722
+ self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
723
+
724
+ if config.use_swiglu_ffn:
725
+ self.mlp = Dinov2WithRegistersSwiGLUFFN(config)
726
+ else:
727
+ self.mlp = Dinov2WithRegistersMLP(config)
728
+ self.layer_scale2 = Dinov2WithRegistersLayerScale(config)
729
+
730
+ def forward(
731
+ self,
732
+ hidden_states: torch.Tensor,
733
+ head_mask: Optional[torch.Tensor] = None,
734
+ output_attentions: bool = False,
735
+ run_full_attention: bool = False,
736
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
737
+ assert head_mask is None, "head_mask is not supported for windowed attention"
738
+ assert (
739
+ not output_attentions
740
+ ), "output_attentions is not supported for windowed attention"
741
+ shortcut = hidden_states
742
+ if run_full_attention:
743
+ # reshape x to remove windows
744
+ B, HW, C = hidden_states.shape
745
+ num_windows_squared = self.num_windows**2
746
+ hidden_states = hidden_states.view(
747
+ B // num_windows_squared, num_windows_squared * HW, C
748
+ )
749
+
750
+ self_attention_outputs = self.attention(
751
+ self.norm1(
752
+ hidden_states
753
+ ), # in Dinov2WithRegisters, layernorm is applied before self-attention
754
+ head_mask,
755
+ output_attentions=output_attentions,
756
+ )
757
+ attention_output = self_attention_outputs[0]
758
+
759
+ if run_full_attention:
760
+ # reshape x to add windows back
761
+ B, HW, C = hidden_states.shape
762
+ num_windows_squared = self.num_windows**2
763
+ # hidden_states = hidden_states.view(B * num_windows_squared, HW // num_windows_squared, C)
764
+ attention_output = attention_output.view(
765
+ B * num_windows_squared, HW // num_windows_squared, C
766
+ )
767
+
768
+ attention_output = self.layer_scale1(attention_output)
769
+ outputs = self_attention_outputs[
770
+ 1:
771
+ ] # add self attentions if we output attention weights
772
+
773
+ # first residual connection
774
+ hidden_states = self.drop_path(attention_output) + shortcut
775
+
776
+ # in Dinov2WithRegisters, layernorm is also applied after self-attention
777
+ layer_output = self.norm2(hidden_states)
778
+ layer_output = self.mlp(layer_output)
779
+ layer_output = self.layer_scale2(layer_output)
780
+
781
+ # second residual connection
782
+ layer_output = self.drop_path(layer_output) + hidden_states
783
+
784
+ outputs = (layer_output,) + outputs
785
+
786
+ return outputs
787
+
788
+
789
+ class WindowedDinov2WithRegistersEncoder(nn.Module):
790
+ def __init__(self, config: WindowedDinov2WithRegistersConfig) -> None:
791
+ super().__init__()
792
+ self.config = config
793
+ self.layer = nn.ModuleList(
794
+ [
795
+ WindowedDinov2WithRegistersLayer(config)
796
+ for _ in range(config.num_hidden_layers)
797
+ ]
798
+ )
799
+ self.gradient_checkpointing = config.gradient_checkpointing
800
+
801
+ def forward(
802
+ self,
803
+ hidden_states: torch.Tensor,
804
+ head_mask: Optional[torch.Tensor] = None,
805
+ output_attentions: bool = False,
806
+ output_hidden_states: bool = False,
807
+ return_dict: bool = True,
808
+ ) -> Union[tuple, BaseModelOutput]:
809
+ all_hidden_states = () if output_hidden_states else None
810
+ all_self_attentions = () if output_attentions else None
811
+
812
+ for i, layer_module in enumerate(self.layer):
813
+ if output_hidden_states:
814
+ all_hidden_states = all_hidden_states + (hidden_states,)
815
+
816
+ if i > int(self.config.out_features[-1][5:]):
817
+ # early stop if we have reached the last output feature
818
+ break
819
+
820
+ run_full_attention = i not in self.config.window_block_indexes
821
+
822
+ layer_head_mask = head_mask[i] if head_mask is not None else None
823
+
824
+ if self.gradient_checkpointing and self.training:
825
+ layer_outputs = self._gradient_checkpointing_func(
826
+ layer_module.__call__,
827
+ hidden_states,
828
+ layer_head_mask,
829
+ output_attentions,
830
+ run_full_attention,
831
+ )
832
+ else:
833
+ layer_outputs = layer_module(
834
+ hidden_states,
835
+ layer_head_mask,
836
+ output_attentions,
837
+ run_full_attention,
838
+ )
839
+
840
+ hidden_states = layer_outputs[0]
841
+
842
+ if output_attentions:
843
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
844
+
845
+ if output_hidden_states:
846
+ all_hidden_states = all_hidden_states + (hidden_states,)
847
+
848
+ if not return_dict:
849
+ return tuple(
850
+ v
851
+ for v in [hidden_states, all_hidden_states, all_self_attentions]
852
+ if v is not None
853
+ )
854
+ return BaseModelOutput(
855
+ last_hidden_state=hidden_states,
856
+ hidden_states=all_hidden_states,
857
+ attentions=all_self_attentions,
858
+ )
859
+
860
+
861
+ class WindowedDinov2WithRegistersPreTrainedModel(PreTrainedModel):
862
+ """
863
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
864
+ models.
865
+ """
866
+
867
+ config_class = WindowedDinov2WithRegistersConfig
868
+ base_model_prefix = "dinov2_with_registers"
869
+ main_input_name = "pixel_values"
870
+ supports_gradient_checkpointing = True
871
+ _no_split_modules = ["Dinov2WithRegistersSwiGLUFFN"]
872
+ _supports_sdpa = True
873
+
874
+ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
875
+ """Initialize the weights"""
876
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
877
+ # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
878
+ # `trunc_normal_cpu` not implemented in `half` issues
879
+ module.weight.data = nn.init.trunc_normal_(
880
+ module.weight.data.to(torch.float32),
881
+ mean=0.0,
882
+ std=self.config.initializer_range,
883
+ ).to(module.weight.dtype)
884
+ if module.bias is not None:
885
+ module.bias.data.zero_()
886
+ elif isinstance(module, nn.LayerNorm):
887
+ module.bias.data.zero_()
888
+ module.weight.data.fill_(1.0)
889
+ elif isinstance(module, WindowedDinov2WithRegistersEmbeddings):
890
+ module.position_embeddings.data = nn.init.trunc_normal_(
891
+ module.position_embeddings.data.to(torch.float32),
892
+ mean=0.0,
893
+ std=self.config.initializer_range,
894
+ ).to(module.position_embeddings.dtype)
895
+
896
+ module.cls_token.data = nn.init.trunc_normal_(
897
+ module.cls_token.data.to(torch.float32),
898
+ mean=0.0,
899
+ std=self.config.initializer_range,
900
+ ).to(module.cls_token.dtype)
901
+
902
+
903
+ _EXPECTED_OUTPUT_SHAPE = [1, 257, 768]
904
+
905
+
906
+ DINOV2_WITH_REGISTERS_START_DOCSTRING = r"""
907
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
908
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
909
+ behavior.
910
+
911
+ Parameters:
912
+ config ([`Dinov2WithRegistersConfig`]): Model configuration class with all the parameters of the model.
913
+ Initializing with a config file does not load the weights associated with the model, only the
914
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
915
+ """
916
+
917
+ DINOV2_WITH_REGISTERS_BASE_INPUTS_DOCSTRING = r"""
918
+ Args:
919
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
920
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
921
+ [`BitImageProcessor.preprocess`] for details.
922
+
923
+ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`):
924
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Only relevant for
925
+ pre-training.
926
+
927
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
928
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
929
+
930
+ - 1 indicates the head is **not masked**,
931
+ - 0 indicates the head is **masked**.
932
+
933
+ output_attentions (`bool`, *optional*):
934
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
935
+ tensors for more detail.
936
+ output_hidden_states (`bool`, *optional*):
937
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
938
+ more detail.
939
+ return_dict (`bool`, *optional*):
940
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
941
+ """
942
+
943
+
944
+ @add_start_docstrings(
945
+ "The bare Dinov2WithRegisters Model transformer outputting raw hidden-states without any specific head on top.",
946
+ DINOV2_WITH_REGISTERS_START_DOCSTRING,
947
+ )
948
+ class WindowedDinov2WithRegistersModel(WindowedDinov2WithRegistersPreTrainedModel):
949
+ def __init__(self, config: WindowedDinov2WithRegistersConfig):
950
+ super().__init__(config)
951
+ self.config = config
952
+
953
+ self.embeddings = WindowedDinov2WithRegistersEmbeddings(config)
954
+ self.encoder = WindowedDinov2WithRegistersEncoder(config)
955
+
956
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
957
+
958
+ # Initialize weights and apply final processing
959
+ self.post_init()
960
+
961
+ def get_input_embeddings(self) -> Dinov2WithRegistersPatchEmbeddings:
962
+ return self.embeddings.patch_embeddings
963
+
964
+ def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
965
+ """
966
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
967
+ class PreTrainedModel
968
+ """
969
+ for layer, heads in heads_to_prune.items():
970
+ self.encoder.layer[layer].attention.prune_heads(heads)
971
+
972
+ @add_start_docstrings_to_model_forward(DINOV2_WITH_REGISTERS_BASE_INPUTS_DOCSTRING)
973
+ @add_code_sample_docstrings(
974
+ checkpoint=_CHECKPOINT_FOR_DOC,
975
+ output_type=BaseModelOutputWithPooling,
976
+ config_class=_CONFIG_FOR_DOC,
977
+ modality="vision",
978
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
979
+ )
980
+ def forward(
981
+ self,
982
+ pixel_values: Optional[torch.Tensor] = None,
983
+ bool_masked_pos: Optional[torch.Tensor] = None,
984
+ head_mask: Optional[torch.Tensor] = None,
985
+ output_attentions: Optional[bool] = None,
986
+ output_hidden_states: Optional[bool] = None,
987
+ return_dict: Optional[bool] = None,
988
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
989
+ output_attentions = (
990
+ output_attentions
991
+ if output_attentions is not None
992
+ else self.config.output_attentions
993
+ )
994
+ output_hidden_states = (
995
+ output_hidden_states
996
+ if output_hidden_states is not None
997
+ else self.config.output_hidden_states
998
+ )
999
+ return_dict = (
1000
+ return_dict if return_dict is not None else self.config.use_return_dict
1001
+ )
1002
+
1003
+ if pixel_values is None:
1004
+ raise ValueError("You have to specify pixel_values")
1005
+
1006
+ # Prepare head mask if needed
1007
+ # 1.0 in head_mask indicate we keep the head
1008
+ # attention_probs has shape bsz x n_heads x N x N
1009
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
1010
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
1011
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
1012
+
1013
+ embedding_output = self.embeddings(
1014
+ pixel_values, bool_masked_pos=bool_masked_pos
1015
+ )
1016
+
1017
+ encoder_outputs = self.encoder(
1018
+ embedding_output,
1019
+ head_mask=head_mask,
1020
+ output_attentions=output_attentions,
1021
+ output_hidden_states=output_hidden_states,
1022
+ return_dict=return_dict,
1023
+ )
1024
+ sequence_output = encoder_outputs[0]
1025
+ sequence_output = self.layernorm(sequence_output)
1026
+ pooled_output = sequence_output[:, 0, :]
1027
+
1028
+ if not return_dict:
1029
+ head_outputs = (sequence_output, pooled_output)
1030
+ return head_outputs + encoder_outputs[1:]
1031
+
1032
+ return BaseModelOutputWithPooling(
1033
+ last_hidden_state=sequence_output,
1034
+ pooler_output=pooled_output,
1035
+ hidden_states=encoder_outputs.hidden_states,
1036
+ attentions=encoder_outputs.attentions,
1037
+ )
1038
+
1039
+
1040
+ # Image classification docstring
1041
+ _IMAGE_CLASS_CHECKPOINT = "facebook/dinov2_with_registers-small-imagenet1k-1-layer"
1042
+ _IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
1043
+
1044
+ DINOV2_WITH_REGISTERS_INPUTS_DOCSTRING = r"""
1045
+ Args:
1046
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
1047
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
1048
+ [`BitImageProcessor.preprocess`] for details.
1049
+
1050
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
1051
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
1052
+
1053
+ - 1 indicates the head is **not masked**,
1054
+ - 0 indicates the head is **masked**.
1055
+
1056
+ output_attentions (`bool`, *optional*):
1057
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1058
+ tensors for more detail.
1059
+ output_hidden_states (`bool`, *optional*):
1060
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1061
+ more detail.
1062
+ return_dict (`bool`, *optional*):
1063
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1064
+ """
1065
+
1066
+
1067
+ @add_start_docstrings(
1068
+ """
1069
+ Dinov2WithRegisters Model transformer with an image classification head on top (a linear layer on top of the final hidden state
1070
+ of the [CLS] token) e.g. for ImageNet.
1071
+ """,
1072
+ DINOV2_WITH_REGISTERS_START_DOCSTRING,
1073
+ )
1074
+ class WindowedDinov2WithRegistersForImageClassification(
1075
+ WindowedDinov2WithRegistersPreTrainedModel
1076
+ ):
1077
+ def __init__(self, config: WindowedDinov2WithRegistersConfig) -> None:
1078
+ super().__init__(config)
1079
+
1080
+ self.num_labels = config.num_labels
1081
+ self.dinov2_with_registers = WindowedDinov2WithRegistersModel(config)
1082
+
1083
+ # Classifier head
1084
+ self.classifier = (
1085
+ nn.Linear(config.hidden_size * 2, config.num_labels)
1086
+ if config.num_labels > 0
1087
+ else nn.Identity()
1088
+ )
1089
+
1090
+ # Initialize weights and apply final processing
1091
+ self.post_init()
1092
+
1093
+ @add_start_docstrings_to_model_forward(DINOV2_WITH_REGISTERS_INPUTS_DOCSTRING)
1094
+ @add_code_sample_docstrings(
1095
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
1096
+ output_type=ImageClassifierOutput,
1097
+ config_class=_CONFIG_FOR_DOC,
1098
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
1099
+ )
1100
+ def forward(
1101
+ self,
1102
+ pixel_values: Optional[torch.Tensor] = None,
1103
+ head_mask: Optional[torch.Tensor] = None,
1104
+ labels: Optional[torch.Tensor] = None,
1105
+ output_attentions: Optional[bool] = None,
1106
+ output_hidden_states: Optional[bool] = None,
1107
+ return_dict: Optional[bool] = None,
1108
+ ) -> Union[tuple, ImageClassifierOutput]:
1109
+ r"""
1110
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1111
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
1112
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1113
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1114
+ """
1115
+ return_dict = (
1116
+ return_dict if return_dict is not None else self.config.use_return_dict
1117
+ )
1118
+
1119
+ outputs = self.dinov2_with_registers(
1120
+ pixel_values,
1121
+ head_mask=head_mask,
1122
+ output_attentions=output_attentions,
1123
+ output_hidden_states=output_hidden_states,
1124
+ return_dict=return_dict,
1125
+ )
1126
+
1127
+ sequence_output = outputs[0] # batch_size, sequence_length, hidden_size
1128
+
1129
+ cls_token = sequence_output[:, 0]
1130
+ patch_tokens = sequence_output[:, 1:]
1131
+
1132
+ linear_input = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1)
1133
+
1134
+ logits = self.classifier(linear_input)
1135
+
1136
+ loss = None
1137
+ if labels is not None:
1138
+ # move labels to correct device to enable model parallelism
1139
+ labels = labels.to(logits.device)
1140
+ if self.config.problem_type is None:
1141
+ if self.num_labels == 1:
1142
+ self.config.problem_type = "regression"
1143
+ elif self.num_labels > 1 and (
1144
+ labels.dtype == torch.long or labels.dtype == torch.int
1145
+ ):
1146
+ self.config.problem_type = "single_label_classification"
1147
+ else:
1148
+ self.config.problem_type = "multi_label_classification"
1149
+
1150
+ if self.config.problem_type == "regression":
1151
+ loss_fct = MSELoss()
1152
+ if self.num_labels == 1:
1153
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1154
+ else:
1155
+ loss = loss_fct(logits, labels)
1156
+ elif self.config.problem_type == "single_label_classification":
1157
+ loss_fct = CrossEntropyLoss()
1158
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1159
+ elif self.config.problem_type == "multi_label_classification":
1160
+ loss_fct = BCEWithLogitsLoss()
1161
+ loss = loss_fct(logits, labels)
1162
+
1163
+ if not return_dict:
1164
+ output = (logits,) + outputs[2:]
1165
+ return ((loss,) + output) if loss is not None else output
1166
+
1167
+ return ImageClassifierOutput(
1168
+ loss=loss,
1169
+ logits=logits,
1170
+ hidden_states=outputs.hidden_states,
1171
+ attentions=outputs.attentions,
1172
+ )
1173
+
1174
+
1175
+ @add_start_docstrings(
1176
+ """
1177
+ Dinov2WithRegisters backbone, to be used with frameworks like DETR and MaskFormer.
1178
+ """,
1179
+ DINOV2_WITH_REGISTERS_START_DOCSTRING,
1180
+ )
1181
+ class WindowedDinov2WithRegistersBackbone(
1182
+ WindowedDinov2WithRegistersPreTrainedModel, BackboneMixin
1183
+ ):
1184
+ def __init__(self, config: WindowedDinov2WithRegistersConfig):
1185
+ super().__init__(config)
1186
+ super()._init_backbone(config)
1187
+ self.num_features = [
1188
+ config.hidden_size for _ in range(config.num_hidden_layers + 1)
1189
+ ]
1190
+ self.embeddings = WindowedDinov2WithRegistersEmbeddings(config)
1191
+ self.encoder = WindowedDinov2WithRegistersEncoder(config)
1192
+
1193
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1194
+
1195
+ self.num_register_tokens = config.num_register_tokens
1196
+
1197
+ # Initialize weights and apply final processing
1198
+ self.post_init()
1199
+
1200
+ def get_input_embeddings(self) -> Dinov2WithRegistersPatchEmbeddings:
1201
+ return self.embeddings.patch_embeddings
1202
+
1203
+ @add_start_docstrings_to_model_forward(DINOV2_WITH_REGISTERS_INPUTS_DOCSTRING)
1204
+ @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
1205
+ def forward(
1206
+ self,
1207
+ pixel_values: torch.Tensor,
1208
+ output_hidden_states: Optional[bool] = None,
1209
+ output_attentions: Optional[bool] = None,
1210
+ return_dict: Optional[bool] = None,
1211
+ ) -> BackboneOutput:
1212
+ """
1213
+ Returns:
1214
+
1215
+ Examples:
1216
+ Returns:
1217
+
1218
+ Examples:
1219
+
1220
+
1221
+ ```python
1222
+ >>> from transformers import AutoImageProcessor, AutoBackbone
1223
+ >>> import torch
1224
+ >>> from PIL import Image
1225
+ >>> import requests
1226
+
1227
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1228
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1229
+
1230
+ >>> processor = AutoImageProcessor.from_pretrained("facebook/dinov2-with-registers-base")
1231
+ >>> model = AutoBackbone.from_pretrained(
1232
+ ... "facebook/dinov2-with-registers-base", out_features=["stage2", "stage5", "stage8", "stage11"]
1233
+ ... )
1234
+
1235
+ >>> inputs = processor(image, return_tensors="pt")
1236
+
1237
+ >>> outputs = model(**inputs)
1238
+ >>> feature_maps = outputs.feature_maps
1239
+ >>> list(feature_maps[-1].shape)
1240
+ [1, 768, 16, 16]
1241
+ ```"""
1242
+ return_dict = (
1243
+ return_dict if return_dict is not None else self.config.use_return_dict
1244
+ )
1245
+ output_hidden_states = (
1246
+ output_hidden_states
1247
+ if output_hidden_states is not None
1248
+ else self.config.output_hidden_states
1249
+ )
1250
+ output_attentions = (
1251
+ output_attentions
1252
+ if output_attentions is not None
1253
+ else self.config.output_attentions
1254
+ )
1255
+
1256
+ embedding_output = self.embeddings(pixel_values)
1257
+
1258
+ outputs = self.encoder(
1259
+ embedding_output,
1260
+ output_hidden_states=True,
1261
+ output_attentions=output_attentions,
1262
+ return_dict=return_dict,
1263
+ )
1264
+
1265
+ hidden_states = outputs.hidden_states if return_dict else outputs[1]
1266
+
1267
+ feature_maps = ()
1268
+ for stage, hidden_state in zip(self.stage_names, hidden_states):
1269
+ if stage in self.out_features:
1270
+ if self.config.apply_layernorm:
1271
+ hidden_state = self.layernorm(hidden_state)
1272
+ if self.config.reshape_hidden_states:
1273
+ hidden_state = hidden_state[:, self.num_register_tokens + 1 :]
1274
+ # this was actually a bug in the original implementation that we copied here,
1275
+ # cause normally the order is height, width
1276
+ batch_size, _, height, width = pixel_values.shape
1277
+ patch_size = self.config.patch_size
1278
+
1279
+ num_h_patches = height // patch_size
1280
+ num_w_patches = width // patch_size
1281
+
1282
+ if self.config.num_windows > 1:
1283
+ # undo windowing
1284
+ num_windows_squared = self.config.num_windows**2
1285
+ B, HW, C = hidden_state.shape
1286
+ num_h_patches_per_window = (
1287
+ num_h_patches // self.config.num_windows
1288
+ )
1289
+ num_w_patches_per_window = (
1290
+ num_w_patches // self.config.num_windows
1291
+ )
1292
+ hidden_state = hidden_state.reshape(
1293
+ B // num_windows_squared, num_windows_squared * HW, C
1294
+ )
1295
+ hidden_state = hidden_state.reshape(
1296
+ (B // num_windows_squared) * self.config.num_windows,
1297
+ self.config.num_windows,
1298
+ num_h_patches_per_window,
1299
+ num_w_patches_per_window,
1300
+ C,
1301
+ )
1302
+ hidden_state = hidden_state.permute(0, 2, 1, 3, 4)
1303
+
1304
+ hidden_state = hidden_state.reshape(
1305
+ batch_size, num_h_patches, num_w_patches, -1
1306
+ )
1307
+ hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
1308
+
1309
+ feature_maps += (hidden_state,)
1310
+
1311
+ if not return_dict:
1312
+ if output_hidden_states:
1313
+ output = (feature_maps,) + outputs[1:]
1314
+ else:
1315
+ output = (feature_maps,) + outputs[2:]
1316
+ return output
1317
+
1318
+ return BackboneOutput(
1319
+ feature_maps=feature_maps,
1320
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
1321
+ attentions=outputs.attentions if output_attentions else None,
1322
+ )
1323
+
1324
+
1325
+ __all__ = [
1326
+ "WindowedDinov2WithRegistersPreTrainedModel",
1327
+ "WindowedDinov2WithRegistersModel",
1328
+ "WindowedDinov2WithRegistersForImageClassification",
1329
+ "WindowedDinov2WithRegistersBackbone",
1330
+ ]