nvidia-nat 1.3.0.dev2__py3-none-any.whl → 1.3.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (242) hide show
  1. aiq/__init__.py +2 -2
  2. nat/agent/base.py +24 -15
  3. nat/agent/dual_node.py +9 -4
  4. nat/agent/prompt_optimizer/prompt.py +68 -0
  5. nat/agent/prompt_optimizer/register.py +149 -0
  6. nat/agent/react_agent/agent.py +79 -47
  7. nat/agent/react_agent/register.py +41 -21
  8. nat/agent/reasoning_agent/reasoning_agent.py +11 -9
  9. nat/agent/register.py +1 -1
  10. nat/agent/rewoo_agent/agent.py +326 -148
  11. nat/agent/rewoo_agent/prompt.py +19 -22
  12. nat/agent/rewoo_agent/register.py +46 -26
  13. nat/agent/tool_calling_agent/agent.py +84 -28
  14. nat/agent/tool_calling_agent/register.py +51 -28
  15. nat/authentication/api_key/api_key_auth_provider.py +2 -2
  16. nat/authentication/credential_validator/bearer_token_validator.py +557 -0
  17. nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
  18. nat/authentication/interfaces.py +5 -2
  19. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +40 -20
  20. nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
  21. nat/authentication/register.py +0 -1
  22. nat/builder/builder.py +56 -24
  23. nat/builder/component_utils.py +9 -5
  24. nat/builder/context.py +46 -11
  25. nat/builder/eval_builder.py +16 -11
  26. nat/builder/framework_enum.py +1 -0
  27. nat/builder/front_end.py +1 -1
  28. nat/builder/function.py +378 -8
  29. nat/builder/function_base.py +3 -3
  30. nat/builder/function_info.py +6 -8
  31. nat/builder/user_interaction_manager.py +2 -2
  32. nat/builder/workflow.py +13 -1
  33. nat/builder/workflow_builder.py +281 -76
  34. nat/cli/cli_utils/config_override.py +2 -2
  35. nat/cli/commands/evaluate.py +1 -1
  36. nat/cli/commands/info/info.py +16 -6
  37. nat/cli/commands/info/list_channels.py +1 -1
  38. nat/cli/commands/info/list_components.py +7 -8
  39. nat/cli/commands/mcp/__init__.py +14 -0
  40. nat/cli/commands/mcp/mcp.py +986 -0
  41. nat/cli/commands/object_store/__init__.py +14 -0
  42. nat/cli/commands/object_store/object_store.py +227 -0
  43. nat/cli/commands/optimize.py +90 -0
  44. nat/cli/commands/registry/publish.py +2 -2
  45. nat/cli/commands/registry/pull.py +2 -2
  46. nat/cli/commands/registry/remove.py +2 -2
  47. nat/cli/commands/registry/search.py +15 -17
  48. nat/cli/commands/start.py +16 -5
  49. nat/cli/commands/uninstall.py +1 -1
  50. nat/cli/commands/workflow/templates/config.yml.j2 +0 -1
  51. nat/cli/commands/workflow/templates/pyproject.toml.j2 +4 -1
  52. nat/cli/commands/workflow/templates/register.py.j2 +0 -1
  53. nat/cli/commands/workflow/workflow_commands.py +9 -13
  54. nat/cli/entrypoint.py +8 -10
  55. nat/cli/register_workflow.py +38 -4
  56. nat/cli/type_registry.py +75 -6
  57. nat/control_flow/__init__.py +0 -0
  58. nat/control_flow/register.py +20 -0
  59. nat/control_flow/router_agent/__init__.py +0 -0
  60. nat/control_flow/router_agent/agent.py +329 -0
  61. nat/control_flow/router_agent/prompt.py +48 -0
  62. nat/control_flow/router_agent/register.py +91 -0
  63. nat/control_flow/sequential_executor.py +166 -0
  64. nat/data_models/agent.py +34 -0
  65. nat/data_models/api_server.py +10 -10
  66. nat/data_models/authentication.py +23 -9
  67. nat/data_models/common.py +1 -1
  68. nat/data_models/component.py +2 -0
  69. nat/data_models/component_ref.py +11 -0
  70. nat/data_models/config.py +41 -17
  71. nat/data_models/dataset_handler.py +1 -1
  72. nat/data_models/discovery_metadata.py +4 -4
  73. nat/data_models/evaluate.py +4 -1
  74. nat/data_models/function.py +34 -0
  75. nat/data_models/function_dependencies.py +14 -6
  76. nat/data_models/gated_field_mixin.py +242 -0
  77. nat/data_models/intermediate_step.py +3 -3
  78. nat/data_models/optimizable.py +119 -0
  79. nat/data_models/optimizer.py +149 -0
  80. nat/data_models/swe_bench_model.py +1 -1
  81. nat/data_models/temperature_mixin.py +44 -0
  82. nat/data_models/thinking_mixin.py +86 -0
  83. nat/data_models/top_p_mixin.py +44 -0
  84. nat/embedder/nim_embedder.py +1 -1
  85. nat/embedder/openai_embedder.py +1 -1
  86. nat/embedder/register.py +0 -1
  87. nat/eval/config.py +3 -1
  88. nat/eval/dataset_handler/dataset_handler.py +71 -7
  89. nat/eval/evaluate.py +86 -31
  90. nat/eval/evaluator/base_evaluator.py +1 -1
  91. nat/eval/evaluator/evaluator_model.py +13 -0
  92. nat/eval/intermediate_step_adapter.py +1 -1
  93. nat/eval/rag_evaluator/evaluate.py +2 -2
  94. nat/eval/rag_evaluator/register.py +3 -3
  95. nat/eval/register.py +4 -1
  96. nat/eval/remote_workflow.py +3 -3
  97. nat/eval/runtime_evaluator/__init__.py +14 -0
  98. nat/eval/runtime_evaluator/evaluate.py +123 -0
  99. nat/eval/runtime_evaluator/register.py +100 -0
  100. nat/eval/swe_bench_evaluator/evaluate.py +6 -6
  101. nat/eval/trajectory_evaluator/evaluate.py +1 -1
  102. nat/eval/trajectory_evaluator/register.py +1 -1
  103. nat/eval/tunable_rag_evaluator/evaluate.py +4 -7
  104. nat/eval/utils/eval_trace_ctx.py +89 -0
  105. nat/eval/utils/weave_eval.py +18 -9
  106. nat/experimental/decorators/experimental_warning_decorator.py +27 -7
  107. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
  108. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +3 -3
  109. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +1 -1
  110. nat/experimental/test_time_compute/models/strategy_base.py +5 -4
  111. nat/experimental/test_time_compute/register.py +0 -1
  112. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -3
  113. nat/front_ends/console/authentication_flow_handler.py +82 -30
  114. nat/front_ends/console/console_front_end_plugin.py +8 -5
  115. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
  116. nat/front_ends/fastapi/dask_client_mixin.py +65 -0
  117. nat/front_ends/fastapi/fastapi_front_end_config.py +36 -5
  118. nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
  119. nat/front_ends/fastapi/fastapi_front_end_plugin.py +135 -4
  120. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +481 -281
  121. nat/front_ends/fastapi/job_store.py +518 -99
  122. nat/front_ends/fastapi/main.py +11 -19
  123. nat/front_ends/fastapi/message_handler.py +13 -14
  124. nat/front_ends/fastapi/message_validator.py +17 -19
  125. nat/front_ends/fastapi/response_helpers.py +4 -4
  126. nat/front_ends/fastapi/step_adaptor.py +2 -2
  127. nat/front_ends/fastapi/utils.py +57 -0
  128. nat/front_ends/mcp/introspection_token_verifier.py +73 -0
  129. nat/front_ends/mcp/mcp_front_end_config.py +10 -1
  130. nat/front_ends/mcp/mcp_front_end_plugin.py +45 -13
  131. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +116 -8
  132. nat/front_ends/mcp/tool_converter.py +44 -14
  133. nat/front_ends/register.py +0 -1
  134. nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
  135. nat/llm/aws_bedrock_llm.py +24 -12
  136. nat/llm/azure_openai_llm.py +13 -6
  137. nat/llm/litellm_llm.py +69 -0
  138. nat/llm/nim_llm.py +20 -8
  139. nat/llm/openai_llm.py +14 -6
  140. nat/llm/register.py +4 -1
  141. nat/llm/utils/env_config_value.py +2 -3
  142. nat/llm/utils/thinking.py +215 -0
  143. nat/meta/pypi.md +9 -9
  144. nat/object_store/register.py +0 -1
  145. nat/observability/exporter/base_exporter.py +3 -3
  146. nat/observability/exporter/file_exporter.py +1 -1
  147. nat/observability/exporter/processing_exporter.py +309 -81
  148. nat/observability/exporter/span_exporter.py +1 -1
  149. nat/observability/exporter_manager.py +7 -7
  150. nat/observability/mixin/file_mixin.py +7 -7
  151. nat/observability/mixin/redaction_config_mixin.py +42 -0
  152. nat/observability/mixin/tagging_config_mixin.py +62 -0
  153. nat/observability/mixin/type_introspection_mixin.py +420 -107
  154. nat/observability/processor/batching_processor.py +5 -7
  155. nat/observability/processor/falsy_batch_filter_processor.py +55 -0
  156. nat/observability/processor/processor.py +3 -0
  157. nat/observability/processor/processor_factory.py +70 -0
  158. nat/observability/processor/redaction/__init__.py +24 -0
  159. nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
  160. nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
  161. nat/observability/processor/redaction/redaction_processor.py +177 -0
  162. nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
  163. nat/observability/processor/span_tagging_processor.py +68 -0
  164. nat/observability/register.py +6 -4
  165. nat/profiler/calc/calc_runner.py +3 -4
  166. nat/profiler/callbacks/agno_callback_handler.py +1 -1
  167. nat/profiler/callbacks/langchain_callback_handler.py +6 -6
  168. nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
  169. nat/profiler/callbacks/semantic_kernel_callback_handler.py +3 -3
  170. nat/profiler/data_frame_row.py +1 -1
  171. nat/profiler/decorators/framework_wrapper.py +62 -13
  172. nat/profiler/decorators/function_tracking.py +160 -3
  173. nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
  174. nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
  175. nat/profiler/inference_optimization/data_models.py +3 -3
  176. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +7 -8
  177. nat/profiler/inference_optimization/token_uniqueness.py +1 -1
  178. nat/profiler/parameter_optimization/__init__.py +0 -0
  179. nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
  180. nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
  181. nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
  182. nat/profiler/parameter_optimization/parameter_selection.py +107 -0
  183. nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
  184. nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
  185. nat/profiler/parameter_optimization/update_helpers.py +66 -0
  186. nat/profiler/profile_runner.py +14 -9
  187. nat/profiler/utils.py +4 -2
  188. nat/registry_handlers/local/local_handler.py +2 -2
  189. nat/registry_handlers/package_utils.py +1 -2
  190. nat/registry_handlers/pypi/pypi_handler.py +23 -26
  191. nat/registry_handlers/register.py +3 -4
  192. nat/registry_handlers/rest/rest_handler.py +12 -13
  193. nat/retriever/milvus/retriever.py +2 -2
  194. nat/retriever/nemo_retriever/retriever.py +1 -1
  195. nat/retriever/register.py +0 -1
  196. nat/runtime/loader.py +2 -2
  197. nat/runtime/runner.py +3 -2
  198. nat/runtime/session.py +43 -8
  199. nat/settings/global_settings.py +16 -5
  200. nat/tool/chat_completion.py +5 -2
  201. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +3 -3
  202. nat/tool/datetime_tools.py +49 -9
  203. nat/tool/document_search.py +2 -2
  204. nat/tool/github_tools.py +450 -0
  205. nat/tool/nvidia_rag.py +1 -1
  206. nat/tool/register.py +2 -9
  207. nat/tool/retriever.py +3 -2
  208. nat/utils/callable_utils.py +70 -0
  209. nat/utils/data_models/schema_validator.py +3 -3
  210. nat/utils/exception_handlers/automatic_retries.py +104 -51
  211. nat/utils/exception_handlers/schemas.py +1 -1
  212. nat/utils/io/yaml_tools.py +2 -2
  213. nat/utils/log_levels.py +25 -0
  214. nat/utils/reactive/base/observable_base.py +2 -2
  215. nat/utils/reactive/base/observer_base.py +1 -1
  216. nat/utils/reactive/observable.py +2 -2
  217. nat/utils/reactive/observer.py +4 -4
  218. nat/utils/reactive/subscription.py +1 -1
  219. nat/utils/settings/global_settings.py +6 -8
  220. nat/utils/type_converter.py +4 -3
  221. nat/utils/type_utils.py +9 -5
  222. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/METADATA +42 -16
  223. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/RECORD +230 -189
  224. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/entry_points.txt +1 -0
  225. nat/cli/commands/info/list_mcp.py +0 -304
  226. nat/tool/github_tools/create_github_commit.py +0 -133
  227. nat/tool/github_tools/create_github_issue.py +0 -87
  228. nat/tool/github_tools/create_github_pr.py +0 -106
  229. nat/tool/github_tools/get_github_file.py +0 -106
  230. nat/tool/github_tools/get_github_issue.py +0 -166
  231. nat/tool/github_tools/get_github_pr.py +0 -256
  232. nat/tool/github_tools/update_github_issue.py +0 -100
  233. nat/tool/mcp/exceptions.py +0 -142
  234. nat/tool/mcp/mcp_client.py +0 -255
  235. nat/tool/mcp/mcp_tool.py +0 -96
  236. nat/utils/exception_handlers/mcp.py +0 -211
  237. /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
  238. /nat/{tool/mcp → authentication/credential_validator}/__init__.py +0 -0
  239. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/WHEEL +0 -0
  240. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  241. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/licenses/LICENSE.md +0 -0
  242. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/top_level.txt +0 -0
@@ -13,56 +13,71 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import logging
17
+ from collections.abc import Callable
18
+ from datetime import UTC
16
19
  from datetime import datetime
17
- from datetime import timezone
18
20
 
21
+ import httpx
19
22
  from authlib.integrations.httpx_client import OAuth2Client as AuthlibOAuth2Client
20
23
  from pydantic import SecretStr
21
24
 
22
25
  from nat.authentication.interfaces import AuthProviderBase
23
26
  from nat.authentication.oauth2.oauth2_auth_code_flow_provider_config import OAuth2AuthCodeFlowProviderConfig
24
27
  from nat.builder.context import Context
28
+ from nat.data_models.authentication import AuthenticatedContext
25
29
  from nat.data_models.authentication import AuthFlowType
26
30
  from nat.data_models.authentication import AuthResult
27
31
  from nat.data_models.authentication import BearerTokenCred
28
32
 
33
+ logger = logging.getLogger(__name__)
34
+
29
35
 
30
36
  class OAuth2AuthCodeFlowProvider(AuthProviderBase[OAuth2AuthCodeFlowProviderConfig]):
31
37
 
32
38
  def __init__(self, config: OAuth2AuthCodeFlowProviderConfig):
33
39
  super().__init__(config)
34
40
  self._authenticated_tokens: dict[str, AuthResult] = {}
35
- self._context = Context.get()
41
+ self._auth_callback = None
36
42
 
37
43
  async def _attempt_token_refresh(self, user_id: str, auth_result: AuthResult) -> AuthResult | None:
38
44
  refresh_token = auth_result.raw.get("refresh_token")
39
45
  if not isinstance(refresh_token, str):
40
46
  return None
41
47
 
42
- with AuthlibOAuth2Client(
43
- client_id=self.config.client_id,
44
- client_secret=self.config.client_secret,
45
- ) as client:
46
- try:
48
+ try:
49
+ with AuthlibOAuth2Client(
50
+ client_id=self.config.client_id,
51
+ client_secret=self.config.client_secret,
52
+ ) as client:
47
53
  new_token_data = client.refresh_token(self.config.token_url, refresh_token=refresh_token)
48
- except Exception:
49
- # On any failure, we'll fall back to the full auth flow.
50
- return None
51
54
 
52
- expires_at_ts = new_token_data.get("expires_at")
53
- new_expires_at = datetime.fromtimestamp(expires_at_ts, tz=timezone.utc) if expires_at_ts else None
55
+ expires_at_ts = new_token_data.get("expires_at")
56
+ new_expires_at = datetime.fromtimestamp(expires_at_ts, tz=UTC) if expires_at_ts else None
54
57
 
55
- new_auth_result = AuthResult(
56
- credentials=[BearerTokenCred(token=SecretStr(new_token_data["access_token"]))],
57
- token_expires_at=new_expires_at,
58
- raw=new_token_data,
59
- )
58
+ new_auth_result = AuthResult(
59
+ credentials=[BearerTokenCred(token=SecretStr(new_token_data["access_token"]))],
60
+ token_expires_at=new_expires_at,
61
+ raw=new_token_data,
62
+ )
60
63
 
61
- self._authenticated_tokens[user_id] = new_auth_result
64
+ self._authenticated_tokens[user_id] = new_auth_result
65
+ except httpx.HTTPStatusError:
66
+ return None
67
+ except httpx.RequestError:
68
+ return None
69
+ except Exception:
70
+ # On any other failure, we'll fall back to the full auth flow.
71
+ return None
62
72
 
63
73
  return new_auth_result
64
74
 
65
- async def authenticate(self, user_id: str | None = None) -> AuthResult:
75
+ def _set_custom_auth_callback(self,
76
+ auth_callback: Callable[[OAuth2AuthCodeFlowProviderConfig, AuthFlowType],
77
+ AuthenticatedContext]):
78
+ self._auth_callback = auth_callback
79
+
80
+ async def authenticate(self, user_id: str | None = None, **kwargs) -> AuthResult:
66
81
  if user_id is None and hasattr(Context.get(), "metadata") and hasattr(
67
82
  Context.get().metadata, "cookies") and Context.get().metadata.cookies is not None:
68
83
  session_id = Context.get().metadata.cookies.get("nat-session", None)
@@ -80,7 +95,12 @@ class OAuth2AuthCodeFlowProvider(AuthProviderBase[OAuth2AuthCodeFlowProviderConf
80
95
  if refreshed_auth_result:
81
96
  return refreshed_auth_result
82
97
 
83
- auth_callback = self._context.user_auth_callback
98
+ # Try getting callback from the context if that's not set, use the default callback
99
+ try:
100
+ auth_callback = Context.get().user_auth_callback
101
+ except RuntimeError:
102
+ auth_callback = self._auth_callback
103
+
84
104
  if not auth_callback:
85
105
  raise RuntimeError("Authentication callback not set on Context.")
86
106
 
@@ -0,0 +1,124 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from urllib.parse import urlparse
17
+
18
+ from pydantic import Field
19
+ from pydantic import field_validator
20
+ from pydantic import model_validator
21
+
22
+ from nat.data_models.authentication import AuthProviderBaseConfig
23
+
24
+
25
+ class OAuth2ResourceServerConfig(AuthProviderBaseConfig, name="oauth2_resource_server"):
26
+ """OAuth 2.0 Resource Server authentication configuration.
27
+
28
+ Supports:
29
+ • JWT access tokens via JWKS / OIDC Discovery / issuer fallback
30
+ • Opaque access tokens via RFC 7662 introspection
31
+ """
32
+
33
+ issuer_url: str = Field(
34
+ description=("The unique issuer identifier for an authorization server. "
35
+ "Required for validation and used to derive the default JWKS URI "
36
+ "(<issuer_url>/.well-known/jwks.json) if `jwks_uri` and `discovery_url` are not provided."), )
37
+ scopes: list[str] = Field(
38
+ default_factory=list,
39
+ description="Scopes required by this API. Validation ensures the token grants all listed scopes.",
40
+ )
41
+ audience: str | None = Field(
42
+ default=None,
43
+ description=(
44
+ "Expected audience (`aud`) claim for this API. If set, validation will reject tokens without this audience."
45
+ ),
46
+ )
47
+
48
+ # JWT verification params
49
+ jwks_uri: str | None = Field(
50
+ default=None,
51
+ description=("Direct JWKS endpoint URI for JWT signature verification. "
52
+ "Optional if discovery or issuer is provided."),
53
+ )
54
+ discovery_url: str | None = Field(
55
+ default=None,
56
+ description=("OIDC discovery metadata URL. Used to automatically resolve JWKS and introspection endpoints."),
57
+ )
58
+
59
+ # Opaque token (introspection) params
60
+ introspection_endpoint: str | None = Field(
61
+ default=None,
62
+ description=("RFC 7662 token introspection endpoint. "
63
+ "Required for opaque token validation and must be used with `client_id` and `client_secret`."),
64
+ )
65
+ client_id: str | None = Field(
66
+ default=None,
67
+ description="OAuth2 client ID for authenticating to the introspection endpoint (opaque token validation).",
68
+ )
69
+ client_secret: str | None = Field(
70
+ default=None,
71
+ description="OAuth2 client secret for authenticating to the introspection endpoint (opaque token validation).",
72
+ )
73
+
74
+ @staticmethod
75
+ def _is_https_or_localhost(url: str) -> bool:
76
+ try:
77
+ value = urlparse(url)
78
+ if not value.scheme or not value.netloc:
79
+ return False
80
+ if value.scheme == "https":
81
+ return True
82
+ return value.scheme == "http" and (value.hostname in {"localhost", "127.0.0.1", "::1"})
83
+ except Exception:
84
+ return False
85
+
86
+ @field_validator("issuer_url", "jwks_uri", "discovery_url", "introspection_endpoint")
87
+ @classmethod
88
+ def _require_valid_url(cls, value: str | None, info):
89
+ if value is None:
90
+ return value
91
+ if not cls._is_https_or_localhost(value):
92
+ raise ValueError(f"{info.field_name} must be HTTPS (http allowed only for localhost). Got: {value}")
93
+ return value
94
+
95
+ # ---------- Cross-field validation: ensure at least one viable path ----------
96
+
97
+ @model_validator(mode="after")
98
+ def _ensure_verification_path(self):
99
+ """
100
+ JWT path viable if any of: jwks_uri OR discovery_url OR issuer_url (fallback JWKS).
101
+ Opaque path viable if: introspection_endpoint AND client_id AND client_secret.
102
+ """
103
+ has_jwt_path = bool(self.jwks_uri or self.discovery_url or self.issuer_url)
104
+ has_opaque_path = bool(self.introspection_endpoint and self.client_id and self.client_secret)
105
+
106
+ # If introspection endpoint is set, enforce creds are present
107
+ if self.introspection_endpoint:
108
+ missing = []
109
+ if not self.client_id:
110
+ missing.append("client_id")
111
+ if not self.client_secret:
112
+ missing.append("client_secret")
113
+ if missing:
114
+ raise ValueError(
115
+ f"introspection_endpoint configured but missing required credentials: {', '.join(missing)}")
116
+
117
+ # Require at least one path
118
+ if not (has_jwt_path or has_opaque_path):
119
+ raise ValueError("Invalid configuration: no verification method available. "
120
+ "Configure one of the following:\n"
121
+ " • JWT path: set jwks_uri OR discovery_url OR issuer_url (for JWKS fallback)\n"
122
+ " • Opaque path: set introspection_endpoint + client_id + client_secret")
123
+
124
+ return self
@@ -13,7 +13,6 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- # pylint: disable=unused-import
17
16
  # flake8: noqa
18
17
 
19
18
  from nat.authentication.api_key import register as register_api_key
nat/builder/builder.py CHANGED
@@ -24,9 +24,11 @@ from nat.authentication.interfaces import AuthProviderBase
24
24
  from nat.builder.context import Context
25
25
  from nat.builder.framework_enum import LLMFrameworkEnum
26
26
  from nat.builder.function import Function
27
+ from nat.builder.function import FunctionGroup
27
28
  from nat.data_models.authentication import AuthProviderBaseConfig
28
29
  from nat.data_models.component_ref import AuthenticationRef
29
30
  from nat.data_models.component_ref import EmbedderRef
31
+ from nat.data_models.component_ref import FunctionGroupRef
30
32
  from nat.data_models.component_ref import FunctionRef
31
33
  from nat.data_models.component_ref import LLMRef
32
34
  from nat.data_models.component_ref import MemoryRef
@@ -36,20 +38,25 @@ from nat.data_models.component_ref import TTCStrategyRef
36
38
  from nat.data_models.embedder import EmbedderBaseConfig
37
39
  from nat.data_models.evaluator import EvaluatorBaseConfig
38
40
  from nat.data_models.function import FunctionBaseConfig
41
+ from nat.data_models.function import FunctionGroupBaseConfig
39
42
  from nat.data_models.function_dependencies import FunctionDependencies
40
43
  from nat.data_models.llm import LLMBaseConfig
41
44
  from nat.data_models.memory import MemoryBaseConfig
42
45
  from nat.data_models.object_store import ObjectStoreBaseConfig
43
46
  from nat.data_models.retriever import RetrieverBaseConfig
44
47
  from nat.data_models.ttc_strategy import TTCStrategyBaseConfig
48
+ from nat.experimental.decorators.experimental_warning_decorator import experimental
45
49
  from nat.experimental.test_time_compute.models.stage_enums import PipelineTypeEnum
46
50
  from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum
47
51
  from nat.memory.interfaces import MemoryEditor
48
52
  from nat.object_store.interfaces import ObjectStore
49
53
  from nat.retriever.interface import Retriever
50
54
 
55
+ if typing.TYPE_CHECKING:
56
+ from nat.experimental.test_time_compute.models.strategy_base import StrategyBase
51
57
 
52
- class UserManagerHolder():
58
+
59
+ class UserManagerHolder:
53
60
 
54
61
  def __init__(self, context: Context) -> None:
55
62
  self._context = context
@@ -58,24 +65,40 @@ class UserManagerHolder():
58
65
  return self._context.user_manager.get_id()
59
66
 
60
67
 
61
- class Builder(ABC): # pylint: disable=too-many-public-methods
68
+ class Builder(ABC):
62
69
 
63
70
  @abstractmethod
64
71
  async def add_function(self, name: str | FunctionRef, config: FunctionBaseConfig) -> Function:
65
72
  pass
66
73
 
67
74
  @abstractmethod
68
- def get_function(self, name: str | FunctionRef) -> Function:
75
+ async def add_function_group(self, name: str | FunctionGroupRef, config: FunctionGroupBaseConfig) -> FunctionGroup:
76
+ pass
77
+
78
+ @abstractmethod
79
+ async def get_function(self, name: str | FunctionRef) -> Function:
69
80
  pass
70
81
 
71
- def get_functions(self, function_names: Sequence[str | FunctionRef]) -> list[Function]:
82
+ @abstractmethod
83
+ async def get_function_group(self, name: str | FunctionGroupRef) -> FunctionGroup:
84
+ pass
72
85
 
73
- return [self.get_function(name) for name in function_names]
86
+ async def get_functions(self, function_names: Sequence[str | FunctionRef]) -> list[Function]:
87
+ tasks = [self.get_function(name) for name in function_names]
88
+ return list(await asyncio.gather(*tasks, return_exceptions=False))
89
+
90
+ async def get_function_groups(self, function_group_names: Sequence[str | FunctionGroupRef]) -> list[FunctionGroup]:
91
+ tasks = [self.get_function_group(name) for name in function_group_names]
92
+ return list(await asyncio.gather(*tasks, return_exceptions=False))
74
93
 
75
94
  @abstractmethod
76
95
  def get_function_config(self, name: str | FunctionRef) -> FunctionBaseConfig:
77
96
  pass
78
97
 
98
+ @abstractmethod
99
+ def get_function_group_config(self, name: str | FunctionGroupRef) -> FunctionGroupBaseConfig:
100
+ pass
101
+
79
102
  @abstractmethod
80
103
  async def set_workflow(self, config: FunctionBaseConfig) -> Function:
81
104
  pass
@@ -88,17 +111,18 @@ class Builder(ABC): # pylint: disable=too-many-public-methods
88
111
  def get_workflow_config(self) -> FunctionBaseConfig:
89
112
  pass
90
113
 
91
- def get_tools(self, tool_names: Sequence[str | FunctionRef],
92
- wrapper_type: LLMFrameworkEnum | str) -> list[typing.Any]:
93
-
94
- return [self.get_tool(fn_name=n, wrapper_type=wrapper_type) for n in tool_names]
114
+ @abstractmethod
115
+ async def get_tools(self,
116
+ tool_names: Sequence[str | FunctionRef | FunctionGroupRef],
117
+ wrapper_type: LLMFrameworkEnum | str) -> list[typing.Any]:
118
+ pass
95
119
 
96
120
  @abstractmethod
97
- def get_tool(self, fn_name: str | FunctionRef, wrapper_type: LLMFrameworkEnum | str) -> typing.Any:
121
+ async def get_tool(self, fn_name: str | FunctionRef, wrapper_type: LLMFrameworkEnum | str) -> typing.Any:
98
122
  pass
99
123
 
100
124
  @abstractmethod
101
- async def add_llm(self, name: str | LLMRef, config: LLMBaseConfig):
125
+ async def add_llm(self, name: str | LLMRef, config: LLMBaseConfig) -> typing.Any:
102
126
  pass
103
127
 
104
128
  @abstractmethod
@@ -119,7 +143,9 @@ class Builder(ABC): # pylint: disable=too-many-public-methods
119
143
  pass
120
144
 
121
145
  @abstractmethod
122
- async def add_auth_provider(self, name: str | AuthenticationRef, config: AuthProviderBaseConfig):
146
+ @experimental(feature_name="Authentication")
147
+ async def add_auth_provider(self, name: str | AuthenticationRef,
148
+ config: AuthProviderBaseConfig) -> AuthProviderBase:
123
149
  pass
124
150
 
125
151
  @abstractmethod
@@ -135,7 +161,7 @@ class Builder(ABC): # pylint: disable=too-many-public-methods
135
161
  return list(auth_providers)
136
162
 
137
163
  @abstractmethod
138
- async def add_object_store(self, name: str | ObjectStoreRef, config: ObjectStoreBaseConfig):
164
+ async def add_object_store(self, name: str | ObjectStoreRef, config: ObjectStoreBaseConfig) -> ObjectStore:
139
165
  pass
140
166
 
141
167
  async def get_object_store_clients(self, object_store_names: Sequence[str | ObjectStoreRef]) -> list[ObjectStore]:
@@ -153,7 +179,7 @@ class Builder(ABC): # pylint: disable=too-many-public-methods
153
179
  pass
154
180
 
155
181
  @abstractmethod
156
- async def add_embedder(self, name: str | EmbedderRef, config: EmbedderBaseConfig):
182
+ async def add_embedder(self, name: str | EmbedderRef, config: EmbedderBaseConfig) -> None:
157
183
  pass
158
184
 
159
185
  async def get_embedders(self, embedder_names: Sequence[str | EmbedderRef],
@@ -174,17 +200,18 @@ class Builder(ABC): # pylint: disable=too-many-public-methods
174
200
  pass
175
201
 
176
202
  @abstractmethod
177
- async def add_memory_client(self, name: str | MemoryRef, config: MemoryBaseConfig):
203
+ async def add_memory_client(self, name: str | MemoryRef, config: MemoryBaseConfig) -> MemoryEditor:
178
204
  pass
179
205
 
180
- def get_memory_clients(self, memory_names: Sequence[str | MemoryRef]) -> list[MemoryEditor]:
206
+ async def get_memory_clients(self, memory_names: Sequence[str | MemoryRef]) -> list[MemoryEditor]:
181
207
  """
182
208
  Return a list of memory clients for the specified names.
183
209
  """
184
- return [self.get_memory_client(n) for n in memory_names]
210
+ tasks = [self.get_memory_client(n) for n in memory_names]
211
+ return list(await asyncio.gather(*tasks, return_exceptions=False))
185
212
 
186
213
  @abstractmethod
187
- def get_memory_client(self, memory_name: str | MemoryRef) -> MemoryEditor:
214
+ async def get_memory_client(self, memory_name: str | MemoryRef) -> MemoryEditor:
188
215
  """
189
216
  Return the instantiated memory client for the given name.
190
217
  """
@@ -195,12 +222,12 @@ class Builder(ABC): # pylint: disable=too-many-public-methods
195
222
  pass
196
223
 
197
224
  @abstractmethod
198
- async def add_retriever(self, name: str | RetrieverRef, config: RetrieverBaseConfig):
225
+ async def add_retriever(self, name: str | RetrieverRef, config: RetrieverBaseConfig) -> None:
199
226
  pass
200
227
 
201
228
  async def get_retrievers(self,
202
229
  retriever_names: Sequence[str | RetrieverRef],
203
- wrapper_type: LLMFrameworkEnum | str | None = None):
230
+ wrapper_type: LLMFrameworkEnum | str | None = None) -> list[Retriever]:
204
231
 
205
232
  tasks = [self.get_retriever(n, wrapper_type=wrapper_type) for n in retriever_names]
206
233
 
@@ -232,14 +259,15 @@ class Builder(ABC): # pylint: disable=too-many-public-methods
232
259
  pass
233
260
 
234
261
  @abstractmethod
235
- async def add_ttc_strategy(self, name: str | str, config: TTCStrategyBaseConfig):
262
+ @experimental(feature_name="TTC")
263
+ async def add_ttc_strategy(self, name: str | TTCStrategyRef, config: TTCStrategyBaseConfig):
236
264
  pass
237
265
 
238
266
  @abstractmethod
239
267
  async def get_ttc_strategy(self,
240
268
  strategy_name: str | TTCStrategyRef,
241
269
  pipeline_type: PipelineTypeEnum,
242
- stage_type: StageTypeEnum):
270
+ stage_type: StageTypeEnum) -> "StrategyBase":
243
271
  pass
244
272
 
245
273
  @abstractmethod
@@ -257,8 +285,12 @@ class Builder(ABC): # pylint: disable=too-many-public-methods
257
285
  def get_function_dependencies(self, fn_name: str) -> FunctionDependencies:
258
286
  pass
259
287
 
288
+ @abstractmethod
289
+ def get_function_group_dependencies(self, fn_name: str) -> FunctionDependencies:
290
+ pass
291
+
260
292
 
261
- class EvalBuilder(Builder):
293
+ class EvalBuilder(ABC):
262
294
 
263
295
  @abstractmethod
264
296
  async def add_evaluator(self, name: str, config: EvaluatorBaseConfig):
@@ -281,5 +313,5 @@ class EvalBuilder(Builder):
281
313
  pass
282
314
 
283
315
  @abstractmethod
284
- def get_all_tools(self, wrapper_type: LLMFrameworkEnum | str) -> list[typing.Any]:
316
+ async def get_all_tools(self, wrapper_type: LLMFrameworkEnum | str) -> list[typing.Any]:
285
317
  pass
@@ -30,6 +30,7 @@ from nat.data_models.component_ref import generate_instance_id
30
30
  from nat.data_models.config import Config
31
31
  from nat.data_models.embedder import EmbedderBaseConfig
32
32
  from nat.data_models.function import FunctionBaseConfig
33
+ from nat.data_models.function import FunctionGroupBaseConfig
33
34
  from nat.data_models.llm import LLMBaseConfig
34
35
  from nat.data_models.memory import MemoryBaseConfig
35
36
  from nat.data_models.object_store import ObjectStoreBaseConfig
@@ -48,6 +49,7 @@ _component_group_order = [
48
49
  ComponentGroup.OBJECT_STORES,
49
50
  ComponentGroup.RETRIEVERS,
50
51
  ComponentGroup.TTC_STRATEGIES,
52
+ ComponentGroup.FUNCTION_GROUPS,
51
53
  ComponentGroup.FUNCTIONS,
52
54
  ]
53
55
 
@@ -107,6 +109,8 @@ def group_from_component(component: TypedBaseModel) -> ComponentGroup | None:
107
109
  return ComponentGroup.EMBEDDERS
108
110
  if (isinstance(component, FunctionBaseConfig)):
109
111
  return ComponentGroup.FUNCTIONS
112
+ if (isinstance(component, FunctionGroupBaseConfig)):
113
+ return ComponentGroup.FUNCTION_GROUPS
110
114
  if (isinstance(component, LLMBaseConfig)):
111
115
  return ComponentGroup.LLMS
112
116
  if (isinstance(component, MemoryBaseConfig)):
@@ -154,7 +158,7 @@ def recursive_componentref_discovery(cls: TypedBaseModel, value: typing.Any,
154
158
  yield from recursive_componentref_discovery(cls, field_data, field_info.annotation)
155
159
  if (decomposed_type.is_union):
156
160
  for arg in decomposed_type.args:
157
- if arg is typing.Any or (isinstance(value, DecomposedType(arg).root)):
161
+ if arg is typing.Any or DecomposedType(arg).is_instance(value):
158
162
  yield from recursive_componentref_discovery(cls, value, arg)
159
163
  else:
160
164
  for arg in decomposed_type.args:
@@ -174,7 +178,7 @@ def update_dependency_graph(config: "Config", instance_config: TypedBaseModel,
174
178
  nx.DiGraph: An dependency graph that has been updated with the provided runtime instance.
175
179
  """
176
180
 
177
- for field_name, field_info in instance_config.model_fields.items():
181
+ for field_name, field_info in type(instance_config).model_fields.items():
178
182
 
179
183
  for instance_id, value_node in recursive_componentref_discovery(
180
184
  instance_config,
@@ -254,9 +258,9 @@ def build_dependency_sequence(config: "Config") -> list[ComponentInstanceData]:
254
258
  runtime instance references.
255
259
  """
256
260
 
257
- total_node_count = len(config.embedders) + len(config.functions) + len(config.llms) + len(config.memory) + len(
258
- config.object_stores) + len(config.retrievers) + len(config.ttc_strategies) + len(
259
- config.authentication) + 1 # +1 for the workflow
261
+ total_node_count = (len(config.embedders) + len(config.functions) + len(config.function_groups) + len(config.llms) +
262
+ len(config.memory) + len(config.object_stores) + len(config.retrievers) +
263
+ len(config.ttc_strategies) + len(config.authentication) + 1) # +1 for the workflow
260
264
 
261
265
  dependency_map: dict
262
266
  dependency_graph: nx.DiGraph
nat/builder/context.py CHANGED
@@ -31,6 +31,7 @@ from nat.data_models.intermediate_step import IntermediateStep
31
31
  from nat.data_models.intermediate_step import IntermediateStepPayload
32
32
  from nat.data_models.intermediate_step import IntermediateStepType
33
33
  from nat.data_models.intermediate_step import StreamEventData
34
+ from nat.data_models.intermediate_step import TraceMetadata
34
35
  from nat.data_models.invocation_node import InvocationNode
35
36
  from nat.runtime.user_metadata import RequestAttributes
36
37
  from nat.utils.reactive.subject import Subject
@@ -38,13 +39,13 @@ from nat.utils.reactive.subject import Subject
38
39
 
39
40
  class Singleton(type):
40
41
 
41
- def __init__(cls, name, bases, dict): # pylint: disable=W0622
42
- super(Singleton, cls).__init__(name, bases, dict)
42
+ def __init__(cls, name, bases, dict):
43
+ super().__init__(name, bases, dict)
43
44
  cls.instance = None
44
45
 
45
46
  def __call__(cls, *args, **kw):
46
47
  if cls.instance is None:
47
- cls.instance = super(Singleton, cls).__call__(*args, **kw)
48
+ cls.instance = super().__call__(*args, **kw)
48
49
  return cls.instance
49
50
 
50
51
 
@@ -65,14 +66,13 @@ class ContextState(metaclass=Singleton):
65
66
 
66
67
  def __init__(self):
67
68
  self.conversation_id: ContextVar[str | None] = ContextVar("conversation_id", default=None)
69
+ self.user_message_id: ContextVar[str | None] = ContextVar("user_message_id", default=None)
68
70
  self.input_message: ContextVar[typing.Any] = ContextVar("input_message", default=None)
69
71
  self.user_manager: ContextVar[typing.Any] = ContextVar("user_manager", default=None)
70
- self.metadata: ContextVar[RequestAttributes] = ContextVar("request_attributes", default=RequestAttributes())
71
- self.event_stream: ContextVar[Subject[IntermediateStep] | None] = ContextVar("event_stream", default=Subject())
72
- self.active_function: ContextVar[InvocationNode] = ContextVar("active_function",
73
- default=InvocationNode(function_id="root",
74
- function_name="root"))
75
- self.active_span_id_stack: ContextVar[list[str]] = ContextVar("active_span_id_stack", default=["root"])
72
+ self._metadata: ContextVar[RequestAttributes | None] = ContextVar("request_attributes", default=None)
73
+ self._event_stream: ContextVar[Subject[IntermediateStep] | None] = ContextVar("event_stream", default=None)
74
+ self._active_function: ContextVar[InvocationNode | None] = ContextVar("active_function", default=None)
75
+ self._active_span_id_stack: ContextVar[list[str] | None] = ContextVar("active_span_id_stack", default=None)
76
76
 
77
77
  # Default is a lambda no-op which returns NoneType
78
78
  self.user_input_callback: ContextVar[Callable[[InteractionPrompt], Awaitable[HumanResponse | None]]
@@ -83,6 +83,30 @@ class ContextState(metaclass=Singleton):
83
83
  Awaitable[AuthenticatedContext]]
84
84
  | None] = ContextVar("user_auth_callback", default=None)
85
85
 
86
+ @property
87
+ def metadata(self) -> ContextVar[RequestAttributes]:
88
+ if self._metadata.get() is None:
89
+ self._metadata.set(RequestAttributes())
90
+ return typing.cast(ContextVar[RequestAttributes], self._metadata)
91
+
92
+ @property
93
+ def active_function(self) -> ContextVar[InvocationNode]:
94
+ if self._active_function.get() is None:
95
+ self._active_function.set(InvocationNode(function_id="root", function_name="root"))
96
+ return typing.cast(ContextVar[InvocationNode], self._active_function)
97
+
98
+ @property
99
+ def event_stream(self) -> ContextVar[Subject[IntermediateStep]]:
100
+ if self._event_stream.get() is None:
101
+ self._event_stream.set(Subject())
102
+ return typing.cast(ContextVar[Subject[IntermediateStep]], self._event_stream)
103
+
104
+ @property
105
+ def active_span_id_stack(self) -> ContextVar[list[str]]:
106
+ if self._active_span_id_stack.get() is None:
107
+ self._active_span_id_stack.set(["root"])
108
+ return typing.cast(ContextVar[list[str]], self._active_span_id_stack)
109
+
86
110
  @staticmethod
87
111
  def get() -> "ContextState":
88
112
  return ContextState()
@@ -165,8 +189,18 @@ class Context:
165
189
  """
166
190
  return self._context_state.conversation_id.get()
167
191
 
192
+ @property
193
+ def user_message_id(self) -> str | None:
194
+ """
195
+ This property retrieves the user message ID which is the unique identifier for the current user message.
196
+ """
197
+ return self._context_state.user_message_id.get()
198
+
168
199
  @contextmanager
169
- def push_active_function(self, function_name: str, input_data: typing.Any | None):
200
+ def push_active_function(self,
201
+ function_name: str,
202
+ input_data: typing.Any | None,
203
+ metadata: dict[str, typing.Any] | TraceMetadata | None = None):
170
204
  """
171
205
  Set the 'active_function' in context, push an invocation node,
172
206
  AND create an OTel child span for that function call.
@@ -187,7 +221,8 @@ class Context:
187
221
  IntermediateStepPayload(UUID=current_function_id,
188
222
  event_type=IntermediateStepType.FUNCTION_START,
189
223
  name=function_name,
190
- data=StreamEventData(input=input_data)))
224
+ data=StreamEventData(input=input_data),
225
+ metadata=metadata))
191
226
 
192
227
  manager = ActiveFunctionContextManager()
193
228
 
@@ -13,6 +13,7 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import asyncio
16
17
  import dataclasses
17
18
  import logging
18
19
  from contextlib import asynccontextmanager
@@ -61,7 +62,7 @@ class WorkflowEvalBuilder(WorkflowBuilder, EvalBuilder):
61
62
  # Store the evaluator
62
63
  self._evaluators[name] = ConfiguredEvaluator(config=config, instance=info_obj)
63
64
  except Exception as e:
64
- logger.error("Error %s adding evaluator `%s` with config `%s`", e, name, config, exc_info=True)
65
+ logger.error("Error %s adding evaluator `%s` with config `%s`", e, name, config)
65
66
  raise
66
67
 
67
68
  @override
@@ -90,17 +91,20 @@ class WorkflowEvalBuilder(WorkflowBuilder, EvalBuilder):
90
91
  return self.eval_general_config.output_dir
91
92
 
92
93
  @override
93
- def get_all_tools(self, wrapper_type: LLMFrameworkEnum | str):
94
- tools = []
94
+ async def get_all_tools(self, wrapper_type: LLMFrameworkEnum | str):
95
95
  tool_wrapper_reg = self._registry.get_tool_wrapper(llm_framework=wrapper_type)
96
- for fn_name in self._functions:
97
- fn = self.get_function(fn_name)
96
+
97
+ async def get_tool(fn_name: str):
98
+ fn = await self.get_function(fn_name)
98
99
  try:
99
- tools.append(tool_wrapper_reg.build_fn(fn_name, fn, self))
100
+ return tool_wrapper_reg.build_fn(fn_name, fn, self)
100
101
  except Exception:
101
- logger.exception("Error fetching tool `%s`", fn_name, exc_info=True)
102
+ logger.exception("Error fetching tool `%s`", fn_name)
103
+ return None
102
104
 
103
- return tools
105
+ tasks = [get_tool(fn_name) for fn_name in self._functions]
106
+ tools = await asyncio.gather(*tasks, return_exceptions=False)
107
+ return [tool for tool in tools if tool is not None]
104
108
 
105
109
  def _log_build_failure_evaluator(self,
106
110
  failing_evaluator_name: str,
@@ -127,11 +131,12 @@ class WorkflowEvalBuilder(WorkflowBuilder, EvalBuilder):
127
131
  remaining_components,
128
132
  original_error)
129
133
 
130
- async def populate_builder(self, config: Config):
134
+ @override
135
+ async def populate_builder(self, config: Config, skip_workflow: bool = False):
131
136
  # Skip setting workflow if workflow config is EmptyFunctionConfig
132
- skip_workflow = isinstance(config.workflow, EmptyFunctionConfig)
137
+ skip_workflow = skip_workflow or isinstance(config.workflow, EmptyFunctionConfig)
133
138
 
134
- await super().populate_builder(config, skip_workflow)
139
+ await super().populate_builder(config, skip_workflow=skip_workflow)
135
140
 
136
141
  # Initialize progress tracking for evaluators
137
142
  completed_evaluators = []