nvidia-nat 1.3.0a20250910__py3-none-any.whl → 1.3.0a20250917__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (100) hide show
  1. nat/agent/base.py +9 -4
  2. nat/agent/prompt_optimizer/prompt.py +68 -0
  3. nat/agent/prompt_optimizer/register.py +149 -0
  4. nat/agent/react_agent/agent.py +1 -1
  5. nat/agent/react_agent/register.py +15 -5
  6. nat/agent/reasoning_agent/reasoning_agent.py +6 -1
  7. nat/agent/register.py +2 -0
  8. nat/agent/rewoo_agent/agent.py +4 -2
  9. nat/agent/rewoo_agent/register.py +8 -3
  10. nat/agent/router_agent/__init__.py +0 -0
  11. nat/agent/router_agent/agent.py +329 -0
  12. nat/agent/router_agent/prompt.py +48 -0
  13. nat/agent/router_agent/register.py +97 -0
  14. nat/agent/tool_calling_agent/agent.py +69 -7
  15. nat/agent/tool_calling_agent/register.py +11 -3
  16. nat/builder/builder.py +27 -4
  17. nat/builder/component_utils.py +7 -3
  18. nat/builder/function.py +167 -0
  19. nat/builder/function_info.py +1 -1
  20. nat/builder/workflow.py +5 -0
  21. nat/builder/workflow_builder.py +213 -16
  22. nat/cli/commands/optimize.py +90 -0
  23. nat/cli/commands/workflow/templates/config.yml.j2 +0 -1
  24. nat/cli/commands/workflow/workflow_commands.py +4 -7
  25. nat/cli/entrypoint.py +2 -0
  26. nat/cli/register_workflow.py +38 -4
  27. nat/cli/type_registry.py +71 -0
  28. nat/data_models/component.py +2 -0
  29. nat/data_models/component_ref.py +11 -0
  30. nat/data_models/config.py +40 -16
  31. nat/data_models/function.py +34 -0
  32. nat/data_models/function_dependencies.py +8 -0
  33. nat/data_models/optimizable.py +119 -0
  34. nat/data_models/optimizer.py +149 -0
  35. nat/data_models/temperature_mixin.py +4 -3
  36. nat/data_models/top_p_mixin.py +4 -3
  37. nat/embedder/nim_embedder.py +1 -1
  38. nat/embedder/openai_embedder.py +1 -1
  39. nat/eval/config.py +1 -1
  40. nat/eval/evaluate.py +5 -1
  41. nat/eval/register.py +4 -0
  42. nat/eval/runtime_evaluator/__init__.py +14 -0
  43. nat/eval/runtime_evaluator/evaluate.py +123 -0
  44. nat/eval/runtime_evaluator/register.py +100 -0
  45. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +5 -1
  46. nat/front_ends/fastapi/dask_client_mixin.py +43 -0
  47. nat/front_ends/fastapi/fastapi_front_end_config.py +14 -3
  48. nat/front_ends/fastapi/fastapi_front_end_plugin.py +111 -3
  49. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +243 -228
  50. nat/front_ends/fastapi/job_store.py +518 -99
  51. nat/front_ends/fastapi/main.py +11 -19
  52. nat/front_ends/fastapi/utils.py +57 -0
  53. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +3 -2
  54. nat/llm/aws_bedrock_llm.py +14 -3
  55. nat/llm/nim_llm.py +14 -3
  56. nat/llm/openai_llm.py +8 -1
  57. nat/observability/exporter/processing_exporter.py +29 -55
  58. nat/observability/mixin/redaction_config_mixin.py +5 -4
  59. nat/observability/mixin/tagging_config_mixin.py +26 -14
  60. nat/observability/mixin/type_introspection_mixin.py +401 -107
  61. nat/observability/processor/processor.py +3 -0
  62. nat/observability/processor/redaction/__init__.py +24 -0
  63. nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
  64. nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
  65. nat/observability/processor/redaction/redaction_processor.py +177 -0
  66. nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
  67. nat/observability/processor/span_tagging_processor.py +21 -14
  68. nat/profiler/decorators/framework_wrapper.py +9 -6
  69. nat/profiler/parameter_optimization/__init__.py +0 -0
  70. nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
  71. nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
  72. nat/profiler/parameter_optimization/parameter_optimizer.py +149 -0
  73. nat/profiler/parameter_optimization/parameter_selection.py +108 -0
  74. nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
  75. nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
  76. nat/profiler/parameter_optimization/update_helpers.py +66 -0
  77. nat/profiler/utils.py +3 -1
  78. nat/tool/chat_completion.py +4 -1
  79. nat/tool/github_tools.py +450 -0
  80. nat/tool/register.py +2 -7
  81. nat/utils/callable_utils.py +70 -0
  82. nat/utils/exception_handlers/automatic_retries.py +103 -48
  83. nat/utils/type_utils.py +4 -0
  84. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/METADATA +8 -1
  85. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/RECORD +91 -71
  86. nat/observability/processor/header_redaction_processor.py +0 -123
  87. nat/observability/processor/redaction_processor.py +0 -77
  88. nat/tool/github_tools/create_github_commit.py +0 -133
  89. nat/tool/github_tools/create_github_issue.py +0 -87
  90. nat/tool/github_tools/create_github_pr.py +0 -106
  91. nat/tool/github_tools/get_github_file.py +0 -106
  92. nat/tool/github_tools/get_github_issue.py +0 -166
  93. nat/tool/github_tools/get_github_pr.py +0 -256
  94. nat/tool/github_tools/update_github_issue.py +0 -100
  95. /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
  96. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/WHEEL +0 -0
  97. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/entry_points.txt +0 -0
  98. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  99. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/licenses/LICENSE.md +0 -0
  100. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/top_level.txt +0 -0
@@ -13,171 +13,465 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import inspect
17
+ import logging
16
18
  from functools import lru_cache
17
19
  from typing import Any
20
+ from typing import TypeVar
18
21
  from typing import get_args
19
22
  from typing import get_origin
20
23
 
24
+ from pydantic import BaseModel
25
+ from pydantic import ValidationError
26
+ from pydantic import create_model
27
+ from pydantic.fields import FieldInfo
28
+
29
+ from nat.utils.type_utils import DecomposedType
30
+
31
+ logger = logging.getLogger(__name__)
32
+
21
33
 
22
34
  class TypeIntrospectionMixin:
23
- """Mixin class providing type introspection capabilities for generic classes.
35
+ """Hybrid mixin class providing type introspection capabilities for generic classes.
24
36
 
25
- This mixin extracts type information from generic class definitions,
26
- allowing classes to determine their InputT and OutputT types at runtime.
37
+ This mixin combines the DecomposedType class utilities with MRO traversal
38
+ to properly handle complex inheritance chains like HeaderRedactionProcessor or ProcessingExporter.
27
39
  """
28
40
 
29
- def _find_generic_types(self) -> tuple[type[Any], type[Any]] | None:
41
+ def _extract_types_from_signature_method(self) -> tuple[type[Any], type[Any]] | None:
42
+ """Extract input/output types from the signature method.
43
+
44
+ This method looks for a signature method (either defined via _signature_method class
45
+ attribute or discovered generically) and extracts input/output types from
46
+ its method signature.
47
+
48
+ Returns:
49
+ tuple[type[Any], type[Any]] | None: (input_type, output_type) or None if not found.
30
50
  """
31
- Recursively search through the inheritance hierarchy to find generic type parameters.
51
+ # First, try to get the signature method name from the class
52
+ signature_method_name = getattr(self.__class__, '_signature_method', None)
53
+
54
+ # If not defined, try to discover it generically
55
+ if not signature_method_name:
56
+ signature_method_name = self._discover_signature_method()
57
+
58
+ if not signature_method_name:
59
+ return None
60
+
61
+ # Get the method and inspect its signature
62
+ try:
63
+ method = getattr(self, signature_method_name)
64
+ sig = inspect.signature(method)
65
+
66
+ # Find the first parameter that's not 'self'
67
+ params = list(sig.parameters.values())
68
+ input_param = None
69
+ for param in params:
70
+ if param.name != 'self':
71
+ input_param = param
72
+ break
73
+
74
+ if not input_param or input_param.annotation == inspect.Parameter.empty:
75
+ return None
76
+
77
+ # Get return type
78
+ return_annotation = sig.return_annotation
79
+ if return_annotation == inspect.Signature.empty:
80
+ return None
81
+
82
+ input_type = input_param.annotation
83
+ output_type = return_annotation
84
+
85
+ # Resolve any TypeVars if needed (including nested ones)
86
+ if isinstance(input_type, TypeVar) or isinstance(
87
+ output_type, TypeVar) or self._contains_typevar(input_type) or self._contains_typevar(output_type):
88
+ # Try to resolve using the MRO approach as fallback
89
+ typevar_mapping = self._build_typevar_mapping()
90
+ input_type = self._resolve_typevar_recursively(input_type, typevar_mapping)
91
+ output_type = self._resolve_typevar_recursively(output_type, typevar_mapping)
92
+
93
+ # Only return if we have concrete types
94
+ if not isinstance(input_type, TypeVar) and not isinstance(output_type, TypeVar):
95
+ return input_type, output_type
96
+
97
+ except (AttributeError, TypeError) as e:
98
+ logger.debug("Failed to extract types from signature method '%s': %s", signature_method_name, e)
32
99
 
33
- This method handles cases where a class inherits from a generic parent class,
34
- resolving the concrete types through the inheritance chain.
100
+ return None
101
+
102
+ def _discover_signature_method(self) -> str | None:
103
+ """Discover any method suitable for type introspection.
104
+
105
+ Looks for any method with the signature pattern: method(self, param: Type) -> ReturnType
106
+ Any method matching this pattern is functionally equivalent for type introspection purposes.
35
107
 
36
108
  Returns:
37
- tuple[type[Any], type[Any]] | None: (input_type, output_type) if found, None otherwise
109
+ str | None: Method name or None if not found
38
110
  """
39
- # First, try to find types directly in this class's __orig_bases__
40
- for base_cls in getattr(self.__class__, '__orig_bases__', []):
41
- base_cls_args = get_args(base_cls)
111
+ # Look through all methods to find ones that match the input/output pattern
112
+ candidates = []
42
113
 
43
- # Direct case: MyClass[InputT, OutputT]
44
- if len(base_cls_args) >= 2:
45
- return base_cls_args[0], base_cls_args[1]
114
+ for cls in self.__class__.__mro__:
115
+ for name, method in inspect.getmembers(cls, inspect.isfunction):
116
+ # Skip private methods except dunder methods
117
+ if name.startswith('_') and not name.startswith('__'):
118
+ continue
46
119
 
47
- # Indirect case: MyClass[SomeGeneric[ConcreteType]]
48
- # Need to resolve the generic parent's types
49
- if len(base_cls_args) == 1:
50
- base_origin = get_origin(base_cls)
51
- if base_origin and hasattr(base_origin, '__orig_bases__'):
52
- # Look at the parent's generic definition
53
- for parent_base in getattr(base_origin, '__orig_bases__', []):
54
- parent_args = get_args(parent_base)
55
- if len(parent_args) >= 2:
56
- # Found the pattern: ParentClass[T, list[T]]
57
- # Substitute T with our concrete type
58
- concrete_type = base_cls_args[0]
59
- input_type = self._substitute_type_var(parent_args[0], concrete_type)
60
- output_type = self._substitute_type_var(parent_args[1], concrete_type)
61
- return input_type, output_type
120
+ # Skip methods that were defined in TypeIntrospectionMixin
121
+ if hasattr(method, '__qualname__') and 'TypeIntrospectionMixin' in method.__qualname__:
122
+ logger.debug("Skipping method '%s' defined in TypeIntrospectionMixin", name)
123
+ continue
62
124
 
63
- return None
125
+ # Let signature analysis determine suitability - method names don't matter
126
+ try:
127
+ sig = inspect.signature(method)
128
+ params = list(sig.parameters.values())
129
+
130
+ # Look for methods with exactly one non-self parameter and a return annotation
131
+ non_self_params = [p for p in params if p.name != 'self']
132
+ if (len(non_self_params) == 1 and non_self_params[0].annotation != inspect.Parameter.empty
133
+ and sig.return_annotation != inspect.Signature.empty):
134
+
135
+ # Prioritize abstract methods
136
+ is_abstract = getattr(method, '__isabstractmethod__', False)
137
+ candidates.append((name, is_abstract, cls))
138
+
139
+ except (TypeError, ValueError) as e:
140
+ logger.debug("Failed to inspect signature of method '%s': %s", name, e)
64
141
 
65
- def _substitute_type_var(self, type_expr: Any, concrete_type: type) -> type[Any]:
142
+ if not candidates:
143
+ logger.debug("No candidates found for signature method")
144
+ return None
145
+
146
+ # Any method with the right signature will work for type introspection
147
+ # Prioritize abstract methods if available, otherwise use the first valid one
148
+ candidates.sort(key=lambda x: not x[1]) # Abstract methods first
149
+ return candidates[0][0]
150
+
151
+ def _resolve_typevar_recursively(self, type_arg: Any, typevar_mapping: dict[TypeVar, type[Any]]) -> Any:
152
+ """Recursively resolve TypeVars within complex types.
153
+
154
+ Args:
155
+ type_arg (Any): The type argument to resolve (could be a TypeVar, generic type, etc.)
156
+ typevar_mapping (dict[TypeVar, type[Any]]): Current mapping of TypeVars to concrete types
157
+
158
+ Returns:
159
+ Any: The resolved type with all TypeVars substituted
66
160
  """
67
- Substitute TypeVar in a type expression with a concrete type.
161
+ # If it's a TypeVar, resolve it
162
+ if isinstance(type_arg, TypeVar):
163
+ return typevar_mapping.get(type_arg, type_arg)
164
+
165
+ # If it's a generic type, decompose and resolve its arguments
166
+ try:
167
+ decomposed = DecomposedType(type_arg)
168
+ if decomposed.is_generic and decomposed.args:
169
+ # Recursively resolve all type arguments
170
+ resolved_args = []
171
+ for arg in decomposed.args:
172
+ resolved_arg = self._resolve_typevar_recursively(arg, typevar_mapping)
173
+ resolved_args.append(resolved_arg)
174
+
175
+ # Reconstruct the generic type with resolved arguments
176
+ if decomposed.origin:
177
+ return decomposed.origin[tuple(resolved_args)]
178
+
179
+ except (TypeError, AttributeError) as e:
180
+ # If we can't decompose or reconstruct, return as-is
181
+ logger.debug("Failed to decompose or reconstruct type '%s': %s", type_arg, e)
182
+
183
+ return type_arg
184
+
185
+ def _contains_typevar(self, type_arg: Any) -> bool:
186
+ """Check if a type contains any TypeVars (including nested ones).
68
187
 
69
188
  Args:
70
- type_expr: The type expression potentially containing TypeVars
71
- concrete_type: The concrete type to substitute
189
+ type_arg (Any): The type to check
72
190
 
73
191
  Returns:
74
- The type expression with TypeVars substituted
192
+ bool: True if the type contains any TypeVars
75
193
  """
76
- from typing import TypeVar
194
+ if isinstance(type_arg, TypeVar):
195
+ return True
77
196
 
78
- # If it's a TypeVar, substitute it
79
- if isinstance(type_expr, TypeVar):
80
- return concrete_type
197
+ try:
198
+ decomposed = DecomposedType(type_arg)
199
+ if decomposed.is_generic and decomposed.args:
200
+ return any(self._contains_typevar(arg) for arg in decomposed.args)
201
+ except (TypeError, AttributeError) as e:
202
+ logger.debug("Failed to decompose or reconstruct type '%s': %s", type_arg, e)
81
203
 
82
- # If it's a generic type like list[T], substitute the args
83
- origin = get_origin(type_expr)
84
- args = get_args(type_expr)
204
+ return False
85
205
 
86
- if origin and args:
87
- # Recursively substitute in the arguments
88
- new_args = tuple(self._substitute_type_var(arg, concrete_type) for arg in args)
89
- # Reconstruct the generic type
90
- return origin[new_args]
206
+ def _build_typevar_mapping(self) -> dict[TypeVar, type[Any]]:
207
+ """Build TypeVar to concrete type mapping from MRO traversal.
91
208
 
92
- # Otherwise, return as-is
93
- return type_expr
209
+ Returns:
210
+ dict[TypeVar, type[Any]]: Mapping of TypeVars to concrete types
211
+ """
212
+ typevar_mapping = {}
213
+
214
+ # First, check if the instance has concrete type arguments from __orig_class__
215
+ # This handles cases like BatchingProcessor[str]() where we need to map T -> str
216
+ orig_class = getattr(self, '__orig_class__', None)
217
+ if orig_class:
218
+ class_origin = get_origin(orig_class)
219
+ class_args = get_args(orig_class)
220
+ class_params = getattr(class_origin, '__parameters__', None)
221
+
222
+ if class_args and class_params:
223
+ # Map class-level TypeVars to their concrete arguments
224
+ for param, arg in zip(class_params, class_args):
225
+ typevar_mapping[param] = arg
226
+
227
+ # Then traverse the MRO to build the complete mapping
228
+ for cls in self.__class__.__mro__:
229
+ for base in getattr(cls, '__orig_bases__', []):
230
+ decomposed_base = DecomposedType(base)
231
+
232
+ if (decomposed_base.is_generic and decomposed_base.origin
233
+ and hasattr(decomposed_base.origin, '__parameters__')):
234
+ type_params = decomposed_base.origin.__parameters__
235
+ # Map each TypeVar to its concrete argument
236
+ for param, arg in zip(type_params, decomposed_base.args):
237
+ if param not in typevar_mapping: # Keep the most specific mapping
238
+ # If arg is also a TypeVar, try to resolve it
239
+ if isinstance(arg, TypeVar) and arg in typevar_mapping:
240
+ typevar_mapping[param] = typevar_mapping[arg]
241
+ else:
242
+ typevar_mapping[param] = arg
243
+
244
+ return typevar_mapping
245
+
246
+ def _extract_instance_types_from_mro(self) -> tuple[type[Any], type[Any]] | None:
247
+ """Extract Generic[InputT, OutputT] types by traversing the MRO.
248
+
249
+ This handles complex inheritance chains by looking for the base
250
+ class and resolving TypeVars through the inheritance hierarchy.
94
251
 
95
- @property
96
- @lru_cache
97
- def input_type(self) -> type[Any]:
252
+ Returns:
253
+ tuple[type[Any], type[Any]] | None: (input_type, output_type) or None if not found
98
254
  """
99
- Get the input type of the class. The input type is determined by the generic parameters of the class.
255
+ # Use the centralized TypeVar mapping
256
+ typevar_mapping = self._build_typevar_mapping()
100
257
 
101
- For example, if a class is defined as `MyClass[list[int], str]`, the `input_type` is `list[int]`.
258
+ # Now find the first generic base with exactly 2 parameters, starting from the base classes
259
+ # This ensures we get the fundamental input/output types rather than specialized ones
260
+ for cls in reversed(self.__class__.__mro__):
261
+ for base in getattr(cls, '__orig_bases__', []):
262
+ decomposed_base = DecomposedType(base)
102
263
 
103
- Returns
104
- -------
105
- type[Any]
106
- The input type specified in the generic parameters
264
+ # Look for any generic with exactly 2 parameters (likely InputT, OutputT pattern)
265
+ if decomposed_base.is_generic and len(decomposed_base.args) == 2:
266
+ input_type = decomposed_base.args[0]
267
+ output_type = decomposed_base.args[1]
107
268
 
108
- Raises
109
- ------
110
- ValueError
111
- If the input type cannot be determined from the class definition
269
+ # Resolve TypeVars to concrete types using recursive resolution
270
+ input_type = self._resolve_typevar_recursively(input_type, typevar_mapping)
271
+ output_type = self._resolve_typevar_recursively(output_type, typevar_mapping)
272
+
273
+ # Only return if we have concrete types (not TypeVars)
274
+ if not isinstance(input_type, TypeVar) and not isinstance(output_type, TypeVar):
275
+ return input_type, output_type
276
+
277
+ return None
278
+
279
+ @lru_cache
280
+ def _extract_input_output_types(self) -> tuple[type[Any], type[Any]]:
281
+ """Extract both input and output types using available approaches.
282
+
283
+ Returns:
284
+ tuple[type[Any], type[Any]]: (input_type, output_type)
285
+
286
+ Raises:
287
+ ValueError: If types cannot be extracted
112
288
  """
113
- types = self._find_generic_types()
114
- if types:
115
- return types[0]
289
+ # First try the signature-based approach
290
+ result = self._extract_types_from_signature_method()
291
+ if result:
292
+ return result
293
+
294
+ # Fallback to MRO-based approach for complex inheritance
295
+ result = self._extract_instance_types_from_mro()
296
+ if result:
297
+ return result
116
298
 
117
- raise ValueError(f"Could not find input type for {self.__class__.__name__}")
299
+ raise ValueError(f"Could not extract input/output types from {self.__class__.__name__}. "
300
+ f"Ensure class inherits from a generic like Processor[InputT, OutputT] "
301
+ f"or has a signature method with type annotations")
302
+
303
+ @property
304
+ def input_type(self) -> type[Any]:
305
+ """Get the input type of the instance.
306
+
307
+ Returns:
308
+ type[Any]: The input type
309
+ """
310
+ return self._extract_input_output_types()[0]
118
311
 
119
312
  @property
120
- @lru_cache
121
313
  def output_type(self) -> type[Any]:
314
+ """Get the output type of the instance.
315
+
316
+ Returns:
317
+ type[Any]: The output type
122
318
  """
123
- Get the output type of the class. The output type is determined by the generic parameters of the class.
319
+ return self._extract_input_output_types()[1]
124
320
 
125
- For example, if a class is defined as `MyClass[list[int], str]`, the `output_type` is `str`.
321
+ @lru_cache
322
+ def _get_union_info(self, type_obj: type[Any]) -> tuple[bool, tuple[type, ...] | None]:
323
+ """Get union information for a type.
126
324
 
127
- Returns
128
- -------
129
- type[Any]
130
- The output type specified in the generic parameters
325
+ Args:
326
+ type_obj (type[Any]): The type to analyze
131
327
 
132
- Raises
133
- ------
134
- ValueError
135
- If the output type cannot be determined from the class definition
328
+ Returns:
329
+ tuple[bool, tuple[type, ...] | None]: (is_union, union_types_or_none)
136
330
  """
137
- types = self._find_generic_types()
138
- if types:
139
- return types[1]
140
-
141
- raise ValueError(f"Could not find output type for {self.__class__.__name__}")
331
+ decomposed = DecomposedType(type_obj)
332
+ return decomposed.is_union, decomposed.args if decomposed.is_union else None
142
333
 
143
334
  @property
144
- @lru_cache
145
- def input_class(self) -> type:
335
+ def has_union_input(self) -> bool:
336
+ """Check if the input type is a union type.
337
+
338
+ Returns:
339
+ bool: True if the input type is a union type, False otherwise
146
340
  """
147
- Get the python class of the input type. This is the class that can be used to check if a value is an
148
- instance of the input type. It removes any generic or annotation information from the input type.
341
+ return self._get_union_info(self.input_type)[0]
149
342
 
150
- For example, if the input type is `list[int]`, the `input_class` is `list`.
343
+ @property
344
+ def has_union_output(self) -> bool:
345
+ """Check if the output type is a union type.
151
346
 
152
- Returns
153
- -------
154
- type
155
- The python type of the input type
347
+ Returns:
348
+ bool: True if the output type is a union type, False otherwise
156
349
  """
157
- input_origin = get_origin(self.input_type)
350
+ return self._get_union_info(self.output_type)[0]
158
351
 
159
- if input_origin is None:
160
- return self.input_type
352
+ @property
353
+ def input_union_types(self) -> tuple[type, ...] | None:
354
+ """Get the individual types in an input union.
161
355
 
162
- return input_origin
356
+ Returns:
357
+ tuple[type, ...] | None: The individual types in an input union or None if not found
358
+ """
359
+ return self._get_union_info(self.input_type)[1]
163
360
 
164
361
  @property
362
+ def output_union_types(self) -> tuple[type, ...] | None:
363
+ """Get the individual types in an output union.
364
+
365
+ Returns:
366
+ tuple[type, ...] | None: The individual types in an output union or None if not found
367
+ """
368
+ return self._get_union_info(self.output_type)[1]
369
+
370
+ def is_compatible_with_input(self, source_type: type) -> bool:
371
+ """Check if a source type is compatible with this instance's input type.
372
+
373
+ Uses Pydantic-based type compatibility checking for strict type matching.
374
+ This focuses on proper type relationships rather than batch compatibility.
375
+
376
+ Args:
377
+ source_type (type): The source type to check
378
+
379
+ Returns:
380
+ bool: True if the source type is compatible with the input type, False otherwise
381
+ """
382
+ return self._is_pydantic_type_compatible(source_type, self.input_type)
383
+
384
+ def is_output_compatible_with(self, target_type: type) -> bool:
385
+ """Check if this instance's output type is compatible with a target type.
386
+
387
+ Uses Pydantic-based type compatibility checking for strict type matching.
388
+ This focuses on proper type relationships rather than batch compatibility.
389
+
390
+ Args:
391
+ target_type (type): The target type to check
392
+
393
+ Returns:
394
+ bool: True if the output type is compatible with the target type, False otherwise
395
+ """
396
+ return self._is_pydantic_type_compatible(self.output_type, target_type)
397
+
398
+ def _is_pydantic_type_compatible(self, source_type: type, target_type: type) -> bool:
399
+ """Check strict type compatibility without batch compatibility hacks.
400
+
401
+ This focuses on proper type relationships: exact matches and subclass relationships.
402
+
403
+ Args:
404
+ source_type (type): The source type to check
405
+ target_type (type): The target type to check compatibility with
406
+
407
+ Returns:
408
+ bool: True if types are compatible, False otherwise
409
+ """
410
+ # Direct equality check (most common case)
411
+ if source_type == target_type:
412
+ return True
413
+
414
+ # Subclass relationship check
415
+ try:
416
+ if issubclass(source_type, target_type):
417
+ return True
418
+ except TypeError:
419
+ # Generic types can't use issubclass, they're only compatible if equal
420
+ logger.debug("Generic type %s cannot be used with issubclass, they're only compatible if equal",
421
+ source_type)
422
+
423
+ return False
424
+
165
425
  @lru_cache
166
- def output_class(self) -> type:
426
+ def _get_input_validator(self) -> type[BaseModel]:
427
+ """Create a Pydantic model for validating input types.
428
+
429
+ Returns:
430
+ type[BaseModel]: The Pydantic model for validating input types
167
431
  """
168
- Get the python class of the output type. This is the class that can be used to check if a value is an
169
- instance of the output type. It removes any generic or annotation information from the output type.
432
+ input_type = self.input_type
433
+ return create_model(f"{self.__class__.__name__}InputValidator", input=(input_type, FieldInfo()))
170
434
 
171
- For example, if the output type is `list[int]`, the `output_class` is `list`.
435
+ @lru_cache
436
+ def _get_output_validator(self) -> type[BaseModel]:
437
+ """Create a Pydantic model for validating output types.
438
+
439
+ Returns:
440
+ type[BaseModel]: The Pydantic model for validating output types
441
+ """
442
+ output_type = self.output_type
443
+ return create_model(f"{self.__class__.__name__}OutputValidator", output=(output_type, FieldInfo()))
444
+
445
+ def validate_input_type(self, item: Any) -> bool:
446
+ """Validate that an item matches the expected input type using Pydantic.
447
+
448
+ Args:
449
+ item (Any): The item to validate
172
450
 
173
- Returns
174
- -------
175
- type
176
- The python type of the output type
451
+ Returns:
452
+ bool: True if the item matches the input type, False otherwise
177
453
  """
178
- output_origin = get_origin(self.output_type)
454
+ try:
455
+ validator = self._get_input_validator()
456
+ validator(input=item)
457
+ return True
458
+ except ValidationError:
459
+ logger.warning("Item %s is not compatible with input type %s", item, self.input_type)
460
+ return False
461
+
462
+ def validate_output_type(self, item: Any) -> bool:
463
+ """Validate that an item matches the expected output type using Pydantic.
179
464
 
180
- if output_origin is None:
181
- return self.output_type
465
+ Args:
466
+ item (Any): The item to validate
182
467
 
183
- return output_origin
468
+ Returns:
469
+ bool: True if the item matches the output type, False otherwise
470
+ """
471
+ try:
472
+ validator = self._get_output_validator()
473
+ validator(output=item)
474
+ return True
475
+ except ValidationError:
476
+ logger.warning("Item %s is not compatible with output type %s", item, self.output_type)
477
+ return False
@@ -58,6 +58,9 @@ class Processor(Generic[InputT, OutputT], TypeIntrospectionMixin, ABC):
58
58
  chained processors.
59
59
  """
60
60
 
61
+ # All processors automatically use this for signature checking
62
+ _signature_method = 'process'
63
+
61
64
  @abstractmethod
62
65
  async def process(self, item: InputT) -> OutputT:
63
66
  """Process an item and return a potentially different type.
@@ -0,0 +1,24 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from .redaction_processor import RedactionContext
17
+ from .redaction_processor import RedactionContextState
18
+ from .span_header_redaction_processor import SpanHeaderRedactionProcessor
19
+
20
+ __all__ = [
21
+ "SpanHeaderRedactionProcessor",
22
+ "RedactionContext",
23
+ "RedactionContextState",
24
+ ]