matrice-inference 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of matrice-inference might be problematic. Click here for more details.
- matrice_inference/__init__.py +72 -0
- matrice_inference/py.typed +0 -0
- matrice_inference/server/__init__.py +23 -0
- matrice_inference/server/inference_interface.py +176 -0
- matrice_inference/server/model/__init__.py +1 -0
- matrice_inference/server/model/model_manager.py +274 -0
- matrice_inference/server/model/model_manager_wrapper.py +550 -0
- matrice_inference/server/model/triton_model_manager.py +290 -0
- matrice_inference/server/model/triton_server.py +1248 -0
- matrice_inference/server/proxy_interface.py +371 -0
- matrice_inference/server/server.py +1004 -0
- matrice_inference/server/stream/__init__.py +0 -0
- matrice_inference/server/stream/app_deployment.py +228 -0
- matrice_inference/server/stream/consumer_worker.py +201 -0
- matrice_inference/server/stream/frame_cache.py +127 -0
- matrice_inference/server/stream/inference_worker.py +163 -0
- matrice_inference/server/stream/post_processing_worker.py +230 -0
- matrice_inference/server/stream/producer_worker.py +147 -0
- matrice_inference/server/stream/stream_pipeline.py +451 -0
- matrice_inference/server/stream/utils.py +23 -0
- matrice_inference/tmp/abstract_model_manager.py +58 -0
- matrice_inference/tmp/aggregator/__init__.py +18 -0
- matrice_inference/tmp/aggregator/aggregator.py +330 -0
- matrice_inference/tmp/aggregator/analytics.py +906 -0
- matrice_inference/tmp/aggregator/ingestor.py +438 -0
- matrice_inference/tmp/aggregator/latency.py +597 -0
- matrice_inference/tmp/aggregator/pipeline.py +968 -0
- matrice_inference/tmp/aggregator/publisher.py +431 -0
- matrice_inference/tmp/aggregator/synchronizer.py +594 -0
- matrice_inference/tmp/batch_manager.py +239 -0
- matrice_inference/tmp/overall_inference_testing.py +338 -0
- matrice_inference/tmp/triton_utils.py +638 -0
- matrice_inference-0.1.2.dist-info/METADATA +28 -0
- matrice_inference-0.1.2.dist-info/RECORD +37 -0
- matrice_inference-0.1.2.dist-info/WHEEL +5 -0
- matrice_inference-0.1.2.dist-info/licenses/LICENSE.txt +21 -0
- matrice_inference-0.1.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
"""Module providing __init__ functionality."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import sys
|
|
5
|
+
from matrice_common.utils import dependencies_check
|
|
6
|
+
|
|
7
|
+
base = [
|
|
8
|
+
"httpx",
|
|
9
|
+
"fastapi",
|
|
10
|
+
"uvicorn",
|
|
11
|
+
"pillow",
|
|
12
|
+
"confluent_kafka[snappy]",
|
|
13
|
+
"aiokafka",
|
|
14
|
+
"aiohttp",
|
|
15
|
+
"filterpy",
|
|
16
|
+
"scipy",
|
|
17
|
+
"scikit-learn",
|
|
18
|
+
"matplotlib",
|
|
19
|
+
"scikit-image",
|
|
20
|
+
"python-snappy",
|
|
21
|
+
"pyyaml",
|
|
22
|
+
"imagehash",
|
|
23
|
+
"Pillow",
|
|
24
|
+
"transformers"
|
|
25
|
+
]
|
|
26
|
+
|
|
27
|
+
# Install base dependencies first
|
|
28
|
+
dependencies_check(base)
|
|
29
|
+
|
|
30
|
+
# Helper to attempt installation and verify importability
|
|
31
|
+
def _install_and_verify(pkg: str, import_name: str):
|
|
32
|
+
"""Install a package expression and return True if the import succeeds."""
|
|
33
|
+
if dependencies_check([pkg]):
|
|
34
|
+
try:
|
|
35
|
+
__import__(import_name)
|
|
36
|
+
return True
|
|
37
|
+
except ImportError:
|
|
38
|
+
return False
|
|
39
|
+
return False
|
|
40
|
+
|
|
41
|
+
if not dependencies_check(["opencv-python"]):
|
|
42
|
+
dependencies_check(["opencv-python-headless"])
|
|
43
|
+
|
|
44
|
+
# Attempt GPU-specific dependencies first
|
|
45
|
+
_gpu_ok = _install_and_verify("onnxruntime-gpu", "onnxruntime") and _install_and_verify(
|
|
46
|
+
"fast-plate-ocr[onnx-gpu]", "fast_plate_ocr"
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
if not _gpu_ok:
|
|
50
|
+
# Fallback to CPU variants
|
|
51
|
+
_cpu_ok = _install_and_verify("onnxruntime", "onnxruntime") and _install_and_verify(
|
|
52
|
+
"fast-plate-ocr[onnx]", "fast_plate_ocr"
|
|
53
|
+
)
|
|
54
|
+
if not _cpu_ok:
|
|
55
|
+
# Last-chance fallback without extras tag (PyPI sometimes lacks them)
|
|
56
|
+
_install_and_verify("fast-plate-ocr", "fast_plate_ocr")
|
|
57
|
+
|
|
58
|
+
# matrice_deps = ["matrice_common", "matrice_analytics", "matrice"]
|
|
59
|
+
|
|
60
|
+
# dependencies_check(matrice_deps)
|
|
61
|
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
|
62
|
+
from server.server import MatriceDeployServer # noqa: E402
|
|
63
|
+
from server.server import MatriceDeployServer as MatriceDeploy # noqa: E402 # Keep this for backwards compatibility
|
|
64
|
+
from server.inference_interface import InferenceInterface # noqa: E402
|
|
65
|
+
from server.proxy_interface import MatriceProxyInterface # noqa: E402
|
|
66
|
+
|
|
67
|
+
__all__ = [
|
|
68
|
+
"MatriceDeploy",
|
|
69
|
+
"MatriceDeployServer",
|
|
70
|
+
"InferenceInterface",
|
|
71
|
+
"MatriceProxyInterface",
|
|
72
|
+
]
|
|
File without changes
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import logging
|
|
3
|
+
|
|
4
|
+
# Root logger
|
|
5
|
+
logging.basicConfig(level=logging.DEBUG)
|
|
6
|
+
|
|
7
|
+
# Console handler (INFO+)
|
|
8
|
+
console_handler = logging.StreamHandler()
|
|
9
|
+
console_handler.setLevel(logging.INFO)
|
|
10
|
+
|
|
11
|
+
# File handler (DEBUG+)
|
|
12
|
+
log_path = os.path.join(os.getcwd(), "deploy_server.log")
|
|
13
|
+
file_handler = logging.FileHandler(log_path)
|
|
14
|
+
file_handler.setLevel(logging.DEBUG)
|
|
15
|
+
|
|
16
|
+
# Formatter
|
|
17
|
+
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
|
18
|
+
console_handler.setFormatter(formatter)
|
|
19
|
+
file_handler.setFormatter(formatter)
|
|
20
|
+
|
|
21
|
+
# Add handlers to root logger
|
|
22
|
+
logging.getLogger().addHandler(console_handler)
|
|
23
|
+
logging.getLogger().addHandler(file_handler)
|
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
from matrice_inference.server.model.model_manager_wrapper import ModelManagerWrapper
|
|
2
|
+
from typing import Dict, Any, Optional, Tuple, Union
|
|
3
|
+
from datetime import datetime, timezone
|
|
4
|
+
import logging
|
|
5
|
+
import time
|
|
6
|
+
from matrice_analytics.post_processing.post_processor import PostProcessor
|
|
7
|
+
|
|
8
|
+
class InferenceInterface:
|
|
9
|
+
"""Interface for proxying requests to model servers with optional post-processing."""
|
|
10
|
+
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
model_manager: ModelManagerWrapper,
|
|
14
|
+
post_processor: Optional[PostProcessor] = None,
|
|
15
|
+
):
|
|
16
|
+
"""
|
|
17
|
+
Initialize the inference interface.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
model_manager: Model manager for model inference
|
|
21
|
+
post_processor: Post processor for post-processing
|
|
22
|
+
"""
|
|
23
|
+
self.logger = logging.getLogger(__name__)
|
|
24
|
+
self.model_manager = model_manager
|
|
25
|
+
self.post_processor = post_processor
|
|
26
|
+
self.latest_inference_time = datetime.now(timezone.utc)
|
|
27
|
+
|
|
28
|
+
def get_latest_inference_time(self) -> datetime:
|
|
29
|
+
"""Get the latest inference time."""
|
|
30
|
+
return self.latest_inference_time
|
|
31
|
+
|
|
32
|
+
async def inference(
|
|
33
|
+
self,
|
|
34
|
+
input: Any,
|
|
35
|
+
extra_params: Optional[Dict[str, Any]] = None,
|
|
36
|
+
apply_post_processing: bool = False,
|
|
37
|
+
post_processing_config: Optional[Union[Dict[str, Any], str]] = None,
|
|
38
|
+
stream_key: Optional[str] = None,
|
|
39
|
+
stream_info: Optional[Dict[str, Any]] = None,
|
|
40
|
+
camera_info: Optional[Dict[str, Any]] = None,
|
|
41
|
+
) -> Tuple[Any, Optional[Dict[str, Any]]]:
|
|
42
|
+
"""Perform inference using the appropriate client with optional post-processing.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
input: Primary input data (e.g., image bytes, numpy array)
|
|
46
|
+
extra_params: Additional parameters for inference (optional)
|
|
47
|
+
apply_post_processing: Whether to apply post-processing
|
|
48
|
+
post_processing_config: Configuration for post-processing
|
|
49
|
+
stream_key: Unique identifier for the input stream
|
|
50
|
+
stream_info: Additional metadata about the stream (optional)
|
|
51
|
+
camera_info: Additional metadata about the camera/source (optional)
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
A tuple containing:
|
|
55
|
+
- The inference results (raw or post-processed)
|
|
56
|
+
- Metadata about the inference and post-processing (if applicable)
|
|
57
|
+
"""
|
|
58
|
+
if input is None:
|
|
59
|
+
raise ValueError("Input cannot be None")
|
|
60
|
+
|
|
61
|
+
# Measure model inference time
|
|
62
|
+
model_start_time = time.time()
|
|
63
|
+
|
|
64
|
+
# Update latest inference time
|
|
65
|
+
self.latest_inference_time = datetime.now(timezone.utc)
|
|
66
|
+
|
|
67
|
+
# Run model inference
|
|
68
|
+
try:
|
|
69
|
+
raw_results, success = self.model_manager.inference(
|
|
70
|
+
input=input,
|
|
71
|
+
extra_params=extra_params,
|
|
72
|
+
stream_key=stream_key,
|
|
73
|
+
stream_info=stream_info
|
|
74
|
+
)
|
|
75
|
+
model_inference_time = time.time() - model_start_time
|
|
76
|
+
|
|
77
|
+
if not success:
|
|
78
|
+
raise RuntimeError("Model inference failed")
|
|
79
|
+
|
|
80
|
+
self.logger.debug(
|
|
81
|
+
f"Model inference executed stream_key={stream_key} time={model_inference_time:.4f}s"
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
except Exception as exc:
|
|
85
|
+
self.logger.error(f"Model inference failed: {str(exc)}", exc_info=True)
|
|
86
|
+
raise RuntimeError(f"Model inference failed: {str(exc)}") from exc
|
|
87
|
+
|
|
88
|
+
# If no post-processing requested, return raw results
|
|
89
|
+
if not apply_post_processing or not self.post_processor:
|
|
90
|
+
return raw_results, {
|
|
91
|
+
"timing_metadata": {
|
|
92
|
+
"model_inference_time_sec": model_inference_time,
|
|
93
|
+
"post_processing_time_sec": 0.0,
|
|
94
|
+
"total_time_sec": model_inference_time,
|
|
95
|
+
}
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
# Apply post-processing using PostProcessor
|
|
99
|
+
try:
|
|
100
|
+
post_processing_start_time = time.time()
|
|
101
|
+
|
|
102
|
+
# Use PostProcessor.process() method directly
|
|
103
|
+
result = await self.post_processor.process(
|
|
104
|
+
data=raw_results,
|
|
105
|
+
config=post_processing_config, # Use stream_key as fallback if no config
|
|
106
|
+
input_bytes=input if isinstance(input, bytes) else None,
|
|
107
|
+
stream_key=stream_key,
|
|
108
|
+
stream_info=stream_info
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
post_processing_time = time.time() - post_processing_start_time
|
|
112
|
+
|
|
113
|
+
# Format the response based on PostProcessor result
|
|
114
|
+
if result.is_success():
|
|
115
|
+
# For face recognition use case, return empty raw results
|
|
116
|
+
processed_raw_results = [] if (
|
|
117
|
+
hasattr(result, 'usecase') and result.usecase == 'face_recognition'
|
|
118
|
+
) else raw_results
|
|
119
|
+
|
|
120
|
+
# Extract agg_summary from result data if available
|
|
121
|
+
agg_summary = {}
|
|
122
|
+
if hasattr(result, 'data') and isinstance(result.data, dict):
|
|
123
|
+
agg_summary = result.data.get("agg_summary", {})
|
|
124
|
+
|
|
125
|
+
post_processing_result = {
|
|
126
|
+
"status": "success",
|
|
127
|
+
"processing_time": result.processing_time,
|
|
128
|
+
"usecase": getattr(result, 'usecase', ''),
|
|
129
|
+
"category": getattr(result, 'category', ''),
|
|
130
|
+
"summary": getattr(result, 'summary', ''),
|
|
131
|
+
"insights": getattr(result, 'insights', []),
|
|
132
|
+
"metrics": getattr(result, 'metrics', {}),
|
|
133
|
+
"predictions": getattr(result, 'predictions', []),
|
|
134
|
+
"agg_summary": agg_summary,
|
|
135
|
+
"stream_key": stream_key or "default_stream",
|
|
136
|
+
"timing_metadata": {
|
|
137
|
+
"model_inference_time_sec": model_inference_time,
|
|
138
|
+
"post_processing_time_sec": post_processing_time,
|
|
139
|
+
"total_time_sec": model_inference_time + post_processing_time,
|
|
140
|
+
}
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
return processed_raw_results, post_processing_result
|
|
144
|
+
else:
|
|
145
|
+
# Post-processing failed
|
|
146
|
+
self.logger.error(f"Post-processing failed: {result.error_message}")
|
|
147
|
+
return raw_results, {
|
|
148
|
+
"status": "post_processing_failed",
|
|
149
|
+
"error": result.error_message,
|
|
150
|
+
"error_type": getattr(result, 'error_type', 'ProcessingError'),
|
|
151
|
+
"processing_time": result.processing_time,
|
|
152
|
+
"processed_data": raw_results,
|
|
153
|
+
"stream_key": stream_key or "default_stream",
|
|
154
|
+
"timing_metadata": {
|
|
155
|
+
"model_inference_time_sec": model_inference_time,
|
|
156
|
+
"post_processing_time_sec": post_processing_time,
|
|
157
|
+
"total_time_sec": model_inference_time + post_processing_time,
|
|
158
|
+
}
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
except Exception as e:
|
|
162
|
+
post_processing_time = time.time() - post_processing_start_time
|
|
163
|
+
self.logger.error(f"Post-processing exception: {str(e)}", exc_info=True)
|
|
164
|
+
|
|
165
|
+
return raw_results, {
|
|
166
|
+
"status": "post_processing_failed",
|
|
167
|
+
"error": str(e),
|
|
168
|
+
"error_type": type(e).__name__,
|
|
169
|
+
"processed_data": raw_results,
|
|
170
|
+
"stream_key": stream_key or "default_stream",
|
|
171
|
+
"timing_metadata": {
|
|
172
|
+
"model_inference_time_sec": model_inference_time,
|
|
173
|
+
"post_processing_time_sec": post_processing_time,
|
|
174
|
+
"total_time_sec": model_inference_time + post_processing_time,
|
|
175
|
+
}
|
|
176
|
+
}
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1,274 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import gc
|
|
3
|
+
from typing import Tuple, Any, Optional, List, Callable, Dict
|
|
4
|
+
|
|
5
|
+
class ModelManager:
|
|
6
|
+
"""Minimal ModelManager that focuses on model lifecycle and prediction calls."""
|
|
7
|
+
|
|
8
|
+
def __init__(
|
|
9
|
+
self,
|
|
10
|
+
action_tracker: Any,
|
|
11
|
+
load_model: Optional[Callable] = None,
|
|
12
|
+
predict: Optional[Callable] = None,
|
|
13
|
+
batch_predict: Optional[Callable] = None,
|
|
14
|
+
num_model_instances: int = 1,
|
|
15
|
+
model_path: Optional[str] = None, # For local model loading testing
|
|
16
|
+
):
|
|
17
|
+
"""Initialize the ModelManager
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
action_tracker: Tracker for monitoring actions.
|
|
21
|
+
load_model: Function to load the model.
|
|
22
|
+
predict: Function to run predictions.
|
|
23
|
+
batch_predict: Function to run batch predictions.
|
|
24
|
+
num_model_instances: Number of model instances to create.
|
|
25
|
+
model_path: Path to the model directory.
|
|
26
|
+
"""
|
|
27
|
+
try:
|
|
28
|
+
self.load_model = self._create_load_model_wrapper(load_model)
|
|
29
|
+
self.predict = self._create_prediction_wrapper(predict)
|
|
30
|
+
self.batch_predict = self._create_prediction_wrapper(batch_predict)
|
|
31
|
+
self.action_tracker = action_tracker
|
|
32
|
+
|
|
33
|
+
# Model instances
|
|
34
|
+
self.model_instances = []
|
|
35
|
+
self._round_robin_counter = 0
|
|
36
|
+
self.model_path = model_path
|
|
37
|
+
|
|
38
|
+
for _ in range(num_model_instances):
|
|
39
|
+
self.scale_up()
|
|
40
|
+
except Exception as e:
|
|
41
|
+
logging.error(f"Failed to initialize ModelManager: {str(e)}")
|
|
42
|
+
raise
|
|
43
|
+
|
|
44
|
+
def _create_load_model_wrapper(self, load_model_func: Callable):
|
|
45
|
+
"""Create a wrapper function that handles parameter passing to the load model function.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
load_model_func: The load model function to wrap
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
A wrapper function that handles parameter passing safely
|
|
52
|
+
"""
|
|
53
|
+
if not load_model_func:
|
|
54
|
+
return load_model_func
|
|
55
|
+
|
|
56
|
+
def wrapper():
|
|
57
|
+
"""Wrapper that safely calls the load model function with proper parameter handling."""
|
|
58
|
+
try:
|
|
59
|
+
# Get function parameter names
|
|
60
|
+
param_names = load_model_func.__code__.co_varnames[
|
|
61
|
+
: load_model_func.__code__.co_argcount
|
|
62
|
+
]
|
|
63
|
+
|
|
64
|
+
arg_count = load_model_func.__code__.co_argcount
|
|
65
|
+
|
|
66
|
+
# Handle case where function has exactly 1 argument and it's not named
|
|
67
|
+
if arg_count == 1 and param_names and param_names[0] in ['_', 'arg', 'args']:
|
|
68
|
+
# Pass action_tracker as positional argument
|
|
69
|
+
if self.action_tracker is not None:
|
|
70
|
+
return load_model_func(self.action_tracker)
|
|
71
|
+
else:
|
|
72
|
+
# Try calling with no arguments if action_tracker is None
|
|
73
|
+
return load_model_func()
|
|
74
|
+
|
|
75
|
+
# Handle case where function has exactly 1 argument with a recognizable name
|
|
76
|
+
if arg_count == 1 and param_names:
|
|
77
|
+
param_name = param_names[0]
|
|
78
|
+
# Check if it's likely to want action_tracker
|
|
79
|
+
if param_name in ["action_tracker", "actionTracker", "tracker"]:
|
|
80
|
+
return load_model_func(self.action_tracker)
|
|
81
|
+
elif param_name in ["model_path", "path"] and self.model_path is not None:
|
|
82
|
+
return load_model_func(self.model_path)
|
|
83
|
+
else:
|
|
84
|
+
# Pass action_tracker as fallback for single argument functions
|
|
85
|
+
return load_model_func(self.action_tracker if self.action_tracker is not None else None)
|
|
86
|
+
|
|
87
|
+
# Build filtered parameters based on what the function accepts (original logic for multi-param functions)
|
|
88
|
+
filtered_params = {}
|
|
89
|
+
|
|
90
|
+
# Add action_tracker if the function accepts it
|
|
91
|
+
if self.action_tracker is not None:
|
|
92
|
+
if "action_tracker" in param_names:
|
|
93
|
+
filtered_params["action_tracker"] = self.action_tracker
|
|
94
|
+
elif "actionTracker" in param_names:
|
|
95
|
+
filtered_params["actionTracker"] = self.action_tracker
|
|
96
|
+
|
|
97
|
+
# Add model_path if the function accepts it
|
|
98
|
+
if "model_path" in param_names and self.model_path is not None:
|
|
99
|
+
filtered_params["model_path"] = self.model_path
|
|
100
|
+
|
|
101
|
+
return load_model_func(**filtered_params)
|
|
102
|
+
|
|
103
|
+
except Exception as e:
|
|
104
|
+
error_msg = f"Load model function execution failed: {str(e)}"
|
|
105
|
+
logging.error(error_msg, exc_info=True)
|
|
106
|
+
raise RuntimeError(error_msg) from e
|
|
107
|
+
|
|
108
|
+
return wrapper
|
|
109
|
+
|
|
110
|
+
def scale_up(self):
|
|
111
|
+
"""Load the model into memory (scale up)"""
|
|
112
|
+
try:
|
|
113
|
+
self.model_instances.append(self.load_model())
|
|
114
|
+
return True
|
|
115
|
+
except Exception as e:
|
|
116
|
+
logging.error(f"Failed to scale up model: {str(e)}")
|
|
117
|
+
return False
|
|
118
|
+
|
|
119
|
+
def scale_down(self):
|
|
120
|
+
"""Unload the model from memory (scale down)"""
|
|
121
|
+
if not self.model_instances:
|
|
122
|
+
return True
|
|
123
|
+
try:
|
|
124
|
+
del self.model_instances[-1]
|
|
125
|
+
gc.collect()
|
|
126
|
+
import torch
|
|
127
|
+
if torch.cuda.is_available():
|
|
128
|
+
torch.cuda.empty_cache()
|
|
129
|
+
return True
|
|
130
|
+
except Exception as e:
|
|
131
|
+
logging.error(f"Failed to scale down model: {str(e)}")
|
|
132
|
+
return False
|
|
133
|
+
|
|
134
|
+
def get_model(self):
|
|
135
|
+
"""Get the model instance in round-robin fashion"""
|
|
136
|
+
if not self.model_instances:
|
|
137
|
+
logging.warning("No model instances available")
|
|
138
|
+
return None
|
|
139
|
+
|
|
140
|
+
order = self._round_robin_counter % len(self.model_instances)
|
|
141
|
+
# Get the current model instance
|
|
142
|
+
model = self.model_instances[order]
|
|
143
|
+
if not model:
|
|
144
|
+
logging.error("No model instance found, will try to load model")
|
|
145
|
+
self.model_instances[order] = self.load_model()
|
|
146
|
+
model = self.model_instances[order]
|
|
147
|
+
|
|
148
|
+
# Increment counter for next call
|
|
149
|
+
self._round_robin_counter = (self._round_robin_counter + 1) % len(
|
|
150
|
+
self.model_instances
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
return model
|
|
154
|
+
|
|
155
|
+
def _create_prediction_wrapper(self, predict_func: Callable):
|
|
156
|
+
"""Create a wrapper function that handles parameter passing to the prediction function.
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
predict_func: The prediction function to wrap
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
A wrapper function that handles parameter passing safely
|
|
163
|
+
"""
|
|
164
|
+
|
|
165
|
+
def wrapper(model, input: bytes, extra_params: Dict[str, Any]=None, stream_key: Optional[str]=None, stream_info: Optional[Dict[str, Any]]=None) -> dict:
|
|
166
|
+
"""Wrapper that safely calls the prediction function with proper parameter handling."""
|
|
167
|
+
try:
|
|
168
|
+
# Ensure extra_params is a dictionary
|
|
169
|
+
if extra_params is None:
|
|
170
|
+
extra_params = {}
|
|
171
|
+
elif isinstance(extra_params, list):
|
|
172
|
+
logging.warning(f"extra_params received as list instead of dict, converting: {extra_params}")
|
|
173
|
+
# Convert list to dict if possible, otherwise use empty dict
|
|
174
|
+
if len(extra_params) == 0:
|
|
175
|
+
extra_params = {}
|
|
176
|
+
elif all(isinstance(item, dict) for item in extra_params):
|
|
177
|
+
# Merge all dictionaries in the list
|
|
178
|
+
merged_params = {}
|
|
179
|
+
for item in extra_params:
|
|
180
|
+
merged_params.update(item)
|
|
181
|
+
extra_params = merged_params
|
|
182
|
+
else:
|
|
183
|
+
logging.error(f"Cannot convert extra_params list to dict: {extra_params}")
|
|
184
|
+
extra_params = {}
|
|
185
|
+
elif not isinstance(extra_params, dict):
|
|
186
|
+
logging.warning(f"extra_params is not a dict, using empty dict instead. Received: {type(extra_params)}")
|
|
187
|
+
extra_params = {}
|
|
188
|
+
|
|
189
|
+
param_names = predict_func.__code__.co_varnames[
|
|
190
|
+
: predict_func.__code__.co_argcount
|
|
191
|
+
]
|
|
192
|
+
filtered_params = {
|
|
193
|
+
k: v for k, v in extra_params.items() if k in param_names
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
# Build arguments list
|
|
197
|
+
args = [model, input]
|
|
198
|
+
|
|
199
|
+
# Add stream_key if the function accepts it (regardless of its value)
|
|
200
|
+
if "stream_key" in param_names:
|
|
201
|
+
filtered_params["stream_key"] = stream_key
|
|
202
|
+
|
|
203
|
+
if "stream_info" in param_names:
|
|
204
|
+
filtered_params["stream_info"] = stream_info
|
|
205
|
+
|
|
206
|
+
return predict_func(*args, **filtered_params)
|
|
207
|
+
|
|
208
|
+
except Exception as e:
|
|
209
|
+
error_msg = f"Prediction function execution failed: {str(e)}"
|
|
210
|
+
logging.error(error_msg, exc_info=True)
|
|
211
|
+
raise RuntimeError(error_msg) from e
|
|
212
|
+
|
|
213
|
+
return wrapper
|
|
214
|
+
|
|
215
|
+
def inference(self, input: bytes, extra_params: Dict[str, Any]=None, stream_key: Optional[str]=None, stream_info: Optional[Dict[str, Any]]=None) -> Tuple[dict, bool]:
|
|
216
|
+
"""Run inference on the provided input data.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
input: Primary input data (can be image bytes or numpy array)
|
|
220
|
+
extra_params: Additional parameters for inference (optional)
|
|
221
|
+
stream_key: Stream key for the inference
|
|
222
|
+
stream_info: Stream info for the inference
|
|
223
|
+
Returns:
|
|
224
|
+
Tuple of (results, success_flag)
|
|
225
|
+
|
|
226
|
+
Raises:
|
|
227
|
+
ValueError: If input data is invalid
|
|
228
|
+
"""
|
|
229
|
+
if input is None:
|
|
230
|
+
raise ValueError("Input data cannot be None")
|
|
231
|
+
|
|
232
|
+
try:
|
|
233
|
+
model = self.get_model()
|
|
234
|
+
results = self.predict(model, input, extra_params, stream_key, stream_info)
|
|
235
|
+
if self.action_tracker:
|
|
236
|
+
results = self.action_tracker.update_prediction_results(results)
|
|
237
|
+
return results, True
|
|
238
|
+
except Exception as e:
|
|
239
|
+
logging.error(f"Inference failed: {str(e)}")
|
|
240
|
+
return None, False
|
|
241
|
+
|
|
242
|
+
def batch_inference(
|
|
243
|
+
self, input: List[bytes], extra_params: Dict[str, Any]=None, stream_key: Optional[str]=None, stream_info: Optional[Dict[str, Any]]=None
|
|
244
|
+
) -> Tuple[dict, bool]:
|
|
245
|
+
"""Run batch inference on the provided input data.
|
|
246
|
+
|
|
247
|
+
Args:
|
|
248
|
+
input: Primary input data
|
|
249
|
+
extra_params: Additional parameters for inference (optional)
|
|
250
|
+
stream_key: Stream key for the inference
|
|
251
|
+
stream_info: Stream info for the inference
|
|
252
|
+
Returns:
|
|
253
|
+
Tuple of (results, success_flag)
|
|
254
|
+
|
|
255
|
+
Raises:
|
|
256
|
+
ValueError: If input data is invalid
|
|
257
|
+
"""
|
|
258
|
+
if input is None:
|
|
259
|
+
raise ValueError("Input data cannot be None")
|
|
260
|
+
try:
|
|
261
|
+
model = self.get_model()
|
|
262
|
+
if not self.batch_predict:
|
|
263
|
+
logging.error("Batch prediction function not found")
|
|
264
|
+
return None, False
|
|
265
|
+
results = self.batch_predict(model, input, extra_params, stream_key, stream_info)
|
|
266
|
+
if self.action_tracker:
|
|
267
|
+
for result in results:
|
|
268
|
+
self.action_tracker.update_prediction_results(result)
|
|
269
|
+
return results, True
|
|
270
|
+
except Exception as e:
|
|
271
|
+
logging.error(f"Batch inference failed: {str(e)}")
|
|
272
|
+
return None, False
|
|
273
|
+
|
|
274
|
+
# TODO: Add multi model execution with torch.cuda.stream()
|