inference-models 0.18.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. inference_models/__init__.py +36 -0
  2. inference_models/configuration.py +72 -0
  3. inference_models/constants.py +2 -0
  4. inference_models/entities.py +5 -0
  5. inference_models/errors.py +137 -0
  6. inference_models/logger.py +52 -0
  7. inference_models/model_pipelines/__init__.py +0 -0
  8. inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
  9. inference_models/model_pipelines/auto_loaders/core.py +120 -0
  10. inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
  11. inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
  12. inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
  13. inference_models/models/__init__.py +0 -0
  14. inference_models/models/auto_loaders/__init__.py +0 -0
  15. inference_models/models/auto_loaders/access_manager.py +168 -0
  16. inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
  17. inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
  18. inference_models/models/auto_loaders/constants.py +7 -0
  19. inference_models/models/auto_loaders/core.py +1341 -0
  20. inference_models/models/auto_loaders/dependency_models.py +52 -0
  21. inference_models/models/auto_loaders/entities.py +57 -0
  22. inference_models/models/auto_loaders/models_registry.py +497 -0
  23. inference_models/models/auto_loaders/presentation_utils.py +333 -0
  24. inference_models/models/auto_loaders/ranking.py +413 -0
  25. inference_models/models/auto_loaders/utils.py +31 -0
  26. inference_models/models/base/__init__.py +0 -0
  27. inference_models/models/base/classification.py +123 -0
  28. inference_models/models/base/depth_estimation.py +62 -0
  29. inference_models/models/base/documents_parsing.py +111 -0
  30. inference_models/models/base/embeddings.py +66 -0
  31. inference_models/models/base/instance_segmentation.py +87 -0
  32. inference_models/models/base/keypoints_detection.py +93 -0
  33. inference_models/models/base/object_detection.py +143 -0
  34. inference_models/models/base/semantic_segmentation.py +74 -0
  35. inference_models/models/base/types.py +5 -0
  36. inference_models/models/clip/__init__.py +0 -0
  37. inference_models/models/clip/clip_onnx.py +148 -0
  38. inference_models/models/clip/clip_pytorch.py +104 -0
  39. inference_models/models/clip/preprocessing.py +162 -0
  40. inference_models/models/common/__init__.py +0 -0
  41. inference_models/models/common/cuda.py +30 -0
  42. inference_models/models/common/model_packages.py +25 -0
  43. inference_models/models/common/onnx.py +379 -0
  44. inference_models/models/common/roboflow/__init__.py +0 -0
  45. inference_models/models/common/roboflow/model_packages.py +361 -0
  46. inference_models/models/common/roboflow/post_processing.py +436 -0
  47. inference_models/models/common/roboflow/pre_processing.py +1332 -0
  48. inference_models/models/common/torch.py +20 -0
  49. inference_models/models/common/trt.py +266 -0
  50. inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
  51. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
  52. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
  53. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
  54. inference_models/models/depth_anything_v2/__init__.py +0 -0
  55. inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
  56. inference_models/models/dinov3/__init__.py +0 -0
  57. inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
  58. inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
  59. inference_models/models/doctr/__init__.py +0 -0
  60. inference_models/models/doctr/doctr_torch.py +304 -0
  61. inference_models/models/easy_ocr/__init__.py +0 -0
  62. inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
  63. inference_models/models/florence2/__init__.py +0 -0
  64. inference_models/models/florence2/florence2_hf.py +897 -0
  65. inference_models/models/grounding_dino/__init__.py +0 -0
  66. inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
  67. inference_models/models/l2cs/__init__.py +0 -0
  68. inference_models/models/l2cs/l2cs_onnx.py +216 -0
  69. inference_models/models/mediapipe_face_detection/__init__.py +0 -0
  70. inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
  71. inference_models/models/moondream2/__init__.py +0 -0
  72. inference_models/models/moondream2/moondream2_hf.py +281 -0
  73. inference_models/models/owlv2/__init__.py +0 -0
  74. inference_models/models/owlv2/cache.py +182 -0
  75. inference_models/models/owlv2/entities.py +112 -0
  76. inference_models/models/owlv2/owlv2_hf.py +695 -0
  77. inference_models/models/owlv2/reference_dataset.py +291 -0
  78. inference_models/models/paligemma/__init__.py +0 -0
  79. inference_models/models/paligemma/paligemma_hf.py +209 -0
  80. inference_models/models/perception_encoder/__init__.py +0 -0
  81. inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
  82. inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
  83. inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
  84. inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
  85. inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
  86. inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
  87. inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
  88. inference_models/models/qwen25vl/__init__.py +1 -0
  89. inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
  90. inference_models/models/resnet/__init__.py +0 -0
  91. inference_models/models/resnet/resnet_classification_onnx.py +330 -0
  92. inference_models/models/resnet/resnet_classification_torch.py +305 -0
  93. inference_models/models/resnet/resnet_classification_trt.py +369 -0
  94. inference_models/models/rfdetr/__init__.py +0 -0
  95. inference_models/models/rfdetr/backbone_builder.py +101 -0
  96. inference_models/models/rfdetr/class_remapping.py +41 -0
  97. inference_models/models/rfdetr/common.py +115 -0
  98. inference_models/models/rfdetr/default_labels.py +108 -0
  99. inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
  100. inference_models/models/rfdetr/misc.py +26 -0
  101. inference_models/models/rfdetr/ms_deform_attn.py +180 -0
  102. inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
  103. inference_models/models/rfdetr/position_encoding.py +166 -0
  104. inference_models/models/rfdetr/post_processor.py +83 -0
  105. inference_models/models/rfdetr/projector.py +373 -0
  106. inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
  107. inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
  108. inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
  109. inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
  110. inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
  111. inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
  112. inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
  113. inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
  114. inference_models/models/rfdetr/segmentation_head.py +273 -0
  115. inference_models/models/rfdetr/transformer.py +767 -0
  116. inference_models/models/roboflow_instant/__init__.py +0 -0
  117. inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
  118. inference_models/models/sam/__init__.py +0 -0
  119. inference_models/models/sam/cache.py +147 -0
  120. inference_models/models/sam/entities.py +25 -0
  121. inference_models/models/sam/sam_torch.py +675 -0
  122. inference_models/models/sam2/__init__.py +0 -0
  123. inference_models/models/sam2/cache.py +162 -0
  124. inference_models/models/sam2/entities.py +43 -0
  125. inference_models/models/sam2/sam2_torch.py +905 -0
  126. inference_models/models/sam2_rt/__init__.py +0 -0
  127. inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
  128. inference_models/models/smolvlm/__init__.py +0 -0
  129. inference_models/models/smolvlm/smolvlm_hf.py +245 -0
  130. inference_models/models/trocr/__init__.py +0 -0
  131. inference_models/models/trocr/trocr_hf.py +53 -0
  132. inference_models/models/vit/__init__.py +0 -0
  133. inference_models/models/vit/vit_classification_huggingface.py +319 -0
  134. inference_models/models/vit/vit_classification_onnx.py +326 -0
  135. inference_models/models/vit/vit_classification_trt.py +365 -0
  136. inference_models/models/yolact/__init__.py +1 -0
  137. inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
  138. inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
  139. inference_models/models/yolo_world/__init__.py +1 -0
  140. inference_models/models/yolonas/__init__.py +0 -0
  141. inference_models/models/yolonas/nms.py +44 -0
  142. inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
  143. inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
  144. inference_models/models/yolov10/__init__.py +0 -0
  145. inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
  146. inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
  147. inference_models/models/yolov11/__init__.py +0 -0
  148. inference_models/models/yolov11/yolov11_onnx.py +28 -0
  149. inference_models/models/yolov11/yolov11_torch_script.py +25 -0
  150. inference_models/models/yolov11/yolov11_trt.py +21 -0
  151. inference_models/models/yolov12/__init__.py +0 -0
  152. inference_models/models/yolov12/yolov12_onnx.py +7 -0
  153. inference_models/models/yolov12/yolov12_torch_script.py +7 -0
  154. inference_models/models/yolov12/yolov12_trt.py +7 -0
  155. inference_models/models/yolov5/__init__.py +0 -0
  156. inference_models/models/yolov5/nms.py +99 -0
  157. inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
  158. inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
  159. inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
  160. inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
  161. inference_models/models/yolov7/__init__.py +0 -0
  162. inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
  163. inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
  164. inference_models/models/yolov8/__init__.py +0 -0
  165. inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
  166. inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
  167. inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
  168. inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
  169. inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
  170. inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
  171. inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
  172. inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
  173. inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
  174. inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
  175. inference_models/models/yolov9/__init__.py +0 -0
  176. inference_models/models/yolov9/yolov9_onnx.py +7 -0
  177. inference_models/models/yolov9/yolov9_torch_script.py +7 -0
  178. inference_models/models/yolov9/yolov9_trt.py +7 -0
  179. inference_models/runtime_introspection/__init__.py +0 -0
  180. inference_models/runtime_introspection/core.py +410 -0
  181. inference_models/utils/__init__.py +0 -0
  182. inference_models/utils/download.py +608 -0
  183. inference_models/utils/environment.py +28 -0
  184. inference_models/utils/file_system.py +51 -0
  185. inference_models/utils/hashing.py +7 -0
  186. inference_models/utils/imports.py +48 -0
  187. inference_models/utils/onnx_introspection.py +17 -0
  188. inference_models/weights_providers/__init__.py +0 -0
  189. inference_models/weights_providers/core.py +20 -0
  190. inference_models/weights_providers/entities.py +159 -0
  191. inference_models/weights_providers/roboflow.py +601 -0
  192. inference_models-0.18.3.dist-info/METADATA +466 -0
  193. inference_models-0.18.3.dist-info/RECORD +195 -0
  194. inference_models-0.18.3.dist-info/WHEEL +5 -0
  195. inference_models-0.18.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,33 @@
1
+ import torchvision.transforms as T
2
+
3
+ from inference_models.models.perception_encoder.vision_encoder.tokenizer import (
4
+ SimpleTokenizer,
5
+ )
6
+
7
+
8
+ def get_image_transform(
9
+ image_size: int,
10
+ center_crop: bool = False,
11
+ interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR, # We used bilinear during training
12
+ ):
13
+ if center_crop:
14
+ crop = [
15
+ T.Resize(image_size, interpolation=interpolation),
16
+ T.CenterCrop(image_size),
17
+ ]
18
+ else:
19
+ # "Squash": most versatile
20
+ crop = [T.Resize((image_size, image_size), interpolation=interpolation)]
21
+
22
+ return T.Compose(
23
+ crop
24
+ + [
25
+ T.Lambda(lambda x: x.convert("RGB")),
26
+ T.ToTensor(),
27
+ T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True),
28
+ ]
29
+ )
30
+
31
+
32
+ def get_text_tokenizer(context_length: int):
33
+ return SimpleTokenizer(context_length=context_length)
@@ -0,0 +1 @@
1
+ # This file makes the qwen25vl directory a Python package
@@ -0,0 +1,285 @@
1
+ import os
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ from peft import PeftModel
7
+ from transformers import (
8
+ BitsAndBytesConfig,
9
+ Qwen2_5_VLForConditionalGeneration,
10
+ Qwen2_5_VLProcessor,
11
+ )
12
+
13
+ from inference_models.configuration import DEFAULT_DEVICE
14
+ from inference_models.entities import ColorFormat
15
+ from inference_models.models.common.roboflow.model_packages import (
16
+ InferenceConfig,
17
+ ResizeMode,
18
+ parse_inference_config,
19
+ )
20
+ from inference_models.models.common.roboflow.pre_processing import (
21
+ pre_process_network_input,
22
+ )
23
+
24
+
25
+ class Qwen25VLHF:
26
+ @classmethod
27
+ def from_pretrained(
28
+ cls,
29
+ model_name_or_path: str,
30
+ device: torch.device = DEFAULT_DEVICE,
31
+ trust_remote_code: bool = False,
32
+ local_files_only: bool = True,
33
+ quantization_config: Optional[BitsAndBytesConfig] = None,
34
+ disable_quantization: bool = False,
35
+ **kwargs,
36
+ ) -> "Qwen25VLHF":
37
+ adapter_config_path = os.path.join(model_name_or_path, "adapter_config.json")
38
+ inference_config_path = os.path.join(
39
+ model_name_or_path, "inference_config.json"
40
+ )
41
+ inference_config = None
42
+ if os.path.exists(inference_config_path):
43
+ inference_config = parse_inference_config(
44
+ config_path=inference_config_path,
45
+ allowed_resize_modes={
46
+ ResizeMode.STRETCH_TO,
47
+ ResizeMode.LETTERBOX,
48
+ ResizeMode.CENTER_CROP,
49
+ ResizeMode.LETTERBOX_REFLECT_EDGES,
50
+ },
51
+ )
52
+ if (
53
+ quantization_config is None
54
+ and device.type == "cuda"
55
+ and not disable_quantization
56
+ ):
57
+ quantization_config = BitsAndBytesConfig(
58
+ load_in_4bit=True,
59
+ bnb_4bit_compute_dtype=torch.float16,
60
+ bnb_4bit_quant_type="nf4",
61
+ )
62
+ if os.path.exists(adapter_config_path):
63
+ base_model_path = os.path.join(model_name_or_path, "base")
64
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
65
+ base_model_path,
66
+ dtype="auto",
67
+ trust_remote_code=trust_remote_code,
68
+ local_files_only=local_files_only,
69
+ quantization_config=quantization_config,
70
+ )
71
+ model = PeftModel.from_pretrained(model, model_name_or_path)
72
+ if quantization_config is None:
73
+ model.merge_and_unload()
74
+ model.to(device)
75
+ processor = Qwen2_5_VLProcessor.from_pretrained(
76
+ model_name_or_path,
77
+ trust_remote_code=trust_remote_code,
78
+ local_files_only=local_files_only,
79
+ use_fast=True,
80
+ )
81
+ else:
82
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
83
+ model_name_or_path,
84
+ dtype="auto",
85
+ device_map=device,
86
+ trust_remote_code=trust_remote_code,
87
+ local_files_only=local_files_only,
88
+ quantization_config=quantization_config,
89
+ ).eval()
90
+ Qwen2_5_VLProcessor.image_processor_class = "Qwen2VLImageProcessor"
91
+ processor = Qwen2_5_VLProcessor.from_pretrained(
92
+ model_name_or_path,
93
+ trust_remote_code=trust_remote_code,
94
+ local_files_only=local_files_only,
95
+ use_fast=True,
96
+ )
97
+ return cls(
98
+ model=model,
99
+ processor=processor,
100
+ inference_config=inference_config,
101
+ device=device,
102
+ )
103
+
104
+ def __init__(
105
+ self,
106
+ model: Qwen2_5_VLForConditionalGeneration,
107
+ processor: Qwen2_5_VLProcessor,
108
+ inference_config: Optional[InferenceConfig],
109
+ device: torch.device,
110
+ ):
111
+ self._model = model
112
+ self._processor = processor
113
+ self._inference_config = inference_config
114
+ self._device = device
115
+ self.default_system_prompt = (
116
+ "You are a Qwen2.5-VL model that can answer questions about any image."
117
+ )
118
+
119
+ def prompt(
120
+ self,
121
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
122
+ prompt: str = None,
123
+ input_color_format: ColorFormat = None,
124
+ max_new_tokens: int = 512,
125
+ do_sample: bool = False,
126
+ skip_special_tokens: bool = False,
127
+ **kwargs,
128
+ ) -> List[str]:
129
+ inputs = self.pre_process_generation(
130
+ images=images, prompt=prompt, input_color_format=input_color_format
131
+ )
132
+ generated_ids = self.generate(
133
+ inputs=inputs,
134
+ max_new_tokens=max_new_tokens,
135
+ do_sample=do_sample,
136
+ )
137
+ return self.post_process_generation(
138
+ generated_ids=generated_ids,
139
+ skip_special_tokens=skip_special_tokens,
140
+ )
141
+
142
+ def pre_process_generation(
143
+ self,
144
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
145
+ prompt: str = None,
146
+ input_color_format: ColorFormat = None,
147
+ image_size: Optional[Tuple[int, int]] = None,
148
+ **kwargs,
149
+ ) -> dict:
150
+ def _to_tensor(image: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
151
+ is_numpy = isinstance(image, np.ndarray)
152
+ if is_numpy:
153
+ tensor_image = torch.from_numpy(image.copy()).permute(2, 0, 1)
154
+ else:
155
+ tensor_image = image
156
+ if input_color_format == "bgr" or (is_numpy and input_color_format is None):
157
+ tensor_image = tensor_image[[2, 1, 0], :, :]
158
+ return tensor_image
159
+
160
+ if self._inference_config is None:
161
+ if isinstance(images, torch.Tensor) and images.ndim > 3:
162
+ image_list = [_to_tensor(img) for img in images]
163
+ elif not isinstance(images, list):
164
+ image_list = [_to_tensor(images)]
165
+ else:
166
+ image_list = [_to_tensor(img) for img in images]
167
+ else:
168
+ images = pre_process_network_input(
169
+ images=images,
170
+ image_pre_processing=self._inference_config.image_pre_processing,
171
+ network_input=self._inference_config.network_input,
172
+ target_device=self._device,
173
+ input_color_format=input_color_format,
174
+ image_size_wh=image_size,
175
+ )[0]
176
+ image_list = [e[0] for e in torch.split(images, 1, dim=0)]
177
+ # Handle prompt and system prompt parsing logic from original implementation
178
+ if prompt is None:
179
+ prompt = ""
180
+ system_prompt = self.default_system_prompt
181
+ else:
182
+ split_prompt = prompt.split("<system_prompt>")
183
+ if len(split_prompt) == 1:
184
+ prompt = split_prompt[0]
185
+ system_prompt = self.default_system_prompt
186
+ else:
187
+ prompt = split_prompt[0]
188
+ system_prompt = split_prompt[1]
189
+
190
+ # Construct conversation following original implementation structure
191
+ conversation = [
192
+ {
193
+ "role": "system",
194
+ "content": [{"type": "text", "text": system_prompt}],
195
+ },
196
+ {
197
+ "role": "user",
198
+ "content": [
199
+ {"type": "image"}, # Processor will handle the actual image
200
+ {"type": "text", "text": prompt},
201
+ ],
202
+ },
203
+ ]
204
+
205
+ # Apply chat template
206
+ text_input = self._processor.apply_chat_template(
207
+ conversation, tokenize=False, add_generation_prompt=True
208
+ )
209
+
210
+ # Process inputs - processor will handle tensor/array inputs directly
211
+ model_inputs = self._processor(
212
+ text=text_input,
213
+ images=image_list,
214
+ return_tensors="pt",
215
+ padding=True,
216
+ )
217
+
218
+ # Move inputs to device
219
+ model_inputs = {
220
+ k: v.to(self._device)
221
+ for k, v in model_inputs.items()
222
+ if isinstance(v, torch.Tensor)
223
+ }
224
+
225
+ return model_inputs
226
+
227
+ def generate(
228
+ self,
229
+ inputs: dict,
230
+ max_new_tokens: int = 512,
231
+ do_sample: bool = False,
232
+ **kwargs,
233
+ ) -> torch.Tensor:
234
+ input_len = inputs["input_ids"].shape[-1]
235
+
236
+ with torch.inference_mode():
237
+ generation = self._model.generate(
238
+ **inputs,
239
+ max_new_tokens=max_new_tokens,
240
+ do_sample=do_sample,
241
+ pad_token_id=self._processor.tokenizer.pad_token_id,
242
+ eos_token_id=self._processor.tokenizer.eos_token_id,
243
+ bos_token_id=self._processor.tokenizer.bos_token_id,
244
+ )
245
+
246
+ # Return only the newly generated tokens
247
+ return generation[:, input_len:]
248
+
249
+ def post_process_generation(
250
+ self,
251
+ generated_ids: torch.Tensor,
252
+ skip_special_tokens: bool = False,
253
+ **kwargs,
254
+ ) -> List[str]:
255
+ # Decode the generated tokens
256
+ decoded = self._processor.batch_decode(
257
+ generated_ids,
258
+ skip_special_tokens=skip_special_tokens,
259
+ )
260
+
261
+ # Apply the same post-processing as original implementation
262
+ result = []
263
+ for text in decoded:
264
+ text = text.replace("assistant\n", "")
265
+ text = text.replace(" addCriterion\n", "")
266
+ result.append(text.strip())
267
+
268
+ return result
269
+
270
+
271
+ def adjust_lora_model_state_dict(state_dict: dict) -> dict:
272
+ return {
273
+ refactor_adapter_weights_key(key=key): value
274
+ for key, value in state_dict.items()
275
+ }
276
+
277
+
278
+ def refactor_adapter_weights_key(key: str) -> str:
279
+ if ".language_model." in key:
280
+ return key
281
+ return (
282
+ key.replace("model.layers", "model.language_model.layers")
283
+ .replace(".weight", ".default.weight")
284
+ .replace(".lora_magnitude_vector", ".lora_magnitude_vector.default.weight")
285
+ )
File without changes
@@ -0,0 +1,330 @@
1
+ from threading import Lock
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ from inference_models import (
8
+ ClassificationModel,
9
+ ClassificationPrediction,
10
+ MultiLabelClassificationModel,
11
+ MultiLabelClassificationPrediction,
12
+ )
13
+ from inference_models.configuration import DEFAULT_DEVICE
14
+ from inference_models.entities import ColorFormat
15
+ from inference_models.errors import (
16
+ CorruptedModelPackageError,
17
+ EnvironmentConfigurationError,
18
+ MissingDependencyError,
19
+ )
20
+ from inference_models.models.base.types import PreprocessedInputs
21
+ from inference_models.models.common.model_packages import get_model_package_contents
22
+ from inference_models.models.common.onnx import (
23
+ run_session_with_batch_size_limit,
24
+ set_execution_provider_defaults,
25
+ )
26
+ from inference_models.models.common.roboflow.model_packages import (
27
+ InferenceConfig,
28
+ ResizeMode,
29
+ parse_class_names_file,
30
+ parse_inference_config,
31
+ )
32
+ from inference_models.models.common.roboflow.pre_processing import (
33
+ pre_process_network_input,
34
+ )
35
+ from inference_models.utils.onnx_introspection import (
36
+ get_selected_onnx_execution_providers,
37
+ )
38
+
39
+ try:
40
+ import onnxruntime
41
+ except ImportError as import_error:
42
+ raise MissingDependencyError(
43
+ message=f"Could not import ResNet model with ONNX backend - this error means that some additional dependencies "
44
+ f"are not installed in the environment. If you run the `inference-models` library directly in your Python "
45
+ f"program, make sure the following extras of the package are installed: \n"
46
+ f"\t* `onnx-cpu` - when you wish to use library with CPU support only\n"
47
+ f"\t* `onnx-cu12` - for running on GPU with Cuda 12 installed\n"
48
+ f"\t* `onnx-cu118` - for running on GPU with Cuda 11.8 installed\n"
49
+ f"\t* `onnx-jp6-cu126` - for running on Jetson with Jetpack 6\n"
50
+ f"If you see this error using Roboflow infrastructure, make sure the service you use does support the model. "
51
+ f"You can also contact Roboflow to get support.",
52
+ help_url="https://todo",
53
+ ) from import_error
54
+
55
+
56
+ class ResNetForClassificationOnnx(ClassificationModel[torch.Tensor, torch.Tensor]):
57
+
58
+ @classmethod
59
+ def from_pretrained(
60
+ cls,
61
+ model_name_or_path: str,
62
+ onnx_execution_providers: Optional[List[Union[str, tuple]]] = None,
63
+ default_onnx_trt_options: bool = True,
64
+ device: torch.device = DEFAULT_DEVICE,
65
+ **kwargs,
66
+ ) -> "ResNetForClassificationOnnx":
67
+ if onnx_execution_providers is None:
68
+ onnx_execution_providers = get_selected_onnx_execution_providers()
69
+ if not onnx_execution_providers:
70
+ raise EnvironmentConfigurationError(
71
+ message=f"Could not initialize model - selected backend is ONNX which requires execution provider to "
72
+ f"be specified - explicitly in `from_pretrained(...)` method or via env variable "
73
+ f"`ONNXRUNTIME_EXECUTION_PROVIDERS`. If you run model locally - adjust your setup, otherwise "
74
+ f"contact the platform support.",
75
+ help_url="https://todo",
76
+ )
77
+ onnx_execution_providers = set_execution_provider_defaults(
78
+ providers=onnx_execution_providers,
79
+ model_package_path=model_name_or_path,
80
+ device=device,
81
+ default_onnx_trt_options=default_onnx_trt_options,
82
+ )
83
+ model_package_content = get_model_package_contents(
84
+ model_package_dir=model_name_or_path,
85
+ elements=[
86
+ "class_names.txt",
87
+ "inference_config.json",
88
+ "weights.onnx",
89
+ ],
90
+ )
91
+ class_names = parse_class_names_file(
92
+ class_names_path=model_package_content["class_names.txt"]
93
+ )
94
+ inference_config = parse_inference_config(
95
+ config_path=model_package_content["inference_config.json"],
96
+ allowed_resize_modes={
97
+ ResizeMode.STRETCH_TO,
98
+ ResizeMode.LETTERBOX,
99
+ ResizeMode.CENTER_CROP,
100
+ ResizeMode.LETTERBOX_REFLECT_EDGES,
101
+ },
102
+ )
103
+ if inference_config.post_processing.type != "softmax":
104
+ raise CorruptedModelPackageError(
105
+ message="Expected Softmax to be the post-processing",
106
+ help_url="https://todo",
107
+ )
108
+ session = onnxruntime.InferenceSession(
109
+ path_or_bytes=model_package_content["weights.onnx"],
110
+ providers=onnx_execution_providers,
111
+ )
112
+ input_shape = session.get_inputs()[0].shape
113
+ input_batch_size = input_shape[0]
114
+ if isinstance(input_batch_size, str):
115
+ input_batch_size = None
116
+ input_name = session.get_inputs()[0].name
117
+ return cls(
118
+ session=session,
119
+ input_name=input_name,
120
+ inference_config=inference_config,
121
+ class_names=class_names,
122
+ device=device,
123
+ input_batch_size=input_batch_size,
124
+ )
125
+
126
+ def __init__(
127
+ self,
128
+ session: onnxruntime.InferenceSession,
129
+ input_name: str,
130
+ inference_config: InferenceConfig,
131
+ class_names: List[str],
132
+ device: torch.device,
133
+ input_batch_size: Optional[int],
134
+ ):
135
+ self._session = session
136
+ self._input_name = input_name
137
+ self._inference_config = inference_config
138
+ self._class_names = class_names
139
+ self._device = device
140
+ self._input_batch_size = input_batch_size
141
+ self._session_thread_lock = Lock()
142
+
143
+ @property
144
+ def class_names(self) -> List[str]:
145
+ return self._class_names
146
+
147
+ def pre_process(
148
+ self,
149
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
150
+ input_color_format: Optional[ColorFormat] = None,
151
+ image_size: Optional[Tuple[int, int]] = None,
152
+ **kwargs,
153
+ ) -> torch.Tensor:
154
+ return pre_process_network_input(
155
+ images=images,
156
+ image_pre_processing=self._inference_config.image_pre_processing,
157
+ network_input=self._inference_config.network_input,
158
+ target_device=self._device,
159
+ input_color_format=input_color_format,
160
+ image_size_wh=image_size,
161
+ )[0]
162
+
163
+ def forward(
164
+ self, pre_processed_images: PreprocessedInputs, **kwargs
165
+ ) -> torch.Tensor:
166
+ with self._session_thread_lock:
167
+ return run_session_with_batch_size_limit(
168
+ session=self._session,
169
+ inputs={self._input_name: pre_processed_images},
170
+ min_batch_size=self._input_batch_size,
171
+ max_batch_size=self._input_batch_size,
172
+ )[0]
173
+
174
+ def post_process(
175
+ self,
176
+ model_results: torch.Tensor,
177
+ **kwargs,
178
+ ) -> ClassificationPrediction:
179
+ if self._inference_config.post_processing.fused:
180
+ confidence = model_results
181
+ else:
182
+ confidence = torch.nn.functional.softmax(model_results, dim=-1)
183
+ return ClassificationPrediction(
184
+ class_id=confidence.argmax(dim=-1),
185
+ confidence=confidence,
186
+ )
187
+
188
+
189
+ class ResNetForMultiLabelClassificationOnnx(
190
+ MultiLabelClassificationModel[torch.Tensor, torch.Tensor]
191
+ ):
192
+
193
+ @classmethod
194
+ def from_pretrained(
195
+ cls,
196
+ model_name_or_path: str,
197
+ onnx_execution_providers: Optional[List[Union[str, tuple]]] = None,
198
+ default_onnx_trt_options: bool = True,
199
+ device: torch.device = DEFAULT_DEVICE,
200
+ **kwargs,
201
+ ) -> "ResNetForMultiLabelClassificationOnnx":
202
+ if onnx_execution_providers is None:
203
+ onnx_execution_providers = get_selected_onnx_execution_providers()
204
+ if not onnx_execution_providers:
205
+ raise EnvironmentConfigurationError(
206
+ message=f"Could not initialize model - selected backend is ONNX which requires execution provider to "
207
+ f"be specified - explicitly in `from_pretrained(...)` method or via env variable "
208
+ f"`ONNXRUNTIME_EXECUTION_PROVIDERS`. If you run model locally - adjust your setup, otherwise "
209
+ f"contact the platform support.",
210
+ help_url="https://todo",
211
+ )
212
+ onnx_execution_providers = set_execution_provider_defaults(
213
+ providers=onnx_execution_providers,
214
+ model_package_path=model_name_or_path,
215
+ device=device,
216
+ default_onnx_trt_options=default_onnx_trt_options,
217
+ )
218
+ model_package_content = get_model_package_contents(
219
+ model_package_dir=model_name_or_path,
220
+ elements=[
221
+ "class_names.txt",
222
+ "inference_config.json",
223
+ "weights.onnx",
224
+ ],
225
+ )
226
+ class_names = parse_class_names_file(
227
+ class_names_path=model_package_content["class_names.txt"]
228
+ )
229
+ inference_config = parse_inference_config(
230
+ config_path=model_package_content["inference_config.json"],
231
+ allowed_resize_modes={
232
+ ResizeMode.STRETCH_TO,
233
+ ResizeMode.LETTERBOX,
234
+ ResizeMode.CENTER_CROP,
235
+ ResizeMode.LETTERBOX_REFLECT_EDGES,
236
+ },
237
+ )
238
+ if inference_config.post_processing.type != "sigmoid":
239
+ raise CorruptedModelPackageError(
240
+ message="Expected sigmoid to be the post-processing",
241
+ help_url="https://todo",
242
+ )
243
+ session = onnxruntime.InferenceSession(
244
+ path_or_bytes=model_package_content["weights.onnx"],
245
+ providers=onnx_execution_providers,
246
+ )
247
+ input_shape = session.get_inputs()[0].shape
248
+ input_batch_size = input_shape[0]
249
+ if isinstance(input_batch_size, str):
250
+ input_batch_size = None
251
+ input_name = session.get_inputs()[0].name
252
+ return cls(
253
+ session=session,
254
+ input_name=input_name,
255
+ inference_config=inference_config,
256
+ class_names=class_names,
257
+ device=device,
258
+ input_batch_size=input_batch_size,
259
+ )
260
+
261
+ def __init__(
262
+ self,
263
+ session: onnxruntime.InferenceSession,
264
+ input_name: str,
265
+ inference_config: InferenceConfig,
266
+ class_names: List[str],
267
+ device: torch.device,
268
+ input_batch_size: Optional[int],
269
+ ):
270
+ self._session = session
271
+ self._input_name = input_name
272
+ self._inference_config = inference_config
273
+ self._class_names = class_names
274
+ self._device = device
275
+ self._input_batch_size = input_batch_size
276
+ self._session_thread_lock = Lock()
277
+
278
+ @property
279
+ def class_names(self) -> List[str]:
280
+ return self._class_names
281
+
282
+ def pre_process(
283
+ self,
284
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
285
+ input_color_format: Optional[ColorFormat] = None,
286
+ image_size: Optional[Tuple[int, int]] = None,
287
+ **kwargs,
288
+ ) -> torch.Tensor:
289
+ return pre_process_network_input(
290
+ images=images,
291
+ image_pre_processing=self._inference_config.image_pre_processing,
292
+ network_input=self._inference_config.network_input,
293
+ target_device=self._device,
294
+ input_color_format=input_color_format,
295
+ image_size_wh=image_size,
296
+ )[0]
297
+
298
+ def forward(
299
+ self, pre_processed_images: PreprocessedInputs, **kwargs
300
+ ) -> torch.Tensor:
301
+ with self._session_thread_lock:
302
+ return run_session_with_batch_size_limit(
303
+ session=self._session,
304
+ inputs={self._input_name: pre_processed_images},
305
+ min_batch_size=self._input_batch_size,
306
+ max_batch_size=self._input_batch_size,
307
+ )[0]
308
+
309
+ def post_process(
310
+ self,
311
+ model_results: torch.Tensor,
312
+ confidence: float = 0.5,
313
+ **kwargs,
314
+ ) -> List[MultiLabelClassificationPrediction]:
315
+ if self._inference_config.post_processing.fused:
316
+ model_results = model_results
317
+ else:
318
+ model_results = torch.nn.functional.sigmoid(model_results)
319
+ results = []
320
+ for batch_element_confidence in model_results:
321
+ predicted_classes = torch.argwhere(
322
+ batch_element_confidence >= confidence
323
+ ).squeeze(dim=-1)
324
+ results.append(
325
+ MultiLabelClassificationPrediction(
326
+ class_ids=predicted_classes,
327
+ confidence=batch_element_confidence,
328
+ )
329
+ )
330
+ return results