dreadnode 1.12.1__tar.gz → 1.13.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dreadnode-1.12.1 → dreadnode-1.13.0}/PKG-INFO +2 -1
- {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/__init__.py +7 -2
- dreadnode-1.13.0/dreadnode/__main__.py +10 -0
- {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/api/client.py +113 -26
- {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/api/models.py +34 -10
- {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/api/util.py +1 -1
- dreadnode-1.13.0/dreadnode/cli/__init__.py +3 -0
- dreadnode-1.13.0/dreadnode/cli/api.py +79 -0
- dreadnode-1.13.0/dreadnode/cli/github.py +273 -0
- dreadnode-1.13.0/dreadnode/cli/main.py +200 -0
- dreadnode-1.13.0/dreadnode/cli/profile/__init__.py +3 -0
- dreadnode-1.13.0/dreadnode/cli/profile/cli.py +101 -0
- dreadnode-1.13.0/dreadnode/config.py +108 -0
- dreadnode-1.13.0/dreadnode/constants.py +57 -0
- dreadnode-1.13.0/dreadnode/data_types/__init__.py +9 -0
- dreadnode-1.13.0/dreadnode/lookup.py +146 -0
- {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/main.py +87 -102
- {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/object.py +3 -1
- {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/scorers/__init__.py +27 -3
- dreadnode-1.13.0/dreadnode/scorers/classification.py +102 -0
- {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/scorers/consistency.py +16 -18
- {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/scorers/contains.py +50 -31
- dreadnode-1.13.0/dreadnode/scorers/format.py +61 -0
- dreadnode-1.13.0/dreadnode/scorers/harm.py +55 -0
- dreadnode-1.12.1/dreadnode/scorers/llm_judge.py → dreadnode-1.13.0/dreadnode/scorers/judge.py +18 -20
- {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/scorers/length.py +28 -23
- dreadnode-1.13.0/dreadnode/scorers/lexical.py +67 -0
- dreadnode-1.13.0/dreadnode/scorers/operators.py +122 -0
- {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/scorers/pii.py +1 -2
- {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/scorers/readability.py +6 -2
- {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/scorers/sentiment.py +14 -2
- dreadnode-1.13.0/dreadnode/scorers/similarity.py +293 -0
- dreadnode-1.13.0/dreadnode/scorers/util.py +33 -0
- {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/task.py +0 -84
- {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/tracing/span.py +6 -7
- {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/util.py +119 -0
- {dreadnode-1.12.1 → dreadnode-1.13.0}/pyproject.toml +13 -2
- dreadnode-1.12.1/dreadnode/constants.py +0 -16
- dreadnode-1.12.1/dreadnode/data_types/__init__.py +0 -9
- dreadnode-1.12.1/dreadnode/scorers/similarity.py +0 -180
- {dreadnode-1.12.1 → dreadnode-1.13.0}/README.md +0 -0
- {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/api/__init__.py +0 -0
- {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/artifact/__init__.py +0 -0
- {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/artifact/merger.py +0 -0
- {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/artifact/storage.py +0 -0
- {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/artifact/tree_builder.py +0 -0
- {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/convert.py +0 -0
- {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/data_types/audio.py +0 -0
- {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/data_types/base.py +0 -0
- {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/data_types/image.py +0 -0
- {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/data_types/object_3d.py +0 -0
- {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/data_types/table.py +0 -0
- {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/data_types/text.py +0 -0
- {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/data_types/video.py +0 -0
- {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/integrations/__init__.py +0 -0
- {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/integrations/transformers.py +0 -0
- {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/metric.py +0 -0
- {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/py.typed +0 -0
- {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/scorers/rigging.py +0 -0
- {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/serialization.py +0 -0
- {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/tracing/__init__.py +0 -0
- {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/tracing/constants.py +0 -0
- {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/tracing/exporters.py +0 -0
- {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/types.py +0 -0
- {dreadnode-1.12.1 → dreadnode-1.13.0}/dreadnode/version.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: dreadnode
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.13.0
|
|
4
4
|
Summary: Dreadnode SDK
|
|
5
5
|
Author: Nick Landers
|
|
6
6
|
Author-email: monoxgas@gmail.com
|
|
@@ -14,6 +14,7 @@ Provides-Extra: all
|
|
|
14
14
|
Provides-Extra: multimodal
|
|
15
15
|
Provides-Extra: training
|
|
16
16
|
Requires-Dist: coolname (>=2.2.0,<3.0.0)
|
|
17
|
+
Requires-Dist: cyclopts (>=3.22.2,<4.0.0)
|
|
17
18
|
Requires-Dist: fsspec[s3] (>=2023.1.0,<=2025.3.0)
|
|
18
19
|
Requires-Dist: httpx (>=0.28.0,<0.29.0)
|
|
19
20
|
Requires-Dist: logfire (>=3.5.3,<=3.20.0)
|
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
from dreadnode import convert, data_types, scorers
|
|
2
2
|
from dreadnode.data_types import Audio, Code, Image, Markdown, Object3D, Table, Text, Video
|
|
3
|
+
from dreadnode.lookup import Lookup, lookup_input, lookup_output, lookup_param, resolve_lookup
|
|
3
4
|
from dreadnode.main import DEFAULT_INSTANCE, Dreadnode
|
|
4
5
|
from dreadnode.metric import Metric, MetricDict, Scorer
|
|
5
6
|
from dreadnode.object import Object
|
|
6
|
-
from dreadnode.task import Task
|
|
7
|
+
from dreadnode.task import Task
|
|
7
8
|
from dreadnode.tracing.span import RunSpan, Span, TaskSpan
|
|
8
9
|
from dreadnode.version import VERSION
|
|
9
10
|
|
|
@@ -39,6 +40,7 @@ __all__ = [
|
|
|
39
40
|
"Code",
|
|
40
41
|
"Dreadnode",
|
|
41
42
|
"Image",
|
|
43
|
+
"Lookup",
|
|
42
44
|
"Markdown",
|
|
43
45
|
"Metric",
|
|
44
46
|
"MetricDict",
|
|
@@ -50,7 +52,6 @@ __all__ = [
|
|
|
50
52
|
"Span",
|
|
51
53
|
"Table",
|
|
52
54
|
"Task",
|
|
53
|
-
"TaskInput",
|
|
54
55
|
"TaskSpan",
|
|
55
56
|
"Text",
|
|
56
57
|
"Video",
|
|
@@ -69,7 +70,11 @@ __all__ = [
|
|
|
69
70
|
"log_output",
|
|
70
71
|
"log_param",
|
|
71
72
|
"log_params",
|
|
73
|
+
"lookup_input",
|
|
74
|
+
"lookup_output",
|
|
75
|
+
"lookup_param",
|
|
72
76
|
"push_update",
|
|
77
|
+
"resolve_lookup",
|
|
73
78
|
"run",
|
|
74
79
|
"scorer",
|
|
75
80
|
"scorers",
|
|
@@ -1,22 +1,19 @@
|
|
|
1
1
|
import io
|
|
2
2
|
import json
|
|
3
|
+
import time
|
|
3
4
|
import typing as t
|
|
5
|
+
from datetime import datetime, timezone
|
|
6
|
+
from urllib.parse import urlparse
|
|
4
7
|
|
|
5
8
|
import httpx
|
|
6
9
|
import pandas as pd
|
|
7
10
|
from pydantic import BaseModel
|
|
8
11
|
from ulid import ULID
|
|
9
12
|
|
|
10
|
-
from dreadnode.api.
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
process_task,
|
|
15
|
-
)
|
|
16
|
-
from dreadnode.util import logger
|
|
17
|
-
from dreadnode.version import VERSION
|
|
18
|
-
|
|
19
|
-
from .models import (
|
|
13
|
+
from dreadnode.api.models import (
|
|
14
|
+
AccessRefreshTokenResponse,
|
|
15
|
+
DeviceCodeResponse,
|
|
16
|
+
GithubTokenResponse,
|
|
20
17
|
MetricAggregationType,
|
|
21
18
|
Project,
|
|
22
19
|
RawRun,
|
|
@@ -31,7 +28,17 @@ from .models import (
|
|
|
31
28
|
TraceSpan,
|
|
32
29
|
TraceTree,
|
|
33
30
|
UserDataCredentials,
|
|
31
|
+
UserResponse,
|
|
34
32
|
)
|
|
33
|
+
from dreadnode.api.util import (
|
|
34
|
+
convert_flat_tasks_to_tree,
|
|
35
|
+
convert_flat_trace_to_tree,
|
|
36
|
+
process_run,
|
|
37
|
+
process_task,
|
|
38
|
+
)
|
|
39
|
+
from dreadnode.constants import DEFAULT_MAX_POLL_TIME, DEFAULT_POLL_INTERVAL
|
|
40
|
+
from dreadnode.util import logger
|
|
41
|
+
from dreadnode.version import VERSION
|
|
35
42
|
|
|
36
43
|
ModelT = t.TypeVar("ModelT", bound=BaseModel)
|
|
37
44
|
|
|
@@ -47,11 +54,13 @@ class ApiClient:
|
|
|
47
54
|
def __init__(
|
|
48
55
|
self,
|
|
49
56
|
base_url: str,
|
|
50
|
-
api_key: str,
|
|
51
57
|
*,
|
|
58
|
+
api_key: str | None = None,
|
|
59
|
+
cookies: dict[str, str] | None = None,
|
|
52
60
|
debug: bool = False,
|
|
53
61
|
):
|
|
54
|
-
"""
|
|
62
|
+
"""
|
|
63
|
+
Initializes the API client.
|
|
55
64
|
|
|
56
65
|
Args:
|
|
57
66
|
base_url (str): The base URL of the Dreadnode API.
|
|
@@ -62,12 +71,28 @@ class ApiClient:
|
|
|
62
71
|
if not self._base_url.endswith("/api"):
|
|
63
72
|
self._base_url += "/api"
|
|
64
73
|
|
|
74
|
+
_cookies = httpx.Cookies()
|
|
75
|
+
cookie_domain = urlparse(base_url).hostname
|
|
76
|
+
if cookie_domain is None:
|
|
77
|
+
raise ValueError(f"Invalid URL: {base_url}")
|
|
78
|
+
|
|
79
|
+
if cookie_domain == "localhost":
|
|
80
|
+
cookie_domain = "localhost.local"
|
|
81
|
+
|
|
82
|
+
for key, value in (cookies or {}).items():
|
|
83
|
+
_cookies.set(key, value, domain=cookie_domain)
|
|
84
|
+
|
|
85
|
+
headers = {
|
|
86
|
+
"User-Agent": f"dreadnode-sdk/{VERSION}",
|
|
87
|
+
"Accept": "application/json",
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
if api_key:
|
|
91
|
+
headers["X-Api-Key"] = api_key
|
|
92
|
+
|
|
65
93
|
self._client = httpx.Client(
|
|
66
|
-
headers=
|
|
67
|
-
|
|
68
|
-
"Accept": "application/json",
|
|
69
|
-
"X-API-Key": api_key,
|
|
70
|
-
},
|
|
94
|
+
headers=headers,
|
|
95
|
+
cookies=_cookies,
|
|
71
96
|
base_url=self._base_url,
|
|
72
97
|
timeout=30,
|
|
73
98
|
)
|
|
@@ -77,7 +102,8 @@ class ApiClient:
|
|
|
77
102
|
self._client.event_hooks["response"].append(self._log_response)
|
|
78
103
|
|
|
79
104
|
def _log_request(self, request: httpx.Request) -> None:
|
|
80
|
-
"""
|
|
105
|
+
"""
|
|
106
|
+
Logs HTTP requests if debug mode is enabled.
|
|
81
107
|
|
|
82
108
|
Args:
|
|
83
109
|
request (httpx.Request): The HTTP request object.
|
|
@@ -90,7 +116,8 @@ class ApiClient:
|
|
|
90
116
|
logger.debug("-------------------------------------------")
|
|
91
117
|
|
|
92
118
|
def _log_response(self, response: httpx.Response) -> None:
|
|
93
|
-
"""
|
|
119
|
+
"""
|
|
120
|
+
Logs HTTP responses if debug mode is enabled.
|
|
94
121
|
|
|
95
122
|
Args:
|
|
96
123
|
response (httpx.Response): The HTTP response object.
|
|
@@ -103,7 +130,8 @@ class ApiClient:
|
|
|
103
130
|
logger.debug("--------------------------------------------")
|
|
104
131
|
|
|
105
132
|
def _get_error_message(self, response: httpx.Response) -> str:
|
|
106
|
-
"""
|
|
133
|
+
"""
|
|
134
|
+
Extracts the error message from an HTTP response.
|
|
107
135
|
|
|
108
136
|
Args:
|
|
109
137
|
response (httpx.Response): The HTTP response object.
|
|
@@ -125,7 +153,8 @@ class ApiClient:
|
|
|
125
153
|
params: dict[str, t.Any] | None = None,
|
|
126
154
|
json_data: dict[str, t.Any] | None = None,
|
|
127
155
|
) -> httpx.Response:
|
|
128
|
-
"""
|
|
156
|
+
"""
|
|
157
|
+
Makes a raw HTTP request to the API.
|
|
129
158
|
|
|
130
159
|
Args:
|
|
131
160
|
method (str): The HTTP method (e.g., "GET", "POST").
|
|
@@ -146,7 +175,8 @@ class ApiClient:
|
|
|
146
175
|
params: dict[str, t.Any] | None = None,
|
|
147
176
|
json_data: dict[str, t.Any] | None = None,
|
|
148
177
|
) -> httpx.Response:
|
|
149
|
-
"""
|
|
178
|
+
"""
|
|
179
|
+
Makes an HTTP request to the API and raises exceptions for errors.
|
|
150
180
|
|
|
151
181
|
Args:
|
|
152
182
|
method (str): The HTTP method (e.g., "GET", "POST").
|
|
@@ -170,6 +200,59 @@ class ApiClient:
|
|
|
170
200
|
|
|
171
201
|
return response
|
|
172
202
|
|
|
203
|
+
# Auth
|
|
204
|
+
|
|
205
|
+
def url_for_user_code(self, user_code: str) -> str:
|
|
206
|
+
"""Get the URL to verify the user code."""
|
|
207
|
+
|
|
208
|
+
return f"{self._base_url.removesuffix('/api')}/account/device?code={user_code}"
|
|
209
|
+
|
|
210
|
+
def get_device_codes(self) -> DeviceCodeResponse:
|
|
211
|
+
"""Start the authentication flow by requesting user and device codes."""
|
|
212
|
+
|
|
213
|
+
response = self.request("POST", "/auth/device/code")
|
|
214
|
+
return DeviceCodeResponse(**response.json())
|
|
215
|
+
|
|
216
|
+
def poll_for_token(
|
|
217
|
+
self,
|
|
218
|
+
device_code: str,
|
|
219
|
+
interval: int = DEFAULT_POLL_INTERVAL,
|
|
220
|
+
max_poll_time: int = DEFAULT_MAX_POLL_TIME,
|
|
221
|
+
) -> AccessRefreshTokenResponse:
|
|
222
|
+
"""Poll for the access token with the given device code."""
|
|
223
|
+
|
|
224
|
+
start_time = datetime.now(timezone.utc)
|
|
225
|
+
while (datetime.now(timezone.utc) - start_time).total_seconds() < max_poll_time:
|
|
226
|
+
response = self._request(
|
|
227
|
+
"POST", "/auth/device/token", json_data={"device_code": device_code}
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
if response.status_code == 200: # noqa: PLR2004
|
|
231
|
+
return AccessRefreshTokenResponse(**response.json())
|
|
232
|
+
if response.status_code != 401: # noqa: PLR2004
|
|
233
|
+
raise RuntimeError(self._get_error_message(response))
|
|
234
|
+
|
|
235
|
+
time.sleep(interval)
|
|
236
|
+
|
|
237
|
+
raise RuntimeError("Polling for token timed out")
|
|
238
|
+
|
|
239
|
+
# User
|
|
240
|
+
|
|
241
|
+
def get_user(self) -> UserResponse:
|
|
242
|
+
"""Get the user email and username."""
|
|
243
|
+
|
|
244
|
+
response = self.request("GET", "/user")
|
|
245
|
+
return UserResponse(**response.json())
|
|
246
|
+
|
|
247
|
+
# Github
|
|
248
|
+
|
|
249
|
+
def get_github_access_token(self, repos: list[str]) -> GithubTokenResponse:
|
|
250
|
+
"""Try to get a GitHub access token for the given repositories."""
|
|
251
|
+
response = self.request("POST", "/github/token", json_data={"repos": repos})
|
|
252
|
+
return GithubTokenResponse(**response.json())
|
|
253
|
+
|
|
254
|
+
# Strikes
|
|
255
|
+
|
|
173
256
|
def list_projects(self) -> list[Project]:
|
|
174
257
|
"""Retrieves a list of projects.
|
|
175
258
|
|
|
@@ -294,7 +377,8 @@ class ApiClient:
|
|
|
294
377
|
status: StatusFilter = "completed",
|
|
295
378
|
aggregations: list[MetricAggregationType] | None = None,
|
|
296
379
|
) -> pd.DataFrame:
|
|
297
|
-
"""
|
|
380
|
+
"""
|
|
381
|
+
Exports run data for a specific project.
|
|
298
382
|
|
|
299
383
|
Args:
|
|
300
384
|
project: The project identifier.
|
|
@@ -327,7 +411,8 @@ class ApiClient:
|
|
|
327
411
|
metrics: list[str] | None = None,
|
|
328
412
|
aggregations: list[MetricAggregationType] | None = None,
|
|
329
413
|
) -> pd.DataFrame:
|
|
330
|
-
"""
|
|
414
|
+
"""
|
|
415
|
+
Exports metric data for a specific project.
|
|
331
416
|
|
|
332
417
|
Args:
|
|
333
418
|
project: The project identifier.
|
|
@@ -363,7 +448,8 @@ class ApiClient:
|
|
|
363
448
|
metrics: list[str] | None = None,
|
|
364
449
|
aggregations: list[MetricAggregationType] | None = None,
|
|
365
450
|
) -> pd.DataFrame:
|
|
366
|
-
"""
|
|
451
|
+
"""
|
|
452
|
+
Exports parameter data for a specific project.
|
|
367
453
|
|
|
368
454
|
Args:
|
|
369
455
|
project: The project identifier.
|
|
@@ -401,7 +487,8 @@ class ApiClient:
|
|
|
401
487
|
time_axis: TimeAxisType = "relative",
|
|
402
488
|
aggregations: list[TimeAggregationType] | None = None,
|
|
403
489
|
) -> pd.DataFrame:
|
|
404
|
-
"""
|
|
490
|
+
"""
|
|
491
|
+
Exports timeseries data for a specific project.
|
|
405
492
|
|
|
406
493
|
Args:
|
|
407
494
|
project: The project identifier.
|
|
@@ -32,6 +32,35 @@ class UserResponse(BaseModel):
|
|
|
32
32
|
api_key: UserAPIKey
|
|
33
33
|
|
|
34
34
|
|
|
35
|
+
class UserDataCredentials(BaseModel):
|
|
36
|
+
access_key_id: str
|
|
37
|
+
secret_access_key: str
|
|
38
|
+
session_token: str
|
|
39
|
+
expiration: datetime
|
|
40
|
+
region: str
|
|
41
|
+
bucket: str
|
|
42
|
+
prefix: str
|
|
43
|
+
endpoint: str | None
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
# Auth
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class DeviceCodeResponse(BaseModel):
|
|
50
|
+
id: UUID
|
|
51
|
+
completed: bool
|
|
52
|
+
device_code: str
|
|
53
|
+
expires_at: datetime
|
|
54
|
+
expires_in: int
|
|
55
|
+
user_code: str
|
|
56
|
+
verification_url: str
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class AccessRefreshTokenResponse(BaseModel):
|
|
60
|
+
access_token: str
|
|
61
|
+
refresh_token: str
|
|
62
|
+
|
|
63
|
+
|
|
35
64
|
# Strikes
|
|
36
65
|
|
|
37
66
|
SpanStatus = t.Literal[
|
|
@@ -406,15 +435,10 @@ class TraceTree(BaseModel):
|
|
|
406
435
|
"""Children of this span, representing nested spans or tasks."""
|
|
407
436
|
|
|
408
437
|
|
|
409
|
-
#
|
|
438
|
+
# Github
|
|
410
439
|
|
|
411
440
|
|
|
412
|
-
class
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
expiration: datetime
|
|
417
|
-
region: str
|
|
418
|
-
bucket: str
|
|
419
|
-
prefix: str
|
|
420
|
-
endpoint: str | None
|
|
441
|
+
class GithubTokenResponse(BaseModel):
|
|
442
|
+
token: str
|
|
443
|
+
expires_at: datetime
|
|
444
|
+
repos: list[str]
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
import atexit
|
|
2
|
+
import base64
|
|
3
|
+
import json
|
|
4
|
+
from datetime import datetime, timezone
|
|
5
|
+
|
|
6
|
+
from dreadnode.api.client import ApiClient
|
|
7
|
+
from dreadnode.config import UserConfig
|
|
8
|
+
from dreadnode.constants import (
|
|
9
|
+
DEFAULT_TOKEN_MAX_TTL,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Token:
|
|
14
|
+
"""A JWT token with an expiration time."""
|
|
15
|
+
|
|
16
|
+
data: str
|
|
17
|
+
expires_at: datetime
|
|
18
|
+
|
|
19
|
+
@staticmethod
|
|
20
|
+
def parse_jwt_token_expiration(token: str) -> datetime:
|
|
21
|
+
"""Return the expiration date from a JWT token."""
|
|
22
|
+
|
|
23
|
+
_, b64payload, _ = token.split(".")
|
|
24
|
+
payload = base64.urlsafe_b64decode(b64payload + "==").decode("utf-8")
|
|
25
|
+
return datetime.fromtimestamp(json.loads(payload).get("exp"), tz=timezone.utc)
|
|
26
|
+
|
|
27
|
+
def __init__(self, token: str):
|
|
28
|
+
self.data = token
|
|
29
|
+
self.expires_at = Token.parse_jwt_token_expiration(token)
|
|
30
|
+
|
|
31
|
+
def ttl(self) -> int:
|
|
32
|
+
"""Get number of seconds left until the token expires."""
|
|
33
|
+
return int((self.expires_at - datetime.now(tz=timezone.utc)).total_seconds())
|
|
34
|
+
|
|
35
|
+
def is_expired(self) -> bool:
|
|
36
|
+
"""Return True if the token is expired."""
|
|
37
|
+
return self.ttl() <= 0
|
|
38
|
+
|
|
39
|
+
def is_close_to_expiry(self) -> bool:
|
|
40
|
+
"""Return True if the token is close to expiry."""
|
|
41
|
+
return self.ttl() <= DEFAULT_TOKEN_MAX_TTL
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def create_api_client(*, profile: str | None = None) -> ApiClient:
|
|
45
|
+
"""Create an authenticated API client using stored configuration data."""
|
|
46
|
+
|
|
47
|
+
user_config = UserConfig.read()
|
|
48
|
+
config = user_config.get_server_config(profile)
|
|
49
|
+
|
|
50
|
+
client = ApiClient(
|
|
51
|
+
config.url,
|
|
52
|
+
cookies={"access_token": config.access_token, "refresh_token": config.refresh_token},
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
# Preemptively check if the token is expired
|
|
56
|
+
if Token(config.refresh_token).is_expired():
|
|
57
|
+
raise RuntimeError("Authentication expired, use [bold]dreadnode login[/]")
|
|
58
|
+
|
|
59
|
+
def _flush_auth_changes() -> None:
|
|
60
|
+
"""Flush the authentication data to disk if it has been updated."""
|
|
61
|
+
|
|
62
|
+
access_token = client._client.cookies.get("access_token") # noqa: SLF001
|
|
63
|
+
refresh_token = client._client.cookies.get("refresh_token") # noqa: SLF001
|
|
64
|
+
|
|
65
|
+
changed: bool = False
|
|
66
|
+
if access_token and access_token != config.access_token:
|
|
67
|
+
changed = True
|
|
68
|
+
config.access_token = access_token
|
|
69
|
+
|
|
70
|
+
if refresh_token and refresh_token != config.refresh_token:
|
|
71
|
+
changed = True
|
|
72
|
+
config.refresh_token = refresh_token
|
|
73
|
+
|
|
74
|
+
if changed:
|
|
75
|
+
user_config.set_server_config(config, profile).write()
|
|
76
|
+
|
|
77
|
+
atexit.register(_flush_auth_changes)
|
|
78
|
+
|
|
79
|
+
return client
|