nv-ingest-api 2025.8.14.dev20250814__py3-none-any.whl → 2025.8.15.dev20250815__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nv-ingest-api might be problematic. Click here for more details.

Files changed (24) hide show
  1. nv_ingest_api/internal/enums/common.py +37 -0
  2. nv_ingest_api/internal/extract/image/image_extractor.py +5 -1
  3. nv_ingest_api/internal/meta/__init__.py +3 -0
  4. nv_ingest_api/internal/meta/udf.py +232 -0
  5. nv_ingest_api/internal/primitives/ingest_control_message.py +63 -22
  6. nv_ingest_api/internal/primitives/tracing/tagging.py +102 -15
  7. nv_ingest_api/internal/schemas/meta/ingest_job_schema.py +40 -4
  8. nv_ingest_api/internal/schemas/meta/udf.py +23 -0
  9. nv_ingest_api/internal/transform/embed_text.py +5 -0
  10. nv_ingest_api/util/exception_handlers/decorators.py +104 -156
  11. nv_ingest_api/util/imports/callable_signatures.py +59 -1
  12. nv_ingest_api/util/imports/dynamic_resolvers.py +53 -5
  13. nv_ingest_api/util/introspection/__init__.py +3 -0
  14. nv_ingest_api/util/introspection/class_inspect.py +145 -0
  15. nv_ingest_api/util/introspection/function_inspect.py +65 -0
  16. nv_ingest_api/util/logging/configuration.py +71 -7
  17. nv_ingest_api/util/string_processing/configuration.py +682 -0
  18. nv_ingest_api/util/string_processing/yaml.py +45 -0
  19. nv_ingest_api/util/system/hardware_info.py +178 -13
  20. {nv_ingest_api-2025.8.14.dev20250814.dist-info → nv_ingest_api-2025.8.15.dev20250815.dist-info}/METADATA +1 -1
  21. {nv_ingest_api-2025.8.14.dev20250814.dist-info → nv_ingest_api-2025.8.15.dev20250815.dist-info}/RECORD +24 -16
  22. {nv_ingest_api-2025.8.14.dev20250814.dist-info → nv_ingest_api-2025.8.15.dev20250815.dist-info}/WHEEL +0 -0
  23. {nv_ingest_api-2025.8.14.dev20250814.dist-info → nv_ingest_api-2025.8.15.dev20250815.dist-info}/licenses/LICENSE +0 -0
  24. {nv_ingest_api-2025.8.14.dev20250814.dist-info → nv_ingest_api-2025.8.15.dev20250815.dist-info}/top_level.txt +0 -0
@@ -35,7 +35,7 @@ class IngestTaskSplitSchema(BaseModelNoExt):
35
35
  tokenizer: Optional[str] = None
36
36
  chunk_size: Annotated[int, Field(gt=0)] = 1024
37
37
  chunk_overlap: Annotated[int, Field(ge=0)] = 150
38
- params: dict
38
+ params: dict = Field(default_factory=dict)
39
39
 
40
40
  @field_validator("chunk_overlap")
41
41
  def check_chunk_overlap(cls, v, values, **kwargs):
@@ -47,7 +47,7 @@ class IngestTaskSplitSchema(BaseModelNoExt):
47
47
  class IngestTaskExtractSchema(BaseModelNoExt):
48
48
  document_type: DocumentTypeEnum
49
49
  method: str
50
- params: dict
50
+ params: dict = Field(default_factory=dict)
51
51
 
52
52
  @field_validator("document_type", mode="before")
53
53
  @classmethod
@@ -61,14 +61,14 @@ class IngestTaskExtractSchema(BaseModelNoExt):
61
61
 
62
62
 
63
63
  class IngestTaskStoreEmbedSchema(BaseModelNoExt):
64
- params: dict
64
+ params: dict = Field(default_factory=dict)
65
65
 
66
66
 
67
67
  class IngestTaskStoreSchema(BaseModelNoExt):
68
68
  structured: bool = True
69
69
  images: bool = False
70
70
  method: str
71
- params: dict
71
+ params: dict = Field(default_factory=dict)
72
72
 
73
73
 
74
74
  # Captioning: All fields are optional and override default parameters.
@@ -143,6 +143,40 @@ class IngestTaskInfographicExtraction(BaseModelNoExt):
143
143
  params: dict = Field(default_factory=dict)
144
144
 
145
145
 
146
+ class IngestTaskUDFSchema(BaseModelNoExt):
147
+ udf_function: str
148
+ udf_function_name: str
149
+ phase: Optional[int] = Field(default=None, ge=1, le=5)
150
+ run_before: bool = Field(default=False, description="Execute UDF before the target stage")
151
+ run_after: bool = Field(default=False, description="Execute UDF after the target stage")
152
+ target_stage: Optional[str] = Field(
153
+ default=None, description="Name of the stage to target (e.g., 'image_dedup', 'text_extract')"
154
+ )
155
+
156
+ @model_validator(mode="after")
157
+ def validate_stage_targeting(self):
158
+ """Validate that stage targeting configuration is consistent"""
159
+ # Must specify either phase or target_stage, but not both
160
+ has_phase = self.phase is not None
161
+ has_target_stage = self.target_stage is not None
162
+
163
+ if has_phase and has_target_stage:
164
+ raise ValueError("Cannot specify both 'phase' and 'target_stage'. Please specify only one.")
165
+ elif not has_phase and not has_target_stage:
166
+ raise ValueError("Must specify either 'phase' or 'target_stage'.")
167
+
168
+ # If using run_before or run_after, must specify target_stage
169
+ if self.run_before or self.run_after:
170
+ if not self.target_stage:
171
+ raise ValueError("target_stage must be specified when using run_before or run_after")
172
+
173
+ # If target_stage is specified, must have at least one timing
174
+ if self.target_stage and not (self.run_before or self.run_after):
175
+ raise ValueError("At least one of run_before or run_after must be True when target_stage is specified")
176
+
177
+ return self
178
+
179
+
146
180
  class IngestTaskSchema(BaseModelNoExt):
147
181
  type: TaskTypeEnum
148
182
  task_properties: Union[
@@ -159,6 +193,7 @@ class IngestTaskSchema(BaseModelNoExt):
159
193
  IngestTaskTableExtraction,
160
194
  IngestTaskChartExtraction,
161
195
  IngestTaskInfographicExtraction,
196
+ IngestTaskUDFSchema,
162
197
  ]
163
198
  raise_on_failure: bool = False
164
199
 
@@ -190,6 +225,7 @@ class IngestTaskSchema(BaseModelNoExt):
190
225
  TaskTypeEnum.TABLE_DATA_EXTRACT: IngestTaskTableExtraction,
191
226
  TaskTypeEnum.CHART_DATA_EXTRACT: IngestTaskChartExtraction,
192
227
  TaskTypeEnum.INFOGRAPHIC_DATA_EXTRACT: IngestTaskInfographicExtraction,
228
+ TaskTypeEnum.UDF: IngestTaskUDFSchema,
193
229
  }
194
230
 
195
231
  expected_schema_cls = task_type_to_schema.get(task_type)
@@ -0,0 +1,23 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ from pydantic import BaseModel, Field, ConfigDict
6
+
7
+
8
+ class UDFStageSchema(BaseModel):
9
+ """
10
+ Schema for UDF stage configuration.
11
+
12
+ The UDF function string should be provided in the task config. If no UDF function
13
+ is provided and ignore_empty_udf is True, the message is returned unchanged.
14
+ If ignore_empty_udf is False, an error is raised when no UDF function is provided.
15
+ """
16
+
17
+ ignore_empty_udf: bool = Field(
18
+ False,
19
+ description="If True, ignore UDF tasks without udf_function and return message unchanged. "
20
+ "If False, raise error.",
21
+ )
22
+
23
+ model_config = ConfigDict(extra="forbid")
@@ -15,6 +15,11 @@ from nv_ingest_api.internal.schemas.transform.transform_text_embedding_schema im
15
15
 
16
16
  logger = logging.getLogger(__name__)
17
17
 
18
+ # Reduce SDK HTTP logging verbosity so request/response logs are not emitted
19
+ logging.getLogger("openai").setLevel(logging.ERROR)
20
+ logging.getLogger("httpx").setLevel(logging.ERROR)
21
+ logging.getLogger("httpcore").setLevel(logging.ERROR)
22
+
18
23
 
19
24
  MULTI_MODAL_MODELS = ["llama-3.2-nemoretriever-1b-vlm-embed-v1"]
20
25
 
@@ -18,7 +18,7 @@ logger = logging.getLogger(__name__)
18
18
 
19
19
 
20
20
  def nv_ingest_node_failure_try_except( # New name to distinguish
21
- annotation_id: str,
21
+ annotation_id: Optional[str] = None,
22
22
  payload_can_be_empty: bool = False,
23
23
  raise_on_failure: bool = False,
24
24
  skip_processing_if_failed: bool = True,
@@ -29,7 +29,19 @@ def nv_ingest_node_failure_try_except( # New name to distinguish
29
29
  failures by annotating an IngestControlMessage. Replaces the context
30
30
  manager approach for potentially simpler interaction with frameworks like Ray.
31
31
 
32
- Parameters are the same as nv_ingest_node_failure_context_manager.
32
+ Parameters
33
+ ----------
34
+ annotation_id : Optional[str]
35
+ A unique identifier for annotation. If None, attempts to auto-detect
36
+ from the stage instance's stage_name property.
37
+ payload_can_be_empty : bool, optional
38
+ If False, the message payload must not be null.
39
+ raise_on_failure : bool, optional
40
+ If True, exceptions are raised; otherwise, they are annotated.
41
+ skip_processing_if_failed : bool, optional
42
+ If True, skip processing if the message is already marked as failed.
43
+ forward_func : Optional[Callable[[Any], Any]]
44
+ If provided, a function to forward the message when processing is skipped.
33
45
  """
34
46
 
35
47
  def extract_message_and_prefix(args: Tuple) -> Tuple[Any, Tuple]:
@@ -47,170 +59,106 @@ def nv_ingest_node_failure_try_except( # New name to distinguish
47
59
  def decorator(func: Callable) -> Callable:
48
60
  func_name = func.__name__ # Get function name for logging/errors
49
61
 
50
- # --- ASYNC WRAPPER ---
51
- if asyncio.iscoroutinefunction(func):
52
-
53
- @functools.wraps(func)
54
- async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
55
- logger.debug(f"async_wrapper for {func_name}: Entering.")
56
- try:
57
- control_message, prefix = extract_message_and_prefix(args)
58
- except ValueError as e:
59
- logger.error(f"async_wrapper for {func_name}: Failed to extract control message. Error: {e}")
60
- raise # Cannot proceed without the message
61
-
62
- # --- Skip logic ---
63
- is_failed = control_message.get_metadata("cm_failed", False)
64
- if is_failed and skip_processing_if_failed:
65
- logger.debug(f"async_wrapper for {func_name}: Skipping processing, message already marked failed.")
66
- if forward_func:
67
- logger.debug("async_wrapper: Forwarding skipped message.")
68
- # Await forward_func if it's async
69
- if asyncio.iscoroutinefunction(forward_func):
70
- return await forward_func(control_message)
71
- else:
72
- return forward_func(control_message)
73
- else:
74
- logger.debug("async_wrapper: Returning skipped message as is.")
75
- return control_message
76
-
77
- # --- Main execution block ---
78
- result = None
79
- try:
80
- # Payload check
81
- if not payload_can_be_empty:
82
- cm_ensure_payload_not_null(control_message)
83
-
84
- # Rebuild args and call original async function
85
- new_args = prefix + (control_message,) + args[len(prefix) + 1 :]
86
- logger.debug(f"async_wrapper for {func_name}: Calling await func...")
87
- result = await func(*new_args, **kwargs)
88
- logger.debug(f"async_wrapper for {func_name}: func call completed.")
89
-
90
- # Success annotation
91
- logger.debug(f"async_wrapper for {func_name}: Annotating success.")
92
- annotate_task_result(
93
- control_message=result if result is not None else control_message,
94
- # Annotate result if func returns it, else original message
95
- result=TaskResultStatus.SUCCESS,
96
- task_id=annotation_id,
62
+ @functools.wraps(func)
63
+ def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
64
+ logger.debug(f"sync_wrapper for {func_name}: Entering.")
65
+
66
+ # Determine the annotation_id to use
67
+ resolved_annotation_id = annotation_id
68
+
69
+ # If no explicit annotation_id provided, try to get it from self.stage_name
70
+ if resolved_annotation_id is None and len(args) >= 1:
71
+ stage_instance = args[0] # 'self' in method calls
72
+ if hasattr(stage_instance, "stage_name") and stage_instance.stage_name:
73
+ resolved_annotation_id = stage_instance.stage_name
74
+ logger.debug("Using auto-detected annotation_id from stage_name: " f"'{resolved_annotation_id}'")
75
+ else:
76
+ # Fallback to function name if no stage_name available
77
+ resolved_annotation_id = func_name
78
+ logger.debug(
79
+ "No stage_name available, using function name as annotation_id: " f"'{resolved_annotation_id}'"
97
80
  )
98
- logger.debug(f"async_wrapper for {func_name}: Success annotation done. Returning result.")
99
- return result
100
-
101
- except Exception as e:
102
- # --- Failure Handling ---
103
- error_message = f"Error in {func_name}: {e}"
104
- logger.error(f"async_wrapper for {func_name}: Caught exception: {error_message}", exc_info=True)
105
-
106
- # Annotate failure on the original message object
107
- try:
108
- cm_set_failure(control_message, error_message)
109
- annotate_task_result(
110
- control_message=control_message,
111
- result=TaskResultStatus.FAILURE,
112
- task_id=annotation_id,
113
- message=error_message,
114
- )
115
- logger.debug(f"async_wrapper for {func_name}: Failure annotation complete.")
116
- except Exception as anno_err:
117
- # Log error during annotation but proceed based on raise_on_failure
118
- logger.exception(
119
- f"async_wrapper for {func_name}: CRITICAL - Error during failure annotation: {anno_err}"
120
- )
121
-
122
- # Decide whether to raise or return annotated message
123
- if raise_on_failure:
124
- logger.debug(f"async_wrapper for {func_name}: Re-raising exception as configured.")
125
- raise e # Re-raise the original exception
126
- else:
127
- logger.debug(
128
- f"async_wrapper for {func_name}: Suppressing exception and returning annotated message."
129
- )
130
- # Return the original control_message, now annotated with failure
131
- return control_message
132
-
133
- return async_wrapper
81
+ elif resolved_annotation_id is None:
82
+ # Fallback to function name if no annotation_id and no instance
83
+ resolved_annotation_id = func_name
84
+ logger.debug(
85
+ "No annotation_id provided and no instance available, using function name: "
86
+ f"'{resolved_annotation_id}'"
87
+ )
134
88
 
135
- # --- SYNC WRAPPER ---
136
- else:
89
+ try:
90
+ control_message, prefix = extract_message_and_prefix(args)
91
+ except ValueError as e:
92
+ logger.error(f"sync_wrapper for {func_name}: Failed to extract control message. Error: {e}")
93
+ raise
94
+
95
+ # --- Skip logic ---
96
+ is_failed = control_message.get_metadata("cm_failed", False)
97
+ if is_failed and skip_processing_if_failed:
98
+ logger.warning(f"sync_wrapper for {func_name}: Skipping processing, message already marked failed.")
99
+ if forward_func:
100
+ logger.debug("sync_wrapper: Forwarding skipped message.")
101
+ return forward_func(control_message) # Assume forward_func is sync here
102
+ else:
103
+ logger.debug("sync_wrapper: Returning skipped message as is.")
104
+ return control_message
137
105
 
138
- @functools.wraps(func)
139
- def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
140
- logger.debug(f"sync_wrapper for {func_name}: Entering.")
141
- try:
142
- control_message, prefix = extract_message_and_prefix(args)
143
- except ValueError as e:
144
- logger.error(f"sync_wrapper for {func_name}: Failed to extract control message. Error: {e}")
145
- raise
106
+ # --- Main execution block ---
107
+ result = None
108
+ try:
109
+ # Payload check
110
+ if not payload_can_be_empty:
111
+ cm_ensure_payload_not_null(control_message)
112
+
113
+ # Rebuild args and call original sync function
114
+ new_args = prefix + (control_message,) + args[len(prefix) + 1 :]
115
+ logger.debug(f"sync_wrapper for {func_name}: Calling func...")
116
+ result = func(*new_args, **kwargs)
117
+ logger.debug(f"sync_wrapper for {func_name}: func call completed.")
118
+
119
+ # Success annotation
120
+ logger.debug(f"sync_wrapper for {func_name}: Annotating success.")
121
+ annotate_task_result(
122
+ control_message=result if result is not None else control_message,
123
+ # Annotate result or original message
124
+ result=TaskResultStatus.SUCCESS,
125
+ task_id=resolved_annotation_id,
126
+ )
127
+ logger.debug(f"sync_wrapper for {func_name}: Success annotation done. Returning result.")
128
+ return result
146
129
 
147
- # --- Skip logic ---
148
- is_failed = control_message.get_metadata("cm_failed", False)
149
- if is_failed and skip_processing_if_failed:
150
- logger.warning(f"sync_wrapper for {func_name}: Skipping processing, message already marked failed.")
151
- if forward_func:
152
- logger.debug("sync_wrapper: Forwarding skipped message.")
153
- return forward_func(control_message) # Assume forward_func is sync here
154
- else:
155
- logger.debug("sync_wrapper: Returning skipped message as is.")
156
- return control_message
130
+ except Exception as e:
131
+ # --- Failure Handling ---
132
+ error_message = f"Error in {func_name}: {e}"
133
+ logger.error(f"sync_wrapper for {func_name}: Caught exception: {error_message}", exc_info=True)
157
134
 
158
- # --- Main execution block ---
159
- result = None
135
+ # Annotate failure on the original message object
160
136
  try:
161
- # Payload check
162
- if not payload_can_be_empty:
163
- cm_ensure_payload_not_null(control_message)
164
-
165
- # Rebuild args and call original sync function
166
- new_args = prefix + (control_message,) + args[len(prefix) + 1 :]
167
- logger.debug(f"sync_wrapper for {func_name}: Calling func...")
168
- result = func(*new_args, **kwargs)
169
- logger.debug(f"sync_wrapper for {func_name}: func call completed.")
170
-
171
- # Success annotation
172
- logger.debug(f"sync_wrapper for {func_name}: Annotating success.")
137
+ cm_set_failure(control_message, error_message)
173
138
  annotate_task_result(
174
- control_message=result if result is not None else control_message,
175
- # Annotate result or original message
176
- result=TaskResultStatus.SUCCESS,
177
- task_id=annotation_id,
139
+ control_message=control_message,
140
+ result=TaskResultStatus.FAILURE,
141
+ task_id=resolved_annotation_id,
142
+ message=error_message,
143
+ )
144
+ logger.debug(f"sync_wrapper for {func_name}: Failure annotation complete.")
145
+ except Exception as anno_err:
146
+ logger.exception(
147
+ f"sync_wrapper for {func_name}: CRITICAL - Error during failure annotation: {anno_err}"
178
148
  )
179
- logger.debug(f"sync_wrapper for {func_name}: Success annotation done. Returning result.")
180
- return result
181
-
182
- except Exception as e:
183
- # --- Failure Handling ---
184
- error_message = f"Error in {func_name}: {e}"
185
- logger.error(f"sync_wrapper for {func_name}: Caught exception: {error_message}", exc_info=True)
186
149
 
187
- # Annotate failure on the original message object
188
- try:
189
- cm_set_failure(control_message, error_message)
190
- annotate_task_result(
191
- control_message=control_message,
192
- result=TaskResultStatus.FAILURE,
193
- task_id=annotation_id,
194
- message=error_message,
195
- )
196
- logger.debug(f"sync_wrapper for {func_name}: Failure annotation complete.")
197
- except Exception as anno_err:
198
- logger.exception(
199
- f"sync_wrapper for {func_name}: CRITICAL - Error during failure annotation: {anno_err}"
200
- )
201
-
202
- # Decide whether to raise or return annotated message
203
- if raise_on_failure:
204
- logger.debug(f"sync_wrapper for {func_name}: Re-raising exception as configured.")
205
- raise e # Re-raise the original exception
206
- else:
207
- logger.debug(
208
- f"sync_wrapper for {func_name}: Suppressing exception and returning annotated message."
209
- )
210
- # Return the original control_message, now annotated with failure
211
- return control_message
150
+ # Decide whether to raise or return annotated message
151
+ if raise_on_failure:
152
+ logger.debug(f"sync_wrapper for {func_name}: Re-raising exception as configured.")
153
+ raise e # Re-raise the original exception
154
+ else:
155
+ logger.debug(
156
+ f"sync_wrapper for {func_name}: Suppressing exception and returning annotated message."
157
+ )
158
+ # Return the original control_message, now annotated with failure
159
+ return control_message
212
160
 
213
- return sync_wrapper
161
+ return sync_wrapper
214
162
 
215
163
  return decorator
216
164
 
@@ -14,6 +14,8 @@ def ingest_stage_callable_signature(sig: inspect.Signature):
14
14
  Validates that a callable has the signature:
15
15
  (IngestControlMessage, BaseModel) -> IngestControlMessage
16
16
 
17
+ Also allows for generic (*args, **kwargs) signatures for flexibility with class constructors.
18
+
17
19
  Raises
18
20
  ------
19
21
  TypeError
@@ -21,11 +23,15 @@ def ingest_stage_callable_signature(sig: inspect.Signature):
21
23
  """
22
24
  params = list(sig.parameters.values())
23
25
 
26
+ # If the signature accepts arbitrary keyword arguments, it's flexible enough.
27
+ if any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params):
28
+ return
29
+
24
30
  if len(params) != 2:
25
31
  raise TypeError(f"Expected exactly 2 parameters, got {len(params)}")
26
32
 
27
33
  if params[0].name != "control_message" or params[1].name != "stage_config":
28
- raise TypeError("Expected parameter names: 'control_message', 'config'")
34
+ raise TypeError("Expected parameter names: 'control_message', 'stage_config'")
29
35
 
30
36
  first_param = params[0].annotation
31
37
  second_param = params[1].annotation
@@ -48,3 +54,55 @@ def ingest_stage_callable_signature(sig: inspect.Signature):
48
54
 
49
55
  if not issubclass(return_type, IngestControlMessage):
50
56
  raise TypeError(f"Return type must be IngestControlMessage, got {return_type}")
57
+
58
+
59
+ def ingest_callable_signature(sig: inspect.Signature):
60
+ """
61
+ Validates that a callable has the signature:
62
+ (IngestControlMessage) -> IngestControlMessage
63
+
64
+ Also allows for generic (*args, **kwargs) signatures for flexibility with class constructors.
65
+
66
+ Raises
67
+ ------
68
+ TypeError
69
+ If the signature does not match the expected pattern.
70
+ """
71
+ params = list(sig.parameters.values())
72
+
73
+ # If the signature accepts arbitrary keyword arguments, it's flexible enough.
74
+ if any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params):
75
+ return
76
+
77
+ if len(params) != 1:
78
+ raise TypeError(f"Expected exactly 1 parameter, got {len(params)}")
79
+
80
+ if params[0].name != "control_message":
81
+ raise TypeError("Expected parameter name: 'control_message'")
82
+
83
+ first_param = params[0].annotation
84
+ return_type = sig.return_annotation
85
+
86
+ if first_param is inspect.Parameter.empty:
87
+ raise TypeError("Parameter must be annotated with IngestControlMessage")
88
+
89
+ if return_type is inspect.Signature.empty:
90
+ raise TypeError("Return type must be annotated with IngestControlMessage")
91
+
92
+ # Handle string annotations (forward references)
93
+ if isinstance(first_param, str):
94
+ if first_param != "IngestControlMessage":
95
+ raise TypeError(f"Parameter must be IngestControlMessage, got {first_param}")
96
+ else:
97
+ # Handle actual class annotations
98
+ if not issubclass(first_param, IngestControlMessage):
99
+ raise TypeError(f"Parameter must be IngestControlMessage, got {first_param}")
100
+
101
+ # Handle string annotations for return type
102
+ if isinstance(return_type, str):
103
+ if return_type != "IngestControlMessage":
104
+ raise TypeError(f"Return type must be IngestControlMessage, got {return_type}")
105
+ else:
106
+ # Handle actual class annotations
107
+ if not issubclass(return_type, IngestControlMessage):
108
+ raise TypeError(f"Return type must be IngestControlMessage, got {return_type}")
@@ -6,6 +6,8 @@ import importlib
6
6
  import inspect
7
7
  from typing import Callable, Union, List, Optional
8
8
 
9
+ from nv_ingest.framework.orchestration.ray.stages.meta.ray_actor_stage_base import RayActorStage
10
+
9
11
 
10
12
  def resolve_obj_from_path(path: str, allowed_base_paths: Optional[List[str]] = None) -> object:
11
13
  """
@@ -99,12 +101,58 @@ def resolve_callable_from_path(
99
101
  try:
100
102
  schema_checker(sig)
101
103
  except Exception as e:
102
- raise TypeError(
103
- f"Callable at '{callable_path}' failed custom signature validation:\n"
104
- f" Signature: {sig}\n"
105
- f" Error: {e}"
106
- ) from e
104
+ raise TypeError(f"Signature validation for '{callable_path}' failed: {e}") from e
107
105
  else:
108
106
  raise TypeError(f"Invalid signature_schema: expected list, callable, or str, got {type(signature_schema)}")
109
107
 
110
108
  return obj
109
+
110
+
111
+ def resolve_actor_class_from_path(
112
+ path: str, expected_base_class: type, allowed_base_paths: Optional[List[str]] = None
113
+ ) -> type:
114
+ """
115
+ Resolves an actor class from a path and validates that it is a class
116
+ that inherits from the expected base class. This function correctly handles
117
+ decorated Ray actors by inspecting their original class.
118
+
119
+ Parameters
120
+ ----------
121
+ path : str
122
+ The full import path to the actor class.
123
+ expected_base_class : type
124
+ The base class that the resolved class must inherit from.
125
+ allowed_base_paths : Optional[List[str]]
126
+ An optional list of base module paths from which imports are allowed.
127
+
128
+ Returns
129
+ -------
130
+ type
131
+ The resolved actor class (or Ray actor factory).
132
+ """
133
+ obj = resolve_obj_from_path(path, allowed_base_paths=allowed_base_paths)
134
+
135
+ # Determine the class to validate. If it's a Ray actor factory, we need to
136
+ # inspect its MRO to find the original user-defined class.
137
+ cls_to_validate = None
138
+ if inspect.isclass(obj):
139
+ cls_to_validate = obj
140
+ else:
141
+ # For actor factories, find the base class in the MRO that inherits from RayActorStage
142
+ for base in obj.__class__.__mro__:
143
+ if inspect.isclass(base) and issubclass(base, RayActorStage) and base is not RayActorStage:
144
+ cls_to_validate = base
145
+ break
146
+
147
+ if cls_to_validate is None:
148
+ raise TypeError(
149
+ f"Could not resolve a valid actor class from path '{path}'. "
150
+ f"The object is not a class and not a recognized actor factory."
151
+ )
152
+
153
+ if not issubclass(cls_to_validate, expected_base_class):
154
+ raise TypeError(
155
+ f"Actor class '{cls_to_validate.__name__}' at '{path}' must inherit from '{expected_base_class.__name__}'."
156
+ )
157
+
158
+ return obj
@@ -0,0 +1,3 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0