wandb 0.15.3__py3-none-any.whl → 0.15.5__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- wandb/__init__.py +1 -1
- wandb/analytics/sentry.py +1 -0
- wandb/apis/importers/base.py +20 -5
- wandb/apis/importers/mlflow.py +7 -1
- wandb/apis/internal.py +12 -0
- wandb/apis/public.py +247 -1387
- wandb/apis/reports/_panels.py +58 -35
- wandb/beta/workflows.py +6 -7
- wandb/cli/cli.py +130 -60
- wandb/data_types.py +3 -1
- wandb/filesync/dir_watcher.py +21 -27
- wandb/filesync/step_checksum.py +8 -8
- wandb/filesync/step_prepare.py +23 -10
- wandb/filesync/step_upload.py +13 -13
- wandb/filesync/upload_job.py +4 -8
- wandb/integration/cohere/__init__.py +3 -0
- wandb/integration/cohere/cohere.py +21 -0
- wandb/integration/cohere/resolver.py +347 -0
- wandb/integration/gym/__init__.py +4 -6
- wandb/integration/huggingface/__init__.py +3 -0
- wandb/integration/huggingface/huggingface.py +18 -0
- wandb/integration/huggingface/resolver.py +213 -0
- wandb/integration/langchain/wandb_tracer.py +16 -179
- wandb/integration/openai/__init__.py +1 -3
- wandb/integration/openai/openai.py +11 -143
- wandb/integration/openai/resolver.py +111 -38
- wandb/integration/sagemaker/config.py +2 -2
- wandb/integration/tensorboard/log.py +4 -4
- wandb/old/settings.py +24 -7
- wandb/proto/v3/wandb_telemetry_pb2.py +12 -12
- wandb/proto/v4/wandb_telemetry_pb2.py +12 -12
- wandb/proto/wandb_deprecated.py +3 -1
- wandb/sdk/__init__.py +1 -1
- wandb/sdk/artifacts/__init__.py +0 -0
- wandb/sdk/artifacts/artifact.py +2101 -0
- wandb/sdk/artifacts/artifact_download_logger.py +42 -0
- wandb/sdk/artifacts/artifact_manifest.py +67 -0
- wandb/sdk/artifacts/artifact_manifest_entry.py +159 -0
- wandb/sdk/artifacts/artifact_manifests/__init__.py +0 -0
- wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +91 -0
- wandb/sdk/{internal → artifacts}/artifact_saver.py +6 -5
- wandb/sdk/artifacts/artifact_state.py +10 -0
- wandb/sdk/{interface/artifacts/artifact_cache.py → artifacts/artifacts_cache.py} +22 -12
- wandb/sdk/artifacts/exceptions.py +55 -0
- wandb/sdk/artifacts/storage_handler.py +59 -0
- wandb/sdk/artifacts/storage_handlers/__init__.py +0 -0
- wandb/sdk/artifacts/storage_handlers/azure_handler.py +192 -0
- wandb/sdk/artifacts/storage_handlers/gcs_handler.py +224 -0
- wandb/sdk/artifacts/storage_handlers/http_handler.py +112 -0
- wandb/sdk/artifacts/storage_handlers/local_file_handler.py +134 -0
- wandb/sdk/artifacts/storage_handlers/multi_handler.py +53 -0
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +301 -0
- wandb/sdk/artifacts/storage_handlers/tracking_handler.py +67 -0
- wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +132 -0
- wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +72 -0
- wandb/sdk/artifacts/storage_layout.py +6 -0
- wandb/sdk/artifacts/storage_policies/__init__.py +0 -0
- wandb/sdk/artifacts/storage_policies/s3_bucket_policy.py +61 -0
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +386 -0
- wandb/sdk/{interface/artifacts/artifact_storage.py → artifacts/storage_policy.py} +5 -57
- wandb/sdk/data_types/_dtypes.py +7 -12
- wandb/sdk/data_types/base_types/json_metadata.py +3 -2
- wandb/sdk/data_types/base_types/media.py +8 -8
- wandb/sdk/data_types/base_types/wb_value.py +12 -13
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +5 -6
- wandb/sdk/data_types/helper_types/classes.py +6 -8
- wandb/sdk/data_types/helper_types/image_mask.py +5 -6
- wandb/sdk/data_types/histogram.py +4 -3
- wandb/sdk/data_types/html.py +3 -4
- wandb/sdk/data_types/image.py +11 -9
- wandb/sdk/data_types/molecule.py +5 -3
- wandb/sdk/data_types/object_3d.py +7 -5
- wandb/sdk/data_types/plotly.py +3 -2
- wandb/sdk/data_types/saved_model.py +11 -11
- wandb/sdk/data_types/trace_tree.py +5 -4
- wandb/sdk/data_types/utils.py +3 -5
- wandb/sdk/data_types/video.py +5 -4
- wandb/sdk/integration_utils/auto_logging.py +215 -0
- wandb/sdk/interface/interface.py +15 -15
- wandb/sdk/internal/file_pusher.py +8 -16
- wandb/sdk/internal/file_stream.py +5 -11
- wandb/sdk/internal/handler.py +13 -1
- wandb/sdk/internal/internal_api.py +287 -13
- wandb/sdk/internal/job_builder.py +119 -30
- wandb/sdk/internal/sender.py +6 -26
- wandb/sdk/internal/settings_static.py +2 -0
- wandb/sdk/internal/system/assets/__init__.py +2 -0
- wandb/sdk/internal/system/assets/gpu.py +42 -0
- wandb/sdk/internal/system/assets/gpu_amd.py +216 -0
- wandb/sdk/internal/system/env_probe_helpers.py +13 -0
- wandb/sdk/internal/system/system_info.py +3 -3
- wandb/sdk/internal/tb_watcher.py +32 -22
- wandb/sdk/internal/thread_local_settings.py +18 -0
- wandb/sdk/launch/_project_spec.py +57 -11
- wandb/sdk/launch/agent/agent.py +147 -65
- wandb/sdk/launch/agent/job_status_tracker.py +34 -0
- wandb/sdk/launch/agent/run_queue_item_file_saver.py +45 -0
- wandb/sdk/launch/builder/abstract.py +5 -1
- wandb/sdk/launch/builder/build.py +21 -18
- wandb/sdk/launch/builder/docker_builder.py +10 -4
- wandb/sdk/launch/builder/kaniko_builder.py +113 -23
- wandb/sdk/launch/builder/noop.py +6 -3
- wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +46 -14
- wandb/sdk/launch/environment/aws_environment.py +3 -2
- wandb/sdk/launch/environment/azure_environment.py +124 -0
- wandb/sdk/launch/environment/gcp_environment.py +2 -4
- wandb/sdk/launch/environment/local_environment.py +1 -1
- wandb/sdk/launch/errors.py +19 -0
- wandb/sdk/launch/github_reference.py +32 -19
- wandb/sdk/launch/launch.py +3 -8
- wandb/sdk/launch/launch_add.py +6 -2
- wandb/sdk/launch/loader.py +21 -2
- wandb/sdk/launch/registry/azure_container_registry.py +132 -0
- wandb/sdk/launch/registry/elastic_container_registry.py +39 -5
- wandb/sdk/launch/registry/google_artifact_registry.py +68 -26
- wandb/sdk/launch/registry/local_registry.py +2 -1
- wandb/sdk/launch/runner/abstract.py +24 -3
- wandb/sdk/launch/runner/kubernetes_runner.py +479 -26
- wandb/sdk/launch/runner/local_container.py +103 -51
- wandb/sdk/launch/runner/local_process.py +1 -1
- wandb/sdk/launch/runner/sagemaker_runner.py +60 -10
- wandb/sdk/launch/runner/vertex_runner.py +10 -5
- wandb/sdk/launch/sweeps/__init__.py +7 -9
- wandb/sdk/launch/sweeps/scheduler.py +307 -77
- wandb/sdk/launch/sweeps/scheduler_sweep.py +2 -1
- wandb/sdk/launch/sweeps/utils.py +82 -35
- wandb/sdk/launch/utils.py +89 -75
- wandb/sdk/lib/_settings_toposort_generated.py +7 -0
- wandb/sdk/lib/capped_dict.py +26 -0
- wandb/sdk/lib/{git.py → gitlib.py} +76 -59
- wandb/sdk/lib/hashutil.py +12 -4
- wandb/sdk/lib/paths.py +96 -8
- wandb/sdk/lib/sock_client.py +2 -2
- wandb/sdk/lib/timer.py +1 -0
- wandb/sdk/service/server.py +22 -9
- wandb/sdk/service/server_sock.py +1 -1
- wandb/sdk/service/service.py +27 -8
- wandb/sdk/verify/verify.py +4 -7
- wandb/sdk/wandb_config.py +2 -6
- wandb/sdk/wandb_init.py +57 -53
- wandb/sdk/wandb_require.py +7 -0
- wandb/sdk/wandb_run.py +61 -223
- wandb/sdk/wandb_settings.py +28 -4
- wandb/testing/relay.py +15 -2
- wandb/util.py +74 -36
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/METADATA +15 -9
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/RECORD +151 -116
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/entry_points.txt +1 -0
- wandb/integration/langchain/util.py +0 -191
- wandb/sdk/interface/artifacts/__init__.py +0 -33
- wandb/sdk/interface/artifacts/artifact.py +0 -615
- wandb/sdk/interface/artifacts/artifact_manifest.py +0 -131
- wandb/sdk/wandb_artifacts.py +0 -2226
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/LICENSE +0 -0
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/WHEEL +0 -0
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,2101 @@
|
|
1
|
+
"""Artifact class."""
|
2
|
+
import concurrent.futures
|
3
|
+
import contextlib
|
4
|
+
import datetime
|
5
|
+
import json
|
6
|
+
import multiprocessing.dummy
|
7
|
+
import os
|
8
|
+
import platform
|
9
|
+
import re
|
10
|
+
import shutil
|
11
|
+
import tempfile
|
12
|
+
import time
|
13
|
+
from copy import copy
|
14
|
+
from functools import partial
|
15
|
+
from pathlib import PurePosixPath
|
16
|
+
from typing import (
|
17
|
+
IO,
|
18
|
+
TYPE_CHECKING,
|
19
|
+
Any,
|
20
|
+
Dict,
|
21
|
+
Generator,
|
22
|
+
List,
|
23
|
+
Optional,
|
24
|
+
Sequence,
|
25
|
+
Set,
|
26
|
+
Tuple,
|
27
|
+
Type,
|
28
|
+
Union,
|
29
|
+
cast,
|
30
|
+
)
|
31
|
+
from urllib.parse import urlparse
|
32
|
+
|
33
|
+
import requests
|
34
|
+
|
35
|
+
import wandb
|
36
|
+
from wandb import data_types, env, util
|
37
|
+
from wandb.apis.normalize import normalize_exceptions
|
38
|
+
from wandb.apis.public import ArtifactFiles, RetryingClient, Run
|
39
|
+
from wandb.data_types import WBValue
|
40
|
+
from wandb.errors.term import termerror, termlog, termwarn
|
41
|
+
from wandb.sdk.artifacts.artifact_download_logger import ArtifactDownloadLogger
|
42
|
+
from wandb.sdk.artifacts.artifact_manifest import ArtifactManifest
|
43
|
+
from wandb.sdk.artifacts.artifact_manifest_entry import ArtifactManifestEntry
|
44
|
+
from wandb.sdk.artifacts.artifact_manifests.artifact_manifest_v1 import (
|
45
|
+
ArtifactManifestV1,
|
46
|
+
)
|
47
|
+
from wandb.sdk.artifacts.artifact_saver import get_staging_dir
|
48
|
+
from wandb.sdk.artifacts.artifact_state import ArtifactState
|
49
|
+
from wandb.sdk.artifacts.artifacts_cache import get_artifacts_cache
|
50
|
+
from wandb.sdk.artifacts.exceptions import (
|
51
|
+
ArtifactFinalizedError,
|
52
|
+
ArtifactNotLoggedError,
|
53
|
+
WaitTimeoutError,
|
54
|
+
)
|
55
|
+
from wandb.sdk.artifacts.storage_layout import StorageLayout
|
56
|
+
from wandb.sdk.artifacts.storage_policies.wandb_storage_policy import WandbStoragePolicy
|
57
|
+
from wandb.sdk.data_types._dtypes import Type as WBType
|
58
|
+
from wandb.sdk.data_types._dtypes import TypeRegistry
|
59
|
+
from wandb.sdk.internal.thread_local_settings import _thread_local_api_settings
|
60
|
+
from wandb.sdk.lib import filesystem, retry, runid, telemetry
|
61
|
+
from wandb.sdk.lib.hashutil import B64MD5, b64_to_hex_id, md5_file_b64
|
62
|
+
from wandb.sdk.lib.paths import FilePathStr, LogicalPath, StrPath, URIStr
|
63
|
+
|
64
|
+
reset_path = util.vendor_setup()
|
65
|
+
|
66
|
+
from wandb_gql import gql # noqa: E402
|
67
|
+
|
68
|
+
reset_path()
|
69
|
+
|
70
|
+
if TYPE_CHECKING:
|
71
|
+
from wandb.sdk.interface.message_future import MessageFuture
|
72
|
+
|
73
|
+
|
74
|
+
class Artifact:
|
75
|
+
"""Flexible and lightweight building block for dataset and model versioning.
|
76
|
+
|
77
|
+
Constructs an empty artifact whose contents can be populated using its `add` family
|
78
|
+
of functions. Once the artifact has all the desired files, you can call
|
79
|
+
`wandb.log_artifact()` to log it.
|
80
|
+
|
81
|
+
Arguments:
|
82
|
+
name: A human-readable name for this artifact, which is how you can identify
|
83
|
+
this artifact in the UI or reference it in `use_artifact` calls. Names can
|
84
|
+
contain letters, numbers, underscores, hyphens, and dots. The name must be
|
85
|
+
unique across a project.
|
86
|
+
type: The type of the artifact, which is used to organize and differentiate
|
87
|
+
artifacts. Common types include `dataset` or `model`, but you can use any
|
88
|
+
string containing letters, numbers, underscores, hyphens, and dots.
|
89
|
+
description: Free text that offers a description of the artifact. The
|
90
|
+
description is markdown rendered in the UI, so this is a good place to place
|
91
|
+
tables, links, etc.
|
92
|
+
metadata: Structured data associated with the artifact, for example class
|
93
|
+
distribution of a dataset. This will eventually be queryable and plottable
|
94
|
+
in the UI. There is a hard limit of 100 total keys.
|
95
|
+
|
96
|
+
Returns:
|
97
|
+
An `Artifact` object.
|
98
|
+
|
99
|
+
Examples:
|
100
|
+
Basic usage:
|
101
|
+
```
|
102
|
+
wandb.init()
|
103
|
+
|
104
|
+
artifact = wandb.Artifact("mnist", type="dataset")
|
105
|
+
artifact.add_dir("mnist/")
|
106
|
+
wandb.log_artifact(artifact)
|
107
|
+
```
|
108
|
+
"""
|
109
|
+
|
110
|
+
_TMP_DIR = tempfile.TemporaryDirectory("wandb-artifacts")
|
111
|
+
_GQL_FRAGMENT = """
|
112
|
+
fragment ArtifactFragment on Artifact {
|
113
|
+
id
|
114
|
+
artifactSequence {
|
115
|
+
project {
|
116
|
+
entityName
|
117
|
+
name
|
118
|
+
}
|
119
|
+
name
|
120
|
+
}
|
121
|
+
versionIndex
|
122
|
+
artifactType {
|
123
|
+
name
|
124
|
+
}
|
125
|
+
description
|
126
|
+
metadata
|
127
|
+
aliases {
|
128
|
+
artifactCollectionName
|
129
|
+
alias
|
130
|
+
}
|
131
|
+
state
|
132
|
+
commitHash
|
133
|
+
fileCount
|
134
|
+
createdAt
|
135
|
+
updatedAt
|
136
|
+
}
|
137
|
+
"""
|
138
|
+
|
139
|
+
def __init__(
|
140
|
+
self,
|
141
|
+
name: str,
|
142
|
+
type: str,
|
143
|
+
description: Optional[str] = None,
|
144
|
+
metadata: Optional[Dict[str, Any]] = None,
|
145
|
+
incremental: bool = False,
|
146
|
+
use_as: Optional[str] = None,
|
147
|
+
) -> None:
|
148
|
+
if not re.match(r"^[a-zA-Z0-9_\-.]+$", name):
|
149
|
+
raise ValueError(
|
150
|
+
f"Artifact name may only contain alphanumeric characters, dashes, "
|
151
|
+
f"underscores, and dots. Invalid name: {name}"
|
152
|
+
)
|
153
|
+
if type == "job" or type.startswith("wandb-"):
|
154
|
+
raise ValueError(
|
155
|
+
"Artifact types 'job' and 'wandb-*' are reserved for internal use. "
|
156
|
+
"Please use a different type."
|
157
|
+
)
|
158
|
+
if incremental:
|
159
|
+
termwarn("Using experimental arg `incremental`")
|
160
|
+
|
161
|
+
# Internal.
|
162
|
+
self._client: Optional[RetryingClient] = None
|
163
|
+
storage_layout = (
|
164
|
+
StorageLayout.V1 if env.get_use_v1_artifacts() else StorageLayout.V2
|
165
|
+
)
|
166
|
+
self._storage_policy = WandbStoragePolicy(
|
167
|
+
config={
|
168
|
+
"storageLayout": storage_layout,
|
169
|
+
# TODO: storage region
|
170
|
+
}
|
171
|
+
)
|
172
|
+
self._tmp_dir: Optional[tempfile.TemporaryDirectory] = None
|
173
|
+
self._added_objs: Dict[
|
174
|
+
int, Tuple[data_types.WBValue, ArtifactManifestEntry]
|
175
|
+
] = {}
|
176
|
+
self._added_local_paths: Dict[str, ArtifactManifestEntry] = {}
|
177
|
+
self._save_future: Optional["MessageFuture"] = None
|
178
|
+
self._dependent_artifacts: Set["Artifact"] = set()
|
179
|
+
self._download_roots: Set[str] = set()
|
180
|
+
# Properties.
|
181
|
+
self._id: Optional[str] = None
|
182
|
+
self._client_id: str = runid.generate_id(128)
|
183
|
+
self._sequence_client_id: str = runid.generate_id(128)
|
184
|
+
self._entity: Optional[str] = None
|
185
|
+
self._project: Optional[str] = None
|
186
|
+
self._name: str = name # includes version after saving
|
187
|
+
self._version: Optional[str] = None
|
188
|
+
self._source_entity: Optional[str] = None
|
189
|
+
self._source_project: Optional[str] = None
|
190
|
+
self._source_name: str = name # includes version after saving
|
191
|
+
self._source_version: Optional[str] = None
|
192
|
+
self._type: str = type
|
193
|
+
self._description: Optional[str] = description
|
194
|
+
self._metadata: dict = self._normalize_metadata(metadata)
|
195
|
+
self._aliases: List[str] = []
|
196
|
+
self._saved_aliases: List[str] = []
|
197
|
+
self._distributed_id: Optional[str] = None
|
198
|
+
self._incremental: bool = incremental
|
199
|
+
self._use_as: Optional[str] = use_as
|
200
|
+
self._state: ArtifactState = ArtifactState.PENDING
|
201
|
+
self._manifest: Optional[ArtifactManifest] = ArtifactManifestV1(
|
202
|
+
self._storage_policy
|
203
|
+
)
|
204
|
+
self._commit_hash: Optional[str] = None
|
205
|
+
self._file_count: Optional[int] = None
|
206
|
+
self._created_at: Optional[str] = None
|
207
|
+
self._updated_at: Optional[str] = None
|
208
|
+
self._final: bool = False
|
209
|
+
# Cache.
|
210
|
+
get_artifacts_cache().store_client_artifact(self)
|
211
|
+
|
212
|
+
def __repr__(self) -> str:
|
213
|
+
return f"<Artifact {self.id or self.name}>"
|
214
|
+
|
215
|
+
@classmethod
|
216
|
+
def _from_id(cls, artifact_id: str, client: RetryingClient) -> Optional["Artifact"]:
|
217
|
+
artifact = get_artifacts_cache().get_artifact(artifact_id)
|
218
|
+
if artifact is not None:
|
219
|
+
return artifact
|
220
|
+
|
221
|
+
query = gql(
|
222
|
+
"""
|
223
|
+
query ArtifactByID($id: ID!) {
|
224
|
+
artifact(id: $id) {
|
225
|
+
...ArtifactFragment
|
226
|
+
currentManifest {
|
227
|
+
file {
|
228
|
+
directUrl
|
229
|
+
}
|
230
|
+
}
|
231
|
+
}
|
232
|
+
}
|
233
|
+
"""
|
234
|
+
+ cls._GQL_FRAGMENT
|
235
|
+
)
|
236
|
+
response = client.execute(
|
237
|
+
query,
|
238
|
+
variable_values={"id": artifact_id},
|
239
|
+
)
|
240
|
+
attrs = response.get("artifact")
|
241
|
+
if attrs is None:
|
242
|
+
return None
|
243
|
+
entity = attrs["artifactSequence"]["project"]["entityName"]
|
244
|
+
project = attrs["artifactSequence"]["project"]["name"]
|
245
|
+
name = "{}:v{}".format(attrs["artifactSequence"]["name"], attrs["versionIndex"])
|
246
|
+
return cls._from_attrs(entity, project, name, attrs, client)
|
247
|
+
|
248
|
+
@classmethod
|
249
|
+
def _from_name(
|
250
|
+
cls, entity: str, project: str, name: str, client: RetryingClient
|
251
|
+
) -> "Artifact":
|
252
|
+
query = gql(
|
253
|
+
"""
|
254
|
+
query ArtifactByName(
|
255
|
+
$entityName: String!,
|
256
|
+
$projectName: String!,
|
257
|
+
$name: String!
|
258
|
+
) {
|
259
|
+
project(name: $projectName, entityName: $entityName) {
|
260
|
+
artifact(name: $name) {
|
261
|
+
...ArtifactFragment
|
262
|
+
}
|
263
|
+
}
|
264
|
+
}
|
265
|
+
"""
|
266
|
+
+ cls._GQL_FRAGMENT
|
267
|
+
)
|
268
|
+
response = client.execute(
|
269
|
+
query,
|
270
|
+
variable_values={
|
271
|
+
"entityName": entity,
|
272
|
+
"projectName": project,
|
273
|
+
"name": name,
|
274
|
+
},
|
275
|
+
)
|
276
|
+
attrs = response.get("project", {}).get("artifact")
|
277
|
+
if attrs is None:
|
278
|
+
raise ValueError(
|
279
|
+
f"Unable to fetch artifact with name {entity}/{project}/{name}"
|
280
|
+
)
|
281
|
+
return cls._from_attrs(entity, project, name, attrs, client)
|
282
|
+
|
283
|
+
@classmethod
|
284
|
+
def _from_attrs(
|
285
|
+
cls,
|
286
|
+
entity: str,
|
287
|
+
project: str,
|
288
|
+
name: str,
|
289
|
+
attrs: Dict[str, Any],
|
290
|
+
client: RetryingClient,
|
291
|
+
) -> "Artifact":
|
292
|
+
# Placeholder is required to skip validation.
|
293
|
+
artifact = cls("placeholder", type="placeholder")
|
294
|
+
artifact._client = client
|
295
|
+
artifact._id = attrs["id"]
|
296
|
+
artifact._entity = entity
|
297
|
+
artifact._project = project
|
298
|
+
artifact._name = name
|
299
|
+
version_aliases = [
|
300
|
+
alias["alias"]
|
301
|
+
for alias in attrs.get("aliases", [])
|
302
|
+
if alias["artifactCollectionName"] == name.split(":")[0]
|
303
|
+
and util.alias_is_version_index(alias["alias"])
|
304
|
+
]
|
305
|
+
# assert len(version_aliases) == 1
|
306
|
+
artifact._version = version_aliases[0]
|
307
|
+
artifact._source_entity = attrs["artifactSequence"]["project"]["entityName"]
|
308
|
+
artifact._source_project = attrs["artifactSequence"]["project"]["name"]
|
309
|
+
artifact._source_name = "{}:v{}".format(
|
310
|
+
attrs["artifactSequence"]["name"], attrs["versionIndex"]
|
311
|
+
)
|
312
|
+
artifact._source_version = "v{}".format(attrs["versionIndex"])
|
313
|
+
artifact._type = attrs["artifactType"]["name"]
|
314
|
+
artifact._description = attrs["description"]
|
315
|
+
artifact.metadata = cls._normalize_metadata(
|
316
|
+
json.loads(attrs["metadata"] or "{}")
|
317
|
+
)
|
318
|
+
artifact._aliases = [
|
319
|
+
alias["alias"]
|
320
|
+
for alias in attrs.get("aliases", [])
|
321
|
+
if alias["artifactCollectionName"] == name.split(":")[0]
|
322
|
+
and not util.alias_is_version_index(alias["alias"])
|
323
|
+
]
|
324
|
+
artifact._saved_aliases = copy(artifact._aliases)
|
325
|
+
artifact._state = ArtifactState(attrs["state"])
|
326
|
+
if "currentManifest" in attrs:
|
327
|
+
artifact._load_manifest(attrs["currentManifest"]["file"]["directUrl"])
|
328
|
+
else:
|
329
|
+
artifact._manifest = None
|
330
|
+
artifact._commit_hash = attrs["commitHash"]
|
331
|
+
artifact._file_count = attrs["fileCount"]
|
332
|
+
artifact._created_at = attrs["createdAt"]
|
333
|
+
artifact._updated_at = attrs["updatedAt"]
|
334
|
+
artifact._final = True
|
335
|
+
# Cache.
|
336
|
+
get_artifacts_cache().store_artifact(artifact)
|
337
|
+
return artifact
|
338
|
+
|
339
|
+
def new_draft(self) -> "Artifact":
|
340
|
+
"""Create a new draft artifact with the same content as this committed artifact.
|
341
|
+
|
342
|
+
The artifact returned can be extended or modified and logged as a new version.
|
343
|
+
|
344
|
+
Raises:
|
345
|
+
ArtifactNotLoggedError: if the artifact has not been logged
|
346
|
+
"""
|
347
|
+
if self._state == ArtifactState.PENDING:
|
348
|
+
raise ArtifactNotLoggedError(self, "new_draft")
|
349
|
+
|
350
|
+
artifact = Artifact(self.source_name.split(":")[0], self.type)
|
351
|
+
artifact._description = self.description
|
352
|
+
artifact._metadata = self.metadata
|
353
|
+
artifact._manifest = ArtifactManifest.from_manifest_json(
|
354
|
+
self.manifest.to_manifest_json()
|
355
|
+
)
|
356
|
+
return artifact
|
357
|
+
|
358
|
+
# Properties.
|
359
|
+
|
360
|
+
@property
|
361
|
+
def id(self) -> Optional[str]:
|
362
|
+
"""The artifact's ID."""
|
363
|
+
if self._state == ArtifactState.PENDING:
|
364
|
+
return None
|
365
|
+
assert self._id is not None
|
366
|
+
return self._id
|
367
|
+
|
368
|
+
@property
|
369
|
+
def entity(self) -> str:
|
370
|
+
"""The name of the entity of the secondary (portfolio) artifact collection."""
|
371
|
+
if self._state == ArtifactState.PENDING:
|
372
|
+
raise ArtifactNotLoggedError(self, "entity")
|
373
|
+
assert self._entity is not None
|
374
|
+
return self._entity
|
375
|
+
|
376
|
+
@property
|
377
|
+
def project(self) -> str:
|
378
|
+
"""The name of the project of the secondary (portfolio) artifact collection."""
|
379
|
+
if self._state == ArtifactState.PENDING:
|
380
|
+
raise ArtifactNotLoggedError(self, "project")
|
381
|
+
assert self._project is not None
|
382
|
+
return self._project
|
383
|
+
|
384
|
+
@property
|
385
|
+
def name(self) -> str:
|
386
|
+
"""The artifact name and version in its secondary (portfolio) collection.
|
387
|
+
|
388
|
+
A string with the format {collection}:{alias}. Before the artifact is saved,
|
389
|
+
contains only the name since the version is not yet known.
|
390
|
+
"""
|
391
|
+
return self._name
|
392
|
+
|
393
|
+
@property
|
394
|
+
def qualified_name(self) -> str:
|
395
|
+
"""The entity/project/name of the secondary (portfolio) collection."""
|
396
|
+
return f"{self.entity}/{self.project}/{self.name}"
|
397
|
+
|
398
|
+
@property
|
399
|
+
def version(self) -> str:
|
400
|
+
"""The artifact's version in its secondary (portfolio) collection."""
|
401
|
+
if self._state == ArtifactState.PENDING:
|
402
|
+
raise ArtifactNotLoggedError(self, "version")
|
403
|
+
assert self._version is not None
|
404
|
+
return self._version
|
405
|
+
|
406
|
+
@property
|
407
|
+
def source_entity(self) -> str:
|
408
|
+
"""The name of the entity of the primary (sequence) artifact collection."""
|
409
|
+
if self._state == ArtifactState.PENDING:
|
410
|
+
raise ArtifactNotLoggedError(self, "source_entity")
|
411
|
+
assert self._source_entity is not None
|
412
|
+
return self._source_entity
|
413
|
+
|
414
|
+
@property
|
415
|
+
def source_project(self) -> str:
|
416
|
+
"""The name of the project of the primary (sequence) artifact collection."""
|
417
|
+
if self._state == ArtifactState.PENDING:
|
418
|
+
raise ArtifactNotLoggedError(self, "source_project")
|
419
|
+
assert self._source_project is not None
|
420
|
+
return self._source_project
|
421
|
+
|
422
|
+
@property
|
423
|
+
def source_name(self) -> str:
|
424
|
+
"""The artifact name and version in its primary (sequence) collection.
|
425
|
+
|
426
|
+
A string with the format {collection}:{alias}. Before the artifact is saved,
|
427
|
+
contains only the name since the version is not yet known.
|
428
|
+
"""
|
429
|
+
return self._source_name
|
430
|
+
|
431
|
+
@property
|
432
|
+
def source_qualified_name(self) -> str:
|
433
|
+
"""The entity/project/name of the primary (sequence) collection."""
|
434
|
+
return f"{self.source_entity}/{self.source_project}/{self.source_name}"
|
435
|
+
|
436
|
+
@property
|
437
|
+
def source_version(self) -> str:
|
438
|
+
"""The artifact's version in its primary (sequence) collection.
|
439
|
+
|
440
|
+
A string with the format "v{number}".
|
441
|
+
"""
|
442
|
+
if self._state == ArtifactState.PENDING:
|
443
|
+
raise ArtifactNotLoggedError(self, "source_version")
|
444
|
+
assert self._source_version is not None
|
445
|
+
return self._source_version
|
446
|
+
|
447
|
+
@property
|
448
|
+
def type(self) -> str:
|
449
|
+
"""The artifact's type."""
|
450
|
+
return self._type
|
451
|
+
|
452
|
+
@property
|
453
|
+
def description(self) -> Optional[str]:
|
454
|
+
"""The artifact description.
|
455
|
+
|
456
|
+
Free text that offers a user-set description of the artifact.
|
457
|
+
"""
|
458
|
+
return self._description
|
459
|
+
|
460
|
+
@description.setter
|
461
|
+
def description(self, description: Optional[str]) -> None:
|
462
|
+
"""Set the description of the artifact.
|
463
|
+
|
464
|
+
The description is markdown rendered in the UI, so this is a good place to put
|
465
|
+
links, etc.
|
466
|
+
|
467
|
+
Arguments:
|
468
|
+
desc: Free text that offers a description of the artifact.
|
469
|
+
"""
|
470
|
+
self._description = description
|
471
|
+
|
472
|
+
@property
|
473
|
+
def metadata(self) -> dict:
|
474
|
+
"""User-defined artifact metadata.
|
475
|
+
|
476
|
+
Structured data associated with the artifact.
|
477
|
+
"""
|
478
|
+
return self._metadata
|
479
|
+
|
480
|
+
@metadata.setter
|
481
|
+
def metadata(self, metadata: dict) -> None:
|
482
|
+
"""User-defined artifact metadata.
|
483
|
+
|
484
|
+
Metadata set this way will eventually be queryable and plottable in the UI; e.g.
|
485
|
+
the class distribution of a dataset.
|
486
|
+
|
487
|
+
Note: There is currently a limit of 100 total keys.
|
488
|
+
|
489
|
+
Arguments:
|
490
|
+
metadata: Structured data associated with the artifact.
|
491
|
+
"""
|
492
|
+
self._metadata = self._normalize_metadata(metadata)
|
493
|
+
|
494
|
+
@property
|
495
|
+
def aliases(self) -> List[str]:
|
496
|
+
"""The aliases associated with this artifact.
|
497
|
+
|
498
|
+
The list is mutable and calling `save()` will persist all alias changes.
|
499
|
+
"""
|
500
|
+
if self._state == ArtifactState.PENDING:
|
501
|
+
raise ArtifactNotLoggedError(self, "aliases")
|
502
|
+
return self._aliases
|
503
|
+
|
504
|
+
@aliases.setter
|
505
|
+
def aliases(self, aliases: List[str]) -> None:
|
506
|
+
"""Set the aliases associated with this artifact."""
|
507
|
+
if self._state == ArtifactState.PENDING:
|
508
|
+
raise ArtifactNotLoggedError(self, "aliases")
|
509
|
+
|
510
|
+
if any(char in alias for alias in aliases for char in ["/", ":"]):
|
511
|
+
raise ValueError(
|
512
|
+
"Aliases must not contain any of the following characters: /, :"
|
513
|
+
)
|
514
|
+
self._aliases = aliases
|
515
|
+
|
516
|
+
@property
|
517
|
+
def distributed_id(self) -> Optional[str]:
|
518
|
+
return self._distributed_id
|
519
|
+
|
520
|
+
@distributed_id.setter
|
521
|
+
def distributed_id(self, distributed_id: Optional[str]) -> None:
|
522
|
+
self._distributed_id = distributed_id
|
523
|
+
|
524
|
+
@property
|
525
|
+
def incremental(self) -> bool:
|
526
|
+
return self._incremental
|
527
|
+
|
528
|
+
@property
|
529
|
+
def use_as(self) -> Optional[str]:
|
530
|
+
return self._use_as
|
531
|
+
|
532
|
+
@property
|
533
|
+
def state(self) -> str:
|
534
|
+
"""The status of the artifact. One of: "PENDING", "COMMITTED", or "DELETED"."""
|
535
|
+
return self._state.value
|
536
|
+
|
537
|
+
@property
|
538
|
+
def manifest(self) -> ArtifactManifest:
|
539
|
+
"""The artifact's manifest.
|
540
|
+
|
541
|
+
The manifest lists all of its contents, and can't be changed once the artifact
|
542
|
+
has been logged.
|
543
|
+
"""
|
544
|
+
if self._manifest is None:
|
545
|
+
query = gql(
|
546
|
+
"""
|
547
|
+
query ArtifactManifest(
|
548
|
+
$entityName: String!,
|
549
|
+
$projectName: String!,
|
550
|
+
$name: String!
|
551
|
+
) {
|
552
|
+
project(entityName: $entityName, name: $projectName) {
|
553
|
+
artifact(name: $name) {
|
554
|
+
currentManifest {
|
555
|
+
file {
|
556
|
+
directUrl
|
557
|
+
}
|
558
|
+
}
|
559
|
+
}
|
560
|
+
}
|
561
|
+
}
|
562
|
+
"""
|
563
|
+
)
|
564
|
+
assert self._client is not None
|
565
|
+
response = self._client.execute(
|
566
|
+
query,
|
567
|
+
variable_values={
|
568
|
+
"entityName": self._entity,
|
569
|
+
"projectName": self._project,
|
570
|
+
"name": self._name,
|
571
|
+
},
|
572
|
+
)
|
573
|
+
attrs = response["project"]["artifact"]
|
574
|
+
self._load_manifest(attrs["currentManifest"]["file"]["directUrl"])
|
575
|
+
assert self._manifest is not None
|
576
|
+
return self._manifest
|
577
|
+
|
578
|
+
@property
|
579
|
+
def digest(self) -> str:
|
580
|
+
"""The logical digest of the artifact.
|
581
|
+
|
582
|
+
The digest is the checksum of the artifact's contents. If an artifact has the
|
583
|
+
same digest as the current `latest` version, then `log_artifact` is a no-op.
|
584
|
+
"""
|
585
|
+
return self.manifest.digest()
|
586
|
+
|
587
|
+
@property
|
588
|
+
def size(self) -> int:
|
589
|
+
"""The total size of the artifact in bytes.
|
590
|
+
|
591
|
+
Includes any references tracked by this artifact.
|
592
|
+
"""
|
593
|
+
total_size: int = 0
|
594
|
+
for entry in self.manifest.entries.values():
|
595
|
+
if entry.size is not None:
|
596
|
+
total_size += entry.size
|
597
|
+
return total_size
|
598
|
+
|
599
|
+
@property
|
600
|
+
def commit_hash(self) -> str:
|
601
|
+
"""The hash returned when this artifact was committed."""
|
602
|
+
if self._state == ArtifactState.PENDING:
|
603
|
+
raise ArtifactNotLoggedError(self, "commit_hash")
|
604
|
+
assert self._commit_hash is not None
|
605
|
+
return self._commit_hash
|
606
|
+
|
607
|
+
@property
|
608
|
+
def file_count(self) -> int:
|
609
|
+
"""The number of files (including references)."""
|
610
|
+
if self._state == ArtifactState.PENDING:
|
611
|
+
raise ArtifactNotLoggedError(self, "file_count")
|
612
|
+
assert self._file_count is not None
|
613
|
+
return self._file_count
|
614
|
+
|
615
|
+
@property
|
616
|
+
def created_at(self) -> str:
|
617
|
+
"""The time at which the artifact was created."""
|
618
|
+
if self._state == ArtifactState.PENDING:
|
619
|
+
raise ArtifactNotLoggedError(self, "created_at")
|
620
|
+
assert self._created_at is not None
|
621
|
+
return self._created_at
|
622
|
+
|
623
|
+
@property
|
624
|
+
def updated_at(self) -> str:
|
625
|
+
"""The time at which the artifact was last updated."""
|
626
|
+
if self._state == ArtifactState.PENDING:
|
627
|
+
raise ArtifactNotLoggedError(self, "created_at")
|
628
|
+
assert self._created_at is not None
|
629
|
+
return self._updated_at or self._created_at
|
630
|
+
|
631
|
+
# State management.
|
632
|
+
|
633
|
+
def finalize(self) -> None:
|
634
|
+
"""Mark this artifact as final, disallowing further modifications.
|
635
|
+
|
636
|
+
This happens automatically when calling `log_artifact`.
|
637
|
+
"""
|
638
|
+
self._final = True
|
639
|
+
|
640
|
+
def _ensure_can_add(self) -> None:
|
641
|
+
if self._final:
|
642
|
+
raise ArtifactFinalizedError(artifact=self)
|
643
|
+
|
644
|
+
def is_draft(self) -> bool:
|
645
|
+
"""Whether the artifact is a draft, i.e. it hasn't been saved yet."""
|
646
|
+
return self._state == ArtifactState.PENDING
|
647
|
+
|
648
|
+
def _is_draft_save_started(self) -> bool:
|
649
|
+
return self._save_future is not None
|
650
|
+
|
651
|
+
def save(
|
652
|
+
self,
|
653
|
+
project: Optional[str] = None,
|
654
|
+
settings: Optional["wandb.wandb_sdk.wandb_settings.Settings"] = None,
|
655
|
+
) -> None:
|
656
|
+
"""Persist any changes made to the artifact.
|
657
|
+
|
658
|
+
If currently in a run, that run will log this artifact. If not currently in a
|
659
|
+
run, a run of type "auto" will be created to track this artifact.
|
660
|
+
|
661
|
+
Arguments:
|
662
|
+
project: A project to use for the artifact in the case that a run is not
|
663
|
+
already in context
|
664
|
+
settings: A settings object to use when initializing an automatic run. Most
|
665
|
+
commonly used in testing harness.
|
666
|
+
"""
|
667
|
+
if self._state != ArtifactState.PENDING:
|
668
|
+
return self._update()
|
669
|
+
|
670
|
+
if self._incremental:
|
671
|
+
with telemetry.context() as tel:
|
672
|
+
tel.feature.artifact_incremental = True
|
673
|
+
|
674
|
+
if wandb.run is None:
|
675
|
+
if settings is None:
|
676
|
+
settings = wandb.Settings(silent="true")
|
677
|
+
with wandb.init(project=project, job_type="auto", settings=settings) as run:
|
678
|
+
# redoing this here because in this branch we know we didn't
|
679
|
+
# have the run at the beginning of the method
|
680
|
+
if self._incremental:
|
681
|
+
with telemetry.context(run=run) as tel:
|
682
|
+
tel.feature.artifact_incremental = True
|
683
|
+
run.log_artifact(self)
|
684
|
+
else:
|
685
|
+
wandb.run.log_artifact(self)
|
686
|
+
|
687
|
+
def _set_save_future(
|
688
|
+
self, save_future: "MessageFuture", client: RetryingClient
|
689
|
+
) -> None:
|
690
|
+
self._save_future = save_future
|
691
|
+
self._client = client
|
692
|
+
|
693
|
+
def wait(self, timeout: Optional[int] = None) -> "Artifact":
|
694
|
+
"""Wait for this artifact to finish logging, if needed.
|
695
|
+
|
696
|
+
Arguments:
|
697
|
+
timeout: Wait up to this long.
|
698
|
+
"""
|
699
|
+
if self._state == ArtifactState.PENDING:
|
700
|
+
if self._save_future is None:
|
701
|
+
raise ArtifactNotLoggedError(self, "wait")
|
702
|
+
result = self._save_future.get(timeout)
|
703
|
+
if not result:
|
704
|
+
raise WaitTimeoutError(
|
705
|
+
"Artifact upload wait timed out, failed to fetch Artifact response"
|
706
|
+
)
|
707
|
+
response = result.response.log_artifact_response
|
708
|
+
if response.error_message:
|
709
|
+
raise ValueError(response.error_message)
|
710
|
+
self._populate_after_save(response.artifact_id)
|
711
|
+
return self
|
712
|
+
|
713
|
+
def _populate_after_save(self, artifact_id: str) -> None:
|
714
|
+
query = gql(
|
715
|
+
"""
|
716
|
+
query ArtifactByIDShort($id: ID!) {
|
717
|
+
artifact(id: $id) {
|
718
|
+
artifactSequence {
|
719
|
+
project {
|
720
|
+
entityName
|
721
|
+
name
|
722
|
+
}
|
723
|
+
name
|
724
|
+
}
|
725
|
+
versionIndex
|
726
|
+
aliases {
|
727
|
+
artifactCollectionName
|
728
|
+
alias
|
729
|
+
}
|
730
|
+
state
|
731
|
+
currentManifest {
|
732
|
+
file {
|
733
|
+
directUrl
|
734
|
+
}
|
735
|
+
}
|
736
|
+
commitHash
|
737
|
+
fileCount
|
738
|
+
createdAt
|
739
|
+
updatedAt
|
740
|
+
}
|
741
|
+
}
|
742
|
+
"""
|
743
|
+
)
|
744
|
+
assert self._client is not None
|
745
|
+
response = self._client.execute(
|
746
|
+
query,
|
747
|
+
variable_values={"id": artifact_id},
|
748
|
+
)
|
749
|
+
attrs = response.get("artifact")
|
750
|
+
if attrs is None:
|
751
|
+
raise ValueError(f"Unable to fetch artifact with id {artifact_id}")
|
752
|
+
self._id = artifact_id
|
753
|
+
self._entity = attrs["artifactSequence"]["project"]["entityName"]
|
754
|
+
self._project = attrs["artifactSequence"]["project"]["name"]
|
755
|
+
self._name = "{}:v{}".format(
|
756
|
+
attrs["artifactSequence"]["name"], attrs["versionIndex"]
|
757
|
+
)
|
758
|
+
self._version = "v{}".format(attrs["versionIndex"])
|
759
|
+
self._source_entity = self._entity
|
760
|
+
self._source_project = self._project
|
761
|
+
self._source_name = self._name
|
762
|
+
self._source_version = self._version
|
763
|
+
self._aliases = [
|
764
|
+
alias["alias"]
|
765
|
+
for alias in attrs.get("aliases", [])
|
766
|
+
if alias["artifactCollectionName"] == self._name.split(":")[0]
|
767
|
+
and not util.alias_is_version_index(alias["alias"])
|
768
|
+
]
|
769
|
+
self._state = ArtifactState(attrs["state"])
|
770
|
+
with requests.get(attrs["currentManifest"]["file"]["directUrl"]) as request:
|
771
|
+
request.raise_for_status()
|
772
|
+
self._manifest = ArtifactManifest.from_manifest_json(
|
773
|
+
json.loads(util.ensure_text(request.content))
|
774
|
+
)
|
775
|
+
self._commit_hash = attrs["commitHash"]
|
776
|
+
self._file_count = attrs["fileCount"]
|
777
|
+
self._created_at = attrs["createdAt"]
|
778
|
+
self._updated_at = attrs["updatedAt"]
|
779
|
+
|
780
|
+
@normalize_exceptions
|
781
|
+
def _update(self) -> None:
|
782
|
+
"""Persists artifact changes to the wandb backend."""
|
783
|
+
aliases = None
|
784
|
+
introspect_query = gql(
|
785
|
+
"""
|
786
|
+
query ProbeServerAddAliasesInput {
|
787
|
+
AddAliasesInputInfoType: __type(name: "AddAliasesInput") {
|
788
|
+
name
|
789
|
+
inputFields {
|
790
|
+
name
|
791
|
+
}
|
792
|
+
}
|
793
|
+
}
|
794
|
+
"""
|
795
|
+
)
|
796
|
+
assert self._client is not None
|
797
|
+
response = self._client.execute(introspect_query)
|
798
|
+
if response.get("AddAliasesInputInfoType"): # wandb backend version >= 0.13.0
|
799
|
+
aliases_to_add = set(self._aliases) - set(self._saved_aliases)
|
800
|
+
aliases_to_delete = set(self._saved_aliases) - set(self._aliases)
|
801
|
+
if len(aliases_to_add) > 0:
|
802
|
+
add_mutation = gql(
|
803
|
+
"""
|
804
|
+
mutation addAliases(
|
805
|
+
$artifactID: ID!,
|
806
|
+
$aliases: [ArtifactCollectionAliasInput!]!,
|
807
|
+
) {
|
808
|
+
addAliases(
|
809
|
+
input: {artifactID: $artifactID, aliases: $aliases}
|
810
|
+
) {
|
811
|
+
success
|
812
|
+
}
|
813
|
+
}
|
814
|
+
"""
|
815
|
+
)
|
816
|
+
assert self._client is not None
|
817
|
+
self._client.execute(
|
818
|
+
add_mutation,
|
819
|
+
variable_values={
|
820
|
+
"artifactID": self.id,
|
821
|
+
"aliases": [
|
822
|
+
{
|
823
|
+
"entityName": self._entity,
|
824
|
+
"projectName": self._project,
|
825
|
+
"artifactCollectionName": self._name.split(":")[0],
|
826
|
+
"alias": alias,
|
827
|
+
}
|
828
|
+
for alias in aliases_to_add
|
829
|
+
],
|
830
|
+
},
|
831
|
+
)
|
832
|
+
if len(aliases_to_delete) > 0:
|
833
|
+
delete_mutation = gql(
|
834
|
+
"""
|
835
|
+
mutation deleteAliases(
|
836
|
+
$artifactID: ID!,
|
837
|
+
$aliases: [ArtifactCollectionAliasInput!]!,
|
838
|
+
) {
|
839
|
+
deleteAliases(
|
840
|
+
input: {artifactID: $artifactID, aliases: $aliases}
|
841
|
+
) {
|
842
|
+
success
|
843
|
+
}
|
844
|
+
}
|
845
|
+
"""
|
846
|
+
)
|
847
|
+
assert self._client is not None
|
848
|
+
self._client.execute(
|
849
|
+
delete_mutation,
|
850
|
+
variable_values={
|
851
|
+
"artifactID": self.id,
|
852
|
+
"aliases": [
|
853
|
+
{
|
854
|
+
"entityName": self._entity,
|
855
|
+
"projectName": self._project,
|
856
|
+
"artifactCollectionName": self._name.split(":")[0],
|
857
|
+
"alias": alias,
|
858
|
+
}
|
859
|
+
for alias in aliases_to_delete
|
860
|
+
],
|
861
|
+
},
|
862
|
+
)
|
863
|
+
self._saved_aliases = copy(self._aliases)
|
864
|
+
else: # wandb backend version < 0.13.0
|
865
|
+
aliases = [
|
866
|
+
{
|
867
|
+
"artifactCollectionName": self._name.split(":")[0],
|
868
|
+
"alias": alias,
|
869
|
+
}
|
870
|
+
for alias in self._aliases
|
871
|
+
]
|
872
|
+
|
873
|
+
mutation = gql(
|
874
|
+
"""
|
875
|
+
mutation updateArtifact(
|
876
|
+
$artifactID: ID!,
|
877
|
+
$description: String,
|
878
|
+
$metadata: JSONString,
|
879
|
+
$aliases: [ArtifactAliasInput!]
|
880
|
+
) {
|
881
|
+
updateArtifact(
|
882
|
+
input: {
|
883
|
+
artifactID: $artifactID,
|
884
|
+
description: $description,
|
885
|
+
metadata: $metadata,
|
886
|
+
aliases: $aliases
|
887
|
+
}
|
888
|
+
) {
|
889
|
+
artifact {
|
890
|
+
id
|
891
|
+
}
|
892
|
+
}
|
893
|
+
}
|
894
|
+
"""
|
895
|
+
)
|
896
|
+
assert self._client is not None
|
897
|
+
self._client.execute(
|
898
|
+
mutation,
|
899
|
+
variable_values={
|
900
|
+
"artifactID": self.id,
|
901
|
+
"description": self.description,
|
902
|
+
"metadata": util.json_dumps_safer(self.metadata),
|
903
|
+
"aliases": aliases,
|
904
|
+
},
|
905
|
+
)
|
906
|
+
|
907
|
+
# Adding, removing, getting entries.
|
908
|
+
|
909
|
+
def __getitem__(self, name: str) -> Optional[data_types.WBValue]:
|
910
|
+
"""Get the WBValue object located at the artifact relative `name`.
|
911
|
+
|
912
|
+
Arguments:
|
913
|
+
name: The artifact relative name to get
|
914
|
+
|
915
|
+
Raises:
|
916
|
+
ArtifactNotLoggedError: if the artifact isn't logged or the run is offline
|
917
|
+
|
918
|
+
Examples:
|
919
|
+
Basic usage:
|
920
|
+
```
|
921
|
+
artifact = wandb.Artifact("my_table", type="dataset")
|
922
|
+
table = wandb.Table(
|
923
|
+
columns=["a", "b", "c"],
|
924
|
+
data=[(i, i * 2, 2**i) for i in range(10)]
|
925
|
+
)
|
926
|
+
artifact["my_table"] = table
|
927
|
+
|
928
|
+
wandb.log_artifact(artifact)
|
929
|
+
```
|
930
|
+
|
931
|
+
Retrieving an object:
|
932
|
+
```
|
933
|
+
artifact = wandb.use_artifact("my_table:latest")
|
934
|
+
table = artifact["my_table"]
|
935
|
+
```
|
936
|
+
"""
|
937
|
+
return self.get(name)
|
938
|
+
|
939
|
+
def __setitem__(self, name: str, item: data_types.WBValue) -> ArtifactManifestEntry:
|
940
|
+
"""Add `item` to the artifact at path `name`.
|
941
|
+
|
942
|
+
Arguments:
|
943
|
+
name: The path within the artifact to add the object.
|
944
|
+
item: The object to add.
|
945
|
+
|
946
|
+
Returns:
|
947
|
+
The added manifest entry
|
948
|
+
|
949
|
+
Raises:
|
950
|
+
ArtifactFinalizedError: if the artifact has already been finalized.
|
951
|
+
|
952
|
+
Examples:
|
953
|
+
Basic usage:
|
954
|
+
```
|
955
|
+
artifact = wandb.Artifact("my_table", type="dataset")
|
956
|
+
table = wandb.Table(
|
957
|
+
columns=["a", "b", "c"],
|
958
|
+
data=[(i, i * 2, 2**i) for i in range(10)]
|
959
|
+
)
|
960
|
+
artifact["my_table"] = table
|
961
|
+
|
962
|
+
wandb.log_artifact(artifact)
|
963
|
+
```
|
964
|
+
|
965
|
+
Retrieving an object:
|
966
|
+
```
|
967
|
+
artifact = wandb.use_artifact("my_table:latest")
|
968
|
+
table = artifact["my_table"]
|
969
|
+
```
|
970
|
+
"""
|
971
|
+
return self.add(item, name)
|
972
|
+
|
973
|
+
@contextlib.contextmanager
|
974
|
+
def new_file(
|
975
|
+
self, name: str, mode: str = "w", encoding: Optional[str] = None
|
976
|
+
) -> Generator[IO, None, None]:
|
977
|
+
"""Open a new temporary file that will be automatically added to the artifact.
|
978
|
+
|
979
|
+
Arguments:
|
980
|
+
name: The name of the new file being added to the artifact.
|
981
|
+
mode: The mode in which to open the new file.
|
982
|
+
encoding: The encoding in which to open the new file.
|
983
|
+
|
984
|
+
Returns:
|
985
|
+
A new file object that can be written to. Upon closing, the file will be
|
986
|
+
automatically added to the artifact.
|
987
|
+
|
988
|
+
Raises:
|
989
|
+
ArtifactFinalizedError: if the artifact has already been finalized.
|
990
|
+
|
991
|
+
Examples:
|
992
|
+
```
|
993
|
+
artifact = wandb.Artifact("my_data", type="dataset")
|
994
|
+
with artifact.new_file("hello.txt") as f:
|
995
|
+
f.write("hello!")
|
996
|
+
wandb.log_artifact(artifact)
|
997
|
+
```
|
998
|
+
"""
|
999
|
+
self._ensure_can_add()
|
1000
|
+
if self._tmp_dir is None:
|
1001
|
+
self._tmp_dir = tempfile.TemporaryDirectory()
|
1002
|
+
path = os.path.join(self._tmp_dir.name, name.lstrip("/"))
|
1003
|
+
if os.path.exists(path):
|
1004
|
+
raise ValueError(f"File with name {name!r} already exists at {path!r}")
|
1005
|
+
|
1006
|
+
filesystem.mkdir_exists_ok(os.path.dirname(path))
|
1007
|
+
try:
|
1008
|
+
with util.fsync_open(path, mode, encoding) as f:
|
1009
|
+
yield f
|
1010
|
+
except UnicodeEncodeError as e:
|
1011
|
+
termerror(
|
1012
|
+
f"Failed to open the provided file (UnicodeEncodeError: {e}). Please "
|
1013
|
+
f"provide the proper encoding."
|
1014
|
+
)
|
1015
|
+
raise e
|
1016
|
+
|
1017
|
+
self.add_file(path, name=name)
|
1018
|
+
|
1019
|
+
def add_file(
|
1020
|
+
self,
|
1021
|
+
local_path: str,
|
1022
|
+
name: Optional[str] = None,
|
1023
|
+
is_tmp: Optional[bool] = False,
|
1024
|
+
) -> ArtifactManifestEntry:
|
1025
|
+
"""Add a local file to the artifact.
|
1026
|
+
|
1027
|
+
Arguments:
|
1028
|
+
local_path: The path to the file being added.
|
1029
|
+
name: The path within the artifact to use for the file being added. Defaults
|
1030
|
+
to the basename of the file.
|
1031
|
+
is_tmp: If true, then the file is renamed deterministically to avoid
|
1032
|
+
collisions.
|
1033
|
+
|
1034
|
+
Returns:
|
1035
|
+
The added manifest entry
|
1036
|
+
|
1037
|
+
Raises:
|
1038
|
+
ArtifactFinalizedError: if the artifact has already been finalized
|
1039
|
+
|
1040
|
+
Examples:
|
1041
|
+
Add a file without an explicit name:
|
1042
|
+
```
|
1043
|
+
# Add as `file.txt'
|
1044
|
+
artifact.add_file("path/to/file.txt")
|
1045
|
+
```
|
1046
|
+
|
1047
|
+
Add a file with an explicit name:
|
1048
|
+
```
|
1049
|
+
# Add as 'new/path/file.txt'
|
1050
|
+
artifact.add_file("path/to/file.txt", name="new/path/file.txt")
|
1051
|
+
```
|
1052
|
+
"""
|
1053
|
+
self._ensure_can_add()
|
1054
|
+
if not os.path.isfile(local_path):
|
1055
|
+
raise ValueError("Path is not a file: %s" % local_path)
|
1056
|
+
|
1057
|
+
name = LogicalPath(name or os.path.basename(local_path))
|
1058
|
+
digest = md5_file_b64(local_path)
|
1059
|
+
|
1060
|
+
if is_tmp:
|
1061
|
+
file_path, file_name = os.path.split(name)
|
1062
|
+
file_name_parts = file_name.split(".")
|
1063
|
+
file_name_parts[0] = b64_to_hex_id(digest)[:20]
|
1064
|
+
name = os.path.join(file_path, ".".join(file_name_parts))
|
1065
|
+
|
1066
|
+
return self._add_local_file(name, local_path, digest=digest)
|
1067
|
+
|
1068
|
+
def add_dir(self, local_path: str, name: Optional[str] = None) -> None:
|
1069
|
+
"""Add a local directory to the artifact.
|
1070
|
+
|
1071
|
+
Arguments:
|
1072
|
+
local_path: The path to the directory being added.
|
1073
|
+
name: The path within the artifact to use for the directory being added.
|
1074
|
+
Defaults to the root of the artifact.
|
1075
|
+
|
1076
|
+
Raises:
|
1077
|
+
ArtifactFinalizedError: if the artifact has already been finalized
|
1078
|
+
|
1079
|
+
Examples:
|
1080
|
+
Add a directory without an explicit name:
|
1081
|
+
```
|
1082
|
+
# All files in `my_dir/` are added at the root of the artifact.
|
1083
|
+
artifact.add_dir("my_dir/")
|
1084
|
+
```
|
1085
|
+
|
1086
|
+
Add a directory and name it explicitly:
|
1087
|
+
```
|
1088
|
+
# All files in `my_dir/` are added under `destination/`.
|
1089
|
+
artifact.add_dir("my_dir/", name="destination")
|
1090
|
+
```
|
1091
|
+
"""
|
1092
|
+
self._ensure_can_add()
|
1093
|
+
if not os.path.isdir(local_path):
|
1094
|
+
raise ValueError("Path is not a directory: %s" % local_path)
|
1095
|
+
|
1096
|
+
termlog(
|
1097
|
+
"Adding directory to artifact (%s)... "
|
1098
|
+
% os.path.join(".", os.path.normpath(local_path)),
|
1099
|
+
newline=False,
|
1100
|
+
)
|
1101
|
+
start_time = time.time()
|
1102
|
+
|
1103
|
+
paths = []
|
1104
|
+
for dirpath, _, filenames in os.walk(local_path, followlinks=True):
|
1105
|
+
for fname in filenames:
|
1106
|
+
physical_path = os.path.join(dirpath, fname)
|
1107
|
+
logical_path = os.path.relpath(physical_path, start=local_path)
|
1108
|
+
if name is not None:
|
1109
|
+
logical_path = os.path.join(name, logical_path)
|
1110
|
+
paths.append((logical_path, physical_path))
|
1111
|
+
|
1112
|
+
def add_manifest_file(log_phy_path: Tuple[str, str]) -> None:
|
1113
|
+
logical_path, physical_path = log_phy_path
|
1114
|
+
self._add_local_file(logical_path, physical_path)
|
1115
|
+
|
1116
|
+
num_threads = 8
|
1117
|
+
pool = multiprocessing.dummy.Pool(num_threads)
|
1118
|
+
pool.map(add_manifest_file, paths)
|
1119
|
+
pool.close()
|
1120
|
+
pool.join()
|
1121
|
+
|
1122
|
+
termlog("Done. %.1fs" % (time.time() - start_time), prefix=False)
|
1123
|
+
|
1124
|
+
def add_reference(
|
1125
|
+
self,
|
1126
|
+
uri: Union[ArtifactManifestEntry, str],
|
1127
|
+
name: Optional[StrPath] = None,
|
1128
|
+
checksum: bool = True,
|
1129
|
+
max_objects: Optional[int] = None,
|
1130
|
+
) -> Sequence[ArtifactManifestEntry]:
|
1131
|
+
"""Add a reference denoted by a URI to the artifact.
|
1132
|
+
|
1133
|
+
Unlike adding files or directories, references are NOT uploaded to W&B. However,
|
1134
|
+
artifact methods such as `download()` can be used regardless of whether the
|
1135
|
+
artifact contains references or uploaded files.
|
1136
|
+
|
1137
|
+
By default, W&B offers special handling for the following schemes:
|
1138
|
+
|
1139
|
+
- http(s): The size and digest of the file will be inferred by the
|
1140
|
+
`Content-Length` and the `ETag` response headers returned by the server.
|
1141
|
+
- s3: The checksum and size will be pulled from the object metadata. If bucket
|
1142
|
+
versioning is enabled, then the version ID is also tracked.
|
1143
|
+
- gs: The checksum and size will be pulled from the object metadata. If bucket
|
1144
|
+
versioning is enabled, then the version ID is also tracked.
|
1145
|
+
- https, domain matching *.blob.core.windows.net (Azure): The checksum and size
|
1146
|
+
will be pulled from the blob metadata. If storage account versioning is
|
1147
|
+
enabled, then the version ID is also tracked.
|
1148
|
+
- file: The checksum and size will be pulled from the file system. This scheme
|
1149
|
+
is useful if you have an NFS share or other externally mounted volume
|
1150
|
+
containing files you wish to track but not necessarily upload.
|
1151
|
+
|
1152
|
+
For any other scheme, the digest is just a hash of the URI and the size is left
|
1153
|
+
blank.
|
1154
|
+
|
1155
|
+
Arguments:
|
1156
|
+
uri: The URI path of the reference to add. Can be an object returned from
|
1157
|
+
Artifact.get_path to store a reference to another artifact's entry.
|
1158
|
+
name: The path within the artifact to place the contents of this reference
|
1159
|
+
checksum: Whether or not to checksum the resource(s) located at the
|
1160
|
+
reference URI. Checksumming is strongly recommended as it enables
|
1161
|
+
automatic integrity validation, however it can be disabled to speed up
|
1162
|
+
artifact creation. (default: True)
|
1163
|
+
max_objects: The maximum number of objects to consider when adding a
|
1164
|
+
reference that points to directory or bucket store prefix. For S3 and
|
1165
|
+
GCS, this limit is 10,000 by default but is uncapped for other URI
|
1166
|
+
schemes. (default: None)
|
1167
|
+
|
1168
|
+
Returns:
|
1169
|
+
The added manifest entries.
|
1170
|
+
|
1171
|
+
Raises:
|
1172
|
+
ArtifactFinalizedError: if the artifact has already been finalized.
|
1173
|
+
|
1174
|
+
Examples:
|
1175
|
+
Add an HTTP link:
|
1176
|
+
```python
|
1177
|
+
# Adds `file.txt` to the root of the artifact as a reference.
|
1178
|
+
artifact.add_reference("http://myserver.com/file.txt")
|
1179
|
+
```
|
1180
|
+
|
1181
|
+
Add an S3 prefix without an explicit name:
|
1182
|
+
```python
|
1183
|
+
# All objects under `prefix/` will be added at the root of the artifact.
|
1184
|
+
artifact.add_reference("s3://mybucket/prefix")
|
1185
|
+
```
|
1186
|
+
|
1187
|
+
Add a GCS prefix with an explicit name:
|
1188
|
+
```python
|
1189
|
+
# All objects under `prefix/` will be added under `path/` at the artifact
|
1190
|
+
# root.
|
1191
|
+
artifact.add_reference("gs://mybucket/prefix", name="path")
|
1192
|
+
```
|
1193
|
+
"""
|
1194
|
+
self._ensure_can_add()
|
1195
|
+
if name is not None:
|
1196
|
+
name = LogicalPath(name)
|
1197
|
+
|
1198
|
+
# This is a bit of a hack, we want to check if the uri is a of the type
|
1199
|
+
# ArtifactManifestEntry. If so, then recover the reference URL.
|
1200
|
+
if isinstance(uri, ArtifactManifestEntry):
|
1201
|
+
uri_str = uri.ref_url()
|
1202
|
+
elif isinstance(uri, str):
|
1203
|
+
uri_str = uri
|
1204
|
+
url = urlparse(str(uri_str))
|
1205
|
+
if not url.scheme:
|
1206
|
+
raise ValueError(
|
1207
|
+
"References must be URIs. To reference a local file, use file://"
|
1208
|
+
)
|
1209
|
+
|
1210
|
+
manifest_entries = self._storage_policy.store_reference(
|
1211
|
+
self,
|
1212
|
+
URIStr(uri_str),
|
1213
|
+
name=name,
|
1214
|
+
checksum=checksum,
|
1215
|
+
max_objects=max_objects,
|
1216
|
+
)
|
1217
|
+
for entry in manifest_entries:
|
1218
|
+
self.manifest.add_entry(entry)
|
1219
|
+
|
1220
|
+
return manifest_entries
|
1221
|
+
|
1222
|
+
def add(self, obj: data_types.WBValue, name: StrPath) -> ArtifactManifestEntry:
|
1223
|
+
"""Add wandb.WBValue `obj` to the artifact.
|
1224
|
+
|
1225
|
+
Arguments:
|
1226
|
+
obj: The object to add. Currently support one of Bokeh, JoinedTable,
|
1227
|
+
PartitionedTable, Table, Classes, ImageMask, BoundingBoxes2D, Audio,
|
1228
|
+
Image, Video, Html, Object3D
|
1229
|
+
name: The path within the artifact to add the object.
|
1230
|
+
|
1231
|
+
Returns:
|
1232
|
+
The added manifest entry
|
1233
|
+
|
1234
|
+
Raises:
|
1235
|
+
ArtifactFinalizedError: if the artifact has already been finalized
|
1236
|
+
|
1237
|
+
Examples:
|
1238
|
+
Basic usage:
|
1239
|
+
```
|
1240
|
+
artifact = wandb.Artifact("my_table", type="dataset")
|
1241
|
+
table = wandb.Table(
|
1242
|
+
columns=["a", "b", "c"],
|
1243
|
+
data=[(i, i * 2, 2**i) for i in range(10)]
|
1244
|
+
)
|
1245
|
+
artifact.add(table, "my_table")
|
1246
|
+
|
1247
|
+
wandb.log_artifact(artifact)
|
1248
|
+
```
|
1249
|
+
|
1250
|
+
Retrieve an object:
|
1251
|
+
```
|
1252
|
+
artifact = wandb.use_artifact("my_table:latest")
|
1253
|
+
table = artifact.get("my_table")
|
1254
|
+
```
|
1255
|
+
"""
|
1256
|
+
self._ensure_can_add()
|
1257
|
+
name = LogicalPath(name)
|
1258
|
+
|
1259
|
+
# This is a "hack" to automatically rename tables added to
|
1260
|
+
# the wandb /media/tables directory to their sha-based name.
|
1261
|
+
# TODO: figure out a more appropriate convention.
|
1262
|
+
is_tmp_name = name.startswith("media/tables")
|
1263
|
+
|
1264
|
+
# Validate that the object is one of the correct wandb.Media types
|
1265
|
+
# TODO: move this to checking subclass of wandb.Media once all are
|
1266
|
+
# generally supported
|
1267
|
+
allowed_types = [
|
1268
|
+
data_types.Bokeh,
|
1269
|
+
data_types.JoinedTable,
|
1270
|
+
data_types.PartitionedTable,
|
1271
|
+
data_types.Table,
|
1272
|
+
data_types.Classes,
|
1273
|
+
data_types.ImageMask,
|
1274
|
+
data_types.BoundingBoxes2D,
|
1275
|
+
data_types.Audio,
|
1276
|
+
data_types.Image,
|
1277
|
+
data_types.Video,
|
1278
|
+
data_types.Html,
|
1279
|
+
data_types.Object3D,
|
1280
|
+
data_types.Molecule,
|
1281
|
+
data_types._SavedModel,
|
1282
|
+
]
|
1283
|
+
|
1284
|
+
if not any(isinstance(obj, t) for t in allowed_types):
|
1285
|
+
raise ValueError(
|
1286
|
+
"Found object of type {}, expected one of {}.".format(
|
1287
|
+
obj.__class__, allowed_types
|
1288
|
+
)
|
1289
|
+
)
|
1290
|
+
|
1291
|
+
obj_id = id(obj)
|
1292
|
+
if obj_id in self._added_objs:
|
1293
|
+
return self._added_objs[obj_id][1]
|
1294
|
+
|
1295
|
+
# If the object is coming from another artifact, save it as a reference
|
1296
|
+
ref_path = obj._get_artifact_entry_ref_url()
|
1297
|
+
if ref_path is not None:
|
1298
|
+
return self.add_reference(ref_path, type(obj).with_suffix(name))[0]
|
1299
|
+
|
1300
|
+
val = obj.to_json(self)
|
1301
|
+
name = obj.with_suffix(name)
|
1302
|
+
entry = self.manifest.get_entry_by_path(name)
|
1303
|
+
if entry is not None:
|
1304
|
+
return entry
|
1305
|
+
|
1306
|
+
def do_write(f: IO) -> None:
|
1307
|
+
import json
|
1308
|
+
|
1309
|
+
# TODO: Do we need to open with utf-8 codec?
|
1310
|
+
f.write(json.dumps(val, sort_keys=True))
|
1311
|
+
|
1312
|
+
if is_tmp_name:
|
1313
|
+
file_path = os.path.join(self._TMP_DIR.name, str(id(self)), name)
|
1314
|
+
folder_path, _ = os.path.split(file_path)
|
1315
|
+
if not os.path.exists(folder_path):
|
1316
|
+
os.makedirs(folder_path)
|
1317
|
+
with open(file_path, "w") as tmp_f:
|
1318
|
+
do_write(tmp_f)
|
1319
|
+
else:
|
1320
|
+
with self.new_file(name) as f:
|
1321
|
+
file_path = f.name
|
1322
|
+
do_write(f)
|
1323
|
+
|
1324
|
+
# Note, we add the file from our temp directory.
|
1325
|
+
# It will be added again later on finalize, but succeed since
|
1326
|
+
# the checksum should match
|
1327
|
+
entry = self.add_file(file_path, name, is_tmp_name)
|
1328
|
+
# We store a reference to the obj so that its id doesn't get reused.
|
1329
|
+
self._added_objs[obj_id] = (obj, entry)
|
1330
|
+
if obj._artifact_target is None:
|
1331
|
+
obj._set_artifact_target(self, entry.path)
|
1332
|
+
|
1333
|
+
if is_tmp_name:
|
1334
|
+
if os.path.exists(file_path):
|
1335
|
+
os.remove(file_path)
|
1336
|
+
|
1337
|
+
return entry
|
1338
|
+
|
1339
|
+
def _add_local_file(
|
1340
|
+
self, name: StrPath, path: StrPath, digest: Optional[B64MD5] = None
|
1341
|
+
) -> ArtifactManifestEntry:
|
1342
|
+
with tempfile.NamedTemporaryFile(dir=get_staging_dir(), delete=False) as f:
|
1343
|
+
staging_path = f.name
|
1344
|
+
shutil.copyfile(path, staging_path)
|
1345
|
+
os.chmod(staging_path, 0o400)
|
1346
|
+
|
1347
|
+
entry = ArtifactManifestEntry(
|
1348
|
+
path=name,
|
1349
|
+
digest=digest or md5_file_b64(staging_path),
|
1350
|
+
size=os.path.getsize(staging_path),
|
1351
|
+
local_path=staging_path,
|
1352
|
+
)
|
1353
|
+
|
1354
|
+
self.manifest.add_entry(entry)
|
1355
|
+
self._added_local_paths[os.fspath(path)] = entry
|
1356
|
+
return entry
|
1357
|
+
|
1358
|
+
def remove(self, item: Union[StrPath, "ArtifactManifestEntry"]) -> None:
|
1359
|
+
"""Remove an item from the artifact.
|
1360
|
+
|
1361
|
+
Arguments:
|
1362
|
+
item: the item to remove. Can be a specific manifest entry or the name of an
|
1363
|
+
artifact-relative path. If the item matches a directory all items in
|
1364
|
+
that directory will be removed.
|
1365
|
+
|
1366
|
+
Raises:
|
1367
|
+
ArtifactFinalizedError: if the artifact has already been finalized.
|
1368
|
+
FileNotFoundError: if the item isn't found in the artifact.
|
1369
|
+
"""
|
1370
|
+
self._ensure_can_add()
|
1371
|
+
|
1372
|
+
if isinstance(item, ArtifactManifestEntry):
|
1373
|
+
self.manifest.remove_entry(item)
|
1374
|
+
return
|
1375
|
+
|
1376
|
+
path = str(PurePosixPath(item))
|
1377
|
+
entry = self.manifest.get_entry_by_path(path)
|
1378
|
+
if entry:
|
1379
|
+
self.manifest.remove_entry(entry)
|
1380
|
+
return
|
1381
|
+
|
1382
|
+
entries = self.manifest.get_entries_in_directory(path)
|
1383
|
+
if not entries:
|
1384
|
+
raise FileNotFoundError(f"No such file or directory: {path}")
|
1385
|
+
for entry in entries:
|
1386
|
+
self.manifest.remove_entry(entry)
|
1387
|
+
|
1388
|
+
def get_path(self, name: StrPath) -> ArtifactManifestEntry:
|
1389
|
+
"""Get the entry with the given name.
|
1390
|
+
|
1391
|
+
Arguments:
|
1392
|
+
name: The artifact relative name to get
|
1393
|
+
|
1394
|
+
Raises:
|
1395
|
+
ArtifactNotLoggedError: if the artifact isn't logged or the run is offline
|
1396
|
+
KeyError: if the artifact doesn't contain an entry with the given name
|
1397
|
+
|
1398
|
+
Examples:
|
1399
|
+
Basic usage:
|
1400
|
+
```
|
1401
|
+
# Run logging the artifact
|
1402
|
+
with wandb.init() as r:
|
1403
|
+
artifact = wandb.Artifact("my_dataset", type="dataset")
|
1404
|
+
artifact.add_file("path/to/file.txt")
|
1405
|
+
wandb.log_artifact(artifact)
|
1406
|
+
|
1407
|
+
# Run using the artifact
|
1408
|
+
with wandb.init() as r:
|
1409
|
+
artifact = r.use_artifact("my_dataset:latest")
|
1410
|
+
path = artifact.get_path("file.txt")
|
1411
|
+
|
1412
|
+
# Can now download 'file.txt' directly:
|
1413
|
+
path.download()
|
1414
|
+
```
|
1415
|
+
"""
|
1416
|
+
if self._state == ArtifactState.PENDING:
|
1417
|
+
raise ArtifactNotLoggedError(self, "get_path")
|
1418
|
+
|
1419
|
+
name = LogicalPath(name)
|
1420
|
+
entry = self.manifest.entries.get(name) or self._get_obj_entry(name)[0]
|
1421
|
+
if entry is None:
|
1422
|
+
raise KeyError("Path not contained in artifact: %s" % name)
|
1423
|
+
entry._parent_artifact = self
|
1424
|
+
return entry
|
1425
|
+
|
1426
|
+
def get(self, name: str) -> Optional[data_types.WBValue]:
|
1427
|
+
"""Get the WBValue object located at the artifact relative `name`.
|
1428
|
+
|
1429
|
+
Arguments:
|
1430
|
+
name: The artifact relative name to get
|
1431
|
+
|
1432
|
+
Raises:
|
1433
|
+
ArtifactNotLoggedError: if the artifact isn't logged or the run is offline
|
1434
|
+
|
1435
|
+
Examples:
|
1436
|
+
Basic usage:
|
1437
|
+
```
|
1438
|
+
# Run logging the artifact
|
1439
|
+
with wandb.init() as r:
|
1440
|
+
artifact = wandb.Artifact("my_dataset", type="dataset")
|
1441
|
+
table = wandb.Table(
|
1442
|
+
columns=["a", "b", "c"],
|
1443
|
+
data=[(i, i * 2, 2**i) for i in range(10)]
|
1444
|
+
)
|
1445
|
+
artifact.add(table, "my_table")
|
1446
|
+
wandb.log_artifact(artifact)
|
1447
|
+
|
1448
|
+
# Run using the artifact
|
1449
|
+
with wandb.init() as r:
|
1450
|
+
artifact = r.use_artifact("my_dataset:latest")
|
1451
|
+
table = artifact.get("my_table")
|
1452
|
+
```
|
1453
|
+
"""
|
1454
|
+
if self._state == ArtifactState.PENDING:
|
1455
|
+
raise ArtifactNotLoggedError(self, "get")
|
1456
|
+
|
1457
|
+
entry, wb_class = self._get_obj_entry(name)
|
1458
|
+
if entry is None or wb_class is None:
|
1459
|
+
return None
|
1460
|
+
|
1461
|
+
# If the entry is a reference from another artifact, then get it directly from
|
1462
|
+
# that artifact.
|
1463
|
+
if entry._is_artifact_reference():
|
1464
|
+
assert self._client is not None
|
1465
|
+
artifact = entry._get_referenced_artifact(self._client)
|
1466
|
+
return artifact.get(util.uri_from_path(entry.ref))
|
1467
|
+
|
1468
|
+
# Special case for wandb.Table. This is intended to be a short term
|
1469
|
+
# optimization. Since tables are likely to download many other assets in
|
1470
|
+
# artifact(s), we eagerly download the artifact using the parallelized
|
1471
|
+
# `artifact.download`. In the future, we should refactor the deserialization
|
1472
|
+
# pattern such that this special case is not needed.
|
1473
|
+
if wb_class == wandb.Table:
|
1474
|
+
self.download(recursive=True)
|
1475
|
+
|
1476
|
+
# Get the ArtifactManifestEntry
|
1477
|
+
item = self.get_path(entry.path)
|
1478
|
+
item_path = item.download()
|
1479
|
+
|
1480
|
+
# Load the object from the JSON blob
|
1481
|
+
result = None
|
1482
|
+
json_obj = {}
|
1483
|
+
with open(item_path) as file:
|
1484
|
+
json_obj = json.load(file)
|
1485
|
+
result = wb_class.from_json(json_obj, self)
|
1486
|
+
result._set_artifact_source(self, name)
|
1487
|
+
return result
|
1488
|
+
|
1489
|
+
def get_added_local_path_name(self, local_path: str) -> Optional[str]:
|
1490
|
+
"""Get the artifact relative name of a file added by a local filesystem path.
|
1491
|
+
|
1492
|
+
Arguments:
|
1493
|
+
local_path: The local path to resolve into an artifact relative name.
|
1494
|
+
|
1495
|
+
Returns:
|
1496
|
+
The artifact relative name.
|
1497
|
+
|
1498
|
+
Examples:
|
1499
|
+
Basic usage:
|
1500
|
+
```
|
1501
|
+
artifact = wandb.Artifact("my_dataset", type="dataset")
|
1502
|
+
artifact.add_file("path/to/file.txt", name="artifact/path/file.txt")
|
1503
|
+
|
1504
|
+
# Returns `artifact/path/file.txt`:
|
1505
|
+
name = artifact.get_added_local_path_name("path/to/file.txt")
|
1506
|
+
```
|
1507
|
+
"""
|
1508
|
+
entry = self._added_local_paths.get(local_path, None)
|
1509
|
+
if entry is None:
|
1510
|
+
return None
|
1511
|
+
return entry.path
|
1512
|
+
|
1513
|
+
def _get_obj_entry(
|
1514
|
+
self, name: str
|
1515
|
+
) -> Tuple[Optional["ArtifactManifestEntry"], Optional[Type[WBValue]]]:
|
1516
|
+
"""Return an object entry by name, handling any type suffixes.
|
1517
|
+
|
1518
|
+
When objects are added with `.add(obj, name)`, the name is typically changed to
|
1519
|
+
include the suffix of the object type when serializing to JSON. So we need to be
|
1520
|
+
able to resolve a name, without tasking the user with appending .THING.json.
|
1521
|
+
This method returns an entry if it exists by a suffixed name.
|
1522
|
+
|
1523
|
+
Arguments:
|
1524
|
+
name: name used when adding
|
1525
|
+
"""
|
1526
|
+
for wb_class in WBValue.type_mapping().values():
|
1527
|
+
wandb_file_name = wb_class.with_suffix(name)
|
1528
|
+
entry = self.manifest.entries.get(wandb_file_name)
|
1529
|
+
if entry is not None:
|
1530
|
+
return entry, wb_class
|
1531
|
+
return None, None
|
1532
|
+
|
1533
|
+
# Downloading.
|
1534
|
+
|
1535
|
+
def download(
|
1536
|
+
self,
|
1537
|
+
root: Optional[str] = None,
|
1538
|
+
recursive: bool = False,
|
1539
|
+
allow_missing_references: bool = False,
|
1540
|
+
) -> FilePathStr:
|
1541
|
+
"""Download the contents of the artifact to the specified root directory.
|
1542
|
+
|
1543
|
+
NOTE: Any existing files at `root` are left untouched. Explicitly delete
|
1544
|
+
root before calling `download` if you want the contents of `root` to exactly
|
1545
|
+
match the artifact.
|
1546
|
+
|
1547
|
+
Arguments:
|
1548
|
+
root: The directory in which to download this artifact's files.
|
1549
|
+
recursive: If true, then all dependent artifacts are eagerly downloaded.
|
1550
|
+
Otherwise, the dependent artifacts are downloaded as needed.
|
1551
|
+
|
1552
|
+
Returns:
|
1553
|
+
The path to the downloaded contents.
|
1554
|
+
|
1555
|
+
Raises:
|
1556
|
+
ArtifactNotLoggedError: if the artifact has not been logged
|
1557
|
+
"""
|
1558
|
+
if self._state == ArtifactState.PENDING:
|
1559
|
+
raise ArtifactNotLoggedError(self, "download")
|
1560
|
+
|
1561
|
+
root = root or self._default_root()
|
1562
|
+
self._add_download_root(root)
|
1563
|
+
|
1564
|
+
nfiles = len(self.manifest.entries)
|
1565
|
+
size = sum(e.size or 0 for e in self.manifest.entries.values())
|
1566
|
+
log = False
|
1567
|
+
if nfiles > 5000 or size > 50 * 1024 * 1024:
|
1568
|
+
log = True
|
1569
|
+
termlog(
|
1570
|
+
"Downloading large artifact {}, {:.2f}MB. {} files... ".format(
|
1571
|
+
self.name, size / (1024 * 1024), nfiles
|
1572
|
+
),
|
1573
|
+
)
|
1574
|
+
start_time = datetime.datetime.now()
|
1575
|
+
download_logger = ArtifactDownloadLogger(nfiles=nfiles)
|
1576
|
+
|
1577
|
+
def _download_entry(
|
1578
|
+
entry: ArtifactManifestEntry,
|
1579
|
+
api_key: Optional[str],
|
1580
|
+
cookies: Optional[Dict],
|
1581
|
+
headers: Optional[Dict],
|
1582
|
+
) -> None:
|
1583
|
+
_thread_local_api_settings.api_key = api_key
|
1584
|
+
_thread_local_api_settings.cookies = cookies
|
1585
|
+
_thread_local_api_settings.headers = headers
|
1586
|
+
|
1587
|
+
try:
|
1588
|
+
entry.download(root)
|
1589
|
+
except FileNotFoundError as e:
|
1590
|
+
if allow_missing_references:
|
1591
|
+
wandb.termwarn(str(e))
|
1592
|
+
return
|
1593
|
+
raise
|
1594
|
+
download_logger.notify_downloaded()
|
1595
|
+
|
1596
|
+
download_entry = partial(
|
1597
|
+
_download_entry,
|
1598
|
+
api_key=_thread_local_api_settings.api_key,
|
1599
|
+
cookies=_thread_local_api_settings.cookies,
|
1600
|
+
headers=_thread_local_api_settings.headers,
|
1601
|
+
)
|
1602
|
+
|
1603
|
+
with concurrent.futures.ThreadPoolExecutor(64) as executor:
|
1604
|
+
active_futures = set()
|
1605
|
+
has_next_page = True
|
1606
|
+
cursor = None
|
1607
|
+
while has_next_page:
|
1608
|
+
attrs = self._fetch_file_urls(cursor)
|
1609
|
+
has_next_page = attrs["pageInfo"]["hasNextPage"]
|
1610
|
+
cursor = attrs["pageInfo"]["endCursor"]
|
1611
|
+
for edge in attrs["edges"]:
|
1612
|
+
entry = self.get_path(edge["node"]["name"])
|
1613
|
+
entry._download_url = edge["node"]["directUrl"]
|
1614
|
+
active_futures.add(executor.submit(download_entry, entry))
|
1615
|
+
# Wait for download threads to catch up.
|
1616
|
+
max_backlog = 5000
|
1617
|
+
if len(active_futures) > max_backlog:
|
1618
|
+
for future in concurrent.futures.as_completed(active_futures):
|
1619
|
+
future.result() # check for errors
|
1620
|
+
active_futures.remove(future)
|
1621
|
+
if len(active_futures) <= max_backlog:
|
1622
|
+
break
|
1623
|
+
# Check for errors.
|
1624
|
+
for future in concurrent.futures.as_completed(active_futures):
|
1625
|
+
future.result()
|
1626
|
+
|
1627
|
+
if recursive:
|
1628
|
+
for dependent_artifact in self._dependent_artifacts:
|
1629
|
+
dependent_artifact.download()
|
1630
|
+
|
1631
|
+
if log:
|
1632
|
+
now = datetime.datetime.now()
|
1633
|
+
delta = abs((now - start_time).total_seconds())
|
1634
|
+
hours = int(delta // 3600)
|
1635
|
+
minutes = int((delta - hours * 3600) // 60)
|
1636
|
+
seconds = delta - hours * 3600 - minutes * 60
|
1637
|
+
termlog(
|
1638
|
+
f"Done. {hours}:{minutes}:{seconds:.1f}",
|
1639
|
+
prefix=False,
|
1640
|
+
)
|
1641
|
+
return FilePathStr(root)
|
1642
|
+
|
1643
|
+
@retry.retriable(
|
1644
|
+
retry_timedelta=datetime.timedelta(minutes=3),
|
1645
|
+
retryable_exceptions=(requests.RequestException),
|
1646
|
+
)
|
1647
|
+
def _fetch_file_urls(self, cursor: Optional[str]) -> Any:
|
1648
|
+
query = gql(
|
1649
|
+
"""
|
1650
|
+
query ArtifactFileURLs($id: ID!, $cursor: String) {
|
1651
|
+
artifact(id: $id) {
|
1652
|
+
files(after: $cursor, first: 5000) {
|
1653
|
+
pageInfo {
|
1654
|
+
hasNextPage
|
1655
|
+
endCursor
|
1656
|
+
}
|
1657
|
+
edges {
|
1658
|
+
node {
|
1659
|
+
name
|
1660
|
+
directUrl
|
1661
|
+
}
|
1662
|
+
}
|
1663
|
+
}
|
1664
|
+
}
|
1665
|
+
}
|
1666
|
+
"""
|
1667
|
+
)
|
1668
|
+
assert self._client is not None
|
1669
|
+
response = self._client.execute(
|
1670
|
+
query,
|
1671
|
+
variable_values={"id": self.id, "cursor": cursor},
|
1672
|
+
timeout=60,
|
1673
|
+
)
|
1674
|
+
return response["artifact"]["files"]
|
1675
|
+
|
1676
|
+
def checkout(self, root: Optional[str] = None) -> str:
|
1677
|
+
"""Replace the specified root directory with the contents of the artifact.
|
1678
|
+
|
1679
|
+
WARNING: This will DELETE all files in `root` that are not included in the
|
1680
|
+
artifact.
|
1681
|
+
|
1682
|
+
Arguments:
|
1683
|
+
root: The directory to replace with this artifact's files.
|
1684
|
+
|
1685
|
+
Returns:
|
1686
|
+
The path to the checked out contents.
|
1687
|
+
|
1688
|
+
Raises:
|
1689
|
+
ArtifactNotLoggedError: if the artifact has not been logged
|
1690
|
+
"""
|
1691
|
+
if self._state == ArtifactState.PENDING:
|
1692
|
+
raise ArtifactNotLoggedError(self, "checkout")
|
1693
|
+
|
1694
|
+
root = root or self._default_root(include_version=False)
|
1695
|
+
|
1696
|
+
for dirpath, _, files in os.walk(root):
|
1697
|
+
for file in files:
|
1698
|
+
full_path = os.path.join(dirpath, file)
|
1699
|
+
artifact_path = os.path.relpath(full_path, start=root)
|
1700
|
+
try:
|
1701
|
+
self.get_path(artifact_path)
|
1702
|
+
except KeyError:
|
1703
|
+
# File is not part of the artifact, remove it.
|
1704
|
+
os.remove(full_path)
|
1705
|
+
|
1706
|
+
return self.download(root=root)
|
1707
|
+
|
1708
|
+
def verify(self, root: Optional[str] = None) -> None:
|
1709
|
+
"""Verify that the actual contents of an artifact match the manifest.
|
1710
|
+
|
1711
|
+
All files in the directory are checksummed and the checksums are then
|
1712
|
+
cross-referenced against the artifact's manifest.
|
1713
|
+
|
1714
|
+
NOTE: References are not verified.
|
1715
|
+
|
1716
|
+
Arguments:
|
1717
|
+
root: The directory to verify. If None artifact will be downloaded to
|
1718
|
+
'./artifacts/self.name/'
|
1719
|
+
|
1720
|
+
Raises:
|
1721
|
+
ArtifactNotLoggedError: if the artifact has not been logged
|
1722
|
+
ValueError: If the verification fails.
|
1723
|
+
"""
|
1724
|
+
if self._state == ArtifactState.PENDING:
|
1725
|
+
raise ArtifactNotLoggedError(self, "verify")
|
1726
|
+
|
1727
|
+
root = root or self._default_root()
|
1728
|
+
|
1729
|
+
for dirpath, _, files in os.walk(root):
|
1730
|
+
for file in files:
|
1731
|
+
full_path = os.path.join(dirpath, file)
|
1732
|
+
artifact_path = os.path.relpath(full_path, start=root)
|
1733
|
+
try:
|
1734
|
+
self.get_path(artifact_path)
|
1735
|
+
except KeyError:
|
1736
|
+
raise ValueError(
|
1737
|
+
"Found file {} which is not a member of artifact {}".format(
|
1738
|
+
full_path, self.name
|
1739
|
+
)
|
1740
|
+
)
|
1741
|
+
|
1742
|
+
ref_count = 0
|
1743
|
+
for entry in self.manifest.entries.values():
|
1744
|
+
if entry.ref is None:
|
1745
|
+
if md5_file_b64(os.path.join(root, entry.path)) != entry.digest:
|
1746
|
+
raise ValueError("Digest mismatch for file: %s" % entry.path)
|
1747
|
+
else:
|
1748
|
+
ref_count += 1
|
1749
|
+
if ref_count > 0:
|
1750
|
+
print("Warning: skipped verification of %s refs" % ref_count)
|
1751
|
+
|
1752
|
+
def file(self, root: Optional[str] = None) -> StrPath:
|
1753
|
+
"""Download a single file artifact to dir specified by the root.
|
1754
|
+
|
1755
|
+
Arguments:
|
1756
|
+
root: The root directory in which to place the file. Defaults to
|
1757
|
+
'./artifacts/self.name/'.
|
1758
|
+
|
1759
|
+
Returns:
|
1760
|
+
The full path of the downloaded file.
|
1761
|
+
|
1762
|
+
Raises:
|
1763
|
+
ArtifactNotLoggedError: if the artifact has not been logged
|
1764
|
+
ValueError: if the artifact contains more than one file
|
1765
|
+
"""
|
1766
|
+
if self._state == ArtifactState.PENDING:
|
1767
|
+
raise ArtifactNotLoggedError(self, "file")
|
1768
|
+
|
1769
|
+
if root is None:
|
1770
|
+
root = os.path.join(".", "artifacts", self.name)
|
1771
|
+
|
1772
|
+
if len(self.manifest.entries) > 1:
|
1773
|
+
raise ValueError(
|
1774
|
+
"This artifact contains more than one file, call `.download()` to get "
|
1775
|
+
'all files or call .get_path("filename").download()'
|
1776
|
+
)
|
1777
|
+
|
1778
|
+
return self.get_path(list(self.manifest.entries)[0]).download(root)
|
1779
|
+
|
1780
|
+
def files(
|
1781
|
+
self, names: Optional[List[str]] = None, per_page: int = 50
|
1782
|
+
) -> ArtifactFiles:
|
1783
|
+
"""Iterate over all files stored in this artifact.
|
1784
|
+
|
1785
|
+
Arguments:
|
1786
|
+
names: The filename paths relative to the root of the artifact you wish to
|
1787
|
+
list.
|
1788
|
+
per_page: The number of files to return per request
|
1789
|
+
|
1790
|
+
Returns:
|
1791
|
+
An iterator containing `File` objects
|
1792
|
+
|
1793
|
+
Raises:
|
1794
|
+
ArtifactNotLoggedError: if the artifact has not been logged
|
1795
|
+
"""
|
1796
|
+
if self._state == ArtifactState.PENDING:
|
1797
|
+
raise ArtifactNotLoggedError(self, "files")
|
1798
|
+
|
1799
|
+
return ArtifactFiles(self._client, self, names, per_page)
|
1800
|
+
|
1801
|
+
def _default_root(self, include_version: bool = True) -> str:
|
1802
|
+
name = self.name if include_version else self.name.split(":")[0]
|
1803
|
+
root = os.path.join(env.get_artifact_dir(), name)
|
1804
|
+
if platform.system() == "Windows":
|
1805
|
+
head, tail = os.path.splitdrive(root)
|
1806
|
+
root = head + tail.replace(":", "-")
|
1807
|
+
return root
|
1808
|
+
|
1809
|
+
def _add_download_root(self, dir_path: str) -> None:
|
1810
|
+
self._download_roots.add(os.path.abspath(dir_path))
|
1811
|
+
|
1812
|
+
def _local_path_to_name(self, file_path: str) -> Optional[str]:
|
1813
|
+
"""Convert a local file path to a path entry in the artifact."""
|
1814
|
+
abs_file_path = os.path.abspath(file_path)
|
1815
|
+
abs_file_parts = abs_file_path.split(os.sep)
|
1816
|
+
for i in range(len(abs_file_parts) + 1):
|
1817
|
+
if os.path.join(os.sep, *abs_file_parts[:i]) in self._download_roots:
|
1818
|
+
return os.path.join(*abs_file_parts[i:])
|
1819
|
+
return None
|
1820
|
+
|
1821
|
+
# Others.
|
1822
|
+
|
1823
|
+
def delete(self, delete_aliases: bool = False) -> None:
|
1824
|
+
"""Delete an artifact and its files.
|
1825
|
+
|
1826
|
+
Arguments:
|
1827
|
+
delete_aliases: If true, deletes all aliases associated with the artifact.
|
1828
|
+
Otherwise, this raises an exception if the artifact has existing
|
1829
|
+
aliases.
|
1830
|
+
|
1831
|
+
Raises:
|
1832
|
+
ArtifactNotLoggedError: if the artifact has not been logged
|
1833
|
+
|
1834
|
+
Examples:
|
1835
|
+
Delete all the "model" artifacts a run has logged:
|
1836
|
+
```
|
1837
|
+
runs = api.runs(path="my_entity/my_project")
|
1838
|
+
for run in runs:
|
1839
|
+
for artifact in run.logged_artifacts():
|
1840
|
+
if artifact.type == "model":
|
1841
|
+
artifact.delete(delete_aliases=True)
|
1842
|
+
```
|
1843
|
+
"""
|
1844
|
+
if self._state == ArtifactState.PENDING:
|
1845
|
+
raise ArtifactNotLoggedError(self, "delete")
|
1846
|
+
self._delete(delete_aliases)
|
1847
|
+
|
1848
|
+
@normalize_exceptions
|
1849
|
+
def _delete(self, delete_aliases: bool = False) -> None:
|
1850
|
+
mutation = gql(
|
1851
|
+
"""
|
1852
|
+
mutation DeleteArtifact($artifactID: ID!, $deleteAliases: Boolean) {
|
1853
|
+
deleteArtifact(input: {
|
1854
|
+
artifactID: $artifactID
|
1855
|
+
deleteAliases: $deleteAliases
|
1856
|
+
}) {
|
1857
|
+
artifact {
|
1858
|
+
id
|
1859
|
+
}
|
1860
|
+
}
|
1861
|
+
}
|
1862
|
+
"""
|
1863
|
+
)
|
1864
|
+
assert self._client is not None
|
1865
|
+
self._client.execute(
|
1866
|
+
mutation,
|
1867
|
+
variable_values={
|
1868
|
+
"artifactID": self.id,
|
1869
|
+
"deleteAliases": delete_aliases,
|
1870
|
+
},
|
1871
|
+
)
|
1872
|
+
|
1873
|
+
def link(self, target_path: str, aliases: Optional[List[str]] = None) -> None:
|
1874
|
+
"""Link this artifact to a portfolio (a promoted collection of artifacts).
|
1875
|
+
|
1876
|
+
Arguments:
|
1877
|
+
target_path: The path to the portfolio. It must take the form {portfolio},
|
1878
|
+
{project}/{portfolio} or {entity}/{project}/{portfolio}.
|
1879
|
+
aliases: A list of strings which uniquely identifies the artifact inside the
|
1880
|
+
specified portfolio.
|
1881
|
+
|
1882
|
+
Raises:
|
1883
|
+
ArtifactNotLoggedError: if the artifact has not been logged
|
1884
|
+
"""
|
1885
|
+
if self._state == ArtifactState.PENDING:
|
1886
|
+
raise ArtifactNotLoggedError(self, "link")
|
1887
|
+
self._link(target_path, aliases)
|
1888
|
+
|
1889
|
+
@normalize_exceptions
|
1890
|
+
def _link(self, target_path: str, aliases: Optional[List[str]] = None) -> None:
|
1891
|
+
if ":" in target_path:
|
1892
|
+
raise ValueError(
|
1893
|
+
f"target_path {target_path} cannot contain `:` because it is not an "
|
1894
|
+
f"alias."
|
1895
|
+
)
|
1896
|
+
|
1897
|
+
portfolio, project, entity = util._parse_entity_project_item(target_path)
|
1898
|
+
aliases = util._resolve_aliases(aliases)
|
1899
|
+
|
1900
|
+
run_entity = wandb.run.entity if wandb.run else None
|
1901
|
+
run_project = wandb.run.project if wandb.run else None
|
1902
|
+
entity = entity or run_entity or self.entity
|
1903
|
+
project = project or run_project or self.project
|
1904
|
+
|
1905
|
+
mutation = gql(
|
1906
|
+
"""
|
1907
|
+
mutation LinkArtifact(
|
1908
|
+
$artifactID: ID!,
|
1909
|
+
$artifactPortfolioName: String!,
|
1910
|
+
$entityName: String!,
|
1911
|
+
$projectName: String!,
|
1912
|
+
$aliases: [ArtifactAliasInput!]
|
1913
|
+
) {
|
1914
|
+
linkArtifact(
|
1915
|
+
input: {
|
1916
|
+
artifactID: $artifactID,
|
1917
|
+
artifactPortfolioName: $artifactPortfolioName,
|
1918
|
+
entityName: $entityName,
|
1919
|
+
projectName: $projectName,
|
1920
|
+
aliases: $aliases
|
1921
|
+
}
|
1922
|
+
) {
|
1923
|
+
versionIndex
|
1924
|
+
}
|
1925
|
+
}
|
1926
|
+
"""
|
1927
|
+
)
|
1928
|
+
assert self._client is not None
|
1929
|
+
self._client.execute(
|
1930
|
+
mutation,
|
1931
|
+
variable_values={
|
1932
|
+
"artifactID": self.id,
|
1933
|
+
"artifactPortfolioName": portfolio,
|
1934
|
+
"entityName": entity,
|
1935
|
+
"projectName": project,
|
1936
|
+
"aliases": [
|
1937
|
+
{"alias": alias, "artifactCollectionName": portfolio}
|
1938
|
+
for alias in aliases
|
1939
|
+
],
|
1940
|
+
},
|
1941
|
+
)
|
1942
|
+
|
1943
|
+
def used_by(self) -> List[Run]:
|
1944
|
+
"""Get a list of the runs that have used this artifact.
|
1945
|
+
|
1946
|
+
Raises:
|
1947
|
+
ArtifactNotLoggedError: if the artifact has not been logged
|
1948
|
+
"""
|
1949
|
+
if self._state == ArtifactState.PENDING:
|
1950
|
+
raise ArtifactNotLoggedError(self, "used_by")
|
1951
|
+
|
1952
|
+
query = gql(
|
1953
|
+
"""
|
1954
|
+
query ArtifactUsedBy(
|
1955
|
+
$id: ID!,
|
1956
|
+
) {
|
1957
|
+
artifact(id: $id) {
|
1958
|
+
usedBy {
|
1959
|
+
edges {
|
1960
|
+
node {
|
1961
|
+
name
|
1962
|
+
project {
|
1963
|
+
name
|
1964
|
+
entityName
|
1965
|
+
}
|
1966
|
+
}
|
1967
|
+
}
|
1968
|
+
}
|
1969
|
+
}
|
1970
|
+
}
|
1971
|
+
"""
|
1972
|
+
)
|
1973
|
+
assert self._client is not None
|
1974
|
+
response = self._client.execute(
|
1975
|
+
query,
|
1976
|
+
variable_values={"id": self.id},
|
1977
|
+
)
|
1978
|
+
return [
|
1979
|
+
Run(
|
1980
|
+
self._client,
|
1981
|
+
edge["node"]["project"]["entityName"],
|
1982
|
+
edge["node"]["project"]["name"],
|
1983
|
+
edge["node"]["name"],
|
1984
|
+
)
|
1985
|
+
for edge in response.get("artifact", {}).get("usedBy", {}).get("edges", [])
|
1986
|
+
]
|
1987
|
+
|
1988
|
+
def logged_by(self) -> Optional[Run]:
|
1989
|
+
"""Get the run that first logged this artifact.
|
1990
|
+
|
1991
|
+
Raises:
|
1992
|
+
ArtifactNotLoggedError: if the artifact has not been logged
|
1993
|
+
"""
|
1994
|
+
if self._state == ArtifactState.PENDING:
|
1995
|
+
raise ArtifactNotLoggedError(self, "logged_by")
|
1996
|
+
|
1997
|
+
query = gql(
|
1998
|
+
"""
|
1999
|
+
query ArtifactCreatedBy(
|
2000
|
+
$id: ID!
|
2001
|
+
) {
|
2002
|
+
artifact(id: $id) {
|
2003
|
+
createdBy {
|
2004
|
+
... on Run {
|
2005
|
+
name
|
2006
|
+
project {
|
2007
|
+
name
|
2008
|
+
entityName
|
2009
|
+
}
|
2010
|
+
}
|
2011
|
+
}
|
2012
|
+
}
|
2013
|
+
}
|
2014
|
+
"""
|
2015
|
+
)
|
2016
|
+
assert self._client is not None
|
2017
|
+
response = self._client.execute(
|
2018
|
+
query,
|
2019
|
+
variable_values={"id": self.id},
|
2020
|
+
)
|
2021
|
+
creator = response.get("artifact", {}).get("createdBy", {})
|
2022
|
+
if creator.get("name") is None:
|
2023
|
+
return None
|
2024
|
+
return Run(
|
2025
|
+
self._client,
|
2026
|
+
creator["project"]["entityName"],
|
2027
|
+
creator["project"]["name"],
|
2028
|
+
creator["name"],
|
2029
|
+
)
|
2030
|
+
|
2031
|
+
def json_encode(self) -> Dict[str, Any]:
|
2032
|
+
if self._state == ArtifactState.PENDING:
|
2033
|
+
raise ArtifactNotLoggedError(self, "json_encode")
|
2034
|
+
return util.artifact_to_json(self)
|
2035
|
+
|
2036
|
+
@staticmethod
|
2037
|
+
def _expected_type(
|
2038
|
+
entity_name: str, project_name: str, name: str, client: RetryingClient
|
2039
|
+
) -> Optional[str]:
|
2040
|
+
"""Returns the expected type for a given artifact name and project."""
|
2041
|
+
query = gql(
|
2042
|
+
"""
|
2043
|
+
query ArtifactType(
|
2044
|
+
$entityName: String,
|
2045
|
+
$projectName: String,
|
2046
|
+
$name: String!
|
2047
|
+
) {
|
2048
|
+
project(name: $projectName, entityName: $entityName) {
|
2049
|
+
artifact(name: $name) {
|
2050
|
+
artifactType {
|
2051
|
+
name
|
2052
|
+
}
|
2053
|
+
}
|
2054
|
+
}
|
2055
|
+
}
|
2056
|
+
"""
|
2057
|
+
)
|
2058
|
+
if ":" not in name:
|
2059
|
+
name += ":latest"
|
2060
|
+
response = client.execute(
|
2061
|
+
query,
|
2062
|
+
variable_values={
|
2063
|
+
"entityName": entity_name,
|
2064
|
+
"projectName": project_name,
|
2065
|
+
"name": name,
|
2066
|
+
},
|
2067
|
+
)
|
2068
|
+
return (
|
2069
|
+
((response.get("project") or {}).get("artifact") or {}).get("artifactType")
|
2070
|
+
or {}
|
2071
|
+
).get("name")
|
2072
|
+
|
2073
|
+
@staticmethod
|
2074
|
+
def _normalize_metadata(metadata: Optional[Dict[str, Any]]) -> Dict[str, Any]:
|
2075
|
+
if metadata is None:
|
2076
|
+
return {}
|
2077
|
+
if not isinstance(metadata, dict):
|
2078
|
+
raise TypeError(f"metadata must be dict, not {type(metadata)}")
|
2079
|
+
return cast(
|
2080
|
+
Dict[str, Any], json.loads(json.dumps(util.json_friendly_val(metadata)))
|
2081
|
+
)
|
2082
|
+
|
2083
|
+
def _load_manifest(self, url: str) -> None:
|
2084
|
+
with requests.get(url) as request:
|
2085
|
+
request.raise_for_status()
|
2086
|
+
self._manifest = ArtifactManifest.from_manifest_json(
|
2087
|
+
json.loads(util.ensure_text(request.content))
|
2088
|
+
)
|
2089
|
+
for entry in self.manifest.entries.values():
|
2090
|
+
if entry._is_artifact_reference():
|
2091
|
+
assert self._client is not None
|
2092
|
+
dep_artifact = entry._get_referenced_artifact(self._client)
|
2093
|
+
self._dependent_artifacts.add(dep_artifact)
|
2094
|
+
|
2095
|
+
|
2096
|
+
class _ArtifactVersionType(WBType):
|
2097
|
+
name = "artifactVersion"
|
2098
|
+
types = [Artifact]
|
2099
|
+
|
2100
|
+
|
2101
|
+
TypeRegistry.add(_ArtifactVersionType)
|