nv-ingest-api 2025.4.15.dev20250415__py3-none-any.whl → 2025.4.17.dev20250417__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nv-ingest-api might be problematic. Click here for more details.

Files changed (153) hide show
  1. nv_ingest_api/__init__.py +3 -0
  2. nv_ingest_api/interface/__init__.py +215 -0
  3. nv_ingest_api/interface/extract.py +972 -0
  4. nv_ingest_api/interface/mutate.py +154 -0
  5. nv_ingest_api/interface/store.py +218 -0
  6. nv_ingest_api/interface/transform.py +382 -0
  7. nv_ingest_api/interface/utility.py +200 -0
  8. nv_ingest_api/internal/enums/__init__.py +3 -0
  9. nv_ingest_api/internal/enums/common.py +494 -0
  10. nv_ingest_api/internal/extract/__init__.py +3 -0
  11. nv_ingest_api/internal/extract/audio/__init__.py +3 -0
  12. nv_ingest_api/internal/extract/audio/audio_extraction.py +149 -0
  13. nv_ingest_api/internal/extract/docx/__init__.py +5 -0
  14. nv_ingest_api/internal/extract/docx/docx_extractor.py +205 -0
  15. nv_ingest_api/internal/extract/docx/engines/__init__.py +0 -0
  16. nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/__init__.py +3 -0
  17. nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/docx_helper.py +122 -0
  18. nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/docxreader.py +895 -0
  19. nv_ingest_api/internal/extract/image/__init__.py +3 -0
  20. nv_ingest_api/internal/extract/image/chart_extractor.py +353 -0
  21. nv_ingest_api/internal/extract/image/image_extractor.py +204 -0
  22. nv_ingest_api/internal/extract/image/image_helpers/__init__.py +3 -0
  23. nv_ingest_api/internal/extract/image/image_helpers/common.py +403 -0
  24. nv_ingest_api/internal/extract/image/infographic_extractor.py +253 -0
  25. nv_ingest_api/internal/extract/image/table_extractor.py +344 -0
  26. nv_ingest_api/internal/extract/pdf/__init__.py +3 -0
  27. nv_ingest_api/internal/extract/pdf/engines/__init__.py +19 -0
  28. nv_ingest_api/internal/extract/pdf/engines/adobe.py +484 -0
  29. nv_ingest_api/internal/extract/pdf/engines/llama.py +243 -0
  30. nv_ingest_api/internal/extract/pdf/engines/nemoretriever.py +597 -0
  31. nv_ingest_api/internal/extract/pdf/engines/pdf_helpers/__init__.py +146 -0
  32. nv_ingest_api/internal/extract/pdf/engines/pdfium.py +603 -0
  33. nv_ingest_api/internal/extract/pdf/engines/tika.py +96 -0
  34. nv_ingest_api/internal/extract/pdf/engines/unstructured_io.py +426 -0
  35. nv_ingest_api/internal/extract/pdf/pdf_extractor.py +74 -0
  36. nv_ingest_api/internal/extract/pptx/__init__.py +5 -0
  37. nv_ingest_api/internal/extract/pptx/engines/__init__.py +0 -0
  38. nv_ingest_api/internal/extract/pptx/engines/pptx_helper.py +799 -0
  39. nv_ingest_api/internal/extract/pptx/pptx_extractor.py +187 -0
  40. nv_ingest_api/internal/mutate/__init__.py +3 -0
  41. nv_ingest_api/internal/mutate/deduplicate.py +110 -0
  42. nv_ingest_api/internal/mutate/filter.py +133 -0
  43. nv_ingest_api/internal/primitives/__init__.py +0 -0
  44. nv_ingest_api/{primitives → internal/primitives}/control_message_task.py +4 -0
  45. nv_ingest_api/{primitives → internal/primitives}/ingest_control_message.py +5 -2
  46. nv_ingest_api/internal/primitives/nim/__init__.py +8 -0
  47. nv_ingest_api/internal/primitives/nim/default_values.py +15 -0
  48. nv_ingest_api/internal/primitives/nim/model_interface/__init__.py +3 -0
  49. nv_ingest_api/internal/primitives/nim/model_interface/cached.py +274 -0
  50. nv_ingest_api/internal/primitives/nim/model_interface/decorators.py +56 -0
  51. nv_ingest_api/internal/primitives/nim/model_interface/deplot.py +270 -0
  52. nv_ingest_api/internal/primitives/nim/model_interface/helpers.py +275 -0
  53. nv_ingest_api/internal/primitives/nim/model_interface/nemoretriever_parse.py +238 -0
  54. nv_ingest_api/internal/primitives/nim/model_interface/paddle.py +462 -0
  55. nv_ingest_api/internal/primitives/nim/model_interface/parakeet.py +367 -0
  56. nv_ingest_api/internal/primitives/nim/model_interface/text_embedding.py +132 -0
  57. nv_ingest_api/internal/primitives/nim/model_interface/vlm.py +152 -0
  58. nv_ingest_api/internal/primitives/nim/model_interface/yolox.py +1400 -0
  59. nv_ingest_api/internal/primitives/nim/nim_client.py +344 -0
  60. nv_ingest_api/internal/primitives/nim/nim_model_interface.py +81 -0
  61. nv_ingest_api/internal/primitives/tracing/__init__.py +0 -0
  62. nv_ingest_api/internal/primitives/tracing/latency.py +69 -0
  63. nv_ingest_api/internal/primitives/tracing/logging.py +96 -0
  64. nv_ingest_api/internal/primitives/tracing/tagging.py +197 -0
  65. nv_ingest_api/internal/schemas/__init__.py +3 -0
  66. nv_ingest_api/internal/schemas/extract/__init__.py +3 -0
  67. nv_ingest_api/internal/schemas/extract/extract_audio_schema.py +130 -0
  68. nv_ingest_api/internal/schemas/extract/extract_chart_schema.py +135 -0
  69. nv_ingest_api/internal/schemas/extract/extract_docx_schema.py +124 -0
  70. nv_ingest_api/internal/schemas/extract/extract_image_schema.py +124 -0
  71. nv_ingest_api/internal/schemas/extract/extract_infographic_schema.py +128 -0
  72. nv_ingest_api/internal/schemas/extract/extract_pdf_schema.py +218 -0
  73. nv_ingest_api/internal/schemas/extract/extract_pptx_schema.py +124 -0
  74. nv_ingest_api/internal/schemas/extract/extract_table_schema.py +129 -0
  75. nv_ingest_api/internal/schemas/message_brokers/__init__.py +3 -0
  76. nv_ingest_api/internal/schemas/message_brokers/message_broker_client_schema.py +23 -0
  77. nv_ingest_api/internal/schemas/message_brokers/request_schema.py +34 -0
  78. nv_ingest_api/internal/schemas/message_brokers/response_schema.py +19 -0
  79. nv_ingest_api/internal/schemas/meta/__init__.py +3 -0
  80. nv_ingest_api/internal/schemas/meta/base_model_noext.py +11 -0
  81. nv_ingest_api/internal/schemas/meta/ingest_job_schema.py +237 -0
  82. nv_ingest_api/internal/schemas/meta/metadata_schema.py +221 -0
  83. nv_ingest_api/internal/schemas/mutate/__init__.py +3 -0
  84. nv_ingest_api/internal/schemas/mutate/mutate_image_dedup_schema.py +16 -0
  85. nv_ingest_api/internal/schemas/store/__init__.py +3 -0
  86. nv_ingest_api/internal/schemas/store/store_embedding_schema.py +28 -0
  87. nv_ingest_api/internal/schemas/store/store_image_schema.py +30 -0
  88. nv_ingest_api/internal/schemas/transform/__init__.py +3 -0
  89. nv_ingest_api/internal/schemas/transform/transform_image_caption_schema.py +15 -0
  90. nv_ingest_api/internal/schemas/transform/transform_image_filter_schema.py +17 -0
  91. nv_ingest_api/internal/schemas/transform/transform_text_embedding_schema.py +25 -0
  92. nv_ingest_api/internal/schemas/transform/transform_text_splitter_schema.py +22 -0
  93. nv_ingest_api/internal/store/__init__.py +3 -0
  94. nv_ingest_api/internal/store/embed_text_upload.py +236 -0
  95. nv_ingest_api/internal/store/image_upload.py +232 -0
  96. nv_ingest_api/internal/transform/__init__.py +3 -0
  97. nv_ingest_api/internal/transform/caption_image.py +205 -0
  98. nv_ingest_api/internal/transform/embed_text.py +496 -0
  99. nv_ingest_api/internal/transform/split_text.py +157 -0
  100. nv_ingest_api/util/__init__.py +0 -0
  101. nv_ingest_api/util/control_message/__init__.py +0 -0
  102. nv_ingest_api/util/control_message/validators.py +47 -0
  103. nv_ingest_api/util/converters/__init__.py +0 -0
  104. nv_ingest_api/util/converters/bytetools.py +78 -0
  105. nv_ingest_api/util/converters/containers.py +65 -0
  106. nv_ingest_api/util/converters/datetools.py +90 -0
  107. nv_ingest_api/util/converters/dftools.py +127 -0
  108. nv_ingest_api/util/converters/formats.py +64 -0
  109. nv_ingest_api/util/converters/type_mappings.py +27 -0
  110. nv_ingest_api/util/detectors/__init__.py +5 -0
  111. nv_ingest_api/util/detectors/language.py +38 -0
  112. nv_ingest_api/util/exception_handlers/__init__.py +0 -0
  113. nv_ingest_api/util/exception_handlers/converters.py +72 -0
  114. nv_ingest_api/util/exception_handlers/decorators.py +223 -0
  115. nv_ingest_api/util/exception_handlers/detectors.py +74 -0
  116. nv_ingest_api/util/exception_handlers/pdf.py +116 -0
  117. nv_ingest_api/util/exception_handlers/schemas.py +68 -0
  118. nv_ingest_api/util/image_processing/__init__.py +5 -0
  119. nv_ingest_api/util/image_processing/clustering.py +260 -0
  120. nv_ingest_api/util/image_processing/processing.py +179 -0
  121. nv_ingest_api/util/image_processing/table_and_chart.py +449 -0
  122. nv_ingest_api/util/image_processing/transforms.py +407 -0
  123. nv_ingest_api/util/logging/__init__.py +0 -0
  124. nv_ingest_api/util/logging/configuration.py +31 -0
  125. nv_ingest_api/util/message_brokers/__init__.py +3 -0
  126. nv_ingest_api/util/message_brokers/simple_message_broker/__init__.py +9 -0
  127. nv_ingest_api/util/message_brokers/simple_message_broker/broker.py +465 -0
  128. nv_ingest_api/util/message_brokers/simple_message_broker/ordered_message_queue.py +71 -0
  129. nv_ingest_api/util/message_brokers/simple_message_broker/simple_client.py +435 -0
  130. nv_ingest_api/util/metadata/__init__.py +5 -0
  131. nv_ingest_api/util/metadata/aggregators.py +469 -0
  132. nv_ingest_api/util/multi_processing/__init__.py +8 -0
  133. nv_ingest_api/util/multi_processing/mp_pool_singleton.py +194 -0
  134. nv_ingest_api/util/nim/__init__.py +56 -0
  135. nv_ingest_api/util/pdf/__init__.py +3 -0
  136. nv_ingest_api/util/pdf/pdfium.py +427 -0
  137. nv_ingest_api/util/schema/__init__.py +0 -0
  138. nv_ingest_api/util/schema/schema_validator.py +10 -0
  139. nv_ingest_api/util/service_clients/__init__.py +3 -0
  140. nv_ingest_api/util/service_clients/client_base.py +72 -0
  141. nv_ingest_api/util/service_clients/kafka/__init__.py +3 -0
  142. nv_ingest_api/util/service_clients/redis/__init__.py +0 -0
  143. nv_ingest_api/util/service_clients/redis/redis_client.py +334 -0
  144. nv_ingest_api/util/service_clients/rest/__init__.py +0 -0
  145. nv_ingest_api/util/service_clients/rest/rest_client.py +398 -0
  146. nv_ingest_api/util/string_processing/__init__.py +51 -0
  147. {nv_ingest_api-2025.4.15.dev20250415.dist-info → nv_ingest_api-2025.4.17.dev20250417.dist-info}/METADATA +1 -1
  148. nv_ingest_api-2025.4.17.dev20250417.dist-info/RECORD +152 -0
  149. nv_ingest_api-2025.4.15.dev20250415.dist-info/RECORD +0 -9
  150. /nv_ingest_api/{primitives → internal}/__init__.py +0 -0
  151. {nv_ingest_api-2025.4.15.dev20250415.dist-info → nv_ingest_api-2025.4.17.dev20250417.dist-info}/WHEEL +0 -0
  152. {nv_ingest_api-2025.4.15.dev20250415.dist-info → nv_ingest_api-2025.4.17.dev20250417.dist-info}/licenses/LICENSE +0 -0
  153. {nv_ingest_api-2025.4.15.dev20250415.dist-info → nv_ingest_api-2025.4.17.dev20250417.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,344 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ import logging
6
+ import threading
7
+ import time
8
+ from concurrent.futures import ThreadPoolExecutor
9
+ from typing import Any
10
+ from typing import Optional
11
+ from typing import Tuple
12
+
13
+ import numpy as np
14
+ import requests
15
+ import tritonclient.grpc as grpcclient
16
+
17
+ from nv_ingest_api.internal.primitives.tracing.tagging import traceable_func
18
+ from nv_ingest_api.util.string_processing import generate_url
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class NimClient:
24
+ """
25
+ A client for interfacing with a model inference server using gRPC or HTTP protocols.
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ model_interface,
31
+ protocol: str,
32
+ endpoints: Tuple[str, str],
33
+ auth_token: Optional[str] = None,
34
+ timeout: float = 120.0,
35
+ max_retries: int = 5,
36
+ ):
37
+ """
38
+ Initialize the NimClient with the specified model interface, protocol, and server endpoints.
39
+
40
+ Parameters
41
+ ----------
42
+ model_interface : ModelInterface
43
+ The model interface implementation to use.
44
+ protocol : str
45
+ The protocol to use ("grpc" or "http").
46
+ endpoints : tuple
47
+ A tuple containing the gRPC and HTTP endpoints.
48
+ auth_token : str, optional
49
+ Authorization token for HTTP requests (default: None).
50
+ timeout : float, optional
51
+ Timeout for HTTP requests in seconds (default: 30.0).
52
+
53
+ Raises
54
+ ------
55
+ ValueError
56
+ If an invalid protocol is specified or if required endpoints are missing.
57
+ """
58
+
59
+ self.client = None
60
+ self.model_interface = model_interface
61
+ self.protocol = protocol.lower()
62
+ self.auth_token = auth_token
63
+ self.timeout = timeout # Timeout for HTTP requests
64
+ self.max_retries = max_retries
65
+ self._grpc_endpoint, self._http_endpoint = endpoints
66
+ self._max_batch_sizes = {}
67
+ self._lock = threading.Lock()
68
+
69
+ if self.protocol == "grpc":
70
+ if not self._grpc_endpoint:
71
+ raise ValueError("gRPC endpoint must be provided for gRPC protocol")
72
+ logger.debug(f"Creating gRPC client with {self._grpc_endpoint}")
73
+ self.client = grpcclient.InferenceServerClient(url=self._grpc_endpoint)
74
+ elif self.protocol == "http":
75
+ if not self._http_endpoint:
76
+ raise ValueError("HTTP endpoint must be provided for HTTP protocol")
77
+ logger.debug(f"Creating HTTP client with {self._http_endpoint}")
78
+ self.endpoint_url = generate_url(self._http_endpoint)
79
+ self.headers = {"accept": "application/json", "content-type": "application/json"}
80
+ if self.auth_token:
81
+ self.headers["Authorization"] = f"Bearer {self.auth_token}"
82
+ else:
83
+ raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.")
84
+
85
+ def _fetch_max_batch_size(self, model_name, model_version: str = "") -> int:
86
+ """Fetch the maximum batch size from the Triton model configuration in a thread-safe manner."""
87
+ if model_name in self._max_batch_sizes:
88
+ return self._max_batch_sizes[model_name]
89
+
90
+ with self._lock:
91
+ # Double check, just in case another thread set the value while we were waiting
92
+ if model_name in self._max_batch_sizes:
93
+ return self._max_batch_sizes[model_name]
94
+
95
+ if not self._grpc_endpoint:
96
+ self._max_batch_sizes[model_name] = 1
97
+ return 1
98
+
99
+ try:
100
+ client = self.client if self.client else grpcclient.InferenceServerClient(url=self._grpc_endpoint)
101
+ model_config = client.get_model_config(model_name=model_name, model_version=model_version)
102
+ self._max_batch_sizes[model_name] = model_config.config.max_batch_size
103
+ logger.debug(f"Max batch size for model '{model_name}': {self._max_batch_sizes[model_name]}")
104
+ except Exception as e:
105
+ self._max_batch_sizes[model_name] = 1
106
+ logger.warning(f"Failed to retrieve max batch size: {e}, defaulting to 1")
107
+
108
+ return self._max_batch_sizes[model_name]
109
+
110
+ def _process_batch(self, batch_input, *, batch_data, model_name, **kwargs):
111
+ """
112
+ Process a single batch input for inference using its corresponding batch_data.
113
+
114
+ Parameters
115
+ ----------
116
+ batch_input : Any
117
+ The input data for this batch.
118
+ batch_data : Any
119
+ The corresponding scratch-pad data for this batch as returned by format_input.
120
+ model_name : str
121
+ The model name for inference.
122
+ kwargs : dict
123
+ Additional parameters.
124
+
125
+ Returns
126
+ -------
127
+ tuple
128
+ A tuple (parsed_output, batch_data) for subsequent post-processing.
129
+ """
130
+ if self.protocol == "grpc":
131
+ logger.debug("Performing gRPC inference for a batch...")
132
+ response = self._grpc_infer(batch_input, model_name)
133
+ logger.debug("gRPC inference received response for a batch")
134
+ elif self.protocol == "http":
135
+ logger.debug("Performing HTTP inference for a batch...")
136
+ response = self._http_infer(batch_input)
137
+ logger.debug("HTTP inference received response for a batch")
138
+ else:
139
+ raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.")
140
+
141
+ parsed_output = self.model_interface.parse_output(response, protocol=self.protocol, data=batch_data, **kwargs)
142
+ return parsed_output, batch_data
143
+
144
+ def try_set_max_batch_size(self, model_name, model_version: str = ""):
145
+ """Attempt to set the max batch size for the model if it is not already set, ensuring thread safety."""
146
+ self._fetch_max_batch_size(model_name, model_version)
147
+
148
+ @traceable_func(trace_name="{stage_name}::{model_name}")
149
+ def infer(self, data: dict, model_name: str, **kwargs) -> Any:
150
+ """
151
+ Perform inference using the specified model and input data.
152
+
153
+ Parameters
154
+ ----------
155
+ data : dict
156
+ The input data for inference.
157
+ model_name : str
158
+ The model name.
159
+ kwargs : dict
160
+ Additional parameters for inference.
161
+
162
+ Returns
163
+ -------
164
+ Any
165
+ The processed inference results, coalesced in the same order as the input images.
166
+ """
167
+ try:
168
+ # 1. Retrieve or default to the model's maximum batch size.
169
+ batch_size = self._fetch_max_batch_size(model_name)
170
+ max_requested_batch_size = kwargs.get("max_batch_size", batch_size)
171
+ force_requested_batch_size = kwargs.get("force_max_batch_size", False)
172
+ max_batch_size = (
173
+ min(batch_size, max_requested_batch_size)
174
+ if not force_requested_batch_size
175
+ else max_requested_batch_size
176
+ )
177
+
178
+ # 2. Prepare data for inference.
179
+ data = self.model_interface.prepare_data_for_inference(data)
180
+
181
+ # 3. Format the input based on protocol.
182
+ formatted_batches, formatted_batch_data = self.model_interface.format_input(
183
+ data, protocol=self.protocol, max_batch_size=max_batch_size, model_name=model_name
184
+ )
185
+
186
+ # Check for a custom maximum pool worker count, and remove it from kwargs.
187
+ max_pool_workers = kwargs.pop("max_pool_workers", 16)
188
+
189
+ # 4. Process each batch concurrently using a thread pool.
190
+ # We enumerate the batches so that we can later reassemble results in order.
191
+ results = [None] * len(formatted_batches)
192
+ with ThreadPoolExecutor(max_workers=max_pool_workers) as executor:
193
+ futures = []
194
+ for idx, (batch, batch_data) in enumerate(zip(formatted_batches, formatted_batch_data)):
195
+ future = executor.submit(
196
+ self._process_batch, batch, batch_data=batch_data, model_name=model_name, **kwargs
197
+ )
198
+ futures.append((idx, future))
199
+ for idx, future in futures:
200
+ results[idx] = future.result()
201
+
202
+ # 5. Process the parsed outputs for each batch using its corresponding batch_data.
203
+ # As the batches are in order, we coalesce their outputs accordingly.
204
+ all_results = []
205
+ for parsed_output, batch_data in results:
206
+ batch_results = self.model_interface.process_inference_results(
207
+ parsed_output,
208
+ original_image_shapes=batch_data.get("original_image_shapes"),
209
+ protocol=self.protocol,
210
+ **kwargs,
211
+ )
212
+ if isinstance(batch_results, list):
213
+ all_results.extend(batch_results)
214
+ else:
215
+ all_results.append(batch_results)
216
+
217
+ except Exception as err:
218
+ error_str = f"Error during NimClient inference [{self.model_interface.name()}, {self.protocol}]: {err}"
219
+ logger.error(error_str)
220
+ raise RuntimeError(error_str)
221
+
222
+ return all_results
223
+
224
+ def _grpc_infer(self, formatted_input: np.ndarray, model_name: str) -> np.ndarray:
225
+ """
226
+ Perform inference using the gRPC protocol.
227
+
228
+ Parameters
229
+ ----------
230
+ formatted_input : np.ndarray
231
+ The input data formatted as a numpy array.
232
+ model_name : str
233
+ The name of the model to use for inference.
234
+
235
+ Returns
236
+ -------
237
+ np.ndarray
238
+ The output of the model as a numpy array.
239
+ """
240
+
241
+ input_tensors = [grpcclient.InferInput("input", formatted_input.shape, datatype="FP32")]
242
+ input_tensors[0].set_data_from_numpy(formatted_input)
243
+
244
+ outputs = [grpcclient.InferRequestedOutput("output")]
245
+ response = self.client.infer(model_name=model_name, inputs=input_tensors, outputs=outputs)
246
+ logger.debug(f"gRPC inference response: {response}")
247
+
248
+ # TODO(self.client.has_error(response)) => raise error
249
+
250
+ return response.as_numpy("output")
251
+
252
+ def _http_infer(self, formatted_input: dict) -> dict:
253
+ """
254
+ Perform inference using the HTTP protocol, retrying for timeouts or 5xx errors up to 5 times.
255
+
256
+ Parameters
257
+ ----------
258
+ formatted_input : dict
259
+ The input data formatted as a dictionary.
260
+
261
+ Returns
262
+ -------
263
+ dict
264
+ The output of the model as a dictionary.
265
+
266
+ Raises
267
+ ------
268
+ TimeoutError
269
+ If the HTTP request times out repeatedly, up to the max retries.
270
+ requests.RequestException
271
+ For other HTTP-related errors that persist after max retries.
272
+ """
273
+
274
+ base_delay = 2.0
275
+ attempt = 0
276
+
277
+ while attempt < self.max_retries:
278
+ try:
279
+ response = requests.post(
280
+ self.endpoint_url, json=formatted_input, headers=self.headers, timeout=self.timeout
281
+ )
282
+ status_code = response.status_code
283
+
284
+ # Check for server-side or rate-limit type errors
285
+ # e.g. 5xx => server error, 429 => too many requests
286
+ if status_code == 429 or status_code == 503 or (500 <= status_code < 600):
287
+ logger.warning(
288
+ f"Received HTTP {status_code} ({response.reason}) from "
289
+ f"{self.model_interface.name()}. Attempt {attempt + 1} of {self.max_retries}."
290
+ )
291
+ if attempt == self.max_retries - 1:
292
+ # No more retries left
293
+ logger.error(f"Max retries exceeded after receiving HTTP {status_code}.")
294
+ response.raise_for_status() # raise the appropriate HTTPError
295
+ else:
296
+ # Exponential backoff
297
+ backoff_time = base_delay * (2**attempt)
298
+ time.sleep(backoff_time)
299
+ attempt += 1
300
+ continue
301
+ else:
302
+ # Not in our "retry" category => just raise_for_status or return
303
+ response.raise_for_status()
304
+ logger.debug(f"HTTP inference response: {response.json()}")
305
+ return response.json()
306
+
307
+ except requests.Timeout:
308
+ # Treat timeouts similarly to 5xx => attempt a retry
309
+ logger.warning(
310
+ f"HTTP request timed out after {self.timeout} seconds during {self.model_interface.name()} "
311
+ f"inference. Attempt {attempt + 1} of {self.max_retries}."
312
+ )
313
+ if attempt == self.max_retries - 1:
314
+ logger.error("Max retries exceeded after repeated timeouts.")
315
+ raise TimeoutError(
316
+ f"Repeated timeouts for {self.model_interface.name()} after {attempt + 1} attempts."
317
+ )
318
+ # Exponential backoff
319
+ backoff_time = base_delay * (2**attempt)
320
+ time.sleep(backoff_time)
321
+ attempt += 1
322
+
323
+ except requests.HTTPError as http_err:
324
+ # If we ended up here, it's a non-retryable 4xx or final 5xx after final attempt
325
+ logger.error(f"HTTP request failed with status code {response.status_code}: {http_err}")
326
+ raise
327
+
328
+ except requests.RequestException as e:
329
+ # ConnectionError or other non-HTTPError
330
+ logger.error(f"HTTP request encountered a network issue: {e}")
331
+ if attempt == self.max_retries - 1:
332
+ raise
333
+ # Else retry on next loop iteration
334
+ backoff_time = base_delay * (2**attempt)
335
+ time.sleep(backoff_time)
336
+ attempt += 1
337
+
338
+ # If we exit the loop without returning, we've exhausted all attempts
339
+ logger.error(f"Failed to get a successful response after {self.max_retries} retries.")
340
+ raise Exception(f"Failed to get a successful response after {self.max_retries} retries.")
341
+
342
+ def close(self):
343
+ if self.protocol == "grpc" and hasattr(self.client, "close"):
344
+ self.client.close()
@@ -0,0 +1,81 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ import logging
6
+ from typing import Optional
7
+
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class ModelInterface:
13
+ """
14
+ Base class for defining a model interface that supports preparing input data, formatting it for
15
+ inference, parsing output, and processing inference results.
16
+ """
17
+
18
+ def format_input(self, data: dict, protocol: str, max_batch_size: int):
19
+ """
20
+ Format the input data for the specified protocol.
21
+
22
+ Parameters
23
+ ----------
24
+ data : dict
25
+ The input data to format.
26
+ protocol : str
27
+ The protocol to format the data for.
28
+ """
29
+
30
+ raise NotImplementedError("Subclasses should implement this method")
31
+
32
+ def parse_output(self, response, protocol: str, data: Optional[dict] = None, **kwargs):
33
+ """
34
+ Parse the output data from the model's inference response.
35
+
36
+ Parameters
37
+ ----------
38
+ response : Any
39
+ The response from the model inference.
40
+ protocol : str
41
+ The protocol used ("grpc" or "http").
42
+ data : dict, optional
43
+ Additional input data passed to the function.
44
+ """
45
+
46
+ raise NotImplementedError("Subclasses should implement this method")
47
+
48
+ def prepare_data_for_inference(self, data: dict):
49
+ """
50
+ Prepare input data for inference by processing or transforming it as required.
51
+
52
+ Parameters
53
+ ----------
54
+ data : dict
55
+ The input data to prepare.
56
+ """
57
+ raise NotImplementedError("Subclasses should implement this method")
58
+
59
+ def process_inference_results(self, output_array, protocol: str, **kwargs):
60
+ """
61
+ Process the inference results from the model.
62
+
63
+ Parameters
64
+ ----------
65
+ output_array : Any
66
+ The raw output from the model.
67
+ kwargs : dict
68
+ Additional parameters for processing.
69
+ """
70
+ raise NotImplementedError("Subclasses should implement this method")
71
+
72
+ def name(self) -> str:
73
+ """
74
+ Get the name of the model interface.
75
+
76
+ Returns
77
+ -------
78
+ str
79
+ The name of the model interface.
80
+ """
81
+ raise NotImplementedError("Subclasses should implement this method")
File without changes
@@ -0,0 +1,69 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+
6
+ import logging
7
+ from datetime import datetime
8
+ from functools import wraps
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ # Define ANSI color codes
14
+ class ColorCodes:
15
+ RED = "\033[91m"
16
+ GREEN = "\033[92m"
17
+ YELLOW = "\033[93m"
18
+ BLUE = "\033[94m" # Added Blue
19
+ RESET = "\033[0m"
20
+
21
+
22
+ # Function to apply color to a message
23
+ def colorize(message, color_code):
24
+ return f"{color_code}{message}{ColorCodes.RESET}"
25
+
26
+
27
+ def latency_logger(name=None):
28
+ """
29
+ A decorator to log the elapsed time of function execution. If available, it also logs
30
+ the latency based on 'latency::ts_send' metadata in a IngestControlMessage object.
31
+
32
+ Parameters
33
+ ----------
34
+ name : str, optional
35
+ Custom name to use in the log message. Defaults to the function's name.
36
+ """
37
+
38
+ def decorator(func):
39
+ @wraps(func)
40
+ def wrapper(*args, **kwargs):
41
+ # Ensure there's at least one argument and it has timestamp handling capabilities
42
+ if args and hasattr(args[0], "get_timestamp"):
43
+ message = args[0]
44
+ start_time = datetime.now()
45
+
46
+ result = func(*args, **kwargs)
47
+
48
+ end_time = datetime.now()
49
+ elapsed_time = end_time - start_time
50
+
51
+ func_name = name if name else func.__name__
52
+
53
+ # Log latency from ts_send if available
54
+ if message.filter_timestamp("latency::ts_send"):
55
+ ts_send = message.get_timestamp("latency::ts_send")
56
+ latency_ms = (start_time - ts_send).total_seconds() * 1e3
57
+ logger.debug(f"{func_name} since ts_send: {latency_ms} msec.")
58
+
59
+ message.set_timestamp("latency::ts_send", datetime.now())
60
+ message.set_timestamp(f"latency::{func_name}::elapsed_time", elapsed_time)
61
+ return result
62
+ else:
63
+ raise ValueError(
64
+ "The first argument must be a IngestControlMessage object with metadata " "capabilities."
65
+ )
66
+
67
+ return wrapper
68
+
69
+ return decorator
@@ -0,0 +1,96 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+
6
+ import inspect
7
+ import uuid
8
+ from datetime import datetime
9
+ from enum import Enum
10
+
11
+ from nv_ingest_api.internal.primitives.ingest_control_message import IngestControlMessage
12
+
13
+
14
+ class TaskResultStatus(Enum):
15
+ SUCCESS = "SUCCESS"
16
+ FAILURE = "FAILURE"
17
+
18
+
19
+ def annotate_cm(control_message: IngestControlMessage, source_id=None, **kwargs):
20
+ """
21
+ Annotate a IngestControlMessage object with arbitrary metadata, a source ID, and a timestamp.
22
+ Each annotation will be uniquely identified by a UUID.
23
+
24
+ Parameters:
25
+ - control_message: The IngestControlMessage object to be annotated.
26
+ - source_id: A unique identifier for the source of the annotation. If None, uses the caller's __name__.
27
+ - **kwargs: Arbitrary key-value pairs to be included in the annotation.
28
+ """
29
+ if source_id is None:
30
+ # Determine the __name__ of the parent caller's module
31
+ frame = inspect.currentframe()
32
+ caller_frame = inspect.getouterframes(frame)[2]
33
+ module = inspect.getmodule(caller_frame[0])
34
+ source_id = module.__name__ if module is not None else "UnknownModule"
35
+
36
+ # Ensure 'annotation_timestamp' is not overridden by kwargs
37
+ if "annotation_timestamp" in kwargs:
38
+ raise ValueError("'annotation_timestamp' is a reserved key and cannot be specified.")
39
+
40
+ message = kwargs.get("message")
41
+ annotation_key = f"annotation::{message}" if message else f"annotation::{uuid.uuid4()}"
42
+
43
+ annotation_timestamp = datetime.now()
44
+ try:
45
+ control_message.set_timestamp(annotation_key, annotation_timestamp)
46
+ except Exception as e:
47
+ print(f"Failed to set annotation timestamp: {e}")
48
+
49
+ # Construct the metadata key uniquely identified by a UUID.
50
+ metadata_key = f"annotation::{uuid.uuid4()}"
51
+
52
+ # Construct the metadata value with reserved 'annotation_timestamp', source_id, and any provided kwargs.
53
+ metadata_value = {
54
+ "source_id": source_id,
55
+ }
56
+ metadata_value.update(kwargs)
57
+
58
+ try:
59
+ # Attempt to set the annotated metadata on the IngestControlMessage object.
60
+ control_message.set_metadata(metadata_key, metadata_value)
61
+ except Exception as e:
62
+ # Handle any exceptions that occur when setting metadata.
63
+ print(f"Failed to annotate IngestControlMessage: {e}")
64
+
65
+
66
+ def annotate_task_result(control_message, result, task_id, source_id=None, **kwargs):
67
+ """
68
+ Annotate a IngestControlMessage object with the result of a task, identified by a task_id,
69
+ and an arbitrary number of additional key-value pairs. The result can be a TaskResultStatus
70
+ enum or a string that will be converted to the corresponding enum.
71
+
72
+ Parameters:
73
+ - control_message: The IngestControlMessage object to be annotated.
74
+ - result: The result of the task, either SUCCESS or FAILURE, as an enum or string.
75
+ - task_id: A unique identifier for the task.
76
+ - **kwargs: Arbitrary additional key-value pairs to be included in the annotation.
77
+ """
78
+ # Convert result to TaskResultStatus enum if it's a string
79
+ if isinstance(result, str):
80
+ try:
81
+ result = TaskResultStatus[result.upper()]
82
+ except KeyError:
83
+ raise ValueError(
84
+ f"Invalid result string: {result}. Must be one of {[status.name for status in TaskResultStatus]}."
85
+ )
86
+ elif not isinstance(result, TaskResultStatus):
87
+ raise ValueError("result must be an instance of TaskResultStatus Enum or a valid result string.")
88
+
89
+ # Annotate the control message with task-related information, including the result and task_id.
90
+ annotate_cm(
91
+ control_message,
92
+ source_id=source_id,
93
+ task_result=result.value,
94
+ task_id=task_id,
95
+ **kwargs,
96
+ )