wandb 0.15.3__py3-none-any.whl → 0.15.5__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- wandb/__init__.py +1 -1
- wandb/analytics/sentry.py +1 -0
- wandb/apis/importers/base.py +20 -5
- wandb/apis/importers/mlflow.py +7 -1
- wandb/apis/internal.py +12 -0
- wandb/apis/public.py +247 -1387
- wandb/apis/reports/_panels.py +58 -35
- wandb/beta/workflows.py +6 -7
- wandb/cli/cli.py +130 -60
- wandb/data_types.py +3 -1
- wandb/filesync/dir_watcher.py +21 -27
- wandb/filesync/step_checksum.py +8 -8
- wandb/filesync/step_prepare.py +23 -10
- wandb/filesync/step_upload.py +13 -13
- wandb/filesync/upload_job.py +4 -8
- wandb/integration/cohere/__init__.py +3 -0
- wandb/integration/cohere/cohere.py +21 -0
- wandb/integration/cohere/resolver.py +347 -0
- wandb/integration/gym/__init__.py +4 -6
- wandb/integration/huggingface/__init__.py +3 -0
- wandb/integration/huggingface/huggingface.py +18 -0
- wandb/integration/huggingface/resolver.py +213 -0
- wandb/integration/langchain/wandb_tracer.py +16 -179
- wandb/integration/openai/__init__.py +1 -3
- wandb/integration/openai/openai.py +11 -143
- wandb/integration/openai/resolver.py +111 -38
- wandb/integration/sagemaker/config.py +2 -2
- wandb/integration/tensorboard/log.py +4 -4
- wandb/old/settings.py +24 -7
- wandb/proto/v3/wandb_telemetry_pb2.py +12 -12
- wandb/proto/v4/wandb_telemetry_pb2.py +12 -12
- wandb/proto/wandb_deprecated.py +3 -1
- wandb/sdk/__init__.py +1 -1
- wandb/sdk/artifacts/__init__.py +0 -0
- wandb/sdk/artifacts/artifact.py +2101 -0
- wandb/sdk/artifacts/artifact_download_logger.py +42 -0
- wandb/sdk/artifacts/artifact_manifest.py +67 -0
- wandb/sdk/artifacts/artifact_manifest_entry.py +159 -0
- wandb/sdk/artifacts/artifact_manifests/__init__.py +0 -0
- wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +91 -0
- wandb/sdk/{internal → artifacts}/artifact_saver.py +6 -5
- wandb/sdk/artifacts/artifact_state.py +10 -0
- wandb/sdk/{interface/artifacts/artifact_cache.py → artifacts/artifacts_cache.py} +22 -12
- wandb/sdk/artifacts/exceptions.py +55 -0
- wandb/sdk/artifacts/storage_handler.py +59 -0
- wandb/sdk/artifacts/storage_handlers/__init__.py +0 -0
- wandb/sdk/artifacts/storage_handlers/azure_handler.py +192 -0
- wandb/sdk/artifacts/storage_handlers/gcs_handler.py +224 -0
- wandb/sdk/artifacts/storage_handlers/http_handler.py +112 -0
- wandb/sdk/artifacts/storage_handlers/local_file_handler.py +134 -0
- wandb/sdk/artifacts/storage_handlers/multi_handler.py +53 -0
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +301 -0
- wandb/sdk/artifacts/storage_handlers/tracking_handler.py +67 -0
- wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +132 -0
- wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +72 -0
- wandb/sdk/artifacts/storage_layout.py +6 -0
- wandb/sdk/artifacts/storage_policies/__init__.py +0 -0
- wandb/sdk/artifacts/storage_policies/s3_bucket_policy.py +61 -0
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +386 -0
- wandb/sdk/{interface/artifacts/artifact_storage.py → artifacts/storage_policy.py} +5 -57
- wandb/sdk/data_types/_dtypes.py +7 -12
- wandb/sdk/data_types/base_types/json_metadata.py +3 -2
- wandb/sdk/data_types/base_types/media.py +8 -8
- wandb/sdk/data_types/base_types/wb_value.py +12 -13
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +5 -6
- wandb/sdk/data_types/helper_types/classes.py +6 -8
- wandb/sdk/data_types/helper_types/image_mask.py +5 -6
- wandb/sdk/data_types/histogram.py +4 -3
- wandb/sdk/data_types/html.py +3 -4
- wandb/sdk/data_types/image.py +11 -9
- wandb/sdk/data_types/molecule.py +5 -3
- wandb/sdk/data_types/object_3d.py +7 -5
- wandb/sdk/data_types/plotly.py +3 -2
- wandb/sdk/data_types/saved_model.py +11 -11
- wandb/sdk/data_types/trace_tree.py +5 -4
- wandb/sdk/data_types/utils.py +3 -5
- wandb/sdk/data_types/video.py +5 -4
- wandb/sdk/integration_utils/auto_logging.py +215 -0
- wandb/sdk/interface/interface.py +15 -15
- wandb/sdk/internal/file_pusher.py +8 -16
- wandb/sdk/internal/file_stream.py +5 -11
- wandb/sdk/internal/handler.py +13 -1
- wandb/sdk/internal/internal_api.py +287 -13
- wandb/sdk/internal/job_builder.py +119 -30
- wandb/sdk/internal/sender.py +6 -26
- wandb/sdk/internal/settings_static.py +2 -0
- wandb/sdk/internal/system/assets/__init__.py +2 -0
- wandb/sdk/internal/system/assets/gpu.py +42 -0
- wandb/sdk/internal/system/assets/gpu_amd.py +216 -0
- wandb/sdk/internal/system/env_probe_helpers.py +13 -0
- wandb/sdk/internal/system/system_info.py +3 -3
- wandb/sdk/internal/tb_watcher.py +32 -22
- wandb/sdk/internal/thread_local_settings.py +18 -0
- wandb/sdk/launch/_project_spec.py +57 -11
- wandb/sdk/launch/agent/agent.py +147 -65
- wandb/sdk/launch/agent/job_status_tracker.py +34 -0
- wandb/sdk/launch/agent/run_queue_item_file_saver.py +45 -0
- wandb/sdk/launch/builder/abstract.py +5 -1
- wandb/sdk/launch/builder/build.py +21 -18
- wandb/sdk/launch/builder/docker_builder.py +10 -4
- wandb/sdk/launch/builder/kaniko_builder.py +113 -23
- wandb/sdk/launch/builder/noop.py +6 -3
- wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +46 -14
- wandb/sdk/launch/environment/aws_environment.py +3 -2
- wandb/sdk/launch/environment/azure_environment.py +124 -0
- wandb/sdk/launch/environment/gcp_environment.py +2 -4
- wandb/sdk/launch/environment/local_environment.py +1 -1
- wandb/sdk/launch/errors.py +19 -0
- wandb/sdk/launch/github_reference.py +32 -19
- wandb/sdk/launch/launch.py +3 -8
- wandb/sdk/launch/launch_add.py +6 -2
- wandb/sdk/launch/loader.py +21 -2
- wandb/sdk/launch/registry/azure_container_registry.py +132 -0
- wandb/sdk/launch/registry/elastic_container_registry.py +39 -5
- wandb/sdk/launch/registry/google_artifact_registry.py +68 -26
- wandb/sdk/launch/registry/local_registry.py +2 -1
- wandb/sdk/launch/runner/abstract.py +24 -3
- wandb/sdk/launch/runner/kubernetes_runner.py +479 -26
- wandb/sdk/launch/runner/local_container.py +103 -51
- wandb/sdk/launch/runner/local_process.py +1 -1
- wandb/sdk/launch/runner/sagemaker_runner.py +60 -10
- wandb/sdk/launch/runner/vertex_runner.py +10 -5
- wandb/sdk/launch/sweeps/__init__.py +7 -9
- wandb/sdk/launch/sweeps/scheduler.py +307 -77
- wandb/sdk/launch/sweeps/scheduler_sweep.py +2 -1
- wandb/sdk/launch/sweeps/utils.py +82 -35
- wandb/sdk/launch/utils.py +89 -75
- wandb/sdk/lib/_settings_toposort_generated.py +7 -0
- wandb/sdk/lib/capped_dict.py +26 -0
- wandb/sdk/lib/{git.py → gitlib.py} +76 -59
- wandb/sdk/lib/hashutil.py +12 -4
- wandb/sdk/lib/paths.py +96 -8
- wandb/sdk/lib/sock_client.py +2 -2
- wandb/sdk/lib/timer.py +1 -0
- wandb/sdk/service/server.py +22 -9
- wandb/sdk/service/server_sock.py +1 -1
- wandb/sdk/service/service.py +27 -8
- wandb/sdk/verify/verify.py +4 -7
- wandb/sdk/wandb_config.py +2 -6
- wandb/sdk/wandb_init.py +57 -53
- wandb/sdk/wandb_require.py +7 -0
- wandb/sdk/wandb_run.py +61 -223
- wandb/sdk/wandb_settings.py +28 -4
- wandb/testing/relay.py +15 -2
- wandb/util.py +74 -36
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/METADATA +15 -9
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/RECORD +151 -116
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/entry_points.txt +1 -0
- wandb/integration/langchain/util.py +0 -191
- wandb/sdk/interface/artifacts/__init__.py +0 -33
- wandb/sdk/interface/artifacts/artifact.py +0 -615
- wandb/sdk/interface/artifacts/artifact_manifest.py +0 -131
- wandb/sdk/wandb_artifacts.py +0 -2226
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/LICENSE +0 -0
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/WHEEL +0 -0
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/top_level.txt +0 -0
wandb/sdk/wandb_artifacts.py
DELETED
@@ -1,2226 +0,0 @@
|
|
1
|
-
import base64
|
2
|
-
import contextlib
|
3
|
-
import json
|
4
|
-
import os
|
5
|
-
import pathlib
|
6
|
-
import re
|
7
|
-
import shutil
|
8
|
-
import tempfile
|
9
|
-
import time
|
10
|
-
from types import ModuleType
|
11
|
-
from typing import (
|
12
|
-
IO,
|
13
|
-
TYPE_CHECKING,
|
14
|
-
Any,
|
15
|
-
Dict,
|
16
|
-
Generator,
|
17
|
-
List,
|
18
|
-
Mapping,
|
19
|
-
Optional,
|
20
|
-
Sequence,
|
21
|
-
Tuple,
|
22
|
-
Union,
|
23
|
-
cast,
|
24
|
-
)
|
25
|
-
from urllib.parse import parse_qsl, quote, urlparse
|
26
|
-
|
27
|
-
import requests
|
28
|
-
import urllib3
|
29
|
-
|
30
|
-
import wandb
|
31
|
-
import wandb.data_types as data_types
|
32
|
-
from wandb import env, util
|
33
|
-
from wandb.apis import InternalApi, PublicApi
|
34
|
-
from wandb.apis.public import Artifact as PublicArtifact
|
35
|
-
from wandb.errors import CommError
|
36
|
-
from wandb.errors.term import termlog, termwarn
|
37
|
-
from wandb.sdk import lib as wandb_lib
|
38
|
-
from wandb.sdk.data_types._dtypes import Type, TypeRegistry
|
39
|
-
from wandb.sdk.interface.artifacts import Artifact as ArtifactInterface
|
40
|
-
from wandb.sdk.interface.artifacts import (
|
41
|
-
ArtifactFinalizedError,
|
42
|
-
ArtifactManifest,
|
43
|
-
ArtifactManifestEntry,
|
44
|
-
ArtifactNotLoggedError,
|
45
|
-
ArtifactsCache,
|
46
|
-
StorageHandler,
|
47
|
-
StorageLayout,
|
48
|
-
StoragePolicy,
|
49
|
-
get_artifacts_cache,
|
50
|
-
)
|
51
|
-
from wandb.sdk.internal import progress
|
52
|
-
from wandb.sdk.internal.artifact_saver import get_staging_dir
|
53
|
-
from wandb.sdk.lib import filesystem, runid
|
54
|
-
from wandb.sdk.lib.hashutil import (
|
55
|
-
B64MD5,
|
56
|
-
ETag,
|
57
|
-
HexMD5,
|
58
|
-
_md5,
|
59
|
-
b64_to_hex_id,
|
60
|
-
hex_to_b64_id,
|
61
|
-
md5_file_b64,
|
62
|
-
md5_string,
|
63
|
-
)
|
64
|
-
from wandb.sdk.lib.paths import FilePathStr, LogicalFilePathStr, URIStr
|
65
|
-
|
66
|
-
if TYPE_CHECKING:
|
67
|
-
from urllib.parse import ParseResult
|
68
|
-
|
69
|
-
import azure.storage.blob # type: ignore
|
70
|
-
|
71
|
-
# We could probably use https://pypi.org/project/boto3-stubs/ or something
|
72
|
-
# instead of `type:ignore`ing these boto imports, but it's nontrivial:
|
73
|
-
# for some reason, despite being actively maintained as of 2022-09-30,
|
74
|
-
# the latest release of boto3-stubs doesn't include all the features we use.
|
75
|
-
import boto3 # type: ignore
|
76
|
-
import boto3.resources.base # type: ignore
|
77
|
-
import boto3.s3 # type: ignore
|
78
|
-
import boto3.session # type: ignore
|
79
|
-
import google.cloud.storage as gcs_module # type: ignore
|
80
|
-
|
81
|
-
import wandb.apis.public
|
82
|
-
from wandb.filesync.step_prepare import StepPrepare
|
83
|
-
|
84
|
-
# This makes the first sleep 1s, and then doubles it up to total times,
|
85
|
-
# which makes for ~18 hours.
|
86
|
-
_REQUEST_RETRY_STRATEGY = urllib3.util.retry.Retry(
|
87
|
-
backoff_factor=1,
|
88
|
-
total=16,
|
89
|
-
status_forcelist=(308, 408, 409, 429, 500, 502, 503, 504),
|
90
|
-
)
|
91
|
-
|
92
|
-
_REQUEST_POOL_CONNECTIONS = 64
|
93
|
-
|
94
|
-
_REQUEST_POOL_MAXSIZE = 64
|
95
|
-
|
96
|
-
ARTIFACT_TMP = tempfile.TemporaryDirectory("wandb-artifacts")
|
97
|
-
|
98
|
-
|
99
|
-
class _AddedObj:
|
100
|
-
def __init__(self, entry: ArtifactManifestEntry, obj: data_types.WBValue):
|
101
|
-
self.entry = entry
|
102
|
-
self.obj = obj
|
103
|
-
|
104
|
-
|
105
|
-
def _normalize_metadata(metadata: Optional[Dict[str, Any]]) -> Dict[str, Any]:
|
106
|
-
if metadata is None:
|
107
|
-
return {}
|
108
|
-
if not isinstance(metadata, dict):
|
109
|
-
raise TypeError(f"metadata must be dict, not {type(metadata)}")
|
110
|
-
return cast(
|
111
|
-
Dict[str, Any], json.loads(json.dumps(util.json_friendly_val(metadata)))
|
112
|
-
)
|
113
|
-
|
114
|
-
|
115
|
-
class Artifact(ArtifactInterface):
|
116
|
-
"""Flexible and lightweight building block for dataset and model versioning.
|
117
|
-
|
118
|
-
Constructs an empty artifact whose contents can be populated using its
|
119
|
-
`add` family of functions. Once the artifact has all the desired files,
|
120
|
-
you can call `wandb.log_artifact()` to log it.
|
121
|
-
|
122
|
-
Arguments:
|
123
|
-
name: (str) A human-readable name for this artifact, which is how you
|
124
|
-
can identify this artifact in the UI or reference it in `use_artifact`
|
125
|
-
calls. Names can contain letters, numbers, underscores, hyphens, and
|
126
|
-
dots. The name must be unique across a project.
|
127
|
-
type: (str) The type of the artifact, which is used to organize and differentiate
|
128
|
-
artifacts. Common types include `dataset` or `model`, but you can use any string
|
129
|
-
containing letters, numbers, underscores, hyphens, and dots.
|
130
|
-
description: (str, optional) Free text that offers a description of the artifact. The
|
131
|
-
description is markdown rendered in the UI, so this is a good place to place tables,
|
132
|
-
links, etc.
|
133
|
-
metadata: (dict, optional) Structured data associated with the artifact,
|
134
|
-
for example class distribution of a dataset. This will eventually be queryable
|
135
|
-
and plottable in the UI. There is a hard limit of 100 total keys.
|
136
|
-
|
137
|
-
Examples:
|
138
|
-
Basic usage
|
139
|
-
```
|
140
|
-
wandb.init()
|
141
|
-
|
142
|
-
artifact = wandb.Artifact('mnist', type='dataset')
|
143
|
-
artifact.add_dir('mnist/')
|
144
|
-
wandb.log_artifact(artifact)
|
145
|
-
```
|
146
|
-
|
147
|
-
Returns:
|
148
|
-
An `Artifact` object.
|
149
|
-
"""
|
150
|
-
|
151
|
-
_added_objs: Dict[int, _AddedObj]
|
152
|
-
_added_local_paths: Dict[str, ArtifactManifestEntry]
|
153
|
-
_distributed_id: Optional[str]
|
154
|
-
_metadata: dict
|
155
|
-
_logged_artifact: Optional[ArtifactInterface]
|
156
|
-
_incremental: bool
|
157
|
-
_client_id: str
|
158
|
-
|
159
|
-
def __init__(
|
160
|
-
self,
|
161
|
-
name: str,
|
162
|
-
type: str,
|
163
|
-
description: Optional[str] = None,
|
164
|
-
metadata: Optional[dict] = None,
|
165
|
-
incremental: Optional[bool] = None,
|
166
|
-
use_as: Optional[str] = None,
|
167
|
-
) -> None:
|
168
|
-
if not re.match(r"^[a-zA-Z0-9_\-.]+$", name):
|
169
|
-
raise ValueError(
|
170
|
-
"Artifact name may only contain alphanumeric characters, dashes, underscores, and dots. "
|
171
|
-
'Invalid name: "%s"' % name
|
172
|
-
)
|
173
|
-
if type == "job" or type.startswith("wandb-"):
|
174
|
-
raise ValueError(
|
175
|
-
"Artifact types 'job' and 'wandb-*' are reserved for internal use. "
|
176
|
-
"Please use a different type."
|
177
|
-
)
|
178
|
-
|
179
|
-
metadata = _normalize_metadata(metadata)
|
180
|
-
# TODO: this shouldn't be a property of the artifact. It's a more like an
|
181
|
-
# argument to log_artifact.
|
182
|
-
storage_layout = StorageLayout.V2
|
183
|
-
if env.get_use_v1_artifacts():
|
184
|
-
storage_layout = StorageLayout.V1
|
185
|
-
|
186
|
-
self._storage_policy = WandbStoragePolicy(
|
187
|
-
config={
|
188
|
-
"storageLayout": storage_layout,
|
189
|
-
# TODO: storage region
|
190
|
-
}
|
191
|
-
)
|
192
|
-
self._api = InternalApi()
|
193
|
-
self._final = False
|
194
|
-
self._digest = ""
|
195
|
-
self._file_entries = None
|
196
|
-
self._manifest = ArtifactManifestV1(self._storage_policy)
|
197
|
-
self._cache = get_artifacts_cache()
|
198
|
-
self._added_objs = {}
|
199
|
-
self._added_local_paths = {}
|
200
|
-
# You can write into this directory when creating artifact files
|
201
|
-
self._artifact_dir = tempfile.TemporaryDirectory()
|
202
|
-
self._type = type
|
203
|
-
self._name = name
|
204
|
-
self._description = description
|
205
|
-
self._metadata = metadata
|
206
|
-
self._distributed_id = None
|
207
|
-
self._logged_artifact = None
|
208
|
-
self._incremental = False
|
209
|
-
self._client_id = runid.generate_id(128)
|
210
|
-
self._sequence_client_id = runid.generate_id(128)
|
211
|
-
self._cache.store_client_artifact(self)
|
212
|
-
self._use_as = use_as
|
213
|
-
|
214
|
-
if incremental:
|
215
|
-
self._incremental = incremental
|
216
|
-
wandb.termwarn("Using experimental arg `incremental`")
|
217
|
-
|
218
|
-
@property
|
219
|
-
def id(self) -> Optional[str]:
|
220
|
-
if self._logged_artifact:
|
221
|
-
return self._logged_artifact.id
|
222
|
-
|
223
|
-
# The artifact hasn't been saved so an ID doesn't exist yet.
|
224
|
-
return None
|
225
|
-
|
226
|
-
@property
|
227
|
-
def source_version(self) -> Optional[str]:
|
228
|
-
if self._logged_artifact:
|
229
|
-
return self._logged_artifact.source_version
|
230
|
-
|
231
|
-
return None
|
232
|
-
|
233
|
-
@property
|
234
|
-
def version(self) -> str:
|
235
|
-
if self._logged_artifact:
|
236
|
-
return self._logged_artifact.version
|
237
|
-
|
238
|
-
raise ArtifactNotLoggedError(self, "version")
|
239
|
-
|
240
|
-
@property
|
241
|
-
def entity(self) -> str:
|
242
|
-
if self._logged_artifact:
|
243
|
-
return self._logged_artifact.entity
|
244
|
-
return self._api.settings("entity") or self._api.viewer().get("entity") # type: ignore
|
245
|
-
|
246
|
-
@property
|
247
|
-
def project(self) -> str:
|
248
|
-
if self._logged_artifact:
|
249
|
-
return self._logged_artifact.project
|
250
|
-
|
251
|
-
return self._api.settings("project") # type: ignore
|
252
|
-
|
253
|
-
@property
|
254
|
-
def manifest(self) -> ArtifactManifest:
|
255
|
-
if self._logged_artifact:
|
256
|
-
return self._logged_artifact.manifest
|
257
|
-
|
258
|
-
self.finalize()
|
259
|
-
return self._manifest
|
260
|
-
|
261
|
-
@property
|
262
|
-
def digest(self) -> str:
|
263
|
-
if self._logged_artifact:
|
264
|
-
return self._logged_artifact.digest
|
265
|
-
|
266
|
-
self.finalize()
|
267
|
-
# Digest will be none if the artifact hasn't been saved yet.
|
268
|
-
return self._digest
|
269
|
-
|
270
|
-
@property
|
271
|
-
def type(self) -> str:
|
272
|
-
if self._logged_artifact:
|
273
|
-
return self._logged_artifact.type
|
274
|
-
|
275
|
-
return self._type
|
276
|
-
|
277
|
-
@property
|
278
|
-
def name(self) -> str:
|
279
|
-
if self._logged_artifact:
|
280
|
-
return self._logged_artifact.name
|
281
|
-
|
282
|
-
return self._name
|
283
|
-
|
284
|
-
@property
|
285
|
-
def full_name(self) -> str:
|
286
|
-
if self._logged_artifact:
|
287
|
-
return self._logged_artifact.full_name
|
288
|
-
|
289
|
-
return super().full_name
|
290
|
-
|
291
|
-
@property
|
292
|
-
def state(self) -> str:
|
293
|
-
if self._logged_artifact:
|
294
|
-
return self._logged_artifact.state
|
295
|
-
|
296
|
-
return "PENDING"
|
297
|
-
|
298
|
-
@property
|
299
|
-
def size(self) -> int:
|
300
|
-
if self._logged_artifact:
|
301
|
-
return self._logged_artifact.size
|
302
|
-
sizes: List[int]
|
303
|
-
sizes = []
|
304
|
-
for entry in self._manifest.entries:
|
305
|
-
e_size = self._manifest.entries[entry].size
|
306
|
-
if e_size is not None:
|
307
|
-
sizes.append(e_size)
|
308
|
-
return sum(sizes)
|
309
|
-
|
310
|
-
@property
|
311
|
-
def commit_hash(self) -> str:
|
312
|
-
if self._logged_artifact:
|
313
|
-
return self._logged_artifact.commit_hash
|
314
|
-
|
315
|
-
raise ArtifactNotLoggedError(self, "commit_hash")
|
316
|
-
|
317
|
-
@property
|
318
|
-
def description(self) -> Optional[str]:
|
319
|
-
if self._logged_artifact:
|
320
|
-
return self._logged_artifact.description
|
321
|
-
|
322
|
-
return self._description
|
323
|
-
|
324
|
-
@description.setter
|
325
|
-
def description(self, desc: Optional[str]) -> None:
|
326
|
-
if self._logged_artifact:
|
327
|
-
self._logged_artifact.description = desc
|
328
|
-
return
|
329
|
-
|
330
|
-
self._description = desc
|
331
|
-
|
332
|
-
@property
|
333
|
-
def metadata(self) -> dict:
|
334
|
-
if self._logged_artifact:
|
335
|
-
return self._logged_artifact.metadata
|
336
|
-
|
337
|
-
return self._metadata
|
338
|
-
|
339
|
-
@metadata.setter
|
340
|
-
def metadata(self, metadata: dict) -> None:
|
341
|
-
metadata = _normalize_metadata(metadata)
|
342
|
-
if self._logged_artifact:
|
343
|
-
self._logged_artifact.metadata = metadata
|
344
|
-
return
|
345
|
-
|
346
|
-
self._metadata = metadata
|
347
|
-
|
348
|
-
@property
|
349
|
-
def aliases(self) -> List[str]:
|
350
|
-
if self._logged_artifact:
|
351
|
-
return self._logged_artifact.aliases
|
352
|
-
|
353
|
-
raise ArtifactNotLoggedError(self, "aliases")
|
354
|
-
|
355
|
-
@aliases.setter
|
356
|
-
def aliases(self, aliases: List[str]) -> None:
|
357
|
-
"""Set artifact aliases.
|
358
|
-
|
359
|
-
Arguments:
|
360
|
-
aliases: (list) The list of aliases associated with this artifact.
|
361
|
-
"""
|
362
|
-
if self._logged_artifact:
|
363
|
-
self._logged_artifact.aliases = aliases
|
364
|
-
return
|
365
|
-
|
366
|
-
raise ArtifactNotLoggedError(self, "aliases")
|
367
|
-
|
368
|
-
@property
|
369
|
-
def use_as(self) -> Optional[str]:
|
370
|
-
return self._use_as
|
371
|
-
|
372
|
-
@property
|
373
|
-
def distributed_id(self) -> Optional[str]:
|
374
|
-
return self._distributed_id
|
375
|
-
|
376
|
-
@distributed_id.setter
|
377
|
-
def distributed_id(self, distributed_id: Optional[str]) -> None:
|
378
|
-
self._distributed_id = distributed_id
|
379
|
-
|
380
|
-
@property
|
381
|
-
def incremental(self) -> bool:
|
382
|
-
return self._incremental
|
383
|
-
|
384
|
-
def used_by(self) -> List["wandb.apis.public.Run"]:
|
385
|
-
if self._logged_artifact:
|
386
|
-
return self._logged_artifact.used_by()
|
387
|
-
|
388
|
-
raise ArtifactNotLoggedError(self, "used_by")
|
389
|
-
|
390
|
-
def logged_by(self) -> "wandb.apis.public.Run":
|
391
|
-
if self._logged_artifact:
|
392
|
-
return self._logged_artifact.logged_by()
|
393
|
-
|
394
|
-
raise ArtifactNotLoggedError(self, "logged_by")
|
395
|
-
|
396
|
-
@contextlib.contextmanager
|
397
|
-
def new_file(
|
398
|
-
self, name: str, mode: str = "w", encoding: Optional[str] = None
|
399
|
-
) -> Generator[IO, None, None]:
|
400
|
-
self._ensure_can_add()
|
401
|
-
path = os.path.join(self._artifact_dir.name, name.lstrip("/"))
|
402
|
-
if os.path.exists(path):
|
403
|
-
raise ValueError(f"File with name {name!r} already exists at {path!r}")
|
404
|
-
|
405
|
-
filesystem.mkdir_exists_ok(os.path.dirname(path))
|
406
|
-
try:
|
407
|
-
with util.fsync_open(path, mode, encoding) as f:
|
408
|
-
yield f
|
409
|
-
except UnicodeEncodeError as e:
|
410
|
-
wandb.termerror(
|
411
|
-
f"Failed to open the provided file (UnicodeEncodeError: {e}). Please provide the proper encoding."
|
412
|
-
)
|
413
|
-
raise e
|
414
|
-
self.add_file(path, name=name)
|
415
|
-
|
416
|
-
def add_file(
|
417
|
-
self,
|
418
|
-
local_path: str,
|
419
|
-
name: Optional[str] = None,
|
420
|
-
is_tmp: Optional[bool] = False,
|
421
|
-
) -> ArtifactManifestEntry:
|
422
|
-
self._ensure_can_add()
|
423
|
-
if not os.path.isfile(local_path):
|
424
|
-
raise ValueError("Path is not a file: %s" % local_path)
|
425
|
-
|
426
|
-
name = util.to_forward_slash_path(name or os.path.basename(local_path))
|
427
|
-
digest = md5_file_b64(local_path)
|
428
|
-
|
429
|
-
if is_tmp:
|
430
|
-
file_path, file_name = os.path.split(name)
|
431
|
-
file_name_parts = file_name.split(".")
|
432
|
-
file_name_parts[0] = b64_to_hex_id(digest)[:20]
|
433
|
-
name = os.path.join(file_path, ".".join(file_name_parts))
|
434
|
-
|
435
|
-
return self._add_local_file(name, local_path, digest=digest)
|
436
|
-
|
437
|
-
def add_dir(self, local_path: str, name: Optional[str] = None) -> None:
|
438
|
-
self._ensure_can_add()
|
439
|
-
if not os.path.isdir(local_path):
|
440
|
-
raise ValueError("Path is not a directory: %s" % local_path)
|
441
|
-
|
442
|
-
termlog(
|
443
|
-
"Adding directory to artifact (%s)... "
|
444
|
-
% os.path.join(".", os.path.normpath(local_path)),
|
445
|
-
newline=False,
|
446
|
-
)
|
447
|
-
start_time = time.time()
|
448
|
-
|
449
|
-
paths = []
|
450
|
-
for dirpath, _, filenames in os.walk(local_path, followlinks=True):
|
451
|
-
for fname in filenames:
|
452
|
-
physical_path = os.path.join(dirpath, fname)
|
453
|
-
logical_path = os.path.relpath(physical_path, start=local_path)
|
454
|
-
if name is not None:
|
455
|
-
logical_path = os.path.join(name, logical_path)
|
456
|
-
paths.append((logical_path, physical_path))
|
457
|
-
|
458
|
-
def add_manifest_file(log_phy_path: Tuple[str, str]) -> None:
|
459
|
-
logical_path, physical_path = log_phy_path
|
460
|
-
self._add_local_file(logical_path, physical_path)
|
461
|
-
|
462
|
-
import multiprocessing.dummy # this uses threads
|
463
|
-
|
464
|
-
num_threads = 8
|
465
|
-
pool = multiprocessing.dummy.Pool(num_threads)
|
466
|
-
pool.map(add_manifest_file, paths)
|
467
|
-
pool.close()
|
468
|
-
pool.join()
|
469
|
-
|
470
|
-
termlog("Done. %.1fs" % (time.time() - start_time), prefix=False)
|
471
|
-
|
472
|
-
def add_reference(
|
473
|
-
self,
|
474
|
-
uri: Union[ArtifactManifestEntry, str],
|
475
|
-
name: Optional[str] = None,
|
476
|
-
checksum: bool = True,
|
477
|
-
max_objects: Optional[int] = None,
|
478
|
-
) -> Sequence[ArtifactManifestEntry]:
|
479
|
-
self._ensure_can_add()
|
480
|
-
if name is not None:
|
481
|
-
name = util.to_forward_slash_path(name)
|
482
|
-
|
483
|
-
# This is a bit of a hack, we want to check if the uri is a of the type
|
484
|
-
# ArtifactManifestEntry which is a private class returned by Artifact.get_path in
|
485
|
-
# wandb/apis/public.py. If so, then recover the reference URL.
|
486
|
-
if isinstance(uri, ArtifactManifestEntry) and uri.parent_artifact() != self:
|
487
|
-
ref_url_fn = uri.ref_url
|
488
|
-
uri_str = ref_url_fn()
|
489
|
-
elif isinstance(uri, str):
|
490
|
-
uri_str = uri
|
491
|
-
url = urlparse(str(uri_str))
|
492
|
-
if not url.scheme:
|
493
|
-
raise ValueError(
|
494
|
-
"References must be URIs. To reference a local file, use file://"
|
495
|
-
)
|
496
|
-
|
497
|
-
manifest_entries = self._storage_policy.store_reference(
|
498
|
-
self,
|
499
|
-
URIStr(uri_str),
|
500
|
-
name=name,
|
501
|
-
checksum=checksum,
|
502
|
-
max_objects=max_objects,
|
503
|
-
)
|
504
|
-
for entry in manifest_entries:
|
505
|
-
self._manifest.add_entry(entry)
|
506
|
-
|
507
|
-
return manifest_entries
|
508
|
-
|
509
|
-
def add(self, obj: data_types.WBValue, name: str) -> ArtifactManifestEntry:
|
510
|
-
self._ensure_can_add()
|
511
|
-
name = util.to_forward_slash_path(name)
|
512
|
-
|
513
|
-
# This is a "hack" to automatically rename tables added to
|
514
|
-
# the wandb /media/tables directory to their sha-based name.
|
515
|
-
# TODO: figure out a more appropriate convention.
|
516
|
-
is_tmp_name = name.startswith("media/tables")
|
517
|
-
|
518
|
-
# Validate that the object is one of the correct wandb.Media types
|
519
|
-
# TODO: move this to checking subclass of wandb.Media once all are
|
520
|
-
# generally supported
|
521
|
-
allowed_types = [
|
522
|
-
data_types.Bokeh,
|
523
|
-
data_types.JoinedTable,
|
524
|
-
data_types.PartitionedTable,
|
525
|
-
data_types.Table,
|
526
|
-
data_types.Classes,
|
527
|
-
data_types.ImageMask,
|
528
|
-
data_types.BoundingBoxes2D,
|
529
|
-
data_types.Audio,
|
530
|
-
data_types.Image,
|
531
|
-
data_types.Video,
|
532
|
-
data_types.Html,
|
533
|
-
data_types.Object3D,
|
534
|
-
data_types.Molecule,
|
535
|
-
data_types._SavedModel,
|
536
|
-
]
|
537
|
-
|
538
|
-
if not any(isinstance(obj, t) for t in allowed_types):
|
539
|
-
raise ValueError(
|
540
|
-
"Found object of type {}, expected one of {}.".format(
|
541
|
-
obj.__class__, allowed_types
|
542
|
-
)
|
543
|
-
)
|
544
|
-
|
545
|
-
obj_id = id(obj)
|
546
|
-
if obj_id in self._added_objs:
|
547
|
-
return self._added_objs[obj_id].entry
|
548
|
-
|
549
|
-
# If the object is coming from another artifact, save it as a reference
|
550
|
-
ref_path = obj._get_artifact_entry_ref_url()
|
551
|
-
if ref_path is not None:
|
552
|
-
return self.add_reference(ref_path, type(obj).with_suffix(name))[0]
|
553
|
-
|
554
|
-
val = obj.to_json(self)
|
555
|
-
name = obj.with_suffix(name)
|
556
|
-
entry = self._manifest.get_entry_by_path(name)
|
557
|
-
if entry is not None:
|
558
|
-
return entry
|
559
|
-
|
560
|
-
def do_write(f: IO) -> None:
|
561
|
-
import json
|
562
|
-
|
563
|
-
# TODO: Do we need to open with utf-8 codec?
|
564
|
-
f.write(json.dumps(val, sort_keys=True))
|
565
|
-
|
566
|
-
if is_tmp_name:
|
567
|
-
file_path = os.path.join(ARTIFACT_TMP.name, str(id(self)), name)
|
568
|
-
folder_path, _ = os.path.split(file_path)
|
569
|
-
if not os.path.exists(folder_path):
|
570
|
-
os.makedirs(folder_path)
|
571
|
-
with open(file_path, "w") as tmp_f:
|
572
|
-
do_write(tmp_f)
|
573
|
-
else:
|
574
|
-
with self.new_file(name) as f:
|
575
|
-
file_path = f.name
|
576
|
-
do_write(f)
|
577
|
-
|
578
|
-
# Note, we add the file from our temp directory.
|
579
|
-
# It will be added again later on finalize, but succeed since
|
580
|
-
# the checksum should match
|
581
|
-
entry = self.add_file(file_path, name, is_tmp_name)
|
582
|
-
self._added_objs[obj_id] = _AddedObj(entry, obj)
|
583
|
-
if obj._artifact_target is None:
|
584
|
-
obj._set_artifact_target(self, entry.path)
|
585
|
-
|
586
|
-
if is_tmp_name:
|
587
|
-
if os.path.exists(file_path):
|
588
|
-
os.remove(file_path)
|
589
|
-
|
590
|
-
return entry
|
591
|
-
|
592
|
-
def get_path(self, name: str) -> ArtifactManifestEntry:
|
593
|
-
if self._logged_artifact:
|
594
|
-
return self._logged_artifact.get_path(name)
|
595
|
-
|
596
|
-
raise ArtifactNotLoggedError(self, "get_path")
|
597
|
-
|
598
|
-
def get(self, name: str) -> data_types.WBValue:
|
599
|
-
if self._logged_artifact:
|
600
|
-
return self._logged_artifact.get(name)
|
601
|
-
|
602
|
-
raise ArtifactNotLoggedError(self, "get")
|
603
|
-
|
604
|
-
def download(
|
605
|
-
self, root: Optional[str] = None, recursive: bool = False
|
606
|
-
) -> FilePathStr:
|
607
|
-
if self._logged_artifact:
|
608
|
-
return self._logged_artifact.download(root=root, recursive=recursive)
|
609
|
-
|
610
|
-
raise ArtifactNotLoggedError(self, "download")
|
611
|
-
|
612
|
-
def checkout(self, root: Optional[str] = None) -> str:
|
613
|
-
if self._logged_artifact:
|
614
|
-
return self._logged_artifact.checkout(root=root)
|
615
|
-
|
616
|
-
raise ArtifactNotLoggedError(self, "checkout")
|
617
|
-
|
618
|
-
def verify(self, root: Optional[str] = None) -> bool:
|
619
|
-
if self._logged_artifact:
|
620
|
-
return self._logged_artifact.verify(root=root)
|
621
|
-
|
622
|
-
raise ArtifactNotLoggedError(self, "verify")
|
623
|
-
|
624
|
-
def save(
|
625
|
-
self,
|
626
|
-
project: Optional[str] = None,
|
627
|
-
settings: Optional["wandb.wandb_sdk.wandb_settings.Settings"] = None,
|
628
|
-
) -> None:
|
629
|
-
"""Persist any changes made to the artifact.
|
630
|
-
|
631
|
-
If currently in a run, that run will log this artifact. If not currently in a
|
632
|
-
run, a run of type "auto" will be created to track this artifact.
|
633
|
-
|
634
|
-
Arguments:
|
635
|
-
project: (str, optional) A project to use for the artifact in the case that
|
636
|
-
a run is not already in context settings: (wandb.Settings, optional) A
|
637
|
-
settings object to use when initializing an automatic run. Most commonly
|
638
|
-
used in testing harness.
|
639
|
-
|
640
|
-
Returns:
|
641
|
-
None
|
642
|
-
"""
|
643
|
-
if self._incremental:
|
644
|
-
with wandb_lib.telemetry.context() as tel:
|
645
|
-
tel.feature.artifact_incremental = True
|
646
|
-
|
647
|
-
if self._logged_artifact:
|
648
|
-
return self._logged_artifact.save()
|
649
|
-
else:
|
650
|
-
if wandb.run is None:
|
651
|
-
if settings is None:
|
652
|
-
settings = wandb.Settings(silent="true")
|
653
|
-
with wandb.init(
|
654
|
-
project=project, job_type="auto", settings=settings
|
655
|
-
) as run:
|
656
|
-
# redoing this here because in this branch we know we didn't
|
657
|
-
# have the run at the beginning of the method
|
658
|
-
if self._incremental:
|
659
|
-
with wandb_lib.telemetry.context(run=run) as tel:
|
660
|
-
tel.feature.artifact_incremental = True
|
661
|
-
run.log_artifact(self)
|
662
|
-
else:
|
663
|
-
wandb.run.log_artifact(self)
|
664
|
-
|
665
|
-
def delete(self) -> None:
|
666
|
-
if self._logged_artifact:
|
667
|
-
return self._logged_artifact.delete()
|
668
|
-
|
669
|
-
raise ArtifactNotLoggedError(self, "delete")
|
670
|
-
|
671
|
-
def wait(self, timeout: Optional[int] = None) -> ArtifactInterface:
|
672
|
-
"""Wait for an artifact to finish logging.
|
673
|
-
|
674
|
-
Arguments:
|
675
|
-
timeout: (int, optional) Wait up to this long.
|
676
|
-
"""
|
677
|
-
if self._logged_artifact:
|
678
|
-
return self._logged_artifact.wait(timeout) # type: ignore [call-arg]
|
679
|
-
|
680
|
-
raise ArtifactNotLoggedError(self, "wait")
|
681
|
-
|
682
|
-
def get_added_local_path_name(self, local_path: str) -> Optional[str]:
|
683
|
-
"""Get the artifact relative name of a file added by a local filesystem path.
|
684
|
-
|
685
|
-
Arguments:
|
686
|
-
local_path: (str) The local path to resolve into an artifact relative name.
|
687
|
-
|
688
|
-
Returns:
|
689
|
-
str: The artifact relative name.
|
690
|
-
|
691
|
-
Examples:
|
692
|
-
Basic usage
|
693
|
-
```
|
694
|
-
artifact = wandb.Artifact('my_dataset', type='dataset')
|
695
|
-
artifact.add_file('path/to/file.txt', name='artifact/path/file.txt')
|
696
|
-
|
697
|
-
# Returns `artifact/path/file.txt`:
|
698
|
-
name = artifact.get_added_local_path_name('path/to/file.txt')
|
699
|
-
```
|
700
|
-
"""
|
701
|
-
entry = self._added_local_paths.get(local_path, None)
|
702
|
-
if entry is None:
|
703
|
-
return None
|
704
|
-
return entry.path
|
705
|
-
|
706
|
-
def finalize(self) -> None:
|
707
|
-
"""Mark this artifact as final, disallowing further modifications.
|
708
|
-
|
709
|
-
This happens automatically when calling `log_artifact`.
|
710
|
-
|
711
|
-
Returns:
|
712
|
-
None
|
713
|
-
"""
|
714
|
-
if self._final:
|
715
|
-
return self._file_entries
|
716
|
-
|
717
|
-
# mark final after all files are added
|
718
|
-
self._final = True
|
719
|
-
self._digest = self._manifest.digest()
|
720
|
-
|
721
|
-
def json_encode(self) -> Dict[str, Any]:
|
722
|
-
if not self._logged_artifact:
|
723
|
-
raise ArtifactNotLoggedError(self, "json_encode")
|
724
|
-
return util.artifact_to_json(self)
|
725
|
-
|
726
|
-
def _ensure_can_add(self) -> None:
|
727
|
-
if self._final:
|
728
|
-
raise ArtifactFinalizedError(artifact=self)
|
729
|
-
|
730
|
-
def _add_local_file(
|
731
|
-
self, name: str, path: str, digest: Optional[B64MD5] = None
|
732
|
-
) -> ArtifactManifestEntry:
|
733
|
-
with tempfile.NamedTemporaryFile(dir=get_staging_dir(), delete=False) as f:
|
734
|
-
staging_path = f.name
|
735
|
-
shutil.copyfile(path, staging_path)
|
736
|
-
os.chmod(staging_path, 0o400)
|
737
|
-
|
738
|
-
entry = ArtifactManifestEntry(
|
739
|
-
path=util.to_forward_slash_path(name),
|
740
|
-
digest=digest or md5_file_b64(staging_path),
|
741
|
-
size=os.path.getsize(staging_path),
|
742
|
-
local_path=staging_path,
|
743
|
-
)
|
744
|
-
|
745
|
-
self._manifest.add_entry(entry)
|
746
|
-
self._added_local_paths[path] = entry
|
747
|
-
return entry
|
748
|
-
|
749
|
-
|
750
|
-
class ArtifactManifestV1(ArtifactManifest):
|
751
|
-
@classmethod
|
752
|
-
def version(cls) -> int:
|
753
|
-
return 1
|
754
|
-
|
755
|
-
@classmethod
|
756
|
-
def from_manifest_json(cls, manifest_json: Dict) -> "ArtifactManifestV1":
|
757
|
-
if manifest_json["version"] != cls.version():
|
758
|
-
raise ValueError(
|
759
|
-
"Expected manifest version 1, got %s" % manifest_json["version"]
|
760
|
-
)
|
761
|
-
|
762
|
-
storage_policy_name = manifest_json["storagePolicy"]
|
763
|
-
storage_policy_config = manifest_json.get("storagePolicyConfig", {})
|
764
|
-
storage_policy_cls = StoragePolicy.lookup_by_name(storage_policy_name)
|
765
|
-
if storage_policy_cls is None:
|
766
|
-
raise ValueError('Failed to find storage policy "%s"' % storage_policy_name)
|
767
|
-
if not issubclass(storage_policy_cls, WandbStoragePolicy):
|
768
|
-
raise ValueError(
|
769
|
-
"No handler found for storage handler of type '%s'"
|
770
|
-
% storage_policy_name
|
771
|
-
)
|
772
|
-
|
773
|
-
entries: Mapping[str, ArtifactManifestEntry]
|
774
|
-
entries = {
|
775
|
-
name: ArtifactManifestEntry(
|
776
|
-
path=LogicalFilePathStr(name),
|
777
|
-
digest=val["digest"],
|
778
|
-
birth_artifact_id=val.get("birthArtifactID"),
|
779
|
-
ref=val.get("ref"),
|
780
|
-
size=val.get("size"),
|
781
|
-
extra=val.get("extra"),
|
782
|
-
local_path=val.get("local_path"),
|
783
|
-
)
|
784
|
-
for name, val in manifest_json["contents"].items()
|
785
|
-
}
|
786
|
-
|
787
|
-
return cls(storage_policy_cls.from_config(storage_policy_config), entries)
|
788
|
-
|
789
|
-
def __init__(
|
790
|
-
self,
|
791
|
-
storage_policy: "WandbStoragePolicy",
|
792
|
-
entries: Optional[Mapping[str, ArtifactManifestEntry]] = None,
|
793
|
-
) -> None:
|
794
|
-
super().__init__(storage_policy, entries=entries)
|
795
|
-
|
796
|
-
def to_manifest_json(self) -> Dict:
|
797
|
-
"""This is the JSON that's stored in wandb_manifest.json.
|
798
|
-
|
799
|
-
If include_local is True we also include the local paths to files. This is
|
800
|
-
used to represent an artifact that's waiting to be saved on the current
|
801
|
-
system. We don't need to include the local paths in the artifact manifest
|
802
|
-
contents.
|
803
|
-
"""
|
804
|
-
contents = {}
|
805
|
-
for entry in sorted(self.entries.values(), key=lambda k: k.path):
|
806
|
-
json_entry: Dict[str, Any] = {
|
807
|
-
"digest": entry.digest,
|
808
|
-
}
|
809
|
-
if entry.birth_artifact_id:
|
810
|
-
json_entry["birthArtifactID"] = entry.birth_artifact_id
|
811
|
-
if entry.ref:
|
812
|
-
json_entry["ref"] = entry.ref
|
813
|
-
if entry.extra:
|
814
|
-
json_entry["extra"] = entry.extra
|
815
|
-
if entry.size is not None:
|
816
|
-
json_entry["size"] = entry.size
|
817
|
-
contents[entry.path] = json_entry
|
818
|
-
return {
|
819
|
-
"version": self.__class__.version(),
|
820
|
-
"storagePolicy": self.storage_policy.name(),
|
821
|
-
"storagePolicyConfig": self.storage_policy.config() or {},
|
822
|
-
"contents": contents,
|
823
|
-
}
|
824
|
-
|
825
|
-
def digest(self) -> HexMD5:
|
826
|
-
hasher = _md5()
|
827
|
-
hasher.update(b"wandb-artifact-manifest-v1\n")
|
828
|
-
for name, entry in sorted(self.entries.items(), key=lambda kv: kv[0]):
|
829
|
-
hasher.update(f"{name}:{entry.digest}\n".encode())
|
830
|
-
return HexMD5(hasher.hexdigest())
|
831
|
-
|
832
|
-
|
833
|
-
class WandbStoragePolicy(StoragePolicy):
|
834
|
-
@classmethod
|
835
|
-
def name(cls) -> str:
|
836
|
-
return "wandb-storage-policy-v1"
|
837
|
-
|
838
|
-
@classmethod
|
839
|
-
def from_config(cls, config: Dict) -> "WandbStoragePolicy":
|
840
|
-
return cls(config=config)
|
841
|
-
|
842
|
-
def __init__(
|
843
|
-
self,
|
844
|
-
config: Optional[Dict] = None,
|
845
|
-
cache: Optional[ArtifactsCache] = None,
|
846
|
-
api: Optional[InternalApi] = None,
|
847
|
-
) -> None:
|
848
|
-
self._cache = cache or get_artifacts_cache()
|
849
|
-
self._config = config or {}
|
850
|
-
self._session = requests.Session()
|
851
|
-
adapter = requests.adapters.HTTPAdapter(
|
852
|
-
max_retries=_REQUEST_RETRY_STRATEGY,
|
853
|
-
pool_connections=_REQUEST_POOL_CONNECTIONS,
|
854
|
-
pool_maxsize=_REQUEST_POOL_MAXSIZE,
|
855
|
-
)
|
856
|
-
self._session.mount("http://", adapter)
|
857
|
-
self._session.mount("https://", adapter)
|
858
|
-
|
859
|
-
s3 = S3Handler()
|
860
|
-
gcs = GCSHandler()
|
861
|
-
azure = AzureHandler()
|
862
|
-
http = HTTPHandler(self._session)
|
863
|
-
https = HTTPHandler(self._session, scheme="https")
|
864
|
-
artifact = WBArtifactHandler()
|
865
|
-
local_artifact = WBLocalArtifactHandler()
|
866
|
-
file_handler = LocalFileHandler()
|
867
|
-
|
868
|
-
self._api = api or InternalApi()
|
869
|
-
self._handler = MultiHandler(
|
870
|
-
handlers=[
|
871
|
-
s3,
|
872
|
-
gcs,
|
873
|
-
azure,
|
874
|
-
http,
|
875
|
-
https,
|
876
|
-
artifact,
|
877
|
-
local_artifact,
|
878
|
-
file_handler,
|
879
|
-
],
|
880
|
-
default_handler=TrackingHandler(),
|
881
|
-
)
|
882
|
-
|
883
|
-
def config(self) -> Dict:
|
884
|
-
return self._config
|
885
|
-
|
886
|
-
def load_file(
|
887
|
-
self,
|
888
|
-
artifact: ArtifactInterface,
|
889
|
-
manifest_entry: ArtifactManifestEntry,
|
890
|
-
) -> str:
|
891
|
-
path, hit, cache_open = self._cache.check_md5_obj_path(
|
892
|
-
B64MD5(manifest_entry.digest), # TODO(spencerpearson): unsafe cast
|
893
|
-
manifest_entry.size if manifest_entry.size is not None else 0,
|
894
|
-
)
|
895
|
-
if hit:
|
896
|
-
return path
|
897
|
-
|
898
|
-
response = self._session.get(
|
899
|
-
self._file_url(self._api, artifact.entity, manifest_entry),
|
900
|
-
auth=("api", self._api.api_key),
|
901
|
-
stream=True,
|
902
|
-
)
|
903
|
-
response.raise_for_status()
|
904
|
-
|
905
|
-
with cache_open(mode="wb") as file:
|
906
|
-
for data in response.iter_content(chunk_size=16 * 1024):
|
907
|
-
file.write(data)
|
908
|
-
return path
|
909
|
-
|
910
|
-
def store_reference(
|
911
|
-
self,
|
912
|
-
artifact: ArtifactInterface,
|
913
|
-
path: Union[URIStr, FilePathStr],
|
914
|
-
name: Optional[str] = None,
|
915
|
-
checksum: bool = True,
|
916
|
-
max_objects: Optional[int] = None,
|
917
|
-
) -> Sequence[ArtifactManifestEntry]:
|
918
|
-
return self._handler.store_path(
|
919
|
-
artifact, path, name=name, checksum=checksum, max_objects=max_objects
|
920
|
-
)
|
921
|
-
|
922
|
-
def load_reference(
|
923
|
-
self,
|
924
|
-
manifest_entry: ArtifactManifestEntry,
|
925
|
-
local: bool = False,
|
926
|
-
) -> str:
|
927
|
-
return self._handler.load_path(manifest_entry, local)
|
928
|
-
|
929
|
-
def _file_url(
|
930
|
-
self, api: InternalApi, entity_name: str, manifest_entry: ArtifactManifestEntry
|
931
|
-
) -> str:
|
932
|
-
storage_layout = self._config.get("storageLayout", StorageLayout.V1)
|
933
|
-
storage_region = self._config.get("storageRegion", "default")
|
934
|
-
md5_hex = b64_to_hex_id(B64MD5(manifest_entry.digest))
|
935
|
-
|
936
|
-
if storage_layout == StorageLayout.V1:
|
937
|
-
return "{}/artifacts/{}/{}".format(
|
938
|
-
api.settings("base_url"), entity_name, md5_hex
|
939
|
-
)
|
940
|
-
elif storage_layout == StorageLayout.V2:
|
941
|
-
return "{}/artifactsV2/{}/{}/{}/{}".format(
|
942
|
-
api.settings("base_url"),
|
943
|
-
storage_region,
|
944
|
-
entity_name,
|
945
|
-
quote(
|
946
|
-
manifest_entry.birth_artifact_id
|
947
|
-
if manifest_entry.birth_artifact_id is not None
|
948
|
-
else ""
|
949
|
-
),
|
950
|
-
md5_hex,
|
951
|
-
)
|
952
|
-
else:
|
953
|
-
raise Exception(f"unrecognized storage layout: {storage_layout}")
|
954
|
-
|
955
|
-
def store_file_sync(
|
956
|
-
self,
|
957
|
-
artifact_id: str,
|
958
|
-
artifact_manifest_id: str,
|
959
|
-
entry: ArtifactManifestEntry,
|
960
|
-
preparer: "StepPrepare",
|
961
|
-
progress_callback: Optional["progress.ProgressFn"] = None,
|
962
|
-
) -> bool:
|
963
|
-
"""Upload a file to the artifact store.
|
964
|
-
|
965
|
-
Returns:
|
966
|
-
True if the file was a duplicate (did not need to be uploaded),
|
967
|
-
False if it needed to be uploaded or was a reference (nothing to dedupe).
|
968
|
-
"""
|
969
|
-
resp = preparer.prepare_sync(
|
970
|
-
{
|
971
|
-
"artifactID": artifact_id,
|
972
|
-
"artifactManifestID": artifact_manifest_id,
|
973
|
-
"name": entry.path,
|
974
|
-
"md5": entry.digest,
|
975
|
-
}
|
976
|
-
).get()
|
977
|
-
|
978
|
-
entry.birth_artifact_id = resp.birth_artifact_id
|
979
|
-
if resp.upload_url is None:
|
980
|
-
return True
|
981
|
-
if entry.local_path is None:
|
982
|
-
return False
|
983
|
-
|
984
|
-
with open(entry.local_path, "rb") as file:
|
985
|
-
# This fails if we don't send the first byte before the signed URL expires.
|
986
|
-
self._api.upload_file_retry(
|
987
|
-
resp.upload_url,
|
988
|
-
file,
|
989
|
-
progress_callback,
|
990
|
-
extra_headers={
|
991
|
-
header.split(":", 1)[0]: header.split(":", 1)[1]
|
992
|
-
for header in (resp.upload_headers or {})
|
993
|
-
},
|
994
|
-
)
|
995
|
-
self._write_cache(entry)
|
996
|
-
|
997
|
-
return False
|
998
|
-
|
999
|
-
async def store_file_async(
|
1000
|
-
self,
|
1001
|
-
artifact_id: str,
|
1002
|
-
artifact_manifest_id: str,
|
1003
|
-
entry: ArtifactManifestEntry,
|
1004
|
-
preparer: "StepPrepare",
|
1005
|
-
progress_callback: Optional["progress.ProgressFn"] = None,
|
1006
|
-
) -> bool:
|
1007
|
-
"""Async equivalent to `store_file_sync`."""
|
1008
|
-
resp = await preparer.prepare_async(
|
1009
|
-
{
|
1010
|
-
"artifactID": artifact_id,
|
1011
|
-
"artifactManifestID": artifact_manifest_id,
|
1012
|
-
"name": entry.path,
|
1013
|
-
"md5": entry.digest,
|
1014
|
-
}
|
1015
|
-
)
|
1016
|
-
|
1017
|
-
entry.birth_artifact_id = resp.birth_artifact_id
|
1018
|
-
if resp.upload_url is None:
|
1019
|
-
return True
|
1020
|
-
if entry.local_path is None:
|
1021
|
-
return False
|
1022
|
-
|
1023
|
-
with open(entry.local_path, "rb") as file:
|
1024
|
-
# This fails if we don't send the first byte before the signed URL expires.
|
1025
|
-
await self._api.upload_file_retry_async(
|
1026
|
-
resp.upload_url,
|
1027
|
-
file,
|
1028
|
-
progress_callback,
|
1029
|
-
extra_headers={
|
1030
|
-
header.split(":", 1)[0]: header.split(":", 1)[1]
|
1031
|
-
for header in (resp.upload_headers or {})
|
1032
|
-
},
|
1033
|
-
)
|
1034
|
-
|
1035
|
-
self._write_cache(entry)
|
1036
|
-
|
1037
|
-
return False
|
1038
|
-
|
1039
|
-
def _write_cache(self, entry: ArtifactManifestEntry) -> None:
|
1040
|
-
if entry.local_path is None:
|
1041
|
-
return
|
1042
|
-
|
1043
|
-
# Cache upon successful upload.
|
1044
|
-
_, hit, cache_open = self._cache.check_md5_obj_path(
|
1045
|
-
B64MD5(entry.digest),
|
1046
|
-
entry.size if entry.size is not None else 0,
|
1047
|
-
)
|
1048
|
-
if not hit:
|
1049
|
-
with cache_open() as f:
|
1050
|
-
shutil.copyfile(entry.local_path, f.name)
|
1051
|
-
|
1052
|
-
|
1053
|
-
# Don't use this yet!
|
1054
|
-
class __S3BucketPolicy(StoragePolicy): # noqa: N801
|
1055
|
-
@classmethod
|
1056
|
-
def name(cls) -> str:
|
1057
|
-
return "wandb-s3-bucket-policy-v1"
|
1058
|
-
|
1059
|
-
@classmethod
|
1060
|
-
def from_config(cls, config: Dict[str, str]) -> "__S3BucketPolicy":
|
1061
|
-
if "bucket" not in config:
|
1062
|
-
raise ValueError("Bucket name not found in config")
|
1063
|
-
return cls(config["bucket"])
|
1064
|
-
|
1065
|
-
def __init__(self, bucket: str) -> None:
|
1066
|
-
self._bucket = bucket
|
1067
|
-
s3 = S3Handler(bucket)
|
1068
|
-
local = LocalFileHandler()
|
1069
|
-
|
1070
|
-
self._handler = MultiHandler(
|
1071
|
-
handlers=[
|
1072
|
-
s3,
|
1073
|
-
local,
|
1074
|
-
],
|
1075
|
-
default_handler=TrackingHandler(),
|
1076
|
-
)
|
1077
|
-
|
1078
|
-
def config(self) -> Dict[str, str]:
|
1079
|
-
return {"bucket": self._bucket}
|
1080
|
-
|
1081
|
-
def load_path(
|
1082
|
-
self,
|
1083
|
-
manifest_entry: ArtifactManifestEntry,
|
1084
|
-
local: bool = False,
|
1085
|
-
) -> Union[URIStr, FilePathStr]:
|
1086
|
-
return self._handler.load_path(manifest_entry, local=local)
|
1087
|
-
|
1088
|
-
def store_path(
|
1089
|
-
self,
|
1090
|
-
artifact: ArtifactInterface,
|
1091
|
-
path: Union[URIStr, FilePathStr],
|
1092
|
-
name: Optional[str] = None,
|
1093
|
-
checksum: bool = True,
|
1094
|
-
max_objects: Optional[int] = None,
|
1095
|
-
) -> Sequence[ArtifactManifestEntry]:
|
1096
|
-
return self._handler.store_path(
|
1097
|
-
artifact, path, name=name, checksum=checksum, max_objects=max_objects
|
1098
|
-
)
|
1099
|
-
|
1100
|
-
|
1101
|
-
class MultiHandler(StorageHandler):
|
1102
|
-
_handlers: List[StorageHandler]
|
1103
|
-
|
1104
|
-
def __init__(
|
1105
|
-
self,
|
1106
|
-
handlers: Optional[List[StorageHandler]] = None,
|
1107
|
-
default_handler: Optional[StorageHandler] = None,
|
1108
|
-
) -> None:
|
1109
|
-
self._handlers = handlers or []
|
1110
|
-
self._default_handler = default_handler
|
1111
|
-
|
1112
|
-
def _get_handler(self, url: Union[FilePathStr, URIStr]) -> StorageHandler:
|
1113
|
-
parsed_url = urlparse(url)
|
1114
|
-
for handler in self._handlers:
|
1115
|
-
if handler.can_handle(parsed_url):
|
1116
|
-
return handler
|
1117
|
-
if self._default_handler is not None:
|
1118
|
-
return self._default_handler
|
1119
|
-
raise ValueError('No storage handler registered for url "%s"' % str(url))
|
1120
|
-
|
1121
|
-
def load_path(
|
1122
|
-
self,
|
1123
|
-
manifest_entry: ArtifactManifestEntry,
|
1124
|
-
local: bool = False,
|
1125
|
-
) -> Union[URIStr, FilePathStr]:
|
1126
|
-
assert manifest_entry.ref is not None
|
1127
|
-
handler = self._get_handler(manifest_entry.ref)
|
1128
|
-
return handler.load_path(manifest_entry, local=local)
|
1129
|
-
|
1130
|
-
def store_path(
|
1131
|
-
self,
|
1132
|
-
artifact: ArtifactInterface,
|
1133
|
-
path: Union[URIStr, FilePathStr],
|
1134
|
-
name: Optional[str] = None,
|
1135
|
-
checksum: bool = True,
|
1136
|
-
max_objects: Optional[int] = None,
|
1137
|
-
) -> Sequence[ArtifactManifestEntry]:
|
1138
|
-
handler = self._get_handler(path)
|
1139
|
-
return handler.store_path(
|
1140
|
-
artifact, path, name=name, checksum=checksum, max_objects=max_objects
|
1141
|
-
)
|
1142
|
-
|
1143
|
-
|
1144
|
-
class TrackingHandler(StorageHandler):
|
1145
|
-
def __init__(self, scheme: Optional[str] = None) -> None:
|
1146
|
-
"""Track paths with no modification or special processing.
|
1147
|
-
|
1148
|
-
Useful when paths being tracked are on file systems mounted at a standardized
|
1149
|
-
location.
|
1150
|
-
|
1151
|
-
For example, if the data to track is located on an NFS share mounted on
|
1152
|
-
`/data`, then it is sufficient to just track the paths.
|
1153
|
-
"""
|
1154
|
-
self._scheme = scheme or ""
|
1155
|
-
|
1156
|
-
def can_handle(self, parsed_url: "ParseResult") -> bool:
|
1157
|
-
return parsed_url.scheme == self._scheme
|
1158
|
-
|
1159
|
-
def load_path(
|
1160
|
-
self,
|
1161
|
-
manifest_entry: ArtifactManifestEntry,
|
1162
|
-
local: bool = False,
|
1163
|
-
) -> Union[URIStr, FilePathStr]:
|
1164
|
-
if local:
|
1165
|
-
# Likely a user error. The tracking handler is
|
1166
|
-
# oblivious to the underlying paths, so it has
|
1167
|
-
# no way of actually loading it.
|
1168
|
-
url = urlparse(manifest_entry.ref)
|
1169
|
-
raise ValueError(
|
1170
|
-
f"Cannot download file at path {str(manifest_entry.ref)}, scheme {str(url.scheme)} not recognized"
|
1171
|
-
)
|
1172
|
-
# TODO(spencerpearson): should this go through util.to_native_slash_path
|
1173
|
-
# instead of just getting typecast?
|
1174
|
-
return FilePathStr(manifest_entry.path)
|
1175
|
-
|
1176
|
-
def store_path(
|
1177
|
-
self,
|
1178
|
-
artifact: ArtifactInterface,
|
1179
|
-
path: Union[URIStr, FilePathStr],
|
1180
|
-
name: Optional[str] = None,
|
1181
|
-
checksum: bool = True,
|
1182
|
-
max_objects: Optional[int] = None,
|
1183
|
-
) -> Sequence[ArtifactManifestEntry]:
|
1184
|
-
url = urlparse(path)
|
1185
|
-
if name is None:
|
1186
|
-
raise ValueError(
|
1187
|
-
'You must pass name="<entry_name>" when tracking references with unknown schemes. ref: %s'
|
1188
|
-
% path
|
1189
|
-
)
|
1190
|
-
termwarn(
|
1191
|
-
"Artifact references with unsupported schemes cannot be checksummed: %s"
|
1192
|
-
% path
|
1193
|
-
)
|
1194
|
-
name = LogicalFilePathStr(name or url.path[1:]) # strip leading slash
|
1195
|
-
return [ArtifactManifestEntry(path=name, ref=path, digest=path)]
|
1196
|
-
|
1197
|
-
|
1198
|
-
DEFAULT_MAX_OBJECTS = 10000
|
1199
|
-
|
1200
|
-
|
1201
|
-
class LocalFileHandler(StorageHandler):
|
1202
|
-
"""Handles file:// references."""
|
1203
|
-
|
1204
|
-
def __init__(self, scheme: Optional[str] = None) -> None:
|
1205
|
-
"""Track files or directories on a local filesystem.
|
1206
|
-
|
1207
|
-
Expand directories to create an entry for each file contained.
|
1208
|
-
"""
|
1209
|
-
self._scheme = scheme or "file"
|
1210
|
-
self._cache = get_artifacts_cache()
|
1211
|
-
|
1212
|
-
def can_handle(self, parsed_url: "ParseResult") -> bool:
|
1213
|
-
return parsed_url.scheme == self._scheme
|
1214
|
-
|
1215
|
-
def load_path(
|
1216
|
-
self,
|
1217
|
-
manifest_entry: ArtifactManifestEntry,
|
1218
|
-
local: bool = False,
|
1219
|
-
) -> Union[URIStr, FilePathStr]:
|
1220
|
-
if manifest_entry.ref is None:
|
1221
|
-
raise ValueError(f"Cannot add path with no ref: {manifest_entry.path}")
|
1222
|
-
local_path = util.local_file_uri_to_path(str(manifest_entry.ref))
|
1223
|
-
if not os.path.exists(local_path):
|
1224
|
-
raise ValueError(
|
1225
|
-
"Local file reference: Failed to find file at path %s" % local_path
|
1226
|
-
)
|
1227
|
-
|
1228
|
-
path, hit, cache_open = self._cache.check_md5_obj_path(
|
1229
|
-
B64MD5(manifest_entry.digest), # TODO(spencerpearson): unsafe cast
|
1230
|
-
manifest_entry.size if manifest_entry.size is not None else 0,
|
1231
|
-
)
|
1232
|
-
if hit:
|
1233
|
-
return path
|
1234
|
-
|
1235
|
-
md5 = md5_file_b64(local_path)
|
1236
|
-
if md5 != manifest_entry.digest:
|
1237
|
-
raise ValueError(
|
1238
|
-
f"Local file reference: Digest mismatch for path {local_path}: expected {manifest_entry.digest} but found {md5}"
|
1239
|
-
)
|
1240
|
-
|
1241
|
-
filesystem.mkdir_exists_ok(os.path.dirname(path))
|
1242
|
-
|
1243
|
-
with cache_open() as f:
|
1244
|
-
shutil.copy(local_path, f.name)
|
1245
|
-
return path
|
1246
|
-
|
1247
|
-
def store_path(
|
1248
|
-
self,
|
1249
|
-
artifact: ArtifactInterface,
|
1250
|
-
path: Union[URIStr, FilePathStr],
|
1251
|
-
name: Optional[str] = None,
|
1252
|
-
checksum: bool = True,
|
1253
|
-
max_objects: Optional[int] = None,
|
1254
|
-
) -> Sequence[ArtifactManifestEntry]:
|
1255
|
-
local_path = util.local_file_uri_to_path(path)
|
1256
|
-
max_objects = max_objects or DEFAULT_MAX_OBJECTS
|
1257
|
-
# We have a single file or directory
|
1258
|
-
# Note, we follow symlinks for files contained within the directory
|
1259
|
-
entries = []
|
1260
|
-
|
1261
|
-
def md5(path: str) -> B64MD5:
|
1262
|
-
return (
|
1263
|
-
md5_file_b64(path)
|
1264
|
-
if checksum
|
1265
|
-
else md5_string(str(os.stat(path).st_size))
|
1266
|
-
)
|
1267
|
-
|
1268
|
-
if os.path.isdir(local_path):
|
1269
|
-
i = 0
|
1270
|
-
start_time = time.time()
|
1271
|
-
if checksum:
|
1272
|
-
termlog(
|
1273
|
-
'Generating checksum for up to %i files in "%s"...\n'
|
1274
|
-
% (max_objects, local_path),
|
1275
|
-
newline=False,
|
1276
|
-
)
|
1277
|
-
for root, _, files in os.walk(local_path):
|
1278
|
-
for sub_path in files:
|
1279
|
-
i += 1
|
1280
|
-
if i > max_objects:
|
1281
|
-
raise ValueError(
|
1282
|
-
"Exceeded %i objects tracked, pass max_objects to add_reference"
|
1283
|
-
% max_objects
|
1284
|
-
)
|
1285
|
-
physical_path = os.path.join(root, sub_path)
|
1286
|
-
# TODO(spencerpearson): this is not a "logical path" in the sense that
|
1287
|
-
# `util.to_forward_slash_path` returns a "logical path"; it's a relative path
|
1288
|
-
# **on the local filesystem**.
|
1289
|
-
logical_path = os.path.relpath(physical_path, start=local_path)
|
1290
|
-
if name is not None:
|
1291
|
-
logical_path = os.path.join(name, logical_path)
|
1292
|
-
|
1293
|
-
entry = ArtifactManifestEntry(
|
1294
|
-
path=LogicalFilePathStr(logical_path),
|
1295
|
-
ref=FilePathStr(os.path.join(path, logical_path)),
|
1296
|
-
size=os.path.getsize(physical_path),
|
1297
|
-
digest=md5(physical_path),
|
1298
|
-
)
|
1299
|
-
entries.append(entry)
|
1300
|
-
if checksum:
|
1301
|
-
termlog("Done. %.1fs" % (time.time() - start_time), prefix=False)
|
1302
|
-
elif os.path.isfile(local_path):
|
1303
|
-
name = name or os.path.basename(local_path)
|
1304
|
-
entry = ArtifactManifestEntry(
|
1305
|
-
path=LogicalFilePathStr(name),
|
1306
|
-
ref=path,
|
1307
|
-
size=os.path.getsize(local_path),
|
1308
|
-
digest=md5(local_path),
|
1309
|
-
)
|
1310
|
-
entries.append(entry)
|
1311
|
-
else:
|
1312
|
-
# TODO: update error message if we don't allow directories.
|
1313
|
-
raise ValueError('Path "%s" must be a valid file or directory path' % path)
|
1314
|
-
return entries
|
1315
|
-
|
1316
|
-
|
1317
|
-
class S3Handler(StorageHandler):
|
1318
|
-
_s3: Optional["boto3.resources.base.ServiceResource"]
|
1319
|
-
_scheme: str
|
1320
|
-
_versioning_enabled: Optional[bool]
|
1321
|
-
|
1322
|
-
def __init__(self, scheme: Optional[str] = None) -> None:
|
1323
|
-
self._scheme = scheme or "s3"
|
1324
|
-
self._s3 = None
|
1325
|
-
self._versioning_enabled = None
|
1326
|
-
self._cache = get_artifacts_cache()
|
1327
|
-
|
1328
|
-
def can_handle(self, parsed_url: "ParseResult") -> bool:
|
1329
|
-
return parsed_url.scheme == self._scheme
|
1330
|
-
|
1331
|
-
def init_boto(self) -> "boto3.resources.base.ServiceResource":
|
1332
|
-
if self._s3 is not None:
|
1333
|
-
return self._s3
|
1334
|
-
boto: "boto3" = util.get_module(
|
1335
|
-
"boto3",
|
1336
|
-
required="s3:// references requires the boto3 library, run pip install wandb[aws]",
|
1337
|
-
lazy=False,
|
1338
|
-
)
|
1339
|
-
self._s3 = boto.session.Session().resource(
|
1340
|
-
"s3",
|
1341
|
-
endpoint_url=os.getenv("AWS_S3_ENDPOINT_URL"),
|
1342
|
-
region_name=os.getenv("AWS_REGION"),
|
1343
|
-
)
|
1344
|
-
self._botocore = util.get_module("botocore")
|
1345
|
-
return self._s3
|
1346
|
-
|
1347
|
-
def _parse_uri(self, uri: str) -> Tuple[str, str, Optional[str]]:
|
1348
|
-
url = urlparse(uri)
|
1349
|
-
query = dict(parse_qsl(url.query))
|
1350
|
-
|
1351
|
-
bucket = url.netloc
|
1352
|
-
key = url.path[1:] # strip leading slash
|
1353
|
-
version = query.get("versionId")
|
1354
|
-
|
1355
|
-
return bucket, key, version
|
1356
|
-
|
1357
|
-
def versioning_enabled(self, bucket: str) -> bool:
|
1358
|
-
self.init_boto()
|
1359
|
-
assert self._s3 is not None # mypy: unwraps optionality
|
1360
|
-
if self._versioning_enabled is not None:
|
1361
|
-
return self._versioning_enabled
|
1362
|
-
res = self._s3.BucketVersioning(bucket)
|
1363
|
-
self._versioning_enabled = res.status == "Enabled"
|
1364
|
-
return self._versioning_enabled
|
1365
|
-
|
1366
|
-
def load_path(
|
1367
|
-
self,
|
1368
|
-
manifest_entry: ArtifactManifestEntry,
|
1369
|
-
local: bool = False,
|
1370
|
-
) -> Union[URIStr, FilePathStr]:
|
1371
|
-
if not local:
|
1372
|
-
assert manifest_entry.ref is not None
|
1373
|
-
return manifest_entry.ref
|
1374
|
-
|
1375
|
-
assert manifest_entry.ref is not None
|
1376
|
-
|
1377
|
-
path, hit, cache_open = self._cache.check_etag_obj_path(
|
1378
|
-
URIStr(manifest_entry.ref),
|
1379
|
-
ETag(manifest_entry.digest), # TODO(spencerpearson): unsafe cast
|
1380
|
-
manifest_entry.size if manifest_entry.size is not None else 0,
|
1381
|
-
)
|
1382
|
-
if hit:
|
1383
|
-
return path
|
1384
|
-
|
1385
|
-
self.init_boto()
|
1386
|
-
assert self._s3 is not None # mypy: unwraps optionality
|
1387
|
-
bucket, key, _ = self._parse_uri(manifest_entry.ref)
|
1388
|
-
version = manifest_entry.extra.get("versionID")
|
1389
|
-
|
1390
|
-
extra_args = {}
|
1391
|
-
if version is None:
|
1392
|
-
# We don't have version information so just get the latest version
|
1393
|
-
# and fallback to listing all versions if we don't have a match.
|
1394
|
-
obj = self._s3.Object(bucket, key)
|
1395
|
-
etag = self._etag_from_obj(obj)
|
1396
|
-
if etag != manifest_entry.digest:
|
1397
|
-
if self.versioning_enabled(bucket):
|
1398
|
-
# Fallback to listing versions
|
1399
|
-
obj = None
|
1400
|
-
object_versions = self._s3.Bucket(bucket).object_versions.filter(
|
1401
|
-
Prefix=key
|
1402
|
-
)
|
1403
|
-
for object_version in object_versions:
|
1404
|
-
if (
|
1405
|
-
manifest_entry.extra.get("etag")
|
1406
|
-
== object_version.e_tag[1:-1]
|
1407
|
-
):
|
1408
|
-
obj = object_version.Object()
|
1409
|
-
extra_args["VersionId"] = object_version.version_id
|
1410
|
-
break
|
1411
|
-
if obj is None:
|
1412
|
-
raise ValueError(
|
1413
|
-
"Couldn't find object version for {}/{} matching etag {}".format(
|
1414
|
-
bucket, key, manifest_entry.extra.get("etag")
|
1415
|
-
)
|
1416
|
-
)
|
1417
|
-
else:
|
1418
|
-
raise ValueError(
|
1419
|
-
f"Digest mismatch for object {manifest_entry.ref}: expected {manifest_entry.digest} but found {etag}"
|
1420
|
-
)
|
1421
|
-
else:
|
1422
|
-
obj = self._s3.ObjectVersion(bucket, key, version).Object()
|
1423
|
-
extra_args["VersionId"] = version
|
1424
|
-
|
1425
|
-
with cache_open(mode="wb") as f:
|
1426
|
-
obj.download_fileobj(f, ExtraArgs=extra_args)
|
1427
|
-
return path
|
1428
|
-
|
1429
|
-
def store_path(
|
1430
|
-
self,
|
1431
|
-
artifact: ArtifactInterface,
|
1432
|
-
path: Union[URIStr, FilePathStr],
|
1433
|
-
name: Optional[str] = None,
|
1434
|
-
checksum: bool = True,
|
1435
|
-
max_objects: Optional[int] = None,
|
1436
|
-
) -> Sequence[ArtifactManifestEntry]:
|
1437
|
-
self.init_boto()
|
1438
|
-
assert self._s3 is not None # mypy: unwraps optionality
|
1439
|
-
|
1440
|
-
# The passed in path might have query string parameters.
|
1441
|
-
# We only need to care about a subset, like version, when
|
1442
|
-
# parsing. Once we have that, we can store the rest of the
|
1443
|
-
# metadata in the artifact entry itself.
|
1444
|
-
bucket, key, version = self._parse_uri(path)
|
1445
|
-
path = URIStr(f"{self._scheme}://{bucket}/{key}")
|
1446
|
-
if not self.versioning_enabled(bucket) and version:
|
1447
|
-
raise ValueError(
|
1448
|
-
f"Specifying a versionId is not valid for s3://{bucket} as it does not have versioning enabled."
|
1449
|
-
)
|
1450
|
-
|
1451
|
-
max_objects = max_objects or DEFAULT_MAX_OBJECTS
|
1452
|
-
if not checksum:
|
1453
|
-
return [
|
1454
|
-
ArtifactManifestEntry(
|
1455
|
-
path=LogicalFilePathStr(name or key), ref=path, digest=path
|
1456
|
-
)
|
1457
|
-
]
|
1458
|
-
|
1459
|
-
# If an explicit version is specified, use that. Otherwise, use the head version.
|
1460
|
-
objs = (
|
1461
|
-
[self._s3.ObjectVersion(bucket, key, version).Object()]
|
1462
|
-
if version
|
1463
|
-
else [self._s3.Object(bucket, key)]
|
1464
|
-
)
|
1465
|
-
start_time = None
|
1466
|
-
multi = False
|
1467
|
-
try:
|
1468
|
-
objs[0].load()
|
1469
|
-
# S3 doesn't have real folders, however there are cases where the folder key has a valid file which will not
|
1470
|
-
# trigger a recursive upload.
|
1471
|
-
# we should check the object's metadata says it is a directory and do a multi file upload if it is
|
1472
|
-
if "x-directory" in objs[0].content_type:
|
1473
|
-
multi = True
|
1474
|
-
except self._botocore.exceptions.ClientError as e:
|
1475
|
-
if e.response["Error"]["Code"] == "404":
|
1476
|
-
multi = True
|
1477
|
-
else:
|
1478
|
-
raise CommError(
|
1479
|
-
"Unable to connect to S3 ({}): {}".format(
|
1480
|
-
e.response["Error"]["Code"], e.response["Error"]["Message"]
|
1481
|
-
)
|
1482
|
-
)
|
1483
|
-
if multi:
|
1484
|
-
start_time = time.time()
|
1485
|
-
termlog(
|
1486
|
-
'Generating checksum for up to %i objects with prefix "%s"... '
|
1487
|
-
% (max_objects, key),
|
1488
|
-
newline=False,
|
1489
|
-
)
|
1490
|
-
objs = self._s3.Bucket(bucket).objects.filter(Prefix=key).limit(max_objects)
|
1491
|
-
# Weird iterator scoping makes us assign this to a local function
|
1492
|
-
size = self._size_from_obj
|
1493
|
-
entries = [
|
1494
|
-
self._entry_from_obj(obj, path, name, prefix=key, multi=multi)
|
1495
|
-
for obj in objs
|
1496
|
-
if size(obj) > 0
|
1497
|
-
]
|
1498
|
-
if start_time is not None:
|
1499
|
-
termlog("Done. %.1fs" % (time.time() - start_time), prefix=False)
|
1500
|
-
if len(entries) > max_objects:
|
1501
|
-
raise ValueError(
|
1502
|
-
"Exceeded %i objects tracked, pass max_objects to add_reference"
|
1503
|
-
% max_objects
|
1504
|
-
)
|
1505
|
-
return entries
|
1506
|
-
|
1507
|
-
def _size_from_obj(self, obj: "boto3.s3.Object") -> int:
|
1508
|
-
# ObjectSummary has size, Object has content_length
|
1509
|
-
size: int
|
1510
|
-
if hasattr(obj, "size"):
|
1511
|
-
size = obj.size
|
1512
|
-
else:
|
1513
|
-
size = obj.content_length
|
1514
|
-
return size
|
1515
|
-
|
1516
|
-
def _entry_from_obj(
|
1517
|
-
self,
|
1518
|
-
obj: "boto3.s3.Object",
|
1519
|
-
path: str,
|
1520
|
-
name: Optional[str] = None,
|
1521
|
-
prefix: str = "",
|
1522
|
-
multi: bool = False,
|
1523
|
-
) -> ArtifactManifestEntry:
|
1524
|
-
"""Create an ArtifactManifestEntry from an S3 object.
|
1525
|
-
|
1526
|
-
Arguments:
|
1527
|
-
obj: The S3 object
|
1528
|
-
path: The S3-style path (e.g.: "s3://bucket/file.txt")
|
1529
|
-
name: The user assigned name, or None if not specified
|
1530
|
-
prefix: The prefix to add (will be the same as `path` for directories)
|
1531
|
-
multi: Whether or not this is a multi-object add.
|
1532
|
-
"""
|
1533
|
-
bucket, key, _ = self._parse_uri(path)
|
1534
|
-
|
1535
|
-
# Always use posix paths, since that's what S3 uses.
|
1536
|
-
posix_key = pathlib.PurePosixPath(obj.key) # the bucket key
|
1537
|
-
posix_path = pathlib.PurePosixPath(bucket) / pathlib.PurePosixPath(
|
1538
|
-
key
|
1539
|
-
) # the path, with the scheme stripped
|
1540
|
-
posix_prefix = pathlib.PurePosixPath(prefix) # the prefix, if adding a prefix
|
1541
|
-
posix_name = pathlib.PurePosixPath(name or "")
|
1542
|
-
posix_ref = posix_path
|
1543
|
-
|
1544
|
-
if name is None:
|
1545
|
-
# We're adding a directory (prefix), so calculate a relative path.
|
1546
|
-
if str(posix_prefix) in str(posix_key) and posix_prefix != posix_key:
|
1547
|
-
posix_name = posix_key.relative_to(posix_prefix)
|
1548
|
-
posix_ref = posix_path / posix_name
|
1549
|
-
else:
|
1550
|
-
posix_name = pathlib.PurePosixPath(posix_key.name)
|
1551
|
-
posix_ref = posix_path
|
1552
|
-
elif multi:
|
1553
|
-
# We're adding a directory with a name override.
|
1554
|
-
relpath = posix_key.relative_to(posix_prefix)
|
1555
|
-
posix_name = posix_name / relpath
|
1556
|
-
posix_ref = posix_path / relpath
|
1557
|
-
return ArtifactManifestEntry(
|
1558
|
-
path=LogicalFilePathStr(str(posix_name)),
|
1559
|
-
ref=URIStr(f"{self._scheme}://{str(posix_ref)}"),
|
1560
|
-
digest=ETag(self._etag_from_obj(obj)),
|
1561
|
-
size=self._size_from_obj(obj),
|
1562
|
-
extra=self._extra_from_obj(obj),
|
1563
|
-
)
|
1564
|
-
|
1565
|
-
@staticmethod
|
1566
|
-
def _etag_from_obj(obj: "boto3.s3.Object") -> ETag:
|
1567
|
-
etag: ETag
|
1568
|
-
etag = obj.e_tag[1:-1] # escape leading and trailing quote
|
1569
|
-
return etag
|
1570
|
-
|
1571
|
-
@staticmethod
|
1572
|
-
def _extra_from_obj(obj: "boto3.s3.Object") -> Dict[str, str]:
|
1573
|
-
extra = {
|
1574
|
-
"etag": obj.e_tag[1:-1], # escape leading and trailing quote
|
1575
|
-
}
|
1576
|
-
# ObjectSummary will never have version_id
|
1577
|
-
if hasattr(obj, "version_id") and obj.version_id != "null":
|
1578
|
-
extra["versionID"] = obj.version_id
|
1579
|
-
return extra
|
1580
|
-
|
1581
|
-
@staticmethod
|
1582
|
-
def _content_addressed_path(md5: str) -> FilePathStr:
|
1583
|
-
# TODO: is this the structure we want? not at all human
|
1584
|
-
# readable, but that's probably OK. don't want people
|
1585
|
-
# poking around in the bucket
|
1586
|
-
return FilePathStr(
|
1587
|
-
"wandb/%s" % base64.b64encode(md5.encode("ascii")).decode("ascii")
|
1588
|
-
)
|
1589
|
-
|
1590
|
-
|
1591
|
-
class GCSHandler(StorageHandler):
|
1592
|
-
_client: Optional["gcs_module.client.Client"]
|
1593
|
-
_versioning_enabled: Optional[bool]
|
1594
|
-
|
1595
|
-
def __init__(self, scheme: Optional[str] = None) -> None:
|
1596
|
-
self._scheme = scheme or "gs"
|
1597
|
-
self._client = None
|
1598
|
-
self._versioning_enabled = None
|
1599
|
-
self._cache = get_artifacts_cache()
|
1600
|
-
|
1601
|
-
def versioning_enabled(self, bucket_path: str) -> bool:
|
1602
|
-
if self._versioning_enabled is not None:
|
1603
|
-
return self._versioning_enabled
|
1604
|
-
self.init_gcs()
|
1605
|
-
assert self._client is not None # mypy: unwraps optionality
|
1606
|
-
bucket = self._client.bucket(bucket_path)
|
1607
|
-
bucket.reload()
|
1608
|
-
self._versioning_enabled = bucket.versioning_enabled
|
1609
|
-
return self._versioning_enabled
|
1610
|
-
|
1611
|
-
def can_handle(self, parsed_url: "ParseResult") -> bool:
|
1612
|
-
return parsed_url.scheme == self._scheme
|
1613
|
-
|
1614
|
-
def init_gcs(self) -> "gcs_module.client.Client":
|
1615
|
-
if self._client is not None:
|
1616
|
-
return self._client
|
1617
|
-
storage = util.get_module(
|
1618
|
-
"google.cloud.storage",
|
1619
|
-
required="gs:// references requires the google-cloud-storage library, run pip install wandb[gcp]",
|
1620
|
-
)
|
1621
|
-
self._client = storage.Client()
|
1622
|
-
return self._client
|
1623
|
-
|
1624
|
-
def _parse_uri(self, uri: str) -> Tuple[str, str, Optional[str]]:
|
1625
|
-
url = urlparse(uri)
|
1626
|
-
bucket = url.netloc
|
1627
|
-
key = url.path[1:]
|
1628
|
-
version = url.fragment if url.fragment else None
|
1629
|
-
return bucket, key, version
|
1630
|
-
|
1631
|
-
def load_path(
|
1632
|
-
self,
|
1633
|
-
manifest_entry: ArtifactManifestEntry,
|
1634
|
-
local: bool = False,
|
1635
|
-
) -> Union[URIStr, FilePathStr]:
|
1636
|
-
if not local:
|
1637
|
-
assert manifest_entry.ref is not None
|
1638
|
-
return manifest_entry.ref
|
1639
|
-
|
1640
|
-
path, hit, cache_open = self._cache.check_md5_obj_path(
|
1641
|
-
B64MD5(manifest_entry.digest), # TODO(spencerpearson): unsafe cast
|
1642
|
-
manifest_entry.size if manifest_entry.size is not None else 0,
|
1643
|
-
)
|
1644
|
-
if hit:
|
1645
|
-
return path
|
1646
|
-
|
1647
|
-
self.init_gcs()
|
1648
|
-
assert self._client is not None # mypy: unwraps optionality
|
1649
|
-
assert manifest_entry.ref is not None
|
1650
|
-
bucket, key, _ = self._parse_uri(manifest_entry.ref)
|
1651
|
-
version = manifest_entry.extra.get("versionID")
|
1652
|
-
|
1653
|
-
obj = None
|
1654
|
-
# First attempt to get the generation specified, this will return None if versioning is not enabled
|
1655
|
-
if version is not None:
|
1656
|
-
obj = self._client.bucket(bucket).get_blob(key, generation=version)
|
1657
|
-
|
1658
|
-
if obj is None:
|
1659
|
-
# Object versioning is disabled on the bucket, so just get
|
1660
|
-
# the latest version and make sure the MD5 matches.
|
1661
|
-
obj = self._client.bucket(bucket).get_blob(key)
|
1662
|
-
if obj is None:
|
1663
|
-
raise ValueError(
|
1664
|
-
f"Unable to download object {manifest_entry.ref} with generation {version}"
|
1665
|
-
)
|
1666
|
-
md5 = obj.md5_hash
|
1667
|
-
if md5 != manifest_entry.digest:
|
1668
|
-
raise ValueError(
|
1669
|
-
f"Digest mismatch for object {manifest_entry.ref}: expected {manifest_entry.digest} but found {md5}"
|
1670
|
-
)
|
1671
|
-
|
1672
|
-
with cache_open(mode="wb") as f:
|
1673
|
-
obj.download_to_file(f)
|
1674
|
-
return path
|
1675
|
-
|
1676
|
-
def store_path(
|
1677
|
-
self,
|
1678
|
-
artifact: ArtifactInterface,
|
1679
|
-
path: Union[URIStr, FilePathStr],
|
1680
|
-
name: Optional[str] = None,
|
1681
|
-
checksum: bool = True,
|
1682
|
-
max_objects: Optional[int] = None,
|
1683
|
-
) -> Sequence[ArtifactManifestEntry]:
|
1684
|
-
self.init_gcs()
|
1685
|
-
assert self._client is not None # mypy: unwraps optionality
|
1686
|
-
|
1687
|
-
# After parsing any query params / fragments for additional context,
|
1688
|
-
# such as version identifiers, pare down the path to just the bucket
|
1689
|
-
# and key.
|
1690
|
-
bucket, key, version = self._parse_uri(path)
|
1691
|
-
path = URIStr(f"{self._scheme}://{bucket}/{key}")
|
1692
|
-
max_objects = max_objects or DEFAULT_MAX_OBJECTS
|
1693
|
-
if not self.versioning_enabled(bucket) and version:
|
1694
|
-
raise ValueError(
|
1695
|
-
f"Specifying a versionId is not valid for s3://{bucket} as it does not have versioning enabled."
|
1696
|
-
)
|
1697
|
-
|
1698
|
-
if not checksum:
|
1699
|
-
return [
|
1700
|
-
ArtifactManifestEntry(
|
1701
|
-
path=LogicalFilePathStr(name or key), ref=path, digest=path
|
1702
|
-
)
|
1703
|
-
]
|
1704
|
-
|
1705
|
-
start_time = None
|
1706
|
-
obj = self._client.bucket(bucket).get_blob(key, generation=version)
|
1707
|
-
multi = obj is None
|
1708
|
-
if multi:
|
1709
|
-
start_time = time.time()
|
1710
|
-
termlog(
|
1711
|
-
'Generating checksum for up to %i objects with prefix "%s"... '
|
1712
|
-
% (max_objects, key),
|
1713
|
-
newline=False,
|
1714
|
-
)
|
1715
|
-
objects = self._client.bucket(bucket).list_blobs(
|
1716
|
-
prefix=key, max_results=max_objects
|
1717
|
-
)
|
1718
|
-
else:
|
1719
|
-
objects = [obj]
|
1720
|
-
|
1721
|
-
entries = [
|
1722
|
-
self._entry_from_obj(obj, path, name, prefix=key, multi=multi)
|
1723
|
-
for obj in objects
|
1724
|
-
]
|
1725
|
-
if start_time is not None:
|
1726
|
-
termlog("Done. %.1fs" % (time.time() - start_time), prefix=False)
|
1727
|
-
if len(entries) > max_objects:
|
1728
|
-
raise ValueError(
|
1729
|
-
"Exceeded %i objects tracked, pass max_objects to add_reference"
|
1730
|
-
% max_objects
|
1731
|
-
)
|
1732
|
-
return entries
|
1733
|
-
|
1734
|
-
def _entry_from_obj(
|
1735
|
-
self,
|
1736
|
-
obj: "gcs_module.blob.Blob",
|
1737
|
-
path: str,
|
1738
|
-
name: Optional[str] = None,
|
1739
|
-
prefix: str = "",
|
1740
|
-
multi: bool = False,
|
1741
|
-
) -> ArtifactManifestEntry:
|
1742
|
-
"""Create an ArtifactManifestEntry from a GCS object.
|
1743
|
-
|
1744
|
-
Arguments:
|
1745
|
-
obj: The GCS object
|
1746
|
-
path: The GCS-style path (e.g.: "gs://bucket/file.txt")
|
1747
|
-
name: The user assigned name, or None if not specified
|
1748
|
-
prefix: The prefix to add (will be the same as `path` for directories)
|
1749
|
-
multi: Whether or not this is a multi-object add.
|
1750
|
-
"""
|
1751
|
-
bucket, key, _ = self._parse_uri(path)
|
1752
|
-
|
1753
|
-
# Always use posix paths, since that's what S3 uses.
|
1754
|
-
posix_key = pathlib.PurePosixPath(obj.name) # the bucket key
|
1755
|
-
posix_path = pathlib.PurePosixPath(bucket) / pathlib.PurePosixPath(
|
1756
|
-
key
|
1757
|
-
) # the path, with the scheme stripped
|
1758
|
-
posix_prefix = pathlib.PurePosixPath(prefix) # the prefix, if adding a prefix
|
1759
|
-
posix_name = pathlib.PurePosixPath(name or "")
|
1760
|
-
posix_ref = posix_path
|
1761
|
-
|
1762
|
-
if name is None:
|
1763
|
-
# We're adding a directory (prefix), so calculate a relative path.
|
1764
|
-
if str(posix_prefix) in str(posix_key) and posix_prefix != posix_key:
|
1765
|
-
posix_name = posix_key.relative_to(posix_prefix)
|
1766
|
-
posix_ref = posix_path / posix_name
|
1767
|
-
else:
|
1768
|
-
posix_name = pathlib.PurePosixPath(posix_key.name)
|
1769
|
-
posix_ref = posix_path
|
1770
|
-
elif multi:
|
1771
|
-
# We're adding a directory with a name override.
|
1772
|
-
relpath = posix_key.relative_to(posix_prefix)
|
1773
|
-
posix_name = posix_name / relpath
|
1774
|
-
posix_ref = posix_path / relpath
|
1775
|
-
return ArtifactManifestEntry(
|
1776
|
-
path=LogicalFilePathStr(str(posix_name)),
|
1777
|
-
ref=URIStr(f"{self._scheme}://{str(posix_ref)}"),
|
1778
|
-
digest=obj.md5_hash,
|
1779
|
-
size=obj.size,
|
1780
|
-
extra=self._extra_from_obj(obj),
|
1781
|
-
)
|
1782
|
-
|
1783
|
-
@staticmethod
|
1784
|
-
def _extra_from_obj(obj: "gcs_module.blob.Blob") -> Dict[str, str]:
|
1785
|
-
return {
|
1786
|
-
"etag": obj.etag,
|
1787
|
-
"versionID": obj.generation,
|
1788
|
-
}
|
1789
|
-
|
1790
|
-
@staticmethod
|
1791
|
-
def _content_addressed_path(md5: str) -> FilePathStr:
|
1792
|
-
# TODO: is this the structure we want? not at all human
|
1793
|
-
# readable, but that's probably OK. don't want people
|
1794
|
-
# poking around in the bucket
|
1795
|
-
return FilePathStr(
|
1796
|
-
"wandb/%s" % base64.b64encode(md5.encode("ascii")).decode("ascii")
|
1797
|
-
)
|
1798
|
-
|
1799
|
-
|
1800
|
-
class AzureHandler(StorageHandler):
|
1801
|
-
def can_handle(self, parsed_url: "ParseResult") -> bool:
|
1802
|
-
return parsed_url.scheme == "https" and parsed_url.netloc.endswith(
|
1803
|
-
".blob.core.windows.net"
|
1804
|
-
)
|
1805
|
-
|
1806
|
-
def load_path(
|
1807
|
-
self,
|
1808
|
-
manifest_entry: "ArtifactManifestEntry",
|
1809
|
-
local: bool = False,
|
1810
|
-
) -> Union[URIStr, FilePathStr]:
|
1811
|
-
assert manifest_entry.ref is not None
|
1812
|
-
if not local:
|
1813
|
-
return manifest_entry.ref
|
1814
|
-
|
1815
|
-
path, hit, cache_open = get_artifacts_cache().check_etag_obj_path(
|
1816
|
-
URIStr(manifest_entry.ref),
|
1817
|
-
ETag(manifest_entry.digest),
|
1818
|
-
manifest_entry.size or 0,
|
1819
|
-
)
|
1820
|
-
if hit:
|
1821
|
-
return path
|
1822
|
-
|
1823
|
-
account_url, container_name, blob_name, query = self._parse_uri(
|
1824
|
-
manifest_entry.ref
|
1825
|
-
)
|
1826
|
-
version_id = manifest_entry.extra.get("versionID")
|
1827
|
-
blob_service_client = self._get_module("azure.storage.blob").BlobServiceClient(
|
1828
|
-
account_url,
|
1829
|
-
credential=self._get_module("azure.identity").DefaultAzureCredential(),
|
1830
|
-
)
|
1831
|
-
blob_client = blob_service_client.get_blob_client(
|
1832
|
-
container=container_name, blob=blob_name
|
1833
|
-
)
|
1834
|
-
if version_id is None:
|
1835
|
-
# Try current version, then all versions.
|
1836
|
-
try:
|
1837
|
-
downloader = blob_client.download_blob(
|
1838
|
-
etag=manifest_entry.digest,
|
1839
|
-
match_condition=self._get_module(
|
1840
|
-
"azure.core"
|
1841
|
-
).MatchConditions.IfNotModified,
|
1842
|
-
)
|
1843
|
-
except self._get_module("azure.core.exceptions").ResourceModifiedError:
|
1844
|
-
container_client = blob_service_client.get_container_client(
|
1845
|
-
container_name
|
1846
|
-
)
|
1847
|
-
for blob_properties in container_client.walk_blobs(
|
1848
|
-
name_starts_with=blob_name, include=["versions"]
|
1849
|
-
):
|
1850
|
-
if (
|
1851
|
-
blob_properties.name == blob_name
|
1852
|
-
and blob_properties.etag == manifest_entry.digest
|
1853
|
-
and blob_properties.version_id is not None
|
1854
|
-
):
|
1855
|
-
downloader = blob_client.download_blob(
|
1856
|
-
version_id=blob_properties.version_id
|
1857
|
-
)
|
1858
|
-
break
|
1859
|
-
else: # didn't break
|
1860
|
-
raise ValueError(
|
1861
|
-
f"Couldn't find blob version for {manifest_entry.ref} matching "
|
1862
|
-
f"etag {manifest_entry.digest}."
|
1863
|
-
)
|
1864
|
-
else:
|
1865
|
-
downloader = blob_client.download_blob(version_id=version_id)
|
1866
|
-
with cache_open(mode="wb") as f:
|
1867
|
-
downloader.readinto(f)
|
1868
|
-
return path
|
1869
|
-
|
1870
|
-
def store_path(
|
1871
|
-
self,
|
1872
|
-
artifact: ArtifactInterface,
|
1873
|
-
path: Union[URIStr, FilePathStr],
|
1874
|
-
name: Optional[str] = None,
|
1875
|
-
checksum: bool = True,
|
1876
|
-
max_objects: Optional[int] = None,
|
1877
|
-
) -> Sequence["ArtifactManifestEntry"]:
|
1878
|
-
account_url, container_name, blob_name, query = self._parse_uri(path)
|
1879
|
-
path = URIStr(f"{account_url}/{container_name}/{blob_name}")
|
1880
|
-
|
1881
|
-
if not checksum:
|
1882
|
-
return [
|
1883
|
-
ArtifactManifestEntry(
|
1884
|
-
path=LogicalFilePathStr(name or blob_name), digest=path, ref=path
|
1885
|
-
)
|
1886
|
-
]
|
1887
|
-
|
1888
|
-
blob_service_client = self._get_module("azure.storage.blob").BlobServiceClient(
|
1889
|
-
account_url,
|
1890
|
-
credential=self._get_module("azure.identity").DefaultAzureCredential(),
|
1891
|
-
)
|
1892
|
-
blob_client = blob_service_client.get_blob_client(
|
1893
|
-
container=container_name, blob=blob_name
|
1894
|
-
)
|
1895
|
-
if blob_client.exists(version_id=query.get("versionId")):
|
1896
|
-
blob_properties = blob_client.get_blob_properties(
|
1897
|
-
version_id=query.get("versionId")
|
1898
|
-
)
|
1899
|
-
return [
|
1900
|
-
self._create_entry(
|
1901
|
-
blob_properties,
|
1902
|
-
path=LogicalFilePathStr(
|
1903
|
-
name or pathlib.PurePosixPath(blob_name).name
|
1904
|
-
),
|
1905
|
-
ref=URIStr(
|
1906
|
-
f"{account_url}/{container_name}/{blob_properties.name}"
|
1907
|
-
),
|
1908
|
-
)
|
1909
|
-
]
|
1910
|
-
|
1911
|
-
entries = []
|
1912
|
-
container_client = blob_service_client.get_container_client(container_name)
|
1913
|
-
max_objects = max_objects or DEFAULT_MAX_OBJECTS
|
1914
|
-
for i, blob_properties in enumerate(
|
1915
|
-
container_client.list_blobs(name_starts_with=f"{blob_name}/")
|
1916
|
-
):
|
1917
|
-
if i >= max_objects:
|
1918
|
-
raise ValueError(
|
1919
|
-
f"Exceeded {max_objects} objects tracked, pass max_objects to "
|
1920
|
-
f"add_reference"
|
1921
|
-
)
|
1922
|
-
suffix = pathlib.PurePosixPath(blob_properties.name).relative_to(blob_name)
|
1923
|
-
entries.append(
|
1924
|
-
self._create_entry(
|
1925
|
-
blob_properties,
|
1926
|
-
path=LogicalFilePathStr(str(name / suffix if name else suffix)),
|
1927
|
-
ref=URIStr(
|
1928
|
-
f"{account_url}/{container_name}/{blob_properties.name}"
|
1929
|
-
),
|
1930
|
-
)
|
1931
|
-
)
|
1932
|
-
return entries
|
1933
|
-
|
1934
|
-
def _get_module(self, name: str) -> ModuleType:
|
1935
|
-
module = util.get_module(
|
1936
|
-
name,
|
1937
|
-
lazy=False,
|
1938
|
-
required="Azure references require the azure library, run "
|
1939
|
-
"pip install wandb[azure]",
|
1940
|
-
)
|
1941
|
-
assert isinstance(module, ModuleType)
|
1942
|
-
return module
|
1943
|
-
|
1944
|
-
def _parse_uri(self, uri: str) -> Tuple[str, str, str, Dict[str, str]]:
|
1945
|
-
parsed_url = urlparse(uri)
|
1946
|
-
query = dict(parse_qsl(parsed_url.query))
|
1947
|
-
account_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
1948
|
-
_, container_name, blob_name = parsed_url.path.split("/", 2)
|
1949
|
-
return account_url, container_name, blob_name, query
|
1950
|
-
|
1951
|
-
def _create_entry(
|
1952
|
-
self,
|
1953
|
-
blob_properties: "azure.storage.blob.BlobProperties",
|
1954
|
-
path: LogicalFilePathStr,
|
1955
|
-
ref: URIStr,
|
1956
|
-
) -> ArtifactManifestEntry:
|
1957
|
-
extra = {"etag": blob_properties.etag.strip('"')}
|
1958
|
-
if blob_properties.version_id:
|
1959
|
-
extra["versionID"] = blob_properties.version_id
|
1960
|
-
return ArtifactManifestEntry(
|
1961
|
-
path=path,
|
1962
|
-
ref=ref,
|
1963
|
-
digest=blob_properties.etag.strip('"'),
|
1964
|
-
size=blob_properties.size,
|
1965
|
-
extra=extra,
|
1966
|
-
)
|
1967
|
-
|
1968
|
-
|
1969
|
-
class HTTPHandler(StorageHandler):
|
1970
|
-
def __init__(self, session: requests.Session, scheme: Optional[str] = None) -> None:
|
1971
|
-
self._scheme = scheme or "http"
|
1972
|
-
self._cache = get_artifacts_cache()
|
1973
|
-
self._session = session
|
1974
|
-
|
1975
|
-
def can_handle(self, parsed_url: "ParseResult") -> bool:
|
1976
|
-
return parsed_url.scheme == self._scheme
|
1977
|
-
|
1978
|
-
def load_path(
|
1979
|
-
self,
|
1980
|
-
manifest_entry: ArtifactManifestEntry,
|
1981
|
-
local: bool = False,
|
1982
|
-
) -> Union[URIStr, FilePathStr]:
|
1983
|
-
if not local:
|
1984
|
-
assert manifest_entry.ref is not None
|
1985
|
-
return manifest_entry.ref
|
1986
|
-
|
1987
|
-
assert manifest_entry.ref is not None
|
1988
|
-
|
1989
|
-
path, hit, cache_open = self._cache.check_etag_obj_path(
|
1990
|
-
URIStr(manifest_entry.ref),
|
1991
|
-
ETag(manifest_entry.digest), # TODO(spencerpearson): unsafe cast
|
1992
|
-
manifest_entry.size if manifest_entry.size is not None else 0,
|
1993
|
-
)
|
1994
|
-
if hit:
|
1995
|
-
return path
|
1996
|
-
|
1997
|
-
response = self._session.get(manifest_entry.ref, stream=True)
|
1998
|
-
response.raise_for_status()
|
1999
|
-
|
2000
|
-
digest: Optional[Union[ETag, FilePathStr, URIStr]]
|
2001
|
-
digest, size, extra = self._entry_from_headers(response.headers)
|
2002
|
-
digest = digest or manifest_entry.ref
|
2003
|
-
if manifest_entry.digest != digest:
|
2004
|
-
raise ValueError(
|
2005
|
-
f"Digest mismatch for url {manifest_entry.ref}: expected {manifest_entry.digest} but found {digest}"
|
2006
|
-
)
|
2007
|
-
|
2008
|
-
with cache_open(mode="wb") as file:
|
2009
|
-
for data in response.iter_content(chunk_size=16 * 1024):
|
2010
|
-
file.write(data)
|
2011
|
-
return path
|
2012
|
-
|
2013
|
-
def store_path(
|
2014
|
-
self,
|
2015
|
-
artifact: ArtifactInterface,
|
2016
|
-
path: Union[URIStr, FilePathStr],
|
2017
|
-
name: Optional[str] = None,
|
2018
|
-
checksum: bool = True,
|
2019
|
-
max_objects: Optional[int] = None,
|
2020
|
-
) -> Sequence[ArtifactManifestEntry]:
|
2021
|
-
name = LogicalFilePathStr(name or os.path.basename(path))
|
2022
|
-
if not checksum:
|
2023
|
-
return [ArtifactManifestEntry(path=name, ref=path, digest=path)]
|
2024
|
-
|
2025
|
-
with self._session.get(path, stream=True) as response:
|
2026
|
-
response.raise_for_status()
|
2027
|
-
digest: Optional[Union[ETag, FilePathStr, URIStr]]
|
2028
|
-
digest, size, extra = self._entry_from_headers(response.headers)
|
2029
|
-
digest = digest or path
|
2030
|
-
return [
|
2031
|
-
ArtifactManifestEntry(
|
2032
|
-
path=name, ref=path, digest=digest, size=size, extra=extra
|
2033
|
-
)
|
2034
|
-
]
|
2035
|
-
|
2036
|
-
def _entry_from_headers(
|
2037
|
-
self, headers: requests.structures.CaseInsensitiveDict
|
2038
|
-
) -> Tuple[Optional[ETag], Optional[int], Dict[str, str]]:
|
2039
|
-
response_headers = {k.lower(): v for k, v in headers.items()}
|
2040
|
-
size = None
|
2041
|
-
if response_headers.get("content-length", None):
|
2042
|
-
size = int(response_headers["content-length"])
|
2043
|
-
|
2044
|
-
digest = response_headers.get("etag", None)
|
2045
|
-
extra = {}
|
2046
|
-
if digest:
|
2047
|
-
extra["etag"] = digest
|
2048
|
-
if digest and digest[:1] == '"' and digest[-1:] == '"':
|
2049
|
-
digest = digest[1:-1] # trim leading and trailing quotes around etag
|
2050
|
-
return digest, size, extra
|
2051
|
-
|
2052
|
-
|
2053
|
-
class WBArtifactHandler(StorageHandler):
|
2054
|
-
"""Handles loading and storing Artifact reference-type files."""
|
2055
|
-
|
2056
|
-
_client: Optional[PublicApi]
|
2057
|
-
|
2058
|
-
def __init__(self) -> None:
|
2059
|
-
self._scheme = "wandb-artifact"
|
2060
|
-
self._cache = get_artifacts_cache()
|
2061
|
-
self._client = None
|
2062
|
-
|
2063
|
-
def can_handle(self, parsed_url: "ParseResult") -> bool:
|
2064
|
-
return parsed_url.scheme == self._scheme
|
2065
|
-
|
2066
|
-
@property
|
2067
|
-
def client(self) -> PublicApi:
|
2068
|
-
if self._client is None:
|
2069
|
-
self._client = PublicApi()
|
2070
|
-
return self._client
|
2071
|
-
|
2072
|
-
def load_path(
|
2073
|
-
self,
|
2074
|
-
manifest_entry: ArtifactManifestEntry,
|
2075
|
-
local: bool = False,
|
2076
|
-
) -> Union[URIStr, FilePathStr]:
|
2077
|
-
"""Load the file in the specified artifact given its corresponding entry.
|
2078
|
-
|
2079
|
-
Download the referenced artifact; create and return a new symlink to the caller.
|
2080
|
-
|
2081
|
-
Arguments:
|
2082
|
-
manifest_entry (ArtifactManifestEntry): The index entry to load
|
2083
|
-
|
2084
|
-
Returns:
|
2085
|
-
(os.PathLike): A path to the file represented by `index_entry`
|
2086
|
-
"""
|
2087
|
-
# We don't check for cache hits here. Since we have 0 for size (since this
|
2088
|
-
# is a cross-artifact reference which and we've made the choice to store 0
|
2089
|
-
# in the size field), we can't confirm if the file is complete. So we just
|
2090
|
-
# rely on the dep_artifact entry's download() method to do its own cache
|
2091
|
-
# check.
|
2092
|
-
|
2093
|
-
# Parse the reference path and download the artifact if needed
|
2094
|
-
artifact_id = util.host_from_path(manifest_entry.ref)
|
2095
|
-
artifact_file_path = util.uri_from_path(manifest_entry.ref)
|
2096
|
-
|
2097
|
-
dep_artifact = PublicArtifact.from_id(hex_to_b64_id(artifact_id), self.client)
|
2098
|
-
link_target_path: FilePathStr
|
2099
|
-
if local:
|
2100
|
-
link_target_path = dep_artifact.get_path(artifact_file_path).download()
|
2101
|
-
else:
|
2102
|
-
link_target_path = dep_artifact.get_path(artifact_file_path).ref_target()
|
2103
|
-
|
2104
|
-
return link_target_path
|
2105
|
-
|
2106
|
-
def store_path(
|
2107
|
-
self,
|
2108
|
-
artifact: ArtifactInterface,
|
2109
|
-
path: Union[URIStr, FilePathStr],
|
2110
|
-
name: Optional[str] = None,
|
2111
|
-
checksum: bool = True,
|
2112
|
-
max_objects: Optional[int] = None,
|
2113
|
-
) -> Sequence[ArtifactManifestEntry]:
|
2114
|
-
"""Store the file or directory at the given path into the specified artifact.
|
2115
|
-
|
2116
|
-
Recursively resolves the reference until the result is a concrete asset.
|
2117
|
-
|
2118
|
-
Arguments:
|
2119
|
-
artifact: The artifact doing the storing path (str): The path to store name
|
2120
|
-
(str): If specified, the logical name that should map to `path`
|
2121
|
-
|
2122
|
-
Returns:
|
2123
|
-
(list[ArtifactManifestEntry]): A list of manifest entries to store within
|
2124
|
-
the artifact
|
2125
|
-
"""
|
2126
|
-
# Recursively resolve the reference until a concrete asset is found
|
2127
|
-
# TODO: Consider resolving server-side for performance improvements.
|
2128
|
-
while path is not None and urlparse(path).scheme == self._scheme:
|
2129
|
-
artifact_id = util.host_from_path(path)
|
2130
|
-
artifact_file_path = util.uri_from_path(path)
|
2131
|
-
target_artifact = PublicArtifact.from_id(
|
2132
|
-
hex_to_b64_id(artifact_id), self.client
|
2133
|
-
)
|
2134
|
-
|
2135
|
-
# this should only have an effect if the user added the reference by url
|
2136
|
-
# string directly (in other words they did not already load the artifact into ram.)
|
2137
|
-
target_artifact._load_manifest()
|
2138
|
-
|
2139
|
-
entry = target_artifact._manifest.get_entry_by_path(artifact_file_path)
|
2140
|
-
path = entry.ref
|
2141
|
-
|
2142
|
-
# Create the path reference
|
2143
|
-
path = URIStr(
|
2144
|
-
"{}://{}/{}".format(
|
2145
|
-
self._scheme,
|
2146
|
-
b64_to_hex_id(target_artifact.id),
|
2147
|
-
artifact_file_path,
|
2148
|
-
)
|
2149
|
-
)
|
2150
|
-
|
2151
|
-
# Return the new entry
|
2152
|
-
return [
|
2153
|
-
ArtifactManifestEntry(
|
2154
|
-
path=LogicalFilePathStr(name or os.path.basename(path)),
|
2155
|
-
ref=path,
|
2156
|
-
size=0,
|
2157
|
-
digest=entry.digest,
|
2158
|
-
)
|
2159
|
-
]
|
2160
|
-
|
2161
|
-
|
2162
|
-
class WBLocalArtifactHandler(StorageHandler):
|
2163
|
-
"""Handles loading and storing Artifact reference-type files."""
|
2164
|
-
|
2165
|
-
_client: Optional[PublicApi]
|
2166
|
-
|
2167
|
-
def __init__(self) -> None:
|
2168
|
-
self._scheme = "wandb-client-artifact"
|
2169
|
-
self._cache = get_artifacts_cache()
|
2170
|
-
|
2171
|
-
def can_handle(self, parsed_url: "ParseResult") -> bool:
|
2172
|
-
return parsed_url.scheme == self._scheme
|
2173
|
-
|
2174
|
-
def load_path(
|
2175
|
-
self,
|
2176
|
-
manifest_entry: ArtifactManifestEntry,
|
2177
|
-
local: bool = False,
|
2178
|
-
) -> Union[URIStr, FilePathStr]:
|
2179
|
-
raise NotImplementedError(
|
2180
|
-
"Should not be loading a path for an artifact entry with unresolved client id."
|
2181
|
-
)
|
2182
|
-
|
2183
|
-
def store_path(
|
2184
|
-
self,
|
2185
|
-
artifact: ArtifactInterface,
|
2186
|
-
path: Union[URIStr, FilePathStr],
|
2187
|
-
name: Optional[str] = None,
|
2188
|
-
checksum: bool = True,
|
2189
|
-
max_objects: Optional[int] = None,
|
2190
|
-
) -> Sequence[ArtifactManifestEntry]:
|
2191
|
-
"""Store the file or directory at the given path within the specified artifact.
|
2192
|
-
|
2193
|
-
Arguments:
|
2194
|
-
artifact: The artifact doing the storing
|
2195
|
-
path (str): The path to store
|
2196
|
-
name (str): If specified, the logical name that should map to `path`
|
2197
|
-
|
2198
|
-
Returns:
|
2199
|
-
(list[ArtifactManifestEntry]): A list of manifest entries to store within the artifact
|
2200
|
-
"""
|
2201
|
-
client_id = util.host_from_path(path)
|
2202
|
-
target_path = util.uri_from_path(path)
|
2203
|
-
target_artifact = self._cache.get_client_artifact(client_id)
|
2204
|
-
if not isinstance(target_artifact, Artifact):
|
2205
|
-
raise RuntimeError("Local Artifact not found - invalid reference")
|
2206
|
-
target_entry = target_artifact._manifest.entries[target_path]
|
2207
|
-
if target_entry is None:
|
2208
|
-
raise RuntimeError("Local entry not found - invalid reference")
|
2209
|
-
|
2210
|
-
# Return the new entry
|
2211
|
-
return [
|
2212
|
-
ArtifactManifestEntry(
|
2213
|
-
path=LogicalFilePathStr(name or os.path.basename(path)),
|
2214
|
-
ref=path,
|
2215
|
-
size=0,
|
2216
|
-
digest=target_entry.digest,
|
2217
|
-
)
|
2218
|
-
]
|
2219
|
-
|
2220
|
-
|
2221
|
-
class _ArtifactVersionType(Type):
|
2222
|
-
name = "artifactVersion"
|
2223
|
-
types = [Artifact, PublicArtifact]
|
2224
|
-
|
2225
|
-
|
2226
|
-
TypeRegistry.add(_ArtifactVersionType)
|