dreadnode 1.13.0__tar.gz → 1.13.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dreadnode-1.13.0 → dreadnode-1.13.1}/PKG-INFO +1 -1
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/api/client.py +12 -3
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/artifact/storage.py +15 -1
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/constants.py +5 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/main.py +68 -14
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/serialization.py +1 -1
- dreadnode-1.13.1/dreadnode/storage_utils.py +37 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/tracing/span.py +13 -1
- {dreadnode-1.13.0 → dreadnode-1.13.1}/pyproject.toml +1 -1
- {dreadnode-1.13.0 → dreadnode-1.13.1}/README.md +0 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/__init__.py +0 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/__main__.py +0 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/api/__init__.py +0 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/api/models.py +0 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/api/util.py +0 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/artifact/__init__.py +0 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/artifact/merger.py +0 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/artifact/tree_builder.py +0 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/cli/__init__.py +0 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/cli/api.py +0 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/cli/github.py +0 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/cli/main.py +0 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/cli/profile/__init__.py +0 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/cli/profile/cli.py +0 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/config.py +0 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/convert.py +0 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/data_types/__init__.py +0 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/data_types/audio.py +0 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/data_types/base.py +0 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/data_types/image.py +0 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/data_types/object_3d.py +0 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/data_types/table.py +0 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/data_types/text.py +0 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/data_types/video.py +0 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/integrations/__init__.py +0 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/integrations/transformers.py +0 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/lookup.py +0 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/metric.py +0 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/object.py +0 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/py.typed +0 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/scorers/__init__.py +0 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/scorers/classification.py +0 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/scorers/consistency.py +0 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/scorers/contains.py +0 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/scorers/format.py +0 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/scorers/harm.py +0 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/scorers/judge.py +0 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/scorers/length.py +0 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/scorers/lexical.py +0 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/scorers/operators.py +0 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/scorers/pii.py +0 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/scorers/readability.py +0 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/scorers/rigging.py +0 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/scorers/sentiment.py +0 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/scorers/similarity.py +0 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/scorers/util.py +0 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/task.py +0 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/tracing/__init__.py +0 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/tracing/constants.py +0 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/tracing/exporters.py +0 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/types.py +0 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/util.py +0 -0
- {dreadnode-1.13.0 → dreadnode-1.13.1}/dreadnode/version.py +0 -0
|
@@ -36,7 +36,11 @@ from dreadnode.api.util import (
|
|
|
36
36
|
process_run,
|
|
37
37
|
process_task,
|
|
38
38
|
)
|
|
39
|
-
from dreadnode.constants import
|
|
39
|
+
from dreadnode.constants import (
|
|
40
|
+
DEFAULT_FS_CREDENTIAL_DURATION,
|
|
41
|
+
DEFAULT_MAX_POLL_TIME,
|
|
42
|
+
DEFAULT_POLL_INTERVAL,
|
|
43
|
+
)
|
|
40
44
|
from dreadnode.util import logger
|
|
41
45
|
from dreadnode.version import VERSION
|
|
42
46
|
|
|
@@ -517,12 +521,17 @@ class ApiClient:
|
|
|
517
521
|
|
|
518
522
|
# User data access
|
|
519
523
|
|
|
520
|
-
def get_user_data_credentials(
|
|
524
|
+
def get_user_data_credentials(
|
|
525
|
+
self, duration: int = DEFAULT_FS_CREDENTIAL_DURATION
|
|
526
|
+
) -> UserDataCredentials:
|
|
521
527
|
"""
|
|
522
528
|
Retrieves user data credentials for secondary storage access.
|
|
523
529
|
|
|
530
|
+
Args:
|
|
531
|
+
duration: Credential lifetime in seconds (default: 4 hours)
|
|
532
|
+
|
|
524
533
|
Returns:
|
|
525
534
|
The user data credentials object.
|
|
526
535
|
"""
|
|
527
|
-
response = self.
|
|
536
|
+
response = self._request("GET", "/user-data/credentials", params={"duration": duration})
|
|
528
537
|
return UserDataCredentials(**response.json())
|
|
@@ -4,10 +4,12 @@ Provides efficient uploading of files and directories with deduplication.
|
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
6
|
import hashlib
|
|
7
|
+
import typing as t
|
|
7
8
|
from pathlib import Path
|
|
8
9
|
|
|
9
10
|
import fsspec # type: ignore[import-untyped]
|
|
10
11
|
|
|
12
|
+
from dreadnode.storage_utils import with_credential_refresh
|
|
11
13
|
from dreadnode.util import logger
|
|
12
14
|
|
|
13
15
|
CHUNK_SIZE = 8 * 1024 * 1024 # 8MB
|
|
@@ -22,15 +24,27 @@ class ArtifactStorage:
|
|
|
22
24
|
- Batch uploads for directories handled by fsspec
|
|
23
25
|
"""
|
|
24
26
|
|
|
25
|
-
def __init__(
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
file_system: fsspec.AbstractFileSystem,
|
|
30
|
+
credential_refresher: t.Callable[[], bool] | None = None,
|
|
31
|
+
):
|
|
26
32
|
"""
|
|
27
33
|
Initialize artifact storage with a file system and prefix path.
|
|
28
34
|
|
|
29
35
|
Args:
|
|
30
36
|
file_system: FSSpec-compatible file system
|
|
37
|
+
credential_refresher: Optional function to refresh credentials when it's about to expire
|
|
31
38
|
"""
|
|
32
39
|
self._file_system = file_system
|
|
40
|
+
self._credential_refresher = credential_refresher
|
|
33
41
|
|
|
42
|
+
def _refresh_credentials_if_needed(self) -> None:
|
|
43
|
+
"""Refresh credentials if refresher is available."""
|
|
44
|
+
if self._credential_refresher:
|
|
45
|
+
self._credential_refresher()
|
|
46
|
+
|
|
47
|
+
@with_credential_refresh
|
|
34
48
|
def store_file(self, file_path: Path, target_key: str) -> str:
|
|
35
49
|
"""
|
|
36
50
|
Store a file in the storage system, using multipart upload for large files.
|
|
@@ -39,6 +39,7 @@ ENV_API_KEY = "DREADNODE_API_KEY" # pragma: allowlist secret (alternative to AP
|
|
|
39
39
|
ENV_LOCAL_DIR = "DREADNODE_LOCAL_DIR"
|
|
40
40
|
ENV_PROJECT = "DREADNODE_PROJECT"
|
|
41
41
|
ENV_PROFILE = "DREADNODE_PROFILE"
|
|
42
|
+
ENV_CONSOLE = "DREADNODE_CONSOLE"
|
|
42
43
|
|
|
43
44
|
#
|
|
44
45
|
# Environment
|
|
@@ -55,3 +56,7 @@ USER_CONFIG_PATH = pathlib.Path(
|
|
|
55
56
|
# allow overriding the user config file via env variable
|
|
56
57
|
os.getenv("DREADNODE_USER_CONFIG_FILE") or pathlib.Path.home() / ".dreadnode" / "config"
|
|
57
58
|
)
|
|
59
|
+
|
|
60
|
+
# Default values for the file system credential management
|
|
61
|
+
DEFAULT_FS_CREDENTIAL_DURATION = 14400 # 4 hours in seconds
|
|
62
|
+
FS_CREDENTIAL_REFRESH_BUFFER = 300 # 5 minutes in seconds
|
|
@@ -26,14 +26,17 @@ from s3fs import S3FileSystem # type: ignore [import-untyped]
|
|
|
26
26
|
from dreadnode.api.client import ApiClient
|
|
27
27
|
from dreadnode.config import UserConfig
|
|
28
28
|
from dreadnode.constants import (
|
|
29
|
+
DEFAULT_FS_CREDENTIAL_DURATION,
|
|
29
30
|
DEFAULT_SERVER_URL,
|
|
30
31
|
ENV_API_KEY,
|
|
31
32
|
ENV_API_TOKEN,
|
|
33
|
+
ENV_CONSOLE,
|
|
32
34
|
ENV_LOCAL_DIR,
|
|
33
35
|
ENV_PROFILE,
|
|
34
36
|
ENV_PROJECT,
|
|
35
37
|
ENV_SERVER,
|
|
36
38
|
ENV_SERVER_URL,
|
|
39
|
+
FS_CREDENTIAL_REFRESH_BUFFER,
|
|
37
40
|
)
|
|
38
41
|
from dreadnode.metric import (
|
|
39
42
|
Metric,
|
|
@@ -63,7 +66,7 @@ from dreadnode.types import (
|
|
|
63
66
|
Inherited,
|
|
64
67
|
JsonValue,
|
|
65
68
|
)
|
|
66
|
-
from dreadnode.util import clean_str, handle_internal_errors, resolve_endpoint
|
|
69
|
+
from dreadnode.util import clean_str, handle_internal_errors, logger, resolve_endpoint
|
|
67
70
|
from dreadnode.version import VERSION
|
|
68
71
|
|
|
69
72
|
if t.TYPE_CHECKING:
|
|
@@ -72,6 +75,8 @@ if t.TYPE_CHECKING:
|
|
|
72
75
|
from opentelemetry.sdk.trace import SpanProcessor
|
|
73
76
|
from opentelemetry.trace import Tracer
|
|
74
77
|
|
|
78
|
+
from dreadnode.api.models import UserDataCredentials
|
|
79
|
+
|
|
75
80
|
|
|
76
81
|
ToObject = t.Literal["task-or-run", "run"]
|
|
77
82
|
|
|
@@ -100,7 +105,7 @@ class Dreadnode:
|
|
|
100
105
|
project: str | None
|
|
101
106
|
service_name: str | None
|
|
102
107
|
service_version: str | None
|
|
103
|
-
console: logfire.ConsoleOptions |
|
|
108
|
+
console: logfire.ConsoleOptions | bool
|
|
104
109
|
send_to_logfire: bool | t.Literal["if-token-present"]
|
|
105
110
|
otel_scope: str
|
|
106
111
|
|
|
@@ -113,7 +118,7 @@ class Dreadnode:
|
|
|
113
118
|
project: str | None = None,
|
|
114
119
|
service_name: str | None = None,
|
|
115
120
|
service_version: str | None = None,
|
|
116
|
-
console: logfire.ConsoleOptions |
|
|
121
|
+
console: logfire.ConsoleOptions | bool = True,
|
|
117
122
|
send_to_logfire: bool | t.Literal["if-token-present"] = False,
|
|
118
123
|
otel_scope: str = "dreadnode",
|
|
119
124
|
) -> None:
|
|
@@ -136,6 +141,8 @@ class Dreadnode:
|
|
|
136
141
|
self._fs_prefix: str = ".dreadnode/storage/"
|
|
137
142
|
|
|
138
143
|
self._initialized = False
|
|
144
|
+
self._credentials: UserDataCredentials | None = None
|
|
145
|
+
self._credentials_expiry: datetime | None = None
|
|
139
146
|
|
|
140
147
|
def _get_profile_server(self, profile: str | None = None) -> str | None:
|
|
141
148
|
with contextlib.suppress(Exception):
|
|
@@ -167,7 +174,7 @@ class Dreadnode:
|
|
|
167
174
|
project: str | None = None,
|
|
168
175
|
service_name: str | None = None,
|
|
169
176
|
service_version: str | None = None,
|
|
170
|
-
console: logfire.ConsoleOptions |
|
|
177
|
+
console: logfire.ConsoleOptions | bool | None = None,
|
|
171
178
|
send_to_logfire: bool | t.Literal["if-token-present"] = False,
|
|
172
179
|
otel_scope: str = "dreadnode",
|
|
173
180
|
) -> None:
|
|
@@ -195,7 +202,7 @@ class Dreadnode:
|
|
|
195
202
|
project: The default project name to associate all runs with.
|
|
196
203
|
service_name: The service name to use for OpenTelemetry.
|
|
197
204
|
service_version: The service version to use for OpenTelemetry.
|
|
198
|
-
console: Whether to log span information to the console.
|
|
205
|
+
console: Whether to log span information to the console (`DREADNODE_CONSOLE` or the default is True).
|
|
199
206
|
send_to_logfire: Whether to send data to Logfire.
|
|
200
207
|
otel_scope: The OpenTelemetry scope name.
|
|
201
208
|
"""
|
|
@@ -252,7 +259,11 @@ class Dreadnode:
|
|
|
252
259
|
self.project = project or os.environ.get(ENV_PROJECT)
|
|
253
260
|
self.service_name = service_name
|
|
254
261
|
self.service_version = service_version
|
|
255
|
-
self.console = console
|
|
262
|
+
self.console = console or os.environ.get(ENV_CONSOLE, "true").lower() in [
|
|
263
|
+
"true",
|
|
264
|
+
"1",
|
|
265
|
+
"yes",
|
|
266
|
+
]
|
|
256
267
|
self.send_to_logfire = send_to_logfire
|
|
257
268
|
self.otel_scope = otel_scope
|
|
258
269
|
|
|
@@ -342,19 +353,21 @@ class Dreadnode:
|
|
|
342
353
|
# )
|
|
343
354
|
# )
|
|
344
355
|
# )
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
356
|
+
self._credentials = self._api.get_user_data_credentials(
|
|
357
|
+
duration=DEFAULT_FS_CREDENTIAL_DURATION
|
|
358
|
+
)
|
|
359
|
+
self._credentials_expiry = self._credentials.expiration
|
|
360
|
+
resolved_endpoint = resolve_endpoint(self._credentials.endpoint)
|
|
348
361
|
self._fs = S3FileSystem(
|
|
349
|
-
key=
|
|
350
|
-
secret=
|
|
351
|
-
token=
|
|
362
|
+
key=self._credentials.access_key_id,
|
|
363
|
+
secret=self._credentials.secret_access_key,
|
|
364
|
+
token=self._credentials.session_token,
|
|
352
365
|
client_kwargs={
|
|
353
366
|
"endpoint_url": resolved_endpoint,
|
|
354
|
-
"region_name":
|
|
367
|
+
"region_name": self._credentials.region,
|
|
355
368
|
},
|
|
356
369
|
)
|
|
357
|
-
self._fs_prefix = f"{
|
|
370
|
+
self._fs_prefix = f"{self._credentials.bucket}/{self._credentials.prefix}/"
|
|
358
371
|
|
|
359
372
|
self._logfire = logfire.configure(
|
|
360
373
|
local=not self.is_default,
|
|
@@ -401,6 +414,45 @@ class Dreadnode:
|
|
|
401
414
|
|
|
402
415
|
return self._api
|
|
403
416
|
|
|
417
|
+
def _refresh_storage_credentials(self) -> bool:
|
|
418
|
+
"""Refresh storage credentials if they are about to expire."""
|
|
419
|
+
if not self._api or not self._credentials:
|
|
420
|
+
return False
|
|
421
|
+
|
|
422
|
+
now = datetime.now(timezone.utc)
|
|
423
|
+
|
|
424
|
+
if (
|
|
425
|
+
self._credentials_expiry is None
|
|
426
|
+
or (self._credentials_expiry - now).total_seconds() < FS_CREDENTIAL_REFRESH_BUFFER
|
|
427
|
+
):
|
|
428
|
+
try:
|
|
429
|
+
logger.info("Refreshing storage credentials")
|
|
430
|
+
self._credentials = self._api.get_user_data_credentials(
|
|
431
|
+
duration=DEFAULT_FS_CREDENTIAL_DURATION
|
|
432
|
+
)
|
|
433
|
+
self._credentials_expiry = self._credentials.expiration
|
|
434
|
+
|
|
435
|
+
resolved_endpoint = resolve_endpoint(self._credentials.endpoint)
|
|
436
|
+
self._fs = S3FileSystem(
|
|
437
|
+
key=self._credentials.access_key_id,
|
|
438
|
+
secret=self._credentials.secret_access_key,
|
|
439
|
+
token=self._credentials.session_token,
|
|
440
|
+
client_kwargs={
|
|
441
|
+
"endpoint_url": resolved_endpoint,
|
|
442
|
+
"region_name": self._credentials.region,
|
|
443
|
+
},
|
|
444
|
+
)
|
|
445
|
+
logger.info(
|
|
446
|
+
f"Storage credentials refreshed, valid until {self._credentials_expiry}"
|
|
447
|
+
)
|
|
448
|
+
return True # noqa: TRY300
|
|
449
|
+
|
|
450
|
+
except Exception as e: # noqa: BLE001
|
|
451
|
+
logger.error(f"Failed to refresh storage credentials: {e}")
|
|
452
|
+
return False
|
|
453
|
+
|
|
454
|
+
return True
|
|
455
|
+
|
|
404
456
|
def _get_tracer(self, *, is_span_tracer: bool = True) -> "Tracer":
|
|
405
457
|
return self._logfire._tracer_provider.get_tracer( # noqa: SLF001
|
|
406
458
|
self.otel_scope,
|
|
@@ -773,6 +825,7 @@ class Dreadnode:
|
|
|
773
825
|
file_system=self._fs,
|
|
774
826
|
prefix_path=self._fs_prefix,
|
|
775
827
|
autolog=autolog,
|
|
828
|
+
credential_refresher=self._refresh_storage_credentials if self._credentials else None,
|
|
776
829
|
)
|
|
777
830
|
|
|
778
831
|
def get_run_context(self) -> RunContext:
|
|
@@ -819,6 +872,7 @@ class Dreadnode:
|
|
|
819
872
|
tracer=self._get_tracer(),
|
|
820
873
|
file_system=self._fs,
|
|
821
874
|
prefix_path=self._fs_prefix,
|
|
875
|
+
credential_refresher=self._refresh_storage_credentials if self._credentials else None,
|
|
822
876
|
)
|
|
823
877
|
|
|
824
878
|
def tag(self, *tag: str, to: ToObject = "task-or-run") -> None:
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import typing as t
|
|
3
|
+
|
|
4
|
+
from dreadnode.util import logger
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def with_credential_refresh(func: t.Callable[..., t.Any]) -> t.Callable[..., t.Any]:
|
|
8
|
+
"""Decorator that automatically handles credential refresh on storage errors."""
|
|
9
|
+
|
|
10
|
+
@functools.wraps(func)
|
|
11
|
+
def wrapper(self: t.Any, *args: t.Any, **kwargs: t.Any) -> t.Any:
|
|
12
|
+
# Try to refresh credentials before operation
|
|
13
|
+
if hasattr(self, "_refresh_credentials_if_needed"):
|
|
14
|
+
self._refresh_credentials_if_needed()
|
|
15
|
+
|
|
16
|
+
try:
|
|
17
|
+
return func(self, *args, **kwargs)
|
|
18
|
+
except Exception as e:
|
|
19
|
+
error_str = str(e)
|
|
20
|
+
if any(
|
|
21
|
+
error in error_str
|
|
22
|
+
for error in [
|
|
23
|
+
"ExpiredToken",
|
|
24
|
+
"TokenRefreshRequired",
|
|
25
|
+
"InvalidAccessKeyId",
|
|
26
|
+
"The Access Key Id you provided does not exist",
|
|
27
|
+
]
|
|
28
|
+
):
|
|
29
|
+
logger.info("Storage credential error, forcing refresh and retrying")
|
|
30
|
+
|
|
31
|
+
if hasattr(self, "_refresh_credentials_if_needed"):
|
|
32
|
+
self._refresh_credentials_if_needed()
|
|
33
|
+
|
|
34
|
+
return func(self, *args, **kwargs)
|
|
35
|
+
raise
|
|
36
|
+
|
|
37
|
+
return wrapper
|
|
@@ -36,6 +36,7 @@ from dreadnode.convert import run_span_to_graph
|
|
|
36
36
|
from dreadnode.metric import Metric, MetricAggMode, MetricsDict
|
|
37
37
|
from dreadnode.object import Object, ObjectRef, ObjectUri, ObjectVal
|
|
38
38
|
from dreadnode.serialization import Serialized, serialize
|
|
39
|
+
from dreadnode.storage_utils import with_credential_refresh
|
|
39
40
|
from dreadnode.tracing.constants import (
|
|
40
41
|
EVENT_ATTRIBUTE_LINK_HASH,
|
|
41
42
|
EVENT_ATTRIBUTE_OBJECT_HASH,
|
|
@@ -365,6 +366,7 @@ class RunSpan(Span):
|
|
|
365
366
|
update_frequency: int = 5,
|
|
366
367
|
run_id: str | ULID | None = None,
|
|
367
368
|
type: SpanType = "run",
|
|
369
|
+
credential_refresher: t.Callable[[], bool] | None = None,
|
|
368
370
|
) -> None:
|
|
369
371
|
self.autolog = autolog
|
|
370
372
|
self.project = project
|
|
@@ -375,7 +377,9 @@ class RunSpan(Span):
|
|
|
375
377
|
self._object_schemas: dict[str, JsonDict] = {}
|
|
376
378
|
self._inputs: list[ObjectRef] = []
|
|
377
379
|
self._outputs: list[ObjectRef] = []
|
|
378
|
-
self._artifact_storage = ArtifactStorage(
|
|
380
|
+
self._artifact_storage = ArtifactStorage(
|
|
381
|
+
file_system=file_system, credential_refresher=credential_refresher
|
|
382
|
+
)
|
|
379
383
|
self._artifacts: list[DirectoryNode] = []
|
|
380
384
|
self._artifact_merger = ArtifactMerger()
|
|
381
385
|
self._artifact_tree_builder = ArtifactTreeBuilder(
|
|
@@ -406,6 +410,7 @@ class RunSpan(Span):
|
|
|
406
410
|
SPAN_ATTRIBUTE_PROJECT: project,
|
|
407
411
|
**(attributes or {}),
|
|
408
412
|
}
|
|
413
|
+
self._credential_refresher = credential_refresher
|
|
409
414
|
super().__init__(name, tracer, attributes=attributes, type=type, tags=tags)
|
|
410
415
|
|
|
411
416
|
@classmethod
|
|
@@ -415,6 +420,7 @@ class RunSpan(Span):
|
|
|
415
420
|
tracer: Tracer,
|
|
416
421
|
file_system: AbstractFileSystem,
|
|
417
422
|
prefix_path: str,
|
|
423
|
+
credential_refresher: t.Callable[[], bool] | None = None,
|
|
418
424
|
) -> "RunSpan":
|
|
419
425
|
self = RunSpan(
|
|
420
426
|
name=f"run.{context['run_id']}.fragment",
|
|
@@ -425,6 +431,7 @@ class RunSpan(Span):
|
|
|
425
431
|
prefix_path=prefix_path,
|
|
426
432
|
type="run_fragment",
|
|
427
433
|
run_id=context["run_id"],
|
|
434
|
+
credential_refresher=credential_refresher,
|
|
428
435
|
)
|
|
429
436
|
|
|
430
437
|
self._remote_context = context["trace_context"]
|
|
@@ -500,6 +507,10 @@ class RunSpan(Span):
|
|
|
500
507
|
if self._context_token is not None:
|
|
501
508
|
current_run_span.reset(self._context_token)
|
|
502
509
|
|
|
510
|
+
def _refresh_credentials_if_needed(self) -> None:
|
|
511
|
+
if self._credential_refresher:
|
|
512
|
+
self._credential_refresher()
|
|
513
|
+
|
|
503
514
|
def push_update(self, *, force: bool = False) -> None:
|
|
504
515
|
if self._span is None:
|
|
505
516
|
return
|
|
@@ -604,6 +615,7 @@ class RunSpan(Span):
|
|
|
604
615
|
|
|
605
616
|
return composite_hash
|
|
606
617
|
|
|
618
|
+
@with_credential_refresh
|
|
607
619
|
def _store_file_by_hash(self, data: bytes, full_path: str) -> str:
|
|
608
620
|
"""
|
|
609
621
|
Writes data to the given full_path in the object store if it doesn't already exist.
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|