nv-ingest-api 25.4.2__py3-none-any.whl → 25.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nv-ingest-api might be problematic. Click here for more details.
- nv_ingest_api/internal/extract/docx/docx_extractor.py +3 -3
- nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/docxreader.py +142 -86
- nv_ingest_api/internal/extract/html/__init__.py +3 -0
- nv_ingest_api/internal/extract/html/html_extractor.py +84 -0
- nv_ingest_api/internal/extract/image/chart_extractor.py +3 -3
- nv_ingest_api/internal/extract/image/image_extractor.py +5 -5
- nv_ingest_api/internal/extract/image/image_helpers/common.py +1 -1
- nv_ingest_api/internal/extract/image/infographic_extractor.py +1 -1
- nv_ingest_api/internal/extract/image/table_extractor.py +2 -2
- nv_ingest_api/internal/extract/pdf/engines/nemoretriever.py +2 -2
- nv_ingest_api/internal/extract/pdf/engines/pdfium.py +1 -1
- nv_ingest_api/internal/extract/pptx/engines/pptx_helper.py +214 -188
- nv_ingest_api/internal/extract/pptx/pptx_extractor.py +6 -9
- nv_ingest_api/internal/primitives/nim/model_interface/text_embedding.py +35 -38
- nv_ingest_api/internal/primitives/nim/model_interface/yolox.py +7 -1
- nv_ingest_api/internal/primitives/nim/nim_client.py +17 -9
- nv_ingest_api/internal/primitives/tracing/tagging.py +20 -16
- nv_ingest_api/internal/schemas/extract/extract_chart_schema.py +1 -1
- nv_ingest_api/internal/schemas/extract/extract_html_schema.py +34 -0
- nv_ingest_api/internal/schemas/extract/extract_infographic_schema.py +1 -1
- nv_ingest_api/internal/schemas/extract/extract_pdf_schema.py +1 -1
- nv_ingest_api/internal/schemas/extract/extract_table_schema.py +1 -1
- nv_ingest_api/internal/schemas/message_brokers/message_broker_client_schema.py +26 -12
- nv_ingest_api/internal/schemas/meta/ingest_job_schema.py +34 -23
- nv_ingest_api/internal/schemas/transform/transform_text_embedding_schema.py +11 -10
- nv_ingest_api/internal/schemas/transform/transform_text_splitter_schema.py +9 -7
- nv_ingest_api/internal/store/image_upload.py +1 -0
- nv_ingest_api/internal/transform/embed_text.py +75 -52
- nv_ingest_api/internal/transform/split_text.py +9 -3
- nv_ingest_api/util/__init__.py +3 -0
- nv_ingest_api/util/exception_handlers/converters.py +1 -1
- nv_ingest_api/util/exception_handlers/decorators.py +309 -51
- nv_ingest_api/util/image_processing/processing.py +1 -1
- nv_ingest_api/util/logging/configuration.py +15 -8
- nv_ingest_api/util/pdf/pdfium.py +2 -2
- nv_ingest_api/util/schema/__init__.py +3 -0
- nv_ingest_api/util/service_clients/redis/__init__.py +3 -0
- nv_ingest_api/util/service_clients/redis/redis_client.py +1 -1
- nv_ingest_api/util/service_clients/rest/rest_client.py +2 -2
- nv_ingest_api/util/system/__init__.py +0 -0
- nv_ingest_api/util/system/hardware_info.py +430 -0
- {nv_ingest_api-25.4.2.dist-info → nv_ingest_api-25.6.1.dist-info}/METADATA +2 -1
- {nv_ingest_api-25.4.2.dist-info → nv_ingest_api-25.6.1.dist-info}/RECORD +46 -41
- {nv_ingest_api-25.4.2.dist-info → nv_ingest_api-25.6.1.dist-info}/WHEEL +1 -1
- {nv_ingest_api-25.4.2.dist-info → nv_ingest_api-25.6.1.dist-info}/licenses/LICENSE +0 -0
- {nv_ingest_api-25.4.2.dist-info → nv_ingest_api-25.6.1.dist-info}/top_level.txt +0 -0
|
@@ -5,9 +5,9 @@
|
|
|
5
5
|
from typing import Any, Dict, List, Optional, Tuple
|
|
6
6
|
|
|
7
7
|
from nv_ingest_api.internal.primitives.nim import ModelInterface
|
|
8
|
+
import numpy as np
|
|
8
9
|
|
|
9
10
|
|
|
10
|
-
# Assume ModelInterface is defined elsewhere in the project.
|
|
11
11
|
class EmbeddingModelInterface(ModelInterface):
|
|
12
12
|
"""
|
|
13
13
|
An interface for handling inference with an embedding model endpoint.
|
|
@@ -22,20 +22,13 @@ class EmbeddingModelInterface(ModelInterface):
|
|
|
22
22
|
|
|
23
23
|
def prepare_data_for_inference(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
|
24
24
|
"""
|
|
25
|
-
Prepare input data for embedding inference.
|
|
26
|
-
and that its value is a list.
|
|
27
|
-
|
|
28
|
-
Raises
|
|
29
|
-
------
|
|
30
|
-
KeyError
|
|
31
|
-
If the 'prompts' key is missing.
|
|
25
|
+
Prepare input data for embedding inference. Returns a list of strings representing the text to be embedded.
|
|
32
26
|
"""
|
|
33
27
|
if "prompts" not in data:
|
|
34
28
|
raise KeyError("Input data must include 'prompts'.")
|
|
35
|
-
# Ensure the prompts are in list format.
|
|
36
29
|
if not isinstance(data["prompts"], list):
|
|
37
30
|
data["prompts"] = [data["prompts"]]
|
|
38
|
-
return data
|
|
31
|
+
return {"prompts": data["prompts"]}
|
|
39
32
|
|
|
40
33
|
def format_input(
|
|
41
34
|
self, data: Dict[str, Any], protocol: str, max_batch_size: int, **kwargs
|
|
@@ -63,29 +56,32 @@ class EmbeddingModelInterface(ModelInterface):
|
|
|
63
56
|
- payloads is a list of JSON-serializable payload dictionaries.
|
|
64
57
|
- batch_data_list is a list of dictionaries containing the key "prompts" corresponding to each batch.
|
|
65
58
|
"""
|
|
66
|
-
if protocol != "http":
|
|
67
|
-
raise ValueError("EmbeddingModelInterface only supports HTTP protocol.")
|
|
68
|
-
|
|
69
|
-
prompts = data.get("prompts", [])
|
|
70
59
|
|
|
71
60
|
def chunk_list(lst, chunk_size):
|
|
61
|
+
lst = lst["prompts"]
|
|
72
62
|
return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
|
|
73
63
|
|
|
74
|
-
batches = chunk_list(
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
"input_type": kwargs.get("input_type", "
|
|
64
|
+
batches = chunk_list(data, max_batch_size)
|
|
65
|
+
if protocol == "http":
|
|
66
|
+
payloads = []
|
|
67
|
+
batch_data_list = []
|
|
68
|
+
for batch in batches:
|
|
69
|
+
payload = {
|
|
70
|
+
"model": kwargs.get("model_name"),
|
|
71
|
+
"input": batch,
|
|
72
|
+
"encoding_format": kwargs.get("encoding_format", "float"),
|
|
73
|
+
"input_type": kwargs.get("input_type", "passage"),
|
|
84
74
|
"truncate": kwargs.get("truncate", "NONE"),
|
|
85
|
-
}
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
75
|
+
}
|
|
76
|
+
payloads.append(payload)
|
|
77
|
+
batch_data_list.append({"prompts": batch})
|
|
78
|
+
elif protocol == "grpc":
|
|
79
|
+
payloads = []
|
|
80
|
+
batch_data_list = []
|
|
81
|
+
for batch in batches:
|
|
82
|
+
text_np = np.array([[text.encode("utf-8")] for text in batch], dtype=np.object_)
|
|
83
|
+
payloads.append(text_np)
|
|
84
|
+
batch_data_list.append({"prompts": batch})
|
|
89
85
|
return payloads, batch_data_list
|
|
90
86
|
|
|
91
87
|
def parse_output(self, response: Any, protocol: str, data: Optional[Dict[str, Any]] = None, **kwargs) -> Any:
|
|
@@ -108,16 +104,17 @@ class EmbeddingModelInterface(ModelInterface):
|
|
|
108
104
|
list
|
|
109
105
|
A list of generated embeddings extracted from the response.
|
|
110
106
|
"""
|
|
111
|
-
if protocol
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
107
|
+
if protocol == "http":
|
|
108
|
+
if isinstance(response, dict):
|
|
109
|
+
embeddings = response.get("data")
|
|
110
|
+
if not embeddings:
|
|
111
|
+
raise RuntimeError("Unexpected response format: 'data' key is missing or empty.")
|
|
112
|
+
# Each item in embeddings is expected to have an 'embedding' field.
|
|
113
|
+
return [item.get("embedding", None) for item in embeddings]
|
|
114
|
+
else:
|
|
115
|
+
return [str(response)]
|
|
116
|
+
elif protocol == "grpc":
|
|
117
|
+
return [res.flatten() for res in response]
|
|
121
118
|
|
|
122
119
|
def process_inference_results(self, output: Any, protocol: str, **kwargs) -> Any:
|
|
123
120
|
"""
|
|
@@ -709,7 +709,13 @@ def postprocess_results(
|
|
|
709
709
|
raise ValueError(f"Error in postprocessing {result.shape} and {original_image_shape}: {e}")
|
|
710
710
|
|
|
711
711
|
for box, score, label in zip(bboxes, scores, labels):
|
|
712
|
-
|
|
712
|
+
# TODO(Devin): Sometimes we get back unexpected class labels?
|
|
713
|
+
if (label < 0) or (label >= len(class_labels)):
|
|
714
|
+
logger.warning(f"Invalid class label {label} found in postprocessing")
|
|
715
|
+
continue
|
|
716
|
+
else:
|
|
717
|
+
class_name = class_labels[int(label)]
|
|
718
|
+
|
|
713
719
|
annotation_dict[class_name].append([round(float(x), 4) for x in np.concatenate((box, [score]))])
|
|
714
720
|
|
|
715
721
|
out.append(annotation_dict)
|
|
@@ -129,7 +129,7 @@ class NimClient:
|
|
|
129
129
|
"""
|
|
130
130
|
if self.protocol == "grpc":
|
|
131
131
|
logger.debug("Performing gRPC inference for a batch...")
|
|
132
|
-
response = self._grpc_infer(batch_input, model_name)
|
|
132
|
+
response = self._grpc_infer(batch_input, model_name, **kwargs)
|
|
133
133
|
logger.debug("gRPC inference received response for a batch")
|
|
134
134
|
elif self.protocol == "http":
|
|
135
135
|
logger.debug("Performing HTTP inference for a batch...")
|
|
@@ -221,7 +221,7 @@ class NimClient:
|
|
|
221
221
|
|
|
222
222
|
return all_results
|
|
223
223
|
|
|
224
|
-
def _grpc_infer(self, formatted_input: np.ndarray, model_name: str) -> np.ndarray:
|
|
224
|
+
def _grpc_infer(self, formatted_input: np.ndarray, model_name: str, **kwargs) -> np.ndarray:
|
|
225
225
|
"""
|
|
226
226
|
Perform inference using the gRPC protocol.
|
|
227
227
|
|
|
@@ -238,16 +238,24 @@ class NimClient:
|
|
|
238
238
|
The output of the model as a numpy array.
|
|
239
239
|
"""
|
|
240
240
|
|
|
241
|
-
|
|
242
|
-
|
|
241
|
+
parameters = kwargs.get("parameters", {})
|
|
242
|
+
output_names = kwargs.get("outputs", ["output"])
|
|
243
|
+
dtype = kwargs.get("dtype", "FP32")
|
|
244
|
+
input_name = kwargs.get("input_name", "input")
|
|
243
245
|
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
logger.debug(f"gRPC inference response: {response}")
|
|
246
|
+
input_tensors = grpcclient.InferInput(input_name, formatted_input.shape, datatype=dtype)
|
|
247
|
+
input_tensors.set_data_from_numpy(formatted_input)
|
|
247
248
|
|
|
248
|
-
|
|
249
|
+
outputs = [grpcclient.InferRequestedOutput(output_name) for output_name in output_names]
|
|
250
|
+
response = self.client.infer(
|
|
251
|
+
model_name=model_name, parameters=parameters, inputs=[input_tensors], outputs=outputs
|
|
252
|
+
)
|
|
253
|
+
logger.debug(f"gRPC inference response: {response}")
|
|
249
254
|
|
|
250
|
-
|
|
255
|
+
if len(outputs) == 1:
|
|
256
|
+
return response.as_numpy(outputs[0].name())
|
|
257
|
+
else:
|
|
258
|
+
return [response.as_numpy(output.name()) for output in outputs]
|
|
251
259
|
|
|
252
260
|
def _http_infer(self, formatted_input: dict) -> dict:
|
|
253
261
|
"""
|
|
@@ -31,13 +31,15 @@ def traceable(trace_name=None):
|
|
|
31
31
|
|
|
32
32
|
Notes
|
|
33
33
|
-----
|
|
34
|
-
The decorated function must accept a IngestControlMessage object as its
|
|
35
|
-
|
|
36
|
-
|
|
34
|
+
The decorated function must accept a IngestControlMessage object as one of its arguments.
|
|
35
|
+
For a regular function, this is expected to be the first argument; for a class method,
|
|
36
|
+
this is expected to be the second argument (after 'self'). The IngestControlMessage object
|
|
37
|
+
must implement `has_metadata`, `get_metadata`, and `set_metadata` methods used by the decorator
|
|
38
|
+
to check for the trace tagging flag and to add trace metadata.
|
|
37
39
|
|
|
38
40
|
The trace metadata added by the decorator includes two entries:
|
|
39
|
-
- 'trace::entry::<trace_name>': The
|
|
40
|
-
- 'trace::exit::<trace_name>': The
|
|
41
|
+
- 'trace::entry::<trace_name>': The timestamp marking the function's entry.
|
|
42
|
+
- 'trace::exit::<trace_name>': The timestamp marking the function's exit.
|
|
41
43
|
|
|
42
44
|
Example
|
|
43
45
|
-------
|
|
@@ -47,23 +49,25 @@ def traceable(trace_name=None):
|
|
|
47
49
|
... def process_message(message):
|
|
48
50
|
... pass
|
|
49
51
|
|
|
50
|
-
Applying the decorator with a custom trace name:
|
|
51
|
-
|
|
52
|
-
>>> @traceable(custom_trace_name="CustomTraceName")
|
|
53
|
-
... def process_message(message):
|
|
54
|
-
... pass
|
|
55
|
-
|
|
56
|
-
In both examples, `process_message` will have entry and exit timestamps added to the
|
|
57
|
-
IngestControlMessage's metadata if 'config::add_trace_tagging' is True.
|
|
52
|
+
Applying the decorator with a custom trace name on a class method:
|
|
58
53
|
|
|
54
|
+
>>> class Processor:
|
|
55
|
+
... @traceable(trace_name="CustomTrace")
|
|
56
|
+
... def process(self, message):
|
|
57
|
+
... pass
|
|
59
58
|
"""
|
|
60
59
|
|
|
61
60
|
def decorator_trace_tagging(func):
|
|
62
61
|
@functools.wraps(func)
|
|
63
62
|
def wrapper_trace_tagging(*args, **kwargs):
|
|
64
|
-
# Assuming the first argument is always the message
|
|
65
63
|
ts_fetched = datetime.now()
|
|
66
|
-
|
|
64
|
+
# Determine which argument is the message.
|
|
65
|
+
if hasattr(args[0], "has_metadata"):
|
|
66
|
+
message = args[0]
|
|
67
|
+
elif len(args) > 1 and hasattr(args[1], "has_metadata"):
|
|
68
|
+
message = args[1]
|
|
69
|
+
else:
|
|
70
|
+
raise ValueError("traceable decorator could not find a message argument with 'has_metadata()'")
|
|
67
71
|
|
|
68
72
|
do_trace_tagging = (message.has_metadata("config::add_trace_tagging") is True) and (
|
|
69
73
|
message.get_metadata("config::add_trace_tagging") is True
|
|
@@ -79,7 +83,7 @@ def traceable(trace_name=None):
|
|
|
79
83
|
message.set_timestamp(f"trace::entry::{trace_prefix}_channel_in", ts_send)
|
|
80
84
|
message.set_timestamp(f"trace::exit::{trace_prefix}_channel_in", ts_fetched)
|
|
81
85
|
|
|
82
|
-
# Call the decorated function
|
|
86
|
+
# Call the decorated function.
|
|
83
87
|
result = func(*args, **kwargs)
|
|
84
88
|
|
|
85
89
|
if do_trace_tagging:
|
|
@@ -129,7 +129,7 @@ class ChartExtractorSchema(BaseModel):
|
|
|
129
129
|
@field_validator("max_queue_size", "n_workers")
|
|
130
130
|
def check_positive(cls, v, field):
|
|
131
131
|
if v <= 0:
|
|
132
|
-
raise ValueError(f"{field.field_name} must be greater than
|
|
132
|
+
raise ValueError(f"{field.field_name} must be greater than 0.")
|
|
133
133
|
return v
|
|
134
134
|
|
|
135
135
|
model_config = ConfigDict(extra="forbid")
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
|
|
8
|
+
from pydantic import ConfigDict, BaseModel
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class HtmlExtractorSchema(BaseModel):
|
|
14
|
+
"""
|
|
15
|
+
Configuration schema for the Html extractor settings.
|
|
16
|
+
|
|
17
|
+
Parameters
|
|
18
|
+
----------
|
|
19
|
+
max_queue_size : int, default=1
|
|
20
|
+
The maximum number of items allowed in the processing queue.
|
|
21
|
+
|
|
22
|
+
n_workers : int, default=16
|
|
23
|
+
The number of worker threads to use for processing.
|
|
24
|
+
|
|
25
|
+
raise_on_failure : bool, default=False
|
|
26
|
+
A flag indicating whether to raise an exception on processing failure.
|
|
27
|
+
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
max_queue_size: int = 1
|
|
31
|
+
n_workers: int = 16
|
|
32
|
+
raise_on_failure: bool = False
|
|
33
|
+
|
|
34
|
+
model_config = ConfigDict(extra="forbid")
|
|
@@ -122,7 +122,7 @@ class InfographicExtractorSchema(BaseModel):
|
|
|
122
122
|
@field_validator("max_queue_size", "n_workers")
|
|
123
123
|
def check_positive(cls, v, field):
|
|
124
124
|
if v <= 0:
|
|
125
|
-
raise ValueError(f"{field.field_name} must be greater than
|
|
125
|
+
raise ValueError(f"{field.field_name} must be greater than 0.")
|
|
126
126
|
return v
|
|
127
127
|
|
|
128
128
|
model_config = ConfigDict(extra="forbid")
|
|
@@ -131,7 +131,7 @@ class NemoRetrieverParseConfigSchema(BaseModel):
|
|
|
131
131
|
nemoretriever_parse_endpoints: Tuple[Optional[str], Optional[str]] = (None, None)
|
|
132
132
|
nemoretriever_parse_infer_protocol: str = ""
|
|
133
133
|
|
|
134
|
-
|
|
134
|
+
nemoretriever_parse_model_name: str = "nvidia/nemoretriever-parse"
|
|
135
135
|
|
|
136
136
|
timeout: float = 300.0
|
|
137
137
|
|
|
@@ -122,7 +122,7 @@ class TableExtractorSchema(BaseModel):
|
|
|
122
122
|
@field_validator("max_queue_size", "n_workers")
|
|
123
123
|
def check_positive(cls, v, field):
|
|
124
124
|
if v <= 0:
|
|
125
|
-
raise ValueError(f"{field.field_name} must be greater than
|
|
125
|
+
raise ValueError(f"{field.field_name} must be greater than 0.")
|
|
126
126
|
return v
|
|
127
127
|
|
|
128
128
|
endpoint_config: Optional[TableExtractorConfigSchema] = None
|
|
@@ -2,22 +2,36 @@
|
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
# SPDX-License-Identifier: Apache-2.0
|
|
4
4
|
|
|
5
|
+
from pydantic import BaseModel, Field
|
|
6
|
+
from typing import Optional, Literal, Annotated
|
|
5
7
|
|
|
6
|
-
from typing import Optional, Literal
|
|
7
8
|
|
|
8
|
-
|
|
9
|
-
|
|
9
|
+
class MessageBrokerClientSchema(BaseModel):
|
|
10
|
+
"""
|
|
11
|
+
Configuration schema for message broker client connections.
|
|
12
|
+
Supports Redis or simple in-memory clients.
|
|
13
|
+
"""
|
|
10
14
|
|
|
15
|
+
host: str = Field(default="redis", description="Hostname of the broker service.")
|
|
11
16
|
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
17
|
+
port: Annotated[int, Field(gt=0, lt=65536)] = Field(
|
|
18
|
+
default=6379, description="Port to connect to. Must be between 1 and 65535."
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
client_type: Literal["redis", "simple"] = Field(
|
|
22
|
+
default="redis", description="Type of broker client. Supported values: 'redis', 'simple'."
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
broker_params: Optional[dict] = Field(
|
|
26
|
+
default_factory=dict, description="Optional parameters passed to the broker client."
|
|
27
|
+
)
|
|
15
28
|
|
|
16
|
-
|
|
17
|
-
|
|
29
|
+
connection_timeout: Annotated[int, Field(ge=0)] = Field(
|
|
30
|
+
default=300, description="Connection timeout in seconds. Must be >= 0."
|
|
31
|
+
)
|
|
18
32
|
|
|
19
|
-
|
|
33
|
+
max_backoff: Annotated[int, Field(ge=0)] = Field(
|
|
34
|
+
default=300, description="Maximum backoff time in seconds. Must be >= 0."
|
|
35
|
+
)
|
|
20
36
|
|
|
21
|
-
|
|
22
|
-
max_backoff: Optional[Annotated[int, Field(ge=0)]] = 300
|
|
23
|
-
max_retries: Optional[Annotated[int, Field(ge=0)]] = 0
|
|
37
|
+
max_retries: Annotated[int, Field(ge=0)] = Field(default=0, description="Maximum number of retries. Must be >= 0.")
|
|
@@ -160,29 +160,40 @@ class IngestTaskSchema(BaseModelNoExt):
|
|
|
160
160
|
@model_validator(mode="before")
|
|
161
161
|
@classmethod
|
|
162
162
|
def check_task_properties_type(cls, values):
|
|
163
|
-
task_type
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
TaskTypeEnum
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
163
|
+
task_type = values.get("type")
|
|
164
|
+
task_properties = values.get("task_properties", {})
|
|
165
|
+
|
|
166
|
+
# Ensure task_type is lowercased and converted to enum early
|
|
167
|
+
if isinstance(task_type, str):
|
|
168
|
+
task_type = task_type.lower()
|
|
169
|
+
try:
|
|
170
|
+
task_type = TaskTypeEnum(task_type)
|
|
171
|
+
except ValueError:
|
|
172
|
+
raise ValueError(f"{task_type} is not a valid TaskTypeEnum value")
|
|
173
|
+
|
|
174
|
+
task_type_to_schema = {
|
|
175
|
+
TaskTypeEnum.CAPTION: IngestTaskCaptionSchema,
|
|
176
|
+
TaskTypeEnum.DEDUP: IngestTaskDedupSchema,
|
|
177
|
+
TaskTypeEnum.EMBED: IngestTaskEmbedSchema,
|
|
178
|
+
TaskTypeEnum.EXTRACT: IngestTaskExtractSchema,
|
|
179
|
+
TaskTypeEnum.FILTER: IngestTaskFilterSchema,
|
|
180
|
+
TaskTypeEnum.SPLIT: IngestTaskSplitSchema,
|
|
181
|
+
TaskTypeEnum.STORE_EMBEDDING: IngestTaskStoreEmbedSchema,
|
|
182
|
+
TaskTypeEnum.STORE: IngestTaskStoreSchema,
|
|
183
|
+
TaskTypeEnum.VDB_UPLOAD: IngestTaskVdbUploadSchema,
|
|
184
|
+
TaskTypeEnum.AUDIO_DATA_EXTRACT: IngestTaskAudioExtraction,
|
|
185
|
+
TaskTypeEnum.TABLE_DATA_EXTRACT: IngestTaskTableExtraction,
|
|
186
|
+
TaskTypeEnum.CHART_DATA_EXTRACT: IngestTaskChartExtraction,
|
|
187
|
+
TaskTypeEnum.INFOGRAPHIC_DATA_EXTRACT: IngestTaskInfographicExtraction,
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
expected_schema_cls = task_type_to_schema.get(task_type)
|
|
191
|
+
if expected_schema_cls is None:
|
|
192
|
+
raise ValueError(f"Unsupported or missing task_type '{task_type}'")
|
|
193
|
+
|
|
194
|
+
validated_task_properties = expected_schema_cls(**task_properties)
|
|
195
|
+
values["type"] = task_type # ensure type is now always the enum
|
|
196
|
+
values["task_properties"] = validated_task_properties
|
|
186
197
|
return values
|
|
187
198
|
|
|
188
199
|
@field_validator("type", mode="before")
|
|
@@ -5,7 +5,7 @@
|
|
|
5
5
|
|
|
6
6
|
import logging
|
|
7
7
|
|
|
8
|
-
from pydantic import ConfigDict, BaseModel
|
|
8
|
+
from pydantic import ConfigDict, BaseModel, Field
|
|
9
9
|
|
|
10
10
|
from nv_ingest_api.util.logging.configuration import LogLevel
|
|
11
11
|
|
|
@@ -13,13 +13,14 @@ logger = logging.getLogger(__name__)
|
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
class TextEmbeddingSchema(BaseModel):
|
|
16
|
-
api_key: str = "api_key"
|
|
17
|
-
batch_size: int = 4
|
|
18
|
-
embedding_model: str = "nvidia/nv-embedqa-
|
|
19
|
-
embedding_nim_endpoint: str = "http://embedding:8000/v1"
|
|
20
|
-
encoding_format: str = "float"
|
|
21
|
-
httpx_log_level: LogLevel = LogLevel.WARNING
|
|
22
|
-
input_type: str = "passage"
|
|
23
|
-
raise_on_failure: bool = False
|
|
24
|
-
truncate: str = "END"
|
|
16
|
+
api_key: str = Field(default="api_key")
|
|
17
|
+
batch_size: int = Field(default=4)
|
|
18
|
+
embedding_model: str = Field(default="nvidia/llama-3.2-nv-embedqa-1b-v2")
|
|
19
|
+
embedding_nim_endpoint: str = Field(default="http://embedding:8000/v1")
|
|
20
|
+
encoding_format: str = Field(default="float")
|
|
21
|
+
httpx_log_level: LogLevel = Field(default=LogLevel.WARNING)
|
|
22
|
+
input_type: str = Field(default="passage")
|
|
23
|
+
raise_on_failure: bool = Field(default=False)
|
|
24
|
+
truncate: str = Field(default="END")
|
|
25
|
+
|
|
25
26
|
model_config = ConfigDict(extra="forbid")
|
|
@@ -2,21 +2,23 @@
|
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
# SPDX-License-Identifier: Apache-2.0
|
|
4
4
|
|
|
5
|
-
from pydantic import Field, BaseModel, field_validator
|
|
5
|
+
from pydantic import Field, BaseModel, field_validator, ConfigDict
|
|
6
6
|
|
|
7
7
|
from typing import Optional
|
|
8
8
|
|
|
9
|
-
from typing_extensions import Annotated
|
|
10
|
-
|
|
11
9
|
|
|
12
10
|
class TextSplitterSchema(BaseModel):
|
|
13
11
|
tokenizer: Optional[str] = None
|
|
14
|
-
chunk_size:
|
|
15
|
-
chunk_overlap:
|
|
12
|
+
chunk_size: int = Field(default=1024, gt=0)
|
|
13
|
+
chunk_overlap: int = Field(default=150, ge=0)
|
|
16
14
|
raise_on_failure: bool = False
|
|
17
15
|
|
|
18
16
|
@field_validator("chunk_overlap")
|
|
19
|
-
|
|
20
|
-
|
|
17
|
+
@classmethod
|
|
18
|
+
def check_chunk_overlap(cls, v, values):
|
|
19
|
+
chunk_size = values.data.get("chunk_size")
|
|
20
|
+
if chunk_size is not None and v >= chunk_size:
|
|
21
21
|
raise ValueError("chunk_overlap must be less than chunk_size")
|
|
22
22
|
return v
|
|
23
|
+
|
|
24
|
+
model_config = ConfigDict(extra="forbid")
|
|
@@ -116,6 +116,7 @@ def _upload_images_to_minio(df: pd.DataFrame, params: Dict[str, Any]) -> pd.Data
|
|
|
116
116
|
if "content" not in metadata:
|
|
117
117
|
logger.error("Row %s: missing 'content' in metadata", idx)
|
|
118
118
|
continue
|
|
119
|
+
|
|
119
120
|
if "source_metadata" not in metadata or not isinstance(metadata["source_metadata"], dict):
|
|
120
121
|
logger.error("Row %s: missing or invalid 'source_metadata' in metadata", idx)
|
|
121
122
|
continue
|