matrice-inference 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of matrice-inference might be problematic. Click here for more details.

Files changed (37) hide show
  1. matrice_inference/__init__.py +72 -0
  2. matrice_inference/py.typed +0 -0
  3. matrice_inference/server/__init__.py +23 -0
  4. matrice_inference/server/inference_interface.py +176 -0
  5. matrice_inference/server/model/__init__.py +1 -0
  6. matrice_inference/server/model/model_manager.py +274 -0
  7. matrice_inference/server/model/model_manager_wrapper.py +550 -0
  8. matrice_inference/server/model/triton_model_manager.py +290 -0
  9. matrice_inference/server/model/triton_server.py +1248 -0
  10. matrice_inference/server/proxy_interface.py +371 -0
  11. matrice_inference/server/server.py +1004 -0
  12. matrice_inference/server/stream/__init__.py +0 -0
  13. matrice_inference/server/stream/app_deployment.py +228 -0
  14. matrice_inference/server/stream/consumer_worker.py +201 -0
  15. matrice_inference/server/stream/frame_cache.py +127 -0
  16. matrice_inference/server/stream/inference_worker.py +163 -0
  17. matrice_inference/server/stream/post_processing_worker.py +230 -0
  18. matrice_inference/server/stream/producer_worker.py +147 -0
  19. matrice_inference/server/stream/stream_pipeline.py +451 -0
  20. matrice_inference/server/stream/utils.py +23 -0
  21. matrice_inference/tmp/abstract_model_manager.py +58 -0
  22. matrice_inference/tmp/aggregator/__init__.py +18 -0
  23. matrice_inference/tmp/aggregator/aggregator.py +330 -0
  24. matrice_inference/tmp/aggregator/analytics.py +906 -0
  25. matrice_inference/tmp/aggregator/ingestor.py +438 -0
  26. matrice_inference/tmp/aggregator/latency.py +597 -0
  27. matrice_inference/tmp/aggregator/pipeline.py +968 -0
  28. matrice_inference/tmp/aggregator/publisher.py +431 -0
  29. matrice_inference/tmp/aggregator/synchronizer.py +594 -0
  30. matrice_inference/tmp/batch_manager.py +239 -0
  31. matrice_inference/tmp/overall_inference_testing.py +338 -0
  32. matrice_inference/tmp/triton_utils.py +638 -0
  33. matrice_inference-0.1.2.dist-info/METADATA +28 -0
  34. matrice_inference-0.1.2.dist-info/RECORD +37 -0
  35. matrice_inference-0.1.2.dist-info/WHEEL +5 -0
  36. matrice_inference-0.1.2.dist-info/licenses/LICENSE.txt +21 -0
  37. matrice_inference-0.1.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,550 @@
1
+ from typing import Tuple, Dict, Any, Optional, List, Union, Callable
2
+ import logging
3
+ from matrice.action_tracker import ActionTracker
4
+ from matrice_inference.server.model.model_manager import ModelManager
5
+ from matrice_inference.server.model.triton_model_manager import TritonModelManager
6
+ import numpy as np
7
+
8
+ class ModelManagerWrapper:
9
+ """Wrapper class for ModelManager and TritonModelManager to provide a unified interface."""
10
+
11
+ def __init__(
12
+ self,
13
+ action_tracker: ActionTracker,
14
+ test_env: bool = False,
15
+ # "default" for ModelManager, "triton" for TritonModelManager
16
+ model_type: str = "default",
17
+ model_id: Optional[str] = None,
18
+ internal_server_type: Optional[str] = None,
19
+ internal_port: Optional[int] = None,
20
+ internal_host: Optional[str] = None,
21
+ num_model_instances: Optional[int] = None,
22
+ load_model: Optional[Callable] = None,
23
+ predict: Optional[Callable] = None,
24
+ batch_predict: Optional[Callable] = None,
25
+ model_name: Optional[str] = None,
26
+ model_path: Optional[str] = None,
27
+ runtime_framework: Optional[str] = None,
28
+ input_size: Optional[Union[int, List[int]]] = None,
29
+ num_classes: Optional[int] = None,
30
+ use_dynamic_batching: Optional[bool] = None,
31
+ max_batch_size: Optional[int] = None,
32
+ is_yolo: Optional[bool] = None,
33
+ is_ocr: Optional[bool] = None,
34
+ use_trt_accelerator: Optional[bool] = None,
35
+ ):
36
+ """
37
+ Initialize the ModelManagerWrapper.
38
+
39
+ Args:
40
+ action_tracker: Action tracker for category mapping and configuration.
41
+ test_env: If True, use provided parameters for testing; if False, extract from action_tracker.
42
+ model_type: Type of model manager ("default" for ModelManager, "triton" for TritonModelManager).
43
+ model_id: ID of the model (for ModelManager).
44
+ internal_server_type: Type of internal server (e.g., "rest", "grpc").
45
+ internal_port: Internal port number.
46
+ internal_host: Internal host address.
47
+ num_model_instances: Number of model instances to create.
48
+ load_model: Function to load the model (for ModelManager).
49
+ predict: Function to run predictions (for ModelManager).
50
+ batch_predict: Function to run batch predictions (for ModelManager).
51
+ model_name: Name of the model (for TritonModelManager).
52
+ model_path: Path to the model (for TritonModelManager).
53
+ runtime_framework: Runtime framework for the model (for TritonModelManager).
54
+ input_size: Input size for the model (for TritonModelManager).
55
+ num_classes: Number of classes for the model (for TritonModelManager).
56
+ use_dynamic_batching: Whether to use dynamic batching (for TritonModelManager).
57
+ max_batch_size: Maximum batch size (for TritonModelManager).
58
+ is_yolo: Whether the model is YOLO (for TritonModelManager).
59
+ is_ocr: Whether the model is OCR (for TritonModelManager).
60
+ use_trt_accelerator: Whether to use TensorRT accelerator (for TritonModelManager).
61
+ """
62
+ self.logger = logging.getLogger(__name__)
63
+ self.action_tracker = action_tracker
64
+ self.test_env = test_env
65
+ self.model_type = model_type.lower() if model_type else "default"
66
+
67
+ # Validate model_type
68
+ if self.model_type not in ["default", "triton"]:
69
+ raise ValueError(f"Invalid model_type '{self.model_type}'. Must be 'default' or 'triton'")
70
+
71
+ # Default configuration for production : TODO
72
+ default_config = {
73
+ "model_id": "",
74
+ "internal_server_type": "rest",
75
+ "internal_port": 8000,
76
+ "internal_host": "localhost",
77
+ "num_model_instances": 1,
78
+ "load_model": None,
79
+ "predict": None,
80
+ "batch_predict": None,
81
+ "model_name": "",
82
+ "model_path": "",
83
+ "runtime_framework": "onnx",
84
+ "input_size": 640,
85
+ "num_classes": 80,
86
+ "use_dynamic_batching": False,
87
+ "max_batch_size": 8,
88
+ "is_yolo": False,
89
+ "is_ocr": False,
90
+ "use_trt_accelerator": False,
91
+ }
92
+
93
+ # Extract configuration from action_tracker for production
94
+ if not test_env:
95
+ config = self._extract_config_from_action_tracker(
96
+ model_type=model_type,
97
+ model_id=model_id,
98
+ internal_server_type=internal_server_type,
99
+ internal_port=internal_port,
100
+ internal_host=internal_host,
101
+ num_model_instances=num_model_instances,
102
+ load_model=load_model,
103
+ predict=predict,
104
+ batch_predict=batch_predict,
105
+ model_name=model_name,
106
+ model_path=model_path,
107
+ runtime_framework=runtime_framework,
108
+ input_size=input_size,
109
+ num_classes=num_classes,
110
+ use_dynamic_batching=use_dynamic_batching,
111
+ max_batch_size=max_batch_size,
112
+ is_yolo=is_yolo,
113
+ is_ocr=is_ocr,
114
+ use_trt_accelerator=use_trt_accelerator,
115
+ )
116
+ if not config:
117
+ self.logger.warning("No valid configuration found in action_tracker, using defaults")
118
+ config = default_config
119
+ else:
120
+ for key, value in default_config.items():
121
+ if key not in config or config[key] is None:
122
+ self.logger.warning(f"Missing or None config key '{key}' in action_tracker, using default: {value}")
123
+ config[key] = value
124
+ else:
125
+ # User provided args for testing
126
+ config = {
127
+ "model_id": model_id,
128
+ "internal_server_type": internal_server_type,
129
+ "internal_port": internal_port,
130
+ "internal_host": internal_host,
131
+ "num_model_instances": num_model_instances,
132
+ "load_model": load_model,
133
+ "predict": predict,
134
+ "batch_predict": batch_predict,
135
+ "model_name": model_name,
136
+ "model_path": model_path,
137
+ "runtime_framework": runtime_framework,
138
+ "input_size": input_size,
139
+ "num_classes": num_classes,
140
+ "use_dynamic_batching": use_dynamic_batching,
141
+ "max_batch_size": max_batch_size,
142
+ "is_yolo": is_yolo,
143
+ "is_ocr": is_ocr,
144
+ "use_trt_accelerator": use_trt_accelerator,
145
+ }
146
+ for key, value in default_config.items():
147
+ if config[key] is None:
148
+ self.logger.warning(f"Missing or None config key '{key}' in test environment, using default: {value}")
149
+ config[key] = value
150
+
151
+ if self.model_type == "triton":
152
+ # Validate required parameters for TritonModelManager
153
+ required_triton_params = ["model_name", "model_path", "runtime_framework"]
154
+ for param in required_triton_params:
155
+ if not config.get(param):
156
+ raise ValueError(f"Required parameter '{param}' is missing or invalid for Triton model manager")
157
+
158
+ self.model_manager = TritonModelManager(
159
+ model_name=config["model_name"],
160
+ model_path=config["model_path"],
161
+ runtime_framework=config["runtime_framework"],
162
+ internal_server_type=config["internal_server_type"],
163
+ internal_port=config["internal_port"],
164
+ internal_host=config["internal_host"],
165
+ input_size=config["input_size"],
166
+ num_classes=config["num_classes"],
167
+ num_model_instances=config["num_model_instances"],
168
+ use_dynamic_batching=config["use_dynamic_batching"],
169
+ max_batch_size=config["max_batch_size"],
170
+ is_yolo=config["is_yolo"],
171
+ is_ocr=config["is_ocr"],
172
+ use_trt_accelerator=config["use_trt_accelerator"],
173
+ )
174
+ else:
175
+ # Validate required parameters for ModelManager
176
+ if not self.action_tracker:
177
+ raise ValueError("action_tracker is required for default ModelManager")
178
+
179
+ # Validate that at least one prediction function is provided
180
+ if not config.get("predict") and not config.get("load_model"):
181
+ self.logger.warning("No prediction functions provided for ModelManager. At least 'predict' and 'load_model' should be provided.")
182
+
183
+ self.model_manager = ModelManager(
184
+ action_tracker=self.action_tracker,
185
+ load_model=config["load_model"],
186
+ predict=config["predict"],
187
+ batch_predict=config.get("batch_predict"),
188
+ num_model_instances=config["num_model_instances"],
189
+ model_path=config.get("model_path"),
190
+ )
191
+
192
+ self.logger.info(f"Initialized ModelManagerWrapper with {self.model_type} model manager")
193
+
194
+ def _extract_config_from_action_tracker(
195
+ self,
196
+ model_type: Optional[str] = None,
197
+ model_id: Optional[str] = None,
198
+ internal_server_type: Optional[str] = None,
199
+ internal_port: Optional[int] = None,
200
+ internal_host: Optional[str] = None,
201
+ num_model_instances: Optional[int] = None,
202
+ load_model: Optional[Callable] = None,
203
+ predict: Optional[Callable] = None,
204
+ batch_predict: Optional[Callable] = None,
205
+ model_name: Optional[str] = None,
206
+ model_path: Optional[str] = None,
207
+ runtime_framework: Optional[str] = None,
208
+ input_size: Optional[Union[int, List[int]]] = None,
209
+ num_classes: Optional[int] = None,
210
+ use_dynamic_batching: Optional[bool] = None,
211
+ max_batch_size: Optional[int] = None,
212
+ is_yolo: Optional[bool] = None,
213
+ is_ocr: Optional[bool] = None,
214
+ use_trt_accelerator: Optional[bool] = None,
215
+ ) -> Dict[str, Any]:
216
+ """
217
+ Extract configuration from action_tracker for production use.
218
+
219
+ Prioritizes configuration from action_tracker, then user-provided arguments, and finally
220
+ defaults. Logs warnings when falling back to user arguments or defaults.
221
+
222
+ Args:
223
+ model_type: User-provided model type.
224
+ model_id: User-provided model ID.
225
+ internal_server_type: User-provided server type.
226
+ internal_port: User-provided port.
227
+ internal_host: User-provided host.
228
+ num_model_instances: User-provided number of model instances.
229
+ load_model: User-provided load model function.
230
+ predict: User-provided predict function.
231
+ batch_predict: User-provided batch predict function.
232
+ model_name: User-provided model name.
233
+ model_path: User-provided model path.
234
+ runtime_framework: User-provided runtime framework.
235
+ input_size: User-provided input size.
236
+ num_classes: User-provided number of classes.
237
+ use_dynamic_batching: User-provided dynamic batching flag.
238
+ max_batch_size: User-provided max batch size.
239
+ is_yolo: User-provided YOLO flag.
240
+ is_ocr: User-provided OCR flag.
241
+ use_trt_accelerator: User-provided TensorRT accelerator flag.
242
+
243
+ Returns:
244
+ Configuration dictionary extracted from action_tracker, user arguments, or defaults.
245
+ """
246
+ try:
247
+ config = {}
248
+ # Default configuration
249
+ default_config = {
250
+ "model_id": "default_model",
251
+ "internal_server_type": "rest",
252
+ "internal_port": 8000,
253
+ "internal_host": "localhost",
254
+ "num_model_instances": 1,
255
+ "load_model": None,
256
+ "predict": None,
257
+ "batch_predict": None,
258
+ "model_name": "default_model",
259
+ "model_path": "/models/default",
260
+ "runtime_framework": "onnx",
261
+ "input_size": 640,
262
+ "num_classes": 10,
263
+ "use_dynamic_batching": False,
264
+ "max_batch_size": 8,
265
+ "is_yolo": False,
266
+ "is_ocr": False,
267
+ "use_trt_accelerator": False,
268
+ }
269
+
270
+ job_params = getattr(self.action_tracker, "job_params", {})
271
+ action_details = getattr(self.action_tracker, "action_details", {})
272
+ action_tracker_sources: List[Dict, Any] = [action_details, job_params]
273
+ # priority: action_tracker > user_arg > default
274
+ def get_param(
275
+ key: str,
276
+ action_tracker_sources: List[Union[Dict, Any]],
277
+ user_value: Optional[Any],
278
+ default_value: Any,
279
+ ) -> Any:
280
+ value = None
281
+ source = None
282
+ # Try action_tracker sources first
283
+ for src in action_tracker_sources:
284
+ if isinstance(src, dict):
285
+ value = src.get(key)
286
+ else:
287
+ value = src
288
+ if value is not None:
289
+ source = "action_tracker"
290
+ break
291
+
292
+ # If !action_tracker.val, try user-provided argument
293
+ if value is None and user_value is not None:
294
+ value = user_value
295
+ source = "user-provided"
296
+
297
+ # use default as fallback
298
+ if value is None:
299
+ value = default_value
300
+ source = "default"
301
+
302
+ if source != "action_tracker":
303
+ self.logger.warning(
304
+ f"Config key '{key}' not found in action_tracker, using {source} value: {value}"
305
+ )
306
+ return value
307
+
308
+ # Common params for both ModelManager and TritonModelManager
309
+ config["model_id"] = get_param(
310
+ "model_id",
311
+ [
312
+ getattr(self.action_tracker, "_idModel_str", None),
313
+ action_details.get("_idModelDeploy"),
314
+ action_details.get("_idModel"),
315
+ ],
316
+ model_id,
317
+ default_config["model_id"],
318
+ )
319
+
320
+ config["internal_server_type"] = get_param(
321
+ "internal_server_type",
322
+ [action_details.get("server_type", "rest").lower()],
323
+ internal_server_type,
324
+ default_config["internal_server_type"],
325
+ )
326
+
327
+ # Map protocol to default port
328
+ protocol2port = {"rest": 8000, "grpc": 8001}
329
+ action_server_type = action_details.get("server_type", "rest").lower()
330
+ config["internal_port"] = get_param(
331
+ "internal_port",
332
+ [protocol2port.get(action_server_type, 8000)],
333
+ internal_port,
334
+ protocol2port.get(config["internal_server_type"], default_config["internal_port"]),
335
+ )
336
+
337
+ config["internal_host"] = get_param(
338
+ "internal_host",
339
+ [action_details.get("host")],
340
+ internal_host,
341
+ default_config["internal_host"],
342
+ )
343
+
344
+ config["num_model_instances"] = get_param(
345
+ "num_model_instances",
346
+ [action_details.get("num_model_instances")],
347
+ num_model_instances,
348
+ default_config["num_model_instances"],
349
+ )
350
+
351
+ # ModelManager-specific parameters
352
+ config["load_model"] = load_model
353
+ config["predict"] = predict
354
+ config["batch_predict"] = batch_predict
355
+
356
+ if self.model_type != "triton":
357
+ return config
358
+
359
+ # TritonModelManager-specific parameters
360
+ config["model_name"] = get_param(
361
+ "model_name",
362
+ [action_details.get("modelKey")],
363
+ model_name,
364
+ default_config["model_name"],
365
+ )
366
+
367
+ config["model_path"] = get_param(
368
+ "model_path",
369
+ [getattr(self.action_tracker, "checkpoint_path", None)],
370
+ model_path,
371
+ default_config["model_path"],
372
+ )
373
+
374
+ config["runtime_framework"] = get_param(
375
+ "runtime_framework",
376
+ [action_details.get("runtimeFramework"), action_details.get("exportFormat")],
377
+ runtime_framework,
378
+ default_config["runtime_framework"],
379
+ )
380
+
381
+ config["input_size"] = get_param(
382
+ "input_size",
383
+ # [self.action_tracker.get_input_size()], TODO: Enable after the API is working
384
+ action_tracker_sources,
385
+ input_size,
386
+ default_config["input_size"],
387
+ )
388
+
389
+ index_to_category = self.action_tracker.get_index_to_category(self.action_tracker.is_exported)
390
+ num_classes_action = len(index_to_category) if index_to_category else None
391
+ config["num_classes"] = get_param(
392
+ "num_classes",
393
+ [num_classes_action],
394
+ num_classes,
395
+ default_config["num_classes"],
396
+ )
397
+
398
+ config["use_dynamic_batching"] = get_param(
399
+ "use_dynamic_batching",
400
+ [],
401
+ use_dynamic_batching,
402
+ default_config["use_dynamic_batching"],
403
+ )
404
+
405
+ config["max_batch_size"] = get_param(
406
+ "max_batch_size",
407
+ [],
408
+ max_batch_size,
409
+ default_config["max_batch_size"],
410
+ )
411
+
412
+ config["is_yolo"] = get_param(
413
+ "is_yolo",
414
+ [],
415
+ is_yolo,
416
+ default_config["is_yolo"],
417
+ )
418
+
419
+ config["is_ocr"] = get_param(
420
+ "is_ocr",
421
+ [],
422
+ is_ocr,
423
+ default_config["is_ocr"],
424
+ )
425
+
426
+ config["use_trt_accelerator"] = get_param(
427
+ "use_trt_accelerator",
428
+ [action_details.get("use_trt_accelerator"), job_params.get("use_trt_accelerator")],
429
+ use_trt_accelerator,
430
+ default_config["use_trt_accelerator"],
431
+ )
432
+
433
+ return config
434
+ except Exception as e:
435
+ self.logger.error(f"Failed to extract config from action_tracker: {str(e)}", exc_info=True)
436
+ return {}
437
+
438
+ def inference(
439
+ self,
440
+ input: Any,
441
+ extra_params: Optional[Dict[str, Any]] = None,
442
+ stream_key: Optional[str] = None,
443
+ stream_info: Optional[Dict[str, Any]] = None,
444
+ ) -> Tuple[Any, bool]:
445
+ """
446
+ Perform synchronous single inference.
447
+ """
448
+ if input is None:
449
+ raise ValueError("input cannot be None")
450
+
451
+ try:
452
+ if self.model_type == "triton":
453
+ # TritonModelManager only accepts input parameter
454
+ return self.model_manager.inference(input=input)
455
+ else:
456
+ # ModelManager accepts additional parameters
457
+ return self.model_manager.inference(
458
+ input=input,
459
+ extra_params=extra_params,
460
+ stream_key=stream_key,
461
+ stream_info=stream_info,
462
+ )
463
+ except Exception as e:
464
+ self.logger.error(f"Inference failed in ModelManagerWrapper: {str(e)}", exc_info=True)
465
+ return None, False
466
+
467
+ async def async_inference(
468
+ self,
469
+ input: Union[bytes, np.ndarray],
470
+ extra_params: Optional[Dict[str, Any]] = None,
471
+ stream_key: Optional[str] = None,
472
+ stream_info: Optional[Dict[str, Any]] = None,
473
+ ) -> Tuple[Any, bool]:
474
+ """
475
+ Perform asynchronous single inference.
476
+ """
477
+ if input is None:
478
+ raise ValueError("input cannot be None")
479
+
480
+ try:
481
+ if self.model_type == "triton":
482
+ return await self.model_manager.async_inference(input=input)
483
+ else:
484
+ # ModelManager doesn't have async_inference, fallback to sync
485
+ return self.model_manager.inference(
486
+ input=input,
487
+ extra_params=extra_params,
488
+ stream_key=stream_key,
489
+ stream_info=stream_info,
490
+ )
491
+ except Exception as e:
492
+ self.logger.error(f"Async inference failed in ModelManagerWrapper: {str(e)}", exc_info=True)
493
+ return None, False
494
+
495
+ def batch_inference(
496
+ self,
497
+ input: List[Any],
498
+ extra_params: Optional[Dict[str, Any]] = None,
499
+ stream_key: Optional[str] = None,
500
+ stream_info: Optional[Dict[str, Any]] = None,
501
+ ) -> Tuple[List[Any], bool]:
502
+ """
503
+ Perform synchronous batch inference.
504
+ """
505
+ if not input:
506
+ raise ValueError("input cannot be None or empty")
507
+
508
+ try:
509
+ if self.model_type == "triton":
510
+ # TritonModelManager only accepts input parameter
511
+ return self.model_manager.batch_inference(input=input)
512
+ else:
513
+ # ModelManager accepts additional parameters
514
+ return self.model_manager.batch_inference(
515
+ input=input,
516
+ extra_params=extra_params,
517
+ stream_key=stream_key,
518
+ stream_info=stream_info,
519
+ )
520
+ except Exception as e:
521
+ self.logger.error(f"Batch inference failed in ModelManagerWrapper: {str(e)}", exc_info=True)
522
+ return [], False
523
+
524
+ async def async_batch_inference(
525
+ self,
526
+ input: List[Any],
527
+ extra_params: Optional[Dict[str, Any]] = None,
528
+ stream_key: Optional[str] = None,
529
+ stream_info: Optional[Dict[str, Any]] = None,
530
+ ) -> Tuple[List[Any], bool]:
531
+ """
532
+ Perform asynchronous batch inference.
533
+ """
534
+ if not input:
535
+ raise ValueError("input cannot be None or empty")
536
+
537
+ try:
538
+ if self.model_type == "triton":
539
+ return await self.model_manager.async_batch_inference(input=input)
540
+ else:
541
+ # ModelManager doesn't have async_batch_inference, fallback to sync
542
+ return self.model_manager.batch_inference(
543
+ input=input,
544
+ extra_params=extra_params,
545
+ stream_key=stream_key,
546
+ stream_info=stream_info,
547
+ )
548
+ except Exception as e:
549
+ self.logger.error(f"Async batch inference failed in ModelManagerWrapper: {str(e)}", exc_info=True)
550
+ return [], False