matrice-inference 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of matrice-inference might be problematic. Click here for more details.
- matrice_inference/__init__.py +72 -0
- matrice_inference/py.typed +0 -0
- matrice_inference/server/__init__.py +23 -0
- matrice_inference/server/inference_interface.py +176 -0
- matrice_inference/server/model/__init__.py +1 -0
- matrice_inference/server/model/model_manager.py +274 -0
- matrice_inference/server/model/model_manager_wrapper.py +550 -0
- matrice_inference/server/model/triton_model_manager.py +290 -0
- matrice_inference/server/model/triton_server.py +1248 -0
- matrice_inference/server/proxy_interface.py +371 -0
- matrice_inference/server/server.py +1004 -0
- matrice_inference/server/stream/__init__.py +0 -0
- matrice_inference/server/stream/app_deployment.py +228 -0
- matrice_inference/server/stream/consumer_worker.py +201 -0
- matrice_inference/server/stream/frame_cache.py +127 -0
- matrice_inference/server/stream/inference_worker.py +163 -0
- matrice_inference/server/stream/post_processing_worker.py +230 -0
- matrice_inference/server/stream/producer_worker.py +147 -0
- matrice_inference/server/stream/stream_pipeline.py +451 -0
- matrice_inference/server/stream/utils.py +23 -0
- matrice_inference/tmp/abstract_model_manager.py +58 -0
- matrice_inference/tmp/aggregator/__init__.py +18 -0
- matrice_inference/tmp/aggregator/aggregator.py +330 -0
- matrice_inference/tmp/aggregator/analytics.py +906 -0
- matrice_inference/tmp/aggregator/ingestor.py +438 -0
- matrice_inference/tmp/aggregator/latency.py +597 -0
- matrice_inference/tmp/aggregator/pipeline.py +968 -0
- matrice_inference/tmp/aggregator/publisher.py +431 -0
- matrice_inference/tmp/aggregator/synchronizer.py +594 -0
- matrice_inference/tmp/batch_manager.py +239 -0
- matrice_inference/tmp/overall_inference_testing.py +338 -0
- matrice_inference/tmp/triton_utils.py +638 -0
- matrice_inference-0.1.2.dist-info/METADATA +28 -0
- matrice_inference-0.1.2.dist-info/RECORD +37 -0
- matrice_inference-0.1.2.dist-info/WHEEL +5 -0
- matrice_inference-0.1.2.dist-info/licenses/LICENSE.txt +21 -0
- matrice_inference-0.1.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,550 @@
|
|
|
1
|
+
from typing import Tuple, Dict, Any, Optional, List, Union, Callable
|
|
2
|
+
import logging
|
|
3
|
+
from matrice.action_tracker import ActionTracker
|
|
4
|
+
from matrice_inference.server.model.model_manager import ModelManager
|
|
5
|
+
from matrice_inference.server.model.triton_model_manager import TritonModelManager
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
class ModelManagerWrapper:
|
|
9
|
+
"""Wrapper class for ModelManager and TritonModelManager to provide a unified interface."""
|
|
10
|
+
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
action_tracker: ActionTracker,
|
|
14
|
+
test_env: bool = False,
|
|
15
|
+
# "default" for ModelManager, "triton" for TritonModelManager
|
|
16
|
+
model_type: str = "default",
|
|
17
|
+
model_id: Optional[str] = None,
|
|
18
|
+
internal_server_type: Optional[str] = None,
|
|
19
|
+
internal_port: Optional[int] = None,
|
|
20
|
+
internal_host: Optional[str] = None,
|
|
21
|
+
num_model_instances: Optional[int] = None,
|
|
22
|
+
load_model: Optional[Callable] = None,
|
|
23
|
+
predict: Optional[Callable] = None,
|
|
24
|
+
batch_predict: Optional[Callable] = None,
|
|
25
|
+
model_name: Optional[str] = None,
|
|
26
|
+
model_path: Optional[str] = None,
|
|
27
|
+
runtime_framework: Optional[str] = None,
|
|
28
|
+
input_size: Optional[Union[int, List[int]]] = None,
|
|
29
|
+
num_classes: Optional[int] = None,
|
|
30
|
+
use_dynamic_batching: Optional[bool] = None,
|
|
31
|
+
max_batch_size: Optional[int] = None,
|
|
32
|
+
is_yolo: Optional[bool] = None,
|
|
33
|
+
is_ocr: Optional[bool] = None,
|
|
34
|
+
use_trt_accelerator: Optional[bool] = None,
|
|
35
|
+
):
|
|
36
|
+
"""
|
|
37
|
+
Initialize the ModelManagerWrapper.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
action_tracker: Action tracker for category mapping and configuration.
|
|
41
|
+
test_env: If True, use provided parameters for testing; if False, extract from action_tracker.
|
|
42
|
+
model_type: Type of model manager ("default" for ModelManager, "triton" for TritonModelManager).
|
|
43
|
+
model_id: ID of the model (for ModelManager).
|
|
44
|
+
internal_server_type: Type of internal server (e.g., "rest", "grpc").
|
|
45
|
+
internal_port: Internal port number.
|
|
46
|
+
internal_host: Internal host address.
|
|
47
|
+
num_model_instances: Number of model instances to create.
|
|
48
|
+
load_model: Function to load the model (for ModelManager).
|
|
49
|
+
predict: Function to run predictions (for ModelManager).
|
|
50
|
+
batch_predict: Function to run batch predictions (for ModelManager).
|
|
51
|
+
model_name: Name of the model (for TritonModelManager).
|
|
52
|
+
model_path: Path to the model (for TritonModelManager).
|
|
53
|
+
runtime_framework: Runtime framework for the model (for TritonModelManager).
|
|
54
|
+
input_size: Input size for the model (for TritonModelManager).
|
|
55
|
+
num_classes: Number of classes for the model (for TritonModelManager).
|
|
56
|
+
use_dynamic_batching: Whether to use dynamic batching (for TritonModelManager).
|
|
57
|
+
max_batch_size: Maximum batch size (for TritonModelManager).
|
|
58
|
+
is_yolo: Whether the model is YOLO (for TritonModelManager).
|
|
59
|
+
is_ocr: Whether the model is OCR (for TritonModelManager).
|
|
60
|
+
use_trt_accelerator: Whether to use TensorRT accelerator (for TritonModelManager).
|
|
61
|
+
"""
|
|
62
|
+
self.logger = logging.getLogger(__name__)
|
|
63
|
+
self.action_tracker = action_tracker
|
|
64
|
+
self.test_env = test_env
|
|
65
|
+
self.model_type = model_type.lower() if model_type else "default"
|
|
66
|
+
|
|
67
|
+
# Validate model_type
|
|
68
|
+
if self.model_type not in ["default", "triton"]:
|
|
69
|
+
raise ValueError(f"Invalid model_type '{self.model_type}'. Must be 'default' or 'triton'")
|
|
70
|
+
|
|
71
|
+
# Default configuration for production : TODO
|
|
72
|
+
default_config = {
|
|
73
|
+
"model_id": "",
|
|
74
|
+
"internal_server_type": "rest",
|
|
75
|
+
"internal_port": 8000,
|
|
76
|
+
"internal_host": "localhost",
|
|
77
|
+
"num_model_instances": 1,
|
|
78
|
+
"load_model": None,
|
|
79
|
+
"predict": None,
|
|
80
|
+
"batch_predict": None,
|
|
81
|
+
"model_name": "",
|
|
82
|
+
"model_path": "",
|
|
83
|
+
"runtime_framework": "onnx",
|
|
84
|
+
"input_size": 640,
|
|
85
|
+
"num_classes": 80,
|
|
86
|
+
"use_dynamic_batching": False,
|
|
87
|
+
"max_batch_size": 8,
|
|
88
|
+
"is_yolo": False,
|
|
89
|
+
"is_ocr": False,
|
|
90
|
+
"use_trt_accelerator": False,
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
# Extract configuration from action_tracker for production
|
|
94
|
+
if not test_env:
|
|
95
|
+
config = self._extract_config_from_action_tracker(
|
|
96
|
+
model_type=model_type,
|
|
97
|
+
model_id=model_id,
|
|
98
|
+
internal_server_type=internal_server_type,
|
|
99
|
+
internal_port=internal_port,
|
|
100
|
+
internal_host=internal_host,
|
|
101
|
+
num_model_instances=num_model_instances,
|
|
102
|
+
load_model=load_model,
|
|
103
|
+
predict=predict,
|
|
104
|
+
batch_predict=batch_predict,
|
|
105
|
+
model_name=model_name,
|
|
106
|
+
model_path=model_path,
|
|
107
|
+
runtime_framework=runtime_framework,
|
|
108
|
+
input_size=input_size,
|
|
109
|
+
num_classes=num_classes,
|
|
110
|
+
use_dynamic_batching=use_dynamic_batching,
|
|
111
|
+
max_batch_size=max_batch_size,
|
|
112
|
+
is_yolo=is_yolo,
|
|
113
|
+
is_ocr=is_ocr,
|
|
114
|
+
use_trt_accelerator=use_trt_accelerator,
|
|
115
|
+
)
|
|
116
|
+
if not config:
|
|
117
|
+
self.logger.warning("No valid configuration found in action_tracker, using defaults")
|
|
118
|
+
config = default_config
|
|
119
|
+
else:
|
|
120
|
+
for key, value in default_config.items():
|
|
121
|
+
if key not in config or config[key] is None:
|
|
122
|
+
self.logger.warning(f"Missing or None config key '{key}' in action_tracker, using default: {value}")
|
|
123
|
+
config[key] = value
|
|
124
|
+
else:
|
|
125
|
+
# User provided args for testing
|
|
126
|
+
config = {
|
|
127
|
+
"model_id": model_id,
|
|
128
|
+
"internal_server_type": internal_server_type,
|
|
129
|
+
"internal_port": internal_port,
|
|
130
|
+
"internal_host": internal_host,
|
|
131
|
+
"num_model_instances": num_model_instances,
|
|
132
|
+
"load_model": load_model,
|
|
133
|
+
"predict": predict,
|
|
134
|
+
"batch_predict": batch_predict,
|
|
135
|
+
"model_name": model_name,
|
|
136
|
+
"model_path": model_path,
|
|
137
|
+
"runtime_framework": runtime_framework,
|
|
138
|
+
"input_size": input_size,
|
|
139
|
+
"num_classes": num_classes,
|
|
140
|
+
"use_dynamic_batching": use_dynamic_batching,
|
|
141
|
+
"max_batch_size": max_batch_size,
|
|
142
|
+
"is_yolo": is_yolo,
|
|
143
|
+
"is_ocr": is_ocr,
|
|
144
|
+
"use_trt_accelerator": use_trt_accelerator,
|
|
145
|
+
}
|
|
146
|
+
for key, value in default_config.items():
|
|
147
|
+
if config[key] is None:
|
|
148
|
+
self.logger.warning(f"Missing or None config key '{key}' in test environment, using default: {value}")
|
|
149
|
+
config[key] = value
|
|
150
|
+
|
|
151
|
+
if self.model_type == "triton":
|
|
152
|
+
# Validate required parameters for TritonModelManager
|
|
153
|
+
required_triton_params = ["model_name", "model_path", "runtime_framework"]
|
|
154
|
+
for param in required_triton_params:
|
|
155
|
+
if not config.get(param):
|
|
156
|
+
raise ValueError(f"Required parameter '{param}' is missing or invalid for Triton model manager")
|
|
157
|
+
|
|
158
|
+
self.model_manager = TritonModelManager(
|
|
159
|
+
model_name=config["model_name"],
|
|
160
|
+
model_path=config["model_path"],
|
|
161
|
+
runtime_framework=config["runtime_framework"],
|
|
162
|
+
internal_server_type=config["internal_server_type"],
|
|
163
|
+
internal_port=config["internal_port"],
|
|
164
|
+
internal_host=config["internal_host"],
|
|
165
|
+
input_size=config["input_size"],
|
|
166
|
+
num_classes=config["num_classes"],
|
|
167
|
+
num_model_instances=config["num_model_instances"],
|
|
168
|
+
use_dynamic_batching=config["use_dynamic_batching"],
|
|
169
|
+
max_batch_size=config["max_batch_size"],
|
|
170
|
+
is_yolo=config["is_yolo"],
|
|
171
|
+
is_ocr=config["is_ocr"],
|
|
172
|
+
use_trt_accelerator=config["use_trt_accelerator"],
|
|
173
|
+
)
|
|
174
|
+
else:
|
|
175
|
+
# Validate required parameters for ModelManager
|
|
176
|
+
if not self.action_tracker:
|
|
177
|
+
raise ValueError("action_tracker is required for default ModelManager")
|
|
178
|
+
|
|
179
|
+
# Validate that at least one prediction function is provided
|
|
180
|
+
if not config.get("predict") and not config.get("load_model"):
|
|
181
|
+
self.logger.warning("No prediction functions provided for ModelManager. At least 'predict' and 'load_model' should be provided.")
|
|
182
|
+
|
|
183
|
+
self.model_manager = ModelManager(
|
|
184
|
+
action_tracker=self.action_tracker,
|
|
185
|
+
load_model=config["load_model"],
|
|
186
|
+
predict=config["predict"],
|
|
187
|
+
batch_predict=config.get("batch_predict"),
|
|
188
|
+
num_model_instances=config["num_model_instances"],
|
|
189
|
+
model_path=config.get("model_path"),
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
self.logger.info(f"Initialized ModelManagerWrapper with {self.model_type} model manager")
|
|
193
|
+
|
|
194
|
+
def _extract_config_from_action_tracker(
|
|
195
|
+
self,
|
|
196
|
+
model_type: Optional[str] = None,
|
|
197
|
+
model_id: Optional[str] = None,
|
|
198
|
+
internal_server_type: Optional[str] = None,
|
|
199
|
+
internal_port: Optional[int] = None,
|
|
200
|
+
internal_host: Optional[str] = None,
|
|
201
|
+
num_model_instances: Optional[int] = None,
|
|
202
|
+
load_model: Optional[Callable] = None,
|
|
203
|
+
predict: Optional[Callable] = None,
|
|
204
|
+
batch_predict: Optional[Callable] = None,
|
|
205
|
+
model_name: Optional[str] = None,
|
|
206
|
+
model_path: Optional[str] = None,
|
|
207
|
+
runtime_framework: Optional[str] = None,
|
|
208
|
+
input_size: Optional[Union[int, List[int]]] = None,
|
|
209
|
+
num_classes: Optional[int] = None,
|
|
210
|
+
use_dynamic_batching: Optional[bool] = None,
|
|
211
|
+
max_batch_size: Optional[int] = None,
|
|
212
|
+
is_yolo: Optional[bool] = None,
|
|
213
|
+
is_ocr: Optional[bool] = None,
|
|
214
|
+
use_trt_accelerator: Optional[bool] = None,
|
|
215
|
+
) -> Dict[str, Any]:
|
|
216
|
+
"""
|
|
217
|
+
Extract configuration from action_tracker for production use.
|
|
218
|
+
|
|
219
|
+
Prioritizes configuration from action_tracker, then user-provided arguments, and finally
|
|
220
|
+
defaults. Logs warnings when falling back to user arguments or defaults.
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
model_type: User-provided model type.
|
|
224
|
+
model_id: User-provided model ID.
|
|
225
|
+
internal_server_type: User-provided server type.
|
|
226
|
+
internal_port: User-provided port.
|
|
227
|
+
internal_host: User-provided host.
|
|
228
|
+
num_model_instances: User-provided number of model instances.
|
|
229
|
+
load_model: User-provided load model function.
|
|
230
|
+
predict: User-provided predict function.
|
|
231
|
+
batch_predict: User-provided batch predict function.
|
|
232
|
+
model_name: User-provided model name.
|
|
233
|
+
model_path: User-provided model path.
|
|
234
|
+
runtime_framework: User-provided runtime framework.
|
|
235
|
+
input_size: User-provided input size.
|
|
236
|
+
num_classes: User-provided number of classes.
|
|
237
|
+
use_dynamic_batching: User-provided dynamic batching flag.
|
|
238
|
+
max_batch_size: User-provided max batch size.
|
|
239
|
+
is_yolo: User-provided YOLO flag.
|
|
240
|
+
is_ocr: User-provided OCR flag.
|
|
241
|
+
use_trt_accelerator: User-provided TensorRT accelerator flag.
|
|
242
|
+
|
|
243
|
+
Returns:
|
|
244
|
+
Configuration dictionary extracted from action_tracker, user arguments, or defaults.
|
|
245
|
+
"""
|
|
246
|
+
try:
|
|
247
|
+
config = {}
|
|
248
|
+
# Default configuration
|
|
249
|
+
default_config = {
|
|
250
|
+
"model_id": "default_model",
|
|
251
|
+
"internal_server_type": "rest",
|
|
252
|
+
"internal_port": 8000,
|
|
253
|
+
"internal_host": "localhost",
|
|
254
|
+
"num_model_instances": 1,
|
|
255
|
+
"load_model": None,
|
|
256
|
+
"predict": None,
|
|
257
|
+
"batch_predict": None,
|
|
258
|
+
"model_name": "default_model",
|
|
259
|
+
"model_path": "/models/default",
|
|
260
|
+
"runtime_framework": "onnx",
|
|
261
|
+
"input_size": 640,
|
|
262
|
+
"num_classes": 10,
|
|
263
|
+
"use_dynamic_batching": False,
|
|
264
|
+
"max_batch_size": 8,
|
|
265
|
+
"is_yolo": False,
|
|
266
|
+
"is_ocr": False,
|
|
267
|
+
"use_trt_accelerator": False,
|
|
268
|
+
}
|
|
269
|
+
|
|
270
|
+
job_params = getattr(self.action_tracker, "job_params", {})
|
|
271
|
+
action_details = getattr(self.action_tracker, "action_details", {})
|
|
272
|
+
action_tracker_sources: List[Dict, Any] = [action_details, job_params]
|
|
273
|
+
# priority: action_tracker > user_arg > default
|
|
274
|
+
def get_param(
|
|
275
|
+
key: str,
|
|
276
|
+
action_tracker_sources: List[Union[Dict, Any]],
|
|
277
|
+
user_value: Optional[Any],
|
|
278
|
+
default_value: Any,
|
|
279
|
+
) -> Any:
|
|
280
|
+
value = None
|
|
281
|
+
source = None
|
|
282
|
+
# Try action_tracker sources first
|
|
283
|
+
for src in action_tracker_sources:
|
|
284
|
+
if isinstance(src, dict):
|
|
285
|
+
value = src.get(key)
|
|
286
|
+
else:
|
|
287
|
+
value = src
|
|
288
|
+
if value is not None:
|
|
289
|
+
source = "action_tracker"
|
|
290
|
+
break
|
|
291
|
+
|
|
292
|
+
# If !action_tracker.val, try user-provided argument
|
|
293
|
+
if value is None and user_value is not None:
|
|
294
|
+
value = user_value
|
|
295
|
+
source = "user-provided"
|
|
296
|
+
|
|
297
|
+
# use default as fallback
|
|
298
|
+
if value is None:
|
|
299
|
+
value = default_value
|
|
300
|
+
source = "default"
|
|
301
|
+
|
|
302
|
+
if source != "action_tracker":
|
|
303
|
+
self.logger.warning(
|
|
304
|
+
f"Config key '{key}' not found in action_tracker, using {source} value: {value}"
|
|
305
|
+
)
|
|
306
|
+
return value
|
|
307
|
+
|
|
308
|
+
# Common params for both ModelManager and TritonModelManager
|
|
309
|
+
config["model_id"] = get_param(
|
|
310
|
+
"model_id",
|
|
311
|
+
[
|
|
312
|
+
getattr(self.action_tracker, "_idModel_str", None),
|
|
313
|
+
action_details.get("_idModelDeploy"),
|
|
314
|
+
action_details.get("_idModel"),
|
|
315
|
+
],
|
|
316
|
+
model_id,
|
|
317
|
+
default_config["model_id"],
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
config["internal_server_type"] = get_param(
|
|
321
|
+
"internal_server_type",
|
|
322
|
+
[action_details.get("server_type", "rest").lower()],
|
|
323
|
+
internal_server_type,
|
|
324
|
+
default_config["internal_server_type"],
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
# Map protocol to default port
|
|
328
|
+
protocol2port = {"rest": 8000, "grpc": 8001}
|
|
329
|
+
action_server_type = action_details.get("server_type", "rest").lower()
|
|
330
|
+
config["internal_port"] = get_param(
|
|
331
|
+
"internal_port",
|
|
332
|
+
[protocol2port.get(action_server_type, 8000)],
|
|
333
|
+
internal_port,
|
|
334
|
+
protocol2port.get(config["internal_server_type"], default_config["internal_port"]),
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
config["internal_host"] = get_param(
|
|
338
|
+
"internal_host",
|
|
339
|
+
[action_details.get("host")],
|
|
340
|
+
internal_host,
|
|
341
|
+
default_config["internal_host"],
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
config["num_model_instances"] = get_param(
|
|
345
|
+
"num_model_instances",
|
|
346
|
+
[action_details.get("num_model_instances")],
|
|
347
|
+
num_model_instances,
|
|
348
|
+
default_config["num_model_instances"],
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
# ModelManager-specific parameters
|
|
352
|
+
config["load_model"] = load_model
|
|
353
|
+
config["predict"] = predict
|
|
354
|
+
config["batch_predict"] = batch_predict
|
|
355
|
+
|
|
356
|
+
if self.model_type != "triton":
|
|
357
|
+
return config
|
|
358
|
+
|
|
359
|
+
# TritonModelManager-specific parameters
|
|
360
|
+
config["model_name"] = get_param(
|
|
361
|
+
"model_name",
|
|
362
|
+
[action_details.get("modelKey")],
|
|
363
|
+
model_name,
|
|
364
|
+
default_config["model_name"],
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
config["model_path"] = get_param(
|
|
368
|
+
"model_path",
|
|
369
|
+
[getattr(self.action_tracker, "checkpoint_path", None)],
|
|
370
|
+
model_path,
|
|
371
|
+
default_config["model_path"],
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
config["runtime_framework"] = get_param(
|
|
375
|
+
"runtime_framework",
|
|
376
|
+
[action_details.get("runtimeFramework"), action_details.get("exportFormat")],
|
|
377
|
+
runtime_framework,
|
|
378
|
+
default_config["runtime_framework"],
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
config["input_size"] = get_param(
|
|
382
|
+
"input_size",
|
|
383
|
+
# [self.action_tracker.get_input_size()], TODO: Enable after the API is working
|
|
384
|
+
action_tracker_sources,
|
|
385
|
+
input_size,
|
|
386
|
+
default_config["input_size"],
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
index_to_category = self.action_tracker.get_index_to_category(self.action_tracker.is_exported)
|
|
390
|
+
num_classes_action = len(index_to_category) if index_to_category else None
|
|
391
|
+
config["num_classes"] = get_param(
|
|
392
|
+
"num_classes",
|
|
393
|
+
[num_classes_action],
|
|
394
|
+
num_classes,
|
|
395
|
+
default_config["num_classes"],
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
config["use_dynamic_batching"] = get_param(
|
|
399
|
+
"use_dynamic_batching",
|
|
400
|
+
[],
|
|
401
|
+
use_dynamic_batching,
|
|
402
|
+
default_config["use_dynamic_batching"],
|
|
403
|
+
)
|
|
404
|
+
|
|
405
|
+
config["max_batch_size"] = get_param(
|
|
406
|
+
"max_batch_size",
|
|
407
|
+
[],
|
|
408
|
+
max_batch_size,
|
|
409
|
+
default_config["max_batch_size"],
|
|
410
|
+
)
|
|
411
|
+
|
|
412
|
+
config["is_yolo"] = get_param(
|
|
413
|
+
"is_yolo",
|
|
414
|
+
[],
|
|
415
|
+
is_yolo,
|
|
416
|
+
default_config["is_yolo"],
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
config["is_ocr"] = get_param(
|
|
420
|
+
"is_ocr",
|
|
421
|
+
[],
|
|
422
|
+
is_ocr,
|
|
423
|
+
default_config["is_ocr"],
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
config["use_trt_accelerator"] = get_param(
|
|
427
|
+
"use_trt_accelerator",
|
|
428
|
+
[action_details.get("use_trt_accelerator"), job_params.get("use_trt_accelerator")],
|
|
429
|
+
use_trt_accelerator,
|
|
430
|
+
default_config["use_trt_accelerator"],
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
return config
|
|
434
|
+
except Exception as e:
|
|
435
|
+
self.logger.error(f"Failed to extract config from action_tracker: {str(e)}", exc_info=True)
|
|
436
|
+
return {}
|
|
437
|
+
|
|
438
|
+
def inference(
|
|
439
|
+
self,
|
|
440
|
+
input: Any,
|
|
441
|
+
extra_params: Optional[Dict[str, Any]] = None,
|
|
442
|
+
stream_key: Optional[str] = None,
|
|
443
|
+
stream_info: Optional[Dict[str, Any]] = None,
|
|
444
|
+
) -> Tuple[Any, bool]:
|
|
445
|
+
"""
|
|
446
|
+
Perform synchronous single inference.
|
|
447
|
+
"""
|
|
448
|
+
if input is None:
|
|
449
|
+
raise ValueError("input cannot be None")
|
|
450
|
+
|
|
451
|
+
try:
|
|
452
|
+
if self.model_type == "triton":
|
|
453
|
+
# TritonModelManager only accepts input parameter
|
|
454
|
+
return self.model_manager.inference(input=input)
|
|
455
|
+
else:
|
|
456
|
+
# ModelManager accepts additional parameters
|
|
457
|
+
return self.model_manager.inference(
|
|
458
|
+
input=input,
|
|
459
|
+
extra_params=extra_params,
|
|
460
|
+
stream_key=stream_key,
|
|
461
|
+
stream_info=stream_info,
|
|
462
|
+
)
|
|
463
|
+
except Exception as e:
|
|
464
|
+
self.logger.error(f"Inference failed in ModelManagerWrapper: {str(e)}", exc_info=True)
|
|
465
|
+
return None, False
|
|
466
|
+
|
|
467
|
+
async def async_inference(
|
|
468
|
+
self,
|
|
469
|
+
input: Union[bytes, np.ndarray],
|
|
470
|
+
extra_params: Optional[Dict[str, Any]] = None,
|
|
471
|
+
stream_key: Optional[str] = None,
|
|
472
|
+
stream_info: Optional[Dict[str, Any]] = None,
|
|
473
|
+
) -> Tuple[Any, bool]:
|
|
474
|
+
"""
|
|
475
|
+
Perform asynchronous single inference.
|
|
476
|
+
"""
|
|
477
|
+
if input is None:
|
|
478
|
+
raise ValueError("input cannot be None")
|
|
479
|
+
|
|
480
|
+
try:
|
|
481
|
+
if self.model_type == "triton":
|
|
482
|
+
return await self.model_manager.async_inference(input=input)
|
|
483
|
+
else:
|
|
484
|
+
# ModelManager doesn't have async_inference, fallback to sync
|
|
485
|
+
return self.model_manager.inference(
|
|
486
|
+
input=input,
|
|
487
|
+
extra_params=extra_params,
|
|
488
|
+
stream_key=stream_key,
|
|
489
|
+
stream_info=stream_info,
|
|
490
|
+
)
|
|
491
|
+
except Exception as e:
|
|
492
|
+
self.logger.error(f"Async inference failed in ModelManagerWrapper: {str(e)}", exc_info=True)
|
|
493
|
+
return None, False
|
|
494
|
+
|
|
495
|
+
def batch_inference(
|
|
496
|
+
self,
|
|
497
|
+
input: List[Any],
|
|
498
|
+
extra_params: Optional[Dict[str, Any]] = None,
|
|
499
|
+
stream_key: Optional[str] = None,
|
|
500
|
+
stream_info: Optional[Dict[str, Any]] = None,
|
|
501
|
+
) -> Tuple[List[Any], bool]:
|
|
502
|
+
"""
|
|
503
|
+
Perform synchronous batch inference.
|
|
504
|
+
"""
|
|
505
|
+
if not input:
|
|
506
|
+
raise ValueError("input cannot be None or empty")
|
|
507
|
+
|
|
508
|
+
try:
|
|
509
|
+
if self.model_type == "triton":
|
|
510
|
+
# TritonModelManager only accepts input parameter
|
|
511
|
+
return self.model_manager.batch_inference(input=input)
|
|
512
|
+
else:
|
|
513
|
+
# ModelManager accepts additional parameters
|
|
514
|
+
return self.model_manager.batch_inference(
|
|
515
|
+
input=input,
|
|
516
|
+
extra_params=extra_params,
|
|
517
|
+
stream_key=stream_key,
|
|
518
|
+
stream_info=stream_info,
|
|
519
|
+
)
|
|
520
|
+
except Exception as e:
|
|
521
|
+
self.logger.error(f"Batch inference failed in ModelManagerWrapper: {str(e)}", exc_info=True)
|
|
522
|
+
return [], False
|
|
523
|
+
|
|
524
|
+
async def async_batch_inference(
|
|
525
|
+
self,
|
|
526
|
+
input: List[Any],
|
|
527
|
+
extra_params: Optional[Dict[str, Any]] = None,
|
|
528
|
+
stream_key: Optional[str] = None,
|
|
529
|
+
stream_info: Optional[Dict[str, Any]] = None,
|
|
530
|
+
) -> Tuple[List[Any], bool]:
|
|
531
|
+
"""
|
|
532
|
+
Perform asynchronous batch inference.
|
|
533
|
+
"""
|
|
534
|
+
if not input:
|
|
535
|
+
raise ValueError("input cannot be None or empty")
|
|
536
|
+
|
|
537
|
+
try:
|
|
538
|
+
if self.model_type == "triton":
|
|
539
|
+
return await self.model_manager.async_batch_inference(input=input)
|
|
540
|
+
else:
|
|
541
|
+
# ModelManager doesn't have async_batch_inference, fallback to sync
|
|
542
|
+
return self.model_manager.batch_inference(
|
|
543
|
+
input=input,
|
|
544
|
+
extra_params=extra_params,
|
|
545
|
+
stream_key=stream_key,
|
|
546
|
+
stream_info=stream_info,
|
|
547
|
+
)
|
|
548
|
+
except Exception as e:
|
|
549
|
+
self.logger.error(f"Async batch inference failed in ModelManagerWrapper: {str(e)}", exc_info=True)
|
|
550
|
+
return [], False
|