nv-ingest-api 2025.4.21.dev20250421__py3-none-any.whl → 2025.4.23.dev20250423__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nv-ingest-api might be problematic. Click here for more details.

Files changed (153) hide show
  1. nv_ingest_api/__init__.py +3 -0
  2. nv_ingest_api/interface/__init__.py +215 -0
  3. nv_ingest_api/interface/extract.py +972 -0
  4. nv_ingest_api/interface/mutate.py +154 -0
  5. nv_ingest_api/interface/store.py +218 -0
  6. nv_ingest_api/interface/transform.py +382 -0
  7. nv_ingest_api/interface/utility.py +200 -0
  8. nv_ingest_api/internal/enums/__init__.py +3 -0
  9. nv_ingest_api/internal/enums/common.py +494 -0
  10. nv_ingest_api/internal/extract/__init__.py +3 -0
  11. nv_ingest_api/internal/extract/audio/__init__.py +3 -0
  12. nv_ingest_api/internal/extract/audio/audio_extraction.py +149 -0
  13. nv_ingest_api/internal/extract/docx/__init__.py +5 -0
  14. nv_ingest_api/internal/extract/docx/docx_extractor.py +205 -0
  15. nv_ingest_api/internal/extract/docx/engines/__init__.py +0 -0
  16. nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/__init__.py +3 -0
  17. nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/docx_helper.py +122 -0
  18. nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/docxreader.py +895 -0
  19. nv_ingest_api/internal/extract/image/__init__.py +3 -0
  20. nv_ingest_api/internal/extract/image/chart_extractor.py +353 -0
  21. nv_ingest_api/internal/extract/image/image_extractor.py +204 -0
  22. nv_ingest_api/internal/extract/image/image_helpers/__init__.py +3 -0
  23. nv_ingest_api/internal/extract/image/image_helpers/common.py +403 -0
  24. nv_ingest_api/internal/extract/image/infographic_extractor.py +253 -0
  25. nv_ingest_api/internal/extract/image/table_extractor.py +344 -0
  26. nv_ingest_api/internal/extract/pdf/__init__.py +3 -0
  27. nv_ingest_api/internal/extract/pdf/engines/__init__.py +19 -0
  28. nv_ingest_api/internal/extract/pdf/engines/adobe.py +484 -0
  29. nv_ingest_api/internal/extract/pdf/engines/llama.py +243 -0
  30. nv_ingest_api/internal/extract/pdf/engines/nemoretriever.py +597 -0
  31. nv_ingest_api/internal/extract/pdf/engines/pdf_helpers/__init__.py +146 -0
  32. nv_ingest_api/internal/extract/pdf/engines/pdfium.py +603 -0
  33. nv_ingest_api/internal/extract/pdf/engines/tika.py +96 -0
  34. nv_ingest_api/internal/extract/pdf/engines/unstructured_io.py +426 -0
  35. nv_ingest_api/internal/extract/pdf/pdf_extractor.py +74 -0
  36. nv_ingest_api/internal/extract/pptx/__init__.py +5 -0
  37. nv_ingest_api/internal/extract/pptx/engines/__init__.py +0 -0
  38. nv_ingest_api/internal/extract/pptx/engines/pptx_helper.py +799 -0
  39. nv_ingest_api/internal/extract/pptx/pptx_extractor.py +187 -0
  40. nv_ingest_api/internal/mutate/__init__.py +3 -0
  41. nv_ingest_api/internal/mutate/deduplicate.py +110 -0
  42. nv_ingest_api/internal/mutate/filter.py +133 -0
  43. nv_ingest_api/internal/primitives/__init__.py +0 -0
  44. nv_ingest_api/{primitives → internal/primitives}/control_message_task.py +4 -0
  45. nv_ingest_api/{primitives → internal/primitives}/ingest_control_message.py +5 -2
  46. nv_ingest_api/internal/primitives/nim/__init__.py +8 -0
  47. nv_ingest_api/internal/primitives/nim/default_values.py +15 -0
  48. nv_ingest_api/internal/primitives/nim/model_interface/__init__.py +3 -0
  49. nv_ingest_api/internal/primitives/nim/model_interface/cached.py +274 -0
  50. nv_ingest_api/internal/primitives/nim/model_interface/decorators.py +56 -0
  51. nv_ingest_api/internal/primitives/nim/model_interface/deplot.py +270 -0
  52. nv_ingest_api/internal/primitives/nim/model_interface/helpers.py +275 -0
  53. nv_ingest_api/internal/primitives/nim/model_interface/nemoretriever_parse.py +238 -0
  54. nv_ingest_api/internal/primitives/nim/model_interface/paddle.py +462 -0
  55. nv_ingest_api/internal/primitives/nim/model_interface/parakeet.py +367 -0
  56. nv_ingest_api/internal/primitives/nim/model_interface/text_embedding.py +132 -0
  57. nv_ingest_api/internal/primitives/nim/model_interface/vlm.py +152 -0
  58. nv_ingest_api/internal/primitives/nim/model_interface/yolox.py +1400 -0
  59. nv_ingest_api/internal/primitives/nim/nim_client.py +344 -0
  60. nv_ingest_api/internal/primitives/nim/nim_model_interface.py +81 -0
  61. nv_ingest_api/internal/primitives/tracing/__init__.py +0 -0
  62. nv_ingest_api/internal/primitives/tracing/latency.py +69 -0
  63. nv_ingest_api/internal/primitives/tracing/logging.py +96 -0
  64. nv_ingest_api/internal/primitives/tracing/tagging.py +197 -0
  65. nv_ingest_api/internal/schemas/__init__.py +3 -0
  66. nv_ingest_api/internal/schemas/extract/__init__.py +3 -0
  67. nv_ingest_api/internal/schemas/extract/extract_audio_schema.py +130 -0
  68. nv_ingest_api/internal/schemas/extract/extract_chart_schema.py +135 -0
  69. nv_ingest_api/internal/schemas/extract/extract_docx_schema.py +124 -0
  70. nv_ingest_api/internal/schemas/extract/extract_image_schema.py +124 -0
  71. nv_ingest_api/internal/schemas/extract/extract_infographic_schema.py +128 -0
  72. nv_ingest_api/internal/schemas/extract/extract_pdf_schema.py +218 -0
  73. nv_ingest_api/internal/schemas/extract/extract_pptx_schema.py +124 -0
  74. nv_ingest_api/internal/schemas/extract/extract_table_schema.py +129 -0
  75. nv_ingest_api/internal/schemas/message_brokers/__init__.py +3 -0
  76. nv_ingest_api/internal/schemas/message_brokers/message_broker_client_schema.py +23 -0
  77. nv_ingest_api/internal/schemas/message_brokers/request_schema.py +34 -0
  78. nv_ingest_api/internal/schemas/message_brokers/response_schema.py +19 -0
  79. nv_ingest_api/internal/schemas/meta/__init__.py +3 -0
  80. nv_ingest_api/internal/schemas/meta/base_model_noext.py +11 -0
  81. nv_ingest_api/internal/schemas/meta/ingest_job_schema.py +237 -0
  82. nv_ingest_api/internal/schemas/meta/metadata_schema.py +221 -0
  83. nv_ingest_api/internal/schemas/mutate/__init__.py +3 -0
  84. nv_ingest_api/internal/schemas/mutate/mutate_image_dedup_schema.py +16 -0
  85. nv_ingest_api/internal/schemas/store/__init__.py +3 -0
  86. nv_ingest_api/internal/schemas/store/store_embedding_schema.py +28 -0
  87. nv_ingest_api/internal/schemas/store/store_image_schema.py +30 -0
  88. nv_ingest_api/internal/schemas/transform/__init__.py +3 -0
  89. nv_ingest_api/internal/schemas/transform/transform_image_caption_schema.py +15 -0
  90. nv_ingest_api/internal/schemas/transform/transform_image_filter_schema.py +17 -0
  91. nv_ingest_api/internal/schemas/transform/transform_text_embedding_schema.py +25 -0
  92. nv_ingest_api/internal/schemas/transform/transform_text_splitter_schema.py +22 -0
  93. nv_ingest_api/internal/store/__init__.py +3 -0
  94. nv_ingest_api/internal/store/embed_text_upload.py +236 -0
  95. nv_ingest_api/internal/store/image_upload.py +232 -0
  96. nv_ingest_api/internal/transform/__init__.py +3 -0
  97. nv_ingest_api/internal/transform/caption_image.py +205 -0
  98. nv_ingest_api/internal/transform/embed_text.py +496 -0
  99. nv_ingest_api/internal/transform/split_text.py +157 -0
  100. nv_ingest_api/util/__init__.py +0 -0
  101. nv_ingest_api/util/control_message/__init__.py +0 -0
  102. nv_ingest_api/util/control_message/validators.py +47 -0
  103. nv_ingest_api/util/converters/__init__.py +0 -0
  104. nv_ingest_api/util/converters/bytetools.py +78 -0
  105. nv_ingest_api/util/converters/containers.py +65 -0
  106. nv_ingest_api/util/converters/datetools.py +90 -0
  107. nv_ingest_api/util/converters/dftools.py +127 -0
  108. nv_ingest_api/util/converters/formats.py +64 -0
  109. nv_ingest_api/util/converters/type_mappings.py +27 -0
  110. nv_ingest_api/util/detectors/__init__.py +5 -0
  111. nv_ingest_api/util/detectors/language.py +38 -0
  112. nv_ingest_api/util/exception_handlers/__init__.py +0 -0
  113. nv_ingest_api/util/exception_handlers/converters.py +72 -0
  114. nv_ingest_api/util/exception_handlers/decorators.py +223 -0
  115. nv_ingest_api/util/exception_handlers/detectors.py +74 -0
  116. nv_ingest_api/util/exception_handlers/pdf.py +116 -0
  117. nv_ingest_api/util/exception_handlers/schemas.py +68 -0
  118. nv_ingest_api/util/image_processing/__init__.py +5 -0
  119. nv_ingest_api/util/image_processing/clustering.py +260 -0
  120. nv_ingest_api/util/image_processing/processing.py +179 -0
  121. nv_ingest_api/util/image_processing/table_and_chart.py +449 -0
  122. nv_ingest_api/util/image_processing/transforms.py +407 -0
  123. nv_ingest_api/util/logging/__init__.py +0 -0
  124. nv_ingest_api/util/logging/configuration.py +31 -0
  125. nv_ingest_api/util/message_brokers/__init__.py +3 -0
  126. nv_ingest_api/util/message_brokers/simple_message_broker/__init__.py +9 -0
  127. nv_ingest_api/util/message_brokers/simple_message_broker/broker.py +465 -0
  128. nv_ingest_api/util/message_brokers/simple_message_broker/ordered_message_queue.py +71 -0
  129. nv_ingest_api/util/message_brokers/simple_message_broker/simple_client.py +451 -0
  130. nv_ingest_api/util/metadata/__init__.py +5 -0
  131. nv_ingest_api/util/metadata/aggregators.py +469 -0
  132. nv_ingest_api/util/multi_processing/__init__.py +8 -0
  133. nv_ingest_api/util/multi_processing/mp_pool_singleton.py +194 -0
  134. nv_ingest_api/util/nim/__init__.py +56 -0
  135. nv_ingest_api/util/pdf/__init__.py +3 -0
  136. nv_ingest_api/util/pdf/pdfium.py +427 -0
  137. nv_ingest_api/util/schema/__init__.py +0 -0
  138. nv_ingest_api/util/schema/schema_validator.py +10 -0
  139. nv_ingest_api/util/service_clients/__init__.py +3 -0
  140. nv_ingest_api/util/service_clients/client_base.py +86 -0
  141. nv_ingest_api/util/service_clients/kafka/__init__.py +3 -0
  142. nv_ingest_api/util/service_clients/redis/__init__.py +0 -0
  143. nv_ingest_api/util/service_clients/redis/redis_client.py +823 -0
  144. nv_ingest_api/util/service_clients/rest/__init__.py +0 -0
  145. nv_ingest_api/util/service_clients/rest/rest_client.py +531 -0
  146. nv_ingest_api/util/string_processing/__init__.py +51 -0
  147. {nv_ingest_api-2025.4.21.dev20250421.dist-info → nv_ingest_api-2025.4.23.dev20250423.dist-info}/METADATA +1 -1
  148. nv_ingest_api-2025.4.23.dev20250423.dist-info/RECORD +152 -0
  149. {nv_ingest_api-2025.4.21.dev20250421.dist-info → nv_ingest_api-2025.4.23.dev20250423.dist-info}/WHEEL +1 -1
  150. nv_ingest_api-2025.4.21.dev20250421.dist-info/RECORD +0 -9
  151. /nv_ingest_api/{primitives → internal}/__init__.py +0 -0
  152. {nv_ingest_api-2025.4.21.dev20250421.dist-info → nv_ingest_api-2025.4.23.dev20250423.dist-info}/licenses/LICENSE +0 -0
  153. {nv_ingest_api-2025.4.21.dev20250421.dist-info → nv_ingest_api-2025.4.23.dev20250423.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1400 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+
6
+ import base64
7
+ import io
8
+ import logging
9
+ import warnings
10
+ from math import log
11
+ from typing import Any
12
+ from typing import Dict
13
+ from typing import List
14
+ from typing import Optional
15
+ from typing import Tuple
16
+
17
+ import cv2
18
+ import numpy as np
19
+ import packaging
20
+ import pandas as pd
21
+ import torch
22
+ import torchvision
23
+ from PIL import Image
24
+
25
+ from nv_ingest_api.internal.primitives.nim import ModelInterface
26
+ from nv_ingest_api.internal.primitives.nim.model_interface.helpers import get_model_name
27
+ from nv_ingest_api.util.image_processing import scale_image_to_encoding_size
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+ # yolox-page-elements-v1 and v2 common contants
32
+ YOLOX_PAGE_CONF_THRESHOLD = 0.01
33
+ YOLOX_PAGE_IOU_THRESHOLD = 0.5
34
+ YOLOX_PAGE_MIN_SCORE = 0.1
35
+ YOLOX_PAGE_NIM_MAX_IMAGE_SIZE = 512_000
36
+ YOLOX_PAGE_IMAGE_PREPROC_HEIGHT = 1024
37
+ YOLOX_PAGE_IMAGE_PREPROC_WIDTH = 1024
38
+
39
+ # yolox-page-elements-v1 contants
40
+ YOLOX_PAGE_V1_NUM_CLASSES = 4
41
+ YOLOX_PAGE_V1_FINAL_SCORE = {"table": 0.48, "chart": 0.48}
42
+ YOLOX_PAGE_V1_CLASS_LABELS = [
43
+ "table",
44
+ "chart",
45
+ "title",
46
+ ]
47
+
48
+ # yolox-page-elements-v2 contants
49
+ YOLOX_PAGE_V2_NUM_CLASSES = 4
50
+ YOLOX_PAGE_V2_FINAL_SCORE = {"table": 0.1, "chart": 0.01, "infographic": 0.01}
51
+ YOLOX_PAGE_V2_CLASS_LABELS = [
52
+ "table",
53
+ "chart",
54
+ "title",
55
+ "infographic",
56
+ ]
57
+
58
+
59
+ # yolox-graphic-elements-v1 contants
60
+ YOLOX_GRAPHIC_NUM_CLASSES = 10
61
+ YOLOX_GRAPHIC_CONF_THRESHOLD = 0.01
62
+ YOLOX_GRAPHIC_IOU_THRESHOLD = 0.25
63
+ YOLOX_GRAPHIC_MIN_SCORE = 0.1
64
+ YOLOX_GRAPHIC_FINAL_SCORE = 0.0
65
+ YOLOX_GRAPHIC_NIM_MAX_IMAGE_SIZE = 512_000
66
+
67
+ # TODO(Devin): Legacy items aren't working right for me. Double check these.
68
+ LEGACY_YOLOX_GRAPHIC_IMAGE_PREPROC_HEIGHT = 1024
69
+ LEGACY_YOLOX_GRAPHIC_IMAGE_PREPROC_WIDTH = 1024
70
+ YOLOX_GRAPHIC_IMAGE_PREPROC_HEIGHT = 1024
71
+ YOLOX_GRAPHIC_IMAGE_PREPROC_WIDTH = 1024
72
+
73
+ YOLOX_GRAPHIC_CLASS_LABELS = [
74
+ "chart_title",
75
+ "x_title",
76
+ "y_title",
77
+ "xlabel",
78
+ "ylabel",
79
+ "other",
80
+ "legend_label",
81
+ "legend_title",
82
+ "mark_label",
83
+ "value_label",
84
+ ]
85
+
86
+
87
+ # yolox-table-structure-v1 contants
88
+ YOLOX_TABLE_NUM_CLASSES = 5
89
+ YOLOX_TABLE_CONF_THRESHOLD = 0.01
90
+ YOLOX_TABLE_IOU_THRESHOLD = 0.25
91
+ YOLOX_TABLE_MIN_SCORE = 0.1
92
+ YOLOX_TABLE_FINAL_SCORE = 0.0
93
+ YOLOX_TABLE_NIM_MAX_IMAGE_SIZE = 512_000
94
+
95
+ YOLOX_TABLE_IMAGE_PREPROC_HEIGHT = 1024
96
+ YOLOX_TABLE_IMAGE_PREPROC_WIDTH = 1024
97
+
98
+ YOLOX_TABLE_CLASS_LABELS = [
99
+ "border",
100
+ "cell",
101
+ "row",
102
+ "column",
103
+ "header",
104
+ ]
105
+
106
+
107
+ # YoloxModelInterfaceBase implements methods that are common to yolox-page-elements and yolox-graphic-elements
108
+ class YoloxModelInterfaceBase(ModelInterface):
109
+ """
110
+ An interface for handling inference with a Yolox object detection model, supporting both gRPC and HTTP protocols.
111
+ """
112
+
113
+ def __init__(
114
+ self,
115
+ image_preproc_width: Optional[int] = None,
116
+ image_preproc_height: Optional[int] = None,
117
+ nim_max_image_size: Optional[int] = None,
118
+ num_classes: Optional[int] = None,
119
+ conf_threshold: Optional[float] = None,
120
+ iou_threshold: Optional[float] = None,
121
+ min_score: Optional[float] = None,
122
+ final_score: Optional[float] = None,
123
+ class_labels: Optional[List[str]] = None,
124
+ ):
125
+ """
126
+ Initialize the YOLOX model interface.
127
+ Parameters
128
+ ----------
129
+ """
130
+ self.image_preproc_width = image_preproc_width
131
+ self.image_preproc_height = image_preproc_height
132
+ self.nim_max_image_size = nim_max_image_size
133
+ self.num_classes = num_classes
134
+ self.conf_threshold = conf_threshold
135
+ self.iou_threshold = iou_threshold
136
+ self.min_score = min_score
137
+ self.final_score = final_score
138
+ self.class_labels = class_labels
139
+
140
+ def prepare_data_for_inference(self, data: Dict[str, Any]) -> Dict[str, Any]:
141
+ """
142
+ Prepare input data for inference by resizing images and storing their original shapes.
143
+
144
+ Parameters
145
+ ----------
146
+ data : dict
147
+ The input data containing a list of images.
148
+
149
+ Returns
150
+ -------
151
+ dict
152
+ The updated data dictionary with resized images and original image shapes.
153
+ """
154
+ if (not isinstance(data, dict)) or ("images" not in data):
155
+ raise KeyError("Input data must be a dictionary containing an 'images' key with a list of images.")
156
+
157
+ if not all(isinstance(x, np.ndarray) for x in data["images"]):
158
+ raise ValueError("All elements in the 'images' list must be numpy.ndarray objects.")
159
+
160
+ original_images = data["images"]
161
+ data["original_image_shapes"] = [image.shape for image in original_images]
162
+
163
+ return data
164
+
165
+ def format_input(
166
+ self, data: Dict[str, Any], protocol: str, max_batch_size: int, **kwargs
167
+ ) -> Tuple[List[Any], List[Dict[str, Any]]]:
168
+ """
169
+ Format input data for the specified protocol, returning a tuple of:
170
+ (formatted_batches, formatted_batch_data)
171
+ where:
172
+ - For gRPC: formatted_batches is a list of NumPy arrays, each of shape (B, H, W, C)
173
+ with B <= max_batch_size.
174
+ - For HTTP: formatted_batches is a list of JSON-serializable dict payloads.
175
+ - In both cases, formatted_batch_data is a list of dicts that coalesce the original
176
+ images and their original shapes in the same order as provided.
177
+
178
+ Parameters
179
+ ----------
180
+ data : dict
181
+ The input data to format. Must include:
182
+ - "images": a list of numpy.ndarray images.
183
+ - "original_image_shapes": a list of tuples with each image's (height, width),
184
+ as set by prepare_data_for_inference.
185
+ protocol : str
186
+ The protocol to use ("grpc" or "http").
187
+ max_batch_size : int
188
+ The maximum number of images per batch.
189
+
190
+ Returns
191
+ -------
192
+ tuple
193
+ A tuple (formatted_batches, formatted_batch_data).
194
+
195
+ Raises
196
+ ------
197
+ ValueError
198
+ If the protocol is invalid.
199
+ """
200
+
201
+ # Helper functions to chunk a list into sublists of length up to chunk_size.
202
+ def chunk_list(lst: list, chunk_size: int) -> List[list]:
203
+ return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
204
+
205
+ def chunk_list_geometrically(lst: list, max_size: int) -> List[list]:
206
+ # TRT engine in Yolox NIM (gRPC) only allows a batch size in powers of 2.
207
+ chunks = []
208
+ i = 0
209
+ while i < len(lst):
210
+ chunk_size = min(2 ** int(log(len(lst) - i, 2)), max_size)
211
+ chunks.append(lst[i : i + chunk_size])
212
+ i += chunk_size
213
+ return chunks
214
+
215
+ if protocol == "grpc":
216
+ logger.debug("Formatting input for gRPC Yolox model")
217
+ # Resize images for model input (Yolox expects 1024x1024).
218
+ resized_images = [
219
+ resize_image(image, (self.image_preproc_width, self.image_preproc_height)) for image in data["images"]
220
+ ]
221
+ # Chunk the resized images, the original images, and their shapes.
222
+ resized_chunks = chunk_list_geometrically(resized_images, max_batch_size)
223
+ original_chunks = chunk_list_geometrically(data["images"], max_batch_size)
224
+ shape_chunks = chunk_list_geometrically(data["original_image_shapes"], max_batch_size)
225
+
226
+ batched_inputs = []
227
+ formatted_batch_data = []
228
+ for r_chunk, orig_chunk, shapes in zip(resized_chunks, original_chunks, shape_chunks):
229
+ # Reorder axes from (B, H, W, C) to (B, C, H, W) as expected by the model.
230
+ input_array = np.einsum("bijk->bkij", r_chunk).astype(np.float32)
231
+ batched_inputs.append(input_array)
232
+ formatted_batch_data.append({"images": orig_chunk, "original_image_shapes": shapes})
233
+ return batched_inputs, formatted_batch_data
234
+
235
+ elif protocol == "http":
236
+ logger.debug("Formatting input for HTTP Yolox model")
237
+ content_list: List[Dict[str, Any]] = []
238
+ for image in data["images"]:
239
+ # Convert to uint8 if needed.
240
+ if image.dtype != np.uint8:
241
+ image = (image * 255).astype(np.uint8)
242
+ # Convert the numpy array to a PIL Image.
243
+ image_pil = Image.fromarray(image)
244
+ original_size = image_pil.size
245
+
246
+ # Save the image to a buffer and encode to base64.
247
+ buffered = io.BytesIO()
248
+ image_pil.save(buffered, format="PNG")
249
+ image_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
250
+
251
+ # Scale the image if necessary.
252
+ scaled_image_b64, new_size = scale_image_to_encoding_size(
253
+ image_b64, max_base64_size=self.nim_max_image_size
254
+ )
255
+ if new_size != original_size:
256
+ logger.debug(f"Image was scaled from {original_size} to {new_size}.")
257
+
258
+ content_list.append({"type": "image_url", "url": f"data:image/png;base64,{scaled_image_b64}"})
259
+
260
+ # Chunk the payload content, the original images, and their shapes.
261
+ content_chunks = chunk_list(content_list, max_batch_size)
262
+ original_chunks = chunk_list(data["images"], max_batch_size)
263
+ shape_chunks = chunk_list(data["original_image_shapes"], max_batch_size)
264
+
265
+ payload_batches = []
266
+ formatted_batch_data = []
267
+ for chunk, orig_chunk, shapes in zip(content_chunks, original_chunks, shape_chunks):
268
+ payload = {"input": chunk}
269
+ payload_batches.append(payload)
270
+ formatted_batch_data.append({"images": orig_chunk, "original_image_shapes": shapes})
271
+ return payload_batches, formatted_batch_data
272
+
273
+ else:
274
+ raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.")
275
+
276
+ def parse_output(self, response: Any, protocol: str, data: Optional[Dict[str, Any]] = None, **kwargs) -> Any:
277
+ """
278
+ Parse the output from the model's inference response.
279
+
280
+ Parameters
281
+ ----------
282
+ response : Any
283
+ The response from the model inference.
284
+ protocol : str
285
+ The protocol used ("grpc" or "http").
286
+ data : dict, optional
287
+ Additional input data passed to the function.
288
+
289
+ Returns
290
+ -------
291
+ Any
292
+ The parsed output data.
293
+
294
+ Raises
295
+ ------
296
+ ValueError
297
+ If an invalid protocol is specified or the response format is unexpected.
298
+ """
299
+
300
+ if protocol == "grpc":
301
+ logger.debug("Parsing output from gRPC Yolox model")
302
+ return response # For gRPC, response is already a numpy array
303
+ elif protocol == "http":
304
+ logger.debug("Parsing output from HTTP Yolox model")
305
+
306
+ processed_outputs = []
307
+
308
+ batch_results = response.get("data", [])
309
+ for detections in batch_results:
310
+ new_bounding_boxes = {label: [] for label in self.class_labels}
311
+
312
+ bounding_boxes = detections.get("bounding_boxes", [])
313
+ for obj_type, bboxes in bounding_boxes.items():
314
+ for bbox in bboxes:
315
+ xmin = bbox["x_min"]
316
+ ymin = bbox["y_min"]
317
+ xmax = bbox["x_max"]
318
+ ymax = bbox["y_max"]
319
+ confidence = bbox["confidence"]
320
+
321
+ new_bounding_boxes[obj_type].append([xmin, ymin, xmax, ymax, confidence])
322
+
323
+ processed_outputs.append(new_bounding_boxes)
324
+
325
+ return processed_outputs
326
+ else:
327
+ raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.")
328
+
329
+ def process_inference_results(self, output: Any, protocol: str, **kwargs) -> List[Dict[str, Any]]:
330
+ """
331
+ Process the results of the Yolox model inference and return the final annotations.
332
+
333
+ Parameters
334
+ ----------
335
+ output_array : np.ndarray
336
+ The raw output from the Yolox model.
337
+ kwargs : dict
338
+ Additional parameters for processing, including thresholds and number of classes.
339
+
340
+ Returns
341
+ -------
342
+ list[dict]
343
+ A list of annotation dictionaries for each image in the batch.
344
+ """
345
+ original_image_shapes = kwargs.get("original_image_shapes", [])
346
+
347
+ if protocol == "http":
348
+ # For http, the output already has postprocessing applied. Skip to table/chart expansion.
349
+ results = output
350
+
351
+ elif protocol == "grpc":
352
+ # For grpc, apply the same NIM postprocessing.
353
+ pred = postprocess_model_prediction(
354
+ output,
355
+ self.num_classes,
356
+ self.conf_threshold,
357
+ self.iou_threshold,
358
+ class_agnostic=False,
359
+ )
360
+ results = postprocess_results(
361
+ pred,
362
+ original_image_shapes,
363
+ self.image_preproc_width,
364
+ self.image_preproc_height,
365
+ self.class_labels,
366
+ min_score=self.min_score,
367
+ )
368
+
369
+ inference_results = self.postprocess_annotations(results, **kwargs)
370
+
371
+ return inference_results
372
+
373
+ def postprocess_annotations(self, annotation_dicts, **kwargs):
374
+ raise NotImplementedError()
375
+
376
+ def transform_normalized_coordinates_to_original(self, results, original_image_shapes):
377
+ """ """
378
+ transformed_results = []
379
+
380
+ for annotation_dict, shape in zip(results, original_image_shapes):
381
+ new_dict = {}
382
+ for label, bboxes_and_scores in annotation_dict.items():
383
+ new_dict[label] = []
384
+ for bbox_and_score in bboxes_and_scores:
385
+ bbox = bbox_and_score[:4]
386
+ transformed_bbox = [
387
+ bbox[0] * shape[1],
388
+ bbox[1] * shape[0],
389
+ bbox[2] * shape[1],
390
+ bbox[3] * shape[0],
391
+ ]
392
+ transformed_bbox += bbox_and_score[4:]
393
+ new_dict[label].append(transformed_bbox)
394
+ transformed_results.append(new_dict)
395
+
396
+ return transformed_results
397
+
398
+
399
+ class YoloxPageElementsModelInterface(YoloxModelInterfaceBase):
400
+ """
401
+ An interface for handling inference with yolox-page-elements model, supporting both gRPC and HTTP protocols.
402
+ """
403
+
404
+ def __init__(self, yolox_model_name: str = "nemoretriever-page-elements-v2"):
405
+ """
406
+ Initialize the yolox-page-elements model interface.
407
+ """
408
+ if yolox_model_name.endswith("-v1"):
409
+ num_classes = YOLOX_PAGE_V1_NUM_CLASSES
410
+ final_score = YOLOX_PAGE_V1_FINAL_SCORE
411
+ class_labels = YOLOX_PAGE_V1_CLASS_LABELS
412
+ else:
413
+ num_classes = YOLOX_PAGE_V2_NUM_CLASSES
414
+ final_score = YOLOX_PAGE_V2_FINAL_SCORE
415
+ class_labels = YOLOX_PAGE_V2_CLASS_LABELS
416
+
417
+ super().__init__(
418
+ image_preproc_width=YOLOX_PAGE_IMAGE_PREPROC_WIDTH,
419
+ image_preproc_height=YOLOX_PAGE_IMAGE_PREPROC_HEIGHT,
420
+ nim_max_image_size=YOLOX_PAGE_NIM_MAX_IMAGE_SIZE,
421
+ num_classes=num_classes,
422
+ conf_threshold=YOLOX_PAGE_CONF_THRESHOLD,
423
+ iou_threshold=YOLOX_PAGE_IOU_THRESHOLD,
424
+ min_score=YOLOX_PAGE_MIN_SCORE,
425
+ final_score=final_score,
426
+ class_labels=class_labels,
427
+ )
428
+
429
+ def name(
430
+ self,
431
+ ) -> str:
432
+ """
433
+ Returns the name of the Yolox model interface.
434
+
435
+ Returns
436
+ -------
437
+ str
438
+ The name of the model interface.
439
+ """
440
+
441
+ return "yolox-page-elements"
442
+
443
+ def postprocess_annotations(self, annotation_dicts, **kwargs):
444
+ original_image_shapes = kwargs.get("original_image_shapes", [])
445
+
446
+ expected_final_score_keys = [x for x in self.class_labels if x != "title"]
447
+ if (not isinstance(self.final_score, dict)) or (
448
+ sorted(self.final_score.keys()) != sorted(expected_final_score_keys)
449
+ ):
450
+ raise ValueError(
451
+ "yolox-page-elements-v2 requires a dictionary of thresholds per each class: "
452
+ f"{expected_final_score_keys}"
453
+ )
454
+
455
+ # Table/chart expansion is "business logic" specific to nv-ingest
456
+ annotation_dicts = [expand_table_bboxes(annotation_dict) for annotation_dict in annotation_dicts]
457
+ annotation_dicts = [expand_chart_bboxes(annotation_dict) for annotation_dict in annotation_dicts]
458
+ inference_results = []
459
+
460
+ # Filter out bounding boxes below the final threshold
461
+ # This final thresholding is "business logic" specific to nv-ingest
462
+ for annotation_dict in annotation_dicts:
463
+ new_dict = {}
464
+ if "table" in annotation_dict:
465
+ new_dict["table"] = [bb for bb in annotation_dict["table"] if bb[4] >= self.final_score["table"]]
466
+ if "chart" in annotation_dict:
467
+ new_dict["chart"] = [bb for bb in annotation_dict["chart"] if bb[4] >= self.final_score["chart"]]
468
+ if "infographic" in annotation_dict:
469
+ new_dict["infographic"] = [
470
+ bb for bb in annotation_dict["infographic"] if bb[4] >= self.final_score["infographic"]
471
+ ]
472
+ if "title" in annotation_dict:
473
+ new_dict["title"] = annotation_dict["title"]
474
+ inference_results.append(new_dict)
475
+
476
+ inference_results = self.transform_normalized_coordinates_to_original(inference_results, original_image_shapes)
477
+
478
+ return inference_results
479
+
480
+
481
+ class YoloxGraphicElementsModelInterface(YoloxModelInterfaceBase):
482
+ """
483
+ An interface for handling inference with yolox-graphic-elemenents model, supporting both gRPC and HTTP protocols.
484
+ """
485
+
486
+ def __init__(self, yolox_version: Optional[str] = None):
487
+ """
488
+ Initialize the yolox-graphic-elements model interface.
489
+ """
490
+ if yolox_version and (
491
+ packaging.version.Version(yolox_version) >= packaging.version.Version("1.2.0-rc5") # gtc release
492
+ ):
493
+ image_preproc_width = YOLOX_GRAPHIC_IMAGE_PREPROC_WIDTH
494
+ image_preproc_height = YOLOX_GRAPHIC_IMAGE_PREPROC_HEIGHT
495
+ else:
496
+ image_preproc_width = LEGACY_YOLOX_GRAPHIC_IMAGE_PREPROC_WIDTH
497
+ image_preproc_height = LEGACY_YOLOX_GRAPHIC_IMAGE_PREPROC_HEIGHT
498
+
499
+ super().__init__(
500
+ image_preproc_width=image_preproc_width,
501
+ image_preproc_height=image_preproc_height,
502
+ nim_max_image_size=YOLOX_GRAPHIC_NIM_MAX_IMAGE_SIZE,
503
+ num_classes=YOLOX_GRAPHIC_NUM_CLASSES,
504
+ conf_threshold=YOLOX_GRAPHIC_CONF_THRESHOLD,
505
+ iou_threshold=YOLOX_GRAPHIC_IOU_THRESHOLD,
506
+ min_score=YOLOX_GRAPHIC_MIN_SCORE,
507
+ final_score=YOLOX_GRAPHIC_FINAL_SCORE,
508
+ class_labels=YOLOX_GRAPHIC_CLASS_LABELS,
509
+ )
510
+
511
+ def name(
512
+ self,
513
+ ) -> str:
514
+ """
515
+ Returns the name of the Yolox model interface.
516
+
517
+ Returns
518
+ -------
519
+ str
520
+ The name of the model interface.
521
+ """
522
+
523
+ return "yolox-graphic-elements"
524
+
525
+ def postprocess_annotations(self, annotation_dicts, **kwargs):
526
+ original_image_shapes = kwargs.get("original_image_shapes", [])
527
+
528
+ annotation_dicts = self.transform_normalized_coordinates_to_original(annotation_dicts, original_image_shapes)
529
+
530
+ inference_results = []
531
+
532
+ # bbox extraction: additional postprocessing speicifc to nv-ingest
533
+ for pred, shape in zip(annotation_dicts, original_image_shapes):
534
+ bbox_dict = get_bbox_dict_yolox_graphic(
535
+ pred,
536
+ shape,
537
+ self.class_labels,
538
+ self.min_score,
539
+ )
540
+ # convert numpy arrays to list
541
+ bbox_dict = {
542
+ label: array.tolist() if isinstance(array, np.ndarray) else array for label, array in bbox_dict.items()
543
+ }
544
+ inference_results.append(bbox_dict)
545
+
546
+ return inference_results
547
+
548
+
549
+ class YoloxTableStructureModelInterface(YoloxModelInterfaceBase):
550
+ """
551
+ An interface for handling inference with yolox-graphic-elemenents model, supporting both gRPC and HTTP protocols.
552
+ """
553
+
554
+ def __init__(self):
555
+ """
556
+ Initialize the yolox-graphic-elements model interface.
557
+ """
558
+ super().__init__(
559
+ image_preproc_width=YOLOX_TABLE_IMAGE_PREPROC_HEIGHT,
560
+ image_preproc_height=YOLOX_TABLE_IMAGE_PREPROC_HEIGHT,
561
+ nim_max_image_size=YOLOX_TABLE_NIM_MAX_IMAGE_SIZE,
562
+ num_classes=YOLOX_TABLE_NUM_CLASSES,
563
+ conf_threshold=YOLOX_TABLE_CONF_THRESHOLD,
564
+ iou_threshold=YOLOX_TABLE_IOU_THRESHOLD,
565
+ min_score=YOLOX_TABLE_MIN_SCORE,
566
+ final_score=YOLOX_TABLE_FINAL_SCORE,
567
+ class_labels=YOLOX_TABLE_CLASS_LABELS,
568
+ )
569
+
570
+ def name(
571
+ self,
572
+ ) -> str:
573
+ """
574
+ Returns the name of the Yolox model interface.
575
+
576
+ Returns
577
+ -------
578
+ str
579
+ The name of the model interface.
580
+ """
581
+
582
+ return "yolox-table-structure"
583
+
584
+ def postprocess_annotations(self, annotation_dicts, **kwargs):
585
+ original_image_shapes = kwargs.get("original_image_shapes", [])
586
+
587
+ annotation_dicts = self.transform_normalized_coordinates_to_original(annotation_dicts, original_image_shapes)
588
+
589
+ inference_results = []
590
+
591
+ # bbox extraction: additional postprocessing speicifc to nv-ingest
592
+ for pred, shape in zip(annotation_dicts, original_image_shapes):
593
+ bbox_dict = get_bbox_dict_yolox_table(
594
+ pred,
595
+ shape,
596
+ self.class_labels,
597
+ self.min_score,
598
+ )
599
+ # convert numpy arrays to list
600
+ bbox_dict = {
601
+ label: array.tolist() if isinstance(array, np.ndarray) else array for label, array in bbox_dict.items()
602
+ }
603
+ inference_results.append(bbox_dict)
604
+
605
+ return inference_results
606
+
607
+
608
+ def postprocess_model_prediction(prediction, num_classes, conf_thre=0.7, nms_thre=0.45, class_agnostic=False):
609
+ # Convert numpy array to torch tensor
610
+ prediction = torch.from_numpy(prediction.copy())
611
+
612
+ # Compute box corners
613
+ box_corner = prediction.new(prediction.shape)
614
+ box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
615
+ box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
616
+ box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
617
+ box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
618
+ prediction[:, :, :4] = box_corner[:, :, :4]
619
+
620
+ output = [None for _ in range(len(prediction))]
621
+
622
+ for i, image_pred in enumerate(prediction):
623
+ # If no detections, continue to the next image
624
+ if not image_pred.size(0):
625
+ continue
626
+
627
+ # Ensure image_pred is 2D
628
+ if image_pred.ndim == 1:
629
+ image_pred = image_pred.unsqueeze(0)
630
+
631
+ # Get score and class with highest confidence
632
+ class_conf, class_pred = torch.max(image_pred[:, 5 : 5 + num_classes], 1, keepdim=True)
633
+
634
+ # Confidence mask
635
+ squeezed_conf = class_conf.squeeze(dim=1)
636
+ conf_mask = image_pred[:, 4] * squeezed_conf >= conf_thre
637
+
638
+ # Apply confidence mask
639
+ detections = torch.cat((image_pred[:, :5], class_conf, class_pred.float()), 1)
640
+ detections = detections[conf_mask]
641
+
642
+ if not detections.size(0):
643
+ continue
644
+
645
+ # Apply Non-Maximum Suppression (NMS)
646
+ if class_agnostic:
647
+ nms_out_index = torchvision.ops.nms(
648
+ detections[:, :4],
649
+ detections[:, 4] * detections[:, 5],
650
+ nms_thre,
651
+ )
652
+ else:
653
+ nms_out_index = torchvision.ops.batched_nms(
654
+ detections[:, :4],
655
+ detections[:, 4] * detections[:, 5],
656
+ detections[:, 6],
657
+ nms_thre,
658
+ )
659
+ detections = detections[nms_out_index]
660
+
661
+ # Append detections to output
662
+ output[i] = detections
663
+
664
+ return output
665
+
666
+
667
+ def postprocess_results(
668
+ results, original_image_shapes, image_preproc_width, image_preproc_height, class_labels, min_score=0.0
669
+ ):
670
+ """
671
+ For each item (==image) in results, computes annotations in the form
672
+
673
+ {"table": [[0.0107, 0.0859, 0.7537, 0.1219, 0.9861], ...],
674
+ "figure": [...],
675
+ "title": [...]
676
+ }
677
+ where each list of 5 floats represents a bounding box in the format [x1, y1, x2, y2, confidence]
678
+
679
+ Keep only bboxes with high enough confidence.
680
+ """
681
+ out = []
682
+
683
+ for original_image_shape, result in zip(original_image_shapes, results):
684
+ annotation_dict = {label: [] for label in class_labels}
685
+
686
+ if result is None:
687
+ out.append(annotation_dict)
688
+ continue
689
+
690
+ try:
691
+ result = result.cpu().numpy()
692
+ scores = result[:, 4] * result[:, 5]
693
+ result = result[scores > min_score]
694
+
695
+ # ratio is used when image was padded
696
+ ratio = min(
697
+ image_preproc_width / original_image_shape[0],
698
+ image_preproc_height / original_image_shape[1],
699
+ )
700
+ bboxes = result[:, :4] / ratio
701
+
702
+ bboxes[:, [0, 2]] /= original_image_shape[1]
703
+ bboxes[:, [1, 3]] /= original_image_shape[0]
704
+ bboxes = np.clip(bboxes, 0.0, 1.0)
705
+
706
+ labels = result[:, 6]
707
+ scores = scores[scores > min_score]
708
+ except Exception as e:
709
+ raise ValueError(f"Error in postprocessing {result.shape} and {original_image_shape}: {e}")
710
+
711
+ for box, score, label in zip(bboxes, scores, labels):
712
+ class_name = class_labels[int(label)]
713
+ annotation_dict[class_name].append([round(float(x), 4) for x in np.concatenate((box, [score]))])
714
+
715
+ out.append(annotation_dict)
716
+
717
+ return out
718
+
719
+
720
+ def resize_image(image, target_img_size):
721
+ w, h, _ = np.array(image).shape
722
+
723
+ if target_img_size is not None: # Resize + Pad
724
+ r = min(target_img_size[0] / w, target_img_size[1] / h)
725
+ image = cv2.resize(
726
+ image,
727
+ (int(h * r), int(w * r)),
728
+ interpolation=cv2.INTER_LINEAR,
729
+ ).astype(np.uint8)
730
+ image = np.pad(
731
+ image,
732
+ ((0, target_img_size[0] - image.shape[0]), (0, target_img_size[1] - image.shape[1]), (0, 0)),
733
+ mode="constant",
734
+ constant_values=114,
735
+ )
736
+
737
+ return image
738
+
739
+
740
+ def expand_table_bboxes(annotation_dict, labels=None):
741
+ """
742
+ Additional preprocessing for tables: extend the upper bounds to capture titles if any.
743
+ Args:
744
+ annotation_dict: output of postprocess_results, a dictionary with keys "table", "figure", "title"
745
+
746
+ Returns:
747
+ annotation_dict: same as input, with expanded bboxes for charts
748
+
749
+ """
750
+ if not labels:
751
+ labels = list(annotation_dict.keys())
752
+
753
+ if not annotation_dict or len(annotation_dict["table"]) == 0:
754
+ return annotation_dict
755
+
756
+ new_annotation_dict = {label: [] for label in labels}
757
+
758
+ for label, bboxes in annotation_dict.items():
759
+ for bbox_and_score in bboxes:
760
+ bbox, score = bbox_and_score[:4], bbox_and_score[4]
761
+
762
+ if label == "table":
763
+ height = bbox[3] - bbox[1]
764
+ bbox[1] = max(0.0, min(1.0, bbox[1] - height * 0.2))
765
+
766
+ new_annotation_dict[label].append([round(float(x), 4) for x in bbox + [score]])
767
+
768
+ return new_annotation_dict
769
+
770
+
771
+ def expand_chart_bboxes(annotation_dict, labels=None):
772
+ """
773
+ Expand bounding boxes of charts and titles based on the bounding boxes of the other class.
774
+ Args:
775
+ annotation_dict: output of postprocess_results, a dictionary with keys "table", "figure", "title"
776
+
777
+ Returns:
778
+ annotation_dict: same as input, with expanded bboxes for charts
779
+
780
+ """
781
+ if not labels:
782
+ labels = list(annotation_dict.keys())
783
+
784
+ if not annotation_dict or len(annotation_dict["chart"]) == 0:
785
+ return annotation_dict
786
+
787
+ bboxes = []
788
+ confidences = []
789
+ label_idxs = []
790
+ for i, label in enumerate(labels):
791
+ label_annotations = np.array(annotation_dict[label])
792
+
793
+ if len(label_annotations) > 0:
794
+ bboxes.append(label_annotations[:, :4])
795
+ confidences.append(label_annotations[:, 4])
796
+ label_idxs.append(np.full(len(label_annotations), i))
797
+ bboxes = np.concatenate(bboxes)
798
+ confidences = np.concatenate(confidences)
799
+ label_idxs = np.concatenate(label_idxs)
800
+
801
+ pred_wbf, confidences_wbf, labels_wbf = weighted_boxes_fusion(
802
+ bboxes[:, None],
803
+ confidences[:, None],
804
+ label_idxs[:, None],
805
+ merge_type="biggest",
806
+ conf_type="max",
807
+ iou_thr=0.01,
808
+ class_agnostic=False,
809
+ )
810
+ chart_bboxes = pred_wbf[labels_wbf == 1]
811
+ chart_confidences = confidences_wbf[labels_wbf == 1]
812
+ title_bboxes = pred_wbf[labels_wbf == 2]
813
+
814
+ found_title_idxs, no_found_title_idxs = [], []
815
+ for i in range(len(chart_bboxes)):
816
+ match = match_with_title(chart_bboxes[i], title_bboxes, iou_th=0.01)
817
+ if match is not None:
818
+ chart_bboxes[i] = match[0]
819
+ title_bboxes = match[1]
820
+ found_title_idxs.append(i)
821
+ else:
822
+ no_found_title_idxs.append(i)
823
+
824
+ chart_bboxes[found_title_idxs] = expand_boxes(chart_bboxes[found_title_idxs], r_x=1.05, r_y=1.1)
825
+ chart_bboxes[no_found_title_idxs] = expand_boxes(chart_bboxes[no_found_title_idxs], r_x=1.1, r_y=1.25)
826
+
827
+ annotation_dict = {
828
+ "table": annotation_dict["table"],
829
+ "chart": np.concatenate([chart_bboxes, chart_confidences[:, None]], axis=1).tolist(),
830
+ "title": annotation_dict["title"],
831
+ }
832
+ return annotation_dict
833
+
834
+
835
+ def weighted_boxes_fusion(
836
+ boxes_list,
837
+ scores_list,
838
+ labels_list,
839
+ iou_thr=0.5,
840
+ skip_box_thr=0.0,
841
+ conf_type="avg",
842
+ merge_type="weighted",
843
+ class_agnostic=False,
844
+ ):
845
+ """
846
+ Custom wbf implementation that supports a class_agnostic mode and a biggest box fusion.
847
+ Boxes are expected to be in normalized (x0, y0, x1, y1) format.
848
+
849
+ Args:
850
+ boxes_list (list[np array[n x 4]]): List of boxes. One list per model.
851
+ scores_list (list[np array[n]]): List of confidences.
852
+ labels_list (list[np array[n]]): List of labels
853
+ iou_thr (float, optional): IoU threshold for matching. Defaults to 0.55.
854
+ skip_box_thr (float, optional): Exclude boxes with score < skip_box_thr. Defaults to 0.0.
855
+ conf_type (str, optional): Confidence merging type. Defaults to "avg".
856
+ merge_type (str, optional): Merge type "weighted" or "biggest". Defaults to "weighted".
857
+ class_agnostic (bool, optional): If True, merge boxes from different classes. Defaults to False.
858
+
859
+ Returns:
860
+ np array[N x 4]: Merged boxes,
861
+ np array[N]: Merged confidences,
862
+ np array[N]: Merged labels.
863
+ """
864
+ weights = np.ones(len(boxes_list))
865
+
866
+ assert conf_type in ["avg", "max"], 'Conf type must be "avg" or "max"'
867
+ assert merge_type in [
868
+ "weighted",
869
+ "biggest",
870
+ ], 'Conf type must be "weighted" or "biggest"'
871
+
872
+ filtered_boxes = prefilter_boxes(
873
+ boxes_list,
874
+ scores_list,
875
+ labels_list,
876
+ weights,
877
+ skip_box_thr,
878
+ class_agnostic=class_agnostic,
879
+ )
880
+ if len(filtered_boxes) == 0:
881
+ return np.zeros((0, 4)), np.zeros((0,)), np.zeros((0,))
882
+
883
+ overall_boxes = []
884
+ for label in filtered_boxes:
885
+ boxes = filtered_boxes[label]
886
+ np.empty((0, 8))
887
+
888
+ clusters = []
889
+
890
+ # Clusterize boxes
891
+ for j in range(len(boxes)):
892
+ ids = [i for i in range(len(boxes)) if i != j]
893
+ index, best_iou = find_matching_box_fast(boxes[ids], boxes[j], iou_thr)
894
+
895
+ if index != -1:
896
+ index = ids[index]
897
+ cluster_idx = [clust_idx for clust_idx, clust in enumerate(clusters) if (j in clust or index in clust)]
898
+ if len(cluster_idx):
899
+ cluster_idx = cluster_idx[0]
900
+ clusters[cluster_idx] = list(set(clusters[cluster_idx] + [index, j]))
901
+ else:
902
+ clusters.append([index, j])
903
+ else:
904
+ clusters.append([j])
905
+
906
+ for j, c in enumerate(clusters):
907
+ if merge_type == "weighted":
908
+ weighted_box = get_weighted_box(boxes[c], conf_type)
909
+ elif merge_type == "biggest":
910
+ weighted_box = get_biggest_box(boxes[c], conf_type)
911
+
912
+ if conf_type == "max":
913
+ weighted_box[1] = weighted_box[1] / weights.max()
914
+ else: # avg
915
+ weighted_box[1] = weighted_box[1] * len(c) / weights.sum()
916
+ overall_boxes.append(weighted_box)
917
+
918
+ overall_boxes = np.array(overall_boxes)
919
+ overall_boxes = overall_boxes[overall_boxes[:, 1].argsort()[::-1]]
920
+ boxes = overall_boxes[:, 4:]
921
+ scores = overall_boxes[:, 1]
922
+ labels = overall_boxes[:, 0]
923
+ return boxes, scores, labels
924
+
925
+
926
+ def prefilter_boxes(boxes, scores, labels, weights, thr, class_agnostic=False):
927
+ """
928
+ Reformats and filters boxes.
929
+ Output is a dict of boxes to merge separately.
930
+
931
+ Args:
932
+ boxes (list[np array[n x 4]]): List of boxes. One list per model.
933
+ scores (list[np array[n]]): List of confidences.
934
+ labels (list[np array[n]]): List of labels.
935
+ weights (list): Model weights.
936
+ thr (float): Confidence threshold
937
+ class_agnostic (bool, optional): If True, merge boxes from different classes. Defaults to False.
938
+
939
+ Returns:
940
+ dict[np array [? x 8]]: Filtered boxes.
941
+ """
942
+ # Create dict with boxes stored by its label
943
+ new_boxes = dict()
944
+
945
+ for t in range(len(boxes)):
946
+ if len(boxes[t]) != len(scores[t]):
947
+ print(
948
+ "Error. Length of boxes arrays not equal to length of scores array: {} != {}".format(
949
+ len(boxes[t]), len(scores[t])
950
+ )
951
+ )
952
+ exit()
953
+
954
+ if len(boxes[t]) != len(labels[t]):
955
+ print(
956
+ "Error. Length of boxes arrays not equal to length of labels array: {} != {}".format(
957
+ len(boxes[t]), len(labels[t])
958
+ )
959
+ )
960
+ exit()
961
+
962
+ for j in range(len(boxes[t])):
963
+ score = scores[t][j]
964
+ if score < thr:
965
+ continue
966
+ label = int(labels[t][j])
967
+ box_part = boxes[t][j]
968
+ x1 = float(box_part[0])
969
+ y1 = float(box_part[1])
970
+ x2 = float(box_part[2])
971
+ y2 = float(box_part[3])
972
+
973
+ # Box data checks
974
+ if x2 < x1:
975
+ warnings.warn("X2 < X1 value in box. Swap them.")
976
+ x1, x2 = x2, x1
977
+ if y2 < y1:
978
+ warnings.warn("Y2 < Y1 value in box. Swap them.")
979
+ y1, y2 = y2, y1
980
+ if x1 < 0:
981
+ warnings.warn("X1 < 0 in box. Set it to 0.")
982
+ x1 = 0
983
+ if x1 > 1:
984
+ warnings.warn("X1 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.")
985
+ x1 = 1
986
+ if x2 < 0:
987
+ warnings.warn("X2 < 0 in box. Set it to 0.")
988
+ x2 = 0
989
+ if x2 > 1:
990
+ warnings.warn("X2 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.")
991
+ x2 = 1
992
+ if y1 < 0:
993
+ warnings.warn("Y1 < 0 in box. Set it to 0.")
994
+ y1 = 0
995
+ if y1 > 1:
996
+ warnings.warn("Y1 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.")
997
+ y1 = 1
998
+ if y2 < 0:
999
+ warnings.warn("Y2 < 0 in box. Set it to 0.")
1000
+ y2 = 0
1001
+ if y2 > 1:
1002
+ warnings.warn("Y2 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.")
1003
+ y2 = 1
1004
+ if (x2 - x1) * (y2 - y1) == 0.0:
1005
+ warnings.warn("Zero area box skipped: {}.".format(box_part))
1006
+ continue
1007
+
1008
+ # [label, score, weight, model index, x1, y1, x2, y2]
1009
+ b = [int(label), float(score) * weights[t], weights[t], t, x1, y1, x2, y2]
1010
+
1011
+ label_k = "*" if class_agnostic else label
1012
+ if label_k not in new_boxes:
1013
+ new_boxes[label_k] = []
1014
+ new_boxes[label_k].append(b)
1015
+
1016
+ # Sort each list in dict by score and transform it to numpy array
1017
+ for k in new_boxes:
1018
+ current_boxes = np.array(new_boxes[k])
1019
+ new_boxes[k] = current_boxes[current_boxes[:, 1].argsort()[::-1]]
1020
+
1021
+ return new_boxes
1022
+
1023
+
1024
+ def find_matching_box_fast(boxes_list, new_box, match_iou):
1025
+ """
1026
+ Reimplementation of find_matching_box with numpy instead of loops. Gives significant speed up for larger arrays
1027
+ (~100x). This was previously the bottleneck since the function is called for every entry in the array.
1028
+ """
1029
+
1030
+ def bb_iou_array(boxes, new_box):
1031
+ # bb interesection over union
1032
+ xA = np.maximum(boxes[:, 0], new_box[0])
1033
+ yA = np.maximum(boxes[:, 1], new_box[1])
1034
+ xB = np.minimum(boxes[:, 2], new_box[2])
1035
+ yB = np.minimum(boxes[:, 3], new_box[3])
1036
+
1037
+ interArea = np.maximum(xB - xA, 0) * np.maximum(yB - yA, 0)
1038
+
1039
+ # compute the area of both the prediction and ground-truth rectangles
1040
+ boxAArea = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
1041
+ boxBArea = (new_box[2] - new_box[0]) * (new_box[3] - new_box[1])
1042
+
1043
+ iou = interArea / (boxAArea + boxBArea - interArea)
1044
+
1045
+ return iou
1046
+
1047
+ if boxes_list.shape[0] == 0:
1048
+ return -1, match_iou
1049
+
1050
+ ious = bb_iou_array(boxes_list[:, 4:], new_box[4:])
1051
+ # ious[boxes[:, 0] != new_box[0]] = -1
1052
+
1053
+ best_idx = np.argmax(ious)
1054
+ best_iou = ious[best_idx]
1055
+
1056
+ if best_iou <= match_iou:
1057
+ best_iou = match_iou
1058
+ best_idx = -1
1059
+
1060
+ return best_idx, best_iou
1061
+
1062
+
1063
+ def get_biggest_box(boxes, conf_type="avg"):
1064
+ """
1065
+ Merges boxes by using the biggest box.
1066
+
1067
+ Args:
1068
+ boxes (np array [n x 8]): Boxes to merge.
1069
+ conf_type (str, optional): Confidence merging type. Defaults to "avg".
1070
+
1071
+ Returns:
1072
+ np array [8]: Merged box.
1073
+ """
1074
+ box = np.zeros(8, dtype=np.float32)
1075
+ box[4:] = boxes[0][4:]
1076
+ conf_list = []
1077
+ w = 0
1078
+ for b in boxes:
1079
+ box[4] = min(box[4], b[4])
1080
+ box[5] = min(box[5], b[5])
1081
+ box[6] = max(box[6], b[6])
1082
+ box[7] = max(box[7], b[7])
1083
+ conf_list.append(b[1])
1084
+ w += b[2]
1085
+
1086
+ box[0] = merge_labels(np.array([b[0] for b in boxes]), np.array([b[1] for b in boxes]))
1087
+ # print(box[0], np.array([b[0] for b in boxes]))
1088
+
1089
+ box[1] = np.max(conf_list) if conf_type == "max" else np.mean(conf_list)
1090
+ box[2] = w
1091
+ box[3] = -1 # model index field is retained for consistency but is not used.
1092
+ return box
1093
+
1094
+
1095
+ def merge_labels(labels, confs):
1096
+ """
1097
+ Custom function for merging labels.
1098
+ If all labels are the same, return the unique value.
1099
+ Else, return the label of the most confident non-title (class 2) box.
1100
+
1101
+ Args:
1102
+ labels (np array [n]): Labels.
1103
+ confs (np array [n]): Confidence.
1104
+
1105
+ Returns:
1106
+ int: Label.
1107
+ """
1108
+ if len(np.unique(labels)) == 1:
1109
+ return labels[0]
1110
+ else: # Most confident and not a title
1111
+ confs = confs[confs != 2]
1112
+ labels = labels[labels != 2]
1113
+ return labels[np.argmax(confs)]
1114
+
1115
+
1116
+ def match_with_title(chart_bbox, title_bboxes, iou_th=0.01):
1117
+ if not len(title_bboxes):
1118
+ return None
1119
+
1120
+ dist_above = np.abs(title_bboxes[:, 3] - chart_bbox[1])
1121
+ dist_below = np.abs(chart_bbox[3] - title_bboxes[:, 1])
1122
+
1123
+ dist_left = np.abs(title_bboxes[:, 0] - chart_bbox[0])
1124
+
1125
+ ious = bb_iou_array(title_bboxes, chart_bbox)
1126
+
1127
+ matches = None
1128
+ if np.max(ious) > iou_th:
1129
+ matches = np.where(ious > iou_th)[0]
1130
+ else:
1131
+ dists = np.min([dist_above, dist_below], 0)
1132
+ dists += dist_left
1133
+ # print(dists)
1134
+ if np.min(dists) < 0.1:
1135
+ matches = [np.argmin(dists)]
1136
+
1137
+ if matches is not None:
1138
+ new_bbox = chart_bbox
1139
+ for match in matches:
1140
+ new_bbox = merge_boxes(new_bbox, title_bboxes[match])
1141
+ title_bboxes = title_bboxes[[i for i in range(len(title_bboxes)) if i not in matches]]
1142
+ return new_bbox, title_bboxes
1143
+
1144
+ else:
1145
+ return None
1146
+
1147
+
1148
+ def bb_iou_array(boxes, new_box):
1149
+ # bb interesection over union
1150
+ xA = np.maximum(boxes[:, 0], new_box[0])
1151
+ yA = np.maximum(boxes[:, 1], new_box[1])
1152
+ xB = np.minimum(boxes[:, 2], new_box[2])
1153
+ yB = np.minimum(boxes[:, 3], new_box[3])
1154
+
1155
+ interArea = np.maximum(xB - xA, 0) * np.maximum(yB - yA, 0)
1156
+
1157
+ # compute the area of both the prediction and ground-truth rectangles
1158
+ boxAArea = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
1159
+ boxBArea = (new_box[2] - new_box[0]) * (new_box[3] - new_box[1])
1160
+
1161
+ iou = interArea / (boxAArea + boxBArea - interArea)
1162
+
1163
+ return iou
1164
+
1165
+
1166
+ def merge_boxes(b1, b2):
1167
+ b = b1.copy()
1168
+ b[0] = min(b1[0], b2[0])
1169
+ b[1] = min(b1[1], b2[1])
1170
+ b[2] = max(b1[2], b2[2])
1171
+ b[3] = max(b1[3], b2[3])
1172
+ return b
1173
+
1174
+
1175
+ def expand_boxes(boxes, r_x=1, r_y=1):
1176
+ dw = (boxes[:, 2] - boxes[:, 0]) / 2 * (r_x - 1)
1177
+ boxes[:, 0] -= dw
1178
+ boxes[:, 2] += dw
1179
+
1180
+ dh = (boxes[:, 3] - boxes[:, 1]) / 2 * (r_y - 1)
1181
+ boxes[:, 1] -= dh
1182
+ boxes[:, 3] += dh
1183
+
1184
+ boxes = np.clip(boxes, 0, 1)
1185
+ return boxes
1186
+
1187
+
1188
+ def get_weighted_box(boxes, conf_type="avg"):
1189
+ """
1190
+ Merges boxes by using the weighted fusion.
1191
+
1192
+ Args:
1193
+ boxes (np array [n x 8]): Boxes to merge.
1194
+ conf_type (str, optional): Confidence merging type. Defaults to "avg".
1195
+
1196
+ Returns:
1197
+ np array [8]: Merged box.
1198
+ """
1199
+ box = np.zeros(8, dtype=np.float32)
1200
+ conf = 0
1201
+ conf_list = []
1202
+ w = 0
1203
+ for b in boxes:
1204
+ box[4:] += b[1] * b[4:]
1205
+ conf += b[1]
1206
+ conf_list.append(b[1])
1207
+ w += b[2]
1208
+
1209
+ box[0] = merge_labels(np.array([b[0] for b in boxes]), np.array([b[1] for b in boxes]))
1210
+
1211
+ box[1] = np.max(conf_list) if conf_type == "max" else np.mean(conf_list)
1212
+ box[2] = w
1213
+ box[3] = -1 # model index field is retained for consistency but is not used.
1214
+ box[4:] /= conf
1215
+ return box
1216
+
1217
+
1218
+ def batched_overlaps(A, B):
1219
+ """
1220
+ Calculate the Intersection over Union (IoU) between
1221
+ two sets of bounding boxes in a batched manner.
1222
+ Normalization is modified to only use the area of A boxes, hence computing the overlaps.
1223
+ Args:
1224
+ A (ndarray): Array of bounding boxes of shape (N, 4) in format [x1, y1, x2, y2].
1225
+ B (ndarray): Array of bounding boxes of shape (M, 4) in format [x1, y1, x2, y2].
1226
+ Returns:
1227
+ ndarray: Array of IoU values of shape (N, M) representing the overlaps
1228
+ between each pair of bounding boxes.
1229
+ """
1230
+ A = A.copy()
1231
+ B = B.copy()
1232
+
1233
+ A = A[None].repeat(B.shape[0], 0)
1234
+ B = B[:, None].repeat(A.shape[1], 1)
1235
+
1236
+ low = np.s_[..., :2]
1237
+ high = np.s_[..., 2:]
1238
+
1239
+ A, B = A.copy(), B.copy()
1240
+ A[high] += 1
1241
+ B[high] += 1
1242
+
1243
+ intrs = (np.maximum(0, np.minimum(A[high], B[high]) - np.maximum(A[low], B[low]))).prod(-1)
1244
+ ious = intrs / (A[high] - A[low]).prod(-1)
1245
+
1246
+ return ious
1247
+
1248
+
1249
+ def find_boxes_inside(boxes, boxes_to_check, threshold=0.9):
1250
+ """
1251
+ Find all boxes that are inside another box based on
1252
+ the intersection area divided by the area of the smaller box,
1253
+ and removes them.
1254
+ """
1255
+ overlaps = batched_overlaps(boxes_to_check, boxes)
1256
+ to_keep = (overlaps >= threshold).sum(0) <= 1
1257
+ return boxes_to_check[to_keep]
1258
+
1259
+
1260
+ def get_bbox_dict_yolox_graphic(preds, shape, class_labels, threshold_=0.1) -> Dict[str, np.ndarray]:
1261
+ """
1262
+ Extracts bounding boxes from YOLOX model predictions:
1263
+ - Applies thresholding
1264
+ - Reformats boxes
1265
+ - Cleans the `other` detections: removes the ones that are included in other detections.
1266
+ - If no title is found, the biggest `other` box is used if it is larger than 0.3*img_w.
1267
+ Args:
1268
+ preds (np.ndarray): YOLOX model predictions including bounding boxes, scores, and labels.
1269
+ shape (tuple): Original image shape.
1270
+ threshold_ (float): Score threshold to filter bounding boxes.
1271
+ Returns:
1272
+ Dict[str, np.ndarray]: Dictionary of bounding boxes, organized by class.
1273
+ """
1274
+ bbox_dict = {label: np.array([]) for label in class_labels}
1275
+
1276
+ for i, label in enumerate(class_labels):
1277
+ bboxes_class = np.array(preds[label])
1278
+
1279
+ if bboxes_class.size == 0:
1280
+ continue
1281
+
1282
+ # Try to find a chart_title box
1283
+ threshold = threshold_ if label != "chart_title" else min(threshold_, bboxes_class[:, -1].max())
1284
+ bboxes_class = bboxes_class[bboxes_class[:, -1] >= threshold][:, :4].astype(int)
1285
+
1286
+ sort = ["x0", "y0"] if label != "ylabel" else ["y0", "x0"]
1287
+ idxs = (
1288
+ pd.DataFrame(
1289
+ {
1290
+ "y0": bboxes_class[:, 1],
1291
+ "x0": bboxes_class[:, 0],
1292
+ }
1293
+ )
1294
+ .sort_values(sort, ascending=label != "ylabel")
1295
+ .index
1296
+ )
1297
+ bboxes_class = bboxes_class[idxs]
1298
+ bbox_dict[label] = bboxes_class
1299
+
1300
+ # Remove other included
1301
+ if len(bbox_dict.get("other", [])):
1302
+ other = find_boxes_inside(
1303
+ np.concatenate(list([v for v in bbox_dict.values() if len(v)])), bbox_dict["other"], threshold=0.7
1304
+ )
1305
+ del bbox_dict["other"]
1306
+ if len(other):
1307
+ bbox_dict["other"] = other
1308
+
1309
+ # Biggest other is title if no title
1310
+ if not len(bbox_dict.get("chart_title", [])) and len(bbox_dict.get("other", [])):
1311
+ boxes = bbox_dict["other"]
1312
+ ws = boxes[:, 2] - boxes[:, 0]
1313
+ if np.max(ws) > shape[1] * 0.3:
1314
+ bbox_dict["chart_title"] = boxes[np.argmax(ws)][None].copy()
1315
+ bbox_dict["other"] = np.delete(boxes, (np.argmax(ws)), axis=0)
1316
+
1317
+ # Make sure other key not lost
1318
+ bbox_dict["other"] = bbox_dict.get("other", [])
1319
+
1320
+ return bbox_dict
1321
+
1322
+
1323
+ def get_bbox_dict_yolox_table(preds, shape, class_labels, threshold=0.1, delta=0.0):
1324
+ """
1325
+ Extracts bounding boxes from YOLOX model predictions:
1326
+ - Applies thresholding
1327
+ - Reformats boxes
1328
+ - Reorders predictions
1329
+
1330
+ Args:
1331
+ preds (np.ndarray): YOLOX model predictions including bounding boxes, scores, and labels.
1332
+ shape (tuple): Original image shape.
1333
+ config: Model configuration, including size for bounding box adjustment.
1334
+ threshold (float): Score threshold to filter bounding boxes.
1335
+ delta (float): How much the table was cropped upwards.
1336
+
1337
+ Returns:
1338
+ dict[str, np.ndarray]: Dictionary of bounding boxes, organized by class.
1339
+ """
1340
+ bbox_dict = {label: np.array([]) for label in class_labels}
1341
+
1342
+ for i, label in enumerate(class_labels):
1343
+ if label not in ["cell", "row", "column"]:
1344
+ continue # Ignore useless classes
1345
+
1346
+ bboxes_class = np.array(preds[label])
1347
+
1348
+ if bboxes_class.size == 0:
1349
+ continue
1350
+
1351
+ # Threshold and clip
1352
+ bboxes_class = bboxes_class[bboxes_class[:, -1] >= threshold][:, :4].astype(int)
1353
+ bboxes_class[:, [0, 2]] = np.clip(bboxes_class[:, [0, 2]], 0, shape[1])
1354
+ bboxes_class[:, [1, 3]] = np.clip(bboxes_class[:, [1, 3]], 0, shape[0])
1355
+
1356
+ # Reorder
1357
+ sort = ["x0", "y0"] if label != "row" else ["y0", "x0"]
1358
+ df = pd.DataFrame(
1359
+ {
1360
+ "y0": (bboxes_class[:, 1] + bboxes_class[:, 3]) / 2,
1361
+ "x0": (bboxes_class[:, 0] + bboxes_class[:, 2]) / 2,
1362
+ }
1363
+ )
1364
+ idxs = df.sort_values(sort).index
1365
+ bboxes_class = bboxes_class[idxs]
1366
+
1367
+ bbox_dict[label] = bboxes_class
1368
+
1369
+ # Enforce spanning the entire table
1370
+ if len(bbox_dict["row"]):
1371
+ bbox_dict["row"][:, 0] = 0
1372
+ bbox_dict["row"][:, 2] = shape[1]
1373
+ if len(bbox_dict["column"]):
1374
+ bbox_dict["column"][:, 1] = 0
1375
+ bbox_dict["column"][:, 3] = shape[0]
1376
+
1377
+ # Shift back if cropped
1378
+ for k in bbox_dict:
1379
+ if len(bbox_dict[k]):
1380
+ bbox_dict[k][:, [1, 3]] = np.add(bbox_dict[k][:, [1, 3]], delta, casting="unsafe")
1381
+
1382
+ return bbox_dict
1383
+
1384
+
1385
+ def get_yolox_model_name(yolox_http_endpoint, default_model_name="nemoretriever-page-elements-v2"):
1386
+ try:
1387
+ yolox_model_name = get_model_name(yolox_http_endpoint, default_model_name)
1388
+ if not yolox_model_name:
1389
+ logger.warning(
1390
+ "Failed to obtain yolox-page-elements model name from the endpoint. "
1391
+ f"Falling back to '{default_model_name}'."
1392
+ )
1393
+ yolox_model_name = default_model_name
1394
+ except Exception:
1395
+ logger.warning(
1396
+ "Failed to get yolox-page-elements version after 30 seconds. " f"Falling back to '{default_model_name}'."
1397
+ )
1398
+ yolox_model_name = default_model_name
1399
+
1400
+ return yolox_model_name