wandb 0.18.0rc1__py3-none-macosx_11_0_arm64.whl → 0.18.2__py3-none-macosx_11_0_arm64.whl
Sign up to get free protection for your applications and to get access to all the features.
- wandb/__init__.py +4 -4
- wandb/__init__.pyi +67 -12
- wandb/apis/internal.py +3 -0
- wandb/apis/public/api.py +128 -2
- wandb/apis/public/artifacts.py +11 -7
- wandb/apis/public/jobs.py +8 -0
- wandb/apis/public/runs.py +18 -5
- wandb/bin/apple_gpu_stats +0 -0
- wandb/bin/wandb-core +0 -0
- wandb/cli/cli.py +0 -5
- wandb/data_types.py +9 -2019
- wandb/env.py +0 -5
- wandb/errors/__init__.py +11 -40
- wandb/errors/errors.py +37 -0
- wandb/errors/warnings.py +2 -0
- wandb/{sklearn → integration/sklearn}/calculate/calibration_curves.py +7 -7
- wandb/{sklearn → integration/sklearn}/calculate/class_proportions.py +1 -1
- wandb/{sklearn → integration/sklearn}/calculate/confusion_matrix.py +3 -2
- wandb/{sklearn → integration/sklearn}/calculate/elbow_curve.py +6 -6
- wandb/{sklearn → integration/sklearn}/calculate/learning_curve.py +2 -2
- wandb/{sklearn → integration/sklearn}/calculate/outlier_candidates.py +2 -2
- wandb/{sklearn → integration/sklearn}/calculate/residuals.py +8 -8
- wandb/{sklearn → integration/sklearn}/calculate/silhouette.py +2 -2
- wandb/{sklearn → integration/sklearn}/calculate/summary_metrics.py +2 -2
- wandb/{sklearn → integration/sklearn}/plot/classifier.py +5 -5
- wandb/{sklearn → integration/sklearn}/plot/clusterer.py +10 -6
- wandb/{sklearn → integration/sklearn}/plot/regressor.py +5 -5
- wandb/{sklearn → integration/sklearn}/plot/shared.py +3 -3
- wandb/{sklearn → integration/sklearn}/utils.py +8 -8
- wandb/integration/tensorboard/log.py +1 -1
- wandb/{wandb_torch.py → integration/torch/wandb_torch.py} +36 -32
- wandb/old/core.py +2 -80
- wandb/plot/bar.py +7 -4
- wandb/plot/confusion_matrix.py +5 -4
- wandb/plot/histogram.py +7 -4
- wandb/plot/line.py +7 -4
- wandb/proto/v3/wandb_base_pb2.py +2 -1
- wandb/proto/v3/wandb_internal_pb2.py +2 -1
- wandb/proto/v3/wandb_server_pb2.py +2 -1
- wandb/proto/v3/wandb_settings_pb2.py +3 -2
- wandb/proto/v3/wandb_telemetry_pb2.py +2 -1
- wandb/proto/v4/wandb_base_pb2.py +2 -1
- wandb/proto/v4/wandb_internal_pb2.py +2 -1
- wandb/proto/v4/wandb_server_pb2.py +2 -1
- wandb/proto/v4/wandb_settings_pb2.py +3 -2
- wandb/proto/v4/wandb_telemetry_pb2.py +2 -1
- wandb/proto/v5/wandb_base_pb2.py +3 -2
- wandb/proto/v5/wandb_internal_pb2.py +3 -2
- wandb/proto/v5/wandb_server_pb2.py +3 -2
- wandb/proto/v5/wandb_settings_pb2.py +4 -3
- wandb/proto/v5/wandb_telemetry_pb2.py +3 -2
- wandb/sdk/artifacts/_validators.py +48 -3
- wandb/sdk/artifacts/artifact.py +157 -183
- wandb/sdk/artifacts/artifact_file_cache.py +13 -11
- wandb/sdk/artifacts/artifact_instance_cache.py +4 -2
- wandb/sdk/artifacts/artifact_manifest.py +13 -11
- wandb/sdk/artifacts/artifact_manifest_entry.py +24 -22
- wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +9 -7
- wandb/sdk/artifacts/artifact_saver.py +27 -25
- wandb/sdk/artifacts/exceptions.py +26 -25
- wandb/sdk/artifacts/storage_handler.py +11 -9
- wandb/sdk/artifacts/storage_handlers/azure_handler.py +16 -14
- wandb/sdk/artifacts/storage_handlers/gcs_handler.py +15 -13
- wandb/sdk/artifacts/storage_handlers/http_handler.py +15 -14
- wandb/sdk/artifacts/storage_handlers/local_file_handler.py +10 -8
- wandb/sdk/artifacts/storage_handlers/multi_handler.py +14 -12
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +19 -19
- wandb/sdk/artifacts/storage_handlers/tracking_handler.py +10 -8
- wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +12 -10
- wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +9 -7
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +31 -29
- wandb/sdk/artifacts/storage_policy.py +20 -20
- wandb/sdk/backend/backend.py +8 -26
- wandb/sdk/data_types/audio.py +165 -0
- wandb/sdk/data_types/base_types/wb_value.py +1 -3
- wandb/sdk/data_types/bokeh.py +70 -0
- wandb/sdk/data_types/graph.py +405 -0
- wandb/sdk/data_types/image.py +156 -0
- wandb/sdk/data_types/table.py +1204 -0
- wandb/sdk/data_types/trace_tree.py +2 -2
- wandb/sdk/data_types/utils.py +49 -0
- wandb/sdk/data_types/video.py +2 -2
- wandb/sdk/interface/interface.py +0 -24
- wandb/sdk/interface/interface_shared.py +0 -12
- wandb/sdk/internal/handler.py +0 -10
- wandb/sdk/internal/internal_api.py +71 -0
- wandb/sdk/internal/sender.py +0 -43
- wandb/sdk/internal/tb_watcher.py +1 -1
- wandb/sdk/lib/_settings_toposort_generated.py +1 -0
- wandb/sdk/lib/hashutil.py +34 -12
- wandb/sdk/lib/service_connection.py +216 -0
- wandb/sdk/lib/service_token.py +94 -0
- wandb/sdk/lib/sock_client.py +7 -3
- wandb/sdk/service/server.py +2 -5
- wandb/sdk/service/service.py +2 -31
- wandb/sdk/service/streams.py +0 -7
- wandb/sdk/wandb_init.py +42 -25
- wandb/sdk/wandb_run.py +18 -159
- wandb/sdk/wandb_settings.py +2 -0
- wandb/sdk/wandb_setup.py +25 -16
- wandb/sdk/wandb_sync.py +9 -3
- wandb/sdk/wandb_watch.py +31 -15
- wandb/sklearn.py +35 -0
- wandb/util.py +14 -3
- {wandb-0.18.0rc1.dist-info → wandb-0.18.2.dist-info}/METADATA +6 -5
- {wandb-0.18.0rc1.dist-info → wandb-0.18.2.dist-info}/RECORD +115 -111
- wandb/sdk/internal/update.py +0 -113
- wandb/sdk/lib/console.py +0 -39
- wandb/sdk/service/service_base.py +0 -50
- wandb/sdk/service/service_sock.py +0 -70
- wandb/sdk/wandb_manager.py +0 -232
- /wandb/{sklearn → integration/sklearn}/__init__.py +0 -0
- /wandb/{sklearn → integration/sklearn}/calculate/__init__.py +0 -0
- /wandb/{sklearn → integration/sklearn}/calculate/decision_boundaries.py +0 -0
- /wandb/{sklearn → integration/sklearn}/calculate/feature_importances.py +0 -0
- /wandb/{sklearn → integration/sklearn}/plot/__init__.py +0 -0
- /wandb/{sdk/lib → plot}/viz.py +0 -0
- {wandb-0.18.0rc1.dist-info → wandb-0.18.2.dist-info}/WHEEL +0 -0
- {wandb-0.18.0rc1.dist-info → wandb-0.18.2.dist-info}/entry_points.txt +0 -0
- {wandb-0.18.0rc1.dist-info → wandb-0.18.2.dist-info}/licenses/LICENSE +0 -0
@@ -1,5 +1,7 @@
|
|
1
1
|
"""Artifact cache."""
|
2
2
|
|
3
|
+
from __future__ import annotations
|
4
|
+
|
3
5
|
import contextlib
|
4
6
|
import errno
|
5
7
|
import hashlib
|
@@ -9,7 +11,7 @@ import subprocess
|
|
9
11
|
import sys
|
10
12
|
from pathlib import Path
|
11
13
|
from tempfile import NamedTemporaryFile
|
12
|
-
from typing import IO, TYPE_CHECKING, ContextManager, Iterator
|
14
|
+
from typing import IO, TYPE_CHECKING, ContextManager, Iterator
|
13
15
|
|
14
16
|
import wandb
|
15
17
|
from wandb import env, util
|
@@ -49,11 +51,11 @@ class ArtifactFileCache:
|
|
49
51
|
# [1] https://stackoverflow.com/questions/10541760/can-i-set-the-umask-for-tempfile-namedtemporaryfile-in-python
|
50
52
|
self._sys_umask = _get_sys_umask_threadsafe()
|
51
53
|
|
52
|
-
self._override_cache_path:
|
54
|
+
self._override_cache_path: StrPath | None = None
|
53
55
|
|
54
56
|
def check_md5_obj_path(
|
55
57
|
self, b64_md5: B64MD5, size: int
|
56
|
-
) ->
|
58
|
+
) -> tuple[FilePathStr, bool, Opener]:
|
57
59
|
# Check if we're using vs skipping the cache
|
58
60
|
if self._override_cache_path is not None:
|
59
61
|
skip_cache = True
|
@@ -71,7 +73,7 @@ class ArtifactFileCache:
|
|
71
73
|
url: URIStr,
|
72
74
|
etag: ETag,
|
73
75
|
size: int,
|
74
|
-
) ->
|
76
|
+
) -> tuple[FilePathStr, bool, Opener]:
|
75
77
|
# Check if we're using vs skipping the cache
|
76
78
|
if self._override_cache_path is not None:
|
77
79
|
skip_cache = True
|
@@ -87,16 +89,16 @@ class ArtifactFileCache:
|
|
87
89
|
|
88
90
|
def _check_or_create(
|
89
91
|
self, path: Path, size: int, skip_cache: bool = False
|
90
|
-
) ->
|
92
|
+
) -> tuple[FilePathStr, bool, Opener]:
|
91
93
|
opener = self._opener(path, size, skip_cache=skip_cache)
|
92
94
|
hit = path.is_file() and path.stat().st_size == size
|
93
95
|
return FilePathStr(str(path)), hit, opener
|
94
96
|
|
95
97
|
def cleanup(
|
96
98
|
self,
|
97
|
-
target_size:
|
99
|
+
target_size: int | None = None,
|
98
100
|
remove_temp: bool = False,
|
99
|
-
target_fraction:
|
101
|
+
target_fraction: float | None = None,
|
100
102
|
) -> int:
|
101
103
|
"""Clean up the cache, removing the least recently used files first.
|
102
104
|
|
@@ -121,9 +123,9 @@ class ArtifactFileCache:
|
|
121
123
|
target_size = 0
|
122
124
|
if target_size is not None and target_fraction is not None:
|
123
125
|
raise ValueError("Cannot specify both target_size and target_fraction")
|
124
|
-
if target_size and target_size < 0:
|
126
|
+
if target_size is not None and target_size < 0:
|
125
127
|
raise ValueError("target_size must be non-negative")
|
126
|
-
if target_fraction and (target_fraction < 0 or target_fraction > 1):
|
128
|
+
if target_fraction is not None and (target_fraction < 0 or target_fraction > 1):
|
127
129
|
raise ValueError("target_fraction must be between 0 and 1")
|
128
130
|
|
129
131
|
bytes_reclaimed = 0
|
@@ -198,7 +200,7 @@ class ArtifactFileCache:
|
|
198
200
|
if size > self._free_space():
|
199
201
|
raise OSError(errno.ENOSPC, f"Insufficient free space in {self._cache_dir}")
|
200
202
|
|
201
|
-
def _opener(self, path: Path, size: int, skip_cache: bool = False) ->
|
203
|
+
def _opener(self, path: Path, size: int, skip_cache: bool = False) -> Opener:
|
202
204
|
@contextlib.contextmanager
|
203
205
|
def atomic_open(mode: str = "w") -> Iterator[IO]:
|
204
206
|
if "a" in mode:
|
@@ -240,7 +242,7 @@ class ArtifactFileCache:
|
|
240
242
|
) from e
|
241
243
|
|
242
244
|
|
243
|
-
_artifact_file_cache:
|
245
|
+
_artifact_file_cache: ArtifactFileCache | None = None
|
244
246
|
|
245
247
|
|
246
248
|
def get_artifact_file_cache() -> ArtifactFileCache:
|
@@ -4,7 +4,9 @@ Artifacts are registered in the cache to ensure they won't be immediately garbag
|
|
4
4
|
collected and can be retrieved by their ID.
|
5
5
|
"""
|
6
6
|
|
7
|
-
from
|
7
|
+
from __future__ import annotations
|
8
|
+
|
9
|
+
from typing import TYPE_CHECKING
|
8
10
|
|
9
11
|
from wandb.sdk.lib.capped_dict import CappedDict
|
10
12
|
|
@@ -12,4 +14,4 @@ if TYPE_CHECKING:
|
|
12
14
|
from wandb.sdk.artifacts.artifact import Artifact
|
13
15
|
|
14
16
|
# There is nothing special about the artifact cache, it's just a global capped dict.
|
15
|
-
artifact_instance_cache:
|
17
|
+
artifact_instance_cache: dict[str, Artifact] = CappedDict(100)
|
@@ -1,6 +1,8 @@
|
|
1
1
|
"""Artifact manifest."""
|
2
2
|
|
3
|
-
from
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
from typing import TYPE_CHECKING, Mapping
|
4
6
|
|
5
7
|
from wandb.sdk.internal.internal_api import Api as InternalApi
|
6
8
|
from wandb.sdk.lib.hashutil import HexMD5
|
@@ -11,12 +13,12 @@ if TYPE_CHECKING:
|
|
11
13
|
|
12
14
|
|
13
15
|
class ArtifactManifest:
|
14
|
-
entries:
|
16
|
+
entries: dict[str, ArtifactManifestEntry]
|
15
17
|
|
16
18
|
@classmethod
|
17
19
|
def from_manifest_json(
|
18
|
-
cls, manifest_json:
|
19
|
-
) ->
|
20
|
+
cls, manifest_json: dict, api: InternalApi | None = None
|
21
|
+
) -> ArtifactManifest:
|
20
22
|
if "version" not in manifest_json:
|
21
23
|
raise ValueError("Invalid manifest format. Must contain version field.")
|
22
24
|
version = manifest_json["version"]
|
@@ -31,8 +33,8 @@ class ArtifactManifest:
|
|
31
33
|
|
32
34
|
def __init__(
|
33
35
|
self,
|
34
|
-
storage_policy:
|
35
|
-
entries:
|
36
|
+
storage_policy: StoragePolicy,
|
37
|
+
entries: Mapping[str, ArtifactManifestEntry] | None = None,
|
36
38
|
) -> None:
|
37
39
|
self.storage_policy = storage_policy
|
38
40
|
self.entries = dict(entries) if entries else {}
|
@@ -40,13 +42,13 @@ class ArtifactManifest:
|
|
40
42
|
def __len__(self) -> int:
|
41
43
|
return len(self.entries)
|
42
44
|
|
43
|
-
def to_manifest_json(self) ->
|
45
|
+
def to_manifest_json(self) -> dict:
|
44
46
|
raise NotImplementedError
|
45
47
|
|
46
48
|
def digest(self) -> HexMD5:
|
47
49
|
raise NotImplementedError
|
48
50
|
|
49
|
-
def add_entry(self, entry:
|
51
|
+
def add_entry(self, entry: ArtifactManifestEntry) -> None:
|
50
52
|
if (
|
51
53
|
entry.path in self.entries
|
52
54
|
and entry.digest != self.entries[entry.path].digest
|
@@ -54,15 +56,15 @@ class ArtifactManifest:
|
|
54
56
|
raise ValueError("Cannot add the same path twice: {}".format(entry.path))
|
55
57
|
self.entries[entry.path] = entry
|
56
58
|
|
57
|
-
def remove_entry(self, entry:
|
59
|
+
def remove_entry(self, entry: ArtifactManifestEntry) -> None:
|
58
60
|
if entry.path not in self.entries:
|
59
61
|
raise FileNotFoundError(f"Cannot remove missing entry: '{entry.path}'")
|
60
62
|
del self.entries[entry.path]
|
61
63
|
|
62
|
-
def get_entry_by_path(self, path: str) ->
|
64
|
+
def get_entry_by_path(self, path: str) -> ArtifactManifestEntry | None:
|
63
65
|
return self.entries.get(path)
|
64
66
|
|
65
|
-
def get_entries_in_directory(self, directory: str) ->
|
67
|
+
def get_entries_in_directory(self, directory: str) -> list[ArtifactManifestEntry]:
|
66
68
|
return [
|
67
69
|
self.entries[entry_key]
|
68
70
|
for entry_key in self.entries
|
@@ -1,10 +1,12 @@
|
|
1
1
|
"""Artifact manifest entry."""
|
2
2
|
|
3
|
+
from __future__ import annotations
|
4
|
+
|
3
5
|
import json
|
4
6
|
import logging
|
5
7
|
import os
|
6
8
|
from pathlib import Path
|
7
|
-
from typing import TYPE_CHECKING
|
9
|
+
from typing import TYPE_CHECKING
|
8
10
|
from urllib.parse import urlparse
|
9
11
|
|
10
12
|
from wandb.sdk.lib import filesystem
|
@@ -32,7 +34,7 @@ if TYPE_CHECKING:
|
|
32
34
|
ref: str
|
33
35
|
birthArtifactID: str
|
34
36
|
size: int
|
35
|
-
extra:
|
37
|
+
extra: dict
|
36
38
|
local_path: str
|
37
39
|
|
38
40
|
|
@@ -40,27 +42,27 @@ class ArtifactManifestEntry:
|
|
40
42
|
"""A single entry in an artifact manifest."""
|
41
43
|
|
42
44
|
path: LogicalPath
|
43
|
-
digest:
|
45
|
+
digest: B64MD5 | URIStr | FilePathStr | ETag
|
44
46
|
skip_cache: bool
|
45
|
-
ref:
|
46
|
-
birth_artifact_id:
|
47
|
-
size:
|
48
|
-
extra:
|
49
|
-
local_path:
|
47
|
+
ref: FilePathStr | URIStr | None
|
48
|
+
birth_artifact_id: str | None
|
49
|
+
size: int | None
|
50
|
+
extra: dict
|
51
|
+
local_path: str | None
|
50
52
|
|
51
|
-
_parent_artifact:
|
52
|
-
_download_url:
|
53
|
+
_parent_artifact: Artifact | None = None
|
54
|
+
_download_url: str | None = None
|
53
55
|
|
54
56
|
def __init__(
|
55
57
|
self,
|
56
58
|
path: StrPath,
|
57
|
-
digest:
|
58
|
-
skip_cache:
|
59
|
-
ref:
|
60
|
-
birth_artifact_id:
|
61
|
-
size:
|
62
|
-
extra:
|
63
|
-
local_path:
|
59
|
+
digest: B64MD5 | URIStr | FilePathStr | ETag,
|
60
|
+
skip_cache: bool | None = False,
|
61
|
+
ref: FilePathStr | URIStr | None = None,
|
62
|
+
birth_artifact_id: str | None = None,
|
63
|
+
size: int | None = None,
|
64
|
+
extra: dict | None = None,
|
65
|
+
local_path: StrPath | None = None,
|
64
66
|
) -> None:
|
65
67
|
self.path = LogicalPath(path)
|
66
68
|
self.digest = digest
|
@@ -116,7 +118,7 @@ class ArtifactManifestEntry:
|
|
116
118
|
)
|
117
119
|
return self.path
|
118
120
|
|
119
|
-
def parent_artifact(self) ->
|
121
|
+
def parent_artifact(self) -> Artifact:
|
120
122
|
"""Get the artifact to which this artifact entry belongs.
|
121
123
|
|
122
124
|
Returns:
|
@@ -127,7 +129,7 @@ class ArtifactManifestEntry:
|
|
127
129
|
return self._parent_artifact
|
128
130
|
|
129
131
|
def download(
|
130
|
-
self, root:
|
132
|
+
self, root: str | None = None, skip_cache: bool | None = None
|
131
133
|
) -> FilePathStr:
|
132
134
|
"""Download this artifact entry to the specified root path.
|
133
135
|
|
@@ -177,7 +179,7 @@ class ArtifactManifestEntry:
|
|
177
179
|
str(filesystem.copy_or_overwrite_changed(cache_path, dest_path))
|
178
180
|
)
|
179
181
|
|
180
|
-
def ref_target(self) ->
|
182
|
+
def ref_target(self) -> FilePathStr | URIStr:
|
181
183
|
"""Get the reference URL that is targeted by this artifact entry.
|
182
184
|
|
183
185
|
Returns:
|
@@ -219,7 +221,7 @@ class ArtifactManifestEntry:
|
|
219
221
|
+ self.path
|
220
222
|
)
|
221
223
|
|
222
|
-
def to_json(self) ->
|
224
|
+
def to_json(self) -> ArtifactManifestEntryDict:
|
223
225
|
contents: ArtifactManifestEntryDict = {
|
224
226
|
"path": self.path,
|
225
227
|
"digest": self.digest,
|
@@ -241,7 +243,7 @@ class ArtifactManifestEntry:
|
|
241
243
|
def _is_artifact_reference(self) -> bool:
|
242
244
|
return self.ref is not None and urlparse(self.ref).scheme == "wandb-artifact"
|
243
245
|
|
244
|
-
def _referenced_artifact_id(self) ->
|
246
|
+
def _referenced_artifact_id(self) -> str | None:
|
245
247
|
if not self._is_artifact_reference():
|
246
248
|
return None
|
247
249
|
return hex_to_b64_id(urlparse(self.ref).netloc)
|
@@ -1,6 +1,8 @@
|
|
1
1
|
"""Artifact manifest v1."""
|
2
2
|
|
3
|
-
from
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
from typing import Any, Mapping
|
4
6
|
|
5
7
|
from wandb.sdk.artifacts.artifact_manifest import ArtifactManifest
|
6
8
|
from wandb.sdk.artifacts.artifact_manifest_entry import ArtifactManifestEntry
|
@@ -16,8 +18,8 @@ class ArtifactManifestV1(ArtifactManifest):
|
|
16
18
|
|
17
19
|
@classmethod
|
18
20
|
def from_manifest_json(
|
19
|
-
cls, manifest_json:
|
20
|
-
) ->
|
21
|
+
cls, manifest_json: dict, api: InternalApi | None = None
|
22
|
+
) -> ArtifactManifestV1:
|
21
23
|
if manifest_json["version"] != cls.version():
|
22
24
|
raise ValueError(
|
23
25
|
"Expected manifest version 1, got {}".format(manifest_json["version"])
|
@@ -48,12 +50,12 @@ class ArtifactManifestV1(ArtifactManifest):
|
|
48
50
|
|
49
51
|
def __init__(
|
50
52
|
self,
|
51
|
-
storage_policy:
|
52
|
-
entries:
|
53
|
+
storage_policy: StoragePolicy,
|
54
|
+
entries: Mapping[str, ArtifactManifestEntry] | None = None,
|
53
55
|
) -> None:
|
54
56
|
super().__init__(storage_policy, entries=entries)
|
55
57
|
|
56
|
-
def to_manifest_json(self) ->
|
58
|
+
def to_manifest_json(self) -> dict:
|
57
59
|
"""This is the JSON that's stored in wandb_manifest.json.
|
58
60
|
|
59
61
|
If include_local is True we also include the local paths to files. This is
|
@@ -63,7 +65,7 @@ class ArtifactManifestV1(ArtifactManifest):
|
|
63
65
|
"""
|
64
66
|
contents = {}
|
65
67
|
for entry in sorted(self.entries.values(), key=lambda k: k.path):
|
66
|
-
json_entry:
|
68
|
+
json_entry: dict[str, Any] = {
|
67
69
|
"digest": entry.digest,
|
68
70
|
}
|
69
71
|
if entry.birth_artifact_id:
|
@@ -1,11 +1,13 @@
|
|
1
1
|
"""Artifact saver."""
|
2
2
|
|
3
|
+
from __future__ import annotations
|
4
|
+
|
3
5
|
import concurrent.futures
|
4
6
|
import json
|
5
7
|
import os
|
6
8
|
import sys
|
7
9
|
import tempfile
|
8
|
-
from typing import TYPE_CHECKING, Awaitable,
|
10
|
+
from typing import TYPE_CHECKING, Awaitable, Sequence
|
9
11
|
|
10
12
|
import wandb
|
11
13
|
import wandb.filesync.step_prepare
|
@@ -27,26 +29,26 @@ if TYPE_CHECKING:
|
|
27
29
|
|
28
30
|
class SaveFn(Protocol):
|
29
31
|
def __call__(
|
30
|
-
self, entry:
|
32
|
+
self, entry: ArtifactManifestEntry, progress_callback: ProgressFn
|
31
33
|
) -> bool:
|
32
34
|
pass
|
33
35
|
|
34
36
|
class SaveFnAsync(Protocol):
|
35
37
|
def __call__(
|
36
|
-
self, entry:
|
38
|
+
self, entry: ArtifactManifestEntry, progress_callback: ProgressFn
|
37
39
|
) -> Awaitable[bool]:
|
38
40
|
pass
|
39
41
|
|
40
42
|
|
41
43
|
class ArtifactSaver:
|
42
|
-
_server_artifact:
|
44
|
+
_server_artifact: dict | None # TODO better define this dict
|
43
45
|
|
44
46
|
def __init__(
|
45
47
|
self,
|
46
|
-
api:
|
48
|
+
api: InternalApi,
|
47
49
|
digest: str,
|
48
|
-
manifest_json:
|
49
|
-
file_pusher:
|
50
|
+
manifest_json: dict,
|
51
|
+
file_pusher: FilePusher,
|
50
52
|
is_user_created: bool = False,
|
51
53
|
) -> None:
|
52
54
|
self._api = api
|
@@ -65,18 +67,18 @@ class ArtifactSaver:
|
|
65
67
|
name: str,
|
66
68
|
client_id: str,
|
67
69
|
sequence_client_id: str,
|
68
|
-
distributed_id:
|
70
|
+
distributed_id: str | None = None,
|
69
71
|
finalize: bool = True,
|
70
|
-
metadata:
|
71
|
-
ttl_duration_seconds:
|
72
|
-
description:
|
73
|
-
aliases:
|
74
|
-
tags:
|
72
|
+
metadata: dict | None = None,
|
73
|
+
ttl_duration_seconds: int | None = None,
|
74
|
+
description: str | None = None,
|
75
|
+
aliases: Sequence[str] | None = None,
|
76
|
+
tags: Sequence[str] | None = None,
|
75
77
|
use_after_commit: bool = False,
|
76
78
|
incremental: bool = False,
|
77
|
-
history_step:
|
78
|
-
base_id:
|
79
|
-
) ->
|
79
|
+
history_step: int | None = None,
|
80
|
+
base_id: str | None = None,
|
81
|
+
) -> dict | None:
|
80
82
|
return self._save_internal(
|
81
83
|
type,
|
82
84
|
name,
|
@@ -101,18 +103,18 @@ class ArtifactSaver:
|
|
101
103
|
name: str,
|
102
104
|
client_id: str,
|
103
105
|
sequence_client_id: str,
|
104
|
-
distributed_id:
|
106
|
+
distributed_id: str | None = None,
|
105
107
|
finalize: bool = True,
|
106
|
-
metadata:
|
107
|
-
ttl_duration_seconds:
|
108
|
-
description:
|
109
|
-
aliases:
|
110
|
-
tags:
|
108
|
+
metadata: dict | None = None,
|
109
|
+
ttl_duration_seconds: int | None = None,
|
110
|
+
description: str | None = None,
|
111
|
+
aliases: Sequence[str] | None = None,
|
112
|
+
tags: Sequence[str] | None = None,
|
111
113
|
use_after_commit: bool = False,
|
112
114
|
incremental: bool = False,
|
113
|
-
history_step:
|
114
|
-
base_id:
|
115
|
-
) ->
|
115
|
+
history_step: int | None = None,
|
116
|
+
base_id: str | None = None,
|
117
|
+
) -> dict | None:
|
116
118
|
alias_specs = []
|
117
119
|
for alias in aliases or []:
|
118
120
|
alias_specs.append({"artifactCollectionName": name, "alias": alias})
|
@@ -1,55 +1,56 @@
|
|
1
1
|
"""Artifact exceptions."""
|
2
2
|
|
3
|
-
from
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
from typing import TYPE_CHECKING, TypeVar
|
4
6
|
|
5
7
|
from wandb import errors
|
6
8
|
|
7
9
|
if TYPE_CHECKING:
|
8
10
|
from wandb.sdk.artifacts.artifact import Artifact
|
9
11
|
|
12
|
+
ArtifactT = TypeVar("ArtifactT", bound=Artifact)
|
13
|
+
|
10
14
|
|
11
15
|
class ArtifactStatusError(AttributeError):
|
12
16
|
"""Raised when an artifact is in an invalid state for the requested operation."""
|
13
17
|
|
14
18
|
def __init__(
|
15
19
|
self,
|
16
|
-
artifact: Optional["Artifact"] = None,
|
17
|
-
attr: Optional[str] = None,
|
18
20
|
msg: str = "Artifact is in an invalid state for the requested operation.",
|
21
|
+
name: str | None = None,
|
22
|
+
obj: ArtifactT | None = None,
|
19
23
|
):
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
24
|
+
# Follow the same pattern as AttributeError in python 3.10+ by `name/obj` attributes
|
25
|
+
# See: https://docs.python.org/3/library/exceptions.html#AttributeError
|
26
|
+
try:
|
27
|
+
super().__init__(msg, name=name, obj=obj)
|
28
|
+
except TypeError:
|
29
|
+
# The `name`/`obj` keyword args and attributes were only added in python >= 3.10
|
30
|
+
super().__init__(msg)
|
31
|
+
self.name = name or ""
|
32
|
+
self.obj = obj
|
26
33
|
|
27
34
|
|
28
35
|
class ArtifactNotLoggedError(ArtifactStatusError):
|
29
36
|
"""Raised for Artifact methods or attributes only available after logging."""
|
30
37
|
|
31
|
-
def __init__(
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
artifact
|
36
|
-
attr,
|
37
|
-
"'{method_id}' used prior to logging artifact or while in offline mode. "
|
38
|
-
"Call wait() before accessing logged artifact properties.",
|
38
|
+
def __init__(self, fullname: str, obj: ArtifactT):
|
39
|
+
*_, name = fullname.split(".")
|
40
|
+
msg = (
|
41
|
+
f"{fullname!r} used prior to logging artifact or while in offline mode. "
|
42
|
+
f"Call {type(obj).wait.__qualname__}() before accessing logged artifact properties."
|
39
43
|
)
|
44
|
+
super().__init__(msg=msg, name=name, obj=obj)
|
40
45
|
|
41
46
|
|
42
47
|
class ArtifactFinalizedError(ArtifactStatusError):
|
43
48
|
"""Raised for Artifact methods or attributes that can't be changed after logging."""
|
44
49
|
|
45
|
-
def __init__(
|
46
|
-
|
47
|
-
|
48
|
-
super().__init__(
|
49
|
-
artifact,
|
50
|
-
attr,
|
51
|
-
"'{method_id}' used on logged artifact. Can't modify finalized artifact.",
|
52
|
-
)
|
50
|
+
def __init__(self, fullname: str, obj: ArtifactT):
|
51
|
+
*_, name = fullname.split(".")
|
52
|
+
msg = f"{fullname!r} used on logged artifact. Can't modify finalized artifact."
|
53
|
+
super().__init__(msg=msg, name=name, obj=obj)
|
53
54
|
|
54
55
|
|
55
56
|
class WaitTimeoutError(errors.Error):
|
@@ -1,6 +1,8 @@
|
|
1
1
|
"""Storage handler."""
|
2
2
|
|
3
|
-
from
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
from typing import TYPE_CHECKING, Sequence
|
4
6
|
|
5
7
|
from wandb.sdk.lib.paths import FilePathStr, URIStr
|
6
8
|
|
@@ -14,7 +16,7 @@ DEFAULT_MAX_OBJECTS = 10**7
|
|
14
16
|
|
15
17
|
|
16
18
|
class StorageHandler:
|
17
|
-
def can_handle(self, parsed_url:
|
19
|
+
def can_handle(self, parsed_url: ParseResult) -> bool:
|
18
20
|
"""Checks whether this handler can handle the given url.
|
19
21
|
|
20
22
|
Returns:
|
@@ -24,9 +26,9 @@ class StorageHandler:
|
|
24
26
|
|
25
27
|
def load_path(
|
26
28
|
self,
|
27
|
-
manifest_entry:
|
29
|
+
manifest_entry: ArtifactManifestEntry,
|
28
30
|
local: bool = False,
|
29
|
-
) ->
|
31
|
+
) -> URIStr | FilePathStr:
|
30
32
|
"""Load a file or directory given the corresponding index entry.
|
31
33
|
|
32
34
|
Args:
|
@@ -40,12 +42,12 @@ class StorageHandler:
|
|
40
42
|
|
41
43
|
def store_path(
|
42
44
|
self,
|
43
|
-
artifact:
|
44
|
-
path:
|
45
|
-
name:
|
45
|
+
artifact: Artifact,
|
46
|
+
path: URIStr | FilePathStr,
|
47
|
+
name: str | None = None,
|
46
48
|
checksum: bool = True,
|
47
|
-
max_objects:
|
48
|
-
) -> Sequence[
|
49
|
+
max_objects: int | None = None,
|
50
|
+
) -> Sequence[ArtifactManifestEntry]:
|
49
51
|
"""Store the file or directory at the given path to the specified artifact.
|
50
52
|
|
51
53
|
Args:
|
@@ -1,8 +1,10 @@
|
|
1
1
|
"""Azure storage handler."""
|
2
2
|
|
3
|
+
from __future__ import annotations
|
4
|
+
|
3
5
|
from pathlib import PurePosixPath
|
4
6
|
from types import ModuleType
|
5
|
-
from typing import TYPE_CHECKING,
|
7
|
+
from typing import TYPE_CHECKING, Sequence
|
6
8
|
from urllib.parse import ParseResult, parse_qsl, urlparse
|
7
9
|
|
8
10
|
import wandb
|
@@ -21,16 +23,16 @@ if TYPE_CHECKING:
|
|
21
23
|
|
22
24
|
|
23
25
|
class AzureHandler(StorageHandler):
|
24
|
-
def can_handle(self, parsed_url:
|
26
|
+
def can_handle(self, parsed_url: ParseResult) -> bool:
|
25
27
|
return parsed_url.scheme == "https" and parsed_url.netloc.endswith(
|
26
28
|
".blob.core.windows.net"
|
27
29
|
)
|
28
30
|
|
29
31
|
def load_path(
|
30
32
|
self,
|
31
|
-
manifest_entry:
|
33
|
+
manifest_entry: ArtifactManifestEntry,
|
32
34
|
local: bool = False,
|
33
|
-
) ->
|
35
|
+
) -> URIStr | FilePathStr:
|
34
36
|
assert manifest_entry.ref is not None
|
35
37
|
if not local:
|
36
38
|
return manifest_entry.ref
|
@@ -91,12 +93,12 @@ class AzureHandler(StorageHandler):
|
|
91
93
|
|
92
94
|
def store_path(
|
93
95
|
self,
|
94
|
-
artifact:
|
95
|
-
path:
|
96
|
-
name:
|
96
|
+
artifact: Artifact,
|
97
|
+
path: URIStr | FilePathStr,
|
98
|
+
name: StrPath | None = None,
|
97
99
|
checksum: bool = True,
|
98
|
-
max_objects:
|
99
|
-
) -> Sequence[
|
100
|
+
max_objects: int | None = None,
|
101
|
+
) -> Sequence[ArtifactManifestEntry]:
|
100
102
|
account_url, container_name, blob_name, query = self._parse_uri(path)
|
101
103
|
path = URIStr(f"{account_url}/{container_name}/{blob_name}")
|
102
104
|
|
@@ -127,7 +129,7 @@ class AzureHandler(StorageHandler):
|
|
127
129
|
)
|
128
130
|
]
|
129
131
|
|
130
|
-
entries:
|
132
|
+
entries: list[ArtifactManifestEntry] = []
|
131
133
|
container_client = blob_service_client.get_container_client(container_name)
|
132
134
|
max_objects = max_objects or DEFAULT_MAX_OBJECTS
|
133
135
|
for blob_properties in container_client.list_blobs(
|
@@ -163,7 +165,7 @@ class AzureHandler(StorageHandler):
|
|
163
165
|
|
164
166
|
def _get_credential(
|
165
167
|
self, account_url: str
|
166
|
-
) ->
|
168
|
+
) -> azure.identity.DefaultAzureCredential | str:
|
167
169
|
if (
|
168
170
|
wandb.run
|
169
171
|
and wandb.run.settings.azure_account_url_to_access_key is not None
|
@@ -172,7 +174,7 @@ class AzureHandler(StorageHandler):
|
|
172
174
|
return wandb.run.settings.azure_account_url_to_access_key[account_url]
|
173
175
|
return self._get_module("azure.identity").DefaultAzureCredential()
|
174
176
|
|
175
|
-
def _parse_uri(self, uri: str) ->
|
177
|
+
def _parse_uri(self, uri: str) -> tuple[str, str, str, dict[str, str]]:
|
176
178
|
parsed_url = urlparse(uri)
|
177
179
|
query = dict(parse_qsl(parsed_url.query))
|
178
180
|
account_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
@@ -181,7 +183,7 @@ class AzureHandler(StorageHandler):
|
|
181
183
|
|
182
184
|
def _create_entry(
|
183
185
|
self,
|
184
|
-
blob_properties:
|
186
|
+
blob_properties: azure.storage.blob.BlobProperties,
|
185
187
|
path: StrPath,
|
186
188
|
ref: URIStr,
|
187
189
|
) -> ArtifactManifestEntry:
|
@@ -197,7 +199,7 @@ class AzureHandler(StorageHandler):
|
|
197
199
|
)
|
198
200
|
|
199
201
|
def _is_directory_stub(
|
200
|
-
self, blob_properties:
|
202
|
+
self, blob_properties: azure.storage.blob.BlobProperties
|
201
203
|
) -> bool:
|
202
204
|
return (
|
203
205
|
blob_properties.has_key("metadata")
|