nv-ingest-api 26.1.0rc4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nv-ingest-api might be problematic. Click here for more details.

Files changed (177) hide show
  1. nv_ingest_api/__init__.py +3 -0
  2. nv_ingest_api/interface/__init__.py +218 -0
  3. nv_ingest_api/interface/extract.py +977 -0
  4. nv_ingest_api/interface/mutate.py +154 -0
  5. nv_ingest_api/interface/store.py +200 -0
  6. nv_ingest_api/interface/transform.py +382 -0
  7. nv_ingest_api/interface/utility.py +186 -0
  8. nv_ingest_api/internal/__init__.py +0 -0
  9. nv_ingest_api/internal/enums/__init__.py +3 -0
  10. nv_ingest_api/internal/enums/common.py +550 -0
  11. nv_ingest_api/internal/extract/__init__.py +3 -0
  12. nv_ingest_api/internal/extract/audio/__init__.py +3 -0
  13. nv_ingest_api/internal/extract/audio/audio_extraction.py +202 -0
  14. nv_ingest_api/internal/extract/docx/__init__.py +5 -0
  15. nv_ingest_api/internal/extract/docx/docx_extractor.py +232 -0
  16. nv_ingest_api/internal/extract/docx/engines/__init__.py +0 -0
  17. nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/__init__.py +3 -0
  18. nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/docx_helper.py +127 -0
  19. nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/docxreader.py +971 -0
  20. nv_ingest_api/internal/extract/html/__init__.py +3 -0
  21. nv_ingest_api/internal/extract/html/html_extractor.py +84 -0
  22. nv_ingest_api/internal/extract/image/__init__.py +3 -0
  23. nv_ingest_api/internal/extract/image/chart_extractor.py +375 -0
  24. nv_ingest_api/internal/extract/image/image_extractor.py +208 -0
  25. nv_ingest_api/internal/extract/image/image_helpers/__init__.py +3 -0
  26. nv_ingest_api/internal/extract/image/image_helpers/common.py +433 -0
  27. nv_ingest_api/internal/extract/image/infographic_extractor.py +290 -0
  28. nv_ingest_api/internal/extract/image/ocr_extractor.py +407 -0
  29. nv_ingest_api/internal/extract/image/table_extractor.py +391 -0
  30. nv_ingest_api/internal/extract/pdf/__init__.py +3 -0
  31. nv_ingest_api/internal/extract/pdf/engines/__init__.py +19 -0
  32. nv_ingest_api/internal/extract/pdf/engines/adobe.py +484 -0
  33. nv_ingest_api/internal/extract/pdf/engines/llama.py +246 -0
  34. nv_ingest_api/internal/extract/pdf/engines/nemotron_parse.py +598 -0
  35. nv_ingest_api/internal/extract/pdf/engines/pdf_helpers/__init__.py +166 -0
  36. nv_ingest_api/internal/extract/pdf/engines/pdfium.py +652 -0
  37. nv_ingest_api/internal/extract/pdf/engines/tika.py +96 -0
  38. nv_ingest_api/internal/extract/pdf/engines/unstructured_io.py +426 -0
  39. nv_ingest_api/internal/extract/pdf/pdf_extractor.py +74 -0
  40. nv_ingest_api/internal/extract/pptx/__init__.py +5 -0
  41. nv_ingest_api/internal/extract/pptx/engines/__init__.py +0 -0
  42. nv_ingest_api/internal/extract/pptx/engines/pptx_helper.py +968 -0
  43. nv_ingest_api/internal/extract/pptx/pptx_extractor.py +210 -0
  44. nv_ingest_api/internal/meta/__init__.py +3 -0
  45. nv_ingest_api/internal/meta/udf.py +232 -0
  46. nv_ingest_api/internal/mutate/__init__.py +3 -0
  47. nv_ingest_api/internal/mutate/deduplicate.py +110 -0
  48. nv_ingest_api/internal/mutate/filter.py +133 -0
  49. nv_ingest_api/internal/primitives/__init__.py +0 -0
  50. nv_ingest_api/internal/primitives/control_message_task.py +16 -0
  51. nv_ingest_api/internal/primitives/ingest_control_message.py +307 -0
  52. nv_ingest_api/internal/primitives/nim/__init__.py +9 -0
  53. nv_ingest_api/internal/primitives/nim/default_values.py +14 -0
  54. nv_ingest_api/internal/primitives/nim/model_interface/__init__.py +3 -0
  55. nv_ingest_api/internal/primitives/nim/model_interface/cached.py +274 -0
  56. nv_ingest_api/internal/primitives/nim/model_interface/decorators.py +56 -0
  57. nv_ingest_api/internal/primitives/nim/model_interface/deplot.py +270 -0
  58. nv_ingest_api/internal/primitives/nim/model_interface/helpers.py +338 -0
  59. nv_ingest_api/internal/primitives/nim/model_interface/nemotron_parse.py +239 -0
  60. nv_ingest_api/internal/primitives/nim/model_interface/ocr.py +776 -0
  61. nv_ingest_api/internal/primitives/nim/model_interface/parakeet.py +367 -0
  62. nv_ingest_api/internal/primitives/nim/model_interface/text_embedding.py +129 -0
  63. nv_ingest_api/internal/primitives/nim/model_interface/vlm.py +177 -0
  64. nv_ingest_api/internal/primitives/nim/model_interface/yolox.py +1681 -0
  65. nv_ingest_api/internal/primitives/nim/nim_client.py +801 -0
  66. nv_ingest_api/internal/primitives/nim/nim_model_interface.py +126 -0
  67. nv_ingest_api/internal/primitives/tracing/__init__.py +0 -0
  68. nv_ingest_api/internal/primitives/tracing/latency.py +69 -0
  69. nv_ingest_api/internal/primitives/tracing/logging.py +96 -0
  70. nv_ingest_api/internal/primitives/tracing/tagging.py +288 -0
  71. nv_ingest_api/internal/schemas/__init__.py +3 -0
  72. nv_ingest_api/internal/schemas/extract/__init__.py +3 -0
  73. nv_ingest_api/internal/schemas/extract/extract_audio_schema.py +133 -0
  74. nv_ingest_api/internal/schemas/extract/extract_chart_schema.py +144 -0
  75. nv_ingest_api/internal/schemas/extract/extract_docx_schema.py +129 -0
  76. nv_ingest_api/internal/schemas/extract/extract_html_schema.py +34 -0
  77. nv_ingest_api/internal/schemas/extract/extract_image_schema.py +126 -0
  78. nv_ingest_api/internal/schemas/extract/extract_infographic_schema.py +137 -0
  79. nv_ingest_api/internal/schemas/extract/extract_ocr_schema.py +137 -0
  80. nv_ingest_api/internal/schemas/extract/extract_pdf_schema.py +220 -0
  81. nv_ingest_api/internal/schemas/extract/extract_pptx_schema.py +128 -0
  82. nv_ingest_api/internal/schemas/extract/extract_table_schema.py +137 -0
  83. nv_ingest_api/internal/schemas/message_brokers/__init__.py +3 -0
  84. nv_ingest_api/internal/schemas/message_brokers/message_broker_client_schema.py +37 -0
  85. nv_ingest_api/internal/schemas/message_brokers/request_schema.py +34 -0
  86. nv_ingest_api/internal/schemas/message_brokers/response_schema.py +19 -0
  87. nv_ingest_api/internal/schemas/meta/__init__.py +3 -0
  88. nv_ingest_api/internal/schemas/meta/base_model_noext.py +11 -0
  89. nv_ingest_api/internal/schemas/meta/ingest_job_schema.py +355 -0
  90. nv_ingest_api/internal/schemas/meta/metadata_schema.py +394 -0
  91. nv_ingest_api/internal/schemas/meta/udf.py +23 -0
  92. nv_ingest_api/internal/schemas/mixins.py +39 -0
  93. nv_ingest_api/internal/schemas/mutate/__init__.py +3 -0
  94. nv_ingest_api/internal/schemas/mutate/mutate_image_dedup_schema.py +16 -0
  95. nv_ingest_api/internal/schemas/store/__init__.py +3 -0
  96. nv_ingest_api/internal/schemas/store/store_embedding_schema.py +28 -0
  97. nv_ingest_api/internal/schemas/store/store_image_schema.py +45 -0
  98. nv_ingest_api/internal/schemas/transform/__init__.py +3 -0
  99. nv_ingest_api/internal/schemas/transform/transform_image_caption_schema.py +36 -0
  100. nv_ingest_api/internal/schemas/transform/transform_image_filter_schema.py +17 -0
  101. nv_ingest_api/internal/schemas/transform/transform_text_embedding_schema.py +48 -0
  102. nv_ingest_api/internal/schemas/transform/transform_text_splitter_schema.py +24 -0
  103. nv_ingest_api/internal/store/__init__.py +3 -0
  104. nv_ingest_api/internal/store/embed_text_upload.py +236 -0
  105. nv_ingest_api/internal/store/image_upload.py +251 -0
  106. nv_ingest_api/internal/transform/__init__.py +3 -0
  107. nv_ingest_api/internal/transform/caption_image.py +219 -0
  108. nv_ingest_api/internal/transform/embed_text.py +702 -0
  109. nv_ingest_api/internal/transform/split_text.py +182 -0
  110. nv_ingest_api/util/__init__.py +3 -0
  111. nv_ingest_api/util/control_message/__init__.py +0 -0
  112. nv_ingest_api/util/control_message/validators.py +47 -0
  113. nv_ingest_api/util/converters/__init__.py +0 -0
  114. nv_ingest_api/util/converters/bytetools.py +78 -0
  115. nv_ingest_api/util/converters/containers.py +65 -0
  116. nv_ingest_api/util/converters/datetools.py +90 -0
  117. nv_ingest_api/util/converters/dftools.py +127 -0
  118. nv_ingest_api/util/converters/formats.py +64 -0
  119. nv_ingest_api/util/converters/type_mappings.py +27 -0
  120. nv_ingest_api/util/dataloader/__init__.py +9 -0
  121. nv_ingest_api/util/dataloader/dataloader.py +409 -0
  122. nv_ingest_api/util/detectors/__init__.py +5 -0
  123. nv_ingest_api/util/detectors/language.py +38 -0
  124. nv_ingest_api/util/exception_handlers/__init__.py +0 -0
  125. nv_ingest_api/util/exception_handlers/converters.py +72 -0
  126. nv_ingest_api/util/exception_handlers/decorators.py +429 -0
  127. nv_ingest_api/util/exception_handlers/detectors.py +74 -0
  128. nv_ingest_api/util/exception_handlers/pdf.py +116 -0
  129. nv_ingest_api/util/exception_handlers/schemas.py +68 -0
  130. nv_ingest_api/util/image_processing/__init__.py +5 -0
  131. nv_ingest_api/util/image_processing/clustering.py +260 -0
  132. nv_ingest_api/util/image_processing/processing.py +177 -0
  133. nv_ingest_api/util/image_processing/table_and_chart.py +504 -0
  134. nv_ingest_api/util/image_processing/transforms.py +850 -0
  135. nv_ingest_api/util/imports/__init__.py +3 -0
  136. nv_ingest_api/util/imports/callable_signatures.py +108 -0
  137. nv_ingest_api/util/imports/dynamic_resolvers.py +158 -0
  138. nv_ingest_api/util/introspection/__init__.py +3 -0
  139. nv_ingest_api/util/introspection/class_inspect.py +145 -0
  140. nv_ingest_api/util/introspection/function_inspect.py +65 -0
  141. nv_ingest_api/util/logging/__init__.py +0 -0
  142. nv_ingest_api/util/logging/configuration.py +102 -0
  143. nv_ingest_api/util/logging/sanitize.py +84 -0
  144. nv_ingest_api/util/message_brokers/__init__.py +3 -0
  145. nv_ingest_api/util/message_brokers/qos_scheduler.py +283 -0
  146. nv_ingest_api/util/message_brokers/simple_message_broker/__init__.py +9 -0
  147. nv_ingest_api/util/message_brokers/simple_message_broker/broker.py +465 -0
  148. nv_ingest_api/util/message_brokers/simple_message_broker/ordered_message_queue.py +71 -0
  149. nv_ingest_api/util/message_brokers/simple_message_broker/simple_client.py +455 -0
  150. nv_ingest_api/util/metadata/__init__.py +5 -0
  151. nv_ingest_api/util/metadata/aggregators.py +516 -0
  152. nv_ingest_api/util/multi_processing/__init__.py +8 -0
  153. nv_ingest_api/util/multi_processing/mp_pool_singleton.py +200 -0
  154. nv_ingest_api/util/nim/__init__.py +161 -0
  155. nv_ingest_api/util/pdf/__init__.py +3 -0
  156. nv_ingest_api/util/pdf/pdfium.py +428 -0
  157. nv_ingest_api/util/schema/__init__.py +3 -0
  158. nv_ingest_api/util/schema/schema_validator.py +10 -0
  159. nv_ingest_api/util/service_clients/__init__.py +3 -0
  160. nv_ingest_api/util/service_clients/client_base.py +86 -0
  161. nv_ingest_api/util/service_clients/kafka/__init__.py +3 -0
  162. nv_ingest_api/util/service_clients/redis/__init__.py +3 -0
  163. nv_ingest_api/util/service_clients/redis/redis_client.py +983 -0
  164. nv_ingest_api/util/service_clients/rest/__init__.py +0 -0
  165. nv_ingest_api/util/service_clients/rest/rest_client.py +595 -0
  166. nv_ingest_api/util/string_processing/__init__.py +51 -0
  167. nv_ingest_api/util/string_processing/configuration.py +682 -0
  168. nv_ingest_api/util/string_processing/yaml.py +109 -0
  169. nv_ingest_api/util/system/__init__.py +0 -0
  170. nv_ingest_api/util/system/hardware_info.py +594 -0
  171. nv_ingest_api-26.1.0rc4.dist-info/METADATA +237 -0
  172. nv_ingest_api-26.1.0rc4.dist-info/RECORD +177 -0
  173. nv_ingest_api-26.1.0rc4.dist-info/WHEEL +5 -0
  174. nv_ingest_api-26.1.0rc4.dist-info/licenses/LICENSE +201 -0
  175. nv_ingest_api-26.1.0rc4.dist-info/top_level.txt +2 -0
  176. udfs/__init__.py +5 -0
  177. udfs/llm_summarizer_udf.py +259 -0
@@ -0,0 +1,270 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ from typing import Dict, Any, Optional, List
6
+
7
+ import numpy as np
8
+ import logging
9
+
10
+ from nv_ingest_api.internal.primitives.nim import ModelInterface
11
+ from nv_ingest_api.util.image_processing.transforms import base64_to_numpy
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class DeplotModelInterface(ModelInterface):
17
+ """
18
+ An interface for handling inference with a Deplot model, supporting both gRPC and HTTP protocols,
19
+ now updated to handle multiple base64 images ('base64_images').
20
+ """
21
+
22
+ def name(self) -> str:
23
+ """
24
+ Get the name of the model interface.
25
+
26
+ Returns
27
+ -------
28
+ str
29
+ The name of the model interface ("Deplot").
30
+ """
31
+ return "Deplot"
32
+
33
+ def prepare_data_for_inference(self, data: Dict[str, Any]) -> Dict[str, Any]:
34
+ """
35
+ Prepare input data by decoding one or more base64-encoded images into NumPy arrays.
36
+
37
+ Parameters
38
+ ----------
39
+ data : dict
40
+ The input data containing either 'base64_image' (single image)
41
+ or 'base64_images' (multiple images).
42
+
43
+ Returns
44
+ -------
45
+ dict
46
+ The updated data dictionary with 'image_arrays': a list of decoded NumPy arrays.
47
+ """
48
+
49
+ # Handle a single base64_image or multiple base64_images
50
+ if "base64_images" in data:
51
+ base64_list = data["base64_images"]
52
+ if not isinstance(base64_list, list):
53
+ raise ValueError("The 'base64_images' key must contain a list of base64-encoded strings.")
54
+ image_arrays = [base64_to_numpy(b64) for b64 in base64_list]
55
+
56
+ elif "base64_image" in data:
57
+ # Fallback for single image
58
+ image_arrays = [base64_to_numpy(data["base64_image"])]
59
+ else:
60
+ raise KeyError("Input data must include 'base64_image' or 'base64_images'.")
61
+
62
+ data["image_arrays"] = image_arrays
63
+
64
+ return data
65
+
66
+ def format_input(self, data: Dict[str, Any], protocol: str, max_batch_size: int, **kwargs) -> Any:
67
+ """
68
+ Format input data for the specified protocol (gRPC or HTTP) for Deplot.
69
+ For HTTP, we now construct multiple messages—one per image batch—along with
70
+ corresponding batch data carrying the original image arrays and their dimensions.
71
+
72
+ Parameters
73
+ ----------
74
+ data : dict of str -> Any
75
+ The input data dictionary, expected to contain "image_arrays" (a list of np.ndarray).
76
+ protocol : str
77
+ The protocol to use, "grpc" or "http".
78
+ max_batch_size : int
79
+ The maximum number of images per batch.
80
+ kwargs : dict
81
+ Additional parameters to pass to the payload preparation (for HTTP).
82
+
83
+ Returns
84
+ -------
85
+ tuple
86
+ (formatted_batches, formatted_batch_data) where:
87
+ - For gRPC: formatted_batches is a list of NumPy arrays, each of shape (B, H, W, C)
88
+ with B <= max_batch_size.
89
+ - For HTTP: formatted_batches is a list of JSON-serializable payload dicts.
90
+ - In both cases, formatted_batch_data is a list of dicts containing:
91
+ "image_arrays": the list of original np.ndarray images for that batch, and
92
+ "image_dims": a list of (height, width) tuples for each image in the batch.
93
+
94
+ Raises
95
+ ------
96
+ KeyError
97
+ If "image_arrays" is missing in the data dictionary.
98
+ ValueError
99
+ If the protocol is invalid, or if no valid images are found.
100
+ """
101
+ if "image_arrays" not in data:
102
+ raise KeyError("Expected 'image_arrays' in data. Call prepare_data_for_inference first.")
103
+
104
+ image_arrays = data["image_arrays"]
105
+ # Compute image dimensions from each image array.
106
+ image_dims = [(img.shape[0], img.shape[1]) for img in image_arrays]
107
+
108
+ # Helper function: chunk a list into sublists of length <= chunk_size.
109
+ def chunk_list(lst: list, chunk_size: int) -> List[list]:
110
+ return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
111
+
112
+ if protocol == "grpc":
113
+ logger.debug("Formatting input for gRPC Deplot model (potentially batched).")
114
+ processed = []
115
+ for arr in image_arrays:
116
+ # Ensure each image has shape (1, H, W, C)
117
+ if arr.ndim == 3:
118
+ arr = np.expand_dims(arr, axis=0)
119
+ arr = arr.astype(np.float32)
120
+ arr /= 255.0 # Normalize to [0,1]
121
+ processed.append(arr)
122
+
123
+ if not processed:
124
+ raise ValueError("No valid images found for gRPC formatting.")
125
+
126
+ formatted_batches = []
127
+ formatted_batch_data = []
128
+ proc_chunks = chunk_list(processed, max_batch_size)
129
+ orig_chunks = chunk_list(image_arrays, max_batch_size)
130
+ dims_chunks = chunk_list(image_dims, max_batch_size)
131
+
132
+ for proc_chunk, orig_chunk, dims_chunk in zip(proc_chunks, orig_chunks, dims_chunks):
133
+ # Concatenate along the batch dimension to form a single input.
134
+ batched_input = np.concatenate(proc_chunk, axis=0)
135
+ formatted_batches.append(batched_input)
136
+ formatted_batch_data.append({"image_arrays": orig_chunk, "image_dims": dims_chunk})
137
+ return formatted_batches, formatted_batch_data
138
+
139
+ elif protocol == "http":
140
+ logger.debug("Formatting input for HTTP Deplot model (multiple messages).")
141
+ if "base64_images" in data:
142
+ base64_list = data["base64_images"]
143
+ else:
144
+ base64_list = [data["base64_image"]]
145
+
146
+ formatted_batches = []
147
+ formatted_batch_data = []
148
+ b64_chunks = chunk_list(base64_list, max_batch_size)
149
+ orig_chunks = chunk_list(image_arrays, max_batch_size)
150
+ dims_chunks = chunk_list(image_dims, max_batch_size)
151
+
152
+ for b64_chunk, orig_chunk, dims_chunk in zip(b64_chunks, orig_chunks, dims_chunks):
153
+ payload = self._prepare_deplot_payload(
154
+ base64_list=b64_chunk,
155
+ max_tokens=kwargs.get("max_tokens", 500),
156
+ temperature=kwargs.get("temperature", 0.5),
157
+ top_p=kwargs.get("top_p", 0.9),
158
+ )
159
+ formatted_batches.append(payload)
160
+ formatted_batch_data.append({"image_arrays": orig_chunk, "image_dims": dims_chunk})
161
+ return formatted_batches, formatted_batch_data
162
+
163
+ else:
164
+ raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.")
165
+
166
+ def parse_output(self, response: Any, protocol: str, data: Optional[Dict[str, Any]] = None, **kwargs) -> Any:
167
+ """
168
+ Parse the model's inference response.
169
+ """
170
+ if protocol == "grpc":
171
+ logger.debug("Parsing output from gRPC Deplot model (batched).")
172
+ # Each batch element might be returned as a list of bytes. Combine or keep separate as needed.
173
+ results = []
174
+ for item in response:
175
+ # If item is [b'...'], decode and join
176
+ if isinstance(item, list):
177
+ joined_str = " ".join(o.decode("utf-8") for o in item)
178
+ results.append(joined_str)
179
+ else:
180
+ # single bytes or str
181
+ val = item.decode("utf-8") if isinstance(item, bytes) else str(item)
182
+ results.append(val)
183
+ return results # Return a list of strings, one per image.
184
+
185
+ elif protocol == "http":
186
+ logger.debug("Parsing output from HTTP Deplot model.")
187
+ return self._extract_content_from_deplot_response(response)
188
+ else:
189
+ raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.")
190
+
191
+ def process_inference_results(self, output: Any, protocol: str, **kwargs) -> Any:
192
+ """
193
+ Process inference results for the Deplot model.
194
+
195
+ Parameters
196
+ ----------
197
+ output : Any
198
+ The raw output from the model.
199
+ protocol : str
200
+ The protocol used for inference (gRPC or HTTP).
201
+
202
+ Returns
203
+ -------
204
+ Any
205
+ The processed inference results.
206
+ """
207
+
208
+ # For Deplot, the output is the chart content as a string
209
+ return output
210
+
211
+ @staticmethod
212
+ def _prepare_deplot_payload(
213
+ base64_list: list,
214
+ max_tokens: int = 500,
215
+ temperature: float = 0.5,
216
+ top_p: float = 0.9,
217
+ ) -> Dict[str, Any]:
218
+ """
219
+ Prepare an HTTP payload for Deplot that includes one message per image,
220
+ matching the original single-image style:
221
+
222
+ messages = [
223
+ {
224
+ "role": "user",
225
+ "content": "Generate ... <img src=\"data:image/png;base64,...\" />"
226
+ },
227
+ {
228
+ "role": "user",
229
+ "content": "Generate ... <img src=\"data:image/png;base64,...\" />"
230
+ },
231
+ ...
232
+ ]
233
+
234
+ If your backend expects multiple messages in a single request, this keeps
235
+ the same structure as the single-image code repeated N times.
236
+ """
237
+ messages = []
238
+ # Note: deplot NIM currently only supports a single message per request
239
+ for b64_img in base64_list:
240
+ messages.append(
241
+ {
242
+ "role": "user",
243
+ "content": (
244
+ "Generate the underlying data table of the figure below: "
245
+ f'<img src="data:image/png;base64,{b64_img}" />'
246
+ ),
247
+ }
248
+ )
249
+
250
+ payload = {
251
+ "model": "google/deplot",
252
+ "messages": messages, # multiple user messages now
253
+ "max_tokens": max_tokens,
254
+ "stream": False,
255
+ "temperature": temperature,
256
+ "top_p": top_p,
257
+ }
258
+ return payload
259
+
260
+ @staticmethod
261
+ def _extract_content_from_deplot_response(json_response: Dict[str, Any]) -> Any:
262
+ """
263
+ Extract content from the JSON response of a Deplot HTTP API request.
264
+ The original code expected a single choice with a single textual content.
265
+ """
266
+ if "choices" not in json_response or not json_response["choices"]:
267
+ raise RuntimeError("Unexpected response format: 'choices' key is missing or empty.")
268
+
269
+ # If the service only returns one textual result, we return that one.
270
+ return json_response["choices"][0]["message"]["content"]
@@ -0,0 +1,338 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ import logging
6
+ from typing import Optional
7
+
8
+ import backoff
9
+ import cv2
10
+ import numpy as np
11
+ import requests
12
+
13
+ from nv_ingest_api.internal.primitives.nim.model_interface.decorators import multiprocessing_cache
14
+ from nv_ingest_api.util.image_processing.transforms import pad_image, normalize_image
15
+ from nv_ingest_api.util.string_processing import generate_url, remove_url_endpoints
16
+
17
+ cv2.setNumThreads(1)
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ def preprocess_image_for_paddle(array: np.ndarray, image_max_dimension: int = 960) -> np.ndarray:
22
+ """
23
+ Preprocesses an input image to be suitable for use with PaddleOCR by resizing, normalizing, padding,
24
+ and transposing it into the required format.
25
+
26
+ This function is intended for preprocessing images to be passed as input to PaddleOCR using GRPC.
27
+ It is not necessary when using the HTTP endpoint.
28
+
29
+ Steps:
30
+ -----
31
+ 1. Resizes the image while maintaining aspect ratio such that its largest dimension is scaled to 960 pixels.
32
+ 2. Normalizes the image using the `normalize_image` function.
33
+ 3. Pads the image to ensure both its height and width are multiples of 32, as required by PaddleOCR.
34
+ 4. Transposes the image from (height, width, channel) to (channel, height, width), the format expected by PaddleOCR.
35
+
36
+ Parameters:
37
+ ----------
38
+ array : np.ndarray
39
+ The input image array of shape (height, width, channels). It should have pixel values in the range [0, 255].
40
+
41
+ Returns:
42
+ -------
43
+ np.ndarray
44
+ A preprocessed image with the shape (channels, height, width) and normalized pixel values.
45
+ The image will be padded to have dimensions that are multiples of 32, with the padding color set to 0.
46
+
47
+ Notes:
48
+ -----
49
+ - The image is resized so that its largest dimension becomes 960 pixels, maintaining the aspect ratio.
50
+ - After normalization, the image is padded to the nearest multiple of 32 in both dimensions, which is
51
+ a requirement for PaddleOCR.
52
+ - The normalized pixel values are scaled between 0 and 1 before padding and transposing the image.
53
+ """
54
+ height, width = array.shape[:2]
55
+ scale_factor = image_max_dimension / max(height, width)
56
+ new_height = int(height * scale_factor)
57
+ new_width = int(width * scale_factor)
58
+ resized = cv2.resize(array, (new_width, new_height))
59
+
60
+ normalized = normalize_image(resized)
61
+
62
+ # PaddleOCR NIM (GRPC) requires input shapes to be multiples of 32.
63
+ new_height = (normalized.shape[0] + 31) // 32 * 32
64
+ new_width = (normalized.shape[1] + 31) // 32 * 32
65
+ padded, (pad_width, pad_height) = pad_image(
66
+ normalized, target_height=new_height, target_width=new_width, background_color=0, dtype=np.float32
67
+ )
68
+
69
+ # PaddleOCR NIM (GRPC) requires input to be (channel, height, width).
70
+ transposed = padded.transpose((2, 0, 1))
71
+
72
+ # Metadata can used for inverting transformations on the resulting bounding boxes.
73
+ metadata = {
74
+ "original_height": height,
75
+ "original_width": width,
76
+ "scale_factor": scale_factor,
77
+ "new_height": transposed.shape[1],
78
+ "new_width": transposed.shape[2],
79
+ "pad_height": pad_height,
80
+ "pad_width": pad_width,
81
+ }
82
+
83
+ return transposed, metadata
84
+
85
+
86
+ def preprocess_image_for_ocr(
87
+ array: np.ndarray,
88
+ target_height: Optional[int] = None,
89
+ target_width: Optional[int] = None,
90
+ pad_how: str = "bottom_right",
91
+ normalize: bool = False,
92
+ channel_first: bool = False,
93
+ ) -> np.ndarray:
94
+ """
95
+ Preprocesses an input image to be suitable for use with NemoRetriever-OCR.
96
+
97
+ This function is intended for preprocessing images to be passed as input to NemoRetriever-OCR using GRPC.
98
+ It is not necessary when using the HTTP endpoint.
99
+
100
+ Parameters:
101
+ ----------
102
+ array : np.ndarray
103
+ The input image array of shape (height, width, channels). It should have pixel values in the range [0, 255].
104
+
105
+ Returns:
106
+ -------
107
+ np.ndarray
108
+ A preprocessed image with the shape (channels, height, width).
109
+ """
110
+ height, width = array.shape[:2]
111
+
112
+ if target_height is None:
113
+ target_height = height
114
+
115
+ if target_width is None:
116
+ target_width = width
117
+
118
+ padded, (pad_width, pad_height) = pad_image(
119
+ array,
120
+ target_height=target_height,
121
+ target_width=target_width,
122
+ background_color=255,
123
+ dtype=np.float32,
124
+ how=pad_how,
125
+ )
126
+
127
+ if normalize:
128
+ padded = padded / 255.0
129
+
130
+ if channel_first:
131
+ # NemoRetriever-OCR NIM (GRPC) requires input to be (channel, height, width).
132
+ padded = padded.transpose((2, 0, 1))
133
+
134
+ # Metadata can used for inverting transformations on the resulting bounding boxes.
135
+ metadata = {
136
+ "original_height": height,
137
+ "original_width": width,
138
+ "new_height": target_height,
139
+ "new_width": target_width,
140
+ "pad_height": pad_height,
141
+ "pad_width": pad_width,
142
+ }
143
+
144
+ return padded, metadata
145
+
146
+
147
+ def is_ready(http_endpoint: str, ready_endpoint: str) -> bool:
148
+ """
149
+ Check if the server at the given endpoint is ready.
150
+
151
+ Parameters
152
+ ----------
153
+ http_endpoint : str
154
+ The HTTP endpoint of the server.
155
+ ready_endpoint : str
156
+ The specific ready-check endpoint.
157
+
158
+ Returns
159
+ -------
160
+ bool
161
+ True if the server is ready, False otherwise.
162
+ """
163
+
164
+ # IF the url is empty or None that means the service was not configured
165
+ # and is therefore automatically marked as "ready"
166
+ if http_endpoint is None or http_endpoint == "":
167
+ return True
168
+
169
+ # If the url is for build.nvidia.com, it is automatically assumed "ready"
170
+ if "ai.api.nvidia.com" in http_endpoint:
171
+ return True
172
+
173
+ url = generate_url(http_endpoint)
174
+ url = remove_url_endpoints(url)
175
+
176
+ if not ready_endpoint.startswith("/") and not url.endswith("/"):
177
+ ready_endpoint = "/" + ready_endpoint
178
+
179
+ url = url + ready_endpoint
180
+
181
+ # Call the ready endpoint of the NIM
182
+ try:
183
+ # Use a short timeout to prevent long hanging calls. 5 seconds seems resonable
184
+ resp = requests.get(url, timeout=5)
185
+ if resp.status_code == 200:
186
+ # The NIM is saying it is ready to serve
187
+ return True
188
+ elif resp.status_code == 503:
189
+ # NIM is explicitly saying it is not ready.
190
+ return False
191
+ else:
192
+ # Any other code is confusing. We should log it with a warning
193
+ # as it could be something that might hold up ready state
194
+ logger.warning(f"'{url}' HTTP Status: {resp.status_code} - Response Payload: {resp.json()}")
195
+ return False
196
+ except requests.HTTPError as http_err:
197
+ logger.warning(f"'{url}' produced a HTTP error: {http_err}")
198
+ return False
199
+ except requests.Timeout:
200
+ logger.warning(f"'{url}' request timed out")
201
+ return False
202
+ except ConnectionError:
203
+ logger.warning(f"A connection error for '{url}' occurred")
204
+ return False
205
+ except requests.RequestException as err:
206
+ logger.warning(f"An error occurred: {err} for '{url}'")
207
+ return False
208
+ except Exception as ex:
209
+ # Don't let anything squeeze by
210
+ logger.warning(f"Exception: {ex}")
211
+ return False
212
+
213
+
214
+ def _query_metadata(
215
+ http_endpoint: str,
216
+ field_name: str,
217
+ default_value: str,
218
+ retry_value: str = "",
219
+ metadata_endpoint: str = "/v1/metadata",
220
+ ) -> str:
221
+ if (http_endpoint is None) or (http_endpoint == ""):
222
+ return default_value
223
+
224
+ url = generate_url(http_endpoint)
225
+ url = remove_url_endpoints(url)
226
+
227
+ if not metadata_endpoint.startswith("/") and not url.endswith("/"):
228
+ metadata_endpoint = "/" + metadata_endpoint
229
+
230
+ url = url + metadata_endpoint
231
+
232
+ # Call the metadata endpoint of the NIM
233
+ try:
234
+ # Use a short timeout to prevent long hanging calls. 5 seconds seems reasonable
235
+ resp = requests.get(url, timeout=5)
236
+ if resp.status_code == 200:
237
+ field_value = resp.json().get(field_name, "")
238
+ if field_value:
239
+ return field_value
240
+ else:
241
+ # If the field is empty, retry
242
+ logger.warning(f"No {field_name} field in response from '{url}'. Retrying.")
243
+ return retry_value
244
+ else:
245
+ # Any other code is confusing. We should log it with a warning
246
+ logger.warning(f"'{url}' HTTP Status: {resp.status_code} - Response Payload: {resp.text}")
247
+ return retry_value
248
+ except requests.HTTPError as http_err:
249
+ logger.warning(f"'{url}' produced a HTTP error: {http_err}")
250
+ return retry_value
251
+ except requests.Timeout:
252
+ logger.warning(f"'{url}' request timed out")
253
+ return retry_value
254
+ except ConnectionError:
255
+ logger.warning(f"A connection error for '{url}' occurred")
256
+ return retry_value
257
+ except requests.RequestException as err:
258
+ logger.warning(f"An error occurred: {err} for '{url}'")
259
+ return retry_value
260
+ except Exception as ex:
261
+ # Don't let anything squeeze by
262
+ logger.warning(f"Exception: {ex}")
263
+ return retry_value
264
+
265
+
266
+ @multiprocessing_cache(max_calls=100) # Cache results first to avoid redundant retries from backoff
267
+ @backoff.on_predicate(backoff.expo, max_time=30)
268
+ def get_version(http_endpoint: str, metadata_endpoint: str = "/v1/metadata", version_field: str = "version") -> str:
269
+ """
270
+ Get the version of the server from its metadata endpoint.
271
+
272
+ Parameters
273
+ ----------
274
+ http_endpoint : str
275
+ The HTTP endpoint of the server.
276
+ metadata_endpoint : str, optional
277
+ The metadata endpoint to query (default: "/v1/metadata").
278
+ version_field : str, optional
279
+ The field containing the version in the response (default: "version").
280
+
281
+ Returns
282
+ -------
283
+ str
284
+ The version of the server, or an empty string if unavailable.
285
+ """
286
+ default_version = "1.0.0"
287
+
288
+ # TODO: Need a way to match NIM version to API versions.
289
+ if "ai.api.nvidia.com" in http_endpoint or "api.nvcf.nvidia.com" in http_endpoint:
290
+ return default_version
291
+
292
+ return _query_metadata(
293
+ http_endpoint,
294
+ field_name=version_field,
295
+ default_value=default_version,
296
+ )
297
+
298
+
299
+ @multiprocessing_cache(max_calls=100) # Cache results first to avoid redundant retries from backoff
300
+ @backoff.on_predicate(backoff.expo, max_time=30)
301
+ def get_model_name(
302
+ http_endpoint: str,
303
+ default_model_name,
304
+ metadata_endpoint: str = "/v1/metadata",
305
+ model_info_field: str = "modelInfo",
306
+ ) -> str:
307
+ """
308
+ Get the model name of the server from its metadata endpoint.
309
+
310
+ Parameters
311
+ ----------
312
+ http_endpoint : str
313
+ The HTTP endpoint of the server.
314
+ metadata_endpoint : str, optional
315
+ The metadata endpoint to query (default: "/v1/metadata").
316
+ model_info_field : str, optional
317
+ The field containing the model info in the response (default: "modelInfo").
318
+
319
+ Returns
320
+ -------
321
+ str
322
+ The model name of the server, or an empty string if unavailable.
323
+ """
324
+ if "ai.api.nvidia.com" in http_endpoint:
325
+ return http_endpoint.strip("/").strip("/chat/completions").split("/")[-1]
326
+
327
+ if "api.nvcf.nvidia.com" in http_endpoint:
328
+ return default_model_name
329
+
330
+ model_info = _query_metadata(
331
+ http_endpoint,
332
+ field_name=model_info_field,
333
+ default_value={"shortName": default_model_name},
334
+ )
335
+ short_name = model_info[0].get("shortName", default_model_name)
336
+ model_name = short_name.split(":")[0]
337
+
338
+ return model_name