dreadnode 1.12.2__tar.gz → 1.13.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. {dreadnode-1.12.2 → dreadnode-1.13.1}/PKG-INFO +2 -1
  2. {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/__init__.py +7 -2
  3. dreadnode-1.13.1/dreadnode/__main__.py +10 -0
  4. {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/api/client.py +124 -28
  5. {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/api/models.py +34 -10
  6. {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/api/util.py +1 -1
  7. {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/artifact/storage.py +15 -1
  8. dreadnode-1.13.1/dreadnode/cli/__init__.py +3 -0
  9. dreadnode-1.13.1/dreadnode/cli/api.py +79 -0
  10. dreadnode-1.13.1/dreadnode/cli/github.py +273 -0
  11. dreadnode-1.13.1/dreadnode/cli/main.py +200 -0
  12. dreadnode-1.13.1/dreadnode/cli/profile/__init__.py +3 -0
  13. dreadnode-1.13.1/dreadnode/cli/profile/cli.py +101 -0
  14. dreadnode-1.13.1/dreadnode/config.py +108 -0
  15. dreadnode-1.13.1/dreadnode/constants.py +62 -0
  16. dreadnode-1.13.1/dreadnode/data_types/__init__.py +9 -0
  17. dreadnode-1.13.1/dreadnode/lookup.py +146 -0
  18. {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/main.py +153 -114
  19. {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/object.py +1 -2
  20. {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/scorers/__init__.py +27 -3
  21. dreadnode-1.13.1/dreadnode/scorers/classification.py +102 -0
  22. {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/scorers/consistency.py +16 -18
  23. {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/scorers/contains.py +50 -31
  24. dreadnode-1.13.1/dreadnode/scorers/format.py +61 -0
  25. dreadnode-1.13.1/dreadnode/scorers/harm.py +55 -0
  26. dreadnode-1.12.2/dreadnode/scorers/llm_judge.py → dreadnode-1.13.1/dreadnode/scorers/judge.py +18 -20
  27. {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/scorers/length.py +28 -23
  28. dreadnode-1.13.1/dreadnode/scorers/lexical.py +67 -0
  29. dreadnode-1.13.1/dreadnode/scorers/operators.py +122 -0
  30. {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/scorers/pii.py +1 -2
  31. {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/scorers/readability.py +6 -2
  32. {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/scorers/sentiment.py +14 -2
  33. dreadnode-1.13.1/dreadnode/scorers/similarity.py +293 -0
  34. dreadnode-1.13.1/dreadnode/scorers/util.py +33 -0
  35. {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/serialization.py +1 -1
  36. dreadnode-1.13.1/dreadnode/storage_utils.py +37 -0
  37. {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/task.py +0 -84
  38. {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/tracing/span.py +19 -8
  39. {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/util.py +119 -0
  40. {dreadnode-1.12.2 → dreadnode-1.13.1}/pyproject.toml +13 -2
  41. dreadnode-1.12.2/dreadnode/constants.py +0 -16
  42. dreadnode-1.12.2/dreadnode/data_types/__init__.py +0 -9
  43. dreadnode-1.12.2/dreadnode/scorers/similarity.py +0 -180
  44. {dreadnode-1.12.2 → dreadnode-1.13.1}/README.md +0 -0
  45. {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/api/__init__.py +0 -0
  46. {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/artifact/__init__.py +0 -0
  47. {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/artifact/merger.py +0 -0
  48. {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/artifact/tree_builder.py +0 -0
  49. {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/convert.py +0 -0
  50. {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/data_types/audio.py +0 -0
  51. {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/data_types/base.py +0 -0
  52. {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/data_types/image.py +0 -0
  53. {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/data_types/object_3d.py +0 -0
  54. {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/data_types/table.py +0 -0
  55. {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/data_types/text.py +0 -0
  56. {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/data_types/video.py +0 -0
  57. {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/integrations/__init__.py +0 -0
  58. {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/integrations/transformers.py +0 -0
  59. {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/metric.py +0 -0
  60. {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/py.typed +0 -0
  61. {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/scorers/rigging.py +0 -0
  62. {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/tracing/__init__.py +0 -0
  63. {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/tracing/constants.py +0 -0
  64. {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/tracing/exporters.py +0 -0
  65. {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/types.py +0 -0
  66. {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/version.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: dreadnode
3
- Version: 1.12.2
3
+ Version: 1.13.1
4
4
  Summary: Dreadnode SDK
5
5
  Author: Nick Landers
6
6
  Author-email: monoxgas@gmail.com
@@ -14,6 +14,7 @@ Provides-Extra: all
14
14
  Provides-Extra: multimodal
15
15
  Provides-Extra: training
16
16
  Requires-Dist: coolname (>=2.2.0,<3.0.0)
17
+ Requires-Dist: cyclopts (>=3.22.2,<4.0.0)
17
18
  Requires-Dist: fsspec[s3] (>=2023.1.0,<=2025.3.0)
18
19
  Requires-Dist: httpx (>=0.28.0,<0.29.0)
19
20
  Requires-Dist: logfire (>=3.5.3,<=3.20.0)
@@ -1,9 +1,10 @@
1
1
  from dreadnode import convert, data_types, scorers
2
2
  from dreadnode.data_types import Audio, Code, Image, Markdown, Object3D, Table, Text, Video
3
+ from dreadnode.lookup import Lookup, lookup_input, lookup_output, lookup_param, resolve_lookup
3
4
  from dreadnode.main import DEFAULT_INSTANCE, Dreadnode
4
5
  from dreadnode.metric import Metric, MetricDict, Scorer
5
6
  from dreadnode.object import Object
6
- from dreadnode.task import Task, TaskInput
7
+ from dreadnode.task import Task
7
8
  from dreadnode.tracing.span import RunSpan, Span, TaskSpan
8
9
  from dreadnode.version import VERSION
9
10
 
@@ -39,6 +40,7 @@ __all__ = [
39
40
  "Code",
40
41
  "Dreadnode",
41
42
  "Image",
43
+ "Lookup",
42
44
  "Markdown",
43
45
  "Metric",
44
46
  "MetricDict",
@@ -50,7 +52,6 @@ __all__ = [
50
52
  "Span",
51
53
  "Table",
52
54
  "Task",
53
- "TaskInput",
54
55
  "TaskSpan",
55
56
  "Text",
56
57
  "Video",
@@ -69,7 +70,11 @@ __all__ = [
69
70
  "log_output",
70
71
  "log_param",
71
72
  "log_params",
73
+ "lookup_input",
74
+ "lookup_output",
75
+ "lookup_param",
72
76
  "push_update",
77
+ "resolve_lookup",
73
78
  "run",
74
79
  "scorer",
75
80
  "scorers",
@@ -0,0 +1,10 @@
1
+ from dreadnode.cli import cli
2
+
3
+
4
+ def run() -> None:
5
+ """Run the Dreadnode CLI."""
6
+ cli.meta()
7
+
8
+
9
+ if __name__ == "__main__":
10
+ run()
@@ -1,22 +1,19 @@
1
1
  import io
2
2
  import json
3
+ import time
3
4
  import typing as t
5
+ from datetime import datetime, timezone
6
+ from urllib.parse import urlparse
4
7
 
5
8
  import httpx
6
9
  import pandas as pd
7
10
  from pydantic import BaseModel
8
11
  from ulid import ULID
9
12
 
10
- from dreadnode.api.util import (
11
- convert_flat_tasks_to_tree,
12
- convert_flat_trace_to_tree,
13
- process_run,
14
- process_task,
15
- )
16
- from dreadnode.util import logger
17
- from dreadnode.version import VERSION
18
-
19
- from .models import (
13
+ from dreadnode.api.models import (
14
+ AccessRefreshTokenResponse,
15
+ DeviceCodeResponse,
16
+ GithubTokenResponse,
20
17
  MetricAggregationType,
21
18
  Project,
22
19
  RawRun,
@@ -31,7 +28,21 @@ from .models import (
31
28
  TraceSpan,
32
29
  TraceTree,
33
30
  UserDataCredentials,
31
+ UserResponse,
34
32
  )
33
+ from dreadnode.api.util import (
34
+ convert_flat_tasks_to_tree,
35
+ convert_flat_trace_to_tree,
36
+ process_run,
37
+ process_task,
38
+ )
39
+ from dreadnode.constants import (
40
+ DEFAULT_FS_CREDENTIAL_DURATION,
41
+ DEFAULT_MAX_POLL_TIME,
42
+ DEFAULT_POLL_INTERVAL,
43
+ )
44
+ from dreadnode.util import logger
45
+ from dreadnode.version import VERSION
35
46
 
36
47
  ModelT = t.TypeVar("ModelT", bound=BaseModel)
37
48
 
@@ -47,11 +58,13 @@ class ApiClient:
47
58
  def __init__(
48
59
  self,
49
60
  base_url: str,
50
- api_key: str,
51
61
  *,
62
+ api_key: str | None = None,
63
+ cookies: dict[str, str] | None = None,
52
64
  debug: bool = False,
53
65
  ):
54
- """Initializes the API client.
66
+ """
67
+ Initializes the API client.
55
68
 
56
69
  Args:
57
70
  base_url (str): The base URL of the Dreadnode API.
@@ -62,12 +75,28 @@ class ApiClient:
62
75
  if not self._base_url.endswith("/api"):
63
76
  self._base_url += "/api"
64
77
 
78
+ _cookies = httpx.Cookies()
79
+ cookie_domain = urlparse(base_url).hostname
80
+ if cookie_domain is None:
81
+ raise ValueError(f"Invalid URL: {base_url}")
82
+
83
+ if cookie_domain == "localhost":
84
+ cookie_domain = "localhost.local"
85
+
86
+ for key, value in (cookies or {}).items():
87
+ _cookies.set(key, value, domain=cookie_domain)
88
+
89
+ headers = {
90
+ "User-Agent": f"dreadnode-sdk/{VERSION}",
91
+ "Accept": "application/json",
92
+ }
93
+
94
+ if api_key:
95
+ headers["X-Api-Key"] = api_key
96
+
65
97
  self._client = httpx.Client(
66
- headers={
67
- "User-Agent": f"dreadnode-sdk/{VERSION}",
68
- "Accept": "application/json",
69
- "X-API-Key": api_key,
70
- },
98
+ headers=headers,
99
+ cookies=_cookies,
71
100
  base_url=self._base_url,
72
101
  timeout=30,
73
102
  )
@@ -77,7 +106,8 @@ class ApiClient:
77
106
  self._client.event_hooks["response"].append(self._log_response)
78
107
 
79
108
  def _log_request(self, request: httpx.Request) -> None:
80
- """Logs HTTP requests if debug mode is enabled.
109
+ """
110
+ Logs HTTP requests if debug mode is enabled.
81
111
 
82
112
  Args:
83
113
  request (httpx.Request): The HTTP request object.
@@ -90,7 +120,8 @@ class ApiClient:
90
120
  logger.debug("-------------------------------------------")
91
121
 
92
122
  def _log_response(self, response: httpx.Response) -> None:
93
- """Logs HTTP responses if debug mode is enabled.
123
+ """
124
+ Logs HTTP responses if debug mode is enabled.
94
125
 
95
126
  Args:
96
127
  response (httpx.Response): The HTTP response object.
@@ -103,7 +134,8 @@ class ApiClient:
103
134
  logger.debug("--------------------------------------------")
104
135
 
105
136
  def _get_error_message(self, response: httpx.Response) -> str:
106
- """Extracts the error message from an HTTP response.
137
+ """
138
+ Extracts the error message from an HTTP response.
107
139
 
108
140
  Args:
109
141
  response (httpx.Response): The HTTP response object.
@@ -125,7 +157,8 @@ class ApiClient:
125
157
  params: dict[str, t.Any] | None = None,
126
158
  json_data: dict[str, t.Any] | None = None,
127
159
  ) -> httpx.Response:
128
- """Makes a raw HTTP request to the API.
160
+ """
161
+ Makes a raw HTTP request to the API.
129
162
 
130
163
  Args:
131
164
  method (str): The HTTP method (e.g., "GET", "POST").
@@ -146,7 +179,8 @@ class ApiClient:
146
179
  params: dict[str, t.Any] | None = None,
147
180
  json_data: dict[str, t.Any] | None = None,
148
181
  ) -> httpx.Response:
149
- """Makes an HTTP request to the API and raises exceptions for errors.
182
+ """
183
+ Makes an HTTP request to the API and raises exceptions for errors.
150
184
 
151
185
  Args:
152
186
  method (str): The HTTP method (e.g., "GET", "POST").
@@ -170,6 +204,59 @@ class ApiClient:
170
204
 
171
205
  return response
172
206
 
207
+ # Auth
208
+
209
+ def url_for_user_code(self, user_code: str) -> str:
210
+ """Get the URL to verify the user code."""
211
+
212
+ return f"{self._base_url.removesuffix('/api')}/account/device?code={user_code}"
213
+
214
+ def get_device_codes(self) -> DeviceCodeResponse:
215
+ """Start the authentication flow by requesting user and device codes."""
216
+
217
+ response = self.request("POST", "/auth/device/code")
218
+ return DeviceCodeResponse(**response.json())
219
+
220
+ def poll_for_token(
221
+ self,
222
+ device_code: str,
223
+ interval: int = DEFAULT_POLL_INTERVAL,
224
+ max_poll_time: int = DEFAULT_MAX_POLL_TIME,
225
+ ) -> AccessRefreshTokenResponse:
226
+ """Poll for the access token with the given device code."""
227
+
228
+ start_time = datetime.now(timezone.utc)
229
+ while (datetime.now(timezone.utc) - start_time).total_seconds() < max_poll_time:
230
+ response = self._request(
231
+ "POST", "/auth/device/token", json_data={"device_code": device_code}
232
+ )
233
+
234
+ if response.status_code == 200: # noqa: PLR2004
235
+ return AccessRefreshTokenResponse(**response.json())
236
+ if response.status_code != 401: # noqa: PLR2004
237
+ raise RuntimeError(self._get_error_message(response))
238
+
239
+ time.sleep(interval)
240
+
241
+ raise RuntimeError("Polling for token timed out")
242
+
243
+ # User
244
+
245
+ def get_user(self) -> UserResponse:
246
+ """Get the user email and username."""
247
+
248
+ response = self.request("GET", "/user")
249
+ return UserResponse(**response.json())
250
+
251
+ # Github
252
+
253
+ def get_github_access_token(self, repos: list[str]) -> GithubTokenResponse:
254
+ """Try to get a GitHub access token for the given repositories."""
255
+ response = self.request("POST", "/github/token", json_data={"repos": repos})
256
+ return GithubTokenResponse(**response.json())
257
+
258
+ # Strikes
259
+
173
260
  def list_projects(self) -> list[Project]:
174
261
  """Retrieves a list of projects.
175
262
 
@@ -294,7 +381,8 @@ class ApiClient:
294
381
  status: StatusFilter = "completed",
295
382
  aggregations: list[MetricAggregationType] | None = None,
296
383
  ) -> pd.DataFrame:
297
- """Exports run data for a specific project.
384
+ """
385
+ Exports run data for a specific project.
298
386
 
299
387
  Args:
300
388
  project: The project identifier.
@@ -327,7 +415,8 @@ class ApiClient:
327
415
  metrics: list[str] | None = None,
328
416
  aggregations: list[MetricAggregationType] | None = None,
329
417
  ) -> pd.DataFrame:
330
- """Exports metric data for a specific project.
418
+ """
419
+ Exports metric data for a specific project.
331
420
 
332
421
  Args:
333
422
  project: The project identifier.
@@ -363,7 +452,8 @@ class ApiClient:
363
452
  metrics: list[str] | None = None,
364
453
  aggregations: list[MetricAggregationType] | None = None,
365
454
  ) -> pd.DataFrame:
366
- """Exports parameter data for a specific project.
455
+ """
456
+ Exports parameter data for a specific project.
367
457
 
368
458
  Args:
369
459
  project: The project identifier.
@@ -401,7 +491,8 @@ class ApiClient:
401
491
  time_axis: TimeAxisType = "relative",
402
492
  aggregations: list[TimeAggregationType] | None = None,
403
493
  ) -> pd.DataFrame:
404
- """Exports timeseries data for a specific project.
494
+ """
495
+ Exports timeseries data for a specific project.
405
496
 
406
497
  Args:
407
498
  project: The project identifier.
@@ -430,12 +521,17 @@ class ApiClient:
430
521
 
431
522
  # User data access
432
523
 
433
- def get_user_data_credentials(self) -> UserDataCredentials:
524
+ def get_user_data_credentials(
525
+ self, duration: int = DEFAULT_FS_CREDENTIAL_DURATION
526
+ ) -> UserDataCredentials:
434
527
  """
435
528
  Retrieves user data credentials for secondary storage access.
436
529
 
530
+ Args:
531
+ duration: Credential lifetime in seconds (default: 4 hours)
532
+
437
533
  Returns:
438
534
  The user data credentials object.
439
535
  """
440
- response = self.request("GET", "/user-data/credentials")
536
+ response = self._request("GET", "/user-data/credentials", params={"duration": duration})
441
537
  return UserDataCredentials(**response.json())
@@ -32,6 +32,35 @@ class UserResponse(BaseModel):
32
32
  api_key: UserAPIKey
33
33
 
34
34
 
35
+ class UserDataCredentials(BaseModel):
36
+ access_key_id: str
37
+ secret_access_key: str
38
+ session_token: str
39
+ expiration: datetime
40
+ region: str
41
+ bucket: str
42
+ prefix: str
43
+ endpoint: str | None
44
+
45
+
46
+ # Auth
47
+
48
+
49
+ class DeviceCodeResponse(BaseModel):
50
+ id: UUID
51
+ completed: bool
52
+ device_code: str
53
+ expires_at: datetime
54
+ expires_in: int
55
+ user_code: str
56
+ verification_url: str
57
+
58
+
59
+ class AccessRefreshTokenResponse(BaseModel):
60
+ access_token: str
61
+ refresh_token: str
62
+
63
+
35
64
  # Strikes
36
65
 
37
66
  SpanStatus = t.Literal[
@@ -406,15 +435,10 @@ class TraceTree(BaseModel):
406
435
  """Children of this span, representing nested spans or tasks."""
407
436
 
408
437
 
409
- # User data credentials
438
+ # Github
410
439
 
411
440
 
412
- class UserDataCredentials(BaseModel):
413
- access_key_id: str
414
- secret_access_key: str
415
- session_token: str
416
- expiration: datetime
417
- region: str
418
- bucket: str
419
- prefix: str
420
- endpoint: str | None
441
+ class GithubTokenResponse(BaseModel):
442
+ token: str
443
+ expires_at: datetime
444
+ repos: list[str]
@@ -1,6 +1,6 @@
1
1
  from logging import getLogger
2
2
 
3
- from .models import (
3
+ from dreadnode.api.models import (
4
4
  Object,
5
5
  ObjectUri,
6
6
  ObjectVal,
@@ -4,10 +4,12 @@ Provides efficient uploading of files and directories with deduplication.
4
4
  """
5
5
 
6
6
  import hashlib
7
+ import typing as t
7
8
  from pathlib import Path
8
9
 
9
10
  import fsspec # type: ignore[import-untyped]
10
11
 
12
+ from dreadnode.storage_utils import with_credential_refresh
11
13
  from dreadnode.util import logger
12
14
 
13
15
  CHUNK_SIZE = 8 * 1024 * 1024 # 8MB
@@ -22,15 +24,27 @@ class ArtifactStorage:
22
24
  - Batch uploads for directories handled by fsspec
23
25
  """
24
26
 
25
- def __init__(self, file_system: fsspec.AbstractFileSystem):
27
+ def __init__(
28
+ self,
29
+ file_system: fsspec.AbstractFileSystem,
30
+ credential_refresher: t.Callable[[], bool] | None = None,
31
+ ):
26
32
  """
27
33
  Initialize artifact storage with a file system and prefix path.
28
34
 
29
35
  Args:
30
36
  file_system: FSSpec-compatible file system
37
+ credential_refresher: Optional function to refresh credentials when it's about to expire
31
38
  """
32
39
  self._file_system = file_system
40
+ self._credential_refresher = credential_refresher
33
41
 
42
+ def _refresh_credentials_if_needed(self) -> None:
43
+ """Refresh credentials if refresher is available."""
44
+ if self._credential_refresher:
45
+ self._credential_refresher()
46
+
47
+ @with_credential_refresh
34
48
  def store_file(self, file_path: Path, target_key: str) -> str:
35
49
  """
36
50
  Store a file in the storage system, using multipart upload for large files.
@@ -0,0 +1,3 @@
1
+ from dreadnode.cli.main import cli
2
+
3
+ __all__ = ["cli"]
@@ -0,0 +1,79 @@
1
+ import atexit
2
+ import base64
3
+ import json
4
+ from datetime import datetime, timezone
5
+
6
+ from dreadnode.api.client import ApiClient
7
+ from dreadnode.config import UserConfig
8
+ from dreadnode.constants import (
9
+ DEFAULT_TOKEN_MAX_TTL,
10
+ )
11
+
12
+
13
+ class Token:
14
+ """A JWT token with an expiration time."""
15
+
16
+ data: str
17
+ expires_at: datetime
18
+
19
+ @staticmethod
20
+ def parse_jwt_token_expiration(token: str) -> datetime:
21
+ """Return the expiration date from a JWT token."""
22
+
23
+ _, b64payload, _ = token.split(".")
24
+ payload = base64.urlsafe_b64decode(b64payload + "==").decode("utf-8")
25
+ return datetime.fromtimestamp(json.loads(payload).get("exp"), tz=timezone.utc)
26
+
27
+ def __init__(self, token: str):
28
+ self.data = token
29
+ self.expires_at = Token.parse_jwt_token_expiration(token)
30
+
31
+ def ttl(self) -> int:
32
+ """Get number of seconds left until the token expires."""
33
+ return int((self.expires_at - datetime.now(tz=timezone.utc)).total_seconds())
34
+
35
+ def is_expired(self) -> bool:
36
+ """Return True if the token is expired."""
37
+ return self.ttl() <= 0
38
+
39
+ def is_close_to_expiry(self) -> bool:
40
+ """Return True if the token is close to expiry."""
41
+ return self.ttl() <= DEFAULT_TOKEN_MAX_TTL
42
+
43
+
44
+ def create_api_client(*, profile: str | None = None) -> ApiClient:
45
+ """Create an authenticated API client using stored configuration data."""
46
+
47
+ user_config = UserConfig.read()
48
+ config = user_config.get_server_config(profile)
49
+
50
+ client = ApiClient(
51
+ config.url,
52
+ cookies={"access_token": config.access_token, "refresh_token": config.refresh_token},
53
+ )
54
+
55
+ # Preemptively check if the token is expired
56
+ if Token(config.refresh_token).is_expired():
57
+ raise RuntimeError("Authentication expired, use [bold]dreadnode login[/]")
58
+
59
+ def _flush_auth_changes() -> None:
60
+ """Flush the authentication data to disk if it has been updated."""
61
+
62
+ access_token = client._client.cookies.get("access_token") # noqa: SLF001
63
+ refresh_token = client._client.cookies.get("refresh_token") # noqa: SLF001
64
+
65
+ changed: bool = False
66
+ if access_token and access_token != config.access_token:
67
+ changed = True
68
+ config.access_token = access_token
69
+
70
+ if refresh_token and refresh_token != config.refresh_token:
71
+ changed = True
72
+ config.refresh_token = refresh_token
73
+
74
+ if changed:
75
+ user_config.set_server_config(config, profile).write()
76
+
77
+ atexit.register(_flush_auth_changes)
78
+
79
+ return client