nvidia-nat 1.3.0.dev2__py3-none-any.whl → 1.3.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +2 -2
- nat/agent/base.py +24 -15
- nat/agent/dual_node.py +9 -4
- nat/agent/prompt_optimizer/prompt.py +68 -0
- nat/agent/prompt_optimizer/register.py +149 -0
- nat/agent/react_agent/agent.py +79 -47
- nat/agent/react_agent/register.py +41 -21
- nat/agent/reasoning_agent/reasoning_agent.py +11 -9
- nat/agent/register.py +1 -1
- nat/agent/rewoo_agent/agent.py +326 -148
- nat/agent/rewoo_agent/prompt.py +19 -22
- nat/agent/rewoo_agent/register.py +46 -26
- nat/agent/tool_calling_agent/agent.py +84 -28
- nat/agent/tool_calling_agent/register.py +51 -28
- nat/authentication/api_key/api_key_auth_provider.py +2 -2
- nat/authentication/credential_validator/bearer_token_validator.py +557 -0
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
- nat/authentication/interfaces.py +5 -2
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +40 -20
- nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
- nat/authentication/register.py +0 -1
- nat/builder/builder.py +56 -24
- nat/builder/component_utils.py +9 -5
- nat/builder/context.py +46 -11
- nat/builder/eval_builder.py +16 -11
- nat/builder/framework_enum.py +1 -0
- nat/builder/front_end.py +1 -1
- nat/builder/function.py +378 -8
- nat/builder/function_base.py +3 -3
- nat/builder/function_info.py +6 -8
- nat/builder/user_interaction_manager.py +2 -2
- nat/builder/workflow.py +13 -1
- nat/builder/workflow_builder.py +281 -76
- nat/cli/cli_utils/config_override.py +2 -2
- nat/cli/commands/evaluate.py +1 -1
- nat/cli/commands/info/info.py +16 -6
- nat/cli/commands/info/list_channels.py +1 -1
- nat/cli/commands/info/list_components.py +7 -8
- nat/cli/commands/mcp/__init__.py +14 -0
- nat/cli/commands/mcp/mcp.py +986 -0
- nat/cli/commands/object_store/__init__.py +14 -0
- nat/cli/commands/object_store/object_store.py +227 -0
- nat/cli/commands/optimize.py +90 -0
- nat/cli/commands/registry/publish.py +2 -2
- nat/cli/commands/registry/pull.py +2 -2
- nat/cli/commands/registry/remove.py +2 -2
- nat/cli/commands/registry/search.py +15 -17
- nat/cli/commands/start.py +16 -5
- nat/cli/commands/uninstall.py +1 -1
- nat/cli/commands/workflow/templates/config.yml.j2 +0 -1
- nat/cli/commands/workflow/templates/pyproject.toml.j2 +4 -1
- nat/cli/commands/workflow/templates/register.py.j2 +0 -1
- nat/cli/commands/workflow/workflow_commands.py +9 -13
- nat/cli/entrypoint.py +8 -10
- nat/cli/register_workflow.py +38 -4
- nat/cli/type_registry.py +75 -6
- nat/control_flow/__init__.py +0 -0
- nat/control_flow/register.py +20 -0
- nat/control_flow/router_agent/__init__.py +0 -0
- nat/control_flow/router_agent/agent.py +329 -0
- nat/control_flow/router_agent/prompt.py +48 -0
- nat/control_flow/router_agent/register.py +91 -0
- nat/control_flow/sequential_executor.py +166 -0
- nat/data_models/agent.py +34 -0
- nat/data_models/api_server.py +10 -10
- nat/data_models/authentication.py +23 -9
- nat/data_models/common.py +1 -1
- nat/data_models/component.py +2 -0
- nat/data_models/component_ref.py +11 -0
- nat/data_models/config.py +41 -17
- nat/data_models/dataset_handler.py +1 -1
- nat/data_models/discovery_metadata.py +4 -4
- nat/data_models/evaluate.py +4 -1
- nat/data_models/function.py +34 -0
- nat/data_models/function_dependencies.py +14 -6
- nat/data_models/gated_field_mixin.py +242 -0
- nat/data_models/intermediate_step.py +3 -3
- nat/data_models/optimizable.py +119 -0
- nat/data_models/optimizer.py +149 -0
- nat/data_models/swe_bench_model.py +1 -1
- nat/data_models/temperature_mixin.py +44 -0
- nat/data_models/thinking_mixin.py +86 -0
- nat/data_models/top_p_mixin.py +44 -0
- nat/embedder/nim_embedder.py +1 -1
- nat/embedder/openai_embedder.py +1 -1
- nat/embedder/register.py +0 -1
- nat/eval/config.py +3 -1
- nat/eval/dataset_handler/dataset_handler.py +71 -7
- nat/eval/evaluate.py +86 -31
- nat/eval/evaluator/base_evaluator.py +1 -1
- nat/eval/evaluator/evaluator_model.py +13 -0
- nat/eval/intermediate_step_adapter.py +1 -1
- nat/eval/rag_evaluator/evaluate.py +2 -2
- nat/eval/rag_evaluator/register.py +3 -3
- nat/eval/register.py +4 -1
- nat/eval/remote_workflow.py +3 -3
- nat/eval/runtime_evaluator/__init__.py +14 -0
- nat/eval/runtime_evaluator/evaluate.py +123 -0
- nat/eval/runtime_evaluator/register.py +100 -0
- nat/eval/swe_bench_evaluator/evaluate.py +6 -6
- nat/eval/trajectory_evaluator/evaluate.py +1 -1
- nat/eval/trajectory_evaluator/register.py +1 -1
- nat/eval/tunable_rag_evaluator/evaluate.py +4 -7
- nat/eval/utils/eval_trace_ctx.py +89 -0
- nat/eval/utils/weave_eval.py +18 -9
- nat/experimental/decorators/experimental_warning_decorator.py +27 -7
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +3 -3
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +1 -1
- nat/experimental/test_time_compute/models/strategy_base.py +5 -4
- nat/experimental/test_time_compute/register.py +0 -1
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -3
- nat/front_ends/console/authentication_flow_handler.py +82 -30
- nat/front_ends/console/console_front_end_plugin.py +8 -5
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
- nat/front_ends/fastapi/dask_client_mixin.py +65 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +36 -5
- nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +135 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +481 -281
- nat/front_ends/fastapi/job_store.py +518 -99
- nat/front_ends/fastapi/main.py +11 -19
- nat/front_ends/fastapi/message_handler.py +13 -14
- nat/front_ends/fastapi/message_validator.py +17 -19
- nat/front_ends/fastapi/response_helpers.py +4 -4
- nat/front_ends/fastapi/step_adaptor.py +2 -2
- nat/front_ends/fastapi/utils.py +57 -0
- nat/front_ends/mcp/introspection_token_verifier.py +73 -0
- nat/front_ends/mcp/mcp_front_end_config.py +10 -1
- nat/front_ends/mcp/mcp_front_end_plugin.py +45 -13
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +116 -8
- nat/front_ends/mcp/tool_converter.py +44 -14
- nat/front_ends/register.py +0 -1
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
- nat/llm/aws_bedrock_llm.py +24 -12
- nat/llm/azure_openai_llm.py +13 -6
- nat/llm/litellm_llm.py +69 -0
- nat/llm/nim_llm.py +20 -8
- nat/llm/openai_llm.py +14 -6
- nat/llm/register.py +4 -1
- nat/llm/utils/env_config_value.py +2 -3
- nat/llm/utils/thinking.py +215 -0
- nat/meta/pypi.md +9 -9
- nat/object_store/register.py +0 -1
- nat/observability/exporter/base_exporter.py +3 -3
- nat/observability/exporter/file_exporter.py +1 -1
- nat/observability/exporter/processing_exporter.py +309 -81
- nat/observability/exporter/span_exporter.py +1 -1
- nat/observability/exporter_manager.py +7 -7
- nat/observability/mixin/file_mixin.py +7 -7
- nat/observability/mixin/redaction_config_mixin.py +42 -0
- nat/observability/mixin/tagging_config_mixin.py +62 -0
- nat/observability/mixin/type_introspection_mixin.py +420 -107
- nat/observability/processor/batching_processor.py +5 -7
- nat/observability/processor/falsy_batch_filter_processor.py +55 -0
- nat/observability/processor/processor.py +3 -0
- nat/observability/processor/processor_factory.py +70 -0
- nat/observability/processor/redaction/__init__.py +24 -0
- nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
- nat/observability/processor/redaction/redaction_processor.py +177 -0
- nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
- nat/observability/processor/span_tagging_processor.py +68 -0
- nat/observability/register.py +6 -4
- nat/profiler/calc/calc_runner.py +3 -4
- nat/profiler/callbacks/agno_callback_handler.py +1 -1
- nat/profiler/callbacks/langchain_callback_handler.py +6 -6
- nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +3 -3
- nat/profiler/data_frame_row.py +1 -1
- nat/profiler/decorators/framework_wrapper.py +62 -13
- nat/profiler/decorators/function_tracking.py +160 -3
- nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/data_models.py +3 -3
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +7 -8
- nat/profiler/inference_optimization/token_uniqueness.py +1 -1
- nat/profiler/parameter_optimization/__init__.py +0 -0
- nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
- nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
- nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
- nat/profiler/parameter_optimization/parameter_selection.py +107 -0
- nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
- nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
- nat/profiler/parameter_optimization/update_helpers.py +66 -0
- nat/profiler/profile_runner.py +14 -9
- nat/profiler/utils.py +4 -2
- nat/registry_handlers/local/local_handler.py +2 -2
- nat/registry_handlers/package_utils.py +1 -2
- nat/registry_handlers/pypi/pypi_handler.py +23 -26
- nat/registry_handlers/register.py +3 -4
- nat/registry_handlers/rest/rest_handler.py +12 -13
- nat/retriever/milvus/retriever.py +2 -2
- nat/retriever/nemo_retriever/retriever.py +1 -1
- nat/retriever/register.py +0 -1
- nat/runtime/loader.py +2 -2
- nat/runtime/runner.py +3 -2
- nat/runtime/session.py +43 -8
- nat/settings/global_settings.py +16 -5
- nat/tool/chat_completion.py +5 -2
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +3 -3
- nat/tool/datetime_tools.py +49 -9
- nat/tool/document_search.py +2 -2
- nat/tool/github_tools.py +450 -0
- nat/tool/nvidia_rag.py +1 -1
- nat/tool/register.py +2 -9
- nat/tool/retriever.py +3 -2
- nat/utils/callable_utils.py +70 -0
- nat/utils/data_models/schema_validator.py +3 -3
- nat/utils/exception_handlers/automatic_retries.py +104 -51
- nat/utils/exception_handlers/schemas.py +1 -1
- nat/utils/io/yaml_tools.py +2 -2
- nat/utils/log_levels.py +25 -0
- nat/utils/reactive/base/observable_base.py +2 -2
- nat/utils/reactive/base/observer_base.py +1 -1
- nat/utils/reactive/observable.py +2 -2
- nat/utils/reactive/observer.py +4 -4
- nat/utils/reactive/subscription.py +1 -1
- nat/utils/settings/global_settings.py +6 -8
- nat/utils/type_converter.py +4 -3
- nat/utils/type_utils.py +9 -5
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/METADATA +42 -16
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/RECORD +230 -189
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/entry_points.txt +1 -0
- nat/cli/commands/info/list_mcp.py +0 -304
- nat/tool/github_tools/create_github_commit.py +0 -133
- nat/tool/github_tools/create_github_issue.py +0 -87
- nat/tool/github_tools/create_github_pr.py +0 -106
- nat/tool/github_tools/get_github_file.py +0 -106
- nat/tool/github_tools/get_github_issue.py +0 -166
- nat/tool/github_tools/get_github_pr.py +0 -256
- nat/tool/github_tools/update_github_issue.py +0 -100
- nat/tool/mcp/exceptions.py +0 -142
- nat/tool/mcp/mcp_client.py +0 -255
- nat/tool/mcp/mcp_tool.py +0 -96
- nat/utils/exception_handlers/mcp.py +0 -211
- /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
- /nat/{tool/mcp → authentication/credential_validator}/__init__.py +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -22,6 +22,7 @@ from dataclasses import dataclass
|
|
|
22
22
|
from dataclasses import field
|
|
23
23
|
|
|
24
24
|
import pkce
|
|
25
|
+
from authlib.common.errors import AuthlibBaseError as OAuthError
|
|
25
26
|
from authlib.integrations.httpx_client import AsyncOAuth2Client
|
|
26
27
|
|
|
27
28
|
from nat.authentication.interfaces import FlowHandlerBase
|
|
@@ -61,14 +62,50 @@ class WebSocketAuthenticationFlowHandler(FlowHandlerBase):
|
|
|
61
62
|
|
|
62
63
|
raise NotImplementedError(f"Authentication method '{method}' is not supported by the websocket frontend.")
|
|
63
64
|
|
|
64
|
-
def create_oauth_client(self, config: OAuth2AuthCodeFlowProviderConfig):
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
65
|
+
def create_oauth_client(self, config: OAuth2AuthCodeFlowProviderConfig) -> AsyncOAuth2Client:
|
|
66
|
+
try:
|
|
67
|
+
return AsyncOAuth2Client(client_id=config.client_id,
|
|
68
|
+
client_secret=config.client_secret,
|
|
69
|
+
redirect_uri=config.redirect_uri,
|
|
70
|
+
scope=" ".join(config.scopes) if config.scopes else None,
|
|
71
|
+
token_endpoint=config.token_url,
|
|
72
|
+
code_challenge_method='S256' if config.use_pkce else None,
|
|
73
|
+
token_endpoint_auth_method=config.token_endpoint_auth_method)
|
|
74
|
+
except (OAuthError, ValueError, TypeError) as e:
|
|
75
|
+
raise RuntimeError(f"Invalid OAuth2 configuration: {e}") from e
|
|
76
|
+
except Exception as e:
|
|
77
|
+
raise RuntimeError(f"Failed to create OAuth2 client: {e}") from e
|
|
78
|
+
|
|
79
|
+
def _create_authorization_url(self,
|
|
80
|
+
client: AsyncOAuth2Client,
|
|
81
|
+
config: OAuth2AuthCodeFlowProviderConfig,
|
|
82
|
+
state: str,
|
|
83
|
+
verifier: str = None,
|
|
84
|
+
challenge: str = None) -> str:
|
|
85
|
+
"""
|
|
86
|
+
Create OAuth authorization URL with proper error handling.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
client: The OAuth2 client instance
|
|
90
|
+
config: OAuth2 configuration
|
|
91
|
+
state: OAuth state parameter
|
|
92
|
+
verifier: PKCE verifier (if using PKCE)
|
|
93
|
+
challenge: PKCE challenge (if using PKCE)
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
The authorization URL
|
|
97
|
+
"""
|
|
98
|
+
try:
|
|
99
|
+
authorization_url, _ = client.create_authorization_url(
|
|
100
|
+
config.authorization_url,
|
|
101
|
+
state=state,
|
|
102
|
+
code_verifier=verifier if config.use_pkce else None,
|
|
103
|
+
code_challenge=challenge if config.use_pkce else None,
|
|
104
|
+
**(config.authorization_kwargs or {})
|
|
105
|
+
)
|
|
106
|
+
return authorization_url
|
|
107
|
+
except (OAuthError, ValueError, TypeError) as e:
|
|
108
|
+
raise RuntimeError(f"Error creating OAuth authorization URL: {e}") from e
|
|
72
109
|
|
|
73
110
|
async def _handle_oauth2_auth_code_flow(self, config: OAuth2AuthCodeFlowProviderConfig) -> AuthenticatedContext:
|
|
74
111
|
|
|
@@ -82,21 +119,19 @@ class WebSocketAuthenticationFlowHandler(FlowHandlerBase):
|
|
|
82
119
|
flow_state.verifier = verifier
|
|
83
120
|
flow_state.challenge = challenge
|
|
84
121
|
|
|
85
|
-
authorization_url
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
**(config.authorization_kwargs or {})
|
|
91
|
-
)
|
|
122
|
+
authorization_url = self._create_authorization_url(client=flow_state.client,
|
|
123
|
+
config=config,
|
|
124
|
+
state=state,
|
|
125
|
+
verifier=flow_state.verifier,
|
|
126
|
+
challenge=flow_state.challenge)
|
|
92
127
|
|
|
93
128
|
await self._add_flow_cb(state, flow_state)
|
|
94
129
|
await self._web_socket_message_handler.create_websocket_message(_HumanPromptOAuthConsent(text=authorization_url)
|
|
95
130
|
)
|
|
96
131
|
try:
|
|
97
132
|
token = await asyncio.wait_for(flow_state.future, timeout=300)
|
|
98
|
-
except
|
|
99
|
-
raise RuntimeError("Authentication flow timed out after 5 minutes.")
|
|
133
|
+
except TimeoutError as exc:
|
|
134
|
+
raise RuntimeError("Authentication flow timed out after 5 minutes.") from exc
|
|
100
135
|
finally:
|
|
101
136
|
|
|
102
137
|
await self._remove_flow_cb(state)
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import typing
|
|
17
|
+
from abc import ABC
|
|
18
|
+
from collections.abc import AsyncGenerator
|
|
19
|
+
from collections.abc import Generator
|
|
20
|
+
from contextlib import asynccontextmanager
|
|
21
|
+
from contextlib import contextmanager
|
|
22
|
+
|
|
23
|
+
if typing.TYPE_CHECKING:
|
|
24
|
+
from dask.distributed import Client
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class DaskClientMixin(ABC):
|
|
28
|
+
|
|
29
|
+
@asynccontextmanager
|
|
30
|
+
async def client(self, address: str) -> AsyncGenerator["Client"]:
|
|
31
|
+
"""
|
|
32
|
+
Async context manager for obtaining a Dask client.
|
|
33
|
+
|
|
34
|
+
Yields
|
|
35
|
+
------
|
|
36
|
+
Client
|
|
37
|
+
An async Dask client connected to the scheduler. The client is automatically closed when exiting the
|
|
38
|
+
context manager.
|
|
39
|
+
"""
|
|
40
|
+
from dask.distributed import Client
|
|
41
|
+
client = await Client(address=address, asynchronous=True)
|
|
42
|
+
|
|
43
|
+
try:
|
|
44
|
+
yield client
|
|
45
|
+
finally:
|
|
46
|
+
await client.close()
|
|
47
|
+
|
|
48
|
+
@contextmanager
|
|
49
|
+
def blocking_client(self, address: str) -> Generator["Client"]:
|
|
50
|
+
"""
|
|
51
|
+
context manager for obtaining a blocking Dask client.
|
|
52
|
+
|
|
53
|
+
Yields
|
|
54
|
+
------
|
|
55
|
+
Client
|
|
56
|
+
A blocking Dask client connected to the scheduler. The client is automatically closed when exiting the
|
|
57
|
+
context manager.
|
|
58
|
+
"""
|
|
59
|
+
from dask.distributed import Client
|
|
60
|
+
client = Client(address=address)
|
|
61
|
+
|
|
62
|
+
try:
|
|
63
|
+
yield client
|
|
64
|
+
finally:
|
|
65
|
+
client.close()
|
|
@@ -14,6 +14,8 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import logging
|
|
17
|
+
import os
|
|
18
|
+
import sys
|
|
17
19
|
import typing
|
|
18
20
|
from datetime import datetime
|
|
19
21
|
from pathlib import Path
|
|
@@ -31,6 +33,20 @@ logger = logging.getLogger(__name__)
|
|
|
31
33
|
YAML_EXTENSIONS = (".yaml", ".yml")
|
|
32
34
|
|
|
33
35
|
|
|
36
|
+
def _is_reserved(path: Path) -> bool:
|
|
37
|
+
"""
|
|
38
|
+
Check if a path is reserved in the current Python version and platform.
|
|
39
|
+
|
|
40
|
+
On Windows, this function checks if the path is reserved in the current Python version.
|
|
41
|
+
On other platforms, returns False
|
|
42
|
+
"""
|
|
43
|
+
if sys.platform != "win32":
|
|
44
|
+
return False
|
|
45
|
+
if sys.version_info >= (3, 13):
|
|
46
|
+
return os.path.isreserved(path)
|
|
47
|
+
return path.is_reserved()
|
|
48
|
+
|
|
49
|
+
|
|
34
50
|
class EvaluateRequest(BaseModel):
|
|
35
51
|
"""Request model for the evaluate endpoint."""
|
|
36
52
|
config_file: str = Field(description="Path to the configuration file for evaluation")
|
|
@@ -51,7 +67,7 @@ class EvaluateRequest(BaseModel):
|
|
|
51
67
|
f"Job ID '{job_id}' contains invalid characters. Only alphanumeric characters and underscores are"
|
|
52
68
|
" allowed.")
|
|
53
69
|
|
|
54
|
-
if job_id_path
|
|
70
|
+
if _is_reserved(job_id_path):
|
|
55
71
|
# reserved names is Windows specific
|
|
56
72
|
raise ValueError(f"Job ID '{job_id}' is a reserved name. Please choose a different name.")
|
|
57
73
|
|
|
@@ -68,7 +84,7 @@ class EvaluateRequest(BaseModel):
|
|
|
68
84
|
raise ValueError(f"Config file '{config_file}' must be a YAML file with one of the following extensions: "
|
|
69
85
|
f"{', '.join(YAML_EXTENSIONS)}")
|
|
70
86
|
|
|
71
|
-
if config_file_path
|
|
87
|
+
if _is_reserved(config_file_path):
|
|
72
88
|
# reserved names is Windows specific
|
|
73
89
|
raise ValueError(f"Config file '{config_file}' is a reserved name. Please choose a different name.")
|
|
74
90
|
|
|
@@ -181,9 +197,24 @@ class FastApiFrontEndConfig(FrontEndBaseConfig, name="fastapi"):
|
|
|
181
197
|
port: int = Field(default=8000, description="Port to bind the server to", ge=0, le=65535)
|
|
182
198
|
reload: bool = Field(default=False, description="Enable auto-reload for development")
|
|
183
199
|
workers: int = Field(default=1, description="Number of workers to run", ge=1)
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
200
|
+
scheduler_address: str | None = Field(
|
|
201
|
+
default=None,
|
|
202
|
+
description=("Address of the Dask scheduler to use for async jobs. If None, a Dask local cluster is created. "
|
|
203
|
+
"Note: This requires the optional dask dependency to be installed."))
|
|
204
|
+
db_url: str | None = Field(
|
|
205
|
+
default=None,
|
|
206
|
+
description=
|
|
207
|
+
"SQLAlchemy database URL for storing async job metadata, if unset a temporary SQLite database is used.")
|
|
208
|
+
max_running_async_jobs: int = Field(
|
|
209
|
+
default=10,
|
|
210
|
+
description=(
|
|
211
|
+
"Maximum number of async jobs to run concurrently, this controls the number of dask workers created. "
|
|
212
|
+
"This parameter is only used when scheduler_address is `None` and a Dask local cluster is created."),
|
|
213
|
+
ge=1)
|
|
214
|
+
dask_log_level: str = Field(
|
|
215
|
+
default="WARNING",
|
|
216
|
+
description="Logging level for Dask.",
|
|
217
|
+
)
|
|
187
218
|
step_adaptor: StepAdaptorConfig = StepAdaptorConfig()
|
|
188
219
|
|
|
189
220
|
workflow: typing.Annotated[EndpointBase, Field(description="Endpoint for the default workflow.")] = EndpointBase(
|
|
@@ -47,11 +47,11 @@ class _FastApiFrontEndController:
|
|
|
47
47
|
self._server_background_task = asyncio.create_task(self._server.serve())
|
|
48
48
|
except asyncio.CancelledError as e:
|
|
49
49
|
error_message = f"Task error occurred while starting API server: {str(e)}"
|
|
50
|
-
logger.error(error_message
|
|
50
|
+
logger.error(error_message)
|
|
51
51
|
raise RuntimeError(error_message) from e
|
|
52
52
|
except Exception as e:
|
|
53
53
|
error_message = f"Unexpected error occurred while starting API server: {str(e)}"
|
|
54
|
-
logger.
|
|
54
|
+
logger.exception(error_message)
|
|
55
55
|
raise RuntimeError(error_message) from e
|
|
56
56
|
|
|
57
57
|
async def stop_server(self) -> None:
|
|
@@ -63,6 +63,6 @@ class _FastApiFrontEndController:
|
|
|
63
63
|
self._server.should_exit = True
|
|
64
64
|
await self._server_background_task
|
|
65
65
|
except asyncio.CancelledError as e:
|
|
66
|
-
logger.
|
|
66
|
+
logger.exception("Server shutdown failed: %s", str(e))
|
|
67
67
|
except Exception as e:
|
|
68
|
-
logger.
|
|
68
|
+
logger.exception("Unexpected error occurred: %s", str(e))
|
|
@@ -13,21 +13,37 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
import asyncio
|
|
16
17
|
import logging
|
|
17
18
|
import os
|
|
19
|
+
import sys
|
|
18
20
|
import tempfile
|
|
19
21
|
import typing
|
|
20
22
|
|
|
21
23
|
from nat.builder.front_end import FrontEndBase
|
|
24
|
+
from nat.front_ends.fastapi.dask_client_mixin import DaskClientMixin
|
|
22
25
|
from nat.front_ends.fastapi.fastapi_front_end_config import FastApiFrontEndConfig
|
|
23
26
|
from nat.front_ends.fastapi.fastapi_front_end_plugin_worker import FastApiFrontEndPluginWorkerBase
|
|
24
27
|
from nat.front_ends.fastapi.main import get_app
|
|
28
|
+
from nat.front_ends.fastapi.utils import get_class_name
|
|
25
29
|
from nat.utils.io.yaml_tools import yaml_dump
|
|
30
|
+
from nat.utils.log_levels import LOG_LEVELS
|
|
31
|
+
|
|
32
|
+
if (typing.TYPE_CHECKING):
|
|
33
|
+
from nat.data_models.config import Config
|
|
26
34
|
|
|
27
35
|
logger = logging.getLogger(__name__)
|
|
28
36
|
|
|
29
37
|
|
|
30
|
-
class FastApiFrontEndPlugin(FrontEndBase[FastApiFrontEndConfig]):
|
|
38
|
+
class FastApiFrontEndPlugin(DaskClientMixin, FrontEndBase[FastApiFrontEndConfig]):
|
|
39
|
+
|
|
40
|
+
def __init__(self, full_config: "Config"):
|
|
41
|
+
super().__init__(full_config)
|
|
42
|
+
|
|
43
|
+
# This attribute is set if dask is installed, and an external cluster is not used (scheduler_address is None)
|
|
44
|
+
self._cluster = None
|
|
45
|
+
self._periodic_cleanup_future = None
|
|
46
|
+
self._scheduler_address = None
|
|
31
47
|
|
|
32
48
|
def get_worker_class(self) -> type[FastApiFrontEndPluginWorkerBase]:
|
|
33
49
|
from nat.front_ends.fastapi.fastapi_front_end_plugin_worker import FastApiFrontEndPluginWorker
|
|
@@ -42,7 +58,45 @@ class FastApiFrontEndPlugin(FrontEndBase[FastApiFrontEndConfig]):
|
|
|
42
58
|
|
|
43
59
|
worker_class = self.get_worker_class()
|
|
44
60
|
|
|
45
|
-
return
|
|
61
|
+
return get_class_name(worker_class)
|
|
62
|
+
|
|
63
|
+
@staticmethod
|
|
64
|
+
async def _periodic_cleanup(scheduler_address: str,
|
|
65
|
+
db_url: str,
|
|
66
|
+
sleep_time_sec: int = 300,
|
|
67
|
+
log_level: int = logging.INFO):
|
|
68
|
+
from nat.front_ends.fastapi.job_store import JobStore
|
|
69
|
+
|
|
70
|
+
job_store = JobStore(scheduler_address=scheduler_address, db_url=db_url)
|
|
71
|
+
|
|
72
|
+
logging.basicConfig(level=log_level)
|
|
73
|
+
logger.info("Starting periodic cleanup of expired jobs every %d seconds", sleep_time_sec)
|
|
74
|
+
while True:
|
|
75
|
+
await asyncio.sleep(sleep_time_sec)
|
|
76
|
+
|
|
77
|
+
try:
|
|
78
|
+
await job_store.cleanup_expired_jobs()
|
|
79
|
+
logger.debug("Expired jobs cleaned up")
|
|
80
|
+
except: # noqa: E722
|
|
81
|
+
logger.exception("Error during job cleanup")
|
|
82
|
+
|
|
83
|
+
async def _submit_cleanup_task(self, scheduler_address: str, db_url: str, log_level: int = logging.INFO):
|
|
84
|
+
"""Submit a cleanup task to the cluster to remove the job after expiry."""
|
|
85
|
+
logger.debug("Submitting periodic cleanup task to Dask cluster at %s", scheduler_address)
|
|
86
|
+
async with self.client(self._scheduler_address) as client:
|
|
87
|
+
self._periodic_cleanup_future = client.submit(self._periodic_cleanup,
|
|
88
|
+
scheduler_address=self._scheduler_address,
|
|
89
|
+
db_url=db_url,
|
|
90
|
+
log_level=log_level)
|
|
91
|
+
|
|
92
|
+
@staticmethod
|
|
93
|
+
def _setup_worker():
|
|
94
|
+
"""
|
|
95
|
+
Setup function to be run in each worker process. This moves each worker into it's own process group.
|
|
96
|
+
This fixes an issue where a Ctrl-C in the terminal sends a SIGINT to all workers, which then causes the
|
|
97
|
+
workers to exit before the main process can shutdown the cluster gracefully.
|
|
98
|
+
"""
|
|
99
|
+
os.setsid()
|
|
46
100
|
|
|
47
101
|
async def run(self):
|
|
48
102
|
|
|
@@ -52,6 +106,59 @@ class FastApiFrontEndPlugin(FrontEndBase[FastApiFrontEndConfig]):
|
|
|
52
106
|
# Get as dict
|
|
53
107
|
config_dict = self.full_config.model_dump(mode="json", by_alias=True, round_trip=True)
|
|
54
108
|
|
|
109
|
+
# Three possible cases:
|
|
110
|
+
# 1. Dask is installed and scheduler_address is None, we create a LocalCluster
|
|
111
|
+
# 2. Dask is installed and scheduler_address is set, we use the existing cluster
|
|
112
|
+
# 3. Dask is not installed, we skip the cluster setup
|
|
113
|
+
dask_log_level = LOG_LEVELS.get(self.front_end_config.dask_log_level.upper(), logging.WARNING)
|
|
114
|
+
dask_logger = logging.getLogger("distributed")
|
|
115
|
+
dask_logger.setLevel(dask_log_level)
|
|
116
|
+
|
|
117
|
+
self._scheduler_address = self.front_end_config.scheduler_address
|
|
118
|
+
if self._scheduler_address is None:
|
|
119
|
+
try:
|
|
120
|
+
|
|
121
|
+
from dask.distributed import LocalCluster
|
|
122
|
+
|
|
123
|
+
self._cluster = LocalCluster(processes=True,
|
|
124
|
+
silence_logs=dask_log_level,
|
|
125
|
+
n_workers=self.front_end_config.max_running_async_jobs,
|
|
126
|
+
threads_per_worker=1)
|
|
127
|
+
|
|
128
|
+
self._scheduler_address = self._cluster.scheduler.address
|
|
129
|
+
|
|
130
|
+
with self.blocking_client(self._scheduler_address) as client:
|
|
131
|
+
# Client.run submits a function to be run on each worker
|
|
132
|
+
client.run(self._setup_worker)
|
|
133
|
+
|
|
134
|
+
logger.info("Created local Dask cluster with scheduler at %s", self._scheduler_address)
|
|
135
|
+
|
|
136
|
+
except ImportError:
|
|
137
|
+
logger.warning("Dask is not installed, async execution and evaluation will not be available.")
|
|
138
|
+
|
|
139
|
+
if self._scheduler_address is not None:
|
|
140
|
+
# If we are here then either the user provided a scheduler address, or we created a LocalCluster
|
|
141
|
+
|
|
142
|
+
from nat.front_ends.fastapi.job_store import Base
|
|
143
|
+
from nat.front_ends.fastapi.job_store import get_db_engine
|
|
144
|
+
|
|
145
|
+
db_engine = get_db_engine(self.front_end_config.db_url, use_async=True)
|
|
146
|
+
async with db_engine.begin() as conn:
|
|
147
|
+
await conn.run_sync(Base.metadata.create_all, checkfirst=True) # create tables if they do not exist
|
|
148
|
+
|
|
149
|
+
# If self.front_end_config.db_url is None, then we need to get the actual url from the engine
|
|
150
|
+
db_url = str(db_engine.url)
|
|
151
|
+
await self._submit_cleanup_task(scheduler_address=self._scheduler_address,
|
|
152
|
+
db_url=db_url,
|
|
153
|
+
log_level=dask_log_level)
|
|
154
|
+
|
|
155
|
+
# Set environment variabls such that the worker subprocesses will know how to connect to dask and to
|
|
156
|
+
# the database
|
|
157
|
+
os.environ.update({
|
|
158
|
+
"NAT_DASK_SCHEDULER_ADDRESS": self._scheduler_address,
|
|
159
|
+
"NAT_JOB_STORE_DB_URL": db_url,
|
|
160
|
+
})
|
|
161
|
+
|
|
55
162
|
# Write to YAML file
|
|
56
163
|
yaml_dump(config_dict, config_file)
|
|
57
164
|
|
|
@@ -70,13 +177,25 @@ class FastApiFrontEndPlugin(FrontEndBase[FastApiFrontEndConfig]):
|
|
|
70
177
|
|
|
71
178
|
reload_excludes = ["./.*"]
|
|
72
179
|
|
|
180
|
+
# By default, Uvicorn uses "auto" event loop policy, which prefers `uvloop` if installed. However,
|
|
181
|
+
# uvloop’s event loop policy for macOS doesn’t provide a child watcher (which is needed for MCP server),
|
|
182
|
+
# so setting loop="asyncio" forces Uvicorn to use the standard event loop, which includes child-watcher
|
|
183
|
+
# support.
|
|
184
|
+
if sys.platform == "darwin" or sys.platform.startswith("linux"):
|
|
185
|
+
# For macOS
|
|
186
|
+
event_loop_policy = "asyncio"
|
|
187
|
+
else:
|
|
188
|
+
# For non-macOS platforms
|
|
189
|
+
event_loop_policy = "auto"
|
|
190
|
+
|
|
73
191
|
uvicorn.run("nat.front_ends.fastapi.main:get_app",
|
|
74
192
|
host=self.front_end_config.host,
|
|
75
193
|
port=self.front_end_config.port,
|
|
76
194
|
workers=self.front_end_config.workers,
|
|
77
195
|
reload=self.front_end_config.reload,
|
|
78
196
|
factory=True,
|
|
79
|
-
reload_excludes=reload_excludes
|
|
197
|
+
reload_excludes=reload_excludes,
|
|
198
|
+
loop=event_loop_policy)
|
|
80
199
|
|
|
81
200
|
else:
|
|
82
201
|
app = get_app()
|
|
@@ -110,7 +229,19 @@ class FastApiFrontEndPlugin(FrontEndBase[FastApiFrontEndConfig]):
|
|
|
110
229
|
StandaloneApplication(app, options=options).run()
|
|
111
230
|
|
|
112
231
|
finally:
|
|
232
|
+
logger.debug("Shutting down")
|
|
233
|
+
if self._periodic_cleanup_future is not None:
|
|
234
|
+
logger.info("Cancelling periodic cleanup task.")
|
|
235
|
+
# Use the scheduler address, because self._cluster is None if an external cluster is used
|
|
236
|
+
async with self.client(self._scheduler_address) as client:
|
|
237
|
+
await client.cancel([self._periodic_cleanup_future], asynchronous=True, force=True)
|
|
238
|
+
|
|
239
|
+
if self._cluster is not None:
|
|
240
|
+
# Only shut down the cluster if we created it
|
|
241
|
+
logger.debug("Closing Local Dask cluster.")
|
|
242
|
+
self._cluster.close()
|
|
243
|
+
|
|
113
244
|
try:
|
|
114
245
|
os.remove(config_file_name)
|
|
115
246
|
except OSError as e:
|
|
116
|
-
logger.
|
|
247
|
+
logger.exception(f"Warning: Failed to delete temp file {config_file_name}: {e}")
|