nv-ingest-api 26.1.0rc4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nv-ingest-api might be problematic. Click here for more details.

Files changed (177) hide show
  1. nv_ingest_api/__init__.py +3 -0
  2. nv_ingest_api/interface/__init__.py +218 -0
  3. nv_ingest_api/interface/extract.py +977 -0
  4. nv_ingest_api/interface/mutate.py +154 -0
  5. nv_ingest_api/interface/store.py +200 -0
  6. nv_ingest_api/interface/transform.py +382 -0
  7. nv_ingest_api/interface/utility.py +186 -0
  8. nv_ingest_api/internal/__init__.py +0 -0
  9. nv_ingest_api/internal/enums/__init__.py +3 -0
  10. nv_ingest_api/internal/enums/common.py +550 -0
  11. nv_ingest_api/internal/extract/__init__.py +3 -0
  12. nv_ingest_api/internal/extract/audio/__init__.py +3 -0
  13. nv_ingest_api/internal/extract/audio/audio_extraction.py +202 -0
  14. nv_ingest_api/internal/extract/docx/__init__.py +5 -0
  15. nv_ingest_api/internal/extract/docx/docx_extractor.py +232 -0
  16. nv_ingest_api/internal/extract/docx/engines/__init__.py +0 -0
  17. nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/__init__.py +3 -0
  18. nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/docx_helper.py +127 -0
  19. nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/docxreader.py +971 -0
  20. nv_ingest_api/internal/extract/html/__init__.py +3 -0
  21. nv_ingest_api/internal/extract/html/html_extractor.py +84 -0
  22. nv_ingest_api/internal/extract/image/__init__.py +3 -0
  23. nv_ingest_api/internal/extract/image/chart_extractor.py +375 -0
  24. nv_ingest_api/internal/extract/image/image_extractor.py +208 -0
  25. nv_ingest_api/internal/extract/image/image_helpers/__init__.py +3 -0
  26. nv_ingest_api/internal/extract/image/image_helpers/common.py +433 -0
  27. nv_ingest_api/internal/extract/image/infographic_extractor.py +290 -0
  28. nv_ingest_api/internal/extract/image/ocr_extractor.py +407 -0
  29. nv_ingest_api/internal/extract/image/table_extractor.py +391 -0
  30. nv_ingest_api/internal/extract/pdf/__init__.py +3 -0
  31. nv_ingest_api/internal/extract/pdf/engines/__init__.py +19 -0
  32. nv_ingest_api/internal/extract/pdf/engines/adobe.py +484 -0
  33. nv_ingest_api/internal/extract/pdf/engines/llama.py +246 -0
  34. nv_ingest_api/internal/extract/pdf/engines/nemotron_parse.py +598 -0
  35. nv_ingest_api/internal/extract/pdf/engines/pdf_helpers/__init__.py +166 -0
  36. nv_ingest_api/internal/extract/pdf/engines/pdfium.py +652 -0
  37. nv_ingest_api/internal/extract/pdf/engines/tika.py +96 -0
  38. nv_ingest_api/internal/extract/pdf/engines/unstructured_io.py +426 -0
  39. nv_ingest_api/internal/extract/pdf/pdf_extractor.py +74 -0
  40. nv_ingest_api/internal/extract/pptx/__init__.py +5 -0
  41. nv_ingest_api/internal/extract/pptx/engines/__init__.py +0 -0
  42. nv_ingest_api/internal/extract/pptx/engines/pptx_helper.py +968 -0
  43. nv_ingest_api/internal/extract/pptx/pptx_extractor.py +210 -0
  44. nv_ingest_api/internal/meta/__init__.py +3 -0
  45. nv_ingest_api/internal/meta/udf.py +232 -0
  46. nv_ingest_api/internal/mutate/__init__.py +3 -0
  47. nv_ingest_api/internal/mutate/deduplicate.py +110 -0
  48. nv_ingest_api/internal/mutate/filter.py +133 -0
  49. nv_ingest_api/internal/primitives/__init__.py +0 -0
  50. nv_ingest_api/internal/primitives/control_message_task.py +16 -0
  51. nv_ingest_api/internal/primitives/ingest_control_message.py +307 -0
  52. nv_ingest_api/internal/primitives/nim/__init__.py +9 -0
  53. nv_ingest_api/internal/primitives/nim/default_values.py +14 -0
  54. nv_ingest_api/internal/primitives/nim/model_interface/__init__.py +3 -0
  55. nv_ingest_api/internal/primitives/nim/model_interface/cached.py +274 -0
  56. nv_ingest_api/internal/primitives/nim/model_interface/decorators.py +56 -0
  57. nv_ingest_api/internal/primitives/nim/model_interface/deplot.py +270 -0
  58. nv_ingest_api/internal/primitives/nim/model_interface/helpers.py +338 -0
  59. nv_ingest_api/internal/primitives/nim/model_interface/nemotron_parse.py +239 -0
  60. nv_ingest_api/internal/primitives/nim/model_interface/ocr.py +776 -0
  61. nv_ingest_api/internal/primitives/nim/model_interface/parakeet.py +367 -0
  62. nv_ingest_api/internal/primitives/nim/model_interface/text_embedding.py +129 -0
  63. nv_ingest_api/internal/primitives/nim/model_interface/vlm.py +177 -0
  64. nv_ingest_api/internal/primitives/nim/model_interface/yolox.py +1681 -0
  65. nv_ingest_api/internal/primitives/nim/nim_client.py +801 -0
  66. nv_ingest_api/internal/primitives/nim/nim_model_interface.py +126 -0
  67. nv_ingest_api/internal/primitives/tracing/__init__.py +0 -0
  68. nv_ingest_api/internal/primitives/tracing/latency.py +69 -0
  69. nv_ingest_api/internal/primitives/tracing/logging.py +96 -0
  70. nv_ingest_api/internal/primitives/tracing/tagging.py +288 -0
  71. nv_ingest_api/internal/schemas/__init__.py +3 -0
  72. nv_ingest_api/internal/schemas/extract/__init__.py +3 -0
  73. nv_ingest_api/internal/schemas/extract/extract_audio_schema.py +133 -0
  74. nv_ingest_api/internal/schemas/extract/extract_chart_schema.py +144 -0
  75. nv_ingest_api/internal/schemas/extract/extract_docx_schema.py +129 -0
  76. nv_ingest_api/internal/schemas/extract/extract_html_schema.py +34 -0
  77. nv_ingest_api/internal/schemas/extract/extract_image_schema.py +126 -0
  78. nv_ingest_api/internal/schemas/extract/extract_infographic_schema.py +137 -0
  79. nv_ingest_api/internal/schemas/extract/extract_ocr_schema.py +137 -0
  80. nv_ingest_api/internal/schemas/extract/extract_pdf_schema.py +220 -0
  81. nv_ingest_api/internal/schemas/extract/extract_pptx_schema.py +128 -0
  82. nv_ingest_api/internal/schemas/extract/extract_table_schema.py +137 -0
  83. nv_ingest_api/internal/schemas/message_brokers/__init__.py +3 -0
  84. nv_ingest_api/internal/schemas/message_brokers/message_broker_client_schema.py +37 -0
  85. nv_ingest_api/internal/schemas/message_brokers/request_schema.py +34 -0
  86. nv_ingest_api/internal/schemas/message_brokers/response_schema.py +19 -0
  87. nv_ingest_api/internal/schemas/meta/__init__.py +3 -0
  88. nv_ingest_api/internal/schemas/meta/base_model_noext.py +11 -0
  89. nv_ingest_api/internal/schemas/meta/ingest_job_schema.py +355 -0
  90. nv_ingest_api/internal/schemas/meta/metadata_schema.py +394 -0
  91. nv_ingest_api/internal/schemas/meta/udf.py +23 -0
  92. nv_ingest_api/internal/schemas/mixins.py +39 -0
  93. nv_ingest_api/internal/schemas/mutate/__init__.py +3 -0
  94. nv_ingest_api/internal/schemas/mutate/mutate_image_dedup_schema.py +16 -0
  95. nv_ingest_api/internal/schemas/store/__init__.py +3 -0
  96. nv_ingest_api/internal/schemas/store/store_embedding_schema.py +28 -0
  97. nv_ingest_api/internal/schemas/store/store_image_schema.py +45 -0
  98. nv_ingest_api/internal/schemas/transform/__init__.py +3 -0
  99. nv_ingest_api/internal/schemas/transform/transform_image_caption_schema.py +36 -0
  100. nv_ingest_api/internal/schemas/transform/transform_image_filter_schema.py +17 -0
  101. nv_ingest_api/internal/schemas/transform/transform_text_embedding_schema.py +48 -0
  102. nv_ingest_api/internal/schemas/transform/transform_text_splitter_schema.py +24 -0
  103. nv_ingest_api/internal/store/__init__.py +3 -0
  104. nv_ingest_api/internal/store/embed_text_upload.py +236 -0
  105. nv_ingest_api/internal/store/image_upload.py +251 -0
  106. nv_ingest_api/internal/transform/__init__.py +3 -0
  107. nv_ingest_api/internal/transform/caption_image.py +219 -0
  108. nv_ingest_api/internal/transform/embed_text.py +702 -0
  109. nv_ingest_api/internal/transform/split_text.py +182 -0
  110. nv_ingest_api/util/__init__.py +3 -0
  111. nv_ingest_api/util/control_message/__init__.py +0 -0
  112. nv_ingest_api/util/control_message/validators.py +47 -0
  113. nv_ingest_api/util/converters/__init__.py +0 -0
  114. nv_ingest_api/util/converters/bytetools.py +78 -0
  115. nv_ingest_api/util/converters/containers.py +65 -0
  116. nv_ingest_api/util/converters/datetools.py +90 -0
  117. nv_ingest_api/util/converters/dftools.py +127 -0
  118. nv_ingest_api/util/converters/formats.py +64 -0
  119. nv_ingest_api/util/converters/type_mappings.py +27 -0
  120. nv_ingest_api/util/dataloader/__init__.py +9 -0
  121. nv_ingest_api/util/dataloader/dataloader.py +409 -0
  122. nv_ingest_api/util/detectors/__init__.py +5 -0
  123. nv_ingest_api/util/detectors/language.py +38 -0
  124. nv_ingest_api/util/exception_handlers/__init__.py +0 -0
  125. nv_ingest_api/util/exception_handlers/converters.py +72 -0
  126. nv_ingest_api/util/exception_handlers/decorators.py +429 -0
  127. nv_ingest_api/util/exception_handlers/detectors.py +74 -0
  128. nv_ingest_api/util/exception_handlers/pdf.py +116 -0
  129. nv_ingest_api/util/exception_handlers/schemas.py +68 -0
  130. nv_ingest_api/util/image_processing/__init__.py +5 -0
  131. nv_ingest_api/util/image_processing/clustering.py +260 -0
  132. nv_ingest_api/util/image_processing/processing.py +177 -0
  133. nv_ingest_api/util/image_processing/table_and_chart.py +504 -0
  134. nv_ingest_api/util/image_processing/transforms.py +850 -0
  135. nv_ingest_api/util/imports/__init__.py +3 -0
  136. nv_ingest_api/util/imports/callable_signatures.py +108 -0
  137. nv_ingest_api/util/imports/dynamic_resolvers.py +158 -0
  138. nv_ingest_api/util/introspection/__init__.py +3 -0
  139. nv_ingest_api/util/introspection/class_inspect.py +145 -0
  140. nv_ingest_api/util/introspection/function_inspect.py +65 -0
  141. nv_ingest_api/util/logging/__init__.py +0 -0
  142. nv_ingest_api/util/logging/configuration.py +102 -0
  143. nv_ingest_api/util/logging/sanitize.py +84 -0
  144. nv_ingest_api/util/message_brokers/__init__.py +3 -0
  145. nv_ingest_api/util/message_brokers/qos_scheduler.py +283 -0
  146. nv_ingest_api/util/message_brokers/simple_message_broker/__init__.py +9 -0
  147. nv_ingest_api/util/message_brokers/simple_message_broker/broker.py +465 -0
  148. nv_ingest_api/util/message_brokers/simple_message_broker/ordered_message_queue.py +71 -0
  149. nv_ingest_api/util/message_brokers/simple_message_broker/simple_client.py +455 -0
  150. nv_ingest_api/util/metadata/__init__.py +5 -0
  151. nv_ingest_api/util/metadata/aggregators.py +516 -0
  152. nv_ingest_api/util/multi_processing/__init__.py +8 -0
  153. nv_ingest_api/util/multi_processing/mp_pool_singleton.py +200 -0
  154. nv_ingest_api/util/nim/__init__.py +161 -0
  155. nv_ingest_api/util/pdf/__init__.py +3 -0
  156. nv_ingest_api/util/pdf/pdfium.py +428 -0
  157. nv_ingest_api/util/schema/__init__.py +3 -0
  158. nv_ingest_api/util/schema/schema_validator.py +10 -0
  159. nv_ingest_api/util/service_clients/__init__.py +3 -0
  160. nv_ingest_api/util/service_clients/client_base.py +86 -0
  161. nv_ingest_api/util/service_clients/kafka/__init__.py +3 -0
  162. nv_ingest_api/util/service_clients/redis/__init__.py +3 -0
  163. nv_ingest_api/util/service_clients/redis/redis_client.py +983 -0
  164. nv_ingest_api/util/service_clients/rest/__init__.py +0 -0
  165. nv_ingest_api/util/service_clients/rest/rest_client.py +595 -0
  166. nv_ingest_api/util/string_processing/__init__.py +51 -0
  167. nv_ingest_api/util/string_processing/configuration.py +682 -0
  168. nv_ingest_api/util/string_processing/yaml.py +109 -0
  169. nv_ingest_api/util/system/__init__.py +0 -0
  170. nv_ingest_api/util/system/hardware_info.py +594 -0
  171. nv_ingest_api-26.1.0rc4.dist-info/METADATA +237 -0
  172. nv_ingest_api-26.1.0rc4.dist-info/RECORD +177 -0
  173. nv_ingest_api-26.1.0rc4.dist-info/WHEEL +5 -0
  174. nv_ingest_api-26.1.0rc4.dist-info/licenses/LICENSE +201 -0
  175. nv_ingest_api-26.1.0rc4.dist-info/top_level.txt +2 -0
  176. udfs/__init__.py +5 -0
  177. udfs/llm_summarizer_udf.py +259 -0
@@ -0,0 +1,1681 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ import os
6
+ import logging
7
+ import warnings
8
+ from math import log
9
+ from typing import Any
10
+ from typing import Dict
11
+ from typing import List
12
+ from typing import Optional
13
+ from typing import Tuple
14
+
15
+ import backoff
16
+ import numpy as np
17
+ import json
18
+ import pandas as pd
19
+
20
+ from nv_ingest_api.internal.primitives.nim import ModelInterface
21
+ import tritonclient.grpc as grpcclient
22
+ from nv_ingest_api.internal.primitives.nim.model_interface.decorators import multiprocessing_cache
23
+ from nv_ingest_api.internal.primitives.nim.model_interface.helpers import get_model_name
24
+ from nv_ingest_api.util.image_processing import scale_image_to_encoding_size
25
+ from nv_ingest_api.util.image_processing.transforms import numpy_to_base64
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+ YOLOX_PAGE_DEFAULT_VERSION = "nemoretriever-page-elements-v3"
30
+
31
+ # yolox-page-elements-v2 and v3 common contants
32
+ YOLOX_PAGE_CONF_THRESHOLD = 0.01
33
+ YOLOX_PAGE_IOU_THRESHOLD = 0.5
34
+ YOLOX_PAGE_MIN_SCORE = 0.1
35
+ YOLOX_PAGE_NIM_MAX_IMAGE_SIZE = 512_000
36
+ YOLOX_PAGE_IMAGE_PREPROC_HEIGHT = 1024
37
+ YOLOX_PAGE_IMAGE_PREPROC_WIDTH = 1024
38
+ YOLOX_PAGE_IMAGE_FORMAT = os.getenv("YOLOX_PAGE_IMAGE_FORMAT", "PNG")
39
+
40
+ # yolox-page-elements-v3 contants
41
+ YOLOX_PAGE_FINAL_SCORE = YOLOX_PAGE_V3_FINAL_SCORE = {
42
+ "table": 0.1,
43
+ "chart": 0.01,
44
+ "title": 0.1,
45
+ "infographic": 0.01,
46
+ "paragraph": 0.1,
47
+ "header_footer": 0.1,
48
+ }
49
+ YOLOX_PAGE_CLASS_LABELS = YOLOX_PAGE_V3_CLASS_LABELS = [
50
+ "table",
51
+ "chart",
52
+ "title",
53
+ "infographic",
54
+ "paragraph",
55
+ "header_footer",
56
+ ]
57
+
58
+ # yolox-page-elements-v2 contants
59
+ YOLOX_PAGE_V2_FINAL_SCORE = {"table": 0.1, "chart": 0.01, "infographic": 0.01}
60
+ YOLOX_PAGE_V2_CLASS_LABELS = [
61
+ "table",
62
+ "chart",
63
+ "title",
64
+ "infographic",
65
+ ]
66
+
67
+
68
+ # yolox-graphic-elements-v1 contants
69
+ YOLOX_GRAPHIC_CONF_THRESHOLD = 0.01
70
+ YOLOX_GRAPHIC_IOU_THRESHOLD = 0.25
71
+ YOLOX_GRAPHIC_MIN_SCORE = 0.1
72
+ YOLOX_GRAPHIC_NIM_MAX_IMAGE_SIZE = 512_000
73
+
74
+
75
+ YOLOX_GRAPHIC_CLASS_LABELS = [
76
+ "chart_title",
77
+ "x_title",
78
+ "y_title",
79
+ "xlabel",
80
+ "ylabel",
81
+ "other",
82
+ "legend_label",
83
+ "legend_title",
84
+ "mark_label",
85
+ "value_label",
86
+ ]
87
+
88
+
89
+ # yolox-table-structure-v1 contants
90
+ YOLOX_TABLE_CONF_THRESHOLD = 0.01
91
+ YOLOX_TABLE_IOU_THRESHOLD = 0.25
92
+ YOLOX_TABLE_MIN_SCORE = 0.1
93
+ YOLOX_TABLE_NIM_MAX_IMAGE_SIZE = 512_000
94
+
95
+ YOLOX_TABLE_IMAGE_PREPROC_HEIGHT = 1024
96
+ YOLOX_TABLE_IMAGE_PREPROC_WIDTH = 1024
97
+
98
+ YOLOX_TABLE_CLASS_LABELS = [
99
+ "border",
100
+ "cell",
101
+ "row",
102
+ "column",
103
+ "header",
104
+ ]
105
+
106
+
107
+ # YoloxModelInterfaceBase implements methods that are common to yolox-page-elements and yolox-graphic-elements
108
+ class YoloxModelInterfaceBase(ModelInterface):
109
+ """
110
+ An interface for handling inference with a Yolox object detection model, supporting both gRPC and HTTP protocols.
111
+ """
112
+
113
+ def __init__(
114
+ self,
115
+ nim_max_image_size: Optional[int] = None,
116
+ conf_threshold: Optional[float] = None,
117
+ iou_threshold: Optional[float] = None,
118
+ min_score: Optional[float] = None,
119
+ class_labels: Optional[List[str]] = None,
120
+ ):
121
+ """
122
+ Initialize the YOLOX model interface.
123
+ Parameters
124
+ ----------
125
+ """
126
+ self.nim_max_image_size = nim_max_image_size
127
+ self.conf_threshold = conf_threshold
128
+ self.iou_threshold = iou_threshold
129
+ self.min_score = min_score
130
+ self.class_labels = class_labels
131
+
132
+ def prepare_data_for_inference(self, data: Dict[str, Any]) -> Dict[str, Any]:
133
+ """
134
+ Prepare input data for inference by resizing images and storing their original shapes.
135
+
136
+ Parameters
137
+ ----------
138
+ data : dict
139
+ The input data containing a list of images.
140
+
141
+ Returns
142
+ -------
143
+ dict
144
+ The updated data dictionary with resized images and original image shapes.
145
+ """
146
+ if (not isinstance(data, dict)) or ("images" not in data):
147
+ raise KeyError("Input data must be a dictionary containing an 'images' key with a list of images.")
148
+
149
+ if not all(isinstance(x, np.ndarray) for x in data["images"]):
150
+ raise ValueError("All elements in the 'images' list must be numpy.ndarray objects.")
151
+
152
+ original_images = data["images"]
153
+ data["original_image_shapes"] = [image.shape for image in original_images]
154
+
155
+ return data
156
+
157
+ def format_input(
158
+ self, data: Dict[str, Any], protocol: str, max_batch_size: int, **kwargs
159
+ ) -> Tuple[List[Any], List[Dict[str, Any]]]:
160
+ """
161
+ Format input data for the specified protocol, returning a tuple of:
162
+ (formatted_batches, formatted_batch_data)
163
+ where:
164
+ - For gRPC: formatted_batches is a list of NumPy arrays, each of shape (B, H, W, C)
165
+ with B <= max_batch_size.
166
+ - For HTTP: formatted_batches is a list of JSON-serializable dict payloads.
167
+ - In both cases, formatted_batch_data is a list of dicts that coalesce the original
168
+ images and their original shapes in the same order as provided.
169
+
170
+ Parameters
171
+ ----------
172
+ data : dict
173
+ The input data to format. Must include:
174
+ - "images": a list of numpy.ndarray images.
175
+ - "original_image_shapes": a list of tuples with each image's (height, width),
176
+ as set by prepare_data_for_inference.
177
+ protocol : str
178
+ The protocol to use ("grpc" or "http").
179
+ max_batch_size : int
180
+ The maximum number of images per batch.
181
+
182
+ Returns
183
+ -------
184
+ tuple
185
+ A tuple (formatted_batches, formatted_batch_data).
186
+
187
+ Raises
188
+ ------
189
+ ValueError
190
+ If the protocol is invalid.
191
+ """
192
+
193
+ # Helper functions to chunk a list into sublists of length up to chunk_size.
194
+ def chunk_list(lst: list, chunk_size: int) -> List[list]:
195
+ chunk_size = max(1, chunk_size)
196
+ return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
197
+
198
+ def chunk_list_geometrically(lst: list, max_size: int) -> List[list]:
199
+ # TRT engine in Yolox NIM (gRPC) only allows a batch size in powers of 2.
200
+ chunks = []
201
+ i = 0
202
+ while i < len(lst):
203
+ chunk_size = max(1, min(2 ** int(log(len(lst) - i, 2)), max_size))
204
+ chunks.append(lst[i : i + chunk_size])
205
+ i += chunk_size
206
+ return chunks
207
+
208
+ if protocol == "grpc":
209
+ logger.debug("Formatting input for gRPC Yolox Ensemble model")
210
+ b64_images = [numpy_to_base64(image, format=YOLOX_PAGE_IMAGE_FORMAT) for image in data["images"]]
211
+ b64_chunks = chunk_list_geometrically(b64_images, max_batch_size)
212
+ original_chunks = chunk_list_geometrically(data["images"], max_batch_size)
213
+ shape_chunks = chunk_list_geometrically(data["original_image_shapes"], max_batch_size)
214
+
215
+ batched_inputs = []
216
+ formatted_batch_data = []
217
+ for b64_chunk, orig_chunk, shapes in zip(b64_chunks, original_chunks, shape_chunks):
218
+ input_array = np.array(b64_chunk, dtype=np.object_)
219
+
220
+ if getattr(self, "_grpc_uses_bls", False):
221
+ # For BLS with dynamic batching (max_batch_size > 0), we need to add explicit batch dimension
222
+ # Shape [N] becomes [1, N] to indicate: batch of 1, containing N images
223
+ input_array = input_array.reshape(1, -1)
224
+ thresholds = np.array([[self.conf_threshold, self.iou_threshold]], dtype=np.float32)
225
+ else:
226
+ current_batch_size = input_array.shape[0]
227
+ single_threshold_pair = [self.conf_threshold, self.iou_threshold]
228
+ thresholds = np.tile(single_threshold_pair, (current_batch_size, 1)).astype(np.float32)
229
+
230
+ batched_inputs.append([input_array, thresholds])
231
+ formatted_batch_data.append({"images": orig_chunk, "original_image_shapes": shapes})
232
+
233
+ return batched_inputs, formatted_batch_data
234
+
235
+ elif protocol == "http":
236
+ logger.debug("Formatting input for HTTP Yolox model")
237
+ content_list: List[Dict[str, Any]] = []
238
+ for image in data["images"]:
239
+ # Convert to uint8 if needed.
240
+ if image.dtype != np.uint8:
241
+ image = (image * 255).astype(np.uint8)
242
+
243
+ # Get original size directly from numpy array (width, height)
244
+ original_size = (image.shape[1], image.shape[0])
245
+ # Convert numpy array directly to base64 using OpenCV
246
+ image_b64 = numpy_to_base64(image, format=YOLOX_PAGE_IMAGE_FORMAT)
247
+ # Scale the image if necessary.
248
+ scaled_image_b64, new_size = scale_image_to_encoding_size(
249
+ image_b64, max_base64_size=self.nim_max_image_size
250
+ )
251
+ if new_size != original_size:
252
+ logger.debug(f"Image was scaled from {original_size} to {new_size}.")
253
+
254
+ content_list.append({"type": "image_url", "url": f"data:image/png;base64,{scaled_image_b64}"})
255
+
256
+ # Chunk the payload content, the original images, and their shapes.
257
+ content_chunks = chunk_list(content_list, max_batch_size)
258
+ original_chunks = chunk_list(data["images"], max_batch_size)
259
+ shape_chunks = chunk_list(data["original_image_shapes"], max_batch_size)
260
+
261
+ payload_batches = []
262
+ formatted_batch_data = []
263
+ for chunk, orig_chunk, shapes in zip(content_chunks, original_chunks, shape_chunks):
264
+ payload = {"input": chunk}
265
+ payload_batches.append(payload)
266
+ formatted_batch_data.append({"images": orig_chunk, "original_image_shapes": shapes})
267
+ return payload_batches, formatted_batch_data
268
+
269
+ else:
270
+ raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.")
271
+
272
+ def parse_output(self, response: Any, protocol: str, data: Optional[Dict[str, Any]] = None, **kwargs) -> Any:
273
+ """
274
+ Parse the output from the model's inference response.
275
+
276
+ Parameters
277
+ ----------
278
+ response : Any
279
+ The response from the model inference.
280
+ protocol : str
281
+ The protocol used ("grpc" or "http").
282
+ data : dict, optional
283
+ Additional input data passed to the function.
284
+
285
+ Returns
286
+ -------
287
+ Any
288
+ The parsed output data.
289
+
290
+ Raises
291
+ ------
292
+ ValueError
293
+ If an invalid protocol is specified or the response format is unexpected.
294
+ """
295
+
296
+ if protocol == "grpc":
297
+ logger.debug("Parsing output from gRPC Yolox model")
298
+ return response # For gRPC, response is already a numpy array
299
+ elif protocol == "http":
300
+ logger.debug("Parsing output from HTTP Yolox model")
301
+
302
+ processed_outputs = []
303
+
304
+ batch_results = response.get("data", [])
305
+ for detections in batch_results:
306
+ new_bounding_boxes = {label: [] for label in self.class_labels}
307
+
308
+ bounding_boxes = detections.get("bounding_boxes", [])
309
+ for obj_type, bboxes in bounding_boxes.items():
310
+ for bbox in bboxes:
311
+ xmin = bbox["x_min"]
312
+ ymin = bbox["y_min"]
313
+ xmax = bbox["x_max"]
314
+ ymax = bbox["y_max"]
315
+ confidence = bbox["confidence"]
316
+
317
+ new_bounding_boxes[obj_type].append([xmin, ymin, xmax, ymax, confidence])
318
+
319
+ processed_outputs.append(new_bounding_boxes)
320
+
321
+ return processed_outputs
322
+ else:
323
+ raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.")
324
+
325
+ def process_inference_results(self, output: Any, protocol: str, **kwargs) -> List[Dict[str, Any]]:
326
+ """
327
+ Process the results of the Yolox model inference and return the final annotations.
328
+
329
+ Parameters
330
+ ----------
331
+ output_array : np.ndarray
332
+ The raw output from the Yolox model.
333
+ kwargs : dict
334
+ Additional parameters for processing, including thresholds and number of classes.
335
+
336
+ Returns
337
+ -------
338
+ list[dict]
339
+ A list of annotation dictionaries for each image in the batch.
340
+ """
341
+ if protocol == "http":
342
+ # For http, the output already has postprocessing applied. Skip to table/chart expansion.
343
+ results = output
344
+
345
+ elif protocol == "grpc":
346
+ results = []
347
+ # For grpc, apply the same NIM postprocessing.
348
+ for out in output:
349
+ if isinstance(out, bytes):
350
+ out = out.decode("utf-8")
351
+ if isinstance(out, dict):
352
+ continue
353
+ results.append(json.loads(out))
354
+ inference_results = self.postprocess_annotations(results, **kwargs)
355
+ return inference_results
356
+
357
+ def postprocess_annotations(self, annotation_dicts, **kwargs):
358
+ raise NotImplementedError()
359
+
360
+ def transform_normalized_coordinates_to_original(self, results, original_image_shapes):
361
+ """ """
362
+ transformed_results = []
363
+
364
+ for annotation_dict, shape in zip(results, original_image_shapes):
365
+ new_dict = {}
366
+ for label, bboxes_and_scores in annotation_dict.items():
367
+ new_dict[label] = []
368
+ for bbox_and_score in bboxes_and_scores:
369
+ bbox = bbox_and_score[:4]
370
+ transformed_bbox = [
371
+ bbox[0] * shape[1],
372
+ bbox[1] * shape[0],
373
+ bbox[2] * shape[1],
374
+ bbox[3] * shape[0],
375
+ ]
376
+ transformed_bbox += bbox_and_score[4:]
377
+ new_dict[label].append(transformed_bbox)
378
+ transformed_results.append(new_dict)
379
+
380
+ return transformed_results
381
+
382
+
383
+ class YoloxPageElementsModelInterface(YoloxModelInterfaceBase):
384
+ """
385
+ An interface for handling inference with yolox-page-elements model, supporting both gRPC and HTTP protocols.
386
+ """
387
+
388
+ def __init__(self, version: str = YOLOX_PAGE_DEFAULT_VERSION):
389
+ """
390
+ Initialize the yolox-page-elements model interface.
391
+ """
392
+ self.version = version
393
+
394
+ if self.version.endswith("-v3"):
395
+ class_labels = YOLOX_PAGE_V3_CLASS_LABELS
396
+ self._grpc_uses_bls = True
397
+ else:
398
+ class_labels = YOLOX_PAGE_V2_CLASS_LABELS
399
+ self._grpc_uses_bls = False
400
+
401
+ super().__init__(
402
+ nim_max_image_size=YOLOX_PAGE_NIM_MAX_IMAGE_SIZE,
403
+ conf_threshold=YOLOX_PAGE_CONF_THRESHOLD,
404
+ iou_threshold=YOLOX_PAGE_IOU_THRESHOLD,
405
+ min_score=YOLOX_PAGE_MIN_SCORE,
406
+ class_labels=class_labels,
407
+ )
408
+
409
+ def name(
410
+ self,
411
+ ) -> str:
412
+ """
413
+ Returns the name of the Yolox model interface.
414
+
415
+ Returns
416
+ -------
417
+ str
418
+ The name of the model interface.
419
+ """
420
+
421
+ return "yolox-page-elements"
422
+
423
+ def postprocess_annotations(self, annotation_dicts, final_score=None, **kwargs):
424
+ original_image_shapes = kwargs.get("original_image_shapes", [])
425
+
426
+ running_v3 = annotation_dicts and set(YOLOX_PAGE_V3_CLASS_LABELS) <= annotation_dicts[0].keys()
427
+
428
+ if not final_score:
429
+ if running_v3:
430
+ final_score = YOLOX_PAGE_V3_FINAL_SCORE
431
+ else:
432
+ final_score = YOLOX_PAGE_V2_FINAL_SCORE
433
+
434
+ if running_v3:
435
+ expected_final_score_keys = YOLOX_PAGE_V3_FINAL_SCORE
436
+ else:
437
+ expected_final_score_keys = [x for x in YOLOX_PAGE_V2_FINAL_SCORE if x != "title"]
438
+
439
+ if (not isinstance(final_score, dict)) or (sorted(final_score.keys()) != sorted(expected_final_score_keys)):
440
+ raise ValueError(
441
+ "yolox-page-elements requires a dictionary of thresholds per each class: "
442
+ f"{expected_final_score_keys}"
443
+ )
444
+
445
+ if annotation_dicts and running_v3:
446
+ annotation_dicts = [
447
+ postprocess_page_elements_v3(annotation_dict, labels=YOLOX_PAGE_V3_CLASS_LABELS)
448
+ for annotation_dict in annotation_dicts
449
+ ]
450
+ else:
451
+ # Table/chart expansion is "business logic" specific to nv-ingest
452
+ annotation_dicts = [expand_table_bboxes(annotation_dict) for annotation_dict in annotation_dicts]
453
+ annotation_dicts = [expand_chart_bboxes(annotation_dict) for annotation_dict in annotation_dicts]
454
+
455
+ inference_results = []
456
+
457
+ # Filter out bounding boxes below the final threshold
458
+ # This final thresholding is "business logic" specific to nv-ingest
459
+ for annotation_dict in annotation_dicts:
460
+ new_dict = {}
461
+ if running_v3:
462
+ for label in YOLOX_PAGE_V3_CLASS_LABELS:
463
+ if label in annotation_dict and label in final_score:
464
+ threshold = final_score[label]
465
+ new_dict[label] = [bb for bb in annotation_dict[label] if bb[4] >= threshold]
466
+ else:
467
+ for label in YOLOX_PAGE_V2_CLASS_LABELS:
468
+ if label in annotation_dict:
469
+ if label == "title":
470
+ new_dict[label] = annotation_dict[label]
471
+ elif label in final_score:
472
+ threshold = final_score[label]
473
+ new_dict[label] = [bb for bb in annotation_dict[label] if bb[4] >= threshold]
474
+
475
+ inference_results.append(new_dict)
476
+
477
+ inference_results = self.transform_normalized_coordinates_to_original(inference_results, original_image_shapes)
478
+
479
+ return inference_results
480
+
481
+
482
+ class YoloxGraphicElementsModelInterface(YoloxModelInterfaceBase):
483
+ """
484
+ An interface for handling inference with yolox-graphic-elemenents model, supporting both gRPC and HTTP protocols.
485
+ """
486
+
487
+ def __init__(self):
488
+ """
489
+ Initialize the yolox-graphic-elements model interface.
490
+ """
491
+ super().__init__(
492
+ nim_max_image_size=YOLOX_GRAPHIC_NIM_MAX_IMAGE_SIZE,
493
+ conf_threshold=YOLOX_GRAPHIC_CONF_THRESHOLD,
494
+ iou_threshold=YOLOX_GRAPHIC_IOU_THRESHOLD,
495
+ min_score=YOLOX_GRAPHIC_MIN_SCORE,
496
+ class_labels=YOLOX_GRAPHIC_CLASS_LABELS,
497
+ )
498
+
499
+ def name(
500
+ self,
501
+ ) -> str:
502
+ """
503
+ Returns the name of the Yolox model interface.
504
+
505
+ Returns
506
+ -------
507
+ str
508
+ The name of the model interface.
509
+ """
510
+
511
+ return "yolox-graphic-elements"
512
+
513
+ def postprocess_annotations(self, annotation_dicts, **kwargs):
514
+ original_image_shapes = kwargs.get("original_image_shapes", [])
515
+
516
+ annotation_dicts = self.transform_normalized_coordinates_to_original(annotation_dicts, original_image_shapes)
517
+
518
+ inference_results = []
519
+
520
+ # bbox extraction: additional postprocessing speicifc to nv-ingest
521
+ for pred, shape in zip(annotation_dicts, original_image_shapes):
522
+ bbox_dict = get_bbox_dict_yolox_graphic(
523
+ pred,
524
+ shape,
525
+ self.class_labels,
526
+ self.min_score,
527
+ )
528
+ # convert numpy arrays to list
529
+ bbox_dict = {
530
+ label: array.tolist() if isinstance(array, np.ndarray) else array for label, array in bbox_dict.items()
531
+ }
532
+ inference_results.append(bbox_dict)
533
+
534
+ return inference_results
535
+
536
+
537
+ class YoloxTableStructureModelInterface(YoloxModelInterfaceBase):
538
+ """
539
+ An interface for handling inference with yolox-graphic-elemenents model, supporting both gRPC and HTTP protocols.
540
+ """
541
+
542
+ def __init__(self):
543
+ """
544
+ Initialize the yolox-graphic-elements model interface.
545
+ """
546
+ super().__init__(
547
+ nim_max_image_size=YOLOX_TABLE_NIM_MAX_IMAGE_SIZE,
548
+ conf_threshold=YOLOX_TABLE_CONF_THRESHOLD,
549
+ iou_threshold=YOLOX_TABLE_IOU_THRESHOLD,
550
+ min_score=YOLOX_TABLE_MIN_SCORE,
551
+ class_labels=YOLOX_TABLE_CLASS_LABELS,
552
+ )
553
+
554
+ def name(
555
+ self,
556
+ ) -> str:
557
+ """
558
+ Returns the name of the Yolox model interface.
559
+
560
+ Returns
561
+ -------
562
+ str
563
+ The name of the model interface.
564
+ """
565
+
566
+ return "yolox-table-structure"
567
+
568
+ def postprocess_annotations(self, annotation_dicts, **kwargs):
569
+ original_image_shapes = kwargs.get("original_image_shapes", [])
570
+
571
+ annotation_dicts = self.transform_normalized_coordinates_to_original(annotation_dicts, original_image_shapes)
572
+
573
+ inference_results = []
574
+
575
+ # bbox extraction: additional postprocessing speicifc to nv-ingest
576
+ for pred, shape in zip(annotation_dicts, original_image_shapes):
577
+ bbox_dict = get_bbox_dict_yolox_table(
578
+ pred,
579
+ shape,
580
+ self.class_labels,
581
+ self.min_score,
582
+ )
583
+ # convert numpy arrays to list
584
+ bbox_dict = {
585
+ label: array.tolist() if isinstance(array, np.ndarray) else array for label, array in bbox_dict.items()
586
+ }
587
+ inference_results.append(bbox_dict)
588
+
589
+ return inference_results
590
+
591
+
592
+ def expand_table_bboxes(annotation_dict, labels=None):
593
+ """
594
+ Additional preprocessing for tables: extend the upper bounds to capture titles if any.
595
+ Args:
596
+ annotation_dict: output of postprocess_results, a dictionary with keys "table", "figure", "title"
597
+
598
+ Returns:
599
+ annotation_dict: same as input, with expanded bboxes for charts
600
+
601
+ """
602
+ if not labels:
603
+ labels = list(annotation_dict.keys())
604
+
605
+ if not annotation_dict or len(annotation_dict["table"]) == 0:
606
+ return annotation_dict
607
+
608
+ new_annotation_dict = {label: [] for label in labels}
609
+
610
+ for label, bboxes in annotation_dict.items():
611
+ for bbox_and_score in bboxes:
612
+ bbox, score = bbox_and_score[:4], bbox_and_score[4]
613
+
614
+ if label == "table":
615
+ height = bbox[3] - bbox[1]
616
+ bbox[1] = max(0.0, min(1.0, bbox[1] - height * 0.2))
617
+
618
+ new_annotation_dict[label].append([round(float(x), 4) for x in bbox + [score]])
619
+
620
+ return new_annotation_dict
621
+
622
+
623
+ def expand_chart_bboxes(annotation_dict, labels=None):
624
+ """
625
+ Expand bounding boxes of charts and titles based on the bounding boxes of the other class.
626
+ Args:
627
+ annotation_dict: output of postprocess_results, a dictionary with keys "table", "figure", "title"
628
+
629
+ Returns:
630
+ annotation_dict: same as input, with expanded bboxes for charts
631
+
632
+ """
633
+ if not labels:
634
+ labels = list(annotation_dict.keys())
635
+
636
+ if not annotation_dict or len(annotation_dict["chart"]) == 0:
637
+ return annotation_dict
638
+
639
+ bboxes = []
640
+ confidences = []
641
+ label_idxs = []
642
+ for i, label in enumerate(labels):
643
+ label_annotations = np.array(annotation_dict[label])
644
+
645
+ if len(label_annotations) > 0:
646
+ bboxes.append(label_annotations[:, :4])
647
+ confidences.append(label_annotations[:, 4])
648
+ label_idxs.append(np.full(len(label_annotations), i))
649
+ bboxes = np.concatenate(bboxes)
650
+ confidences = np.concatenate(confidences)
651
+ label_idxs = np.concatenate(label_idxs)
652
+
653
+ pred_wbf, confidences_wbf, labels_wbf = weighted_boxes_fusion(
654
+ bboxes[:, None],
655
+ confidences[:, None],
656
+ label_idxs[:, None],
657
+ merge_type="biggest",
658
+ conf_type="max",
659
+ iou_thr=0.01,
660
+ class_agnostic=False,
661
+ )
662
+
663
+ chart_bboxes = pred_wbf[labels_wbf == 1]
664
+ chart_confidences = confidences_wbf[labels_wbf == 1]
665
+ title_bboxes = pred_wbf[labels_wbf == 2]
666
+
667
+ found_title_idxs, no_found_title_idxs = [], []
668
+ for i in range(len(chart_bboxes)):
669
+ match = match_with_title_v1(chart_bboxes[i], title_bboxes, iou_th=0.01)
670
+ if match is not None:
671
+ chart_bboxes[i] = match[0]
672
+ title_bboxes = match[1]
673
+ found_title_idxs.append(i)
674
+ else:
675
+ no_found_title_idxs.append(i)
676
+
677
+ chart_bboxes[found_title_idxs] = expand_boxes_v1(chart_bboxes[found_title_idxs], r_x=1.05, r_y=1.1)
678
+ chart_bboxes[no_found_title_idxs] = expand_boxes_v1(chart_bboxes[no_found_title_idxs], r_x=1.1, r_y=1.25)
679
+
680
+ annotation_dict = {
681
+ "table": annotation_dict["table"],
682
+ "chart": np.concatenate([chart_bboxes, chart_confidences[:, None]], axis=1).tolist(),
683
+ "title": annotation_dict["title"],
684
+ }
685
+
686
+ return annotation_dict
687
+
688
+
689
+ def postprocess_page_elements_v3(annotation_dict, labels=None):
690
+ """
691
+ Expand bounding boxes of tables/charts/infographics and titles based on the bounding boxes of the other class.
692
+ Args:
693
+ annotation_dict: output of postprocess_results, a dictionary with keys:
694
+ "table", "chart", "infographics", "title", "paragraph", "header_footer".
695
+
696
+ Returns:
697
+ annotation_dict: same as input, with expanded bboxes for page elements.
698
+
699
+ """
700
+ if not labels:
701
+ labels = list(annotation_dict.keys())
702
+
703
+ if not annotation_dict:
704
+ return annotation_dict
705
+
706
+ bboxes = []
707
+ confidences = []
708
+ label_idxs = []
709
+
710
+ for i, label in enumerate(labels):
711
+ if label not in annotation_dict:
712
+ continue
713
+
714
+ label_annotations = np.array(annotation_dict[label])
715
+
716
+ if len(label_annotations) > 0:
717
+ bboxes.append(label_annotations[:, :4])
718
+ confidences.append(label_annotations[:, 4])
719
+ label_idxs.append(np.full(len(label_annotations), i))
720
+
721
+ if not bboxes:
722
+ return annotation_dict
723
+
724
+ bboxes = np.concatenate(bboxes)
725
+ confidences = np.concatenate(confidences)
726
+ label_idxs = np.concatenate(label_idxs)
727
+
728
+ bboxes, confidences, label_idxs = remove_overlapping_boxes_using_wbf(bboxes, confidences, label_idxs)
729
+ bboxes, confidences, label_idxs, found_title = match_structured_boxes_with_title(
730
+ bboxes, confidences, label_idxs, labels
731
+ )
732
+ bboxes, confidences, label_idxs = expand_tables_and_charts(bboxes, confidences, label_idxs, labels, found_title)
733
+ bboxes, confidences, label_idxs = postprocess_included_texts(bboxes, confidences, label_idxs, labels)
734
+
735
+ order = np.argsort(bboxes[:, 1] * 10 + bboxes[:, 0])
736
+ bboxes, confidences, label_idxs = bboxes[order], confidences[order], label_idxs[order]
737
+
738
+ new_annotation_dict = {}
739
+ for i, label in enumerate(labels):
740
+ selected_bboxes = bboxes[label_idxs == i]
741
+ selected_confidences = confidences[label_idxs == i]
742
+ new_annotation_dict[label] = np.concatenate([selected_bboxes, selected_confidences[:, None]], axis=1).tolist()
743
+
744
+ return new_annotation_dict
745
+
746
+
747
+ def weighted_boxes_fusion(
748
+ boxes_list,
749
+ scores_list,
750
+ labels_list,
751
+ iou_thr=0.5,
752
+ skip_box_thr=0.0,
753
+ conf_type="avg",
754
+ merge_type="weighted",
755
+ class_agnostic=False,
756
+ ):
757
+ """
758
+ Custom wbf implementation that supports a class_agnostic mode and a biggest box fusion.
759
+ Boxes are expected to be in normalized (x0, y0, x1, y1) format.
760
+
761
+ Args:
762
+ boxes_list (list[np array[n x 4]]): List of boxes. One list per model.
763
+ scores_list (list[np array[n]]): List of confidences.
764
+ labels_list (list[np array[n]]): List of labels
765
+ iou_thr (float, optional): IoU threshold for matching. Defaults to 0.55.
766
+ skip_box_thr (float, optional): Exclude boxes with score < skip_box_thr. Defaults to 0.0.
767
+ conf_type (str, optional): Confidence merging type. Defaults to "avg".
768
+ merge_type (str, optional): Merge type "weighted" or "biggest". Defaults to "weighted".
769
+ class_agnostic (bool, optional): If True, merge boxes from different classes. Defaults to False.
770
+
771
+ Returns:
772
+ np array[N x 4]: Merged boxes,
773
+ np array[N]: Merged confidences,
774
+ np array[N]: Merged labels.
775
+ """
776
+ weights = np.ones(len(boxes_list))
777
+
778
+ assert conf_type in ["avg", "max"], 'Conf type must be "avg" or "max"'
779
+ assert merge_type in [
780
+ "weighted",
781
+ "biggest",
782
+ ], 'Conf type must be "weighted" or "biggest"'
783
+
784
+ filtered_boxes = prefilter_boxes(
785
+ boxes_list,
786
+ scores_list,
787
+ labels_list,
788
+ weights,
789
+ skip_box_thr,
790
+ class_agnostic=class_agnostic,
791
+ )
792
+ if len(filtered_boxes) == 0:
793
+ return np.zeros((0, 4)), np.zeros((0,)), np.zeros((0,))
794
+
795
+ overall_boxes = []
796
+ for label in filtered_boxes:
797
+ boxes = filtered_boxes[label]
798
+ np.empty((0, 8))
799
+
800
+ clusters = []
801
+
802
+ # Clusterize boxes
803
+ for j in range(len(boxes)):
804
+ ids = [i for i in range(len(boxes)) if i != j]
805
+ index, best_iou = find_matching_box_fast(boxes[ids], boxes[j], iou_thr)
806
+
807
+ if index != -1:
808
+ index = ids[index]
809
+ cluster_idx = [clust_idx for clust_idx, clust in enumerate(clusters) if (j in clust or index in clust)]
810
+ if len(cluster_idx):
811
+ cluster_idx = cluster_idx[0]
812
+ clusters[cluster_idx] = list(set(clusters[cluster_idx] + [index, j]))
813
+ else:
814
+ clusters.append([index, j])
815
+ else:
816
+ clusters.append([j])
817
+
818
+ for j, c in enumerate(clusters):
819
+ if merge_type == "weighted":
820
+ weighted_box = get_weighted_box(boxes[c], conf_type)
821
+ elif merge_type == "biggest":
822
+ weighted_box = get_biggest_box(boxes[c], conf_type)
823
+
824
+ if conf_type == "max":
825
+ weighted_box[1] = weighted_box[1] / weights.max()
826
+ else: # avg
827
+ weighted_box[1] = weighted_box[1] * len(c) / weights.sum()
828
+ overall_boxes.append(weighted_box)
829
+
830
+ overall_boxes = np.array(overall_boxes)
831
+ overall_boxes = overall_boxes[overall_boxes[:, 1].argsort()[::-1]]
832
+ boxes = overall_boxes[:, 4:]
833
+ scores = overall_boxes[:, 1]
834
+ labels = overall_boxes[:, 0]
835
+ return boxes, scores, labels
836
+
837
+
838
+ def prefilter_boxes(boxes, scores, labels, weights, thr, class_agnostic=False):
839
+ """
840
+ Reformats and filters boxes.
841
+ Output is a dict of boxes to merge separately.
842
+
843
+ Args:
844
+ boxes (list[np array[n x 4]]): List of boxes. One list per model.
845
+ scores (list[np array[n]]): List of confidences.
846
+ labels (list[np array[n]]): List of labels.
847
+ weights (list): Model weights.
848
+ thr (float): Confidence threshold
849
+ class_agnostic (bool, optional): If True, merge boxes from different classes. Defaults to False.
850
+
851
+ Returns:
852
+ dict[np array [? x 8]]: Filtered boxes.
853
+ """
854
+ # Create dict with boxes stored by its label
855
+ new_boxes = dict()
856
+
857
+ for t in range(len(boxes)):
858
+ if len(boxes[t]) != len(scores[t]):
859
+ print(
860
+ "Error. Length of boxes arrays not equal to length of scores array: {} != {}".format(
861
+ len(boxes[t]), len(scores[t])
862
+ )
863
+ )
864
+ exit()
865
+
866
+ if len(boxes[t]) != len(labels[t]):
867
+ print(
868
+ "Error. Length of boxes arrays not equal to length of labels array: {} != {}".format(
869
+ len(boxes[t]), len(labels[t])
870
+ )
871
+ )
872
+ exit()
873
+
874
+ for j in range(len(boxes[t])):
875
+ score = scores[t][j]
876
+ if score < thr:
877
+ continue
878
+ label = int(labels[t][j])
879
+ box_part = boxes[t][j]
880
+ x1 = float(box_part[0])
881
+ y1 = float(box_part[1])
882
+ x2 = float(box_part[2])
883
+ y2 = float(box_part[3])
884
+
885
+ # Box data checks
886
+ if x2 < x1:
887
+ warnings.warn("X2 < X1 value in box. Swap them.")
888
+ x1, x2 = x2, x1
889
+ if y2 < y1:
890
+ warnings.warn("Y2 < Y1 value in box. Swap them.")
891
+ y1, y2 = y2, y1
892
+ if x1 < 0:
893
+ warnings.warn("X1 < 0 in box. Set it to 0.")
894
+ x1 = 0
895
+ if x1 > 1:
896
+ warnings.warn("X1 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.")
897
+ x1 = 1
898
+ if x2 < 0:
899
+ warnings.warn("X2 < 0 in box. Set it to 0.")
900
+ x2 = 0
901
+ if x2 > 1:
902
+ warnings.warn("X2 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.")
903
+ x2 = 1
904
+ if y1 < 0:
905
+ warnings.warn("Y1 < 0 in box. Set it to 0.")
906
+ y1 = 0
907
+ if y1 > 1:
908
+ warnings.warn("Y1 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.")
909
+ y1 = 1
910
+ if y2 < 0:
911
+ warnings.warn("Y2 < 0 in box. Set it to 0.")
912
+ y2 = 0
913
+ if y2 > 1:
914
+ warnings.warn("Y2 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.")
915
+ y2 = 1
916
+ if (x2 - x1) * (y2 - y1) == 0.0:
917
+ warnings.warn("Zero area box skipped: {}.".format(box_part))
918
+ continue
919
+
920
+ # [label, score, weight, model index, x1, y1, x2, y2]
921
+ b = [int(label), float(score) * weights[t], weights[t], t, x1, y1, x2, y2]
922
+
923
+ label_k = "*" if class_agnostic else label
924
+ if label_k not in new_boxes:
925
+ new_boxes[label_k] = []
926
+ new_boxes[label_k].append(b)
927
+
928
+ # Sort each list in dict by score and transform it to numpy array
929
+ for k in new_boxes:
930
+ current_boxes = np.array(new_boxes[k])
931
+ new_boxes[k] = current_boxes[current_boxes[:, 1].argsort()[::-1]]
932
+
933
+ return new_boxes
934
+
935
+
936
+ def find_matching_box_fast(boxes_list, new_box, match_iou):
937
+ """
938
+ Reimplementation of find_matching_box with numpy instead of loops. Gives significant speed up for larger arrays
939
+ (~100x). This was previously the bottleneck since the function is called for every entry in the array.
940
+ """
941
+
942
+ def bb_iou_array(boxes, new_box):
943
+ # bb interesection over union
944
+ xA = np.maximum(boxes[:, 0], new_box[0])
945
+ yA = np.maximum(boxes[:, 1], new_box[1])
946
+ xB = np.minimum(boxes[:, 2], new_box[2])
947
+ yB = np.minimum(boxes[:, 3], new_box[3])
948
+
949
+ interArea = np.maximum(xB - xA, 0) * np.maximum(yB - yA, 0)
950
+
951
+ # compute the area of both the prediction and ground-truth rectangles
952
+ boxAArea = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
953
+ boxBArea = (new_box[2] - new_box[0]) * (new_box[3] - new_box[1])
954
+
955
+ iou = interArea / (boxAArea + boxBArea - interArea)
956
+
957
+ return iou
958
+
959
+ if boxes_list.shape[0] == 0:
960
+ return -1, match_iou
961
+
962
+ ious = bb_iou_array(boxes_list[:, 4:], new_box[4:])
963
+ # ious[boxes[:, 0] != new_box[0]] = -1
964
+
965
+ best_idx = np.argmax(ious)
966
+ best_iou = ious[best_idx]
967
+
968
+ if best_iou <= match_iou:
969
+ best_iou = match_iou
970
+ best_idx = -1
971
+
972
+ return best_idx, best_iou
973
+
974
+
975
+ def get_biggest_box(boxes, conf_type="avg"):
976
+ """
977
+ Merges boxes by using the biggest box.
978
+
979
+ Args:
980
+ boxes (np array [n x 8]): Boxes to merge.
981
+ conf_type (str, optional): Confidence merging type. Defaults to "avg".
982
+
983
+ Returns:
984
+ np array [8]: Merged box.
985
+ """
986
+ box = np.zeros(8, dtype=np.float32)
987
+ box[4:] = boxes[0][4:]
988
+ conf_list = []
989
+ w = 0
990
+ for b in boxes:
991
+ box[4] = min(box[4], b[4])
992
+ box[5] = min(box[5], b[5])
993
+ box[6] = max(box[6], b[6])
994
+ box[7] = max(box[7], b[7])
995
+ conf_list.append(b[1])
996
+ w += b[2]
997
+
998
+ box[0] = merge_labels(np.array([b[0] for b in boxes]), np.array([b[1] for b in boxes]))
999
+ # print(box[0], np.array([b[0] for b in boxes]))
1000
+
1001
+ box[1] = np.max(conf_list) if conf_type == "max" else np.mean(conf_list)
1002
+ box[2] = w
1003
+ box[3] = -1 # model index field is retained for consistency but is not used.
1004
+ return box
1005
+
1006
+
1007
+ def merge_labels(labels, confs):
1008
+ """
1009
+ Custom function for merging labels.
1010
+ If all labels are the same, return the unique value.
1011
+ Else, return the label of the most confident non-title (class 2) box.
1012
+
1013
+ Args:
1014
+ labels (np array [n]): Labels.
1015
+ confs (np array [n]): Confidence.
1016
+
1017
+ Returns:
1018
+ int: Label.
1019
+ """
1020
+ if len(np.unique(labels)) == 1:
1021
+ return labels[0]
1022
+ else: # Most confident and not a title
1023
+ confs = confs[confs != 2]
1024
+ labels = labels[labels != 2]
1025
+ return labels[np.argmax(confs)]
1026
+
1027
+
1028
+ def match_with_title_v1(chart_bbox, title_bboxes, iou_th=0.01):
1029
+ if not len(title_bboxes):
1030
+ return None
1031
+
1032
+ dist_above = np.abs(title_bboxes[:, 3] - chart_bbox[1])
1033
+ dist_below = np.abs(chart_bbox[3] - title_bboxes[:, 1])
1034
+
1035
+ dist_left = np.abs(title_bboxes[:, 0] - chart_bbox[0])
1036
+
1037
+ ious = bb_iou_array(title_bboxes, chart_bbox)
1038
+
1039
+ matches = None
1040
+ if np.max(ious) > iou_th:
1041
+ matches = np.where(ious > iou_th)[0]
1042
+ else:
1043
+ dists = np.min([dist_above, dist_below], 0)
1044
+ dists += dist_left
1045
+ # print(dists)
1046
+ if np.min(dists) < 0.1:
1047
+ matches = [np.argmin(dists)]
1048
+
1049
+ if matches is not None:
1050
+ new_bbox = chart_bbox
1051
+ for match in matches:
1052
+ new_bbox = merge_boxes(new_bbox, title_bboxes[match])
1053
+ title_bboxes = title_bboxes[[i for i in range(len(title_bboxes)) if i not in matches]]
1054
+ return new_bbox, title_bboxes
1055
+
1056
+ else:
1057
+ return None
1058
+
1059
+
1060
+ def bb_iou_array(boxes, new_box):
1061
+ # bb interesection over union
1062
+ xA = np.maximum(boxes[:, 0], new_box[0])
1063
+ yA = np.maximum(boxes[:, 1], new_box[1])
1064
+ xB = np.minimum(boxes[:, 2], new_box[2])
1065
+ yB = np.minimum(boxes[:, 3], new_box[3])
1066
+
1067
+ interArea = np.maximum(xB - xA, 0) * np.maximum(yB - yA, 0)
1068
+
1069
+ # compute the area of both the prediction and ground-truth rectangles
1070
+ boxAArea = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
1071
+ boxBArea = (new_box[2] - new_box[0]) * (new_box[3] - new_box[1])
1072
+
1073
+ iou = interArea / (boxAArea + boxBArea - interArea)
1074
+
1075
+ return iou
1076
+
1077
+
1078
+ def merge_boxes(b1, b2):
1079
+ b = b1.copy()
1080
+ b[0] = min(b1[0], b2[0])
1081
+ b[1] = min(b1[1], b2[1])
1082
+ b[2] = max(b1[2], b2[2])
1083
+ b[3] = max(b1[3], b2[3])
1084
+ return b
1085
+
1086
+
1087
+ def expand_boxes_v1(boxes, r_x=1, r_y=1):
1088
+ dw = (boxes[:, 2] - boxes[:, 0]) / 2 * (r_x - 1)
1089
+ boxes[:, 0] -= dw
1090
+ boxes[:, 2] += dw
1091
+
1092
+ dh = (boxes[:, 3] - boxes[:, 1]) / 2 * (r_y - 1)
1093
+ boxes[:, 1] -= dh
1094
+ boxes[:, 3] += dh
1095
+
1096
+ boxes = np.clip(boxes, 0, 1)
1097
+ return boxes
1098
+
1099
+
1100
+ def get_weighted_box(boxes, conf_type="avg"):
1101
+ """
1102
+ Merges boxes by using the weighted fusion.
1103
+
1104
+ Args:
1105
+ boxes (np array [n x 8]): Boxes to merge.
1106
+ conf_type (str, optional): Confidence merging type. Defaults to "avg".
1107
+
1108
+ Returns:
1109
+ np array [8]: Merged box.
1110
+ """
1111
+ box = np.zeros(8, dtype=np.float32)
1112
+ conf = 0
1113
+ conf_list = []
1114
+ w = 0
1115
+ for b in boxes:
1116
+ box[4:] += b[1] * b[4:]
1117
+ conf += b[1]
1118
+ conf_list.append(b[1])
1119
+ w += b[2]
1120
+
1121
+ box[0] = merge_labels(np.array([b[0] for b in boxes]), np.array([b[1] for b in boxes]))
1122
+
1123
+ box[1] = np.max(conf_list) if conf_type == "max" else np.mean(conf_list)
1124
+ box[2] = w
1125
+ box[3] = -1 # model index field is retained for consistency but is not used.
1126
+ box[4:] /= conf
1127
+ return box
1128
+
1129
+
1130
+ def batched_overlaps(A, B):
1131
+ """
1132
+ Calculate the Intersection over Union (IoU) between
1133
+ two sets of bounding boxes in a batched manner.
1134
+ Normalization is modified to only use the area of A boxes, hence computing the overlaps.
1135
+ Args:
1136
+ A (ndarray): Array of bounding boxes of shape (N, 4) in format [x1, y1, x2, y2].
1137
+ B (ndarray): Array of bounding boxes of shape (M, 4) in format [x1, y1, x2, y2].
1138
+ Returns:
1139
+ ndarray: Array of IoU values of shape (N, M) representing the overlaps
1140
+ between each pair of bounding boxes.
1141
+ """
1142
+ A = A.copy()
1143
+ B = B.copy()
1144
+
1145
+ A = A[None].repeat(B.shape[0], 0)
1146
+ B = B[:, None].repeat(A.shape[1], 1)
1147
+
1148
+ low = np.s_[..., :2]
1149
+ high = np.s_[..., 2:]
1150
+
1151
+ A, B = A.copy(), B.copy()
1152
+ A[high] += 1
1153
+ B[high] += 1
1154
+
1155
+ intrs = (np.maximum(0, np.minimum(A[high], B[high]) - np.maximum(A[low], B[low]))).prod(-1)
1156
+ ious = intrs / (A[high] - A[low]).prod(-1)
1157
+
1158
+ return ious
1159
+
1160
+
1161
+ def find_boxes_inside(boxes, boxes_to_check, threshold=0.9):
1162
+ """
1163
+ Find all boxes that are inside another box based on
1164
+ the intersection area divided by the area of the smaller box,
1165
+ and removes them.
1166
+ """
1167
+ overlaps = batched_overlaps(boxes_to_check, boxes)
1168
+ to_keep = (overlaps >= threshold).sum(0) <= 1
1169
+ return boxes_to_check[to_keep]
1170
+
1171
+
1172
+ def get_bbox_dict_yolox_graphic(preds, shape, class_labels, threshold_=0.1) -> Dict[str, np.ndarray]:
1173
+ """
1174
+ Extracts bounding boxes from YOLOX model predictions:
1175
+ - Applies thresholding
1176
+ - Reformats boxes
1177
+ - Cleans the `other` detections: removes the ones that are included in other detections.
1178
+ - If no title is found, the biggest `other` box is used if it is larger than 0.3*img_w.
1179
+ Args:
1180
+ preds (np.ndarray): YOLOX model predictions including bounding boxes, scores, and labels.
1181
+ shape (tuple): Original image shape.
1182
+ threshold_ (float): Score threshold to filter bounding boxes.
1183
+ Returns:
1184
+ Dict[str, np.ndarray]: Dictionary of bounding boxes, organized by class.
1185
+ """
1186
+ bbox_dict = {label: np.array([]) for label in class_labels}
1187
+
1188
+ for i, label in enumerate(class_labels):
1189
+ if label not in preds:
1190
+ continue
1191
+
1192
+ bboxes_class = np.array(preds[label])
1193
+
1194
+ if bboxes_class.size == 0:
1195
+ continue
1196
+
1197
+ # Try to find a chart_title box
1198
+ threshold = threshold_ if label != "chart_title" else min(threshold_, bboxes_class[:, -1].max())
1199
+ bboxes_class = bboxes_class[bboxes_class[:, -1] >= threshold][:, :4].astype(int)
1200
+
1201
+ sort = ["x0", "y0"] if label != "ylabel" else ["y0", "x0"]
1202
+ idxs = (
1203
+ pd.DataFrame(
1204
+ {
1205
+ "y0": bboxes_class[:, 1],
1206
+ "x0": bboxes_class[:, 0],
1207
+ }
1208
+ )
1209
+ .sort_values(sort, ascending=label != "ylabel")
1210
+ .index
1211
+ )
1212
+ bboxes_class = bboxes_class[idxs]
1213
+ bbox_dict[label] = bboxes_class
1214
+
1215
+ # Remove other included
1216
+ if len(bbox_dict.get("other", [])):
1217
+ other = find_boxes_inside(
1218
+ np.concatenate(list([v for v in bbox_dict.values() if len(v)])), bbox_dict["other"], threshold=0.7
1219
+ )
1220
+ del bbox_dict["other"]
1221
+ if len(other):
1222
+ bbox_dict["other"] = other
1223
+
1224
+ # Biggest other is title if no title
1225
+ if not len(bbox_dict.get("chart_title", [])) and len(bbox_dict.get("other", [])):
1226
+ boxes = bbox_dict["other"]
1227
+ ws = boxes[:, 2] - boxes[:, 0]
1228
+ if np.max(ws) > shape[1] * 0.3:
1229
+ bbox_dict["chart_title"] = boxes[np.argmax(ws)][None].copy()
1230
+ bbox_dict["other"] = np.delete(boxes, (np.argmax(ws)), axis=0)
1231
+
1232
+ # Make sure other key not lost
1233
+ bbox_dict["other"] = bbox_dict.get("other", [])
1234
+
1235
+ return bbox_dict
1236
+
1237
+
1238
+ def get_bbox_dict_yolox_table(preds, shape, class_labels, threshold=0.1, delta=0.0):
1239
+ """
1240
+ Extracts bounding boxes from YOLOX model predictions:
1241
+ - Applies thresholding
1242
+ - Reformats boxes
1243
+ - Reorders predictions
1244
+
1245
+ Args:
1246
+ preds (np.ndarray): YOLOX model predictions including bounding boxes, scores, and labels.
1247
+ shape (tuple): Original image shape.
1248
+ config: Model configuration, including size for bounding box adjustment.
1249
+ threshold (float): Score threshold to filter bounding boxes.
1250
+ delta (float): How much the table was cropped upwards.
1251
+
1252
+ Returns:
1253
+ dict[str, np.ndarray]: Dictionary of bounding boxes, organized by class.
1254
+ """
1255
+ bbox_dict = {label: np.array([]) for label in class_labels}
1256
+
1257
+ for i, label in enumerate(class_labels):
1258
+ if label not in ["cell", "row", "column"]:
1259
+ continue # Ignore useless classes
1260
+
1261
+ bboxes_class = np.array(preds[label])
1262
+
1263
+ if bboxes_class.size == 0:
1264
+ continue
1265
+
1266
+ # Threshold and clip
1267
+ bboxes_class = bboxes_class[bboxes_class[:, -1] >= threshold][:, :4].astype(int)
1268
+ bboxes_class[:, [0, 2]] = np.clip(bboxes_class[:, [0, 2]], 0, shape[1])
1269
+ bboxes_class[:, [1, 3]] = np.clip(bboxes_class[:, [1, 3]], 0, shape[0])
1270
+
1271
+ # Reorder
1272
+ sort = ["x0", "y0"] if label != "row" else ["y0", "x0"]
1273
+ df = pd.DataFrame(
1274
+ {
1275
+ "y0": (bboxes_class[:, 1] + bboxes_class[:, 3]) / 2,
1276
+ "x0": (bboxes_class[:, 0] + bboxes_class[:, 2]) / 2,
1277
+ }
1278
+ )
1279
+ idxs = df.sort_values(sort).index
1280
+ bboxes_class = bboxes_class[idxs]
1281
+
1282
+ bbox_dict[label] = bboxes_class
1283
+
1284
+ # Enforce spanning the entire table
1285
+ if len(bbox_dict["row"]):
1286
+ bbox_dict["row"][:, 0] = 0
1287
+ bbox_dict["row"][:, 2] = shape[1]
1288
+ if len(bbox_dict["column"]):
1289
+ bbox_dict["column"][:, 1] = 0
1290
+ bbox_dict["column"][:, 3] = shape[0]
1291
+
1292
+ # Shift back if cropped
1293
+ for k in bbox_dict:
1294
+ if len(bbox_dict[k]):
1295
+ bbox_dict[k][:, [1, 3]] = np.add(bbox_dict[k][:, [1, 3]], delta, casting="unsafe")
1296
+
1297
+ return bbox_dict
1298
+
1299
+
1300
+ def match_with_title_v3(bbox, title_bboxes, match_dist=0.1, delta=1.5, already_matched=[]):
1301
+ """
1302
+ Matches a bounding box with a title bounding box based on IoU or proximity.
1303
+
1304
+ Args:
1305
+ bbox (numpy.ndarray): Bounding box to match with title [x_min, y_min, x_max, y_max].
1306
+ title_bboxes (numpy.ndarray): Array of title bounding boxes with shape (N, 4).
1307
+ match_dist (float, optional): Maximum distance for matching. Defaults to 0.1.
1308
+ delta (float, optional): Multiplier for matching several titles. Defaults to 1.5.
1309
+ already_matched (list, optional): List of already matched title indices. Defaults to [].
1310
+
1311
+ Returns:
1312
+ tuple or None: If matched, returns a tuple of (merged_bbox, updated_title_bboxes).
1313
+ If no match is found, returns None, None.
1314
+ """
1315
+ if not len(title_bboxes):
1316
+ return None, None
1317
+
1318
+ dist_above = np.abs(title_bboxes[:, 3] - bbox[1])
1319
+ dist_below = np.abs(bbox[3] - title_bboxes[:, 1])
1320
+
1321
+ dist_left = np.abs(title_bboxes[:, 0] - bbox[0])
1322
+ dist_center = np.abs(title_bboxes[:, 0] + title_bboxes[:, 2] - bbox[0] - bbox[2]) / 2
1323
+
1324
+ dists = np.min([dist_above, dist_below], 0)
1325
+ dists += np.min([dist_left, dist_center], 0) / 2
1326
+
1327
+ ious = bb_iou_array(title_bboxes, bbox)
1328
+ dists = np.where(ious > 0, min(match_dist, np.min(dists)), dists)
1329
+
1330
+ if len(already_matched):
1331
+ dists[already_matched] = match_dist * 10 # Remove already matched titles
1332
+
1333
+ # print(dists)
1334
+ matches = None # noqa
1335
+ if np.min(dists) <= match_dist:
1336
+ matches = np.where(dists <= min(match_dist, np.min(dists) * delta))[0]
1337
+
1338
+ if matches is not None:
1339
+ new_bbox = bbox
1340
+ for match in matches:
1341
+ new_bbox = merge_boxes(new_bbox, title_bboxes[match])
1342
+ return new_bbox, list(matches)
1343
+ else:
1344
+ return None, None
1345
+
1346
+
1347
+ def match_boxes_with_title(boxes, confs, labels, classes, to_match_labels=["chart"], remove_matched_titles=False):
1348
+ """
1349
+ Matches charts with title.
1350
+
1351
+ Args:
1352
+ boxes (numpy.ndarray): Array of bounding boxes with shape (N, 4).
1353
+ confs (numpy.ndarray): Array of confidence scores with shape (N,).
1354
+ labels (numpy.ndarray): Array of labels with shape (N,).
1355
+ classes (list): List of class names.
1356
+ to_match_labels (list): List of class names to match with titles.
1357
+ remove_matched_titles (bool): Whether to remove matched titles from the boxes.
1358
+
1359
+ Returns:
1360
+ boxes (numpy.ndarray): Array of bounding boxes with shape (M, 4).
1361
+ confs (numpy.ndarray): Array of confidence scores with shape (M,).
1362
+ labels (numpy.ndarray): Array of labels with shape (M,).
1363
+ found_title (list): List of indices of matched titles.
1364
+ no_found_title (list): List of indices of unmatched titles.
1365
+ """
1366
+ # Put titles at the end
1367
+ title_ids = np.where(labels == classes.index("title"))[0]
1368
+ order = np.concatenate([np.delete(np.arange(len(boxes)), title_ids), title_ids])
1369
+ boxes = boxes[order]
1370
+ confs = confs[order]
1371
+ labels = labels[order]
1372
+
1373
+ # Ids
1374
+ title_ids = np.where(labels == classes.index("title"))[0]
1375
+ to_match = np.where(np.isin(labels, [classes.index(c) for c in to_match_labels]))[0]
1376
+
1377
+ # Matching
1378
+ found_title, already_matched = [], []
1379
+ for i in range(len(boxes)):
1380
+ if i not in to_match:
1381
+ continue
1382
+ merged_box, matched_title_ids = match_with_title_v3(
1383
+ boxes[i],
1384
+ boxes[title_ids],
1385
+ already_matched=already_matched,
1386
+ )
1387
+ if matched_title_ids is not None:
1388
+ # print(f'Merged {classes[int(labels[i])]} at idx #{i} with title {matched_title_ids[-1]}') # noqa
1389
+ boxes[i] = merged_box
1390
+ already_matched += matched_title_ids
1391
+ found_title.append(i)
1392
+
1393
+ if remove_matched_titles and len(already_matched):
1394
+ boxes = np.delete(boxes, title_ids[already_matched], axis=0)
1395
+ confs = np.delete(confs, title_ids[already_matched], axis=0)
1396
+ labels = np.delete(labels, title_ids[already_matched], axis=0)
1397
+
1398
+ return boxes, confs, labels, found_title
1399
+
1400
+
1401
+ def expand_boxes_v3(boxes, r_x=(1, 1), r_y=(1, 1), size_agnostic=True):
1402
+ """
1403
+ Expands bounding boxes by a specified ratio.
1404
+ Expected box format is normalized [x_min, y_min, x_max, y_max].
1405
+
1406
+ Args:
1407
+ boxes (numpy.ndarray): Array of bounding boxes with shape (N, 4).
1408
+ r_x (tuple, optional): Left, right expansion ratios. Defaults to (1, 1) (no expansion).
1409
+ r_y (tuple, optional): Up, down expansion ratios. Defaults to (1, 1) (no expansion).
1410
+ size_agnostic (bool, optional): Expand independently of the bbox shape. Defaults to True.
1411
+
1412
+ Returns:
1413
+ numpy.ndarray: Adjusted bounding boxes clipped to the [0, 1] range.
1414
+ """
1415
+ old_boxes = boxes.copy()
1416
+
1417
+ if not size_agnostic:
1418
+ h = boxes[:, 3] - boxes[:, 1]
1419
+ w = boxes[:, 2] - boxes[:, 0]
1420
+ else:
1421
+ h, w = 1, 1
1422
+
1423
+ boxes[:, 0] -= w * (r_x[0] - 1) # left
1424
+ boxes[:, 2] += w * (r_x[1] - 1) # right
1425
+ boxes[:, 1] -= h * (r_y[0] - 1) # up
1426
+ boxes[:, 3] += h * (r_y[1] - 1) # down
1427
+
1428
+ boxes = np.clip(boxes, 0, 1)
1429
+
1430
+ # Enforce non-overlapping boxes
1431
+ for i in range(len(boxes)):
1432
+ for j in range(i + 1, len(boxes)):
1433
+ iou = bb_iou_array(boxes[i][None], boxes[j])[0]
1434
+ old_iou = bb_iou_array(old_boxes[i][None], old_boxes[j])[0]
1435
+ # print(iou, old_iou)
1436
+ if iou > 0.05 and old_iou < 0.1:
1437
+ if boxes[i, 1] < boxes[j, 1]: # i above j
1438
+ boxes[j, 1] = min(old_boxes[j, 1], boxes[i, 3])
1439
+ if old_iou > 0:
1440
+ boxes[i, 3] = max(old_boxes[i, 3], boxes[j, 1])
1441
+ else:
1442
+ boxes[i, 1] = min(old_boxes[i, 1], boxes[j, 3])
1443
+ if old_iou > 0:
1444
+ boxes[j, 3] = max(old_boxes[j, 3], boxes[i, 1])
1445
+
1446
+ return boxes
1447
+
1448
+
1449
+ def get_overlaps(boxes, other_boxes, normalize="box_only"):
1450
+ """
1451
+ Checks if a box overlaps with any other box.
1452
+ Boxes are expeceted in format (x0, y0, x1, y1)
1453
+
1454
+ Args:
1455
+ boxes (np array [4] or [n x 4]): Boxes.
1456
+ other_boxes (np array [m x 4]): Other boxes.
1457
+
1458
+ Returns:
1459
+ np array [n x m]: Overlaps.
1460
+ """
1461
+ if boxes.ndim == 1:
1462
+ boxes = boxes[None, :]
1463
+
1464
+ x0, y0, x1, y1 = (boxes[:, 0][:, None], boxes[:, 1][:, None], boxes[:, 2][:, None], boxes[:, 3][:, None])
1465
+ areas = (y1 - y0) * (x1 - x0)
1466
+
1467
+ x0_other, y0_other, x1_other, y1_other = (
1468
+ other_boxes[:, 0][None, :],
1469
+ other_boxes[:, 1][None, :],
1470
+ other_boxes[:, 2][None, :],
1471
+ other_boxes[:, 3][None, :],
1472
+ )
1473
+ areas_other = (y1_other - y0_other) * (x1_other - x0_other)
1474
+
1475
+ # Intersection
1476
+ inter_y0 = np.maximum(y0, y0_other)
1477
+ inter_y1 = np.minimum(y1, y1_other)
1478
+ inter_x0 = np.maximum(x0, x0_other)
1479
+ inter_x1 = np.minimum(x1, x1_other)
1480
+ inter_area = np.maximum(0, inter_y1 - inter_y0) * np.maximum(0, inter_x1 - inter_x0)
1481
+
1482
+ # Overlap
1483
+ if normalize == "box_only": # Only consider box included in other box
1484
+ overlaps = inter_area / areas
1485
+ elif normalize == "all": # Consider box included in other box and other box included in box
1486
+ overlaps = inter_area / np.minimum(areas, areas_other[:, None])
1487
+ else:
1488
+ raise ValueError(f"Invalid normalization: {normalize}")
1489
+ return overlaps
1490
+
1491
+
1492
+ def postprocess_included(boxes, labels, confs, class_="title", classes=["table", "chart", "title", "infographic"]):
1493
+ """
1494
+ Post process title predictions.
1495
+ - Remove titles that are included in other boxes
1496
+
1497
+ Args:
1498
+ boxes (numpy.ndarray [N, 4]): Array of bounding boxes.
1499
+ labels (numpy.ndarray [N]): Array of labels.
1500
+ confs (numpy.ndarray [N]): Array of confidences.
1501
+ class_ (str, optional): Class to postprocess. Defaults to "title".
1502
+ classes (list, optional): Classes. Defaults to ["table", "chart", "title", "infographic"].
1503
+
1504
+ Returns:
1505
+ boxes (numpy.ndarray): Array of bounding boxes.
1506
+ labels (numpy.ndarray): Array of labels.
1507
+ confs (numpy.ndarray): Array of confidences.
1508
+ """
1509
+ boxes_to_pp = boxes[labels == classes.index(class_)]
1510
+ confs_to_pp = confs[labels == classes.index(class_)]
1511
+
1512
+ order = np.argsort(confs_to_pp) # least to most confident for NMS
1513
+ boxes_to_pp, confs_to_pp = boxes_to_pp[order], confs_to_pp[order]
1514
+
1515
+ if len(boxes_to_pp) == 0:
1516
+ return boxes, labels, confs
1517
+
1518
+ # other_boxes = boxes[labels != classes.index("title")]
1519
+
1520
+ inclusion_classes = ["table", "infographic", "chart"]
1521
+ if class_ in ["header_footer", "title"]:
1522
+ inclusion_classes.append("paragraph")
1523
+
1524
+ other_boxes = boxes[np.isin(labels, [classes.index(c) for c in inclusion_classes])]
1525
+
1526
+ # Remove boxes included in other_boxes
1527
+ kept_boxes, kept_confs = [], []
1528
+ for i, b in enumerate(boxes_to_pp):
1529
+ if len(other_boxes) > 0:
1530
+ overlaps = get_overlaps(b, other_boxes, normalize="box_only")
1531
+ if overlaps.max() > 0.9:
1532
+ continue
1533
+
1534
+ kept_boxes.append(b)
1535
+ kept_confs.append(confs_to_pp[i])
1536
+
1537
+ # Aggregate
1538
+ kept_boxes = np.stack(kept_boxes) if len(kept_boxes) else np.empty((0, 4))
1539
+ kept_confs = np.stack(kept_confs) if len(kept_confs) else np.empty(0)
1540
+
1541
+ boxes_pp = np.concatenate([boxes[labels != classes.index(class_)], kept_boxes])
1542
+ confs_pp = np.concatenate([confs[labels != classes.index(class_)], kept_confs])
1543
+ labels_pp = np.concatenate(
1544
+ [labels[labels != classes.index(class_)], np.ones(len(kept_boxes)) * classes.index(class_)]
1545
+ )
1546
+
1547
+ return boxes_pp, labels_pp, confs_pp
1548
+
1549
+
1550
+ def remove_overlapping_boxes_using_wbf(boxes, confs, labels):
1551
+ """
1552
+ Remove overlapping boxes using WBF
1553
+ """
1554
+ # Applied twice because once is not enough in some rare cases
1555
+ for _ in range(2):
1556
+ boxes, confs, labels = weighted_boxes_fusion(
1557
+ boxes[:, None],
1558
+ confs[:, None],
1559
+ labels[:, None],
1560
+ merge_type="biggest",
1561
+ conf_type="max",
1562
+ iou_thr=0.01,
1563
+ class_agnostic=False,
1564
+ )
1565
+
1566
+ return boxes, confs, labels
1567
+
1568
+
1569
+ def match_structured_boxes_with_title(boxes, confs, labels, classes):
1570
+ # Reorder by y, x
1571
+ order = np.argsort(boxes[:, 1] * 10 + boxes[:, 0])
1572
+ boxes, confs, labels = boxes[order], confs[order], labels[order]
1573
+
1574
+ # Match with title
1575
+ # Although the model should detect titles, additional post-processing helps retrieve FNs
1576
+ found_title = []
1577
+ boxes, confs, labels, found_title = match_boxes_with_title(
1578
+ boxes,
1579
+ confs,
1580
+ labels,
1581
+ classes,
1582
+ to_match_labels=["chart", "table", "infographic"],
1583
+ remove_matched_titles=True,
1584
+ )
1585
+
1586
+ return boxes, confs, labels, found_title
1587
+
1588
+
1589
+ def expand_tables_and_charts(boxes, confs, labels, classes, found_title):
1590
+ # This is mostly to retrieve titles, but this also helps when YOLOX boxes are too tight.
1591
+ # Boxes with titles matched are expanded less.
1592
+ # Expansion is different for tables and charts
1593
+ no_found_title = [i for i in range(len(boxes)) if i not in found_title]
1594
+ ids = np.arange(len(boxes))
1595
+
1596
+ if len(found_title): # Boxes with title matched are expanded less
1597
+ ids_ = ids[found_title][labels[found_title] == classes.index("chart")]
1598
+ boxes[ids_] = expand_boxes_v3(
1599
+ boxes[ids_],
1600
+ r_x=(1.025, 1.025),
1601
+ r_y=(1.05, 1.05),
1602
+ size_agnostic=False,
1603
+ )
1604
+ ids_ = ids[found_title][labels[found_title] == classes.index("table")]
1605
+ boxes[ids_] = expand_boxes_v3(
1606
+ boxes[ids_],
1607
+ r_x=(1.01, 1.01),
1608
+ r_y=(1.05, 1.01),
1609
+ )
1610
+
1611
+ ids_ = ids[no_found_title][labels[no_found_title] == classes.index("chart")]
1612
+ boxes[ids_] = expand_boxes_v3(
1613
+ boxes[ids_],
1614
+ r_x=(1.05, 1.05),
1615
+ r_y=(1.125, 1.125),
1616
+ size_agnostic=False,
1617
+ )
1618
+
1619
+ ids_ = ids[no_found_title][labels[no_found_title] == classes.index("table")]
1620
+ boxes[ids_] = expand_boxes_v3(
1621
+ boxes[ids_],
1622
+ r_x=(1.02, 1.02),
1623
+ r_y=(1.05, 1.05),
1624
+ )
1625
+
1626
+ order = np.argsort(boxes[:, 1] * 10 + boxes[:, 0])
1627
+ boxes, labels, confs = boxes[order], labels[order], confs[order]
1628
+
1629
+ return boxes, labels, confs
1630
+
1631
+
1632
+ def postprocess_included_texts(boxes, confs, labels, classes):
1633
+ for c in ["title", "paragraph", "header_footer"]:
1634
+ boxes, labels, confs = postprocess_included(boxes, labels, confs, c, classes)
1635
+ return boxes, labels, confs
1636
+
1637
+
1638
+ @multiprocessing_cache(max_calls=100) # Cache results first to avoid redundant retries from backoff
1639
+ @backoff.on_predicate(backoff.expo, max_time=30)
1640
+ def get_yolox_model_name(yolox_grpc_endpoint, default_model_name="yolox"):
1641
+ try:
1642
+ client = grpcclient.InferenceServerClient(yolox_grpc_endpoint)
1643
+ model_index = client.get_model_repository_index(as_json=True)
1644
+ model_names = [x["name"] for x in model_index.get("models", [])]
1645
+ if "pipeline" in model_names:
1646
+ yolox_model_name = "pipeline"
1647
+ elif "yolox_ensemble" in model_names:
1648
+ yolox_model_name = "yolox_ensemble"
1649
+ else:
1650
+ yolox_model_name = default_model_name
1651
+ except Exception:
1652
+ logger.warning(
1653
+ f"Failed to get yolox-page-elements version after 30 seconds. Falling back to '{default_model_name}'."
1654
+ )
1655
+ yolox_model_name = default_model_name
1656
+
1657
+ return yolox_model_name
1658
+
1659
+
1660
+ @multiprocessing_cache(max_calls=100) # Cache results first to avoid redundant retries from backoff
1661
+ @backoff.on_predicate(backoff.expo, max_time=30)
1662
+ def get_yolox_page_version(yolox_http_endpoint, default_version=YOLOX_PAGE_DEFAULT_VERSION):
1663
+ """
1664
+ Determines the YOLOX page elements model version by querying the endpoint.
1665
+ Falls back to a default version on failure.
1666
+ """
1667
+ try:
1668
+ yolox_version = get_model_name(yolox_http_endpoint, default_version)
1669
+ if not yolox_version:
1670
+ logger.warning(
1671
+ "Failed to obtain yolox-page-elements version from the endpoint. "
1672
+ f"Falling back to '{default_version}'."
1673
+ )
1674
+ return default_version
1675
+
1676
+ return yolox_version
1677
+ except Exception:
1678
+ logger.warning(
1679
+ f"Failed to get yolox-page-elements version after 30 seconds. Falling back to '{default_version}'."
1680
+ )
1681
+ return default_version