matrice-inference 0.1.2__py3-none-any.whl → 0.1.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of matrice-inference might be problematic. Click here for more details.

@@ -1,50 +1,80 @@
1
1
  import logging
2
- import threading
3
2
  import queue
4
-
3
+ import threading
4
+ from typing import Optional
5
5
 
6
6
  try:
7
7
  import redis # type: ignore
8
- except Exception: # pragma: no cover
8
+ except ImportError: # pragma: no cover
9
9
  redis = None # type: ignore
10
10
 
11
11
 
12
12
  class RedisFrameCache:
13
- """Non-blocking Redis cache for frames keyed by frame_id.
13
+ """Non-blocking Redis cache for frames with optimized resource management.
14
14
 
15
15
  Stores base64 string content under key 'stream:frames:{frame_id}' with field 'frame'.
16
16
  Each insert sets or refreshes the TTL.
17
17
  """
18
18
 
19
+ DEFAULT_TTL_SECONDS = 300
20
+ DEFAULT_MAX_QUEUE = 10000
21
+ DEFAULT_WORKER_THREADS = 2
22
+ DEFAULT_CONNECT_TIMEOUT = 2.0
23
+ DEFAULT_SOCKET_TIMEOUT = 0.5
24
+ DEFAULT_HEALTH_CHECK_INTERVAL = 30
25
+ DEFAULT_PREFIX = "stream:frames:"
26
+
19
27
  def __init__(
20
28
  self,
21
29
  host: str = "localhost",
22
30
  port: int = 6379,
23
31
  db: int = 0,
24
- password: str = None,
25
- username: str = None,
26
- ttl_seconds: int = 300,
27
- prefix: str = "stream:frames:",
28
- max_queue: int = 10000,
29
- worker_threads: int = 2,
30
- connect_timeout: float = 2.0,
31
- socket_timeout: float = 0.5,
32
+ password: Optional[str] = None,
33
+ username: Optional[str] = None,
34
+ ttl_seconds: int = DEFAULT_TTL_SECONDS,
35
+ prefix: str = DEFAULT_PREFIX,
36
+ max_queue: int = DEFAULT_MAX_QUEUE,
37
+ worker_threads: int = DEFAULT_WORKER_THREADS,
38
+ connect_timeout: float = DEFAULT_CONNECT_TIMEOUT,
39
+ socket_timeout: float = DEFAULT_SOCKET_TIMEOUT,
32
40
  ) -> None:
33
- self.logger = logging.getLogger(__name__ + ".frame_cache")
34
- self.ttl_seconds = int(ttl_seconds)
41
+ self.logger = logging.getLogger(f"{__name__}.frame_cache")
42
+ self.ttl_seconds = max(1, int(ttl_seconds))
35
43
  self.prefix = prefix
36
- self.queue: "queue.Queue" = queue.Queue(maxsize=max_queue)
37
- self.threads = []
38
44
  self.running = False
39
- self._client = None
40
45
  self._worker_threads = max(1, int(worker_threads))
41
46
 
47
+ self.queue: queue.Queue = queue.Queue(maxsize=max_queue)
48
+ self.threads: list = []
49
+ self._client: Optional[redis.Redis] = None
50
+
51
+ if not self._is_redis_available():
52
+ return
53
+
54
+ self._client = self._create_redis_client(
55
+ host, port, db, password, username, connect_timeout, socket_timeout
56
+ )
57
+
58
+ def _is_redis_available(self) -> bool:
59
+ """Check if Redis package is available."""
42
60
  if redis is None:
43
61
  self.logger.warning("redis package not installed; frame caching disabled")
44
- return
62
+ return False
63
+ return True
45
64
 
65
+ def _create_redis_client(
66
+ self,
67
+ host: str,
68
+ port: int,
69
+ db: int,
70
+ password: Optional[str],
71
+ username: Optional[str],
72
+ connect_timeout: float,
73
+ socket_timeout: float
74
+ ) -> Optional[redis.Redis]:
75
+ """Create Redis client with proper error handling."""
46
76
  try:
47
- self._client = redis.Redis(
77
+ return redis.Redis(
48
78
  host=host,
49
79
  port=port,
50
80
  db=db,
@@ -52,76 +82,141 @@ class RedisFrameCache:
52
82
  username=username,
53
83
  socket_connect_timeout=connect_timeout,
54
84
  socket_timeout=socket_timeout,
55
- health_check_interval=30,
85
+ health_check_interval=self.DEFAULT_HEALTH_CHECK_INTERVAL,
56
86
  retry_on_timeout=True,
57
- decode_responses=True, # store strings directly
87
+ decode_responses=True,
58
88
  )
59
89
  except Exception as e:
60
- self.logger.warning("Failed to init Redis client: %s", e)
61
- self._client = None
90
+ self.logger.warning(f"Failed to initialize Redis client: {e}")
91
+ return None
62
92
 
63
93
  def start(self) -> None:
94
+ """Start the frame cache with worker threads."""
64
95
  if not self._client or self.running:
65
96
  return
97
+
66
98
  self.running = True
99
+ self._start_worker_threads()
100
+
101
+ def _start_worker_threads(self) -> None:
102
+ """Start worker threads for processing cache operations."""
67
103
  for i in range(self._worker_threads):
68
- t = threading.Thread(target=self._worker, name=f"FrameCache-{i}", daemon=True)
69
- t.start()
70
- self.threads.append(t)
104
+ thread = threading.Thread(
105
+ target=self._worker,
106
+ name=f"FrameCache-{i}",
107
+ daemon=True
108
+ )
109
+ thread.start()
110
+ self.threads.append(thread)
71
111
 
72
112
  def stop(self) -> None:
113
+ """Stop the frame cache and cleanup resources."""
73
114
  if not self.running:
74
115
  return
116
+
75
117
  self.running = False
118
+ self._stop_worker_threads()
119
+ self.threads.clear()
120
+
121
+ def _stop_worker_threads(self) -> None:
122
+ """Stop all worker threads gracefully."""
123
+ # Signal threads to stop
76
124
  for _ in self.threads:
77
125
  try:
78
126
  self.queue.put_nowait(None)
79
- except Exception:
127
+ except queue.Full:
80
128
  pass
81
- for t in self.threads:
129
+
130
+ # Wait for threads to finish
131
+ for thread in self.threads:
82
132
  try:
83
- t.join(timeout=2.0)
84
- except Exception:
85
- pass
86
- self.threads.clear()
133
+ thread.join(timeout=2.0)
134
+ except Exception as e:
135
+ self.logger.warning(f"Error joining thread {thread.name}: {e}")
87
136
 
88
137
  def put(self, frame_id: str, base64_content: str) -> None:
89
138
  """Enqueue a cache write for the given frame.
90
139
 
91
- - frame_id: unique identifier
92
- - base64_content: base64-encoded image string
140
+ Args:
141
+ frame_id: unique identifier for the frame
142
+ base64_content: base64-encoded image string
93
143
  """
94
- if not self._client or not self.running:
144
+ if not self._is_cache_ready():
95
145
  return
96
- if not frame_id or not base64_content:
146
+
147
+ if not self._validate_input(frame_id, base64_content):
97
148
  return
149
+
98
150
  try:
99
151
  key = f"{self.prefix}{frame_id}"
100
152
  self.queue.put_nowait((key, base64_content))
101
153
  except queue.Full:
102
- # Drop silently; never block pipeline
103
- self.logger.debug("Frame cache queue full; dropping frame_id=%s", frame_id)
154
+ self._handle_queue_full(frame_id)
155
+
156
+ def _is_cache_ready(self) -> bool:
157
+ """Check if cache is ready for operations."""
158
+ return bool(self._client and self.running)
159
+
160
+ def _validate_input(self, frame_id: str, base64_content: str) -> bool:
161
+ """Validate input parameters."""
162
+ if not frame_id:
163
+ self.logger.warning("Empty frame_id provided")
164
+ return False
165
+ if not base64_content:
166
+ self.logger.warning("Empty base64_content provided")
167
+ return False
168
+ return True
169
+
170
+ def _handle_queue_full(self, frame_id: str) -> None:
171
+ """Handle queue full condition."""
172
+ self.logger.debug(f"Frame cache queue full; dropping frame_id={frame_id}")
104
173
 
105
174
  def _worker(self) -> None:
175
+ """Worker thread for processing cache operations."""
106
176
  while self.running:
107
- try:
108
- item = self.queue.get(timeout=0.5)
109
- except queue.Empty:
110
- continue
177
+ item = self._get_work_item()
111
178
  if item is None:
179
+ continue
180
+ if self._is_stop_signal(item):
112
181
  break
182
+
183
+ self._process_cache_item(item)
184
+
185
+ def _get_work_item(self) -> Optional[tuple]:
186
+ """Get work item from queue with timeout."""
187
+ try:
188
+ return self.queue.get(timeout=0.5)
189
+ except queue.Empty:
190
+ return None
191
+
192
+ def _is_stop_signal(self, item: tuple) -> bool:
193
+ """Check if item is a stop signal."""
194
+ return item is None
195
+
196
+ def _process_cache_item(self, item: tuple) -> None:
197
+ """Process a single cache item."""
198
+ try:
113
199
  key, base64_content = item
114
- try:
115
- # Store base64 string in a Redis hash field 'frame', then set TTL
116
- # Mimics the Go backend behavior
117
- self._client.hset(key, "frame", base64_content)
118
- self._client.expire(key, self.ttl_seconds)
119
- except Exception as e:
120
- self.logger.debug("Failed to cache frame %s: %s", key, e)
121
- finally:
122
- try:
123
- self.queue.task_done()
124
- except Exception:
125
- pass
200
+ self._store_frame_data(key, base64_content)
201
+ except Exception as e:
202
+ self.logger.debug(f"Failed to process cache item: {e}")
203
+ finally:
204
+ self._mark_task_done()
205
+
206
+ def _store_frame_data(self, key: str, base64_content: str) -> None:
207
+ """Store frame data in Redis with TTL."""
208
+ try:
209
+ # Store base64 string in Redis hash field 'frame', then set TTL
210
+ self._client.hset(key, "frame", base64_content)
211
+ self._client.expire(key, self.ttl_seconds)
212
+ except Exception as e:
213
+ self.logger.debug(f"Failed to cache frame {key}: {e}")
214
+
215
+ def _mark_task_done(self) -> None:
216
+ """Mark queue task as done."""
217
+ try:
218
+ self.queue.task_done()
219
+ except Exception:
220
+ pass
126
221
 
127
222
 
@@ -1,18 +1,25 @@
1
1
  import asyncio
2
- import json
3
- import time
2
+ import base64
4
3
  import logging
5
- import threading
6
4
  import queue
7
- from typing import Any, Dict
5
+ import threading
6
+ import time
8
7
  from concurrent.futures import ThreadPoolExecutor
8
+ from typing import Any, Dict, Optional
9
9
 
10
10
  class InferenceWorker:
11
- """Handles inference processing using threading."""
12
-
13
- def __init__(self, worker_id: int, inference_queue: queue.PriorityQueue,
14
- postproc_queue: queue.PriorityQueue, inference_executor: ThreadPoolExecutor,
15
- message_timeout: float, inference_timeout: float, inference_interface=None):
11
+ """Handles inference processing with clean resource management and error handling."""
12
+
13
+ def __init__(
14
+ self,
15
+ worker_id: int,
16
+ inference_queue: queue.PriorityQueue,
17
+ postproc_queue: queue.PriorityQueue,
18
+ inference_executor: ThreadPoolExecutor,
19
+ message_timeout: float,
20
+ inference_timeout: float,
21
+ inference_interface: Optional[Any] = None
22
+ ):
16
23
  self.worker_id = worker_id
17
24
  self.inference_queue = inference_queue
18
25
  self.postproc_queue = postproc_queue
@@ -23,10 +30,14 @@ class InferenceWorker:
23
30
  self.running = False
24
31
  self.logger = logging.getLogger(f"{__name__}.inference.{worker_id}")
25
32
 
26
- def start(self):
33
+ def start(self) -> threading.Thread:
27
34
  """Start the inference worker in a separate thread."""
28
35
  self.running = True
29
- thread = threading.Thread(target=self._run, name=f"InferenceWorker-{self.worker_id}", daemon=False)
36
+ thread = threading.Thread(
37
+ target=self._run,
38
+ name=f"InferenceWorker-{self.worker_id}",
39
+ daemon=False
40
+ )
30
41
  thread.start()
31
42
  return thread
32
43
 
@@ -34,72 +45,84 @@ class InferenceWorker:
34
45
  """Stop the inference worker."""
35
46
  self.running = False
36
47
 
37
- def _run(self):
38
- """Main inference dispatcher loop."""
48
+ def _run(self) -> None:
49
+ """Main inference dispatcher loop with proper error handling."""
39
50
  self.logger.info(f"Started inference worker {self.worker_id}")
40
-
41
- while self.running:
42
- try:
43
- # Get task from inference queue
44
- try:
45
- priority, timestamp, task_data = self.inference_queue.get(timeout=self.message_timeout)
46
- except queue.Empty:
47
- continue
48
-
49
- # Process inference task
50
- self._process_inference_task(priority, task_data)
51
-
52
- except Exception as e:
53
- self.logger.error(f"Inference worker error: {e}")
54
-
55
- self.logger.info(f"Inference worker {self.worker_id} stopped")
51
+
52
+ try:
53
+ while self.running:
54
+ task = self._get_task_from_queue()
55
+ if task:
56
+ self._process_inference_task(*task)
57
+ except Exception as e:
58
+ self.logger.error(f"Fatal error in inference worker: {e}")
59
+ finally:
60
+ self.logger.info(f"Inference worker {self.worker_id} stopped")
61
+
62
+ def _get_task_from_queue(self) -> Optional[tuple]:
63
+ """Get task from inference queue with timeout handling."""
64
+ try:
65
+ return self.inference_queue.get(timeout=self.message_timeout)
66
+ except queue.Empty:
67
+ return None
68
+ except Exception as e:
69
+ self.logger.error(f"Error getting task from queue: {e}")
70
+ return None
56
71
 
57
- def _process_inference_task(self, priority: int, task_data: Dict[str, Any]):
58
- """Process a single inference task."""
72
+ def _process_inference_task(self, priority: int, timestamp: float, task_data: Dict[str, Any]) -> None:
73
+ """Process a single inference task with proper error handling."""
59
74
  try:
60
- message = task_data["message"]
61
-
62
- # Submit to thread pool for async execution
75
+ if not self._validate_task_data(task_data):
76
+ return
77
+
63
78
  start_time = time.time()
64
- future = self.inference_executor.submit(self._run_inference, task_data)
65
- result = future.result(timeout=self.inference_timeout)
79
+ result = self._execute_inference(task_data)
66
80
  processing_time = time.time() - start_time
67
-
81
+
68
82
  if result["success"]:
69
- # Create post-processing task
70
- postproc_task = {
71
- "original_message": message,
72
- "model_result": result["model_result"],
73
- "metadata": result["metadata"],
74
- "processing_time": processing_time,
75
- "input_stream": task_data["input_stream"],
76
- "stream_key": task_data["stream_key"],
77
- "camera_config": task_data["camera_config"]
78
- }
79
-
80
- # Add to post-processing queue with timestamp as tie-breaker
83
+ postproc_task = self._create_postprocessing_task(
84
+ task_data, result, processing_time
85
+ )
81
86
  self.postproc_queue.put((priority, time.time(), postproc_task))
82
87
  else:
83
88
  self.logger.error(f"Inference failed: {result['error']}")
84
-
89
+
85
90
  except Exception as e:
86
91
  self.logger.error(f"Inference task error: {e}")
87
92
 
93
+ def _validate_task_data(self, task_data: Dict[str, Any]) -> bool:
94
+ """Validate that task data contains required fields."""
95
+ required_fields = ["message", "input_stream", "stream_key", "camera_config"]
96
+ for field in required_fields:
97
+ if field not in task_data:
98
+ self.logger.error(f"Missing required field '{field}' in task data")
99
+ return False
100
+ return True
101
+
102
+ def _execute_inference(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
103
+ """Execute inference task in thread pool."""
104
+ future = self.inference_executor.submit(self._run_inference, task_data)
105
+ return future.result(timeout=self.inference_timeout)
106
+
107
+ def _create_postprocessing_task(
108
+ self, task_data: Dict[str, Any], result: Dict[str, Any], processing_time: float
109
+ ) -> Dict[str, Any]:
110
+ """Create post-processing task from inference result."""
111
+ return {
112
+ "original_message": task_data["message"],
113
+ "model_result": result["model_result"],
114
+ "metadata": result["metadata"],
115
+ "processing_time": processing_time,
116
+ "input_stream": task_data["input_stream"],
117
+ "stream_key": task_data["stream_key"],
118
+ "camera_config": task_data["camera_config"]
119
+ }
120
+
88
121
  def _run_inference(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
89
- """Run inference in thread pool."""
122
+ """Run inference in thread pool with proper error handling and cleanup."""
90
123
  try:
91
124
  # Extract task data - handle camera streamer format
92
125
  input_stream_data = task_data.get("input_stream", {})
93
- input_content = input_stream_data.get("content")
94
-
95
- # Handle base64 encoded content from camera streamer
96
- if input_content and isinstance(input_content, str):
97
- import base64
98
- try:
99
- input_content = base64.b64decode(input_content)
100
- except Exception as e:
101
- logging.warning(f"Failed to decode base64 input: {str(e)}")
102
-
103
126
  stream_key = task_data.get("stream_key")
104
127
  stream_info = input_stream_data.get("stream_info", {})
105
128
  camera_info = input_stream_data.get("camera_info", {})
@@ -123,41 +146,107 @@ class InferenceWorker:
123
146
  else:
124
147
  extra_params = {}
125
148
 
126
- if self.inference_interface is None:
149
+ if not self.inference_interface:
127
150
  raise ValueError("Inference interface not initialized")
128
-
129
- # Create event loop for this thread if it doesn't exist
130
- try:
131
- loop = asyncio.get_event_loop()
132
- except RuntimeError:
133
- loop = asyncio.new_event_loop()
134
- asyncio.set_event_loop(loop)
135
-
136
- # Perform inference
151
+
152
+ inference_params = self._extract_inference_params(task_data)
153
+ loop = self._get_or_create_event_loop()
154
+
137
155
  model_result, metadata = loop.run_until_complete(
138
- self.inference_interface.inference(
139
- input=input_content,
140
- extra_params=extra_params,
141
- apply_post_processing=False, # Inference only
142
- stream_key=stream_key,
143
- stream_info=stream_info,
144
- camera_info=camera_info
145
- )
156
+ self.inference_interface.inference(**inference_params)
146
157
  )
147
-
148
- return {
149
- "model_result": model_result,
150
- "metadata": metadata,
151
- "success": True,
152
- "error": None
153
- }
154
-
158
+
159
+ return self._create_success_result(model_result, metadata)
160
+
155
161
  except Exception as e:
156
- logging.error(f"Inference worker error: {str(e)}", exc_info=True)
157
- return {
158
- "model_result": None,
159
- "metadata": None,
160
- "success": False,
161
- "error": str(e)
162
- }
162
+ self.logger.error(f"Inference execution error: {e}", exc_info=True)
163
+ return self._create_error_result(str(e))
164
+
165
+ def _extract_inference_params(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
166
+ """Extract and validate inference parameters from task data."""
167
+ input_stream_data = task_data.get("input_stream", {})
168
+ # Prefer decoded bytes if provided by upstream stages
169
+ input_bytes = task_data.get("decoded_input_bytes")
170
+ if not isinstance(input_bytes, (bytes, bytearray)):
171
+ content = input_stream_data.get("content")
172
+ if isinstance(content, str) and content:
173
+ try:
174
+ input_bytes = base64.b64decode(content)
175
+ except Exception as e:
176
+ self.logger.warning(f"Failed to decode base64 content for inference: {e}")
177
+ input_bytes = None
178
+ elif isinstance(content, (bytes, bytearray)):
179
+ input_bytes = content
180
+ else:
181
+ input_bytes = None
182
+
183
+ extra_params = self._normalize_extra_params(task_data.get("extra_params", {}))
184
+
185
+ return {
186
+ "input": input_bytes,
187
+ "extra_params": extra_params,
188
+ "apply_post_processing": False,
189
+ "stream_key": task_data.get("stream_key"),
190
+ "stream_info": input_stream_data.get("stream_info", {}),
191
+ "camera_info": input_stream_data.get("camera_info", {})
192
+ }
193
+
194
+ def _decode_input_content(self, content: Any) -> Any:
195
+ """Decode base64 content if it's a string."""
196
+ if content and isinstance(content, str):
197
+ try:
198
+ return base64.b64decode(content)
199
+ except Exception as e:
200
+ self.logger.warning(f"Failed to decode base64 input: {e}")
201
+ return content
202
+
203
+ def _normalize_extra_params(self, extra_params: Any) -> Dict[str, Any]:
204
+ """Normalize extra_params to ensure it's a dictionary."""
205
+ if isinstance(extra_params, dict):
206
+ return extra_params
207
+ elif isinstance(extra_params, list):
208
+ return self._merge_list_params(extra_params)
209
+ else:
210
+ self.logger.warning(f"Invalid extra_params type {type(extra_params)}, using empty dict")
211
+ return {}
212
+
213
+ def _merge_list_params(self, params_list: list) -> Dict[str, Any]:
214
+ """Merge list of dictionaries into single dictionary."""
215
+ if not params_list:
216
+ return {}
217
+
218
+ if all(isinstance(item, dict) for item in params_list):
219
+ merged = {}
220
+ for item in params_list:
221
+ merged.update(item)
222
+ return merged
223
+
224
+ return {}
225
+
226
+ def _get_or_create_event_loop(self) -> asyncio.AbstractEventLoop:
227
+ """Get existing event loop or create a new one for this thread."""
228
+ try:
229
+ return asyncio.get_event_loop()
230
+ except RuntimeError:
231
+ loop = asyncio.new_event_loop()
232
+ asyncio.set_event_loop(loop)
233
+ return loop
234
+
235
+ def _create_success_result(self, model_result: Any, metadata: Any) -> Dict[str, Any]:
236
+ """Create successful inference result."""
237
+ return {
238
+ "model_result": model_result,
239
+ "metadata": metadata,
240
+ "success": True,
241
+ "error": None
242
+ }
243
+
244
+ def _create_error_result(self, error_message: str) -> Dict[str, Any]:
245
+ """Create error inference result."""
246
+ return {
247
+ "model_result": None,
248
+ "metadata": None,
249
+ "success": False,
250
+ "error": error_message
251
+ }
163
252