inference-models 0.18.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. inference_models/__init__.py +36 -0
  2. inference_models/configuration.py +72 -0
  3. inference_models/constants.py +2 -0
  4. inference_models/entities.py +5 -0
  5. inference_models/errors.py +137 -0
  6. inference_models/logger.py +52 -0
  7. inference_models/model_pipelines/__init__.py +0 -0
  8. inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
  9. inference_models/model_pipelines/auto_loaders/core.py +120 -0
  10. inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
  11. inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
  12. inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
  13. inference_models/models/__init__.py +0 -0
  14. inference_models/models/auto_loaders/__init__.py +0 -0
  15. inference_models/models/auto_loaders/access_manager.py +168 -0
  16. inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
  17. inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
  18. inference_models/models/auto_loaders/constants.py +7 -0
  19. inference_models/models/auto_loaders/core.py +1341 -0
  20. inference_models/models/auto_loaders/dependency_models.py +52 -0
  21. inference_models/models/auto_loaders/entities.py +57 -0
  22. inference_models/models/auto_loaders/models_registry.py +497 -0
  23. inference_models/models/auto_loaders/presentation_utils.py +333 -0
  24. inference_models/models/auto_loaders/ranking.py +413 -0
  25. inference_models/models/auto_loaders/utils.py +31 -0
  26. inference_models/models/base/__init__.py +0 -0
  27. inference_models/models/base/classification.py +123 -0
  28. inference_models/models/base/depth_estimation.py +62 -0
  29. inference_models/models/base/documents_parsing.py +111 -0
  30. inference_models/models/base/embeddings.py +66 -0
  31. inference_models/models/base/instance_segmentation.py +87 -0
  32. inference_models/models/base/keypoints_detection.py +93 -0
  33. inference_models/models/base/object_detection.py +143 -0
  34. inference_models/models/base/semantic_segmentation.py +74 -0
  35. inference_models/models/base/types.py +5 -0
  36. inference_models/models/clip/__init__.py +0 -0
  37. inference_models/models/clip/clip_onnx.py +148 -0
  38. inference_models/models/clip/clip_pytorch.py +104 -0
  39. inference_models/models/clip/preprocessing.py +162 -0
  40. inference_models/models/common/__init__.py +0 -0
  41. inference_models/models/common/cuda.py +30 -0
  42. inference_models/models/common/model_packages.py +25 -0
  43. inference_models/models/common/onnx.py +379 -0
  44. inference_models/models/common/roboflow/__init__.py +0 -0
  45. inference_models/models/common/roboflow/model_packages.py +361 -0
  46. inference_models/models/common/roboflow/post_processing.py +436 -0
  47. inference_models/models/common/roboflow/pre_processing.py +1332 -0
  48. inference_models/models/common/torch.py +20 -0
  49. inference_models/models/common/trt.py +266 -0
  50. inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
  51. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
  52. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
  53. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
  54. inference_models/models/depth_anything_v2/__init__.py +0 -0
  55. inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
  56. inference_models/models/dinov3/__init__.py +0 -0
  57. inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
  58. inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
  59. inference_models/models/doctr/__init__.py +0 -0
  60. inference_models/models/doctr/doctr_torch.py +304 -0
  61. inference_models/models/easy_ocr/__init__.py +0 -0
  62. inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
  63. inference_models/models/florence2/__init__.py +0 -0
  64. inference_models/models/florence2/florence2_hf.py +897 -0
  65. inference_models/models/grounding_dino/__init__.py +0 -0
  66. inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
  67. inference_models/models/l2cs/__init__.py +0 -0
  68. inference_models/models/l2cs/l2cs_onnx.py +216 -0
  69. inference_models/models/mediapipe_face_detection/__init__.py +0 -0
  70. inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
  71. inference_models/models/moondream2/__init__.py +0 -0
  72. inference_models/models/moondream2/moondream2_hf.py +281 -0
  73. inference_models/models/owlv2/__init__.py +0 -0
  74. inference_models/models/owlv2/cache.py +182 -0
  75. inference_models/models/owlv2/entities.py +112 -0
  76. inference_models/models/owlv2/owlv2_hf.py +695 -0
  77. inference_models/models/owlv2/reference_dataset.py +291 -0
  78. inference_models/models/paligemma/__init__.py +0 -0
  79. inference_models/models/paligemma/paligemma_hf.py +209 -0
  80. inference_models/models/perception_encoder/__init__.py +0 -0
  81. inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
  82. inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
  83. inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
  84. inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
  85. inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
  86. inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
  87. inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
  88. inference_models/models/qwen25vl/__init__.py +1 -0
  89. inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
  90. inference_models/models/resnet/__init__.py +0 -0
  91. inference_models/models/resnet/resnet_classification_onnx.py +330 -0
  92. inference_models/models/resnet/resnet_classification_torch.py +305 -0
  93. inference_models/models/resnet/resnet_classification_trt.py +369 -0
  94. inference_models/models/rfdetr/__init__.py +0 -0
  95. inference_models/models/rfdetr/backbone_builder.py +101 -0
  96. inference_models/models/rfdetr/class_remapping.py +41 -0
  97. inference_models/models/rfdetr/common.py +115 -0
  98. inference_models/models/rfdetr/default_labels.py +108 -0
  99. inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
  100. inference_models/models/rfdetr/misc.py +26 -0
  101. inference_models/models/rfdetr/ms_deform_attn.py +180 -0
  102. inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
  103. inference_models/models/rfdetr/position_encoding.py +166 -0
  104. inference_models/models/rfdetr/post_processor.py +83 -0
  105. inference_models/models/rfdetr/projector.py +373 -0
  106. inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
  107. inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
  108. inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
  109. inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
  110. inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
  111. inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
  112. inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
  113. inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
  114. inference_models/models/rfdetr/segmentation_head.py +273 -0
  115. inference_models/models/rfdetr/transformer.py +767 -0
  116. inference_models/models/roboflow_instant/__init__.py +0 -0
  117. inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
  118. inference_models/models/sam/__init__.py +0 -0
  119. inference_models/models/sam/cache.py +147 -0
  120. inference_models/models/sam/entities.py +25 -0
  121. inference_models/models/sam/sam_torch.py +675 -0
  122. inference_models/models/sam2/__init__.py +0 -0
  123. inference_models/models/sam2/cache.py +162 -0
  124. inference_models/models/sam2/entities.py +43 -0
  125. inference_models/models/sam2/sam2_torch.py +905 -0
  126. inference_models/models/sam2_rt/__init__.py +0 -0
  127. inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
  128. inference_models/models/smolvlm/__init__.py +0 -0
  129. inference_models/models/smolvlm/smolvlm_hf.py +245 -0
  130. inference_models/models/trocr/__init__.py +0 -0
  131. inference_models/models/trocr/trocr_hf.py +53 -0
  132. inference_models/models/vit/__init__.py +0 -0
  133. inference_models/models/vit/vit_classification_huggingface.py +319 -0
  134. inference_models/models/vit/vit_classification_onnx.py +326 -0
  135. inference_models/models/vit/vit_classification_trt.py +365 -0
  136. inference_models/models/yolact/__init__.py +1 -0
  137. inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
  138. inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
  139. inference_models/models/yolo_world/__init__.py +1 -0
  140. inference_models/models/yolonas/__init__.py +0 -0
  141. inference_models/models/yolonas/nms.py +44 -0
  142. inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
  143. inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
  144. inference_models/models/yolov10/__init__.py +0 -0
  145. inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
  146. inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
  147. inference_models/models/yolov11/__init__.py +0 -0
  148. inference_models/models/yolov11/yolov11_onnx.py +28 -0
  149. inference_models/models/yolov11/yolov11_torch_script.py +25 -0
  150. inference_models/models/yolov11/yolov11_trt.py +21 -0
  151. inference_models/models/yolov12/__init__.py +0 -0
  152. inference_models/models/yolov12/yolov12_onnx.py +7 -0
  153. inference_models/models/yolov12/yolov12_torch_script.py +7 -0
  154. inference_models/models/yolov12/yolov12_trt.py +7 -0
  155. inference_models/models/yolov5/__init__.py +0 -0
  156. inference_models/models/yolov5/nms.py +99 -0
  157. inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
  158. inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
  159. inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
  160. inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
  161. inference_models/models/yolov7/__init__.py +0 -0
  162. inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
  163. inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
  164. inference_models/models/yolov8/__init__.py +0 -0
  165. inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
  166. inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
  167. inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
  168. inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
  169. inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
  170. inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
  171. inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
  172. inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
  173. inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
  174. inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
  175. inference_models/models/yolov9/__init__.py +0 -0
  176. inference_models/models/yolov9/yolov9_onnx.py +7 -0
  177. inference_models/models/yolov9/yolov9_torch_script.py +7 -0
  178. inference_models/models/yolov9/yolov9_trt.py +7 -0
  179. inference_models/runtime_introspection/__init__.py +0 -0
  180. inference_models/runtime_introspection/core.py +410 -0
  181. inference_models/utils/__init__.py +0 -0
  182. inference_models/utils/download.py +608 -0
  183. inference_models/utils/environment.py +28 -0
  184. inference_models/utils/file_system.py +51 -0
  185. inference_models/utils/hashing.py +7 -0
  186. inference_models/utils/imports.py +48 -0
  187. inference_models/utils/onnx_introspection.py +17 -0
  188. inference_models/weights_providers/__init__.py +0 -0
  189. inference_models/weights_providers/core.py +20 -0
  190. inference_models/weights_providers/entities.py +159 -0
  191. inference_models/weights_providers/roboflow.py +601 -0
  192. inference_models-0.18.3.dist-info/METADATA +466 -0
  193. inference_models-0.18.3.dist-info/RECORD +195 -0
  194. inference_models-0.18.3.dist-info/WHEEL +5 -0
  195. inference_models-0.18.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,897 @@
1
+ import json
2
+ import os
3
+ import re
4
+ from typing import List, Literal, Optional, Tuple, Union
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import torch
9
+ from peft import LoraConfig, get_peft_model
10
+ from peft.mapping import PEFT_TYPE_TO_PREFIX_MAPPING
11
+ from peft.utils.save_and_load import load_peft_weights, set_peft_model_state_dict
12
+ from transformers import (
13
+ BitsAndBytesConfig,
14
+ Florence2ForConditionalGeneration,
15
+ Florence2Processor,
16
+ )
17
+
18
+ from inference_models import Detections, InstanceDetections
19
+ from inference_models.configuration import DEFAULT_DEVICE
20
+ from inference_models.entities import ColorFormat, ImageDimensions
21
+ from inference_models.errors import CorruptedModelPackageError, ModelRuntimeError
22
+ from inference_models.models.common.roboflow.model_packages import (
23
+ InferenceConfig,
24
+ PreProcessingMetadata,
25
+ ResizeMode,
26
+ parse_inference_config,
27
+ )
28
+ from inference_models.models.common.roboflow.pre_processing import (
29
+ extract_input_images_dimensions,
30
+ pre_process_network_input,
31
+ )
32
+
33
+ GRANULARITY_2TASK = {
34
+ "normal": "<CAPTION>",
35
+ "detailed": "<DETAILED_CAPTION>",
36
+ "very_detailed": "<MORE_DETAILED_CAPTION>",
37
+ }
38
+ LABEL_MODE2TASK = {
39
+ "rois": "<REGION_PROPOSAL>",
40
+ "classes": "<OD>",
41
+ "captions": "<DENSE_REGION_CAPTION>",
42
+ }
43
+ LOC_BINS = 1000
44
+
45
+
46
+ class Florence2HF:
47
+
48
+ @classmethod
49
+ def from_pretrained(
50
+ cls,
51
+ model_name_or_path: str,
52
+ device: torch.device = DEFAULT_DEVICE,
53
+ trust_remote_code: bool = False,
54
+ local_files_only: bool = True,
55
+ quantization_config: Optional[BitsAndBytesConfig] = None,
56
+ disable_quantization: bool = False,
57
+ **kwargs,
58
+ ) -> "Florence2HF":
59
+ torch_dtype = torch.float16 if device.type == "cuda" else torch.bfloat16
60
+ inference_config_path = os.path.join(
61
+ model_name_or_path, "inference_config.json"
62
+ )
63
+ inference_config = None
64
+ if os.path.exists(inference_config_path):
65
+ inference_config = parse_inference_config(
66
+ config_path=inference_config_path,
67
+ allowed_resize_modes={
68
+ ResizeMode.STRETCH_TO,
69
+ ResizeMode.LETTERBOX,
70
+ ResizeMode.CENTER_CROP,
71
+ ResizeMode.LETTERBOX_REFLECT_EDGES,
72
+ },
73
+ )
74
+
75
+ adapter_config_path = os.path.join(model_name_or_path, "adapter_config.json")
76
+ is_adapter_package = os.path.exists(adapter_config_path)
77
+
78
+ base_model_path = (
79
+ os.path.join(model_name_or_path, "base")
80
+ if is_adapter_package
81
+ else model_name_or_path
82
+ )
83
+ if not os.path.isdir(base_model_path):
84
+ raise ModelRuntimeError(
85
+ message=f"Provided model path does not exist or is not a directory: {base_model_path}",
86
+ help_url="https://todo",
87
+ )
88
+ if not os.path.isfile(os.path.join(base_model_path, "config.json")):
89
+ raise ModelRuntimeError(
90
+ message=(
91
+ "Provided model directory does not look like a valid HF Florence-2 checkpoint (missing config.json). "
92
+ "If you used the official converter, point to its output directory."
93
+ ),
94
+ help_url="https://todo",
95
+ )
96
+ if (
97
+ quantization_config is None
98
+ and device.type == "cuda"
99
+ and not disable_quantization
100
+ ):
101
+ quantization_config = BitsAndBytesConfig(
102
+ load_in_4bit=True,
103
+ bnb_4bit_compute_dtype=torch.float16,
104
+ bnb_4bit_quant_type="nf4",
105
+ )
106
+ # Native HF Florence2 path only (require transformers >= 4.56)
107
+ model = Florence2ForConditionalGeneration.from_pretrained( # type: ignore[arg-type]
108
+ pretrained_model_name_or_path=base_model_path,
109
+ dtype=torch_dtype,
110
+ local_files_only=local_files_only,
111
+ trust_remote_code=trust_remote_code,
112
+ quantization_config=quantization_config,
113
+ )
114
+ if is_adapter_package:
115
+ # Custom LoRA attach to also cover vision modules
116
+ adapter_cfg_path = os.path.join(model_name_or_path, "adapter_config.json")
117
+ with open(adapter_cfg_path, "r") as f:
118
+ adapter_cfg = json.load(f)
119
+
120
+ requested_target_modules = adapter_cfg.get("target_modules") or []
121
+ adapter_task_type = adapter_cfg.get("task_type") or "SEQ_2_SEQ_LM"
122
+ lora_config = LoraConfig(
123
+ r=adapter_cfg.get("r", 8),
124
+ lora_alpha=adapter_cfg.get("lora_alpha", 8),
125
+ lora_dropout=adapter_cfg.get("lora_dropout", 0.0),
126
+ bias="none",
127
+ target_modules=sorted(requested_target_modules),
128
+ use_dora=bool(adapter_cfg.get("use_dora", False)),
129
+ use_rslora=bool(adapter_cfg.get("use_rslora", False)),
130
+ task_type=adapter_task_type,
131
+ )
132
+
133
+ model = get_peft_model(model, lora_config)
134
+ # Load adapter weights
135
+ adapter_state = load_peft_weights(model_name_or_path, device=device.type)
136
+ adapter_state = normalize_adapter_state_dict(adapter_state)
137
+ load_result = set_peft_model_state_dict(
138
+ model, adapter_state, adapter_name="default"
139
+ )
140
+ tuner = lora_config.peft_type
141
+ tuner_prefix = PEFT_TYPE_TO_PREFIX_MAPPING.get(tuner, "")
142
+ adapter_missing_keys = []
143
+ # Filter missing keys specific to the current adapter and tuner prefix.
144
+ for key in load_result.missing_keys:
145
+ if tuner_prefix in key and "default" in key:
146
+ adapter_missing_keys.append(key)
147
+ load_result.missing_keys.clear()
148
+ load_result.missing_keys.extend(adapter_missing_keys)
149
+ if len(load_result.missing_keys) > 0:
150
+ raise CorruptedModelPackageError(
151
+ message="Could not load LoRA weights for the model - found missing checkpoint keys "
152
+ f"({len(load_result.missing_keys)}): {load_result.missing_keys}",
153
+ help_url="https://todo",
154
+ )
155
+ if quantization_config is None:
156
+ model.merge_and_unload()
157
+ # Ensure global dtype consistency (handles CPU bfloat16 vs fp32 mismatches)
158
+ model = model.to(dtype=torch_dtype)
159
+ model = model.to(device)
160
+
161
+ processor = Florence2Processor.from_pretrained( # type: ignore[arg-type]
162
+ pretrained_model_name_or_path=base_model_path,
163
+ local_files_only=local_files_only,
164
+ trust_remote_code=trust_remote_code,
165
+ use_fast=True,
166
+ )
167
+
168
+ return cls(
169
+ model=model,
170
+ processor=processor,
171
+ inference_config=inference_config,
172
+ device=device,
173
+ torch_dtype=torch_dtype,
174
+ )
175
+
176
+ def __init__(
177
+ self,
178
+ model: Florence2ForConditionalGeneration,
179
+ processor: Florence2Processor,
180
+ inference_config: Optional[InferenceConfig],
181
+ device: torch.device,
182
+ torch_dtype: torch.dtype,
183
+ ):
184
+ self._model = model
185
+ self._processor = processor
186
+ self._inference_config = inference_config
187
+ self._device = device
188
+ self._torch_dtype = torch_dtype
189
+
190
+ def classify_image_region(
191
+ self,
192
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
193
+ xyxy: Union[
194
+ torch.Tensor,
195
+ List[List[Union[float, int]]],
196
+ List[Union[float, int]],
197
+ np.ndarray,
198
+ ],
199
+ max_new_tokens: int = 4096,
200
+ num_beams: int = 3,
201
+ do_sample: bool = False,
202
+ input_color_format: Optional[ColorFormat] = None,
203
+ ) -> List[str]:
204
+ loc_phrases = region_to_loc_phrase(images=images, xyxy=xyxy)
205
+ prompt = [f"<REGION_TO_CATEGORY>{phrase}" for phrase in loc_phrases]
206
+ task = "<REGION_TO_CATEGORY>"
207
+ result = self.prompt(
208
+ images=images,
209
+ prompt=prompt,
210
+ max_new_tokens=max_new_tokens,
211
+ num_beams=num_beams,
212
+ do_sample=do_sample,
213
+ task=task,
214
+ input_color_format=input_color_format,
215
+ )
216
+ return [deduce_localisation(r[task]) for r in result]
217
+
218
+ def caption_image_region(
219
+ self,
220
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
221
+ xyxy: Union[
222
+ torch.Tensor,
223
+ List[List[Union[float, int]]],
224
+ List[Union[float, int]],
225
+ np.ndarray,
226
+ ],
227
+ max_new_tokens: int = 4096,
228
+ num_beams: int = 3,
229
+ do_sample: bool = False,
230
+ input_color_format: Optional[ColorFormat] = None,
231
+ ) -> List[str]:
232
+ loc_phrases = region_to_loc_phrase(images=images, xyxy=xyxy)
233
+ prompt = [f"<REGION_TO_DESCRIPTION>{phrase}" for phrase in loc_phrases]
234
+ task = "<REGION_TO_DESCRIPTION>"
235
+ result = self.prompt(
236
+ images=images,
237
+ prompt=prompt,
238
+ max_new_tokens=max_new_tokens,
239
+ num_beams=num_beams,
240
+ do_sample=do_sample,
241
+ task=task,
242
+ input_color_format=input_color_format,
243
+ )
244
+ return [deduce_localisation(r[task]) for r in result]
245
+
246
+ def ocr_image_region(
247
+ self,
248
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
249
+ xyxy: Union[
250
+ torch.Tensor,
251
+ List[List[Union[float, int]]],
252
+ List[Union[float, int]],
253
+ np.ndarray,
254
+ ],
255
+ max_new_tokens: int = 4096,
256
+ num_beams: int = 3,
257
+ do_sample: bool = False,
258
+ input_color_format: Optional[ColorFormat] = None,
259
+ ) -> List[str]:
260
+ loc_phrases = region_to_loc_phrase(images=images, xyxy=xyxy)
261
+ prompt = [f"<REGION_TO_OCR>{phrase}" for phrase in loc_phrases]
262
+ task = "<REGION_TO_OCR>"
263
+ result = self.prompt(
264
+ images=images,
265
+ prompt=prompt,
266
+ max_new_tokens=max_new_tokens,
267
+ num_beams=num_beams,
268
+ do_sample=do_sample,
269
+ task=task,
270
+ input_color_format=input_color_format,
271
+ )
272
+ return [deduce_localisation(r[task]) for r in result]
273
+
274
+ def segment_region(
275
+ self,
276
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
277
+ xyxy: Union[
278
+ torch.Tensor,
279
+ List[List[Union[float, int]]],
280
+ List[Union[float, int]],
281
+ np.ndarray,
282
+ ],
283
+ max_new_tokens: int = 4096,
284
+ num_beams: int = 3,
285
+ do_sample: bool = False,
286
+ input_color_format: Optional[ColorFormat] = None,
287
+ ) -> List[InstanceDetections]:
288
+ loc_phrases = region_to_loc_phrase(images=images, xyxy=xyxy)
289
+ prompt = [f"<REGION_TO_SEGMENTATION>{phrase}" for phrase in loc_phrases]
290
+ task = "<REGION_TO_SEGMENTATION>"
291
+ inputs, image_dimensions, pre_processing_metadata = self.pre_process_generation(
292
+ images=images, prompt=prompt, input_color_format=input_color_format
293
+ )
294
+ generated_ids = self.generate(
295
+ inputs=inputs,
296
+ max_new_tokens=max_new_tokens,
297
+ num_beams=num_beams,
298
+ do_sample=do_sample,
299
+ )
300
+ result = self.post_process_generation(
301
+ generated_ids=generated_ids,
302
+ image_dimensions=image_dimensions,
303
+ task=task,
304
+ )
305
+ if pre_processing_metadata is None:
306
+ pre_processing_metadata = [None] * len(image_dimensions)
307
+ return [
308
+ parse_instance_segmentation_prediction(
309
+ prediction=r[task],
310
+ input_image_dimensions=i,
311
+ image_metadata=image_metadata,
312
+ device=self._device,
313
+ )
314
+ for r, i, image_metadata in zip(
315
+ result, image_dimensions, pre_processing_metadata
316
+ )
317
+ ]
318
+
319
+ def segment_phrase(
320
+ self,
321
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
322
+ phrase: str,
323
+ max_new_tokens: int = 4096,
324
+ num_beams: int = 3,
325
+ do_sample: bool = False,
326
+ input_color_format: Optional[ColorFormat] = None,
327
+ ) -> List[InstanceDetections]:
328
+ prompt = f"<REFERRING_EXPRESSION_SEGMENTATION>{phrase}"
329
+ task = "<REFERRING_EXPRESSION_SEGMENTATION>"
330
+ inputs, image_dimensions, pre_processing_metadata = self.pre_process_generation(
331
+ images=images, prompt=prompt, input_color_format=input_color_format
332
+ )
333
+ generated_ids = self.generate(
334
+ inputs=inputs,
335
+ max_new_tokens=max_new_tokens,
336
+ num_beams=num_beams,
337
+ do_sample=do_sample,
338
+ )
339
+ result = self.post_process_generation(
340
+ generated_ids=generated_ids,
341
+ image_dimensions=image_dimensions,
342
+ task=task,
343
+ )
344
+ if pre_processing_metadata is None:
345
+ pre_processing_metadata = [None] * len(image_dimensions)
346
+ image_dimensions = extract_input_images_dimensions(images=images)
347
+ return [
348
+ parse_instance_segmentation_prediction(
349
+ prediction=r[task],
350
+ input_image_dimensions=i,
351
+ image_metadata=image_metadata,
352
+ device=self._device,
353
+ )
354
+ for r, i, image_metadata in zip(
355
+ result, image_dimensions, pre_processing_metadata
356
+ )
357
+ ]
358
+
359
+ def ground_phrase(
360
+ self,
361
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
362
+ phrase: str,
363
+ max_new_tokens: int = 4096,
364
+ num_beams: int = 3,
365
+ do_sample: bool = False,
366
+ input_color_format: Optional[ColorFormat] = None,
367
+ ) -> List[Detections]:
368
+ prompt = f"<CAPTION_TO_PHRASE_GROUNDING>{phrase}"
369
+ task = "<CAPTION_TO_PHRASE_GROUNDING>"
370
+ inputs, image_dimensions, pre_processing_metadata = self.pre_process_generation(
371
+ images=images, prompt=prompt, input_color_format=input_color_format
372
+ )
373
+ generated_ids = self.generate(
374
+ inputs=inputs,
375
+ max_new_tokens=max_new_tokens,
376
+ num_beams=num_beams,
377
+ do_sample=do_sample,
378
+ )
379
+ result = self.post_process_generation(
380
+ generated_ids=generated_ids,
381
+ image_dimensions=image_dimensions,
382
+ task=task,
383
+ )
384
+ if pre_processing_metadata is None:
385
+ pre_processing_metadata = [None] * len(image_dimensions)
386
+ return [
387
+ parse_object_detection_prediction(
388
+ prediction=r[task],
389
+ image_metadata=image_metadata,
390
+ device=self._device,
391
+ )
392
+ for r, image_metadata in zip(result, pre_processing_metadata)
393
+ ]
394
+
395
+ def detect_objects(
396
+ self,
397
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
398
+ labels_mode: Literal["classes", "captions", "rois"] = "classes",
399
+ classes: Optional[List[str]] = None,
400
+ max_new_tokens: int = 4096,
401
+ num_beams: int = 3,
402
+ do_sample: bool = False,
403
+ input_color_format: Optional[ColorFormat] = None,
404
+ ) -> List[Detections]:
405
+ if classes:
406
+ classes_str = "<and>".join(classes)
407
+ # not using <OPEN_VOCABULARY_DETECTION> as it associates number of objects with phrases
408
+ prompt = f"<CAPTION_TO_PHRASE_GROUNDING>{classes_str}"
409
+ task = "<CAPTION_TO_PHRASE_GROUNDING>"
410
+ else:
411
+ task = LABEL_MODE2TASK[labels_mode]
412
+ prompt = task
413
+ inputs, image_dimensions, pre_processing_metadata = self.pre_process_generation(
414
+ images=images, prompt=prompt, input_color_format=input_color_format
415
+ )
416
+ generated_ids = self.generate(
417
+ inputs=inputs,
418
+ max_new_tokens=max_new_tokens,
419
+ num_beams=num_beams,
420
+ do_sample=do_sample,
421
+ )
422
+ result = self.post_process_generation(
423
+ generated_ids=generated_ids,
424
+ image_dimensions=image_dimensions,
425
+ task=task,
426
+ )
427
+ if pre_processing_metadata is None:
428
+ pre_processing_metadata = [None] * len(image_dimensions)
429
+ return [
430
+ parse_object_detection_prediction(
431
+ prediction=r[task],
432
+ image_metadata=image_metadata,
433
+ expected_classes=classes,
434
+ device=self._device,
435
+ )
436
+ for r, image_metadata in zip(result, pre_processing_metadata)
437
+ ]
438
+
439
+ def caption_image(
440
+ self,
441
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
442
+ granularity: Literal["normal", "detailed", "very_detailed"] = "normal",
443
+ max_new_tokens: int = 4096,
444
+ num_beams: int = 3,
445
+ do_sample: bool = False,
446
+ input_color_format: Optional[ColorFormat] = None,
447
+ ) -> List[str]:
448
+ task = GRANULARITY_2TASK[granularity]
449
+ result = self.prompt(
450
+ images=images,
451
+ prompt=task,
452
+ max_new_tokens=max_new_tokens,
453
+ num_beams=num_beams,
454
+ do_sample=do_sample,
455
+ task=task,
456
+ input_color_format=input_color_format,
457
+ )
458
+ return [r[task] for r in result]
459
+
460
+ def parse_document(
461
+ self,
462
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
463
+ max_new_tokens: int = 4096,
464
+ num_beams: int = 3,
465
+ do_sample: bool = False,
466
+ input_color_format: Optional[ColorFormat] = None,
467
+ ) -> List[Detections]:
468
+ task = "<OCR_WITH_REGION>"
469
+ inputs, image_dimensions, pre_processing_metadata = self.pre_process_generation(
470
+ images=images, prompt=task, input_color_format=input_color_format
471
+ )
472
+ generated_ids = self.generate(
473
+ inputs=inputs,
474
+ max_new_tokens=max_new_tokens,
475
+ num_beams=num_beams,
476
+ do_sample=do_sample,
477
+ )
478
+ result = self.post_process_generation(
479
+ generated_ids=generated_ids,
480
+ image_dimensions=image_dimensions,
481
+ task=task,
482
+ )
483
+ if pre_processing_metadata is None:
484
+ pre_processing_metadata = [None] * len(image_dimensions)
485
+ return [
486
+ parse_dense_ocr_prediction(
487
+ prediction=r[task], image_metadata=image_metadata, device=self._device
488
+ )
489
+ for r, image_metadata in zip(result, pre_processing_metadata)
490
+ ]
491
+
492
+ def ocr_image(
493
+ self,
494
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
495
+ max_new_tokens: int = 4096,
496
+ num_beams: int = 3,
497
+ do_sample: bool = False,
498
+ input_color_format: Optional[ColorFormat] = None,
499
+ ) -> List[str]:
500
+ task = "<OCR>"
501
+ result = self.prompt(
502
+ images=images,
503
+ prompt=task,
504
+ max_new_tokens=max_new_tokens,
505
+ num_beams=num_beams,
506
+ do_sample=do_sample,
507
+ task=task,
508
+ input_color_format=input_color_format,
509
+ )
510
+ return [r[task] for r in result]
511
+
512
+ def prompt(
513
+ self,
514
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
515
+ prompt: Union[str, List[str]],
516
+ max_new_tokens: int = 4096,
517
+ num_beams: int = 3,
518
+ do_sample: bool = False,
519
+ skip_special_tokens: bool = False,
520
+ task: Optional[str] = None,
521
+ input_color_format: Optional[ColorFormat] = None,
522
+ **kwargs,
523
+ ) -> List[str]:
524
+ inputs, image_dimensions, _ = self.pre_process_generation(
525
+ images=images, prompt=prompt, input_color_format=input_color_format
526
+ )
527
+ generated_ids = self.generate(
528
+ inputs=inputs,
529
+ max_new_tokens=max_new_tokens,
530
+ num_beams=num_beams,
531
+ do_sample=do_sample,
532
+ )
533
+ return self.post_process_generation(
534
+ generated_ids=generated_ids,
535
+ skip_special_tokens=skip_special_tokens,
536
+ image_dimensions=image_dimensions,
537
+ task=task,
538
+ )
539
+
540
+ def pre_process_generation(
541
+ self,
542
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
543
+ prompt: Union[str, List[str]],
544
+ input_color_format: Optional[ColorFormat] = None,
545
+ **kwargs,
546
+ ) -> Tuple[dict, List[ImageDimensions], Optional[List[PreProcessingMetadata]]]:
547
+ # # maybe don't need to convert to tensor here, since processor also accepts numpy arrays
548
+ # # but need to handle input_color_format here and this is consistent with how we do it in other models
549
+ def _to_tensor(image: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
550
+ is_numpy = isinstance(image, np.ndarray)
551
+ if is_numpy:
552
+ tensor_image = torch.from_numpy(image.copy()).permute(2, 0, 1)
553
+ else:
554
+ tensor_image = image
555
+ if input_color_format == "bgr" or (is_numpy and input_color_format is None):
556
+ tensor_image = tensor_image[[2, 1, 0], :, :]
557
+ return tensor_image
558
+
559
+ if self._inference_config is None:
560
+ if isinstance(images, torch.Tensor) and images.ndim > 3:
561
+ image_list = [_to_tensor(img) for img in images]
562
+ elif not isinstance(images, list):
563
+ image_list = [_to_tensor(images)]
564
+ else:
565
+ image_list = [_to_tensor(img) for img in images]
566
+ image_dimensions = extract_input_images_dimensions(images=image_list)
567
+ pre_processing_metadata = None
568
+ else:
569
+ images, pre_processing_metadata = pre_process_network_input(
570
+ images=images,
571
+ image_pre_processing=self._inference_config.image_pre_processing,
572
+ network_input=self._inference_config.network_input,
573
+ target_device=self._device,
574
+ input_color_format=input_color_format,
575
+ )
576
+ image_list = [e[0] for e in torch.split(images, 1, dim=0)]
577
+ image_dimensions = [
578
+ e.size_after_pre_processing for e in pre_processing_metadata
579
+ ]
580
+
581
+ if isinstance(prompt, list):
582
+ if len(prompt) != len(image_dimensions):
583
+ raise ModelRuntimeError(
584
+ message="Provided prompt as list, but the number of prompt elements does not match number of input images.",
585
+ help_url="https://todo",
586
+ )
587
+ else:
588
+ prompt = [prompt] * len(image_dimensions)
589
+
590
+ inputs = self._processor(
591
+ text=prompt, images=image_list, return_tensors="pt"
592
+ ).to(self._device, self._torch_dtype)
593
+ return inputs, image_dimensions, pre_processing_metadata
594
+
595
+ def generate(
596
+ self,
597
+ inputs: dict,
598
+ max_new_tokens: int = 4096,
599
+ num_beams: int = 3,
600
+ do_sample: bool = False,
601
+ **kwargs,
602
+ ) -> torch.Tensor:
603
+ return self._model.generate(
604
+ input_ids=inputs["input_ids"],
605
+ pixel_values=inputs["pixel_values"],
606
+ max_new_tokens=max_new_tokens,
607
+ num_beams=num_beams,
608
+ do_sample=do_sample,
609
+ **kwargs,
610
+ )
611
+
612
+ def post_process_generation(
613
+ self,
614
+ generated_ids: torch.Tensor,
615
+ skip_special_tokens: bool = False,
616
+ image_dimensions: Optional[List[ImageDimensions]] = None,
617
+ task: Optional[str] = None,
618
+ **kwargs,
619
+ ) -> Union[List[dict], List[str]]:
620
+ generated_texts = self._processor.batch_decode(
621
+ generated_ids, skip_special_tokens=skip_special_tokens
622
+ )
623
+ if image_dimensions is None or task is None:
624
+ return generated_texts
625
+ results = []
626
+ for single_image_text, single_image_dimensions in zip(
627
+ generated_texts, image_dimensions
628
+ ):
629
+ post_processed = self._processor.post_process_generation(
630
+ single_image_text,
631
+ task=task,
632
+ image_size=(
633
+ single_image_dimensions.width,
634
+ single_image_dimensions.height,
635
+ ),
636
+ )
637
+ results.append(post_processed)
638
+ return results
639
+
640
+
641
+ def region_to_loc_phrase(
642
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
643
+ xyxy: Union[
644
+ torch.Tensor, List[List[Union[float, int]]], List[Union[float, int]], np.ndarray
645
+ ],
646
+ ) -> List[str]:
647
+ if isinstance(xyxy, torch.Tensor):
648
+ xyxy = xyxy.cpu().numpy()
649
+ if isinstance(xyxy, np.ndarray):
650
+ xyxy = xyxy.tolist()
651
+ image_dimensions = extract_input_images_dimensions(images=images)
652
+ if not xyxy:
653
+ raise ModelRuntimeError(
654
+ message="Provided empty region grounding.", help_url="https://todo"
655
+ )
656
+ nested = isinstance(xyxy[0], list)
657
+ if not nested:
658
+ xyxy = [xyxy] * len(image_dimensions)
659
+ if len(xyxy) != len(image_dimensions):
660
+ raise ModelRuntimeError(
661
+ message="Provided multiple regions - it is expected to provide a single region for each image, but number "
662
+ "of regions does not match number of input images.",
663
+ help_url="https://todo",
664
+ )
665
+ result = []
666
+ for image_xyxy, single_image_dimensions in zip(xyxy, image_dimensions):
667
+ if _coordinates_are_relative(xyxy=image_xyxy):
668
+ left_top_x = _coordinate_to_loc(value=image_xyxy[0])
669
+ left_top_y = _coordinate_to_loc(value=image_xyxy[1])
670
+ right_bottom_x = _coordinate_to_loc(value=image_xyxy[2])
671
+ right_bottom_y = _coordinate_to_loc(value=image_xyxy[3])
672
+ loc_string = f"<loc_{left_top_x}><loc_{left_top_y}><loc_{right_bottom_x}><loc_{right_bottom_y}>"
673
+ result.append(loc_string)
674
+ else:
675
+ left_top_x = _coordinate_to_loc(
676
+ value=image_xyxy[0] / single_image_dimensions.width
677
+ )
678
+ left_top_y = _coordinate_to_loc(
679
+ value=image_xyxy[1] / single_image_dimensions.height
680
+ )
681
+ right_bottom_x = _coordinate_to_loc(
682
+ value=image_xyxy[2] / single_image_dimensions.width
683
+ )
684
+ right_bottom_y = _coordinate_to_loc(
685
+ value=image_xyxy[3] / single_image_dimensions.height
686
+ )
687
+ loc_string = f"<loc_{left_top_x}><loc_{left_top_y}><loc_{right_bottom_x}><loc_{right_bottom_y}>"
688
+ result.append(loc_string)
689
+ return result
690
+
691
+
692
+ def _coordinates_are_relative(xyxy: List[Union[float, int]]) -> bool:
693
+ return all(0 <= c <= 1 for c in xyxy)
694
+
695
+
696
+ def _coordinate_to_loc(value: float) -> int:
697
+ loc_bin = round(_scale_value(value=value, min_value=0.0, max_value=1.0) * LOC_BINS)
698
+ return _scale_value( # to make sure 0-999 cutting out 1000 on 1.0
699
+ value=loc_bin,
700
+ min_value=0,
701
+ max_value=LOC_BINS - 1,
702
+ )
703
+
704
+
705
+ def _scale_value(
706
+ value: Union[int, float],
707
+ min_value: Union[int, float],
708
+ max_value: Union[int, float],
709
+ ) -> Union[int, float]:
710
+ return max(min(value, max_value), min_value)
711
+
712
+
713
+ def parse_dense_ocr_prediction(
714
+ prediction: dict,
715
+ image_metadata: Optional[PreProcessingMetadata],
716
+ device: torch.device,
717
+ ) -> Detections:
718
+ bboxes = prediction["quad_boxes"]
719
+ labels = prediction.get("labels", [""] * len(bboxes))
720
+ class_ids = [0] * len(bboxes)
721
+ xyxy = []
722
+ for box in bboxes:
723
+ np_box = np.array(box).reshape(-1, 2).round().astype(np.int32)
724
+ min_x, min_y = np_box[:, 0].min(), np_box[:, 1].min()
725
+ max_x, max_y = np_box[:, 0].max(), np_box[:, 1].max()
726
+ xyxy.append([min_x, min_y, max_x, max_y])
727
+ xyxy = torch.tensor(xyxy, device=device).round().int()
728
+ if image_metadata is not None and (
729
+ image_metadata.static_crop_offset.offset_x > 0
730
+ or image_metadata.static_crop_offset.offset_y > 0
731
+ ):
732
+ static_crop_offsets = torch.as_tensor(
733
+ [
734
+ image_metadata.static_crop_offset.offset_x,
735
+ image_metadata.static_crop_offset.offset_y,
736
+ image_metadata.static_crop_offset.offset_x,
737
+ image_metadata.static_crop_offset.offset_y,
738
+ ],
739
+ dtype=xyxy.dtype,
740
+ device=xyxy.device,
741
+ )
742
+ xyxy.add_(static_crop_offsets).round_()
743
+ class_ids = torch.tensor(class_ids, device=device).int()
744
+ confidence = torch.tensor([1.0] * len(labels), device=device)
745
+ bboxes_metadata = [{"class_name": label} for label in labels]
746
+ return Detections(
747
+ xyxy=xyxy,
748
+ class_id=class_ids,
749
+ confidence=confidence,
750
+ bboxes_metadata=bboxes_metadata,
751
+ )
752
+
753
+
754
+ def parse_object_detection_prediction(
755
+ prediction: dict,
756
+ image_metadata: Optional[PreProcessingMetadata],
757
+ device: torch.device,
758
+ expected_classes: Optional[List[int]] = None,
759
+ ) -> Detections:
760
+ bboxes = prediction["bboxes"]
761
+ labels = prediction.get(
762
+ "labels", prediction.get("bboxes_labels", [""] * len(bboxes))
763
+ )
764
+ if not expected_classes:
765
+ class_ids = [0] * len(bboxes)
766
+ else:
767
+ class_name2idx = {c: i for i, c in enumerate(expected_classes)}
768
+ unknown_class_id = len(expected_classes)
769
+ class_ids = []
770
+ for label in labels:
771
+ class_ids.append(class_name2idx.get(label, unknown_class_id))
772
+ xyxy = torch.tensor(bboxes, device=device).round().int()
773
+ if image_metadata is not None and (
774
+ image_metadata.static_crop_offset.offset_x > 0
775
+ or image_metadata.static_crop_offset.offset_y > 0
776
+ ):
777
+ static_crop_offsets = torch.as_tensor(
778
+ [
779
+ image_metadata.static_crop_offset.offset_x,
780
+ image_metadata.static_crop_offset.offset_y,
781
+ image_metadata.static_crop_offset.offset_x,
782
+ image_metadata.static_crop_offset.offset_y,
783
+ ],
784
+ dtype=xyxy.dtype,
785
+ device=xyxy.device,
786
+ )
787
+ xyxy.add_(static_crop_offsets).round_()
788
+ class_ids = torch.tensor(class_ids, device=device).int()
789
+ confidence = torch.tensor([1.0] * len(labels), device=device)
790
+ bboxes_metadata = None
791
+ if not expected_classes:
792
+ bboxes_metadata = [{"class_name": label} for label in labels]
793
+ return Detections(
794
+ xyxy=xyxy,
795
+ class_id=class_ids,
796
+ confidence=confidence,
797
+ bboxes_metadata=bboxes_metadata,
798
+ )
799
+
800
+
801
+ def deduce_localisation(result: str) -> str:
802
+ if "<loc" not in result:
803
+ return result
804
+ return result[: result.index("<loc")]
805
+
806
+
807
+ def parse_instance_segmentation_prediction(
808
+ prediction: dict,
809
+ input_image_dimensions: ImageDimensions,
810
+ image_metadata: Optional[PreProcessingMetadata],
811
+ device: torch.device,
812
+ ) -> InstanceDetections:
813
+ xyxy = []
814
+ masks = []
815
+ for polygons in prediction["polygons"]:
816
+ for polygon in polygons:
817
+ mask = np.zeros(
818
+ (input_image_dimensions.height, input_image_dimensions.width),
819
+ dtype=np.uint8,
820
+ )
821
+ np_polygon = np.array(polygon).reshape(-1, 2).round().astype(np.int32)
822
+ if len(np_polygon) < 3:
823
+ continue
824
+ mask = cv2.fillPoly(mask, pts=[np_polygon], color=255)
825
+ mask = mask > 0
826
+ masks.append(mask)
827
+ min_x, min_y = np_polygon[:, 0].min(), np_polygon[:, 1].min()
828
+ max_x, max_y = np_polygon[:, 0].max(), np_polygon[:, 1].max()
829
+ xyxy.append([min_x, min_y, max_x, max_y])
830
+ class_ids = [0] * len(xyxy)
831
+ confidence = [1.0] * len(xyxy)
832
+ xyxy = torch.tensor(xyxy, device=device).round().int()
833
+ mask = torch.from_numpy(np.stack(masks, axis=0)).to(device)
834
+ if image_metadata is not None and (
835
+ image_metadata.static_crop_offset.offset_x > 0
836
+ or image_metadata.static_crop_offset.offset_y > 0
837
+ ):
838
+ static_crop_offsets = torch.as_tensor(
839
+ [
840
+ image_metadata.static_crop_offset.offset_x,
841
+ image_metadata.static_crop_offset.offset_y,
842
+ image_metadata.static_crop_offset.offset_x,
843
+ image_metadata.static_crop_offset.offset_y,
844
+ ],
845
+ dtype=xyxy.dtype,
846
+ device=device,
847
+ )
848
+ xyxy.add_(static_crop_offsets).round_()
849
+ mask_canvas = torch.zeros(
850
+ (
851
+ mask.shape[0],
852
+ image_metadata.original_size.height,
853
+ image_metadata.original_size.width,
854
+ ),
855
+ dtype=torch.bool,
856
+ device=device,
857
+ )
858
+ mask_canvas[
859
+ :,
860
+ image_metadata.static_crop_offset.offset_y : image_metadata.static_crop_offset.offset_y
861
+ + mask.shape[1],
862
+ image_metadata.static_crop_offset.offset_x : image_metadata.static_crop_offset.offset_x
863
+ + mask.shape[2],
864
+ ] = mask
865
+ return InstanceDetections(
866
+ xyxy=xyxy,
867
+ class_id=torch.tensor(class_ids, device=device).int(),
868
+ confidence=torch.tensor(confidence, device=device),
869
+ mask=mask,
870
+ )
871
+
872
+
873
+ def normalize_adapter_state_dict(adapter_state: dict) -> dict:
874
+ normalized = {}
875
+ for key, value in adapter_state.items():
876
+ new_key = key
877
+ # Ensure Florence-2 PEFT prefix matches injected structure
878
+ if (
879
+ "base_model.model.vision_tower." in new_key
880
+ and "base_model.model.model.vision_tower." not in new_key
881
+ ):
882
+ new_key = new_key.replace(
883
+ "base_model.model.vision_tower.",
884
+ "base_model.model.model.vision_tower.",
885
+ )
886
+ # Normalize original repo FFN path to HF-native
887
+ if ".ffn.fn.net.fc1" in new_key:
888
+ new_key = new_key.replace(".ffn.fn.net.fc1", ".ffn.fc1")
889
+ if ".ffn.fn.net.fc2" in new_key:
890
+ new_key = new_key.replace(".ffn.fn.net.fc2", ".ffn.fc2")
891
+ # Normalize language path if needed
892
+ if ".language_model.model." in new_key:
893
+ new_key = new_key.replace(
894
+ ".language_model.model.", ".model.language_model."
895
+ )
896
+ normalized[new_key] = value
897
+ return normalized