nv-ingest-api 26.1.0rc4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nv-ingest-api might be problematic. Click here for more details.
- nv_ingest_api/__init__.py +3 -0
- nv_ingest_api/interface/__init__.py +218 -0
- nv_ingest_api/interface/extract.py +977 -0
- nv_ingest_api/interface/mutate.py +154 -0
- nv_ingest_api/interface/store.py +200 -0
- nv_ingest_api/interface/transform.py +382 -0
- nv_ingest_api/interface/utility.py +186 -0
- nv_ingest_api/internal/__init__.py +0 -0
- nv_ingest_api/internal/enums/__init__.py +3 -0
- nv_ingest_api/internal/enums/common.py +550 -0
- nv_ingest_api/internal/extract/__init__.py +3 -0
- nv_ingest_api/internal/extract/audio/__init__.py +3 -0
- nv_ingest_api/internal/extract/audio/audio_extraction.py +202 -0
- nv_ingest_api/internal/extract/docx/__init__.py +5 -0
- nv_ingest_api/internal/extract/docx/docx_extractor.py +232 -0
- nv_ingest_api/internal/extract/docx/engines/__init__.py +0 -0
- nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/__init__.py +3 -0
- nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/docx_helper.py +127 -0
- nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/docxreader.py +971 -0
- nv_ingest_api/internal/extract/html/__init__.py +3 -0
- nv_ingest_api/internal/extract/html/html_extractor.py +84 -0
- nv_ingest_api/internal/extract/image/__init__.py +3 -0
- nv_ingest_api/internal/extract/image/chart_extractor.py +375 -0
- nv_ingest_api/internal/extract/image/image_extractor.py +208 -0
- nv_ingest_api/internal/extract/image/image_helpers/__init__.py +3 -0
- nv_ingest_api/internal/extract/image/image_helpers/common.py +433 -0
- nv_ingest_api/internal/extract/image/infographic_extractor.py +290 -0
- nv_ingest_api/internal/extract/image/ocr_extractor.py +407 -0
- nv_ingest_api/internal/extract/image/table_extractor.py +391 -0
- nv_ingest_api/internal/extract/pdf/__init__.py +3 -0
- nv_ingest_api/internal/extract/pdf/engines/__init__.py +19 -0
- nv_ingest_api/internal/extract/pdf/engines/adobe.py +484 -0
- nv_ingest_api/internal/extract/pdf/engines/llama.py +246 -0
- nv_ingest_api/internal/extract/pdf/engines/nemotron_parse.py +598 -0
- nv_ingest_api/internal/extract/pdf/engines/pdf_helpers/__init__.py +166 -0
- nv_ingest_api/internal/extract/pdf/engines/pdfium.py +652 -0
- nv_ingest_api/internal/extract/pdf/engines/tika.py +96 -0
- nv_ingest_api/internal/extract/pdf/engines/unstructured_io.py +426 -0
- nv_ingest_api/internal/extract/pdf/pdf_extractor.py +74 -0
- nv_ingest_api/internal/extract/pptx/__init__.py +5 -0
- nv_ingest_api/internal/extract/pptx/engines/__init__.py +0 -0
- nv_ingest_api/internal/extract/pptx/engines/pptx_helper.py +968 -0
- nv_ingest_api/internal/extract/pptx/pptx_extractor.py +210 -0
- nv_ingest_api/internal/meta/__init__.py +3 -0
- nv_ingest_api/internal/meta/udf.py +232 -0
- nv_ingest_api/internal/mutate/__init__.py +3 -0
- nv_ingest_api/internal/mutate/deduplicate.py +110 -0
- nv_ingest_api/internal/mutate/filter.py +133 -0
- nv_ingest_api/internal/primitives/__init__.py +0 -0
- nv_ingest_api/internal/primitives/control_message_task.py +16 -0
- nv_ingest_api/internal/primitives/ingest_control_message.py +307 -0
- nv_ingest_api/internal/primitives/nim/__init__.py +9 -0
- nv_ingest_api/internal/primitives/nim/default_values.py +14 -0
- nv_ingest_api/internal/primitives/nim/model_interface/__init__.py +3 -0
- nv_ingest_api/internal/primitives/nim/model_interface/cached.py +274 -0
- nv_ingest_api/internal/primitives/nim/model_interface/decorators.py +56 -0
- nv_ingest_api/internal/primitives/nim/model_interface/deplot.py +270 -0
- nv_ingest_api/internal/primitives/nim/model_interface/helpers.py +338 -0
- nv_ingest_api/internal/primitives/nim/model_interface/nemotron_parse.py +239 -0
- nv_ingest_api/internal/primitives/nim/model_interface/ocr.py +776 -0
- nv_ingest_api/internal/primitives/nim/model_interface/parakeet.py +367 -0
- nv_ingest_api/internal/primitives/nim/model_interface/text_embedding.py +129 -0
- nv_ingest_api/internal/primitives/nim/model_interface/vlm.py +177 -0
- nv_ingest_api/internal/primitives/nim/model_interface/yolox.py +1681 -0
- nv_ingest_api/internal/primitives/nim/nim_client.py +801 -0
- nv_ingest_api/internal/primitives/nim/nim_model_interface.py +126 -0
- nv_ingest_api/internal/primitives/tracing/__init__.py +0 -0
- nv_ingest_api/internal/primitives/tracing/latency.py +69 -0
- nv_ingest_api/internal/primitives/tracing/logging.py +96 -0
- nv_ingest_api/internal/primitives/tracing/tagging.py +288 -0
- nv_ingest_api/internal/schemas/__init__.py +3 -0
- nv_ingest_api/internal/schemas/extract/__init__.py +3 -0
- nv_ingest_api/internal/schemas/extract/extract_audio_schema.py +133 -0
- nv_ingest_api/internal/schemas/extract/extract_chart_schema.py +144 -0
- nv_ingest_api/internal/schemas/extract/extract_docx_schema.py +129 -0
- nv_ingest_api/internal/schemas/extract/extract_html_schema.py +34 -0
- nv_ingest_api/internal/schemas/extract/extract_image_schema.py +126 -0
- nv_ingest_api/internal/schemas/extract/extract_infographic_schema.py +137 -0
- nv_ingest_api/internal/schemas/extract/extract_ocr_schema.py +137 -0
- nv_ingest_api/internal/schemas/extract/extract_pdf_schema.py +220 -0
- nv_ingest_api/internal/schemas/extract/extract_pptx_schema.py +128 -0
- nv_ingest_api/internal/schemas/extract/extract_table_schema.py +137 -0
- nv_ingest_api/internal/schemas/message_brokers/__init__.py +3 -0
- nv_ingest_api/internal/schemas/message_brokers/message_broker_client_schema.py +37 -0
- nv_ingest_api/internal/schemas/message_brokers/request_schema.py +34 -0
- nv_ingest_api/internal/schemas/message_brokers/response_schema.py +19 -0
- nv_ingest_api/internal/schemas/meta/__init__.py +3 -0
- nv_ingest_api/internal/schemas/meta/base_model_noext.py +11 -0
- nv_ingest_api/internal/schemas/meta/ingest_job_schema.py +355 -0
- nv_ingest_api/internal/schemas/meta/metadata_schema.py +394 -0
- nv_ingest_api/internal/schemas/meta/udf.py +23 -0
- nv_ingest_api/internal/schemas/mixins.py +39 -0
- nv_ingest_api/internal/schemas/mutate/__init__.py +3 -0
- nv_ingest_api/internal/schemas/mutate/mutate_image_dedup_schema.py +16 -0
- nv_ingest_api/internal/schemas/store/__init__.py +3 -0
- nv_ingest_api/internal/schemas/store/store_embedding_schema.py +28 -0
- nv_ingest_api/internal/schemas/store/store_image_schema.py +45 -0
- nv_ingest_api/internal/schemas/transform/__init__.py +3 -0
- nv_ingest_api/internal/schemas/transform/transform_image_caption_schema.py +36 -0
- nv_ingest_api/internal/schemas/transform/transform_image_filter_schema.py +17 -0
- nv_ingest_api/internal/schemas/transform/transform_text_embedding_schema.py +48 -0
- nv_ingest_api/internal/schemas/transform/transform_text_splitter_schema.py +24 -0
- nv_ingest_api/internal/store/__init__.py +3 -0
- nv_ingest_api/internal/store/embed_text_upload.py +236 -0
- nv_ingest_api/internal/store/image_upload.py +251 -0
- nv_ingest_api/internal/transform/__init__.py +3 -0
- nv_ingest_api/internal/transform/caption_image.py +219 -0
- nv_ingest_api/internal/transform/embed_text.py +702 -0
- nv_ingest_api/internal/transform/split_text.py +182 -0
- nv_ingest_api/util/__init__.py +3 -0
- nv_ingest_api/util/control_message/__init__.py +0 -0
- nv_ingest_api/util/control_message/validators.py +47 -0
- nv_ingest_api/util/converters/__init__.py +0 -0
- nv_ingest_api/util/converters/bytetools.py +78 -0
- nv_ingest_api/util/converters/containers.py +65 -0
- nv_ingest_api/util/converters/datetools.py +90 -0
- nv_ingest_api/util/converters/dftools.py +127 -0
- nv_ingest_api/util/converters/formats.py +64 -0
- nv_ingest_api/util/converters/type_mappings.py +27 -0
- nv_ingest_api/util/dataloader/__init__.py +9 -0
- nv_ingest_api/util/dataloader/dataloader.py +409 -0
- nv_ingest_api/util/detectors/__init__.py +5 -0
- nv_ingest_api/util/detectors/language.py +38 -0
- nv_ingest_api/util/exception_handlers/__init__.py +0 -0
- nv_ingest_api/util/exception_handlers/converters.py +72 -0
- nv_ingest_api/util/exception_handlers/decorators.py +429 -0
- nv_ingest_api/util/exception_handlers/detectors.py +74 -0
- nv_ingest_api/util/exception_handlers/pdf.py +116 -0
- nv_ingest_api/util/exception_handlers/schemas.py +68 -0
- nv_ingest_api/util/image_processing/__init__.py +5 -0
- nv_ingest_api/util/image_processing/clustering.py +260 -0
- nv_ingest_api/util/image_processing/processing.py +177 -0
- nv_ingest_api/util/image_processing/table_and_chart.py +504 -0
- nv_ingest_api/util/image_processing/transforms.py +850 -0
- nv_ingest_api/util/imports/__init__.py +3 -0
- nv_ingest_api/util/imports/callable_signatures.py +108 -0
- nv_ingest_api/util/imports/dynamic_resolvers.py +158 -0
- nv_ingest_api/util/introspection/__init__.py +3 -0
- nv_ingest_api/util/introspection/class_inspect.py +145 -0
- nv_ingest_api/util/introspection/function_inspect.py +65 -0
- nv_ingest_api/util/logging/__init__.py +0 -0
- nv_ingest_api/util/logging/configuration.py +102 -0
- nv_ingest_api/util/logging/sanitize.py +84 -0
- nv_ingest_api/util/message_brokers/__init__.py +3 -0
- nv_ingest_api/util/message_brokers/qos_scheduler.py +283 -0
- nv_ingest_api/util/message_brokers/simple_message_broker/__init__.py +9 -0
- nv_ingest_api/util/message_brokers/simple_message_broker/broker.py +465 -0
- nv_ingest_api/util/message_brokers/simple_message_broker/ordered_message_queue.py +71 -0
- nv_ingest_api/util/message_brokers/simple_message_broker/simple_client.py +455 -0
- nv_ingest_api/util/metadata/__init__.py +5 -0
- nv_ingest_api/util/metadata/aggregators.py +516 -0
- nv_ingest_api/util/multi_processing/__init__.py +8 -0
- nv_ingest_api/util/multi_processing/mp_pool_singleton.py +200 -0
- nv_ingest_api/util/nim/__init__.py +161 -0
- nv_ingest_api/util/pdf/__init__.py +3 -0
- nv_ingest_api/util/pdf/pdfium.py +428 -0
- nv_ingest_api/util/schema/__init__.py +3 -0
- nv_ingest_api/util/schema/schema_validator.py +10 -0
- nv_ingest_api/util/service_clients/__init__.py +3 -0
- nv_ingest_api/util/service_clients/client_base.py +86 -0
- nv_ingest_api/util/service_clients/kafka/__init__.py +3 -0
- nv_ingest_api/util/service_clients/redis/__init__.py +3 -0
- nv_ingest_api/util/service_clients/redis/redis_client.py +983 -0
- nv_ingest_api/util/service_clients/rest/__init__.py +0 -0
- nv_ingest_api/util/service_clients/rest/rest_client.py +595 -0
- nv_ingest_api/util/string_processing/__init__.py +51 -0
- nv_ingest_api/util/string_processing/configuration.py +682 -0
- nv_ingest_api/util/string_processing/yaml.py +109 -0
- nv_ingest_api/util/system/__init__.py +0 -0
- nv_ingest_api/util/system/hardware_info.py +594 -0
- nv_ingest_api-26.1.0rc4.dist-info/METADATA +237 -0
- nv_ingest_api-26.1.0rc4.dist-info/RECORD +177 -0
- nv_ingest_api-26.1.0rc4.dist-info/WHEEL +5 -0
- nv_ingest_api-26.1.0rc4.dist-info/licenses/LICENSE +201 -0
- nv_ingest_api-26.1.0rc4.dist-info/top_level.txt +2 -0
- udfs/__init__.py +5 -0
- udfs/llm_summarizer_udf.py +259 -0
|
@@ -0,0 +1,290 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from typing import Any
|
|
7
|
+
from typing import Dict
|
|
8
|
+
from typing import List
|
|
9
|
+
from typing import Optional
|
|
10
|
+
from typing import Tuple
|
|
11
|
+
|
|
12
|
+
import pandas as pd
|
|
13
|
+
|
|
14
|
+
from nv_ingest_api.internal.primitives.nim import NimClient
|
|
15
|
+
from nv_ingest_api.internal.primitives.nim.model_interface.ocr import PaddleOCRModelInterface
|
|
16
|
+
from nv_ingest_api.internal.primitives.nim.model_interface.ocr import NemoRetrieverOCRModelInterface
|
|
17
|
+
from nv_ingest_api.internal.primitives.nim.model_interface.ocr import get_ocr_model_name
|
|
18
|
+
from nv_ingest_api.internal.schemas.extract.extract_infographic_schema import InfographicExtractorSchema
|
|
19
|
+
from nv_ingest_api.util.image_processing.transforms import base64_to_numpy
|
|
20
|
+
from nv_ingest_api.util.nim import create_inference_client
|
|
21
|
+
from nv_ingest_api.util.image_processing.table_and_chart import reorder_boxes
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
PADDLE_MIN_WIDTH = 32
|
|
26
|
+
PADDLE_MIN_HEIGHT = 32
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _filter_infographic_images(
|
|
30
|
+
base64_images: List[str],
|
|
31
|
+
) -> Tuple[List[str], List[int], List[Tuple[str, Optional[Any], Optional[Any]]]]:
|
|
32
|
+
"""
|
|
33
|
+
Filters base64-encoded images based on minimum size requirements.
|
|
34
|
+
|
|
35
|
+
Parameters
|
|
36
|
+
----------
|
|
37
|
+
base64_images : List[str]
|
|
38
|
+
List of base64-encoded image strings.
|
|
39
|
+
|
|
40
|
+
Returns
|
|
41
|
+
-------
|
|
42
|
+
Tuple[List[str], List[int], List[Tuple[str, Optional[Any], Optional[Any]]]]
|
|
43
|
+
- valid_images: List of images that meet the size requirements.
|
|
44
|
+
- valid_indices: Original indices of valid images.
|
|
45
|
+
- results: Initialized results list, with invalid images marked as (img, None, None).
|
|
46
|
+
"""
|
|
47
|
+
results: List[Tuple[str, Optional[Any], Optional[Any]]] = [("", None, None)] * len(base64_images)
|
|
48
|
+
valid_images: List[str] = []
|
|
49
|
+
valid_indices: List[int] = []
|
|
50
|
+
|
|
51
|
+
for i, img in enumerate(base64_images):
|
|
52
|
+
array = base64_to_numpy(img)
|
|
53
|
+
height, width = array.shape[0], array.shape[1]
|
|
54
|
+
if width >= PADDLE_MIN_WIDTH and height >= PADDLE_MIN_HEIGHT:
|
|
55
|
+
valid_images.append(img)
|
|
56
|
+
valid_indices.append(i)
|
|
57
|
+
else:
|
|
58
|
+
# Mark image as skipped if it does not meet size requirements.
|
|
59
|
+
results[i] = (img, None, None)
|
|
60
|
+
return valid_images, valid_indices, results
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _update_infographic_metadata(
|
|
64
|
+
base64_images: List[str],
|
|
65
|
+
ocr_client: NimClient,
|
|
66
|
+
ocr_model_name: str,
|
|
67
|
+
worker_pool_size: int = 8, # Not currently used
|
|
68
|
+
trace_info: Optional[Dict] = None,
|
|
69
|
+
) -> List[Tuple[str, Optional[Any], Optional[Any]]]:
|
|
70
|
+
"""
|
|
71
|
+
Filters base64-encoded images and uses OCR to extract infographic data.
|
|
72
|
+
|
|
73
|
+
For each image that meets the minimum size, calls ocr_client.infer to obtain
|
|
74
|
+
(text_predictions, bounding_boxes). Invalid images are marked as skipped.
|
|
75
|
+
|
|
76
|
+
Parameters
|
|
77
|
+
----------
|
|
78
|
+
base64_images : List[str]
|
|
79
|
+
List of base64-encoded images.
|
|
80
|
+
ocr_client : NimClient
|
|
81
|
+
Client instance for OCR inference.
|
|
82
|
+
worker_pool_size : int, optional
|
|
83
|
+
Worker pool size (currently not used), by default 8.
|
|
84
|
+
trace_info : Optional[Dict], optional
|
|
85
|
+
Optional trace information for debugging.
|
|
86
|
+
|
|
87
|
+
Returns
|
|
88
|
+
-------
|
|
89
|
+
List[Tuple[str, Optional[Any], Optional[Any]]]
|
|
90
|
+
List of tuples in the same order as base64_images, where each tuple contains:
|
|
91
|
+
(base64_image, text_predictions, bounding_boxes).
|
|
92
|
+
"""
|
|
93
|
+
logger.debug(f"Running infographic extraction using protocol {ocr_client.protocol}")
|
|
94
|
+
|
|
95
|
+
valid_images, valid_indices, results = _filter_infographic_images(base64_images)
|
|
96
|
+
data_ocr = {"base64_images": valid_images}
|
|
97
|
+
|
|
98
|
+
# worker_pool_size is not used in current implementation.
|
|
99
|
+
_ = worker_pool_size
|
|
100
|
+
|
|
101
|
+
infer_kwargs = dict(
|
|
102
|
+
stage_name="infographic_extraction",
|
|
103
|
+
trace_info=trace_info,
|
|
104
|
+
)
|
|
105
|
+
if ocr_model_name == "paddle":
|
|
106
|
+
infer_kwargs.update(
|
|
107
|
+
model_name="paddle",
|
|
108
|
+
max_batch_size=1 if ocr_client.protocol == "grpc" else 2,
|
|
109
|
+
)
|
|
110
|
+
elif ocr_model_name in {"scene_text_ensemble", "scene_text_wrapper", "scene_text_python"}:
|
|
111
|
+
infer_kwargs.update(
|
|
112
|
+
model_name=ocr_model_name,
|
|
113
|
+
input_names=["INPUT_IMAGE_URLS", "MERGE_LEVELS"],
|
|
114
|
+
output_names=["OUTPUT"],
|
|
115
|
+
dtypes=["BYTES", "BYTES"],
|
|
116
|
+
merge_level="paragraph",
|
|
117
|
+
)
|
|
118
|
+
else:
|
|
119
|
+
raise ValueError(f"Unknown OCR model name: {ocr_model_name}")
|
|
120
|
+
|
|
121
|
+
try:
|
|
122
|
+
ocr_results = ocr_client.infer(data_ocr, **infer_kwargs)
|
|
123
|
+
except Exception as e:
|
|
124
|
+
logger.error(f"Error calling ocr_client.infer: {e}", exc_info=True)
|
|
125
|
+
raise
|
|
126
|
+
|
|
127
|
+
if len(ocr_results) != len(valid_images):
|
|
128
|
+
raise ValueError(f"Expected {len(valid_images)} ocr results, got {len(ocr_results)}")
|
|
129
|
+
|
|
130
|
+
for idx, ocr_res in enumerate(ocr_results):
|
|
131
|
+
original_index = valid_indices[idx]
|
|
132
|
+
|
|
133
|
+
if ocr_model_name == "paddle":
|
|
134
|
+
logger.debug(f"OCR results for image {base64_images[original_index]}: {ocr_res}")
|
|
135
|
+
else:
|
|
136
|
+
# Each ocr_res is expected to be a tuple (text_predictions, bounding_boxes, conf_scores).
|
|
137
|
+
ocr_res = reorder_boxes(*ocr_res)
|
|
138
|
+
|
|
139
|
+
results[original_index] = (
|
|
140
|
+
base64_images[original_index],
|
|
141
|
+
ocr_res[0],
|
|
142
|
+
ocr_res[1],
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
return results
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def _create_ocr_client(
|
|
149
|
+
ocr_endpoints: Tuple[str, str],
|
|
150
|
+
ocr_protocol: str,
|
|
151
|
+
ocr_model_name: str,
|
|
152
|
+
auth_token: str,
|
|
153
|
+
) -> NimClient:
|
|
154
|
+
ocr_model_interface = (
|
|
155
|
+
NemoRetrieverOCRModelInterface()
|
|
156
|
+
if ocr_model_name in {"scene_text_ensemble", "scene_text_wrapper", "scene_text_python"}
|
|
157
|
+
else PaddleOCRModelInterface()
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
ocr_client = create_inference_client(
|
|
161
|
+
endpoints=ocr_endpoints,
|
|
162
|
+
model_interface=ocr_model_interface,
|
|
163
|
+
auth_token=auth_token,
|
|
164
|
+
infer_protocol=ocr_protocol,
|
|
165
|
+
enable_dynamic_batching=(
|
|
166
|
+
True if ocr_model_name in {"scene_text_ensemble", "scene_text_wrapper", "scene_text_python"} else False
|
|
167
|
+
),
|
|
168
|
+
dynamic_batch_memory_budget_mb=32,
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
return ocr_client
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def _meets_infographic_criteria(row: pd.Series) -> bool:
|
|
175
|
+
"""
|
|
176
|
+
Determines if a DataFrame row meets the criteria for infographic extraction.
|
|
177
|
+
|
|
178
|
+
A row qualifies if:
|
|
179
|
+
- It contains a 'metadata' dictionary.
|
|
180
|
+
- The 'content_metadata' in metadata has type "structured" and subtype "infographic".
|
|
181
|
+
- The 'table_metadata' is not None.
|
|
182
|
+
- The 'content' is not None or an empty string.
|
|
183
|
+
|
|
184
|
+
Parameters
|
|
185
|
+
----------
|
|
186
|
+
row : pd.Series
|
|
187
|
+
A row from the DataFrame.
|
|
188
|
+
|
|
189
|
+
Returns
|
|
190
|
+
-------
|
|
191
|
+
bool
|
|
192
|
+
True if the row meets all criteria; False otherwise.
|
|
193
|
+
"""
|
|
194
|
+
metadata = row.get("metadata", {})
|
|
195
|
+
if not metadata:
|
|
196
|
+
return False
|
|
197
|
+
|
|
198
|
+
content_md = metadata.get("content_metadata", {})
|
|
199
|
+
if (
|
|
200
|
+
content_md.get("type") == "structured"
|
|
201
|
+
and content_md.get("subtype") == "infographic"
|
|
202
|
+
and metadata.get("table_metadata") is not None
|
|
203
|
+
and metadata.get("content") not in [None, ""]
|
|
204
|
+
):
|
|
205
|
+
return True
|
|
206
|
+
|
|
207
|
+
return False
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def extract_infographic_data_from_image_internal(
|
|
211
|
+
df_extraction_ledger: pd.DataFrame,
|
|
212
|
+
task_config: Dict[str, Any],
|
|
213
|
+
extraction_config: InfographicExtractorSchema,
|
|
214
|
+
execution_trace_log: Optional[Dict] = None,
|
|
215
|
+
) -> Tuple[pd.DataFrame, Dict]:
|
|
216
|
+
"""
|
|
217
|
+
Extracts infographic data from a DataFrame in bulk, following the chart extraction pattern.
|
|
218
|
+
|
|
219
|
+
Parameters
|
|
220
|
+
----------
|
|
221
|
+
df_extraction_ledger : pd.DataFrame
|
|
222
|
+
DataFrame containing the content from which infographic data is to be extracted.
|
|
223
|
+
task_config : Dict[str, Any]
|
|
224
|
+
Dictionary containing task properties and configurations.
|
|
225
|
+
extraction_config : Any
|
|
226
|
+
The validated configuration object for infographic extraction.
|
|
227
|
+
execution_trace_log : Optional[Dict], optional
|
|
228
|
+
Optional trace information for debugging or logging. Defaults to None.
|
|
229
|
+
|
|
230
|
+
Returns
|
|
231
|
+
-------
|
|
232
|
+
Tuple[pd.DataFrame, Dict]
|
|
233
|
+
A tuple containing the updated DataFrame and the trace information.
|
|
234
|
+
"""
|
|
235
|
+
_ = task_config # Unused
|
|
236
|
+
|
|
237
|
+
if execution_trace_log is None:
|
|
238
|
+
execution_trace_log = {}
|
|
239
|
+
logger.debug("No trace_info provided. Initialized empty trace_info dictionary.")
|
|
240
|
+
|
|
241
|
+
if df_extraction_ledger.empty:
|
|
242
|
+
return df_extraction_ledger, execution_trace_log
|
|
243
|
+
|
|
244
|
+
endpoint_config = extraction_config.endpoint_config
|
|
245
|
+
|
|
246
|
+
# Get the grpc endpoint to determine the model if needed
|
|
247
|
+
ocr_grpc_endpoint = endpoint_config.ocr_endpoints[0]
|
|
248
|
+
ocr_model_name = get_ocr_model_name(ocr_grpc_endpoint)
|
|
249
|
+
|
|
250
|
+
try:
|
|
251
|
+
# Identify rows that meet the infographic criteria.
|
|
252
|
+
mask = df_extraction_ledger.apply(_meets_infographic_criteria, axis=1)
|
|
253
|
+
valid_indices = df_extraction_ledger[mask].index.tolist()
|
|
254
|
+
|
|
255
|
+
# If no rows meet the criteria, return early.
|
|
256
|
+
if not valid_indices:
|
|
257
|
+
return df_extraction_ledger, {"trace_info": execution_trace_log}
|
|
258
|
+
|
|
259
|
+
# Extract base64 images from valid rows.
|
|
260
|
+
base64_images = [df_extraction_ledger.at[idx, "metadata"]["content"] for idx in valid_indices]
|
|
261
|
+
|
|
262
|
+
# Call bulk update to extract infographic data.
|
|
263
|
+
ocr_client = _create_ocr_client(
|
|
264
|
+
endpoint_config.ocr_endpoints,
|
|
265
|
+
endpoint_config.ocr_infer_protocol,
|
|
266
|
+
ocr_model_name,
|
|
267
|
+
endpoint_config.auth_token,
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
bulk_results = _update_infographic_metadata(
|
|
271
|
+
base64_images=base64_images,
|
|
272
|
+
ocr_client=ocr_client,
|
|
273
|
+
ocr_model_name=ocr_model_name,
|
|
274
|
+
worker_pool_size=endpoint_config.workers_per_progress_engine,
|
|
275
|
+
trace_info=execution_trace_log,
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
# Write the extracted results back into the DataFrame.
|
|
279
|
+
for result_idx, df_idx in enumerate(valid_indices):
|
|
280
|
+
# Unpack result: (base64_image, ocr_bounding_boxes, ocr_text_predictions)
|
|
281
|
+
_, _, text_predictions = bulk_results[result_idx]
|
|
282
|
+
table_content = " ".join(text_predictions) if text_predictions else None
|
|
283
|
+
df_extraction_ledger.at[df_idx, "metadata"]["table_metadata"]["table_content"] = table_content
|
|
284
|
+
|
|
285
|
+
return df_extraction_ledger, {"trace_info": execution_trace_log}
|
|
286
|
+
|
|
287
|
+
except Exception:
|
|
288
|
+
err_msg = "Error occurred while extracting infographic data."
|
|
289
|
+
logger.exception(err_msg)
|
|
290
|
+
raise
|
|
@@ -0,0 +1,407 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from typing import Any
|
|
7
|
+
from typing import Dict
|
|
8
|
+
from typing import List
|
|
9
|
+
from typing import Optional
|
|
10
|
+
from typing import Tuple
|
|
11
|
+
|
|
12
|
+
import pandas as pd
|
|
13
|
+
|
|
14
|
+
from nv_ingest_api.internal.enums.common import ContentTypeEnum
|
|
15
|
+
from nv_ingest_api.internal.primitives.nim import NimClient
|
|
16
|
+
from nv_ingest_api.internal.primitives.nim.model_interface.ocr import PaddleOCRModelInterface
|
|
17
|
+
from nv_ingest_api.internal.primitives.nim.model_interface.ocr import NemoRetrieverOCRModelInterface
|
|
18
|
+
from nv_ingest_api.internal.primitives.nim.model_interface.ocr import get_ocr_model_name
|
|
19
|
+
from nv_ingest_api.internal.schemas.extract.extract_ocr_schema import OCRExtractorSchema
|
|
20
|
+
from nv_ingest_api.util.image_processing.transforms import base64_to_numpy
|
|
21
|
+
from nv_ingest_api.util.nim import create_inference_client
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
PADDLE_MIN_WIDTH = 32
|
|
26
|
+
PADDLE_MIN_HEIGHT = 32
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _filter_text_images(
|
|
30
|
+
base64_images: List[str],
|
|
31
|
+
min_width: int = PADDLE_MIN_WIDTH,
|
|
32
|
+
min_height: int = PADDLE_MIN_HEIGHT,
|
|
33
|
+
) -> Tuple[List[str], List[int], List[Tuple[str, Optional[Any], Optional[Any]]]]:
|
|
34
|
+
"""
|
|
35
|
+
Filters base64-encoded images based on minimum size requirements.
|
|
36
|
+
|
|
37
|
+
Parameters
|
|
38
|
+
----------
|
|
39
|
+
base64_images : List[str]
|
|
40
|
+
List of base64-encoded image strings.
|
|
41
|
+
|
|
42
|
+
Returns
|
|
43
|
+
-------
|
|
44
|
+
Tuple[List[str], List[int], List[Tuple[str, Optional[Any], Optional[Any]]]]
|
|
45
|
+
- valid_images: List of images that meet the size requirements.
|
|
46
|
+
- valid_indices: Original indices of valid images.
|
|
47
|
+
"""
|
|
48
|
+
valid_images: List[str] = []
|
|
49
|
+
valid_indices: List[int] = []
|
|
50
|
+
|
|
51
|
+
for i, img in enumerate(base64_images):
|
|
52
|
+
array = base64_to_numpy(img)
|
|
53
|
+
height, width = array.shape[0], array.shape[1]
|
|
54
|
+
if width >= min_width and height >= min_height:
|
|
55
|
+
valid_images.append(img)
|
|
56
|
+
valid_indices.append(i)
|
|
57
|
+
return valid_images, valid_indices
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def _update_text_metadata(
|
|
61
|
+
base64_images: List[str],
|
|
62
|
+
ocr_client: NimClient,
|
|
63
|
+
ocr_model_name: str,
|
|
64
|
+
worker_pool_size: int = 8, # Not currently used
|
|
65
|
+
trace_info: Optional[Dict] = None,
|
|
66
|
+
) -> List[Tuple[str, Optional[Any], Optional[Any]]]:
|
|
67
|
+
"""
|
|
68
|
+
Filters base64-encoded images and uses OCR to extract text data.
|
|
69
|
+
|
|
70
|
+
For each image that meets the minimum size, calls ocr_client.infer to obtain
|
|
71
|
+
(text_predictions, bounding_boxes). Invalid images are marked as skipped.
|
|
72
|
+
|
|
73
|
+
Parameters
|
|
74
|
+
----------
|
|
75
|
+
base64_images : List[str]
|
|
76
|
+
List of base64-encoded images.
|
|
77
|
+
ocr_client : NimClient
|
|
78
|
+
Client instance for OCR inference.
|
|
79
|
+
worker_pool_size : int, optional
|
|
80
|
+
Worker pool size (currently not used), by default 8.
|
|
81
|
+
trace_info : Optional[Dict], optional
|
|
82
|
+
Optional trace information for debugging.
|
|
83
|
+
|
|
84
|
+
Returns
|
|
85
|
+
-------
|
|
86
|
+
List[Tuple[str, Optional[Any], Optional[Any]]]
|
|
87
|
+
List of tuples in the same order as base64_images, where each tuple contains:
|
|
88
|
+
(base64_image, text_predictions, bounding_boxes).
|
|
89
|
+
"""
|
|
90
|
+
logger.debug(f"Running text extraction using protocol {ocr_client.protocol}")
|
|
91
|
+
|
|
92
|
+
if ocr_model_name == "paddle":
|
|
93
|
+
valid_images, valid_indices = _filter_text_images(base64_images)
|
|
94
|
+
else:
|
|
95
|
+
valid_images, valid_indices = _filter_text_images(base64_images, min_width=1, min_height=1)
|
|
96
|
+
data_ocr = {"base64_images": valid_images}
|
|
97
|
+
|
|
98
|
+
# worker_pool_size is not used in current implementation.
|
|
99
|
+
_ = worker_pool_size
|
|
100
|
+
|
|
101
|
+
infer_kwargs = dict(
|
|
102
|
+
stage_name="ocr_extraction",
|
|
103
|
+
trace_info=trace_info,
|
|
104
|
+
)
|
|
105
|
+
if ocr_model_name == "paddle":
|
|
106
|
+
infer_kwargs.update(
|
|
107
|
+
model_name="paddle",
|
|
108
|
+
max_batch_size=1 if ocr_client.protocol == "grpc" else 2,
|
|
109
|
+
)
|
|
110
|
+
elif ocr_model_name in {"scene_text_ensemble", "scene_text_wrapper", "scene_text_python"}:
|
|
111
|
+
infer_kwargs.update(
|
|
112
|
+
model_name=ocr_model_name,
|
|
113
|
+
input_names=["INPUT_IMAGE_URLS", "MERGE_LEVELS"],
|
|
114
|
+
output_names=["OUTPUT"],
|
|
115
|
+
dtypes=["BYTES", "BYTES"],
|
|
116
|
+
merge_level="paragraph",
|
|
117
|
+
)
|
|
118
|
+
else:
|
|
119
|
+
raise ValueError(f"Unknown OCR model name: {ocr_model_name}")
|
|
120
|
+
|
|
121
|
+
try:
|
|
122
|
+
ocr_results = ocr_client.infer(data_ocr, **infer_kwargs)
|
|
123
|
+
except Exception as e:
|
|
124
|
+
logger.error(f"Error calling ocr_client.infer: {e}", exc_info=True)
|
|
125
|
+
raise
|
|
126
|
+
|
|
127
|
+
if len(ocr_results) != len(valid_images):
|
|
128
|
+
raise ValueError(f"Expected {len(valid_images)} ocr results, got {len(ocr_results)}")
|
|
129
|
+
|
|
130
|
+
results = [(None, None, None)] * len(base64_images)
|
|
131
|
+
for idx, ocr_res in enumerate(ocr_results):
|
|
132
|
+
original_index = valid_indices[idx]
|
|
133
|
+
results[original_index] = ocr_res
|
|
134
|
+
|
|
135
|
+
return results
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def _create_ocr_client(
|
|
139
|
+
ocr_endpoints: Tuple[str, str],
|
|
140
|
+
ocr_protocol: str,
|
|
141
|
+
ocr_model_name: str,
|
|
142
|
+
auth_token: str,
|
|
143
|
+
) -> NimClient:
|
|
144
|
+
ocr_model_interface = (
|
|
145
|
+
NemoRetrieverOCRModelInterface()
|
|
146
|
+
if ocr_model_name in {"scene_text_ensemble", "scene_text_wrapper", "scene_text_python"}
|
|
147
|
+
else PaddleOCRModelInterface()
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
ocr_client = create_inference_client(
|
|
151
|
+
endpoints=ocr_endpoints,
|
|
152
|
+
model_interface=ocr_model_interface,
|
|
153
|
+
auth_token=auth_token,
|
|
154
|
+
infer_protocol=ocr_protocol,
|
|
155
|
+
enable_dynamic_batching=(
|
|
156
|
+
True if ocr_model_name in {"scene_text_ensemble", "scene_text_wrapper", "scene_text_python"} else False
|
|
157
|
+
),
|
|
158
|
+
dynamic_batch_memory_budget_mb=32,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
return ocr_client
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def _meets_page_elements_text_criteria(row: pd.Series) -> bool:
|
|
165
|
+
"""
|
|
166
|
+
Determines if a DataFrame row meets the criteria for text extraction.
|
|
167
|
+
|
|
168
|
+
A row qualifies if:
|
|
169
|
+
- It contains a 'metadata' dictionary.
|
|
170
|
+
- The 'content_metadata' in metadata has type "text" and one of subtype:
|
|
171
|
+
"title", "paragraph", "header_footer".
|
|
172
|
+
- The 'content' is not None or an empty string.
|
|
173
|
+
|
|
174
|
+
Parameters
|
|
175
|
+
----------
|
|
176
|
+
row : pd.Series
|
|
177
|
+
A row from the DataFrame.
|
|
178
|
+
|
|
179
|
+
Returns
|
|
180
|
+
-------
|
|
181
|
+
bool
|
|
182
|
+
True if the row meets all criteria; False otherwise.
|
|
183
|
+
"""
|
|
184
|
+
page_element_subtypes = {"paragraph", "title", "header_footer"}
|
|
185
|
+
|
|
186
|
+
metadata = row.get("metadata", {})
|
|
187
|
+
if not metadata:
|
|
188
|
+
return False
|
|
189
|
+
|
|
190
|
+
content_md = metadata.get("content_metadata", {})
|
|
191
|
+
|
|
192
|
+
if (
|
|
193
|
+
content_md.get("type") == ContentTypeEnum.TEXT
|
|
194
|
+
and content_md.get("subtype") in page_element_subtypes
|
|
195
|
+
and metadata.get("content") not in {None, ""}
|
|
196
|
+
):
|
|
197
|
+
return True
|
|
198
|
+
|
|
199
|
+
return False
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def _meets_page_image_criteria(row: pd.Series) -> bool:
|
|
203
|
+
"""
|
|
204
|
+
Determines if a DataFrame row meets the criteria for text extraction.
|
|
205
|
+
|
|
206
|
+
A row qualifies if:
|
|
207
|
+
- It contains a 'metadata' dictionary.
|
|
208
|
+
- The 'content_metadata' in metadata has type "image" and subtype "page_image".
|
|
209
|
+
- The 'content' is not None or an empty string.
|
|
210
|
+
|
|
211
|
+
Parameters
|
|
212
|
+
----------
|
|
213
|
+
row : pd.Series
|
|
214
|
+
A row from the DataFrame.
|
|
215
|
+
|
|
216
|
+
Returns
|
|
217
|
+
-------
|
|
218
|
+
bool
|
|
219
|
+
True if the row meets all criteria; False otherwise.
|
|
220
|
+
"""
|
|
221
|
+
page_image_subtypes = {ContentTypeEnum.PAGE_IMAGE}
|
|
222
|
+
|
|
223
|
+
metadata = row.get("metadata", {})
|
|
224
|
+
if not metadata:
|
|
225
|
+
return False
|
|
226
|
+
|
|
227
|
+
content_md = metadata.get("content_metadata", {})
|
|
228
|
+
|
|
229
|
+
if (
|
|
230
|
+
content_md.get("type") == ContentTypeEnum.IMAGE
|
|
231
|
+
and content_md.get("subtype") in page_image_subtypes
|
|
232
|
+
and metadata.get("content") not in {None, ""}
|
|
233
|
+
):
|
|
234
|
+
return True
|
|
235
|
+
|
|
236
|
+
return False
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def _process_page_images(df_to_process: pd.DataFrame, ocr_results: List[Tuple]):
|
|
240
|
+
valid_indices = df_to_process.index.tolist()
|
|
241
|
+
|
|
242
|
+
for result_idx, df_idx in enumerate(valid_indices):
|
|
243
|
+
# Unpack result: (bounding_boxes, text_predictions, confidence_scores)
|
|
244
|
+
bboxes, texts, _ = ocr_results[result_idx]
|
|
245
|
+
if not bboxes or not texts:
|
|
246
|
+
df_to_process.loc[df_idx, "metadata"]["image_metadata"]["text"] = ""
|
|
247
|
+
continue
|
|
248
|
+
|
|
249
|
+
df_to_process.loc[df_idx, "metadata"]["image_metadata"]["text"] = " ".join([t for t in texts])
|
|
250
|
+
|
|
251
|
+
return df_to_process
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
def _process_page_elements(df_to_process: pd.DataFrame, ocr_results: List[Tuple]):
|
|
255
|
+
valid_indices = df_to_process.index.tolist()
|
|
256
|
+
if not valid_indices:
|
|
257
|
+
return df_to_process
|
|
258
|
+
|
|
259
|
+
for result_idx, df_idx in enumerate(valid_indices):
|
|
260
|
+
# Unpack result: (bounding_boxes, text_predictions, confidence_scores)
|
|
261
|
+
bboxes, texts, _ = ocr_results[result_idx]
|
|
262
|
+
if not bboxes or not texts:
|
|
263
|
+
df_to_process.loc[df_idx, "_x0"] = None
|
|
264
|
+
df_to_process.loc[df_idx, "_y0"] = None
|
|
265
|
+
df_to_process.loc[df_idx, "metadata"]["content"] = ""
|
|
266
|
+
continue
|
|
267
|
+
|
|
268
|
+
combined_data = list(zip(bboxes, texts))
|
|
269
|
+
try:
|
|
270
|
+
# Sort by reading order (y_min, then x_min)
|
|
271
|
+
combined_data.sort(key=lambda item: (min(p[1] for p in item[0]), min(p[0] for p in item[0])))
|
|
272
|
+
except (ValueError, IndexError):
|
|
273
|
+
logger.warning("Could not sort OCR results due to malformed bounding box.")
|
|
274
|
+
df_to_process.loc[df_idx, "_x0"] = min(point[0] for item in combined_data for point in item[0])
|
|
275
|
+
df_to_process.loc[df_idx, "_y0"] = min(point[1] for item in combined_data for point in item[0])
|
|
276
|
+
df_to_process.loc[df_idx, "metadata"]["content"] = " ".join([item[1] for item in combined_data])
|
|
277
|
+
|
|
278
|
+
df_to_process = df_to_process.drop(["_x0", "_y0"], axis=1)
|
|
279
|
+
|
|
280
|
+
df_to_process.loc[:, "_page_number"] = df_to_process["metadata"].apply(
|
|
281
|
+
lambda meta: meta["content_metadata"]["page_number"]
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
# Group by page number to aggregate all text blocks on each page
|
|
285
|
+
grouped = df_to_process.groupby("_page_number")
|
|
286
|
+
|
|
287
|
+
new_text = {}
|
|
288
|
+
for page_num, group_df in grouped:
|
|
289
|
+
if group_df.empty:
|
|
290
|
+
continue
|
|
291
|
+
# Sort text blocks by their original position for correct reading order
|
|
292
|
+
group_df.loc[:, "_x0"] = group_df["metadata"].apply(lambda meta: meta["text_metadata"]["text_location"][0])
|
|
293
|
+
group_df.loc[:, "_y0"] = group_df["metadata"].apply(lambda meta: meta["text_metadata"]["text_location"][1])
|
|
294
|
+
|
|
295
|
+
loc_mask = group_df[["_y0", "_x0"]].notna().all(axis=1)
|
|
296
|
+
sorted_group = group_df.loc[loc_mask].sort_values(by=["_y0", "_x0"], ascending=[True, True])
|
|
297
|
+
page_text = " ".join(sorted_group["metadata"].apply(lambda meta: meta["content"]).tolist())
|
|
298
|
+
|
|
299
|
+
if page_text.strip():
|
|
300
|
+
new_text[page_num] = page_text
|
|
301
|
+
|
|
302
|
+
df_text = df_to_process[df_to_process["document_type"] == "text"].drop_duplicates(
|
|
303
|
+
subset=["_page_number"], keep="first"
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
for page_num, page_text in new_text.items():
|
|
307
|
+
page_num_mask = df_text["_page_number"] == page_num
|
|
308
|
+
df_text.loc[page_num_mask, "metadata"] = df_text.loc[page_num_mask, "metadata"].apply(
|
|
309
|
+
lambda meta: {**meta, "content": page_text}
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
df_non_text = df_to_process[df_to_process["document_type"] != "text"]
|
|
313
|
+
df_to_process = pd.concat([df_text, df_non_text])
|
|
314
|
+
|
|
315
|
+
for col in {"_y0", "_x0", "_page_number"}:
|
|
316
|
+
if col in df_to_process:
|
|
317
|
+
df_to_process = df_to_process.drop(col, axis=1)
|
|
318
|
+
|
|
319
|
+
return df_to_process
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
def extract_text_data_from_image_internal(
|
|
323
|
+
df_extraction_ledger: pd.DataFrame,
|
|
324
|
+
task_config: Dict[str, Any],
|
|
325
|
+
extraction_config: OCRExtractorSchema,
|
|
326
|
+
execution_trace_log: Optional[Dict] = None,
|
|
327
|
+
) -> Tuple[pd.DataFrame, Dict]:
|
|
328
|
+
"""
|
|
329
|
+
Extracts text data from a DataFrame in bulk, following the chart extraction pattern.
|
|
330
|
+
|
|
331
|
+
Parameters
|
|
332
|
+
----------
|
|
333
|
+
df_extraction_ledger : pd.DataFrame
|
|
334
|
+
DataFrame containing the content from which text data is to be extracted.
|
|
335
|
+
task_config : Dict[str, Any]
|
|
336
|
+
Dictionary containing task properties and configurations.
|
|
337
|
+
extraction_config : Any
|
|
338
|
+
The validated configuration object for text extraction.
|
|
339
|
+
execution_trace_log : Optional[Dict], optional
|
|
340
|
+
Optional trace information for debugging or logging. Defaults to None.
|
|
341
|
+
|
|
342
|
+
Returns
|
|
343
|
+
-------
|
|
344
|
+
Tuple[pd.DataFrame, Dict]
|
|
345
|
+
A tuple containing the updated DataFrame and the trace information.
|
|
346
|
+
"""
|
|
347
|
+
_ = task_config # Unused
|
|
348
|
+
|
|
349
|
+
if execution_trace_log is None:
|
|
350
|
+
execution_trace_log = {}
|
|
351
|
+
logger.debug("No trace_info provided. Initialized empty trace_info dictionary.")
|
|
352
|
+
|
|
353
|
+
if df_extraction_ledger.empty:
|
|
354
|
+
return df_extraction_ledger, execution_trace_log
|
|
355
|
+
|
|
356
|
+
endpoint_config = extraction_config.endpoint_config
|
|
357
|
+
|
|
358
|
+
# Get the grpc endpoint to determine the model if needed
|
|
359
|
+
ocr_grpc_endpoint = endpoint_config.ocr_endpoints[0]
|
|
360
|
+
ocr_model_name = get_ocr_model_name(ocr_grpc_endpoint)
|
|
361
|
+
|
|
362
|
+
try:
|
|
363
|
+
# Identify rows that meet the text criteria.
|
|
364
|
+
page_images_mask = df_extraction_ledger.apply(_meets_page_image_criteria, axis=1)
|
|
365
|
+
page_elements_mask = df_extraction_ledger.apply(_meets_page_elements_text_criteria, axis=1)
|
|
366
|
+
|
|
367
|
+
df_to_process = df_extraction_ledger[page_images_mask | page_elements_mask].copy()
|
|
368
|
+
df_unprocessed = df_extraction_ledger[~page_images_mask & ~page_elements_mask].copy()
|
|
369
|
+
|
|
370
|
+
valid_indices = df_to_process.index.tolist()
|
|
371
|
+
# If no rows meet the criteria, return early.
|
|
372
|
+
if not valid_indices:
|
|
373
|
+
return df_extraction_ledger, {"trace_info": execution_trace_log}
|
|
374
|
+
|
|
375
|
+
# Extract base64 images from valid rows.
|
|
376
|
+
base64_images = [row["metadata"]["content"] for _, row in df_to_process.iterrows()]
|
|
377
|
+
|
|
378
|
+
# Call bulk update to extract text data.
|
|
379
|
+
ocr_client = _create_ocr_client(
|
|
380
|
+
endpoint_config.ocr_endpoints,
|
|
381
|
+
endpoint_config.ocr_infer_protocol,
|
|
382
|
+
ocr_model_name,
|
|
383
|
+
endpoint_config.auth_token,
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
bulk_results = _update_text_metadata(
|
|
387
|
+
base64_images=base64_images,
|
|
388
|
+
ocr_client=ocr_client,
|
|
389
|
+
ocr_model_name=ocr_model_name,
|
|
390
|
+
worker_pool_size=endpoint_config.workers_per_progress_engine,
|
|
391
|
+
trace_info=execution_trace_log,
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
df_page_images = df_to_process[df_to_process.apply(_meets_page_image_criteria, axis=1)]
|
|
395
|
+
df_page_images = _process_page_images(df_page_images, bulk_results)
|
|
396
|
+
|
|
397
|
+
df_page_elements = df_to_process[df_to_process.apply(_meets_page_elements_text_criteria, axis=1)]
|
|
398
|
+
df_page_elements = _process_page_elements(df_page_elements, bulk_results)
|
|
399
|
+
|
|
400
|
+
df_final = pd.concat([df_unprocessed, df_page_images, df_page_elements], ignore_index=True)
|
|
401
|
+
|
|
402
|
+
return df_final, {"trace_info": execution_trace_log}
|
|
403
|
+
|
|
404
|
+
except Exception:
|
|
405
|
+
err_msg = "Error occurred while extracting text data."
|
|
406
|
+
logger.exception(err_msg)
|
|
407
|
+
raise
|