inference-models 0.18.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. inference_models/__init__.py +36 -0
  2. inference_models/configuration.py +72 -0
  3. inference_models/constants.py +2 -0
  4. inference_models/entities.py +5 -0
  5. inference_models/errors.py +137 -0
  6. inference_models/logger.py +52 -0
  7. inference_models/model_pipelines/__init__.py +0 -0
  8. inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
  9. inference_models/model_pipelines/auto_loaders/core.py +120 -0
  10. inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
  11. inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
  12. inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
  13. inference_models/models/__init__.py +0 -0
  14. inference_models/models/auto_loaders/__init__.py +0 -0
  15. inference_models/models/auto_loaders/access_manager.py +168 -0
  16. inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
  17. inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
  18. inference_models/models/auto_loaders/constants.py +7 -0
  19. inference_models/models/auto_loaders/core.py +1341 -0
  20. inference_models/models/auto_loaders/dependency_models.py +52 -0
  21. inference_models/models/auto_loaders/entities.py +57 -0
  22. inference_models/models/auto_loaders/models_registry.py +497 -0
  23. inference_models/models/auto_loaders/presentation_utils.py +333 -0
  24. inference_models/models/auto_loaders/ranking.py +413 -0
  25. inference_models/models/auto_loaders/utils.py +31 -0
  26. inference_models/models/base/__init__.py +0 -0
  27. inference_models/models/base/classification.py +123 -0
  28. inference_models/models/base/depth_estimation.py +62 -0
  29. inference_models/models/base/documents_parsing.py +111 -0
  30. inference_models/models/base/embeddings.py +66 -0
  31. inference_models/models/base/instance_segmentation.py +87 -0
  32. inference_models/models/base/keypoints_detection.py +93 -0
  33. inference_models/models/base/object_detection.py +143 -0
  34. inference_models/models/base/semantic_segmentation.py +74 -0
  35. inference_models/models/base/types.py +5 -0
  36. inference_models/models/clip/__init__.py +0 -0
  37. inference_models/models/clip/clip_onnx.py +148 -0
  38. inference_models/models/clip/clip_pytorch.py +104 -0
  39. inference_models/models/clip/preprocessing.py +162 -0
  40. inference_models/models/common/__init__.py +0 -0
  41. inference_models/models/common/cuda.py +30 -0
  42. inference_models/models/common/model_packages.py +25 -0
  43. inference_models/models/common/onnx.py +379 -0
  44. inference_models/models/common/roboflow/__init__.py +0 -0
  45. inference_models/models/common/roboflow/model_packages.py +361 -0
  46. inference_models/models/common/roboflow/post_processing.py +436 -0
  47. inference_models/models/common/roboflow/pre_processing.py +1332 -0
  48. inference_models/models/common/torch.py +20 -0
  49. inference_models/models/common/trt.py +266 -0
  50. inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
  51. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
  52. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
  53. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
  54. inference_models/models/depth_anything_v2/__init__.py +0 -0
  55. inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
  56. inference_models/models/dinov3/__init__.py +0 -0
  57. inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
  58. inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
  59. inference_models/models/doctr/__init__.py +0 -0
  60. inference_models/models/doctr/doctr_torch.py +304 -0
  61. inference_models/models/easy_ocr/__init__.py +0 -0
  62. inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
  63. inference_models/models/florence2/__init__.py +0 -0
  64. inference_models/models/florence2/florence2_hf.py +897 -0
  65. inference_models/models/grounding_dino/__init__.py +0 -0
  66. inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
  67. inference_models/models/l2cs/__init__.py +0 -0
  68. inference_models/models/l2cs/l2cs_onnx.py +216 -0
  69. inference_models/models/mediapipe_face_detection/__init__.py +0 -0
  70. inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
  71. inference_models/models/moondream2/__init__.py +0 -0
  72. inference_models/models/moondream2/moondream2_hf.py +281 -0
  73. inference_models/models/owlv2/__init__.py +0 -0
  74. inference_models/models/owlv2/cache.py +182 -0
  75. inference_models/models/owlv2/entities.py +112 -0
  76. inference_models/models/owlv2/owlv2_hf.py +695 -0
  77. inference_models/models/owlv2/reference_dataset.py +291 -0
  78. inference_models/models/paligemma/__init__.py +0 -0
  79. inference_models/models/paligemma/paligemma_hf.py +209 -0
  80. inference_models/models/perception_encoder/__init__.py +0 -0
  81. inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
  82. inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
  83. inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
  84. inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
  85. inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
  86. inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
  87. inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
  88. inference_models/models/qwen25vl/__init__.py +1 -0
  89. inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
  90. inference_models/models/resnet/__init__.py +0 -0
  91. inference_models/models/resnet/resnet_classification_onnx.py +330 -0
  92. inference_models/models/resnet/resnet_classification_torch.py +305 -0
  93. inference_models/models/resnet/resnet_classification_trt.py +369 -0
  94. inference_models/models/rfdetr/__init__.py +0 -0
  95. inference_models/models/rfdetr/backbone_builder.py +101 -0
  96. inference_models/models/rfdetr/class_remapping.py +41 -0
  97. inference_models/models/rfdetr/common.py +115 -0
  98. inference_models/models/rfdetr/default_labels.py +108 -0
  99. inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
  100. inference_models/models/rfdetr/misc.py +26 -0
  101. inference_models/models/rfdetr/ms_deform_attn.py +180 -0
  102. inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
  103. inference_models/models/rfdetr/position_encoding.py +166 -0
  104. inference_models/models/rfdetr/post_processor.py +83 -0
  105. inference_models/models/rfdetr/projector.py +373 -0
  106. inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
  107. inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
  108. inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
  109. inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
  110. inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
  111. inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
  112. inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
  113. inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
  114. inference_models/models/rfdetr/segmentation_head.py +273 -0
  115. inference_models/models/rfdetr/transformer.py +767 -0
  116. inference_models/models/roboflow_instant/__init__.py +0 -0
  117. inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
  118. inference_models/models/sam/__init__.py +0 -0
  119. inference_models/models/sam/cache.py +147 -0
  120. inference_models/models/sam/entities.py +25 -0
  121. inference_models/models/sam/sam_torch.py +675 -0
  122. inference_models/models/sam2/__init__.py +0 -0
  123. inference_models/models/sam2/cache.py +162 -0
  124. inference_models/models/sam2/entities.py +43 -0
  125. inference_models/models/sam2/sam2_torch.py +905 -0
  126. inference_models/models/sam2_rt/__init__.py +0 -0
  127. inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
  128. inference_models/models/smolvlm/__init__.py +0 -0
  129. inference_models/models/smolvlm/smolvlm_hf.py +245 -0
  130. inference_models/models/trocr/__init__.py +0 -0
  131. inference_models/models/trocr/trocr_hf.py +53 -0
  132. inference_models/models/vit/__init__.py +0 -0
  133. inference_models/models/vit/vit_classification_huggingface.py +319 -0
  134. inference_models/models/vit/vit_classification_onnx.py +326 -0
  135. inference_models/models/vit/vit_classification_trt.py +365 -0
  136. inference_models/models/yolact/__init__.py +1 -0
  137. inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
  138. inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
  139. inference_models/models/yolo_world/__init__.py +1 -0
  140. inference_models/models/yolonas/__init__.py +0 -0
  141. inference_models/models/yolonas/nms.py +44 -0
  142. inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
  143. inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
  144. inference_models/models/yolov10/__init__.py +0 -0
  145. inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
  146. inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
  147. inference_models/models/yolov11/__init__.py +0 -0
  148. inference_models/models/yolov11/yolov11_onnx.py +28 -0
  149. inference_models/models/yolov11/yolov11_torch_script.py +25 -0
  150. inference_models/models/yolov11/yolov11_trt.py +21 -0
  151. inference_models/models/yolov12/__init__.py +0 -0
  152. inference_models/models/yolov12/yolov12_onnx.py +7 -0
  153. inference_models/models/yolov12/yolov12_torch_script.py +7 -0
  154. inference_models/models/yolov12/yolov12_trt.py +7 -0
  155. inference_models/models/yolov5/__init__.py +0 -0
  156. inference_models/models/yolov5/nms.py +99 -0
  157. inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
  158. inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
  159. inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
  160. inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
  161. inference_models/models/yolov7/__init__.py +0 -0
  162. inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
  163. inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
  164. inference_models/models/yolov8/__init__.py +0 -0
  165. inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
  166. inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
  167. inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
  168. inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
  169. inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
  170. inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
  171. inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
  172. inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
  173. inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
  174. inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
  175. inference_models/models/yolov9/__init__.py +0 -0
  176. inference_models/models/yolov9/yolov9_onnx.py +7 -0
  177. inference_models/models/yolov9/yolov9_torch_script.py +7 -0
  178. inference_models/models/yolov9/yolov9_trt.py +7 -0
  179. inference_models/runtime_introspection/__init__.py +0 -0
  180. inference_models/runtime_introspection/core.py +410 -0
  181. inference_models/utils/__init__.py +0 -0
  182. inference_models/utils/download.py +608 -0
  183. inference_models/utils/environment.py +28 -0
  184. inference_models/utils/file_system.py +51 -0
  185. inference_models/utils/hashing.py +7 -0
  186. inference_models/utils/imports.py +48 -0
  187. inference_models/utils/onnx_introspection.py +17 -0
  188. inference_models/weights_providers/__init__.py +0 -0
  189. inference_models/weights_providers/core.py +20 -0
  190. inference_models/weights_providers/entities.py +159 -0
  191. inference_models/weights_providers/roboflow.py +601 -0
  192. inference_models-0.18.3.dist-info/METADATA +466 -0
  193. inference_models-0.18.3.dist-info/RECORD +195 -0
  194. inference_models-0.18.3.dist-info/WHEEL +5 -0
  195. inference_models-0.18.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,201 @@
1
+ from threading import Lock
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torchvision # DO NOT REMOVE, THIS IMPORT ENABLES NMS OPERATION
7
+
8
+ from inference_models import InstanceDetections, InstanceSegmentationModel
9
+ from inference_models.configuration import DEFAULT_DEVICE
10
+ from inference_models.entities import ColorFormat
11
+ from inference_models.errors import CorruptedModelPackageError
12
+ from inference_models.models.common.model_packages import get_model_package_contents
13
+ from inference_models.models.common.roboflow.model_packages import (
14
+ InferenceConfig,
15
+ PreProcessingMetadata,
16
+ ResizeMode,
17
+ parse_class_names_file,
18
+ parse_inference_config,
19
+ )
20
+ from inference_models.models.common.roboflow.post_processing import (
21
+ align_instance_segmentation_results,
22
+ crop_masks_to_boxes,
23
+ post_process_nms_fused_model_output,
24
+ preprocess_segmentation_masks,
25
+ run_nms_for_instance_segmentation,
26
+ )
27
+ from inference_models.models.common.roboflow.pre_processing import (
28
+ pre_process_network_input,
29
+ )
30
+ from inference_models.models.common.torch import generate_batch_chunks
31
+
32
+
33
+ class YOLOv8ForInstanceSegmentationTorchScript(
34
+ InstanceSegmentationModel[
35
+ torch.Tensor, PreProcessingMetadata, Tuple[torch.Tensor, torch.Tensor]
36
+ ]
37
+ ):
38
+
39
+ @classmethod
40
+ def from_pretrained(
41
+ cls,
42
+ model_name_or_path: str,
43
+ device: torch.device = DEFAULT_DEVICE,
44
+ **kwargs,
45
+ ) -> "YOLOv8ForInstanceSegmentationTorchScript":
46
+ model_package_content = get_model_package_contents(
47
+ model_package_dir=model_name_or_path,
48
+ elements=[
49
+ "class_names.txt",
50
+ "inference_config.json",
51
+ "weights.torchscript",
52
+ ],
53
+ )
54
+ class_names = parse_class_names_file(
55
+ class_names_path=model_package_content["class_names.txt"]
56
+ )
57
+ inference_config = parse_inference_config(
58
+ config_path=model_package_content["inference_config.json"],
59
+ allowed_resize_modes={
60
+ ResizeMode.STRETCH_TO,
61
+ ResizeMode.LETTERBOX,
62
+ ResizeMode.CENTER_CROP,
63
+ ResizeMode.LETTERBOX_REFLECT_EDGES,
64
+ },
65
+ )
66
+ if inference_config.post_processing.type != "nms":
67
+ raise CorruptedModelPackageError(
68
+ message="Expected NMS to be the post-processing",
69
+ help_url="https://todo",
70
+ )
71
+ if inference_config.forward_pass.static_batch_size is None:
72
+ raise CorruptedModelPackageError(
73
+ message="Expected static batch size to be registered in the inference configuration.",
74
+ help_url="https://todo",
75
+ )
76
+ model = torch.jit.load(
77
+ model_package_content["weights.torchscript"], map_location=device
78
+ ).eval()
79
+ return cls(
80
+ model=model,
81
+ class_names=class_names,
82
+ inference_config=inference_config,
83
+ device=device,
84
+ )
85
+
86
+ def __init__(
87
+ self,
88
+ model: torch.nn.Module,
89
+ inference_config: InferenceConfig,
90
+ class_names: List[str],
91
+ device: torch.device,
92
+ ):
93
+ self._model = model
94
+ self._inference_config = inference_config
95
+ self._class_names = class_names
96
+ self._device = device
97
+ self._session_thread_lock = Lock()
98
+
99
+ @property
100
+ def class_names(self) -> List[str]:
101
+ return self._class_names
102
+
103
+ def pre_process(
104
+ self,
105
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
106
+ input_color_format: Optional[ColorFormat] = None,
107
+ **kwargs,
108
+ ) -> Tuple[torch.Tensor, List[PreProcessingMetadata]]:
109
+ return pre_process_network_input(
110
+ images=images,
111
+ image_pre_processing=self._inference_config.image_pre_processing,
112
+ network_input=self._inference_config.network_input,
113
+ target_device=self._device,
114
+ input_color_format=input_color_format,
115
+ )
116
+
117
+ def forward(
118
+ self, pre_processed_images: torch.Tensor, **kwargs
119
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
120
+ with torch.inference_mode():
121
+ if (
122
+ pre_processed_images.shape[0]
123
+ == self._inference_config.forward_pass.static_batch_size
124
+ ):
125
+ instances, protos = self._model(pre_processed_images)
126
+ return instances.to(self._device), protos.to(self._device)
127
+ instances, protos = [], []
128
+ for input_tensor, padding_size in generate_batch_chunks(
129
+ input_batch=pre_processed_images,
130
+ chunk_size=self._inference_config.forward_pass.static_batch_size,
131
+ ):
132
+ instances_for_chunk, protos_for_chunk = self._model(input_tensor)
133
+ if padding_size > 0:
134
+ instances_for_chunk = instances_for_chunk[:-padding_size]
135
+ protos_for_chunk = protos_for_chunk[:-padding_size]
136
+ instances.append(instances_for_chunk)
137
+ protos.append(protos_for_chunk)
138
+ return torch.cat(instances, dim=0).to(self._device), torch.cat(
139
+ protos, dim=0
140
+ ).to(self._device)
141
+
142
+ def post_process(
143
+ self,
144
+ model_results: Tuple[torch.Tensor, torch.Tensor],
145
+ pre_processing_meta: List[PreProcessingMetadata],
146
+ conf_thresh: float = 0.25,
147
+ iou_thresh: float = 0.45,
148
+ max_detections: int = 100,
149
+ class_agnostic: bool = False,
150
+ **kwargs,
151
+ ) -> List[InstanceDetections]:
152
+ instances, protos = model_results
153
+ if self._inference_config.post_processing.fused:
154
+ nms_results = post_process_nms_fused_model_output(
155
+ output=instances, conf_thresh=conf_thresh
156
+ )
157
+ else:
158
+ nms_results = run_nms_for_instance_segmentation(
159
+ output=instances,
160
+ conf_thresh=conf_thresh,
161
+ iou_thresh=iou_thresh,
162
+ max_detections=max_detections,
163
+ class_agnostic=class_agnostic,
164
+ )
165
+ final_results = []
166
+ for image_bboxes, image_protos, image_meta in zip(
167
+ nms_results, protos, pre_processing_meta
168
+ ):
169
+ pre_processed_masks = preprocess_segmentation_masks(
170
+ protos=image_protos,
171
+ masks_in=image_bboxes[:, 6:],
172
+ )
173
+ cropped_masks = crop_masks_to_boxes(
174
+ image_bboxes[:, :4], pre_processed_masks
175
+ )
176
+ padding = (
177
+ image_meta.pad_left,
178
+ image_meta.pad_top,
179
+ image_meta.pad_right,
180
+ image_meta.pad_bottom,
181
+ )
182
+ aligned_boxes, aligned_masks = align_instance_segmentation_results(
183
+ image_bboxes=image_bboxes,
184
+ masks=cropped_masks,
185
+ padding=padding,
186
+ scale_height=image_meta.scale_height,
187
+ scale_width=image_meta.scale_width,
188
+ original_size=image_meta.original_size,
189
+ size_after_pre_processing=image_meta.size_after_pre_processing,
190
+ inference_size=image_meta.inference_size,
191
+ static_crop_offset=image_meta.static_crop_offset,
192
+ )
193
+ final_results.append(
194
+ InstanceDetections(
195
+ xyxy=aligned_boxes[:, :4].round().int(),
196
+ class_id=aligned_boxes[:, 5].int(),
197
+ confidence=aligned_boxes[:, 4],
198
+ mask=aligned_masks,
199
+ )
200
+ )
201
+ return final_results
@@ -0,0 +1,268 @@
1
+ from threading import Lock
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ from inference_models import InstanceDetections, InstanceSegmentationModel
8
+ from inference_models.configuration import DEFAULT_DEVICE
9
+ from inference_models.entities import ColorFormat
10
+ from inference_models.errors import (
11
+ CorruptedModelPackageError,
12
+ MissingDependencyError,
13
+ ModelRuntimeError,
14
+ )
15
+ from inference_models.models.common.cuda import (
16
+ use_cuda_context,
17
+ use_primary_cuda_context,
18
+ )
19
+ from inference_models.models.common.model_packages import get_model_package_contents
20
+ from inference_models.models.common.roboflow.model_packages import (
21
+ InferenceConfig,
22
+ PreProcessingMetadata,
23
+ ResizeMode,
24
+ TRTConfig,
25
+ parse_class_names_file,
26
+ parse_inference_config,
27
+ parse_trt_config,
28
+ )
29
+ from inference_models.models.common.roboflow.post_processing import (
30
+ align_instance_segmentation_results,
31
+ crop_masks_to_boxes,
32
+ post_process_nms_fused_model_output,
33
+ preprocess_segmentation_masks,
34
+ run_nms_for_instance_segmentation,
35
+ )
36
+ from inference_models.models.common.roboflow.pre_processing import (
37
+ pre_process_network_input,
38
+ )
39
+ from inference_models.models.common.trt import (
40
+ get_engine_inputs_and_outputs,
41
+ infer_from_trt_engine,
42
+ load_model,
43
+ )
44
+
45
+ try:
46
+ import tensorrt as trt
47
+ except ImportError as import_error:
48
+ raise MissingDependencyError(
49
+ message=f"Could not import YOLOv8 model with TRT backend - this error means that some additional dependencies "
50
+ f"are not installed in the environment. If you run the `inference-models` library directly in your Python "
51
+ f"program, make sure the following extras of the package are installed: `trt10` - installation can only "
52
+ f"succeed for Linux and Windows machines with Cuda 12 installed. Jetson devices, should have TRT 10.x "
53
+ f"installed for all builds with Jetpack 6. "
54
+ f"If you see this error using Roboflow infrastructure, make sure the service you use does support the model. "
55
+ f"You can also contact Roboflow to get support.",
56
+ help_url="https://todo",
57
+ ) from import_error
58
+
59
+ try:
60
+ import pycuda.driver as cuda
61
+ except ImportError as import_error:
62
+ raise MissingDependencyError(
63
+ message="TODO",
64
+ help_url="https://todo",
65
+ ) from import_error
66
+
67
+
68
+ class YOLOv8ForInstanceSegmentationTRT(
69
+ InstanceSegmentationModel[
70
+ torch.Tensor, PreProcessingMetadata, Tuple[torch.Tensor, torch.Tensor]
71
+ ]
72
+ ):
73
+
74
+ @classmethod
75
+ def from_pretrained(
76
+ cls,
77
+ model_name_or_path: str,
78
+ device: torch.device = DEFAULT_DEVICE,
79
+ engine_host_code_allowed: bool = False,
80
+ **kwargs,
81
+ ) -> "YOLOv8ForInstanceSegmentationTRT":
82
+ if device.type != "cuda":
83
+ raise ModelRuntimeError(
84
+ message=f"TRT engine only runs on CUDA device - {device} device detected.",
85
+ help_url="https://todo",
86
+ )
87
+ model_package_content = get_model_package_contents(
88
+ model_package_dir=model_name_or_path,
89
+ elements=[
90
+ "class_names.txt",
91
+ "inference_config.json",
92
+ "trt_config.json",
93
+ "engine.plan",
94
+ ],
95
+ )
96
+ class_names = parse_class_names_file(
97
+ class_names_path=model_package_content["class_names.txt"]
98
+ )
99
+ inference_config = parse_inference_config(
100
+ config_path=model_package_content["inference_config.json"],
101
+ allowed_resize_modes={
102
+ ResizeMode.STRETCH_TO,
103
+ ResizeMode.LETTERBOX,
104
+ ResizeMode.CENTER_CROP,
105
+ ResizeMode.LETTERBOX_REFLECT_EDGES,
106
+ },
107
+ )
108
+ if inference_config.post_processing.type != "nms":
109
+ raise CorruptedModelPackageError(
110
+ message="Expected NMS to be the post-processing",
111
+ help_url="https://todo",
112
+ )
113
+ trt_config = parse_trt_config(
114
+ config_path=model_package_content["trt_config.json"]
115
+ )
116
+ cuda.init()
117
+ cuda_device = cuda.Device(device.index or 0)
118
+ with use_primary_cuda_context(cuda_device=cuda_device) as cuda_context:
119
+ engine = load_model(
120
+ model_path=model_package_content["engine.plan"],
121
+ engine_host_code_allowed=engine_host_code_allowed,
122
+ )
123
+ execution_context = engine.create_execution_context()
124
+ inputs, outputs = get_engine_inputs_and_outputs(engine=engine)
125
+ if len(inputs) != 1:
126
+ raise CorruptedModelPackageError(
127
+ message=f"Implementation assume single model input, found: {len(inputs)}.",
128
+ help_url="https://todo",
129
+ )
130
+ if len(outputs) != 2:
131
+ raise CorruptedModelPackageError(
132
+ message=f"Implementation assume 2 model outputs, found: {len(outputs)}.",
133
+ help_url="https://todo",
134
+ )
135
+ if "output0" not in outputs or "output1" not in outputs:
136
+ raise CorruptedModelPackageError(
137
+ message=f"Expected model outputs to be named `output0` and `output1`, but found: {outputs}.",
138
+ help_url="https://todo",
139
+ )
140
+ return cls(
141
+ engine=engine,
142
+ input_name=inputs[0],
143
+ output_names=["output0", "output1"],
144
+ class_names=class_names,
145
+ inference_config=inference_config,
146
+ trt_config=trt_config,
147
+ device=device,
148
+ execution_context=execution_context,
149
+ cuda_context=cuda_context,
150
+ )
151
+
152
+ def __init__(
153
+ self,
154
+ engine: trt.ICudaEngine,
155
+ input_name: str,
156
+ output_names: List[str],
157
+ class_names: List[str],
158
+ inference_config: InferenceConfig,
159
+ trt_config: TRTConfig,
160
+ device: torch.device,
161
+ cuda_context: cuda.Context,
162
+ execution_context: trt.IExecutionContext,
163
+ ):
164
+ self._engine = engine
165
+ self._input_name = input_name
166
+ self._output_names = output_names
167
+ self._class_names = class_names
168
+ self._inference_config = inference_config
169
+ self._trt_config = trt_config
170
+ self._device = device
171
+ self._cuda_context = cuda_context
172
+ self._execution_context = execution_context
173
+ self._session_thread_lock = Lock()
174
+
175
+ @property
176
+ def class_names(self) -> List[str]:
177
+ return self._class_names
178
+
179
+ def pre_process(
180
+ self,
181
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
182
+ input_color_format: Optional[ColorFormat] = None,
183
+ **kwargs,
184
+ ) -> Tuple[torch.Tensor, List[PreProcessingMetadata]]:
185
+ return pre_process_network_input(
186
+ images=images,
187
+ image_pre_processing=self._inference_config.image_pre_processing,
188
+ network_input=self._inference_config.network_input,
189
+ target_device=self._device,
190
+ input_color_format=input_color_format,
191
+ )
192
+
193
+ def forward(
194
+ self, pre_processed_images: torch.Tensor, **kwargs
195
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
196
+ with self._session_thread_lock:
197
+ with use_cuda_context(context=self._cuda_context):
198
+ instances, protos = infer_from_trt_engine(
199
+ pre_processed_images=pre_processed_images,
200
+ trt_config=self._trt_config,
201
+ engine=self._engine,
202
+ context=self._execution_context,
203
+ device=self._device,
204
+ input_name=self._input_name,
205
+ outputs=self._output_names,
206
+ )
207
+ return instances, protos
208
+
209
+ def post_process(
210
+ self,
211
+ model_results: Tuple[torch.Tensor, torch.Tensor],
212
+ pre_processing_meta: List[PreProcessingMetadata],
213
+ conf_thresh: float = 0.25,
214
+ iou_thresh: float = 0.45,
215
+ max_detections: int = 100,
216
+ class_agnostic: bool = False,
217
+ **kwargs,
218
+ ) -> List[InstanceDetections]:
219
+ instances, protos = model_results
220
+ if self._inference_config.post_processing.fused:
221
+ nms_results = post_process_nms_fused_model_output(
222
+ output=instances, conf_thresh=conf_thresh
223
+ )
224
+ else:
225
+ nms_results = run_nms_for_instance_segmentation(
226
+ output=instances,
227
+ conf_thresh=conf_thresh,
228
+ iou_thresh=iou_thresh,
229
+ max_detections=max_detections,
230
+ class_agnostic=class_agnostic,
231
+ )
232
+ final_results = []
233
+ for image_bboxes, image_protos, image_meta in zip(
234
+ nms_results, protos, pre_processing_meta
235
+ ):
236
+ pre_processed_masks = preprocess_segmentation_masks(
237
+ protos=image_protos,
238
+ masks_in=image_bboxes[:, 6:],
239
+ )
240
+ cropped_masks = crop_masks_to_boxes(
241
+ image_bboxes[:, :4], pre_processed_masks
242
+ )
243
+ padding = (
244
+ image_meta.pad_left,
245
+ image_meta.pad_top,
246
+ image_meta.pad_right,
247
+ image_meta.pad_bottom,
248
+ )
249
+ aligned_boxes, aligned_masks = align_instance_segmentation_results(
250
+ image_bboxes=image_bboxes,
251
+ masks=cropped_masks,
252
+ padding=padding,
253
+ scale_height=image_meta.scale_height,
254
+ scale_width=image_meta.scale_width,
255
+ original_size=image_meta.original_size,
256
+ size_after_pre_processing=image_meta.size_after_pre_processing,
257
+ inference_size=image_meta.inference_size,
258
+ static_crop_offset=image_meta.static_crop_offset,
259
+ )
260
+ final_results.append(
261
+ InstanceDetections(
262
+ xyxy=aligned_boxes[:, :4].round().int(),
263
+ class_id=aligned_boxes[:, 5].int(),
264
+ confidence=aligned_boxes[:, 4],
265
+ mask=aligned_masks,
266
+ )
267
+ )
268
+ return final_results