nv-ingest-api 2025.7.16.dev20250716__py3-none-any.whl → 2025.7.18.dev20250718__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nv-ingest-api might be problematic. Click here for more details.

Files changed (24) hide show
  1. nv_ingest_api/interface/extract.py +18 -18
  2. nv_ingest_api/internal/extract/image/chart_extractor.py +80 -75
  3. nv_ingest_api/internal/extract/image/image_helpers/common.py +5 -6
  4. nv_ingest_api/internal/extract/image/infographic_extractor.py +59 -35
  5. nv_ingest_api/internal/extract/image/table_extractor.py +84 -64
  6. nv_ingest_api/internal/extract/pdf/engines/nemoretriever.py +10 -7
  7. nv_ingest_api/internal/extract/pdf/engines/pdfium.py +16 -29
  8. nv_ingest_api/internal/primitives/nim/model_interface/helpers.py +59 -0
  9. nv_ingest_api/internal/primitives/nim/model_interface/{paddle.py → ocr.py} +132 -39
  10. nv_ingest_api/internal/primitives/nim/model_interface/yolox.py +37 -224
  11. nv_ingest_api/internal/primitives/nim/nim_client.py +55 -14
  12. nv_ingest_api/internal/schemas/extract/extract_chart_schema.py +6 -6
  13. nv_ingest_api/internal/schemas/extract/extract_infographic_schema.py +6 -6
  14. nv_ingest_api/internal/schemas/extract/extract_table_schema.py +5 -5
  15. nv_ingest_api/internal/transform/split_text.py +13 -8
  16. nv_ingest_api/util/image_processing/table_and_chart.py +97 -42
  17. nv_ingest_api/util/image_processing/transforms.py +16 -5
  18. nv_ingest_api/util/message_brokers/simple_message_broker/broker.py +1 -1
  19. nv_ingest_api/util/message_brokers/simple_message_broker/simple_client.py +51 -48
  20. {nv_ingest_api-2025.7.16.dev20250716.dist-info → nv_ingest_api-2025.7.18.dev20250718.dist-info}/METADATA +1 -1
  21. {nv_ingest_api-2025.7.16.dev20250716.dist-info → nv_ingest_api-2025.7.18.dev20250718.dist-info}/RECORD +24 -24
  22. {nv_ingest_api-2025.7.16.dev20250716.dist-info → nv_ingest_api-2025.7.18.dev20250718.dist-info}/WHEEL +0 -0
  23. {nv_ingest_api-2025.7.16.dev20250716.dist-info → nv_ingest_api-2025.7.18.dev20250718.dist-info}/licenses/LICENSE +0 -0
  24. {nv_ingest_api-2025.7.16.dev20250716.dist-info → nv_ingest_api-2025.7.18.dev20250718.dist-info}/top_level.txt +0 -0
@@ -12,15 +12,14 @@ from typing import List
12
12
  from typing import Optional
13
13
  from typing import Tuple
14
14
 
15
- import cv2
15
+ import backoff
16
16
  import numpy as np
17
- import packaging
17
+ import json
18
18
  import pandas as pd
19
- import torch
20
- import torchvision
21
19
 
22
20
  from nv_ingest_api.internal.primitives.nim import ModelInterface
23
- from nv_ingest_api.internal.primitives.nim.model_interface.helpers import get_model_name
21
+ import tritonclient.grpc as grpcclient
22
+ from nv_ingest_api.internal.primitives.nim.model_interface.decorators import multiprocessing_cache
24
23
  from nv_ingest_api.util.image_processing import scale_image_to_encoding_size
25
24
  from nv_ingest_api.util.image_processing.transforms import numpy_to_base64
26
25
 
@@ -35,15 +34,6 @@ YOLOX_PAGE_IMAGE_PREPROC_HEIGHT = 1024
35
34
  YOLOX_PAGE_IMAGE_PREPROC_WIDTH = 1024
36
35
  YOLOX_PAGE_IMAGE_FORMAT = os.getenv("YOLOX_PAGE_IMAGE_FORMAT", "PNG")
37
36
 
38
- # yolox-page-elements-v1 contants
39
- YOLOX_PAGE_V1_NUM_CLASSES = 4
40
- YOLOX_PAGE_V1_FINAL_SCORE = {"table": 0.48, "chart": 0.48}
41
- YOLOX_PAGE_V1_CLASS_LABELS = [
42
- "table",
43
- "chart",
44
- "title",
45
- ]
46
-
47
37
  # yolox-page-elements-v2 contants
48
38
  YOLOX_PAGE_V2_NUM_CLASSES = 4
49
39
  YOLOX_PAGE_V2_FINAL_SCORE = {"table": 0.1, "chart": 0.01, "infographic": 0.01}
@@ -63,11 +53,6 @@ YOLOX_GRAPHIC_MIN_SCORE = 0.1
63
53
  YOLOX_GRAPHIC_FINAL_SCORE = 0.0
64
54
  YOLOX_GRAPHIC_NIM_MAX_IMAGE_SIZE = 512_000
65
55
 
66
- # TODO(Devin): Legacy items aren't working right for me. Double check these.
67
- LEGACY_YOLOX_GRAPHIC_IMAGE_PREPROC_HEIGHT = 1024
68
- LEGACY_YOLOX_GRAPHIC_IMAGE_PREPROC_WIDTH = 1024
69
- YOLOX_GRAPHIC_IMAGE_PREPROC_HEIGHT = 1024
70
- YOLOX_GRAPHIC_IMAGE_PREPROC_WIDTH = 1024
71
56
 
72
57
  YOLOX_GRAPHIC_CLASS_LABELS = [
73
58
  "chart_title",
@@ -111,8 +96,6 @@ class YoloxModelInterfaceBase(ModelInterface):
111
96
 
112
97
  def __init__(
113
98
  self,
114
- image_preproc_width: Optional[int] = None,
115
- image_preproc_height: Optional[int] = None,
116
99
  nim_max_image_size: Optional[int] = None,
117
100
  num_classes: Optional[int] = None,
118
101
  conf_threshold: Optional[float] = None,
@@ -126,8 +109,6 @@ class YoloxModelInterfaceBase(ModelInterface):
126
109
  Parameters
127
110
  ----------
128
111
  """
129
- self.image_preproc_width = image_preproc_width
130
- self.image_preproc_height = image_preproc_height
131
112
  self.nim_max_image_size = nim_max_image_size
132
113
  self.num_classes = num_classes
133
114
  self.conf_threshold = conf_threshold
@@ -199,6 +180,7 @@ class YoloxModelInterfaceBase(ModelInterface):
199
180
 
200
181
  # Helper functions to chunk a list into sublists of length up to chunk_size.
201
182
  def chunk_list(lst: list, chunk_size: int) -> List[list]:
183
+ chunk_size = max(1, chunk_size)
202
184
  return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
203
185
 
204
186
  def chunk_list_geometrically(lst: list, max_size: int) -> List[list]:
@@ -206,29 +188,28 @@ class YoloxModelInterfaceBase(ModelInterface):
206
188
  chunks = []
207
189
  i = 0
208
190
  while i < len(lst):
209
- chunk_size = min(2 ** int(log(len(lst) - i, 2)), max_size)
191
+ chunk_size = max(1, min(2 ** int(log(len(lst) - i, 2)), max_size))
210
192
  chunks.append(lst[i : i + chunk_size])
211
193
  i += chunk_size
212
194
  return chunks
213
195
 
214
196
  if protocol == "grpc":
215
- logger.debug("Formatting input for gRPC Yolox model")
216
- # Resize images for model input (Yolox expects 1024x1024).
217
- resized_images = [
218
- resize_image(image, (self.image_preproc_width, self.image_preproc_height)) for image in data["images"]
219
- ]
220
- # Chunk the resized images, the original images, and their shapes.
221
- resized_chunks = chunk_list_geometrically(resized_images, max_batch_size)
197
+ logger.debug("Formatting input for gRPC Yolox Ensemble model")
198
+ b64_images = [numpy_to_base64(image, format=YOLOX_PAGE_IMAGE_FORMAT) for image in data["images"]]
199
+ b64_chunks = chunk_list_geometrically(b64_images, max_batch_size)
222
200
  original_chunks = chunk_list_geometrically(data["images"], max_batch_size)
223
201
  shape_chunks = chunk_list_geometrically(data["original_image_shapes"], max_batch_size)
224
202
 
225
203
  batched_inputs = []
226
204
  formatted_batch_data = []
227
- for r_chunk, orig_chunk, shapes in zip(resized_chunks, original_chunks, shape_chunks):
228
- # Reorder axes from (B, H, W, C) to (B, C, H, W) as expected by the model.
229
- input_array = np.einsum("bijk->bkij", r_chunk).astype(np.float32)
230
- batched_inputs.append(input_array)
205
+ for b64_chunk, orig_chunk, shapes in zip(b64_chunks, original_chunks, shape_chunks):
206
+ input_array = np.array(b64_chunk, dtype=np.object_)
207
+ current_batch_size = input_array.shape[0]
208
+ single_threshold_pair = [self.conf_threshold, self.iou_threshold]
209
+ thresholds = np.tile(single_threshold_pair, (current_batch_size, 1)).astype(np.float32)
210
+ batched_inputs.append([input_array, thresholds])
231
211
  formatted_batch_data.append({"images": orig_chunk, "original_image_shapes": shapes})
212
+
232
213
  return batched_inputs, formatted_batch_data
233
214
 
234
215
  elif protocol == "http":
@@ -337,32 +318,20 @@ class YoloxModelInterfaceBase(ModelInterface):
337
318
  list[dict]
338
319
  A list of annotation dictionaries for each image in the batch.
339
320
  """
340
- original_image_shapes = kwargs.get("original_image_shapes", [])
341
-
342
321
  if protocol == "http":
343
322
  # For http, the output already has postprocessing applied. Skip to table/chart expansion.
344
323
  results = output
345
324
 
346
325
  elif protocol == "grpc":
326
+ results = []
347
327
  # For grpc, apply the same NIM postprocessing.
348
- pred = postprocess_model_prediction(
349
- output,
350
- self.num_classes,
351
- self.conf_threshold,
352
- self.iou_threshold,
353
- class_agnostic=False,
354
- )
355
- results = postprocess_results(
356
- pred,
357
- original_image_shapes,
358
- self.image_preproc_width,
359
- self.image_preproc_height,
360
- self.class_labels,
361
- min_score=self.min_score,
362
- )
363
-
328
+ for out in output:
329
+ if isinstance(out, bytes):
330
+ out = out.decode("utf-8")
331
+ if isinstance(out, dict):
332
+ continue
333
+ results.append(json.loads(out))
364
334
  inference_results = self.postprocess_annotations(results, **kwargs)
365
-
366
335
  return inference_results
367
336
 
368
337
  def postprocess_annotations(self, annotation_dicts, **kwargs):
@@ -396,22 +365,15 @@ class YoloxPageElementsModelInterface(YoloxModelInterfaceBase):
396
365
  An interface for handling inference with yolox-page-elements model, supporting both gRPC and HTTP protocols.
397
366
  """
398
367
 
399
- def __init__(self, yolox_model_name: str = "nemoretriever-page-elements-v2"):
368
+ def __init__(self):
400
369
  """
401
370
  Initialize the yolox-page-elements model interface.
402
371
  """
403
- if yolox_model_name.endswith("-v1"):
404
- num_classes = YOLOX_PAGE_V1_NUM_CLASSES
405
- final_score = YOLOX_PAGE_V1_FINAL_SCORE
406
- class_labels = YOLOX_PAGE_V1_CLASS_LABELS
407
- else:
408
- num_classes = YOLOX_PAGE_V2_NUM_CLASSES
409
- final_score = YOLOX_PAGE_V2_FINAL_SCORE
410
- class_labels = YOLOX_PAGE_V2_CLASS_LABELS
372
+ num_classes = YOLOX_PAGE_V2_NUM_CLASSES
373
+ final_score = YOLOX_PAGE_V2_FINAL_SCORE
374
+ class_labels = YOLOX_PAGE_V2_CLASS_LABELS
411
375
 
412
376
  super().__init__(
413
- image_preproc_width=YOLOX_PAGE_IMAGE_PREPROC_WIDTH,
414
- image_preproc_height=YOLOX_PAGE_IMAGE_PREPROC_HEIGHT,
415
377
  nim_max_image_size=YOLOX_PAGE_NIM_MAX_IMAGE_SIZE,
416
378
  num_classes=num_classes,
417
379
  conf_threshold=YOLOX_PAGE_CONF_THRESHOLD,
@@ -478,22 +440,11 @@ class YoloxGraphicElementsModelInterface(YoloxModelInterfaceBase):
478
440
  An interface for handling inference with yolox-graphic-elemenents model, supporting both gRPC and HTTP protocols.
479
441
  """
480
442
 
481
- def __init__(self, yolox_version: Optional[str] = None):
443
+ def __init__(self):
482
444
  """
483
445
  Initialize the yolox-graphic-elements model interface.
484
446
  """
485
- if yolox_version and (
486
- packaging.version.Version(yolox_version) >= packaging.version.Version("1.2.0-rc5") # gtc release
487
- ):
488
- image_preproc_width = YOLOX_GRAPHIC_IMAGE_PREPROC_WIDTH
489
- image_preproc_height = YOLOX_GRAPHIC_IMAGE_PREPROC_HEIGHT
490
- else:
491
- image_preproc_width = LEGACY_YOLOX_GRAPHIC_IMAGE_PREPROC_WIDTH
492
- image_preproc_height = LEGACY_YOLOX_GRAPHIC_IMAGE_PREPROC_HEIGHT
493
-
494
447
  super().__init__(
495
- image_preproc_width=image_preproc_width,
496
- image_preproc_height=image_preproc_height,
497
448
  nim_max_image_size=YOLOX_GRAPHIC_NIM_MAX_IMAGE_SIZE,
498
449
  num_classes=YOLOX_GRAPHIC_NUM_CLASSES,
499
450
  conf_threshold=YOLOX_GRAPHIC_CONF_THRESHOLD,
@@ -551,8 +502,6 @@ class YoloxTableStructureModelInterface(YoloxModelInterfaceBase):
551
502
  Initialize the yolox-graphic-elements model interface.
552
503
  """
553
504
  super().__init__(
554
- image_preproc_width=YOLOX_TABLE_IMAGE_PREPROC_HEIGHT,
555
- image_preproc_height=YOLOX_TABLE_IMAGE_PREPROC_HEIGHT,
556
505
  nim_max_image_size=YOLOX_TABLE_NIM_MAX_IMAGE_SIZE,
557
506
  num_classes=YOLOX_TABLE_NUM_CLASSES,
558
507
  conf_threshold=YOLOX_TABLE_CONF_THRESHOLD,
@@ -600,144 +549,6 @@ class YoloxTableStructureModelInterface(YoloxModelInterfaceBase):
600
549
  return inference_results
601
550
 
602
551
 
603
- def postprocess_model_prediction(prediction, num_classes, conf_thre=0.7, nms_thre=0.45, class_agnostic=False):
604
- # Convert numpy array to torch tensor
605
- prediction = torch.from_numpy(prediction.copy())
606
-
607
- # Compute box corners
608
- box_corner = prediction.new(prediction.shape)
609
- box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
610
- box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
611
- box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
612
- box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
613
- prediction[:, :, :4] = box_corner[:, :, :4]
614
-
615
- output = [None for _ in range(len(prediction))]
616
-
617
- for i, image_pred in enumerate(prediction):
618
- # If no detections, continue to the next image
619
- if not image_pred.size(0):
620
- continue
621
-
622
- # Ensure image_pred is 2D
623
- if image_pred.ndim == 1:
624
- image_pred = image_pred.unsqueeze(0)
625
-
626
- # Get score and class with highest confidence
627
- class_conf, class_pred = torch.max(image_pred[:, 5 : 5 + num_classes], 1, keepdim=True)
628
-
629
- # Confidence mask
630
- squeezed_conf = class_conf.squeeze(dim=1)
631
- conf_mask = image_pred[:, 4] * squeezed_conf >= conf_thre
632
-
633
- # Apply confidence mask
634
- detections = torch.cat((image_pred[:, :5], class_conf, class_pred.float()), 1)
635
- detections = detections[conf_mask]
636
-
637
- if not detections.size(0):
638
- continue
639
-
640
- # Apply Non-Maximum Suppression (NMS)
641
- if class_agnostic:
642
- nms_out_index = torchvision.ops.nms(
643
- detections[:, :4],
644
- detections[:, 4] * detections[:, 5],
645
- nms_thre,
646
- )
647
- else:
648
- nms_out_index = torchvision.ops.batched_nms(
649
- detections[:, :4],
650
- detections[:, 4] * detections[:, 5],
651
- detections[:, 6],
652
- nms_thre,
653
- )
654
- detections = detections[nms_out_index]
655
-
656
- # Append detections to output
657
- output[i] = detections
658
-
659
- return output
660
-
661
-
662
- def postprocess_results(
663
- results, original_image_shapes, image_preproc_width, image_preproc_height, class_labels, min_score=0.0
664
- ):
665
- """
666
- For each item (==image) in results, computes annotations in the form
667
-
668
- {"table": [[0.0107, 0.0859, 0.7537, 0.1219, 0.9861], ...],
669
- "figure": [...],
670
- "title": [...]
671
- }
672
- where each list of 5 floats represents a bounding box in the format [x1, y1, x2, y2, confidence]
673
-
674
- Keep only bboxes with high enough confidence.
675
- """
676
- out = []
677
-
678
- for original_image_shape, result in zip(original_image_shapes, results):
679
- annotation_dict = {label: [] for label in class_labels}
680
-
681
- if result is None:
682
- out.append(annotation_dict)
683
- continue
684
-
685
- try:
686
- result = result.cpu().numpy()
687
- scores = result[:, 4] * result[:, 5]
688
- result = result[scores > min_score]
689
-
690
- # ratio is used when image was padded
691
- ratio = min(
692
- image_preproc_width / original_image_shape[0],
693
- image_preproc_height / original_image_shape[1],
694
- )
695
- bboxes = result[:, :4] / ratio
696
-
697
- bboxes[:, [0, 2]] /= original_image_shape[1]
698
- bboxes[:, [1, 3]] /= original_image_shape[0]
699
- bboxes = np.clip(bboxes, 0.0, 1.0)
700
-
701
- labels = result[:, 6]
702
- scores = scores[scores > min_score]
703
- except Exception as e:
704
- raise ValueError(f"Error in postprocessing {result.shape} and {original_image_shape}: {e}")
705
-
706
- for box, score, label in zip(bboxes, scores, labels):
707
- # TODO(Devin): Sometimes we get back unexpected class labels?
708
- if (label < 0) or (label >= len(class_labels)):
709
- logger.warning(f"Invalid class label {label} found in postprocessing")
710
- continue
711
- else:
712
- class_name = class_labels[int(label)]
713
-
714
- annotation_dict[class_name].append([round(float(x), 4) for x in np.concatenate((box, [score]))])
715
-
716
- out.append(annotation_dict)
717
-
718
- return out
719
-
720
-
721
- def resize_image(image, target_img_size):
722
- w, h, _ = np.array(image).shape
723
-
724
- if target_img_size is not None: # Resize + Pad
725
- r = min(target_img_size[0] / w, target_img_size[1] / h)
726
- image = cv2.resize(
727
- image,
728
- (int(h * r), int(w * r)),
729
- interpolation=cv2.INTER_LINEAR,
730
- ).astype(np.uint8)
731
- image = np.pad(
732
- image,
733
- ((0, target_img_size[0] - image.shape[0]), (0, target_img_size[1] - image.shape[1]), (0, 0)),
734
- mode="constant",
735
- constant_values=114,
736
- )
737
-
738
- return image
739
-
740
-
741
552
  def expand_table_bboxes(annotation_dict, labels=None):
742
553
  """
743
554
  Additional preprocessing for tables: extend the upper bounds to capture titles if any.
@@ -1383,14 +1194,16 @@ def get_bbox_dict_yolox_table(preds, shape, class_labels, threshold=0.1, delta=0
1383
1194
  return bbox_dict
1384
1195
 
1385
1196
 
1386
- def get_yolox_model_name(yolox_http_endpoint, default_model_name="nemoretriever-page-elements-v2"):
1197
+ @multiprocessing_cache(max_calls=100) # Cache results first to avoid redundant retries from backoff
1198
+ @backoff.on_predicate(backoff.expo, max_time=30)
1199
+ def get_yolox_model_name(yolox_grpc_endpoint, default_model_name="yolox"):
1387
1200
  try:
1388
- yolox_model_name = get_model_name(yolox_http_endpoint, default_model_name)
1389
- if not yolox_model_name:
1390
- logger.warning(
1391
- "Failed to obtain yolox-page-elements model name from the endpoint. "
1392
- f"Falling back to '{default_model_name}'."
1393
- )
1201
+ client = grpcclient.InferenceServerClient(yolox_grpc_endpoint)
1202
+ model_index = client.get_model_repository_index(as_json=True)
1203
+ model_names = [x["name"] for x in model_index.get("models", [])]
1204
+ if "yolox_ensemble" in model_names:
1205
+ yolox_model_name = "yolox_ensemble"
1206
+ else:
1394
1207
  yolox_model_name = default_model_name
1395
1208
  except Exception:
1396
1209
  logger.warning(
@@ -8,7 +8,7 @@ import time
8
8
  from concurrent.futures import ThreadPoolExecutor
9
9
  from typing import Any
10
10
  from typing import Optional
11
- from typing import Tuple
11
+ from typing import Tuple, Union
12
12
 
13
13
  import numpy as np
14
14
  import requests
@@ -33,6 +33,7 @@ class NimClient:
33
33
  auth_token: Optional[str] = None,
34
34
  timeout: float = 120.0,
35
35
  max_retries: int = 5,
36
+ max_429_retries: int = 5,
36
37
  ):
37
38
  """
38
39
  Initialize the NimClient with the specified model interface, protocol, and server endpoints.
@@ -49,6 +50,10 @@ class NimClient:
49
50
  Authorization token for HTTP requests (default: None).
50
51
  timeout : float, optional
51
52
  Timeout for HTTP requests in seconds (default: 30.0).
53
+ max_retries : int, optional
54
+ The maximum number of retries for non-429 server-side errors (default: 5).
55
+ max_429_retries : int, optional
56
+ The maximum number of retries specifically for 429 errors (default: 10).
52
57
 
53
58
  Raises
54
59
  ------
@@ -62,6 +67,7 @@ class NimClient:
62
67
  self.auth_token = auth_token
63
68
  self.timeout = timeout # Timeout for HTTP requests
64
69
  self.max_retries = max_retries
70
+ self.max_429_retries = max_429_retries
65
71
  self._grpc_endpoint, self._http_endpoint = endpoints
66
72
  self._max_batch_sizes = {}
67
73
  self._lock = threading.Lock()
@@ -84,6 +90,10 @@ class NimClient:
84
90
 
85
91
  def _fetch_max_batch_size(self, model_name, model_version: str = "") -> int:
86
92
  """Fetch the maximum batch size from the Triton model configuration in a thread-safe manner."""
93
+
94
+ if model_name == "yolox_ensemble":
95
+ model_name = "yolox"
96
+
87
97
  if model_name in self._max_batch_sizes:
88
98
  return self._max_batch_sizes[model_name]
89
99
 
@@ -138,7 +148,9 @@ class NimClient:
138
148
  else:
139
149
  raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.")
140
150
 
141
- parsed_output = self.model_interface.parse_output(response, protocol=self.protocol, data=batch_data, **kwargs)
151
+ parsed_output = self.model_interface.parse_output(
152
+ response, protocol=self.protocol, data=batch_data, model_name=model_name, **kwargs
153
+ )
142
154
  return parsed_output, batch_data
143
155
 
144
156
  def try_set_max_batch_size(self, model_name, model_version: str = ""):
@@ -167,10 +179,10 @@ class NimClient:
167
179
  try:
168
180
  # 1. Retrieve or default to the model's maximum batch size.
169
181
  batch_size = self._fetch_max_batch_size(model_name)
170
- max_requested_batch_size = kwargs.get("max_batch_size", batch_size)
171
- force_requested_batch_size = kwargs.get("force_max_batch_size", False)
182
+ max_requested_batch_size = kwargs.pop("max_batch_size", batch_size)
183
+ force_requested_batch_size = kwargs.pop("force_max_batch_size", False)
172
184
  max_batch_size = (
173
- min(batch_size, max_requested_batch_size)
185
+ max(1, min(batch_size, max_requested_batch_size))
174
186
  if not force_requested_batch_size
175
187
  else max_requested_batch_size
176
188
  )
@@ -180,7 +192,11 @@ class NimClient:
180
192
 
181
193
  # 3. Format the input based on protocol.
182
194
  formatted_batches, formatted_batch_data = self.model_interface.format_input(
183
- data, protocol=self.protocol, max_batch_size=max_batch_size, model_name=model_name
195
+ data,
196
+ protocol=self.protocol,
197
+ max_batch_size=max_batch_size,
198
+ model_name=model_name,
199
+ **kwargs,
184
200
  )
185
201
 
186
202
  # Check for a custom maximum pool worker count, and remove it from kwargs.
@@ -221,7 +237,9 @@ class NimClient:
221
237
 
222
238
  return all_results
223
239
 
224
- def _grpc_infer(self, formatted_input: np.ndarray, model_name: str, **kwargs) -> np.ndarray:
240
+ def _grpc_infer(
241
+ self, formatted_input: Union[list, list[np.ndarray]], model_name: str, **kwargs
242
+ ) -> Union[list, list[np.ndarray]]:
225
243
  """
226
244
  Perform inference using the gRPC protocol.
227
245
 
@@ -237,19 +255,27 @@ class NimClient:
237
255
  np.ndarray
238
256
  The output of the model as a numpy array.
239
257
  """
258
+ if not isinstance(formatted_input, list):
259
+ formatted_input = [formatted_input]
240
260
 
241
261
  parameters = kwargs.get("parameters", {})
242
- output_names = kwargs.get("outputs", ["output"])
243
- dtype = kwargs.get("dtype", "FP32")
244
- input_name = kwargs.get("input_name", "input")
262
+ output_names = kwargs.get("output_names", ["output"])
263
+ dtypes = kwargs.get("dtypes", ["FP32"])
264
+ input_names = kwargs.get("input_names", ["input"])
245
265
 
246
- input_tensors = grpcclient.InferInput(input_name, formatted_input.shape, datatype=dtype)
247
- input_tensors.set_data_from_numpy(formatted_input)
266
+ input_tensors = []
267
+ for input_name, input_data, dtype in zip(input_names, formatted_input, dtypes):
268
+ input_tensors.append(grpcclient.InferInput(input_name, input_data.shape, datatype=dtype))
269
+
270
+ for idx, input_data in enumerate(formatted_input):
271
+ input_tensors[idx].set_data_from_numpy(input_data)
248
272
 
249
273
  outputs = [grpcclient.InferRequestedOutput(output_name) for output_name in output_names]
274
+
250
275
  response = self.client.infer(
251
- model_name=model_name, parameters=parameters, inputs=[input_tensors], outputs=outputs
276
+ model_name=model_name, parameters=parameters, inputs=input_tensors, outputs=outputs
252
277
  )
278
+
253
279
  logger.debug(f"gRPC inference response: {response}")
254
280
 
255
281
  if len(outputs) == 1:
@@ -281,6 +307,7 @@ class NimClient:
281
307
 
282
308
  base_delay = 2.0
283
309
  attempt = 0
310
+ retries_429 = 0
284
311
 
285
312
  while attempt < self.max_retries:
286
313
  try:
@@ -291,7 +318,21 @@ class NimClient:
291
318
 
292
319
  # Check for server-side or rate-limit type errors
293
320
  # e.g. 5xx => server error, 429 => too many requests
294
- if status_code == 429 or status_code == 503 or (500 <= status_code < 600):
321
+ if status_code == 429:
322
+ retries_429 += 1
323
+ logger.warning(
324
+ f"Received HTTP 429 (Too Many Requests) from {self.model_interface.name()}. "
325
+ f"Attempt {retries_429} of {self.max_429_retries}."
326
+ )
327
+ if retries_429 >= self.max_429_retries:
328
+ logger.error("Max retries for HTTP 429 exceeded.")
329
+ response.raise_for_status()
330
+ else:
331
+ backoff_time = base_delay * (2**retries_429)
332
+ time.sleep(backoff_time)
333
+ continue # Retry without incrementing the main attempt counter
334
+
335
+ if status_code == 503 or (500 <= status_code < 600):
295
336
  logger.warning(
296
337
  f"Received HTTP {status_code} ({response.reason}) from "
297
338
  f"{self.model_interface.name()}. Attempt {attempt + 1} of {self.max_retries}."
@@ -24,8 +24,8 @@ class ChartExtractorConfigSchema(BaseModel):
24
24
  A tuple containing the gRPC and HTTP services for the yolox endpoint.
25
25
  Either the gRPC or HTTP service can be empty, but not both.
26
26
 
27
- paddle_endpoints : Tuple[Optional[str], Optional[str]], default=(None, None)
28
- A tuple containing the gRPC and HTTP services for the paddle endpoint.
27
+ ocr_endpoints : Tuple[Optional[str], Optional[str]], default=(None, None)
28
+ A tuple containing the gRPC and HTTP services for the ocr endpoint.
29
29
  Either the gRPC or HTTP service can be empty, but not both.
30
30
 
31
31
  Methods
@@ -49,8 +49,8 @@ class ChartExtractorConfigSchema(BaseModel):
49
49
  yolox_endpoints: Tuple[Optional[str], Optional[str]] = (None, None)
50
50
  yolox_infer_protocol: str = ""
51
51
 
52
- paddle_endpoints: Tuple[Optional[str], Optional[str]] = (None, None)
53
- paddle_infer_protocol: str = ""
52
+ ocr_endpoints: Tuple[Optional[str], Optional[str]] = (None, None)
53
+ ocr_infer_protocol: str = ""
54
54
 
55
55
  nim_batch_size: int = 2
56
56
  workers_per_progress_engine: int = 5
@@ -86,7 +86,7 @@ class ChartExtractorConfigSchema(BaseModel):
86
86
  return None
87
87
  return service
88
88
 
89
- for endpoint_name in ["yolox_endpoints", "paddle_endpoints"]:
89
+ for endpoint_name in ["yolox_endpoints", "ocr_endpoints"]:
90
90
  grpc_service, http_service = values.get(endpoint_name, (None, None))
91
91
  grpc_service = clean_service(grpc_service)
92
92
  http_service = clean_service(http_service)
@@ -117,7 +117,7 @@ class ChartExtractorSchema(BaseModel):
117
117
  A flag indicating whether to raise an exception if a failure occurs during chart extraction.
118
118
 
119
119
  extraction_config: Optional[ChartExtractorConfigSchema], default=None
120
- Configuration for the chart extraction stage, including yolox and paddle service endpoints.
120
+ Configuration for the chart extraction stage, including yolox and ocr service endpoints.
121
121
  """
122
122
 
123
123
  max_queue_size: int = 1
@@ -20,8 +20,8 @@ class InfographicExtractorConfigSchema(BaseModel):
20
20
  auth_token : Optional[str], default=None
21
21
  Authentication token required for secure services.
22
22
 
23
- paddle_endpoints : Tuple[Optional[str], Optional[str]], default=(None, None)
24
- A tuple containing the gRPC and HTTP services for the paddle endpoint.
23
+ ocr_endpoints : Tuple[Optional[str], Optional[str]], default=(None, None)
24
+ A tuple containing the gRPC and HTTP services for the ocr endpoint.
25
25
  Either the gRPC or HTTP service can be empty, but not both.
26
26
 
27
27
  Methods
@@ -42,8 +42,8 @@ class InfographicExtractorConfigSchema(BaseModel):
42
42
 
43
43
  auth_token: Optional[str] = None
44
44
 
45
- paddle_endpoints: Tuple[Optional[str], Optional[str]] = (None, None)
46
- paddle_infer_protocol: str = ""
45
+ ocr_endpoints: Tuple[Optional[str], Optional[str]] = (None, None)
46
+ ocr_infer_protocol: str = ""
47
47
 
48
48
  nim_batch_size: int = 2
49
49
  workers_per_progress_engine: int = 5
@@ -79,7 +79,7 @@ class InfographicExtractorConfigSchema(BaseModel):
79
79
  return None
80
80
  return service
81
81
 
82
- for endpoint_name in ["paddle_endpoints"]:
82
+ for endpoint_name in ["ocr_endpoints"]:
83
83
  grpc_service, http_service = values.get(endpoint_name, (None, None))
84
84
  grpc_service = clean_service(grpc_service)
85
85
  http_service = clean_service(http_service)
@@ -110,7 +110,7 @@ class InfographicExtractorSchema(BaseModel):
110
110
  A flag indicating whether to raise an exception if a failure occurs during infographic extraction.
111
111
 
112
112
  stage_config : Optional[InfographicExtractorConfigSchema], default=None
113
- Configuration for the infographic extraction stage, including yolox and paddle service endpoints.
113
+ Configuration for the infographic extraction stage, including yolox and ocr service endpoints.
114
114
  """
115
115
 
116
116
  max_queue_size: int = 1
@@ -22,8 +22,8 @@ class TableExtractorConfigSchema(BaseModel):
22
22
  auth_token : Optional[str], default=None
23
23
  Authentication token required for secure services.
24
24
 
25
- paddle_endpoints : Tuple[Optional[str], Optional[str]], default=(None, None)
26
- A tuple containing the gRPC and HTTP services for the paddle endpoint.
25
+ ocr_endpoints : Tuple[Optional[str], Optional[str]], default=(None, None)
26
+ A tuple containing the gRPC and HTTP services for the ocr endpoint.
27
27
  Either the gRPC or HTTP service can be empty, but not both.
28
28
 
29
29
  Methods
@@ -47,8 +47,8 @@ class TableExtractorConfigSchema(BaseModel):
47
47
  yolox_endpoints: Tuple[Optional[str], Optional[str]] = (None, None)
48
48
  yolox_infer_protocol: str = ""
49
49
 
50
- paddle_endpoints: Tuple[Optional[str], Optional[str]] = (None, None)
51
- paddle_infer_protocol: str = ""
50
+ ocr_endpoints: Tuple[Optional[str], Optional[str]] = (None, None)
51
+ ocr_infer_protocol: str = ""
52
52
 
53
53
  nim_batch_size: int = 2
54
54
  workers_per_progress_engine: int = 5
@@ -81,7 +81,7 @@ class TableExtractorConfigSchema(BaseModel):
81
81
  return None
82
82
  return service
83
83
 
84
- for endpoint_name in ["yolox_endpoints", "paddle_endpoints"]:
84
+ for endpoint_name in ["yolox_endpoints", "ocr_endpoints"]:
85
85
  grpc_service, http_service = values.get(endpoint_name, (None, None))
86
86
  grpc_service = clean_service(grpc_service)
87
87
  http_service = clean_service(http_service)