wandb 0.15.4__py3-none-any.whl → 0.15.5__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- wandb/__init__.py +1 -1
- wandb/analytics/sentry.py +1 -0
- wandb/apis/internal.py +3 -0
- wandb/apis/public.py +18 -20
- wandb/beta/workflows.py +5 -6
- wandb/cli/cli.py +27 -27
- wandb/data_types.py +2 -0
- wandb/integration/langchain/wandb_tracer.py +16 -179
- wandb/integration/sagemaker/config.py +2 -2
- wandb/integration/tensorboard/log.py +4 -4
- wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
- wandb/proto/wandb_deprecated.py +3 -1
- wandb/sdk/__init__.py +1 -4
- wandb/sdk/artifacts/__init__.py +0 -14
- wandb/sdk/artifacts/artifact.py +1757 -277
- wandb/sdk/artifacts/artifact_manifest_entry.py +26 -6
- wandb/sdk/artifacts/artifact_state.py +10 -0
- wandb/sdk/artifacts/artifacts_cache.py +7 -8
- wandb/sdk/artifacts/exceptions.py +4 -4
- wandb/sdk/artifacts/storage_handler.py +2 -2
- wandb/sdk/artifacts/storage_handlers/azure_handler.py +16 -6
- wandb/sdk/artifacts/storage_handlers/gcs_handler.py +2 -2
- wandb/sdk/artifacts/storage_handlers/http_handler.py +2 -2
- wandb/sdk/artifacts/storage_handlers/local_file_handler.py +2 -2
- wandb/sdk/artifacts/storage_handlers/multi_handler.py +2 -2
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +35 -32
- wandb/sdk/artifacts/storage_handlers/tracking_handler.py +2 -2
- wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +5 -9
- wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +2 -2
- wandb/sdk/artifacts/storage_policies/s3_bucket_policy.py +2 -2
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +24 -16
- wandb/sdk/artifacts/storage_policy.py +3 -3
- wandb/sdk/data_types/_dtypes.py +7 -12
- wandb/sdk/data_types/base_types/json_metadata.py +2 -2
- wandb/sdk/data_types/base_types/media.py +5 -6
- wandb/sdk/data_types/base_types/wb_value.py +12 -13
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +4 -5
- wandb/sdk/data_types/helper_types/classes.py +5 -8
- wandb/sdk/data_types/helper_types/image_mask.py +4 -5
- wandb/sdk/data_types/histogram.py +3 -3
- wandb/sdk/data_types/html.py +3 -4
- wandb/sdk/data_types/image.py +4 -5
- wandb/sdk/data_types/molecule.py +2 -2
- wandb/sdk/data_types/object_3d.py +3 -3
- wandb/sdk/data_types/plotly.py +2 -2
- wandb/sdk/data_types/saved_model.py +7 -8
- wandb/sdk/data_types/trace_tree.py +4 -4
- wandb/sdk/data_types/video.py +4 -4
- wandb/sdk/interface/interface.py +8 -10
- wandb/sdk/internal/file_stream.py +2 -3
- wandb/sdk/internal/internal_api.py +99 -4
- wandb/sdk/internal/job_builder.py +15 -7
- wandb/sdk/internal/sender.py +4 -0
- wandb/sdk/internal/settings_static.py +1 -0
- wandb/sdk/launch/_project_spec.py +9 -7
- wandb/sdk/launch/agent/agent.py +115 -58
- wandb/sdk/launch/agent/job_status_tracker.py +34 -0
- wandb/sdk/launch/agent/run_queue_item_file_saver.py +45 -0
- wandb/sdk/launch/builder/abstract.py +5 -1
- wandb/sdk/launch/builder/build.py +16 -10
- wandb/sdk/launch/builder/docker_builder.py +9 -2
- wandb/sdk/launch/builder/kaniko_builder.py +108 -22
- wandb/sdk/launch/builder/noop.py +3 -1
- wandb/sdk/launch/environment/aws_environment.py +2 -1
- wandb/sdk/launch/environment/azure_environment.py +124 -0
- wandb/sdk/launch/github_reference.py +30 -18
- wandb/sdk/launch/launch.py +1 -1
- wandb/sdk/launch/loader.py +15 -0
- wandb/sdk/launch/registry/azure_container_registry.py +132 -0
- wandb/sdk/launch/registry/elastic_container_registry.py +38 -4
- wandb/sdk/launch/registry/google_artifact_registry.py +46 -7
- wandb/sdk/launch/runner/abstract.py +19 -3
- wandb/sdk/launch/runner/kubernetes_runner.py +111 -47
- wandb/sdk/launch/runner/local_container.py +101 -48
- wandb/sdk/launch/runner/sagemaker_runner.py +59 -9
- wandb/sdk/launch/runner/vertex_runner.py +8 -4
- wandb/sdk/launch/sweeps/scheduler.py +102 -27
- wandb/sdk/launch/sweeps/utils.py +21 -0
- wandb/sdk/launch/utils.py +19 -7
- wandb/sdk/lib/_settings_toposort_generated.py +3 -0
- wandb/sdk/service/server.py +22 -9
- wandb/sdk/service/service.py +27 -8
- wandb/sdk/verify/verify.py +6 -9
- wandb/sdk/wandb_config.py +2 -4
- wandb/sdk/wandb_init.py +2 -0
- wandb/sdk/wandb_require.py +7 -0
- wandb/sdk/wandb_run.py +32 -35
- wandb/sdk/wandb_settings.py +10 -3
- wandb/testing/relay.py +15 -2
- wandb/util.py +55 -23
- {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/METADATA +11 -8
- {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/RECORD +97 -97
- wandb/integration/langchain/util.py +0 -191
- wandb/sdk/artifacts/invalid_artifact.py +0 -23
- wandb/sdk/artifacts/lazy_artifact.py +0 -162
- wandb/sdk/artifacts/local_artifact.py +0 -719
- wandb/sdk/artifacts/public_artifact.py +0 -1188
- {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/LICENSE +0 -0
- {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/WHEEL +0 -0
- {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/entry_points.txt +0 -0
- {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/top_level.txt +0 -0
@@ -1,1188 +0,0 @@
|
|
1
|
-
"""Public (saved) artifact."""
|
2
|
-
import contextlib
|
3
|
-
import datetime
|
4
|
-
import json
|
5
|
-
import os
|
6
|
-
import platform
|
7
|
-
import re
|
8
|
-
import urllib
|
9
|
-
from functools import partial
|
10
|
-
from typing import (
|
11
|
-
IO,
|
12
|
-
TYPE_CHECKING,
|
13
|
-
Any,
|
14
|
-
Dict,
|
15
|
-
Generator,
|
16
|
-
Iterable,
|
17
|
-
List,
|
18
|
-
Mapping,
|
19
|
-
Optional,
|
20
|
-
Sequence,
|
21
|
-
Set,
|
22
|
-
Tuple,
|
23
|
-
Type,
|
24
|
-
Union,
|
25
|
-
)
|
26
|
-
|
27
|
-
import requests
|
28
|
-
|
29
|
-
import wandb
|
30
|
-
from wandb import util
|
31
|
-
from wandb.apis.normalize import normalize_exceptions
|
32
|
-
from wandb.apis.public import ArtifactFiles, RetryingClient, Run
|
33
|
-
from wandb.data_types import WBValue
|
34
|
-
from wandb.env import get_artifact_dir
|
35
|
-
from wandb.errors.term import termlog
|
36
|
-
from wandb.sdk.artifacts.artifact import Artifact as ArtifactInterface
|
37
|
-
from wandb.sdk.artifacts.artifact_download_logger import ArtifactDownloadLogger
|
38
|
-
from wandb.sdk.artifacts.artifact_manifest import ArtifactManifest
|
39
|
-
from wandb.sdk.artifacts.artifacts_cache import get_artifacts_cache
|
40
|
-
from wandb.sdk.internal.thread_local_settings import _thread_local_api_settings
|
41
|
-
from wandb.sdk.lib.hashutil import hex_to_b64_id, md5_file_b64
|
42
|
-
from wandb.sdk.lib.paths import FilePathStr, LogicalPath, StrPath
|
43
|
-
|
44
|
-
if TYPE_CHECKING:
|
45
|
-
from wandb.sdk.artifacts.artifact_manifest_entry import ArtifactManifestEntry
|
46
|
-
from wandb.sdk.artifacts.local_artifact import Artifact as LocalArtifact
|
47
|
-
|
48
|
-
reset_path = util.vendor_setup()
|
49
|
-
|
50
|
-
from wandb_gql import gql # noqa: E402
|
51
|
-
|
52
|
-
reset_path()
|
53
|
-
|
54
|
-
ARTIFACT_FRAGMENT = """
|
55
|
-
fragment ArtifactFragment on Artifact {
|
56
|
-
id
|
57
|
-
digest
|
58
|
-
description
|
59
|
-
state
|
60
|
-
size
|
61
|
-
createdAt
|
62
|
-
updatedAt
|
63
|
-
labels
|
64
|
-
metadata
|
65
|
-
fileCount
|
66
|
-
versionIndex
|
67
|
-
aliases {
|
68
|
-
artifactCollectionName
|
69
|
-
alias
|
70
|
-
}
|
71
|
-
artifactSequence {
|
72
|
-
id
|
73
|
-
name
|
74
|
-
}
|
75
|
-
artifactType {
|
76
|
-
id
|
77
|
-
name
|
78
|
-
project {
|
79
|
-
name
|
80
|
-
entity {
|
81
|
-
name
|
82
|
-
}
|
83
|
-
}
|
84
|
-
}
|
85
|
-
commitHash
|
86
|
-
}
|
87
|
-
"""
|
88
|
-
|
89
|
-
|
90
|
-
class Artifact(ArtifactInterface):
|
91
|
-
"""A wandb Artifact.
|
92
|
-
|
93
|
-
An artifact that has been logged, including all its attributes, links to the runs
|
94
|
-
that use it, and a link to the run that logged it.
|
95
|
-
|
96
|
-
Examples:
|
97
|
-
Basic usage
|
98
|
-
```
|
99
|
-
api = wandb.Api()
|
100
|
-
artifact = api.artifact('project/artifact:alias')
|
101
|
-
|
102
|
-
# Get information about the artifact...
|
103
|
-
artifact.digest
|
104
|
-
artifact.aliases
|
105
|
-
```
|
106
|
-
|
107
|
-
Updating an artifact
|
108
|
-
```
|
109
|
-
artifact = api.artifact('project/artifact:alias')
|
110
|
-
|
111
|
-
# Update the description
|
112
|
-
artifact.description = 'My new description'
|
113
|
-
|
114
|
-
# Selectively update metadata keys
|
115
|
-
artifact.metadata["oldKey"] = "new value"
|
116
|
-
|
117
|
-
# Replace the metadata entirely
|
118
|
-
artifact.metadata = {"newKey": "new value"}
|
119
|
-
|
120
|
-
# Add an alias
|
121
|
-
artifact.aliases.append('best')
|
122
|
-
|
123
|
-
# Remove an alias
|
124
|
-
artifact.aliases.remove('latest')
|
125
|
-
|
126
|
-
# Completely replace the aliases
|
127
|
-
artifact.aliases = ['replaced']
|
128
|
-
|
129
|
-
# Persist all artifact modifications
|
130
|
-
artifact.save()
|
131
|
-
```
|
132
|
-
|
133
|
-
Artifact graph traversal
|
134
|
-
```
|
135
|
-
artifact = api.artifact('project/artifact:alias')
|
136
|
-
|
137
|
-
# Walk up and down the graph from an artifact:
|
138
|
-
producer_run = artifact.logged_by()
|
139
|
-
consumer_runs = artifact.used_by()
|
140
|
-
|
141
|
-
# Walk up and down the graph from a run:
|
142
|
-
logged_artifacts = run.logged_artifacts()
|
143
|
-
used_artifacts = run.used_artifacts()
|
144
|
-
```
|
145
|
-
|
146
|
-
Deleting an artifact
|
147
|
-
```
|
148
|
-
artifact = api.artifact('project/artifact:alias')
|
149
|
-
artifact.delete()
|
150
|
-
```
|
151
|
-
"""
|
152
|
-
|
153
|
-
QUERY = gql(
|
154
|
-
"""
|
155
|
-
query ArtifactWithCurrentManifest(
|
156
|
-
$id: ID!,
|
157
|
-
) {
|
158
|
-
artifact(id: $id) {
|
159
|
-
currentManifest {
|
160
|
-
id
|
161
|
-
file {
|
162
|
-
id
|
163
|
-
directUrl
|
164
|
-
}
|
165
|
-
}
|
166
|
-
...ArtifactFragment
|
167
|
-
}
|
168
|
-
}
|
169
|
-
%s
|
170
|
-
"""
|
171
|
-
% ARTIFACT_FRAGMENT
|
172
|
-
)
|
173
|
-
|
174
|
-
@classmethod
|
175
|
-
def from_id(cls, artifact_id: str, client: RetryingClient) -> Optional["Artifact"]:
|
176
|
-
artifact = get_artifacts_cache().get_artifact(artifact_id)
|
177
|
-
if artifact is not None:
|
178
|
-
assert isinstance(artifact, Artifact)
|
179
|
-
return artifact
|
180
|
-
response: Mapping[str, Any] = client.execute(
|
181
|
-
Artifact.QUERY,
|
182
|
-
variable_values={"id": artifact_id},
|
183
|
-
)
|
184
|
-
|
185
|
-
if response.get("artifact") is None:
|
186
|
-
return None
|
187
|
-
p = response.get("artifact", {}).get("artifactType", {}).get("project", {})
|
188
|
-
project = p.get("name") # defaults to None
|
189
|
-
entity = p.get("entity", {}).get("name")
|
190
|
-
name = "{}:v{}".format(
|
191
|
-
response["artifact"]["artifactSequence"]["name"],
|
192
|
-
response["artifact"]["versionIndex"],
|
193
|
-
)
|
194
|
-
artifact = cls(
|
195
|
-
client=client,
|
196
|
-
entity=entity,
|
197
|
-
project=project,
|
198
|
-
name=name,
|
199
|
-
attrs=response["artifact"],
|
200
|
-
)
|
201
|
-
index_file_url = response["artifact"]["currentManifest"]["file"]["directUrl"]
|
202
|
-
with requests.get(index_file_url) as req:
|
203
|
-
req.raise_for_status()
|
204
|
-
artifact._manifest = ArtifactManifest.from_manifest_json(
|
205
|
-
json.loads(util.ensure_text(req.content))
|
206
|
-
)
|
207
|
-
|
208
|
-
artifact._load_dependent_manifests()
|
209
|
-
|
210
|
-
return artifact
|
211
|
-
|
212
|
-
def __init__(
|
213
|
-
self,
|
214
|
-
client: RetryingClient,
|
215
|
-
entity: str,
|
216
|
-
project: str,
|
217
|
-
name: str,
|
218
|
-
attrs: Optional[Dict[str, Any]] = None,
|
219
|
-
) -> None:
|
220
|
-
self.client = client
|
221
|
-
self._entity = entity
|
222
|
-
self._project = project
|
223
|
-
self._name = name
|
224
|
-
self._artifact_collection_name = name.split(":")[0]
|
225
|
-
self._attrs = attrs or self._load()
|
226
|
-
|
227
|
-
# The entity and project above are taken from the passed-in artifact version path
|
228
|
-
# so if the user is pulling an artifact version from an artifact portfolio, the entity/project
|
229
|
-
# of that portfolio may be different than the birth entity/project of the artifact version.
|
230
|
-
self._source_project = (
|
231
|
-
self._attrs.get("artifactType", {}).get("project", {}).get("name")
|
232
|
-
)
|
233
|
-
self._source_entity = (
|
234
|
-
self._attrs.get("artifactType", {})
|
235
|
-
.get("project", {})
|
236
|
-
.get("entity", {})
|
237
|
-
.get("name")
|
238
|
-
)
|
239
|
-
self._metadata = json.loads(self._attrs.get("metadata") or "{}")
|
240
|
-
self._description = self._attrs.get("description", None)
|
241
|
-
self._source_name = "{}:v{}".format(
|
242
|
-
self._attrs["artifactSequence"]["name"], self._attrs.get("versionIndex")
|
243
|
-
)
|
244
|
-
self._source_version = "v{}".format(self._attrs.get("versionIndex"))
|
245
|
-
# We will only show aliases under the Collection this artifact version is fetched from
|
246
|
-
# _aliases will be a mutable copy on which the user can append or remove aliases
|
247
|
-
self._aliases: List[str] = [
|
248
|
-
a["alias"]
|
249
|
-
for a in self._attrs["aliases"]
|
250
|
-
if not re.match(r"^v\d+$", a["alias"])
|
251
|
-
and a["artifactCollectionName"] == self._artifact_collection_name
|
252
|
-
]
|
253
|
-
self._frozen_aliases: List[str] = [a for a in self._aliases]
|
254
|
-
self._manifest: Optional[ArtifactManifest] = None
|
255
|
-
self._is_downloaded: bool = False
|
256
|
-
self._dependent_artifacts: List["Artifact"] = []
|
257
|
-
self._download_roots: Set[str] = set()
|
258
|
-
get_artifacts_cache().store_artifact(self)
|
259
|
-
|
260
|
-
@property
|
261
|
-
def id(self) -> Optional[str]:
|
262
|
-
return self._attrs["id"]
|
263
|
-
|
264
|
-
@property
|
265
|
-
def entity(self) -> str:
|
266
|
-
return self._entity
|
267
|
-
|
268
|
-
@property
|
269
|
-
def project(self) -> str:
|
270
|
-
return self._project
|
271
|
-
|
272
|
-
@property
|
273
|
-
def name(self) -> str:
|
274
|
-
return self._name
|
275
|
-
|
276
|
-
@property
|
277
|
-
def version(self) -> str:
|
278
|
-
"""The artifact's version index under the given artifact collection.
|
279
|
-
|
280
|
-
A string with the format "v{number}".
|
281
|
-
"""
|
282
|
-
for a in self._attrs["aliases"]:
|
283
|
-
if a[
|
284
|
-
"artifactCollectionName"
|
285
|
-
] == self._artifact_collection_name and util.alias_is_version_index(
|
286
|
-
a["alias"]
|
287
|
-
):
|
288
|
-
return a["alias"]
|
289
|
-
raise NotImplementedError
|
290
|
-
|
291
|
-
@property
|
292
|
-
def source_entity(self) -> str:
|
293
|
-
return self._source_entity
|
294
|
-
|
295
|
-
@property
|
296
|
-
def source_project(self) -> str:
|
297
|
-
return self._source_project
|
298
|
-
|
299
|
-
@property
|
300
|
-
def source_name(self) -> str:
|
301
|
-
return self._source_name
|
302
|
-
|
303
|
-
@property
|
304
|
-
def source_version(self) -> str:
|
305
|
-
"""The artifact's version index under its parent artifact collection.
|
306
|
-
|
307
|
-
A string with the format "v{number}".
|
308
|
-
"""
|
309
|
-
return self._source_version
|
310
|
-
|
311
|
-
@property
|
312
|
-
def file_count(self) -> int:
|
313
|
-
return self._attrs["fileCount"]
|
314
|
-
|
315
|
-
@property
|
316
|
-
def metadata(self) -> dict:
|
317
|
-
return self._metadata
|
318
|
-
|
319
|
-
@metadata.setter
|
320
|
-
def metadata(self, metadata: dict) -> None:
|
321
|
-
self._metadata = metadata
|
322
|
-
|
323
|
-
@property
|
324
|
-
def manifest(self) -> ArtifactManifest:
|
325
|
-
return self._load_manifest()
|
326
|
-
|
327
|
-
@property
|
328
|
-
def digest(self) -> str:
|
329
|
-
return self._attrs["digest"]
|
330
|
-
|
331
|
-
@property
|
332
|
-
def state(self) -> str:
|
333
|
-
return self._attrs["state"]
|
334
|
-
|
335
|
-
@property
|
336
|
-
def size(self) -> int:
|
337
|
-
return self._attrs["size"]
|
338
|
-
|
339
|
-
@property
|
340
|
-
def created_at(self) -> str:
|
341
|
-
"""The time at which the artifact was created."""
|
342
|
-
return self._attrs["createdAt"]
|
343
|
-
|
344
|
-
@property
|
345
|
-
def updated_at(self) -> str:
|
346
|
-
"""The time at which the artifact was last updated."""
|
347
|
-
return self._attrs["updatedAt"] or self._attrs["createdAt"]
|
348
|
-
|
349
|
-
@property
|
350
|
-
def description(self) -> Optional[str]:
|
351
|
-
return self._description
|
352
|
-
|
353
|
-
@description.setter
|
354
|
-
def description(self, desc: Optional[str]) -> None:
|
355
|
-
self._description = desc
|
356
|
-
|
357
|
-
@property
|
358
|
-
def type(self) -> str:
|
359
|
-
return self._attrs["artifactType"]["name"]
|
360
|
-
|
361
|
-
@property
|
362
|
-
def commit_hash(self) -> str:
|
363
|
-
return self._attrs.get("commitHash", "")
|
364
|
-
|
365
|
-
@property
|
366
|
-
def aliases(self) -> List[str]:
|
367
|
-
"""The aliases associated with this artifact.
|
368
|
-
|
369
|
-
Returns:
|
370
|
-
List[str]: The aliases associated with this artifact.
|
371
|
-
|
372
|
-
"""
|
373
|
-
return self._aliases
|
374
|
-
|
375
|
-
@aliases.setter
|
376
|
-
def aliases(self, aliases: List[str]) -> None:
|
377
|
-
for alias in aliases:
|
378
|
-
if any(char in alias for char in ["/", ":"]):
|
379
|
-
raise ValueError(
|
380
|
-
'Invalid alias "%s", slashes and colons are disallowed' % alias
|
381
|
-
)
|
382
|
-
self._aliases = aliases
|
383
|
-
|
384
|
-
@staticmethod
|
385
|
-
def expected_type(
|
386
|
-
client: RetryingClient, name: str, entity_name: str, project_name: str
|
387
|
-
) -> Optional[str]:
|
388
|
-
"""Returns the expected type for a given artifact name and project."""
|
389
|
-
query = gql(
|
390
|
-
"""
|
391
|
-
query ArtifactType(
|
392
|
-
$entityName: String,
|
393
|
-
$projectName: String,
|
394
|
-
$name: String!
|
395
|
-
) {
|
396
|
-
project(name: $projectName, entityName: $entityName) {
|
397
|
-
artifact(name: $name) {
|
398
|
-
artifactType {
|
399
|
-
name
|
400
|
-
}
|
401
|
-
}
|
402
|
-
}
|
403
|
-
}
|
404
|
-
"""
|
405
|
-
)
|
406
|
-
if ":" not in name:
|
407
|
-
name += ":latest"
|
408
|
-
|
409
|
-
response = client.execute(
|
410
|
-
query,
|
411
|
-
variable_values={
|
412
|
-
"entityName": entity_name,
|
413
|
-
"projectName": project_name,
|
414
|
-
"name": name,
|
415
|
-
},
|
416
|
-
)
|
417
|
-
|
418
|
-
project = response.get("project")
|
419
|
-
if project is not None:
|
420
|
-
artifact = project.get("artifact")
|
421
|
-
if artifact is not None:
|
422
|
-
artifact_type = artifact.get("artifactType")
|
423
|
-
if artifact_type is not None:
|
424
|
-
return artifact_type.get("name")
|
425
|
-
|
426
|
-
return None
|
427
|
-
|
428
|
-
@property
|
429
|
-
def _use_as(self) -> Optional[str]:
|
430
|
-
return self._attrs.get("_use_as")
|
431
|
-
|
432
|
-
@_use_as.setter
|
433
|
-
def _use_as(self, use_as: Optional[str]) -> None:
|
434
|
-
self._attrs["_use_as"] = use_as
|
435
|
-
|
436
|
-
@normalize_exceptions
|
437
|
-
def link(self, target_path: str, aliases: Optional[List[str]] = None) -> None:
|
438
|
-
if ":" in target_path:
|
439
|
-
raise ValueError(
|
440
|
-
f"target_path {target_path} cannot contain `:` because it is not an alias."
|
441
|
-
)
|
442
|
-
|
443
|
-
portfolio, project, entity = util._parse_entity_project_item(target_path)
|
444
|
-
aliases = util._resolve_aliases(aliases)
|
445
|
-
|
446
|
-
run_entity = wandb.run.entity if wandb.run else None
|
447
|
-
run_project = wandb.run.project if wandb.run else None
|
448
|
-
entity = entity or run_entity or self.entity
|
449
|
-
project = project or run_project or self.project
|
450
|
-
|
451
|
-
mutation = gql(
|
452
|
-
"""
|
453
|
-
mutation LinkArtifact($artifactID: ID!, $artifactPortfolioName: String!, $entityName: String!, $projectName: String!, $aliases: [ArtifactAliasInput!]) {
|
454
|
-
linkArtifact(input: {artifactID: $artifactID, artifactPortfolioName: $artifactPortfolioName,
|
455
|
-
entityName: $entityName,
|
456
|
-
projectName: $projectName,
|
457
|
-
aliases: $aliases
|
458
|
-
}) {
|
459
|
-
versionIndex
|
460
|
-
}
|
461
|
-
}
|
462
|
-
"""
|
463
|
-
)
|
464
|
-
self.client.execute(
|
465
|
-
mutation,
|
466
|
-
variable_values={
|
467
|
-
"artifactID": self.id,
|
468
|
-
"artifactPortfolioName": portfolio,
|
469
|
-
"entityName": entity,
|
470
|
-
"projectName": project,
|
471
|
-
"aliases": [
|
472
|
-
{"alias": alias, "artifactCollectionName": portfolio}
|
473
|
-
for alias in aliases
|
474
|
-
],
|
475
|
-
},
|
476
|
-
)
|
477
|
-
|
478
|
-
@normalize_exceptions
|
479
|
-
def delete(self, delete_aliases: bool = False) -> None:
|
480
|
-
"""Delete an artifact and its files.
|
481
|
-
|
482
|
-
Examples:
|
483
|
-
Delete all the "model" artifacts a run has logged:
|
484
|
-
```
|
485
|
-
runs = api.runs(path="my_entity/my_project")
|
486
|
-
for run in runs:
|
487
|
-
for artifact in run.logged_artifacts():
|
488
|
-
if artifact.type == "model":
|
489
|
-
artifact.delete(delete_aliases=True)
|
490
|
-
```
|
491
|
-
|
492
|
-
Arguments:
|
493
|
-
delete_aliases: (bool) If true, deletes all aliases associated with the artifact.
|
494
|
-
Otherwise, this raises an exception if the artifact has existing aliases.
|
495
|
-
"""
|
496
|
-
mutation = gql(
|
497
|
-
"""
|
498
|
-
mutation DeleteArtifact($artifactID: ID!, $deleteAliases: Boolean) {
|
499
|
-
deleteArtifact(input: {
|
500
|
-
artifactID: $artifactID
|
501
|
-
deleteAliases: $deleteAliases
|
502
|
-
}) {
|
503
|
-
artifact {
|
504
|
-
id
|
505
|
-
}
|
506
|
-
}
|
507
|
-
}
|
508
|
-
"""
|
509
|
-
)
|
510
|
-
self.client.execute(
|
511
|
-
mutation,
|
512
|
-
variable_values={
|
513
|
-
"artifactID": self.id,
|
514
|
-
"deleteAliases": delete_aliases,
|
515
|
-
},
|
516
|
-
)
|
517
|
-
|
518
|
-
@contextlib.contextmanager
|
519
|
-
def new_file(
|
520
|
-
self, name: str, mode: str = "w", encoding: Optional[str] = None
|
521
|
-
) -> Generator[IO, None, None]:
|
522
|
-
raise ValueError("Cannot add files to an artifact once it has been saved")
|
523
|
-
|
524
|
-
def add_file(
|
525
|
-
self,
|
526
|
-
local_path: str,
|
527
|
-
name: Optional[str] = None,
|
528
|
-
is_tmp: Optional[bool] = False,
|
529
|
-
) -> "ArtifactManifestEntry":
|
530
|
-
raise ValueError("Cannot add files to an artifact once it has been saved")
|
531
|
-
|
532
|
-
def add_dir(self, local_path: str, name: Optional[str] = None) -> None:
|
533
|
-
raise ValueError("Cannot add files to an artifact once it has been saved")
|
534
|
-
|
535
|
-
def add_reference(
|
536
|
-
self,
|
537
|
-
uri: Union["ArtifactManifestEntry", str],
|
538
|
-
name: Optional[StrPath] = None,
|
539
|
-
checksum: bool = True,
|
540
|
-
max_objects: Optional[int] = None,
|
541
|
-
) -> Sequence["ArtifactManifestEntry"]:
|
542
|
-
raise ValueError("Cannot add files to an artifact once it has been saved")
|
543
|
-
|
544
|
-
def add(self, obj: WBValue, name: StrPath) -> "ArtifactManifestEntry":
|
545
|
-
raise ValueError("Cannot add files to an artifact once it has been saved")
|
546
|
-
|
547
|
-
def remove(self, item: Union[str, "os.PathLike", "ArtifactManifestEntry"]) -> None:
|
548
|
-
raise ValueError("Cannot remove files from an artifact once it has been saved")
|
549
|
-
|
550
|
-
def _add_download_root(self, dir_path: str) -> None:
|
551
|
-
"""Make `dir_path` a root directory for this artifact."""
|
552
|
-
self._download_roots.add(os.path.abspath(dir_path))
|
553
|
-
|
554
|
-
def _is_download_root(self, dir_path: str) -> bool:
|
555
|
-
"""Determine if `dir_path` is a root directory for this artifact."""
|
556
|
-
return dir_path in self._download_roots
|
557
|
-
|
558
|
-
def _local_path_to_name(self, file_path: str) -> Optional[str]:
|
559
|
-
"""Convert a local file path to a path entry in the artifact."""
|
560
|
-
abs_file_path = os.path.abspath(file_path)
|
561
|
-
abs_file_parts = abs_file_path.split(os.sep)
|
562
|
-
for i in range(len(abs_file_parts) + 1):
|
563
|
-
if self._is_download_root(os.path.join(os.sep, *abs_file_parts[:i])):
|
564
|
-
return os.path.join(*abs_file_parts[i:])
|
565
|
-
return None
|
566
|
-
|
567
|
-
def _get_obj_entry(
|
568
|
-
self, name: str
|
569
|
-
) -> Tuple[Optional["ArtifactManifestEntry"], Optional[Type[WBValue]]]:
|
570
|
-
"""Return an object entry by name, handling any type suffixes.
|
571
|
-
|
572
|
-
When objects are added with `.add(obj, name)`, the name is typically changed to
|
573
|
-
include the suffix of the object type when serializing to JSON. So we need to be
|
574
|
-
able to resolve a name, without tasking the user with appending .THING.json.
|
575
|
-
This method returns an entry if it exists by a suffixed name.
|
576
|
-
|
577
|
-
Args:
|
578
|
-
name: (str) name used when adding
|
579
|
-
"""
|
580
|
-
self._load_manifest()
|
581
|
-
|
582
|
-
type_mapping = WBValue.type_mapping()
|
583
|
-
for artifact_type_str in type_mapping:
|
584
|
-
wb_class = type_mapping[artifact_type_str]
|
585
|
-
wandb_file_name = wb_class.with_suffix(name)
|
586
|
-
entry = self.manifest.entries.get(wandb_file_name)
|
587
|
-
if entry is not None:
|
588
|
-
return entry, wb_class
|
589
|
-
return None, None
|
590
|
-
|
591
|
-
def get_path(self, name: StrPath) -> "ArtifactManifestEntry":
|
592
|
-
name = LogicalPath(name)
|
593
|
-
manifest = self._load_manifest()
|
594
|
-
entry = manifest.entries.get(name) or self._get_obj_entry(name)[0]
|
595
|
-
if entry is None:
|
596
|
-
raise KeyError("Path not contained in artifact: %s" % name)
|
597
|
-
entry._parent_artifact = self
|
598
|
-
return entry
|
599
|
-
|
600
|
-
def get(self, name: str) -> Optional[WBValue]:
|
601
|
-
entry, wb_class = self._get_obj_entry(name)
|
602
|
-
if entry is None or wb_class is None:
|
603
|
-
return None
|
604
|
-
# If the entry is a reference from another artifact, then get it directly from that artifact
|
605
|
-
if self._manifest_entry_is_artifact_reference(entry):
|
606
|
-
artifact = self._get_ref_artifact_from_entry(entry)
|
607
|
-
return artifact.get(util.uri_from_path(entry.ref))
|
608
|
-
|
609
|
-
# Special case for wandb.Table. This is intended to be a short term optimization.
|
610
|
-
# Since tables are likely to download many other assets in artifact(s), we eagerly download
|
611
|
-
# the artifact using the parallelized `artifact.download`. In the future, we should refactor
|
612
|
-
# the deserialization pattern such that this special case is not needed.
|
613
|
-
if wb_class == wandb.Table:
|
614
|
-
self.download(recursive=True)
|
615
|
-
|
616
|
-
# Get the ArtifactManifestEntry
|
617
|
-
item = self.get_path(entry.path)
|
618
|
-
item_path = item.download()
|
619
|
-
|
620
|
-
# Load the object from the JSON blob
|
621
|
-
result = None
|
622
|
-
json_obj = {}
|
623
|
-
with open(item_path) as file:
|
624
|
-
json_obj = json.load(file)
|
625
|
-
result = wb_class.from_json(json_obj, self)
|
626
|
-
result._set_artifact_source(self, name)
|
627
|
-
return result
|
628
|
-
|
629
|
-
def download(
|
630
|
-
self, root: Optional[str] = None, recursive: bool = False
|
631
|
-
) -> FilePathStr:
|
632
|
-
dirpath = root or self._default_root()
|
633
|
-
self._add_download_root(dirpath)
|
634
|
-
manifest = self._load_manifest()
|
635
|
-
nfiles = len(manifest.entries)
|
636
|
-
size = sum(e.size or 0 for e in manifest.entries.values())
|
637
|
-
log = False
|
638
|
-
if nfiles > 5000 or size > 50 * 1024 * 1024:
|
639
|
-
log = True
|
640
|
-
termlog(
|
641
|
-
"Downloading large artifact {}, {:.2f}MB. {} files... ".format(
|
642
|
-
self.name, size / (1024 * 1024), nfiles
|
643
|
-
),
|
644
|
-
)
|
645
|
-
start_time = datetime.datetime.now()
|
646
|
-
|
647
|
-
# Force all the files to download into the same directory.
|
648
|
-
# Download in parallel
|
649
|
-
import multiprocessing.dummy # this uses threads
|
650
|
-
|
651
|
-
download_logger = ArtifactDownloadLogger(nfiles=nfiles)
|
652
|
-
|
653
|
-
def _download_file_with_thread_local_api_settings(
|
654
|
-
name: str,
|
655
|
-
root: str,
|
656
|
-
download_logger: ArtifactDownloadLogger,
|
657
|
-
tlas_api_key: Optional[str],
|
658
|
-
tlas_cookies: Optional[Dict],
|
659
|
-
tlas_headers: Optional[Dict],
|
660
|
-
) -> StrPath:
|
661
|
-
_thread_local_api_settings.api_key = tlas_api_key
|
662
|
-
_thread_local_api_settings.cookies = tlas_cookies
|
663
|
-
_thread_local_api_settings.headers = tlas_headers
|
664
|
-
|
665
|
-
return self._download_file(name, root, download_logger)
|
666
|
-
|
667
|
-
pool = multiprocessing.dummy.Pool(32)
|
668
|
-
pool.map(
|
669
|
-
partial(
|
670
|
-
_download_file_with_thread_local_api_settings,
|
671
|
-
root=dirpath,
|
672
|
-
download_logger=download_logger,
|
673
|
-
tlas_api_key=_thread_local_api_settings.api_key,
|
674
|
-
tlas_headers={**(_thread_local_api_settings.headers or {})},
|
675
|
-
tlas_cookies={**(_thread_local_api_settings.cookies or {})},
|
676
|
-
),
|
677
|
-
manifest.entries,
|
678
|
-
)
|
679
|
-
if recursive:
|
680
|
-
pool.map(lambda artifact: artifact.download(), self._dependent_artifacts)
|
681
|
-
pool.close()
|
682
|
-
pool.join()
|
683
|
-
|
684
|
-
self._is_downloaded = True
|
685
|
-
|
686
|
-
if log:
|
687
|
-
now = datetime.datetime.now()
|
688
|
-
delta = abs((now - start_time).total_seconds())
|
689
|
-
hours = int(delta // 3600)
|
690
|
-
minutes = int((delta - hours * 3600) // 60)
|
691
|
-
seconds = delta - hours * 3600 - minutes * 60
|
692
|
-
termlog(
|
693
|
-
f"Done. {hours}:{minutes}:{seconds:.1f}",
|
694
|
-
prefix=False,
|
695
|
-
)
|
696
|
-
return FilePathStr(dirpath)
|
697
|
-
|
698
|
-
def checkout(self, root: Optional[str] = None) -> str:
|
699
|
-
dirpath = root or self._default_root(include_version=False)
|
700
|
-
|
701
|
-
for root, _, files in os.walk(dirpath):
|
702
|
-
for file in files:
|
703
|
-
full_path = os.path.join(root, file)
|
704
|
-
artifact_path = os.path.relpath(full_path, start=dirpath)
|
705
|
-
try:
|
706
|
-
self.get_path(artifact_path)
|
707
|
-
except KeyError:
|
708
|
-
# File is not part of the artifact, remove it.
|
709
|
-
os.remove(full_path)
|
710
|
-
|
711
|
-
return self.download(root=dirpath)
|
712
|
-
|
713
|
-
def verify(self, root: Optional[str] = None) -> None:
|
714
|
-
dirpath = root or self._default_root()
|
715
|
-
manifest = self._load_manifest()
|
716
|
-
ref_count = 0
|
717
|
-
|
718
|
-
for root, _, files in os.walk(dirpath):
|
719
|
-
for file in files:
|
720
|
-
full_path = os.path.join(root, file)
|
721
|
-
artifact_path = os.path.relpath(full_path, start=dirpath)
|
722
|
-
try:
|
723
|
-
self.get_path(artifact_path)
|
724
|
-
except KeyError:
|
725
|
-
raise ValueError(
|
726
|
-
"Found file {} which is not a member of artifact {}".format(
|
727
|
-
full_path, self.name
|
728
|
-
)
|
729
|
-
)
|
730
|
-
|
731
|
-
for entry in manifest.entries.values():
|
732
|
-
if entry.ref is None:
|
733
|
-
if md5_file_b64(os.path.join(dirpath, entry.path)) != entry.digest:
|
734
|
-
raise ValueError("Digest mismatch for file: %s" % entry.path)
|
735
|
-
else:
|
736
|
-
ref_count += 1
|
737
|
-
if ref_count > 0:
|
738
|
-
print("Warning: skipped verification of %s refs" % ref_count)
|
739
|
-
|
740
|
-
def file(self, root: Optional[str] = None) -> StrPath:
|
741
|
-
"""Download a single file artifact to dir specified by the root.
|
742
|
-
|
743
|
-
Arguments:
|
744
|
-
root: (str, optional) The root directory in which to place the file. Defaults to './artifacts/self.name/'.
|
745
|
-
|
746
|
-
Returns:
|
747
|
-
(str): The full path of the downloaded file.
|
748
|
-
"""
|
749
|
-
if root is None:
|
750
|
-
root = os.path.join(".", "artifacts", self.name)
|
751
|
-
|
752
|
-
manifest = self._load_manifest()
|
753
|
-
nfiles = len(manifest.entries)
|
754
|
-
if nfiles > 1:
|
755
|
-
raise ValueError(
|
756
|
-
"This artifact contains more than one file, call `.download()` to get all files or call "
|
757
|
-
'.get_path("filename").download()'
|
758
|
-
)
|
759
|
-
|
760
|
-
return self._download_file(list(manifest.entries)[0], root=root)
|
761
|
-
|
762
|
-
def _download_file(
|
763
|
-
self,
|
764
|
-
name: str,
|
765
|
-
root: str,
|
766
|
-
download_logger: Optional[ArtifactDownloadLogger] = None,
|
767
|
-
) -> StrPath:
|
768
|
-
# download file into cache and copy to target dir
|
769
|
-
downloaded_path = self.get_path(name).download(root)
|
770
|
-
if download_logger is not None:
|
771
|
-
download_logger.notify_downloaded()
|
772
|
-
return downloaded_path
|
773
|
-
|
774
|
-
def _default_root(self, include_version: bool = True) -> str:
|
775
|
-
name = self.source_name if include_version else self.source_name.split(":")[0]
|
776
|
-
root = os.path.join(get_artifact_dir(), name)
|
777
|
-
if platform.system() == "Windows":
|
778
|
-
head, tail = os.path.splitdrive(root)
|
779
|
-
root = head + tail.replace(":", "-")
|
780
|
-
return root
|
781
|
-
|
782
|
-
def json_encode(self) -> Dict[str, Any]:
|
783
|
-
return util.artifact_to_json(self)
|
784
|
-
|
785
|
-
@normalize_exceptions
|
786
|
-
def save(self) -> None:
|
787
|
-
"""Persists artifact changes to the wandb backend."""
|
788
|
-
mutation = gql(
|
789
|
-
"""
|
790
|
-
mutation updateArtifact(
|
791
|
-
$artifactID: ID!,
|
792
|
-
$description: String,
|
793
|
-
$metadata: JSONString,
|
794
|
-
$aliases: [ArtifactAliasInput!]
|
795
|
-
) {
|
796
|
-
updateArtifact(input: {
|
797
|
-
artifactID: $artifactID,
|
798
|
-
description: $description,
|
799
|
-
metadata: $metadata,
|
800
|
-
aliases: $aliases
|
801
|
-
}) {
|
802
|
-
artifact {
|
803
|
-
id
|
804
|
-
}
|
805
|
-
}
|
806
|
-
}
|
807
|
-
"""
|
808
|
-
)
|
809
|
-
introspect_query = gql(
|
810
|
-
"""
|
811
|
-
query ProbeServerAddAliasesInput {
|
812
|
-
AddAliasesInputInfoType: __type(name: "AddAliasesInput") {
|
813
|
-
name
|
814
|
-
inputFields {
|
815
|
-
name
|
816
|
-
}
|
817
|
-
}
|
818
|
-
}
|
819
|
-
"""
|
820
|
-
)
|
821
|
-
res = self.client.execute(introspect_query)
|
822
|
-
valid = res.get("AddAliasesInputInfoType")
|
823
|
-
aliases = None
|
824
|
-
if not valid:
|
825
|
-
# If valid, wandb backend version >= 0.13.0.
|
826
|
-
# This means we can safely remove aliases from this updateArtifact request since we'll be calling
|
827
|
-
# the alias endpoints below in _save_alias_changes.
|
828
|
-
# If not valid, wandb backend version < 0.13.0. This requires aliases to be sent in updateArtifact.
|
829
|
-
aliases = [
|
830
|
-
{
|
831
|
-
"artifactCollectionName": self._artifact_collection_name,
|
832
|
-
"alias": alias,
|
833
|
-
}
|
834
|
-
for alias in self._aliases
|
835
|
-
]
|
836
|
-
|
837
|
-
self.client.execute(
|
838
|
-
mutation,
|
839
|
-
variable_values={
|
840
|
-
"artifactID": self.id,
|
841
|
-
"description": self.description,
|
842
|
-
"metadata": util.json_dumps_safer(self.metadata),
|
843
|
-
"aliases": aliases,
|
844
|
-
},
|
845
|
-
)
|
846
|
-
# Save locally modified aliases
|
847
|
-
self._save_alias_changes()
|
848
|
-
|
849
|
-
def wait(self) -> "Artifact":
|
850
|
-
return self
|
851
|
-
|
852
|
-
@normalize_exceptions
|
853
|
-
def _save_alias_changes(self) -> None:
|
854
|
-
"""Persist alias changes on this artifact to the wandb backend.
|
855
|
-
|
856
|
-
Called by artifact.save().
|
857
|
-
"""
|
858
|
-
aliases_to_add = set(self._aliases) - set(self._frozen_aliases)
|
859
|
-
aliases_to_remove = set(self._frozen_aliases) - set(self._aliases)
|
860
|
-
|
861
|
-
# Introspect
|
862
|
-
introspect_query = gql(
|
863
|
-
"""
|
864
|
-
query ProbeServerAddAliasesInput {
|
865
|
-
AddAliasesInputInfoType: __type(name: "AddAliasesInput") {
|
866
|
-
name
|
867
|
-
inputFields {
|
868
|
-
name
|
869
|
-
}
|
870
|
-
}
|
871
|
-
}
|
872
|
-
"""
|
873
|
-
)
|
874
|
-
res = self.client.execute(introspect_query)
|
875
|
-
valid = res.get("AddAliasesInputInfoType")
|
876
|
-
if not valid:
|
877
|
-
return
|
878
|
-
|
879
|
-
if len(aliases_to_add) > 0:
|
880
|
-
add_mutation = gql(
|
881
|
-
"""
|
882
|
-
mutation addAliases(
|
883
|
-
$artifactID: ID!,
|
884
|
-
$aliases: [ArtifactCollectionAliasInput!]!,
|
885
|
-
) {
|
886
|
-
addAliases(
|
887
|
-
input: {
|
888
|
-
artifactID: $artifactID,
|
889
|
-
aliases: $aliases,
|
890
|
-
}
|
891
|
-
) {
|
892
|
-
success
|
893
|
-
}
|
894
|
-
}
|
895
|
-
"""
|
896
|
-
)
|
897
|
-
self.client.execute(
|
898
|
-
add_mutation,
|
899
|
-
variable_values={
|
900
|
-
"artifactID": self.id,
|
901
|
-
"aliases": [
|
902
|
-
{
|
903
|
-
"artifactCollectionName": self._artifact_collection_name,
|
904
|
-
"alias": alias,
|
905
|
-
"entityName": self._entity,
|
906
|
-
"projectName": self._project,
|
907
|
-
}
|
908
|
-
for alias in aliases_to_add
|
909
|
-
],
|
910
|
-
},
|
911
|
-
)
|
912
|
-
|
913
|
-
if len(aliases_to_remove) > 0:
|
914
|
-
delete_mutation = gql(
|
915
|
-
"""
|
916
|
-
mutation deleteAliases(
|
917
|
-
$artifactID: ID!,
|
918
|
-
$aliases: [ArtifactCollectionAliasInput!]!,
|
919
|
-
) {
|
920
|
-
deleteAliases(
|
921
|
-
input: {
|
922
|
-
artifactID: $artifactID,
|
923
|
-
aliases: $aliases,
|
924
|
-
}
|
925
|
-
) {
|
926
|
-
success
|
927
|
-
}
|
928
|
-
}
|
929
|
-
"""
|
930
|
-
)
|
931
|
-
self.client.execute(
|
932
|
-
delete_mutation,
|
933
|
-
variable_values={
|
934
|
-
"artifactID": self.id,
|
935
|
-
"aliases": [
|
936
|
-
{
|
937
|
-
"artifactCollectionName": self._artifact_collection_name,
|
938
|
-
"alias": alias,
|
939
|
-
"entityName": self._entity,
|
940
|
-
"projectName": self._project,
|
941
|
-
}
|
942
|
-
for alias in aliases_to_remove
|
943
|
-
],
|
944
|
-
},
|
945
|
-
)
|
946
|
-
|
947
|
-
# reset local state
|
948
|
-
self._frozen_aliases = self._aliases
|
949
|
-
|
950
|
-
# TODO: not yet public, but we probably want something like this.
|
951
|
-
def _list(self) -> Iterable[str]:
|
952
|
-
manifest = self._load_manifest()
|
953
|
-
return manifest.entries.keys()
|
954
|
-
|
955
|
-
def __repr__(self) -> str:
|
956
|
-
return f"<Artifact {self.id}>"
|
957
|
-
|
958
|
-
def _load(self) -> Dict[str, Any]:
|
959
|
-
query = gql(
|
960
|
-
"""
|
961
|
-
query Artifact(
|
962
|
-
$entityName: String,
|
963
|
-
$projectName: String,
|
964
|
-
$name: String!
|
965
|
-
) {
|
966
|
-
project(name: $projectName, entityName: $entityName) {
|
967
|
-
artifact(name: $name) {
|
968
|
-
...ArtifactFragment
|
969
|
-
}
|
970
|
-
}
|
971
|
-
}
|
972
|
-
%s
|
973
|
-
"""
|
974
|
-
% ARTIFACT_FRAGMENT
|
975
|
-
)
|
976
|
-
response = None
|
977
|
-
try:
|
978
|
-
response = self.client.execute(
|
979
|
-
query,
|
980
|
-
variable_values={
|
981
|
-
"entityName": self.entity,
|
982
|
-
"projectName": self.project,
|
983
|
-
"name": self.name,
|
984
|
-
},
|
985
|
-
)
|
986
|
-
except Exception:
|
987
|
-
# we check for this after doing the call, since the backend supports raw digest lookups
|
988
|
-
# which don't include ":" and are 32 characters long
|
989
|
-
if ":" not in self.name and len(self.name) != 32:
|
990
|
-
raise ValueError(
|
991
|
-
'Attempted to fetch artifact without alias (e.g. "<artifact_name>:v3" or "<artifact_name>:latest")'
|
992
|
-
)
|
993
|
-
if (
|
994
|
-
response is None
|
995
|
-
or response.get("project") is None
|
996
|
-
or response["project"].get("artifact") is None
|
997
|
-
):
|
998
|
-
raise ValueError(
|
999
|
-
f'Project {self.entity}/{self.project} does not contain artifact: "{self.name}"'
|
1000
|
-
)
|
1001
|
-
return response["project"]["artifact"]
|
1002
|
-
|
1003
|
-
def files(
|
1004
|
-
self, names: Optional[List[str]] = None, per_page: int = 50
|
1005
|
-
) -> ArtifactFiles:
|
1006
|
-
"""Iterate over all files stored in this artifact.
|
1007
|
-
|
1008
|
-
Arguments:
|
1009
|
-
names: (list of str, optional) The filename paths relative to the
|
1010
|
-
root of the artifact you wish to list.
|
1011
|
-
per_page: (int, default 50) The number of files to return per request
|
1012
|
-
|
1013
|
-
Returns:
|
1014
|
-
(`ArtifactFiles`): An iterator containing `File` objects
|
1015
|
-
"""
|
1016
|
-
return ArtifactFiles(self.client, self, names, per_page)
|
1017
|
-
|
1018
|
-
def _load_manifest(self) -> ArtifactManifest:
|
1019
|
-
if self._manifest is None:
|
1020
|
-
query = gql(
|
1021
|
-
"""
|
1022
|
-
query ArtifactManifest(
|
1023
|
-
$entityName: String!,
|
1024
|
-
$projectName: String!,
|
1025
|
-
$name: String!
|
1026
|
-
) {
|
1027
|
-
project(name: $projectName, entityName: $entityName) {
|
1028
|
-
artifact(name: $name) {
|
1029
|
-
currentManifest {
|
1030
|
-
id
|
1031
|
-
file {
|
1032
|
-
id
|
1033
|
-
directUrl
|
1034
|
-
}
|
1035
|
-
}
|
1036
|
-
}
|
1037
|
-
}
|
1038
|
-
}
|
1039
|
-
"""
|
1040
|
-
)
|
1041
|
-
response = self.client.execute(
|
1042
|
-
query,
|
1043
|
-
variable_values={
|
1044
|
-
"entityName": self.entity,
|
1045
|
-
"projectName": self.project,
|
1046
|
-
"name": self.name,
|
1047
|
-
},
|
1048
|
-
)
|
1049
|
-
|
1050
|
-
index_file_url = response["project"]["artifact"]["currentManifest"]["file"][
|
1051
|
-
"directUrl"
|
1052
|
-
]
|
1053
|
-
with requests.get(index_file_url) as req:
|
1054
|
-
req.raise_for_status()
|
1055
|
-
self._manifest = ArtifactManifest.from_manifest_json(
|
1056
|
-
json.loads(util.ensure_text(req.content))
|
1057
|
-
)
|
1058
|
-
|
1059
|
-
self._load_dependent_manifests()
|
1060
|
-
|
1061
|
-
return self._manifest
|
1062
|
-
|
1063
|
-
def _load_dependent_manifests(self) -> None:
|
1064
|
-
"""Interrogate entries and ensure we have loaded their manifests."""
|
1065
|
-
# Make sure dependencies are avail
|
1066
|
-
for entry_key in self.manifest.entries:
|
1067
|
-
entry = self.manifest.entries[entry_key]
|
1068
|
-
if self._manifest_entry_is_artifact_reference(entry):
|
1069
|
-
dep_artifact = self._get_ref_artifact_from_entry(entry)
|
1070
|
-
if dep_artifact not in self._dependent_artifacts:
|
1071
|
-
dep_artifact._load_manifest()
|
1072
|
-
self._dependent_artifacts.append(dep_artifact)
|
1073
|
-
|
1074
|
-
@staticmethod
|
1075
|
-
def _manifest_entry_is_artifact_reference(entry: "ArtifactManifestEntry") -> bool:
|
1076
|
-
"""Determine if an ArtifactManifestEntry is an artifact reference."""
|
1077
|
-
return (
|
1078
|
-
entry.ref is not None
|
1079
|
-
and urllib.parse.urlparse(entry.ref).scheme == "wandb-artifact"
|
1080
|
-
)
|
1081
|
-
|
1082
|
-
def _get_ref_artifact_from_entry(
|
1083
|
-
self, entry: "ArtifactManifestEntry"
|
1084
|
-
) -> "Artifact":
|
1085
|
-
"""Helper function returns the referenced artifact from an entry."""
|
1086
|
-
artifact_id = util.host_from_path(entry.ref)
|
1087
|
-
artifact = Artifact.from_id(hex_to_b64_id(artifact_id), self.client)
|
1088
|
-
assert artifact is not None
|
1089
|
-
return artifact
|
1090
|
-
|
1091
|
-
def used_by(self) -> List[Run]:
|
1092
|
-
"""Retrieve the runs which use this artifact directly.
|
1093
|
-
|
1094
|
-
Returns:
|
1095
|
-
[Run]: a list of Run objects which use this artifact
|
1096
|
-
"""
|
1097
|
-
query = gql(
|
1098
|
-
"""
|
1099
|
-
query ArtifactUsedBy(
|
1100
|
-
$id: ID!,
|
1101
|
-
$before: String,
|
1102
|
-
$after: String,
|
1103
|
-
$first: Int,
|
1104
|
-
$last: Int
|
1105
|
-
) {
|
1106
|
-
artifact(id: $id) {
|
1107
|
-
usedBy(before: $before, after: $after, first: $first, last: $last) {
|
1108
|
-
edges {
|
1109
|
-
node {
|
1110
|
-
name
|
1111
|
-
project {
|
1112
|
-
name
|
1113
|
-
entityName
|
1114
|
-
}
|
1115
|
-
}
|
1116
|
-
}
|
1117
|
-
}
|
1118
|
-
}
|
1119
|
-
}
|
1120
|
-
"""
|
1121
|
-
)
|
1122
|
-
response = self.client.execute(
|
1123
|
-
query,
|
1124
|
-
variable_values={"id": self.id},
|
1125
|
-
)
|
1126
|
-
# yes, "name" is actually id
|
1127
|
-
runs = [
|
1128
|
-
Run(
|
1129
|
-
self.client,
|
1130
|
-
edge["node"]["project"]["entityName"],
|
1131
|
-
edge["node"]["project"]["name"],
|
1132
|
-
edge["node"]["name"],
|
1133
|
-
)
|
1134
|
-
for edge in response.get("artifact", {}).get("usedBy", {}).get("edges", [])
|
1135
|
-
]
|
1136
|
-
return runs
|
1137
|
-
|
1138
|
-
def logged_by(self) -> Optional[Run]:
|
1139
|
-
"""Retrieve the run which logged this artifact.
|
1140
|
-
|
1141
|
-
Returns:
|
1142
|
-
Run: Run object which logged this artifact
|
1143
|
-
"""
|
1144
|
-
query = gql(
|
1145
|
-
"""
|
1146
|
-
query ArtifactCreatedBy(
|
1147
|
-
$id: ID!
|
1148
|
-
) {
|
1149
|
-
artifact(id: $id) {
|
1150
|
-
createdBy {
|
1151
|
-
... on Run {
|
1152
|
-
name
|
1153
|
-
project {
|
1154
|
-
name
|
1155
|
-
entityName
|
1156
|
-
}
|
1157
|
-
}
|
1158
|
-
}
|
1159
|
-
}
|
1160
|
-
}
|
1161
|
-
"""
|
1162
|
-
)
|
1163
|
-
response = self.client.execute(
|
1164
|
-
query,
|
1165
|
-
variable_values={"id": self.id},
|
1166
|
-
)
|
1167
|
-
run_obj = response.get("artifact", {}).get("createdBy", {})
|
1168
|
-
if run_obj is None:
|
1169
|
-
return None
|
1170
|
-
return Run(
|
1171
|
-
self.client,
|
1172
|
-
run_obj["project"]["entityName"],
|
1173
|
-
run_obj["project"]["name"],
|
1174
|
-
run_obj["name"],
|
1175
|
-
)
|
1176
|
-
|
1177
|
-
def new_draft(self) -> "LocalArtifact":
|
1178
|
-
"""Create a new draft artifact with the same content as this committed artifact.
|
1179
|
-
|
1180
|
-
The artifact returned can be extended or modified and logged as a new version.
|
1181
|
-
"""
|
1182
|
-
artifact = wandb.Artifact(self.name.split(":")[0], self.type)
|
1183
|
-
artifact._description = self.description
|
1184
|
-
artifact._metadata = self.metadata
|
1185
|
-
artifact._manifest = ArtifactManifest.from_manifest_json(
|
1186
|
-
self.manifest.to_manifest_json()
|
1187
|
-
)
|
1188
|
-
return artifact
|