nvidia-nat 1.3.0a20250910__py3-none-any.whl → 1.3.0a20250922__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (114) hide show
  1. nat/agent/base.py +9 -4
  2. nat/agent/prompt_optimizer/prompt.py +68 -0
  3. nat/agent/prompt_optimizer/register.py +149 -0
  4. nat/agent/react_agent/agent.py +1 -1
  5. nat/agent/react_agent/register.py +17 -14
  6. nat/agent/reasoning_agent/reasoning_agent.py +9 -7
  7. nat/agent/register.py +1 -0
  8. nat/agent/rewoo_agent/agent.py +9 -2
  9. nat/agent/rewoo_agent/register.py +16 -12
  10. nat/agent/tool_calling_agent/agent.py +69 -7
  11. nat/agent/tool_calling_agent/register.py +14 -13
  12. nat/authentication/credential_validator/__init__.py +14 -0
  13. nat/authentication/credential_validator/bearer_token_validator.py +557 -0
  14. nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
  15. nat/builder/builder.py +27 -4
  16. nat/builder/component_utils.py +7 -3
  17. nat/builder/context.py +28 -6
  18. nat/builder/function.py +313 -0
  19. nat/builder/function_info.py +1 -1
  20. nat/builder/workflow.py +5 -0
  21. nat/builder/workflow_builder.py +215 -16
  22. nat/cli/commands/optimize.py +90 -0
  23. nat/cli/commands/workflow/templates/config.yml.j2 +0 -1
  24. nat/cli/commands/workflow/workflow_commands.py +4 -7
  25. nat/cli/entrypoint.py +4 -9
  26. nat/cli/register_workflow.py +38 -4
  27. nat/cli/type_registry.py +71 -0
  28. nat/control_flow/__init__.py +0 -0
  29. nat/control_flow/register.py +20 -0
  30. nat/control_flow/router_agent/__init__.py +0 -0
  31. nat/control_flow/router_agent/agent.py +329 -0
  32. nat/control_flow/router_agent/prompt.py +48 -0
  33. nat/control_flow/router_agent/register.py +91 -0
  34. nat/control_flow/sequential_executor.py +167 -0
  35. nat/data_models/agent.py +34 -0
  36. nat/data_models/authentication.py +38 -0
  37. nat/data_models/component.py +2 -0
  38. nat/data_models/component_ref.py +11 -0
  39. nat/data_models/config.py +40 -16
  40. nat/data_models/function.py +34 -0
  41. nat/data_models/function_dependencies.py +8 -0
  42. nat/data_models/optimizable.py +119 -0
  43. nat/data_models/optimizer.py +149 -0
  44. nat/data_models/temperature_mixin.py +4 -3
  45. nat/data_models/top_p_mixin.py +4 -3
  46. nat/embedder/nim_embedder.py +1 -1
  47. nat/embedder/openai_embedder.py +1 -1
  48. nat/eval/config.py +1 -1
  49. nat/eval/evaluate.py +5 -1
  50. nat/eval/register.py +4 -0
  51. nat/eval/runtime_evaluator/__init__.py +14 -0
  52. nat/eval/runtime_evaluator/evaluate.py +123 -0
  53. nat/eval/runtime_evaluator/register.py +100 -0
  54. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +5 -1
  55. nat/front_ends/fastapi/dask_client_mixin.py +65 -0
  56. nat/front_ends/fastapi/fastapi_front_end_config.py +18 -3
  57. nat/front_ends/fastapi/fastapi_front_end_plugin.py +134 -3
  58. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +243 -228
  59. nat/front_ends/fastapi/job_store.py +518 -99
  60. nat/front_ends/fastapi/main.py +11 -19
  61. nat/front_ends/fastapi/utils.py +57 -0
  62. nat/front_ends/mcp/introspection_token_verifier.py +73 -0
  63. nat/front_ends/mcp/mcp_front_end_config.py +5 -1
  64. nat/front_ends/mcp/mcp_front_end_plugin.py +37 -11
  65. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +111 -3
  66. nat/front_ends/mcp/tool_converter.py +3 -0
  67. nat/llm/aws_bedrock_llm.py +14 -3
  68. nat/llm/nim_llm.py +14 -3
  69. nat/llm/openai_llm.py +8 -1
  70. nat/observability/exporter/processing_exporter.py +29 -55
  71. nat/observability/mixin/redaction_config_mixin.py +5 -4
  72. nat/observability/mixin/tagging_config_mixin.py +26 -14
  73. nat/observability/mixin/type_introspection_mixin.py +420 -107
  74. nat/observability/processor/processor.py +3 -0
  75. nat/observability/processor/redaction/__init__.py +24 -0
  76. nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
  77. nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
  78. nat/observability/processor/redaction/redaction_processor.py +177 -0
  79. nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
  80. nat/observability/processor/span_tagging_processor.py +21 -14
  81. nat/profiler/decorators/framework_wrapper.py +9 -6
  82. nat/profiler/parameter_optimization/__init__.py +0 -0
  83. nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
  84. nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
  85. nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
  86. nat/profiler/parameter_optimization/parameter_selection.py +108 -0
  87. nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
  88. nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
  89. nat/profiler/parameter_optimization/update_helpers.py +66 -0
  90. nat/profiler/utils.py +3 -1
  91. nat/tool/chat_completion.py +4 -1
  92. nat/tool/github_tools.py +450 -0
  93. nat/tool/register.py +2 -7
  94. nat/utils/callable_utils.py +70 -0
  95. nat/utils/exception_handlers/automatic_retries.py +103 -48
  96. nat/utils/log_levels.py +25 -0
  97. nat/utils/type_utils.py +4 -0
  98. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/METADATA +10 -1
  99. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/RECORD +105 -76
  100. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/entry_points.txt +1 -0
  101. nat/observability/processor/header_redaction_processor.py +0 -123
  102. nat/observability/processor/redaction_processor.py +0 -77
  103. nat/tool/github_tools/create_github_commit.py +0 -133
  104. nat/tool/github_tools/create_github_issue.py +0 -87
  105. nat/tool/github_tools/create_github_pr.py +0 -106
  106. nat/tool/github_tools/get_github_file.py +0 -106
  107. nat/tool/github_tools/get_github_issue.py +0 -166
  108. nat/tool/github_tools/get_github_pr.py +0 -256
  109. nat/tool/github_tools/update_github_issue.py +0 -100
  110. /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
  111. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/WHEEL +0 -0
  112. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  113. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/licenses/LICENSE.md +0 -0
  114. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,124 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from urllib.parse import urlparse
17
+
18
+ from pydantic import Field
19
+ from pydantic import field_validator
20
+ from pydantic import model_validator
21
+
22
+ from nat.data_models.authentication import AuthProviderBaseConfig
23
+
24
+
25
+ class OAuth2ResourceServerConfig(AuthProviderBaseConfig, name="oauth2_resource_server"):
26
+ """OAuth 2.0 Resource Server authentication configuration.
27
+
28
+ Supports:
29
+ • JWT access tokens via JWKS / OIDC Discovery / issuer fallback
30
+ • Opaque access tokens via RFC 7662 introspection
31
+ """
32
+
33
+ issuer_url: str = Field(
34
+ description=("The unique issuer identifier for an authorization server. "
35
+ "Required for validation and used to derive the default JWKS URI "
36
+ "(<issuer_url>/.well-known/jwks.json) if `jwks_uri` and `discovery_url` are not provided."), )
37
+ scopes: list[str] = Field(
38
+ default_factory=list,
39
+ description="Scopes required by this API. Validation ensures the token grants all listed scopes.",
40
+ )
41
+ audience: str | None = Field(
42
+ default=None,
43
+ description=(
44
+ "Expected audience (`aud`) claim for this API. If set, validation will reject tokens without this audience."
45
+ ),
46
+ )
47
+
48
+ # JWT verification params
49
+ jwks_uri: str | None = Field(
50
+ default=None,
51
+ description=("Direct JWKS endpoint URI for JWT signature verification. "
52
+ "Optional if discovery or issuer is provided."),
53
+ )
54
+ discovery_url: str | None = Field(
55
+ default=None,
56
+ description=("OIDC discovery metadata URL. Used to automatically resolve JWKS and introspection endpoints."),
57
+ )
58
+
59
+ # Opaque token (introspection) params
60
+ introspection_endpoint: str | None = Field(
61
+ default=None,
62
+ description=("RFC 7662 token introspection endpoint. "
63
+ "Required for opaque token validation and must be used with `client_id` and `client_secret`."),
64
+ )
65
+ client_id: str | None = Field(
66
+ default=None,
67
+ description="OAuth2 client ID for authenticating to the introspection endpoint (opaque token validation).",
68
+ )
69
+ client_secret: str | None = Field(
70
+ default=None,
71
+ description="OAuth2 client secret for authenticating to the introspection endpoint (opaque token validation).",
72
+ )
73
+
74
+ @staticmethod
75
+ def _is_https_or_localhost(url: str) -> bool:
76
+ try:
77
+ value = urlparse(url)
78
+ if not value.scheme or not value.netloc:
79
+ return False
80
+ if value.scheme == "https":
81
+ return True
82
+ return value.scheme == "http" and (value.hostname in {"localhost", "127.0.0.1", "::1"})
83
+ except Exception:
84
+ return False
85
+
86
+ @field_validator("issuer_url", "jwks_uri", "discovery_url", "introspection_endpoint")
87
+ @classmethod
88
+ def _require_valid_url(cls, value: str | None, info):
89
+ if value is None:
90
+ return value
91
+ if not cls._is_https_or_localhost(value):
92
+ raise ValueError(f"{info.field_name} must be HTTPS (http allowed only for localhost). Got: {value}")
93
+ return value
94
+
95
+ # ---------- Cross-field validation: ensure at least one viable path ----------
96
+
97
+ @model_validator(mode="after")
98
+ def _ensure_verification_path(self):
99
+ """
100
+ JWT path viable if any of: jwks_uri OR discovery_url OR issuer_url (fallback JWKS).
101
+ Opaque path viable if: introspection_endpoint AND client_id AND client_secret.
102
+ """
103
+ has_jwt_path = bool(self.jwks_uri or self.discovery_url or self.issuer_url)
104
+ has_opaque_path = bool(self.introspection_endpoint and self.client_id and self.client_secret)
105
+
106
+ # If introspection endpoint is set, enforce creds are present
107
+ if self.introspection_endpoint:
108
+ missing = []
109
+ if not self.client_id:
110
+ missing.append("client_id")
111
+ if not self.client_secret:
112
+ missing.append("client_secret")
113
+ if missing:
114
+ raise ValueError(
115
+ f"introspection_endpoint configured but missing required credentials: {', '.join(missing)}")
116
+
117
+ # Require at least one path
118
+ if not (has_jwt_path or has_opaque_path):
119
+ raise ValueError("Invalid configuration: no verification method available. "
120
+ "Configure one of the following:\n"
121
+ " • JWT path: set jwks_uri OR discovery_url OR issuer_url (for JWKS fallback)\n"
122
+ " • Opaque path: set introspection_endpoint + client_id + client_secret")
123
+
124
+ return self
nat/builder/builder.py CHANGED
@@ -24,9 +24,11 @@ from nat.authentication.interfaces import AuthProviderBase
24
24
  from nat.builder.context import Context
25
25
  from nat.builder.framework_enum import LLMFrameworkEnum
26
26
  from nat.builder.function import Function
27
+ from nat.builder.function import FunctionGroup
27
28
  from nat.data_models.authentication import AuthProviderBaseConfig
28
29
  from nat.data_models.component_ref import AuthenticationRef
29
30
  from nat.data_models.component_ref import EmbedderRef
31
+ from nat.data_models.component_ref import FunctionGroupRef
30
32
  from nat.data_models.component_ref import FunctionRef
31
33
  from nat.data_models.component_ref import LLMRef
32
34
  from nat.data_models.component_ref import MemoryRef
@@ -36,6 +38,7 @@ from nat.data_models.component_ref import TTCStrategyRef
36
38
  from nat.data_models.embedder import EmbedderBaseConfig
37
39
  from nat.data_models.evaluator import EvaluatorBaseConfig
38
40
  from nat.data_models.function import FunctionBaseConfig
41
+ from nat.data_models.function import FunctionGroupBaseConfig
39
42
  from nat.data_models.function_dependencies import FunctionDependencies
40
43
  from nat.data_models.llm import LLMBaseConfig
41
44
  from nat.data_models.memory import MemoryBaseConfig
@@ -64,18 +67,33 @@ class Builder(ABC):
64
67
  async def add_function(self, name: str | FunctionRef, config: FunctionBaseConfig) -> Function:
65
68
  pass
66
69
 
70
+ @abstractmethod
71
+ async def add_function_group(self, name: str | FunctionGroupRef, config: FunctionGroupBaseConfig) -> FunctionGroup:
72
+ pass
73
+
67
74
  @abstractmethod
68
75
  def get_function(self, name: str | FunctionRef) -> Function:
69
76
  pass
70
77
 
78
+ @abstractmethod
79
+ def get_function_group(self, name: str | FunctionGroupRef) -> FunctionGroup:
80
+ pass
81
+
71
82
  def get_functions(self, function_names: Sequence[str | FunctionRef]) -> list[Function]:
72
83
 
73
84
  return [self.get_function(name) for name in function_names]
74
85
 
86
+ def get_function_groups(self, function_group_names: Sequence[str | FunctionGroupRef]) -> list[FunctionGroup]:
87
+ return [self.get_function_group(name) for name in function_group_names]
88
+
75
89
  @abstractmethod
76
90
  def get_function_config(self, name: str | FunctionRef) -> FunctionBaseConfig:
77
91
  pass
78
92
 
93
+ @abstractmethod
94
+ def get_function_group_config(self, name: str | FunctionGroupRef) -> FunctionGroupBaseConfig:
95
+ pass
96
+
79
97
  @abstractmethod
80
98
  async def set_workflow(self, config: FunctionBaseConfig) -> Function:
81
99
  pass
@@ -88,10 +106,11 @@ class Builder(ABC):
88
106
  def get_workflow_config(self) -> FunctionBaseConfig:
89
107
  pass
90
108
 
91
- def get_tools(self, tool_names: Sequence[str | FunctionRef],
109
+ @abstractmethod
110
+ def get_tools(self,
111
+ tool_names: Sequence[str | FunctionRef | FunctionGroupRef],
92
112
  wrapper_type: LLMFrameworkEnum | str) -> list[typing.Any]:
93
-
94
- return [self.get_tool(fn_name=n, wrapper_type=wrapper_type) for n in tool_names]
113
+ pass
95
114
 
96
115
  @abstractmethod
97
116
  def get_tool(self, fn_name: str | FunctionRef, wrapper_type: LLMFrameworkEnum | str) -> typing.Any:
@@ -257,8 +276,12 @@ class Builder(ABC):
257
276
  def get_function_dependencies(self, fn_name: str) -> FunctionDependencies:
258
277
  pass
259
278
 
279
+ @abstractmethod
280
+ def get_function_group_dependencies(self, fn_name: str) -> FunctionDependencies:
281
+ pass
282
+
260
283
 
261
- class EvalBuilder(Builder):
284
+ class EvalBuilder(ABC):
262
285
 
263
286
  @abstractmethod
264
287
  async def add_evaluator(self, name: str, config: EvaluatorBaseConfig):
@@ -30,6 +30,7 @@ from nat.data_models.component_ref import generate_instance_id
30
30
  from nat.data_models.config import Config
31
31
  from nat.data_models.embedder import EmbedderBaseConfig
32
32
  from nat.data_models.function import FunctionBaseConfig
33
+ from nat.data_models.function import FunctionGroupBaseConfig
33
34
  from nat.data_models.llm import LLMBaseConfig
34
35
  from nat.data_models.memory import MemoryBaseConfig
35
36
  from nat.data_models.object_store import ObjectStoreBaseConfig
@@ -48,6 +49,7 @@ _component_group_order = [
48
49
  ComponentGroup.OBJECT_STORES,
49
50
  ComponentGroup.RETRIEVERS,
50
51
  ComponentGroup.TTC_STRATEGIES,
52
+ ComponentGroup.FUNCTION_GROUPS,
51
53
  ComponentGroup.FUNCTIONS,
52
54
  ]
53
55
 
@@ -107,6 +109,8 @@ def group_from_component(component: TypedBaseModel) -> ComponentGroup | None:
107
109
  return ComponentGroup.EMBEDDERS
108
110
  if (isinstance(component, FunctionBaseConfig)):
109
111
  return ComponentGroup.FUNCTIONS
112
+ if (isinstance(component, FunctionGroupBaseConfig)):
113
+ return ComponentGroup.FUNCTION_GROUPS
110
114
  if (isinstance(component, LLMBaseConfig)):
111
115
  return ComponentGroup.LLMS
112
116
  if (isinstance(component, MemoryBaseConfig)):
@@ -254,9 +258,9 @@ def build_dependency_sequence(config: "Config") -> list[ComponentInstanceData]:
254
258
  runtime instance references.
255
259
  """
256
260
 
257
- total_node_count = len(config.embedders) + len(config.functions) + len(config.llms) + len(config.memory) + len(
258
- config.object_stores) + len(config.retrievers) + len(config.ttc_strategies) + len(
259
- config.authentication) + 1 # +1 for the workflow
261
+ total_node_count = (len(config.embedders) + len(config.functions) + len(config.function_groups) + len(config.llms) +
262
+ len(config.memory) + len(config.object_stores) + len(config.retrievers) +
263
+ len(config.ttc_strategies) + len(config.authentication) + 1) # +1 for the workflow
260
264
 
261
265
  dependency_map: dict
262
266
  dependency_graph: nx.DiGraph
nat/builder/context.py CHANGED
@@ -69,12 +69,10 @@ class ContextState(metaclass=Singleton):
69
69
  self.user_message_id: ContextVar[str | None] = ContextVar("user_message_id", default=None)
70
70
  self.input_message: ContextVar[typing.Any] = ContextVar("input_message", default=None)
71
71
  self.user_manager: ContextVar[typing.Any] = ContextVar("user_manager", default=None)
72
- self.metadata: ContextVar[RequestAttributes] = ContextVar("request_attributes", default=RequestAttributes())
73
- self.event_stream: ContextVar[Subject[IntermediateStep] | None] = ContextVar("event_stream", default=Subject())
74
- self.active_function: ContextVar[InvocationNode] = ContextVar("active_function",
75
- default=InvocationNode(function_id="root",
76
- function_name="root"))
77
- self.active_span_id_stack: ContextVar[list[str]] = ContextVar("active_span_id_stack", default=["root"])
72
+ self._metadata: ContextVar[RequestAttributes | None] = ContextVar("request_attributes", default=None)
73
+ self._event_stream: ContextVar[Subject[IntermediateStep] | None] = ContextVar("event_stream", default=None)
74
+ self._active_function: ContextVar[InvocationNode | None] = ContextVar("active_function", default=None)
75
+ self._active_span_id_stack: ContextVar[list[str] | None] = ContextVar("active_span_id_stack", default=None)
78
76
 
79
77
  # Default is a lambda no-op which returns NoneType
80
78
  self.user_input_callback: ContextVar[Callable[[InteractionPrompt], Awaitable[HumanResponse | None]]
@@ -85,6 +83,30 @@ class ContextState(metaclass=Singleton):
85
83
  Awaitable[AuthenticatedContext]]
86
84
  | None] = ContextVar("user_auth_callback", default=None)
87
85
 
86
+ @property
87
+ def metadata(self) -> ContextVar[RequestAttributes]:
88
+ if self._metadata.get() is None:
89
+ self._metadata.set(RequestAttributes())
90
+ return typing.cast(ContextVar[RequestAttributes], self._metadata)
91
+
92
+ @property
93
+ def active_function(self) -> ContextVar[InvocationNode]:
94
+ if self._active_function.get() is None:
95
+ self._active_function.set(InvocationNode(function_id="root", function_name="root"))
96
+ return typing.cast(ContextVar[InvocationNode], self._active_function)
97
+
98
+ @property
99
+ def event_stream(self) -> ContextVar[Subject[IntermediateStep]]:
100
+ if self._event_stream.get() is None:
101
+ self._event_stream.set(Subject())
102
+ return typing.cast(ContextVar[Subject[IntermediateStep]], self._event_stream)
103
+
104
+ @property
105
+ def active_span_id_stack(self) -> ContextVar[list[str]]:
106
+ if self._active_span_id_stack.get() is None:
107
+ self._active_span_id_stack.set(["root"])
108
+ return typing.cast(ContextVar[list[str]], self._active_span_id_stack)
109
+
88
110
  @staticmethod
89
111
  def get() -> "ContextState":
90
112
  return ContextState()
nat/builder/function.py CHANGED
@@ -14,12 +14,14 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import logging
17
+ import re
17
18
  import typing
18
19
  from abc import ABC
19
20
  from abc import abstractmethod
20
21
  from collections.abc import AsyncGenerator
21
22
  from collections.abc import Awaitable
22
23
  from collections.abc import Callable
24
+ from collections.abc import Sequence
23
25
 
24
26
  from pydantic import BaseModel
25
27
 
@@ -29,7 +31,9 @@ from nat.builder.function_base import InputT
29
31
  from nat.builder.function_base import SingleOutputT
30
32
  from nat.builder.function_base import StreamingOutputT
31
33
  from nat.builder.function_info import FunctionInfo
34
+ from nat.data_models.function import EmptyFunctionConfig
32
35
  from nat.data_models.function import FunctionBaseConfig
36
+ from nat.data_models.function import FunctionGroupBaseConfig
33
37
 
34
38
  _InvokeFnT = Callable[[InputT], Awaitable[SingleOutputT]]
35
39
  _StreamFnT = Callable[[InputT], AsyncGenerator[StreamingOutputT]]
@@ -342,3 +346,312 @@ class LambdaFunction(Function[InputT, StreamingOutputT, SingleOutputT]):
342
346
  pass
343
347
 
344
348
  return FunctionImpl(config=config, info=info, instance_name=instance_name)
349
+
350
+
351
+ class FunctionGroup:
352
+ """
353
+ A group of functions that can be used together, sharing the same configuration, context, and resources.
354
+ """
355
+
356
+ def __init__(self,
357
+ *,
358
+ config: FunctionGroupBaseConfig,
359
+ instance_name: str | None = None,
360
+ filter_fn: Callable[[Sequence[str]], Sequence[str]] | None = None):
361
+ """
362
+ Creates a new function group.
363
+
364
+ Parameters
365
+ ----------
366
+ config : FunctionGroupBaseConfig
367
+ The configuration for the function group.
368
+ instance_name : str | None, optional
369
+ The name of the function group. If not provided, the type of the function group will be used.
370
+ filter_fn : Callable[[Sequence[str]], Sequence[str]] | None, optional
371
+ A callback function to additionally filter the functions in the function group dynamically when
372
+ the functions are accessed via any accessor method.
373
+ """
374
+ self._config = config
375
+ self._instance_name = instance_name or config.type
376
+ self._functions: dict[str, Function] = dict()
377
+ self._filter_fn = filter_fn
378
+ self._per_function_filter_fn: dict[str, Callable[[str], bool]] = dict()
379
+
380
+ def add_function(self,
381
+ name: str,
382
+ fn: Callable,
383
+ *,
384
+ input_schema: type[BaseModel] | None = None,
385
+ description: str | None = None,
386
+ converters: list[Callable] | None = None,
387
+ filter_fn: Callable[[str], bool] | None = None):
388
+ """
389
+ Adds a function to the function group.
390
+
391
+ Parameters
392
+ ----------
393
+ name : str
394
+ The name of the function.
395
+ fn : Callable
396
+ The function to add to the function group.
397
+ input_schema : type[BaseModel] | None, optional
398
+ The input schema for the function.
399
+ description : str | None, optional
400
+ The description of the function.
401
+ converters : list[Callable] | None, optional
402
+ The converters to use for the function.
403
+ filter_fn : Callable[[str], bool] | None, optional
404
+ A callback to determine if the function should be included in the function group. The
405
+ callback will be called with the function name. The callback is invoked dynamically when
406
+ the functions are accessed via any accessor method such as `get_accessible_functions`,
407
+ `get_included_functions`, `get_excluded_functions`, `get_all_functions`.
408
+
409
+ Raises
410
+ ------
411
+ ValueError
412
+ When the function name is empty or blank.
413
+ When the function name contains invalid characters.
414
+ When the function already exists in the function group.
415
+ """
416
+ if not name.strip():
417
+ raise ValueError("Function name cannot be empty or blank")
418
+ if not re.match(r"^[a-zA-Z0-9_-]+$", name):
419
+ raise ValueError(f"Function name can only contain letters, numbers, underscores, and hyphens: {name}")
420
+ if name in self._functions:
421
+ raise ValueError(f"Function {name} already exists in function group {self._instance_name}")
422
+
423
+ info = FunctionInfo.from_fn(fn, input_schema=input_schema, description=description, converters=converters)
424
+ full_name = self._get_fn_name(name)
425
+ lambda_fn = LambdaFunction.from_info(config=EmptyFunctionConfig(), info=info, instance_name=full_name)
426
+ self._functions[name] = lambda_fn
427
+ if filter_fn:
428
+ self._per_function_filter_fn[name] = filter_fn
429
+
430
+ def get_config(self) -> FunctionGroupBaseConfig:
431
+ """
432
+ Returns the configuration for the function group.
433
+
434
+ Returns
435
+ -------
436
+ FunctionGroupBaseConfig
437
+ The configuration for the function group.
438
+ """
439
+ return self._config
440
+
441
+ def _get_fn_name(self, name: str) -> str:
442
+ return f"{self._instance_name}.{name}"
443
+
444
+ def _fn_should_be_included(self, name: str) -> bool:
445
+ return (name not in self._per_function_filter_fn or self._per_function_filter_fn[name](name))
446
+
447
+ def _get_all_but_excluded_functions(
448
+ self,
449
+ filter_fn: Callable[[Sequence[str]], Sequence[str]] | None = None,
450
+ ) -> dict[str, Function]:
451
+ """
452
+ Returns a dictionary of all functions in the function group except the excluded functions.
453
+ """
454
+ missing = set(self._config.exclude) - set(self._functions.keys())
455
+ if missing:
456
+ raise ValueError(f"Unknown excluded functions: {sorted(missing)}")
457
+ filter_fn = filter_fn or self._filter_fn or (lambda x: x)
458
+ excluded = set(self._config.exclude)
459
+ included = set(filter_fn(list(self._functions.keys())))
460
+
461
+ def predicate(name: str) -> bool:
462
+ if name in excluded:
463
+ return False
464
+ if not self._fn_should_be_included(name):
465
+ return False
466
+ return name in included
467
+
468
+ return {self._get_fn_name(name): self._functions[name] for name in self._functions if predicate(name)}
469
+
470
+ def get_accessible_functions(
471
+ self,
472
+ filter_fn: Callable[[Sequence[str]], Sequence[str]] | None = None,
473
+ ) -> dict[str, Function]:
474
+ """
475
+ Returns a dictionary of all accessible functions in the function group.
476
+
477
+ First, the functions are filtered by the function group's configuration.
478
+ If the function group is configured to:
479
+ - include some functions, this will return only the included functions.
480
+ - not include or exclude any function, this will return all functions in the group.
481
+ - exclude some functions, this will return all functions in the group except the excluded functions.
482
+
483
+ Then, the functions are filtered by filter function and per-function filter functions.
484
+
485
+ Parameters
486
+ ----------
487
+ filter_fn : Callable[[Sequence[str]], Sequence[str]] | None, optional
488
+ A callback function to additionally filter the functions in the function group dynamically. If not provided
489
+ then fall back to the function group's filter function. If no filter function is set for the function group
490
+ all functions will be returned.
491
+
492
+ Returns
493
+ -------
494
+ dict[str, Function]
495
+ A dictionary of all accessible functions in the function group.
496
+
497
+ Raises
498
+ ------
499
+ ValueError
500
+ When the function group is configured to include functions that are not found in the group.
501
+ """
502
+ if self._config.include:
503
+ return self.get_included_functions(filter_fn=filter_fn)
504
+ if self._config.exclude:
505
+ return self._get_all_but_excluded_functions(filter_fn=filter_fn)
506
+ return self.get_all_functions(filter_fn=filter_fn)
507
+
508
+ def get_excluded_functions(
509
+ self,
510
+ filter_fn: Callable[[Sequence[str]], Sequence[str]] | None = None,
511
+ ) -> dict[str, Function]:
512
+ """
513
+ Returns a dictionary of all functions in the function group which are configured to be excluded or filtered
514
+ out by a filter function or per-function filter function.
515
+
516
+ Parameters
517
+ ----------
518
+ filter_fn : Callable[[Sequence[str]], Sequence[str]] | None, optional
519
+ A callback function to additionally filter the functions in the function group dynamically. If not provided
520
+ then fall back to the function group's filter function. If no filter function is set for the function group
521
+ then no functions will be added to the returned dictionary.
522
+
523
+ Returns
524
+ -------
525
+ dict[str, Function]
526
+ A dictionary of all excluded functions in the function group.
527
+
528
+ Raises
529
+ ------
530
+ ValueError
531
+ When the function group is configured to exclude functions that are not found in the group.
532
+ """
533
+ missing = set(self._config.exclude) - set(self._functions.keys())
534
+ if missing:
535
+ raise ValueError(f"Unknown excluded functions: {sorted(missing)}")
536
+ filter_fn = filter_fn or self._filter_fn or (lambda x: x)
537
+ excluded = set(self._config.exclude)
538
+ included = set(filter_fn(list(self._functions.keys())))
539
+
540
+ def predicate(name: str) -> bool:
541
+ if name in excluded:
542
+ return True
543
+ if not self._fn_should_be_included(name):
544
+ return True
545
+ return name not in included
546
+
547
+ return {self._get_fn_name(name): self._functions[name] for name in self._functions if predicate(name)}
548
+
549
+ def get_included_functions(
550
+ self,
551
+ filter_fn: Callable[[Sequence[str]], Sequence[str]] | None = None,
552
+ ) -> dict[str, Function]:
553
+ """
554
+ Returns a dictionary of all functions in the function group which are:
555
+ - configured to be included and added to the global function registry
556
+ - not configured to be excluded.
557
+ - not filtered out by a filter function.
558
+
559
+ Parameters
560
+ ----------
561
+ filter_fn : Callable[[Sequence[str]], Sequence[str]] | None, optional
562
+ A callback function to additionally filter the functions in the function group dynamically. If not provided
563
+ then fall back to the function group's filter function. If no filter function is set for the function group
564
+ all functions will be returned.
565
+
566
+ Returns
567
+ -------
568
+ dict[str, Function]
569
+ A dictionary of all included functions in the function group.
570
+
571
+ Raises
572
+ ------
573
+ ValueError
574
+ When the function group is configured to include functions that are not found in the group.
575
+ """
576
+ missing = set(self._config.include) - set(self._functions.keys())
577
+ if missing:
578
+ raise ValueError(f"Unknown included functions: {sorted(missing)}")
579
+ filter_fn = filter_fn or self._filter_fn or (lambda x: x)
580
+ included = set(filter_fn(list(self._config.include)))
581
+ included = {name for name in included if self._fn_should_be_included(name)}
582
+ return {self._get_fn_name(name): self._functions[name] for name in included}
583
+
584
+ def get_all_functions(
585
+ self,
586
+ filter_fn: Callable[[Sequence[str]], Sequence[str]] | None = None,
587
+ ) -> dict[str, Function]:
588
+ """
589
+ Returns a dictionary of all functions in the function group, regardless if they are included or excluded.
590
+
591
+ If a filter function has been set, the returned functions will additionally be filtered by the callback.
592
+
593
+ Parameters
594
+ ----------
595
+ filter_fn : Callable[[Sequence[str]], Sequence[str]] | None, optional
596
+ A callback function to additionally filter the functions in the function group dynamically. If not provided
597
+ then fall back to the function group's filter function. If no filter function is set for the function group
598
+ all functions will be returned.
599
+
600
+ Returns
601
+ -------
602
+ dict[str, Function]
603
+ A dictionary of all functions in the function group.
604
+ """
605
+ filter_fn = filter_fn or self._filter_fn or (lambda x: x)
606
+ included = set(filter_fn(list(self._functions.keys())))
607
+ included = {name for name in included if self._fn_should_be_included(name)}
608
+ return {self._get_fn_name(name): self._functions[name] for name in included}
609
+
610
+ def set_filter_fn(self, filter_fn: Callable[[Sequence[str]], Sequence[str]]):
611
+ """
612
+ Sets the filter function for the function group.
613
+
614
+ Parameters
615
+ ----------
616
+ filter_fn : Callable[[Sequence[str]], Sequence[str]]
617
+ The filter function to set for the function group.
618
+ """
619
+ self._filter_fn = filter_fn
620
+
621
+ def set_per_function_filter_fn(self, name: str, filter_fn: Callable[[str], bool]):
622
+ """
623
+ Sets the a per-function filter function for the a function within the function group.
624
+
625
+ Parameters
626
+ ----------
627
+ name : str
628
+ The name of the function.
629
+ filter_fn : Callable[[str], bool]
630
+ The per-function filter function to set for the function group.
631
+
632
+ Raises
633
+ ------
634
+ ValueError
635
+ When the function is not found in the function group.
636
+ """
637
+ if name not in self._functions:
638
+ raise ValueError(f"Function {name} not found in function group {self._instance_name}")
639
+ self._per_function_filter_fn[name] = filter_fn
640
+
641
+ def set_instance_name(self, instance_name: str):
642
+ """
643
+ Sets the instance name for the function group.
644
+
645
+ Parameters
646
+ ----------
647
+ instance_name : str
648
+ The instance name to set for the function group.
649
+ """
650
+ self._instance_name = instance_name
651
+
652
+ @property
653
+ def instance_name(self) -> str:
654
+ """
655
+ Returns the instance name for the function group.
656
+ """
657
+ return self._instance_name
@@ -233,7 +233,7 @@ class FunctionDescriptor:
233
233
 
234
234
  is_input_typed = all([a != sig.empty for a in annotations])
235
235
 
236
- input_type = tuple[*annotations] if is_input_typed else None # noqa: syntax-error
236
+ input_type = tuple[*annotations] if is_input_typed else None
237
237
 
238
238
  # Get the base type here removing all annotations and async generators
239
239
  output_annotation_decomp = DecomposedType(sig.return_annotation).get_base_type()
nat/builder/workflow.py CHANGED
@@ -20,6 +20,7 @@ from typing import Any
20
20
  from nat.builder.context import ContextState
21
21
  from nat.builder.embedder import EmbedderProviderInfo
22
22
  from nat.builder.function import Function
23
+ from nat.builder.function import FunctionGroup
23
24
  from nat.builder.function_base import FunctionBase
24
25
  from nat.builder.function_base import InputT
25
26
  from nat.builder.function_base import SingleOutputT
@@ -44,6 +45,7 @@ class Workflow(FunctionBase[InputT, StreamingOutputT, SingleOutputT]):
44
45
  config: Config,
45
46
  entry_fn: Function[InputT, StreamingOutputT, SingleOutputT],
46
47
  functions: dict[str, Function] | None = None,
48
+ function_groups: dict[str, FunctionGroup] | None = None,
47
49
  llms: dict[str, LLMProviderInfo] | None = None,
48
50
  embeddings: dict[str, EmbedderProviderInfo] | None = None,
49
51
  memory: dict[str, MemoryEditor] | None = None,
@@ -59,6 +61,7 @@ class Workflow(FunctionBase[InputT, StreamingOutputT, SingleOutputT]):
59
61
 
60
62
  self.config = config
61
63
  self.functions = functions or {}
64
+ self.function_groups = function_groups or {}
62
65
  self.llms = llms or {}
63
66
  self.embeddings = embeddings or {}
64
67
  self.memory = memory or {}
@@ -126,6 +129,7 @@ class Workflow(FunctionBase[InputT, StreamingOutputT, SingleOutputT]):
126
129
  config: Config,
127
130
  entry_fn: Function[InputT, StreamingOutputT, SingleOutputT],
128
131
  functions: dict[str, Function] | None = None,
132
+ function_groups: dict[str, FunctionGroup] | None = None,
129
133
  llms: dict[str, LLMProviderInfo] | None = None,
130
134
  embeddings: dict[str, EmbedderProviderInfo] | None = None,
131
135
  memory: dict[str, MemoryEditor] | None = None,
@@ -145,6 +149,7 @@ class Workflow(FunctionBase[InputT, StreamingOutputT, SingleOutputT]):
145
149
  return WorkflowImpl(config=config,
146
150
  entry_fn=entry_fn,
147
151
  functions=functions,
152
+ function_groups=function_groups,
148
153
  llms=llms,
149
154
  embeddings=embeddings,
150
155
  memory=memory,