inference-models 0.18.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inference_models/__init__.py +36 -0
- inference_models/configuration.py +72 -0
- inference_models/constants.py +2 -0
- inference_models/entities.py +5 -0
- inference_models/errors.py +137 -0
- inference_models/logger.py +52 -0
- inference_models/model_pipelines/__init__.py +0 -0
- inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
- inference_models/model_pipelines/auto_loaders/core.py +120 -0
- inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
- inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
- inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
- inference_models/models/__init__.py +0 -0
- inference_models/models/auto_loaders/__init__.py +0 -0
- inference_models/models/auto_loaders/access_manager.py +168 -0
- inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
- inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
- inference_models/models/auto_loaders/constants.py +7 -0
- inference_models/models/auto_loaders/core.py +1341 -0
- inference_models/models/auto_loaders/dependency_models.py +52 -0
- inference_models/models/auto_loaders/entities.py +57 -0
- inference_models/models/auto_loaders/models_registry.py +497 -0
- inference_models/models/auto_loaders/presentation_utils.py +333 -0
- inference_models/models/auto_loaders/ranking.py +413 -0
- inference_models/models/auto_loaders/utils.py +31 -0
- inference_models/models/base/__init__.py +0 -0
- inference_models/models/base/classification.py +123 -0
- inference_models/models/base/depth_estimation.py +62 -0
- inference_models/models/base/documents_parsing.py +111 -0
- inference_models/models/base/embeddings.py +66 -0
- inference_models/models/base/instance_segmentation.py +87 -0
- inference_models/models/base/keypoints_detection.py +93 -0
- inference_models/models/base/object_detection.py +143 -0
- inference_models/models/base/semantic_segmentation.py +74 -0
- inference_models/models/base/types.py +5 -0
- inference_models/models/clip/__init__.py +0 -0
- inference_models/models/clip/clip_onnx.py +148 -0
- inference_models/models/clip/clip_pytorch.py +104 -0
- inference_models/models/clip/preprocessing.py +162 -0
- inference_models/models/common/__init__.py +0 -0
- inference_models/models/common/cuda.py +30 -0
- inference_models/models/common/model_packages.py +25 -0
- inference_models/models/common/onnx.py +379 -0
- inference_models/models/common/roboflow/__init__.py +0 -0
- inference_models/models/common/roboflow/model_packages.py +361 -0
- inference_models/models/common/roboflow/post_processing.py +436 -0
- inference_models/models/common/roboflow/pre_processing.py +1332 -0
- inference_models/models/common/torch.py +20 -0
- inference_models/models/common/trt.py +266 -0
- inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
- inference_models/models/depth_anything_v2/__init__.py +0 -0
- inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
- inference_models/models/dinov3/__init__.py +0 -0
- inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
- inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
- inference_models/models/doctr/__init__.py +0 -0
- inference_models/models/doctr/doctr_torch.py +304 -0
- inference_models/models/easy_ocr/__init__.py +0 -0
- inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
- inference_models/models/florence2/__init__.py +0 -0
- inference_models/models/florence2/florence2_hf.py +897 -0
- inference_models/models/grounding_dino/__init__.py +0 -0
- inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
- inference_models/models/l2cs/__init__.py +0 -0
- inference_models/models/l2cs/l2cs_onnx.py +216 -0
- inference_models/models/mediapipe_face_detection/__init__.py +0 -0
- inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
- inference_models/models/moondream2/__init__.py +0 -0
- inference_models/models/moondream2/moondream2_hf.py +281 -0
- inference_models/models/owlv2/__init__.py +0 -0
- inference_models/models/owlv2/cache.py +182 -0
- inference_models/models/owlv2/entities.py +112 -0
- inference_models/models/owlv2/owlv2_hf.py +695 -0
- inference_models/models/owlv2/reference_dataset.py +291 -0
- inference_models/models/paligemma/__init__.py +0 -0
- inference_models/models/paligemma/paligemma_hf.py +209 -0
- inference_models/models/perception_encoder/__init__.py +0 -0
- inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
- inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
- inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
- inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
- inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
- inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
- inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
- inference_models/models/qwen25vl/__init__.py +1 -0
- inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
- inference_models/models/resnet/__init__.py +0 -0
- inference_models/models/resnet/resnet_classification_onnx.py +330 -0
- inference_models/models/resnet/resnet_classification_torch.py +305 -0
- inference_models/models/resnet/resnet_classification_trt.py +369 -0
- inference_models/models/rfdetr/__init__.py +0 -0
- inference_models/models/rfdetr/backbone_builder.py +101 -0
- inference_models/models/rfdetr/class_remapping.py +41 -0
- inference_models/models/rfdetr/common.py +115 -0
- inference_models/models/rfdetr/default_labels.py +108 -0
- inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
- inference_models/models/rfdetr/misc.py +26 -0
- inference_models/models/rfdetr/ms_deform_attn.py +180 -0
- inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
- inference_models/models/rfdetr/position_encoding.py +166 -0
- inference_models/models/rfdetr/post_processor.py +83 -0
- inference_models/models/rfdetr/projector.py +373 -0
- inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
- inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
- inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
- inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
- inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
- inference_models/models/rfdetr/segmentation_head.py +273 -0
- inference_models/models/rfdetr/transformer.py +767 -0
- inference_models/models/roboflow_instant/__init__.py +0 -0
- inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
- inference_models/models/sam/__init__.py +0 -0
- inference_models/models/sam/cache.py +147 -0
- inference_models/models/sam/entities.py +25 -0
- inference_models/models/sam/sam_torch.py +675 -0
- inference_models/models/sam2/__init__.py +0 -0
- inference_models/models/sam2/cache.py +162 -0
- inference_models/models/sam2/entities.py +43 -0
- inference_models/models/sam2/sam2_torch.py +905 -0
- inference_models/models/sam2_rt/__init__.py +0 -0
- inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
- inference_models/models/smolvlm/__init__.py +0 -0
- inference_models/models/smolvlm/smolvlm_hf.py +245 -0
- inference_models/models/trocr/__init__.py +0 -0
- inference_models/models/trocr/trocr_hf.py +53 -0
- inference_models/models/vit/__init__.py +0 -0
- inference_models/models/vit/vit_classification_huggingface.py +319 -0
- inference_models/models/vit/vit_classification_onnx.py +326 -0
- inference_models/models/vit/vit_classification_trt.py +365 -0
- inference_models/models/yolact/__init__.py +1 -0
- inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
- inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
- inference_models/models/yolo_world/__init__.py +1 -0
- inference_models/models/yolonas/__init__.py +0 -0
- inference_models/models/yolonas/nms.py +44 -0
- inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
- inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
- inference_models/models/yolov10/__init__.py +0 -0
- inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
- inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
- inference_models/models/yolov11/__init__.py +0 -0
- inference_models/models/yolov11/yolov11_onnx.py +28 -0
- inference_models/models/yolov11/yolov11_torch_script.py +25 -0
- inference_models/models/yolov11/yolov11_trt.py +21 -0
- inference_models/models/yolov12/__init__.py +0 -0
- inference_models/models/yolov12/yolov12_onnx.py +7 -0
- inference_models/models/yolov12/yolov12_torch_script.py +7 -0
- inference_models/models/yolov12/yolov12_trt.py +7 -0
- inference_models/models/yolov5/__init__.py +0 -0
- inference_models/models/yolov5/nms.py +99 -0
- inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
- inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
- inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
- inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
- inference_models/models/yolov7/__init__.py +0 -0
- inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
- inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
- inference_models/models/yolov8/__init__.py +0 -0
- inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
- inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
- inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
- inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
- inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
- inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
- inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
- inference_models/models/yolov9/__init__.py +0 -0
- inference_models/models/yolov9/yolov9_onnx.py +7 -0
- inference_models/models/yolov9/yolov9_torch_script.py +7 -0
- inference_models/models/yolov9/yolov9_trt.py +7 -0
- inference_models/runtime_introspection/__init__.py +0 -0
- inference_models/runtime_introspection/core.py +410 -0
- inference_models/utils/__init__.py +0 -0
- inference_models/utils/download.py +608 -0
- inference_models/utils/environment.py +28 -0
- inference_models/utils/file_system.py +51 -0
- inference_models/utils/hashing.py +7 -0
- inference_models/utils/imports.py +48 -0
- inference_models/utils/onnx_introspection.py +17 -0
- inference_models/weights_providers/__init__.py +0 -0
- inference_models/weights_providers/core.py +20 -0
- inference_models/weights_providers/entities.py +159 -0
- inference_models/weights_providers/roboflow.py +601 -0
- inference_models-0.18.3.dist-info/METADATA +466 -0
- inference_models-0.18.3.dist-info/RECORD +195 -0
- inference_models-0.18.3.dist-info/WHEEL +5 -0
- inference_models-0.18.3.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,373 @@
|
|
|
1
|
+
# ------------------------------------------------------------------------
|
|
2
|
+
# RF-DETR
|
|
3
|
+
# Copyright (c) 2025 Roboflow. All Rights Reserved.
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
|
5
|
+
# ------------------------------------------------------------------------
|
|
6
|
+
# Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR)
|
|
7
|
+
# Copyright (c) 2024 Baidu. All Rights Reserved.
|
|
8
|
+
# ------------------------------------------------------------------------
|
|
9
|
+
# Modified from ViTDet (https://github.com/facebookresearch/detectron2/tree/main/projects/ViTDet)
|
|
10
|
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
|
11
|
+
# ------------------------------------------------------------------------
|
|
12
|
+
|
|
13
|
+
"""
|
|
14
|
+
Projector
|
|
15
|
+
"""
|
|
16
|
+
import math
|
|
17
|
+
import random
|
|
18
|
+
|
|
19
|
+
import numpy as np
|
|
20
|
+
import torch
|
|
21
|
+
import torch.nn as nn
|
|
22
|
+
import torch.nn.functional as F
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class LayerNorm(nn.Module):
|
|
26
|
+
"""
|
|
27
|
+
A LayerNorm variant, popularized by Transformers, that performs point-wise mean and
|
|
28
|
+
variance normalization over the channel dimension for inputs that have shape
|
|
29
|
+
(batch_size, channels, height, width).
|
|
30
|
+
https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(self, normalized_shape, eps=1e-6):
|
|
34
|
+
super().__init__()
|
|
35
|
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
|
36
|
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
|
37
|
+
self.eps = eps
|
|
38
|
+
self.normalized_shape = (normalized_shape,)
|
|
39
|
+
|
|
40
|
+
def forward(self, x):
|
|
41
|
+
"""
|
|
42
|
+
LayerNorm forward
|
|
43
|
+
TODO: this is a hack to avoid overflow when using fp16
|
|
44
|
+
"""
|
|
45
|
+
x = x.permute(0, 2, 3, 1)
|
|
46
|
+
x = F.layer_norm(x, (x.size(3),), self.weight, self.bias, self.eps)
|
|
47
|
+
x = x.permute(0, 3, 1, 2)
|
|
48
|
+
return x
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def get_norm(norm, out_channels):
|
|
52
|
+
"""
|
|
53
|
+
Args:
|
|
54
|
+
norm (str or callable): either one of BN, SyncBN, FrozenBN, GN;
|
|
55
|
+
or a callable that takes a channel number and returns
|
|
56
|
+
the normalization layer as a nn.Module.
|
|
57
|
+
Returns:
|
|
58
|
+
nn.Module or None: the normalization layer
|
|
59
|
+
"""
|
|
60
|
+
if norm is None:
|
|
61
|
+
return None
|
|
62
|
+
if isinstance(norm, str):
|
|
63
|
+
if len(norm) == 0:
|
|
64
|
+
return None
|
|
65
|
+
norm = {
|
|
66
|
+
"LN": lambda channels: LayerNorm(channels),
|
|
67
|
+
}[norm]
|
|
68
|
+
return norm(out_channels)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def get_activation(name, inplace=False):
|
|
72
|
+
"""get activation"""
|
|
73
|
+
if name == "silu":
|
|
74
|
+
module = nn.SiLU(inplace=inplace)
|
|
75
|
+
elif name == "relu":
|
|
76
|
+
module = nn.ReLU(inplace=inplace)
|
|
77
|
+
elif name in ["LeakyReLU", "leakyrelu", "lrelu"]:
|
|
78
|
+
module = nn.LeakyReLU(0.1, inplace=inplace)
|
|
79
|
+
elif name is None:
|
|
80
|
+
module = nn.Identity()
|
|
81
|
+
else:
|
|
82
|
+
raise AttributeError("Unsupported act type: {}".format(name))
|
|
83
|
+
return module
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class ConvX(nn.Module):
|
|
87
|
+
"""Conv-bn module"""
|
|
88
|
+
|
|
89
|
+
def __init__(
|
|
90
|
+
self,
|
|
91
|
+
in_planes,
|
|
92
|
+
out_planes,
|
|
93
|
+
kernel=3,
|
|
94
|
+
stride=1,
|
|
95
|
+
groups=1,
|
|
96
|
+
dilation=1,
|
|
97
|
+
act="relu",
|
|
98
|
+
layer_norm=False,
|
|
99
|
+
rms_norm=False,
|
|
100
|
+
):
|
|
101
|
+
super(ConvX, self).__init__()
|
|
102
|
+
if not isinstance(kernel, tuple):
|
|
103
|
+
kernel = (kernel, kernel)
|
|
104
|
+
padding = (kernel[0] // 2, kernel[1] // 2)
|
|
105
|
+
self.conv = nn.Conv2d(
|
|
106
|
+
in_planes,
|
|
107
|
+
out_planes,
|
|
108
|
+
kernel_size=kernel,
|
|
109
|
+
stride=stride,
|
|
110
|
+
padding=padding,
|
|
111
|
+
groups=groups,
|
|
112
|
+
dilation=dilation,
|
|
113
|
+
bias=False,
|
|
114
|
+
)
|
|
115
|
+
if rms_norm:
|
|
116
|
+
self.bn = nn.RMSNorm(out_planes)
|
|
117
|
+
else:
|
|
118
|
+
self.bn = (
|
|
119
|
+
get_norm("LN", out_planes) if layer_norm else nn.BatchNorm2d(out_planes)
|
|
120
|
+
)
|
|
121
|
+
self.act = get_activation(act, inplace=True)
|
|
122
|
+
|
|
123
|
+
def forward(self, x):
|
|
124
|
+
"""forward"""
|
|
125
|
+
out = self.act(self.bn(self.conv(x.contiguous())))
|
|
126
|
+
return out
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
class Bottleneck(nn.Module):
|
|
130
|
+
"""Standard bottleneck."""
|
|
131
|
+
|
|
132
|
+
def __init__(
|
|
133
|
+
self,
|
|
134
|
+
c1,
|
|
135
|
+
c2,
|
|
136
|
+
shortcut=True,
|
|
137
|
+
g=1,
|
|
138
|
+
k=(3, 3),
|
|
139
|
+
e=0.5,
|
|
140
|
+
act="silu",
|
|
141
|
+
layer_norm=False,
|
|
142
|
+
rms_norm=False,
|
|
143
|
+
):
|
|
144
|
+
"""ch_in, ch_out, shortcut, groups, kernels, expand"""
|
|
145
|
+
super().__init__()
|
|
146
|
+
c_ = int(c2 * e) # hidden channels
|
|
147
|
+
self.cv1 = ConvX(
|
|
148
|
+
c1, c_, k[0], 1, act=act, layer_norm=layer_norm, rms_norm=rms_norm
|
|
149
|
+
)
|
|
150
|
+
self.cv2 = ConvX(
|
|
151
|
+
c_, c2, k[1], 1, groups=g, act=act, layer_norm=layer_norm, rms_norm=rms_norm
|
|
152
|
+
)
|
|
153
|
+
self.add = shortcut and c1 == c2
|
|
154
|
+
|
|
155
|
+
def forward(self, x):
|
|
156
|
+
"""'forward()' applies the YOLOv5 FPN to input data."""
|
|
157
|
+
return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
class C2f(nn.Module):
|
|
161
|
+
"""Faster Implementation of CSP Bottleneck with 2 convolutions."""
|
|
162
|
+
|
|
163
|
+
def __init__(
|
|
164
|
+
self,
|
|
165
|
+
c1,
|
|
166
|
+
c2,
|
|
167
|
+
n=1,
|
|
168
|
+
shortcut=False,
|
|
169
|
+
g=1,
|
|
170
|
+
e=0.5,
|
|
171
|
+
act="silu",
|
|
172
|
+
layer_norm=False,
|
|
173
|
+
rms_norm=False,
|
|
174
|
+
):
|
|
175
|
+
"""ch_in, ch_out, number, shortcut, groups, expansion"""
|
|
176
|
+
super().__init__()
|
|
177
|
+
self.c = int(c2 * e) # hidden channels
|
|
178
|
+
self.cv1 = ConvX(
|
|
179
|
+
c1, 2 * self.c, 1, 1, act=act, layer_norm=layer_norm, rms_norm=rms_norm
|
|
180
|
+
)
|
|
181
|
+
self.cv2 = ConvX(
|
|
182
|
+
(2 + n) * self.c, c2, 1, act=act, layer_norm=layer_norm, rms_norm=rms_norm
|
|
183
|
+
) # optional act=FReLU(c2)
|
|
184
|
+
self.m = nn.ModuleList(
|
|
185
|
+
Bottleneck(
|
|
186
|
+
self.c,
|
|
187
|
+
self.c,
|
|
188
|
+
shortcut,
|
|
189
|
+
g,
|
|
190
|
+
k=(3, 3),
|
|
191
|
+
e=1.0,
|
|
192
|
+
act=act,
|
|
193
|
+
layer_norm=layer_norm,
|
|
194
|
+
rms_norm=rms_norm,
|
|
195
|
+
)
|
|
196
|
+
for _ in range(n)
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
def forward(self, x):
|
|
200
|
+
"""Forward pass using split() instead of chunk()."""
|
|
201
|
+
y = list(self.cv1(x).split((self.c, self.c), 1))
|
|
202
|
+
y.extend(m(y[-1]) for m in self.m)
|
|
203
|
+
return self.cv2(torch.cat(y, 1))
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
class MultiScaleProjector(nn.Module):
|
|
207
|
+
"""
|
|
208
|
+
This module implements MultiScaleProjector in :paper:`lwdetr`.
|
|
209
|
+
It creates pyramid features built on top of the input feature map.
|
|
210
|
+
"""
|
|
211
|
+
|
|
212
|
+
def __init__(
|
|
213
|
+
self,
|
|
214
|
+
in_channels,
|
|
215
|
+
out_channels,
|
|
216
|
+
scale_factors,
|
|
217
|
+
num_blocks=3,
|
|
218
|
+
layer_norm=False,
|
|
219
|
+
rms_norm=False,
|
|
220
|
+
survival_prob=1.0,
|
|
221
|
+
force_drop_last_n_features=0,
|
|
222
|
+
):
|
|
223
|
+
"""
|
|
224
|
+
Args:
|
|
225
|
+
net (Backbone): module representing the subnetwork backbone.
|
|
226
|
+
Must be a subclass of :class:`Backbone`.
|
|
227
|
+
out_channels (int): number of channels in the output feature maps.
|
|
228
|
+
scale_factors (list[float]): list of scaling factors to upsample or downsample
|
|
229
|
+
the input features for creating pyramid features.
|
|
230
|
+
"""
|
|
231
|
+
super(MultiScaleProjector, self).__init__()
|
|
232
|
+
|
|
233
|
+
self.scale_factors = scale_factors
|
|
234
|
+
self.survival_prob = survival_prob
|
|
235
|
+
self.force_drop_last_n_features = force_drop_last_n_features
|
|
236
|
+
|
|
237
|
+
stages_sampling = []
|
|
238
|
+
stages = []
|
|
239
|
+
# use_bias = norm == ""
|
|
240
|
+
use_bias = False
|
|
241
|
+
self.use_extra_pool = False
|
|
242
|
+
for scale in scale_factors:
|
|
243
|
+
stages_sampling.append([])
|
|
244
|
+
for in_dim in in_channels:
|
|
245
|
+
out_dim = in_dim
|
|
246
|
+
layers = []
|
|
247
|
+
|
|
248
|
+
# if in_dim > 512:
|
|
249
|
+
# layers.append(ConvX(in_dim, in_dim // 2, kernel=1))
|
|
250
|
+
# in_dim = in_dim // 2
|
|
251
|
+
|
|
252
|
+
if scale == 4.0:
|
|
253
|
+
layers.extend(
|
|
254
|
+
[
|
|
255
|
+
nn.ConvTranspose2d(
|
|
256
|
+
in_dim, in_dim // 2, kernel_size=2, stride=2
|
|
257
|
+
),
|
|
258
|
+
get_norm("LN", in_dim // 2),
|
|
259
|
+
nn.GELU(),
|
|
260
|
+
nn.ConvTranspose2d(
|
|
261
|
+
in_dim // 2, in_dim // 4, kernel_size=2, stride=2
|
|
262
|
+
),
|
|
263
|
+
]
|
|
264
|
+
)
|
|
265
|
+
out_dim = in_dim // 4
|
|
266
|
+
elif scale == 2.0:
|
|
267
|
+
# a hack to reduce the FLOPs and Params when the dimention of output feature is too large
|
|
268
|
+
# if in_dim > 512:
|
|
269
|
+
# layers = [
|
|
270
|
+
# ConvX(in_dim, in_dim // 2, kernel=1),
|
|
271
|
+
# nn.ConvTranspose2d(in_dim // 2, in_dim // 4, kernel_size=2, stride=2),
|
|
272
|
+
# ]
|
|
273
|
+
# out_dim = in_dim // 4
|
|
274
|
+
# else:
|
|
275
|
+
layers.extend(
|
|
276
|
+
[
|
|
277
|
+
nn.ConvTranspose2d(
|
|
278
|
+
in_dim, in_dim // 2, kernel_size=2, stride=2
|
|
279
|
+
),
|
|
280
|
+
]
|
|
281
|
+
)
|
|
282
|
+
out_dim = in_dim // 2
|
|
283
|
+
elif scale == 1.0:
|
|
284
|
+
pass
|
|
285
|
+
elif scale == 0.5:
|
|
286
|
+
layers.extend(
|
|
287
|
+
[
|
|
288
|
+
ConvX(in_dim, in_dim, 3, 2, layer_norm=layer_norm),
|
|
289
|
+
]
|
|
290
|
+
)
|
|
291
|
+
elif scale == 0.25:
|
|
292
|
+
self.use_extra_pool = True
|
|
293
|
+
continue
|
|
294
|
+
else:
|
|
295
|
+
raise NotImplementedError(
|
|
296
|
+
"Unsupported scale_factor:{}".format(scale)
|
|
297
|
+
)
|
|
298
|
+
layers = nn.Sequential(*layers)
|
|
299
|
+
stages_sampling[-1].append(layers)
|
|
300
|
+
stages_sampling[-1] = nn.ModuleList(stages_sampling[-1])
|
|
301
|
+
|
|
302
|
+
in_dim = int(sum(in_channel // max(1, scale) for in_channel in in_channels))
|
|
303
|
+
layers = [
|
|
304
|
+
C2f(in_dim, out_channels, num_blocks, layer_norm=layer_norm),
|
|
305
|
+
get_norm("LN", out_channels),
|
|
306
|
+
]
|
|
307
|
+
layers = nn.Sequential(*layers)
|
|
308
|
+
stages.append(layers)
|
|
309
|
+
|
|
310
|
+
self.stages_sampling = nn.ModuleList(stages_sampling)
|
|
311
|
+
self.stages = nn.ModuleList(stages)
|
|
312
|
+
|
|
313
|
+
def forward(self, x):
|
|
314
|
+
"""
|
|
315
|
+
Args:
|
|
316
|
+
x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
|
|
317
|
+
Returns:
|
|
318
|
+
dict[str->Tensor]:
|
|
319
|
+
mapping from feature map name to pyramid feature map tensor
|
|
320
|
+
in high to low resolution order. Returned feature names follow the FPN
|
|
321
|
+
convention: "p<stage>", where stage has stride = 2 ** stage e.g.,
|
|
322
|
+
["p2", "p3", ..., "p6"].
|
|
323
|
+
"""
|
|
324
|
+
num_features = len(x)
|
|
325
|
+
if self.survival_prob < 1.0 and self.training:
|
|
326
|
+
final_drop_prob = 1 - self.survival_prob
|
|
327
|
+
drop_p = np.random.uniform()
|
|
328
|
+
for i in range(1, num_features):
|
|
329
|
+
critical_drop_prob = i * (final_drop_prob / (num_features - 1))
|
|
330
|
+
if drop_p < critical_drop_prob:
|
|
331
|
+
x[i][:] = 0
|
|
332
|
+
elif self.force_drop_last_n_features > 0:
|
|
333
|
+
for i in range(self.force_drop_last_n_features):
|
|
334
|
+
# don't do it inplace to ensure the compiler can optimize out the backbone layers
|
|
335
|
+
x[-(i + 1)] = torch.zeros_like(x[-(i + 1)])
|
|
336
|
+
|
|
337
|
+
results = []
|
|
338
|
+
# x list of len(out_features_indexes)
|
|
339
|
+
for i, stage in enumerate(self.stages):
|
|
340
|
+
feat_fuse = []
|
|
341
|
+
for j, stage_sampling in enumerate(self.stages_sampling[i]):
|
|
342
|
+
feat_fuse.append(stage_sampling(x[j]))
|
|
343
|
+
if len(feat_fuse) > 1:
|
|
344
|
+
feat_fuse = torch.cat(feat_fuse, dim=1)
|
|
345
|
+
else:
|
|
346
|
+
feat_fuse = feat_fuse[0]
|
|
347
|
+
results.append(stage(feat_fuse))
|
|
348
|
+
if self.use_extra_pool:
|
|
349
|
+
results.append(
|
|
350
|
+
F.max_pool2d(results[-1], kernel_size=1, stride=2, padding=0)
|
|
351
|
+
)
|
|
352
|
+
return results
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
class SimpleProjector(nn.Module):
|
|
356
|
+
def __init__(self, in_dim, out_dim, factor_kernel=False):
|
|
357
|
+
super(SimpleProjector, self).__init__()
|
|
358
|
+
if not factor_kernel:
|
|
359
|
+
self.convx1 = ConvX(in_dim, in_dim * 2, layer_norm=True, act="silu")
|
|
360
|
+
self.convx2 = ConvX(in_dim * 2, out_dim, layer_norm=True, act="silu")
|
|
361
|
+
else:
|
|
362
|
+
self.convx1 = ConvX(
|
|
363
|
+
in_dim, out_dim, kernel=(3, 1), layer_norm=True, act="silu"
|
|
364
|
+
)
|
|
365
|
+
self.convx2 = ConvX(
|
|
366
|
+
out_dim, out_dim, kernel=(1, 3), layer_norm=True, act="silu"
|
|
367
|
+
)
|
|
368
|
+
self.ln = get_norm("LN", out_dim)
|
|
369
|
+
|
|
370
|
+
def forward(self, x):
|
|
371
|
+
"""forward"""
|
|
372
|
+
out = self.ln(self.convx2(self.convx1(x[0])))
|
|
373
|
+
return [out]
|