nv-ingest-api 2025.5.11.dev20250511__py3-none-any.whl → 2025.5.13.dev20250513__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nv-ingest-api might be problematic. Click here for more details.

Files changed (28) hide show
  1. nv_ingest_api/interface/transform.py +1 -1
  2. nv_ingest_api/internal/extract/docx/docx_extractor.py +3 -3
  3. nv_ingest_api/internal/extract/image/image_extractor.py +5 -5
  4. nv_ingest_api/internal/extract/pdf/engines/nemoretriever.py +1 -1
  5. nv_ingest_api/internal/extract/pptx/engines/pptx_helper.py +44 -17
  6. nv_ingest_api/internal/extract/pptx/pptx_extractor.py +1 -1
  7. nv_ingest_api/internal/primitives/nim/model_interface/text_embedding.py +35 -38
  8. nv_ingest_api/internal/primitives/nim/model_interface/yolox.py +7 -1
  9. nv_ingest_api/internal/primitives/nim/nim_client.py +17 -9
  10. nv_ingest_api/internal/primitives/tracing/tagging.py +20 -16
  11. nv_ingest_api/internal/schemas/extract/extract_pdf_schema.py +1 -1
  12. nv_ingest_api/internal/schemas/meta/ingest_job_schema.py +2 -2
  13. nv_ingest_api/internal/schemas/transform/transform_image_caption_schema.py +1 -1
  14. nv_ingest_api/internal/transform/caption_image.py +1 -1
  15. nv_ingest_api/internal/transform/embed_text.py +75 -56
  16. nv_ingest_api/util/exception_handlers/converters.py +1 -1
  17. nv_ingest_api/util/exception_handlers/decorators.py +309 -51
  18. nv_ingest_api/util/logging/configuration.py +15 -8
  19. nv_ingest_api/util/pdf/pdfium.py +1 -1
  20. nv_ingest_api/util/service_clients/redis/redis_client.py +1 -1
  21. nv_ingest_api/util/service_clients/rest/rest_client.py +1 -1
  22. nv_ingest_api/util/system/__init__.py +0 -0
  23. nv_ingest_api/util/system/hardware_info.py +426 -0
  24. {nv_ingest_api-2025.5.11.dev20250511.dist-info → nv_ingest_api-2025.5.13.dev20250513.dist-info}/METADATA +1 -1
  25. {nv_ingest_api-2025.5.11.dev20250511.dist-info → nv_ingest_api-2025.5.13.dev20250513.dist-info}/RECORD +28 -26
  26. {nv_ingest_api-2025.5.11.dev20250511.dist-info → nv_ingest_api-2025.5.13.dev20250513.dist-info}/WHEEL +0 -0
  27. {nv_ingest_api-2025.5.11.dev20250511.dist-info → nv_ingest_api-2025.5.13.dev20250513.dist-info}/licenses/LICENSE +0 -0
  28. {nv_ingest_api-2025.5.11.dev20250511.dist-info → nv_ingest_api-2025.5.13.dev20250513.dist-info}/top_level.txt +0 -0
@@ -230,28 +230,35 @@ def _async_runner(
230
230
  def _add_embeddings(row, embeddings, info_msgs):
231
231
  """
232
232
  Updates a DataFrame row with embedding data and associated error info.
233
+ Ensures the 'embedding' field is always present, even if None.
233
234
 
234
235
  Parameters
235
236
  ----------
236
237
  row : pandas.Series
237
238
  A row of the DataFrame.
238
- embeddings : list
239
- List of embeddings corresponding to DataFrame rows.
240
- info_msgs : list
241
- List of info message dictionaries corresponding to DataFrame rows.
239
+ embeddings : dict
240
+ Dictionary mapping row indices to embeddings.
241
+ info_msgs : dict
242
+ Dictionary mapping row indices to info message dicts.
242
243
 
243
244
  Returns
244
245
  -------
245
246
  pandas.Series
246
- The updated row with embedding and info message metadata added.
247
+ The updated row with 'embedding', 'info_message_metadata', and
248
+ '_contains_embeddings' appropriately set.
247
249
  """
248
- row["metadata"]["embedding"] = embeddings[row.name]
249
- if info_msgs[row.name] is not None:
250
- row["metadata"]["info_message_metadata"] = info_msgs[row.name]
250
+ embedding = embeddings.get(row.name, None)
251
+ info_msg = info_msgs.get(row.name, None)
252
+
253
+ # Always set embedding, even if None
254
+ row["metadata"]["embedding"] = embedding
255
+
256
+ if info_msg:
257
+ row["metadata"]["info_message_metadata"] = info_msg
251
258
  row["document_type"] = ContentTypeEnum.INFO_MSG
252
259
  row["_contains_embeddings"] = False
253
260
  else:
254
- row["_contains_embeddings"] = True
261
+ row["_contains_embeddings"] = embedding is not None
255
262
 
256
263
  return row
257
264
 
@@ -287,7 +294,7 @@ def _get_pandas_table_content(row):
287
294
  str
288
295
  The table/chart content from the row.
289
296
  """
290
- return row["table_metadata"]["table_content"]
297
+ return row.get("table_metadata", {}).get("table_content")
291
298
 
292
299
 
293
300
  def _get_pandas_image_content(row):
@@ -304,7 +311,14 @@ def _get_pandas_image_content(row):
304
311
  str
305
312
  The image caption from the row.
306
313
  """
307
- return row["image_metadata"]["caption"]
314
+ return row.get("image_metadata", {}).get("caption")
315
+
316
+
317
+ def _get_pandas_audio_content(row):
318
+ """
319
+ A pandas UDF used to select extracted audio transcription to be used to create embeddings.
320
+ """
321
+ return row.get("audio_metadata", {}).get("audio_transcript")
308
322
 
309
323
 
310
324
  # ------------------------------------------------------------------------------
@@ -352,13 +366,6 @@ def _generate_batches(prompts: List[str], batch_size: int = 100) -> List[str]:
352
366
  return [batch for batch in _batch_generator(prompts, batch_size)]
353
367
 
354
368
 
355
- def _get_pandas_audio_content(row):
356
- """
357
- A pandas UDF used to select extracted audio transcription to be used to create embeddings.
358
- """
359
- return row["audio_metadata"]["audio_transcript"]
360
-
361
-
362
369
  # ------------------------------------------------------------------------------
363
370
  # DataFrame Concatenation Utility
364
371
  # ------------------------------------------------------------------------------
@@ -408,17 +415,20 @@ def transform_create_text_embeddings_internal(
408
415
  execution_trace_log: Optional[Dict] = None,
409
416
  ) -> Tuple[pd.DataFrame, Dict]:
410
417
  """
411
- Generates text embeddings for supported content types (TEXT, STRUCTURED, IMAGE)
418
+ Generates text embeddings for supported content types (TEXT, STRUCTURED, IMAGE, AUDIO)
412
419
  from a pandas DataFrame using asynchronous requests.
413
420
 
421
+ This function ensures that even if the extracted content is empty or None,
422
+ the embedding field is explicitly created and set to None.
423
+
414
424
  Parameters
415
425
  ----------
416
426
  df_transform_ledger : pd.DataFrame
417
427
  The DataFrame containing content for embedding extraction.
418
428
  task_config : Dict[str, Any]
419
429
  Dictionary containing task properties (e.g., filter error flag).
420
- transform_config : Any
421
- Validated configuration for text embedding extraction (EmbedExtractionsSchema).
430
+ transform_config : TextEmbeddingSchema, optional
431
+ Validated configuration for text embedding extraction.
422
432
  execution_trace_log : Optional[Dict], optional
423
433
  Optional trace information for debugging or logging (default is None).
424
434
 
@@ -429,24 +439,20 @@ def transform_create_text_embeddings_internal(
429
439
  - The updated DataFrame with embeddings applied.
430
440
  - A dictionary with trace information.
431
441
  """
432
-
433
- # Retrieve configuration values with fallback to transform_config defaults.
434
- api_key: str = task_config.get("api_key") or transform_config.api_key
435
- endpoint_url: str = task_config.get("endpoint_url") or transform_config.embedding_nim_endpoint
436
- model_name: str = task_config.get("model_name") or transform_config.embedding_model
442
+ api_key = task_config.get("api_key") or transform_config.api_key
443
+ endpoint_url = task_config.get("endpoint_url") or transform_config.embedding_nim_endpoint
444
+ model_name = task_config.get("model_name") or transform_config.embedding_model
437
445
 
438
446
  if execution_trace_log is None:
439
447
  execution_trace_log = {}
440
448
  logger.debug("No trace_info provided. Initialized empty trace_info dictionary.")
441
449
 
442
- # TODO(Devin)
443
450
  if df_transform_ledger.empty:
444
451
  return df_transform_ledger, {"trace_info": execution_trace_log}
445
452
 
446
453
  embedding_dataframes = []
447
- content_masks = [] # List of pandas boolean Series
454
+ content_masks = []
448
455
 
449
- # Define pandas content extractors for supported content types.
450
456
  pandas_content_extractor = {
451
457
  ContentTypeEnum.TEXT: _get_pandas_text_content,
452
458
  ContentTypeEnum.STRUCTURED: _get_pandas_table_content,
@@ -455,49 +461,62 @@ def transform_create_text_embeddings_internal(
455
461
  ContentTypeEnum.VIDEO: lambda x: None, # Not supported yet.
456
462
  }
457
463
 
458
- logger.debug("Generating text embeddings for supported content types: TEXT, STRUCTURED, IMAGE.")
459
-
460
464
  def _content_type_getter(row):
461
465
  return row["content_metadata"]["type"]
462
466
 
463
- # Process each supported content type.
464
467
  for content_type, content_getter in pandas_content_extractor.items():
465
468
  if not content_getter:
466
469
  logger.debug(f"Skipping unsupported content type: {content_type}")
467
470
  continue
468
471
 
472
+ # Get rows matching the content type
469
473
  content_mask = df_transform_ledger["metadata"].apply(_content_type_getter) == content_type.value
470
474
  if not content_mask.any():
471
475
  continue
472
476
 
473
- # Extract content from metadata and filter out rows with empty content.
474
- extracted_content = df_transform_ledger.loc[content_mask, "metadata"].apply(content_getter)
475
- non_empty_mask = extracted_content.notna() & (extracted_content.str.strip() != "")
476
- final_mask = content_mask & non_empty_mask
477
- if not final_mask.any():
478
- continue
477
+ # Always include all content_mask rows and prepare them
478
+ df_content = df_transform_ledger.loc[content_mask].copy().reset_index(drop=True)
479
479
 
480
- df_content = df_transform_ledger.loc[final_mask].copy().reset_index(drop=True)
481
- filtered_content = df_content["metadata"].apply(content_getter)
482
- filtered_content_batches = _generate_batches(filtered_content.tolist(), batch_size=transform_config.batch_size)
483
- content_embeddings = _async_runner(
484
- filtered_content_batches,
485
- api_key,
486
- endpoint_url,
487
- model_name,
488
- transform_config.encoding_format,
489
- transform_config.input_type,
490
- transform_config.truncate,
491
- False,
480
+ # Extract content and normalize empty or non-str to None
481
+ extracted_content = (
482
+ df_content["metadata"]
483
+ .apply(content_getter)
484
+ .apply(lambda x: x.strip() if isinstance(x, str) and x.strip() else None)
492
485
  )
493
- # Apply the embeddings (and any error info) to each row.
494
- df_content[["metadata", "document_type", "_contains_embeddings"]] = df_content.apply(
495
- _add_embeddings, **content_embeddings, axis=1
496
- )[["metadata", "document_type", "_contains_embeddings"]]
497
- df_content["_content"] = filtered_content
486
+ df_content["_content"] = extracted_content
487
+
488
+ # Prepare batches for only valid (non-None) content
489
+ valid_content_mask = df_content["_content"].notna()
490
+ if valid_content_mask.any():
491
+ filtered_content_batches = _generate_batches(
492
+ df_content.loc[valid_content_mask, "_content"].tolist(), batch_size=transform_config.batch_size
493
+ )
494
+ content_embeddings = _async_runner(
495
+ filtered_content_batches,
496
+ api_key,
497
+ endpoint_url,
498
+ model_name,
499
+ transform_config.encoding_format,
500
+ transform_config.input_type,
501
+ transform_config.truncate,
502
+ False,
503
+ )
504
+ # Build a simple row index -> embedding map
505
+ embeddings_dict = dict(
506
+ zip(df_content.loc[valid_content_mask].index, content_embeddings.get("embeddings", []))
507
+ )
508
+ info_msgs_dict = dict(
509
+ zip(df_content.loc[valid_content_mask].index, content_embeddings.get("info_msgs", []))
510
+ )
511
+ else:
512
+ embeddings_dict = {}
513
+ info_msgs_dict = {}
514
+
515
+ # Apply embeddings or None to all rows
516
+ df_content = df_content.apply(_add_embeddings, embeddings=embeddings_dict, info_msgs=info_msgs_dict, axis=1)
498
517
 
499
518
  embedding_dataframes.append(df_content)
500
- content_masks.append(final_mask)
519
+ content_masks.append(content_mask)
501
520
 
502
521
  combined_df = _concatenate_extractions_pandas(df_transform_ledger, embedding_dataframes, content_masks)
503
522
  return combined_df, {"trace_info": execution_trace_log}
@@ -66,7 +66,7 @@ def datetools_exception_handler(func: Callable, **kwargs: Dict[str, Any]) -> Cal
66
66
  return func(*args, **kwargs)
67
67
  except Exception as e:
68
68
  log_error_message = f"Invalid date format: {e}"
69
- logger.warning(log_error_message)
69
+ logger.debug(log_error_message)
70
70
  return datetools.remove_tz(datetime.now(timezone.utc)).isoformat()
71
71
 
72
72
  return inner_function
@@ -2,77 +2,321 @@
2
2
  # All rights reserved.
3
3
  # SPDX-License-Identifier: Apache-2.0
4
4
 
5
+ import asyncio
5
6
  import logging
6
7
  import functools
7
8
  import inspect
8
9
  import re
9
- import typing
10
+ from typing import Any, Optional, Callable, Tuple
10
11
  from functools import wraps
11
12
 
12
13
  from nv_ingest_api.internal.primitives.ingest_control_message import IngestControlMessage
13
14
  from nv_ingest_api.internal.primitives.tracing.logging import TaskResultStatus, annotate_task_result
14
15
  from nv_ingest_api.util.control_message.validators import cm_ensure_payload_not_null, cm_set_failure
15
16
 
16
-
17
17
  logger = logging.getLogger(__name__)
18
18
 
19
19
 
20
- # TODO(Devin): move back to framework
20
+ def nv_ingest_node_failure_try_except( # New name to distinguish
21
+ annotation_id: str,
22
+ payload_can_be_empty: bool = False,
23
+ raise_on_failure: bool = False,
24
+ skip_processing_if_failed: bool = True,
25
+ forward_func: Optional[Callable[[Any], Any]] = None,
26
+ ) -> Callable:
27
+ """
28
+ Decorator that wraps function execution in a try/except block to handle
29
+ failures by annotating an IngestControlMessage. Replaces the context
30
+ manager approach for potentially simpler interaction with frameworks like Ray.
31
+
32
+ Parameters are the same as nv_ingest_node_failure_context_manager.
33
+ """
34
+
35
+ def extract_message_and_prefix(args: Tuple) -> Tuple[Any, Tuple]:
36
+ """Extracts control_message and potential 'self' prefix."""
37
+ # (Keep the implementation from the original decorator)
38
+ if args and hasattr(args[0], "get_metadata"):
39
+ return args[0], ()
40
+ elif len(args) >= 2 and hasattr(args[1], "get_metadata"):
41
+ return args[1], (args[0],)
42
+ else:
43
+ # Be more specific in error if possible
44
+ arg_types = [type(arg).__name__ for arg in args]
45
+ raise ValueError(f"No IngestControlMessage found in first or second argument. Got types: {arg_types}")
46
+
47
+ def decorator(func: Callable) -> Callable:
48
+ func_name = func.__name__ # Get function name for logging/errors
49
+
50
+ # --- ASYNC WRAPPER ---
51
+ if asyncio.iscoroutinefunction(func):
52
+
53
+ @functools.wraps(func)
54
+ async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
55
+ logger.debug(f"async_wrapper for {func_name}: Entering.")
56
+ try:
57
+ control_message, prefix = extract_message_and_prefix(args)
58
+ except ValueError as e:
59
+ logger.error(f"async_wrapper for {func_name}: Failed to extract control message. Error: {e}")
60
+ raise # Cannot proceed without the message
61
+
62
+ # --- Skip logic ---
63
+ is_failed = control_message.get_metadata("cm_failed", False)
64
+ if is_failed and skip_processing_if_failed:
65
+ logger.debug(f"async_wrapper for {func_name}: Skipping processing, message already marked failed.")
66
+ if forward_func:
67
+ logger.debug("async_wrapper: Forwarding skipped message.")
68
+ # Await forward_func if it's async
69
+ if asyncio.iscoroutinefunction(forward_func):
70
+ return await forward_func(control_message)
71
+ else:
72
+ return forward_func(control_message)
73
+ else:
74
+ logger.debug("async_wrapper: Returning skipped message as is.")
75
+ return control_message
76
+
77
+ # --- Main execution block ---
78
+ result = None
79
+ try:
80
+ # Payload check
81
+ if not payload_can_be_empty:
82
+ cm_ensure_payload_not_null(control_message)
83
+
84
+ # Rebuild args and call original async function
85
+ new_args = prefix + (control_message,) + args[len(prefix) + 1 :]
86
+ logger.debug(f"async_wrapper for {func_name}: Calling await func...")
87
+ result = await func(*new_args, **kwargs)
88
+ logger.debug(f"async_wrapper for {func_name}: func call completed.")
89
+
90
+ # Success annotation
91
+ logger.debug(f"async_wrapper for {func_name}: Annotating success.")
92
+ annotate_task_result(
93
+ control_message=result if result is not None else control_message,
94
+ # Annotate result if func returns it, else original message
95
+ result=TaskResultStatus.SUCCESS,
96
+ task_id=annotation_id,
97
+ )
98
+ logger.debug(f"async_wrapper for {func_name}: Success annotation done. Returning result.")
99
+ return result
100
+
101
+ except Exception as e:
102
+ # --- Failure Handling ---
103
+ error_message = f"Error in {func_name}: {e}"
104
+ logger.error(f"async_wrapper for {func_name}: Caught exception: {error_message}", exc_info=True)
105
+
106
+ # Annotate failure on the original message object
107
+ try:
108
+ cm_set_failure(control_message, error_message)
109
+ annotate_task_result(
110
+ control_message=control_message,
111
+ result=TaskResultStatus.FAILURE,
112
+ task_id=annotation_id,
113
+ message=error_message,
114
+ )
115
+ logger.debug(f"async_wrapper for {func_name}: Failure annotation complete.")
116
+ except Exception as anno_err:
117
+ # Log error during annotation but proceed based on raise_on_failure
118
+ logger.exception(
119
+ f"async_wrapper for {func_name}: CRITICAL - Error during failure annotation: {anno_err}"
120
+ )
121
+
122
+ # Decide whether to raise or return annotated message
123
+ if raise_on_failure:
124
+ logger.debug(f"async_wrapper for {func_name}: Re-raising exception as configured.")
125
+ raise e # Re-raise the original exception
126
+ else:
127
+ logger.debug(
128
+ f"async_wrapper for {func_name}: Suppressing exception and returning annotated message."
129
+ )
130
+ # Return the original control_message, now annotated with failure
131
+ return control_message
132
+
133
+ return async_wrapper
134
+
135
+ # --- SYNC WRAPPER ---
136
+ else:
137
+
138
+ @functools.wraps(func)
139
+ def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
140
+ logger.debug(f"sync_wrapper for {func_name}: Entering.")
141
+ try:
142
+ control_message, prefix = extract_message_and_prefix(args)
143
+ except ValueError as e:
144
+ logger.error(f"sync_wrapper for {func_name}: Failed to extract control message. Error: {e}")
145
+ raise
146
+
147
+ # --- Skip logic ---
148
+ is_failed = control_message.get_metadata("cm_failed", False)
149
+ if is_failed and skip_processing_if_failed:
150
+ logger.warning(f"sync_wrapper for {func_name}: Skipping processing, message already marked failed.")
151
+ if forward_func:
152
+ logger.debug("sync_wrapper: Forwarding skipped message.")
153
+ return forward_func(control_message) # Assume forward_func is sync here
154
+ else:
155
+ logger.debug("sync_wrapper: Returning skipped message as is.")
156
+ return control_message
157
+
158
+ # --- Main execution block ---
159
+ result = None
160
+ try:
161
+ # Payload check
162
+ if not payload_can_be_empty:
163
+ cm_ensure_payload_not_null(control_message)
164
+
165
+ # Rebuild args and call original sync function
166
+ new_args = prefix + (control_message,) + args[len(prefix) + 1 :]
167
+ logger.debug(f"sync_wrapper for {func_name}: Calling func...")
168
+ result = func(*new_args, **kwargs)
169
+ logger.debug(f"sync_wrapper for {func_name}: func call completed.")
170
+
171
+ # Success annotation
172
+ logger.debug(f"sync_wrapper for {func_name}: Annotating success.")
173
+ annotate_task_result(
174
+ control_message=result if result is not None else control_message,
175
+ # Annotate result or original message
176
+ result=TaskResultStatus.SUCCESS,
177
+ task_id=annotation_id,
178
+ )
179
+ logger.debug(f"sync_wrapper for {func_name}: Success annotation done. Returning result.")
180
+ return result
181
+
182
+ except Exception as e:
183
+ # --- Failure Handling ---
184
+ error_message = f"Error in {func_name}: {e}"
185
+ logger.error(f"sync_wrapper for {func_name}: Caught exception: {error_message}", exc_info=True)
186
+
187
+ # Annotate failure on the original message object
188
+ try:
189
+ cm_set_failure(control_message, error_message)
190
+ annotate_task_result(
191
+ control_message=control_message,
192
+ result=TaskResultStatus.FAILURE,
193
+ task_id=annotation_id,
194
+ message=error_message,
195
+ )
196
+ logger.debug(f"sync_wrapper for {func_name}: Failure annotation complete.")
197
+ except Exception as anno_err:
198
+ logger.exception(
199
+ f"sync_wrapper for {func_name}: CRITICAL - Error during failure annotation: {anno_err}"
200
+ )
201
+
202
+ # Decide whether to raise or return annotated message
203
+ if raise_on_failure:
204
+ logger.debug(f"sync_wrapper for {func_name}: Re-raising exception as configured.")
205
+ raise e # Re-raise the original exception
206
+ else:
207
+ logger.debug(
208
+ f"sync_wrapper for {func_name}: Suppressing exception and returning annotated message."
209
+ )
210
+ # Return the original control_message, now annotated with failure
211
+ return control_message
212
+
213
+ return sync_wrapper
214
+
215
+ return decorator
216
+
217
+
21
218
  def nv_ingest_node_failure_context_manager(
22
219
  annotation_id: str,
23
220
  payload_can_be_empty: bool = False,
24
221
  raise_on_failure: bool = False,
25
222
  skip_processing_if_failed: bool = True,
26
- forward_func=None,
27
- ) -> typing.Callable:
223
+ forward_func: Optional[Callable[[Any], Any]] = None,
224
+ ) -> Callable:
28
225
  """
29
- A decorator that applies a default failure context manager around a function to manage
30
- the execution and potential failure of operations involving IngestControlMessages.
226
+ Decorator that applies a failure context manager around a function processing an IngestControlMessage.
227
+ Works with both synchronous and asynchronous functions, and supports class methods (with 'self').
31
228
 
32
229
  Parameters
33
230
  ----------
34
231
  annotation_id : str
35
- A unique identifier used for annotating the task's result.
232
+ A unique identifier for annotation.
36
233
  payload_can_be_empty : bool, optional
37
- If False, the payload of the IngestControlMessage will be checked to ensure it's not null,
38
- raising an exception if it is null. Defaults to False, enforcing payload presence.
234
+ If False, the message payload must not be null.
39
235
  raise_on_failure : bool, optional
40
- If True, an exception is raised if the decorated function encounters an error.
41
- Otherwise, the error is handled silently by annotating the IngestControlMessage. Defaults to False.
236
+ If True, exceptions are raised; otherwise, they are annotated.
42
237
  skip_processing_if_failed : bool, optional
43
- If True, skips the processing of the decorated function if the control message has already
44
- been marked as failed. If False, the function will be processed regardless of the failure
45
- status of the IngestControlMessage. Defaults to True.
46
- forward_func : callable, optional
47
- A function to forward the IngestControlMessage if it has already been marked as failed.
238
+ If True, skip processing if the message is already marked as failed.
239
+ forward_func : Optional[Callable[[Any], Any]]
240
+ If provided, a function to forward the message when processing is skipped.
48
241
 
49
242
  Returns
50
243
  -------
51
244
  Callable
52
- A decorator that wraps the given function with failure handling logic.
245
+ The decorated function.
53
246
  """
54
247
 
55
- def decorator(func):
56
- @wraps(func)
57
- def wrapper(control_message: IngestControlMessage, *args, **kwargs):
58
- # Quick return if the IngestControlMessage has already failed
59
- is_failed = control_message.get_metadata("cm_failed", False)
60
- if not is_failed or not skip_processing_if_failed:
61
- with CMNVIngestFailureContextManager(
62
- control_message=control_message,
63
- annotation_id=annotation_id,
64
- raise_on_failure=raise_on_failure,
65
- func_name=func.__name__,
66
- ) as ctx_mgr:
67
- if not payload_can_be_empty:
68
- cm_ensure_payload_not_null(control_message=control_message)
69
- control_message = func(ctx_mgr.control_message, *args, **kwargs)
70
- else:
71
- if forward_func:
72
- control_message = forward_func(control_message)
73
- return control_message
248
+ def extract_message_and_prefix(args: Tuple) -> Tuple[Any, Tuple]:
249
+ """
250
+ Determines if the function is a method (first argument is self) or a standalone function.
251
+ Returns a tuple (control_message, prefix) where prefix is a tuple of preceding arguments to be preserved.
252
+ """
253
+ if args and hasattr(args[0], "get_metadata"):
254
+ # Standalone function: first argument is the message.
255
+ return args[0], ()
256
+ elif len(args) >= 2 and hasattr(args[1], "get_metadata"):
257
+ # Method: first argument is self, second is the message.
258
+ return args[1], (args[0],)
259
+ else:
260
+ raise ValueError("No IngestControlMessage found in the first or second argument.")
74
261
 
75
- return wrapper
262
+ def decorator(func: Callable) -> Callable:
263
+ if asyncio.iscoroutinefunction(func):
264
+
265
+ @functools.wraps(func)
266
+ async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
267
+ control_message, prefix = extract_message_and_prefix(args)
268
+ is_failed = control_message.get_metadata("cm_failed", False)
269
+ if not is_failed or not skip_processing_if_failed:
270
+ ctx_mgr = CMNVIngestFailureContextManager(
271
+ control_message=control_message,
272
+ annotation_id=annotation_id,
273
+ raise_on_failure=raise_on_failure,
274
+ func_name=func.__name__,
275
+ )
276
+ try:
277
+ ctx_mgr.__enter__()
278
+ if not payload_can_be_empty:
279
+ cm_ensure_payload_not_null(control_message)
280
+ # Rebuild argument list preserving any prefix (e.g. self).
281
+ new_args = prefix + (ctx_mgr.control_message,) + args[len(prefix) + 1 :]
282
+ result = await func(*new_args, **kwargs)
283
+ except Exception as e:
284
+ ctx_mgr.__exit__(type(e), e, e.__traceback__)
285
+ raise
286
+ else:
287
+ ctx_mgr.__exit__(None, None, None)
288
+ return result
289
+ else:
290
+ if forward_func:
291
+ return await forward_func(control_message)
292
+ else:
293
+ return control_message
294
+
295
+ return async_wrapper
296
+ else:
297
+
298
+ @functools.wraps(func)
299
+ def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
300
+ control_message, prefix = extract_message_and_prefix(args)
301
+ is_failed = control_message.get_metadata("cm_failed", False)
302
+ if not is_failed or not skip_processing_if_failed:
303
+ with CMNVIngestFailureContextManager(
304
+ control_message=control_message,
305
+ annotation_id=annotation_id,
306
+ raise_on_failure=raise_on_failure,
307
+ func_name=func.__name__,
308
+ ) as ctx_mgr:
309
+ if not payload_can_be_empty:
310
+ cm_ensure_payload_not_null(control_message)
311
+ new_args = prefix + (ctx_mgr.control_message,) + args[len(prefix) + 1 :]
312
+ return func(*new_args, **kwargs)
313
+ else:
314
+ if forward_func:
315
+ return forward_func(control_message)
316
+ else:
317
+ return control_message
318
+
319
+ return sync_wrapper
76
320
 
77
321
  return decorator
78
322
 
@@ -81,7 +325,7 @@ def nv_ingest_source_failure_context_manager(
81
325
  annotation_id: str,
82
326
  payload_can_be_empty: bool = False,
83
327
  raise_on_failure: bool = False,
84
- ) -> typing.Callable:
328
+ ) -> Callable:
85
329
  """
86
330
  A decorator that ensures any function's output is treated as a IngestControlMessage for annotation.
87
331
  It applies a context manager to handle success and failure annotations based on the function's execution.
@@ -209,15 +453,29 @@ class CMNVIngestFailureContextManager:
209
453
 
210
454
 
211
455
  def unified_exception_handler(func):
212
- @functools.wraps(func)
213
- def wrapper(*args, **kwargs):
214
- try:
215
- return func(*args, **kwargs)
216
- except Exception as e:
217
- # Use the function's name in the error message
218
- func_name = func.__name__
219
- err_msg = f"{func_name}: error: {e}"
220
- logger.exception(err_msg, exc_info=True)
221
- raise type(e)(err_msg) from e
222
-
223
- return wrapper
456
+ if asyncio.iscoroutinefunction(func):
457
+
458
+ @functools.wraps(func)
459
+ async def async_wrapper(*args, **kwargs):
460
+ try:
461
+ return await func(*args, **kwargs)
462
+ except Exception as e:
463
+ func_name = func.__name__
464
+ err_msg = f"{func_name}: error: {e}"
465
+ logger.exception(err_msg, exc_info=True)
466
+ raise type(e)(err_msg) from e
467
+
468
+ return async_wrapper
469
+ else:
470
+
471
+ @functools.wraps(func)
472
+ def sync_wrapper(*args, **kwargs):
473
+ try:
474
+ return func(*args, **kwargs)
475
+ except Exception as e:
476
+ func_name = func.__name__
477
+ err_msg = f"{func_name}: error: {e}"
478
+ logger.exception(err_msg, exc_info=True)
479
+ raise type(e)(err_msg) from e
480
+
481
+ return sync_wrapper