nv-ingest-api 2025.4.21.dev20250421__py3-none-any.whl → 2025.4.23.dev20250423__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nv-ingest-api might be problematic. Click here for more details.
- nv_ingest_api/__init__.py +3 -0
- nv_ingest_api/interface/__init__.py +215 -0
- nv_ingest_api/interface/extract.py +972 -0
- nv_ingest_api/interface/mutate.py +154 -0
- nv_ingest_api/interface/store.py +218 -0
- nv_ingest_api/interface/transform.py +382 -0
- nv_ingest_api/interface/utility.py +200 -0
- nv_ingest_api/internal/enums/__init__.py +3 -0
- nv_ingest_api/internal/enums/common.py +494 -0
- nv_ingest_api/internal/extract/__init__.py +3 -0
- nv_ingest_api/internal/extract/audio/__init__.py +3 -0
- nv_ingest_api/internal/extract/audio/audio_extraction.py +149 -0
- nv_ingest_api/internal/extract/docx/__init__.py +5 -0
- nv_ingest_api/internal/extract/docx/docx_extractor.py +205 -0
- nv_ingest_api/internal/extract/docx/engines/__init__.py +0 -0
- nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/__init__.py +3 -0
- nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/docx_helper.py +122 -0
- nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/docxreader.py +895 -0
- nv_ingest_api/internal/extract/image/__init__.py +3 -0
- nv_ingest_api/internal/extract/image/chart_extractor.py +353 -0
- nv_ingest_api/internal/extract/image/image_extractor.py +204 -0
- nv_ingest_api/internal/extract/image/image_helpers/__init__.py +3 -0
- nv_ingest_api/internal/extract/image/image_helpers/common.py +403 -0
- nv_ingest_api/internal/extract/image/infographic_extractor.py +253 -0
- nv_ingest_api/internal/extract/image/table_extractor.py +344 -0
- nv_ingest_api/internal/extract/pdf/__init__.py +3 -0
- nv_ingest_api/internal/extract/pdf/engines/__init__.py +19 -0
- nv_ingest_api/internal/extract/pdf/engines/adobe.py +484 -0
- nv_ingest_api/internal/extract/pdf/engines/llama.py +243 -0
- nv_ingest_api/internal/extract/pdf/engines/nemoretriever.py +597 -0
- nv_ingest_api/internal/extract/pdf/engines/pdf_helpers/__init__.py +146 -0
- nv_ingest_api/internal/extract/pdf/engines/pdfium.py +603 -0
- nv_ingest_api/internal/extract/pdf/engines/tika.py +96 -0
- nv_ingest_api/internal/extract/pdf/engines/unstructured_io.py +426 -0
- nv_ingest_api/internal/extract/pdf/pdf_extractor.py +74 -0
- nv_ingest_api/internal/extract/pptx/__init__.py +5 -0
- nv_ingest_api/internal/extract/pptx/engines/__init__.py +0 -0
- nv_ingest_api/internal/extract/pptx/engines/pptx_helper.py +799 -0
- nv_ingest_api/internal/extract/pptx/pptx_extractor.py +187 -0
- nv_ingest_api/internal/mutate/__init__.py +3 -0
- nv_ingest_api/internal/mutate/deduplicate.py +110 -0
- nv_ingest_api/internal/mutate/filter.py +133 -0
- nv_ingest_api/internal/primitives/__init__.py +0 -0
- nv_ingest_api/{primitives → internal/primitives}/control_message_task.py +4 -0
- nv_ingest_api/{primitives → internal/primitives}/ingest_control_message.py +5 -2
- nv_ingest_api/internal/primitives/nim/__init__.py +8 -0
- nv_ingest_api/internal/primitives/nim/default_values.py +15 -0
- nv_ingest_api/internal/primitives/nim/model_interface/__init__.py +3 -0
- nv_ingest_api/internal/primitives/nim/model_interface/cached.py +274 -0
- nv_ingest_api/internal/primitives/nim/model_interface/decorators.py +56 -0
- nv_ingest_api/internal/primitives/nim/model_interface/deplot.py +270 -0
- nv_ingest_api/internal/primitives/nim/model_interface/helpers.py +275 -0
- nv_ingest_api/internal/primitives/nim/model_interface/nemoretriever_parse.py +238 -0
- nv_ingest_api/internal/primitives/nim/model_interface/paddle.py +462 -0
- nv_ingest_api/internal/primitives/nim/model_interface/parakeet.py +367 -0
- nv_ingest_api/internal/primitives/nim/model_interface/text_embedding.py +132 -0
- nv_ingest_api/internal/primitives/nim/model_interface/vlm.py +152 -0
- nv_ingest_api/internal/primitives/nim/model_interface/yolox.py +1400 -0
- nv_ingest_api/internal/primitives/nim/nim_client.py +344 -0
- nv_ingest_api/internal/primitives/nim/nim_model_interface.py +81 -0
- nv_ingest_api/internal/primitives/tracing/__init__.py +0 -0
- nv_ingest_api/internal/primitives/tracing/latency.py +69 -0
- nv_ingest_api/internal/primitives/tracing/logging.py +96 -0
- nv_ingest_api/internal/primitives/tracing/tagging.py +197 -0
- nv_ingest_api/internal/schemas/__init__.py +3 -0
- nv_ingest_api/internal/schemas/extract/__init__.py +3 -0
- nv_ingest_api/internal/schemas/extract/extract_audio_schema.py +130 -0
- nv_ingest_api/internal/schemas/extract/extract_chart_schema.py +135 -0
- nv_ingest_api/internal/schemas/extract/extract_docx_schema.py +124 -0
- nv_ingest_api/internal/schemas/extract/extract_image_schema.py +124 -0
- nv_ingest_api/internal/schemas/extract/extract_infographic_schema.py +128 -0
- nv_ingest_api/internal/schemas/extract/extract_pdf_schema.py +218 -0
- nv_ingest_api/internal/schemas/extract/extract_pptx_schema.py +124 -0
- nv_ingest_api/internal/schemas/extract/extract_table_schema.py +129 -0
- nv_ingest_api/internal/schemas/message_brokers/__init__.py +3 -0
- nv_ingest_api/internal/schemas/message_brokers/message_broker_client_schema.py +23 -0
- nv_ingest_api/internal/schemas/message_brokers/request_schema.py +34 -0
- nv_ingest_api/internal/schemas/message_brokers/response_schema.py +19 -0
- nv_ingest_api/internal/schemas/meta/__init__.py +3 -0
- nv_ingest_api/internal/schemas/meta/base_model_noext.py +11 -0
- nv_ingest_api/internal/schemas/meta/ingest_job_schema.py +237 -0
- nv_ingest_api/internal/schemas/meta/metadata_schema.py +221 -0
- nv_ingest_api/internal/schemas/mutate/__init__.py +3 -0
- nv_ingest_api/internal/schemas/mutate/mutate_image_dedup_schema.py +16 -0
- nv_ingest_api/internal/schemas/store/__init__.py +3 -0
- nv_ingest_api/internal/schemas/store/store_embedding_schema.py +28 -0
- nv_ingest_api/internal/schemas/store/store_image_schema.py +30 -0
- nv_ingest_api/internal/schemas/transform/__init__.py +3 -0
- nv_ingest_api/internal/schemas/transform/transform_image_caption_schema.py +15 -0
- nv_ingest_api/internal/schemas/transform/transform_image_filter_schema.py +17 -0
- nv_ingest_api/internal/schemas/transform/transform_text_embedding_schema.py +25 -0
- nv_ingest_api/internal/schemas/transform/transform_text_splitter_schema.py +22 -0
- nv_ingest_api/internal/store/__init__.py +3 -0
- nv_ingest_api/internal/store/embed_text_upload.py +236 -0
- nv_ingest_api/internal/store/image_upload.py +232 -0
- nv_ingest_api/internal/transform/__init__.py +3 -0
- nv_ingest_api/internal/transform/caption_image.py +205 -0
- nv_ingest_api/internal/transform/embed_text.py +496 -0
- nv_ingest_api/internal/transform/split_text.py +157 -0
- nv_ingest_api/util/__init__.py +0 -0
- nv_ingest_api/util/control_message/__init__.py +0 -0
- nv_ingest_api/util/control_message/validators.py +47 -0
- nv_ingest_api/util/converters/__init__.py +0 -0
- nv_ingest_api/util/converters/bytetools.py +78 -0
- nv_ingest_api/util/converters/containers.py +65 -0
- nv_ingest_api/util/converters/datetools.py +90 -0
- nv_ingest_api/util/converters/dftools.py +127 -0
- nv_ingest_api/util/converters/formats.py +64 -0
- nv_ingest_api/util/converters/type_mappings.py +27 -0
- nv_ingest_api/util/detectors/__init__.py +5 -0
- nv_ingest_api/util/detectors/language.py +38 -0
- nv_ingest_api/util/exception_handlers/__init__.py +0 -0
- nv_ingest_api/util/exception_handlers/converters.py +72 -0
- nv_ingest_api/util/exception_handlers/decorators.py +223 -0
- nv_ingest_api/util/exception_handlers/detectors.py +74 -0
- nv_ingest_api/util/exception_handlers/pdf.py +116 -0
- nv_ingest_api/util/exception_handlers/schemas.py +68 -0
- nv_ingest_api/util/image_processing/__init__.py +5 -0
- nv_ingest_api/util/image_processing/clustering.py +260 -0
- nv_ingest_api/util/image_processing/processing.py +179 -0
- nv_ingest_api/util/image_processing/table_and_chart.py +449 -0
- nv_ingest_api/util/image_processing/transforms.py +407 -0
- nv_ingest_api/util/logging/__init__.py +0 -0
- nv_ingest_api/util/logging/configuration.py +31 -0
- nv_ingest_api/util/message_brokers/__init__.py +3 -0
- nv_ingest_api/util/message_brokers/simple_message_broker/__init__.py +9 -0
- nv_ingest_api/util/message_brokers/simple_message_broker/broker.py +465 -0
- nv_ingest_api/util/message_brokers/simple_message_broker/ordered_message_queue.py +71 -0
- nv_ingest_api/util/message_brokers/simple_message_broker/simple_client.py +451 -0
- nv_ingest_api/util/metadata/__init__.py +5 -0
- nv_ingest_api/util/metadata/aggregators.py +469 -0
- nv_ingest_api/util/multi_processing/__init__.py +8 -0
- nv_ingest_api/util/multi_processing/mp_pool_singleton.py +194 -0
- nv_ingest_api/util/nim/__init__.py +56 -0
- nv_ingest_api/util/pdf/__init__.py +3 -0
- nv_ingest_api/util/pdf/pdfium.py +427 -0
- nv_ingest_api/util/schema/__init__.py +0 -0
- nv_ingest_api/util/schema/schema_validator.py +10 -0
- nv_ingest_api/util/service_clients/__init__.py +3 -0
- nv_ingest_api/util/service_clients/client_base.py +86 -0
- nv_ingest_api/util/service_clients/kafka/__init__.py +3 -0
- nv_ingest_api/util/service_clients/redis/__init__.py +0 -0
- nv_ingest_api/util/service_clients/redis/redis_client.py +823 -0
- nv_ingest_api/util/service_clients/rest/__init__.py +0 -0
- nv_ingest_api/util/service_clients/rest/rest_client.py +531 -0
- nv_ingest_api/util/string_processing/__init__.py +51 -0
- {nv_ingest_api-2025.4.21.dev20250421.dist-info → nv_ingest_api-2025.4.23.dev20250423.dist-info}/METADATA +1 -1
- nv_ingest_api-2025.4.23.dev20250423.dist-info/RECORD +152 -0
- {nv_ingest_api-2025.4.21.dev20250421.dist-info → nv_ingest_api-2025.4.23.dev20250423.dist-info}/WHEEL +1 -1
- nv_ingest_api-2025.4.21.dev20250421.dist-info/RECORD +0 -9
- /nv_ingest_api/{primitives → internal}/__init__.py +0 -0
- {nv_ingest_api-2025.4.21.dev20250421.dist-info → nv_ingest_api-2025.4.23.dev20250423.dist-info}/licenses/LICENSE +0 -0
- {nv_ingest_api-2025.4.21.dev20250421.dist-info → nv_ingest_api-2025.4.23.dev20250423.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,344 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import threading
|
|
7
|
+
import time
|
|
8
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
9
|
+
from typing import Any
|
|
10
|
+
from typing import Optional
|
|
11
|
+
from typing import Tuple
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
import requests
|
|
15
|
+
import tritonclient.grpc as grpcclient
|
|
16
|
+
|
|
17
|
+
from nv_ingest_api.internal.primitives.tracing.tagging import traceable_func
|
|
18
|
+
from nv_ingest_api.util.string_processing import generate_url
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class NimClient:
|
|
24
|
+
"""
|
|
25
|
+
A client for interfacing with a model inference server using gRPC or HTTP protocols.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
model_interface,
|
|
31
|
+
protocol: str,
|
|
32
|
+
endpoints: Tuple[str, str],
|
|
33
|
+
auth_token: Optional[str] = None,
|
|
34
|
+
timeout: float = 120.0,
|
|
35
|
+
max_retries: int = 5,
|
|
36
|
+
):
|
|
37
|
+
"""
|
|
38
|
+
Initialize the NimClient with the specified model interface, protocol, and server endpoints.
|
|
39
|
+
|
|
40
|
+
Parameters
|
|
41
|
+
----------
|
|
42
|
+
model_interface : ModelInterface
|
|
43
|
+
The model interface implementation to use.
|
|
44
|
+
protocol : str
|
|
45
|
+
The protocol to use ("grpc" or "http").
|
|
46
|
+
endpoints : tuple
|
|
47
|
+
A tuple containing the gRPC and HTTP endpoints.
|
|
48
|
+
auth_token : str, optional
|
|
49
|
+
Authorization token for HTTP requests (default: None).
|
|
50
|
+
timeout : float, optional
|
|
51
|
+
Timeout for HTTP requests in seconds (default: 30.0).
|
|
52
|
+
|
|
53
|
+
Raises
|
|
54
|
+
------
|
|
55
|
+
ValueError
|
|
56
|
+
If an invalid protocol is specified or if required endpoints are missing.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
self.client = None
|
|
60
|
+
self.model_interface = model_interface
|
|
61
|
+
self.protocol = protocol.lower()
|
|
62
|
+
self.auth_token = auth_token
|
|
63
|
+
self.timeout = timeout # Timeout for HTTP requests
|
|
64
|
+
self.max_retries = max_retries
|
|
65
|
+
self._grpc_endpoint, self._http_endpoint = endpoints
|
|
66
|
+
self._max_batch_sizes = {}
|
|
67
|
+
self._lock = threading.Lock()
|
|
68
|
+
|
|
69
|
+
if self.protocol == "grpc":
|
|
70
|
+
if not self._grpc_endpoint:
|
|
71
|
+
raise ValueError("gRPC endpoint must be provided for gRPC protocol")
|
|
72
|
+
logger.debug(f"Creating gRPC client with {self._grpc_endpoint}")
|
|
73
|
+
self.client = grpcclient.InferenceServerClient(url=self._grpc_endpoint)
|
|
74
|
+
elif self.protocol == "http":
|
|
75
|
+
if not self._http_endpoint:
|
|
76
|
+
raise ValueError("HTTP endpoint must be provided for HTTP protocol")
|
|
77
|
+
logger.debug(f"Creating HTTP client with {self._http_endpoint}")
|
|
78
|
+
self.endpoint_url = generate_url(self._http_endpoint)
|
|
79
|
+
self.headers = {"accept": "application/json", "content-type": "application/json"}
|
|
80
|
+
if self.auth_token:
|
|
81
|
+
self.headers["Authorization"] = f"Bearer {self.auth_token}"
|
|
82
|
+
else:
|
|
83
|
+
raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.")
|
|
84
|
+
|
|
85
|
+
def _fetch_max_batch_size(self, model_name, model_version: str = "") -> int:
|
|
86
|
+
"""Fetch the maximum batch size from the Triton model configuration in a thread-safe manner."""
|
|
87
|
+
if model_name in self._max_batch_sizes:
|
|
88
|
+
return self._max_batch_sizes[model_name]
|
|
89
|
+
|
|
90
|
+
with self._lock:
|
|
91
|
+
# Double check, just in case another thread set the value while we were waiting
|
|
92
|
+
if model_name in self._max_batch_sizes:
|
|
93
|
+
return self._max_batch_sizes[model_name]
|
|
94
|
+
|
|
95
|
+
if not self._grpc_endpoint:
|
|
96
|
+
self._max_batch_sizes[model_name] = 1
|
|
97
|
+
return 1
|
|
98
|
+
|
|
99
|
+
try:
|
|
100
|
+
client = self.client if self.client else grpcclient.InferenceServerClient(url=self._grpc_endpoint)
|
|
101
|
+
model_config = client.get_model_config(model_name=model_name, model_version=model_version)
|
|
102
|
+
self._max_batch_sizes[model_name] = model_config.config.max_batch_size
|
|
103
|
+
logger.debug(f"Max batch size for model '{model_name}': {self._max_batch_sizes[model_name]}")
|
|
104
|
+
except Exception as e:
|
|
105
|
+
self._max_batch_sizes[model_name] = 1
|
|
106
|
+
logger.warning(f"Failed to retrieve max batch size: {e}, defaulting to 1")
|
|
107
|
+
|
|
108
|
+
return self._max_batch_sizes[model_name]
|
|
109
|
+
|
|
110
|
+
def _process_batch(self, batch_input, *, batch_data, model_name, **kwargs):
|
|
111
|
+
"""
|
|
112
|
+
Process a single batch input for inference using its corresponding batch_data.
|
|
113
|
+
|
|
114
|
+
Parameters
|
|
115
|
+
----------
|
|
116
|
+
batch_input : Any
|
|
117
|
+
The input data for this batch.
|
|
118
|
+
batch_data : Any
|
|
119
|
+
The corresponding scratch-pad data for this batch as returned by format_input.
|
|
120
|
+
model_name : str
|
|
121
|
+
The model name for inference.
|
|
122
|
+
kwargs : dict
|
|
123
|
+
Additional parameters.
|
|
124
|
+
|
|
125
|
+
Returns
|
|
126
|
+
-------
|
|
127
|
+
tuple
|
|
128
|
+
A tuple (parsed_output, batch_data) for subsequent post-processing.
|
|
129
|
+
"""
|
|
130
|
+
if self.protocol == "grpc":
|
|
131
|
+
logger.debug("Performing gRPC inference for a batch...")
|
|
132
|
+
response = self._grpc_infer(batch_input, model_name)
|
|
133
|
+
logger.debug("gRPC inference received response for a batch")
|
|
134
|
+
elif self.protocol == "http":
|
|
135
|
+
logger.debug("Performing HTTP inference for a batch...")
|
|
136
|
+
response = self._http_infer(batch_input)
|
|
137
|
+
logger.debug("HTTP inference received response for a batch")
|
|
138
|
+
else:
|
|
139
|
+
raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.")
|
|
140
|
+
|
|
141
|
+
parsed_output = self.model_interface.parse_output(response, protocol=self.protocol, data=batch_data, **kwargs)
|
|
142
|
+
return parsed_output, batch_data
|
|
143
|
+
|
|
144
|
+
def try_set_max_batch_size(self, model_name, model_version: str = ""):
|
|
145
|
+
"""Attempt to set the max batch size for the model if it is not already set, ensuring thread safety."""
|
|
146
|
+
self._fetch_max_batch_size(model_name, model_version)
|
|
147
|
+
|
|
148
|
+
@traceable_func(trace_name="{stage_name}::{model_name}")
|
|
149
|
+
def infer(self, data: dict, model_name: str, **kwargs) -> Any:
|
|
150
|
+
"""
|
|
151
|
+
Perform inference using the specified model and input data.
|
|
152
|
+
|
|
153
|
+
Parameters
|
|
154
|
+
----------
|
|
155
|
+
data : dict
|
|
156
|
+
The input data for inference.
|
|
157
|
+
model_name : str
|
|
158
|
+
The model name.
|
|
159
|
+
kwargs : dict
|
|
160
|
+
Additional parameters for inference.
|
|
161
|
+
|
|
162
|
+
Returns
|
|
163
|
+
-------
|
|
164
|
+
Any
|
|
165
|
+
The processed inference results, coalesced in the same order as the input images.
|
|
166
|
+
"""
|
|
167
|
+
try:
|
|
168
|
+
# 1. Retrieve or default to the model's maximum batch size.
|
|
169
|
+
batch_size = self._fetch_max_batch_size(model_name)
|
|
170
|
+
max_requested_batch_size = kwargs.get("max_batch_size", batch_size)
|
|
171
|
+
force_requested_batch_size = kwargs.get("force_max_batch_size", False)
|
|
172
|
+
max_batch_size = (
|
|
173
|
+
min(batch_size, max_requested_batch_size)
|
|
174
|
+
if not force_requested_batch_size
|
|
175
|
+
else max_requested_batch_size
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
# 2. Prepare data for inference.
|
|
179
|
+
data = self.model_interface.prepare_data_for_inference(data)
|
|
180
|
+
|
|
181
|
+
# 3. Format the input based on protocol.
|
|
182
|
+
formatted_batches, formatted_batch_data = self.model_interface.format_input(
|
|
183
|
+
data, protocol=self.protocol, max_batch_size=max_batch_size, model_name=model_name
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
# Check for a custom maximum pool worker count, and remove it from kwargs.
|
|
187
|
+
max_pool_workers = kwargs.pop("max_pool_workers", 16)
|
|
188
|
+
|
|
189
|
+
# 4. Process each batch concurrently using a thread pool.
|
|
190
|
+
# We enumerate the batches so that we can later reassemble results in order.
|
|
191
|
+
results = [None] * len(formatted_batches)
|
|
192
|
+
with ThreadPoolExecutor(max_workers=max_pool_workers) as executor:
|
|
193
|
+
futures = []
|
|
194
|
+
for idx, (batch, batch_data) in enumerate(zip(formatted_batches, formatted_batch_data)):
|
|
195
|
+
future = executor.submit(
|
|
196
|
+
self._process_batch, batch, batch_data=batch_data, model_name=model_name, **kwargs
|
|
197
|
+
)
|
|
198
|
+
futures.append((idx, future))
|
|
199
|
+
for idx, future in futures:
|
|
200
|
+
results[idx] = future.result()
|
|
201
|
+
|
|
202
|
+
# 5. Process the parsed outputs for each batch using its corresponding batch_data.
|
|
203
|
+
# As the batches are in order, we coalesce their outputs accordingly.
|
|
204
|
+
all_results = []
|
|
205
|
+
for parsed_output, batch_data in results:
|
|
206
|
+
batch_results = self.model_interface.process_inference_results(
|
|
207
|
+
parsed_output,
|
|
208
|
+
original_image_shapes=batch_data.get("original_image_shapes"),
|
|
209
|
+
protocol=self.protocol,
|
|
210
|
+
**kwargs,
|
|
211
|
+
)
|
|
212
|
+
if isinstance(batch_results, list):
|
|
213
|
+
all_results.extend(batch_results)
|
|
214
|
+
else:
|
|
215
|
+
all_results.append(batch_results)
|
|
216
|
+
|
|
217
|
+
except Exception as err:
|
|
218
|
+
error_str = f"Error during NimClient inference [{self.model_interface.name()}, {self.protocol}]: {err}"
|
|
219
|
+
logger.error(error_str)
|
|
220
|
+
raise RuntimeError(error_str)
|
|
221
|
+
|
|
222
|
+
return all_results
|
|
223
|
+
|
|
224
|
+
def _grpc_infer(self, formatted_input: np.ndarray, model_name: str) -> np.ndarray:
|
|
225
|
+
"""
|
|
226
|
+
Perform inference using the gRPC protocol.
|
|
227
|
+
|
|
228
|
+
Parameters
|
|
229
|
+
----------
|
|
230
|
+
formatted_input : np.ndarray
|
|
231
|
+
The input data formatted as a numpy array.
|
|
232
|
+
model_name : str
|
|
233
|
+
The name of the model to use for inference.
|
|
234
|
+
|
|
235
|
+
Returns
|
|
236
|
+
-------
|
|
237
|
+
np.ndarray
|
|
238
|
+
The output of the model as a numpy array.
|
|
239
|
+
"""
|
|
240
|
+
|
|
241
|
+
input_tensors = [grpcclient.InferInput("input", formatted_input.shape, datatype="FP32")]
|
|
242
|
+
input_tensors[0].set_data_from_numpy(formatted_input)
|
|
243
|
+
|
|
244
|
+
outputs = [grpcclient.InferRequestedOutput("output")]
|
|
245
|
+
response = self.client.infer(model_name=model_name, inputs=input_tensors, outputs=outputs)
|
|
246
|
+
logger.debug(f"gRPC inference response: {response}")
|
|
247
|
+
|
|
248
|
+
# TODO(self.client.has_error(response)) => raise error
|
|
249
|
+
|
|
250
|
+
return response.as_numpy("output")
|
|
251
|
+
|
|
252
|
+
def _http_infer(self, formatted_input: dict) -> dict:
|
|
253
|
+
"""
|
|
254
|
+
Perform inference using the HTTP protocol, retrying for timeouts or 5xx errors up to 5 times.
|
|
255
|
+
|
|
256
|
+
Parameters
|
|
257
|
+
----------
|
|
258
|
+
formatted_input : dict
|
|
259
|
+
The input data formatted as a dictionary.
|
|
260
|
+
|
|
261
|
+
Returns
|
|
262
|
+
-------
|
|
263
|
+
dict
|
|
264
|
+
The output of the model as a dictionary.
|
|
265
|
+
|
|
266
|
+
Raises
|
|
267
|
+
------
|
|
268
|
+
TimeoutError
|
|
269
|
+
If the HTTP request times out repeatedly, up to the max retries.
|
|
270
|
+
requests.RequestException
|
|
271
|
+
For other HTTP-related errors that persist after max retries.
|
|
272
|
+
"""
|
|
273
|
+
|
|
274
|
+
base_delay = 2.0
|
|
275
|
+
attempt = 0
|
|
276
|
+
|
|
277
|
+
while attempt < self.max_retries:
|
|
278
|
+
try:
|
|
279
|
+
response = requests.post(
|
|
280
|
+
self.endpoint_url, json=formatted_input, headers=self.headers, timeout=self.timeout
|
|
281
|
+
)
|
|
282
|
+
status_code = response.status_code
|
|
283
|
+
|
|
284
|
+
# Check for server-side or rate-limit type errors
|
|
285
|
+
# e.g. 5xx => server error, 429 => too many requests
|
|
286
|
+
if status_code == 429 or status_code == 503 or (500 <= status_code < 600):
|
|
287
|
+
logger.warning(
|
|
288
|
+
f"Received HTTP {status_code} ({response.reason}) from "
|
|
289
|
+
f"{self.model_interface.name()}. Attempt {attempt + 1} of {self.max_retries}."
|
|
290
|
+
)
|
|
291
|
+
if attempt == self.max_retries - 1:
|
|
292
|
+
# No more retries left
|
|
293
|
+
logger.error(f"Max retries exceeded after receiving HTTP {status_code}.")
|
|
294
|
+
response.raise_for_status() # raise the appropriate HTTPError
|
|
295
|
+
else:
|
|
296
|
+
# Exponential backoff
|
|
297
|
+
backoff_time = base_delay * (2**attempt)
|
|
298
|
+
time.sleep(backoff_time)
|
|
299
|
+
attempt += 1
|
|
300
|
+
continue
|
|
301
|
+
else:
|
|
302
|
+
# Not in our "retry" category => just raise_for_status or return
|
|
303
|
+
response.raise_for_status()
|
|
304
|
+
logger.debug(f"HTTP inference response: {response.json()}")
|
|
305
|
+
return response.json()
|
|
306
|
+
|
|
307
|
+
except requests.Timeout:
|
|
308
|
+
# Treat timeouts similarly to 5xx => attempt a retry
|
|
309
|
+
logger.warning(
|
|
310
|
+
f"HTTP request timed out after {self.timeout} seconds during {self.model_interface.name()} "
|
|
311
|
+
f"inference. Attempt {attempt + 1} of {self.max_retries}."
|
|
312
|
+
)
|
|
313
|
+
if attempt == self.max_retries - 1:
|
|
314
|
+
logger.error("Max retries exceeded after repeated timeouts.")
|
|
315
|
+
raise TimeoutError(
|
|
316
|
+
f"Repeated timeouts for {self.model_interface.name()} after {attempt + 1} attempts."
|
|
317
|
+
)
|
|
318
|
+
# Exponential backoff
|
|
319
|
+
backoff_time = base_delay * (2**attempt)
|
|
320
|
+
time.sleep(backoff_time)
|
|
321
|
+
attempt += 1
|
|
322
|
+
|
|
323
|
+
except requests.HTTPError as http_err:
|
|
324
|
+
# If we ended up here, it's a non-retryable 4xx or final 5xx after final attempt
|
|
325
|
+
logger.error(f"HTTP request failed with status code {response.status_code}: {http_err}")
|
|
326
|
+
raise
|
|
327
|
+
|
|
328
|
+
except requests.RequestException as e:
|
|
329
|
+
# ConnectionError or other non-HTTPError
|
|
330
|
+
logger.error(f"HTTP request encountered a network issue: {e}")
|
|
331
|
+
if attempt == self.max_retries - 1:
|
|
332
|
+
raise
|
|
333
|
+
# Else retry on next loop iteration
|
|
334
|
+
backoff_time = base_delay * (2**attempt)
|
|
335
|
+
time.sleep(backoff_time)
|
|
336
|
+
attempt += 1
|
|
337
|
+
|
|
338
|
+
# If we exit the loop without returning, we've exhausted all attempts
|
|
339
|
+
logger.error(f"Failed to get a successful response after {self.max_retries} retries.")
|
|
340
|
+
raise Exception(f"Failed to get a successful response after {self.max_retries} retries.")
|
|
341
|
+
|
|
342
|
+
def close(self):
|
|
343
|
+
if self.protocol == "grpc" and hasattr(self.client, "close"):
|
|
344
|
+
self.client.close()
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ModelInterface:
|
|
13
|
+
"""
|
|
14
|
+
Base class for defining a model interface that supports preparing input data, formatting it for
|
|
15
|
+
inference, parsing output, and processing inference results.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def format_input(self, data: dict, protocol: str, max_batch_size: int):
|
|
19
|
+
"""
|
|
20
|
+
Format the input data for the specified protocol.
|
|
21
|
+
|
|
22
|
+
Parameters
|
|
23
|
+
----------
|
|
24
|
+
data : dict
|
|
25
|
+
The input data to format.
|
|
26
|
+
protocol : str
|
|
27
|
+
The protocol to format the data for.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
raise NotImplementedError("Subclasses should implement this method")
|
|
31
|
+
|
|
32
|
+
def parse_output(self, response, protocol: str, data: Optional[dict] = None, **kwargs):
|
|
33
|
+
"""
|
|
34
|
+
Parse the output data from the model's inference response.
|
|
35
|
+
|
|
36
|
+
Parameters
|
|
37
|
+
----------
|
|
38
|
+
response : Any
|
|
39
|
+
The response from the model inference.
|
|
40
|
+
protocol : str
|
|
41
|
+
The protocol used ("grpc" or "http").
|
|
42
|
+
data : dict, optional
|
|
43
|
+
Additional input data passed to the function.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
raise NotImplementedError("Subclasses should implement this method")
|
|
47
|
+
|
|
48
|
+
def prepare_data_for_inference(self, data: dict):
|
|
49
|
+
"""
|
|
50
|
+
Prepare input data for inference by processing or transforming it as required.
|
|
51
|
+
|
|
52
|
+
Parameters
|
|
53
|
+
----------
|
|
54
|
+
data : dict
|
|
55
|
+
The input data to prepare.
|
|
56
|
+
"""
|
|
57
|
+
raise NotImplementedError("Subclasses should implement this method")
|
|
58
|
+
|
|
59
|
+
def process_inference_results(self, output_array, protocol: str, **kwargs):
|
|
60
|
+
"""
|
|
61
|
+
Process the inference results from the model.
|
|
62
|
+
|
|
63
|
+
Parameters
|
|
64
|
+
----------
|
|
65
|
+
output_array : Any
|
|
66
|
+
The raw output from the model.
|
|
67
|
+
kwargs : dict
|
|
68
|
+
Additional parameters for processing.
|
|
69
|
+
"""
|
|
70
|
+
raise NotImplementedError("Subclasses should implement this method")
|
|
71
|
+
|
|
72
|
+
def name(self) -> str:
|
|
73
|
+
"""
|
|
74
|
+
Get the name of the model interface.
|
|
75
|
+
|
|
76
|
+
Returns
|
|
77
|
+
-------
|
|
78
|
+
str
|
|
79
|
+
The name of the model interface.
|
|
80
|
+
"""
|
|
81
|
+
raise NotImplementedError("Subclasses should implement this method")
|
|
File without changes
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
from datetime import datetime
|
|
8
|
+
from functools import wraps
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
# Define ANSI color codes
|
|
14
|
+
class ColorCodes:
|
|
15
|
+
RED = "\033[91m"
|
|
16
|
+
GREEN = "\033[92m"
|
|
17
|
+
YELLOW = "\033[93m"
|
|
18
|
+
BLUE = "\033[94m" # Added Blue
|
|
19
|
+
RESET = "\033[0m"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
# Function to apply color to a message
|
|
23
|
+
def colorize(message, color_code):
|
|
24
|
+
return f"{color_code}{message}{ColorCodes.RESET}"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def latency_logger(name=None):
|
|
28
|
+
"""
|
|
29
|
+
A decorator to log the elapsed time of function execution. If available, it also logs
|
|
30
|
+
the latency based on 'latency::ts_send' metadata in a IngestControlMessage object.
|
|
31
|
+
|
|
32
|
+
Parameters
|
|
33
|
+
----------
|
|
34
|
+
name : str, optional
|
|
35
|
+
Custom name to use in the log message. Defaults to the function's name.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def decorator(func):
|
|
39
|
+
@wraps(func)
|
|
40
|
+
def wrapper(*args, **kwargs):
|
|
41
|
+
# Ensure there's at least one argument and it has timestamp handling capabilities
|
|
42
|
+
if args and hasattr(args[0], "get_timestamp"):
|
|
43
|
+
message = args[0]
|
|
44
|
+
start_time = datetime.now()
|
|
45
|
+
|
|
46
|
+
result = func(*args, **kwargs)
|
|
47
|
+
|
|
48
|
+
end_time = datetime.now()
|
|
49
|
+
elapsed_time = end_time - start_time
|
|
50
|
+
|
|
51
|
+
func_name = name if name else func.__name__
|
|
52
|
+
|
|
53
|
+
# Log latency from ts_send if available
|
|
54
|
+
if message.filter_timestamp("latency::ts_send"):
|
|
55
|
+
ts_send = message.get_timestamp("latency::ts_send")
|
|
56
|
+
latency_ms = (start_time - ts_send).total_seconds() * 1e3
|
|
57
|
+
logger.debug(f"{func_name} since ts_send: {latency_ms} msec.")
|
|
58
|
+
|
|
59
|
+
message.set_timestamp("latency::ts_send", datetime.now())
|
|
60
|
+
message.set_timestamp(f"latency::{func_name}::elapsed_time", elapsed_time)
|
|
61
|
+
return result
|
|
62
|
+
else:
|
|
63
|
+
raise ValueError(
|
|
64
|
+
"The first argument must be a IngestControlMessage object with metadata " "capabilities."
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
return wrapper
|
|
68
|
+
|
|
69
|
+
return decorator
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
import inspect
|
|
7
|
+
import uuid
|
|
8
|
+
from datetime import datetime
|
|
9
|
+
from enum import Enum
|
|
10
|
+
|
|
11
|
+
from nv_ingest_api.internal.primitives.ingest_control_message import IngestControlMessage
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class TaskResultStatus(Enum):
|
|
15
|
+
SUCCESS = "SUCCESS"
|
|
16
|
+
FAILURE = "FAILURE"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def annotate_cm(control_message: IngestControlMessage, source_id=None, **kwargs):
|
|
20
|
+
"""
|
|
21
|
+
Annotate a IngestControlMessage object with arbitrary metadata, a source ID, and a timestamp.
|
|
22
|
+
Each annotation will be uniquely identified by a UUID.
|
|
23
|
+
|
|
24
|
+
Parameters:
|
|
25
|
+
- control_message: The IngestControlMessage object to be annotated.
|
|
26
|
+
- source_id: A unique identifier for the source of the annotation. If None, uses the caller's __name__.
|
|
27
|
+
- **kwargs: Arbitrary key-value pairs to be included in the annotation.
|
|
28
|
+
"""
|
|
29
|
+
if source_id is None:
|
|
30
|
+
# Determine the __name__ of the parent caller's module
|
|
31
|
+
frame = inspect.currentframe()
|
|
32
|
+
caller_frame = inspect.getouterframes(frame)[2]
|
|
33
|
+
module = inspect.getmodule(caller_frame[0])
|
|
34
|
+
source_id = module.__name__ if module is not None else "UnknownModule"
|
|
35
|
+
|
|
36
|
+
# Ensure 'annotation_timestamp' is not overridden by kwargs
|
|
37
|
+
if "annotation_timestamp" in kwargs:
|
|
38
|
+
raise ValueError("'annotation_timestamp' is a reserved key and cannot be specified.")
|
|
39
|
+
|
|
40
|
+
message = kwargs.get("message")
|
|
41
|
+
annotation_key = f"annotation::{message}" if message else f"annotation::{uuid.uuid4()}"
|
|
42
|
+
|
|
43
|
+
annotation_timestamp = datetime.now()
|
|
44
|
+
try:
|
|
45
|
+
control_message.set_timestamp(annotation_key, annotation_timestamp)
|
|
46
|
+
except Exception as e:
|
|
47
|
+
print(f"Failed to set annotation timestamp: {e}")
|
|
48
|
+
|
|
49
|
+
# Construct the metadata key uniquely identified by a UUID.
|
|
50
|
+
metadata_key = f"annotation::{uuid.uuid4()}"
|
|
51
|
+
|
|
52
|
+
# Construct the metadata value with reserved 'annotation_timestamp', source_id, and any provided kwargs.
|
|
53
|
+
metadata_value = {
|
|
54
|
+
"source_id": source_id,
|
|
55
|
+
}
|
|
56
|
+
metadata_value.update(kwargs)
|
|
57
|
+
|
|
58
|
+
try:
|
|
59
|
+
# Attempt to set the annotated metadata on the IngestControlMessage object.
|
|
60
|
+
control_message.set_metadata(metadata_key, metadata_value)
|
|
61
|
+
except Exception as e:
|
|
62
|
+
# Handle any exceptions that occur when setting metadata.
|
|
63
|
+
print(f"Failed to annotate IngestControlMessage: {e}")
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def annotate_task_result(control_message, result, task_id, source_id=None, **kwargs):
|
|
67
|
+
"""
|
|
68
|
+
Annotate a IngestControlMessage object with the result of a task, identified by a task_id,
|
|
69
|
+
and an arbitrary number of additional key-value pairs. The result can be a TaskResultStatus
|
|
70
|
+
enum or a string that will be converted to the corresponding enum.
|
|
71
|
+
|
|
72
|
+
Parameters:
|
|
73
|
+
- control_message: The IngestControlMessage object to be annotated.
|
|
74
|
+
- result: The result of the task, either SUCCESS or FAILURE, as an enum or string.
|
|
75
|
+
- task_id: A unique identifier for the task.
|
|
76
|
+
- **kwargs: Arbitrary additional key-value pairs to be included in the annotation.
|
|
77
|
+
"""
|
|
78
|
+
# Convert result to TaskResultStatus enum if it's a string
|
|
79
|
+
if isinstance(result, str):
|
|
80
|
+
try:
|
|
81
|
+
result = TaskResultStatus[result.upper()]
|
|
82
|
+
except KeyError:
|
|
83
|
+
raise ValueError(
|
|
84
|
+
f"Invalid result string: {result}. Must be one of {[status.name for status in TaskResultStatus]}."
|
|
85
|
+
)
|
|
86
|
+
elif not isinstance(result, TaskResultStatus):
|
|
87
|
+
raise ValueError("result must be an instance of TaskResultStatus Enum or a valid result string.")
|
|
88
|
+
|
|
89
|
+
# Annotate the control message with task-related information, including the result and task_id.
|
|
90
|
+
annotate_cm(
|
|
91
|
+
control_message,
|
|
92
|
+
source_id=source_id,
|
|
93
|
+
task_result=result.value,
|
|
94
|
+
task_id=task_id,
|
|
95
|
+
**kwargs,
|
|
96
|
+
)
|