nvidia-nat 1.3.0a20250910__py3-none-any.whl → 1.3.0a20250922__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/agent/base.py +9 -4
- nat/agent/prompt_optimizer/prompt.py +68 -0
- nat/agent/prompt_optimizer/register.py +149 -0
- nat/agent/react_agent/agent.py +1 -1
- nat/agent/react_agent/register.py +17 -14
- nat/agent/reasoning_agent/reasoning_agent.py +9 -7
- nat/agent/register.py +1 -0
- nat/agent/rewoo_agent/agent.py +9 -2
- nat/agent/rewoo_agent/register.py +16 -12
- nat/agent/tool_calling_agent/agent.py +69 -7
- nat/agent/tool_calling_agent/register.py +14 -13
- nat/authentication/credential_validator/__init__.py +14 -0
- nat/authentication/credential_validator/bearer_token_validator.py +557 -0
- nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
- nat/builder/builder.py +27 -4
- nat/builder/component_utils.py +7 -3
- nat/builder/context.py +28 -6
- nat/builder/function.py +313 -0
- nat/builder/function_info.py +1 -1
- nat/builder/workflow.py +5 -0
- nat/builder/workflow_builder.py +215 -16
- nat/cli/commands/optimize.py +90 -0
- nat/cli/commands/workflow/templates/config.yml.j2 +0 -1
- nat/cli/commands/workflow/workflow_commands.py +4 -7
- nat/cli/entrypoint.py +4 -9
- nat/cli/register_workflow.py +38 -4
- nat/cli/type_registry.py +71 -0
- nat/control_flow/__init__.py +0 -0
- nat/control_flow/register.py +20 -0
- nat/control_flow/router_agent/__init__.py +0 -0
- nat/control_flow/router_agent/agent.py +329 -0
- nat/control_flow/router_agent/prompt.py +48 -0
- nat/control_flow/router_agent/register.py +91 -0
- nat/control_flow/sequential_executor.py +167 -0
- nat/data_models/agent.py +34 -0
- nat/data_models/authentication.py +38 -0
- nat/data_models/component.py +2 -0
- nat/data_models/component_ref.py +11 -0
- nat/data_models/config.py +40 -16
- nat/data_models/function.py +34 -0
- nat/data_models/function_dependencies.py +8 -0
- nat/data_models/optimizable.py +119 -0
- nat/data_models/optimizer.py +149 -0
- nat/data_models/temperature_mixin.py +4 -3
- nat/data_models/top_p_mixin.py +4 -3
- nat/embedder/nim_embedder.py +1 -1
- nat/embedder/openai_embedder.py +1 -1
- nat/eval/config.py +1 -1
- nat/eval/evaluate.py +5 -1
- nat/eval/register.py +4 -0
- nat/eval/runtime_evaluator/__init__.py +14 -0
- nat/eval/runtime_evaluator/evaluate.py +123 -0
- nat/eval/runtime_evaluator/register.py +100 -0
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +5 -1
- nat/front_ends/fastapi/dask_client_mixin.py +65 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +18 -3
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +134 -3
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +243 -228
- nat/front_ends/fastapi/job_store.py +518 -99
- nat/front_ends/fastapi/main.py +11 -19
- nat/front_ends/fastapi/utils.py +57 -0
- nat/front_ends/mcp/introspection_token_verifier.py +73 -0
- nat/front_ends/mcp/mcp_front_end_config.py +5 -1
- nat/front_ends/mcp/mcp_front_end_plugin.py +37 -11
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +111 -3
- nat/front_ends/mcp/tool_converter.py +3 -0
- nat/llm/aws_bedrock_llm.py +14 -3
- nat/llm/nim_llm.py +14 -3
- nat/llm/openai_llm.py +8 -1
- nat/observability/exporter/processing_exporter.py +29 -55
- nat/observability/mixin/redaction_config_mixin.py +5 -4
- nat/observability/mixin/tagging_config_mixin.py +26 -14
- nat/observability/mixin/type_introspection_mixin.py +420 -107
- nat/observability/processor/processor.py +3 -0
- nat/observability/processor/redaction/__init__.py +24 -0
- nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
- nat/observability/processor/redaction/redaction_processor.py +177 -0
- nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
- nat/observability/processor/span_tagging_processor.py +21 -14
- nat/profiler/decorators/framework_wrapper.py +9 -6
- nat/profiler/parameter_optimization/__init__.py +0 -0
- nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
- nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
- nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
- nat/profiler/parameter_optimization/parameter_selection.py +108 -0
- nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
- nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
- nat/profiler/parameter_optimization/update_helpers.py +66 -0
- nat/profiler/utils.py +3 -1
- nat/tool/chat_completion.py +4 -1
- nat/tool/github_tools.py +450 -0
- nat/tool/register.py +2 -7
- nat/utils/callable_utils.py +70 -0
- nat/utils/exception_handlers/automatic_retries.py +103 -48
- nat/utils/log_levels.py +25 -0
- nat/utils/type_utils.py +4 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/METADATA +10 -1
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/RECORD +105 -76
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/entry_points.txt +1 -0
- nat/observability/processor/header_redaction_processor.py +0 -123
- nat/observability/processor/redaction_processor.py +0 -77
- nat/tool/github_tools/create_github_commit.py +0 -133
- nat/tool/github_tools/create_github_issue.py +0 -87
- nat/tool/github_tools/create_github_pr.py +0 -106
- nat/tool/github_tools/get_github_file.py +0 -106
- nat/tool/github_tools/get_github_issue.py +0 -166
- nat/tool/github_tools/get_github_pr.py +0 -256
- nat/tool/github_tools/update_github_issue.py +0 -100
- /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from urllib.parse import urlparse
|
|
17
|
+
|
|
18
|
+
from pydantic import Field
|
|
19
|
+
from pydantic import field_validator
|
|
20
|
+
from pydantic import model_validator
|
|
21
|
+
|
|
22
|
+
from nat.data_models.authentication import AuthProviderBaseConfig
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class OAuth2ResourceServerConfig(AuthProviderBaseConfig, name="oauth2_resource_server"):
|
|
26
|
+
"""OAuth 2.0 Resource Server authentication configuration.
|
|
27
|
+
|
|
28
|
+
Supports:
|
|
29
|
+
• JWT access tokens via JWKS / OIDC Discovery / issuer fallback
|
|
30
|
+
• Opaque access tokens via RFC 7662 introspection
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
issuer_url: str = Field(
|
|
34
|
+
description=("The unique issuer identifier for an authorization server. "
|
|
35
|
+
"Required for validation and used to derive the default JWKS URI "
|
|
36
|
+
"(<issuer_url>/.well-known/jwks.json) if `jwks_uri` and `discovery_url` are not provided."), )
|
|
37
|
+
scopes: list[str] = Field(
|
|
38
|
+
default_factory=list,
|
|
39
|
+
description="Scopes required by this API. Validation ensures the token grants all listed scopes.",
|
|
40
|
+
)
|
|
41
|
+
audience: str | None = Field(
|
|
42
|
+
default=None,
|
|
43
|
+
description=(
|
|
44
|
+
"Expected audience (`aud`) claim for this API. If set, validation will reject tokens without this audience."
|
|
45
|
+
),
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
# JWT verification params
|
|
49
|
+
jwks_uri: str | None = Field(
|
|
50
|
+
default=None,
|
|
51
|
+
description=("Direct JWKS endpoint URI for JWT signature verification. "
|
|
52
|
+
"Optional if discovery or issuer is provided."),
|
|
53
|
+
)
|
|
54
|
+
discovery_url: str | None = Field(
|
|
55
|
+
default=None,
|
|
56
|
+
description=("OIDC discovery metadata URL. Used to automatically resolve JWKS and introspection endpoints."),
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
# Opaque token (introspection) params
|
|
60
|
+
introspection_endpoint: str | None = Field(
|
|
61
|
+
default=None,
|
|
62
|
+
description=("RFC 7662 token introspection endpoint. "
|
|
63
|
+
"Required for opaque token validation and must be used with `client_id` and `client_secret`."),
|
|
64
|
+
)
|
|
65
|
+
client_id: str | None = Field(
|
|
66
|
+
default=None,
|
|
67
|
+
description="OAuth2 client ID for authenticating to the introspection endpoint (opaque token validation).",
|
|
68
|
+
)
|
|
69
|
+
client_secret: str | None = Field(
|
|
70
|
+
default=None,
|
|
71
|
+
description="OAuth2 client secret for authenticating to the introspection endpoint (opaque token validation).",
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
@staticmethod
|
|
75
|
+
def _is_https_or_localhost(url: str) -> bool:
|
|
76
|
+
try:
|
|
77
|
+
value = urlparse(url)
|
|
78
|
+
if not value.scheme or not value.netloc:
|
|
79
|
+
return False
|
|
80
|
+
if value.scheme == "https":
|
|
81
|
+
return True
|
|
82
|
+
return value.scheme == "http" and (value.hostname in {"localhost", "127.0.0.1", "::1"})
|
|
83
|
+
except Exception:
|
|
84
|
+
return False
|
|
85
|
+
|
|
86
|
+
@field_validator("issuer_url", "jwks_uri", "discovery_url", "introspection_endpoint")
|
|
87
|
+
@classmethod
|
|
88
|
+
def _require_valid_url(cls, value: str | None, info):
|
|
89
|
+
if value is None:
|
|
90
|
+
return value
|
|
91
|
+
if not cls._is_https_or_localhost(value):
|
|
92
|
+
raise ValueError(f"{info.field_name} must be HTTPS (http allowed only for localhost). Got: {value}")
|
|
93
|
+
return value
|
|
94
|
+
|
|
95
|
+
# ---------- Cross-field validation: ensure at least one viable path ----------
|
|
96
|
+
|
|
97
|
+
@model_validator(mode="after")
|
|
98
|
+
def _ensure_verification_path(self):
|
|
99
|
+
"""
|
|
100
|
+
JWT path viable if any of: jwks_uri OR discovery_url OR issuer_url (fallback JWKS).
|
|
101
|
+
Opaque path viable if: introspection_endpoint AND client_id AND client_secret.
|
|
102
|
+
"""
|
|
103
|
+
has_jwt_path = bool(self.jwks_uri or self.discovery_url or self.issuer_url)
|
|
104
|
+
has_opaque_path = bool(self.introspection_endpoint and self.client_id and self.client_secret)
|
|
105
|
+
|
|
106
|
+
# If introspection endpoint is set, enforce creds are present
|
|
107
|
+
if self.introspection_endpoint:
|
|
108
|
+
missing = []
|
|
109
|
+
if not self.client_id:
|
|
110
|
+
missing.append("client_id")
|
|
111
|
+
if not self.client_secret:
|
|
112
|
+
missing.append("client_secret")
|
|
113
|
+
if missing:
|
|
114
|
+
raise ValueError(
|
|
115
|
+
f"introspection_endpoint configured but missing required credentials: {', '.join(missing)}")
|
|
116
|
+
|
|
117
|
+
# Require at least one path
|
|
118
|
+
if not (has_jwt_path or has_opaque_path):
|
|
119
|
+
raise ValueError("Invalid configuration: no verification method available. "
|
|
120
|
+
"Configure one of the following:\n"
|
|
121
|
+
" • JWT path: set jwks_uri OR discovery_url OR issuer_url (for JWKS fallback)\n"
|
|
122
|
+
" • Opaque path: set introspection_endpoint + client_id + client_secret")
|
|
123
|
+
|
|
124
|
+
return self
|
nat/builder/builder.py
CHANGED
|
@@ -24,9 +24,11 @@ from nat.authentication.interfaces import AuthProviderBase
|
|
|
24
24
|
from nat.builder.context import Context
|
|
25
25
|
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
26
26
|
from nat.builder.function import Function
|
|
27
|
+
from nat.builder.function import FunctionGroup
|
|
27
28
|
from nat.data_models.authentication import AuthProviderBaseConfig
|
|
28
29
|
from nat.data_models.component_ref import AuthenticationRef
|
|
29
30
|
from nat.data_models.component_ref import EmbedderRef
|
|
31
|
+
from nat.data_models.component_ref import FunctionGroupRef
|
|
30
32
|
from nat.data_models.component_ref import FunctionRef
|
|
31
33
|
from nat.data_models.component_ref import LLMRef
|
|
32
34
|
from nat.data_models.component_ref import MemoryRef
|
|
@@ -36,6 +38,7 @@ from nat.data_models.component_ref import TTCStrategyRef
|
|
|
36
38
|
from nat.data_models.embedder import EmbedderBaseConfig
|
|
37
39
|
from nat.data_models.evaluator import EvaluatorBaseConfig
|
|
38
40
|
from nat.data_models.function import FunctionBaseConfig
|
|
41
|
+
from nat.data_models.function import FunctionGroupBaseConfig
|
|
39
42
|
from nat.data_models.function_dependencies import FunctionDependencies
|
|
40
43
|
from nat.data_models.llm import LLMBaseConfig
|
|
41
44
|
from nat.data_models.memory import MemoryBaseConfig
|
|
@@ -64,18 +67,33 @@ class Builder(ABC):
|
|
|
64
67
|
async def add_function(self, name: str | FunctionRef, config: FunctionBaseConfig) -> Function:
|
|
65
68
|
pass
|
|
66
69
|
|
|
70
|
+
@abstractmethod
|
|
71
|
+
async def add_function_group(self, name: str | FunctionGroupRef, config: FunctionGroupBaseConfig) -> FunctionGroup:
|
|
72
|
+
pass
|
|
73
|
+
|
|
67
74
|
@abstractmethod
|
|
68
75
|
def get_function(self, name: str | FunctionRef) -> Function:
|
|
69
76
|
pass
|
|
70
77
|
|
|
78
|
+
@abstractmethod
|
|
79
|
+
def get_function_group(self, name: str | FunctionGroupRef) -> FunctionGroup:
|
|
80
|
+
pass
|
|
81
|
+
|
|
71
82
|
def get_functions(self, function_names: Sequence[str | FunctionRef]) -> list[Function]:
|
|
72
83
|
|
|
73
84
|
return [self.get_function(name) for name in function_names]
|
|
74
85
|
|
|
86
|
+
def get_function_groups(self, function_group_names: Sequence[str | FunctionGroupRef]) -> list[FunctionGroup]:
|
|
87
|
+
return [self.get_function_group(name) for name in function_group_names]
|
|
88
|
+
|
|
75
89
|
@abstractmethod
|
|
76
90
|
def get_function_config(self, name: str | FunctionRef) -> FunctionBaseConfig:
|
|
77
91
|
pass
|
|
78
92
|
|
|
93
|
+
@abstractmethod
|
|
94
|
+
def get_function_group_config(self, name: str | FunctionGroupRef) -> FunctionGroupBaseConfig:
|
|
95
|
+
pass
|
|
96
|
+
|
|
79
97
|
@abstractmethod
|
|
80
98
|
async def set_workflow(self, config: FunctionBaseConfig) -> Function:
|
|
81
99
|
pass
|
|
@@ -88,10 +106,11 @@ class Builder(ABC):
|
|
|
88
106
|
def get_workflow_config(self) -> FunctionBaseConfig:
|
|
89
107
|
pass
|
|
90
108
|
|
|
91
|
-
|
|
109
|
+
@abstractmethod
|
|
110
|
+
def get_tools(self,
|
|
111
|
+
tool_names: Sequence[str | FunctionRef | FunctionGroupRef],
|
|
92
112
|
wrapper_type: LLMFrameworkEnum | str) -> list[typing.Any]:
|
|
93
|
-
|
|
94
|
-
return [self.get_tool(fn_name=n, wrapper_type=wrapper_type) for n in tool_names]
|
|
113
|
+
pass
|
|
95
114
|
|
|
96
115
|
@abstractmethod
|
|
97
116
|
def get_tool(self, fn_name: str | FunctionRef, wrapper_type: LLMFrameworkEnum | str) -> typing.Any:
|
|
@@ -257,8 +276,12 @@ class Builder(ABC):
|
|
|
257
276
|
def get_function_dependencies(self, fn_name: str) -> FunctionDependencies:
|
|
258
277
|
pass
|
|
259
278
|
|
|
279
|
+
@abstractmethod
|
|
280
|
+
def get_function_group_dependencies(self, fn_name: str) -> FunctionDependencies:
|
|
281
|
+
pass
|
|
282
|
+
|
|
260
283
|
|
|
261
|
-
class EvalBuilder(
|
|
284
|
+
class EvalBuilder(ABC):
|
|
262
285
|
|
|
263
286
|
@abstractmethod
|
|
264
287
|
async def add_evaluator(self, name: str, config: EvaluatorBaseConfig):
|
nat/builder/component_utils.py
CHANGED
|
@@ -30,6 +30,7 @@ from nat.data_models.component_ref import generate_instance_id
|
|
|
30
30
|
from nat.data_models.config import Config
|
|
31
31
|
from nat.data_models.embedder import EmbedderBaseConfig
|
|
32
32
|
from nat.data_models.function import FunctionBaseConfig
|
|
33
|
+
from nat.data_models.function import FunctionGroupBaseConfig
|
|
33
34
|
from nat.data_models.llm import LLMBaseConfig
|
|
34
35
|
from nat.data_models.memory import MemoryBaseConfig
|
|
35
36
|
from nat.data_models.object_store import ObjectStoreBaseConfig
|
|
@@ -48,6 +49,7 @@ _component_group_order = [
|
|
|
48
49
|
ComponentGroup.OBJECT_STORES,
|
|
49
50
|
ComponentGroup.RETRIEVERS,
|
|
50
51
|
ComponentGroup.TTC_STRATEGIES,
|
|
52
|
+
ComponentGroup.FUNCTION_GROUPS,
|
|
51
53
|
ComponentGroup.FUNCTIONS,
|
|
52
54
|
]
|
|
53
55
|
|
|
@@ -107,6 +109,8 @@ def group_from_component(component: TypedBaseModel) -> ComponentGroup | None:
|
|
|
107
109
|
return ComponentGroup.EMBEDDERS
|
|
108
110
|
if (isinstance(component, FunctionBaseConfig)):
|
|
109
111
|
return ComponentGroup.FUNCTIONS
|
|
112
|
+
if (isinstance(component, FunctionGroupBaseConfig)):
|
|
113
|
+
return ComponentGroup.FUNCTION_GROUPS
|
|
110
114
|
if (isinstance(component, LLMBaseConfig)):
|
|
111
115
|
return ComponentGroup.LLMS
|
|
112
116
|
if (isinstance(component, MemoryBaseConfig)):
|
|
@@ -254,9 +258,9 @@ def build_dependency_sequence(config: "Config") -> list[ComponentInstanceData]:
|
|
|
254
258
|
runtime instance references.
|
|
255
259
|
"""
|
|
256
260
|
|
|
257
|
-
total_node_count = len(config.embedders) + len(config.functions) + len(config.
|
|
258
|
-
|
|
259
|
-
|
|
261
|
+
total_node_count = (len(config.embedders) + len(config.functions) + len(config.function_groups) + len(config.llms) +
|
|
262
|
+
len(config.memory) + len(config.object_stores) + len(config.retrievers) +
|
|
263
|
+
len(config.ttc_strategies) + len(config.authentication) + 1) # +1 for the workflow
|
|
260
264
|
|
|
261
265
|
dependency_map: dict
|
|
262
266
|
dependency_graph: nx.DiGraph
|
nat/builder/context.py
CHANGED
|
@@ -69,12 +69,10 @@ class ContextState(metaclass=Singleton):
|
|
|
69
69
|
self.user_message_id: ContextVar[str | None] = ContextVar("user_message_id", default=None)
|
|
70
70
|
self.input_message: ContextVar[typing.Any] = ContextVar("input_message", default=None)
|
|
71
71
|
self.user_manager: ContextVar[typing.Any] = ContextVar("user_manager", default=None)
|
|
72
|
-
self.
|
|
73
|
-
self.
|
|
74
|
-
self.
|
|
75
|
-
|
|
76
|
-
function_name="root"))
|
|
77
|
-
self.active_span_id_stack: ContextVar[list[str]] = ContextVar("active_span_id_stack", default=["root"])
|
|
72
|
+
self._metadata: ContextVar[RequestAttributes | None] = ContextVar("request_attributes", default=None)
|
|
73
|
+
self._event_stream: ContextVar[Subject[IntermediateStep] | None] = ContextVar("event_stream", default=None)
|
|
74
|
+
self._active_function: ContextVar[InvocationNode | None] = ContextVar("active_function", default=None)
|
|
75
|
+
self._active_span_id_stack: ContextVar[list[str] | None] = ContextVar("active_span_id_stack", default=None)
|
|
78
76
|
|
|
79
77
|
# Default is a lambda no-op which returns NoneType
|
|
80
78
|
self.user_input_callback: ContextVar[Callable[[InteractionPrompt], Awaitable[HumanResponse | None]]
|
|
@@ -85,6 +83,30 @@ class ContextState(metaclass=Singleton):
|
|
|
85
83
|
Awaitable[AuthenticatedContext]]
|
|
86
84
|
| None] = ContextVar("user_auth_callback", default=None)
|
|
87
85
|
|
|
86
|
+
@property
|
|
87
|
+
def metadata(self) -> ContextVar[RequestAttributes]:
|
|
88
|
+
if self._metadata.get() is None:
|
|
89
|
+
self._metadata.set(RequestAttributes())
|
|
90
|
+
return typing.cast(ContextVar[RequestAttributes], self._metadata)
|
|
91
|
+
|
|
92
|
+
@property
|
|
93
|
+
def active_function(self) -> ContextVar[InvocationNode]:
|
|
94
|
+
if self._active_function.get() is None:
|
|
95
|
+
self._active_function.set(InvocationNode(function_id="root", function_name="root"))
|
|
96
|
+
return typing.cast(ContextVar[InvocationNode], self._active_function)
|
|
97
|
+
|
|
98
|
+
@property
|
|
99
|
+
def event_stream(self) -> ContextVar[Subject[IntermediateStep]]:
|
|
100
|
+
if self._event_stream.get() is None:
|
|
101
|
+
self._event_stream.set(Subject())
|
|
102
|
+
return typing.cast(ContextVar[Subject[IntermediateStep]], self._event_stream)
|
|
103
|
+
|
|
104
|
+
@property
|
|
105
|
+
def active_span_id_stack(self) -> ContextVar[list[str]]:
|
|
106
|
+
if self._active_span_id_stack.get() is None:
|
|
107
|
+
self._active_span_id_stack.set(["root"])
|
|
108
|
+
return typing.cast(ContextVar[list[str]], self._active_span_id_stack)
|
|
109
|
+
|
|
88
110
|
@staticmethod
|
|
89
111
|
def get() -> "ContextState":
|
|
90
112
|
return ContextState()
|
nat/builder/function.py
CHANGED
|
@@ -14,12 +14,14 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import logging
|
|
17
|
+
import re
|
|
17
18
|
import typing
|
|
18
19
|
from abc import ABC
|
|
19
20
|
from abc import abstractmethod
|
|
20
21
|
from collections.abc import AsyncGenerator
|
|
21
22
|
from collections.abc import Awaitable
|
|
22
23
|
from collections.abc import Callable
|
|
24
|
+
from collections.abc import Sequence
|
|
23
25
|
|
|
24
26
|
from pydantic import BaseModel
|
|
25
27
|
|
|
@@ -29,7 +31,9 @@ from nat.builder.function_base import InputT
|
|
|
29
31
|
from nat.builder.function_base import SingleOutputT
|
|
30
32
|
from nat.builder.function_base import StreamingOutputT
|
|
31
33
|
from nat.builder.function_info import FunctionInfo
|
|
34
|
+
from nat.data_models.function import EmptyFunctionConfig
|
|
32
35
|
from nat.data_models.function import FunctionBaseConfig
|
|
36
|
+
from nat.data_models.function import FunctionGroupBaseConfig
|
|
33
37
|
|
|
34
38
|
_InvokeFnT = Callable[[InputT], Awaitable[SingleOutputT]]
|
|
35
39
|
_StreamFnT = Callable[[InputT], AsyncGenerator[StreamingOutputT]]
|
|
@@ -342,3 +346,312 @@ class LambdaFunction(Function[InputT, StreamingOutputT, SingleOutputT]):
|
|
|
342
346
|
pass
|
|
343
347
|
|
|
344
348
|
return FunctionImpl(config=config, info=info, instance_name=instance_name)
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
class FunctionGroup:
|
|
352
|
+
"""
|
|
353
|
+
A group of functions that can be used together, sharing the same configuration, context, and resources.
|
|
354
|
+
"""
|
|
355
|
+
|
|
356
|
+
def __init__(self,
|
|
357
|
+
*,
|
|
358
|
+
config: FunctionGroupBaseConfig,
|
|
359
|
+
instance_name: str | None = None,
|
|
360
|
+
filter_fn: Callable[[Sequence[str]], Sequence[str]] | None = None):
|
|
361
|
+
"""
|
|
362
|
+
Creates a new function group.
|
|
363
|
+
|
|
364
|
+
Parameters
|
|
365
|
+
----------
|
|
366
|
+
config : FunctionGroupBaseConfig
|
|
367
|
+
The configuration for the function group.
|
|
368
|
+
instance_name : str | None, optional
|
|
369
|
+
The name of the function group. If not provided, the type of the function group will be used.
|
|
370
|
+
filter_fn : Callable[[Sequence[str]], Sequence[str]] | None, optional
|
|
371
|
+
A callback function to additionally filter the functions in the function group dynamically when
|
|
372
|
+
the functions are accessed via any accessor method.
|
|
373
|
+
"""
|
|
374
|
+
self._config = config
|
|
375
|
+
self._instance_name = instance_name or config.type
|
|
376
|
+
self._functions: dict[str, Function] = dict()
|
|
377
|
+
self._filter_fn = filter_fn
|
|
378
|
+
self._per_function_filter_fn: dict[str, Callable[[str], bool]] = dict()
|
|
379
|
+
|
|
380
|
+
def add_function(self,
|
|
381
|
+
name: str,
|
|
382
|
+
fn: Callable,
|
|
383
|
+
*,
|
|
384
|
+
input_schema: type[BaseModel] | None = None,
|
|
385
|
+
description: str | None = None,
|
|
386
|
+
converters: list[Callable] | None = None,
|
|
387
|
+
filter_fn: Callable[[str], bool] | None = None):
|
|
388
|
+
"""
|
|
389
|
+
Adds a function to the function group.
|
|
390
|
+
|
|
391
|
+
Parameters
|
|
392
|
+
----------
|
|
393
|
+
name : str
|
|
394
|
+
The name of the function.
|
|
395
|
+
fn : Callable
|
|
396
|
+
The function to add to the function group.
|
|
397
|
+
input_schema : type[BaseModel] | None, optional
|
|
398
|
+
The input schema for the function.
|
|
399
|
+
description : str | None, optional
|
|
400
|
+
The description of the function.
|
|
401
|
+
converters : list[Callable] | None, optional
|
|
402
|
+
The converters to use for the function.
|
|
403
|
+
filter_fn : Callable[[str], bool] | None, optional
|
|
404
|
+
A callback to determine if the function should be included in the function group. The
|
|
405
|
+
callback will be called with the function name. The callback is invoked dynamically when
|
|
406
|
+
the functions are accessed via any accessor method such as `get_accessible_functions`,
|
|
407
|
+
`get_included_functions`, `get_excluded_functions`, `get_all_functions`.
|
|
408
|
+
|
|
409
|
+
Raises
|
|
410
|
+
------
|
|
411
|
+
ValueError
|
|
412
|
+
When the function name is empty or blank.
|
|
413
|
+
When the function name contains invalid characters.
|
|
414
|
+
When the function already exists in the function group.
|
|
415
|
+
"""
|
|
416
|
+
if not name.strip():
|
|
417
|
+
raise ValueError("Function name cannot be empty or blank")
|
|
418
|
+
if not re.match(r"^[a-zA-Z0-9_-]+$", name):
|
|
419
|
+
raise ValueError(f"Function name can only contain letters, numbers, underscores, and hyphens: {name}")
|
|
420
|
+
if name in self._functions:
|
|
421
|
+
raise ValueError(f"Function {name} already exists in function group {self._instance_name}")
|
|
422
|
+
|
|
423
|
+
info = FunctionInfo.from_fn(fn, input_schema=input_schema, description=description, converters=converters)
|
|
424
|
+
full_name = self._get_fn_name(name)
|
|
425
|
+
lambda_fn = LambdaFunction.from_info(config=EmptyFunctionConfig(), info=info, instance_name=full_name)
|
|
426
|
+
self._functions[name] = lambda_fn
|
|
427
|
+
if filter_fn:
|
|
428
|
+
self._per_function_filter_fn[name] = filter_fn
|
|
429
|
+
|
|
430
|
+
def get_config(self) -> FunctionGroupBaseConfig:
|
|
431
|
+
"""
|
|
432
|
+
Returns the configuration for the function group.
|
|
433
|
+
|
|
434
|
+
Returns
|
|
435
|
+
-------
|
|
436
|
+
FunctionGroupBaseConfig
|
|
437
|
+
The configuration for the function group.
|
|
438
|
+
"""
|
|
439
|
+
return self._config
|
|
440
|
+
|
|
441
|
+
def _get_fn_name(self, name: str) -> str:
|
|
442
|
+
return f"{self._instance_name}.{name}"
|
|
443
|
+
|
|
444
|
+
def _fn_should_be_included(self, name: str) -> bool:
|
|
445
|
+
return (name not in self._per_function_filter_fn or self._per_function_filter_fn[name](name))
|
|
446
|
+
|
|
447
|
+
def _get_all_but_excluded_functions(
|
|
448
|
+
self,
|
|
449
|
+
filter_fn: Callable[[Sequence[str]], Sequence[str]] | None = None,
|
|
450
|
+
) -> dict[str, Function]:
|
|
451
|
+
"""
|
|
452
|
+
Returns a dictionary of all functions in the function group except the excluded functions.
|
|
453
|
+
"""
|
|
454
|
+
missing = set(self._config.exclude) - set(self._functions.keys())
|
|
455
|
+
if missing:
|
|
456
|
+
raise ValueError(f"Unknown excluded functions: {sorted(missing)}")
|
|
457
|
+
filter_fn = filter_fn or self._filter_fn or (lambda x: x)
|
|
458
|
+
excluded = set(self._config.exclude)
|
|
459
|
+
included = set(filter_fn(list(self._functions.keys())))
|
|
460
|
+
|
|
461
|
+
def predicate(name: str) -> bool:
|
|
462
|
+
if name in excluded:
|
|
463
|
+
return False
|
|
464
|
+
if not self._fn_should_be_included(name):
|
|
465
|
+
return False
|
|
466
|
+
return name in included
|
|
467
|
+
|
|
468
|
+
return {self._get_fn_name(name): self._functions[name] for name in self._functions if predicate(name)}
|
|
469
|
+
|
|
470
|
+
def get_accessible_functions(
|
|
471
|
+
self,
|
|
472
|
+
filter_fn: Callable[[Sequence[str]], Sequence[str]] | None = None,
|
|
473
|
+
) -> dict[str, Function]:
|
|
474
|
+
"""
|
|
475
|
+
Returns a dictionary of all accessible functions in the function group.
|
|
476
|
+
|
|
477
|
+
First, the functions are filtered by the function group's configuration.
|
|
478
|
+
If the function group is configured to:
|
|
479
|
+
- include some functions, this will return only the included functions.
|
|
480
|
+
- not include or exclude any function, this will return all functions in the group.
|
|
481
|
+
- exclude some functions, this will return all functions in the group except the excluded functions.
|
|
482
|
+
|
|
483
|
+
Then, the functions are filtered by filter function and per-function filter functions.
|
|
484
|
+
|
|
485
|
+
Parameters
|
|
486
|
+
----------
|
|
487
|
+
filter_fn : Callable[[Sequence[str]], Sequence[str]] | None, optional
|
|
488
|
+
A callback function to additionally filter the functions in the function group dynamically. If not provided
|
|
489
|
+
then fall back to the function group's filter function. If no filter function is set for the function group
|
|
490
|
+
all functions will be returned.
|
|
491
|
+
|
|
492
|
+
Returns
|
|
493
|
+
-------
|
|
494
|
+
dict[str, Function]
|
|
495
|
+
A dictionary of all accessible functions in the function group.
|
|
496
|
+
|
|
497
|
+
Raises
|
|
498
|
+
------
|
|
499
|
+
ValueError
|
|
500
|
+
When the function group is configured to include functions that are not found in the group.
|
|
501
|
+
"""
|
|
502
|
+
if self._config.include:
|
|
503
|
+
return self.get_included_functions(filter_fn=filter_fn)
|
|
504
|
+
if self._config.exclude:
|
|
505
|
+
return self._get_all_but_excluded_functions(filter_fn=filter_fn)
|
|
506
|
+
return self.get_all_functions(filter_fn=filter_fn)
|
|
507
|
+
|
|
508
|
+
def get_excluded_functions(
|
|
509
|
+
self,
|
|
510
|
+
filter_fn: Callable[[Sequence[str]], Sequence[str]] | None = None,
|
|
511
|
+
) -> dict[str, Function]:
|
|
512
|
+
"""
|
|
513
|
+
Returns a dictionary of all functions in the function group which are configured to be excluded or filtered
|
|
514
|
+
out by a filter function or per-function filter function.
|
|
515
|
+
|
|
516
|
+
Parameters
|
|
517
|
+
----------
|
|
518
|
+
filter_fn : Callable[[Sequence[str]], Sequence[str]] | None, optional
|
|
519
|
+
A callback function to additionally filter the functions in the function group dynamically. If not provided
|
|
520
|
+
then fall back to the function group's filter function. If no filter function is set for the function group
|
|
521
|
+
then no functions will be added to the returned dictionary.
|
|
522
|
+
|
|
523
|
+
Returns
|
|
524
|
+
-------
|
|
525
|
+
dict[str, Function]
|
|
526
|
+
A dictionary of all excluded functions in the function group.
|
|
527
|
+
|
|
528
|
+
Raises
|
|
529
|
+
------
|
|
530
|
+
ValueError
|
|
531
|
+
When the function group is configured to exclude functions that are not found in the group.
|
|
532
|
+
"""
|
|
533
|
+
missing = set(self._config.exclude) - set(self._functions.keys())
|
|
534
|
+
if missing:
|
|
535
|
+
raise ValueError(f"Unknown excluded functions: {sorted(missing)}")
|
|
536
|
+
filter_fn = filter_fn or self._filter_fn or (lambda x: x)
|
|
537
|
+
excluded = set(self._config.exclude)
|
|
538
|
+
included = set(filter_fn(list(self._functions.keys())))
|
|
539
|
+
|
|
540
|
+
def predicate(name: str) -> bool:
|
|
541
|
+
if name in excluded:
|
|
542
|
+
return True
|
|
543
|
+
if not self._fn_should_be_included(name):
|
|
544
|
+
return True
|
|
545
|
+
return name not in included
|
|
546
|
+
|
|
547
|
+
return {self._get_fn_name(name): self._functions[name] for name in self._functions if predicate(name)}
|
|
548
|
+
|
|
549
|
+
def get_included_functions(
|
|
550
|
+
self,
|
|
551
|
+
filter_fn: Callable[[Sequence[str]], Sequence[str]] | None = None,
|
|
552
|
+
) -> dict[str, Function]:
|
|
553
|
+
"""
|
|
554
|
+
Returns a dictionary of all functions in the function group which are:
|
|
555
|
+
- configured to be included and added to the global function registry
|
|
556
|
+
- not configured to be excluded.
|
|
557
|
+
- not filtered out by a filter function.
|
|
558
|
+
|
|
559
|
+
Parameters
|
|
560
|
+
----------
|
|
561
|
+
filter_fn : Callable[[Sequence[str]], Sequence[str]] | None, optional
|
|
562
|
+
A callback function to additionally filter the functions in the function group dynamically. If not provided
|
|
563
|
+
then fall back to the function group's filter function. If no filter function is set for the function group
|
|
564
|
+
all functions will be returned.
|
|
565
|
+
|
|
566
|
+
Returns
|
|
567
|
+
-------
|
|
568
|
+
dict[str, Function]
|
|
569
|
+
A dictionary of all included functions in the function group.
|
|
570
|
+
|
|
571
|
+
Raises
|
|
572
|
+
------
|
|
573
|
+
ValueError
|
|
574
|
+
When the function group is configured to include functions that are not found in the group.
|
|
575
|
+
"""
|
|
576
|
+
missing = set(self._config.include) - set(self._functions.keys())
|
|
577
|
+
if missing:
|
|
578
|
+
raise ValueError(f"Unknown included functions: {sorted(missing)}")
|
|
579
|
+
filter_fn = filter_fn or self._filter_fn or (lambda x: x)
|
|
580
|
+
included = set(filter_fn(list(self._config.include)))
|
|
581
|
+
included = {name for name in included if self._fn_should_be_included(name)}
|
|
582
|
+
return {self._get_fn_name(name): self._functions[name] for name in included}
|
|
583
|
+
|
|
584
|
+
def get_all_functions(
|
|
585
|
+
self,
|
|
586
|
+
filter_fn: Callable[[Sequence[str]], Sequence[str]] | None = None,
|
|
587
|
+
) -> dict[str, Function]:
|
|
588
|
+
"""
|
|
589
|
+
Returns a dictionary of all functions in the function group, regardless if they are included or excluded.
|
|
590
|
+
|
|
591
|
+
If a filter function has been set, the returned functions will additionally be filtered by the callback.
|
|
592
|
+
|
|
593
|
+
Parameters
|
|
594
|
+
----------
|
|
595
|
+
filter_fn : Callable[[Sequence[str]], Sequence[str]] | None, optional
|
|
596
|
+
A callback function to additionally filter the functions in the function group dynamically. If not provided
|
|
597
|
+
then fall back to the function group's filter function. If no filter function is set for the function group
|
|
598
|
+
all functions will be returned.
|
|
599
|
+
|
|
600
|
+
Returns
|
|
601
|
+
-------
|
|
602
|
+
dict[str, Function]
|
|
603
|
+
A dictionary of all functions in the function group.
|
|
604
|
+
"""
|
|
605
|
+
filter_fn = filter_fn or self._filter_fn or (lambda x: x)
|
|
606
|
+
included = set(filter_fn(list(self._functions.keys())))
|
|
607
|
+
included = {name for name in included if self._fn_should_be_included(name)}
|
|
608
|
+
return {self._get_fn_name(name): self._functions[name] for name in included}
|
|
609
|
+
|
|
610
|
+
def set_filter_fn(self, filter_fn: Callable[[Sequence[str]], Sequence[str]]):
|
|
611
|
+
"""
|
|
612
|
+
Sets the filter function for the function group.
|
|
613
|
+
|
|
614
|
+
Parameters
|
|
615
|
+
----------
|
|
616
|
+
filter_fn : Callable[[Sequence[str]], Sequence[str]]
|
|
617
|
+
The filter function to set for the function group.
|
|
618
|
+
"""
|
|
619
|
+
self._filter_fn = filter_fn
|
|
620
|
+
|
|
621
|
+
def set_per_function_filter_fn(self, name: str, filter_fn: Callable[[str], bool]):
|
|
622
|
+
"""
|
|
623
|
+
Sets the a per-function filter function for the a function within the function group.
|
|
624
|
+
|
|
625
|
+
Parameters
|
|
626
|
+
----------
|
|
627
|
+
name : str
|
|
628
|
+
The name of the function.
|
|
629
|
+
filter_fn : Callable[[str], bool]
|
|
630
|
+
The per-function filter function to set for the function group.
|
|
631
|
+
|
|
632
|
+
Raises
|
|
633
|
+
------
|
|
634
|
+
ValueError
|
|
635
|
+
When the function is not found in the function group.
|
|
636
|
+
"""
|
|
637
|
+
if name not in self._functions:
|
|
638
|
+
raise ValueError(f"Function {name} not found in function group {self._instance_name}")
|
|
639
|
+
self._per_function_filter_fn[name] = filter_fn
|
|
640
|
+
|
|
641
|
+
def set_instance_name(self, instance_name: str):
|
|
642
|
+
"""
|
|
643
|
+
Sets the instance name for the function group.
|
|
644
|
+
|
|
645
|
+
Parameters
|
|
646
|
+
----------
|
|
647
|
+
instance_name : str
|
|
648
|
+
The instance name to set for the function group.
|
|
649
|
+
"""
|
|
650
|
+
self._instance_name = instance_name
|
|
651
|
+
|
|
652
|
+
@property
|
|
653
|
+
def instance_name(self) -> str:
|
|
654
|
+
"""
|
|
655
|
+
Returns the instance name for the function group.
|
|
656
|
+
"""
|
|
657
|
+
return self._instance_name
|
nat/builder/function_info.py
CHANGED
|
@@ -233,7 +233,7 @@ class FunctionDescriptor:
|
|
|
233
233
|
|
|
234
234
|
is_input_typed = all([a != sig.empty for a in annotations])
|
|
235
235
|
|
|
236
|
-
input_type = tuple[*annotations] if is_input_typed else None
|
|
236
|
+
input_type = tuple[*annotations] if is_input_typed else None
|
|
237
237
|
|
|
238
238
|
# Get the base type here removing all annotations and async generators
|
|
239
239
|
output_annotation_decomp = DecomposedType(sig.return_annotation).get_base_type()
|
nat/builder/workflow.py
CHANGED
|
@@ -20,6 +20,7 @@ from typing import Any
|
|
|
20
20
|
from nat.builder.context import ContextState
|
|
21
21
|
from nat.builder.embedder import EmbedderProviderInfo
|
|
22
22
|
from nat.builder.function import Function
|
|
23
|
+
from nat.builder.function import FunctionGroup
|
|
23
24
|
from nat.builder.function_base import FunctionBase
|
|
24
25
|
from nat.builder.function_base import InputT
|
|
25
26
|
from nat.builder.function_base import SingleOutputT
|
|
@@ -44,6 +45,7 @@ class Workflow(FunctionBase[InputT, StreamingOutputT, SingleOutputT]):
|
|
|
44
45
|
config: Config,
|
|
45
46
|
entry_fn: Function[InputT, StreamingOutputT, SingleOutputT],
|
|
46
47
|
functions: dict[str, Function] | None = None,
|
|
48
|
+
function_groups: dict[str, FunctionGroup] | None = None,
|
|
47
49
|
llms: dict[str, LLMProviderInfo] | None = None,
|
|
48
50
|
embeddings: dict[str, EmbedderProviderInfo] | None = None,
|
|
49
51
|
memory: dict[str, MemoryEditor] | None = None,
|
|
@@ -59,6 +61,7 @@ class Workflow(FunctionBase[InputT, StreamingOutputT, SingleOutputT]):
|
|
|
59
61
|
|
|
60
62
|
self.config = config
|
|
61
63
|
self.functions = functions or {}
|
|
64
|
+
self.function_groups = function_groups or {}
|
|
62
65
|
self.llms = llms or {}
|
|
63
66
|
self.embeddings = embeddings or {}
|
|
64
67
|
self.memory = memory or {}
|
|
@@ -126,6 +129,7 @@ class Workflow(FunctionBase[InputT, StreamingOutputT, SingleOutputT]):
|
|
|
126
129
|
config: Config,
|
|
127
130
|
entry_fn: Function[InputT, StreamingOutputT, SingleOutputT],
|
|
128
131
|
functions: dict[str, Function] | None = None,
|
|
132
|
+
function_groups: dict[str, FunctionGroup] | None = None,
|
|
129
133
|
llms: dict[str, LLMProviderInfo] | None = None,
|
|
130
134
|
embeddings: dict[str, EmbedderProviderInfo] | None = None,
|
|
131
135
|
memory: dict[str, MemoryEditor] | None = None,
|
|
@@ -145,6 +149,7 @@ class Workflow(FunctionBase[InputT, StreamingOutputT, SingleOutputT]):
|
|
|
145
149
|
return WorkflowImpl(config=config,
|
|
146
150
|
entry_fn=entry_fn,
|
|
147
151
|
functions=functions,
|
|
152
|
+
function_groups=function_groups,
|
|
148
153
|
llms=llms,
|
|
149
154
|
embeddings=embeddings,
|
|
150
155
|
memory=memory,
|