inference-models 0.18.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. inference_models/__init__.py +36 -0
  2. inference_models/configuration.py +72 -0
  3. inference_models/constants.py +2 -0
  4. inference_models/entities.py +5 -0
  5. inference_models/errors.py +137 -0
  6. inference_models/logger.py +52 -0
  7. inference_models/model_pipelines/__init__.py +0 -0
  8. inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
  9. inference_models/model_pipelines/auto_loaders/core.py +120 -0
  10. inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
  11. inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
  12. inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
  13. inference_models/models/__init__.py +0 -0
  14. inference_models/models/auto_loaders/__init__.py +0 -0
  15. inference_models/models/auto_loaders/access_manager.py +168 -0
  16. inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
  17. inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
  18. inference_models/models/auto_loaders/constants.py +7 -0
  19. inference_models/models/auto_loaders/core.py +1341 -0
  20. inference_models/models/auto_loaders/dependency_models.py +52 -0
  21. inference_models/models/auto_loaders/entities.py +57 -0
  22. inference_models/models/auto_loaders/models_registry.py +497 -0
  23. inference_models/models/auto_loaders/presentation_utils.py +333 -0
  24. inference_models/models/auto_loaders/ranking.py +413 -0
  25. inference_models/models/auto_loaders/utils.py +31 -0
  26. inference_models/models/base/__init__.py +0 -0
  27. inference_models/models/base/classification.py +123 -0
  28. inference_models/models/base/depth_estimation.py +62 -0
  29. inference_models/models/base/documents_parsing.py +111 -0
  30. inference_models/models/base/embeddings.py +66 -0
  31. inference_models/models/base/instance_segmentation.py +87 -0
  32. inference_models/models/base/keypoints_detection.py +93 -0
  33. inference_models/models/base/object_detection.py +143 -0
  34. inference_models/models/base/semantic_segmentation.py +74 -0
  35. inference_models/models/base/types.py +5 -0
  36. inference_models/models/clip/__init__.py +0 -0
  37. inference_models/models/clip/clip_onnx.py +148 -0
  38. inference_models/models/clip/clip_pytorch.py +104 -0
  39. inference_models/models/clip/preprocessing.py +162 -0
  40. inference_models/models/common/__init__.py +0 -0
  41. inference_models/models/common/cuda.py +30 -0
  42. inference_models/models/common/model_packages.py +25 -0
  43. inference_models/models/common/onnx.py +379 -0
  44. inference_models/models/common/roboflow/__init__.py +0 -0
  45. inference_models/models/common/roboflow/model_packages.py +361 -0
  46. inference_models/models/common/roboflow/post_processing.py +436 -0
  47. inference_models/models/common/roboflow/pre_processing.py +1332 -0
  48. inference_models/models/common/torch.py +20 -0
  49. inference_models/models/common/trt.py +266 -0
  50. inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
  51. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
  52. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
  53. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
  54. inference_models/models/depth_anything_v2/__init__.py +0 -0
  55. inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
  56. inference_models/models/dinov3/__init__.py +0 -0
  57. inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
  58. inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
  59. inference_models/models/doctr/__init__.py +0 -0
  60. inference_models/models/doctr/doctr_torch.py +304 -0
  61. inference_models/models/easy_ocr/__init__.py +0 -0
  62. inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
  63. inference_models/models/florence2/__init__.py +0 -0
  64. inference_models/models/florence2/florence2_hf.py +897 -0
  65. inference_models/models/grounding_dino/__init__.py +0 -0
  66. inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
  67. inference_models/models/l2cs/__init__.py +0 -0
  68. inference_models/models/l2cs/l2cs_onnx.py +216 -0
  69. inference_models/models/mediapipe_face_detection/__init__.py +0 -0
  70. inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
  71. inference_models/models/moondream2/__init__.py +0 -0
  72. inference_models/models/moondream2/moondream2_hf.py +281 -0
  73. inference_models/models/owlv2/__init__.py +0 -0
  74. inference_models/models/owlv2/cache.py +182 -0
  75. inference_models/models/owlv2/entities.py +112 -0
  76. inference_models/models/owlv2/owlv2_hf.py +695 -0
  77. inference_models/models/owlv2/reference_dataset.py +291 -0
  78. inference_models/models/paligemma/__init__.py +0 -0
  79. inference_models/models/paligemma/paligemma_hf.py +209 -0
  80. inference_models/models/perception_encoder/__init__.py +0 -0
  81. inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
  82. inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
  83. inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
  84. inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
  85. inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
  86. inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
  87. inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
  88. inference_models/models/qwen25vl/__init__.py +1 -0
  89. inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
  90. inference_models/models/resnet/__init__.py +0 -0
  91. inference_models/models/resnet/resnet_classification_onnx.py +330 -0
  92. inference_models/models/resnet/resnet_classification_torch.py +305 -0
  93. inference_models/models/resnet/resnet_classification_trt.py +369 -0
  94. inference_models/models/rfdetr/__init__.py +0 -0
  95. inference_models/models/rfdetr/backbone_builder.py +101 -0
  96. inference_models/models/rfdetr/class_remapping.py +41 -0
  97. inference_models/models/rfdetr/common.py +115 -0
  98. inference_models/models/rfdetr/default_labels.py +108 -0
  99. inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
  100. inference_models/models/rfdetr/misc.py +26 -0
  101. inference_models/models/rfdetr/ms_deform_attn.py +180 -0
  102. inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
  103. inference_models/models/rfdetr/position_encoding.py +166 -0
  104. inference_models/models/rfdetr/post_processor.py +83 -0
  105. inference_models/models/rfdetr/projector.py +373 -0
  106. inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
  107. inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
  108. inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
  109. inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
  110. inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
  111. inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
  112. inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
  113. inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
  114. inference_models/models/rfdetr/segmentation_head.py +273 -0
  115. inference_models/models/rfdetr/transformer.py +767 -0
  116. inference_models/models/roboflow_instant/__init__.py +0 -0
  117. inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
  118. inference_models/models/sam/__init__.py +0 -0
  119. inference_models/models/sam/cache.py +147 -0
  120. inference_models/models/sam/entities.py +25 -0
  121. inference_models/models/sam/sam_torch.py +675 -0
  122. inference_models/models/sam2/__init__.py +0 -0
  123. inference_models/models/sam2/cache.py +162 -0
  124. inference_models/models/sam2/entities.py +43 -0
  125. inference_models/models/sam2/sam2_torch.py +905 -0
  126. inference_models/models/sam2_rt/__init__.py +0 -0
  127. inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
  128. inference_models/models/smolvlm/__init__.py +0 -0
  129. inference_models/models/smolvlm/smolvlm_hf.py +245 -0
  130. inference_models/models/trocr/__init__.py +0 -0
  131. inference_models/models/trocr/trocr_hf.py +53 -0
  132. inference_models/models/vit/__init__.py +0 -0
  133. inference_models/models/vit/vit_classification_huggingface.py +319 -0
  134. inference_models/models/vit/vit_classification_onnx.py +326 -0
  135. inference_models/models/vit/vit_classification_trt.py +365 -0
  136. inference_models/models/yolact/__init__.py +1 -0
  137. inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
  138. inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
  139. inference_models/models/yolo_world/__init__.py +1 -0
  140. inference_models/models/yolonas/__init__.py +0 -0
  141. inference_models/models/yolonas/nms.py +44 -0
  142. inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
  143. inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
  144. inference_models/models/yolov10/__init__.py +0 -0
  145. inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
  146. inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
  147. inference_models/models/yolov11/__init__.py +0 -0
  148. inference_models/models/yolov11/yolov11_onnx.py +28 -0
  149. inference_models/models/yolov11/yolov11_torch_script.py +25 -0
  150. inference_models/models/yolov11/yolov11_trt.py +21 -0
  151. inference_models/models/yolov12/__init__.py +0 -0
  152. inference_models/models/yolov12/yolov12_onnx.py +7 -0
  153. inference_models/models/yolov12/yolov12_torch_script.py +7 -0
  154. inference_models/models/yolov12/yolov12_trt.py +7 -0
  155. inference_models/models/yolov5/__init__.py +0 -0
  156. inference_models/models/yolov5/nms.py +99 -0
  157. inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
  158. inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
  159. inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
  160. inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
  161. inference_models/models/yolov7/__init__.py +0 -0
  162. inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
  163. inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
  164. inference_models/models/yolov8/__init__.py +0 -0
  165. inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
  166. inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
  167. inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
  168. inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
  169. inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
  170. inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
  171. inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
  172. inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
  173. inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
  174. inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
  175. inference_models/models/yolov9/__init__.py +0 -0
  176. inference_models/models/yolov9/yolov9_onnx.py +7 -0
  177. inference_models/models/yolov9/yolov9_torch_script.py +7 -0
  178. inference_models/models/yolov9/yolov9_trt.py +7 -0
  179. inference_models/runtime_introspection/__init__.py +0 -0
  180. inference_models/runtime_introspection/core.py +410 -0
  181. inference_models/utils/__init__.py +0 -0
  182. inference_models/utils/download.py +608 -0
  183. inference_models/utils/environment.py +28 -0
  184. inference_models/utils/file_system.py +51 -0
  185. inference_models/utils/hashing.py +7 -0
  186. inference_models/utils/imports.py +48 -0
  187. inference_models/utils/onnx_introspection.py +17 -0
  188. inference_models/weights_providers/__init__.py +0 -0
  189. inference_models/weights_providers/core.py +20 -0
  190. inference_models/weights_providers/entities.py +159 -0
  191. inference_models/weights_providers/roboflow.py +601 -0
  192. inference_models-0.18.3.dist-info/METADATA +466 -0
  193. inference_models-0.18.3.dist-info/RECORD +195 -0
  194. inference_models-0.18.3.dist-info/WHEEL +5 -0
  195. inference_models-0.18.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,361 @@
1
+ from collections import namedtuple
2
+ from dataclasses import dataclass
3
+ from enum import Enum
4
+ from typing import Annotated, List, Literal, Optional, Set, Tuple, Union
5
+
6
+ from pydantic import BaseModel, BeforeValidator, Field, ValidationError
7
+
8
+ from inference_models.entities import ImageDimensions
9
+ from inference_models.errors import CorruptedModelPackageError
10
+ from inference_models.utils.file_system import read_json, stream_file_lines
11
+
12
+
13
+ def parse_class_names_file(class_names_path: str) -> List[str]:
14
+ try:
15
+ result = list(stream_file_lines(path=class_names_path))
16
+ if not result:
17
+ raise ValueError("Empty class list")
18
+ return result
19
+ except (OSError, ValueError) as error:
20
+ raise CorruptedModelPackageError(
21
+ message=f"Could not decode file which is supposed to provide list of model class names. Error: {error}."
22
+ f"If you created model package manually, please verify its consistency in docs. In case that the "
23
+ f"weights are hosted on the Roboflow platform - contact support.",
24
+ help_url="https://todo",
25
+ ) from error
26
+
27
+
28
+ PADDING_VALUES_MAPPING = {
29
+ "black edges": 0,
30
+ "grey edges": 127,
31
+ "white edges": 255,
32
+ }
33
+ StaticCropOffset = namedtuple(
34
+ "StaticCropOffset",
35
+ [
36
+ "offset_x",
37
+ "offset_y",
38
+ "crop_width",
39
+ "crop_height",
40
+ ],
41
+ )
42
+ PreProcessingMetadata = namedtuple(
43
+ "PreProcessingMetadata",
44
+ [
45
+ "pad_left",
46
+ "pad_top",
47
+ "pad_right",
48
+ "pad_bottom",
49
+ "original_size",
50
+ "size_after_pre_processing",
51
+ "inference_size",
52
+ "scale_width",
53
+ "scale_height",
54
+ "static_crop_offset",
55
+ ],
56
+ )
57
+
58
+
59
+ def parse_key_points_metadata(
60
+ key_points_metadata_path: str,
61
+ ) -> Tuple[List[List[str]], List[List[Tuple[int, int]]]]:
62
+ try:
63
+ parsed_config = read_json(path=key_points_metadata_path)
64
+ if not isinstance(parsed_config, list):
65
+ raise ValueError(
66
+ "config should contain list of key points descriptions for each instance"
67
+ )
68
+ class_names: List[Optional[List[str]]] = [None] * len(parsed_config)
69
+ skeletons: List[Optional[List[Tuple[int, int]]]] = [None] * len(parsed_config)
70
+ for instance_key_point_description in parsed_config:
71
+ if "object_class_id" not in instance_key_point_description:
72
+ raise ValueError(
73
+ "instance key point description lack 'object_class_id' key"
74
+ )
75
+ object_class_id: int = instance_key_point_description["object_class_id"]
76
+ if not 0 <= object_class_id < len(class_names):
77
+ raise ValueError("`object_class_id` field point invalid class")
78
+ if "keypoints" not in instance_key_point_description:
79
+ raise ValueError(
80
+ f"`keypoints` field not available in config for class with id {object_class_id}"
81
+ )
82
+ class_names[object_class_id] = _retrieve_key_points_names(
83
+ key_points=instance_key_point_description["keypoints"],
84
+ )
85
+ key_points_count = len(class_names[object_class_id])
86
+ if "edges" not in instance_key_point_description:
87
+ raise ValueError(
88
+ f"`edges` field not available in config for class with id {object_class_id}"
89
+ )
90
+ skeletons[object_class_id] = _retrieve_skeleton(
91
+ edges=instance_key_point_description["edges"],
92
+ key_points_count=key_points_count,
93
+ )
94
+ if any(e is None for e in class_names):
95
+ raise ValueError(
96
+ "config does not provide metadata describing each instance key points"
97
+ )
98
+ if any(e is None for e in skeletons):
99
+ raise ValueError(
100
+ "config does not provide metadata describing each instance skeleton"
101
+ )
102
+ return class_names, skeletons
103
+ except (IOError, OSError, ValueError) as error:
104
+ raise CorruptedModelPackageError(
105
+ message=f"Key points config file is malformed: "
106
+ f"{error}. In case that the package is "
107
+ f"hosted on the Roboflow platform - contact support. If you created model package manually, please "
108
+ f"verify its consistency in docs.",
109
+ help_url="https://todo",
110
+ ) from error
111
+
112
+
113
+ def _retrieve_key_points_names(key_points: dict) -> List[str]:
114
+ key_points_dump = sorted(
115
+ [(int(k), v) for k, v in key_points.items()],
116
+ key=lambda e: e[0],
117
+ )
118
+ return [e[1] for e in key_points_dump]
119
+
120
+
121
+ def _retrieve_skeleton(
122
+ edges: List[dict], key_points_count: int
123
+ ) -> List[Tuple[int, int]]:
124
+ result = []
125
+ for edge in edges:
126
+ if not isinstance(edge, dict) or "from" not in edge or "to" not in edge:
127
+ raise ValueError(
128
+ "skeleton edge malformed - invalid format or lack of required keys"
129
+ )
130
+ start = edge["from"]
131
+ end = edge["to"]
132
+ if not 0 <= start < key_points_count or not 0 <= end < key_points_count:
133
+ raise ValueError(
134
+ "skeleton edge malformed - identifier of skeleton edge end is out of allowed range determined by "
135
+ "the number of key points in the skeleton"
136
+ )
137
+ result.append((edge["from"], edge["to"]))
138
+ return result
139
+
140
+
141
+ @dataclass
142
+ class TRTConfig:
143
+ static_batch_size: Optional[int]
144
+ dynamic_batch_size_min: Optional[int]
145
+ dynamic_batch_size_opt: Optional[int]
146
+ dynamic_batch_size_max: Optional[int]
147
+
148
+
149
+ def parse_trt_config(config_path: str) -> TRTConfig:
150
+ try:
151
+ parsed_config = read_json(path=config_path)
152
+ if not isinstance(parsed_config, dict):
153
+ raise ValueError(
154
+ f"Expected config format is dict, found {type(parsed_config)} instead"
155
+ )
156
+ config = TRTConfig(
157
+ static_batch_size=parsed_config.get("static_batch_size"),
158
+ dynamic_batch_size_min=parsed_config.get("dynamic_batch_size_min"),
159
+ dynamic_batch_size_opt=parsed_config.get("dynamic_batch_size_opt"),
160
+ dynamic_batch_size_max=parsed_config.get("dynamic_batch_size_max"),
161
+ )
162
+ if config.static_batch_size is not None:
163
+ if config.static_batch_size <= 0:
164
+ raise ValueError(
165
+ f"invalid static batch size - {config.static_batch_size}"
166
+ )
167
+ return config
168
+ if (
169
+ config.dynamic_batch_size_min is None
170
+ or config.dynamic_batch_size_opt is None
171
+ or config.dynamic_batch_size_max is None
172
+ ):
173
+ raise ValueError(
174
+ "configuration does not provide information about boundaries for dynamic batch size"
175
+ )
176
+ if (
177
+ config.dynamic_batch_size_min <= 0
178
+ or config.dynamic_batch_size_max < config.dynamic_batch_size_min
179
+ or config.dynamic_batch_size_opt < config.dynamic_batch_size_min
180
+ or config.dynamic_batch_size_opt > config.dynamic_batch_size_max
181
+ ):
182
+ raise ValueError(f"invalid dynamic batch size")
183
+ return config
184
+ except (IOError, OSError, ValueError) as error:
185
+ raise CorruptedModelPackageError(
186
+ message=f"TRT config file of the model package is malformed: "
187
+ f"{error}. In case that the package is "
188
+ f"hosted on the Roboflow platform - contact support. If you created model package manually, please "
189
+ f"verify its consistency in docs.",
190
+ help_url="https://todo",
191
+ ) from error
192
+
193
+
194
+ class AutoOrient(BaseModel):
195
+ enabled: bool
196
+
197
+
198
+ class StaticCrop(BaseModel):
199
+ enabled: bool
200
+ x_min: int
201
+ x_max: int
202
+ y_min: int
203
+ y_max: int
204
+
205
+
206
+ class ContrastType(str, Enum):
207
+ ADAPTIVE_EQUALIZATION = "Adaptive Equalization"
208
+ CONTRAST_STRETCHING = "Contrast Stretching"
209
+ HISTOGRAM_EQUALIZATION = "Histogram Equalization"
210
+
211
+
212
+ class Contrast(BaseModel):
213
+ enabled: bool
214
+ type: ContrastType
215
+
216
+
217
+ class Grayscale(BaseModel):
218
+ enabled: bool
219
+
220
+
221
+ class ImagePreProcessing(BaseModel):
222
+ auto_orient: Optional[AutoOrient] = Field(alias="auto-orient", default=None)
223
+ static_crop: Optional[StaticCrop] = Field(alias="static-crop", default=None)
224
+ contrast: Optional[Contrast] = Field(default=None)
225
+ grayscale: Optional[Grayscale] = Field(default=None)
226
+
227
+
228
+ class TrainingInputSize(BaseModel):
229
+ height: int
230
+ width: int
231
+
232
+
233
+ class DivisiblePadding(BaseModel):
234
+ type: Literal["pad-to-be-divisible"]
235
+ value: int
236
+
237
+
238
+ class AnySizePadding(BaseModel):
239
+ type: Literal["any-size"]
240
+
241
+
242
+ class ColorMode(str, Enum):
243
+ BGR = "bgr"
244
+ RGB = "rgb"
245
+
246
+
247
+ class ResizeMode(str, Enum):
248
+ STRETCH_TO = "stretch"
249
+ LETTERBOX = "letterbox"
250
+ CENTER_CROP = "center-crop"
251
+ FIT_LONGER_EDGE = "fit-longer-edge"
252
+ LETTERBOX_REFLECT_EDGES = "letterbox-reflect-edges"
253
+
254
+
255
+ Number = Union[int, float]
256
+
257
+
258
+ class NetworkInputDefinition(BaseModel):
259
+ training_input_size: TrainingInputSize
260
+ dynamic_spatial_size_supported: bool
261
+ dynamic_spatial_size_mode: Optional[Union[DivisiblePadding, AnySizePadding]] = (
262
+ Field(discriminator="type", default=None)
263
+ )
264
+ color_mode: ColorMode
265
+ resize_mode: ResizeMode
266
+ padding_value: Optional[int] = Field(default=None)
267
+ input_channels: int
268
+ scaling_factor: Optional[Number] = Field(default=None)
269
+ normalization: Optional[Tuple[List[Number], List[Number]]] = Field(default=None)
270
+
271
+
272
+ class ForwardPassConfiguration(BaseModel):
273
+ static_batch_size: Optional[int] = Field(default=None)
274
+ max_dynamic_batch_size: Optional[int] = Field(default=None)
275
+
276
+
277
+ class FusedNMSParameters(BaseModel):
278
+ max_detections: int
279
+ confidence_threshold: float
280
+ iou_threshold: float
281
+ class_agnostic: int
282
+
283
+
284
+ class NMSPostProcessing(BaseModel):
285
+ type: Literal["nms"]
286
+ fused: bool
287
+ nms_parameters: Optional[FusedNMSParameters] = Field(default=None)
288
+
289
+
290
+ class SigmoidPostProcessing(BaseModel):
291
+ type: Literal["sigmoid"]
292
+ fused: bool
293
+
294
+
295
+ class SoftMaxPostProcessing(BaseModel):
296
+ type: Literal["softmax"]
297
+ fused: bool
298
+
299
+
300
+ ImagePreProcessingValidator = BeforeValidator(
301
+ lambda value: value if value is not None else ImagePreProcessing()
302
+ )
303
+
304
+
305
+ class ClassNameRemoval(BaseModel):
306
+ type: Literal["class_name_removal"]
307
+ class_name: str
308
+
309
+
310
+ class InferenceConfig(BaseModel):
311
+ image_pre_processing: Annotated[ImagePreProcessing, ImagePreProcessingValidator] = (
312
+ Field(default_factory=lambda: ImagePreProcessing())
313
+ )
314
+ network_input: NetworkInputDefinition
315
+ forward_pass: ForwardPassConfiguration = Field(
316
+ default_factory=lambda: ForwardPassConfiguration()
317
+ )
318
+ post_processing: Optional[
319
+ Union[NMSPostProcessing, SoftMaxPostProcessing, SigmoidPostProcessing]
320
+ ] = Field(default=None, discriminator="type")
321
+ model_initialization: Optional[dict] = Field(default=None)
322
+ class_names_operations: Optional[
323
+ List[Annotated[Union[ClassNameRemoval], Field(discriminator="type")]]
324
+ ] = Field(default=None)
325
+
326
+
327
+ def parse_inference_config(
328
+ config_path: str,
329
+ allowed_resize_modes: Set[ResizeMode],
330
+ ) -> InferenceConfig:
331
+ try:
332
+ decoded_config = read_json(path=config_path)
333
+ if not isinstance(decoded_config, dict):
334
+ raise ValueError(
335
+ f"Expected config format is dict, found {type(decoded_config)} instead"
336
+ )
337
+ except (IOError, OSError, ValueError) as error:
338
+ raise CorruptedModelPackageError(
339
+ message=f"Inference config file of the model package is malformed: "
340
+ f"{error}. In case that the package is "
341
+ f"hosted on the Roboflow platform - contact support. If you created model package manually, please "
342
+ f"verify its consistency in docs.",
343
+ help_url="https://todo",
344
+ ) from error
345
+ try:
346
+ parsed_config = InferenceConfig.model_validate(decoded_config)
347
+ except ValidationError as error:
348
+ raise CorruptedModelPackageError(
349
+ message=f"Could not parse the inference config from the model package.",
350
+ help_url="https://todo",
351
+ ) from error
352
+ if parsed_config.network_input.resize_mode not in allowed_resize_modes:
353
+ allowed_resize_modes_str = ", ".join([e.value for e in allowed_resize_modes])
354
+ raise CorruptedModelPackageError(
355
+ message=f"Inference configuration shipped with model package defines input resize "
356
+ f"{parsed_config.network_input.resize_mode} which is not supported by the model implementation. "
357
+ f"Config defines: {parsed_config.network_input.resize_mode.value}, but the allowed values are: "
358
+ f"{allowed_resize_modes_str}.",
359
+ help_url="https://todo",
360
+ )
361
+ return parsed_config