nv-ingest-api 26.1.0rc4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nv-ingest-api might be problematic. Click here for more details.
- nv_ingest_api/__init__.py +3 -0
- nv_ingest_api/interface/__init__.py +218 -0
- nv_ingest_api/interface/extract.py +977 -0
- nv_ingest_api/interface/mutate.py +154 -0
- nv_ingest_api/interface/store.py +200 -0
- nv_ingest_api/interface/transform.py +382 -0
- nv_ingest_api/interface/utility.py +186 -0
- nv_ingest_api/internal/__init__.py +0 -0
- nv_ingest_api/internal/enums/__init__.py +3 -0
- nv_ingest_api/internal/enums/common.py +550 -0
- nv_ingest_api/internal/extract/__init__.py +3 -0
- nv_ingest_api/internal/extract/audio/__init__.py +3 -0
- nv_ingest_api/internal/extract/audio/audio_extraction.py +202 -0
- nv_ingest_api/internal/extract/docx/__init__.py +5 -0
- nv_ingest_api/internal/extract/docx/docx_extractor.py +232 -0
- nv_ingest_api/internal/extract/docx/engines/__init__.py +0 -0
- nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/__init__.py +3 -0
- nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/docx_helper.py +127 -0
- nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/docxreader.py +971 -0
- nv_ingest_api/internal/extract/html/__init__.py +3 -0
- nv_ingest_api/internal/extract/html/html_extractor.py +84 -0
- nv_ingest_api/internal/extract/image/__init__.py +3 -0
- nv_ingest_api/internal/extract/image/chart_extractor.py +375 -0
- nv_ingest_api/internal/extract/image/image_extractor.py +208 -0
- nv_ingest_api/internal/extract/image/image_helpers/__init__.py +3 -0
- nv_ingest_api/internal/extract/image/image_helpers/common.py +433 -0
- nv_ingest_api/internal/extract/image/infographic_extractor.py +290 -0
- nv_ingest_api/internal/extract/image/ocr_extractor.py +407 -0
- nv_ingest_api/internal/extract/image/table_extractor.py +391 -0
- nv_ingest_api/internal/extract/pdf/__init__.py +3 -0
- nv_ingest_api/internal/extract/pdf/engines/__init__.py +19 -0
- nv_ingest_api/internal/extract/pdf/engines/adobe.py +484 -0
- nv_ingest_api/internal/extract/pdf/engines/llama.py +246 -0
- nv_ingest_api/internal/extract/pdf/engines/nemotron_parse.py +598 -0
- nv_ingest_api/internal/extract/pdf/engines/pdf_helpers/__init__.py +166 -0
- nv_ingest_api/internal/extract/pdf/engines/pdfium.py +652 -0
- nv_ingest_api/internal/extract/pdf/engines/tika.py +96 -0
- nv_ingest_api/internal/extract/pdf/engines/unstructured_io.py +426 -0
- nv_ingest_api/internal/extract/pdf/pdf_extractor.py +74 -0
- nv_ingest_api/internal/extract/pptx/__init__.py +5 -0
- nv_ingest_api/internal/extract/pptx/engines/__init__.py +0 -0
- nv_ingest_api/internal/extract/pptx/engines/pptx_helper.py +968 -0
- nv_ingest_api/internal/extract/pptx/pptx_extractor.py +210 -0
- nv_ingest_api/internal/meta/__init__.py +3 -0
- nv_ingest_api/internal/meta/udf.py +232 -0
- nv_ingest_api/internal/mutate/__init__.py +3 -0
- nv_ingest_api/internal/mutate/deduplicate.py +110 -0
- nv_ingest_api/internal/mutate/filter.py +133 -0
- nv_ingest_api/internal/primitives/__init__.py +0 -0
- nv_ingest_api/internal/primitives/control_message_task.py +16 -0
- nv_ingest_api/internal/primitives/ingest_control_message.py +307 -0
- nv_ingest_api/internal/primitives/nim/__init__.py +9 -0
- nv_ingest_api/internal/primitives/nim/default_values.py +14 -0
- nv_ingest_api/internal/primitives/nim/model_interface/__init__.py +3 -0
- nv_ingest_api/internal/primitives/nim/model_interface/cached.py +274 -0
- nv_ingest_api/internal/primitives/nim/model_interface/decorators.py +56 -0
- nv_ingest_api/internal/primitives/nim/model_interface/deplot.py +270 -0
- nv_ingest_api/internal/primitives/nim/model_interface/helpers.py +338 -0
- nv_ingest_api/internal/primitives/nim/model_interface/nemotron_parse.py +239 -0
- nv_ingest_api/internal/primitives/nim/model_interface/ocr.py +776 -0
- nv_ingest_api/internal/primitives/nim/model_interface/parakeet.py +367 -0
- nv_ingest_api/internal/primitives/nim/model_interface/text_embedding.py +129 -0
- nv_ingest_api/internal/primitives/nim/model_interface/vlm.py +177 -0
- nv_ingest_api/internal/primitives/nim/model_interface/yolox.py +1681 -0
- nv_ingest_api/internal/primitives/nim/nim_client.py +801 -0
- nv_ingest_api/internal/primitives/nim/nim_model_interface.py +126 -0
- nv_ingest_api/internal/primitives/tracing/__init__.py +0 -0
- nv_ingest_api/internal/primitives/tracing/latency.py +69 -0
- nv_ingest_api/internal/primitives/tracing/logging.py +96 -0
- nv_ingest_api/internal/primitives/tracing/tagging.py +288 -0
- nv_ingest_api/internal/schemas/__init__.py +3 -0
- nv_ingest_api/internal/schemas/extract/__init__.py +3 -0
- nv_ingest_api/internal/schemas/extract/extract_audio_schema.py +133 -0
- nv_ingest_api/internal/schemas/extract/extract_chart_schema.py +144 -0
- nv_ingest_api/internal/schemas/extract/extract_docx_schema.py +129 -0
- nv_ingest_api/internal/schemas/extract/extract_html_schema.py +34 -0
- nv_ingest_api/internal/schemas/extract/extract_image_schema.py +126 -0
- nv_ingest_api/internal/schemas/extract/extract_infographic_schema.py +137 -0
- nv_ingest_api/internal/schemas/extract/extract_ocr_schema.py +137 -0
- nv_ingest_api/internal/schemas/extract/extract_pdf_schema.py +220 -0
- nv_ingest_api/internal/schemas/extract/extract_pptx_schema.py +128 -0
- nv_ingest_api/internal/schemas/extract/extract_table_schema.py +137 -0
- nv_ingest_api/internal/schemas/message_brokers/__init__.py +3 -0
- nv_ingest_api/internal/schemas/message_brokers/message_broker_client_schema.py +37 -0
- nv_ingest_api/internal/schemas/message_brokers/request_schema.py +34 -0
- nv_ingest_api/internal/schemas/message_brokers/response_schema.py +19 -0
- nv_ingest_api/internal/schemas/meta/__init__.py +3 -0
- nv_ingest_api/internal/schemas/meta/base_model_noext.py +11 -0
- nv_ingest_api/internal/schemas/meta/ingest_job_schema.py +355 -0
- nv_ingest_api/internal/schemas/meta/metadata_schema.py +394 -0
- nv_ingest_api/internal/schemas/meta/udf.py +23 -0
- nv_ingest_api/internal/schemas/mixins.py +39 -0
- nv_ingest_api/internal/schemas/mutate/__init__.py +3 -0
- nv_ingest_api/internal/schemas/mutate/mutate_image_dedup_schema.py +16 -0
- nv_ingest_api/internal/schemas/store/__init__.py +3 -0
- nv_ingest_api/internal/schemas/store/store_embedding_schema.py +28 -0
- nv_ingest_api/internal/schemas/store/store_image_schema.py +45 -0
- nv_ingest_api/internal/schemas/transform/__init__.py +3 -0
- nv_ingest_api/internal/schemas/transform/transform_image_caption_schema.py +36 -0
- nv_ingest_api/internal/schemas/transform/transform_image_filter_schema.py +17 -0
- nv_ingest_api/internal/schemas/transform/transform_text_embedding_schema.py +48 -0
- nv_ingest_api/internal/schemas/transform/transform_text_splitter_schema.py +24 -0
- nv_ingest_api/internal/store/__init__.py +3 -0
- nv_ingest_api/internal/store/embed_text_upload.py +236 -0
- nv_ingest_api/internal/store/image_upload.py +251 -0
- nv_ingest_api/internal/transform/__init__.py +3 -0
- nv_ingest_api/internal/transform/caption_image.py +219 -0
- nv_ingest_api/internal/transform/embed_text.py +702 -0
- nv_ingest_api/internal/transform/split_text.py +182 -0
- nv_ingest_api/util/__init__.py +3 -0
- nv_ingest_api/util/control_message/__init__.py +0 -0
- nv_ingest_api/util/control_message/validators.py +47 -0
- nv_ingest_api/util/converters/__init__.py +0 -0
- nv_ingest_api/util/converters/bytetools.py +78 -0
- nv_ingest_api/util/converters/containers.py +65 -0
- nv_ingest_api/util/converters/datetools.py +90 -0
- nv_ingest_api/util/converters/dftools.py +127 -0
- nv_ingest_api/util/converters/formats.py +64 -0
- nv_ingest_api/util/converters/type_mappings.py +27 -0
- nv_ingest_api/util/dataloader/__init__.py +9 -0
- nv_ingest_api/util/dataloader/dataloader.py +409 -0
- nv_ingest_api/util/detectors/__init__.py +5 -0
- nv_ingest_api/util/detectors/language.py +38 -0
- nv_ingest_api/util/exception_handlers/__init__.py +0 -0
- nv_ingest_api/util/exception_handlers/converters.py +72 -0
- nv_ingest_api/util/exception_handlers/decorators.py +429 -0
- nv_ingest_api/util/exception_handlers/detectors.py +74 -0
- nv_ingest_api/util/exception_handlers/pdf.py +116 -0
- nv_ingest_api/util/exception_handlers/schemas.py +68 -0
- nv_ingest_api/util/image_processing/__init__.py +5 -0
- nv_ingest_api/util/image_processing/clustering.py +260 -0
- nv_ingest_api/util/image_processing/processing.py +177 -0
- nv_ingest_api/util/image_processing/table_and_chart.py +504 -0
- nv_ingest_api/util/image_processing/transforms.py +850 -0
- nv_ingest_api/util/imports/__init__.py +3 -0
- nv_ingest_api/util/imports/callable_signatures.py +108 -0
- nv_ingest_api/util/imports/dynamic_resolvers.py +158 -0
- nv_ingest_api/util/introspection/__init__.py +3 -0
- nv_ingest_api/util/introspection/class_inspect.py +145 -0
- nv_ingest_api/util/introspection/function_inspect.py +65 -0
- nv_ingest_api/util/logging/__init__.py +0 -0
- nv_ingest_api/util/logging/configuration.py +102 -0
- nv_ingest_api/util/logging/sanitize.py +84 -0
- nv_ingest_api/util/message_brokers/__init__.py +3 -0
- nv_ingest_api/util/message_brokers/qos_scheduler.py +283 -0
- nv_ingest_api/util/message_brokers/simple_message_broker/__init__.py +9 -0
- nv_ingest_api/util/message_brokers/simple_message_broker/broker.py +465 -0
- nv_ingest_api/util/message_brokers/simple_message_broker/ordered_message_queue.py +71 -0
- nv_ingest_api/util/message_brokers/simple_message_broker/simple_client.py +455 -0
- nv_ingest_api/util/metadata/__init__.py +5 -0
- nv_ingest_api/util/metadata/aggregators.py +516 -0
- nv_ingest_api/util/multi_processing/__init__.py +8 -0
- nv_ingest_api/util/multi_processing/mp_pool_singleton.py +200 -0
- nv_ingest_api/util/nim/__init__.py +161 -0
- nv_ingest_api/util/pdf/__init__.py +3 -0
- nv_ingest_api/util/pdf/pdfium.py +428 -0
- nv_ingest_api/util/schema/__init__.py +3 -0
- nv_ingest_api/util/schema/schema_validator.py +10 -0
- nv_ingest_api/util/service_clients/__init__.py +3 -0
- nv_ingest_api/util/service_clients/client_base.py +86 -0
- nv_ingest_api/util/service_clients/kafka/__init__.py +3 -0
- nv_ingest_api/util/service_clients/redis/__init__.py +3 -0
- nv_ingest_api/util/service_clients/redis/redis_client.py +983 -0
- nv_ingest_api/util/service_clients/rest/__init__.py +0 -0
- nv_ingest_api/util/service_clients/rest/rest_client.py +595 -0
- nv_ingest_api/util/string_processing/__init__.py +51 -0
- nv_ingest_api/util/string_processing/configuration.py +682 -0
- nv_ingest_api/util/string_processing/yaml.py +109 -0
- nv_ingest_api/util/system/__init__.py +0 -0
- nv_ingest_api/util/system/hardware_info.py +594 -0
- nv_ingest_api-26.1.0rc4.dist-info/METADATA +237 -0
- nv_ingest_api-26.1.0rc4.dist-info/RECORD +177 -0
- nv_ingest_api-26.1.0rc4.dist-info/WHEEL +5 -0
- nv_ingest_api-26.1.0rc4.dist-info/licenses/LICENSE +201 -0
- nv_ingest_api-26.1.0rc4.dist-info/top_level.txt +2 -0
- udfs/__init__.py +5 -0
- udfs/llm_summarizer_udf.py +259 -0
|
@@ -0,0 +1,801 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
import hashlib
|
|
6
|
+
import json
|
|
7
|
+
import logging
|
|
8
|
+
import re
|
|
9
|
+
import threading
|
|
10
|
+
import time
|
|
11
|
+
import queue
|
|
12
|
+
from collections import namedtuple
|
|
13
|
+
from concurrent.futures import Future, ThreadPoolExecutor, as_completed
|
|
14
|
+
from typing import Any
|
|
15
|
+
from typing import Optional
|
|
16
|
+
from typing import Tuple, Union
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
import requests
|
|
20
|
+
import tritonclient.grpc as grpcclient
|
|
21
|
+
|
|
22
|
+
from nv_ingest_api.internal.primitives.tracing.tagging import traceable_func
|
|
23
|
+
from nv_ingest_api.util.string_processing import generate_url
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
# Regex pattern to detect CUDA-related errors in Triton gRPC responses
|
|
29
|
+
CUDA_ERROR_REGEX = re.compile(
|
|
30
|
+
r"(model reload|illegal memory access|illegal instruction|invalid argument|failed to (copy|load|perform) .*: .*|TritonModelException: failed to copy data: .*)", # noqa: E501
|
|
31
|
+
re.IGNORECASE,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
# A simple structure to hold a request's data and its Future for the result
|
|
35
|
+
InferenceRequest = namedtuple("InferenceRequest", ["data", "future", "model_name", "dims", "kwargs"])
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class NimClient:
|
|
39
|
+
"""
|
|
40
|
+
A client for interfacing with a model inference server using gRPC or HTTP protocols.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
model_interface,
|
|
46
|
+
protocol: str,
|
|
47
|
+
endpoints: Tuple[str, str],
|
|
48
|
+
auth_token: Optional[str] = None,
|
|
49
|
+
timeout: float = 120.0,
|
|
50
|
+
max_retries: int = 10,
|
|
51
|
+
max_429_retries: int = 5,
|
|
52
|
+
enable_dynamic_batching: bool = False,
|
|
53
|
+
dynamic_batch_timeout: float = 0.1, # 100 milliseconds
|
|
54
|
+
dynamic_batch_memory_budget_mb: Optional[float] = None,
|
|
55
|
+
):
|
|
56
|
+
"""
|
|
57
|
+
Initialize the NimClient with the specified model interface, protocol, and server endpoints.
|
|
58
|
+
|
|
59
|
+
Parameters
|
|
60
|
+
----------
|
|
61
|
+
model_interface : ModelInterface
|
|
62
|
+
The model interface implementation to use.
|
|
63
|
+
protocol : str
|
|
64
|
+
The protocol to use ("grpc" or "http").
|
|
65
|
+
endpoints : tuple
|
|
66
|
+
A tuple containing the gRPC and HTTP endpoints.
|
|
67
|
+
auth_token : str, optional
|
|
68
|
+
Authorization token for HTTP requests (default: None).
|
|
69
|
+
timeout : float, optional
|
|
70
|
+
Timeout for HTTP requests in seconds (default: 120.0).
|
|
71
|
+
max_retries : int, optional
|
|
72
|
+
The maximum number of retries for non-429 server-side errors (default: 10).
|
|
73
|
+
max_429_retries : int, optional
|
|
74
|
+
The maximum number of retries specifically for 429 errors (default: 5).
|
|
75
|
+
|
|
76
|
+
Raises
|
|
77
|
+
------
|
|
78
|
+
ValueError
|
|
79
|
+
If an invalid protocol is specified or if required endpoints are missing.
|
|
80
|
+
"""
|
|
81
|
+
self.client = None
|
|
82
|
+
self.model_interface = model_interface
|
|
83
|
+
self.protocol = protocol.lower()
|
|
84
|
+
self.auth_token = auth_token
|
|
85
|
+
self.timeout = timeout # Timeout for HTTP requests
|
|
86
|
+
self.max_retries = max_retries
|
|
87
|
+
self.max_429_retries = max_429_retries
|
|
88
|
+
self._grpc_endpoint, self._http_endpoint = endpoints
|
|
89
|
+
self._max_batch_sizes = {}
|
|
90
|
+
self._lock = threading.Lock()
|
|
91
|
+
|
|
92
|
+
if self.protocol == "grpc":
|
|
93
|
+
if not self._grpc_endpoint:
|
|
94
|
+
raise ValueError("gRPC endpoint must be provided for gRPC protocol")
|
|
95
|
+
logger.debug(f"Creating gRPC client with {self._grpc_endpoint}")
|
|
96
|
+
self.client = grpcclient.InferenceServerClient(url=self._grpc_endpoint)
|
|
97
|
+
elif self.protocol == "http":
|
|
98
|
+
if not self._http_endpoint:
|
|
99
|
+
raise ValueError("HTTP endpoint must be provided for HTTP protocol")
|
|
100
|
+
logger.debug(f"Creating HTTP client with {self._http_endpoint}")
|
|
101
|
+
self.endpoint_url = generate_url(self._http_endpoint)
|
|
102
|
+
self.headers = {"accept": "application/json", "content-type": "application/json"}
|
|
103
|
+
if self.auth_token:
|
|
104
|
+
self.headers["Authorization"] = f"Bearer {self.auth_token}"
|
|
105
|
+
else:
|
|
106
|
+
raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.")
|
|
107
|
+
|
|
108
|
+
self.dynamic_batching_enabled = enable_dynamic_batching
|
|
109
|
+
if self.dynamic_batching_enabled:
|
|
110
|
+
self._batch_timeout = dynamic_batch_timeout
|
|
111
|
+
if dynamic_batch_memory_budget_mb is not None:
|
|
112
|
+
self._batch_memory_budget_bytes = dynamic_batch_memory_budget_mb * 1024 * 1024
|
|
113
|
+
else:
|
|
114
|
+
self._batch_memory_budget_bytes = None
|
|
115
|
+
|
|
116
|
+
self._request_queue = queue.Queue()
|
|
117
|
+
self._stop_event = threading.Event()
|
|
118
|
+
self._batcher_thread = threading.Thread(target=self._batcher_loop, daemon=True)
|
|
119
|
+
|
|
120
|
+
def start(self):
|
|
121
|
+
"""Starts the dynamic batching worker thread if enabled."""
|
|
122
|
+
if self.dynamic_batching_enabled and not self._batcher_thread.is_alive():
|
|
123
|
+
self._batcher_thread.start()
|
|
124
|
+
|
|
125
|
+
def _fetch_max_batch_size(self, model_name, model_version: str = "") -> int:
|
|
126
|
+
"""Fetch the maximum batch size from the Triton model configuration in a thread-safe manner."""
|
|
127
|
+
|
|
128
|
+
if model_name == "yolox_ensemble":
|
|
129
|
+
model_name = "yolox"
|
|
130
|
+
|
|
131
|
+
if model_name in self._max_batch_sizes:
|
|
132
|
+
return self._max_batch_sizes[model_name]
|
|
133
|
+
|
|
134
|
+
with self._lock:
|
|
135
|
+
# Double check, just in case another thread set the value while we were waiting
|
|
136
|
+
if model_name in self._max_batch_sizes:
|
|
137
|
+
return self._max_batch_sizes[model_name]
|
|
138
|
+
|
|
139
|
+
if not self._grpc_endpoint or not self.client:
|
|
140
|
+
self._max_batch_sizes[model_name] = 1
|
|
141
|
+
return 1
|
|
142
|
+
|
|
143
|
+
try:
|
|
144
|
+
model_config = self.client.get_model_config(model_name=model_name, model_version=model_version)
|
|
145
|
+
self._max_batch_sizes[model_name] = model_config.config.max_batch_size
|
|
146
|
+
logger.debug(f"Max batch size for model '{model_name}': {self._max_batch_sizes[model_name]}")
|
|
147
|
+
except Exception as e:
|
|
148
|
+
self._max_batch_sizes[model_name] = 1
|
|
149
|
+
logger.warning(f"Failed to retrieve max batch size: {e}, defaulting to 1")
|
|
150
|
+
|
|
151
|
+
return self._max_batch_sizes[model_name]
|
|
152
|
+
|
|
153
|
+
def _process_batch(self, batch_input, *, batch_data, model_name, **kwargs):
|
|
154
|
+
"""
|
|
155
|
+
Process a single batch input for inference using its corresponding batch_data.
|
|
156
|
+
|
|
157
|
+
Parameters
|
|
158
|
+
----------
|
|
159
|
+
batch_input : Any
|
|
160
|
+
The input data for this batch.
|
|
161
|
+
batch_data : Any
|
|
162
|
+
The corresponding scratch-pad data for this batch as returned by format_input.
|
|
163
|
+
model_name : str
|
|
164
|
+
The model name for inference.
|
|
165
|
+
kwargs : dict
|
|
166
|
+
Additional parameters.
|
|
167
|
+
|
|
168
|
+
Returns
|
|
169
|
+
-------
|
|
170
|
+
tuple
|
|
171
|
+
A tuple (parsed_output, batch_data) for subsequent post-processing.
|
|
172
|
+
"""
|
|
173
|
+
if self.protocol == "grpc":
|
|
174
|
+
logger.debug("Performing gRPC inference for a batch...")
|
|
175
|
+
response = self._grpc_infer(batch_input, model_name, **kwargs)
|
|
176
|
+
logger.debug("gRPC inference received response for a batch")
|
|
177
|
+
elif self.protocol == "http":
|
|
178
|
+
logger.debug("Performing HTTP inference for a batch...")
|
|
179
|
+
response = self._http_infer(batch_input)
|
|
180
|
+
logger.debug("HTTP inference received response for a batch")
|
|
181
|
+
else:
|
|
182
|
+
raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.")
|
|
183
|
+
|
|
184
|
+
parsed_output = self.model_interface.parse_output(
|
|
185
|
+
response, protocol=self.protocol, data=batch_data, model_name=model_name, **kwargs
|
|
186
|
+
)
|
|
187
|
+
return parsed_output, batch_data
|
|
188
|
+
|
|
189
|
+
def try_set_max_batch_size(self, model_name, model_version: str = ""):
|
|
190
|
+
"""Attempt to set the max batch size for the model if it is not already set, ensuring thread safety."""
|
|
191
|
+
self._fetch_max_batch_size(model_name, model_version)
|
|
192
|
+
|
|
193
|
+
@traceable_func(trace_name="{stage_name}::{model_name}")
|
|
194
|
+
def infer(self, data: dict, model_name: str, **kwargs) -> Any:
|
|
195
|
+
"""
|
|
196
|
+
Perform inference using the specified model and input data.
|
|
197
|
+
|
|
198
|
+
Parameters
|
|
199
|
+
----------
|
|
200
|
+
data : dict
|
|
201
|
+
The input data for inference.
|
|
202
|
+
model_name : str
|
|
203
|
+
The model name.
|
|
204
|
+
kwargs : dict
|
|
205
|
+
Additional parameters for inference.
|
|
206
|
+
|
|
207
|
+
Returns
|
|
208
|
+
-------
|
|
209
|
+
Any
|
|
210
|
+
The processed inference results, coalesced in the same order as the input images.
|
|
211
|
+
"""
|
|
212
|
+
# 1. Retrieve or default to the model's maximum batch size.
|
|
213
|
+
batch_size = self._fetch_max_batch_size(model_name)
|
|
214
|
+
max_requested_batch_size = kwargs.pop("max_batch_size", batch_size)
|
|
215
|
+
force_requested_batch_size = kwargs.pop("force_max_batch_size", False)
|
|
216
|
+
max_batch_size = (
|
|
217
|
+
max(1, min(batch_size, max_requested_batch_size))
|
|
218
|
+
if not force_requested_batch_size
|
|
219
|
+
else max_requested_batch_size
|
|
220
|
+
)
|
|
221
|
+
self._batch_size = max_batch_size
|
|
222
|
+
|
|
223
|
+
if self.dynamic_batching_enabled:
|
|
224
|
+
# DYNAMIC BATCHING PATH
|
|
225
|
+
try:
|
|
226
|
+
data = self.model_interface.prepare_data_for_inference(data)
|
|
227
|
+
|
|
228
|
+
futures = []
|
|
229
|
+
for base64_image, image_array in zip(data["base64_images"], data["images"]):
|
|
230
|
+
dims = image_array.shape[:2]
|
|
231
|
+
futures.append(self.submit(base64_image, model_name, dims, **kwargs))
|
|
232
|
+
|
|
233
|
+
results = [future.result() for future in futures]
|
|
234
|
+
|
|
235
|
+
return results
|
|
236
|
+
|
|
237
|
+
except Exception as err:
|
|
238
|
+
error_str = (
|
|
239
|
+
f"Error during synchronous infer with dynamic batching [{self.model_interface.name()}]: {err}"
|
|
240
|
+
)
|
|
241
|
+
logger.error(error_str)
|
|
242
|
+
raise RuntimeError(error_str) from err
|
|
243
|
+
|
|
244
|
+
# OFFLINE BATCHING PATH
|
|
245
|
+
try:
|
|
246
|
+
# 2. Prepare data for inference.
|
|
247
|
+
data = self.model_interface.prepare_data_for_inference(data)
|
|
248
|
+
|
|
249
|
+
# 3. Format the input based on protocol.
|
|
250
|
+
formatted_batches, formatted_batch_data = self.model_interface.format_input(
|
|
251
|
+
data,
|
|
252
|
+
protocol=self.protocol,
|
|
253
|
+
max_batch_size=max_batch_size,
|
|
254
|
+
model_name=model_name,
|
|
255
|
+
**kwargs,
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
# Check for a custom maximum pool worker count, and remove it from kwargs.
|
|
259
|
+
max_pool_workers = kwargs.pop("max_pool_workers", 16)
|
|
260
|
+
|
|
261
|
+
# 4. Process each batch concurrently using a thread pool.
|
|
262
|
+
# We enumerate the batches so that we can later reassemble results in order.
|
|
263
|
+
results = [None] * len(formatted_batches)
|
|
264
|
+
with ThreadPoolExecutor(max_workers=max_pool_workers) as executor:
|
|
265
|
+
future_to_idx = {}
|
|
266
|
+
for idx, (batch, batch_data) in enumerate(zip(formatted_batches, formatted_batch_data)):
|
|
267
|
+
future = executor.submit(
|
|
268
|
+
self._process_batch, batch, batch_data=batch_data, model_name=model_name, **kwargs
|
|
269
|
+
)
|
|
270
|
+
future_to_idx[future] = idx
|
|
271
|
+
|
|
272
|
+
for future in as_completed(future_to_idx.keys()):
|
|
273
|
+
idx = future_to_idx[future]
|
|
274
|
+
results[idx] = future.result()
|
|
275
|
+
|
|
276
|
+
# 5. Process the parsed outputs for each batch using its corresponding batch_data.
|
|
277
|
+
# As the batches are in order, we coalesce their outputs accordingly.
|
|
278
|
+
all_results = []
|
|
279
|
+
for parsed_output, batch_data in results:
|
|
280
|
+
batch_results = self.model_interface.process_inference_results(
|
|
281
|
+
parsed_output,
|
|
282
|
+
original_image_shapes=batch_data.get("original_image_shapes"),
|
|
283
|
+
protocol=self.protocol,
|
|
284
|
+
**kwargs,
|
|
285
|
+
)
|
|
286
|
+
if isinstance(batch_results, list):
|
|
287
|
+
all_results.extend(batch_results)
|
|
288
|
+
else:
|
|
289
|
+
all_results.append(batch_results)
|
|
290
|
+
|
|
291
|
+
except Exception as err:
|
|
292
|
+
error_str = f"Error during NimClient inference [{self.model_interface.name()}, {self.protocol}]: {err}"
|
|
293
|
+
logger.error(error_str)
|
|
294
|
+
raise RuntimeError(error_str)
|
|
295
|
+
|
|
296
|
+
return all_results
|
|
297
|
+
|
|
298
|
+
def _grpc_infer(
|
|
299
|
+
self, formatted_input: Union[list, list[np.ndarray]], model_name: str, **kwargs
|
|
300
|
+
) -> Union[list, list[np.ndarray]]:
|
|
301
|
+
"""
|
|
302
|
+
Perform inference using the gRPC protocol.
|
|
303
|
+
|
|
304
|
+
Parameters
|
|
305
|
+
----------
|
|
306
|
+
formatted_input : np.ndarray
|
|
307
|
+
The input data formatted as a numpy array.
|
|
308
|
+
model_name : str
|
|
309
|
+
The name of the model to use for inference.
|
|
310
|
+
|
|
311
|
+
Returns
|
|
312
|
+
-------
|
|
313
|
+
np.ndarray
|
|
314
|
+
The output of the model as a numpy array.
|
|
315
|
+
"""
|
|
316
|
+
if not isinstance(formatted_input, list):
|
|
317
|
+
formatted_input = [formatted_input]
|
|
318
|
+
|
|
319
|
+
parameters = kwargs.get("parameters", {})
|
|
320
|
+
output_names = kwargs.get("output_names", ["output"])
|
|
321
|
+
dtypes = kwargs.get("dtypes", ["FP32"])
|
|
322
|
+
input_names = kwargs.get("input_names", ["input"])
|
|
323
|
+
|
|
324
|
+
input_tensors = []
|
|
325
|
+
for input_name, input_data, dtype in zip(input_names, formatted_input, dtypes):
|
|
326
|
+
input_tensors.append(grpcclient.InferInput(input_name, input_data.shape, datatype=dtype))
|
|
327
|
+
|
|
328
|
+
for idx, input_data in enumerate(formatted_input):
|
|
329
|
+
input_tensors[idx].set_data_from_numpy(input_data)
|
|
330
|
+
|
|
331
|
+
outputs = [grpcclient.InferRequestedOutput(output_name) for output_name in output_names]
|
|
332
|
+
|
|
333
|
+
base_delay = 2.0
|
|
334
|
+
attempt = 0
|
|
335
|
+
retries_429 = 0
|
|
336
|
+
max_grpc_retries = self.max_429_retries
|
|
337
|
+
|
|
338
|
+
while attempt < self.max_retries:
|
|
339
|
+
try:
|
|
340
|
+
response = self.client.infer(
|
|
341
|
+
model_name=model_name, parameters=parameters, inputs=input_tensors, outputs=outputs
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
logger.debug(f"gRPC inference response: {response}")
|
|
345
|
+
|
|
346
|
+
if len(outputs) == 1:
|
|
347
|
+
return response.as_numpy(outputs[0].name())
|
|
348
|
+
else:
|
|
349
|
+
return [response.as_numpy(output.name()) for output in outputs]
|
|
350
|
+
|
|
351
|
+
except grpcclient.InferenceServerException as e:
|
|
352
|
+
status = str(e.status())
|
|
353
|
+
message = e.message()
|
|
354
|
+
|
|
355
|
+
# Handle CUDA memory errors
|
|
356
|
+
if status == "StatusCode.INTERNAL":
|
|
357
|
+
if CUDA_ERROR_REGEX.search(message):
|
|
358
|
+
logger.warning(
|
|
359
|
+
f"Received gRPC INTERNAL error with CUDA-related message for model '{model_name}'. "
|
|
360
|
+
f"Attempt {attempt + 1} of {self.max_retries}. Message (truncated): {message[:500]}"
|
|
361
|
+
)
|
|
362
|
+
if attempt >= self.max_retries - 1:
|
|
363
|
+
logger.error(f"Max retries exceeded for CUDA errors on model '{model_name}'.")
|
|
364
|
+
raise e
|
|
365
|
+
# Try to reload models before retrying
|
|
366
|
+
model_reload_succeeded = reload_models(client=self.client, client_timeout=self.timeout)
|
|
367
|
+
if not model_reload_succeeded:
|
|
368
|
+
logger.error(f"Failed to reload models for model '{model_name}'.")
|
|
369
|
+
else:
|
|
370
|
+
logger.warning(
|
|
371
|
+
f"Received gRPC INTERNAL error for model '{model_name}'. "
|
|
372
|
+
f"Attempt {attempt + 1} of {self.max_retries}. Message (truncated): {message[:500]}"
|
|
373
|
+
)
|
|
374
|
+
if attempt >= self.max_retries - 1:
|
|
375
|
+
logger.error(f"Max retries exceeded for INTERNAL error on model '{model_name}'.")
|
|
376
|
+
raise e
|
|
377
|
+
|
|
378
|
+
# Common retry logic for both CUDA and non-CUDA INTERNAL errors
|
|
379
|
+
backoff_time = base_delay * (2**attempt)
|
|
380
|
+
time.sleep(backoff_time)
|
|
381
|
+
attempt += 1
|
|
382
|
+
continue
|
|
383
|
+
|
|
384
|
+
# Handle errors that can occur after model reload (NOT_FOUND, model not loaded)
|
|
385
|
+
if status == "StatusCode.NOT_FOUND":
|
|
386
|
+
logger.warning(
|
|
387
|
+
f"Received gRPC {status} error for model '{model_name}'. "
|
|
388
|
+
f"Attempt {attempt + 1} of {self.max_retries}. Message: {message[:500]}"
|
|
389
|
+
)
|
|
390
|
+
if attempt >= self.max_retries - 1:
|
|
391
|
+
logger.error(f"Max retries exceeded for model not found errors on model '{model_name}'.")
|
|
392
|
+
raise e
|
|
393
|
+
|
|
394
|
+
# Retry with exponential backoff WITHOUT reloading
|
|
395
|
+
backoff_time = base_delay * (2**attempt)
|
|
396
|
+
logger.info(
|
|
397
|
+
f"Retrying after {backoff_time}s backoff for model not found error on model '{model_name}'."
|
|
398
|
+
)
|
|
399
|
+
time.sleep(backoff_time)
|
|
400
|
+
attempt += 1
|
|
401
|
+
continue
|
|
402
|
+
|
|
403
|
+
if status == "StatusCode.UNAVAILABLE" and "Exceeds maximum queue size".lower() in message.lower():
|
|
404
|
+
retries_429 += 1
|
|
405
|
+
logger.warning(
|
|
406
|
+
f"Received gRPC {status} for model '{model_name}'. "
|
|
407
|
+
f"Attempt {retries_429} of {max_grpc_retries}."
|
|
408
|
+
)
|
|
409
|
+
if retries_429 >= max_grpc_retries:
|
|
410
|
+
logger.error(f"Max retries for gRPC {status} exceeded for model '{model_name}'.")
|
|
411
|
+
raise
|
|
412
|
+
|
|
413
|
+
backoff_time = base_delay * (2**retries_429)
|
|
414
|
+
time.sleep(backoff_time)
|
|
415
|
+
continue
|
|
416
|
+
|
|
417
|
+
# For other server-side errors (e.g., INVALID_ARGUMENT, etc.),
|
|
418
|
+
# fail fast as retrying will not help
|
|
419
|
+
logger.error(
|
|
420
|
+
f"Received non-retryable gRPC error {status} from Triton for model '{model_name}': {message}"
|
|
421
|
+
)
|
|
422
|
+
raise
|
|
423
|
+
|
|
424
|
+
except Exception as e:
|
|
425
|
+
# Catch any other unexpected exceptions (e.g., network issues not caught by Triton client)
|
|
426
|
+
logger.error(f"An unexpected error occurred during gRPC inference for model '{model_name}': {e}")
|
|
427
|
+
raise
|
|
428
|
+
|
|
429
|
+
def _http_infer(self, formatted_input: dict) -> dict:
|
|
430
|
+
"""
|
|
431
|
+
Perform inference using the HTTP protocol, retrying for timeouts or 5xx errors up to 5 times.
|
|
432
|
+
|
|
433
|
+
Parameters
|
|
434
|
+
----------
|
|
435
|
+
formatted_input : dict
|
|
436
|
+
The input data formatted as a dictionary.
|
|
437
|
+
|
|
438
|
+
Returns
|
|
439
|
+
-------
|
|
440
|
+
dict
|
|
441
|
+
The output of the model as a dictionary.
|
|
442
|
+
|
|
443
|
+
Raises
|
|
444
|
+
------
|
|
445
|
+
TimeoutError
|
|
446
|
+
If the HTTP request times out repeatedly, up to the max retries.
|
|
447
|
+
requests.RequestException
|
|
448
|
+
For other HTTP-related errors that persist after max retries.
|
|
449
|
+
"""
|
|
450
|
+
|
|
451
|
+
base_delay = 2.0
|
|
452
|
+
attempt = 0
|
|
453
|
+
retries_429 = 0
|
|
454
|
+
|
|
455
|
+
while attempt < self.max_retries:
|
|
456
|
+
try:
|
|
457
|
+
# Log system prompt for VLM requests
|
|
458
|
+
if isinstance(formatted_input, dict) and "messages" in formatted_input:
|
|
459
|
+
messages = formatted_input.get("messages", [])
|
|
460
|
+
if messages and messages[0].get("role") == "system":
|
|
461
|
+
system_content = messages[0].get("content", "")
|
|
462
|
+
model_name = self.model_interface.name()
|
|
463
|
+
logger.debug(f"{model_name}: Sending HTTP request with system prompt: '{system_content}'")
|
|
464
|
+
|
|
465
|
+
response = requests.post(
|
|
466
|
+
self.endpoint_url, json=formatted_input, headers=self.headers, timeout=self.timeout
|
|
467
|
+
)
|
|
468
|
+
status_code = response.status_code
|
|
469
|
+
|
|
470
|
+
# Check for server-side or rate-limit type errors
|
|
471
|
+
# e.g. 5xx => server error, 429 => too many requests
|
|
472
|
+
if status_code == 429:
|
|
473
|
+
retries_429 += 1
|
|
474
|
+
logger.warning(
|
|
475
|
+
f"Received HTTP 429 (Too Many Requests) from {self.model_interface.name()}. "
|
|
476
|
+
f"Attempt {retries_429} of {self.max_429_retries}."
|
|
477
|
+
)
|
|
478
|
+
if retries_429 >= self.max_429_retries:
|
|
479
|
+
logger.error("Max retries for HTTP 429 exceeded.")
|
|
480
|
+
response.raise_for_status()
|
|
481
|
+
else:
|
|
482
|
+
backoff_time = base_delay * (2**retries_429)
|
|
483
|
+
time.sleep(backoff_time)
|
|
484
|
+
continue # Retry without incrementing the main attempt counter
|
|
485
|
+
|
|
486
|
+
if status_code == 503 or (500 <= status_code < 600):
|
|
487
|
+
logger.warning(
|
|
488
|
+
f"Received HTTP {status_code} ({response.reason}) from "
|
|
489
|
+
f"{self.model_interface.name()}. Attempt {attempt + 1} of {self.max_retries}."
|
|
490
|
+
)
|
|
491
|
+
if attempt == self.max_retries - 1:
|
|
492
|
+
# No more retries left
|
|
493
|
+
logger.error(f"Max retries exceeded after receiving HTTP {status_code}.")
|
|
494
|
+
response.raise_for_status() # raise the appropriate HTTPError
|
|
495
|
+
else:
|
|
496
|
+
# Exponential backoff
|
|
497
|
+
backoff_time = base_delay * (2**attempt)
|
|
498
|
+
time.sleep(backoff_time)
|
|
499
|
+
attempt += 1
|
|
500
|
+
continue
|
|
501
|
+
else:
|
|
502
|
+
# Not in our "retry" category => just raise_for_status or return
|
|
503
|
+
response.raise_for_status()
|
|
504
|
+
logger.debug(f"HTTP inference response: {response.json()}")
|
|
505
|
+
return response.json()
|
|
506
|
+
|
|
507
|
+
except requests.Timeout:
|
|
508
|
+
# Treat timeouts similarly to 5xx => attempt a retry
|
|
509
|
+
logger.warning(
|
|
510
|
+
f"HTTP request timed out after {self.timeout} seconds during {self.model_interface.name()} "
|
|
511
|
+
f"inference. Attempt {attempt + 1} of {self.max_retries}."
|
|
512
|
+
)
|
|
513
|
+
if attempt == self.max_retries - 1:
|
|
514
|
+
logger.error("Max retries exceeded after repeated timeouts.")
|
|
515
|
+
raise TimeoutError(
|
|
516
|
+
f"Repeated timeouts for {self.model_interface.name()} after {attempt + 1} attempts."
|
|
517
|
+
)
|
|
518
|
+
# Exponential backoff
|
|
519
|
+
backoff_time = base_delay * (2**attempt)
|
|
520
|
+
time.sleep(backoff_time)
|
|
521
|
+
attempt += 1
|
|
522
|
+
|
|
523
|
+
except requests.HTTPError as http_err:
|
|
524
|
+
# If we ended up here, it's a non-retryable 4xx or final 5xx after final attempt
|
|
525
|
+
logger.error(f"HTTP request failed with status code {response.status_code}: {http_err}")
|
|
526
|
+
raise
|
|
527
|
+
|
|
528
|
+
except requests.RequestException as e:
|
|
529
|
+
# ConnectionError or other non-HTTPError
|
|
530
|
+
logger.error(f"HTTP request encountered a network issue: {e}")
|
|
531
|
+
if attempt == self.max_retries - 1:
|
|
532
|
+
raise
|
|
533
|
+
# Else retry on next loop iteration
|
|
534
|
+
backoff_time = base_delay * (2**attempt)
|
|
535
|
+
time.sleep(backoff_time)
|
|
536
|
+
attempt += 1
|
|
537
|
+
|
|
538
|
+
# If we exit the loop without returning, we've exhausted all attempts
|
|
539
|
+
logger.error(f"Failed to get a successful response after {self.max_retries} retries.")
|
|
540
|
+
raise Exception(f"Failed to get a successful response after {self.max_retries} retries.")
|
|
541
|
+
|
|
542
|
+
def _batcher_loop(self):
|
|
543
|
+
"""The main loop for the background thread to form and process batches."""
|
|
544
|
+
while not self._stop_event.is_set():
|
|
545
|
+
requests_batch = []
|
|
546
|
+
try:
|
|
547
|
+
first_req = self._request_queue.get(timeout=self._batch_timeout)
|
|
548
|
+
if first_req is None:
|
|
549
|
+
continue
|
|
550
|
+
requests_batch.append(first_req)
|
|
551
|
+
|
|
552
|
+
start_time = time.monotonic()
|
|
553
|
+
|
|
554
|
+
while len(requests_batch) < self._batch_size:
|
|
555
|
+
if (time.monotonic() - start_time) >= self._batch_timeout:
|
|
556
|
+
break
|
|
557
|
+
|
|
558
|
+
if self._request_queue.empty():
|
|
559
|
+
break
|
|
560
|
+
|
|
561
|
+
next_req_peek = self._request_queue.queue[0]
|
|
562
|
+
if next_req_peek is None:
|
|
563
|
+
break
|
|
564
|
+
|
|
565
|
+
if self._batch_memory_budget_bytes:
|
|
566
|
+
if not self.model_interface.does_item_fit_in_batch(
|
|
567
|
+
requests_batch,
|
|
568
|
+
next_req_peek,
|
|
569
|
+
self._batch_memory_budget_bytes,
|
|
570
|
+
):
|
|
571
|
+
break
|
|
572
|
+
|
|
573
|
+
try:
|
|
574
|
+
next_req = self._request_queue.get_nowait()
|
|
575
|
+
if next_req is None:
|
|
576
|
+
break
|
|
577
|
+
requests_batch.append(next_req)
|
|
578
|
+
except queue.Empty:
|
|
579
|
+
break
|
|
580
|
+
|
|
581
|
+
except queue.Empty:
|
|
582
|
+
continue
|
|
583
|
+
|
|
584
|
+
if requests_batch:
|
|
585
|
+
self._process_dynamic_batch(requests_batch)
|
|
586
|
+
|
|
587
|
+
def _process_dynamic_batch(self, requests: list[InferenceRequest]):
|
|
588
|
+
"""Coalesces, infers, and distributes results for a dynamic batch."""
|
|
589
|
+
if not requests:
|
|
590
|
+
return
|
|
591
|
+
|
|
592
|
+
first_req = requests[0]
|
|
593
|
+
model_name = first_req.model_name
|
|
594
|
+
kwargs = first_req.kwargs
|
|
595
|
+
|
|
596
|
+
try:
|
|
597
|
+
# 1. Coalesce individual data items into a single batch input
|
|
598
|
+
batch_input, batch_data = self.model_interface.coalesce_requests_to_batch(
|
|
599
|
+
[req.data for req in requests],
|
|
600
|
+
[req.dims for req in requests],
|
|
601
|
+
protocol=self.protocol,
|
|
602
|
+
model_name=model_name,
|
|
603
|
+
**kwargs,
|
|
604
|
+
)
|
|
605
|
+
|
|
606
|
+
# 2. Perform inference using the existing _process_batch logic
|
|
607
|
+
parsed_output, _ = self._process_batch(batch_input, batch_data=batch_data, model_name=model_name, **kwargs)
|
|
608
|
+
|
|
609
|
+
# 3. Process the batched output to get final results
|
|
610
|
+
all_results = self.model_interface.process_inference_results(
|
|
611
|
+
parsed_output,
|
|
612
|
+
original_image_shapes=batch_data.get("original_image_shapes"),
|
|
613
|
+
protocol=self.protocol,
|
|
614
|
+
**kwargs,
|
|
615
|
+
)
|
|
616
|
+
|
|
617
|
+
# 4. Distribute the individual results back to the correct Future
|
|
618
|
+
if len(all_results) != len(requests):
|
|
619
|
+
raise ValueError("Mismatch between result count and request count.")
|
|
620
|
+
|
|
621
|
+
for i, req in enumerate(requests):
|
|
622
|
+
req.future.set_result(all_results[i])
|
|
623
|
+
|
|
624
|
+
except Exception as e:
|
|
625
|
+
# If anything fails, propagate the exception to all futures in the batch
|
|
626
|
+
logger.error(f"Error processing dynamic batch: {e}")
|
|
627
|
+
for req in requests:
|
|
628
|
+
req.future.set_exception(e)
|
|
629
|
+
|
|
630
|
+
def submit(self, data: Any, model_name: str, dims: Tuple[int, int], **kwargs) -> Future:
|
|
631
|
+
"""
|
|
632
|
+
Submits a single inference request to the dynamic batcher.
|
|
633
|
+
|
|
634
|
+
This method is non-blocking and returns a Future object that will
|
|
635
|
+
eventually contain the inference result.
|
|
636
|
+
|
|
637
|
+
Parameters
|
|
638
|
+
----------
|
|
639
|
+
data : Any
|
|
640
|
+
The single data item for inference (e.g., one image, one text prompt).
|
|
641
|
+
|
|
642
|
+
Returns
|
|
643
|
+
-------
|
|
644
|
+
concurrent.futures.Future
|
|
645
|
+
A future that will be fulfilled with the inference result.
|
|
646
|
+
"""
|
|
647
|
+
if not self.dynamic_batching_enabled:
|
|
648
|
+
raise RuntimeError(
|
|
649
|
+
"Dynamic batching is not enabled. Please initialize NimClient with " "enable_dynamic_batching=True."
|
|
650
|
+
)
|
|
651
|
+
|
|
652
|
+
future = Future()
|
|
653
|
+
request = InferenceRequest(data=data, future=future, model_name=model_name, dims=dims, kwargs=kwargs)
|
|
654
|
+
self._request_queue.put(request)
|
|
655
|
+
return future
|
|
656
|
+
|
|
657
|
+
def close(self):
|
|
658
|
+
"""Stops the dynamic batching worker and closes client connections."""
|
|
659
|
+
|
|
660
|
+
if self.dynamic_batching_enabled:
|
|
661
|
+
self._stop_event.set()
|
|
662
|
+
# Unblock the queue in case the thread is waiting on get()
|
|
663
|
+
self._request_queue.put(None)
|
|
664
|
+
if self._batcher_thread.is_alive():
|
|
665
|
+
self._batcher_thread.join()
|
|
666
|
+
|
|
667
|
+
if self.client:
|
|
668
|
+
self.client.close()
|
|
669
|
+
|
|
670
|
+
|
|
671
|
+
class NimClientManager:
|
|
672
|
+
"""
|
|
673
|
+
A thread-safe, singleton manager for creating and sharing NimClient instances.
|
|
674
|
+
|
|
675
|
+
This manager ensures that only one NimClient is created per unique configuration.
|
|
676
|
+
"""
|
|
677
|
+
|
|
678
|
+
_instance = None
|
|
679
|
+
_lock = threading.Lock()
|
|
680
|
+
|
|
681
|
+
def __new__(cls):
|
|
682
|
+
# Singleton pattern
|
|
683
|
+
if cls._instance is None:
|
|
684
|
+
with cls._lock:
|
|
685
|
+
if cls._instance is None:
|
|
686
|
+
cls._instance = super(NimClientManager, cls).__new__(cls)
|
|
687
|
+
return cls._instance
|
|
688
|
+
|
|
689
|
+
def __init__(self):
|
|
690
|
+
if not hasattr(self, "_initialized"):
|
|
691
|
+
with self._lock:
|
|
692
|
+
if not hasattr(self, "_initialized"):
|
|
693
|
+
self._clients = {} # Key: config_hash, Value: NimClient instance
|
|
694
|
+
self._client_lock = threading.Lock()
|
|
695
|
+
self._initialized = True
|
|
696
|
+
|
|
697
|
+
def _generate_config_key(self, **kwargs) -> str:
|
|
698
|
+
"""Creates a stable, hashable key from client configuration."""
|
|
699
|
+
sorted_config = sorted(kwargs.items())
|
|
700
|
+
config_str = json.dumps(sorted_config)
|
|
701
|
+
return hashlib.md5(config_str.encode("utf-8")).hexdigest()
|
|
702
|
+
|
|
703
|
+
def get_client(self, model_interface, **kwargs) -> "NimClient":
|
|
704
|
+
"""
|
|
705
|
+
Gets or creates a NimClient for the given configuration.
|
|
706
|
+
"""
|
|
707
|
+
config_key = self._generate_config_key(model_interface_name=model_interface.name(), **kwargs)
|
|
708
|
+
|
|
709
|
+
if config_key in self._clients:
|
|
710
|
+
return self._clients[config_key]
|
|
711
|
+
|
|
712
|
+
with self._client_lock:
|
|
713
|
+
if config_key in self._clients:
|
|
714
|
+
return self._clients[config_key]
|
|
715
|
+
|
|
716
|
+
logger.debug(f"Creating new NimClient for config hash: {config_key}")
|
|
717
|
+
|
|
718
|
+
new_client = NimClient(model_interface=model_interface, **kwargs)
|
|
719
|
+
|
|
720
|
+
if new_client.dynamic_batching_enabled:
|
|
721
|
+
new_client.start()
|
|
722
|
+
|
|
723
|
+
self._clients[config_key] = new_client
|
|
724
|
+
|
|
725
|
+
return new_client
|
|
726
|
+
|
|
727
|
+
def shutdown(self):
|
|
728
|
+
"""
|
|
729
|
+
Gracefully closes all managed NimClient instances.
|
|
730
|
+
This is called automatically on application exit by `atexit`.
|
|
731
|
+
"""
|
|
732
|
+
logger.debug(f"Shutting down NimClientManager and {len(self._clients)} client(s)...")
|
|
733
|
+
with self._client_lock:
|
|
734
|
+
for config_key, client in self._clients.items():
|
|
735
|
+
logger.debug(f"Closing client for config: {config_key}")
|
|
736
|
+
try:
|
|
737
|
+
client.close()
|
|
738
|
+
except Exception as e:
|
|
739
|
+
logger.error(f"Error closing client for config {config_key}: {e}")
|
|
740
|
+
self._clients.clear()
|
|
741
|
+
logger.debug("NimClientManager shutdown complete.")
|
|
742
|
+
|
|
743
|
+
|
|
744
|
+
# A global helper function to make access even easier
|
|
745
|
+
def get_nim_client_manager(*args, **kwargs) -> NimClientManager:
|
|
746
|
+
"""Returns the singleton instance of the NimClientManager."""
|
|
747
|
+
return NimClientManager(*args, **kwargs)
|
|
748
|
+
|
|
749
|
+
|
|
750
|
+
def reload_models(client: grpcclient.InferenceServerClient, exclude: list[str] = [], client_timeout: int = 120) -> bool:
|
|
751
|
+
"""
|
|
752
|
+
Reloads all models in the Triton server except for the models in the exclude list.
|
|
753
|
+
|
|
754
|
+
Parameters
|
|
755
|
+
----------
|
|
756
|
+
client : grpcclient.InferenceServerClient
|
|
757
|
+
The gRPC client connected to the Triton server.
|
|
758
|
+
exclude : list[str], optional
|
|
759
|
+
A list of model names to exclude from reloading.
|
|
760
|
+
client_timeout : int, optional
|
|
761
|
+
Timeout for client operations in seconds (default: 120).
|
|
762
|
+
|
|
763
|
+
Returns
|
|
764
|
+
-------
|
|
765
|
+
bool
|
|
766
|
+
True if all models were successfully reloaded, False otherwise.
|
|
767
|
+
"""
|
|
768
|
+
model_index = client.get_model_repository_index()
|
|
769
|
+
exclude = set(exclude)
|
|
770
|
+
names = [m.name for m in model_index.models if m.name not in exclude]
|
|
771
|
+
|
|
772
|
+
logger.info(f"Reloading {len(names)} model(s): {', '.join(names) if names else '(none)'}")
|
|
773
|
+
|
|
774
|
+
# 1) Unload
|
|
775
|
+
for name in names:
|
|
776
|
+
try:
|
|
777
|
+
client.unload_model(name)
|
|
778
|
+
except grpcclient.InferenceServerException as e:
|
|
779
|
+
msg = e.message()
|
|
780
|
+
if "explicit model load / unload" in msg.lower():
|
|
781
|
+
status = e.status()
|
|
782
|
+
logger.warning(
|
|
783
|
+
f"[SKIP Model Reload] Explicit model control disabled; cannot unload '{name}'. Status: {status}."
|
|
784
|
+
)
|
|
785
|
+
return False
|
|
786
|
+
logger.error(f"[ERROR] Failed to unload '{name}': {msg}")
|
|
787
|
+
return False
|
|
788
|
+
|
|
789
|
+
# 2) Load
|
|
790
|
+
for name in names:
|
|
791
|
+
client.load_model(name)
|
|
792
|
+
|
|
793
|
+
# 3) Readiness check
|
|
794
|
+
for name in names:
|
|
795
|
+
ready = client.is_model_ready(model_name=name, client_timeout=client_timeout)
|
|
796
|
+
if not ready:
|
|
797
|
+
logger.warning(f"[Warning] Triton Not ready: {name}")
|
|
798
|
+
return False
|
|
799
|
+
|
|
800
|
+
logger.info("✅ Reload of models complete.")
|
|
801
|
+
return True
|