matrice-inference 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of matrice-inference might be problematic. Click here for more details.
- matrice_inference/__init__.py +72 -0
- matrice_inference/py.typed +0 -0
- matrice_inference/server/__init__.py +23 -0
- matrice_inference/server/inference_interface.py +176 -0
- matrice_inference/server/model/__init__.py +1 -0
- matrice_inference/server/model/model_manager.py +274 -0
- matrice_inference/server/model/model_manager_wrapper.py +550 -0
- matrice_inference/server/model/triton_model_manager.py +290 -0
- matrice_inference/server/model/triton_server.py +1248 -0
- matrice_inference/server/proxy_interface.py +371 -0
- matrice_inference/server/server.py +1004 -0
- matrice_inference/server/stream/__init__.py +0 -0
- matrice_inference/server/stream/app_deployment.py +228 -0
- matrice_inference/server/stream/consumer_worker.py +201 -0
- matrice_inference/server/stream/frame_cache.py +127 -0
- matrice_inference/server/stream/inference_worker.py +163 -0
- matrice_inference/server/stream/post_processing_worker.py +230 -0
- matrice_inference/server/stream/producer_worker.py +147 -0
- matrice_inference/server/stream/stream_pipeline.py +451 -0
- matrice_inference/server/stream/utils.py +23 -0
- matrice_inference/tmp/abstract_model_manager.py +58 -0
- matrice_inference/tmp/aggregator/__init__.py +18 -0
- matrice_inference/tmp/aggregator/aggregator.py +330 -0
- matrice_inference/tmp/aggregator/analytics.py +906 -0
- matrice_inference/tmp/aggregator/ingestor.py +438 -0
- matrice_inference/tmp/aggregator/latency.py +597 -0
- matrice_inference/tmp/aggregator/pipeline.py +968 -0
- matrice_inference/tmp/aggregator/publisher.py +431 -0
- matrice_inference/tmp/aggregator/synchronizer.py +594 -0
- matrice_inference/tmp/batch_manager.py +239 -0
- matrice_inference/tmp/overall_inference_testing.py +338 -0
- matrice_inference/tmp/triton_utils.py +638 -0
- matrice_inference-0.1.2.dist-info/METADATA +28 -0
- matrice_inference-0.1.2.dist-info/RECORD +37 -0
- matrice_inference-0.1.2.dist-info/WHEEL +5 -0
- matrice_inference-0.1.2.dist-info/licenses/LICENSE.txt +21 -0
- matrice_inference-0.1.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,638 @@
|
|
|
1
|
+
"""Module providing triton_server functionality."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import zipfile
|
|
5
|
+
import subprocess
|
|
6
|
+
import logging
|
|
7
|
+
import threading
|
|
8
|
+
import shlex
|
|
9
|
+
import shutil
|
|
10
|
+
from matrice_common.utils import dependencies_check
|
|
11
|
+
from matrice.docker_utils import pull_docker_image
|
|
12
|
+
|
|
13
|
+
TRITON_DOCKER_IMAGE = "nvcr.io/nvidia/tritonserver:23.08-py3"
|
|
14
|
+
BASE_PATH = "./model_repository"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class MatriceTritonServer:
|
|
18
|
+
|
|
19
|
+
def __init__(self, action_tracker):
|
|
20
|
+
dependencies_check("torch")
|
|
21
|
+
import torch
|
|
22
|
+
|
|
23
|
+
logging.info("Initializing MatriceTritonServer (v0)")
|
|
24
|
+
self.action_tracker = action_tracker
|
|
25
|
+
self.action_details = action_tracker.action_details
|
|
26
|
+
self.model_id = self.action_details["_idModelDeploy"]
|
|
27
|
+
self.deployment_id = self.action_details["_idDeployment"]
|
|
28
|
+
self.deployment_instance_id = self.action_details["_idModelDeployInstance"]
|
|
29
|
+
logging.info("Model ID: %s", self.model_id)
|
|
30
|
+
logging.info(
|
|
31
|
+
"Deployment ID: %s",
|
|
32
|
+
self.deployment_id,
|
|
33
|
+
)
|
|
34
|
+
logging.info(
|
|
35
|
+
"Deployment Instance ID: %s",
|
|
36
|
+
self.deployment_instance_id,
|
|
37
|
+
)
|
|
38
|
+
self.connection_protocol = (
|
|
39
|
+
"grpc" if "grpc" in self.action_details.get("server_type", "rest").lower() else "rest"
|
|
40
|
+
)
|
|
41
|
+
logging.info(
|
|
42
|
+
"Using connection protocol: %s",
|
|
43
|
+
self.connection_protocol,
|
|
44
|
+
)
|
|
45
|
+
self.job_params = self.action_tracker.get_job_params()
|
|
46
|
+
logging.debug("Job parameters: %s", self.job_params)
|
|
47
|
+
self.gpus_count = torch.cuda.device_count()
|
|
48
|
+
logging.info(
|
|
49
|
+
"Found %s GPUs available for inference",
|
|
50
|
+
self.gpus_count,
|
|
51
|
+
)
|
|
52
|
+
self.docker_pull_process = subprocess.Popen(
|
|
53
|
+
[
|
|
54
|
+
shutil.which("docker"),
|
|
55
|
+
"pull",
|
|
56
|
+
TRITON_DOCKER_IMAGE,
|
|
57
|
+
],
|
|
58
|
+
stdout=subprocess.PIPE,
|
|
59
|
+
stderr=subprocess.PIPE,
|
|
60
|
+
text=True,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
def check_triton_docker_image(self):
|
|
64
|
+
"""Check if docker image download is complete and wait for it to finish"""
|
|
65
|
+
logging.info("Checking docker image download status")
|
|
66
|
+
stdout, stderr = self.docker_pull_process.communicate()
|
|
67
|
+
if self.docker_pull_process.returncode == 0:
|
|
68
|
+
logging.info(
|
|
69
|
+
"Docker image %s downloaded successfully",
|
|
70
|
+
TRITON_DOCKER_IMAGE,
|
|
71
|
+
)
|
|
72
|
+
else:
|
|
73
|
+
error_msg = stderr.decode()
|
|
74
|
+
logging.error(
|
|
75
|
+
"Docker pull failed with return code %s",
|
|
76
|
+
self.docker_pull_process.returncode,
|
|
77
|
+
)
|
|
78
|
+
logging.error("Error message: %s", error_msg)
|
|
79
|
+
raise RuntimeError(f"Docker pull failed: {error_msg}")
|
|
80
|
+
|
|
81
|
+
def download_model(self, model_version_dir):
|
|
82
|
+
"""Download and extract the model files"""
|
|
83
|
+
try:
|
|
84
|
+
runtime_framework = self.action_tracker.export_format.lower()
|
|
85
|
+
logging.info(
|
|
86
|
+
"Downloading model with runtime framework: %s",
|
|
87
|
+
runtime_framework,
|
|
88
|
+
)
|
|
89
|
+
model_map = {
|
|
90
|
+
"onnx": "model.onnx",
|
|
91
|
+
"torchscript": "model.pt",
|
|
92
|
+
"pytorch": "model.pt",
|
|
93
|
+
"tensorrt": "model.engine",
|
|
94
|
+
"openvino": "model_openvino.zip",
|
|
95
|
+
}
|
|
96
|
+
if runtime_framework not in model_map:
|
|
97
|
+
logging.error(
|
|
98
|
+
"Runtime framework '%s' not supported. Supported frameworks: %s",
|
|
99
|
+
runtime_framework,
|
|
100
|
+
list(model_map.keys()),
|
|
101
|
+
)
|
|
102
|
+
raise ValueError(f"Unsupported runtime framework: {runtime_framework}")
|
|
103
|
+
model_file = os.path.join(
|
|
104
|
+
model_version_dir,
|
|
105
|
+
model_map[runtime_framework],
|
|
106
|
+
)
|
|
107
|
+
logging.info(
|
|
108
|
+
"Downloading model to path: %s",
|
|
109
|
+
model_file,
|
|
110
|
+
)
|
|
111
|
+
model_type = "exported" if self.action_tracker.is_exported else "trained"
|
|
112
|
+
logging.info("Model type: %s", model_type)
|
|
113
|
+
self.action_tracker.download_model(model_file, model_type=model_type)
|
|
114
|
+
logging.info("Model download completed successfully")
|
|
115
|
+
if runtime_framework == "pytorch":
|
|
116
|
+
def compile_torch_model(
|
|
117
|
+
model_path: str,
|
|
118
|
+
):
|
|
119
|
+
import torch
|
|
120
|
+
|
|
121
|
+
logging.info("Compiling PyTorch model")
|
|
122
|
+
if self.gpus_count > 0:
|
|
123
|
+
model = torch.load(model_path)
|
|
124
|
+
else:
|
|
125
|
+
model = torch.load(
|
|
126
|
+
model_path,
|
|
127
|
+
map_location=torch.device("cpu"),
|
|
128
|
+
)
|
|
129
|
+
model.eval()
|
|
130
|
+
compiled_model = torch.jit.script(model)
|
|
131
|
+
compiled_model.save(model_path)
|
|
132
|
+
logging.info("PyTorch model compiled successfully")
|
|
133
|
+
|
|
134
|
+
compile_torch_model(model_file)
|
|
135
|
+
if runtime_framework == "openvino":
|
|
136
|
+
logging.info("Starting OpenVINO model extraction")
|
|
137
|
+
with zipfile.ZipFile("model_openvino.zip", "r") as zip_ref:
|
|
138
|
+
zip_ref.extractall("model_openvino")
|
|
139
|
+
logging.info("OpenVINO model extracted successfully")
|
|
140
|
+
except Exception as e:
|
|
141
|
+
logging.error(
|
|
142
|
+
"Model download failed: %s",
|
|
143
|
+
str(e),
|
|
144
|
+
exc_info=True,
|
|
145
|
+
)
|
|
146
|
+
raise
|
|
147
|
+
|
|
148
|
+
def create_model_repository(self):
|
|
149
|
+
"""Create the model repository directory structure"""
|
|
150
|
+
try:
|
|
151
|
+
model_version = "1"
|
|
152
|
+
model_dir = os.path.join(BASE_PATH, self.model_id)
|
|
153
|
+
version_dir = os.path.join(model_dir, str(model_version))
|
|
154
|
+
logging.info("Creating model repository structure:")
|
|
155
|
+
logging.info("Base path: %s", BASE_PATH)
|
|
156
|
+
logging.info("Model directory: %s", model_dir)
|
|
157
|
+
logging.info(
|
|
158
|
+
"Version directory: %s",
|
|
159
|
+
version_dir,
|
|
160
|
+
)
|
|
161
|
+
os.makedirs(version_dir, exist_ok=True)
|
|
162
|
+
logging.info("Model repository directories created successfully")
|
|
163
|
+
return model_dir, version_dir
|
|
164
|
+
except Exception as e:
|
|
165
|
+
logging.error(
|
|
166
|
+
"Failed to create model repository: %s",
|
|
167
|
+
str(e),
|
|
168
|
+
exc_info=True,
|
|
169
|
+
)
|
|
170
|
+
raise
|
|
171
|
+
|
|
172
|
+
def write_config_file(
|
|
173
|
+
self,
|
|
174
|
+
model_dir,
|
|
175
|
+
max_batch_size=0,
|
|
176
|
+
num_model_instances=1,
|
|
177
|
+
image_size=[224, 224],
|
|
178
|
+
num_classes=10,
|
|
179
|
+
input_data_type: str = "TYPE_FP32",
|
|
180
|
+
output_data_type: str = "TYPE_FP32",
|
|
181
|
+
dynamic_batching: bool = False,
|
|
182
|
+
preferred_batch_size: list = [2, 4, 8],
|
|
183
|
+
max_queue_delay_microseconds: int = 100,
|
|
184
|
+
input_pinned_memory: bool = True,
|
|
185
|
+
output_pinned_memory: bool = True,
|
|
186
|
+
**kwargs,
|
|
187
|
+
):
|
|
188
|
+
"""Write the model configuration file for Triton Inference Server"""
|
|
189
|
+
try:
|
|
190
|
+
runtime_framework = self.action_tracker.export_format.lower()
|
|
191
|
+
logging.info("Starting to write Triton config file")
|
|
192
|
+
platform_map = {
|
|
193
|
+
"onnx": "onnxruntime_onnx",
|
|
194
|
+
"tensorrt": "tensorrt_plan",
|
|
195
|
+
"pytorch": "pytorch_libtorch",
|
|
196
|
+
"torchscript": "pytorch_libtorch",
|
|
197
|
+
"openvino": "openvino",
|
|
198
|
+
}
|
|
199
|
+
platform = platform_map.get(runtime_framework)
|
|
200
|
+
if not platform:
|
|
201
|
+
logging.error(
|
|
202
|
+
"Runtime framework '%s' not found in platform map",
|
|
203
|
+
runtime_framework,
|
|
204
|
+
)
|
|
205
|
+
raise ValueError(f"Unsupported runtime framework: {runtime_framework}")
|
|
206
|
+
config_path = os.path.join(model_dir, "config.pbtxt")
|
|
207
|
+
logging.info(
|
|
208
|
+
"Writing config to: %s",
|
|
209
|
+
config_path,
|
|
210
|
+
)
|
|
211
|
+
config_str = """
|
|
212
|
+
name: "{self.model_id}"
|
|
213
|
+
platform: "{platform}"
|
|
214
|
+
max_batch_size: {max_batch_size}
|
|
215
|
+
"""
|
|
216
|
+
if platform == "pytorch_libtorch":
|
|
217
|
+
logging.info("Adding PyTorch-specific configuration")
|
|
218
|
+
config_str += """
|
|
219
|
+
# Input configuration
|
|
220
|
+
input [
|
|
221
|
+
{{
|
|
222
|
+
name: "input__0"
|
|
223
|
+
data_type: {input_data_type}
|
|
224
|
+
dims: [ 3, {image_size[0]}, {image_size[1]} ]
|
|
225
|
+
}}
|
|
226
|
+
]
|
|
227
|
+
|
|
228
|
+
# Output configuration
|
|
229
|
+
output [
|
|
230
|
+
{{
|
|
231
|
+
name: "output__0"
|
|
232
|
+
data_type: {output_data_type}
|
|
233
|
+
dims: [ {num_classes} ]
|
|
234
|
+
}}
|
|
235
|
+
]
|
|
236
|
+
"""
|
|
237
|
+
if num_model_instances > 1:
|
|
238
|
+
device_type = "KIND_GPU" if self.gpus_count > 0 else "KIND_CPU"
|
|
239
|
+
logging.info(
|
|
240
|
+
"Adding instance group configuration for %s %s instances",
|
|
241
|
+
num_model_instances,
|
|
242
|
+
device_type,
|
|
243
|
+
)
|
|
244
|
+
config_str += """
|
|
245
|
+
# Instance groups for GPU/CPU execution
|
|
246
|
+
instance_group [
|
|
247
|
+
{{
|
|
248
|
+
count: {num_model_instances}
|
|
249
|
+
kind: {device_type}
|
|
250
|
+
}}
|
|
251
|
+
]
|
|
252
|
+
"""
|
|
253
|
+
if dynamic_batching:
|
|
254
|
+
logging.info("Adding dynamic batching configuration")
|
|
255
|
+
config_str += """
|
|
256
|
+
# Dynamic batching config
|
|
257
|
+
dynamic_batching {{
|
|
258
|
+
preferred_batch_size: {preferred_batch_size}
|
|
259
|
+
max_queue_delay_microseconds: {max_queue_delay_microseconds}
|
|
260
|
+
}}
|
|
261
|
+
"""
|
|
262
|
+
if not input_pinned_memory or not output_pinned_memory:
|
|
263
|
+
logging.info("Adding pinned memory configuration")
|
|
264
|
+
config_str += """
|
|
265
|
+
optimization {{
|
|
266
|
+
input_pinned_memory {{
|
|
267
|
+
enable: {input_pinned_memory}
|
|
268
|
+
}}
|
|
269
|
+
output_pinned_memory {{
|
|
270
|
+
enable: {output_pinned_memory}
|
|
271
|
+
}}
|
|
272
|
+
}}
|
|
273
|
+
"""
|
|
274
|
+
with open(config_path, "w") as f:
|
|
275
|
+
f.write(config_str)
|
|
276
|
+
logging.info("Config file written successfully")
|
|
277
|
+
logging.info("Config content:\n%s", config_str)
|
|
278
|
+
except Exception as e:
|
|
279
|
+
logging.error(
|
|
280
|
+
"Failed to write config file: %s",
|
|
281
|
+
str(e),
|
|
282
|
+
exc_info=True,
|
|
283
|
+
)
|
|
284
|
+
raise
|
|
285
|
+
|
|
286
|
+
def get_config_params(self):
|
|
287
|
+
try:
|
|
288
|
+
logging.info("Retrieving configuration parameters")
|
|
289
|
+
input_size = self.action_tracker.get_input_size()
|
|
290
|
+
num_classes = len(
|
|
291
|
+
self.action_tracker.get_index_to_category(self.action_tracker.is_exported)
|
|
292
|
+
)
|
|
293
|
+
logging.info(
|
|
294
|
+
"Retrieved input size: %s",
|
|
295
|
+
input_size,
|
|
296
|
+
)
|
|
297
|
+
logging.info(
|
|
298
|
+
"Retrieved number of classes: %s",
|
|
299
|
+
num_classes,
|
|
300
|
+
)
|
|
301
|
+
params = {
|
|
302
|
+
"max_batch_size": 8,
|
|
303
|
+
"num_model_instances": 1,
|
|
304
|
+
"image_size": [
|
|
305
|
+
input_size,
|
|
306
|
+
input_size,
|
|
307
|
+
],
|
|
308
|
+
"num_classes": num_classes,
|
|
309
|
+
"input_data_type": "TYPE_FP32",
|
|
310
|
+
"output_data_type": "TYPE_FP32",
|
|
311
|
+
"dynamic_batching": False,
|
|
312
|
+
"preferred_batch_size": [2, 4, 8],
|
|
313
|
+
"max_queue_delay_microseconds": 100,
|
|
314
|
+
"input_pinned_memory": True,
|
|
315
|
+
"output_pinned_memory": True,
|
|
316
|
+
}
|
|
317
|
+
params.update(self.job_params)
|
|
318
|
+
logging.debug(
|
|
319
|
+
"Final configuration parameters: %s",
|
|
320
|
+
params,
|
|
321
|
+
)
|
|
322
|
+
return params
|
|
323
|
+
except Exception as e:
|
|
324
|
+
logging.error(
|
|
325
|
+
"Failed to get configuration parameters: %s",
|
|
326
|
+
str(e),
|
|
327
|
+
exc_info=True,
|
|
328
|
+
)
|
|
329
|
+
raise
|
|
330
|
+
|
|
331
|
+
def start_server(self):
|
|
332
|
+
"""Start the Triton Inference Server"""
|
|
333
|
+
gpu_option = "--gpus=all " if self.gpus_count > 0 else ""
|
|
334
|
+
port_mapping = f"-p{os.environ['INTERNAL_PORT']}:{8000 if self.connection_protocol == 'rest' else 8001}"
|
|
335
|
+
start_triton_server = f"docker run {gpu_option}--rm {port_mapping} -v {os.path.abspath(BASE_PATH)}:/models --label action_id={self.action_tracker.action_id_str} {TRITON_DOCKER_IMAGE} tritonserver --model-repository=/models "
|
|
336
|
+
logging.info("Checking docker image download status before starting server")
|
|
337
|
+
self.check_triton_docker_image()
|
|
338
|
+
try:
|
|
339
|
+
logging.info(
|
|
340
|
+
"Starting Triton server with command: %s",
|
|
341
|
+
start_triton_server,
|
|
342
|
+
)
|
|
343
|
+
self.process = subprocess.Popen(
|
|
344
|
+
shlex.split(start_triton_server),
|
|
345
|
+
stdout=subprocess.PIPE,
|
|
346
|
+
stderr=subprocess.PIPE,
|
|
347
|
+
text=True,
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
def log_output():
|
|
351
|
+
while True:
|
|
352
|
+
stdout_line = self.process.stdout.readline()
|
|
353
|
+
stderr_line = self.process.stderr.readline()
|
|
354
|
+
if stdout_line:
|
|
355
|
+
logging.info(stdout_line.strip())
|
|
356
|
+
if stderr_line:
|
|
357
|
+
logging.info(stderr_line.strip())
|
|
358
|
+
if stdout_line == "" and stderr_line == "" and self.process.poll() is not None:
|
|
359
|
+
break
|
|
360
|
+
|
|
361
|
+
threading.Thread(target=log_output, daemon=True).start()
|
|
362
|
+
logging.info(
|
|
363
|
+
"Triton server started successfully on port %s",
|
|
364
|
+
os.environ.get("INTERNAL_PORT"),
|
|
365
|
+
)
|
|
366
|
+
return self.process
|
|
367
|
+
except Exception as e:
|
|
368
|
+
logging.error(
|
|
369
|
+
"Failed to start Triton server: %s",
|
|
370
|
+
str(e),
|
|
371
|
+
exc_info=True,
|
|
372
|
+
)
|
|
373
|
+
raise
|
|
374
|
+
|
|
375
|
+
def setup(self):
|
|
376
|
+
try:
|
|
377
|
+
logging.info("Beginning Triton server setup")
|
|
378
|
+
logging.info("Step 1: Creating model repository")
|
|
379
|
+
self.model_dir, self.version_dir = self.create_model_repository()
|
|
380
|
+
logging.info("Step 2: Downloading model")
|
|
381
|
+
self.download_model(self.version_dir)
|
|
382
|
+
logging.info("Step 3: Getting configuration parameters")
|
|
383
|
+
self.config_params = self.get_config_params()
|
|
384
|
+
logging.info("Step 4: Writing configuration file")
|
|
385
|
+
self.write_config_file(
|
|
386
|
+
self.model_dir,
|
|
387
|
+
**self.config_params,
|
|
388
|
+
)
|
|
389
|
+
logging.info("Step 5: Starting Triton server")
|
|
390
|
+
self.process = self.start_server()
|
|
391
|
+
logging.info("Triton server setup completed successfully")
|
|
392
|
+
return self.process
|
|
393
|
+
except Exception as e:
|
|
394
|
+
logging.error(
|
|
395
|
+
"Triton server setup failed: %s",
|
|
396
|
+
str(e),
|
|
397
|
+
exc_info=True,
|
|
398
|
+
)
|
|
399
|
+
raise
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
"""Module providing inference_utils functionality for FastAPI and Triton inference."""
|
|
403
|
+
|
|
404
|
+
from PIL import Image
|
|
405
|
+
import httpx
|
|
406
|
+
import logging
|
|
407
|
+
from typing import Optional, Dict, Union, Any
|
|
408
|
+
from datetime import datetime, timezone
|
|
409
|
+
from io import BytesIO
|
|
410
|
+
import numpy as np
|
|
411
|
+
from matrice_common.utils import dependencies_check
|
|
412
|
+
|
|
413
|
+
|
|
414
|
+
class TritonInference:
|
|
415
|
+
"""Class for making Triton inference requests."""
|
|
416
|
+
|
|
417
|
+
def __init__(
|
|
418
|
+
self,
|
|
419
|
+
server_type: str,
|
|
420
|
+
model_id: str,
|
|
421
|
+
internal_port: int = 80,
|
|
422
|
+
internal_host: str = "localhost",
|
|
423
|
+
):
|
|
424
|
+
"""Initialize Triton inference client.
|
|
425
|
+
|
|
426
|
+
Args:
|
|
427
|
+
server_type: Type of server (grpc/rest)
|
|
428
|
+
model_id: ID of model to use
|
|
429
|
+
internal_port: Port number for internal API
|
|
430
|
+
internal_host: Hostname for internal API
|
|
431
|
+
"""
|
|
432
|
+
self.model_id = model_id
|
|
433
|
+
self.data_type_mapping = {
|
|
434
|
+
(6): "TYPE_INT8",
|
|
435
|
+
(7): "TYPE_INT16",
|
|
436
|
+
(8): "TYPE_INT32",
|
|
437
|
+
(9): "TYPE_INT64",
|
|
438
|
+
(10): "TYPE_FP16",
|
|
439
|
+
(11): "TYPE_FP32",
|
|
440
|
+
(12): "TYPE_FP64",
|
|
441
|
+
}
|
|
442
|
+
self.numpy_data_type_mapping = {
|
|
443
|
+
"INT8": np.int8,
|
|
444
|
+
"INT16": np.int16,
|
|
445
|
+
"INT32": np.int32,
|
|
446
|
+
"INT64": np.int64,
|
|
447
|
+
"FP16": np.float16,
|
|
448
|
+
"FP32": np.float32,
|
|
449
|
+
"FP64": np.float64,
|
|
450
|
+
}
|
|
451
|
+
self.setup_client_funcs = {
|
|
452
|
+
"grpc": self._setup_grpc_client,
|
|
453
|
+
"rest": self._setup_rest_client,
|
|
454
|
+
}
|
|
455
|
+
self.url = f"{internal_host}:{internal_port}"
|
|
456
|
+
self.connection_protocol = "grpc" if "grpc" in server_type else "rest"
|
|
457
|
+
self.tritonclientclass = None
|
|
458
|
+
self._dependencies_check()
|
|
459
|
+
self.client_info = self.setup_client_funcs[self.connection_protocol]()
|
|
460
|
+
logging.info(
|
|
461
|
+
"Initialized TritonClientUtils with %s protocol",
|
|
462
|
+
self.connection_protocol,
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
def _dependencies_check(self):
|
|
466
|
+
"""Check and import required Triton dependencies."""
|
|
467
|
+
try:
|
|
468
|
+
if self.connection_protocol == "rest":
|
|
469
|
+
dependencies_check(["tritonclient[http]"])
|
|
470
|
+
import tritonclient.http as tritonclientclass
|
|
471
|
+
else:
|
|
472
|
+
dependencies_check(["tritonclient[grpc]"])
|
|
473
|
+
import tritonclient.grpc as tritonclientclass
|
|
474
|
+
self.tritonclientclass = tritonclientclass
|
|
475
|
+
except Exception as err:
|
|
476
|
+
logging.error(
|
|
477
|
+
"Failed to import tritonclient: %s",
|
|
478
|
+
err,
|
|
479
|
+
)
|
|
480
|
+
raise
|
|
481
|
+
|
|
482
|
+
def _setup_rest_client(self):
|
|
483
|
+
"""Setup REST client and model configuration.
|
|
484
|
+
|
|
485
|
+
Returns:
|
|
486
|
+
Dictionary containing client configuration
|
|
487
|
+
"""
|
|
488
|
+
client = self.tritonclientclass.InferenceServerClient(url=self.url)
|
|
489
|
+
model_config = client.get_model_config(
|
|
490
|
+
model_name=self.model_id,
|
|
491
|
+
model_version="1",
|
|
492
|
+
)
|
|
493
|
+
input_shape = [1, 3, 244, 244][: 4 - len(model_config["input"][0]["dims"])] + model_config[
|
|
494
|
+
"input"
|
|
495
|
+
][0]["dims"]
|
|
496
|
+
input_obj = self.tritonclientclass.InferInput(
|
|
497
|
+
model_config["input"][0]["name"],
|
|
498
|
+
input_shape,
|
|
499
|
+
model_config["input"][0]["data_type"].split("_")[-1],
|
|
500
|
+
)
|
|
501
|
+
output = self.tritonclientclass.InferRequestedOutput(model_config["output"][0]["name"])
|
|
502
|
+
return {
|
|
503
|
+
"client": client,
|
|
504
|
+
"input": input_obj,
|
|
505
|
+
"output": output,
|
|
506
|
+
}
|
|
507
|
+
|
|
508
|
+
def _setup_grpc_client(self):
|
|
509
|
+
"""Setup gRPC client and model configuration.
|
|
510
|
+
|
|
511
|
+
Returns:
|
|
512
|
+
Dictionary containing client configuration
|
|
513
|
+
"""
|
|
514
|
+
client = self.tritonclientclass.InferenceServerClient(url=self.url)
|
|
515
|
+
model_config = client.get_model_config(
|
|
516
|
+
model_name=self.model_id,
|
|
517
|
+
model_version="1",
|
|
518
|
+
)
|
|
519
|
+
input_shape = [1, 3, 244, 244][: 4 - len(model_config.config.input[0].dims)] + list(
|
|
520
|
+
model_config.config.input[0].dims
|
|
521
|
+
)
|
|
522
|
+
input_obj = self.tritonclientclass.InferInput(
|
|
523
|
+
model_config.config.input[0].name,
|
|
524
|
+
input_shape,
|
|
525
|
+
self.data_type_mapping[model_config.config.input[0].data_type].split("_")[-1],
|
|
526
|
+
)
|
|
527
|
+
output = self.tritonclientclass.InferRequestedOutput(model_config.config.output[0].name)
|
|
528
|
+
return {
|
|
529
|
+
"client": client,
|
|
530
|
+
"input": input_obj,
|
|
531
|
+
"output": output,
|
|
532
|
+
}
|
|
533
|
+
|
|
534
|
+
def inference(self, input_data: bytes) -> np.ndarray:
|
|
535
|
+
"""Make a synchronous inference request.
|
|
536
|
+
|
|
537
|
+
Args:
|
|
538
|
+
input_data: Input data as bytes
|
|
539
|
+
|
|
540
|
+
Returns:
|
|
541
|
+
Model prediction as numpy array
|
|
542
|
+
|
|
543
|
+
Raises:
|
|
544
|
+
Exception: If inference fails
|
|
545
|
+
"""
|
|
546
|
+
try:
|
|
547
|
+
logging.debug(
|
|
548
|
+
"Making inference request for instance %s",
|
|
549
|
+
self.url,
|
|
550
|
+
)
|
|
551
|
+
input_array = self._preprocess_input(input_data)
|
|
552
|
+
self.client_info["input"].set_data_from_numpy(input_array)
|
|
553
|
+
resp = self.client_info["client"].infer(
|
|
554
|
+
model_name=self.model_id,
|
|
555
|
+
model_version="1",
|
|
556
|
+
inputs=[self.client_info["input"]],
|
|
557
|
+
outputs=[self.client_info["output"]],
|
|
558
|
+
)
|
|
559
|
+
logging.debug("Successfully got inference result")
|
|
560
|
+
return resp.as_numpy(self.client_info["output"].name())
|
|
561
|
+
except Exception as err:
|
|
562
|
+
logging.error("Triton inference failed: %s", err)
|
|
563
|
+
raise Exception(f"Triton inference failed: {err}") from err
|
|
564
|
+
|
|
565
|
+
async def async_inference(self, input_data: bytes) -> np.ndarray:
|
|
566
|
+
"""Make an asynchronous inference request.
|
|
567
|
+
|
|
568
|
+
Args:
|
|
569
|
+
input_data: Input data as bytes
|
|
570
|
+
|
|
571
|
+
Returns:
|
|
572
|
+
Model prediction as numpy array
|
|
573
|
+
|
|
574
|
+
Raises:
|
|
575
|
+
Exception: If inference fails
|
|
576
|
+
"""
|
|
577
|
+
try:
|
|
578
|
+
logging.debug(
|
|
579
|
+
"Making async inference request for instance %s",
|
|
580
|
+
self.url,
|
|
581
|
+
)
|
|
582
|
+
input_array = self._preprocess_input(input_data)
|
|
583
|
+
self.client_info["input"].set_data_from_numpy(input_array)
|
|
584
|
+
if self.connection_protocol == "rest":
|
|
585
|
+
resp = await self.client_info["client"].async_infer(
|
|
586
|
+
model_name=self.model_id,
|
|
587
|
+
model_version="1",
|
|
588
|
+
inputs=[self.client_info["input"]],
|
|
589
|
+
outputs=[self.client_info["output"]],
|
|
590
|
+
)
|
|
591
|
+
else:
|
|
592
|
+
resp = await self.client_info["client"].infer_async(
|
|
593
|
+
model_name=self.model_id,
|
|
594
|
+
model_version="1",
|
|
595
|
+
inputs=[self.client_info["input"]],
|
|
596
|
+
outputs=[self.client_info["output"]],
|
|
597
|
+
)
|
|
598
|
+
logging.debug("Successfully got async inference result")
|
|
599
|
+
return resp.as_numpy(self.client_info["output"].name())
|
|
600
|
+
except Exception as err:
|
|
601
|
+
logging.error(
|
|
602
|
+
"Async Triton inference failed: %s",
|
|
603
|
+
err,
|
|
604
|
+
)
|
|
605
|
+
raise Exception(f"Async Triton inference failed: {err}") from err
|
|
606
|
+
|
|
607
|
+
def _preprocess_input(self, input_data: bytes) -> np.ndarray:
|
|
608
|
+
"""Preprocess input data for model inference.
|
|
609
|
+
|
|
610
|
+
Args:
|
|
611
|
+
input_data: Raw input bytes
|
|
612
|
+
|
|
613
|
+
Returns:
|
|
614
|
+
Preprocessed numpy array ready for inference
|
|
615
|
+
"""
|
|
616
|
+
image = Image.open(BytesIO(input_data)).convert("RGB")
|
|
617
|
+
image = image.resize(self.client_info["input"].shape()[2:])
|
|
618
|
+
array = np.array(image).astype(
|
|
619
|
+
self.numpy_data_type_mapping[self.client_info["input"].datatype()]
|
|
620
|
+
)
|
|
621
|
+
array = array.transpose(2, 0, 1)
|
|
622
|
+
array = np.expand_dims(array, axis=0)
|
|
623
|
+
return array
|
|
624
|
+
|
|
625
|
+
def format_response(self, response: np.ndarray) -> Dict[str, Any]:
|
|
626
|
+
"""Format model response for consistent logging.
|
|
627
|
+
|
|
628
|
+
Args:
|
|
629
|
+
response: Raw model output
|
|
630
|
+
|
|
631
|
+
Returns:
|
|
632
|
+
Formatted response dictionary
|
|
633
|
+
"""
|
|
634
|
+
return {
|
|
635
|
+
"predictions": (response.tolist() if isinstance(response, np.ndarray) else response),
|
|
636
|
+
"model_id": self.model_id,
|
|
637
|
+
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
638
|
+
}
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: matrice_inference
|
|
3
|
+
Version: 0.1.2
|
|
4
|
+
Summary: Common server utilities for Matrice.ai services
|
|
5
|
+
Author-email: "Matrice.ai" <dipendra@matrice.ai>
|
|
6
|
+
License-Expression: MIT
|
|
7
|
+
Keywords: matrice,common,utilities,pyarmor,obfuscated
|
|
8
|
+
Classifier: Development Status :: 4 - Beta
|
|
9
|
+
Classifier: Intended Audience :: Developers
|
|
10
|
+
Classifier: Operating System :: OS Independent
|
|
11
|
+
Classifier: Operating System :: POSIX :: Linux
|
|
12
|
+
Classifier: Operating System :: Microsoft :: Windows
|
|
13
|
+
Classifier: Operating System :: MacOS
|
|
14
|
+
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.8
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
19
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
20
|
+
Classifier: Typing :: Typed
|
|
21
|
+
Requires-Python: >=3.8
|
|
22
|
+
Description-Content-Type: text/markdown
|
|
23
|
+
License-File: LICENSE.txt
|
|
24
|
+
Dynamic: license-file
|
|
25
|
+
Dynamic: requires-python
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
# matrice\_inference
|