inference-models 0.18.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. inference_models/__init__.py +36 -0
  2. inference_models/configuration.py +72 -0
  3. inference_models/constants.py +2 -0
  4. inference_models/entities.py +5 -0
  5. inference_models/errors.py +137 -0
  6. inference_models/logger.py +52 -0
  7. inference_models/model_pipelines/__init__.py +0 -0
  8. inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
  9. inference_models/model_pipelines/auto_loaders/core.py +120 -0
  10. inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
  11. inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
  12. inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
  13. inference_models/models/__init__.py +0 -0
  14. inference_models/models/auto_loaders/__init__.py +0 -0
  15. inference_models/models/auto_loaders/access_manager.py +168 -0
  16. inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
  17. inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
  18. inference_models/models/auto_loaders/constants.py +7 -0
  19. inference_models/models/auto_loaders/core.py +1341 -0
  20. inference_models/models/auto_loaders/dependency_models.py +52 -0
  21. inference_models/models/auto_loaders/entities.py +57 -0
  22. inference_models/models/auto_loaders/models_registry.py +497 -0
  23. inference_models/models/auto_loaders/presentation_utils.py +333 -0
  24. inference_models/models/auto_loaders/ranking.py +413 -0
  25. inference_models/models/auto_loaders/utils.py +31 -0
  26. inference_models/models/base/__init__.py +0 -0
  27. inference_models/models/base/classification.py +123 -0
  28. inference_models/models/base/depth_estimation.py +62 -0
  29. inference_models/models/base/documents_parsing.py +111 -0
  30. inference_models/models/base/embeddings.py +66 -0
  31. inference_models/models/base/instance_segmentation.py +87 -0
  32. inference_models/models/base/keypoints_detection.py +93 -0
  33. inference_models/models/base/object_detection.py +143 -0
  34. inference_models/models/base/semantic_segmentation.py +74 -0
  35. inference_models/models/base/types.py +5 -0
  36. inference_models/models/clip/__init__.py +0 -0
  37. inference_models/models/clip/clip_onnx.py +148 -0
  38. inference_models/models/clip/clip_pytorch.py +104 -0
  39. inference_models/models/clip/preprocessing.py +162 -0
  40. inference_models/models/common/__init__.py +0 -0
  41. inference_models/models/common/cuda.py +30 -0
  42. inference_models/models/common/model_packages.py +25 -0
  43. inference_models/models/common/onnx.py +379 -0
  44. inference_models/models/common/roboflow/__init__.py +0 -0
  45. inference_models/models/common/roboflow/model_packages.py +361 -0
  46. inference_models/models/common/roboflow/post_processing.py +436 -0
  47. inference_models/models/common/roboflow/pre_processing.py +1332 -0
  48. inference_models/models/common/torch.py +20 -0
  49. inference_models/models/common/trt.py +266 -0
  50. inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
  51. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
  52. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
  53. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
  54. inference_models/models/depth_anything_v2/__init__.py +0 -0
  55. inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
  56. inference_models/models/dinov3/__init__.py +0 -0
  57. inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
  58. inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
  59. inference_models/models/doctr/__init__.py +0 -0
  60. inference_models/models/doctr/doctr_torch.py +304 -0
  61. inference_models/models/easy_ocr/__init__.py +0 -0
  62. inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
  63. inference_models/models/florence2/__init__.py +0 -0
  64. inference_models/models/florence2/florence2_hf.py +897 -0
  65. inference_models/models/grounding_dino/__init__.py +0 -0
  66. inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
  67. inference_models/models/l2cs/__init__.py +0 -0
  68. inference_models/models/l2cs/l2cs_onnx.py +216 -0
  69. inference_models/models/mediapipe_face_detection/__init__.py +0 -0
  70. inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
  71. inference_models/models/moondream2/__init__.py +0 -0
  72. inference_models/models/moondream2/moondream2_hf.py +281 -0
  73. inference_models/models/owlv2/__init__.py +0 -0
  74. inference_models/models/owlv2/cache.py +182 -0
  75. inference_models/models/owlv2/entities.py +112 -0
  76. inference_models/models/owlv2/owlv2_hf.py +695 -0
  77. inference_models/models/owlv2/reference_dataset.py +291 -0
  78. inference_models/models/paligemma/__init__.py +0 -0
  79. inference_models/models/paligemma/paligemma_hf.py +209 -0
  80. inference_models/models/perception_encoder/__init__.py +0 -0
  81. inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
  82. inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
  83. inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
  84. inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
  85. inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
  86. inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
  87. inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
  88. inference_models/models/qwen25vl/__init__.py +1 -0
  89. inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
  90. inference_models/models/resnet/__init__.py +0 -0
  91. inference_models/models/resnet/resnet_classification_onnx.py +330 -0
  92. inference_models/models/resnet/resnet_classification_torch.py +305 -0
  93. inference_models/models/resnet/resnet_classification_trt.py +369 -0
  94. inference_models/models/rfdetr/__init__.py +0 -0
  95. inference_models/models/rfdetr/backbone_builder.py +101 -0
  96. inference_models/models/rfdetr/class_remapping.py +41 -0
  97. inference_models/models/rfdetr/common.py +115 -0
  98. inference_models/models/rfdetr/default_labels.py +108 -0
  99. inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
  100. inference_models/models/rfdetr/misc.py +26 -0
  101. inference_models/models/rfdetr/ms_deform_attn.py +180 -0
  102. inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
  103. inference_models/models/rfdetr/position_encoding.py +166 -0
  104. inference_models/models/rfdetr/post_processor.py +83 -0
  105. inference_models/models/rfdetr/projector.py +373 -0
  106. inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
  107. inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
  108. inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
  109. inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
  110. inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
  111. inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
  112. inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
  113. inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
  114. inference_models/models/rfdetr/segmentation_head.py +273 -0
  115. inference_models/models/rfdetr/transformer.py +767 -0
  116. inference_models/models/roboflow_instant/__init__.py +0 -0
  117. inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
  118. inference_models/models/sam/__init__.py +0 -0
  119. inference_models/models/sam/cache.py +147 -0
  120. inference_models/models/sam/entities.py +25 -0
  121. inference_models/models/sam/sam_torch.py +675 -0
  122. inference_models/models/sam2/__init__.py +0 -0
  123. inference_models/models/sam2/cache.py +162 -0
  124. inference_models/models/sam2/entities.py +43 -0
  125. inference_models/models/sam2/sam2_torch.py +905 -0
  126. inference_models/models/sam2_rt/__init__.py +0 -0
  127. inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
  128. inference_models/models/smolvlm/__init__.py +0 -0
  129. inference_models/models/smolvlm/smolvlm_hf.py +245 -0
  130. inference_models/models/trocr/__init__.py +0 -0
  131. inference_models/models/trocr/trocr_hf.py +53 -0
  132. inference_models/models/vit/__init__.py +0 -0
  133. inference_models/models/vit/vit_classification_huggingface.py +319 -0
  134. inference_models/models/vit/vit_classification_onnx.py +326 -0
  135. inference_models/models/vit/vit_classification_trt.py +365 -0
  136. inference_models/models/yolact/__init__.py +1 -0
  137. inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
  138. inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
  139. inference_models/models/yolo_world/__init__.py +1 -0
  140. inference_models/models/yolonas/__init__.py +0 -0
  141. inference_models/models/yolonas/nms.py +44 -0
  142. inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
  143. inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
  144. inference_models/models/yolov10/__init__.py +0 -0
  145. inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
  146. inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
  147. inference_models/models/yolov11/__init__.py +0 -0
  148. inference_models/models/yolov11/yolov11_onnx.py +28 -0
  149. inference_models/models/yolov11/yolov11_torch_script.py +25 -0
  150. inference_models/models/yolov11/yolov11_trt.py +21 -0
  151. inference_models/models/yolov12/__init__.py +0 -0
  152. inference_models/models/yolov12/yolov12_onnx.py +7 -0
  153. inference_models/models/yolov12/yolov12_torch_script.py +7 -0
  154. inference_models/models/yolov12/yolov12_trt.py +7 -0
  155. inference_models/models/yolov5/__init__.py +0 -0
  156. inference_models/models/yolov5/nms.py +99 -0
  157. inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
  158. inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
  159. inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
  160. inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
  161. inference_models/models/yolov7/__init__.py +0 -0
  162. inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
  163. inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
  164. inference_models/models/yolov8/__init__.py +0 -0
  165. inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
  166. inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
  167. inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
  168. inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
  169. inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
  170. inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
  171. inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
  172. inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
  173. inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
  174. inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
  175. inference_models/models/yolov9/__init__.py +0 -0
  176. inference_models/models/yolov9/yolov9_onnx.py +7 -0
  177. inference_models/models/yolov9/yolov9_torch_script.py +7 -0
  178. inference_models/models/yolov9/yolov9_trt.py +7 -0
  179. inference_models/runtime_introspection/__init__.py +0 -0
  180. inference_models/runtime_introspection/core.py +410 -0
  181. inference_models/utils/__init__.py +0 -0
  182. inference_models/utils/download.py +608 -0
  183. inference_models/utils/environment.py +28 -0
  184. inference_models/utils/file_system.py +51 -0
  185. inference_models/utils/hashing.py +7 -0
  186. inference_models/utils/imports.py +48 -0
  187. inference_models/utils/onnx_introspection.py +17 -0
  188. inference_models/weights_providers/__init__.py +0 -0
  189. inference_models/weights_providers/core.py +20 -0
  190. inference_models/weights_providers/entities.py +159 -0
  191. inference_models/weights_providers/roboflow.py +601 -0
  192. inference_models-0.18.3.dist-info/METADATA +466 -0
  193. inference_models-0.18.3.dist-info/RECORD +195 -0
  194. inference_models-0.18.3.dist-info/WHEEL +5 -0
  195. inference_models-0.18.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,264 @@
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import segmentation_models_pytorch as smp
4
+ import torch
5
+ from torchvision.transforms import functional
6
+
7
+ from inference_models import ColorFormat, SemanticSegmentationModel
8
+ from inference_models.configuration import DEFAULT_DEVICE
9
+ from inference_models.errors import CorruptedModelPackageError
10
+ from inference_models.models.base.semantic_segmentation import (
11
+ SemanticSegmentationResult,
12
+ )
13
+ from inference_models.models.base.types import PreprocessingMetadata
14
+ from inference_models.models.common.model_packages import get_model_package_contents
15
+ from inference_models.models.common.roboflow.model_packages import (
16
+ InferenceConfig,
17
+ PreProcessingMetadata,
18
+ ResizeMode,
19
+ parse_class_names_file,
20
+ parse_inference_config,
21
+ )
22
+ from inference_models.models.common.roboflow.pre_processing import (
23
+ pre_process_network_input,
24
+ )
25
+
26
+
27
+ class DeepLabV3PlusForSemanticSegmentationTorch(
28
+ SemanticSegmentationModel[torch.Tensor, PreProcessingMetadata, torch.Tensor]
29
+ ):
30
+
31
+ @classmethod
32
+ def from_pretrained(
33
+ cls,
34
+ model_name_or_path: str,
35
+ device: torch.device = DEFAULT_DEVICE,
36
+ **kwargs,
37
+ ) -> "DeepLabV3PlusForSemanticSegmentationTorch":
38
+ model_package_content = get_model_package_contents(
39
+ model_package_dir=model_name_or_path,
40
+ elements=[
41
+ "class_names.txt",
42
+ "inference_config.json",
43
+ "weights.pt",
44
+ ],
45
+ )
46
+ class_names = parse_class_names_file(
47
+ class_names_path=model_package_content["class_names.txt"]
48
+ )
49
+ try:
50
+ background_class_id = [c.lower() for c in class_names].index("background")
51
+ except ValueError:
52
+ background_class_id = -1
53
+ inference_config = parse_inference_config(
54
+ config_path=model_package_content["inference_config.json"],
55
+ allowed_resize_modes={
56
+ ResizeMode.STRETCH_TO,
57
+ ResizeMode.LETTERBOX,
58
+ ResizeMode.CENTER_CROP,
59
+ ResizeMode.LETTERBOX_REFLECT_EDGES,
60
+ },
61
+ )
62
+ if inference_config.model_initialization is None:
63
+ raise CorruptedModelPackageError(
64
+ message="Expected model initialization parameters not provided in inference config.",
65
+ help_url="https://todo",
66
+ )
67
+ num_classes = inference_config.model_initialization.get("classes")
68
+ in_channels = inference_config.model_initialization.get("in_channels")
69
+ encoder_name = inference_config.model_initialization.get("encoder_name")
70
+ if not isinstance(num_classes, int) or num_classes < 1:
71
+ raise CorruptedModelPackageError(
72
+ message="Expected model initialization parameter `num_classes` not provided or in invalid format.",
73
+ help_url="https://todo",
74
+ )
75
+ if not isinstance(in_channels, int) or in_channels not in {1, 3}:
76
+ raise CorruptedModelPackageError(
77
+ message="Expected model initialization parameter `in_channels` not provided or in invalid format.",
78
+ help_url="https://todo",
79
+ )
80
+ if not isinstance(encoder_name, str):
81
+ raise CorruptedModelPackageError(
82
+ message="Expected model initialization parameter `encoder_name` not provided or in invalid format.",
83
+ help_url="https://todo",
84
+ )
85
+ model = (
86
+ smp.DeepLabV3Plus(
87
+ encoder_name=encoder_name,
88
+ in_channels=in_channels,
89
+ classes=num_classes,
90
+ )
91
+ .to(device)
92
+ .eval()
93
+ )
94
+ state_dict = torch.load(
95
+ model_package_content["weights.pt"],
96
+ weights_only=True,
97
+ map_location=device,
98
+ )
99
+ if "state_dict" in state_dict:
100
+ state_dict = state_dict["state_dict"]
101
+ state_dict = {k[len("model.") :]: v for k, v in state_dict.items()}
102
+ model.load_state_dict(state_dict)
103
+ return cls(
104
+ model=model.eval(),
105
+ inference_config=inference_config,
106
+ class_names=class_names,
107
+ background_class_id=background_class_id,
108
+ device=device,
109
+ )
110
+
111
+ def __init__(
112
+ self,
113
+ model: smp.DeepLabV3Plus,
114
+ inference_config: InferenceConfig,
115
+ class_names: List[str],
116
+ background_class_id: int,
117
+ device: torch.device,
118
+ ):
119
+ self._model = model
120
+ self._inference_config = inference_config
121
+ self._class_names = class_names
122
+ self._background_class_id = background_class_id
123
+ self._device = device
124
+
125
+ @property
126
+ def class_names(self) -> List[str]:
127
+ return self._class_names
128
+
129
+ def pre_process(
130
+ self,
131
+ images: Union[torch.Tensor, List[torch.Tensor]],
132
+ input_color_format: Optional[ColorFormat] = None,
133
+ image_size: Optional[Tuple[int, int]] = None,
134
+ **kwargs,
135
+ ) -> Tuple[torch.Tensor, PreprocessingMetadata]:
136
+ return pre_process_network_input(
137
+ images=images,
138
+ image_pre_processing=self._inference_config.image_pre_processing,
139
+ network_input=self._inference_config.network_input,
140
+ target_device=self._device,
141
+ input_color_format=input_color_format,
142
+ image_size_wh=image_size,
143
+ )
144
+
145
+ def forward(self, pre_processed_images: torch.Tensor, **kwargs) -> torch.Tensor:
146
+ with torch.inference_mode():
147
+ return self._model(pre_processed_images)
148
+
149
+ def post_process(
150
+ self,
151
+ model_results: torch.Tensor,
152
+ pre_processing_meta: List[PreProcessingMetadata],
153
+ confidence_threshold: float = 0.5,
154
+ **kwargs,
155
+ ) -> List[SemanticSegmentationResult]:
156
+ results = []
157
+ for image_results, image_metadata in zip(model_results, pre_processing_meta):
158
+ inference_size = image_metadata.inference_size
159
+ mask_h_scale = model_results.shape[2] / inference_size.height
160
+ mask_w_scale = model_results.shape[3] / inference_size.width
161
+ mask_pad_top, mask_pad_bottom, mask_pad_left, mask_pad_right = (
162
+ round(mask_h_scale * image_metadata.pad_top),
163
+ round(mask_h_scale * image_metadata.pad_bottom),
164
+ round(mask_w_scale * image_metadata.pad_left),
165
+ round(mask_w_scale * image_metadata.pad_right),
166
+ )
167
+ _, mh, mw = image_results.shape
168
+ if (
169
+ mask_pad_top < 0
170
+ or mask_pad_bottom < 0
171
+ or mask_pad_left < 0
172
+ or mask_pad_right < 0
173
+ ):
174
+ image_results = torch.nn.functional.pad(
175
+ image_results,
176
+ (
177
+ abs(min(mask_pad_left, 0)),
178
+ abs(min(mask_pad_right, 0)),
179
+ abs(min(mask_pad_top, 0)),
180
+ abs(min(mask_pad_bottom, 0)),
181
+ ),
182
+ "constant",
183
+ self._background_class_id,
184
+ )
185
+ padded_mask_offset_top = max(mask_pad_top, 0)
186
+ padded_mask_offset_bottom = max(mask_pad_bottom, 0)
187
+ padded_mask_offset_left = max(mask_pad_left, 0)
188
+ padded_mask_offset_right = max(mask_pad_right, 0)
189
+ image_results = image_results[
190
+ :,
191
+ padded_mask_offset_top : image_results.shape[1]
192
+ - padded_mask_offset_bottom,
193
+ padded_mask_offset_left : image_results.shape[1]
194
+ - padded_mask_offset_right,
195
+ ]
196
+ else:
197
+ image_results = image_results[
198
+ :,
199
+ mask_pad_top : mh - mask_pad_bottom,
200
+ mask_pad_left : mw - mask_pad_right,
201
+ ]
202
+ if (
203
+ image_results.shape[1]
204
+ != image_metadata.size_after_pre_processing.height
205
+ or image_results.shape[2]
206
+ != image_metadata.size_after_pre_processing.width
207
+ ):
208
+ image_results = functional.resize(
209
+ image_results,
210
+ [
211
+ image_metadata.size_after_pre_processing.height,
212
+ image_metadata.size_after_pre_processing.width,
213
+ ],
214
+ interpolation=functional.InterpolationMode.BILINEAR,
215
+ )
216
+ image_results = torch.nn.functional.softmax(image_results, dim=0)
217
+ image_confidence, image_class_ids = torch.max(image_results, dim=0)
218
+ below_threshold = image_confidence < confidence_threshold
219
+ image_confidence[below_threshold] = 0.0
220
+ image_class_ids[below_threshold] = self._background_class_id
221
+ if (
222
+ image_metadata.static_crop_offset.offset_x > 0
223
+ or image_metadata.static_crop_offset.offset_y > 0
224
+ ):
225
+ original_size_confidence_canvas = torch.zeros(
226
+ (
227
+ image_metadata.original_size.height,
228
+ image_metadata.original_size.width,
229
+ ),
230
+ device=self._device,
231
+ dtype=image_confidence.dtype,
232
+ )
233
+ original_size_confidence_canvas[
234
+ image_metadata.static_crop_offset.offset_y : image_metadata.static_crop_offset.offset_y
235
+ + image_confidence.shape[0],
236
+ image_metadata.static_crop_offset.offset_x : image_metadata.static_crop_offset.offset_x
237
+ + image_confidence.shape[1],
238
+ ] = image_confidence
239
+ original_size_confidence_class_id_canvas = (
240
+ torch.ones(
241
+ (
242
+ image_metadata.original_size.height,
243
+ image_metadata.original_size.width,
244
+ ),
245
+ device=self._device,
246
+ dtype=image_class_ids.dtype,
247
+ )
248
+ * self._background_class_id
249
+ )
250
+ original_size_confidence_class_id_canvas[
251
+ image_metadata.static_crop_offset.offset_y : image_metadata.static_crop_offset.offset_y
252
+ + image_class_ids.shape[0],
253
+ image_metadata.static_crop_offset.offset_x : image_metadata.static_crop_offset.offset_x
254
+ + image_class_ids.shape[1],
255
+ ] = image_class_ids
256
+ image_class_ids = original_size_confidence_class_id_canvas
257
+ image_confidence = original_size_confidence_canvas
258
+ results.append(
259
+ SemanticSegmentationResult(
260
+ segmentation_map=image_class_ids,
261
+ confidence=image_confidence,
262
+ )
263
+ )
264
+ return results
@@ -0,0 +1,313 @@
1
+ from threading import Lock
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ from torchvision.transforms import functional
6
+
7
+ from inference_models import ColorFormat, SemanticSegmentationModel
8
+ from inference_models.configuration import DEFAULT_DEVICE
9
+ from inference_models.errors import (
10
+ CorruptedModelPackageError,
11
+ MissingDependencyError,
12
+ ModelRuntimeError,
13
+ )
14
+ from inference_models.models.base.semantic_segmentation import (
15
+ SemanticSegmentationResult,
16
+ )
17
+ from inference_models.models.base.types import PreprocessedInputs, PreprocessingMetadata
18
+ from inference_models.models.common.cuda import (
19
+ use_cuda_context,
20
+ use_primary_cuda_context,
21
+ )
22
+ from inference_models.models.common.model_packages import get_model_package_contents
23
+ from inference_models.models.common.roboflow.model_packages import (
24
+ InferenceConfig,
25
+ PreProcessingMetadata,
26
+ ResizeMode,
27
+ TRTConfig,
28
+ parse_class_names_file,
29
+ parse_inference_config,
30
+ parse_trt_config,
31
+ )
32
+ from inference_models.models.common.roboflow.pre_processing import (
33
+ pre_process_network_input,
34
+ )
35
+ from inference_models.models.common.trt import (
36
+ get_engine_inputs_and_outputs,
37
+ infer_from_trt_engine,
38
+ load_model,
39
+ )
40
+
41
+ try:
42
+ import tensorrt as trt
43
+ except ImportError as import_error:
44
+ raise MissingDependencyError(
45
+ message=f"Could not import YOLOv8 model with TRT backend - this error means that some additional dependencies "
46
+ f"are not installed in the environment. If you run the `inference-models` library directly in your Python "
47
+ f"program, make sure the following extras of the package are installed: `trt10` - installation can only "
48
+ f"succeed for Linux and Windows machines with Cuda 12 installed. Jetson devices, should have TRT 10.x "
49
+ f"installed for all builds with Jetpack 6. "
50
+ f"If you see this error using Roboflow infrastructure, make sure the service you use does support the model. "
51
+ f"You can also contact Roboflow to get support.",
52
+ help_url="https://todo",
53
+ ) from import_error
54
+
55
+ try:
56
+ import pycuda.driver as cuda
57
+ except ImportError as import_error:
58
+ raise MissingDependencyError(
59
+ message="TODO", help_url="https://todo"
60
+ ) from import_error
61
+
62
+
63
+ class DeepLabV3PlusForSemanticSegmentationTRT(
64
+ SemanticSegmentationModel[torch.Tensor, PreProcessingMetadata, torch.Tensor]
65
+ ):
66
+
67
+ @classmethod
68
+ def from_pretrained(
69
+ cls,
70
+ model_name_or_path: str,
71
+ device: torch.device = DEFAULT_DEVICE,
72
+ engine_host_code_allowed: bool = False,
73
+ **kwargs,
74
+ ) -> "DeepLabV3PlusForSemanticSegmentationTRT":
75
+ if device.type != "cuda":
76
+ raise ModelRuntimeError(
77
+ message=f"TRT engine only runs on CUDA device - {device} device detected.",
78
+ help_url="https://todo",
79
+ )
80
+ model_package_content = get_model_package_contents(
81
+ model_package_dir=model_name_or_path,
82
+ elements=[
83
+ "class_names.txt",
84
+ "inference_config.json",
85
+ "trt_config.json",
86
+ "engine.plan",
87
+ ],
88
+ )
89
+ class_names = parse_class_names_file(
90
+ class_names_path=model_package_content["class_names.txt"]
91
+ )
92
+ try:
93
+ background_class_id = [c.lower() for c in class_names].index("background")
94
+ except ValueError:
95
+ background_class_id = -1
96
+ inference_config = parse_inference_config(
97
+ config_path=model_package_content["inference_config.json"],
98
+ allowed_resize_modes={
99
+ ResizeMode.STRETCH_TO,
100
+ ResizeMode.LETTERBOX,
101
+ ResizeMode.CENTER_CROP,
102
+ ResizeMode.LETTERBOX_REFLECT_EDGES,
103
+ },
104
+ )
105
+ trt_config = parse_trt_config(
106
+ config_path=model_package_content["trt_config.json"]
107
+ )
108
+ cuda.init()
109
+ cuda_device = cuda.Device(device.index or 0)
110
+ with use_primary_cuda_context(cuda_device=cuda_device) as cuda_context:
111
+ engine = load_model(
112
+ model_path=model_package_content["engine.plan"],
113
+ engine_host_code_allowed=engine_host_code_allowed,
114
+ )
115
+ execution_context = engine.create_execution_context()
116
+ inputs, outputs = get_engine_inputs_and_outputs(engine=engine)
117
+ if len(inputs) != 1:
118
+ raise CorruptedModelPackageError(
119
+ message=f"Implementation assume single model input, found: {len(inputs)}.",
120
+ help_url="https://todo",
121
+ )
122
+ if len(outputs) != 1:
123
+ raise CorruptedModelPackageError(
124
+ message=f"Implementation assume single model output, found: {len(outputs)}.",
125
+ help_url="https://todo",
126
+ )
127
+ return cls(
128
+ engine=engine,
129
+ input_name=inputs[0],
130
+ output_name=outputs[0],
131
+ class_names=class_names,
132
+ background_class_id=background_class_id,
133
+ inference_config=inference_config,
134
+ trt_config=trt_config,
135
+ device=device,
136
+ cuda_context=cuda_context,
137
+ execution_context=execution_context,
138
+ )
139
+
140
+ def __init__(
141
+ self,
142
+ engine: trt.ICudaEngine,
143
+ input_name: str,
144
+ output_name: str,
145
+ class_names: List[str],
146
+ background_class_id: int,
147
+ inference_config: InferenceConfig,
148
+ trt_config: TRTConfig,
149
+ device: torch.device,
150
+ cuda_context: cuda.Context,
151
+ execution_context: trt.IExecutionContext,
152
+ ):
153
+ self._engine = engine
154
+ self._input_name = input_name
155
+ self._output_names = [output_name]
156
+ self._class_names = class_names
157
+ self._background_class_id = background_class_id
158
+ self._inference_config = inference_config
159
+ self._trt_config = trt_config
160
+ self._device = device
161
+ self._cuda_context = cuda_context
162
+ self._execution_context = execution_context
163
+ self._lock = Lock()
164
+
165
+ @property
166
+ def class_names(self) -> List[str]:
167
+ return self._class_names
168
+
169
+ def pre_process(
170
+ self,
171
+ images: Union[torch.Tensor, List[torch.Tensor]],
172
+ input_color_format: Optional[ColorFormat] = None,
173
+ **kwargs,
174
+ ) -> Tuple[PreprocessedInputs, PreprocessingMetadata]:
175
+ return pre_process_network_input(
176
+ images=images,
177
+ image_pre_processing=self._inference_config.image_pre_processing,
178
+ network_input=self._inference_config.network_input,
179
+ target_device=self._device,
180
+ input_color_format=input_color_format,
181
+ )
182
+
183
+ def forward(
184
+ self, pre_processed_images: PreprocessedInputs, **kwargs
185
+ ) -> torch.Tensor:
186
+ with self._lock:
187
+ with use_cuda_context(context=self._cuda_context):
188
+ return infer_from_trt_engine(
189
+ pre_processed_images=pre_processed_images,
190
+ trt_config=self._trt_config,
191
+ engine=self._engine,
192
+ context=self._execution_context,
193
+ device=self._device,
194
+ input_name=self._input_name,
195
+ outputs=self._output_names,
196
+ )[0]
197
+
198
+ def post_process(
199
+ self,
200
+ model_results: torch.Tensor,
201
+ pre_processing_meta: PreprocessedInputs,
202
+ confidence_threshold: float = 0.5,
203
+ **kwargs,
204
+ ) -> List[SemanticSegmentationResult]:
205
+ results = []
206
+ for image_results, image_metadata in zip(model_results, pre_processing_meta):
207
+ inference_size = image_metadata.inference_size
208
+ mask_h_scale = model_results.shape[2] / inference_size.height
209
+ mask_w_scale = model_results.shape[3] / inference_size.width
210
+ mask_pad_top, mask_pad_bottom, mask_pad_left, mask_pad_right = (
211
+ round(mask_h_scale * image_metadata.pad_top),
212
+ round(mask_h_scale * image_metadata.pad_bottom),
213
+ round(mask_w_scale * image_metadata.pad_left),
214
+ round(mask_w_scale * image_metadata.pad_right),
215
+ )
216
+ _, mh, mw = image_results.shape
217
+ if (
218
+ mask_pad_top < 0
219
+ or mask_pad_bottom < 0
220
+ or mask_pad_left < 0
221
+ or mask_pad_right < 0
222
+ ):
223
+ image_results = torch.nn.functional.pad(
224
+ image_results,
225
+ (
226
+ abs(min(mask_pad_left, 0)),
227
+ abs(min(mask_pad_right, 0)),
228
+ abs(min(mask_pad_top, 0)),
229
+ abs(min(mask_pad_bottom, 0)),
230
+ ),
231
+ "constant",
232
+ self._background_class_id,
233
+ )
234
+ padded_mask_offset_top = max(mask_pad_top, 0)
235
+ padded_mask_offset_bottom = max(mask_pad_bottom, 0)
236
+ padded_mask_offset_left = max(mask_pad_left, 0)
237
+ padded_mask_offset_right = max(mask_pad_right, 0)
238
+ image_results = image_results[
239
+ :,
240
+ padded_mask_offset_top : image_results.shape[1]
241
+ - padded_mask_offset_bottom,
242
+ padded_mask_offset_left : image_results.shape[1]
243
+ - padded_mask_offset_right,
244
+ ]
245
+ else:
246
+ image_results = image_results[
247
+ :,
248
+ mask_pad_top : mh - mask_pad_bottom,
249
+ mask_pad_left : mw - mask_pad_right,
250
+ ]
251
+ if (
252
+ image_results.shape[1]
253
+ != image_metadata.size_after_pre_processing.height
254
+ or image_results.shape[2]
255
+ != image_metadata.size_after_pre_processing.width
256
+ ):
257
+ image_results = functional.resize(
258
+ image_results,
259
+ [
260
+ image_metadata.size_after_pre_processing.height,
261
+ image_metadata.size_after_pre_processing.width,
262
+ ],
263
+ interpolation=functional.InterpolationMode.BILINEAR,
264
+ )
265
+ image_results = torch.nn.functional.softmax(image_results, dim=0)
266
+ image_confidence, image_class_ids = torch.max(image_results, dim=0)
267
+ below_threshold = image_confidence < confidence_threshold
268
+ image_confidence[below_threshold] = 0.0
269
+ image_class_ids[below_threshold] = self._background_class_id
270
+ if (
271
+ image_metadata.static_crop_offset.offset_x > 0
272
+ or image_metadata.static_crop_offset.offset_y > 0
273
+ ):
274
+ original_size_confidence_canvas = torch.zeros(
275
+ (
276
+ image_metadata.original_size.height,
277
+ image_metadata.original_size.width,
278
+ ),
279
+ device=self._device,
280
+ dtype=image_confidence.dtype,
281
+ )
282
+ original_size_confidence_canvas[
283
+ image_metadata.static_crop_offset.offset_y : image_metadata.static_crop_offset.offset_y
284
+ + image_confidence.shape[0],
285
+ image_metadata.static_crop_offset.offset_x : image_metadata.static_crop_offset.offset_x
286
+ + image_confidence.shape[1],
287
+ ] = image_confidence
288
+ original_size_confidence_class_id_canvas = (
289
+ torch.ones(
290
+ (
291
+ image_metadata.original_size.height,
292
+ image_metadata.original_size.width,
293
+ ),
294
+ device=self._device,
295
+ dtype=image_class_ids.dtype,
296
+ )
297
+ * self._background_class_id
298
+ )
299
+ original_size_confidence_class_id_canvas[
300
+ image_metadata.static_crop_offset.offset_y : image_metadata.static_crop_offset.offset_y
301
+ + image_class_ids.shape[0],
302
+ image_metadata.static_crop_offset.offset_x : image_metadata.static_crop_offset.offset_x
303
+ + image_class_ids.shape[1],
304
+ ] = image_class_ids
305
+ image_class_ids = original_size_confidence_class_id_canvas
306
+ image_confidence = original_size_confidence_canvas
307
+ results.append(
308
+ SemanticSegmentationResult(
309
+ segmentation_map=image_class_ids,
310
+ confidence=image_confidence,
311
+ )
312
+ )
313
+ return results
File without changes
@@ -0,0 +1,77 @@
1
+ from typing import List, Tuple, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ from transformers import AutoImageProcessor, AutoModelForDepthEstimation
6
+
7
+ from inference_models.configuration import DEFAULT_DEVICE
8
+ from inference_models.entities import ImageDimensions
9
+ from inference_models.models.base.depth_estimation import DepthEstimationModel
10
+ from inference_models.models.common.roboflow.pre_processing import (
11
+ extract_input_images_dimensions,
12
+ )
13
+
14
+
15
+ class DepthAnythingV2HF(
16
+ DepthEstimationModel[torch.Tensor, List[ImageDimensions], torch.Tensor]
17
+ ):
18
+
19
+ @classmethod
20
+ def from_pretrained(
21
+ cls,
22
+ model_name_or_path: str,
23
+ device: torch.device = DEFAULT_DEVICE,
24
+ local_files_only: bool = True,
25
+ **kwargs,
26
+ ) -> "DepthAnythingV2HF":
27
+ model = AutoModelForDepthEstimation.from_pretrained(
28
+ model_name_or_path,
29
+ local_files_only=local_files_only,
30
+ ).to(device)
31
+ processor = AutoImageProcessor.from_pretrained(
32
+ model_name_or_path, local_files_only=local_files_only, use_fast=True
33
+ )
34
+ return cls(model=model, processor=processor, device=device)
35
+
36
+ def __init__(
37
+ self,
38
+ model: AutoModelForDepthEstimation,
39
+ processor: AutoImageProcessor,
40
+ device: torch.device,
41
+ ):
42
+ self._model = model
43
+ self._processor = processor
44
+ self._device = device
45
+
46
+ def pre_process(
47
+ self,
48
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
49
+ **kwargs,
50
+ ) -> Tuple[torch.Tensor, List[ImageDimensions]]:
51
+ image_dimensions = extract_input_images_dimensions(images=images)
52
+ inputs = self._processor(images=images, return_tensors="pt")
53
+ return inputs["pixel_values"].to(self._device), image_dimensions
54
+
55
+ def forward(
56
+ self,
57
+ pre_processed_images: torch.Tensor,
58
+ **kwargs,
59
+ ) -> torch.Tensor:
60
+ with torch.inference_mode():
61
+ return self._model(pre_processed_images)
62
+
63
+ def post_process(
64
+ self,
65
+ model_results: torch.Tensor,
66
+ pre_processing_meta: List[ImageDimensions],
67
+ **kwargs,
68
+ ) -> List[torch.Tensor]:
69
+ target_sizes = [(dim.height, dim.width) for dim in pre_processing_meta]
70
+ post_processed_outputs = self._processor.post_process_depth_estimation(
71
+ model_results,
72
+ target_sizes=target_sizes,
73
+ )
74
+ return [
75
+ output["predicted_depth"].to(self._device)
76
+ for output in post_processed_outputs
77
+ ]
File without changes