llamactl 0.3.0a18__py3-none-any.whl → 0.3.0a20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,362 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import logging
5
+ from types import TracebackType
6
+ from typing import Any, AsyncContextManager, AsyncGenerator, Awaitable, Callable, Self
7
+
8
+ import httpx
9
+ import jwt
10
+ from jwt.algorithms import RSAAlgorithm # type: ignore[possibly-unbound-import]
11
+ from llama_deploy.cli.config.schema import DeviceOIDC
12
+ from pydantic import BaseModel
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class OidcDiscoveryResponse(BaseModel):
18
+ discovery_url: str
19
+ client_ids: dict[str, str] | None = None
20
+
21
+
22
+ class OidcProviderConfiguration(BaseModel):
23
+ device_authorization_endpoint: str | None = None
24
+ token_endpoint: str | None = None
25
+ scopes_supported: list[str] | None = None
26
+ jwks_uri: str | None = None
27
+
28
+
29
+ class JsonWebKey(BaseModel):
30
+ kty: str
31
+ kid: str | None = None
32
+ use: str | None = None
33
+ alg: str | None = None
34
+ n: str | None = None
35
+ e: str | None = None
36
+ x5c: list[str] | None = None
37
+ x5t: str | None = None
38
+ x5t_s256: str | None = None
39
+
40
+
41
+ class JsonWebKeySet(BaseModel):
42
+ keys: list[JsonWebKey]
43
+
44
+
45
+ class AuthMeResponse(BaseModel):
46
+ id: str
47
+ email: str | None = None
48
+ last_login_provider: str | None = None
49
+ name: str | None = None
50
+ first_name: str | None = None
51
+ last_name: str | None = None
52
+ claims: dict[str, Any] | None = None
53
+ restrict: Any | None = None
54
+ created_at: str | None = None
55
+
56
+
57
+ class ClientContextManager(AsyncContextManager):
58
+ def __init__(self, base_url: str | None, auth: httpx.Auth | None = None) -> None:
59
+ self.base_url = base_url.rstrip("/") if base_url else None
60
+ if self.base_url:
61
+ self.client = httpx.AsyncClient(base_url=self.base_url, auth=auth)
62
+ else:
63
+ self.client = httpx.AsyncClient(auth=auth)
64
+
65
+ async def close(self) -> None:
66
+ try:
67
+ await self.client.aclose()
68
+ except Exception:
69
+ pass
70
+
71
+ async def __aenter__(self) -> Self:
72
+ return self
73
+
74
+ async def __aexit__(
75
+ self,
76
+ exc_type: type | None,
77
+ exc_value: BaseException | None,
78
+ traceback: TracebackType | None,
79
+ ) -> None:
80
+ await self.close()
81
+
82
+
83
+ class PlatformAuthDiscoveryClient(ClientContextManager):
84
+ """Client for ad hoc auth endpoints under /api/v1/auth."""
85
+
86
+ def __init__(self, base_url: str) -> None:
87
+ super().__init__(base_url)
88
+
89
+ async def oidc_discovery(self) -> OidcDiscoveryResponse:
90
+ resp = await self.client.get("/api/v1/auth/oidc/discovery", timeout=10.0)
91
+ resp.raise_for_status()
92
+ return OidcDiscoveryResponse.model_validate(resp.json())
93
+
94
+
95
+ class APIToken(BaseModel):
96
+ token: str
97
+ id: str
98
+
99
+
100
+ class PlatformAuthClient(ClientContextManager):
101
+ """Client for user introspection under /api/v1/auth/me."""
102
+
103
+ def __init__(
104
+ self, base_url: str, id_token: str | None = None, auth: httpx.Auth | None = None
105
+ ) -> None:
106
+ self.id_token = id_token
107
+ super().__init__(base_url, auth=auth)
108
+
109
+ async def me(self) -> AuthMeResponse:
110
+ headers = (
111
+ {"Authorization": f"Bearer {self.id_token}"} if self.id_token else None
112
+ )
113
+ resp = await self.client.get("/api/v1/auth/me", headers=headers, timeout=10.0)
114
+ resp.raise_for_status()
115
+ return AuthMeResponse.model_validate(resp.json())
116
+
117
+ async def create_agent_api_key(self, name: str) -> APIToken:
118
+ resp = await self.client.post(
119
+ "/api/v1/api-keys",
120
+ json={"name": name, "project_id": None},
121
+ )
122
+ resp.raise_for_status()
123
+ json = resp.json()
124
+ token = json["redacted_api_key"]
125
+ id = json["id"]
126
+ return APIToken(token=token, id=id)
127
+
128
+ async def delete_api_key(self, id: str) -> None:
129
+ response = await self.client.delete(f"/api/v1/api-keys/{id}")
130
+ response.raise_for_status()
131
+
132
+
133
+ class RefreshMiddleware(httpx.Auth):
134
+ def __init__(
135
+ self,
136
+ device_oidc: DeviceOIDC,
137
+ on_refresh: Callable[[DeviceOIDC], Awaitable[None]],
138
+ ) -> None:
139
+ self.device_oidc = device_oidc
140
+ self.on_refresh = on_refresh
141
+ self.lock = asyncio.Lock()
142
+
143
+ async def _refresh_and_update(self) -> None:
144
+ new_device_oidc = await refresh(self.device_oidc)
145
+ self.device_oidc = new_device_oidc
146
+ try:
147
+ await self.on_refresh(new_device_oidc)
148
+ except Exception:
149
+ logger.exception("Error in on_refresh callback")
150
+
151
+ async def async_auth_flow(
152
+ self, request: httpx.Request
153
+ ) -> AsyncGenerator[httpx.Request, httpx.Response]:
154
+ token = self.device_oidc.device_access_token
155
+ request.headers["Authorization"] = f"Bearer {token}"
156
+
157
+ response = yield request
158
+ if response.status_code == 401:
159
+ async with self.lock:
160
+ if token == self.device_oidc.device_access_token:
161
+ await self._refresh_and_update()
162
+ request.headers["Authorization"] = (
163
+ f"Bearer {self.device_oidc.device_access_token}"
164
+ )
165
+ yield request
166
+
167
+
168
+ class DeviceAuthorizationRequest(BaseModel):
169
+ client_id: str
170
+ scope: str
171
+
172
+
173
+ class DeviceAuthorizationResponse(BaseModel):
174
+ device_code: str
175
+ user_code: str
176
+ verification_uri: str
177
+ verification_uri_complete: str | None = None
178
+ expires_in: int
179
+ interval: int | None = None
180
+
181
+
182
+ class TokenRequestDeviceCode(BaseModel):
183
+ grant_type: str = "urn:ietf:params:oauth:grant-type:device_code"
184
+ device_code: str
185
+ client_id: str
186
+
187
+
188
+ class TokenResponse(BaseModel):
189
+ # Success fields
190
+ id_token: str | None = None
191
+ access_token: str | None = None
192
+ refresh_token: str | None = None
193
+ expires_in: int | None = None
194
+ token_type: str | None = None
195
+ scope: str | None = None
196
+ # Error fields
197
+ error: str | None = None
198
+ error_description: str | None = None
199
+
200
+
201
+ class TokenRequestRefresh(BaseModel):
202
+ grant_type: str = "refresh_token"
203
+ refresh_token: str
204
+ client_id: str
205
+
206
+
207
+ class OIDCClient(ClientContextManager):
208
+ def __init__(self) -> None:
209
+ super().__init__(None)
210
+
211
+ async def fetch_provider_configuration(
212
+ self, discovery_url: str
213
+ ) -> OidcProviderConfiguration:
214
+ resp = await self.client.get(discovery_url, timeout=10.0)
215
+ resp.raise_for_status()
216
+ return OidcProviderConfiguration.model_validate(resp.json())
217
+
218
+ async def device_authorization(
219
+ self, device_endpoint: str, request: DeviceAuthorizationRequest
220
+ ) -> DeviceAuthorizationResponse:
221
+ resp = await self.client.post(
222
+ device_endpoint,
223
+ data=request.model_dump(),
224
+ headers={
225
+ "Accept": "application/json",
226
+ "Content-Type": "application/x-www-form-urlencoded",
227
+ },
228
+ timeout=10.0,
229
+ )
230
+ resp.raise_for_status()
231
+ return DeviceAuthorizationResponse.model_validate(resp.json())
232
+
233
+ async def token_with_device_code(
234
+ self, token_endpoint: str, request: TokenRequestDeviceCode
235
+ ) -> TokenResponse:
236
+ resp = await self.client.post(
237
+ token_endpoint,
238
+ data=request.model_dump(),
239
+ headers={
240
+ "Accept": "application/json",
241
+ "Content-Type": "application/x-www-form-urlencoded",
242
+ },
243
+ timeout=10.0,
244
+ )
245
+ # Do not raise for status; callers inspect error payloads during polling
246
+ try:
247
+ payload = resp.json()
248
+ except Exception:
249
+ # Fall back to minimal error information
250
+ return TokenResponse(error="invalid_response", error_description=resp.text)
251
+ return TokenResponse.model_validate(payload)
252
+
253
+ async def token_with_refresh(
254
+ self, token_endpoint: str, request: TokenRequestRefresh
255
+ ) -> TokenResponse:
256
+ resp = await self.client.post(
257
+ token_endpoint,
258
+ data=request.model_dump(),
259
+ headers={
260
+ "Accept": "application/json",
261
+ "Content-Type": "application/x-www-form-urlencoded",
262
+ },
263
+ timeout=10.0,
264
+ )
265
+ try:
266
+ payload = resp.json()
267
+ except Exception:
268
+ return TokenResponse(error="invalid_response", error_description=resp.text)
269
+ return TokenResponse.model_validate(payload)
270
+
271
+ async def get_jwks(self, jwks_uri: str) -> JsonWebKeySet:
272
+ resp = await self.client.get(jwks_uri, timeout=10.0)
273
+ resp.raise_for_status()
274
+ return JsonWebKeySet.model_validate(resp.json())
275
+
276
+
277
+ async def decode_jwt_claims_from_device_oidc(
278
+ oidc_device: DeviceOIDC,
279
+ verify_audience: bool = False,
280
+ verify_expiration: bool = False,
281
+ audience: str | None = None,
282
+ ) -> dict[str, Any]:
283
+ """Decode JWT claims by discovering provider and verifying via JWKS.
284
+
285
+ Assumes RSA signing. Audience verification can be toggled and, when enabled,
286
+ an audience value can be provided.
287
+ """
288
+ if not oidc_device.device_id_token:
289
+ raise ValueError("Device ID token is missing. Cannot decode claims.")
290
+ async with OIDCClient() as oidc:
291
+ provider = await oidc.fetch_provider_configuration(oidc_device.discovery_url)
292
+ jwks_uri = provider.jwks_uri
293
+ if not jwks_uri:
294
+ raise ValueError("Provider does not expose jwks_uri")
295
+ return await decode_jwt_claims(
296
+ oidc_device.device_id_token,
297
+ jwks_uri,
298
+ verify_audience,
299
+ verify_expiration,
300
+ audience,
301
+ )
302
+
303
+
304
+ async def decode_jwt_claims(
305
+ token: str,
306
+ jwks_uri: str,
307
+ verify_audience: bool = False,
308
+ verify_expiration: bool = False,
309
+ audience: str | None = None,
310
+ ) -> dict[str, Any]:
311
+ async with OIDCClient() as oidc:
312
+ jwks = await oidc.get_jwks(jwks_uri)
313
+
314
+ # Select key
315
+ header = jwt.get_unverified_header(token)
316
+ kid = header.get("kid")
317
+ alg = header.get("alg", "RS256")
318
+ keys = jwks.keys
319
+ key = next((k for k in keys if k.kid == kid), None) or next(iter(keys), None)
320
+ if not key:
321
+ raise ValueError("Signing key not found in JWKS")
322
+
323
+ # Build public key (RSA-only)
324
+ if key.kty != "RSA":
325
+ raise ValueError("Unsupported JWK kty; only RSA is supported")
326
+ key_json = key.model_dump_json()
327
+ public_key = RSAAlgorithm.from_jwk(key_json)
328
+
329
+ return jwt.decode(
330
+ token,
331
+ public_key,
332
+ algorithms=[alg],
333
+ options={"verify_aud": verify_audience, "verify_exp": verify_expiration},
334
+ audience=audience,
335
+ )
336
+
337
+
338
+ async def refresh(device_oidc: DeviceOIDC) -> DeviceOIDC:
339
+ """
340
+ Run a refresh on the access token, storing updated tokens in a new DeviceOIDC.
341
+ """
342
+ async with OIDCClient() as oidc:
343
+ provider = await oidc.fetch_provider_configuration(device_oidc.discovery_url)
344
+ token_endpoint = provider.token_endpoint
345
+ if not token_endpoint:
346
+ raise ValueError("Provider does not expose token_endpoint")
347
+ if not device_oidc.device_refresh_token:
348
+ raise ValueError("Device refresh token is missing. Cannot refresh.")
349
+ token = await oidc.token_with_refresh(
350
+ token_endpoint,
351
+ TokenRequestRefresh(
352
+ refresh_token=device_oidc.device_refresh_token,
353
+ client_id=device_oidc.client_id,
354
+ ),
355
+ )
356
+ copy = device_oidc.model_copy()
357
+ if not token.access_token:
358
+ raise ValueError("Refresh failed: token response missing access_token")
359
+ copy.device_access_token = token.access_token
360
+ copy.device_refresh_token = token.refresh_token or copy.device_refresh_token
361
+ copy.device_id_token = token.id_token or copy.device_id_token
362
+ return copy
@@ -7,26 +7,35 @@ from rich import print as rprint
7
7
 
8
8
 
9
9
  def get_control_plane_client() -> ControlPlaneClient:
10
+ auth_svc = service.current_auth_service()
10
11
  profile = service.current_auth_service().get_current_profile()
11
12
  if profile:
12
13
  resolved_base_url = profile.api_url.rstrip("/")
13
14
  resolved_api_key = profile.api_key
14
- return ControlPlaneClient(resolved_base_url, resolved_api_key)
15
+ return ControlPlaneClient(
16
+ resolved_base_url, resolved_api_key, auth_svc.auth_middleware()
17
+ )
15
18
 
16
19
  # Fallback: allow env-scoped client construction for env operations
17
20
  env = service.get_current_environment()
18
21
  resolved_base_url = env.api_url.rstrip("/")
19
- return ControlPlaneClient(resolved_base_url, None)
22
+ return ControlPlaneClient(resolved_base_url)
20
23
 
21
24
 
22
25
  def get_project_client() -> ProjectClient:
23
- profile = service.current_auth_service().get_current_profile()
26
+ auth_svc = service.current_auth_service()
27
+ profile = auth_svc.get_current_profile()
24
28
  if not profile:
25
29
  rprint("\n[bold red]No profile configured![/bold red]")
26
30
  rprint("\nTo get started, create a profile with:")
27
- rprint("[cyan]llamactl auth token[/cyan]")
31
+ if auth_svc.env.requires_auth:
32
+ rprint("[cyan]llamactl auth login[/cyan]")
33
+ else:
34
+ rprint("[cyan]llamactl auth token[/cyan]")
28
35
  raise SystemExit(1)
29
- return ProjectClient(profile.api_url, profile.project_id, profile.api_key)
36
+ return ProjectClient(
37
+ profile.api_url, profile.project_id, profile.api_key, auth_svc.auth_middleware()
38
+ )
30
39
 
31
40
 
32
41
  @asynccontextmanager