matrice-inference 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of matrice-inference might be problematic. Click here for more details.

Files changed (37) hide show
  1. matrice_inference/__init__.py +72 -0
  2. matrice_inference/py.typed +0 -0
  3. matrice_inference/server/__init__.py +23 -0
  4. matrice_inference/server/inference_interface.py +176 -0
  5. matrice_inference/server/model/__init__.py +1 -0
  6. matrice_inference/server/model/model_manager.py +274 -0
  7. matrice_inference/server/model/model_manager_wrapper.py +550 -0
  8. matrice_inference/server/model/triton_model_manager.py +290 -0
  9. matrice_inference/server/model/triton_server.py +1248 -0
  10. matrice_inference/server/proxy_interface.py +371 -0
  11. matrice_inference/server/server.py +1004 -0
  12. matrice_inference/server/stream/__init__.py +0 -0
  13. matrice_inference/server/stream/app_deployment.py +228 -0
  14. matrice_inference/server/stream/consumer_worker.py +201 -0
  15. matrice_inference/server/stream/frame_cache.py +127 -0
  16. matrice_inference/server/stream/inference_worker.py +163 -0
  17. matrice_inference/server/stream/post_processing_worker.py +230 -0
  18. matrice_inference/server/stream/producer_worker.py +147 -0
  19. matrice_inference/server/stream/stream_pipeline.py +451 -0
  20. matrice_inference/server/stream/utils.py +23 -0
  21. matrice_inference/tmp/abstract_model_manager.py +58 -0
  22. matrice_inference/tmp/aggregator/__init__.py +18 -0
  23. matrice_inference/tmp/aggregator/aggregator.py +330 -0
  24. matrice_inference/tmp/aggregator/analytics.py +906 -0
  25. matrice_inference/tmp/aggregator/ingestor.py +438 -0
  26. matrice_inference/tmp/aggregator/latency.py +597 -0
  27. matrice_inference/tmp/aggregator/pipeline.py +968 -0
  28. matrice_inference/tmp/aggregator/publisher.py +431 -0
  29. matrice_inference/tmp/aggregator/synchronizer.py +594 -0
  30. matrice_inference/tmp/batch_manager.py +239 -0
  31. matrice_inference/tmp/overall_inference_testing.py +338 -0
  32. matrice_inference/tmp/triton_utils.py +638 -0
  33. matrice_inference-0.1.2.dist-info/METADATA +28 -0
  34. matrice_inference-0.1.2.dist-info/RECORD +37 -0
  35. matrice_inference-0.1.2.dist-info/WHEEL +5 -0
  36. matrice_inference-0.1.2.dist-info/licenses/LICENSE.txt +21 -0
  37. matrice_inference-0.1.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,371 @@
1
+ """Module providing proxy_interface functionality."""
2
+
3
+ import logging
4
+ import time
5
+ import threading
6
+ from datetime import datetime, timezone
7
+ from typing import Optional, Set
8
+ import httpx
9
+ import uvicorn
10
+ import asyncio
11
+
12
+ from fastapi import (
13
+ FastAPI,
14
+ HTTPException,
15
+ Request,
16
+ )
17
+ from fastapi.encoders import jsonable_encoder
18
+ from fastapi.responses import JSONResponse
19
+ from matrice_inference.server.inference_interface import InferenceInterface
20
+
21
+
22
+ class MatriceProxyInterface:
23
+ """Interface for proxying requests to model servers."""
24
+
25
+ def __init__(
26
+ self,
27
+ session,
28
+ deployment_id: str,
29
+ deployment_instance_id: str,
30
+ external_port: int,
31
+ inference_interface: InferenceInterface,
32
+ auth_refresh_interval_minutes: int = 1,
33
+ ):
34
+ """Initialize proxy server.
35
+
36
+ Args:
37
+ session: Session object for authentication and RPC
38
+ deployment_id: ID of the deployment
39
+ deployment_instance_id: ID of the deployment instance
40
+ external_port: Port to expose externally
41
+ inference_interface: Interface for model inference
42
+ auth_refresh_interval_minutes: Minimum minutes between auth key refreshes
43
+ """
44
+ self.session = session
45
+ self.rpc = self.session.rpc
46
+ self.deployment_id = deployment_id
47
+ self.deployment_instance_id = deployment_instance_id
48
+ self.external_port = external_port
49
+ self.app = FastAPI()
50
+ self.logger = logging.getLogger(__name__)
51
+ self.inference_interface = inference_interface
52
+ self._shutdown_event = threading.Event()
53
+ self._server = None
54
+ self._server_thread = None
55
+
56
+ # Auth key management
57
+ self.auth_keys: Set[str] = set()
58
+ self.auth_keys_info = []
59
+ self.auth_refresh_interval_minutes = auth_refresh_interval_minutes
60
+ self.last_auth_refresh_time = 0.0
61
+ self._auth_lock = threading.Lock()
62
+
63
+ # Initialize auth keys on startup
64
+ self._initialize_auth_keys()
65
+ self._register_routes()
66
+
67
+ def _initialize_auth_keys(self):
68
+ """Initialize auth keys on startup."""
69
+ try:
70
+ self.update_auth_keys()
71
+ self.logger.info("Auth keys initialized successfully")
72
+ except Exception as exc:
73
+ self.logger.error("Failed to initialize auth keys: %s", str(exc))
74
+ # Continue without auth keys - they will be retried on first request
75
+
76
+ def _should_refresh_auth_keys(self) -> bool:
77
+ """Check if auth keys should be refreshed based on time interval."""
78
+ current_time = time.time()
79
+ time_since_last_refresh = current_time - self.last_auth_refresh_time
80
+ return time_since_last_refresh >= (self.auth_refresh_interval_minutes * 60)
81
+
82
+ def _refresh_auth_keys_if_needed(self):
83
+ """Refresh auth keys if the refresh interval has passed."""
84
+ if self._should_refresh_auth_keys():
85
+ with self._auth_lock:
86
+ # Double-check after acquiring lock
87
+ if self._should_refresh_auth_keys():
88
+ try:
89
+ self.update_auth_keys()
90
+ self.logger.debug("Auth keys refreshed successfully")
91
+ except Exception as exc:
92
+ self.logger.error("Failed to refresh auth keys: %s", str(exc))
93
+
94
+ def validate_auth_key(self, auth_key):
95
+ """Validate auth key.
96
+
97
+ Args:
98
+ auth_key: Authentication key to validate
99
+
100
+ Returns:
101
+ bool: True if valid, False otherwise
102
+ """
103
+ if not auth_key:
104
+ return False
105
+
106
+ # Refresh auth keys if needed before validation
107
+ self._refresh_auth_keys_if_needed()
108
+
109
+ with self._auth_lock:
110
+ return auth_key in self.auth_keys
111
+
112
+ def _parse_expiry_time(self, expiry_time_str: str) -> float:
113
+ """Parse expiry time string to timestamp."""
114
+ # Handle different ISO format variations
115
+ try:
116
+ # Replace Z with timezone if needed
117
+ if expiry_time_str.endswith("Z"):
118
+ expiry_time_str = expiry_time_str.replace("Z", "+00:00")
119
+
120
+ # Normalize ISO format for proper parsing
121
+ if '.' in expiry_time_str:
122
+ main, rest = expiry_time_str.split('.', 1)
123
+ if '+' in rest:
124
+ frac, tz = rest.split('+', 1)
125
+ frac = (frac + '000000')[:6] # pad/truncate to 6 digits
126
+ expiry_time_str = f"{main}.{frac}+{tz}"
127
+ elif '-' in rest:
128
+ frac, tz = rest.split('-', 1)
129
+ frac = (frac + '000000')[:6] # pad/truncate to 6 digits
130
+ expiry_time_str = f"{main}.{frac}-{tz}"
131
+ except Exception as err:
132
+ self.logger.error("Error parsing expiry time: %s", str(err))
133
+ expiry_time_str = expiry_time_str.replace("Z", "+00:00")
134
+ return datetime.fromisoformat(expiry_time_str).timestamp()
135
+
136
+ def update_auth_keys(self) -> None:
137
+ """Fetch and validate auth keys for the deployment."""
138
+ try:
139
+ response = self.rpc.get(f"/v1/inference/{self.deployment_id}", raise_exception=False)
140
+ if not response.get("success"):
141
+ self.logger.error("Failed to fetch auth keys")
142
+ return
143
+
144
+ if response["data"]["authKeys"]:
145
+ self.auth_keys_info = response["data"]["authKeys"]
146
+ else:
147
+ self.auth_keys_info = []
148
+
149
+ if not self.auth_keys_info:
150
+ self.logger.debug("No auth keys found for deployment")
151
+ return
152
+
153
+ current_time = time.time()
154
+ self.auth_keys.clear()
155
+
156
+ for auth_key_info in self.auth_keys_info:
157
+ try:
158
+ expiry_time = self._parse_expiry_time(auth_key_info["expiryTime"])
159
+ if expiry_time > current_time:
160
+ self.auth_keys.add(auth_key_info["key"])
161
+ else:
162
+ self.logger.debug("Skipping expired auth key")
163
+ except (ValueError, KeyError) as err:
164
+ self.logger.error("Invalid auth key data: %s", err)
165
+ continue
166
+
167
+ # Update last refresh time
168
+ self.last_auth_refresh_time = current_time
169
+
170
+ self.logger.debug(
171
+ "Successfully loaded %d valid auth keys",
172
+ len(self.auth_keys),
173
+ )
174
+ except Exception as err:
175
+ self.logger.error("Error fetching auth keys: %s", str(err))
176
+ raise
177
+
178
+ def _register_routes(self):
179
+ """Register proxy routes."""
180
+
181
+ @self.app.post("/inference")
182
+ async def proxy_request(request: Request):
183
+ # Check if server is shutting down
184
+ if self._shutdown_event.is_set():
185
+ raise HTTPException(
186
+ status_code=503,
187
+ detail="Server is shutting down",
188
+ )
189
+
190
+ # Parse form data manually
191
+ try:
192
+ form_data = await request.form()
193
+ except Exception as e:
194
+ raise HTTPException(
195
+ status_code=400,
196
+ detail=f"Failed to parse form data: {str(e)}",
197
+ ) from e
198
+
199
+ # Extract parameters from form data
200
+ auth_key = form_data.get("auth_key") or form_data.get("authKey")
201
+ input_file = form_data.get("input")
202
+ input_url_value = form_data.get("input_url") or form_data.get("inputUrl")
203
+ extra_params = form_data.get("extra_params")
204
+ apply_post_processing = form_data.get("apply_post_processing", "false")
205
+
206
+ # if not self.validate_auth_key(auth_key): # TODO: enable once fixed to send the external auth key for FR server
207
+ # raise HTTPException(
208
+ # status_code=401,
209
+ # detail="Invalid auth key",
210
+ # )
211
+
212
+ # Handle file input
213
+ input_data = None
214
+ if input_file and hasattr(input_file, 'read'):
215
+ input_data = await input_file.read()
216
+ elif isinstance(input_file, bytes):
217
+ input_data = input_file
218
+
219
+ if input_url_value:
220
+ try:
221
+ # Use timeout and error handling for URL downloads
222
+ timeout = httpx.Timeout(60.0) # 60 second timeout
223
+ async with httpx.AsyncClient(timeout=timeout) as client:
224
+ response = await client.get(input_url_value)
225
+ response.raise_for_status() # Raise exception for HTTP errors
226
+ input_data = response.content
227
+ except asyncio.CancelledError:
228
+ # Handle shutdown during request
229
+ raise HTTPException(
230
+ status_code=503,
231
+ detail="Request cancelled due to server shutdown",
232
+ )
233
+ except httpx.TimeoutException:
234
+ raise HTTPException(
235
+ status_code=408,
236
+ detail="Timeout fetching input URL",
237
+ )
238
+ except httpx.HTTPStatusError as e:
239
+ raise HTTPException(
240
+ status_code=400,
241
+ detail=f"HTTP error fetching input URL: {e.response.status_code}",
242
+ )
243
+ except Exception as e:
244
+ raise HTTPException(
245
+ status_code=400,
246
+ detail=f"Error fetching input URL: {str(e)}",
247
+ )
248
+
249
+ if not input_data:
250
+ raise HTTPException(
251
+ status_code=400,
252
+ detail="No input provided",
253
+ )
254
+
255
+ # Parse apply_post_processing parameter
256
+ apply_post_processing_flag = False
257
+ if apply_post_processing:
258
+ apply_post_processing_flag = apply_post_processing.lower() in ("true", "1", "yes")
259
+
260
+ try:
261
+ # Check shutdown again before inference
262
+ if self._shutdown_event.is_set():
263
+ raise HTTPException(
264
+ status_code=503,
265
+ detail="Server is shutting down",
266
+ )
267
+
268
+ result, post_processing_result = await self.inference_interface.inference(
269
+ input=input_data,
270
+ extra_params=extra_params,
271
+ apply_post_processing=apply_post_processing_flag
272
+ )
273
+
274
+
275
+ response_data = {
276
+ "status": 1,
277
+ "message": "Request success",
278
+ "result": result,
279
+ }
280
+
281
+ # Include post-processing results if available
282
+ if post_processing_result is not None:
283
+ response_data["post_processing_result"] = post_processing_result
284
+ response_data["post_processing_applied"] = True
285
+ else:
286
+ response_data["post_processing_applied"] = False
287
+
288
+ return JSONResponse(
289
+ content=jsonable_encoder(response_data)
290
+ )
291
+ except asyncio.CancelledError:
292
+ # Handle shutdown during inference
293
+ self.logger.info("Request cancelled during inference due to shutdown")
294
+ raise HTTPException(
295
+ status_code=503,
296
+ detail="Request cancelled due to server shutdown",
297
+ )
298
+ except Exception as exc:
299
+ self.logger.error("Proxy error: %s", str(exc))
300
+ raise HTTPException(
301
+ status_code=500,
302
+ detail=str(exc),
303
+ ) from exc
304
+
305
+
306
+ def start(self):
307
+ """Start the proxy server in a background thread."""
308
+ def run_server():
309
+ """Run the uvicorn server."""
310
+ try:
311
+ self.logger.info(
312
+ "Starting proxy server on port %d",
313
+ self.external_port,
314
+ )
315
+ config = uvicorn.Config(
316
+ app=self.app,
317
+ host="0.0.0.0",
318
+ port=self.external_port,
319
+ log_level="info",
320
+ )
321
+ self._server = uvicorn.Server(config)
322
+ self._server.run()
323
+
324
+ except Exception as exc:
325
+ if not self._shutdown_event.is_set():
326
+ self.logger.error(
327
+ "Failed to start proxy server: %s",
328
+ str(exc),
329
+ )
330
+ else:
331
+ self.logger.info("Proxy server stopped during shutdown")
332
+
333
+ # Start the server in a background thread
334
+ self._server_thread = threading.Thread(target=run_server, daemon=False, name="ProxyServer")
335
+ self._server_thread.start()
336
+
337
+ # Wait a moment for the server to start
338
+ time.sleep(0.5)
339
+ self.logger.info("Proxy server thread started successfully")
340
+
341
+ def stop(self):
342
+ """Stop the proxy server gracefully."""
343
+ try:
344
+ self.logger.info("Stopping proxy server...")
345
+
346
+ # Signal shutdown to prevent new requests
347
+ self._shutdown_event.set()
348
+
349
+ # Stop the uvicorn server if it exists
350
+ if self._server:
351
+ try:
352
+ # Force shutdown the server
353
+ if hasattr(self._server, 'should_exit'):
354
+ self._server.should_exit = True
355
+ if hasattr(self._server, 'force_exit'):
356
+ self._server.force_exit = True
357
+ except Exception as exc:
358
+ self.logger.warning("Error stopping uvicorn server: %s", str(exc))
359
+
360
+ # Wait for the server thread to finish
361
+ if self._server_thread and self._server_thread.is_alive():
362
+ self.logger.info("Waiting for proxy server thread to stop...")
363
+ self._server_thread.join(timeout=5.0)
364
+ if self._server_thread.is_alive():
365
+ self.logger.warning("Proxy server thread did not stop within timeout")
366
+ else:
367
+ self.logger.info("Proxy server thread stopped successfully")
368
+
369
+ self.logger.info("Proxy server stopped")
370
+ except Exception as exc:
371
+ self.logger.error("Error stopping proxy server: %s", str(exc))