matrice-inference 0.1.2__py3-none-any.whl → 0.1.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of matrice-inference might be problematic. Click here for more details.
- matrice_inference/__init__.py +40 -23
- matrice_inference/server/__init__.py +17 -11
- matrice_inference/server/model/triton_server.py +1 -3
- matrice_inference/server/server.py +3 -4
- matrice_inference/server/stream/consumer_worker.py +398 -141
- matrice_inference/server/stream/frame_cache.py +149 -54
- matrice_inference/server/stream/inference_worker.py +183 -94
- matrice_inference/server/stream/post_processing_worker.py +246 -181
- matrice_inference/server/stream/producer_worker.py +155 -98
- matrice_inference/server/stream/stream_pipeline.py +220 -248
- matrice_inference/tmp/aggregator/analytics.py +1 -1
- matrice_inference/tmp/overall_inference_testing.py +0 -4
- {matrice_inference-0.1.2.dist-info → matrice_inference-0.1.22.dist-info}/METADATA +1 -1
- {matrice_inference-0.1.2.dist-info → matrice_inference-0.1.22.dist-info}/RECORD +17 -17
- {matrice_inference-0.1.2.dist-info → matrice_inference-0.1.22.dist-info}/WHEEL +0 -0
- {matrice_inference-0.1.2.dist-info → matrice_inference-0.1.22.dist-info}/licenses/LICENSE.txt +0 -0
- {matrice_inference-0.1.2.dist-info → matrice_inference-0.1.22.dist-info}/top_level.txt +0 -0
|
@@ -1,50 +1,80 @@
|
|
|
1
1
|
import logging
|
|
2
|
-
import threading
|
|
3
2
|
import queue
|
|
4
|
-
|
|
3
|
+
import threading
|
|
4
|
+
from typing import Optional
|
|
5
5
|
|
|
6
6
|
try:
|
|
7
7
|
import redis # type: ignore
|
|
8
|
-
except
|
|
8
|
+
except ImportError: # pragma: no cover
|
|
9
9
|
redis = None # type: ignore
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
class RedisFrameCache:
|
|
13
|
-
"""Non-blocking Redis cache for frames
|
|
13
|
+
"""Non-blocking Redis cache for frames with optimized resource management.
|
|
14
14
|
|
|
15
15
|
Stores base64 string content under key 'stream:frames:{frame_id}' with field 'frame'.
|
|
16
16
|
Each insert sets or refreshes the TTL.
|
|
17
17
|
"""
|
|
18
18
|
|
|
19
|
+
DEFAULT_TTL_SECONDS = 300
|
|
20
|
+
DEFAULT_MAX_QUEUE = 10000
|
|
21
|
+
DEFAULT_WORKER_THREADS = 2
|
|
22
|
+
DEFAULT_CONNECT_TIMEOUT = 2.0
|
|
23
|
+
DEFAULT_SOCKET_TIMEOUT = 0.5
|
|
24
|
+
DEFAULT_HEALTH_CHECK_INTERVAL = 30
|
|
25
|
+
DEFAULT_PREFIX = "stream:frames:"
|
|
26
|
+
|
|
19
27
|
def __init__(
|
|
20
28
|
self,
|
|
21
29
|
host: str = "localhost",
|
|
22
30
|
port: int = 6379,
|
|
23
31
|
db: int = 0,
|
|
24
|
-
password: str = None,
|
|
25
|
-
username: str = None,
|
|
26
|
-
ttl_seconds: int =
|
|
27
|
-
prefix: str =
|
|
28
|
-
max_queue: int =
|
|
29
|
-
worker_threads: int =
|
|
30
|
-
connect_timeout: float =
|
|
31
|
-
socket_timeout: float =
|
|
32
|
+
password: Optional[str] = None,
|
|
33
|
+
username: Optional[str] = None,
|
|
34
|
+
ttl_seconds: int = DEFAULT_TTL_SECONDS,
|
|
35
|
+
prefix: str = DEFAULT_PREFIX,
|
|
36
|
+
max_queue: int = DEFAULT_MAX_QUEUE,
|
|
37
|
+
worker_threads: int = DEFAULT_WORKER_THREADS,
|
|
38
|
+
connect_timeout: float = DEFAULT_CONNECT_TIMEOUT,
|
|
39
|
+
socket_timeout: float = DEFAULT_SOCKET_TIMEOUT,
|
|
32
40
|
) -> None:
|
|
33
|
-
self.logger = logging.getLogger(__name__
|
|
34
|
-
self.ttl_seconds = int(ttl_seconds)
|
|
41
|
+
self.logger = logging.getLogger(f"{__name__}.frame_cache")
|
|
42
|
+
self.ttl_seconds = max(1, int(ttl_seconds))
|
|
35
43
|
self.prefix = prefix
|
|
36
|
-
self.queue: "queue.Queue" = queue.Queue(maxsize=max_queue)
|
|
37
|
-
self.threads = []
|
|
38
44
|
self.running = False
|
|
39
|
-
self._client = None
|
|
40
45
|
self._worker_threads = max(1, int(worker_threads))
|
|
41
46
|
|
|
47
|
+
self.queue: queue.Queue = queue.Queue(maxsize=max_queue)
|
|
48
|
+
self.threads: list = []
|
|
49
|
+
self._client: Optional[redis.Redis] = None
|
|
50
|
+
|
|
51
|
+
if not self._is_redis_available():
|
|
52
|
+
return
|
|
53
|
+
|
|
54
|
+
self._client = self._create_redis_client(
|
|
55
|
+
host, port, db, password, username, connect_timeout, socket_timeout
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
def _is_redis_available(self) -> bool:
|
|
59
|
+
"""Check if Redis package is available."""
|
|
42
60
|
if redis is None:
|
|
43
61
|
self.logger.warning("redis package not installed; frame caching disabled")
|
|
44
|
-
return
|
|
62
|
+
return False
|
|
63
|
+
return True
|
|
45
64
|
|
|
65
|
+
def _create_redis_client(
|
|
66
|
+
self,
|
|
67
|
+
host: str,
|
|
68
|
+
port: int,
|
|
69
|
+
db: int,
|
|
70
|
+
password: Optional[str],
|
|
71
|
+
username: Optional[str],
|
|
72
|
+
connect_timeout: float,
|
|
73
|
+
socket_timeout: float
|
|
74
|
+
) -> Optional[redis.Redis]:
|
|
75
|
+
"""Create Redis client with proper error handling."""
|
|
46
76
|
try:
|
|
47
|
-
|
|
77
|
+
return redis.Redis(
|
|
48
78
|
host=host,
|
|
49
79
|
port=port,
|
|
50
80
|
db=db,
|
|
@@ -52,76 +82,141 @@ class RedisFrameCache:
|
|
|
52
82
|
username=username,
|
|
53
83
|
socket_connect_timeout=connect_timeout,
|
|
54
84
|
socket_timeout=socket_timeout,
|
|
55
|
-
health_check_interval=
|
|
85
|
+
health_check_interval=self.DEFAULT_HEALTH_CHECK_INTERVAL,
|
|
56
86
|
retry_on_timeout=True,
|
|
57
|
-
decode_responses=True,
|
|
87
|
+
decode_responses=True,
|
|
58
88
|
)
|
|
59
89
|
except Exception as e:
|
|
60
|
-
self.logger.warning("Failed to
|
|
61
|
-
|
|
90
|
+
self.logger.warning(f"Failed to initialize Redis client: {e}")
|
|
91
|
+
return None
|
|
62
92
|
|
|
63
93
|
def start(self) -> None:
|
|
94
|
+
"""Start the frame cache with worker threads."""
|
|
64
95
|
if not self._client or self.running:
|
|
65
96
|
return
|
|
97
|
+
|
|
66
98
|
self.running = True
|
|
99
|
+
self._start_worker_threads()
|
|
100
|
+
|
|
101
|
+
def _start_worker_threads(self) -> None:
|
|
102
|
+
"""Start worker threads for processing cache operations."""
|
|
67
103
|
for i in range(self._worker_threads):
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
104
|
+
thread = threading.Thread(
|
|
105
|
+
target=self._worker,
|
|
106
|
+
name=f"FrameCache-{i}",
|
|
107
|
+
daemon=True
|
|
108
|
+
)
|
|
109
|
+
thread.start()
|
|
110
|
+
self.threads.append(thread)
|
|
71
111
|
|
|
72
112
|
def stop(self) -> None:
|
|
113
|
+
"""Stop the frame cache and cleanup resources."""
|
|
73
114
|
if not self.running:
|
|
74
115
|
return
|
|
116
|
+
|
|
75
117
|
self.running = False
|
|
118
|
+
self._stop_worker_threads()
|
|
119
|
+
self.threads.clear()
|
|
120
|
+
|
|
121
|
+
def _stop_worker_threads(self) -> None:
|
|
122
|
+
"""Stop all worker threads gracefully."""
|
|
123
|
+
# Signal threads to stop
|
|
76
124
|
for _ in self.threads:
|
|
77
125
|
try:
|
|
78
126
|
self.queue.put_nowait(None)
|
|
79
|
-
except
|
|
127
|
+
except queue.Full:
|
|
80
128
|
pass
|
|
81
|
-
|
|
129
|
+
|
|
130
|
+
# Wait for threads to finish
|
|
131
|
+
for thread in self.threads:
|
|
82
132
|
try:
|
|
83
|
-
|
|
84
|
-
except Exception:
|
|
85
|
-
|
|
86
|
-
self.threads.clear()
|
|
133
|
+
thread.join(timeout=2.0)
|
|
134
|
+
except Exception as e:
|
|
135
|
+
self.logger.warning(f"Error joining thread {thread.name}: {e}")
|
|
87
136
|
|
|
88
137
|
def put(self, frame_id: str, base64_content: str) -> None:
|
|
89
138
|
"""Enqueue a cache write for the given frame.
|
|
90
139
|
|
|
91
|
-
|
|
92
|
-
|
|
140
|
+
Args:
|
|
141
|
+
frame_id: unique identifier for the frame
|
|
142
|
+
base64_content: base64-encoded image string
|
|
93
143
|
"""
|
|
94
|
-
if not self.
|
|
144
|
+
if not self._is_cache_ready():
|
|
95
145
|
return
|
|
96
|
-
|
|
146
|
+
|
|
147
|
+
if not self._validate_input(frame_id, base64_content):
|
|
97
148
|
return
|
|
149
|
+
|
|
98
150
|
try:
|
|
99
151
|
key = f"{self.prefix}{frame_id}"
|
|
100
152
|
self.queue.put_nowait((key, base64_content))
|
|
101
153
|
except queue.Full:
|
|
102
|
-
|
|
103
|
-
|
|
154
|
+
self._handle_queue_full(frame_id)
|
|
155
|
+
|
|
156
|
+
def _is_cache_ready(self) -> bool:
|
|
157
|
+
"""Check if cache is ready for operations."""
|
|
158
|
+
return bool(self._client and self.running)
|
|
159
|
+
|
|
160
|
+
def _validate_input(self, frame_id: str, base64_content: str) -> bool:
|
|
161
|
+
"""Validate input parameters."""
|
|
162
|
+
if not frame_id:
|
|
163
|
+
self.logger.warning("Empty frame_id provided")
|
|
164
|
+
return False
|
|
165
|
+
if not base64_content:
|
|
166
|
+
self.logger.warning("Empty base64_content provided")
|
|
167
|
+
return False
|
|
168
|
+
return True
|
|
169
|
+
|
|
170
|
+
def _handle_queue_full(self, frame_id: str) -> None:
|
|
171
|
+
"""Handle queue full condition."""
|
|
172
|
+
self.logger.debug(f"Frame cache queue full; dropping frame_id={frame_id}")
|
|
104
173
|
|
|
105
174
|
def _worker(self) -> None:
|
|
175
|
+
"""Worker thread for processing cache operations."""
|
|
106
176
|
while self.running:
|
|
107
|
-
|
|
108
|
-
item = self.queue.get(timeout=0.5)
|
|
109
|
-
except queue.Empty:
|
|
110
|
-
continue
|
|
177
|
+
item = self._get_work_item()
|
|
111
178
|
if item is None:
|
|
179
|
+
continue
|
|
180
|
+
if self._is_stop_signal(item):
|
|
112
181
|
break
|
|
182
|
+
|
|
183
|
+
self._process_cache_item(item)
|
|
184
|
+
|
|
185
|
+
def _get_work_item(self) -> Optional[tuple]:
|
|
186
|
+
"""Get work item from queue with timeout."""
|
|
187
|
+
try:
|
|
188
|
+
return self.queue.get(timeout=0.5)
|
|
189
|
+
except queue.Empty:
|
|
190
|
+
return None
|
|
191
|
+
|
|
192
|
+
def _is_stop_signal(self, item: tuple) -> bool:
|
|
193
|
+
"""Check if item is a stop signal."""
|
|
194
|
+
return item is None
|
|
195
|
+
|
|
196
|
+
def _process_cache_item(self, item: tuple) -> None:
|
|
197
|
+
"""Process a single cache item."""
|
|
198
|
+
try:
|
|
113
199
|
key, base64_content = item
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
200
|
+
self._store_frame_data(key, base64_content)
|
|
201
|
+
except Exception as e:
|
|
202
|
+
self.logger.debug(f"Failed to process cache item: {e}")
|
|
203
|
+
finally:
|
|
204
|
+
self._mark_task_done()
|
|
205
|
+
|
|
206
|
+
def _store_frame_data(self, key: str, base64_content: str) -> None:
|
|
207
|
+
"""Store frame data in Redis with TTL."""
|
|
208
|
+
try:
|
|
209
|
+
# Store base64 string in Redis hash field 'frame', then set TTL
|
|
210
|
+
self._client.hset(key, "frame", base64_content)
|
|
211
|
+
self._client.expire(key, self.ttl_seconds)
|
|
212
|
+
except Exception as e:
|
|
213
|
+
self.logger.debug(f"Failed to cache frame {key}: {e}")
|
|
214
|
+
|
|
215
|
+
def _mark_task_done(self) -> None:
|
|
216
|
+
"""Mark queue task as done."""
|
|
217
|
+
try:
|
|
218
|
+
self.queue.task_done()
|
|
219
|
+
except Exception:
|
|
220
|
+
pass
|
|
126
221
|
|
|
127
222
|
|
|
@@ -1,18 +1,25 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
-
import
|
|
3
|
-
import time
|
|
2
|
+
import base64
|
|
4
3
|
import logging
|
|
5
|
-
import threading
|
|
6
4
|
import queue
|
|
7
|
-
|
|
5
|
+
import threading
|
|
6
|
+
import time
|
|
8
7
|
from concurrent.futures import ThreadPoolExecutor
|
|
8
|
+
from typing import Any, Dict, Optional
|
|
9
9
|
|
|
10
10
|
class InferenceWorker:
|
|
11
|
-
"""Handles inference processing
|
|
12
|
-
|
|
13
|
-
def __init__(
|
|
14
|
-
|
|
15
|
-
|
|
11
|
+
"""Handles inference processing with clean resource management and error handling."""
|
|
12
|
+
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
worker_id: int,
|
|
16
|
+
inference_queue: queue.PriorityQueue,
|
|
17
|
+
postproc_queue: queue.PriorityQueue,
|
|
18
|
+
inference_executor: ThreadPoolExecutor,
|
|
19
|
+
message_timeout: float,
|
|
20
|
+
inference_timeout: float,
|
|
21
|
+
inference_interface: Optional[Any] = None
|
|
22
|
+
):
|
|
16
23
|
self.worker_id = worker_id
|
|
17
24
|
self.inference_queue = inference_queue
|
|
18
25
|
self.postproc_queue = postproc_queue
|
|
@@ -23,10 +30,14 @@ class InferenceWorker:
|
|
|
23
30
|
self.running = False
|
|
24
31
|
self.logger = logging.getLogger(f"{__name__}.inference.{worker_id}")
|
|
25
32
|
|
|
26
|
-
def start(self):
|
|
33
|
+
def start(self) -> threading.Thread:
|
|
27
34
|
"""Start the inference worker in a separate thread."""
|
|
28
35
|
self.running = True
|
|
29
|
-
thread = threading.Thread(
|
|
36
|
+
thread = threading.Thread(
|
|
37
|
+
target=self._run,
|
|
38
|
+
name=f"InferenceWorker-{self.worker_id}",
|
|
39
|
+
daemon=False
|
|
40
|
+
)
|
|
30
41
|
thread.start()
|
|
31
42
|
return thread
|
|
32
43
|
|
|
@@ -34,72 +45,84 @@ class InferenceWorker:
|
|
|
34
45
|
"""Stop the inference worker."""
|
|
35
46
|
self.running = False
|
|
36
47
|
|
|
37
|
-
def _run(self):
|
|
38
|
-
"""Main inference dispatcher loop."""
|
|
48
|
+
def _run(self) -> None:
|
|
49
|
+
"""Main inference dispatcher loop with proper error handling."""
|
|
39
50
|
self.logger.info(f"Started inference worker {self.worker_id}")
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
51
|
+
|
|
52
|
+
try:
|
|
53
|
+
while self.running:
|
|
54
|
+
task = self._get_task_from_queue()
|
|
55
|
+
if task:
|
|
56
|
+
self._process_inference_task(*task)
|
|
57
|
+
except Exception as e:
|
|
58
|
+
self.logger.error(f"Fatal error in inference worker: {e}")
|
|
59
|
+
finally:
|
|
60
|
+
self.logger.info(f"Inference worker {self.worker_id} stopped")
|
|
61
|
+
|
|
62
|
+
def _get_task_from_queue(self) -> Optional[tuple]:
|
|
63
|
+
"""Get task from inference queue with timeout handling."""
|
|
64
|
+
try:
|
|
65
|
+
return self.inference_queue.get(timeout=self.message_timeout)
|
|
66
|
+
except queue.Empty:
|
|
67
|
+
return None
|
|
68
|
+
except Exception as e:
|
|
69
|
+
self.logger.error(f"Error getting task from queue: {e}")
|
|
70
|
+
return None
|
|
56
71
|
|
|
57
|
-
def _process_inference_task(self, priority: int, task_data: Dict[str, Any]):
|
|
58
|
-
"""Process a single inference task."""
|
|
72
|
+
def _process_inference_task(self, priority: int, timestamp: float, task_data: Dict[str, Any]) -> None:
|
|
73
|
+
"""Process a single inference task with proper error handling."""
|
|
59
74
|
try:
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
75
|
+
if not self._validate_task_data(task_data):
|
|
76
|
+
return
|
|
77
|
+
|
|
63
78
|
start_time = time.time()
|
|
64
|
-
|
|
65
|
-
result = future.result(timeout=self.inference_timeout)
|
|
79
|
+
result = self._execute_inference(task_data)
|
|
66
80
|
processing_time = time.time() - start_time
|
|
67
|
-
|
|
81
|
+
|
|
68
82
|
if result["success"]:
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
"model_result": result["model_result"],
|
|
73
|
-
"metadata": result["metadata"],
|
|
74
|
-
"processing_time": processing_time,
|
|
75
|
-
"input_stream": task_data["input_stream"],
|
|
76
|
-
"stream_key": task_data["stream_key"],
|
|
77
|
-
"camera_config": task_data["camera_config"]
|
|
78
|
-
}
|
|
79
|
-
|
|
80
|
-
# Add to post-processing queue with timestamp as tie-breaker
|
|
83
|
+
postproc_task = self._create_postprocessing_task(
|
|
84
|
+
task_data, result, processing_time
|
|
85
|
+
)
|
|
81
86
|
self.postproc_queue.put((priority, time.time(), postproc_task))
|
|
82
87
|
else:
|
|
83
88
|
self.logger.error(f"Inference failed: {result['error']}")
|
|
84
|
-
|
|
89
|
+
|
|
85
90
|
except Exception as e:
|
|
86
91
|
self.logger.error(f"Inference task error: {e}")
|
|
87
92
|
|
|
93
|
+
def _validate_task_data(self, task_data: Dict[str, Any]) -> bool:
|
|
94
|
+
"""Validate that task data contains required fields."""
|
|
95
|
+
required_fields = ["message", "input_stream", "stream_key", "camera_config"]
|
|
96
|
+
for field in required_fields:
|
|
97
|
+
if field not in task_data:
|
|
98
|
+
self.logger.error(f"Missing required field '{field}' in task data")
|
|
99
|
+
return False
|
|
100
|
+
return True
|
|
101
|
+
|
|
102
|
+
def _execute_inference(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
103
|
+
"""Execute inference task in thread pool."""
|
|
104
|
+
future = self.inference_executor.submit(self._run_inference, task_data)
|
|
105
|
+
return future.result(timeout=self.inference_timeout)
|
|
106
|
+
|
|
107
|
+
def _create_postprocessing_task(
|
|
108
|
+
self, task_data: Dict[str, Any], result: Dict[str, Any], processing_time: float
|
|
109
|
+
) -> Dict[str, Any]:
|
|
110
|
+
"""Create post-processing task from inference result."""
|
|
111
|
+
return {
|
|
112
|
+
"original_message": task_data["message"],
|
|
113
|
+
"model_result": result["model_result"],
|
|
114
|
+
"metadata": result["metadata"],
|
|
115
|
+
"processing_time": processing_time,
|
|
116
|
+
"input_stream": task_data["input_stream"],
|
|
117
|
+
"stream_key": task_data["stream_key"],
|
|
118
|
+
"camera_config": task_data["camera_config"]
|
|
119
|
+
}
|
|
120
|
+
|
|
88
121
|
def _run_inference(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
89
|
-
"""Run inference in thread pool."""
|
|
122
|
+
"""Run inference in thread pool with proper error handling and cleanup."""
|
|
90
123
|
try:
|
|
91
124
|
# Extract task data - handle camera streamer format
|
|
92
125
|
input_stream_data = task_data.get("input_stream", {})
|
|
93
|
-
input_content = input_stream_data.get("content")
|
|
94
|
-
|
|
95
|
-
# Handle base64 encoded content from camera streamer
|
|
96
|
-
if input_content and isinstance(input_content, str):
|
|
97
|
-
import base64
|
|
98
|
-
try:
|
|
99
|
-
input_content = base64.b64decode(input_content)
|
|
100
|
-
except Exception as e:
|
|
101
|
-
logging.warning(f"Failed to decode base64 input: {str(e)}")
|
|
102
|
-
|
|
103
126
|
stream_key = task_data.get("stream_key")
|
|
104
127
|
stream_info = input_stream_data.get("stream_info", {})
|
|
105
128
|
camera_info = input_stream_data.get("camera_info", {})
|
|
@@ -123,41 +146,107 @@ class InferenceWorker:
|
|
|
123
146
|
else:
|
|
124
147
|
extra_params = {}
|
|
125
148
|
|
|
126
|
-
if self.inference_interface
|
|
149
|
+
if not self.inference_interface:
|
|
127
150
|
raise ValueError("Inference interface not initialized")
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
except RuntimeError:
|
|
133
|
-
loop = asyncio.new_event_loop()
|
|
134
|
-
asyncio.set_event_loop(loop)
|
|
135
|
-
|
|
136
|
-
# Perform inference
|
|
151
|
+
|
|
152
|
+
inference_params = self._extract_inference_params(task_data)
|
|
153
|
+
loop = self._get_or_create_event_loop()
|
|
154
|
+
|
|
137
155
|
model_result, metadata = loop.run_until_complete(
|
|
138
|
-
self.inference_interface.inference(
|
|
139
|
-
input=input_content,
|
|
140
|
-
extra_params=extra_params,
|
|
141
|
-
apply_post_processing=False, # Inference only
|
|
142
|
-
stream_key=stream_key,
|
|
143
|
-
stream_info=stream_info,
|
|
144
|
-
camera_info=camera_info
|
|
145
|
-
)
|
|
156
|
+
self.inference_interface.inference(**inference_params)
|
|
146
157
|
)
|
|
147
|
-
|
|
148
|
-
return
|
|
149
|
-
|
|
150
|
-
"metadata": metadata,
|
|
151
|
-
"success": True,
|
|
152
|
-
"error": None
|
|
153
|
-
}
|
|
154
|
-
|
|
158
|
+
|
|
159
|
+
return self._create_success_result(model_result, metadata)
|
|
160
|
+
|
|
155
161
|
except Exception as e:
|
|
156
|
-
|
|
157
|
-
return
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
162
|
+
self.logger.error(f"Inference execution error: {e}", exc_info=True)
|
|
163
|
+
return self._create_error_result(str(e))
|
|
164
|
+
|
|
165
|
+
def _extract_inference_params(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
166
|
+
"""Extract and validate inference parameters from task data."""
|
|
167
|
+
input_stream_data = task_data.get("input_stream", {})
|
|
168
|
+
# Prefer decoded bytes if provided by upstream stages
|
|
169
|
+
input_bytes = task_data.get("decoded_input_bytes")
|
|
170
|
+
if not isinstance(input_bytes, (bytes, bytearray)):
|
|
171
|
+
content = input_stream_data.get("content")
|
|
172
|
+
if isinstance(content, str) and content:
|
|
173
|
+
try:
|
|
174
|
+
input_bytes = base64.b64decode(content)
|
|
175
|
+
except Exception as e:
|
|
176
|
+
self.logger.warning(f"Failed to decode base64 content for inference: {e}")
|
|
177
|
+
input_bytes = None
|
|
178
|
+
elif isinstance(content, (bytes, bytearray)):
|
|
179
|
+
input_bytes = content
|
|
180
|
+
else:
|
|
181
|
+
input_bytes = None
|
|
182
|
+
|
|
183
|
+
extra_params = self._normalize_extra_params(task_data.get("extra_params", {}))
|
|
184
|
+
|
|
185
|
+
return {
|
|
186
|
+
"input": input_bytes,
|
|
187
|
+
"extra_params": extra_params,
|
|
188
|
+
"apply_post_processing": False,
|
|
189
|
+
"stream_key": task_data.get("stream_key"),
|
|
190
|
+
"stream_info": input_stream_data.get("stream_info", {}),
|
|
191
|
+
"camera_info": input_stream_data.get("camera_info", {})
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
def _decode_input_content(self, content: Any) -> Any:
|
|
195
|
+
"""Decode base64 content if it's a string."""
|
|
196
|
+
if content and isinstance(content, str):
|
|
197
|
+
try:
|
|
198
|
+
return base64.b64decode(content)
|
|
199
|
+
except Exception as e:
|
|
200
|
+
self.logger.warning(f"Failed to decode base64 input: {e}")
|
|
201
|
+
return content
|
|
202
|
+
|
|
203
|
+
def _normalize_extra_params(self, extra_params: Any) -> Dict[str, Any]:
|
|
204
|
+
"""Normalize extra_params to ensure it's a dictionary."""
|
|
205
|
+
if isinstance(extra_params, dict):
|
|
206
|
+
return extra_params
|
|
207
|
+
elif isinstance(extra_params, list):
|
|
208
|
+
return self._merge_list_params(extra_params)
|
|
209
|
+
else:
|
|
210
|
+
self.logger.warning(f"Invalid extra_params type {type(extra_params)}, using empty dict")
|
|
211
|
+
return {}
|
|
212
|
+
|
|
213
|
+
def _merge_list_params(self, params_list: list) -> Dict[str, Any]:
|
|
214
|
+
"""Merge list of dictionaries into single dictionary."""
|
|
215
|
+
if not params_list:
|
|
216
|
+
return {}
|
|
217
|
+
|
|
218
|
+
if all(isinstance(item, dict) for item in params_list):
|
|
219
|
+
merged = {}
|
|
220
|
+
for item in params_list:
|
|
221
|
+
merged.update(item)
|
|
222
|
+
return merged
|
|
223
|
+
|
|
224
|
+
return {}
|
|
225
|
+
|
|
226
|
+
def _get_or_create_event_loop(self) -> asyncio.AbstractEventLoop:
|
|
227
|
+
"""Get existing event loop or create a new one for this thread."""
|
|
228
|
+
try:
|
|
229
|
+
return asyncio.get_event_loop()
|
|
230
|
+
except RuntimeError:
|
|
231
|
+
loop = asyncio.new_event_loop()
|
|
232
|
+
asyncio.set_event_loop(loop)
|
|
233
|
+
return loop
|
|
234
|
+
|
|
235
|
+
def _create_success_result(self, model_result: Any, metadata: Any) -> Dict[str, Any]:
|
|
236
|
+
"""Create successful inference result."""
|
|
237
|
+
return {
|
|
238
|
+
"model_result": model_result,
|
|
239
|
+
"metadata": metadata,
|
|
240
|
+
"success": True,
|
|
241
|
+
"error": None
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
def _create_error_result(self, error_message: str) -> Dict[str, Any]:
|
|
245
|
+
"""Create error inference result."""
|
|
246
|
+
return {
|
|
247
|
+
"model_result": None,
|
|
248
|
+
"metadata": None,
|
|
249
|
+
"success": False,
|
|
250
|
+
"error": error_message
|
|
251
|
+
}
|
|
163
252
|
|