nv-ingest-api 2025.9.22.dev20250922__py3-none-any.whl → 2025.9.23.dev20250923__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nv-ingest-api might be problematic. Click here for more details.
- nv_ingest_api/internal/extract/image/chart_extractor.py +38 -35
- nv_ingest_api/internal/extract/image/image_helpers/common.py +0 -4
- nv_ingest_api/internal/extract/image/infographic_extractor.py +23 -25
- nv_ingest_api/internal/extract/image/table_extractor.py +43 -30
- nv_ingest_api/internal/extract/pdf/engines/pdfium.py +0 -7
- nv_ingest_api/internal/primitives/nim/__init__.py +2 -1
- nv_ingest_api/internal/primitives/nim/model_interface/helpers.py +8 -4
- nv_ingest_api/internal/primitives/nim/model_interface/ocr.py +402 -211
- nv_ingest_api/internal/primitives/nim/nim_client.py +271 -16
- nv_ingest_api/internal/primitives/nim/nim_model_interface.py +45 -0
- nv_ingest_api/util/nim/__init__.py +19 -6
- {nv_ingest_api-2025.9.22.dev20250922.dist-info → nv_ingest_api-2025.9.23.dev20250923.dist-info}/METADATA +1 -1
- {nv_ingest_api-2025.9.22.dev20250922.dist-info → nv_ingest_api-2025.9.23.dev20250923.dist-info}/RECORD +16 -16
- {nv_ingest_api-2025.9.22.dev20250922.dist-info → nv_ingest_api-2025.9.23.dev20250923.dist-info}/WHEEL +0 -0
- {nv_ingest_api-2025.9.22.dev20250922.dist-info → nv_ingest_api-2025.9.23.dev20250923.dist-info}/licenses/LICENSE +0 -0
- {nv_ingest_api-2025.9.22.dev20250922.dist-info → nv_ingest_api-2025.9.23.dev20250923.dist-info}/top_level.txt +0 -0
|
@@ -16,225 +16,20 @@ import numpy as np
|
|
|
16
16
|
import tritonclient.grpc as grpcclient
|
|
17
17
|
|
|
18
18
|
from nv_ingest_api.internal.primitives.nim import ModelInterface
|
|
19
|
-
from nv_ingest_api.internal.primitives.nim.model_interface.decorators import
|
|
20
|
-
|
|
21
|
-
)
|
|
22
|
-
from nv_ingest_api.internal.primitives.nim.model_interface.helpers import (
|
|
23
|
-
preprocess_image_for_ocr,
|
|
24
|
-
)
|
|
25
|
-
from nv_ingest_api.internal.primitives.nim.model_interface.helpers import (
|
|
26
|
-
preprocess_image_for_paddle,
|
|
27
|
-
)
|
|
19
|
+
from nv_ingest_api.internal.primitives.nim.model_interface.decorators import multiprocessing_cache
|
|
20
|
+
from nv_ingest_api.internal.primitives.nim.model_interface.helpers import preprocess_image_for_paddle
|
|
28
21
|
from nv_ingest_api.util.image_processing.transforms import base64_to_numpy
|
|
29
|
-
from nv_ingest_api.util.image_processing.transforms import numpy_to_base64
|
|
30
22
|
|
|
31
23
|
DEFAULT_OCR_MODEL_NAME = "paddle"
|
|
32
|
-
NEMORETRIEVER_OCR_EA_MODEL_NAME = "scene_text"
|
|
33
24
|
NEMORETRIEVER_OCR_MODEL_NAME = "scene_text_ensemble"
|
|
34
25
|
|
|
35
26
|
logger = logging.getLogger(__name__)
|
|
36
27
|
|
|
37
28
|
|
|
38
|
-
class
|
|
39
|
-
"""
|
|
40
|
-
An interface for handling inference with a OCR model, supporting both gRPC and HTTP protocols.
|
|
41
|
-
"""
|
|
42
|
-
|
|
43
|
-
def name(self) -> str:
|
|
44
|
-
"""
|
|
45
|
-
Get the name of the model interface.
|
|
46
|
-
|
|
47
|
-
Returns
|
|
48
|
-
-------
|
|
49
|
-
str
|
|
50
|
-
The name of the model interface.
|
|
51
|
-
"""
|
|
52
|
-
return "OCR"
|
|
53
|
-
|
|
54
|
-
def prepare_data_for_inference(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
|
55
|
-
"""
|
|
56
|
-
Decode one or more base64-encoded images into NumPy arrays, storing them
|
|
57
|
-
alongside their dimensions in `data`.
|
|
58
|
-
|
|
59
|
-
Parameters
|
|
60
|
-
----------
|
|
61
|
-
data : dict of str -> Any
|
|
62
|
-
The input data containing either:
|
|
63
|
-
- 'base64_image': a single base64-encoded image, or
|
|
64
|
-
- 'base64_images': a list of base64-encoded images.
|
|
65
|
-
|
|
66
|
-
Returns
|
|
67
|
-
-------
|
|
68
|
-
dict of str -> Any
|
|
69
|
-
The updated data dictionary with the following keys added:
|
|
70
|
-
- "image_arrays": List of decoded NumPy arrays of shape (H, W, C).
|
|
71
|
-
- "image_dims": List of (height, width) tuples for each decoded image.
|
|
72
|
-
|
|
73
|
-
Raises
|
|
74
|
-
------
|
|
75
|
-
KeyError
|
|
76
|
-
If neither 'base64_image' nor 'base64_images' is found in `data`.
|
|
77
|
-
ValueError
|
|
78
|
-
If 'base64_images' is present but is not a list.
|
|
79
|
-
"""
|
|
80
|
-
if "base64_images" in data:
|
|
81
|
-
base64_list = data["base64_images"]
|
|
82
|
-
if not isinstance(base64_list, list):
|
|
83
|
-
raise ValueError("The 'base64_images' key must contain a list of base64-encoded strings.")
|
|
84
|
-
|
|
85
|
-
image_arrays: List[np.ndarray] = []
|
|
86
|
-
for b64 in base64_list:
|
|
87
|
-
img = base64_to_numpy(b64)
|
|
88
|
-
image_arrays.append(img)
|
|
89
|
-
|
|
90
|
-
data["image_arrays"] = image_arrays
|
|
91
|
-
|
|
92
|
-
elif "base64_image" in data:
|
|
93
|
-
# Single-image fallback
|
|
94
|
-
img = base64_to_numpy(data["base64_image"])
|
|
95
|
-
data["image_arrays"] = [img]
|
|
96
|
-
|
|
97
|
-
else:
|
|
98
|
-
raise KeyError("Input data must include 'base64_image' or 'base64_images'.")
|
|
99
|
-
|
|
100
|
-
return data
|
|
101
|
-
|
|
102
|
-
def format_input(self, data: Dict[str, Any], protocol: str, max_batch_size: int, **kwargs) -> Any:
|
|
103
|
-
"""
|
|
104
|
-
Format input data for the specified protocol ("grpc" or "http"), supporting batched data.
|
|
105
|
-
|
|
106
|
-
Parameters
|
|
107
|
-
----------
|
|
108
|
-
data : dict of str -> Any
|
|
109
|
-
The input data dictionary, expected to contain "image_arrays" (list of np.ndarray)
|
|
110
|
-
and "image_dims" (list of (height, width) tuples), as produced by prepare_data_for_inference.
|
|
111
|
-
protocol : str
|
|
112
|
-
The inference protocol, either "grpc" or "http".
|
|
113
|
-
max_batch_size : int
|
|
114
|
-
The maximum batch size for batching.
|
|
115
|
-
|
|
116
|
-
Returns
|
|
117
|
-
-------
|
|
118
|
-
tuple
|
|
119
|
-
A tuple (formatted_batches, formatted_batch_data) where:
|
|
120
|
-
- formatted_batches is a list of batches ready for inference.
|
|
121
|
-
- formatted_batch_data is a list of scratch-pad dictionaries corresponding to each batch,
|
|
122
|
-
containing the keys "image_arrays" and "image_dims" for later post-processing.
|
|
123
|
-
|
|
124
|
-
Raises
|
|
125
|
-
------
|
|
126
|
-
KeyError
|
|
127
|
-
If either "image_arrays" or "image_dims" is not found in `data`.
|
|
128
|
-
ValueError
|
|
129
|
-
If an invalid protocol is specified.
|
|
130
|
-
"""
|
|
131
|
-
|
|
132
|
-
images = data["image_arrays"]
|
|
133
|
-
|
|
134
|
-
dims: List[Dict[str, Any]] = []
|
|
135
|
-
data["image_dims"] = dims
|
|
136
|
-
|
|
137
|
-
# Helper function to split a list into chunks of size up to chunk_size.
|
|
138
|
-
def chunk_list(lst, chunk_size):
|
|
139
|
-
return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
|
|
140
|
-
|
|
141
|
-
if "image_arrays" not in data or "image_dims" not in data:
|
|
142
|
-
raise KeyError("Expected 'image_arrays' and 'image_dims' in data. Call prepare_data_for_inference first.")
|
|
143
|
-
|
|
144
|
-
images = data["image_arrays"]
|
|
145
|
-
dims = data["image_dims"]
|
|
146
|
-
|
|
147
|
-
model_name = kwargs.get("model_name", DEFAULT_OCR_MODEL_NAME)
|
|
148
|
-
merge_level = kwargs.get("merge_level", "paragraph")
|
|
149
|
-
|
|
150
|
-
if protocol == "grpc":
|
|
151
|
-
logger.debug("Formatting input for gRPC OCR model (batched).")
|
|
152
|
-
processed: List[np.ndarray] = []
|
|
153
|
-
|
|
154
|
-
max_length = max(max(img.shape[:2]) for img in images)
|
|
155
|
-
max_length = min(max_length, 65500) # Maximum supported image dimension for JPEG is 65500 pixels.
|
|
156
|
-
|
|
157
|
-
for img in images:
|
|
158
|
-
if model_name == DEFAULT_OCR_MODEL_NAME:
|
|
159
|
-
arr, _dims = preprocess_image_for_paddle(img)
|
|
160
|
-
elif model_name == NEMORETRIEVER_OCR_EA_MODEL_NAME:
|
|
161
|
-
arr, _dims = preprocess_image_for_ocr(
|
|
162
|
-
img,
|
|
163
|
-
target_height=max_length,
|
|
164
|
-
target_width=max_length,
|
|
165
|
-
pad_how="bottom_right",
|
|
166
|
-
)
|
|
167
|
-
elif model_name == NEMORETRIEVER_OCR_MODEL_NAME:
|
|
168
|
-
arr = img
|
|
169
|
-
_dims = {"new_width": img.shape[1], "new_height": img.shape[0]}
|
|
170
|
-
else:
|
|
171
|
-
raise ValueError(f"Unknown model name: {model_name}")
|
|
172
|
-
|
|
173
|
-
dims.append(_dims)
|
|
174
|
-
|
|
175
|
-
if model_name == NEMORETRIEVER_OCR_MODEL_NAME:
|
|
176
|
-
arr = np.array([numpy_to_base64(arr, format="JPEG")], dtype=np.object_)
|
|
177
|
-
else:
|
|
178
|
-
arr = arr.astype(np.float32)
|
|
179
|
-
|
|
180
|
-
arr = np.expand_dims(arr, axis=0)
|
|
181
|
-
|
|
182
|
-
processed.append(arr)
|
|
183
|
-
|
|
184
|
-
batches = []
|
|
185
|
-
batch_data_list = []
|
|
186
|
-
for proc_chunk, orig_chunk, dims_chunk in zip(
|
|
187
|
-
chunk_list(processed, max_batch_size),
|
|
188
|
-
chunk_list(images, max_batch_size),
|
|
189
|
-
chunk_list(dims, max_batch_size),
|
|
190
|
-
):
|
|
191
|
-
batched_input = np.concatenate(proc_chunk, axis=0)
|
|
192
|
-
|
|
193
|
-
if model_name == DEFAULT_OCR_MODEL_NAME:
|
|
194
|
-
batches.append(batched_input)
|
|
195
|
-
else:
|
|
196
|
-
merge_levels = np.array([[merge_level] * len(batched_input)], dtype="object")
|
|
197
|
-
batches.append([batched_input, merge_levels])
|
|
198
|
-
|
|
199
|
-
batch_data_list.append({"image_arrays": orig_chunk, "image_dims": dims_chunk})
|
|
200
|
-
return batches, batch_data_list
|
|
201
|
-
|
|
202
|
-
elif protocol == "http":
|
|
203
|
-
logger.debug("Formatting input for HTTP OCR model (batched).")
|
|
204
|
-
if "base64_images" in data:
|
|
205
|
-
base64_list = data["base64_images"]
|
|
206
|
-
else:
|
|
207
|
-
base64_list = [data["base64_image"]]
|
|
208
|
-
|
|
209
|
-
input_list: List[Dict[str, Any]] = []
|
|
210
|
-
for b64, img in zip(base64_list, images):
|
|
211
|
-
image_url = f"data:image/png;base64,{b64}"
|
|
212
|
-
image_obj = {"type": "image_url", "url": image_url}
|
|
213
|
-
input_list.append(image_obj)
|
|
214
|
-
_dims = {"new_width": img.shape[1], "new_height": img.shape[0]}
|
|
215
|
-
dims.append(_dims)
|
|
216
|
-
|
|
217
|
-
batches = []
|
|
218
|
-
batch_data_list = []
|
|
219
|
-
for input_chunk, orig_chunk, dims_chunk in zip(
|
|
220
|
-
chunk_list(input_list, max_batch_size),
|
|
221
|
-
chunk_list(images, max_batch_size),
|
|
222
|
-
chunk_list(dims, max_batch_size),
|
|
223
|
-
):
|
|
224
|
-
if model_name == DEFAULT_OCR_MODEL_NAME:
|
|
225
|
-
payload = {"input": input_chunk}
|
|
226
|
-
else:
|
|
227
|
-
payload = {
|
|
228
|
-
"input": input_chunk,
|
|
229
|
-
"merge_levels": [merge_level] * len(input_chunk),
|
|
230
|
-
}
|
|
231
|
-
batches.append(payload)
|
|
232
|
-
batch_data_list.append({"image_arrays": orig_chunk, "image_dims": dims_chunk})
|
|
233
|
-
|
|
234
|
-
return batches, batch_data_list
|
|
29
|
+
class OCRModelInterfaceBase(ModelInterface):
|
|
235
30
|
|
|
236
|
-
|
|
237
|
-
|
|
31
|
+
NUM_CHANNELS = 3
|
|
32
|
+
BYTES_PER_ELEMENT = 4 # For float32
|
|
238
33
|
|
|
239
34
|
def parse_output(
|
|
240
35
|
self,
|
|
@@ -302,6 +97,25 @@ class OCRModelInterface(ModelInterface):
|
|
|
302
97
|
"""
|
|
303
98
|
return output
|
|
304
99
|
|
|
100
|
+
def does_item_fit_in_batch(self, current_batch, next_request, memory_budget_bytes: int) -> bool:
|
|
101
|
+
"""
|
|
102
|
+
Estimates the memory of a potential batch of padded images and checks it
|
|
103
|
+
against the configured budget.
|
|
104
|
+
"""
|
|
105
|
+
all_requests = current_batch + [next_request]
|
|
106
|
+
all_dims = [req.dims for req in all_requests]
|
|
107
|
+
|
|
108
|
+
potential_max_h = max(d[0] for d in all_dims)
|
|
109
|
+
potential_max_w = max(d[1] for d in all_dims)
|
|
110
|
+
|
|
111
|
+
potential_batch_size = len(all_requests)
|
|
112
|
+
|
|
113
|
+
potential_memory_bytes = (
|
|
114
|
+
potential_batch_size * potential_max_h * potential_max_w * self.NUM_CHANNELS * self.BYTES_PER_ELEMENT
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
return potential_memory_bytes <= memory_budget_bytes
|
|
118
|
+
|
|
305
119
|
def _prepare_ocr_payload(self, base64_img: str) -> Dict[str, Any]:
|
|
306
120
|
"""
|
|
307
121
|
DEPRECATED by batch logic in format_input. Kept here if you need single-image direct calls.
|
|
@@ -462,7 +276,7 @@ class OCRModelInterface(ModelInterface):
|
|
|
462
276
|
conf_scores,
|
|
463
277
|
dimensions,
|
|
464
278
|
img_index=i,
|
|
465
|
-
scale_coordinates=
|
|
279
|
+
scale_coordinates=True,
|
|
466
280
|
)
|
|
467
281
|
|
|
468
282
|
results.append([bounding_boxes, text_predictions, conf_scores])
|
|
@@ -548,6 +362,383 @@ class OCRModelInterface(ModelInterface):
|
|
|
548
362
|
return bboxes, texts, confs
|
|
549
363
|
|
|
550
364
|
|
|
365
|
+
class PaddleOCRModelInterface(OCRModelInterfaceBase):
|
|
366
|
+
"""
|
|
367
|
+
An interface for handling inference with a legacy OCR model, supporting both gRPC and HTTP protocols.
|
|
368
|
+
"""
|
|
369
|
+
|
|
370
|
+
def name(self) -> str:
|
|
371
|
+
"""
|
|
372
|
+
Get the name of the model interface.
|
|
373
|
+
|
|
374
|
+
Returns
|
|
375
|
+
-------
|
|
376
|
+
str
|
|
377
|
+
The name of the model interface.
|
|
378
|
+
"""
|
|
379
|
+
return "PaddleOCR"
|
|
380
|
+
|
|
381
|
+
def prepare_data_for_inference(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
|
382
|
+
"""
|
|
383
|
+
Decode one or more base64-encoded images into NumPy arrays, storing them
|
|
384
|
+
alongside their dimensions in `data`.
|
|
385
|
+
|
|
386
|
+
Parameters
|
|
387
|
+
----------
|
|
388
|
+
data : dict of str -> Any
|
|
389
|
+
The input data containing either:
|
|
390
|
+
- 'base64_image': a single base64-encoded image, or
|
|
391
|
+
- 'base64_images': a list of base64-encoded images.
|
|
392
|
+
|
|
393
|
+
Returns
|
|
394
|
+
-------
|
|
395
|
+
dict of str -> Any
|
|
396
|
+
The updated data dictionary with the following keys added:
|
|
397
|
+
- "images": List of decoded NumPy arrays of shape (H, W, C).
|
|
398
|
+
- "image_dims": List of (height, width) tuples for each decoded image.
|
|
399
|
+
|
|
400
|
+
Raises
|
|
401
|
+
------
|
|
402
|
+
KeyError
|
|
403
|
+
If neither 'base64_image' nor 'base64_images' is found in `data`.
|
|
404
|
+
ValueError
|
|
405
|
+
If 'base64_images' is present but is not a list.
|
|
406
|
+
"""
|
|
407
|
+
if "base64_images" in data:
|
|
408
|
+
base64_list = data["base64_images"]
|
|
409
|
+
if not isinstance(base64_list, list):
|
|
410
|
+
raise ValueError("The 'base64_images' key must contain a list of base64-encoded strings.")
|
|
411
|
+
|
|
412
|
+
images: List[np.ndarray] = []
|
|
413
|
+
for b64 in base64_list:
|
|
414
|
+
img = base64_to_numpy(b64)
|
|
415
|
+
images.append(img)
|
|
416
|
+
|
|
417
|
+
data["images"] = images
|
|
418
|
+
|
|
419
|
+
elif "base64_image" in data:
|
|
420
|
+
# Single-image fallback
|
|
421
|
+
img = base64_to_numpy(data["base64_image"])
|
|
422
|
+
data["images"] = [img]
|
|
423
|
+
|
|
424
|
+
else:
|
|
425
|
+
raise KeyError("Input data must include 'base64_image' or 'base64_images'.")
|
|
426
|
+
|
|
427
|
+
return data
|
|
428
|
+
|
|
429
|
+
def format_input(self, data: Dict[str, Any], protocol: str, max_batch_size: int, **kwargs) -> Any:
|
|
430
|
+
"""
|
|
431
|
+
Format input data for the specified protocol ("grpc" or "http"), supporting batched data.
|
|
432
|
+
|
|
433
|
+
Parameters
|
|
434
|
+
----------
|
|
435
|
+
data : dict of str -> Any
|
|
436
|
+
The input data dictionary, expected to contain "images" (list of np.ndarray)
|
|
437
|
+
and "image_dims" (list of (height, width) tuples), as produced by prepare_data_for_inference.
|
|
438
|
+
protocol : str
|
|
439
|
+
The inference protocol, either "grpc" or "http".
|
|
440
|
+
max_batch_size : int
|
|
441
|
+
The maximum batch size for batching.
|
|
442
|
+
|
|
443
|
+
Returns
|
|
444
|
+
-------
|
|
445
|
+
tuple
|
|
446
|
+
A tuple (formatted_batches, formatted_batch_data) where:
|
|
447
|
+
- formatted_batches is a list of batches ready for inference.
|
|
448
|
+
- formatted_batch_data is a list of scratch-pad dictionaries corresponding to each batch,
|
|
449
|
+
containing the keys "images" and "image_dims" for later post-processing.
|
|
450
|
+
|
|
451
|
+
Raises
|
|
452
|
+
------
|
|
453
|
+
KeyError
|
|
454
|
+
If either "images" or "image_dims" is not found in `data`.
|
|
455
|
+
ValueError
|
|
456
|
+
If an invalid protocol is specified.
|
|
457
|
+
"""
|
|
458
|
+
|
|
459
|
+
images = data["images"]
|
|
460
|
+
|
|
461
|
+
dims: List[Dict[str, Any]] = []
|
|
462
|
+
data["image_dims"] = dims
|
|
463
|
+
|
|
464
|
+
# Helper function to split a list into chunks of size up to chunk_size.
|
|
465
|
+
def chunk_list(lst, chunk_size):
|
|
466
|
+
return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
|
|
467
|
+
|
|
468
|
+
if "images" not in data or "image_dims" not in data:
|
|
469
|
+
raise KeyError("Expected 'images' and 'image_dims' in data. Call prepare_data_for_inference first.")
|
|
470
|
+
|
|
471
|
+
images = data["images"]
|
|
472
|
+
dims = data["image_dims"]
|
|
473
|
+
|
|
474
|
+
if protocol == "grpc":
|
|
475
|
+
logger.debug("Formatting input for gRPC OCR model (batched).")
|
|
476
|
+
processed: List[np.ndarray] = []
|
|
477
|
+
|
|
478
|
+
for img in images:
|
|
479
|
+
arr, _dims = preprocess_image_for_paddle(img)
|
|
480
|
+
dims.append(_dims)
|
|
481
|
+
arr = arr.astype(np.float32)
|
|
482
|
+
arr = np.expand_dims(arr, axis=0)
|
|
483
|
+
processed.append(arr)
|
|
484
|
+
|
|
485
|
+
batches = []
|
|
486
|
+
batch_data_list = []
|
|
487
|
+
for proc_chunk, orig_chunk, dims_chunk in zip(
|
|
488
|
+
chunk_list(processed, max_batch_size),
|
|
489
|
+
chunk_list(images, max_batch_size),
|
|
490
|
+
chunk_list(dims, max_batch_size),
|
|
491
|
+
):
|
|
492
|
+
batched_input = np.concatenate(proc_chunk, axis=0)
|
|
493
|
+
batches.append(batched_input)
|
|
494
|
+
batch_data_list.append({"images": orig_chunk, "image_dims": dims_chunk})
|
|
495
|
+
return batches, batch_data_list
|
|
496
|
+
|
|
497
|
+
elif protocol == "http":
|
|
498
|
+
logger.debug("Formatting input for HTTP OCR model (batched).")
|
|
499
|
+
if "base64_images" in data:
|
|
500
|
+
base64_list = data["base64_images"]
|
|
501
|
+
else:
|
|
502
|
+
base64_list = [data["base64_image"]]
|
|
503
|
+
|
|
504
|
+
input_list: List[Dict[str, Any]] = []
|
|
505
|
+
for b64, img in zip(base64_list, images):
|
|
506
|
+
image_url = f"data:image/png;base64,{b64}"
|
|
507
|
+
image_obj = {"type": "image_url", "url": image_url}
|
|
508
|
+
input_list.append(image_obj)
|
|
509
|
+
_dims = {"new_width": img.shape[1], "new_height": img.shape[0]}
|
|
510
|
+
dims.append(_dims)
|
|
511
|
+
|
|
512
|
+
batches = []
|
|
513
|
+
batch_data_list = []
|
|
514
|
+
for input_chunk, orig_chunk, dims_chunk in zip(
|
|
515
|
+
chunk_list(input_list, max_batch_size),
|
|
516
|
+
chunk_list(images, max_batch_size),
|
|
517
|
+
chunk_list(dims, max_batch_size),
|
|
518
|
+
):
|
|
519
|
+
payload = {"input": input_chunk}
|
|
520
|
+
batches.append(payload)
|
|
521
|
+
batch_data_list.append({"images": orig_chunk, "image_dims": dims_chunk})
|
|
522
|
+
|
|
523
|
+
return batches, batch_data_list
|
|
524
|
+
|
|
525
|
+
else:
|
|
526
|
+
raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.")
|
|
527
|
+
|
|
528
|
+
|
|
529
|
+
class NemoRetrieverOCRModelInterface(OCRModelInterfaceBase):
|
|
530
|
+
"""
|
|
531
|
+
An interface for handling inference with NemoRetrieverOCR model, supporting both gRPC and HTTP protocols.
|
|
532
|
+
"""
|
|
533
|
+
|
|
534
|
+
def name(self) -> str:
|
|
535
|
+
"""
|
|
536
|
+
Get the name of the model interface.
|
|
537
|
+
|
|
538
|
+
Returns
|
|
539
|
+
-------
|
|
540
|
+
str
|
|
541
|
+
The name of the model interface.
|
|
542
|
+
"""
|
|
543
|
+
return "NemoRetrieverOCR"
|
|
544
|
+
|
|
545
|
+
def prepare_data_for_inference(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
|
546
|
+
"""
|
|
547
|
+
Decode one or more base64-encoded images into NumPy arrays, storing them
|
|
548
|
+
alongside their dimensions in `data`.
|
|
549
|
+
|
|
550
|
+
Parameters
|
|
551
|
+
----------
|
|
552
|
+
data : dict of str -> Any
|
|
553
|
+
The input data containing either:
|
|
554
|
+
- 'base64_image': a single base64-encoded image, or
|
|
555
|
+
- 'base64_images': a list of base64-encoded images.
|
|
556
|
+
|
|
557
|
+
Returns
|
|
558
|
+
-------
|
|
559
|
+
dict of str -> Any
|
|
560
|
+
The updated data dictionary with the following keys added:
|
|
561
|
+
- "images": List of decoded NumPy arrays of shape (H, W, C).
|
|
562
|
+
- "image_dims": List of (height, width) tuples for each decoded image.
|
|
563
|
+
|
|
564
|
+
Raises
|
|
565
|
+
------
|
|
566
|
+
KeyError
|
|
567
|
+
If neither 'base64_image' nor 'base64_images' is found in `data`.
|
|
568
|
+
ValueError
|
|
569
|
+
If 'base64_images' is present but is not a list.
|
|
570
|
+
"""
|
|
571
|
+
if "base64_images" in data:
|
|
572
|
+
base64_list = data["base64_images"]
|
|
573
|
+
if not isinstance(base64_list, list):
|
|
574
|
+
raise ValueError("The 'base64_images' key must contain a list of base64-encoded strings.")
|
|
575
|
+
|
|
576
|
+
images: List[np.ndarray] = []
|
|
577
|
+
for b64 in base64_list:
|
|
578
|
+
img = base64_to_numpy(b64)
|
|
579
|
+
images.append(img)
|
|
580
|
+
|
|
581
|
+
data["images"] = images
|
|
582
|
+
|
|
583
|
+
elif "base64_image" in data:
|
|
584
|
+
# Single-image fallback
|
|
585
|
+
img = base64_to_numpy(data["base64_image"])
|
|
586
|
+
data["images"] = [img]
|
|
587
|
+
|
|
588
|
+
else:
|
|
589
|
+
raise KeyError("Input data must include 'base64_image' or 'base64_images'.")
|
|
590
|
+
|
|
591
|
+
return data
|
|
592
|
+
|
|
593
|
+
def coalesce_requests_to_batch(
|
|
594
|
+
self,
|
|
595
|
+
requests: List[np.ndarray],
|
|
596
|
+
original_image_shapes: List[Tuple[int, int]],
|
|
597
|
+
protocol: str,
|
|
598
|
+
**kwargs,
|
|
599
|
+
) -> Tuple[List[Any], List[Dict[str, Any]]]:
|
|
600
|
+
"""
|
|
601
|
+
Takes a list of individual data items (NumPy image arrays) and combines them
|
|
602
|
+
into a single formatted batch ready for inference.
|
|
603
|
+
|
|
604
|
+
This method mirrors the logic of `format_input` but operates on an already-formed
|
|
605
|
+
batch from the dynamic batcher, so it does not perform any chunking.
|
|
606
|
+
|
|
607
|
+
Parameters
|
|
608
|
+
----------
|
|
609
|
+
requests : List[np.ndarray]
|
|
610
|
+
A list of single data items, which are NumPy arrays representing images.
|
|
611
|
+
protocol : str
|
|
612
|
+
The inference protocol, either "grpc" or "http".
|
|
613
|
+
**kwargs : Any
|
|
614
|
+
Additional keyword arguments, such as `model_name` and `merge_level`.
|
|
615
|
+
|
|
616
|
+
Returns
|
|
617
|
+
-------
|
|
618
|
+
Tuple[List[Any], List[Dict[str, Any]]]
|
|
619
|
+
A tuple containing two lists, each with a single element:
|
|
620
|
+
- The first list contains the single formatted batch.
|
|
621
|
+
- The second list contains the single scratch-pad dictionary for that batch.
|
|
622
|
+
"""
|
|
623
|
+
if not requests:
|
|
624
|
+
return None, {}
|
|
625
|
+
|
|
626
|
+
return self._format_single_batch(requests, original_image_shapes, protocol, **kwargs)
|
|
627
|
+
|
|
628
|
+
def format_input(self, data: Dict[str, Any], protocol: str, max_batch_size: int, **kwargs) -> Any:
|
|
629
|
+
"""
|
|
630
|
+
Format input data for the specified protocol ("grpc" or "http"), supporting batched data.
|
|
631
|
+
|
|
632
|
+
Parameters
|
|
633
|
+
----------
|
|
634
|
+
data : dict of str -> Any
|
|
635
|
+
The input data dictionary, expected to contain "images" (list of np.ndarray)
|
|
636
|
+
and "image_dims" (list of (height, width) tuples), as produced by prepare_data_for_inference.
|
|
637
|
+
protocol : str
|
|
638
|
+
The inference protocol, either "grpc" or "http".
|
|
639
|
+
max_batch_size : int
|
|
640
|
+
The maximum batch size for batching.
|
|
641
|
+
|
|
642
|
+
Returns
|
|
643
|
+
-------
|
|
644
|
+
tuple
|
|
645
|
+
A tuple (formatted_batches, formatted_batch_data) where:
|
|
646
|
+
- formatted_batches is a list of batches ready for inference.
|
|
647
|
+
- formatted_batch_data is a list of scratch-pad dictionaries corresponding to each batch,
|
|
648
|
+
containing the keys "images" and "image_dims" for later post-processing.
|
|
649
|
+
|
|
650
|
+
Raises
|
|
651
|
+
------
|
|
652
|
+
KeyError
|
|
653
|
+
If either "images" or "image_dims" is not found in `data`.
|
|
654
|
+
ValueError
|
|
655
|
+
If an invalid protocol is specified.
|
|
656
|
+
"""
|
|
657
|
+
|
|
658
|
+
# Helper function to split a list into chunks of size up to chunk_size.
|
|
659
|
+
def chunk_list(lst, chunk_size):
|
|
660
|
+
return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
|
|
661
|
+
|
|
662
|
+
if "images" not in data:
|
|
663
|
+
raise KeyError("Expected 'images' in data. Call prepare_data_for_inference first.")
|
|
664
|
+
|
|
665
|
+
images = data["base64_images"]
|
|
666
|
+
dims = [img.shape[:2] for img in data["images"]]
|
|
667
|
+
|
|
668
|
+
formatted_batches = []
|
|
669
|
+
formatted_batch_data = []
|
|
670
|
+
|
|
671
|
+
image_chunks = chunk_list(images, max_batch_size)
|
|
672
|
+
dims_chunks = chunk_list(dims, max_batch_size)
|
|
673
|
+
for image_chunk, dims_chunk in zip(image_chunks, dims_chunks):
|
|
674
|
+
final_batch, batch_data = self._format_single_batch(image_chunk, dims_chunk, protocol, **kwargs)
|
|
675
|
+
formatted_batches.append(final_batch)
|
|
676
|
+
formatted_batch_data.append(batch_data)
|
|
677
|
+
|
|
678
|
+
all_dims = [item for d in formatted_batch_data for item in d.get("image_dims", [])]
|
|
679
|
+
data["image_dims"] = all_dims
|
|
680
|
+
|
|
681
|
+
return formatted_batches, formatted_batch_data
|
|
682
|
+
|
|
683
|
+
def _format_single_batch(
|
|
684
|
+
self,
|
|
685
|
+
batch_images: List[str],
|
|
686
|
+
batch_dims: List[Tuple[int, int]],
|
|
687
|
+
protocol: str,
|
|
688
|
+
**kwargs,
|
|
689
|
+
) -> Tuple[Any, Dict[str, Any]]:
|
|
690
|
+
dims: List[Dict[str, Any]] = []
|
|
691
|
+
|
|
692
|
+
merge_level = kwargs.get("merge_level", "paragraph")
|
|
693
|
+
|
|
694
|
+
if protocol == "grpc":
|
|
695
|
+
logger.debug("Formatting input for gRPC OCR model (batched).")
|
|
696
|
+
processed: List[np.ndarray] = []
|
|
697
|
+
|
|
698
|
+
for img, shape in zip(batch_images, batch_dims):
|
|
699
|
+
_dims = {"new_width": shape[1], "new_height": shape[0]}
|
|
700
|
+
dims.append(_dims)
|
|
701
|
+
|
|
702
|
+
arr = np.array([img], dtype=np.object_)
|
|
703
|
+
arr = np.expand_dims(arr, axis=0)
|
|
704
|
+
processed.append(arr)
|
|
705
|
+
|
|
706
|
+
batched_input = np.concatenate(processed, axis=0)
|
|
707
|
+
|
|
708
|
+
batch_size = batched_input.shape[0]
|
|
709
|
+
|
|
710
|
+
merge_levels_list = [[merge_level] for _ in range(batch_size)]
|
|
711
|
+
merge_levels = np.array(merge_levels_list, dtype="object")
|
|
712
|
+
|
|
713
|
+
final_batch = [batched_input, merge_levels]
|
|
714
|
+
batch_data = {"image_dims": dims}
|
|
715
|
+
|
|
716
|
+
return final_batch, batch_data
|
|
717
|
+
|
|
718
|
+
elif protocol == "http":
|
|
719
|
+
logger.debug("Formatting input for HTTP OCR model (batched).")
|
|
720
|
+
|
|
721
|
+
input_list: List[Dict[str, Any]] = []
|
|
722
|
+
for b64, shape in zip(batch_images, batch_dims):
|
|
723
|
+
image_url = f"data:image/png;base64,{b64}"
|
|
724
|
+
image_obj = {"type": "image_url", "url": image_url}
|
|
725
|
+
input_list.append(image_obj)
|
|
726
|
+
_dims = {"new_width": shape[1], "new_height": shape[0]}
|
|
727
|
+
dims.append(_dims)
|
|
728
|
+
|
|
729
|
+
payload = {
|
|
730
|
+
"input": input_list,
|
|
731
|
+
"merge_levels": [merge_level] * len(input_list),
|
|
732
|
+
}
|
|
733
|
+
|
|
734
|
+
batch_data = {"image_dims": dims}
|
|
735
|
+
|
|
736
|
+
return payload, batch_data
|
|
737
|
+
|
|
738
|
+
else:
|
|
739
|
+
raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.")
|
|
740
|
+
|
|
741
|
+
|
|
551
742
|
@multiprocessing_cache(max_calls=100) # Cache results first to avoid redundant retries from backoff
|
|
552
743
|
@backoff.on_predicate(backoff.expo, max_time=30)
|
|
553
744
|
def get_ocr_model_name(ocr_grpc_endpoint=None, default_model_name=DEFAULT_OCR_MODEL_NAME):
|