nv-ingest-api 26.1.0rc4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nv-ingest-api might be problematic. Click here for more details.

Files changed (177) hide show
  1. nv_ingest_api/__init__.py +3 -0
  2. nv_ingest_api/interface/__init__.py +218 -0
  3. nv_ingest_api/interface/extract.py +977 -0
  4. nv_ingest_api/interface/mutate.py +154 -0
  5. nv_ingest_api/interface/store.py +200 -0
  6. nv_ingest_api/interface/transform.py +382 -0
  7. nv_ingest_api/interface/utility.py +186 -0
  8. nv_ingest_api/internal/__init__.py +0 -0
  9. nv_ingest_api/internal/enums/__init__.py +3 -0
  10. nv_ingest_api/internal/enums/common.py +550 -0
  11. nv_ingest_api/internal/extract/__init__.py +3 -0
  12. nv_ingest_api/internal/extract/audio/__init__.py +3 -0
  13. nv_ingest_api/internal/extract/audio/audio_extraction.py +202 -0
  14. nv_ingest_api/internal/extract/docx/__init__.py +5 -0
  15. nv_ingest_api/internal/extract/docx/docx_extractor.py +232 -0
  16. nv_ingest_api/internal/extract/docx/engines/__init__.py +0 -0
  17. nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/__init__.py +3 -0
  18. nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/docx_helper.py +127 -0
  19. nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/docxreader.py +971 -0
  20. nv_ingest_api/internal/extract/html/__init__.py +3 -0
  21. nv_ingest_api/internal/extract/html/html_extractor.py +84 -0
  22. nv_ingest_api/internal/extract/image/__init__.py +3 -0
  23. nv_ingest_api/internal/extract/image/chart_extractor.py +375 -0
  24. nv_ingest_api/internal/extract/image/image_extractor.py +208 -0
  25. nv_ingest_api/internal/extract/image/image_helpers/__init__.py +3 -0
  26. nv_ingest_api/internal/extract/image/image_helpers/common.py +433 -0
  27. nv_ingest_api/internal/extract/image/infographic_extractor.py +290 -0
  28. nv_ingest_api/internal/extract/image/ocr_extractor.py +407 -0
  29. nv_ingest_api/internal/extract/image/table_extractor.py +391 -0
  30. nv_ingest_api/internal/extract/pdf/__init__.py +3 -0
  31. nv_ingest_api/internal/extract/pdf/engines/__init__.py +19 -0
  32. nv_ingest_api/internal/extract/pdf/engines/adobe.py +484 -0
  33. nv_ingest_api/internal/extract/pdf/engines/llama.py +246 -0
  34. nv_ingest_api/internal/extract/pdf/engines/nemotron_parse.py +598 -0
  35. nv_ingest_api/internal/extract/pdf/engines/pdf_helpers/__init__.py +166 -0
  36. nv_ingest_api/internal/extract/pdf/engines/pdfium.py +652 -0
  37. nv_ingest_api/internal/extract/pdf/engines/tika.py +96 -0
  38. nv_ingest_api/internal/extract/pdf/engines/unstructured_io.py +426 -0
  39. nv_ingest_api/internal/extract/pdf/pdf_extractor.py +74 -0
  40. nv_ingest_api/internal/extract/pptx/__init__.py +5 -0
  41. nv_ingest_api/internal/extract/pptx/engines/__init__.py +0 -0
  42. nv_ingest_api/internal/extract/pptx/engines/pptx_helper.py +968 -0
  43. nv_ingest_api/internal/extract/pptx/pptx_extractor.py +210 -0
  44. nv_ingest_api/internal/meta/__init__.py +3 -0
  45. nv_ingest_api/internal/meta/udf.py +232 -0
  46. nv_ingest_api/internal/mutate/__init__.py +3 -0
  47. nv_ingest_api/internal/mutate/deduplicate.py +110 -0
  48. nv_ingest_api/internal/mutate/filter.py +133 -0
  49. nv_ingest_api/internal/primitives/__init__.py +0 -0
  50. nv_ingest_api/internal/primitives/control_message_task.py +16 -0
  51. nv_ingest_api/internal/primitives/ingest_control_message.py +307 -0
  52. nv_ingest_api/internal/primitives/nim/__init__.py +9 -0
  53. nv_ingest_api/internal/primitives/nim/default_values.py +14 -0
  54. nv_ingest_api/internal/primitives/nim/model_interface/__init__.py +3 -0
  55. nv_ingest_api/internal/primitives/nim/model_interface/cached.py +274 -0
  56. nv_ingest_api/internal/primitives/nim/model_interface/decorators.py +56 -0
  57. nv_ingest_api/internal/primitives/nim/model_interface/deplot.py +270 -0
  58. nv_ingest_api/internal/primitives/nim/model_interface/helpers.py +338 -0
  59. nv_ingest_api/internal/primitives/nim/model_interface/nemotron_parse.py +239 -0
  60. nv_ingest_api/internal/primitives/nim/model_interface/ocr.py +776 -0
  61. nv_ingest_api/internal/primitives/nim/model_interface/parakeet.py +367 -0
  62. nv_ingest_api/internal/primitives/nim/model_interface/text_embedding.py +129 -0
  63. nv_ingest_api/internal/primitives/nim/model_interface/vlm.py +177 -0
  64. nv_ingest_api/internal/primitives/nim/model_interface/yolox.py +1681 -0
  65. nv_ingest_api/internal/primitives/nim/nim_client.py +801 -0
  66. nv_ingest_api/internal/primitives/nim/nim_model_interface.py +126 -0
  67. nv_ingest_api/internal/primitives/tracing/__init__.py +0 -0
  68. nv_ingest_api/internal/primitives/tracing/latency.py +69 -0
  69. nv_ingest_api/internal/primitives/tracing/logging.py +96 -0
  70. nv_ingest_api/internal/primitives/tracing/tagging.py +288 -0
  71. nv_ingest_api/internal/schemas/__init__.py +3 -0
  72. nv_ingest_api/internal/schemas/extract/__init__.py +3 -0
  73. nv_ingest_api/internal/schemas/extract/extract_audio_schema.py +133 -0
  74. nv_ingest_api/internal/schemas/extract/extract_chart_schema.py +144 -0
  75. nv_ingest_api/internal/schemas/extract/extract_docx_schema.py +129 -0
  76. nv_ingest_api/internal/schemas/extract/extract_html_schema.py +34 -0
  77. nv_ingest_api/internal/schemas/extract/extract_image_schema.py +126 -0
  78. nv_ingest_api/internal/schemas/extract/extract_infographic_schema.py +137 -0
  79. nv_ingest_api/internal/schemas/extract/extract_ocr_schema.py +137 -0
  80. nv_ingest_api/internal/schemas/extract/extract_pdf_schema.py +220 -0
  81. nv_ingest_api/internal/schemas/extract/extract_pptx_schema.py +128 -0
  82. nv_ingest_api/internal/schemas/extract/extract_table_schema.py +137 -0
  83. nv_ingest_api/internal/schemas/message_brokers/__init__.py +3 -0
  84. nv_ingest_api/internal/schemas/message_brokers/message_broker_client_schema.py +37 -0
  85. nv_ingest_api/internal/schemas/message_brokers/request_schema.py +34 -0
  86. nv_ingest_api/internal/schemas/message_brokers/response_schema.py +19 -0
  87. nv_ingest_api/internal/schemas/meta/__init__.py +3 -0
  88. nv_ingest_api/internal/schemas/meta/base_model_noext.py +11 -0
  89. nv_ingest_api/internal/schemas/meta/ingest_job_schema.py +355 -0
  90. nv_ingest_api/internal/schemas/meta/metadata_schema.py +394 -0
  91. nv_ingest_api/internal/schemas/meta/udf.py +23 -0
  92. nv_ingest_api/internal/schemas/mixins.py +39 -0
  93. nv_ingest_api/internal/schemas/mutate/__init__.py +3 -0
  94. nv_ingest_api/internal/schemas/mutate/mutate_image_dedup_schema.py +16 -0
  95. nv_ingest_api/internal/schemas/store/__init__.py +3 -0
  96. nv_ingest_api/internal/schemas/store/store_embedding_schema.py +28 -0
  97. nv_ingest_api/internal/schemas/store/store_image_schema.py +45 -0
  98. nv_ingest_api/internal/schemas/transform/__init__.py +3 -0
  99. nv_ingest_api/internal/schemas/transform/transform_image_caption_schema.py +36 -0
  100. nv_ingest_api/internal/schemas/transform/transform_image_filter_schema.py +17 -0
  101. nv_ingest_api/internal/schemas/transform/transform_text_embedding_schema.py +48 -0
  102. nv_ingest_api/internal/schemas/transform/transform_text_splitter_schema.py +24 -0
  103. nv_ingest_api/internal/store/__init__.py +3 -0
  104. nv_ingest_api/internal/store/embed_text_upload.py +236 -0
  105. nv_ingest_api/internal/store/image_upload.py +251 -0
  106. nv_ingest_api/internal/transform/__init__.py +3 -0
  107. nv_ingest_api/internal/transform/caption_image.py +219 -0
  108. nv_ingest_api/internal/transform/embed_text.py +702 -0
  109. nv_ingest_api/internal/transform/split_text.py +182 -0
  110. nv_ingest_api/util/__init__.py +3 -0
  111. nv_ingest_api/util/control_message/__init__.py +0 -0
  112. nv_ingest_api/util/control_message/validators.py +47 -0
  113. nv_ingest_api/util/converters/__init__.py +0 -0
  114. nv_ingest_api/util/converters/bytetools.py +78 -0
  115. nv_ingest_api/util/converters/containers.py +65 -0
  116. nv_ingest_api/util/converters/datetools.py +90 -0
  117. nv_ingest_api/util/converters/dftools.py +127 -0
  118. nv_ingest_api/util/converters/formats.py +64 -0
  119. nv_ingest_api/util/converters/type_mappings.py +27 -0
  120. nv_ingest_api/util/dataloader/__init__.py +9 -0
  121. nv_ingest_api/util/dataloader/dataloader.py +409 -0
  122. nv_ingest_api/util/detectors/__init__.py +5 -0
  123. nv_ingest_api/util/detectors/language.py +38 -0
  124. nv_ingest_api/util/exception_handlers/__init__.py +0 -0
  125. nv_ingest_api/util/exception_handlers/converters.py +72 -0
  126. nv_ingest_api/util/exception_handlers/decorators.py +429 -0
  127. nv_ingest_api/util/exception_handlers/detectors.py +74 -0
  128. nv_ingest_api/util/exception_handlers/pdf.py +116 -0
  129. nv_ingest_api/util/exception_handlers/schemas.py +68 -0
  130. nv_ingest_api/util/image_processing/__init__.py +5 -0
  131. nv_ingest_api/util/image_processing/clustering.py +260 -0
  132. nv_ingest_api/util/image_processing/processing.py +177 -0
  133. nv_ingest_api/util/image_processing/table_and_chart.py +504 -0
  134. nv_ingest_api/util/image_processing/transforms.py +850 -0
  135. nv_ingest_api/util/imports/__init__.py +3 -0
  136. nv_ingest_api/util/imports/callable_signatures.py +108 -0
  137. nv_ingest_api/util/imports/dynamic_resolvers.py +158 -0
  138. nv_ingest_api/util/introspection/__init__.py +3 -0
  139. nv_ingest_api/util/introspection/class_inspect.py +145 -0
  140. nv_ingest_api/util/introspection/function_inspect.py +65 -0
  141. nv_ingest_api/util/logging/__init__.py +0 -0
  142. nv_ingest_api/util/logging/configuration.py +102 -0
  143. nv_ingest_api/util/logging/sanitize.py +84 -0
  144. nv_ingest_api/util/message_brokers/__init__.py +3 -0
  145. nv_ingest_api/util/message_brokers/qos_scheduler.py +283 -0
  146. nv_ingest_api/util/message_brokers/simple_message_broker/__init__.py +9 -0
  147. nv_ingest_api/util/message_brokers/simple_message_broker/broker.py +465 -0
  148. nv_ingest_api/util/message_brokers/simple_message_broker/ordered_message_queue.py +71 -0
  149. nv_ingest_api/util/message_brokers/simple_message_broker/simple_client.py +455 -0
  150. nv_ingest_api/util/metadata/__init__.py +5 -0
  151. nv_ingest_api/util/metadata/aggregators.py +516 -0
  152. nv_ingest_api/util/multi_processing/__init__.py +8 -0
  153. nv_ingest_api/util/multi_processing/mp_pool_singleton.py +200 -0
  154. nv_ingest_api/util/nim/__init__.py +161 -0
  155. nv_ingest_api/util/pdf/__init__.py +3 -0
  156. nv_ingest_api/util/pdf/pdfium.py +428 -0
  157. nv_ingest_api/util/schema/__init__.py +3 -0
  158. nv_ingest_api/util/schema/schema_validator.py +10 -0
  159. nv_ingest_api/util/service_clients/__init__.py +3 -0
  160. nv_ingest_api/util/service_clients/client_base.py +86 -0
  161. nv_ingest_api/util/service_clients/kafka/__init__.py +3 -0
  162. nv_ingest_api/util/service_clients/redis/__init__.py +3 -0
  163. nv_ingest_api/util/service_clients/redis/redis_client.py +983 -0
  164. nv_ingest_api/util/service_clients/rest/__init__.py +0 -0
  165. nv_ingest_api/util/service_clients/rest/rest_client.py +595 -0
  166. nv_ingest_api/util/string_processing/__init__.py +51 -0
  167. nv_ingest_api/util/string_processing/configuration.py +682 -0
  168. nv_ingest_api/util/string_processing/yaml.py +109 -0
  169. nv_ingest_api/util/system/__init__.py +0 -0
  170. nv_ingest_api/util/system/hardware_info.py +594 -0
  171. nv_ingest_api-26.1.0rc4.dist-info/METADATA +237 -0
  172. nv_ingest_api-26.1.0rc4.dist-info/RECORD +177 -0
  173. nv_ingest_api-26.1.0rc4.dist-info/WHEEL +5 -0
  174. nv_ingest_api-26.1.0rc4.dist-info/licenses/LICENSE +201 -0
  175. nv_ingest_api-26.1.0rc4.dist-info/top_level.txt +2 -0
  176. udfs/__init__.py +5 -0
  177. udfs/llm_summarizer_udf.py +259 -0
@@ -0,0 +1,801 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ import hashlib
6
+ import json
7
+ import logging
8
+ import re
9
+ import threading
10
+ import time
11
+ import queue
12
+ from collections import namedtuple
13
+ from concurrent.futures import Future, ThreadPoolExecutor, as_completed
14
+ from typing import Any
15
+ from typing import Optional
16
+ from typing import Tuple, Union
17
+
18
+ import numpy as np
19
+ import requests
20
+ import tritonclient.grpc as grpcclient
21
+
22
+ from nv_ingest_api.internal.primitives.tracing.tagging import traceable_func
23
+ from nv_ingest_api.util.string_processing import generate_url
24
+
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+ # Regex pattern to detect CUDA-related errors in Triton gRPC responses
29
+ CUDA_ERROR_REGEX = re.compile(
30
+ r"(model reload|illegal memory access|illegal instruction|invalid argument|failed to (copy|load|perform) .*: .*|TritonModelException: failed to copy data: .*)", # noqa: E501
31
+ re.IGNORECASE,
32
+ )
33
+
34
+ # A simple structure to hold a request's data and its Future for the result
35
+ InferenceRequest = namedtuple("InferenceRequest", ["data", "future", "model_name", "dims", "kwargs"])
36
+
37
+
38
+ class NimClient:
39
+ """
40
+ A client for interfacing with a model inference server using gRPC or HTTP protocols.
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ model_interface,
46
+ protocol: str,
47
+ endpoints: Tuple[str, str],
48
+ auth_token: Optional[str] = None,
49
+ timeout: float = 120.0,
50
+ max_retries: int = 10,
51
+ max_429_retries: int = 5,
52
+ enable_dynamic_batching: bool = False,
53
+ dynamic_batch_timeout: float = 0.1, # 100 milliseconds
54
+ dynamic_batch_memory_budget_mb: Optional[float] = None,
55
+ ):
56
+ """
57
+ Initialize the NimClient with the specified model interface, protocol, and server endpoints.
58
+
59
+ Parameters
60
+ ----------
61
+ model_interface : ModelInterface
62
+ The model interface implementation to use.
63
+ protocol : str
64
+ The protocol to use ("grpc" or "http").
65
+ endpoints : tuple
66
+ A tuple containing the gRPC and HTTP endpoints.
67
+ auth_token : str, optional
68
+ Authorization token for HTTP requests (default: None).
69
+ timeout : float, optional
70
+ Timeout for HTTP requests in seconds (default: 120.0).
71
+ max_retries : int, optional
72
+ The maximum number of retries for non-429 server-side errors (default: 10).
73
+ max_429_retries : int, optional
74
+ The maximum number of retries specifically for 429 errors (default: 5).
75
+
76
+ Raises
77
+ ------
78
+ ValueError
79
+ If an invalid protocol is specified or if required endpoints are missing.
80
+ """
81
+ self.client = None
82
+ self.model_interface = model_interface
83
+ self.protocol = protocol.lower()
84
+ self.auth_token = auth_token
85
+ self.timeout = timeout # Timeout for HTTP requests
86
+ self.max_retries = max_retries
87
+ self.max_429_retries = max_429_retries
88
+ self._grpc_endpoint, self._http_endpoint = endpoints
89
+ self._max_batch_sizes = {}
90
+ self._lock = threading.Lock()
91
+
92
+ if self.protocol == "grpc":
93
+ if not self._grpc_endpoint:
94
+ raise ValueError("gRPC endpoint must be provided for gRPC protocol")
95
+ logger.debug(f"Creating gRPC client with {self._grpc_endpoint}")
96
+ self.client = grpcclient.InferenceServerClient(url=self._grpc_endpoint)
97
+ elif self.protocol == "http":
98
+ if not self._http_endpoint:
99
+ raise ValueError("HTTP endpoint must be provided for HTTP protocol")
100
+ logger.debug(f"Creating HTTP client with {self._http_endpoint}")
101
+ self.endpoint_url = generate_url(self._http_endpoint)
102
+ self.headers = {"accept": "application/json", "content-type": "application/json"}
103
+ if self.auth_token:
104
+ self.headers["Authorization"] = f"Bearer {self.auth_token}"
105
+ else:
106
+ raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.")
107
+
108
+ self.dynamic_batching_enabled = enable_dynamic_batching
109
+ if self.dynamic_batching_enabled:
110
+ self._batch_timeout = dynamic_batch_timeout
111
+ if dynamic_batch_memory_budget_mb is not None:
112
+ self._batch_memory_budget_bytes = dynamic_batch_memory_budget_mb * 1024 * 1024
113
+ else:
114
+ self._batch_memory_budget_bytes = None
115
+
116
+ self._request_queue = queue.Queue()
117
+ self._stop_event = threading.Event()
118
+ self._batcher_thread = threading.Thread(target=self._batcher_loop, daemon=True)
119
+
120
+ def start(self):
121
+ """Starts the dynamic batching worker thread if enabled."""
122
+ if self.dynamic_batching_enabled and not self._batcher_thread.is_alive():
123
+ self._batcher_thread.start()
124
+
125
+ def _fetch_max_batch_size(self, model_name, model_version: str = "") -> int:
126
+ """Fetch the maximum batch size from the Triton model configuration in a thread-safe manner."""
127
+
128
+ if model_name == "yolox_ensemble":
129
+ model_name = "yolox"
130
+
131
+ if model_name in self._max_batch_sizes:
132
+ return self._max_batch_sizes[model_name]
133
+
134
+ with self._lock:
135
+ # Double check, just in case another thread set the value while we were waiting
136
+ if model_name in self._max_batch_sizes:
137
+ return self._max_batch_sizes[model_name]
138
+
139
+ if not self._grpc_endpoint or not self.client:
140
+ self._max_batch_sizes[model_name] = 1
141
+ return 1
142
+
143
+ try:
144
+ model_config = self.client.get_model_config(model_name=model_name, model_version=model_version)
145
+ self._max_batch_sizes[model_name] = model_config.config.max_batch_size
146
+ logger.debug(f"Max batch size for model '{model_name}': {self._max_batch_sizes[model_name]}")
147
+ except Exception as e:
148
+ self._max_batch_sizes[model_name] = 1
149
+ logger.warning(f"Failed to retrieve max batch size: {e}, defaulting to 1")
150
+
151
+ return self._max_batch_sizes[model_name]
152
+
153
+ def _process_batch(self, batch_input, *, batch_data, model_name, **kwargs):
154
+ """
155
+ Process a single batch input for inference using its corresponding batch_data.
156
+
157
+ Parameters
158
+ ----------
159
+ batch_input : Any
160
+ The input data for this batch.
161
+ batch_data : Any
162
+ The corresponding scratch-pad data for this batch as returned by format_input.
163
+ model_name : str
164
+ The model name for inference.
165
+ kwargs : dict
166
+ Additional parameters.
167
+
168
+ Returns
169
+ -------
170
+ tuple
171
+ A tuple (parsed_output, batch_data) for subsequent post-processing.
172
+ """
173
+ if self.protocol == "grpc":
174
+ logger.debug("Performing gRPC inference for a batch...")
175
+ response = self._grpc_infer(batch_input, model_name, **kwargs)
176
+ logger.debug("gRPC inference received response for a batch")
177
+ elif self.protocol == "http":
178
+ logger.debug("Performing HTTP inference for a batch...")
179
+ response = self._http_infer(batch_input)
180
+ logger.debug("HTTP inference received response for a batch")
181
+ else:
182
+ raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.")
183
+
184
+ parsed_output = self.model_interface.parse_output(
185
+ response, protocol=self.protocol, data=batch_data, model_name=model_name, **kwargs
186
+ )
187
+ return parsed_output, batch_data
188
+
189
+ def try_set_max_batch_size(self, model_name, model_version: str = ""):
190
+ """Attempt to set the max batch size for the model if it is not already set, ensuring thread safety."""
191
+ self._fetch_max_batch_size(model_name, model_version)
192
+
193
+ @traceable_func(trace_name="{stage_name}::{model_name}")
194
+ def infer(self, data: dict, model_name: str, **kwargs) -> Any:
195
+ """
196
+ Perform inference using the specified model and input data.
197
+
198
+ Parameters
199
+ ----------
200
+ data : dict
201
+ The input data for inference.
202
+ model_name : str
203
+ The model name.
204
+ kwargs : dict
205
+ Additional parameters for inference.
206
+
207
+ Returns
208
+ -------
209
+ Any
210
+ The processed inference results, coalesced in the same order as the input images.
211
+ """
212
+ # 1. Retrieve or default to the model's maximum batch size.
213
+ batch_size = self._fetch_max_batch_size(model_name)
214
+ max_requested_batch_size = kwargs.pop("max_batch_size", batch_size)
215
+ force_requested_batch_size = kwargs.pop("force_max_batch_size", False)
216
+ max_batch_size = (
217
+ max(1, min(batch_size, max_requested_batch_size))
218
+ if not force_requested_batch_size
219
+ else max_requested_batch_size
220
+ )
221
+ self._batch_size = max_batch_size
222
+
223
+ if self.dynamic_batching_enabled:
224
+ # DYNAMIC BATCHING PATH
225
+ try:
226
+ data = self.model_interface.prepare_data_for_inference(data)
227
+
228
+ futures = []
229
+ for base64_image, image_array in zip(data["base64_images"], data["images"]):
230
+ dims = image_array.shape[:2]
231
+ futures.append(self.submit(base64_image, model_name, dims, **kwargs))
232
+
233
+ results = [future.result() for future in futures]
234
+
235
+ return results
236
+
237
+ except Exception as err:
238
+ error_str = (
239
+ f"Error during synchronous infer with dynamic batching [{self.model_interface.name()}]: {err}"
240
+ )
241
+ logger.error(error_str)
242
+ raise RuntimeError(error_str) from err
243
+
244
+ # OFFLINE BATCHING PATH
245
+ try:
246
+ # 2. Prepare data for inference.
247
+ data = self.model_interface.prepare_data_for_inference(data)
248
+
249
+ # 3. Format the input based on protocol.
250
+ formatted_batches, formatted_batch_data = self.model_interface.format_input(
251
+ data,
252
+ protocol=self.protocol,
253
+ max_batch_size=max_batch_size,
254
+ model_name=model_name,
255
+ **kwargs,
256
+ )
257
+
258
+ # Check for a custom maximum pool worker count, and remove it from kwargs.
259
+ max_pool_workers = kwargs.pop("max_pool_workers", 16)
260
+
261
+ # 4. Process each batch concurrently using a thread pool.
262
+ # We enumerate the batches so that we can later reassemble results in order.
263
+ results = [None] * len(formatted_batches)
264
+ with ThreadPoolExecutor(max_workers=max_pool_workers) as executor:
265
+ future_to_idx = {}
266
+ for idx, (batch, batch_data) in enumerate(zip(formatted_batches, formatted_batch_data)):
267
+ future = executor.submit(
268
+ self._process_batch, batch, batch_data=batch_data, model_name=model_name, **kwargs
269
+ )
270
+ future_to_idx[future] = idx
271
+
272
+ for future in as_completed(future_to_idx.keys()):
273
+ idx = future_to_idx[future]
274
+ results[idx] = future.result()
275
+
276
+ # 5. Process the parsed outputs for each batch using its corresponding batch_data.
277
+ # As the batches are in order, we coalesce their outputs accordingly.
278
+ all_results = []
279
+ for parsed_output, batch_data in results:
280
+ batch_results = self.model_interface.process_inference_results(
281
+ parsed_output,
282
+ original_image_shapes=batch_data.get("original_image_shapes"),
283
+ protocol=self.protocol,
284
+ **kwargs,
285
+ )
286
+ if isinstance(batch_results, list):
287
+ all_results.extend(batch_results)
288
+ else:
289
+ all_results.append(batch_results)
290
+
291
+ except Exception as err:
292
+ error_str = f"Error during NimClient inference [{self.model_interface.name()}, {self.protocol}]: {err}"
293
+ logger.error(error_str)
294
+ raise RuntimeError(error_str)
295
+
296
+ return all_results
297
+
298
+ def _grpc_infer(
299
+ self, formatted_input: Union[list, list[np.ndarray]], model_name: str, **kwargs
300
+ ) -> Union[list, list[np.ndarray]]:
301
+ """
302
+ Perform inference using the gRPC protocol.
303
+
304
+ Parameters
305
+ ----------
306
+ formatted_input : np.ndarray
307
+ The input data formatted as a numpy array.
308
+ model_name : str
309
+ The name of the model to use for inference.
310
+
311
+ Returns
312
+ -------
313
+ np.ndarray
314
+ The output of the model as a numpy array.
315
+ """
316
+ if not isinstance(formatted_input, list):
317
+ formatted_input = [formatted_input]
318
+
319
+ parameters = kwargs.get("parameters", {})
320
+ output_names = kwargs.get("output_names", ["output"])
321
+ dtypes = kwargs.get("dtypes", ["FP32"])
322
+ input_names = kwargs.get("input_names", ["input"])
323
+
324
+ input_tensors = []
325
+ for input_name, input_data, dtype in zip(input_names, formatted_input, dtypes):
326
+ input_tensors.append(grpcclient.InferInput(input_name, input_data.shape, datatype=dtype))
327
+
328
+ for idx, input_data in enumerate(formatted_input):
329
+ input_tensors[idx].set_data_from_numpy(input_data)
330
+
331
+ outputs = [grpcclient.InferRequestedOutput(output_name) for output_name in output_names]
332
+
333
+ base_delay = 2.0
334
+ attempt = 0
335
+ retries_429 = 0
336
+ max_grpc_retries = self.max_429_retries
337
+
338
+ while attempt < self.max_retries:
339
+ try:
340
+ response = self.client.infer(
341
+ model_name=model_name, parameters=parameters, inputs=input_tensors, outputs=outputs
342
+ )
343
+
344
+ logger.debug(f"gRPC inference response: {response}")
345
+
346
+ if len(outputs) == 1:
347
+ return response.as_numpy(outputs[0].name())
348
+ else:
349
+ return [response.as_numpy(output.name()) for output in outputs]
350
+
351
+ except grpcclient.InferenceServerException as e:
352
+ status = str(e.status())
353
+ message = e.message()
354
+
355
+ # Handle CUDA memory errors
356
+ if status == "StatusCode.INTERNAL":
357
+ if CUDA_ERROR_REGEX.search(message):
358
+ logger.warning(
359
+ f"Received gRPC INTERNAL error with CUDA-related message for model '{model_name}'. "
360
+ f"Attempt {attempt + 1} of {self.max_retries}. Message (truncated): {message[:500]}"
361
+ )
362
+ if attempt >= self.max_retries - 1:
363
+ logger.error(f"Max retries exceeded for CUDA errors on model '{model_name}'.")
364
+ raise e
365
+ # Try to reload models before retrying
366
+ model_reload_succeeded = reload_models(client=self.client, client_timeout=self.timeout)
367
+ if not model_reload_succeeded:
368
+ logger.error(f"Failed to reload models for model '{model_name}'.")
369
+ else:
370
+ logger.warning(
371
+ f"Received gRPC INTERNAL error for model '{model_name}'. "
372
+ f"Attempt {attempt + 1} of {self.max_retries}. Message (truncated): {message[:500]}"
373
+ )
374
+ if attempt >= self.max_retries - 1:
375
+ logger.error(f"Max retries exceeded for INTERNAL error on model '{model_name}'.")
376
+ raise e
377
+
378
+ # Common retry logic for both CUDA and non-CUDA INTERNAL errors
379
+ backoff_time = base_delay * (2**attempt)
380
+ time.sleep(backoff_time)
381
+ attempt += 1
382
+ continue
383
+
384
+ # Handle errors that can occur after model reload (NOT_FOUND, model not loaded)
385
+ if status == "StatusCode.NOT_FOUND":
386
+ logger.warning(
387
+ f"Received gRPC {status} error for model '{model_name}'. "
388
+ f"Attempt {attempt + 1} of {self.max_retries}. Message: {message[:500]}"
389
+ )
390
+ if attempt >= self.max_retries - 1:
391
+ logger.error(f"Max retries exceeded for model not found errors on model '{model_name}'.")
392
+ raise e
393
+
394
+ # Retry with exponential backoff WITHOUT reloading
395
+ backoff_time = base_delay * (2**attempt)
396
+ logger.info(
397
+ f"Retrying after {backoff_time}s backoff for model not found error on model '{model_name}'."
398
+ )
399
+ time.sleep(backoff_time)
400
+ attempt += 1
401
+ continue
402
+
403
+ if status == "StatusCode.UNAVAILABLE" and "Exceeds maximum queue size".lower() in message.lower():
404
+ retries_429 += 1
405
+ logger.warning(
406
+ f"Received gRPC {status} for model '{model_name}'. "
407
+ f"Attempt {retries_429} of {max_grpc_retries}."
408
+ )
409
+ if retries_429 >= max_grpc_retries:
410
+ logger.error(f"Max retries for gRPC {status} exceeded for model '{model_name}'.")
411
+ raise
412
+
413
+ backoff_time = base_delay * (2**retries_429)
414
+ time.sleep(backoff_time)
415
+ continue
416
+
417
+ # For other server-side errors (e.g., INVALID_ARGUMENT, etc.),
418
+ # fail fast as retrying will not help
419
+ logger.error(
420
+ f"Received non-retryable gRPC error {status} from Triton for model '{model_name}': {message}"
421
+ )
422
+ raise
423
+
424
+ except Exception as e:
425
+ # Catch any other unexpected exceptions (e.g., network issues not caught by Triton client)
426
+ logger.error(f"An unexpected error occurred during gRPC inference for model '{model_name}': {e}")
427
+ raise
428
+
429
+ def _http_infer(self, formatted_input: dict) -> dict:
430
+ """
431
+ Perform inference using the HTTP protocol, retrying for timeouts or 5xx errors up to 5 times.
432
+
433
+ Parameters
434
+ ----------
435
+ formatted_input : dict
436
+ The input data formatted as a dictionary.
437
+
438
+ Returns
439
+ -------
440
+ dict
441
+ The output of the model as a dictionary.
442
+
443
+ Raises
444
+ ------
445
+ TimeoutError
446
+ If the HTTP request times out repeatedly, up to the max retries.
447
+ requests.RequestException
448
+ For other HTTP-related errors that persist after max retries.
449
+ """
450
+
451
+ base_delay = 2.0
452
+ attempt = 0
453
+ retries_429 = 0
454
+
455
+ while attempt < self.max_retries:
456
+ try:
457
+ # Log system prompt for VLM requests
458
+ if isinstance(formatted_input, dict) and "messages" in formatted_input:
459
+ messages = formatted_input.get("messages", [])
460
+ if messages and messages[0].get("role") == "system":
461
+ system_content = messages[0].get("content", "")
462
+ model_name = self.model_interface.name()
463
+ logger.debug(f"{model_name}: Sending HTTP request with system prompt: '{system_content}'")
464
+
465
+ response = requests.post(
466
+ self.endpoint_url, json=formatted_input, headers=self.headers, timeout=self.timeout
467
+ )
468
+ status_code = response.status_code
469
+
470
+ # Check for server-side or rate-limit type errors
471
+ # e.g. 5xx => server error, 429 => too many requests
472
+ if status_code == 429:
473
+ retries_429 += 1
474
+ logger.warning(
475
+ f"Received HTTP 429 (Too Many Requests) from {self.model_interface.name()}. "
476
+ f"Attempt {retries_429} of {self.max_429_retries}."
477
+ )
478
+ if retries_429 >= self.max_429_retries:
479
+ logger.error("Max retries for HTTP 429 exceeded.")
480
+ response.raise_for_status()
481
+ else:
482
+ backoff_time = base_delay * (2**retries_429)
483
+ time.sleep(backoff_time)
484
+ continue # Retry without incrementing the main attempt counter
485
+
486
+ if status_code == 503 or (500 <= status_code < 600):
487
+ logger.warning(
488
+ f"Received HTTP {status_code} ({response.reason}) from "
489
+ f"{self.model_interface.name()}. Attempt {attempt + 1} of {self.max_retries}."
490
+ )
491
+ if attempt == self.max_retries - 1:
492
+ # No more retries left
493
+ logger.error(f"Max retries exceeded after receiving HTTP {status_code}.")
494
+ response.raise_for_status() # raise the appropriate HTTPError
495
+ else:
496
+ # Exponential backoff
497
+ backoff_time = base_delay * (2**attempt)
498
+ time.sleep(backoff_time)
499
+ attempt += 1
500
+ continue
501
+ else:
502
+ # Not in our "retry" category => just raise_for_status or return
503
+ response.raise_for_status()
504
+ logger.debug(f"HTTP inference response: {response.json()}")
505
+ return response.json()
506
+
507
+ except requests.Timeout:
508
+ # Treat timeouts similarly to 5xx => attempt a retry
509
+ logger.warning(
510
+ f"HTTP request timed out after {self.timeout} seconds during {self.model_interface.name()} "
511
+ f"inference. Attempt {attempt + 1} of {self.max_retries}."
512
+ )
513
+ if attempt == self.max_retries - 1:
514
+ logger.error("Max retries exceeded after repeated timeouts.")
515
+ raise TimeoutError(
516
+ f"Repeated timeouts for {self.model_interface.name()} after {attempt + 1} attempts."
517
+ )
518
+ # Exponential backoff
519
+ backoff_time = base_delay * (2**attempt)
520
+ time.sleep(backoff_time)
521
+ attempt += 1
522
+
523
+ except requests.HTTPError as http_err:
524
+ # If we ended up here, it's a non-retryable 4xx or final 5xx after final attempt
525
+ logger.error(f"HTTP request failed with status code {response.status_code}: {http_err}")
526
+ raise
527
+
528
+ except requests.RequestException as e:
529
+ # ConnectionError or other non-HTTPError
530
+ logger.error(f"HTTP request encountered a network issue: {e}")
531
+ if attempt == self.max_retries - 1:
532
+ raise
533
+ # Else retry on next loop iteration
534
+ backoff_time = base_delay * (2**attempt)
535
+ time.sleep(backoff_time)
536
+ attempt += 1
537
+
538
+ # If we exit the loop without returning, we've exhausted all attempts
539
+ logger.error(f"Failed to get a successful response after {self.max_retries} retries.")
540
+ raise Exception(f"Failed to get a successful response after {self.max_retries} retries.")
541
+
542
+ def _batcher_loop(self):
543
+ """The main loop for the background thread to form and process batches."""
544
+ while not self._stop_event.is_set():
545
+ requests_batch = []
546
+ try:
547
+ first_req = self._request_queue.get(timeout=self._batch_timeout)
548
+ if first_req is None:
549
+ continue
550
+ requests_batch.append(first_req)
551
+
552
+ start_time = time.monotonic()
553
+
554
+ while len(requests_batch) < self._batch_size:
555
+ if (time.monotonic() - start_time) >= self._batch_timeout:
556
+ break
557
+
558
+ if self._request_queue.empty():
559
+ break
560
+
561
+ next_req_peek = self._request_queue.queue[0]
562
+ if next_req_peek is None:
563
+ break
564
+
565
+ if self._batch_memory_budget_bytes:
566
+ if not self.model_interface.does_item_fit_in_batch(
567
+ requests_batch,
568
+ next_req_peek,
569
+ self._batch_memory_budget_bytes,
570
+ ):
571
+ break
572
+
573
+ try:
574
+ next_req = self._request_queue.get_nowait()
575
+ if next_req is None:
576
+ break
577
+ requests_batch.append(next_req)
578
+ except queue.Empty:
579
+ break
580
+
581
+ except queue.Empty:
582
+ continue
583
+
584
+ if requests_batch:
585
+ self._process_dynamic_batch(requests_batch)
586
+
587
+ def _process_dynamic_batch(self, requests: list[InferenceRequest]):
588
+ """Coalesces, infers, and distributes results for a dynamic batch."""
589
+ if not requests:
590
+ return
591
+
592
+ first_req = requests[0]
593
+ model_name = first_req.model_name
594
+ kwargs = first_req.kwargs
595
+
596
+ try:
597
+ # 1. Coalesce individual data items into a single batch input
598
+ batch_input, batch_data = self.model_interface.coalesce_requests_to_batch(
599
+ [req.data for req in requests],
600
+ [req.dims for req in requests],
601
+ protocol=self.protocol,
602
+ model_name=model_name,
603
+ **kwargs,
604
+ )
605
+
606
+ # 2. Perform inference using the existing _process_batch logic
607
+ parsed_output, _ = self._process_batch(batch_input, batch_data=batch_data, model_name=model_name, **kwargs)
608
+
609
+ # 3. Process the batched output to get final results
610
+ all_results = self.model_interface.process_inference_results(
611
+ parsed_output,
612
+ original_image_shapes=batch_data.get("original_image_shapes"),
613
+ protocol=self.protocol,
614
+ **kwargs,
615
+ )
616
+
617
+ # 4. Distribute the individual results back to the correct Future
618
+ if len(all_results) != len(requests):
619
+ raise ValueError("Mismatch between result count and request count.")
620
+
621
+ for i, req in enumerate(requests):
622
+ req.future.set_result(all_results[i])
623
+
624
+ except Exception as e:
625
+ # If anything fails, propagate the exception to all futures in the batch
626
+ logger.error(f"Error processing dynamic batch: {e}")
627
+ for req in requests:
628
+ req.future.set_exception(e)
629
+
630
+ def submit(self, data: Any, model_name: str, dims: Tuple[int, int], **kwargs) -> Future:
631
+ """
632
+ Submits a single inference request to the dynamic batcher.
633
+
634
+ This method is non-blocking and returns a Future object that will
635
+ eventually contain the inference result.
636
+
637
+ Parameters
638
+ ----------
639
+ data : Any
640
+ The single data item for inference (e.g., one image, one text prompt).
641
+
642
+ Returns
643
+ -------
644
+ concurrent.futures.Future
645
+ A future that will be fulfilled with the inference result.
646
+ """
647
+ if not self.dynamic_batching_enabled:
648
+ raise RuntimeError(
649
+ "Dynamic batching is not enabled. Please initialize NimClient with " "enable_dynamic_batching=True."
650
+ )
651
+
652
+ future = Future()
653
+ request = InferenceRequest(data=data, future=future, model_name=model_name, dims=dims, kwargs=kwargs)
654
+ self._request_queue.put(request)
655
+ return future
656
+
657
+ def close(self):
658
+ """Stops the dynamic batching worker and closes client connections."""
659
+
660
+ if self.dynamic_batching_enabled:
661
+ self._stop_event.set()
662
+ # Unblock the queue in case the thread is waiting on get()
663
+ self._request_queue.put(None)
664
+ if self._batcher_thread.is_alive():
665
+ self._batcher_thread.join()
666
+
667
+ if self.client:
668
+ self.client.close()
669
+
670
+
671
+ class NimClientManager:
672
+ """
673
+ A thread-safe, singleton manager for creating and sharing NimClient instances.
674
+
675
+ This manager ensures that only one NimClient is created per unique configuration.
676
+ """
677
+
678
+ _instance = None
679
+ _lock = threading.Lock()
680
+
681
+ def __new__(cls):
682
+ # Singleton pattern
683
+ if cls._instance is None:
684
+ with cls._lock:
685
+ if cls._instance is None:
686
+ cls._instance = super(NimClientManager, cls).__new__(cls)
687
+ return cls._instance
688
+
689
+ def __init__(self):
690
+ if not hasattr(self, "_initialized"):
691
+ with self._lock:
692
+ if not hasattr(self, "_initialized"):
693
+ self._clients = {} # Key: config_hash, Value: NimClient instance
694
+ self._client_lock = threading.Lock()
695
+ self._initialized = True
696
+
697
+ def _generate_config_key(self, **kwargs) -> str:
698
+ """Creates a stable, hashable key from client configuration."""
699
+ sorted_config = sorted(kwargs.items())
700
+ config_str = json.dumps(sorted_config)
701
+ return hashlib.md5(config_str.encode("utf-8")).hexdigest()
702
+
703
+ def get_client(self, model_interface, **kwargs) -> "NimClient":
704
+ """
705
+ Gets or creates a NimClient for the given configuration.
706
+ """
707
+ config_key = self._generate_config_key(model_interface_name=model_interface.name(), **kwargs)
708
+
709
+ if config_key in self._clients:
710
+ return self._clients[config_key]
711
+
712
+ with self._client_lock:
713
+ if config_key in self._clients:
714
+ return self._clients[config_key]
715
+
716
+ logger.debug(f"Creating new NimClient for config hash: {config_key}")
717
+
718
+ new_client = NimClient(model_interface=model_interface, **kwargs)
719
+
720
+ if new_client.dynamic_batching_enabled:
721
+ new_client.start()
722
+
723
+ self._clients[config_key] = new_client
724
+
725
+ return new_client
726
+
727
+ def shutdown(self):
728
+ """
729
+ Gracefully closes all managed NimClient instances.
730
+ This is called automatically on application exit by `atexit`.
731
+ """
732
+ logger.debug(f"Shutting down NimClientManager and {len(self._clients)} client(s)...")
733
+ with self._client_lock:
734
+ for config_key, client in self._clients.items():
735
+ logger.debug(f"Closing client for config: {config_key}")
736
+ try:
737
+ client.close()
738
+ except Exception as e:
739
+ logger.error(f"Error closing client for config {config_key}: {e}")
740
+ self._clients.clear()
741
+ logger.debug("NimClientManager shutdown complete.")
742
+
743
+
744
+ # A global helper function to make access even easier
745
+ def get_nim_client_manager(*args, **kwargs) -> NimClientManager:
746
+ """Returns the singleton instance of the NimClientManager."""
747
+ return NimClientManager(*args, **kwargs)
748
+
749
+
750
+ def reload_models(client: grpcclient.InferenceServerClient, exclude: list[str] = [], client_timeout: int = 120) -> bool:
751
+ """
752
+ Reloads all models in the Triton server except for the models in the exclude list.
753
+
754
+ Parameters
755
+ ----------
756
+ client : grpcclient.InferenceServerClient
757
+ The gRPC client connected to the Triton server.
758
+ exclude : list[str], optional
759
+ A list of model names to exclude from reloading.
760
+ client_timeout : int, optional
761
+ Timeout for client operations in seconds (default: 120).
762
+
763
+ Returns
764
+ -------
765
+ bool
766
+ True if all models were successfully reloaded, False otherwise.
767
+ """
768
+ model_index = client.get_model_repository_index()
769
+ exclude = set(exclude)
770
+ names = [m.name for m in model_index.models if m.name not in exclude]
771
+
772
+ logger.info(f"Reloading {len(names)} model(s): {', '.join(names) if names else '(none)'}")
773
+
774
+ # 1) Unload
775
+ for name in names:
776
+ try:
777
+ client.unload_model(name)
778
+ except grpcclient.InferenceServerException as e:
779
+ msg = e.message()
780
+ if "explicit model load / unload" in msg.lower():
781
+ status = e.status()
782
+ logger.warning(
783
+ f"[SKIP Model Reload] Explicit model control disabled; cannot unload '{name}'. Status: {status}."
784
+ )
785
+ return False
786
+ logger.error(f"[ERROR] Failed to unload '{name}': {msg}")
787
+ return False
788
+
789
+ # 2) Load
790
+ for name in names:
791
+ client.load_model(name)
792
+
793
+ # 3) Readiness check
794
+ for name in names:
795
+ ready = client.is_model_ready(model_name=name, client_timeout=client_timeout)
796
+ if not ready:
797
+ logger.warning(f"[Warning] Triton Not ready: {name}")
798
+ return False
799
+
800
+ logger.info("✅ Reload of models complete.")
801
+ return True