inference-models 0.18.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. inference_models/__init__.py +36 -0
  2. inference_models/configuration.py +72 -0
  3. inference_models/constants.py +2 -0
  4. inference_models/entities.py +5 -0
  5. inference_models/errors.py +137 -0
  6. inference_models/logger.py +52 -0
  7. inference_models/model_pipelines/__init__.py +0 -0
  8. inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
  9. inference_models/model_pipelines/auto_loaders/core.py +120 -0
  10. inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
  11. inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
  12. inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
  13. inference_models/models/__init__.py +0 -0
  14. inference_models/models/auto_loaders/__init__.py +0 -0
  15. inference_models/models/auto_loaders/access_manager.py +168 -0
  16. inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
  17. inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
  18. inference_models/models/auto_loaders/constants.py +7 -0
  19. inference_models/models/auto_loaders/core.py +1341 -0
  20. inference_models/models/auto_loaders/dependency_models.py +52 -0
  21. inference_models/models/auto_loaders/entities.py +57 -0
  22. inference_models/models/auto_loaders/models_registry.py +497 -0
  23. inference_models/models/auto_loaders/presentation_utils.py +333 -0
  24. inference_models/models/auto_loaders/ranking.py +413 -0
  25. inference_models/models/auto_loaders/utils.py +31 -0
  26. inference_models/models/base/__init__.py +0 -0
  27. inference_models/models/base/classification.py +123 -0
  28. inference_models/models/base/depth_estimation.py +62 -0
  29. inference_models/models/base/documents_parsing.py +111 -0
  30. inference_models/models/base/embeddings.py +66 -0
  31. inference_models/models/base/instance_segmentation.py +87 -0
  32. inference_models/models/base/keypoints_detection.py +93 -0
  33. inference_models/models/base/object_detection.py +143 -0
  34. inference_models/models/base/semantic_segmentation.py +74 -0
  35. inference_models/models/base/types.py +5 -0
  36. inference_models/models/clip/__init__.py +0 -0
  37. inference_models/models/clip/clip_onnx.py +148 -0
  38. inference_models/models/clip/clip_pytorch.py +104 -0
  39. inference_models/models/clip/preprocessing.py +162 -0
  40. inference_models/models/common/__init__.py +0 -0
  41. inference_models/models/common/cuda.py +30 -0
  42. inference_models/models/common/model_packages.py +25 -0
  43. inference_models/models/common/onnx.py +379 -0
  44. inference_models/models/common/roboflow/__init__.py +0 -0
  45. inference_models/models/common/roboflow/model_packages.py +361 -0
  46. inference_models/models/common/roboflow/post_processing.py +436 -0
  47. inference_models/models/common/roboflow/pre_processing.py +1332 -0
  48. inference_models/models/common/torch.py +20 -0
  49. inference_models/models/common/trt.py +266 -0
  50. inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
  51. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
  52. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
  53. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
  54. inference_models/models/depth_anything_v2/__init__.py +0 -0
  55. inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
  56. inference_models/models/dinov3/__init__.py +0 -0
  57. inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
  58. inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
  59. inference_models/models/doctr/__init__.py +0 -0
  60. inference_models/models/doctr/doctr_torch.py +304 -0
  61. inference_models/models/easy_ocr/__init__.py +0 -0
  62. inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
  63. inference_models/models/florence2/__init__.py +0 -0
  64. inference_models/models/florence2/florence2_hf.py +897 -0
  65. inference_models/models/grounding_dino/__init__.py +0 -0
  66. inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
  67. inference_models/models/l2cs/__init__.py +0 -0
  68. inference_models/models/l2cs/l2cs_onnx.py +216 -0
  69. inference_models/models/mediapipe_face_detection/__init__.py +0 -0
  70. inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
  71. inference_models/models/moondream2/__init__.py +0 -0
  72. inference_models/models/moondream2/moondream2_hf.py +281 -0
  73. inference_models/models/owlv2/__init__.py +0 -0
  74. inference_models/models/owlv2/cache.py +182 -0
  75. inference_models/models/owlv2/entities.py +112 -0
  76. inference_models/models/owlv2/owlv2_hf.py +695 -0
  77. inference_models/models/owlv2/reference_dataset.py +291 -0
  78. inference_models/models/paligemma/__init__.py +0 -0
  79. inference_models/models/paligemma/paligemma_hf.py +209 -0
  80. inference_models/models/perception_encoder/__init__.py +0 -0
  81. inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
  82. inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
  83. inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
  84. inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
  85. inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
  86. inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
  87. inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
  88. inference_models/models/qwen25vl/__init__.py +1 -0
  89. inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
  90. inference_models/models/resnet/__init__.py +0 -0
  91. inference_models/models/resnet/resnet_classification_onnx.py +330 -0
  92. inference_models/models/resnet/resnet_classification_torch.py +305 -0
  93. inference_models/models/resnet/resnet_classification_trt.py +369 -0
  94. inference_models/models/rfdetr/__init__.py +0 -0
  95. inference_models/models/rfdetr/backbone_builder.py +101 -0
  96. inference_models/models/rfdetr/class_remapping.py +41 -0
  97. inference_models/models/rfdetr/common.py +115 -0
  98. inference_models/models/rfdetr/default_labels.py +108 -0
  99. inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
  100. inference_models/models/rfdetr/misc.py +26 -0
  101. inference_models/models/rfdetr/ms_deform_attn.py +180 -0
  102. inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
  103. inference_models/models/rfdetr/position_encoding.py +166 -0
  104. inference_models/models/rfdetr/post_processor.py +83 -0
  105. inference_models/models/rfdetr/projector.py +373 -0
  106. inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
  107. inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
  108. inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
  109. inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
  110. inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
  111. inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
  112. inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
  113. inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
  114. inference_models/models/rfdetr/segmentation_head.py +273 -0
  115. inference_models/models/rfdetr/transformer.py +767 -0
  116. inference_models/models/roboflow_instant/__init__.py +0 -0
  117. inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
  118. inference_models/models/sam/__init__.py +0 -0
  119. inference_models/models/sam/cache.py +147 -0
  120. inference_models/models/sam/entities.py +25 -0
  121. inference_models/models/sam/sam_torch.py +675 -0
  122. inference_models/models/sam2/__init__.py +0 -0
  123. inference_models/models/sam2/cache.py +162 -0
  124. inference_models/models/sam2/entities.py +43 -0
  125. inference_models/models/sam2/sam2_torch.py +905 -0
  126. inference_models/models/sam2_rt/__init__.py +0 -0
  127. inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
  128. inference_models/models/smolvlm/__init__.py +0 -0
  129. inference_models/models/smolvlm/smolvlm_hf.py +245 -0
  130. inference_models/models/trocr/__init__.py +0 -0
  131. inference_models/models/trocr/trocr_hf.py +53 -0
  132. inference_models/models/vit/__init__.py +0 -0
  133. inference_models/models/vit/vit_classification_huggingface.py +319 -0
  134. inference_models/models/vit/vit_classification_onnx.py +326 -0
  135. inference_models/models/vit/vit_classification_trt.py +365 -0
  136. inference_models/models/yolact/__init__.py +1 -0
  137. inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
  138. inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
  139. inference_models/models/yolo_world/__init__.py +1 -0
  140. inference_models/models/yolonas/__init__.py +0 -0
  141. inference_models/models/yolonas/nms.py +44 -0
  142. inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
  143. inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
  144. inference_models/models/yolov10/__init__.py +0 -0
  145. inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
  146. inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
  147. inference_models/models/yolov11/__init__.py +0 -0
  148. inference_models/models/yolov11/yolov11_onnx.py +28 -0
  149. inference_models/models/yolov11/yolov11_torch_script.py +25 -0
  150. inference_models/models/yolov11/yolov11_trt.py +21 -0
  151. inference_models/models/yolov12/__init__.py +0 -0
  152. inference_models/models/yolov12/yolov12_onnx.py +7 -0
  153. inference_models/models/yolov12/yolov12_torch_script.py +7 -0
  154. inference_models/models/yolov12/yolov12_trt.py +7 -0
  155. inference_models/models/yolov5/__init__.py +0 -0
  156. inference_models/models/yolov5/nms.py +99 -0
  157. inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
  158. inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
  159. inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
  160. inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
  161. inference_models/models/yolov7/__init__.py +0 -0
  162. inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
  163. inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
  164. inference_models/models/yolov8/__init__.py +0 -0
  165. inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
  166. inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
  167. inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
  168. inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
  169. inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
  170. inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
  171. inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
  172. inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
  173. inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
  174. inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
  175. inference_models/models/yolov9/__init__.py +0 -0
  176. inference_models/models/yolov9/yolov9_onnx.py +7 -0
  177. inference_models/models/yolov9/yolov9_torch_script.py +7 -0
  178. inference_models/models/yolov9/yolov9_trt.py +7 -0
  179. inference_models/runtime_introspection/__init__.py +0 -0
  180. inference_models/runtime_introspection/core.py +410 -0
  181. inference_models/utils/__init__.py +0 -0
  182. inference_models/utils/download.py +608 -0
  183. inference_models/utils/environment.py +28 -0
  184. inference_models/utils/file_system.py +51 -0
  185. inference_models/utils/hashing.py +7 -0
  186. inference_models/utils/imports.py +48 -0
  187. inference_models/utils/onnx_introspection.py +17 -0
  188. inference_models/weights_providers/__init__.py +0 -0
  189. inference_models/weights_providers/core.py +20 -0
  190. inference_models/weights_providers/entities.py +159 -0
  191. inference_models/weights_providers/roboflow.py +601 -0
  192. inference_models-0.18.3.dist-info/METADATA +466 -0
  193. inference_models-0.18.3.dist-info/RECORD +195 -0
  194. inference_models-0.18.3.dist-info/WHEEL +5 -0
  195. inference_models-0.18.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,470 @@
1
+ import os.path
2
+ from copy import deepcopy
3
+ from typing import List, Optional, Tuple, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+ from inference_models import Detections, ObjectDetectionModel
9
+ from inference_models.configuration import DEFAULT_DEVICE
10
+ from inference_models.entities import ColorFormat
11
+ from inference_models.errors import (
12
+ CorruptedModelPackageError,
13
+ ModelLoadingError,
14
+ ModelRuntimeError,
15
+ )
16
+ from inference_models.logger import LOGGER
17
+ from inference_models.models.common.model_packages import get_model_package_contents
18
+ from inference_models.models.common.roboflow.model_packages import (
19
+ ColorMode,
20
+ DivisiblePadding,
21
+ InferenceConfig,
22
+ NetworkInputDefinition,
23
+ PreProcessingMetadata,
24
+ ResizeMode,
25
+ TrainingInputSize,
26
+ parse_class_names_file,
27
+ parse_inference_config,
28
+ )
29
+ from inference_models.models.common.roboflow.pre_processing import (
30
+ pre_process_network_input,
31
+ )
32
+ from inference_models.models.rfdetr.class_remapping import (
33
+ ClassesReMapping,
34
+ prepare_class_remapping,
35
+ )
36
+ from inference_models.models.rfdetr.common import parse_model_type
37
+ from inference_models.models.rfdetr.default_labels import resolve_labels
38
+ from inference_models.models.rfdetr.post_processor import PostProcess
39
+ from inference_models.models.rfdetr.rfdetr_base_pytorch import (
40
+ LWDETR,
41
+ RFDETRBaseConfig,
42
+ RFDETRLargeConfig,
43
+ RFDETRMediumConfig,
44
+ RFDETRNanoConfig,
45
+ RFDETRSmallConfig,
46
+ build_model,
47
+ )
48
+
49
+ try:
50
+ torch.set_float32_matmul_precision("high")
51
+ except:
52
+ pass
53
+
54
+ CONFIG_FOR_MODEL_TYPE = {
55
+ "rfdetr-nano": RFDETRNanoConfig,
56
+ "rfdetr-small": RFDETRSmallConfig,
57
+ "rfdetr-medium": RFDETRMediumConfig,
58
+ "rfdetr-base": RFDETRBaseConfig,
59
+ "rfdetr-large": RFDETRLargeConfig,
60
+ }
61
+
62
+ RESIZE_MODES_TO_REVERT_PADDING = {
63
+ ResizeMode.LETTERBOX,
64
+ ResizeMode.LETTERBOX_REFLECT_EDGES,
65
+ }
66
+
67
+
68
+ class RFDetrForObjectDetectionTorch(
69
+ (ObjectDetectionModel[torch.Tensor, PreProcessingMetadata, dict])
70
+ ):
71
+
72
+ @classmethod
73
+ def from_pretrained(
74
+ cls,
75
+ model_name_or_path: str,
76
+ device: torch.device = DEFAULT_DEVICE,
77
+ model_type: Optional[str] = None,
78
+ labels: Optional[Union[str, List[str]]] = None,
79
+ resolution: Optional[int] = None,
80
+ **kwargs,
81
+ ) -> "RFDetrForObjectDetectionTorch":
82
+ if os.path.isfile(model_name_or_path):
83
+ return cls.from_checkpoint_file(
84
+ checkpoint_path=model_name_or_path,
85
+ model_type=model_type,
86
+ labels=labels,
87
+ resolution=resolution,
88
+ )
89
+ model_package_content = get_model_package_contents(
90
+ model_package_dir=model_name_or_path,
91
+ elements=[
92
+ "class_names.txt",
93
+ "inference_config.json",
94
+ "model_type.json",
95
+ "weights.pth",
96
+ ],
97
+ )
98
+ class_names = parse_class_names_file(
99
+ class_names_path=model_package_content["class_names.txt"]
100
+ )
101
+ inference_config = parse_inference_config(
102
+ config_path=model_package_content["inference_config.json"],
103
+ allowed_resize_modes={
104
+ ResizeMode.STRETCH_TO,
105
+ ResizeMode.LETTERBOX,
106
+ ResizeMode.CENTER_CROP,
107
+ ResizeMode.LETTERBOX_REFLECT_EDGES,
108
+ },
109
+ )
110
+ classes_re_mapping = None
111
+ if inference_config.class_names_operations:
112
+ class_names, classes_re_mapping = prepare_class_remapping(
113
+ class_names=class_names,
114
+ class_names_operations=inference_config.class_names_operations,
115
+ device=device,
116
+ )
117
+ weights_dict = torch.load(
118
+ model_package_content["weights.pth"],
119
+ map_location=device,
120
+ weights_only=False,
121
+ )["model"]
122
+ model_type = parse_model_type(
123
+ config_path=model_package_content["model_type.json"]
124
+ )
125
+ if model_type not in CONFIG_FOR_MODEL_TYPE:
126
+ raise CorruptedModelPackageError(
127
+ message=f"Model package describes model_type as '{model_type}' which is not supported. "
128
+ f"Supported model types: {list(CONFIG_FOR_MODEL_TYPE.keys())}.",
129
+ help_url="https://todo",
130
+ )
131
+ model_config = CONFIG_FOR_MODEL_TYPE[model_type](device=device)
132
+ checkpoint_num_classes = weights_dict["class_embed.bias"].shape[0]
133
+ model_config.num_classes = checkpoint_num_classes - 1
134
+ model_config.resolution = (
135
+ inference_config.network_input.training_input_size.height
136
+ )
137
+ model = build_model(config=model_config)
138
+ model.load_state_dict(weights_dict)
139
+ model = model.eval().to(device)
140
+ post_processor = PostProcess()
141
+ return cls(
142
+ model=model,
143
+ class_names=class_names,
144
+ classes_re_mapping=classes_re_mapping,
145
+ device=device,
146
+ inference_config=inference_config,
147
+ post_processor=post_processor,
148
+ resolution=model_config.resolution,
149
+ )
150
+
151
+ @classmethod
152
+ def from_checkpoint_file(
153
+ cls,
154
+ checkpoint_path: str,
155
+ model_type: Optional[str] = None,
156
+ labels: Optional[Union[str, List[str]]] = None,
157
+ resolution: Optional[int] = None,
158
+ device: torch.device = DEFAULT_DEVICE,
159
+ ):
160
+ if model_type is None:
161
+ raise ModelLoadingError(
162
+ message="While loading RFDetr model (using torch backend) could not determine `model_type`. "
163
+ "If you used `RFDetrForObjectDetectionTorch` directly imported in your code, please pass "
164
+ f"one of the value: {CONFIG_FOR_MODEL_TYPE.keys()} as the parameter. If you see this "
165
+ f"error, while using `AutoModel.from_pretrained(...)` or thrown from managed Roboflow service, "
166
+ f"this is a bug - raise the issue: https://github.com/roboflow/inference/issue providing "
167
+ f"full context.",
168
+ help_url="https://todo",
169
+ )
170
+ weights_dict = torch.load(
171
+ checkpoint_path,
172
+ map_location=device,
173
+ weights_only=False,
174
+ )["model"]
175
+ if model_type not in CONFIG_FOR_MODEL_TYPE:
176
+ raise ModelLoadingError(
177
+ message=f"Model package describes model_type as '{model_type}' which is not supported. "
178
+ f"Supported model types: {list(CONFIG_FOR_MODEL_TYPE.keys())}.",
179
+ help_url="https://todo",
180
+ )
181
+ model_config = CONFIG_FOR_MODEL_TYPE[model_type](device=device)
182
+ divisibility = model_config.num_windows * model_config.patch_size
183
+ if resolution is not None:
184
+ if resolution < 0 or resolution % divisibility != 0:
185
+ raise ModelLoadingError(
186
+ message=f"Attempted to load RFDetr model (using torch backend) with `resolution` parameter which "
187
+ f"is invalid - the model required positive value divisible by 56. Make sure you used "
188
+ f"proper value, corresponding to the one used to train the model.",
189
+ help_url="https://todo",
190
+ )
191
+ model_config.resolution = resolution
192
+ inference_config = InferenceConfig(
193
+ network_input=NetworkInputDefinition(
194
+ training_input_size=TrainingInputSize(
195
+ height=model_config.resolution,
196
+ width=model_config.resolution,
197
+ ),
198
+ dynamic_spatial_size_supported=True,
199
+ dynamic_spatial_size_mode=DivisiblePadding(
200
+ type="pad-to-be-divisible",
201
+ value=divisibility,
202
+ ),
203
+ color_mode=ColorMode.BGR,
204
+ resize_mode=ResizeMode.STRETCH_TO,
205
+ input_channels=3,
206
+ scaling_factor=255,
207
+ normalization=([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
208
+ )
209
+ )
210
+ checkpoint_num_classes = weights_dict["class_embed.bias"].shape[0]
211
+ model_config.num_classes = checkpoint_num_classes - 1
212
+ model = build_model(config=model_config)
213
+ if labels is None:
214
+ class_names = [f"class_{i}" for i in range(checkpoint_num_classes)]
215
+ elif isinstance(labels, str):
216
+ class_names = resolve_labels(labels=labels)
217
+ else:
218
+ class_names = labels
219
+ if checkpoint_num_classes != len(class_names):
220
+ raise ModelLoadingError(
221
+ message=f"Checkpoint pointed to load RFDetr defines {checkpoint_num_classes} output classes, but "
222
+ f"loaded labels define {len(class_names)} classes - fix the value of `labels` parameter.",
223
+ help_url="https://todo",
224
+ )
225
+ model.load_state_dict(weights_dict)
226
+ model = model.eval().to(device)
227
+ post_processor = PostProcess()
228
+ return cls(
229
+ model=model,
230
+ class_names=class_names,
231
+ classes_re_mapping=None,
232
+ device=device,
233
+ inference_config=inference_config,
234
+ post_processor=post_processor,
235
+ resolution=model_config.resolution,
236
+ )
237
+
238
+ def __init__(
239
+ self,
240
+ model: LWDETR,
241
+ inference_config: InferenceConfig,
242
+ class_names: List[str],
243
+ classes_re_mapping: Optional[ClassesReMapping],
244
+ device: torch.device,
245
+ post_processor: PostProcess,
246
+ resolution: int,
247
+ ):
248
+ self._model = model
249
+ self._inference_config = inference_config
250
+ self._class_names = class_names
251
+ self._classes_re_mapping = classes_re_mapping
252
+ self._post_processor = post_processor
253
+ self._device = device
254
+ self._resolution = resolution
255
+ self._has_warned_about_not_being_optimized_for_inference = False
256
+ self._inference_model: Optional[LWDETR] = None
257
+ self._optimized_has_been_compiled = False
258
+ self._optimized_batch_size = None
259
+ self._optimized_dtype = None
260
+
261
+ @property
262
+ def class_names(self) -> List[str]:
263
+ return self._class_names
264
+
265
+ def optimize_for_inference(
266
+ self,
267
+ compile: bool = True,
268
+ batch_size: int = 1,
269
+ dtype: torch.dtype = torch.float32,
270
+ ) -> None:
271
+ self.remove_optimized_model()
272
+ self._inference_model = deepcopy(self._model)
273
+ self._inference_model.eval()
274
+ self._inference_model.export()
275
+ self._inference_model = self._inference_model.to(dtype=dtype)
276
+ self._optimized_dtype = dtype
277
+ if compile:
278
+ self._inference_model = torch.jit.trace(
279
+ self._inference_model,
280
+ torch.randn(
281
+ batch_size,
282
+ 3,
283
+ self._resolution,
284
+ self._resolution,
285
+ device=self._device,
286
+ dtype=dtype,
287
+ ),
288
+ )
289
+ self._optimized_has_been_compiled = True
290
+ self._optimized_batch_size = batch_size
291
+
292
+ def remove_optimized_model(self) -> None:
293
+ self._has_warned_about_not_being_optimized_for_inference = False
294
+ self._inference_model = None
295
+ self._optimized_has_been_compiled = False
296
+ self._optimized_batch_size = None
297
+
298
+ def pre_process(
299
+ self,
300
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
301
+ input_color_format: Optional[ColorFormat] = None,
302
+ image_size: Optional[Tuple[int, int]] = None,
303
+ **kwargs,
304
+ ) -> Tuple[torch.Tensor, List[PreProcessingMetadata]]:
305
+ return pre_process_network_input(
306
+ images=images,
307
+ image_pre_processing=self._inference_config.image_pre_processing,
308
+ network_input=self._inference_config.network_input,
309
+ target_device=self._device,
310
+ input_color_format=input_color_format,
311
+ image_size_wh=image_size,
312
+ )
313
+
314
+ def forward(self, pre_processed_images: torch.Tensor, **kwargs) -> dict:
315
+ if (
316
+ self._inference_model is None
317
+ and not self._has_warned_about_not_being_optimized_for_inference
318
+ ):
319
+ LOGGER.warning(
320
+ "Model is not optimized for inference. "
321
+ "Latency may be higher than expected. "
322
+ "You can optimize the model for inference by calling model.optimize_for_inference()."
323
+ )
324
+ self._has_warned_about_not_being_optimized_for_inference = True
325
+ if self._inference_model is not None:
326
+ if (self._resolution, self._resolution) != tuple(
327
+ pre_processed_images.shape[2:]
328
+ ):
329
+ raise ModelRuntimeError(
330
+ message=f"Resolution mismatch. Model was optimized for resolution {self._resolution}, "
331
+ f"but got {tuple(pre_processed_images.shape[2:])}. "
332
+ "You can explicitly remove the optimized model by calling model.remove_optimized_model().",
333
+ help_url="https://todo",
334
+ )
335
+ if self._optimized_has_been_compiled:
336
+ if self._optimized_batch_size != pre_processed_images.shape[0]:
337
+ raise ModelRuntimeError(
338
+ message="Batch size mismatch. Optimized model was compiled for batch size "
339
+ f"{self._optimized_batch_size}, but got {pre_processed_images.shape[0]}. "
340
+ "You can explicitly remove the optimized model by calling model.remove_optimized_model(). "
341
+ "Alternatively, you can recompile the optimized model for a different batch size "
342
+ "by calling model.optimize_for_inference(batch_size=<new_batch_size>).",
343
+ help_url="https://todo",
344
+ )
345
+ with torch.inference_mode():
346
+ if self._inference_model:
347
+ predictions = self._inference_model(
348
+ pre_processed_images.to(dtype=self._optimized_dtype)
349
+ )
350
+ else:
351
+ predictions = self._model(pre_processed_images)
352
+ if isinstance(predictions, tuple):
353
+ predictions = {
354
+ "pred_logits": predictions[1],
355
+ "pred_boxes": predictions[0],
356
+ }
357
+ return predictions
358
+
359
+ def post_process(
360
+ self,
361
+ model_results: dict,
362
+ pre_processing_meta: List[PreProcessingMetadata],
363
+ threshold: float = 0.5,
364
+ **kwargs,
365
+ ) -> List[Detections]:
366
+ if (
367
+ self._inference_config.network_input.resize_mode
368
+ in RESIZE_MODES_TO_REVERT_PADDING
369
+ ):
370
+ un_padding_results = []
371
+ for out_box_tensor, image_metadata in zip(
372
+ model_results["pred_boxes"], pre_processing_meta
373
+ ):
374
+ box_center_offsets = torch.as_tensor( # bboxes in format cxcywh now, so only cx, cy to be pushed
375
+ [
376
+ image_metadata.pad_left / image_metadata.inference_size.width,
377
+ image_metadata.pad_top / image_metadata.inference_size.height,
378
+ 0.0,
379
+ 0.0,
380
+ ],
381
+ dtype=out_box_tensor.dtype,
382
+ device=out_box_tensor.device,
383
+ )
384
+ ox_padding = (
385
+ image_metadata.pad_left + image_metadata.pad_right
386
+ ) / image_metadata.inference_size.width
387
+ oy_padding = (
388
+ image_metadata.pad_top + image_metadata.pad_bottom
389
+ ) / image_metadata.inference_size.height
390
+ box_wh_offsets = torch.as_tensor( # bboxes in format cxcywh now, so only cx, cy to be pushed
391
+ [
392
+ 1.0 - ox_padding,
393
+ 1.0 - oy_padding,
394
+ 1.0 - ox_padding,
395
+ 1.0 - oy_padding,
396
+ ],
397
+ dtype=out_box_tensor.dtype,
398
+ device=out_box_tensor.device,
399
+ )
400
+ out_box_tensor = (out_box_tensor - box_center_offsets) / box_wh_offsets
401
+ un_padding_results.append(out_box_tensor)
402
+ model_results["pred_boxes"] = torch.stack(un_padding_results, dim=0)
403
+ if self._inference_config.network_input.resize_mode is ResizeMode.CENTER_CROP:
404
+ orig_sizes = [
405
+ (
406
+ round(e.inference_size.height / e.scale_height),
407
+ round(e.inference_size.width / e.scale_width),
408
+ )
409
+ for e in pre_processing_meta
410
+ ]
411
+ else:
412
+ orig_sizes = [
413
+ (e.size_after_pre_processing.height, e.size_after_pre_processing.width)
414
+ for e in pre_processing_meta
415
+ ]
416
+ target_sizes = torch.tensor(orig_sizes, device=self._device)
417
+ results = self._post_processor(model_results, target_sizes=target_sizes)
418
+ detections_list = []
419
+ for image_result, image_metadata in zip(results, pre_processing_meta):
420
+ scores = image_result["scores"]
421
+ labels = image_result["labels"]
422
+ boxes = image_result["boxes"]
423
+ if self._classes_re_mapping is not None:
424
+ remapping_mask = torch.isin(
425
+ labels, self._classes_re_mapping.remaining_class_ids
426
+ )
427
+ scores = scores[remapping_mask]
428
+ labels = self._classes_re_mapping.class_mapping[labels[remapping_mask]]
429
+ boxes = boxes[remapping_mask]
430
+ keep = scores > threshold
431
+ scores = scores[keep]
432
+ labels = labels[keep]
433
+ boxes = boxes[keep]
434
+ if (
435
+ self._inference_config.network_input.resize_mode
436
+ is ResizeMode.CENTER_CROP
437
+ ):
438
+ offsets = torch.as_tensor(
439
+ [
440
+ image_metadata.pad_left,
441
+ image_metadata.pad_top,
442
+ image_metadata.pad_left,
443
+ image_metadata.pad_top,
444
+ ],
445
+ dtype=boxes.dtype,
446
+ device=boxes.device,
447
+ )
448
+ boxes[:, :4].sub_(offsets)
449
+ if (
450
+ image_metadata.static_crop_offset.offset_x != 0
451
+ or image_metadata.static_crop_offset.offset_y != 0
452
+ ):
453
+ static_crop_offsets = torch.as_tensor(
454
+ [
455
+ image_metadata.static_crop_offset.offset_x,
456
+ image_metadata.static_crop_offset.offset_y,
457
+ image_metadata.static_crop_offset.offset_x,
458
+ image_metadata.static_crop_offset.offset_y,
459
+ ],
460
+ dtype=boxes.dtype,
461
+ device=boxes.device,
462
+ )
463
+ boxes[:, :4].add_(static_crop_offsets)
464
+ detections = Detections(
465
+ xyxy=boxes.round().int(),
466
+ confidence=scores,
467
+ class_id=labels.int(),
468
+ )
469
+ detections_list.append(detections)
470
+ return detections_list