nvidia-nat 1.3.0.dev2__py3-none-any.whl → 1.3.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +2 -2
- nat/agent/base.py +24 -15
- nat/agent/dual_node.py +9 -4
- nat/agent/prompt_optimizer/prompt.py +68 -0
- nat/agent/prompt_optimizer/register.py +149 -0
- nat/agent/react_agent/agent.py +79 -47
- nat/agent/react_agent/register.py +50 -22
- nat/agent/reasoning_agent/reasoning_agent.py +11 -9
- nat/agent/register.py +1 -1
- nat/agent/rewoo_agent/agent.py +326 -148
- nat/agent/rewoo_agent/prompt.py +19 -22
- nat/agent/rewoo_agent/register.py +54 -27
- nat/agent/tool_calling_agent/agent.py +84 -28
- nat/agent/tool_calling_agent/register.py +51 -28
- nat/authentication/api_key/api_key_auth_provider.py +2 -2
- nat/authentication/credential_validator/bearer_token_validator.py +557 -0
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
- nat/authentication/interfaces.py +5 -2
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +69 -36
- nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
- nat/authentication/register.py +0 -1
- nat/builder/builder.py +56 -24
- nat/builder/component_utils.py +9 -5
- nat/builder/context.py +68 -17
- nat/builder/eval_builder.py +16 -11
- nat/builder/framework_enum.py +1 -0
- nat/builder/front_end.py +1 -1
- nat/builder/function.py +378 -8
- nat/builder/function_base.py +3 -3
- nat/builder/function_info.py +6 -8
- nat/builder/user_interaction_manager.py +2 -2
- nat/builder/workflow.py +13 -1
- nat/builder/workflow_builder.py +281 -76
- nat/cli/cli_utils/config_override.py +2 -2
- nat/cli/commands/evaluate.py +1 -1
- nat/cli/commands/info/info.py +16 -6
- nat/cli/commands/info/list_channels.py +1 -1
- nat/cli/commands/info/list_components.py +7 -8
- nat/cli/commands/mcp/__init__.py +14 -0
- nat/cli/commands/mcp/mcp.py +986 -0
- nat/cli/commands/object_store/__init__.py +14 -0
- nat/cli/commands/object_store/object_store.py +227 -0
- nat/cli/commands/optimize.py +90 -0
- nat/cli/commands/registry/publish.py +2 -2
- nat/cli/commands/registry/pull.py +2 -2
- nat/cli/commands/registry/remove.py +2 -2
- nat/cli/commands/registry/search.py +15 -17
- nat/cli/commands/start.py +16 -5
- nat/cli/commands/uninstall.py +1 -1
- nat/cli/commands/workflow/templates/config.yml.j2 +14 -13
- nat/cli/commands/workflow/templates/pyproject.toml.j2 +4 -1
- nat/cli/commands/workflow/templates/register.py.j2 +2 -3
- nat/cli/commands/workflow/templates/workflow.py.j2 +35 -21
- nat/cli/commands/workflow/workflow_commands.py +62 -22
- nat/cli/entrypoint.py +8 -10
- nat/cli/main.py +3 -0
- nat/cli/register_workflow.py +38 -4
- nat/cli/type_registry.py +75 -6
- nat/control_flow/__init__.py +0 -0
- nat/control_flow/register.py +20 -0
- nat/control_flow/router_agent/__init__.py +0 -0
- nat/control_flow/router_agent/agent.py +329 -0
- nat/control_flow/router_agent/prompt.py +48 -0
- nat/control_flow/router_agent/register.py +91 -0
- nat/control_flow/sequential_executor.py +166 -0
- nat/data_models/agent.py +34 -0
- nat/data_models/api_server.py +74 -66
- nat/data_models/authentication.py +23 -9
- nat/data_models/common.py +1 -1
- nat/data_models/component.py +2 -0
- nat/data_models/component_ref.py +11 -0
- nat/data_models/config.py +41 -17
- nat/data_models/dataset_handler.py +1 -1
- nat/data_models/discovery_metadata.py +4 -4
- nat/data_models/evaluate.py +4 -1
- nat/data_models/function.py +34 -0
- nat/data_models/function_dependencies.py +14 -6
- nat/data_models/gated_field_mixin.py +242 -0
- nat/data_models/intermediate_step.py +3 -3
- nat/data_models/optimizable.py +119 -0
- nat/data_models/optimizer.py +149 -0
- nat/data_models/span.py +41 -3
- nat/data_models/swe_bench_model.py +1 -1
- nat/data_models/temperature_mixin.py +44 -0
- nat/data_models/thinking_mixin.py +86 -0
- nat/data_models/top_p_mixin.py +44 -0
- nat/embedder/nim_embedder.py +1 -1
- nat/embedder/openai_embedder.py +1 -1
- nat/embedder/register.py +0 -1
- nat/eval/config.py +3 -1
- nat/eval/dataset_handler/dataset_handler.py +71 -7
- nat/eval/evaluate.py +86 -31
- nat/eval/evaluator/base_evaluator.py +1 -1
- nat/eval/evaluator/evaluator_model.py +13 -0
- nat/eval/intermediate_step_adapter.py +1 -1
- nat/eval/rag_evaluator/evaluate.py +2 -2
- nat/eval/rag_evaluator/register.py +3 -3
- nat/eval/register.py +4 -1
- nat/eval/remote_workflow.py +3 -3
- nat/eval/runtime_evaluator/__init__.py +14 -0
- nat/eval/runtime_evaluator/evaluate.py +123 -0
- nat/eval/runtime_evaluator/register.py +100 -0
- nat/eval/swe_bench_evaluator/evaluate.py +6 -6
- nat/eval/trajectory_evaluator/evaluate.py +1 -1
- nat/eval/trajectory_evaluator/register.py +1 -1
- nat/eval/tunable_rag_evaluator/evaluate.py +4 -7
- nat/eval/utils/eval_trace_ctx.py +89 -0
- nat/eval/utils/weave_eval.py +18 -9
- nat/experimental/decorators/experimental_warning_decorator.py +27 -7
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +3 -3
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +1 -1
- nat/experimental/test_time_compute/models/strategy_base.py +5 -4
- nat/experimental/test_time_compute/register.py +0 -1
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -3
- nat/front_ends/console/authentication_flow_handler.py +82 -30
- nat/front_ends/console/console_front_end_plugin.py +8 -5
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
- nat/front_ends/fastapi/dask_client_mixin.py +65 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +36 -5
- nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +135 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +452 -282
- nat/front_ends/fastapi/job_store.py +518 -99
- nat/front_ends/fastapi/main.py +11 -19
- nat/front_ends/fastapi/message_handler.py +13 -14
- nat/front_ends/fastapi/message_validator.py +19 -19
- nat/front_ends/fastapi/response_helpers.py +4 -4
- nat/front_ends/fastapi/step_adaptor.py +2 -2
- nat/front_ends/fastapi/utils.py +57 -0
- nat/front_ends/mcp/introspection_token_verifier.py +73 -0
- nat/front_ends/mcp/mcp_front_end_config.py +10 -1
- nat/front_ends/mcp/mcp_front_end_plugin.py +45 -13
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +116 -8
- nat/front_ends/mcp/tool_converter.py +44 -14
- nat/front_ends/register.py +0 -1
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
- nat/llm/aws_bedrock_llm.py +24 -12
- nat/llm/azure_openai_llm.py +13 -6
- nat/llm/litellm_llm.py +69 -0
- nat/llm/nim_llm.py +20 -8
- nat/llm/openai_llm.py +14 -6
- nat/llm/register.py +4 -1
- nat/llm/utils/env_config_value.py +2 -3
- nat/llm/utils/thinking.py +215 -0
- nat/meta/pypi.md +9 -9
- nat/object_store/register.py +0 -1
- nat/observability/exporter/base_exporter.py +3 -3
- nat/observability/exporter/file_exporter.py +1 -1
- nat/observability/exporter/processing_exporter.py +309 -81
- nat/observability/exporter/span_exporter.py +35 -15
- nat/observability/exporter_manager.py +7 -7
- nat/observability/mixin/file_mixin.py +7 -7
- nat/observability/mixin/redaction_config_mixin.py +42 -0
- nat/observability/mixin/tagging_config_mixin.py +62 -0
- nat/observability/mixin/type_introspection_mixin.py +420 -107
- nat/observability/processor/batching_processor.py +5 -7
- nat/observability/processor/falsy_batch_filter_processor.py +55 -0
- nat/observability/processor/processor.py +3 -0
- nat/observability/processor/processor_factory.py +70 -0
- nat/observability/processor/redaction/__init__.py +24 -0
- nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
- nat/observability/processor/redaction/redaction_processor.py +177 -0
- nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
- nat/observability/processor/span_tagging_processor.py +68 -0
- nat/observability/register.py +6 -4
- nat/profiler/calc/calc_runner.py +3 -4
- nat/profiler/callbacks/agno_callback_handler.py +1 -1
- nat/profiler/callbacks/langchain_callback_handler.py +6 -6
- nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +3 -3
- nat/profiler/data_frame_row.py +1 -1
- nat/profiler/decorators/framework_wrapper.py +62 -13
- nat/profiler/decorators/function_tracking.py +160 -3
- nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
- nat/profiler/forecasting/models/linear_model.py +1 -1
- nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
- nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/data_models.py +3 -3
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +8 -9
- nat/profiler/inference_optimization/token_uniqueness.py +1 -1
- nat/profiler/parameter_optimization/__init__.py +0 -0
- nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
- nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
- nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
- nat/profiler/parameter_optimization/parameter_selection.py +107 -0
- nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
- nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
- nat/profiler/parameter_optimization/update_helpers.py +66 -0
- nat/profiler/profile_runner.py +14 -9
- nat/profiler/utils.py +4 -2
- nat/registry_handlers/local/local_handler.py +2 -2
- nat/registry_handlers/package_utils.py +1 -2
- nat/registry_handlers/pypi/pypi_handler.py +23 -26
- nat/registry_handlers/register.py +3 -4
- nat/registry_handlers/rest/rest_handler.py +12 -13
- nat/retriever/milvus/retriever.py +2 -2
- nat/retriever/nemo_retriever/retriever.py +1 -1
- nat/retriever/register.py +0 -1
- nat/runtime/loader.py +2 -2
- nat/runtime/runner.py +106 -8
- nat/runtime/session.py +69 -8
- nat/settings/global_settings.py +16 -5
- nat/tool/chat_completion.py +5 -2
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +3 -3
- nat/tool/datetime_tools.py +49 -9
- nat/tool/document_search.py +2 -2
- nat/tool/github_tools.py +450 -0
- nat/tool/memory_tools/get_memory_tool.py +1 -1
- nat/tool/nvidia_rag.py +1 -1
- nat/tool/register.py +2 -9
- nat/tool/retriever.py +3 -2
- nat/utils/callable_utils.py +70 -0
- nat/utils/data_models/schema_validator.py +3 -3
- nat/utils/decorators.py +210 -0
- nat/utils/exception_handlers/automatic_retries.py +104 -51
- nat/utils/exception_handlers/schemas.py +1 -1
- nat/utils/io/yaml_tools.py +2 -2
- nat/utils/log_levels.py +25 -0
- nat/utils/reactive/base/observable_base.py +2 -2
- nat/utils/reactive/base/observer_base.py +1 -1
- nat/utils/reactive/observable.py +2 -2
- nat/utils/reactive/observer.py +4 -4
- nat/utils/reactive/subscription.py +1 -1
- nat/utils/settings/global_settings.py +6 -8
- nat/utils/type_converter.py +4 -3
- nat/utils/type_utils.py +9 -5
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/METADATA +42 -18
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/RECORD +238 -196
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/entry_points.txt +1 -0
- nat/cli/commands/info/list_mcp.py +0 -304
- nat/tool/github_tools/create_github_commit.py +0 -133
- nat/tool/github_tools/create_github_issue.py +0 -87
- nat/tool/github_tools/create_github_pr.py +0 -106
- nat/tool/github_tools/get_github_file.py +0 -106
- nat/tool/github_tools/get_github_issue.py +0 -166
- nat/tool/github_tools/get_github_pr.py +0 -256
- nat/tool/github_tools/update_github_issue.py +0 -100
- nat/tool/mcp/exceptions.py +0 -142
- nat/tool/mcp/mcp_client.py +0 -255
- nat/tool/mcp/mcp_tool.py +0 -96
- nat/utils/exception_handlers/mcp.py +0 -211
- /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
- /nat/{tool/mcp → authentication/credential_validator}/__init__.py +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -82,7 +82,7 @@ async def register_ttc_tool_orchestration_function(
|
|
|
82
82
|
function_map = {}
|
|
83
83
|
for fn_ref in config.augmented_fns:
|
|
84
84
|
# Retrieve the actual function from the builder
|
|
85
|
-
fn_obj = builder.get_function(fn_ref)
|
|
85
|
+
fn_obj = await builder.get_function(fn_ref)
|
|
86
86
|
function_map[fn_ref] = fn_obj
|
|
87
87
|
|
|
88
88
|
# 2) Instantiate search, editing, scoring, selection strategies (if any)
|
|
@@ -148,13 +148,13 @@ async def register_ttc_tool_orchestration_function(
|
|
|
148
148
|
result = await fn.acall_invoke(item.output)
|
|
149
149
|
return item, result, None
|
|
150
150
|
except Exception as e:
|
|
151
|
-
logger.
|
|
151
|
+
logger.exception(f"Error invoking function '{item.name}': {e}")
|
|
152
152
|
return item, None, str(e)
|
|
153
153
|
|
|
154
154
|
tasks = []
|
|
155
155
|
for item in ttc_items:
|
|
156
156
|
if item.name not in function_map:
|
|
157
|
-
logger.error(f"Function '{item.name}' not found in function map.")
|
|
157
|
+
logger.error(f"Function '{item.name}' not found in function map.", exc_info=True)
|
|
158
158
|
item.output = f"Error: Function '{item.name}' not found in function map. Check your input"
|
|
159
159
|
else:
|
|
160
160
|
fn = function_map[item.name]
|
|
@@ -80,7 +80,7 @@ async def register_ttc_tool_wrapper_function(
|
|
|
80
80
|
raise ImportError("langchain-core is not installed. Please install it to use SingleShotMultiPlanPlanner.\n"
|
|
81
81
|
"This error can be resolved by installing nvidia-nat-langchain.")
|
|
82
82
|
|
|
83
|
-
augmented_function: Function = builder.get_function(config.augmented_fn)
|
|
83
|
+
augmented_function: Function = await builder.get_function(config.augmented_fn)
|
|
84
84
|
input_llm: BaseChatModel = await builder.get_llm(config.input_llm, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
85
85
|
|
|
86
86
|
if not augmented_function.has_single_output:
|
|
@@ -17,9 +17,10 @@ from abc import ABC
|
|
|
17
17
|
from abc import abstractmethod
|
|
18
18
|
|
|
19
19
|
from nat.builder.builder import Builder
|
|
20
|
-
from nat.experimental.test_time_compute.models.ttc_item import TTCItem
|
|
21
|
-
from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum, PipelineTypeEnum
|
|
22
20
|
from nat.data_models.ttc_strategy import TTCStrategyBaseConfig
|
|
21
|
+
from nat.experimental.test_time_compute.models.stage_enums import PipelineTypeEnum
|
|
22
|
+
from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum
|
|
23
|
+
from nat.experimental.test_time_compute.models.ttc_item import TTCItem
|
|
23
24
|
|
|
24
25
|
|
|
25
26
|
class StrategyBase(ABC):
|
|
@@ -45,11 +46,11 @@ class StrategyBase(ABC):
|
|
|
45
46
|
items: list[TTCItem],
|
|
46
47
|
original_prompt: str | None = None,
|
|
47
48
|
agent_context: str | None = None,
|
|
48
|
-
**kwargs) -> [TTCItem]:
|
|
49
|
+
**kwargs) -> list[TTCItem]:
|
|
49
50
|
pass
|
|
50
51
|
|
|
51
52
|
@abstractmethod
|
|
52
|
-
def supported_pipeline_types(self) -> [PipelineTypeEnum]:
|
|
53
|
+
def supported_pipeline_types(self) -> list[PipelineTypeEnum]:
|
|
53
54
|
"""Return the stage types supported by this selector."""
|
|
54
55
|
pass
|
|
55
56
|
|
|
@@ -71,7 +71,7 @@ class LLMBasedOutputMergingSelector(StrategyBase):
|
|
|
71
71
|
raise ImportError("langchain-core is not installed. Please install it to use SingleShotMultiPlanPlanner.\n"
|
|
72
72
|
"This error can be resolved by installing nvidia-nat-langchain.")
|
|
73
73
|
|
|
74
|
-
from
|
|
74
|
+
from collections.abc import Callable
|
|
75
75
|
|
|
76
76
|
from pydantic import BaseModel
|
|
77
77
|
|
|
@@ -135,8 +135,6 @@ class LLMBasedOutputMergingSelector(StrategyBase):
|
|
|
135
135
|
except Exception as e:
|
|
136
136
|
logger.error(f"Error parsing merged output: {e}")
|
|
137
137
|
raise ValueError("Failed to parse merged output.")
|
|
138
|
-
else:
|
|
139
|
-
merged_output = merged_output
|
|
140
138
|
|
|
141
139
|
logger.info("Merged output: %s", str(merged_output))
|
|
142
140
|
|
|
@@ -14,13 +14,16 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import asyncio
|
|
17
|
+
import logging
|
|
17
18
|
import secrets
|
|
18
19
|
import webbrowser
|
|
19
20
|
from dataclasses import dataclass
|
|
20
21
|
from dataclasses import field
|
|
21
22
|
|
|
22
23
|
import click
|
|
24
|
+
import httpx
|
|
23
25
|
import pkce
|
|
26
|
+
from authlib.common.errors import AuthlibBaseError as OAuthError
|
|
24
27
|
from authlib.integrations.httpx_client import AsyncOAuth2Client
|
|
25
28
|
from fastapi import FastAPI
|
|
26
29
|
from fastapi import Request
|
|
@@ -32,6 +35,8 @@ from nat.data_models.authentication import AuthFlowType
|
|
|
32
35
|
from nat.data_models.authentication import AuthProviderBaseConfig
|
|
33
36
|
from nat.front_ends.fastapi.fastapi_front_end_controller import _FastApiFrontEndController
|
|
34
37
|
|
|
38
|
+
logger = logging.getLogger(__name__)
|
|
39
|
+
|
|
35
40
|
|
|
36
41
|
# --------------------------------------------------------------------------- #
|
|
37
42
|
# Helpers #
|
|
@@ -87,17 +92,53 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
|
|
|
87
92
|
"""
|
|
88
93
|
Separated for easy overriding in tests (to inject ASGITransport).
|
|
89
94
|
"""
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
95
|
+
try:
|
|
96
|
+
client = AsyncOAuth2Client(
|
|
97
|
+
client_id=cfg.client_id,
|
|
98
|
+
client_secret=cfg.client_secret,
|
|
99
|
+
redirect_uri=cfg.redirect_uri,
|
|
100
|
+
scope=" ".join(cfg.scopes) if cfg.scopes else None,
|
|
101
|
+
token_endpoint=cfg.token_url,
|
|
102
|
+
token_endpoint_auth_method=cfg.token_endpoint_auth_method,
|
|
103
|
+
code_challenge_method="S256" if cfg.use_pkce else None,
|
|
104
|
+
)
|
|
105
|
+
self._oauth_client = client
|
|
106
|
+
return client
|
|
107
|
+
except (OAuthError, ValueError, TypeError) as e:
|
|
108
|
+
raise RuntimeError(f"Invalid OAuth2 configuration: {e}") from e
|
|
109
|
+
except Exception as e:
|
|
110
|
+
raise RuntimeError(f"Failed to create OAuth2 client: {e}") from e
|
|
111
|
+
|
|
112
|
+
def _create_authorization_url(self,
|
|
113
|
+
client: AsyncOAuth2Client,
|
|
114
|
+
config: OAuth2AuthCodeFlowProviderConfig,
|
|
115
|
+
state: str,
|
|
116
|
+
verifier: str | None = None,
|
|
117
|
+
challenge: str | None = None) -> str:
|
|
118
|
+
"""
|
|
119
|
+
Create OAuth authorization URL with proper error handling.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
client: The OAuth2 client instance
|
|
123
|
+
config: OAuth2 configuration
|
|
124
|
+
state: OAuth state parameter
|
|
125
|
+
verifier: PKCE verifier (if using PKCE)
|
|
126
|
+
challenge: PKCE challenge (if using PKCE)
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
The authorization URL
|
|
130
|
+
"""
|
|
131
|
+
try:
|
|
132
|
+
auth_url, _ = client.create_authorization_url(
|
|
133
|
+
config.authorization_url,
|
|
134
|
+
state=state,
|
|
135
|
+
code_verifier=verifier if config.use_pkce else None,
|
|
136
|
+
code_challenge=challenge if config.use_pkce else None,
|
|
137
|
+
**(config.authorization_kwargs or {})
|
|
138
|
+
)
|
|
139
|
+
return auth_url
|
|
140
|
+
except (OAuthError, ValueError, TypeError) as e:
|
|
141
|
+
raise RuntimeError(f"Error creating OAuth authorization URL: {e}") from e
|
|
101
142
|
|
|
102
143
|
# --------------------------- HTTP Basic ------------------------------ #
|
|
103
144
|
@staticmethod
|
|
@@ -131,13 +172,12 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
|
|
|
131
172
|
flow_state.verifier = verifier
|
|
132
173
|
flow_state.challenge = challenge
|
|
133
174
|
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
)
|
|
175
|
+
# Create authorization URL using helper function
|
|
176
|
+
auth_url = self._create_authorization_url(client=client,
|
|
177
|
+
config=cfg,
|
|
178
|
+
state=state,
|
|
179
|
+
verifier=flow_state.verifier,
|
|
180
|
+
challenge=flow_state.challenge)
|
|
141
181
|
|
|
142
182
|
# Register flow + maybe spin up redirect handler
|
|
143
183
|
async with self._server_lock:
|
|
@@ -149,14 +189,18 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
|
|
|
149
189
|
self._flows[state] = flow_state
|
|
150
190
|
self._active_flows += 1
|
|
151
191
|
|
|
152
|
-
|
|
153
|
-
|
|
192
|
+
try:
|
|
193
|
+
webbrowser.open(auth_url)
|
|
194
|
+
click.echo("Your browser has been opened for authentication.")
|
|
195
|
+
except Exception as e:
|
|
196
|
+
logger.error("Browser open failed: %s", e)
|
|
197
|
+
raise RuntimeError(f"Browser open failed: {e}") from e
|
|
154
198
|
|
|
155
199
|
# Wait for the redirect to land
|
|
156
200
|
try:
|
|
157
201
|
token = await asyncio.wait_for(flow_state.future, timeout=300)
|
|
158
|
-
except
|
|
159
|
-
raise RuntimeError("Authentication timed out (5 min).")
|
|
202
|
+
except TimeoutError as exc:
|
|
203
|
+
raise RuntimeError("Authentication timed out (5 min).") from exc
|
|
160
204
|
finally:
|
|
161
205
|
async with self._server_lock:
|
|
162
206
|
self._flows.pop(state, None)
|
|
@@ -175,9 +219,9 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
|
|
|
175
219
|
# --------------- redirect server / in‑process app -------------------- #
|
|
176
220
|
async def _build_redirect_app(self) -> FastAPI:
|
|
177
221
|
"""
|
|
178
|
-
* If cfg.run_redirect_local_server == True → start a
|
|
179
|
-
* Else → only build the
|
|
180
|
-
for in‑process testing
|
|
222
|
+
* If cfg.run_redirect_local_server == True → start a local server.
|
|
223
|
+
* Else → only build the redirect app and save it to `self._redirect_app`
|
|
224
|
+
for in‑process testing.
|
|
181
225
|
"""
|
|
182
226
|
app = FastAPI()
|
|
183
227
|
|
|
@@ -195,8 +239,16 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
|
|
|
195
239
|
state=state,
|
|
196
240
|
)
|
|
197
241
|
flow_state.future.set_result(token)
|
|
198
|
-
except
|
|
199
|
-
flow_state.future.set_exception(
|
|
242
|
+
except OAuthError as e:
|
|
243
|
+
flow_state.future.set_exception(
|
|
244
|
+
RuntimeError(f"Authorization server rejected request: {e.error} ({e.description})"))
|
|
245
|
+
return "Authentication failed: Authorization server rejected the request. You may close this tab."
|
|
246
|
+
except httpx.HTTPError as e:
|
|
247
|
+
flow_state.future.set_exception(RuntimeError(f"Network error during token fetch: {e}"))
|
|
248
|
+
return "Authentication failed: Network error occurred. You may close this tab."
|
|
249
|
+
except Exception as e:
|
|
250
|
+
flow_state.future.set_exception(RuntimeError(f"Authentication failed: {e}"))
|
|
251
|
+
return "Authentication failed: An unexpected error occurred. You may close this tab."
|
|
200
252
|
return "Authentication successful – you may close this tab."
|
|
201
253
|
|
|
202
254
|
return app
|
|
@@ -213,7 +265,7 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
|
|
|
213
265
|
|
|
214
266
|
asyncio.create_task(self._server_controller.start_server(host="localhost", port=8000))
|
|
215
267
|
|
|
216
|
-
# Give
|
|
268
|
+
# Give the server a moment to bind sockets before we return
|
|
217
269
|
await asyncio.sleep(0.3)
|
|
218
270
|
except Exception as exc: # noqa: BLE001
|
|
219
271
|
raise RuntimeError(f"Failed to start redirect server: {exc}") from exc
|
|
@@ -227,7 +279,7 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
|
|
|
227
279
|
@property
|
|
228
280
|
def redirect_app(self) -> FastAPI | None:
|
|
229
281
|
"""
|
|
230
|
-
In
|
|
231
|
-
app is exposed
|
|
282
|
+
In test mode (run_redirect_local_server=False) the in‑memory redirect
|
|
283
|
+
app is exposed for testing purposes.
|
|
232
284
|
"""
|
|
233
285
|
return self._redirect_app
|
|
@@ -55,9 +55,10 @@ class ConsoleFrontEndPlugin(SimpleFrontEndPluginBase[ConsoleFrontEndConfig]):
|
|
|
55
55
|
self.auth_flow_handler = ConsoleAuthenticationFlowHandler()
|
|
56
56
|
|
|
57
57
|
async def pre_run(self):
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
58
|
+
if (self.front_end_config.input_query is not None and self.front_end_config.input_file is not None):
|
|
59
|
+
raise click.UsageError("Must specify either --input or --input_file, not both")
|
|
60
|
+
if (self.front_end_config.input_query is None and self.front_end_config.input_file is None):
|
|
61
|
+
raise click.UsageError("Must specify either --input or --input_file")
|
|
61
62
|
|
|
62
63
|
async def run_workflow(self, session_manager: SessionManager):
|
|
63
64
|
|
|
@@ -80,12 +81,14 @@ class ConsoleFrontEndPlugin(SimpleFrontEndPluginBase[ConsoleFrontEndConfig]):
|
|
|
80
81
|
input_list = list(self.front_end_config.input_query)
|
|
81
82
|
logger.debug("Processing input: %s", self.front_end_config.input_query)
|
|
82
83
|
|
|
83
|
-
|
|
84
|
+
# Make `return_exceptions=False` explicit; all exceptions are raised instead of being silenced
|
|
85
|
+
runner_outputs = await asyncio.gather(*[run_single_query(query) for query in input_list],
|
|
86
|
+
return_exceptions=False)
|
|
84
87
|
|
|
85
88
|
elif (self.front_end_config.input_file):
|
|
86
89
|
|
|
87
90
|
# Run the workflow
|
|
88
|
-
with open(self.front_end_config.input_file,
|
|
91
|
+
with open(self.front_end_config.input_file, encoding="utf-8") as f:
|
|
89
92
|
|
|
90
93
|
async with session_manager.workflow.run(f) as runner:
|
|
91
94
|
runner_outputs = await runner.result(to_type=str)
|
|
@@ -22,6 +22,7 @@ from dataclasses import dataclass
|
|
|
22
22
|
from dataclasses import field
|
|
23
23
|
|
|
24
24
|
import pkce
|
|
25
|
+
from authlib.common.errors import AuthlibBaseError as OAuthError
|
|
25
26
|
from authlib.integrations.httpx_client import AsyncOAuth2Client
|
|
26
27
|
|
|
27
28
|
from nat.authentication.interfaces import FlowHandlerBase
|
|
@@ -61,14 +62,50 @@ class WebSocketAuthenticationFlowHandler(FlowHandlerBase):
|
|
|
61
62
|
|
|
62
63
|
raise NotImplementedError(f"Authentication method '{method}' is not supported by the websocket frontend.")
|
|
63
64
|
|
|
64
|
-
def create_oauth_client(self, config: OAuth2AuthCodeFlowProviderConfig):
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
65
|
+
def create_oauth_client(self, config: OAuth2AuthCodeFlowProviderConfig) -> AsyncOAuth2Client:
|
|
66
|
+
try:
|
|
67
|
+
return AsyncOAuth2Client(client_id=config.client_id,
|
|
68
|
+
client_secret=config.client_secret,
|
|
69
|
+
redirect_uri=config.redirect_uri,
|
|
70
|
+
scope=" ".join(config.scopes) if config.scopes else None,
|
|
71
|
+
token_endpoint=config.token_url,
|
|
72
|
+
code_challenge_method='S256' if config.use_pkce else None,
|
|
73
|
+
token_endpoint_auth_method=config.token_endpoint_auth_method)
|
|
74
|
+
except (OAuthError, ValueError, TypeError) as e:
|
|
75
|
+
raise RuntimeError(f"Invalid OAuth2 configuration: {e}") from e
|
|
76
|
+
except Exception as e:
|
|
77
|
+
raise RuntimeError(f"Failed to create OAuth2 client: {e}") from e
|
|
78
|
+
|
|
79
|
+
def _create_authorization_url(self,
|
|
80
|
+
client: AsyncOAuth2Client,
|
|
81
|
+
config: OAuth2AuthCodeFlowProviderConfig,
|
|
82
|
+
state: str,
|
|
83
|
+
verifier: str = None,
|
|
84
|
+
challenge: str = None) -> str:
|
|
85
|
+
"""
|
|
86
|
+
Create OAuth authorization URL with proper error handling.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
client: The OAuth2 client instance
|
|
90
|
+
config: OAuth2 configuration
|
|
91
|
+
state: OAuth state parameter
|
|
92
|
+
verifier: PKCE verifier (if using PKCE)
|
|
93
|
+
challenge: PKCE challenge (if using PKCE)
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
The authorization URL
|
|
97
|
+
"""
|
|
98
|
+
try:
|
|
99
|
+
authorization_url, _ = client.create_authorization_url(
|
|
100
|
+
config.authorization_url,
|
|
101
|
+
state=state,
|
|
102
|
+
code_verifier=verifier if config.use_pkce else None,
|
|
103
|
+
code_challenge=challenge if config.use_pkce else None,
|
|
104
|
+
**(config.authorization_kwargs or {})
|
|
105
|
+
)
|
|
106
|
+
return authorization_url
|
|
107
|
+
except (OAuthError, ValueError, TypeError) as e:
|
|
108
|
+
raise RuntimeError(f"Error creating OAuth authorization URL: {e}") from e
|
|
72
109
|
|
|
73
110
|
async def _handle_oauth2_auth_code_flow(self, config: OAuth2AuthCodeFlowProviderConfig) -> AuthenticatedContext:
|
|
74
111
|
|
|
@@ -82,21 +119,19 @@ class WebSocketAuthenticationFlowHandler(FlowHandlerBase):
|
|
|
82
119
|
flow_state.verifier = verifier
|
|
83
120
|
flow_state.challenge = challenge
|
|
84
121
|
|
|
85
|
-
authorization_url
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
**(config.authorization_kwargs or {})
|
|
91
|
-
)
|
|
122
|
+
authorization_url = self._create_authorization_url(client=flow_state.client,
|
|
123
|
+
config=config,
|
|
124
|
+
state=state,
|
|
125
|
+
verifier=flow_state.verifier,
|
|
126
|
+
challenge=flow_state.challenge)
|
|
92
127
|
|
|
93
128
|
await self._add_flow_cb(state, flow_state)
|
|
94
129
|
await self._web_socket_message_handler.create_websocket_message(_HumanPromptOAuthConsent(text=authorization_url)
|
|
95
130
|
)
|
|
96
131
|
try:
|
|
97
132
|
token = await asyncio.wait_for(flow_state.future, timeout=300)
|
|
98
|
-
except
|
|
99
|
-
raise RuntimeError("Authentication flow timed out after 5 minutes.")
|
|
133
|
+
except TimeoutError as exc:
|
|
134
|
+
raise RuntimeError("Authentication flow timed out after 5 minutes.") from exc
|
|
100
135
|
finally:
|
|
101
136
|
|
|
102
137
|
await self._remove_flow_cb(state)
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import typing
|
|
17
|
+
from abc import ABC
|
|
18
|
+
from collections.abc import AsyncGenerator
|
|
19
|
+
from collections.abc import Generator
|
|
20
|
+
from contextlib import asynccontextmanager
|
|
21
|
+
from contextlib import contextmanager
|
|
22
|
+
|
|
23
|
+
if typing.TYPE_CHECKING:
|
|
24
|
+
from dask.distributed import Client
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class DaskClientMixin(ABC):
|
|
28
|
+
|
|
29
|
+
@asynccontextmanager
|
|
30
|
+
async def client(self, address: str) -> AsyncGenerator["Client"]:
|
|
31
|
+
"""
|
|
32
|
+
Async context manager for obtaining a Dask client.
|
|
33
|
+
|
|
34
|
+
Yields
|
|
35
|
+
------
|
|
36
|
+
Client
|
|
37
|
+
An async Dask client connected to the scheduler. The client is automatically closed when exiting the
|
|
38
|
+
context manager.
|
|
39
|
+
"""
|
|
40
|
+
from dask.distributed import Client
|
|
41
|
+
client = await Client(address=address, asynchronous=True)
|
|
42
|
+
|
|
43
|
+
try:
|
|
44
|
+
yield client
|
|
45
|
+
finally:
|
|
46
|
+
await client.close()
|
|
47
|
+
|
|
48
|
+
@contextmanager
|
|
49
|
+
def blocking_client(self, address: str) -> Generator["Client"]:
|
|
50
|
+
"""
|
|
51
|
+
context manager for obtaining a blocking Dask client.
|
|
52
|
+
|
|
53
|
+
Yields
|
|
54
|
+
------
|
|
55
|
+
Client
|
|
56
|
+
A blocking Dask client connected to the scheduler. The client is automatically closed when exiting the
|
|
57
|
+
context manager.
|
|
58
|
+
"""
|
|
59
|
+
from dask.distributed import Client
|
|
60
|
+
client = Client(address=address)
|
|
61
|
+
|
|
62
|
+
try:
|
|
63
|
+
yield client
|
|
64
|
+
finally:
|
|
65
|
+
client.close()
|
|
@@ -14,6 +14,8 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import logging
|
|
17
|
+
import os
|
|
18
|
+
import sys
|
|
17
19
|
import typing
|
|
18
20
|
from datetime import datetime
|
|
19
21
|
from pathlib import Path
|
|
@@ -31,6 +33,20 @@ logger = logging.getLogger(__name__)
|
|
|
31
33
|
YAML_EXTENSIONS = (".yaml", ".yml")
|
|
32
34
|
|
|
33
35
|
|
|
36
|
+
def _is_reserved(path: Path) -> bool:
|
|
37
|
+
"""
|
|
38
|
+
Check if a path is reserved in the current Python version and platform.
|
|
39
|
+
|
|
40
|
+
On Windows, this function checks if the path is reserved in the current Python version.
|
|
41
|
+
On other platforms, returns False
|
|
42
|
+
"""
|
|
43
|
+
if sys.platform != "win32":
|
|
44
|
+
return False
|
|
45
|
+
if sys.version_info >= (3, 13):
|
|
46
|
+
return os.path.isreserved(path)
|
|
47
|
+
return path.is_reserved()
|
|
48
|
+
|
|
49
|
+
|
|
34
50
|
class EvaluateRequest(BaseModel):
|
|
35
51
|
"""Request model for the evaluate endpoint."""
|
|
36
52
|
config_file: str = Field(description="Path to the configuration file for evaluation")
|
|
@@ -51,7 +67,7 @@ class EvaluateRequest(BaseModel):
|
|
|
51
67
|
f"Job ID '{job_id}' contains invalid characters. Only alphanumeric characters and underscores are"
|
|
52
68
|
" allowed.")
|
|
53
69
|
|
|
54
|
-
if job_id_path
|
|
70
|
+
if _is_reserved(job_id_path):
|
|
55
71
|
# reserved names is Windows specific
|
|
56
72
|
raise ValueError(f"Job ID '{job_id}' is a reserved name. Please choose a different name.")
|
|
57
73
|
|
|
@@ -68,7 +84,7 @@ class EvaluateRequest(BaseModel):
|
|
|
68
84
|
raise ValueError(f"Config file '{config_file}' must be a YAML file with one of the following extensions: "
|
|
69
85
|
f"{', '.join(YAML_EXTENSIONS)}")
|
|
70
86
|
|
|
71
|
-
if config_file_path
|
|
87
|
+
if _is_reserved(config_file_path):
|
|
72
88
|
# reserved names is Windows specific
|
|
73
89
|
raise ValueError(f"Config file '{config_file}' is a reserved name. Please choose a different name.")
|
|
74
90
|
|
|
@@ -181,9 +197,24 @@ class FastApiFrontEndConfig(FrontEndBaseConfig, name="fastapi"):
|
|
|
181
197
|
port: int = Field(default=8000, description="Port to bind the server to", ge=0, le=65535)
|
|
182
198
|
reload: bool = Field(default=False, description="Enable auto-reload for development")
|
|
183
199
|
workers: int = Field(default=1, description="Number of workers to run", ge=1)
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
200
|
+
scheduler_address: str | None = Field(
|
|
201
|
+
default=None,
|
|
202
|
+
description=("Address of the Dask scheduler to use for async jobs. If None, a Dask local cluster is created. "
|
|
203
|
+
"Note: This requires the optional dask dependency to be installed."))
|
|
204
|
+
db_url: str | None = Field(
|
|
205
|
+
default=None,
|
|
206
|
+
description=
|
|
207
|
+
"SQLAlchemy database URL for storing async job metadata, if unset a temporary SQLite database is used.")
|
|
208
|
+
max_running_async_jobs: int = Field(
|
|
209
|
+
default=10,
|
|
210
|
+
description=(
|
|
211
|
+
"Maximum number of async jobs to run concurrently, this controls the number of dask workers created. "
|
|
212
|
+
"This parameter is only used when scheduler_address is `None` and a Dask local cluster is created."),
|
|
213
|
+
ge=1)
|
|
214
|
+
dask_log_level: str = Field(
|
|
215
|
+
default="WARNING",
|
|
216
|
+
description="Logging level for Dask.",
|
|
217
|
+
)
|
|
187
218
|
step_adaptor: StepAdaptorConfig = StepAdaptorConfig()
|
|
188
219
|
|
|
189
220
|
workflow: typing.Annotated[EndpointBase, Field(description="Endpoint for the default workflow.")] = EndpointBase(
|
|
@@ -47,11 +47,11 @@ class _FastApiFrontEndController:
|
|
|
47
47
|
self._server_background_task = asyncio.create_task(self._server.serve())
|
|
48
48
|
except asyncio.CancelledError as e:
|
|
49
49
|
error_message = f"Task error occurred while starting API server: {str(e)}"
|
|
50
|
-
logger.error(error_message
|
|
50
|
+
logger.error(error_message)
|
|
51
51
|
raise RuntimeError(error_message) from e
|
|
52
52
|
except Exception as e:
|
|
53
53
|
error_message = f"Unexpected error occurred while starting API server: {str(e)}"
|
|
54
|
-
logger.
|
|
54
|
+
logger.exception(error_message)
|
|
55
55
|
raise RuntimeError(error_message) from e
|
|
56
56
|
|
|
57
57
|
async def stop_server(self) -> None:
|
|
@@ -63,6 +63,6 @@ class _FastApiFrontEndController:
|
|
|
63
63
|
self._server.should_exit = True
|
|
64
64
|
await self._server_background_task
|
|
65
65
|
except asyncio.CancelledError as e:
|
|
66
|
-
logger.
|
|
66
|
+
logger.exception("Server shutdown failed: %s", str(e))
|
|
67
67
|
except Exception as e:
|
|
68
|
-
logger.
|
|
68
|
+
logger.exception("Unexpected error occurred: %s", str(e))
|