nvidia-nat 1.3.0a20250928__py3-none-any.whl → 1.3.0a20250930__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/agent/base.py +1 -1
- nat/agent/rewoo_agent/agent.py +298 -118
- nat/agent/rewoo_agent/prompt.py +19 -22
- nat/agent/rewoo_agent/register.py +4 -1
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +26 -18
- nat/builder/builder.py +1 -1
- nat/builder/context.py +2 -2
- nat/builder/front_end.py +1 -1
- nat/cli/cli_utils/config_override.py +1 -1
- nat/cli/commands/mcp/mcp.py +2 -2
- nat/cli/commands/start.py +1 -1
- nat/cli/type_registry.py +1 -1
- nat/control_flow/router_agent/register.py +1 -1
- nat/data_models/api_server.py +9 -9
- nat/data_models/authentication.py +3 -9
- nat/data_models/dataset_handler.py +1 -1
- nat/eval/evaluator/base_evaluator.py +1 -1
- nat/eval/swe_bench_evaluator/evaluate.py +1 -1
- nat/eval/tunable_rag_evaluator/evaluate.py +1 -1
- nat/experimental/decorators/experimental_warning_decorator.py +1 -2
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -1
- nat/front_ends/console/authentication_flow_handler.py +82 -30
- nat/front_ends/console/console_front_end_plugin.py +1 -1
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +188 -2
- nat/front_ends/fastapi/job_store.py +2 -2
- nat/front_ends/fastapi/message_handler.py +4 -4
- nat/front_ends/fastapi/message_validator.py +5 -5
- nat/front_ends/mcp/tool_converter.py +1 -1
- nat/llm/utils/thinking.py +1 -1
- nat/observability/exporter/base_exporter.py +1 -1
- nat/observability/exporter/span_exporter.py +1 -1
- nat/observability/exporter_manager.py +2 -2
- nat/observability/processor/batching_processor.py +1 -1
- nat/profiler/decorators/function_tracking.py +2 -2
- nat/profiler/parameter_optimization/parameter_selection.py +3 -4
- nat/profiler/parameter_optimization/pareto_visualizer.py +1 -1
- nat/retriever/milvus/retriever.py +1 -1
- nat/settings/global_settings.py +2 -2
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +1 -1
- nat/tool/datetime_tools.py +1 -1
- nat/utils/data_models/schema_validator.py +1 -1
- nat/utils/exception_handlers/automatic_retries.py +1 -1
- nat/utils/io/yaml_tools.py +1 -1
- nat/utils/type_utils.py +1 -1
- {nvidia_nat-1.3.0a20250928.dist-info → nvidia_nat-1.3.0a20250930.dist-info}/METADATA +2 -1
- {nvidia_nat-1.3.0a20250928.dist-info → nvidia_nat-1.3.0a20250930.dist-info}/RECORD +52 -52
- {nvidia_nat-1.3.0a20250928.dist-info → nvidia_nat-1.3.0a20250930.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0a20250928.dist-info → nvidia_nat-1.3.0a20250930.dist-info}/entry_points.txt +0 -0
- {nvidia_nat-1.3.0a20250928.dist-info → nvidia_nat-1.3.0a20250930.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0a20250928.dist-info → nvidia_nat-1.3.0a20250930.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0a20250928.dist-info → nvidia_nat-1.3.0a20250930.dist-info}/top_level.txt +0 -0
|
@@ -13,10 +13,12 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
import logging
|
|
17
|
+
from collections.abc import Callable
|
|
18
|
+
from datetime import UTC
|
|
16
19
|
from datetime import datetime
|
|
17
|
-
from datetime import timezone
|
|
18
|
-
from typing import Callable
|
|
19
20
|
|
|
21
|
+
import httpx
|
|
20
22
|
from authlib.integrations.httpx_client import OAuth2Client as AuthlibOAuth2Client
|
|
21
23
|
from pydantic import SecretStr
|
|
22
24
|
|
|
@@ -28,6 +30,8 @@ from nat.data_models.authentication import AuthFlowType
|
|
|
28
30
|
from nat.data_models.authentication import AuthResult
|
|
29
31
|
from nat.data_models.authentication import BearerTokenCred
|
|
30
32
|
|
|
33
|
+
logger = logging.getLogger(__name__)
|
|
34
|
+
|
|
31
35
|
|
|
32
36
|
class OAuth2AuthCodeFlowProvider(AuthProviderBase[OAuth2AuthCodeFlowProviderConfig]):
|
|
33
37
|
|
|
@@ -41,26 +45,30 @@ class OAuth2AuthCodeFlowProvider(AuthProviderBase[OAuth2AuthCodeFlowProviderConf
|
|
|
41
45
|
if not isinstance(refresh_token, str):
|
|
42
46
|
return None
|
|
43
47
|
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
48
|
+
try:
|
|
49
|
+
with AuthlibOAuth2Client(
|
|
50
|
+
client_id=self.config.client_id,
|
|
51
|
+
client_secret=self.config.client_secret,
|
|
52
|
+
) as client:
|
|
49
53
|
new_token_data = client.refresh_token(self.config.token_url, refresh_token=refresh_token)
|
|
50
|
-
except Exception:
|
|
51
|
-
# On any failure, we'll fall back to the full auth flow.
|
|
52
|
-
return None
|
|
53
54
|
|
|
54
|
-
|
|
55
|
-
|
|
55
|
+
expires_at_ts = new_token_data.get("expires_at")
|
|
56
|
+
new_expires_at = datetime.fromtimestamp(expires_at_ts, tz=UTC) if expires_at_ts else None
|
|
56
57
|
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
58
|
+
new_auth_result = AuthResult(
|
|
59
|
+
credentials=[BearerTokenCred(token=SecretStr(new_token_data["access_token"]))],
|
|
60
|
+
token_expires_at=new_expires_at,
|
|
61
|
+
raw=new_token_data,
|
|
62
|
+
)
|
|
62
63
|
|
|
63
|
-
|
|
64
|
+
self._authenticated_tokens[user_id] = new_auth_result
|
|
65
|
+
except httpx.HTTPStatusError:
|
|
66
|
+
return None
|
|
67
|
+
except httpx.RequestError:
|
|
68
|
+
return None
|
|
69
|
+
except Exception:
|
|
70
|
+
# On any other failure, we'll fall back to the full auth flow.
|
|
71
|
+
return None
|
|
64
72
|
|
|
65
73
|
return new_auth_result
|
|
66
74
|
|
nat/builder/builder.py
CHANGED
nat/builder/context.py
CHANGED
|
@@ -40,12 +40,12 @@ from nat.utils.reactive.subject import Subject
|
|
|
40
40
|
class Singleton(type):
|
|
41
41
|
|
|
42
42
|
def __init__(cls, name, bases, dict):
|
|
43
|
-
super(
|
|
43
|
+
super().__init__(name, bases, dict)
|
|
44
44
|
cls.instance = None
|
|
45
45
|
|
|
46
46
|
def __call__(cls, *args, **kw):
|
|
47
47
|
if cls.instance is None:
|
|
48
|
-
cls.instance = super(
|
|
48
|
+
cls.instance = super().__call__(*args, **kw)
|
|
49
49
|
return cls.instance
|
|
50
50
|
|
|
51
51
|
|
nat/builder/front_end.py
CHANGED
|
@@ -37,7 +37,7 @@ class FrontEndBase(typing.Generic[FrontEndConfigT], ABC):
|
|
|
37
37
|
|
|
38
38
|
super().__init__()
|
|
39
39
|
|
|
40
|
-
self._full_config:
|
|
40
|
+
self._full_config: Config = full_config
|
|
41
41
|
self._front_end_config: FrontEndConfigT = typing.cast(FrontEndConfigT, full_config.general.front_end)
|
|
42
42
|
|
|
43
43
|
@property
|
|
@@ -84,7 +84,7 @@ class LayeredConfig:
|
|
|
84
84
|
if lower_value not in ['true', 'false']:
|
|
85
85
|
raise ValueError(f"Boolean value must be 'true' or 'false', got '{value}'")
|
|
86
86
|
value = lower_value == 'true'
|
|
87
|
-
elif isinstance(original_value,
|
|
87
|
+
elif isinstance(original_value, int | float):
|
|
88
88
|
value = type(original_value)(value)
|
|
89
89
|
elif isinstance(original_value, list):
|
|
90
90
|
value = [v.strip() for v in value.split(',')]
|
nat/cli/commands/mcp/mcp.py
CHANGED
|
@@ -297,7 +297,7 @@ async def list_tools_via_function_group(
|
|
|
297
297
|
if fn is not None:
|
|
298
298
|
tools.append(to_tool_entry(full, fn))
|
|
299
299
|
else:
|
|
300
|
-
for full, fn in fns.items():
|
|
300
|
+
for full, fn in (await fns).items():
|
|
301
301
|
tools.append(to_tool_entry(full, fn))
|
|
302
302
|
|
|
303
303
|
return tools
|
|
@@ -443,7 +443,7 @@ async def ping_mcp_server(url: str,
|
|
|
443
443
|
# Apply timeout to the entire ping operation
|
|
444
444
|
return await asyncio.wait_for(_ping_operation(), timeout=timeout)
|
|
445
445
|
|
|
446
|
-
except
|
|
446
|
+
except TimeoutError:
|
|
447
447
|
return MCPPingResult(url=url,
|
|
448
448
|
status="unhealthy",
|
|
449
449
|
response_time_ms=None,
|
nat/cli/commands/start.py
CHANGED
|
@@ -111,7 +111,7 @@ class StartCommandGroup(click.Group):
|
|
|
111
111
|
elif (issubclass(decomposed_type.root, Path)):
|
|
112
112
|
param_type = click.Path(exists=True, file_okay=True, dir_okay=False, path_type=Path)
|
|
113
113
|
|
|
114
|
-
elif (issubclass(decomposed_type.root,
|
|
114
|
+
elif (issubclass(decomposed_type.root, list | tuple | set)):
|
|
115
115
|
if (len(decomposed_type.args) == 1):
|
|
116
116
|
inner = DecomposedType(decomposed_type.args[0])
|
|
117
117
|
# Support containers of Literal values -> multiple Choice
|
nat/cli/type_registry.py
CHANGED
|
@@ -992,7 +992,7 @@ class TypeRegistry:
|
|
|
992
992
|
if (short_names[key.local_name] == 1):
|
|
993
993
|
type_list.append((key.local_name, key.config_type))
|
|
994
994
|
|
|
995
|
-
return typing.Union[tuple(typing.Annotated[x_type, Tag(x_id)] for x_id, x_type in type_list)]
|
|
995
|
+
return typing.Union[*tuple(typing.Annotated[x_type, Tag(x_id)] for x_id, x_type in type_list)]
|
|
996
996
|
|
|
997
997
|
def compute_annotation(self, cls: type[TypedBaseModelT]):
|
|
998
998
|
|
|
@@ -81,7 +81,7 @@ async def router_agent_workflow(config: RouterAgentWorkflowConfig, builder: Buil
|
|
|
81
81
|
logger.exception("%s Router Agent failed with exception: %s", AGENT_LOG_PREFIX, ex)
|
|
82
82
|
if config.verbose:
|
|
83
83
|
return str(ex)
|
|
84
|
-
return "Router agent failed with exception:
|
|
84
|
+
return f"Router agent failed with exception: {ex}"
|
|
85
85
|
|
|
86
86
|
try:
|
|
87
87
|
yield FunctionInfo.from_fn(_response_fn, description=config.description)
|
nat/data_models/api_server.py
CHANGED
|
@@ -273,7 +273,7 @@ class ChatResponse(ResponseBaseModelOutput):
|
|
|
273
273
|
if model is None:
|
|
274
274
|
model = ""
|
|
275
275
|
if created is None:
|
|
276
|
-
created = datetime.datetime.now(datetime.
|
|
276
|
+
created = datetime.datetime.now(datetime.UTC)
|
|
277
277
|
|
|
278
278
|
return ChatResponse(id=id_,
|
|
279
279
|
object=object_,
|
|
@@ -317,7 +317,7 @@ class ChatResponseChunk(ResponseBaseModelOutput):
|
|
|
317
317
|
if id_ is None:
|
|
318
318
|
id_ = str(uuid.uuid4())
|
|
319
319
|
if created is None:
|
|
320
|
-
created = datetime.datetime.now(datetime.
|
|
320
|
+
created = datetime.datetime.now(datetime.UTC)
|
|
321
321
|
if model is None:
|
|
322
322
|
model = ""
|
|
323
323
|
if object_ is None:
|
|
@@ -343,7 +343,7 @@ class ChatResponseChunk(ResponseBaseModelOutput):
|
|
|
343
343
|
if id_ is None:
|
|
344
344
|
id_ = str(uuid.uuid4())
|
|
345
345
|
if created is None:
|
|
346
|
-
created = datetime.datetime.now(datetime.
|
|
346
|
+
created = datetime.datetime.now(datetime.UTC)
|
|
347
347
|
if model is None:
|
|
348
348
|
model = ""
|
|
349
349
|
|
|
@@ -485,7 +485,7 @@ class WebSocketUserMessage(BaseModel):
|
|
|
485
485
|
security: Security = Security()
|
|
486
486
|
error: Error = Error()
|
|
487
487
|
schema_version: str = "1.0.0"
|
|
488
|
-
timestamp: str = str(datetime.datetime.now(datetime.
|
|
488
|
+
timestamp: str = str(datetime.datetime.now(datetime.UTC))
|
|
489
489
|
|
|
490
490
|
|
|
491
491
|
class WebSocketUserInteractionResponseMessage(BaseModel):
|
|
@@ -501,7 +501,7 @@ class WebSocketUserInteractionResponseMessage(BaseModel):
|
|
|
501
501
|
security: Security = Security()
|
|
502
502
|
error: Error = Error()
|
|
503
503
|
schema_version: str = "1.0.0"
|
|
504
|
-
timestamp: str = str(datetime.datetime.now(datetime.
|
|
504
|
+
timestamp: str = str(datetime.datetime.now(datetime.UTC))
|
|
505
505
|
|
|
506
506
|
|
|
507
507
|
class SystemIntermediateStepContent(BaseModel):
|
|
@@ -527,7 +527,7 @@ class WebSocketSystemIntermediateStepMessage(BaseModel):
|
|
|
527
527
|
conversation_id: str | None = None
|
|
528
528
|
content: SystemIntermediateStepContent
|
|
529
529
|
status: WebSocketMessageStatus
|
|
530
|
-
timestamp: str = str(datetime.datetime.now(datetime.
|
|
530
|
+
timestamp: str = str(datetime.datetime.now(datetime.UTC))
|
|
531
531
|
|
|
532
532
|
|
|
533
533
|
class SystemResponseContent(BaseModel):
|
|
@@ -551,7 +551,7 @@ class WebSocketSystemResponseTokenMessage(BaseModel):
|
|
|
551
551
|
conversation_id: str | None = None
|
|
552
552
|
content: SystemResponseContent | Error | GenerateResponse
|
|
553
553
|
status: WebSocketMessageStatus
|
|
554
|
-
timestamp: str = str(datetime.datetime.now(datetime.
|
|
554
|
+
timestamp: str = str(datetime.datetime.now(datetime.UTC))
|
|
555
555
|
|
|
556
556
|
@field_validator("content")
|
|
557
557
|
@classmethod
|
|
@@ -560,7 +560,7 @@ class WebSocketSystemResponseTokenMessage(BaseModel):
|
|
|
560
560
|
raise ValueError(f"Field: content must be 'Error' when type is {WebSocketMessageType.ERROR_MESSAGE}")
|
|
561
561
|
|
|
562
562
|
if info.data.get("type") == WebSocketMessageType.RESPONSE_MESSAGE and not isinstance(
|
|
563
|
-
value,
|
|
563
|
+
value, SystemResponseContent | GenerateResponse):
|
|
564
564
|
raise ValueError(
|
|
565
565
|
f"Field: content must be 'SystemResponseContent' when type is {WebSocketMessageType.RESPONSE_MESSAGE}")
|
|
566
566
|
return value
|
|
@@ -582,7 +582,7 @@ class WebSocketSystemInteractionMessage(BaseModel):
|
|
|
582
582
|
conversation_id: str | None = None
|
|
583
583
|
content: HumanPrompt
|
|
584
584
|
status: WebSocketMessageStatus
|
|
585
|
-
timestamp: str = str(datetime.datetime.now(datetime.
|
|
585
|
+
timestamp: str = str(datetime.datetime.now(datetime.UTC))
|
|
586
586
|
|
|
587
587
|
|
|
588
588
|
# ======== GenerateResponse Converters ========
|
|
@@ -14,8 +14,8 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import typing
|
|
17
|
+
from datetime import UTC
|
|
17
18
|
from datetime import datetime
|
|
18
|
-
from datetime import timezone
|
|
19
19
|
from enum import Enum
|
|
20
20
|
|
|
21
21
|
import httpx
|
|
@@ -166,13 +166,7 @@ class BearerTokenCred(_CredBase):
|
|
|
166
166
|
|
|
167
167
|
|
|
168
168
|
Credential = typing.Annotated[
|
|
169
|
-
|
|
170
|
-
HeaderCred,
|
|
171
|
-
QueryCred,
|
|
172
|
-
CookieCred,
|
|
173
|
-
BasicAuthCred,
|
|
174
|
-
BearerTokenCred,
|
|
175
|
-
],
|
|
169
|
+
HeaderCred | QueryCred | CookieCred | BasicAuthCred | BearerTokenCred,
|
|
176
170
|
Field(discriminator="kind"),
|
|
177
171
|
]
|
|
178
172
|
|
|
@@ -213,7 +207,7 @@ class AuthResult(BaseModel):
|
|
|
213
207
|
"""
|
|
214
208
|
Checks if the authentication token has expired.
|
|
215
209
|
"""
|
|
216
|
-
return bool(self.token_expires_at and datetime.now(
|
|
210
|
+
return bool(self.token_expires_at and datetime.now(UTC) >= self.token_expires_at)
|
|
217
211
|
|
|
218
212
|
def as_requests_kwargs(self) -> dict[str, typing.Any]:
|
|
219
213
|
"""
|
|
@@ -80,7 +80,7 @@ class EvalDatasetJsonConfig(EvalDatasetBaseConfig, name="json"):
|
|
|
80
80
|
|
|
81
81
|
|
|
82
82
|
def read_jsonl(file_path: FilePath):
|
|
83
|
-
with open(file_path,
|
|
83
|
+
with open(file_path, encoding='utf-8') as f:
|
|
84
84
|
data = [json.loads(line) for line in f]
|
|
85
85
|
return pd.DataFrame(data)
|
|
86
86
|
|
|
@@ -71,7 +71,7 @@ class BaseEvaluator(ABC):
|
|
|
71
71
|
TqdmPositionRegistry.release(tqdm_position)
|
|
72
72
|
|
|
73
73
|
# Compute average if possible
|
|
74
|
-
numeric_scores = [item.score for item in output_items if isinstance(item.score,
|
|
74
|
+
numeric_scores = [item.score for item in output_items if isinstance(item.score, int | float)]
|
|
75
75
|
avg_score = round(sum(numeric_scores) / len(numeric_scores), 2) if numeric_scores else None
|
|
76
76
|
|
|
77
77
|
return EvalOutput(average_score=avg_score, eval_output_items=output_items)
|
|
@@ -204,7 +204,7 @@ class SweBenchEvaluator:
|
|
|
204
204
|
# if report file is not present, return empty EvalOutput
|
|
205
205
|
avg_score = 0.0
|
|
206
206
|
if report_file.exists():
|
|
207
|
-
with open(report_file,
|
|
207
|
+
with open(report_file, encoding="utf-8") as f:
|
|
208
208
|
report = json.load(f)
|
|
209
209
|
resolved_instances = report.get("resolved_instances", 0)
|
|
210
210
|
total_instances = report.get("total_instances", 0)
|
|
@@ -137,8 +137,7 @@ def experimental(func: Any = None, *, feature_name: str | None = None, metadata:
|
|
|
137
137
|
@functools.wraps(func)
|
|
138
138
|
def sync_gen_wrapper(*args, **kwargs) -> Generator[Any, Any, Any]:
|
|
139
139
|
issue_experimental_warning(function_name, feature_name, metadata)
|
|
140
|
-
|
|
141
|
-
yield item # yield the original item
|
|
140
|
+
yield from func(*args, **kwargs) # yield the original item
|
|
142
141
|
|
|
143
142
|
return sync_gen_wrapper
|
|
144
143
|
|
|
@@ -71,7 +71,7 @@ class LLMBasedOutputMergingSelector(StrategyBase):
|
|
|
71
71
|
raise ImportError("langchain-core is not installed. Please install it to use SingleShotMultiPlanPlanner.\n"
|
|
72
72
|
"This error can be resolved by installing nvidia-nat-langchain.")
|
|
73
73
|
|
|
74
|
-
from
|
|
74
|
+
from collections.abc import Callable
|
|
75
75
|
|
|
76
76
|
from pydantic import BaseModel
|
|
77
77
|
|
|
@@ -14,13 +14,16 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import asyncio
|
|
17
|
+
import logging
|
|
17
18
|
import secrets
|
|
18
19
|
import webbrowser
|
|
19
20
|
from dataclasses import dataclass
|
|
20
21
|
from dataclasses import field
|
|
21
22
|
|
|
22
23
|
import click
|
|
24
|
+
import httpx
|
|
23
25
|
import pkce
|
|
26
|
+
from authlib.common.errors import AuthlibBaseError as OAuthError
|
|
24
27
|
from authlib.integrations.httpx_client import AsyncOAuth2Client
|
|
25
28
|
from fastapi import FastAPI
|
|
26
29
|
from fastapi import Request
|
|
@@ -32,6 +35,8 @@ from nat.data_models.authentication import AuthFlowType
|
|
|
32
35
|
from nat.data_models.authentication import AuthProviderBaseConfig
|
|
33
36
|
from nat.front_ends.fastapi.fastapi_front_end_controller import _FastApiFrontEndController
|
|
34
37
|
|
|
38
|
+
logger = logging.getLogger(__name__)
|
|
39
|
+
|
|
35
40
|
|
|
36
41
|
# --------------------------------------------------------------------------- #
|
|
37
42
|
# Helpers #
|
|
@@ -87,17 +92,53 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
|
|
|
87
92
|
"""
|
|
88
93
|
Separated for easy overriding in tests (to inject ASGITransport).
|
|
89
94
|
"""
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
95
|
+
try:
|
|
96
|
+
client = AsyncOAuth2Client(
|
|
97
|
+
client_id=cfg.client_id,
|
|
98
|
+
client_secret=cfg.client_secret,
|
|
99
|
+
redirect_uri=cfg.redirect_uri,
|
|
100
|
+
scope=" ".join(cfg.scopes) if cfg.scopes else None,
|
|
101
|
+
token_endpoint=cfg.token_url,
|
|
102
|
+
token_endpoint_auth_method=cfg.token_endpoint_auth_method,
|
|
103
|
+
code_challenge_method="S256" if cfg.use_pkce else None,
|
|
104
|
+
)
|
|
105
|
+
self._oauth_client = client
|
|
106
|
+
return client
|
|
107
|
+
except (OAuthError, ValueError, TypeError) as e:
|
|
108
|
+
raise RuntimeError(f"Invalid OAuth2 configuration: {e}") from e
|
|
109
|
+
except Exception as e:
|
|
110
|
+
raise RuntimeError(f"Failed to create OAuth2 client: {e}") from e
|
|
111
|
+
|
|
112
|
+
def _create_authorization_url(self,
|
|
113
|
+
client: AsyncOAuth2Client,
|
|
114
|
+
config: OAuth2AuthCodeFlowProviderConfig,
|
|
115
|
+
state: str,
|
|
116
|
+
verifier: str | None = None,
|
|
117
|
+
challenge: str | None = None) -> str:
|
|
118
|
+
"""
|
|
119
|
+
Create OAuth authorization URL with proper error handling.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
client: The OAuth2 client instance
|
|
123
|
+
config: OAuth2 configuration
|
|
124
|
+
state: OAuth state parameter
|
|
125
|
+
verifier: PKCE verifier (if using PKCE)
|
|
126
|
+
challenge: PKCE challenge (if using PKCE)
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
The authorization URL
|
|
130
|
+
"""
|
|
131
|
+
try:
|
|
132
|
+
auth_url, _ = client.create_authorization_url(
|
|
133
|
+
config.authorization_url,
|
|
134
|
+
state=state,
|
|
135
|
+
code_verifier=verifier if config.use_pkce else None,
|
|
136
|
+
code_challenge=challenge if config.use_pkce else None,
|
|
137
|
+
**(config.authorization_kwargs or {})
|
|
138
|
+
)
|
|
139
|
+
return auth_url
|
|
140
|
+
except (OAuthError, ValueError, TypeError) as e:
|
|
141
|
+
raise RuntimeError(f"Error creating OAuth authorization URL: {e}") from e
|
|
101
142
|
|
|
102
143
|
# --------------------------- HTTP Basic ------------------------------ #
|
|
103
144
|
@staticmethod
|
|
@@ -131,13 +172,12 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
|
|
|
131
172
|
flow_state.verifier = verifier
|
|
132
173
|
flow_state.challenge = challenge
|
|
133
174
|
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
)
|
|
175
|
+
# Create authorization URL using helper function
|
|
176
|
+
auth_url = self._create_authorization_url(client=client,
|
|
177
|
+
config=cfg,
|
|
178
|
+
state=state,
|
|
179
|
+
verifier=flow_state.verifier,
|
|
180
|
+
challenge=flow_state.challenge)
|
|
141
181
|
|
|
142
182
|
# Register flow + maybe spin up redirect handler
|
|
143
183
|
async with self._server_lock:
|
|
@@ -149,14 +189,18 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
|
|
|
149
189
|
self._flows[state] = flow_state
|
|
150
190
|
self._active_flows += 1
|
|
151
191
|
|
|
152
|
-
|
|
153
|
-
|
|
192
|
+
try:
|
|
193
|
+
webbrowser.open(auth_url)
|
|
194
|
+
click.echo("Your browser has been opened for authentication.")
|
|
195
|
+
except Exception as e:
|
|
196
|
+
logger.error("Browser open failed: %s", e)
|
|
197
|
+
raise RuntimeError(f"Browser open failed: {e}") from e
|
|
154
198
|
|
|
155
199
|
# Wait for the redirect to land
|
|
156
200
|
try:
|
|
157
201
|
token = await asyncio.wait_for(flow_state.future, timeout=300)
|
|
158
|
-
except
|
|
159
|
-
raise RuntimeError("Authentication timed out (5 min).")
|
|
202
|
+
except TimeoutError as exc:
|
|
203
|
+
raise RuntimeError("Authentication timed out (5 min).") from exc
|
|
160
204
|
finally:
|
|
161
205
|
async with self._server_lock:
|
|
162
206
|
self._flows.pop(state, None)
|
|
@@ -175,9 +219,9 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
|
|
|
175
219
|
# --------------- redirect server / in‑process app -------------------- #
|
|
176
220
|
async def _build_redirect_app(self) -> FastAPI:
|
|
177
221
|
"""
|
|
178
|
-
* If cfg.run_redirect_local_server == True → start a
|
|
179
|
-
* Else → only build the
|
|
180
|
-
for in‑process testing
|
|
222
|
+
* If cfg.run_redirect_local_server == True → start a local server.
|
|
223
|
+
* Else → only build the redirect app and save it to `self._redirect_app`
|
|
224
|
+
for in‑process testing.
|
|
181
225
|
"""
|
|
182
226
|
app = FastAPI()
|
|
183
227
|
|
|
@@ -195,8 +239,16 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
|
|
|
195
239
|
state=state,
|
|
196
240
|
)
|
|
197
241
|
flow_state.future.set_result(token)
|
|
198
|
-
except
|
|
199
|
-
flow_state.future.set_exception(
|
|
242
|
+
except OAuthError as e:
|
|
243
|
+
flow_state.future.set_exception(
|
|
244
|
+
RuntimeError(f"Authorization server rejected request: {e.error} ({e.description})"))
|
|
245
|
+
return "Authentication failed: Authorization server rejected the request. You may close this tab."
|
|
246
|
+
except httpx.HTTPError as e:
|
|
247
|
+
flow_state.future.set_exception(RuntimeError(f"Network error during token fetch: {e}"))
|
|
248
|
+
return "Authentication failed: Network error occurred. You may close this tab."
|
|
249
|
+
except Exception as e:
|
|
250
|
+
flow_state.future.set_exception(RuntimeError(f"Authentication failed: {e}"))
|
|
251
|
+
return "Authentication failed: An unexpected error occurred. You may close this tab."
|
|
200
252
|
return "Authentication successful – you may close this tab."
|
|
201
253
|
|
|
202
254
|
return app
|
|
@@ -213,7 +265,7 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
|
|
|
213
265
|
|
|
214
266
|
asyncio.create_task(self._server_controller.start_server(host="localhost", port=8000))
|
|
215
267
|
|
|
216
|
-
# Give
|
|
268
|
+
# Give the server a moment to bind sockets before we return
|
|
217
269
|
await asyncio.sleep(0.3)
|
|
218
270
|
except Exception as exc: # noqa: BLE001
|
|
219
271
|
raise RuntimeError(f"Failed to start redirect server: {exc}") from exc
|
|
@@ -227,7 +279,7 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
|
|
|
227
279
|
@property
|
|
228
280
|
def redirect_app(self) -> FastAPI | None:
|
|
229
281
|
"""
|
|
230
|
-
In
|
|
231
|
-
app is exposed
|
|
282
|
+
In test mode (run_redirect_local_server=False) the in‑memory redirect
|
|
283
|
+
app is exposed for testing purposes.
|
|
232
284
|
"""
|
|
233
285
|
return self._redirect_app
|
|
@@ -88,7 +88,7 @@ class ConsoleFrontEndPlugin(SimpleFrontEndPluginBase[ConsoleFrontEndConfig]):
|
|
|
88
88
|
elif (self.front_end_config.input_file):
|
|
89
89
|
|
|
90
90
|
# Run the workflow
|
|
91
|
-
with open(self.front_end_config.input_file,
|
|
91
|
+
with open(self.front_end_config.input_file, encoding="utf-8") as f:
|
|
92
92
|
|
|
93
93
|
async with session_manager.workflow.run(f) as runner:
|
|
94
94
|
runner_outputs = await runner.result(to_type=str)
|
|
@@ -22,6 +22,7 @@ from dataclasses import dataclass
|
|
|
22
22
|
from dataclasses import field
|
|
23
23
|
|
|
24
24
|
import pkce
|
|
25
|
+
from authlib.common.errors import AuthlibBaseError as OAuthError
|
|
25
26
|
from authlib.integrations.httpx_client import AsyncOAuth2Client
|
|
26
27
|
|
|
27
28
|
from nat.authentication.interfaces import FlowHandlerBase
|
|
@@ -61,14 +62,50 @@ class WebSocketAuthenticationFlowHandler(FlowHandlerBase):
|
|
|
61
62
|
|
|
62
63
|
raise NotImplementedError(f"Authentication method '{method}' is not supported by the websocket frontend.")
|
|
63
64
|
|
|
64
|
-
def create_oauth_client(self, config: OAuth2AuthCodeFlowProviderConfig):
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
65
|
+
def create_oauth_client(self, config: OAuth2AuthCodeFlowProviderConfig) -> AsyncOAuth2Client:
|
|
66
|
+
try:
|
|
67
|
+
return AsyncOAuth2Client(client_id=config.client_id,
|
|
68
|
+
client_secret=config.client_secret,
|
|
69
|
+
redirect_uri=config.redirect_uri,
|
|
70
|
+
scope=" ".join(config.scopes) if config.scopes else None,
|
|
71
|
+
token_endpoint=config.token_url,
|
|
72
|
+
code_challenge_method='S256' if config.use_pkce else None,
|
|
73
|
+
token_endpoint_auth_method=config.token_endpoint_auth_method)
|
|
74
|
+
except (OAuthError, ValueError, TypeError) as e:
|
|
75
|
+
raise RuntimeError(f"Invalid OAuth2 configuration: {e}") from e
|
|
76
|
+
except Exception as e:
|
|
77
|
+
raise RuntimeError(f"Failed to create OAuth2 client: {e}") from e
|
|
78
|
+
|
|
79
|
+
def _create_authorization_url(self,
|
|
80
|
+
client: AsyncOAuth2Client,
|
|
81
|
+
config: OAuth2AuthCodeFlowProviderConfig,
|
|
82
|
+
state: str,
|
|
83
|
+
verifier: str = None,
|
|
84
|
+
challenge: str = None) -> str:
|
|
85
|
+
"""
|
|
86
|
+
Create OAuth authorization URL with proper error handling.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
client: The OAuth2 client instance
|
|
90
|
+
config: OAuth2 configuration
|
|
91
|
+
state: OAuth state parameter
|
|
92
|
+
verifier: PKCE verifier (if using PKCE)
|
|
93
|
+
challenge: PKCE challenge (if using PKCE)
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
The authorization URL
|
|
97
|
+
"""
|
|
98
|
+
try:
|
|
99
|
+
authorization_url, _ = client.create_authorization_url(
|
|
100
|
+
config.authorization_url,
|
|
101
|
+
state=state,
|
|
102
|
+
code_verifier=verifier if config.use_pkce else None,
|
|
103
|
+
code_challenge=challenge if config.use_pkce else None,
|
|
104
|
+
**(config.authorization_kwargs or {})
|
|
105
|
+
)
|
|
106
|
+
return authorization_url
|
|
107
|
+
except (OAuthError, ValueError, TypeError) as e:
|
|
108
|
+
raise RuntimeError(f"Error creating OAuth authorization URL: {e}") from e
|
|
72
109
|
|
|
73
110
|
async def _handle_oauth2_auth_code_flow(self, config: OAuth2AuthCodeFlowProviderConfig) -> AuthenticatedContext:
|
|
74
111
|
|
|
@@ -82,21 +119,19 @@ class WebSocketAuthenticationFlowHandler(FlowHandlerBase):
|
|
|
82
119
|
flow_state.verifier = verifier
|
|
83
120
|
flow_state.challenge = challenge
|
|
84
121
|
|
|
85
|
-
authorization_url
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
**(config.authorization_kwargs or {})
|
|
91
|
-
)
|
|
122
|
+
authorization_url = self._create_authorization_url(client=flow_state.client,
|
|
123
|
+
config=config,
|
|
124
|
+
state=state,
|
|
125
|
+
verifier=flow_state.verifier,
|
|
126
|
+
challenge=flow_state.challenge)
|
|
92
127
|
|
|
93
128
|
await self._add_flow_cb(state, flow_state)
|
|
94
129
|
await self._web_socket_message_handler.create_websocket_message(_HumanPromptOAuthConsent(text=authorization_url)
|
|
95
130
|
)
|
|
96
131
|
try:
|
|
97
132
|
token = await asyncio.wait_for(flow_state.future, timeout=300)
|
|
98
|
-
except
|
|
99
|
-
raise RuntimeError("Authentication flow timed out after 5 minutes.")
|
|
133
|
+
except TimeoutError as exc:
|
|
134
|
+
raise RuntimeError("Authentication flow timed out after 5 minutes.") from exc
|
|
100
135
|
finally:
|
|
101
136
|
|
|
102
137
|
await self._remove_flow_cb(state)
|