inference-models 0.18.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. inference_models/__init__.py +36 -0
  2. inference_models/configuration.py +72 -0
  3. inference_models/constants.py +2 -0
  4. inference_models/entities.py +5 -0
  5. inference_models/errors.py +137 -0
  6. inference_models/logger.py +52 -0
  7. inference_models/model_pipelines/__init__.py +0 -0
  8. inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
  9. inference_models/model_pipelines/auto_loaders/core.py +120 -0
  10. inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
  11. inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
  12. inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
  13. inference_models/models/__init__.py +0 -0
  14. inference_models/models/auto_loaders/__init__.py +0 -0
  15. inference_models/models/auto_loaders/access_manager.py +168 -0
  16. inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
  17. inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
  18. inference_models/models/auto_loaders/constants.py +7 -0
  19. inference_models/models/auto_loaders/core.py +1341 -0
  20. inference_models/models/auto_loaders/dependency_models.py +52 -0
  21. inference_models/models/auto_loaders/entities.py +57 -0
  22. inference_models/models/auto_loaders/models_registry.py +497 -0
  23. inference_models/models/auto_loaders/presentation_utils.py +333 -0
  24. inference_models/models/auto_loaders/ranking.py +413 -0
  25. inference_models/models/auto_loaders/utils.py +31 -0
  26. inference_models/models/base/__init__.py +0 -0
  27. inference_models/models/base/classification.py +123 -0
  28. inference_models/models/base/depth_estimation.py +62 -0
  29. inference_models/models/base/documents_parsing.py +111 -0
  30. inference_models/models/base/embeddings.py +66 -0
  31. inference_models/models/base/instance_segmentation.py +87 -0
  32. inference_models/models/base/keypoints_detection.py +93 -0
  33. inference_models/models/base/object_detection.py +143 -0
  34. inference_models/models/base/semantic_segmentation.py +74 -0
  35. inference_models/models/base/types.py +5 -0
  36. inference_models/models/clip/__init__.py +0 -0
  37. inference_models/models/clip/clip_onnx.py +148 -0
  38. inference_models/models/clip/clip_pytorch.py +104 -0
  39. inference_models/models/clip/preprocessing.py +162 -0
  40. inference_models/models/common/__init__.py +0 -0
  41. inference_models/models/common/cuda.py +30 -0
  42. inference_models/models/common/model_packages.py +25 -0
  43. inference_models/models/common/onnx.py +379 -0
  44. inference_models/models/common/roboflow/__init__.py +0 -0
  45. inference_models/models/common/roboflow/model_packages.py +361 -0
  46. inference_models/models/common/roboflow/post_processing.py +436 -0
  47. inference_models/models/common/roboflow/pre_processing.py +1332 -0
  48. inference_models/models/common/torch.py +20 -0
  49. inference_models/models/common/trt.py +266 -0
  50. inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
  51. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
  52. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
  53. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
  54. inference_models/models/depth_anything_v2/__init__.py +0 -0
  55. inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
  56. inference_models/models/dinov3/__init__.py +0 -0
  57. inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
  58. inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
  59. inference_models/models/doctr/__init__.py +0 -0
  60. inference_models/models/doctr/doctr_torch.py +304 -0
  61. inference_models/models/easy_ocr/__init__.py +0 -0
  62. inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
  63. inference_models/models/florence2/__init__.py +0 -0
  64. inference_models/models/florence2/florence2_hf.py +897 -0
  65. inference_models/models/grounding_dino/__init__.py +0 -0
  66. inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
  67. inference_models/models/l2cs/__init__.py +0 -0
  68. inference_models/models/l2cs/l2cs_onnx.py +216 -0
  69. inference_models/models/mediapipe_face_detection/__init__.py +0 -0
  70. inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
  71. inference_models/models/moondream2/__init__.py +0 -0
  72. inference_models/models/moondream2/moondream2_hf.py +281 -0
  73. inference_models/models/owlv2/__init__.py +0 -0
  74. inference_models/models/owlv2/cache.py +182 -0
  75. inference_models/models/owlv2/entities.py +112 -0
  76. inference_models/models/owlv2/owlv2_hf.py +695 -0
  77. inference_models/models/owlv2/reference_dataset.py +291 -0
  78. inference_models/models/paligemma/__init__.py +0 -0
  79. inference_models/models/paligemma/paligemma_hf.py +209 -0
  80. inference_models/models/perception_encoder/__init__.py +0 -0
  81. inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
  82. inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
  83. inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
  84. inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
  85. inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
  86. inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
  87. inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
  88. inference_models/models/qwen25vl/__init__.py +1 -0
  89. inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
  90. inference_models/models/resnet/__init__.py +0 -0
  91. inference_models/models/resnet/resnet_classification_onnx.py +330 -0
  92. inference_models/models/resnet/resnet_classification_torch.py +305 -0
  93. inference_models/models/resnet/resnet_classification_trt.py +369 -0
  94. inference_models/models/rfdetr/__init__.py +0 -0
  95. inference_models/models/rfdetr/backbone_builder.py +101 -0
  96. inference_models/models/rfdetr/class_remapping.py +41 -0
  97. inference_models/models/rfdetr/common.py +115 -0
  98. inference_models/models/rfdetr/default_labels.py +108 -0
  99. inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
  100. inference_models/models/rfdetr/misc.py +26 -0
  101. inference_models/models/rfdetr/ms_deform_attn.py +180 -0
  102. inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
  103. inference_models/models/rfdetr/position_encoding.py +166 -0
  104. inference_models/models/rfdetr/post_processor.py +83 -0
  105. inference_models/models/rfdetr/projector.py +373 -0
  106. inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
  107. inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
  108. inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
  109. inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
  110. inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
  111. inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
  112. inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
  113. inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
  114. inference_models/models/rfdetr/segmentation_head.py +273 -0
  115. inference_models/models/rfdetr/transformer.py +767 -0
  116. inference_models/models/roboflow_instant/__init__.py +0 -0
  117. inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
  118. inference_models/models/sam/__init__.py +0 -0
  119. inference_models/models/sam/cache.py +147 -0
  120. inference_models/models/sam/entities.py +25 -0
  121. inference_models/models/sam/sam_torch.py +675 -0
  122. inference_models/models/sam2/__init__.py +0 -0
  123. inference_models/models/sam2/cache.py +162 -0
  124. inference_models/models/sam2/entities.py +43 -0
  125. inference_models/models/sam2/sam2_torch.py +905 -0
  126. inference_models/models/sam2_rt/__init__.py +0 -0
  127. inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
  128. inference_models/models/smolvlm/__init__.py +0 -0
  129. inference_models/models/smolvlm/smolvlm_hf.py +245 -0
  130. inference_models/models/trocr/__init__.py +0 -0
  131. inference_models/models/trocr/trocr_hf.py +53 -0
  132. inference_models/models/vit/__init__.py +0 -0
  133. inference_models/models/vit/vit_classification_huggingface.py +319 -0
  134. inference_models/models/vit/vit_classification_onnx.py +326 -0
  135. inference_models/models/vit/vit_classification_trt.py +365 -0
  136. inference_models/models/yolact/__init__.py +1 -0
  137. inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
  138. inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
  139. inference_models/models/yolo_world/__init__.py +1 -0
  140. inference_models/models/yolonas/__init__.py +0 -0
  141. inference_models/models/yolonas/nms.py +44 -0
  142. inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
  143. inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
  144. inference_models/models/yolov10/__init__.py +0 -0
  145. inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
  146. inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
  147. inference_models/models/yolov11/__init__.py +0 -0
  148. inference_models/models/yolov11/yolov11_onnx.py +28 -0
  149. inference_models/models/yolov11/yolov11_torch_script.py +25 -0
  150. inference_models/models/yolov11/yolov11_trt.py +21 -0
  151. inference_models/models/yolov12/__init__.py +0 -0
  152. inference_models/models/yolov12/yolov12_onnx.py +7 -0
  153. inference_models/models/yolov12/yolov12_torch_script.py +7 -0
  154. inference_models/models/yolov12/yolov12_trt.py +7 -0
  155. inference_models/models/yolov5/__init__.py +0 -0
  156. inference_models/models/yolov5/nms.py +99 -0
  157. inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
  158. inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
  159. inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
  160. inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
  161. inference_models/models/yolov7/__init__.py +0 -0
  162. inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
  163. inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
  164. inference_models/models/yolov8/__init__.py +0 -0
  165. inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
  166. inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
  167. inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
  168. inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
  169. inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
  170. inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
  171. inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
  172. inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
  173. inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
  174. inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
  175. inference_models/models/yolov9/__init__.py +0 -0
  176. inference_models/models/yolov9/yolov9_onnx.py +7 -0
  177. inference_models/models/yolov9/yolov9_torch_script.py +7 -0
  178. inference_models/models/yolov9/yolov9_trt.py +7 -0
  179. inference_models/runtime_introspection/__init__.py +0 -0
  180. inference_models/runtime_introspection/core.py +410 -0
  181. inference_models/utils/__init__.py +0 -0
  182. inference_models/utils/download.py +608 -0
  183. inference_models/utils/environment.py +28 -0
  184. inference_models/utils/file_system.py +51 -0
  185. inference_models/utils/hashing.py +7 -0
  186. inference_models/utils/imports.py +48 -0
  187. inference_models/utils/onnx_introspection.py +17 -0
  188. inference_models/weights_providers/__init__.py +0 -0
  189. inference_models/weights_providers/core.py +20 -0
  190. inference_models/weights_providers/entities.py +159 -0
  191. inference_models/weights_providers/roboflow.py +601 -0
  192. inference_models-0.18.3.dist-info/METADATA +466 -0
  193. inference_models-0.18.3.dist-info/RECORD +195 -0
  194. inference_models-0.18.3.dist-info/WHEEL +5 -0
  195. inference_models-0.18.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,742 @@
1
+ import copy
2
+ import math
3
+ import random
4
+ from collections import OrderedDict
5
+ from dataclasses import asdict
6
+ from functools import partial
7
+ from logging import getLogger
8
+ from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn as nn
13
+ from einops import rearrange
14
+ from timm.layers import DropPath
15
+ from torch import nn
16
+ from torch.nn import functional as F
17
+ from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
18
+ from torch.nn.parameter import Parameter
19
+ from torch.utils.checkpoint import checkpoint
20
+
21
+ from inference_models.logger import LOGGER
22
+ from inference_models.models.perception_encoder.vision_encoder.config import (
23
+ PE_TEXT_CONFIG,
24
+ PE_VISION_CONFIG,
25
+ PEConfig,
26
+ PETextConfig,
27
+ fetch_pe_checkpoint,
28
+ )
29
+ from inference_models.models.perception_encoder.vision_encoder.rope import Rope2D
30
+
31
+ logger = getLogger()
32
+
33
+
34
+ class LayerScale(nn.Module):
35
+ def __init__(self, dim, init_values=1e-5, inplace=False):
36
+ super().__init__()
37
+ self.inplace = inplace
38
+ self.dim = dim
39
+ self.init_values = init_values
40
+
41
+ def forward(self, x):
42
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
43
+
44
+ def init_tensors(self):
45
+ self.gamma = nn.Parameter(self.init_values * torch.ones(self.dim))
46
+
47
+
48
+ class AttentionPooling(nn.Module):
49
+ def __init__(
50
+ self,
51
+ embed_dim: int,
52
+ num_heads: int,
53
+ num_probe: int = 1,
54
+ mlp_ratio: int = 4,
55
+ act_layer: Callable = nn.GELU,
56
+ norm_layer: Callable = nn.LayerNorm,
57
+ ):
58
+ super().__init__()
59
+
60
+ self.embed_dim = embed_dim
61
+ self.num_heads = num_heads
62
+
63
+ assert (
64
+ self.embed_dim % num_heads == 0
65
+ ), "embed_dim must be divisible by num_heads"
66
+
67
+ self.probe = nn.Parameter(torch.randn(1, num_probe, self.embed_dim))
68
+ self.attn = nn.MultiheadAttention(
69
+ self.embed_dim, self.num_heads, batch_first=True
70
+ )
71
+
72
+ self.layernorm = norm_layer(embed_dim)
73
+ self.mlp_width = int(embed_dim * mlp_ratio)
74
+ self.mlp = nn.Sequential(
75
+ OrderedDict(
76
+ [
77
+ ("c_fc", nn.Linear(self.embed_dim, self.mlp_width)),
78
+ ("gelu", act_layer()),
79
+ ("c_proj", nn.Linear(self.mlp_width, self.embed_dim)),
80
+ ]
81
+ )
82
+ )
83
+
84
+ def forward(self, x: torch.Tensor):
85
+ batch, _, _ = x.shape
86
+
87
+ q = self.probe.repeat((batch, 1, 1)).to(x.dtype)
88
+ x = self.attn(q, x, x, need_weights=False)[0]
89
+ x = x + self.mlp(self.layernorm(x))
90
+
91
+ return x
92
+
93
+
94
+ class SelfAttention(nn.Module):
95
+ r"""
96
+ Implements sequence packed attention and RoPe
97
+ """
98
+
99
+ def __init__(
100
+ self,
101
+ embed_dim: int,
102
+ num_heads: int,
103
+ rope: Optional[nn.Module] = None,
104
+ ):
105
+ super(SelfAttention, self).__init__()
106
+ self.embed_dim = embed_dim
107
+
108
+ self.num_heads = num_heads
109
+ self.head_dim = embed_dim // num_heads
110
+ assert (
111
+ self.head_dim * num_heads == self.embed_dim
112
+ ), "embed_dim must be divisible by num_heads"
113
+
114
+ # To make this compatibile with nn.MultiHeadAttention
115
+ self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
116
+ self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
117
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
118
+
119
+ self.rope = rope
120
+ self.scale = self.head_dim ** (-0.5)
121
+
122
+ def init_tensors(self):
123
+ xavier_uniform_(self.in_proj_weight)
124
+ constant_(self.in_proj_bias, 0.0)
125
+ constant_(self.out_proj.bias, 0.0)
126
+
127
+ def forward(self, x, attn_mask=None):
128
+ batch, seq, embed_dim = x.shape
129
+ proj = F.linear(x, self.in_proj_weight, self.in_proj_bias)
130
+
131
+ # reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk()
132
+ proj = (
133
+ proj.unflatten(-1, (3, embed_dim))
134
+ .unsqueeze(0)
135
+ .transpose(0, -2)
136
+ .squeeze(-2)
137
+ .contiguous()
138
+ )
139
+ q, k, v = proj[0], proj[1], proj[2]
140
+
141
+ # Use "q_" so that we don't accidentally quit in pdb :)
142
+ q = rearrange(q, "b s (h d) -> b h s d", h=self.num_heads)
143
+ k = rearrange(k, "b s (h d) -> b h s d", h=self.num_heads)
144
+ v = rearrange(v, "b s (h d) -> b h s d", h=self.num_heads)
145
+
146
+ if self.rope:
147
+ q, k = self.rope(q, k)
148
+
149
+ attn = F.scaled_dot_product_attention(
150
+ q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=self.scale
151
+ )
152
+ attn = rearrange(attn, "b h s d -> b s (h d)")
153
+
154
+ return F.linear(attn, self.out_proj.weight, self.out_proj.bias)
155
+
156
+
157
+ class ResidualAttentionBlock(nn.Module):
158
+ def __init__(
159
+ self,
160
+ d_model: int,
161
+ n_head: int,
162
+ mlp_ratio: float = 4.0,
163
+ ls_init_value: float = None,
164
+ act_layer: Callable = nn.GELU,
165
+ norm_layer: Callable = nn.LayerNorm,
166
+ drop_path: float = 0.0,
167
+ rope: Optional[nn.Module] = None,
168
+ ):
169
+ super().__init__()
170
+
171
+ if rope:
172
+ self.attn = SelfAttention(d_model, n_head, rope=rope)
173
+ else:
174
+ self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=True)
175
+
176
+ self.ls_1 = (
177
+ LayerScale(d_model, ls_init_value)
178
+ if ls_init_value is not None
179
+ else nn.Identity()
180
+ )
181
+ self.ls_2 = (
182
+ LayerScale(d_model, ls_init_value)
183
+ if ls_init_value is not None
184
+ else nn.Identity()
185
+ )
186
+
187
+ self.ln_1 = norm_layer(d_model)
188
+ self.ln_2 = norm_layer(d_model)
189
+
190
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
191
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
192
+
193
+ mlp_width = int(d_model * mlp_ratio)
194
+ self.mlp = nn.Sequential(
195
+ OrderedDict(
196
+ [
197
+ ("c_fc", nn.Linear(d_model, mlp_width)),
198
+ ("gelu", act_layer()),
199
+ ("c_proj", nn.Linear(mlp_width, d_model)),
200
+ ]
201
+ )
202
+ )
203
+
204
+ def _call_attn(
205
+ self,
206
+ q_x: torch.Tensor,
207
+ attn_mask: Optional[torch.Tensor] = None,
208
+ ):
209
+
210
+ if attn_mask is not None:
211
+ # Leave boolean masks as is
212
+ if not attn_mask.dtype == torch.bool:
213
+ attn_mask = attn_mask.to(q_x.dtype)
214
+
215
+ if isinstance(self.attn, SelfAttention):
216
+ return self.attn(q_x, attn_mask=attn_mask)
217
+ else:
218
+ return self.attn(q_x, q_x, q_x, attn_mask=attn_mask, need_weights=False)[0]
219
+
220
+ def forward(
221
+ self,
222
+ x: torch.Tensor,
223
+ attn_mask: Optional[torch.Tensor] = None,
224
+ ):
225
+ x = x + self.drop_path1(
226
+ self.ls_1(self._call_attn(self.ln_1(x), attn_mask=attn_mask))
227
+ )
228
+ x = x + self.drop_path2(self.ls_2(self.mlp(self.ln_2(x))))
229
+ return x
230
+
231
+
232
+ class Transformer(nn.Module):
233
+ def __init__(
234
+ self,
235
+ width: int,
236
+ layers: int,
237
+ heads: int,
238
+ mlp_ratio: float = 4.0,
239
+ ls_init_value: float = None,
240
+ act_layer: Callable = nn.GELU,
241
+ norm_layer: Callable = nn.LayerNorm,
242
+ drop_path: float = 0.0,
243
+ rope: Optional[nn.Module] = None,
244
+ ):
245
+ super().__init__()
246
+ self.width = width
247
+ self.layers = layers
248
+ self.grad_checkpointing = False
249
+
250
+ self.resblocks = nn.ModuleList(
251
+ [
252
+ ResidualAttentionBlock(
253
+ width,
254
+ heads,
255
+ mlp_ratio,
256
+ ls_init_value=ls_init_value,
257
+ act_layer=act_layer,
258
+ norm_layer=norm_layer,
259
+ drop_path=drop_path,
260
+ rope=rope,
261
+ )
262
+ for _ in range(layers)
263
+ ]
264
+ )
265
+
266
+ @torch.jit.ignore
267
+ def set_grad_checkpointing(self, enable=True):
268
+ self.grad_checkpointing = enable
269
+
270
+ @torch.jit.ignore
271
+ def truncate(self, layer_idx: int):
272
+ """Delete layers so the last layer is the given layer index."""
273
+ self.layers = ((self.layers + layer_idx) % self.layers) + 1
274
+ self.resblocks = nn.ModuleList(self.resblocks[: self.layers])
275
+
276
+ def forward(
277
+ self,
278
+ x: torch.Tensor,
279
+ attn_mask: Optional[torch.Tensor] = None,
280
+ layer_idx: int = -1,
281
+ ):
282
+ stop_idx = (self.layers + layer_idx) % self.layers
283
+
284
+ for i, r in enumerate(self.resblocks):
285
+ if self.grad_checkpointing and not torch.jit.is_scripting():
286
+ # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
287
+ x = checkpoint(r, x, None, None, attn_mask)
288
+ else:
289
+ x = r(x, attn_mask=attn_mask)
290
+
291
+ if i == stop_idx:
292
+ break
293
+
294
+ return x
295
+
296
+
297
+ class VisionTransformer(nn.Module):
298
+ def __init__(
299
+ self,
300
+ patch_size: int,
301
+ width: int,
302
+ layers: int,
303
+ heads: int,
304
+ mlp_ratio: float,
305
+ act_layer: Callable = nn.GELU,
306
+ norm_layer: Callable = partial(nn.LayerNorm, eps=1e-5),
307
+ use_ln_pre: bool = True,
308
+ use_ln_post: bool = True,
309
+ ls_init_value: float = None,
310
+ drop_path: float = 0.0,
311
+ image_size: int = 448, # Pretrain image size only; you can pass in any image size
312
+ use_abs_posemb: bool = True,
313
+ use_rope2d: bool = True,
314
+ use_cls_token: bool = False,
315
+ output_dim: Optional[int] = 1280,
316
+ attn_pooler_heads: int = 8,
317
+ pool_type: Literal["attn", "tok", "avg", "none"] = "attn",
318
+ ):
319
+ super().__init__()
320
+ assert pool_type in ("attn", "tok", "avg", "none")
321
+ self.pool_type = pool_type
322
+ self.patch_size = patch_size
323
+
324
+ self.output_dim = output_dim or width
325
+ self.proj_dim = output_dim
326
+ self.heads = heads
327
+ self.width = width
328
+ self.layers = layers
329
+
330
+ self.use_abs_posemb = use_abs_posemb
331
+ self.use_cls_token = use_cls_token
332
+ self.use_rope2d = use_rope2d
333
+ self.image_size = image_size
334
+
335
+ self.conv1 = nn.Conv2d(
336
+ in_channels=3,
337
+ out_channels=width,
338
+ kernel_size=patch_size,
339
+ stride=patch_size,
340
+ bias=False,
341
+ )
342
+ self.rope = (
343
+ Rope2D(
344
+ dim=width // heads,
345
+ use_cls_token=self.use_cls_token,
346
+ )
347
+ if self.use_rope2d
348
+ else None
349
+ )
350
+
351
+ self.ln_pre = norm_layer(width) if use_ln_pre else nn.Identity()
352
+ self.ln_post = norm_layer(self.width) if use_ln_post else nn.Identity()
353
+
354
+ self.transformer = Transformer(
355
+ width,
356
+ layers,
357
+ heads,
358
+ mlp_ratio,
359
+ ls_init_value=ls_init_value,
360
+ act_layer=act_layer,
361
+ norm_layer=norm_layer,
362
+ drop_path=drop_path,
363
+ rope=self.rope,
364
+ )
365
+
366
+ if pool_type == "attn":
367
+ self.attn_pool = AttentionPooling(
368
+ embed_dim=width,
369
+ num_heads=attn_pooler_heads,
370
+ act_layer=act_layer,
371
+ norm_layer=norm_layer,
372
+ )
373
+ else:
374
+ self.attn_pool = None
375
+
376
+ self.init_tensors()
377
+
378
+ def init_tensors(self):
379
+ def init_submodule_tensors(module):
380
+ for name, child in module.named_children():
381
+ if hasattr(child, "init_tensors"):
382
+ logger.debug(f"Initializing tensors for submodule: {name}")
383
+ child.init_tensors()
384
+ init_submodule_tensors(child)
385
+
386
+ init_submodule_tensors(self)
387
+ self.rope.init_tensors()
388
+
389
+ # class embeddings and positional embeddings
390
+ init_scale = self.width**-0.5
391
+
392
+ if self.use_cls_token:
393
+ self.class_embedding = nn.Parameter(init_scale * torch.randn(self.width))
394
+
395
+ if self.use_abs_posemb:
396
+ self.posemb_grid_size = self.image_size // self.patch_size
397
+ self.positional_embedding = nn.Parameter(
398
+ init_scale
399
+ * torch.randn(
400
+ int(self.use_cls_token) + self.posemb_grid_size**2, self.width
401
+ )
402
+ )
403
+
404
+ if self.proj_dim is not None:
405
+ self.proj = nn.Parameter(
406
+ init_scale * torch.randn(self.width, self.proj_dim)
407
+ )
408
+
409
+ def load_ckpt(self, ckpt_path: str):
410
+ _sd = torch.load(ckpt_path, weights_only=True)
411
+ if "state_dict" in _sd:
412
+ _sd = _sd["state_dict"]
413
+ elif "weights" in _sd:
414
+ _sd = _sd["weights"]
415
+
416
+ # for backwards compatibility
417
+ _sd = {k.replace("module.", ""): v for k, v in _sd.items()}
418
+ if any(k.startswith("visual.") for k in _sd):
419
+ _sd = {k.replace("visual.", ""): v for k, v in _sd.items() if "visual" in k}
420
+
421
+ m, u = self.load_state_dict(_sd, strict=False)
422
+ LOGGER.warning(f"Missing keys for loading vision encoder: {m}")
423
+ LOGGER.info(f"Unexpected keys for loading vision encoder: {u}")
424
+
425
+ def truncate(self, layer_idx: int):
426
+ """Delete layers so the last layer is the given layer index."""
427
+ self.transformer.truncate(layer_idx)
428
+ self.layers = self.transformer.layers
429
+
430
+ @classmethod
431
+ def from_config(
432
+ cls,
433
+ name: str,
434
+ pretrained: bool = False,
435
+ checkpoint_path: Optional[str] = None,
436
+ **kwdargs,
437
+ ):
438
+ if name not in PE_VISION_CONFIG:
439
+ raise RuntimeError(f"{name} not found in configs.")
440
+
441
+ args = asdict(PE_VISION_CONFIG[name])
442
+ args.update(kwdargs)
443
+
444
+ model = cls(**args)
445
+ if pretrained:
446
+ model.load_ckpt(fetch_pe_checkpoint(name, checkpoint_path))
447
+
448
+ return model
449
+
450
+ @classmethod
451
+ def available_configs(cls):
452
+ return list(PE_VISION_CONFIG.keys())
453
+
454
+ @torch.jit.ignore
455
+ def set_grad_checkpointing(self, enable=True):
456
+ self.transformer.set_grad_checkpointing(enable=enable)
457
+
458
+ def _sample_abs_posemb(self, grid_h: int, grid_w: int):
459
+ """Interpolates the absolute position embedding if necessary."""
460
+ if self.posemb_grid_size == grid_h and self.posemb_grid_size == grid_w:
461
+ return self.positional_embedding[None, ...]
462
+
463
+ pos_embed = self.positional_embedding
464
+ if self.use_cls_token:
465
+ cls_token_embed, pos_embed = pos_embed[:1], pos_embed[1:]
466
+
467
+ pos_embed = (
468
+ pos_embed.reshape(1, self.posemb_grid_size, self.posemb_grid_size, -1)
469
+ .permute(0, 3, 1, 2)
470
+ .contiguous()
471
+ )
472
+ pos_embed = F.interpolate(
473
+ pos_embed, size=(grid_h, grid_w), mode="bilinear", align_corners=False
474
+ )
475
+ pos_embed = pos_embed.permute(0, 2, 3, 1).reshape(-1, self.width).contiguous()
476
+
477
+ if self.use_cls_token:
478
+ pos_embed = torch.cat([cls_token_embed, pos_embed], dim=0)
479
+
480
+ return pos_embed[None, ...]
481
+
482
+ def _pool(self, x: torch.Tensor):
483
+ if self.pool_type == "tok":
484
+ return x[:, 0]
485
+ elif self.pool_type == "avg":
486
+ return x.mean(dim=1)
487
+ elif self.pool_type == "attn":
488
+ return self.attn_pool(x).squeeze(1)
489
+ elif self.pool_type == "none":
490
+ return x
491
+ else:
492
+ raise NotImplementedError
493
+
494
+ def forward_features(
495
+ self,
496
+ x: torch.Tensor,
497
+ norm: bool = False,
498
+ layer_idx: int = -1,
499
+ strip_cls_token: bool = False,
500
+ ):
501
+ batch, _, h, w = x.shape
502
+ grid_h, grid_w = h // self.patch_size, w // self.patch_size
503
+
504
+ x = self.conv1(x)
505
+ x = x.permute(0, 2, 3, 1).reshape(batch, -1, self.width)
506
+
507
+ if self.use_cls_token:
508
+ x = torch.cat(
509
+ [self.class_embedding.view(1, 1, -1).expand(batch, -1, -1), x],
510
+ dim=1,
511
+ )
512
+
513
+ if self.use_abs_posemb:
514
+ x = x + self._sample_abs_posemb(grid_h, grid_w)
515
+
516
+ if self.use_rope2d:
517
+ self.rope.update_grid(x.device, grid_h, grid_w)
518
+
519
+ x = self.ln_pre(x)
520
+ x = self.transformer(x, layer_idx=layer_idx)
521
+
522
+ if norm:
523
+ x = self.ln_post(x)
524
+
525
+ if strip_cls_token and self.use_cls_token:
526
+ x = x[:, 1:, :]
527
+
528
+ return x
529
+
530
+ def forward(self, x: torch.Tensor, **kwargs):
531
+ x = self.forward_features(x, norm=True, **kwargs)
532
+ x = self._pool(x)
533
+
534
+ if self.proj_dim is not None:
535
+ x = x @ self.proj
536
+
537
+ return x
538
+
539
+
540
+ class TextTransformer(nn.Module):
541
+ def __init__(
542
+ self,
543
+ context_length: int = 72,
544
+ vocab_size: int = 49408,
545
+ width: int = 512,
546
+ heads: int = 8,
547
+ layers: int = 12,
548
+ mlp_ratio: float = 4.0,
549
+ ls_init_value: float = None,
550
+ output_dim: int = 1280,
551
+ no_causal_mask: bool = False,
552
+ pad_id: int = 0,
553
+ pool_type: str = "argmax",
554
+ proj_bias: bool = False,
555
+ act_layer: Callable = nn.GELU,
556
+ norm_layer: Callable = partial(nn.LayerNorm, eps=1e-5),
557
+ output_tokens: bool = False,
558
+ use_ln_post: bool = True,
559
+ ):
560
+ super().__init__()
561
+ assert pool_type in ("first", "last", "argmax", "none")
562
+ self.pool_type = pool_type
563
+ self.output_tokens = output_tokens
564
+ self.num_pos = self.context_length = context_length
565
+ self.vocab_size = vocab_size
566
+ self.width = width
567
+ self.output_dim = output_dim
568
+ self.heads = heads
569
+ self.pad_id = pad_id
570
+ self.layers = layers
571
+
572
+ self.token_embedding = nn.Embedding(vocab_size, width)
573
+ self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width))
574
+
575
+ self.transformer = Transformer(
576
+ width=width,
577
+ layers=layers,
578
+ heads=heads,
579
+ mlp_ratio=mlp_ratio,
580
+ ls_init_value=ls_init_value,
581
+ act_layer=act_layer,
582
+ norm_layer=norm_layer,
583
+ )
584
+
585
+ self.ln_final = norm_layer(width) if use_ln_post else nn.Identity()
586
+
587
+ if no_causal_mask:
588
+ self.attn_mask = None
589
+ else:
590
+ self.register_buffer(
591
+ "attn_mask", self.build_causal_mask(), persistent=False
592
+ )
593
+
594
+ if pool_type == "attn" or pool_type == "attn_eos":
595
+ self.attn_pool = AttentionPooling(
596
+ embed_dim=width,
597
+ num_heads=heads,
598
+ act_layer=act_layer,
599
+ norm_layer=norm_layer,
600
+ )
601
+ else: # argmax
602
+ self.attn_pool = None
603
+
604
+ if proj_bias:
605
+ self.text_projection = nn.Linear(width, output_dim)
606
+ else:
607
+ self.text_projection = nn.Parameter(torch.empty(width, output_dim))
608
+
609
+ def build_causal_mask(self):
610
+ # lazily create causal attention mask, with full attention between the tokens
611
+ # pytorch uses additive attention mask; fill with -inf
612
+ mask = torch.empty(self.num_pos, self.num_pos)
613
+ mask.fill_(float("-inf"))
614
+ mask.triu_(1) # zero out the lower diagonal
615
+ return mask
616
+
617
+ def load_ckpt(self, ckpt_path: str):
618
+ _sd = torch.load(ckpt_path, weights_only=True)
619
+ if "state_dict" in _sd:
620
+ _sd = _sd["state_dict"]
621
+ elif "weights" in _sd:
622
+ _sd = _sd["weights"]
623
+
624
+ _sd = {k.replace("module.", ""): v for k, v in _sd.items()}
625
+
626
+ m, u = self.load_state_dict(_sd, strict=False)
627
+
628
+ if m:
629
+ LOGGER.warning(f"Missing keys for loading model: {m}")
630
+ if u:
631
+ LOGGER.warning(f"Unexpected keys for loading model: {u}")
632
+
633
+ def build_cls_mask(self, text):
634
+ cls_mask = (text != self.pad_id).unsqueeze(1)
635
+ cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=True)
636
+ additive_mask = torch.empty(cls_mask.shape, device=cls_mask.device)
637
+ additive_mask.fill_(0)
638
+ additive_mask.masked_fill_(~cls_mask, float("-inf"))
639
+ additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0)
640
+ return additive_mask
641
+
642
+ def text_global_pool(
643
+ self, x, text: Optional[torch.Tensor] = None, pool_type: str = "argmax"
644
+ ):
645
+ if pool_type == "first":
646
+ pooled, tokens = x[:, 0], x[:, 1:]
647
+ elif pool_type == "last":
648
+ pooled, tokens = x[:, -1], x[:, :-1]
649
+ elif pool_type == "argmax":
650
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
651
+ assert text is not None
652
+ pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x
653
+ else:
654
+ pooled = tokens = x
655
+
656
+ return pooled, tokens
657
+
658
+ def forward(self, text):
659
+ seq_len = text.shape[1]
660
+ x = self.token_embedding(text)
661
+ attn_mask = self.attn_mask
662
+ if attn_mask is not None:
663
+ attn_mask = attn_mask[:seq_len, :seq_len]
664
+
665
+ x = x + self.positional_embedding[:seq_len]
666
+ x = self.transformer(x, attn_mask=attn_mask)
667
+
668
+ x = self.ln_final(x)
669
+ pooled, tokens = self.text_global_pool(x, text, pool_type=self.pool_type)
670
+
671
+ if self.text_projection is not None:
672
+ if isinstance(self.text_projection, nn.Linear):
673
+ pooled = self.text_projection(pooled)
674
+ else:
675
+ pooled = pooled @ self.text_projection
676
+
677
+ if self.output_tokens:
678
+ return pooled, tokens
679
+
680
+ return pooled
681
+
682
+
683
+ class CLIP(TextTransformer):
684
+ def __init__(
685
+ self,
686
+ vision_cfg: PEConfig,
687
+ text_cfg: PETextConfig,
688
+ init_logit_scale: float = np.log(1 / 0.07),
689
+ ):
690
+ super(CLIP, self).__init__(**asdict(text_cfg))
691
+ self.visual = VisionTransformer(**asdict(vision_cfg))
692
+ self.image_size = self.visual.image_size # For ease of use
693
+ self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale)
694
+
695
+ def encode_image(self, image, normalize: bool = False):
696
+ x = self.visual(image)
697
+ return F.normalize(x, dim=-1) if normalize else x
698
+
699
+ def encode_video(self, video, normalize: bool = False): # b n c h w
700
+ b, n, c, h, w = video.shape
701
+ frms = video.reshape(b * n, c, h, w)
702
+ frm_feats = self.encode_image(frms, normalize=normalize)
703
+ video_feats = frm_feats.reshape(b, n, -1)
704
+ video_feats = video_feats.mean(dim=1)
705
+ return video_feats
706
+
707
+ def encode_text(self, text, normalize: bool = False):
708
+ x = super().forward(text)
709
+ return F.normalize(x, dim=-1) if normalize else x
710
+
711
+ def forward(
712
+ self,
713
+ image: Optional[torch.Tensor] = None,
714
+ text: Optional[torch.Tensor] = None,
715
+ ):
716
+ image_features = (
717
+ self.encode_image(image, normalize=True) if image is not None else None
718
+ )
719
+ text_features = (
720
+ self.encode_text(text, normalize=True) if text is not None else None
721
+ )
722
+ return image_features, text_features, self.logit_scale.exp()
723
+
724
+ @classmethod
725
+ def from_config(
726
+ cls,
727
+ name: str,
728
+ pretrained: bool = False,
729
+ checkpoint_path: Optional[str] = None, # To load your own
730
+ ):
731
+ if name not in PE_VISION_CONFIG or name not in PE_TEXT_CONFIG:
732
+ raise RuntimeError(f"{name} not found in configs.")
733
+
734
+ model = cls(PE_VISION_CONFIG[name], PE_TEXT_CONFIG[name])
735
+ if pretrained:
736
+ model.load_ckpt(fetch_pe_checkpoint(name, checkpoint_path))
737
+
738
+ return model
739
+
740
+ @classmethod
741
+ def available_configs(cls):
742
+ return [k for k in PE_VISION_CONFIG if k in PE_TEXT_CONFIG]