matrice-inference 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of matrice-inference might be problematic. Click here for more details.

Files changed (37) hide show
  1. matrice_inference/__init__.py +72 -0
  2. matrice_inference/py.typed +0 -0
  3. matrice_inference/server/__init__.py +23 -0
  4. matrice_inference/server/inference_interface.py +176 -0
  5. matrice_inference/server/model/__init__.py +1 -0
  6. matrice_inference/server/model/model_manager.py +274 -0
  7. matrice_inference/server/model/model_manager_wrapper.py +550 -0
  8. matrice_inference/server/model/triton_model_manager.py +290 -0
  9. matrice_inference/server/model/triton_server.py +1248 -0
  10. matrice_inference/server/proxy_interface.py +371 -0
  11. matrice_inference/server/server.py +1004 -0
  12. matrice_inference/server/stream/__init__.py +0 -0
  13. matrice_inference/server/stream/app_deployment.py +228 -0
  14. matrice_inference/server/stream/consumer_worker.py +201 -0
  15. matrice_inference/server/stream/frame_cache.py +127 -0
  16. matrice_inference/server/stream/inference_worker.py +163 -0
  17. matrice_inference/server/stream/post_processing_worker.py +230 -0
  18. matrice_inference/server/stream/producer_worker.py +147 -0
  19. matrice_inference/server/stream/stream_pipeline.py +451 -0
  20. matrice_inference/server/stream/utils.py +23 -0
  21. matrice_inference/tmp/abstract_model_manager.py +58 -0
  22. matrice_inference/tmp/aggregator/__init__.py +18 -0
  23. matrice_inference/tmp/aggregator/aggregator.py +330 -0
  24. matrice_inference/tmp/aggregator/analytics.py +906 -0
  25. matrice_inference/tmp/aggregator/ingestor.py +438 -0
  26. matrice_inference/tmp/aggregator/latency.py +597 -0
  27. matrice_inference/tmp/aggregator/pipeline.py +968 -0
  28. matrice_inference/tmp/aggregator/publisher.py +431 -0
  29. matrice_inference/tmp/aggregator/synchronizer.py +594 -0
  30. matrice_inference/tmp/batch_manager.py +239 -0
  31. matrice_inference/tmp/overall_inference_testing.py +338 -0
  32. matrice_inference/tmp/triton_utils.py +638 -0
  33. matrice_inference-0.1.2.dist-info/METADATA +28 -0
  34. matrice_inference-0.1.2.dist-info/RECORD +37 -0
  35. matrice_inference-0.1.2.dist-info/WHEEL +5 -0
  36. matrice_inference-0.1.2.dist-info/licenses/LICENSE.txt +21 -0
  37. matrice_inference-0.1.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1248 @@
1
+ """Module providing triton_server functionality."""
2
+
3
+ import os
4
+ import zipfile
5
+ import subprocess
6
+ import tempfile
7
+ import asyncio
8
+ import shutil
9
+ import logging
10
+ import threading
11
+ import shlex
12
+ from typing import Tuple, Optional, Any, Dict, Union, List
13
+ import importlib.util
14
+ from matrice_common.utils import dependencies_check
15
+
16
+ # TRITON_DOCKER_IMAGE = "nvcr.io/nvidia/tritonserver:24.08-py3"
17
+ TRITON_DOCKER_IMAGE = "nvcr.io/nvidia/tritonserver:23.08-py3"
18
+ BASE_PATH = "./model_repository"
19
+
20
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
21
+
22
+ class TritonServer:
23
+ def __init__(
24
+ self,
25
+ model_name: str,
26
+ model_path: str,
27
+ runtime_framework: str,
28
+ input_size: Union[int, List[int]] = 224,
29
+ num_classes: int = 10,
30
+ dynamic_batching: bool = False,
31
+ num_model_instances: int = 1,
32
+ max_batch_size: int = 8,
33
+ connection_protocol: str = "rest",
34
+ is_yolo: bool = False,
35
+ is_ocr: bool = False,
36
+ use_trt_accelerator: bool = False,
37
+ **kwargs,
38
+ ):
39
+ """Initialize the Triton server.
40
+
41
+ Args:
42
+ model_name: Name of the model (used for Triton model repository).
43
+ model_path: Path to the model file on the local filesystem.
44
+ runtime_framework: Framework of the model ('onnx', 'pytorch', 'torchscript', 'yolo', 'tensorrt', 'openvino').
45
+ input_size: Input size for the model (int for square images or [height, width]).
46
+ num_classes: Number of output classes.
47
+ dynamic_batching: Enable dynamic batching for the model.
48
+ num_model_instances: Number of model instances to deploy.
49
+ max_batch_size: Maximum batch size for inference.
50
+ connection_protocol: Protocol for Triton server ('rest' or 'grpc').
51
+ use_trt_accelerator: Enable TensorRT acceleration for inference.
52
+
53
+ is_yolo: Boolean indicating if the model is a YOLO model.
54
+ is_ocr: Boolean indicating if the model is an OCR model.
55
+ """
56
+ if not dependencies_check("torch"):
57
+ raise ImportError("PyTorch is required but not installed")
58
+ import torch
59
+
60
+ if not model_name:
61
+ raise ValueError("model_name must be provided")
62
+ if not model_path:
63
+ raise ValueError("model_path must be provided")
64
+ if not os.path.exists(model_path):
65
+ raise FileNotFoundError(f"Model file not found at: {model_path}")
66
+
67
+ logging.info("Initializing TritonServer")
68
+
69
+ self.model_name = model_name
70
+ self.model_path = os.path.abspath(model_path)
71
+ self.runtime_framework = runtime_framework.lower()
72
+ self.connection_protocol = connection_protocol.lower()
73
+
74
+ if isinstance(input_size, (int, float)):
75
+ self.input_size = [int(input_size), int(input_size)]
76
+ elif isinstance(input_size, (list, tuple)):
77
+ if len(input_size) == 2: # (H, W)
78
+ self.input_size = [int(input_size[0]), int(input_size[1])]
79
+ elif len(input_size) == 3: # (H, W, C)
80
+ self.input_size = [int(input_size[0]), int(input_size[1])]
81
+ elif len(input_size) == 4: # (N, C, H, W)
82
+ self.input_size = [int(input_size[-2]), int(input_size[-1])]
83
+ else:
84
+ logging.warning("Unexpected input_size length: %s, using default [224, 224]", input_size)
85
+ self.input_size = [224, 224]
86
+
87
+ self.num_classes = num_classes
88
+ self.dynamic_batching = dynamic_batching
89
+ self.num_model_instances = num_model_instances
90
+ self.max_batch_size = max_batch_size
91
+ self.use_trt_accelerator = use_trt_accelerator
92
+
93
+ self.is_yolo = is_yolo or self.runtime_framework == "yolo"
94
+ self.is_ocr = is_ocr
95
+ self.input_name = "images" if self.is_yolo else "input"
96
+ self.output_name = "output0" if self.is_yolo else "output"
97
+
98
+ self.gpus_count = torch.cuda.device_count()
99
+ self.config_params = {}
100
+
101
+ logging.info("Model name: %s", self.model_name)
102
+ logging.info("Model path: %s", self.model_path)
103
+ logging.info("Runtime framework: %s", self.runtime_framework)
104
+ logging.info("Using connection protocol: %s", self.connection_protocol)
105
+ logging.info("Input size: %s", self.input_size)
106
+ logging.info("Number of classes: %s", self.num_classes)
107
+ logging.info("Found %s GPUs available for inference", self.gpus_count)
108
+
109
+ self.docker_pull_process = subprocess.Popen(
110
+ [
111
+ shutil.which("docker"),
112
+ "pull",
113
+ TRITON_DOCKER_IMAGE,
114
+ ],
115
+ stdout=subprocess.PIPE,
116
+ stderr=subprocess.PIPE,
117
+ text=True,
118
+ )
119
+
120
+ def check_triton_docker_image(self):
121
+ """Check if docker image download is complete and wait for it to finish"""
122
+ logging.info("Checking docker image download status")
123
+ stdout, stderr = self.docker_pull_process.communicate()
124
+ if self.docker_pull_process.returncode == 0:
125
+ logging.info(f"Docker image {TRITON_DOCKER_IMAGE} downloaded successfully")
126
+ else:
127
+ error_msg = stderr.decode()
128
+ logging.error(f"Docker pull failed with return code {self.docker_pull_process.returncode}")
129
+ logging.error("Error message: %s", error_msg)
130
+ raise RuntimeError(f"Docker pull failed: {error_msg}")
131
+
132
+ def prepare_model(self, model_version_dir: str) -> None:
133
+ """Prepare the model file for Triton Inference Server.
134
+
135
+ Copies model from self.model_path to model_version_dir and converts if necessary
136
+ to the format expected by Triton (model.onnx, model.xml, model.plan).
137
+
138
+ Args:
139
+ model_version_dir: Directory to store the model file
140
+ (e.g., '/models/<model_name>/1').
141
+ """
142
+ try:
143
+ runtime_framework = self.runtime_framework
144
+ logging.info("Preparing model with runtime framework: %s", runtime_framework)
145
+ logging.info("Source model path: %s", self.model_path)
146
+
147
+ if runtime_framework not in ["onnx", "pytorch", "torchscript", "yolo", "tensorrt", "openvino"]:
148
+ logging.error("Runtime framework '%s' not supported. Supported: %s",
149
+ runtime_framework, ["onnx", "pytorch", "torchscript", "yolo", "tensorrt", "openvino"])
150
+ raise ValueError(f"Unsupported runtime framework: {runtime_framework}")
151
+
152
+ os.makedirs(model_version_dir, exist_ok=True)
153
+
154
+ # 1. ONNX - copy to model.onnx
155
+ if runtime_framework == "onnx":
156
+ model_file = os.path.join(model_version_dir, "model.onnx")
157
+ if os.path.abspath(self.model_path) != os.path.abspath(model_file):
158
+ shutil.copy2(self.model_path, model_file)
159
+ logging.info("Copied ONNX model to: %s", model_file)
160
+ else:
161
+ logging.info("Model path is already correct: %s", model_file)
162
+ self._verify_onnx_model(model_file)
163
+
164
+ # 2. PyTorch/TorchScript/Yolo - export to ONNX
165
+ elif runtime_framework in ["pytorch", "torchscript", "yolo"]:
166
+ model_file = os.path.join(model_version_dir, "model.onnx")
167
+ img_chw = (3, self.input_size[0], self.input_size[1])
168
+ logging.info("Converting %s model to ONNX with input shape: %s", runtime_framework, img_chw)
169
+ self.to_onnx(self.model_path, model_file, (1, *img_chw))
170
+ logging.info("Exported ONNX model to: %s", model_file)
171
+ self._verify_onnx_model(model_file)
172
+
173
+ # 3. TensorRT - copy to model.plan
174
+ elif runtime_framework == "tensorrt":
175
+ model_file = os.path.join(model_version_dir, "model.plan")
176
+ shutil.copy2(self.model_path, model_file)
177
+ logging.info("Copied TensorRT model to: %s", model_file)
178
+ self.runtime_framework = "tensorrt"
179
+
180
+ # 4. OpenVINO - extract ZIP or copy files
181
+ elif runtime_framework == "openvino":
182
+ if self.model_path.endswith('.zip'):
183
+ logging.info("Extracting OpenVINO ZIP to: %s", model_version_dir)
184
+ with zipfile.ZipFile(self.model_path, "r") as zip_ref:
185
+ zip_ref.extractall(model_version_dir)
186
+ model_file = os.path.join(model_version_dir, "model.xml")
187
+ model_bin_file = os.path.join(model_version_dir, "model.bin")
188
+ else:
189
+ model_file = os.path.join(model_version_dir, "model.xml")
190
+ model_bin_file = os.path.join(model_version_dir, "model.bin")
191
+
192
+ shutil.copy2(self.model_path, model_file)
193
+
194
+ source_bin = self.model_path.replace('.xml', '.bin')
195
+ if os.path.exists(source_bin):
196
+ shutil.copy2(source_bin, model_bin_file)
197
+ else:
198
+ raise RuntimeError(f"OpenVINO model.bin not found at {source_bin}")
199
+
200
+ if not os.path.exists(model_file):
201
+ logging.error("OpenVINO model.xml not found at %s", model_file)
202
+ raise RuntimeError(f"OpenVINO model.xml not found at {model_file}")
203
+ if not os.path.exists(model_bin_file):
204
+ logging.error("OpenVINO model.bin not found at %s", model_bin_file)
205
+ raise RuntimeError(f"OpenVINO model.bin not found at {model_bin_file}")
206
+ logging.info("Prepared OpenVINO model: %s", model_file)
207
+
208
+ logging.info("Model preparation completed successfully")
209
+ except Exception as e:
210
+ logging.error("Model preparation failed: %s", str(e), exc_info=True)
211
+ raise
212
+
213
+ def to_onnx(self, checkpoint_path: str, onnx_path: str, input_shape: Tuple[int, int, int, int]) -> None:
214
+ """Export PyTorch or YOLO checkpoint to ONNX."""
215
+ try:
216
+ runtime_framework = self.runtime_framework.lower()
217
+ logging.info("Exporting %s model to ONNX on CPU", runtime_framework)
218
+
219
+ if runtime_framework == "yolo":
220
+ try:
221
+ if dependencies_check("ultralytics"):
222
+ from ultralytics import YOLO
223
+ logging.info("Using Ultralytics YOLO for ONNX export")
224
+ model = YOLO(checkpoint_path)
225
+ # NOTE: Update 4 -- opset=12 for YOLOv8 compatibility
226
+ export_path = model.export(format="onnx", imgsz=input_shape[2], dynamic=True, opset=12) # Added opset=12
227
+ if export_path != onnx_path:
228
+ shutil.move(export_path, onnx_path)
229
+ logging.info("Exported YOLO model to ONNX: %s", onnx_path)
230
+ return
231
+ else:
232
+ logging.warning("Ultralytics not available; falling back to PyTorch export for YOLO")
233
+ except Exception as e:
234
+ logging.warning("Ultralytics YOLO export failed: %s; trying PyTorch export", str(e))
235
+
236
+ import torch
237
+ model = torch.load(checkpoint_path, map_location="cpu")
238
+ # TODO: Add support for model_state_dict
239
+ model.eval()
240
+ dummy_input = torch.randn(*input_shape)
241
+ torch.onnx.export(
242
+ model,
243
+ dummy_input,
244
+ onnx_path,
245
+ opset_version=17,
246
+ input_names=["input__0"],
247
+ output_names=["output__0"],
248
+ dynamic_axes={"input__0": {0: "batch"}, "output__0": {0: "batch"}},
249
+ )
250
+ logging.info("Exported PyTorch model to ONNX: %s", onnx_path)
251
+ except Exception as e:
252
+ logging.error("Failed to export to ONNX: %s", str(e), exc_info=True)
253
+ raise
254
+
255
+ def _verify_onnx_model(self, onnx_path: str):
256
+ """Verify that the ONNX model is valid"""
257
+ try:
258
+ if dependencies_check("onnx"):
259
+ import onnx
260
+ model = onnx.load(onnx_path)
261
+ onnx.checker.check_model(model)
262
+ logging.info(f"ONNX model verification successful: {onnx_path}")
263
+ else:
264
+ logging.warning("ONNX library not available for model verification")
265
+ except Exception as e:
266
+ logging.error(f"ONNX model verification failed: %s", str(e))
267
+ raise ValueError(f"Invalid ONNX model at {onnx_path}: {str(e)}")
268
+
269
+ def create_model_repository(self):
270
+ """Create the model repository directory structure"""
271
+ try:
272
+ model_version = "1"
273
+ model_dir = os.path.join(BASE_PATH, self.model_name)
274
+ version_dir = os.path.join(model_dir, str(model_version))
275
+ logging.info("Creating model repository structure:")
276
+ logging.info("Base path: %s", BASE_PATH)
277
+ logging.info("Model directory: %s", model_dir)
278
+ logging.info("Version directory: %s", version_dir)
279
+ os.makedirs(version_dir, exist_ok=True)
280
+ logging.info("Model repository directories created successfully")
281
+ return model_dir, version_dir
282
+ except Exception as e:
283
+ logging.error(
284
+ "Failed to create model repository: %s",
285
+ str(e),
286
+ exc_info=True,
287
+ )
288
+ raise
289
+
290
+ def write_config_file(
291
+ self,
292
+ model_dir: str,
293
+ max_batch_size: int = 8,
294
+ num_model_instances: int = 1,
295
+ image_size: List[int] = [224, 224],
296
+ num_classes: int = 10,
297
+ input_data_type: str = "TYPE_FP32",
298
+ output_data_type: str = "TYPE_FP32",
299
+ dynamic_batching: bool = False,
300
+ preferred_batch_size: list = [2, 4, 8],
301
+ max_queue_delay_microseconds: int = 100,
302
+ input_pinned_memory: bool = True,
303
+ output_pinned_memory: bool = True,
304
+ **kwargs,
305
+ ):
306
+ """Write the model configuration file for Triton Inference Server."""
307
+ try:
308
+ runtime_framework = self.runtime_framework.lower()
309
+ logging.info("Starting to write Triton config file for framework: %s", runtime_framework)
310
+
311
+ if runtime_framework == "tensorrt":
312
+ platform = "tensorrt_plan"
313
+ model_filename = "model.plan"
314
+ elif runtime_framework in ["pytorch", "torchscript", "yolo", "onnx"]:
315
+ platform = "onnxruntime_onnx"
316
+ model_filename = "model.onnx"
317
+ else:
318
+ platform = "openvino"
319
+ model_filename = "model.xml"
320
+ logging.info("Using %s backend with model file: %s", platform, model_filename)
321
+
322
+ config_path = os.path.join(model_dir, "config.pbtxt")
323
+ logging.info("Writing config to: %s", config_path)
324
+
325
+ # NOTE: Update X0
326
+ onnx_to_triton_dtype = {
327
+ 1: "TYPE_FP32", # FLOAT
328
+ 2: "TYPE_UINT8", # UINT8
329
+ 3: "TYPE_INT8", # INT8
330
+ 4: "TYPE_UINT16", # UINT16
331
+ 5: "TYPE_INT16", # INT16
332
+ 6: "TYPE_INT32", # INT32
333
+ 7: "TYPE_INT64", # INT64
334
+ 8: "TYPE_STRING", # STRING
335
+ 9: "TYPE_BOOL", # BOOL
336
+ 10: "TYPE_FP16", # HALF
337
+ 11: "TYPE_FP64", # DOUBLE
338
+ 12: "TYPE_UINT32",# UINT32
339
+ 13: "TYPE_UINT64",# UINT64
340
+ }
341
+
342
+ if platform == "onnxruntime_onnx":
343
+ model_file = os.path.join(model_dir, "1", "model.onnx")
344
+ if os.path.exists(model_file) and dependencies_check("onnx"):
345
+ import onnx
346
+ import onnx.numpy_helper
347
+ model = onnx.load(model_file)
348
+ graph = model.graph
349
+
350
+ inputs = []
351
+ for inp in graph.input:
352
+ shape = [d.dim_value if d.HasField("dim_value") else -1 for d in inp.type.tensor_type.shape.dim]
353
+ dtype_id = inp.type.tensor_type.elem_type
354
+ dtype = onnx_to_triton_dtype.get(dtype_id, "TYPE_FP32") # Fallback to FP32 if unknown
355
+ inputs.append((inp.name, dtype, shape))
356
+
357
+ outputs = []
358
+ for out in graph.output:
359
+ shape = [d.dim_value if d.HasField("dim_value") else -1 for d in out.type.tensor_type.shape.dim]
360
+ if self.is_yolo and out.name != "output__0":
361
+ continue # Skip intermediate YOLO outputs
362
+ output_shape = shape[1:] if max_batch_size > 0 and len(shape) > 1 else shape
363
+ dtype_id = out.type.tensor_type.elem_type
364
+ dtype = onnx_to_triton_dtype.get(dtype_id, "TYPE_FP32") # Fallback to FP32 if unknown
365
+ outputs.append((out.name, dtype, output_shape))
366
+
367
+ logging.info("ONNX inputs: %s", inputs)
368
+ logging.info("ONNX outputs: %s", outputs)
369
+ else:
370
+ # Fallback when ONNX / model 404
371
+ inputs = [(self.input_name, input_data_type, [3, image_size[0], image_size[1]])]
372
+ outputs = [(self.output_name, output_data_type, [-1, -1] if self.is_yolo else [num_classes])]
373
+ elif platform == "tensorrt_plan":
374
+ # TensorRT: Use OCR-specific configuration if is_ocr is True
375
+ if self.is_ocr:
376
+ # NOTE update Y1 -- hardcoded for OCR model
377
+ inputs = [("input", "TYPE_UINT8", [64, 128, 3])]
378
+ outputs = [("Identity:0", "TYPE_FP32", [9, 37])]
379
+ else:
380
+ # Fallback for YOLO or other TensorRT models
381
+ inputs = [(self.input_name, input_data_type, [3, image_size[0], image_size[1]])]
382
+ outputs = [(self.output_name, output_data_type, [-1, -1] if self.is_yolo else [num_classes])]
383
+ else:
384
+ # OpenVINO fallback
385
+ inputs = [(self.input_name, input_data_type, [3, image_size[0], image_size[1]])]
386
+ outputs = [(self.output_name, output_data_type, [num_classes])]
387
+
388
+ logging.info("Final inputs for config: %s", inputs)
389
+ logging.info("Final outputs for config: %s", outputs)
390
+
391
+ config_content = f'name: "{self.model_name}"\n'
392
+ config_content += f'platform: "{platform}"\n'
393
+ config_content += f'max_batch_size: {max_batch_size}\n'
394
+
395
+ # Input section
396
+ config_content += 'input [\n'
397
+ for name, dtype, shape in inputs:
398
+ config_content += ' {\n'
399
+ config_content += f' name: "{name}"\n'
400
+ config_content += f' data_type: {dtype}\n'
401
+ config_content += f' dims: [{", ".join(str(dim) for dim in shape)}]\n'
402
+ config_content += ' }\n'
403
+ config_content += ']\n'
404
+
405
+ # Output section
406
+ config_content += 'output [\n'
407
+ for name, dtype, shape in outputs:
408
+ config_content += ' {\n'
409
+ config_content += f' name: "{name}"\n'
410
+ config_content += f' data_type: {dtype}\n'
411
+ config_content += f' dims: [{", ".join(str(dim) for dim in shape)}]\n'
412
+ config_content += ' }\n'
413
+ config_content += ']\n'
414
+
415
+ # Instance group
416
+ if num_model_instances > 1 or self.gpus_count > 0:
417
+ device_type = "KIND_GPU" if self.gpus_count > 0 else "KIND_CPU"
418
+ logging.info("Adding instance group configuration for %s %s instances", num_model_instances, device_type)
419
+ config_content += 'instance_group [\n'
420
+ config_content += ' {\n'
421
+ config_content += f' count: {num_model_instances}\n'
422
+ config_content += f' kind: {device_type}\n'
423
+ config_content += ' }\n'
424
+ config_content += ']\n'
425
+
426
+ # Dynamic batching
427
+ if dynamic_batching:
428
+ logging.info("Adding dynamic batching configuration")
429
+ valid_pref_sizes = [bs for bs in preferred_batch_size if bs <= max_batch_size]
430
+ if valid_pref_sizes:
431
+ config_content += 'dynamic_batching {\n'
432
+ config_content += f' preferred_batch_size: [{", ".join(str(bs) for bs in valid_pref_sizes)}]\n'
433
+ config_content += f' max_queue_delay_microseconds: {max_queue_delay_microseconds}\n'
434
+ config_content += '}\n'
435
+
436
+ # Optimization settings
437
+ if input_pinned_memory or output_pinned_memory or self.gpus_count > 0:
438
+ config_content += 'optimization {\n'
439
+ if input_pinned_memory:
440
+ config_content += ' input_pinned_memory {\n'
441
+ config_content += ' enable: true\n'
442
+ config_content += ' }\n'
443
+ if output_pinned_memory:
444
+ config_content += ' output_pinned_memory {\n'
445
+ config_content += ' enable: true\n'
446
+ config_content += ' }\n'
447
+ if self.gpus_count > 0 and self.use_trt_accelerator:
448
+ config_content += ' execution_accelerators {\n'
449
+ config_content += ' gpu_execution_accelerator {\n'
450
+ config_content += ' name: "tensorrt"\n'
451
+ config_content += ' parameters {\n'
452
+ config_content += ' key: "precision_mode"\n'
453
+ config_content += ' value: "FP16"\n'
454
+ config_content += ' }\n'
455
+ config_content += ' parameters {\n'
456
+ config_content += ' key: "max_workspace_size_bytes"\n'
457
+ config_content += ' value: "1073741824"\n'
458
+ config_content += ' }\n'
459
+ config_content += ' parameters {\n'
460
+ config_content += ' key: "trt_engine_cache_enable"\n'
461
+ config_content += ' value: "1"\n'
462
+ config_content += ' }\n'
463
+ config_content += ' parameters {\n'
464
+ config_content += ' key: "trt_engine_cache_path"\n'
465
+ config_content += f' value: "/models/{self.model_name}/1"\n'
466
+ config_content += ' }\n'
467
+ config_content += ' }\n'
468
+ config_content += ' }\n'
469
+ config_content += '}\n'
470
+
471
+ with open(config_path, "w") as f:
472
+ f.write(config_content)
473
+
474
+ logging.info("Config file written successfully")
475
+ logging.info("Config content:\n%s", config_content)
476
+
477
+ except Exception as e:
478
+ logging.error("Failed to write config file: %s", str(e), exc_info=True)
479
+ raise
480
+
481
+
482
+ def get_config_params(self):
483
+ """Get configuration parameters for Triton config file"""
484
+ try:
485
+ logging.info("Retrieving configuration parameters")
486
+
487
+ logging.info("Using input size: %s", self.input_size)
488
+ logging.info("Using number of classes: %s", self.num_classes)
489
+
490
+ params = {
491
+ "max_batch_size": self.max_batch_size,
492
+ "num_instances": self.num_model_instances,
493
+ "image_size": self.input_size,
494
+ "num_classes": self.num_classes,
495
+ "input_data_type": "TYPE_FP32",
496
+ "output_data_type": "TYPE_FP32",
497
+ "dynamic_batching": self.dynamic_batching,
498
+ "preferred_batch_size": [1, 2, 4, 8],
499
+ "max_queue_delay_microseconds": 100,
500
+ "input_pinned_memory": True,
501
+ "output_pinned_memory": True,
502
+ }
503
+
504
+ logging.debug("Final configuration parameters: %s", params)
505
+ return params
506
+
507
+ except Exception as e:
508
+ logging.error(
509
+ "Failed to get configuration parameters: %s",
510
+ str(e),
511
+ exc_info=True,
512
+ )
513
+ raise
514
+
515
+ def start_server(self, internal_port: int = 8000):
516
+ """Start the Triton Inference Server
517
+
518
+ Args:
519
+ internal_port: Port to expose the server on
520
+ """
521
+ gpu_option = "--gpus=all " if self.gpus_count > 0 else ""
522
+ logging.debug("Starting Triton server with GPU option: '%s'", gpu_option.strip())
523
+ triton_port = 8000 if self.connection_protocol == 'rest' else 8001
524
+ port_mapping = f"-p{internal_port}:{triton_port}"
525
+ start_triton_server = f"docker run {gpu_option}--rm {port_mapping} -v {os.path.abspath(BASE_PATH)}:/models --label model_name={self.model_name} {TRITON_DOCKER_IMAGE} tritonserver --model-repository=/models "
526
+ logging.info("Checking docker image download status before starting server")
527
+ self.check_triton_docker_image()
528
+ try:
529
+ logging.info(
530
+ "Starting Triton server with command: %s",
531
+ start_triton_server,
532
+ )
533
+ self.process = subprocess.Popen(
534
+ shlex.split(start_triton_server),
535
+ stdout=subprocess.PIPE,
536
+ stderr=subprocess.PIPE,
537
+ text=True,
538
+ )
539
+
540
+ def log_output():
541
+ while True:
542
+ stdout_line = self.process.stdout.readline()
543
+ stderr_line = self.process.stderr.readline()
544
+ if stdout_line:
545
+ logging.info(stdout_line.strip())
546
+ if stderr_line:
547
+ logging.info(stderr_line.strip())
548
+ if stdout_line == "" and stderr_line == "" and self.process.poll() is not None:
549
+ break
550
+
551
+ threading.Thread(target=log_output, daemon=False).start()
552
+ logging.info(
553
+ "Triton server started successfully on port %s",
554
+ internal_port,
555
+ )
556
+ return self.process
557
+ except Exception as e:
558
+ logging.error(
559
+ "Failed to start Triton server: %s",
560
+ str(e),
561
+ exc_info=True,
562
+ )
563
+ raise
564
+
565
+ def setup(self, internal_port: int = 8000):
566
+ """Setup the Triton server with the provided model.
567
+
568
+ Args:
569
+ internal_port: Port to expose the server on
570
+ """
571
+ try:
572
+ logging.info("Beginning Triton server setup")
573
+ logging.info("Step 1: Creating model repository")
574
+ self.model_dir, self.version_dir = self.create_model_repository()
575
+ logging.info("Step 2: Preparing model")
576
+ self.prepare_model(self.version_dir)
577
+ logging.info("Step 3: Getting configuration parameters")
578
+ self.config_params = self.get_config_params()
579
+ logging.info("Step 4: Writing configuration file")
580
+ self.write_config_file(
581
+ self.model_dir,
582
+ **self.config_params,
583
+ )
584
+ logging.info("Step 5: Starting Triton server")
585
+ self.process = self.start_server(internal_port)
586
+ logging.info("Triton server setup completed successfully")
587
+ return self.process
588
+ except Exception as e:
589
+ logging.error(
590
+ "Triton server setup failed: %s",
591
+ str(e),
592
+ exc_info=True,
593
+ )
594
+ raise
595
+
596
+
597
+ """Module providing inference_utils functionality for FastAPI and Triton inference."""
598
+
599
+ from PIL import Image
600
+ import httpx
601
+ import logging
602
+ from typing import Optional, Dict, Union, Any
603
+ from datetime import datetime, timezone
604
+ from io import BytesIO
605
+ import numpy as np
606
+ import cv2
607
+ import torch
608
+ import torchvision
609
+ from typing import Tuple, Dict, Any, Optional, Union
610
+ import logging
611
+ from PIL import Image
612
+ from io import BytesIO
613
+ import os
614
+ from datetime import datetime, timezone
615
+
616
+ class TritonInference:
617
+ """Class for making Triton inference requests."""
618
+
619
+ def __init__(
620
+ self,
621
+ server_type: str,
622
+ model_name: str,
623
+ internal_port: int = 80,
624
+ internal_host: str = "localhost",
625
+ task_type: str = "detection",
626
+ runtime_framework: str = "onnx",
627
+ is_yolo: bool = False,
628
+ is_ocr: bool = False,
629
+ input_size: Union[int, List[int]] = (224, 224)
630
+ ):
631
+ """Initialize Triton inference client.
632
+
633
+ Args:
634
+ server_type: Type of server (grpc/rest)
635
+ model_name: Name of model to use
636
+ internal_port: Port number for internal API
637
+ internal_host: Hostname for internal API
638
+ task_type: Type of task (e.g., detection)
639
+ runtime_framework: Framework used for the model (e.g., onnx)
640
+ is_yolo: Boolean indicating if the model is YOLO
641
+ is_ocr: Boolean indicating if the model is an OCR model
642
+ input_size: Input size for the model (int or [height, width])
643
+ """
644
+ self.model_name = model_name
645
+ self.task_type = task_type
646
+ self.runtime_framework = runtime_framework
647
+ self.is_yolo = is_yolo
648
+ self.is_ocr = is_ocr
649
+ self.input_size = [input_size, input_size] if isinstance(input_size, int) else input_size
650
+ self.ocr_config = {
651
+ "color_mode": "rgba",
652
+ "keep_aspect_ratio": True,
653
+ "interpolation": "linear",
654
+ "padding_color": (114, 114, 114, 255),
655
+ }
656
+ self.data_type_mapping = {
657
+ 2: "TYPE_UINT8",
658
+ 6: "TYPE_INT8",
659
+ 7: "TYPE_INT16",
660
+ 8: "TYPE_INT32",
661
+ 9: "TYPE_INT64",
662
+ 10: "TYPE_FP16",
663
+ 11: "TYPE_FP32",
664
+ 12: "TYPE_FP64",
665
+ }
666
+ self.numpy_data_type_mapping = {
667
+ "INT8": np.int8,
668
+ "INT16": np.int16,
669
+ "INT32": np.int32,
670
+ "INT64": np.int64,
671
+ "FP16": np.float16,
672
+ "FP32": np.float32,
673
+ "FP64": np.float64,
674
+ "UINT8": np.uint8,
675
+ }
676
+ self.setup_client_funcs = {
677
+ "grpc": self._setup_grpc_client,
678
+ "rest": self._setup_rest_client,
679
+ }
680
+ self.url = f"{internal_host}:{internal_port}"
681
+ self.connection_protocol = "grpc" if "grpc" in server_type else "rest"
682
+ self.tritonclientclass = None
683
+ self._dependencies_check()
684
+ self.client_info = self.setup_client_funcs[self.connection_protocol]()
685
+ logging.info(
686
+ "Initialized TritonInference with %s protocol",
687
+ self.connection_protocol,
688
+ )
689
+
690
+ def _dependencies_check(self):
691
+ """Check and import required Triton dependencies."""
692
+ try:
693
+ if self.connection_protocol == "rest":
694
+ import tritonclient.http as tritonclientclass
695
+ else:
696
+ import tritonclient.grpc as tritonclientclass
697
+ self.tritonclientclass = tritonclientclass
698
+ except ImportError as err:
699
+ package_name = "tritonclient[http]" if self.connection_protocol == "rest" else "tritonclient[grpc]"
700
+ logging.error(
701
+ "Failed to import tritonclient (%s): %s. Please install with: pip install %s",
702
+ package_name, err, package_name
703
+ )
704
+ raise ImportError(f"Required package {package_name} not installed: {err}")
705
+ except Exception as err:
706
+ logging.error(
707
+ "Failed to import tritonclient: %s",
708
+ err,
709
+ )
710
+ raise
711
+
712
+ def _setup_rest_client(self):
713
+ """Setup REST client and model configuration.
714
+
715
+ Returns:
716
+ Dictionary containing client configuration
717
+ """
718
+ client = self.tritonclientclass.InferenceServerClient(url=self.url)
719
+ model_config = client.get_model_config(model_name=self.model_name, model_version="1")
720
+ input_config = model_config["input"][0]
721
+ input_shape = [1] + input_config["dims"] # Prepend batch dimension
722
+ input_obj = self.tritonclientclass.InferInput(
723
+ input_config["name"],
724
+ input_shape,
725
+ input_config["data_type"].split("_")[-1],
726
+ )
727
+ output = self.tritonclientclass.InferRequestedOutput(model_config["output"][0]["name"])
728
+ return {
729
+ "client": client,
730
+ "input": input_obj,
731
+ "output": output,
732
+ }
733
+
734
+ def _setup_grpc_client(self):
735
+ """Setup gRPC client and model configuration.
736
+
737
+ Returns:
738
+ Dictionary containing client configuration
739
+ """
740
+ client = self.tritonclientclass.InferenceServerClient(url=self.url)
741
+ model_config = client.get_model_config(model_name=self.model_name, model_version="1")
742
+ input_config = model_config.config.input[0]
743
+ input_shape = [1] + list(input_config.dims) # Prepend batch dimension
744
+ input_obj = self.tritonclientclass.InferInput(
745
+ input_config.name,
746
+ input_shape,
747
+ self.data_type_mapping[input_config.data_type].split("_")[-1],
748
+ )
749
+ output = self.tritonclientclass.InferRequestedOutput(model_config.config.output[0].name)
750
+ return {
751
+ "client": client,
752
+ "input": input_obj,
753
+ "output": output,
754
+ }
755
+
756
+ def inference(self, input_data: Union[bytes, np.ndarray]) -> np.ndarray:
757
+ """Make a synchronous inference request.
758
+
759
+ Args:
760
+ input_data: Input data as bytes or stacked numpy array
761
+
762
+ Returns:
763
+ Model prediction as numpy array
764
+
765
+ Raises:
766
+ Exception: If inference fails
767
+ """
768
+ try:
769
+ # If already preprocessed ndarray, make it C-contiguous FP32.
770
+ if isinstance(input_data, np.ndarray):
771
+ input_array = np.ascontiguousarray(input_data, dtype=np.float32)
772
+ if input_array.ndim == 5 and input_array.shape[1] == 1:
773
+ # [B, 1, C, H, W] -> [B, C, H, W]
774
+ input_array = np.ascontiguousarray(
775
+ input_array.reshape(input_array.shape[0],
776
+ input_array.shape[2],
777
+ input_array.shape[3],
778
+ input_array.shape[4]),
779
+ dtype=np.float32
780
+ )
781
+ else:
782
+ # -> [1, C, H, W], FP32, contiguous
783
+ input_array = self._preprocess_input(input_data)
784
+
785
+ # Update InferInput shape to match batch (N,C,H,W)
786
+ self.client_info["input"].set_shape(list(input_array.shape))
787
+ self.client_info["input"].set_data_from_numpy(input_array)
788
+
789
+ if self.connection_protocol == "rest":
790
+ resp = self.client_info["client"].infer(
791
+ model_name=self.model_name,
792
+ model_version="1",
793
+ inputs=[self.client_info["input"]],
794
+ outputs=[self.client_info["output"]],
795
+ )
796
+ else:
797
+ resp = self.client_info["client"].infer(
798
+ model_name=self.model_name,
799
+ model_version="1",
800
+ inputs=[self.client_info["input"]],
801
+ outputs=[self.client_info["output"]],
802
+ )
803
+
804
+ return resp.as_numpy(self.client_info["output"].name())
805
+
806
+ except Exception as err:
807
+ logging.error("Triton inference failed: %s", err, exc_info=True)
808
+ raise Exception(f"Triton inference failed: {err}") from err
809
+
810
+ async def async_inference(self, input_data: Union[bytes, np.ndarray]) -> np.ndarray:
811
+ """Make an asynchronous inference request (REST + gRPC)."""
812
+ try:
813
+ logging.debug("Making async inference request")
814
+
815
+ if isinstance(input_data, np.ndarray):
816
+ input_array = input_data
817
+ else:
818
+ input_array = self._preprocess_input(input_data)
819
+
820
+ # Ensure C-contiguous
821
+ if not input_array.flags.c_contiguous:
822
+ input_array = np.ascontiguousarray(input_array)
823
+
824
+ self.client_info["input"].set_shape(list(input_array.shape))
825
+ self.client_info["input"].set_data_from_numpy(input_array)
826
+
827
+ if self.connection_protocol == "rest":
828
+ # REST: async_infer -> InferAsyncRequest, then block with get_result()
829
+ resp = self.client_info["client"].async_infer(
830
+ model_name=self.model_name,
831
+ model_version="1",
832
+ inputs=[self.client_info["input"]],
833
+ outputs=[self.client_info["output"]],
834
+ )
835
+ result = resp.get_result()
836
+
837
+ else:
838
+ # gRPC: async_infer uses callback; wrap it into an awaitable Future
839
+ loop = asyncio.get_running_loop()
840
+ fut: asyncio.Future = loop.create_future()
841
+
842
+ def _callback(result, error):
843
+ if error is not None:
844
+ loop.call_soon_threadsafe(fut.set_exception, error)
845
+ else:
846
+ loop.call_soon_threadsafe(fut.set_result, result)
847
+
848
+ self.client_info["client"].async_infer(
849
+ model_name=self.model_name,
850
+ model_version="1",
851
+ inputs=[self.client_info["input"]],
852
+ outputs=[self.client_info["output"]],
853
+ callback=_callback,
854
+ )
855
+ result = await fut
856
+
857
+ logging.debug(f"Async inference response type: {type(result)}")
858
+ logging.info("Successfully got async inference result")
859
+
860
+ output_array = result.as_numpy(self.client_info["output"].name())
861
+ logging.info(f"Output shape: {output_array.shape}")
862
+ return output_array
863
+
864
+ except Exception as err:
865
+ logging.error(f"Async Triton inference failed: {err}")
866
+ raise Exception(f"Async Triton inference failed: {err}") from err
867
+
868
+ def _preprocess_input(self, input_data) -> np.ndarray:
869
+ """Preprocess input data for YOLOv8 or OCR inference.
870
+
871
+ Args:
872
+ input_data: Raw input bytes or string (file path)
873
+
874
+ Returns:
875
+ Preprocessed numpy array ready for inference
876
+ """
877
+ if isinstance(self.input_size, int):
878
+ resize_shape = (self.input_size, self.input_size)
879
+ elif isinstance(self.input_size, (list, tuple)) and len(self.input_size) == 2:
880
+ resize_shape = (self.input_size[0], self.input_size[1])
881
+ input_shape = [1, 3, resize_shape[0], resize_shape[1]] # Default for compatibility
882
+
883
+ if isinstance(input_data, str) and os.path.exists(input_data):
884
+ with open(input_data, "rb") as f:
885
+ input_data = f.read()
886
+
887
+ if isinstance(input_data, bytes):
888
+ try:
889
+ image = Image.open(BytesIO(input_data)).convert("RGB")
890
+ except Exception as e:
891
+ arr = np.frombuffer(input_data, dtype=np.uint8).reshape(640, 640, 3)
892
+ image = Image.fromarray(arr, mode="RGB")
893
+ elif isinstance(input_data, np.ndarray):
894
+ image = Image.fromarray(input_data, mode="RGB")
895
+ else:
896
+ raise ValueError(f"Unsupported input_data type: {type(input_data)}")
897
+
898
+ if self.is_yolo:
899
+ logging.debug("Preprocessing input for YOLO model")
900
+ image, ratio, (dw, dh) = self._letterbox_resize(image, resize_shape)
901
+ arr = np.array(image).astype(np.float32) / 255.0
902
+ arr = arr.transpose(2, 0, 1)
903
+ arr = np.expand_dims(arr, axis=0)
904
+ self.client_info["padding_info"] = {"ratio": ratio, "dw": dw, "dh": dh}
905
+ elif self.is_ocr:
906
+ logging.debug("Preprocessing input for OCR model")
907
+ config = getattr(self, "ocr_config", {})
908
+ arr = self._preprocess_ocr(
909
+ image=image,
910
+ resize_shape=resize_shape,
911
+ image_color_mode=config.get("color_mode", "rgb"),
912
+ keep_aspect_ratio=config.get("keep_aspect_ratio", False),
913
+ interpolation_method=config.get("interpolation", "linear"),
914
+ padding_color=config.get("padding_color", (114, 114, 114)),
915
+ use_grayscale=config.get("use_grayscale", False),
916
+ apply_contrast=config.get("apply_contrast", False),
917
+ )
918
+ else:
919
+ # Classifier preprocessing: resize directly, ImageNet normalization
920
+ image = image.resize(resize_shape)
921
+ arr = np.array(image).astype(np.float32) / 255.0
922
+ arr = arr.transpose(2, 0, 1) # Convert to CHW
923
+ arr = np.expand_dims(arr, axis=0) # Add batch dimension
924
+ mean = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(1, 3, 1, 1)
925
+ std = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(1, 3, 1, 1)
926
+ arr = (arr - mean) / std
927
+
928
+ arr = arr.astype(self.numpy_data_type_mapping[self.client_info["input"].datatype()])
929
+ return arr
930
+
931
+
932
+ def _preprocess_ocr(
933
+ self,
934
+ image: Image.Image,
935
+ resize_shape: tuple,
936
+ image_color_mode: str = "rgb",
937
+ keep_aspect_ratio: bool = False,
938
+ interpolation_method: str = "linear",
939
+ padding_color: tuple = (114, 114, 114),
940
+ use_grayscale: bool = False,
941
+ apply_contrast: bool = False,
942
+ ) -> np.ndarray:
943
+ """Preprocess an input PIL Image for OCR model inference.
944
+
945
+ Args:
946
+ image: PIL Image in RGB format.
947
+ resize_shape: (height, width) tuple.
948
+ image_color_mode: "rgb" or "grayscale" (affects output channels).
949
+ keep_aspect_ratio: Whether to preserve aspect ratio with padding.
950
+ interpolation_method: One of ["linear", "nearest", "cubic", "area"].
951
+ padding_color: Padding color for aspect ratio preservation (RGB or scalar for grayscale).
952
+ use_grayscale: Convert to grayscale before processing (default: False).
953
+ apply_contrast: Apply CLAHE contrast enhancement (default: False).
954
+
955
+ Returns:
956
+ Preprocessed numpy array (batch, height, width, C) ready for OCR inference.
957
+ """
958
+ import cv2
959
+ import numpy as np
960
+
961
+ img_height, img_width = resize_shape
962
+ img = np.array(image)
963
+
964
+ if img.shape[-1] == 4:
965
+ img = img[:, :, :3]
966
+
967
+ if use_grayscale:
968
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
969
+ if image_color_mode == "rgb":
970
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
971
+ elif image_color_mode == "grayscale":
972
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
973
+
974
+ if apply_contrast:
975
+ if image_color_mode == "grayscale" or use_grayscale:
976
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
977
+ img = clahe.apply(img if img.ndim == 2 else img[:, :, 0])
978
+ if image_color_mode == "rgb":
979
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
980
+ else:
981
+ img_yuv = cv2.cvtColor(img, cv2.COLOR_RGB2YUV)
982
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
983
+ img_yuv[:, :, 0] = clahe.apply(img_yuv[:, :, 0])
984
+ img = cv2.cvtColor(img_yuv, cv2.COLOR_YUV2RGB)
985
+
986
+ INTERPOLATION_MAP = {
987
+ "linear": cv2.INTER_LINEAR,
988
+ "nearest": cv2.INTER_NEAREST,
989
+ "cubic": cv2.INTER_CUBIC,
990
+ "area": cv2.INTER_AREA,
991
+ }
992
+ interpolation = INTERPOLATION_MAP.get(interpolation_method, cv2.INTER_LINEAR)
993
+
994
+ if not keep_aspect_ratio:
995
+ img = cv2.resize(img, (img_width, img_height), interpolation=interpolation)
996
+ else:
997
+ orig_h, orig_w = img.shape[:2]
998
+ r = min(img_height / orig_h, img_width / orig_w)
999
+ new_unpad_w, new_unpad_h = round(orig_w * r), round(orig_h * r)
1000
+ img = cv2.resize(img, (new_unpad_w, new_unpad_h), interpolation=interpolation)
1001
+ dw, dh = (img_width - new_unpad_w) / 2, (img_height - new_unpad_h) / 2
1002
+ top, bottom, left, right = (
1003
+ round(dh - 0.1),
1004
+ round(dh + 0.1),
1005
+ round(dw - 0.1),
1006
+ round(dw + 0.1),
1007
+ )
1008
+ border_color = padding_color[0] if image_color_mode == "grayscale" else padding_color
1009
+ img = cv2.copyMakeBorder(
1010
+ img,
1011
+ top,
1012
+ bottom,
1013
+ left,
1014
+ right,
1015
+ borderType=cv2.BORDER_CONSTANT,
1016
+ value=border_color,
1017
+ )
1018
+
1019
+ if image_color_mode == "grayscale" and img.ndim == 2:
1020
+ img = np.expand_dims(img, axis=-1)
1021
+
1022
+ dtype = self.numpy_data_type_mapping[self.client_info["input"].datatype()]
1023
+ arr = img.astype(np.float32)
1024
+ if dtype == np.uint8:
1025
+ arr = (arr * 255.0).astype(np.uint8)
1026
+ else:
1027
+ arr = arr / 255.0
1028
+
1029
+ arr = np.expand_dims(arr, axis=0)
1030
+
1031
+ logging.debug(f"Preprocessed OCR input shape: {arr.shape}, dtype: {arr.dtype}")
1032
+
1033
+ if arr.shape[-1] == 3:
1034
+ cv2.imwrite("preprocessed_ocr_image.png", cv2.cvtColor((arr[0] * 255).astype(np.uint8), cv2.COLOR_RGB2BGR))
1035
+ else:
1036
+ cv2.imwrite("preprocessed_ocr_image.png", (arr[0, :, :, 0] * 255).astype(np.uint8))
1037
+
1038
+ return arr.astype(dtype)
1039
+
1040
+
1041
+
1042
+ def _letterbox_resize(self, image, target_size):
1043
+ """Resize image with letterbox padding to maintain aspect ratio."""
1044
+ target_h, target_w = target_size
1045
+ img_w, img_h = image.size
1046
+ ratio = min(target_w / img_w, target_h / img_h)
1047
+ new_w, new_h = int(img_w * ratio), int(img_h * ratio)
1048
+ image = image.resize((new_w, new_h), Image.LANCZOS)
1049
+ padded_image = Image.new("RGB", (target_w, target_h), (114, 114, 114))
1050
+ dw, dh = (target_w - new_w) // 2, (target_h - new_h) // 2
1051
+ padded_image.paste(image, (dw, dh))
1052
+ logging.debug("Letterbox resize: original size %s, new size %s, padding (dw, dh) (%d, %d)", (img_w, img_h), (new_w, new_h), dw, dh)
1053
+ logging.debug("Letterbox resize completed for target size: %s", target_size)
1054
+ return padded_image, ratio, (dw, dh)
1055
+
1056
+ def _postprocess_yolo(
1057
+ self,
1058
+ outputs: np.ndarray,
1059
+ conf_thres: float = 0.25,
1060
+ iou_thres: float = 0.45,
1061
+ max_det: int = 300
1062
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
1063
+ """Postprocess YOLOv8 outputs (Torch or ONNX/Triton) with pipeline compatibility.
1064
+
1065
+ Args:
1066
+ outputs: Raw model output as NumPy array, expected shape [batch, num_boxes, num_classes + 4] for YOLO
1067
+ conf_thres: Confidence threshold for filtering detections
1068
+ iou_thres: IoU threshold for Non-Maximum Suppression
1069
+ max_det: Maximum number of detections to keep
1070
+
1071
+ Returns:
1072
+ Tuple of (boxes, scores, class_ids) as NumPy arrays:
1073
+ - boxes: [N, 4] array of bounding boxes in xyxy format
1074
+ - scores: [N] array of confidence scores
1075
+ - class_ids: [N] array of class IDs
1076
+ """
1077
+ if not self.is_yolo:
1078
+ return outputs, np.array([]), np.array([])
1079
+
1080
+ try:
1081
+ if isinstance(outputs, np.ndarray):
1082
+ outputs = torch.from_numpy(outputs)
1083
+ elif not isinstance(outputs, torch.Tensor):
1084
+ outputs = torch.tensor(outputs, dtype=torch.float32)
1085
+
1086
+ if outputs.ndim == 2:
1087
+ outputs = outputs.unsqueeze(0)
1088
+
1089
+ if outputs.shape[1] < outputs.shape[2]:
1090
+ # Format [batch, num_classes + 4, num_boxes] -> [batch, num_boxes, num_classes + 4]
1091
+ outputs = outputs.transpose(1, 2)
1092
+
1093
+ boxes = outputs[..., :4] # xywh, [batch, num_boxes, 4]
1094
+ scores_all = outputs[..., 4:] # class scores, [batch, num_boxes, num_classes]
1095
+
1096
+ all_boxes, all_scores, all_class_ids = [], [], []
1097
+ for batch_idx in range(outputs.shape[0]):
1098
+ batch_boxes = boxes[batch_idx] # [num_boxes, 4]
1099
+ batch_scores_all = scores_all[batch_idx] # [num_boxes, num_classes]
1100
+
1101
+ # Get best class per box
1102
+ scores, class_ids = batch_scores_all.max(dim=-1)
1103
+
1104
+ # Confidence filter
1105
+ mask = scores > conf_thres
1106
+ batch_boxes = batch_boxes[mask]
1107
+ batch_scores = scores[mask]
1108
+ batch_class_ids = class_ids[mask]
1109
+
1110
+ # Convert xywh -> xyxy
1111
+ if batch_boxes.shape[0] > 0:
1112
+ xyxy = torch.zeros_like(batch_boxes)
1113
+ xyxy[:, 0] = batch_boxes[:, 0] - batch_boxes[:, 2] / 2 # x1
1114
+ xyxy[:, 1] = batch_boxes[:, 1] - batch_boxes[:, 3] / 2 # y1
1115
+ xyxy[:, 2] = batch_boxes[:, 0] + batch_boxes[:, 2] / 2 # x2
1116
+ xyxy[:, 3] = batch_boxes[:, 1] + batch_boxes[:, 3] / 2 # y2
1117
+ batch_boxes = xyxy
1118
+
1119
+ # Adjust for letterbox padding if provided
1120
+ padding_info = self.client_info.get("padding_info", {})
1121
+ if padding_info:
1122
+ ratio = padding_info.get("ratio", 1.0)
1123
+ dw = padding_info.get("dw", 0)
1124
+ dh = padding_info.get("dh", 0)
1125
+ batch_boxes[:, 0] = (batch_boxes[:, 0] - dw) / ratio # x1
1126
+ batch_boxes[:, 1] = (batch_boxes[:, 1] - dh) / ratio # y1
1127
+ batch_boxes[:, 2] = (batch_boxes[:, 2] - dw) / ratio # x2
1128
+ batch_boxes[:, 3] = (batch_boxes[:, 3] - dh) / ratio # y2
1129
+
1130
+ # NMS
1131
+ if batch_boxes.shape[0] > 0:
1132
+ keep = torchvision.ops.nms(batch_boxes, batch_scores, iou_thres)
1133
+ batch_boxes = batch_boxes[keep]
1134
+ batch_scores = batch_scores[keep]
1135
+ batch_class_ids = batch_class_ids[keep]
1136
+
1137
+ if batch_boxes.shape[0] > max_det:
1138
+ topk = batch_scores.topk(max_det).indices
1139
+ batch_boxes = batch_boxes[topk]
1140
+ batch_scores = batch_scores[topk]
1141
+ batch_class_ids = batch_class_ids[topk]
1142
+
1143
+ all_boxes.append(batch_boxes.cpu().numpy())
1144
+ all_scores.append(batch_scores.cpu().numpy())
1145
+ all_class_ids.append(batch_class_ids.cpu().numpy())
1146
+
1147
+ boxes = np.concatenate(all_boxes, axis=0) if all_boxes else np.empty((0, 4))
1148
+ scores = np.concatenate(all_scores, axis=0) if all_scores else np.empty(0)
1149
+ class_ids = np.concatenate(all_class_ids, axis=0) if all_class_ids else np.empty(0)
1150
+
1151
+ return boxes, scores, class_ids
1152
+
1153
+ except Exception as e:
1154
+ logging.error("YOLO post-processing failed: %s", str(e), exc_info=True)
1155
+ return np.empty((0, 4)), np.empty(0), np.empty(0)
1156
+
1157
+ def _postprocess_ocr(
1158
+ self,
1159
+ model_output: np.ndarray,
1160
+ max_plate_slots: int = 9,
1161
+ model_alphabet: str = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ_",
1162
+ return_confidence: bool = True,
1163
+ confidence_threshold: float = 0.0, # Disabled threshold to match ONNX
1164
+ ) -> Tuple[list[str], np.ndarray] | list[str]:
1165
+ """Postprocess OCR model outputs into license plate strings.
1166
+
1167
+ Args:
1168
+ model_output: Raw output tensor from the model.
1169
+ max_plate_slots: Maximum number of character positions. Defaults to 9.
1170
+ model_alphabet: Alphabet used by the OCR model. Defaults to alphanumeric.
1171
+ return_confidence: If True, also return per-character confidence scores. Defaults to True.
1172
+ confidence_threshold: Minimum confidence for a character to be considered valid. Defaults to 0.0.
1173
+
1174
+ Returns:
1175
+ If return_confidence is False: a list of decoded plate strings.
1176
+ If True: a two-tuple (plates, probs) where plates is the list of decoded strings,
1177
+ and probs is an array of shape (N, max_plate_slots) with confidence scores.
1178
+ """
1179
+ try:
1180
+ logging.debug(f"OCR model output shape: {model_output.shape}")
1181
+
1182
+ predictions = model_output.reshape((-1, max_plate_slots, len(model_alphabet)))
1183
+ probs = np.max(predictions, axis=-1)
1184
+ prediction_indices = np.argmax(predictions, axis=-1)
1185
+
1186
+ alphabet_array = np.array(list(model_alphabet))
1187
+ if confidence_threshold > 0:
1188
+ pad_char_index = model_alphabet.index('_')
1189
+ prediction_indices[probs < confidence_threshold] = pad_char_index
1190
+
1191
+ plate_chars = alphabet_array[prediction_indices]
1192
+ plates = np.apply_along_axis("".join, 1, plate_chars).tolist()
1193
+
1194
+ if return_confidence:
1195
+ return plates, probs
1196
+ return plates
1197
+ except Exception as e:
1198
+ logging.error("OCR post-processing failed: %s", str(e), exc_info=True)
1199
+ return [], np.array([])
1200
+
1201
+ def format_response(self, response: np.ndarray) -> Dict[str, Any]:
1202
+ """Format model response for consistent logging.
1203
+
1204
+ Args:
1205
+ response: Raw model output
1206
+
1207
+ Returns:
1208
+ Formatted response dictionary
1209
+ """
1210
+ if self.is_yolo:
1211
+ boxes, scores, class_ids = self._postprocess_yolo(
1212
+ response,
1213
+ conf_thres=0.25,
1214
+ iou_thres=0.45,
1215
+ max_det=300
1216
+ )
1217
+ predictions = {
1218
+ "boxes": boxes.tolist(),
1219
+ "scores": scores.tolist(),
1220
+ "class_ids": class_ids.tolist()
1221
+ }
1222
+ elif self.is_ocr:
1223
+ plates, probs = self._postprocess_ocr(
1224
+ response,
1225
+ max_plate_slots=9,
1226
+ model_alphabet="0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ_",
1227
+ return_confidence=True,
1228
+ )
1229
+ predictions = {
1230
+ "plates" : plates,
1231
+ "prob" : probs
1232
+ }
1233
+
1234
+ else:
1235
+ predictions = response.tolist() if isinstance(response, np.ndarray) else response
1236
+
1237
+ return {
1238
+ "predictions": predictions,
1239
+ "model_id": self.model_name,
1240
+ "timestamp": datetime.now(timezone.utc).isoformat(),
1241
+ }
1242
+
1243
+ # TODO: Bifurcate Triton server and inference utils into separate files
1244
+ # TODO: import and use postprocess functions for diff use-cases (yolo, ocr, cls)
1245
+ # TODO: Verify and Generalize for Obj Det models
1246
+ # TODO: Implement a unified interface for model post-processing
1247
+ # TODO: Define a Standardized template for custom model configs support (ocr)
1248
+ # TODO: Remove hardcoded versions and provide user-defined version control support