inference-models 0.18.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. inference_models/__init__.py +36 -0
  2. inference_models/configuration.py +72 -0
  3. inference_models/constants.py +2 -0
  4. inference_models/entities.py +5 -0
  5. inference_models/errors.py +137 -0
  6. inference_models/logger.py +52 -0
  7. inference_models/model_pipelines/__init__.py +0 -0
  8. inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
  9. inference_models/model_pipelines/auto_loaders/core.py +120 -0
  10. inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
  11. inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
  12. inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
  13. inference_models/models/__init__.py +0 -0
  14. inference_models/models/auto_loaders/__init__.py +0 -0
  15. inference_models/models/auto_loaders/access_manager.py +168 -0
  16. inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
  17. inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
  18. inference_models/models/auto_loaders/constants.py +7 -0
  19. inference_models/models/auto_loaders/core.py +1341 -0
  20. inference_models/models/auto_loaders/dependency_models.py +52 -0
  21. inference_models/models/auto_loaders/entities.py +57 -0
  22. inference_models/models/auto_loaders/models_registry.py +497 -0
  23. inference_models/models/auto_loaders/presentation_utils.py +333 -0
  24. inference_models/models/auto_loaders/ranking.py +413 -0
  25. inference_models/models/auto_loaders/utils.py +31 -0
  26. inference_models/models/base/__init__.py +0 -0
  27. inference_models/models/base/classification.py +123 -0
  28. inference_models/models/base/depth_estimation.py +62 -0
  29. inference_models/models/base/documents_parsing.py +111 -0
  30. inference_models/models/base/embeddings.py +66 -0
  31. inference_models/models/base/instance_segmentation.py +87 -0
  32. inference_models/models/base/keypoints_detection.py +93 -0
  33. inference_models/models/base/object_detection.py +143 -0
  34. inference_models/models/base/semantic_segmentation.py +74 -0
  35. inference_models/models/base/types.py +5 -0
  36. inference_models/models/clip/__init__.py +0 -0
  37. inference_models/models/clip/clip_onnx.py +148 -0
  38. inference_models/models/clip/clip_pytorch.py +104 -0
  39. inference_models/models/clip/preprocessing.py +162 -0
  40. inference_models/models/common/__init__.py +0 -0
  41. inference_models/models/common/cuda.py +30 -0
  42. inference_models/models/common/model_packages.py +25 -0
  43. inference_models/models/common/onnx.py +379 -0
  44. inference_models/models/common/roboflow/__init__.py +0 -0
  45. inference_models/models/common/roboflow/model_packages.py +361 -0
  46. inference_models/models/common/roboflow/post_processing.py +436 -0
  47. inference_models/models/common/roboflow/pre_processing.py +1332 -0
  48. inference_models/models/common/torch.py +20 -0
  49. inference_models/models/common/trt.py +266 -0
  50. inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
  51. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
  52. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
  53. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
  54. inference_models/models/depth_anything_v2/__init__.py +0 -0
  55. inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
  56. inference_models/models/dinov3/__init__.py +0 -0
  57. inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
  58. inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
  59. inference_models/models/doctr/__init__.py +0 -0
  60. inference_models/models/doctr/doctr_torch.py +304 -0
  61. inference_models/models/easy_ocr/__init__.py +0 -0
  62. inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
  63. inference_models/models/florence2/__init__.py +0 -0
  64. inference_models/models/florence2/florence2_hf.py +897 -0
  65. inference_models/models/grounding_dino/__init__.py +0 -0
  66. inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
  67. inference_models/models/l2cs/__init__.py +0 -0
  68. inference_models/models/l2cs/l2cs_onnx.py +216 -0
  69. inference_models/models/mediapipe_face_detection/__init__.py +0 -0
  70. inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
  71. inference_models/models/moondream2/__init__.py +0 -0
  72. inference_models/models/moondream2/moondream2_hf.py +281 -0
  73. inference_models/models/owlv2/__init__.py +0 -0
  74. inference_models/models/owlv2/cache.py +182 -0
  75. inference_models/models/owlv2/entities.py +112 -0
  76. inference_models/models/owlv2/owlv2_hf.py +695 -0
  77. inference_models/models/owlv2/reference_dataset.py +291 -0
  78. inference_models/models/paligemma/__init__.py +0 -0
  79. inference_models/models/paligemma/paligemma_hf.py +209 -0
  80. inference_models/models/perception_encoder/__init__.py +0 -0
  81. inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
  82. inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
  83. inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
  84. inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
  85. inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
  86. inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
  87. inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
  88. inference_models/models/qwen25vl/__init__.py +1 -0
  89. inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
  90. inference_models/models/resnet/__init__.py +0 -0
  91. inference_models/models/resnet/resnet_classification_onnx.py +330 -0
  92. inference_models/models/resnet/resnet_classification_torch.py +305 -0
  93. inference_models/models/resnet/resnet_classification_trt.py +369 -0
  94. inference_models/models/rfdetr/__init__.py +0 -0
  95. inference_models/models/rfdetr/backbone_builder.py +101 -0
  96. inference_models/models/rfdetr/class_remapping.py +41 -0
  97. inference_models/models/rfdetr/common.py +115 -0
  98. inference_models/models/rfdetr/default_labels.py +108 -0
  99. inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
  100. inference_models/models/rfdetr/misc.py +26 -0
  101. inference_models/models/rfdetr/ms_deform_attn.py +180 -0
  102. inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
  103. inference_models/models/rfdetr/position_encoding.py +166 -0
  104. inference_models/models/rfdetr/post_processor.py +83 -0
  105. inference_models/models/rfdetr/projector.py +373 -0
  106. inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
  107. inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
  108. inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
  109. inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
  110. inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
  111. inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
  112. inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
  113. inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
  114. inference_models/models/rfdetr/segmentation_head.py +273 -0
  115. inference_models/models/rfdetr/transformer.py +767 -0
  116. inference_models/models/roboflow_instant/__init__.py +0 -0
  117. inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
  118. inference_models/models/sam/__init__.py +0 -0
  119. inference_models/models/sam/cache.py +147 -0
  120. inference_models/models/sam/entities.py +25 -0
  121. inference_models/models/sam/sam_torch.py +675 -0
  122. inference_models/models/sam2/__init__.py +0 -0
  123. inference_models/models/sam2/cache.py +162 -0
  124. inference_models/models/sam2/entities.py +43 -0
  125. inference_models/models/sam2/sam2_torch.py +905 -0
  126. inference_models/models/sam2_rt/__init__.py +0 -0
  127. inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
  128. inference_models/models/smolvlm/__init__.py +0 -0
  129. inference_models/models/smolvlm/smolvlm_hf.py +245 -0
  130. inference_models/models/trocr/__init__.py +0 -0
  131. inference_models/models/trocr/trocr_hf.py +53 -0
  132. inference_models/models/vit/__init__.py +0 -0
  133. inference_models/models/vit/vit_classification_huggingface.py +319 -0
  134. inference_models/models/vit/vit_classification_onnx.py +326 -0
  135. inference_models/models/vit/vit_classification_trt.py +365 -0
  136. inference_models/models/yolact/__init__.py +1 -0
  137. inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
  138. inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
  139. inference_models/models/yolo_world/__init__.py +1 -0
  140. inference_models/models/yolonas/__init__.py +0 -0
  141. inference_models/models/yolonas/nms.py +44 -0
  142. inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
  143. inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
  144. inference_models/models/yolov10/__init__.py +0 -0
  145. inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
  146. inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
  147. inference_models/models/yolov11/__init__.py +0 -0
  148. inference_models/models/yolov11/yolov11_onnx.py +28 -0
  149. inference_models/models/yolov11/yolov11_torch_script.py +25 -0
  150. inference_models/models/yolov11/yolov11_trt.py +21 -0
  151. inference_models/models/yolov12/__init__.py +0 -0
  152. inference_models/models/yolov12/yolov12_onnx.py +7 -0
  153. inference_models/models/yolov12/yolov12_torch_script.py +7 -0
  154. inference_models/models/yolov12/yolov12_trt.py +7 -0
  155. inference_models/models/yolov5/__init__.py +0 -0
  156. inference_models/models/yolov5/nms.py +99 -0
  157. inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
  158. inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
  159. inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
  160. inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
  161. inference_models/models/yolov7/__init__.py +0 -0
  162. inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
  163. inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
  164. inference_models/models/yolov8/__init__.py +0 -0
  165. inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
  166. inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
  167. inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
  168. inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
  169. inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
  170. inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
  171. inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
  172. inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
  173. inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
  174. inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
  175. inference_models/models/yolov9/__init__.py +0 -0
  176. inference_models/models/yolov9/yolov9_onnx.py +7 -0
  177. inference_models/models/yolov9/yolov9_torch_script.py +7 -0
  178. inference_models/models/yolov9/yolov9_trt.py +7 -0
  179. inference_models/runtime_introspection/__init__.py +0 -0
  180. inference_models/runtime_introspection/core.py +410 -0
  181. inference_models/utils/__init__.py +0 -0
  182. inference_models/utils/download.py +608 -0
  183. inference_models/utils/environment.py +28 -0
  184. inference_models/utils/file_system.py +51 -0
  185. inference_models/utils/hashing.py +7 -0
  186. inference_models/utils/imports.py +48 -0
  187. inference_models/utils/onnx_introspection.py +17 -0
  188. inference_models/weights_providers/__init__.py +0 -0
  189. inference_models/weights_providers/core.py +20 -0
  190. inference_models/weights_providers/entities.py +159 -0
  191. inference_models/weights_providers/roboflow.py +601 -0
  192. inference_models-0.18.3.dist-info/METADATA +466 -0
  193. inference_models-0.18.3.dist-info/RECORD +195 -0
  194. inference_models-0.18.3.dist-info/WHEEL +5 -0
  195. inference_models-0.18.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1332 @@
1
+ import math
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import PIL
7
+ import torch
8
+ from PIL.Image import Image
9
+ from skimage import exposure
10
+ from torchvision.transforms import Grayscale, functional
11
+
12
+ from inference_models.entities import ColorFormat, ImageDimensions
13
+ from inference_models.errors import ModelRuntimeError
14
+ from inference_models.logger import LOGGER
15
+ from inference_models.models.common.roboflow.model_packages import (
16
+ AnySizePadding,
17
+ ColorMode,
18
+ ContrastType,
19
+ DivisiblePadding,
20
+ ImagePreProcessing,
21
+ NetworkInputDefinition,
22
+ PreProcessingMetadata,
23
+ ResizeMode,
24
+ StaticCrop,
25
+ StaticCropOffset,
26
+ )
27
+
28
+
29
+ def pre_process_network_input(
30
+ images: Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]],
31
+ image_pre_processing: ImagePreProcessing,
32
+ network_input: NetworkInputDefinition,
33
+ target_device: torch.device,
34
+ input_color_format: Optional[ColorFormat] = None,
35
+ image_size_wh: Optional[Union[int, Tuple[int, int]]] = None,
36
+ ) -> Tuple[torch.Tensor, List[PreProcessingMetadata]]:
37
+ if network_input.input_channels != 3:
38
+ raise ModelRuntimeError(
39
+ message=f"`inference` currently does not support Roboflow pre-processing for model inputs with "
40
+ f"channels numbers different than 1. Let us know if you need this feature.",
41
+ help_url="https://todo",
42
+ )
43
+ input_color_mode = None
44
+ if input_color_format is not None:
45
+ input_color_mode = ColorMode(input_color_format)
46
+ if isinstance(image_size_wh, (int, float)):
47
+ image_size_wh = int(image_size_wh), int(image_size_wh)
48
+ if isinstance(images, np.ndarray):
49
+ return pre_process_numpy_image(
50
+ image=images,
51
+ image_pre_processing=image_pre_processing,
52
+ network_input=network_input,
53
+ target_device=target_device,
54
+ input_color_mode=input_color_mode,
55
+ image_size_wh=image_size_wh,
56
+ )
57
+ if isinstance(images, torch.Tensor):
58
+ return pre_process_images_tensor(
59
+ images=images,
60
+ image_pre_processing=image_pre_processing,
61
+ network_input=network_input,
62
+ input_color_mode=input_color_mode,
63
+ target_device=target_device,
64
+ image_size_wh=image_size_wh,
65
+ )
66
+ if not isinstance(images, list):
67
+ raise ModelRuntimeError(
68
+ message="Pre-processing supports only np.array or torch.Tensor or list of above.",
69
+ help_url="https://todo",
70
+ )
71
+ if not len(images):
72
+ raise ModelRuntimeError(
73
+ message="Detected empty input to the model", help_url="https://todo"
74
+ )
75
+ if network_input.resize_mode is ResizeMode.FIT_LONGER_EDGE:
76
+ raise ModelRuntimeError(
77
+ message="Model input resize type (fit-longer-edge) cannot be applied equally for "
78
+ "all input batch elements arbitrarily - this type of model does not support input batches.",
79
+ help_url="https://todo",
80
+ )
81
+ if isinstance(images[0], np.ndarray):
82
+ return pre_process_numpy_images_list(
83
+ images=images,
84
+ image_pre_processing=image_pre_processing,
85
+ network_input=network_input,
86
+ input_color_mode=input_color_mode,
87
+ target_device=target_device,
88
+ image_size_wh=image_size_wh,
89
+ )
90
+ if isinstance(images[0], torch.Tensor):
91
+ return pre_process_images_tensor_list(
92
+ images=images,
93
+ image_pre_processing=image_pre_processing,
94
+ network_input=network_input,
95
+ input_color_mode=input_color_mode,
96
+ target_device=target_device,
97
+ image_size_wh=image_size_wh,
98
+ )
99
+ raise ModelRuntimeError(
100
+ message=f"Detected unknown input batch element: {type(images[0])}",
101
+ help_url="https://todo",
102
+ )
103
+
104
+
105
+ @torch.inference_mode()
106
+ def pre_process_images_tensor(
107
+ images: torch.Tensor,
108
+ image_pre_processing: ImagePreProcessing,
109
+ network_input: NetworkInputDefinition,
110
+ target_device: torch.device,
111
+ input_color_mode: Optional[ColorMode] = None,
112
+ image_size_wh: Optional[Tuple[int, int]] = None,
113
+ ) -> Tuple[torch.Tensor, List[PreProcessingMetadata]]:
114
+ if input_color_mode is None:
115
+ input_color_mode = ColorMode.RGB
116
+ target_dimensions = (
117
+ network_input.training_input_size.width,
118
+ network_input.training_input_size.height,
119
+ )
120
+ if image_size_wh is not None and image_size_wh != target_dimensions:
121
+ if not network_input.dynamic_spatial_size_supported:
122
+ LOGGER.warning(
123
+ f"Requested image size: {image_size_wh} cannot be applied for model input, as model was trained with "
124
+ f"input resolution and does not support inputs of a different shape. `image_size_wh` gets ignored."
125
+ )
126
+ elif isinstance(network_input.dynamic_spatial_size_mode, DivisiblePadding):
127
+ target_dimensions = (
128
+ make_the_value_divisible(
129
+ x=image_size_wh[0], by=network_input.dynamic_spatial_size_mode.value
130
+ ),
131
+ make_the_value_divisible(
132
+ x=image_size_wh[1], by=network_input.dynamic_spatial_size_mode.value
133
+ ),
134
+ )
135
+ elif isinstance(network_input.dynamic_spatial_size_mode, AnySizePadding):
136
+ target_dimensions = image_size_wh
137
+ else:
138
+ raise ModelRuntimeError(
139
+ message=f"Handler for dynamic spatial mode of type {type(network_input.dynamic_spatial_size_mode)} "
140
+ f"is not implemented.",
141
+ help_url="",
142
+ )
143
+ if images.device != target_device:
144
+ images = images.to(target_device)
145
+ if len(images.shape) == 3:
146
+ images = torch.unsqueeze(images, 0)
147
+ if (
148
+ images.shape[1] != network_input.input_channels
149
+ and images.shape[3] == network_input.input_channels
150
+ ):
151
+ images = images.permute(0, 3, 1, 2)
152
+ original_size = ImageDimensions(width=images.shape[3], height=images.shape[2])
153
+ image, static_crop_offset = apply_pre_processing_to_torch_image(
154
+ image=images,
155
+ image_pre_processing=image_pre_processing,
156
+ network_input_channels=network_input.input_channels,
157
+ )
158
+ if network_input.resize_mode not in NUMPY_IMAGES_PREPARATION_HANDLERS:
159
+ raise ModelRuntimeError(
160
+ message=f"Unsupported model input resize mode: {network_input.resize_mode}",
161
+ help_url="https://todo",
162
+ )
163
+ return TORCH_IMAGES_PREPARATION_HANDLERS[network_input.resize_mode](
164
+ image,
165
+ network_input,
166
+ input_color_mode,
167
+ original_size,
168
+ ImageDimensions(width=target_dimensions[0], height=target_dimensions[1]),
169
+ static_crop_offset,
170
+ )
171
+
172
+
173
+ def apply_pre_processing_to_torch_image(
174
+ image: torch.Tensor,
175
+ image_pre_processing: ImagePreProcessing,
176
+ network_input_channels: int,
177
+ ) -> Tuple[torch.Tensor, StaticCropOffset]:
178
+ static_crop_offset = StaticCropOffset(
179
+ offset_x=0,
180
+ offset_y=0,
181
+ crop_width=image.shape[3],
182
+ crop_height=image.shape[2],
183
+ )
184
+ if image_pre_processing.static_crop and image_pre_processing.static_crop.enabled:
185
+ image, static_crop_offset = apply_static_crop_to_torch_image(
186
+ image=image,
187
+ config=image_pre_processing.static_crop,
188
+ )
189
+ if image_pre_processing.grayscale and image_pre_processing.grayscale.enabled:
190
+ image = Grayscale(num_output_channels=network_input_channels)(image)
191
+ if image_pre_processing.contrast and image_pre_processing.contrast.enabled:
192
+ if (
193
+ image_pre_processing.contrast.type
194
+ not in CONTRAST_ADJUSTMENT_METHODS_FOR_TORCH
195
+ ):
196
+ raise ModelRuntimeError(
197
+ message=f"Unsupported image contrast adjustment type: {image_pre_processing.contrast.type.value}",
198
+ help_url="https://todo",
199
+ )
200
+ image = CONTRAST_ADJUSTMENT_METHODS_FOR_TORCH[
201
+ image_pre_processing.contrast.type
202
+ ](image)
203
+ return image, static_crop_offset
204
+
205
+
206
+ def apply_static_crop_to_torch_image(
207
+ image: torch.Tensor, config: StaticCrop
208
+ ) -> Tuple[torch.Tensor, StaticCropOffset]:
209
+ width, height = image.shape[3], image.shape[2]
210
+ x_min = int(config.x_min / 100 * width)
211
+ y_min = int(config.y_min / 100 * height)
212
+ x_max = int(config.x_max / 100 * width)
213
+ y_max = int(config.y_max / 100 * height)
214
+ cropped_tensor = image[:, :, y_min:y_max, x_min:x_max]
215
+ offset = StaticCropOffset(
216
+ offset_x=x_min,
217
+ offset_y=y_min,
218
+ crop_width=cropped_tensor.shape[3],
219
+ crop_height=cropped_tensor.shape[2],
220
+ )
221
+ return cropped_tensor, offset
222
+
223
+
224
+ def apply_adaptive_equalization_to_torch_image(image: torch.Tensor) -> torch.Tensor:
225
+ original_device = image.device
226
+ results = []
227
+ for single_image in image:
228
+ single_image_numpy = np.transpose(single_image.cpu().numpy(), (1, 2, 0))
229
+ image = single_image_numpy.astype(np.float32) / 255
230
+ image_adapted = (
231
+ exposure.equalize_adapthist(image, clip_limit=0.03) * 255
232
+ ).astype(np.uint8)
233
+ results.append(torch.from_numpy(image_adapted).to(original_device))
234
+ return torch.stack(results, dim=0).permute(0, 3, 1, 2)
235
+
236
+
237
+ def apply_contrast_stretching_to_torch_image(image: torch.Tensor) -> torch.Tensor:
238
+ original_device = image.device
239
+ results = []
240
+ for single_image in image:
241
+ single_image_numpy = np.transpose(single_image.cpu().numpy(), (1, 2, 0))
242
+ p2 = np.percentile(single_image_numpy, 2)
243
+ p98 = np.percentile(single_image_numpy, 98)
244
+ rescaled_image = exposure.rescale_intensity(
245
+ single_image_numpy, in_range=(p2, p98)
246
+ )
247
+ results.append(torch.from_numpy(rescaled_image).to(original_device))
248
+ return torch.stack(results, dim=0).permute(0, 3, 1, 2)
249
+
250
+
251
+ def apply_histogram_equalization_to_torch_image(image: torch.Tensor) -> torch.Tensor:
252
+ original_device = image.device
253
+ results = []
254
+ for single_image in image:
255
+ single_image_numpy = np.transpose(single_image.cpu().numpy(), (1, 2, 0))
256
+ single_image_numpy = single_image_numpy.astype(np.float32) / 255
257
+ image_equalized = exposure.equalize_hist(single_image_numpy) * 255
258
+ results.append(torch.from_numpy(image_equalized).to(original_device))
259
+ return torch.stack(results, dim=0).permute(0, 3, 1, 2)
260
+
261
+
262
+ CONTRAST_ADJUSTMENT_METHODS_FOR_TORCH = {
263
+ ContrastType.ADAPTIVE_EQUALIZATION: apply_adaptive_equalization_to_torch_image,
264
+ ContrastType.CONTRAST_STRETCHING: apply_contrast_stretching_to_torch_image,
265
+ ContrastType.HISTOGRAM_EQUALIZATION: apply_histogram_equalization_to_torch_image,
266
+ }
267
+
268
+
269
+ def handle_tensor_input_preparation_with_stretch(
270
+ image: torch.Tensor,
271
+ network_input: NetworkInputDefinition,
272
+ input_color_mode: ColorMode,
273
+ original_size: ImageDimensions,
274
+ target_size: ImageDimensions,
275
+ static_crop_offset: StaticCropOffset,
276
+ ) -> Tuple[torch.Tensor, List[PreProcessingMetadata]]:
277
+ size_after_pre_processing = ImageDimensions(
278
+ height=image.shape[2], width=image.shape[3]
279
+ )
280
+ if image.device.type == "cuda":
281
+ image = image.float()
282
+ image = torch.nn.functional.interpolate(
283
+ image,
284
+ size=[target_size.height, target_size.width],
285
+ mode="bilinear",
286
+ )
287
+ if input_color_mode != network_input.color_mode:
288
+ image = image[:, [2, 1, 0], :, :]
289
+ if network_input.scaling_factor is not None:
290
+ image = image / network_input.scaling_factor
291
+ if network_input.normalization is not None:
292
+ if not image.is_floating_point():
293
+ image = image.to(dtype=torch.float32)
294
+ image = functional.normalize(
295
+ image,
296
+ mean=network_input.normalization[0],
297
+ std=network_input.normalization[1],
298
+ )
299
+ metadata = PreProcessingMetadata(
300
+ pad_left=0,
301
+ pad_top=0,
302
+ pad_right=0,
303
+ pad_bottom=0,
304
+ original_size=original_size,
305
+ size_after_pre_processing=size_after_pre_processing,
306
+ inference_size=target_size,
307
+ scale_width=target_size.width / size_after_pre_processing.width,
308
+ scale_height=target_size.height / size_after_pre_processing.height,
309
+ static_crop_offset=static_crop_offset,
310
+ )
311
+ return image.contiguous(), [metadata] * image.shape[0]
312
+
313
+
314
+ def handle_torch_input_preparation_with_letterbox(
315
+ image: torch.Tensor,
316
+ network_input: NetworkInputDefinition,
317
+ input_color_mode: ColorMode,
318
+ original_size: ImageDimensions,
319
+ target_size: ImageDimensions,
320
+ static_crop_offset: StaticCropOffset,
321
+ ) -> Tuple[torch.Tensor, List[PreProcessingMetadata]]:
322
+ original_height, original_width = image.shape[2], image.shape[3]
323
+ size_after_pre_processing = ImageDimensions(
324
+ height=original_height, width=original_width
325
+ )
326
+ scale_w = target_size.width / original_width
327
+ scale_h = target_size.height / original_height
328
+ scale = min(scale_w, scale_h)
329
+ new_width = int(original_width * scale)
330
+ new_height = int(original_height * scale)
331
+ pad_top = int((target_size.height - new_height) / 2)
332
+ pad_left = int((target_size.width - new_width) / 2)
333
+ if image.device.type == "cuda":
334
+ image = image.float()
335
+ image = torch.nn.functional.interpolate(
336
+ image,
337
+ [new_height, new_width],
338
+ mode="bilinear",
339
+ )
340
+ if input_color_mode != network_input.color_mode:
341
+ image = image[:, [2, 1, 0], :, :]
342
+ final_batch = torch.full(
343
+ (
344
+ image.shape[0],
345
+ image.shape[1],
346
+ target_size.height,
347
+ target_size.width,
348
+ ),
349
+ network_input.padding_value or 0,
350
+ dtype=torch.float32,
351
+ device=image.device,
352
+ )
353
+ final_batch[
354
+ :, :, pad_top : pad_top + new_height, pad_left : pad_left + new_width
355
+ ] = image
356
+ pad_right = target_size.width - pad_left - new_width
357
+ pad_bottom = target_size.height - pad_top - new_height
358
+ metadata = PreProcessingMetadata(
359
+ pad_left=pad_left,
360
+ pad_top=pad_top,
361
+ pad_right=pad_right,
362
+ pad_bottom=pad_bottom,
363
+ original_size=original_size,
364
+ size_after_pre_processing=size_after_pre_processing,
365
+ inference_size=target_size,
366
+ scale_width=scale,
367
+ scale_height=scale,
368
+ static_crop_offset=static_crop_offset,
369
+ )
370
+ if network_input.scaling_factor is not None:
371
+ final_batch = final_batch / network_input.scaling_factor
372
+ if network_input.normalization is not None:
373
+ if not final_batch.is_floating_point():
374
+ final_batch = final_batch.to(dtype=torch.float32)
375
+ final_batch = functional.normalize(
376
+ final_batch,
377
+ mean=network_input.normalization[0],
378
+ std=network_input.normalization[1],
379
+ )
380
+ return final_batch.contiguous(), [metadata] * final_batch.shape[0]
381
+
382
+
383
+ def handle_torch_input_preparation_with_center_crop(
384
+ image: torch.Tensor,
385
+ network_input: NetworkInputDefinition,
386
+ input_color_mode: ColorMode,
387
+ original_size: ImageDimensions,
388
+ target_size: ImageDimensions,
389
+ static_crop_offset: StaticCropOffset,
390
+ ) -> Tuple[torch.Tensor, List[PreProcessingMetadata]]:
391
+ if input_color_mode != network_input.color_mode:
392
+ image = image[:, [2, 1, 0], :, :]
393
+ size_after_pre_processing = ImageDimensions(
394
+ height=image.shape[2], width=image.shape[3]
395
+ )
396
+ padding_ltrb = [0, 0, 0, 0]
397
+ if (
398
+ target_size.width > size_after_pre_processing.width
399
+ or target_size.height > size_after_pre_processing.height
400
+ ):
401
+ padding_ltrb = [
402
+ (
403
+ (target_size.width - size_after_pre_processing.width) // 2
404
+ if target_size.width > size_after_pre_processing.width
405
+ else 0
406
+ ),
407
+ (
408
+ (target_size.height - size_after_pre_processing.height) // 2
409
+ if target_size.height > size_after_pre_processing.height
410
+ else 0
411
+ ),
412
+ (
413
+ (target_size.width - size_after_pre_processing.width + 1) // 2
414
+ if target_size.width > size_after_pre_processing.width
415
+ else 0
416
+ ),
417
+ (
418
+ (target_size.height - size_after_pre_processing.height + 1) // 2
419
+ if target_size.height > size_after_pre_processing.height
420
+ else 0
421
+ ),
422
+ ]
423
+ image = functional.pad(image, padding_ltrb, fill=0)
424
+ crop_ltrb = [0, 0, 0, 0]
425
+ if target_size.width != image.shape[3] or target_size.height != image.shape[2]:
426
+ crop_top = int(round((image.shape[2] - target_size.height) / 2.0))
427
+ crop_bottom = image.shape[2] - target_size.height - crop_top
428
+ crop_left = int(round((image.shape[3] - target_size.width) / 2.0))
429
+ crop_right = image.shape[3] - target_size.width - crop_left
430
+ crop_ltrb = [crop_left, crop_top, crop_right, crop_bottom]
431
+ image = functional.crop(
432
+ image, crop_top, crop_left, target_size.height, target_size.width
433
+ )
434
+ if target_size.height > size_after_pre_processing.height:
435
+ reported_padding_top = padding_ltrb[1]
436
+ reported_padding_bottom = padding_ltrb[3]
437
+ else:
438
+ reported_padding_top = -crop_ltrb[1]
439
+ reported_padding_bottom = -crop_ltrb[3]
440
+ if target_size.width > size_after_pre_processing.width:
441
+ reported_padding_left = padding_ltrb[0]
442
+ reported_padding_right = padding_ltrb[2]
443
+ else:
444
+ reported_padding_left = -crop_ltrb[0]
445
+ reported_padding_right = -crop_ltrb[2]
446
+ image_metadata = PreProcessingMetadata(
447
+ pad_left=reported_padding_left,
448
+ pad_top=reported_padding_top,
449
+ pad_right=reported_padding_right,
450
+ pad_bottom=reported_padding_bottom,
451
+ original_size=original_size,
452
+ size_after_pre_processing=size_after_pre_processing,
453
+ inference_size=target_size,
454
+ scale_width=1.0,
455
+ scale_height=1.0,
456
+ static_crop_offset=static_crop_offset,
457
+ )
458
+ if network_input.scaling_factor is not None:
459
+ image = image / network_input.scaling_factor
460
+ if network_input.normalization is not None:
461
+ if not image.is_floating_point():
462
+ image = image.to(dtype=torch.float32)
463
+ image = functional.normalize(
464
+ image,
465
+ mean=network_input.normalization[0],
466
+ std=network_input.normalization[1],
467
+ )
468
+ return image.contiguous(), [image_metadata] * image.shape[0]
469
+
470
+
471
+ def handle_torch_input_preparation_fitting_longer_edge(
472
+ image: torch.Tensor,
473
+ network_input: NetworkInputDefinition,
474
+ input_color_mode: ColorMode,
475
+ original_size: ImageDimensions,
476
+ target_size: ImageDimensions,
477
+ static_crop_offset: StaticCropOffset,
478
+ ) -> Tuple[torch.Tensor, List[PreProcessingMetadata]]:
479
+ original_height, original_width = image.shape[2], image.shape[3]
480
+ size_after_pre_processing = ImageDimensions(
481
+ height=original_height, width=original_width
482
+ )
483
+ scale_ox = target_size.width / size_after_pre_processing.width
484
+ scale_oy = target_size.height / size_after_pre_processing.height
485
+ if scale_ox < scale_oy:
486
+ actual_target_width = target_size.width
487
+ actual_target_height = round(scale_ox * size_after_pre_processing.height)
488
+ else:
489
+ actual_target_width = round(scale_oy * size_after_pre_processing.width)
490
+ actual_target_height = target_size.height
491
+ actual_target_size = ImageDimensions(
492
+ height=actual_target_height,
493
+ width=actual_target_width,
494
+ )
495
+ if image.device.type == "cuda":
496
+ image = image.float()
497
+ image = torch.nn.functional.interpolate(
498
+ image,
499
+ [actual_target_size.height, actual_target_size.width],
500
+ mode="bilinear",
501
+ )
502
+ if input_color_mode != network_input.color_mode:
503
+ image = image[:, [2, 1, 0], :, :]
504
+ image_metadata = PreProcessingMetadata(
505
+ pad_left=0,
506
+ pad_top=0,
507
+ pad_right=0,
508
+ pad_bottom=0,
509
+ original_size=original_size,
510
+ size_after_pre_processing=size_after_pre_processing,
511
+ inference_size=actual_target_size,
512
+ scale_width=actual_target_size.width / size_after_pre_processing.width,
513
+ scale_height=actual_target_size.height / size_after_pre_processing.height,
514
+ static_crop_offset=static_crop_offset,
515
+ )
516
+ if network_input.scaling_factor is not None:
517
+ image = image / network_input.scaling_factor
518
+ if network_input.normalization is not None:
519
+ if not image.is_floating_point():
520
+ image = image.to(dtype=torch.float32)
521
+ image = functional.normalize(
522
+ image,
523
+ mean=network_input.normalization[0],
524
+ std=network_input.normalization[1],
525
+ )
526
+ return image.contiguous(), [image_metadata] * image.shape[0]
527
+
528
+
529
+ TORCH_IMAGES_PREPARATION_HANDLERS = {
530
+ ResizeMode.STRETCH_TO: handle_tensor_input_preparation_with_stretch,
531
+ ResizeMode.LETTERBOX: handle_torch_input_preparation_with_letterbox,
532
+ ResizeMode.CENTER_CROP: handle_torch_input_preparation_with_center_crop,
533
+ ResizeMode.FIT_LONGER_EDGE: handle_torch_input_preparation_fitting_longer_edge,
534
+ ResizeMode.LETTERBOX_REFLECT_EDGES: handle_torch_input_preparation_with_letterbox,
535
+ }
536
+
537
+
538
+ @torch.inference_mode()
539
+ def pre_process_images_tensor_list(
540
+ images: List[torch.Tensor],
541
+ image_pre_processing: ImagePreProcessing,
542
+ network_input: NetworkInputDefinition,
543
+ target_device: torch.device,
544
+ input_color_mode: Optional[ColorMode] = None,
545
+ image_size_wh: Optional[Tuple[int, int]] = None,
546
+ ) -> Tuple[torch.Tensor, List[PreProcessingMetadata]]:
547
+ if network_input.resize_mode not in TORCH_LIST_IMAGES_PREPARATION_HANDLERS:
548
+ raise ModelRuntimeError(
549
+ message=f"Unsupported model input resize mode: {network_input.resize_mode}",
550
+ help_url="https://todo",
551
+ )
552
+ if input_color_mode is None:
553
+ input_color_mode = ColorMode.RGB
554
+ target_dimensions = (
555
+ network_input.training_input_size.width,
556
+ network_input.training_input_size.height,
557
+ )
558
+ if image_size_wh is not None and image_size_wh != target_dimensions:
559
+ if not network_input.dynamic_spatial_size_supported:
560
+ LOGGER.warning(
561
+ f"Requested image size: {image_size_wh} cannot be applied for model input, as model was trained with "
562
+ f"input resolution and does not support inputs of a different shape. `image_size_wh` gets ignored."
563
+ )
564
+ elif isinstance(network_input.dynamic_spatial_size_mode, DivisiblePadding):
565
+ target_dimensions = (
566
+ make_the_value_divisible(
567
+ x=image_size_wh[0], by=network_input.dynamic_spatial_size_mode.value
568
+ ),
569
+ make_the_value_divisible(
570
+ x=image_size_wh[1], by=network_input.dynamic_spatial_size_mode.value
571
+ ),
572
+ )
573
+ elif isinstance(network_input.dynamic_spatial_size_mode, AnySizePadding):
574
+ target_dimensions = image_size_wh
575
+ else:
576
+ raise ModelRuntimeError(
577
+ message=f"Handler for dynamic spatial mode of type {type(network_input.dynamic_spatial_size_mode)} "
578
+ f"is not implemented.",
579
+ help_url="",
580
+ )
581
+ images, static_crop_offsets, original_sizes = (
582
+ apply_pre_processing_to_list_of_torch_image(
583
+ images=images,
584
+ image_pre_processing=image_pre_processing,
585
+ network_input_channels=network_input.input_channels,
586
+ target_device=target_device,
587
+ )
588
+ )
589
+ return TORCH_LIST_IMAGES_PREPARATION_HANDLERS[network_input.resize_mode](
590
+ images,
591
+ network_input,
592
+ input_color_mode,
593
+ original_sizes,
594
+ ImageDimensions(width=target_dimensions[0], height=target_dimensions[1]),
595
+ static_crop_offsets,
596
+ target_device,
597
+ )
598
+
599
+
600
+ def apply_pre_processing_to_list_of_torch_image(
601
+ images: List[torch.Tensor],
602
+ image_pre_processing: ImagePreProcessing,
603
+ network_input_channels: int,
604
+ target_device: torch.device,
605
+ ) -> Tuple[List[torch.Tensor], List[StaticCropOffset], List[ImageDimensions]]:
606
+ result_images, result_offsets, original_sizes = [], [], []
607
+ for image in images:
608
+ if len(image.shape) != 3:
609
+ raise ModelRuntimeError(
610
+ message="When providing List[torch.Tensor] as input, model requires tensors to have 3 dimensions.",
611
+ help_url="https://todo",
612
+ )
613
+ image = image.to(target_device)
614
+ if image.shape[0] != 3 and image.shape[-1] == 3:
615
+ image = image.permute(2, 0, 1)
616
+ original_sizes.append(
617
+ ImageDimensions(height=image.shape[1], width=image.shape[2])
618
+ )
619
+ result_image, result_offset = apply_pre_processing_to_torch_image(
620
+ image=image.unsqueeze(0),
621
+ image_pre_processing=image_pre_processing,
622
+ network_input_channels=network_input_channels,
623
+ )
624
+ result_images.append(result_image)
625
+ result_offsets.append(result_offset)
626
+ return result_images, result_offsets, original_sizes
627
+
628
+
629
+ def handle_tensor_list_input_preparation_with_stretch(
630
+ images: List[torch.Tensor],
631
+ network_input: NetworkInputDefinition,
632
+ input_color_mode: ColorMode,
633
+ original_sizes: List[ImageDimensions],
634
+ target_size: ImageDimensions,
635
+ static_crop_offsets: List[StaticCropOffset],
636
+ target_device: torch.device,
637
+ ) -> Tuple[torch.Tensor, List[PreProcessingMetadata]]:
638
+ processed = []
639
+ images_metadata = []
640
+ for img, offset, original_size in zip(images, static_crop_offsets, original_sizes):
641
+ size_after_pre_processing = ImageDimensions(
642
+ height=img.shape[2], width=img.shape[3]
643
+ )
644
+ if input_color_mode != network_input.color_mode:
645
+ img = img[:, [2, 1, 0], :, :]
646
+ if img.device.type == "cuda":
647
+ img = img.float()
648
+ img = torch.nn.functional.interpolate(
649
+ img,
650
+ size=[target_size.height, target_size.width],
651
+ mode="bilinear",
652
+ )
653
+ if network_input.scaling_factor is not None:
654
+ img = img / network_input.scaling_factor
655
+ if network_input.normalization is not None:
656
+ if not img.is_floating_point():
657
+ img = img.to(dtype=torch.float32)
658
+ img = functional.normalize(
659
+ img,
660
+ mean=network_input.normalization[0],
661
+ std=network_input.normalization[1],
662
+ )
663
+ processed.append(img.contiguous())
664
+ image_metadata = PreProcessingMetadata(
665
+ pad_left=0,
666
+ pad_top=0,
667
+ pad_right=0,
668
+ pad_bottom=0,
669
+ original_size=original_size,
670
+ size_after_pre_processing=size_after_pre_processing,
671
+ inference_size=target_size,
672
+ scale_width=target_size.width / size_after_pre_processing.width,
673
+ scale_height=target_size.height / size_after_pre_processing.height,
674
+ static_crop_offset=offset,
675
+ )
676
+ images_metadata.append(image_metadata)
677
+ return torch.concat(processed, dim=0).contiguous(), images_metadata
678
+
679
+
680
+ def handle_tensor_list_input_preparation_with_letterbox(
681
+ images: List[torch.Tensor],
682
+ network_input: NetworkInputDefinition,
683
+ input_color_mode: ColorMode,
684
+ original_sizes: List[ImageDimensions],
685
+ target_size: ImageDimensions,
686
+ static_crop_offsets: List[StaticCropOffset],
687
+ target_device: torch.device,
688
+ ) -> Tuple[torch.Tensor, List[PreProcessingMetadata]]:
689
+ num_images = len(images)
690
+ final_batch = torch.full(
691
+ (num_images, 3, target_size.height, target_size.width),
692
+ network_input.padding_value or 0,
693
+ dtype=torch.float32,
694
+ device=target_device,
695
+ )
696
+ original_shapes = torch.tensor(
697
+ [[img.shape[2], img.shape[3]] for img in images], dtype=torch.float32
698
+ )
699
+ scale_w = target_size.width / original_shapes[:, 1]
700
+ scale_h = target_size.height / original_shapes[:, 0]
701
+ scales = torch.minimum(scale_w, scale_h)
702
+ new_ws = (original_shapes[:, 1] * scales).int()
703
+ new_hs = (original_shapes[:, 0] * scales).int()
704
+ pad_tops = ((target_size.height - new_hs) / 2).int()
705
+ pad_lefts = ((target_size.width - new_ws) / 2).int()
706
+ images_metadata = []
707
+ for i in range(num_images):
708
+ img = images[i]
709
+ if len(img.shape) != 4:
710
+ raise ModelRuntimeError(
711
+ message="When providing List[torch.Tensor] as input, model requires tensors to have 3 dimensions.",
712
+ help_url="https://todo",
713
+ )
714
+ original_size = original_sizes[i]
715
+ size_after_pre_processing = ImageDimensions(
716
+ height=img.shape[2], width=img.shape[3]
717
+ )
718
+ if input_color_mode != network_input.color_mode:
719
+ img = img[:, [2, 1, 0], :, :]
720
+ new_h_i, new_w_i = new_hs[i].item(), new_ws[i].item()
721
+ if img.device.type == "cuda":
722
+ img = img.float()
723
+ img = torch.nn.functional.interpolate(
724
+ img,
725
+ size=[new_h_i, new_w_i],
726
+ mode="bilinear",
727
+ )
728
+ pad_top_i, pad_left_i = pad_tops[i].item(), pad_lefts[i].item()
729
+ final_batch[
730
+ i, :, pad_top_i : pad_top_i + new_h_i, pad_left_i : pad_left_i + new_w_i
731
+ ] = img
732
+ pad_right = target_size.width - pad_left_i - new_w_i
733
+ pad_bottom = target_size.height - pad_top_i - new_h_i
734
+ image_metadata = PreProcessingMetadata(
735
+ pad_left=pad_left_i,
736
+ pad_top=pad_top_i,
737
+ pad_right=pad_right,
738
+ pad_bottom=pad_bottom,
739
+ original_size=original_size,
740
+ size_after_pre_processing=size_after_pre_processing,
741
+ inference_size=target_size,
742
+ scale_width=scales[i].item(),
743
+ scale_height=scales[i].item(),
744
+ static_crop_offset=static_crop_offsets[i],
745
+ )
746
+ images_metadata.append(image_metadata)
747
+ if network_input.scaling_factor is not None:
748
+ final_batch = final_batch / network_input.scaling_factor
749
+ if network_input.normalization:
750
+ if not final_batch.is_floating_point():
751
+ final_batch = final_batch.to(dtype=torch.float32)
752
+ final_batch = functional.normalize(
753
+ final_batch,
754
+ mean=network_input.normalization[0],
755
+ std=network_input.normalization[1],
756
+ )
757
+ return final_batch.contiguous(), images_metadata
758
+
759
+
760
+ def handle_tensor_list_input_preparation_with_center_crop(
761
+ images: List[torch.Tensor],
762
+ network_input: NetworkInputDefinition,
763
+ input_color_mode: ColorMode,
764
+ original_sizes: List[ImageDimensions],
765
+ target_size: ImageDimensions,
766
+ static_crop_offsets: List[StaticCropOffset],
767
+ target_device: torch.device,
768
+ ) -> Tuple[torch.Tensor, List[PreProcessingMetadata]]:
769
+ result_tensors, result_metadata = [], []
770
+ for image, offset, original_size in zip(
771
+ images, static_crop_offsets, original_sizes
772
+ ):
773
+ if len(image.shape) != 4:
774
+ # TODO!
775
+ raise ModelRuntimeError(
776
+ message="When providing List[torch.Tensor] as input, model requires tensors to have 3 dimensions.",
777
+ help_url="https://todo",
778
+ )
779
+ image = image.to(target_device)
780
+ if (
781
+ image.shape[1] != network_input.input_channels
782
+ and image.shape[3] == network_input.input_channels
783
+ ):
784
+ image = image.permute(0, 3, 1, 2)
785
+ tensor, metadata = handle_torch_input_preparation_with_center_crop(
786
+ image=image,
787
+ network_input=network_input,
788
+ input_color_mode=input_color_mode,
789
+ original_size=original_size,
790
+ target_size=target_size,
791
+ static_crop_offset=offset,
792
+ )
793
+ result_tensors.append(tensor)
794
+ result_metadata.append(metadata[0])
795
+ return torch.concat(result_tensors, dim=0), result_metadata
796
+
797
+
798
+ TORCH_LIST_IMAGES_PREPARATION_HANDLERS = {
799
+ ResizeMode.STRETCH_TO: handle_tensor_list_input_preparation_with_stretch,
800
+ ResizeMode.LETTERBOX: handle_tensor_list_input_preparation_with_letterbox,
801
+ ResizeMode.CENTER_CROP: handle_tensor_list_input_preparation_with_center_crop,
802
+ ResizeMode.LETTERBOX_REFLECT_EDGES: handle_tensor_list_input_preparation_with_letterbox,
803
+ }
804
+
805
+
806
+ def pre_process_numpy_images_list(
807
+ images: List[np.ndarray],
808
+ image_pre_processing: ImagePreProcessing,
809
+ network_input: NetworkInputDefinition,
810
+ target_device: torch.device,
811
+ input_color_mode: Optional[ColorMode] = None,
812
+ image_size_wh: Optional[Tuple[int, int]] = None,
813
+ ) -> Tuple[torch.Tensor, List[PreProcessingMetadata]]:
814
+ result_tensors, result_metadata = [], []
815
+ for image in images:
816
+ tensor, metadata = pre_process_numpy_image(
817
+ image=image,
818
+ image_pre_processing=image_pre_processing,
819
+ network_input=network_input,
820
+ target_device=target_device,
821
+ input_color_mode=input_color_mode,
822
+ image_size_wh=image_size_wh,
823
+ )
824
+ result_tensors.append(tensor)
825
+ result_metadata.extend(metadata)
826
+ return torch.concat(result_tensors, dim=0).contiguous(), result_metadata
827
+
828
+
829
+ @torch.inference_mode()
830
+ def pre_process_numpy_image(
831
+ image: np.ndarray,
832
+ image_pre_processing: ImagePreProcessing,
833
+ network_input: NetworkInputDefinition,
834
+ target_device: torch.device,
835
+ input_color_mode: Optional[ColorMode] = None,
836
+ image_size_wh: Optional[Tuple[int, int]] = None,
837
+ ) -> Tuple[torch.Tensor, List[PreProcessingMetadata]]:
838
+ if input_color_mode is None:
839
+ input_color_mode = ColorMode.BGR
840
+ target_dimensions = (
841
+ network_input.training_input_size.width,
842
+ network_input.training_input_size.height,
843
+ )
844
+ if image_size_wh is not None and image_size_wh != target_dimensions:
845
+ if not network_input.dynamic_spatial_size_supported:
846
+ LOGGER.warning(
847
+ f"Requested image size: {image_size_wh} cannot be applied for model input, as model was trained with "
848
+ f"input resolution and does not support inputs of a different shape. `image_size_wh` gets ignored."
849
+ )
850
+ elif isinstance(network_input.dynamic_spatial_size_mode, DivisiblePadding):
851
+ target_dimensions = (
852
+ make_the_value_divisible(
853
+ x=image_size_wh[0], by=network_input.dynamic_spatial_size_mode.value
854
+ ),
855
+ make_the_value_divisible(
856
+ x=image_size_wh[1], by=network_input.dynamic_spatial_size_mode.value
857
+ ),
858
+ )
859
+ elif isinstance(network_input.dynamic_spatial_size_mode, AnySizePadding):
860
+ target_dimensions = image_size_wh
861
+ else:
862
+ raise ModelRuntimeError(
863
+ message=f"Handler for dynamic spatial mode of type {type(network_input.dynamic_spatial_size_mode)} "
864
+ f"is not implemented.",
865
+ help_url="",
866
+ )
867
+ original_size = ImageDimensions(width=image.shape[1], height=image.shape[0])
868
+ image, static_crop_offset = apply_pre_processing_to_numpy_image(
869
+ image=image,
870
+ image_pre_processing=image_pre_processing,
871
+ network_input_channels=network_input.input_channels,
872
+ input_color_mode=input_color_mode,
873
+ )
874
+ if network_input.resize_mode not in NUMPY_IMAGES_PREPARATION_HANDLERS:
875
+ raise ModelRuntimeError(
876
+ message=f"Unsupported model input resize mode: {network_input.resize_mode}",
877
+ help_url="https://todo",
878
+ )
879
+ return NUMPY_IMAGES_PREPARATION_HANDLERS[network_input.resize_mode](
880
+ image,
881
+ network_input,
882
+ target_device,
883
+ input_color_mode,
884
+ original_size,
885
+ ImageDimensions(width=target_dimensions[0], height=target_dimensions[1]),
886
+ static_crop_offset,
887
+ )
888
+
889
+
890
+ def apply_pre_processing_to_numpy_image(
891
+ image: np.ndarray,
892
+ image_pre_processing: ImagePreProcessing,
893
+ network_input_channels: int,
894
+ input_color_mode: Optional[ColorMode] = None,
895
+ ) -> Tuple[np.ndarray, StaticCropOffset]:
896
+ if input_color_mode is None:
897
+ input_color_mode = ColorMode.BGR
898
+ static_crop_offset = StaticCropOffset(
899
+ offset_x=0,
900
+ offset_y=0,
901
+ crop_width=image.shape[1],
902
+ crop_height=image.shape[0],
903
+ )
904
+ if image_pre_processing.static_crop and image_pre_processing.static_crop.enabled:
905
+ image, static_crop_offset = apply_static_crop_to_numpy_image(
906
+ image=image,
907
+ config=image_pre_processing.static_crop,
908
+ )
909
+ if image_pre_processing.grayscale and image_pre_processing.grayscale.enabled:
910
+ mode = (
911
+ cv2.COLOR_BGR2GRAY
912
+ if input_color_mode is ColorMode.BGR
913
+ else cv2.COLOR_RGB2GRAY
914
+ )
915
+ image = cv2.cvtColor(image, mode)
916
+ image = np.stack([image] * network_input_channels, axis=2)
917
+ if image_pre_processing.contrast and image_pre_processing.contrast.enabled:
918
+ if (
919
+ image_pre_processing.contrast.type
920
+ not in CONTRAST_ADJUSTMENT_METHODS_FOR_NUMPY
921
+ ):
922
+ raise ModelRuntimeError(
923
+ message=f"Unsupported image contrast adjustment type: {image_pre_processing.contrast.type.value}",
924
+ help_url="https://todo",
925
+ )
926
+ image = CONTRAST_ADJUSTMENT_METHODS_FOR_NUMPY[
927
+ image_pre_processing.contrast.type
928
+ ](image)
929
+ return image, static_crop_offset
930
+
931
+
932
+ def apply_static_crop_to_numpy_image(
933
+ image: np.ndarray, config: StaticCrop
934
+ ) -> Tuple[np.ndarray, StaticCropOffset]:
935
+ width, height = image.shape[1], image.shape[0]
936
+ x_min = int(config.x_min / 100 * width)
937
+ y_min = int(config.y_min / 100 * height)
938
+ x_max = int(config.x_max / 100 * width)
939
+ y_max = int(config.y_max / 100 * height)
940
+ result_image = image[y_min:y_max, x_min:x_max]
941
+ return result_image, StaticCropOffset(
942
+ offset_x=x_min,
943
+ offset_y=y_min,
944
+ crop_width=result_image.shape[1],
945
+ crop_height=result_image.shape[0],
946
+ )
947
+
948
+
949
+ def apply_adaptive_equalization_to_numpy_image(image: np.ndarray) -> np.ndarray:
950
+ image = image.astype(np.float32) / 255
951
+ image_adapted = exposure.equalize_adapthist(image, clip_limit=0.03) * 255
952
+ return image_adapted.astype(np.uint8)
953
+
954
+
955
+ def apply_contrast_stretching_to_numpy_image(image: np.ndarray) -> np.ndarray:
956
+ p2 = np.percentile(image, 2)
957
+ p98 = np.percentile(image, 98)
958
+ return exposure.rescale_intensity(image, in_range=(p2, p98))
959
+
960
+
961
+ def apply_histogram_equalization_to_numpy_image(image: np.ndarray) -> np.ndarray:
962
+ image = image.astype(np.float32) / 255
963
+ image_equalized = exposure.equalize_hist(image) * 255
964
+ return image_equalized.astype(np.uint8)
965
+
966
+
967
+ CONTRAST_ADJUSTMENT_METHODS_FOR_NUMPY = {
968
+ ContrastType.ADAPTIVE_EQUALIZATION: apply_adaptive_equalization_to_numpy_image,
969
+ ContrastType.CONTRAST_STRETCHING: apply_contrast_stretching_to_numpy_image,
970
+ ContrastType.HISTOGRAM_EQUALIZATION: apply_histogram_equalization_to_numpy_image,
971
+ }
972
+
973
+
974
+ def handle_numpy_input_preparation_with_stretch(
975
+ image: np.ndarray,
976
+ network_input: NetworkInputDefinition,
977
+ target_device: torch.device,
978
+ input_color_mode: ColorMode,
979
+ original_size: ImageDimensions,
980
+ target_size: ImageDimensions,
981
+ static_crop_offset: StaticCropOffset,
982
+ ) -> Tuple[torch.Tensor, List[PreProcessingMetadata]]:
983
+ size_after_pre_processing = ImageDimensions(
984
+ height=image.shape[0], width=image.shape[1]
985
+ )
986
+ resized_image = cv2.resize(image, (target_size.width, target_size.height))
987
+ tensor = torch.from_numpy(resized_image).to(device=target_device)
988
+ tensor = torch.unsqueeze(tensor, 0)
989
+ tensor = tensor.permute(0, 3, 1, 2)
990
+ if input_color_mode != network_input.color_mode:
991
+ tensor = tensor[:, [2, 1, 0], :, :]
992
+ if network_input.scaling_factor is not None:
993
+ tensor = tensor / network_input.scaling_factor
994
+ if network_input.normalization:
995
+ if not tensor.is_floating_point():
996
+ tensor = tensor.to(dtype=torch.float32)
997
+ tensor = functional.normalize(
998
+ tensor,
999
+ mean=network_input.normalization[0],
1000
+ std=network_input.normalization[1],
1001
+ )
1002
+ image_metadata = PreProcessingMetadata(
1003
+ pad_left=0,
1004
+ pad_top=0,
1005
+ pad_right=0,
1006
+ pad_bottom=0,
1007
+ original_size=original_size,
1008
+ size_after_pre_processing=size_after_pre_processing,
1009
+ inference_size=target_size,
1010
+ scale_width=target_size.width / size_after_pre_processing.width,
1011
+ scale_height=target_size.height / size_after_pre_processing.height,
1012
+ static_crop_offset=static_crop_offset,
1013
+ )
1014
+ return tensor.contiguous(), [image_metadata]
1015
+
1016
+
1017
+ def handle_numpy_input_preparation_with_letterbox(
1018
+ image: np.ndarray,
1019
+ network_input: NetworkInputDefinition,
1020
+ target_device: torch.device,
1021
+ input_color_mode: ColorMode,
1022
+ original_size: ImageDimensions,
1023
+ target_size: ImageDimensions,
1024
+ static_crop_offset: StaticCropOffset,
1025
+ ) -> Tuple[torch.Tensor, List[PreProcessingMetadata]]:
1026
+ padding_value = network_input.padding_value or 0
1027
+ original_height, original_width = image.shape[0], image.shape[1]
1028
+ size_after_pre_processing = ImageDimensions(
1029
+ height=original_height, width=original_width
1030
+ )
1031
+ scale_w = target_size.width / original_width
1032
+ scale_h = target_size.height / original_height
1033
+ scale = min(scale_w, scale_h)
1034
+ new_width = int(original_width * scale)
1035
+ new_height = int(original_height * scale)
1036
+ pad_top = int((target_size.height - new_height) / 2)
1037
+ pad_left = int((target_size.width - new_width) / 2)
1038
+ scaled_image = cv2.resize(image, (new_width, new_height))
1039
+ scaled_image_tensor = torch.from_numpy(scaled_image).to(target_device)
1040
+ scaled_image_tensor = scaled_image_tensor.permute(2, 0, 1)
1041
+ final_batch = torch.full(
1042
+ (
1043
+ 1,
1044
+ image.shape[2],
1045
+ target_size.height,
1046
+ target_size.width,
1047
+ ),
1048
+ padding_value,
1049
+ dtype=torch.float32,
1050
+ device=target_device,
1051
+ )
1052
+ final_batch[
1053
+ 0, :, pad_top : pad_top + new_height, pad_left : pad_left + new_width
1054
+ ] = scaled_image_tensor
1055
+ if input_color_mode != network_input.color_mode:
1056
+ final_batch = final_batch[:, [2, 1, 0], :, :]
1057
+ pad_right = target_size.width - pad_left - new_width
1058
+ pad_bottom = target_size.height - pad_top - new_height
1059
+ image_metadata = PreProcessingMetadata(
1060
+ pad_left=pad_left,
1061
+ pad_top=pad_top,
1062
+ pad_right=pad_right,
1063
+ pad_bottom=pad_bottom,
1064
+ original_size=original_size,
1065
+ size_after_pre_processing=size_after_pre_processing,
1066
+ inference_size=target_size,
1067
+ scale_width=scale,
1068
+ scale_height=scale,
1069
+ static_crop_offset=static_crop_offset,
1070
+ )
1071
+ if network_input.scaling_factor is not None:
1072
+ final_batch = final_batch / network_input.scaling_factor
1073
+ if network_input.normalization is not None:
1074
+ if not final_batch.is_floating_point():
1075
+ final_batch = final_batch.to(dtype=torch.float32)
1076
+ final_batch = functional.normalize(
1077
+ final_batch,
1078
+ mean=network_input.normalization[0],
1079
+ std=network_input.normalization[1],
1080
+ )
1081
+ return final_batch.contiguous(), [image_metadata]
1082
+
1083
+
1084
+ def handle_numpy_input_preparation_with_center_crop(
1085
+ image: np.ndarray,
1086
+ network_input: NetworkInputDefinition,
1087
+ target_device: torch.device,
1088
+ input_color_mode: ColorMode,
1089
+ original_size: ImageDimensions,
1090
+ target_size: ImageDimensions,
1091
+ static_crop_offset: StaticCropOffset,
1092
+ ) -> Tuple[torch.Tensor, List[PreProcessingMetadata]]:
1093
+ original_height, original_width = image.shape[0], image.shape[1]
1094
+ size_after_pre_processing = ImageDimensions(
1095
+ height=original_height, width=original_width
1096
+ )
1097
+ canvas = np.zeros((target_size.height, target_size.width, 3), dtype=np.uint8)
1098
+ canvas_ox_padding = max(target_size.width - image.shape[1], 0)
1099
+ canvas_padding_left = canvas_ox_padding // 2
1100
+ canvas_padding_right = canvas_ox_padding - canvas_padding_left
1101
+ canvas_oy_padding = max(target_size.height - image.shape[0], 0)
1102
+ canvas_padding_top = canvas_oy_padding // 2
1103
+ canvas_padding_bottom = canvas_oy_padding - canvas_padding_top
1104
+ original_image_ox_padding = max(image.shape[1] - target_size.width, 0)
1105
+ original_image_padding_left = original_image_ox_padding // 2
1106
+ original_image_padding_right = (
1107
+ original_image_ox_padding - original_image_padding_left
1108
+ )
1109
+ original_image_oy_padding = max(image.shape[0] - target_size.height, 0)
1110
+ original_image_padding_top = original_image_oy_padding // 2
1111
+ original_image_padding_bottom = (
1112
+ original_image_oy_padding - original_image_padding_top
1113
+ )
1114
+ canvas[
1115
+ canvas_padding_top : canvas.shape[0] - canvas_padding_bottom,
1116
+ canvas_padding_left : canvas.shape[1] - canvas_padding_right,
1117
+ ] = image[
1118
+ original_image_padding_top : image.shape[0] - original_image_padding_bottom,
1119
+ original_image_padding_left : image.shape[1] - original_image_padding_right,
1120
+ ]
1121
+ if canvas.shape[0] > image.shape[0]:
1122
+ reported_padding_top = canvas_padding_top
1123
+ reported_padding_bottom = canvas_padding_bottom
1124
+ else:
1125
+ reported_padding_top = -original_image_padding_top
1126
+ reported_padding_bottom = -original_image_padding_bottom
1127
+ if canvas.shape[1] > image.shape[1]:
1128
+ reported_padding_left = canvas_padding_left
1129
+ reported_padding_right = canvas_padding_right
1130
+ else:
1131
+ reported_padding_left = -original_image_padding_left
1132
+ reported_padding_right = -original_image_padding_right
1133
+ image_metadata = PreProcessingMetadata(
1134
+ pad_left=reported_padding_left,
1135
+ pad_top=reported_padding_top,
1136
+ pad_right=reported_padding_right,
1137
+ pad_bottom=reported_padding_bottom,
1138
+ original_size=original_size,
1139
+ size_after_pre_processing=size_after_pre_processing,
1140
+ inference_size=target_size,
1141
+ scale_width=1.0,
1142
+ scale_height=1.0,
1143
+ static_crop_offset=static_crop_offset,
1144
+ )
1145
+ tensor = torch.from_numpy(canvas).to(device=target_device)
1146
+ tensor = torch.unsqueeze(tensor, 0)
1147
+ tensor = tensor.permute(0, 3, 1, 2)
1148
+ if input_color_mode != network_input.color_mode:
1149
+ tensor = tensor[:, [2, 1, 0], :, :]
1150
+ if network_input.scaling_factor is not None:
1151
+ tensor = tensor / network_input.scaling_factor
1152
+ if network_input.normalization:
1153
+ if not tensor.is_floating_point():
1154
+ tensor = tensor.to(dtype=torch.float32)
1155
+ tensor = functional.normalize(
1156
+ tensor,
1157
+ mean=network_input.normalization[0],
1158
+ std=network_input.normalization[1],
1159
+ )
1160
+ return tensor.contiguous(), [image_metadata]
1161
+
1162
+
1163
+ def handle_numpy_input_preparation_fitting_longer_edge(
1164
+ image: np.ndarray,
1165
+ network_input: NetworkInputDefinition,
1166
+ target_device: torch.device,
1167
+ input_color_mode: ColorMode,
1168
+ original_size: ImageDimensions,
1169
+ target_size: ImageDimensions,
1170
+ static_crop_offset: StaticCropOffset,
1171
+ ) -> Tuple[torch.Tensor, List[PreProcessingMetadata]]:
1172
+ original_height, original_width = image.shape[0], image.shape[1]
1173
+ size_after_pre_processing = ImageDimensions(
1174
+ height=original_height, width=original_width
1175
+ )
1176
+ scale_ox = target_size.width / size_after_pre_processing.width
1177
+ scale_oy = target_size.height / size_after_pre_processing.height
1178
+ if scale_ox < scale_oy:
1179
+ actual_target_width = target_size.width
1180
+ actual_target_height = round(scale_ox * size_after_pre_processing.height)
1181
+ else:
1182
+ actual_target_width = round(scale_oy * size_after_pre_processing.width)
1183
+ actual_target_height = target_size.height
1184
+ actual_target_size = ImageDimensions(
1185
+ height=actual_target_height,
1186
+ width=actual_target_width,
1187
+ )
1188
+ scaled_image = cv2.resize(
1189
+ image, (actual_target_size.width, actual_target_size.height)
1190
+ )
1191
+ image_metadata = PreProcessingMetadata(
1192
+ pad_left=0,
1193
+ pad_top=0,
1194
+ pad_right=0,
1195
+ pad_bottom=0,
1196
+ original_size=original_size,
1197
+ size_after_pre_processing=size_after_pre_processing,
1198
+ inference_size=actual_target_size,
1199
+ scale_width=actual_target_size.width / size_after_pre_processing.width,
1200
+ scale_height=actual_target_size.height / size_after_pre_processing.height,
1201
+ static_crop_offset=static_crop_offset,
1202
+ )
1203
+ tensor = torch.from_numpy(scaled_image).to(device=target_device)
1204
+ tensor = torch.unsqueeze(tensor, 0)
1205
+ tensor = tensor.permute(0, 3, 1, 2)
1206
+ if input_color_mode != network_input.color_mode:
1207
+ tensor = tensor[:, [2, 1, 0], :, :]
1208
+ if network_input.scaling_factor is not None:
1209
+ tensor = tensor / network_input.scaling_factor
1210
+ if network_input.normalization:
1211
+ if not tensor.is_floating_point():
1212
+ tensor = tensor.to(dtype=torch.float32)
1213
+ tensor = functional.normalize(
1214
+ tensor,
1215
+ mean=network_input.normalization[0],
1216
+ std=network_input.normalization[1],
1217
+ )
1218
+ return tensor.contiguous(), [image_metadata]
1219
+
1220
+
1221
+ NUMPY_IMAGES_PREPARATION_HANDLERS = {
1222
+ ResizeMode.STRETCH_TO: handle_numpy_input_preparation_with_stretch,
1223
+ ResizeMode.LETTERBOX: handle_numpy_input_preparation_with_letterbox,
1224
+ ResizeMode.CENTER_CROP: handle_numpy_input_preparation_with_center_crop,
1225
+ ResizeMode.FIT_LONGER_EDGE: handle_numpy_input_preparation_fitting_longer_edge,
1226
+ ResizeMode.LETTERBOX_REFLECT_EDGES: handle_numpy_input_preparation_with_letterbox,
1227
+ }
1228
+
1229
+
1230
+ def extract_input_images_dimensions(
1231
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
1232
+ ) -> List[ImageDimensions]:
1233
+ if isinstance(images, np.ndarray):
1234
+ return [ImageDimensions(height=images.shape[0], width=images.shape[1])]
1235
+ if isinstance(images, torch.Tensor):
1236
+ if len(images.shape) == 3:
1237
+ images = torch.unsqueeze(images, dim=0)
1238
+ image_dimensions = []
1239
+ for image in images:
1240
+ image_dimensions.append(
1241
+ ImageDimensions(height=image.shape[1], width=image.shape[2])
1242
+ )
1243
+ return image_dimensions
1244
+ if not isinstance(images, list):
1245
+ raise ModelRuntimeError(
1246
+ message="Pre-processing supports only np.array or torch.Tensor or list of above.",
1247
+ help_url="https://todo",
1248
+ )
1249
+ if not len(images):
1250
+ raise ModelRuntimeError(
1251
+ message="Detected empty input to the model", help_url="https://todo"
1252
+ )
1253
+ if isinstance(images[0], np.ndarray):
1254
+ return [ImageDimensions(height=i.shape[0], width=i.shape[1]) for i in images]
1255
+ if isinstance(images[0], torch.Tensor):
1256
+ image_dimensions = []
1257
+ for image in images:
1258
+ image_dimensions.append(
1259
+ ImageDimensions(height=image.shape[1], width=image.shape[2])
1260
+ )
1261
+ return image_dimensions
1262
+ raise ModelRuntimeError(
1263
+ message=f"Detected unknown input batch element: {type(images[0])}",
1264
+ help_url="https://todo",
1265
+ )
1266
+
1267
+
1268
+ def images_to_pillow(
1269
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
1270
+ input_color_format: Optional[ColorFormat] = None,
1271
+ model_color_format: ColorFormat = "rgb",
1272
+ ) -> Tuple[List[Image], List[ImageDimensions]]:
1273
+ if isinstance(images, np.ndarray):
1274
+ input_color_format = input_color_format or "bgr"
1275
+ if input_color_format != model_color_format:
1276
+ images = images[:, :, ::-1]
1277
+ h, w = images.shape[:2]
1278
+ return [PIL.Image.fromarray(images)], [ImageDimensions(height=h, width=w)]
1279
+ if isinstance(images, torch.Tensor):
1280
+ input_color_format = input_color_format or "rgb"
1281
+ if len(images.shape) == 3:
1282
+ images = torch.unsqueeze(images, dim=0)
1283
+ if input_color_format != model_color_format:
1284
+ images = images[:, [2, 1, 0], :, :]
1285
+ result = []
1286
+ dimensions = []
1287
+ for image in images:
1288
+ np_image = image.permute(1, 2, 0).cpu().numpy()
1289
+ result.append(PIL.Image.fromarray(np_image))
1290
+ dimensions.append(
1291
+ ImageDimensions(height=np_image.shape[0], width=np_image.shape[1])
1292
+ )
1293
+ return result, dimensions
1294
+ if not isinstance(images, list):
1295
+ raise ModelRuntimeError(
1296
+ message="Pre-processing supports only np.array or torch.Tensor or list of above.",
1297
+ help_url="https://todo",
1298
+ )
1299
+ if not len(images):
1300
+ raise ModelRuntimeError(
1301
+ message="Detected empty input to the model", help_url="https://todo"
1302
+ )
1303
+ if isinstance(images[0], np.ndarray):
1304
+ input_color_format = input_color_format or "bgr"
1305
+ if input_color_format != model_color_format:
1306
+ images = [i[:, :, ::-1] for i in images]
1307
+ dimensions = [
1308
+ ImageDimensions(height=i.shape[0], width=i.shape[1]) for i in images
1309
+ ]
1310
+ images = [PIL.Image.fromarray(i) for i in images]
1311
+ return images, dimensions
1312
+ if isinstance(images[0], torch.Tensor):
1313
+ result = []
1314
+ dimensions = []
1315
+ input_color_format = input_color_format or "rgb"
1316
+ for image in images:
1317
+ if input_color_format != model_color_format:
1318
+ image = image[[2, 1, 0], :, :]
1319
+ np_image = image.permute(1, 2, 0).cpu().numpy()
1320
+ result.append(PIL.Image.fromarray(np_image))
1321
+ dimensions.append(
1322
+ ImageDimensions(height=np_image.shape[0], width=np_image.shape[1])
1323
+ )
1324
+ return result, dimensions
1325
+ raise ModelRuntimeError(
1326
+ message=f"Detected unknown input batch element: {type(images[0])}",
1327
+ help_url="https://todo",
1328
+ )
1329
+
1330
+
1331
+ def make_the_value_divisible(x: int, by: int) -> int:
1332
+ return math.ceil(x / by) * by