matrice-inference 0.1.2__py3-none-any.whl → 0.1.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of matrice-inference might be problematic. Click here for more details.

@@ -1,18 +1,25 @@
1
1
  import asyncio
2
- import json
3
- import time
2
+ import base64
4
3
  import logging
5
- import threading
6
4
  import queue
7
- from typing import Any, Dict
5
+ import threading
6
+ import time
8
7
  from concurrent.futures import ThreadPoolExecutor
8
+ from typing import Any, Dict, Optional
9
9
 
10
10
  class InferenceWorker:
11
- """Handles inference processing using threading."""
12
-
13
- def __init__(self, worker_id: int, inference_queue: queue.PriorityQueue,
14
- postproc_queue: queue.PriorityQueue, inference_executor: ThreadPoolExecutor,
15
- message_timeout: float, inference_timeout: float, inference_interface=None):
11
+ """Handles inference processing with clean resource management and error handling."""
12
+
13
+ def __init__(
14
+ self,
15
+ worker_id: int,
16
+ inference_queue: queue.PriorityQueue,
17
+ postproc_queue: queue.PriorityQueue,
18
+ inference_executor: ThreadPoolExecutor,
19
+ message_timeout: float,
20
+ inference_timeout: float,
21
+ inference_interface: Optional[Any] = None
22
+ ):
16
23
  self.worker_id = worker_id
17
24
  self.inference_queue = inference_queue
18
25
  self.postproc_queue = postproc_queue
@@ -23,10 +30,14 @@ class InferenceWorker:
23
30
  self.running = False
24
31
  self.logger = logging.getLogger(f"{__name__}.inference.{worker_id}")
25
32
 
26
- def start(self):
33
+ def start(self) -> threading.Thread:
27
34
  """Start the inference worker in a separate thread."""
28
35
  self.running = True
29
- thread = threading.Thread(target=self._run, name=f"InferenceWorker-{self.worker_id}", daemon=False)
36
+ thread = threading.Thread(
37
+ target=self._run,
38
+ name=f"InferenceWorker-{self.worker_id}",
39
+ daemon=False
40
+ )
30
41
  thread.start()
31
42
  return thread
32
43
 
@@ -34,72 +45,93 @@ class InferenceWorker:
34
45
  """Stop the inference worker."""
35
46
  self.running = False
36
47
 
37
- def _run(self):
38
- """Main inference dispatcher loop."""
48
+ def _run(self) -> None:
49
+ """Main inference dispatcher loop with proper error handling."""
39
50
  self.logger.info(f"Started inference worker {self.worker_id}")
40
-
41
- while self.running:
42
- try:
43
- # Get task from inference queue
44
- try:
45
- priority, timestamp, task_data = self.inference_queue.get(timeout=self.message_timeout)
46
- except queue.Empty:
47
- continue
48
-
49
- # Process inference task
50
- self._process_inference_task(priority, task_data)
51
-
52
- except Exception as e:
53
- self.logger.error(f"Inference worker error: {e}")
54
-
55
- self.logger.info(f"Inference worker {self.worker_id} stopped")
51
+
52
+ try:
53
+ while self.running:
54
+ task = self._get_task_from_queue()
55
+ if task:
56
+ self._process_inference_task(*task)
57
+ except Exception as e:
58
+ self.logger.error(f"Fatal error in inference worker: {e}")
59
+ finally:
60
+ self.logger.info(f"Inference worker {self.worker_id} stopped")
61
+
62
+ def _get_task_from_queue(self) -> Optional[tuple]:
63
+ """Get task from inference queue with timeout handling."""
64
+ try:
65
+ return self.inference_queue.get(timeout=self.message_timeout)
66
+ except queue.Empty:
67
+ return None
68
+ except Exception as e:
69
+ self.logger.error(f"Error getting task from queue: {e}")
70
+ return None
56
71
 
57
- def _process_inference_task(self, priority: int, task_data: Dict[str, Any]):
58
- """Process a single inference task."""
72
+ def _process_inference_task(self, priority: int, timestamp: float, task_data: Dict[str, Any]) -> None:
73
+ """Process a single inference task with proper error handling."""
59
74
  try:
60
- message = task_data["message"]
61
-
62
- # Submit to thread pool for async execution
75
+ if not self._validate_task_data(task_data):
76
+ return
77
+
63
78
  start_time = time.time()
64
- future = self.inference_executor.submit(self._run_inference, task_data)
65
- result = future.result(timeout=self.inference_timeout)
79
+ result = self._execute_inference(task_data)
66
80
  processing_time = time.time() - start_time
67
-
81
+
68
82
  if result["success"]:
69
- # Create post-processing task
70
- postproc_task = {
71
- "original_message": message,
72
- "model_result": result["model_result"],
73
- "metadata": result["metadata"],
74
- "processing_time": processing_time,
75
- "input_stream": task_data["input_stream"],
76
- "stream_key": task_data["stream_key"],
77
- "camera_config": task_data["camera_config"]
78
- }
79
-
80
- # Add to post-processing queue with timestamp as tie-breaker
83
+ postproc_task = self._create_postprocessing_task(
84
+ task_data, result, processing_time
85
+ )
81
86
  self.postproc_queue.put((priority, time.time(), postproc_task))
82
87
  else:
83
88
  self.logger.error(f"Inference failed: {result['error']}")
84
-
89
+
85
90
  except Exception as e:
86
91
  self.logger.error(f"Inference task error: {e}")
87
92
 
93
+ def _validate_task_data(self, task_data: Dict[str, Any]) -> bool:
94
+ """Validate that task data contains required fields."""
95
+ required_fields = ["message", "input_stream", "stream_key", "camera_config"]
96
+ for field in required_fields:
97
+ if field not in task_data:
98
+ self.logger.error(f"Missing required field '{field}' in task data")
99
+ return False
100
+ return True
101
+
102
+ def _execute_inference(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
103
+ """Execute inference task in thread pool."""
104
+ future = self.inference_executor.submit(self._run_inference, task_data)
105
+ return future.result(timeout=self.inference_timeout)
106
+
107
+ def _create_postprocessing_task(
108
+ self, task_data: Dict[str, Any], result: Dict[str, Any], processing_time: float
109
+ ) -> Dict[str, Any]:
110
+ """Create post-processing task from inference result, preserving frame_id."""
111
+ postproc_task = {
112
+ "original_message": task_data["message"],
113
+ "model_result": result["model_result"],
114
+ "metadata": result["metadata"],
115
+ "processing_time": processing_time,
116
+ "input_stream": task_data["input_stream"],
117
+ "stream_key": task_data["stream_key"],
118
+ "camera_config": task_data["camera_config"]
119
+ }
120
+
121
+ # Preserve frame_id from task_data (critical for cache retrieval)
122
+ if "frame_id" in task_data:
123
+ postproc_task["frame_id"] = task_data["frame_id"]
124
+ self.logger.debug(f"Preserved frame_id in postproc task: {task_data['frame_id']}")
125
+ else:
126
+ self.logger.warning("No frame_id in task_data to preserve")
127
+
128
+ return postproc_task
129
+
88
130
  def _run_inference(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
89
- """Run inference in thread pool."""
131
+ """Run inference in thread pool with proper error handling and cleanup."""
90
132
  try:
91
133
  # Extract task data - handle camera streamer format
92
134
  input_stream_data = task_data.get("input_stream", {})
93
- input_content = input_stream_data.get("content")
94
-
95
- # Handle base64 encoded content from camera streamer
96
- if input_content and isinstance(input_content, str):
97
- import base64
98
- try:
99
- input_content = base64.b64decode(input_content)
100
- except Exception as e:
101
- logging.warning(f"Failed to decode base64 input: {str(e)}")
102
-
103
135
  stream_key = task_data.get("stream_key")
104
136
  stream_info = input_stream_data.get("stream_info", {})
105
137
  camera_info = input_stream_data.get("camera_info", {})
@@ -123,41 +155,107 @@ class InferenceWorker:
123
155
  else:
124
156
  extra_params = {}
125
157
 
126
- if self.inference_interface is None:
158
+ if not self.inference_interface:
127
159
  raise ValueError("Inference interface not initialized")
128
-
129
- # Create event loop for this thread if it doesn't exist
130
- try:
131
- loop = asyncio.get_event_loop()
132
- except RuntimeError:
133
- loop = asyncio.new_event_loop()
134
- asyncio.set_event_loop(loop)
135
-
136
- # Perform inference
160
+
161
+ inference_params = self._extract_inference_params(task_data)
162
+ loop = self._get_or_create_event_loop()
163
+
137
164
  model_result, metadata = loop.run_until_complete(
138
- self.inference_interface.inference(
139
- input=input_content,
140
- extra_params=extra_params,
141
- apply_post_processing=False, # Inference only
142
- stream_key=stream_key,
143
- stream_info=stream_info,
144
- camera_info=camera_info
145
- )
165
+ self.inference_interface.inference(**inference_params)
146
166
  )
147
-
148
- return {
149
- "model_result": model_result,
150
- "metadata": metadata,
151
- "success": True,
152
- "error": None
153
- }
154
-
167
+
168
+ return self._create_success_result(model_result, metadata)
169
+
155
170
  except Exception as e:
156
- logging.error(f"Inference worker error: {str(e)}", exc_info=True)
157
- return {
158
- "model_result": None,
159
- "metadata": None,
160
- "success": False,
161
- "error": str(e)
162
- }
171
+ self.logger.error(f"Inference execution error: {e}", exc_info=True)
172
+ return self._create_error_result(str(e))
173
+
174
+ def _extract_inference_params(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
175
+ """Extract and validate inference parameters from task data."""
176
+ input_stream_data = task_data.get("input_stream", {})
177
+ # Prefer decoded bytes if provided by upstream stages
178
+ input_bytes = task_data.get("decoded_input_bytes")
179
+ if not isinstance(input_bytes, (bytes, bytearray)):
180
+ content = input_stream_data.get("content")
181
+ if isinstance(content, str) and content:
182
+ try:
183
+ input_bytes = base64.b64decode(content)
184
+ except Exception as e:
185
+ self.logger.warning(f"Failed to decode base64 content for inference: {e}")
186
+ input_bytes = None
187
+ elif isinstance(content, (bytes, bytearray)):
188
+ input_bytes = content
189
+ else:
190
+ input_bytes = None
191
+
192
+ extra_params = self._normalize_extra_params(task_data.get("extra_params", {}))
193
+
194
+ return {
195
+ "input": input_bytes,
196
+ "extra_params": extra_params,
197
+ "apply_post_processing": False,
198
+ "stream_key": task_data.get("stream_key"),
199
+ "stream_info": input_stream_data.get("stream_info", {}),
200
+ "camera_info": input_stream_data.get("camera_info", {})
201
+ }
202
+
203
+ def _decode_input_content(self, content: Any) -> Any:
204
+ """Decode base64 content if it's a string."""
205
+ if content and isinstance(content, str):
206
+ try:
207
+ return base64.b64decode(content)
208
+ except Exception as e:
209
+ self.logger.warning(f"Failed to decode base64 input: {e}")
210
+ return content
211
+
212
+ def _normalize_extra_params(self, extra_params: Any) -> Dict[str, Any]:
213
+ """Normalize extra_params to ensure it's a dictionary."""
214
+ if isinstance(extra_params, dict):
215
+ return extra_params
216
+ elif isinstance(extra_params, list):
217
+ return self._merge_list_params(extra_params)
218
+ else:
219
+ self.logger.warning(f"Invalid extra_params type {type(extra_params)}, using empty dict")
220
+ return {}
221
+
222
+ def _merge_list_params(self, params_list: list) -> Dict[str, Any]:
223
+ """Merge list of dictionaries into single dictionary."""
224
+ if not params_list:
225
+ return {}
226
+
227
+ if all(isinstance(item, dict) for item in params_list):
228
+ merged = {}
229
+ for item in params_list:
230
+ merged.update(item)
231
+ return merged
232
+
233
+ return {}
234
+
235
+ def _get_or_create_event_loop(self) -> asyncio.AbstractEventLoop:
236
+ """Get existing event loop or create a new one for this thread."""
237
+ try:
238
+ return asyncio.get_event_loop()
239
+ except RuntimeError:
240
+ loop = asyncio.new_event_loop()
241
+ asyncio.set_event_loop(loop)
242
+ return loop
243
+
244
+ def _create_success_result(self, model_result: Any, metadata: Any) -> Dict[str, Any]:
245
+ """Create successful inference result."""
246
+ return {
247
+ "model_result": model_result,
248
+ "metadata": metadata,
249
+ "success": True,
250
+ "error": None
251
+ }
252
+
253
+ def _create_error_result(self, error_message: str) -> Dict[str, Any]:
254
+ """Create error inference result."""
255
+ return {
256
+ "model_result": None,
257
+ "metadata": None,
258
+ "success": False,
259
+ "error": error_message
260
+ }
163
261