nv-ingest-api 25.4.2__py3-none-any.whl → 25.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nv-ingest-api might be problematic. Click here for more details.

Files changed (46) hide show
  1. nv_ingest_api/internal/extract/docx/docx_extractor.py +3 -3
  2. nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/docxreader.py +142 -86
  3. nv_ingest_api/internal/extract/html/__init__.py +3 -0
  4. nv_ingest_api/internal/extract/html/html_extractor.py +84 -0
  5. nv_ingest_api/internal/extract/image/chart_extractor.py +3 -3
  6. nv_ingest_api/internal/extract/image/image_extractor.py +5 -5
  7. nv_ingest_api/internal/extract/image/image_helpers/common.py +1 -1
  8. nv_ingest_api/internal/extract/image/infographic_extractor.py +1 -1
  9. nv_ingest_api/internal/extract/image/table_extractor.py +2 -2
  10. nv_ingest_api/internal/extract/pdf/engines/nemoretriever.py +2 -2
  11. nv_ingest_api/internal/extract/pdf/engines/pdfium.py +1 -1
  12. nv_ingest_api/internal/extract/pptx/engines/pptx_helper.py +214 -188
  13. nv_ingest_api/internal/extract/pptx/pptx_extractor.py +6 -9
  14. nv_ingest_api/internal/primitives/nim/model_interface/text_embedding.py +35 -38
  15. nv_ingest_api/internal/primitives/nim/model_interface/yolox.py +7 -1
  16. nv_ingest_api/internal/primitives/nim/nim_client.py +17 -9
  17. nv_ingest_api/internal/primitives/tracing/tagging.py +20 -16
  18. nv_ingest_api/internal/schemas/extract/extract_chart_schema.py +1 -1
  19. nv_ingest_api/internal/schemas/extract/extract_html_schema.py +34 -0
  20. nv_ingest_api/internal/schemas/extract/extract_infographic_schema.py +1 -1
  21. nv_ingest_api/internal/schemas/extract/extract_pdf_schema.py +1 -1
  22. nv_ingest_api/internal/schemas/extract/extract_table_schema.py +1 -1
  23. nv_ingest_api/internal/schemas/message_brokers/message_broker_client_schema.py +26 -12
  24. nv_ingest_api/internal/schemas/meta/ingest_job_schema.py +34 -23
  25. nv_ingest_api/internal/schemas/transform/transform_text_embedding_schema.py +11 -10
  26. nv_ingest_api/internal/schemas/transform/transform_text_splitter_schema.py +9 -7
  27. nv_ingest_api/internal/store/image_upload.py +1 -0
  28. nv_ingest_api/internal/transform/embed_text.py +75 -52
  29. nv_ingest_api/internal/transform/split_text.py +9 -3
  30. nv_ingest_api/util/__init__.py +3 -0
  31. nv_ingest_api/util/exception_handlers/converters.py +1 -1
  32. nv_ingest_api/util/exception_handlers/decorators.py +309 -51
  33. nv_ingest_api/util/image_processing/processing.py +1 -1
  34. nv_ingest_api/util/logging/configuration.py +15 -8
  35. nv_ingest_api/util/pdf/pdfium.py +2 -2
  36. nv_ingest_api/util/schema/__init__.py +3 -0
  37. nv_ingest_api/util/service_clients/redis/__init__.py +3 -0
  38. nv_ingest_api/util/service_clients/redis/redis_client.py +1 -1
  39. nv_ingest_api/util/service_clients/rest/rest_client.py +2 -2
  40. nv_ingest_api/util/system/__init__.py +0 -0
  41. nv_ingest_api/util/system/hardware_info.py +430 -0
  42. {nv_ingest_api-25.4.2.dist-info → nv_ingest_api-25.6.1.dist-info}/METADATA +2 -1
  43. {nv_ingest_api-25.4.2.dist-info → nv_ingest_api-25.6.1.dist-info}/RECORD +46 -41
  44. {nv_ingest_api-25.4.2.dist-info → nv_ingest_api-25.6.1.dist-info}/WHEEL +1 -1
  45. {nv_ingest_api-25.4.2.dist-info → nv_ingest_api-25.6.1.dist-info}/licenses/LICENSE +0 -0
  46. {nv_ingest_api-25.4.2.dist-info → nv_ingest_api-25.6.1.dist-info}/top_level.txt +0 -0
@@ -5,9 +5,9 @@
5
5
  from typing import Any, Dict, List, Optional, Tuple
6
6
 
7
7
  from nv_ingest_api.internal.primitives.nim import ModelInterface
8
+ import numpy as np
8
9
 
9
10
 
10
- # Assume ModelInterface is defined elsewhere in the project.
11
11
  class EmbeddingModelInterface(ModelInterface):
12
12
  """
13
13
  An interface for handling inference with an embedding model endpoint.
@@ -22,20 +22,13 @@ class EmbeddingModelInterface(ModelInterface):
22
22
 
23
23
  def prepare_data_for_inference(self, data: Dict[str, Any]) -> Dict[str, Any]:
24
24
  """
25
- Prepare input data for embedding inference. Ensures that a 'prompts' key is provided
26
- and that its value is a list.
27
-
28
- Raises
29
- ------
30
- KeyError
31
- If the 'prompts' key is missing.
25
+ Prepare input data for embedding inference. Returns a list of strings representing the text to be embedded.
32
26
  """
33
27
  if "prompts" not in data:
34
28
  raise KeyError("Input data must include 'prompts'.")
35
- # Ensure the prompts are in list format.
36
29
  if not isinstance(data["prompts"], list):
37
30
  data["prompts"] = [data["prompts"]]
38
- return data
31
+ return {"prompts": data["prompts"]}
39
32
 
40
33
  def format_input(
41
34
  self, data: Dict[str, Any], protocol: str, max_batch_size: int, **kwargs
@@ -63,29 +56,32 @@ class EmbeddingModelInterface(ModelInterface):
63
56
  - payloads is a list of JSON-serializable payload dictionaries.
64
57
  - batch_data_list is a list of dictionaries containing the key "prompts" corresponding to each batch.
65
58
  """
66
- if protocol != "http":
67
- raise ValueError("EmbeddingModelInterface only supports HTTP protocol.")
68
-
69
- prompts = data.get("prompts", [])
70
59
 
71
60
  def chunk_list(lst, chunk_size):
61
+ lst = lst["prompts"]
72
62
  return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
73
63
 
74
- batches = chunk_list(prompts, max_batch_size)
75
- payloads = []
76
- batch_data_list = []
77
- for batch in batches:
78
- payload = {
79
- "model": kwargs.get("model_name"),
80
- "input": batch,
81
- "encoding_format": kwargs.get("encoding_format", "float"),
82
- "extra_body": {
83
- "input_type": kwargs.get("input_type", "query"),
64
+ batches = chunk_list(data, max_batch_size)
65
+ if protocol == "http":
66
+ payloads = []
67
+ batch_data_list = []
68
+ for batch in batches:
69
+ payload = {
70
+ "model": kwargs.get("model_name"),
71
+ "input": batch,
72
+ "encoding_format": kwargs.get("encoding_format", "float"),
73
+ "input_type": kwargs.get("input_type", "passage"),
84
74
  "truncate": kwargs.get("truncate", "NONE"),
85
- },
86
- }
87
- payloads.append(payload)
88
- batch_data_list.append({"prompts": batch})
75
+ }
76
+ payloads.append(payload)
77
+ batch_data_list.append({"prompts": batch})
78
+ elif protocol == "grpc":
79
+ payloads = []
80
+ batch_data_list = []
81
+ for batch in batches:
82
+ text_np = np.array([[text.encode("utf-8")] for text in batch], dtype=np.object_)
83
+ payloads.append(text_np)
84
+ batch_data_list.append({"prompts": batch})
89
85
  return payloads, batch_data_list
90
86
 
91
87
  def parse_output(self, response: Any, protocol: str, data: Optional[Dict[str, Any]] = None, **kwargs) -> Any:
@@ -108,16 +104,17 @@ class EmbeddingModelInterface(ModelInterface):
108
104
  list
109
105
  A list of generated embeddings extracted from the response.
110
106
  """
111
- if protocol != "http":
112
- raise ValueError("EmbeddingModelInterface only supports HTTP protocol.")
113
- if isinstance(response, dict):
114
- embeddings = response.get("data")
115
- if not embeddings:
116
- raise RuntimeError("Unexpected response format: 'data' key is missing or empty.")
117
- # Each item in embeddings is expected to have an 'embedding' field.
118
- return [item.get("embedding", None) for item in embeddings]
119
- else:
120
- return [str(response)]
107
+ if protocol == "http":
108
+ if isinstance(response, dict):
109
+ embeddings = response.get("data")
110
+ if not embeddings:
111
+ raise RuntimeError("Unexpected response format: 'data' key is missing or empty.")
112
+ # Each item in embeddings is expected to have an 'embedding' field.
113
+ return [item.get("embedding", None) for item in embeddings]
114
+ else:
115
+ return [str(response)]
116
+ elif protocol == "grpc":
117
+ return [res.flatten() for res in response]
121
118
 
122
119
  def process_inference_results(self, output: Any, protocol: str, **kwargs) -> Any:
123
120
  """
@@ -709,7 +709,13 @@ def postprocess_results(
709
709
  raise ValueError(f"Error in postprocessing {result.shape} and {original_image_shape}: {e}")
710
710
 
711
711
  for box, score, label in zip(bboxes, scores, labels):
712
- class_name = class_labels[int(label)]
712
+ # TODO(Devin): Sometimes we get back unexpected class labels?
713
+ if (label < 0) or (label >= len(class_labels)):
714
+ logger.warning(f"Invalid class label {label} found in postprocessing")
715
+ continue
716
+ else:
717
+ class_name = class_labels[int(label)]
718
+
713
719
  annotation_dict[class_name].append([round(float(x), 4) for x in np.concatenate((box, [score]))])
714
720
 
715
721
  out.append(annotation_dict)
@@ -129,7 +129,7 @@ class NimClient:
129
129
  """
130
130
  if self.protocol == "grpc":
131
131
  logger.debug("Performing gRPC inference for a batch...")
132
- response = self._grpc_infer(batch_input, model_name)
132
+ response = self._grpc_infer(batch_input, model_name, **kwargs)
133
133
  logger.debug("gRPC inference received response for a batch")
134
134
  elif self.protocol == "http":
135
135
  logger.debug("Performing HTTP inference for a batch...")
@@ -221,7 +221,7 @@ class NimClient:
221
221
 
222
222
  return all_results
223
223
 
224
- def _grpc_infer(self, formatted_input: np.ndarray, model_name: str) -> np.ndarray:
224
+ def _grpc_infer(self, formatted_input: np.ndarray, model_name: str, **kwargs) -> np.ndarray:
225
225
  """
226
226
  Perform inference using the gRPC protocol.
227
227
 
@@ -238,16 +238,24 @@ class NimClient:
238
238
  The output of the model as a numpy array.
239
239
  """
240
240
 
241
- input_tensors = [grpcclient.InferInput("input", formatted_input.shape, datatype="FP32")]
242
- input_tensors[0].set_data_from_numpy(formatted_input)
241
+ parameters = kwargs.get("parameters", {})
242
+ output_names = kwargs.get("outputs", ["output"])
243
+ dtype = kwargs.get("dtype", "FP32")
244
+ input_name = kwargs.get("input_name", "input")
243
245
 
244
- outputs = [grpcclient.InferRequestedOutput("output")]
245
- response = self.client.infer(model_name=model_name, inputs=input_tensors, outputs=outputs)
246
- logger.debug(f"gRPC inference response: {response}")
246
+ input_tensors = grpcclient.InferInput(input_name, formatted_input.shape, datatype=dtype)
247
+ input_tensors.set_data_from_numpy(formatted_input)
247
248
 
248
- # TODO(self.client.has_error(response)) => raise error
249
+ outputs = [grpcclient.InferRequestedOutput(output_name) for output_name in output_names]
250
+ response = self.client.infer(
251
+ model_name=model_name, parameters=parameters, inputs=[input_tensors], outputs=outputs
252
+ )
253
+ logger.debug(f"gRPC inference response: {response}")
249
254
 
250
- return response.as_numpy("output")
255
+ if len(outputs) == 1:
256
+ return response.as_numpy(outputs[0].name())
257
+ else:
258
+ return [response.as_numpy(output.name()) for output in outputs]
251
259
 
252
260
  def _http_infer(self, formatted_input: dict) -> dict:
253
261
  """
@@ -31,13 +31,15 @@ def traceable(trace_name=None):
31
31
 
32
32
  Notes
33
33
  -----
34
- The decorated function must accept a IngestControlMessage object as its first argument. The
35
- IngestControlMessage object must implement `has_metadata`, `get_metadata`, and `set_metadata`
36
- methods used by the decorator to check for the trace tagging flag and to add trace metadata.
34
+ The decorated function must accept a IngestControlMessage object as one of its arguments.
35
+ For a regular function, this is expected to be the first argument; for a class method,
36
+ this is expected to be the second argument (after 'self'). The IngestControlMessage object
37
+ must implement `has_metadata`, `get_metadata`, and `set_metadata` methods used by the decorator
38
+ to check for the trace tagging flag and to add trace metadata.
37
39
 
38
40
  The trace metadata added by the decorator includes two entries:
39
- - 'trace::entry::<trace_name>': The monotonic timestamp marking the function's entry.
40
- - 'trace::exit::<trace_name>': The monotonic timestamp marking the function's exit.
41
+ - 'trace::entry::<trace_name>': The timestamp marking the function's entry.
42
+ - 'trace::exit::<trace_name>': The timestamp marking the function's exit.
41
43
 
42
44
  Example
43
45
  -------
@@ -47,23 +49,25 @@ def traceable(trace_name=None):
47
49
  ... def process_message(message):
48
50
  ... pass
49
51
 
50
- Applying the decorator with a custom trace name:
51
-
52
- >>> @traceable(custom_trace_name="CustomTraceName")
53
- ... def process_message(message):
54
- ... pass
55
-
56
- In both examples, `process_message` will have entry and exit timestamps added to the
57
- IngestControlMessage's metadata if 'config::add_trace_tagging' is True.
52
+ Applying the decorator with a custom trace name on a class method:
58
53
 
54
+ >>> class Processor:
55
+ ... @traceable(trace_name="CustomTrace")
56
+ ... def process(self, message):
57
+ ... pass
59
58
  """
60
59
 
61
60
  def decorator_trace_tagging(func):
62
61
  @functools.wraps(func)
63
62
  def wrapper_trace_tagging(*args, **kwargs):
64
- # Assuming the first argument is always the message
65
63
  ts_fetched = datetime.now()
66
- message = args[0]
64
+ # Determine which argument is the message.
65
+ if hasattr(args[0], "has_metadata"):
66
+ message = args[0]
67
+ elif len(args) > 1 and hasattr(args[1], "has_metadata"):
68
+ message = args[1]
69
+ else:
70
+ raise ValueError("traceable decorator could not find a message argument with 'has_metadata()'")
67
71
 
68
72
  do_trace_tagging = (message.has_metadata("config::add_trace_tagging") is True) and (
69
73
  message.get_metadata("config::add_trace_tagging") is True
@@ -79,7 +83,7 @@ def traceable(trace_name=None):
79
83
  message.set_timestamp(f"trace::entry::{trace_prefix}_channel_in", ts_send)
80
84
  message.set_timestamp(f"trace::exit::{trace_prefix}_channel_in", ts_fetched)
81
85
 
82
- # Call the decorated function
86
+ # Call the decorated function.
83
87
  result = func(*args, **kwargs)
84
88
 
85
89
  if do_trace_tagging:
@@ -129,7 +129,7 @@ class ChartExtractorSchema(BaseModel):
129
129
  @field_validator("max_queue_size", "n_workers")
130
130
  def check_positive(cls, v, field):
131
131
  if v <= 0:
132
- raise ValueError(f"{field.field_name} must be greater than 10.")
132
+ raise ValueError(f"{field.field_name} must be greater than 0.")
133
133
  return v
134
134
 
135
135
  model_config = ConfigDict(extra="forbid")
@@ -0,0 +1,34 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+
6
+ import logging
7
+
8
+ from pydantic import ConfigDict, BaseModel
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class HtmlExtractorSchema(BaseModel):
14
+ """
15
+ Configuration schema for the Html extractor settings.
16
+
17
+ Parameters
18
+ ----------
19
+ max_queue_size : int, default=1
20
+ The maximum number of items allowed in the processing queue.
21
+
22
+ n_workers : int, default=16
23
+ The number of worker threads to use for processing.
24
+
25
+ raise_on_failure : bool, default=False
26
+ A flag indicating whether to raise an exception on processing failure.
27
+
28
+ """
29
+
30
+ max_queue_size: int = 1
31
+ n_workers: int = 16
32
+ raise_on_failure: bool = False
33
+
34
+ model_config = ConfigDict(extra="forbid")
@@ -122,7 +122,7 @@ class InfographicExtractorSchema(BaseModel):
122
122
  @field_validator("max_queue_size", "n_workers")
123
123
  def check_positive(cls, v, field):
124
124
  if v <= 0:
125
- raise ValueError(f"{field.field_name} must be greater than 10.")
125
+ raise ValueError(f"{field.field_name} must be greater than 0.")
126
126
  return v
127
127
 
128
128
  model_config = ConfigDict(extra="forbid")
@@ -131,7 +131,7 @@ class NemoRetrieverParseConfigSchema(BaseModel):
131
131
  nemoretriever_parse_endpoints: Tuple[Optional[str], Optional[str]] = (None, None)
132
132
  nemoretriever_parse_infer_protocol: str = ""
133
133
 
134
- model_name: str = "nvidia/nemoretriever-parse"
134
+ nemoretriever_parse_model_name: str = "nvidia/nemoretriever-parse"
135
135
 
136
136
  timeout: float = 300.0
137
137
 
@@ -122,7 +122,7 @@ class TableExtractorSchema(BaseModel):
122
122
  @field_validator("max_queue_size", "n_workers")
123
123
  def check_positive(cls, v, field):
124
124
  if v <= 0:
125
- raise ValueError(f"{field.field_name} must be greater than 10.")
125
+ raise ValueError(f"{field.field_name} must be greater than 0.")
126
126
  return v
127
127
 
128
128
  endpoint_config: Optional[TableExtractorConfigSchema] = None
@@ -2,22 +2,36 @@
2
2
  # All rights reserved.
3
3
  # SPDX-License-Identifier: Apache-2.0
4
4
 
5
+ from pydantic import BaseModel, Field
6
+ from typing import Optional, Literal, Annotated
5
7
 
6
- from typing import Optional, Literal
7
8
 
8
- from pydantic import Field, BaseModel
9
- from typing_extensions import Annotated
9
+ class MessageBrokerClientSchema(BaseModel):
10
+ """
11
+ Configuration schema for message broker client connections.
12
+ Supports Redis or simple in-memory clients.
13
+ """
10
14
 
15
+ host: str = Field(default="redis", description="Hostname of the broker service.")
11
16
 
12
- class MessageBrokerClientSchema(BaseModel):
13
- host: str = "redis"
14
- port: Annotated[int, Field(gt=0, lt=65536)] = 6379
17
+ port: Annotated[int, Field(gt=0, lt=65536)] = Field(
18
+ default=6379, description="Port to connect to. Must be between 1 and 65535."
19
+ )
20
+
21
+ client_type: Literal["redis", "simple"] = Field(
22
+ default="redis", description="Type of broker client. Supported values: 'redis', 'simple'."
23
+ )
24
+
25
+ broker_params: Optional[dict] = Field(
26
+ default_factory=dict, description="Optional parameters passed to the broker client."
27
+ )
15
28
 
16
- # Update this for new broker types
17
- client_type: Literal["redis", "simple"] = "redis" # Restrict to 'redis' or 'simple'
29
+ connection_timeout: Annotated[int, Field(ge=0)] = Field(
30
+ default=300, description="Connection timeout in seconds. Must be >= 0."
31
+ )
18
32
 
19
- broker_params: Optional[dict] = Field(default_factory=dict)
33
+ max_backoff: Annotated[int, Field(ge=0)] = Field(
34
+ default=300, description="Maximum backoff time in seconds. Must be >= 0."
35
+ )
20
36
 
21
- connection_timeout: Optional[Annotated[int, Field(ge=0)]] = 300
22
- max_backoff: Optional[Annotated[int, Field(ge=0)]] = 300
23
- max_retries: Optional[Annotated[int, Field(ge=0)]] = 0
37
+ max_retries: Annotated[int, Field(ge=0)] = Field(default=0, description="Maximum number of retries. Must be >= 0.")
@@ -160,29 +160,40 @@ class IngestTaskSchema(BaseModelNoExt):
160
160
  @model_validator(mode="before")
161
161
  @classmethod
162
162
  def check_task_properties_type(cls, values):
163
- task_type, task_properties = values.get("type"), values.get("task_properties", {})
164
- if task_type and task_properties:
165
- expected_type = {
166
- TaskTypeEnum.CAPTION: IngestTaskCaptionSchema,
167
- TaskTypeEnum.DEDUP: IngestTaskDedupSchema,
168
- TaskTypeEnum.EMBED: IngestTaskEmbedSchema,
169
- TaskTypeEnum.EXTRACT: IngestTaskExtractSchema,
170
- TaskTypeEnum.FILTER: IngestTaskFilterSchema, # Extend mapping as necessary
171
- TaskTypeEnum.SPLIT: IngestTaskSplitSchema,
172
- TaskTypeEnum.STORE_EMBEDDING: IngestTaskStoreEmbedSchema,
173
- TaskTypeEnum.STORE: IngestTaskStoreSchema,
174
- TaskTypeEnum.VDB_UPLOAD: IngestTaskVdbUploadSchema,
175
- TaskTypeEnum.AUDIO_DATA_EXTRACT: IngestTaskAudioExtraction,
176
- TaskTypeEnum.TABLE_DATA_EXTRACT: IngestTaskTableExtraction,
177
- TaskTypeEnum.CHART_DATA_EXTRACT: IngestTaskChartExtraction,
178
- TaskTypeEnum.INFOGRAPHIC_DATA_EXTRACT: IngestTaskInfographicExtraction,
179
- }.get(
180
- task_type
181
- ) # Removed .upper()
182
-
183
- # Validate task_properties against the expected schema.
184
- validated_task_properties = expected_type(**task_properties)
185
- values["task_properties"] = validated_task_properties
163
+ task_type = values.get("type")
164
+ task_properties = values.get("task_properties", {})
165
+
166
+ # Ensure task_type is lowercased and converted to enum early
167
+ if isinstance(task_type, str):
168
+ task_type = task_type.lower()
169
+ try:
170
+ task_type = TaskTypeEnum(task_type)
171
+ except ValueError:
172
+ raise ValueError(f"{task_type} is not a valid TaskTypeEnum value")
173
+
174
+ task_type_to_schema = {
175
+ TaskTypeEnum.CAPTION: IngestTaskCaptionSchema,
176
+ TaskTypeEnum.DEDUP: IngestTaskDedupSchema,
177
+ TaskTypeEnum.EMBED: IngestTaskEmbedSchema,
178
+ TaskTypeEnum.EXTRACT: IngestTaskExtractSchema,
179
+ TaskTypeEnum.FILTER: IngestTaskFilterSchema,
180
+ TaskTypeEnum.SPLIT: IngestTaskSplitSchema,
181
+ TaskTypeEnum.STORE_EMBEDDING: IngestTaskStoreEmbedSchema,
182
+ TaskTypeEnum.STORE: IngestTaskStoreSchema,
183
+ TaskTypeEnum.VDB_UPLOAD: IngestTaskVdbUploadSchema,
184
+ TaskTypeEnum.AUDIO_DATA_EXTRACT: IngestTaskAudioExtraction,
185
+ TaskTypeEnum.TABLE_DATA_EXTRACT: IngestTaskTableExtraction,
186
+ TaskTypeEnum.CHART_DATA_EXTRACT: IngestTaskChartExtraction,
187
+ TaskTypeEnum.INFOGRAPHIC_DATA_EXTRACT: IngestTaskInfographicExtraction,
188
+ }
189
+
190
+ expected_schema_cls = task_type_to_schema.get(task_type)
191
+ if expected_schema_cls is None:
192
+ raise ValueError(f"Unsupported or missing task_type '{task_type}'")
193
+
194
+ validated_task_properties = expected_schema_cls(**task_properties)
195
+ values["type"] = task_type # ensure type is now always the enum
196
+ values["task_properties"] = validated_task_properties
186
197
  return values
187
198
 
188
199
  @field_validator("type", mode="before")
@@ -5,7 +5,7 @@
5
5
 
6
6
  import logging
7
7
 
8
- from pydantic import ConfigDict, BaseModel
8
+ from pydantic import ConfigDict, BaseModel, Field
9
9
 
10
10
  from nv_ingest_api.util.logging.configuration import LogLevel
11
11
 
@@ -13,13 +13,14 @@ logger = logging.getLogger(__name__)
13
13
 
14
14
 
15
15
  class TextEmbeddingSchema(BaseModel):
16
- api_key: str = "api_key"
17
- batch_size: int = 4
18
- embedding_model: str = "nvidia/nv-embedqa-e5-v5"
19
- embedding_nim_endpoint: str = "http://embedding:8000/v1"
20
- encoding_format: str = "float"
21
- httpx_log_level: LogLevel = LogLevel.WARNING
22
- input_type: str = "passage"
23
- raise_on_failure: bool = False
24
- truncate: str = "END"
16
+ api_key: str = Field(default="api_key")
17
+ batch_size: int = Field(default=4)
18
+ embedding_model: str = Field(default="nvidia/llama-3.2-nv-embedqa-1b-v2")
19
+ embedding_nim_endpoint: str = Field(default="http://embedding:8000/v1")
20
+ encoding_format: str = Field(default="float")
21
+ httpx_log_level: LogLevel = Field(default=LogLevel.WARNING)
22
+ input_type: str = Field(default="passage")
23
+ raise_on_failure: bool = Field(default=False)
24
+ truncate: str = Field(default="END")
25
+
25
26
  model_config = ConfigDict(extra="forbid")
@@ -2,21 +2,23 @@
2
2
  # All rights reserved.
3
3
  # SPDX-License-Identifier: Apache-2.0
4
4
 
5
- from pydantic import Field, BaseModel, field_validator
5
+ from pydantic import Field, BaseModel, field_validator, ConfigDict
6
6
 
7
7
  from typing import Optional
8
8
 
9
- from typing_extensions import Annotated
10
-
11
9
 
12
10
  class TextSplitterSchema(BaseModel):
13
11
  tokenizer: Optional[str] = None
14
- chunk_size: Annotated[int, Field(gt=0)] = 1024
15
- chunk_overlap: Annotated[int, Field(ge=0)] = 150
12
+ chunk_size: int = Field(default=1024, gt=0)
13
+ chunk_overlap: int = Field(default=150, ge=0)
16
14
  raise_on_failure: bool = False
17
15
 
18
16
  @field_validator("chunk_overlap")
19
- def check_chunk_overlap(cls, v, values, **kwargs):
20
- if v is not None and "chunk_size" in values.data and v >= values.data["chunk_size"]:
17
+ @classmethod
18
+ def check_chunk_overlap(cls, v, values):
19
+ chunk_size = values.data.get("chunk_size")
20
+ if chunk_size is not None and v >= chunk_size:
21
21
  raise ValueError("chunk_overlap must be less than chunk_size")
22
22
  return v
23
+
24
+ model_config = ConfigDict(extra="forbid")
@@ -116,6 +116,7 @@ def _upload_images_to_minio(df: pd.DataFrame, params: Dict[str, Any]) -> pd.Data
116
116
  if "content" not in metadata:
117
117
  logger.error("Row %s: missing 'content' in metadata", idx)
118
118
  continue
119
+
119
120
  if "source_metadata" not in metadata or not isinstance(metadata["source_metadata"], dict):
120
121
  logger.error("Row %s: missing or invalid 'source_metadata' in metadata", idx)
121
122
  continue