nv-ingest-api 2025.7.16.dev20250716__py3-none-any.whl → 2025.7.17.dev20250717__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nv-ingest-api might be problematic. Click here for more details.
- nv_ingest_api/interface/extract.py +18 -18
- nv_ingest_api/internal/extract/image/chart_extractor.py +75 -55
- nv_ingest_api/internal/extract/image/infographic_extractor.py +59 -35
- nv_ingest_api/internal/extract/image/table_extractor.py +81 -63
- nv_ingest_api/internal/extract/pdf/engines/nemoretriever.py +7 -7
- nv_ingest_api/internal/extract/pdf/engines/pdfium.py +9 -9
- nv_ingest_api/internal/primitives/nim/model_interface/helpers.py +58 -0
- nv_ingest_api/internal/primitives/nim/model_interface/{paddle.py → ocr.py} +132 -39
- nv_ingest_api/internal/primitives/nim/nim_client.py +46 -11
- nv_ingest_api/internal/schemas/extract/extract_chart_schema.py +6 -6
- nv_ingest_api/internal/schemas/extract/extract_infographic_schema.py +6 -6
- nv_ingest_api/internal/schemas/extract/extract_table_schema.py +5 -5
- nv_ingest_api/internal/transform/split_text.py +13 -8
- nv_ingest_api/util/image_processing/table_and_chart.py +97 -42
- nv_ingest_api/util/image_processing/transforms.py +16 -5
- nv_ingest_api/util/message_brokers/simple_message_broker/broker.py +1 -1
- nv_ingest_api/util/message_brokers/simple_message_broker/simple_client.py +51 -48
- {nv_ingest_api-2025.7.16.dev20250716.dist-info → nv_ingest_api-2025.7.17.dev20250717.dist-info}/METADATA +1 -1
- {nv_ingest_api-2025.7.16.dev20250716.dist-info → nv_ingest_api-2025.7.17.dev20250717.dist-info}/RECORD +22 -22
- {nv_ingest_api-2025.7.16.dev20250716.dist-info → nv_ingest_api-2025.7.17.dev20250717.dist-info}/WHEEL +0 -0
- {nv_ingest_api-2025.7.16.dev20250716.dist-info → nv_ingest_api-2025.7.17.dev20250717.dist-info}/licenses/LICENSE +0 -0
- {nv_ingest_api-2025.7.16.dev20250716.dist-info → nv_ingest_api-2025.7.17.dev20250717.dist-info}/top_level.txt +0 -0
|
@@ -4,22 +4,37 @@
|
|
|
4
4
|
|
|
5
5
|
import json
|
|
6
6
|
import logging
|
|
7
|
-
|
|
7
|
+
import os
|
|
8
|
+
from typing import Any
|
|
8
9
|
from typing import Dict
|
|
10
|
+
from typing import List
|
|
9
11
|
from typing import Optional
|
|
12
|
+
from typing import Tuple
|
|
10
13
|
|
|
14
|
+
import backoff
|
|
11
15
|
import numpy as np
|
|
16
|
+
import tritonclient.grpc as grpcclient
|
|
12
17
|
|
|
13
18
|
from nv_ingest_api.internal.primitives.nim import ModelInterface
|
|
14
|
-
from nv_ingest_api.internal.primitives.nim.model_interface.
|
|
19
|
+
from nv_ingest_api.internal.primitives.nim.model_interface.decorators import (
|
|
20
|
+
multiprocessing_cache,
|
|
21
|
+
)
|
|
22
|
+
from nv_ingest_api.internal.primitives.nim.model_interface.helpers import (
|
|
23
|
+
preprocess_image_for_ocr,
|
|
24
|
+
)
|
|
25
|
+
from nv_ingest_api.internal.primitives.nim.model_interface.helpers import (
|
|
26
|
+
preprocess_image_for_paddle,
|
|
27
|
+
)
|
|
15
28
|
from nv_ingest_api.util.image_processing.transforms import base64_to_numpy
|
|
16
29
|
|
|
30
|
+
DEFAULT_OCR_MODEL_NAME = "paddle"
|
|
31
|
+
|
|
17
32
|
logger = logging.getLogger(__name__)
|
|
18
33
|
|
|
19
34
|
|
|
20
|
-
class
|
|
35
|
+
class OCRModelInterface(ModelInterface):
|
|
21
36
|
"""
|
|
22
|
-
An interface for handling inference with a
|
|
37
|
+
An interface for handling inference with a OCR model, supporting both gRPC and HTTP protocols.
|
|
23
38
|
"""
|
|
24
39
|
|
|
25
40
|
def name(self) -> str:
|
|
@@ -31,7 +46,7 @@ class PaddleOCRModelInterface(ModelInterface):
|
|
|
31
46
|
str
|
|
32
47
|
The name of the model interface.
|
|
33
48
|
"""
|
|
34
|
-
return "
|
|
49
|
+
return "OCR"
|
|
35
50
|
|
|
36
51
|
def prepare_data_for_inference(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
|
37
52
|
"""
|
|
@@ -126,11 +141,26 @@ class PaddleOCRModelInterface(ModelInterface):
|
|
|
126
141
|
images = data["image_arrays"]
|
|
127
142
|
dims = data["image_dims"]
|
|
128
143
|
|
|
144
|
+
model_name = kwargs.get("model_name", "paddle")
|
|
145
|
+
merge_level = kwargs.get("merge_level", "paragraph")
|
|
146
|
+
|
|
129
147
|
if protocol == "grpc":
|
|
130
|
-
logger.debug("Formatting input for gRPC
|
|
148
|
+
logger.debug("Formatting input for gRPC OCR model (batched).")
|
|
131
149
|
processed: List[np.ndarray] = []
|
|
150
|
+
|
|
151
|
+
max_length = max(max(img.shape[:2]) for img in images)
|
|
152
|
+
|
|
132
153
|
for img in images:
|
|
133
|
-
|
|
154
|
+
if model_name == "paddle":
|
|
155
|
+
arr, _dims = preprocess_image_for_paddle(img)
|
|
156
|
+
else:
|
|
157
|
+
arr, _dims = preprocess_image_for_ocr(
|
|
158
|
+
img,
|
|
159
|
+
target_height=max_length,
|
|
160
|
+
target_width=max_length,
|
|
161
|
+
pad_how="bottom_right",
|
|
162
|
+
)
|
|
163
|
+
|
|
134
164
|
dims.append(_dims)
|
|
135
165
|
arr = arr.astype(np.float32)
|
|
136
166
|
arr = np.expand_dims(arr, axis=0) # => shape (1, H, W, C)
|
|
@@ -144,12 +174,18 @@ class PaddleOCRModelInterface(ModelInterface):
|
|
|
144
174
|
chunk_list(dims, max_batch_size),
|
|
145
175
|
):
|
|
146
176
|
batched_input = np.concatenate(proc_chunk, axis=0)
|
|
147
|
-
|
|
177
|
+
|
|
178
|
+
if model_name == "paddle":
|
|
179
|
+
batches.append(batched_input)
|
|
180
|
+
else:
|
|
181
|
+
merge_levels = np.array([[merge_level] * len(batched_input)], dtype="object")
|
|
182
|
+
batches.append([batched_input, merge_levels])
|
|
183
|
+
|
|
148
184
|
batch_data_list.append({"image_arrays": orig_chunk, "image_dims": dims_chunk})
|
|
149
185
|
return batches, batch_data_list
|
|
150
186
|
|
|
151
187
|
elif protocol == "http":
|
|
152
|
-
logger.debug("Formatting input for HTTP
|
|
188
|
+
logger.debug("Formatting input for HTTP OCR model (batched).")
|
|
153
189
|
if "base64_images" in data:
|
|
154
190
|
base64_list = data["base64_images"]
|
|
155
191
|
else:
|
|
@@ -170,7 +206,13 @@ class PaddleOCRModelInterface(ModelInterface):
|
|
|
170
206
|
chunk_list(images, max_batch_size),
|
|
171
207
|
chunk_list(dims, max_batch_size),
|
|
172
208
|
):
|
|
173
|
-
|
|
209
|
+
if model_name == "paddle":
|
|
210
|
+
payload = {"input": input_chunk}
|
|
211
|
+
else:
|
|
212
|
+
payload = {
|
|
213
|
+
"input": input_chunk,
|
|
214
|
+
"merge_levels": [merge_level] * len(input_chunk),
|
|
215
|
+
}
|
|
174
216
|
batches.append(payload)
|
|
175
217
|
batch_data_list.append({"image_arrays": orig_chunk, "image_dims": dims_chunk})
|
|
176
218
|
|
|
@@ -179,7 +221,14 @@ class PaddleOCRModelInterface(ModelInterface):
|
|
|
179
221
|
else:
|
|
180
222
|
raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.")
|
|
181
223
|
|
|
182
|
-
def parse_output(
|
|
224
|
+
def parse_output(
|
|
225
|
+
self,
|
|
226
|
+
response: Any,
|
|
227
|
+
protocol: str,
|
|
228
|
+
data: Optional[Dict[str, Any]] = None,
|
|
229
|
+
model_name: str = "paddle",
|
|
230
|
+
**kwargs: Any,
|
|
231
|
+
) -> Any:
|
|
183
232
|
"""
|
|
184
233
|
Parse the model's inference response for the given protocol. The parsing
|
|
185
234
|
may handle batched outputs for multiple images.
|
|
@@ -187,7 +236,7 @@ class PaddleOCRModelInterface(ModelInterface):
|
|
|
187
236
|
Parameters
|
|
188
237
|
----------
|
|
189
238
|
response : Any
|
|
190
|
-
The raw response from the
|
|
239
|
+
The raw response from the OCR model.
|
|
191
240
|
protocol : str
|
|
192
241
|
The protocol used for inference, "grpc" or "http".
|
|
193
242
|
data : dict of str -> Any, optional
|
|
@@ -209,24 +258,24 @@ class PaddleOCRModelInterface(ModelInterface):
|
|
|
209
258
|
dims: Optional[List[Tuple[int, int]]] = data.get("image_dims") if data else None
|
|
210
259
|
|
|
211
260
|
if protocol == "grpc":
|
|
212
|
-
logger.debug("Parsing output from gRPC
|
|
213
|
-
return self.
|
|
261
|
+
logger.debug("Parsing output from gRPC OCR model (batched).")
|
|
262
|
+
return self._extract_content_from_ocr_grpc_response(response, dims, model_name=model_name)
|
|
214
263
|
|
|
215
264
|
elif protocol == "http":
|
|
216
|
-
logger.debug("Parsing output from HTTP
|
|
217
|
-
return self.
|
|
265
|
+
logger.debug("Parsing output from HTTP OCR model (batched).")
|
|
266
|
+
return self._extract_content_from_ocr_http_response(response, dims)
|
|
218
267
|
|
|
219
268
|
else:
|
|
220
269
|
raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.")
|
|
221
270
|
|
|
222
271
|
def process_inference_results(self, output: Any, **kwargs: Any) -> Any:
|
|
223
272
|
"""
|
|
224
|
-
Process inference results for the
|
|
273
|
+
Process inference results for the OCR model.
|
|
225
274
|
|
|
226
275
|
Parameters
|
|
227
276
|
----------
|
|
228
277
|
output : Any
|
|
229
|
-
The raw output parsed from the
|
|
278
|
+
The raw output parsed from the OCR model.
|
|
230
279
|
**kwargs : Any
|
|
231
280
|
Additional keyword arguments for customization.
|
|
232
281
|
|
|
@@ -238,7 +287,7 @@ class PaddleOCRModelInterface(ModelInterface):
|
|
|
238
287
|
"""
|
|
239
288
|
return output
|
|
240
289
|
|
|
241
|
-
def
|
|
290
|
+
def _prepare_ocr_payload(self, base64_img: str) -> Dict[str, Any]:
|
|
242
291
|
"""
|
|
243
292
|
DEPRECATED by batch logic in format_input. Kept here if you need single-image direct calls.
|
|
244
293
|
|
|
@@ -250,7 +299,7 @@ class PaddleOCRModelInterface(ModelInterface):
|
|
|
250
299
|
Returns
|
|
251
300
|
-------
|
|
252
301
|
dict of str -> Any
|
|
253
|
-
The payload in either legacy or new format for
|
|
302
|
+
The payload in either legacy or new format for OCR's HTTP endpoint.
|
|
254
303
|
"""
|
|
255
304
|
image_url = f"data:image/png;base64,{base64_img}"
|
|
256
305
|
|
|
@@ -259,18 +308,18 @@ class PaddleOCRModelInterface(ModelInterface):
|
|
|
259
308
|
|
|
260
309
|
return payload
|
|
261
310
|
|
|
262
|
-
def
|
|
311
|
+
def _extract_content_from_ocr_http_response(
|
|
263
312
|
self,
|
|
264
313
|
json_response: Dict[str, Any],
|
|
265
314
|
dimensions: List[Dict[str, Any]],
|
|
266
315
|
) -> List[Tuple[str, str]]:
|
|
267
316
|
"""
|
|
268
|
-
Extract content from the JSON response of a
|
|
317
|
+
Extract content from the JSON response of a OCR HTTP API request.
|
|
269
318
|
|
|
270
319
|
Parameters
|
|
271
320
|
----------
|
|
272
321
|
json_response : dict of str -> Any
|
|
273
|
-
The JSON response returned by the
|
|
322
|
+
The JSON response returned by the OCR endpoint.
|
|
274
323
|
table_content_format : str or None
|
|
275
324
|
The specified format for table content (e.g., 'simple' or 'pseudo_markdown').
|
|
276
325
|
dimensions : list of dict, optional
|
|
@@ -296,25 +345,29 @@ class PaddleOCRModelInterface(ModelInterface):
|
|
|
296
345
|
text_detections = item.get("text_detections", [])
|
|
297
346
|
text_predictions = []
|
|
298
347
|
bounding_boxes = []
|
|
348
|
+
conf_scores = []
|
|
299
349
|
for td in text_detections:
|
|
300
350
|
text_predictions.append(td["text_prediction"]["text"])
|
|
301
351
|
bounding_boxes.append([[pt["x"], pt["y"]] for pt in td["bounding_box"]["points"]])
|
|
352
|
+
conf_scores.append(td["text_prediction"]["confidence"])
|
|
302
353
|
|
|
303
|
-
bounding_boxes, text_predictions = self.
|
|
354
|
+
bounding_boxes, text_predictions, conf_scores = self._postprocess_ocr_response(
|
|
304
355
|
bounding_boxes,
|
|
305
356
|
text_predictions,
|
|
357
|
+
conf_scores,
|
|
306
358
|
dimensions,
|
|
307
359
|
img_index=item_idx,
|
|
308
360
|
)
|
|
309
361
|
|
|
310
|
-
results.append([bounding_boxes, text_predictions])
|
|
362
|
+
results.append([bounding_boxes, text_predictions, conf_scores])
|
|
311
363
|
|
|
312
364
|
return results
|
|
313
365
|
|
|
314
|
-
def
|
|
366
|
+
def _extract_content_from_ocr_grpc_response(
|
|
315
367
|
self,
|
|
316
368
|
response: np.ndarray,
|
|
317
369
|
dimensions: List[Dict[str, Any]],
|
|
370
|
+
model_name: str = "paddle",
|
|
318
371
|
) -> List[Tuple[str, str]]:
|
|
319
372
|
"""
|
|
320
373
|
Parse a gRPC response for one or more images. The response can have two possible shapes:
|
|
@@ -367,33 +420,41 @@ class PaddleOCRModelInterface(ModelInterface):
|
|
|
367
420
|
texts_bytestr: bytes = response[1, i]
|
|
368
421
|
text_predictions = json.loads(texts_bytestr.decode("utf8"))
|
|
369
422
|
|
|
370
|
-
# 3)
|
|
371
|
-
|
|
372
|
-
|
|
423
|
+
# 3) Parse confidence scores
|
|
424
|
+
confs_bytestr: bytes = response[2, i]
|
|
425
|
+
conf_scores = json.loads(confs_bytestr.decode("utf8"))
|
|
373
426
|
|
|
374
427
|
# Some gRPC responses nest single-item lists; flatten them if needed
|
|
375
428
|
if isinstance(bounding_boxes, list) and len(bounding_boxes) == 1:
|
|
376
429
|
bounding_boxes = bounding_boxes[0]
|
|
377
430
|
if isinstance(text_predictions, list) and len(text_predictions) == 1:
|
|
378
431
|
text_predictions = text_predictions[0]
|
|
432
|
+
if isinstance(conf_scores, list) and len(conf_scores) == 1:
|
|
433
|
+
conf_scores = conf_scores[0]
|
|
379
434
|
|
|
380
|
-
|
|
435
|
+
# 4) Postprocess
|
|
436
|
+
bounding_boxes, text_predictions, conf_scores = self._postprocess_ocr_response(
|
|
381
437
|
bounding_boxes,
|
|
382
438
|
text_predictions,
|
|
439
|
+
conf_scores,
|
|
383
440
|
dimensions,
|
|
384
441
|
img_index=i,
|
|
442
|
+
scale_coordinates=True if model_name == "paddle" else False,
|
|
385
443
|
)
|
|
386
444
|
|
|
387
|
-
results.append([bounding_boxes, text_predictions])
|
|
445
|
+
results.append([bounding_boxes, text_predictions, conf_scores])
|
|
388
446
|
|
|
389
447
|
return results
|
|
390
448
|
|
|
391
449
|
@staticmethod
|
|
392
|
-
def
|
|
450
|
+
def _postprocess_ocr_response(
|
|
393
451
|
bounding_boxes: List[Any],
|
|
394
452
|
text_predictions: List[str],
|
|
453
|
+
conf_scores: List[float],
|
|
395
454
|
dims: Optional[List[Dict[str, Any]]] = None,
|
|
396
455
|
img_index: int = 0,
|
|
456
|
+
scale_coordinates: bool = True,
|
|
457
|
+
shift_coordinates: bool = True,
|
|
397
458
|
) -> Tuple[List[Any], List[str]]:
|
|
398
459
|
"""
|
|
399
460
|
Convert bounding boxes with normalized coordinates to pixel cooridnates by using
|
|
@@ -434,17 +495,18 @@ class PaddleOCRModelInterface(ModelInterface):
|
|
|
434
495
|
logger.warning("Image index out of range for stored dimensions. Using first image dims by default.")
|
|
435
496
|
img_index = 0
|
|
436
497
|
|
|
437
|
-
max_width = dims[img_index]["new_width"]
|
|
438
|
-
max_height = dims[img_index]["new_height"]
|
|
439
|
-
pad_width = dims[img_index].get("pad_width", 0)
|
|
440
|
-
pad_height = dims[img_index].get("pad_height", 0)
|
|
441
|
-
scale_factor = dims[img_index].get("scale_factor", 1.0)
|
|
498
|
+
max_width = dims[img_index]["new_width"] if scale_coordinates else 1.0
|
|
499
|
+
max_height = dims[img_index]["new_height"] if scale_coordinates else 1.0
|
|
500
|
+
pad_width = dims[img_index].get("pad_width", 0) if shift_coordinates else 0.0
|
|
501
|
+
pad_height = dims[img_index].get("pad_height", 0) if shift_coordinates else 0.0
|
|
502
|
+
scale_factor = dims[img_index].get("scale_factor", 1.0) if scale_coordinates else 1.0
|
|
442
503
|
|
|
443
504
|
bboxes: List[List[float]] = []
|
|
444
505
|
texts: List[str] = []
|
|
506
|
+
confs: List[float] = []
|
|
445
507
|
|
|
446
508
|
# Convert normalized coords back to actual pixel coords
|
|
447
|
-
for box, txt in zip(bounding_boxes, text_predictions):
|
|
509
|
+
for box, txt, conf in zip(bounding_boxes, text_predictions, conf_scores):
|
|
448
510
|
if box == "nan":
|
|
449
511
|
continue
|
|
450
512
|
points: List[List[float]] = []
|
|
@@ -458,5 +520,36 @@ class PaddleOCRModelInterface(ModelInterface):
|
|
|
458
520
|
points.append([x_original, y_original])
|
|
459
521
|
bboxes.append(points)
|
|
460
522
|
texts.append(txt)
|
|
523
|
+
confs.append(conf)
|
|
524
|
+
|
|
525
|
+
return bboxes, texts, confs
|
|
461
526
|
|
|
462
|
-
|
|
527
|
+
|
|
528
|
+
@multiprocessing_cache(max_calls=100) # Cache results first to avoid redundant retries from backoff
|
|
529
|
+
@backoff.on_predicate(backoff.expo, max_time=30)
|
|
530
|
+
def get_ocr_model_name(ocr_grpc_endpoint=None, default_model_name=DEFAULT_OCR_MODEL_NAME):
|
|
531
|
+
"""
|
|
532
|
+
Determines the OCR model name by checking the environment, querying the gRPC endpoint,
|
|
533
|
+
or falling back to a default.
|
|
534
|
+
"""
|
|
535
|
+
# 1. Check for an explicit override from the environment variable first.
|
|
536
|
+
ocr_model_name = os.getenv("OCR_MODEL_NAME", None)
|
|
537
|
+
if ocr_model_name is not None:
|
|
538
|
+
return ocr_model_name
|
|
539
|
+
|
|
540
|
+
# 2. If no gRPC endpoint is provided, fall back to the default immediately.
|
|
541
|
+
if not ocr_grpc_endpoint:
|
|
542
|
+
logger.debug(f"No OCR gRPC endpoint provided. Falling back to default model name '{default_model_name}'.")
|
|
543
|
+
return default_model_name
|
|
544
|
+
|
|
545
|
+
# 3. Attempt to query the gRPC endpoint to discover the model name.
|
|
546
|
+
try:
|
|
547
|
+
client = grpcclient.InferenceServerClient(ocr_grpc_endpoint)
|
|
548
|
+
model_index = client.get_model_repository_index(as_json=True)
|
|
549
|
+
model_names = [x["name"] for x in model_index.get("models", [])]
|
|
550
|
+
ocr_model_name = model_names[0]
|
|
551
|
+
except Exception:
|
|
552
|
+
logger.warning(f"Failed to get ocr model name after 30 seconds. Falling back to '{default_model_name}'.")
|
|
553
|
+
ocr_model_name = default_model_name
|
|
554
|
+
|
|
555
|
+
return ocr_model_name
|
|
@@ -33,6 +33,7 @@ class NimClient:
|
|
|
33
33
|
auth_token: Optional[str] = None,
|
|
34
34
|
timeout: float = 120.0,
|
|
35
35
|
max_retries: int = 5,
|
|
36
|
+
max_429_retries: int = 5,
|
|
36
37
|
):
|
|
37
38
|
"""
|
|
38
39
|
Initialize the NimClient with the specified model interface, protocol, and server endpoints.
|
|
@@ -49,6 +50,10 @@ class NimClient:
|
|
|
49
50
|
Authorization token for HTTP requests (default: None).
|
|
50
51
|
timeout : float, optional
|
|
51
52
|
Timeout for HTTP requests in seconds (default: 30.0).
|
|
53
|
+
max_retries : int, optional
|
|
54
|
+
The maximum number of retries for non-429 server-side errors (default: 5).
|
|
55
|
+
max_429_retries : int, optional
|
|
56
|
+
The maximum number of retries specifically for 429 errors (default: 10).
|
|
52
57
|
|
|
53
58
|
Raises
|
|
54
59
|
------
|
|
@@ -62,6 +67,7 @@ class NimClient:
|
|
|
62
67
|
self.auth_token = auth_token
|
|
63
68
|
self.timeout = timeout # Timeout for HTTP requests
|
|
64
69
|
self.max_retries = max_retries
|
|
70
|
+
self.max_429_retries = max_429_retries
|
|
65
71
|
self._grpc_endpoint, self._http_endpoint = endpoints
|
|
66
72
|
self._max_batch_sizes = {}
|
|
67
73
|
self._lock = threading.Lock()
|
|
@@ -138,7 +144,9 @@ class NimClient:
|
|
|
138
144
|
else:
|
|
139
145
|
raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.")
|
|
140
146
|
|
|
141
|
-
parsed_output = self.model_interface.parse_output(
|
|
147
|
+
parsed_output = self.model_interface.parse_output(
|
|
148
|
+
response, protocol=self.protocol, data=batch_data, model_name=model_name, **kwargs
|
|
149
|
+
)
|
|
142
150
|
return parsed_output, batch_data
|
|
143
151
|
|
|
144
152
|
def try_set_max_batch_size(self, model_name, model_version: str = ""):
|
|
@@ -167,8 +175,8 @@ class NimClient:
|
|
|
167
175
|
try:
|
|
168
176
|
# 1. Retrieve or default to the model's maximum batch size.
|
|
169
177
|
batch_size = self._fetch_max_batch_size(model_name)
|
|
170
|
-
max_requested_batch_size = kwargs.
|
|
171
|
-
force_requested_batch_size = kwargs.
|
|
178
|
+
max_requested_batch_size = kwargs.pop("max_batch_size", batch_size)
|
|
179
|
+
force_requested_batch_size = kwargs.pop("force_max_batch_size", False)
|
|
172
180
|
max_batch_size = (
|
|
173
181
|
min(batch_size, max_requested_batch_size)
|
|
174
182
|
if not force_requested_batch_size
|
|
@@ -180,7 +188,11 @@ class NimClient:
|
|
|
180
188
|
|
|
181
189
|
# 3. Format the input based on protocol.
|
|
182
190
|
formatted_batches, formatted_batch_data = self.model_interface.format_input(
|
|
183
|
-
data,
|
|
191
|
+
data,
|
|
192
|
+
protocol=self.protocol,
|
|
193
|
+
max_batch_size=max_batch_size,
|
|
194
|
+
model_name=model_name,
|
|
195
|
+
**kwargs,
|
|
184
196
|
)
|
|
185
197
|
|
|
186
198
|
# Check for a custom maximum pool worker count, and remove it from kwargs.
|
|
@@ -237,19 +249,27 @@ class NimClient:
|
|
|
237
249
|
np.ndarray
|
|
238
250
|
The output of the model as a numpy array.
|
|
239
251
|
"""
|
|
252
|
+
if not isinstance(formatted_input, list):
|
|
253
|
+
formatted_input = [formatted_input]
|
|
240
254
|
|
|
241
255
|
parameters = kwargs.get("parameters", {})
|
|
242
|
-
output_names = kwargs.get("
|
|
243
|
-
|
|
244
|
-
|
|
256
|
+
output_names = kwargs.get("output_names", ["output"])
|
|
257
|
+
dtypes = kwargs.get("dtypes", ["FP32"])
|
|
258
|
+
input_names = kwargs.get("input_names", ["input"])
|
|
259
|
+
|
|
260
|
+
input_tensors = []
|
|
261
|
+
for input_name, input_data, dtype in zip(input_names, formatted_input, dtypes):
|
|
262
|
+
input_tensors.append(grpcclient.InferInput(input_name, input_data.shape, datatype=dtype))
|
|
245
263
|
|
|
246
|
-
|
|
247
|
-
|
|
264
|
+
for idx, input_data in enumerate(formatted_input):
|
|
265
|
+
input_tensors[idx].set_data_from_numpy(input_data)
|
|
248
266
|
|
|
249
267
|
outputs = [grpcclient.InferRequestedOutput(output_name) for output_name in output_names]
|
|
268
|
+
|
|
250
269
|
response = self.client.infer(
|
|
251
|
-
model_name=model_name, parameters=parameters, inputs=
|
|
270
|
+
model_name=model_name, parameters=parameters, inputs=input_tensors, outputs=outputs
|
|
252
271
|
)
|
|
272
|
+
|
|
253
273
|
logger.debug(f"gRPC inference response: {response}")
|
|
254
274
|
|
|
255
275
|
if len(outputs) == 1:
|
|
@@ -281,6 +301,7 @@ class NimClient:
|
|
|
281
301
|
|
|
282
302
|
base_delay = 2.0
|
|
283
303
|
attempt = 0
|
|
304
|
+
retries_429 = 0
|
|
284
305
|
|
|
285
306
|
while attempt < self.max_retries:
|
|
286
307
|
try:
|
|
@@ -291,7 +312,21 @@ class NimClient:
|
|
|
291
312
|
|
|
292
313
|
# Check for server-side or rate-limit type errors
|
|
293
314
|
# e.g. 5xx => server error, 429 => too many requests
|
|
294
|
-
if status_code == 429
|
|
315
|
+
if status_code == 429:
|
|
316
|
+
retries_429 += 1
|
|
317
|
+
logger.warning(
|
|
318
|
+
f"Received HTTP 429 (Too Many Requests) from {self.model_interface.name()}. "
|
|
319
|
+
f"Attempt {retries_429} of {self.max_429_retries}."
|
|
320
|
+
)
|
|
321
|
+
if retries_429 >= self.max_429_retries:
|
|
322
|
+
logger.error("Max retries for HTTP 429 exceeded.")
|
|
323
|
+
response.raise_for_status()
|
|
324
|
+
else:
|
|
325
|
+
backoff_time = base_delay * (2**retries_429)
|
|
326
|
+
time.sleep(backoff_time)
|
|
327
|
+
continue # Retry without incrementing the main attempt counter
|
|
328
|
+
|
|
329
|
+
if status_code == 503 or (500 <= status_code < 600):
|
|
295
330
|
logger.warning(
|
|
296
331
|
f"Received HTTP {status_code} ({response.reason}) from "
|
|
297
332
|
f"{self.model_interface.name()}. Attempt {attempt + 1} of {self.max_retries}."
|
|
@@ -24,8 +24,8 @@ class ChartExtractorConfigSchema(BaseModel):
|
|
|
24
24
|
A tuple containing the gRPC and HTTP services for the yolox endpoint.
|
|
25
25
|
Either the gRPC or HTTP service can be empty, but not both.
|
|
26
26
|
|
|
27
|
-
|
|
28
|
-
A tuple containing the gRPC and HTTP services for the
|
|
27
|
+
ocr_endpoints : Tuple[Optional[str], Optional[str]], default=(None, None)
|
|
28
|
+
A tuple containing the gRPC and HTTP services for the ocr endpoint.
|
|
29
29
|
Either the gRPC or HTTP service can be empty, but not both.
|
|
30
30
|
|
|
31
31
|
Methods
|
|
@@ -49,8 +49,8 @@ class ChartExtractorConfigSchema(BaseModel):
|
|
|
49
49
|
yolox_endpoints: Tuple[Optional[str], Optional[str]] = (None, None)
|
|
50
50
|
yolox_infer_protocol: str = ""
|
|
51
51
|
|
|
52
|
-
|
|
53
|
-
|
|
52
|
+
ocr_endpoints: Tuple[Optional[str], Optional[str]] = (None, None)
|
|
53
|
+
ocr_infer_protocol: str = ""
|
|
54
54
|
|
|
55
55
|
nim_batch_size: int = 2
|
|
56
56
|
workers_per_progress_engine: int = 5
|
|
@@ -86,7 +86,7 @@ class ChartExtractorConfigSchema(BaseModel):
|
|
|
86
86
|
return None
|
|
87
87
|
return service
|
|
88
88
|
|
|
89
|
-
for endpoint_name in ["yolox_endpoints", "
|
|
89
|
+
for endpoint_name in ["yolox_endpoints", "ocr_endpoints"]:
|
|
90
90
|
grpc_service, http_service = values.get(endpoint_name, (None, None))
|
|
91
91
|
grpc_service = clean_service(grpc_service)
|
|
92
92
|
http_service = clean_service(http_service)
|
|
@@ -117,7 +117,7 @@ class ChartExtractorSchema(BaseModel):
|
|
|
117
117
|
A flag indicating whether to raise an exception if a failure occurs during chart extraction.
|
|
118
118
|
|
|
119
119
|
extraction_config: Optional[ChartExtractorConfigSchema], default=None
|
|
120
|
-
Configuration for the chart extraction stage, including yolox and
|
|
120
|
+
Configuration for the chart extraction stage, including yolox and ocr service endpoints.
|
|
121
121
|
"""
|
|
122
122
|
|
|
123
123
|
max_queue_size: int = 1
|
|
@@ -20,8 +20,8 @@ class InfographicExtractorConfigSchema(BaseModel):
|
|
|
20
20
|
auth_token : Optional[str], default=None
|
|
21
21
|
Authentication token required for secure services.
|
|
22
22
|
|
|
23
|
-
|
|
24
|
-
A tuple containing the gRPC and HTTP services for the
|
|
23
|
+
ocr_endpoints : Tuple[Optional[str], Optional[str]], default=(None, None)
|
|
24
|
+
A tuple containing the gRPC and HTTP services for the ocr endpoint.
|
|
25
25
|
Either the gRPC or HTTP service can be empty, but not both.
|
|
26
26
|
|
|
27
27
|
Methods
|
|
@@ -42,8 +42,8 @@ class InfographicExtractorConfigSchema(BaseModel):
|
|
|
42
42
|
|
|
43
43
|
auth_token: Optional[str] = None
|
|
44
44
|
|
|
45
|
-
|
|
46
|
-
|
|
45
|
+
ocr_endpoints: Tuple[Optional[str], Optional[str]] = (None, None)
|
|
46
|
+
ocr_infer_protocol: str = ""
|
|
47
47
|
|
|
48
48
|
nim_batch_size: int = 2
|
|
49
49
|
workers_per_progress_engine: int = 5
|
|
@@ -79,7 +79,7 @@ class InfographicExtractorConfigSchema(BaseModel):
|
|
|
79
79
|
return None
|
|
80
80
|
return service
|
|
81
81
|
|
|
82
|
-
for endpoint_name in ["
|
|
82
|
+
for endpoint_name in ["ocr_endpoints"]:
|
|
83
83
|
grpc_service, http_service = values.get(endpoint_name, (None, None))
|
|
84
84
|
grpc_service = clean_service(grpc_service)
|
|
85
85
|
http_service = clean_service(http_service)
|
|
@@ -110,7 +110,7 @@ class InfographicExtractorSchema(BaseModel):
|
|
|
110
110
|
A flag indicating whether to raise an exception if a failure occurs during infographic extraction.
|
|
111
111
|
|
|
112
112
|
stage_config : Optional[InfographicExtractorConfigSchema], default=None
|
|
113
|
-
Configuration for the infographic extraction stage, including yolox and
|
|
113
|
+
Configuration for the infographic extraction stage, including yolox and ocr service endpoints.
|
|
114
114
|
"""
|
|
115
115
|
|
|
116
116
|
max_queue_size: int = 1
|
|
@@ -22,8 +22,8 @@ class TableExtractorConfigSchema(BaseModel):
|
|
|
22
22
|
auth_token : Optional[str], default=None
|
|
23
23
|
Authentication token required for secure services.
|
|
24
24
|
|
|
25
|
-
|
|
26
|
-
A tuple containing the gRPC and HTTP services for the
|
|
25
|
+
ocr_endpoints : Tuple[Optional[str], Optional[str]], default=(None, None)
|
|
26
|
+
A tuple containing the gRPC and HTTP services for the ocr endpoint.
|
|
27
27
|
Either the gRPC or HTTP service can be empty, but not both.
|
|
28
28
|
|
|
29
29
|
Methods
|
|
@@ -47,8 +47,8 @@ class TableExtractorConfigSchema(BaseModel):
|
|
|
47
47
|
yolox_endpoints: Tuple[Optional[str], Optional[str]] = (None, None)
|
|
48
48
|
yolox_infer_protocol: str = ""
|
|
49
49
|
|
|
50
|
-
|
|
51
|
-
|
|
50
|
+
ocr_endpoints: Tuple[Optional[str], Optional[str]] = (None, None)
|
|
51
|
+
ocr_infer_protocol: str = ""
|
|
52
52
|
|
|
53
53
|
nim_batch_size: int = 2
|
|
54
54
|
workers_per_progress_engine: int = 5
|
|
@@ -81,7 +81,7 @@ class TableExtractorConfigSchema(BaseModel):
|
|
|
81
81
|
return None
|
|
82
82
|
return service
|
|
83
83
|
|
|
84
|
-
for endpoint_name in ["yolox_endpoints", "
|
|
84
|
+
for endpoint_name in ["yolox_endpoints", "ocr_endpoints"]:
|
|
85
85
|
grpc_service, http_service = values.get(endpoint_name, (None, None))
|
|
86
86
|
grpc_service = clean_service(grpc_service)
|
|
87
87
|
http_service = clean_service(http_service)
|
|
@@ -141,14 +141,19 @@ def transform_text_split_and_tokenize_internal(
|
|
|
141
141
|
|
|
142
142
|
model_predownload_path = os.environ.get("MODEL_PREDOWNLOAD_PATH")
|
|
143
143
|
|
|
144
|
-
if
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
144
|
+
if model_predownload_path is not None:
|
|
145
|
+
if os.path.exists(os.path.join(model_predownload_path, "llama-3.2-1b/tokenizer/tokenizer.json")) and (
|
|
146
|
+
tokenizer_identifier is None or tokenizer_identifier == "meta-llama/Llama-3.2-1B"
|
|
147
|
+
):
|
|
148
|
+
tokenizer_identifier = os.path.join(model_predownload_path, "llama-3.2-1b/tokenizer/")
|
|
149
|
+
elif os.path.exists(
|
|
150
|
+
os.path.join(model_predownload_path, "e5-large-unsupervised/tokenizer/tokenizer.json")
|
|
151
|
+
) and (tokenizer_identifier is None or tokenizer_identifier == "intfloat/e5-large-unsupervised"):
|
|
152
|
+
tokenizer_identifier = os.path.join(model_predownload_path, "e5-large-unsupervised/tokenizer/")
|
|
153
|
+
|
|
154
|
+
# Defaulto to intfloat/e5-large-unsupervised if no tokenizer predownloaded or specified
|
|
155
|
+
if tokenizer_identifier is None:
|
|
156
|
+
tokenizer_identifier = "intfloat/e5-large-unsupervised"
|
|
152
157
|
|
|
153
158
|
tokenizer_model = AutoTokenizer.from_pretrained(tokenizer_identifier, token=hf_access_token)
|
|
154
159
|
|