wandb 0.15.3__py3-none-any.whl → 0.15.5__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- wandb/__init__.py +1 -1
- wandb/analytics/sentry.py +1 -0
- wandb/apis/importers/base.py +20 -5
- wandb/apis/importers/mlflow.py +7 -1
- wandb/apis/internal.py +12 -0
- wandb/apis/public.py +247 -1387
- wandb/apis/reports/_panels.py +58 -35
- wandb/beta/workflows.py +6 -7
- wandb/cli/cli.py +130 -60
- wandb/data_types.py +3 -1
- wandb/filesync/dir_watcher.py +21 -27
- wandb/filesync/step_checksum.py +8 -8
- wandb/filesync/step_prepare.py +23 -10
- wandb/filesync/step_upload.py +13 -13
- wandb/filesync/upload_job.py +4 -8
- wandb/integration/cohere/__init__.py +3 -0
- wandb/integration/cohere/cohere.py +21 -0
- wandb/integration/cohere/resolver.py +347 -0
- wandb/integration/gym/__init__.py +4 -6
- wandb/integration/huggingface/__init__.py +3 -0
- wandb/integration/huggingface/huggingface.py +18 -0
- wandb/integration/huggingface/resolver.py +213 -0
- wandb/integration/langchain/wandb_tracer.py +16 -179
- wandb/integration/openai/__init__.py +1 -3
- wandb/integration/openai/openai.py +11 -143
- wandb/integration/openai/resolver.py +111 -38
- wandb/integration/sagemaker/config.py +2 -2
- wandb/integration/tensorboard/log.py +4 -4
- wandb/old/settings.py +24 -7
- wandb/proto/v3/wandb_telemetry_pb2.py +12 -12
- wandb/proto/v4/wandb_telemetry_pb2.py +12 -12
- wandb/proto/wandb_deprecated.py +3 -1
- wandb/sdk/__init__.py +1 -1
- wandb/sdk/artifacts/__init__.py +0 -0
- wandb/sdk/artifacts/artifact.py +2101 -0
- wandb/sdk/artifacts/artifact_download_logger.py +42 -0
- wandb/sdk/artifacts/artifact_manifest.py +67 -0
- wandb/sdk/artifacts/artifact_manifest_entry.py +159 -0
- wandb/sdk/artifacts/artifact_manifests/__init__.py +0 -0
- wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +91 -0
- wandb/sdk/{internal → artifacts}/artifact_saver.py +6 -5
- wandb/sdk/artifacts/artifact_state.py +10 -0
- wandb/sdk/{interface/artifacts/artifact_cache.py → artifacts/artifacts_cache.py} +22 -12
- wandb/sdk/artifacts/exceptions.py +55 -0
- wandb/sdk/artifacts/storage_handler.py +59 -0
- wandb/sdk/artifacts/storage_handlers/__init__.py +0 -0
- wandb/sdk/artifacts/storage_handlers/azure_handler.py +192 -0
- wandb/sdk/artifacts/storage_handlers/gcs_handler.py +224 -0
- wandb/sdk/artifacts/storage_handlers/http_handler.py +112 -0
- wandb/sdk/artifacts/storage_handlers/local_file_handler.py +134 -0
- wandb/sdk/artifacts/storage_handlers/multi_handler.py +53 -0
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +301 -0
- wandb/sdk/artifacts/storage_handlers/tracking_handler.py +67 -0
- wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +132 -0
- wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +72 -0
- wandb/sdk/artifacts/storage_layout.py +6 -0
- wandb/sdk/artifacts/storage_policies/__init__.py +0 -0
- wandb/sdk/artifacts/storage_policies/s3_bucket_policy.py +61 -0
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +386 -0
- wandb/sdk/{interface/artifacts/artifact_storage.py → artifacts/storage_policy.py} +5 -57
- wandb/sdk/data_types/_dtypes.py +7 -12
- wandb/sdk/data_types/base_types/json_metadata.py +3 -2
- wandb/sdk/data_types/base_types/media.py +8 -8
- wandb/sdk/data_types/base_types/wb_value.py +12 -13
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +5 -6
- wandb/sdk/data_types/helper_types/classes.py +6 -8
- wandb/sdk/data_types/helper_types/image_mask.py +5 -6
- wandb/sdk/data_types/histogram.py +4 -3
- wandb/sdk/data_types/html.py +3 -4
- wandb/sdk/data_types/image.py +11 -9
- wandb/sdk/data_types/molecule.py +5 -3
- wandb/sdk/data_types/object_3d.py +7 -5
- wandb/sdk/data_types/plotly.py +3 -2
- wandb/sdk/data_types/saved_model.py +11 -11
- wandb/sdk/data_types/trace_tree.py +5 -4
- wandb/sdk/data_types/utils.py +3 -5
- wandb/sdk/data_types/video.py +5 -4
- wandb/sdk/integration_utils/auto_logging.py +215 -0
- wandb/sdk/interface/interface.py +15 -15
- wandb/sdk/internal/file_pusher.py +8 -16
- wandb/sdk/internal/file_stream.py +5 -11
- wandb/sdk/internal/handler.py +13 -1
- wandb/sdk/internal/internal_api.py +287 -13
- wandb/sdk/internal/job_builder.py +119 -30
- wandb/sdk/internal/sender.py +6 -26
- wandb/sdk/internal/settings_static.py +2 -0
- wandb/sdk/internal/system/assets/__init__.py +2 -0
- wandb/sdk/internal/system/assets/gpu.py +42 -0
- wandb/sdk/internal/system/assets/gpu_amd.py +216 -0
- wandb/sdk/internal/system/env_probe_helpers.py +13 -0
- wandb/sdk/internal/system/system_info.py +3 -3
- wandb/sdk/internal/tb_watcher.py +32 -22
- wandb/sdk/internal/thread_local_settings.py +18 -0
- wandb/sdk/launch/_project_spec.py +57 -11
- wandb/sdk/launch/agent/agent.py +147 -65
- wandb/sdk/launch/agent/job_status_tracker.py +34 -0
- wandb/sdk/launch/agent/run_queue_item_file_saver.py +45 -0
- wandb/sdk/launch/builder/abstract.py +5 -1
- wandb/sdk/launch/builder/build.py +21 -18
- wandb/sdk/launch/builder/docker_builder.py +10 -4
- wandb/sdk/launch/builder/kaniko_builder.py +113 -23
- wandb/sdk/launch/builder/noop.py +6 -3
- wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +46 -14
- wandb/sdk/launch/environment/aws_environment.py +3 -2
- wandb/sdk/launch/environment/azure_environment.py +124 -0
- wandb/sdk/launch/environment/gcp_environment.py +2 -4
- wandb/sdk/launch/environment/local_environment.py +1 -1
- wandb/sdk/launch/errors.py +19 -0
- wandb/sdk/launch/github_reference.py +32 -19
- wandb/sdk/launch/launch.py +3 -8
- wandb/sdk/launch/launch_add.py +6 -2
- wandb/sdk/launch/loader.py +21 -2
- wandb/sdk/launch/registry/azure_container_registry.py +132 -0
- wandb/sdk/launch/registry/elastic_container_registry.py +39 -5
- wandb/sdk/launch/registry/google_artifact_registry.py +68 -26
- wandb/sdk/launch/registry/local_registry.py +2 -1
- wandb/sdk/launch/runner/abstract.py +24 -3
- wandb/sdk/launch/runner/kubernetes_runner.py +479 -26
- wandb/sdk/launch/runner/local_container.py +103 -51
- wandb/sdk/launch/runner/local_process.py +1 -1
- wandb/sdk/launch/runner/sagemaker_runner.py +60 -10
- wandb/sdk/launch/runner/vertex_runner.py +10 -5
- wandb/sdk/launch/sweeps/__init__.py +7 -9
- wandb/sdk/launch/sweeps/scheduler.py +307 -77
- wandb/sdk/launch/sweeps/scheduler_sweep.py +2 -1
- wandb/sdk/launch/sweeps/utils.py +82 -35
- wandb/sdk/launch/utils.py +89 -75
- wandb/sdk/lib/_settings_toposort_generated.py +7 -0
- wandb/sdk/lib/capped_dict.py +26 -0
- wandb/sdk/lib/{git.py → gitlib.py} +76 -59
- wandb/sdk/lib/hashutil.py +12 -4
- wandb/sdk/lib/paths.py +96 -8
- wandb/sdk/lib/sock_client.py +2 -2
- wandb/sdk/lib/timer.py +1 -0
- wandb/sdk/service/server.py +22 -9
- wandb/sdk/service/server_sock.py +1 -1
- wandb/sdk/service/service.py +27 -8
- wandb/sdk/verify/verify.py +4 -7
- wandb/sdk/wandb_config.py +2 -6
- wandb/sdk/wandb_init.py +57 -53
- wandb/sdk/wandb_require.py +7 -0
- wandb/sdk/wandb_run.py +61 -223
- wandb/sdk/wandb_settings.py +28 -4
- wandb/testing/relay.py +15 -2
- wandb/util.py +74 -36
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/METADATA +15 -9
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/RECORD +151 -116
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/entry_points.txt +1 -0
- wandb/integration/langchain/util.py +0 -191
- wandb/sdk/interface/artifacts/__init__.py +0 -33
- wandb/sdk/interface/artifacts/artifact.py +0 -615
- wandb/sdk/interface/artifacts/artifact_manifest.py +0 -131
- wandb/sdk/wandb_artifacts.py +0 -2226
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/LICENSE +0 -0
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/WHEEL +0 -0
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/top_level.txt +0 -0
wandb/apis/public.py
CHANGED
@@ -15,20 +15,15 @@ import datetime
|
|
15
15
|
import io
|
16
16
|
import json
|
17
17
|
import logging
|
18
|
-
import multiprocessing.dummy # this uses threads
|
19
18
|
import os
|
20
19
|
import platform
|
21
|
-
import re
|
22
20
|
import shutil
|
23
21
|
import tempfile
|
24
22
|
import time
|
25
23
|
import urllib
|
26
|
-
from collections import namedtuple
|
27
|
-
from functools import partial
|
28
24
|
from typing import (
|
29
25
|
TYPE_CHECKING,
|
30
26
|
Any,
|
31
|
-
Callable,
|
32
27
|
Dict,
|
33
28
|
List,
|
34
29
|
Mapping,
|
@@ -45,21 +40,19 @@ import wandb
|
|
45
40
|
from wandb import __version__, env, util
|
46
41
|
from wandb.apis.internal import Api as InternalApi
|
47
42
|
from wandb.apis.normalize import normalize_exceptions
|
48
|
-
from wandb.data_types import WBValue
|
49
|
-
from wandb.env import get_artifact_dir
|
50
43
|
from wandb.errors import CommError
|
51
|
-
from wandb.errors.term import termlog
|
52
44
|
from wandb.sdk.data_types._dtypes import InvalidType, Type, TypeRegistry
|
53
|
-
from wandb.sdk.
|
45
|
+
from wandb.sdk.internal.thread_local_settings import _thread_local_api_settings
|
46
|
+
from wandb.sdk.launch.errors import LaunchError
|
54
47
|
from wandb.sdk.launch.utils import (
|
55
48
|
LAUNCH_DEFAULT_PROJECT,
|
56
|
-
LaunchError,
|
57
49
|
_fetch_git_repo,
|
58
50
|
apply_patch,
|
51
|
+
convert_jupyter_notebook_to_script,
|
59
52
|
)
|
60
|
-
from wandb.sdk.lib import
|
53
|
+
from wandb.sdk.lib import ipython, retry, runid
|
61
54
|
from wandb.sdk.lib.gql_request import GraphQLSession
|
62
|
-
from wandb.sdk.lib.
|
55
|
+
from wandb.sdk.lib.paths import LogicalPath
|
63
56
|
|
64
57
|
if TYPE_CHECKING:
|
65
58
|
import wandb.apis.reports
|
@@ -144,41 +137,6 @@ fragment ArtifactTypesFragment on ArtifactTypeConnection {
|
|
144
137
|
}
|
145
138
|
"""
|
146
139
|
|
147
|
-
ARTIFACT_FRAGMENT = """
|
148
|
-
fragment ArtifactFragment on Artifact {
|
149
|
-
id
|
150
|
-
digest
|
151
|
-
description
|
152
|
-
state
|
153
|
-
size
|
154
|
-
createdAt
|
155
|
-
updatedAt
|
156
|
-
labels
|
157
|
-
metadata
|
158
|
-
fileCount
|
159
|
-
versionIndex
|
160
|
-
aliases {
|
161
|
-
artifactCollectionName
|
162
|
-
alias
|
163
|
-
}
|
164
|
-
artifactSequence {
|
165
|
-
id
|
166
|
-
name
|
167
|
-
}
|
168
|
-
artifactType {
|
169
|
-
id
|
170
|
-
name
|
171
|
-
project {
|
172
|
-
name
|
173
|
-
entity {
|
174
|
-
name
|
175
|
-
}
|
176
|
-
}
|
177
|
-
}
|
178
|
-
commitHash
|
179
|
-
}
|
180
|
-
"""
|
181
|
-
|
182
140
|
# TODO, factor out common file fragment
|
183
141
|
ARTIFACT_FILES_FRAGMENT = """fragment ArtifactFilesFragment on Artifact {
|
184
142
|
files(names: $fileNames, after: $fileCursor, first: $fileLimit) {
|
@@ -407,7 +365,7 @@ class Api:
|
|
407
365
|
self.settings = InternalApi().settings()
|
408
366
|
_overrides = overrides or {}
|
409
367
|
self._api_key = api_key
|
410
|
-
if self.api_key is None:
|
368
|
+
if self.api_key is None and _thread_local_api_settings.cookies is None:
|
411
369
|
wandb.login(host=_overrides.get("base_url"))
|
412
370
|
self.settings.update(_overrides)
|
413
371
|
if "username" in _overrides and "entity" not in _overrides:
|
@@ -424,15 +382,23 @@ class Api:
|
|
424
382
|
self._reports = {}
|
425
383
|
self._default_entity = None
|
426
384
|
self._timeout = timeout if timeout is not None else self._HTTP_TIMEOUT
|
385
|
+
auth = None
|
386
|
+
if not _thread_local_api_settings.cookies:
|
387
|
+
auth = ("api", self.api_key)
|
427
388
|
self._base_client = Client(
|
428
389
|
transport=GraphQLSession(
|
429
|
-
headers={
|
390
|
+
headers={
|
391
|
+
"User-Agent": self.user_agent,
|
392
|
+
"Use-Admin-Privileges": "true",
|
393
|
+
**(_thread_local_api_settings.headers or {}),
|
394
|
+
},
|
430
395
|
use_json=True,
|
431
396
|
# this timeout won't apply when the DNS lookup fails. in that case, it will be 60s
|
432
397
|
# https://bugs.python.org/issue22889
|
433
398
|
timeout=self._timeout,
|
434
|
-
auth=
|
399
|
+
auth=auth,
|
435
400
|
url="%s/graphql" % self.settings["base_url"],
|
401
|
+
cookies=_thread_local_api_settings.cookies,
|
436
402
|
)
|
437
403
|
)
|
438
404
|
self._client = RetryingClient(self._base_client)
|
@@ -525,6 +491,9 @@ class Api:
|
|
525
491
|
|
526
492
|
@property
|
527
493
|
def api_key(self):
|
494
|
+
# just use thread local api key if it's set
|
495
|
+
if _thread_local_api_settings.api_key:
|
496
|
+
return _thread_local_api_settings.api_key
|
528
497
|
if self._api_key is not None:
|
529
498
|
return self._api_key
|
530
499
|
auth = requests.utils.get_netrc_auth(self.settings["base_url"])
|
@@ -939,14 +908,13 @@ class Api:
|
|
939
908
|
|
940
909
|
@normalize_exceptions
|
941
910
|
def artifact(self, name, type=None):
|
942
|
-
"""Return a single artifact by parsing path in the form `entity/project/
|
911
|
+
"""Return a single artifact by parsing path in the form `entity/project/name`.
|
943
912
|
|
944
913
|
Arguments:
|
945
914
|
name: (str) An artifact name. May be prefixed with entity/project. Valid names
|
946
915
|
can be in the following forms:
|
947
916
|
name:version
|
948
917
|
name:alias
|
949
|
-
digest
|
950
918
|
type: (str, optional) The type of artifact to fetch.
|
951
919
|
|
952
920
|
Returns:
|
@@ -955,7 +923,9 @@ class Api:
|
|
955
923
|
if name is None:
|
956
924
|
raise ValueError("You must specify name= to fetch an artifact.")
|
957
925
|
entity, project, artifact_name = self._parse_artifact_path(name)
|
958
|
-
artifact = Artifact(
|
926
|
+
artifact = wandb.Artifact._from_name(
|
927
|
+
entity, project, artifact_name, self.client
|
928
|
+
)
|
959
929
|
if type is not None and artifact.type != type:
|
960
930
|
raise ValueError(
|
961
931
|
f"type {type} specified but this artifact is of type {artifact.type}"
|
@@ -966,6 +936,10 @@ class Api:
|
|
966
936
|
def job(self, name, path=None):
|
967
937
|
if name is None:
|
968
938
|
raise ValueError("You must specify name= to fetch a job.")
|
939
|
+
elif name.count("/") != 2 or ":" not in name:
|
940
|
+
raise ValueError(
|
941
|
+
"Invalid job specification. A job must be of the form: <entity>/<project>/<job-name>:<alias-or-version>"
|
942
|
+
)
|
969
943
|
return Job(self, name, path)
|
970
944
|
|
971
945
|
|
@@ -2036,7 +2010,7 @@ class Run(Attrs):
|
|
2036
2010
|
root = os.path.abspath(root)
|
2037
2011
|
name = os.path.relpath(path, root)
|
2038
2012
|
with open(os.path.join(root, name), "rb") as f:
|
2039
|
-
api.push({
|
2013
|
+
api.push({LogicalPath(name): f})
|
2040
2014
|
return Files(self.client, self, [name])[0]
|
2041
2015
|
|
2042
2016
|
@normalize_exceptions
|
@@ -2166,10 +2140,10 @@ class Run(Attrs):
|
|
2166
2140
|
)
|
2167
2141
|
api.set_current_run_id(self.id)
|
2168
2142
|
|
2169
|
-
if isinstance(artifact, Artifact):
|
2143
|
+
if isinstance(artifact, wandb.Artifact) and not artifact.is_draft():
|
2170
2144
|
api.use_artifact(artifact.id, use_as=use_as or artifact.name)
|
2171
2145
|
return artifact
|
2172
|
-
elif isinstance(artifact, wandb.Artifact):
|
2146
|
+
elif isinstance(artifact, wandb.Artifact) and artifact.is_draft():
|
2173
2147
|
raise ValueError(
|
2174
2148
|
"Only existing artifacts are accepted by this api. "
|
2175
2149
|
"Manually create one with `wandb artifacts put`"
|
@@ -2194,7 +2168,7 @@ class Run(Attrs):
|
|
2194
2168
|
)
|
2195
2169
|
api.set_current_run_id(self.id)
|
2196
2170
|
|
2197
|
-
if isinstance(artifact, Artifact):
|
2171
|
+
if isinstance(artifact, wandb.Artifact) and not artifact.is_draft():
|
2198
2172
|
artifact_collection_name = artifact.name.split(":")[0]
|
2199
2173
|
api.create_artifact(
|
2200
2174
|
artifact.type,
|
@@ -2203,7 +2177,7 @@ class Run(Attrs):
|
|
2203
2177
|
aliases=aliases,
|
2204
2178
|
)
|
2205
2179
|
return artifact
|
2206
|
-
elif isinstance(artifact, wandb.Artifact):
|
2180
|
+
elif isinstance(artifact, wandb.Artifact) and artifact.is_draft():
|
2207
2181
|
raise ValueError(
|
2208
2182
|
"Only existing artifacts are accepted by this api. "
|
2209
2183
|
"Manually create one with `wandb artifacts put`"
|
@@ -3812,72 +3786,72 @@ class ProjectArtifactCollections(Paginator):
|
|
3812
3786
|
|
3813
3787
|
|
3814
3788
|
class RunArtifacts(Paginator):
|
3815
|
-
|
3816
|
-
"""
|
3817
|
-
|
3818
|
-
|
3819
|
-
|
3820
|
-
|
3821
|
-
|
3822
|
-
|
3823
|
-
|
3824
|
-
|
3825
|
-
|
3826
|
-
|
3789
|
+
def __init__(
|
3790
|
+
self, client: Client, run: "Run", mode="logged", per_page: Optional[int] = 50
|
3791
|
+
):
|
3792
|
+
output_query = gql(
|
3793
|
+
"""
|
3794
|
+
query RunOutputArtifacts(
|
3795
|
+
$entity: String!, $project: String!, $runName: String!, $cursor: String, $perPage: Int,
|
3796
|
+
) {
|
3797
|
+
project(name: $project, entityName: $entity) {
|
3798
|
+
run(name: $runName) {
|
3799
|
+
outputArtifacts(after: $cursor, first: $perPage) {
|
3800
|
+
totalCount
|
3801
|
+
edges {
|
3802
|
+
node {
|
3803
|
+
...ArtifactFragment
|
3804
|
+
}
|
3805
|
+
cursor
|
3806
|
+
}
|
3807
|
+
pageInfo {
|
3808
|
+
endCursor
|
3809
|
+
hasNextPage
|
3827
3810
|
}
|
3828
|
-
cursor
|
3829
|
-
}
|
3830
|
-
pageInfo {
|
3831
|
-
endCursor
|
3832
|
-
hasNextPage
|
3833
3811
|
}
|
3834
3812
|
}
|
3835
3813
|
}
|
3836
3814
|
}
|
3837
|
-
|
3838
|
-
|
3839
|
-
|
3840
|
-
|
3841
|
-
)
|
3815
|
+
%s
|
3816
|
+
"""
|
3817
|
+
% wandb.Artifact._GQL_FRAGMENT
|
3818
|
+
)
|
3842
3819
|
|
3843
|
-
|
3844
|
-
|
3845
|
-
|
3846
|
-
|
3847
|
-
|
3848
|
-
|
3849
|
-
|
3850
|
-
|
3851
|
-
|
3852
|
-
|
3853
|
-
|
3854
|
-
|
3820
|
+
input_query = gql(
|
3821
|
+
"""
|
3822
|
+
query RunInputArtifacts(
|
3823
|
+
$entity: String!, $project: String!, $runName: String!, $cursor: String, $perPage: Int,
|
3824
|
+
) {
|
3825
|
+
project(name: $project, entityName: $entity) {
|
3826
|
+
run(name: $runName) {
|
3827
|
+
inputArtifacts(after: $cursor, first: $perPage) {
|
3828
|
+
totalCount
|
3829
|
+
edges {
|
3830
|
+
node {
|
3831
|
+
...ArtifactFragment
|
3832
|
+
}
|
3833
|
+
cursor
|
3834
|
+
}
|
3835
|
+
pageInfo {
|
3836
|
+
endCursor
|
3837
|
+
hasNextPage
|
3855
3838
|
}
|
3856
|
-
cursor
|
3857
|
-
}
|
3858
|
-
pageInfo {
|
3859
|
-
endCursor
|
3860
|
-
hasNextPage
|
3861
3839
|
}
|
3862
3840
|
}
|
3863
3841
|
}
|
3864
3842
|
}
|
3865
|
-
|
3866
|
-
|
3867
|
-
|
3868
|
-
|
3869
|
-
)
|
3843
|
+
%s
|
3844
|
+
"""
|
3845
|
+
% wandb.Artifact._GQL_FRAGMENT
|
3846
|
+
)
|
3870
3847
|
|
3871
|
-
def __init__(
|
3872
|
-
self, client: Client, run: "Run", mode="logged", per_page: Optional[int] = 50
|
3873
|
-
):
|
3874
3848
|
self.run = run
|
3875
3849
|
if mode == "logged":
|
3876
3850
|
self.run_key = "outputArtifacts"
|
3877
|
-
self.QUERY =
|
3851
|
+
self.QUERY = output_query
|
3878
3852
|
elif mode == "used":
|
3879
3853
|
self.run_key = "inputArtifacts"
|
3880
|
-
self.QUERY =
|
3854
|
+
self.QUERY = input_query
|
3881
3855
|
else:
|
3882
3856
|
raise ValueError("mode must be logged or used")
|
3883
3857
|
|
@@ -3916,14 +3890,14 @@ class RunArtifacts(Paginator):
|
|
3916
3890
|
|
3917
3891
|
def convert_objects(self):
|
3918
3892
|
return [
|
3919
|
-
Artifact(
|
3920
|
-
self.client,
|
3893
|
+
wandb.Artifact._from_attrs(
|
3921
3894
|
self.run.entity,
|
3922
3895
|
self.run.project,
|
3923
3896
|
"{}:v{}".format(
|
3924
3897
|
r["node"]["artifactSequence"]["name"], r["node"]["versionIndex"]
|
3925
3898
|
),
|
3926
3899
|
r["node"],
|
3900
|
+
self.client,
|
3927
3901
|
)
|
3928
3902
|
for r in self.last_response["project"]["run"][self.run_key]["edges"]
|
3929
3903
|
]
|
@@ -4109,1293 +4083,154 @@ class ArtifactCollection:
|
|
4109
4083
|
return f"<ArtifactCollection {self.name} ({self.type})>"
|
4110
4084
|
|
4111
4085
|
|
4112
|
-
class
|
4086
|
+
class ArtifactVersions(Paginator):
|
4087
|
+
"""An iterable collection of artifact versions associated with a project and optional filter.
|
4088
|
+
|
4089
|
+
This is generally used indirectly via the `Api`.artifact_versions method.
|
4090
|
+
"""
|
4091
|
+
|
4113
4092
|
def __init__(
|
4114
4093
|
self,
|
4115
|
-
|
4116
|
-
|
4094
|
+
client: Client,
|
4095
|
+
entity: str,
|
4096
|
+
project: str,
|
4097
|
+
collection_name: str,
|
4098
|
+
type: str,
|
4099
|
+
filters: Optional[Mapping[str, Any]] = None,
|
4100
|
+
order: Optional[str] = None,
|
4101
|
+
per_page: int = 50,
|
4117
4102
|
):
|
4118
|
-
|
4119
|
-
|
4120
|
-
|
4121
|
-
|
4122
|
-
|
4123
|
-
|
4124
|
-
|
4125
|
-
|
4103
|
+
self.entity = entity
|
4104
|
+
self.collection_name = collection_name
|
4105
|
+
self.type = type
|
4106
|
+
self.project = project
|
4107
|
+
self.filters = {"state": "COMMITTED"} if filters is None else filters
|
4108
|
+
self.order = order
|
4109
|
+
variables = {
|
4110
|
+
"project": self.project,
|
4111
|
+
"entity": self.entity,
|
4112
|
+
"order": self.order,
|
4113
|
+
"type": self.type,
|
4114
|
+
"collection": self.collection_name,
|
4115
|
+
"filters": json.dumps(self.filters),
|
4116
|
+
}
|
4117
|
+
self.QUERY = gql(
|
4118
|
+
"""
|
4119
|
+
query Artifacts($project: String!, $entity: String!, $type: String!, $collection: String!, $cursor: String, $perPage: Int = 50, $order: String, $filters: JSONString) {{
|
4120
|
+
project(name: $project, entityName: $entity) {{
|
4121
|
+
artifactType(name: $type) {{
|
4122
|
+
artifactCollection: {}(name: $collection) {{
|
4123
|
+
name
|
4124
|
+
artifacts(filters: $filters, after: $cursor, first: $perPage, order: $order) {{
|
4125
|
+
totalCount
|
4126
|
+
edges {{
|
4127
|
+
node {{
|
4128
|
+
...ArtifactFragment
|
4129
|
+
}}
|
4130
|
+
version
|
4131
|
+
cursor
|
4132
|
+
}}
|
4133
|
+
pageInfo {{
|
4134
|
+
endCursor
|
4135
|
+
hasNextPage
|
4136
|
+
}}
|
4137
|
+
}}
|
4138
|
+
}}
|
4139
|
+
}}
|
4140
|
+
}}
|
4141
|
+
}}
|
4142
|
+
{}
|
4143
|
+
""".format(
|
4144
|
+
artifact_collection_edge_name(
|
4145
|
+
server_supports_artifact_collections_gql_edges(client)
|
4146
|
+
),
|
4147
|
+
wandb.Artifact._GQL_FRAGMENT,
|
4148
|
+
)
|
4126
4149
|
)
|
4127
|
-
|
4150
|
+
super().__init__(client, variables, per_page)
|
4128
4151
|
|
4129
4152
|
@property
|
4130
|
-
def
|
4131
|
-
|
4132
|
-
|
4133
|
-
|
4134
|
-
|
4135
|
-
|
4136
|
-
|
4137
|
-
|
4138
|
-
def copy(self, cache_path, target_path):
|
4139
|
-
raise NotImplementedError()
|
4140
|
-
|
4141
|
-
def download(self, root=None):
|
4142
|
-
root = root or self._parent_artifact._default_root()
|
4143
|
-
dest_path = os.path.join(root, self.path)
|
4144
|
-
|
4145
|
-
self._parent_artifact._add_download_root(root)
|
4146
|
-
manifest = self._parent_artifact._load_manifest()
|
4147
|
-
|
4148
|
-
# Skip checking the cache (and possibly downloading) if the file already exists
|
4149
|
-
# and has the digest we're expecting.
|
4150
|
-
entry = manifest.entries[self.path]
|
4151
|
-
if os.path.exists(dest_path) and entry.digest == md5_file_b64(dest_path):
|
4152
|
-
return dest_path
|
4153
|
+
def length(self):
|
4154
|
+
if self.last_response:
|
4155
|
+
return self.last_response["project"]["artifactType"]["artifactCollection"][
|
4156
|
+
"artifacts"
|
4157
|
+
]["totalCount"]
|
4158
|
+
else:
|
4159
|
+
return None
|
4153
4160
|
|
4154
|
-
|
4155
|
-
|
4161
|
+
@property
|
4162
|
+
def more(self):
|
4163
|
+
if self.last_response:
|
4164
|
+
return self.last_response["project"]["artifactType"]["artifactCollection"][
|
4165
|
+
"artifacts"
|
4166
|
+
]["pageInfo"]["hasNextPage"]
|
4156
4167
|
else:
|
4157
|
-
|
4168
|
+
return True
|
4158
4169
|
|
4159
|
-
|
4170
|
+
@property
|
4171
|
+
def cursor(self):
|
4172
|
+
if self.last_response:
|
4173
|
+
return self.last_response["project"]["artifactType"]["artifactCollection"][
|
4174
|
+
"artifacts"
|
4175
|
+
]["edges"][-1]["cursor"]
|
4176
|
+
else:
|
4177
|
+
return None
|
4160
4178
|
|
4161
|
-
def
|
4162
|
-
|
4163
|
-
|
4164
|
-
|
4165
|
-
|
4166
|
-
|
4179
|
+
def convert_objects(self):
|
4180
|
+
if self.last_response["project"]["artifactType"]["artifactCollection"] is None:
|
4181
|
+
return []
|
4182
|
+
return [
|
4183
|
+
wandb.Artifact._from_attrs(
|
4184
|
+
self.entity,
|
4185
|
+
self.project,
|
4186
|
+
self.collection_name + ":" + a["version"],
|
4187
|
+
a["node"],
|
4188
|
+
self.client,
|
4167
4189
|
)
|
4168
|
-
|
4169
|
-
|
4170
|
-
|
4171
|
-
|
4172
|
-
"wandb-artifact://"
|
4173
|
-
+ b64_to_hex_id(self._parent_artifact.id)
|
4174
|
-
+ "/"
|
4175
|
-
+ self.path
|
4176
|
-
)
|
4177
|
-
|
4178
|
-
|
4179
|
-
class _ArtifactDownloadLogger:
|
4180
|
-
def __init__(
|
4181
|
-
self,
|
4182
|
-
nfiles: int,
|
4183
|
-
clock_for_testing: Callable[[], float] = time.monotonic,
|
4184
|
-
termlog_for_testing=termlog,
|
4185
|
-
) -> None:
|
4186
|
-
self._nfiles = nfiles
|
4187
|
-
self._clock = clock_for_testing
|
4188
|
-
self._termlog = termlog_for_testing
|
4189
|
-
|
4190
|
-
self._n_files_downloaded = 0
|
4191
|
-
self._spinner_index = 0
|
4192
|
-
self._last_log_time = self._clock()
|
4193
|
-
self._lock = multiprocessing.dummy.Lock()
|
4194
|
-
|
4195
|
-
def notify_downloaded(self) -> None:
|
4196
|
-
with self._lock:
|
4197
|
-
self._n_files_downloaded += 1
|
4198
|
-
if self._n_files_downloaded == self._nfiles:
|
4199
|
-
self._termlog(
|
4200
|
-
f" {self._nfiles} of {self._nfiles} files downloaded. ",
|
4201
|
-
# ^ trailing spaces to wipe out ellipsis from previous logs
|
4202
|
-
newline=True,
|
4203
|
-
)
|
4204
|
-
self._last_log_time = self._clock()
|
4205
|
-
elif self._clock() - self._last_log_time > 0.1:
|
4206
|
-
self._spinner_index += 1
|
4207
|
-
spinner = r"-\|/"[self._spinner_index % 4]
|
4208
|
-
self._termlog(
|
4209
|
-
f"{spinner} {self._n_files_downloaded} of {self._nfiles} files downloaded...\r",
|
4210
|
-
newline=False,
|
4211
|
-
)
|
4212
|
-
self._last_log_time = self._clock()
|
4213
|
-
|
4214
|
-
|
4215
|
-
class Artifact(artifacts.Artifact):
|
4216
|
-
"""A wandb Artifact.
|
4217
|
-
|
4218
|
-
An artifact that has been logged, including all its attributes, links to the runs
|
4219
|
-
that use it, and a link to the run that logged it.
|
4220
|
-
|
4221
|
-
Examples:
|
4222
|
-
Basic usage
|
4223
|
-
```
|
4224
|
-
api = wandb.Api()
|
4225
|
-
artifact = api.artifact('project/artifact:alias')
|
4226
|
-
|
4227
|
-
# Get information about the artifact...
|
4228
|
-
artifact.digest
|
4229
|
-
artifact.aliases
|
4230
|
-
```
|
4231
|
-
|
4232
|
-
Updating an artifact
|
4233
|
-
```
|
4234
|
-
artifact = api.artifact('project/artifact:alias')
|
4235
|
-
|
4236
|
-
# Update the description
|
4237
|
-
artifact.description = 'My new description'
|
4238
|
-
|
4239
|
-
# Selectively update metadata keys
|
4240
|
-
artifact.metadata["oldKey"] = "new value"
|
4241
|
-
|
4242
|
-
# Replace the metadata entirely
|
4243
|
-
artifact.metadata = {"newKey": "new value"}
|
4244
|
-
|
4245
|
-
# Add an alias
|
4246
|
-
artifact.aliases.append('best')
|
4247
|
-
|
4248
|
-
# Remove an alias
|
4249
|
-
artifact.aliases.remove('latest')
|
4250
|
-
|
4251
|
-
# Completely replace the aliases
|
4252
|
-
artifact.aliases = ['replaced']
|
4253
|
-
|
4254
|
-
# Persist all artifact modifications
|
4255
|
-
artifact.save()
|
4256
|
-
```
|
4257
|
-
|
4258
|
-
Artifact graph traversal
|
4259
|
-
```
|
4260
|
-
artifact = api.artifact('project/artifact:alias')
|
4261
|
-
|
4262
|
-
# Walk up and down the graph from an artifact:
|
4263
|
-
producer_run = artifact.logged_by()
|
4264
|
-
consumer_runs = artifact.used_by()
|
4265
|
-
|
4266
|
-
# Walk up and down the graph from a run:
|
4267
|
-
logged_artifacts = run.logged_artifacts()
|
4268
|
-
used_artifacts = run.used_artifacts()
|
4269
|
-
```
|
4190
|
+
for a in self.last_response["project"]["artifactType"][
|
4191
|
+
"artifactCollection"
|
4192
|
+
]["artifacts"]["edges"]
|
4193
|
+
]
|
4270
4194
|
|
4271
|
-
Deleting an artifact
|
4272
|
-
```
|
4273
|
-
artifact = api.artifact('project/artifact:alias')
|
4274
|
-
artifact.delete()
|
4275
|
-
```
|
4276
|
-
"""
|
4277
4195
|
|
4196
|
+
class ArtifactFiles(Paginator):
|
4278
4197
|
QUERY = gql(
|
4279
4198
|
"""
|
4280
|
-
query
|
4281
|
-
$
|
4199
|
+
query ArtifactFiles(
|
4200
|
+
$entityName: String!,
|
4201
|
+
$projectName: String!,
|
4202
|
+
$artifactTypeName: String!,
|
4203
|
+
$artifactName: String!
|
4204
|
+
$fileNames: [String!],
|
4205
|
+
$fileCursor: String,
|
4206
|
+
$fileLimit: Int = 50
|
4282
4207
|
) {
|
4283
|
-
|
4284
|
-
|
4285
|
-
|
4286
|
-
|
4287
|
-
id
|
4288
|
-
directUrl
|
4208
|
+
project(name: $projectName, entityName: $entityName) {
|
4209
|
+
artifactType(name: $artifactTypeName) {
|
4210
|
+
artifact(name: $artifactName) {
|
4211
|
+
...ArtifactFilesFragment
|
4289
4212
|
}
|
4290
4213
|
}
|
4291
|
-
...ArtifactFragment
|
4292
4214
|
}
|
4293
4215
|
}
|
4294
4216
|
%s
|
4295
4217
|
"""
|
4296
|
-
%
|
4297
|
-
)
|
4298
|
-
|
4299
|
-
@classmethod
|
4300
|
-
def from_id(cls, artifact_id: str, client: Client):
|
4301
|
-
artifact = artifacts.get_artifacts_cache().get_artifact(artifact_id)
|
4302
|
-
if artifact is not None:
|
4303
|
-
return artifact
|
4304
|
-
response: Mapping[str, Any] = client.execute(
|
4305
|
-
Artifact.QUERY,
|
4306
|
-
variable_values={"id": artifact_id},
|
4307
|
-
)
|
4308
|
-
|
4309
|
-
name = None
|
4310
|
-
if response.get("artifact") is not None:
|
4311
|
-
if response["artifact"].get("aliases") is not None:
|
4312
|
-
aliases = response["artifact"]["aliases"]
|
4313
|
-
name = ":".join(
|
4314
|
-
[aliases[0]["artifactCollectionName"], aliases[0]["alias"]]
|
4315
|
-
)
|
4316
|
-
if len(aliases) > 1:
|
4317
|
-
for alias in aliases:
|
4318
|
-
if alias["alias"] != "latest":
|
4319
|
-
name = ":".join(
|
4320
|
-
[alias["artifactCollectionName"], alias["alias"]]
|
4321
|
-
)
|
4322
|
-
break
|
4323
|
-
|
4324
|
-
p = response.get("artifact", {}).get("artifactType", {}).get("project", {})
|
4325
|
-
project = p.get("name") # defaults to None
|
4326
|
-
entity = p.get("entity", {}).get("name")
|
4327
|
-
|
4328
|
-
artifact = cls(
|
4329
|
-
client=client,
|
4330
|
-
entity=entity,
|
4331
|
-
project=project,
|
4332
|
-
name=name,
|
4333
|
-
attrs=response["artifact"],
|
4334
|
-
)
|
4335
|
-
index_file_url = response["artifact"]["currentManifest"]["file"][
|
4336
|
-
"directUrl"
|
4337
|
-
]
|
4338
|
-
with requests.get(index_file_url) as req:
|
4339
|
-
req.raise_for_status()
|
4340
|
-
artifact._manifest = artifacts.ArtifactManifest.from_manifest_json(
|
4341
|
-
json.loads(util.ensure_text(req.content))
|
4342
|
-
)
|
4343
|
-
|
4344
|
-
artifact._load_dependent_manifests()
|
4345
|
-
|
4346
|
-
return artifact
|
4347
|
-
|
4348
|
-
def __init__(self, client, entity, project, name, attrs=None):
|
4349
|
-
self.client = client
|
4350
|
-
self._entity = entity
|
4351
|
-
self._project = project
|
4352
|
-
self._artifact_name = name
|
4353
|
-
self._artifact_collection_name = name.split(":")[0]
|
4354
|
-
self._attrs = attrs
|
4355
|
-
if self._attrs is None:
|
4356
|
-
self._load()
|
4357
|
-
|
4358
|
-
# The entity and project above are taken from the passed-in artifact version path
|
4359
|
-
# so if the user is pulling an artifact version from an artifact portfolio, the entity/project
|
4360
|
-
# of that portfolio may be different than the birth entity/project of the artifact version.
|
4361
|
-
self._birth_project = (
|
4362
|
-
self._attrs.get("artifactType", {}).get("project", {}).get("name")
|
4363
|
-
)
|
4364
|
-
self._birth_entity = (
|
4365
|
-
self._attrs.get("artifactType", {})
|
4366
|
-
.get("project", {})
|
4367
|
-
.get("entity", {})
|
4368
|
-
.get("name")
|
4369
|
-
)
|
4370
|
-
self._metadata = json.loads(self._attrs.get("metadata") or "{}")
|
4371
|
-
self._description = self._attrs.get("description", None)
|
4372
|
-
self._sequence_name = self._attrs["artifactSequence"]["name"]
|
4373
|
-
self._sequence_version_index = self._attrs.get("versionIndex", None)
|
4374
|
-
# We will only show aliases under the Collection this artifact version is fetched from
|
4375
|
-
# _aliases will be a mutable copy on which the user can append or remove aliases
|
4376
|
-
self._aliases = [
|
4377
|
-
a["alias"]
|
4378
|
-
for a in self._attrs["aliases"]
|
4379
|
-
if not re.match(r"^v\d+$", a["alias"])
|
4380
|
-
and a["artifactCollectionName"] == self._artifact_collection_name
|
4381
|
-
]
|
4382
|
-
self._frozen_aliases = [a for a in self._aliases]
|
4383
|
-
self._manifest = None
|
4384
|
-
self._is_downloaded = False
|
4385
|
-
self._dependent_artifacts = []
|
4386
|
-
self._download_roots = set()
|
4387
|
-
artifacts.get_artifacts_cache().store_artifact(self)
|
4388
|
-
|
4389
|
-
@property
|
4390
|
-
def id(self):
|
4391
|
-
return self._attrs["id"]
|
4392
|
-
|
4393
|
-
@property
|
4394
|
-
def file_count(self):
|
4395
|
-
return self._attrs["fileCount"]
|
4396
|
-
|
4397
|
-
@property
|
4398
|
-
def source_version(self):
|
4399
|
-
"""The artifact's version index under its parent artifact collection.
|
4400
|
-
|
4401
|
-
A string with the format "v{number}".
|
4402
|
-
"""
|
4403
|
-
return f"v{self._sequence_version_index}"
|
4404
|
-
|
4405
|
-
@property
|
4406
|
-
def version(self):
|
4407
|
-
"""The artifact's version index under the given artifact collection.
|
4408
|
-
|
4409
|
-
A string with the format "v{number}".
|
4410
|
-
"""
|
4411
|
-
for a in self._attrs["aliases"]:
|
4412
|
-
if a[
|
4413
|
-
"artifactCollectionName"
|
4414
|
-
] == self._artifact_collection_name and util.alias_is_version_index(
|
4415
|
-
a["alias"]
|
4416
|
-
):
|
4417
|
-
return a["alias"]
|
4418
|
-
return None
|
4419
|
-
|
4420
|
-
@property
|
4421
|
-
def entity(self):
|
4422
|
-
return self._entity
|
4423
|
-
|
4424
|
-
@property
|
4425
|
-
def project(self):
|
4426
|
-
return self._project
|
4427
|
-
|
4428
|
-
@property
|
4429
|
-
def metadata(self):
|
4430
|
-
return self._metadata
|
4431
|
-
|
4432
|
-
@metadata.setter
|
4433
|
-
def metadata(self, metadata):
|
4434
|
-
self._metadata = metadata
|
4435
|
-
|
4436
|
-
@property
|
4437
|
-
def manifest(self):
|
4438
|
-
return self._load_manifest()
|
4439
|
-
|
4440
|
-
@property
|
4441
|
-
def digest(self):
|
4442
|
-
return self._attrs["digest"]
|
4443
|
-
|
4444
|
-
@property
|
4445
|
-
def state(self):
|
4446
|
-
return self._attrs["state"]
|
4447
|
-
|
4448
|
-
@property
|
4449
|
-
def size(self):
|
4450
|
-
return self._attrs["size"]
|
4451
|
-
|
4452
|
-
@property
|
4453
|
-
def created_at(self):
|
4454
|
-
"""The time at which the artifact was created."""
|
4455
|
-
return self._attrs["createdAt"]
|
4456
|
-
|
4457
|
-
@property
|
4458
|
-
def updated_at(self):
|
4459
|
-
"""The time at which the artifact was last updated."""
|
4460
|
-
return self._attrs["updatedAt"] or self._attrs["createdAt"]
|
4461
|
-
|
4462
|
-
@property
|
4463
|
-
def description(self):
|
4464
|
-
return self._description
|
4465
|
-
|
4466
|
-
@description.setter
|
4467
|
-
def description(self, desc):
|
4468
|
-
self._description = desc
|
4469
|
-
|
4470
|
-
@property
|
4471
|
-
def type(self):
|
4472
|
-
return self._attrs["artifactType"]["name"]
|
4473
|
-
|
4474
|
-
@property
|
4475
|
-
def commit_hash(self):
|
4476
|
-
return self._attrs.get("commitHash", "")
|
4477
|
-
|
4478
|
-
@property
|
4479
|
-
def name(self):
|
4480
|
-
if self._sequence_version_index is None:
|
4481
|
-
return self.digest
|
4482
|
-
return f"{self._sequence_name}:v{self._sequence_version_index}"
|
4483
|
-
|
4484
|
-
@property
|
4485
|
-
def aliases(self):
|
4486
|
-
"""The aliases associated with this artifact.
|
4487
|
-
|
4488
|
-
Returns:
|
4489
|
-
List[str]: The aliases associated with this artifact.
|
4490
|
-
|
4491
|
-
"""
|
4492
|
-
return self._aliases
|
4493
|
-
|
4494
|
-
@aliases.setter
|
4495
|
-
def aliases(self, aliases):
|
4496
|
-
for alias in aliases:
|
4497
|
-
if any(char in alias for char in ["/", ":"]):
|
4498
|
-
raise ValueError(
|
4499
|
-
'Invalid alias "%s", slashes and colons are disallowed' % alias
|
4500
|
-
)
|
4501
|
-
self._aliases = aliases
|
4502
|
-
|
4503
|
-
@staticmethod
|
4504
|
-
def expected_type(client, name, entity_name, project_name):
|
4505
|
-
"""Returns the expected type for a given artifact name and project."""
|
4506
|
-
query = gql(
|
4507
|
-
"""
|
4508
|
-
query ArtifactType(
|
4509
|
-
$entityName: String,
|
4510
|
-
$projectName: String,
|
4511
|
-
$name: String!
|
4512
|
-
) {
|
4513
|
-
project(name: $projectName, entityName: $entityName) {
|
4514
|
-
artifact(name: $name) {
|
4515
|
-
artifactType {
|
4516
|
-
name
|
4517
|
-
}
|
4518
|
-
}
|
4519
|
-
}
|
4520
|
-
}
|
4521
|
-
"""
|
4522
|
-
)
|
4523
|
-
if ":" not in name:
|
4524
|
-
name += ":latest"
|
4525
|
-
|
4526
|
-
response = client.execute(
|
4527
|
-
query,
|
4528
|
-
variable_values={
|
4529
|
-
"entityName": entity_name,
|
4530
|
-
"projectName": project_name,
|
4531
|
-
"name": name,
|
4532
|
-
},
|
4533
|
-
)
|
4534
|
-
|
4535
|
-
project = response.get("project")
|
4536
|
-
if project is not None:
|
4537
|
-
artifact = project.get("artifact")
|
4538
|
-
if artifact is not None:
|
4539
|
-
artifact_type = artifact.get("artifactType")
|
4540
|
-
if artifact_type is not None:
|
4541
|
-
return artifact_type.get("name")
|
4542
|
-
|
4543
|
-
return None
|
4544
|
-
|
4545
|
-
@property
|
4546
|
-
def _use_as(self):
|
4547
|
-
return self._attrs.get("_use_as")
|
4548
|
-
|
4549
|
-
@_use_as.setter
|
4550
|
-
def _use_as(self, use_as):
|
4551
|
-
self._attrs["_use_as"] = use_as
|
4552
|
-
return use_as
|
4553
|
-
|
4554
|
-
@normalize_exceptions
|
4555
|
-
def link(self, target_path: str, aliases=None):
|
4556
|
-
if ":" in target_path:
|
4557
|
-
raise ValueError(
|
4558
|
-
f"target_path {target_path} cannot contain `:` because it is not an alias."
|
4559
|
-
)
|
4560
|
-
|
4561
|
-
portfolio, project, entity = util._parse_entity_project_item(target_path)
|
4562
|
-
aliases = util._resolve_aliases(aliases)
|
4563
|
-
|
4564
|
-
EmptyRunProps = namedtuple("Empty", "entity project")
|
4565
|
-
r = wandb.run if wandb.run else EmptyRunProps(entity=None, project=None)
|
4566
|
-
entity = entity or r.entity or self.entity
|
4567
|
-
project = project or r.project or self.project
|
4568
|
-
|
4569
|
-
mutation = gql(
|
4570
|
-
"""
|
4571
|
-
mutation LinkArtifact($artifactID: ID!, $artifactPortfolioName: String!, $entityName: String!, $projectName: String!, $aliases: [ArtifactAliasInput!]) {
|
4572
|
-
linkArtifact(input: {artifactID: $artifactID, artifactPortfolioName: $artifactPortfolioName,
|
4573
|
-
entityName: $entityName,
|
4574
|
-
projectName: $projectName,
|
4575
|
-
aliases: $aliases
|
4576
|
-
}) {
|
4577
|
-
versionIndex
|
4578
|
-
}
|
4579
|
-
}
|
4580
|
-
"""
|
4581
|
-
)
|
4582
|
-
self.client.execute(
|
4583
|
-
mutation,
|
4584
|
-
variable_values={
|
4585
|
-
"artifactID": self.id,
|
4586
|
-
"artifactPortfolioName": portfolio,
|
4587
|
-
"entityName": entity,
|
4588
|
-
"projectName": project,
|
4589
|
-
"aliases": [
|
4590
|
-
{"alias": alias, "artifactCollectionName": portfolio}
|
4591
|
-
for alias in aliases
|
4592
|
-
],
|
4593
|
-
},
|
4594
|
-
)
|
4595
|
-
return True
|
4596
|
-
|
4597
|
-
@normalize_exceptions
|
4598
|
-
def delete(self, delete_aliases=False):
|
4599
|
-
"""Delete an artifact and its files.
|
4600
|
-
|
4601
|
-
Examples:
|
4602
|
-
Delete all the "model" artifacts a run has logged:
|
4603
|
-
```
|
4604
|
-
runs = api.runs(path="my_entity/my_project")
|
4605
|
-
for run in runs:
|
4606
|
-
for artifact in run.logged_artifacts():
|
4607
|
-
if artifact.type == "model":
|
4608
|
-
artifact.delete(delete_aliases=True)
|
4609
|
-
```
|
4610
|
-
|
4611
|
-
Arguments:
|
4612
|
-
delete_aliases: (bool) If true, deletes all aliases associated with the artifact.
|
4613
|
-
Otherwise, this raises an exception if the artifact has existing aliases.
|
4614
|
-
"""
|
4615
|
-
mutation = gql(
|
4616
|
-
"""
|
4617
|
-
mutation DeleteArtifact($artifactID: ID!, $deleteAliases: Boolean) {
|
4618
|
-
deleteArtifact(input: {
|
4619
|
-
artifactID: $artifactID
|
4620
|
-
deleteAliases: $deleteAliases
|
4621
|
-
}) {
|
4622
|
-
artifact {
|
4623
|
-
id
|
4624
|
-
}
|
4625
|
-
}
|
4626
|
-
}
|
4627
|
-
"""
|
4628
|
-
)
|
4629
|
-
self.client.execute(
|
4630
|
-
mutation,
|
4631
|
-
variable_values={
|
4632
|
-
"artifactID": self.id,
|
4633
|
-
"deleteAliases": delete_aliases,
|
4634
|
-
},
|
4635
|
-
)
|
4636
|
-
return True
|
4637
|
-
|
4638
|
-
def new_file(self, name, mode=None):
|
4639
|
-
raise ValueError("Cannot add files to an artifact once it has been saved")
|
4640
|
-
|
4641
|
-
def add_file(self, local_path, name=None, is_tmp=False):
|
4642
|
-
raise ValueError("Cannot add files to an artifact once it has been saved")
|
4643
|
-
|
4644
|
-
def add_dir(self, path, name=None):
|
4645
|
-
raise ValueError("Cannot add files to an artifact once it has been saved")
|
4646
|
-
|
4647
|
-
def add_reference(self, uri, name=None, checksum=True, max_objects=None):
|
4648
|
-
raise ValueError("Cannot add files to an artifact once it has been saved")
|
4649
|
-
|
4650
|
-
def add(self, obj, name):
|
4651
|
-
raise ValueError("Cannot add files to an artifact once it has been saved")
|
4652
|
-
|
4653
|
-
def _add_download_root(self, dir_path):
|
4654
|
-
"""Make `dir_path` a root directory for this artifact."""
|
4655
|
-
self._download_roots.add(os.path.abspath(dir_path))
|
4656
|
-
|
4657
|
-
def _is_download_root(self, dir_path):
|
4658
|
-
"""Determine if `dir_path` is a root directory for this artifact."""
|
4659
|
-
return dir_path in self._download_roots
|
4660
|
-
|
4661
|
-
def _local_path_to_name(self, file_path):
|
4662
|
-
"""Convert a local file path to a path entry in the artifact."""
|
4663
|
-
abs_file_path = os.path.abspath(file_path)
|
4664
|
-
abs_file_parts = abs_file_path.split(os.sep)
|
4665
|
-
for i in range(len(abs_file_parts) + 1):
|
4666
|
-
if self._is_download_root(os.path.join(os.sep, *abs_file_parts[:i])):
|
4667
|
-
return os.path.join(*abs_file_parts[i:])
|
4668
|
-
return None
|
4669
|
-
|
4670
|
-
def _get_obj_entry(self, name):
|
4671
|
-
"""Return an object entry by name, handling any type suffixes.
|
4672
|
-
|
4673
|
-
When objects are added with `.add(obj, name)`, the name is typically changed to
|
4674
|
-
include the suffix of the object type when serializing to JSON. So we need to be
|
4675
|
-
able to resolve a name, without tasking the user with appending .THING.json.
|
4676
|
-
This method returns an entry if it exists by a suffixed name.
|
4677
|
-
|
4678
|
-
Args:
|
4679
|
-
name: (str) name used when adding
|
4680
|
-
"""
|
4681
|
-
self._load_manifest()
|
4682
|
-
|
4683
|
-
type_mapping = WBValue.type_mapping()
|
4684
|
-
for artifact_type_str in type_mapping:
|
4685
|
-
wb_class = type_mapping[artifact_type_str]
|
4686
|
-
wandb_file_name = wb_class.with_suffix(name)
|
4687
|
-
entry = self._manifest.entries.get(wandb_file_name)
|
4688
|
-
if entry is not None:
|
4689
|
-
return entry, wb_class
|
4690
|
-
return None, None
|
4691
|
-
|
4692
|
-
def get_path(self, name):
|
4693
|
-
manifest = self._load_manifest()
|
4694
|
-
entry = manifest.entries.get(name) or self._get_obj_entry(name)[0]
|
4695
|
-
if entry is None:
|
4696
|
-
raise KeyError("Path not contained in artifact: %s" % name)
|
4697
|
-
|
4698
|
-
return _DownloadedArtifactEntry(entry, self)
|
4699
|
-
|
4700
|
-
def get(self, name):
|
4701
|
-
entry, wb_class = self._get_obj_entry(name)
|
4702
|
-
if entry is not None:
|
4703
|
-
# If the entry is a reference from another artifact, then get it directly from that artifact
|
4704
|
-
if self._manifest_entry_is_artifact_reference(entry):
|
4705
|
-
artifact = self._get_ref_artifact_from_entry(entry)
|
4706
|
-
return artifact.get(util.uri_from_path(entry.ref))
|
4707
|
-
|
4708
|
-
# Special case for wandb.Table. This is intended to be a short term optimization.
|
4709
|
-
# Since tables are likely to download many other assets in artifact(s), we eagerly download
|
4710
|
-
# the artifact using the parallelized `artifact.download`. In the future, we should refactor
|
4711
|
-
# the deserialization pattern such that this special case is not needed.
|
4712
|
-
if wb_class == wandb.Table:
|
4713
|
-
self.download(recursive=True)
|
4714
|
-
|
4715
|
-
# Get the ArtifactManifestEntry
|
4716
|
-
item = self.get_path(entry.path)
|
4717
|
-
item_path = item.download()
|
4718
|
-
|
4719
|
-
# Load the object from the JSON blob
|
4720
|
-
result = None
|
4721
|
-
json_obj = {}
|
4722
|
-
with open(item_path) as file:
|
4723
|
-
json_obj = json.load(file)
|
4724
|
-
result = wb_class.from_json(json_obj, self)
|
4725
|
-
result._set_artifact_source(self, name)
|
4726
|
-
return result
|
4727
|
-
|
4728
|
-
def download(self, root=None, recursive=False):
|
4729
|
-
dirpath = root or self._default_root()
|
4730
|
-
self._add_download_root(dirpath)
|
4731
|
-
manifest = self._load_manifest()
|
4732
|
-
nfiles = len(manifest.entries)
|
4733
|
-
size = sum(e.size for e in manifest.entries.values())
|
4734
|
-
log = False
|
4735
|
-
if nfiles > 5000 or size > 50 * 1024 * 1024:
|
4736
|
-
log = True
|
4737
|
-
termlog(
|
4738
|
-
"Downloading large artifact {}, {:.2f}MB. {} files... ".format(
|
4739
|
-
self._artifact_name, size / (1024 * 1024), nfiles
|
4740
|
-
),
|
4741
|
-
)
|
4742
|
-
start_time = datetime.datetime.now()
|
4743
|
-
|
4744
|
-
# Force all the files to download into the same directory.
|
4745
|
-
# Download in parallel
|
4746
|
-
import multiprocessing.dummy # this uses threads
|
4747
|
-
|
4748
|
-
download_logger = _ArtifactDownloadLogger(nfiles=nfiles)
|
4749
|
-
|
4750
|
-
pool = multiprocessing.dummy.Pool(32)
|
4751
|
-
pool.map(
|
4752
|
-
partial(self._download_file, root=dirpath, download_logger=download_logger),
|
4753
|
-
manifest.entries,
|
4754
|
-
)
|
4755
|
-
if recursive:
|
4756
|
-
pool.map(lambda artifact: artifact.download(), self._dependent_artifacts)
|
4757
|
-
pool.close()
|
4758
|
-
pool.join()
|
4759
|
-
|
4760
|
-
self._is_downloaded = True
|
4761
|
-
|
4762
|
-
if log:
|
4763
|
-
now = datetime.datetime.now()
|
4764
|
-
delta = abs((now - start_time).total_seconds())
|
4765
|
-
hours = int(delta // 3600)
|
4766
|
-
minutes = int((delta - hours * 3600) // 60)
|
4767
|
-
seconds = delta - hours * 3600 - minutes * 60
|
4768
|
-
termlog(
|
4769
|
-
f"Done. {hours}:{minutes}:{seconds:.1f}",
|
4770
|
-
prefix=False,
|
4771
|
-
)
|
4772
|
-
return dirpath
|
4773
|
-
|
4774
|
-
def checkout(self, root=None):
|
4775
|
-
dirpath = root or self._default_root(include_version=False)
|
4776
|
-
|
4777
|
-
for root, _, files in os.walk(dirpath):
|
4778
|
-
for file in files:
|
4779
|
-
full_path = os.path.join(root, file)
|
4780
|
-
artifact_path = util.to_forward_slash_path(
|
4781
|
-
os.path.relpath(full_path, start=dirpath)
|
4782
|
-
)
|
4783
|
-
try:
|
4784
|
-
self.get_path(artifact_path)
|
4785
|
-
except KeyError:
|
4786
|
-
# File is not part of the artifact, remove it.
|
4787
|
-
os.remove(full_path)
|
4788
|
-
|
4789
|
-
return self.download(root=dirpath)
|
4790
|
-
|
4791
|
-
def verify(self, root=None):
|
4792
|
-
dirpath = root or self._default_root()
|
4793
|
-
manifest = self._load_manifest()
|
4794
|
-
ref_count = 0
|
4795
|
-
|
4796
|
-
for root, _, files in os.walk(dirpath):
|
4797
|
-
for file in files:
|
4798
|
-
full_path = os.path.join(root, file)
|
4799
|
-
artifact_path = util.to_forward_slash_path(
|
4800
|
-
os.path.relpath(full_path, start=dirpath)
|
4801
|
-
)
|
4802
|
-
try:
|
4803
|
-
self.get_path(artifact_path)
|
4804
|
-
except KeyError:
|
4805
|
-
raise ValueError(
|
4806
|
-
"Found file {} which is not a member of artifact {}".format(
|
4807
|
-
full_path, self.name
|
4808
|
-
)
|
4809
|
-
)
|
4810
|
-
|
4811
|
-
for entry in manifest.entries.values():
|
4812
|
-
if entry.ref is None:
|
4813
|
-
if md5_file_b64(os.path.join(dirpath, entry.path)) != entry.digest:
|
4814
|
-
raise ValueError("Digest mismatch for file: %s" % entry.path)
|
4815
|
-
else:
|
4816
|
-
ref_count += 1
|
4817
|
-
if ref_count > 0:
|
4818
|
-
print("Warning: skipped verification of %s refs" % ref_count)
|
4819
|
-
|
4820
|
-
def file(self, root=None):
|
4821
|
-
"""Download a single file artifact to dir specified by the root.
|
4822
|
-
|
4823
|
-
Arguments:
|
4824
|
-
root: (str, optional) The root directory in which to place the file. Defaults to './artifacts/self.name/'.
|
4825
|
-
|
4826
|
-
Returns:
|
4827
|
-
(str): The full path of the downloaded file.
|
4828
|
-
"""
|
4829
|
-
if root is None:
|
4830
|
-
root = os.path.join(".", "artifacts", self.name)
|
4831
|
-
|
4832
|
-
manifest = self._load_manifest()
|
4833
|
-
nfiles = len(manifest.entries)
|
4834
|
-
if nfiles > 1:
|
4835
|
-
raise ValueError(
|
4836
|
-
"This artifact contains more than one file, call `.download()` to get all files or call "
|
4837
|
-
'.get_path("filename").download()'
|
4838
|
-
)
|
4839
|
-
|
4840
|
-
return self._download_file(list(manifest.entries)[0], root=root)
|
4841
|
-
|
4842
|
-
def _download_file(
|
4843
|
-
self, name, root, download_logger: Optional[_ArtifactDownloadLogger] = None
|
4844
|
-
):
|
4845
|
-
# download file into cache and copy to target dir
|
4846
|
-
downloaded_path = self.get_path(name).download(root)
|
4847
|
-
if download_logger is not None:
|
4848
|
-
download_logger.notify_downloaded()
|
4849
|
-
return downloaded_path
|
4850
|
-
|
4851
|
-
def _default_root(self, include_version=True):
|
4852
|
-
name = self.name if include_version else self._sequence_name
|
4853
|
-
root = os.path.join(get_artifact_dir(), name)
|
4854
|
-
if platform.system() == "Windows":
|
4855
|
-
head, tail = os.path.splitdrive(root)
|
4856
|
-
root = head + tail.replace(":", "-")
|
4857
|
-
return root
|
4858
|
-
|
4859
|
-
def json_encode(self):
|
4860
|
-
return util.artifact_to_json(self)
|
4861
|
-
|
4862
|
-
@normalize_exceptions
|
4863
|
-
def save(self):
|
4864
|
-
"""Persists artifact changes to the wandb backend."""
|
4865
|
-
mutation = gql(
|
4866
|
-
"""
|
4867
|
-
mutation updateArtifact(
|
4868
|
-
$artifactID: ID!,
|
4869
|
-
$description: String,
|
4870
|
-
$metadata: JSONString,
|
4871
|
-
$aliases: [ArtifactAliasInput!]
|
4872
|
-
) {
|
4873
|
-
updateArtifact(input: {
|
4874
|
-
artifactID: $artifactID,
|
4875
|
-
description: $description,
|
4876
|
-
metadata: $metadata,
|
4877
|
-
aliases: $aliases
|
4878
|
-
}) {
|
4879
|
-
artifact {
|
4880
|
-
id
|
4881
|
-
}
|
4882
|
-
}
|
4883
|
-
}
|
4884
|
-
"""
|
4885
|
-
)
|
4886
|
-
introspect_query = gql(
|
4887
|
-
"""
|
4888
|
-
query ProbeServerAddAliasesInput {
|
4889
|
-
AddAliasesInputInfoType: __type(name: "AddAliasesInput") {
|
4890
|
-
name
|
4891
|
-
inputFields {
|
4892
|
-
name
|
4893
|
-
}
|
4894
|
-
}
|
4895
|
-
}
|
4896
|
-
"""
|
4897
|
-
)
|
4898
|
-
res = self.client.execute(introspect_query)
|
4899
|
-
valid = res.get("AddAliasesInputInfoType")
|
4900
|
-
aliases = None
|
4901
|
-
if not valid:
|
4902
|
-
# If valid, wandb backend version >= 0.13.0.
|
4903
|
-
# This means we can safely remove aliases from this updateArtifact request since we'll be calling
|
4904
|
-
# the alias endpoints below in _save_alias_changes.
|
4905
|
-
# If not valid, wandb backend version < 0.13.0. This requires aliases to be sent in updateArtifact.
|
4906
|
-
aliases = [
|
4907
|
-
{
|
4908
|
-
"artifactCollectionName": self._artifact_collection_name,
|
4909
|
-
"alias": alias,
|
4910
|
-
}
|
4911
|
-
for alias in self._aliases
|
4912
|
-
]
|
4913
|
-
|
4914
|
-
self.client.execute(
|
4915
|
-
mutation,
|
4916
|
-
variable_values={
|
4917
|
-
"artifactID": self.id,
|
4918
|
-
"description": self.description,
|
4919
|
-
"metadata": util.json_dumps_safer(self.metadata),
|
4920
|
-
"aliases": aliases,
|
4921
|
-
},
|
4922
|
-
)
|
4923
|
-
# Save locally modified aliases
|
4924
|
-
self._save_alias_changes()
|
4925
|
-
return True
|
4926
|
-
|
4927
|
-
def wait(self):
|
4928
|
-
return self
|
4929
|
-
|
4930
|
-
@normalize_exceptions
|
4931
|
-
def _save_alias_changes(self):
|
4932
|
-
"""Persist alias changes on this artifact to the wandb backend.
|
4933
|
-
|
4934
|
-
Called by artifact.save().
|
4935
|
-
"""
|
4936
|
-
aliases_to_add = set(self._aliases) - set(self._frozen_aliases)
|
4937
|
-
aliases_to_remove = set(self._frozen_aliases) - set(self._aliases)
|
4938
|
-
|
4939
|
-
# Introspect
|
4940
|
-
introspect_query = gql(
|
4941
|
-
"""
|
4942
|
-
query ProbeServerAddAliasesInput {
|
4943
|
-
AddAliasesInputInfoType: __type(name: "AddAliasesInput") {
|
4944
|
-
name
|
4945
|
-
inputFields {
|
4946
|
-
name
|
4947
|
-
}
|
4948
|
-
}
|
4949
|
-
}
|
4950
|
-
"""
|
4951
|
-
)
|
4952
|
-
res = self.client.execute(introspect_query)
|
4953
|
-
valid = res.get("AddAliasesInputInfoType")
|
4954
|
-
if not valid:
|
4955
|
-
return
|
4956
|
-
|
4957
|
-
if len(aliases_to_add) > 0:
|
4958
|
-
add_mutation = gql(
|
4959
|
-
"""
|
4960
|
-
mutation addAliases(
|
4961
|
-
$artifactID: ID!,
|
4962
|
-
$aliases: [ArtifactCollectionAliasInput!]!,
|
4963
|
-
) {
|
4964
|
-
addAliases(
|
4965
|
-
input: {
|
4966
|
-
artifactID: $artifactID,
|
4967
|
-
aliases: $aliases,
|
4968
|
-
}
|
4969
|
-
) {
|
4970
|
-
success
|
4971
|
-
}
|
4972
|
-
}
|
4973
|
-
"""
|
4974
|
-
)
|
4975
|
-
self.client.execute(
|
4976
|
-
add_mutation,
|
4977
|
-
variable_values={
|
4978
|
-
"artifactID": self.id,
|
4979
|
-
"aliases": [
|
4980
|
-
{
|
4981
|
-
"artifactCollectionName": self._artifact_collection_name,
|
4982
|
-
"alias": alias,
|
4983
|
-
"entityName": self._entity,
|
4984
|
-
"projectName": self._project,
|
4985
|
-
}
|
4986
|
-
for alias in aliases_to_add
|
4987
|
-
],
|
4988
|
-
},
|
4989
|
-
)
|
4990
|
-
|
4991
|
-
if len(aliases_to_remove) > 0:
|
4992
|
-
delete_mutation = gql(
|
4993
|
-
"""
|
4994
|
-
mutation deleteAliases(
|
4995
|
-
$artifactID: ID!,
|
4996
|
-
$aliases: [ArtifactCollectionAliasInput!]!,
|
4997
|
-
) {
|
4998
|
-
deleteAliases(
|
4999
|
-
input: {
|
5000
|
-
artifactID: $artifactID,
|
5001
|
-
aliases: $aliases,
|
5002
|
-
}
|
5003
|
-
) {
|
5004
|
-
success
|
5005
|
-
}
|
5006
|
-
}
|
5007
|
-
"""
|
5008
|
-
)
|
5009
|
-
self.client.execute(
|
5010
|
-
delete_mutation,
|
5011
|
-
variable_values={
|
5012
|
-
"artifactID": self.id,
|
5013
|
-
"aliases": [
|
5014
|
-
{
|
5015
|
-
"artifactCollectionName": self._artifact_collection_name,
|
5016
|
-
"alias": alias,
|
5017
|
-
"entityName": self._entity,
|
5018
|
-
"projectName": self._project,
|
5019
|
-
}
|
5020
|
-
for alias in aliases_to_remove
|
5021
|
-
],
|
5022
|
-
},
|
5023
|
-
)
|
5024
|
-
|
5025
|
-
# reset local state
|
5026
|
-
self._frozen_aliases = self._aliases
|
5027
|
-
return True
|
5028
|
-
|
5029
|
-
# TODO: not yet public, but we probably want something like this.
|
5030
|
-
def _list(self):
|
5031
|
-
manifest = self._load_manifest()
|
5032
|
-
return manifest.entries.keys()
|
5033
|
-
|
5034
|
-
def __repr__(self):
|
5035
|
-
return f"<Artifact {self.id}>"
|
5036
|
-
|
5037
|
-
def _load(self):
|
5038
|
-
query = gql(
|
5039
|
-
"""
|
5040
|
-
query Artifact(
|
5041
|
-
$entityName: String,
|
5042
|
-
$projectName: String,
|
5043
|
-
$name: String!
|
5044
|
-
) {
|
5045
|
-
project(name: $projectName, entityName: $entityName) {
|
5046
|
-
artifact(name: $name) {
|
5047
|
-
...ArtifactFragment
|
5048
|
-
}
|
5049
|
-
}
|
5050
|
-
}
|
5051
|
-
%s
|
5052
|
-
"""
|
5053
|
-
% ARTIFACT_FRAGMENT
|
5054
|
-
)
|
5055
|
-
response = None
|
5056
|
-
try:
|
5057
|
-
response = self.client.execute(
|
5058
|
-
query,
|
5059
|
-
variable_values={
|
5060
|
-
"entityName": self.entity,
|
5061
|
-
"projectName": self.project,
|
5062
|
-
"name": self._artifact_name,
|
5063
|
-
},
|
5064
|
-
)
|
5065
|
-
except Exception:
|
5066
|
-
# we check for this after doing the call, since the backend supports raw digest lookups
|
5067
|
-
# which don't include ":" and are 32 characters long
|
5068
|
-
if ":" not in self._artifact_name and len(self._artifact_name) != 32:
|
5069
|
-
raise ValueError(
|
5070
|
-
'Attempted to fetch artifact without alias (e.g. "<artifact_name>:v3" or "<artifact_name>:latest")'
|
5071
|
-
)
|
5072
|
-
if (
|
5073
|
-
response is None
|
5074
|
-
or response.get("project") is None
|
5075
|
-
or response["project"].get("artifact") is None
|
5076
|
-
):
|
5077
|
-
raise ValueError(
|
5078
|
-
f'Project {self.entity}/{self.project} does not contain artifact: "{self._artifact_name}"'
|
5079
|
-
)
|
5080
|
-
self._attrs = response["project"]["artifact"]
|
5081
|
-
return self._attrs
|
5082
|
-
|
5083
|
-
def files(self, names=None, per_page=50):
|
5084
|
-
"""Iterate over all files stored in this artifact.
|
5085
|
-
|
5086
|
-
Arguments:
|
5087
|
-
names: (list of str, optional) The filename paths relative to the
|
5088
|
-
root of the artifact you wish to list.
|
5089
|
-
per_page: (int, default 50) The number of files to return per request
|
5090
|
-
|
5091
|
-
Returns:
|
5092
|
-
(`ArtifactFiles`): An iterator containing `File` objects
|
5093
|
-
"""
|
5094
|
-
return ArtifactFiles(self.client, self, names, per_page)
|
5095
|
-
|
5096
|
-
def _load_manifest(self):
|
5097
|
-
if self._manifest is None:
|
5098
|
-
query = gql(
|
5099
|
-
"""
|
5100
|
-
query ArtifactManifest(
|
5101
|
-
$entityName: String!,
|
5102
|
-
$projectName: String!,
|
5103
|
-
$name: String!
|
5104
|
-
) {
|
5105
|
-
project(name: $projectName, entityName: $entityName) {
|
5106
|
-
artifact(name: $name) {
|
5107
|
-
currentManifest {
|
5108
|
-
id
|
5109
|
-
file {
|
5110
|
-
id
|
5111
|
-
directUrl
|
5112
|
-
}
|
5113
|
-
}
|
5114
|
-
}
|
5115
|
-
}
|
5116
|
-
}
|
5117
|
-
"""
|
5118
|
-
)
|
5119
|
-
response = self.client.execute(
|
5120
|
-
query,
|
5121
|
-
variable_values={
|
5122
|
-
"entityName": self.entity,
|
5123
|
-
"projectName": self.project,
|
5124
|
-
"name": self._artifact_name,
|
5125
|
-
},
|
5126
|
-
)
|
5127
|
-
|
5128
|
-
index_file_url = response["project"]["artifact"]["currentManifest"]["file"][
|
5129
|
-
"directUrl"
|
5130
|
-
]
|
5131
|
-
with requests.get(index_file_url) as req:
|
5132
|
-
req.raise_for_status()
|
5133
|
-
self._manifest = artifacts.ArtifactManifest.from_manifest_json(
|
5134
|
-
json.loads(util.ensure_text(req.content))
|
5135
|
-
)
|
5136
|
-
|
5137
|
-
self._load_dependent_manifests()
|
5138
|
-
|
5139
|
-
return self._manifest
|
5140
|
-
|
5141
|
-
def _load_dependent_manifests(self):
|
5142
|
-
"""Interrogate entries and ensure we have loaded their manifests."""
|
5143
|
-
# Make sure dependencies are avail
|
5144
|
-
for entry_key in self._manifest.entries:
|
5145
|
-
entry = self._manifest.entries[entry_key]
|
5146
|
-
if self._manifest_entry_is_artifact_reference(entry):
|
5147
|
-
dep_artifact = self._get_ref_artifact_from_entry(entry)
|
5148
|
-
if dep_artifact not in self._dependent_artifacts:
|
5149
|
-
dep_artifact._load_manifest()
|
5150
|
-
self._dependent_artifacts.append(dep_artifact)
|
5151
|
-
|
5152
|
-
@staticmethod
|
5153
|
-
def _manifest_entry_is_artifact_reference(entry):
|
5154
|
-
"""Determine if an ArtifactManifestEntry is an artifact reference."""
|
5155
|
-
return (
|
5156
|
-
entry.ref is not None
|
5157
|
-
and urllib.parse.urlparse(entry.ref).scheme == "wandb-artifact"
|
5158
|
-
)
|
5159
|
-
|
5160
|
-
def _get_ref_artifact_from_entry(self, entry):
|
5161
|
-
"""Helper function returns the referenced artifact from an entry."""
|
5162
|
-
artifact_id = util.host_from_path(entry.ref)
|
5163
|
-
return Artifact.from_id(hex_to_b64_id(artifact_id), self.client)
|
5164
|
-
|
5165
|
-
def used_by(self):
|
5166
|
-
"""Retrieve the runs which use this artifact directly.
|
5167
|
-
|
5168
|
-
Returns:
|
5169
|
-
[Run]: a list of Run objects which use this artifact
|
5170
|
-
"""
|
5171
|
-
query = gql(
|
5172
|
-
"""
|
5173
|
-
query ArtifactUsedBy(
|
5174
|
-
$id: ID!,
|
5175
|
-
$before: String,
|
5176
|
-
$after: String,
|
5177
|
-
$first: Int,
|
5178
|
-
$last: Int
|
5179
|
-
) {
|
5180
|
-
artifact(id: $id) {
|
5181
|
-
usedBy(before: $before, after: $after, first: $first, last: $last) {
|
5182
|
-
edges {
|
5183
|
-
node {
|
5184
|
-
name
|
5185
|
-
project {
|
5186
|
-
name
|
5187
|
-
entityName
|
5188
|
-
}
|
5189
|
-
}
|
5190
|
-
}
|
5191
|
-
}
|
5192
|
-
}
|
5193
|
-
}
|
5194
|
-
"""
|
5195
|
-
)
|
5196
|
-
response = self.client.execute(
|
5197
|
-
query,
|
5198
|
-
variable_values={"id": self.id},
|
5199
|
-
)
|
5200
|
-
# yes, "name" is actually id
|
5201
|
-
runs = [
|
5202
|
-
Run(
|
5203
|
-
self.client,
|
5204
|
-
edge["node"]["project"]["entityName"],
|
5205
|
-
edge["node"]["project"]["name"],
|
5206
|
-
edge["node"]["name"],
|
5207
|
-
)
|
5208
|
-
for edge in response.get("artifact", {}).get("usedBy", {}).get("edges", [])
|
5209
|
-
]
|
5210
|
-
return runs
|
5211
|
-
|
5212
|
-
def logged_by(self):
|
5213
|
-
"""Retrieve the run which logged this artifact.
|
5214
|
-
|
5215
|
-
Returns:
|
5216
|
-
Run: Run object which logged this artifact
|
5217
|
-
"""
|
5218
|
-
query = gql(
|
5219
|
-
"""
|
5220
|
-
query ArtifactCreatedBy(
|
5221
|
-
$id: ID!
|
5222
|
-
) {
|
5223
|
-
artifact(id: $id) {
|
5224
|
-
createdBy {
|
5225
|
-
... on Run {
|
5226
|
-
name
|
5227
|
-
project {
|
5228
|
-
name
|
5229
|
-
entityName
|
5230
|
-
}
|
5231
|
-
}
|
5232
|
-
}
|
5233
|
-
}
|
5234
|
-
}
|
5235
|
-
"""
|
5236
|
-
)
|
5237
|
-
response = self.client.execute(
|
5238
|
-
query,
|
5239
|
-
variable_values={"id": self.id},
|
5240
|
-
)
|
5241
|
-
run_obj = response.get("artifact", {}).get("createdBy", {})
|
5242
|
-
if run_obj is not None:
|
5243
|
-
return Run(
|
5244
|
-
self.client,
|
5245
|
-
run_obj["project"]["entityName"],
|
5246
|
-
run_obj["project"]["name"],
|
5247
|
-
run_obj["name"],
|
5248
|
-
)
|
5249
|
-
|
5250
|
-
|
5251
|
-
class ArtifactVersions(Paginator):
|
5252
|
-
"""An iterable collection of artifact versions associated with a project and optional filter.
|
5253
|
-
|
5254
|
-
This is generally used indirectly via the `Api`.artifact_versions method.
|
5255
|
-
"""
|
5256
|
-
|
5257
|
-
def __init__(
|
5258
|
-
self,
|
5259
|
-
client: Client,
|
5260
|
-
entity: str,
|
5261
|
-
project: str,
|
5262
|
-
collection_name: str,
|
5263
|
-
type: str,
|
5264
|
-
filters: Optional[Mapping[str, Any]] = None,
|
5265
|
-
order: Optional[str] = None,
|
5266
|
-
per_page: int = 50,
|
5267
|
-
):
|
5268
|
-
self.entity = entity
|
5269
|
-
self.collection_name = collection_name
|
5270
|
-
self.type = type
|
5271
|
-
self.project = project
|
5272
|
-
self.filters = {"state": "COMMITTED"} if filters is None else filters
|
5273
|
-
self.order = order
|
5274
|
-
variables = {
|
5275
|
-
"project": self.project,
|
5276
|
-
"entity": self.entity,
|
5277
|
-
"order": self.order,
|
5278
|
-
"type": self.type,
|
5279
|
-
"collection": self.collection_name,
|
5280
|
-
"filters": json.dumps(self.filters),
|
5281
|
-
}
|
5282
|
-
self.QUERY = gql(
|
5283
|
-
"""
|
5284
|
-
query Artifacts($project: String!, $entity: String!, $type: String!, $collection: String!, $cursor: String, $perPage: Int = 50, $order: String, $filters: JSONString) {{
|
5285
|
-
project(name: $project, entityName: $entity) {{
|
5286
|
-
artifactType(name: $type) {{
|
5287
|
-
artifactCollection: {}(name: $collection) {{
|
5288
|
-
name
|
5289
|
-
artifacts(filters: $filters, after: $cursor, first: $perPage, order: $order) {{
|
5290
|
-
totalCount
|
5291
|
-
edges {{
|
5292
|
-
node {{
|
5293
|
-
...ArtifactFragment
|
5294
|
-
}}
|
5295
|
-
version
|
5296
|
-
cursor
|
5297
|
-
}}
|
5298
|
-
pageInfo {{
|
5299
|
-
endCursor
|
5300
|
-
hasNextPage
|
5301
|
-
}}
|
5302
|
-
}}
|
5303
|
-
}}
|
5304
|
-
}}
|
5305
|
-
}}
|
5306
|
-
}}
|
5307
|
-
{}
|
5308
|
-
""".format(
|
5309
|
-
artifact_collection_edge_name(
|
5310
|
-
server_supports_artifact_collections_gql_edges(client)
|
5311
|
-
),
|
5312
|
-
ARTIFACT_FRAGMENT,
|
5313
|
-
)
|
5314
|
-
)
|
5315
|
-
super().__init__(client, variables, per_page)
|
5316
|
-
|
5317
|
-
@property
|
5318
|
-
def length(self):
|
5319
|
-
if self.last_response:
|
5320
|
-
return self.last_response["project"]["artifactType"]["artifactCollection"][
|
5321
|
-
"artifacts"
|
5322
|
-
]["totalCount"]
|
5323
|
-
else:
|
5324
|
-
return None
|
5325
|
-
|
5326
|
-
@property
|
5327
|
-
def more(self):
|
5328
|
-
if self.last_response:
|
5329
|
-
return self.last_response["project"]["artifactType"]["artifactCollection"][
|
5330
|
-
"artifacts"
|
5331
|
-
]["pageInfo"]["hasNextPage"]
|
5332
|
-
else:
|
5333
|
-
return True
|
5334
|
-
|
5335
|
-
@property
|
5336
|
-
def cursor(self):
|
5337
|
-
if self.last_response:
|
5338
|
-
return self.last_response["project"]["artifactType"]["artifactCollection"][
|
5339
|
-
"artifacts"
|
5340
|
-
]["edges"][-1]["cursor"]
|
5341
|
-
else:
|
5342
|
-
return None
|
5343
|
-
|
5344
|
-
def convert_objects(self):
|
5345
|
-
if self.last_response["project"]["artifactType"]["artifactCollection"] is None:
|
5346
|
-
return []
|
5347
|
-
return [
|
5348
|
-
Artifact(
|
5349
|
-
self.client,
|
5350
|
-
self.entity,
|
5351
|
-
self.project,
|
5352
|
-
self.collection_name + ":" + a["version"],
|
5353
|
-
a["node"],
|
5354
|
-
)
|
5355
|
-
for a in self.last_response["project"]["artifactType"][
|
5356
|
-
"artifactCollection"
|
5357
|
-
]["artifacts"]["edges"]
|
5358
|
-
]
|
5359
|
-
|
5360
|
-
|
5361
|
-
class ArtifactFiles(Paginator):
|
5362
|
-
QUERY = gql(
|
5363
|
-
"""
|
5364
|
-
query ArtifactFiles(
|
5365
|
-
$entityName: String!,
|
5366
|
-
$projectName: String!,
|
5367
|
-
$artifactTypeName: String!,
|
5368
|
-
$artifactName: String!
|
5369
|
-
$fileNames: [String!],
|
5370
|
-
$fileCursor: String,
|
5371
|
-
$fileLimit: Int = 50
|
5372
|
-
) {
|
5373
|
-
project(name: $projectName, entityName: $entityName) {
|
5374
|
-
artifactType(name: $artifactTypeName) {
|
5375
|
-
artifact(name: $artifactName) {
|
5376
|
-
...ArtifactFilesFragment
|
5377
|
-
}
|
5378
|
-
}
|
5379
|
-
}
|
5380
|
-
}
|
5381
|
-
%s
|
5382
|
-
"""
|
5383
|
-
% ARTIFACT_FILES_FRAGMENT
|
4218
|
+
% ARTIFACT_FILES_FRAGMENT
|
5384
4219
|
)
|
5385
4220
|
|
5386
4221
|
def __init__(
|
5387
4222
|
self,
|
5388
4223
|
client: Client,
|
5389
|
-
artifact: Artifact,
|
4224
|
+
artifact: "wandb.Artifact",
|
5390
4225
|
names: Optional[Sequence[str]] = None,
|
5391
4226
|
per_page: int = 50,
|
5392
4227
|
):
|
5393
4228
|
self.artifact = artifact
|
5394
4229
|
variables = {
|
5395
|
-
"entityName": artifact.
|
5396
|
-
"projectName": artifact.
|
4230
|
+
"entityName": artifact.source_entity,
|
4231
|
+
"projectName": artifact.source_project,
|
5397
4232
|
"artifactTypeName": artifact.type,
|
5398
|
-
"artifactName": artifact.
|
4233
|
+
"artifactName": artifact.source_name,
|
5399
4234
|
"fileNames": names,
|
5400
4235
|
}
|
5401
4236
|
# The server must advertise at least SDK 0.12.21
|
@@ -5452,6 +4287,7 @@ class Job:
|
|
5452
4287
|
_entity: str
|
5453
4288
|
_project: str
|
5454
4289
|
_entrypoint: List[str]
|
4290
|
+
_notebook_job: bool
|
5455
4291
|
|
5456
4292
|
def __init__(self, api: Api, name, path: Optional[str] = None) -> None:
|
5457
4293
|
try:
|
@@ -5468,22 +4304,25 @@ class Job:
|
|
5468
4304
|
self._entity = api.default_entity
|
5469
4305
|
|
5470
4306
|
with open(os.path.join(self._fpath, "wandb-job.json")) as f:
|
5471
|
-
self.
|
5472
|
-
|
5473
|
-
|
4307
|
+
self._job_info: Mapping[str, Any] = json.load(f)
|
4308
|
+
source_info = self._job_info.get("source", {})
|
4309
|
+
# only use notebook job if entrypoint not set and notebook is set
|
4310
|
+
self._notebook_job = source_info.get("notebook", False)
|
4311
|
+
self._entrypoint = source_info.get("entrypoint")
|
4312
|
+
self._args = source_info.get("args")
|
5474
4313
|
self._requirements_file = os.path.join(self._fpath, "requirements.frozen.txt")
|
5475
4314
|
self._input_types = TypeRegistry.type_from_dict(
|
5476
|
-
self.
|
4315
|
+
self._job_info.get("input_types")
|
5477
4316
|
)
|
5478
4317
|
self._output_types = TypeRegistry.type_from_dict(
|
5479
|
-
self.
|
4318
|
+
self._job_info.get("output_types")
|
5480
4319
|
)
|
5481
4320
|
|
5482
|
-
if self.
|
4321
|
+
if self._job_info.get("source_type") == "artifact":
|
5483
4322
|
self._set_configure_launch_project(self._configure_launch_project_artifact)
|
5484
|
-
if self.
|
4323
|
+
if self._job_info.get("source_type") == "repo":
|
5485
4324
|
self._set_configure_launch_project(self._configure_launch_project_repo)
|
5486
|
-
if self.
|
4325
|
+
if self._job_info.get("source_type") == "image":
|
5487
4326
|
self._set_configure_launch_project(self._configure_launch_project_container)
|
5488
4327
|
|
5489
4328
|
@property
|
@@ -5493,8 +4332,26 @@ class Job:
|
|
5493
4332
|
def _set_configure_launch_project(self, func):
|
5494
4333
|
self.configure_launch_project = func
|
5495
4334
|
|
4335
|
+
def _get_code_artifact(self, artifact_string):
|
4336
|
+
artifact_string, base_url, is_id = util.parse_artifact_string(artifact_string)
|
4337
|
+
if is_id:
|
4338
|
+
code_artifact = wandb.Artifact._from_id(artifact_string, self._api._client)
|
4339
|
+
else:
|
4340
|
+
code_artifact = self._api.artifact(name=artifact_string, type="code")
|
4341
|
+
if code_artifact is None:
|
4342
|
+
raise LaunchError("No code artifact found")
|
4343
|
+
return code_artifact
|
4344
|
+
|
4345
|
+
def _configure_launch_project_notebook(self, launch_project):
|
4346
|
+
new_fname = convert_jupyter_notebook_to_script(
|
4347
|
+
self._entrypoint[-1], launch_project.project_dir
|
4348
|
+
)
|
4349
|
+
new_entrypoint = self._entrypoint
|
4350
|
+
new_entrypoint[-1] = new_fname
|
4351
|
+
launch_project.add_entry_point(new_entrypoint)
|
4352
|
+
|
5496
4353
|
def _configure_launch_project_repo(self, launch_project):
|
5497
|
-
git_info = self.
|
4354
|
+
git_info = self._job_info.get("source", {}).get("git", {})
|
5498
4355
|
_fetch_git_repo(
|
5499
4356
|
launch_project.project_dir,
|
5500
4357
|
git_info["remote"],
|
@@ -5504,27 +4361,30 @@ class Job:
|
|
5504
4361
|
with open(os.path.join(self._fpath, "diff.patch")) as f:
|
5505
4362
|
apply_patch(f.read(), launch_project.project_dir)
|
5506
4363
|
shutil.copy(self._requirements_file, launch_project.project_dir)
|
5507
|
-
launch_project.
|
5508
|
-
|
4364
|
+
launch_project.python_version = self._job_info.get("runtime")
|
4365
|
+
if self._notebook_job:
|
4366
|
+
self._configure_launch_project_notebook(launch_project)
|
4367
|
+
else:
|
4368
|
+
launch_project.add_entry_point(self._entrypoint)
|
5509
4369
|
|
5510
4370
|
def _configure_launch_project_artifact(self, launch_project):
|
5511
|
-
artifact_string = self.
|
4371
|
+
artifact_string = self._job_info.get("source", {}).get("artifact")
|
5512
4372
|
if artifact_string is None:
|
5513
4373
|
raise LaunchError(f"Job {self.name} had no source artifact")
|
5514
|
-
|
5515
|
-
|
5516
|
-
|
5517
|
-
else:
|
5518
|
-
code_artifact = self._api.artifact(name=artifact_string, type="code")
|
5519
|
-
if code_artifact is None:
|
5520
|
-
raise LaunchError("No code artifact found")
|
5521
|
-
code_artifact.download(launch_project.project_dir)
|
4374
|
+
|
4375
|
+
code_artifact = self._get_code_artifact(artifact_string)
|
4376
|
+
launch_project.python_version = self._job_info.get("runtime")
|
5522
4377
|
shutil.copy(self._requirements_file, launch_project.project_dir)
|
5523
|
-
|
5524
|
-
launch_project.
|
4378
|
+
|
4379
|
+
code_artifact.download(launch_project.project_dir)
|
4380
|
+
|
4381
|
+
if self._notebook_job:
|
4382
|
+
self._configure_launch_project_notebook(launch_project)
|
4383
|
+
else:
|
4384
|
+
launch_project.add_entry_point(self._entrypoint)
|
5525
4385
|
|
5526
4386
|
def _configure_launch_project_container(self, launch_project):
|
5527
|
-
launch_project.docker_image = self.
|
4387
|
+
launch_project.docker_image = self._job_info.get("source", {}).get("image")
|
5528
4388
|
if launch_project.docker_image is None:
|
5529
4389
|
raise LaunchError(
|
5530
4390
|
"Job had malformed source dictionary without an image key"
|
@@ -5550,7 +4410,7 @@ class Job:
|
|
5550
4410
|
run_config = {}
|
5551
4411
|
for key, item in config.items():
|
5552
4412
|
if util._is_artifact_object(item):
|
5553
|
-
if isinstance(item, wandb.Artifact) and item.
|
4413
|
+
if isinstance(item, wandb.Artifact) and item.is_draft():
|
5554
4414
|
raise ValueError("Cannot queue jobs with unlogged artifacts")
|
5555
4415
|
run_config[key] = util.artifact_to_json(item)
|
5556
4416
|
|