nvidia-nat 1.3.0a20250928__py3-none-any.whl → 1.3.0a20250930__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. nat/agent/base.py +1 -1
  2. nat/agent/rewoo_agent/agent.py +298 -118
  3. nat/agent/rewoo_agent/prompt.py +19 -22
  4. nat/agent/rewoo_agent/register.py +4 -1
  5. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +26 -18
  6. nat/builder/builder.py +1 -1
  7. nat/builder/context.py +2 -2
  8. nat/builder/front_end.py +1 -1
  9. nat/cli/cli_utils/config_override.py +1 -1
  10. nat/cli/commands/mcp/mcp.py +2 -2
  11. nat/cli/commands/start.py +1 -1
  12. nat/cli/type_registry.py +1 -1
  13. nat/control_flow/router_agent/register.py +1 -1
  14. nat/data_models/api_server.py +9 -9
  15. nat/data_models/authentication.py +3 -9
  16. nat/data_models/dataset_handler.py +1 -1
  17. nat/eval/evaluator/base_evaluator.py +1 -1
  18. nat/eval/swe_bench_evaluator/evaluate.py +1 -1
  19. nat/eval/tunable_rag_evaluator/evaluate.py +1 -1
  20. nat/experimental/decorators/experimental_warning_decorator.py +1 -2
  21. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -1
  22. nat/front_ends/console/authentication_flow_handler.py +82 -30
  23. nat/front_ends/console/console_front_end_plugin.py +1 -1
  24. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
  25. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +188 -2
  26. nat/front_ends/fastapi/job_store.py +2 -2
  27. nat/front_ends/fastapi/message_handler.py +4 -4
  28. nat/front_ends/fastapi/message_validator.py +5 -5
  29. nat/front_ends/mcp/tool_converter.py +1 -1
  30. nat/llm/utils/thinking.py +1 -1
  31. nat/observability/exporter/base_exporter.py +1 -1
  32. nat/observability/exporter/span_exporter.py +1 -1
  33. nat/observability/exporter_manager.py +2 -2
  34. nat/observability/processor/batching_processor.py +1 -1
  35. nat/profiler/decorators/function_tracking.py +2 -2
  36. nat/profiler/parameter_optimization/parameter_selection.py +3 -4
  37. nat/profiler/parameter_optimization/pareto_visualizer.py +1 -1
  38. nat/retriever/milvus/retriever.py +1 -1
  39. nat/settings/global_settings.py +2 -2
  40. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +1 -1
  41. nat/tool/datetime_tools.py +1 -1
  42. nat/utils/data_models/schema_validator.py +1 -1
  43. nat/utils/exception_handlers/automatic_retries.py +1 -1
  44. nat/utils/io/yaml_tools.py +1 -1
  45. nat/utils/type_utils.py +1 -1
  46. {nvidia_nat-1.3.0a20250928.dist-info → nvidia_nat-1.3.0a20250930.dist-info}/METADATA +2 -1
  47. {nvidia_nat-1.3.0a20250928.dist-info → nvidia_nat-1.3.0a20250930.dist-info}/RECORD +52 -52
  48. {nvidia_nat-1.3.0a20250928.dist-info → nvidia_nat-1.3.0a20250930.dist-info}/WHEEL +0 -0
  49. {nvidia_nat-1.3.0a20250928.dist-info → nvidia_nat-1.3.0a20250930.dist-info}/entry_points.txt +0 -0
  50. {nvidia_nat-1.3.0a20250928.dist-info → nvidia_nat-1.3.0a20250930.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  51. {nvidia_nat-1.3.0a20250928.dist-info → nvidia_nat-1.3.0a20250930.dist-info}/licenses/LICENSE.md +0 -0
  52. {nvidia_nat-1.3.0a20250928.dist-info → nvidia_nat-1.3.0a20250930.dist-info}/top_level.txt +0 -0
@@ -13,10 +13,12 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import logging
17
+ from collections.abc import Callable
18
+ from datetime import UTC
16
19
  from datetime import datetime
17
- from datetime import timezone
18
- from typing import Callable
19
20
 
21
+ import httpx
20
22
  from authlib.integrations.httpx_client import OAuth2Client as AuthlibOAuth2Client
21
23
  from pydantic import SecretStr
22
24
 
@@ -28,6 +30,8 @@ from nat.data_models.authentication import AuthFlowType
28
30
  from nat.data_models.authentication import AuthResult
29
31
  from nat.data_models.authentication import BearerTokenCred
30
32
 
33
+ logger = logging.getLogger(__name__)
34
+
31
35
 
32
36
  class OAuth2AuthCodeFlowProvider(AuthProviderBase[OAuth2AuthCodeFlowProviderConfig]):
33
37
 
@@ -41,26 +45,30 @@ class OAuth2AuthCodeFlowProvider(AuthProviderBase[OAuth2AuthCodeFlowProviderConf
41
45
  if not isinstance(refresh_token, str):
42
46
  return None
43
47
 
44
- with AuthlibOAuth2Client(
45
- client_id=self.config.client_id,
46
- client_secret=self.config.client_secret,
47
- ) as client:
48
- try:
48
+ try:
49
+ with AuthlibOAuth2Client(
50
+ client_id=self.config.client_id,
51
+ client_secret=self.config.client_secret,
52
+ ) as client:
49
53
  new_token_data = client.refresh_token(self.config.token_url, refresh_token=refresh_token)
50
- except Exception:
51
- # On any failure, we'll fall back to the full auth flow.
52
- return None
53
54
 
54
- expires_at_ts = new_token_data.get("expires_at")
55
- new_expires_at = datetime.fromtimestamp(expires_at_ts, tz=timezone.utc) if expires_at_ts else None
55
+ expires_at_ts = new_token_data.get("expires_at")
56
+ new_expires_at = datetime.fromtimestamp(expires_at_ts, tz=UTC) if expires_at_ts else None
56
57
 
57
- new_auth_result = AuthResult(
58
- credentials=[BearerTokenCred(token=SecretStr(new_token_data["access_token"]))],
59
- token_expires_at=new_expires_at,
60
- raw=new_token_data,
61
- )
58
+ new_auth_result = AuthResult(
59
+ credentials=[BearerTokenCred(token=SecretStr(new_token_data["access_token"]))],
60
+ token_expires_at=new_expires_at,
61
+ raw=new_token_data,
62
+ )
62
63
 
63
- self._authenticated_tokens[user_id] = new_auth_result
64
+ self._authenticated_tokens[user_id] = new_auth_result
65
+ except httpx.HTTPStatusError:
66
+ return None
67
+ except httpx.RequestError:
68
+ return None
69
+ except Exception:
70
+ # On any other failure, we'll fall back to the full auth flow.
71
+ return None
64
72
 
65
73
  return new_auth_result
66
74
 
nat/builder/builder.py CHANGED
@@ -56,7 +56,7 @@ if typing.TYPE_CHECKING:
56
56
  from nat.experimental.test_time_compute.models.strategy_base import StrategyBase
57
57
 
58
58
 
59
- class UserManagerHolder():
59
+ class UserManagerHolder:
60
60
 
61
61
  def __init__(self, context: Context) -> None:
62
62
  self._context = context
nat/builder/context.py CHANGED
@@ -40,12 +40,12 @@ from nat.utils.reactive.subject import Subject
40
40
  class Singleton(type):
41
41
 
42
42
  def __init__(cls, name, bases, dict):
43
- super(Singleton, cls).__init__(name, bases, dict)
43
+ super().__init__(name, bases, dict)
44
44
  cls.instance = None
45
45
 
46
46
  def __call__(cls, *args, **kw):
47
47
  if cls.instance is None:
48
- cls.instance = super(Singleton, cls).__call__(*args, **kw)
48
+ cls.instance = super().__call__(*args, **kw)
49
49
  return cls.instance
50
50
 
51
51
 
nat/builder/front_end.py CHANGED
@@ -37,7 +37,7 @@ class FrontEndBase(typing.Generic[FrontEndConfigT], ABC):
37
37
 
38
38
  super().__init__()
39
39
 
40
- self._full_config: "Config" = full_config
40
+ self._full_config: Config = full_config
41
41
  self._front_end_config: FrontEndConfigT = typing.cast(FrontEndConfigT, full_config.general.front_end)
42
42
 
43
43
  @property
@@ -84,7 +84,7 @@ class LayeredConfig:
84
84
  if lower_value not in ['true', 'false']:
85
85
  raise ValueError(f"Boolean value must be 'true' or 'false', got '{value}'")
86
86
  value = lower_value == 'true'
87
- elif isinstance(original_value, (int, float)):
87
+ elif isinstance(original_value, int | float):
88
88
  value = type(original_value)(value)
89
89
  elif isinstance(original_value, list):
90
90
  value = [v.strip() for v in value.split(',')]
@@ -297,7 +297,7 @@ async def list_tools_via_function_group(
297
297
  if fn is not None:
298
298
  tools.append(to_tool_entry(full, fn))
299
299
  else:
300
- for full, fn in fns.items():
300
+ for full, fn in (await fns).items():
301
301
  tools.append(to_tool_entry(full, fn))
302
302
 
303
303
  return tools
@@ -443,7 +443,7 @@ async def ping_mcp_server(url: str,
443
443
  # Apply timeout to the entire ping operation
444
444
  return await asyncio.wait_for(_ping_operation(), timeout=timeout)
445
445
 
446
- except asyncio.TimeoutError:
446
+ except TimeoutError:
447
447
  return MCPPingResult(url=url,
448
448
  status="unhealthy",
449
449
  response_time_ms=None,
nat/cli/commands/start.py CHANGED
@@ -111,7 +111,7 @@ class StartCommandGroup(click.Group):
111
111
  elif (issubclass(decomposed_type.root, Path)):
112
112
  param_type = click.Path(exists=True, file_okay=True, dir_okay=False, path_type=Path)
113
113
 
114
- elif (issubclass(decomposed_type.root, (list, tuple, set))):
114
+ elif (issubclass(decomposed_type.root, list | tuple | set)):
115
115
  if (len(decomposed_type.args) == 1):
116
116
  inner = DecomposedType(decomposed_type.args[0])
117
117
  # Support containers of Literal values -> multiple Choice
nat/cli/type_registry.py CHANGED
@@ -992,7 +992,7 @@ class TypeRegistry:
992
992
  if (short_names[key.local_name] == 1):
993
993
  type_list.append((key.local_name, key.config_type))
994
994
 
995
- return typing.Union[tuple(typing.Annotated[x_type, Tag(x_id)] for x_id, x_type in type_list)]
995
+ return typing.Union[*tuple(typing.Annotated[x_type, Tag(x_id)] for x_id, x_type in type_list)]
996
996
 
997
997
  def compute_annotation(self, cls: type[TypedBaseModelT]):
998
998
 
@@ -81,7 +81,7 @@ async def router_agent_workflow(config: RouterAgentWorkflowConfig, builder: Buil
81
81
  logger.exception("%s Router Agent failed with exception: %s", AGENT_LOG_PREFIX, ex)
82
82
  if config.verbose:
83
83
  return str(ex)
84
- return "Router agent failed with exception: %s" % ex
84
+ return f"Router agent failed with exception: {ex}"
85
85
 
86
86
  try:
87
87
  yield FunctionInfo.from_fn(_response_fn, description=config.description)
@@ -273,7 +273,7 @@ class ChatResponse(ResponseBaseModelOutput):
273
273
  if model is None:
274
274
  model = ""
275
275
  if created is None:
276
- created = datetime.datetime.now(datetime.timezone.utc)
276
+ created = datetime.datetime.now(datetime.UTC)
277
277
 
278
278
  return ChatResponse(id=id_,
279
279
  object=object_,
@@ -317,7 +317,7 @@ class ChatResponseChunk(ResponseBaseModelOutput):
317
317
  if id_ is None:
318
318
  id_ = str(uuid.uuid4())
319
319
  if created is None:
320
- created = datetime.datetime.now(datetime.timezone.utc)
320
+ created = datetime.datetime.now(datetime.UTC)
321
321
  if model is None:
322
322
  model = ""
323
323
  if object_ is None:
@@ -343,7 +343,7 @@ class ChatResponseChunk(ResponseBaseModelOutput):
343
343
  if id_ is None:
344
344
  id_ = str(uuid.uuid4())
345
345
  if created is None:
346
- created = datetime.datetime.now(datetime.timezone.utc)
346
+ created = datetime.datetime.now(datetime.UTC)
347
347
  if model is None:
348
348
  model = ""
349
349
 
@@ -485,7 +485,7 @@ class WebSocketUserMessage(BaseModel):
485
485
  security: Security = Security()
486
486
  error: Error = Error()
487
487
  schema_version: str = "1.0.0"
488
- timestamp: str = str(datetime.datetime.now(datetime.timezone.utc))
488
+ timestamp: str = str(datetime.datetime.now(datetime.UTC))
489
489
 
490
490
 
491
491
  class WebSocketUserInteractionResponseMessage(BaseModel):
@@ -501,7 +501,7 @@ class WebSocketUserInteractionResponseMessage(BaseModel):
501
501
  security: Security = Security()
502
502
  error: Error = Error()
503
503
  schema_version: str = "1.0.0"
504
- timestamp: str = str(datetime.datetime.now(datetime.timezone.utc))
504
+ timestamp: str = str(datetime.datetime.now(datetime.UTC))
505
505
 
506
506
 
507
507
  class SystemIntermediateStepContent(BaseModel):
@@ -527,7 +527,7 @@ class WebSocketSystemIntermediateStepMessage(BaseModel):
527
527
  conversation_id: str | None = None
528
528
  content: SystemIntermediateStepContent
529
529
  status: WebSocketMessageStatus
530
- timestamp: str = str(datetime.datetime.now(datetime.timezone.utc))
530
+ timestamp: str = str(datetime.datetime.now(datetime.UTC))
531
531
 
532
532
 
533
533
  class SystemResponseContent(BaseModel):
@@ -551,7 +551,7 @@ class WebSocketSystemResponseTokenMessage(BaseModel):
551
551
  conversation_id: str | None = None
552
552
  content: SystemResponseContent | Error | GenerateResponse
553
553
  status: WebSocketMessageStatus
554
- timestamp: str = str(datetime.datetime.now(datetime.timezone.utc))
554
+ timestamp: str = str(datetime.datetime.now(datetime.UTC))
555
555
 
556
556
  @field_validator("content")
557
557
  @classmethod
@@ -560,7 +560,7 @@ class WebSocketSystemResponseTokenMessage(BaseModel):
560
560
  raise ValueError(f"Field: content must be 'Error' when type is {WebSocketMessageType.ERROR_MESSAGE}")
561
561
 
562
562
  if info.data.get("type") == WebSocketMessageType.RESPONSE_MESSAGE and not isinstance(
563
- value, (SystemResponseContent, GenerateResponse)):
563
+ value, SystemResponseContent | GenerateResponse):
564
564
  raise ValueError(
565
565
  f"Field: content must be 'SystemResponseContent' when type is {WebSocketMessageType.RESPONSE_MESSAGE}")
566
566
  return value
@@ -582,7 +582,7 @@ class WebSocketSystemInteractionMessage(BaseModel):
582
582
  conversation_id: str | None = None
583
583
  content: HumanPrompt
584
584
  status: WebSocketMessageStatus
585
- timestamp: str = str(datetime.datetime.now(datetime.timezone.utc))
585
+ timestamp: str = str(datetime.datetime.now(datetime.UTC))
586
586
 
587
587
 
588
588
  # ======== GenerateResponse Converters ========
@@ -14,8 +14,8 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import typing
17
+ from datetime import UTC
17
18
  from datetime import datetime
18
- from datetime import timezone
19
19
  from enum import Enum
20
20
 
21
21
  import httpx
@@ -166,13 +166,7 @@ class BearerTokenCred(_CredBase):
166
166
 
167
167
 
168
168
  Credential = typing.Annotated[
169
- typing.Union[
170
- HeaderCred,
171
- QueryCred,
172
- CookieCred,
173
- BasicAuthCred,
174
- BearerTokenCred,
175
- ],
169
+ HeaderCred | QueryCred | CookieCred | BasicAuthCred | BearerTokenCred,
176
170
  Field(discriminator="kind"),
177
171
  ]
178
172
 
@@ -213,7 +207,7 @@ class AuthResult(BaseModel):
213
207
  """
214
208
  Checks if the authentication token has expired.
215
209
  """
216
- return bool(self.token_expires_at and datetime.now(timezone.utc) >= self.token_expires_at)
210
+ return bool(self.token_expires_at and datetime.now(UTC) >= self.token_expires_at)
217
211
 
218
212
  def as_requests_kwargs(self) -> dict[str, typing.Any]:
219
213
  """
@@ -80,7 +80,7 @@ class EvalDatasetJsonConfig(EvalDatasetBaseConfig, name="json"):
80
80
 
81
81
 
82
82
  def read_jsonl(file_path: FilePath):
83
- with open(file_path, 'r', encoding='utf-8') as f:
83
+ with open(file_path, encoding='utf-8') as f:
84
84
  data = [json.loads(line) for line in f]
85
85
  return pd.DataFrame(data)
86
86
 
@@ -71,7 +71,7 @@ class BaseEvaluator(ABC):
71
71
  TqdmPositionRegistry.release(tqdm_position)
72
72
 
73
73
  # Compute average if possible
74
- numeric_scores = [item.score for item in output_items if isinstance(item.score, (int, float))]
74
+ numeric_scores = [item.score for item in output_items if isinstance(item.score, int | float)]
75
75
  avg_score = round(sum(numeric_scores) / len(numeric_scores), 2) if numeric_scores else None
76
76
 
77
77
  return EvalOutput(average_score=avg_score, eval_output_items=output_items)
@@ -204,7 +204,7 @@ class SweBenchEvaluator:
204
204
  # if report file is not present, return empty EvalOutput
205
205
  avg_score = 0.0
206
206
  if report_file.exists():
207
- with open(report_file, "r", encoding="utf-8") as f:
207
+ with open(report_file, encoding="utf-8") as f:
208
208
  report = json.load(f)
209
209
  resolved_instances = report.get("resolved_instances", 0)
210
210
  total_instances = report.get("total_instances", 0)
@@ -14,7 +14,7 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import logging
17
- from typing import Callable
17
+ from collections.abc import Callable
18
18
 
19
19
  from langchain.output_parsers import ResponseSchema
20
20
  from langchain.output_parsers import StructuredOutputParser
@@ -137,8 +137,7 @@ def experimental(func: Any = None, *, feature_name: str | None = None, metadata:
137
137
  @functools.wraps(func)
138
138
  def sync_gen_wrapper(*args, **kwargs) -> Generator[Any, Any, Any]:
139
139
  issue_experimental_warning(function_name, feature_name, metadata)
140
- for item in func(*args, **kwargs):
141
- yield item # yield the original item
140
+ yield from func(*args, **kwargs) # yield the original item
142
141
 
143
142
  return sync_gen_wrapper
144
143
 
@@ -71,7 +71,7 @@ class LLMBasedOutputMergingSelector(StrategyBase):
71
71
  raise ImportError("langchain-core is not installed. Please install it to use SingleShotMultiPlanPlanner.\n"
72
72
  "This error can be resolved by installing nvidia-nat-langchain.")
73
73
 
74
- from typing import Callable
74
+ from collections.abc import Callable
75
75
 
76
76
  from pydantic import BaseModel
77
77
 
@@ -14,13 +14,16 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import asyncio
17
+ import logging
17
18
  import secrets
18
19
  import webbrowser
19
20
  from dataclasses import dataclass
20
21
  from dataclasses import field
21
22
 
22
23
  import click
24
+ import httpx
23
25
  import pkce
26
+ from authlib.common.errors import AuthlibBaseError as OAuthError
24
27
  from authlib.integrations.httpx_client import AsyncOAuth2Client
25
28
  from fastapi import FastAPI
26
29
  from fastapi import Request
@@ -32,6 +35,8 @@ from nat.data_models.authentication import AuthFlowType
32
35
  from nat.data_models.authentication import AuthProviderBaseConfig
33
36
  from nat.front_ends.fastapi.fastapi_front_end_controller import _FastApiFrontEndController
34
37
 
38
+ logger = logging.getLogger(__name__)
39
+
35
40
 
36
41
  # --------------------------------------------------------------------------- #
37
42
  # Helpers #
@@ -87,17 +92,53 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
87
92
  """
88
93
  Separated for easy overriding in tests (to inject ASGITransport).
89
94
  """
90
- client = AsyncOAuth2Client(
91
- client_id=cfg.client_id,
92
- client_secret=cfg.client_secret,
93
- redirect_uri=cfg.redirect_uri,
94
- scope=" ".join(cfg.scopes) if cfg.scopes else None,
95
- token_endpoint=cfg.token_url,
96
- token_endpoint_auth_method=cfg.token_endpoint_auth_method,
97
- code_challenge_method="S256" if cfg.use_pkce else None,
98
- )
99
- self._oauth_client = client
100
- return client
95
+ try:
96
+ client = AsyncOAuth2Client(
97
+ client_id=cfg.client_id,
98
+ client_secret=cfg.client_secret,
99
+ redirect_uri=cfg.redirect_uri,
100
+ scope=" ".join(cfg.scopes) if cfg.scopes else None,
101
+ token_endpoint=cfg.token_url,
102
+ token_endpoint_auth_method=cfg.token_endpoint_auth_method,
103
+ code_challenge_method="S256" if cfg.use_pkce else None,
104
+ )
105
+ self._oauth_client = client
106
+ return client
107
+ except (OAuthError, ValueError, TypeError) as e:
108
+ raise RuntimeError(f"Invalid OAuth2 configuration: {e}") from e
109
+ except Exception as e:
110
+ raise RuntimeError(f"Failed to create OAuth2 client: {e}") from e
111
+
112
+ def _create_authorization_url(self,
113
+ client: AsyncOAuth2Client,
114
+ config: OAuth2AuthCodeFlowProviderConfig,
115
+ state: str,
116
+ verifier: str | None = None,
117
+ challenge: str | None = None) -> str:
118
+ """
119
+ Create OAuth authorization URL with proper error handling.
120
+
121
+ Args:
122
+ client: The OAuth2 client instance
123
+ config: OAuth2 configuration
124
+ state: OAuth state parameter
125
+ verifier: PKCE verifier (if using PKCE)
126
+ challenge: PKCE challenge (if using PKCE)
127
+
128
+ Returns:
129
+ The authorization URL
130
+ """
131
+ try:
132
+ auth_url, _ = client.create_authorization_url(
133
+ config.authorization_url,
134
+ state=state,
135
+ code_verifier=verifier if config.use_pkce else None,
136
+ code_challenge=challenge if config.use_pkce else None,
137
+ **(config.authorization_kwargs or {})
138
+ )
139
+ return auth_url
140
+ except (OAuthError, ValueError, TypeError) as e:
141
+ raise RuntimeError(f"Error creating OAuth authorization URL: {e}") from e
101
142
 
102
143
  # --------------------------- HTTP Basic ------------------------------ #
103
144
  @staticmethod
@@ -131,13 +172,12 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
131
172
  flow_state.verifier = verifier
132
173
  flow_state.challenge = challenge
133
174
 
134
- auth_url, _ = client.create_authorization_url(
135
- cfg.authorization_url,
136
- state=state,
137
- code_verifier=flow_state.verifier if cfg.use_pkce else None,
138
- code_challenge=flow_state.challenge if cfg.use_pkce else None,
139
- **(cfg.authorization_kwargs or {})
140
- )
175
+ # Create authorization URL using helper function
176
+ auth_url = self._create_authorization_url(client=client,
177
+ config=cfg,
178
+ state=state,
179
+ verifier=flow_state.verifier,
180
+ challenge=flow_state.challenge)
141
181
 
142
182
  # Register flow + maybe spin up redirect handler
143
183
  async with self._server_lock:
@@ -149,14 +189,18 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
149
189
  self._flows[state] = flow_state
150
190
  self._active_flows += 1
151
191
 
152
- click.echo("Your browser has been opened for authentication.")
153
- webbrowser.open(auth_url)
192
+ try:
193
+ webbrowser.open(auth_url)
194
+ click.echo("Your browser has been opened for authentication.")
195
+ except Exception as e:
196
+ logger.error("Browser open failed: %s", e)
197
+ raise RuntimeError(f"Browser open failed: {e}") from e
154
198
 
155
199
  # Wait for the redirect to land
156
200
  try:
157
201
  token = await asyncio.wait_for(flow_state.future, timeout=300)
158
- except asyncio.TimeoutError:
159
- raise RuntimeError("Authentication timed out (5 min).")
202
+ except TimeoutError as exc:
203
+ raise RuntimeError("Authentication timed out (5 min).") from exc
160
204
  finally:
161
205
  async with self._server_lock:
162
206
  self._flows.pop(state, None)
@@ -175,9 +219,9 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
175
219
  # --------------- redirect server / in‑process app -------------------- #
176
220
  async def _build_redirect_app(self) -> FastAPI:
177
221
  """
178
- * If cfg.run_redirect_local_server == True → start a uvicorn server (old behaviour).
179
- * Else → only build the FastAPI app and save it to `self._redirect_app`
180
- for in‑process testing with ASGITransport.
222
+ * If cfg.run_redirect_local_server == True → start a local server.
223
+ * Else → only build the redirect app and save it to `self._redirect_app`
224
+ for in‑process testing.
181
225
  """
182
226
  app = FastAPI()
183
227
 
@@ -195,8 +239,16 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
195
239
  state=state,
196
240
  )
197
241
  flow_state.future.set_result(token)
198
- except Exception as exc: # noqa: BLE001
199
- flow_state.future.set_exception(exc)
242
+ except OAuthError as e:
243
+ flow_state.future.set_exception(
244
+ RuntimeError(f"Authorization server rejected request: {e.error} ({e.description})"))
245
+ return "Authentication failed: Authorization server rejected the request. You may close this tab."
246
+ except httpx.HTTPError as e:
247
+ flow_state.future.set_exception(RuntimeError(f"Network error during token fetch: {e}"))
248
+ return "Authentication failed: Network error occurred. You may close this tab."
249
+ except Exception as e:
250
+ flow_state.future.set_exception(RuntimeError(f"Authentication failed: {e}"))
251
+ return "Authentication failed: An unexpected error occurred. You may close this tab."
200
252
  return "Authentication successful – you may close this tab."
201
253
 
202
254
  return app
@@ -213,7 +265,7 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
213
265
 
214
266
  asyncio.create_task(self._server_controller.start_server(host="localhost", port=8000))
215
267
 
216
- # Give uvicorn a moment to bind sockets before we return
268
+ # Give the server a moment to bind sockets before we return
217
269
  await asyncio.sleep(0.3)
218
270
  except Exception as exc: # noqa: BLE001
219
271
  raise RuntimeError(f"Failed to start redirect server: {exc}") from exc
@@ -227,7 +279,7 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
227
279
  @property
228
280
  def redirect_app(self) -> FastAPI | None:
229
281
  """
230
- In testmode (run_redirect_local_server=False) the in‑memory FastAPI
231
- app is exposed so you can mount it on `httpx.ASGITransport`.
282
+ In test mode (run_redirect_local_server=False) the in‑memory redirect
283
+ app is exposed for testing purposes.
232
284
  """
233
285
  return self._redirect_app
@@ -88,7 +88,7 @@ class ConsoleFrontEndPlugin(SimpleFrontEndPluginBase[ConsoleFrontEndConfig]):
88
88
  elif (self.front_end_config.input_file):
89
89
 
90
90
  # Run the workflow
91
- with open(self.front_end_config.input_file, "r", encoding="utf-8") as f:
91
+ with open(self.front_end_config.input_file, encoding="utf-8") as f:
92
92
 
93
93
  async with session_manager.workflow.run(f) as runner:
94
94
  runner_outputs = await runner.result(to_type=str)
@@ -22,6 +22,7 @@ from dataclasses import dataclass
22
22
  from dataclasses import field
23
23
 
24
24
  import pkce
25
+ from authlib.common.errors import AuthlibBaseError as OAuthError
25
26
  from authlib.integrations.httpx_client import AsyncOAuth2Client
26
27
 
27
28
  from nat.authentication.interfaces import FlowHandlerBase
@@ -61,14 +62,50 @@ class WebSocketAuthenticationFlowHandler(FlowHandlerBase):
61
62
 
62
63
  raise NotImplementedError(f"Authentication method '{method}' is not supported by the websocket frontend.")
63
64
 
64
- def create_oauth_client(self, config: OAuth2AuthCodeFlowProviderConfig):
65
- return AsyncOAuth2Client(client_id=config.client_id,
66
- client_secret=config.client_secret,
67
- redirect_uri=config.redirect_uri,
68
- scope=" ".join(config.scopes) if config.scopes else None,
69
- token_endpoint=config.token_url,
70
- code_challenge_method='S256' if config.use_pkce else None,
71
- token_endpoint_auth_method=config.token_endpoint_auth_method)
65
+ def create_oauth_client(self, config: OAuth2AuthCodeFlowProviderConfig) -> AsyncOAuth2Client:
66
+ try:
67
+ return AsyncOAuth2Client(client_id=config.client_id,
68
+ client_secret=config.client_secret,
69
+ redirect_uri=config.redirect_uri,
70
+ scope=" ".join(config.scopes) if config.scopes else None,
71
+ token_endpoint=config.token_url,
72
+ code_challenge_method='S256' if config.use_pkce else None,
73
+ token_endpoint_auth_method=config.token_endpoint_auth_method)
74
+ except (OAuthError, ValueError, TypeError) as e:
75
+ raise RuntimeError(f"Invalid OAuth2 configuration: {e}") from e
76
+ except Exception as e:
77
+ raise RuntimeError(f"Failed to create OAuth2 client: {e}") from e
78
+
79
+ def _create_authorization_url(self,
80
+ client: AsyncOAuth2Client,
81
+ config: OAuth2AuthCodeFlowProviderConfig,
82
+ state: str,
83
+ verifier: str = None,
84
+ challenge: str = None) -> str:
85
+ """
86
+ Create OAuth authorization URL with proper error handling.
87
+
88
+ Args:
89
+ client: The OAuth2 client instance
90
+ config: OAuth2 configuration
91
+ state: OAuth state parameter
92
+ verifier: PKCE verifier (if using PKCE)
93
+ challenge: PKCE challenge (if using PKCE)
94
+
95
+ Returns:
96
+ The authorization URL
97
+ """
98
+ try:
99
+ authorization_url, _ = client.create_authorization_url(
100
+ config.authorization_url,
101
+ state=state,
102
+ code_verifier=verifier if config.use_pkce else None,
103
+ code_challenge=challenge if config.use_pkce else None,
104
+ **(config.authorization_kwargs or {})
105
+ )
106
+ return authorization_url
107
+ except (OAuthError, ValueError, TypeError) as e:
108
+ raise RuntimeError(f"Error creating OAuth authorization URL: {e}") from e
72
109
 
73
110
  async def _handle_oauth2_auth_code_flow(self, config: OAuth2AuthCodeFlowProviderConfig) -> AuthenticatedContext:
74
111
 
@@ -82,21 +119,19 @@ class WebSocketAuthenticationFlowHandler(FlowHandlerBase):
82
119
  flow_state.verifier = verifier
83
120
  flow_state.challenge = challenge
84
121
 
85
- authorization_url, _ = flow_state.client.create_authorization_url(
86
- config.authorization_url,
87
- state=state,
88
- code_verifier=flow_state.verifier if config.use_pkce else None,
89
- code_challenge=flow_state.challenge if config.use_pkce else None,
90
- **(config.authorization_kwargs or {})
91
- )
122
+ authorization_url = self._create_authorization_url(client=flow_state.client,
123
+ config=config,
124
+ state=state,
125
+ verifier=flow_state.verifier,
126
+ challenge=flow_state.challenge)
92
127
 
93
128
  await self._add_flow_cb(state, flow_state)
94
129
  await self._web_socket_message_handler.create_websocket_message(_HumanPromptOAuthConsent(text=authorization_url)
95
130
  )
96
131
  try:
97
132
  token = await asyncio.wait_for(flow_state.future, timeout=300)
98
- except asyncio.TimeoutError:
99
- raise RuntimeError("Authentication flow timed out after 5 minutes.")
133
+ except TimeoutError as exc:
134
+ raise RuntimeError("Authentication flow timed out after 5 minutes.") from exc
100
135
  finally:
101
136
 
102
137
  await self._remove_flow_cb(state)