matrice-inference 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of matrice-inference might be problematic. Click here for more details.
- matrice_inference/__init__.py +72 -0
- matrice_inference/py.typed +0 -0
- matrice_inference/server/__init__.py +23 -0
- matrice_inference/server/inference_interface.py +176 -0
- matrice_inference/server/model/__init__.py +1 -0
- matrice_inference/server/model/model_manager.py +274 -0
- matrice_inference/server/model/model_manager_wrapper.py +550 -0
- matrice_inference/server/model/triton_model_manager.py +290 -0
- matrice_inference/server/model/triton_server.py +1248 -0
- matrice_inference/server/proxy_interface.py +371 -0
- matrice_inference/server/server.py +1004 -0
- matrice_inference/server/stream/__init__.py +0 -0
- matrice_inference/server/stream/app_deployment.py +228 -0
- matrice_inference/server/stream/consumer_worker.py +201 -0
- matrice_inference/server/stream/frame_cache.py +127 -0
- matrice_inference/server/stream/inference_worker.py +163 -0
- matrice_inference/server/stream/post_processing_worker.py +230 -0
- matrice_inference/server/stream/producer_worker.py +147 -0
- matrice_inference/server/stream/stream_pipeline.py +451 -0
- matrice_inference/server/stream/utils.py +23 -0
- matrice_inference/tmp/abstract_model_manager.py +58 -0
- matrice_inference/tmp/aggregator/__init__.py +18 -0
- matrice_inference/tmp/aggregator/aggregator.py +330 -0
- matrice_inference/tmp/aggregator/analytics.py +906 -0
- matrice_inference/tmp/aggregator/ingestor.py +438 -0
- matrice_inference/tmp/aggregator/latency.py +597 -0
- matrice_inference/tmp/aggregator/pipeline.py +968 -0
- matrice_inference/tmp/aggregator/publisher.py +431 -0
- matrice_inference/tmp/aggregator/synchronizer.py +594 -0
- matrice_inference/tmp/batch_manager.py +239 -0
- matrice_inference/tmp/overall_inference_testing.py +338 -0
- matrice_inference/tmp/triton_utils.py +638 -0
- matrice_inference-0.1.2.dist-info/METADATA +28 -0
- matrice_inference-0.1.2.dist-info/RECORD +37 -0
- matrice_inference-0.1.2.dist-info/WHEEL +5 -0
- matrice_inference-0.1.2.dist-info/licenses/LICENSE.txt +21 -0
- matrice_inference-0.1.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,371 @@
|
|
|
1
|
+
"""Module providing proxy_interface functionality."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import time
|
|
5
|
+
import threading
|
|
6
|
+
from datetime import datetime, timezone
|
|
7
|
+
from typing import Optional, Set
|
|
8
|
+
import httpx
|
|
9
|
+
import uvicorn
|
|
10
|
+
import asyncio
|
|
11
|
+
|
|
12
|
+
from fastapi import (
|
|
13
|
+
FastAPI,
|
|
14
|
+
HTTPException,
|
|
15
|
+
Request,
|
|
16
|
+
)
|
|
17
|
+
from fastapi.encoders import jsonable_encoder
|
|
18
|
+
from fastapi.responses import JSONResponse
|
|
19
|
+
from matrice_inference.server.inference_interface import InferenceInterface
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class MatriceProxyInterface:
|
|
23
|
+
"""Interface for proxying requests to model servers."""
|
|
24
|
+
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
session,
|
|
28
|
+
deployment_id: str,
|
|
29
|
+
deployment_instance_id: str,
|
|
30
|
+
external_port: int,
|
|
31
|
+
inference_interface: InferenceInterface,
|
|
32
|
+
auth_refresh_interval_minutes: int = 1,
|
|
33
|
+
):
|
|
34
|
+
"""Initialize proxy server.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
session: Session object for authentication and RPC
|
|
38
|
+
deployment_id: ID of the deployment
|
|
39
|
+
deployment_instance_id: ID of the deployment instance
|
|
40
|
+
external_port: Port to expose externally
|
|
41
|
+
inference_interface: Interface for model inference
|
|
42
|
+
auth_refresh_interval_minutes: Minimum minutes between auth key refreshes
|
|
43
|
+
"""
|
|
44
|
+
self.session = session
|
|
45
|
+
self.rpc = self.session.rpc
|
|
46
|
+
self.deployment_id = deployment_id
|
|
47
|
+
self.deployment_instance_id = deployment_instance_id
|
|
48
|
+
self.external_port = external_port
|
|
49
|
+
self.app = FastAPI()
|
|
50
|
+
self.logger = logging.getLogger(__name__)
|
|
51
|
+
self.inference_interface = inference_interface
|
|
52
|
+
self._shutdown_event = threading.Event()
|
|
53
|
+
self._server = None
|
|
54
|
+
self._server_thread = None
|
|
55
|
+
|
|
56
|
+
# Auth key management
|
|
57
|
+
self.auth_keys: Set[str] = set()
|
|
58
|
+
self.auth_keys_info = []
|
|
59
|
+
self.auth_refresh_interval_minutes = auth_refresh_interval_minutes
|
|
60
|
+
self.last_auth_refresh_time = 0.0
|
|
61
|
+
self._auth_lock = threading.Lock()
|
|
62
|
+
|
|
63
|
+
# Initialize auth keys on startup
|
|
64
|
+
self._initialize_auth_keys()
|
|
65
|
+
self._register_routes()
|
|
66
|
+
|
|
67
|
+
def _initialize_auth_keys(self):
|
|
68
|
+
"""Initialize auth keys on startup."""
|
|
69
|
+
try:
|
|
70
|
+
self.update_auth_keys()
|
|
71
|
+
self.logger.info("Auth keys initialized successfully")
|
|
72
|
+
except Exception as exc:
|
|
73
|
+
self.logger.error("Failed to initialize auth keys: %s", str(exc))
|
|
74
|
+
# Continue without auth keys - they will be retried on first request
|
|
75
|
+
|
|
76
|
+
def _should_refresh_auth_keys(self) -> bool:
|
|
77
|
+
"""Check if auth keys should be refreshed based on time interval."""
|
|
78
|
+
current_time = time.time()
|
|
79
|
+
time_since_last_refresh = current_time - self.last_auth_refresh_time
|
|
80
|
+
return time_since_last_refresh >= (self.auth_refresh_interval_minutes * 60)
|
|
81
|
+
|
|
82
|
+
def _refresh_auth_keys_if_needed(self):
|
|
83
|
+
"""Refresh auth keys if the refresh interval has passed."""
|
|
84
|
+
if self._should_refresh_auth_keys():
|
|
85
|
+
with self._auth_lock:
|
|
86
|
+
# Double-check after acquiring lock
|
|
87
|
+
if self._should_refresh_auth_keys():
|
|
88
|
+
try:
|
|
89
|
+
self.update_auth_keys()
|
|
90
|
+
self.logger.debug("Auth keys refreshed successfully")
|
|
91
|
+
except Exception as exc:
|
|
92
|
+
self.logger.error("Failed to refresh auth keys: %s", str(exc))
|
|
93
|
+
|
|
94
|
+
def validate_auth_key(self, auth_key):
|
|
95
|
+
"""Validate auth key.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
auth_key: Authentication key to validate
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
bool: True if valid, False otherwise
|
|
102
|
+
"""
|
|
103
|
+
if not auth_key:
|
|
104
|
+
return False
|
|
105
|
+
|
|
106
|
+
# Refresh auth keys if needed before validation
|
|
107
|
+
self._refresh_auth_keys_if_needed()
|
|
108
|
+
|
|
109
|
+
with self._auth_lock:
|
|
110
|
+
return auth_key in self.auth_keys
|
|
111
|
+
|
|
112
|
+
def _parse_expiry_time(self, expiry_time_str: str) -> float:
|
|
113
|
+
"""Parse expiry time string to timestamp."""
|
|
114
|
+
# Handle different ISO format variations
|
|
115
|
+
try:
|
|
116
|
+
# Replace Z with timezone if needed
|
|
117
|
+
if expiry_time_str.endswith("Z"):
|
|
118
|
+
expiry_time_str = expiry_time_str.replace("Z", "+00:00")
|
|
119
|
+
|
|
120
|
+
# Normalize ISO format for proper parsing
|
|
121
|
+
if '.' in expiry_time_str:
|
|
122
|
+
main, rest = expiry_time_str.split('.', 1)
|
|
123
|
+
if '+' in rest:
|
|
124
|
+
frac, tz = rest.split('+', 1)
|
|
125
|
+
frac = (frac + '000000')[:6] # pad/truncate to 6 digits
|
|
126
|
+
expiry_time_str = f"{main}.{frac}+{tz}"
|
|
127
|
+
elif '-' in rest:
|
|
128
|
+
frac, tz = rest.split('-', 1)
|
|
129
|
+
frac = (frac + '000000')[:6] # pad/truncate to 6 digits
|
|
130
|
+
expiry_time_str = f"{main}.{frac}-{tz}"
|
|
131
|
+
except Exception as err:
|
|
132
|
+
self.logger.error("Error parsing expiry time: %s", str(err))
|
|
133
|
+
expiry_time_str = expiry_time_str.replace("Z", "+00:00")
|
|
134
|
+
return datetime.fromisoformat(expiry_time_str).timestamp()
|
|
135
|
+
|
|
136
|
+
def update_auth_keys(self) -> None:
|
|
137
|
+
"""Fetch and validate auth keys for the deployment."""
|
|
138
|
+
try:
|
|
139
|
+
response = self.rpc.get(f"/v1/inference/{self.deployment_id}", raise_exception=False)
|
|
140
|
+
if not response.get("success"):
|
|
141
|
+
self.logger.error("Failed to fetch auth keys")
|
|
142
|
+
return
|
|
143
|
+
|
|
144
|
+
if response["data"]["authKeys"]:
|
|
145
|
+
self.auth_keys_info = response["data"]["authKeys"]
|
|
146
|
+
else:
|
|
147
|
+
self.auth_keys_info = []
|
|
148
|
+
|
|
149
|
+
if not self.auth_keys_info:
|
|
150
|
+
self.logger.debug("No auth keys found for deployment")
|
|
151
|
+
return
|
|
152
|
+
|
|
153
|
+
current_time = time.time()
|
|
154
|
+
self.auth_keys.clear()
|
|
155
|
+
|
|
156
|
+
for auth_key_info in self.auth_keys_info:
|
|
157
|
+
try:
|
|
158
|
+
expiry_time = self._parse_expiry_time(auth_key_info["expiryTime"])
|
|
159
|
+
if expiry_time > current_time:
|
|
160
|
+
self.auth_keys.add(auth_key_info["key"])
|
|
161
|
+
else:
|
|
162
|
+
self.logger.debug("Skipping expired auth key")
|
|
163
|
+
except (ValueError, KeyError) as err:
|
|
164
|
+
self.logger.error("Invalid auth key data: %s", err)
|
|
165
|
+
continue
|
|
166
|
+
|
|
167
|
+
# Update last refresh time
|
|
168
|
+
self.last_auth_refresh_time = current_time
|
|
169
|
+
|
|
170
|
+
self.logger.debug(
|
|
171
|
+
"Successfully loaded %d valid auth keys",
|
|
172
|
+
len(self.auth_keys),
|
|
173
|
+
)
|
|
174
|
+
except Exception as err:
|
|
175
|
+
self.logger.error("Error fetching auth keys: %s", str(err))
|
|
176
|
+
raise
|
|
177
|
+
|
|
178
|
+
def _register_routes(self):
|
|
179
|
+
"""Register proxy routes."""
|
|
180
|
+
|
|
181
|
+
@self.app.post("/inference")
|
|
182
|
+
async def proxy_request(request: Request):
|
|
183
|
+
# Check if server is shutting down
|
|
184
|
+
if self._shutdown_event.is_set():
|
|
185
|
+
raise HTTPException(
|
|
186
|
+
status_code=503,
|
|
187
|
+
detail="Server is shutting down",
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
# Parse form data manually
|
|
191
|
+
try:
|
|
192
|
+
form_data = await request.form()
|
|
193
|
+
except Exception as e:
|
|
194
|
+
raise HTTPException(
|
|
195
|
+
status_code=400,
|
|
196
|
+
detail=f"Failed to parse form data: {str(e)}",
|
|
197
|
+
) from e
|
|
198
|
+
|
|
199
|
+
# Extract parameters from form data
|
|
200
|
+
auth_key = form_data.get("auth_key") or form_data.get("authKey")
|
|
201
|
+
input_file = form_data.get("input")
|
|
202
|
+
input_url_value = form_data.get("input_url") or form_data.get("inputUrl")
|
|
203
|
+
extra_params = form_data.get("extra_params")
|
|
204
|
+
apply_post_processing = form_data.get("apply_post_processing", "false")
|
|
205
|
+
|
|
206
|
+
# if not self.validate_auth_key(auth_key): # TODO: enable once fixed to send the external auth key for FR server
|
|
207
|
+
# raise HTTPException(
|
|
208
|
+
# status_code=401,
|
|
209
|
+
# detail="Invalid auth key",
|
|
210
|
+
# )
|
|
211
|
+
|
|
212
|
+
# Handle file input
|
|
213
|
+
input_data = None
|
|
214
|
+
if input_file and hasattr(input_file, 'read'):
|
|
215
|
+
input_data = await input_file.read()
|
|
216
|
+
elif isinstance(input_file, bytes):
|
|
217
|
+
input_data = input_file
|
|
218
|
+
|
|
219
|
+
if input_url_value:
|
|
220
|
+
try:
|
|
221
|
+
# Use timeout and error handling for URL downloads
|
|
222
|
+
timeout = httpx.Timeout(60.0) # 60 second timeout
|
|
223
|
+
async with httpx.AsyncClient(timeout=timeout) as client:
|
|
224
|
+
response = await client.get(input_url_value)
|
|
225
|
+
response.raise_for_status() # Raise exception for HTTP errors
|
|
226
|
+
input_data = response.content
|
|
227
|
+
except asyncio.CancelledError:
|
|
228
|
+
# Handle shutdown during request
|
|
229
|
+
raise HTTPException(
|
|
230
|
+
status_code=503,
|
|
231
|
+
detail="Request cancelled due to server shutdown",
|
|
232
|
+
)
|
|
233
|
+
except httpx.TimeoutException:
|
|
234
|
+
raise HTTPException(
|
|
235
|
+
status_code=408,
|
|
236
|
+
detail="Timeout fetching input URL",
|
|
237
|
+
)
|
|
238
|
+
except httpx.HTTPStatusError as e:
|
|
239
|
+
raise HTTPException(
|
|
240
|
+
status_code=400,
|
|
241
|
+
detail=f"HTTP error fetching input URL: {e.response.status_code}",
|
|
242
|
+
)
|
|
243
|
+
except Exception as e:
|
|
244
|
+
raise HTTPException(
|
|
245
|
+
status_code=400,
|
|
246
|
+
detail=f"Error fetching input URL: {str(e)}",
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
if not input_data:
|
|
250
|
+
raise HTTPException(
|
|
251
|
+
status_code=400,
|
|
252
|
+
detail="No input provided",
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
# Parse apply_post_processing parameter
|
|
256
|
+
apply_post_processing_flag = False
|
|
257
|
+
if apply_post_processing:
|
|
258
|
+
apply_post_processing_flag = apply_post_processing.lower() in ("true", "1", "yes")
|
|
259
|
+
|
|
260
|
+
try:
|
|
261
|
+
# Check shutdown again before inference
|
|
262
|
+
if self._shutdown_event.is_set():
|
|
263
|
+
raise HTTPException(
|
|
264
|
+
status_code=503,
|
|
265
|
+
detail="Server is shutting down",
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
result, post_processing_result = await self.inference_interface.inference(
|
|
269
|
+
input=input_data,
|
|
270
|
+
extra_params=extra_params,
|
|
271
|
+
apply_post_processing=apply_post_processing_flag
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
response_data = {
|
|
276
|
+
"status": 1,
|
|
277
|
+
"message": "Request success",
|
|
278
|
+
"result": result,
|
|
279
|
+
}
|
|
280
|
+
|
|
281
|
+
# Include post-processing results if available
|
|
282
|
+
if post_processing_result is not None:
|
|
283
|
+
response_data["post_processing_result"] = post_processing_result
|
|
284
|
+
response_data["post_processing_applied"] = True
|
|
285
|
+
else:
|
|
286
|
+
response_data["post_processing_applied"] = False
|
|
287
|
+
|
|
288
|
+
return JSONResponse(
|
|
289
|
+
content=jsonable_encoder(response_data)
|
|
290
|
+
)
|
|
291
|
+
except asyncio.CancelledError:
|
|
292
|
+
# Handle shutdown during inference
|
|
293
|
+
self.logger.info("Request cancelled during inference due to shutdown")
|
|
294
|
+
raise HTTPException(
|
|
295
|
+
status_code=503,
|
|
296
|
+
detail="Request cancelled due to server shutdown",
|
|
297
|
+
)
|
|
298
|
+
except Exception as exc:
|
|
299
|
+
self.logger.error("Proxy error: %s", str(exc))
|
|
300
|
+
raise HTTPException(
|
|
301
|
+
status_code=500,
|
|
302
|
+
detail=str(exc),
|
|
303
|
+
) from exc
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
def start(self):
|
|
307
|
+
"""Start the proxy server in a background thread."""
|
|
308
|
+
def run_server():
|
|
309
|
+
"""Run the uvicorn server."""
|
|
310
|
+
try:
|
|
311
|
+
self.logger.info(
|
|
312
|
+
"Starting proxy server on port %d",
|
|
313
|
+
self.external_port,
|
|
314
|
+
)
|
|
315
|
+
config = uvicorn.Config(
|
|
316
|
+
app=self.app,
|
|
317
|
+
host="0.0.0.0",
|
|
318
|
+
port=self.external_port,
|
|
319
|
+
log_level="info",
|
|
320
|
+
)
|
|
321
|
+
self._server = uvicorn.Server(config)
|
|
322
|
+
self._server.run()
|
|
323
|
+
|
|
324
|
+
except Exception as exc:
|
|
325
|
+
if not self._shutdown_event.is_set():
|
|
326
|
+
self.logger.error(
|
|
327
|
+
"Failed to start proxy server: %s",
|
|
328
|
+
str(exc),
|
|
329
|
+
)
|
|
330
|
+
else:
|
|
331
|
+
self.logger.info("Proxy server stopped during shutdown")
|
|
332
|
+
|
|
333
|
+
# Start the server in a background thread
|
|
334
|
+
self._server_thread = threading.Thread(target=run_server, daemon=False, name="ProxyServer")
|
|
335
|
+
self._server_thread.start()
|
|
336
|
+
|
|
337
|
+
# Wait a moment for the server to start
|
|
338
|
+
time.sleep(0.5)
|
|
339
|
+
self.logger.info("Proxy server thread started successfully")
|
|
340
|
+
|
|
341
|
+
def stop(self):
|
|
342
|
+
"""Stop the proxy server gracefully."""
|
|
343
|
+
try:
|
|
344
|
+
self.logger.info("Stopping proxy server...")
|
|
345
|
+
|
|
346
|
+
# Signal shutdown to prevent new requests
|
|
347
|
+
self._shutdown_event.set()
|
|
348
|
+
|
|
349
|
+
# Stop the uvicorn server if it exists
|
|
350
|
+
if self._server:
|
|
351
|
+
try:
|
|
352
|
+
# Force shutdown the server
|
|
353
|
+
if hasattr(self._server, 'should_exit'):
|
|
354
|
+
self._server.should_exit = True
|
|
355
|
+
if hasattr(self._server, 'force_exit'):
|
|
356
|
+
self._server.force_exit = True
|
|
357
|
+
except Exception as exc:
|
|
358
|
+
self.logger.warning("Error stopping uvicorn server: %s", str(exc))
|
|
359
|
+
|
|
360
|
+
# Wait for the server thread to finish
|
|
361
|
+
if self._server_thread and self._server_thread.is_alive():
|
|
362
|
+
self.logger.info("Waiting for proxy server thread to stop...")
|
|
363
|
+
self._server_thread.join(timeout=5.0)
|
|
364
|
+
if self._server_thread.is_alive():
|
|
365
|
+
self.logger.warning("Proxy server thread did not stop within timeout")
|
|
366
|
+
else:
|
|
367
|
+
self.logger.info("Proxy server thread stopped successfully")
|
|
368
|
+
|
|
369
|
+
self.logger.info("Proxy server stopped")
|
|
370
|
+
except Exception as exc:
|
|
371
|
+
self.logger.error("Error stopping proxy server: %s", str(exc))
|