nvidia-nat 1.3.0a20250910__py3-none-any.whl → 1.4.0a20251112__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (213) hide show
  1. nat/agent/base.py +13 -8
  2. nat/agent/prompt_optimizer/prompt.py +68 -0
  3. nat/agent/prompt_optimizer/register.py +149 -0
  4. nat/agent/react_agent/agent.py +6 -5
  5. nat/agent/react_agent/register.py +49 -39
  6. nat/agent/reasoning_agent/reasoning_agent.py +17 -15
  7. nat/agent/register.py +2 -0
  8. nat/agent/responses_api_agent/__init__.py +14 -0
  9. nat/agent/responses_api_agent/register.py +126 -0
  10. nat/agent/rewoo_agent/agent.py +304 -117
  11. nat/agent/rewoo_agent/prompt.py +19 -22
  12. nat/agent/rewoo_agent/register.py +51 -38
  13. nat/agent/tool_calling_agent/agent.py +75 -17
  14. nat/agent/tool_calling_agent/register.py +46 -23
  15. nat/authentication/api_key/api_key_auth_provider.py +6 -11
  16. nat/authentication/api_key/api_key_auth_provider_config.py +8 -5
  17. nat/authentication/credential_validator/__init__.py +14 -0
  18. nat/authentication/credential_validator/bearer_token_validator.py +557 -0
  19. nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
  20. nat/authentication/interfaces.py +5 -2
  21. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +69 -36
  22. nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +2 -1
  23. nat/authentication/oauth2/oauth2_resource_server_config.py +125 -0
  24. nat/builder/builder.py +55 -23
  25. nat/builder/component_utils.py +9 -5
  26. nat/builder/context.py +54 -15
  27. nat/builder/eval_builder.py +14 -9
  28. nat/builder/framework_enum.py +1 -0
  29. nat/builder/front_end.py +1 -1
  30. nat/builder/function.py +370 -0
  31. nat/builder/function_info.py +1 -1
  32. nat/builder/intermediate_step_manager.py +38 -2
  33. nat/builder/workflow.py +5 -0
  34. nat/builder/workflow_builder.py +306 -54
  35. nat/cli/cli_utils/config_override.py +1 -1
  36. nat/cli/commands/info/info.py +16 -6
  37. nat/cli/commands/mcp/__init__.py +14 -0
  38. nat/cli/commands/mcp/mcp.py +986 -0
  39. nat/cli/commands/optimize.py +90 -0
  40. nat/cli/commands/start.py +1 -1
  41. nat/cli/commands/workflow/templates/config.yml.j2 +14 -13
  42. nat/cli/commands/workflow/templates/register.py.j2 +2 -2
  43. nat/cli/commands/workflow/templates/workflow.py.j2 +35 -21
  44. nat/cli/commands/workflow/workflow_commands.py +60 -18
  45. nat/cli/entrypoint.py +15 -11
  46. nat/cli/main.py +3 -0
  47. nat/cli/register_workflow.py +38 -4
  48. nat/cli/type_registry.py +72 -1
  49. nat/control_flow/__init__.py +0 -0
  50. nat/control_flow/register.py +20 -0
  51. nat/control_flow/router_agent/__init__.py +0 -0
  52. nat/control_flow/router_agent/agent.py +329 -0
  53. nat/control_flow/router_agent/prompt.py +48 -0
  54. nat/control_flow/router_agent/register.py +91 -0
  55. nat/control_flow/sequential_executor.py +166 -0
  56. nat/data_models/agent.py +34 -0
  57. nat/data_models/api_server.py +199 -69
  58. nat/data_models/authentication.py +23 -9
  59. nat/data_models/common.py +47 -0
  60. nat/data_models/component.py +2 -0
  61. nat/data_models/component_ref.py +11 -0
  62. nat/data_models/config.py +41 -17
  63. nat/data_models/dataset_handler.py +4 -3
  64. nat/data_models/function.py +34 -0
  65. nat/data_models/function_dependencies.py +8 -0
  66. nat/data_models/intermediate_step.py +9 -1
  67. nat/data_models/llm.py +15 -1
  68. nat/data_models/openai_mcp.py +46 -0
  69. nat/data_models/optimizable.py +208 -0
  70. nat/data_models/optimizer.py +161 -0
  71. nat/data_models/span.py +41 -3
  72. nat/data_models/thinking_mixin.py +2 -2
  73. nat/embedder/azure_openai_embedder.py +2 -1
  74. nat/embedder/nim_embedder.py +3 -2
  75. nat/embedder/openai_embedder.py +3 -2
  76. nat/eval/config.py +1 -1
  77. nat/eval/dataset_handler/dataset_downloader.py +3 -2
  78. nat/eval/dataset_handler/dataset_filter.py +34 -2
  79. nat/eval/evaluate.py +10 -3
  80. nat/eval/evaluator/base_evaluator.py +1 -1
  81. nat/eval/rag_evaluator/evaluate.py +7 -4
  82. nat/eval/register.py +4 -0
  83. nat/eval/runtime_evaluator/__init__.py +14 -0
  84. nat/eval/runtime_evaluator/evaluate.py +123 -0
  85. nat/eval/runtime_evaluator/register.py +100 -0
  86. nat/eval/swe_bench_evaluator/evaluate.py +1 -1
  87. nat/eval/trajectory_evaluator/register.py +1 -1
  88. nat/eval/tunable_rag_evaluator/evaluate.py +1 -1
  89. nat/eval/usage_stats.py +2 -0
  90. nat/eval/utils/output_uploader.py +3 -2
  91. nat/eval/utils/weave_eval.py +17 -3
  92. nat/experimental/decorators/experimental_warning_decorator.py +27 -7
  93. nat/experimental/test_time_compute/functions/execute_score_select_function.py +1 -1
  94. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
  95. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +1 -1
  96. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +3 -3
  97. nat/experimental/test_time_compute/models/strategy_base.py +2 -2
  98. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -1
  99. nat/front_ends/console/authentication_flow_handler.py +82 -30
  100. nat/front_ends/console/console_front_end_plugin.py +19 -7
  101. nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +1 -1
  102. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
  103. nat/front_ends/fastapi/dask_client_mixin.py +65 -0
  104. nat/front_ends/fastapi/fastapi_front_end_config.py +25 -3
  105. nat/front_ends/fastapi/fastapi_front_end_plugin.py +140 -3
  106. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +445 -265
  107. nat/front_ends/fastapi/job_store.py +518 -99
  108. nat/front_ends/fastapi/main.py +11 -19
  109. nat/front_ends/fastapi/message_handler.py +69 -44
  110. nat/front_ends/fastapi/message_validator.py +8 -7
  111. nat/front_ends/fastapi/utils.py +57 -0
  112. nat/front_ends/mcp/introspection_token_verifier.py +73 -0
  113. nat/front_ends/mcp/mcp_front_end_config.py +71 -3
  114. nat/front_ends/mcp/mcp_front_end_plugin.py +85 -21
  115. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +248 -29
  116. nat/front_ends/mcp/memory_profiler.py +320 -0
  117. nat/front_ends/mcp/tool_converter.py +78 -25
  118. nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
  119. nat/llm/aws_bedrock_llm.py +21 -8
  120. nat/llm/azure_openai_llm.py +14 -5
  121. nat/llm/litellm_llm.py +80 -0
  122. nat/llm/nim_llm.py +23 -9
  123. nat/llm/openai_llm.py +19 -7
  124. nat/llm/register.py +4 -0
  125. nat/llm/utils/thinking.py +1 -1
  126. nat/observability/exporter/base_exporter.py +1 -1
  127. nat/observability/exporter/processing_exporter.py +29 -55
  128. nat/observability/exporter/span_exporter.py +43 -15
  129. nat/observability/exporter_manager.py +2 -2
  130. nat/observability/mixin/redaction_config_mixin.py +5 -4
  131. nat/observability/mixin/tagging_config_mixin.py +26 -14
  132. nat/observability/mixin/type_introspection_mixin.py +420 -107
  133. nat/observability/processor/batching_processor.py +1 -1
  134. nat/observability/processor/processor.py +3 -0
  135. nat/observability/processor/redaction/__init__.py +24 -0
  136. nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
  137. nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
  138. nat/observability/processor/redaction/redaction_processor.py +177 -0
  139. nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
  140. nat/observability/processor/span_tagging_processor.py +21 -14
  141. nat/observability/register.py +16 -0
  142. nat/profiler/callbacks/langchain_callback_handler.py +32 -7
  143. nat/profiler/callbacks/llama_index_callback_handler.py +36 -2
  144. nat/profiler/callbacks/token_usage_base_model.py +2 -0
  145. nat/profiler/decorators/framework_wrapper.py +61 -9
  146. nat/profiler/decorators/function_tracking.py +35 -3
  147. nat/profiler/forecasting/models/linear_model.py +1 -1
  148. nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
  149. nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
  150. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +1 -1
  151. nat/profiler/parameter_optimization/__init__.py +0 -0
  152. nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
  153. nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
  154. nat/profiler/parameter_optimization/parameter_optimizer.py +189 -0
  155. nat/profiler/parameter_optimization/parameter_selection.py +107 -0
  156. nat/profiler/parameter_optimization/pareto_visualizer.py +460 -0
  157. nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
  158. nat/profiler/parameter_optimization/update_helpers.py +66 -0
  159. nat/profiler/utils.py +3 -1
  160. nat/registry_handlers/pypi/register_pypi.py +5 -3
  161. nat/registry_handlers/rest/register_rest.py +5 -3
  162. nat/retriever/milvus/retriever.py +1 -1
  163. nat/retriever/nemo_retriever/register.py +2 -1
  164. nat/runtime/loader.py +1 -1
  165. nat/runtime/runner.py +111 -6
  166. nat/runtime/session.py +49 -3
  167. nat/settings/global_settings.py +2 -2
  168. nat/tool/chat_completion.py +4 -1
  169. nat/tool/code_execution/code_sandbox.py +3 -6
  170. nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +19 -32
  171. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +6 -1
  172. nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +2 -0
  173. nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +10 -4
  174. nat/tool/datetime_tools.py +1 -1
  175. nat/tool/github_tools.py +450 -0
  176. nat/tool/memory_tools/add_memory_tool.py +3 -3
  177. nat/tool/memory_tools/delete_memory_tool.py +3 -4
  178. nat/tool/memory_tools/get_memory_tool.py +4 -4
  179. nat/tool/register.py +2 -7
  180. nat/tool/server_tools.py +15 -2
  181. nat/utils/__init__.py +76 -0
  182. nat/utils/callable_utils.py +70 -0
  183. nat/utils/data_models/schema_validator.py +1 -1
  184. nat/utils/decorators.py +210 -0
  185. nat/utils/exception_handlers/automatic_retries.py +278 -72
  186. nat/utils/io/yaml_tools.py +73 -3
  187. nat/utils/log_levels.py +25 -0
  188. nat/utils/responses_api.py +26 -0
  189. nat/utils/string_utils.py +16 -0
  190. nat/utils/type_converter.py +12 -3
  191. nat/utils/type_utils.py +6 -2
  192. nvidia_nat-1.4.0a20251112.dist-info/METADATA +197 -0
  193. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/RECORD +199 -165
  194. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/entry_points.txt +1 -0
  195. nat/cli/commands/info/list_mcp.py +0 -461
  196. nat/data_models/temperature_mixin.py +0 -43
  197. nat/data_models/top_p_mixin.py +0 -43
  198. nat/observability/processor/header_redaction_processor.py +0 -123
  199. nat/observability/processor/redaction_processor.py +0 -77
  200. nat/tool/code_execution/test_code_execution_sandbox.py +0 -414
  201. nat/tool/github_tools/create_github_commit.py +0 -133
  202. nat/tool/github_tools/create_github_issue.py +0 -87
  203. nat/tool/github_tools/create_github_pr.py +0 -106
  204. nat/tool/github_tools/get_github_file.py +0 -106
  205. nat/tool/github_tools/get_github_issue.py +0 -166
  206. nat/tool/github_tools/get_github_pr.py +0 -256
  207. nat/tool/github_tools/update_github_issue.py +0 -100
  208. nvidia_nat-1.3.0a20250910.dist-info/METADATA +0 -373
  209. /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
  210. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/WHEEL +0 -0
  211. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  212. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/licenses/LICENSE.md +0 -0
  213. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/top_level.txt +0 -0
@@ -22,6 +22,7 @@ from dataclasses import dataclass
22
22
  from dataclasses import field
23
23
 
24
24
  import pkce
25
+ from authlib.common.errors import AuthlibBaseError as OAuthError
25
26
  from authlib.integrations.httpx_client import AsyncOAuth2Client
26
27
 
27
28
  from nat.authentication.interfaces import FlowHandlerBase
@@ -61,14 +62,50 @@ class WebSocketAuthenticationFlowHandler(FlowHandlerBase):
61
62
 
62
63
  raise NotImplementedError(f"Authentication method '{method}' is not supported by the websocket frontend.")
63
64
 
64
- def create_oauth_client(self, config: OAuth2AuthCodeFlowProviderConfig):
65
- return AsyncOAuth2Client(client_id=config.client_id,
66
- client_secret=config.client_secret,
67
- redirect_uri=config.redirect_uri,
68
- scope=" ".join(config.scopes) if config.scopes else None,
69
- token_endpoint=config.token_url,
70
- code_challenge_method='S256' if config.use_pkce else None,
71
- token_endpoint_auth_method=config.token_endpoint_auth_method)
65
+ def create_oauth_client(self, config: OAuth2AuthCodeFlowProviderConfig) -> AsyncOAuth2Client:
66
+ try:
67
+ return AsyncOAuth2Client(client_id=config.client_id,
68
+ client_secret=config.client_secret,
69
+ redirect_uri=config.redirect_uri,
70
+ scope=" ".join(config.scopes) if config.scopes else None,
71
+ token_endpoint=config.token_url,
72
+ code_challenge_method='S256' if config.use_pkce else None,
73
+ token_endpoint_auth_method=config.token_endpoint_auth_method)
74
+ except (OAuthError, ValueError, TypeError) as e:
75
+ raise RuntimeError(f"Invalid OAuth2 configuration: {e}") from e
76
+ except Exception as e:
77
+ raise RuntimeError(f"Failed to create OAuth2 client: {e}") from e
78
+
79
+ def _create_authorization_url(self,
80
+ client: AsyncOAuth2Client,
81
+ config: OAuth2AuthCodeFlowProviderConfig,
82
+ state: str,
83
+ verifier: str = None,
84
+ challenge: str = None) -> str:
85
+ """
86
+ Create OAuth authorization URL with proper error handling.
87
+
88
+ Args:
89
+ client: The OAuth2 client instance
90
+ config: OAuth2 configuration
91
+ state: OAuth state parameter
92
+ verifier: PKCE verifier (if using PKCE)
93
+ challenge: PKCE challenge (if using PKCE)
94
+
95
+ Returns:
96
+ The authorization URL
97
+ """
98
+ try:
99
+ authorization_url, _ = client.create_authorization_url(
100
+ config.authorization_url,
101
+ state=state,
102
+ code_verifier=verifier if config.use_pkce else None,
103
+ code_challenge=challenge if config.use_pkce else None,
104
+ **(config.authorization_kwargs or {})
105
+ )
106
+ return authorization_url
107
+ except (OAuthError, ValueError, TypeError) as e:
108
+ raise RuntimeError(f"Error creating OAuth authorization URL: {e}") from e
72
109
 
73
110
  async def _handle_oauth2_auth_code_flow(self, config: OAuth2AuthCodeFlowProviderConfig) -> AuthenticatedContext:
74
111
 
@@ -82,21 +119,19 @@ class WebSocketAuthenticationFlowHandler(FlowHandlerBase):
82
119
  flow_state.verifier = verifier
83
120
  flow_state.challenge = challenge
84
121
 
85
- authorization_url, _ = flow_state.client.create_authorization_url(
86
- config.authorization_url,
87
- state=state,
88
- code_verifier=flow_state.verifier if config.use_pkce else None,
89
- code_challenge=flow_state.challenge if config.use_pkce else None,
90
- **(config.authorization_kwargs or {})
91
- )
122
+ authorization_url = self._create_authorization_url(client=flow_state.client,
123
+ config=config,
124
+ state=state,
125
+ verifier=flow_state.verifier,
126
+ challenge=flow_state.challenge)
92
127
 
93
128
  await self._add_flow_cb(state, flow_state)
94
129
  await self._web_socket_message_handler.create_websocket_message(_HumanPromptOAuthConsent(text=authorization_url)
95
130
  )
96
131
  try:
97
132
  token = await asyncio.wait_for(flow_state.future, timeout=300)
98
- except asyncio.TimeoutError:
99
- raise RuntimeError("Authentication flow timed out after 5 minutes.")
133
+ except TimeoutError as exc:
134
+ raise RuntimeError("Authentication flow timed out after 5 minutes.") from exc
100
135
  finally:
101
136
 
102
137
  await self._remove_flow_cb(state)
@@ -0,0 +1,65 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import typing
17
+ from abc import ABC
18
+ from collections.abc import AsyncGenerator
19
+ from collections.abc import Generator
20
+ from contextlib import asynccontextmanager
21
+ from contextlib import contextmanager
22
+
23
+ if typing.TYPE_CHECKING:
24
+ from dask.distributed import Client
25
+
26
+
27
+ class DaskClientMixin(ABC):
28
+
29
+ @asynccontextmanager
30
+ async def client(self, address: str) -> AsyncGenerator["Client"]:
31
+ """
32
+ Async context manager for obtaining a Dask client.
33
+
34
+ Yields
35
+ ------
36
+ Client
37
+ An async Dask client connected to the scheduler. The client is automatically closed when exiting the
38
+ context manager.
39
+ """
40
+ from dask.distributed import Client
41
+ client = await Client(address=address, asynchronous=True)
42
+
43
+ try:
44
+ yield client
45
+ finally:
46
+ await client.close()
47
+
48
+ @contextmanager
49
+ def blocking_client(self, address: str) -> Generator["Client"]:
50
+ """
51
+ context manager for obtaining a blocking Dask client.
52
+
53
+ Yields
54
+ ------
55
+ Client
56
+ A blocking Dask client connected to the scheduler. The client is automatically closed when exiting the
57
+ context manager.
58
+ """
59
+ from dask.distributed import Client
60
+ client = Client(address=address)
61
+
62
+ try:
63
+ yield client
64
+ finally:
65
+ client.close()
@@ -197,9 +197,31 @@ class FastApiFrontEndConfig(FrontEndBaseConfig, name="fastapi"):
197
197
  port: int = Field(default=8000, description="Port to bind the server to", ge=0, le=65535)
198
198
  reload: bool = Field(default=False, description="Enable auto-reload for development")
199
199
  workers: int = Field(default=1, description="Number of workers to run", ge=1)
200
- max_running_async_jobs: int = Field(default=10,
201
- description="Maximum number of async jobs to run concurrently",
202
- ge=1)
200
+ scheduler_address: str | None = Field(
201
+ default=None,
202
+ description=("Address of the Dask scheduler to use for async jobs. If None, a Dask local cluster is created. "
203
+ "Note: This requires the optional dask dependency to be installed."))
204
+ db_url: str | None = Field(
205
+ default=None,
206
+ description=
207
+ "SQLAlchemy database URL for storing async job metadata, if unset a temporary SQLite database is used.")
208
+ max_running_async_jobs: int = Field(
209
+ default=10,
210
+ description=(
211
+ "Maximum number of async jobs to run concurrently, this controls the number of dask workers created. "
212
+ "This parameter is only used when scheduler_address is `None` and a Dask local cluster is created."),
213
+ ge=1)
214
+ dask_workers: typing.Literal["threads", "processes"] = Field(
215
+ default="processes",
216
+ description=(
217
+ "Type of Dask workers to use. Options are 'threads' for Threaded Dask workers or 'processes' for "
218
+ "Process based Dask workers. This parameter is only used when scheduler_address is `None` and a local Dask "
219
+ "cluster is created."),
220
+ )
221
+ dask_log_level: str = Field(
222
+ default="WARNING",
223
+ description="Logging level for Dask.",
224
+ )
203
225
  step_adaptor: StepAdaptorConfig = StepAdaptorConfig()
204
226
 
205
227
  workflow: typing.Annotated[EndpointBase, Field(description="Endpoint for the default workflow.")] = EndpointBase(
@@ -13,21 +13,37 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import asyncio
16
17
  import logging
17
18
  import os
19
+ import sys
18
20
  import tempfile
19
21
  import typing
20
22
 
21
23
  from nat.builder.front_end import FrontEndBase
24
+ from nat.front_ends.fastapi.dask_client_mixin import DaskClientMixin
22
25
  from nat.front_ends.fastapi.fastapi_front_end_config import FastApiFrontEndConfig
23
26
  from nat.front_ends.fastapi.fastapi_front_end_plugin_worker import FastApiFrontEndPluginWorkerBase
24
27
  from nat.front_ends.fastapi.main import get_app
28
+ from nat.front_ends.fastapi.utils import get_class_name
25
29
  from nat.utils.io.yaml_tools import yaml_dump
30
+ from nat.utils.log_levels import LOG_LEVELS
31
+
32
+ if (typing.TYPE_CHECKING):
33
+ from nat.data_models.config import Config
26
34
 
27
35
  logger = logging.getLogger(__name__)
28
36
 
29
37
 
30
- class FastApiFrontEndPlugin(FrontEndBase[FastApiFrontEndConfig]):
38
+ class FastApiFrontEndPlugin(DaskClientMixin, FrontEndBase[FastApiFrontEndConfig]):
39
+
40
+ def __init__(self, full_config: "Config"):
41
+ super().__init__(full_config)
42
+
43
+ # This attribute is set if dask is installed, and an external cluster is not used (scheduler_address is None)
44
+ self._cluster = None
45
+ self._periodic_cleanup_future = None
46
+ self._scheduler_address = None
31
47
 
32
48
  def get_worker_class(self) -> type[FastApiFrontEndPluginWorkerBase]:
33
49
  from nat.front_ends.fastapi.fastapi_front_end_plugin_worker import FastApiFrontEndPluginWorker
@@ -42,7 +58,45 @@ class FastApiFrontEndPlugin(FrontEndBase[FastApiFrontEndConfig]):
42
58
 
43
59
  worker_class = self.get_worker_class()
44
60
 
45
- return f"{worker_class.__module__}.{worker_class.__qualname__}"
61
+ return get_class_name(worker_class)
62
+
63
+ @staticmethod
64
+ async def _periodic_cleanup(scheduler_address: str,
65
+ db_url: str,
66
+ sleep_time_sec: int = 300,
67
+ log_level: int = logging.INFO):
68
+ from nat.front_ends.fastapi.job_store import JobStore
69
+
70
+ job_store = JobStore(scheduler_address=scheduler_address, db_url=db_url)
71
+
72
+ logging.basicConfig(level=log_level)
73
+ logger.info("Starting periodic cleanup of expired jobs every %d seconds", sleep_time_sec)
74
+ while True:
75
+ await asyncio.sleep(sleep_time_sec)
76
+
77
+ try:
78
+ await job_store.cleanup_expired_jobs()
79
+ logger.debug("Expired jobs cleaned up")
80
+ except: # noqa: E722
81
+ logger.exception("Error during job cleanup")
82
+
83
+ async def _submit_cleanup_task(self, scheduler_address: str, db_url: str, log_level: int = logging.INFO):
84
+ """Submit a cleanup task to the cluster to remove the job after expiry."""
85
+ logger.debug("Submitting periodic cleanup task to Dask cluster at %s", scheduler_address)
86
+ async with self.client(self._scheduler_address) as client:
87
+ self._periodic_cleanup_future = client.submit(self._periodic_cleanup,
88
+ scheduler_address=self._scheduler_address,
89
+ db_url=db_url,
90
+ log_level=log_level)
91
+
92
+ @staticmethod
93
+ def _setup_worker():
94
+ """
95
+ Setup function to be run in each worker process. This moves each worker into it's own process group.
96
+ This fixes an issue where a Ctrl-C in the terminal sends a SIGINT to all workers, which then causes the
97
+ workers to exit before the main process can shutdown the cluster gracefully.
98
+ """
99
+ os.setsid()
46
100
 
47
101
  async def run(self):
48
102
 
@@ -52,6 +106,65 @@ class FastApiFrontEndPlugin(FrontEndBase[FastApiFrontEndConfig]):
52
106
  # Get as dict
53
107
  config_dict = self.full_config.model_dump(mode="json", by_alias=True, round_trip=True)
54
108
 
109
+ # Three possible cases:
110
+ # 1. Dask is installed and scheduler_address is None, we create a LocalCluster
111
+ # 2. Dask is installed and scheduler_address is set, we use the existing cluster
112
+ # 3. Dask is not installed, we skip the cluster setup
113
+ dask_log_level = LOG_LEVELS.get(self.front_end_config.dask_log_level.upper(), logging.WARNING)
114
+ dask_logger = logging.getLogger("distributed")
115
+ dask_logger.setLevel(dask_log_level)
116
+
117
+ self._scheduler_address = self.front_end_config.scheduler_address
118
+ if self._scheduler_address is None:
119
+ try:
120
+
121
+ from dask.distributed import LocalCluster
122
+
123
+ use_threads = self.front_end_config.dask_workers == 'threads'
124
+
125
+ # set n_workers to max_running_async_jobs + 1 to allow for one worker to handle the cleanup task
126
+ self._cluster = LocalCluster(processes=not use_threads,
127
+ silence_logs=dask_log_level,
128
+ protocol="tcp",
129
+ n_workers=self.front_end_config.max_running_async_jobs + 1)
130
+
131
+ self._scheduler_address = self._cluster.scheduler.address
132
+
133
+ if not use_threads and sys.platform != "win32":
134
+ with self.blocking_client(self._scheduler_address) as client:
135
+ # Client.run submits a function to be run on each worker
136
+ client.run(self._setup_worker)
137
+
138
+ logger.info("Created local Dask cluster with scheduler at %s using %s workers",
139
+ self._scheduler_address,
140
+ self.front_end_config.dask_workers)
141
+
142
+ except ImportError:
143
+ logger.warning("Dask is not installed, async execution and evaluation will not be available.")
144
+
145
+ if self._scheduler_address is not None:
146
+ # If we are here then either the user provided a scheduler address, or we created a LocalCluster
147
+
148
+ from nat.front_ends.fastapi.job_store import Base
149
+ from nat.front_ends.fastapi.job_store import get_db_engine
150
+
151
+ db_engine = get_db_engine(self.front_end_config.db_url, use_async=True)
152
+ async with db_engine.begin() as conn:
153
+ await conn.run_sync(Base.metadata.create_all, checkfirst=True) # create tables if they do not exist
154
+
155
+ # If self.front_end_config.db_url is None, then we need to get the actual url from the engine
156
+ db_url = str(db_engine.url)
157
+ await self._submit_cleanup_task(scheduler_address=self._scheduler_address,
158
+ db_url=db_url,
159
+ log_level=dask_log_level)
160
+
161
+ # Set environment variabls such that the worker subprocesses will know how to connect to dask and to
162
+ # the database
163
+ os.environ.update({
164
+ "NAT_DASK_SCHEDULER_ADDRESS": self._scheduler_address,
165
+ "NAT_JOB_STORE_DB_URL": db_url,
166
+ })
167
+
55
168
  # Write to YAML file
56
169
  yaml_dump(config_dict, config_file)
57
170
 
@@ -70,13 +183,25 @@ class FastApiFrontEndPlugin(FrontEndBase[FastApiFrontEndConfig]):
70
183
 
71
184
  reload_excludes = ["./.*"]
72
185
 
186
+ # By default, Uvicorn uses "auto" event loop policy, which prefers `uvloop` if installed. However,
187
+ # uvloop’s event loop policy for macOS doesn’t provide a child watcher (which is needed for MCP server),
188
+ # so setting loop="asyncio" forces Uvicorn to use the standard event loop, which includes child-watcher
189
+ # support.
190
+ if sys.platform == "darwin" or sys.platform.startswith("linux"):
191
+ # For macOS
192
+ event_loop_policy = "asyncio"
193
+ else:
194
+ # For non-macOS platforms
195
+ event_loop_policy = "auto"
196
+
73
197
  uvicorn.run("nat.front_ends.fastapi.main:get_app",
74
198
  host=self.front_end_config.host,
75
199
  port=self.front_end_config.port,
76
200
  workers=self.front_end_config.workers,
77
201
  reload=self.front_end_config.reload,
78
202
  factory=True,
79
- reload_excludes=reload_excludes)
203
+ reload_excludes=reload_excludes,
204
+ loop=event_loop_policy)
80
205
 
81
206
  else:
82
207
  app = get_app()
@@ -110,6 +235,18 @@ class FastApiFrontEndPlugin(FrontEndBase[FastApiFrontEndConfig]):
110
235
  StandaloneApplication(app, options=options).run()
111
236
 
112
237
  finally:
238
+ logger.debug("Shutting down")
239
+ if self._periodic_cleanup_future is not None:
240
+ logger.info("Cancelling periodic cleanup task.")
241
+ # Use the scheduler address, because self._cluster is None if an external cluster is used
242
+ async with self.client(self._scheduler_address) as client:
243
+ await client.cancel([self._periodic_cleanup_future], asynchronous=True, force=True)
244
+
245
+ if self._cluster is not None:
246
+ # Only shut down the cluster if we created it
247
+ logger.debug("Closing Local Dask cluster.")
248
+ self._cluster.close()
249
+
113
250
  try:
114
251
  os.remove(config_file_name)
115
252
  except OSError as e: