nv-ingest-api 2025.4.21.dev20250421__py3-none-any.whl → 2025.4.23.dev20250423__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nv-ingest-api might be problematic. Click here for more details.

Files changed (153) hide show
  1. nv_ingest_api/__init__.py +3 -0
  2. nv_ingest_api/interface/__init__.py +215 -0
  3. nv_ingest_api/interface/extract.py +972 -0
  4. nv_ingest_api/interface/mutate.py +154 -0
  5. nv_ingest_api/interface/store.py +218 -0
  6. nv_ingest_api/interface/transform.py +382 -0
  7. nv_ingest_api/interface/utility.py +200 -0
  8. nv_ingest_api/internal/enums/__init__.py +3 -0
  9. nv_ingest_api/internal/enums/common.py +494 -0
  10. nv_ingest_api/internal/extract/__init__.py +3 -0
  11. nv_ingest_api/internal/extract/audio/__init__.py +3 -0
  12. nv_ingest_api/internal/extract/audio/audio_extraction.py +149 -0
  13. nv_ingest_api/internal/extract/docx/__init__.py +5 -0
  14. nv_ingest_api/internal/extract/docx/docx_extractor.py +205 -0
  15. nv_ingest_api/internal/extract/docx/engines/__init__.py +0 -0
  16. nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/__init__.py +3 -0
  17. nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/docx_helper.py +122 -0
  18. nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/docxreader.py +895 -0
  19. nv_ingest_api/internal/extract/image/__init__.py +3 -0
  20. nv_ingest_api/internal/extract/image/chart_extractor.py +353 -0
  21. nv_ingest_api/internal/extract/image/image_extractor.py +204 -0
  22. nv_ingest_api/internal/extract/image/image_helpers/__init__.py +3 -0
  23. nv_ingest_api/internal/extract/image/image_helpers/common.py +403 -0
  24. nv_ingest_api/internal/extract/image/infographic_extractor.py +253 -0
  25. nv_ingest_api/internal/extract/image/table_extractor.py +344 -0
  26. nv_ingest_api/internal/extract/pdf/__init__.py +3 -0
  27. nv_ingest_api/internal/extract/pdf/engines/__init__.py +19 -0
  28. nv_ingest_api/internal/extract/pdf/engines/adobe.py +484 -0
  29. nv_ingest_api/internal/extract/pdf/engines/llama.py +243 -0
  30. nv_ingest_api/internal/extract/pdf/engines/nemoretriever.py +597 -0
  31. nv_ingest_api/internal/extract/pdf/engines/pdf_helpers/__init__.py +146 -0
  32. nv_ingest_api/internal/extract/pdf/engines/pdfium.py +603 -0
  33. nv_ingest_api/internal/extract/pdf/engines/tika.py +96 -0
  34. nv_ingest_api/internal/extract/pdf/engines/unstructured_io.py +426 -0
  35. nv_ingest_api/internal/extract/pdf/pdf_extractor.py +74 -0
  36. nv_ingest_api/internal/extract/pptx/__init__.py +5 -0
  37. nv_ingest_api/internal/extract/pptx/engines/__init__.py +0 -0
  38. nv_ingest_api/internal/extract/pptx/engines/pptx_helper.py +799 -0
  39. nv_ingest_api/internal/extract/pptx/pptx_extractor.py +187 -0
  40. nv_ingest_api/internal/mutate/__init__.py +3 -0
  41. nv_ingest_api/internal/mutate/deduplicate.py +110 -0
  42. nv_ingest_api/internal/mutate/filter.py +133 -0
  43. nv_ingest_api/internal/primitives/__init__.py +0 -0
  44. nv_ingest_api/{primitives → internal/primitives}/control_message_task.py +4 -0
  45. nv_ingest_api/{primitives → internal/primitives}/ingest_control_message.py +5 -2
  46. nv_ingest_api/internal/primitives/nim/__init__.py +8 -0
  47. nv_ingest_api/internal/primitives/nim/default_values.py +15 -0
  48. nv_ingest_api/internal/primitives/nim/model_interface/__init__.py +3 -0
  49. nv_ingest_api/internal/primitives/nim/model_interface/cached.py +274 -0
  50. nv_ingest_api/internal/primitives/nim/model_interface/decorators.py +56 -0
  51. nv_ingest_api/internal/primitives/nim/model_interface/deplot.py +270 -0
  52. nv_ingest_api/internal/primitives/nim/model_interface/helpers.py +275 -0
  53. nv_ingest_api/internal/primitives/nim/model_interface/nemoretriever_parse.py +238 -0
  54. nv_ingest_api/internal/primitives/nim/model_interface/paddle.py +462 -0
  55. nv_ingest_api/internal/primitives/nim/model_interface/parakeet.py +367 -0
  56. nv_ingest_api/internal/primitives/nim/model_interface/text_embedding.py +132 -0
  57. nv_ingest_api/internal/primitives/nim/model_interface/vlm.py +152 -0
  58. nv_ingest_api/internal/primitives/nim/model_interface/yolox.py +1400 -0
  59. nv_ingest_api/internal/primitives/nim/nim_client.py +344 -0
  60. nv_ingest_api/internal/primitives/nim/nim_model_interface.py +81 -0
  61. nv_ingest_api/internal/primitives/tracing/__init__.py +0 -0
  62. nv_ingest_api/internal/primitives/tracing/latency.py +69 -0
  63. nv_ingest_api/internal/primitives/tracing/logging.py +96 -0
  64. nv_ingest_api/internal/primitives/tracing/tagging.py +197 -0
  65. nv_ingest_api/internal/schemas/__init__.py +3 -0
  66. nv_ingest_api/internal/schemas/extract/__init__.py +3 -0
  67. nv_ingest_api/internal/schemas/extract/extract_audio_schema.py +130 -0
  68. nv_ingest_api/internal/schemas/extract/extract_chart_schema.py +135 -0
  69. nv_ingest_api/internal/schemas/extract/extract_docx_schema.py +124 -0
  70. nv_ingest_api/internal/schemas/extract/extract_image_schema.py +124 -0
  71. nv_ingest_api/internal/schemas/extract/extract_infographic_schema.py +128 -0
  72. nv_ingest_api/internal/schemas/extract/extract_pdf_schema.py +218 -0
  73. nv_ingest_api/internal/schemas/extract/extract_pptx_schema.py +124 -0
  74. nv_ingest_api/internal/schemas/extract/extract_table_schema.py +129 -0
  75. nv_ingest_api/internal/schemas/message_brokers/__init__.py +3 -0
  76. nv_ingest_api/internal/schemas/message_brokers/message_broker_client_schema.py +23 -0
  77. nv_ingest_api/internal/schemas/message_brokers/request_schema.py +34 -0
  78. nv_ingest_api/internal/schemas/message_brokers/response_schema.py +19 -0
  79. nv_ingest_api/internal/schemas/meta/__init__.py +3 -0
  80. nv_ingest_api/internal/schemas/meta/base_model_noext.py +11 -0
  81. nv_ingest_api/internal/schemas/meta/ingest_job_schema.py +237 -0
  82. nv_ingest_api/internal/schemas/meta/metadata_schema.py +221 -0
  83. nv_ingest_api/internal/schemas/mutate/__init__.py +3 -0
  84. nv_ingest_api/internal/schemas/mutate/mutate_image_dedup_schema.py +16 -0
  85. nv_ingest_api/internal/schemas/store/__init__.py +3 -0
  86. nv_ingest_api/internal/schemas/store/store_embedding_schema.py +28 -0
  87. nv_ingest_api/internal/schemas/store/store_image_schema.py +30 -0
  88. nv_ingest_api/internal/schemas/transform/__init__.py +3 -0
  89. nv_ingest_api/internal/schemas/transform/transform_image_caption_schema.py +15 -0
  90. nv_ingest_api/internal/schemas/transform/transform_image_filter_schema.py +17 -0
  91. nv_ingest_api/internal/schemas/transform/transform_text_embedding_schema.py +25 -0
  92. nv_ingest_api/internal/schemas/transform/transform_text_splitter_schema.py +22 -0
  93. nv_ingest_api/internal/store/__init__.py +3 -0
  94. nv_ingest_api/internal/store/embed_text_upload.py +236 -0
  95. nv_ingest_api/internal/store/image_upload.py +232 -0
  96. nv_ingest_api/internal/transform/__init__.py +3 -0
  97. nv_ingest_api/internal/transform/caption_image.py +205 -0
  98. nv_ingest_api/internal/transform/embed_text.py +496 -0
  99. nv_ingest_api/internal/transform/split_text.py +157 -0
  100. nv_ingest_api/util/__init__.py +0 -0
  101. nv_ingest_api/util/control_message/__init__.py +0 -0
  102. nv_ingest_api/util/control_message/validators.py +47 -0
  103. nv_ingest_api/util/converters/__init__.py +0 -0
  104. nv_ingest_api/util/converters/bytetools.py +78 -0
  105. nv_ingest_api/util/converters/containers.py +65 -0
  106. nv_ingest_api/util/converters/datetools.py +90 -0
  107. nv_ingest_api/util/converters/dftools.py +127 -0
  108. nv_ingest_api/util/converters/formats.py +64 -0
  109. nv_ingest_api/util/converters/type_mappings.py +27 -0
  110. nv_ingest_api/util/detectors/__init__.py +5 -0
  111. nv_ingest_api/util/detectors/language.py +38 -0
  112. nv_ingest_api/util/exception_handlers/__init__.py +0 -0
  113. nv_ingest_api/util/exception_handlers/converters.py +72 -0
  114. nv_ingest_api/util/exception_handlers/decorators.py +223 -0
  115. nv_ingest_api/util/exception_handlers/detectors.py +74 -0
  116. nv_ingest_api/util/exception_handlers/pdf.py +116 -0
  117. nv_ingest_api/util/exception_handlers/schemas.py +68 -0
  118. nv_ingest_api/util/image_processing/__init__.py +5 -0
  119. nv_ingest_api/util/image_processing/clustering.py +260 -0
  120. nv_ingest_api/util/image_processing/processing.py +179 -0
  121. nv_ingest_api/util/image_processing/table_and_chart.py +449 -0
  122. nv_ingest_api/util/image_processing/transforms.py +407 -0
  123. nv_ingest_api/util/logging/__init__.py +0 -0
  124. nv_ingest_api/util/logging/configuration.py +31 -0
  125. nv_ingest_api/util/message_brokers/__init__.py +3 -0
  126. nv_ingest_api/util/message_brokers/simple_message_broker/__init__.py +9 -0
  127. nv_ingest_api/util/message_brokers/simple_message_broker/broker.py +465 -0
  128. nv_ingest_api/util/message_brokers/simple_message_broker/ordered_message_queue.py +71 -0
  129. nv_ingest_api/util/message_brokers/simple_message_broker/simple_client.py +451 -0
  130. nv_ingest_api/util/metadata/__init__.py +5 -0
  131. nv_ingest_api/util/metadata/aggregators.py +469 -0
  132. nv_ingest_api/util/multi_processing/__init__.py +8 -0
  133. nv_ingest_api/util/multi_processing/mp_pool_singleton.py +194 -0
  134. nv_ingest_api/util/nim/__init__.py +56 -0
  135. nv_ingest_api/util/pdf/__init__.py +3 -0
  136. nv_ingest_api/util/pdf/pdfium.py +427 -0
  137. nv_ingest_api/util/schema/__init__.py +0 -0
  138. nv_ingest_api/util/schema/schema_validator.py +10 -0
  139. nv_ingest_api/util/service_clients/__init__.py +3 -0
  140. nv_ingest_api/util/service_clients/client_base.py +86 -0
  141. nv_ingest_api/util/service_clients/kafka/__init__.py +3 -0
  142. nv_ingest_api/util/service_clients/redis/__init__.py +0 -0
  143. nv_ingest_api/util/service_clients/redis/redis_client.py +823 -0
  144. nv_ingest_api/util/service_clients/rest/__init__.py +0 -0
  145. nv_ingest_api/util/service_clients/rest/rest_client.py +531 -0
  146. nv_ingest_api/util/string_processing/__init__.py +51 -0
  147. {nv_ingest_api-2025.4.21.dev20250421.dist-info → nv_ingest_api-2025.4.23.dev20250423.dist-info}/METADATA +1 -1
  148. nv_ingest_api-2025.4.23.dev20250423.dist-info/RECORD +152 -0
  149. {nv_ingest_api-2025.4.21.dev20250421.dist-info → nv_ingest_api-2025.4.23.dev20250423.dist-info}/WHEEL +1 -1
  150. nv_ingest_api-2025.4.21.dev20250421.dist-info/RECORD +0 -9
  151. /nv_ingest_api/{primitives → internal}/__init__.py +0 -0
  152. {nv_ingest_api-2025.4.21.dev20250421.dist-info → nv_ingest_api-2025.4.23.dev20250423.dist-info}/licenses/LICENSE +0 -0
  153. {nv_ingest_api-2025.4.21.dev20250421.dist-info → nv_ingest_api-2025.4.23.dev20250423.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,275 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ import logging
6
+
7
+ import backoff
8
+ import cv2
9
+ import numpy as np
10
+ import requests
11
+
12
+ from nv_ingest_api.internal.primitives.nim.model_interface.decorators import multiprocessing_cache
13
+ from nv_ingest_api.util.image_processing.transforms import pad_image, normalize_image
14
+ from nv_ingest_api.util.string_processing import generate_url, remove_url_endpoints
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def preprocess_image_for_paddle(array: np.ndarray, image_max_dimension: int = 960) -> np.ndarray:
20
+ """
21
+ Preprocesses an input image to be suitable for use with PaddleOCR by resizing, normalizing, padding,
22
+ and transposing it into the required format.
23
+
24
+ This function is intended for preprocessing images to be passed as input to PaddleOCR using GRPC.
25
+ It is not necessary when using the HTTP endpoint.
26
+
27
+ Steps:
28
+ -----
29
+ 1. Resizes the image while maintaining aspect ratio such that its largest dimension is scaled to 960 pixels.
30
+ 2. Normalizes the image using the `normalize_image` function.
31
+ 3. Pads the image to ensure both its height and width are multiples of 32, as required by PaddleOCR.
32
+ 4. Transposes the image from (height, width, channel) to (channel, height, width), the format expected by PaddleOCR.
33
+
34
+ Parameters:
35
+ ----------
36
+ array : np.ndarray
37
+ The input image array of shape (height, width, channels). It should have pixel values in the range [0, 255].
38
+
39
+ Returns:
40
+ -------
41
+ np.ndarray
42
+ A preprocessed image with the shape (channels, height, width) and normalized pixel values.
43
+ The image will be padded to have dimensions that are multiples of 32, with the padding color set to 0.
44
+
45
+ Notes:
46
+ -----
47
+ - The image is resized so that its largest dimension becomes 960 pixels, maintaining the aspect ratio.
48
+ - After normalization, the image is padded to the nearest multiple of 32 in both dimensions, which is
49
+ a requirement for PaddleOCR.
50
+ - The normalized pixel values are scaled between 0 and 1 before padding and transposing the image.
51
+ """
52
+ height, width = array.shape[:2]
53
+ scale_factor = image_max_dimension / max(height, width)
54
+ new_height = int(height * scale_factor)
55
+ new_width = int(width * scale_factor)
56
+ resized = cv2.resize(array, (new_width, new_height))
57
+
58
+ normalized = normalize_image(resized)
59
+
60
+ # PaddleOCR NIM (GRPC) requires input shapes to be multiples of 32.
61
+ new_height = (normalized.shape[0] + 31) // 32 * 32
62
+ new_width = (normalized.shape[1] + 31) // 32 * 32
63
+ padded, (pad_width, pad_height) = pad_image(
64
+ normalized, target_height=new_height, target_width=new_width, background_color=0, dtype=np.float32
65
+ )
66
+
67
+ # PaddleOCR NIM (GRPC) requires input to be (channel, height, width).
68
+ transposed = padded.transpose((2, 0, 1))
69
+
70
+ # Metadata can used for inverting transformations on the resulting bounding boxes.
71
+ metadata = {
72
+ "original_height": height,
73
+ "original_width": width,
74
+ "scale_factor": scale_factor,
75
+ "new_height": transposed.shape[1],
76
+ "new_width": transposed.shape[2],
77
+ "pad_height": pad_height,
78
+ "pad_width": pad_width,
79
+ }
80
+
81
+ return transposed, metadata
82
+
83
+
84
+ def is_ready(http_endpoint: str, ready_endpoint: str) -> bool:
85
+ """
86
+ Check if the server at the given endpoint is ready.
87
+
88
+ Parameters
89
+ ----------
90
+ http_endpoint : str
91
+ The HTTP endpoint of the server.
92
+ ready_endpoint : str
93
+ The specific ready-check endpoint.
94
+
95
+ Returns
96
+ -------
97
+ bool
98
+ True if the server is ready, False otherwise.
99
+ """
100
+
101
+ # IF the url is empty or None that means the service was not configured
102
+ # and is therefore automatically marked as "ready"
103
+ if http_endpoint is None or http_endpoint == "":
104
+ return True
105
+
106
+ # If the url is for build.nvidia.com, it is automatically assumed "ready"
107
+ if "ai.api.nvidia.com" in http_endpoint:
108
+ return True
109
+
110
+ url = generate_url(http_endpoint)
111
+ url = remove_url_endpoints(url)
112
+
113
+ if not ready_endpoint.startswith("/") and not url.endswith("/"):
114
+ ready_endpoint = "/" + ready_endpoint
115
+
116
+ url = url + ready_endpoint
117
+
118
+ # Call the ready endpoint of the NIM
119
+ try:
120
+ # Use a short timeout to prevent long hanging calls. 5 seconds seems resonable
121
+ resp = requests.get(url, timeout=5)
122
+ if resp.status_code == 200:
123
+ # The NIM is saying it is ready to serve
124
+ return True
125
+ elif resp.status_code == 503:
126
+ # NIM is explicitly saying it is not ready.
127
+ return False
128
+ else:
129
+ # Any other code is confusing. We should log it with a warning
130
+ # as it could be something that might hold up ready state
131
+ logger.warning(f"'{url}' HTTP Status: {resp.status_code} - Response Payload: {resp.json()}")
132
+ return False
133
+ except requests.HTTPError as http_err:
134
+ logger.warning(f"'{url}' produced a HTTP error: {http_err}")
135
+ return False
136
+ except requests.Timeout:
137
+ logger.warning(f"'{url}' request timed out")
138
+ return False
139
+ except ConnectionError:
140
+ logger.warning(f"A connection error for '{url}' occurred")
141
+ return False
142
+ except requests.RequestException as err:
143
+ logger.warning(f"An error occurred: {err} for '{url}'")
144
+ return False
145
+ except Exception as ex:
146
+ # Don't let anything squeeze by
147
+ logger.warning(f"Exception: {ex}")
148
+ return False
149
+
150
+
151
+ def _query_metadata(
152
+ http_endpoint: str,
153
+ field_name: str,
154
+ default_value: str,
155
+ retry_value: str = "",
156
+ metadata_endpoint: str = "/v1/metadata",
157
+ ) -> str:
158
+ if (http_endpoint is None) or (http_endpoint == ""):
159
+ return default_value
160
+
161
+ url = generate_url(http_endpoint)
162
+ url = remove_url_endpoints(url)
163
+
164
+ if not metadata_endpoint.startswith("/") and not url.endswith("/"):
165
+ metadata_endpoint = "/" + metadata_endpoint
166
+
167
+ url = url + metadata_endpoint
168
+
169
+ # Call the metadata endpoint of the NIM
170
+ try:
171
+ # Use a short timeout to prevent long hanging calls. 5 seconds seems reasonable
172
+ resp = requests.get(url, timeout=5)
173
+ if resp.status_code == 200:
174
+ field_value = resp.json().get(field_name, "")
175
+ if field_value:
176
+ return field_value
177
+ else:
178
+ # If the field is empty, retry
179
+ logger.warning(f"No {field_name} field in response from '{url}'. Retrying.")
180
+ return retry_value
181
+ else:
182
+ # Any other code is confusing. We should log it with a warning
183
+ logger.warning(f"'{url}' HTTP Status: {resp.status_code} - Response Payload: {resp.text}")
184
+ return retry_value
185
+ except requests.HTTPError as http_err:
186
+ logger.warning(f"'{url}' produced a HTTP error: {http_err}")
187
+ return retry_value
188
+ except requests.Timeout:
189
+ logger.warning(f"'{url}' request timed out")
190
+ return retry_value
191
+ except ConnectionError:
192
+ logger.warning(f"A connection error for '{url}' occurred")
193
+ return retry_value
194
+ except requests.RequestException as err:
195
+ logger.warning(f"An error occurred: {err} for '{url}'")
196
+ return retry_value
197
+ except Exception as ex:
198
+ # Don't let anything squeeze by
199
+ logger.warning(f"Exception: {ex}")
200
+ return retry_value
201
+
202
+
203
+ @multiprocessing_cache(max_calls=100) # Cache results first to avoid redundant retries from backoff
204
+ @backoff.on_predicate(backoff.expo, max_time=30)
205
+ def get_version(http_endpoint: str, metadata_endpoint: str = "/v1/metadata", version_field: str = "version") -> str:
206
+ """
207
+ Get the version of the server from its metadata endpoint.
208
+
209
+ Parameters
210
+ ----------
211
+ http_endpoint : str
212
+ The HTTP endpoint of the server.
213
+ metadata_endpoint : str, optional
214
+ The metadata endpoint to query (default: "/v1/metadata").
215
+ version_field : str, optional
216
+ The field containing the version in the response (default: "version").
217
+
218
+ Returns
219
+ -------
220
+ str
221
+ The version of the server, or an empty string if unavailable.
222
+ """
223
+ default_version = "1.0.0"
224
+
225
+ # TODO: Need a way to match NIM version to API versions.
226
+ if "ai.api.nvidia.com" in http_endpoint or "api.nvcf.nvidia.com" in http_endpoint:
227
+ return default_version
228
+
229
+ return _query_metadata(
230
+ http_endpoint,
231
+ field_name=version_field,
232
+ default_value=default_version,
233
+ )
234
+
235
+
236
+ @multiprocessing_cache(max_calls=100) # Cache results first to avoid redundant retries from backoff
237
+ @backoff.on_predicate(backoff.expo, max_time=30)
238
+ def get_model_name(
239
+ http_endpoint: str,
240
+ default_model_name,
241
+ metadata_endpoint: str = "/v1/metadata",
242
+ model_info_field: str = "modelInfo",
243
+ ) -> str:
244
+ """
245
+ Get the model name of the server from its metadata endpoint.
246
+
247
+ Parameters
248
+ ----------
249
+ http_endpoint : str
250
+ The HTTP endpoint of the server.
251
+ metadata_endpoint : str, optional
252
+ The metadata endpoint to query (default: "/v1/metadata").
253
+ model_info_field : str, optional
254
+ The field containing the model info in the response (default: "modelInfo").
255
+
256
+ Returns
257
+ -------
258
+ str
259
+ The model name of the server, or an empty string if unavailable.
260
+ """
261
+ if "ai.api.nvidia.com" in http_endpoint:
262
+ return http_endpoint.strip("/").strip("/chat/completions").split("/")[-1]
263
+
264
+ if "api.nvcf.nvidia.com" in http_endpoint:
265
+ return default_model_name
266
+
267
+ model_info = _query_metadata(
268
+ http_endpoint,
269
+ field_name=model_info_field,
270
+ default_value={"shortName": default_model_name},
271
+ )
272
+ short_name = model_info[0].get("shortName", default_model_name)
273
+ model_name = short_name.split(":")[0]
274
+
275
+ return model_name
@@ -0,0 +1,238 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ import json
6
+ import logging
7
+ from typing import Any
8
+ from typing import Dict
9
+ from typing import List
10
+ from typing import Optional
11
+
12
+ from nv_ingest_api.internal.primitives.nim import ModelInterface
13
+ from nv_ingest_api.util.image_processing.transforms import numpy_to_base64
14
+
15
+ ACCEPTED_TEXT_CLASSES = set(
16
+ [
17
+ "Text",
18
+ "Title",
19
+ "Section-header",
20
+ "List-item",
21
+ "TOC",
22
+ "Bibliography",
23
+ "Formula",
24
+ "Page-header",
25
+ "Page-footer",
26
+ "Caption",
27
+ "Footnote",
28
+ "Floating-text",
29
+ ]
30
+ )
31
+ ACCEPTED_TABLE_CLASSES = set(
32
+ [
33
+ "Table",
34
+ ]
35
+ )
36
+ ACCEPTED_IMAGE_CLASSES = set(
37
+ [
38
+ "Picture",
39
+ ]
40
+ )
41
+ ACCEPTED_CLASSES = ACCEPTED_TEXT_CLASSES | ACCEPTED_TABLE_CLASSES | ACCEPTED_IMAGE_CLASSES
42
+
43
+ logger = logging.getLogger(__name__)
44
+
45
+
46
+ class NemoRetrieverParseModelInterface(ModelInterface):
47
+ """
48
+ An interface for handling inference with a NemoRetrieverParse model.
49
+ """
50
+
51
+ def __init__(self, model_name: str = "nvidia/nemoretriever-parse"):
52
+ """
53
+ Initialize the instance with a specified model name.
54
+ Parameters
55
+ ----------
56
+ model_name : str, optional
57
+ The name of the model to be used, by default "nvidia/nemoretriever-parse".
58
+ """
59
+ self.model_name = model_name
60
+
61
+ def name(self) -> str:
62
+ """
63
+ Get the name of the model interface.
64
+
65
+ Returns
66
+ -------
67
+ str
68
+ The name of the model interface.
69
+ """
70
+ return "nemoretriever_parse"
71
+
72
+ def prepare_data_for_inference(self, data: Dict[str, Any]) -> Dict[str, Any]:
73
+ """
74
+ Prepare input data for inference by resizing images and storing their original shapes.
75
+
76
+ Parameters
77
+ ----------
78
+ data : dict
79
+ The input data containing a list of images.
80
+
81
+ Returns
82
+ -------
83
+ dict
84
+ The updated data dictionary with resized images and original image shapes.
85
+ """
86
+
87
+ return data
88
+
89
+ def format_input(self, data: Dict[str, Any], protocol: str, max_batch_size: int, **kwargs) -> Any:
90
+ """
91
+ Format input data for the specified protocol.
92
+
93
+ Parameters
94
+ ----------
95
+ data : dict
96
+ The input data to format.
97
+ protocol : str
98
+ The protocol to use ("grpc" or "http").
99
+ **kwargs : dict
100
+ Additional parameters for HTTP payload formatting.
101
+
102
+ Returns
103
+ -------
104
+ Any
105
+ The formatted input data.
106
+
107
+ Raises
108
+ ------
109
+ ValueError
110
+ If an invalid protocol is specified.
111
+ """
112
+
113
+ # Helper function: chunk a list into sublists of length <= chunk_size.
114
+ def chunk_list(lst: list, chunk_size: int) -> List[list]:
115
+ return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
116
+
117
+ if protocol == "grpc":
118
+ raise ValueError("gRPC protocol is not supported for NemoRetrieverParse.")
119
+ elif protocol == "http":
120
+ logger.debug("Formatting input for HTTP NemoRetrieverParse model")
121
+ # Prepare payload for HTTP request
122
+
123
+ if "images" in data:
124
+ base64_list = [numpy_to_base64(img) for img in data["images"]]
125
+ else:
126
+ base64_list = [numpy_to_base64(data["image"])]
127
+
128
+ formatted_batches = []
129
+ formatted_batch_data = []
130
+ b64_chunks = chunk_list(base64_list, max_batch_size)
131
+
132
+ for b64_chunk in b64_chunks:
133
+ payload = self._prepare_nemoretriever_parse_payload(b64_chunk)
134
+ formatted_batches.append(payload)
135
+ formatted_batch_data.append({})
136
+ return formatted_batches, formatted_batch_data
137
+
138
+ else:
139
+ raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.")
140
+
141
+ def parse_output(self, response: Any, protocol: str, data: Optional[Dict[str, Any]] = None, **kwargs) -> Any:
142
+ """
143
+ Parse the output from the model's inference response.
144
+
145
+ Parameters
146
+ ----------
147
+ response : Any
148
+ The response from the model inference.
149
+ protocol : str
150
+ The protocol used ("grpc" or "http").
151
+ data : dict, optional
152
+ Additional input data passed to the function.
153
+
154
+ Returns
155
+ -------
156
+ Any
157
+ The parsed output data.
158
+
159
+ Raises
160
+ ------
161
+ ValueError
162
+ If an invalid protocol is specified.
163
+ """
164
+
165
+ if protocol == "grpc":
166
+ raise ValueError("gRPC protocol is not supported for NemoRetrieverParse.")
167
+ elif protocol == "http":
168
+ logger.debug("Parsing output from HTTP NemoRetrieverParse model")
169
+ return self._extract_content_from_nemoretriever_parse_response(response)
170
+ else:
171
+ raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.")
172
+
173
+ def process_inference_results(self, output: Any, **kwargs) -> Any:
174
+ """
175
+ Process inference results for the NemoRetrieverParse model.
176
+
177
+ Parameters
178
+ ----------
179
+ output : Any
180
+ The raw output from the model.
181
+
182
+ Returns
183
+ -------
184
+ Any
185
+ The processed inference results.
186
+ """
187
+
188
+ return output
189
+
190
+ def _prepare_nemoretriever_parse_payload(self, base64_list: List[str]) -> Dict[str, Any]:
191
+ messages = []
192
+
193
+ for b64_img in base64_list:
194
+ messages.append(
195
+ {
196
+ "role": "user",
197
+ "content": [
198
+ {
199
+ "type": "image_url",
200
+ "image_url": {
201
+ "url": f"data:image/png;base64,{b64_img}",
202
+ },
203
+ }
204
+ ],
205
+ }
206
+ )
207
+ payload = {
208
+ "model": self.model_name,
209
+ "messages": messages,
210
+ }
211
+
212
+ return payload
213
+
214
+ def _extract_content_from_nemoretriever_parse_response(self, json_response: Dict[str, Any]) -> Any:
215
+ """
216
+ Extract content from the JSON response of a Deplot HTTP API request.
217
+
218
+ Parameters
219
+ ----------
220
+ json_response : dict
221
+ The JSON response from the Deplot API.
222
+
223
+ Returns
224
+ -------
225
+ Any
226
+ The extracted content from the response.
227
+
228
+ Raises
229
+ ------
230
+ RuntimeError
231
+ If the response does not contain the expected "choices" key or if it is empty.
232
+ """
233
+
234
+ if "choices" not in json_response or not json_response["choices"]:
235
+ raise RuntimeError("Unexpected response format: 'choices' key is missing or empty.")
236
+
237
+ tool_call = json_response["choices"][0]["message"]["tool_calls"][0]
238
+ return json.loads(tool_call["function"]["arguments"])