nvidia-nat 1.3.0.dev2__py3-none-any.whl → 1.3.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (250) hide show
  1. aiq/__init__.py +2 -2
  2. nat/agent/base.py +24 -15
  3. nat/agent/dual_node.py +9 -4
  4. nat/agent/prompt_optimizer/prompt.py +68 -0
  5. nat/agent/prompt_optimizer/register.py +149 -0
  6. nat/agent/react_agent/agent.py +79 -47
  7. nat/agent/react_agent/register.py +50 -22
  8. nat/agent/reasoning_agent/reasoning_agent.py +11 -9
  9. nat/agent/register.py +1 -1
  10. nat/agent/rewoo_agent/agent.py +326 -148
  11. nat/agent/rewoo_agent/prompt.py +19 -22
  12. nat/agent/rewoo_agent/register.py +54 -27
  13. nat/agent/tool_calling_agent/agent.py +84 -28
  14. nat/agent/tool_calling_agent/register.py +51 -28
  15. nat/authentication/api_key/api_key_auth_provider.py +2 -2
  16. nat/authentication/credential_validator/bearer_token_validator.py +557 -0
  17. nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
  18. nat/authentication/interfaces.py +5 -2
  19. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +69 -36
  20. nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
  21. nat/authentication/register.py +0 -1
  22. nat/builder/builder.py +56 -24
  23. nat/builder/component_utils.py +9 -5
  24. nat/builder/context.py +68 -17
  25. nat/builder/eval_builder.py +16 -11
  26. nat/builder/framework_enum.py +1 -0
  27. nat/builder/front_end.py +1 -1
  28. nat/builder/function.py +378 -8
  29. nat/builder/function_base.py +3 -3
  30. nat/builder/function_info.py +6 -8
  31. nat/builder/user_interaction_manager.py +2 -2
  32. nat/builder/workflow.py +13 -1
  33. nat/builder/workflow_builder.py +281 -76
  34. nat/cli/cli_utils/config_override.py +2 -2
  35. nat/cli/commands/evaluate.py +1 -1
  36. nat/cli/commands/info/info.py +16 -6
  37. nat/cli/commands/info/list_channels.py +1 -1
  38. nat/cli/commands/info/list_components.py +7 -8
  39. nat/cli/commands/mcp/__init__.py +14 -0
  40. nat/cli/commands/mcp/mcp.py +986 -0
  41. nat/cli/commands/object_store/__init__.py +14 -0
  42. nat/cli/commands/object_store/object_store.py +227 -0
  43. nat/cli/commands/optimize.py +90 -0
  44. nat/cli/commands/registry/publish.py +2 -2
  45. nat/cli/commands/registry/pull.py +2 -2
  46. nat/cli/commands/registry/remove.py +2 -2
  47. nat/cli/commands/registry/search.py +15 -17
  48. nat/cli/commands/start.py +16 -5
  49. nat/cli/commands/uninstall.py +1 -1
  50. nat/cli/commands/workflow/templates/config.yml.j2 +14 -13
  51. nat/cli/commands/workflow/templates/pyproject.toml.j2 +4 -1
  52. nat/cli/commands/workflow/templates/register.py.j2 +2 -3
  53. nat/cli/commands/workflow/templates/workflow.py.j2 +35 -21
  54. nat/cli/commands/workflow/workflow_commands.py +62 -22
  55. nat/cli/entrypoint.py +8 -10
  56. nat/cli/main.py +3 -0
  57. nat/cli/register_workflow.py +38 -4
  58. nat/cli/type_registry.py +75 -6
  59. nat/control_flow/__init__.py +0 -0
  60. nat/control_flow/register.py +20 -0
  61. nat/control_flow/router_agent/__init__.py +0 -0
  62. nat/control_flow/router_agent/agent.py +329 -0
  63. nat/control_flow/router_agent/prompt.py +48 -0
  64. nat/control_flow/router_agent/register.py +91 -0
  65. nat/control_flow/sequential_executor.py +166 -0
  66. nat/data_models/agent.py +34 -0
  67. nat/data_models/api_server.py +74 -66
  68. nat/data_models/authentication.py +23 -9
  69. nat/data_models/common.py +1 -1
  70. nat/data_models/component.py +2 -0
  71. nat/data_models/component_ref.py +11 -0
  72. nat/data_models/config.py +41 -17
  73. nat/data_models/dataset_handler.py +1 -1
  74. nat/data_models/discovery_metadata.py +4 -4
  75. nat/data_models/evaluate.py +4 -1
  76. nat/data_models/function.py +34 -0
  77. nat/data_models/function_dependencies.py +14 -6
  78. nat/data_models/gated_field_mixin.py +242 -0
  79. nat/data_models/intermediate_step.py +3 -3
  80. nat/data_models/optimizable.py +119 -0
  81. nat/data_models/optimizer.py +149 -0
  82. nat/data_models/span.py +41 -3
  83. nat/data_models/swe_bench_model.py +1 -1
  84. nat/data_models/temperature_mixin.py +44 -0
  85. nat/data_models/thinking_mixin.py +86 -0
  86. nat/data_models/top_p_mixin.py +44 -0
  87. nat/embedder/nim_embedder.py +1 -1
  88. nat/embedder/openai_embedder.py +1 -1
  89. nat/embedder/register.py +0 -1
  90. nat/eval/config.py +3 -1
  91. nat/eval/dataset_handler/dataset_handler.py +71 -7
  92. nat/eval/evaluate.py +86 -31
  93. nat/eval/evaluator/base_evaluator.py +1 -1
  94. nat/eval/evaluator/evaluator_model.py +13 -0
  95. nat/eval/intermediate_step_adapter.py +1 -1
  96. nat/eval/rag_evaluator/evaluate.py +2 -2
  97. nat/eval/rag_evaluator/register.py +3 -3
  98. nat/eval/register.py +4 -1
  99. nat/eval/remote_workflow.py +3 -3
  100. nat/eval/runtime_evaluator/__init__.py +14 -0
  101. nat/eval/runtime_evaluator/evaluate.py +123 -0
  102. nat/eval/runtime_evaluator/register.py +100 -0
  103. nat/eval/swe_bench_evaluator/evaluate.py +6 -6
  104. nat/eval/trajectory_evaluator/evaluate.py +1 -1
  105. nat/eval/trajectory_evaluator/register.py +1 -1
  106. nat/eval/tunable_rag_evaluator/evaluate.py +4 -7
  107. nat/eval/utils/eval_trace_ctx.py +89 -0
  108. nat/eval/utils/weave_eval.py +18 -9
  109. nat/experimental/decorators/experimental_warning_decorator.py +27 -7
  110. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
  111. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +3 -3
  112. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +1 -1
  113. nat/experimental/test_time_compute/models/strategy_base.py +5 -4
  114. nat/experimental/test_time_compute/register.py +0 -1
  115. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -3
  116. nat/front_ends/console/authentication_flow_handler.py +82 -30
  117. nat/front_ends/console/console_front_end_plugin.py +8 -5
  118. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
  119. nat/front_ends/fastapi/dask_client_mixin.py +65 -0
  120. nat/front_ends/fastapi/fastapi_front_end_config.py +36 -5
  121. nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
  122. nat/front_ends/fastapi/fastapi_front_end_plugin.py +135 -4
  123. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +452 -282
  124. nat/front_ends/fastapi/job_store.py +518 -99
  125. nat/front_ends/fastapi/main.py +11 -19
  126. nat/front_ends/fastapi/message_handler.py +13 -14
  127. nat/front_ends/fastapi/message_validator.py +19 -19
  128. nat/front_ends/fastapi/response_helpers.py +4 -4
  129. nat/front_ends/fastapi/step_adaptor.py +2 -2
  130. nat/front_ends/fastapi/utils.py +57 -0
  131. nat/front_ends/mcp/introspection_token_verifier.py +73 -0
  132. nat/front_ends/mcp/mcp_front_end_config.py +10 -1
  133. nat/front_ends/mcp/mcp_front_end_plugin.py +45 -13
  134. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +116 -8
  135. nat/front_ends/mcp/tool_converter.py +44 -14
  136. nat/front_ends/register.py +0 -1
  137. nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
  138. nat/llm/aws_bedrock_llm.py +24 -12
  139. nat/llm/azure_openai_llm.py +13 -6
  140. nat/llm/litellm_llm.py +69 -0
  141. nat/llm/nim_llm.py +20 -8
  142. nat/llm/openai_llm.py +14 -6
  143. nat/llm/register.py +4 -1
  144. nat/llm/utils/env_config_value.py +2 -3
  145. nat/llm/utils/thinking.py +215 -0
  146. nat/meta/pypi.md +9 -9
  147. nat/object_store/register.py +0 -1
  148. nat/observability/exporter/base_exporter.py +3 -3
  149. nat/observability/exporter/file_exporter.py +1 -1
  150. nat/observability/exporter/processing_exporter.py +309 -81
  151. nat/observability/exporter/span_exporter.py +35 -15
  152. nat/observability/exporter_manager.py +7 -7
  153. nat/observability/mixin/file_mixin.py +7 -7
  154. nat/observability/mixin/redaction_config_mixin.py +42 -0
  155. nat/observability/mixin/tagging_config_mixin.py +62 -0
  156. nat/observability/mixin/type_introspection_mixin.py +420 -107
  157. nat/observability/processor/batching_processor.py +5 -7
  158. nat/observability/processor/falsy_batch_filter_processor.py +55 -0
  159. nat/observability/processor/processor.py +3 -0
  160. nat/observability/processor/processor_factory.py +70 -0
  161. nat/observability/processor/redaction/__init__.py +24 -0
  162. nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
  163. nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
  164. nat/observability/processor/redaction/redaction_processor.py +177 -0
  165. nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
  166. nat/observability/processor/span_tagging_processor.py +68 -0
  167. nat/observability/register.py +6 -4
  168. nat/profiler/calc/calc_runner.py +3 -4
  169. nat/profiler/callbacks/agno_callback_handler.py +1 -1
  170. nat/profiler/callbacks/langchain_callback_handler.py +6 -6
  171. nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
  172. nat/profiler/callbacks/semantic_kernel_callback_handler.py +3 -3
  173. nat/profiler/data_frame_row.py +1 -1
  174. nat/profiler/decorators/framework_wrapper.py +62 -13
  175. nat/profiler/decorators/function_tracking.py +160 -3
  176. nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
  177. nat/profiler/forecasting/models/linear_model.py +1 -1
  178. nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
  179. nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
  180. nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
  181. nat/profiler/inference_optimization/data_models.py +3 -3
  182. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +8 -9
  183. nat/profiler/inference_optimization/token_uniqueness.py +1 -1
  184. nat/profiler/parameter_optimization/__init__.py +0 -0
  185. nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
  186. nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
  187. nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
  188. nat/profiler/parameter_optimization/parameter_selection.py +107 -0
  189. nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
  190. nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
  191. nat/profiler/parameter_optimization/update_helpers.py +66 -0
  192. nat/profiler/profile_runner.py +14 -9
  193. nat/profiler/utils.py +4 -2
  194. nat/registry_handlers/local/local_handler.py +2 -2
  195. nat/registry_handlers/package_utils.py +1 -2
  196. nat/registry_handlers/pypi/pypi_handler.py +23 -26
  197. nat/registry_handlers/register.py +3 -4
  198. nat/registry_handlers/rest/rest_handler.py +12 -13
  199. nat/retriever/milvus/retriever.py +2 -2
  200. nat/retriever/nemo_retriever/retriever.py +1 -1
  201. nat/retriever/register.py +0 -1
  202. nat/runtime/loader.py +2 -2
  203. nat/runtime/runner.py +106 -8
  204. nat/runtime/session.py +69 -8
  205. nat/settings/global_settings.py +16 -5
  206. nat/tool/chat_completion.py +5 -2
  207. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +3 -3
  208. nat/tool/datetime_tools.py +49 -9
  209. nat/tool/document_search.py +2 -2
  210. nat/tool/github_tools.py +450 -0
  211. nat/tool/memory_tools/get_memory_tool.py +1 -1
  212. nat/tool/nvidia_rag.py +1 -1
  213. nat/tool/register.py +2 -9
  214. nat/tool/retriever.py +3 -2
  215. nat/utils/callable_utils.py +70 -0
  216. nat/utils/data_models/schema_validator.py +3 -3
  217. nat/utils/decorators.py +210 -0
  218. nat/utils/exception_handlers/automatic_retries.py +104 -51
  219. nat/utils/exception_handlers/schemas.py +1 -1
  220. nat/utils/io/yaml_tools.py +2 -2
  221. nat/utils/log_levels.py +25 -0
  222. nat/utils/reactive/base/observable_base.py +2 -2
  223. nat/utils/reactive/base/observer_base.py +1 -1
  224. nat/utils/reactive/observable.py +2 -2
  225. nat/utils/reactive/observer.py +4 -4
  226. nat/utils/reactive/subscription.py +1 -1
  227. nat/utils/settings/global_settings.py +6 -8
  228. nat/utils/type_converter.py +4 -3
  229. nat/utils/type_utils.py +9 -5
  230. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/METADATA +42 -18
  231. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/RECORD +238 -196
  232. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/entry_points.txt +1 -0
  233. nat/cli/commands/info/list_mcp.py +0 -304
  234. nat/tool/github_tools/create_github_commit.py +0 -133
  235. nat/tool/github_tools/create_github_issue.py +0 -87
  236. nat/tool/github_tools/create_github_pr.py +0 -106
  237. nat/tool/github_tools/get_github_file.py +0 -106
  238. nat/tool/github_tools/get_github_issue.py +0 -166
  239. nat/tool/github_tools/get_github_pr.py +0 -256
  240. nat/tool/github_tools/update_github_issue.py +0 -100
  241. nat/tool/mcp/exceptions.py +0 -142
  242. nat/tool/mcp/mcp_client.py +0 -255
  243. nat/tool/mcp/mcp_tool.py +0 -96
  244. nat/utils/exception_handlers/mcp.py +0 -211
  245. /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
  246. /nat/{tool/mcp → authentication/credential_validator}/__init__.py +0 -0
  247. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/WHEEL +0 -0
  248. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  249. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/licenses/LICENSE.md +0 -0
  250. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/top_level.txt +0 -0
@@ -13,74 +13,104 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import logging
17
+ from collections.abc import Awaitable
18
+ from collections.abc import Callable
19
+ from datetime import UTC
16
20
  from datetime import datetime
17
- from datetime import timezone
18
21
 
22
+ import httpx
19
23
  from authlib.integrations.httpx_client import OAuth2Client as AuthlibOAuth2Client
20
24
  from pydantic import SecretStr
21
25
 
22
26
  from nat.authentication.interfaces import AuthProviderBase
23
27
  from nat.authentication.oauth2.oauth2_auth_code_flow_provider_config import OAuth2AuthCodeFlowProviderConfig
24
28
  from nat.builder.context import Context
29
+ from nat.data_models.authentication import AuthenticatedContext
25
30
  from nat.data_models.authentication import AuthFlowType
26
31
  from nat.data_models.authentication import AuthResult
27
32
  from nat.data_models.authentication import BearerTokenCred
28
33
 
34
+ logger = logging.getLogger(__name__)
35
+
29
36
 
30
37
  class OAuth2AuthCodeFlowProvider(AuthProviderBase[OAuth2AuthCodeFlowProviderConfig]):
31
38
 
32
- def __init__(self, config: OAuth2AuthCodeFlowProviderConfig):
39
+ def __init__(self, config: OAuth2AuthCodeFlowProviderConfig, token_storage=None):
33
40
  super().__init__(config)
34
- self._authenticated_tokens: dict[str, AuthResult] = {}
35
- self._context = Context.get()
41
+ self._auth_callback = None
42
+ # Always use token storage - defaults to in-memory if not provided
43
+ if token_storage is None:
44
+ from nat.plugins.mcp.auth.token_storage import InMemoryTokenStorage
45
+ self._token_storage = InMemoryTokenStorage()
46
+ else:
47
+ self._token_storage = token_storage
36
48
 
37
49
  async def _attempt_token_refresh(self, user_id: str, auth_result: AuthResult) -> AuthResult | None:
38
50
  refresh_token = auth_result.raw.get("refresh_token")
39
51
  if not isinstance(refresh_token, str):
40
52
  return None
41
53
 
42
- with AuthlibOAuth2Client(
43
- client_id=self.config.client_id,
44
- client_secret=self.config.client_secret,
45
- ) as client:
46
- try:
54
+ try:
55
+ with AuthlibOAuth2Client(
56
+ client_id=self.config.client_id,
57
+ client_secret=self.config.client_secret,
58
+ ) as client:
47
59
  new_token_data = client.refresh_token(self.config.token_url, refresh_token=refresh_token)
48
- except Exception:
49
- # On any failure, we'll fall back to the full auth flow.
50
- return None
51
60
 
52
- expires_at_ts = new_token_data.get("expires_at")
53
- new_expires_at = datetime.fromtimestamp(expires_at_ts, tz=timezone.utc) if expires_at_ts else None
61
+ expires_at_ts = new_token_data.get("expires_at")
62
+ new_expires_at = datetime.fromtimestamp(expires_at_ts, tz=UTC) if expires_at_ts else None
54
63
 
55
- new_auth_result = AuthResult(
56
- credentials=[BearerTokenCred(token=SecretStr(new_token_data["access_token"]))],
57
- token_expires_at=new_expires_at,
58
- raw=new_token_data,
59
- )
64
+ new_auth_result = AuthResult(
65
+ credentials=[BearerTokenCred(token=SecretStr(new_token_data["access_token"]))],
66
+ token_expires_at=new_expires_at,
67
+ raw=new_token_data,
68
+ )
60
69
 
61
- self._authenticated_tokens[user_id] = new_auth_result
70
+ await self._token_storage.store(user_id, new_auth_result)
71
+ except httpx.HTTPStatusError:
72
+ return None
73
+ except httpx.RequestError:
74
+ return None
75
+ except Exception:
76
+ # On any other failure, we'll fall back to the full auth flow.
77
+ return None
62
78
 
63
79
  return new_auth_result
64
80
 
65
- async def authenticate(self, user_id: str | None = None) -> AuthResult:
66
- if user_id is None and hasattr(Context.get(), "metadata") and hasattr(
67
- Context.get().metadata, "cookies") and Context.get().metadata.cookies is not None:
68
- session_id = Context.get().metadata.cookies.get("nat-session", None)
81
+ def _set_custom_auth_callback(self,
82
+ auth_callback: Callable[[OAuth2AuthCodeFlowProviderConfig, AuthFlowType],
83
+ Awaitable[AuthenticatedContext]]):
84
+ self._auth_callback = auth_callback
85
+
86
+ async def authenticate(self, user_id: str | None = None, **kwargs) -> AuthResult:
87
+ context = Context.get()
88
+ if user_id is None and hasattr(context, "metadata") and hasattr(
89
+ context.metadata, "cookies") and context.metadata.cookies is not None:
90
+ session_id = context.metadata.cookies.get("nat-session", None)
69
91
  if not session_id:
70
92
  raise RuntimeError("Authentication failed. No session ID found. Cannot identify user.")
71
93
 
72
94
  user_id = session_id
73
95
 
74
- if user_id and user_id in self._authenticated_tokens:
75
- auth_result = self._authenticated_tokens[user_id]
76
- if not auth_result.is_expired():
77
- return auth_result
96
+ if user_id:
97
+ # Try to retrieve from token storage
98
+ auth_result = await self._token_storage.retrieve(user_id)
99
+
100
+ if auth_result:
101
+ if not auth_result.is_expired():
102
+ return auth_result
78
103
 
79
- refreshed_auth_result = await self._attempt_token_refresh(user_id, auth_result)
80
- if refreshed_auth_result:
81
- return refreshed_auth_result
104
+ refreshed_auth_result = await self._attempt_token_refresh(user_id, auth_result)
105
+ if refreshed_auth_result:
106
+ return refreshed_auth_result
107
+
108
+ # Try getting callback from the context if that's not set, use the default callback
109
+ try:
110
+ auth_callback = Context.get().user_auth_callback
111
+ except RuntimeError:
112
+ auth_callback = self._auth_callback
82
113
 
83
- auth_callback = self._context.user_auth_callback
84
114
  if not auth_callback:
85
115
  raise RuntimeError("Authentication callback not set on Context.")
86
116
 
@@ -89,19 +119,22 @@ class OAuth2AuthCodeFlowProvider(AuthProviderBase[OAuth2AuthCodeFlowProviderConf
89
119
  except Exception as e:
90
120
  raise RuntimeError(f"Authentication callback failed: {e}") from e
91
121
 
92
- auth_header = authenticated_context.headers.get("Authorization", "")
122
+ headers = authenticated_context.headers or {}
123
+ auth_header = headers.get("Authorization", "")
93
124
  if not auth_header.startswith("Bearer "):
94
125
  raise RuntimeError("Invalid Authorization header")
95
126
 
96
127
  token = auth_header.split(" ")[1]
97
128
 
129
+ # Safely access metadata
130
+ metadata = authenticated_context.metadata or {}
98
131
  auth_result = AuthResult(
99
132
  credentials=[BearerTokenCred(token=SecretStr(token))],
100
- token_expires_at=authenticated_context.metadata.get("expires_at"),
101
- raw=authenticated_context.metadata.get("raw_token"),
133
+ token_expires_at=metadata.get("expires_at"),
134
+ raw=metadata.get("raw_token") or {},
102
135
  )
103
136
 
104
137
  if user_id:
105
- self._authenticated_tokens[user_id] = auth_result
138
+ await self._token_storage.store(user_id, auth_result)
106
139
 
107
140
  return auth_result
@@ -0,0 +1,124 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from urllib.parse import urlparse
17
+
18
+ from pydantic import Field
19
+ from pydantic import field_validator
20
+ from pydantic import model_validator
21
+
22
+ from nat.data_models.authentication import AuthProviderBaseConfig
23
+
24
+
25
+ class OAuth2ResourceServerConfig(AuthProviderBaseConfig, name="oauth2_resource_server"):
26
+ """OAuth 2.0 Resource Server authentication configuration.
27
+
28
+ Supports:
29
+ • JWT access tokens via JWKS / OIDC Discovery / issuer fallback
30
+ • Opaque access tokens via RFC 7662 introspection
31
+ """
32
+
33
+ issuer_url: str = Field(
34
+ description=("The unique issuer identifier for an authorization server. "
35
+ "Required for validation and used to derive the default JWKS URI "
36
+ "(<issuer_url>/.well-known/jwks.json) if `jwks_uri` and `discovery_url` are not provided."), )
37
+ scopes: list[str] = Field(
38
+ default_factory=list,
39
+ description="Scopes required by this API. Validation ensures the token grants all listed scopes.",
40
+ )
41
+ audience: str | None = Field(
42
+ default=None,
43
+ description=(
44
+ "Expected audience (`aud`) claim for this API. If set, validation will reject tokens without this audience."
45
+ ),
46
+ )
47
+
48
+ # JWT verification params
49
+ jwks_uri: str | None = Field(
50
+ default=None,
51
+ description=("Direct JWKS endpoint URI for JWT signature verification. "
52
+ "Optional if discovery or issuer is provided."),
53
+ )
54
+ discovery_url: str | None = Field(
55
+ default=None,
56
+ description=("OIDC discovery metadata URL. Used to automatically resolve JWKS and introspection endpoints."),
57
+ )
58
+
59
+ # Opaque token (introspection) params
60
+ introspection_endpoint: str | None = Field(
61
+ default=None,
62
+ description=("RFC 7662 token introspection endpoint. "
63
+ "Required for opaque token validation and must be used with `client_id` and `client_secret`."),
64
+ )
65
+ client_id: str | None = Field(
66
+ default=None,
67
+ description="OAuth2 client ID for authenticating to the introspection endpoint (opaque token validation).",
68
+ )
69
+ client_secret: str | None = Field(
70
+ default=None,
71
+ description="OAuth2 client secret for authenticating to the introspection endpoint (opaque token validation).",
72
+ )
73
+
74
+ @staticmethod
75
+ def _is_https_or_localhost(url: str) -> bool:
76
+ try:
77
+ value = urlparse(url)
78
+ if not value.scheme or not value.netloc:
79
+ return False
80
+ if value.scheme == "https":
81
+ return True
82
+ return value.scheme == "http" and (value.hostname in {"localhost", "127.0.0.1", "::1"})
83
+ except Exception:
84
+ return False
85
+
86
+ @field_validator("issuer_url", "jwks_uri", "discovery_url", "introspection_endpoint")
87
+ @classmethod
88
+ def _require_valid_url(cls, value: str | None, info):
89
+ if value is None:
90
+ return value
91
+ if not cls._is_https_or_localhost(value):
92
+ raise ValueError(f"{info.field_name} must be HTTPS (http allowed only for localhost). Got: {value}")
93
+ return value
94
+
95
+ # ---------- Cross-field validation: ensure at least one viable path ----------
96
+
97
+ @model_validator(mode="after")
98
+ def _ensure_verification_path(self):
99
+ """
100
+ JWT path viable if any of: jwks_uri OR discovery_url OR issuer_url (fallback JWKS).
101
+ Opaque path viable if: introspection_endpoint AND client_id AND client_secret.
102
+ """
103
+ has_jwt_path = bool(self.jwks_uri or self.discovery_url or self.issuer_url)
104
+ has_opaque_path = bool(self.introspection_endpoint and self.client_id and self.client_secret)
105
+
106
+ # If introspection endpoint is set, enforce creds are present
107
+ if self.introspection_endpoint:
108
+ missing = []
109
+ if not self.client_id:
110
+ missing.append("client_id")
111
+ if not self.client_secret:
112
+ missing.append("client_secret")
113
+ if missing:
114
+ raise ValueError(
115
+ f"introspection_endpoint configured but missing required credentials: {', '.join(missing)}")
116
+
117
+ # Require at least one path
118
+ if not (has_jwt_path or has_opaque_path):
119
+ raise ValueError("Invalid configuration: no verification method available. "
120
+ "Configure one of the following:\n"
121
+ " • JWT path: set jwks_uri OR discovery_url OR issuer_url (for JWKS fallback)\n"
122
+ " • Opaque path: set introspection_endpoint + client_id + client_secret")
123
+
124
+ return self
@@ -13,7 +13,6 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- # pylint: disable=unused-import
17
16
  # flake8: noqa
18
17
 
19
18
  from nat.authentication.api_key import register as register_api_key
nat/builder/builder.py CHANGED
@@ -24,9 +24,11 @@ from nat.authentication.interfaces import AuthProviderBase
24
24
  from nat.builder.context import Context
25
25
  from nat.builder.framework_enum import LLMFrameworkEnum
26
26
  from nat.builder.function import Function
27
+ from nat.builder.function import FunctionGroup
27
28
  from nat.data_models.authentication import AuthProviderBaseConfig
28
29
  from nat.data_models.component_ref import AuthenticationRef
29
30
  from nat.data_models.component_ref import EmbedderRef
31
+ from nat.data_models.component_ref import FunctionGroupRef
30
32
  from nat.data_models.component_ref import FunctionRef
31
33
  from nat.data_models.component_ref import LLMRef
32
34
  from nat.data_models.component_ref import MemoryRef
@@ -36,20 +38,25 @@ from nat.data_models.component_ref import TTCStrategyRef
36
38
  from nat.data_models.embedder import EmbedderBaseConfig
37
39
  from nat.data_models.evaluator import EvaluatorBaseConfig
38
40
  from nat.data_models.function import FunctionBaseConfig
41
+ from nat.data_models.function import FunctionGroupBaseConfig
39
42
  from nat.data_models.function_dependencies import FunctionDependencies
40
43
  from nat.data_models.llm import LLMBaseConfig
41
44
  from nat.data_models.memory import MemoryBaseConfig
42
45
  from nat.data_models.object_store import ObjectStoreBaseConfig
43
46
  from nat.data_models.retriever import RetrieverBaseConfig
44
47
  from nat.data_models.ttc_strategy import TTCStrategyBaseConfig
48
+ from nat.experimental.decorators.experimental_warning_decorator import experimental
45
49
  from nat.experimental.test_time_compute.models.stage_enums import PipelineTypeEnum
46
50
  from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum
47
51
  from nat.memory.interfaces import MemoryEditor
48
52
  from nat.object_store.interfaces import ObjectStore
49
53
  from nat.retriever.interface import Retriever
50
54
 
55
+ if typing.TYPE_CHECKING:
56
+ from nat.experimental.test_time_compute.models.strategy_base import StrategyBase
51
57
 
52
- class UserManagerHolder():
58
+
59
+ class UserManagerHolder:
53
60
 
54
61
  def __init__(self, context: Context) -> None:
55
62
  self._context = context
@@ -58,24 +65,40 @@ class UserManagerHolder():
58
65
  return self._context.user_manager.get_id()
59
66
 
60
67
 
61
- class Builder(ABC): # pylint: disable=too-many-public-methods
68
+ class Builder(ABC):
62
69
 
63
70
  @abstractmethod
64
71
  async def add_function(self, name: str | FunctionRef, config: FunctionBaseConfig) -> Function:
65
72
  pass
66
73
 
67
74
  @abstractmethod
68
- def get_function(self, name: str | FunctionRef) -> Function:
75
+ async def add_function_group(self, name: str | FunctionGroupRef, config: FunctionGroupBaseConfig) -> FunctionGroup:
76
+ pass
77
+
78
+ @abstractmethod
79
+ async def get_function(self, name: str | FunctionRef) -> Function:
69
80
  pass
70
81
 
71
- def get_functions(self, function_names: Sequence[str | FunctionRef]) -> list[Function]:
82
+ @abstractmethod
83
+ async def get_function_group(self, name: str | FunctionGroupRef) -> FunctionGroup:
84
+ pass
72
85
 
73
- return [self.get_function(name) for name in function_names]
86
+ async def get_functions(self, function_names: Sequence[str | FunctionRef]) -> list[Function]:
87
+ tasks = [self.get_function(name) for name in function_names]
88
+ return list(await asyncio.gather(*tasks, return_exceptions=False))
89
+
90
+ async def get_function_groups(self, function_group_names: Sequence[str | FunctionGroupRef]) -> list[FunctionGroup]:
91
+ tasks = [self.get_function_group(name) for name in function_group_names]
92
+ return list(await asyncio.gather(*tasks, return_exceptions=False))
74
93
 
75
94
  @abstractmethod
76
95
  def get_function_config(self, name: str | FunctionRef) -> FunctionBaseConfig:
77
96
  pass
78
97
 
98
+ @abstractmethod
99
+ def get_function_group_config(self, name: str | FunctionGroupRef) -> FunctionGroupBaseConfig:
100
+ pass
101
+
79
102
  @abstractmethod
80
103
  async def set_workflow(self, config: FunctionBaseConfig) -> Function:
81
104
  pass
@@ -88,17 +111,18 @@ class Builder(ABC): # pylint: disable=too-many-public-methods
88
111
  def get_workflow_config(self) -> FunctionBaseConfig:
89
112
  pass
90
113
 
91
- def get_tools(self, tool_names: Sequence[str | FunctionRef],
92
- wrapper_type: LLMFrameworkEnum | str) -> list[typing.Any]:
93
-
94
- return [self.get_tool(fn_name=n, wrapper_type=wrapper_type) for n in tool_names]
114
+ @abstractmethod
115
+ async def get_tools(self,
116
+ tool_names: Sequence[str | FunctionRef | FunctionGroupRef],
117
+ wrapper_type: LLMFrameworkEnum | str) -> list[typing.Any]:
118
+ pass
95
119
 
96
120
  @abstractmethod
97
- def get_tool(self, fn_name: str | FunctionRef, wrapper_type: LLMFrameworkEnum | str) -> typing.Any:
121
+ async def get_tool(self, fn_name: str | FunctionRef, wrapper_type: LLMFrameworkEnum | str) -> typing.Any:
98
122
  pass
99
123
 
100
124
  @abstractmethod
101
- async def add_llm(self, name: str | LLMRef, config: LLMBaseConfig):
125
+ async def add_llm(self, name: str | LLMRef, config: LLMBaseConfig) -> typing.Any:
102
126
  pass
103
127
 
104
128
  @abstractmethod
@@ -119,7 +143,9 @@ class Builder(ABC): # pylint: disable=too-many-public-methods
119
143
  pass
120
144
 
121
145
  @abstractmethod
122
- async def add_auth_provider(self, name: str | AuthenticationRef, config: AuthProviderBaseConfig):
146
+ @experimental(feature_name="Authentication")
147
+ async def add_auth_provider(self, name: str | AuthenticationRef,
148
+ config: AuthProviderBaseConfig) -> AuthProviderBase:
123
149
  pass
124
150
 
125
151
  @abstractmethod
@@ -135,7 +161,7 @@ class Builder(ABC): # pylint: disable=too-many-public-methods
135
161
  return list(auth_providers)
136
162
 
137
163
  @abstractmethod
138
- async def add_object_store(self, name: str | ObjectStoreRef, config: ObjectStoreBaseConfig):
164
+ async def add_object_store(self, name: str | ObjectStoreRef, config: ObjectStoreBaseConfig) -> ObjectStore:
139
165
  pass
140
166
 
141
167
  async def get_object_store_clients(self, object_store_names: Sequence[str | ObjectStoreRef]) -> list[ObjectStore]:
@@ -153,7 +179,7 @@ class Builder(ABC): # pylint: disable=too-many-public-methods
153
179
  pass
154
180
 
155
181
  @abstractmethod
156
- async def add_embedder(self, name: str | EmbedderRef, config: EmbedderBaseConfig):
182
+ async def add_embedder(self, name: str | EmbedderRef, config: EmbedderBaseConfig) -> None:
157
183
  pass
158
184
 
159
185
  async def get_embedders(self, embedder_names: Sequence[str | EmbedderRef],
@@ -174,17 +200,18 @@ class Builder(ABC): # pylint: disable=too-many-public-methods
174
200
  pass
175
201
 
176
202
  @abstractmethod
177
- async def add_memory_client(self, name: str | MemoryRef, config: MemoryBaseConfig):
203
+ async def add_memory_client(self, name: str | MemoryRef, config: MemoryBaseConfig) -> MemoryEditor:
178
204
  pass
179
205
 
180
- def get_memory_clients(self, memory_names: Sequence[str | MemoryRef]) -> list[MemoryEditor]:
206
+ async def get_memory_clients(self, memory_names: Sequence[str | MemoryRef]) -> list[MemoryEditor]:
181
207
  """
182
208
  Return a list of memory clients for the specified names.
183
209
  """
184
- return [self.get_memory_client(n) for n in memory_names]
210
+ tasks = [self.get_memory_client(n) for n in memory_names]
211
+ return list(await asyncio.gather(*tasks, return_exceptions=False))
185
212
 
186
213
  @abstractmethod
187
- def get_memory_client(self, memory_name: str | MemoryRef) -> MemoryEditor:
214
+ async def get_memory_client(self, memory_name: str | MemoryRef) -> MemoryEditor:
188
215
  """
189
216
  Return the instantiated memory client for the given name.
190
217
  """
@@ -195,12 +222,12 @@ class Builder(ABC): # pylint: disable=too-many-public-methods
195
222
  pass
196
223
 
197
224
  @abstractmethod
198
- async def add_retriever(self, name: str | RetrieverRef, config: RetrieverBaseConfig):
225
+ async def add_retriever(self, name: str | RetrieverRef, config: RetrieverBaseConfig) -> None:
199
226
  pass
200
227
 
201
228
  async def get_retrievers(self,
202
229
  retriever_names: Sequence[str | RetrieverRef],
203
- wrapper_type: LLMFrameworkEnum | str | None = None):
230
+ wrapper_type: LLMFrameworkEnum | str | None = None) -> list[Retriever]:
204
231
 
205
232
  tasks = [self.get_retriever(n, wrapper_type=wrapper_type) for n in retriever_names]
206
233
 
@@ -232,14 +259,15 @@ class Builder(ABC): # pylint: disable=too-many-public-methods
232
259
  pass
233
260
 
234
261
  @abstractmethod
235
- async def add_ttc_strategy(self, name: str | str, config: TTCStrategyBaseConfig):
262
+ @experimental(feature_name="TTC")
263
+ async def add_ttc_strategy(self, name: str | TTCStrategyRef, config: TTCStrategyBaseConfig):
236
264
  pass
237
265
 
238
266
  @abstractmethod
239
267
  async def get_ttc_strategy(self,
240
268
  strategy_name: str | TTCStrategyRef,
241
269
  pipeline_type: PipelineTypeEnum,
242
- stage_type: StageTypeEnum):
270
+ stage_type: StageTypeEnum) -> "StrategyBase":
243
271
  pass
244
272
 
245
273
  @abstractmethod
@@ -257,8 +285,12 @@ class Builder(ABC): # pylint: disable=too-many-public-methods
257
285
  def get_function_dependencies(self, fn_name: str) -> FunctionDependencies:
258
286
  pass
259
287
 
288
+ @abstractmethod
289
+ def get_function_group_dependencies(self, fn_name: str) -> FunctionDependencies:
290
+ pass
291
+
260
292
 
261
- class EvalBuilder(Builder):
293
+ class EvalBuilder(ABC):
262
294
 
263
295
  @abstractmethod
264
296
  async def add_evaluator(self, name: str, config: EvaluatorBaseConfig):
@@ -281,5 +313,5 @@ class EvalBuilder(Builder):
281
313
  pass
282
314
 
283
315
  @abstractmethod
284
- def get_all_tools(self, wrapper_type: LLMFrameworkEnum | str) -> list[typing.Any]:
316
+ async def get_all_tools(self, wrapper_type: LLMFrameworkEnum | str) -> list[typing.Any]:
285
317
  pass
@@ -30,6 +30,7 @@ from nat.data_models.component_ref import generate_instance_id
30
30
  from nat.data_models.config import Config
31
31
  from nat.data_models.embedder import EmbedderBaseConfig
32
32
  from nat.data_models.function import FunctionBaseConfig
33
+ from nat.data_models.function import FunctionGroupBaseConfig
33
34
  from nat.data_models.llm import LLMBaseConfig
34
35
  from nat.data_models.memory import MemoryBaseConfig
35
36
  from nat.data_models.object_store import ObjectStoreBaseConfig
@@ -48,6 +49,7 @@ _component_group_order = [
48
49
  ComponentGroup.OBJECT_STORES,
49
50
  ComponentGroup.RETRIEVERS,
50
51
  ComponentGroup.TTC_STRATEGIES,
52
+ ComponentGroup.FUNCTION_GROUPS,
51
53
  ComponentGroup.FUNCTIONS,
52
54
  ]
53
55
 
@@ -107,6 +109,8 @@ def group_from_component(component: TypedBaseModel) -> ComponentGroup | None:
107
109
  return ComponentGroup.EMBEDDERS
108
110
  if (isinstance(component, FunctionBaseConfig)):
109
111
  return ComponentGroup.FUNCTIONS
112
+ if (isinstance(component, FunctionGroupBaseConfig)):
113
+ return ComponentGroup.FUNCTION_GROUPS
110
114
  if (isinstance(component, LLMBaseConfig)):
111
115
  return ComponentGroup.LLMS
112
116
  if (isinstance(component, MemoryBaseConfig)):
@@ -154,7 +158,7 @@ def recursive_componentref_discovery(cls: TypedBaseModel, value: typing.Any,
154
158
  yield from recursive_componentref_discovery(cls, field_data, field_info.annotation)
155
159
  if (decomposed_type.is_union):
156
160
  for arg in decomposed_type.args:
157
- if arg is typing.Any or (isinstance(value, DecomposedType(arg).root)):
161
+ if arg is typing.Any or DecomposedType(arg).is_instance(value):
158
162
  yield from recursive_componentref_discovery(cls, value, arg)
159
163
  else:
160
164
  for arg in decomposed_type.args:
@@ -174,7 +178,7 @@ def update_dependency_graph(config: "Config", instance_config: TypedBaseModel,
174
178
  nx.DiGraph: An dependency graph that has been updated with the provided runtime instance.
175
179
  """
176
180
 
177
- for field_name, field_info in instance_config.model_fields.items():
181
+ for field_name, field_info in type(instance_config).model_fields.items():
178
182
 
179
183
  for instance_id, value_node in recursive_componentref_discovery(
180
184
  instance_config,
@@ -254,9 +258,9 @@ def build_dependency_sequence(config: "Config") -> list[ComponentInstanceData]:
254
258
  runtime instance references.
255
259
  """
256
260
 
257
- total_node_count = len(config.embedders) + len(config.functions) + len(config.llms) + len(config.memory) + len(
258
- config.object_stores) + len(config.retrievers) + len(config.ttc_strategies) + len(
259
- config.authentication) + 1 # +1 for the workflow
261
+ total_node_count = (len(config.embedders) + len(config.functions) + len(config.function_groups) + len(config.llms) +
262
+ len(config.memory) + len(config.object_stores) + len(config.retrievers) +
263
+ len(config.ttc_strategies) + len(config.authentication) + 1) # +1 for the workflow
260
264
 
261
265
  dependency_map: dict
262
266
  dependency_graph: nx.DiGraph