kscale 0.0.10__py3-none-any.whl → 0.1.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,314 @@
1
+ """Defines a base client for the K-Scale WWW API client."""
2
+
3
+ import asyncio
4
+ import json
5
+ import logging
6
+ import secrets
7
+ import time
8
+ import webbrowser
9
+ from types import TracebackType
10
+ from typing import Any, Self, Type
11
+ from urllib.parse import urljoin
12
+
13
+ import aiohttp
14
+ import httpx
15
+ from aiohttp import web
16
+ from async_lru import alru_cache
17
+ from jwt import ExpiredSignatureError, PyJWKClient, decode as jwt_decode
18
+ from pydantic import BaseModel
19
+ from yarl import URL
20
+
21
+ from kscale.web.gen.api import OICDInfo
22
+ from kscale.web.utils import DEFAULT_UPLOAD_TIMEOUT, get_api_root, get_cache_dir
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ # This port matches the available port for the OAuth callback.
27
+ OAUTH_PORT = 16821
28
+
29
+
30
+ class OAuthCallback:
31
+ def __init__(self) -> None:
32
+ self.access_token: str | None = None
33
+ self.app = web.Application()
34
+ self.app.router.add_get("/token", self.handle_token)
35
+ self.app.router.add_get("/callback", self.handle_callback)
36
+
37
+ async def handle_token(self, request: web.Request) -> web.Response:
38
+ """Handle the token extraction."""
39
+ self.access_token = request.query.get("access_token")
40
+ return web.Response(text="OK")
41
+
42
+ async def handle_callback(self, request: web.Request) -> web.Response:
43
+ """Handle the OAuth callback with token in URL fragment."""
44
+ return web.Response(
45
+ text="""
46
+ <!DOCTYPE html>
47
+ <html lang="en">
48
+
49
+ <head>
50
+ <meta charset="UTF-8">
51
+ <meta http-equiv="X-UA-Compatible" content="IE=edge">
52
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
53
+ <title>Authentication successful</title>
54
+ <style>
55
+ body {
56
+ display: flex;
57
+ justify-content: center;
58
+ align-items: center;
59
+ min-height: 100vh;
60
+ margin: 0;
61
+ text-align: center;
62
+ }
63
+ #content {
64
+ padding: 20px;
65
+ }
66
+ #closeNotification {
67
+ display: none;
68
+ padding: 10px 20px;
69
+ margin-top: 20px;
70
+ cursor: pointer;
71
+ margin-left: auto;
72
+ margin-right: auto;
73
+ }
74
+ </style>
75
+ </head>
76
+
77
+ <body>
78
+ <div id="content">
79
+ <h1>Authentication successful!</h1>
80
+ <p>This window will close in <span id="countdown">3</span> seconds.</p>
81
+ <p id="closeNotification" onclick="window.close()">Please close this window manually.</p>
82
+ </div>
83
+ <script>
84
+ const params = new URLSearchParams(window.location.hash.substring(1));
85
+ const token = params.get('access_token');
86
+ if (token) {
87
+ fetch('/token?access_token=' + token);
88
+ }
89
+
90
+ let timeLeft = 3;
91
+ const countdownElement = document.getElementById('countdown');
92
+ const closeNotification = document.getElementById('closeNotification');
93
+ const timer = setInterval(() => {
94
+ timeLeft--;
95
+ countdownElement.textContent = timeLeft;
96
+ if (timeLeft <= 0) {
97
+ clearInterval(timer);
98
+ window.close();
99
+ setTimeout(() => {
100
+ closeNotification.style.display = 'block';
101
+ }, 500);
102
+ }
103
+ }, 1000);
104
+ </script>
105
+ </body>
106
+ </html>
107
+ """,
108
+ content_type="text/html",
109
+ )
110
+
111
+
112
+ class BaseClient:
113
+ def __init__(
114
+ self,
115
+ base_url: str | None = None,
116
+ upload_timeout: float = DEFAULT_UPLOAD_TIMEOUT,
117
+ use_cache: bool = True,
118
+ ) -> None:
119
+ self.base_url = get_api_root() if base_url is None else base_url
120
+ self.upload_timeout = upload_timeout
121
+ self.use_cache = use_cache
122
+ self._client: httpx.AsyncClient | None = None
123
+ self._client_no_auth: httpx.AsyncClient | None = None
124
+
125
+ @alru_cache
126
+ async def _get_oicd_info(self) -> OICDInfo:
127
+ cache_path = get_cache_dir() / "oicd_info.json"
128
+ if self.use_cache and cache_path.exists():
129
+ with open(cache_path, "r") as f:
130
+ return OICDInfo(**json.load(f))
131
+ data = await self._request("GET", "/auth/oicd", auth=False)
132
+ if self.use_cache:
133
+ cache_path.parent.mkdir(parents=True, exist_ok=True)
134
+ with open(cache_path, "w") as f:
135
+ json.dump(data, f)
136
+ return OICDInfo(**data)
137
+
138
+ @alru_cache
139
+ async def _get_oicd_metadata(self) -> dict:
140
+ """Returns the OpenID Connect server configuration.
141
+
142
+ Returns:
143
+ The OpenID Connect server configuration.
144
+ """
145
+ cache_path = get_cache_dir() / "oicd_metadata.json"
146
+ if self.use_cache and cache_path.exists():
147
+ with open(cache_path, "r") as f:
148
+ return json.load(f)
149
+ oicd_info = await self._get_oicd_info()
150
+ oicd_config_url = f"{oicd_info.authority}/.well-known/openid-configuration"
151
+ async with aiohttp.ClientSession() as session:
152
+ async with session.get(oicd_config_url) as response:
153
+ metadata = await response.json()
154
+ if self.use_cache:
155
+ cache_path.parent.mkdir(parents=True, exist_ok=True)
156
+ with open(cache_path, "w") as f:
157
+ json.dump(metadata, f, indent=2)
158
+ logger.info("Cached OpenID Connect metadata to %s", cache_path)
159
+ return metadata
160
+
161
+ async def _get_bearer_token(self) -> str:
162
+ """Get a bearer token using the OAuth2 implicit flow.
163
+
164
+ Returns:
165
+ A bearer token to use with the K-Scale WWW API.
166
+ """
167
+ oicd_info = await self._get_oicd_info()
168
+ metadata = await self._get_oicd_metadata()
169
+ auth_endpoint = metadata["authorization_endpoint"]
170
+ state = secrets.token_urlsafe(32)
171
+ nonce = secrets.token_urlsafe(32)
172
+
173
+ auth_url = str(
174
+ URL(auth_endpoint).with_query(
175
+ {
176
+ "response_type": "token",
177
+ "redirect_uri": f"http://localhost:{OAUTH_PORT}/callback",
178
+ "state": state,
179
+ "nonce": nonce,
180
+ "scope": "openid profile email",
181
+ "client_id": oicd_info.client_id,
182
+ }
183
+ )
184
+ )
185
+
186
+ # Start local server to receive callback
187
+ callback_handler = OAuthCallback()
188
+ runner = web.AppRunner(callback_handler.app)
189
+ await runner.setup()
190
+ site = web.TCPSite(runner, "localhost", OAUTH_PORT)
191
+
192
+ try:
193
+ await site.start()
194
+ except OSError as e:
195
+ raise OSError(
196
+ f"The command line interface requires access to local port {OAUTH_PORT} in order to authenticate with "
197
+ "OpenID Connect. Please ensure that no other application is using this port."
198
+ ) from e
199
+
200
+ # Open browser for user authentication
201
+ webbrowser.open(auth_url)
202
+
203
+ # Wait for the callback with timeout
204
+ try:
205
+ start_time = time.time()
206
+ while callback_handler.access_token is None:
207
+ if time.time() - start_time > 30:
208
+ raise TimeoutError("Authentication timed out after 30 seconds")
209
+ await asyncio.sleep(0.1)
210
+
211
+ return callback_handler.access_token
212
+ finally:
213
+ await runner.cleanup()
214
+
215
+ @alru_cache
216
+ async def _get_jwk_client(self) -> PyJWKClient:
217
+ """Returns a JWK client for the OpenID Connect server."""
218
+ oicd_info = await self._get_oicd_info()
219
+ jwks_uri = f"{oicd_info.authority}/.well-known/jwks.json"
220
+ return PyJWKClient(uri=jwks_uri)
221
+
222
+ async def _is_token_expired(self, token: str) -> bool:
223
+ """Check if a token is expired."""
224
+ jwk_client = await self._get_jwk_client()
225
+ signing_key = jwk_client.get_signing_key_from_jwt(token)
226
+
227
+ try:
228
+ claims = jwt_decode(
229
+ token,
230
+ signing_key.key,
231
+ algorithms=["RS256"],
232
+ options={"verify_aud": False},
233
+ )
234
+ except ExpiredSignatureError:
235
+ return True
236
+
237
+ return claims["exp"] < time.time()
238
+
239
+ @alru_cache
240
+ async def get_bearer_token(self) -> str:
241
+ """Get a bearer token from OpenID Connect.
242
+
243
+ Returns:
244
+ A bearer token to use with the K-Scale WWW API.
245
+ """
246
+ cache_path = get_cache_dir() / "bearer_token.txt"
247
+ if self.use_cache and cache_path.exists():
248
+ token = cache_path.read_text()
249
+ if not await self._is_token_expired(token):
250
+ return token
251
+ token = await self._get_bearer_token()
252
+ if self.use_cache:
253
+ cache_path.write_text(token)
254
+ cache_path.chmod(0o600)
255
+ return token
256
+
257
+ async def get_client(self, *, auth: bool = True) -> httpx.AsyncClient:
258
+ client = self._client if auth else self._client_no_auth
259
+ if client is None:
260
+ client = httpx.AsyncClient(
261
+ base_url=self.base_url,
262
+ headers={"Authorization": f"Bearer {await self.get_bearer_token()}"} if auth else None,
263
+ timeout=httpx.Timeout(30.0),
264
+ )
265
+ if auth:
266
+ self._client = client
267
+ else:
268
+ self._client_no_auth = client
269
+ return client
270
+
271
+ async def _request(
272
+ self,
273
+ method: str,
274
+ endpoint: str,
275
+ *,
276
+ auth: bool = True,
277
+ params: dict[str, Any] | None = None,
278
+ data: BaseModel | dict[str, Any] | None = None,
279
+ files: dict[str, Any] | None = None,
280
+ ) -> dict[str, Any]:
281
+ url = urljoin(self.base_url, endpoint)
282
+ kwargs: dict[str, Any] = {"params": params}
283
+
284
+ if data:
285
+ if isinstance(data, BaseModel):
286
+ kwargs["json"] = data.model_dump(exclude_unset=True)
287
+ else:
288
+ kwargs["json"] = data
289
+ if files:
290
+ kwargs["files"] = files
291
+
292
+ client = await self.get_client(auth=auth)
293
+ response = await client.request(method, url, **kwargs)
294
+
295
+ if response.is_error:
296
+ logger.error("Error response from K-Scale: %s", response.text)
297
+ response.raise_for_status()
298
+ return response.json()
299
+
300
+ async def close(self) -> None:
301
+ if self._client is not None:
302
+ await self._client.aclose()
303
+ self._client = None
304
+
305
+ async def __aenter__(self) -> Self:
306
+ return self
307
+
308
+ async def __aexit__(
309
+ self,
310
+ exc_type: Type[BaseException] | None,
311
+ exc_val: BaseException | None,
312
+ exc_tb: TracebackType | None,
313
+ ) -> None:
314
+ await self.close()
@@ -0,0 +1,11 @@
1
+ """Defines a unified client for the K-Scale WWW API."""
2
+
3
+ from kscale.web.clients.base import BaseClient
4
+ from kscale.web.clients.user import UserClient
5
+
6
+
7
+ class WWWClient(
8
+ UserClient,
9
+ BaseClient,
10
+ ):
11
+ pass
@@ -0,0 +1,39 @@
1
+ """Defines the client for interacting with the K-Scale robot endpoints."""
2
+
3
+ from kscale.web.clients.base import BaseClient
4
+ from kscale.web.gen.api import Robot, RobotResponse
5
+
6
+
7
+ class RobotClient(BaseClient):
8
+ async def get_all_robots(self) -> list[Robot]:
9
+ data = await self._request("GET", "/robot/", auth=True)
10
+ return [Robot.model_validate(item) for item in data]
11
+
12
+ async def get_user_robots(self, user_id: str = "me") -> list[Robot]:
13
+ data = await self._request("GET", f"/robot/user/{user_id}", auth=True)
14
+ return [Robot.model_validate(item) for item in data]
15
+
16
+ async def add_robot(
17
+ self,
18
+ robot_name: str,
19
+ class_name: str,
20
+ description: str | None = None,
21
+ ) -> RobotResponse:
22
+ params = {"class_name": class_name}
23
+ if description is not None:
24
+ params["description"] = description
25
+ data = await self._request(
26
+ "PUT",
27
+ f"/robot/{robot_name}",
28
+ params=params,
29
+ auth=True,
30
+ )
31
+ return RobotResponse.model_validate(data)
32
+
33
+ async def get_robot_by_id(self, robot_id: str) -> RobotResponse:
34
+ data = await self._request("GET", f"/robot/id/{robot_id}", auth=True)
35
+ return RobotResponse.model_validate(data)
36
+
37
+ async def get_robot_by_name(self, robot_name: str) -> RobotResponse:
38
+ data = await self._request("GET", f"/robot/name/{robot_name}", auth=True)
39
+ return RobotResponse.model_validate(data)
@@ -0,0 +1,114 @@
1
+ """Defines the client for interacting with the K-Scale robot class endpoints."""
2
+
3
+ import hashlib
4
+ import logging
5
+ from pathlib import Path
6
+
7
+ import httpx
8
+
9
+ from kscale.web.clients.base import BaseClient
10
+ from kscale.web.gen.api import RobotClass, RobotDownloadURDFResponse, RobotUploadURDFResponse
11
+ from kscale.web.utils import get_cache_dir
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class RobotClassClient(BaseClient):
17
+ async def get_robot_classes(self) -> list[RobotClass]:
18
+ data = await self._request(
19
+ "GET",
20
+ "/robots/",
21
+ auth=True,
22
+ )
23
+ return [RobotClass.model_validate(item) for item in data]
24
+
25
+ async def create_robot_class(self, class_name: str, description: str | None = None) -> RobotClass:
26
+ params = {}
27
+ if description is not None:
28
+ params["description"] = description
29
+ data = await self._request(
30
+ "PUT",
31
+ f"/robots/{class_name}",
32
+ params=params,
33
+ auth=True,
34
+ )
35
+ return RobotClass.model_validate(data)
36
+
37
+ async def update_robot_class(
38
+ self,
39
+ class_name: str,
40
+ new_class_name: str | None = None,
41
+ new_description: str | None = None,
42
+ ) -> RobotClass:
43
+ params = {}
44
+ if new_class_name is not None:
45
+ params["new_class_name"] = new_class_name
46
+ if new_description is not None:
47
+ params["new_description"] = new_description
48
+ if not params:
49
+ raise ValueError("No parameters to update")
50
+ data = await self._request(
51
+ "POST",
52
+ f"/robots/{class_name}",
53
+ params=params,
54
+ auth=True,
55
+ )
56
+ return RobotClass.model_validate(data)
57
+
58
+ async def delete_robot_class(self, class_name: str) -> None:
59
+ await self._request("DELETE", f"/robots/{class_name}", auth=True)
60
+
61
+ async def upload_robot_class_urdf(self, class_name: str, urdf_file: str | Path) -> RobotUploadURDFResponse:
62
+ if not (urdf_file := Path(urdf_file)).exists():
63
+ raise FileNotFoundError(f"URDF file not found: {urdf_file}")
64
+
65
+ # Gets the content type from the file extension.
66
+ ext = urdf_file.suffix.lower()
67
+ match ext:
68
+ case ".tgz":
69
+ content_type = "application/x-compressed-tar"
70
+ case _:
71
+ raise ValueError(f"Unsupported file extension: {ext}")
72
+
73
+ data = await self._request(
74
+ "PUT",
75
+ f"/robots/urdf/{class_name}",
76
+ params={"filename": urdf_file.name, "content_type": content_type},
77
+ auth=True,
78
+ )
79
+ response = RobotUploadURDFResponse.model_validate(data)
80
+ async with httpx.AsyncClient() as client:
81
+ async with client.stream(
82
+ "PUT",
83
+ response.url,
84
+ content=urdf_file.read_bytes(),
85
+ headers={"Content-Type": response.content_type},
86
+ ) as r:
87
+ r.raise_for_status()
88
+ return response
89
+
90
+ async def download_robot_class_urdf(self, class_name: str, *, cache: bool = True) -> Path:
91
+ cache_path = get_cache_dir() / class_name / "robot.tgz"
92
+ if cache and cache_path.exists():
93
+ return cache_path
94
+ data = await self._request("GET", f"/robots/urdf/{class_name}", auth=True)
95
+ response = RobotDownloadURDFResponse.model_validate(data)
96
+ expected_hash = response.md5_hash
97
+ cache_path.parent.mkdir(parents=True, exist_ok=True)
98
+
99
+ logger.info("Downloading URDF file from %s", response.url)
100
+ async with httpx.AsyncClient() as client:
101
+ with open(cache_path, "wb") as file:
102
+ hash_value = hashlib.md5()
103
+ async with client.stream("GET", response.url) as r:
104
+ r.raise_for_status()
105
+ async for chunk in r.aiter_bytes():
106
+ file.write(chunk)
107
+ hash_value.update(chunk)
108
+
109
+ logger.info("Checking MD5 hash of downloaded file")
110
+ hash_value_hex = f'"{hash_value.hexdigest()}"'
111
+ if hash_value_hex != expected_hash:
112
+ raise ValueError(f"MD5 hash mismatch: {hash_value_hex} != {expected_hash}")
113
+
114
+ return cache_path
@@ -0,0 +1,10 @@
1
+ """Defines the client for interacting with the K-Scale authentication endpoints."""
2
+
3
+ from kscale.web.clients.base import BaseClient
4
+ from kscale.web.gen.api import ProfileResponse
5
+
6
+
7
+ class UserClient(BaseClient):
8
+ async def get_profile_info(self) -> ProfileResponse:
9
+ data = await self._request("GET", "/auth/profile", auth=True)
10
+ return ProfileResponse(**data)
File without changes
kscale/web/gen/api.py ADDED
@@ -0,0 +1,73 @@
1
+ """Auto-generated by generate.sh script."""
2
+
3
+ # generated by datamodel-codegen:
4
+ # filename: openapi.json
5
+ # timestamp: 2025-01-15T22:35:42+00:00
6
+
7
+ from __future__ import annotations
8
+
9
+ from typing import List, Optional, Union
10
+
11
+ from pydantic import BaseModel, Field
12
+
13
+
14
+ class OICDInfo(BaseModel):
15
+ authority: str = Field(..., title="Authority")
16
+ client_id: str = Field(..., title="Client Id")
17
+
18
+
19
+ class Robot(BaseModel):
20
+ id: str = Field(..., title="Id")
21
+ robot_name: str = Field(..., title="Robot Name")
22
+ description: str = Field(..., title="Description")
23
+ user_id: str = Field(..., title="User Id")
24
+ class_id: str = Field(..., title="Class Id")
25
+
26
+
27
+ class RobotClass(BaseModel):
28
+ id: str = Field(..., title="Id")
29
+ class_name: str = Field(..., title="Class Name")
30
+ description: str = Field(..., title="Description")
31
+ user_id: str = Field(..., title="User Id")
32
+
33
+
34
+ class RobotDownloadURDFResponse(BaseModel):
35
+ url: str = Field(..., title="Url")
36
+ md5_hash: str = Field(..., title="Md5 Hash")
37
+
38
+
39
+ class RobotResponse(BaseModel):
40
+ id: str = Field(..., title="Id")
41
+ robot_name: str = Field(..., title="Robot Name")
42
+ description: str = Field(..., title="Description")
43
+ user_id: str = Field(..., title="User Id")
44
+ class_name: str = Field(..., title="Class Name")
45
+
46
+
47
+ class RobotUploadURDFResponse(BaseModel):
48
+ url: str = Field(..., title="Url")
49
+ filename: str = Field(..., title="Filename")
50
+ content_type: str = Field(..., title="Content Type")
51
+
52
+
53
+ class UserResponse(BaseModel):
54
+ user_id: str = Field(..., title="User Id")
55
+ is_admin: bool = Field(..., title="Is Admin")
56
+ can_upload: bool = Field(..., title="Can Upload")
57
+ can_test: bool = Field(..., title="Can Test")
58
+
59
+
60
+ class ValidationError(BaseModel):
61
+ loc: List[Union[str, int]] = Field(..., title="Location")
62
+ msg: str = Field(..., title="Message")
63
+ type: str = Field(..., title="Error Type")
64
+
65
+
66
+ class HTTPValidationError(BaseModel):
67
+ detail: Optional[List[ValidationError]] = Field(None, title="Detail")
68
+
69
+
70
+ class ProfileResponse(BaseModel):
71
+ email: str = Field(..., title="Email")
72
+ email_verified: bool = Field(..., title="Email Verified")
73
+ user: UserResponse
kscale/web/utils.py ADDED
@@ -0,0 +1,31 @@
1
+ """Utility functions for interacting with the K-Scale WWW API."""
2
+
3
+ import functools
4
+ import logging
5
+ from pathlib import Path
6
+
7
+ from kscale.conf import Settings
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ DEFAULT_UPLOAD_TIMEOUT = 300.0 # 5 minutes
12
+
13
+
14
+ @functools.lru_cache
15
+ def get_cache_dir() -> Path:
16
+ """Returns the cache directory for artifacts."""
17
+ return Path(Settings.load().www.cache_dir).expanduser().resolve()
18
+
19
+
20
+ @functools.lru_cache
21
+ def get_artifact_dir(artifact_id: str) -> Path:
22
+ """Returns the directory for a specific artifact."""
23
+ cache_dir = get_cache_dir() / artifact_id
24
+ cache_dir.mkdir(parents=True, exist_ok=True)
25
+ return cache_dir
26
+
27
+
28
+ @functools.lru_cache
29
+ def get_api_root() -> str:
30
+ """Returns the root URL for the K-Scale WWW API."""
31
+ return Settings.load().www.api_root