inference-models 0.18.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. inference_models/__init__.py +36 -0
  2. inference_models/configuration.py +72 -0
  3. inference_models/constants.py +2 -0
  4. inference_models/entities.py +5 -0
  5. inference_models/errors.py +137 -0
  6. inference_models/logger.py +52 -0
  7. inference_models/model_pipelines/__init__.py +0 -0
  8. inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
  9. inference_models/model_pipelines/auto_loaders/core.py +120 -0
  10. inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
  11. inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
  12. inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
  13. inference_models/models/__init__.py +0 -0
  14. inference_models/models/auto_loaders/__init__.py +0 -0
  15. inference_models/models/auto_loaders/access_manager.py +168 -0
  16. inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
  17. inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
  18. inference_models/models/auto_loaders/constants.py +7 -0
  19. inference_models/models/auto_loaders/core.py +1341 -0
  20. inference_models/models/auto_loaders/dependency_models.py +52 -0
  21. inference_models/models/auto_loaders/entities.py +57 -0
  22. inference_models/models/auto_loaders/models_registry.py +497 -0
  23. inference_models/models/auto_loaders/presentation_utils.py +333 -0
  24. inference_models/models/auto_loaders/ranking.py +413 -0
  25. inference_models/models/auto_loaders/utils.py +31 -0
  26. inference_models/models/base/__init__.py +0 -0
  27. inference_models/models/base/classification.py +123 -0
  28. inference_models/models/base/depth_estimation.py +62 -0
  29. inference_models/models/base/documents_parsing.py +111 -0
  30. inference_models/models/base/embeddings.py +66 -0
  31. inference_models/models/base/instance_segmentation.py +87 -0
  32. inference_models/models/base/keypoints_detection.py +93 -0
  33. inference_models/models/base/object_detection.py +143 -0
  34. inference_models/models/base/semantic_segmentation.py +74 -0
  35. inference_models/models/base/types.py +5 -0
  36. inference_models/models/clip/__init__.py +0 -0
  37. inference_models/models/clip/clip_onnx.py +148 -0
  38. inference_models/models/clip/clip_pytorch.py +104 -0
  39. inference_models/models/clip/preprocessing.py +162 -0
  40. inference_models/models/common/__init__.py +0 -0
  41. inference_models/models/common/cuda.py +30 -0
  42. inference_models/models/common/model_packages.py +25 -0
  43. inference_models/models/common/onnx.py +379 -0
  44. inference_models/models/common/roboflow/__init__.py +0 -0
  45. inference_models/models/common/roboflow/model_packages.py +361 -0
  46. inference_models/models/common/roboflow/post_processing.py +436 -0
  47. inference_models/models/common/roboflow/pre_processing.py +1332 -0
  48. inference_models/models/common/torch.py +20 -0
  49. inference_models/models/common/trt.py +266 -0
  50. inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
  51. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
  52. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
  53. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
  54. inference_models/models/depth_anything_v2/__init__.py +0 -0
  55. inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
  56. inference_models/models/dinov3/__init__.py +0 -0
  57. inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
  58. inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
  59. inference_models/models/doctr/__init__.py +0 -0
  60. inference_models/models/doctr/doctr_torch.py +304 -0
  61. inference_models/models/easy_ocr/__init__.py +0 -0
  62. inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
  63. inference_models/models/florence2/__init__.py +0 -0
  64. inference_models/models/florence2/florence2_hf.py +897 -0
  65. inference_models/models/grounding_dino/__init__.py +0 -0
  66. inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
  67. inference_models/models/l2cs/__init__.py +0 -0
  68. inference_models/models/l2cs/l2cs_onnx.py +216 -0
  69. inference_models/models/mediapipe_face_detection/__init__.py +0 -0
  70. inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
  71. inference_models/models/moondream2/__init__.py +0 -0
  72. inference_models/models/moondream2/moondream2_hf.py +281 -0
  73. inference_models/models/owlv2/__init__.py +0 -0
  74. inference_models/models/owlv2/cache.py +182 -0
  75. inference_models/models/owlv2/entities.py +112 -0
  76. inference_models/models/owlv2/owlv2_hf.py +695 -0
  77. inference_models/models/owlv2/reference_dataset.py +291 -0
  78. inference_models/models/paligemma/__init__.py +0 -0
  79. inference_models/models/paligemma/paligemma_hf.py +209 -0
  80. inference_models/models/perception_encoder/__init__.py +0 -0
  81. inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
  82. inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
  83. inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
  84. inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
  85. inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
  86. inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
  87. inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
  88. inference_models/models/qwen25vl/__init__.py +1 -0
  89. inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
  90. inference_models/models/resnet/__init__.py +0 -0
  91. inference_models/models/resnet/resnet_classification_onnx.py +330 -0
  92. inference_models/models/resnet/resnet_classification_torch.py +305 -0
  93. inference_models/models/resnet/resnet_classification_trt.py +369 -0
  94. inference_models/models/rfdetr/__init__.py +0 -0
  95. inference_models/models/rfdetr/backbone_builder.py +101 -0
  96. inference_models/models/rfdetr/class_remapping.py +41 -0
  97. inference_models/models/rfdetr/common.py +115 -0
  98. inference_models/models/rfdetr/default_labels.py +108 -0
  99. inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
  100. inference_models/models/rfdetr/misc.py +26 -0
  101. inference_models/models/rfdetr/ms_deform_attn.py +180 -0
  102. inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
  103. inference_models/models/rfdetr/position_encoding.py +166 -0
  104. inference_models/models/rfdetr/post_processor.py +83 -0
  105. inference_models/models/rfdetr/projector.py +373 -0
  106. inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
  107. inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
  108. inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
  109. inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
  110. inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
  111. inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
  112. inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
  113. inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
  114. inference_models/models/rfdetr/segmentation_head.py +273 -0
  115. inference_models/models/rfdetr/transformer.py +767 -0
  116. inference_models/models/roboflow_instant/__init__.py +0 -0
  117. inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
  118. inference_models/models/sam/__init__.py +0 -0
  119. inference_models/models/sam/cache.py +147 -0
  120. inference_models/models/sam/entities.py +25 -0
  121. inference_models/models/sam/sam_torch.py +675 -0
  122. inference_models/models/sam2/__init__.py +0 -0
  123. inference_models/models/sam2/cache.py +162 -0
  124. inference_models/models/sam2/entities.py +43 -0
  125. inference_models/models/sam2/sam2_torch.py +905 -0
  126. inference_models/models/sam2_rt/__init__.py +0 -0
  127. inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
  128. inference_models/models/smolvlm/__init__.py +0 -0
  129. inference_models/models/smolvlm/smolvlm_hf.py +245 -0
  130. inference_models/models/trocr/__init__.py +0 -0
  131. inference_models/models/trocr/trocr_hf.py +53 -0
  132. inference_models/models/vit/__init__.py +0 -0
  133. inference_models/models/vit/vit_classification_huggingface.py +319 -0
  134. inference_models/models/vit/vit_classification_onnx.py +326 -0
  135. inference_models/models/vit/vit_classification_trt.py +365 -0
  136. inference_models/models/yolact/__init__.py +1 -0
  137. inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
  138. inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
  139. inference_models/models/yolo_world/__init__.py +1 -0
  140. inference_models/models/yolonas/__init__.py +0 -0
  141. inference_models/models/yolonas/nms.py +44 -0
  142. inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
  143. inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
  144. inference_models/models/yolov10/__init__.py +0 -0
  145. inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
  146. inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
  147. inference_models/models/yolov11/__init__.py +0 -0
  148. inference_models/models/yolov11/yolov11_onnx.py +28 -0
  149. inference_models/models/yolov11/yolov11_torch_script.py +25 -0
  150. inference_models/models/yolov11/yolov11_trt.py +21 -0
  151. inference_models/models/yolov12/__init__.py +0 -0
  152. inference_models/models/yolov12/yolov12_onnx.py +7 -0
  153. inference_models/models/yolov12/yolov12_torch_script.py +7 -0
  154. inference_models/models/yolov12/yolov12_trt.py +7 -0
  155. inference_models/models/yolov5/__init__.py +0 -0
  156. inference_models/models/yolov5/nms.py +99 -0
  157. inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
  158. inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
  159. inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
  160. inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
  161. inference_models/models/yolov7/__init__.py +0 -0
  162. inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
  163. inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
  164. inference_models/models/yolov8/__init__.py +0 -0
  165. inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
  166. inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
  167. inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
  168. inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
  169. inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
  170. inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
  171. inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
  172. inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
  173. inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
  174. inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
  175. inference_models/models/yolov9/__init__.py +0 -0
  176. inference_models/models/yolov9/yolov9_onnx.py +7 -0
  177. inference_models/models/yolov9/yolov9_torch_script.py +7 -0
  178. inference_models/models/yolov9/yolov9_trt.py +7 -0
  179. inference_models/runtime_introspection/__init__.py +0 -0
  180. inference_models/runtime_introspection/core.py +410 -0
  181. inference_models/utils/__init__.py +0 -0
  182. inference_models/utils/download.py +608 -0
  183. inference_models/utils/environment.py +28 -0
  184. inference_models/utils/file_system.py +51 -0
  185. inference_models/utils/hashing.py +7 -0
  186. inference_models/utils/imports.py +48 -0
  187. inference_models/utils/onnx_introspection.py +17 -0
  188. inference_models/weights_providers/__init__.py +0 -0
  189. inference_models/weights_providers/core.py +20 -0
  190. inference_models/weights_providers/entities.py +159 -0
  191. inference_models/weights_providers/roboflow.py +601 -0
  192. inference_models-0.18.3.dist-info/METADATA +466 -0
  193. inference_models-0.18.3.dist-info/RECORD +195 -0
  194. inference_models-0.18.3.dist-info/WHEEL +5 -0
  195. inference_models-0.18.3.dist-info/top_level.txt +1 -0
File without changes
@@ -0,0 +1,119 @@
1
+ from pathlib import Path
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ from inference_models.configuration import DEFAULT_DEVICE
8
+ from inference_models.errors import MissingDependencyError, ModelRuntimeError
9
+ from inference_models.models.common.model_packages import get_model_package_contents
10
+
11
+ try:
12
+ import hydra
13
+ from sam2.build_sam import build_sam2_camera_predictor
14
+ from sam2.sam2_camera_predictor import SAM2CameraPredictor
15
+ except ImportError as import_error:
16
+ raise MissingDependencyError(
17
+ message=f"Could not import SAM2 model, please consult README for installation instructions.",
18
+ ) from import_error
19
+
20
+
21
+ class SAM2ForStream:
22
+ @classmethod
23
+ def from_pretrained(
24
+ cls,
25
+ model_name_or_path: str,
26
+ device: torch.device = DEFAULT_DEVICE,
27
+ **kwargs,
28
+ ) -> "SAM2ForStream":
29
+ model_package_content = get_model_package_contents(
30
+ model_package_dir=model_name_or_path,
31
+ elements=[
32
+ "weights.pt",
33
+ "sam2-rt.yaml",
34
+ ],
35
+ )
36
+ hydra.core.global_hydra.GlobalHydra.instance().clear()
37
+ hydra.initialize_config_dir(
38
+ config_dir=Path(model_package_content["sam2-rt.yaml"]).parent.as_posix(),
39
+ version_base=None,
40
+ )
41
+ predictor: SAM2CameraPredictor = build_sam2_camera_predictor(
42
+ config_file=Path(model_package_content["sam2-rt.yaml"]).name,
43
+ ckpt_path=model_package_content["weights.pt"],
44
+ device=device,
45
+ )
46
+ return cls(predictor=predictor, device=device)
47
+
48
+ def __init__(self, predictor: SAM2CameraPredictor, device: torch.device):
49
+ self._predictor = predictor
50
+ self._device = device
51
+
52
+ def prompt(
53
+ self,
54
+ image: Union[np.ndarray, torch.Tensor],
55
+ bboxes: Union[Tuple[int, int, int, int], List[Tuple[int, int, int, int]]],
56
+ state_dict: Optional[dict] = None,
57
+ clear_old_points: bool = True,
58
+ normalize_coords: bool = True,
59
+ frame_idx: int = 0,
60
+ ) -> tuple:
61
+ if isinstance(image, torch.Tensor):
62
+ image = image.detach().cpu().numpy()
63
+ if clear_old_points or not self._predictor.condition_state:
64
+ self._predictor.load_first_frame(image)
65
+ if state_dict is not None:
66
+ self._predictor.load_state_dict(state_dict)
67
+ obj_id = 0
68
+ if (
69
+ self._predictor.condition_state
70
+ and self._predictor.condition_state["obj_ids"]
71
+ ):
72
+ obj_id = max(self._predictor.condition_state["obj_ids"]) + 1
73
+ if not isinstance(bboxes, list):
74
+ bboxes = [bboxes]
75
+ for pts in bboxes:
76
+ if len(pts) < 4:
77
+ continue
78
+ x1, y1, x2, y2 = pts[:4]
79
+ x_lt = int(round(min(x1, x2)))
80
+ y_lt = int(round(min(y1, y2)))
81
+ x_rb = int(round(max(x1, x2)))
82
+ y_rb = int(round(max(y1, y2)))
83
+ xyxy = np.array([[x_lt, y_lt, x_rb, y_rb]])
84
+
85
+ _, object_ids, mask_logits = self._predictor.add_new_prompt(
86
+ frame_idx=frame_idx,
87
+ obj_id=obj_id,
88
+ bbox=xyxy,
89
+ clear_old_points=clear_old_points,
90
+ normalize_coords=normalize_coords,
91
+ )
92
+ obj_id += 1
93
+ masks = (mask_logits > 0.0).cpu().numpy()
94
+ masks = np.squeeze(masks).astype(bool)
95
+ if len(masks.shape) == 2:
96
+ masks = np.expand_dims(masks, axis=0)
97
+ object_ids = np.array(object_ids)
98
+ return masks, object_ids, self._predictor.state_dict()
99
+
100
+ def track(
101
+ self,
102
+ image: Union[np.ndarray, torch.Tensor],
103
+ state_dict: Optional[dict] = None,
104
+ ) -> tuple:
105
+ if isinstance(image, torch.Tensor):
106
+ image = image.detach().cpu().numpy()
107
+ if state_dict is not None:
108
+ self._predictor.load_state_dict(state_dict)
109
+ if not self._predictor.condition_state:
110
+ raise ModelRuntimeError(
111
+ "Attempt to track with no prior call to prompt; prompt must be called first"
112
+ )
113
+ object_ids, mask_logits = self._predictor.track(image)
114
+ masks = (mask_logits > 0.0).cpu().numpy()
115
+ masks = np.squeeze(masks).astype(bool)
116
+ if len(masks.shape) == 2:
117
+ masks = np.expand_dims(masks, axis=0)
118
+ object_ids = np.array(object_ids)
119
+ return masks, object_ids, self._predictor.state_dict()
File without changes
@@ -0,0 +1,245 @@
1
+ import os
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ from peft import PeftModel
7
+ from transformers import AutoModelForImageTextToText, AutoProcessor, BitsAndBytesConfig
8
+
9
+ from inference_models.configuration import DEFAULT_DEVICE
10
+ from inference_models.entities import ColorFormat
11
+ from inference_models.models.common.roboflow.model_packages import (
12
+ InferenceConfig,
13
+ ResizeMode,
14
+ parse_inference_config,
15
+ )
16
+ from inference_models.models.common.roboflow.pre_processing import (
17
+ pre_process_network_input,
18
+ )
19
+
20
+
21
+ class SmolVLMHF:
22
+
23
+ @classmethod
24
+ def from_pretrained(
25
+ cls,
26
+ model_name_or_path: str,
27
+ device: torch.device = DEFAULT_DEVICE,
28
+ trust_remote_code: bool = False,
29
+ local_files_only: bool = True,
30
+ quantization_config: Optional[BitsAndBytesConfig] = None,
31
+ disable_quantization: bool = False,
32
+ **kwargs,
33
+ ) -> "SmolVLMHF":
34
+ torch_dtype = torch.float16 if device.type == "cuda" else torch.float32
35
+ inference_config_path = os.path.join(
36
+ model_name_or_path, "inference_config.json"
37
+ )
38
+ inference_config = None
39
+ if os.path.exists(inference_config_path):
40
+ inference_config = parse_inference_config(
41
+ config_path=inference_config_path,
42
+ allowed_resize_modes={
43
+ ResizeMode.STRETCH_TO,
44
+ ResizeMode.LETTERBOX,
45
+ ResizeMode.CENTER_CROP,
46
+ ResizeMode.LETTERBOX_REFLECT_EDGES,
47
+ },
48
+ )
49
+ adapter_config_path = os.path.join(model_name_or_path, "adapter_config.json")
50
+ if (
51
+ quantization_config is None
52
+ and device.type == "cuda"
53
+ and not disable_quantization
54
+ ):
55
+ quantization_config = BitsAndBytesConfig(
56
+ load_in_4bit=True,
57
+ bnb_4bit_compute_dtype=torch.float16,
58
+ bnb_4bit_quant_type="nf4",
59
+ )
60
+ if os.path.exists(adapter_config_path):
61
+
62
+ base_model_path = os.path.join(model_name_or_path, "base")
63
+ model = AutoModelForImageTextToText.from_pretrained(
64
+ base_model_path,
65
+ dtype=torch_dtype,
66
+ trust_remote_code=trust_remote_code,
67
+ local_files_only=local_files_only,
68
+ quantization_config=quantization_config,
69
+ )
70
+ model = PeftModel.from_pretrained(model, model_name_or_path)
71
+ if quantization_config is None:
72
+ model.merge_and_unload()
73
+ model.to(device)
74
+
75
+ processor = AutoProcessor.from_pretrained(
76
+ base_model_path,
77
+ padding_side="left",
78
+ trust_remote_code=trust_remote_code,
79
+ local_files_only=local_files_only,
80
+ use_fast=True,
81
+ )
82
+ else:
83
+ model = AutoModelForImageTextToText.from_pretrained(
84
+ model_name_or_path,
85
+ dtype=torch_dtype,
86
+ device_map=device,
87
+ trust_remote_code=trust_remote_code,
88
+ local_files_only=local_files_only,
89
+ quantization_config=quantization_config,
90
+ ).eval()
91
+ processor = AutoProcessor.from_pretrained(
92
+ model_name_or_path,
93
+ padding_side="left",
94
+ trust_remote_code=trust_remote_code,
95
+ local_files_only=local_files_only,
96
+ use_fast=True,
97
+ )
98
+ return cls(
99
+ model=model,
100
+ processor=processor,
101
+ inference_config=inference_config,
102
+ device=device,
103
+ torch_dtype=torch_dtype,
104
+ )
105
+
106
+ def __init__(
107
+ self,
108
+ model: AutoModelForImageTextToText,
109
+ processor: AutoProcessor,
110
+ inference_config: Optional[InferenceConfig],
111
+ device: torch.device,
112
+ torch_dtype: torch.dtype,
113
+ ):
114
+ self._model = model
115
+ self._processor = processor
116
+ self._inference_config = inference_config
117
+ self._device = device
118
+ self._torch_dtype = torch_dtype
119
+
120
+ def prompt(
121
+ self,
122
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
123
+ prompt: str,
124
+ images_to_single_prompt: bool = True,
125
+ input_color_format: Optional[ColorFormat] = None,
126
+ max_new_tokens: int = 400,
127
+ do_sample: bool = False,
128
+ skip_special_tokens: bool = True,
129
+ **kwargs,
130
+ ) -> List[str]:
131
+ inputs = self.pre_process_generation(
132
+ images=images,
133
+ prompt=prompt,
134
+ images_to_single_prompt=images_to_single_prompt,
135
+ input_color_format=input_color_format,
136
+ )
137
+ generated_ids = self.generate(
138
+ inputs=inputs,
139
+ max_new_tokens=max_new_tokens,
140
+ do_sample=do_sample,
141
+ )
142
+ return self.post_process_generation(
143
+ generated_ids=generated_ids,
144
+ skip_special_tokens=skip_special_tokens,
145
+ )
146
+
147
+ def pre_process_generation(
148
+ self,
149
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
150
+ prompt: str,
151
+ images_to_single_prompt: bool = True,
152
+ input_color_format: Optional[ColorFormat] = None,
153
+ image_size: Optional[Tuple[int, int]] = None,
154
+ **kwargs,
155
+ ) -> dict:
156
+ def _to_tensor(image: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
157
+ is_numpy = isinstance(image, np.ndarray)
158
+ if is_numpy:
159
+ tensor_image = torch.from_numpy(image.copy()).permute(2, 0, 1)
160
+ else:
161
+ tensor_image = image
162
+ if input_color_format == "bgr" or (is_numpy and input_color_format is None):
163
+ tensor_image = tensor_image[[2, 1, 0], :, :]
164
+ if image_size is not None:
165
+ tensor_image = torch.nn.functional.interpolate(
166
+ image,
167
+ [image_size[1], image_size[0]],
168
+ mode="bilinear",
169
+ )
170
+ return tensor_image
171
+
172
+ if self._inference_config is None:
173
+ if isinstance(images, torch.Tensor) and images.ndim > 3:
174
+ image_list = [_to_tensor(img) for img in images]
175
+ elif not isinstance(images, list):
176
+ image_list = [_to_tensor(images)]
177
+ else:
178
+ image_list = [_to_tensor(img) for img in images]
179
+ else:
180
+ images = pre_process_network_input(
181
+ images=images,
182
+ image_pre_processing=self._inference_config.image_pre_processing,
183
+ network_input=self._inference_config.network_input,
184
+ target_device=self._device,
185
+ input_color_format=input_color_format,
186
+ image_size_wh=image_size,
187
+ )[0]
188
+ image_list = [e[0] for e in torch.split(images, 1, dim=0)]
189
+ if images_to_single_prompt:
190
+ content = [{"type": "image"}] * len(image_list)
191
+ content.append({"type": "text", "text": prompt})
192
+ conversations = [[{"role": "user", "content": content}]]
193
+ else:
194
+ conversations = []
195
+ for _ in image_list:
196
+ conversations.append(
197
+ [
198
+ {
199
+ "role": "user",
200
+ "content": [
201
+ {"type": "image"},
202
+ {"type": "text", "text": prompt},
203
+ ],
204
+ }
205
+ ]
206
+ )
207
+ text_prompts = self._processor.apply_chat_template(
208
+ conversations, add_generation_prompt=True
209
+ )
210
+ max_image_size = None
211
+ if image_size:
212
+ max_image_size = {"longest_edge": max(image_size[0], image_size[1])}
213
+
214
+ inputs = self._processor(
215
+ text=text_prompts,
216
+ images=image_list,
217
+ return_tensors="pt",
218
+ padding=True,
219
+ max_image_size=max_image_size,
220
+ )
221
+ return inputs.to(self._device, dtype=self._torch_dtype)
222
+
223
+ def generate(
224
+ self,
225
+ inputs: dict,
226
+ max_new_tokens: int = 400,
227
+ do_sample: bool = False,
228
+ **kwargs,
229
+ ) -> torch.Tensor:
230
+ generation = self._model.generate(
231
+ **inputs, do_sample=do_sample, max_new_tokens=max_new_tokens
232
+ )
233
+ input_len = inputs["input_ids"].shape[-1]
234
+ return generation[:, input_len:]
235
+
236
+ def post_process_generation(
237
+ self,
238
+ generated_ids: torch.Tensor,
239
+ skip_special_tokens: bool = False,
240
+ **kwargs,
241
+ ) -> List[str]:
242
+ decoded = self._processor.batch_decode(
243
+ generated_ids, skip_special_tokens=skip_special_tokens
244
+ )
245
+ return [result.strip() for result in decoded]
File without changes
@@ -0,0 +1,53 @@
1
+ from typing import List, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
6
+
7
+ from inference_models.configuration import DEFAULT_DEVICE
8
+ from inference_models.models.base.documents_parsing import TextOnlyOCRModel
9
+
10
+
11
+ class TROcrHF(TextOnlyOCRModel[torch.Tensor, torch.Tensor]):
12
+
13
+ @classmethod
14
+ def from_pretrained(
15
+ cls,
16
+ model_name_or_path: str,
17
+ device: torch.device = DEFAULT_DEVICE,
18
+ local_files_only: bool = True,
19
+ **kwargs,
20
+ ) -> "TextOnlyOCRModel":
21
+ model = VisionEncoderDecoderModel.from_pretrained(
22
+ model_name_or_path,
23
+ local_files_only=local_files_only,
24
+ ).to(device)
25
+ processor = TrOCRProcessor.from_pretrained(
26
+ model_name_or_path, local_files_only=local_files_only
27
+ )
28
+ return cls(model=model, processor=processor, device=device)
29
+
30
+ def __init__(
31
+ self,
32
+ processor: TrOCRProcessor,
33
+ model: VisionEncoderDecoderModel,
34
+ device: torch.device,
35
+ ):
36
+ self._processor = processor
37
+ self._model = model
38
+ self._device = device
39
+
40
+ def pre_process(
41
+ self,
42
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
43
+ **kwargs,
44
+ ) -> torch.Tensor:
45
+ inputs = self._processor(images=images, return_tensors="pt")
46
+ return inputs["pixel_values"].to(self._device)
47
+
48
+ def forward(self, pre_processed_images: torch.Tensor, **kwargs) -> torch.Tensor:
49
+ with torch.inference_mode():
50
+ return self._model.generate(pre_processed_images)
51
+
52
+ def post_process(self, model_results: torch.Tensor, **kwargs) -> List[str]:
53
+ return self._processor.batch_decode(model_results, skip_special_tokens=True)
File without changes