nv-ingest-api 26.1.0rc4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nv-ingest-api might be problematic. Click here for more details.
- nv_ingest_api/__init__.py +3 -0
- nv_ingest_api/interface/__init__.py +218 -0
- nv_ingest_api/interface/extract.py +977 -0
- nv_ingest_api/interface/mutate.py +154 -0
- nv_ingest_api/interface/store.py +200 -0
- nv_ingest_api/interface/transform.py +382 -0
- nv_ingest_api/interface/utility.py +186 -0
- nv_ingest_api/internal/__init__.py +0 -0
- nv_ingest_api/internal/enums/__init__.py +3 -0
- nv_ingest_api/internal/enums/common.py +550 -0
- nv_ingest_api/internal/extract/__init__.py +3 -0
- nv_ingest_api/internal/extract/audio/__init__.py +3 -0
- nv_ingest_api/internal/extract/audio/audio_extraction.py +202 -0
- nv_ingest_api/internal/extract/docx/__init__.py +5 -0
- nv_ingest_api/internal/extract/docx/docx_extractor.py +232 -0
- nv_ingest_api/internal/extract/docx/engines/__init__.py +0 -0
- nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/__init__.py +3 -0
- nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/docx_helper.py +127 -0
- nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/docxreader.py +971 -0
- nv_ingest_api/internal/extract/html/__init__.py +3 -0
- nv_ingest_api/internal/extract/html/html_extractor.py +84 -0
- nv_ingest_api/internal/extract/image/__init__.py +3 -0
- nv_ingest_api/internal/extract/image/chart_extractor.py +375 -0
- nv_ingest_api/internal/extract/image/image_extractor.py +208 -0
- nv_ingest_api/internal/extract/image/image_helpers/__init__.py +3 -0
- nv_ingest_api/internal/extract/image/image_helpers/common.py +433 -0
- nv_ingest_api/internal/extract/image/infographic_extractor.py +290 -0
- nv_ingest_api/internal/extract/image/ocr_extractor.py +407 -0
- nv_ingest_api/internal/extract/image/table_extractor.py +391 -0
- nv_ingest_api/internal/extract/pdf/__init__.py +3 -0
- nv_ingest_api/internal/extract/pdf/engines/__init__.py +19 -0
- nv_ingest_api/internal/extract/pdf/engines/adobe.py +484 -0
- nv_ingest_api/internal/extract/pdf/engines/llama.py +246 -0
- nv_ingest_api/internal/extract/pdf/engines/nemotron_parse.py +598 -0
- nv_ingest_api/internal/extract/pdf/engines/pdf_helpers/__init__.py +166 -0
- nv_ingest_api/internal/extract/pdf/engines/pdfium.py +652 -0
- nv_ingest_api/internal/extract/pdf/engines/tika.py +96 -0
- nv_ingest_api/internal/extract/pdf/engines/unstructured_io.py +426 -0
- nv_ingest_api/internal/extract/pdf/pdf_extractor.py +74 -0
- nv_ingest_api/internal/extract/pptx/__init__.py +5 -0
- nv_ingest_api/internal/extract/pptx/engines/__init__.py +0 -0
- nv_ingest_api/internal/extract/pptx/engines/pptx_helper.py +968 -0
- nv_ingest_api/internal/extract/pptx/pptx_extractor.py +210 -0
- nv_ingest_api/internal/meta/__init__.py +3 -0
- nv_ingest_api/internal/meta/udf.py +232 -0
- nv_ingest_api/internal/mutate/__init__.py +3 -0
- nv_ingest_api/internal/mutate/deduplicate.py +110 -0
- nv_ingest_api/internal/mutate/filter.py +133 -0
- nv_ingest_api/internal/primitives/__init__.py +0 -0
- nv_ingest_api/internal/primitives/control_message_task.py +16 -0
- nv_ingest_api/internal/primitives/ingest_control_message.py +307 -0
- nv_ingest_api/internal/primitives/nim/__init__.py +9 -0
- nv_ingest_api/internal/primitives/nim/default_values.py +14 -0
- nv_ingest_api/internal/primitives/nim/model_interface/__init__.py +3 -0
- nv_ingest_api/internal/primitives/nim/model_interface/cached.py +274 -0
- nv_ingest_api/internal/primitives/nim/model_interface/decorators.py +56 -0
- nv_ingest_api/internal/primitives/nim/model_interface/deplot.py +270 -0
- nv_ingest_api/internal/primitives/nim/model_interface/helpers.py +338 -0
- nv_ingest_api/internal/primitives/nim/model_interface/nemotron_parse.py +239 -0
- nv_ingest_api/internal/primitives/nim/model_interface/ocr.py +776 -0
- nv_ingest_api/internal/primitives/nim/model_interface/parakeet.py +367 -0
- nv_ingest_api/internal/primitives/nim/model_interface/text_embedding.py +129 -0
- nv_ingest_api/internal/primitives/nim/model_interface/vlm.py +177 -0
- nv_ingest_api/internal/primitives/nim/model_interface/yolox.py +1681 -0
- nv_ingest_api/internal/primitives/nim/nim_client.py +801 -0
- nv_ingest_api/internal/primitives/nim/nim_model_interface.py +126 -0
- nv_ingest_api/internal/primitives/tracing/__init__.py +0 -0
- nv_ingest_api/internal/primitives/tracing/latency.py +69 -0
- nv_ingest_api/internal/primitives/tracing/logging.py +96 -0
- nv_ingest_api/internal/primitives/tracing/tagging.py +288 -0
- nv_ingest_api/internal/schemas/__init__.py +3 -0
- nv_ingest_api/internal/schemas/extract/__init__.py +3 -0
- nv_ingest_api/internal/schemas/extract/extract_audio_schema.py +133 -0
- nv_ingest_api/internal/schemas/extract/extract_chart_schema.py +144 -0
- nv_ingest_api/internal/schemas/extract/extract_docx_schema.py +129 -0
- nv_ingest_api/internal/schemas/extract/extract_html_schema.py +34 -0
- nv_ingest_api/internal/schemas/extract/extract_image_schema.py +126 -0
- nv_ingest_api/internal/schemas/extract/extract_infographic_schema.py +137 -0
- nv_ingest_api/internal/schemas/extract/extract_ocr_schema.py +137 -0
- nv_ingest_api/internal/schemas/extract/extract_pdf_schema.py +220 -0
- nv_ingest_api/internal/schemas/extract/extract_pptx_schema.py +128 -0
- nv_ingest_api/internal/schemas/extract/extract_table_schema.py +137 -0
- nv_ingest_api/internal/schemas/message_brokers/__init__.py +3 -0
- nv_ingest_api/internal/schemas/message_brokers/message_broker_client_schema.py +37 -0
- nv_ingest_api/internal/schemas/message_brokers/request_schema.py +34 -0
- nv_ingest_api/internal/schemas/message_brokers/response_schema.py +19 -0
- nv_ingest_api/internal/schemas/meta/__init__.py +3 -0
- nv_ingest_api/internal/schemas/meta/base_model_noext.py +11 -0
- nv_ingest_api/internal/schemas/meta/ingest_job_schema.py +355 -0
- nv_ingest_api/internal/schemas/meta/metadata_schema.py +394 -0
- nv_ingest_api/internal/schemas/meta/udf.py +23 -0
- nv_ingest_api/internal/schemas/mixins.py +39 -0
- nv_ingest_api/internal/schemas/mutate/__init__.py +3 -0
- nv_ingest_api/internal/schemas/mutate/mutate_image_dedup_schema.py +16 -0
- nv_ingest_api/internal/schemas/store/__init__.py +3 -0
- nv_ingest_api/internal/schemas/store/store_embedding_schema.py +28 -0
- nv_ingest_api/internal/schemas/store/store_image_schema.py +45 -0
- nv_ingest_api/internal/schemas/transform/__init__.py +3 -0
- nv_ingest_api/internal/schemas/transform/transform_image_caption_schema.py +36 -0
- nv_ingest_api/internal/schemas/transform/transform_image_filter_schema.py +17 -0
- nv_ingest_api/internal/schemas/transform/transform_text_embedding_schema.py +48 -0
- nv_ingest_api/internal/schemas/transform/transform_text_splitter_schema.py +24 -0
- nv_ingest_api/internal/store/__init__.py +3 -0
- nv_ingest_api/internal/store/embed_text_upload.py +236 -0
- nv_ingest_api/internal/store/image_upload.py +251 -0
- nv_ingest_api/internal/transform/__init__.py +3 -0
- nv_ingest_api/internal/transform/caption_image.py +219 -0
- nv_ingest_api/internal/transform/embed_text.py +702 -0
- nv_ingest_api/internal/transform/split_text.py +182 -0
- nv_ingest_api/util/__init__.py +3 -0
- nv_ingest_api/util/control_message/__init__.py +0 -0
- nv_ingest_api/util/control_message/validators.py +47 -0
- nv_ingest_api/util/converters/__init__.py +0 -0
- nv_ingest_api/util/converters/bytetools.py +78 -0
- nv_ingest_api/util/converters/containers.py +65 -0
- nv_ingest_api/util/converters/datetools.py +90 -0
- nv_ingest_api/util/converters/dftools.py +127 -0
- nv_ingest_api/util/converters/formats.py +64 -0
- nv_ingest_api/util/converters/type_mappings.py +27 -0
- nv_ingest_api/util/dataloader/__init__.py +9 -0
- nv_ingest_api/util/dataloader/dataloader.py +409 -0
- nv_ingest_api/util/detectors/__init__.py +5 -0
- nv_ingest_api/util/detectors/language.py +38 -0
- nv_ingest_api/util/exception_handlers/__init__.py +0 -0
- nv_ingest_api/util/exception_handlers/converters.py +72 -0
- nv_ingest_api/util/exception_handlers/decorators.py +429 -0
- nv_ingest_api/util/exception_handlers/detectors.py +74 -0
- nv_ingest_api/util/exception_handlers/pdf.py +116 -0
- nv_ingest_api/util/exception_handlers/schemas.py +68 -0
- nv_ingest_api/util/image_processing/__init__.py +5 -0
- nv_ingest_api/util/image_processing/clustering.py +260 -0
- nv_ingest_api/util/image_processing/processing.py +177 -0
- nv_ingest_api/util/image_processing/table_and_chart.py +504 -0
- nv_ingest_api/util/image_processing/transforms.py +850 -0
- nv_ingest_api/util/imports/__init__.py +3 -0
- nv_ingest_api/util/imports/callable_signatures.py +108 -0
- nv_ingest_api/util/imports/dynamic_resolvers.py +158 -0
- nv_ingest_api/util/introspection/__init__.py +3 -0
- nv_ingest_api/util/introspection/class_inspect.py +145 -0
- nv_ingest_api/util/introspection/function_inspect.py +65 -0
- nv_ingest_api/util/logging/__init__.py +0 -0
- nv_ingest_api/util/logging/configuration.py +102 -0
- nv_ingest_api/util/logging/sanitize.py +84 -0
- nv_ingest_api/util/message_brokers/__init__.py +3 -0
- nv_ingest_api/util/message_brokers/qos_scheduler.py +283 -0
- nv_ingest_api/util/message_brokers/simple_message_broker/__init__.py +9 -0
- nv_ingest_api/util/message_brokers/simple_message_broker/broker.py +465 -0
- nv_ingest_api/util/message_brokers/simple_message_broker/ordered_message_queue.py +71 -0
- nv_ingest_api/util/message_brokers/simple_message_broker/simple_client.py +455 -0
- nv_ingest_api/util/metadata/__init__.py +5 -0
- nv_ingest_api/util/metadata/aggregators.py +516 -0
- nv_ingest_api/util/multi_processing/__init__.py +8 -0
- nv_ingest_api/util/multi_processing/mp_pool_singleton.py +200 -0
- nv_ingest_api/util/nim/__init__.py +161 -0
- nv_ingest_api/util/pdf/__init__.py +3 -0
- nv_ingest_api/util/pdf/pdfium.py +428 -0
- nv_ingest_api/util/schema/__init__.py +3 -0
- nv_ingest_api/util/schema/schema_validator.py +10 -0
- nv_ingest_api/util/service_clients/__init__.py +3 -0
- nv_ingest_api/util/service_clients/client_base.py +86 -0
- nv_ingest_api/util/service_clients/kafka/__init__.py +3 -0
- nv_ingest_api/util/service_clients/redis/__init__.py +3 -0
- nv_ingest_api/util/service_clients/redis/redis_client.py +983 -0
- nv_ingest_api/util/service_clients/rest/__init__.py +0 -0
- nv_ingest_api/util/service_clients/rest/rest_client.py +595 -0
- nv_ingest_api/util/string_processing/__init__.py +51 -0
- nv_ingest_api/util/string_processing/configuration.py +682 -0
- nv_ingest_api/util/string_processing/yaml.py +109 -0
- nv_ingest_api/util/system/__init__.py +0 -0
- nv_ingest_api/util/system/hardware_info.py +594 -0
- nv_ingest_api-26.1.0rc4.dist-info/METADATA +237 -0
- nv_ingest_api-26.1.0rc4.dist-info/RECORD +177 -0
- nv_ingest_api-26.1.0rc4.dist-info/WHEEL +5 -0
- nv_ingest_api-26.1.0rc4.dist-info/licenses/LICENSE +201 -0
- nv_ingest_api-26.1.0rc4.dist-info/top_level.txt +2 -0
- udfs/__init__.py +5 -0
- udfs/llm_summarizer_udf.py +259 -0
|
@@ -0,0 +1,270 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
from typing import Dict, Any, Optional, List
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import logging
|
|
9
|
+
|
|
10
|
+
from nv_ingest_api.internal.primitives.nim import ModelInterface
|
|
11
|
+
from nv_ingest_api.util.image_processing.transforms import base64_to_numpy
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class DeplotModelInterface(ModelInterface):
|
|
17
|
+
"""
|
|
18
|
+
An interface for handling inference with a Deplot model, supporting both gRPC and HTTP protocols,
|
|
19
|
+
now updated to handle multiple base64 images ('base64_images').
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def name(self) -> str:
|
|
23
|
+
"""
|
|
24
|
+
Get the name of the model interface.
|
|
25
|
+
|
|
26
|
+
Returns
|
|
27
|
+
-------
|
|
28
|
+
str
|
|
29
|
+
The name of the model interface ("Deplot").
|
|
30
|
+
"""
|
|
31
|
+
return "Deplot"
|
|
32
|
+
|
|
33
|
+
def prepare_data_for_inference(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
|
34
|
+
"""
|
|
35
|
+
Prepare input data by decoding one or more base64-encoded images into NumPy arrays.
|
|
36
|
+
|
|
37
|
+
Parameters
|
|
38
|
+
----------
|
|
39
|
+
data : dict
|
|
40
|
+
The input data containing either 'base64_image' (single image)
|
|
41
|
+
or 'base64_images' (multiple images).
|
|
42
|
+
|
|
43
|
+
Returns
|
|
44
|
+
-------
|
|
45
|
+
dict
|
|
46
|
+
The updated data dictionary with 'image_arrays': a list of decoded NumPy arrays.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
# Handle a single base64_image or multiple base64_images
|
|
50
|
+
if "base64_images" in data:
|
|
51
|
+
base64_list = data["base64_images"]
|
|
52
|
+
if not isinstance(base64_list, list):
|
|
53
|
+
raise ValueError("The 'base64_images' key must contain a list of base64-encoded strings.")
|
|
54
|
+
image_arrays = [base64_to_numpy(b64) for b64 in base64_list]
|
|
55
|
+
|
|
56
|
+
elif "base64_image" in data:
|
|
57
|
+
# Fallback for single image
|
|
58
|
+
image_arrays = [base64_to_numpy(data["base64_image"])]
|
|
59
|
+
else:
|
|
60
|
+
raise KeyError("Input data must include 'base64_image' or 'base64_images'.")
|
|
61
|
+
|
|
62
|
+
data["image_arrays"] = image_arrays
|
|
63
|
+
|
|
64
|
+
return data
|
|
65
|
+
|
|
66
|
+
def format_input(self, data: Dict[str, Any], protocol: str, max_batch_size: int, **kwargs) -> Any:
|
|
67
|
+
"""
|
|
68
|
+
Format input data for the specified protocol (gRPC or HTTP) for Deplot.
|
|
69
|
+
For HTTP, we now construct multiple messages—one per image batch—along with
|
|
70
|
+
corresponding batch data carrying the original image arrays and their dimensions.
|
|
71
|
+
|
|
72
|
+
Parameters
|
|
73
|
+
----------
|
|
74
|
+
data : dict of str -> Any
|
|
75
|
+
The input data dictionary, expected to contain "image_arrays" (a list of np.ndarray).
|
|
76
|
+
protocol : str
|
|
77
|
+
The protocol to use, "grpc" or "http".
|
|
78
|
+
max_batch_size : int
|
|
79
|
+
The maximum number of images per batch.
|
|
80
|
+
kwargs : dict
|
|
81
|
+
Additional parameters to pass to the payload preparation (for HTTP).
|
|
82
|
+
|
|
83
|
+
Returns
|
|
84
|
+
-------
|
|
85
|
+
tuple
|
|
86
|
+
(formatted_batches, formatted_batch_data) where:
|
|
87
|
+
- For gRPC: formatted_batches is a list of NumPy arrays, each of shape (B, H, W, C)
|
|
88
|
+
with B <= max_batch_size.
|
|
89
|
+
- For HTTP: formatted_batches is a list of JSON-serializable payload dicts.
|
|
90
|
+
- In both cases, formatted_batch_data is a list of dicts containing:
|
|
91
|
+
"image_arrays": the list of original np.ndarray images for that batch, and
|
|
92
|
+
"image_dims": a list of (height, width) tuples for each image in the batch.
|
|
93
|
+
|
|
94
|
+
Raises
|
|
95
|
+
------
|
|
96
|
+
KeyError
|
|
97
|
+
If "image_arrays" is missing in the data dictionary.
|
|
98
|
+
ValueError
|
|
99
|
+
If the protocol is invalid, or if no valid images are found.
|
|
100
|
+
"""
|
|
101
|
+
if "image_arrays" not in data:
|
|
102
|
+
raise KeyError("Expected 'image_arrays' in data. Call prepare_data_for_inference first.")
|
|
103
|
+
|
|
104
|
+
image_arrays = data["image_arrays"]
|
|
105
|
+
# Compute image dimensions from each image array.
|
|
106
|
+
image_dims = [(img.shape[0], img.shape[1]) for img in image_arrays]
|
|
107
|
+
|
|
108
|
+
# Helper function: chunk a list into sublists of length <= chunk_size.
|
|
109
|
+
def chunk_list(lst: list, chunk_size: int) -> List[list]:
|
|
110
|
+
return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
|
|
111
|
+
|
|
112
|
+
if protocol == "grpc":
|
|
113
|
+
logger.debug("Formatting input for gRPC Deplot model (potentially batched).")
|
|
114
|
+
processed = []
|
|
115
|
+
for arr in image_arrays:
|
|
116
|
+
# Ensure each image has shape (1, H, W, C)
|
|
117
|
+
if arr.ndim == 3:
|
|
118
|
+
arr = np.expand_dims(arr, axis=0)
|
|
119
|
+
arr = arr.astype(np.float32)
|
|
120
|
+
arr /= 255.0 # Normalize to [0,1]
|
|
121
|
+
processed.append(arr)
|
|
122
|
+
|
|
123
|
+
if not processed:
|
|
124
|
+
raise ValueError("No valid images found for gRPC formatting.")
|
|
125
|
+
|
|
126
|
+
formatted_batches = []
|
|
127
|
+
formatted_batch_data = []
|
|
128
|
+
proc_chunks = chunk_list(processed, max_batch_size)
|
|
129
|
+
orig_chunks = chunk_list(image_arrays, max_batch_size)
|
|
130
|
+
dims_chunks = chunk_list(image_dims, max_batch_size)
|
|
131
|
+
|
|
132
|
+
for proc_chunk, orig_chunk, dims_chunk in zip(proc_chunks, orig_chunks, dims_chunks):
|
|
133
|
+
# Concatenate along the batch dimension to form a single input.
|
|
134
|
+
batched_input = np.concatenate(proc_chunk, axis=0)
|
|
135
|
+
formatted_batches.append(batched_input)
|
|
136
|
+
formatted_batch_data.append({"image_arrays": orig_chunk, "image_dims": dims_chunk})
|
|
137
|
+
return formatted_batches, formatted_batch_data
|
|
138
|
+
|
|
139
|
+
elif protocol == "http":
|
|
140
|
+
logger.debug("Formatting input for HTTP Deplot model (multiple messages).")
|
|
141
|
+
if "base64_images" in data:
|
|
142
|
+
base64_list = data["base64_images"]
|
|
143
|
+
else:
|
|
144
|
+
base64_list = [data["base64_image"]]
|
|
145
|
+
|
|
146
|
+
formatted_batches = []
|
|
147
|
+
formatted_batch_data = []
|
|
148
|
+
b64_chunks = chunk_list(base64_list, max_batch_size)
|
|
149
|
+
orig_chunks = chunk_list(image_arrays, max_batch_size)
|
|
150
|
+
dims_chunks = chunk_list(image_dims, max_batch_size)
|
|
151
|
+
|
|
152
|
+
for b64_chunk, orig_chunk, dims_chunk in zip(b64_chunks, orig_chunks, dims_chunks):
|
|
153
|
+
payload = self._prepare_deplot_payload(
|
|
154
|
+
base64_list=b64_chunk,
|
|
155
|
+
max_tokens=kwargs.get("max_tokens", 500),
|
|
156
|
+
temperature=kwargs.get("temperature", 0.5),
|
|
157
|
+
top_p=kwargs.get("top_p", 0.9),
|
|
158
|
+
)
|
|
159
|
+
formatted_batches.append(payload)
|
|
160
|
+
formatted_batch_data.append({"image_arrays": orig_chunk, "image_dims": dims_chunk})
|
|
161
|
+
return formatted_batches, formatted_batch_data
|
|
162
|
+
|
|
163
|
+
else:
|
|
164
|
+
raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.")
|
|
165
|
+
|
|
166
|
+
def parse_output(self, response: Any, protocol: str, data: Optional[Dict[str, Any]] = None, **kwargs) -> Any:
|
|
167
|
+
"""
|
|
168
|
+
Parse the model's inference response.
|
|
169
|
+
"""
|
|
170
|
+
if protocol == "grpc":
|
|
171
|
+
logger.debug("Parsing output from gRPC Deplot model (batched).")
|
|
172
|
+
# Each batch element might be returned as a list of bytes. Combine or keep separate as needed.
|
|
173
|
+
results = []
|
|
174
|
+
for item in response:
|
|
175
|
+
# If item is [b'...'], decode and join
|
|
176
|
+
if isinstance(item, list):
|
|
177
|
+
joined_str = " ".join(o.decode("utf-8") for o in item)
|
|
178
|
+
results.append(joined_str)
|
|
179
|
+
else:
|
|
180
|
+
# single bytes or str
|
|
181
|
+
val = item.decode("utf-8") if isinstance(item, bytes) else str(item)
|
|
182
|
+
results.append(val)
|
|
183
|
+
return results # Return a list of strings, one per image.
|
|
184
|
+
|
|
185
|
+
elif protocol == "http":
|
|
186
|
+
logger.debug("Parsing output from HTTP Deplot model.")
|
|
187
|
+
return self._extract_content_from_deplot_response(response)
|
|
188
|
+
else:
|
|
189
|
+
raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.")
|
|
190
|
+
|
|
191
|
+
def process_inference_results(self, output: Any, protocol: str, **kwargs) -> Any:
|
|
192
|
+
"""
|
|
193
|
+
Process inference results for the Deplot model.
|
|
194
|
+
|
|
195
|
+
Parameters
|
|
196
|
+
----------
|
|
197
|
+
output : Any
|
|
198
|
+
The raw output from the model.
|
|
199
|
+
protocol : str
|
|
200
|
+
The protocol used for inference (gRPC or HTTP).
|
|
201
|
+
|
|
202
|
+
Returns
|
|
203
|
+
-------
|
|
204
|
+
Any
|
|
205
|
+
The processed inference results.
|
|
206
|
+
"""
|
|
207
|
+
|
|
208
|
+
# For Deplot, the output is the chart content as a string
|
|
209
|
+
return output
|
|
210
|
+
|
|
211
|
+
@staticmethod
|
|
212
|
+
def _prepare_deplot_payload(
|
|
213
|
+
base64_list: list,
|
|
214
|
+
max_tokens: int = 500,
|
|
215
|
+
temperature: float = 0.5,
|
|
216
|
+
top_p: float = 0.9,
|
|
217
|
+
) -> Dict[str, Any]:
|
|
218
|
+
"""
|
|
219
|
+
Prepare an HTTP payload for Deplot that includes one message per image,
|
|
220
|
+
matching the original single-image style:
|
|
221
|
+
|
|
222
|
+
messages = [
|
|
223
|
+
{
|
|
224
|
+
"role": "user",
|
|
225
|
+
"content": "Generate ... <img src=\"data:image/png;base64,...\" />"
|
|
226
|
+
},
|
|
227
|
+
{
|
|
228
|
+
"role": "user",
|
|
229
|
+
"content": "Generate ... <img src=\"data:image/png;base64,...\" />"
|
|
230
|
+
},
|
|
231
|
+
...
|
|
232
|
+
]
|
|
233
|
+
|
|
234
|
+
If your backend expects multiple messages in a single request, this keeps
|
|
235
|
+
the same structure as the single-image code repeated N times.
|
|
236
|
+
"""
|
|
237
|
+
messages = []
|
|
238
|
+
# Note: deplot NIM currently only supports a single message per request
|
|
239
|
+
for b64_img in base64_list:
|
|
240
|
+
messages.append(
|
|
241
|
+
{
|
|
242
|
+
"role": "user",
|
|
243
|
+
"content": (
|
|
244
|
+
"Generate the underlying data table of the figure below: "
|
|
245
|
+
f'<img src="data:image/png;base64,{b64_img}" />'
|
|
246
|
+
),
|
|
247
|
+
}
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
payload = {
|
|
251
|
+
"model": "google/deplot",
|
|
252
|
+
"messages": messages, # multiple user messages now
|
|
253
|
+
"max_tokens": max_tokens,
|
|
254
|
+
"stream": False,
|
|
255
|
+
"temperature": temperature,
|
|
256
|
+
"top_p": top_p,
|
|
257
|
+
}
|
|
258
|
+
return payload
|
|
259
|
+
|
|
260
|
+
@staticmethod
|
|
261
|
+
def _extract_content_from_deplot_response(json_response: Dict[str, Any]) -> Any:
|
|
262
|
+
"""
|
|
263
|
+
Extract content from the JSON response of a Deplot HTTP API request.
|
|
264
|
+
The original code expected a single choice with a single textual content.
|
|
265
|
+
"""
|
|
266
|
+
if "choices" not in json_response or not json_response["choices"]:
|
|
267
|
+
raise RuntimeError("Unexpected response format: 'choices' key is missing or empty.")
|
|
268
|
+
|
|
269
|
+
# If the service only returns one textual result, we return that one.
|
|
270
|
+
return json_response["choices"][0]["message"]["content"]
|
|
@@ -0,0 +1,338 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
import backoff
|
|
9
|
+
import cv2
|
|
10
|
+
import numpy as np
|
|
11
|
+
import requests
|
|
12
|
+
|
|
13
|
+
from nv_ingest_api.internal.primitives.nim.model_interface.decorators import multiprocessing_cache
|
|
14
|
+
from nv_ingest_api.util.image_processing.transforms import pad_image, normalize_image
|
|
15
|
+
from nv_ingest_api.util.string_processing import generate_url, remove_url_endpoints
|
|
16
|
+
|
|
17
|
+
cv2.setNumThreads(1)
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def preprocess_image_for_paddle(array: np.ndarray, image_max_dimension: int = 960) -> np.ndarray:
|
|
22
|
+
"""
|
|
23
|
+
Preprocesses an input image to be suitable for use with PaddleOCR by resizing, normalizing, padding,
|
|
24
|
+
and transposing it into the required format.
|
|
25
|
+
|
|
26
|
+
This function is intended for preprocessing images to be passed as input to PaddleOCR using GRPC.
|
|
27
|
+
It is not necessary when using the HTTP endpoint.
|
|
28
|
+
|
|
29
|
+
Steps:
|
|
30
|
+
-----
|
|
31
|
+
1. Resizes the image while maintaining aspect ratio such that its largest dimension is scaled to 960 pixels.
|
|
32
|
+
2. Normalizes the image using the `normalize_image` function.
|
|
33
|
+
3. Pads the image to ensure both its height and width are multiples of 32, as required by PaddleOCR.
|
|
34
|
+
4. Transposes the image from (height, width, channel) to (channel, height, width), the format expected by PaddleOCR.
|
|
35
|
+
|
|
36
|
+
Parameters:
|
|
37
|
+
----------
|
|
38
|
+
array : np.ndarray
|
|
39
|
+
The input image array of shape (height, width, channels). It should have pixel values in the range [0, 255].
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
-------
|
|
43
|
+
np.ndarray
|
|
44
|
+
A preprocessed image with the shape (channels, height, width) and normalized pixel values.
|
|
45
|
+
The image will be padded to have dimensions that are multiples of 32, with the padding color set to 0.
|
|
46
|
+
|
|
47
|
+
Notes:
|
|
48
|
+
-----
|
|
49
|
+
- The image is resized so that its largest dimension becomes 960 pixels, maintaining the aspect ratio.
|
|
50
|
+
- After normalization, the image is padded to the nearest multiple of 32 in both dimensions, which is
|
|
51
|
+
a requirement for PaddleOCR.
|
|
52
|
+
- The normalized pixel values are scaled between 0 and 1 before padding and transposing the image.
|
|
53
|
+
"""
|
|
54
|
+
height, width = array.shape[:2]
|
|
55
|
+
scale_factor = image_max_dimension / max(height, width)
|
|
56
|
+
new_height = int(height * scale_factor)
|
|
57
|
+
new_width = int(width * scale_factor)
|
|
58
|
+
resized = cv2.resize(array, (new_width, new_height))
|
|
59
|
+
|
|
60
|
+
normalized = normalize_image(resized)
|
|
61
|
+
|
|
62
|
+
# PaddleOCR NIM (GRPC) requires input shapes to be multiples of 32.
|
|
63
|
+
new_height = (normalized.shape[0] + 31) // 32 * 32
|
|
64
|
+
new_width = (normalized.shape[1] + 31) // 32 * 32
|
|
65
|
+
padded, (pad_width, pad_height) = pad_image(
|
|
66
|
+
normalized, target_height=new_height, target_width=new_width, background_color=0, dtype=np.float32
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
# PaddleOCR NIM (GRPC) requires input to be (channel, height, width).
|
|
70
|
+
transposed = padded.transpose((2, 0, 1))
|
|
71
|
+
|
|
72
|
+
# Metadata can used for inverting transformations on the resulting bounding boxes.
|
|
73
|
+
metadata = {
|
|
74
|
+
"original_height": height,
|
|
75
|
+
"original_width": width,
|
|
76
|
+
"scale_factor": scale_factor,
|
|
77
|
+
"new_height": transposed.shape[1],
|
|
78
|
+
"new_width": transposed.shape[2],
|
|
79
|
+
"pad_height": pad_height,
|
|
80
|
+
"pad_width": pad_width,
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
return transposed, metadata
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def preprocess_image_for_ocr(
|
|
87
|
+
array: np.ndarray,
|
|
88
|
+
target_height: Optional[int] = None,
|
|
89
|
+
target_width: Optional[int] = None,
|
|
90
|
+
pad_how: str = "bottom_right",
|
|
91
|
+
normalize: bool = False,
|
|
92
|
+
channel_first: bool = False,
|
|
93
|
+
) -> np.ndarray:
|
|
94
|
+
"""
|
|
95
|
+
Preprocesses an input image to be suitable for use with NemoRetriever-OCR.
|
|
96
|
+
|
|
97
|
+
This function is intended for preprocessing images to be passed as input to NemoRetriever-OCR using GRPC.
|
|
98
|
+
It is not necessary when using the HTTP endpoint.
|
|
99
|
+
|
|
100
|
+
Parameters:
|
|
101
|
+
----------
|
|
102
|
+
array : np.ndarray
|
|
103
|
+
The input image array of shape (height, width, channels). It should have pixel values in the range [0, 255].
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
-------
|
|
107
|
+
np.ndarray
|
|
108
|
+
A preprocessed image with the shape (channels, height, width).
|
|
109
|
+
"""
|
|
110
|
+
height, width = array.shape[:2]
|
|
111
|
+
|
|
112
|
+
if target_height is None:
|
|
113
|
+
target_height = height
|
|
114
|
+
|
|
115
|
+
if target_width is None:
|
|
116
|
+
target_width = width
|
|
117
|
+
|
|
118
|
+
padded, (pad_width, pad_height) = pad_image(
|
|
119
|
+
array,
|
|
120
|
+
target_height=target_height,
|
|
121
|
+
target_width=target_width,
|
|
122
|
+
background_color=255,
|
|
123
|
+
dtype=np.float32,
|
|
124
|
+
how=pad_how,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
if normalize:
|
|
128
|
+
padded = padded / 255.0
|
|
129
|
+
|
|
130
|
+
if channel_first:
|
|
131
|
+
# NemoRetriever-OCR NIM (GRPC) requires input to be (channel, height, width).
|
|
132
|
+
padded = padded.transpose((2, 0, 1))
|
|
133
|
+
|
|
134
|
+
# Metadata can used for inverting transformations on the resulting bounding boxes.
|
|
135
|
+
metadata = {
|
|
136
|
+
"original_height": height,
|
|
137
|
+
"original_width": width,
|
|
138
|
+
"new_height": target_height,
|
|
139
|
+
"new_width": target_width,
|
|
140
|
+
"pad_height": pad_height,
|
|
141
|
+
"pad_width": pad_width,
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
return padded, metadata
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def is_ready(http_endpoint: str, ready_endpoint: str) -> bool:
|
|
148
|
+
"""
|
|
149
|
+
Check if the server at the given endpoint is ready.
|
|
150
|
+
|
|
151
|
+
Parameters
|
|
152
|
+
----------
|
|
153
|
+
http_endpoint : str
|
|
154
|
+
The HTTP endpoint of the server.
|
|
155
|
+
ready_endpoint : str
|
|
156
|
+
The specific ready-check endpoint.
|
|
157
|
+
|
|
158
|
+
Returns
|
|
159
|
+
-------
|
|
160
|
+
bool
|
|
161
|
+
True if the server is ready, False otherwise.
|
|
162
|
+
"""
|
|
163
|
+
|
|
164
|
+
# IF the url is empty or None that means the service was not configured
|
|
165
|
+
# and is therefore automatically marked as "ready"
|
|
166
|
+
if http_endpoint is None or http_endpoint == "":
|
|
167
|
+
return True
|
|
168
|
+
|
|
169
|
+
# If the url is for build.nvidia.com, it is automatically assumed "ready"
|
|
170
|
+
if "ai.api.nvidia.com" in http_endpoint:
|
|
171
|
+
return True
|
|
172
|
+
|
|
173
|
+
url = generate_url(http_endpoint)
|
|
174
|
+
url = remove_url_endpoints(url)
|
|
175
|
+
|
|
176
|
+
if not ready_endpoint.startswith("/") and not url.endswith("/"):
|
|
177
|
+
ready_endpoint = "/" + ready_endpoint
|
|
178
|
+
|
|
179
|
+
url = url + ready_endpoint
|
|
180
|
+
|
|
181
|
+
# Call the ready endpoint of the NIM
|
|
182
|
+
try:
|
|
183
|
+
# Use a short timeout to prevent long hanging calls. 5 seconds seems resonable
|
|
184
|
+
resp = requests.get(url, timeout=5)
|
|
185
|
+
if resp.status_code == 200:
|
|
186
|
+
# The NIM is saying it is ready to serve
|
|
187
|
+
return True
|
|
188
|
+
elif resp.status_code == 503:
|
|
189
|
+
# NIM is explicitly saying it is not ready.
|
|
190
|
+
return False
|
|
191
|
+
else:
|
|
192
|
+
# Any other code is confusing. We should log it with a warning
|
|
193
|
+
# as it could be something that might hold up ready state
|
|
194
|
+
logger.warning(f"'{url}' HTTP Status: {resp.status_code} - Response Payload: {resp.json()}")
|
|
195
|
+
return False
|
|
196
|
+
except requests.HTTPError as http_err:
|
|
197
|
+
logger.warning(f"'{url}' produced a HTTP error: {http_err}")
|
|
198
|
+
return False
|
|
199
|
+
except requests.Timeout:
|
|
200
|
+
logger.warning(f"'{url}' request timed out")
|
|
201
|
+
return False
|
|
202
|
+
except ConnectionError:
|
|
203
|
+
logger.warning(f"A connection error for '{url}' occurred")
|
|
204
|
+
return False
|
|
205
|
+
except requests.RequestException as err:
|
|
206
|
+
logger.warning(f"An error occurred: {err} for '{url}'")
|
|
207
|
+
return False
|
|
208
|
+
except Exception as ex:
|
|
209
|
+
# Don't let anything squeeze by
|
|
210
|
+
logger.warning(f"Exception: {ex}")
|
|
211
|
+
return False
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def _query_metadata(
|
|
215
|
+
http_endpoint: str,
|
|
216
|
+
field_name: str,
|
|
217
|
+
default_value: str,
|
|
218
|
+
retry_value: str = "",
|
|
219
|
+
metadata_endpoint: str = "/v1/metadata",
|
|
220
|
+
) -> str:
|
|
221
|
+
if (http_endpoint is None) or (http_endpoint == ""):
|
|
222
|
+
return default_value
|
|
223
|
+
|
|
224
|
+
url = generate_url(http_endpoint)
|
|
225
|
+
url = remove_url_endpoints(url)
|
|
226
|
+
|
|
227
|
+
if not metadata_endpoint.startswith("/") and not url.endswith("/"):
|
|
228
|
+
metadata_endpoint = "/" + metadata_endpoint
|
|
229
|
+
|
|
230
|
+
url = url + metadata_endpoint
|
|
231
|
+
|
|
232
|
+
# Call the metadata endpoint of the NIM
|
|
233
|
+
try:
|
|
234
|
+
# Use a short timeout to prevent long hanging calls. 5 seconds seems reasonable
|
|
235
|
+
resp = requests.get(url, timeout=5)
|
|
236
|
+
if resp.status_code == 200:
|
|
237
|
+
field_value = resp.json().get(field_name, "")
|
|
238
|
+
if field_value:
|
|
239
|
+
return field_value
|
|
240
|
+
else:
|
|
241
|
+
# If the field is empty, retry
|
|
242
|
+
logger.warning(f"No {field_name} field in response from '{url}'. Retrying.")
|
|
243
|
+
return retry_value
|
|
244
|
+
else:
|
|
245
|
+
# Any other code is confusing. We should log it with a warning
|
|
246
|
+
logger.warning(f"'{url}' HTTP Status: {resp.status_code} - Response Payload: {resp.text}")
|
|
247
|
+
return retry_value
|
|
248
|
+
except requests.HTTPError as http_err:
|
|
249
|
+
logger.warning(f"'{url}' produced a HTTP error: {http_err}")
|
|
250
|
+
return retry_value
|
|
251
|
+
except requests.Timeout:
|
|
252
|
+
logger.warning(f"'{url}' request timed out")
|
|
253
|
+
return retry_value
|
|
254
|
+
except ConnectionError:
|
|
255
|
+
logger.warning(f"A connection error for '{url}' occurred")
|
|
256
|
+
return retry_value
|
|
257
|
+
except requests.RequestException as err:
|
|
258
|
+
logger.warning(f"An error occurred: {err} for '{url}'")
|
|
259
|
+
return retry_value
|
|
260
|
+
except Exception as ex:
|
|
261
|
+
# Don't let anything squeeze by
|
|
262
|
+
logger.warning(f"Exception: {ex}")
|
|
263
|
+
return retry_value
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
@multiprocessing_cache(max_calls=100) # Cache results first to avoid redundant retries from backoff
|
|
267
|
+
@backoff.on_predicate(backoff.expo, max_time=30)
|
|
268
|
+
def get_version(http_endpoint: str, metadata_endpoint: str = "/v1/metadata", version_field: str = "version") -> str:
|
|
269
|
+
"""
|
|
270
|
+
Get the version of the server from its metadata endpoint.
|
|
271
|
+
|
|
272
|
+
Parameters
|
|
273
|
+
----------
|
|
274
|
+
http_endpoint : str
|
|
275
|
+
The HTTP endpoint of the server.
|
|
276
|
+
metadata_endpoint : str, optional
|
|
277
|
+
The metadata endpoint to query (default: "/v1/metadata").
|
|
278
|
+
version_field : str, optional
|
|
279
|
+
The field containing the version in the response (default: "version").
|
|
280
|
+
|
|
281
|
+
Returns
|
|
282
|
+
-------
|
|
283
|
+
str
|
|
284
|
+
The version of the server, or an empty string if unavailable.
|
|
285
|
+
"""
|
|
286
|
+
default_version = "1.0.0"
|
|
287
|
+
|
|
288
|
+
# TODO: Need a way to match NIM version to API versions.
|
|
289
|
+
if "ai.api.nvidia.com" in http_endpoint or "api.nvcf.nvidia.com" in http_endpoint:
|
|
290
|
+
return default_version
|
|
291
|
+
|
|
292
|
+
return _query_metadata(
|
|
293
|
+
http_endpoint,
|
|
294
|
+
field_name=version_field,
|
|
295
|
+
default_value=default_version,
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
@multiprocessing_cache(max_calls=100) # Cache results first to avoid redundant retries from backoff
|
|
300
|
+
@backoff.on_predicate(backoff.expo, max_time=30)
|
|
301
|
+
def get_model_name(
|
|
302
|
+
http_endpoint: str,
|
|
303
|
+
default_model_name,
|
|
304
|
+
metadata_endpoint: str = "/v1/metadata",
|
|
305
|
+
model_info_field: str = "modelInfo",
|
|
306
|
+
) -> str:
|
|
307
|
+
"""
|
|
308
|
+
Get the model name of the server from its metadata endpoint.
|
|
309
|
+
|
|
310
|
+
Parameters
|
|
311
|
+
----------
|
|
312
|
+
http_endpoint : str
|
|
313
|
+
The HTTP endpoint of the server.
|
|
314
|
+
metadata_endpoint : str, optional
|
|
315
|
+
The metadata endpoint to query (default: "/v1/metadata").
|
|
316
|
+
model_info_field : str, optional
|
|
317
|
+
The field containing the model info in the response (default: "modelInfo").
|
|
318
|
+
|
|
319
|
+
Returns
|
|
320
|
+
-------
|
|
321
|
+
str
|
|
322
|
+
The model name of the server, or an empty string if unavailable.
|
|
323
|
+
"""
|
|
324
|
+
if "ai.api.nvidia.com" in http_endpoint:
|
|
325
|
+
return http_endpoint.strip("/").strip("/chat/completions").split("/")[-1]
|
|
326
|
+
|
|
327
|
+
if "api.nvcf.nvidia.com" in http_endpoint:
|
|
328
|
+
return default_model_name
|
|
329
|
+
|
|
330
|
+
model_info = _query_metadata(
|
|
331
|
+
http_endpoint,
|
|
332
|
+
field_name=model_info_field,
|
|
333
|
+
default_value={"shortName": default_model_name},
|
|
334
|
+
)
|
|
335
|
+
short_name = model_info[0].get("shortName", default_model_name)
|
|
336
|
+
model_name = short_name.split(":")[0]
|
|
337
|
+
|
|
338
|
+
return model_name
|