nvidia-nat 1.3.0.dev2__py3-none-any.whl → 1.3.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +2 -2
- nat/agent/base.py +24 -15
- nat/agent/dual_node.py +9 -4
- nat/agent/prompt_optimizer/prompt.py +68 -0
- nat/agent/prompt_optimizer/register.py +149 -0
- nat/agent/react_agent/agent.py +79 -47
- nat/agent/react_agent/register.py +41 -21
- nat/agent/reasoning_agent/reasoning_agent.py +11 -9
- nat/agent/register.py +1 -1
- nat/agent/rewoo_agent/agent.py +326 -148
- nat/agent/rewoo_agent/prompt.py +19 -22
- nat/agent/rewoo_agent/register.py +46 -26
- nat/agent/tool_calling_agent/agent.py +84 -28
- nat/agent/tool_calling_agent/register.py +51 -28
- nat/authentication/api_key/api_key_auth_provider.py +2 -2
- nat/authentication/credential_validator/bearer_token_validator.py +557 -0
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
- nat/authentication/interfaces.py +5 -2
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +40 -20
- nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
- nat/authentication/register.py +0 -1
- nat/builder/builder.py +56 -24
- nat/builder/component_utils.py +9 -5
- nat/builder/context.py +46 -11
- nat/builder/eval_builder.py +16 -11
- nat/builder/framework_enum.py +1 -0
- nat/builder/front_end.py +1 -1
- nat/builder/function.py +378 -8
- nat/builder/function_base.py +3 -3
- nat/builder/function_info.py +6 -8
- nat/builder/user_interaction_manager.py +2 -2
- nat/builder/workflow.py +13 -1
- nat/builder/workflow_builder.py +281 -76
- nat/cli/cli_utils/config_override.py +2 -2
- nat/cli/commands/evaluate.py +1 -1
- nat/cli/commands/info/info.py +16 -6
- nat/cli/commands/info/list_channels.py +1 -1
- nat/cli/commands/info/list_components.py +7 -8
- nat/cli/commands/mcp/__init__.py +14 -0
- nat/cli/commands/mcp/mcp.py +986 -0
- nat/cli/commands/object_store/__init__.py +14 -0
- nat/cli/commands/object_store/object_store.py +227 -0
- nat/cli/commands/optimize.py +90 -0
- nat/cli/commands/registry/publish.py +2 -2
- nat/cli/commands/registry/pull.py +2 -2
- nat/cli/commands/registry/remove.py +2 -2
- nat/cli/commands/registry/search.py +15 -17
- nat/cli/commands/start.py +16 -5
- nat/cli/commands/uninstall.py +1 -1
- nat/cli/commands/workflow/templates/config.yml.j2 +0 -1
- nat/cli/commands/workflow/templates/pyproject.toml.j2 +4 -1
- nat/cli/commands/workflow/templates/register.py.j2 +0 -1
- nat/cli/commands/workflow/workflow_commands.py +9 -13
- nat/cli/entrypoint.py +8 -10
- nat/cli/register_workflow.py +38 -4
- nat/cli/type_registry.py +75 -6
- nat/control_flow/__init__.py +0 -0
- nat/control_flow/register.py +20 -0
- nat/control_flow/router_agent/__init__.py +0 -0
- nat/control_flow/router_agent/agent.py +329 -0
- nat/control_flow/router_agent/prompt.py +48 -0
- nat/control_flow/router_agent/register.py +91 -0
- nat/control_flow/sequential_executor.py +166 -0
- nat/data_models/agent.py +34 -0
- nat/data_models/api_server.py +10 -10
- nat/data_models/authentication.py +23 -9
- nat/data_models/common.py +1 -1
- nat/data_models/component.py +2 -0
- nat/data_models/component_ref.py +11 -0
- nat/data_models/config.py +41 -17
- nat/data_models/dataset_handler.py +1 -1
- nat/data_models/discovery_metadata.py +4 -4
- nat/data_models/evaluate.py +4 -1
- nat/data_models/function.py +34 -0
- nat/data_models/function_dependencies.py +14 -6
- nat/data_models/gated_field_mixin.py +242 -0
- nat/data_models/intermediate_step.py +3 -3
- nat/data_models/optimizable.py +119 -0
- nat/data_models/optimizer.py +149 -0
- nat/data_models/swe_bench_model.py +1 -1
- nat/data_models/temperature_mixin.py +44 -0
- nat/data_models/thinking_mixin.py +86 -0
- nat/data_models/top_p_mixin.py +44 -0
- nat/embedder/nim_embedder.py +1 -1
- nat/embedder/openai_embedder.py +1 -1
- nat/embedder/register.py +0 -1
- nat/eval/config.py +3 -1
- nat/eval/dataset_handler/dataset_handler.py +71 -7
- nat/eval/evaluate.py +86 -31
- nat/eval/evaluator/base_evaluator.py +1 -1
- nat/eval/evaluator/evaluator_model.py +13 -0
- nat/eval/intermediate_step_adapter.py +1 -1
- nat/eval/rag_evaluator/evaluate.py +2 -2
- nat/eval/rag_evaluator/register.py +3 -3
- nat/eval/register.py +4 -1
- nat/eval/remote_workflow.py +3 -3
- nat/eval/runtime_evaluator/__init__.py +14 -0
- nat/eval/runtime_evaluator/evaluate.py +123 -0
- nat/eval/runtime_evaluator/register.py +100 -0
- nat/eval/swe_bench_evaluator/evaluate.py +6 -6
- nat/eval/trajectory_evaluator/evaluate.py +1 -1
- nat/eval/trajectory_evaluator/register.py +1 -1
- nat/eval/tunable_rag_evaluator/evaluate.py +4 -7
- nat/eval/utils/eval_trace_ctx.py +89 -0
- nat/eval/utils/weave_eval.py +18 -9
- nat/experimental/decorators/experimental_warning_decorator.py +27 -7
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +3 -3
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +1 -1
- nat/experimental/test_time_compute/models/strategy_base.py +5 -4
- nat/experimental/test_time_compute/register.py +0 -1
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -3
- nat/front_ends/console/authentication_flow_handler.py +82 -30
- nat/front_ends/console/console_front_end_plugin.py +8 -5
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
- nat/front_ends/fastapi/dask_client_mixin.py +65 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +36 -5
- nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +135 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +481 -281
- nat/front_ends/fastapi/job_store.py +518 -99
- nat/front_ends/fastapi/main.py +11 -19
- nat/front_ends/fastapi/message_handler.py +13 -14
- nat/front_ends/fastapi/message_validator.py +17 -19
- nat/front_ends/fastapi/response_helpers.py +4 -4
- nat/front_ends/fastapi/step_adaptor.py +2 -2
- nat/front_ends/fastapi/utils.py +57 -0
- nat/front_ends/mcp/introspection_token_verifier.py +73 -0
- nat/front_ends/mcp/mcp_front_end_config.py +10 -1
- nat/front_ends/mcp/mcp_front_end_plugin.py +45 -13
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +116 -8
- nat/front_ends/mcp/tool_converter.py +44 -14
- nat/front_ends/register.py +0 -1
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
- nat/llm/aws_bedrock_llm.py +24 -12
- nat/llm/azure_openai_llm.py +13 -6
- nat/llm/litellm_llm.py +69 -0
- nat/llm/nim_llm.py +20 -8
- nat/llm/openai_llm.py +14 -6
- nat/llm/register.py +4 -1
- nat/llm/utils/env_config_value.py +2 -3
- nat/llm/utils/thinking.py +215 -0
- nat/meta/pypi.md +9 -9
- nat/object_store/register.py +0 -1
- nat/observability/exporter/base_exporter.py +3 -3
- nat/observability/exporter/file_exporter.py +1 -1
- nat/observability/exporter/processing_exporter.py +309 -81
- nat/observability/exporter/span_exporter.py +1 -1
- nat/observability/exporter_manager.py +7 -7
- nat/observability/mixin/file_mixin.py +7 -7
- nat/observability/mixin/redaction_config_mixin.py +42 -0
- nat/observability/mixin/tagging_config_mixin.py +62 -0
- nat/observability/mixin/type_introspection_mixin.py +420 -107
- nat/observability/processor/batching_processor.py +5 -7
- nat/observability/processor/falsy_batch_filter_processor.py +55 -0
- nat/observability/processor/processor.py +3 -0
- nat/observability/processor/processor_factory.py +70 -0
- nat/observability/processor/redaction/__init__.py +24 -0
- nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
- nat/observability/processor/redaction/redaction_processor.py +177 -0
- nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
- nat/observability/processor/span_tagging_processor.py +68 -0
- nat/observability/register.py +6 -4
- nat/profiler/calc/calc_runner.py +3 -4
- nat/profiler/callbacks/agno_callback_handler.py +1 -1
- nat/profiler/callbacks/langchain_callback_handler.py +6 -6
- nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +3 -3
- nat/profiler/data_frame_row.py +1 -1
- nat/profiler/decorators/framework_wrapper.py +62 -13
- nat/profiler/decorators/function_tracking.py +160 -3
- nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/data_models.py +3 -3
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +7 -8
- nat/profiler/inference_optimization/token_uniqueness.py +1 -1
- nat/profiler/parameter_optimization/__init__.py +0 -0
- nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
- nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
- nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
- nat/profiler/parameter_optimization/parameter_selection.py +107 -0
- nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
- nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
- nat/profiler/parameter_optimization/update_helpers.py +66 -0
- nat/profiler/profile_runner.py +14 -9
- nat/profiler/utils.py +4 -2
- nat/registry_handlers/local/local_handler.py +2 -2
- nat/registry_handlers/package_utils.py +1 -2
- nat/registry_handlers/pypi/pypi_handler.py +23 -26
- nat/registry_handlers/register.py +3 -4
- nat/registry_handlers/rest/rest_handler.py +12 -13
- nat/retriever/milvus/retriever.py +2 -2
- nat/retriever/nemo_retriever/retriever.py +1 -1
- nat/retriever/register.py +0 -1
- nat/runtime/loader.py +2 -2
- nat/runtime/runner.py +3 -2
- nat/runtime/session.py +43 -8
- nat/settings/global_settings.py +16 -5
- nat/tool/chat_completion.py +5 -2
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +3 -3
- nat/tool/datetime_tools.py +49 -9
- nat/tool/document_search.py +2 -2
- nat/tool/github_tools.py +450 -0
- nat/tool/nvidia_rag.py +1 -1
- nat/tool/register.py +2 -9
- nat/tool/retriever.py +3 -2
- nat/utils/callable_utils.py +70 -0
- nat/utils/data_models/schema_validator.py +3 -3
- nat/utils/exception_handlers/automatic_retries.py +104 -51
- nat/utils/exception_handlers/schemas.py +1 -1
- nat/utils/io/yaml_tools.py +2 -2
- nat/utils/log_levels.py +25 -0
- nat/utils/reactive/base/observable_base.py +2 -2
- nat/utils/reactive/base/observer_base.py +1 -1
- nat/utils/reactive/observable.py +2 -2
- nat/utils/reactive/observer.py +4 -4
- nat/utils/reactive/subscription.py +1 -1
- nat/utils/settings/global_settings.py +6 -8
- nat/utils/type_converter.py +4 -3
- nat/utils/type_utils.py +9 -5
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/METADATA +42 -16
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/RECORD +230 -189
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/entry_points.txt +1 -0
- nat/cli/commands/info/list_mcp.py +0 -304
- nat/tool/github_tools/create_github_commit.py +0 -133
- nat/tool/github_tools/create_github_issue.py +0 -87
- nat/tool/github_tools/create_github_pr.py +0 -106
- nat/tool/github_tools/get_github_file.py +0 -106
- nat/tool/github_tools/get_github_issue.py +0 -166
- nat/tool/github_tools/get_github_pr.py +0 -256
- nat/tool/github_tools/update_github_issue.py +0 -100
- nat/tool/mcp/exceptions.py +0 -142
- nat/tool/mcp/mcp_client.py +0 -255
- nat/tool/mcp/mcp_tool.py +0 -96
- nat/utils/exception_handlers/mcp.py +0 -211
- /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
- /nat/{tool/mcp → authentication/credential_validator}/__init__.py +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -13,56 +13,71 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
import logging
|
|
17
|
+
from collections.abc import Callable
|
|
18
|
+
from datetime import UTC
|
|
16
19
|
from datetime import datetime
|
|
17
|
-
from datetime import timezone
|
|
18
20
|
|
|
21
|
+
import httpx
|
|
19
22
|
from authlib.integrations.httpx_client import OAuth2Client as AuthlibOAuth2Client
|
|
20
23
|
from pydantic import SecretStr
|
|
21
24
|
|
|
22
25
|
from nat.authentication.interfaces import AuthProviderBase
|
|
23
26
|
from nat.authentication.oauth2.oauth2_auth_code_flow_provider_config import OAuth2AuthCodeFlowProviderConfig
|
|
24
27
|
from nat.builder.context import Context
|
|
28
|
+
from nat.data_models.authentication import AuthenticatedContext
|
|
25
29
|
from nat.data_models.authentication import AuthFlowType
|
|
26
30
|
from nat.data_models.authentication import AuthResult
|
|
27
31
|
from nat.data_models.authentication import BearerTokenCred
|
|
28
32
|
|
|
33
|
+
logger = logging.getLogger(__name__)
|
|
34
|
+
|
|
29
35
|
|
|
30
36
|
class OAuth2AuthCodeFlowProvider(AuthProviderBase[OAuth2AuthCodeFlowProviderConfig]):
|
|
31
37
|
|
|
32
38
|
def __init__(self, config: OAuth2AuthCodeFlowProviderConfig):
|
|
33
39
|
super().__init__(config)
|
|
34
40
|
self._authenticated_tokens: dict[str, AuthResult] = {}
|
|
35
|
-
self.
|
|
41
|
+
self._auth_callback = None
|
|
36
42
|
|
|
37
43
|
async def _attempt_token_refresh(self, user_id: str, auth_result: AuthResult) -> AuthResult | None:
|
|
38
44
|
refresh_token = auth_result.raw.get("refresh_token")
|
|
39
45
|
if not isinstance(refresh_token, str):
|
|
40
46
|
return None
|
|
41
47
|
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
48
|
+
try:
|
|
49
|
+
with AuthlibOAuth2Client(
|
|
50
|
+
client_id=self.config.client_id,
|
|
51
|
+
client_secret=self.config.client_secret,
|
|
52
|
+
) as client:
|
|
47
53
|
new_token_data = client.refresh_token(self.config.token_url, refresh_token=refresh_token)
|
|
48
|
-
except Exception:
|
|
49
|
-
# On any failure, we'll fall back to the full auth flow.
|
|
50
|
-
return None
|
|
51
54
|
|
|
52
|
-
|
|
53
|
-
|
|
55
|
+
expires_at_ts = new_token_data.get("expires_at")
|
|
56
|
+
new_expires_at = datetime.fromtimestamp(expires_at_ts, tz=UTC) if expires_at_ts else None
|
|
54
57
|
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
58
|
+
new_auth_result = AuthResult(
|
|
59
|
+
credentials=[BearerTokenCred(token=SecretStr(new_token_data["access_token"]))],
|
|
60
|
+
token_expires_at=new_expires_at,
|
|
61
|
+
raw=new_token_data,
|
|
62
|
+
)
|
|
60
63
|
|
|
61
|
-
|
|
64
|
+
self._authenticated_tokens[user_id] = new_auth_result
|
|
65
|
+
except httpx.HTTPStatusError:
|
|
66
|
+
return None
|
|
67
|
+
except httpx.RequestError:
|
|
68
|
+
return None
|
|
69
|
+
except Exception:
|
|
70
|
+
# On any other failure, we'll fall back to the full auth flow.
|
|
71
|
+
return None
|
|
62
72
|
|
|
63
73
|
return new_auth_result
|
|
64
74
|
|
|
65
|
-
|
|
75
|
+
def _set_custom_auth_callback(self,
|
|
76
|
+
auth_callback: Callable[[OAuth2AuthCodeFlowProviderConfig, AuthFlowType],
|
|
77
|
+
AuthenticatedContext]):
|
|
78
|
+
self._auth_callback = auth_callback
|
|
79
|
+
|
|
80
|
+
async def authenticate(self, user_id: str | None = None, **kwargs) -> AuthResult:
|
|
66
81
|
if user_id is None and hasattr(Context.get(), "metadata") and hasattr(
|
|
67
82
|
Context.get().metadata, "cookies") and Context.get().metadata.cookies is not None:
|
|
68
83
|
session_id = Context.get().metadata.cookies.get("nat-session", None)
|
|
@@ -80,7 +95,12 @@ class OAuth2AuthCodeFlowProvider(AuthProviderBase[OAuth2AuthCodeFlowProviderConf
|
|
|
80
95
|
if refreshed_auth_result:
|
|
81
96
|
return refreshed_auth_result
|
|
82
97
|
|
|
83
|
-
|
|
98
|
+
# Try getting callback from the context if that's not set, use the default callback
|
|
99
|
+
try:
|
|
100
|
+
auth_callback = Context.get().user_auth_callback
|
|
101
|
+
except RuntimeError:
|
|
102
|
+
auth_callback = self._auth_callback
|
|
103
|
+
|
|
84
104
|
if not auth_callback:
|
|
85
105
|
raise RuntimeError("Authentication callback not set on Context.")
|
|
86
106
|
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from urllib.parse import urlparse
|
|
17
|
+
|
|
18
|
+
from pydantic import Field
|
|
19
|
+
from pydantic import field_validator
|
|
20
|
+
from pydantic import model_validator
|
|
21
|
+
|
|
22
|
+
from nat.data_models.authentication import AuthProviderBaseConfig
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class OAuth2ResourceServerConfig(AuthProviderBaseConfig, name="oauth2_resource_server"):
|
|
26
|
+
"""OAuth 2.0 Resource Server authentication configuration.
|
|
27
|
+
|
|
28
|
+
Supports:
|
|
29
|
+
• JWT access tokens via JWKS / OIDC Discovery / issuer fallback
|
|
30
|
+
• Opaque access tokens via RFC 7662 introspection
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
issuer_url: str = Field(
|
|
34
|
+
description=("The unique issuer identifier for an authorization server. "
|
|
35
|
+
"Required for validation and used to derive the default JWKS URI "
|
|
36
|
+
"(<issuer_url>/.well-known/jwks.json) if `jwks_uri` and `discovery_url` are not provided."), )
|
|
37
|
+
scopes: list[str] = Field(
|
|
38
|
+
default_factory=list,
|
|
39
|
+
description="Scopes required by this API. Validation ensures the token grants all listed scopes.",
|
|
40
|
+
)
|
|
41
|
+
audience: str | None = Field(
|
|
42
|
+
default=None,
|
|
43
|
+
description=(
|
|
44
|
+
"Expected audience (`aud`) claim for this API. If set, validation will reject tokens without this audience."
|
|
45
|
+
),
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
# JWT verification params
|
|
49
|
+
jwks_uri: str | None = Field(
|
|
50
|
+
default=None,
|
|
51
|
+
description=("Direct JWKS endpoint URI for JWT signature verification. "
|
|
52
|
+
"Optional if discovery or issuer is provided."),
|
|
53
|
+
)
|
|
54
|
+
discovery_url: str | None = Field(
|
|
55
|
+
default=None,
|
|
56
|
+
description=("OIDC discovery metadata URL. Used to automatically resolve JWKS and introspection endpoints."),
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
# Opaque token (introspection) params
|
|
60
|
+
introspection_endpoint: str | None = Field(
|
|
61
|
+
default=None,
|
|
62
|
+
description=("RFC 7662 token introspection endpoint. "
|
|
63
|
+
"Required for opaque token validation and must be used with `client_id` and `client_secret`."),
|
|
64
|
+
)
|
|
65
|
+
client_id: str | None = Field(
|
|
66
|
+
default=None,
|
|
67
|
+
description="OAuth2 client ID for authenticating to the introspection endpoint (opaque token validation).",
|
|
68
|
+
)
|
|
69
|
+
client_secret: str | None = Field(
|
|
70
|
+
default=None,
|
|
71
|
+
description="OAuth2 client secret for authenticating to the introspection endpoint (opaque token validation).",
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
@staticmethod
|
|
75
|
+
def _is_https_or_localhost(url: str) -> bool:
|
|
76
|
+
try:
|
|
77
|
+
value = urlparse(url)
|
|
78
|
+
if not value.scheme or not value.netloc:
|
|
79
|
+
return False
|
|
80
|
+
if value.scheme == "https":
|
|
81
|
+
return True
|
|
82
|
+
return value.scheme == "http" and (value.hostname in {"localhost", "127.0.0.1", "::1"})
|
|
83
|
+
except Exception:
|
|
84
|
+
return False
|
|
85
|
+
|
|
86
|
+
@field_validator("issuer_url", "jwks_uri", "discovery_url", "introspection_endpoint")
|
|
87
|
+
@classmethod
|
|
88
|
+
def _require_valid_url(cls, value: str | None, info):
|
|
89
|
+
if value is None:
|
|
90
|
+
return value
|
|
91
|
+
if not cls._is_https_or_localhost(value):
|
|
92
|
+
raise ValueError(f"{info.field_name} must be HTTPS (http allowed only for localhost). Got: {value}")
|
|
93
|
+
return value
|
|
94
|
+
|
|
95
|
+
# ---------- Cross-field validation: ensure at least one viable path ----------
|
|
96
|
+
|
|
97
|
+
@model_validator(mode="after")
|
|
98
|
+
def _ensure_verification_path(self):
|
|
99
|
+
"""
|
|
100
|
+
JWT path viable if any of: jwks_uri OR discovery_url OR issuer_url (fallback JWKS).
|
|
101
|
+
Opaque path viable if: introspection_endpoint AND client_id AND client_secret.
|
|
102
|
+
"""
|
|
103
|
+
has_jwt_path = bool(self.jwks_uri or self.discovery_url or self.issuer_url)
|
|
104
|
+
has_opaque_path = bool(self.introspection_endpoint and self.client_id and self.client_secret)
|
|
105
|
+
|
|
106
|
+
# If introspection endpoint is set, enforce creds are present
|
|
107
|
+
if self.introspection_endpoint:
|
|
108
|
+
missing = []
|
|
109
|
+
if not self.client_id:
|
|
110
|
+
missing.append("client_id")
|
|
111
|
+
if not self.client_secret:
|
|
112
|
+
missing.append("client_secret")
|
|
113
|
+
if missing:
|
|
114
|
+
raise ValueError(
|
|
115
|
+
f"introspection_endpoint configured but missing required credentials: {', '.join(missing)}")
|
|
116
|
+
|
|
117
|
+
# Require at least one path
|
|
118
|
+
if not (has_jwt_path or has_opaque_path):
|
|
119
|
+
raise ValueError("Invalid configuration: no verification method available. "
|
|
120
|
+
"Configure one of the following:\n"
|
|
121
|
+
" • JWT path: set jwks_uri OR discovery_url OR issuer_url (for JWKS fallback)\n"
|
|
122
|
+
" • Opaque path: set introspection_endpoint + client_id + client_secret")
|
|
123
|
+
|
|
124
|
+
return self
|
nat/authentication/register.py
CHANGED
nat/builder/builder.py
CHANGED
|
@@ -24,9 +24,11 @@ from nat.authentication.interfaces import AuthProviderBase
|
|
|
24
24
|
from nat.builder.context import Context
|
|
25
25
|
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
26
26
|
from nat.builder.function import Function
|
|
27
|
+
from nat.builder.function import FunctionGroup
|
|
27
28
|
from nat.data_models.authentication import AuthProviderBaseConfig
|
|
28
29
|
from nat.data_models.component_ref import AuthenticationRef
|
|
29
30
|
from nat.data_models.component_ref import EmbedderRef
|
|
31
|
+
from nat.data_models.component_ref import FunctionGroupRef
|
|
30
32
|
from nat.data_models.component_ref import FunctionRef
|
|
31
33
|
from nat.data_models.component_ref import LLMRef
|
|
32
34
|
from nat.data_models.component_ref import MemoryRef
|
|
@@ -36,20 +38,25 @@ from nat.data_models.component_ref import TTCStrategyRef
|
|
|
36
38
|
from nat.data_models.embedder import EmbedderBaseConfig
|
|
37
39
|
from nat.data_models.evaluator import EvaluatorBaseConfig
|
|
38
40
|
from nat.data_models.function import FunctionBaseConfig
|
|
41
|
+
from nat.data_models.function import FunctionGroupBaseConfig
|
|
39
42
|
from nat.data_models.function_dependencies import FunctionDependencies
|
|
40
43
|
from nat.data_models.llm import LLMBaseConfig
|
|
41
44
|
from nat.data_models.memory import MemoryBaseConfig
|
|
42
45
|
from nat.data_models.object_store import ObjectStoreBaseConfig
|
|
43
46
|
from nat.data_models.retriever import RetrieverBaseConfig
|
|
44
47
|
from nat.data_models.ttc_strategy import TTCStrategyBaseConfig
|
|
48
|
+
from nat.experimental.decorators.experimental_warning_decorator import experimental
|
|
45
49
|
from nat.experimental.test_time_compute.models.stage_enums import PipelineTypeEnum
|
|
46
50
|
from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum
|
|
47
51
|
from nat.memory.interfaces import MemoryEditor
|
|
48
52
|
from nat.object_store.interfaces import ObjectStore
|
|
49
53
|
from nat.retriever.interface import Retriever
|
|
50
54
|
|
|
55
|
+
if typing.TYPE_CHECKING:
|
|
56
|
+
from nat.experimental.test_time_compute.models.strategy_base import StrategyBase
|
|
51
57
|
|
|
52
|
-
|
|
58
|
+
|
|
59
|
+
class UserManagerHolder:
|
|
53
60
|
|
|
54
61
|
def __init__(self, context: Context) -> None:
|
|
55
62
|
self._context = context
|
|
@@ -58,24 +65,40 @@ class UserManagerHolder():
|
|
|
58
65
|
return self._context.user_manager.get_id()
|
|
59
66
|
|
|
60
67
|
|
|
61
|
-
class Builder(ABC):
|
|
68
|
+
class Builder(ABC):
|
|
62
69
|
|
|
63
70
|
@abstractmethod
|
|
64
71
|
async def add_function(self, name: str | FunctionRef, config: FunctionBaseConfig) -> Function:
|
|
65
72
|
pass
|
|
66
73
|
|
|
67
74
|
@abstractmethod
|
|
68
|
-
def
|
|
75
|
+
async def add_function_group(self, name: str | FunctionGroupRef, config: FunctionGroupBaseConfig) -> FunctionGroup:
|
|
76
|
+
pass
|
|
77
|
+
|
|
78
|
+
@abstractmethod
|
|
79
|
+
async def get_function(self, name: str | FunctionRef) -> Function:
|
|
69
80
|
pass
|
|
70
81
|
|
|
71
|
-
|
|
82
|
+
@abstractmethod
|
|
83
|
+
async def get_function_group(self, name: str | FunctionGroupRef) -> FunctionGroup:
|
|
84
|
+
pass
|
|
72
85
|
|
|
73
|
-
|
|
86
|
+
async def get_functions(self, function_names: Sequence[str | FunctionRef]) -> list[Function]:
|
|
87
|
+
tasks = [self.get_function(name) for name in function_names]
|
|
88
|
+
return list(await asyncio.gather(*tasks, return_exceptions=False))
|
|
89
|
+
|
|
90
|
+
async def get_function_groups(self, function_group_names: Sequence[str | FunctionGroupRef]) -> list[FunctionGroup]:
|
|
91
|
+
tasks = [self.get_function_group(name) for name in function_group_names]
|
|
92
|
+
return list(await asyncio.gather(*tasks, return_exceptions=False))
|
|
74
93
|
|
|
75
94
|
@abstractmethod
|
|
76
95
|
def get_function_config(self, name: str | FunctionRef) -> FunctionBaseConfig:
|
|
77
96
|
pass
|
|
78
97
|
|
|
98
|
+
@abstractmethod
|
|
99
|
+
def get_function_group_config(self, name: str | FunctionGroupRef) -> FunctionGroupBaseConfig:
|
|
100
|
+
pass
|
|
101
|
+
|
|
79
102
|
@abstractmethod
|
|
80
103
|
async def set_workflow(self, config: FunctionBaseConfig) -> Function:
|
|
81
104
|
pass
|
|
@@ -88,17 +111,18 @@ class Builder(ABC): # pylint: disable=too-many-public-methods
|
|
|
88
111
|
def get_workflow_config(self) -> FunctionBaseConfig:
|
|
89
112
|
pass
|
|
90
113
|
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
114
|
+
@abstractmethod
|
|
115
|
+
async def get_tools(self,
|
|
116
|
+
tool_names: Sequence[str | FunctionRef | FunctionGroupRef],
|
|
117
|
+
wrapper_type: LLMFrameworkEnum | str) -> list[typing.Any]:
|
|
118
|
+
pass
|
|
95
119
|
|
|
96
120
|
@abstractmethod
|
|
97
|
-
def get_tool(self, fn_name: str | FunctionRef, wrapper_type: LLMFrameworkEnum | str) -> typing.Any:
|
|
121
|
+
async def get_tool(self, fn_name: str | FunctionRef, wrapper_type: LLMFrameworkEnum | str) -> typing.Any:
|
|
98
122
|
pass
|
|
99
123
|
|
|
100
124
|
@abstractmethod
|
|
101
|
-
async def add_llm(self, name: str | LLMRef, config: LLMBaseConfig):
|
|
125
|
+
async def add_llm(self, name: str | LLMRef, config: LLMBaseConfig) -> typing.Any:
|
|
102
126
|
pass
|
|
103
127
|
|
|
104
128
|
@abstractmethod
|
|
@@ -119,7 +143,9 @@ class Builder(ABC): # pylint: disable=too-many-public-methods
|
|
|
119
143
|
pass
|
|
120
144
|
|
|
121
145
|
@abstractmethod
|
|
122
|
-
|
|
146
|
+
@experimental(feature_name="Authentication")
|
|
147
|
+
async def add_auth_provider(self, name: str | AuthenticationRef,
|
|
148
|
+
config: AuthProviderBaseConfig) -> AuthProviderBase:
|
|
123
149
|
pass
|
|
124
150
|
|
|
125
151
|
@abstractmethod
|
|
@@ -135,7 +161,7 @@ class Builder(ABC): # pylint: disable=too-many-public-methods
|
|
|
135
161
|
return list(auth_providers)
|
|
136
162
|
|
|
137
163
|
@abstractmethod
|
|
138
|
-
async def add_object_store(self, name: str | ObjectStoreRef, config: ObjectStoreBaseConfig):
|
|
164
|
+
async def add_object_store(self, name: str | ObjectStoreRef, config: ObjectStoreBaseConfig) -> ObjectStore:
|
|
139
165
|
pass
|
|
140
166
|
|
|
141
167
|
async def get_object_store_clients(self, object_store_names: Sequence[str | ObjectStoreRef]) -> list[ObjectStore]:
|
|
@@ -153,7 +179,7 @@ class Builder(ABC): # pylint: disable=too-many-public-methods
|
|
|
153
179
|
pass
|
|
154
180
|
|
|
155
181
|
@abstractmethod
|
|
156
|
-
async def add_embedder(self, name: str | EmbedderRef, config: EmbedderBaseConfig):
|
|
182
|
+
async def add_embedder(self, name: str | EmbedderRef, config: EmbedderBaseConfig) -> None:
|
|
157
183
|
pass
|
|
158
184
|
|
|
159
185
|
async def get_embedders(self, embedder_names: Sequence[str | EmbedderRef],
|
|
@@ -174,17 +200,18 @@ class Builder(ABC): # pylint: disable=too-many-public-methods
|
|
|
174
200
|
pass
|
|
175
201
|
|
|
176
202
|
@abstractmethod
|
|
177
|
-
async def add_memory_client(self, name: str | MemoryRef, config: MemoryBaseConfig):
|
|
203
|
+
async def add_memory_client(self, name: str | MemoryRef, config: MemoryBaseConfig) -> MemoryEditor:
|
|
178
204
|
pass
|
|
179
205
|
|
|
180
|
-
def get_memory_clients(self, memory_names: Sequence[str | MemoryRef]) -> list[MemoryEditor]:
|
|
206
|
+
async def get_memory_clients(self, memory_names: Sequence[str | MemoryRef]) -> list[MemoryEditor]:
|
|
181
207
|
"""
|
|
182
208
|
Return a list of memory clients for the specified names.
|
|
183
209
|
"""
|
|
184
|
-
|
|
210
|
+
tasks = [self.get_memory_client(n) for n in memory_names]
|
|
211
|
+
return list(await asyncio.gather(*tasks, return_exceptions=False))
|
|
185
212
|
|
|
186
213
|
@abstractmethod
|
|
187
|
-
def get_memory_client(self, memory_name: str | MemoryRef) -> MemoryEditor:
|
|
214
|
+
async def get_memory_client(self, memory_name: str | MemoryRef) -> MemoryEditor:
|
|
188
215
|
"""
|
|
189
216
|
Return the instantiated memory client for the given name.
|
|
190
217
|
"""
|
|
@@ -195,12 +222,12 @@ class Builder(ABC): # pylint: disable=too-many-public-methods
|
|
|
195
222
|
pass
|
|
196
223
|
|
|
197
224
|
@abstractmethod
|
|
198
|
-
async def add_retriever(self, name: str | RetrieverRef, config: RetrieverBaseConfig):
|
|
225
|
+
async def add_retriever(self, name: str | RetrieverRef, config: RetrieverBaseConfig) -> None:
|
|
199
226
|
pass
|
|
200
227
|
|
|
201
228
|
async def get_retrievers(self,
|
|
202
229
|
retriever_names: Sequence[str | RetrieverRef],
|
|
203
|
-
wrapper_type: LLMFrameworkEnum | str | None = None):
|
|
230
|
+
wrapper_type: LLMFrameworkEnum | str | None = None) -> list[Retriever]:
|
|
204
231
|
|
|
205
232
|
tasks = [self.get_retriever(n, wrapper_type=wrapper_type) for n in retriever_names]
|
|
206
233
|
|
|
@@ -232,14 +259,15 @@ class Builder(ABC): # pylint: disable=too-many-public-methods
|
|
|
232
259
|
pass
|
|
233
260
|
|
|
234
261
|
@abstractmethod
|
|
235
|
-
|
|
262
|
+
@experimental(feature_name="TTC")
|
|
263
|
+
async def add_ttc_strategy(self, name: str | TTCStrategyRef, config: TTCStrategyBaseConfig):
|
|
236
264
|
pass
|
|
237
265
|
|
|
238
266
|
@abstractmethod
|
|
239
267
|
async def get_ttc_strategy(self,
|
|
240
268
|
strategy_name: str | TTCStrategyRef,
|
|
241
269
|
pipeline_type: PipelineTypeEnum,
|
|
242
|
-
stage_type: StageTypeEnum):
|
|
270
|
+
stage_type: StageTypeEnum) -> "StrategyBase":
|
|
243
271
|
pass
|
|
244
272
|
|
|
245
273
|
@abstractmethod
|
|
@@ -257,8 +285,12 @@ class Builder(ABC): # pylint: disable=too-many-public-methods
|
|
|
257
285
|
def get_function_dependencies(self, fn_name: str) -> FunctionDependencies:
|
|
258
286
|
pass
|
|
259
287
|
|
|
288
|
+
@abstractmethod
|
|
289
|
+
def get_function_group_dependencies(self, fn_name: str) -> FunctionDependencies:
|
|
290
|
+
pass
|
|
291
|
+
|
|
260
292
|
|
|
261
|
-
class EvalBuilder(
|
|
293
|
+
class EvalBuilder(ABC):
|
|
262
294
|
|
|
263
295
|
@abstractmethod
|
|
264
296
|
async def add_evaluator(self, name: str, config: EvaluatorBaseConfig):
|
|
@@ -281,5 +313,5 @@ class EvalBuilder(Builder):
|
|
|
281
313
|
pass
|
|
282
314
|
|
|
283
315
|
@abstractmethod
|
|
284
|
-
def get_all_tools(self, wrapper_type: LLMFrameworkEnum | str) -> list[typing.Any]:
|
|
316
|
+
async def get_all_tools(self, wrapper_type: LLMFrameworkEnum | str) -> list[typing.Any]:
|
|
285
317
|
pass
|
nat/builder/component_utils.py
CHANGED
|
@@ -30,6 +30,7 @@ from nat.data_models.component_ref import generate_instance_id
|
|
|
30
30
|
from nat.data_models.config import Config
|
|
31
31
|
from nat.data_models.embedder import EmbedderBaseConfig
|
|
32
32
|
from nat.data_models.function import FunctionBaseConfig
|
|
33
|
+
from nat.data_models.function import FunctionGroupBaseConfig
|
|
33
34
|
from nat.data_models.llm import LLMBaseConfig
|
|
34
35
|
from nat.data_models.memory import MemoryBaseConfig
|
|
35
36
|
from nat.data_models.object_store import ObjectStoreBaseConfig
|
|
@@ -48,6 +49,7 @@ _component_group_order = [
|
|
|
48
49
|
ComponentGroup.OBJECT_STORES,
|
|
49
50
|
ComponentGroup.RETRIEVERS,
|
|
50
51
|
ComponentGroup.TTC_STRATEGIES,
|
|
52
|
+
ComponentGroup.FUNCTION_GROUPS,
|
|
51
53
|
ComponentGroup.FUNCTIONS,
|
|
52
54
|
]
|
|
53
55
|
|
|
@@ -107,6 +109,8 @@ def group_from_component(component: TypedBaseModel) -> ComponentGroup | None:
|
|
|
107
109
|
return ComponentGroup.EMBEDDERS
|
|
108
110
|
if (isinstance(component, FunctionBaseConfig)):
|
|
109
111
|
return ComponentGroup.FUNCTIONS
|
|
112
|
+
if (isinstance(component, FunctionGroupBaseConfig)):
|
|
113
|
+
return ComponentGroup.FUNCTION_GROUPS
|
|
110
114
|
if (isinstance(component, LLMBaseConfig)):
|
|
111
115
|
return ComponentGroup.LLMS
|
|
112
116
|
if (isinstance(component, MemoryBaseConfig)):
|
|
@@ -154,7 +158,7 @@ def recursive_componentref_discovery(cls: TypedBaseModel, value: typing.Any,
|
|
|
154
158
|
yield from recursive_componentref_discovery(cls, field_data, field_info.annotation)
|
|
155
159
|
if (decomposed_type.is_union):
|
|
156
160
|
for arg in decomposed_type.args:
|
|
157
|
-
if arg is typing.Any or
|
|
161
|
+
if arg is typing.Any or DecomposedType(arg).is_instance(value):
|
|
158
162
|
yield from recursive_componentref_discovery(cls, value, arg)
|
|
159
163
|
else:
|
|
160
164
|
for arg in decomposed_type.args:
|
|
@@ -174,7 +178,7 @@ def update_dependency_graph(config: "Config", instance_config: TypedBaseModel,
|
|
|
174
178
|
nx.DiGraph: An dependency graph that has been updated with the provided runtime instance.
|
|
175
179
|
"""
|
|
176
180
|
|
|
177
|
-
for field_name, field_info in instance_config.model_fields.items():
|
|
181
|
+
for field_name, field_info in type(instance_config).model_fields.items():
|
|
178
182
|
|
|
179
183
|
for instance_id, value_node in recursive_componentref_discovery(
|
|
180
184
|
instance_config,
|
|
@@ -254,9 +258,9 @@ def build_dependency_sequence(config: "Config") -> list[ComponentInstanceData]:
|
|
|
254
258
|
runtime instance references.
|
|
255
259
|
"""
|
|
256
260
|
|
|
257
|
-
total_node_count = len(config.embedders) + len(config.functions) + len(config.
|
|
258
|
-
|
|
259
|
-
|
|
261
|
+
total_node_count = (len(config.embedders) + len(config.functions) + len(config.function_groups) + len(config.llms) +
|
|
262
|
+
len(config.memory) + len(config.object_stores) + len(config.retrievers) +
|
|
263
|
+
len(config.ttc_strategies) + len(config.authentication) + 1) # +1 for the workflow
|
|
260
264
|
|
|
261
265
|
dependency_map: dict
|
|
262
266
|
dependency_graph: nx.DiGraph
|
nat/builder/context.py
CHANGED
|
@@ -31,6 +31,7 @@ from nat.data_models.intermediate_step import IntermediateStep
|
|
|
31
31
|
from nat.data_models.intermediate_step import IntermediateStepPayload
|
|
32
32
|
from nat.data_models.intermediate_step import IntermediateStepType
|
|
33
33
|
from nat.data_models.intermediate_step import StreamEventData
|
|
34
|
+
from nat.data_models.intermediate_step import TraceMetadata
|
|
34
35
|
from nat.data_models.invocation_node import InvocationNode
|
|
35
36
|
from nat.runtime.user_metadata import RequestAttributes
|
|
36
37
|
from nat.utils.reactive.subject import Subject
|
|
@@ -38,13 +39,13 @@ from nat.utils.reactive.subject import Subject
|
|
|
38
39
|
|
|
39
40
|
class Singleton(type):
|
|
40
41
|
|
|
41
|
-
def __init__(cls, name, bases, dict):
|
|
42
|
-
super(
|
|
42
|
+
def __init__(cls, name, bases, dict):
|
|
43
|
+
super().__init__(name, bases, dict)
|
|
43
44
|
cls.instance = None
|
|
44
45
|
|
|
45
46
|
def __call__(cls, *args, **kw):
|
|
46
47
|
if cls.instance is None:
|
|
47
|
-
cls.instance = super(
|
|
48
|
+
cls.instance = super().__call__(*args, **kw)
|
|
48
49
|
return cls.instance
|
|
49
50
|
|
|
50
51
|
|
|
@@ -65,14 +66,13 @@ class ContextState(metaclass=Singleton):
|
|
|
65
66
|
|
|
66
67
|
def __init__(self):
|
|
67
68
|
self.conversation_id: ContextVar[str | None] = ContextVar("conversation_id", default=None)
|
|
69
|
+
self.user_message_id: ContextVar[str | None] = ContextVar("user_message_id", default=None)
|
|
68
70
|
self.input_message: ContextVar[typing.Any] = ContextVar("input_message", default=None)
|
|
69
71
|
self.user_manager: ContextVar[typing.Any] = ContextVar("user_manager", default=None)
|
|
70
|
-
self.
|
|
71
|
-
self.
|
|
72
|
-
self.
|
|
73
|
-
|
|
74
|
-
function_name="root"))
|
|
75
|
-
self.active_span_id_stack: ContextVar[list[str]] = ContextVar("active_span_id_stack", default=["root"])
|
|
72
|
+
self._metadata: ContextVar[RequestAttributes | None] = ContextVar("request_attributes", default=None)
|
|
73
|
+
self._event_stream: ContextVar[Subject[IntermediateStep] | None] = ContextVar("event_stream", default=None)
|
|
74
|
+
self._active_function: ContextVar[InvocationNode | None] = ContextVar("active_function", default=None)
|
|
75
|
+
self._active_span_id_stack: ContextVar[list[str] | None] = ContextVar("active_span_id_stack", default=None)
|
|
76
76
|
|
|
77
77
|
# Default is a lambda no-op which returns NoneType
|
|
78
78
|
self.user_input_callback: ContextVar[Callable[[InteractionPrompt], Awaitable[HumanResponse | None]]
|
|
@@ -83,6 +83,30 @@ class ContextState(metaclass=Singleton):
|
|
|
83
83
|
Awaitable[AuthenticatedContext]]
|
|
84
84
|
| None] = ContextVar("user_auth_callback", default=None)
|
|
85
85
|
|
|
86
|
+
@property
|
|
87
|
+
def metadata(self) -> ContextVar[RequestAttributes]:
|
|
88
|
+
if self._metadata.get() is None:
|
|
89
|
+
self._metadata.set(RequestAttributes())
|
|
90
|
+
return typing.cast(ContextVar[RequestAttributes], self._metadata)
|
|
91
|
+
|
|
92
|
+
@property
|
|
93
|
+
def active_function(self) -> ContextVar[InvocationNode]:
|
|
94
|
+
if self._active_function.get() is None:
|
|
95
|
+
self._active_function.set(InvocationNode(function_id="root", function_name="root"))
|
|
96
|
+
return typing.cast(ContextVar[InvocationNode], self._active_function)
|
|
97
|
+
|
|
98
|
+
@property
|
|
99
|
+
def event_stream(self) -> ContextVar[Subject[IntermediateStep]]:
|
|
100
|
+
if self._event_stream.get() is None:
|
|
101
|
+
self._event_stream.set(Subject())
|
|
102
|
+
return typing.cast(ContextVar[Subject[IntermediateStep]], self._event_stream)
|
|
103
|
+
|
|
104
|
+
@property
|
|
105
|
+
def active_span_id_stack(self) -> ContextVar[list[str]]:
|
|
106
|
+
if self._active_span_id_stack.get() is None:
|
|
107
|
+
self._active_span_id_stack.set(["root"])
|
|
108
|
+
return typing.cast(ContextVar[list[str]], self._active_span_id_stack)
|
|
109
|
+
|
|
86
110
|
@staticmethod
|
|
87
111
|
def get() -> "ContextState":
|
|
88
112
|
return ContextState()
|
|
@@ -165,8 +189,18 @@ class Context:
|
|
|
165
189
|
"""
|
|
166
190
|
return self._context_state.conversation_id.get()
|
|
167
191
|
|
|
192
|
+
@property
|
|
193
|
+
def user_message_id(self) -> str | None:
|
|
194
|
+
"""
|
|
195
|
+
This property retrieves the user message ID which is the unique identifier for the current user message.
|
|
196
|
+
"""
|
|
197
|
+
return self._context_state.user_message_id.get()
|
|
198
|
+
|
|
168
199
|
@contextmanager
|
|
169
|
-
def push_active_function(self,
|
|
200
|
+
def push_active_function(self,
|
|
201
|
+
function_name: str,
|
|
202
|
+
input_data: typing.Any | None,
|
|
203
|
+
metadata: dict[str, typing.Any] | TraceMetadata | None = None):
|
|
170
204
|
"""
|
|
171
205
|
Set the 'active_function' in context, push an invocation node,
|
|
172
206
|
AND create an OTel child span for that function call.
|
|
@@ -187,7 +221,8 @@ class Context:
|
|
|
187
221
|
IntermediateStepPayload(UUID=current_function_id,
|
|
188
222
|
event_type=IntermediateStepType.FUNCTION_START,
|
|
189
223
|
name=function_name,
|
|
190
|
-
data=StreamEventData(input=input_data)
|
|
224
|
+
data=StreamEventData(input=input_data),
|
|
225
|
+
metadata=metadata))
|
|
191
226
|
|
|
192
227
|
manager = ActiveFunctionContextManager()
|
|
193
228
|
|
nat/builder/eval_builder.py
CHANGED
|
@@ -13,6 +13,7 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
import asyncio
|
|
16
17
|
import dataclasses
|
|
17
18
|
import logging
|
|
18
19
|
from contextlib import asynccontextmanager
|
|
@@ -61,7 +62,7 @@ class WorkflowEvalBuilder(WorkflowBuilder, EvalBuilder):
|
|
|
61
62
|
# Store the evaluator
|
|
62
63
|
self._evaluators[name] = ConfiguredEvaluator(config=config, instance=info_obj)
|
|
63
64
|
except Exception as e:
|
|
64
|
-
logger.error("Error %s adding evaluator `%s` with config `%s`", e, name, config
|
|
65
|
+
logger.error("Error %s adding evaluator `%s` with config `%s`", e, name, config)
|
|
65
66
|
raise
|
|
66
67
|
|
|
67
68
|
@override
|
|
@@ -90,17 +91,20 @@ class WorkflowEvalBuilder(WorkflowBuilder, EvalBuilder):
|
|
|
90
91
|
return self.eval_general_config.output_dir
|
|
91
92
|
|
|
92
93
|
@override
|
|
93
|
-
def get_all_tools(self, wrapper_type: LLMFrameworkEnum | str):
|
|
94
|
-
tools = []
|
|
94
|
+
async def get_all_tools(self, wrapper_type: LLMFrameworkEnum | str):
|
|
95
95
|
tool_wrapper_reg = self._registry.get_tool_wrapper(llm_framework=wrapper_type)
|
|
96
|
-
|
|
97
|
-
|
|
96
|
+
|
|
97
|
+
async def get_tool(fn_name: str):
|
|
98
|
+
fn = await self.get_function(fn_name)
|
|
98
99
|
try:
|
|
99
|
-
|
|
100
|
+
return tool_wrapper_reg.build_fn(fn_name, fn, self)
|
|
100
101
|
except Exception:
|
|
101
|
-
logger.exception("Error fetching tool `%s`", fn_name
|
|
102
|
+
logger.exception("Error fetching tool `%s`", fn_name)
|
|
103
|
+
return None
|
|
102
104
|
|
|
103
|
-
|
|
105
|
+
tasks = [get_tool(fn_name) for fn_name in self._functions]
|
|
106
|
+
tools = await asyncio.gather(*tasks, return_exceptions=False)
|
|
107
|
+
return [tool for tool in tools if tool is not None]
|
|
104
108
|
|
|
105
109
|
def _log_build_failure_evaluator(self,
|
|
106
110
|
failing_evaluator_name: str,
|
|
@@ -127,11 +131,12 @@ class WorkflowEvalBuilder(WorkflowBuilder, EvalBuilder):
|
|
|
127
131
|
remaining_components,
|
|
128
132
|
original_error)
|
|
129
133
|
|
|
130
|
-
|
|
134
|
+
@override
|
|
135
|
+
async def populate_builder(self, config: Config, skip_workflow: bool = False):
|
|
131
136
|
# Skip setting workflow if workflow config is EmptyFunctionConfig
|
|
132
|
-
skip_workflow = isinstance(config.workflow, EmptyFunctionConfig)
|
|
137
|
+
skip_workflow = skip_workflow or isinstance(config.workflow, EmptyFunctionConfig)
|
|
133
138
|
|
|
134
|
-
await super().populate_builder(config, skip_workflow)
|
|
139
|
+
await super().populate_builder(config, skip_workflow=skip_workflow)
|
|
135
140
|
|
|
136
141
|
# Initialize progress tracking for evaluators
|
|
137
142
|
completed_evaluators = []
|