matrice-inference 0.1.2__py3-none-any.whl → 0.1.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of matrice-inference might be problematic. Click here for more details.

@@ -1,19 +1,27 @@
1
1
  import asyncio
2
+ import base64
2
3
  import logging
3
- import threading
4
4
  import queue
5
+ import threading
5
6
  import time
6
- from typing import Any, Dict
7
7
  from concurrent.futures import ThreadPoolExecutor
8
+ from typing import Any, Dict, Optional
8
9
 
9
10
 
10
11
  class PostProcessingWorker:
11
- """Handles post-processing using threading."""
12
-
13
- def __init__(self, worker_id: int, postproc_queue: queue.PriorityQueue,
14
- output_queue: queue.PriorityQueue, postprocessing_executor: ThreadPoolExecutor,
15
- message_timeout: float, inference_timeout: float, post_processor=None,
16
- frame_cache=None):
12
+ """Handles post-processing with clean resource management and error handling."""
13
+
14
+ def __init__(
15
+ self,
16
+ worker_id: int,
17
+ postproc_queue: queue.PriorityQueue,
18
+ output_queue: queue.PriorityQueue,
19
+ postprocessing_executor: ThreadPoolExecutor,
20
+ message_timeout: float,
21
+ inference_timeout: float,
22
+ post_processor: Optional[Any] = None,
23
+ frame_cache: Optional[Any] = None
24
+ ):
17
25
  self.worker_id = worker_id
18
26
  self.postproc_queue = postproc_queue
19
27
  self.output_queue = output_queue
@@ -25,10 +33,14 @@ class PostProcessingWorker:
25
33
  self.running = False
26
34
  self.logger = logging.getLogger(f"{__name__}.postproc.{worker_id}")
27
35
 
28
- def start(self):
36
+ def start(self) -> threading.Thread:
29
37
  """Start the post-processing worker in a separate thread."""
30
38
  self.running = True
31
- thread = threading.Thread(target=self._run, name=f"PostProcWorker-{self.worker_id}", daemon=False)
39
+ thread = threading.Thread(
40
+ target=self._run,
41
+ name=f"PostProcWorker-{self.worker_id}",
42
+ daemon=False
43
+ )
32
44
  thread.start()
33
45
  return thread
34
46
 
@@ -36,195 +48,255 @@ class PostProcessingWorker:
36
48
  """Stop the post-processing worker."""
37
49
  self.running = False
38
50
 
39
- def _run(self):
40
- """Main post-processing dispatcher loop."""
51
+ def _run(self) -> None:
52
+ """Main post-processing dispatcher loop with proper error handling."""
41
53
  self.logger.info(f"Started post-processing worker {self.worker_id}")
42
-
43
- while self.running:
44
- try:
45
- # Get task from post-processing queue
46
- try:
47
- priority, timestamp, task_data = self.postproc_queue.get(timeout=self.message_timeout)
48
- except queue.Empty:
49
- continue
50
-
51
- # Process post-processing task
52
- self._process_postproc_task(priority, task_data)
53
-
54
- except Exception as e:
55
- self.logger.error(f"Post-processing worker error: {e}")
56
-
57
- self.logger.info(f"Post-processing worker {self.worker_id} stopped")
54
+
55
+ try:
56
+ while self.running:
57
+ task = self._get_task_from_queue()
58
+ if task:
59
+ self._process_postproc_task(*task)
60
+ except Exception as e:
61
+ self.logger.error(f"Fatal error in post-processing worker: {e}")
62
+ finally:
63
+ self.logger.info(f"Post-processing worker {self.worker_id} stopped")
64
+
65
+ def _get_task_from_queue(self) -> Optional[tuple]:
66
+ """Get task from post-processing queue with timeout handling."""
67
+ try:
68
+ return self.postproc_queue.get(timeout=self.message_timeout)
69
+ except queue.Empty:
70
+ return None
71
+ except Exception as e:
72
+ self.logger.error(f"Error getting task from queue: {e}")
73
+ return None
58
74
 
59
- def _process_postproc_task(self, priority: int, task_data: Dict[str, Any]):
60
- """Process a single post-processing task."""
75
+ def _process_postproc_task(self, priority: int, timestamp: float, task_data: Dict[str, Any]) -> None:
76
+ """Process a single post-processing task with proper error handling."""
61
77
  try:
62
- # Submit to thread pool for async execution
63
- future = self.postprocessing_executor.submit(self._run_post_processing, task_data)
64
- result = future.result(timeout=self.inference_timeout)
65
-
78
+ if not self._validate_task_data(task_data):
79
+ return
80
+
81
+ result = self._execute_post_processing(task_data)
82
+
66
83
  if result["success"]:
67
- # Cache disabled: preserving content in output and not pushing to Redis
68
- # try:
69
- # orig_input = task_data.get("input_stream", {}) or {}
70
- # content = orig_input.get("content")
71
- # frame_id_for_cache = task_data.get("frame_id") or orig_input.get("frame_id")
72
- # if content and frame_id_for_cache and self.frame_cache:
73
- # if isinstance(content, bytes):
74
- # import base64
75
- # try:
76
- # content = base64.b64encode(content).decode("ascii")
77
- # except Exception:
78
- # content = None
79
- # if isinstance(content, str):
80
- # self.frame_cache.put(frame_id_for_cache, content)
81
- # except Exception:
82
- # pass
83
-
84
- # Create final output message
85
- # Prepare input_stream for output: ensure frame_id is present and strip bulky content
86
- safe_input_stream = {}
87
- try:
88
- if isinstance(task_data.get("input_stream"), dict):
89
- safe_input_stream = dict(task_data["input_stream"]) # shallow copy
90
- # Ensure frame_id propagation
91
- if "frame_id" not in safe_input_stream and "frame_id" in task_data:
92
- safe_input_stream["frame_id"] = task_data["frame_id"]
93
- # Do not strip content; keep as-is in output
94
- # if "content" in safe_input_stream:
95
- # safe_input_stream["content"] = ""
96
- except Exception:
97
- safe_input_stream = task_data.get("input_stream", {})
98
-
99
- # Determine frame_id for top-level convenience
100
- frame_id = task_data.get("frame_id")
101
- if not frame_id and isinstance(safe_input_stream, dict):
102
- frame_id = safe_input_stream.get("frame_id")
103
-
104
- output_data = {
105
- "camera_id": task_data["original_message"].camera_id,
106
- "message_key": task_data["original_message"].message_key,
107
- "timestamp": task_data["original_message"].timestamp.isoformat(),
108
- "frame_id": frame_id,
109
- "model_result": task_data["model_result"],
110
- "input_stream": safe_input_stream,
111
- "post_processing_result": result["post_processing_result"],
112
- "processing_time_sec": task_data["processing_time"],
113
- "metadata": task_data.get("metadata", {})
114
- }
115
-
116
- # Add to output queue
117
- output_task = {
118
- "camera_id": task_data["original_message"].camera_id,
119
- "message_key": task_data["original_message"].message_key,
120
- "data": output_data,
121
- }
122
- # Add to output queue with timestamp as tie-breaker
84
+ output_task = self._create_output_task(task_data, result)
123
85
  self.output_queue.put((priority, time.time(), output_task))
124
86
  else:
125
87
  self.logger.error(f"Post-processing failed: {result['error']}")
126
-
88
+
127
89
  except Exception as e:
128
90
  self.logger.error(f"Post-processing task error: {e}")
129
91
 
92
+ def _validate_task_data(self, task_data: Dict[str, Any]) -> bool:
93
+ """Validate that task data contains required fields."""
94
+ required_fields = ["original_message", "model_result", "input_stream"]
95
+ for field in required_fields:
96
+ if field not in task_data:
97
+ self.logger.error(f"Missing required field '{field}' in task data")
98
+ return False
99
+ return True
100
+
101
+ def _execute_post_processing(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
102
+ """Execute post-processing task in thread pool."""
103
+ future = self.postprocessing_executor.submit(self._run_post_processing, task_data)
104
+ return future.result(timeout=self.inference_timeout)
105
+
106
+ def _create_output_task(self, task_data: Dict[str, Any], result: Dict[str, Any]) -> Dict[str, Any]:
107
+ """Create output task from post-processing result."""
108
+ safe_input_stream = self._prepare_safe_input_stream(task_data)
109
+ frame_id = self._determine_frame_id(task_data, safe_input_stream)
110
+
111
+ # Strip content before publishing to output topic
112
+ try:
113
+ if isinstance(safe_input_stream, dict):
114
+ safe_input_stream["content"] = ""
115
+ except Exception:
116
+ pass
117
+
118
+ output_data = {
119
+ "camera_id": task_data["original_message"].camera_id,
120
+ "message_key": task_data["original_message"].message_key,
121
+ "timestamp": task_data["original_message"].timestamp.isoformat(),
122
+ "frame_id": frame_id,
123
+ "model_result": task_data["model_result"],
124
+ "input_stream": safe_input_stream,
125
+ "post_processing_result": result["post_processing_result"],
126
+ "processing_time_sec": task_data["processing_time"],
127
+ "metadata": task_data.get("metadata", {})
128
+ }
129
+
130
+ # Verify frame_id is present in output
131
+ if not frame_id:
132
+ self.logger.warning(
133
+ f"Output task missing frame_id for camera={task_data['original_message'].camera_id}, "
134
+ f"message_key={task_data['original_message'].message_key}"
135
+ )
136
+
137
+ return {
138
+ "camera_id": task_data["original_message"].camera_id,
139
+ "message_key": task_data["original_message"].message_key,
140
+ "data": output_data,
141
+ }
142
+
143
+ def _prepare_safe_input_stream(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
144
+ """Prepare input stream for output with proper frame_id propagation."""
145
+ try:
146
+ input_stream = task_data.get("input_stream")
147
+ if isinstance(input_stream, dict):
148
+ safe_input_stream = dict(input_stream) # shallow copy
149
+ # Ensure frame_id propagation
150
+ if "frame_id" not in safe_input_stream and "frame_id" in task_data:
151
+ safe_input_stream["frame_id"] = task_data["frame_id"]
152
+ return safe_input_stream
153
+ except Exception as e:
154
+ self.logger.warning(f"Error preparing input stream: {e}")
155
+
156
+ return task_data.get("input_stream", {})
157
+
158
+ def _determine_frame_id(self, task_data: Dict[str, Any], safe_input_stream: Dict[str, Any]) -> Optional[str]:
159
+ """Determine frame_id from task data or input stream."""
160
+ frame_id = task_data.get("frame_id")
161
+ if not frame_id and isinstance(safe_input_stream, dict):
162
+ frame_id = safe_input_stream.get("frame_id")
163
+ return frame_id
164
+
130
165
  def _run_post_processing(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
131
- """Run post-processing in thread pool."""
166
+ """Run post-processing in thread pool with proper error handling."""
132
167
  try:
133
- if self.post_processor is None:
168
+ if not self.post_processor:
134
169
  raise ValueError("Post processor not initialized")
135
-
136
- # Extract task data
137
- model_result = task_data["model_result"]
138
- input_stream_data = task_data.get("input_stream", {})
139
- input_content = input_stream_data.get("content")
140
-
141
- # Handle base64 encoded content
142
- if input_content and isinstance(input_content, str):
143
- import base64
144
- try:
145
- input_content = base64.b64decode(input_content)
146
- except Exception as e:
147
- logging.warning(f"Failed to decode base64 input: {str(e)}")
148
- input_content = None
149
-
150
- stream_key = task_data.get("stream_key")
151
- stream_info = input_stream_data.get("stream_info", {})
152
- camera_config = task_data.get("camera_config", {})
153
-
154
- # Create event loop for this thread if it doesn't exist
155
- try:
156
- loop = asyncio.get_event_loop()
157
- except RuntimeError:
158
- loop = asyncio.new_event_loop()
159
- asyncio.set_event_loop(loop)
160
-
161
- # Perform post-processing
170
+
171
+ processing_params = self._extract_processing_params(task_data)
172
+ loop = self._get_or_create_event_loop()
173
+
162
174
  result = loop.run_until_complete(
163
- self.post_processor.process(
164
- data=model_result,
165
- input_bytes=input_content if isinstance(input_content, bytes) else None,
166
- stream_key=stream_key,
167
- stream_info=stream_info
168
- )
175
+ self.post_processor.process(**processing_params)
169
176
  )
170
-
171
- # For face recognition use case, return empty raw results
172
- processed_raw_results = []
173
- try:
174
- if hasattr(result, 'usecase') and not result.usecase == 'face_recognition':
175
- processed_raw_results = model_result
176
- except Exception as e:
177
- logging.warning(f"Failed to get processed raw results: {str(e)}")
178
-
179
- # Extract agg_summary from result data if available
180
- agg_summary = {}
181
- try:
182
- if hasattr(result, 'data') and isinstance(result.data, dict):
183
- agg_summary = result.data.get("agg_summary", {})
184
- except Exception as e:
185
- logging.warning(f"Failed to get agg summary: {str(e)}")
186
-
187
- # Format result similar to InferenceInterface
188
- if result.is_success():
189
- post_processing_result = {
190
- "status": "success",
191
- "processing_time": result.processing_time,
192
- "usecase": getattr(result, 'usecase', ''),
193
- "category": getattr(result, 'category', ''),
194
- "summary": getattr(result, 'summary', ''),
195
- "insights": getattr(result, 'insights', []),
196
- "metrics": getattr(result, 'metrics', {}),
197
- "predictions": getattr(result, 'predictions', []),
198
- "agg_summary": agg_summary,
199
- "raw_results": processed_raw_results,
200
- "stream_key": stream_key
201
- }
202
- else:
203
- post_processing_result = {
204
- "status": "post_processing_failed",
205
- "error": result.error_message,
206
- "error_type": getattr(result, 'error_type', 'ProcessingError'),
207
- "processing_time": result.processing_time,
208
- "stream_key": stream_key,
209
- "agg_summary": agg_summary,
210
- "raw_results": model_result
211
- }
212
-
177
+
178
+ post_processing_result = self._format_processing_result(
179
+ result, task_data, processing_params["stream_key"]
180
+ )
181
+
213
182
  return {
214
183
  "post_processing_result": post_processing_result,
215
184
  "success": True,
216
185
  "error": None
217
186
  }
218
-
187
+
219
188
  except Exception as e:
220
- logging.error(f"Post-processing worker error: {str(e)}", exc_info=True)
221
- return {
222
- "post_processing_result": {
223
- "status": "post_processing_failed",
224
- "error": str(e),
225
- "error_type": type(e).__name__,
226
- "stream_key": task_data.get("stream_key")
227
- },
228
- "success": False,
229
- "error": str(e)
230
- }
189
+ self.logger.error(f"Post-processing execution error: {e}", exc_info=True)
190
+ return self._create_error_result(str(e), task_data)
191
+
192
+ def _extract_processing_params(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
193
+ """Extract and validate post-processing parameters from task data."""
194
+ input_stream_data = task_data.get("input_stream", {})
195
+ input_content = self._decode_input_content(input_stream_data.get("content"))
196
+
197
+ # Extract stream_info and add frame_id to it
198
+ stream_info = input_stream_data.get("stream_info", {})
199
+ if isinstance(stream_info, dict):
200
+ # Add frame_id to stream_info if available
201
+ frame_id = task_data.get("frame_id")
202
+ if not frame_id and isinstance(input_stream_data, dict):
203
+ frame_id = input_stream_data.get("frame_id")
204
+ if frame_id:
205
+ stream_info["frame_id"] = frame_id
206
+
207
+ return {
208
+ "data": task_data["model_result"],
209
+ "input_bytes": input_content if isinstance(input_content, bytes) else None,
210
+ "stream_key": task_data.get("stream_key"),
211
+ "stream_info": stream_info
212
+ }
213
+
214
+ def _decode_input_content(self, content: Any) -> Any:
215
+ """Decode base64 content if it's a string."""
216
+ if content and isinstance(content, str):
217
+ try:
218
+ return base64.b64decode(content)
219
+ except Exception as e:
220
+ self.logger.warning(f"Failed to decode base64 input: {e}")
221
+ return None
222
+ return content
223
+
224
+ def _get_or_create_event_loop(self) -> asyncio.AbstractEventLoop:
225
+ """Get existing event loop or create a new one for this thread."""
226
+ try:
227
+ return asyncio.get_event_loop()
228
+ except RuntimeError:
229
+ loop = asyncio.new_event_loop()
230
+ asyncio.set_event_loop(loop)
231
+ return loop
232
+
233
+ def _format_processing_result(self, result: Any, task_data: Dict[str, Any], stream_key: str) -> Dict[str, Any]:
234
+ """Format post-processing result based on success status."""
235
+ if result.is_success():
236
+ return self._create_success_result(result, task_data["model_result"], stream_key)
237
+ else:
238
+ return self._create_failure_result(result, task_data["model_result"], stream_key)
239
+
240
+ def _create_success_result(self, result: Any, model_result: Any, stream_key: str) -> Dict[str, Any]:
241
+ """Create successful post-processing result."""
242
+ processed_raw_results = self._get_processed_raw_results(result, model_result)
243
+ agg_summary = self._extract_agg_summary(result)
244
+
245
+ return {
246
+ "status": "success",
247
+ "processing_time": result.processing_time,
248
+ "usecase": getattr(result, 'usecase', ''),
249
+ "category": getattr(result, 'category', ''),
250
+ "summary": getattr(result, 'summary', ''),
251
+ "insights": getattr(result, 'insights', []),
252
+ "metrics": getattr(result, 'metrics', {}),
253
+ "predictions": getattr(result, 'predictions', []),
254
+ "agg_summary": agg_summary,
255
+ "raw_results": processed_raw_results,
256
+ "stream_key": stream_key
257
+ }
258
+
259
+ def _create_failure_result(self, result: Any, model_result: Any, stream_key: str) -> Dict[str, Any]:
260
+ """Create failed post-processing result."""
261
+ agg_summary = self._extract_agg_summary(result)
262
+
263
+ return {
264
+ "status": "post_processing_failed",
265
+ "error": result.error_message,
266
+ "error_type": getattr(result, 'error_type', 'ProcessingError'),
267
+ "processing_time": result.processing_time,
268
+ "stream_key": stream_key,
269
+ "agg_summary": agg_summary,
270
+ "raw_results": model_result
271
+ }
272
+
273
+ def _get_processed_raw_results(self, result: Any, model_result: Any) -> list:
274
+ """Get processed raw results, handling face recognition special case."""
275
+ try:
276
+ if hasattr(result, 'usecase') and result.usecase != 'face_recognition':
277
+ return model_result
278
+ except Exception as e:
279
+ self.logger.warning(f"Failed to get processed raw results: {e}")
280
+ return []
281
+
282
+ def _extract_agg_summary(self, result: Any) -> Dict[str, Any]:
283
+ """Extract aggregated summary from result data."""
284
+ try:
285
+ if hasattr(result, 'data') and isinstance(result.data, dict):
286
+ return result.data.get("agg_summary", {})
287
+ except Exception as e:
288
+ self.logger.warning(f"Failed to get agg summary: {e}")
289
+ return {}
290
+
291
+ def _create_error_result(self, error_message: str, task_data: Dict[str, Any]) -> Dict[str, Any]:
292
+ """Create error post-processing result."""
293
+ return {
294
+ "post_processing_result": {
295
+ "status": "post_processing_failed",
296
+ "error": error_message,
297
+ "error_type": "ProcessingError",
298
+ "stream_key": task_data.get("stream_key")
299
+ },
300
+ "success": False,
301
+ "error": error_message
302
+ }