agentstack-cli 0.6.0rc4__py3-none-any.whl → 0.6.1rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,8 +7,13 @@ from collections import defaultdict
7
7
  from typing import Any
8
8
 
9
9
  import httpx
10
+ from authlib.common.errors import AuthlibBaseError
11
+ from authlib.integrations.httpx_client import AsyncOAuth2Client
12
+ from authlib.oauth2.rfc6749.errors import InvalidGrantError, OAuth2Error
10
13
  from pydantic import BaseModel, Field
11
14
 
15
+ TOKEN_EXPIRY_LEEWAY = 60 # seconds
16
+
12
17
 
13
18
  class AuthToken(BaseModel):
14
19
  access_token: str
@@ -44,6 +49,7 @@ class AuthManager:
44
49
  def __init__(self, config_path: pathlib.Path):
45
50
  self._auth_path = config_path
46
51
  self._auth = self._load()
52
+ self._oidc_cache: dict[str, dict[str, Any]] = {}
47
53
 
48
54
  def _load(self) -> Auth:
49
55
  if not self._auth_path.exists():
@@ -53,7 +59,63 @@ class AuthManager:
53
59
  def _save(self) -> None:
54
60
  self._auth_path.write_text(self._auth.model_dump_json(indent=2))
55
61
 
56
- def save_auth_token(
62
+ async def _get_oidc_metadata(self, auth_server: str) -> dict[str, Any]:
63
+ """Fetch and cache OIDC metadata."""
64
+ if auth_server in self._oidc_cache:
65
+ return self._oidc_cache[auth_server]
66
+
67
+ async with httpx.AsyncClient() as client:
68
+ try:
69
+ resp = await client.get(f"{auth_server}/.well-known/openid-configuration")
70
+ resp.raise_for_status()
71
+ metadata = resp.json()
72
+ self._oidc_cache[auth_server] = metadata
73
+ return metadata
74
+ except Exception as e:
75
+ raise RuntimeError(f"OIDC discovery failed: {e}") from e
76
+
77
+ def _create_token_update_callback(self, server: str, auth_server: str):
78
+ """Create a callback that saves tokens when they're refreshed."""
79
+
80
+ def update_token(token: dict[str, Any]):
81
+ # Authlib calls this automatically when tokens are refreshed
82
+ # kwargs may include refresh_token and access_token but we don't need them
83
+ auth_config = self._auth.servers[server].authorization_servers[auth_server]
84
+ self.save_auth_info(
85
+ server=server,
86
+ auth_server=auth_server,
87
+ client_id=auth_config.client_id,
88
+ client_secret=auth_config.client_secret,
89
+ token=token,
90
+ registration_token=auth_config.registration_token,
91
+ )
92
+
93
+ return update_token
94
+
95
+ async def _get_oauth_client(self, server: str, auth_server: str) -> AsyncOAuth2Client:
96
+ """Create an OAuth2 client configured with current credentials."""
97
+ auth_config = self._auth.servers[server].authorization_servers[auth_server]
98
+
99
+ if not auth_config or not auth_config.token:
100
+ raise ValueError(f"No token found for {auth_server}")
101
+
102
+ metadata = await self._get_oidc_metadata(auth_server)
103
+
104
+ # Convert AuthToken to dict format authlib expects
105
+ token_dict = auth_config.token.model_dump(exclude_none=True)
106
+
107
+ client = AsyncOAuth2Client(
108
+ client_id=auth_config.client_id,
109
+ client_secret=auth_config.client_secret,
110
+ token_endpoint=metadata["token_endpoint"],
111
+ token=token_dict,
112
+ scope=token_dict.get("scope"),
113
+ update_token=self._create_token_update_callback(server, auth_server),
114
+ )
115
+
116
+ return client
117
+
118
+ def save_auth_info(
57
119
  self,
58
120
  server: str,
59
121
  auth_server: str | None = None,
@@ -63,7 +125,7 @@ class AuthManager:
63
125
  registration_token: str | None = None,
64
126
  ) -> None:
65
127
  if auth_server is not None and client_id is not None and token is not None:
66
- if token["access_token"]:
128
+ if token["access_token"] and token.get("expires_in") is not None:
67
129
  usetimestamp = int(time.time()) + int(token["expires_in"])
68
130
  token["expires_at"] = usetimestamp
69
131
  self._auth.servers[server].authorization_servers[auth_server] = AuthServer(
@@ -78,57 +140,56 @@ class AuthManager:
78
140
 
79
141
  async def exchange_refresh_token(self, auth_server: str, token: AuthToken) -> dict[str, Any] | None:
80
142
  """
81
- This method exchanges a refresh token for a new access token.
82
- """
143
+ Exchange a refresh token for a new access token using authlib.
83
144
 
84
- async with httpx.AsyncClient(headers={"Accept": "application/json"}) as client:
85
- resp = None
86
- try:
87
- resp = await client.get(f"{auth_server}/.well-known/openid-configuration")
88
- resp.raise_for_status()
89
- oidc = resp.json()
90
- except Exception as e:
91
- if resp:
92
- error_details = resp.json()
93
- print(f"error: {error_details['error']} error description: {error_details['error_description']}")
94
- raise RuntimeError(f"OIDC discovery failed: {e}") from e
95
-
96
- token_endpoint = oidc["token_endpoint"]
97
- try:
98
- client_id = (
99
- self._auth.servers[self._auth.active_server or ""].authorization_servers[auth_server].client_id
100
- )
101
- client_secret = (
102
- self._auth.servers[self._auth.active_server or ""].authorization_servers[auth_server].client_secret
103
- )
104
- resp = await client.post(
105
- f"{token_endpoint}",
106
- data={
107
- "grant_type": "refresh_token",
108
- "refresh_token": token.refresh_token,
109
- "scope": token.scope,
110
- "client_id": client_id,
111
- }
112
- | ({"client_secret": client_secret} if client_secret else {}),
145
+ Raises:
146
+ InvalidGrantError: If the refresh token is invalid or expired (4xx auth errors)
147
+ OAuth2Error: For other OAuth2 protocol errors
148
+ RuntimeError: For network errors or OIDC discovery failures
149
+ """
150
+ if not self._auth.active_server:
151
+ raise ValueError("No active server configured")
152
+
153
+ try:
154
+ metadata = await self._get_oidc_metadata(auth_server)
155
+ token_endpoint = metadata["token_endpoint"]
156
+
157
+ async with await self._get_oauth_client(self._auth.active_server, auth_server) as client:
158
+ # Authlib's fetch_token with refresh_token grant automatically handles the refresh
159
+ # and calls update_token callback to save the new token
160
+ new_token = await client.fetch_token(
161
+ url=token_endpoint,
162
+ grant_type="refresh_token",
163
+ refresh_token=token.refresh_token,
113
164
  )
114
- resp.raise_for_status()
115
- new_token = resp.json()
116
- except Exception as e:
117
- if resp:
118
- error_details = resp.json()
119
- print(f"error: {error_details['error']} error description: {error_details['error_description']}")
120
- raise RuntimeError(f"Failed to refresh token: {e}") from e
121
- self.save_auth_token(
122
- self._auth.active_server or "",
123
- self._auth.active_auth_server or "",
124
- self._auth.servers[self._auth.active_server or ""].authorization_servers[auth_server].client_id or "",
125
- self._auth.servers[self._auth.active_server or ""].authorization_servers[auth_server].client_secret
126
- or "",
127
- token=new_token,
128
- )
129
- return new_token
165
+ return new_token
166
+ except InvalidGrantError as e:
167
+ # 400-level OAuth errors: invalid/expired refresh token
168
+ raise InvalidGrantError(
169
+ description=f"Token refresh failed - invalid or expired refresh token: {e.description}"
170
+ ) from e
171
+ except OAuth2Error as e:
172
+ # Other OAuth2 protocol errors
173
+ raise OAuth2Error(description=f"OAuth2 error during token refresh: {e.description}") from e
174
+ except AuthlibBaseError as e:
175
+ # Other authlib errors
176
+ raise RuntimeError(f"Token refresh failed: {e}") from e
177
+ except Exception as e:
178
+ # Network errors, OIDC discovery failures, etc.
179
+ raise RuntimeError(f"Failed to refresh token: {e}") from e
130
180
 
131
181
  async def load_auth_token(self) -> str | None:
182
+ """
183
+ Load and refresh auth token if needed using authlib.
184
+
185
+ Returns:
186
+ Access token string, or None if no auth configured
187
+
188
+ Raises:
189
+ InvalidGrantError: If token is expired and refresh fails due to auth issues (4xx)
190
+ OAuth2Error: For other OAuth2 protocol errors
191
+ RuntimeError: For network or other errors
192
+ """
132
193
  active_res = self._auth.active_server
133
194
  active_auth_server = self._auth.active_auth_server
134
195
  if not active_res or not active_auth_server:
@@ -136,12 +197,12 @@ class AuthManager:
136
197
  server = self._auth.servers.get(active_res)
137
198
  if not server:
138
199
  return None
139
-
140
200
  auth_server = server.authorization_servers.get(active_auth_server)
141
201
  if not auth_server or not auth_server.token:
142
202
  return None
143
203
 
144
- if (auth_server.token.expires_at or 0) - 60 < time.time():
204
+ if (auth_server.token.expires_at or 0) - TOKEN_EXPIRY_LEEWAY < time.time():
205
+ # Token expired, try to refresh - this may raise TokenRefreshError
145
206
  new_token = await self.exchange_refresh_token(active_auth_server, auth_server.token)
146
207
  if new_token:
147
208
  return new_token["access_token"]
@@ -149,62 +210,53 @@ class AuthManager:
149
210
 
150
211
  return auth_server.token.access_token
151
212
 
152
- async def deregister_client(self, auth_server, client_id, registration_token) -> None:
153
- async with httpx.AsyncClient(headers={"Accept": "application/json"}) as client:
154
- resp = None
155
- try:
156
- resp = await client.get(f"{auth_server}/.well-known/openid-configuration")
157
- resp.raise_for_status()
158
- oidc = resp.json()
159
- registration_endpoint = oidc["registration_endpoint"]
160
- except Exception as e:
161
- if resp:
162
- error_details = resp.json()
163
- print(f"error: {error_details['error']} error description: {error_details['error_description']}")
164
- raise RuntimeError(f"OIDC discovery failed: {e}") from e
213
+ async def deregister_client(self, auth_server: str, client_id: str | None, registration_token: str | None) -> None:
214
+ """Deregister a dynamically registered OAuth2 client."""
215
+ if not client_id or not registration_token:
216
+ return # Nothing to deregister
165
217
 
166
- try:
167
- if client_id is not None and client_id != "" and registration_token is not None:
168
- headers = {"authorization": f"bearer {registration_token}"}
169
- resp = await client.delete(f"{registration_endpoint}/{client_id}", headers=headers)
170
- resp.raise_for_status()
218
+ try:
219
+ metadata = await self._get_oidc_metadata(auth_server)
220
+ registration_endpoint = metadata.get("registration_endpoint")
171
221
 
172
- except Exception as e:
173
- if resp:
174
- error_details = resp.json()
175
- print(f"error: {error_details['error']} error description: {error_details['error_description']}")
176
- raise RuntimeError(f"Dynamic client de-registration failed. {e}") from e
222
+ if not registration_endpoint:
223
+ raise RuntimeError("Registration endpoint not found in OIDC metadata")
224
+
225
+ async with AsyncOAuth2Client() as client:
226
+ headers = {"Authorization": f"Bearer {registration_token}"}
227
+ resp = await client.delete(f"{registration_endpoint}/{client_id}", headers=headers)
228
+ resp.raise_for_status()
229
+
230
+ except Exception as e:
231
+ raise RuntimeError(f"Dynamic client de-registration failed: {e}") from e
177
232
 
178
233
  async def clear_auth_token(self, all: bool = False) -> None:
179
234
  if all:
180
235
  for server in self._auth.servers:
181
236
  for auth_server in self._auth.servers[server].authorization_servers:
237
+ auth_config = self._auth.servers[server].authorization_servers[auth_server]
182
238
  await self.deregister_client(
183
239
  auth_server,
184
- self._auth.servers[server].authorization_servers[auth_server].client_id,
185
- self._auth.servers[server].authorization_servers[auth_server].registration_token,
240
+ auth_config.client_id,
241
+ auth_config.registration_token,
186
242
  )
187
243
 
188
244
  self._auth.servers = defaultdict(Server)
189
245
  else:
190
246
  if self._auth.active_server and self._auth.active_auth_server:
191
- if (
192
- self._auth.servers[self._auth.active_server]
193
- .authorization_servers[self._auth.active_auth_server]
194
- .client_id
195
- ):
196
- await self.deregister_client(
197
- self._auth.active_auth_server,
198
- self._auth.servers[self._auth.active_server]
199
- .authorization_servers[self._auth.active_auth_server]
200
- .client_id,
201
- self._auth.servers[self._auth.active_server]
202
- .authorization_servers[self._auth.active_auth_server]
203
- .registration_token,
204
- )
247
+ auth_config = self._auth.servers[self._auth.active_server].authorization_servers[
248
+ self._auth.active_auth_server
249
+ ]
250
+ await self.deregister_client(
251
+ self._auth.active_auth_server,
252
+ auth_config.client_id,
253
+ auth_config.registration_token,
254
+ )
205
255
  del self._auth.servers[self._auth.active_server].authorization_servers[self._auth.active_auth_server]
256
+
206
257
  if self._auth.active_server and not self._auth.servers[self._auth.active_server].authorization_servers:
207
258
  del self._auth.servers[self._auth.active_server]
259
+
208
260
  self._auth.active_server = None
209
261
  self._auth.active_auth_server = None
210
262
  self._save()
@@ -112,9 +112,8 @@ from rich.table import Column
112
112
 
113
113
  from agentstack_cli.api import a2a_client
114
114
  from agentstack_cli.async_typer import AsyncTyper, console, create_table, err_console
115
+ from agentstack_cli.server_utils import announce_server_action, confirm_server_action
115
116
  from agentstack_cli.utils import (
116
- announce_server_action,
117
- confirm_server_action,
118
117
  generate_schema_example,
119
118
  is_github_url,
120
119
  parse_env_var,
@@ -181,6 +180,36 @@ processing_messages = [
181
180
 
182
181
  configuration = Configuration()
183
182
 
183
+ DISCOVERY_TIMEOUT_SEC = 180
184
+ DISCOVERY_POLL_INTERVAL_SEC = 2
185
+
186
+
187
+ async def _discover_agent_card(docker_image: str) -> AgentCard:
188
+ from agentstack_sdk.platform.provider_discovery import DiscoveryState, ProviderDiscovery
189
+
190
+ console.info("Image missing agent card label, starting discovery...")
191
+
192
+ async with configuration.use_platform_client():
193
+ with status("Creating discovery task"):
194
+ discovery = await ProviderDiscovery.create(docker_image=docker_image)
195
+
196
+ start = asyncio.get_event_loop().time()
197
+ with status("Discovering agent card (this may take a while)"):
198
+ while discovery.status in (DiscoveryState.PENDING, DiscoveryState.IN_PROGRESS):
199
+ if asyncio.get_event_loop().time() - start > DISCOVERY_TIMEOUT_SEC:
200
+ raise RuntimeError("Discovery timed out after 3 minutes")
201
+ await asyncio.sleep(DISCOVERY_POLL_INTERVAL_SEC)
202
+ await discovery.get()
203
+
204
+ if discovery.status == DiscoveryState.FAILED:
205
+ raise RuntimeError(f"Discovery failed: {discovery.error_message}")
206
+
207
+ card = discovery.agent_card
208
+ if not card:
209
+ raise RuntimeError("Discovery completed but no agent card was returned")
210
+
211
+ return card
212
+
184
213
 
185
214
  @app.command("add")
186
215
  async def add_agent(
@@ -257,9 +286,18 @@ async def add_agent(
257
286
  if dockerfile:
258
287
  raise ValueError("Dockerfile can be specified only if location is a GitHub url")
259
288
  console.info(f"Assuming public docker image or network address, attempting to add {location}")
260
- with status("Registering agent to platform"):
261
- async with configuration.use_platform_client():
262
- await Provider.create(location=location)
289
+ try:
290
+ with status("Registering agent to platform"):
291
+ async with configuration.use_platform_client():
292
+ await Provider.create(location=location)
293
+ except httpx.HTTPStatusError as e:
294
+ if e.response.status_code == 422:
295
+ agent_card = await _discover_agent_card(location)
296
+ with status("Registering agent with discovered card"):
297
+ async with configuration.use_platform_client():
298
+ await Provider.create(location=location, agent_card=agent_card)
299
+ else:
300
+ raise
263
301
  console.success(f"Agent [bold]{location}[/bold] added to platform")
264
302
  await list_agents()
265
303
 
@@ -25,10 +25,9 @@ from tenacity import AsyncRetrying, retry_if_exception_type, stop_after_delay, w
25
25
 
26
26
  from agentstack_cli.async_typer import AsyncTyper
27
27
  from agentstack_cli.console import console, err_console
28
+ from agentstack_cli.server_utils import announce_server_action, confirm_server_action
28
29
  from agentstack_cli.utils import (
29
- announce_server_action,
30
30
  capture_output,
31
- confirm_server_action,
32
31
  extract_messages,
33
32
  print_log,
34
33
  run_command,
@@ -1,6 +1,7 @@
1
1
  # Copyright 2025 © BeeAI a Series of LF Projects, LLC
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  import asyncio
4
+ import sys
4
5
  import typing
5
6
 
6
7
  import pydantic
@@ -14,7 +15,7 @@ from agentstack_cli import configuration
14
15
  from agentstack_cli.async_typer import AsyncTyper
15
16
  from agentstack_cli.configuration import Configuration
16
17
  from agentstack_cli.console import console
17
- from agentstack_cli.utils import (
18
+ from agentstack_cli.server_utils import (
18
19
  announce_server_action,
19
20
  confirm_server_action,
20
21
  )
@@ -99,7 +100,7 @@ async def remove_connector(
99
100
  console.error(
100
101
  "[red]Cannot specify both --all and a search path. Use --all to remove all connectors, or provide a search path for specific connectors.[/red]"
101
102
  )
102
- raise typer.Exit(1)
103
+ sys.exit(1)
103
104
 
104
105
  async with configuration.use_platform_client():
105
106
  connectors_list = await Connector.list()
@@ -191,7 +192,7 @@ async def select_connector(search_path: str) -> Connector | None:
191
192
  except ValueError as e:
192
193
  console.error(e.__str__())
193
194
  console.hint("Please refine your input to match exactly one connector id or url.")
194
- raise typer.Exit(code=1) from None
195
+ sys.exit(1)
195
196
 
196
197
 
197
198
  @app.command("get")
@@ -258,7 +259,7 @@ async def disconnect(
258
259
  console.error(
259
260
  "[red]Cannot specify both --all and a search path. Use --all to remove all connectors, or provide a search path for specific connectors.[/red]"
260
261
  )
261
- raise typer.Exit(1)
262
+ sys.exit(1)
262
263
 
263
264
  async with configuration.use_platform_client():
264
265
  connectors_list = await Connector.list()
@@ -25,7 +25,8 @@ from rich.table import Column
25
25
  from agentstack_cli.api import openai_client
26
26
  from agentstack_cli.async_typer import AsyncTyper, console, create_table
27
27
  from agentstack_cli.configuration import Configuration
28
- from agentstack_cli.utils import announce_server_action, confirm_server_action, run_command, verbosity
28
+ from agentstack_cli.server_utils import announce_server_action, confirm_server_action
29
+ from agentstack_cli.utils import run_command, verbosity
29
30
 
30
31
  app = AsyncTyper()
31
32
  configuration = Configuration()
@@ -17,7 +17,7 @@ import typer
17
17
  from tenacity import AsyncRetrying, retry_if_exception_type, stop_after_delay, wait_fixed
18
18
 
19
19
  from agentstack_cli.async_typer import AsyncTyper
20
- from agentstack_cli.commands.platform.base_driver import BaseDriver
20
+ from agentstack_cli.commands.platform.base_driver import BaseDriver, ImagePullMode
21
21
  from agentstack_cli.commands.platform.lima_driver import LimaDriver
22
22
  from agentstack_cli.commands.platform.wsl_driver import WSLDriver
23
23
  from agentstack_cli.configuration import Configuration
@@ -64,19 +64,20 @@ async def start(
64
64
  set_values_list: typing.Annotated[
65
65
  list[str], typer.Option("--set", help="Set Helm chart values using <key>=<value> syntax", default_factory=list)
66
66
  ],
67
- import_images: typing.Annotated[
68
- list[str],
67
+ image_pull_mode: typing.Annotated[
68
+ ImagePullMode,
69
69
  typer.Option(
70
- "--import", help="Import an image from a local Docker CLI into Agent Stack platform", default_factory=list
70
+ "--image-pull-mode",
71
+ help=textwrap.dedent(
72
+ """\
73
+ guest = pull all images inside VM
74
+ host = pull unavailable images on host, then import all
75
+ hybrid = import available images from host, pull the rest in VM
76
+ skip = skip explicit pull step (Kubernetes will attempt to pull missing images)
77
+ """
78
+ ),
71
79
  ),
72
- ],
73
- pull_on_host: typing.Annotated[
74
- bool,
75
- typer.Option(
76
- "--pull-on-host",
77
- help="Pull images on host Docker daemon and import them instead of pulling inside the VM. Acts as a pull cache layer.",
78
- ),
79
- ] = False,
80
+ ] = ImagePullMode.guest,
80
81
  values_file: typing.Annotated[
81
82
  pathlib.Path | None, typer.Option("-f", help="Set Helm chart values using yaml values file")
82
83
  ] = None,
@@ -101,10 +102,7 @@ async def start(
101
102
  await driver.deploy(
102
103
  set_values_list=set_values_list,
103
104
  values_file=values_file_path,
104
- import_images=import_images,
105
- pull_on_host=pull_on_host,
106
- skip_pull=skip_pull,
107
- skip_restart_deployments=skip_restart_deployments,
105
+ image_pull_mode=image_pull_mode,
108
106
  )
109
107
 
110
108
  if not no_wait_for_platform: