nvidia-nat 1.3.0.dev2__py3-none-any.whl → 1.3.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +2 -2
- nat/agent/base.py +24 -15
- nat/agent/dual_node.py +9 -4
- nat/agent/prompt_optimizer/prompt.py +68 -0
- nat/agent/prompt_optimizer/register.py +149 -0
- nat/agent/react_agent/agent.py +79 -47
- nat/agent/react_agent/register.py +50 -22
- nat/agent/reasoning_agent/reasoning_agent.py +11 -9
- nat/agent/register.py +1 -1
- nat/agent/rewoo_agent/agent.py +326 -148
- nat/agent/rewoo_agent/prompt.py +19 -22
- nat/agent/rewoo_agent/register.py +54 -27
- nat/agent/tool_calling_agent/agent.py +84 -28
- nat/agent/tool_calling_agent/register.py +51 -28
- nat/authentication/api_key/api_key_auth_provider.py +2 -2
- nat/authentication/credential_validator/bearer_token_validator.py +557 -0
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
- nat/authentication/interfaces.py +5 -2
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +69 -36
- nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
- nat/authentication/register.py +0 -1
- nat/builder/builder.py +56 -24
- nat/builder/component_utils.py +9 -5
- nat/builder/context.py +68 -17
- nat/builder/eval_builder.py +16 -11
- nat/builder/framework_enum.py +1 -0
- nat/builder/front_end.py +1 -1
- nat/builder/function.py +378 -8
- nat/builder/function_base.py +3 -3
- nat/builder/function_info.py +6 -8
- nat/builder/user_interaction_manager.py +2 -2
- nat/builder/workflow.py +13 -1
- nat/builder/workflow_builder.py +281 -76
- nat/cli/cli_utils/config_override.py +2 -2
- nat/cli/commands/evaluate.py +1 -1
- nat/cli/commands/info/info.py +16 -6
- nat/cli/commands/info/list_channels.py +1 -1
- nat/cli/commands/info/list_components.py +7 -8
- nat/cli/commands/mcp/__init__.py +14 -0
- nat/cli/commands/mcp/mcp.py +986 -0
- nat/cli/commands/object_store/__init__.py +14 -0
- nat/cli/commands/object_store/object_store.py +227 -0
- nat/cli/commands/optimize.py +90 -0
- nat/cli/commands/registry/publish.py +2 -2
- nat/cli/commands/registry/pull.py +2 -2
- nat/cli/commands/registry/remove.py +2 -2
- nat/cli/commands/registry/search.py +15 -17
- nat/cli/commands/start.py +16 -5
- nat/cli/commands/uninstall.py +1 -1
- nat/cli/commands/workflow/templates/config.yml.j2 +14 -13
- nat/cli/commands/workflow/templates/pyproject.toml.j2 +4 -1
- nat/cli/commands/workflow/templates/register.py.j2 +2 -3
- nat/cli/commands/workflow/templates/workflow.py.j2 +35 -21
- nat/cli/commands/workflow/workflow_commands.py +62 -22
- nat/cli/entrypoint.py +8 -10
- nat/cli/main.py +3 -0
- nat/cli/register_workflow.py +38 -4
- nat/cli/type_registry.py +75 -6
- nat/control_flow/__init__.py +0 -0
- nat/control_flow/register.py +20 -0
- nat/control_flow/router_agent/__init__.py +0 -0
- nat/control_flow/router_agent/agent.py +329 -0
- nat/control_flow/router_agent/prompt.py +48 -0
- nat/control_flow/router_agent/register.py +91 -0
- nat/control_flow/sequential_executor.py +166 -0
- nat/data_models/agent.py +34 -0
- nat/data_models/api_server.py +74 -66
- nat/data_models/authentication.py +23 -9
- nat/data_models/common.py +1 -1
- nat/data_models/component.py +2 -0
- nat/data_models/component_ref.py +11 -0
- nat/data_models/config.py +41 -17
- nat/data_models/dataset_handler.py +1 -1
- nat/data_models/discovery_metadata.py +4 -4
- nat/data_models/evaluate.py +4 -1
- nat/data_models/function.py +34 -0
- nat/data_models/function_dependencies.py +14 -6
- nat/data_models/gated_field_mixin.py +242 -0
- nat/data_models/intermediate_step.py +3 -3
- nat/data_models/optimizable.py +119 -0
- nat/data_models/optimizer.py +149 -0
- nat/data_models/span.py +41 -3
- nat/data_models/swe_bench_model.py +1 -1
- nat/data_models/temperature_mixin.py +44 -0
- nat/data_models/thinking_mixin.py +86 -0
- nat/data_models/top_p_mixin.py +44 -0
- nat/embedder/nim_embedder.py +1 -1
- nat/embedder/openai_embedder.py +1 -1
- nat/embedder/register.py +0 -1
- nat/eval/config.py +3 -1
- nat/eval/dataset_handler/dataset_handler.py +71 -7
- nat/eval/evaluate.py +86 -31
- nat/eval/evaluator/base_evaluator.py +1 -1
- nat/eval/evaluator/evaluator_model.py +13 -0
- nat/eval/intermediate_step_adapter.py +1 -1
- nat/eval/rag_evaluator/evaluate.py +2 -2
- nat/eval/rag_evaluator/register.py +3 -3
- nat/eval/register.py +4 -1
- nat/eval/remote_workflow.py +3 -3
- nat/eval/runtime_evaluator/__init__.py +14 -0
- nat/eval/runtime_evaluator/evaluate.py +123 -0
- nat/eval/runtime_evaluator/register.py +100 -0
- nat/eval/swe_bench_evaluator/evaluate.py +6 -6
- nat/eval/trajectory_evaluator/evaluate.py +1 -1
- nat/eval/trajectory_evaluator/register.py +1 -1
- nat/eval/tunable_rag_evaluator/evaluate.py +4 -7
- nat/eval/utils/eval_trace_ctx.py +89 -0
- nat/eval/utils/weave_eval.py +18 -9
- nat/experimental/decorators/experimental_warning_decorator.py +27 -7
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +3 -3
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +1 -1
- nat/experimental/test_time_compute/models/strategy_base.py +5 -4
- nat/experimental/test_time_compute/register.py +0 -1
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -3
- nat/front_ends/console/authentication_flow_handler.py +82 -30
- nat/front_ends/console/console_front_end_plugin.py +8 -5
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
- nat/front_ends/fastapi/dask_client_mixin.py +65 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +36 -5
- nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +135 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +452 -282
- nat/front_ends/fastapi/job_store.py +518 -99
- nat/front_ends/fastapi/main.py +11 -19
- nat/front_ends/fastapi/message_handler.py +13 -14
- nat/front_ends/fastapi/message_validator.py +19 -19
- nat/front_ends/fastapi/response_helpers.py +4 -4
- nat/front_ends/fastapi/step_adaptor.py +2 -2
- nat/front_ends/fastapi/utils.py +57 -0
- nat/front_ends/mcp/introspection_token_verifier.py +73 -0
- nat/front_ends/mcp/mcp_front_end_config.py +10 -1
- nat/front_ends/mcp/mcp_front_end_plugin.py +45 -13
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +116 -8
- nat/front_ends/mcp/tool_converter.py +44 -14
- nat/front_ends/register.py +0 -1
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
- nat/llm/aws_bedrock_llm.py +24 -12
- nat/llm/azure_openai_llm.py +13 -6
- nat/llm/litellm_llm.py +69 -0
- nat/llm/nim_llm.py +20 -8
- nat/llm/openai_llm.py +14 -6
- nat/llm/register.py +4 -1
- nat/llm/utils/env_config_value.py +2 -3
- nat/llm/utils/thinking.py +215 -0
- nat/meta/pypi.md +9 -9
- nat/object_store/register.py +0 -1
- nat/observability/exporter/base_exporter.py +3 -3
- nat/observability/exporter/file_exporter.py +1 -1
- nat/observability/exporter/processing_exporter.py +309 -81
- nat/observability/exporter/span_exporter.py +35 -15
- nat/observability/exporter_manager.py +7 -7
- nat/observability/mixin/file_mixin.py +7 -7
- nat/observability/mixin/redaction_config_mixin.py +42 -0
- nat/observability/mixin/tagging_config_mixin.py +62 -0
- nat/observability/mixin/type_introspection_mixin.py +420 -107
- nat/observability/processor/batching_processor.py +5 -7
- nat/observability/processor/falsy_batch_filter_processor.py +55 -0
- nat/observability/processor/processor.py +3 -0
- nat/observability/processor/processor_factory.py +70 -0
- nat/observability/processor/redaction/__init__.py +24 -0
- nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
- nat/observability/processor/redaction/redaction_processor.py +177 -0
- nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
- nat/observability/processor/span_tagging_processor.py +68 -0
- nat/observability/register.py +6 -4
- nat/profiler/calc/calc_runner.py +3 -4
- nat/profiler/callbacks/agno_callback_handler.py +1 -1
- nat/profiler/callbacks/langchain_callback_handler.py +6 -6
- nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +3 -3
- nat/profiler/data_frame_row.py +1 -1
- nat/profiler/decorators/framework_wrapper.py +62 -13
- nat/profiler/decorators/function_tracking.py +160 -3
- nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
- nat/profiler/forecasting/models/linear_model.py +1 -1
- nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
- nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/data_models.py +3 -3
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +8 -9
- nat/profiler/inference_optimization/token_uniqueness.py +1 -1
- nat/profiler/parameter_optimization/__init__.py +0 -0
- nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
- nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
- nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
- nat/profiler/parameter_optimization/parameter_selection.py +107 -0
- nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
- nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
- nat/profiler/parameter_optimization/update_helpers.py +66 -0
- nat/profiler/profile_runner.py +14 -9
- nat/profiler/utils.py +4 -2
- nat/registry_handlers/local/local_handler.py +2 -2
- nat/registry_handlers/package_utils.py +1 -2
- nat/registry_handlers/pypi/pypi_handler.py +23 -26
- nat/registry_handlers/register.py +3 -4
- nat/registry_handlers/rest/rest_handler.py +12 -13
- nat/retriever/milvus/retriever.py +2 -2
- nat/retriever/nemo_retriever/retriever.py +1 -1
- nat/retriever/register.py +0 -1
- nat/runtime/loader.py +2 -2
- nat/runtime/runner.py +106 -8
- nat/runtime/session.py +69 -8
- nat/settings/global_settings.py +16 -5
- nat/tool/chat_completion.py +5 -2
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +3 -3
- nat/tool/datetime_tools.py +49 -9
- nat/tool/document_search.py +2 -2
- nat/tool/github_tools.py +450 -0
- nat/tool/memory_tools/get_memory_tool.py +1 -1
- nat/tool/nvidia_rag.py +1 -1
- nat/tool/register.py +2 -9
- nat/tool/retriever.py +3 -2
- nat/utils/callable_utils.py +70 -0
- nat/utils/data_models/schema_validator.py +3 -3
- nat/utils/decorators.py +210 -0
- nat/utils/exception_handlers/automatic_retries.py +104 -51
- nat/utils/exception_handlers/schemas.py +1 -1
- nat/utils/io/yaml_tools.py +2 -2
- nat/utils/log_levels.py +25 -0
- nat/utils/reactive/base/observable_base.py +2 -2
- nat/utils/reactive/base/observer_base.py +1 -1
- nat/utils/reactive/observable.py +2 -2
- nat/utils/reactive/observer.py +4 -4
- nat/utils/reactive/subscription.py +1 -1
- nat/utils/settings/global_settings.py +6 -8
- nat/utils/type_converter.py +4 -3
- nat/utils/type_utils.py +9 -5
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/METADATA +42 -18
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/RECORD +238 -196
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/entry_points.txt +1 -0
- nat/cli/commands/info/list_mcp.py +0 -304
- nat/tool/github_tools/create_github_commit.py +0 -133
- nat/tool/github_tools/create_github_issue.py +0 -87
- nat/tool/github_tools/create_github_pr.py +0 -106
- nat/tool/github_tools/get_github_file.py +0 -106
- nat/tool/github_tools/get_github_issue.py +0 -166
- nat/tool/github_tools/get_github_pr.py +0 -256
- nat/tool/github_tools/update_github_issue.py +0 -100
- nat/tool/mcp/exceptions.py +0 -142
- nat/tool/mcp/mcp_client.py +0 -255
- nat/tool/mcp/mcp_tool.py +0 -96
- nat/utils/exception_handlers/mcp.py +0 -211
- /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
- /nat/{tool/mcp → authentication/credential_validator}/__init__.py +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -13,171 +13,484 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
import inspect
|
|
17
|
+
import logging
|
|
18
|
+
import types
|
|
16
19
|
from functools import lru_cache
|
|
17
20
|
from typing import Any
|
|
21
|
+
from typing import TypeVar
|
|
18
22
|
from typing import get_args
|
|
19
23
|
from typing import get_origin
|
|
20
24
|
|
|
25
|
+
from pydantic import BaseModel
|
|
26
|
+
from pydantic import ValidationError
|
|
27
|
+
from pydantic import create_model
|
|
28
|
+
from pydantic.fields import FieldInfo
|
|
29
|
+
|
|
30
|
+
from nat.utils.type_utils import DecomposedType
|
|
31
|
+
|
|
32
|
+
logger = logging.getLogger(__name__)
|
|
33
|
+
|
|
21
34
|
|
|
22
35
|
class TypeIntrospectionMixin:
|
|
23
|
-
"""
|
|
36
|
+
"""Hybrid mixin class providing type introspection capabilities for generic classes.
|
|
24
37
|
|
|
25
|
-
This mixin
|
|
26
|
-
|
|
38
|
+
This mixin combines the DecomposedType class utilities with MRO traversal
|
|
39
|
+
to properly handle complex inheritance chains like HeaderRedactionProcessor or ProcessingExporter.
|
|
27
40
|
"""
|
|
28
41
|
|
|
29
|
-
def
|
|
42
|
+
def _extract_types_from_signature_method(self) -> tuple[type[Any], type[Any]] | None:
|
|
43
|
+
"""Extract input/output types from the signature method.
|
|
44
|
+
|
|
45
|
+
This method looks for a signature method (either defined via _signature_method class
|
|
46
|
+
attribute or discovered generically) and extracts input/output types from
|
|
47
|
+
its method signature.
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
tuple[type[Any], type[Any]] | None: (input_type, output_type) or None if not found.
|
|
30
51
|
"""
|
|
31
|
-
|
|
52
|
+
# First, try to get the signature method name from the class
|
|
53
|
+
signature_method_name = getattr(self.__class__, '_signature_method', None)
|
|
54
|
+
|
|
55
|
+
# If not defined, try to discover it generically
|
|
56
|
+
if not signature_method_name:
|
|
57
|
+
signature_method_name = self._discover_signature_method()
|
|
58
|
+
|
|
59
|
+
if not signature_method_name:
|
|
60
|
+
return None
|
|
61
|
+
|
|
62
|
+
# Get the method and inspect its signature
|
|
63
|
+
try:
|
|
64
|
+
method = getattr(self, signature_method_name)
|
|
65
|
+
sig = inspect.signature(method)
|
|
66
|
+
|
|
67
|
+
# Find the first parameter that's not 'self'
|
|
68
|
+
params = list(sig.parameters.values())
|
|
69
|
+
input_param = None
|
|
70
|
+
for param in params:
|
|
71
|
+
if param.name != 'self':
|
|
72
|
+
input_param = param
|
|
73
|
+
break
|
|
74
|
+
|
|
75
|
+
if not input_param or input_param.annotation == inspect.Parameter.empty:
|
|
76
|
+
return None
|
|
77
|
+
|
|
78
|
+
# Get return type
|
|
79
|
+
return_annotation = sig.return_annotation
|
|
80
|
+
if return_annotation == inspect.Signature.empty:
|
|
81
|
+
return None
|
|
82
|
+
|
|
83
|
+
input_type = input_param.annotation
|
|
84
|
+
output_type = return_annotation
|
|
85
|
+
|
|
86
|
+
# Resolve any TypeVars if needed (including nested ones)
|
|
87
|
+
if isinstance(input_type, TypeVar) or isinstance(
|
|
88
|
+
output_type, TypeVar) or self._contains_typevar(input_type) or self._contains_typevar(output_type):
|
|
89
|
+
# Try to resolve using the MRO approach as fallback
|
|
90
|
+
typevar_mapping = self._build_typevar_mapping()
|
|
91
|
+
input_type = self._resolve_typevar_recursively(input_type, typevar_mapping)
|
|
92
|
+
output_type = self._resolve_typevar_recursively(output_type, typevar_mapping)
|
|
93
|
+
|
|
94
|
+
# Only return if we have concrete types
|
|
95
|
+
if not isinstance(input_type, TypeVar) and not isinstance(output_type, TypeVar):
|
|
96
|
+
return input_type, output_type
|
|
97
|
+
|
|
98
|
+
except (AttributeError, TypeError) as e:
|
|
99
|
+
logger.debug("Failed to extract types from signature method '%s': %s", signature_method_name, e)
|
|
100
|
+
|
|
101
|
+
return None
|
|
102
|
+
|
|
103
|
+
def _discover_signature_method(self) -> str | None:
|
|
104
|
+
"""Discover any method suitable for type introspection.
|
|
32
105
|
|
|
33
|
-
|
|
34
|
-
|
|
106
|
+
Looks for any method with the signature pattern: method(self, param: Type) -> ReturnType
|
|
107
|
+
Any method matching this pattern is functionally equivalent for type introspection purposes.
|
|
35
108
|
|
|
36
109
|
Returns:
|
|
37
|
-
|
|
110
|
+
str | None: Method name or None if not found
|
|
38
111
|
"""
|
|
39
|
-
#
|
|
40
|
-
|
|
41
|
-
base_cls_args = get_args(base_cls)
|
|
112
|
+
# Look through all methods to find ones that match the input/output pattern
|
|
113
|
+
candidates = []
|
|
42
114
|
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
115
|
+
for cls in self.__class__.__mro__:
|
|
116
|
+
for name, method in inspect.getmembers(cls, inspect.isfunction):
|
|
117
|
+
# Skip private methods except dunder methods
|
|
118
|
+
if name.startswith('_') and not name.startswith('__'):
|
|
119
|
+
continue
|
|
46
120
|
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
if base_origin and hasattr(base_origin, '__orig_bases__'):
|
|
52
|
-
# Look at the parent's generic definition
|
|
53
|
-
for parent_base in getattr(base_origin, '__orig_bases__', []):
|
|
54
|
-
parent_args = get_args(parent_base)
|
|
55
|
-
if len(parent_args) >= 2:
|
|
56
|
-
# Found the pattern: ParentClass[T, list[T]]
|
|
57
|
-
# Substitute T with our concrete type
|
|
58
|
-
concrete_type = base_cls_args[0]
|
|
59
|
-
input_type = self._substitute_type_var(parent_args[0], concrete_type)
|
|
60
|
-
output_type = self._substitute_type_var(parent_args[1], concrete_type)
|
|
61
|
-
return input_type, output_type
|
|
121
|
+
# Skip methods that were defined in TypeIntrospectionMixin
|
|
122
|
+
if hasattr(method, '__qualname__') and 'TypeIntrospectionMixin' in method.__qualname__:
|
|
123
|
+
logger.debug("Skipping method '%s' defined in TypeIntrospectionMixin", name)
|
|
124
|
+
continue
|
|
62
125
|
|
|
63
|
-
|
|
126
|
+
# Let signature analysis determine suitability - method names don't matter
|
|
127
|
+
try:
|
|
128
|
+
sig = inspect.signature(method)
|
|
129
|
+
params = list(sig.parameters.values())
|
|
130
|
+
|
|
131
|
+
# Look for methods with exactly one non-self parameter and a return annotation
|
|
132
|
+
non_self_params = [p for p in params if p.name != 'self']
|
|
133
|
+
if (len(non_self_params) == 1 and non_self_params[0].annotation != inspect.Parameter.empty
|
|
134
|
+
and sig.return_annotation != inspect.Signature.empty):
|
|
135
|
+
|
|
136
|
+
# Prioritize abstract methods
|
|
137
|
+
is_abstract = getattr(method, '__isabstractmethod__', False)
|
|
138
|
+
candidates.append((name, is_abstract, cls))
|
|
139
|
+
|
|
140
|
+
except (TypeError, ValueError) as e:
|
|
141
|
+
logger.debug("Failed to inspect signature of method '%s': %s", name, e)
|
|
142
|
+
|
|
143
|
+
if not candidates:
|
|
144
|
+
logger.debug("No candidates found for signature method")
|
|
145
|
+
return None
|
|
64
146
|
|
|
65
|
-
|
|
147
|
+
# Any method with the right signature will work for type introspection
|
|
148
|
+
# Prioritize abstract methods if available, otherwise use the first valid one
|
|
149
|
+
candidates.sort(key=lambda x: not x[1]) # Abstract methods first
|
|
150
|
+
return candidates[0][0]
|
|
151
|
+
|
|
152
|
+
def _resolve_typevar_recursively(self, type_arg: Any, typevar_mapping: dict[TypeVar, type[Any]]) -> Any:
|
|
153
|
+
"""Recursively resolve TypeVars within complex types.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
type_arg (Any): The type argument to resolve (could be a TypeVar, generic type, etc.)
|
|
157
|
+
typevar_mapping (dict[TypeVar, type[Any]]): Current mapping of TypeVars to concrete types
|
|
158
|
+
|
|
159
|
+
Returns:
|
|
160
|
+
Any: The resolved type with all TypeVars substituted
|
|
66
161
|
"""
|
|
67
|
-
|
|
162
|
+
# If it's a TypeVar, resolve it
|
|
163
|
+
if isinstance(type_arg, TypeVar):
|
|
164
|
+
return typevar_mapping.get(type_arg, type_arg)
|
|
165
|
+
|
|
166
|
+
# If it's a generic type, decompose and resolve its arguments
|
|
167
|
+
try:
|
|
168
|
+
decomposed = DecomposedType(type_arg)
|
|
169
|
+
if decomposed.is_generic and decomposed.args:
|
|
170
|
+
# Recursively resolve all type arguments
|
|
171
|
+
resolved_args = []
|
|
172
|
+
for arg in decomposed.args:
|
|
173
|
+
resolved_arg = self._resolve_typevar_recursively(arg, typevar_mapping)
|
|
174
|
+
resolved_args.append(resolved_arg)
|
|
175
|
+
|
|
176
|
+
# Reconstruct the generic type with resolved arguments
|
|
177
|
+
if decomposed.origin:
|
|
178
|
+
return decomposed.origin[tuple(resolved_args)]
|
|
179
|
+
|
|
180
|
+
except (TypeError, AttributeError) as e:
|
|
181
|
+
# If we can't decompose or reconstruct, return as-is
|
|
182
|
+
logger.debug("Failed to decompose or reconstruct type '%s': %s", type_arg, e)
|
|
183
|
+
|
|
184
|
+
return type_arg
|
|
185
|
+
|
|
186
|
+
def _contains_typevar(self, type_arg: Any) -> bool:
|
|
187
|
+
"""Check if a type contains any TypeVars (including nested ones).
|
|
68
188
|
|
|
69
189
|
Args:
|
|
70
|
-
|
|
71
|
-
concrete_type: The concrete type to substitute
|
|
190
|
+
type_arg (Any): The type to check
|
|
72
191
|
|
|
73
192
|
Returns:
|
|
74
|
-
|
|
193
|
+
bool: True if the type contains any TypeVars
|
|
75
194
|
"""
|
|
76
|
-
|
|
195
|
+
if isinstance(type_arg, TypeVar):
|
|
196
|
+
return True
|
|
77
197
|
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
198
|
+
try:
|
|
199
|
+
decomposed = DecomposedType(type_arg)
|
|
200
|
+
if decomposed.is_generic and decomposed.args:
|
|
201
|
+
return any(self._contains_typevar(arg) for arg in decomposed.args)
|
|
202
|
+
except (TypeError, AttributeError) as e:
|
|
203
|
+
logger.debug("Failed to decompose or reconstruct type '%s': %s", type_arg, e)
|
|
81
204
|
|
|
82
|
-
|
|
83
|
-
origin = get_origin(type_expr)
|
|
84
|
-
args = get_args(type_expr)
|
|
205
|
+
return False
|
|
85
206
|
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
new_args = tuple(self._substitute_type_var(arg, concrete_type) for arg in args)
|
|
89
|
-
# Reconstruct the generic type
|
|
90
|
-
return origin[new_args]
|
|
207
|
+
def _build_typevar_mapping(self) -> dict[TypeVar, type[Any]]:
|
|
208
|
+
"""Build TypeVar to concrete type mapping from MRO traversal.
|
|
91
209
|
|
|
92
|
-
|
|
93
|
-
|
|
210
|
+
Returns:
|
|
211
|
+
dict[TypeVar, type[Any]]: Mapping of TypeVars to concrete types
|
|
212
|
+
"""
|
|
213
|
+
typevar_mapping = {}
|
|
214
|
+
|
|
215
|
+
# First, check if the instance has concrete type arguments from __orig_class__
|
|
216
|
+
# This handles cases like BatchingProcessor[str]() where we need to map T -> str
|
|
217
|
+
orig_class = getattr(self, '__orig_class__', None)
|
|
218
|
+
if orig_class:
|
|
219
|
+
class_origin = get_origin(orig_class)
|
|
220
|
+
class_args = get_args(orig_class)
|
|
221
|
+
class_params = getattr(class_origin, '__parameters__', None)
|
|
222
|
+
|
|
223
|
+
if class_args and class_params:
|
|
224
|
+
# Map class-level TypeVars to their concrete arguments
|
|
225
|
+
for param, arg in zip(class_params, class_args):
|
|
226
|
+
typevar_mapping[param] = arg
|
|
227
|
+
|
|
228
|
+
# Then traverse the MRO to build the complete mapping
|
|
229
|
+
for cls in self.__class__.__mro__:
|
|
230
|
+
for base in getattr(cls, '__orig_bases__', []):
|
|
231
|
+
decomposed_base = DecomposedType(base)
|
|
232
|
+
|
|
233
|
+
if (decomposed_base.is_generic and decomposed_base.origin
|
|
234
|
+
and hasattr(decomposed_base.origin, '__parameters__')):
|
|
235
|
+
type_params = decomposed_base.origin.__parameters__
|
|
236
|
+
# Map each TypeVar to its concrete argument
|
|
237
|
+
for param, arg in zip(type_params, decomposed_base.args):
|
|
238
|
+
if param not in typevar_mapping: # Keep the most specific mapping
|
|
239
|
+
# If arg is also a TypeVar, try to resolve it
|
|
240
|
+
if isinstance(arg, TypeVar) and arg in typevar_mapping:
|
|
241
|
+
typevar_mapping[param] = typevar_mapping[arg]
|
|
242
|
+
else:
|
|
243
|
+
typevar_mapping[param] = arg
|
|
244
|
+
|
|
245
|
+
return typevar_mapping
|
|
246
|
+
|
|
247
|
+
def _extract_instance_types_from_mro(self) -> tuple[type[Any], type[Any]] | None:
|
|
248
|
+
"""Extract Generic[InputT, OutputT] types by traversing the MRO.
|
|
249
|
+
|
|
250
|
+
This handles complex inheritance chains by looking for the base
|
|
251
|
+
class and resolving TypeVars through the inheritance hierarchy.
|
|
94
252
|
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
def input_type(self) -> type[Any]:
|
|
253
|
+
Returns:
|
|
254
|
+
tuple[type[Any], type[Any]] | None: (input_type, output_type) or None if not found
|
|
98
255
|
"""
|
|
99
|
-
|
|
256
|
+
# Use the centralized TypeVar mapping
|
|
257
|
+
typevar_mapping = self._build_typevar_mapping()
|
|
258
|
+
|
|
259
|
+
# Now find the first generic base with exactly 2 parameters, starting from the base classes
|
|
260
|
+
# This ensures we get the fundamental input/output types rather than specialized ones
|
|
261
|
+
for cls in reversed(self.__class__.__mro__):
|
|
262
|
+
for base in getattr(cls, '__orig_bases__', []):
|
|
263
|
+
decomposed_base = DecomposedType(base)
|
|
264
|
+
|
|
265
|
+
# Look for any generic with exactly 2 parameters (likely InputT, OutputT pattern)
|
|
266
|
+
if decomposed_base.is_generic and len(decomposed_base.args) == 2:
|
|
267
|
+
input_type = decomposed_base.args[0]
|
|
268
|
+
output_type = decomposed_base.args[1]
|
|
269
|
+
|
|
270
|
+
# Resolve TypeVars to concrete types using recursive resolution
|
|
271
|
+
input_type = self._resolve_typevar_recursively(input_type, typevar_mapping)
|
|
272
|
+
output_type = self._resolve_typevar_recursively(output_type, typevar_mapping)
|
|
273
|
+
|
|
274
|
+
# Only return if we have concrete types (not TypeVars)
|
|
275
|
+
if not isinstance(input_type, TypeVar) and not isinstance(output_type, TypeVar):
|
|
276
|
+
return input_type, output_type
|
|
277
|
+
|
|
278
|
+
return None
|
|
100
279
|
|
|
101
|
-
|
|
280
|
+
@lru_cache
|
|
281
|
+
def _extract_input_output_types(self) -> tuple[type[Any], type[Any]]:
|
|
282
|
+
"""Extract both input and output types using available approaches.
|
|
102
283
|
|
|
103
|
-
Returns
|
|
104
|
-
|
|
105
|
-
type[Any]
|
|
106
|
-
The input type specified in the generic parameters
|
|
284
|
+
Returns:
|
|
285
|
+
tuple[type[Any], type[Any]]: (input_type, output_type)
|
|
107
286
|
|
|
108
|
-
Raises
|
|
109
|
-
|
|
110
|
-
ValueError
|
|
111
|
-
If the input type cannot be determined from the class definition
|
|
287
|
+
Raises:
|
|
288
|
+
ValueError: If types cannot be extracted
|
|
112
289
|
"""
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
290
|
+
# First try the signature-based approach
|
|
291
|
+
result = self._extract_types_from_signature_method()
|
|
292
|
+
if result:
|
|
293
|
+
return result
|
|
294
|
+
|
|
295
|
+
# Fallback to MRO-based approach for complex inheritance
|
|
296
|
+
result = self._extract_instance_types_from_mro()
|
|
297
|
+
if result:
|
|
298
|
+
return result
|
|
116
299
|
|
|
117
|
-
raise ValueError(f"Could not
|
|
300
|
+
raise ValueError(f"Could not extract input/output types from {self.__class__.__name__}. "
|
|
301
|
+
f"Ensure class inherits from a generic like Processor[InputT, OutputT] "
|
|
302
|
+
f"or has a signature method with type annotations")
|
|
303
|
+
|
|
304
|
+
@property
|
|
305
|
+
def input_type(self) -> type[Any]:
|
|
306
|
+
"""Get the input type of the instance.
|
|
307
|
+
|
|
308
|
+
Returns:
|
|
309
|
+
type[Any]: The input type
|
|
310
|
+
"""
|
|
311
|
+
return self._extract_input_output_types()[0]
|
|
118
312
|
|
|
119
313
|
@property
|
|
120
|
-
@lru_cache
|
|
121
314
|
def output_type(self) -> type[Any]:
|
|
315
|
+
"""Get the output type of the instance.
|
|
316
|
+
|
|
317
|
+
Returns:
|
|
318
|
+
type[Any]: The output type
|
|
122
319
|
"""
|
|
123
|
-
|
|
320
|
+
return self._extract_input_output_types()[1]
|
|
124
321
|
|
|
125
|
-
|
|
322
|
+
@lru_cache
|
|
323
|
+
def _get_union_info(self, type_obj: type[Any]) -> tuple[bool, tuple[type, ...] | None]:
|
|
324
|
+
"""Get union information for a type.
|
|
126
325
|
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
type[Any]
|
|
130
|
-
The output type specified in the generic parameters
|
|
326
|
+
Args:
|
|
327
|
+
type_obj (type[Any]): The type to analyze
|
|
131
328
|
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
ValueError
|
|
135
|
-
If the output type cannot be determined from the class definition
|
|
329
|
+
Returns:
|
|
330
|
+
tuple[bool, tuple[type, ...] | None]: (is_union, union_types_or_none)
|
|
136
331
|
"""
|
|
137
|
-
|
|
138
|
-
if
|
|
139
|
-
return types[1]
|
|
140
|
-
|
|
141
|
-
raise ValueError(f"Could not find output type for {self.__class__.__name__}")
|
|
332
|
+
decomposed = DecomposedType(type_obj)
|
|
333
|
+
return decomposed.is_union, decomposed.args if decomposed.is_union else None
|
|
142
334
|
|
|
143
335
|
@property
|
|
144
|
-
|
|
145
|
-
|
|
336
|
+
def has_union_input(self) -> bool:
|
|
337
|
+
"""Check if the input type is a union type.
|
|
338
|
+
|
|
339
|
+
Returns:
|
|
340
|
+
bool: True if the input type is a union type, False otherwise
|
|
146
341
|
"""
|
|
147
|
-
|
|
148
|
-
instance of the input type. It removes any generic or annotation information from the input type.
|
|
342
|
+
return self._get_union_info(self.input_type)[0]
|
|
149
343
|
|
|
150
|
-
|
|
344
|
+
@property
|
|
345
|
+
def has_union_output(self) -> bool:
|
|
346
|
+
"""Check if the output type is a union type.
|
|
151
347
|
|
|
152
|
-
Returns
|
|
153
|
-
|
|
154
|
-
type
|
|
155
|
-
The python type of the input type
|
|
348
|
+
Returns:
|
|
349
|
+
bool: True if the output type is a union type, False otherwise
|
|
156
350
|
"""
|
|
157
|
-
|
|
351
|
+
return self._get_union_info(self.output_type)[0]
|
|
158
352
|
|
|
159
|
-
|
|
160
|
-
|
|
353
|
+
@property
|
|
354
|
+
def input_union_types(self) -> tuple[type, ...] | None:
|
|
355
|
+
"""Get the individual types in an input union.
|
|
161
356
|
|
|
162
|
-
|
|
357
|
+
Returns:
|
|
358
|
+
tuple[type, ...] | None: The individual types in an input union or None if not found
|
|
359
|
+
"""
|
|
360
|
+
return self._get_union_info(self.input_type)[1]
|
|
163
361
|
|
|
164
362
|
@property
|
|
363
|
+
def output_union_types(self) -> tuple[type, ...] | None:
|
|
364
|
+
"""Get the individual types in an output union.
|
|
365
|
+
|
|
366
|
+
Returns:
|
|
367
|
+
tuple[type, ...] | None: The individual types in an output union or None if not found
|
|
368
|
+
"""
|
|
369
|
+
return self._get_union_info(self.output_type)[1]
|
|
370
|
+
|
|
371
|
+
def is_compatible_with_input(self, source_type: type) -> bool:
|
|
372
|
+
"""Check if a source type is compatible with this instance's input type.
|
|
373
|
+
|
|
374
|
+
Uses Pydantic-based type compatibility checking for strict type matching.
|
|
375
|
+
This focuses on proper type relationships rather than batch compatibility.
|
|
376
|
+
|
|
377
|
+
Args:
|
|
378
|
+
source_type (type): The source type to check
|
|
379
|
+
|
|
380
|
+
Returns:
|
|
381
|
+
bool: True if the source type is compatible with the input type, False otherwise
|
|
382
|
+
"""
|
|
383
|
+
return self._is_pydantic_type_compatible(source_type, self.input_type)
|
|
384
|
+
|
|
385
|
+
def is_output_compatible_with(self, target_type: type) -> bool:
|
|
386
|
+
"""Check if this instance's output type is compatible with a target type.
|
|
387
|
+
|
|
388
|
+
Uses Pydantic-based type compatibility checking for strict type matching.
|
|
389
|
+
This focuses on proper type relationships rather than batch compatibility.
|
|
390
|
+
|
|
391
|
+
Args:
|
|
392
|
+
target_type (type): The target type to check
|
|
393
|
+
|
|
394
|
+
Returns:
|
|
395
|
+
bool: True if the output type is compatible with the target type, False otherwise
|
|
396
|
+
"""
|
|
397
|
+
return self._is_pydantic_type_compatible(self.output_type, target_type)
|
|
398
|
+
|
|
399
|
+
def _is_pydantic_type_compatible(self, source_type: type, target_type: type) -> bool:
|
|
400
|
+
"""Check strict type compatibility without batch compatibility hacks.
|
|
401
|
+
|
|
402
|
+
This focuses on proper type relationships: exact matches and subclass relationships.
|
|
403
|
+
|
|
404
|
+
Args:
|
|
405
|
+
source_type (type): The source type to check
|
|
406
|
+
target_type (type): The target type to check compatibility with
|
|
407
|
+
|
|
408
|
+
Returns:
|
|
409
|
+
bool: True if types are compatible, False otherwise
|
|
410
|
+
"""
|
|
411
|
+
# Direct equality check (most common case)
|
|
412
|
+
if source_type == target_type:
|
|
413
|
+
return True
|
|
414
|
+
|
|
415
|
+
# Subclass relationship check
|
|
416
|
+
try:
|
|
417
|
+
if issubclass(source_type, target_type):
|
|
418
|
+
return True
|
|
419
|
+
except TypeError:
|
|
420
|
+
# Generic types can't use issubclass, they're only compatible if equal
|
|
421
|
+
logger.debug("Generic type %s cannot be used with issubclass, they're only compatible if equal",
|
|
422
|
+
source_type)
|
|
423
|
+
|
|
424
|
+
return False
|
|
425
|
+
|
|
165
426
|
@lru_cache
|
|
166
|
-
def
|
|
427
|
+
def _get_input_validator(self) -> type[BaseModel]:
|
|
428
|
+
"""Create a Pydantic model for validating input types.
|
|
429
|
+
|
|
430
|
+
Returns:
|
|
431
|
+
type[BaseModel]: The Pydantic model for validating input types
|
|
167
432
|
"""
|
|
168
|
-
|
|
169
|
-
|
|
433
|
+
input_type = self.input_type
|
|
434
|
+
return create_model(f"{self.__class__.__name__}InputValidator", input=(input_type, FieldInfo()))
|
|
170
435
|
|
|
171
|
-
|
|
436
|
+
@lru_cache
|
|
437
|
+
def _get_output_validator(self) -> type[BaseModel]:
|
|
438
|
+
"""Create a Pydantic model for validating output types.
|
|
172
439
|
|
|
173
|
-
Returns
|
|
174
|
-
|
|
175
|
-
type
|
|
176
|
-
The python type of the output type
|
|
440
|
+
Returns:
|
|
441
|
+
type[BaseModel]: The Pydantic model for validating output types
|
|
177
442
|
"""
|
|
178
|
-
|
|
443
|
+
output_type = self.output_type
|
|
444
|
+
return create_model(f"{self.__class__.__name__}OutputValidator", output=(output_type, FieldInfo()))
|
|
179
445
|
|
|
180
|
-
|
|
181
|
-
|
|
446
|
+
def validate_input_type(self, item: Any) -> bool:
|
|
447
|
+
"""Validate that an item matches the expected input type using Pydantic.
|
|
448
|
+
|
|
449
|
+
Args:
|
|
450
|
+
item (Any): The item to validate
|
|
451
|
+
|
|
452
|
+
Returns:
|
|
453
|
+
bool: True if the item matches the input type, False otherwise
|
|
454
|
+
"""
|
|
455
|
+
try:
|
|
456
|
+
validator = self._get_input_validator()
|
|
457
|
+
validator(input=item)
|
|
458
|
+
return True
|
|
459
|
+
except ValidationError:
|
|
460
|
+
logger.warning("Item %s is not compatible with input type %s", item, self.input_type)
|
|
461
|
+
return False
|
|
182
462
|
|
|
183
|
-
|
|
463
|
+
def validate_output_type(self, item: Any) -> bool:
|
|
464
|
+
"""Validate that an item matches the expected output type using Pydantic.
|
|
465
|
+
|
|
466
|
+
Args:
|
|
467
|
+
item (Any): The item to validate
|
|
468
|
+
|
|
469
|
+
Returns:
|
|
470
|
+
bool: True if the item matches the output type, False otherwise
|
|
471
|
+
"""
|
|
472
|
+
try:
|
|
473
|
+
validator = self._get_output_validator()
|
|
474
|
+
validator(output=item)
|
|
475
|
+
return True
|
|
476
|
+
except ValidationError:
|
|
477
|
+
logger.warning("Item %s is not compatible with output type %s", item, self.output_type)
|
|
478
|
+
return False
|
|
479
|
+
|
|
480
|
+
@lru_cache
|
|
481
|
+
def extract_non_optional_type(self, type_obj: type | types.UnionType) -> Any:
|
|
482
|
+
"""Extract the non-None type from Optional[T] or Union[T, None] types.
|
|
483
|
+
|
|
484
|
+
This is useful when you need to pass a type to a system that doesn't
|
|
485
|
+
understand Optional types (like registries that expect concrete types).
|
|
486
|
+
|
|
487
|
+
Args:
|
|
488
|
+
type_obj (type | types.UnionType): The type to extract from (could be Optional[T] or Union[T, None])
|
|
489
|
+
|
|
490
|
+
Returns:
|
|
491
|
+
Any: The actual type without None, or the original type if not a union with None
|
|
492
|
+
"""
|
|
493
|
+
decomposed = DecomposedType(type_obj) # type: ignore[arg-type]
|
|
494
|
+
if decomposed.is_optional:
|
|
495
|
+
return decomposed.get_optional_type().type
|
|
496
|
+
return type_obj
|
|
@@ -193,14 +193,14 @@ class BatchingProcessor(CallbackProcessor[T, list[T]], Generic[T]):
|
|
|
193
193
|
await self._done_callback(batch)
|
|
194
194
|
logger.debug("Scheduled flush routed batch of %d items through pipeline", len(batch))
|
|
195
195
|
except Exception as e:
|
|
196
|
-
logger.
|
|
196
|
+
logger.exception("Error routing scheduled batch through pipeline: %s", e)
|
|
197
197
|
else:
|
|
198
198
|
logger.warning("Scheduled flush created batch of %d items but no pipeline callback set",
|
|
199
199
|
len(batch))
|
|
200
200
|
except asyncio.CancelledError:
|
|
201
201
|
pass
|
|
202
202
|
except Exception as e:
|
|
203
|
-
logger.
|
|
203
|
+
logger.exception("Error in scheduled flush: %s", e)
|
|
204
204
|
|
|
205
205
|
async def _create_batch(self) -> list[T]:
|
|
206
206
|
"""Create a batch from the current queue."""
|
|
@@ -241,7 +241,7 @@ class BatchingProcessor(CallbackProcessor[T, list[T]], Generic[T]):
|
|
|
241
241
|
try:
|
|
242
242
|
await asyncio.wait_for(self._shutdown_complete_event.wait(), timeout=self._shutdown_timeout)
|
|
243
243
|
logger.debug("Shutdown completion detected via event")
|
|
244
|
-
except
|
|
244
|
+
except TimeoutError:
|
|
245
245
|
logger.warning("Shutdown completion timeout exceeded (%s seconds)", self._shutdown_timeout)
|
|
246
246
|
return
|
|
247
247
|
|
|
@@ -271,9 +271,7 @@ class BatchingProcessor(CallbackProcessor[T, list[T]], Generic[T]):
|
|
|
271
271
|
"Successfully flushed final batch of %d items through pipeline during shutdown",
|
|
272
272
|
len(final_batch))
|
|
273
273
|
except Exception as e:
|
|
274
|
-
logger.
|
|
275
|
-
e,
|
|
276
|
-
exc_info=True)
|
|
274
|
+
logger.exception("Error routing final batch through pipeline during shutdown: %s", e)
|
|
277
275
|
else:
|
|
278
276
|
logger.warning("Final batch of %d items created during shutdown but no pipeline callback set",
|
|
279
277
|
len(final_batch))
|
|
@@ -285,7 +283,7 @@ class BatchingProcessor(CallbackProcessor[T, list[T]], Generic[T]):
|
|
|
285
283
|
logger.debug("BatchingProcessor shutdown completed successfully")
|
|
286
284
|
|
|
287
285
|
except Exception as e:
|
|
288
|
-
logger.
|
|
286
|
+
logger.exception("Error during BatchingProcessor shutdown: %s", e)
|
|
289
287
|
self._shutdown_complete = True
|
|
290
288
|
self._shutdown_complete_event.set()
|
|
291
289
|
|