inference-models 0.18.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. inference_models/__init__.py +36 -0
  2. inference_models/configuration.py +72 -0
  3. inference_models/constants.py +2 -0
  4. inference_models/entities.py +5 -0
  5. inference_models/errors.py +137 -0
  6. inference_models/logger.py +52 -0
  7. inference_models/model_pipelines/__init__.py +0 -0
  8. inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
  9. inference_models/model_pipelines/auto_loaders/core.py +120 -0
  10. inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
  11. inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
  12. inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
  13. inference_models/models/__init__.py +0 -0
  14. inference_models/models/auto_loaders/__init__.py +0 -0
  15. inference_models/models/auto_loaders/access_manager.py +168 -0
  16. inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
  17. inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
  18. inference_models/models/auto_loaders/constants.py +7 -0
  19. inference_models/models/auto_loaders/core.py +1341 -0
  20. inference_models/models/auto_loaders/dependency_models.py +52 -0
  21. inference_models/models/auto_loaders/entities.py +57 -0
  22. inference_models/models/auto_loaders/models_registry.py +497 -0
  23. inference_models/models/auto_loaders/presentation_utils.py +333 -0
  24. inference_models/models/auto_loaders/ranking.py +413 -0
  25. inference_models/models/auto_loaders/utils.py +31 -0
  26. inference_models/models/base/__init__.py +0 -0
  27. inference_models/models/base/classification.py +123 -0
  28. inference_models/models/base/depth_estimation.py +62 -0
  29. inference_models/models/base/documents_parsing.py +111 -0
  30. inference_models/models/base/embeddings.py +66 -0
  31. inference_models/models/base/instance_segmentation.py +87 -0
  32. inference_models/models/base/keypoints_detection.py +93 -0
  33. inference_models/models/base/object_detection.py +143 -0
  34. inference_models/models/base/semantic_segmentation.py +74 -0
  35. inference_models/models/base/types.py +5 -0
  36. inference_models/models/clip/__init__.py +0 -0
  37. inference_models/models/clip/clip_onnx.py +148 -0
  38. inference_models/models/clip/clip_pytorch.py +104 -0
  39. inference_models/models/clip/preprocessing.py +162 -0
  40. inference_models/models/common/__init__.py +0 -0
  41. inference_models/models/common/cuda.py +30 -0
  42. inference_models/models/common/model_packages.py +25 -0
  43. inference_models/models/common/onnx.py +379 -0
  44. inference_models/models/common/roboflow/__init__.py +0 -0
  45. inference_models/models/common/roboflow/model_packages.py +361 -0
  46. inference_models/models/common/roboflow/post_processing.py +436 -0
  47. inference_models/models/common/roboflow/pre_processing.py +1332 -0
  48. inference_models/models/common/torch.py +20 -0
  49. inference_models/models/common/trt.py +266 -0
  50. inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
  51. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
  52. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
  53. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
  54. inference_models/models/depth_anything_v2/__init__.py +0 -0
  55. inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
  56. inference_models/models/dinov3/__init__.py +0 -0
  57. inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
  58. inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
  59. inference_models/models/doctr/__init__.py +0 -0
  60. inference_models/models/doctr/doctr_torch.py +304 -0
  61. inference_models/models/easy_ocr/__init__.py +0 -0
  62. inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
  63. inference_models/models/florence2/__init__.py +0 -0
  64. inference_models/models/florence2/florence2_hf.py +897 -0
  65. inference_models/models/grounding_dino/__init__.py +0 -0
  66. inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
  67. inference_models/models/l2cs/__init__.py +0 -0
  68. inference_models/models/l2cs/l2cs_onnx.py +216 -0
  69. inference_models/models/mediapipe_face_detection/__init__.py +0 -0
  70. inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
  71. inference_models/models/moondream2/__init__.py +0 -0
  72. inference_models/models/moondream2/moondream2_hf.py +281 -0
  73. inference_models/models/owlv2/__init__.py +0 -0
  74. inference_models/models/owlv2/cache.py +182 -0
  75. inference_models/models/owlv2/entities.py +112 -0
  76. inference_models/models/owlv2/owlv2_hf.py +695 -0
  77. inference_models/models/owlv2/reference_dataset.py +291 -0
  78. inference_models/models/paligemma/__init__.py +0 -0
  79. inference_models/models/paligemma/paligemma_hf.py +209 -0
  80. inference_models/models/perception_encoder/__init__.py +0 -0
  81. inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
  82. inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
  83. inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
  84. inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
  85. inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
  86. inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
  87. inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
  88. inference_models/models/qwen25vl/__init__.py +1 -0
  89. inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
  90. inference_models/models/resnet/__init__.py +0 -0
  91. inference_models/models/resnet/resnet_classification_onnx.py +330 -0
  92. inference_models/models/resnet/resnet_classification_torch.py +305 -0
  93. inference_models/models/resnet/resnet_classification_trt.py +369 -0
  94. inference_models/models/rfdetr/__init__.py +0 -0
  95. inference_models/models/rfdetr/backbone_builder.py +101 -0
  96. inference_models/models/rfdetr/class_remapping.py +41 -0
  97. inference_models/models/rfdetr/common.py +115 -0
  98. inference_models/models/rfdetr/default_labels.py +108 -0
  99. inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
  100. inference_models/models/rfdetr/misc.py +26 -0
  101. inference_models/models/rfdetr/ms_deform_attn.py +180 -0
  102. inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
  103. inference_models/models/rfdetr/position_encoding.py +166 -0
  104. inference_models/models/rfdetr/post_processor.py +83 -0
  105. inference_models/models/rfdetr/projector.py +373 -0
  106. inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
  107. inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
  108. inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
  109. inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
  110. inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
  111. inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
  112. inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
  113. inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
  114. inference_models/models/rfdetr/segmentation_head.py +273 -0
  115. inference_models/models/rfdetr/transformer.py +767 -0
  116. inference_models/models/roboflow_instant/__init__.py +0 -0
  117. inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
  118. inference_models/models/sam/__init__.py +0 -0
  119. inference_models/models/sam/cache.py +147 -0
  120. inference_models/models/sam/entities.py +25 -0
  121. inference_models/models/sam/sam_torch.py +675 -0
  122. inference_models/models/sam2/__init__.py +0 -0
  123. inference_models/models/sam2/cache.py +162 -0
  124. inference_models/models/sam2/entities.py +43 -0
  125. inference_models/models/sam2/sam2_torch.py +905 -0
  126. inference_models/models/sam2_rt/__init__.py +0 -0
  127. inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
  128. inference_models/models/smolvlm/__init__.py +0 -0
  129. inference_models/models/smolvlm/smolvlm_hf.py +245 -0
  130. inference_models/models/trocr/__init__.py +0 -0
  131. inference_models/models/trocr/trocr_hf.py +53 -0
  132. inference_models/models/vit/__init__.py +0 -0
  133. inference_models/models/vit/vit_classification_huggingface.py +319 -0
  134. inference_models/models/vit/vit_classification_onnx.py +326 -0
  135. inference_models/models/vit/vit_classification_trt.py +365 -0
  136. inference_models/models/yolact/__init__.py +1 -0
  137. inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
  138. inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
  139. inference_models/models/yolo_world/__init__.py +1 -0
  140. inference_models/models/yolonas/__init__.py +0 -0
  141. inference_models/models/yolonas/nms.py +44 -0
  142. inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
  143. inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
  144. inference_models/models/yolov10/__init__.py +0 -0
  145. inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
  146. inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
  147. inference_models/models/yolov11/__init__.py +0 -0
  148. inference_models/models/yolov11/yolov11_onnx.py +28 -0
  149. inference_models/models/yolov11/yolov11_torch_script.py +25 -0
  150. inference_models/models/yolov11/yolov11_trt.py +21 -0
  151. inference_models/models/yolov12/__init__.py +0 -0
  152. inference_models/models/yolov12/yolov12_onnx.py +7 -0
  153. inference_models/models/yolov12/yolov12_torch_script.py +7 -0
  154. inference_models/models/yolov12/yolov12_trt.py +7 -0
  155. inference_models/models/yolov5/__init__.py +0 -0
  156. inference_models/models/yolov5/nms.py +99 -0
  157. inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
  158. inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
  159. inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
  160. inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
  161. inference_models/models/yolov7/__init__.py +0 -0
  162. inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
  163. inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
  164. inference_models/models/yolov8/__init__.py +0 -0
  165. inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
  166. inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
  167. inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
  168. inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
  169. inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
  170. inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
  171. inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
  172. inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
  173. inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
  174. inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
  175. inference_models/models/yolov9/__init__.py +0 -0
  176. inference_models/models/yolov9/yolov9_onnx.py +7 -0
  177. inference_models/models/yolov9/yolov9_torch_script.py +7 -0
  178. inference_models/models/yolov9/yolov9_trt.py +7 -0
  179. inference_models/runtime_introspection/__init__.py +0 -0
  180. inference_models/runtime_introspection/core.py +410 -0
  181. inference_models/utils/__init__.py +0 -0
  182. inference_models/utils/download.py +608 -0
  183. inference_models/utils/environment.py +28 -0
  184. inference_models/utils/file_system.py +51 -0
  185. inference_models/utils/hashing.py +7 -0
  186. inference_models/utils/imports.py +48 -0
  187. inference_models/utils/onnx_introspection.py +17 -0
  188. inference_models/weights_providers/__init__.py +0 -0
  189. inference_models/weights_providers/core.py +20 -0
  190. inference_models/weights_providers/entities.py +159 -0
  191. inference_models/weights_providers/roboflow.py +601 -0
  192. inference_models-0.18.3.dist-info/METADATA +466 -0
  193. inference_models-0.18.3.dist-info/RECORD +195 -0
  194. inference_models-0.18.3.dist-info/WHEEL +5 -0
  195. inference_models-0.18.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,373 @@
1
+ # ------------------------------------------------------------------------
2
+ # RF-DETR
3
+ # Copyright (c) 2025 Roboflow. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR)
7
+ # Copyright (c) 2024 Baidu. All Rights Reserved.
8
+ # ------------------------------------------------------------------------
9
+ # Modified from ViTDet (https://github.com/facebookresearch/detectron2/tree/main/projects/ViTDet)
10
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
11
+ # ------------------------------------------------------------------------
12
+
13
+ """
14
+ Projector
15
+ """
16
+ import math
17
+ import random
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+
24
+
25
+ class LayerNorm(nn.Module):
26
+ """
27
+ A LayerNorm variant, popularized by Transformers, that performs point-wise mean and
28
+ variance normalization over the channel dimension for inputs that have shape
29
+ (batch_size, channels, height, width).
30
+ https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119
31
+ """
32
+
33
+ def __init__(self, normalized_shape, eps=1e-6):
34
+ super().__init__()
35
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
36
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
37
+ self.eps = eps
38
+ self.normalized_shape = (normalized_shape,)
39
+
40
+ def forward(self, x):
41
+ """
42
+ LayerNorm forward
43
+ TODO: this is a hack to avoid overflow when using fp16
44
+ """
45
+ x = x.permute(0, 2, 3, 1)
46
+ x = F.layer_norm(x, (x.size(3),), self.weight, self.bias, self.eps)
47
+ x = x.permute(0, 3, 1, 2)
48
+ return x
49
+
50
+
51
+ def get_norm(norm, out_channels):
52
+ """
53
+ Args:
54
+ norm (str or callable): either one of BN, SyncBN, FrozenBN, GN;
55
+ or a callable that takes a channel number and returns
56
+ the normalization layer as a nn.Module.
57
+ Returns:
58
+ nn.Module or None: the normalization layer
59
+ """
60
+ if norm is None:
61
+ return None
62
+ if isinstance(norm, str):
63
+ if len(norm) == 0:
64
+ return None
65
+ norm = {
66
+ "LN": lambda channels: LayerNorm(channels),
67
+ }[norm]
68
+ return norm(out_channels)
69
+
70
+
71
+ def get_activation(name, inplace=False):
72
+ """get activation"""
73
+ if name == "silu":
74
+ module = nn.SiLU(inplace=inplace)
75
+ elif name == "relu":
76
+ module = nn.ReLU(inplace=inplace)
77
+ elif name in ["LeakyReLU", "leakyrelu", "lrelu"]:
78
+ module = nn.LeakyReLU(0.1, inplace=inplace)
79
+ elif name is None:
80
+ module = nn.Identity()
81
+ else:
82
+ raise AttributeError("Unsupported act type: {}".format(name))
83
+ return module
84
+
85
+
86
+ class ConvX(nn.Module):
87
+ """Conv-bn module"""
88
+
89
+ def __init__(
90
+ self,
91
+ in_planes,
92
+ out_planes,
93
+ kernel=3,
94
+ stride=1,
95
+ groups=1,
96
+ dilation=1,
97
+ act="relu",
98
+ layer_norm=False,
99
+ rms_norm=False,
100
+ ):
101
+ super(ConvX, self).__init__()
102
+ if not isinstance(kernel, tuple):
103
+ kernel = (kernel, kernel)
104
+ padding = (kernel[0] // 2, kernel[1] // 2)
105
+ self.conv = nn.Conv2d(
106
+ in_planes,
107
+ out_planes,
108
+ kernel_size=kernel,
109
+ stride=stride,
110
+ padding=padding,
111
+ groups=groups,
112
+ dilation=dilation,
113
+ bias=False,
114
+ )
115
+ if rms_norm:
116
+ self.bn = nn.RMSNorm(out_planes)
117
+ else:
118
+ self.bn = (
119
+ get_norm("LN", out_planes) if layer_norm else nn.BatchNorm2d(out_planes)
120
+ )
121
+ self.act = get_activation(act, inplace=True)
122
+
123
+ def forward(self, x):
124
+ """forward"""
125
+ out = self.act(self.bn(self.conv(x.contiguous())))
126
+ return out
127
+
128
+
129
+ class Bottleneck(nn.Module):
130
+ """Standard bottleneck."""
131
+
132
+ def __init__(
133
+ self,
134
+ c1,
135
+ c2,
136
+ shortcut=True,
137
+ g=1,
138
+ k=(3, 3),
139
+ e=0.5,
140
+ act="silu",
141
+ layer_norm=False,
142
+ rms_norm=False,
143
+ ):
144
+ """ch_in, ch_out, shortcut, groups, kernels, expand"""
145
+ super().__init__()
146
+ c_ = int(c2 * e) # hidden channels
147
+ self.cv1 = ConvX(
148
+ c1, c_, k[0], 1, act=act, layer_norm=layer_norm, rms_norm=rms_norm
149
+ )
150
+ self.cv2 = ConvX(
151
+ c_, c2, k[1], 1, groups=g, act=act, layer_norm=layer_norm, rms_norm=rms_norm
152
+ )
153
+ self.add = shortcut and c1 == c2
154
+
155
+ def forward(self, x):
156
+ """'forward()' applies the YOLOv5 FPN to input data."""
157
+ return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
158
+
159
+
160
+ class C2f(nn.Module):
161
+ """Faster Implementation of CSP Bottleneck with 2 convolutions."""
162
+
163
+ def __init__(
164
+ self,
165
+ c1,
166
+ c2,
167
+ n=1,
168
+ shortcut=False,
169
+ g=1,
170
+ e=0.5,
171
+ act="silu",
172
+ layer_norm=False,
173
+ rms_norm=False,
174
+ ):
175
+ """ch_in, ch_out, number, shortcut, groups, expansion"""
176
+ super().__init__()
177
+ self.c = int(c2 * e) # hidden channels
178
+ self.cv1 = ConvX(
179
+ c1, 2 * self.c, 1, 1, act=act, layer_norm=layer_norm, rms_norm=rms_norm
180
+ )
181
+ self.cv2 = ConvX(
182
+ (2 + n) * self.c, c2, 1, act=act, layer_norm=layer_norm, rms_norm=rms_norm
183
+ ) # optional act=FReLU(c2)
184
+ self.m = nn.ModuleList(
185
+ Bottleneck(
186
+ self.c,
187
+ self.c,
188
+ shortcut,
189
+ g,
190
+ k=(3, 3),
191
+ e=1.0,
192
+ act=act,
193
+ layer_norm=layer_norm,
194
+ rms_norm=rms_norm,
195
+ )
196
+ for _ in range(n)
197
+ )
198
+
199
+ def forward(self, x):
200
+ """Forward pass using split() instead of chunk()."""
201
+ y = list(self.cv1(x).split((self.c, self.c), 1))
202
+ y.extend(m(y[-1]) for m in self.m)
203
+ return self.cv2(torch.cat(y, 1))
204
+
205
+
206
+ class MultiScaleProjector(nn.Module):
207
+ """
208
+ This module implements MultiScaleProjector in :paper:`lwdetr`.
209
+ It creates pyramid features built on top of the input feature map.
210
+ """
211
+
212
+ def __init__(
213
+ self,
214
+ in_channels,
215
+ out_channels,
216
+ scale_factors,
217
+ num_blocks=3,
218
+ layer_norm=False,
219
+ rms_norm=False,
220
+ survival_prob=1.0,
221
+ force_drop_last_n_features=0,
222
+ ):
223
+ """
224
+ Args:
225
+ net (Backbone): module representing the subnetwork backbone.
226
+ Must be a subclass of :class:`Backbone`.
227
+ out_channels (int): number of channels in the output feature maps.
228
+ scale_factors (list[float]): list of scaling factors to upsample or downsample
229
+ the input features for creating pyramid features.
230
+ """
231
+ super(MultiScaleProjector, self).__init__()
232
+
233
+ self.scale_factors = scale_factors
234
+ self.survival_prob = survival_prob
235
+ self.force_drop_last_n_features = force_drop_last_n_features
236
+
237
+ stages_sampling = []
238
+ stages = []
239
+ # use_bias = norm == ""
240
+ use_bias = False
241
+ self.use_extra_pool = False
242
+ for scale in scale_factors:
243
+ stages_sampling.append([])
244
+ for in_dim in in_channels:
245
+ out_dim = in_dim
246
+ layers = []
247
+
248
+ # if in_dim > 512:
249
+ # layers.append(ConvX(in_dim, in_dim // 2, kernel=1))
250
+ # in_dim = in_dim // 2
251
+
252
+ if scale == 4.0:
253
+ layers.extend(
254
+ [
255
+ nn.ConvTranspose2d(
256
+ in_dim, in_dim // 2, kernel_size=2, stride=2
257
+ ),
258
+ get_norm("LN", in_dim // 2),
259
+ nn.GELU(),
260
+ nn.ConvTranspose2d(
261
+ in_dim // 2, in_dim // 4, kernel_size=2, stride=2
262
+ ),
263
+ ]
264
+ )
265
+ out_dim = in_dim // 4
266
+ elif scale == 2.0:
267
+ # a hack to reduce the FLOPs and Params when the dimention of output feature is too large
268
+ # if in_dim > 512:
269
+ # layers = [
270
+ # ConvX(in_dim, in_dim // 2, kernel=1),
271
+ # nn.ConvTranspose2d(in_dim // 2, in_dim // 4, kernel_size=2, stride=2),
272
+ # ]
273
+ # out_dim = in_dim // 4
274
+ # else:
275
+ layers.extend(
276
+ [
277
+ nn.ConvTranspose2d(
278
+ in_dim, in_dim // 2, kernel_size=2, stride=2
279
+ ),
280
+ ]
281
+ )
282
+ out_dim = in_dim // 2
283
+ elif scale == 1.0:
284
+ pass
285
+ elif scale == 0.5:
286
+ layers.extend(
287
+ [
288
+ ConvX(in_dim, in_dim, 3, 2, layer_norm=layer_norm),
289
+ ]
290
+ )
291
+ elif scale == 0.25:
292
+ self.use_extra_pool = True
293
+ continue
294
+ else:
295
+ raise NotImplementedError(
296
+ "Unsupported scale_factor:{}".format(scale)
297
+ )
298
+ layers = nn.Sequential(*layers)
299
+ stages_sampling[-1].append(layers)
300
+ stages_sampling[-1] = nn.ModuleList(stages_sampling[-1])
301
+
302
+ in_dim = int(sum(in_channel // max(1, scale) for in_channel in in_channels))
303
+ layers = [
304
+ C2f(in_dim, out_channels, num_blocks, layer_norm=layer_norm),
305
+ get_norm("LN", out_channels),
306
+ ]
307
+ layers = nn.Sequential(*layers)
308
+ stages.append(layers)
309
+
310
+ self.stages_sampling = nn.ModuleList(stages_sampling)
311
+ self.stages = nn.ModuleList(stages)
312
+
313
+ def forward(self, x):
314
+ """
315
+ Args:
316
+ x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
317
+ Returns:
318
+ dict[str->Tensor]:
319
+ mapping from feature map name to pyramid feature map tensor
320
+ in high to low resolution order. Returned feature names follow the FPN
321
+ convention: "p<stage>", where stage has stride = 2 ** stage e.g.,
322
+ ["p2", "p3", ..., "p6"].
323
+ """
324
+ num_features = len(x)
325
+ if self.survival_prob < 1.0 and self.training:
326
+ final_drop_prob = 1 - self.survival_prob
327
+ drop_p = np.random.uniform()
328
+ for i in range(1, num_features):
329
+ critical_drop_prob = i * (final_drop_prob / (num_features - 1))
330
+ if drop_p < critical_drop_prob:
331
+ x[i][:] = 0
332
+ elif self.force_drop_last_n_features > 0:
333
+ for i in range(self.force_drop_last_n_features):
334
+ # don't do it inplace to ensure the compiler can optimize out the backbone layers
335
+ x[-(i + 1)] = torch.zeros_like(x[-(i + 1)])
336
+
337
+ results = []
338
+ # x list of len(out_features_indexes)
339
+ for i, stage in enumerate(self.stages):
340
+ feat_fuse = []
341
+ for j, stage_sampling in enumerate(self.stages_sampling[i]):
342
+ feat_fuse.append(stage_sampling(x[j]))
343
+ if len(feat_fuse) > 1:
344
+ feat_fuse = torch.cat(feat_fuse, dim=1)
345
+ else:
346
+ feat_fuse = feat_fuse[0]
347
+ results.append(stage(feat_fuse))
348
+ if self.use_extra_pool:
349
+ results.append(
350
+ F.max_pool2d(results[-1], kernel_size=1, stride=2, padding=0)
351
+ )
352
+ return results
353
+
354
+
355
+ class SimpleProjector(nn.Module):
356
+ def __init__(self, in_dim, out_dim, factor_kernel=False):
357
+ super(SimpleProjector, self).__init__()
358
+ if not factor_kernel:
359
+ self.convx1 = ConvX(in_dim, in_dim * 2, layer_norm=True, act="silu")
360
+ self.convx2 = ConvX(in_dim * 2, out_dim, layer_norm=True, act="silu")
361
+ else:
362
+ self.convx1 = ConvX(
363
+ in_dim, out_dim, kernel=(3, 1), layer_norm=True, act="silu"
364
+ )
365
+ self.convx2 = ConvX(
366
+ out_dim, out_dim, kernel=(1, 3), layer_norm=True, act="silu"
367
+ )
368
+ self.ln = get_norm("LN", out_dim)
369
+
370
+ def forward(self, x):
371
+ """forward"""
372
+ out = self.ln(self.convx2(self.convx1(x[0])))
373
+ return [out]