nv-ingest-api 2025.7.15.dev20250715__py3-none-any.whl → 2025.7.17.dev20250717__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nv-ingest-api might be problematic. Click here for more details.
- nv_ingest_api/interface/extract.py +18 -18
- nv_ingest_api/internal/enums/common.py +6 -0
- nv_ingest_api/internal/extract/image/chart_extractor.py +75 -55
- nv_ingest_api/internal/extract/image/infographic_extractor.py +59 -35
- nv_ingest_api/internal/extract/image/table_extractor.py +81 -63
- nv_ingest_api/internal/extract/pdf/engines/nemoretriever.py +7 -7
- nv_ingest_api/internal/extract/pdf/engines/pdf_helpers/__init__.py +32 -20
- nv_ingest_api/internal/extract/pdf/engines/pdfium.py +32 -9
- nv_ingest_api/internal/primitives/nim/model_interface/helpers.py +58 -0
- nv_ingest_api/internal/primitives/nim/model_interface/{paddle.py → ocr.py} +132 -39
- nv_ingest_api/internal/primitives/nim/nim_client.py +46 -11
- nv_ingest_api/internal/schemas/extract/extract_chart_schema.py +6 -6
- nv_ingest_api/internal/schemas/extract/extract_infographic_schema.py +6 -6
- nv_ingest_api/internal/schemas/extract/extract_table_schema.py +5 -5
- nv_ingest_api/internal/schemas/meta/ingest_job_schema.py +5 -0
- nv_ingest_api/internal/schemas/transform/transform_text_embedding_schema.py +4 -0
- nv_ingest_api/internal/transform/embed_text.py +103 -12
- nv_ingest_api/internal/transform/split_text.py +13 -8
- nv_ingest_api/util/image_processing/table_and_chart.py +97 -42
- nv_ingest_api/util/image_processing/transforms.py +19 -5
- nv_ingest_api/util/message_brokers/simple_message_broker/broker.py +1 -1
- nv_ingest_api/util/message_brokers/simple_message_broker/simple_client.py +51 -48
- nv_ingest_api/util/metadata/aggregators.py +4 -1
- {nv_ingest_api-2025.7.15.dev20250715.dist-info → nv_ingest_api-2025.7.17.dev20250717.dist-info}/METADATA +1 -1
- {nv_ingest_api-2025.7.15.dev20250715.dist-info → nv_ingest_api-2025.7.17.dev20250717.dist-info}/RECORD +28 -28
- {nv_ingest_api-2025.7.15.dev20250715.dist-info → nv_ingest_api-2025.7.17.dev20250717.dist-info}/WHEEL +0 -0
- {nv_ingest_api-2025.7.15.dev20250715.dist-info → nv_ingest_api-2025.7.17.dev20250717.dist-info}/licenses/LICENSE +0 -0
- {nv_ingest_api-2025.7.15.dev20250715.dist-info → nv_ingest_api-2025.7.17.dev20250717.dist-info}/top_level.txt +0 -0
|
@@ -781,9 +781,9 @@ def extract_chart_data_from_image(
|
|
|
781
781
|
*,
|
|
782
782
|
df_ledger: pd.DataFrame,
|
|
783
783
|
yolox_endpoints: Tuple[str, str],
|
|
784
|
-
|
|
784
|
+
ocr_endpoints: Tuple[str, str],
|
|
785
785
|
yolox_protocol: str = "grpc",
|
|
786
|
-
|
|
786
|
+
ocr_protocol: str = "grpc",
|
|
787
787
|
auth_token: str = "",
|
|
788
788
|
) -> DataFrame:
|
|
789
789
|
"""
|
|
@@ -795,11 +795,11 @@ def extract_chart_data_from_image(
|
|
|
795
795
|
DataFrame containing metadata required for chart extraction.
|
|
796
796
|
yolox_endpoints : Tuple[str, str]
|
|
797
797
|
YOLOX inference server endpoints.
|
|
798
|
-
|
|
798
|
+
ocr_endpoints : Tuple[str, str]
|
|
799
799
|
PaddleOCR inference server endpoints.
|
|
800
800
|
yolox_protocol : str, optional
|
|
801
801
|
Protocol for YOLOX inference (default "grpc").
|
|
802
|
-
|
|
802
|
+
ocr_protocol : str, optional
|
|
803
803
|
Protocol for PaddleOCR inference (default "grpc").
|
|
804
804
|
auth_token : str, optional
|
|
805
805
|
Authentication token for inference services.
|
|
@@ -821,9 +821,9 @@ def extract_chart_data_from_image(
|
|
|
821
821
|
**{
|
|
822
822
|
"endpoint_config": {
|
|
823
823
|
"yolox_endpoints": yolox_endpoints,
|
|
824
|
-
"
|
|
824
|
+
"ocr_endpoints": ocr_endpoints,
|
|
825
825
|
"yolox_infer_protocol": yolox_protocol,
|
|
826
|
-
"
|
|
826
|
+
"ocr_infer_protocol": ocr_protocol,
|
|
827
827
|
"auth_token": auth_token,
|
|
828
828
|
}
|
|
829
829
|
}
|
|
@@ -844,9 +844,9 @@ def extract_table_data_from_image(
|
|
|
844
844
|
*,
|
|
845
845
|
df_ledger: pd.DataFrame,
|
|
846
846
|
yolox_endpoints: Optional[Tuple[str, str]] = None,
|
|
847
|
-
|
|
847
|
+
ocr_endpoints: Optional[Tuple[str, str]] = None,
|
|
848
848
|
yolox_protocol: Optional[str] = None,
|
|
849
|
-
|
|
849
|
+
ocr_protocol: Optional[str] = None,
|
|
850
850
|
auth_token: Optional[str] = None,
|
|
851
851
|
) -> pd.DataFrame:
|
|
852
852
|
"""
|
|
@@ -858,11 +858,11 @@ def extract_table_data_from_image(
|
|
|
858
858
|
DataFrame containing metadata required for chart extraction.
|
|
859
859
|
yolox_endpoints : Optional[Tuple[str, str]], default=None
|
|
860
860
|
YOLOX inference server endpoints. If None, the default defined in ChartExtractorConfigSchema is used.
|
|
861
|
-
|
|
861
|
+
ocr_endpoints : Optional[Tuple[str, str]], default=None
|
|
862
862
|
PaddleOCR inference server endpoints. If None, the default defined in ChartExtractorConfigSchema is used.
|
|
863
863
|
yolox_protocol : Optional[str], default=None
|
|
864
864
|
Protocol for YOLOX inference. If None, the default defined in ChartExtractorConfigSchema is used.
|
|
865
|
-
|
|
865
|
+
ocr_protocol : Optional[str], default=None
|
|
866
866
|
Protocol for PaddleOCR inference. If None, the default defined in ChartExtractorConfigSchema is used.
|
|
867
867
|
auth_token : Optional[str], default=None
|
|
868
868
|
Authentication token for inference services. If None, the default defined in ChartExtractorConfigSchema is used.
|
|
@@ -882,9 +882,9 @@ def extract_table_data_from_image(
|
|
|
882
882
|
config_kwargs = {
|
|
883
883
|
"endpoint_config": {
|
|
884
884
|
"yolox_endpoints": yolox_endpoints,
|
|
885
|
-
"
|
|
885
|
+
"ocr_endpoints": ocr_endpoints,
|
|
886
886
|
"yolox_infer_protocol": yolox_protocol,
|
|
887
|
-
"
|
|
887
|
+
"ocr_infer_protocol": ocr_protocol,
|
|
888
888
|
"auth_token": auth_token,
|
|
889
889
|
}
|
|
890
890
|
}
|
|
@@ -907,8 +907,8 @@ def extract_table_data_from_image(
|
|
|
907
907
|
def extract_infographic_data_from_image(
|
|
908
908
|
*,
|
|
909
909
|
df_ledger: pd.DataFrame,
|
|
910
|
-
|
|
911
|
-
|
|
910
|
+
ocr_endpoints: Optional[Tuple[str, str]] = None,
|
|
911
|
+
ocr_protocol: Optional[str] = None,
|
|
912
912
|
auth_token: Optional[str] = None,
|
|
913
913
|
) -> pd.DataFrame:
|
|
914
914
|
"""
|
|
@@ -924,10 +924,10 @@ def extract_infographic_data_from_image(
|
|
|
924
924
|
----------
|
|
925
925
|
df_extraction_ledger : pd.DataFrame
|
|
926
926
|
DataFrame containing the images and associated metadata from which infographic data is to be extracted.
|
|
927
|
-
|
|
927
|
+
ocr_endpoints : Optional[Tuple[str, str]], default=None
|
|
928
928
|
A tuple of PaddleOCR endpoint addresses (e.g., (gRPC_endpoint, HTTP_endpoint)) used for inference.
|
|
929
929
|
If None, the default endpoints from InfographicExtractorConfigSchema are used.
|
|
930
|
-
|
|
930
|
+
ocr_protocol : Optional[str], default=None
|
|
931
931
|
The protocol (e.g., "grpc" or "http") for PaddleOCR inference.
|
|
932
932
|
If None, the default protocol from InfographicExtractorConfigSchema is used.
|
|
933
933
|
auth_token : Optional[str], default=None
|
|
@@ -951,8 +951,8 @@ def extract_infographic_data_from_image(
|
|
|
951
951
|
extractor_config_kwargs = {
|
|
952
952
|
"endpoint_config": InfographicExtractorConfigSchema(
|
|
953
953
|
**{
|
|
954
|
-
"
|
|
955
|
-
"
|
|
954
|
+
"ocr_endpoints": ocr_endpoints,
|
|
955
|
+
"ocr_infer_protocol": ocr_protocol,
|
|
956
956
|
"auth_token": auth_token,
|
|
957
957
|
}
|
|
958
958
|
)
|
|
@@ -52,6 +52,8 @@ class ContentDescriptionEnum(str, Enum):
|
|
|
52
52
|
Description for image extracted from PDF document.
|
|
53
53
|
PDF_INFOGRAPHIC : str
|
|
54
54
|
Description for structured infographic extracted from PDF document.
|
|
55
|
+
PDF_PAGE_IMAGE : str
|
|
56
|
+
Description for a full-page image rendered from a PDF document.
|
|
55
57
|
PDF_TABLE : str
|
|
56
58
|
Description for structured table extracted from PDF document.
|
|
57
59
|
PDF_TEXT : str
|
|
@@ -70,6 +72,7 @@ class ContentDescriptionEnum(str, Enum):
|
|
|
70
72
|
PDF_CHART: str = "Structured chart extracted from PDF document."
|
|
71
73
|
PDF_IMAGE: str = "Image extracted from PDF document."
|
|
72
74
|
PDF_INFOGRAPHIC: str = "Structured infographic extracted from PDF document."
|
|
75
|
+
PDF_PAGE_IMAGE: str = "Full-page image rendered from a PDF document."
|
|
73
76
|
PDF_TABLE: str = "Structured table extracted from PDF document."
|
|
74
77
|
PDF_TEXT: str = "Unstructured text from PDF document."
|
|
75
78
|
PPTX_IMAGE: str = "Image extracted from PPTX presentation."
|
|
@@ -94,6 +97,8 @@ class ContentTypeEnum(str, Enum):
|
|
|
94
97
|
Represents image content.
|
|
95
98
|
INFO_MSG : str
|
|
96
99
|
Represents an informational message.
|
|
100
|
+
PAGE_IMAGE : str
|
|
101
|
+
Represents a full-page image rendered from a document.
|
|
97
102
|
STRUCTURED : str
|
|
98
103
|
Represents structured content.
|
|
99
104
|
TEXT : str
|
|
@@ -111,6 +116,7 @@ class ContentTypeEnum(str, Enum):
|
|
|
111
116
|
INFOGRAPHIC: str = "infographic"
|
|
112
117
|
INFO_MSG: str = "info_message"
|
|
113
118
|
NONE: str = "none"
|
|
119
|
+
PAGE_IMAGE: str = "page_image"
|
|
114
120
|
STRUCTURED: str = "structured"
|
|
115
121
|
TABLE: str = "table"
|
|
116
122
|
TEXT: str = "text"
|
|
@@ -16,9 +16,10 @@ import pandas as pd
|
|
|
16
16
|
from nv_ingest_api.internal.primitives.nim.model_interface.helpers import get_version
|
|
17
17
|
from nv_ingest_api.internal.schemas.extract.extract_chart_schema import ChartExtractorSchema
|
|
18
18
|
from nv_ingest_api.internal.schemas.meta.ingest_job_schema import IngestTaskChartExtraction
|
|
19
|
-
from nv_ingest_api.util.image_processing.table_and_chart import
|
|
19
|
+
from nv_ingest_api.util.image_processing.table_and_chart import join_yolox_graphic_elements_and_ocr_output
|
|
20
20
|
from nv_ingest_api.util.image_processing.table_and_chart import process_yolox_graphic_elements
|
|
21
|
-
from nv_ingest_api.internal.primitives.nim.model_interface.
|
|
21
|
+
from nv_ingest_api.internal.primitives.nim.model_interface.ocr import OCRModelInterface
|
|
22
|
+
from nv_ingest_api.internal.primitives.nim.model_interface.ocr import get_ocr_model_name
|
|
22
23
|
from nv_ingest_api.internal.primitives.nim import NimClient
|
|
23
24
|
from nv_ingest_api.internal.primitives.nim.model_interface.yolox import YoloxGraphicElementsModelInterface
|
|
24
25
|
from nv_ingest_api.util.image_processing.transforms import base64_to_numpy
|
|
@@ -62,7 +63,8 @@ def _filter_valid_chart_images(
|
|
|
62
63
|
|
|
63
64
|
def _run_chart_inference(
|
|
64
65
|
yolox_client: Any,
|
|
65
|
-
|
|
66
|
+
ocr_client: Any,
|
|
67
|
+
ocr_model_name: str,
|
|
66
68
|
valid_arrays: List[np.ndarray],
|
|
67
69
|
valid_images: List[str],
|
|
68
70
|
trace_info: Dict,
|
|
@@ -70,29 +72,40 @@ def _run_chart_inference(
|
|
|
70
72
|
"""
|
|
71
73
|
Run concurrent inference for chart extraction using YOLOX and Paddle.
|
|
72
74
|
|
|
73
|
-
Returns a tuple of (yolox_results,
|
|
75
|
+
Returns a tuple of (yolox_results, ocr_results).
|
|
74
76
|
"""
|
|
75
77
|
data_yolox = {"images": valid_arrays}
|
|
76
|
-
|
|
78
|
+
data_ocr = {"base64_images": valid_images}
|
|
77
79
|
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
80
|
+
future_yolox_kwargs = dict(
|
|
81
|
+
data=data_yolox,
|
|
82
|
+
model_name="yolox",
|
|
83
|
+
stage_name="chart_extraction",
|
|
84
|
+
max_batch_size=8,
|
|
85
|
+
trace_info=trace_info,
|
|
86
|
+
)
|
|
87
|
+
future_ocr_kwargs = dict(
|
|
88
|
+
data=data_ocr,
|
|
89
|
+
stage_name="chart_extraction",
|
|
90
|
+
max_batch_size=1 if ocr_client.protocol == "grpc" else 2,
|
|
91
|
+
trace_info=trace_info,
|
|
92
|
+
)
|
|
93
|
+
if ocr_model_name == "paddle":
|
|
94
|
+
future_ocr_kwargs.update(
|
|
90
95
|
model_name="paddle",
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
96
|
+
)
|
|
97
|
+
else:
|
|
98
|
+
future_ocr_kwargs.update(
|
|
99
|
+
model_name="scene_text",
|
|
100
|
+
input_names=["input", "merge_levels"],
|
|
101
|
+
dtypes=["FP32", "BYTES"],
|
|
102
|
+
merge_level="paragraph",
|
|
94
103
|
)
|
|
95
104
|
|
|
105
|
+
with ThreadPoolExecutor(max_workers=2) as executor:
|
|
106
|
+
future_yolox = executor.submit(yolox_client.infer, **future_yolox_kwargs)
|
|
107
|
+
future_ocr = executor.submit(ocr_client.infer, **future_ocr_kwargs)
|
|
108
|
+
|
|
96
109
|
try:
|
|
97
110
|
yolox_results = future_yolox.result()
|
|
98
111
|
except Exception as e:
|
|
@@ -100,16 +113,16 @@ def _run_chart_inference(
|
|
|
100
113
|
raise
|
|
101
114
|
|
|
102
115
|
try:
|
|
103
|
-
|
|
116
|
+
ocr_results = future_ocr.result()
|
|
104
117
|
except Exception as e:
|
|
105
|
-
logger.error(f"Error calling
|
|
118
|
+
logger.error(f"Error calling ocr_client.infer: {e}", exc_info=True)
|
|
106
119
|
raise
|
|
107
120
|
|
|
108
|
-
return yolox_results,
|
|
121
|
+
return yolox_results, ocr_results
|
|
109
122
|
|
|
110
123
|
|
|
111
124
|
def _validate_chart_inference_results(
|
|
112
|
-
yolox_results: Any,
|
|
125
|
+
yolox_results: Any, ocr_results: Any, valid_arrays: List[Any], valid_images: List[str]
|
|
113
126
|
) -> Tuple[List[Any], List[Any]]:
|
|
114
127
|
"""
|
|
115
128
|
Ensure inference results are lists and have expected lengths.
|
|
@@ -117,21 +130,21 @@ def _validate_chart_inference_results(
|
|
|
117
130
|
Raises:
|
|
118
131
|
ValueError if results do not match expected types or lengths.
|
|
119
132
|
"""
|
|
120
|
-
if not (isinstance(yolox_results, list) and isinstance(
|
|
121
|
-
raise ValueError("Expected list results from both yolox_client and
|
|
133
|
+
if not (isinstance(yolox_results, list) and isinstance(ocr_results, list)):
|
|
134
|
+
raise ValueError("Expected list results from both yolox_client and ocr_client infer calls.")
|
|
122
135
|
|
|
123
136
|
if len(yolox_results) != len(valid_arrays):
|
|
124
137
|
raise ValueError(f"Expected {len(valid_arrays)} yolox results, got {len(yolox_results)}")
|
|
125
|
-
if len(
|
|
126
|
-
raise ValueError(f"Expected {len(valid_images)}
|
|
127
|
-
return yolox_results,
|
|
138
|
+
if len(ocr_results) != len(valid_images):
|
|
139
|
+
raise ValueError(f"Expected {len(valid_images)} ocr results, got {len(ocr_results)}")
|
|
140
|
+
return yolox_results, ocr_results
|
|
128
141
|
|
|
129
142
|
|
|
130
143
|
def _merge_chart_results(
|
|
131
144
|
base64_images: List[str],
|
|
132
145
|
valid_indices: List[int],
|
|
133
146
|
yolox_results: List[Any],
|
|
134
|
-
|
|
147
|
+
ocr_results: List[Any],
|
|
135
148
|
initial_results: List[Tuple[str, Optional[Dict]]],
|
|
136
149
|
) -> List[Tuple[str, Optional[Dict]]]:
|
|
137
150
|
"""
|
|
@@ -140,10 +153,10 @@ def _merge_chart_results(
|
|
|
140
153
|
For each valid image, processes the results from both inference calls and updates the
|
|
141
154
|
corresponding entry in the results list.
|
|
142
155
|
"""
|
|
143
|
-
for idx, (yolox_res,
|
|
144
|
-
# Unpack
|
|
145
|
-
bounding_boxes, text_predictions =
|
|
146
|
-
yolox_elements =
|
|
156
|
+
for idx, (yolox_res, ocr_res) in enumerate(zip(yolox_results, ocr_results)):
|
|
157
|
+
# Unpack ocr result into bounding boxes and text predictions.
|
|
158
|
+
bounding_boxes, text_predictions, _ = ocr_res
|
|
159
|
+
yolox_elements = join_yolox_graphic_elements_and_ocr_output(yolox_res, bounding_boxes, text_predictions)
|
|
147
160
|
chart_content = process_yolox_graphic_elements(yolox_elements)
|
|
148
161
|
original_index = valid_indices[idx]
|
|
149
162
|
initial_results[original_index] = (base64_images[original_index], chart_content)
|
|
@@ -153,7 +166,8 @@ def _merge_chart_results(
|
|
|
153
166
|
def _update_chart_metadata(
|
|
154
167
|
base64_images: List[str],
|
|
155
168
|
yolox_client: Any,
|
|
156
|
-
|
|
169
|
+
ocr_client: Any,
|
|
170
|
+
ocr_model_name: str,
|
|
157
171
|
trace_info: Dict,
|
|
158
172
|
worker_pool_size: int = 8, # Not currently used.
|
|
159
173
|
) -> List[Tuple[str, Optional[Dict]]]:
|
|
@@ -172,28 +186,29 @@ def _update_chart_metadata(
|
|
|
172
186
|
valid_images, valid_arrays, valid_indices, results = _filter_valid_chart_images(base64_images)
|
|
173
187
|
|
|
174
188
|
# Run concurrent inference only for valid images.
|
|
175
|
-
yolox_results,
|
|
189
|
+
yolox_results, ocr_results = _run_chart_inference(
|
|
176
190
|
yolox_client=yolox_client,
|
|
177
|
-
|
|
191
|
+
ocr_client=ocr_client,
|
|
192
|
+
ocr_model_name=ocr_model_name,
|
|
178
193
|
valid_arrays=valid_arrays,
|
|
179
194
|
valid_images=valid_images,
|
|
180
195
|
trace_info=trace_info,
|
|
181
196
|
)
|
|
182
197
|
|
|
183
198
|
# Validate that the returned inference results are lists of the expected length.
|
|
184
|
-
yolox_results,
|
|
185
|
-
yolox_results,
|
|
199
|
+
yolox_results, ocr_results = _validate_chart_inference_results(
|
|
200
|
+
yolox_results, ocr_results, valid_arrays, valid_images
|
|
186
201
|
)
|
|
187
202
|
|
|
188
203
|
# Merge the inference results into the results list.
|
|
189
|
-
return _merge_chart_results(base64_images, valid_indices, yolox_results,
|
|
204
|
+
return _merge_chart_results(base64_images, valid_indices, yolox_results, ocr_results, results)
|
|
190
205
|
|
|
191
206
|
|
|
192
207
|
def _create_clients(
|
|
193
208
|
yolox_endpoints: Tuple[str, str],
|
|
194
209
|
yolox_protocol: str,
|
|
195
|
-
|
|
196
|
-
|
|
210
|
+
ocr_endpoints: Tuple[str, str],
|
|
211
|
+
ocr_protocol: str,
|
|
197
212
|
auth_token: str,
|
|
198
213
|
) -> Tuple[NimClient, NimClient]:
|
|
199
214
|
# Obtain yolox_version
|
|
@@ -214,9 +229,9 @@ def _create_clients(
|
|
|
214
229
|
yolox_version = None # Default to the latest version
|
|
215
230
|
|
|
216
231
|
yolox_model_interface = YoloxGraphicElementsModelInterface(yolox_version=yolox_version)
|
|
217
|
-
|
|
232
|
+
ocr_model_interface = OCRModelInterface()
|
|
218
233
|
|
|
219
|
-
logger.debug(f"Inference protocols: yolox={yolox_protocol},
|
|
234
|
+
logger.debug(f"Inference protocols: yolox={yolox_protocol}, ocr={ocr_protocol}")
|
|
220
235
|
|
|
221
236
|
yolox_client = create_inference_client(
|
|
222
237
|
endpoints=yolox_endpoints,
|
|
@@ -225,14 +240,14 @@ def _create_clients(
|
|
|
225
240
|
infer_protocol=yolox_protocol,
|
|
226
241
|
)
|
|
227
242
|
|
|
228
|
-
|
|
229
|
-
endpoints=
|
|
230
|
-
model_interface=
|
|
243
|
+
ocr_client = create_inference_client(
|
|
244
|
+
endpoints=ocr_endpoints,
|
|
245
|
+
model_interface=ocr_model_interface,
|
|
231
246
|
auth_token=auth_token,
|
|
232
|
-
infer_protocol=
|
|
247
|
+
infer_protocol=ocr_protocol,
|
|
233
248
|
)
|
|
234
249
|
|
|
235
|
-
return yolox_client,
|
|
250
|
+
return yolox_client, ocr_client
|
|
236
251
|
|
|
237
252
|
|
|
238
253
|
def extract_chart_data_from_image_internal(
|
|
@@ -275,14 +290,18 @@ def extract_chart_data_from_image_internal(
|
|
|
275
290
|
return df_extraction_ledger, execution_trace_log
|
|
276
291
|
|
|
277
292
|
endpoint_config = extraction_config.endpoint_config
|
|
278
|
-
yolox_client,
|
|
293
|
+
yolox_client, ocr_client = _create_clients(
|
|
279
294
|
endpoint_config.yolox_endpoints,
|
|
280
295
|
endpoint_config.yolox_infer_protocol,
|
|
281
|
-
endpoint_config.
|
|
282
|
-
endpoint_config.
|
|
296
|
+
endpoint_config.ocr_endpoints,
|
|
297
|
+
endpoint_config.ocr_infer_protocol,
|
|
283
298
|
endpoint_config.auth_token,
|
|
284
299
|
)
|
|
285
300
|
|
|
301
|
+
# Get the grpc endpoint to determine the model if needed
|
|
302
|
+
ocr_grpc_endpoint = endpoint_config.ocr_endpoints[0]
|
|
303
|
+
ocr_model_name = get_ocr_model_name(ocr_grpc_endpoint)
|
|
304
|
+
|
|
286
305
|
try:
|
|
287
306
|
# 1) Identify rows that meet criteria in a single pass
|
|
288
307
|
# - metadata exists
|
|
@@ -323,7 +342,8 @@ def extract_chart_data_from_image_internal(
|
|
|
323
342
|
bulk_results = _update_chart_metadata(
|
|
324
343
|
base64_images=base64_images,
|
|
325
344
|
yolox_client=yolox_client,
|
|
326
|
-
|
|
345
|
+
ocr_client=ocr_client,
|
|
346
|
+
ocr_model_name=ocr_model_name,
|
|
327
347
|
worker_pool_size=endpoint_config.workers_per_progress_engine,
|
|
328
348
|
trace_info=execution_trace_log,
|
|
329
349
|
)
|
|
@@ -344,8 +364,8 @@ def extract_chart_data_from_image_internal(
|
|
|
344
364
|
|
|
345
365
|
finally:
|
|
346
366
|
try:
|
|
347
|
-
if
|
|
348
|
-
|
|
367
|
+
if ocr_client is not None:
|
|
368
|
+
ocr_client.close()
|
|
349
369
|
if yolox_client is not None:
|
|
350
370
|
yolox_client.close()
|
|
351
371
|
|
|
@@ -12,12 +12,14 @@ from typing import Tuple
|
|
|
12
12
|
import pandas as pd
|
|
13
13
|
|
|
14
14
|
from nv_ingest_api.internal.primitives.nim import NimClient
|
|
15
|
-
from nv_ingest_api.internal.primitives.nim.model_interface.
|
|
15
|
+
from nv_ingest_api.internal.primitives.nim.model_interface.ocr import OCRModelInterface
|
|
16
|
+
from nv_ingest_api.internal.primitives.nim.model_interface.ocr import get_ocr_model_name
|
|
16
17
|
from nv_ingest_api.internal.schemas.extract.extract_infographic_schema import (
|
|
17
18
|
InfographicExtractorSchema,
|
|
18
19
|
)
|
|
19
20
|
from nv_ingest_api.util.image_processing.transforms import base64_to_numpy
|
|
20
21
|
from nv_ingest_api.util.nim import create_inference_client
|
|
22
|
+
from nv_ingest_api.util.image_processing.table_and_chart import reorder_boxes
|
|
21
23
|
|
|
22
24
|
logger = logging.getLogger(__name__)
|
|
23
25
|
|
|
@@ -61,22 +63,23 @@ def _filter_infographic_images(
|
|
|
61
63
|
|
|
62
64
|
def _update_infographic_metadata(
|
|
63
65
|
base64_images: List[str],
|
|
64
|
-
|
|
66
|
+
ocr_client: NimClient,
|
|
67
|
+
ocr_model_name: str,
|
|
65
68
|
worker_pool_size: int = 8, # Not currently used
|
|
66
69
|
trace_info: Optional[Dict] = None,
|
|
67
70
|
) -> List[Tuple[str, Optional[Any], Optional[Any]]]:
|
|
68
71
|
"""
|
|
69
|
-
Filters base64-encoded images and uses
|
|
72
|
+
Filters base64-encoded images and uses OCR to extract infographic data.
|
|
70
73
|
|
|
71
|
-
For each image that meets the minimum size, calls
|
|
74
|
+
For each image that meets the minimum size, calls ocr_client.infer to obtain
|
|
72
75
|
(text_predictions, bounding_boxes). Invalid images are marked as skipped.
|
|
73
76
|
|
|
74
77
|
Parameters
|
|
75
78
|
----------
|
|
76
79
|
base64_images : List[str]
|
|
77
80
|
List of base64-encoded images.
|
|
78
|
-
|
|
79
|
-
Client instance for
|
|
81
|
+
ocr_client : NimClient
|
|
82
|
+
Client instance for OCR inference.
|
|
80
83
|
worker_pool_size : int, optional
|
|
81
84
|
Worker pool size (currently not used), by default 8.
|
|
82
85
|
trace_info : Optional[Dict], optional
|
|
@@ -88,54 +91,70 @@ def _update_infographic_metadata(
|
|
|
88
91
|
List of tuples in the same order as base64_images, where each tuple contains:
|
|
89
92
|
(base64_image, text_predictions, bounding_boxes).
|
|
90
93
|
"""
|
|
91
|
-
logger.debug(f"Running infographic extraction using protocol {
|
|
94
|
+
logger.debug(f"Running infographic extraction using protocol {ocr_client.protocol}")
|
|
92
95
|
|
|
93
96
|
valid_images, valid_indices, results = _filter_infographic_images(base64_images)
|
|
94
|
-
|
|
97
|
+
data_ocr = {"base64_images": valid_images}
|
|
95
98
|
|
|
96
99
|
# worker_pool_size is not used in current implementation.
|
|
97
100
|
_ = worker_pool_size
|
|
98
101
|
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
+
infer_kwargs = dict(
|
|
103
|
+
stage_name="infographic_extraction",
|
|
104
|
+
max_batch_size=1 if ocr_client.protocol == "grpc" else 2,
|
|
105
|
+
trace_info=trace_info,
|
|
106
|
+
)
|
|
107
|
+
if ocr_model_name == "paddle":
|
|
108
|
+
infer_kwargs.update(
|
|
102
109
|
model_name="paddle",
|
|
103
|
-
stage_name="infographic_extraction",
|
|
104
|
-
max_batch_size=1 if paddle_client.protocol == "grpc" else 2,
|
|
105
|
-
trace_info=trace_info,
|
|
106
110
|
)
|
|
111
|
+
else:
|
|
112
|
+
infer_kwargs.update(
|
|
113
|
+
model_name="scene_text",
|
|
114
|
+
input_names=["input", "merge_levels"],
|
|
115
|
+
dtypes=["FP32", "BYTES"],
|
|
116
|
+
merge_level="paragraph",
|
|
117
|
+
)
|
|
118
|
+
try:
|
|
119
|
+
ocr_results = ocr_client.infer(data_ocr, **infer_kwargs)
|
|
107
120
|
except Exception as e:
|
|
108
|
-
logger.error(f"Error calling
|
|
121
|
+
logger.error(f"Error calling ocr_client.infer: {e}", exc_info=True)
|
|
109
122
|
raise
|
|
110
123
|
|
|
111
|
-
if len(
|
|
112
|
-
raise ValueError(f"Expected {len(valid_images)}
|
|
124
|
+
if len(ocr_results) != len(valid_images):
|
|
125
|
+
raise ValueError(f"Expected {len(valid_images)} ocr results, got {len(ocr_results)}")
|
|
113
126
|
|
|
114
|
-
for idx,
|
|
127
|
+
for idx, ocr_res in enumerate(ocr_results):
|
|
115
128
|
original_index = valid_indices[idx]
|
|
116
|
-
|
|
117
|
-
|
|
129
|
+
|
|
130
|
+
if ocr_model_name == "paddle":
|
|
131
|
+
logger.debug(f"OCR results for image {base64_images[original_index]}: {ocr_res}")
|
|
132
|
+
else:
|
|
133
|
+
# Each ocr_res is expected to be a tuple (text_predictions, bounding_boxes, conf_scores).
|
|
134
|
+
ocr_res = reorder_boxes(*ocr_res)
|
|
135
|
+
|
|
136
|
+
results[original_index] = (base64_images[original_index], ocr_res[0], ocr_res[1])
|
|
118
137
|
|
|
119
138
|
return results
|
|
120
139
|
|
|
121
140
|
|
|
122
141
|
def _create_clients(
|
|
123
|
-
|
|
124
|
-
|
|
142
|
+
ocr_endpoints: Tuple[str, str],
|
|
143
|
+
ocr_protocol: str,
|
|
125
144
|
auth_token: str,
|
|
126
145
|
) -> NimClient:
|
|
127
|
-
|
|
146
|
+
ocr_model_interface = OCRModelInterface()
|
|
128
147
|
|
|
129
|
-
logger.debug(f"Inference protocols:
|
|
148
|
+
logger.debug(f"Inference protocols: ocr={ocr_protocol}")
|
|
130
149
|
|
|
131
|
-
|
|
132
|
-
endpoints=
|
|
133
|
-
model_interface=
|
|
150
|
+
ocr_client = create_inference_client(
|
|
151
|
+
endpoints=ocr_endpoints,
|
|
152
|
+
model_interface=ocr_model_interface,
|
|
134
153
|
auth_token=auth_token,
|
|
135
|
-
infer_protocol=
|
|
154
|
+
infer_protocol=ocr_protocol,
|
|
136
155
|
)
|
|
137
156
|
|
|
138
|
-
return
|
|
157
|
+
return ocr_client
|
|
139
158
|
|
|
140
159
|
|
|
141
160
|
def _meets_infographic_criteria(row: pd.Series) -> bool:
|
|
@@ -209,12 +228,16 @@ def extract_infographic_data_from_image_internal(
|
|
|
209
228
|
return df_extraction_ledger, execution_trace_log
|
|
210
229
|
|
|
211
230
|
endpoint_config = extraction_config.endpoint_config
|
|
212
|
-
|
|
213
|
-
endpoint_config.
|
|
214
|
-
endpoint_config.
|
|
231
|
+
ocr_client = _create_clients(
|
|
232
|
+
endpoint_config.ocr_endpoints,
|
|
233
|
+
endpoint_config.ocr_infer_protocol,
|
|
215
234
|
endpoint_config.auth_token,
|
|
216
235
|
)
|
|
217
236
|
|
|
237
|
+
# Get the grpc endpoint to determine the model if needed
|
|
238
|
+
ocr_grpc_endpoint = endpoint_config.ocr_endpoints[0]
|
|
239
|
+
ocr_model_name = get_ocr_model_name(ocr_grpc_endpoint)
|
|
240
|
+
|
|
218
241
|
try:
|
|
219
242
|
# Identify rows that meet the infographic criteria.
|
|
220
243
|
mask = df_extraction_ledger.apply(_meets_infographic_criteria, axis=1)
|
|
@@ -230,14 +253,15 @@ def extract_infographic_data_from_image_internal(
|
|
|
230
253
|
# Call bulk update to extract infographic data.
|
|
231
254
|
bulk_results = _update_infographic_metadata(
|
|
232
255
|
base64_images=base64_images,
|
|
233
|
-
|
|
256
|
+
ocr_client=ocr_client,
|
|
257
|
+
ocr_model_name=ocr_model_name,
|
|
234
258
|
worker_pool_size=endpoint_config.workers_per_progress_engine,
|
|
235
259
|
trace_info=execution_trace_log,
|
|
236
260
|
)
|
|
237
261
|
|
|
238
262
|
# Write the extracted results back into the DataFrame.
|
|
239
263
|
for result_idx, df_idx in enumerate(valid_indices):
|
|
240
|
-
# Unpack result: (base64_image,
|
|
264
|
+
# Unpack result: (base64_image, ocr_bounding_boxes, ocr_text_predictions)
|
|
241
265
|
_, _, text_predictions = bulk_results[result_idx]
|
|
242
266
|
table_content = " ".join(text_predictions) if text_predictions else None
|
|
243
267
|
df_extraction_ledger.at[df_idx, "metadata"]["table_metadata"]["table_content"] = table_content
|
|
@@ -250,4 +274,4 @@ def extract_infographic_data_from_image_internal(
|
|
|
250
274
|
raise
|
|
251
275
|
|
|
252
276
|
finally:
|
|
253
|
-
|
|
277
|
+
ocr_client.close()
|