matrice-inference 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of matrice-inference might be problematic. Click here for more details.
- matrice_inference/__init__.py +72 -0
- matrice_inference/py.typed +0 -0
- matrice_inference/server/__init__.py +23 -0
- matrice_inference/server/inference_interface.py +176 -0
- matrice_inference/server/model/__init__.py +1 -0
- matrice_inference/server/model/model_manager.py +274 -0
- matrice_inference/server/model/model_manager_wrapper.py +550 -0
- matrice_inference/server/model/triton_model_manager.py +290 -0
- matrice_inference/server/model/triton_server.py +1248 -0
- matrice_inference/server/proxy_interface.py +371 -0
- matrice_inference/server/server.py +1004 -0
- matrice_inference/server/stream/__init__.py +0 -0
- matrice_inference/server/stream/app_deployment.py +228 -0
- matrice_inference/server/stream/consumer_worker.py +201 -0
- matrice_inference/server/stream/frame_cache.py +127 -0
- matrice_inference/server/stream/inference_worker.py +163 -0
- matrice_inference/server/stream/post_processing_worker.py +230 -0
- matrice_inference/server/stream/producer_worker.py +147 -0
- matrice_inference/server/stream/stream_pipeline.py +451 -0
- matrice_inference/server/stream/utils.py +23 -0
- matrice_inference/tmp/abstract_model_manager.py +58 -0
- matrice_inference/tmp/aggregator/__init__.py +18 -0
- matrice_inference/tmp/aggregator/aggregator.py +330 -0
- matrice_inference/tmp/aggregator/analytics.py +906 -0
- matrice_inference/tmp/aggregator/ingestor.py +438 -0
- matrice_inference/tmp/aggregator/latency.py +597 -0
- matrice_inference/tmp/aggregator/pipeline.py +968 -0
- matrice_inference/tmp/aggregator/publisher.py +431 -0
- matrice_inference/tmp/aggregator/synchronizer.py +594 -0
- matrice_inference/tmp/batch_manager.py +239 -0
- matrice_inference/tmp/overall_inference_testing.py +338 -0
- matrice_inference/tmp/triton_utils.py +638 -0
- matrice_inference-0.1.2.dist-info/METADATA +28 -0
- matrice_inference-0.1.2.dist-info/RECORD +37 -0
- matrice_inference-0.1.2.dist-info/WHEEL +5 -0
- matrice_inference-0.1.2.dist-info/licenses/LICENSE.txt +21 -0
- matrice_inference-0.1.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1248 @@
|
|
|
1
|
+
"""Module providing triton_server functionality."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import zipfile
|
|
5
|
+
import subprocess
|
|
6
|
+
import tempfile
|
|
7
|
+
import asyncio
|
|
8
|
+
import shutil
|
|
9
|
+
import logging
|
|
10
|
+
import threading
|
|
11
|
+
import shlex
|
|
12
|
+
from typing import Tuple, Optional, Any, Dict, Union, List
|
|
13
|
+
import importlib.util
|
|
14
|
+
from matrice_common.utils import dependencies_check
|
|
15
|
+
|
|
16
|
+
# TRITON_DOCKER_IMAGE = "nvcr.io/nvidia/tritonserver:24.08-py3"
|
|
17
|
+
TRITON_DOCKER_IMAGE = "nvcr.io/nvidia/tritonserver:23.08-py3"
|
|
18
|
+
BASE_PATH = "./model_repository"
|
|
19
|
+
|
|
20
|
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
|
21
|
+
|
|
22
|
+
class TritonServer:
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
model_name: str,
|
|
26
|
+
model_path: str,
|
|
27
|
+
runtime_framework: str,
|
|
28
|
+
input_size: Union[int, List[int]] = 224,
|
|
29
|
+
num_classes: int = 10,
|
|
30
|
+
dynamic_batching: bool = False,
|
|
31
|
+
num_model_instances: int = 1,
|
|
32
|
+
max_batch_size: int = 8,
|
|
33
|
+
connection_protocol: str = "rest",
|
|
34
|
+
is_yolo: bool = False,
|
|
35
|
+
is_ocr: bool = False,
|
|
36
|
+
use_trt_accelerator: bool = False,
|
|
37
|
+
**kwargs,
|
|
38
|
+
):
|
|
39
|
+
"""Initialize the Triton server.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
model_name: Name of the model (used for Triton model repository).
|
|
43
|
+
model_path: Path to the model file on the local filesystem.
|
|
44
|
+
runtime_framework: Framework of the model ('onnx', 'pytorch', 'torchscript', 'yolo', 'tensorrt', 'openvino').
|
|
45
|
+
input_size: Input size for the model (int for square images or [height, width]).
|
|
46
|
+
num_classes: Number of output classes.
|
|
47
|
+
dynamic_batching: Enable dynamic batching for the model.
|
|
48
|
+
num_model_instances: Number of model instances to deploy.
|
|
49
|
+
max_batch_size: Maximum batch size for inference.
|
|
50
|
+
connection_protocol: Protocol for Triton server ('rest' or 'grpc').
|
|
51
|
+
use_trt_accelerator: Enable TensorRT acceleration for inference.
|
|
52
|
+
|
|
53
|
+
is_yolo: Boolean indicating if the model is a YOLO model.
|
|
54
|
+
is_ocr: Boolean indicating if the model is an OCR model.
|
|
55
|
+
"""
|
|
56
|
+
if not dependencies_check("torch"):
|
|
57
|
+
raise ImportError("PyTorch is required but not installed")
|
|
58
|
+
import torch
|
|
59
|
+
|
|
60
|
+
if not model_name:
|
|
61
|
+
raise ValueError("model_name must be provided")
|
|
62
|
+
if not model_path:
|
|
63
|
+
raise ValueError("model_path must be provided")
|
|
64
|
+
if not os.path.exists(model_path):
|
|
65
|
+
raise FileNotFoundError(f"Model file not found at: {model_path}")
|
|
66
|
+
|
|
67
|
+
logging.info("Initializing TritonServer")
|
|
68
|
+
|
|
69
|
+
self.model_name = model_name
|
|
70
|
+
self.model_path = os.path.abspath(model_path)
|
|
71
|
+
self.runtime_framework = runtime_framework.lower()
|
|
72
|
+
self.connection_protocol = connection_protocol.lower()
|
|
73
|
+
|
|
74
|
+
if isinstance(input_size, (int, float)):
|
|
75
|
+
self.input_size = [int(input_size), int(input_size)]
|
|
76
|
+
elif isinstance(input_size, (list, tuple)):
|
|
77
|
+
if len(input_size) == 2: # (H, W)
|
|
78
|
+
self.input_size = [int(input_size[0]), int(input_size[1])]
|
|
79
|
+
elif len(input_size) == 3: # (H, W, C)
|
|
80
|
+
self.input_size = [int(input_size[0]), int(input_size[1])]
|
|
81
|
+
elif len(input_size) == 4: # (N, C, H, W)
|
|
82
|
+
self.input_size = [int(input_size[-2]), int(input_size[-1])]
|
|
83
|
+
else:
|
|
84
|
+
logging.warning("Unexpected input_size length: %s, using default [224, 224]", input_size)
|
|
85
|
+
self.input_size = [224, 224]
|
|
86
|
+
|
|
87
|
+
self.num_classes = num_classes
|
|
88
|
+
self.dynamic_batching = dynamic_batching
|
|
89
|
+
self.num_model_instances = num_model_instances
|
|
90
|
+
self.max_batch_size = max_batch_size
|
|
91
|
+
self.use_trt_accelerator = use_trt_accelerator
|
|
92
|
+
|
|
93
|
+
self.is_yolo = is_yolo or self.runtime_framework == "yolo"
|
|
94
|
+
self.is_ocr = is_ocr
|
|
95
|
+
self.input_name = "images" if self.is_yolo else "input"
|
|
96
|
+
self.output_name = "output0" if self.is_yolo else "output"
|
|
97
|
+
|
|
98
|
+
self.gpus_count = torch.cuda.device_count()
|
|
99
|
+
self.config_params = {}
|
|
100
|
+
|
|
101
|
+
logging.info("Model name: %s", self.model_name)
|
|
102
|
+
logging.info("Model path: %s", self.model_path)
|
|
103
|
+
logging.info("Runtime framework: %s", self.runtime_framework)
|
|
104
|
+
logging.info("Using connection protocol: %s", self.connection_protocol)
|
|
105
|
+
logging.info("Input size: %s", self.input_size)
|
|
106
|
+
logging.info("Number of classes: %s", self.num_classes)
|
|
107
|
+
logging.info("Found %s GPUs available for inference", self.gpus_count)
|
|
108
|
+
|
|
109
|
+
self.docker_pull_process = subprocess.Popen(
|
|
110
|
+
[
|
|
111
|
+
shutil.which("docker"),
|
|
112
|
+
"pull",
|
|
113
|
+
TRITON_DOCKER_IMAGE,
|
|
114
|
+
],
|
|
115
|
+
stdout=subprocess.PIPE,
|
|
116
|
+
stderr=subprocess.PIPE,
|
|
117
|
+
text=True,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
def check_triton_docker_image(self):
|
|
121
|
+
"""Check if docker image download is complete and wait for it to finish"""
|
|
122
|
+
logging.info("Checking docker image download status")
|
|
123
|
+
stdout, stderr = self.docker_pull_process.communicate()
|
|
124
|
+
if self.docker_pull_process.returncode == 0:
|
|
125
|
+
logging.info(f"Docker image {TRITON_DOCKER_IMAGE} downloaded successfully")
|
|
126
|
+
else:
|
|
127
|
+
error_msg = stderr.decode()
|
|
128
|
+
logging.error(f"Docker pull failed with return code {self.docker_pull_process.returncode}")
|
|
129
|
+
logging.error("Error message: %s", error_msg)
|
|
130
|
+
raise RuntimeError(f"Docker pull failed: {error_msg}")
|
|
131
|
+
|
|
132
|
+
def prepare_model(self, model_version_dir: str) -> None:
|
|
133
|
+
"""Prepare the model file for Triton Inference Server.
|
|
134
|
+
|
|
135
|
+
Copies model from self.model_path to model_version_dir and converts if necessary
|
|
136
|
+
to the format expected by Triton (model.onnx, model.xml, model.plan).
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
model_version_dir: Directory to store the model file
|
|
140
|
+
(e.g., '/models/<model_name>/1').
|
|
141
|
+
"""
|
|
142
|
+
try:
|
|
143
|
+
runtime_framework = self.runtime_framework
|
|
144
|
+
logging.info("Preparing model with runtime framework: %s", runtime_framework)
|
|
145
|
+
logging.info("Source model path: %s", self.model_path)
|
|
146
|
+
|
|
147
|
+
if runtime_framework not in ["onnx", "pytorch", "torchscript", "yolo", "tensorrt", "openvino"]:
|
|
148
|
+
logging.error("Runtime framework '%s' not supported. Supported: %s",
|
|
149
|
+
runtime_framework, ["onnx", "pytorch", "torchscript", "yolo", "tensorrt", "openvino"])
|
|
150
|
+
raise ValueError(f"Unsupported runtime framework: {runtime_framework}")
|
|
151
|
+
|
|
152
|
+
os.makedirs(model_version_dir, exist_ok=True)
|
|
153
|
+
|
|
154
|
+
# 1. ONNX - copy to model.onnx
|
|
155
|
+
if runtime_framework == "onnx":
|
|
156
|
+
model_file = os.path.join(model_version_dir, "model.onnx")
|
|
157
|
+
if os.path.abspath(self.model_path) != os.path.abspath(model_file):
|
|
158
|
+
shutil.copy2(self.model_path, model_file)
|
|
159
|
+
logging.info("Copied ONNX model to: %s", model_file)
|
|
160
|
+
else:
|
|
161
|
+
logging.info("Model path is already correct: %s", model_file)
|
|
162
|
+
self._verify_onnx_model(model_file)
|
|
163
|
+
|
|
164
|
+
# 2. PyTorch/TorchScript/Yolo - export to ONNX
|
|
165
|
+
elif runtime_framework in ["pytorch", "torchscript", "yolo"]:
|
|
166
|
+
model_file = os.path.join(model_version_dir, "model.onnx")
|
|
167
|
+
img_chw = (3, self.input_size[0], self.input_size[1])
|
|
168
|
+
logging.info("Converting %s model to ONNX with input shape: %s", runtime_framework, img_chw)
|
|
169
|
+
self.to_onnx(self.model_path, model_file, (1, *img_chw))
|
|
170
|
+
logging.info("Exported ONNX model to: %s", model_file)
|
|
171
|
+
self._verify_onnx_model(model_file)
|
|
172
|
+
|
|
173
|
+
# 3. TensorRT - copy to model.plan
|
|
174
|
+
elif runtime_framework == "tensorrt":
|
|
175
|
+
model_file = os.path.join(model_version_dir, "model.plan")
|
|
176
|
+
shutil.copy2(self.model_path, model_file)
|
|
177
|
+
logging.info("Copied TensorRT model to: %s", model_file)
|
|
178
|
+
self.runtime_framework = "tensorrt"
|
|
179
|
+
|
|
180
|
+
# 4. OpenVINO - extract ZIP or copy files
|
|
181
|
+
elif runtime_framework == "openvino":
|
|
182
|
+
if self.model_path.endswith('.zip'):
|
|
183
|
+
logging.info("Extracting OpenVINO ZIP to: %s", model_version_dir)
|
|
184
|
+
with zipfile.ZipFile(self.model_path, "r") as zip_ref:
|
|
185
|
+
zip_ref.extractall(model_version_dir)
|
|
186
|
+
model_file = os.path.join(model_version_dir, "model.xml")
|
|
187
|
+
model_bin_file = os.path.join(model_version_dir, "model.bin")
|
|
188
|
+
else:
|
|
189
|
+
model_file = os.path.join(model_version_dir, "model.xml")
|
|
190
|
+
model_bin_file = os.path.join(model_version_dir, "model.bin")
|
|
191
|
+
|
|
192
|
+
shutil.copy2(self.model_path, model_file)
|
|
193
|
+
|
|
194
|
+
source_bin = self.model_path.replace('.xml', '.bin')
|
|
195
|
+
if os.path.exists(source_bin):
|
|
196
|
+
shutil.copy2(source_bin, model_bin_file)
|
|
197
|
+
else:
|
|
198
|
+
raise RuntimeError(f"OpenVINO model.bin not found at {source_bin}")
|
|
199
|
+
|
|
200
|
+
if not os.path.exists(model_file):
|
|
201
|
+
logging.error("OpenVINO model.xml not found at %s", model_file)
|
|
202
|
+
raise RuntimeError(f"OpenVINO model.xml not found at {model_file}")
|
|
203
|
+
if not os.path.exists(model_bin_file):
|
|
204
|
+
logging.error("OpenVINO model.bin not found at %s", model_bin_file)
|
|
205
|
+
raise RuntimeError(f"OpenVINO model.bin not found at {model_bin_file}")
|
|
206
|
+
logging.info("Prepared OpenVINO model: %s", model_file)
|
|
207
|
+
|
|
208
|
+
logging.info("Model preparation completed successfully")
|
|
209
|
+
except Exception as e:
|
|
210
|
+
logging.error("Model preparation failed: %s", str(e), exc_info=True)
|
|
211
|
+
raise
|
|
212
|
+
|
|
213
|
+
def to_onnx(self, checkpoint_path: str, onnx_path: str, input_shape: Tuple[int, int, int, int]) -> None:
|
|
214
|
+
"""Export PyTorch or YOLO checkpoint to ONNX."""
|
|
215
|
+
try:
|
|
216
|
+
runtime_framework = self.runtime_framework.lower()
|
|
217
|
+
logging.info("Exporting %s model to ONNX on CPU", runtime_framework)
|
|
218
|
+
|
|
219
|
+
if runtime_framework == "yolo":
|
|
220
|
+
try:
|
|
221
|
+
if dependencies_check("ultralytics"):
|
|
222
|
+
from ultralytics import YOLO
|
|
223
|
+
logging.info("Using Ultralytics YOLO for ONNX export")
|
|
224
|
+
model = YOLO(checkpoint_path)
|
|
225
|
+
# NOTE: Update 4 -- opset=12 for YOLOv8 compatibility
|
|
226
|
+
export_path = model.export(format="onnx", imgsz=input_shape[2], dynamic=True, opset=12) # Added opset=12
|
|
227
|
+
if export_path != onnx_path:
|
|
228
|
+
shutil.move(export_path, onnx_path)
|
|
229
|
+
logging.info("Exported YOLO model to ONNX: %s", onnx_path)
|
|
230
|
+
return
|
|
231
|
+
else:
|
|
232
|
+
logging.warning("Ultralytics not available; falling back to PyTorch export for YOLO")
|
|
233
|
+
except Exception as e:
|
|
234
|
+
logging.warning("Ultralytics YOLO export failed: %s; trying PyTorch export", str(e))
|
|
235
|
+
|
|
236
|
+
import torch
|
|
237
|
+
model = torch.load(checkpoint_path, map_location="cpu")
|
|
238
|
+
# TODO: Add support for model_state_dict
|
|
239
|
+
model.eval()
|
|
240
|
+
dummy_input = torch.randn(*input_shape)
|
|
241
|
+
torch.onnx.export(
|
|
242
|
+
model,
|
|
243
|
+
dummy_input,
|
|
244
|
+
onnx_path,
|
|
245
|
+
opset_version=17,
|
|
246
|
+
input_names=["input__0"],
|
|
247
|
+
output_names=["output__0"],
|
|
248
|
+
dynamic_axes={"input__0": {0: "batch"}, "output__0": {0: "batch"}},
|
|
249
|
+
)
|
|
250
|
+
logging.info("Exported PyTorch model to ONNX: %s", onnx_path)
|
|
251
|
+
except Exception as e:
|
|
252
|
+
logging.error("Failed to export to ONNX: %s", str(e), exc_info=True)
|
|
253
|
+
raise
|
|
254
|
+
|
|
255
|
+
def _verify_onnx_model(self, onnx_path: str):
|
|
256
|
+
"""Verify that the ONNX model is valid"""
|
|
257
|
+
try:
|
|
258
|
+
if dependencies_check("onnx"):
|
|
259
|
+
import onnx
|
|
260
|
+
model = onnx.load(onnx_path)
|
|
261
|
+
onnx.checker.check_model(model)
|
|
262
|
+
logging.info(f"ONNX model verification successful: {onnx_path}")
|
|
263
|
+
else:
|
|
264
|
+
logging.warning("ONNX library not available for model verification")
|
|
265
|
+
except Exception as e:
|
|
266
|
+
logging.error(f"ONNX model verification failed: %s", str(e))
|
|
267
|
+
raise ValueError(f"Invalid ONNX model at {onnx_path}: {str(e)}")
|
|
268
|
+
|
|
269
|
+
def create_model_repository(self):
|
|
270
|
+
"""Create the model repository directory structure"""
|
|
271
|
+
try:
|
|
272
|
+
model_version = "1"
|
|
273
|
+
model_dir = os.path.join(BASE_PATH, self.model_name)
|
|
274
|
+
version_dir = os.path.join(model_dir, str(model_version))
|
|
275
|
+
logging.info("Creating model repository structure:")
|
|
276
|
+
logging.info("Base path: %s", BASE_PATH)
|
|
277
|
+
logging.info("Model directory: %s", model_dir)
|
|
278
|
+
logging.info("Version directory: %s", version_dir)
|
|
279
|
+
os.makedirs(version_dir, exist_ok=True)
|
|
280
|
+
logging.info("Model repository directories created successfully")
|
|
281
|
+
return model_dir, version_dir
|
|
282
|
+
except Exception as e:
|
|
283
|
+
logging.error(
|
|
284
|
+
"Failed to create model repository: %s",
|
|
285
|
+
str(e),
|
|
286
|
+
exc_info=True,
|
|
287
|
+
)
|
|
288
|
+
raise
|
|
289
|
+
|
|
290
|
+
def write_config_file(
|
|
291
|
+
self,
|
|
292
|
+
model_dir: str,
|
|
293
|
+
max_batch_size: int = 8,
|
|
294
|
+
num_model_instances: int = 1,
|
|
295
|
+
image_size: List[int] = [224, 224],
|
|
296
|
+
num_classes: int = 10,
|
|
297
|
+
input_data_type: str = "TYPE_FP32",
|
|
298
|
+
output_data_type: str = "TYPE_FP32",
|
|
299
|
+
dynamic_batching: bool = False,
|
|
300
|
+
preferred_batch_size: list = [2, 4, 8],
|
|
301
|
+
max_queue_delay_microseconds: int = 100,
|
|
302
|
+
input_pinned_memory: bool = True,
|
|
303
|
+
output_pinned_memory: bool = True,
|
|
304
|
+
**kwargs,
|
|
305
|
+
):
|
|
306
|
+
"""Write the model configuration file for Triton Inference Server."""
|
|
307
|
+
try:
|
|
308
|
+
runtime_framework = self.runtime_framework.lower()
|
|
309
|
+
logging.info("Starting to write Triton config file for framework: %s", runtime_framework)
|
|
310
|
+
|
|
311
|
+
if runtime_framework == "tensorrt":
|
|
312
|
+
platform = "tensorrt_plan"
|
|
313
|
+
model_filename = "model.plan"
|
|
314
|
+
elif runtime_framework in ["pytorch", "torchscript", "yolo", "onnx"]:
|
|
315
|
+
platform = "onnxruntime_onnx"
|
|
316
|
+
model_filename = "model.onnx"
|
|
317
|
+
else:
|
|
318
|
+
platform = "openvino"
|
|
319
|
+
model_filename = "model.xml"
|
|
320
|
+
logging.info("Using %s backend with model file: %s", platform, model_filename)
|
|
321
|
+
|
|
322
|
+
config_path = os.path.join(model_dir, "config.pbtxt")
|
|
323
|
+
logging.info("Writing config to: %s", config_path)
|
|
324
|
+
|
|
325
|
+
# NOTE: Update X0
|
|
326
|
+
onnx_to_triton_dtype = {
|
|
327
|
+
1: "TYPE_FP32", # FLOAT
|
|
328
|
+
2: "TYPE_UINT8", # UINT8
|
|
329
|
+
3: "TYPE_INT8", # INT8
|
|
330
|
+
4: "TYPE_UINT16", # UINT16
|
|
331
|
+
5: "TYPE_INT16", # INT16
|
|
332
|
+
6: "TYPE_INT32", # INT32
|
|
333
|
+
7: "TYPE_INT64", # INT64
|
|
334
|
+
8: "TYPE_STRING", # STRING
|
|
335
|
+
9: "TYPE_BOOL", # BOOL
|
|
336
|
+
10: "TYPE_FP16", # HALF
|
|
337
|
+
11: "TYPE_FP64", # DOUBLE
|
|
338
|
+
12: "TYPE_UINT32",# UINT32
|
|
339
|
+
13: "TYPE_UINT64",# UINT64
|
|
340
|
+
}
|
|
341
|
+
|
|
342
|
+
if platform == "onnxruntime_onnx":
|
|
343
|
+
model_file = os.path.join(model_dir, "1", "model.onnx")
|
|
344
|
+
if os.path.exists(model_file) and dependencies_check("onnx"):
|
|
345
|
+
import onnx
|
|
346
|
+
import onnx.numpy_helper
|
|
347
|
+
model = onnx.load(model_file)
|
|
348
|
+
graph = model.graph
|
|
349
|
+
|
|
350
|
+
inputs = []
|
|
351
|
+
for inp in graph.input:
|
|
352
|
+
shape = [d.dim_value if d.HasField("dim_value") else -1 for d in inp.type.tensor_type.shape.dim]
|
|
353
|
+
dtype_id = inp.type.tensor_type.elem_type
|
|
354
|
+
dtype = onnx_to_triton_dtype.get(dtype_id, "TYPE_FP32") # Fallback to FP32 if unknown
|
|
355
|
+
inputs.append((inp.name, dtype, shape))
|
|
356
|
+
|
|
357
|
+
outputs = []
|
|
358
|
+
for out in graph.output:
|
|
359
|
+
shape = [d.dim_value if d.HasField("dim_value") else -1 for d in out.type.tensor_type.shape.dim]
|
|
360
|
+
if self.is_yolo and out.name != "output__0":
|
|
361
|
+
continue # Skip intermediate YOLO outputs
|
|
362
|
+
output_shape = shape[1:] if max_batch_size > 0 and len(shape) > 1 else shape
|
|
363
|
+
dtype_id = out.type.tensor_type.elem_type
|
|
364
|
+
dtype = onnx_to_triton_dtype.get(dtype_id, "TYPE_FP32") # Fallback to FP32 if unknown
|
|
365
|
+
outputs.append((out.name, dtype, output_shape))
|
|
366
|
+
|
|
367
|
+
logging.info("ONNX inputs: %s", inputs)
|
|
368
|
+
logging.info("ONNX outputs: %s", outputs)
|
|
369
|
+
else:
|
|
370
|
+
# Fallback when ONNX / model 404
|
|
371
|
+
inputs = [(self.input_name, input_data_type, [3, image_size[0], image_size[1]])]
|
|
372
|
+
outputs = [(self.output_name, output_data_type, [-1, -1] if self.is_yolo else [num_classes])]
|
|
373
|
+
elif platform == "tensorrt_plan":
|
|
374
|
+
# TensorRT: Use OCR-specific configuration if is_ocr is True
|
|
375
|
+
if self.is_ocr:
|
|
376
|
+
# NOTE update Y1 -- hardcoded for OCR model
|
|
377
|
+
inputs = [("input", "TYPE_UINT8", [64, 128, 3])]
|
|
378
|
+
outputs = [("Identity:0", "TYPE_FP32", [9, 37])]
|
|
379
|
+
else:
|
|
380
|
+
# Fallback for YOLO or other TensorRT models
|
|
381
|
+
inputs = [(self.input_name, input_data_type, [3, image_size[0], image_size[1]])]
|
|
382
|
+
outputs = [(self.output_name, output_data_type, [-1, -1] if self.is_yolo else [num_classes])]
|
|
383
|
+
else:
|
|
384
|
+
# OpenVINO fallback
|
|
385
|
+
inputs = [(self.input_name, input_data_type, [3, image_size[0], image_size[1]])]
|
|
386
|
+
outputs = [(self.output_name, output_data_type, [num_classes])]
|
|
387
|
+
|
|
388
|
+
logging.info("Final inputs for config: %s", inputs)
|
|
389
|
+
logging.info("Final outputs for config: %s", outputs)
|
|
390
|
+
|
|
391
|
+
config_content = f'name: "{self.model_name}"\n'
|
|
392
|
+
config_content += f'platform: "{platform}"\n'
|
|
393
|
+
config_content += f'max_batch_size: {max_batch_size}\n'
|
|
394
|
+
|
|
395
|
+
# Input section
|
|
396
|
+
config_content += 'input [\n'
|
|
397
|
+
for name, dtype, shape in inputs:
|
|
398
|
+
config_content += ' {\n'
|
|
399
|
+
config_content += f' name: "{name}"\n'
|
|
400
|
+
config_content += f' data_type: {dtype}\n'
|
|
401
|
+
config_content += f' dims: [{", ".join(str(dim) for dim in shape)}]\n'
|
|
402
|
+
config_content += ' }\n'
|
|
403
|
+
config_content += ']\n'
|
|
404
|
+
|
|
405
|
+
# Output section
|
|
406
|
+
config_content += 'output [\n'
|
|
407
|
+
for name, dtype, shape in outputs:
|
|
408
|
+
config_content += ' {\n'
|
|
409
|
+
config_content += f' name: "{name}"\n'
|
|
410
|
+
config_content += f' data_type: {dtype}\n'
|
|
411
|
+
config_content += f' dims: [{", ".join(str(dim) for dim in shape)}]\n'
|
|
412
|
+
config_content += ' }\n'
|
|
413
|
+
config_content += ']\n'
|
|
414
|
+
|
|
415
|
+
# Instance group
|
|
416
|
+
if num_model_instances > 1 or self.gpus_count > 0:
|
|
417
|
+
device_type = "KIND_GPU" if self.gpus_count > 0 else "KIND_CPU"
|
|
418
|
+
logging.info("Adding instance group configuration for %s %s instances", num_model_instances, device_type)
|
|
419
|
+
config_content += 'instance_group [\n'
|
|
420
|
+
config_content += ' {\n'
|
|
421
|
+
config_content += f' count: {num_model_instances}\n'
|
|
422
|
+
config_content += f' kind: {device_type}\n'
|
|
423
|
+
config_content += ' }\n'
|
|
424
|
+
config_content += ']\n'
|
|
425
|
+
|
|
426
|
+
# Dynamic batching
|
|
427
|
+
if dynamic_batching:
|
|
428
|
+
logging.info("Adding dynamic batching configuration")
|
|
429
|
+
valid_pref_sizes = [bs for bs in preferred_batch_size if bs <= max_batch_size]
|
|
430
|
+
if valid_pref_sizes:
|
|
431
|
+
config_content += 'dynamic_batching {\n'
|
|
432
|
+
config_content += f' preferred_batch_size: [{", ".join(str(bs) for bs in valid_pref_sizes)}]\n'
|
|
433
|
+
config_content += f' max_queue_delay_microseconds: {max_queue_delay_microseconds}\n'
|
|
434
|
+
config_content += '}\n'
|
|
435
|
+
|
|
436
|
+
# Optimization settings
|
|
437
|
+
if input_pinned_memory or output_pinned_memory or self.gpus_count > 0:
|
|
438
|
+
config_content += 'optimization {\n'
|
|
439
|
+
if input_pinned_memory:
|
|
440
|
+
config_content += ' input_pinned_memory {\n'
|
|
441
|
+
config_content += ' enable: true\n'
|
|
442
|
+
config_content += ' }\n'
|
|
443
|
+
if output_pinned_memory:
|
|
444
|
+
config_content += ' output_pinned_memory {\n'
|
|
445
|
+
config_content += ' enable: true\n'
|
|
446
|
+
config_content += ' }\n'
|
|
447
|
+
if self.gpus_count > 0 and self.use_trt_accelerator:
|
|
448
|
+
config_content += ' execution_accelerators {\n'
|
|
449
|
+
config_content += ' gpu_execution_accelerator {\n'
|
|
450
|
+
config_content += ' name: "tensorrt"\n'
|
|
451
|
+
config_content += ' parameters {\n'
|
|
452
|
+
config_content += ' key: "precision_mode"\n'
|
|
453
|
+
config_content += ' value: "FP16"\n'
|
|
454
|
+
config_content += ' }\n'
|
|
455
|
+
config_content += ' parameters {\n'
|
|
456
|
+
config_content += ' key: "max_workspace_size_bytes"\n'
|
|
457
|
+
config_content += ' value: "1073741824"\n'
|
|
458
|
+
config_content += ' }\n'
|
|
459
|
+
config_content += ' parameters {\n'
|
|
460
|
+
config_content += ' key: "trt_engine_cache_enable"\n'
|
|
461
|
+
config_content += ' value: "1"\n'
|
|
462
|
+
config_content += ' }\n'
|
|
463
|
+
config_content += ' parameters {\n'
|
|
464
|
+
config_content += ' key: "trt_engine_cache_path"\n'
|
|
465
|
+
config_content += f' value: "/models/{self.model_name}/1"\n'
|
|
466
|
+
config_content += ' }\n'
|
|
467
|
+
config_content += ' }\n'
|
|
468
|
+
config_content += ' }\n'
|
|
469
|
+
config_content += '}\n'
|
|
470
|
+
|
|
471
|
+
with open(config_path, "w") as f:
|
|
472
|
+
f.write(config_content)
|
|
473
|
+
|
|
474
|
+
logging.info("Config file written successfully")
|
|
475
|
+
logging.info("Config content:\n%s", config_content)
|
|
476
|
+
|
|
477
|
+
except Exception as e:
|
|
478
|
+
logging.error("Failed to write config file: %s", str(e), exc_info=True)
|
|
479
|
+
raise
|
|
480
|
+
|
|
481
|
+
|
|
482
|
+
def get_config_params(self):
|
|
483
|
+
"""Get configuration parameters for Triton config file"""
|
|
484
|
+
try:
|
|
485
|
+
logging.info("Retrieving configuration parameters")
|
|
486
|
+
|
|
487
|
+
logging.info("Using input size: %s", self.input_size)
|
|
488
|
+
logging.info("Using number of classes: %s", self.num_classes)
|
|
489
|
+
|
|
490
|
+
params = {
|
|
491
|
+
"max_batch_size": self.max_batch_size,
|
|
492
|
+
"num_instances": self.num_model_instances,
|
|
493
|
+
"image_size": self.input_size,
|
|
494
|
+
"num_classes": self.num_classes,
|
|
495
|
+
"input_data_type": "TYPE_FP32",
|
|
496
|
+
"output_data_type": "TYPE_FP32",
|
|
497
|
+
"dynamic_batching": self.dynamic_batching,
|
|
498
|
+
"preferred_batch_size": [1, 2, 4, 8],
|
|
499
|
+
"max_queue_delay_microseconds": 100,
|
|
500
|
+
"input_pinned_memory": True,
|
|
501
|
+
"output_pinned_memory": True,
|
|
502
|
+
}
|
|
503
|
+
|
|
504
|
+
logging.debug("Final configuration parameters: %s", params)
|
|
505
|
+
return params
|
|
506
|
+
|
|
507
|
+
except Exception as e:
|
|
508
|
+
logging.error(
|
|
509
|
+
"Failed to get configuration parameters: %s",
|
|
510
|
+
str(e),
|
|
511
|
+
exc_info=True,
|
|
512
|
+
)
|
|
513
|
+
raise
|
|
514
|
+
|
|
515
|
+
def start_server(self, internal_port: int = 8000):
|
|
516
|
+
"""Start the Triton Inference Server
|
|
517
|
+
|
|
518
|
+
Args:
|
|
519
|
+
internal_port: Port to expose the server on
|
|
520
|
+
"""
|
|
521
|
+
gpu_option = "--gpus=all " if self.gpus_count > 0 else ""
|
|
522
|
+
logging.debug("Starting Triton server with GPU option: '%s'", gpu_option.strip())
|
|
523
|
+
triton_port = 8000 if self.connection_protocol == 'rest' else 8001
|
|
524
|
+
port_mapping = f"-p{internal_port}:{triton_port}"
|
|
525
|
+
start_triton_server = f"docker run {gpu_option}--rm {port_mapping} -v {os.path.abspath(BASE_PATH)}:/models --label model_name={self.model_name} {TRITON_DOCKER_IMAGE} tritonserver --model-repository=/models "
|
|
526
|
+
logging.info("Checking docker image download status before starting server")
|
|
527
|
+
self.check_triton_docker_image()
|
|
528
|
+
try:
|
|
529
|
+
logging.info(
|
|
530
|
+
"Starting Triton server with command: %s",
|
|
531
|
+
start_triton_server,
|
|
532
|
+
)
|
|
533
|
+
self.process = subprocess.Popen(
|
|
534
|
+
shlex.split(start_triton_server),
|
|
535
|
+
stdout=subprocess.PIPE,
|
|
536
|
+
stderr=subprocess.PIPE,
|
|
537
|
+
text=True,
|
|
538
|
+
)
|
|
539
|
+
|
|
540
|
+
def log_output():
|
|
541
|
+
while True:
|
|
542
|
+
stdout_line = self.process.stdout.readline()
|
|
543
|
+
stderr_line = self.process.stderr.readline()
|
|
544
|
+
if stdout_line:
|
|
545
|
+
logging.info(stdout_line.strip())
|
|
546
|
+
if stderr_line:
|
|
547
|
+
logging.info(stderr_line.strip())
|
|
548
|
+
if stdout_line == "" and stderr_line == "" and self.process.poll() is not None:
|
|
549
|
+
break
|
|
550
|
+
|
|
551
|
+
threading.Thread(target=log_output, daemon=False).start()
|
|
552
|
+
logging.info(
|
|
553
|
+
"Triton server started successfully on port %s",
|
|
554
|
+
internal_port,
|
|
555
|
+
)
|
|
556
|
+
return self.process
|
|
557
|
+
except Exception as e:
|
|
558
|
+
logging.error(
|
|
559
|
+
"Failed to start Triton server: %s",
|
|
560
|
+
str(e),
|
|
561
|
+
exc_info=True,
|
|
562
|
+
)
|
|
563
|
+
raise
|
|
564
|
+
|
|
565
|
+
def setup(self, internal_port: int = 8000):
|
|
566
|
+
"""Setup the Triton server with the provided model.
|
|
567
|
+
|
|
568
|
+
Args:
|
|
569
|
+
internal_port: Port to expose the server on
|
|
570
|
+
"""
|
|
571
|
+
try:
|
|
572
|
+
logging.info("Beginning Triton server setup")
|
|
573
|
+
logging.info("Step 1: Creating model repository")
|
|
574
|
+
self.model_dir, self.version_dir = self.create_model_repository()
|
|
575
|
+
logging.info("Step 2: Preparing model")
|
|
576
|
+
self.prepare_model(self.version_dir)
|
|
577
|
+
logging.info("Step 3: Getting configuration parameters")
|
|
578
|
+
self.config_params = self.get_config_params()
|
|
579
|
+
logging.info("Step 4: Writing configuration file")
|
|
580
|
+
self.write_config_file(
|
|
581
|
+
self.model_dir,
|
|
582
|
+
**self.config_params,
|
|
583
|
+
)
|
|
584
|
+
logging.info("Step 5: Starting Triton server")
|
|
585
|
+
self.process = self.start_server(internal_port)
|
|
586
|
+
logging.info("Triton server setup completed successfully")
|
|
587
|
+
return self.process
|
|
588
|
+
except Exception as e:
|
|
589
|
+
logging.error(
|
|
590
|
+
"Triton server setup failed: %s",
|
|
591
|
+
str(e),
|
|
592
|
+
exc_info=True,
|
|
593
|
+
)
|
|
594
|
+
raise
|
|
595
|
+
|
|
596
|
+
|
|
597
|
+
"""Module providing inference_utils functionality for FastAPI and Triton inference."""
|
|
598
|
+
|
|
599
|
+
from PIL import Image
|
|
600
|
+
import httpx
|
|
601
|
+
import logging
|
|
602
|
+
from typing import Optional, Dict, Union, Any
|
|
603
|
+
from datetime import datetime, timezone
|
|
604
|
+
from io import BytesIO
|
|
605
|
+
import numpy as np
|
|
606
|
+
import cv2
|
|
607
|
+
import torch
|
|
608
|
+
import torchvision
|
|
609
|
+
from typing import Tuple, Dict, Any, Optional, Union
|
|
610
|
+
import logging
|
|
611
|
+
from PIL import Image
|
|
612
|
+
from io import BytesIO
|
|
613
|
+
import os
|
|
614
|
+
from datetime import datetime, timezone
|
|
615
|
+
|
|
616
|
+
class TritonInference:
|
|
617
|
+
"""Class for making Triton inference requests."""
|
|
618
|
+
|
|
619
|
+
def __init__(
|
|
620
|
+
self,
|
|
621
|
+
server_type: str,
|
|
622
|
+
model_name: str,
|
|
623
|
+
internal_port: int = 80,
|
|
624
|
+
internal_host: str = "localhost",
|
|
625
|
+
task_type: str = "detection",
|
|
626
|
+
runtime_framework: str = "onnx",
|
|
627
|
+
is_yolo: bool = False,
|
|
628
|
+
is_ocr: bool = False,
|
|
629
|
+
input_size: Union[int, List[int]] = (224, 224)
|
|
630
|
+
):
|
|
631
|
+
"""Initialize Triton inference client.
|
|
632
|
+
|
|
633
|
+
Args:
|
|
634
|
+
server_type: Type of server (grpc/rest)
|
|
635
|
+
model_name: Name of model to use
|
|
636
|
+
internal_port: Port number for internal API
|
|
637
|
+
internal_host: Hostname for internal API
|
|
638
|
+
task_type: Type of task (e.g., detection)
|
|
639
|
+
runtime_framework: Framework used for the model (e.g., onnx)
|
|
640
|
+
is_yolo: Boolean indicating if the model is YOLO
|
|
641
|
+
is_ocr: Boolean indicating if the model is an OCR model
|
|
642
|
+
input_size: Input size for the model (int or [height, width])
|
|
643
|
+
"""
|
|
644
|
+
self.model_name = model_name
|
|
645
|
+
self.task_type = task_type
|
|
646
|
+
self.runtime_framework = runtime_framework
|
|
647
|
+
self.is_yolo = is_yolo
|
|
648
|
+
self.is_ocr = is_ocr
|
|
649
|
+
self.input_size = [input_size, input_size] if isinstance(input_size, int) else input_size
|
|
650
|
+
self.ocr_config = {
|
|
651
|
+
"color_mode": "rgba",
|
|
652
|
+
"keep_aspect_ratio": True,
|
|
653
|
+
"interpolation": "linear",
|
|
654
|
+
"padding_color": (114, 114, 114, 255),
|
|
655
|
+
}
|
|
656
|
+
self.data_type_mapping = {
|
|
657
|
+
2: "TYPE_UINT8",
|
|
658
|
+
6: "TYPE_INT8",
|
|
659
|
+
7: "TYPE_INT16",
|
|
660
|
+
8: "TYPE_INT32",
|
|
661
|
+
9: "TYPE_INT64",
|
|
662
|
+
10: "TYPE_FP16",
|
|
663
|
+
11: "TYPE_FP32",
|
|
664
|
+
12: "TYPE_FP64",
|
|
665
|
+
}
|
|
666
|
+
self.numpy_data_type_mapping = {
|
|
667
|
+
"INT8": np.int8,
|
|
668
|
+
"INT16": np.int16,
|
|
669
|
+
"INT32": np.int32,
|
|
670
|
+
"INT64": np.int64,
|
|
671
|
+
"FP16": np.float16,
|
|
672
|
+
"FP32": np.float32,
|
|
673
|
+
"FP64": np.float64,
|
|
674
|
+
"UINT8": np.uint8,
|
|
675
|
+
}
|
|
676
|
+
self.setup_client_funcs = {
|
|
677
|
+
"grpc": self._setup_grpc_client,
|
|
678
|
+
"rest": self._setup_rest_client,
|
|
679
|
+
}
|
|
680
|
+
self.url = f"{internal_host}:{internal_port}"
|
|
681
|
+
self.connection_protocol = "grpc" if "grpc" in server_type else "rest"
|
|
682
|
+
self.tritonclientclass = None
|
|
683
|
+
self._dependencies_check()
|
|
684
|
+
self.client_info = self.setup_client_funcs[self.connection_protocol]()
|
|
685
|
+
logging.info(
|
|
686
|
+
"Initialized TritonInference with %s protocol",
|
|
687
|
+
self.connection_protocol,
|
|
688
|
+
)
|
|
689
|
+
|
|
690
|
+
def _dependencies_check(self):
|
|
691
|
+
"""Check and import required Triton dependencies."""
|
|
692
|
+
try:
|
|
693
|
+
if self.connection_protocol == "rest":
|
|
694
|
+
import tritonclient.http as tritonclientclass
|
|
695
|
+
else:
|
|
696
|
+
import tritonclient.grpc as tritonclientclass
|
|
697
|
+
self.tritonclientclass = tritonclientclass
|
|
698
|
+
except ImportError as err:
|
|
699
|
+
package_name = "tritonclient[http]" if self.connection_protocol == "rest" else "tritonclient[grpc]"
|
|
700
|
+
logging.error(
|
|
701
|
+
"Failed to import tritonclient (%s): %s. Please install with: pip install %s",
|
|
702
|
+
package_name, err, package_name
|
|
703
|
+
)
|
|
704
|
+
raise ImportError(f"Required package {package_name} not installed: {err}")
|
|
705
|
+
except Exception as err:
|
|
706
|
+
logging.error(
|
|
707
|
+
"Failed to import tritonclient: %s",
|
|
708
|
+
err,
|
|
709
|
+
)
|
|
710
|
+
raise
|
|
711
|
+
|
|
712
|
+
def _setup_rest_client(self):
|
|
713
|
+
"""Setup REST client and model configuration.
|
|
714
|
+
|
|
715
|
+
Returns:
|
|
716
|
+
Dictionary containing client configuration
|
|
717
|
+
"""
|
|
718
|
+
client = self.tritonclientclass.InferenceServerClient(url=self.url)
|
|
719
|
+
model_config = client.get_model_config(model_name=self.model_name, model_version="1")
|
|
720
|
+
input_config = model_config["input"][0]
|
|
721
|
+
input_shape = [1] + input_config["dims"] # Prepend batch dimension
|
|
722
|
+
input_obj = self.tritonclientclass.InferInput(
|
|
723
|
+
input_config["name"],
|
|
724
|
+
input_shape,
|
|
725
|
+
input_config["data_type"].split("_")[-1],
|
|
726
|
+
)
|
|
727
|
+
output = self.tritonclientclass.InferRequestedOutput(model_config["output"][0]["name"])
|
|
728
|
+
return {
|
|
729
|
+
"client": client,
|
|
730
|
+
"input": input_obj,
|
|
731
|
+
"output": output,
|
|
732
|
+
}
|
|
733
|
+
|
|
734
|
+
def _setup_grpc_client(self):
|
|
735
|
+
"""Setup gRPC client and model configuration.
|
|
736
|
+
|
|
737
|
+
Returns:
|
|
738
|
+
Dictionary containing client configuration
|
|
739
|
+
"""
|
|
740
|
+
client = self.tritonclientclass.InferenceServerClient(url=self.url)
|
|
741
|
+
model_config = client.get_model_config(model_name=self.model_name, model_version="1")
|
|
742
|
+
input_config = model_config.config.input[0]
|
|
743
|
+
input_shape = [1] + list(input_config.dims) # Prepend batch dimension
|
|
744
|
+
input_obj = self.tritonclientclass.InferInput(
|
|
745
|
+
input_config.name,
|
|
746
|
+
input_shape,
|
|
747
|
+
self.data_type_mapping[input_config.data_type].split("_")[-1],
|
|
748
|
+
)
|
|
749
|
+
output = self.tritonclientclass.InferRequestedOutput(model_config.config.output[0].name)
|
|
750
|
+
return {
|
|
751
|
+
"client": client,
|
|
752
|
+
"input": input_obj,
|
|
753
|
+
"output": output,
|
|
754
|
+
}
|
|
755
|
+
|
|
756
|
+
def inference(self, input_data: Union[bytes, np.ndarray]) -> np.ndarray:
|
|
757
|
+
"""Make a synchronous inference request.
|
|
758
|
+
|
|
759
|
+
Args:
|
|
760
|
+
input_data: Input data as bytes or stacked numpy array
|
|
761
|
+
|
|
762
|
+
Returns:
|
|
763
|
+
Model prediction as numpy array
|
|
764
|
+
|
|
765
|
+
Raises:
|
|
766
|
+
Exception: If inference fails
|
|
767
|
+
"""
|
|
768
|
+
try:
|
|
769
|
+
# If already preprocessed ndarray, make it C-contiguous FP32.
|
|
770
|
+
if isinstance(input_data, np.ndarray):
|
|
771
|
+
input_array = np.ascontiguousarray(input_data, dtype=np.float32)
|
|
772
|
+
if input_array.ndim == 5 and input_array.shape[1] == 1:
|
|
773
|
+
# [B, 1, C, H, W] -> [B, C, H, W]
|
|
774
|
+
input_array = np.ascontiguousarray(
|
|
775
|
+
input_array.reshape(input_array.shape[0],
|
|
776
|
+
input_array.shape[2],
|
|
777
|
+
input_array.shape[3],
|
|
778
|
+
input_array.shape[4]),
|
|
779
|
+
dtype=np.float32
|
|
780
|
+
)
|
|
781
|
+
else:
|
|
782
|
+
# -> [1, C, H, W], FP32, contiguous
|
|
783
|
+
input_array = self._preprocess_input(input_data)
|
|
784
|
+
|
|
785
|
+
# Update InferInput shape to match batch (N,C,H,W)
|
|
786
|
+
self.client_info["input"].set_shape(list(input_array.shape))
|
|
787
|
+
self.client_info["input"].set_data_from_numpy(input_array)
|
|
788
|
+
|
|
789
|
+
if self.connection_protocol == "rest":
|
|
790
|
+
resp = self.client_info["client"].infer(
|
|
791
|
+
model_name=self.model_name,
|
|
792
|
+
model_version="1",
|
|
793
|
+
inputs=[self.client_info["input"]],
|
|
794
|
+
outputs=[self.client_info["output"]],
|
|
795
|
+
)
|
|
796
|
+
else:
|
|
797
|
+
resp = self.client_info["client"].infer(
|
|
798
|
+
model_name=self.model_name,
|
|
799
|
+
model_version="1",
|
|
800
|
+
inputs=[self.client_info["input"]],
|
|
801
|
+
outputs=[self.client_info["output"]],
|
|
802
|
+
)
|
|
803
|
+
|
|
804
|
+
return resp.as_numpy(self.client_info["output"].name())
|
|
805
|
+
|
|
806
|
+
except Exception as err:
|
|
807
|
+
logging.error("Triton inference failed: %s", err, exc_info=True)
|
|
808
|
+
raise Exception(f"Triton inference failed: {err}") from err
|
|
809
|
+
|
|
810
|
+
async def async_inference(self, input_data: Union[bytes, np.ndarray]) -> np.ndarray:
|
|
811
|
+
"""Make an asynchronous inference request (REST + gRPC)."""
|
|
812
|
+
try:
|
|
813
|
+
logging.debug("Making async inference request")
|
|
814
|
+
|
|
815
|
+
if isinstance(input_data, np.ndarray):
|
|
816
|
+
input_array = input_data
|
|
817
|
+
else:
|
|
818
|
+
input_array = self._preprocess_input(input_data)
|
|
819
|
+
|
|
820
|
+
# Ensure C-contiguous
|
|
821
|
+
if not input_array.flags.c_contiguous:
|
|
822
|
+
input_array = np.ascontiguousarray(input_array)
|
|
823
|
+
|
|
824
|
+
self.client_info["input"].set_shape(list(input_array.shape))
|
|
825
|
+
self.client_info["input"].set_data_from_numpy(input_array)
|
|
826
|
+
|
|
827
|
+
if self.connection_protocol == "rest":
|
|
828
|
+
# REST: async_infer -> InferAsyncRequest, then block with get_result()
|
|
829
|
+
resp = self.client_info["client"].async_infer(
|
|
830
|
+
model_name=self.model_name,
|
|
831
|
+
model_version="1",
|
|
832
|
+
inputs=[self.client_info["input"]],
|
|
833
|
+
outputs=[self.client_info["output"]],
|
|
834
|
+
)
|
|
835
|
+
result = resp.get_result()
|
|
836
|
+
|
|
837
|
+
else:
|
|
838
|
+
# gRPC: async_infer uses callback; wrap it into an awaitable Future
|
|
839
|
+
loop = asyncio.get_running_loop()
|
|
840
|
+
fut: asyncio.Future = loop.create_future()
|
|
841
|
+
|
|
842
|
+
def _callback(result, error):
|
|
843
|
+
if error is not None:
|
|
844
|
+
loop.call_soon_threadsafe(fut.set_exception, error)
|
|
845
|
+
else:
|
|
846
|
+
loop.call_soon_threadsafe(fut.set_result, result)
|
|
847
|
+
|
|
848
|
+
self.client_info["client"].async_infer(
|
|
849
|
+
model_name=self.model_name,
|
|
850
|
+
model_version="1",
|
|
851
|
+
inputs=[self.client_info["input"]],
|
|
852
|
+
outputs=[self.client_info["output"]],
|
|
853
|
+
callback=_callback,
|
|
854
|
+
)
|
|
855
|
+
result = await fut
|
|
856
|
+
|
|
857
|
+
logging.debug(f"Async inference response type: {type(result)}")
|
|
858
|
+
logging.info("Successfully got async inference result")
|
|
859
|
+
|
|
860
|
+
output_array = result.as_numpy(self.client_info["output"].name())
|
|
861
|
+
logging.info(f"Output shape: {output_array.shape}")
|
|
862
|
+
return output_array
|
|
863
|
+
|
|
864
|
+
except Exception as err:
|
|
865
|
+
logging.error(f"Async Triton inference failed: {err}")
|
|
866
|
+
raise Exception(f"Async Triton inference failed: {err}") from err
|
|
867
|
+
|
|
868
|
+
def _preprocess_input(self, input_data) -> np.ndarray:
|
|
869
|
+
"""Preprocess input data for YOLOv8 or OCR inference.
|
|
870
|
+
|
|
871
|
+
Args:
|
|
872
|
+
input_data: Raw input bytes or string (file path)
|
|
873
|
+
|
|
874
|
+
Returns:
|
|
875
|
+
Preprocessed numpy array ready for inference
|
|
876
|
+
"""
|
|
877
|
+
if isinstance(self.input_size, int):
|
|
878
|
+
resize_shape = (self.input_size, self.input_size)
|
|
879
|
+
elif isinstance(self.input_size, (list, tuple)) and len(self.input_size) == 2:
|
|
880
|
+
resize_shape = (self.input_size[0], self.input_size[1])
|
|
881
|
+
input_shape = [1, 3, resize_shape[0], resize_shape[1]] # Default for compatibility
|
|
882
|
+
|
|
883
|
+
if isinstance(input_data, str) and os.path.exists(input_data):
|
|
884
|
+
with open(input_data, "rb") as f:
|
|
885
|
+
input_data = f.read()
|
|
886
|
+
|
|
887
|
+
if isinstance(input_data, bytes):
|
|
888
|
+
try:
|
|
889
|
+
image = Image.open(BytesIO(input_data)).convert("RGB")
|
|
890
|
+
except Exception as e:
|
|
891
|
+
arr = np.frombuffer(input_data, dtype=np.uint8).reshape(640, 640, 3)
|
|
892
|
+
image = Image.fromarray(arr, mode="RGB")
|
|
893
|
+
elif isinstance(input_data, np.ndarray):
|
|
894
|
+
image = Image.fromarray(input_data, mode="RGB")
|
|
895
|
+
else:
|
|
896
|
+
raise ValueError(f"Unsupported input_data type: {type(input_data)}")
|
|
897
|
+
|
|
898
|
+
if self.is_yolo:
|
|
899
|
+
logging.debug("Preprocessing input for YOLO model")
|
|
900
|
+
image, ratio, (dw, dh) = self._letterbox_resize(image, resize_shape)
|
|
901
|
+
arr = np.array(image).astype(np.float32) / 255.0
|
|
902
|
+
arr = arr.transpose(2, 0, 1)
|
|
903
|
+
arr = np.expand_dims(arr, axis=0)
|
|
904
|
+
self.client_info["padding_info"] = {"ratio": ratio, "dw": dw, "dh": dh}
|
|
905
|
+
elif self.is_ocr:
|
|
906
|
+
logging.debug("Preprocessing input for OCR model")
|
|
907
|
+
config = getattr(self, "ocr_config", {})
|
|
908
|
+
arr = self._preprocess_ocr(
|
|
909
|
+
image=image,
|
|
910
|
+
resize_shape=resize_shape,
|
|
911
|
+
image_color_mode=config.get("color_mode", "rgb"),
|
|
912
|
+
keep_aspect_ratio=config.get("keep_aspect_ratio", False),
|
|
913
|
+
interpolation_method=config.get("interpolation", "linear"),
|
|
914
|
+
padding_color=config.get("padding_color", (114, 114, 114)),
|
|
915
|
+
use_grayscale=config.get("use_grayscale", False),
|
|
916
|
+
apply_contrast=config.get("apply_contrast", False),
|
|
917
|
+
)
|
|
918
|
+
else:
|
|
919
|
+
# Classifier preprocessing: resize directly, ImageNet normalization
|
|
920
|
+
image = image.resize(resize_shape)
|
|
921
|
+
arr = np.array(image).astype(np.float32) / 255.0
|
|
922
|
+
arr = arr.transpose(2, 0, 1) # Convert to CHW
|
|
923
|
+
arr = np.expand_dims(arr, axis=0) # Add batch dimension
|
|
924
|
+
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(1, 3, 1, 1)
|
|
925
|
+
std = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(1, 3, 1, 1)
|
|
926
|
+
arr = (arr - mean) / std
|
|
927
|
+
|
|
928
|
+
arr = arr.astype(self.numpy_data_type_mapping[self.client_info["input"].datatype()])
|
|
929
|
+
return arr
|
|
930
|
+
|
|
931
|
+
|
|
932
|
+
def _preprocess_ocr(
|
|
933
|
+
self,
|
|
934
|
+
image: Image.Image,
|
|
935
|
+
resize_shape: tuple,
|
|
936
|
+
image_color_mode: str = "rgb",
|
|
937
|
+
keep_aspect_ratio: bool = False,
|
|
938
|
+
interpolation_method: str = "linear",
|
|
939
|
+
padding_color: tuple = (114, 114, 114),
|
|
940
|
+
use_grayscale: bool = False,
|
|
941
|
+
apply_contrast: bool = False,
|
|
942
|
+
) -> np.ndarray:
|
|
943
|
+
"""Preprocess an input PIL Image for OCR model inference.
|
|
944
|
+
|
|
945
|
+
Args:
|
|
946
|
+
image: PIL Image in RGB format.
|
|
947
|
+
resize_shape: (height, width) tuple.
|
|
948
|
+
image_color_mode: "rgb" or "grayscale" (affects output channels).
|
|
949
|
+
keep_aspect_ratio: Whether to preserve aspect ratio with padding.
|
|
950
|
+
interpolation_method: One of ["linear", "nearest", "cubic", "area"].
|
|
951
|
+
padding_color: Padding color for aspect ratio preservation (RGB or scalar for grayscale).
|
|
952
|
+
use_grayscale: Convert to grayscale before processing (default: False).
|
|
953
|
+
apply_contrast: Apply CLAHE contrast enhancement (default: False).
|
|
954
|
+
|
|
955
|
+
Returns:
|
|
956
|
+
Preprocessed numpy array (batch, height, width, C) ready for OCR inference.
|
|
957
|
+
"""
|
|
958
|
+
import cv2
|
|
959
|
+
import numpy as np
|
|
960
|
+
|
|
961
|
+
img_height, img_width = resize_shape
|
|
962
|
+
img = np.array(image)
|
|
963
|
+
|
|
964
|
+
if img.shape[-1] == 4:
|
|
965
|
+
img = img[:, :, :3]
|
|
966
|
+
|
|
967
|
+
if use_grayscale:
|
|
968
|
+
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
|
|
969
|
+
if image_color_mode == "rgb":
|
|
970
|
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
|
|
971
|
+
elif image_color_mode == "grayscale":
|
|
972
|
+
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
|
|
973
|
+
|
|
974
|
+
if apply_contrast:
|
|
975
|
+
if image_color_mode == "grayscale" or use_grayscale:
|
|
976
|
+
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
|
977
|
+
img = clahe.apply(img if img.ndim == 2 else img[:, :, 0])
|
|
978
|
+
if image_color_mode == "rgb":
|
|
979
|
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
|
|
980
|
+
else:
|
|
981
|
+
img_yuv = cv2.cvtColor(img, cv2.COLOR_RGB2YUV)
|
|
982
|
+
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
|
983
|
+
img_yuv[:, :, 0] = clahe.apply(img_yuv[:, :, 0])
|
|
984
|
+
img = cv2.cvtColor(img_yuv, cv2.COLOR_YUV2RGB)
|
|
985
|
+
|
|
986
|
+
INTERPOLATION_MAP = {
|
|
987
|
+
"linear": cv2.INTER_LINEAR,
|
|
988
|
+
"nearest": cv2.INTER_NEAREST,
|
|
989
|
+
"cubic": cv2.INTER_CUBIC,
|
|
990
|
+
"area": cv2.INTER_AREA,
|
|
991
|
+
}
|
|
992
|
+
interpolation = INTERPOLATION_MAP.get(interpolation_method, cv2.INTER_LINEAR)
|
|
993
|
+
|
|
994
|
+
if not keep_aspect_ratio:
|
|
995
|
+
img = cv2.resize(img, (img_width, img_height), interpolation=interpolation)
|
|
996
|
+
else:
|
|
997
|
+
orig_h, orig_w = img.shape[:2]
|
|
998
|
+
r = min(img_height / orig_h, img_width / orig_w)
|
|
999
|
+
new_unpad_w, new_unpad_h = round(orig_w * r), round(orig_h * r)
|
|
1000
|
+
img = cv2.resize(img, (new_unpad_w, new_unpad_h), interpolation=interpolation)
|
|
1001
|
+
dw, dh = (img_width - new_unpad_w) / 2, (img_height - new_unpad_h) / 2
|
|
1002
|
+
top, bottom, left, right = (
|
|
1003
|
+
round(dh - 0.1),
|
|
1004
|
+
round(dh + 0.1),
|
|
1005
|
+
round(dw - 0.1),
|
|
1006
|
+
round(dw + 0.1),
|
|
1007
|
+
)
|
|
1008
|
+
border_color = padding_color[0] if image_color_mode == "grayscale" else padding_color
|
|
1009
|
+
img = cv2.copyMakeBorder(
|
|
1010
|
+
img,
|
|
1011
|
+
top,
|
|
1012
|
+
bottom,
|
|
1013
|
+
left,
|
|
1014
|
+
right,
|
|
1015
|
+
borderType=cv2.BORDER_CONSTANT,
|
|
1016
|
+
value=border_color,
|
|
1017
|
+
)
|
|
1018
|
+
|
|
1019
|
+
if image_color_mode == "grayscale" and img.ndim == 2:
|
|
1020
|
+
img = np.expand_dims(img, axis=-1)
|
|
1021
|
+
|
|
1022
|
+
dtype = self.numpy_data_type_mapping[self.client_info["input"].datatype()]
|
|
1023
|
+
arr = img.astype(np.float32)
|
|
1024
|
+
if dtype == np.uint8:
|
|
1025
|
+
arr = (arr * 255.0).astype(np.uint8)
|
|
1026
|
+
else:
|
|
1027
|
+
arr = arr / 255.0
|
|
1028
|
+
|
|
1029
|
+
arr = np.expand_dims(arr, axis=0)
|
|
1030
|
+
|
|
1031
|
+
logging.debug(f"Preprocessed OCR input shape: {arr.shape}, dtype: {arr.dtype}")
|
|
1032
|
+
|
|
1033
|
+
if arr.shape[-1] == 3:
|
|
1034
|
+
cv2.imwrite("preprocessed_ocr_image.png", cv2.cvtColor((arr[0] * 255).astype(np.uint8), cv2.COLOR_RGB2BGR))
|
|
1035
|
+
else:
|
|
1036
|
+
cv2.imwrite("preprocessed_ocr_image.png", (arr[0, :, :, 0] * 255).astype(np.uint8))
|
|
1037
|
+
|
|
1038
|
+
return arr.astype(dtype)
|
|
1039
|
+
|
|
1040
|
+
|
|
1041
|
+
|
|
1042
|
+
def _letterbox_resize(self, image, target_size):
|
|
1043
|
+
"""Resize image with letterbox padding to maintain aspect ratio."""
|
|
1044
|
+
target_h, target_w = target_size
|
|
1045
|
+
img_w, img_h = image.size
|
|
1046
|
+
ratio = min(target_w / img_w, target_h / img_h)
|
|
1047
|
+
new_w, new_h = int(img_w * ratio), int(img_h * ratio)
|
|
1048
|
+
image = image.resize((new_w, new_h), Image.LANCZOS)
|
|
1049
|
+
padded_image = Image.new("RGB", (target_w, target_h), (114, 114, 114))
|
|
1050
|
+
dw, dh = (target_w - new_w) // 2, (target_h - new_h) // 2
|
|
1051
|
+
padded_image.paste(image, (dw, dh))
|
|
1052
|
+
logging.debug("Letterbox resize: original size %s, new size %s, padding (dw, dh) (%d, %d)", (img_w, img_h), (new_w, new_h), dw, dh)
|
|
1053
|
+
logging.debug("Letterbox resize completed for target size: %s", target_size)
|
|
1054
|
+
return padded_image, ratio, (dw, dh)
|
|
1055
|
+
|
|
1056
|
+
def _postprocess_yolo(
|
|
1057
|
+
self,
|
|
1058
|
+
outputs: np.ndarray,
|
|
1059
|
+
conf_thres: float = 0.25,
|
|
1060
|
+
iou_thres: float = 0.45,
|
|
1061
|
+
max_det: int = 300
|
|
1062
|
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
1063
|
+
"""Postprocess YOLOv8 outputs (Torch or ONNX/Triton) with pipeline compatibility.
|
|
1064
|
+
|
|
1065
|
+
Args:
|
|
1066
|
+
outputs: Raw model output as NumPy array, expected shape [batch, num_boxes, num_classes + 4] for YOLO
|
|
1067
|
+
conf_thres: Confidence threshold for filtering detections
|
|
1068
|
+
iou_thres: IoU threshold for Non-Maximum Suppression
|
|
1069
|
+
max_det: Maximum number of detections to keep
|
|
1070
|
+
|
|
1071
|
+
Returns:
|
|
1072
|
+
Tuple of (boxes, scores, class_ids) as NumPy arrays:
|
|
1073
|
+
- boxes: [N, 4] array of bounding boxes in xyxy format
|
|
1074
|
+
- scores: [N] array of confidence scores
|
|
1075
|
+
- class_ids: [N] array of class IDs
|
|
1076
|
+
"""
|
|
1077
|
+
if not self.is_yolo:
|
|
1078
|
+
return outputs, np.array([]), np.array([])
|
|
1079
|
+
|
|
1080
|
+
try:
|
|
1081
|
+
if isinstance(outputs, np.ndarray):
|
|
1082
|
+
outputs = torch.from_numpy(outputs)
|
|
1083
|
+
elif not isinstance(outputs, torch.Tensor):
|
|
1084
|
+
outputs = torch.tensor(outputs, dtype=torch.float32)
|
|
1085
|
+
|
|
1086
|
+
if outputs.ndim == 2:
|
|
1087
|
+
outputs = outputs.unsqueeze(0)
|
|
1088
|
+
|
|
1089
|
+
if outputs.shape[1] < outputs.shape[2]:
|
|
1090
|
+
# Format [batch, num_classes + 4, num_boxes] -> [batch, num_boxes, num_classes + 4]
|
|
1091
|
+
outputs = outputs.transpose(1, 2)
|
|
1092
|
+
|
|
1093
|
+
boxes = outputs[..., :4] # xywh, [batch, num_boxes, 4]
|
|
1094
|
+
scores_all = outputs[..., 4:] # class scores, [batch, num_boxes, num_classes]
|
|
1095
|
+
|
|
1096
|
+
all_boxes, all_scores, all_class_ids = [], [], []
|
|
1097
|
+
for batch_idx in range(outputs.shape[0]):
|
|
1098
|
+
batch_boxes = boxes[batch_idx] # [num_boxes, 4]
|
|
1099
|
+
batch_scores_all = scores_all[batch_idx] # [num_boxes, num_classes]
|
|
1100
|
+
|
|
1101
|
+
# Get best class per box
|
|
1102
|
+
scores, class_ids = batch_scores_all.max(dim=-1)
|
|
1103
|
+
|
|
1104
|
+
# Confidence filter
|
|
1105
|
+
mask = scores > conf_thres
|
|
1106
|
+
batch_boxes = batch_boxes[mask]
|
|
1107
|
+
batch_scores = scores[mask]
|
|
1108
|
+
batch_class_ids = class_ids[mask]
|
|
1109
|
+
|
|
1110
|
+
# Convert xywh -> xyxy
|
|
1111
|
+
if batch_boxes.shape[0] > 0:
|
|
1112
|
+
xyxy = torch.zeros_like(batch_boxes)
|
|
1113
|
+
xyxy[:, 0] = batch_boxes[:, 0] - batch_boxes[:, 2] / 2 # x1
|
|
1114
|
+
xyxy[:, 1] = batch_boxes[:, 1] - batch_boxes[:, 3] / 2 # y1
|
|
1115
|
+
xyxy[:, 2] = batch_boxes[:, 0] + batch_boxes[:, 2] / 2 # x2
|
|
1116
|
+
xyxy[:, 3] = batch_boxes[:, 1] + batch_boxes[:, 3] / 2 # y2
|
|
1117
|
+
batch_boxes = xyxy
|
|
1118
|
+
|
|
1119
|
+
# Adjust for letterbox padding if provided
|
|
1120
|
+
padding_info = self.client_info.get("padding_info", {})
|
|
1121
|
+
if padding_info:
|
|
1122
|
+
ratio = padding_info.get("ratio", 1.0)
|
|
1123
|
+
dw = padding_info.get("dw", 0)
|
|
1124
|
+
dh = padding_info.get("dh", 0)
|
|
1125
|
+
batch_boxes[:, 0] = (batch_boxes[:, 0] - dw) / ratio # x1
|
|
1126
|
+
batch_boxes[:, 1] = (batch_boxes[:, 1] - dh) / ratio # y1
|
|
1127
|
+
batch_boxes[:, 2] = (batch_boxes[:, 2] - dw) / ratio # x2
|
|
1128
|
+
batch_boxes[:, 3] = (batch_boxes[:, 3] - dh) / ratio # y2
|
|
1129
|
+
|
|
1130
|
+
# NMS
|
|
1131
|
+
if batch_boxes.shape[0] > 0:
|
|
1132
|
+
keep = torchvision.ops.nms(batch_boxes, batch_scores, iou_thres)
|
|
1133
|
+
batch_boxes = batch_boxes[keep]
|
|
1134
|
+
batch_scores = batch_scores[keep]
|
|
1135
|
+
batch_class_ids = batch_class_ids[keep]
|
|
1136
|
+
|
|
1137
|
+
if batch_boxes.shape[0] > max_det:
|
|
1138
|
+
topk = batch_scores.topk(max_det).indices
|
|
1139
|
+
batch_boxes = batch_boxes[topk]
|
|
1140
|
+
batch_scores = batch_scores[topk]
|
|
1141
|
+
batch_class_ids = batch_class_ids[topk]
|
|
1142
|
+
|
|
1143
|
+
all_boxes.append(batch_boxes.cpu().numpy())
|
|
1144
|
+
all_scores.append(batch_scores.cpu().numpy())
|
|
1145
|
+
all_class_ids.append(batch_class_ids.cpu().numpy())
|
|
1146
|
+
|
|
1147
|
+
boxes = np.concatenate(all_boxes, axis=0) if all_boxes else np.empty((0, 4))
|
|
1148
|
+
scores = np.concatenate(all_scores, axis=0) if all_scores else np.empty(0)
|
|
1149
|
+
class_ids = np.concatenate(all_class_ids, axis=0) if all_class_ids else np.empty(0)
|
|
1150
|
+
|
|
1151
|
+
return boxes, scores, class_ids
|
|
1152
|
+
|
|
1153
|
+
except Exception as e:
|
|
1154
|
+
logging.error("YOLO post-processing failed: %s", str(e), exc_info=True)
|
|
1155
|
+
return np.empty((0, 4)), np.empty(0), np.empty(0)
|
|
1156
|
+
|
|
1157
|
+
def _postprocess_ocr(
|
|
1158
|
+
self,
|
|
1159
|
+
model_output: np.ndarray,
|
|
1160
|
+
max_plate_slots: int = 9,
|
|
1161
|
+
model_alphabet: str = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ_",
|
|
1162
|
+
return_confidence: bool = True,
|
|
1163
|
+
confidence_threshold: float = 0.0, # Disabled threshold to match ONNX
|
|
1164
|
+
) -> Tuple[list[str], np.ndarray] | list[str]:
|
|
1165
|
+
"""Postprocess OCR model outputs into license plate strings.
|
|
1166
|
+
|
|
1167
|
+
Args:
|
|
1168
|
+
model_output: Raw output tensor from the model.
|
|
1169
|
+
max_plate_slots: Maximum number of character positions. Defaults to 9.
|
|
1170
|
+
model_alphabet: Alphabet used by the OCR model. Defaults to alphanumeric.
|
|
1171
|
+
return_confidence: If True, also return per-character confidence scores. Defaults to True.
|
|
1172
|
+
confidence_threshold: Minimum confidence for a character to be considered valid. Defaults to 0.0.
|
|
1173
|
+
|
|
1174
|
+
Returns:
|
|
1175
|
+
If return_confidence is False: a list of decoded plate strings.
|
|
1176
|
+
If True: a two-tuple (plates, probs) where plates is the list of decoded strings,
|
|
1177
|
+
and probs is an array of shape (N, max_plate_slots) with confidence scores.
|
|
1178
|
+
"""
|
|
1179
|
+
try:
|
|
1180
|
+
logging.debug(f"OCR model output shape: {model_output.shape}")
|
|
1181
|
+
|
|
1182
|
+
predictions = model_output.reshape((-1, max_plate_slots, len(model_alphabet)))
|
|
1183
|
+
probs = np.max(predictions, axis=-1)
|
|
1184
|
+
prediction_indices = np.argmax(predictions, axis=-1)
|
|
1185
|
+
|
|
1186
|
+
alphabet_array = np.array(list(model_alphabet))
|
|
1187
|
+
if confidence_threshold > 0:
|
|
1188
|
+
pad_char_index = model_alphabet.index('_')
|
|
1189
|
+
prediction_indices[probs < confidence_threshold] = pad_char_index
|
|
1190
|
+
|
|
1191
|
+
plate_chars = alphabet_array[prediction_indices]
|
|
1192
|
+
plates = np.apply_along_axis("".join, 1, plate_chars).tolist()
|
|
1193
|
+
|
|
1194
|
+
if return_confidence:
|
|
1195
|
+
return plates, probs
|
|
1196
|
+
return plates
|
|
1197
|
+
except Exception as e:
|
|
1198
|
+
logging.error("OCR post-processing failed: %s", str(e), exc_info=True)
|
|
1199
|
+
return [], np.array([])
|
|
1200
|
+
|
|
1201
|
+
def format_response(self, response: np.ndarray) -> Dict[str, Any]:
|
|
1202
|
+
"""Format model response for consistent logging.
|
|
1203
|
+
|
|
1204
|
+
Args:
|
|
1205
|
+
response: Raw model output
|
|
1206
|
+
|
|
1207
|
+
Returns:
|
|
1208
|
+
Formatted response dictionary
|
|
1209
|
+
"""
|
|
1210
|
+
if self.is_yolo:
|
|
1211
|
+
boxes, scores, class_ids = self._postprocess_yolo(
|
|
1212
|
+
response,
|
|
1213
|
+
conf_thres=0.25,
|
|
1214
|
+
iou_thres=0.45,
|
|
1215
|
+
max_det=300
|
|
1216
|
+
)
|
|
1217
|
+
predictions = {
|
|
1218
|
+
"boxes": boxes.tolist(),
|
|
1219
|
+
"scores": scores.tolist(),
|
|
1220
|
+
"class_ids": class_ids.tolist()
|
|
1221
|
+
}
|
|
1222
|
+
elif self.is_ocr:
|
|
1223
|
+
plates, probs = self._postprocess_ocr(
|
|
1224
|
+
response,
|
|
1225
|
+
max_plate_slots=9,
|
|
1226
|
+
model_alphabet="0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ_",
|
|
1227
|
+
return_confidence=True,
|
|
1228
|
+
)
|
|
1229
|
+
predictions = {
|
|
1230
|
+
"plates" : plates,
|
|
1231
|
+
"prob" : probs
|
|
1232
|
+
}
|
|
1233
|
+
|
|
1234
|
+
else:
|
|
1235
|
+
predictions = response.tolist() if isinstance(response, np.ndarray) else response
|
|
1236
|
+
|
|
1237
|
+
return {
|
|
1238
|
+
"predictions": predictions,
|
|
1239
|
+
"model_id": self.model_name,
|
|
1240
|
+
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
1241
|
+
}
|
|
1242
|
+
|
|
1243
|
+
# TODO: Bifurcate Triton server and inference utils into separate files
|
|
1244
|
+
# TODO: import and use postprocess functions for diff use-cases (yolo, ocr, cls)
|
|
1245
|
+
# TODO: Verify and Generalize for Obj Det models
|
|
1246
|
+
# TODO: Implement a unified interface for model post-processing
|
|
1247
|
+
# TODO: Define a Standardized template for custom model configs support (ocr)
|
|
1248
|
+
# TODO: Remove hardcoded versions and provide user-defined version control support
|