inference-models 0.18.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inference_models/__init__.py +36 -0
- inference_models/configuration.py +72 -0
- inference_models/constants.py +2 -0
- inference_models/entities.py +5 -0
- inference_models/errors.py +137 -0
- inference_models/logger.py +52 -0
- inference_models/model_pipelines/__init__.py +0 -0
- inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
- inference_models/model_pipelines/auto_loaders/core.py +120 -0
- inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
- inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
- inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
- inference_models/models/__init__.py +0 -0
- inference_models/models/auto_loaders/__init__.py +0 -0
- inference_models/models/auto_loaders/access_manager.py +168 -0
- inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
- inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
- inference_models/models/auto_loaders/constants.py +7 -0
- inference_models/models/auto_loaders/core.py +1341 -0
- inference_models/models/auto_loaders/dependency_models.py +52 -0
- inference_models/models/auto_loaders/entities.py +57 -0
- inference_models/models/auto_loaders/models_registry.py +497 -0
- inference_models/models/auto_loaders/presentation_utils.py +333 -0
- inference_models/models/auto_loaders/ranking.py +413 -0
- inference_models/models/auto_loaders/utils.py +31 -0
- inference_models/models/base/__init__.py +0 -0
- inference_models/models/base/classification.py +123 -0
- inference_models/models/base/depth_estimation.py +62 -0
- inference_models/models/base/documents_parsing.py +111 -0
- inference_models/models/base/embeddings.py +66 -0
- inference_models/models/base/instance_segmentation.py +87 -0
- inference_models/models/base/keypoints_detection.py +93 -0
- inference_models/models/base/object_detection.py +143 -0
- inference_models/models/base/semantic_segmentation.py +74 -0
- inference_models/models/base/types.py +5 -0
- inference_models/models/clip/__init__.py +0 -0
- inference_models/models/clip/clip_onnx.py +148 -0
- inference_models/models/clip/clip_pytorch.py +104 -0
- inference_models/models/clip/preprocessing.py +162 -0
- inference_models/models/common/__init__.py +0 -0
- inference_models/models/common/cuda.py +30 -0
- inference_models/models/common/model_packages.py +25 -0
- inference_models/models/common/onnx.py +379 -0
- inference_models/models/common/roboflow/__init__.py +0 -0
- inference_models/models/common/roboflow/model_packages.py +361 -0
- inference_models/models/common/roboflow/post_processing.py +436 -0
- inference_models/models/common/roboflow/pre_processing.py +1332 -0
- inference_models/models/common/torch.py +20 -0
- inference_models/models/common/trt.py +266 -0
- inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
- inference_models/models/depth_anything_v2/__init__.py +0 -0
- inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
- inference_models/models/dinov3/__init__.py +0 -0
- inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
- inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
- inference_models/models/doctr/__init__.py +0 -0
- inference_models/models/doctr/doctr_torch.py +304 -0
- inference_models/models/easy_ocr/__init__.py +0 -0
- inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
- inference_models/models/florence2/__init__.py +0 -0
- inference_models/models/florence2/florence2_hf.py +897 -0
- inference_models/models/grounding_dino/__init__.py +0 -0
- inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
- inference_models/models/l2cs/__init__.py +0 -0
- inference_models/models/l2cs/l2cs_onnx.py +216 -0
- inference_models/models/mediapipe_face_detection/__init__.py +0 -0
- inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
- inference_models/models/moondream2/__init__.py +0 -0
- inference_models/models/moondream2/moondream2_hf.py +281 -0
- inference_models/models/owlv2/__init__.py +0 -0
- inference_models/models/owlv2/cache.py +182 -0
- inference_models/models/owlv2/entities.py +112 -0
- inference_models/models/owlv2/owlv2_hf.py +695 -0
- inference_models/models/owlv2/reference_dataset.py +291 -0
- inference_models/models/paligemma/__init__.py +0 -0
- inference_models/models/paligemma/paligemma_hf.py +209 -0
- inference_models/models/perception_encoder/__init__.py +0 -0
- inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
- inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
- inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
- inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
- inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
- inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
- inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
- inference_models/models/qwen25vl/__init__.py +1 -0
- inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
- inference_models/models/resnet/__init__.py +0 -0
- inference_models/models/resnet/resnet_classification_onnx.py +330 -0
- inference_models/models/resnet/resnet_classification_torch.py +305 -0
- inference_models/models/resnet/resnet_classification_trt.py +369 -0
- inference_models/models/rfdetr/__init__.py +0 -0
- inference_models/models/rfdetr/backbone_builder.py +101 -0
- inference_models/models/rfdetr/class_remapping.py +41 -0
- inference_models/models/rfdetr/common.py +115 -0
- inference_models/models/rfdetr/default_labels.py +108 -0
- inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
- inference_models/models/rfdetr/misc.py +26 -0
- inference_models/models/rfdetr/ms_deform_attn.py +180 -0
- inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
- inference_models/models/rfdetr/position_encoding.py +166 -0
- inference_models/models/rfdetr/post_processor.py +83 -0
- inference_models/models/rfdetr/projector.py +373 -0
- inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
- inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
- inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
- inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
- inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
- inference_models/models/rfdetr/segmentation_head.py +273 -0
- inference_models/models/rfdetr/transformer.py +767 -0
- inference_models/models/roboflow_instant/__init__.py +0 -0
- inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
- inference_models/models/sam/__init__.py +0 -0
- inference_models/models/sam/cache.py +147 -0
- inference_models/models/sam/entities.py +25 -0
- inference_models/models/sam/sam_torch.py +675 -0
- inference_models/models/sam2/__init__.py +0 -0
- inference_models/models/sam2/cache.py +162 -0
- inference_models/models/sam2/entities.py +43 -0
- inference_models/models/sam2/sam2_torch.py +905 -0
- inference_models/models/sam2_rt/__init__.py +0 -0
- inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
- inference_models/models/smolvlm/__init__.py +0 -0
- inference_models/models/smolvlm/smolvlm_hf.py +245 -0
- inference_models/models/trocr/__init__.py +0 -0
- inference_models/models/trocr/trocr_hf.py +53 -0
- inference_models/models/vit/__init__.py +0 -0
- inference_models/models/vit/vit_classification_huggingface.py +319 -0
- inference_models/models/vit/vit_classification_onnx.py +326 -0
- inference_models/models/vit/vit_classification_trt.py +365 -0
- inference_models/models/yolact/__init__.py +1 -0
- inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
- inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
- inference_models/models/yolo_world/__init__.py +1 -0
- inference_models/models/yolonas/__init__.py +0 -0
- inference_models/models/yolonas/nms.py +44 -0
- inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
- inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
- inference_models/models/yolov10/__init__.py +0 -0
- inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
- inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
- inference_models/models/yolov11/__init__.py +0 -0
- inference_models/models/yolov11/yolov11_onnx.py +28 -0
- inference_models/models/yolov11/yolov11_torch_script.py +25 -0
- inference_models/models/yolov11/yolov11_trt.py +21 -0
- inference_models/models/yolov12/__init__.py +0 -0
- inference_models/models/yolov12/yolov12_onnx.py +7 -0
- inference_models/models/yolov12/yolov12_torch_script.py +7 -0
- inference_models/models/yolov12/yolov12_trt.py +7 -0
- inference_models/models/yolov5/__init__.py +0 -0
- inference_models/models/yolov5/nms.py +99 -0
- inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
- inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
- inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
- inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
- inference_models/models/yolov7/__init__.py +0 -0
- inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
- inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
- inference_models/models/yolov8/__init__.py +0 -0
- inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
- inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
- inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
- inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
- inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
- inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
- inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
- inference_models/models/yolov9/__init__.py +0 -0
- inference_models/models/yolov9/yolov9_onnx.py +7 -0
- inference_models/models/yolov9/yolov9_torch_script.py +7 -0
- inference_models/models/yolov9/yolov9_trt.py +7 -0
- inference_models/runtime_introspection/__init__.py +0 -0
- inference_models/runtime_introspection/core.py +410 -0
- inference_models/utils/__init__.py +0 -0
- inference_models/utils/download.py +608 -0
- inference_models/utils/environment.py +28 -0
- inference_models/utils/file_system.py +51 -0
- inference_models/utils/hashing.py +7 -0
- inference_models/utils/imports.py +48 -0
- inference_models/utils/onnx_introspection.py +17 -0
- inference_models/weights_providers/__init__.py +0 -0
- inference_models/weights_providers/core.py +20 -0
- inference_models/weights_providers/entities.py +159 -0
- inference_models/weights_providers/roboflow.py +601 -0
- inference_models-0.18.3.dist-info/METADATA +466 -0
- inference_models-0.18.3.dist-info/RECORD +195 -0
- inference_models-0.18.3.dist-info/WHEEL +5 -0
- inference_models-0.18.3.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,394 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import math
|
|
3
|
+
import os
|
|
4
|
+
import types
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
from peft import PeftModel
|
|
9
|
+
from torch import nn
|
|
10
|
+
from transformers import AutoBackbone
|
|
11
|
+
|
|
12
|
+
from inference_models.logger import LOGGER
|
|
13
|
+
from inference_models.models.rfdetr.dinov2_with_windowed_attn import (
|
|
14
|
+
WindowedDinov2WithRegistersBackbone,
|
|
15
|
+
WindowedDinov2WithRegistersConfig,
|
|
16
|
+
)
|
|
17
|
+
from inference_models.models.rfdetr.misc import NestedTensor
|
|
18
|
+
from inference_models.models.rfdetr.projector import MultiScaleProjector
|
|
19
|
+
|
|
20
|
+
size_to_width = {
|
|
21
|
+
"tiny": 192,
|
|
22
|
+
"small": 384,
|
|
23
|
+
"base": 768,
|
|
24
|
+
"large": 1024,
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
size_to_config = {
|
|
28
|
+
"small": "dinov2_small.json",
|
|
29
|
+
"base": "dinov2_base.json",
|
|
30
|
+
"large": "dinov2_large.json",
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
size_to_config_with_registers = {
|
|
34
|
+
"small": "dinov2_with_registers_small.json",
|
|
35
|
+
"base": "dinov2_with_registers_base.json",
|
|
36
|
+
"large": "dinov2_with_registers_large.json",
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def get_config(size, use_registers):
|
|
41
|
+
config_dict = size_to_config_with_registers if use_registers else size_to_config
|
|
42
|
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
43
|
+
configs_dir = os.path.join(current_dir, "dinov2_configs")
|
|
44
|
+
config_path = os.path.join(configs_dir, config_dict[size])
|
|
45
|
+
with open(config_path, "r") as f:
|
|
46
|
+
dino_config = json.load(f)
|
|
47
|
+
return dino_config
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class DinoV2(nn.Module):
|
|
51
|
+
def __init__(
|
|
52
|
+
self,
|
|
53
|
+
shape=(640, 640),
|
|
54
|
+
out_feature_indexes=[2, 4, 5, 9],
|
|
55
|
+
size="base",
|
|
56
|
+
use_registers=True,
|
|
57
|
+
use_windowed_attn=True,
|
|
58
|
+
gradient_checkpointing=False,
|
|
59
|
+
load_dinov2_weights=True,
|
|
60
|
+
patch_size=14,
|
|
61
|
+
num_windows=4,
|
|
62
|
+
positional_encoding_size=37,
|
|
63
|
+
):
|
|
64
|
+
super().__init__()
|
|
65
|
+
|
|
66
|
+
name = (
|
|
67
|
+
f"facebook/dinov2-with-registers-{size}"
|
|
68
|
+
if use_registers
|
|
69
|
+
else f"facebook/dinov2-{size}"
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
self.shape = shape
|
|
73
|
+
self.patch_size = patch_size
|
|
74
|
+
self.num_windows = num_windows
|
|
75
|
+
|
|
76
|
+
# Create the encoder
|
|
77
|
+
|
|
78
|
+
if not use_windowed_attn:
|
|
79
|
+
assert (
|
|
80
|
+
not gradient_checkpointing
|
|
81
|
+
), "Gradient checkpointing is not supported for non-windowed attention"
|
|
82
|
+
assert (
|
|
83
|
+
load_dinov2_weights
|
|
84
|
+
), "Using non-windowed attention requires loading dinov2 weights from hub"
|
|
85
|
+
self.encoder = AutoBackbone.from_pretrained(
|
|
86
|
+
name,
|
|
87
|
+
out_features=[f"stage{i}" for i in out_feature_indexes],
|
|
88
|
+
return_dict=False,
|
|
89
|
+
)
|
|
90
|
+
else:
|
|
91
|
+
window_block_indexes = set(range(out_feature_indexes[-1] + 1))
|
|
92
|
+
window_block_indexes.difference_update(out_feature_indexes)
|
|
93
|
+
window_block_indexes = list(window_block_indexes)
|
|
94
|
+
|
|
95
|
+
dino_config = get_config(size, use_registers)
|
|
96
|
+
|
|
97
|
+
dino_config["return_dict"] = False
|
|
98
|
+
dino_config["out_features"] = [f"stage{i}" for i in out_feature_indexes]
|
|
99
|
+
implied_resolution = positional_encoding_size * patch_size
|
|
100
|
+
|
|
101
|
+
if implied_resolution != dino_config["image_size"]:
|
|
102
|
+
LOGGER.warning(
|
|
103
|
+
f"Using a different number of positional encodings than DINOv2, which means we're not loading DINOv2 backbone weights. This is not a problem if finetuning a pretrained RF-DETR model."
|
|
104
|
+
)
|
|
105
|
+
dino_config["image_size"] = implied_resolution
|
|
106
|
+
load_dinov2_weights = False
|
|
107
|
+
|
|
108
|
+
if patch_size != 14:
|
|
109
|
+
LOGGER.warning(
|
|
110
|
+
f"Using patch size {patch_size} instead of 14, which means we're not loading DINOv2 backbone weights. This is not a problem if finetuning a pretrained RF-DETR model."
|
|
111
|
+
)
|
|
112
|
+
dino_config["patch_size"] = patch_size
|
|
113
|
+
load_dinov2_weights = False
|
|
114
|
+
|
|
115
|
+
if use_registers:
|
|
116
|
+
windowed_dino_config = WindowedDinov2WithRegistersConfig(
|
|
117
|
+
**dino_config,
|
|
118
|
+
num_windows=num_windows,
|
|
119
|
+
window_block_indexes=window_block_indexes,
|
|
120
|
+
gradient_checkpointing=gradient_checkpointing,
|
|
121
|
+
)
|
|
122
|
+
else:
|
|
123
|
+
windowed_dino_config = WindowedDinov2WithRegistersConfig(
|
|
124
|
+
**dino_config,
|
|
125
|
+
num_windows=num_windows,
|
|
126
|
+
window_block_indexes=window_block_indexes,
|
|
127
|
+
num_register_tokens=0,
|
|
128
|
+
gradient_checkpointing=gradient_checkpointing,
|
|
129
|
+
)
|
|
130
|
+
self.encoder = (
|
|
131
|
+
WindowedDinov2WithRegistersBackbone.from_pretrained(
|
|
132
|
+
name,
|
|
133
|
+
config=windowed_dino_config,
|
|
134
|
+
)
|
|
135
|
+
if load_dinov2_weights
|
|
136
|
+
else WindowedDinov2WithRegistersBackbone(windowed_dino_config)
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
self._out_feature_channels = [size_to_width[size]] * len(out_feature_indexes)
|
|
140
|
+
self._export = False
|
|
141
|
+
|
|
142
|
+
def export(self):
|
|
143
|
+
if self._export:
|
|
144
|
+
return
|
|
145
|
+
self._export = True
|
|
146
|
+
shape = self.shape
|
|
147
|
+
|
|
148
|
+
def make_new_interpolated_pos_encoding(
|
|
149
|
+
position_embeddings, patch_size, height, width
|
|
150
|
+
):
|
|
151
|
+
|
|
152
|
+
num_positions = position_embeddings.shape[1] - 1
|
|
153
|
+
dim = position_embeddings.shape[-1]
|
|
154
|
+
height = height // patch_size
|
|
155
|
+
width = width // patch_size
|
|
156
|
+
|
|
157
|
+
class_pos_embed = position_embeddings[:, 0]
|
|
158
|
+
patch_pos_embed = position_embeddings[:, 1:]
|
|
159
|
+
|
|
160
|
+
# Reshape and permute
|
|
161
|
+
patch_pos_embed = patch_pos_embed.reshape(
|
|
162
|
+
1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
|
|
163
|
+
)
|
|
164
|
+
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
|
165
|
+
|
|
166
|
+
# Use bilinear interpolation without antialias
|
|
167
|
+
patch_pos_embed = F.interpolate(
|
|
168
|
+
patch_pos_embed,
|
|
169
|
+
size=(height, width),
|
|
170
|
+
mode="bicubic",
|
|
171
|
+
align_corners=False,
|
|
172
|
+
antialias=True,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
# Reshape back
|
|
176
|
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).reshape(1, -1, dim)
|
|
177
|
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
|
178
|
+
|
|
179
|
+
# If the shape of self.encoder.embeddings.position_embeddings
|
|
180
|
+
# matches the shape of your new tensor, use copy_:
|
|
181
|
+
with torch.no_grad():
|
|
182
|
+
new_positions = make_new_interpolated_pos_encoding(
|
|
183
|
+
self.encoder.embeddings.position_embeddings,
|
|
184
|
+
self.encoder.config.patch_size,
|
|
185
|
+
shape[0],
|
|
186
|
+
shape[1],
|
|
187
|
+
)
|
|
188
|
+
# Create a new Parameter with the new size
|
|
189
|
+
old_interpolate_pos_encoding = self.encoder.embeddings.interpolate_pos_encoding
|
|
190
|
+
|
|
191
|
+
def new_interpolate_pos_encoding(self_mod, embeddings, height, width):
|
|
192
|
+
num_patches = embeddings.shape[1] - 1
|
|
193
|
+
num_positions = self_mod.position_embeddings.shape[1] - 1
|
|
194
|
+
if num_patches == num_positions and height == width:
|
|
195
|
+
return self_mod.position_embeddings
|
|
196
|
+
return old_interpolate_pos_encoding(embeddings, height, width)
|
|
197
|
+
|
|
198
|
+
self.encoder.embeddings.position_embeddings = nn.Parameter(new_positions)
|
|
199
|
+
self.encoder.embeddings.interpolate_pos_encoding = types.MethodType(
|
|
200
|
+
new_interpolate_pos_encoding, self.encoder.embeddings
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
def forward(self, x):
|
|
204
|
+
block_size = self.patch_size * self.num_windows
|
|
205
|
+
assert (
|
|
206
|
+
x.shape[2] % block_size == 0 and x.shape[3] % block_size == 0
|
|
207
|
+
), f"Backbone requires input shape to be divisible by {block_size}, but got {x.shape}"
|
|
208
|
+
x = self.encoder(x)
|
|
209
|
+
return list(x[0])
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
class BackboneBase(nn.Module):
|
|
213
|
+
def __init__(self):
|
|
214
|
+
super().__init__()
|
|
215
|
+
|
|
216
|
+
def get_named_param_lr_pairs(self, args, prefix: str):
|
|
217
|
+
raise NotImplementedError
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
class Backbone(BackboneBase):
|
|
221
|
+
"""backbone."""
|
|
222
|
+
|
|
223
|
+
def __init__(
|
|
224
|
+
self,
|
|
225
|
+
name: str,
|
|
226
|
+
pretrained_encoder: str = None,
|
|
227
|
+
window_block_indexes: list = None,
|
|
228
|
+
drop_path=0.0,
|
|
229
|
+
out_channels=256,
|
|
230
|
+
out_feature_indexes: list = None,
|
|
231
|
+
projector_scale: list = None,
|
|
232
|
+
use_cls_token: bool = False,
|
|
233
|
+
freeze_encoder: bool = False,
|
|
234
|
+
layer_norm: bool = False,
|
|
235
|
+
target_shape: tuple[int, int] = (640, 640),
|
|
236
|
+
rms_norm: bool = False,
|
|
237
|
+
backbone_lora: bool = False,
|
|
238
|
+
gradient_checkpointing: bool = False,
|
|
239
|
+
load_dinov2_weights: bool = True,
|
|
240
|
+
patch_size: int = 14,
|
|
241
|
+
num_windows: int = 4,
|
|
242
|
+
positional_encoding_size: bool = False,
|
|
243
|
+
):
|
|
244
|
+
super().__init__()
|
|
245
|
+
# an example name here would be "dinov2_base" or "dinov2_registers_windowed_base"
|
|
246
|
+
# if "registers" is in the name, then use_registers is set to True, otherwise it is set to False
|
|
247
|
+
# similarly, if "windowed" is in the name, then use_windowed_attn is set to True, otherwise it is set to False
|
|
248
|
+
# the last part of the name should be the size
|
|
249
|
+
# and the start should be dinov2
|
|
250
|
+
name_parts = name.split("_")
|
|
251
|
+
assert name_parts[0] == "dinov2"
|
|
252
|
+
size = name_parts[-1]
|
|
253
|
+
use_registers = False
|
|
254
|
+
if "registers" in name_parts:
|
|
255
|
+
use_registers = True
|
|
256
|
+
name_parts.remove("registers")
|
|
257
|
+
use_windowed_attn = False
|
|
258
|
+
if "windowed" in name_parts:
|
|
259
|
+
use_windowed_attn = True
|
|
260
|
+
name_parts.remove("windowed")
|
|
261
|
+
assert (
|
|
262
|
+
len(name_parts) == 2
|
|
263
|
+
), "name should be dinov2, then either registers, windowed, both, or none, then the size"
|
|
264
|
+
self.encoder = DinoV2(
|
|
265
|
+
size=name_parts[-1],
|
|
266
|
+
out_feature_indexes=out_feature_indexes,
|
|
267
|
+
shape=target_shape,
|
|
268
|
+
use_registers=use_registers,
|
|
269
|
+
use_windowed_attn=use_windowed_attn,
|
|
270
|
+
gradient_checkpointing=gradient_checkpointing,
|
|
271
|
+
load_dinov2_weights=load_dinov2_weights,
|
|
272
|
+
patch_size=patch_size,
|
|
273
|
+
num_windows=num_windows,
|
|
274
|
+
positional_encoding_size=positional_encoding_size,
|
|
275
|
+
)
|
|
276
|
+
# build encoder + projector as backbone module
|
|
277
|
+
if freeze_encoder:
|
|
278
|
+
for param in self.encoder.parameters():
|
|
279
|
+
param.requires_grad = False
|
|
280
|
+
|
|
281
|
+
self.projector_scale = projector_scale
|
|
282
|
+
assert len(self.projector_scale) > 0
|
|
283
|
+
# x[0]
|
|
284
|
+
assert (
|
|
285
|
+
sorted(self.projector_scale) == self.projector_scale
|
|
286
|
+
), "only support projector scale P3/P4/P5/P6 in ascending order."
|
|
287
|
+
level2scalefactor = dict(P3=2.0, P4=1.0, P5=0.5, P6=0.25)
|
|
288
|
+
scale_factors = [level2scalefactor[lvl] for lvl in self.projector_scale]
|
|
289
|
+
|
|
290
|
+
self.projector = MultiScaleProjector(
|
|
291
|
+
in_channels=self.encoder._out_feature_channels,
|
|
292
|
+
out_channels=out_channels,
|
|
293
|
+
scale_factors=scale_factors,
|
|
294
|
+
layer_norm=layer_norm,
|
|
295
|
+
rms_norm=rms_norm,
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
self._export = False
|
|
299
|
+
|
|
300
|
+
def export(self):
|
|
301
|
+
self._export = True
|
|
302
|
+
self._forward_origin = self.forward
|
|
303
|
+
self.forward = self.forward_export
|
|
304
|
+
|
|
305
|
+
if isinstance(self.encoder, PeftModel):
|
|
306
|
+
LOGGER.info("Merging and unloading LoRA weights")
|
|
307
|
+
self.encoder.merge_and_unload()
|
|
308
|
+
|
|
309
|
+
def forward(self, tensor_list: NestedTensor):
|
|
310
|
+
""" """
|
|
311
|
+
# (H, W, B, C)
|
|
312
|
+
feats = self.encoder(tensor_list.tensors)
|
|
313
|
+
feats = self.projector(feats)
|
|
314
|
+
# x: [(B, C, H, W)]
|
|
315
|
+
out = []
|
|
316
|
+
for feat in feats:
|
|
317
|
+
m = tensor_list.mask
|
|
318
|
+
assert m is not None
|
|
319
|
+
mask = F.interpolate(m[None].float(), size=feat.shape[-2:]).to(torch.bool)[
|
|
320
|
+
0
|
|
321
|
+
]
|
|
322
|
+
out.append(NestedTensor(feat, mask))
|
|
323
|
+
return out
|
|
324
|
+
|
|
325
|
+
def forward_export(self, tensors: torch.Tensor):
|
|
326
|
+
feats = self.encoder(tensors)
|
|
327
|
+
feats = self.projector(feats)
|
|
328
|
+
out_feats = []
|
|
329
|
+
out_masks = []
|
|
330
|
+
for feat in feats:
|
|
331
|
+
# x: [(B, C, H, W)]
|
|
332
|
+
b, _, h, w = feat.shape
|
|
333
|
+
out_masks.append(
|
|
334
|
+
torch.zeros((b, h, w), dtype=torch.bool, device=feat.device)
|
|
335
|
+
)
|
|
336
|
+
out_feats.append(feat)
|
|
337
|
+
return out_feats, out_masks
|
|
338
|
+
|
|
339
|
+
def get_named_param_lr_pairs(self, args, prefix: str = "backbone.0"):
|
|
340
|
+
num_layers = args.out_feature_indexes[-1] + 1
|
|
341
|
+
backbone_key = "backbone.0.encoder"
|
|
342
|
+
named_param_lr_pairs = {}
|
|
343
|
+
for n, p in self.named_parameters():
|
|
344
|
+
n = prefix + "." + n
|
|
345
|
+
if backbone_key in n and p.requires_grad:
|
|
346
|
+
lr = (
|
|
347
|
+
args.lr_encoder
|
|
348
|
+
* get_dinov2_lr_decay_rate(
|
|
349
|
+
n,
|
|
350
|
+
lr_decay_rate=args.lr_vit_layer_decay,
|
|
351
|
+
num_layers=num_layers,
|
|
352
|
+
)
|
|
353
|
+
* args.lr_component_decay**2
|
|
354
|
+
)
|
|
355
|
+
wd = args.weight_decay * get_dinov2_weight_decay_rate(n)
|
|
356
|
+
named_param_lr_pairs[n] = {
|
|
357
|
+
"params": p,
|
|
358
|
+
"lr": lr,
|
|
359
|
+
"weight_decay": wd,
|
|
360
|
+
}
|
|
361
|
+
return named_param_lr_pairs
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
def get_dinov2_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12):
|
|
365
|
+
"""
|
|
366
|
+
Calculate lr decay rate for different ViT blocks.
|
|
367
|
+
|
|
368
|
+
Args:
|
|
369
|
+
name (string): parameter name.
|
|
370
|
+
lr_decay_rate (float): base lr decay rate.
|
|
371
|
+
num_layers (int): number of ViT blocks.
|
|
372
|
+
Returns:
|
|
373
|
+
lr decay rate for the given parameter.
|
|
374
|
+
"""
|
|
375
|
+
layer_id = num_layers + 1
|
|
376
|
+
if name.startswith("backbone"):
|
|
377
|
+
if "embeddings" in name:
|
|
378
|
+
layer_id = 0
|
|
379
|
+
elif ".layer." in name and ".residual." not in name:
|
|
380
|
+
layer_id = int(name[name.find(".layer.") :].split(".")[2]) + 1
|
|
381
|
+
return lr_decay_rate ** (num_layers + 1 - layer_id)
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
def get_dinov2_weight_decay_rate(name, weight_decay_rate=1.0):
|
|
385
|
+
if (
|
|
386
|
+
("gamma" in name)
|
|
387
|
+
or ("pos_embed" in name)
|
|
388
|
+
or ("rel_pos" in name)
|
|
389
|
+
or ("bias" in name)
|
|
390
|
+
or ("norm" in name)
|
|
391
|
+
or ("embeddings" in name)
|
|
392
|
+
):
|
|
393
|
+
weight_decay_rate = 0.0
|
|
394
|
+
return weight_decay_rate
|