cradle-sdk 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cradle/sdk/__init__.py ADDED
@@ -0,0 +1,9 @@
1
+ from .client import (
2
+ Client as Client,
3
+ DataClient as DataClient,
4
+ DataLoadClient as DataLoadClient,
5
+ ProjectClient as ProjectClient,
6
+ RoundClient as RoundClient,
7
+ TaskClient as TaskClient,
8
+ )
9
+ from .exceptions import ClientError as ClientError
File without changes
@@ -0,0 +1,478 @@
1
+ import contextlib
2
+ import json
3
+ import logging
4
+ import time
5
+ from datetime import UTC, datetime, timedelta
6
+ from typing import Annotated, Literal, Self
7
+ from urllib.parse import urlencode, urljoin
8
+
9
+ import httpx
10
+ import jwt
11
+ import keyring
12
+ import keyring.errors
13
+ from keyring.backends.chainer import ChainerBackend
14
+ from pydantic import BaseModel, Field, TypeAdapter, ValidationError
15
+ from typing_extensions import Generator
16
+
17
+ from cradle.sdk.exceptions import ClientError
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class ClientAuthError(Exception):
23
+ def __init__(self, message: str):
24
+ self.message = message
25
+ super().__init__(f"Authentication Error: {message}")
26
+
27
+
28
+ class ReauthenticationRequiredError(ClientAuthError):
29
+ def __init__(self, logout_url):
30
+ self.logout_url = logout_url
31
+ super().__init__("Reauthentication required")
32
+
33
+
34
+ class DeviceAuthResponse(BaseModel):
35
+ device_code: str
36
+ user_code: str
37
+ verification_uri: str
38
+ verification_uri_complete: str
39
+ expires_in: int
40
+ interval: int
41
+
42
+
43
+ class AuthenticateSuccessResponse(BaseModel):
44
+ access_token: str
45
+ refresh_token: str
46
+ organization_id: str | None = None
47
+ authentication_method: str
48
+ user: dict
49
+
50
+
51
+ class AuthenticateAuthenticationErrorResponse(BaseModel):
52
+ error: str
53
+ error_description: str
54
+
55
+
56
+ class UnknownErrorResponse(BaseModel):
57
+ response: dict
58
+
59
+
60
+ AuthenticateErrorResponse = AuthenticateAuthenticationErrorResponse | UnknownErrorResponse
61
+
62
+
63
+ class UnauthorizedResponse(BaseModel):
64
+ authorized: Literal[False] = False
65
+ can_reauth: bool
66
+ org_id: str | None
67
+
68
+
69
+ class AuthorizedResponse(BaseModel):
70
+ authorized: Literal[True] = True
71
+
72
+
73
+ CheckAuthResponse = Annotated[UnauthorizedResponse | AuthorizedResponse, Field(discriminator="authorized")]
74
+ CheckAuthResponseAdapter = TypeAdapter(CheckAuthResponse)
75
+
76
+
77
+ class _DeviceAuthStrategy(BaseModel):
78
+ client_id: str
79
+ app_url: str
80
+ auth_api_url: str
81
+
82
+
83
+ class DeviceAuth(httpx.Auth):
84
+ def __init__(
85
+ self,
86
+ client_id: str,
87
+ cradle_app_url: str,
88
+ cradle_api_base_url: str,
89
+ auth_base_url: str,
90
+ workspace: str | None = None,
91
+ use_keyring: bool = True,
92
+ user_agent: str | None = None,
93
+ ):
94
+ self.client_id = client_id
95
+ self.cradle_app_url = cradle_app_url
96
+ self.use_keyring = use_keyring
97
+ self._keyring_servicename: str = __name__
98
+ self.workspace = workspace
99
+ self.cradle_api_base_url = cradle_api_base_url
100
+ self.token: AuthenticateSuccessResponse | None = None
101
+
102
+ self.checked_for_reauth = False
103
+ self.force_org_id = None
104
+
105
+ self.auth_base_url = auth_base_url
106
+ self.user_agent = user_agent
107
+
108
+ # Try to load token from keyring on initialization
109
+ if self.use_keyring:
110
+ self._load_token_from_keyring()
111
+
112
+ @classmethod
113
+ def from_strategy(
114
+ cls,
115
+ client: httpx.Client,
116
+ base_url: str,
117
+ workspace: str | None,
118
+ use_keyring: bool = True,
119
+ ) -> Self:
120
+ """Create device auth with parameters looked up in the API."""
121
+ api_strategy_url = f"{base_url}/auth:apiStrategy"
122
+
123
+ try:
124
+ response = client.get(api_strategy_url)
125
+ response.raise_for_status()
126
+ strategy = _DeviceAuthStrategy.model_validate(response.json())
127
+ except Exception as exc:
128
+ raise RuntimeError(f"Failed to retreive auth strategy from {api_strategy_url}") from exc
129
+
130
+ return cls(
131
+ workspace=workspace,
132
+ cradle_api_base_url=base_url,
133
+ client_id=strategy.client_id,
134
+ cradle_app_url=strategy.app_url,
135
+ auth_base_url="https://auth.cradle.bio",
136
+ use_keyring=use_keyring,
137
+ user_agent=client.headers.get("User-Agent"),
138
+ )
139
+
140
+ @property
141
+ def access_token(self) -> str | None:
142
+ return self.token.access_token if self.token else None
143
+
144
+ @property
145
+ def refresh_token(self) -> str | None:
146
+ return self.token.refresh_token if self.token else None
147
+
148
+ @property
149
+ def session_id(self) -> str | None:
150
+ return self.jwt_payload.get("sid") if self.jwt_payload else None
151
+
152
+ @property
153
+ def authorized_workspace_id(self) -> str | None:
154
+ return self.jwt_payload.get("urn:cradle:workspace_id") if self.jwt_payload else None
155
+
156
+ @property
157
+ def jwks_url(self) -> str:
158
+ return f"{self.auth_base_url}/sso/jwks/{self.client_id}"
159
+
160
+ @property
161
+ def jwt_payload(self) -> dict | None:
162
+ if self.token is None:
163
+ return None
164
+
165
+ # Per https://github.com/jpadilla/pyjwt/issues/939 - we should not be verifying the `iat` time here
166
+ # as it could be slightly in the future.
167
+ return jwt.decode(self.token.access_token, options={"verify_signature": False, "verify_iat": False})
168
+
169
+ @property
170
+ def expires_at(self) -> datetime | None:
171
+ if self.jwt_payload is None:
172
+ return None
173
+ return datetime.fromtimestamp(self.jwt_payload["exp"], tz=UTC)
174
+
175
+ @property
176
+ def org_id(self) -> str | None:
177
+ if self.jwt_payload is None:
178
+ return None
179
+ return self.jwt_payload.get("org_id")
180
+
181
+ @property
182
+ def logout_url(self) -> str | None:
183
+ session_id = self.session_id
184
+ if session_id is None:
185
+ return None
186
+
187
+ logout_path = "/user_management/sessions/logout"
188
+ return_to = f"{self.cradle_app_url}/_/sdk/post-logout"
189
+ query = urlencode({"session_id": session_id, "return_to": return_to})
190
+ return urljoin(self.auth_base_url, logout_path) + f"?{query}"
191
+
192
+ def _needs_refresh(self) -> bool:
193
+ if self.access_token is None or self.expires_at is None:
194
+ return False
195
+
196
+ if self.force_org_id is not None and self.org_id != self.force_org_id:
197
+ return True
198
+
199
+ # access_token is valid until expires_at (which I think is currently issue time + 5 min)
200
+ # We refresh the access_token 10 seconds before expiration to avoid race conditions
201
+ return datetime.now(tz=UTC) > self.expires_at - timedelta(seconds=10)
202
+
203
+ def _save_token_to_keyring(self, token: AuthenticateSuccessResponse) -> None:
204
+ """Save token data to keyring."""
205
+ if not self.use_keyring:
206
+ return
207
+
208
+ # The backends built into keyring are secure.
209
+ # They will only be in the list of chainer backends if they are "viable"
210
+ # If this list is empty, it likely means we have no secure keyrings available.
211
+ has_keyring_backend = any(x.__class__.__module__.startswith("keyring.") for x in ChainerBackend.backends)
212
+ # The backends built into keyrings.alt are insecure.
213
+ has_alt_backend = any(x.__class__.__module__.startswith("keyrings.alt") for x in ChainerBackend.backends)
214
+
215
+ if not has_keyring_backend and has_alt_backend:
216
+ logger.warning(
217
+ "Keyring is likely using a backend from `keyrings.alt`. This is an *insecure* way to save your authentication for a limited period of time."
218
+ )
219
+
220
+ # Store the token response, don't worry if saving to keyring fails - maybe this system doesn't support it
221
+ #
222
+ # Some systems don't support keyring, such as a headless linux machine (because there is no GUI to pop up the
223
+ # "do you want to allow this application to access your keyring?" dialog). This can be a problem when people run
224
+ # code on a remote machine.
225
+ #
226
+ # In those cases, a user can install `keyrings.alt` into their notebook or environment, and a new backend called
227
+ # `keyrings.alt.file.PlaintextKeyring` will be available. Because this is insecure, we don't do it by default.
228
+ # However, all a user needs to do is install `keyrings.alt` into their notebook or environment, and it will automatically
229
+ # be added to the resolution order of keyring backends.
230
+ #
231
+ # If the system does not support any secure keyrings, and no alternate backends are installed, we will get a NoKeyringError.
232
+ #
233
+ # If there is a keyring that does not support set_password in the resolution order - such as keyrings.gauth.GooglePythonAuth
234
+ # aka `keyrings.google-artifactregistry-auth`, we will get a NotImplementedError. If there is another keyring below it in the
235
+ # order, like PlaintextKeyring that does, then the ChainerBackend will fall through to the PlaintextKeyring.
236
+ try:
237
+ keyring.set_password(self._keyring_servicename, self.client_id, token.model_dump_json())
238
+ except (NotImplementedError, keyring.errors.NoKeyringError):
239
+ logger.warning("""No keyring available. You will need to re-authenticate every time you run this code.
240
+
241
+ To save your authentication for a limited period of time in an *insecure* way, install python package `keyrings.alt` into your notebook or environment and re-run this code.
242
+ """)
243
+
244
+ def _load_token_from_keyring(self) -> None:
245
+ """Load token data from keyring."""
246
+ if not self.use_keyring:
247
+ return
248
+
249
+ with contextlib.suppress(keyring.errors.KeyringError):
250
+ token_json = keyring.get_password(self._keyring_servicename, self.client_id)
251
+ if token_json is not None:
252
+ self._update_token(AuthenticateSuccessResponse.model_validate(json.loads(token_json)))
253
+
254
+ def _clear_token_from_keyring(self) -> None:
255
+ """Clear token data from keyring."""
256
+ if not self.use_keyring:
257
+ return
258
+
259
+ # Don't worry if deleting from keyring fails - maybe this system doesn't support it
260
+ with contextlib.suppress(keyring.errors.KeyringError):
261
+ keyring.delete_password(self._keyring_servicename, self.client_id)
262
+
263
+ def _update_token(self, success_response: AuthenticateSuccessResponse) -> None:
264
+ self.token = success_response
265
+ # Save to keyring for persistence
266
+ self._save_token_to_keyring(success_response)
267
+
268
+ @staticmethod
269
+ def parse_response(response: httpx.Response) -> AuthenticateSuccessResponse | AuthenticateErrorResponse:
270
+ if response.is_success:
271
+ return AuthenticateSuccessResponse.model_validate(response.json())
272
+ try:
273
+ return AuthenticateAuthenticationErrorResponse.model_validate(response.json())
274
+ except ValidationError:
275
+ return UnknownErrorResponse(response=response.json())
276
+
277
+ def _authenticate(
278
+ self, data: dict[str, str]
279
+ ) -> Generator[httpx.Request, httpx.Response, AuthenticateSuccessResponse | AuthenticateErrorResponse]:
280
+ response = yield httpx.Request(
281
+ method="POST",
282
+ url=f"{self.auth_base_url}/user_management/authenticate",
283
+ headers={"Content-Type": "application/x-www-form-urlencoded", "User-Agent": self.user_agent or "unknown"},
284
+ data={**data, "client_id": self.client_id},
285
+ )
286
+ response.read()
287
+
288
+ return self.parse_response(response)
289
+
290
+ def _refresh_token(self) -> Generator[httpx.Request, httpx.Response, None]:
291
+ if self.refresh_token is None:
292
+ raise ClientAuthError("No refresh token available")
293
+
294
+ refresh_data = {
295
+ "grant_type": "refresh_token",
296
+ "refresh_token": self.refresh_token,
297
+ }
298
+ if self.force_org_id is not None:
299
+ refresh_data["organization_id"] = self.force_org_id
300
+
301
+ response = yield from self._authenticate(refresh_data)
302
+
303
+ match response:
304
+ case AuthenticateAuthenticationErrorResponse(error="invalid_grant"):
305
+ # This might mean the refresh token has expired, so let's clear keyring and re-run the device flow
306
+ self._clear_token_from_keyring()
307
+ self.token = None
308
+ yield from self._start_device_flow()
309
+ case AuthenticateAuthenticationErrorResponse(error="sso_required"):
310
+ # This means the token cannot be refreshed for a different org that requires SSO,
311
+ # so let's clear keyring and re-run the device flow
312
+ self._clear_token_from_keyring()
313
+ self.token = None
314
+ yield from self._start_device_flow()
315
+ case AuthenticateAuthenticationErrorResponse(error="mfa_enrollment"):
316
+ # It might not be possible to refresh a token to a different organization if the new
317
+ # organization has a different security policy in place, e.g. regarding MFA.
318
+ # In this case we need to log out and re-authenticate to the new workspace.
319
+ # Unfortunately this requires the user to manually click a logout link.
320
+ logout_url = self.logout_url
321
+ if logout_url is None:
322
+ raise ClientAuthError("No session ID available to log out.")
323
+
324
+ print(
325
+ "Access to the requested workspace requires re-authentication. "
326
+ "Please log out under the following link and then retry the request.\n\n"
327
+ f"{logout_url}"
328
+ )
329
+
330
+ raise ReauthenticationRequiredError(logout_url)
331
+ case AuthenticateAuthenticationErrorResponse():
332
+ raise ClientAuthError(response.error_description)
333
+ case UnknownErrorResponse():
334
+ raise ClientAuthError(json.dumps(response.response))
335
+ case AuthenticateSuccessResponse():
336
+ self._update_token(response)
337
+
338
+ def _poll_for_tokens(
339
+ self, device_code: str, expires_in: int = 300, interval: int = 5
340
+ ) -> Generator[httpx.Request, httpx.Response, AuthenticateSuccessResponse]:
341
+ """Poll for authentication tokens using device code flow.
342
+
343
+ Args:
344
+ device_code: The device code from the initial auth request
345
+ expires_in: Timeout in seconds (default 300)
346
+ interval: Polling interval in seconds (default 5)
347
+
348
+ Returns:
349
+ AuthenticateSuccessResponse: access and refresh tokens
350
+
351
+ Raises:
352
+ Exception: If polling for access/refresh tokens fails or times out
353
+ """
354
+ start_time = time.monotonic()
355
+
356
+ while True:
357
+ # Check timeout
358
+ if time.monotonic() - start_time > expires_in:
359
+ raise ClientAuthError("Authentication timed out")
360
+
361
+ try:
362
+ refresh_data = {
363
+ "device_code": device_code,
364
+ "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
365
+ "organization_id": self.force_org_id,
366
+ }
367
+ response = yield from self._authenticate(refresh_data)
368
+
369
+ except httpx.TimeoutException as e:
370
+ raise ClientAuthError("Authentication timed out") from e
371
+
372
+ match response:
373
+ case AuthenticateSuccessResponse():
374
+ return response
375
+ case AuthenticateAuthenticationErrorResponse(error="authorization_pending"):
376
+ time.sleep(interval)
377
+ case AuthenticateAuthenticationErrorResponse(error="slow_down"):
378
+ interval += 1
379
+ case AuthenticateAuthenticationErrorResponse():
380
+ raise ClientAuthError(response.error_description)
381
+ case UnknownErrorResponse():
382
+ raise ClientAuthError(json.dumps(response.response))
383
+
384
+ def _start_device_flow(self) -> Generator[httpx.Request, httpx.Response, None]:
385
+ device_auth_http_response: httpx.Response = yield httpx.Request(
386
+ method="POST",
387
+ url=f"{self.auth_base_url}/user_management/authorize/device",
388
+ json={
389
+ "client_id": self.client_id,
390
+ },
391
+ headers={"user-agent": self.user_agent or "unknown"},
392
+ )
393
+ device_auth_http_response.read()
394
+ device_auth_http_response.raise_for_status()
395
+ device_auth_response = DeviceAuthResponse.model_validate(device_auth_http_response.json())
396
+
397
+ print(f"""
398
+ Please click the following URL to complete authentication
399
+
400
+ {device_auth_response.verification_uri_complete}
401
+
402
+ and verify that the code matches: {device_auth_response.user_code}
403
+ """)
404
+
405
+ # Poll for tokens using the device code
406
+ token_data = yield from self._poll_for_tokens(
407
+ device_code=device_auth_response.device_code,
408
+ expires_in=device_auth_response.expires_in,
409
+ interval=device_auth_response.interval,
410
+ )
411
+ self._update_token(token_data)
412
+
413
+ def _check_auth(self) -> Generator[httpx.Request, httpx.Response, None | CheckAuthResponse]:
414
+ if self.cradle_api_base_url is None or self.workspace is None or self.access_token is None:
415
+ return None
416
+
417
+ needs_reauth_response: httpx.Response = yield httpx.Request(
418
+ method="GET",
419
+ url=f"{self.cradle_api_base_url}/auth:checkAuth",
420
+ params={"workspace": self.workspace},
421
+ headers={"Authorization": f"Bearer ca_{self.access_token}", "User-Agent": self.user_agent or "unknown"},
422
+ )
423
+
424
+ needs_reauth_response.read()
425
+ needs_reauth_response.raise_for_status()
426
+ return CheckAuthResponseAdapter.validate_python(needs_reauth_response.json())
427
+
428
+ def auth_flow(self, request):
429
+ if not self.access_token:
430
+ yield from self._start_device_flow()
431
+
432
+ # If we loaded the token from keyring, the access_token is almost certainly expired,
433
+ # so we need to refresh it *before* hitting the checkAuth endpoint, which is authenticated
434
+ # otherwise we get a 403.
435
+ if self._needs_refresh():
436
+ yield from self._refresh_token()
437
+
438
+ # Only check for reauth once - the requested workspace that this client is constructed with
439
+ # will not change, so we won't need to check if we need to reauth to a different workspace/org
440
+ # more than once
441
+ if not self.checked_for_reauth:
442
+ check_auth_response = yield from self._check_auth()
443
+ # can_reauth=True means that we are a member of the workspace we want to access, but need to reauthorize
444
+ # our access_token to the correct org
445
+ if check_auth_response and check_auth_response.authorized == False and check_auth_response.can_reauth: # noqa: E712
446
+ # This will force _refresh_token to request access to the correct org
447
+ self.force_org_id = check_auth_response.org_id
448
+
449
+ self.checked_for_reauth = True
450
+
451
+ # After checking if we need to reauth, this will fire if we need to re-auth the token to a different org
452
+ # from the one it was issued for.
453
+ if self._needs_refresh():
454
+ yield from self._refresh_token()
455
+
456
+ request.headers["Authorization"] = f"Bearer ca_{self.access_token}"
457
+ yield request
458
+
459
+ def suggest_logout(self, workspace: str) -> None: # noqa: ARG002 - might use workspace in the future
460
+ """Print logout URL."""
461
+ logout_url = self.logout_url
462
+ if logout_url is None:
463
+ raise ClientAuthError("No session ID available to log out.")
464
+
465
+ # Don't clear token in case this is caused by fetching a workspace immediately after
466
+ # creating it.
467
+ #
468
+ # Don't clear keyring in case the user decides to recreate the client with a different
469
+ # workspace name.
470
+
471
+ # Reset this check in case we just created a workspace and now need to auth to it
472
+ self.checked_for_reauth = False
473
+
474
+ message = (
475
+ f'Authorized workspace "{self.authorized_workspace_id}" does not match requested workspace "{self.workspace}"\n'
476
+ f"If you believe this is an error, you can log out of this session at {logout_url} and then try the request again."
477
+ )
478
+ raise ClientError(403, message, [])