nv-ingest-api 25.4.2__py3-none-any.whl → 25.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nv-ingest-api might be problematic. Click here for more details.

Files changed (46) hide show
  1. nv_ingest_api/internal/extract/docx/docx_extractor.py +3 -3
  2. nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/docxreader.py +142 -86
  3. nv_ingest_api/internal/extract/html/__init__.py +3 -0
  4. nv_ingest_api/internal/extract/html/html_extractor.py +84 -0
  5. nv_ingest_api/internal/extract/image/chart_extractor.py +3 -3
  6. nv_ingest_api/internal/extract/image/image_extractor.py +5 -5
  7. nv_ingest_api/internal/extract/image/image_helpers/common.py +1 -1
  8. nv_ingest_api/internal/extract/image/infographic_extractor.py +1 -1
  9. nv_ingest_api/internal/extract/image/table_extractor.py +2 -2
  10. nv_ingest_api/internal/extract/pdf/engines/nemoretriever.py +2 -2
  11. nv_ingest_api/internal/extract/pdf/engines/pdfium.py +1 -1
  12. nv_ingest_api/internal/extract/pptx/engines/pptx_helper.py +213 -187
  13. nv_ingest_api/internal/extract/pptx/pptx_extractor.py +6 -9
  14. nv_ingest_api/internal/primitives/nim/model_interface/text_embedding.py +35 -38
  15. nv_ingest_api/internal/primitives/nim/model_interface/yolox.py +7 -1
  16. nv_ingest_api/internal/primitives/nim/nim_client.py +17 -9
  17. nv_ingest_api/internal/primitives/tracing/tagging.py +20 -16
  18. nv_ingest_api/internal/schemas/extract/extract_chart_schema.py +1 -1
  19. nv_ingest_api/internal/schemas/extract/extract_html_schema.py +34 -0
  20. nv_ingest_api/internal/schemas/extract/extract_infographic_schema.py +1 -1
  21. nv_ingest_api/internal/schemas/extract/extract_pdf_schema.py +1 -1
  22. nv_ingest_api/internal/schemas/extract/extract_table_schema.py +1 -1
  23. nv_ingest_api/internal/schemas/message_brokers/message_broker_client_schema.py +26 -12
  24. nv_ingest_api/internal/schemas/meta/ingest_job_schema.py +34 -23
  25. nv_ingest_api/internal/schemas/transform/transform_text_embedding_schema.py +11 -10
  26. nv_ingest_api/internal/schemas/transform/transform_text_splitter_schema.py +9 -7
  27. nv_ingest_api/internal/store/image_upload.py +1 -0
  28. nv_ingest_api/internal/transform/embed_text.py +75 -52
  29. nv_ingest_api/internal/transform/split_text.py +9 -3
  30. nv_ingest_api/util/__init__.py +3 -0
  31. nv_ingest_api/util/exception_handlers/converters.py +1 -1
  32. nv_ingest_api/util/exception_handlers/decorators.py +309 -51
  33. nv_ingest_api/util/image_processing/processing.py +1 -1
  34. nv_ingest_api/util/logging/configuration.py +15 -8
  35. nv_ingest_api/util/pdf/pdfium.py +2 -2
  36. nv_ingest_api/util/schema/__init__.py +3 -0
  37. nv_ingest_api/util/service_clients/redis/__init__.py +3 -0
  38. nv_ingest_api/util/service_clients/redis/redis_client.py +1 -1
  39. nv_ingest_api/util/service_clients/rest/rest_client.py +2 -2
  40. nv_ingest_api/util/system/__init__.py +0 -0
  41. nv_ingest_api/util/system/hardware_info.py +430 -0
  42. {nv_ingest_api-25.4.2.dist-info → nv_ingest_api-25.6.0.dist-info}/METADATA +2 -1
  43. {nv_ingest_api-25.4.2.dist-info → nv_ingest_api-25.6.0.dist-info}/RECORD +46 -41
  44. {nv_ingest_api-25.4.2.dist-info → nv_ingest_api-25.6.0.dist-info}/WHEEL +1 -1
  45. {nv_ingest_api-25.4.2.dist-info → nv_ingest_api-25.6.0.dist-info}/licenses/LICENSE +0 -0
  46. {nv_ingest_api-25.4.2.dist-info → nv_ingest_api-25.6.0.dist-info}/top_level.txt +0 -0
@@ -709,7 +709,13 @@ def postprocess_results(
709
709
  raise ValueError(f"Error in postprocessing {result.shape} and {original_image_shape}: {e}")
710
710
 
711
711
  for box, score, label in zip(bboxes, scores, labels):
712
- class_name = class_labels[int(label)]
712
+ # TODO(Devin): Sometimes we get back unexpected class labels?
713
+ if (label < 0) or (label >= len(class_labels)):
714
+ logger.warning(f"Invalid class label {label} found in postprocessing")
715
+ continue
716
+ else:
717
+ class_name = class_labels[int(label)]
718
+
713
719
  annotation_dict[class_name].append([round(float(x), 4) for x in np.concatenate((box, [score]))])
714
720
 
715
721
  out.append(annotation_dict)
@@ -129,7 +129,7 @@ class NimClient:
129
129
  """
130
130
  if self.protocol == "grpc":
131
131
  logger.debug("Performing gRPC inference for a batch...")
132
- response = self._grpc_infer(batch_input, model_name)
132
+ response = self._grpc_infer(batch_input, model_name, **kwargs)
133
133
  logger.debug("gRPC inference received response for a batch")
134
134
  elif self.protocol == "http":
135
135
  logger.debug("Performing HTTP inference for a batch...")
@@ -221,7 +221,7 @@ class NimClient:
221
221
 
222
222
  return all_results
223
223
 
224
- def _grpc_infer(self, formatted_input: np.ndarray, model_name: str) -> np.ndarray:
224
+ def _grpc_infer(self, formatted_input: np.ndarray, model_name: str, **kwargs) -> np.ndarray:
225
225
  """
226
226
  Perform inference using the gRPC protocol.
227
227
 
@@ -238,16 +238,24 @@ class NimClient:
238
238
  The output of the model as a numpy array.
239
239
  """
240
240
 
241
- input_tensors = [grpcclient.InferInput("input", formatted_input.shape, datatype="FP32")]
242
- input_tensors[0].set_data_from_numpy(formatted_input)
241
+ parameters = kwargs.get("parameters", {})
242
+ output_names = kwargs.get("outputs", ["output"])
243
+ dtype = kwargs.get("dtype", "FP32")
244
+ input_name = kwargs.get("input_name", "input")
243
245
 
244
- outputs = [grpcclient.InferRequestedOutput("output")]
245
- response = self.client.infer(model_name=model_name, inputs=input_tensors, outputs=outputs)
246
- logger.debug(f"gRPC inference response: {response}")
246
+ input_tensors = grpcclient.InferInput(input_name, formatted_input.shape, datatype=dtype)
247
+ input_tensors.set_data_from_numpy(formatted_input)
247
248
 
248
- # TODO(self.client.has_error(response)) => raise error
249
+ outputs = [grpcclient.InferRequestedOutput(output_name) for output_name in output_names]
250
+ response = self.client.infer(
251
+ model_name=model_name, parameters=parameters, inputs=[input_tensors], outputs=outputs
252
+ )
253
+ logger.debug(f"gRPC inference response: {response}")
249
254
 
250
- return response.as_numpy("output")
255
+ if len(outputs) == 1:
256
+ return response.as_numpy(outputs[0].name())
257
+ else:
258
+ return [response.as_numpy(output.name()) for output in outputs]
251
259
 
252
260
  def _http_infer(self, formatted_input: dict) -> dict:
253
261
  """
@@ -31,13 +31,15 @@ def traceable(trace_name=None):
31
31
 
32
32
  Notes
33
33
  -----
34
- The decorated function must accept a IngestControlMessage object as its first argument. The
35
- IngestControlMessage object must implement `has_metadata`, `get_metadata`, and `set_metadata`
36
- methods used by the decorator to check for the trace tagging flag and to add trace metadata.
34
+ The decorated function must accept a IngestControlMessage object as one of its arguments.
35
+ For a regular function, this is expected to be the first argument; for a class method,
36
+ this is expected to be the second argument (after 'self'). The IngestControlMessage object
37
+ must implement `has_metadata`, `get_metadata`, and `set_metadata` methods used by the decorator
38
+ to check for the trace tagging flag and to add trace metadata.
37
39
 
38
40
  The trace metadata added by the decorator includes two entries:
39
- - 'trace::entry::<trace_name>': The monotonic timestamp marking the function's entry.
40
- - 'trace::exit::<trace_name>': The monotonic timestamp marking the function's exit.
41
+ - 'trace::entry::<trace_name>': The timestamp marking the function's entry.
42
+ - 'trace::exit::<trace_name>': The timestamp marking the function's exit.
41
43
 
42
44
  Example
43
45
  -------
@@ -47,23 +49,25 @@ def traceable(trace_name=None):
47
49
  ... def process_message(message):
48
50
  ... pass
49
51
 
50
- Applying the decorator with a custom trace name:
51
-
52
- >>> @traceable(custom_trace_name="CustomTraceName")
53
- ... def process_message(message):
54
- ... pass
55
-
56
- In both examples, `process_message` will have entry and exit timestamps added to the
57
- IngestControlMessage's metadata if 'config::add_trace_tagging' is True.
52
+ Applying the decorator with a custom trace name on a class method:
58
53
 
54
+ >>> class Processor:
55
+ ... @traceable(trace_name="CustomTrace")
56
+ ... def process(self, message):
57
+ ... pass
59
58
  """
60
59
 
61
60
  def decorator_trace_tagging(func):
62
61
  @functools.wraps(func)
63
62
  def wrapper_trace_tagging(*args, **kwargs):
64
- # Assuming the first argument is always the message
65
63
  ts_fetched = datetime.now()
66
- message = args[0]
64
+ # Determine which argument is the message.
65
+ if hasattr(args[0], "has_metadata"):
66
+ message = args[0]
67
+ elif len(args) > 1 and hasattr(args[1], "has_metadata"):
68
+ message = args[1]
69
+ else:
70
+ raise ValueError("traceable decorator could not find a message argument with 'has_metadata()'")
67
71
 
68
72
  do_trace_tagging = (message.has_metadata("config::add_trace_tagging") is True) and (
69
73
  message.get_metadata("config::add_trace_tagging") is True
@@ -79,7 +83,7 @@ def traceable(trace_name=None):
79
83
  message.set_timestamp(f"trace::entry::{trace_prefix}_channel_in", ts_send)
80
84
  message.set_timestamp(f"trace::exit::{trace_prefix}_channel_in", ts_fetched)
81
85
 
82
- # Call the decorated function
86
+ # Call the decorated function.
83
87
  result = func(*args, **kwargs)
84
88
 
85
89
  if do_trace_tagging:
@@ -129,7 +129,7 @@ class ChartExtractorSchema(BaseModel):
129
129
  @field_validator("max_queue_size", "n_workers")
130
130
  def check_positive(cls, v, field):
131
131
  if v <= 0:
132
- raise ValueError(f"{field.field_name} must be greater than 10.")
132
+ raise ValueError(f"{field.field_name} must be greater than 0.")
133
133
  return v
134
134
 
135
135
  model_config = ConfigDict(extra="forbid")
@@ -0,0 +1,34 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+
6
+ import logging
7
+
8
+ from pydantic import ConfigDict, BaseModel
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class HtmlExtractorSchema(BaseModel):
14
+ """
15
+ Configuration schema for the Html extractor settings.
16
+
17
+ Parameters
18
+ ----------
19
+ max_queue_size : int, default=1
20
+ The maximum number of items allowed in the processing queue.
21
+
22
+ n_workers : int, default=16
23
+ The number of worker threads to use for processing.
24
+
25
+ raise_on_failure : bool, default=False
26
+ A flag indicating whether to raise an exception on processing failure.
27
+
28
+ """
29
+
30
+ max_queue_size: int = 1
31
+ n_workers: int = 16
32
+ raise_on_failure: bool = False
33
+
34
+ model_config = ConfigDict(extra="forbid")
@@ -122,7 +122,7 @@ class InfographicExtractorSchema(BaseModel):
122
122
  @field_validator("max_queue_size", "n_workers")
123
123
  def check_positive(cls, v, field):
124
124
  if v <= 0:
125
- raise ValueError(f"{field.field_name} must be greater than 10.")
125
+ raise ValueError(f"{field.field_name} must be greater than 0.")
126
126
  return v
127
127
 
128
128
  model_config = ConfigDict(extra="forbid")
@@ -131,7 +131,7 @@ class NemoRetrieverParseConfigSchema(BaseModel):
131
131
  nemoretriever_parse_endpoints: Tuple[Optional[str], Optional[str]] = (None, None)
132
132
  nemoretriever_parse_infer_protocol: str = ""
133
133
 
134
- model_name: str = "nvidia/nemoretriever-parse"
134
+ nemoretriever_parse_model_name: str = "nvidia/nemoretriever-parse"
135
135
 
136
136
  timeout: float = 300.0
137
137
 
@@ -122,7 +122,7 @@ class TableExtractorSchema(BaseModel):
122
122
  @field_validator("max_queue_size", "n_workers")
123
123
  def check_positive(cls, v, field):
124
124
  if v <= 0:
125
- raise ValueError(f"{field.field_name} must be greater than 10.")
125
+ raise ValueError(f"{field.field_name} must be greater than 0.")
126
126
  return v
127
127
 
128
128
  endpoint_config: Optional[TableExtractorConfigSchema] = None
@@ -2,22 +2,36 @@
2
2
  # All rights reserved.
3
3
  # SPDX-License-Identifier: Apache-2.0
4
4
 
5
+ from pydantic import BaseModel, Field
6
+ from typing import Optional, Literal, Annotated
5
7
 
6
- from typing import Optional, Literal
7
8
 
8
- from pydantic import Field, BaseModel
9
- from typing_extensions import Annotated
9
+ class MessageBrokerClientSchema(BaseModel):
10
+ """
11
+ Configuration schema for message broker client connections.
12
+ Supports Redis or simple in-memory clients.
13
+ """
10
14
 
15
+ host: str = Field(default="redis", description="Hostname of the broker service.")
11
16
 
12
- class MessageBrokerClientSchema(BaseModel):
13
- host: str = "redis"
14
- port: Annotated[int, Field(gt=0, lt=65536)] = 6379
17
+ port: Annotated[int, Field(gt=0, lt=65536)] = Field(
18
+ default=6379, description="Port to connect to. Must be between 1 and 65535."
19
+ )
20
+
21
+ client_type: Literal["redis", "simple"] = Field(
22
+ default="redis", description="Type of broker client. Supported values: 'redis', 'simple'."
23
+ )
24
+
25
+ broker_params: Optional[dict] = Field(
26
+ default_factory=dict, description="Optional parameters passed to the broker client."
27
+ )
15
28
 
16
- # Update this for new broker types
17
- client_type: Literal["redis", "simple"] = "redis" # Restrict to 'redis' or 'simple'
29
+ connection_timeout: Annotated[int, Field(ge=0)] = Field(
30
+ default=300, description="Connection timeout in seconds. Must be >= 0."
31
+ )
18
32
 
19
- broker_params: Optional[dict] = Field(default_factory=dict)
33
+ max_backoff: Annotated[int, Field(ge=0)] = Field(
34
+ default=300, description="Maximum backoff time in seconds. Must be >= 0."
35
+ )
20
36
 
21
- connection_timeout: Optional[Annotated[int, Field(ge=0)]] = 300
22
- max_backoff: Optional[Annotated[int, Field(ge=0)]] = 300
23
- max_retries: Optional[Annotated[int, Field(ge=0)]] = 0
37
+ max_retries: Annotated[int, Field(ge=0)] = Field(default=0, description="Maximum number of retries. Must be >= 0.")
@@ -160,29 +160,40 @@ class IngestTaskSchema(BaseModelNoExt):
160
160
  @model_validator(mode="before")
161
161
  @classmethod
162
162
  def check_task_properties_type(cls, values):
163
- task_type, task_properties = values.get("type"), values.get("task_properties", {})
164
- if task_type and task_properties:
165
- expected_type = {
166
- TaskTypeEnum.CAPTION: IngestTaskCaptionSchema,
167
- TaskTypeEnum.DEDUP: IngestTaskDedupSchema,
168
- TaskTypeEnum.EMBED: IngestTaskEmbedSchema,
169
- TaskTypeEnum.EXTRACT: IngestTaskExtractSchema,
170
- TaskTypeEnum.FILTER: IngestTaskFilterSchema, # Extend mapping as necessary
171
- TaskTypeEnum.SPLIT: IngestTaskSplitSchema,
172
- TaskTypeEnum.STORE_EMBEDDING: IngestTaskStoreEmbedSchema,
173
- TaskTypeEnum.STORE: IngestTaskStoreSchema,
174
- TaskTypeEnum.VDB_UPLOAD: IngestTaskVdbUploadSchema,
175
- TaskTypeEnum.AUDIO_DATA_EXTRACT: IngestTaskAudioExtraction,
176
- TaskTypeEnum.TABLE_DATA_EXTRACT: IngestTaskTableExtraction,
177
- TaskTypeEnum.CHART_DATA_EXTRACT: IngestTaskChartExtraction,
178
- TaskTypeEnum.INFOGRAPHIC_DATA_EXTRACT: IngestTaskInfographicExtraction,
179
- }.get(
180
- task_type
181
- ) # Removed .upper()
182
-
183
- # Validate task_properties against the expected schema.
184
- validated_task_properties = expected_type(**task_properties)
185
- values["task_properties"] = validated_task_properties
163
+ task_type = values.get("type")
164
+ task_properties = values.get("task_properties", {})
165
+
166
+ # Ensure task_type is lowercased and converted to enum early
167
+ if isinstance(task_type, str):
168
+ task_type = task_type.lower()
169
+ try:
170
+ task_type = TaskTypeEnum(task_type)
171
+ except ValueError:
172
+ raise ValueError(f"{task_type} is not a valid TaskTypeEnum value")
173
+
174
+ task_type_to_schema = {
175
+ TaskTypeEnum.CAPTION: IngestTaskCaptionSchema,
176
+ TaskTypeEnum.DEDUP: IngestTaskDedupSchema,
177
+ TaskTypeEnum.EMBED: IngestTaskEmbedSchema,
178
+ TaskTypeEnum.EXTRACT: IngestTaskExtractSchema,
179
+ TaskTypeEnum.FILTER: IngestTaskFilterSchema,
180
+ TaskTypeEnum.SPLIT: IngestTaskSplitSchema,
181
+ TaskTypeEnum.STORE_EMBEDDING: IngestTaskStoreEmbedSchema,
182
+ TaskTypeEnum.STORE: IngestTaskStoreSchema,
183
+ TaskTypeEnum.VDB_UPLOAD: IngestTaskVdbUploadSchema,
184
+ TaskTypeEnum.AUDIO_DATA_EXTRACT: IngestTaskAudioExtraction,
185
+ TaskTypeEnum.TABLE_DATA_EXTRACT: IngestTaskTableExtraction,
186
+ TaskTypeEnum.CHART_DATA_EXTRACT: IngestTaskChartExtraction,
187
+ TaskTypeEnum.INFOGRAPHIC_DATA_EXTRACT: IngestTaskInfographicExtraction,
188
+ }
189
+
190
+ expected_schema_cls = task_type_to_schema.get(task_type)
191
+ if expected_schema_cls is None:
192
+ raise ValueError(f"Unsupported or missing task_type '{task_type}'")
193
+
194
+ validated_task_properties = expected_schema_cls(**task_properties)
195
+ values["type"] = task_type # ensure type is now always the enum
196
+ values["task_properties"] = validated_task_properties
186
197
  return values
187
198
 
188
199
  @field_validator("type", mode="before")
@@ -5,7 +5,7 @@
5
5
 
6
6
  import logging
7
7
 
8
- from pydantic import ConfigDict, BaseModel
8
+ from pydantic import ConfigDict, BaseModel, Field
9
9
 
10
10
  from nv_ingest_api.util.logging.configuration import LogLevel
11
11
 
@@ -13,13 +13,14 @@ logger = logging.getLogger(__name__)
13
13
 
14
14
 
15
15
  class TextEmbeddingSchema(BaseModel):
16
- api_key: str = "api_key"
17
- batch_size: int = 4
18
- embedding_model: str = "nvidia/nv-embedqa-e5-v5"
19
- embedding_nim_endpoint: str = "http://embedding:8000/v1"
20
- encoding_format: str = "float"
21
- httpx_log_level: LogLevel = LogLevel.WARNING
22
- input_type: str = "passage"
23
- raise_on_failure: bool = False
24
- truncate: str = "END"
16
+ api_key: str = Field(default="api_key")
17
+ batch_size: int = Field(default=4)
18
+ embedding_model: str = Field(default="nvidia/llama-3.2-nv-embedqa-1b-v2")
19
+ embedding_nim_endpoint: str = Field(default="http://embedding:8000/v1")
20
+ encoding_format: str = Field(default="float")
21
+ httpx_log_level: LogLevel = Field(default=LogLevel.WARNING)
22
+ input_type: str = Field(default="passage")
23
+ raise_on_failure: bool = Field(default=False)
24
+ truncate: str = Field(default="END")
25
+
25
26
  model_config = ConfigDict(extra="forbid")
@@ -2,21 +2,23 @@
2
2
  # All rights reserved.
3
3
  # SPDX-License-Identifier: Apache-2.0
4
4
 
5
- from pydantic import Field, BaseModel, field_validator
5
+ from pydantic import Field, BaseModel, field_validator, ConfigDict
6
6
 
7
7
  from typing import Optional
8
8
 
9
- from typing_extensions import Annotated
10
-
11
9
 
12
10
  class TextSplitterSchema(BaseModel):
13
11
  tokenizer: Optional[str] = None
14
- chunk_size: Annotated[int, Field(gt=0)] = 1024
15
- chunk_overlap: Annotated[int, Field(ge=0)] = 150
12
+ chunk_size: int = Field(default=1024, gt=0)
13
+ chunk_overlap: int = Field(default=150, ge=0)
16
14
  raise_on_failure: bool = False
17
15
 
18
16
  @field_validator("chunk_overlap")
19
- def check_chunk_overlap(cls, v, values, **kwargs):
20
- if v is not None and "chunk_size" in values.data and v >= values.data["chunk_size"]:
17
+ @classmethod
18
+ def check_chunk_overlap(cls, v, values):
19
+ chunk_size = values.data.get("chunk_size")
20
+ if chunk_size is not None and v >= chunk_size:
21
21
  raise ValueError("chunk_overlap must be less than chunk_size")
22
22
  return v
23
+
24
+ model_config = ConfigDict(extra="forbid")
@@ -116,6 +116,7 @@ def _upload_images_to_minio(df: pd.DataFrame, params: Dict[str, Any]) -> pd.Data
116
116
  if "content" not in metadata:
117
117
  logger.error("Row %s: missing 'content' in metadata", idx)
118
118
  continue
119
+
119
120
  if "source_metadata" not in metadata or not isinstance(metadata["source_metadata"], dict):
120
121
  logger.error("Row %s: missing or invalid 'source_metadata' in metadata", idx)
121
122
  continue
@@ -230,28 +230,35 @@ def _async_runner(
230
230
  def _add_embeddings(row, embeddings, info_msgs):
231
231
  """
232
232
  Updates a DataFrame row with embedding data and associated error info.
233
+ Ensures the 'embedding' field is always present, even if None.
233
234
 
234
235
  Parameters
235
236
  ----------
236
237
  row : pandas.Series
237
238
  A row of the DataFrame.
238
- embeddings : list
239
- List of embeddings corresponding to DataFrame rows.
240
- info_msgs : list
241
- List of info message dictionaries corresponding to DataFrame rows.
239
+ embeddings : dict
240
+ Dictionary mapping row indices to embeddings.
241
+ info_msgs : dict
242
+ Dictionary mapping row indices to info message dicts.
242
243
 
243
244
  Returns
244
245
  -------
245
246
  pandas.Series
246
- The updated row with embedding and info message metadata added.
247
+ The updated row with 'embedding', 'info_message_metadata', and
248
+ '_contains_embeddings' appropriately set.
247
249
  """
248
- row["metadata"]["embedding"] = embeddings[row.name]
249
- if info_msgs[row.name] is not None:
250
- row["metadata"]["info_message_metadata"] = info_msgs[row.name]
250
+ embedding = embeddings.get(row.name, None)
251
+ info_msg = info_msgs.get(row.name, None)
252
+
253
+ # Always set embedding, even if None
254
+ row["metadata"]["embedding"] = embedding
255
+
256
+ if info_msg:
257
+ row["metadata"]["info_message_metadata"] = info_msg
251
258
  row["document_type"] = ContentTypeEnum.INFO_MSG
252
259
  row["_contains_embeddings"] = False
253
260
  else:
254
- row["_contains_embeddings"] = True
261
+ row["_contains_embeddings"] = embedding is not None
255
262
 
256
263
  return row
257
264
 
@@ -287,7 +294,7 @@ def _get_pandas_table_content(row):
287
294
  str
288
295
  The table/chart content from the row.
289
296
  """
290
- return row["table_metadata"]["table_content"]
297
+ return row.get("table_metadata", {}).get("table_content")
291
298
 
292
299
 
293
300
  def _get_pandas_image_content(row):
@@ -304,7 +311,14 @@ def _get_pandas_image_content(row):
304
311
  str
305
312
  The image caption from the row.
306
313
  """
307
- return row["image_metadata"]["caption"]
314
+ return row.get("image_metadata", {}).get("caption")
315
+
316
+
317
+ def _get_pandas_audio_content(row):
318
+ """
319
+ A pandas UDF used to select extracted audio transcription to be used to create embeddings.
320
+ """
321
+ return row.get("audio_metadata", {}).get("audio_transcript")
308
322
 
309
323
 
310
324
  # ------------------------------------------------------------------------------
@@ -352,13 +366,6 @@ def _generate_batches(prompts: List[str], batch_size: int = 100) -> List[str]:
352
366
  return [batch for batch in _batch_generator(prompts, batch_size)]
353
367
 
354
368
 
355
- def _get_pandas_audio_content(row):
356
- """
357
- A pandas UDF used to select extracted audio transcription to be used to create embeddings.
358
- """
359
- return row["audio_metadata"]["audio_transcript"]
360
-
361
-
362
369
  # ------------------------------------------------------------------------------
363
370
  # DataFrame Concatenation Utility
364
371
  # ------------------------------------------------------------------------------
@@ -408,17 +415,20 @@ def transform_create_text_embeddings_internal(
408
415
  execution_trace_log: Optional[Dict] = None,
409
416
  ) -> Tuple[pd.DataFrame, Dict]:
410
417
  """
411
- Generates text embeddings for supported content types (TEXT, STRUCTURED, IMAGE)
418
+ Generates text embeddings for supported content types (TEXT, STRUCTURED, IMAGE, AUDIO)
412
419
  from a pandas DataFrame using asynchronous requests.
413
420
 
421
+ This function ensures that even if the extracted content is empty or None,
422
+ the embedding field is explicitly created and set to None.
423
+
414
424
  Parameters
415
425
  ----------
416
426
  df_transform_ledger : pd.DataFrame
417
427
  The DataFrame containing content for embedding extraction.
418
428
  task_config : Dict[str, Any]
419
429
  Dictionary containing task properties (e.g., filter error flag).
420
- transform_config : Any
421
- Validated configuration for text embedding extraction (EmbedExtractionsSchema).
430
+ transform_config : TextEmbeddingSchema, optional
431
+ Validated configuration for text embedding extraction.
422
432
  execution_trace_log : Optional[Dict], optional
423
433
  Optional trace information for debugging or logging (default is None).
424
434
 
@@ -429,20 +439,20 @@ def transform_create_text_embeddings_internal(
429
439
  - The updated DataFrame with embeddings applied.
430
440
  - A dictionary with trace information.
431
441
  """
432
- _ = task_config # Currently unused.
442
+ api_key = task_config.get("api_key") or transform_config.api_key
443
+ endpoint_url = task_config.get("endpoint_url") or transform_config.embedding_nim_endpoint
444
+ model_name = task_config.get("model_name") or transform_config.embedding_model
433
445
 
434
446
  if execution_trace_log is None:
435
447
  execution_trace_log = {}
436
448
  logger.debug("No trace_info provided. Initialized empty trace_info dictionary.")
437
449
 
438
- # TODO(Devin)
439
450
  if df_transform_ledger.empty:
440
451
  return df_transform_ledger, {"trace_info": execution_trace_log}
441
452
 
442
453
  embedding_dataframes = []
443
- content_masks = [] # List of pandas boolean Series
454
+ content_masks = []
444
455
 
445
- # Define pandas content extractors for supported content types.
446
456
  pandas_content_extractor = {
447
457
  ContentTypeEnum.TEXT: _get_pandas_text_content,
448
458
  ContentTypeEnum.STRUCTURED: _get_pandas_table_content,
@@ -451,49 +461,62 @@ def transform_create_text_embeddings_internal(
451
461
  ContentTypeEnum.VIDEO: lambda x: None, # Not supported yet.
452
462
  }
453
463
 
454
- logger.debug("Generating text embeddings for supported content types: TEXT, STRUCTURED, IMAGE.")
455
-
456
464
  def _content_type_getter(row):
457
465
  return row["content_metadata"]["type"]
458
466
 
459
- # Process each supported content type.
460
467
  for content_type, content_getter in pandas_content_extractor.items():
461
468
  if not content_getter:
462
469
  logger.debug(f"Skipping unsupported content type: {content_type}")
463
470
  continue
464
471
 
472
+ # Get rows matching the content type
465
473
  content_mask = df_transform_ledger["metadata"].apply(_content_type_getter) == content_type.value
466
474
  if not content_mask.any():
467
475
  continue
468
476
 
469
- # Extract content from metadata and filter out rows with empty content.
470
- extracted_content = df_transform_ledger.loc[content_mask, "metadata"].apply(content_getter)
471
- non_empty_mask = extracted_content.notna() & (extracted_content.str.strip() != "")
472
- final_mask = content_mask & non_empty_mask
473
- if not final_mask.any():
474
- continue
477
+ # Always include all content_mask rows and prepare them
478
+ df_content = df_transform_ledger.loc[content_mask].copy().reset_index(drop=True)
475
479
 
476
- df_content = df_transform_ledger.loc[final_mask].copy().reset_index(drop=True)
477
- filtered_content = df_content["metadata"].apply(content_getter)
478
- filtered_content_batches = _generate_batches(filtered_content.tolist(), batch_size=transform_config.batch_size)
479
- content_embeddings = _async_runner(
480
- filtered_content_batches,
481
- transform_config.api_key,
482
- transform_config.embedding_nim_endpoint,
483
- transform_config.embedding_model,
484
- transform_config.encoding_format,
485
- transform_config.input_type,
486
- transform_config.truncate,
487
- False,
480
+ # Extract content and normalize empty or non-str to None
481
+ extracted_content = (
482
+ df_content["metadata"]
483
+ .apply(content_getter)
484
+ .apply(lambda x: x.strip() if isinstance(x, str) and x.strip() else None)
488
485
  )
489
- # Apply the embeddings (and any error info) to each row.
490
- df_content[["metadata", "document_type", "_contains_embeddings"]] = df_content.apply(
491
- _add_embeddings, **content_embeddings, axis=1
492
- )[["metadata", "document_type", "_contains_embeddings"]]
493
- df_content["_content"] = filtered_content
486
+ df_content["_content"] = extracted_content
487
+
488
+ # Prepare batches for only valid (non-None) content
489
+ valid_content_mask = df_content["_content"].notna()
490
+ if valid_content_mask.any():
491
+ filtered_content_batches = _generate_batches(
492
+ df_content.loc[valid_content_mask, "_content"].tolist(), batch_size=transform_config.batch_size
493
+ )
494
+ content_embeddings = _async_runner(
495
+ filtered_content_batches,
496
+ api_key,
497
+ endpoint_url,
498
+ model_name,
499
+ transform_config.encoding_format,
500
+ transform_config.input_type,
501
+ transform_config.truncate,
502
+ False,
503
+ )
504
+ # Build a simple row index -> embedding map
505
+ embeddings_dict = dict(
506
+ zip(df_content.loc[valid_content_mask].index, content_embeddings.get("embeddings", []))
507
+ )
508
+ info_msgs_dict = dict(
509
+ zip(df_content.loc[valid_content_mask].index, content_embeddings.get("info_msgs", []))
510
+ )
511
+ else:
512
+ embeddings_dict = {}
513
+ info_msgs_dict = {}
514
+
515
+ # Apply embeddings or None to all rows
516
+ df_content = df_content.apply(_add_embeddings, embeddings=embeddings_dict, info_msgs=info_msgs_dict, axis=1)
494
517
 
495
518
  embedding_dataframes.append(df_content)
496
- content_masks.append(final_mask)
519
+ content_masks.append(content_mask)
497
520
 
498
521
  combined_df = _concatenate_extractions_pandas(df_transform_ledger, embedding_dataframes, content_masks)
499
522
  return combined_df, {"trace_info": execution_trace_log}
@@ -118,9 +118,15 @@ def transform_text_split_and_tokenize_internal(
118
118
  )
119
119
 
120
120
  # Filter to documents with text content.
121
- bool_index = (df_transform_ledger["document_type"] == ContentTypeEnum.TEXT) & (
122
- pd.json_normalize(df_transform_ledger["metadata"])["source_metadata.source_type"].isin(split_source_types)
123
- )
121
+ text_type_condition = df_transform_ledger["document_type"] == ContentTypeEnum.TEXT
122
+
123
+ normalized_meta_df = pd.json_normalize(df_transform_ledger["metadata"], errors="ignore")
124
+ if "source_metadata.source_type" in normalized_meta_df.columns:
125
+ source_type_condition = normalized_meta_df["source_metadata.source_type"].isin(split_source_types)
126
+ else:
127
+ source_type_condition = False
128
+
129
+ bool_index = text_type_condition & source_type_condition
124
130
  df_filtered: pd.DataFrame = df_transform_ledger.loc[bool_index]
125
131
 
126
132
  if df_filtered.empty:
@@ -0,0 +1,3 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0