nv-ingest-api 25.4.2__py3-none-any.whl → 25.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nv-ingest-api might be problematic. Click here for more details.
- nv_ingest_api/internal/extract/docx/docx_extractor.py +3 -3
- nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/docxreader.py +142 -86
- nv_ingest_api/internal/extract/html/__init__.py +3 -0
- nv_ingest_api/internal/extract/html/html_extractor.py +84 -0
- nv_ingest_api/internal/extract/image/chart_extractor.py +3 -3
- nv_ingest_api/internal/extract/image/image_extractor.py +5 -5
- nv_ingest_api/internal/extract/image/image_helpers/common.py +1 -1
- nv_ingest_api/internal/extract/image/infographic_extractor.py +1 -1
- nv_ingest_api/internal/extract/image/table_extractor.py +2 -2
- nv_ingest_api/internal/extract/pdf/engines/nemoretriever.py +2 -2
- nv_ingest_api/internal/extract/pdf/engines/pdfium.py +1 -1
- nv_ingest_api/internal/extract/pptx/engines/pptx_helper.py +213 -187
- nv_ingest_api/internal/extract/pptx/pptx_extractor.py +6 -9
- nv_ingest_api/internal/primitives/nim/model_interface/text_embedding.py +35 -38
- nv_ingest_api/internal/primitives/nim/model_interface/yolox.py +7 -1
- nv_ingest_api/internal/primitives/nim/nim_client.py +17 -9
- nv_ingest_api/internal/primitives/tracing/tagging.py +20 -16
- nv_ingest_api/internal/schemas/extract/extract_chart_schema.py +1 -1
- nv_ingest_api/internal/schemas/extract/extract_html_schema.py +34 -0
- nv_ingest_api/internal/schemas/extract/extract_infographic_schema.py +1 -1
- nv_ingest_api/internal/schemas/extract/extract_pdf_schema.py +1 -1
- nv_ingest_api/internal/schemas/extract/extract_table_schema.py +1 -1
- nv_ingest_api/internal/schemas/message_brokers/message_broker_client_schema.py +26 -12
- nv_ingest_api/internal/schemas/meta/ingest_job_schema.py +34 -23
- nv_ingest_api/internal/schemas/transform/transform_text_embedding_schema.py +11 -10
- nv_ingest_api/internal/schemas/transform/transform_text_splitter_schema.py +9 -7
- nv_ingest_api/internal/store/image_upload.py +1 -0
- nv_ingest_api/internal/transform/embed_text.py +75 -52
- nv_ingest_api/internal/transform/split_text.py +9 -3
- nv_ingest_api/util/__init__.py +3 -0
- nv_ingest_api/util/exception_handlers/converters.py +1 -1
- nv_ingest_api/util/exception_handlers/decorators.py +309 -51
- nv_ingest_api/util/image_processing/processing.py +1 -1
- nv_ingest_api/util/logging/configuration.py +15 -8
- nv_ingest_api/util/pdf/pdfium.py +2 -2
- nv_ingest_api/util/schema/__init__.py +3 -0
- nv_ingest_api/util/service_clients/redis/__init__.py +3 -0
- nv_ingest_api/util/service_clients/redis/redis_client.py +1 -1
- nv_ingest_api/util/service_clients/rest/rest_client.py +2 -2
- nv_ingest_api/util/system/__init__.py +0 -0
- nv_ingest_api/util/system/hardware_info.py +430 -0
- {nv_ingest_api-25.4.2.dist-info → nv_ingest_api-25.6.0.dist-info}/METADATA +2 -1
- {nv_ingest_api-25.4.2.dist-info → nv_ingest_api-25.6.0.dist-info}/RECORD +46 -41
- {nv_ingest_api-25.4.2.dist-info → nv_ingest_api-25.6.0.dist-info}/WHEEL +1 -1
- {nv_ingest_api-25.4.2.dist-info → nv_ingest_api-25.6.0.dist-info}/licenses/LICENSE +0 -0
- {nv_ingest_api-25.4.2.dist-info → nv_ingest_api-25.6.0.dist-info}/top_level.txt +0 -0
|
@@ -709,7 +709,13 @@ def postprocess_results(
|
|
|
709
709
|
raise ValueError(f"Error in postprocessing {result.shape} and {original_image_shape}: {e}")
|
|
710
710
|
|
|
711
711
|
for box, score, label in zip(bboxes, scores, labels):
|
|
712
|
-
|
|
712
|
+
# TODO(Devin): Sometimes we get back unexpected class labels?
|
|
713
|
+
if (label < 0) or (label >= len(class_labels)):
|
|
714
|
+
logger.warning(f"Invalid class label {label} found in postprocessing")
|
|
715
|
+
continue
|
|
716
|
+
else:
|
|
717
|
+
class_name = class_labels[int(label)]
|
|
718
|
+
|
|
713
719
|
annotation_dict[class_name].append([round(float(x), 4) for x in np.concatenate((box, [score]))])
|
|
714
720
|
|
|
715
721
|
out.append(annotation_dict)
|
|
@@ -129,7 +129,7 @@ class NimClient:
|
|
|
129
129
|
"""
|
|
130
130
|
if self.protocol == "grpc":
|
|
131
131
|
logger.debug("Performing gRPC inference for a batch...")
|
|
132
|
-
response = self._grpc_infer(batch_input, model_name)
|
|
132
|
+
response = self._grpc_infer(batch_input, model_name, **kwargs)
|
|
133
133
|
logger.debug("gRPC inference received response for a batch")
|
|
134
134
|
elif self.protocol == "http":
|
|
135
135
|
logger.debug("Performing HTTP inference for a batch...")
|
|
@@ -221,7 +221,7 @@ class NimClient:
|
|
|
221
221
|
|
|
222
222
|
return all_results
|
|
223
223
|
|
|
224
|
-
def _grpc_infer(self, formatted_input: np.ndarray, model_name: str) -> np.ndarray:
|
|
224
|
+
def _grpc_infer(self, formatted_input: np.ndarray, model_name: str, **kwargs) -> np.ndarray:
|
|
225
225
|
"""
|
|
226
226
|
Perform inference using the gRPC protocol.
|
|
227
227
|
|
|
@@ -238,16 +238,24 @@ class NimClient:
|
|
|
238
238
|
The output of the model as a numpy array.
|
|
239
239
|
"""
|
|
240
240
|
|
|
241
|
-
|
|
242
|
-
|
|
241
|
+
parameters = kwargs.get("parameters", {})
|
|
242
|
+
output_names = kwargs.get("outputs", ["output"])
|
|
243
|
+
dtype = kwargs.get("dtype", "FP32")
|
|
244
|
+
input_name = kwargs.get("input_name", "input")
|
|
243
245
|
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
logger.debug(f"gRPC inference response: {response}")
|
|
246
|
+
input_tensors = grpcclient.InferInput(input_name, formatted_input.shape, datatype=dtype)
|
|
247
|
+
input_tensors.set_data_from_numpy(formatted_input)
|
|
247
248
|
|
|
248
|
-
|
|
249
|
+
outputs = [grpcclient.InferRequestedOutput(output_name) for output_name in output_names]
|
|
250
|
+
response = self.client.infer(
|
|
251
|
+
model_name=model_name, parameters=parameters, inputs=[input_tensors], outputs=outputs
|
|
252
|
+
)
|
|
253
|
+
logger.debug(f"gRPC inference response: {response}")
|
|
249
254
|
|
|
250
|
-
|
|
255
|
+
if len(outputs) == 1:
|
|
256
|
+
return response.as_numpy(outputs[0].name())
|
|
257
|
+
else:
|
|
258
|
+
return [response.as_numpy(output.name()) for output in outputs]
|
|
251
259
|
|
|
252
260
|
def _http_infer(self, formatted_input: dict) -> dict:
|
|
253
261
|
"""
|
|
@@ -31,13 +31,15 @@ def traceable(trace_name=None):
|
|
|
31
31
|
|
|
32
32
|
Notes
|
|
33
33
|
-----
|
|
34
|
-
The decorated function must accept a IngestControlMessage object as its
|
|
35
|
-
|
|
36
|
-
|
|
34
|
+
The decorated function must accept a IngestControlMessage object as one of its arguments.
|
|
35
|
+
For a regular function, this is expected to be the first argument; for a class method,
|
|
36
|
+
this is expected to be the second argument (after 'self'). The IngestControlMessage object
|
|
37
|
+
must implement `has_metadata`, `get_metadata`, and `set_metadata` methods used by the decorator
|
|
38
|
+
to check for the trace tagging flag and to add trace metadata.
|
|
37
39
|
|
|
38
40
|
The trace metadata added by the decorator includes two entries:
|
|
39
|
-
- 'trace::entry::<trace_name>': The
|
|
40
|
-
- 'trace::exit::<trace_name>': The
|
|
41
|
+
- 'trace::entry::<trace_name>': The timestamp marking the function's entry.
|
|
42
|
+
- 'trace::exit::<trace_name>': The timestamp marking the function's exit.
|
|
41
43
|
|
|
42
44
|
Example
|
|
43
45
|
-------
|
|
@@ -47,23 +49,25 @@ def traceable(trace_name=None):
|
|
|
47
49
|
... def process_message(message):
|
|
48
50
|
... pass
|
|
49
51
|
|
|
50
|
-
Applying the decorator with a custom trace name:
|
|
51
|
-
|
|
52
|
-
>>> @traceable(custom_trace_name="CustomTraceName")
|
|
53
|
-
... def process_message(message):
|
|
54
|
-
... pass
|
|
55
|
-
|
|
56
|
-
In both examples, `process_message` will have entry and exit timestamps added to the
|
|
57
|
-
IngestControlMessage's metadata if 'config::add_trace_tagging' is True.
|
|
52
|
+
Applying the decorator with a custom trace name on a class method:
|
|
58
53
|
|
|
54
|
+
>>> class Processor:
|
|
55
|
+
... @traceable(trace_name="CustomTrace")
|
|
56
|
+
... def process(self, message):
|
|
57
|
+
... pass
|
|
59
58
|
"""
|
|
60
59
|
|
|
61
60
|
def decorator_trace_tagging(func):
|
|
62
61
|
@functools.wraps(func)
|
|
63
62
|
def wrapper_trace_tagging(*args, **kwargs):
|
|
64
|
-
# Assuming the first argument is always the message
|
|
65
63
|
ts_fetched = datetime.now()
|
|
66
|
-
|
|
64
|
+
# Determine which argument is the message.
|
|
65
|
+
if hasattr(args[0], "has_metadata"):
|
|
66
|
+
message = args[0]
|
|
67
|
+
elif len(args) > 1 and hasattr(args[1], "has_metadata"):
|
|
68
|
+
message = args[1]
|
|
69
|
+
else:
|
|
70
|
+
raise ValueError("traceable decorator could not find a message argument with 'has_metadata()'")
|
|
67
71
|
|
|
68
72
|
do_trace_tagging = (message.has_metadata("config::add_trace_tagging") is True) and (
|
|
69
73
|
message.get_metadata("config::add_trace_tagging") is True
|
|
@@ -79,7 +83,7 @@ def traceable(trace_name=None):
|
|
|
79
83
|
message.set_timestamp(f"trace::entry::{trace_prefix}_channel_in", ts_send)
|
|
80
84
|
message.set_timestamp(f"trace::exit::{trace_prefix}_channel_in", ts_fetched)
|
|
81
85
|
|
|
82
|
-
# Call the decorated function
|
|
86
|
+
# Call the decorated function.
|
|
83
87
|
result = func(*args, **kwargs)
|
|
84
88
|
|
|
85
89
|
if do_trace_tagging:
|
|
@@ -129,7 +129,7 @@ class ChartExtractorSchema(BaseModel):
|
|
|
129
129
|
@field_validator("max_queue_size", "n_workers")
|
|
130
130
|
def check_positive(cls, v, field):
|
|
131
131
|
if v <= 0:
|
|
132
|
-
raise ValueError(f"{field.field_name} must be greater than
|
|
132
|
+
raise ValueError(f"{field.field_name} must be greater than 0.")
|
|
133
133
|
return v
|
|
134
134
|
|
|
135
135
|
model_config = ConfigDict(extra="forbid")
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
|
|
8
|
+
from pydantic import ConfigDict, BaseModel
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class HtmlExtractorSchema(BaseModel):
|
|
14
|
+
"""
|
|
15
|
+
Configuration schema for the Html extractor settings.
|
|
16
|
+
|
|
17
|
+
Parameters
|
|
18
|
+
----------
|
|
19
|
+
max_queue_size : int, default=1
|
|
20
|
+
The maximum number of items allowed in the processing queue.
|
|
21
|
+
|
|
22
|
+
n_workers : int, default=16
|
|
23
|
+
The number of worker threads to use for processing.
|
|
24
|
+
|
|
25
|
+
raise_on_failure : bool, default=False
|
|
26
|
+
A flag indicating whether to raise an exception on processing failure.
|
|
27
|
+
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
max_queue_size: int = 1
|
|
31
|
+
n_workers: int = 16
|
|
32
|
+
raise_on_failure: bool = False
|
|
33
|
+
|
|
34
|
+
model_config = ConfigDict(extra="forbid")
|
|
@@ -122,7 +122,7 @@ class InfographicExtractorSchema(BaseModel):
|
|
|
122
122
|
@field_validator("max_queue_size", "n_workers")
|
|
123
123
|
def check_positive(cls, v, field):
|
|
124
124
|
if v <= 0:
|
|
125
|
-
raise ValueError(f"{field.field_name} must be greater than
|
|
125
|
+
raise ValueError(f"{field.field_name} must be greater than 0.")
|
|
126
126
|
return v
|
|
127
127
|
|
|
128
128
|
model_config = ConfigDict(extra="forbid")
|
|
@@ -131,7 +131,7 @@ class NemoRetrieverParseConfigSchema(BaseModel):
|
|
|
131
131
|
nemoretriever_parse_endpoints: Tuple[Optional[str], Optional[str]] = (None, None)
|
|
132
132
|
nemoretriever_parse_infer_protocol: str = ""
|
|
133
133
|
|
|
134
|
-
|
|
134
|
+
nemoretriever_parse_model_name: str = "nvidia/nemoretriever-parse"
|
|
135
135
|
|
|
136
136
|
timeout: float = 300.0
|
|
137
137
|
|
|
@@ -122,7 +122,7 @@ class TableExtractorSchema(BaseModel):
|
|
|
122
122
|
@field_validator("max_queue_size", "n_workers")
|
|
123
123
|
def check_positive(cls, v, field):
|
|
124
124
|
if v <= 0:
|
|
125
|
-
raise ValueError(f"{field.field_name} must be greater than
|
|
125
|
+
raise ValueError(f"{field.field_name} must be greater than 0.")
|
|
126
126
|
return v
|
|
127
127
|
|
|
128
128
|
endpoint_config: Optional[TableExtractorConfigSchema] = None
|
|
@@ -2,22 +2,36 @@
|
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
# SPDX-License-Identifier: Apache-2.0
|
|
4
4
|
|
|
5
|
+
from pydantic import BaseModel, Field
|
|
6
|
+
from typing import Optional, Literal, Annotated
|
|
5
7
|
|
|
6
|
-
from typing import Optional, Literal
|
|
7
8
|
|
|
8
|
-
|
|
9
|
-
|
|
9
|
+
class MessageBrokerClientSchema(BaseModel):
|
|
10
|
+
"""
|
|
11
|
+
Configuration schema for message broker client connections.
|
|
12
|
+
Supports Redis or simple in-memory clients.
|
|
13
|
+
"""
|
|
10
14
|
|
|
15
|
+
host: str = Field(default="redis", description="Hostname of the broker service.")
|
|
11
16
|
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
17
|
+
port: Annotated[int, Field(gt=0, lt=65536)] = Field(
|
|
18
|
+
default=6379, description="Port to connect to. Must be between 1 and 65535."
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
client_type: Literal["redis", "simple"] = Field(
|
|
22
|
+
default="redis", description="Type of broker client. Supported values: 'redis', 'simple'."
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
broker_params: Optional[dict] = Field(
|
|
26
|
+
default_factory=dict, description="Optional parameters passed to the broker client."
|
|
27
|
+
)
|
|
15
28
|
|
|
16
|
-
|
|
17
|
-
|
|
29
|
+
connection_timeout: Annotated[int, Field(ge=0)] = Field(
|
|
30
|
+
default=300, description="Connection timeout in seconds. Must be >= 0."
|
|
31
|
+
)
|
|
18
32
|
|
|
19
|
-
|
|
33
|
+
max_backoff: Annotated[int, Field(ge=0)] = Field(
|
|
34
|
+
default=300, description="Maximum backoff time in seconds. Must be >= 0."
|
|
35
|
+
)
|
|
20
36
|
|
|
21
|
-
|
|
22
|
-
max_backoff: Optional[Annotated[int, Field(ge=0)]] = 300
|
|
23
|
-
max_retries: Optional[Annotated[int, Field(ge=0)]] = 0
|
|
37
|
+
max_retries: Annotated[int, Field(ge=0)] = Field(default=0, description="Maximum number of retries. Must be >= 0.")
|
|
@@ -160,29 +160,40 @@ class IngestTaskSchema(BaseModelNoExt):
|
|
|
160
160
|
@model_validator(mode="before")
|
|
161
161
|
@classmethod
|
|
162
162
|
def check_task_properties_type(cls, values):
|
|
163
|
-
task_type
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
TaskTypeEnum
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
163
|
+
task_type = values.get("type")
|
|
164
|
+
task_properties = values.get("task_properties", {})
|
|
165
|
+
|
|
166
|
+
# Ensure task_type is lowercased and converted to enum early
|
|
167
|
+
if isinstance(task_type, str):
|
|
168
|
+
task_type = task_type.lower()
|
|
169
|
+
try:
|
|
170
|
+
task_type = TaskTypeEnum(task_type)
|
|
171
|
+
except ValueError:
|
|
172
|
+
raise ValueError(f"{task_type} is not a valid TaskTypeEnum value")
|
|
173
|
+
|
|
174
|
+
task_type_to_schema = {
|
|
175
|
+
TaskTypeEnum.CAPTION: IngestTaskCaptionSchema,
|
|
176
|
+
TaskTypeEnum.DEDUP: IngestTaskDedupSchema,
|
|
177
|
+
TaskTypeEnum.EMBED: IngestTaskEmbedSchema,
|
|
178
|
+
TaskTypeEnum.EXTRACT: IngestTaskExtractSchema,
|
|
179
|
+
TaskTypeEnum.FILTER: IngestTaskFilterSchema,
|
|
180
|
+
TaskTypeEnum.SPLIT: IngestTaskSplitSchema,
|
|
181
|
+
TaskTypeEnum.STORE_EMBEDDING: IngestTaskStoreEmbedSchema,
|
|
182
|
+
TaskTypeEnum.STORE: IngestTaskStoreSchema,
|
|
183
|
+
TaskTypeEnum.VDB_UPLOAD: IngestTaskVdbUploadSchema,
|
|
184
|
+
TaskTypeEnum.AUDIO_DATA_EXTRACT: IngestTaskAudioExtraction,
|
|
185
|
+
TaskTypeEnum.TABLE_DATA_EXTRACT: IngestTaskTableExtraction,
|
|
186
|
+
TaskTypeEnum.CHART_DATA_EXTRACT: IngestTaskChartExtraction,
|
|
187
|
+
TaskTypeEnum.INFOGRAPHIC_DATA_EXTRACT: IngestTaskInfographicExtraction,
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
expected_schema_cls = task_type_to_schema.get(task_type)
|
|
191
|
+
if expected_schema_cls is None:
|
|
192
|
+
raise ValueError(f"Unsupported or missing task_type '{task_type}'")
|
|
193
|
+
|
|
194
|
+
validated_task_properties = expected_schema_cls(**task_properties)
|
|
195
|
+
values["type"] = task_type # ensure type is now always the enum
|
|
196
|
+
values["task_properties"] = validated_task_properties
|
|
186
197
|
return values
|
|
187
198
|
|
|
188
199
|
@field_validator("type", mode="before")
|
|
@@ -5,7 +5,7 @@
|
|
|
5
5
|
|
|
6
6
|
import logging
|
|
7
7
|
|
|
8
|
-
from pydantic import ConfigDict, BaseModel
|
|
8
|
+
from pydantic import ConfigDict, BaseModel, Field
|
|
9
9
|
|
|
10
10
|
from nv_ingest_api.util.logging.configuration import LogLevel
|
|
11
11
|
|
|
@@ -13,13 +13,14 @@ logger = logging.getLogger(__name__)
|
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
class TextEmbeddingSchema(BaseModel):
|
|
16
|
-
api_key: str = "api_key"
|
|
17
|
-
batch_size: int = 4
|
|
18
|
-
embedding_model: str = "nvidia/nv-embedqa-
|
|
19
|
-
embedding_nim_endpoint: str = "http://embedding:8000/v1"
|
|
20
|
-
encoding_format: str = "float"
|
|
21
|
-
httpx_log_level: LogLevel = LogLevel.WARNING
|
|
22
|
-
input_type: str = "passage"
|
|
23
|
-
raise_on_failure: bool = False
|
|
24
|
-
truncate: str = "END"
|
|
16
|
+
api_key: str = Field(default="api_key")
|
|
17
|
+
batch_size: int = Field(default=4)
|
|
18
|
+
embedding_model: str = Field(default="nvidia/llama-3.2-nv-embedqa-1b-v2")
|
|
19
|
+
embedding_nim_endpoint: str = Field(default="http://embedding:8000/v1")
|
|
20
|
+
encoding_format: str = Field(default="float")
|
|
21
|
+
httpx_log_level: LogLevel = Field(default=LogLevel.WARNING)
|
|
22
|
+
input_type: str = Field(default="passage")
|
|
23
|
+
raise_on_failure: bool = Field(default=False)
|
|
24
|
+
truncate: str = Field(default="END")
|
|
25
|
+
|
|
25
26
|
model_config = ConfigDict(extra="forbid")
|
|
@@ -2,21 +2,23 @@
|
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
# SPDX-License-Identifier: Apache-2.0
|
|
4
4
|
|
|
5
|
-
from pydantic import Field, BaseModel, field_validator
|
|
5
|
+
from pydantic import Field, BaseModel, field_validator, ConfigDict
|
|
6
6
|
|
|
7
7
|
from typing import Optional
|
|
8
8
|
|
|
9
|
-
from typing_extensions import Annotated
|
|
10
|
-
|
|
11
9
|
|
|
12
10
|
class TextSplitterSchema(BaseModel):
|
|
13
11
|
tokenizer: Optional[str] = None
|
|
14
|
-
chunk_size:
|
|
15
|
-
chunk_overlap:
|
|
12
|
+
chunk_size: int = Field(default=1024, gt=0)
|
|
13
|
+
chunk_overlap: int = Field(default=150, ge=0)
|
|
16
14
|
raise_on_failure: bool = False
|
|
17
15
|
|
|
18
16
|
@field_validator("chunk_overlap")
|
|
19
|
-
|
|
20
|
-
|
|
17
|
+
@classmethod
|
|
18
|
+
def check_chunk_overlap(cls, v, values):
|
|
19
|
+
chunk_size = values.data.get("chunk_size")
|
|
20
|
+
if chunk_size is not None and v >= chunk_size:
|
|
21
21
|
raise ValueError("chunk_overlap must be less than chunk_size")
|
|
22
22
|
return v
|
|
23
|
+
|
|
24
|
+
model_config = ConfigDict(extra="forbid")
|
|
@@ -116,6 +116,7 @@ def _upload_images_to_minio(df: pd.DataFrame, params: Dict[str, Any]) -> pd.Data
|
|
|
116
116
|
if "content" not in metadata:
|
|
117
117
|
logger.error("Row %s: missing 'content' in metadata", idx)
|
|
118
118
|
continue
|
|
119
|
+
|
|
119
120
|
if "source_metadata" not in metadata or not isinstance(metadata["source_metadata"], dict):
|
|
120
121
|
logger.error("Row %s: missing or invalid 'source_metadata' in metadata", idx)
|
|
121
122
|
continue
|
|
@@ -230,28 +230,35 @@ def _async_runner(
|
|
|
230
230
|
def _add_embeddings(row, embeddings, info_msgs):
|
|
231
231
|
"""
|
|
232
232
|
Updates a DataFrame row with embedding data and associated error info.
|
|
233
|
+
Ensures the 'embedding' field is always present, even if None.
|
|
233
234
|
|
|
234
235
|
Parameters
|
|
235
236
|
----------
|
|
236
237
|
row : pandas.Series
|
|
237
238
|
A row of the DataFrame.
|
|
238
|
-
embeddings :
|
|
239
|
-
|
|
240
|
-
info_msgs :
|
|
241
|
-
|
|
239
|
+
embeddings : dict
|
|
240
|
+
Dictionary mapping row indices to embeddings.
|
|
241
|
+
info_msgs : dict
|
|
242
|
+
Dictionary mapping row indices to info message dicts.
|
|
242
243
|
|
|
243
244
|
Returns
|
|
244
245
|
-------
|
|
245
246
|
pandas.Series
|
|
246
|
-
The updated row with embedding and
|
|
247
|
+
The updated row with 'embedding', 'info_message_metadata', and
|
|
248
|
+
'_contains_embeddings' appropriately set.
|
|
247
249
|
"""
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
250
|
+
embedding = embeddings.get(row.name, None)
|
|
251
|
+
info_msg = info_msgs.get(row.name, None)
|
|
252
|
+
|
|
253
|
+
# Always set embedding, even if None
|
|
254
|
+
row["metadata"]["embedding"] = embedding
|
|
255
|
+
|
|
256
|
+
if info_msg:
|
|
257
|
+
row["metadata"]["info_message_metadata"] = info_msg
|
|
251
258
|
row["document_type"] = ContentTypeEnum.INFO_MSG
|
|
252
259
|
row["_contains_embeddings"] = False
|
|
253
260
|
else:
|
|
254
|
-
row["_contains_embeddings"] =
|
|
261
|
+
row["_contains_embeddings"] = embedding is not None
|
|
255
262
|
|
|
256
263
|
return row
|
|
257
264
|
|
|
@@ -287,7 +294,7 @@ def _get_pandas_table_content(row):
|
|
|
287
294
|
str
|
|
288
295
|
The table/chart content from the row.
|
|
289
296
|
"""
|
|
290
|
-
return row
|
|
297
|
+
return row.get("table_metadata", {}).get("table_content")
|
|
291
298
|
|
|
292
299
|
|
|
293
300
|
def _get_pandas_image_content(row):
|
|
@@ -304,7 +311,14 @@ def _get_pandas_image_content(row):
|
|
|
304
311
|
str
|
|
305
312
|
The image caption from the row.
|
|
306
313
|
"""
|
|
307
|
-
return row
|
|
314
|
+
return row.get("image_metadata", {}).get("caption")
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def _get_pandas_audio_content(row):
|
|
318
|
+
"""
|
|
319
|
+
A pandas UDF used to select extracted audio transcription to be used to create embeddings.
|
|
320
|
+
"""
|
|
321
|
+
return row.get("audio_metadata", {}).get("audio_transcript")
|
|
308
322
|
|
|
309
323
|
|
|
310
324
|
# ------------------------------------------------------------------------------
|
|
@@ -352,13 +366,6 @@ def _generate_batches(prompts: List[str], batch_size: int = 100) -> List[str]:
|
|
|
352
366
|
return [batch for batch in _batch_generator(prompts, batch_size)]
|
|
353
367
|
|
|
354
368
|
|
|
355
|
-
def _get_pandas_audio_content(row):
|
|
356
|
-
"""
|
|
357
|
-
A pandas UDF used to select extracted audio transcription to be used to create embeddings.
|
|
358
|
-
"""
|
|
359
|
-
return row["audio_metadata"]["audio_transcript"]
|
|
360
|
-
|
|
361
|
-
|
|
362
369
|
# ------------------------------------------------------------------------------
|
|
363
370
|
# DataFrame Concatenation Utility
|
|
364
371
|
# ------------------------------------------------------------------------------
|
|
@@ -408,17 +415,20 @@ def transform_create_text_embeddings_internal(
|
|
|
408
415
|
execution_trace_log: Optional[Dict] = None,
|
|
409
416
|
) -> Tuple[pd.DataFrame, Dict]:
|
|
410
417
|
"""
|
|
411
|
-
Generates text embeddings for supported content types (TEXT, STRUCTURED, IMAGE)
|
|
418
|
+
Generates text embeddings for supported content types (TEXT, STRUCTURED, IMAGE, AUDIO)
|
|
412
419
|
from a pandas DataFrame using asynchronous requests.
|
|
413
420
|
|
|
421
|
+
This function ensures that even if the extracted content is empty or None,
|
|
422
|
+
the embedding field is explicitly created and set to None.
|
|
423
|
+
|
|
414
424
|
Parameters
|
|
415
425
|
----------
|
|
416
426
|
df_transform_ledger : pd.DataFrame
|
|
417
427
|
The DataFrame containing content for embedding extraction.
|
|
418
428
|
task_config : Dict[str, Any]
|
|
419
429
|
Dictionary containing task properties (e.g., filter error flag).
|
|
420
|
-
transform_config :
|
|
421
|
-
Validated configuration for text embedding extraction
|
|
430
|
+
transform_config : TextEmbeddingSchema, optional
|
|
431
|
+
Validated configuration for text embedding extraction.
|
|
422
432
|
execution_trace_log : Optional[Dict], optional
|
|
423
433
|
Optional trace information for debugging or logging (default is None).
|
|
424
434
|
|
|
@@ -429,20 +439,20 @@ def transform_create_text_embeddings_internal(
|
|
|
429
439
|
- The updated DataFrame with embeddings applied.
|
|
430
440
|
- A dictionary with trace information.
|
|
431
441
|
"""
|
|
432
|
-
|
|
442
|
+
api_key = task_config.get("api_key") or transform_config.api_key
|
|
443
|
+
endpoint_url = task_config.get("endpoint_url") or transform_config.embedding_nim_endpoint
|
|
444
|
+
model_name = task_config.get("model_name") or transform_config.embedding_model
|
|
433
445
|
|
|
434
446
|
if execution_trace_log is None:
|
|
435
447
|
execution_trace_log = {}
|
|
436
448
|
logger.debug("No trace_info provided. Initialized empty trace_info dictionary.")
|
|
437
449
|
|
|
438
|
-
# TODO(Devin)
|
|
439
450
|
if df_transform_ledger.empty:
|
|
440
451
|
return df_transform_ledger, {"trace_info": execution_trace_log}
|
|
441
452
|
|
|
442
453
|
embedding_dataframes = []
|
|
443
|
-
content_masks = []
|
|
454
|
+
content_masks = []
|
|
444
455
|
|
|
445
|
-
# Define pandas content extractors for supported content types.
|
|
446
456
|
pandas_content_extractor = {
|
|
447
457
|
ContentTypeEnum.TEXT: _get_pandas_text_content,
|
|
448
458
|
ContentTypeEnum.STRUCTURED: _get_pandas_table_content,
|
|
@@ -451,49 +461,62 @@ def transform_create_text_embeddings_internal(
|
|
|
451
461
|
ContentTypeEnum.VIDEO: lambda x: None, # Not supported yet.
|
|
452
462
|
}
|
|
453
463
|
|
|
454
|
-
logger.debug("Generating text embeddings for supported content types: TEXT, STRUCTURED, IMAGE.")
|
|
455
|
-
|
|
456
464
|
def _content_type_getter(row):
|
|
457
465
|
return row["content_metadata"]["type"]
|
|
458
466
|
|
|
459
|
-
# Process each supported content type.
|
|
460
467
|
for content_type, content_getter in pandas_content_extractor.items():
|
|
461
468
|
if not content_getter:
|
|
462
469
|
logger.debug(f"Skipping unsupported content type: {content_type}")
|
|
463
470
|
continue
|
|
464
471
|
|
|
472
|
+
# Get rows matching the content type
|
|
465
473
|
content_mask = df_transform_ledger["metadata"].apply(_content_type_getter) == content_type.value
|
|
466
474
|
if not content_mask.any():
|
|
467
475
|
continue
|
|
468
476
|
|
|
469
|
-
#
|
|
470
|
-
|
|
471
|
-
non_empty_mask = extracted_content.notna() & (extracted_content.str.strip() != "")
|
|
472
|
-
final_mask = content_mask & non_empty_mask
|
|
473
|
-
if not final_mask.any():
|
|
474
|
-
continue
|
|
477
|
+
# Always include all content_mask rows and prepare them
|
|
478
|
+
df_content = df_transform_ledger.loc[content_mask].copy().reset_index(drop=True)
|
|
475
479
|
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
transform_config.api_key,
|
|
482
|
-
transform_config.embedding_nim_endpoint,
|
|
483
|
-
transform_config.embedding_model,
|
|
484
|
-
transform_config.encoding_format,
|
|
485
|
-
transform_config.input_type,
|
|
486
|
-
transform_config.truncate,
|
|
487
|
-
False,
|
|
480
|
+
# Extract content and normalize empty or non-str to None
|
|
481
|
+
extracted_content = (
|
|
482
|
+
df_content["metadata"]
|
|
483
|
+
.apply(content_getter)
|
|
484
|
+
.apply(lambda x: x.strip() if isinstance(x, str) and x.strip() else None)
|
|
488
485
|
)
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
486
|
+
df_content["_content"] = extracted_content
|
|
487
|
+
|
|
488
|
+
# Prepare batches for only valid (non-None) content
|
|
489
|
+
valid_content_mask = df_content["_content"].notna()
|
|
490
|
+
if valid_content_mask.any():
|
|
491
|
+
filtered_content_batches = _generate_batches(
|
|
492
|
+
df_content.loc[valid_content_mask, "_content"].tolist(), batch_size=transform_config.batch_size
|
|
493
|
+
)
|
|
494
|
+
content_embeddings = _async_runner(
|
|
495
|
+
filtered_content_batches,
|
|
496
|
+
api_key,
|
|
497
|
+
endpoint_url,
|
|
498
|
+
model_name,
|
|
499
|
+
transform_config.encoding_format,
|
|
500
|
+
transform_config.input_type,
|
|
501
|
+
transform_config.truncate,
|
|
502
|
+
False,
|
|
503
|
+
)
|
|
504
|
+
# Build a simple row index -> embedding map
|
|
505
|
+
embeddings_dict = dict(
|
|
506
|
+
zip(df_content.loc[valid_content_mask].index, content_embeddings.get("embeddings", []))
|
|
507
|
+
)
|
|
508
|
+
info_msgs_dict = dict(
|
|
509
|
+
zip(df_content.loc[valid_content_mask].index, content_embeddings.get("info_msgs", []))
|
|
510
|
+
)
|
|
511
|
+
else:
|
|
512
|
+
embeddings_dict = {}
|
|
513
|
+
info_msgs_dict = {}
|
|
514
|
+
|
|
515
|
+
# Apply embeddings or None to all rows
|
|
516
|
+
df_content = df_content.apply(_add_embeddings, embeddings=embeddings_dict, info_msgs=info_msgs_dict, axis=1)
|
|
494
517
|
|
|
495
518
|
embedding_dataframes.append(df_content)
|
|
496
|
-
content_masks.append(
|
|
519
|
+
content_masks.append(content_mask)
|
|
497
520
|
|
|
498
521
|
combined_df = _concatenate_extractions_pandas(df_transform_ledger, embedding_dataframes, content_masks)
|
|
499
522
|
return combined_df, {"trace_info": execution_trace_log}
|
|
@@ -118,9 +118,15 @@ def transform_text_split_and_tokenize_internal(
|
|
|
118
118
|
)
|
|
119
119
|
|
|
120
120
|
# Filter to documents with text content.
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
)
|
|
121
|
+
text_type_condition = df_transform_ledger["document_type"] == ContentTypeEnum.TEXT
|
|
122
|
+
|
|
123
|
+
normalized_meta_df = pd.json_normalize(df_transform_ledger["metadata"], errors="ignore")
|
|
124
|
+
if "source_metadata.source_type" in normalized_meta_df.columns:
|
|
125
|
+
source_type_condition = normalized_meta_df["source_metadata.source_type"].isin(split_source_types)
|
|
126
|
+
else:
|
|
127
|
+
source_type_condition = False
|
|
128
|
+
|
|
129
|
+
bool_index = text_type_condition & source_type_condition
|
|
124
130
|
df_filtered: pd.DataFrame = df_transform_ledger.loc[bool_index]
|
|
125
131
|
|
|
126
132
|
if df_filtered.empty:
|
nv_ingest_api/util/__init__.py
CHANGED