nv-ingest-api 26.1.0rc4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nv-ingest-api might be problematic. Click here for more details.

Files changed (177) hide show
  1. nv_ingest_api/__init__.py +3 -0
  2. nv_ingest_api/interface/__init__.py +218 -0
  3. nv_ingest_api/interface/extract.py +977 -0
  4. nv_ingest_api/interface/mutate.py +154 -0
  5. nv_ingest_api/interface/store.py +200 -0
  6. nv_ingest_api/interface/transform.py +382 -0
  7. nv_ingest_api/interface/utility.py +186 -0
  8. nv_ingest_api/internal/__init__.py +0 -0
  9. nv_ingest_api/internal/enums/__init__.py +3 -0
  10. nv_ingest_api/internal/enums/common.py +550 -0
  11. nv_ingest_api/internal/extract/__init__.py +3 -0
  12. nv_ingest_api/internal/extract/audio/__init__.py +3 -0
  13. nv_ingest_api/internal/extract/audio/audio_extraction.py +202 -0
  14. nv_ingest_api/internal/extract/docx/__init__.py +5 -0
  15. nv_ingest_api/internal/extract/docx/docx_extractor.py +232 -0
  16. nv_ingest_api/internal/extract/docx/engines/__init__.py +0 -0
  17. nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/__init__.py +3 -0
  18. nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/docx_helper.py +127 -0
  19. nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/docxreader.py +971 -0
  20. nv_ingest_api/internal/extract/html/__init__.py +3 -0
  21. nv_ingest_api/internal/extract/html/html_extractor.py +84 -0
  22. nv_ingest_api/internal/extract/image/__init__.py +3 -0
  23. nv_ingest_api/internal/extract/image/chart_extractor.py +375 -0
  24. nv_ingest_api/internal/extract/image/image_extractor.py +208 -0
  25. nv_ingest_api/internal/extract/image/image_helpers/__init__.py +3 -0
  26. nv_ingest_api/internal/extract/image/image_helpers/common.py +433 -0
  27. nv_ingest_api/internal/extract/image/infographic_extractor.py +290 -0
  28. nv_ingest_api/internal/extract/image/ocr_extractor.py +407 -0
  29. nv_ingest_api/internal/extract/image/table_extractor.py +391 -0
  30. nv_ingest_api/internal/extract/pdf/__init__.py +3 -0
  31. nv_ingest_api/internal/extract/pdf/engines/__init__.py +19 -0
  32. nv_ingest_api/internal/extract/pdf/engines/adobe.py +484 -0
  33. nv_ingest_api/internal/extract/pdf/engines/llama.py +246 -0
  34. nv_ingest_api/internal/extract/pdf/engines/nemotron_parse.py +598 -0
  35. nv_ingest_api/internal/extract/pdf/engines/pdf_helpers/__init__.py +166 -0
  36. nv_ingest_api/internal/extract/pdf/engines/pdfium.py +652 -0
  37. nv_ingest_api/internal/extract/pdf/engines/tika.py +96 -0
  38. nv_ingest_api/internal/extract/pdf/engines/unstructured_io.py +426 -0
  39. nv_ingest_api/internal/extract/pdf/pdf_extractor.py +74 -0
  40. nv_ingest_api/internal/extract/pptx/__init__.py +5 -0
  41. nv_ingest_api/internal/extract/pptx/engines/__init__.py +0 -0
  42. nv_ingest_api/internal/extract/pptx/engines/pptx_helper.py +968 -0
  43. nv_ingest_api/internal/extract/pptx/pptx_extractor.py +210 -0
  44. nv_ingest_api/internal/meta/__init__.py +3 -0
  45. nv_ingest_api/internal/meta/udf.py +232 -0
  46. nv_ingest_api/internal/mutate/__init__.py +3 -0
  47. nv_ingest_api/internal/mutate/deduplicate.py +110 -0
  48. nv_ingest_api/internal/mutate/filter.py +133 -0
  49. nv_ingest_api/internal/primitives/__init__.py +0 -0
  50. nv_ingest_api/internal/primitives/control_message_task.py +16 -0
  51. nv_ingest_api/internal/primitives/ingest_control_message.py +307 -0
  52. nv_ingest_api/internal/primitives/nim/__init__.py +9 -0
  53. nv_ingest_api/internal/primitives/nim/default_values.py +14 -0
  54. nv_ingest_api/internal/primitives/nim/model_interface/__init__.py +3 -0
  55. nv_ingest_api/internal/primitives/nim/model_interface/cached.py +274 -0
  56. nv_ingest_api/internal/primitives/nim/model_interface/decorators.py +56 -0
  57. nv_ingest_api/internal/primitives/nim/model_interface/deplot.py +270 -0
  58. nv_ingest_api/internal/primitives/nim/model_interface/helpers.py +338 -0
  59. nv_ingest_api/internal/primitives/nim/model_interface/nemotron_parse.py +239 -0
  60. nv_ingest_api/internal/primitives/nim/model_interface/ocr.py +776 -0
  61. nv_ingest_api/internal/primitives/nim/model_interface/parakeet.py +367 -0
  62. nv_ingest_api/internal/primitives/nim/model_interface/text_embedding.py +129 -0
  63. nv_ingest_api/internal/primitives/nim/model_interface/vlm.py +177 -0
  64. nv_ingest_api/internal/primitives/nim/model_interface/yolox.py +1681 -0
  65. nv_ingest_api/internal/primitives/nim/nim_client.py +801 -0
  66. nv_ingest_api/internal/primitives/nim/nim_model_interface.py +126 -0
  67. nv_ingest_api/internal/primitives/tracing/__init__.py +0 -0
  68. nv_ingest_api/internal/primitives/tracing/latency.py +69 -0
  69. nv_ingest_api/internal/primitives/tracing/logging.py +96 -0
  70. nv_ingest_api/internal/primitives/tracing/tagging.py +288 -0
  71. nv_ingest_api/internal/schemas/__init__.py +3 -0
  72. nv_ingest_api/internal/schemas/extract/__init__.py +3 -0
  73. nv_ingest_api/internal/schemas/extract/extract_audio_schema.py +133 -0
  74. nv_ingest_api/internal/schemas/extract/extract_chart_schema.py +144 -0
  75. nv_ingest_api/internal/schemas/extract/extract_docx_schema.py +129 -0
  76. nv_ingest_api/internal/schemas/extract/extract_html_schema.py +34 -0
  77. nv_ingest_api/internal/schemas/extract/extract_image_schema.py +126 -0
  78. nv_ingest_api/internal/schemas/extract/extract_infographic_schema.py +137 -0
  79. nv_ingest_api/internal/schemas/extract/extract_ocr_schema.py +137 -0
  80. nv_ingest_api/internal/schemas/extract/extract_pdf_schema.py +220 -0
  81. nv_ingest_api/internal/schemas/extract/extract_pptx_schema.py +128 -0
  82. nv_ingest_api/internal/schemas/extract/extract_table_schema.py +137 -0
  83. nv_ingest_api/internal/schemas/message_brokers/__init__.py +3 -0
  84. nv_ingest_api/internal/schemas/message_brokers/message_broker_client_schema.py +37 -0
  85. nv_ingest_api/internal/schemas/message_brokers/request_schema.py +34 -0
  86. nv_ingest_api/internal/schemas/message_brokers/response_schema.py +19 -0
  87. nv_ingest_api/internal/schemas/meta/__init__.py +3 -0
  88. nv_ingest_api/internal/schemas/meta/base_model_noext.py +11 -0
  89. nv_ingest_api/internal/schemas/meta/ingest_job_schema.py +355 -0
  90. nv_ingest_api/internal/schemas/meta/metadata_schema.py +394 -0
  91. nv_ingest_api/internal/schemas/meta/udf.py +23 -0
  92. nv_ingest_api/internal/schemas/mixins.py +39 -0
  93. nv_ingest_api/internal/schemas/mutate/__init__.py +3 -0
  94. nv_ingest_api/internal/schemas/mutate/mutate_image_dedup_schema.py +16 -0
  95. nv_ingest_api/internal/schemas/store/__init__.py +3 -0
  96. nv_ingest_api/internal/schemas/store/store_embedding_schema.py +28 -0
  97. nv_ingest_api/internal/schemas/store/store_image_schema.py +45 -0
  98. nv_ingest_api/internal/schemas/transform/__init__.py +3 -0
  99. nv_ingest_api/internal/schemas/transform/transform_image_caption_schema.py +36 -0
  100. nv_ingest_api/internal/schemas/transform/transform_image_filter_schema.py +17 -0
  101. nv_ingest_api/internal/schemas/transform/transform_text_embedding_schema.py +48 -0
  102. nv_ingest_api/internal/schemas/transform/transform_text_splitter_schema.py +24 -0
  103. nv_ingest_api/internal/store/__init__.py +3 -0
  104. nv_ingest_api/internal/store/embed_text_upload.py +236 -0
  105. nv_ingest_api/internal/store/image_upload.py +251 -0
  106. nv_ingest_api/internal/transform/__init__.py +3 -0
  107. nv_ingest_api/internal/transform/caption_image.py +219 -0
  108. nv_ingest_api/internal/transform/embed_text.py +702 -0
  109. nv_ingest_api/internal/transform/split_text.py +182 -0
  110. nv_ingest_api/util/__init__.py +3 -0
  111. nv_ingest_api/util/control_message/__init__.py +0 -0
  112. nv_ingest_api/util/control_message/validators.py +47 -0
  113. nv_ingest_api/util/converters/__init__.py +0 -0
  114. nv_ingest_api/util/converters/bytetools.py +78 -0
  115. nv_ingest_api/util/converters/containers.py +65 -0
  116. nv_ingest_api/util/converters/datetools.py +90 -0
  117. nv_ingest_api/util/converters/dftools.py +127 -0
  118. nv_ingest_api/util/converters/formats.py +64 -0
  119. nv_ingest_api/util/converters/type_mappings.py +27 -0
  120. nv_ingest_api/util/dataloader/__init__.py +9 -0
  121. nv_ingest_api/util/dataloader/dataloader.py +409 -0
  122. nv_ingest_api/util/detectors/__init__.py +5 -0
  123. nv_ingest_api/util/detectors/language.py +38 -0
  124. nv_ingest_api/util/exception_handlers/__init__.py +0 -0
  125. nv_ingest_api/util/exception_handlers/converters.py +72 -0
  126. nv_ingest_api/util/exception_handlers/decorators.py +429 -0
  127. nv_ingest_api/util/exception_handlers/detectors.py +74 -0
  128. nv_ingest_api/util/exception_handlers/pdf.py +116 -0
  129. nv_ingest_api/util/exception_handlers/schemas.py +68 -0
  130. nv_ingest_api/util/image_processing/__init__.py +5 -0
  131. nv_ingest_api/util/image_processing/clustering.py +260 -0
  132. nv_ingest_api/util/image_processing/processing.py +177 -0
  133. nv_ingest_api/util/image_processing/table_and_chart.py +504 -0
  134. nv_ingest_api/util/image_processing/transforms.py +850 -0
  135. nv_ingest_api/util/imports/__init__.py +3 -0
  136. nv_ingest_api/util/imports/callable_signatures.py +108 -0
  137. nv_ingest_api/util/imports/dynamic_resolvers.py +158 -0
  138. nv_ingest_api/util/introspection/__init__.py +3 -0
  139. nv_ingest_api/util/introspection/class_inspect.py +145 -0
  140. nv_ingest_api/util/introspection/function_inspect.py +65 -0
  141. nv_ingest_api/util/logging/__init__.py +0 -0
  142. nv_ingest_api/util/logging/configuration.py +102 -0
  143. nv_ingest_api/util/logging/sanitize.py +84 -0
  144. nv_ingest_api/util/message_brokers/__init__.py +3 -0
  145. nv_ingest_api/util/message_brokers/qos_scheduler.py +283 -0
  146. nv_ingest_api/util/message_brokers/simple_message_broker/__init__.py +9 -0
  147. nv_ingest_api/util/message_brokers/simple_message_broker/broker.py +465 -0
  148. nv_ingest_api/util/message_brokers/simple_message_broker/ordered_message_queue.py +71 -0
  149. nv_ingest_api/util/message_brokers/simple_message_broker/simple_client.py +455 -0
  150. nv_ingest_api/util/metadata/__init__.py +5 -0
  151. nv_ingest_api/util/metadata/aggregators.py +516 -0
  152. nv_ingest_api/util/multi_processing/__init__.py +8 -0
  153. nv_ingest_api/util/multi_processing/mp_pool_singleton.py +200 -0
  154. nv_ingest_api/util/nim/__init__.py +161 -0
  155. nv_ingest_api/util/pdf/__init__.py +3 -0
  156. nv_ingest_api/util/pdf/pdfium.py +428 -0
  157. nv_ingest_api/util/schema/__init__.py +3 -0
  158. nv_ingest_api/util/schema/schema_validator.py +10 -0
  159. nv_ingest_api/util/service_clients/__init__.py +3 -0
  160. nv_ingest_api/util/service_clients/client_base.py +86 -0
  161. nv_ingest_api/util/service_clients/kafka/__init__.py +3 -0
  162. nv_ingest_api/util/service_clients/redis/__init__.py +3 -0
  163. nv_ingest_api/util/service_clients/redis/redis_client.py +983 -0
  164. nv_ingest_api/util/service_clients/rest/__init__.py +0 -0
  165. nv_ingest_api/util/service_clients/rest/rest_client.py +595 -0
  166. nv_ingest_api/util/string_processing/__init__.py +51 -0
  167. nv_ingest_api/util/string_processing/configuration.py +682 -0
  168. nv_ingest_api/util/string_processing/yaml.py +109 -0
  169. nv_ingest_api/util/system/__init__.py +0 -0
  170. nv_ingest_api/util/system/hardware_info.py +594 -0
  171. nv_ingest_api-26.1.0rc4.dist-info/METADATA +237 -0
  172. nv_ingest_api-26.1.0rc4.dist-info/RECORD +177 -0
  173. nv_ingest_api-26.1.0rc4.dist-info/WHEEL +5 -0
  174. nv_ingest_api-26.1.0rc4.dist-info/licenses/LICENSE +201 -0
  175. nv_ingest_api-26.1.0rc4.dist-info/top_level.txt +2 -0
  176. udfs/__init__.py +5 -0
  177. udfs/llm_summarizer_udf.py +259 -0
@@ -0,0 +1,776 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ import json
6
+ import logging
7
+ import os
8
+ from typing import Any
9
+ from typing import Dict
10
+ from typing import List
11
+ from typing import Optional
12
+ from typing import Tuple
13
+
14
+ import backoff
15
+ import numpy as np
16
+ import tritonclient.grpc as grpcclient
17
+
18
+ from nv_ingest_api.internal.primitives.nim import ModelInterface
19
+ from nv_ingest_api.internal.primitives.nim.model_interface.decorators import multiprocessing_cache
20
+ from nv_ingest_api.internal.primitives.nim.model_interface.helpers import preprocess_image_for_paddle
21
+ from nv_ingest_api.util.image_processing.transforms import base64_to_numpy
22
+
23
+ DEFAULT_OCR_MODEL_NAME = "scene_text_ensemble"
24
+ NEMORETRIEVER_OCR_MODEL_NAME = "scene_text_wrapper"
25
+ NEMORETRIEVER_OCR_ENSEMBLE_MODEL_NAME = "scene_text_ensemble"
26
+ NEMORETRIEVER_OCR_BLS_MODEL_NAME = "scene_text_python"
27
+
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ class OCRModelInterfaceBase(ModelInterface):
33
+
34
+ NUM_CHANNELS = 3
35
+ BYTES_PER_ELEMENT = 4 # For float32
36
+
37
+ def parse_output(
38
+ self,
39
+ response: Any,
40
+ protocol: str,
41
+ data: Optional[Dict[str, Any]] = None,
42
+ model_name: str = DEFAULT_OCR_MODEL_NAME,
43
+ **kwargs: Any,
44
+ ) -> Any:
45
+ """
46
+ Parse the model's inference response for the given protocol. The parsing
47
+ may handle batched outputs for multiple images.
48
+
49
+ Parameters
50
+ ----------
51
+ response : Any
52
+ The raw response from the OCR model.
53
+ protocol : str
54
+ The protocol used for inference, "grpc" or "http".
55
+ data : dict of str -> Any, optional
56
+ Additional data dictionary that may include "image_dims" for bounding box scaling.
57
+ **kwargs : Any
58
+ Additional keyword arguments, such as custom `table_content_format`.
59
+
60
+ Returns
61
+ -------
62
+ Any
63
+ The parsed output, typically a list of (content, table_content_format) tuples.
64
+
65
+ Raises
66
+ ------
67
+ ValueError
68
+ If an invalid protocol is specified.
69
+ """
70
+ # Retrieve image dimensions if available
71
+ dims: Optional[List[Tuple[int, int]]] = data.get("image_dims") if data else None
72
+
73
+ if protocol == "grpc":
74
+ logger.debug("Parsing output from gRPC OCR model (batched).")
75
+ return self._extract_content_from_ocr_grpc_response(response, dims, model_name=model_name)
76
+
77
+ elif protocol == "http":
78
+ logger.debug("Parsing output from HTTP OCR model (batched).")
79
+ return self._extract_content_from_ocr_http_response(response, dims)
80
+
81
+ else:
82
+ raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.")
83
+
84
+ def process_inference_results(self, output: Any, **kwargs: Any) -> Any:
85
+ """
86
+ Process inference results for the OCR model.
87
+
88
+ Parameters
89
+ ----------
90
+ output : Any
91
+ The raw output parsed from the OCR model.
92
+ **kwargs : Any
93
+ Additional keyword arguments for customization.
94
+
95
+ Returns
96
+ -------
97
+ Any
98
+ The post-processed inference results. By default, this simply returns the output
99
+ as the table content (or content list).
100
+ """
101
+ return output
102
+
103
+ def does_item_fit_in_batch(self, current_batch, next_request, memory_budget_bytes: int) -> bool:
104
+ """
105
+ Estimates the memory of a potential batch of padded images and checks it
106
+ against the configured budget.
107
+ """
108
+ all_requests = current_batch + [next_request]
109
+ all_dims = [req.dims for req in all_requests]
110
+
111
+ potential_max_h = max(d[0] for d in all_dims)
112
+ potential_max_w = max(d[1] for d in all_dims)
113
+
114
+ potential_batch_size = len(all_requests)
115
+
116
+ potential_memory_bytes = (
117
+ potential_batch_size * potential_max_h * potential_max_w * self.NUM_CHANNELS * self.BYTES_PER_ELEMENT
118
+ )
119
+
120
+ return potential_memory_bytes <= memory_budget_bytes
121
+
122
+ def _prepare_ocr_payload(self, base64_img: str) -> Dict[str, Any]:
123
+ """
124
+ DEPRECATED by batch logic in format_input. Kept here if you need single-image direct calls.
125
+
126
+ Parameters
127
+ ----------
128
+ base64_img : str
129
+ A single base64-encoded image string.
130
+
131
+ Returns
132
+ -------
133
+ dict of str -> Any
134
+ The payload in either legacy or new format for OCR's HTTP endpoint.
135
+ """
136
+ image_url = f"data:image/png;base64,{base64_img}"
137
+
138
+ image = {"type": "image_url", "url": image_url}
139
+ payload = {"input": [image]}
140
+
141
+ return payload
142
+
143
+ def _extract_content_from_ocr_http_response(
144
+ self,
145
+ json_response: Dict[str, Any],
146
+ dimensions: List[Dict[str, Any]],
147
+ ) -> List[Tuple[str, str]]:
148
+ """
149
+ Extract content from the JSON response of a OCR HTTP API request.
150
+
151
+ Parameters
152
+ ----------
153
+ json_response : dict of str -> Any
154
+ The JSON response returned by the OCR endpoint.
155
+ table_content_format : str or None
156
+ The specified format for table content (e.g., 'simple' or 'pseudo_markdown').
157
+ dimensions : list of dict, optional
158
+ A list of dict for each corresponding image, used for bounding box scaling.
159
+
160
+ Returns
161
+ -------
162
+ list of (str, str)
163
+ A list of (content, table_content_format) tuples, one for each image result.
164
+
165
+ Raises
166
+ ------
167
+ RuntimeError
168
+ If the response format is missing or invalid.
169
+ ValueError
170
+ If the `table_content_format` is unrecognized.
171
+ """
172
+ if "data" not in json_response or not json_response["data"]:
173
+ raise RuntimeError("Unexpected response format: 'data' key is missing or empty.")
174
+
175
+ results: List[str] = []
176
+ for item_idx, item in enumerate(json_response["data"]):
177
+ text_detections = item.get("text_detections", [])
178
+ text_predictions = []
179
+ bounding_boxes = []
180
+ conf_scores = []
181
+ for td in text_detections:
182
+ text_predictions.append(td["text_prediction"]["text"])
183
+ bounding_boxes.append([[pt["x"], pt["y"]] for pt in td["bounding_box"]["points"]])
184
+ conf_scores.append(td["text_prediction"]["confidence"])
185
+
186
+ bounding_boxes, text_predictions, conf_scores = self._postprocess_ocr_response(
187
+ bounding_boxes,
188
+ text_predictions,
189
+ conf_scores,
190
+ dimensions,
191
+ img_index=item_idx,
192
+ )
193
+
194
+ results.append([bounding_boxes, text_predictions, conf_scores])
195
+
196
+ return results
197
+
198
+ def _extract_content_from_ocr_grpc_response(
199
+ self,
200
+ response: np.ndarray,
201
+ dimensions: List[Dict[str, Any]],
202
+ model_name: str = DEFAULT_OCR_MODEL_NAME,
203
+ ) -> List[Tuple[str, str]]:
204
+ """
205
+ Parse a gRPC response for one or more images. The response can have two possible shapes:
206
+ - (3,) for batch_size=1
207
+ - (3, n) for batch_size=n
208
+
209
+ In either case:
210
+ response[0, i]: byte string containing bounding box data
211
+ response[1, i]: byte string containing text prediction data
212
+ response[2, i]: (Optional) additional data/metadata (ignored here)
213
+
214
+ Parameters
215
+ ----------
216
+ response : np.ndarray
217
+ The raw NumPy array from gRPC. Expected shape: (3,) or (3, n).
218
+ table_content_format : str
219
+ The format of the output text content, e.g. 'simple' or 'pseudo_markdown'.
220
+ dims : list of dict, optional
221
+ A list of dict for each corresponding image, used for bounding box scaling.
222
+
223
+ Returns
224
+ -------
225
+ list of (str, str)
226
+ A list of (content, table_content_format) for each image.
227
+
228
+ Raises
229
+ ------
230
+ ValueError
231
+ If the response is not a NumPy array or has an unexpected shape,
232
+ or if the `table_content_format` is unrecognized.
233
+ """
234
+ if not isinstance(response, np.ndarray):
235
+ raise ValueError("Unexpected response format: response is not a NumPy array.")
236
+
237
+ if model_name in [
238
+ NEMORETRIEVER_OCR_MODEL_NAME,
239
+ NEMORETRIEVER_OCR_ENSEMBLE_MODEL_NAME,
240
+ NEMORETRIEVER_OCR_BLS_MODEL_NAME,
241
+ ]:
242
+ response = response.transpose((1, 0))
243
+
244
+ # If we have shape (3,), convert to (3, 1)
245
+ if response.ndim == 1 and response.shape == (3,):
246
+ response = response.reshape(3, 1)
247
+ elif response.ndim != 2 or response.shape[0] != 3:
248
+ raise ValueError(f"Unexpected response shape: {response.shape}. Expecting (3,) or (3, n).")
249
+ batch_size = response.shape[1]
250
+ results: List[Tuple[str, str]] = []
251
+
252
+ for i in range(batch_size):
253
+ # 1) Parse bounding boxes
254
+ bboxes_bytestr: bytes = response[0, i]
255
+ bounding_boxes = json.loads(bboxes_bytestr.decode("utf8"))
256
+
257
+ # 2) Parse text predictions
258
+ texts_bytestr: bytes = response[1, i]
259
+ text_predictions = json.loads(texts_bytestr.decode("utf8"))
260
+
261
+ # 3) Parse confidence scores
262
+ confs_bytestr: bytes = response[2, i]
263
+ conf_scores = json.loads(confs_bytestr.decode("utf8"))
264
+
265
+ # Some gRPC responses nest single-item lists; flatten them if needed
266
+ if (
267
+ (isinstance(bounding_boxes, list) and len(bounding_boxes) == 1 and isinstance(bounding_boxes[0], list))
268
+ and (
269
+ isinstance(text_predictions, list)
270
+ and len(text_predictions) == 1
271
+ and isinstance(text_predictions[0], list)
272
+ )
273
+ and (isinstance(conf_scores, list) and len(conf_scores) == 1 and isinstance(conf_scores[0], list))
274
+ ):
275
+ bounding_boxes = bounding_boxes[0]
276
+ text_predictions = text_predictions[0]
277
+ conf_scores = conf_scores[0]
278
+
279
+ # 4) Postprocess
280
+ bounding_boxes, text_predictions, conf_scores = self._postprocess_ocr_response(
281
+ bounding_boxes,
282
+ text_predictions,
283
+ conf_scores,
284
+ dimensions,
285
+ img_index=i,
286
+ scale_coordinates=True,
287
+ )
288
+
289
+ results.append([bounding_boxes, text_predictions, conf_scores])
290
+
291
+ return results
292
+
293
+ @staticmethod
294
+ def _postprocess_ocr_response(
295
+ bounding_boxes: List[Any],
296
+ text_predictions: List[str],
297
+ conf_scores: List[float],
298
+ dims: Optional[List[Dict[str, Any]]] = None,
299
+ img_index: int = 0,
300
+ scale_coordinates: bool = True,
301
+ shift_coordinates: bool = True,
302
+ ) -> Tuple[List[Any], List[str]]:
303
+ """
304
+ Convert bounding boxes with normalized coordinates to pixel cooridnates by using
305
+ the dimensions. Also shift the coorindates if the inputs were padded. For multiple images,
306
+ the correct image dimensions (height, width) are retrieved from `dims[img_index]`.
307
+
308
+ Parameters
309
+ ----------
310
+ bounding_boxes : list of Any
311
+ A list (per line of text) of bounding boxes, each a list of (x, y) points.
312
+ text_predictions : list of str
313
+ A list of text predictions, one for each bounding box.
314
+ img_index : int, optional
315
+ The index of the image for which bounding boxes are being converted. Default is 0.
316
+ dims : list of dict, optional
317
+ A list of dictionaries, where each dictionary contains image-specific dimensions
318
+ and scaling information:
319
+ - "new_width" (int): The width of the image after processing.
320
+ - "new_height" (int): The height of the image after processing.
321
+ - "pad_width" (int, optional): The width of padding added to the image.
322
+ - "pad_height" (int, optional): The height of padding added to the image.
323
+ - "scale_factor" (float, optional): The scaling factor applied to the image.
324
+
325
+ Returns
326
+ -------
327
+ Tuple[List[Any], List[str]]
328
+ Bounding boxes scaled backed to the original dimensions and detected text lines.
329
+
330
+ Notes
331
+ -----
332
+ - If `dims` is None or `img_index` is out of range, bounding boxes will not be scaled properly.
333
+ """
334
+ # Default to no scaling if dims are missing or out of range
335
+ if not dims:
336
+ raise ValueError("No image_dims provided.")
337
+ else:
338
+ if img_index >= len(dims):
339
+ logger.warning("Image index out of range for stored dimensions. Using first image dims by default.")
340
+ img_index = 0
341
+
342
+ max_width = dims[img_index]["new_width"] if scale_coordinates else 1.0
343
+ max_height = dims[img_index]["new_height"] if scale_coordinates else 1.0
344
+ pad_width = dims[img_index].get("pad_width", 0) if shift_coordinates else 0.0
345
+ pad_height = dims[img_index].get("pad_height", 0) if shift_coordinates else 0.0
346
+ scale_factor = dims[img_index].get("scale_factor", 1.0) if scale_coordinates else 1.0
347
+
348
+ bboxes: List[List[float]] = []
349
+ texts: List[str] = []
350
+ confs: List[float] = []
351
+
352
+ # Convert normalized coords back to actual pixel coords
353
+ for box, txt, conf in zip(bounding_boxes, text_predictions, conf_scores):
354
+ if box == "nan":
355
+ continue
356
+ points: List[List[float]] = []
357
+ for point in box:
358
+ # Convert normalized coords back to actual pixel coords,
359
+ # and shift them back to their original positions if padded.
360
+ x_pixels = float(point[0]) * max_width - pad_width
361
+ y_pixels = float(point[1]) * max_height - pad_height
362
+ x_original = x_pixels / scale_factor
363
+ y_original = y_pixels / scale_factor
364
+ points.append([x_original, y_original])
365
+ bboxes.append(points)
366
+ texts.append(txt)
367
+ confs.append(conf)
368
+
369
+ return bboxes, texts, confs
370
+
371
+
372
+ class PaddleOCRModelInterface(OCRModelInterfaceBase):
373
+ """
374
+ An interface for handling inference with a legacy OCR model, supporting both gRPC and HTTP protocols.
375
+ """
376
+
377
+ def name(self) -> str:
378
+ """
379
+ Get the name of the model interface.
380
+
381
+ Returns
382
+ -------
383
+ str
384
+ The name of the model interface.
385
+ """
386
+ return "PaddleOCR"
387
+
388
+ def prepare_data_for_inference(self, data: Dict[str, Any]) -> Dict[str, Any]:
389
+ """
390
+ Decode one or more base64-encoded images into NumPy arrays, storing them
391
+ alongside their dimensions in `data`.
392
+
393
+ Parameters
394
+ ----------
395
+ data : dict of str -> Any
396
+ The input data containing either:
397
+ - 'base64_image': a single base64-encoded image, or
398
+ - 'base64_images': a list of base64-encoded images.
399
+
400
+ Returns
401
+ -------
402
+ dict of str -> Any
403
+ The updated data dictionary with the following keys added:
404
+ - "images": List of decoded NumPy arrays of shape (H, W, C).
405
+ - "image_dims": List of (height, width) tuples for each decoded image.
406
+
407
+ Raises
408
+ ------
409
+ KeyError
410
+ If neither 'base64_image' nor 'base64_images' is found in `data`.
411
+ ValueError
412
+ If 'base64_images' is present but is not a list.
413
+ """
414
+ if "base64_images" in data:
415
+ base64_list = data["base64_images"]
416
+ if not isinstance(base64_list, list):
417
+ raise ValueError("The 'base64_images' key must contain a list of base64-encoded strings.")
418
+
419
+ images: List[np.ndarray] = []
420
+ for b64 in base64_list:
421
+ img = base64_to_numpy(b64)
422
+ images.append(img)
423
+
424
+ data["images"] = images
425
+
426
+ elif "base64_image" in data:
427
+ # Single-image fallback
428
+ img = base64_to_numpy(data["base64_image"])
429
+ data["images"] = [img]
430
+
431
+ else:
432
+ raise KeyError("Input data must include 'base64_image' or 'base64_images'.")
433
+
434
+ return data
435
+
436
+ def format_input(self, data: Dict[str, Any], protocol: str, max_batch_size: int, **kwargs) -> Any:
437
+ """
438
+ Format input data for the specified protocol ("grpc" or "http"), supporting batched data.
439
+
440
+ Parameters
441
+ ----------
442
+ data : dict of str -> Any
443
+ The input data dictionary, expected to contain "images" (list of np.ndarray)
444
+ and "image_dims" (list of (height, width) tuples), as produced by prepare_data_for_inference.
445
+ protocol : str
446
+ The inference protocol, either "grpc" or "http".
447
+ max_batch_size : int
448
+ The maximum batch size for batching.
449
+
450
+ Returns
451
+ -------
452
+ tuple
453
+ A tuple (formatted_batches, formatted_batch_data) where:
454
+ - formatted_batches is a list of batches ready for inference.
455
+ - formatted_batch_data is a list of scratch-pad dictionaries corresponding to each batch,
456
+ containing the keys "images" and "image_dims" for later post-processing.
457
+
458
+ Raises
459
+ ------
460
+ KeyError
461
+ If either "images" or "image_dims" is not found in `data`.
462
+ ValueError
463
+ If an invalid protocol is specified.
464
+ """
465
+
466
+ images = data["images"]
467
+
468
+ dims: List[Dict[str, Any]] = []
469
+ data["image_dims"] = dims
470
+
471
+ # Helper function to split a list into chunks of size up to chunk_size.
472
+ def chunk_list(lst, chunk_size):
473
+ return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
474
+
475
+ if "images" not in data or "image_dims" not in data:
476
+ raise KeyError("Expected 'images' and 'image_dims' in data. Call prepare_data_for_inference first.")
477
+
478
+ images = data["images"]
479
+ dims = data["image_dims"]
480
+
481
+ if protocol == "grpc":
482
+ logger.debug("Formatting input for gRPC OCR model (batched).")
483
+ processed: List[np.ndarray] = []
484
+
485
+ for img in images:
486
+ arr, _dims = preprocess_image_for_paddle(img)
487
+ dims.append(_dims)
488
+ arr = arr.astype(np.float32)
489
+ arr = np.expand_dims(arr, axis=0)
490
+ processed.append(arr)
491
+
492
+ batches = []
493
+ batch_data_list = []
494
+ for proc_chunk, orig_chunk, dims_chunk in zip(
495
+ chunk_list(processed, max_batch_size),
496
+ chunk_list(images, max_batch_size),
497
+ chunk_list(dims, max_batch_size),
498
+ ):
499
+ batched_input = np.concatenate(proc_chunk, axis=0)
500
+ batches.append(batched_input)
501
+ batch_data_list.append({"images": orig_chunk, "image_dims": dims_chunk})
502
+ return batches, batch_data_list
503
+
504
+ elif protocol == "http":
505
+ logger.debug("Formatting input for HTTP OCR model (batched).")
506
+ if "base64_images" in data:
507
+ base64_list = data["base64_images"]
508
+ else:
509
+ base64_list = [data["base64_image"]]
510
+
511
+ input_list: List[Dict[str, Any]] = []
512
+ for b64, img in zip(base64_list, images):
513
+ image_url = f"data:image/png;base64,{b64}"
514
+ image_obj = {"type": "image_url", "url": image_url}
515
+ input_list.append(image_obj)
516
+ _dims = {"new_width": img.shape[1], "new_height": img.shape[0]}
517
+ dims.append(_dims)
518
+
519
+ batches = []
520
+ batch_data_list = []
521
+ for input_chunk, orig_chunk, dims_chunk in zip(
522
+ chunk_list(input_list, max_batch_size),
523
+ chunk_list(images, max_batch_size),
524
+ chunk_list(dims, max_batch_size),
525
+ ):
526
+ payload = {"input": input_chunk}
527
+ batches.append(payload)
528
+ batch_data_list.append({"images": orig_chunk, "image_dims": dims_chunk})
529
+
530
+ return batches, batch_data_list
531
+
532
+ else:
533
+ raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.")
534
+
535
+
536
+ class NemoRetrieverOCRModelInterface(OCRModelInterfaceBase):
537
+ """
538
+ An interface for handling inference with NemoRetrieverOCR model, supporting both gRPC and HTTP protocols.
539
+ """
540
+
541
+ def name(self) -> str:
542
+ """
543
+ Get the name of the model interface.
544
+
545
+ Returns
546
+ -------
547
+ str
548
+ The name of the model interface.
549
+ """
550
+ return "NemoRetrieverOCR"
551
+
552
+ def prepare_data_for_inference(self, data: Dict[str, Any]) -> Dict[str, Any]:
553
+ """
554
+ Decode one or more base64-encoded images into NumPy arrays, storing them
555
+ alongside their dimensions in `data`.
556
+
557
+ Parameters
558
+ ----------
559
+ data : dict of str -> Any
560
+ The input data containing either:
561
+ - 'base64_image': a single base64-encoded image, or
562
+ - 'base64_images': a list of base64-encoded images.
563
+
564
+ Returns
565
+ -------
566
+ dict of str -> Any
567
+ The updated data dictionary with the following keys added:
568
+ - "images": List of decoded NumPy arrays of shape (H, W, C).
569
+ - "image_dims": List of (height, width) tuples for each decoded image.
570
+
571
+ Raises
572
+ ------
573
+ KeyError
574
+ If neither 'base64_image' nor 'base64_images' is found in `data`.
575
+ ValueError
576
+ If 'base64_images' is present but is not a list.
577
+ """
578
+ if "base64_images" in data:
579
+ base64_list = data["base64_images"]
580
+ if not isinstance(base64_list, list):
581
+ raise ValueError("The 'base64_images' key must contain a list of base64-encoded strings.")
582
+
583
+ images: List[np.ndarray] = []
584
+ for b64 in base64_list:
585
+ img = base64_to_numpy(b64)
586
+ images.append(img)
587
+
588
+ data["images"] = images
589
+
590
+ elif "base64_image" in data:
591
+ # Single-image fallback
592
+ img = base64_to_numpy(data["base64_image"])
593
+ data["images"] = [img]
594
+
595
+ else:
596
+ raise KeyError("Input data must include 'base64_image' or 'base64_images'.")
597
+
598
+ return data
599
+
600
+ def coalesce_requests_to_batch(
601
+ self,
602
+ requests: List[np.ndarray],
603
+ original_image_shapes: List[Tuple[int, int]],
604
+ protocol: str,
605
+ **kwargs,
606
+ ) -> Tuple[List[Any], List[Dict[str, Any]]]:
607
+ """
608
+ Takes a list of individual data items (NumPy image arrays) and combines them
609
+ into a single formatted batch ready for inference.
610
+
611
+ This method mirrors the logic of `format_input` but operates on an already-formed
612
+ batch from the dynamic batcher, so it does not perform any chunking.
613
+
614
+ Parameters
615
+ ----------
616
+ requests : List[np.ndarray]
617
+ A list of single data items, which are NumPy arrays representing images.
618
+ protocol : str
619
+ The inference protocol, either "grpc" or "http".
620
+ **kwargs : Any
621
+ Additional keyword arguments, such as `model_name` and `merge_level`.
622
+
623
+ Returns
624
+ -------
625
+ Tuple[List[Any], List[Dict[str, Any]]]
626
+ A tuple containing two lists, each with a single element:
627
+ - The first list contains the single formatted batch.
628
+ - The second list contains the single scratch-pad dictionary for that batch.
629
+ """
630
+ if not requests:
631
+ return None, {}
632
+
633
+ return self._format_single_batch(requests, original_image_shapes, protocol, **kwargs)
634
+
635
+ def format_input(self, data: Dict[str, Any], protocol: str, max_batch_size: int, **kwargs) -> Any:
636
+ """
637
+ Format input data for the specified protocol ("grpc" or "http"), supporting batched data.
638
+
639
+ Parameters
640
+ ----------
641
+ data : dict of str -> Any
642
+ The input data dictionary, expected to contain "images" (list of np.ndarray)
643
+ and "image_dims" (list of (height, width) tuples), as produced by prepare_data_for_inference.
644
+ protocol : str
645
+ The inference protocol, either "grpc" or "http".
646
+ max_batch_size : int
647
+ The maximum batch size for batching.
648
+
649
+ Returns
650
+ -------
651
+ tuple
652
+ A tuple (formatted_batches, formatted_batch_data) where:
653
+ - formatted_batches is a list of batches ready for inference.
654
+ - formatted_batch_data is a list of scratch-pad dictionaries corresponding to each batch,
655
+ containing the keys "images" and "image_dims" for later post-processing.
656
+
657
+ Raises
658
+ ------
659
+ KeyError
660
+ If either "images" or "image_dims" is not found in `data`.
661
+ ValueError
662
+ If an invalid protocol is specified.
663
+ """
664
+
665
+ # Helper function to split a list into chunks of size up to chunk_size.
666
+ def chunk_list(lst, chunk_size):
667
+ return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
668
+
669
+ if "images" not in data:
670
+ raise KeyError("Expected 'images' in data. Call prepare_data_for_inference first.")
671
+
672
+ images = data["base64_images"]
673
+ dims = [img.shape[:2] for img in data["images"]]
674
+
675
+ formatted_batches = []
676
+ formatted_batch_data = []
677
+
678
+ image_chunks = chunk_list(images, max_batch_size)
679
+ dims_chunks = chunk_list(dims, max_batch_size)
680
+ for image_chunk, dims_chunk in zip(image_chunks, dims_chunks):
681
+ final_batch, batch_data = self._format_single_batch(image_chunk, dims_chunk, protocol, **kwargs)
682
+ formatted_batches.append(final_batch)
683
+ formatted_batch_data.append(batch_data)
684
+
685
+ all_dims = [item for d in formatted_batch_data for item in d.get("image_dims", [])]
686
+ data["image_dims"] = all_dims
687
+
688
+ return formatted_batches, formatted_batch_data
689
+
690
+ def _format_single_batch(
691
+ self,
692
+ batch_images: List[str],
693
+ batch_dims: List[Tuple[int, int]],
694
+ protocol: str,
695
+ **kwargs,
696
+ ) -> Tuple[Any, Dict[str, Any]]:
697
+ dims: List[Dict[str, Any]] = []
698
+
699
+ merge_level = kwargs.get("merge_level", "paragraph")
700
+
701
+ if protocol == "grpc":
702
+ logger.debug("Formatting input for gRPC OCR model (batched).")
703
+ processed: List[np.ndarray] = []
704
+
705
+ for img, shape in zip(batch_images, batch_dims):
706
+ _dims = {"new_width": shape[1], "new_height": shape[0]}
707
+ dims.append(_dims)
708
+
709
+ arr = np.array([img], dtype=np.object_)
710
+ arr = np.expand_dims(arr, axis=0)
711
+ processed.append(arr)
712
+
713
+ batched_input = np.concatenate(processed, axis=0)
714
+
715
+ batch_size = batched_input.shape[0]
716
+
717
+ merge_levels_list = [[merge_level] for _ in range(batch_size)]
718
+ merge_levels = np.array(merge_levels_list, dtype="object")
719
+
720
+ final_batch = [batched_input, merge_levels]
721
+ batch_data = {"image_dims": dims}
722
+
723
+ return final_batch, batch_data
724
+
725
+ elif protocol == "http":
726
+ logger.debug("Formatting input for HTTP OCR model (batched).")
727
+
728
+ input_list: List[Dict[str, Any]] = []
729
+ for b64, shape in zip(batch_images, batch_dims):
730
+ image_url = f"data:image/png;base64,{b64}"
731
+ image_obj = {"type": "image_url", "url": image_url}
732
+ input_list.append(image_obj)
733
+ _dims = {"new_width": shape[1], "new_height": shape[0]}
734
+ dims.append(_dims)
735
+
736
+ payload = {
737
+ "input": input_list,
738
+ "merge_levels": [merge_level] * len(input_list),
739
+ }
740
+
741
+ batch_data = {"image_dims": dims}
742
+
743
+ return payload, batch_data
744
+
745
+ else:
746
+ raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.")
747
+
748
+
749
+ @multiprocessing_cache(max_calls=100) # Cache results first to avoid redundant retries from backoff
750
+ @backoff.on_predicate(backoff.expo, max_time=30)
751
+ def get_ocr_model_name(ocr_grpc_endpoint=None, default_model_name=DEFAULT_OCR_MODEL_NAME):
752
+ """
753
+ Determines the OCR model name by checking the environment, querying the gRPC endpoint,
754
+ or falling back to a default.
755
+ """
756
+ # 1. Check for an explicit override from the environment variable first.
757
+ ocr_model_name = os.getenv("OCR_MODEL_NAME", None)
758
+ if ocr_model_name is not None:
759
+ return ocr_model_name
760
+
761
+ # 2. If no gRPC endpoint is provided or the endpoint is a NVCF endpoint, fall back to the default immediately.
762
+ if (not ocr_grpc_endpoint) or ("grpc.nvcf.nvidia.com" in ocr_grpc_endpoint):
763
+ logger.debug(f"No OCR gRPC endpoint provided. Falling back to default model name '{default_model_name}'.")
764
+ return default_model_name
765
+
766
+ # 3. Attempt to query the gRPC endpoint to discover the model name.
767
+ try:
768
+ client = grpcclient.InferenceServerClient(ocr_grpc_endpoint)
769
+ model_index = client.get_model_repository_index(as_json=True)
770
+ model_names = [x["name"] for x in model_index.get("models", [])]
771
+ ocr_model_name = model_names[0]
772
+ except Exception:
773
+ logger.warning(f"Failed to get ocr model name after 30 seconds. Falling back to '{default_model_name}'.")
774
+ ocr_model_name = default_model_name
775
+
776
+ return ocr_model_name