inference-models 0.18.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inference_models/__init__.py +36 -0
- inference_models/configuration.py +72 -0
- inference_models/constants.py +2 -0
- inference_models/entities.py +5 -0
- inference_models/errors.py +137 -0
- inference_models/logger.py +52 -0
- inference_models/model_pipelines/__init__.py +0 -0
- inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
- inference_models/model_pipelines/auto_loaders/core.py +120 -0
- inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
- inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
- inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
- inference_models/models/__init__.py +0 -0
- inference_models/models/auto_loaders/__init__.py +0 -0
- inference_models/models/auto_loaders/access_manager.py +168 -0
- inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
- inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
- inference_models/models/auto_loaders/constants.py +7 -0
- inference_models/models/auto_loaders/core.py +1341 -0
- inference_models/models/auto_loaders/dependency_models.py +52 -0
- inference_models/models/auto_loaders/entities.py +57 -0
- inference_models/models/auto_loaders/models_registry.py +497 -0
- inference_models/models/auto_loaders/presentation_utils.py +333 -0
- inference_models/models/auto_loaders/ranking.py +413 -0
- inference_models/models/auto_loaders/utils.py +31 -0
- inference_models/models/base/__init__.py +0 -0
- inference_models/models/base/classification.py +123 -0
- inference_models/models/base/depth_estimation.py +62 -0
- inference_models/models/base/documents_parsing.py +111 -0
- inference_models/models/base/embeddings.py +66 -0
- inference_models/models/base/instance_segmentation.py +87 -0
- inference_models/models/base/keypoints_detection.py +93 -0
- inference_models/models/base/object_detection.py +143 -0
- inference_models/models/base/semantic_segmentation.py +74 -0
- inference_models/models/base/types.py +5 -0
- inference_models/models/clip/__init__.py +0 -0
- inference_models/models/clip/clip_onnx.py +148 -0
- inference_models/models/clip/clip_pytorch.py +104 -0
- inference_models/models/clip/preprocessing.py +162 -0
- inference_models/models/common/__init__.py +0 -0
- inference_models/models/common/cuda.py +30 -0
- inference_models/models/common/model_packages.py +25 -0
- inference_models/models/common/onnx.py +379 -0
- inference_models/models/common/roboflow/__init__.py +0 -0
- inference_models/models/common/roboflow/model_packages.py +361 -0
- inference_models/models/common/roboflow/post_processing.py +436 -0
- inference_models/models/common/roboflow/pre_processing.py +1332 -0
- inference_models/models/common/torch.py +20 -0
- inference_models/models/common/trt.py +266 -0
- inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
- inference_models/models/depth_anything_v2/__init__.py +0 -0
- inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
- inference_models/models/dinov3/__init__.py +0 -0
- inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
- inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
- inference_models/models/doctr/__init__.py +0 -0
- inference_models/models/doctr/doctr_torch.py +304 -0
- inference_models/models/easy_ocr/__init__.py +0 -0
- inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
- inference_models/models/florence2/__init__.py +0 -0
- inference_models/models/florence2/florence2_hf.py +897 -0
- inference_models/models/grounding_dino/__init__.py +0 -0
- inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
- inference_models/models/l2cs/__init__.py +0 -0
- inference_models/models/l2cs/l2cs_onnx.py +216 -0
- inference_models/models/mediapipe_face_detection/__init__.py +0 -0
- inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
- inference_models/models/moondream2/__init__.py +0 -0
- inference_models/models/moondream2/moondream2_hf.py +281 -0
- inference_models/models/owlv2/__init__.py +0 -0
- inference_models/models/owlv2/cache.py +182 -0
- inference_models/models/owlv2/entities.py +112 -0
- inference_models/models/owlv2/owlv2_hf.py +695 -0
- inference_models/models/owlv2/reference_dataset.py +291 -0
- inference_models/models/paligemma/__init__.py +0 -0
- inference_models/models/paligemma/paligemma_hf.py +209 -0
- inference_models/models/perception_encoder/__init__.py +0 -0
- inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
- inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
- inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
- inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
- inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
- inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
- inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
- inference_models/models/qwen25vl/__init__.py +1 -0
- inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
- inference_models/models/resnet/__init__.py +0 -0
- inference_models/models/resnet/resnet_classification_onnx.py +330 -0
- inference_models/models/resnet/resnet_classification_torch.py +305 -0
- inference_models/models/resnet/resnet_classification_trt.py +369 -0
- inference_models/models/rfdetr/__init__.py +0 -0
- inference_models/models/rfdetr/backbone_builder.py +101 -0
- inference_models/models/rfdetr/class_remapping.py +41 -0
- inference_models/models/rfdetr/common.py +115 -0
- inference_models/models/rfdetr/default_labels.py +108 -0
- inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
- inference_models/models/rfdetr/misc.py +26 -0
- inference_models/models/rfdetr/ms_deform_attn.py +180 -0
- inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
- inference_models/models/rfdetr/position_encoding.py +166 -0
- inference_models/models/rfdetr/post_processor.py +83 -0
- inference_models/models/rfdetr/projector.py +373 -0
- inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
- inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
- inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
- inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
- inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
- inference_models/models/rfdetr/segmentation_head.py +273 -0
- inference_models/models/rfdetr/transformer.py +767 -0
- inference_models/models/roboflow_instant/__init__.py +0 -0
- inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
- inference_models/models/sam/__init__.py +0 -0
- inference_models/models/sam/cache.py +147 -0
- inference_models/models/sam/entities.py +25 -0
- inference_models/models/sam/sam_torch.py +675 -0
- inference_models/models/sam2/__init__.py +0 -0
- inference_models/models/sam2/cache.py +162 -0
- inference_models/models/sam2/entities.py +43 -0
- inference_models/models/sam2/sam2_torch.py +905 -0
- inference_models/models/sam2_rt/__init__.py +0 -0
- inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
- inference_models/models/smolvlm/__init__.py +0 -0
- inference_models/models/smolvlm/smolvlm_hf.py +245 -0
- inference_models/models/trocr/__init__.py +0 -0
- inference_models/models/trocr/trocr_hf.py +53 -0
- inference_models/models/vit/__init__.py +0 -0
- inference_models/models/vit/vit_classification_huggingface.py +319 -0
- inference_models/models/vit/vit_classification_onnx.py +326 -0
- inference_models/models/vit/vit_classification_trt.py +365 -0
- inference_models/models/yolact/__init__.py +1 -0
- inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
- inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
- inference_models/models/yolo_world/__init__.py +1 -0
- inference_models/models/yolonas/__init__.py +0 -0
- inference_models/models/yolonas/nms.py +44 -0
- inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
- inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
- inference_models/models/yolov10/__init__.py +0 -0
- inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
- inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
- inference_models/models/yolov11/__init__.py +0 -0
- inference_models/models/yolov11/yolov11_onnx.py +28 -0
- inference_models/models/yolov11/yolov11_torch_script.py +25 -0
- inference_models/models/yolov11/yolov11_trt.py +21 -0
- inference_models/models/yolov12/__init__.py +0 -0
- inference_models/models/yolov12/yolov12_onnx.py +7 -0
- inference_models/models/yolov12/yolov12_torch_script.py +7 -0
- inference_models/models/yolov12/yolov12_trt.py +7 -0
- inference_models/models/yolov5/__init__.py +0 -0
- inference_models/models/yolov5/nms.py +99 -0
- inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
- inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
- inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
- inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
- inference_models/models/yolov7/__init__.py +0 -0
- inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
- inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
- inference_models/models/yolov8/__init__.py +0 -0
- inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
- inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
- inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
- inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
- inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
- inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
- inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
- inference_models/models/yolov9/__init__.py +0 -0
- inference_models/models/yolov9/yolov9_onnx.py +7 -0
- inference_models/models/yolov9/yolov9_torch_script.py +7 -0
- inference_models/models/yolov9/yolov9_trt.py +7 -0
- inference_models/runtime_introspection/__init__.py +0 -0
- inference_models/runtime_introspection/core.py +410 -0
- inference_models/utils/__init__.py +0 -0
- inference_models/utils/download.py +608 -0
- inference_models/utils/environment.py +28 -0
- inference_models/utils/file_system.py +51 -0
- inference_models/utils/hashing.py +7 -0
- inference_models/utils/imports.py +48 -0
- inference_models/utils/onnx_introspection.py +17 -0
- inference_models/weights_providers/__init__.py +0 -0
- inference_models/weights_providers/core.py +20 -0
- inference_models/weights_providers/entities.py +159 -0
- inference_models/weights_providers/roboflow.py +601 -0
- inference_models-0.18.3.dist-info/METADATA +466 -0
- inference_models-0.18.3.dist-info/RECORD +195 -0
- inference_models-0.18.3.dist-info/WHEEL +5 -0
- inference_models-0.18.3.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,361 @@
|
|
|
1
|
+
from collections import namedtuple
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from typing import Annotated, List, Literal, Optional, Set, Tuple, Union
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel, BeforeValidator, Field, ValidationError
|
|
7
|
+
|
|
8
|
+
from inference_models.entities import ImageDimensions
|
|
9
|
+
from inference_models.errors import CorruptedModelPackageError
|
|
10
|
+
from inference_models.utils.file_system import read_json, stream_file_lines
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def parse_class_names_file(class_names_path: str) -> List[str]:
|
|
14
|
+
try:
|
|
15
|
+
result = list(stream_file_lines(path=class_names_path))
|
|
16
|
+
if not result:
|
|
17
|
+
raise ValueError("Empty class list")
|
|
18
|
+
return result
|
|
19
|
+
except (OSError, ValueError) as error:
|
|
20
|
+
raise CorruptedModelPackageError(
|
|
21
|
+
message=f"Could not decode file which is supposed to provide list of model class names. Error: {error}."
|
|
22
|
+
f"If you created model package manually, please verify its consistency in docs. In case that the "
|
|
23
|
+
f"weights are hosted on the Roboflow platform - contact support.",
|
|
24
|
+
help_url="https://todo",
|
|
25
|
+
) from error
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
PADDING_VALUES_MAPPING = {
|
|
29
|
+
"black edges": 0,
|
|
30
|
+
"grey edges": 127,
|
|
31
|
+
"white edges": 255,
|
|
32
|
+
}
|
|
33
|
+
StaticCropOffset = namedtuple(
|
|
34
|
+
"StaticCropOffset",
|
|
35
|
+
[
|
|
36
|
+
"offset_x",
|
|
37
|
+
"offset_y",
|
|
38
|
+
"crop_width",
|
|
39
|
+
"crop_height",
|
|
40
|
+
],
|
|
41
|
+
)
|
|
42
|
+
PreProcessingMetadata = namedtuple(
|
|
43
|
+
"PreProcessingMetadata",
|
|
44
|
+
[
|
|
45
|
+
"pad_left",
|
|
46
|
+
"pad_top",
|
|
47
|
+
"pad_right",
|
|
48
|
+
"pad_bottom",
|
|
49
|
+
"original_size",
|
|
50
|
+
"size_after_pre_processing",
|
|
51
|
+
"inference_size",
|
|
52
|
+
"scale_width",
|
|
53
|
+
"scale_height",
|
|
54
|
+
"static_crop_offset",
|
|
55
|
+
],
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def parse_key_points_metadata(
|
|
60
|
+
key_points_metadata_path: str,
|
|
61
|
+
) -> Tuple[List[List[str]], List[List[Tuple[int, int]]]]:
|
|
62
|
+
try:
|
|
63
|
+
parsed_config = read_json(path=key_points_metadata_path)
|
|
64
|
+
if not isinstance(parsed_config, list):
|
|
65
|
+
raise ValueError(
|
|
66
|
+
"config should contain list of key points descriptions for each instance"
|
|
67
|
+
)
|
|
68
|
+
class_names: List[Optional[List[str]]] = [None] * len(parsed_config)
|
|
69
|
+
skeletons: List[Optional[List[Tuple[int, int]]]] = [None] * len(parsed_config)
|
|
70
|
+
for instance_key_point_description in parsed_config:
|
|
71
|
+
if "object_class_id" not in instance_key_point_description:
|
|
72
|
+
raise ValueError(
|
|
73
|
+
"instance key point description lack 'object_class_id' key"
|
|
74
|
+
)
|
|
75
|
+
object_class_id: int = instance_key_point_description["object_class_id"]
|
|
76
|
+
if not 0 <= object_class_id < len(class_names):
|
|
77
|
+
raise ValueError("`object_class_id` field point invalid class")
|
|
78
|
+
if "keypoints" not in instance_key_point_description:
|
|
79
|
+
raise ValueError(
|
|
80
|
+
f"`keypoints` field not available in config for class with id {object_class_id}"
|
|
81
|
+
)
|
|
82
|
+
class_names[object_class_id] = _retrieve_key_points_names(
|
|
83
|
+
key_points=instance_key_point_description["keypoints"],
|
|
84
|
+
)
|
|
85
|
+
key_points_count = len(class_names[object_class_id])
|
|
86
|
+
if "edges" not in instance_key_point_description:
|
|
87
|
+
raise ValueError(
|
|
88
|
+
f"`edges` field not available in config for class with id {object_class_id}"
|
|
89
|
+
)
|
|
90
|
+
skeletons[object_class_id] = _retrieve_skeleton(
|
|
91
|
+
edges=instance_key_point_description["edges"],
|
|
92
|
+
key_points_count=key_points_count,
|
|
93
|
+
)
|
|
94
|
+
if any(e is None for e in class_names):
|
|
95
|
+
raise ValueError(
|
|
96
|
+
"config does not provide metadata describing each instance key points"
|
|
97
|
+
)
|
|
98
|
+
if any(e is None for e in skeletons):
|
|
99
|
+
raise ValueError(
|
|
100
|
+
"config does not provide metadata describing each instance skeleton"
|
|
101
|
+
)
|
|
102
|
+
return class_names, skeletons
|
|
103
|
+
except (IOError, OSError, ValueError) as error:
|
|
104
|
+
raise CorruptedModelPackageError(
|
|
105
|
+
message=f"Key points config file is malformed: "
|
|
106
|
+
f"{error}. In case that the package is "
|
|
107
|
+
f"hosted on the Roboflow platform - contact support. If you created model package manually, please "
|
|
108
|
+
f"verify its consistency in docs.",
|
|
109
|
+
help_url="https://todo",
|
|
110
|
+
) from error
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def _retrieve_key_points_names(key_points: dict) -> List[str]:
|
|
114
|
+
key_points_dump = sorted(
|
|
115
|
+
[(int(k), v) for k, v in key_points.items()],
|
|
116
|
+
key=lambda e: e[0],
|
|
117
|
+
)
|
|
118
|
+
return [e[1] for e in key_points_dump]
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def _retrieve_skeleton(
|
|
122
|
+
edges: List[dict], key_points_count: int
|
|
123
|
+
) -> List[Tuple[int, int]]:
|
|
124
|
+
result = []
|
|
125
|
+
for edge in edges:
|
|
126
|
+
if not isinstance(edge, dict) or "from" not in edge or "to" not in edge:
|
|
127
|
+
raise ValueError(
|
|
128
|
+
"skeleton edge malformed - invalid format or lack of required keys"
|
|
129
|
+
)
|
|
130
|
+
start = edge["from"]
|
|
131
|
+
end = edge["to"]
|
|
132
|
+
if not 0 <= start < key_points_count or not 0 <= end < key_points_count:
|
|
133
|
+
raise ValueError(
|
|
134
|
+
"skeleton edge malformed - identifier of skeleton edge end is out of allowed range determined by "
|
|
135
|
+
"the number of key points in the skeleton"
|
|
136
|
+
)
|
|
137
|
+
result.append((edge["from"], edge["to"]))
|
|
138
|
+
return result
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
@dataclass
|
|
142
|
+
class TRTConfig:
|
|
143
|
+
static_batch_size: Optional[int]
|
|
144
|
+
dynamic_batch_size_min: Optional[int]
|
|
145
|
+
dynamic_batch_size_opt: Optional[int]
|
|
146
|
+
dynamic_batch_size_max: Optional[int]
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def parse_trt_config(config_path: str) -> TRTConfig:
|
|
150
|
+
try:
|
|
151
|
+
parsed_config = read_json(path=config_path)
|
|
152
|
+
if not isinstance(parsed_config, dict):
|
|
153
|
+
raise ValueError(
|
|
154
|
+
f"Expected config format is dict, found {type(parsed_config)} instead"
|
|
155
|
+
)
|
|
156
|
+
config = TRTConfig(
|
|
157
|
+
static_batch_size=parsed_config.get("static_batch_size"),
|
|
158
|
+
dynamic_batch_size_min=parsed_config.get("dynamic_batch_size_min"),
|
|
159
|
+
dynamic_batch_size_opt=parsed_config.get("dynamic_batch_size_opt"),
|
|
160
|
+
dynamic_batch_size_max=parsed_config.get("dynamic_batch_size_max"),
|
|
161
|
+
)
|
|
162
|
+
if config.static_batch_size is not None:
|
|
163
|
+
if config.static_batch_size <= 0:
|
|
164
|
+
raise ValueError(
|
|
165
|
+
f"invalid static batch size - {config.static_batch_size}"
|
|
166
|
+
)
|
|
167
|
+
return config
|
|
168
|
+
if (
|
|
169
|
+
config.dynamic_batch_size_min is None
|
|
170
|
+
or config.dynamic_batch_size_opt is None
|
|
171
|
+
or config.dynamic_batch_size_max is None
|
|
172
|
+
):
|
|
173
|
+
raise ValueError(
|
|
174
|
+
"configuration does not provide information about boundaries for dynamic batch size"
|
|
175
|
+
)
|
|
176
|
+
if (
|
|
177
|
+
config.dynamic_batch_size_min <= 0
|
|
178
|
+
or config.dynamic_batch_size_max < config.dynamic_batch_size_min
|
|
179
|
+
or config.dynamic_batch_size_opt < config.dynamic_batch_size_min
|
|
180
|
+
or config.dynamic_batch_size_opt > config.dynamic_batch_size_max
|
|
181
|
+
):
|
|
182
|
+
raise ValueError(f"invalid dynamic batch size")
|
|
183
|
+
return config
|
|
184
|
+
except (IOError, OSError, ValueError) as error:
|
|
185
|
+
raise CorruptedModelPackageError(
|
|
186
|
+
message=f"TRT config file of the model package is malformed: "
|
|
187
|
+
f"{error}. In case that the package is "
|
|
188
|
+
f"hosted on the Roboflow platform - contact support. If you created model package manually, please "
|
|
189
|
+
f"verify its consistency in docs.",
|
|
190
|
+
help_url="https://todo",
|
|
191
|
+
) from error
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
class AutoOrient(BaseModel):
|
|
195
|
+
enabled: bool
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
class StaticCrop(BaseModel):
|
|
199
|
+
enabled: bool
|
|
200
|
+
x_min: int
|
|
201
|
+
x_max: int
|
|
202
|
+
y_min: int
|
|
203
|
+
y_max: int
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
class ContrastType(str, Enum):
|
|
207
|
+
ADAPTIVE_EQUALIZATION = "Adaptive Equalization"
|
|
208
|
+
CONTRAST_STRETCHING = "Contrast Stretching"
|
|
209
|
+
HISTOGRAM_EQUALIZATION = "Histogram Equalization"
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
class Contrast(BaseModel):
|
|
213
|
+
enabled: bool
|
|
214
|
+
type: ContrastType
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
class Grayscale(BaseModel):
|
|
218
|
+
enabled: bool
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
class ImagePreProcessing(BaseModel):
|
|
222
|
+
auto_orient: Optional[AutoOrient] = Field(alias="auto-orient", default=None)
|
|
223
|
+
static_crop: Optional[StaticCrop] = Field(alias="static-crop", default=None)
|
|
224
|
+
contrast: Optional[Contrast] = Field(default=None)
|
|
225
|
+
grayscale: Optional[Grayscale] = Field(default=None)
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
class TrainingInputSize(BaseModel):
|
|
229
|
+
height: int
|
|
230
|
+
width: int
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
class DivisiblePadding(BaseModel):
|
|
234
|
+
type: Literal["pad-to-be-divisible"]
|
|
235
|
+
value: int
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
class AnySizePadding(BaseModel):
|
|
239
|
+
type: Literal["any-size"]
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
class ColorMode(str, Enum):
|
|
243
|
+
BGR = "bgr"
|
|
244
|
+
RGB = "rgb"
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
class ResizeMode(str, Enum):
|
|
248
|
+
STRETCH_TO = "stretch"
|
|
249
|
+
LETTERBOX = "letterbox"
|
|
250
|
+
CENTER_CROP = "center-crop"
|
|
251
|
+
FIT_LONGER_EDGE = "fit-longer-edge"
|
|
252
|
+
LETTERBOX_REFLECT_EDGES = "letterbox-reflect-edges"
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
Number = Union[int, float]
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
class NetworkInputDefinition(BaseModel):
|
|
259
|
+
training_input_size: TrainingInputSize
|
|
260
|
+
dynamic_spatial_size_supported: bool
|
|
261
|
+
dynamic_spatial_size_mode: Optional[Union[DivisiblePadding, AnySizePadding]] = (
|
|
262
|
+
Field(discriminator="type", default=None)
|
|
263
|
+
)
|
|
264
|
+
color_mode: ColorMode
|
|
265
|
+
resize_mode: ResizeMode
|
|
266
|
+
padding_value: Optional[int] = Field(default=None)
|
|
267
|
+
input_channels: int
|
|
268
|
+
scaling_factor: Optional[Number] = Field(default=None)
|
|
269
|
+
normalization: Optional[Tuple[List[Number], List[Number]]] = Field(default=None)
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
class ForwardPassConfiguration(BaseModel):
|
|
273
|
+
static_batch_size: Optional[int] = Field(default=None)
|
|
274
|
+
max_dynamic_batch_size: Optional[int] = Field(default=None)
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
class FusedNMSParameters(BaseModel):
|
|
278
|
+
max_detections: int
|
|
279
|
+
confidence_threshold: float
|
|
280
|
+
iou_threshold: float
|
|
281
|
+
class_agnostic: int
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
class NMSPostProcessing(BaseModel):
|
|
285
|
+
type: Literal["nms"]
|
|
286
|
+
fused: bool
|
|
287
|
+
nms_parameters: Optional[FusedNMSParameters] = Field(default=None)
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
class SigmoidPostProcessing(BaseModel):
|
|
291
|
+
type: Literal["sigmoid"]
|
|
292
|
+
fused: bool
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
class SoftMaxPostProcessing(BaseModel):
|
|
296
|
+
type: Literal["softmax"]
|
|
297
|
+
fused: bool
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
ImagePreProcessingValidator = BeforeValidator(
|
|
301
|
+
lambda value: value if value is not None else ImagePreProcessing()
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
class ClassNameRemoval(BaseModel):
|
|
306
|
+
type: Literal["class_name_removal"]
|
|
307
|
+
class_name: str
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
class InferenceConfig(BaseModel):
|
|
311
|
+
image_pre_processing: Annotated[ImagePreProcessing, ImagePreProcessingValidator] = (
|
|
312
|
+
Field(default_factory=lambda: ImagePreProcessing())
|
|
313
|
+
)
|
|
314
|
+
network_input: NetworkInputDefinition
|
|
315
|
+
forward_pass: ForwardPassConfiguration = Field(
|
|
316
|
+
default_factory=lambda: ForwardPassConfiguration()
|
|
317
|
+
)
|
|
318
|
+
post_processing: Optional[
|
|
319
|
+
Union[NMSPostProcessing, SoftMaxPostProcessing, SigmoidPostProcessing]
|
|
320
|
+
] = Field(default=None, discriminator="type")
|
|
321
|
+
model_initialization: Optional[dict] = Field(default=None)
|
|
322
|
+
class_names_operations: Optional[
|
|
323
|
+
List[Annotated[Union[ClassNameRemoval], Field(discriminator="type")]]
|
|
324
|
+
] = Field(default=None)
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
def parse_inference_config(
|
|
328
|
+
config_path: str,
|
|
329
|
+
allowed_resize_modes: Set[ResizeMode],
|
|
330
|
+
) -> InferenceConfig:
|
|
331
|
+
try:
|
|
332
|
+
decoded_config = read_json(path=config_path)
|
|
333
|
+
if not isinstance(decoded_config, dict):
|
|
334
|
+
raise ValueError(
|
|
335
|
+
f"Expected config format is dict, found {type(decoded_config)} instead"
|
|
336
|
+
)
|
|
337
|
+
except (IOError, OSError, ValueError) as error:
|
|
338
|
+
raise CorruptedModelPackageError(
|
|
339
|
+
message=f"Inference config file of the model package is malformed: "
|
|
340
|
+
f"{error}. In case that the package is "
|
|
341
|
+
f"hosted on the Roboflow platform - contact support. If you created model package manually, please "
|
|
342
|
+
f"verify its consistency in docs.",
|
|
343
|
+
help_url="https://todo",
|
|
344
|
+
) from error
|
|
345
|
+
try:
|
|
346
|
+
parsed_config = InferenceConfig.model_validate(decoded_config)
|
|
347
|
+
except ValidationError as error:
|
|
348
|
+
raise CorruptedModelPackageError(
|
|
349
|
+
message=f"Could not parse the inference config from the model package.",
|
|
350
|
+
help_url="https://todo",
|
|
351
|
+
) from error
|
|
352
|
+
if parsed_config.network_input.resize_mode not in allowed_resize_modes:
|
|
353
|
+
allowed_resize_modes_str = ", ".join([e.value for e in allowed_resize_modes])
|
|
354
|
+
raise CorruptedModelPackageError(
|
|
355
|
+
message=f"Inference configuration shipped with model package defines input resize "
|
|
356
|
+
f"{parsed_config.network_input.resize_mode} which is not supported by the model implementation. "
|
|
357
|
+
f"Config defines: {parsed_config.network_input.resize_mode.value}, but the allowed values are: "
|
|
358
|
+
f"{allowed_resize_modes_str}.",
|
|
359
|
+
help_url="https://todo",
|
|
360
|
+
)
|
|
361
|
+
return parsed_config
|