nv-ingest-api 26.1.0rc4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nv-ingest-api might be problematic. Click here for more details.
- nv_ingest_api/__init__.py +3 -0
- nv_ingest_api/interface/__init__.py +218 -0
- nv_ingest_api/interface/extract.py +977 -0
- nv_ingest_api/interface/mutate.py +154 -0
- nv_ingest_api/interface/store.py +200 -0
- nv_ingest_api/interface/transform.py +382 -0
- nv_ingest_api/interface/utility.py +186 -0
- nv_ingest_api/internal/__init__.py +0 -0
- nv_ingest_api/internal/enums/__init__.py +3 -0
- nv_ingest_api/internal/enums/common.py +550 -0
- nv_ingest_api/internal/extract/__init__.py +3 -0
- nv_ingest_api/internal/extract/audio/__init__.py +3 -0
- nv_ingest_api/internal/extract/audio/audio_extraction.py +202 -0
- nv_ingest_api/internal/extract/docx/__init__.py +5 -0
- nv_ingest_api/internal/extract/docx/docx_extractor.py +232 -0
- nv_ingest_api/internal/extract/docx/engines/__init__.py +0 -0
- nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/__init__.py +3 -0
- nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/docx_helper.py +127 -0
- nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/docxreader.py +971 -0
- nv_ingest_api/internal/extract/html/__init__.py +3 -0
- nv_ingest_api/internal/extract/html/html_extractor.py +84 -0
- nv_ingest_api/internal/extract/image/__init__.py +3 -0
- nv_ingest_api/internal/extract/image/chart_extractor.py +375 -0
- nv_ingest_api/internal/extract/image/image_extractor.py +208 -0
- nv_ingest_api/internal/extract/image/image_helpers/__init__.py +3 -0
- nv_ingest_api/internal/extract/image/image_helpers/common.py +433 -0
- nv_ingest_api/internal/extract/image/infographic_extractor.py +290 -0
- nv_ingest_api/internal/extract/image/ocr_extractor.py +407 -0
- nv_ingest_api/internal/extract/image/table_extractor.py +391 -0
- nv_ingest_api/internal/extract/pdf/__init__.py +3 -0
- nv_ingest_api/internal/extract/pdf/engines/__init__.py +19 -0
- nv_ingest_api/internal/extract/pdf/engines/adobe.py +484 -0
- nv_ingest_api/internal/extract/pdf/engines/llama.py +246 -0
- nv_ingest_api/internal/extract/pdf/engines/nemotron_parse.py +598 -0
- nv_ingest_api/internal/extract/pdf/engines/pdf_helpers/__init__.py +166 -0
- nv_ingest_api/internal/extract/pdf/engines/pdfium.py +652 -0
- nv_ingest_api/internal/extract/pdf/engines/tika.py +96 -0
- nv_ingest_api/internal/extract/pdf/engines/unstructured_io.py +426 -0
- nv_ingest_api/internal/extract/pdf/pdf_extractor.py +74 -0
- nv_ingest_api/internal/extract/pptx/__init__.py +5 -0
- nv_ingest_api/internal/extract/pptx/engines/__init__.py +0 -0
- nv_ingest_api/internal/extract/pptx/engines/pptx_helper.py +968 -0
- nv_ingest_api/internal/extract/pptx/pptx_extractor.py +210 -0
- nv_ingest_api/internal/meta/__init__.py +3 -0
- nv_ingest_api/internal/meta/udf.py +232 -0
- nv_ingest_api/internal/mutate/__init__.py +3 -0
- nv_ingest_api/internal/mutate/deduplicate.py +110 -0
- nv_ingest_api/internal/mutate/filter.py +133 -0
- nv_ingest_api/internal/primitives/__init__.py +0 -0
- nv_ingest_api/internal/primitives/control_message_task.py +16 -0
- nv_ingest_api/internal/primitives/ingest_control_message.py +307 -0
- nv_ingest_api/internal/primitives/nim/__init__.py +9 -0
- nv_ingest_api/internal/primitives/nim/default_values.py +14 -0
- nv_ingest_api/internal/primitives/nim/model_interface/__init__.py +3 -0
- nv_ingest_api/internal/primitives/nim/model_interface/cached.py +274 -0
- nv_ingest_api/internal/primitives/nim/model_interface/decorators.py +56 -0
- nv_ingest_api/internal/primitives/nim/model_interface/deplot.py +270 -0
- nv_ingest_api/internal/primitives/nim/model_interface/helpers.py +338 -0
- nv_ingest_api/internal/primitives/nim/model_interface/nemotron_parse.py +239 -0
- nv_ingest_api/internal/primitives/nim/model_interface/ocr.py +776 -0
- nv_ingest_api/internal/primitives/nim/model_interface/parakeet.py +367 -0
- nv_ingest_api/internal/primitives/nim/model_interface/text_embedding.py +129 -0
- nv_ingest_api/internal/primitives/nim/model_interface/vlm.py +177 -0
- nv_ingest_api/internal/primitives/nim/model_interface/yolox.py +1681 -0
- nv_ingest_api/internal/primitives/nim/nim_client.py +801 -0
- nv_ingest_api/internal/primitives/nim/nim_model_interface.py +126 -0
- nv_ingest_api/internal/primitives/tracing/__init__.py +0 -0
- nv_ingest_api/internal/primitives/tracing/latency.py +69 -0
- nv_ingest_api/internal/primitives/tracing/logging.py +96 -0
- nv_ingest_api/internal/primitives/tracing/tagging.py +288 -0
- nv_ingest_api/internal/schemas/__init__.py +3 -0
- nv_ingest_api/internal/schemas/extract/__init__.py +3 -0
- nv_ingest_api/internal/schemas/extract/extract_audio_schema.py +133 -0
- nv_ingest_api/internal/schemas/extract/extract_chart_schema.py +144 -0
- nv_ingest_api/internal/schemas/extract/extract_docx_schema.py +129 -0
- nv_ingest_api/internal/schemas/extract/extract_html_schema.py +34 -0
- nv_ingest_api/internal/schemas/extract/extract_image_schema.py +126 -0
- nv_ingest_api/internal/schemas/extract/extract_infographic_schema.py +137 -0
- nv_ingest_api/internal/schemas/extract/extract_ocr_schema.py +137 -0
- nv_ingest_api/internal/schemas/extract/extract_pdf_schema.py +220 -0
- nv_ingest_api/internal/schemas/extract/extract_pptx_schema.py +128 -0
- nv_ingest_api/internal/schemas/extract/extract_table_schema.py +137 -0
- nv_ingest_api/internal/schemas/message_brokers/__init__.py +3 -0
- nv_ingest_api/internal/schemas/message_brokers/message_broker_client_schema.py +37 -0
- nv_ingest_api/internal/schemas/message_brokers/request_schema.py +34 -0
- nv_ingest_api/internal/schemas/message_brokers/response_schema.py +19 -0
- nv_ingest_api/internal/schemas/meta/__init__.py +3 -0
- nv_ingest_api/internal/schemas/meta/base_model_noext.py +11 -0
- nv_ingest_api/internal/schemas/meta/ingest_job_schema.py +355 -0
- nv_ingest_api/internal/schemas/meta/metadata_schema.py +394 -0
- nv_ingest_api/internal/schemas/meta/udf.py +23 -0
- nv_ingest_api/internal/schemas/mixins.py +39 -0
- nv_ingest_api/internal/schemas/mutate/__init__.py +3 -0
- nv_ingest_api/internal/schemas/mutate/mutate_image_dedup_schema.py +16 -0
- nv_ingest_api/internal/schemas/store/__init__.py +3 -0
- nv_ingest_api/internal/schemas/store/store_embedding_schema.py +28 -0
- nv_ingest_api/internal/schemas/store/store_image_schema.py +45 -0
- nv_ingest_api/internal/schemas/transform/__init__.py +3 -0
- nv_ingest_api/internal/schemas/transform/transform_image_caption_schema.py +36 -0
- nv_ingest_api/internal/schemas/transform/transform_image_filter_schema.py +17 -0
- nv_ingest_api/internal/schemas/transform/transform_text_embedding_schema.py +48 -0
- nv_ingest_api/internal/schemas/transform/transform_text_splitter_schema.py +24 -0
- nv_ingest_api/internal/store/__init__.py +3 -0
- nv_ingest_api/internal/store/embed_text_upload.py +236 -0
- nv_ingest_api/internal/store/image_upload.py +251 -0
- nv_ingest_api/internal/transform/__init__.py +3 -0
- nv_ingest_api/internal/transform/caption_image.py +219 -0
- nv_ingest_api/internal/transform/embed_text.py +702 -0
- nv_ingest_api/internal/transform/split_text.py +182 -0
- nv_ingest_api/util/__init__.py +3 -0
- nv_ingest_api/util/control_message/__init__.py +0 -0
- nv_ingest_api/util/control_message/validators.py +47 -0
- nv_ingest_api/util/converters/__init__.py +0 -0
- nv_ingest_api/util/converters/bytetools.py +78 -0
- nv_ingest_api/util/converters/containers.py +65 -0
- nv_ingest_api/util/converters/datetools.py +90 -0
- nv_ingest_api/util/converters/dftools.py +127 -0
- nv_ingest_api/util/converters/formats.py +64 -0
- nv_ingest_api/util/converters/type_mappings.py +27 -0
- nv_ingest_api/util/dataloader/__init__.py +9 -0
- nv_ingest_api/util/dataloader/dataloader.py +409 -0
- nv_ingest_api/util/detectors/__init__.py +5 -0
- nv_ingest_api/util/detectors/language.py +38 -0
- nv_ingest_api/util/exception_handlers/__init__.py +0 -0
- nv_ingest_api/util/exception_handlers/converters.py +72 -0
- nv_ingest_api/util/exception_handlers/decorators.py +429 -0
- nv_ingest_api/util/exception_handlers/detectors.py +74 -0
- nv_ingest_api/util/exception_handlers/pdf.py +116 -0
- nv_ingest_api/util/exception_handlers/schemas.py +68 -0
- nv_ingest_api/util/image_processing/__init__.py +5 -0
- nv_ingest_api/util/image_processing/clustering.py +260 -0
- nv_ingest_api/util/image_processing/processing.py +177 -0
- nv_ingest_api/util/image_processing/table_and_chart.py +504 -0
- nv_ingest_api/util/image_processing/transforms.py +850 -0
- nv_ingest_api/util/imports/__init__.py +3 -0
- nv_ingest_api/util/imports/callable_signatures.py +108 -0
- nv_ingest_api/util/imports/dynamic_resolvers.py +158 -0
- nv_ingest_api/util/introspection/__init__.py +3 -0
- nv_ingest_api/util/introspection/class_inspect.py +145 -0
- nv_ingest_api/util/introspection/function_inspect.py +65 -0
- nv_ingest_api/util/logging/__init__.py +0 -0
- nv_ingest_api/util/logging/configuration.py +102 -0
- nv_ingest_api/util/logging/sanitize.py +84 -0
- nv_ingest_api/util/message_brokers/__init__.py +3 -0
- nv_ingest_api/util/message_brokers/qos_scheduler.py +283 -0
- nv_ingest_api/util/message_brokers/simple_message_broker/__init__.py +9 -0
- nv_ingest_api/util/message_brokers/simple_message_broker/broker.py +465 -0
- nv_ingest_api/util/message_brokers/simple_message_broker/ordered_message_queue.py +71 -0
- nv_ingest_api/util/message_brokers/simple_message_broker/simple_client.py +455 -0
- nv_ingest_api/util/metadata/__init__.py +5 -0
- nv_ingest_api/util/metadata/aggregators.py +516 -0
- nv_ingest_api/util/multi_processing/__init__.py +8 -0
- nv_ingest_api/util/multi_processing/mp_pool_singleton.py +200 -0
- nv_ingest_api/util/nim/__init__.py +161 -0
- nv_ingest_api/util/pdf/__init__.py +3 -0
- nv_ingest_api/util/pdf/pdfium.py +428 -0
- nv_ingest_api/util/schema/__init__.py +3 -0
- nv_ingest_api/util/schema/schema_validator.py +10 -0
- nv_ingest_api/util/service_clients/__init__.py +3 -0
- nv_ingest_api/util/service_clients/client_base.py +86 -0
- nv_ingest_api/util/service_clients/kafka/__init__.py +3 -0
- nv_ingest_api/util/service_clients/redis/__init__.py +3 -0
- nv_ingest_api/util/service_clients/redis/redis_client.py +983 -0
- nv_ingest_api/util/service_clients/rest/__init__.py +0 -0
- nv_ingest_api/util/service_clients/rest/rest_client.py +595 -0
- nv_ingest_api/util/string_processing/__init__.py +51 -0
- nv_ingest_api/util/string_processing/configuration.py +682 -0
- nv_ingest_api/util/string_processing/yaml.py +109 -0
- nv_ingest_api/util/system/__init__.py +0 -0
- nv_ingest_api/util/system/hardware_info.py +594 -0
- nv_ingest_api-26.1.0rc4.dist-info/METADATA +237 -0
- nv_ingest_api-26.1.0rc4.dist-info/RECORD +177 -0
- nv_ingest_api-26.1.0rc4.dist-info/WHEEL +5 -0
- nv_ingest_api-26.1.0rc4.dist-info/licenses/LICENSE +201 -0
- nv_ingest_api-26.1.0rc4.dist-info/top_level.txt +2 -0
- udfs/__init__.py +5 -0
- udfs/llm_summarizer_udf.py +259 -0
|
@@ -0,0 +1,776 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import logging
|
|
7
|
+
import os
|
|
8
|
+
from typing import Any
|
|
9
|
+
from typing import Dict
|
|
10
|
+
from typing import List
|
|
11
|
+
from typing import Optional
|
|
12
|
+
from typing import Tuple
|
|
13
|
+
|
|
14
|
+
import backoff
|
|
15
|
+
import numpy as np
|
|
16
|
+
import tritonclient.grpc as grpcclient
|
|
17
|
+
|
|
18
|
+
from nv_ingest_api.internal.primitives.nim import ModelInterface
|
|
19
|
+
from nv_ingest_api.internal.primitives.nim.model_interface.decorators import multiprocessing_cache
|
|
20
|
+
from nv_ingest_api.internal.primitives.nim.model_interface.helpers import preprocess_image_for_paddle
|
|
21
|
+
from nv_ingest_api.util.image_processing.transforms import base64_to_numpy
|
|
22
|
+
|
|
23
|
+
DEFAULT_OCR_MODEL_NAME = "scene_text_ensemble"
|
|
24
|
+
NEMORETRIEVER_OCR_MODEL_NAME = "scene_text_wrapper"
|
|
25
|
+
NEMORETRIEVER_OCR_ENSEMBLE_MODEL_NAME = "scene_text_ensemble"
|
|
26
|
+
NEMORETRIEVER_OCR_BLS_MODEL_NAME = "scene_text_python"
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class OCRModelInterfaceBase(ModelInterface):
|
|
33
|
+
|
|
34
|
+
NUM_CHANNELS = 3
|
|
35
|
+
BYTES_PER_ELEMENT = 4 # For float32
|
|
36
|
+
|
|
37
|
+
def parse_output(
|
|
38
|
+
self,
|
|
39
|
+
response: Any,
|
|
40
|
+
protocol: str,
|
|
41
|
+
data: Optional[Dict[str, Any]] = None,
|
|
42
|
+
model_name: str = DEFAULT_OCR_MODEL_NAME,
|
|
43
|
+
**kwargs: Any,
|
|
44
|
+
) -> Any:
|
|
45
|
+
"""
|
|
46
|
+
Parse the model's inference response for the given protocol. The parsing
|
|
47
|
+
may handle batched outputs for multiple images.
|
|
48
|
+
|
|
49
|
+
Parameters
|
|
50
|
+
----------
|
|
51
|
+
response : Any
|
|
52
|
+
The raw response from the OCR model.
|
|
53
|
+
protocol : str
|
|
54
|
+
The protocol used for inference, "grpc" or "http".
|
|
55
|
+
data : dict of str -> Any, optional
|
|
56
|
+
Additional data dictionary that may include "image_dims" for bounding box scaling.
|
|
57
|
+
**kwargs : Any
|
|
58
|
+
Additional keyword arguments, such as custom `table_content_format`.
|
|
59
|
+
|
|
60
|
+
Returns
|
|
61
|
+
-------
|
|
62
|
+
Any
|
|
63
|
+
The parsed output, typically a list of (content, table_content_format) tuples.
|
|
64
|
+
|
|
65
|
+
Raises
|
|
66
|
+
------
|
|
67
|
+
ValueError
|
|
68
|
+
If an invalid protocol is specified.
|
|
69
|
+
"""
|
|
70
|
+
# Retrieve image dimensions if available
|
|
71
|
+
dims: Optional[List[Tuple[int, int]]] = data.get("image_dims") if data else None
|
|
72
|
+
|
|
73
|
+
if protocol == "grpc":
|
|
74
|
+
logger.debug("Parsing output from gRPC OCR model (batched).")
|
|
75
|
+
return self._extract_content_from_ocr_grpc_response(response, dims, model_name=model_name)
|
|
76
|
+
|
|
77
|
+
elif protocol == "http":
|
|
78
|
+
logger.debug("Parsing output from HTTP OCR model (batched).")
|
|
79
|
+
return self._extract_content_from_ocr_http_response(response, dims)
|
|
80
|
+
|
|
81
|
+
else:
|
|
82
|
+
raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.")
|
|
83
|
+
|
|
84
|
+
def process_inference_results(self, output: Any, **kwargs: Any) -> Any:
|
|
85
|
+
"""
|
|
86
|
+
Process inference results for the OCR model.
|
|
87
|
+
|
|
88
|
+
Parameters
|
|
89
|
+
----------
|
|
90
|
+
output : Any
|
|
91
|
+
The raw output parsed from the OCR model.
|
|
92
|
+
**kwargs : Any
|
|
93
|
+
Additional keyword arguments for customization.
|
|
94
|
+
|
|
95
|
+
Returns
|
|
96
|
+
-------
|
|
97
|
+
Any
|
|
98
|
+
The post-processed inference results. By default, this simply returns the output
|
|
99
|
+
as the table content (or content list).
|
|
100
|
+
"""
|
|
101
|
+
return output
|
|
102
|
+
|
|
103
|
+
def does_item_fit_in_batch(self, current_batch, next_request, memory_budget_bytes: int) -> bool:
|
|
104
|
+
"""
|
|
105
|
+
Estimates the memory of a potential batch of padded images and checks it
|
|
106
|
+
against the configured budget.
|
|
107
|
+
"""
|
|
108
|
+
all_requests = current_batch + [next_request]
|
|
109
|
+
all_dims = [req.dims for req in all_requests]
|
|
110
|
+
|
|
111
|
+
potential_max_h = max(d[0] for d in all_dims)
|
|
112
|
+
potential_max_w = max(d[1] for d in all_dims)
|
|
113
|
+
|
|
114
|
+
potential_batch_size = len(all_requests)
|
|
115
|
+
|
|
116
|
+
potential_memory_bytes = (
|
|
117
|
+
potential_batch_size * potential_max_h * potential_max_w * self.NUM_CHANNELS * self.BYTES_PER_ELEMENT
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
return potential_memory_bytes <= memory_budget_bytes
|
|
121
|
+
|
|
122
|
+
def _prepare_ocr_payload(self, base64_img: str) -> Dict[str, Any]:
|
|
123
|
+
"""
|
|
124
|
+
DEPRECATED by batch logic in format_input. Kept here if you need single-image direct calls.
|
|
125
|
+
|
|
126
|
+
Parameters
|
|
127
|
+
----------
|
|
128
|
+
base64_img : str
|
|
129
|
+
A single base64-encoded image string.
|
|
130
|
+
|
|
131
|
+
Returns
|
|
132
|
+
-------
|
|
133
|
+
dict of str -> Any
|
|
134
|
+
The payload in either legacy or new format for OCR's HTTP endpoint.
|
|
135
|
+
"""
|
|
136
|
+
image_url = f"data:image/png;base64,{base64_img}"
|
|
137
|
+
|
|
138
|
+
image = {"type": "image_url", "url": image_url}
|
|
139
|
+
payload = {"input": [image]}
|
|
140
|
+
|
|
141
|
+
return payload
|
|
142
|
+
|
|
143
|
+
def _extract_content_from_ocr_http_response(
|
|
144
|
+
self,
|
|
145
|
+
json_response: Dict[str, Any],
|
|
146
|
+
dimensions: List[Dict[str, Any]],
|
|
147
|
+
) -> List[Tuple[str, str]]:
|
|
148
|
+
"""
|
|
149
|
+
Extract content from the JSON response of a OCR HTTP API request.
|
|
150
|
+
|
|
151
|
+
Parameters
|
|
152
|
+
----------
|
|
153
|
+
json_response : dict of str -> Any
|
|
154
|
+
The JSON response returned by the OCR endpoint.
|
|
155
|
+
table_content_format : str or None
|
|
156
|
+
The specified format for table content (e.g., 'simple' or 'pseudo_markdown').
|
|
157
|
+
dimensions : list of dict, optional
|
|
158
|
+
A list of dict for each corresponding image, used for bounding box scaling.
|
|
159
|
+
|
|
160
|
+
Returns
|
|
161
|
+
-------
|
|
162
|
+
list of (str, str)
|
|
163
|
+
A list of (content, table_content_format) tuples, one for each image result.
|
|
164
|
+
|
|
165
|
+
Raises
|
|
166
|
+
------
|
|
167
|
+
RuntimeError
|
|
168
|
+
If the response format is missing or invalid.
|
|
169
|
+
ValueError
|
|
170
|
+
If the `table_content_format` is unrecognized.
|
|
171
|
+
"""
|
|
172
|
+
if "data" not in json_response or not json_response["data"]:
|
|
173
|
+
raise RuntimeError("Unexpected response format: 'data' key is missing or empty.")
|
|
174
|
+
|
|
175
|
+
results: List[str] = []
|
|
176
|
+
for item_idx, item in enumerate(json_response["data"]):
|
|
177
|
+
text_detections = item.get("text_detections", [])
|
|
178
|
+
text_predictions = []
|
|
179
|
+
bounding_boxes = []
|
|
180
|
+
conf_scores = []
|
|
181
|
+
for td in text_detections:
|
|
182
|
+
text_predictions.append(td["text_prediction"]["text"])
|
|
183
|
+
bounding_boxes.append([[pt["x"], pt["y"]] for pt in td["bounding_box"]["points"]])
|
|
184
|
+
conf_scores.append(td["text_prediction"]["confidence"])
|
|
185
|
+
|
|
186
|
+
bounding_boxes, text_predictions, conf_scores = self._postprocess_ocr_response(
|
|
187
|
+
bounding_boxes,
|
|
188
|
+
text_predictions,
|
|
189
|
+
conf_scores,
|
|
190
|
+
dimensions,
|
|
191
|
+
img_index=item_idx,
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
results.append([bounding_boxes, text_predictions, conf_scores])
|
|
195
|
+
|
|
196
|
+
return results
|
|
197
|
+
|
|
198
|
+
def _extract_content_from_ocr_grpc_response(
|
|
199
|
+
self,
|
|
200
|
+
response: np.ndarray,
|
|
201
|
+
dimensions: List[Dict[str, Any]],
|
|
202
|
+
model_name: str = DEFAULT_OCR_MODEL_NAME,
|
|
203
|
+
) -> List[Tuple[str, str]]:
|
|
204
|
+
"""
|
|
205
|
+
Parse a gRPC response for one or more images. The response can have two possible shapes:
|
|
206
|
+
- (3,) for batch_size=1
|
|
207
|
+
- (3, n) for batch_size=n
|
|
208
|
+
|
|
209
|
+
In either case:
|
|
210
|
+
response[0, i]: byte string containing bounding box data
|
|
211
|
+
response[1, i]: byte string containing text prediction data
|
|
212
|
+
response[2, i]: (Optional) additional data/metadata (ignored here)
|
|
213
|
+
|
|
214
|
+
Parameters
|
|
215
|
+
----------
|
|
216
|
+
response : np.ndarray
|
|
217
|
+
The raw NumPy array from gRPC. Expected shape: (3,) or (3, n).
|
|
218
|
+
table_content_format : str
|
|
219
|
+
The format of the output text content, e.g. 'simple' or 'pseudo_markdown'.
|
|
220
|
+
dims : list of dict, optional
|
|
221
|
+
A list of dict for each corresponding image, used for bounding box scaling.
|
|
222
|
+
|
|
223
|
+
Returns
|
|
224
|
+
-------
|
|
225
|
+
list of (str, str)
|
|
226
|
+
A list of (content, table_content_format) for each image.
|
|
227
|
+
|
|
228
|
+
Raises
|
|
229
|
+
------
|
|
230
|
+
ValueError
|
|
231
|
+
If the response is not a NumPy array or has an unexpected shape,
|
|
232
|
+
or if the `table_content_format` is unrecognized.
|
|
233
|
+
"""
|
|
234
|
+
if not isinstance(response, np.ndarray):
|
|
235
|
+
raise ValueError("Unexpected response format: response is not a NumPy array.")
|
|
236
|
+
|
|
237
|
+
if model_name in [
|
|
238
|
+
NEMORETRIEVER_OCR_MODEL_NAME,
|
|
239
|
+
NEMORETRIEVER_OCR_ENSEMBLE_MODEL_NAME,
|
|
240
|
+
NEMORETRIEVER_OCR_BLS_MODEL_NAME,
|
|
241
|
+
]:
|
|
242
|
+
response = response.transpose((1, 0))
|
|
243
|
+
|
|
244
|
+
# If we have shape (3,), convert to (3, 1)
|
|
245
|
+
if response.ndim == 1 and response.shape == (3,):
|
|
246
|
+
response = response.reshape(3, 1)
|
|
247
|
+
elif response.ndim != 2 or response.shape[0] != 3:
|
|
248
|
+
raise ValueError(f"Unexpected response shape: {response.shape}. Expecting (3,) or (3, n).")
|
|
249
|
+
batch_size = response.shape[1]
|
|
250
|
+
results: List[Tuple[str, str]] = []
|
|
251
|
+
|
|
252
|
+
for i in range(batch_size):
|
|
253
|
+
# 1) Parse bounding boxes
|
|
254
|
+
bboxes_bytestr: bytes = response[0, i]
|
|
255
|
+
bounding_boxes = json.loads(bboxes_bytestr.decode("utf8"))
|
|
256
|
+
|
|
257
|
+
# 2) Parse text predictions
|
|
258
|
+
texts_bytestr: bytes = response[1, i]
|
|
259
|
+
text_predictions = json.loads(texts_bytestr.decode("utf8"))
|
|
260
|
+
|
|
261
|
+
# 3) Parse confidence scores
|
|
262
|
+
confs_bytestr: bytes = response[2, i]
|
|
263
|
+
conf_scores = json.loads(confs_bytestr.decode("utf8"))
|
|
264
|
+
|
|
265
|
+
# Some gRPC responses nest single-item lists; flatten them if needed
|
|
266
|
+
if (
|
|
267
|
+
(isinstance(bounding_boxes, list) and len(bounding_boxes) == 1 and isinstance(bounding_boxes[0], list))
|
|
268
|
+
and (
|
|
269
|
+
isinstance(text_predictions, list)
|
|
270
|
+
and len(text_predictions) == 1
|
|
271
|
+
and isinstance(text_predictions[0], list)
|
|
272
|
+
)
|
|
273
|
+
and (isinstance(conf_scores, list) and len(conf_scores) == 1 and isinstance(conf_scores[0], list))
|
|
274
|
+
):
|
|
275
|
+
bounding_boxes = bounding_boxes[0]
|
|
276
|
+
text_predictions = text_predictions[0]
|
|
277
|
+
conf_scores = conf_scores[0]
|
|
278
|
+
|
|
279
|
+
# 4) Postprocess
|
|
280
|
+
bounding_boxes, text_predictions, conf_scores = self._postprocess_ocr_response(
|
|
281
|
+
bounding_boxes,
|
|
282
|
+
text_predictions,
|
|
283
|
+
conf_scores,
|
|
284
|
+
dimensions,
|
|
285
|
+
img_index=i,
|
|
286
|
+
scale_coordinates=True,
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
results.append([bounding_boxes, text_predictions, conf_scores])
|
|
290
|
+
|
|
291
|
+
return results
|
|
292
|
+
|
|
293
|
+
@staticmethod
|
|
294
|
+
def _postprocess_ocr_response(
|
|
295
|
+
bounding_boxes: List[Any],
|
|
296
|
+
text_predictions: List[str],
|
|
297
|
+
conf_scores: List[float],
|
|
298
|
+
dims: Optional[List[Dict[str, Any]]] = None,
|
|
299
|
+
img_index: int = 0,
|
|
300
|
+
scale_coordinates: bool = True,
|
|
301
|
+
shift_coordinates: bool = True,
|
|
302
|
+
) -> Tuple[List[Any], List[str]]:
|
|
303
|
+
"""
|
|
304
|
+
Convert bounding boxes with normalized coordinates to pixel cooridnates by using
|
|
305
|
+
the dimensions. Also shift the coorindates if the inputs were padded. For multiple images,
|
|
306
|
+
the correct image dimensions (height, width) are retrieved from `dims[img_index]`.
|
|
307
|
+
|
|
308
|
+
Parameters
|
|
309
|
+
----------
|
|
310
|
+
bounding_boxes : list of Any
|
|
311
|
+
A list (per line of text) of bounding boxes, each a list of (x, y) points.
|
|
312
|
+
text_predictions : list of str
|
|
313
|
+
A list of text predictions, one for each bounding box.
|
|
314
|
+
img_index : int, optional
|
|
315
|
+
The index of the image for which bounding boxes are being converted. Default is 0.
|
|
316
|
+
dims : list of dict, optional
|
|
317
|
+
A list of dictionaries, where each dictionary contains image-specific dimensions
|
|
318
|
+
and scaling information:
|
|
319
|
+
- "new_width" (int): The width of the image after processing.
|
|
320
|
+
- "new_height" (int): The height of the image after processing.
|
|
321
|
+
- "pad_width" (int, optional): The width of padding added to the image.
|
|
322
|
+
- "pad_height" (int, optional): The height of padding added to the image.
|
|
323
|
+
- "scale_factor" (float, optional): The scaling factor applied to the image.
|
|
324
|
+
|
|
325
|
+
Returns
|
|
326
|
+
-------
|
|
327
|
+
Tuple[List[Any], List[str]]
|
|
328
|
+
Bounding boxes scaled backed to the original dimensions and detected text lines.
|
|
329
|
+
|
|
330
|
+
Notes
|
|
331
|
+
-----
|
|
332
|
+
- If `dims` is None or `img_index` is out of range, bounding boxes will not be scaled properly.
|
|
333
|
+
"""
|
|
334
|
+
# Default to no scaling if dims are missing or out of range
|
|
335
|
+
if not dims:
|
|
336
|
+
raise ValueError("No image_dims provided.")
|
|
337
|
+
else:
|
|
338
|
+
if img_index >= len(dims):
|
|
339
|
+
logger.warning("Image index out of range for stored dimensions. Using first image dims by default.")
|
|
340
|
+
img_index = 0
|
|
341
|
+
|
|
342
|
+
max_width = dims[img_index]["new_width"] if scale_coordinates else 1.0
|
|
343
|
+
max_height = dims[img_index]["new_height"] if scale_coordinates else 1.0
|
|
344
|
+
pad_width = dims[img_index].get("pad_width", 0) if shift_coordinates else 0.0
|
|
345
|
+
pad_height = dims[img_index].get("pad_height", 0) if shift_coordinates else 0.0
|
|
346
|
+
scale_factor = dims[img_index].get("scale_factor", 1.0) if scale_coordinates else 1.0
|
|
347
|
+
|
|
348
|
+
bboxes: List[List[float]] = []
|
|
349
|
+
texts: List[str] = []
|
|
350
|
+
confs: List[float] = []
|
|
351
|
+
|
|
352
|
+
# Convert normalized coords back to actual pixel coords
|
|
353
|
+
for box, txt, conf in zip(bounding_boxes, text_predictions, conf_scores):
|
|
354
|
+
if box == "nan":
|
|
355
|
+
continue
|
|
356
|
+
points: List[List[float]] = []
|
|
357
|
+
for point in box:
|
|
358
|
+
# Convert normalized coords back to actual pixel coords,
|
|
359
|
+
# and shift them back to their original positions if padded.
|
|
360
|
+
x_pixels = float(point[0]) * max_width - pad_width
|
|
361
|
+
y_pixels = float(point[1]) * max_height - pad_height
|
|
362
|
+
x_original = x_pixels / scale_factor
|
|
363
|
+
y_original = y_pixels / scale_factor
|
|
364
|
+
points.append([x_original, y_original])
|
|
365
|
+
bboxes.append(points)
|
|
366
|
+
texts.append(txt)
|
|
367
|
+
confs.append(conf)
|
|
368
|
+
|
|
369
|
+
return bboxes, texts, confs
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
class PaddleOCRModelInterface(OCRModelInterfaceBase):
|
|
373
|
+
"""
|
|
374
|
+
An interface for handling inference with a legacy OCR model, supporting both gRPC and HTTP protocols.
|
|
375
|
+
"""
|
|
376
|
+
|
|
377
|
+
def name(self) -> str:
|
|
378
|
+
"""
|
|
379
|
+
Get the name of the model interface.
|
|
380
|
+
|
|
381
|
+
Returns
|
|
382
|
+
-------
|
|
383
|
+
str
|
|
384
|
+
The name of the model interface.
|
|
385
|
+
"""
|
|
386
|
+
return "PaddleOCR"
|
|
387
|
+
|
|
388
|
+
def prepare_data_for_inference(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
|
389
|
+
"""
|
|
390
|
+
Decode one or more base64-encoded images into NumPy arrays, storing them
|
|
391
|
+
alongside their dimensions in `data`.
|
|
392
|
+
|
|
393
|
+
Parameters
|
|
394
|
+
----------
|
|
395
|
+
data : dict of str -> Any
|
|
396
|
+
The input data containing either:
|
|
397
|
+
- 'base64_image': a single base64-encoded image, or
|
|
398
|
+
- 'base64_images': a list of base64-encoded images.
|
|
399
|
+
|
|
400
|
+
Returns
|
|
401
|
+
-------
|
|
402
|
+
dict of str -> Any
|
|
403
|
+
The updated data dictionary with the following keys added:
|
|
404
|
+
- "images": List of decoded NumPy arrays of shape (H, W, C).
|
|
405
|
+
- "image_dims": List of (height, width) tuples for each decoded image.
|
|
406
|
+
|
|
407
|
+
Raises
|
|
408
|
+
------
|
|
409
|
+
KeyError
|
|
410
|
+
If neither 'base64_image' nor 'base64_images' is found in `data`.
|
|
411
|
+
ValueError
|
|
412
|
+
If 'base64_images' is present but is not a list.
|
|
413
|
+
"""
|
|
414
|
+
if "base64_images" in data:
|
|
415
|
+
base64_list = data["base64_images"]
|
|
416
|
+
if not isinstance(base64_list, list):
|
|
417
|
+
raise ValueError("The 'base64_images' key must contain a list of base64-encoded strings.")
|
|
418
|
+
|
|
419
|
+
images: List[np.ndarray] = []
|
|
420
|
+
for b64 in base64_list:
|
|
421
|
+
img = base64_to_numpy(b64)
|
|
422
|
+
images.append(img)
|
|
423
|
+
|
|
424
|
+
data["images"] = images
|
|
425
|
+
|
|
426
|
+
elif "base64_image" in data:
|
|
427
|
+
# Single-image fallback
|
|
428
|
+
img = base64_to_numpy(data["base64_image"])
|
|
429
|
+
data["images"] = [img]
|
|
430
|
+
|
|
431
|
+
else:
|
|
432
|
+
raise KeyError("Input data must include 'base64_image' or 'base64_images'.")
|
|
433
|
+
|
|
434
|
+
return data
|
|
435
|
+
|
|
436
|
+
def format_input(self, data: Dict[str, Any], protocol: str, max_batch_size: int, **kwargs) -> Any:
|
|
437
|
+
"""
|
|
438
|
+
Format input data for the specified protocol ("grpc" or "http"), supporting batched data.
|
|
439
|
+
|
|
440
|
+
Parameters
|
|
441
|
+
----------
|
|
442
|
+
data : dict of str -> Any
|
|
443
|
+
The input data dictionary, expected to contain "images" (list of np.ndarray)
|
|
444
|
+
and "image_dims" (list of (height, width) tuples), as produced by prepare_data_for_inference.
|
|
445
|
+
protocol : str
|
|
446
|
+
The inference protocol, either "grpc" or "http".
|
|
447
|
+
max_batch_size : int
|
|
448
|
+
The maximum batch size for batching.
|
|
449
|
+
|
|
450
|
+
Returns
|
|
451
|
+
-------
|
|
452
|
+
tuple
|
|
453
|
+
A tuple (formatted_batches, formatted_batch_data) where:
|
|
454
|
+
- formatted_batches is a list of batches ready for inference.
|
|
455
|
+
- formatted_batch_data is a list of scratch-pad dictionaries corresponding to each batch,
|
|
456
|
+
containing the keys "images" and "image_dims" for later post-processing.
|
|
457
|
+
|
|
458
|
+
Raises
|
|
459
|
+
------
|
|
460
|
+
KeyError
|
|
461
|
+
If either "images" or "image_dims" is not found in `data`.
|
|
462
|
+
ValueError
|
|
463
|
+
If an invalid protocol is specified.
|
|
464
|
+
"""
|
|
465
|
+
|
|
466
|
+
images = data["images"]
|
|
467
|
+
|
|
468
|
+
dims: List[Dict[str, Any]] = []
|
|
469
|
+
data["image_dims"] = dims
|
|
470
|
+
|
|
471
|
+
# Helper function to split a list into chunks of size up to chunk_size.
|
|
472
|
+
def chunk_list(lst, chunk_size):
|
|
473
|
+
return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
|
|
474
|
+
|
|
475
|
+
if "images" not in data or "image_dims" not in data:
|
|
476
|
+
raise KeyError("Expected 'images' and 'image_dims' in data. Call prepare_data_for_inference first.")
|
|
477
|
+
|
|
478
|
+
images = data["images"]
|
|
479
|
+
dims = data["image_dims"]
|
|
480
|
+
|
|
481
|
+
if protocol == "grpc":
|
|
482
|
+
logger.debug("Formatting input for gRPC OCR model (batched).")
|
|
483
|
+
processed: List[np.ndarray] = []
|
|
484
|
+
|
|
485
|
+
for img in images:
|
|
486
|
+
arr, _dims = preprocess_image_for_paddle(img)
|
|
487
|
+
dims.append(_dims)
|
|
488
|
+
arr = arr.astype(np.float32)
|
|
489
|
+
arr = np.expand_dims(arr, axis=0)
|
|
490
|
+
processed.append(arr)
|
|
491
|
+
|
|
492
|
+
batches = []
|
|
493
|
+
batch_data_list = []
|
|
494
|
+
for proc_chunk, orig_chunk, dims_chunk in zip(
|
|
495
|
+
chunk_list(processed, max_batch_size),
|
|
496
|
+
chunk_list(images, max_batch_size),
|
|
497
|
+
chunk_list(dims, max_batch_size),
|
|
498
|
+
):
|
|
499
|
+
batched_input = np.concatenate(proc_chunk, axis=0)
|
|
500
|
+
batches.append(batched_input)
|
|
501
|
+
batch_data_list.append({"images": orig_chunk, "image_dims": dims_chunk})
|
|
502
|
+
return batches, batch_data_list
|
|
503
|
+
|
|
504
|
+
elif protocol == "http":
|
|
505
|
+
logger.debug("Formatting input for HTTP OCR model (batched).")
|
|
506
|
+
if "base64_images" in data:
|
|
507
|
+
base64_list = data["base64_images"]
|
|
508
|
+
else:
|
|
509
|
+
base64_list = [data["base64_image"]]
|
|
510
|
+
|
|
511
|
+
input_list: List[Dict[str, Any]] = []
|
|
512
|
+
for b64, img in zip(base64_list, images):
|
|
513
|
+
image_url = f"data:image/png;base64,{b64}"
|
|
514
|
+
image_obj = {"type": "image_url", "url": image_url}
|
|
515
|
+
input_list.append(image_obj)
|
|
516
|
+
_dims = {"new_width": img.shape[1], "new_height": img.shape[0]}
|
|
517
|
+
dims.append(_dims)
|
|
518
|
+
|
|
519
|
+
batches = []
|
|
520
|
+
batch_data_list = []
|
|
521
|
+
for input_chunk, orig_chunk, dims_chunk in zip(
|
|
522
|
+
chunk_list(input_list, max_batch_size),
|
|
523
|
+
chunk_list(images, max_batch_size),
|
|
524
|
+
chunk_list(dims, max_batch_size),
|
|
525
|
+
):
|
|
526
|
+
payload = {"input": input_chunk}
|
|
527
|
+
batches.append(payload)
|
|
528
|
+
batch_data_list.append({"images": orig_chunk, "image_dims": dims_chunk})
|
|
529
|
+
|
|
530
|
+
return batches, batch_data_list
|
|
531
|
+
|
|
532
|
+
else:
|
|
533
|
+
raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.")
|
|
534
|
+
|
|
535
|
+
|
|
536
|
+
class NemoRetrieverOCRModelInterface(OCRModelInterfaceBase):
|
|
537
|
+
"""
|
|
538
|
+
An interface for handling inference with NemoRetrieverOCR model, supporting both gRPC and HTTP protocols.
|
|
539
|
+
"""
|
|
540
|
+
|
|
541
|
+
def name(self) -> str:
|
|
542
|
+
"""
|
|
543
|
+
Get the name of the model interface.
|
|
544
|
+
|
|
545
|
+
Returns
|
|
546
|
+
-------
|
|
547
|
+
str
|
|
548
|
+
The name of the model interface.
|
|
549
|
+
"""
|
|
550
|
+
return "NemoRetrieverOCR"
|
|
551
|
+
|
|
552
|
+
def prepare_data_for_inference(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
|
553
|
+
"""
|
|
554
|
+
Decode one or more base64-encoded images into NumPy arrays, storing them
|
|
555
|
+
alongside their dimensions in `data`.
|
|
556
|
+
|
|
557
|
+
Parameters
|
|
558
|
+
----------
|
|
559
|
+
data : dict of str -> Any
|
|
560
|
+
The input data containing either:
|
|
561
|
+
- 'base64_image': a single base64-encoded image, or
|
|
562
|
+
- 'base64_images': a list of base64-encoded images.
|
|
563
|
+
|
|
564
|
+
Returns
|
|
565
|
+
-------
|
|
566
|
+
dict of str -> Any
|
|
567
|
+
The updated data dictionary with the following keys added:
|
|
568
|
+
- "images": List of decoded NumPy arrays of shape (H, W, C).
|
|
569
|
+
- "image_dims": List of (height, width) tuples for each decoded image.
|
|
570
|
+
|
|
571
|
+
Raises
|
|
572
|
+
------
|
|
573
|
+
KeyError
|
|
574
|
+
If neither 'base64_image' nor 'base64_images' is found in `data`.
|
|
575
|
+
ValueError
|
|
576
|
+
If 'base64_images' is present but is not a list.
|
|
577
|
+
"""
|
|
578
|
+
if "base64_images" in data:
|
|
579
|
+
base64_list = data["base64_images"]
|
|
580
|
+
if not isinstance(base64_list, list):
|
|
581
|
+
raise ValueError("The 'base64_images' key must contain a list of base64-encoded strings.")
|
|
582
|
+
|
|
583
|
+
images: List[np.ndarray] = []
|
|
584
|
+
for b64 in base64_list:
|
|
585
|
+
img = base64_to_numpy(b64)
|
|
586
|
+
images.append(img)
|
|
587
|
+
|
|
588
|
+
data["images"] = images
|
|
589
|
+
|
|
590
|
+
elif "base64_image" in data:
|
|
591
|
+
# Single-image fallback
|
|
592
|
+
img = base64_to_numpy(data["base64_image"])
|
|
593
|
+
data["images"] = [img]
|
|
594
|
+
|
|
595
|
+
else:
|
|
596
|
+
raise KeyError("Input data must include 'base64_image' or 'base64_images'.")
|
|
597
|
+
|
|
598
|
+
return data
|
|
599
|
+
|
|
600
|
+
def coalesce_requests_to_batch(
|
|
601
|
+
self,
|
|
602
|
+
requests: List[np.ndarray],
|
|
603
|
+
original_image_shapes: List[Tuple[int, int]],
|
|
604
|
+
protocol: str,
|
|
605
|
+
**kwargs,
|
|
606
|
+
) -> Tuple[List[Any], List[Dict[str, Any]]]:
|
|
607
|
+
"""
|
|
608
|
+
Takes a list of individual data items (NumPy image arrays) and combines them
|
|
609
|
+
into a single formatted batch ready for inference.
|
|
610
|
+
|
|
611
|
+
This method mirrors the logic of `format_input` but operates on an already-formed
|
|
612
|
+
batch from the dynamic batcher, so it does not perform any chunking.
|
|
613
|
+
|
|
614
|
+
Parameters
|
|
615
|
+
----------
|
|
616
|
+
requests : List[np.ndarray]
|
|
617
|
+
A list of single data items, which are NumPy arrays representing images.
|
|
618
|
+
protocol : str
|
|
619
|
+
The inference protocol, either "grpc" or "http".
|
|
620
|
+
**kwargs : Any
|
|
621
|
+
Additional keyword arguments, such as `model_name` and `merge_level`.
|
|
622
|
+
|
|
623
|
+
Returns
|
|
624
|
+
-------
|
|
625
|
+
Tuple[List[Any], List[Dict[str, Any]]]
|
|
626
|
+
A tuple containing two lists, each with a single element:
|
|
627
|
+
- The first list contains the single formatted batch.
|
|
628
|
+
- The second list contains the single scratch-pad dictionary for that batch.
|
|
629
|
+
"""
|
|
630
|
+
if not requests:
|
|
631
|
+
return None, {}
|
|
632
|
+
|
|
633
|
+
return self._format_single_batch(requests, original_image_shapes, protocol, **kwargs)
|
|
634
|
+
|
|
635
|
+
def format_input(self, data: Dict[str, Any], protocol: str, max_batch_size: int, **kwargs) -> Any:
|
|
636
|
+
"""
|
|
637
|
+
Format input data for the specified protocol ("grpc" or "http"), supporting batched data.
|
|
638
|
+
|
|
639
|
+
Parameters
|
|
640
|
+
----------
|
|
641
|
+
data : dict of str -> Any
|
|
642
|
+
The input data dictionary, expected to contain "images" (list of np.ndarray)
|
|
643
|
+
and "image_dims" (list of (height, width) tuples), as produced by prepare_data_for_inference.
|
|
644
|
+
protocol : str
|
|
645
|
+
The inference protocol, either "grpc" or "http".
|
|
646
|
+
max_batch_size : int
|
|
647
|
+
The maximum batch size for batching.
|
|
648
|
+
|
|
649
|
+
Returns
|
|
650
|
+
-------
|
|
651
|
+
tuple
|
|
652
|
+
A tuple (formatted_batches, formatted_batch_data) where:
|
|
653
|
+
- formatted_batches is a list of batches ready for inference.
|
|
654
|
+
- formatted_batch_data is a list of scratch-pad dictionaries corresponding to each batch,
|
|
655
|
+
containing the keys "images" and "image_dims" for later post-processing.
|
|
656
|
+
|
|
657
|
+
Raises
|
|
658
|
+
------
|
|
659
|
+
KeyError
|
|
660
|
+
If either "images" or "image_dims" is not found in `data`.
|
|
661
|
+
ValueError
|
|
662
|
+
If an invalid protocol is specified.
|
|
663
|
+
"""
|
|
664
|
+
|
|
665
|
+
# Helper function to split a list into chunks of size up to chunk_size.
|
|
666
|
+
def chunk_list(lst, chunk_size):
|
|
667
|
+
return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
|
|
668
|
+
|
|
669
|
+
if "images" not in data:
|
|
670
|
+
raise KeyError("Expected 'images' in data. Call prepare_data_for_inference first.")
|
|
671
|
+
|
|
672
|
+
images = data["base64_images"]
|
|
673
|
+
dims = [img.shape[:2] for img in data["images"]]
|
|
674
|
+
|
|
675
|
+
formatted_batches = []
|
|
676
|
+
formatted_batch_data = []
|
|
677
|
+
|
|
678
|
+
image_chunks = chunk_list(images, max_batch_size)
|
|
679
|
+
dims_chunks = chunk_list(dims, max_batch_size)
|
|
680
|
+
for image_chunk, dims_chunk in zip(image_chunks, dims_chunks):
|
|
681
|
+
final_batch, batch_data = self._format_single_batch(image_chunk, dims_chunk, protocol, **kwargs)
|
|
682
|
+
formatted_batches.append(final_batch)
|
|
683
|
+
formatted_batch_data.append(batch_data)
|
|
684
|
+
|
|
685
|
+
all_dims = [item for d in formatted_batch_data for item in d.get("image_dims", [])]
|
|
686
|
+
data["image_dims"] = all_dims
|
|
687
|
+
|
|
688
|
+
return formatted_batches, formatted_batch_data
|
|
689
|
+
|
|
690
|
+
def _format_single_batch(
|
|
691
|
+
self,
|
|
692
|
+
batch_images: List[str],
|
|
693
|
+
batch_dims: List[Tuple[int, int]],
|
|
694
|
+
protocol: str,
|
|
695
|
+
**kwargs,
|
|
696
|
+
) -> Tuple[Any, Dict[str, Any]]:
|
|
697
|
+
dims: List[Dict[str, Any]] = []
|
|
698
|
+
|
|
699
|
+
merge_level = kwargs.get("merge_level", "paragraph")
|
|
700
|
+
|
|
701
|
+
if protocol == "grpc":
|
|
702
|
+
logger.debug("Formatting input for gRPC OCR model (batched).")
|
|
703
|
+
processed: List[np.ndarray] = []
|
|
704
|
+
|
|
705
|
+
for img, shape in zip(batch_images, batch_dims):
|
|
706
|
+
_dims = {"new_width": shape[1], "new_height": shape[0]}
|
|
707
|
+
dims.append(_dims)
|
|
708
|
+
|
|
709
|
+
arr = np.array([img], dtype=np.object_)
|
|
710
|
+
arr = np.expand_dims(arr, axis=0)
|
|
711
|
+
processed.append(arr)
|
|
712
|
+
|
|
713
|
+
batched_input = np.concatenate(processed, axis=0)
|
|
714
|
+
|
|
715
|
+
batch_size = batched_input.shape[0]
|
|
716
|
+
|
|
717
|
+
merge_levels_list = [[merge_level] for _ in range(batch_size)]
|
|
718
|
+
merge_levels = np.array(merge_levels_list, dtype="object")
|
|
719
|
+
|
|
720
|
+
final_batch = [batched_input, merge_levels]
|
|
721
|
+
batch_data = {"image_dims": dims}
|
|
722
|
+
|
|
723
|
+
return final_batch, batch_data
|
|
724
|
+
|
|
725
|
+
elif protocol == "http":
|
|
726
|
+
logger.debug("Formatting input for HTTP OCR model (batched).")
|
|
727
|
+
|
|
728
|
+
input_list: List[Dict[str, Any]] = []
|
|
729
|
+
for b64, shape in zip(batch_images, batch_dims):
|
|
730
|
+
image_url = f"data:image/png;base64,{b64}"
|
|
731
|
+
image_obj = {"type": "image_url", "url": image_url}
|
|
732
|
+
input_list.append(image_obj)
|
|
733
|
+
_dims = {"new_width": shape[1], "new_height": shape[0]}
|
|
734
|
+
dims.append(_dims)
|
|
735
|
+
|
|
736
|
+
payload = {
|
|
737
|
+
"input": input_list,
|
|
738
|
+
"merge_levels": [merge_level] * len(input_list),
|
|
739
|
+
}
|
|
740
|
+
|
|
741
|
+
batch_data = {"image_dims": dims}
|
|
742
|
+
|
|
743
|
+
return payload, batch_data
|
|
744
|
+
|
|
745
|
+
else:
|
|
746
|
+
raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.")
|
|
747
|
+
|
|
748
|
+
|
|
749
|
+
@multiprocessing_cache(max_calls=100) # Cache results first to avoid redundant retries from backoff
|
|
750
|
+
@backoff.on_predicate(backoff.expo, max_time=30)
|
|
751
|
+
def get_ocr_model_name(ocr_grpc_endpoint=None, default_model_name=DEFAULT_OCR_MODEL_NAME):
|
|
752
|
+
"""
|
|
753
|
+
Determines the OCR model name by checking the environment, querying the gRPC endpoint,
|
|
754
|
+
or falling back to a default.
|
|
755
|
+
"""
|
|
756
|
+
# 1. Check for an explicit override from the environment variable first.
|
|
757
|
+
ocr_model_name = os.getenv("OCR_MODEL_NAME", None)
|
|
758
|
+
if ocr_model_name is not None:
|
|
759
|
+
return ocr_model_name
|
|
760
|
+
|
|
761
|
+
# 2. If no gRPC endpoint is provided or the endpoint is a NVCF endpoint, fall back to the default immediately.
|
|
762
|
+
if (not ocr_grpc_endpoint) or ("grpc.nvcf.nvidia.com" in ocr_grpc_endpoint):
|
|
763
|
+
logger.debug(f"No OCR gRPC endpoint provided. Falling back to default model name '{default_model_name}'.")
|
|
764
|
+
return default_model_name
|
|
765
|
+
|
|
766
|
+
# 3. Attempt to query the gRPC endpoint to discover the model name.
|
|
767
|
+
try:
|
|
768
|
+
client = grpcclient.InferenceServerClient(ocr_grpc_endpoint)
|
|
769
|
+
model_index = client.get_model_repository_index(as_json=True)
|
|
770
|
+
model_names = [x["name"] for x in model_index.get("models", [])]
|
|
771
|
+
ocr_model_name = model_names[0]
|
|
772
|
+
except Exception:
|
|
773
|
+
logger.warning(f"Failed to get ocr model name after 30 seconds. Falling back to '{default_model_name}'.")
|
|
774
|
+
ocr_model_name = default_model_name
|
|
775
|
+
|
|
776
|
+
return ocr_model_name
|