kscale 0.0.11__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,314 @@
1
+ """Defines a base client for the K-Scale WWW API client."""
2
+
3
+ import asyncio
4
+ import json
5
+ import logging
6
+ import secrets
7
+ import time
8
+ import webbrowser
9
+ from types import TracebackType
10
+ from typing import Any, Self, Type
11
+ from urllib.parse import urljoin
12
+
13
+ import aiohttp
14
+ import httpx
15
+ from aiohttp import web
16
+ from async_lru import alru_cache
17
+ from jwt import ExpiredSignatureError, PyJWKClient, decode as jwt_decode
18
+ from pydantic import BaseModel
19
+ from yarl import URL
20
+
21
+ from kscale.web.gen.api import OICDInfo
22
+ from kscale.web.utils import DEFAULT_UPLOAD_TIMEOUT, get_api_root, get_cache_dir
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ # This port matches the available port for the OAuth callback.
27
+ OAUTH_PORT = 16821
28
+
29
+
30
+ class OAuthCallback:
31
+ def __init__(self) -> None:
32
+ self.access_token: str | None = None
33
+ self.app = web.Application()
34
+ self.app.router.add_get("/token", self.handle_token)
35
+ self.app.router.add_get("/callback", self.handle_callback)
36
+
37
+ async def handle_token(self, request: web.Request) -> web.Response:
38
+ """Handle the token extraction."""
39
+ self.access_token = request.query.get("access_token")
40
+ return web.Response(text="OK")
41
+
42
+ async def handle_callback(self, request: web.Request) -> web.Response:
43
+ """Handle the OAuth callback with token in URL fragment."""
44
+ return web.Response(
45
+ text="""
46
+ <!DOCTYPE html>
47
+ <html lang="en">
48
+
49
+ <head>
50
+ <meta charset="UTF-8">
51
+ <meta http-equiv="X-UA-Compatible" content="IE=edge">
52
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
53
+ <title>Authentication successful</title>
54
+ <style>
55
+ body {
56
+ display: flex;
57
+ justify-content: center;
58
+ align-items: center;
59
+ min-height: 100vh;
60
+ margin: 0;
61
+ text-align: center;
62
+ }
63
+ #content {
64
+ padding: 20px;
65
+ }
66
+ #closeNotification {
67
+ display: none;
68
+ padding: 10px 20px;
69
+ margin-top: 20px;
70
+ cursor: pointer;
71
+ margin-left: auto;
72
+ margin-right: auto;
73
+ }
74
+ </style>
75
+ </head>
76
+
77
+ <body>
78
+ <div id="content">
79
+ <h1>Authentication successful!</h1>
80
+ <p>This window will close in <span id="countdown">3</span> seconds.</p>
81
+ <p id="closeNotification" onclick="window.close()">Please close this window manually.</p>
82
+ </div>
83
+ <script>
84
+ const params = new URLSearchParams(window.location.hash.substring(1));
85
+ const token = params.get('access_token');
86
+ if (token) {
87
+ fetch('/token?access_token=' + token);
88
+ }
89
+
90
+ let timeLeft = 3;
91
+ const countdownElement = document.getElementById('countdown');
92
+ const closeNotification = document.getElementById('closeNotification');
93
+ const timer = setInterval(() => {
94
+ timeLeft--;
95
+ countdownElement.textContent = timeLeft;
96
+ if (timeLeft <= 0) {
97
+ clearInterval(timer);
98
+ window.close();
99
+ setTimeout(() => {
100
+ closeNotification.style.display = 'block';
101
+ }, 500);
102
+ }
103
+ }, 1000);
104
+ </script>
105
+ </body>
106
+ </html>
107
+ """,
108
+ content_type="text/html",
109
+ )
110
+
111
+
112
+ class BaseClient:
113
+ def __init__(
114
+ self,
115
+ base_url: str | None = None,
116
+ upload_timeout: float = DEFAULT_UPLOAD_TIMEOUT,
117
+ use_cache: bool = True,
118
+ ) -> None:
119
+ self.base_url = get_api_root() if base_url is None else base_url
120
+ self.upload_timeout = upload_timeout
121
+ self.use_cache = use_cache
122
+ self._client: httpx.AsyncClient | None = None
123
+ self._client_no_auth: httpx.AsyncClient | None = None
124
+
125
+ @alru_cache
126
+ async def _get_oicd_info(self) -> OICDInfo:
127
+ cache_path = get_cache_dir() / "oicd_info.json"
128
+ if self.use_cache and cache_path.exists():
129
+ with open(cache_path, "r") as f:
130
+ return OICDInfo(**json.load(f))
131
+ data = await self._request("GET", "/auth/oicd", auth=False)
132
+ if self.use_cache:
133
+ cache_path.parent.mkdir(parents=True, exist_ok=True)
134
+ with open(cache_path, "w") as f:
135
+ json.dump(data, f)
136
+ return OICDInfo(**data)
137
+
138
+ @alru_cache
139
+ async def _get_oicd_metadata(self) -> dict:
140
+ """Returns the OpenID Connect server configuration.
141
+
142
+ Returns:
143
+ The OpenID Connect server configuration.
144
+ """
145
+ cache_path = get_cache_dir() / "oicd_metadata.json"
146
+ if self.use_cache and cache_path.exists():
147
+ with open(cache_path, "r") as f:
148
+ return json.load(f)
149
+ oicd_info = await self._get_oicd_info()
150
+ oicd_config_url = f"{oicd_info.authority}/.well-known/openid-configuration"
151
+ async with aiohttp.ClientSession() as session:
152
+ async with session.get(oicd_config_url) as response:
153
+ metadata = await response.json()
154
+ if self.use_cache:
155
+ cache_path.parent.mkdir(parents=True, exist_ok=True)
156
+ with open(cache_path, "w") as f:
157
+ json.dump(metadata, f, indent=2)
158
+ logger.info("Cached OpenID Connect metadata to %s", cache_path)
159
+ return metadata
160
+
161
+ async def _get_bearer_token(self) -> str:
162
+ """Get a bearer token using the OAuth2 implicit flow.
163
+
164
+ Returns:
165
+ A bearer token to use with the K-Scale WWW API.
166
+ """
167
+ oicd_info = await self._get_oicd_info()
168
+ metadata = await self._get_oicd_metadata()
169
+ auth_endpoint = metadata["authorization_endpoint"]
170
+ state = secrets.token_urlsafe(32)
171
+ nonce = secrets.token_urlsafe(32)
172
+
173
+ auth_url = str(
174
+ URL(auth_endpoint).with_query(
175
+ {
176
+ "response_type": "token",
177
+ "redirect_uri": f"http://localhost:{OAUTH_PORT}/callback",
178
+ "state": state,
179
+ "nonce": nonce,
180
+ "scope": "openid profile email",
181
+ "client_id": oicd_info.client_id,
182
+ }
183
+ )
184
+ )
185
+
186
+ # Start local server to receive callback
187
+ callback_handler = OAuthCallback()
188
+ runner = web.AppRunner(callback_handler.app)
189
+ await runner.setup()
190
+ site = web.TCPSite(runner, "localhost", OAUTH_PORT)
191
+
192
+ try:
193
+ await site.start()
194
+ except OSError as e:
195
+ raise OSError(
196
+ f"The command line interface requires access to local port {OAUTH_PORT} in order to authenticate with "
197
+ "OpenID Connect. Please ensure that no other application is using this port."
198
+ ) from e
199
+
200
+ # Open browser for user authentication
201
+ webbrowser.open(auth_url)
202
+
203
+ # Wait for the callback with timeout
204
+ try:
205
+ start_time = time.time()
206
+ while callback_handler.access_token is None:
207
+ if time.time() - start_time > 30:
208
+ raise TimeoutError("Authentication timed out after 30 seconds")
209
+ await asyncio.sleep(0.1)
210
+
211
+ return callback_handler.access_token
212
+ finally:
213
+ await runner.cleanup()
214
+
215
+ @alru_cache
216
+ async def _get_jwk_client(self) -> PyJWKClient:
217
+ """Returns a JWK client for the OpenID Connect server."""
218
+ oicd_info = await self._get_oicd_info()
219
+ jwks_uri = f"{oicd_info.authority}/.well-known/jwks.json"
220
+ return PyJWKClient(uri=jwks_uri)
221
+
222
+ async def _is_token_expired(self, token: str) -> bool:
223
+ """Check if a token is expired."""
224
+ jwk_client = await self._get_jwk_client()
225
+ signing_key = jwk_client.get_signing_key_from_jwt(token)
226
+
227
+ try:
228
+ claims = jwt_decode(
229
+ token,
230
+ signing_key.key,
231
+ algorithms=["RS256"],
232
+ options={"verify_aud": False},
233
+ )
234
+ except ExpiredSignatureError:
235
+ return True
236
+
237
+ return claims["exp"] < time.time()
238
+
239
+ @alru_cache
240
+ async def get_bearer_token(self) -> str:
241
+ """Get a bearer token from OpenID Connect.
242
+
243
+ Returns:
244
+ A bearer token to use with the K-Scale WWW API.
245
+ """
246
+ cache_path = get_cache_dir() / "bearer_token.txt"
247
+ if self.use_cache and cache_path.exists():
248
+ token = cache_path.read_text()
249
+ if not await self._is_token_expired(token):
250
+ return token
251
+ token = await self._get_bearer_token()
252
+ if self.use_cache:
253
+ cache_path.write_text(token)
254
+ cache_path.chmod(0o600)
255
+ return token
256
+
257
+ async def get_client(self, *, auth: bool = True) -> httpx.AsyncClient:
258
+ client = self._client if auth else self._client_no_auth
259
+ if client is None:
260
+ client = httpx.AsyncClient(
261
+ base_url=self.base_url,
262
+ headers={"Authorization": f"Bearer {await self.get_bearer_token()}"} if auth else None,
263
+ timeout=httpx.Timeout(30.0),
264
+ )
265
+ if auth:
266
+ self._client = client
267
+ else:
268
+ self._client_no_auth = client
269
+ return client
270
+
271
+ async def _request(
272
+ self,
273
+ method: str,
274
+ endpoint: str,
275
+ *,
276
+ auth: bool = True,
277
+ params: dict[str, Any] | None = None,
278
+ data: BaseModel | dict[str, Any] | None = None,
279
+ files: dict[str, Any] | None = None,
280
+ ) -> dict[str, Any]:
281
+ url = urljoin(self.base_url, endpoint)
282
+ kwargs: dict[str, Any] = {"params": params}
283
+
284
+ if data:
285
+ if isinstance(data, BaseModel):
286
+ kwargs["json"] = data.model_dump(exclude_unset=True)
287
+ else:
288
+ kwargs["json"] = data
289
+ if files:
290
+ kwargs["files"] = files
291
+
292
+ client = await self.get_client(auth=auth)
293
+ response = await client.request(method, url, **kwargs)
294
+
295
+ if response.is_error:
296
+ logger.error("Error response from K-Scale: %s", response.text)
297
+ response.raise_for_status()
298
+ return response.json()
299
+
300
+ async def close(self) -> None:
301
+ if self._client is not None:
302
+ await self._client.aclose()
303
+ self._client = None
304
+
305
+ async def __aenter__(self) -> Self:
306
+ return self
307
+
308
+ async def __aexit__(
309
+ self,
310
+ exc_type: Type[BaseException] | None,
311
+ exc_val: BaseException | None,
312
+ exc_tb: TracebackType | None,
313
+ ) -> None:
314
+ await self.close()
@@ -0,0 +1,11 @@
1
+ """Defines a unified client for the K-Scale WWW API."""
2
+
3
+ from kscale.web.clients.base import BaseClient
4
+ from kscale.web.clients.user import UserClient
5
+
6
+
7
+ class WWWClient(
8
+ UserClient,
9
+ BaseClient,
10
+ ):
11
+ pass
@@ -0,0 +1,39 @@
1
+ """Defines the client for interacting with the K-Scale robot endpoints."""
2
+
3
+ from kscale.web.clients.base import BaseClient
4
+ from kscale.web.gen.api import Robot, RobotResponse
5
+
6
+
7
+ class RobotClient(BaseClient):
8
+ async def get_all_robots(self) -> list[Robot]:
9
+ data = await self._request("GET", "/robot/", auth=True)
10
+ return [Robot.model_validate(item) for item in data]
11
+
12
+ async def get_user_robots(self, user_id: str = "me") -> list[Robot]:
13
+ data = await self._request("GET", f"/robot/user/{user_id}", auth=True)
14
+ return [Robot.model_validate(item) for item in data]
15
+
16
+ async def add_robot(
17
+ self,
18
+ robot_name: str,
19
+ class_name: str,
20
+ description: str | None = None,
21
+ ) -> RobotResponse:
22
+ params = {"class_name": class_name}
23
+ if description is not None:
24
+ params["description"] = description
25
+ data = await self._request(
26
+ "PUT",
27
+ f"/robot/{robot_name}",
28
+ params=params,
29
+ auth=True,
30
+ )
31
+ return RobotResponse.model_validate(data)
32
+
33
+ async def get_robot_by_id(self, robot_id: str) -> RobotResponse:
34
+ data = await self._request("GET", f"/robot/id/{robot_id}", auth=True)
35
+ return RobotResponse.model_validate(data)
36
+
37
+ async def get_robot_by_name(self, robot_name: str) -> RobotResponse:
38
+ data = await self._request("GET", f"/robot/name/{robot_name}", auth=True)
39
+ return RobotResponse.model_validate(data)
@@ -0,0 +1,114 @@
1
+ """Defines the client for interacting with the K-Scale robot class endpoints."""
2
+
3
+ import hashlib
4
+ import logging
5
+ from pathlib import Path
6
+
7
+ import httpx
8
+
9
+ from kscale.web.clients.base import BaseClient
10
+ from kscale.web.gen.api import RobotClass, RobotDownloadURDFResponse, RobotUploadURDFResponse
11
+ from kscale.web.utils import get_cache_dir
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class RobotClassClient(BaseClient):
17
+ async def get_robot_classes(self) -> list[RobotClass]:
18
+ data = await self._request(
19
+ "GET",
20
+ "/robots/",
21
+ auth=True,
22
+ )
23
+ return [RobotClass.model_validate(item) for item in data]
24
+
25
+ async def create_robot_class(self, class_name: str, description: str | None = None) -> RobotClass:
26
+ params = {}
27
+ if description is not None:
28
+ params["description"] = description
29
+ data = await self._request(
30
+ "PUT",
31
+ f"/robots/{class_name}",
32
+ params=params,
33
+ auth=True,
34
+ )
35
+ return RobotClass.model_validate(data)
36
+
37
+ async def update_robot_class(
38
+ self,
39
+ class_name: str,
40
+ new_class_name: str | None = None,
41
+ new_description: str | None = None,
42
+ ) -> RobotClass:
43
+ params = {}
44
+ if new_class_name is not None:
45
+ params["new_class_name"] = new_class_name
46
+ if new_description is not None:
47
+ params["new_description"] = new_description
48
+ if not params:
49
+ raise ValueError("No parameters to update")
50
+ data = await self._request(
51
+ "POST",
52
+ f"/robots/{class_name}",
53
+ params=params,
54
+ auth=True,
55
+ )
56
+ return RobotClass.model_validate(data)
57
+
58
+ async def delete_robot_class(self, class_name: str) -> None:
59
+ await self._request("DELETE", f"/robots/{class_name}", auth=True)
60
+
61
+ async def upload_robot_class_urdf(self, class_name: str, urdf_file: str | Path) -> RobotUploadURDFResponse:
62
+ if not (urdf_file := Path(urdf_file)).exists():
63
+ raise FileNotFoundError(f"URDF file not found: {urdf_file}")
64
+
65
+ # Gets the content type from the file extension.
66
+ ext = urdf_file.suffix.lower()
67
+ match ext:
68
+ case ".tgz":
69
+ content_type = "application/x-compressed-tar"
70
+ case _:
71
+ raise ValueError(f"Unsupported file extension: {ext}")
72
+
73
+ data = await self._request(
74
+ "PUT",
75
+ f"/robots/urdf/{class_name}",
76
+ params={"filename": urdf_file.name, "content_type": content_type},
77
+ auth=True,
78
+ )
79
+ response = RobotUploadURDFResponse.model_validate(data)
80
+ async with httpx.AsyncClient() as client:
81
+ async with client.stream(
82
+ "PUT",
83
+ response.url,
84
+ content=urdf_file.read_bytes(),
85
+ headers={"Content-Type": response.content_type},
86
+ ) as r:
87
+ r.raise_for_status()
88
+ return response
89
+
90
+ async def download_robot_class_urdf(self, class_name: str, *, cache: bool = True) -> Path:
91
+ cache_path = get_cache_dir() / class_name / "robot.tgz"
92
+ if cache and cache_path.exists():
93
+ return cache_path
94
+ data = await self._request("GET", f"/robots/urdf/{class_name}", auth=True)
95
+ response = RobotDownloadURDFResponse.model_validate(data)
96
+ expected_hash = response.md5_hash
97
+ cache_path.parent.mkdir(parents=True, exist_ok=True)
98
+
99
+ logger.info("Downloading URDF file from %s", response.url)
100
+ async with httpx.AsyncClient() as client:
101
+ with open(cache_path, "wb") as file:
102
+ hash_value = hashlib.md5()
103
+ async with client.stream("GET", response.url) as r:
104
+ r.raise_for_status()
105
+ async for chunk in r.aiter_bytes():
106
+ file.write(chunk)
107
+ hash_value.update(chunk)
108
+
109
+ logger.info("Checking MD5 hash of downloaded file")
110
+ hash_value_hex = f'"{hash_value.hexdigest()}"'
111
+ if hash_value_hex != expected_hash:
112
+ raise ValueError(f"MD5 hash mismatch: {hash_value_hex} != {expected_hash}")
113
+
114
+ return cache_path
@@ -0,0 +1,10 @@
1
+ """Defines the client for interacting with the K-Scale authentication endpoints."""
2
+
3
+ from kscale.web.clients.base import BaseClient
4
+ from kscale.web.gen.api import ProfileResponse
5
+
6
+
7
+ class UserClient(BaseClient):
8
+ async def get_profile_info(self) -> ProfileResponse:
9
+ data = await self._request("GET", "/auth/profile", auth=True)
10
+ return ProfileResponse(**data)
File without changes
kscale/web/gen/api.py ADDED
@@ -0,0 +1,73 @@
1
+ """Auto-generated by generate.sh script."""
2
+
3
+ # generated by datamodel-codegen:
4
+ # filename: openapi.json
5
+ # timestamp: 2025-01-15T22:35:42+00:00
6
+
7
+ from __future__ import annotations
8
+
9
+ from typing import List, Optional, Union
10
+
11
+ from pydantic import BaseModel, Field
12
+
13
+
14
+ class OICDInfo(BaseModel):
15
+ authority: str = Field(..., title="Authority")
16
+ client_id: str = Field(..., title="Client Id")
17
+
18
+
19
+ class Robot(BaseModel):
20
+ id: str = Field(..., title="Id")
21
+ robot_name: str = Field(..., title="Robot Name")
22
+ description: str = Field(..., title="Description")
23
+ user_id: str = Field(..., title="User Id")
24
+ class_id: str = Field(..., title="Class Id")
25
+
26
+
27
+ class RobotClass(BaseModel):
28
+ id: str = Field(..., title="Id")
29
+ class_name: str = Field(..., title="Class Name")
30
+ description: str = Field(..., title="Description")
31
+ user_id: str = Field(..., title="User Id")
32
+
33
+
34
+ class RobotDownloadURDFResponse(BaseModel):
35
+ url: str = Field(..., title="Url")
36
+ md5_hash: str = Field(..., title="Md5 Hash")
37
+
38
+
39
+ class RobotResponse(BaseModel):
40
+ id: str = Field(..., title="Id")
41
+ robot_name: str = Field(..., title="Robot Name")
42
+ description: str = Field(..., title="Description")
43
+ user_id: str = Field(..., title="User Id")
44
+ class_name: str = Field(..., title="Class Name")
45
+
46
+
47
+ class RobotUploadURDFResponse(BaseModel):
48
+ url: str = Field(..., title="Url")
49
+ filename: str = Field(..., title="Filename")
50
+ content_type: str = Field(..., title="Content Type")
51
+
52
+
53
+ class UserResponse(BaseModel):
54
+ user_id: str = Field(..., title="User Id")
55
+ is_admin: bool = Field(..., title="Is Admin")
56
+ can_upload: bool = Field(..., title="Can Upload")
57
+ can_test: bool = Field(..., title="Can Test")
58
+
59
+
60
+ class ValidationError(BaseModel):
61
+ loc: List[Union[str, int]] = Field(..., title="Location")
62
+ msg: str = Field(..., title="Message")
63
+ type: str = Field(..., title="Error Type")
64
+
65
+
66
+ class HTTPValidationError(BaseModel):
67
+ detail: Optional[List[ValidationError]] = Field(None, title="Detail")
68
+
69
+
70
+ class ProfileResponse(BaseModel):
71
+ email: str = Field(..., title="Email")
72
+ email_verified: bool = Field(..., title="Email Verified")
73
+ user: UserResponse
kscale/web/utils.py ADDED
@@ -0,0 +1,31 @@
1
+ """Utility functions for interacting with the K-Scale WWW API."""
2
+
3
+ import functools
4
+ import logging
5
+ from pathlib import Path
6
+
7
+ from kscale.conf import Settings
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ DEFAULT_UPLOAD_TIMEOUT = 300.0 # 5 minutes
12
+
13
+
14
+ @functools.lru_cache
15
+ def get_cache_dir() -> Path:
16
+ """Returns the cache directory for artifacts."""
17
+ return Path(Settings.load().www.cache_dir).expanduser().resolve()
18
+
19
+
20
+ @functools.lru_cache
21
+ def get_artifact_dir(artifact_id: str) -> Path:
22
+ """Returns the directory for a specific artifact."""
23
+ cache_dir = get_cache_dir() / artifact_id
24
+ cache_dir.mkdir(parents=True, exist_ok=True)
25
+ return cache_dir
26
+
27
+
28
+ @functools.lru_cache
29
+ def get_api_root() -> str:
30
+ """Returns the root URL for the K-Scale WWW API."""
31
+ return Settings.load().www.api_root