nv-ingest-api 25.7.7.dev20250707__py3-none-any.whl → 25.8.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nv-ingest-api might be problematic. Click here for more details.

Files changed (33) hide show
  1. nv_ingest_api/interface/extract.py +18 -18
  2. nv_ingest_api/internal/enums/common.py +6 -0
  3. nv_ingest_api/internal/extract/image/chart_extractor.py +80 -75
  4. nv_ingest_api/internal/extract/image/image_helpers/common.py +5 -6
  5. nv_ingest_api/internal/extract/image/infographic_extractor.py +59 -35
  6. nv_ingest_api/internal/extract/image/table_extractor.py +84 -64
  7. nv_ingest_api/internal/extract/pdf/engines/nemoretriever.py +9 -8
  8. nv_ingest_api/internal/extract/pdf/engines/pdf_helpers/__init__.py +32 -20
  9. nv_ingest_api/internal/extract/pdf/engines/pdfium.py +40 -29
  10. nv_ingest_api/internal/primitives/nim/model_interface/helpers.py +59 -0
  11. nv_ingest_api/internal/primitives/nim/model_interface/nemoretriever_parse.py +1 -0
  12. nv_ingest_api/internal/primitives/nim/model_interface/{paddle.py → ocr.py} +132 -39
  13. nv_ingest_api/internal/primitives/nim/model_interface/yolox.py +44 -236
  14. nv_ingest_api/internal/primitives/nim/nim_client.py +61 -18
  15. nv_ingest_api/internal/schemas/extract/extract_chart_schema.py +6 -6
  16. nv_ingest_api/internal/schemas/extract/extract_infographic_schema.py +6 -6
  17. nv_ingest_api/internal/schemas/extract/extract_table_schema.py +5 -5
  18. nv_ingest_api/internal/schemas/meta/ingest_job_schema.py +5 -0
  19. nv_ingest_api/internal/schemas/transform/transform_image_caption_schema.py +1 -1
  20. nv_ingest_api/internal/schemas/transform/transform_text_embedding_schema.py +4 -0
  21. nv_ingest_api/internal/transform/embed_text.py +103 -12
  22. nv_ingest_api/internal/transform/split_text.py +13 -8
  23. nv_ingest_api/util/image_processing/table_and_chart.py +97 -42
  24. nv_ingest_api/util/image_processing/transforms.py +351 -87
  25. nv_ingest_api/util/message_brokers/simple_message_broker/broker.py +1 -1
  26. nv_ingest_api/util/message_brokers/simple_message_broker/simple_client.py +51 -48
  27. nv_ingest_api/util/metadata/aggregators.py +4 -1
  28. nv_ingest_api/util/pdf/pdfium.py +6 -14
  29. {nv_ingest_api-25.7.7.dev20250707.dist-info → nv_ingest_api-25.8.0rc1.dist-info}/METADATA +2 -1
  30. {nv_ingest_api-25.7.7.dev20250707.dist-info → nv_ingest_api-25.8.0rc1.dist-info}/RECORD +33 -33
  31. {nv_ingest_api-25.7.7.dev20250707.dist-info → nv_ingest_api-25.8.0rc1.dist-info}/WHEEL +0 -0
  32. {nv_ingest_api-25.7.7.dev20250707.dist-info → nv_ingest_api-25.8.0rc1.dist-info}/licenses/LICENSE +0 -0
  33. {nv_ingest_api-25.7.7.dev20250707.dist-info → nv_ingest_api-25.8.0rc1.dist-info}/top_level.txt +0 -0
@@ -2,9 +2,7 @@
2
2
  # All rights reserved.
3
3
  # SPDX-License-Identifier: Apache-2.0
4
4
 
5
-
6
- import base64
7
- import io
5
+ import os
8
6
  import logging
9
7
  import warnings
10
8
  from math import log
@@ -14,17 +12,16 @@ from typing import List
14
12
  from typing import Optional
15
13
  from typing import Tuple
16
14
 
17
- import cv2
15
+ import backoff
18
16
  import numpy as np
19
- import packaging
17
+ import json
20
18
  import pandas as pd
21
- import torch
22
- import torchvision
23
- from PIL import Image
24
19
 
25
20
  from nv_ingest_api.internal.primitives.nim import ModelInterface
26
- from nv_ingest_api.internal.primitives.nim.model_interface.helpers import get_model_name
21
+ import tritonclient.grpc as grpcclient
22
+ from nv_ingest_api.internal.primitives.nim.model_interface.decorators import multiprocessing_cache
27
23
  from nv_ingest_api.util.image_processing import scale_image_to_encoding_size
24
+ from nv_ingest_api.util.image_processing.transforms import numpy_to_base64
28
25
 
29
26
  logger = logging.getLogger(__name__)
30
27
 
@@ -35,15 +32,7 @@ YOLOX_PAGE_MIN_SCORE = 0.1
35
32
  YOLOX_PAGE_NIM_MAX_IMAGE_SIZE = 512_000
36
33
  YOLOX_PAGE_IMAGE_PREPROC_HEIGHT = 1024
37
34
  YOLOX_PAGE_IMAGE_PREPROC_WIDTH = 1024
38
-
39
- # yolox-page-elements-v1 contants
40
- YOLOX_PAGE_V1_NUM_CLASSES = 4
41
- YOLOX_PAGE_V1_FINAL_SCORE = {"table": 0.48, "chart": 0.48}
42
- YOLOX_PAGE_V1_CLASS_LABELS = [
43
- "table",
44
- "chart",
45
- "title",
46
- ]
35
+ YOLOX_PAGE_IMAGE_FORMAT = os.getenv("YOLOX_PAGE_IMAGE_FORMAT", "PNG")
47
36
 
48
37
  # yolox-page-elements-v2 contants
49
38
  YOLOX_PAGE_V2_NUM_CLASSES = 4
@@ -64,11 +53,6 @@ YOLOX_GRAPHIC_MIN_SCORE = 0.1
64
53
  YOLOX_GRAPHIC_FINAL_SCORE = 0.0
65
54
  YOLOX_GRAPHIC_NIM_MAX_IMAGE_SIZE = 512_000
66
55
 
67
- # TODO(Devin): Legacy items aren't working right for me. Double check these.
68
- LEGACY_YOLOX_GRAPHIC_IMAGE_PREPROC_HEIGHT = 1024
69
- LEGACY_YOLOX_GRAPHIC_IMAGE_PREPROC_WIDTH = 1024
70
- YOLOX_GRAPHIC_IMAGE_PREPROC_HEIGHT = 1024
71
- YOLOX_GRAPHIC_IMAGE_PREPROC_WIDTH = 1024
72
56
 
73
57
  YOLOX_GRAPHIC_CLASS_LABELS = [
74
58
  "chart_title",
@@ -112,8 +96,6 @@ class YoloxModelInterfaceBase(ModelInterface):
112
96
 
113
97
  def __init__(
114
98
  self,
115
- image_preproc_width: Optional[int] = None,
116
- image_preproc_height: Optional[int] = None,
117
99
  nim_max_image_size: Optional[int] = None,
118
100
  num_classes: Optional[int] = None,
119
101
  conf_threshold: Optional[float] = None,
@@ -127,8 +109,6 @@ class YoloxModelInterfaceBase(ModelInterface):
127
109
  Parameters
128
110
  ----------
129
111
  """
130
- self.image_preproc_width = image_preproc_width
131
- self.image_preproc_height = image_preproc_height
132
112
  self.nim_max_image_size = nim_max_image_size
133
113
  self.num_classes = num_classes
134
114
  self.conf_threshold = conf_threshold
@@ -200,6 +180,7 @@ class YoloxModelInterfaceBase(ModelInterface):
200
180
 
201
181
  # Helper functions to chunk a list into sublists of length up to chunk_size.
202
182
  def chunk_list(lst: list, chunk_size: int) -> List[list]:
183
+ chunk_size = max(1, chunk_size)
203
184
  return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
204
185
 
205
186
  def chunk_list_geometrically(lst: list, max_size: int) -> List[list]:
@@ -207,29 +188,28 @@ class YoloxModelInterfaceBase(ModelInterface):
207
188
  chunks = []
208
189
  i = 0
209
190
  while i < len(lst):
210
- chunk_size = min(2 ** int(log(len(lst) - i, 2)), max_size)
191
+ chunk_size = max(1, min(2 ** int(log(len(lst) - i, 2)), max_size))
211
192
  chunks.append(lst[i : i + chunk_size])
212
193
  i += chunk_size
213
194
  return chunks
214
195
 
215
196
  if protocol == "grpc":
216
- logger.debug("Formatting input for gRPC Yolox model")
217
- # Resize images for model input (Yolox expects 1024x1024).
218
- resized_images = [
219
- resize_image(image, (self.image_preproc_width, self.image_preproc_height)) for image in data["images"]
220
- ]
221
- # Chunk the resized images, the original images, and their shapes.
222
- resized_chunks = chunk_list_geometrically(resized_images, max_batch_size)
197
+ logger.debug("Formatting input for gRPC Yolox Ensemble model")
198
+ b64_images = [numpy_to_base64(image, format=YOLOX_PAGE_IMAGE_FORMAT) for image in data["images"]]
199
+ b64_chunks = chunk_list_geometrically(b64_images, max_batch_size)
223
200
  original_chunks = chunk_list_geometrically(data["images"], max_batch_size)
224
201
  shape_chunks = chunk_list_geometrically(data["original_image_shapes"], max_batch_size)
225
202
 
226
203
  batched_inputs = []
227
204
  formatted_batch_data = []
228
- for r_chunk, orig_chunk, shapes in zip(resized_chunks, original_chunks, shape_chunks):
229
- # Reorder axes from (B, H, W, C) to (B, C, H, W) as expected by the model.
230
- input_array = np.einsum("bijk->bkij", r_chunk).astype(np.float32)
231
- batched_inputs.append(input_array)
205
+ for b64_chunk, orig_chunk, shapes in zip(b64_chunks, original_chunks, shape_chunks):
206
+ input_array = np.array(b64_chunk, dtype=np.object_)
207
+ current_batch_size = input_array.shape[0]
208
+ single_threshold_pair = [self.conf_threshold, self.iou_threshold]
209
+ thresholds = np.tile(single_threshold_pair, (current_batch_size, 1)).astype(np.float32)
210
+ batched_inputs.append([input_array, thresholds])
232
211
  formatted_batch_data.append({"images": orig_chunk, "original_image_shapes": shapes})
212
+
233
213
  return batched_inputs, formatted_batch_data
234
214
 
235
215
  elif protocol == "http":
@@ -239,15 +219,11 @@ class YoloxModelInterfaceBase(ModelInterface):
239
219
  # Convert to uint8 if needed.
240
220
  if image.dtype != np.uint8:
241
221
  image = (image * 255).astype(np.uint8)
242
- # Convert the numpy array to a PIL Image.
243
- image_pil = Image.fromarray(image)
244
- original_size = image_pil.size
245
-
246
- # Save the image to a buffer and encode to base64.
247
- buffered = io.BytesIO()
248
- image_pil.save(buffered, format="PNG")
249
- image_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
250
222
 
223
+ # Get original size directly from numpy array (width, height)
224
+ original_size = (image.shape[1], image.shape[0])
225
+ # Convert numpy array directly to base64 using OpenCV
226
+ image_b64 = numpy_to_base64(image, format=YOLOX_PAGE_IMAGE_FORMAT)
251
227
  # Scale the image if necessary.
252
228
  scaled_image_b64, new_size = scale_image_to_encoding_size(
253
229
  image_b64, max_base64_size=self.nim_max_image_size
@@ -342,32 +318,20 @@ class YoloxModelInterfaceBase(ModelInterface):
342
318
  list[dict]
343
319
  A list of annotation dictionaries for each image in the batch.
344
320
  """
345
- original_image_shapes = kwargs.get("original_image_shapes", [])
346
-
347
321
  if protocol == "http":
348
322
  # For http, the output already has postprocessing applied. Skip to table/chart expansion.
349
323
  results = output
350
324
 
351
325
  elif protocol == "grpc":
326
+ results = []
352
327
  # For grpc, apply the same NIM postprocessing.
353
- pred = postprocess_model_prediction(
354
- output,
355
- self.num_classes,
356
- self.conf_threshold,
357
- self.iou_threshold,
358
- class_agnostic=False,
359
- )
360
- results = postprocess_results(
361
- pred,
362
- original_image_shapes,
363
- self.image_preproc_width,
364
- self.image_preproc_height,
365
- self.class_labels,
366
- min_score=self.min_score,
367
- )
368
-
328
+ for out in output:
329
+ if isinstance(out, bytes):
330
+ out = out.decode("utf-8")
331
+ if isinstance(out, dict):
332
+ continue
333
+ results.append(json.loads(out))
369
334
  inference_results = self.postprocess_annotations(results, **kwargs)
370
-
371
335
  return inference_results
372
336
 
373
337
  def postprocess_annotations(self, annotation_dicts, **kwargs):
@@ -401,22 +365,15 @@ class YoloxPageElementsModelInterface(YoloxModelInterfaceBase):
401
365
  An interface for handling inference with yolox-page-elements model, supporting both gRPC and HTTP protocols.
402
366
  """
403
367
 
404
- def __init__(self, yolox_model_name: str = "nemoretriever-page-elements-v2"):
368
+ def __init__(self):
405
369
  """
406
370
  Initialize the yolox-page-elements model interface.
407
371
  """
408
- if yolox_model_name.endswith("-v1"):
409
- num_classes = YOLOX_PAGE_V1_NUM_CLASSES
410
- final_score = YOLOX_PAGE_V1_FINAL_SCORE
411
- class_labels = YOLOX_PAGE_V1_CLASS_LABELS
412
- else:
413
- num_classes = YOLOX_PAGE_V2_NUM_CLASSES
414
- final_score = YOLOX_PAGE_V2_FINAL_SCORE
415
- class_labels = YOLOX_PAGE_V2_CLASS_LABELS
372
+ num_classes = YOLOX_PAGE_V2_NUM_CLASSES
373
+ final_score = YOLOX_PAGE_V2_FINAL_SCORE
374
+ class_labels = YOLOX_PAGE_V2_CLASS_LABELS
416
375
 
417
376
  super().__init__(
418
- image_preproc_width=YOLOX_PAGE_IMAGE_PREPROC_WIDTH,
419
- image_preproc_height=YOLOX_PAGE_IMAGE_PREPROC_HEIGHT,
420
377
  nim_max_image_size=YOLOX_PAGE_NIM_MAX_IMAGE_SIZE,
421
378
  num_classes=num_classes,
422
379
  conf_threshold=YOLOX_PAGE_CONF_THRESHOLD,
@@ -483,22 +440,11 @@ class YoloxGraphicElementsModelInterface(YoloxModelInterfaceBase):
483
440
  An interface for handling inference with yolox-graphic-elemenents model, supporting both gRPC and HTTP protocols.
484
441
  """
485
442
 
486
- def __init__(self, yolox_version: Optional[str] = None):
443
+ def __init__(self):
487
444
  """
488
445
  Initialize the yolox-graphic-elements model interface.
489
446
  """
490
- if yolox_version and (
491
- packaging.version.Version(yolox_version) >= packaging.version.Version("1.2.0-rc5") # gtc release
492
- ):
493
- image_preproc_width = YOLOX_GRAPHIC_IMAGE_PREPROC_WIDTH
494
- image_preproc_height = YOLOX_GRAPHIC_IMAGE_PREPROC_HEIGHT
495
- else:
496
- image_preproc_width = LEGACY_YOLOX_GRAPHIC_IMAGE_PREPROC_WIDTH
497
- image_preproc_height = LEGACY_YOLOX_GRAPHIC_IMAGE_PREPROC_HEIGHT
498
-
499
447
  super().__init__(
500
- image_preproc_width=image_preproc_width,
501
- image_preproc_height=image_preproc_height,
502
448
  nim_max_image_size=YOLOX_GRAPHIC_NIM_MAX_IMAGE_SIZE,
503
449
  num_classes=YOLOX_GRAPHIC_NUM_CLASSES,
504
450
  conf_threshold=YOLOX_GRAPHIC_CONF_THRESHOLD,
@@ -556,8 +502,6 @@ class YoloxTableStructureModelInterface(YoloxModelInterfaceBase):
556
502
  Initialize the yolox-graphic-elements model interface.
557
503
  """
558
504
  super().__init__(
559
- image_preproc_width=YOLOX_TABLE_IMAGE_PREPROC_HEIGHT,
560
- image_preproc_height=YOLOX_TABLE_IMAGE_PREPROC_HEIGHT,
561
505
  nim_max_image_size=YOLOX_TABLE_NIM_MAX_IMAGE_SIZE,
562
506
  num_classes=YOLOX_TABLE_NUM_CLASSES,
563
507
  conf_threshold=YOLOX_TABLE_CONF_THRESHOLD,
@@ -605,144 +549,6 @@ class YoloxTableStructureModelInterface(YoloxModelInterfaceBase):
605
549
  return inference_results
606
550
 
607
551
 
608
- def postprocess_model_prediction(prediction, num_classes, conf_thre=0.7, nms_thre=0.45, class_agnostic=False):
609
- # Convert numpy array to torch tensor
610
- prediction = torch.from_numpy(prediction.copy())
611
-
612
- # Compute box corners
613
- box_corner = prediction.new(prediction.shape)
614
- box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
615
- box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
616
- box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
617
- box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
618
- prediction[:, :, :4] = box_corner[:, :, :4]
619
-
620
- output = [None for _ in range(len(prediction))]
621
-
622
- for i, image_pred in enumerate(prediction):
623
- # If no detections, continue to the next image
624
- if not image_pred.size(0):
625
- continue
626
-
627
- # Ensure image_pred is 2D
628
- if image_pred.ndim == 1:
629
- image_pred = image_pred.unsqueeze(0)
630
-
631
- # Get score and class with highest confidence
632
- class_conf, class_pred = torch.max(image_pred[:, 5 : 5 + num_classes], 1, keepdim=True)
633
-
634
- # Confidence mask
635
- squeezed_conf = class_conf.squeeze(dim=1)
636
- conf_mask = image_pred[:, 4] * squeezed_conf >= conf_thre
637
-
638
- # Apply confidence mask
639
- detections = torch.cat((image_pred[:, :5], class_conf, class_pred.float()), 1)
640
- detections = detections[conf_mask]
641
-
642
- if not detections.size(0):
643
- continue
644
-
645
- # Apply Non-Maximum Suppression (NMS)
646
- if class_agnostic:
647
- nms_out_index = torchvision.ops.nms(
648
- detections[:, :4],
649
- detections[:, 4] * detections[:, 5],
650
- nms_thre,
651
- )
652
- else:
653
- nms_out_index = torchvision.ops.batched_nms(
654
- detections[:, :4],
655
- detections[:, 4] * detections[:, 5],
656
- detections[:, 6],
657
- nms_thre,
658
- )
659
- detections = detections[nms_out_index]
660
-
661
- # Append detections to output
662
- output[i] = detections
663
-
664
- return output
665
-
666
-
667
- def postprocess_results(
668
- results, original_image_shapes, image_preproc_width, image_preproc_height, class_labels, min_score=0.0
669
- ):
670
- """
671
- For each item (==image) in results, computes annotations in the form
672
-
673
- {"table": [[0.0107, 0.0859, 0.7537, 0.1219, 0.9861], ...],
674
- "figure": [...],
675
- "title": [...]
676
- }
677
- where each list of 5 floats represents a bounding box in the format [x1, y1, x2, y2, confidence]
678
-
679
- Keep only bboxes with high enough confidence.
680
- """
681
- out = []
682
-
683
- for original_image_shape, result in zip(original_image_shapes, results):
684
- annotation_dict = {label: [] for label in class_labels}
685
-
686
- if result is None:
687
- out.append(annotation_dict)
688
- continue
689
-
690
- try:
691
- result = result.cpu().numpy()
692
- scores = result[:, 4] * result[:, 5]
693
- result = result[scores > min_score]
694
-
695
- # ratio is used when image was padded
696
- ratio = min(
697
- image_preproc_width / original_image_shape[0],
698
- image_preproc_height / original_image_shape[1],
699
- )
700
- bboxes = result[:, :4] / ratio
701
-
702
- bboxes[:, [0, 2]] /= original_image_shape[1]
703
- bboxes[:, [1, 3]] /= original_image_shape[0]
704
- bboxes = np.clip(bboxes, 0.0, 1.0)
705
-
706
- labels = result[:, 6]
707
- scores = scores[scores > min_score]
708
- except Exception as e:
709
- raise ValueError(f"Error in postprocessing {result.shape} and {original_image_shape}: {e}")
710
-
711
- for box, score, label in zip(bboxes, scores, labels):
712
- # TODO(Devin): Sometimes we get back unexpected class labels?
713
- if (label < 0) or (label >= len(class_labels)):
714
- logger.warning(f"Invalid class label {label} found in postprocessing")
715
- continue
716
- else:
717
- class_name = class_labels[int(label)]
718
-
719
- annotation_dict[class_name].append([round(float(x), 4) for x in np.concatenate((box, [score]))])
720
-
721
- out.append(annotation_dict)
722
-
723
- return out
724
-
725
-
726
- def resize_image(image, target_img_size):
727
- w, h, _ = np.array(image).shape
728
-
729
- if target_img_size is not None: # Resize + Pad
730
- r = min(target_img_size[0] / w, target_img_size[1] / h)
731
- image = cv2.resize(
732
- image,
733
- (int(h * r), int(w * r)),
734
- interpolation=cv2.INTER_LINEAR,
735
- ).astype(np.uint8)
736
- image = np.pad(
737
- image,
738
- ((0, target_img_size[0] - image.shape[0]), (0, target_img_size[1] - image.shape[1]), (0, 0)),
739
- mode="constant",
740
- constant_values=114,
741
- )
742
-
743
- return image
744
-
745
-
746
552
  def expand_table_bboxes(annotation_dict, labels=None):
747
553
  """
748
554
  Additional preprocessing for tables: extend the upper bounds to capture titles if any.
@@ -1388,14 +1194,16 @@ def get_bbox_dict_yolox_table(preds, shape, class_labels, threshold=0.1, delta=0
1388
1194
  return bbox_dict
1389
1195
 
1390
1196
 
1391
- def get_yolox_model_name(yolox_http_endpoint, default_model_name="nemoretriever-page-elements-v2"):
1197
+ @multiprocessing_cache(max_calls=100) # Cache results first to avoid redundant retries from backoff
1198
+ @backoff.on_predicate(backoff.expo, max_time=30)
1199
+ def get_yolox_model_name(yolox_grpc_endpoint, default_model_name="yolox"):
1392
1200
  try:
1393
- yolox_model_name = get_model_name(yolox_http_endpoint, default_model_name)
1394
- if not yolox_model_name:
1395
- logger.warning(
1396
- "Failed to obtain yolox-page-elements model name from the endpoint. "
1397
- f"Falling back to '{default_model_name}'."
1398
- )
1201
+ client = grpcclient.InferenceServerClient(yolox_grpc_endpoint)
1202
+ model_index = client.get_model_repository_index(as_json=True)
1203
+ model_names = [x["name"] for x in model_index.get("models", [])]
1204
+ if "yolox_ensemble" in model_names:
1205
+ yolox_model_name = "yolox_ensemble"
1206
+ else:
1399
1207
  yolox_model_name = default_model_name
1400
1208
  except Exception:
1401
1209
  logger.warning(
@@ -5,10 +5,10 @@
5
5
  import logging
6
6
  import threading
7
7
  import time
8
- from concurrent.futures import ThreadPoolExecutor
8
+ from concurrent.futures import ThreadPoolExecutor, as_completed
9
9
  from typing import Any
10
10
  from typing import Optional
11
- from typing import Tuple
11
+ from typing import Tuple, Union
12
12
 
13
13
  import numpy as np
14
14
  import requests
@@ -33,6 +33,7 @@ class NimClient:
33
33
  auth_token: Optional[str] = None,
34
34
  timeout: float = 120.0,
35
35
  max_retries: int = 5,
36
+ max_429_retries: int = 5,
36
37
  ):
37
38
  """
38
39
  Initialize the NimClient with the specified model interface, protocol, and server endpoints.
@@ -49,6 +50,10 @@ class NimClient:
49
50
  Authorization token for HTTP requests (default: None).
50
51
  timeout : float, optional
51
52
  Timeout for HTTP requests in seconds (default: 30.0).
53
+ max_retries : int, optional
54
+ The maximum number of retries for non-429 server-side errors (default: 5).
55
+ max_429_retries : int, optional
56
+ The maximum number of retries specifically for 429 errors (default: 10).
52
57
 
53
58
  Raises
54
59
  ------
@@ -62,6 +67,7 @@ class NimClient:
62
67
  self.auth_token = auth_token
63
68
  self.timeout = timeout # Timeout for HTTP requests
64
69
  self.max_retries = max_retries
70
+ self.max_429_retries = max_429_retries
65
71
  self._grpc_endpoint, self._http_endpoint = endpoints
66
72
  self._max_batch_sizes = {}
67
73
  self._lock = threading.Lock()
@@ -84,6 +90,10 @@ class NimClient:
84
90
 
85
91
  def _fetch_max_batch_size(self, model_name, model_version: str = "") -> int:
86
92
  """Fetch the maximum batch size from the Triton model configuration in a thread-safe manner."""
93
+
94
+ if model_name == "yolox_ensemble":
95
+ model_name = "yolox"
96
+
87
97
  if model_name in self._max_batch_sizes:
88
98
  return self._max_batch_sizes[model_name]
89
99
 
@@ -138,7 +148,9 @@ class NimClient:
138
148
  else:
139
149
  raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.")
140
150
 
141
- parsed_output = self.model_interface.parse_output(response, protocol=self.protocol, data=batch_data, **kwargs)
151
+ parsed_output = self.model_interface.parse_output(
152
+ response, protocol=self.protocol, data=batch_data, model_name=model_name, **kwargs
153
+ )
142
154
  return parsed_output, batch_data
143
155
 
144
156
  def try_set_max_batch_size(self, model_name, model_version: str = ""):
@@ -167,10 +179,10 @@ class NimClient:
167
179
  try:
168
180
  # 1. Retrieve or default to the model's maximum batch size.
169
181
  batch_size = self._fetch_max_batch_size(model_name)
170
- max_requested_batch_size = kwargs.get("max_batch_size", batch_size)
171
- force_requested_batch_size = kwargs.get("force_max_batch_size", False)
182
+ max_requested_batch_size = kwargs.pop("max_batch_size", batch_size)
183
+ force_requested_batch_size = kwargs.pop("force_max_batch_size", False)
172
184
  max_batch_size = (
173
- min(batch_size, max_requested_batch_size)
185
+ max(1, min(batch_size, max_requested_batch_size))
174
186
  if not force_requested_batch_size
175
187
  else max_requested_batch_size
176
188
  )
@@ -180,7 +192,11 @@ class NimClient:
180
192
 
181
193
  # 3. Format the input based on protocol.
182
194
  formatted_batches, formatted_batch_data = self.model_interface.format_input(
183
- data, protocol=self.protocol, max_batch_size=max_batch_size, model_name=model_name
195
+ data,
196
+ protocol=self.protocol,
197
+ max_batch_size=max_batch_size,
198
+ model_name=model_name,
199
+ **kwargs,
184
200
  )
185
201
 
186
202
  # Check for a custom maximum pool worker count, and remove it from kwargs.
@@ -190,13 +206,15 @@ class NimClient:
190
206
  # We enumerate the batches so that we can later reassemble results in order.
191
207
  results = [None] * len(formatted_batches)
192
208
  with ThreadPoolExecutor(max_workers=max_pool_workers) as executor:
193
- futures = []
209
+ future_to_idx = {}
194
210
  for idx, (batch, batch_data) in enumerate(zip(formatted_batches, formatted_batch_data)):
195
211
  future = executor.submit(
196
212
  self._process_batch, batch, batch_data=batch_data, model_name=model_name, **kwargs
197
213
  )
198
- futures.append((idx, future))
199
- for idx, future in futures:
214
+ future_to_idx[future] = idx
215
+
216
+ for future in as_completed(future_to_idx.keys()):
217
+ idx = future_to_idx[future]
200
218
  results[idx] = future.result()
201
219
 
202
220
  # 5. Process the parsed outputs for each batch using its corresponding batch_data.
@@ -221,7 +239,9 @@ class NimClient:
221
239
 
222
240
  return all_results
223
241
 
224
- def _grpc_infer(self, formatted_input: np.ndarray, model_name: str, **kwargs) -> np.ndarray:
242
+ def _grpc_infer(
243
+ self, formatted_input: Union[list, list[np.ndarray]], model_name: str, **kwargs
244
+ ) -> Union[list, list[np.ndarray]]:
225
245
  """
226
246
  Perform inference using the gRPC protocol.
227
247
 
@@ -237,19 +257,27 @@ class NimClient:
237
257
  np.ndarray
238
258
  The output of the model as a numpy array.
239
259
  """
260
+ if not isinstance(formatted_input, list):
261
+ formatted_input = [formatted_input]
240
262
 
241
263
  parameters = kwargs.get("parameters", {})
242
- output_names = kwargs.get("outputs", ["output"])
243
- dtype = kwargs.get("dtype", "FP32")
244
- input_name = kwargs.get("input_name", "input")
264
+ output_names = kwargs.get("output_names", ["output"])
265
+ dtypes = kwargs.get("dtypes", ["FP32"])
266
+ input_names = kwargs.get("input_names", ["input"])
267
+
268
+ input_tensors = []
269
+ for input_name, input_data, dtype in zip(input_names, formatted_input, dtypes):
270
+ input_tensors.append(grpcclient.InferInput(input_name, input_data.shape, datatype=dtype))
245
271
 
246
- input_tensors = grpcclient.InferInput(input_name, formatted_input.shape, datatype=dtype)
247
- input_tensors.set_data_from_numpy(formatted_input)
272
+ for idx, input_data in enumerate(formatted_input):
273
+ input_tensors[idx].set_data_from_numpy(input_data)
248
274
 
249
275
  outputs = [grpcclient.InferRequestedOutput(output_name) for output_name in output_names]
276
+
250
277
  response = self.client.infer(
251
- model_name=model_name, parameters=parameters, inputs=[input_tensors], outputs=outputs
278
+ model_name=model_name, parameters=parameters, inputs=input_tensors, outputs=outputs
252
279
  )
280
+
253
281
  logger.debug(f"gRPC inference response: {response}")
254
282
 
255
283
  if len(outputs) == 1:
@@ -281,6 +309,7 @@ class NimClient:
281
309
 
282
310
  base_delay = 2.0
283
311
  attempt = 0
312
+ retries_429 = 0
284
313
 
285
314
  while attempt < self.max_retries:
286
315
  try:
@@ -291,7 +320,21 @@ class NimClient:
291
320
 
292
321
  # Check for server-side or rate-limit type errors
293
322
  # e.g. 5xx => server error, 429 => too many requests
294
- if status_code == 429 or status_code == 503 or (500 <= status_code < 600):
323
+ if status_code == 429:
324
+ retries_429 += 1
325
+ logger.warning(
326
+ f"Received HTTP 429 (Too Many Requests) from {self.model_interface.name()}. "
327
+ f"Attempt {retries_429} of {self.max_429_retries}."
328
+ )
329
+ if retries_429 >= self.max_429_retries:
330
+ logger.error("Max retries for HTTP 429 exceeded.")
331
+ response.raise_for_status()
332
+ else:
333
+ backoff_time = base_delay * (2**retries_429)
334
+ time.sleep(backoff_time)
335
+ continue # Retry without incrementing the main attempt counter
336
+
337
+ if status_code == 503 or (500 <= status_code < 600):
295
338
  logger.warning(
296
339
  f"Received HTTP {status_code} ({response.reason}) from "
297
340
  f"{self.model_interface.name()}. Attempt {attempt + 1} of {self.max_retries}."
@@ -24,8 +24,8 @@ class ChartExtractorConfigSchema(BaseModel):
24
24
  A tuple containing the gRPC and HTTP services for the yolox endpoint.
25
25
  Either the gRPC or HTTP service can be empty, but not both.
26
26
 
27
- paddle_endpoints : Tuple[Optional[str], Optional[str]], default=(None, None)
28
- A tuple containing the gRPC and HTTP services for the paddle endpoint.
27
+ ocr_endpoints : Tuple[Optional[str], Optional[str]], default=(None, None)
28
+ A tuple containing the gRPC and HTTP services for the ocr endpoint.
29
29
  Either the gRPC or HTTP service can be empty, but not both.
30
30
 
31
31
  Methods
@@ -49,8 +49,8 @@ class ChartExtractorConfigSchema(BaseModel):
49
49
  yolox_endpoints: Tuple[Optional[str], Optional[str]] = (None, None)
50
50
  yolox_infer_protocol: str = ""
51
51
 
52
- paddle_endpoints: Tuple[Optional[str], Optional[str]] = (None, None)
53
- paddle_infer_protocol: str = ""
52
+ ocr_endpoints: Tuple[Optional[str], Optional[str]] = (None, None)
53
+ ocr_infer_protocol: str = ""
54
54
 
55
55
  nim_batch_size: int = 2
56
56
  workers_per_progress_engine: int = 5
@@ -86,7 +86,7 @@ class ChartExtractorConfigSchema(BaseModel):
86
86
  return None
87
87
  return service
88
88
 
89
- for endpoint_name in ["yolox_endpoints", "paddle_endpoints"]:
89
+ for endpoint_name in ["yolox_endpoints", "ocr_endpoints"]:
90
90
  grpc_service, http_service = values.get(endpoint_name, (None, None))
91
91
  grpc_service = clean_service(grpc_service)
92
92
  http_service = clean_service(http_service)
@@ -117,7 +117,7 @@ class ChartExtractorSchema(BaseModel):
117
117
  A flag indicating whether to raise an exception if a failure occurs during chart extraction.
118
118
 
119
119
  extraction_config: Optional[ChartExtractorConfigSchema], default=None
120
- Configuration for the chart extraction stage, including yolox and paddle service endpoints.
120
+ Configuration for the chart extraction stage, including yolox and ocr service endpoints.
121
121
  """
122
122
 
123
123
  max_queue_size: int = 1
@@ -20,8 +20,8 @@ class InfographicExtractorConfigSchema(BaseModel):
20
20
  auth_token : Optional[str], default=None
21
21
  Authentication token required for secure services.
22
22
 
23
- paddle_endpoints : Tuple[Optional[str], Optional[str]], default=(None, None)
24
- A tuple containing the gRPC and HTTP services for the paddle endpoint.
23
+ ocr_endpoints : Tuple[Optional[str], Optional[str]], default=(None, None)
24
+ A tuple containing the gRPC and HTTP services for the ocr endpoint.
25
25
  Either the gRPC or HTTP service can be empty, but not both.
26
26
 
27
27
  Methods
@@ -42,8 +42,8 @@ class InfographicExtractorConfigSchema(BaseModel):
42
42
 
43
43
  auth_token: Optional[str] = None
44
44
 
45
- paddle_endpoints: Tuple[Optional[str], Optional[str]] = (None, None)
46
- paddle_infer_protocol: str = ""
45
+ ocr_endpoints: Tuple[Optional[str], Optional[str]] = (None, None)
46
+ ocr_infer_protocol: str = ""
47
47
 
48
48
  nim_batch_size: int = 2
49
49
  workers_per_progress_engine: int = 5
@@ -79,7 +79,7 @@ class InfographicExtractorConfigSchema(BaseModel):
79
79
  return None
80
80
  return service
81
81
 
82
- for endpoint_name in ["paddle_endpoints"]:
82
+ for endpoint_name in ["ocr_endpoints"]:
83
83
  grpc_service, http_service = values.get(endpoint_name, (None, None))
84
84
  grpc_service = clean_service(grpc_service)
85
85
  http_service = clean_service(http_service)
@@ -110,7 +110,7 @@ class InfographicExtractorSchema(BaseModel):
110
110
  A flag indicating whether to raise an exception if a failure occurs during infographic extraction.
111
111
 
112
112
  stage_config : Optional[InfographicExtractorConfigSchema], default=None
113
- Configuration for the infographic extraction stage, including yolox and paddle service endpoints.
113
+ Configuration for the infographic extraction stage, including yolox and ocr service endpoints.
114
114
  """
115
115
 
116
116
  max_queue_size: int = 1