kscale 0.3.15__tar.gz → 0.3.17__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. {kscale-0.3.15/kscale.egg-info → kscale-0.3.17}/PKG-INFO +1 -1
  2. {kscale-0.3.15 → kscale-0.3.17}/kscale/__init__.py +1 -1
  3. {kscale-0.3.15 → kscale-0.3.17}/kscale/web/cli/robot_class.py +4 -4
  4. {kscale-0.3.15 → kscale-0.3.17}/kscale/web/cli/user.py +0 -9
  5. kscale-0.3.17/kscale/web/clients/base.py +124 -0
  6. {kscale-0.3.15 → kscale-0.3.17}/kscale/web/clients/user.py +0 -4
  7. {kscale-0.3.15 → kscale-0.3.17}/kscale/web/gen/api.py +78 -2
  8. {kscale-0.3.15 → kscale-0.3.17/kscale.egg-info}/PKG-INFO +1 -1
  9. kscale-0.3.15/kscale/web/clients/base.py +0 -423
  10. {kscale-0.3.15 → kscale-0.3.17}/LICENSE +0 -0
  11. {kscale-0.3.15 → kscale-0.3.17}/MANIFEST.in +0 -0
  12. {kscale-0.3.15 → kscale-0.3.17}/README.md +0 -0
  13. {kscale-0.3.15 → kscale-0.3.17}/kscale/artifacts/__init__.py +0 -0
  14. {kscale-0.3.15 → kscale-0.3.17}/kscale/artifacts/plane.obj +0 -0
  15. {kscale-0.3.15 → kscale-0.3.17}/kscale/artifacts/plane.urdf +0 -0
  16. {kscale-0.3.15 → kscale-0.3.17}/kscale/cli.py +0 -0
  17. {kscale-0.3.15 → kscale-0.3.17}/kscale/conf.py +0 -0
  18. {kscale-0.3.15 → kscale-0.3.17}/kscale/py.typed +0 -0
  19. {kscale-0.3.15 → kscale-0.3.17}/kscale/requirements-dev.txt +0 -0
  20. {kscale-0.3.15 → kscale-0.3.17}/kscale/requirements.txt +0 -0
  21. {kscale-0.3.15 → kscale-0.3.17}/kscale/utils/__init__.py +0 -0
  22. {kscale-0.3.15 → kscale-0.3.17}/kscale/utils/api_base.py +0 -0
  23. {kscale-0.3.15 → kscale-0.3.17}/kscale/utils/checksum.py +0 -0
  24. {kscale-0.3.15 → kscale-0.3.17}/kscale/utils/cli.py +0 -0
  25. {kscale-0.3.15 → kscale-0.3.17}/kscale/web/__init__.py +0 -0
  26. {kscale-0.3.15 → kscale-0.3.17}/kscale/web/cli/__init__.py +0 -0
  27. {kscale-0.3.15 → kscale-0.3.17}/kscale/web/cli/robot.py +0 -0
  28. {kscale-0.3.15 → kscale-0.3.17}/kscale/web/clients/__init__.py +0 -0
  29. {kscale-0.3.15 → kscale-0.3.17}/kscale/web/clients/client.py +0 -0
  30. {kscale-0.3.15 → kscale-0.3.17}/kscale/web/clients/robot.py +0 -0
  31. {kscale-0.3.15 → kscale-0.3.17}/kscale/web/clients/robot_class.py +0 -0
  32. {kscale-0.3.15 → kscale-0.3.17}/kscale/web/gen/__init__.py +0 -0
  33. {kscale-0.3.15 → kscale-0.3.17}/kscale/web/utils.py +0 -0
  34. {kscale-0.3.15 → kscale-0.3.17}/kscale.egg-info/SOURCES.txt +0 -0
  35. {kscale-0.3.15 → kscale-0.3.17}/kscale.egg-info/dependency_links.txt +0 -0
  36. {kscale-0.3.15 → kscale-0.3.17}/kscale.egg-info/entry_points.txt +0 -0
  37. {kscale-0.3.15 → kscale-0.3.17}/kscale.egg-info/not-zip-safe +0 -0
  38. {kscale-0.3.15 → kscale-0.3.17}/kscale.egg-info/requires.txt +0 -0
  39. {kscale-0.3.15 → kscale-0.3.17}/kscale.egg-info/top_level.txt +0 -0
  40. {kscale-0.3.15 → kscale-0.3.17}/pyproject.toml +0 -0
  41. {kscale-0.3.15 → kscale-0.3.17}/setup.cfg +0 -0
  42. {kscale-0.3.15 → kscale-0.3.17}/setup.py +0 -0
  43. {kscale-0.3.15 → kscale-0.3.17}/tests/test_dummy.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: kscale
3
- Version: 0.3.15
3
+ Version: 0.3.17
4
4
  Summary: The kscale project
5
5
  Home-page: https://github.com/kscalelabs/kscale
6
6
  Author: Benjamin Bolte
@@ -1,6 +1,6 @@
1
1
  """Defines the common interface for the K-Scale Python API."""
2
2
 
3
- __version__ = "0.3.15"
3
+ __version__ = "0.3.17"
4
4
 
5
5
  from pathlib import Path
6
6
 
@@ -192,7 +192,7 @@ async def run_pybullet(
192
192
  ) -> None:
193
193
  """Shows the URDF file for a robot class in PyBullet."""
194
194
  try:
195
- import pybullet as p
195
+ import pybullet as p # noqa: PLC0415
196
196
  except ImportError:
197
197
  click.echo(click.style("PyBullet is not installed; install it with `pip install pybullet`", fg="red"))
198
198
  return
@@ -453,14 +453,14 @@ async def run_mujoco(class_name: str, scene: str, no_cache: bool) -> None:
453
453
  launches the Mujoco viewer using the provided MJCF file.
454
454
  """
455
455
  try:
456
- from mujoco_scenes.errors import TemplateNotFoundError
457
- from mujoco_scenes.mjcf import list_scenes, load_mjmodel
456
+ from mujoco_scenes.errors import TemplateNotFoundError # noqa: PLC0415
457
+ from mujoco_scenes.mjcf import list_scenes, load_mjmodel # noqa: PLC0415
458
458
  except ImportError:
459
459
  click.echo(click.style("Mujoco Scenes is required; install with `pip install mujoco-scenes`", fg="red"))
460
460
  return
461
461
 
462
462
  try:
463
- import mujoco.viewer
463
+ import mujoco.viewer # noqa: PLC0415
464
464
  except ImportError:
465
465
  click.echo(click.style("Mujoco is required; install with `pip install mujoco`", fg="red"))
466
466
  return
@@ -39,14 +39,5 @@ async def me() -> None:
39
39
  )
40
40
 
41
41
 
42
- @cli.command()
43
- @coro
44
- async def key() -> None:
45
- """Get an API key for the currently-authenticated user."""
46
- client = UserClient()
47
- api_key = await client.get_api_key()
48
- click.echo(f"API key: {click.style(api_key, fg='green')}")
49
-
50
-
51
42
  if __name__ == "__main__":
52
43
  cli()
@@ -0,0 +1,124 @@
1
+ """Defines a base client for the K-Scale WWW API client."""
2
+
3
+ import logging
4
+ import os
5
+ import sys
6
+ from types import TracebackType
7
+ from typing import Any, Mapping, Self, Type
8
+ from urllib.parse import urljoin
9
+
10
+ import httpx
11
+ from pydantic import BaseModel
12
+
13
+ from kscale.web.utils import DEFAULT_UPLOAD_TIMEOUT, get_api_root
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # This is the name of the API key header for the K-Scale WWW API.
18
+ HEADER_NAME = "x-kscale-api-key"
19
+
20
+
21
+ def verbose_error() -> bool:
22
+ return os.environ.get("KSCALE_VERBOSE_ERROR", "0") == "1"
23
+
24
+
25
+ class BaseClient:
26
+ def __init__(
27
+ self,
28
+ base_url: str | None = None,
29
+ upload_timeout: float = DEFAULT_UPLOAD_TIMEOUT,
30
+ use_cache: bool = True,
31
+ ) -> None:
32
+ self.base_url = get_api_root() if base_url is None else base_url
33
+ self.upload_timeout = upload_timeout
34
+ self.use_cache = use_cache
35
+ self._client: httpx.AsyncClient | None = None
36
+ self._client_no_auth: httpx.AsyncClient | None = None
37
+
38
+ async def get_client(self, *, auth: bool = True) -> httpx.AsyncClient:
39
+ client = self._client if auth else self._client_no_auth
40
+ if client is None:
41
+ headers: dict[str, str] = {}
42
+ if auth:
43
+ if "KSCALE_API_KEY" not in os.environ:
44
+ raise ValueError("KSCALE_API_KEY is not set! Obtain one here: https://kscale.dev/dashboard/keys")
45
+ headers[HEADER_NAME] = os.environ["KSCALE_API_KEY"]
46
+
47
+ client = httpx.AsyncClient(
48
+ base_url=self.base_url,
49
+ headers=headers,
50
+ timeout=httpx.Timeout(30.0),
51
+ )
52
+ if auth:
53
+ self._client = client
54
+ else:
55
+ self._client_no_auth = client
56
+ return client
57
+
58
+ async def _request(
59
+ self,
60
+ method: str,
61
+ endpoint: str,
62
+ *,
63
+ auth: bool = True,
64
+ params: dict[str, Any] | None = None,
65
+ data: BaseModel | dict[str, Any] | None = None,
66
+ files: dict[str, Any] | None = None,
67
+ error_code_suggestions: dict[int, str] | None = None,
68
+ ) -> dict[str, Any]:
69
+ url = urljoin(self.base_url, endpoint)
70
+ kwargs: dict[str, Any] = {}
71
+ if params is not None:
72
+ kwargs["params"] = params
73
+ if data is not None:
74
+ if isinstance(data, BaseModel):
75
+ kwargs["json"] = data.model_dump(exclude_unset=True)
76
+ else:
77
+ kwargs["json"] = data
78
+ if files:
79
+ kwargs["files"] = files
80
+
81
+ client = await self.get_client(auth=auth)
82
+ response = await client.request(method, url, **kwargs)
83
+
84
+ if response.is_error:
85
+ error_code = response.status_code
86
+ error_json = response.json()
87
+ use_verbose_error = verbose_error()
88
+
89
+ if not use_verbose_error:
90
+ logger.info("Use KSCALE_VERBOSE_ERROR=1 to see the full error message")
91
+ logger.info("If this persists, please create an issue here: https://github.com/kscalelabs/kscale")
92
+
93
+ logger.error("Got error %d from the K-Scale API", error_code)
94
+ if isinstance(error_json, Mapping):
95
+ for key, value in error_json.items():
96
+ logger.error(" [%s] %s", key, value)
97
+ else:
98
+ logger.error(" %s", error_json)
99
+
100
+ if error_code_suggestions is not None and error_code in error_code_suggestions:
101
+ logger.error("Hint: %s", error_code_suggestions[error_code])
102
+
103
+ if use_verbose_error:
104
+ response.raise_for_status()
105
+ else:
106
+ sys.exit(1)
107
+
108
+ return response.json()
109
+
110
+ async def close(self) -> None:
111
+ if self._client is not None:
112
+ await self._client.aclose()
113
+ self._client = None
114
+
115
+ async def __aenter__(self) -> Self:
116
+ return self
117
+
118
+ async def __aexit__(
119
+ self,
120
+ exc_type: Type[BaseException] | None,
121
+ exc_val: BaseException | None,
122
+ exc_tb: TracebackType | None,
123
+ ) -> None:
124
+ await self.close()
@@ -8,7 +8,3 @@ class UserClient(BaseClient):
8
8
  async def get_profile_info(self) -> ProfileResponse:
9
9
  data = await self._request("GET", "/auth/profile", auth=True)
10
10
  return ProfileResponse(**data)
11
-
12
- async def get_api_key(self, num_hours: int = 24) -> str:
13
- data = await self._request("POST", "/auth/key", auth=True, data={"num_hours": num_hours})
14
- return data["api_key"]
@@ -2,15 +2,30 @@
2
2
 
3
3
  # generated by datamodel-codegen:
4
4
  # filename: openapi.json
5
- # timestamp: 2025-05-04T22:45:26+00:00
5
+ # timestamp: 2025-07-10T20:02:21+00:00
6
6
 
7
7
  from __future__ import annotations
8
8
 
9
+ from datetime import datetime
9
10
  from typing import Dict, List, Optional, Union
10
11
 
11
12
  from pydantic import BaseModel, Field
12
13
 
13
14
 
15
+ class APIKey(BaseModel):
16
+ id: str = Field(..., title="Id")
17
+ user_id: str = Field(..., title="User Id")
18
+ name: str = Field(..., title="Name")
19
+ key_hash: str = Field(..., title="Key Hash")
20
+ permissions: List[str] = Field(..., title="Permissions")
21
+ email: str = Field(..., title="Email")
22
+ email_verified: bool = Field(..., title="Email Verified")
23
+ created_at: datetime = Field(..., title="Created At")
24
+ last_used_at: Optional[datetime] = Field(None, title="Last Used At")
25
+ expires_at: Optional[datetime] = Field(None, title="Expires At")
26
+ is_active: Optional[bool] = Field(True, title="Is Active")
27
+
28
+
14
29
  class APIKeyRequest(BaseModel):
15
30
  num_hours: Optional[int] = Field(24, title="Num Hours")
16
31
 
@@ -19,6 +34,16 @@ class APIKeyResponse(BaseModel):
19
34
  api_key: str = Field(..., title="Api Key")
20
35
 
21
36
 
37
+ class APIKeySummaryResponse(BaseModel):
38
+ id: str = Field(..., title="Id")
39
+ name: str = Field(..., title="Name")
40
+ permissions: List[str] = Field(..., title="Permissions")
41
+ created_at: datetime = Field(..., title="Created At")
42
+ last_used_at: Optional[datetime] = Field(..., title="Last Used At")
43
+ expires_at: Optional[datetime] = Field(..., title="Expires At")
44
+ is_active: bool = Field(..., title="Is Active")
45
+
46
+
22
47
  class ActuatorMetadataInput(BaseModel):
23
48
  actuator_type: Optional[str] = Field(None, title="Actuator Type")
24
49
  sysid: Optional[str] = Field(None, title="Sysid")
@@ -62,6 +87,45 @@ class AddRobotRequest(BaseModel):
62
87
  class_name: str = Field(..., title="Class Name")
63
88
 
64
89
 
90
+ class Agent(BaseModel):
91
+ id: str = Field(..., title="Id")
92
+ user_id: str = Field(..., title="User Id")
93
+ upload_time: str = Field(..., title="Upload Time")
94
+ description: Optional[str] = Field(None, title="Description")
95
+ num_downloads: Optional[int] = Field(0, title="Num Downloads")
96
+
97
+
98
+ class AgentDownloadResponse(BaseModel):
99
+ url: str = Field(..., title="Url")
100
+ md5_hash: str = Field(..., title="Md5 Hash")
101
+
102
+
103
+ class AgentUploadRequest(BaseModel):
104
+ filename: str = Field(..., title="Filename")
105
+ content_type: str = Field(..., title="Content Type")
106
+
107
+
108
+ class AgentUploadResponse(BaseModel):
109
+ url: str = Field(..., title="Url")
110
+ filename: str = Field(..., title="Filename")
111
+ content_type: str = Field(..., title="Content Type")
112
+
113
+
114
+ class CreateAPIKeyRequest(BaseModel):
115
+ name: str = Field(..., title="Name")
116
+ permissions: List[str] = Field(..., title="Permissions")
117
+ expires_at: Optional[datetime] = Field(None, title="Expires At")
118
+
119
+
120
+ class CreateAPIKeyResponse(BaseModel):
121
+ api_key: APIKey
122
+ plain_key: str = Field(..., title="Plain Key")
123
+
124
+
125
+ class CreateAgentRequest(BaseModel):
126
+ description: Optional[str] = Field(None, title="Description")
127
+
128
+
65
129
  class JointMetadataInput(BaseModel):
66
130
  id: Optional[int] = Field(None, title="Id")
67
131
  kp: Optional[Union[float, str]] = Field(None, title="Kp")
@@ -73,6 +137,8 @@ class JointMetadataInput(BaseModel):
73
137
  actuator_type: Optional[str] = Field(None, title="Actuator Type")
74
138
  nn_id: Optional[int] = Field(None, title="Nn Id")
75
139
  soft_torque_limit: Optional[Union[float, str]] = Field(None, title="Soft Torque Limit")
140
+ min_angle_deg: Optional[Union[float, str]] = Field(None, title="Min Angle Deg")
141
+ max_angle_deg: Optional[Union[float, str]] = Field(None, title="Max Angle Deg")
76
142
 
77
143
 
78
144
  class JointMetadataOutput(BaseModel):
@@ -86,9 +152,11 @@ class JointMetadataOutput(BaseModel):
86
152
  actuator_type: Optional[str] = Field(None, title="Actuator Type")
87
153
  nn_id: Optional[int] = Field(None, title="Nn Id")
88
154
  soft_torque_limit: Optional[str] = Field(None, title="Soft Torque Limit")
155
+ min_angle_deg: Optional[str] = Field(None, title="Min Angle Deg")
156
+ max_angle_deg: Optional[str] = Field(None, title="Max Angle Deg")
89
157
 
90
158
 
91
- class OICDInfo(BaseModel):
159
+ class OIDCInfo(BaseModel):
92
160
  authority: str = Field(..., title="Authority")
93
161
  client_id: str = Field(..., title="Client Id")
94
162
 
@@ -141,6 +209,14 @@ class RobotUploadURDFResponse(BaseModel):
141
209
  content_type: str = Field(..., title="Content Type")
142
210
 
143
211
 
212
+ class UpdateAPIKeyPermissionsRequest(BaseModel):
213
+ permissions: List[str] = Field(..., title="Permissions")
214
+
215
+
216
+ class UpdateAgentRequest(BaseModel):
217
+ description: Optional[str] = Field(None, title="Description")
218
+
219
+
144
220
  class UpdateRobotClassRequest(BaseModel):
145
221
  new_class_name: Optional[str] = Field(None, title="New Class Name")
146
222
  new_description: Optional[str] = Field(None, title="New Description")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: kscale
3
- Version: 0.3.15
3
+ Version: 0.3.17
4
4
  Summary: The kscale project
5
5
  Home-page: https://github.com/kscalelabs/kscale
6
6
  Author: Benjamin Bolte
@@ -1,423 +0,0 @@
1
- """Defines a base client for the K-Scale WWW API client."""
2
-
3
- import asyncio
4
- import json
5
- import logging
6
- import os
7
- import secrets
8
- import sys
9
- import time
10
- import webbrowser
11
- from types import TracebackType
12
- from typing import Any, Mapping, Self, Type
13
- from urllib.parse import urljoin
14
-
15
- import aiohttp
16
- import httpx
17
- from aiohttp import web
18
- from async_lru import alru_cache
19
- from jwt import ExpiredSignatureError, PyJWKClient, decode as jwt_decode
20
- from pydantic import BaseModel
21
- from yarl import URL
22
-
23
- from kscale.web.gen.api import OICDInfo
24
- from kscale.web.utils import DEFAULT_UPLOAD_TIMEOUT, get_api_root, get_auth_dir
25
-
26
- logger = logging.getLogger(__name__)
27
-
28
- # This port matches the available port for the OAuth callback.
29
- OAUTH_PORT = 16821
30
-
31
- # This is the name of the API key header for the K-Scale WWW API.
32
- HEADER_NAME = "x-kscale-api-key"
33
-
34
-
35
- def verbose_error() -> bool:
36
- return os.environ.get("KSCALE_VERBOSE_ERROR", "0") == "1"
37
-
38
-
39
- class OAuthCallback:
40
- def __init__(self) -> None:
41
- self.token_type: str | None = None
42
- self.access_token: str | None = None
43
- self.id_token: str | None = None
44
- self.state: str | None = None
45
- self.expires_in: str | None = None
46
- self.app = web.Application()
47
- self.app.router.add_get("/token", self.handle_token)
48
- self.app.router.add_get("/callback", self.handle_callback)
49
-
50
- async def handle_token(self, request: web.Request) -> web.Response:
51
- """Handle the token extraction."""
52
- self.token_type = request.query.get("token_type")
53
- self.access_token = request.query.get("access_token")
54
- self.id_token = request.query.get("id_token")
55
- self.state = request.query.get("state")
56
- self.expires_in = request.query.get("expires_in")
57
- return web.Response(text="OK")
58
-
59
- async def handle_callback(self, request: web.Request) -> web.Response:
60
- """Handle the OAuth callback with token in URL fragment."""
61
- return web.Response(
62
- text="""
63
- <!DOCTYPE html>
64
- <html lang="en">
65
- <head>
66
- <meta charset="UTF-8">
67
- <meta name="viewport" content="width=device-width, initial-scale=1.0">
68
- <title>Authentication successful</title>
69
- <style>
70
- body {
71
- font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
72
- display: flex;
73
- justify-content: center;
74
- align-items: center;
75
- min-height: 100vh;
76
- margin: 0;
77
- background: #f5f5f5;
78
- color: #333;
79
- }
80
- .container {
81
- background: white;
82
- padding: 2rem;
83
- border-radius: 8px;
84
- box-shadow: 0 2px 4px rgba(0,0,0,0.1);
85
- max-width: 600px;
86
- width: 90%;
87
- }
88
- h1 {
89
- color: #2c3e50;
90
- margin-bottom: 1rem;
91
- }
92
- .token-info {
93
- background: #f8f9fa;
94
- border: 1px solid #dee2e6;
95
- border-radius: 4px;
96
- padding: 1rem;
97
- margin: 1rem 0;
98
- word-break: break-all;
99
- }
100
- .token-label {
101
- font-weight: bold;
102
- color: #6c757d;
103
- margin-bottom: 0.5rem;
104
- }
105
- .success-icon {
106
- color: #28a745;
107
- font-size: 48px;
108
- margin-bottom: 1rem;
109
- }
110
- </style>
111
- </head>
112
- <body>
113
- <div class="container">
114
- <div class="success-icon">✓</div>
115
- <h1>Authentication successful!</h1>
116
- <p>Your authentication tokens are shown below. You can now close this window.</p>
117
-
118
- <div class="token-info">
119
- <div class="token-label">Access Token:</div>
120
- <div id="accessTokenDisplay"></div>
121
- </div>
122
-
123
- <div class="token-info">
124
- <div class="token-label">ID Token:</div>
125
- <div id="idTokenDisplay"></div>
126
- </div>
127
- </div>
128
-
129
- <script>
130
- const params = new URLSearchParams(window.location.hash.substring(1));
131
- const tokenType = params.get('token_type');
132
- const accessToken = params.get('access_token');
133
- const idToken = params.get('id_token');
134
- const state = params.get('state');
135
- const expiresIn = params.get('expires_in');
136
-
137
- // Display tokens
138
- document.getElementById('accessTokenDisplay').textContent = accessToken || 'Not provided';
139
- document.getElementById('idTokenDisplay').textContent = idToken || 'Not provided';
140
-
141
- if (accessToken) {
142
- const tokenUrl = new URL(window.location.href);
143
- tokenUrl.pathname = '/token';
144
- tokenUrl.searchParams.set('access_token', accessToken);
145
- tokenUrl.searchParams.set('token_type', tokenType);
146
- tokenUrl.searchParams.set('id_token', idToken);
147
- tokenUrl.searchParams.set('state', state);
148
- tokenUrl.searchParams.set('expires_in', expiresIn);
149
- fetch(tokenUrl.toString());
150
- }
151
- </script>
152
- </body>
153
- </html>
154
- """,
155
- content_type="text/html",
156
- )
157
-
158
-
159
- class BaseClient:
160
- def __init__(
161
- self,
162
- base_url: str | None = None,
163
- upload_timeout: float = DEFAULT_UPLOAD_TIMEOUT,
164
- use_cache: bool = True,
165
- ) -> None:
166
- self.base_url = get_api_root() if base_url is None else base_url
167
- self.upload_timeout = upload_timeout
168
- self.use_cache = use_cache
169
- self._client: httpx.AsyncClient | None = None
170
- self._client_no_auth: httpx.AsyncClient | None = None
171
-
172
- @alru_cache
173
- async def _get_oicd_info(self) -> OICDInfo:
174
- cache_path = get_auth_dir() / "oicd_info.json"
175
- if self.use_cache and cache_path.exists():
176
- with open(cache_path, "r") as f:
177
- return OICDInfo(**json.load(f))
178
- data = await self._request("GET", "/auth/oicd", auth=False)
179
- if self.use_cache:
180
- cache_path.parent.mkdir(parents=True, exist_ok=True)
181
- with open(cache_path, "w") as f:
182
- json.dump(data, f)
183
- return OICDInfo(**data)
184
-
185
- @alru_cache
186
- async def _get_oicd_metadata(self) -> dict:
187
- """Returns the OpenID Connect server configuration.
188
-
189
- Returns:
190
- The OpenID Connect server configuration.
191
- """
192
- cache_path = get_auth_dir() / "oicd_metadata.json"
193
- if self.use_cache and cache_path.exists():
194
- with open(cache_path, "r") as f:
195
- return json.load(f)
196
- oicd_info = await self._get_oicd_info()
197
- oicd_config_url = f"{oicd_info.authority}/.well-known/openid-configuration"
198
- async with aiohttp.ClientSession() as session:
199
- async with session.get(oicd_config_url) as response:
200
- metadata = await response.json()
201
- if self.use_cache:
202
- cache_path.parent.mkdir(parents=True, exist_ok=True)
203
- with open(cache_path, "w") as f:
204
- json.dump(metadata, f, indent=2)
205
- logger.info("Cached OpenID Connect metadata to %s", cache_path)
206
- return metadata
207
-
208
- async def _get_bearer_token(self) -> str:
209
- """Get a bearer token using the OAuth2 implicit flow.
210
-
211
- Returns:
212
- A bearer token to use with the K-Scale WWW API.
213
- """
214
- # Check if we are in a headless environment.
215
- error_message = (
216
- "Cannot perform browser-based authentication in a headless environment. "
217
- "Please use 'kscale user key' to generate an API key locally and set "
218
- "the KSCALE_API_KEY environment variable instead."
219
- )
220
- try:
221
- if not webbrowser.get().name != "null":
222
- raise RuntimeError(error_message)
223
- except webbrowser.Error:
224
- raise RuntimeError(error_message)
225
-
226
- oicd_info = await self._get_oicd_info()
227
- metadata = await self._get_oicd_metadata()
228
- auth_endpoint = metadata["authorization_endpoint"]
229
-
230
- # Use the cached state and nonce if available, otherwise generate.
231
- state_file = get_auth_dir() / "oauth_state.json"
232
- state: str | None = None
233
- nonce: str | None = None
234
- if state_file.exists():
235
- with open(state_file, "r") as f:
236
- state_data = json.load(f)
237
- state = state_data.get("state")
238
- nonce = state_data.get("nonce")
239
- if state is None:
240
- state = secrets.token_urlsafe(32)
241
- if nonce is None:
242
- nonce = secrets.token_urlsafe(32)
243
-
244
- # Change /oauth2/authorize to /login to use the login endpoint.
245
- auth_endpoint = auth_endpoint.replace("/oauth2/authorize", "/login")
246
-
247
- auth_url = str(
248
- URL(auth_endpoint).with_query(
249
- {
250
- "response_type": "token",
251
- "redirect_uri": f"http://localhost:{OAUTH_PORT}/callback",
252
- "state": state,
253
- "nonce": nonce,
254
- "scope": "openid profile email",
255
- "client_id": oicd_info.client_id,
256
- }
257
- )
258
- )
259
-
260
- # Start local server to receive callback
261
- callback_handler = OAuthCallback()
262
- runner = web.AppRunner(callback_handler.app)
263
- await runner.setup()
264
- site = web.TCPSite(runner, "localhost", OAUTH_PORT)
265
-
266
- try:
267
- await site.start()
268
- except OSError as e:
269
- raise OSError(
270
- f"The command line interface requires access to local port {OAUTH_PORT} in order to authenticate with "
271
- "OpenID Connect. Please ensure that no other application is using this port."
272
- ) from e
273
-
274
- # Open browser for user authentication
275
- webbrowser.open(auth_url)
276
-
277
- # Wait for the callback with timeout
278
- try:
279
- start_time = time.time()
280
- while callback_handler.access_token is None:
281
- if time.time() - start_time > 30:
282
- raise TimeoutError("Authentication timed out after 30 seconds")
283
- await asyncio.sleep(0.1)
284
-
285
- # Save the state and nonce to the cache.
286
- state = callback_handler.state
287
- state_file.parent.mkdir(parents=True, exist_ok=True)
288
- state_file.write_text(json.dumps({"state": state, "nonce": nonce}))
289
-
290
- return callback_handler.access_token
291
- finally:
292
- await runner.cleanup()
293
-
294
- @alru_cache
295
- async def _get_jwk_client(self) -> PyJWKClient:
296
- """Returns a JWK client for the OpenID Connect server."""
297
- oicd_info = await self._get_oicd_info()
298
- jwks_uri = f"{oicd_info.authority}/.well-known/jwks.json"
299
- return PyJWKClient(uri=jwks_uri)
300
-
301
- async def _is_token_expired(self, token: str) -> bool:
302
- """Check if a token is expired."""
303
- jwk_client = await self._get_jwk_client()
304
- signing_key = jwk_client.get_signing_key_from_jwt(token)
305
-
306
- try:
307
- claims = jwt_decode(
308
- token,
309
- signing_key.key,
310
- algorithms=["RS256"],
311
- options={"verify_aud": False},
312
- )
313
- except ExpiredSignatureError:
314
- return True
315
-
316
- return claims["exp"] < time.time()
317
-
318
- @alru_cache
319
- async def get_bearer_token(self) -> str:
320
- """Get a bearer token from OpenID Connect.
321
-
322
- Returns:
323
- A bearer token to use with the K-Scale WWW API.
324
- """
325
- cache_path = get_auth_dir() / "bearer_token.txt"
326
- if self.use_cache and cache_path.exists():
327
- token = cache_path.read_text()
328
- if not await self._is_token_expired(token):
329
- return token
330
- token = await self._get_bearer_token()
331
- if self.use_cache:
332
- cache_path.write_text(token)
333
- cache_path.chmod(0o600)
334
- return token
335
-
336
- async def get_client(self, *, auth: bool = True) -> httpx.AsyncClient:
337
- client = self._client if auth else self._client_no_auth
338
- if client is None:
339
- headers: dict[str, str] = {}
340
- if auth:
341
- if "KSCALE_API_KEY" in os.environ:
342
- headers[HEADER_NAME] = os.environ["KSCALE_API_KEY"]
343
- else:
344
- headers["Authorization"] = f"Bearer {await self.get_bearer_token()}"
345
-
346
- client = httpx.AsyncClient(
347
- base_url=self.base_url,
348
- headers=headers,
349
- timeout=httpx.Timeout(30.0),
350
- )
351
- if auth:
352
- self._client = client
353
- else:
354
- self._client_no_auth = client
355
- return client
356
-
357
- async def _request(
358
- self,
359
- method: str,
360
- endpoint: str,
361
- *,
362
- auth: bool = True,
363
- params: dict[str, Any] | None = None,
364
- data: BaseModel | dict[str, Any] | None = None,
365
- files: dict[str, Any] | None = None,
366
- error_code_suggestions: dict[int, str] | None = None,
367
- ) -> dict[str, Any]:
368
- url = urljoin(self.base_url, endpoint)
369
- kwargs: dict[str, Any] = {}
370
- if params is not None:
371
- kwargs["params"] = params
372
- if data is not None:
373
- if isinstance(data, BaseModel):
374
- kwargs["json"] = data.model_dump(exclude_unset=True)
375
- else:
376
- kwargs["json"] = data
377
- if files:
378
- kwargs["files"] = files
379
-
380
- client = await self.get_client(auth=auth)
381
- response = await client.request(method, url, **kwargs)
382
-
383
- if response.is_error:
384
- error_code = response.status_code
385
- error_json = response.json()
386
- use_verbose_error = verbose_error()
387
-
388
- if not use_verbose_error:
389
- logger.info("Use KSCALE_VERBOSE_ERROR=1 to see the full error message")
390
- logger.info("If this persists, please create an issue here: https://github.com/kscalelabs/kscale")
391
-
392
- logger.error("Got error %d from the K-Scale API", error_code)
393
- if isinstance(error_json, Mapping):
394
- for key, value in error_json.items():
395
- logger.error(" [%s] %s", key, value)
396
- else:
397
- logger.error(" %s", error_json)
398
-
399
- if error_code_suggestions is not None and error_code in error_code_suggestions:
400
- logger.error("Hint: %s", error_code_suggestions[error_code])
401
-
402
- if use_verbose_error:
403
- response.raise_for_status()
404
- else:
405
- sys.exit(1)
406
-
407
- return response.json()
408
-
409
- async def close(self) -> None:
410
- if self._client is not None:
411
- await self._client.aclose()
412
- self._client = None
413
-
414
- async def __aenter__(self) -> Self:
415
- return self
416
-
417
- async def __aexit__(
418
- self,
419
- exc_type: Type[BaseException] | None,
420
- exc_val: BaseException | None,
421
- exc_tb: TracebackType | None,
422
- ) -> None:
423
- await self.close()
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes