wandb 0.21.1__py3-none-musllinux_1_2_aarch64.whl → 0.21.3__py3-none-musllinux_1_2_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. wandb/__init__.py +1 -1
  2. wandb/__init__.pyi +1 -1
  3. wandb/apis/public/api.py +1 -2
  4. wandb/apis/public/artifacts.py +3 -5
  5. wandb/apis/public/registries/_utils.py +14 -16
  6. wandb/apis/public/registries/registries_search.py +176 -289
  7. wandb/apis/public/reports.py +13 -10
  8. wandb/automations/_generated/delete_automation.py +1 -3
  9. wandb/automations/_generated/enums.py +13 -11
  10. wandb/bin/gpu_stats +0 -0
  11. wandb/bin/wandb-core +0 -0
  12. wandb/cli/cli.py +47 -2
  13. wandb/integration/metaflow/data_pandas.py +2 -2
  14. wandb/integration/metaflow/data_pytorch.py +75 -0
  15. wandb/integration/metaflow/data_sklearn.py +76 -0
  16. wandb/integration/metaflow/metaflow.py +16 -87
  17. wandb/integration/weave/__init__.py +6 -0
  18. wandb/integration/weave/interface.py +49 -0
  19. wandb/integration/weave/weave.py +63 -0
  20. wandb/proto/v3/wandb_internal_pb2.py +3 -2
  21. wandb/proto/v4/wandb_internal_pb2.py +2 -2
  22. wandb/proto/v5/wandb_internal_pb2.py +2 -2
  23. wandb/proto/v6/wandb_internal_pb2.py +2 -2
  24. wandb/sdk/artifacts/_factories.py +17 -0
  25. wandb/sdk/artifacts/_generated/__init__.py +221 -13
  26. wandb/sdk/artifacts/_generated/artifact_by_id.py +17 -0
  27. wandb/sdk/artifacts/_generated/artifact_by_name.py +22 -0
  28. wandb/sdk/artifacts/_generated/artifact_collection_membership_file_urls.py +43 -0
  29. wandb/sdk/artifacts/_generated/artifact_created_by.py +47 -0
  30. wandb/sdk/artifacts/_generated/artifact_file_urls.py +22 -0
  31. wandb/sdk/artifacts/_generated/artifact_type.py +31 -0
  32. wandb/sdk/artifacts/_generated/artifact_used_by.py +43 -0
  33. wandb/sdk/artifacts/_generated/artifact_via_membership_by_name.py +26 -0
  34. wandb/sdk/artifacts/_generated/delete_artifact.py +28 -0
  35. wandb/sdk/artifacts/_generated/enums.py +5 -0
  36. wandb/sdk/artifacts/_generated/fetch_artifact_manifest.py +38 -0
  37. wandb/sdk/artifacts/_generated/fetch_registries.py +32 -0
  38. wandb/sdk/artifacts/_generated/fragments.py +279 -41
  39. wandb/sdk/artifacts/_generated/link_artifact.py +6 -0
  40. wandb/sdk/artifacts/_generated/operations.py +654 -51
  41. wandb/sdk/artifacts/_generated/registry_collections.py +34 -0
  42. wandb/sdk/artifacts/_generated/registry_versions.py +34 -0
  43. wandb/sdk/artifacts/_generated/unlink_artifact.py +25 -0
  44. wandb/sdk/artifacts/_graphql_fragments.py +3 -86
  45. wandb/sdk/artifacts/_validators.py +6 -4
  46. wandb/sdk/artifacts/artifact.py +410 -547
  47. wandb/sdk/artifacts/artifact_file_cache.py +11 -7
  48. wandb/sdk/artifacts/artifact_manifest.py +10 -9
  49. wandb/sdk/artifacts/artifact_manifest_entry.py +15 -18
  50. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +5 -3
  51. wandb/sdk/artifacts/storage_handlers/gcs_handler.py +1 -1
  52. wandb/sdk/artifacts/storage_handlers/http_handler.py +1 -1
  53. wandb/sdk/artifacts/storage_handlers/s3_handler.py +1 -1
  54. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +1 -1
  55. wandb/sdk/data_types/video.py +2 -2
  56. wandb/sdk/interface/interface_queue.py +1 -4
  57. wandb/sdk/interface/interface_shared.py +26 -37
  58. wandb/sdk/interface/interface_sock.py +24 -14
  59. wandb/sdk/internal/settings_static.py +2 -3
  60. wandb/sdk/launch/create_job.py +12 -1
  61. wandb/sdk/launch/inputs/internal.py +25 -24
  62. wandb/sdk/launch/inputs/schema.py +31 -1
  63. wandb/sdk/launch/runner/kubernetes_runner.py +24 -29
  64. wandb/sdk/lib/asyncio_compat.py +16 -16
  65. wandb/sdk/lib/asyncio_manager.py +252 -0
  66. wandb/sdk/lib/hashutil.py +13 -4
  67. wandb/sdk/lib/paths.py +23 -21
  68. wandb/sdk/lib/printer.py +2 -2
  69. wandb/sdk/lib/printer_asyncio.py +3 -1
  70. wandb/sdk/lib/retry.py +185 -78
  71. wandb/sdk/lib/service/service_client.py +106 -0
  72. wandb/sdk/lib/service/service_connection.py +20 -26
  73. wandb/sdk/lib/service/service_token.py +30 -13
  74. wandb/sdk/mailbox/mailbox.py +13 -5
  75. wandb/sdk/mailbox/mailbox_handle.py +22 -13
  76. wandb/sdk/mailbox/response_handle.py +42 -106
  77. wandb/sdk/mailbox/wait_with_progress.py +7 -42
  78. wandb/sdk/wandb_init.py +11 -25
  79. wandb/sdk/wandb_login.py +1 -1
  80. wandb/sdk/wandb_run.py +92 -56
  81. wandb/sdk/wandb_settings.py +45 -32
  82. wandb/sdk/wandb_setup.py +176 -96
  83. wandb/util.py +1 -1
  84. {wandb-0.21.1.dist-info → wandb-0.21.3.dist-info}/METADATA +2 -2
  85. {wandb-0.21.1.dist-info → wandb-0.21.3.dist-info}/RECORD +88 -72
  86. wandb/sdk/interface/interface_relay.py +0 -38
  87. wandb/sdk/interface/router.py +0 -89
  88. wandb/sdk/interface/router_queue.py +0 -43
  89. wandb/sdk/interface/router_relay.py +0 -50
  90. wandb/sdk/interface/router_sock.py +0 -32
  91. wandb/sdk/lib/sock_client.py +0 -232
  92. {wandb-0.21.1.dist-info → wandb-0.21.3.dist-info}/WHEEL +0 -0
  93. {wandb-0.21.1.dist-info → wandb-0.21.3.dist-info}/entry_points.txt +0 -0
  94. {wandb-0.21.1.dist-info → wandb-0.21.3.dist-info}/licenses/LICENSE +0 -0
@@ -17,11 +17,22 @@ import time
17
17
  from collections import deque
18
18
  from copy import copy
19
19
  from dataclasses import dataclass
20
- from datetime import datetime, timedelta
20
+ from datetime import timedelta
21
21
  from functools import partial
22
22
  from itertools import filterfalse
23
- from pathlib import PurePosixPath
24
- from typing import IO, TYPE_CHECKING, Any, Iterator, Literal, Sequence, Type, final
23
+ from pathlib import Path, PurePosixPath
24
+ from typing import (
25
+ IO,
26
+ TYPE_CHECKING,
27
+ Any,
28
+ Final,
29
+ Iterator,
30
+ Literal,
31
+ Sequence,
32
+ Type,
33
+ cast,
34
+ final,
35
+ )
25
36
  from urllib.parse import quote, urljoin, urlparse
26
37
 
27
38
  import requests
@@ -43,8 +54,9 @@ from wandb.sdk.data_types._dtypes import Type as WBType
43
54
  from wandb.sdk.data_types._dtypes import TypeRegistry
44
55
  from wandb.sdk.internal.internal_api import Api as InternalApi
45
56
  from wandb.sdk.internal.thread_local_settings import _thread_local_api_settings
46
- from wandb.sdk.lib import filesystem, retry, runid, telemetry
57
+ from wandb.sdk.lib import retry, runid, telemetry
47
58
  from wandb.sdk.lib.deprecate import deprecate
59
+ from wandb.sdk.lib.filesystem import check_exists, system_preferred_path
48
60
  from wandb.sdk.lib.hashutil import B64MD5, b64_to_hex_id, md5_file_b64
49
61
  from wandb.sdk.lib.paths import FilePathStr, LogicalPath, StrPath, URIStr
50
62
  from wandb.sdk.lib.runid import generate_id
@@ -58,21 +70,45 @@ from wandb.util import (
58
70
  vendor_setup,
59
71
  )
60
72
 
73
+ from ._factories import make_storage_policy
61
74
  from ._generated import (
62
75
  ADD_ALIASES_GQL,
76
+ ARTIFACT_BY_ID_GQL,
77
+ ARTIFACT_BY_NAME_GQL,
78
+ ARTIFACT_COLLECTION_MEMBERSHIP_FILE_URLS_GQL,
79
+ ARTIFACT_CREATED_BY_GQL,
80
+ ARTIFACT_FILE_URLS_GQL,
81
+ ARTIFACT_TYPE_GQL,
82
+ ARTIFACT_USED_BY_GQL,
83
+ ARTIFACT_VIA_MEMBERSHIP_BY_NAME_GQL,
63
84
  DELETE_ALIASES_GQL,
85
+ DELETE_ARTIFACT_GQL,
86
+ FETCH_ARTIFACT_MANIFEST_GQL,
64
87
  FETCH_LINKED_ARTIFACTS_GQL,
65
88
  LINK_ARTIFACT_GQL,
89
+ UNLINK_ARTIFACT_GQL,
66
90
  UPDATE_ARTIFACT_GQL,
67
91
  ArtifactAliasInput,
92
+ ArtifactByID,
93
+ ArtifactByName,
68
94
  ArtifactCollectionAliasInput,
95
+ ArtifactCollectionMembershipFileUrls,
96
+ ArtifactCreatedBy,
97
+ ArtifactFileUrls,
98
+ ArtifactFragment,
99
+ ArtifactType,
100
+ ArtifactUsedBy,
101
+ ArtifactViaMembershipByName,
102
+ FetchArtifactManifest,
69
103
  FetchLinkedArtifacts,
104
+ FileUrlsFragment,
70
105
  LinkArtifact,
71
106
  LinkArtifactInput,
107
+ MembershipWithArtifact,
72
108
  TagInput,
73
109
  UpdateArtifact,
74
110
  )
75
- from ._graphql_fragments import _gql_artifact_fragment, omit_artifact_fields
111
+ from ._graphql_fragments import omit_artifact_fields
76
112
  from ._validators import (
77
113
  LINKED_ARTIFACT_COLLECTION_TYPE,
78
114
  ArtifactPath,
@@ -103,9 +139,6 @@ from .exceptions import (
103
139
  )
104
140
  from .staging import get_staging_dir
105
141
  from .storage_handlers.gcs_handler import _GCSIsADirectoryError
106
- from .storage_layout import StorageLayout
107
- from .storage_policies import WANDB_STORAGE_POLICY
108
- from .storage_policy import StoragePolicy
109
142
 
110
143
  reset_path = vendor_setup()
111
144
 
@@ -119,6 +152,9 @@ if TYPE_CHECKING:
119
152
  logger = logging.getLogger(__name__)
120
153
 
121
154
 
155
+ _MB: Final[int] = 1024 * 1024
156
+
157
+
122
158
  @final
123
159
  @dataclass
124
160
  class _DeferredArtifactManifest:
@@ -188,11 +224,6 @@ class Artifact:
188
224
  # Internal.
189
225
  self._client: RetryingClient | None = None
190
226
 
191
- storage_policy_cls = StoragePolicy.lookup_by_name(WANDB_STORAGE_POLICY)
192
- layout = StorageLayout.V1 if env.get_use_v1_artifacts() else StorageLayout.V2
193
- policy_config = {"storageLayout": layout}
194
- self._storage_policy = storage_policy_cls.from_config(config=policy_config)
195
-
196
227
  self._tmp_dir: tempfile.TemporaryDirectory | None = None
197
228
  self._added_objs: dict[int, tuple[WBValue, ArtifactManifestEntry]] = {}
198
229
  self._added_local_paths: dict[str, ArtifactManifestEntry] = {}
@@ -236,7 +267,7 @@ class Artifact:
236
267
  self._use_as: str | None = None
237
268
  self._state: ArtifactState = ArtifactState.PENDING
238
269
  self._manifest: ArtifactManifest | _DeferredArtifactManifest | None = (
239
- ArtifactManifestV1(self._storage_policy)
270
+ ArtifactManifestV1(storage_policy=make_storage_policy())
240
271
  )
241
272
  self._commit_hash: str | None = None
242
273
  self._file_count: int | None = None
@@ -257,32 +288,22 @@ class Artifact:
257
288
  if (artifact := artifact_instance_cache.get(artifact_id)) is not None:
258
289
  return artifact
259
290
 
260
- query = gql(
261
- """
262
- query ArtifactByID($id: ID!) {
263
- artifact(id: $id) {
264
- ...ArtifactFragment
265
- }
266
- }
267
- """
268
- + _gql_artifact_fragment()
269
- )
270
- response = client.execute(
271
- query,
272
- variable_values={"id": artifact_id},
273
- )
274
- attrs = response.get("artifact")
275
- if attrs is None:
291
+ query = gql_compat(ARTIFACT_BY_ID_GQL, omit_fields=omit_artifact_fields())
292
+
293
+ data = client.execute(query, variable_values={"id": artifact_id})
294
+ result = ArtifactByID.model_validate(data)
295
+
296
+ if (art := result.artifact) is None:
276
297
  return None
277
298
 
278
- src_collection = attrs["artifactSequence"]
279
- src_project = src_collection["project"]
299
+ src_collection = art.artifact_sequence
300
+ src_project = src_collection.project
280
301
 
281
- entity_name = src_project["entityName"] if src_project else ""
282
- project_name = src_project["name"] if src_project else ""
302
+ entity_name = src_project.entity_name if src_project else ""
303
+ project_name = src_project.name if src_project else ""
283
304
 
284
- name = "{}:v{}".format(src_collection["name"], attrs["versionIndex"])
285
- return cls._from_attrs(entity_name, project_name, name, attrs, client)
305
+ name = f"{src_collection.name}:v{art.version_index}"
306
+ return cls._from_attrs(entity_name, project_name, name, art, client)
286
307
 
287
308
  @classmethod
288
309
  def _membership_from_name(
@@ -293,7 +314,7 @@ class Artifact:
293
314
  name: str,
294
315
  client: RetryingClient,
295
316
  ) -> Artifact:
296
- if not InternalApi()._server_supports(
317
+ if not (api := InternalApi())._server_supports(
297
318
  pb.ServerFeature.PROJECT_ARTIFACT_COLLECTION_MEMBERSHIP
298
319
  ):
299
320
  raise UnsupportedError(
@@ -301,69 +322,26 @@ class Artifact:
301
322
  "by this version of wandb server. Consider updating to the latest version."
302
323
  )
303
324
 
304
- query = gql(
305
- f"""
306
- query ArtifactByName($entityName: String!, $projectName: String!, $name: String!) {{
307
- project(name: $projectName, entityName: $entityName) {{
308
- artifactCollectionMembership(name: $name) {{
309
- id
310
- artifactCollection {{
311
- id
312
- name
313
- project {{
314
- id
315
- entityName
316
- name
317
- }}
318
- }}
319
- artifact {{
320
- ...ArtifactFragment
321
- }}
322
- }}
323
- }}
324
- }}
325
- {_gql_artifact_fragment()}
326
- """
325
+ query = gql_compat(
326
+ ARTIFACT_VIA_MEMBERSHIP_BY_NAME_GQL,
327
+ omit_fields=omit_artifact_fields(api=api),
327
328
  )
328
329
 
329
- query_variable_values: dict[str, Any] = {
330
- "entityName": entity,
331
- "projectName": project,
332
- "name": name,
333
- }
334
- response = client.execute(
335
- query,
336
- variable_values=query_variable_values,
337
- )
338
- if not (project_attrs := response.get("project")):
330
+ gql_vars = {"entityName": entity, "projectName": project, "name": name}
331
+ data = client.execute(query, variable_values=gql_vars)
332
+ result = ArtifactViaMembershipByName.model_validate(data)
333
+
334
+ if not (project_attrs := result.project):
339
335
  raise ValueError(f"project {project!r} not found under entity {entity!r}")
340
- if not (acm_attrs := project_attrs.get("artifactCollectionMembership")):
336
+
337
+ if not (acm_attrs := project_attrs.artifact_collection_membership):
341
338
  entity_project = f"{entity}/{project}"
342
339
  raise ValueError(
343
340
  f"artifact membership {name!r} not found in {entity_project!r}"
344
341
  )
345
- if not (ac_attrs := acm_attrs.get("artifactCollection")):
346
- raise ValueError("artifact collection not found")
347
- if not (
348
- (ac_name := ac_attrs.get("name"))
349
- and (ac_project_attrs := ac_attrs.get("project"))
350
- ):
351
- raise ValueError("artifact collection project not found")
352
- ac_project = ac_project_attrs.get("name")
353
- ac_entity = ac_project_attrs.get("entityName")
354
- if is_artifact_registry_project(ac_project) and project == "model-registry":
355
- wandb.termwarn(
356
- "This model registry has been migrated and will be discontinued. "
357
- f"Your request was redirected to the corresponding artifact `{ac_name}` in the new registry. "
358
- f"Please update your paths to point to the migrated registry directly, `{ac_project}/{ac_name}`."
359
- )
360
- entity = ac_entity
361
- project = ac_project
362
- if not (attrs := acm_attrs.get("artifact")):
363
- entity_project = f"{entity}/{project}"
364
- raise ValueError(f"artifact {name!r} not found in {entity_project!r}")
365
342
 
366
- return cls._from_attrs(entity, project, name, attrs, client)
343
+ target_path = ArtifactPath(prefix=entity, project=project, name=name)
344
+ return cls._from_membership(acm_attrs, target=target_path, client=client)
367
345
 
368
346
  @classmethod
369
347
  def _from_name(
@@ -375,59 +353,71 @@ class Artifact:
375
353
  client: RetryingClient,
376
354
  enable_tracking: bool = False,
377
355
  ) -> Artifact:
378
- if InternalApi()._server_supports(
356
+ if (api := InternalApi())._server_supports(
379
357
  pb.ServerFeature.PROJECT_ARTIFACT_COLLECTION_MEMBERSHIP
380
358
  ):
381
359
  return cls._membership_from_name(
382
- entity=entity,
383
- project=project,
384
- name=name,
385
- client=client,
360
+ entity=entity, project=project, name=name, client=client
386
361
  )
387
362
 
388
- query_variable_values: dict[str, Any] = {
363
+ supports_enable_tracking_gql_var = api.server_project_type_introspection()
364
+ omit_vars = None if supports_enable_tracking_gql_var else {"enableTracking"}
365
+
366
+ gql_vars = {
389
367
  "entityName": entity,
390
368
  "projectName": project,
391
369
  "name": name,
370
+ "enableTracking": enable_tracking,
392
371
  }
393
- query_vars = ["$entityName: String!", "$projectName: String!", "$name: String!"]
394
- query_args = ["name: $name"]
395
-
396
- server_supports_enabling_artifact_usage_tracking = (
397
- InternalApi().server_project_type_introspection()
372
+ query = gql_compat(
373
+ ARTIFACT_BY_NAME_GQL,
374
+ omit_variables=omit_vars,
375
+ omit_fields=omit_artifact_fields(api=api),
398
376
  )
399
- if server_supports_enabling_artifact_usage_tracking:
400
- query_vars.append("$enableTracking: Boolean")
401
- query_args.append("enableTracking: $enableTracking")
402
- query_variable_values["enableTracking"] = enable_tracking
403
-
404
- vars_str = ", ".join(query_vars)
405
- args_str = ", ".join(query_args)
406
-
407
- query = gql(
408
- f"""
409
- query ArtifactByName({vars_str}) {{
410
- project(name: $projectName, entityName: $entityName) {{
411
- artifact({args_str}) {{
412
- ...ArtifactFragment
413
- }}
414
- }}
415
- }}
416
- {_gql_artifact_fragment()}
417
- """
418
- )
419
- response = client.execute(
420
- query,
421
- variable_values=query_variable_values,
422
- )
423
- project_attrs = response.get("project")
424
- if not project_attrs:
425
- raise ValueError(f"project '{project}' not found under entity '{entity}'")
426
- attrs = project_attrs.get("artifact")
427
- if not attrs:
428
- raise ValueError(f"artifact '{name}' not found in '{entity}/{project}'")
429
377
 
430
- return cls._from_attrs(entity, project, name, attrs, client)
378
+ data = client.execute(query, variable_values=gql_vars)
379
+ result = ArtifactByName.model_validate(data)
380
+
381
+ if not (proj_attrs := result.project):
382
+ raise ValueError(f"project {project!r} not found under entity {entity!r}")
383
+
384
+ if not (art_attrs := proj_attrs.artifact):
385
+ entity_project = f"{entity}/{project}"
386
+ raise ValueError(f"artifact {name!r} not found in {entity_project!r}")
387
+
388
+ return cls._from_attrs(entity, project, name, art_attrs, client)
389
+
390
+ @classmethod
391
+ def _from_membership(
392
+ cls,
393
+ membership: MembershipWithArtifact,
394
+ target: ArtifactPath,
395
+ client: RetryingClient,
396
+ ) -> Artifact:
397
+ if not (
398
+ (collection := membership.artifact_collection)
399
+ and (name := collection.name)
400
+ and (proj := collection.project)
401
+ ):
402
+ raise ValueError("Missing artifact collection project in GraphQL response")
403
+
404
+ if is_artifact_registry_project(proj.name) and (
405
+ target.project == "model-registry"
406
+ ):
407
+ wandb.termwarn(
408
+ "This model registry has been migrated and will be discontinued. "
409
+ f"Your request was redirected to the corresponding artifact {name!r} in the new registry. "
410
+ f"Please update your paths to point to the migrated registry directly, '{proj.name}/{name}'."
411
+ )
412
+ new_entity, new_project = proj.entity_name, proj.name
413
+ else:
414
+ new_entity = cast(str, target.prefix)
415
+ new_project = cast(str, target.project)
416
+
417
+ if not (artifact := membership.artifact):
418
+ raise ValueError(f"Artifact {target.to_str()!r} not found in response")
419
+
420
+ return cls._from_attrs(new_entity, new_project, target.name, artifact, client)
431
421
 
432
422
  @classmethod
433
423
  def _from_attrs(
@@ -435,7 +425,7 @@ class Artifact:
435
425
  entity: str,
436
426
  project: str,
437
427
  name: str,
438
- attrs: dict[str, Any],
428
+ attrs: dict[str, Any] | ArtifactFragment,
439
429
  client: RetryingClient,
440
430
  aliases: list[str] | None = None,
441
431
  ) -> Artifact:
@@ -445,7 +435,9 @@ class Artifact:
445
435
  artifact._entity = entity
446
436
  artifact._project = project
447
437
  artifact._name = name
448
- artifact._assign_attrs(attrs, aliases)
438
+
439
+ validated_attrs = ArtifactFragment.model_validate(attrs)
440
+ artifact._assign_attrs(validated_attrs, aliases)
449
441
 
450
442
  artifact.finalize()
451
443
 
@@ -458,29 +450,24 @@ class Artifact:
458
450
  # doesn't make it clear if the artifact is a link or not and have to manually set it.
459
451
  def _assign_attrs(
460
452
  self,
461
- attrs: dict[str, Any],
453
+ art: ArtifactFragment,
462
454
  aliases: list[str] | None = None,
463
455
  is_link: bool | None = None,
464
456
  ) -> None:
465
457
  """Update this Artifact's attributes using the server response."""
466
- self._id = attrs["id"]
467
-
468
- src_version = f"v{attrs['versionIndex']}"
469
- src_collection = attrs["artifactSequence"]
470
- src_project = src_collection["project"]
458
+ self._id = art.id
471
459
 
472
- self._source_entity = src_project["entityName"] if src_project else ""
473
- self._source_project = src_project["name"] if src_project else ""
474
- self._source_name = f"{src_collection['name']}:{src_version}"
475
- self._source_version = src_version
460
+ src_collection = art.artifact_sequence
461
+ src_project = src_collection.project
476
462
 
477
- if self._entity is None:
478
- self._entity = self._source_entity
479
- if self._project is None:
480
- self._project = self._source_project
463
+ self._source_entity = src_project.entity_name if src_project else ""
464
+ self._source_project = src_project.name if src_project else ""
465
+ self._source_name = f"{src_collection.name}:v{art.version_index}"
466
+ self._source_version = f"v{art.version_index}"
481
467
 
482
- if self._name is None:
483
- self._name = self._source_name
468
+ self._entity = self._entity or self._source_entity
469
+ self._project = self._project or self._source_project
470
+ self._name = self._name or self._source_name
484
471
 
485
472
  # TODO: Refactor artifact query to fetch artifact via membership instead
486
473
  # and get the collection type
@@ -488,33 +475,35 @@ class Artifact:
488
475
  self._is_link = (
489
476
  self._entity != self._source_entity
490
477
  or self._project != self._source_project
491
- or self._name != self._source_name
478
+ or self._name.split(":")[0] != self._source_name.split(":")[0]
492
479
  )
493
480
  else:
494
481
  self._is_link = is_link
495
482
 
496
- self._type = attrs["artifactType"]["name"]
497
- self._description = attrs["description"]
498
-
499
- entity = self._entity
500
- project = self._project
501
- collection, *_ = self._name.split(":")
483
+ self._type = art.artifact_type.name
484
+ self._description = art.description
502
485
 
503
- processed_aliases = []
504
486
  # The future of aliases is to move all alias fetches to the membership level
505
487
  # so we don't have to do the collection fetches below
506
488
  if aliases:
507
489
  processed_aliases = aliases
508
- else:
490
+ elif art.aliases:
491
+ entity = self._entity
492
+ project = self._project
493
+ collection = self._name.split(":")[0]
509
494
  processed_aliases = [
510
- obj["alias"]
511
- for obj in attrs["aliases"]
512
- if obj["artifactCollection"]
513
- and obj["artifactCollection"]["project"]
514
- and obj["artifactCollection"]["project"]["entityName"] == entity
515
- and obj["artifactCollection"]["project"]["name"] == project
516
- and obj["artifactCollection"]["name"] == collection
495
+ art_alias.alias
496
+ for art_alias in art.aliases
497
+ if (
498
+ (coll := art_alias.artifact_collection)
499
+ and (proj := coll.project)
500
+ and proj.entity_name == entity
501
+ and proj.name == project
502
+ and coll.name == collection
503
+ )
517
504
  ]
505
+ else:
506
+ processed_aliases = []
518
507
 
519
508
  version_aliases = list(filter(alias_is_version_index, processed_aliases))
520
509
  other_aliases = list(filterfalse(alias_is_version_index, processed_aliases))
@@ -524,49 +513,42 @@ class Artifact:
524
513
  version_aliases, too_short=TooFewItemsError, too_long=TooManyItemsError
525
514
  )
526
515
  except TooFewItemsError:
527
- version = src_version # default to the source version
516
+ version = f"v{art.version_index}" # default to the source version
528
517
  except TooManyItemsError:
529
518
  msg = f"Expected at most one version alias, got {len(version_aliases)}: {version_aliases!r}"
530
519
  raise ValueError(msg) from None
531
520
 
532
521
  self._version = version
533
-
534
- if ":" not in self._name:
535
- self._name = f"{self._name}:{version}"
522
+ self._name = self._name if (":" in self._name) else f"{self._name}:{version}"
536
523
 
537
524
  self._aliases = other_aliases
538
- self._saved_aliases = copy(other_aliases)
525
+ self._saved_aliases = copy(self._aliases)
539
526
 
540
- tags = [obj["name"] for obj in (attrs.get("tags") or [])]
541
- self._tags = tags
542
- self._saved_tags = copy(tags)
527
+ self._tags = [tag.name for tag in (art.tags or [])]
528
+ self._saved_tags = copy(self._tags)
543
529
 
544
- metadata_str = attrs["metadata"]
545
- self._metadata = validate_metadata(
546
- json.loads(metadata_str) if metadata_str else {}
547
- )
530
+ self._metadata = validate_metadata(art.metadata)
548
531
 
549
532
  self._ttl_duration_seconds = validate_ttl_duration_seconds(
550
- attrs.get("ttlDurationSeconds")
533
+ art.ttl_duration_seconds
551
534
  )
552
535
  self._ttl_is_inherited = (
553
- True if (attrs.get("ttlIsInherited") is None) else attrs["ttlIsInherited"]
536
+ True if (art.ttl_is_inherited is None) else art.ttl_is_inherited
554
537
  )
555
538
 
556
- self._state = ArtifactState(attrs["state"])
539
+ self._state = ArtifactState(art.state)
557
540
 
558
- try:
559
- manifest_url = attrs["currentManifest"]["file"]["directUrl"]
560
- except (LookupError, TypeError):
561
- self._manifest = None
562
- else:
563
- self._manifest = _DeferredArtifactManifest(manifest_url)
541
+ self._manifest = (
542
+ _DeferredArtifactManifest(manifest.file.direct_url)
543
+ if (manifest := art.current_manifest)
544
+ else None
545
+ )
564
546
 
565
- self._commit_hash = attrs["commitHash"]
566
- self._file_count = attrs["fileCount"]
567
- self._created_at = attrs["createdAt"]
568
- self._updated_at = attrs["updatedAt"]
569
- self._history_step = attrs.get("historyStep", None)
547
+ self._commit_hash = art.commit_hash
548
+ self._file_count = art.file_count
549
+ self._created_at = art.created_at
550
+ self._updated_at = art.updated_at
551
+ self._history_step = art.history_step
570
552
 
571
553
  @ensure_logged
572
554
  def new_draft(self) -> Artifact:
@@ -1063,37 +1045,24 @@ class Artifact:
1063
1045
  return self._manifest
1064
1046
 
1065
1047
  if self._manifest is None:
1066
- query = gql(
1067
- """
1068
- query ArtifactManifest(
1069
- $entityName: String!,
1070
- $projectName: String!,
1071
- $name: String!
1072
- ) {
1073
- project(entityName: $entityName, name: $projectName) {
1074
- artifact(name: $name) {
1075
- currentManifest {
1076
- file {
1077
- directUrl
1078
- }
1079
- }
1080
- }
1081
- }
1082
- }
1083
- """
1084
- )
1085
- assert self._client is not None
1086
- response = self._client.execute(
1087
- query,
1088
- variable_values={
1089
- "entityName": self._entity,
1090
- "projectName": self._project,
1091
- "name": self._name,
1092
- },
1093
- )
1094
- attrs = response["project"]["artifact"]
1095
- manifest_url = attrs["currentManifest"]["file"]["directUrl"]
1096
- self._manifest = self._load_manifest(manifest_url)
1048
+ if self._client is None:
1049
+ raise RuntimeError("Client not initialized for artifact queries")
1050
+
1051
+ query = gql(FETCH_ARTIFACT_MANIFEST_GQL)
1052
+ gql_vars = {
1053
+ "entityName": self.entity,
1054
+ "projectName": self.project,
1055
+ "name": self.name,
1056
+ }
1057
+ data = self._client.execute(query, variable_values=gql_vars)
1058
+ result = FetchArtifactManifest.model_validate(data)
1059
+ if not (
1060
+ (project := result.project)
1061
+ and (artifact := project.artifact)
1062
+ and (manifest := artifact.current_manifest)
1063
+ ):
1064
+ raise ValueError("Failed to fetch artifact manifest")
1065
+ self._manifest = self._load_manifest(manifest.file.direct_url)
1097
1066
 
1098
1067
  return self._manifest
1099
1068
 
@@ -1112,11 +1081,7 @@ class Artifact:
1112
1081
 
1113
1082
  Includes any references tracked by this artifact.
1114
1083
  """
1115
- total_size: int = 0
1116
- for entry in self.manifest.entries.values():
1117
- if entry.size is not None:
1118
- total_size += entry.size
1119
- return total_size
1084
+ return sum(entry.size for entry in self.manifest.entries.values() if entry.size)
1120
1085
 
1121
1086
  @property
1122
1087
  @ensure_logged
@@ -1184,7 +1149,7 @@ class Artifact:
1184
1149
  Returns:
1185
1150
  Boolean. `False` if artifact is saved. `True` if artifact is not saved.
1186
1151
  """
1187
- return self._state == ArtifactState.PENDING
1152
+ return self._state is ArtifactState.PENDING
1188
1153
 
1189
1154
  def _is_draft_save_started(self) -> bool:
1190
1155
  return self._save_handle is not None
@@ -1205,7 +1170,7 @@ class Artifact:
1205
1170
  settings: A settings object to use when initializing an automatic run. Most
1206
1171
  commonly used in testing harness.
1207
1172
  """
1208
- if self._state != ArtifactState.PENDING:
1173
+ if self._state is not ArtifactState.PENDING:
1209
1174
  return self._update()
1210
1175
 
1211
1176
  if self._incremental:
@@ -1266,31 +1231,20 @@ class Artifact:
1266
1231
  return self
1267
1232
 
1268
1233
  def _populate_after_save(self, artifact_id: str) -> None:
1269
- query_template = """
1270
- query ArtifactByIDShort($id: ID!) {
1271
- artifact(id: $id) {
1272
- ...ArtifactFragment
1273
- }
1274
- }
1275
- """ + _gql_artifact_fragment()
1234
+ assert self._client is not None
1276
1235
 
1277
- query = gql(query_template)
1236
+ query = gql_compat(ARTIFACT_BY_ID_GQL, omit_fields=omit_artifact_fields())
1278
1237
 
1279
- assert self._client is not None
1280
- response = self._client.execute(
1281
- query,
1282
- variable_values={"id": artifact_id},
1283
- )
1238
+ data = self._client.execute(query, variable_values={"id": artifact_id})
1239
+ result = ArtifactByID.model_validate(data)
1284
1240
 
1285
- try:
1286
- attrs = response["artifact"]
1287
- except LookupError:
1241
+ if not (artifact := result.artifact):
1288
1242
  raise ValueError(f"Unable to fetch artifact with id: {artifact_id!r}")
1289
- else:
1290
- # _populate_after_save is only called on source artifacts, not linked artifacts
1291
- # We have to manually set is_link because we aren't fetching the collection the artifact.
1292
- # That requires greater refactoring for commitArtifact to return the artifact collection type.
1293
- self._assign_attrs(attrs, is_link=False)
1243
+
1244
+ # _populate_after_save is only called on source artifacts, not linked artifacts
1245
+ # We have to manually set is_link because we aren't fetching the collection the artifact.
1246
+ # That requires greater refactoring for commitArtifact to return the artifact collection type.
1247
+ self._assign_attrs(artifact, is_link=False)
1294
1248
 
1295
1249
  @normalize_exceptions
1296
1250
  def _update(self) -> None:
@@ -1375,7 +1329,7 @@ class Artifact:
1375
1329
  for alias in self.aliases
1376
1330
  ]
1377
1331
 
1378
- omit_fields = omit_artifact_fields(api=InternalApi())
1332
+ omit_fields = omit_artifact_fields()
1379
1333
  omit_variables = set()
1380
1334
 
1381
1335
  if {"ttlIsInherited", "ttlDurationSeconds"} & omit_fields:
@@ -1399,7 +1353,9 @@ class Artifact:
1399
1353
  omit_variables |= {"tagsToAdd", "tagsToDelete"}
1400
1354
 
1401
1355
  mutation = gql_compat(
1402
- UPDATE_ARTIFACT_GQL, omit_variables=omit_variables, omit_fields=omit_fields
1356
+ UPDATE_ARTIFACT_GQL,
1357
+ omit_variables=omit_variables,
1358
+ omit_fields=omit_fields,
1403
1359
  )
1404
1360
 
1405
1361
  gql_vars = {
@@ -1417,7 +1373,7 @@ class Artifact:
1417
1373
  result = UpdateArtifact.model_validate(data).update_artifact
1418
1374
  if not (result and (artifact := result.artifact)):
1419
1375
  raise ValueError("Unable to parse updateArtifact response")
1420
- self._assign_attrs(artifact.model_dump())
1376
+ self._assign_attrs(artifact)
1421
1377
 
1422
1378
  self._ttl_changed = False # Reset after updating artifact
1423
1379
 
@@ -1481,7 +1437,7 @@ class Artifact:
1481
1437
  self._tmp_dir = tempfile.TemporaryDirectory()
1482
1438
  path = os.path.join(self._tmp_dir.name, name.lstrip("/"))
1483
1439
 
1484
- filesystem.mkdir_exists_ok(os.path.dirname(path))
1440
+ Path(path).parent.mkdir(parents=True, exist_ok=True)
1485
1441
  try:
1486
1442
  with fsync_open(path, mode, encoding) as f:
1487
1443
  yield f
@@ -1588,30 +1544,27 @@ class Artifact:
1588
1544
  ValueError: Policy must be "mutable" or "immutable"
1589
1545
  """
1590
1546
  if not os.path.isdir(local_path):
1591
- raise ValueError(f"Path is not a directory: {local_path}")
1547
+ raise ValueError(f"Path is not a directory: {local_path!r}")
1592
1548
 
1593
1549
  termlog(
1594
- "Adding directory to artifact ({})... ".format(
1595
- os.path.join(".", os.path.normpath(local_path))
1596
- ),
1550
+ f"Adding directory to artifact ({Path('.', local_path)})... ",
1597
1551
  newline=False,
1598
1552
  )
1599
- start_time = time.time()
1553
+ start_time = time.monotonic()
1600
1554
 
1601
- paths = []
1555
+ paths: deque[tuple[str, str]] = deque()
1556
+ logical_root = name or "" # shared prefix, if any, for logical paths
1602
1557
  for dirpath, _, filenames in os.walk(local_path, followlinks=True):
1603
1558
  for fname in filenames:
1604
1559
  physical_path = os.path.join(dirpath, fname)
1605
1560
  logical_path = os.path.relpath(physical_path, start=local_path)
1606
- if name is not None:
1607
- logical_path = os.path.join(name, logical_path)
1561
+ logical_path = os.path.join(logical_root, logical_path)
1608
1562
  paths.append((logical_path, physical_path))
1609
1563
 
1610
- def add_manifest_file(log_phy_path: tuple[str, str]) -> None:
1611
- logical_path, physical_path = log_phy_path
1564
+ def add_manifest_file(logical_pth: str, physical_pth: str) -> None:
1612
1565
  self._add_local_file(
1613
- name=logical_path,
1614
- path=physical_path,
1566
+ name=logical_pth,
1567
+ path=physical_pth,
1615
1568
  skip_cache=skip_cache,
1616
1569
  policy=policy,
1617
1570
  overwrite=merge,
@@ -1619,11 +1572,11 @@ class Artifact:
1619
1572
 
1620
1573
  num_threads = 8
1621
1574
  pool = multiprocessing.dummy.Pool(num_threads)
1622
- pool.map(add_manifest_file, paths)
1575
+ pool.starmap(add_manifest_file, paths)
1623
1576
  pool.close()
1624
1577
  pool.join()
1625
1578
 
1626
- termlog("Done. %.1fs" % (time.time() - start_time), prefix=False)
1579
+ termlog("Done. %.1fs" % (time.monotonic() - start_time), prefix=False)
1627
1580
 
1628
1581
  @ensure_not_finalized
1629
1582
  def add_reference(
@@ -1699,7 +1652,7 @@ class Artifact:
1699
1652
  "References must be URIs. To reference a local file, use file://"
1700
1653
  )
1701
1654
 
1702
- manifest_entries = self._storage_policy.store_reference(
1655
+ manifest_entries = self.manifest.storage_policy.store_reference(
1703
1656
  self,
1704
1657
  URIStr(uri_str),
1705
1658
  name=name,
@@ -1861,10 +1814,8 @@ class Artifact:
1861
1814
  return
1862
1815
 
1863
1816
  path = str(PurePosixPath(item))
1864
- entry = self.manifest.get_entry_by_path(path)
1865
- if entry:
1866
- self.manifest.remove_entry(entry)
1867
- return
1817
+ if entry := self.manifest.get_entry_by_path(path):
1818
+ return self.manifest.remove_entry(entry)
1868
1819
 
1869
1820
  entries = self.manifest.get_entries_in_directory(path)
1870
1821
  if not entries:
@@ -1922,8 +1873,7 @@ class Artifact:
1922
1873
 
1923
1874
  # If the entry is a reference from another artifact, then get it directly from
1924
1875
  # that artifact.
1925
- referenced_id = entry._referenced_artifact_id()
1926
- if referenced_id:
1876
+ if referenced_id := entry._referenced_artifact_id():
1927
1877
  assert self._client is not None
1928
1878
  artifact = self._from_id(referenced_id, client=self._client)
1929
1879
  assert artifact is not None
@@ -1942,10 +1892,9 @@ class Artifact:
1942
1892
  item_path = item.download()
1943
1893
 
1944
1894
  # Load the object from the JSON blob
1945
- result = None
1946
- json_obj = {}
1947
1895
  with open(item_path) as file:
1948
1896
  json_obj = json.load(file)
1897
+
1949
1898
  result = wb_class.from_json(json_obj, self)
1950
1899
  result._set_artifact_source(self, name)
1951
1900
  return result
@@ -1959,10 +1908,9 @@ class Artifact:
1959
1908
  Returns:
1960
1909
  The artifact relative name.
1961
1910
  """
1962
- entry = self._added_local_paths.get(local_path, None)
1963
- if entry is None:
1964
- return None
1965
- return entry.path
1911
+ if entry := self._added_local_paths.get(local_path):
1912
+ return entry.path
1913
+ return None
1966
1914
 
1967
1915
  def _get_obj_entry(
1968
1916
  self, name: str
@@ -1979,8 +1927,7 @@ class Artifact:
1979
1927
  """
1980
1928
  for wb_class in WBValue.type_mapping().values():
1981
1929
  wandb_file_name = wb_class.with_suffix(name)
1982
- entry = self.manifest.entries.get(wandb_file_name)
1983
- if entry is not None:
1930
+ if entry := self.manifest.entries.get(wandb_file_name):
1984
1931
  return entry, wb_class
1985
1932
  return None, None
1986
1933
 
@@ -2021,7 +1968,7 @@ class Artifact:
2021
1968
  Raises:
2022
1969
  ArtifactNotLoggedError: If the artifact is not logged.
2023
1970
  """
2024
- root = FilePathStr(str(root or self._default_root()))
1971
+ root = FilePathStr(root or self._default_root())
2025
1972
  self._add_download_root(root)
2026
1973
 
2027
1974
  # TODO: download artifacts using core when implemented
@@ -2110,14 +2057,14 @@ class Artifact:
2110
2057
  multipart: bool | None = None,
2111
2058
  ) -> FilePathStr:
2112
2059
  nfiles = len(self.manifest.entries)
2113
- size = sum(e.size or 0 for e in self.manifest.entries.values())
2114
- log = False
2115
- if nfiles > 5000 or size > 50 * 1024 * 1024:
2116
- log = True
2060
+ size_mb = self.size / _MB
2061
+
2062
+ if log := (nfiles > 5000 or size_mb > 50):
2117
2063
  termlog(
2118
- f"Downloading large artifact {self.name}, {size / (1024 * 1024):.2f}MB. {nfiles} files... ",
2064
+ f"Downloading large artifact {self.name!r}, {size_mb:.2f}MB. {nfiles!r} files...",
2119
2065
  )
2120
- start_time = datetime.now()
2066
+ start_time = time.monotonic()
2067
+
2121
2068
  download_logger = ArtifactDownloadLogger(nfiles=nfiles)
2122
2069
 
2123
2070
  def _download_entry(
@@ -2156,44 +2103,62 @@ class Artifact:
2156
2103
  cookies=_thread_local_api_settings.cookies,
2157
2104
  headers=_thread_local_api_settings.headers,
2158
2105
  )
2106
+
2107
+ batch_size = env.get_artifact_fetch_file_url_batch_size()
2108
+
2159
2109
  active_futures = set()
2160
- has_next_page = True
2161
- cursor = None
2162
- while has_next_page:
2163
- fetch_url_batch_size = env.get_artifact_fetch_file_url_batch_size()
2164
- attrs = self._fetch_file_urls(cursor, fetch_url_batch_size)
2165
- has_next_page = attrs["pageInfo"]["hasNextPage"]
2166
- cursor = attrs["pageInfo"]["endCursor"]
2167
- for edge in attrs["edges"]:
2168
- entry = self.get_entry(edge["node"]["name"])
2110
+ cursor, has_more = None, True
2111
+ while has_more:
2112
+ files_page = self._fetch_file_urls(cursor=cursor, per_page=batch_size)
2113
+
2114
+ has_more = files_page.page_info.has_next_page
2115
+ cursor = files_page.page_info.end_cursor
2116
+
2117
+ # `File` nodes are formally nullable, so filter them out just in case.
2118
+ file_nodes = (e.node for e in files_page.edges if e.node)
2119
+ for node in file_nodes:
2120
+ entry = self.get_entry(node.name)
2169
2121
  # TODO: uncomment once artifact downloads are supported in core
2170
2122
  # if require_core and entry.ref is None:
2171
2123
  # # Handled by core
2172
2124
  # continue
2173
- entry._download_url = edge["node"]["directUrl"]
2125
+ entry._download_url = node.direct_url
2174
2126
  if (not path_prefix) or entry.path.startswith(str(path_prefix)):
2175
2127
  active_futures.add(executor.submit(download_entry, entry))
2128
+
2176
2129
  # Wait for download threads to catch up.
2177
- max_backlog = fetch_url_batch_size
2178
- if len(active_futures) > max_backlog:
2130
+ #
2131
+ # Extra context and observations (tonyyli):
2132
+ # - Even though the ThreadPoolExecutor limits the number of
2133
+ # concurrently-executed tasks, its internal task queue is unbounded.
2134
+ # The code below seems intended to ensure that at most `batch_size`
2135
+ # "backlogged" futures are held in memory at any given time. This seems like
2136
+ # a reasonable safeguard against unbounded memory consumption.
2137
+ #
2138
+ # - We should probably use a builtin (bounded) Queue or Semaphore here instead.
2139
+ # Consider this for a future change, or (depending on risk and risk tolerance)
2140
+ # managing this logic via asyncio instead, if viable.
2141
+ if len(active_futures) > batch_size:
2179
2142
  for future in concurrent.futures.as_completed(active_futures):
2180
2143
  future.result() # check for errors
2181
2144
  active_futures.remove(future)
2182
- if len(active_futures) <= max_backlog:
2145
+ if len(active_futures) <= batch_size:
2183
2146
  break
2147
+
2184
2148
  # Check for errors.
2185
2149
  for future in concurrent.futures.as_completed(active_futures):
2186
2150
  future.result()
2187
2151
 
2188
2152
  if log:
2189
- now = datetime.now()
2190
- delta = abs((now - start_time).total_seconds())
2191
- hours = int(delta // 3600)
2192
- minutes = int((delta - hours * 3600) // 60)
2193
- seconds = delta - hours * 3600 - minutes * 60
2194
- speed = size / 1024 / 1024 / delta
2153
+ # If you're wondering if we can display a `timedelta`, note that it
2154
+ # doesn't really support custom string format specifiers (compared to
2155
+ # e.g. `datetime` objs). To truncate the number of decimal places for
2156
+ # the seconds part, we manually convert/format each part below.
2157
+ dt_secs = abs(time.monotonic() - start_time)
2158
+ hrs, mins = divmod(dt_secs, 3600)
2159
+ mins, secs = divmod(mins, 60)
2195
2160
  termlog(
2196
- f"Done. {hours}:{minutes}:{seconds:.1f} ({speed:.1f}MB/s)",
2161
+ f"Done. {int(hrs):02d}:{int(mins):02d}:{secs:04.1f} ({size_mb / dt_secs:.1f}MB/s)",
2197
2162
  prefix=False,
2198
2163
  )
2199
2164
  return FilePathStr(root)
@@ -2202,79 +2167,44 @@ class Artifact:
2202
2167
  retry_timedelta=timedelta(minutes=3),
2203
2168
  retryable_exceptions=(requests.RequestException),
2204
2169
  )
2205
- def _fetch_file_urls(self, cursor: str | None, per_page: int | None = 5000) -> Any:
2170
+ def _fetch_file_urls(
2171
+ self, cursor: str | None, per_page: int = 5000
2172
+ ) -> FileUrlsFragment:
2173
+ if self._client is None:
2174
+ raise RuntimeError("Client not initialized")
2175
+
2206
2176
  if InternalApi()._server_supports(
2207
2177
  pb.ServerFeature.ARTIFACT_COLLECTION_MEMBERSHIP_FILES
2208
2178
  ):
2209
- query = gql(
2210
- """
2211
- query ArtifactCollectionMembershipFileURLs($entityName: String!, $projectName: String!, \
2212
- $artifactName: String!, $artifactVersionIndex: String!, $cursor: String, $perPage: Int) {
2213
- project(name: $projectName, entityName: $entityName) {
2214
- artifactCollection(name: $artifactName) {
2215
- artifactMembership(aliasName: $artifactVersionIndex) {
2216
- files(after: $cursor, first: $perPage) {
2217
- pageInfo {
2218
- hasNextPage
2219
- endCursor
2220
- }
2221
- edges {
2222
- node {
2223
- name
2224
- directUrl
2225
- }
2226
- }
2227
- }
2228
- }
2229
- }
2230
- }
2231
- }
2232
- """
2233
- )
2234
- assert self._client is not None
2235
- response = self._client.execute(
2236
- query,
2237
- variable_values={
2238
- "entityName": self.entity,
2239
- "projectName": self.project,
2240
- "artifactName": self.name.split(":")[0],
2241
- "artifactVersionIndex": self.version,
2242
- "cursor": cursor,
2243
- "perPage": per_page,
2244
- },
2245
- timeout=60,
2246
- )
2247
- return response["project"]["artifactCollection"]["artifactMembership"][
2248
- "files"
2249
- ]
2179
+ query = gql(ARTIFACT_COLLECTION_MEMBERSHIP_FILE_URLS_GQL)
2180
+ gql_vars = {
2181
+ "entityName": self.entity,
2182
+ "projectName": self.project,
2183
+ "artifactName": self.name.split(":")[0],
2184
+ "artifactVersionIndex": self.version,
2185
+ "cursor": cursor,
2186
+ "perPage": per_page,
2187
+ }
2188
+ data = self._client.execute(query, variable_values=gql_vars, timeout=60)
2189
+ result = ArtifactCollectionMembershipFileUrls.model_validate(data)
2190
+
2191
+ if not (
2192
+ (project := result.project)
2193
+ and (collection := project.artifact_collection)
2194
+ and (membership := collection.artifact_membership)
2195
+ and (files := membership.files)
2196
+ ):
2197
+ raise ValueError(f"Unable to fetch files for artifact: {self.name!r}")
2198
+ return files
2250
2199
  else:
2251
- query = gql(
2252
- """
2253
- query ArtifactFileURLs($id: ID!, $cursor: String, $perPage: Int) {
2254
- artifact(id: $id) {
2255
- files(after: $cursor, first: $perPage) {
2256
- pageInfo {
2257
- hasNextPage
2258
- endCursor
2259
- }
2260
- edges {
2261
- node {
2262
- name
2263
- directUrl
2264
- }
2265
- }
2266
- }
2267
- }
2268
- }
2269
- """
2270
- )
2271
- assert self._client is not None
2272
- response = self._client.execute(
2273
- query,
2274
- variable_values={"id": self.id, "cursor": cursor, "perPage": per_page},
2275
- timeout=60,
2276
- )
2277
- return response["artifact"]["files"]
2200
+ query = gql(ARTIFACT_FILE_URLS_GQL)
2201
+ gql_vars = {"id": self.id, "cursor": cursor, "perPage": per_page}
2202
+ data = self._client.execute(query, variable_values=gql_vars, timeout=60)
2203
+ result = ArtifactFileUrls.model_validate(data)
2204
+
2205
+ if not ((artifact := result.artifact) and (files := artifact.files)):
2206
+ raise ValueError(f"Unable to fetch files for artifact: {self.name!r}")
2207
+ return files
2278
2208
 
2279
2209
  @ensure_logged
2280
2210
  def checkout(self, root: str | None = None) -> str:
@@ -2395,8 +2325,7 @@ class Artifact:
2395
2325
  # In case we're on a system where the artifact dir has a name corresponding to
2396
2326
  # an unexpected filesystem, we'll check for alternate roots. If one exists we'll
2397
2327
  # use that, otherwise we'll fall back to the system-preferred path.
2398
- path = filesystem.check_exists(root) or filesystem.system_preferred_path(root)
2399
- return FilePathStr(str(path))
2328
+ return FilePathStr(check_exists(root) or system_preferred_path(root))
2400
2329
 
2401
2330
  def _add_download_root(self, dir_path: str) -> None:
2402
2331
  self._download_roots.add(os.path.abspath(dir_path))
@@ -2440,33 +2369,16 @@ class Artifact:
2440
2369
 
2441
2370
  @normalize_exceptions
2442
2371
  def _delete(self, delete_aliases: bool = False) -> None:
2443
- mutation = gql(
2444
- """
2445
- mutation DeleteArtifact($artifactID: ID!, $deleteAliases: Boolean) {
2446
- deleteArtifact(input: {
2447
- artifactID: $artifactID
2448
- deleteAliases: $deleteAliases
2449
- }) {
2450
- artifact {
2451
- id
2452
- }
2453
- }
2454
- }
2455
- """
2456
- )
2457
- assert self._client is not None
2458
- self._client.execute(
2459
- mutation,
2460
- variable_values={
2461
- "artifactID": self.id,
2462
- "deleteAliases": delete_aliases,
2463
- },
2464
- )
2372
+ if self._client is None:
2373
+ raise RuntimeError("Client not initialized for artifact mutations")
2374
+
2375
+ mutation = gql(DELETE_ARTIFACT_GQL)
2376
+ gql_vars = {"artifactID": self.id, "deleteAliases": delete_aliases}
2377
+
2378
+ self._client.execute(mutation, variable_values=gql_vars)
2465
2379
 
2466
2380
  @normalize_exceptions
2467
- def link(
2468
- self, target_path: str, aliases: list[str] | None = None
2469
- ) -> Artifact | None:
2381
+ def link(self, target_path: str, aliases: list[str] | None = None) -> Artifact:
2470
2382
  """Link this artifact to a portfolio (a promoted collection of artifacts).
2471
2383
 
2472
2384
  Args:
@@ -2485,7 +2397,7 @@ class Artifact:
2485
2397
  ArtifactNotLoggedError: If the artifact is not logged.
2486
2398
 
2487
2399
  Returns:
2488
- The linked artifact if linking was successful, otherwise None.
2400
+ The linked artifact.
2489
2401
  """
2490
2402
  from wandb import Api
2491
2403
 
@@ -2505,10 +2417,11 @@ class Artifact:
2505
2417
  # Wait until the artifact is committed before trying to link it.
2506
2418
  self.wait()
2507
2419
 
2508
- api = Api(overrides={"entity": self.source_entity})
2420
+ api = InternalApi()
2421
+ settings = api.settings()
2509
2422
 
2510
2423
  target = ArtifactPath.from_str(target_path).with_defaults(
2511
- project=api.settings.get("project") or "uncategorized",
2424
+ project=settings.get("project") or "uncategorized",
2512
2425
  )
2513
2426
 
2514
2427
  # Parse the entity (first part of the path) appropriately,
@@ -2516,11 +2429,8 @@ class Artifact:
2516
2429
  if target.project and is_artifact_registry_project(target.project):
2517
2430
  # In a Registry linking, the entity is used to fetch the organization of the artifact
2518
2431
  # therefore the source artifact's entity is passed to the backend
2519
- organization = target.prefix or api.settings.get("organization") or ""
2520
-
2521
- target.prefix = InternalApi()._resolve_org_entity_name(
2522
- self.source_entity, organization
2523
- )
2432
+ org = target.prefix or settings.get("organization") or ""
2433
+ target.prefix = api._resolve_org_entity_name(self.source_entity, org)
2524
2434
  else:
2525
2435
  target = target.with_defaults(prefix=self.source_entity)
2526
2436
 
@@ -2537,21 +2447,35 @@ class Artifact:
2537
2447
  aliases=alias_inputs,
2538
2448
  )
2539
2449
  gql_vars = {"input": gql_input.model_dump(exclude_none=True)}
2540
- gql_op = gql(LINK_ARTIFACT_GQL)
2541
- data = self._client.execute(gql_op, variable_values=gql_vars)
2542
2450
 
2451
+ # Newer server versions can return `artifactMembership` directly in the response,
2452
+ # avoiding the need to re-fetch the linked artifact at the end.
2453
+ if api._server_supports(
2454
+ pb.ServerFeature.ARTIFACT_MEMBERSHIP_IN_LINK_ARTIFACT_RESPONSE
2455
+ ):
2456
+ omit_fragments = set()
2457
+ else:
2458
+ # FIXME: Make `gql_compat` omit nested fragment definitions recursively (but safely)
2459
+ omit_fragments = {
2460
+ "MembershipWithArtifact",
2461
+ "ArtifactFragment",
2462
+ "ArtifactFragmentWithoutAliases",
2463
+ }
2464
+
2465
+ gql_op = gql_compat(LINK_ARTIFACT_GQL, omit_fragments=omit_fragments)
2466
+ data = self._client.execute(gql_op, variable_values=gql_vars)
2543
2467
  result = LinkArtifact.model_validate(data).link_artifact
2468
+
2469
+ # Newer server versions can return artifactMembership directly in the response
2470
+ if result and (membership := result.artifact_membership):
2471
+ return self._from_membership(membership, target=target, client=self._client)
2472
+
2473
+ # Fallback to old behavior, which requires re-fetching the linked artifact to return it
2544
2474
  if not (result and (version_idx := result.version_index) is not None):
2545
2475
  raise ValueError("Unable to parse linked artifact version from response")
2546
2476
 
2547
- # Fetch the linked artifact to return it
2548
- linked_path = f"{target.to_str()}:v{version_idx}"
2549
-
2550
- try:
2551
- return api._artifact(linked_path)
2552
- except Exception as e:
2553
- wandb.termerror(f"Error fetching link artifact after linking: {e}")
2554
- return None
2477
+ link_name = f"{target.to_str()}:v{version_idx}"
2478
+ return Api(overrides={"entity": self.source_entity})._artifact(link_name)
2555
2479
 
2556
2480
  @ensure_logged
2557
2481
  def unlink(self) -> None:
@@ -2573,28 +2497,14 @@ class Artifact:
2573
2497
 
2574
2498
  @normalize_exceptions
2575
2499
  def _unlink(self) -> None:
2576
- mutation = gql(
2577
- """
2578
- mutation UnlinkArtifact($artifactID: ID!, $artifactPortfolioID: ID!) {
2579
- unlinkArtifact(
2580
- input: { artifactID: $artifactID, artifactPortfolioID: $artifactPortfolioID }
2581
- ) {
2582
- artifactID
2583
- success
2584
- clientMutationId
2585
- }
2586
- }
2587
- """
2588
- )
2589
- assert self._client is not None
2500
+ if self._client is None:
2501
+ raise RuntimeError("Client not initialized for artifact mutations")
2502
+
2503
+ mutation = gql(UNLINK_ARTIFACT_GQL)
2504
+ gql_vars = {"artifactID": self.id, "artifactPortfolioID": self.collection.id}
2505
+
2590
2506
  try:
2591
- self._client.execute(
2592
- mutation,
2593
- variable_values={
2594
- "artifactID": self.id,
2595
- "artifactPortfolioID": self.collection.id,
2596
- },
2597
- )
2507
+ self._client.execute(mutation, variable_values=gql_vars)
2598
2508
  except CommError as e:
2599
2509
  raise CommError(
2600
2510
  f"You do not have permission to unlink the artifact {self.qualified_name}"
@@ -2610,41 +2520,26 @@ class Artifact:
2610
2520
  Raises:
2611
2521
  ArtifactNotLoggedError: If the artifact is not logged.
2612
2522
  """
2613
- query = gql(
2614
- """
2615
- query ArtifactUsedBy(
2616
- $id: ID!,
2617
- ) {
2618
- artifact(id: $id) {
2619
- usedBy {
2620
- edges {
2621
- node {
2622
- name
2623
- project {
2624
- name
2625
- entityName
2626
- }
2627
- }
2628
- }
2629
- }
2630
- }
2631
- }
2632
- """
2633
- )
2634
- assert self._client is not None
2635
- response = self._client.execute(
2636
- query,
2637
- variable_values={"id": self.id},
2638
- )
2639
- return [
2640
- Run(
2641
- self._client,
2642
- edge["node"]["project"]["entityName"],
2643
- edge["node"]["project"]["name"],
2644
- edge["node"]["name"],
2645
- )
2646
- for edge in response.get("artifact", {}).get("usedBy", {}).get("edges", [])
2647
- ]
2523
+ if self._client is None:
2524
+ raise RuntimeError("Client not initialized for artifact queries")
2525
+
2526
+ query = gql(ARTIFACT_USED_BY_GQL)
2527
+ gql_vars = {"id": self.id}
2528
+ data = self._client.execute(query, variable_values=gql_vars)
2529
+ result = ArtifactUsedBy.model_validate(data)
2530
+
2531
+ if (
2532
+ (artifact := result.artifact)
2533
+ and (used_by := artifact.used_by)
2534
+ and (edges := used_by.edges)
2535
+ ):
2536
+ run_nodes = (e.node for e in edges)
2537
+ return [
2538
+ Run(self._client, proj.entity_name, proj.name, run.name)
2539
+ for run in run_nodes
2540
+ if (proj := run.project)
2541
+ ]
2542
+ return []
2648
2543
 
2649
2544
  @ensure_logged
2650
2545
  def logged_by(self) -> Run | None:
@@ -2656,39 +2551,22 @@ class Artifact:
2656
2551
  Raises:
2657
2552
  ArtifactNotLoggedError: If the artifact is not logged.
2658
2553
  """
2659
- query = gql(
2660
- """
2661
- query ArtifactCreatedBy(
2662
- $id: ID!
2663
- ) {
2664
- artifact(id: $id) {
2665
- createdBy {
2666
- ... on Run {
2667
- name
2668
- project {
2669
- name
2670
- entityName
2671
- }
2672
- }
2673
- }
2674
- }
2675
- }
2676
- """
2677
- )
2678
- assert self._client is not None
2679
- response = self._client.execute(
2680
- query,
2681
- variable_values={"id": self.id},
2682
- )
2683
- creator = response.get("artifact", {}).get("createdBy", {})
2684
- if creator.get("name") is None:
2685
- return None
2686
- return Run(
2687
- self._client,
2688
- creator["project"]["entityName"],
2689
- creator["project"]["name"],
2690
- creator["name"],
2691
- )
2554
+ if self._client is None:
2555
+ raise RuntimeError("Client not initialized for artifact queries")
2556
+
2557
+ query = gql(ARTIFACT_CREATED_BY_GQL)
2558
+ gql_vars = {"id": self.id}
2559
+ data = self._client.execute(query, variable_values=gql_vars)
2560
+ result = ArtifactCreatedBy.model_validate(data)
2561
+
2562
+ if (
2563
+ (artifact := result.artifact)
2564
+ and (creator := artifact.created_by)
2565
+ and (name := creator.name)
2566
+ and (project := creator.project)
2567
+ ):
2568
+ return Run(self._client, project.entity_name, project.name, name)
2569
+ return None
2692
2570
 
2693
2571
  @ensure_logged
2694
2572
  def json_encode(self) -> dict[str, Any]:
@@ -2704,37 +2582,22 @@ class Artifact:
2704
2582
  entity_name: str, project_name: str, name: str, client: RetryingClient
2705
2583
  ) -> str | None:
2706
2584
  """Returns the expected type for a given artifact name and project."""
2707
- query = gql(
2708
- """
2709
- query ArtifactType(
2710
- $entityName: String,
2711
- $projectName: String,
2712
- $name: String!
2713
- ) {
2714
- project(name: $projectName, entityName: $entityName) {
2715
- artifact(name: $name) {
2716
- artifactType {
2717
- name
2718
- }
2719
- }
2720
- }
2721
- }
2722
- """
2723
- )
2724
- if ":" not in name:
2725
- name += ":latest"
2726
- response = client.execute(
2727
- query,
2728
- variable_values={
2729
- "entityName": entity_name,
2730
- "projectName": project_name,
2731
- "name": name,
2732
- },
2733
- )
2734
- return (
2735
- ((response.get("project") or {}).get("artifact") or {}).get("artifactType")
2736
- or {}
2737
- ).get("name")
2585
+ query = gql(ARTIFACT_TYPE_GQL)
2586
+ gql_vars = {
2587
+ "entityName": entity_name,
2588
+ "projectName": project_name,
2589
+ "name": name if (":" in name) else f"{name}:latest",
2590
+ }
2591
+ data = client.execute(query, variable_values=gql_vars)
2592
+ result = ArtifactType.model_validate(data)
2593
+
2594
+ if (
2595
+ (project := result.project)
2596
+ and (artifact := project.artifact)
2597
+ and (artifact_type := artifact.artifact_type)
2598
+ ):
2599
+ return artifact_type.name
2600
+ return None
2738
2601
 
2739
2602
  def _load_manifest(self, url: str) -> ArtifactManifest:
2740
2603
  with requests.get(url) as response: