wandb 0.21.1__py3-none-musllinux_1_2_aarch64.whl → 0.21.3__py3-none-musllinux_1_2_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +1 -1
- wandb/__init__.pyi +1 -1
- wandb/apis/public/api.py +1 -2
- wandb/apis/public/artifacts.py +3 -5
- wandb/apis/public/registries/_utils.py +14 -16
- wandb/apis/public/registries/registries_search.py +176 -289
- wandb/apis/public/reports.py +13 -10
- wandb/automations/_generated/delete_automation.py +1 -3
- wandb/automations/_generated/enums.py +13 -11
- wandb/bin/gpu_stats +0 -0
- wandb/bin/wandb-core +0 -0
- wandb/cli/cli.py +47 -2
- wandb/integration/metaflow/data_pandas.py +2 -2
- wandb/integration/metaflow/data_pytorch.py +75 -0
- wandb/integration/metaflow/data_sklearn.py +76 -0
- wandb/integration/metaflow/metaflow.py +16 -87
- wandb/integration/weave/__init__.py +6 -0
- wandb/integration/weave/interface.py +49 -0
- wandb/integration/weave/weave.py +63 -0
- wandb/proto/v3/wandb_internal_pb2.py +3 -2
- wandb/proto/v4/wandb_internal_pb2.py +2 -2
- wandb/proto/v5/wandb_internal_pb2.py +2 -2
- wandb/proto/v6/wandb_internal_pb2.py +2 -2
- wandb/sdk/artifacts/_factories.py +17 -0
- wandb/sdk/artifacts/_generated/__init__.py +221 -13
- wandb/sdk/artifacts/_generated/artifact_by_id.py +17 -0
- wandb/sdk/artifacts/_generated/artifact_by_name.py +22 -0
- wandb/sdk/artifacts/_generated/artifact_collection_membership_file_urls.py +43 -0
- wandb/sdk/artifacts/_generated/artifact_created_by.py +47 -0
- wandb/sdk/artifacts/_generated/artifact_file_urls.py +22 -0
- wandb/sdk/artifacts/_generated/artifact_type.py +31 -0
- wandb/sdk/artifacts/_generated/artifact_used_by.py +43 -0
- wandb/sdk/artifacts/_generated/artifact_via_membership_by_name.py +26 -0
- wandb/sdk/artifacts/_generated/delete_artifact.py +28 -0
- wandb/sdk/artifacts/_generated/enums.py +5 -0
- wandb/sdk/artifacts/_generated/fetch_artifact_manifest.py +38 -0
- wandb/sdk/artifacts/_generated/fetch_registries.py +32 -0
- wandb/sdk/artifacts/_generated/fragments.py +279 -41
- wandb/sdk/artifacts/_generated/link_artifact.py +6 -0
- wandb/sdk/artifacts/_generated/operations.py +654 -51
- wandb/sdk/artifacts/_generated/registry_collections.py +34 -0
- wandb/sdk/artifacts/_generated/registry_versions.py +34 -0
- wandb/sdk/artifacts/_generated/unlink_artifact.py +25 -0
- wandb/sdk/artifacts/_graphql_fragments.py +3 -86
- wandb/sdk/artifacts/_validators.py +6 -4
- wandb/sdk/artifacts/artifact.py +410 -547
- wandb/sdk/artifacts/artifact_file_cache.py +11 -7
- wandb/sdk/artifacts/artifact_manifest.py +10 -9
- wandb/sdk/artifacts/artifact_manifest_entry.py +15 -18
- wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +5 -3
- wandb/sdk/artifacts/storage_handlers/gcs_handler.py +1 -1
- wandb/sdk/artifacts/storage_handlers/http_handler.py +1 -1
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +1 -1
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +1 -1
- wandb/sdk/data_types/video.py +2 -2
- wandb/sdk/interface/interface_queue.py +1 -4
- wandb/sdk/interface/interface_shared.py +26 -37
- wandb/sdk/interface/interface_sock.py +24 -14
- wandb/sdk/internal/settings_static.py +2 -3
- wandb/sdk/launch/create_job.py +12 -1
- wandb/sdk/launch/inputs/internal.py +25 -24
- wandb/sdk/launch/inputs/schema.py +31 -1
- wandb/sdk/launch/runner/kubernetes_runner.py +24 -29
- wandb/sdk/lib/asyncio_compat.py +16 -16
- wandb/sdk/lib/asyncio_manager.py +252 -0
- wandb/sdk/lib/hashutil.py +13 -4
- wandb/sdk/lib/paths.py +23 -21
- wandb/sdk/lib/printer.py +2 -2
- wandb/sdk/lib/printer_asyncio.py +3 -1
- wandb/sdk/lib/retry.py +185 -78
- wandb/sdk/lib/service/service_client.py +106 -0
- wandb/sdk/lib/service/service_connection.py +20 -26
- wandb/sdk/lib/service/service_token.py +30 -13
- wandb/sdk/mailbox/mailbox.py +13 -5
- wandb/sdk/mailbox/mailbox_handle.py +22 -13
- wandb/sdk/mailbox/response_handle.py +42 -106
- wandb/sdk/mailbox/wait_with_progress.py +7 -42
- wandb/sdk/wandb_init.py +11 -25
- wandb/sdk/wandb_login.py +1 -1
- wandb/sdk/wandb_run.py +92 -56
- wandb/sdk/wandb_settings.py +45 -32
- wandb/sdk/wandb_setup.py +176 -96
- wandb/util.py +1 -1
- {wandb-0.21.1.dist-info → wandb-0.21.3.dist-info}/METADATA +2 -2
- {wandb-0.21.1.dist-info → wandb-0.21.3.dist-info}/RECORD +88 -72
- wandb/sdk/interface/interface_relay.py +0 -38
- wandb/sdk/interface/router.py +0 -89
- wandb/sdk/interface/router_queue.py +0 -43
- wandb/sdk/interface/router_relay.py +0 -50
- wandb/sdk/interface/router_sock.py +0 -32
- wandb/sdk/lib/sock_client.py +0 -232
- {wandb-0.21.1.dist-info → wandb-0.21.3.dist-info}/WHEEL +0 -0
- {wandb-0.21.1.dist-info → wandb-0.21.3.dist-info}/entry_points.txt +0 -0
- {wandb-0.21.1.dist-info → wandb-0.21.3.dist-info}/licenses/LICENSE +0 -0
wandb/sdk/artifacts/artifact.py
CHANGED
@@ -17,11 +17,22 @@ import time
|
|
17
17
|
from collections import deque
|
18
18
|
from copy import copy
|
19
19
|
from dataclasses import dataclass
|
20
|
-
from datetime import
|
20
|
+
from datetime import timedelta
|
21
21
|
from functools import partial
|
22
22
|
from itertools import filterfalse
|
23
|
-
from pathlib import PurePosixPath
|
24
|
-
from typing import
|
23
|
+
from pathlib import Path, PurePosixPath
|
24
|
+
from typing import (
|
25
|
+
IO,
|
26
|
+
TYPE_CHECKING,
|
27
|
+
Any,
|
28
|
+
Final,
|
29
|
+
Iterator,
|
30
|
+
Literal,
|
31
|
+
Sequence,
|
32
|
+
Type,
|
33
|
+
cast,
|
34
|
+
final,
|
35
|
+
)
|
25
36
|
from urllib.parse import quote, urljoin, urlparse
|
26
37
|
|
27
38
|
import requests
|
@@ -43,8 +54,9 @@ from wandb.sdk.data_types._dtypes import Type as WBType
|
|
43
54
|
from wandb.sdk.data_types._dtypes import TypeRegistry
|
44
55
|
from wandb.sdk.internal.internal_api import Api as InternalApi
|
45
56
|
from wandb.sdk.internal.thread_local_settings import _thread_local_api_settings
|
46
|
-
from wandb.sdk.lib import
|
57
|
+
from wandb.sdk.lib import retry, runid, telemetry
|
47
58
|
from wandb.sdk.lib.deprecate import deprecate
|
59
|
+
from wandb.sdk.lib.filesystem import check_exists, system_preferred_path
|
48
60
|
from wandb.sdk.lib.hashutil import B64MD5, b64_to_hex_id, md5_file_b64
|
49
61
|
from wandb.sdk.lib.paths import FilePathStr, LogicalPath, StrPath, URIStr
|
50
62
|
from wandb.sdk.lib.runid import generate_id
|
@@ -58,21 +70,45 @@ from wandb.util import (
|
|
58
70
|
vendor_setup,
|
59
71
|
)
|
60
72
|
|
73
|
+
from ._factories import make_storage_policy
|
61
74
|
from ._generated import (
|
62
75
|
ADD_ALIASES_GQL,
|
76
|
+
ARTIFACT_BY_ID_GQL,
|
77
|
+
ARTIFACT_BY_NAME_GQL,
|
78
|
+
ARTIFACT_COLLECTION_MEMBERSHIP_FILE_URLS_GQL,
|
79
|
+
ARTIFACT_CREATED_BY_GQL,
|
80
|
+
ARTIFACT_FILE_URLS_GQL,
|
81
|
+
ARTIFACT_TYPE_GQL,
|
82
|
+
ARTIFACT_USED_BY_GQL,
|
83
|
+
ARTIFACT_VIA_MEMBERSHIP_BY_NAME_GQL,
|
63
84
|
DELETE_ALIASES_GQL,
|
85
|
+
DELETE_ARTIFACT_GQL,
|
86
|
+
FETCH_ARTIFACT_MANIFEST_GQL,
|
64
87
|
FETCH_LINKED_ARTIFACTS_GQL,
|
65
88
|
LINK_ARTIFACT_GQL,
|
89
|
+
UNLINK_ARTIFACT_GQL,
|
66
90
|
UPDATE_ARTIFACT_GQL,
|
67
91
|
ArtifactAliasInput,
|
92
|
+
ArtifactByID,
|
93
|
+
ArtifactByName,
|
68
94
|
ArtifactCollectionAliasInput,
|
95
|
+
ArtifactCollectionMembershipFileUrls,
|
96
|
+
ArtifactCreatedBy,
|
97
|
+
ArtifactFileUrls,
|
98
|
+
ArtifactFragment,
|
99
|
+
ArtifactType,
|
100
|
+
ArtifactUsedBy,
|
101
|
+
ArtifactViaMembershipByName,
|
102
|
+
FetchArtifactManifest,
|
69
103
|
FetchLinkedArtifacts,
|
104
|
+
FileUrlsFragment,
|
70
105
|
LinkArtifact,
|
71
106
|
LinkArtifactInput,
|
107
|
+
MembershipWithArtifact,
|
72
108
|
TagInput,
|
73
109
|
UpdateArtifact,
|
74
110
|
)
|
75
|
-
from ._graphql_fragments import
|
111
|
+
from ._graphql_fragments import omit_artifact_fields
|
76
112
|
from ._validators import (
|
77
113
|
LINKED_ARTIFACT_COLLECTION_TYPE,
|
78
114
|
ArtifactPath,
|
@@ -103,9 +139,6 @@ from .exceptions import (
|
|
103
139
|
)
|
104
140
|
from .staging import get_staging_dir
|
105
141
|
from .storage_handlers.gcs_handler import _GCSIsADirectoryError
|
106
|
-
from .storage_layout import StorageLayout
|
107
|
-
from .storage_policies import WANDB_STORAGE_POLICY
|
108
|
-
from .storage_policy import StoragePolicy
|
109
142
|
|
110
143
|
reset_path = vendor_setup()
|
111
144
|
|
@@ -119,6 +152,9 @@ if TYPE_CHECKING:
|
|
119
152
|
logger = logging.getLogger(__name__)
|
120
153
|
|
121
154
|
|
155
|
+
_MB: Final[int] = 1024 * 1024
|
156
|
+
|
157
|
+
|
122
158
|
@final
|
123
159
|
@dataclass
|
124
160
|
class _DeferredArtifactManifest:
|
@@ -188,11 +224,6 @@ class Artifact:
|
|
188
224
|
# Internal.
|
189
225
|
self._client: RetryingClient | None = None
|
190
226
|
|
191
|
-
storage_policy_cls = StoragePolicy.lookup_by_name(WANDB_STORAGE_POLICY)
|
192
|
-
layout = StorageLayout.V1 if env.get_use_v1_artifacts() else StorageLayout.V2
|
193
|
-
policy_config = {"storageLayout": layout}
|
194
|
-
self._storage_policy = storage_policy_cls.from_config(config=policy_config)
|
195
|
-
|
196
227
|
self._tmp_dir: tempfile.TemporaryDirectory | None = None
|
197
228
|
self._added_objs: dict[int, tuple[WBValue, ArtifactManifestEntry]] = {}
|
198
229
|
self._added_local_paths: dict[str, ArtifactManifestEntry] = {}
|
@@ -236,7 +267,7 @@ class Artifact:
|
|
236
267
|
self._use_as: str | None = None
|
237
268
|
self._state: ArtifactState = ArtifactState.PENDING
|
238
269
|
self._manifest: ArtifactManifest | _DeferredArtifactManifest | None = (
|
239
|
-
ArtifactManifestV1(
|
270
|
+
ArtifactManifestV1(storage_policy=make_storage_policy())
|
240
271
|
)
|
241
272
|
self._commit_hash: str | None = None
|
242
273
|
self._file_count: int | None = None
|
@@ -257,32 +288,22 @@ class Artifact:
|
|
257
288
|
if (artifact := artifact_instance_cache.get(artifact_id)) is not None:
|
258
289
|
return artifact
|
259
290
|
|
260
|
-
query =
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
}
|
267
|
-
"""
|
268
|
-
+ _gql_artifact_fragment()
|
269
|
-
)
|
270
|
-
response = client.execute(
|
271
|
-
query,
|
272
|
-
variable_values={"id": artifact_id},
|
273
|
-
)
|
274
|
-
attrs = response.get("artifact")
|
275
|
-
if attrs is None:
|
291
|
+
query = gql_compat(ARTIFACT_BY_ID_GQL, omit_fields=omit_artifact_fields())
|
292
|
+
|
293
|
+
data = client.execute(query, variable_values={"id": artifact_id})
|
294
|
+
result = ArtifactByID.model_validate(data)
|
295
|
+
|
296
|
+
if (art := result.artifact) is None:
|
276
297
|
return None
|
277
298
|
|
278
|
-
src_collection =
|
279
|
-
src_project = src_collection
|
299
|
+
src_collection = art.artifact_sequence
|
300
|
+
src_project = src_collection.project
|
280
301
|
|
281
|
-
entity_name = src_project
|
282
|
-
project_name = src_project
|
302
|
+
entity_name = src_project.entity_name if src_project else ""
|
303
|
+
project_name = src_project.name if src_project else ""
|
283
304
|
|
284
|
-
name = "{}:v{}"
|
285
|
-
return cls._from_attrs(entity_name, project_name, name,
|
305
|
+
name = f"{src_collection.name}:v{art.version_index}"
|
306
|
+
return cls._from_attrs(entity_name, project_name, name, art, client)
|
286
307
|
|
287
308
|
@classmethod
|
288
309
|
def _membership_from_name(
|
@@ -293,7 +314,7 @@ class Artifact:
|
|
293
314
|
name: str,
|
294
315
|
client: RetryingClient,
|
295
316
|
) -> Artifact:
|
296
|
-
if not InternalApi()._server_supports(
|
317
|
+
if not (api := InternalApi())._server_supports(
|
297
318
|
pb.ServerFeature.PROJECT_ARTIFACT_COLLECTION_MEMBERSHIP
|
298
319
|
):
|
299
320
|
raise UnsupportedError(
|
@@ -301,69 +322,26 @@ class Artifact:
|
|
301
322
|
"by this version of wandb server. Consider updating to the latest version."
|
302
323
|
)
|
303
324
|
|
304
|
-
query =
|
305
|
-
|
306
|
-
|
307
|
-
project(name: $projectName, entityName: $entityName) {{
|
308
|
-
artifactCollectionMembership(name: $name) {{
|
309
|
-
id
|
310
|
-
artifactCollection {{
|
311
|
-
id
|
312
|
-
name
|
313
|
-
project {{
|
314
|
-
id
|
315
|
-
entityName
|
316
|
-
name
|
317
|
-
}}
|
318
|
-
}}
|
319
|
-
artifact {{
|
320
|
-
...ArtifactFragment
|
321
|
-
}}
|
322
|
-
}}
|
323
|
-
}}
|
324
|
-
}}
|
325
|
-
{_gql_artifact_fragment()}
|
326
|
-
"""
|
325
|
+
query = gql_compat(
|
326
|
+
ARTIFACT_VIA_MEMBERSHIP_BY_NAME_GQL,
|
327
|
+
omit_fields=omit_artifact_fields(api=api),
|
327
328
|
)
|
328
329
|
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
response = client.execute(
|
335
|
-
query,
|
336
|
-
variable_values=query_variable_values,
|
337
|
-
)
|
338
|
-
if not (project_attrs := response.get("project")):
|
330
|
+
gql_vars = {"entityName": entity, "projectName": project, "name": name}
|
331
|
+
data = client.execute(query, variable_values=gql_vars)
|
332
|
+
result = ArtifactViaMembershipByName.model_validate(data)
|
333
|
+
|
334
|
+
if not (project_attrs := result.project):
|
339
335
|
raise ValueError(f"project {project!r} not found under entity {entity!r}")
|
340
|
-
|
336
|
+
|
337
|
+
if not (acm_attrs := project_attrs.artifact_collection_membership):
|
341
338
|
entity_project = f"{entity}/{project}"
|
342
339
|
raise ValueError(
|
343
340
|
f"artifact membership {name!r} not found in {entity_project!r}"
|
344
341
|
)
|
345
|
-
if not (ac_attrs := acm_attrs.get("artifactCollection")):
|
346
|
-
raise ValueError("artifact collection not found")
|
347
|
-
if not (
|
348
|
-
(ac_name := ac_attrs.get("name"))
|
349
|
-
and (ac_project_attrs := ac_attrs.get("project"))
|
350
|
-
):
|
351
|
-
raise ValueError("artifact collection project not found")
|
352
|
-
ac_project = ac_project_attrs.get("name")
|
353
|
-
ac_entity = ac_project_attrs.get("entityName")
|
354
|
-
if is_artifact_registry_project(ac_project) and project == "model-registry":
|
355
|
-
wandb.termwarn(
|
356
|
-
"This model registry has been migrated and will be discontinued. "
|
357
|
-
f"Your request was redirected to the corresponding artifact `{ac_name}` in the new registry. "
|
358
|
-
f"Please update your paths to point to the migrated registry directly, `{ac_project}/{ac_name}`."
|
359
|
-
)
|
360
|
-
entity = ac_entity
|
361
|
-
project = ac_project
|
362
|
-
if not (attrs := acm_attrs.get("artifact")):
|
363
|
-
entity_project = f"{entity}/{project}"
|
364
|
-
raise ValueError(f"artifact {name!r} not found in {entity_project!r}")
|
365
342
|
|
366
|
-
|
343
|
+
target_path = ArtifactPath(prefix=entity, project=project, name=name)
|
344
|
+
return cls._from_membership(acm_attrs, target=target_path, client=client)
|
367
345
|
|
368
346
|
@classmethod
|
369
347
|
def _from_name(
|
@@ -375,59 +353,71 @@ class Artifact:
|
|
375
353
|
client: RetryingClient,
|
376
354
|
enable_tracking: bool = False,
|
377
355
|
) -> Artifact:
|
378
|
-
if InternalApi()._server_supports(
|
356
|
+
if (api := InternalApi())._server_supports(
|
379
357
|
pb.ServerFeature.PROJECT_ARTIFACT_COLLECTION_MEMBERSHIP
|
380
358
|
):
|
381
359
|
return cls._membership_from_name(
|
382
|
-
entity=entity,
|
383
|
-
project=project,
|
384
|
-
name=name,
|
385
|
-
client=client,
|
360
|
+
entity=entity, project=project, name=name, client=client
|
386
361
|
)
|
387
362
|
|
388
|
-
|
363
|
+
supports_enable_tracking_gql_var = api.server_project_type_introspection()
|
364
|
+
omit_vars = None if supports_enable_tracking_gql_var else {"enableTracking"}
|
365
|
+
|
366
|
+
gql_vars = {
|
389
367
|
"entityName": entity,
|
390
368
|
"projectName": project,
|
391
369
|
"name": name,
|
370
|
+
"enableTracking": enable_tracking,
|
392
371
|
}
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
InternalApi().server_project_type_introspection()
|
372
|
+
query = gql_compat(
|
373
|
+
ARTIFACT_BY_NAME_GQL,
|
374
|
+
omit_variables=omit_vars,
|
375
|
+
omit_fields=omit_artifact_fields(api=api),
|
398
376
|
)
|
399
|
-
if server_supports_enabling_artifact_usage_tracking:
|
400
|
-
query_vars.append("$enableTracking: Boolean")
|
401
|
-
query_args.append("enableTracking: $enableTracking")
|
402
|
-
query_variable_values["enableTracking"] = enable_tracking
|
403
|
-
|
404
|
-
vars_str = ", ".join(query_vars)
|
405
|
-
args_str = ", ".join(query_args)
|
406
|
-
|
407
|
-
query = gql(
|
408
|
-
f"""
|
409
|
-
query ArtifactByName({vars_str}) {{
|
410
|
-
project(name: $projectName, entityName: $entityName) {{
|
411
|
-
artifact({args_str}) {{
|
412
|
-
...ArtifactFragment
|
413
|
-
}}
|
414
|
-
}}
|
415
|
-
}}
|
416
|
-
{_gql_artifact_fragment()}
|
417
|
-
"""
|
418
|
-
)
|
419
|
-
response = client.execute(
|
420
|
-
query,
|
421
|
-
variable_values=query_variable_values,
|
422
|
-
)
|
423
|
-
project_attrs = response.get("project")
|
424
|
-
if not project_attrs:
|
425
|
-
raise ValueError(f"project '{project}' not found under entity '{entity}'")
|
426
|
-
attrs = project_attrs.get("artifact")
|
427
|
-
if not attrs:
|
428
|
-
raise ValueError(f"artifact '{name}' not found in '{entity}/{project}'")
|
429
377
|
|
430
|
-
|
378
|
+
data = client.execute(query, variable_values=gql_vars)
|
379
|
+
result = ArtifactByName.model_validate(data)
|
380
|
+
|
381
|
+
if not (proj_attrs := result.project):
|
382
|
+
raise ValueError(f"project {project!r} not found under entity {entity!r}")
|
383
|
+
|
384
|
+
if not (art_attrs := proj_attrs.artifact):
|
385
|
+
entity_project = f"{entity}/{project}"
|
386
|
+
raise ValueError(f"artifact {name!r} not found in {entity_project!r}")
|
387
|
+
|
388
|
+
return cls._from_attrs(entity, project, name, art_attrs, client)
|
389
|
+
|
390
|
+
@classmethod
|
391
|
+
def _from_membership(
|
392
|
+
cls,
|
393
|
+
membership: MembershipWithArtifact,
|
394
|
+
target: ArtifactPath,
|
395
|
+
client: RetryingClient,
|
396
|
+
) -> Artifact:
|
397
|
+
if not (
|
398
|
+
(collection := membership.artifact_collection)
|
399
|
+
and (name := collection.name)
|
400
|
+
and (proj := collection.project)
|
401
|
+
):
|
402
|
+
raise ValueError("Missing artifact collection project in GraphQL response")
|
403
|
+
|
404
|
+
if is_artifact_registry_project(proj.name) and (
|
405
|
+
target.project == "model-registry"
|
406
|
+
):
|
407
|
+
wandb.termwarn(
|
408
|
+
"This model registry has been migrated and will be discontinued. "
|
409
|
+
f"Your request was redirected to the corresponding artifact {name!r} in the new registry. "
|
410
|
+
f"Please update your paths to point to the migrated registry directly, '{proj.name}/{name}'."
|
411
|
+
)
|
412
|
+
new_entity, new_project = proj.entity_name, proj.name
|
413
|
+
else:
|
414
|
+
new_entity = cast(str, target.prefix)
|
415
|
+
new_project = cast(str, target.project)
|
416
|
+
|
417
|
+
if not (artifact := membership.artifact):
|
418
|
+
raise ValueError(f"Artifact {target.to_str()!r} not found in response")
|
419
|
+
|
420
|
+
return cls._from_attrs(new_entity, new_project, target.name, artifact, client)
|
431
421
|
|
432
422
|
@classmethod
|
433
423
|
def _from_attrs(
|
@@ -435,7 +425,7 @@ class Artifact:
|
|
435
425
|
entity: str,
|
436
426
|
project: str,
|
437
427
|
name: str,
|
438
|
-
attrs: dict[str, Any],
|
428
|
+
attrs: dict[str, Any] | ArtifactFragment,
|
439
429
|
client: RetryingClient,
|
440
430
|
aliases: list[str] | None = None,
|
441
431
|
) -> Artifact:
|
@@ -445,7 +435,9 @@ class Artifact:
|
|
445
435
|
artifact._entity = entity
|
446
436
|
artifact._project = project
|
447
437
|
artifact._name = name
|
448
|
-
|
438
|
+
|
439
|
+
validated_attrs = ArtifactFragment.model_validate(attrs)
|
440
|
+
artifact._assign_attrs(validated_attrs, aliases)
|
449
441
|
|
450
442
|
artifact.finalize()
|
451
443
|
|
@@ -458,29 +450,24 @@ class Artifact:
|
|
458
450
|
# doesn't make it clear if the artifact is a link or not and have to manually set it.
|
459
451
|
def _assign_attrs(
|
460
452
|
self,
|
461
|
-
|
453
|
+
art: ArtifactFragment,
|
462
454
|
aliases: list[str] | None = None,
|
463
455
|
is_link: bool | None = None,
|
464
456
|
) -> None:
|
465
457
|
"""Update this Artifact's attributes using the server response."""
|
466
|
-
self._id =
|
467
|
-
|
468
|
-
src_version = f"v{attrs['versionIndex']}"
|
469
|
-
src_collection = attrs["artifactSequence"]
|
470
|
-
src_project = src_collection["project"]
|
458
|
+
self._id = art.id
|
471
459
|
|
472
|
-
|
473
|
-
|
474
|
-
self._source_name = f"{src_collection['name']}:{src_version}"
|
475
|
-
self._source_version = src_version
|
460
|
+
src_collection = art.artifact_sequence
|
461
|
+
src_project = src_collection.project
|
476
462
|
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
463
|
+
self._source_entity = src_project.entity_name if src_project else ""
|
464
|
+
self._source_project = src_project.name if src_project else ""
|
465
|
+
self._source_name = f"{src_collection.name}:v{art.version_index}"
|
466
|
+
self._source_version = f"v{art.version_index}"
|
481
467
|
|
482
|
-
|
483
|
-
|
468
|
+
self._entity = self._entity or self._source_entity
|
469
|
+
self._project = self._project or self._source_project
|
470
|
+
self._name = self._name or self._source_name
|
484
471
|
|
485
472
|
# TODO: Refactor artifact query to fetch artifact via membership instead
|
486
473
|
# and get the collection type
|
@@ -488,33 +475,35 @@ class Artifact:
|
|
488
475
|
self._is_link = (
|
489
476
|
self._entity != self._source_entity
|
490
477
|
or self._project != self._source_project
|
491
|
-
or self._name != self._source_name
|
478
|
+
or self._name.split(":")[0] != self._source_name.split(":")[0]
|
492
479
|
)
|
493
480
|
else:
|
494
481
|
self._is_link = is_link
|
495
482
|
|
496
|
-
self._type =
|
497
|
-
self._description =
|
498
|
-
|
499
|
-
entity = self._entity
|
500
|
-
project = self._project
|
501
|
-
collection, *_ = self._name.split(":")
|
483
|
+
self._type = art.artifact_type.name
|
484
|
+
self._description = art.description
|
502
485
|
|
503
|
-
processed_aliases = []
|
504
486
|
# The future of aliases is to move all alias fetches to the membership level
|
505
487
|
# so we don't have to do the collection fetches below
|
506
488
|
if aliases:
|
507
489
|
processed_aliases = aliases
|
508
|
-
|
490
|
+
elif art.aliases:
|
491
|
+
entity = self._entity
|
492
|
+
project = self._project
|
493
|
+
collection = self._name.split(":")[0]
|
509
494
|
processed_aliases = [
|
510
|
-
|
511
|
-
for
|
512
|
-
if
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
495
|
+
art_alias.alias
|
496
|
+
for art_alias in art.aliases
|
497
|
+
if (
|
498
|
+
(coll := art_alias.artifact_collection)
|
499
|
+
and (proj := coll.project)
|
500
|
+
and proj.entity_name == entity
|
501
|
+
and proj.name == project
|
502
|
+
and coll.name == collection
|
503
|
+
)
|
517
504
|
]
|
505
|
+
else:
|
506
|
+
processed_aliases = []
|
518
507
|
|
519
508
|
version_aliases = list(filter(alias_is_version_index, processed_aliases))
|
520
509
|
other_aliases = list(filterfalse(alias_is_version_index, processed_aliases))
|
@@ -524,49 +513,42 @@ class Artifact:
|
|
524
513
|
version_aliases, too_short=TooFewItemsError, too_long=TooManyItemsError
|
525
514
|
)
|
526
515
|
except TooFewItemsError:
|
527
|
-
version =
|
516
|
+
version = f"v{art.version_index}" # default to the source version
|
528
517
|
except TooManyItemsError:
|
529
518
|
msg = f"Expected at most one version alias, got {len(version_aliases)}: {version_aliases!r}"
|
530
519
|
raise ValueError(msg) from None
|
531
520
|
|
532
521
|
self._version = version
|
533
|
-
|
534
|
-
if ":" not in self._name:
|
535
|
-
self._name = f"{self._name}:{version}"
|
522
|
+
self._name = self._name if (":" in self._name) else f"{self._name}:{version}"
|
536
523
|
|
537
524
|
self._aliases = other_aliases
|
538
|
-
self._saved_aliases = copy(
|
525
|
+
self._saved_aliases = copy(self._aliases)
|
539
526
|
|
540
|
-
|
541
|
-
self.
|
542
|
-
self._saved_tags = copy(tags)
|
527
|
+
self._tags = [tag.name for tag in (art.tags or [])]
|
528
|
+
self._saved_tags = copy(self._tags)
|
543
529
|
|
544
|
-
|
545
|
-
self._metadata = validate_metadata(
|
546
|
-
json.loads(metadata_str) if metadata_str else {}
|
547
|
-
)
|
530
|
+
self._metadata = validate_metadata(art.metadata)
|
548
531
|
|
549
532
|
self._ttl_duration_seconds = validate_ttl_duration_seconds(
|
550
|
-
|
533
|
+
art.ttl_duration_seconds
|
551
534
|
)
|
552
535
|
self._ttl_is_inherited = (
|
553
|
-
True if (
|
536
|
+
True if (art.ttl_is_inherited is None) else art.ttl_is_inherited
|
554
537
|
)
|
555
538
|
|
556
|
-
self._state = ArtifactState(
|
539
|
+
self._state = ArtifactState(art.state)
|
557
540
|
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
|
562
|
-
|
563
|
-
self._manifest = _DeferredArtifactManifest(manifest_url)
|
541
|
+
self._manifest = (
|
542
|
+
_DeferredArtifactManifest(manifest.file.direct_url)
|
543
|
+
if (manifest := art.current_manifest)
|
544
|
+
else None
|
545
|
+
)
|
564
546
|
|
565
|
-
self._commit_hash =
|
566
|
-
self._file_count =
|
567
|
-
self._created_at =
|
568
|
-
self._updated_at =
|
569
|
-
self._history_step =
|
547
|
+
self._commit_hash = art.commit_hash
|
548
|
+
self._file_count = art.file_count
|
549
|
+
self._created_at = art.created_at
|
550
|
+
self._updated_at = art.updated_at
|
551
|
+
self._history_step = art.history_step
|
570
552
|
|
571
553
|
@ensure_logged
|
572
554
|
def new_draft(self) -> Artifact:
|
@@ -1063,37 +1045,24 @@ class Artifact:
|
|
1063
1045
|
return self._manifest
|
1064
1046
|
|
1065
1047
|
if self._manifest is None:
|
1066
|
-
|
1067
|
-
""
|
1068
|
-
|
1069
|
-
|
1070
|
-
|
1071
|
-
|
1072
|
-
|
1073
|
-
|
1074
|
-
|
1075
|
-
|
1076
|
-
|
1077
|
-
|
1078
|
-
|
1079
|
-
|
1080
|
-
|
1081
|
-
|
1082
|
-
|
1083
|
-
|
1084
|
-
)
|
1085
|
-
assert self._client is not None
|
1086
|
-
response = self._client.execute(
|
1087
|
-
query,
|
1088
|
-
variable_values={
|
1089
|
-
"entityName": self._entity,
|
1090
|
-
"projectName": self._project,
|
1091
|
-
"name": self._name,
|
1092
|
-
},
|
1093
|
-
)
|
1094
|
-
attrs = response["project"]["artifact"]
|
1095
|
-
manifest_url = attrs["currentManifest"]["file"]["directUrl"]
|
1096
|
-
self._manifest = self._load_manifest(manifest_url)
|
1048
|
+
if self._client is None:
|
1049
|
+
raise RuntimeError("Client not initialized for artifact queries")
|
1050
|
+
|
1051
|
+
query = gql(FETCH_ARTIFACT_MANIFEST_GQL)
|
1052
|
+
gql_vars = {
|
1053
|
+
"entityName": self.entity,
|
1054
|
+
"projectName": self.project,
|
1055
|
+
"name": self.name,
|
1056
|
+
}
|
1057
|
+
data = self._client.execute(query, variable_values=gql_vars)
|
1058
|
+
result = FetchArtifactManifest.model_validate(data)
|
1059
|
+
if not (
|
1060
|
+
(project := result.project)
|
1061
|
+
and (artifact := project.artifact)
|
1062
|
+
and (manifest := artifact.current_manifest)
|
1063
|
+
):
|
1064
|
+
raise ValueError("Failed to fetch artifact manifest")
|
1065
|
+
self._manifest = self._load_manifest(manifest.file.direct_url)
|
1097
1066
|
|
1098
1067
|
return self._manifest
|
1099
1068
|
|
@@ -1112,11 +1081,7 @@ class Artifact:
|
|
1112
1081
|
|
1113
1082
|
Includes any references tracked by this artifact.
|
1114
1083
|
"""
|
1115
|
-
|
1116
|
-
for entry in self.manifest.entries.values():
|
1117
|
-
if entry.size is not None:
|
1118
|
-
total_size += entry.size
|
1119
|
-
return total_size
|
1084
|
+
return sum(entry.size for entry in self.manifest.entries.values() if entry.size)
|
1120
1085
|
|
1121
1086
|
@property
|
1122
1087
|
@ensure_logged
|
@@ -1184,7 +1149,7 @@ class Artifact:
|
|
1184
1149
|
Returns:
|
1185
1150
|
Boolean. `False` if artifact is saved. `True` if artifact is not saved.
|
1186
1151
|
"""
|
1187
|
-
return self._state
|
1152
|
+
return self._state is ArtifactState.PENDING
|
1188
1153
|
|
1189
1154
|
def _is_draft_save_started(self) -> bool:
|
1190
1155
|
return self._save_handle is not None
|
@@ -1205,7 +1170,7 @@ class Artifact:
|
|
1205
1170
|
settings: A settings object to use when initializing an automatic run. Most
|
1206
1171
|
commonly used in testing harness.
|
1207
1172
|
"""
|
1208
|
-
if self._state
|
1173
|
+
if self._state is not ArtifactState.PENDING:
|
1209
1174
|
return self._update()
|
1210
1175
|
|
1211
1176
|
if self._incremental:
|
@@ -1266,31 +1231,20 @@ class Artifact:
|
|
1266
1231
|
return self
|
1267
1232
|
|
1268
1233
|
def _populate_after_save(self, artifact_id: str) -> None:
|
1269
|
-
|
1270
|
-
query ArtifactByIDShort($id: ID!) {
|
1271
|
-
artifact(id: $id) {
|
1272
|
-
...ArtifactFragment
|
1273
|
-
}
|
1274
|
-
}
|
1275
|
-
""" + _gql_artifact_fragment()
|
1234
|
+
assert self._client is not None
|
1276
1235
|
|
1277
|
-
query =
|
1236
|
+
query = gql_compat(ARTIFACT_BY_ID_GQL, omit_fields=omit_artifact_fields())
|
1278
1237
|
|
1279
|
-
|
1280
|
-
|
1281
|
-
query,
|
1282
|
-
variable_values={"id": artifact_id},
|
1283
|
-
)
|
1238
|
+
data = self._client.execute(query, variable_values={"id": artifact_id})
|
1239
|
+
result = ArtifactByID.model_validate(data)
|
1284
1240
|
|
1285
|
-
|
1286
|
-
attrs = response["artifact"]
|
1287
|
-
except LookupError:
|
1241
|
+
if not (artifact := result.artifact):
|
1288
1242
|
raise ValueError(f"Unable to fetch artifact with id: {artifact_id!r}")
|
1289
|
-
|
1290
|
-
|
1291
|
-
|
1292
|
-
|
1293
|
-
|
1243
|
+
|
1244
|
+
# _populate_after_save is only called on source artifacts, not linked artifacts
|
1245
|
+
# We have to manually set is_link because we aren't fetching the collection the artifact.
|
1246
|
+
# That requires greater refactoring for commitArtifact to return the artifact collection type.
|
1247
|
+
self._assign_attrs(artifact, is_link=False)
|
1294
1248
|
|
1295
1249
|
@normalize_exceptions
|
1296
1250
|
def _update(self) -> None:
|
@@ -1375,7 +1329,7 @@ class Artifact:
|
|
1375
1329
|
for alias in self.aliases
|
1376
1330
|
]
|
1377
1331
|
|
1378
|
-
omit_fields = omit_artifact_fields(
|
1332
|
+
omit_fields = omit_artifact_fields()
|
1379
1333
|
omit_variables = set()
|
1380
1334
|
|
1381
1335
|
if {"ttlIsInherited", "ttlDurationSeconds"} & omit_fields:
|
@@ -1399,7 +1353,9 @@ class Artifact:
|
|
1399
1353
|
omit_variables |= {"tagsToAdd", "tagsToDelete"}
|
1400
1354
|
|
1401
1355
|
mutation = gql_compat(
|
1402
|
-
UPDATE_ARTIFACT_GQL,
|
1356
|
+
UPDATE_ARTIFACT_GQL,
|
1357
|
+
omit_variables=omit_variables,
|
1358
|
+
omit_fields=omit_fields,
|
1403
1359
|
)
|
1404
1360
|
|
1405
1361
|
gql_vars = {
|
@@ -1417,7 +1373,7 @@ class Artifact:
|
|
1417
1373
|
result = UpdateArtifact.model_validate(data).update_artifact
|
1418
1374
|
if not (result and (artifact := result.artifact)):
|
1419
1375
|
raise ValueError("Unable to parse updateArtifact response")
|
1420
|
-
self._assign_attrs(artifact
|
1376
|
+
self._assign_attrs(artifact)
|
1421
1377
|
|
1422
1378
|
self._ttl_changed = False # Reset after updating artifact
|
1423
1379
|
|
@@ -1481,7 +1437,7 @@ class Artifact:
|
|
1481
1437
|
self._tmp_dir = tempfile.TemporaryDirectory()
|
1482
1438
|
path = os.path.join(self._tmp_dir.name, name.lstrip("/"))
|
1483
1439
|
|
1484
|
-
|
1440
|
+
Path(path).parent.mkdir(parents=True, exist_ok=True)
|
1485
1441
|
try:
|
1486
1442
|
with fsync_open(path, mode, encoding) as f:
|
1487
1443
|
yield f
|
@@ -1588,30 +1544,27 @@ class Artifact:
|
|
1588
1544
|
ValueError: Policy must be "mutable" or "immutable"
|
1589
1545
|
"""
|
1590
1546
|
if not os.path.isdir(local_path):
|
1591
|
-
raise ValueError(f"Path is not a directory: {local_path}")
|
1547
|
+
raise ValueError(f"Path is not a directory: {local_path!r}")
|
1592
1548
|
|
1593
1549
|
termlog(
|
1594
|
-
"Adding directory to artifact ({})... "
|
1595
|
-
os.path.join(".", os.path.normpath(local_path))
|
1596
|
-
),
|
1550
|
+
f"Adding directory to artifact ({Path('.', local_path)})... ",
|
1597
1551
|
newline=False,
|
1598
1552
|
)
|
1599
|
-
start_time = time.
|
1553
|
+
start_time = time.monotonic()
|
1600
1554
|
|
1601
|
-
paths =
|
1555
|
+
paths: deque[tuple[str, str]] = deque()
|
1556
|
+
logical_root = name or "" # shared prefix, if any, for logical paths
|
1602
1557
|
for dirpath, _, filenames in os.walk(local_path, followlinks=True):
|
1603
1558
|
for fname in filenames:
|
1604
1559
|
physical_path = os.path.join(dirpath, fname)
|
1605
1560
|
logical_path = os.path.relpath(physical_path, start=local_path)
|
1606
|
-
|
1607
|
-
logical_path = os.path.join(name, logical_path)
|
1561
|
+
logical_path = os.path.join(logical_root, logical_path)
|
1608
1562
|
paths.append((logical_path, physical_path))
|
1609
1563
|
|
1610
|
-
def add_manifest_file(
|
1611
|
-
logical_path, physical_path = log_phy_path
|
1564
|
+
def add_manifest_file(logical_pth: str, physical_pth: str) -> None:
|
1612
1565
|
self._add_local_file(
|
1613
|
-
name=
|
1614
|
-
path=
|
1566
|
+
name=logical_pth,
|
1567
|
+
path=physical_pth,
|
1615
1568
|
skip_cache=skip_cache,
|
1616
1569
|
policy=policy,
|
1617
1570
|
overwrite=merge,
|
@@ -1619,11 +1572,11 @@ class Artifact:
|
|
1619
1572
|
|
1620
1573
|
num_threads = 8
|
1621
1574
|
pool = multiprocessing.dummy.Pool(num_threads)
|
1622
|
-
pool.
|
1575
|
+
pool.starmap(add_manifest_file, paths)
|
1623
1576
|
pool.close()
|
1624
1577
|
pool.join()
|
1625
1578
|
|
1626
|
-
termlog("Done. %.1fs" % (time.
|
1579
|
+
termlog("Done. %.1fs" % (time.monotonic() - start_time), prefix=False)
|
1627
1580
|
|
1628
1581
|
@ensure_not_finalized
|
1629
1582
|
def add_reference(
|
@@ -1699,7 +1652,7 @@ class Artifact:
|
|
1699
1652
|
"References must be URIs. To reference a local file, use file://"
|
1700
1653
|
)
|
1701
1654
|
|
1702
|
-
manifest_entries = self.
|
1655
|
+
manifest_entries = self.manifest.storage_policy.store_reference(
|
1703
1656
|
self,
|
1704
1657
|
URIStr(uri_str),
|
1705
1658
|
name=name,
|
@@ -1861,10 +1814,8 @@ class Artifact:
|
|
1861
1814
|
return
|
1862
1815
|
|
1863
1816
|
path = str(PurePosixPath(item))
|
1864
|
-
entry
|
1865
|
-
|
1866
|
-
self.manifest.remove_entry(entry)
|
1867
|
-
return
|
1817
|
+
if entry := self.manifest.get_entry_by_path(path):
|
1818
|
+
return self.manifest.remove_entry(entry)
|
1868
1819
|
|
1869
1820
|
entries = self.manifest.get_entries_in_directory(path)
|
1870
1821
|
if not entries:
|
@@ -1922,8 +1873,7 @@ class Artifact:
|
|
1922
1873
|
|
1923
1874
|
# If the entry is a reference from another artifact, then get it directly from
|
1924
1875
|
# that artifact.
|
1925
|
-
referenced_id
|
1926
|
-
if referenced_id:
|
1876
|
+
if referenced_id := entry._referenced_artifact_id():
|
1927
1877
|
assert self._client is not None
|
1928
1878
|
artifact = self._from_id(referenced_id, client=self._client)
|
1929
1879
|
assert artifact is not None
|
@@ -1942,10 +1892,9 @@ class Artifact:
|
|
1942
1892
|
item_path = item.download()
|
1943
1893
|
|
1944
1894
|
# Load the object from the JSON blob
|
1945
|
-
result = None
|
1946
|
-
json_obj = {}
|
1947
1895
|
with open(item_path) as file:
|
1948
1896
|
json_obj = json.load(file)
|
1897
|
+
|
1949
1898
|
result = wb_class.from_json(json_obj, self)
|
1950
1899
|
result._set_artifact_source(self, name)
|
1951
1900
|
return result
|
@@ -1959,10 +1908,9 @@ class Artifact:
|
|
1959
1908
|
Returns:
|
1960
1909
|
The artifact relative name.
|
1961
1910
|
"""
|
1962
|
-
entry
|
1963
|
-
|
1964
|
-
|
1965
|
-
return entry.path
|
1911
|
+
if entry := self._added_local_paths.get(local_path):
|
1912
|
+
return entry.path
|
1913
|
+
return None
|
1966
1914
|
|
1967
1915
|
def _get_obj_entry(
|
1968
1916
|
self, name: str
|
@@ -1979,8 +1927,7 @@ class Artifact:
|
|
1979
1927
|
"""
|
1980
1928
|
for wb_class in WBValue.type_mapping().values():
|
1981
1929
|
wandb_file_name = wb_class.with_suffix(name)
|
1982
|
-
entry
|
1983
|
-
if entry is not None:
|
1930
|
+
if entry := self.manifest.entries.get(wandb_file_name):
|
1984
1931
|
return entry, wb_class
|
1985
1932
|
return None, None
|
1986
1933
|
|
@@ -2021,7 +1968,7 @@ class Artifact:
|
|
2021
1968
|
Raises:
|
2022
1969
|
ArtifactNotLoggedError: If the artifact is not logged.
|
2023
1970
|
"""
|
2024
|
-
root = FilePathStr(
|
1971
|
+
root = FilePathStr(root or self._default_root())
|
2025
1972
|
self._add_download_root(root)
|
2026
1973
|
|
2027
1974
|
# TODO: download artifacts using core when implemented
|
@@ -2110,14 +2057,14 @@ class Artifact:
|
|
2110
2057
|
multipart: bool | None = None,
|
2111
2058
|
) -> FilePathStr:
|
2112
2059
|
nfiles = len(self.manifest.entries)
|
2113
|
-
|
2114
|
-
|
2115
|
-
if nfiles > 5000 or
|
2116
|
-
log = True
|
2060
|
+
size_mb = self.size / _MB
|
2061
|
+
|
2062
|
+
if log := (nfiles > 5000 or size_mb > 50):
|
2117
2063
|
termlog(
|
2118
|
-
f"Downloading large artifact {self.name}, {
|
2064
|
+
f"Downloading large artifact {self.name!r}, {size_mb:.2f}MB. {nfiles!r} files...",
|
2119
2065
|
)
|
2120
|
-
start_time =
|
2066
|
+
start_time = time.monotonic()
|
2067
|
+
|
2121
2068
|
download_logger = ArtifactDownloadLogger(nfiles=nfiles)
|
2122
2069
|
|
2123
2070
|
def _download_entry(
|
@@ -2156,44 +2103,62 @@ class Artifact:
|
|
2156
2103
|
cookies=_thread_local_api_settings.cookies,
|
2157
2104
|
headers=_thread_local_api_settings.headers,
|
2158
2105
|
)
|
2106
|
+
|
2107
|
+
batch_size = env.get_artifact_fetch_file_url_batch_size()
|
2108
|
+
|
2159
2109
|
active_futures = set()
|
2160
|
-
|
2161
|
-
|
2162
|
-
|
2163
|
-
|
2164
|
-
|
2165
|
-
|
2166
|
-
|
2167
|
-
|
2168
|
-
|
2110
|
+
cursor, has_more = None, True
|
2111
|
+
while has_more:
|
2112
|
+
files_page = self._fetch_file_urls(cursor=cursor, per_page=batch_size)
|
2113
|
+
|
2114
|
+
has_more = files_page.page_info.has_next_page
|
2115
|
+
cursor = files_page.page_info.end_cursor
|
2116
|
+
|
2117
|
+
# `File` nodes are formally nullable, so filter them out just in case.
|
2118
|
+
file_nodes = (e.node for e in files_page.edges if e.node)
|
2119
|
+
for node in file_nodes:
|
2120
|
+
entry = self.get_entry(node.name)
|
2169
2121
|
# TODO: uncomment once artifact downloads are supported in core
|
2170
2122
|
# if require_core and entry.ref is None:
|
2171
2123
|
# # Handled by core
|
2172
2124
|
# continue
|
2173
|
-
entry._download_url =
|
2125
|
+
entry._download_url = node.direct_url
|
2174
2126
|
if (not path_prefix) or entry.path.startswith(str(path_prefix)):
|
2175
2127
|
active_futures.add(executor.submit(download_entry, entry))
|
2128
|
+
|
2176
2129
|
# Wait for download threads to catch up.
|
2177
|
-
|
2178
|
-
|
2130
|
+
#
|
2131
|
+
# Extra context and observations (tonyyli):
|
2132
|
+
# - Even though the ThreadPoolExecutor limits the number of
|
2133
|
+
# concurrently-executed tasks, its internal task queue is unbounded.
|
2134
|
+
# The code below seems intended to ensure that at most `batch_size`
|
2135
|
+
# "backlogged" futures are held in memory at any given time. This seems like
|
2136
|
+
# a reasonable safeguard against unbounded memory consumption.
|
2137
|
+
#
|
2138
|
+
# - We should probably use a builtin (bounded) Queue or Semaphore here instead.
|
2139
|
+
# Consider this for a future change, or (depending on risk and risk tolerance)
|
2140
|
+
# managing this logic via asyncio instead, if viable.
|
2141
|
+
if len(active_futures) > batch_size:
|
2179
2142
|
for future in concurrent.futures.as_completed(active_futures):
|
2180
2143
|
future.result() # check for errors
|
2181
2144
|
active_futures.remove(future)
|
2182
|
-
if len(active_futures) <=
|
2145
|
+
if len(active_futures) <= batch_size:
|
2183
2146
|
break
|
2147
|
+
|
2184
2148
|
# Check for errors.
|
2185
2149
|
for future in concurrent.futures.as_completed(active_futures):
|
2186
2150
|
future.result()
|
2187
2151
|
|
2188
2152
|
if log:
|
2189
|
-
|
2190
|
-
|
2191
|
-
|
2192
|
-
|
2193
|
-
|
2194
|
-
|
2153
|
+
# If you're wondering if we can display a `timedelta`, note that it
|
2154
|
+
# doesn't really support custom string format specifiers (compared to
|
2155
|
+
# e.g. `datetime` objs). To truncate the number of decimal places for
|
2156
|
+
# the seconds part, we manually convert/format each part below.
|
2157
|
+
dt_secs = abs(time.monotonic() - start_time)
|
2158
|
+
hrs, mins = divmod(dt_secs, 3600)
|
2159
|
+
mins, secs = divmod(mins, 60)
|
2195
2160
|
termlog(
|
2196
|
-
f"Done. {
|
2161
|
+
f"Done. {int(hrs):02d}:{int(mins):02d}:{secs:04.1f} ({size_mb / dt_secs:.1f}MB/s)",
|
2197
2162
|
prefix=False,
|
2198
2163
|
)
|
2199
2164
|
return FilePathStr(root)
|
@@ -2202,79 +2167,44 @@ class Artifact:
|
|
2202
2167
|
retry_timedelta=timedelta(minutes=3),
|
2203
2168
|
retryable_exceptions=(requests.RequestException),
|
2204
2169
|
)
|
2205
|
-
def _fetch_file_urls(
|
2170
|
+
def _fetch_file_urls(
|
2171
|
+
self, cursor: str | None, per_page: int = 5000
|
2172
|
+
) -> FileUrlsFragment:
|
2173
|
+
if self._client is None:
|
2174
|
+
raise RuntimeError("Client not initialized")
|
2175
|
+
|
2206
2176
|
if InternalApi()._server_supports(
|
2207
2177
|
pb.ServerFeature.ARTIFACT_COLLECTION_MEMBERSHIP_FILES
|
2208
2178
|
):
|
2209
|
-
query = gql(
|
2210
|
-
|
2211
|
-
|
2212
|
-
|
2213
|
-
|
2214
|
-
|
2215
|
-
|
2216
|
-
|
2217
|
-
|
2218
|
-
|
2219
|
-
|
2220
|
-
|
2221
|
-
|
2222
|
-
|
2223
|
-
|
2224
|
-
|
2225
|
-
|
2226
|
-
|
2227
|
-
|
2228
|
-
|
2229
|
-
}
|
2230
|
-
}
|
2231
|
-
}
|
2232
|
-
"""
|
2233
|
-
)
|
2234
|
-
assert self._client is not None
|
2235
|
-
response = self._client.execute(
|
2236
|
-
query,
|
2237
|
-
variable_values={
|
2238
|
-
"entityName": self.entity,
|
2239
|
-
"projectName": self.project,
|
2240
|
-
"artifactName": self.name.split(":")[0],
|
2241
|
-
"artifactVersionIndex": self.version,
|
2242
|
-
"cursor": cursor,
|
2243
|
-
"perPage": per_page,
|
2244
|
-
},
|
2245
|
-
timeout=60,
|
2246
|
-
)
|
2247
|
-
return response["project"]["artifactCollection"]["artifactMembership"][
|
2248
|
-
"files"
|
2249
|
-
]
|
2179
|
+
query = gql(ARTIFACT_COLLECTION_MEMBERSHIP_FILE_URLS_GQL)
|
2180
|
+
gql_vars = {
|
2181
|
+
"entityName": self.entity,
|
2182
|
+
"projectName": self.project,
|
2183
|
+
"artifactName": self.name.split(":")[0],
|
2184
|
+
"artifactVersionIndex": self.version,
|
2185
|
+
"cursor": cursor,
|
2186
|
+
"perPage": per_page,
|
2187
|
+
}
|
2188
|
+
data = self._client.execute(query, variable_values=gql_vars, timeout=60)
|
2189
|
+
result = ArtifactCollectionMembershipFileUrls.model_validate(data)
|
2190
|
+
|
2191
|
+
if not (
|
2192
|
+
(project := result.project)
|
2193
|
+
and (collection := project.artifact_collection)
|
2194
|
+
and (membership := collection.artifact_membership)
|
2195
|
+
and (files := membership.files)
|
2196
|
+
):
|
2197
|
+
raise ValueError(f"Unable to fetch files for artifact: {self.name!r}")
|
2198
|
+
return files
|
2250
2199
|
else:
|
2251
|
-
query = gql(
|
2252
|
-
|
2253
|
-
|
2254
|
-
|
2255
|
-
|
2256
|
-
|
2257
|
-
|
2258
|
-
|
2259
|
-
}
|
2260
|
-
edges {
|
2261
|
-
node {
|
2262
|
-
name
|
2263
|
-
directUrl
|
2264
|
-
}
|
2265
|
-
}
|
2266
|
-
}
|
2267
|
-
}
|
2268
|
-
}
|
2269
|
-
"""
|
2270
|
-
)
|
2271
|
-
assert self._client is not None
|
2272
|
-
response = self._client.execute(
|
2273
|
-
query,
|
2274
|
-
variable_values={"id": self.id, "cursor": cursor, "perPage": per_page},
|
2275
|
-
timeout=60,
|
2276
|
-
)
|
2277
|
-
return response["artifact"]["files"]
|
2200
|
+
query = gql(ARTIFACT_FILE_URLS_GQL)
|
2201
|
+
gql_vars = {"id": self.id, "cursor": cursor, "perPage": per_page}
|
2202
|
+
data = self._client.execute(query, variable_values=gql_vars, timeout=60)
|
2203
|
+
result = ArtifactFileUrls.model_validate(data)
|
2204
|
+
|
2205
|
+
if not ((artifact := result.artifact) and (files := artifact.files)):
|
2206
|
+
raise ValueError(f"Unable to fetch files for artifact: {self.name!r}")
|
2207
|
+
return files
|
2278
2208
|
|
2279
2209
|
@ensure_logged
|
2280
2210
|
def checkout(self, root: str | None = None) -> str:
|
@@ -2395,8 +2325,7 @@ class Artifact:
|
|
2395
2325
|
# In case we're on a system where the artifact dir has a name corresponding to
|
2396
2326
|
# an unexpected filesystem, we'll check for alternate roots. If one exists we'll
|
2397
2327
|
# use that, otherwise we'll fall back to the system-preferred path.
|
2398
|
-
|
2399
|
-
return FilePathStr(str(path))
|
2328
|
+
return FilePathStr(check_exists(root) or system_preferred_path(root))
|
2400
2329
|
|
2401
2330
|
def _add_download_root(self, dir_path: str) -> None:
|
2402
2331
|
self._download_roots.add(os.path.abspath(dir_path))
|
@@ -2440,33 +2369,16 @@ class Artifact:
|
|
2440
2369
|
|
2441
2370
|
@normalize_exceptions
|
2442
2371
|
def _delete(self, delete_aliases: bool = False) -> None:
|
2443
|
-
|
2444
|
-
""
|
2445
|
-
|
2446
|
-
|
2447
|
-
|
2448
|
-
|
2449
|
-
|
2450
|
-
artifact {
|
2451
|
-
id
|
2452
|
-
}
|
2453
|
-
}
|
2454
|
-
}
|
2455
|
-
"""
|
2456
|
-
)
|
2457
|
-
assert self._client is not None
|
2458
|
-
self._client.execute(
|
2459
|
-
mutation,
|
2460
|
-
variable_values={
|
2461
|
-
"artifactID": self.id,
|
2462
|
-
"deleteAliases": delete_aliases,
|
2463
|
-
},
|
2464
|
-
)
|
2372
|
+
if self._client is None:
|
2373
|
+
raise RuntimeError("Client not initialized for artifact mutations")
|
2374
|
+
|
2375
|
+
mutation = gql(DELETE_ARTIFACT_GQL)
|
2376
|
+
gql_vars = {"artifactID": self.id, "deleteAliases": delete_aliases}
|
2377
|
+
|
2378
|
+
self._client.execute(mutation, variable_values=gql_vars)
|
2465
2379
|
|
2466
2380
|
@normalize_exceptions
|
2467
|
-
def link(
|
2468
|
-
self, target_path: str, aliases: list[str] | None = None
|
2469
|
-
) -> Artifact | None:
|
2381
|
+
def link(self, target_path: str, aliases: list[str] | None = None) -> Artifact:
|
2470
2382
|
"""Link this artifact to a portfolio (a promoted collection of artifacts).
|
2471
2383
|
|
2472
2384
|
Args:
|
@@ -2485,7 +2397,7 @@ class Artifact:
|
|
2485
2397
|
ArtifactNotLoggedError: If the artifact is not logged.
|
2486
2398
|
|
2487
2399
|
Returns:
|
2488
|
-
The linked artifact
|
2400
|
+
The linked artifact.
|
2489
2401
|
"""
|
2490
2402
|
from wandb import Api
|
2491
2403
|
|
@@ -2505,10 +2417,11 @@ class Artifact:
|
|
2505
2417
|
# Wait until the artifact is committed before trying to link it.
|
2506
2418
|
self.wait()
|
2507
2419
|
|
2508
|
-
api =
|
2420
|
+
api = InternalApi()
|
2421
|
+
settings = api.settings()
|
2509
2422
|
|
2510
2423
|
target = ArtifactPath.from_str(target_path).with_defaults(
|
2511
|
-
project=
|
2424
|
+
project=settings.get("project") or "uncategorized",
|
2512
2425
|
)
|
2513
2426
|
|
2514
2427
|
# Parse the entity (first part of the path) appropriately,
|
@@ -2516,11 +2429,8 @@ class Artifact:
|
|
2516
2429
|
if target.project and is_artifact_registry_project(target.project):
|
2517
2430
|
# In a Registry linking, the entity is used to fetch the organization of the artifact
|
2518
2431
|
# therefore the source artifact's entity is passed to the backend
|
2519
|
-
|
2520
|
-
|
2521
|
-
target.prefix = InternalApi()._resolve_org_entity_name(
|
2522
|
-
self.source_entity, organization
|
2523
|
-
)
|
2432
|
+
org = target.prefix or settings.get("organization") or ""
|
2433
|
+
target.prefix = api._resolve_org_entity_name(self.source_entity, org)
|
2524
2434
|
else:
|
2525
2435
|
target = target.with_defaults(prefix=self.source_entity)
|
2526
2436
|
|
@@ -2537,21 +2447,35 @@ class Artifact:
|
|
2537
2447
|
aliases=alias_inputs,
|
2538
2448
|
)
|
2539
2449
|
gql_vars = {"input": gql_input.model_dump(exclude_none=True)}
|
2540
|
-
gql_op = gql(LINK_ARTIFACT_GQL)
|
2541
|
-
data = self._client.execute(gql_op, variable_values=gql_vars)
|
2542
2450
|
|
2451
|
+
# Newer server versions can return `artifactMembership` directly in the response,
|
2452
|
+
# avoiding the need to re-fetch the linked artifact at the end.
|
2453
|
+
if api._server_supports(
|
2454
|
+
pb.ServerFeature.ARTIFACT_MEMBERSHIP_IN_LINK_ARTIFACT_RESPONSE
|
2455
|
+
):
|
2456
|
+
omit_fragments = set()
|
2457
|
+
else:
|
2458
|
+
# FIXME: Make `gql_compat` omit nested fragment definitions recursively (but safely)
|
2459
|
+
omit_fragments = {
|
2460
|
+
"MembershipWithArtifact",
|
2461
|
+
"ArtifactFragment",
|
2462
|
+
"ArtifactFragmentWithoutAliases",
|
2463
|
+
}
|
2464
|
+
|
2465
|
+
gql_op = gql_compat(LINK_ARTIFACT_GQL, omit_fragments=omit_fragments)
|
2466
|
+
data = self._client.execute(gql_op, variable_values=gql_vars)
|
2543
2467
|
result = LinkArtifact.model_validate(data).link_artifact
|
2468
|
+
|
2469
|
+
# Newer server versions can return artifactMembership directly in the response
|
2470
|
+
if result and (membership := result.artifact_membership):
|
2471
|
+
return self._from_membership(membership, target=target, client=self._client)
|
2472
|
+
|
2473
|
+
# Fallback to old behavior, which requires re-fetching the linked artifact to return it
|
2544
2474
|
if not (result and (version_idx := result.version_index) is not None):
|
2545
2475
|
raise ValueError("Unable to parse linked artifact version from response")
|
2546
2476
|
|
2547
|
-
|
2548
|
-
|
2549
|
-
|
2550
|
-
try:
|
2551
|
-
return api._artifact(linked_path)
|
2552
|
-
except Exception as e:
|
2553
|
-
wandb.termerror(f"Error fetching link artifact after linking: {e}")
|
2554
|
-
return None
|
2477
|
+
link_name = f"{target.to_str()}:v{version_idx}"
|
2478
|
+
return Api(overrides={"entity": self.source_entity})._artifact(link_name)
|
2555
2479
|
|
2556
2480
|
@ensure_logged
|
2557
2481
|
def unlink(self) -> None:
|
@@ -2573,28 +2497,14 @@ class Artifact:
|
|
2573
2497
|
|
2574
2498
|
@normalize_exceptions
|
2575
2499
|
def _unlink(self) -> None:
|
2576
|
-
|
2577
|
-
""
|
2578
|
-
|
2579
|
-
|
2580
|
-
|
2581
|
-
|
2582
|
-
artifactID
|
2583
|
-
success
|
2584
|
-
clientMutationId
|
2585
|
-
}
|
2586
|
-
}
|
2587
|
-
"""
|
2588
|
-
)
|
2589
|
-
assert self._client is not None
|
2500
|
+
if self._client is None:
|
2501
|
+
raise RuntimeError("Client not initialized for artifact mutations")
|
2502
|
+
|
2503
|
+
mutation = gql(UNLINK_ARTIFACT_GQL)
|
2504
|
+
gql_vars = {"artifactID": self.id, "artifactPortfolioID": self.collection.id}
|
2505
|
+
|
2590
2506
|
try:
|
2591
|
-
self._client.execute(
|
2592
|
-
mutation,
|
2593
|
-
variable_values={
|
2594
|
-
"artifactID": self.id,
|
2595
|
-
"artifactPortfolioID": self.collection.id,
|
2596
|
-
},
|
2597
|
-
)
|
2507
|
+
self._client.execute(mutation, variable_values=gql_vars)
|
2598
2508
|
except CommError as e:
|
2599
2509
|
raise CommError(
|
2600
2510
|
f"You do not have permission to unlink the artifact {self.qualified_name}"
|
@@ -2610,41 +2520,26 @@ class Artifact:
|
|
2610
2520
|
Raises:
|
2611
2521
|
ArtifactNotLoggedError: If the artifact is not logged.
|
2612
2522
|
"""
|
2613
|
-
|
2614
|
-
""
|
2615
|
-
|
2616
|
-
|
2617
|
-
|
2618
|
-
|
2619
|
-
|
2620
|
-
|
2621
|
-
|
2622
|
-
|
2623
|
-
|
2624
|
-
|
2625
|
-
|
2626
|
-
|
2627
|
-
|
2628
|
-
|
2629
|
-
|
2630
|
-
|
2631
|
-
|
2632
|
-
|
2633
|
-
)
|
2634
|
-
assert self._client is not None
|
2635
|
-
response = self._client.execute(
|
2636
|
-
query,
|
2637
|
-
variable_values={"id": self.id},
|
2638
|
-
)
|
2639
|
-
return [
|
2640
|
-
Run(
|
2641
|
-
self._client,
|
2642
|
-
edge["node"]["project"]["entityName"],
|
2643
|
-
edge["node"]["project"]["name"],
|
2644
|
-
edge["node"]["name"],
|
2645
|
-
)
|
2646
|
-
for edge in response.get("artifact", {}).get("usedBy", {}).get("edges", [])
|
2647
|
-
]
|
2523
|
+
if self._client is None:
|
2524
|
+
raise RuntimeError("Client not initialized for artifact queries")
|
2525
|
+
|
2526
|
+
query = gql(ARTIFACT_USED_BY_GQL)
|
2527
|
+
gql_vars = {"id": self.id}
|
2528
|
+
data = self._client.execute(query, variable_values=gql_vars)
|
2529
|
+
result = ArtifactUsedBy.model_validate(data)
|
2530
|
+
|
2531
|
+
if (
|
2532
|
+
(artifact := result.artifact)
|
2533
|
+
and (used_by := artifact.used_by)
|
2534
|
+
and (edges := used_by.edges)
|
2535
|
+
):
|
2536
|
+
run_nodes = (e.node for e in edges)
|
2537
|
+
return [
|
2538
|
+
Run(self._client, proj.entity_name, proj.name, run.name)
|
2539
|
+
for run in run_nodes
|
2540
|
+
if (proj := run.project)
|
2541
|
+
]
|
2542
|
+
return []
|
2648
2543
|
|
2649
2544
|
@ensure_logged
|
2650
2545
|
def logged_by(self) -> Run | None:
|
@@ -2656,39 +2551,22 @@ class Artifact:
|
|
2656
2551
|
Raises:
|
2657
2552
|
ArtifactNotLoggedError: If the artifact is not logged.
|
2658
2553
|
"""
|
2659
|
-
|
2660
|
-
""
|
2661
|
-
|
2662
|
-
|
2663
|
-
|
2664
|
-
|
2665
|
-
|
2666
|
-
|
2667
|
-
|
2668
|
-
|
2669
|
-
|
2670
|
-
|
2671
|
-
|
2672
|
-
|
2673
|
-
|
2674
|
-
|
2675
|
-
}
|
2676
|
-
"""
|
2677
|
-
)
|
2678
|
-
assert self._client is not None
|
2679
|
-
response = self._client.execute(
|
2680
|
-
query,
|
2681
|
-
variable_values={"id": self.id},
|
2682
|
-
)
|
2683
|
-
creator = response.get("artifact", {}).get("createdBy", {})
|
2684
|
-
if creator.get("name") is None:
|
2685
|
-
return None
|
2686
|
-
return Run(
|
2687
|
-
self._client,
|
2688
|
-
creator["project"]["entityName"],
|
2689
|
-
creator["project"]["name"],
|
2690
|
-
creator["name"],
|
2691
|
-
)
|
2554
|
+
if self._client is None:
|
2555
|
+
raise RuntimeError("Client not initialized for artifact queries")
|
2556
|
+
|
2557
|
+
query = gql(ARTIFACT_CREATED_BY_GQL)
|
2558
|
+
gql_vars = {"id": self.id}
|
2559
|
+
data = self._client.execute(query, variable_values=gql_vars)
|
2560
|
+
result = ArtifactCreatedBy.model_validate(data)
|
2561
|
+
|
2562
|
+
if (
|
2563
|
+
(artifact := result.artifact)
|
2564
|
+
and (creator := artifact.created_by)
|
2565
|
+
and (name := creator.name)
|
2566
|
+
and (project := creator.project)
|
2567
|
+
):
|
2568
|
+
return Run(self._client, project.entity_name, project.name, name)
|
2569
|
+
return None
|
2692
2570
|
|
2693
2571
|
@ensure_logged
|
2694
2572
|
def json_encode(self) -> dict[str, Any]:
|
@@ -2704,37 +2582,22 @@ class Artifact:
|
|
2704
2582
|
entity_name: str, project_name: str, name: str, client: RetryingClient
|
2705
2583
|
) -> str | None:
|
2706
2584
|
"""Returns the expected type for a given artifact name and project."""
|
2707
|
-
query = gql(
|
2708
|
-
|
2709
|
-
|
2710
|
-
|
2711
|
-
|
2712
|
-
|
2713
|
-
|
2714
|
-
|
2715
|
-
|
2716
|
-
|
2717
|
-
|
2718
|
-
|
2719
|
-
|
2720
|
-
|
2721
|
-
|
2722
|
-
|
2723
|
-
)
|
2724
|
-
if ":" not in name:
|
2725
|
-
name += ":latest"
|
2726
|
-
response = client.execute(
|
2727
|
-
query,
|
2728
|
-
variable_values={
|
2729
|
-
"entityName": entity_name,
|
2730
|
-
"projectName": project_name,
|
2731
|
-
"name": name,
|
2732
|
-
},
|
2733
|
-
)
|
2734
|
-
return (
|
2735
|
-
((response.get("project") or {}).get("artifact") or {}).get("artifactType")
|
2736
|
-
or {}
|
2737
|
-
).get("name")
|
2585
|
+
query = gql(ARTIFACT_TYPE_GQL)
|
2586
|
+
gql_vars = {
|
2587
|
+
"entityName": entity_name,
|
2588
|
+
"projectName": project_name,
|
2589
|
+
"name": name if (":" in name) else f"{name}:latest",
|
2590
|
+
}
|
2591
|
+
data = client.execute(query, variable_values=gql_vars)
|
2592
|
+
result = ArtifactType.model_validate(data)
|
2593
|
+
|
2594
|
+
if (
|
2595
|
+
(project := result.project)
|
2596
|
+
and (artifact := project.artifact)
|
2597
|
+
and (artifact_type := artifact.artifact_type)
|
2598
|
+
):
|
2599
|
+
return artifact_type.name
|
2600
|
+
return None
|
2738
2601
|
|
2739
2602
|
def _load_manifest(self, url: str) -> ArtifactManifest:
|
2740
2603
|
with requests.get(url) as response:
|