dreadnode 1.12.2__tar.gz → 1.13.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dreadnode-1.12.2 → dreadnode-1.13.1}/PKG-INFO +2 -1
- {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/__init__.py +7 -2
- dreadnode-1.13.1/dreadnode/__main__.py +10 -0
- {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/api/client.py +124 -28
- {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/api/models.py +34 -10
- {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/api/util.py +1 -1
- {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/artifact/storage.py +15 -1
- dreadnode-1.13.1/dreadnode/cli/__init__.py +3 -0
- dreadnode-1.13.1/dreadnode/cli/api.py +79 -0
- dreadnode-1.13.1/dreadnode/cli/github.py +273 -0
- dreadnode-1.13.1/dreadnode/cli/main.py +200 -0
- dreadnode-1.13.1/dreadnode/cli/profile/__init__.py +3 -0
- dreadnode-1.13.1/dreadnode/cli/profile/cli.py +101 -0
- dreadnode-1.13.1/dreadnode/config.py +108 -0
- dreadnode-1.13.1/dreadnode/constants.py +62 -0
- dreadnode-1.13.1/dreadnode/data_types/__init__.py +9 -0
- dreadnode-1.13.1/dreadnode/lookup.py +146 -0
- {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/main.py +153 -114
- {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/object.py +1 -2
- {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/scorers/__init__.py +27 -3
- dreadnode-1.13.1/dreadnode/scorers/classification.py +102 -0
- {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/scorers/consistency.py +16 -18
- {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/scorers/contains.py +50 -31
- dreadnode-1.13.1/dreadnode/scorers/format.py +61 -0
- dreadnode-1.13.1/dreadnode/scorers/harm.py +55 -0
- dreadnode-1.12.2/dreadnode/scorers/llm_judge.py → dreadnode-1.13.1/dreadnode/scorers/judge.py +18 -20
- {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/scorers/length.py +28 -23
- dreadnode-1.13.1/dreadnode/scorers/lexical.py +67 -0
- dreadnode-1.13.1/dreadnode/scorers/operators.py +122 -0
- {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/scorers/pii.py +1 -2
- {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/scorers/readability.py +6 -2
- {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/scorers/sentiment.py +14 -2
- dreadnode-1.13.1/dreadnode/scorers/similarity.py +293 -0
- dreadnode-1.13.1/dreadnode/scorers/util.py +33 -0
- {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/serialization.py +1 -1
- dreadnode-1.13.1/dreadnode/storage_utils.py +37 -0
- {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/task.py +0 -84
- {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/tracing/span.py +19 -8
- {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/util.py +119 -0
- {dreadnode-1.12.2 → dreadnode-1.13.1}/pyproject.toml +13 -2
- dreadnode-1.12.2/dreadnode/constants.py +0 -16
- dreadnode-1.12.2/dreadnode/data_types/__init__.py +0 -9
- dreadnode-1.12.2/dreadnode/scorers/similarity.py +0 -180
- {dreadnode-1.12.2 → dreadnode-1.13.1}/README.md +0 -0
- {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/api/__init__.py +0 -0
- {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/artifact/__init__.py +0 -0
- {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/artifact/merger.py +0 -0
- {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/artifact/tree_builder.py +0 -0
- {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/convert.py +0 -0
- {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/data_types/audio.py +0 -0
- {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/data_types/base.py +0 -0
- {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/data_types/image.py +0 -0
- {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/data_types/object_3d.py +0 -0
- {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/data_types/table.py +0 -0
- {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/data_types/text.py +0 -0
- {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/data_types/video.py +0 -0
- {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/integrations/__init__.py +0 -0
- {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/integrations/transformers.py +0 -0
- {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/metric.py +0 -0
- {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/py.typed +0 -0
- {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/scorers/rigging.py +0 -0
- {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/tracing/__init__.py +0 -0
- {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/tracing/constants.py +0 -0
- {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/tracing/exporters.py +0 -0
- {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/types.py +0 -0
- {dreadnode-1.12.2 → dreadnode-1.13.1}/dreadnode/version.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: dreadnode
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.13.1
|
|
4
4
|
Summary: Dreadnode SDK
|
|
5
5
|
Author: Nick Landers
|
|
6
6
|
Author-email: monoxgas@gmail.com
|
|
@@ -14,6 +14,7 @@ Provides-Extra: all
|
|
|
14
14
|
Provides-Extra: multimodal
|
|
15
15
|
Provides-Extra: training
|
|
16
16
|
Requires-Dist: coolname (>=2.2.0,<3.0.0)
|
|
17
|
+
Requires-Dist: cyclopts (>=3.22.2,<4.0.0)
|
|
17
18
|
Requires-Dist: fsspec[s3] (>=2023.1.0,<=2025.3.0)
|
|
18
19
|
Requires-Dist: httpx (>=0.28.0,<0.29.0)
|
|
19
20
|
Requires-Dist: logfire (>=3.5.3,<=3.20.0)
|
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
from dreadnode import convert, data_types, scorers
|
|
2
2
|
from dreadnode.data_types import Audio, Code, Image, Markdown, Object3D, Table, Text, Video
|
|
3
|
+
from dreadnode.lookup import Lookup, lookup_input, lookup_output, lookup_param, resolve_lookup
|
|
3
4
|
from dreadnode.main import DEFAULT_INSTANCE, Dreadnode
|
|
4
5
|
from dreadnode.metric import Metric, MetricDict, Scorer
|
|
5
6
|
from dreadnode.object import Object
|
|
6
|
-
from dreadnode.task import Task
|
|
7
|
+
from dreadnode.task import Task
|
|
7
8
|
from dreadnode.tracing.span import RunSpan, Span, TaskSpan
|
|
8
9
|
from dreadnode.version import VERSION
|
|
9
10
|
|
|
@@ -39,6 +40,7 @@ __all__ = [
|
|
|
39
40
|
"Code",
|
|
40
41
|
"Dreadnode",
|
|
41
42
|
"Image",
|
|
43
|
+
"Lookup",
|
|
42
44
|
"Markdown",
|
|
43
45
|
"Metric",
|
|
44
46
|
"MetricDict",
|
|
@@ -50,7 +52,6 @@ __all__ = [
|
|
|
50
52
|
"Span",
|
|
51
53
|
"Table",
|
|
52
54
|
"Task",
|
|
53
|
-
"TaskInput",
|
|
54
55
|
"TaskSpan",
|
|
55
56
|
"Text",
|
|
56
57
|
"Video",
|
|
@@ -69,7 +70,11 @@ __all__ = [
|
|
|
69
70
|
"log_output",
|
|
70
71
|
"log_param",
|
|
71
72
|
"log_params",
|
|
73
|
+
"lookup_input",
|
|
74
|
+
"lookup_output",
|
|
75
|
+
"lookup_param",
|
|
72
76
|
"push_update",
|
|
77
|
+
"resolve_lookup",
|
|
73
78
|
"run",
|
|
74
79
|
"scorer",
|
|
75
80
|
"scorers",
|
|
@@ -1,22 +1,19 @@
|
|
|
1
1
|
import io
|
|
2
2
|
import json
|
|
3
|
+
import time
|
|
3
4
|
import typing as t
|
|
5
|
+
from datetime import datetime, timezone
|
|
6
|
+
from urllib.parse import urlparse
|
|
4
7
|
|
|
5
8
|
import httpx
|
|
6
9
|
import pandas as pd
|
|
7
10
|
from pydantic import BaseModel
|
|
8
11
|
from ulid import ULID
|
|
9
12
|
|
|
10
|
-
from dreadnode.api.
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
process_task,
|
|
15
|
-
)
|
|
16
|
-
from dreadnode.util import logger
|
|
17
|
-
from dreadnode.version import VERSION
|
|
18
|
-
|
|
19
|
-
from .models import (
|
|
13
|
+
from dreadnode.api.models import (
|
|
14
|
+
AccessRefreshTokenResponse,
|
|
15
|
+
DeviceCodeResponse,
|
|
16
|
+
GithubTokenResponse,
|
|
20
17
|
MetricAggregationType,
|
|
21
18
|
Project,
|
|
22
19
|
RawRun,
|
|
@@ -31,7 +28,21 @@ from .models import (
|
|
|
31
28
|
TraceSpan,
|
|
32
29
|
TraceTree,
|
|
33
30
|
UserDataCredentials,
|
|
31
|
+
UserResponse,
|
|
34
32
|
)
|
|
33
|
+
from dreadnode.api.util import (
|
|
34
|
+
convert_flat_tasks_to_tree,
|
|
35
|
+
convert_flat_trace_to_tree,
|
|
36
|
+
process_run,
|
|
37
|
+
process_task,
|
|
38
|
+
)
|
|
39
|
+
from dreadnode.constants import (
|
|
40
|
+
DEFAULT_FS_CREDENTIAL_DURATION,
|
|
41
|
+
DEFAULT_MAX_POLL_TIME,
|
|
42
|
+
DEFAULT_POLL_INTERVAL,
|
|
43
|
+
)
|
|
44
|
+
from dreadnode.util import logger
|
|
45
|
+
from dreadnode.version import VERSION
|
|
35
46
|
|
|
36
47
|
ModelT = t.TypeVar("ModelT", bound=BaseModel)
|
|
37
48
|
|
|
@@ -47,11 +58,13 @@ class ApiClient:
|
|
|
47
58
|
def __init__(
|
|
48
59
|
self,
|
|
49
60
|
base_url: str,
|
|
50
|
-
api_key: str,
|
|
51
61
|
*,
|
|
62
|
+
api_key: str | None = None,
|
|
63
|
+
cookies: dict[str, str] | None = None,
|
|
52
64
|
debug: bool = False,
|
|
53
65
|
):
|
|
54
|
-
"""
|
|
66
|
+
"""
|
|
67
|
+
Initializes the API client.
|
|
55
68
|
|
|
56
69
|
Args:
|
|
57
70
|
base_url (str): The base URL of the Dreadnode API.
|
|
@@ -62,12 +75,28 @@ class ApiClient:
|
|
|
62
75
|
if not self._base_url.endswith("/api"):
|
|
63
76
|
self._base_url += "/api"
|
|
64
77
|
|
|
78
|
+
_cookies = httpx.Cookies()
|
|
79
|
+
cookie_domain = urlparse(base_url).hostname
|
|
80
|
+
if cookie_domain is None:
|
|
81
|
+
raise ValueError(f"Invalid URL: {base_url}")
|
|
82
|
+
|
|
83
|
+
if cookie_domain == "localhost":
|
|
84
|
+
cookie_domain = "localhost.local"
|
|
85
|
+
|
|
86
|
+
for key, value in (cookies or {}).items():
|
|
87
|
+
_cookies.set(key, value, domain=cookie_domain)
|
|
88
|
+
|
|
89
|
+
headers = {
|
|
90
|
+
"User-Agent": f"dreadnode-sdk/{VERSION}",
|
|
91
|
+
"Accept": "application/json",
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
if api_key:
|
|
95
|
+
headers["X-Api-Key"] = api_key
|
|
96
|
+
|
|
65
97
|
self._client = httpx.Client(
|
|
66
|
-
headers=
|
|
67
|
-
|
|
68
|
-
"Accept": "application/json",
|
|
69
|
-
"X-API-Key": api_key,
|
|
70
|
-
},
|
|
98
|
+
headers=headers,
|
|
99
|
+
cookies=_cookies,
|
|
71
100
|
base_url=self._base_url,
|
|
72
101
|
timeout=30,
|
|
73
102
|
)
|
|
@@ -77,7 +106,8 @@ class ApiClient:
|
|
|
77
106
|
self._client.event_hooks["response"].append(self._log_response)
|
|
78
107
|
|
|
79
108
|
def _log_request(self, request: httpx.Request) -> None:
|
|
80
|
-
"""
|
|
109
|
+
"""
|
|
110
|
+
Logs HTTP requests if debug mode is enabled.
|
|
81
111
|
|
|
82
112
|
Args:
|
|
83
113
|
request (httpx.Request): The HTTP request object.
|
|
@@ -90,7 +120,8 @@ class ApiClient:
|
|
|
90
120
|
logger.debug("-------------------------------------------")
|
|
91
121
|
|
|
92
122
|
def _log_response(self, response: httpx.Response) -> None:
|
|
93
|
-
"""
|
|
123
|
+
"""
|
|
124
|
+
Logs HTTP responses if debug mode is enabled.
|
|
94
125
|
|
|
95
126
|
Args:
|
|
96
127
|
response (httpx.Response): The HTTP response object.
|
|
@@ -103,7 +134,8 @@ class ApiClient:
|
|
|
103
134
|
logger.debug("--------------------------------------------")
|
|
104
135
|
|
|
105
136
|
def _get_error_message(self, response: httpx.Response) -> str:
|
|
106
|
-
"""
|
|
137
|
+
"""
|
|
138
|
+
Extracts the error message from an HTTP response.
|
|
107
139
|
|
|
108
140
|
Args:
|
|
109
141
|
response (httpx.Response): The HTTP response object.
|
|
@@ -125,7 +157,8 @@ class ApiClient:
|
|
|
125
157
|
params: dict[str, t.Any] | None = None,
|
|
126
158
|
json_data: dict[str, t.Any] | None = None,
|
|
127
159
|
) -> httpx.Response:
|
|
128
|
-
"""
|
|
160
|
+
"""
|
|
161
|
+
Makes a raw HTTP request to the API.
|
|
129
162
|
|
|
130
163
|
Args:
|
|
131
164
|
method (str): The HTTP method (e.g., "GET", "POST").
|
|
@@ -146,7 +179,8 @@ class ApiClient:
|
|
|
146
179
|
params: dict[str, t.Any] | None = None,
|
|
147
180
|
json_data: dict[str, t.Any] | None = None,
|
|
148
181
|
) -> httpx.Response:
|
|
149
|
-
"""
|
|
182
|
+
"""
|
|
183
|
+
Makes an HTTP request to the API and raises exceptions for errors.
|
|
150
184
|
|
|
151
185
|
Args:
|
|
152
186
|
method (str): The HTTP method (e.g., "GET", "POST").
|
|
@@ -170,6 +204,59 @@ class ApiClient:
|
|
|
170
204
|
|
|
171
205
|
return response
|
|
172
206
|
|
|
207
|
+
# Auth
|
|
208
|
+
|
|
209
|
+
def url_for_user_code(self, user_code: str) -> str:
|
|
210
|
+
"""Get the URL to verify the user code."""
|
|
211
|
+
|
|
212
|
+
return f"{self._base_url.removesuffix('/api')}/account/device?code={user_code}"
|
|
213
|
+
|
|
214
|
+
def get_device_codes(self) -> DeviceCodeResponse:
|
|
215
|
+
"""Start the authentication flow by requesting user and device codes."""
|
|
216
|
+
|
|
217
|
+
response = self.request("POST", "/auth/device/code")
|
|
218
|
+
return DeviceCodeResponse(**response.json())
|
|
219
|
+
|
|
220
|
+
def poll_for_token(
|
|
221
|
+
self,
|
|
222
|
+
device_code: str,
|
|
223
|
+
interval: int = DEFAULT_POLL_INTERVAL,
|
|
224
|
+
max_poll_time: int = DEFAULT_MAX_POLL_TIME,
|
|
225
|
+
) -> AccessRefreshTokenResponse:
|
|
226
|
+
"""Poll for the access token with the given device code."""
|
|
227
|
+
|
|
228
|
+
start_time = datetime.now(timezone.utc)
|
|
229
|
+
while (datetime.now(timezone.utc) - start_time).total_seconds() < max_poll_time:
|
|
230
|
+
response = self._request(
|
|
231
|
+
"POST", "/auth/device/token", json_data={"device_code": device_code}
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
if response.status_code == 200: # noqa: PLR2004
|
|
235
|
+
return AccessRefreshTokenResponse(**response.json())
|
|
236
|
+
if response.status_code != 401: # noqa: PLR2004
|
|
237
|
+
raise RuntimeError(self._get_error_message(response))
|
|
238
|
+
|
|
239
|
+
time.sleep(interval)
|
|
240
|
+
|
|
241
|
+
raise RuntimeError("Polling for token timed out")
|
|
242
|
+
|
|
243
|
+
# User
|
|
244
|
+
|
|
245
|
+
def get_user(self) -> UserResponse:
|
|
246
|
+
"""Get the user email and username."""
|
|
247
|
+
|
|
248
|
+
response = self.request("GET", "/user")
|
|
249
|
+
return UserResponse(**response.json())
|
|
250
|
+
|
|
251
|
+
# Github
|
|
252
|
+
|
|
253
|
+
def get_github_access_token(self, repos: list[str]) -> GithubTokenResponse:
|
|
254
|
+
"""Try to get a GitHub access token for the given repositories."""
|
|
255
|
+
response = self.request("POST", "/github/token", json_data={"repos": repos})
|
|
256
|
+
return GithubTokenResponse(**response.json())
|
|
257
|
+
|
|
258
|
+
# Strikes
|
|
259
|
+
|
|
173
260
|
def list_projects(self) -> list[Project]:
|
|
174
261
|
"""Retrieves a list of projects.
|
|
175
262
|
|
|
@@ -294,7 +381,8 @@ class ApiClient:
|
|
|
294
381
|
status: StatusFilter = "completed",
|
|
295
382
|
aggregations: list[MetricAggregationType] | None = None,
|
|
296
383
|
) -> pd.DataFrame:
|
|
297
|
-
"""
|
|
384
|
+
"""
|
|
385
|
+
Exports run data for a specific project.
|
|
298
386
|
|
|
299
387
|
Args:
|
|
300
388
|
project: The project identifier.
|
|
@@ -327,7 +415,8 @@ class ApiClient:
|
|
|
327
415
|
metrics: list[str] | None = None,
|
|
328
416
|
aggregations: list[MetricAggregationType] | None = None,
|
|
329
417
|
) -> pd.DataFrame:
|
|
330
|
-
"""
|
|
418
|
+
"""
|
|
419
|
+
Exports metric data for a specific project.
|
|
331
420
|
|
|
332
421
|
Args:
|
|
333
422
|
project: The project identifier.
|
|
@@ -363,7 +452,8 @@ class ApiClient:
|
|
|
363
452
|
metrics: list[str] | None = None,
|
|
364
453
|
aggregations: list[MetricAggregationType] | None = None,
|
|
365
454
|
) -> pd.DataFrame:
|
|
366
|
-
"""
|
|
455
|
+
"""
|
|
456
|
+
Exports parameter data for a specific project.
|
|
367
457
|
|
|
368
458
|
Args:
|
|
369
459
|
project: The project identifier.
|
|
@@ -401,7 +491,8 @@ class ApiClient:
|
|
|
401
491
|
time_axis: TimeAxisType = "relative",
|
|
402
492
|
aggregations: list[TimeAggregationType] | None = None,
|
|
403
493
|
) -> pd.DataFrame:
|
|
404
|
-
"""
|
|
494
|
+
"""
|
|
495
|
+
Exports timeseries data for a specific project.
|
|
405
496
|
|
|
406
497
|
Args:
|
|
407
498
|
project: The project identifier.
|
|
@@ -430,12 +521,17 @@ class ApiClient:
|
|
|
430
521
|
|
|
431
522
|
# User data access
|
|
432
523
|
|
|
433
|
-
def get_user_data_credentials(
|
|
524
|
+
def get_user_data_credentials(
|
|
525
|
+
self, duration: int = DEFAULT_FS_CREDENTIAL_DURATION
|
|
526
|
+
) -> UserDataCredentials:
|
|
434
527
|
"""
|
|
435
528
|
Retrieves user data credentials for secondary storage access.
|
|
436
529
|
|
|
530
|
+
Args:
|
|
531
|
+
duration: Credential lifetime in seconds (default: 4 hours)
|
|
532
|
+
|
|
437
533
|
Returns:
|
|
438
534
|
The user data credentials object.
|
|
439
535
|
"""
|
|
440
|
-
response = self.
|
|
536
|
+
response = self._request("GET", "/user-data/credentials", params={"duration": duration})
|
|
441
537
|
return UserDataCredentials(**response.json())
|
|
@@ -32,6 +32,35 @@ class UserResponse(BaseModel):
|
|
|
32
32
|
api_key: UserAPIKey
|
|
33
33
|
|
|
34
34
|
|
|
35
|
+
class UserDataCredentials(BaseModel):
|
|
36
|
+
access_key_id: str
|
|
37
|
+
secret_access_key: str
|
|
38
|
+
session_token: str
|
|
39
|
+
expiration: datetime
|
|
40
|
+
region: str
|
|
41
|
+
bucket: str
|
|
42
|
+
prefix: str
|
|
43
|
+
endpoint: str | None
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
# Auth
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class DeviceCodeResponse(BaseModel):
|
|
50
|
+
id: UUID
|
|
51
|
+
completed: bool
|
|
52
|
+
device_code: str
|
|
53
|
+
expires_at: datetime
|
|
54
|
+
expires_in: int
|
|
55
|
+
user_code: str
|
|
56
|
+
verification_url: str
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class AccessRefreshTokenResponse(BaseModel):
|
|
60
|
+
access_token: str
|
|
61
|
+
refresh_token: str
|
|
62
|
+
|
|
63
|
+
|
|
35
64
|
# Strikes
|
|
36
65
|
|
|
37
66
|
SpanStatus = t.Literal[
|
|
@@ -406,15 +435,10 @@ class TraceTree(BaseModel):
|
|
|
406
435
|
"""Children of this span, representing nested spans or tasks."""
|
|
407
436
|
|
|
408
437
|
|
|
409
|
-
#
|
|
438
|
+
# Github
|
|
410
439
|
|
|
411
440
|
|
|
412
|
-
class
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
expiration: datetime
|
|
417
|
-
region: str
|
|
418
|
-
bucket: str
|
|
419
|
-
prefix: str
|
|
420
|
-
endpoint: str | None
|
|
441
|
+
class GithubTokenResponse(BaseModel):
|
|
442
|
+
token: str
|
|
443
|
+
expires_at: datetime
|
|
444
|
+
repos: list[str]
|
|
@@ -4,10 +4,12 @@ Provides efficient uploading of files and directories with deduplication.
|
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
6
|
import hashlib
|
|
7
|
+
import typing as t
|
|
7
8
|
from pathlib import Path
|
|
8
9
|
|
|
9
10
|
import fsspec # type: ignore[import-untyped]
|
|
10
11
|
|
|
12
|
+
from dreadnode.storage_utils import with_credential_refresh
|
|
11
13
|
from dreadnode.util import logger
|
|
12
14
|
|
|
13
15
|
CHUNK_SIZE = 8 * 1024 * 1024 # 8MB
|
|
@@ -22,15 +24,27 @@ class ArtifactStorage:
|
|
|
22
24
|
- Batch uploads for directories handled by fsspec
|
|
23
25
|
"""
|
|
24
26
|
|
|
25
|
-
def __init__(
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
file_system: fsspec.AbstractFileSystem,
|
|
30
|
+
credential_refresher: t.Callable[[], bool] | None = None,
|
|
31
|
+
):
|
|
26
32
|
"""
|
|
27
33
|
Initialize artifact storage with a file system and prefix path.
|
|
28
34
|
|
|
29
35
|
Args:
|
|
30
36
|
file_system: FSSpec-compatible file system
|
|
37
|
+
credential_refresher: Optional function to refresh credentials when it's about to expire
|
|
31
38
|
"""
|
|
32
39
|
self._file_system = file_system
|
|
40
|
+
self._credential_refresher = credential_refresher
|
|
33
41
|
|
|
42
|
+
def _refresh_credentials_if_needed(self) -> None:
|
|
43
|
+
"""Refresh credentials if refresher is available."""
|
|
44
|
+
if self._credential_refresher:
|
|
45
|
+
self._credential_refresher()
|
|
46
|
+
|
|
47
|
+
@with_credential_refresh
|
|
34
48
|
def store_file(self, file_path: Path, target_key: str) -> str:
|
|
35
49
|
"""
|
|
36
50
|
Store a file in the storage system, using multipart upload for large files.
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
import atexit
|
|
2
|
+
import base64
|
|
3
|
+
import json
|
|
4
|
+
from datetime import datetime, timezone
|
|
5
|
+
|
|
6
|
+
from dreadnode.api.client import ApiClient
|
|
7
|
+
from dreadnode.config import UserConfig
|
|
8
|
+
from dreadnode.constants import (
|
|
9
|
+
DEFAULT_TOKEN_MAX_TTL,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Token:
|
|
14
|
+
"""A JWT token with an expiration time."""
|
|
15
|
+
|
|
16
|
+
data: str
|
|
17
|
+
expires_at: datetime
|
|
18
|
+
|
|
19
|
+
@staticmethod
|
|
20
|
+
def parse_jwt_token_expiration(token: str) -> datetime:
|
|
21
|
+
"""Return the expiration date from a JWT token."""
|
|
22
|
+
|
|
23
|
+
_, b64payload, _ = token.split(".")
|
|
24
|
+
payload = base64.urlsafe_b64decode(b64payload + "==").decode("utf-8")
|
|
25
|
+
return datetime.fromtimestamp(json.loads(payload).get("exp"), tz=timezone.utc)
|
|
26
|
+
|
|
27
|
+
def __init__(self, token: str):
|
|
28
|
+
self.data = token
|
|
29
|
+
self.expires_at = Token.parse_jwt_token_expiration(token)
|
|
30
|
+
|
|
31
|
+
def ttl(self) -> int:
|
|
32
|
+
"""Get number of seconds left until the token expires."""
|
|
33
|
+
return int((self.expires_at - datetime.now(tz=timezone.utc)).total_seconds())
|
|
34
|
+
|
|
35
|
+
def is_expired(self) -> bool:
|
|
36
|
+
"""Return True if the token is expired."""
|
|
37
|
+
return self.ttl() <= 0
|
|
38
|
+
|
|
39
|
+
def is_close_to_expiry(self) -> bool:
|
|
40
|
+
"""Return True if the token is close to expiry."""
|
|
41
|
+
return self.ttl() <= DEFAULT_TOKEN_MAX_TTL
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def create_api_client(*, profile: str | None = None) -> ApiClient:
|
|
45
|
+
"""Create an authenticated API client using stored configuration data."""
|
|
46
|
+
|
|
47
|
+
user_config = UserConfig.read()
|
|
48
|
+
config = user_config.get_server_config(profile)
|
|
49
|
+
|
|
50
|
+
client = ApiClient(
|
|
51
|
+
config.url,
|
|
52
|
+
cookies={"access_token": config.access_token, "refresh_token": config.refresh_token},
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
# Preemptively check if the token is expired
|
|
56
|
+
if Token(config.refresh_token).is_expired():
|
|
57
|
+
raise RuntimeError("Authentication expired, use [bold]dreadnode login[/]")
|
|
58
|
+
|
|
59
|
+
def _flush_auth_changes() -> None:
|
|
60
|
+
"""Flush the authentication data to disk if it has been updated."""
|
|
61
|
+
|
|
62
|
+
access_token = client._client.cookies.get("access_token") # noqa: SLF001
|
|
63
|
+
refresh_token = client._client.cookies.get("refresh_token") # noqa: SLF001
|
|
64
|
+
|
|
65
|
+
changed: bool = False
|
|
66
|
+
if access_token and access_token != config.access_token:
|
|
67
|
+
changed = True
|
|
68
|
+
config.access_token = access_token
|
|
69
|
+
|
|
70
|
+
if refresh_token and refresh_token != config.refresh_token:
|
|
71
|
+
changed = True
|
|
72
|
+
config.refresh_token = refresh_token
|
|
73
|
+
|
|
74
|
+
if changed:
|
|
75
|
+
user_config.set_server_config(config, profile).write()
|
|
76
|
+
|
|
77
|
+
atexit.register(_flush_auth_changes)
|
|
78
|
+
|
|
79
|
+
return client
|