inference-models 0.18.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. inference_models/__init__.py +36 -0
  2. inference_models/configuration.py +72 -0
  3. inference_models/constants.py +2 -0
  4. inference_models/entities.py +5 -0
  5. inference_models/errors.py +137 -0
  6. inference_models/logger.py +52 -0
  7. inference_models/model_pipelines/__init__.py +0 -0
  8. inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
  9. inference_models/model_pipelines/auto_loaders/core.py +120 -0
  10. inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
  11. inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
  12. inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
  13. inference_models/models/__init__.py +0 -0
  14. inference_models/models/auto_loaders/__init__.py +0 -0
  15. inference_models/models/auto_loaders/access_manager.py +168 -0
  16. inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
  17. inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
  18. inference_models/models/auto_loaders/constants.py +7 -0
  19. inference_models/models/auto_loaders/core.py +1341 -0
  20. inference_models/models/auto_loaders/dependency_models.py +52 -0
  21. inference_models/models/auto_loaders/entities.py +57 -0
  22. inference_models/models/auto_loaders/models_registry.py +497 -0
  23. inference_models/models/auto_loaders/presentation_utils.py +333 -0
  24. inference_models/models/auto_loaders/ranking.py +413 -0
  25. inference_models/models/auto_loaders/utils.py +31 -0
  26. inference_models/models/base/__init__.py +0 -0
  27. inference_models/models/base/classification.py +123 -0
  28. inference_models/models/base/depth_estimation.py +62 -0
  29. inference_models/models/base/documents_parsing.py +111 -0
  30. inference_models/models/base/embeddings.py +66 -0
  31. inference_models/models/base/instance_segmentation.py +87 -0
  32. inference_models/models/base/keypoints_detection.py +93 -0
  33. inference_models/models/base/object_detection.py +143 -0
  34. inference_models/models/base/semantic_segmentation.py +74 -0
  35. inference_models/models/base/types.py +5 -0
  36. inference_models/models/clip/__init__.py +0 -0
  37. inference_models/models/clip/clip_onnx.py +148 -0
  38. inference_models/models/clip/clip_pytorch.py +104 -0
  39. inference_models/models/clip/preprocessing.py +162 -0
  40. inference_models/models/common/__init__.py +0 -0
  41. inference_models/models/common/cuda.py +30 -0
  42. inference_models/models/common/model_packages.py +25 -0
  43. inference_models/models/common/onnx.py +379 -0
  44. inference_models/models/common/roboflow/__init__.py +0 -0
  45. inference_models/models/common/roboflow/model_packages.py +361 -0
  46. inference_models/models/common/roboflow/post_processing.py +436 -0
  47. inference_models/models/common/roboflow/pre_processing.py +1332 -0
  48. inference_models/models/common/torch.py +20 -0
  49. inference_models/models/common/trt.py +266 -0
  50. inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
  51. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
  52. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
  53. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
  54. inference_models/models/depth_anything_v2/__init__.py +0 -0
  55. inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
  56. inference_models/models/dinov3/__init__.py +0 -0
  57. inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
  58. inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
  59. inference_models/models/doctr/__init__.py +0 -0
  60. inference_models/models/doctr/doctr_torch.py +304 -0
  61. inference_models/models/easy_ocr/__init__.py +0 -0
  62. inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
  63. inference_models/models/florence2/__init__.py +0 -0
  64. inference_models/models/florence2/florence2_hf.py +897 -0
  65. inference_models/models/grounding_dino/__init__.py +0 -0
  66. inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
  67. inference_models/models/l2cs/__init__.py +0 -0
  68. inference_models/models/l2cs/l2cs_onnx.py +216 -0
  69. inference_models/models/mediapipe_face_detection/__init__.py +0 -0
  70. inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
  71. inference_models/models/moondream2/__init__.py +0 -0
  72. inference_models/models/moondream2/moondream2_hf.py +281 -0
  73. inference_models/models/owlv2/__init__.py +0 -0
  74. inference_models/models/owlv2/cache.py +182 -0
  75. inference_models/models/owlv2/entities.py +112 -0
  76. inference_models/models/owlv2/owlv2_hf.py +695 -0
  77. inference_models/models/owlv2/reference_dataset.py +291 -0
  78. inference_models/models/paligemma/__init__.py +0 -0
  79. inference_models/models/paligemma/paligemma_hf.py +209 -0
  80. inference_models/models/perception_encoder/__init__.py +0 -0
  81. inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
  82. inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
  83. inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
  84. inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
  85. inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
  86. inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
  87. inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
  88. inference_models/models/qwen25vl/__init__.py +1 -0
  89. inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
  90. inference_models/models/resnet/__init__.py +0 -0
  91. inference_models/models/resnet/resnet_classification_onnx.py +330 -0
  92. inference_models/models/resnet/resnet_classification_torch.py +305 -0
  93. inference_models/models/resnet/resnet_classification_trt.py +369 -0
  94. inference_models/models/rfdetr/__init__.py +0 -0
  95. inference_models/models/rfdetr/backbone_builder.py +101 -0
  96. inference_models/models/rfdetr/class_remapping.py +41 -0
  97. inference_models/models/rfdetr/common.py +115 -0
  98. inference_models/models/rfdetr/default_labels.py +108 -0
  99. inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
  100. inference_models/models/rfdetr/misc.py +26 -0
  101. inference_models/models/rfdetr/ms_deform_attn.py +180 -0
  102. inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
  103. inference_models/models/rfdetr/position_encoding.py +166 -0
  104. inference_models/models/rfdetr/post_processor.py +83 -0
  105. inference_models/models/rfdetr/projector.py +373 -0
  106. inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
  107. inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
  108. inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
  109. inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
  110. inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
  111. inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
  112. inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
  113. inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
  114. inference_models/models/rfdetr/segmentation_head.py +273 -0
  115. inference_models/models/rfdetr/transformer.py +767 -0
  116. inference_models/models/roboflow_instant/__init__.py +0 -0
  117. inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
  118. inference_models/models/sam/__init__.py +0 -0
  119. inference_models/models/sam/cache.py +147 -0
  120. inference_models/models/sam/entities.py +25 -0
  121. inference_models/models/sam/sam_torch.py +675 -0
  122. inference_models/models/sam2/__init__.py +0 -0
  123. inference_models/models/sam2/cache.py +162 -0
  124. inference_models/models/sam2/entities.py +43 -0
  125. inference_models/models/sam2/sam2_torch.py +905 -0
  126. inference_models/models/sam2_rt/__init__.py +0 -0
  127. inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
  128. inference_models/models/smolvlm/__init__.py +0 -0
  129. inference_models/models/smolvlm/smolvlm_hf.py +245 -0
  130. inference_models/models/trocr/__init__.py +0 -0
  131. inference_models/models/trocr/trocr_hf.py +53 -0
  132. inference_models/models/vit/__init__.py +0 -0
  133. inference_models/models/vit/vit_classification_huggingface.py +319 -0
  134. inference_models/models/vit/vit_classification_onnx.py +326 -0
  135. inference_models/models/vit/vit_classification_trt.py +365 -0
  136. inference_models/models/yolact/__init__.py +1 -0
  137. inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
  138. inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
  139. inference_models/models/yolo_world/__init__.py +1 -0
  140. inference_models/models/yolonas/__init__.py +0 -0
  141. inference_models/models/yolonas/nms.py +44 -0
  142. inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
  143. inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
  144. inference_models/models/yolov10/__init__.py +0 -0
  145. inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
  146. inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
  147. inference_models/models/yolov11/__init__.py +0 -0
  148. inference_models/models/yolov11/yolov11_onnx.py +28 -0
  149. inference_models/models/yolov11/yolov11_torch_script.py +25 -0
  150. inference_models/models/yolov11/yolov11_trt.py +21 -0
  151. inference_models/models/yolov12/__init__.py +0 -0
  152. inference_models/models/yolov12/yolov12_onnx.py +7 -0
  153. inference_models/models/yolov12/yolov12_torch_script.py +7 -0
  154. inference_models/models/yolov12/yolov12_trt.py +7 -0
  155. inference_models/models/yolov5/__init__.py +0 -0
  156. inference_models/models/yolov5/nms.py +99 -0
  157. inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
  158. inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
  159. inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
  160. inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
  161. inference_models/models/yolov7/__init__.py +0 -0
  162. inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
  163. inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
  164. inference_models/models/yolov8/__init__.py +0 -0
  165. inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
  166. inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
  167. inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
  168. inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
  169. inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
  170. inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
  171. inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
  172. inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
  173. inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
  174. inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
  175. inference_models/models/yolov9/__init__.py +0 -0
  176. inference_models/models/yolov9/yolov9_onnx.py +7 -0
  177. inference_models/models/yolov9/yolov9_torch_script.py +7 -0
  178. inference_models/models/yolov9/yolov9_trt.py +7 -0
  179. inference_models/runtime_introspection/__init__.py +0 -0
  180. inference_models/runtime_introspection/core.py +410 -0
  181. inference_models/utils/__init__.py +0 -0
  182. inference_models/utils/download.py +608 -0
  183. inference_models/utils/environment.py +28 -0
  184. inference_models/utils/file_system.py +51 -0
  185. inference_models/utils/hashing.py +7 -0
  186. inference_models/utils/imports.py +48 -0
  187. inference_models/utils/onnx_introspection.py +17 -0
  188. inference_models/weights_providers/__init__.py +0 -0
  189. inference_models/weights_providers/core.py +20 -0
  190. inference_models/weights_providers/entities.py +159 -0
  191. inference_models/weights_providers/roboflow.py +601 -0
  192. inference_models-0.18.3.dist-info/METADATA +466 -0
  193. inference_models-0.18.3.dist-info/RECORD +195 -0
  194. inference_models-0.18.3.dist-info/WHEEL +5 -0
  195. inference_models-0.18.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,436 @@
1
+ from typing import List, Literal, Tuple
2
+
3
+ import torch
4
+ import torchvision
5
+ from torchvision.transforms import functional
6
+
7
+ from inference_models.entities import ImageDimensions
8
+ from inference_models.models.common.roboflow.model_packages import (
9
+ PreProcessingMetadata,
10
+ StaticCropOffset,
11
+ )
12
+
13
+
14
+ def run_nms_for_object_detection(
15
+ output: torch.Tensor,
16
+ conf_thresh: float = 0.25,
17
+ iou_thresh: float = 0.45,
18
+ max_detections: int = 100,
19
+ class_agnostic: bool = False,
20
+ box_format: Literal["xywh", "xyxy"] = "xywh",
21
+ ) -> List[torch.Tensor]:
22
+ bs = output.shape[0]
23
+ boxes = output[:, :4, :]
24
+ scores = output[:, 4:, :]
25
+ results = []
26
+ for b in range(bs):
27
+ # Combine transpose & max for efficiency
28
+ class_scores = scores[b] # (80, 8400)
29
+ class_conf, class_ids = class_scores.max(0) # (8400,), (8400,)
30
+ mask = class_conf > conf_thresh
31
+ if not torch.any(mask):
32
+ results.append(torch.zeros((0, 6), device=output.device))
33
+ continue
34
+ bboxes = boxes[b][:, mask].T # (num, 4) -- selects and then transposes
35
+ class_conf = class_conf[mask]
36
+ class_ids = class_ids[mask]
37
+ if box_format == "xywh":
38
+ # Vectorized [x, y, w, h] -> [x1, y1, x2, y2]
39
+ xy = bboxes[:, :2]
40
+ wh = bboxes[:, 2:]
41
+ half_wh = wh / 2
42
+ xyxy = torch.cat((xy - half_wh, xy + half_wh), 1)
43
+ else:
44
+ xyxy = bboxes
45
+ # Class-agnostic NMS -> use dummy class ids
46
+ nms_class_ids = torch.zeros_like(class_ids) if class_agnostic else class_ids
47
+ # NMS and limiting max detections
48
+ keep = torchvision.ops.batched_nms(xyxy, class_conf, nms_class_ids, iou_thresh)
49
+ if keep.numel() > max_detections:
50
+ keep = keep[:max_detections]
51
+ detections = torch.cat(
52
+ (
53
+ xyxy[keep],
54
+ class_conf[keep, None], # unsqueeze(1) is replaced with None
55
+ class_ids[keep, None].float(),
56
+ ),
57
+ 1,
58
+ ) # [x1, y1, x2, y2, conf, cls]
59
+
60
+ results.append(detections)
61
+ return results
62
+
63
+
64
+ def post_process_nms_fused_model_output(
65
+ output: torch.Tensor,
66
+ conf_thresh: float = 0.25,
67
+ ) -> List[torch.Tensor]:
68
+ bs = output.shape[0]
69
+ nms_results = []
70
+ for batch_element_id in range(bs):
71
+ batch_element_result = output[batch_element_id]
72
+ batch_element_result = batch_element_result[
73
+ batch_element_result[:, 4] >= conf_thresh
74
+ ]
75
+ nms_results.append(batch_element_result)
76
+ return nms_results
77
+
78
+
79
+ def run_nms_for_instance_segmentation(
80
+ output: torch.Tensor,
81
+ conf_thresh: float = 0.25,
82
+ iou_thresh: float = 0.45,
83
+ max_detections: int = 100,
84
+ class_agnostic: bool = False,
85
+ box_format: Literal["xywh", "xyxy"] = "xywh",
86
+ ) -> List[torch.Tensor]:
87
+ bs = output.shape[0]
88
+ boxes = output[:, :4, :] # (N, 4, 8400)
89
+ scores = output[:, 4:-32, :] # (N, 80, 8400)
90
+ masks = output[:, -32:, :]
91
+ results = []
92
+
93
+ for b in range(bs):
94
+ bboxes = boxes[b].T # (8400, 4)
95
+ class_scores = scores[b].T # (8400, 80)
96
+ box_masks = masks[b].T
97
+ class_conf, class_ids = class_scores.max(1) # (8400,), (8400,)
98
+ mask = class_conf > conf_thresh
99
+ if mask.sum() == 0:
100
+ results.append(torch.zeros((0, 38), device=output.device))
101
+ continue
102
+ bboxes = bboxes[mask]
103
+ class_conf = class_conf[mask]
104
+ class_ids = class_ids[mask]
105
+ box_masks = box_masks[mask]
106
+ if box_format == "xywh":
107
+ # Vectorized [x, y, w, h] -> [x1, y1, x2, y2]
108
+ xy = bboxes[:, :2]
109
+ wh = bboxes[:, 2:]
110
+ half_wh = wh / 2
111
+ xyxy = torch.cat((xy - half_wh, xy + half_wh), 1)
112
+ else:
113
+ xyxy = bboxes
114
+ # Class-agnostic NMS -> use dummy class ids
115
+ nms_class_ids = torch.zeros_like(class_ids) if class_agnostic else class_ids
116
+ keep = torchvision.ops.batched_nms(xyxy, class_conf, nms_class_ids, iou_thresh)
117
+ keep = keep[:max_detections]
118
+ detections = torch.cat(
119
+ [
120
+ xyxy[keep],
121
+ class_conf[keep].unsqueeze(1),
122
+ class_ids[keep].unsqueeze(1).float(),
123
+ box_masks[keep],
124
+ ],
125
+ dim=1,
126
+ ) # [x1, y1, x2, y2, conf, cls]
127
+ results.append(detections)
128
+ return results
129
+
130
+
131
+ def run_nms_for_key_points_detection(
132
+ output: torch.Tensor,
133
+ num_classes: int,
134
+ key_points_slots_in_prediction: int,
135
+ conf_thresh: float = 0.25,
136
+ iou_thresh: float = 0.45,
137
+ max_detections: int = 100,
138
+ class_agnostic: bool = False,
139
+ ) -> List[torch.Tensor]:
140
+ bs = output.shape[0]
141
+ boxes = output[:, :4, :]
142
+ scores = output[:, 4 : 4 + num_classes, :]
143
+ key_points = output[:, 4 + num_classes :, :]
144
+ results = []
145
+ for b in range(bs):
146
+ class_scores = scores[b]
147
+ class_conf, class_ids = class_scores.max(0)
148
+ mask = class_conf > conf_thresh
149
+ if not torch.any(mask):
150
+ results.append(
151
+ torch.zeros(
152
+ (0, 6 + key_points_slots_in_prediction * 3), device=output.device
153
+ )
154
+ )
155
+ continue
156
+ bboxes = boxes[b][:, mask].T
157
+ image_key_points = key_points[b, :, mask].T
158
+ class_conf = class_conf[mask]
159
+ class_ids = class_ids[mask]
160
+ xy = bboxes[:, :2]
161
+ wh = bboxes[:, 2:]
162
+ half_wh = wh / 2
163
+ xyxy = torch.cat((xy - half_wh, xy + half_wh), 1)
164
+ # Class-agnostic NMS -> use dummy class ids
165
+ nms_class_ids = torch.zeros_like(class_ids) if class_agnostic else class_ids
166
+ # NMS and limiting max detections
167
+ keep = torchvision.ops.batched_nms(xyxy, class_conf, nms_class_ids, iou_thresh)
168
+ if keep.numel() > max_detections:
169
+ keep = keep[:max_detections]
170
+ detections = torch.cat(
171
+ (
172
+ xyxy[keep],
173
+ class_conf[keep, None], # unsqueeze(1) is replaced with None
174
+ class_ids[keep, None].float(),
175
+ image_key_points[keep],
176
+ ),
177
+ 1,
178
+ ) # [x1, y1, x2, y2, conf, cls, keypoints....]
179
+ results.append(detections)
180
+ return results
181
+
182
+
183
+ def rescale_detections(
184
+ detections: List[torch.Tensor], images_metadata: List[PreProcessingMetadata]
185
+ ) -> List[torch.Tensor]:
186
+ for image_detections, metadata in zip(detections, images_metadata):
187
+ _ = rescale_image_detections(
188
+ image_detections=image_detections, image_metadata=metadata
189
+ )
190
+ return detections
191
+
192
+
193
+ def rescale_image_detections(
194
+ image_detections: torch.Tensor,
195
+ image_metadata: PreProcessingMetadata,
196
+ ) -> torch.Tensor:
197
+ # in-place processing
198
+ offsets = torch.as_tensor(
199
+ [
200
+ image_metadata.pad_left,
201
+ image_metadata.pad_top,
202
+ image_metadata.pad_left,
203
+ image_metadata.pad_top,
204
+ ],
205
+ dtype=image_detections.dtype,
206
+ device=image_detections.device,
207
+ )
208
+ image_detections[:, :4].sub_(offsets) # in-place subtraction for speed/memory
209
+ scale = torch.as_tensor(
210
+ [
211
+ image_metadata.scale_width,
212
+ image_metadata.scale_height,
213
+ image_metadata.scale_width,
214
+ image_metadata.scale_height,
215
+ ],
216
+ dtype=image_detections.dtype,
217
+ device=image_detections.device,
218
+ )
219
+ image_detections[:, :4].div_(scale)
220
+ if (
221
+ image_metadata.static_crop_offset.offset_x != 0
222
+ or image_metadata.static_crop_offset.offset_y != 0
223
+ ):
224
+ static_crop_offsets = torch.as_tensor(
225
+ [
226
+ image_metadata.static_crop_offset.offset_x,
227
+ image_metadata.static_crop_offset.offset_y,
228
+ image_metadata.static_crop_offset.offset_x,
229
+ image_metadata.static_crop_offset.offset_y,
230
+ ],
231
+ dtype=image_detections.dtype,
232
+ device=image_detections.device,
233
+ )
234
+ image_detections[:, :4].add_(static_crop_offsets)
235
+ return image_detections
236
+
237
+
238
+ def rescale_key_points_detections(
239
+ detections: List[torch.Tensor],
240
+ images_metadata: List[PreProcessingMetadata],
241
+ num_classes: int,
242
+ key_points_slots_in_prediction: int,
243
+ ) -> List[torch.Tensor]:
244
+ for image_detections, metadata in zip(detections, images_metadata):
245
+ offsets = torch.as_tensor(
246
+ [metadata.pad_left, metadata.pad_top, metadata.pad_left, metadata.pad_top],
247
+ dtype=image_detections.dtype,
248
+ device=image_detections.device,
249
+ )
250
+ image_detections[:, :4].sub_(offsets) # in-place subtraction for speed/memory
251
+ scale = torch.as_tensor(
252
+ [
253
+ metadata.scale_width,
254
+ metadata.scale_height,
255
+ metadata.scale_width,
256
+ metadata.scale_height,
257
+ ],
258
+ dtype=image_detections.dtype,
259
+ device=image_detections.device,
260
+ )
261
+ image_detections[:, :4].div_(scale)
262
+ key_points_offsets = torch.as_tensor(
263
+ [metadata.pad_left, metadata.pad_top, 0],
264
+ dtype=image_detections.dtype,
265
+ device=image_detections.device,
266
+ ).repeat(key_points_slots_in_prediction)
267
+ image_detections[:, 6:].sub_(key_points_offsets)
268
+ key_points_scale = torch.as_tensor(
269
+ [metadata.scale_width, metadata.scale_height, 1.0],
270
+ dtype=image_detections.dtype,
271
+ device=image_detections.device,
272
+ ).repeat(key_points_slots_in_prediction)
273
+ image_detections[:, 6:].div_(key_points_scale)
274
+ if (
275
+ metadata.static_crop_offset.offset_x != 0
276
+ or metadata.static_crop_offset.offset_y != 0
277
+ ):
278
+ static_crop_offset_length = (image_detections.shape[1] - 6) // 3
279
+ static_crop_offsets = torch.as_tensor(
280
+ [
281
+ metadata.static_crop_offset.offset_x,
282
+ metadata.static_crop_offset.offset_y,
283
+ 0,
284
+ ]
285
+ * static_crop_offset_length,
286
+ dtype=image_detections.dtype,
287
+ device=image_detections.device,
288
+ )
289
+ image_detections[:, 6:].add_(static_crop_offsets)
290
+ static_crop_offsets = torch.as_tensor(
291
+ [
292
+ metadata.static_crop_offset.offset_x,
293
+ metadata.static_crop_offset.offset_y,
294
+ metadata.static_crop_offset.offset_x,
295
+ metadata.static_crop_offset.offset_y,
296
+ ],
297
+ dtype=image_detections.dtype,
298
+ device=image_detections.device,
299
+ )
300
+ image_detections[:, :4].add_(static_crop_offsets)
301
+ return detections
302
+
303
+
304
+ def preprocess_segmentation_masks(
305
+ protos: torch.Tensor,
306
+ masks_in: torch.Tensor,
307
+ ) -> torch.Tensor:
308
+ return torch.einsum("chw,nc->nhw", protos, masks_in)
309
+
310
+
311
+ def crop_masks_to_boxes(
312
+ boxes: torch.Tensor,
313
+ masks: torch.Tensor,
314
+ scaling: float = 0.25,
315
+ ) -> torch.Tensor:
316
+ n, h, w = masks.shape
317
+ scaled_boxes = boxes * scaling
318
+ x1, y1, x2, y2 = (
319
+ scaled_boxes[:, 0][:, None, None],
320
+ scaled_boxes[:, 1][:, None, None],
321
+ scaled_boxes[:, 2][:, None, None],
322
+ scaled_boxes[:, 3][:, None, None],
323
+ )
324
+ rows = torch.arange(w, device=masks.device)[None, None, :] # shape: [1, 1, w]
325
+ cols = torch.arange(h, device=masks.device)[None, :, None] # shape: [1, h, 1]
326
+ crop_mask = (rows >= x1) & (rows < x2) & (cols >= y1) & (cols < y2)
327
+ return masks * crop_mask
328
+
329
+
330
+ def align_instance_segmentation_results(
331
+ image_bboxes: torch.Tensor,
332
+ masks: torch.Tensor,
333
+ padding: Tuple[int, int, int, int],
334
+ scale_width: float,
335
+ scale_height: float,
336
+ original_size: ImageDimensions,
337
+ size_after_pre_processing: ImageDimensions,
338
+ inference_size: ImageDimensions,
339
+ static_crop_offset: StaticCropOffset,
340
+ binarization_threshold: float = 0.0,
341
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
342
+ if image_bboxes.shape[0] == 0:
343
+ empty_masks = torch.empty(
344
+ size=(0, size_after_pre_processing.height, size_after_pre_processing.width),
345
+ dtype=torch.bool,
346
+ device=image_bboxes.device,
347
+ )
348
+ return image_bboxes, empty_masks
349
+ pad_left, pad_top, pad_right, pad_bottom = padding
350
+ offsets = torch.tensor(
351
+ [pad_left, pad_top, pad_left, pad_top],
352
+ device=image_bboxes.device,
353
+ )
354
+ image_bboxes[:, :4].sub_(offsets)
355
+ scale = torch.as_tensor(
356
+ [scale_width, scale_height, scale_width, scale_height],
357
+ dtype=image_bboxes.dtype,
358
+ device=image_bboxes.device,
359
+ )
360
+ image_bboxes[:, :4].div_(scale)
361
+ n, mh, mw = masks.shape
362
+ mask_h_scale = mh / inference_size.height
363
+ mask_w_scale = mw / inference_size.width
364
+ mask_pad_top, mask_pad_bottom, mask_pad_left, mask_pad_right = (
365
+ round(mask_h_scale * pad_top),
366
+ round(mask_h_scale * pad_bottom),
367
+ round(mask_w_scale * pad_left),
368
+ round(mask_w_scale * pad_right),
369
+ )
370
+ if (
371
+ mask_pad_top < 0
372
+ or mask_pad_bottom < 0
373
+ or mask_pad_left < 0
374
+ or mask_pad_right < 0
375
+ ):
376
+ masks = torch.nn.functional.pad(
377
+ masks,
378
+ (
379
+ abs(min(mask_pad_left, 0)),
380
+ abs(min(mask_pad_right, 0)),
381
+ abs(min(mask_pad_top, 0)),
382
+ abs(min(mask_pad_bottom, 0)),
383
+ ),
384
+ "constant",
385
+ 0,
386
+ )
387
+ padded_mask_offset_top = max(mask_pad_top, 0)
388
+ padded_mask_offset_bottom = max(mask_pad_bottom, 0)
389
+ padded_mask_offset_left = max(mask_pad_left, 0)
390
+ padded_mask_offset_right = max(mask_pad_right, 0)
391
+ masks = masks[
392
+ :,
393
+ padded_mask_offset_top : masks.shape[1] - padded_mask_offset_bottom,
394
+ padded_mask_offset_left : masks.shape[2] - padded_mask_offset_right,
395
+ ]
396
+ else:
397
+ masks = masks[
398
+ :, mask_pad_top : mh - mask_pad_bottom, mask_pad_left : mw - mask_pad_right
399
+ ]
400
+ masks = (
401
+ functional.resize(
402
+ masks,
403
+ [size_after_pre_processing.height, size_after_pre_processing.width],
404
+ interpolation=functional.InterpolationMode.BILINEAR,
405
+ )
406
+ .gt_(binarization_threshold)
407
+ .to(dtype=torch.bool)
408
+ )
409
+ if static_crop_offset.offset_x > 0 or static_crop_offset.offset_y > 0:
410
+ mask_canvas = torch.zeros(
411
+ (
412
+ masks.shape[0],
413
+ original_size.height,
414
+ original_size.width,
415
+ ),
416
+ dtype=torch.bool,
417
+ device=masks.device,
418
+ )
419
+ mask_canvas[
420
+ :,
421
+ static_crop_offset.offset_y : static_crop_offset.offset_y + masks.shape[1],
422
+ static_crop_offset.offset_x : static_crop_offset.offset_x + masks.shape[2],
423
+ ] = masks
424
+ static_crop_offsets = torch.as_tensor(
425
+ [
426
+ static_crop_offset.offset_x,
427
+ static_crop_offset.offset_y,
428
+ static_crop_offset.offset_x,
429
+ static_crop_offset.offset_y,
430
+ ],
431
+ dtype=image_bboxes.dtype,
432
+ device=image_bboxes.device,
433
+ )
434
+ image_bboxes[:, :4].add_(static_crop_offsets)
435
+ masks = mask_canvas
436
+ return image_bboxes, masks