wandb 0.21.1__py3-none-win_amd64.whl → 0.21.2__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. wandb/__init__.py +1 -1
  2. wandb/__init__.pyi +1 -1
  3. wandb/apis/public/api.py +1 -2
  4. wandb/apis/public/artifacts.py +3 -5
  5. wandb/apis/public/registries/_utils.py +14 -16
  6. wandb/apis/public/registries/registries_search.py +176 -289
  7. wandb/apis/public/reports.py +13 -10
  8. wandb/automations/_generated/delete_automation.py +1 -3
  9. wandb/automations/_generated/enums.py +13 -11
  10. wandb/bin/gpu_stats.exe +0 -0
  11. wandb/bin/wandb-core +0 -0
  12. wandb/cli/cli.py +47 -2
  13. wandb/integration/metaflow/data_pandas.py +2 -2
  14. wandb/integration/metaflow/data_pytorch.py +75 -0
  15. wandb/integration/metaflow/data_sklearn.py +76 -0
  16. wandb/integration/metaflow/metaflow.py +16 -87
  17. wandb/integration/weave/__init__.py +6 -0
  18. wandb/integration/weave/interface.py +49 -0
  19. wandb/integration/weave/weave.py +63 -0
  20. wandb/proto/v3/wandb_internal_pb2.py +3 -2
  21. wandb/proto/v4/wandb_internal_pb2.py +2 -2
  22. wandb/proto/v5/wandb_internal_pb2.py +2 -2
  23. wandb/proto/v6/wandb_internal_pb2.py +2 -2
  24. wandb/sdk/artifacts/_factories.py +17 -0
  25. wandb/sdk/artifacts/_generated/__init__.py +221 -13
  26. wandb/sdk/artifacts/_generated/artifact_by_id.py +17 -0
  27. wandb/sdk/artifacts/_generated/artifact_by_name.py +22 -0
  28. wandb/sdk/artifacts/_generated/artifact_collection_membership_file_urls.py +43 -0
  29. wandb/sdk/artifacts/_generated/artifact_created_by.py +47 -0
  30. wandb/sdk/artifacts/_generated/artifact_file_urls.py +22 -0
  31. wandb/sdk/artifacts/_generated/artifact_type.py +31 -0
  32. wandb/sdk/artifacts/_generated/artifact_used_by.py +43 -0
  33. wandb/sdk/artifacts/_generated/artifact_via_membership_by_name.py +26 -0
  34. wandb/sdk/artifacts/_generated/delete_artifact.py +28 -0
  35. wandb/sdk/artifacts/_generated/enums.py +5 -0
  36. wandb/sdk/artifacts/_generated/fetch_artifact_manifest.py +38 -0
  37. wandb/sdk/artifacts/_generated/fetch_registries.py +32 -0
  38. wandb/sdk/artifacts/_generated/fragments.py +279 -41
  39. wandb/sdk/artifacts/_generated/link_artifact.py +6 -0
  40. wandb/sdk/artifacts/_generated/operations.py +654 -51
  41. wandb/sdk/artifacts/_generated/registry_collections.py +34 -0
  42. wandb/sdk/artifacts/_generated/registry_versions.py +34 -0
  43. wandb/sdk/artifacts/_generated/unlink_artifact.py +25 -0
  44. wandb/sdk/artifacts/_graphql_fragments.py +3 -86
  45. wandb/sdk/artifacts/_validators.py +6 -4
  46. wandb/sdk/artifacts/artifact.py +406 -543
  47. wandb/sdk/artifacts/artifact_file_cache.py +10 -6
  48. wandb/sdk/artifacts/artifact_manifest.py +10 -9
  49. wandb/sdk/artifacts/artifact_manifest_entry.py +9 -10
  50. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +5 -3
  51. wandb/sdk/artifacts/storage_handlers/http_handler.py +1 -1
  52. wandb/sdk/artifacts/storage_handlers/s3_handler.py +1 -1
  53. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +1 -1
  54. wandb/sdk/data_types/video.py +2 -2
  55. wandb/sdk/interface/interface_queue.py +1 -4
  56. wandb/sdk/interface/interface_shared.py +26 -37
  57. wandb/sdk/interface/interface_sock.py +24 -14
  58. wandb/sdk/internal/settings_static.py +2 -3
  59. wandb/sdk/launch/create_job.py +12 -1
  60. wandb/sdk/launch/runner/kubernetes_runner.py +24 -29
  61. wandb/sdk/lib/asyncio_compat.py +16 -16
  62. wandb/sdk/lib/asyncio_manager.py +252 -0
  63. wandb/sdk/lib/hashutil.py +13 -4
  64. wandb/sdk/lib/printer.py +2 -2
  65. wandb/sdk/lib/printer_asyncio.py +3 -1
  66. wandb/sdk/lib/retry.py +185 -78
  67. wandb/sdk/lib/service/service_client.py +106 -0
  68. wandb/sdk/lib/service/service_connection.py +20 -26
  69. wandb/sdk/lib/service/service_token.py +30 -13
  70. wandb/sdk/mailbox/mailbox.py +13 -5
  71. wandb/sdk/mailbox/mailbox_handle.py +22 -13
  72. wandb/sdk/mailbox/response_handle.py +42 -106
  73. wandb/sdk/mailbox/wait_with_progress.py +7 -42
  74. wandb/sdk/wandb_init.py +11 -25
  75. wandb/sdk/wandb_login.py +1 -1
  76. wandb/sdk/wandb_run.py +91 -55
  77. wandb/sdk/wandb_settings.py +45 -32
  78. wandb/sdk/wandb_setup.py +176 -96
  79. wandb/util.py +1 -1
  80. {wandb-0.21.1.dist-info → wandb-0.21.2.dist-info}/METADATA +1 -1
  81. {wandb-0.21.1.dist-info → wandb-0.21.2.dist-info}/RECORD +84 -68
  82. wandb/sdk/interface/interface_relay.py +0 -38
  83. wandb/sdk/interface/router.py +0 -89
  84. wandb/sdk/interface/router_queue.py +0 -43
  85. wandb/sdk/interface/router_relay.py +0 -50
  86. wandb/sdk/interface/router_sock.py +0 -32
  87. wandb/sdk/lib/sock_client.py +0 -232
  88. {wandb-0.21.1.dist-info → wandb-0.21.2.dist-info}/WHEEL +0 -0
  89. {wandb-0.21.1.dist-info → wandb-0.21.2.dist-info}/entry_points.txt +0 -0
  90. {wandb-0.21.1.dist-info → wandb-0.21.2.dist-info}/licenses/LICENSE +0 -0
@@ -17,11 +17,22 @@ import time
17
17
  from collections import deque
18
18
  from copy import copy
19
19
  from dataclasses import dataclass
20
- from datetime import datetime, timedelta
20
+ from datetime import timedelta
21
21
  from functools import partial
22
22
  from itertools import filterfalse
23
- from pathlib import PurePosixPath
24
- from typing import IO, TYPE_CHECKING, Any, Iterator, Literal, Sequence, Type, final
23
+ from pathlib import Path, PurePosixPath
24
+ from typing import (
25
+ IO,
26
+ TYPE_CHECKING,
27
+ Any,
28
+ Final,
29
+ Iterator,
30
+ Literal,
31
+ Sequence,
32
+ Type,
33
+ cast,
34
+ final,
35
+ )
25
36
  from urllib.parse import quote, urljoin, urlparse
26
37
 
27
38
  import requests
@@ -58,21 +69,45 @@ from wandb.util import (
58
69
  vendor_setup,
59
70
  )
60
71
 
72
+ from ._factories import make_storage_policy
61
73
  from ._generated import (
62
74
  ADD_ALIASES_GQL,
75
+ ARTIFACT_BY_ID_GQL,
76
+ ARTIFACT_BY_NAME_GQL,
77
+ ARTIFACT_COLLECTION_MEMBERSHIP_FILE_URLS_GQL,
78
+ ARTIFACT_CREATED_BY_GQL,
79
+ ARTIFACT_FILE_URLS_GQL,
80
+ ARTIFACT_TYPE_GQL,
81
+ ARTIFACT_USED_BY_GQL,
82
+ ARTIFACT_VIA_MEMBERSHIP_BY_NAME_GQL,
63
83
  DELETE_ALIASES_GQL,
84
+ DELETE_ARTIFACT_GQL,
85
+ FETCH_ARTIFACT_MANIFEST_GQL,
64
86
  FETCH_LINKED_ARTIFACTS_GQL,
65
87
  LINK_ARTIFACT_GQL,
88
+ UNLINK_ARTIFACT_GQL,
66
89
  UPDATE_ARTIFACT_GQL,
67
90
  ArtifactAliasInput,
91
+ ArtifactByID,
92
+ ArtifactByName,
68
93
  ArtifactCollectionAliasInput,
94
+ ArtifactCollectionMembershipFileUrls,
95
+ ArtifactCreatedBy,
96
+ ArtifactFileUrls,
97
+ ArtifactFragment,
98
+ ArtifactType,
99
+ ArtifactUsedBy,
100
+ ArtifactViaMembershipByName,
101
+ FetchArtifactManifest,
69
102
  FetchLinkedArtifacts,
103
+ FileUrlsFragment,
70
104
  LinkArtifact,
71
105
  LinkArtifactInput,
106
+ MembershipWithArtifact,
72
107
  TagInput,
73
108
  UpdateArtifact,
74
109
  )
75
- from ._graphql_fragments import _gql_artifact_fragment, omit_artifact_fields
110
+ from ._graphql_fragments import omit_artifact_fields
76
111
  from ._validators import (
77
112
  LINKED_ARTIFACT_COLLECTION_TYPE,
78
113
  ArtifactPath,
@@ -103,9 +138,6 @@ from .exceptions import (
103
138
  )
104
139
  from .staging import get_staging_dir
105
140
  from .storage_handlers.gcs_handler import _GCSIsADirectoryError
106
- from .storage_layout import StorageLayout
107
- from .storage_policies import WANDB_STORAGE_POLICY
108
- from .storage_policy import StoragePolicy
109
141
 
110
142
  reset_path = vendor_setup()
111
143
 
@@ -119,6 +151,9 @@ if TYPE_CHECKING:
119
151
  logger = logging.getLogger(__name__)
120
152
 
121
153
 
154
+ _MB: Final[int] = 1024 * 1024
155
+
156
+
122
157
  @final
123
158
  @dataclass
124
159
  class _DeferredArtifactManifest:
@@ -188,11 +223,6 @@ class Artifact:
188
223
  # Internal.
189
224
  self._client: RetryingClient | None = None
190
225
 
191
- storage_policy_cls = StoragePolicy.lookup_by_name(WANDB_STORAGE_POLICY)
192
- layout = StorageLayout.V1 if env.get_use_v1_artifacts() else StorageLayout.V2
193
- policy_config = {"storageLayout": layout}
194
- self._storage_policy = storage_policy_cls.from_config(config=policy_config)
195
-
196
226
  self._tmp_dir: tempfile.TemporaryDirectory | None = None
197
227
  self._added_objs: dict[int, tuple[WBValue, ArtifactManifestEntry]] = {}
198
228
  self._added_local_paths: dict[str, ArtifactManifestEntry] = {}
@@ -236,7 +266,7 @@ class Artifact:
236
266
  self._use_as: str | None = None
237
267
  self._state: ArtifactState = ArtifactState.PENDING
238
268
  self._manifest: ArtifactManifest | _DeferredArtifactManifest | None = (
239
- ArtifactManifestV1(self._storage_policy)
269
+ ArtifactManifestV1(storage_policy=make_storage_policy())
240
270
  )
241
271
  self._commit_hash: str | None = None
242
272
  self._file_count: int | None = None
@@ -257,32 +287,22 @@ class Artifact:
257
287
  if (artifact := artifact_instance_cache.get(artifact_id)) is not None:
258
288
  return artifact
259
289
 
260
- query = gql(
261
- """
262
- query ArtifactByID($id: ID!) {
263
- artifact(id: $id) {
264
- ...ArtifactFragment
265
- }
266
- }
267
- """
268
- + _gql_artifact_fragment()
269
- )
270
- response = client.execute(
271
- query,
272
- variable_values={"id": artifact_id},
273
- )
274
- attrs = response.get("artifact")
275
- if attrs is None:
290
+ query = gql_compat(ARTIFACT_BY_ID_GQL, omit_fields=omit_artifact_fields())
291
+
292
+ data = client.execute(query, variable_values={"id": artifact_id})
293
+ result = ArtifactByID.model_validate(data)
294
+
295
+ if (art := result.artifact) is None:
276
296
  return None
277
297
 
278
- src_collection = attrs["artifactSequence"]
279
- src_project = src_collection["project"]
298
+ src_collection = art.artifact_sequence
299
+ src_project = src_collection.project
280
300
 
281
- entity_name = src_project["entityName"] if src_project else ""
282
- project_name = src_project["name"] if src_project else ""
301
+ entity_name = src_project.entity_name if src_project else ""
302
+ project_name = src_project.name if src_project else ""
283
303
 
284
- name = "{}:v{}".format(src_collection["name"], attrs["versionIndex"])
285
- return cls._from_attrs(entity_name, project_name, name, attrs, client)
304
+ name = f"{src_collection.name}:v{art.version_index}"
305
+ return cls._from_attrs(entity_name, project_name, name, art, client)
286
306
 
287
307
  @classmethod
288
308
  def _membership_from_name(
@@ -293,7 +313,7 @@ class Artifact:
293
313
  name: str,
294
314
  client: RetryingClient,
295
315
  ) -> Artifact:
296
- if not InternalApi()._server_supports(
316
+ if not (api := InternalApi())._server_supports(
297
317
  pb.ServerFeature.PROJECT_ARTIFACT_COLLECTION_MEMBERSHIP
298
318
  ):
299
319
  raise UnsupportedError(
@@ -301,69 +321,26 @@ class Artifact:
301
321
  "by this version of wandb server. Consider updating to the latest version."
302
322
  )
303
323
 
304
- query = gql(
305
- f"""
306
- query ArtifactByName($entityName: String!, $projectName: String!, $name: String!) {{
307
- project(name: $projectName, entityName: $entityName) {{
308
- artifactCollectionMembership(name: $name) {{
309
- id
310
- artifactCollection {{
311
- id
312
- name
313
- project {{
314
- id
315
- entityName
316
- name
317
- }}
318
- }}
319
- artifact {{
320
- ...ArtifactFragment
321
- }}
322
- }}
323
- }}
324
- }}
325
- {_gql_artifact_fragment()}
326
- """
324
+ query = gql_compat(
325
+ ARTIFACT_VIA_MEMBERSHIP_BY_NAME_GQL,
326
+ omit_fields=omit_artifact_fields(api=api),
327
327
  )
328
328
 
329
- query_variable_values: dict[str, Any] = {
330
- "entityName": entity,
331
- "projectName": project,
332
- "name": name,
333
- }
334
- response = client.execute(
335
- query,
336
- variable_values=query_variable_values,
337
- )
338
- if not (project_attrs := response.get("project")):
329
+ gql_vars = {"entityName": entity, "projectName": project, "name": name}
330
+ data = client.execute(query, variable_values=gql_vars)
331
+ result = ArtifactViaMembershipByName.model_validate(data)
332
+
333
+ if not (project_attrs := result.project):
339
334
  raise ValueError(f"project {project!r} not found under entity {entity!r}")
340
- if not (acm_attrs := project_attrs.get("artifactCollectionMembership")):
335
+
336
+ if not (acm_attrs := project_attrs.artifact_collection_membership):
341
337
  entity_project = f"{entity}/{project}"
342
338
  raise ValueError(
343
339
  f"artifact membership {name!r} not found in {entity_project!r}"
344
340
  )
345
- if not (ac_attrs := acm_attrs.get("artifactCollection")):
346
- raise ValueError("artifact collection not found")
347
- if not (
348
- (ac_name := ac_attrs.get("name"))
349
- and (ac_project_attrs := ac_attrs.get("project"))
350
- ):
351
- raise ValueError("artifact collection project not found")
352
- ac_project = ac_project_attrs.get("name")
353
- ac_entity = ac_project_attrs.get("entityName")
354
- if is_artifact_registry_project(ac_project) and project == "model-registry":
355
- wandb.termwarn(
356
- "This model registry has been migrated and will be discontinued. "
357
- f"Your request was redirected to the corresponding artifact `{ac_name}` in the new registry. "
358
- f"Please update your paths to point to the migrated registry directly, `{ac_project}/{ac_name}`."
359
- )
360
- entity = ac_entity
361
- project = ac_project
362
- if not (attrs := acm_attrs.get("artifact")):
363
- entity_project = f"{entity}/{project}"
364
- raise ValueError(f"artifact {name!r} not found in {entity_project!r}")
365
341
 
366
- return cls._from_attrs(entity, project, name, attrs, client)
342
+ target_path = ArtifactPath(prefix=entity, project=project, name=name)
343
+ return cls._from_membership(acm_attrs, target=target_path, client=client)
367
344
 
368
345
  @classmethod
369
346
  def _from_name(
@@ -375,59 +352,71 @@ class Artifact:
375
352
  client: RetryingClient,
376
353
  enable_tracking: bool = False,
377
354
  ) -> Artifact:
378
- if InternalApi()._server_supports(
355
+ if (api := InternalApi())._server_supports(
379
356
  pb.ServerFeature.PROJECT_ARTIFACT_COLLECTION_MEMBERSHIP
380
357
  ):
381
358
  return cls._membership_from_name(
382
- entity=entity,
383
- project=project,
384
- name=name,
385
- client=client,
359
+ entity=entity, project=project, name=name, client=client
386
360
  )
387
361
 
388
- query_variable_values: dict[str, Any] = {
362
+ supports_enable_tracking_gql_var = api.server_project_type_introspection()
363
+ omit_vars = None if supports_enable_tracking_gql_var else {"enableTracking"}
364
+
365
+ gql_vars = {
389
366
  "entityName": entity,
390
367
  "projectName": project,
391
368
  "name": name,
369
+ "enableTracking": enable_tracking,
392
370
  }
393
- query_vars = ["$entityName: String!", "$projectName: String!", "$name: String!"]
394
- query_args = ["name: $name"]
395
-
396
- server_supports_enabling_artifact_usage_tracking = (
397
- InternalApi().server_project_type_introspection()
371
+ query = gql_compat(
372
+ ARTIFACT_BY_NAME_GQL,
373
+ omit_variables=omit_vars,
374
+ omit_fields=omit_artifact_fields(api=api),
398
375
  )
399
- if server_supports_enabling_artifact_usage_tracking:
400
- query_vars.append("$enableTracking: Boolean")
401
- query_args.append("enableTracking: $enableTracking")
402
- query_variable_values["enableTracking"] = enable_tracking
403
-
404
- vars_str = ", ".join(query_vars)
405
- args_str = ", ".join(query_args)
406
-
407
- query = gql(
408
- f"""
409
- query ArtifactByName({vars_str}) {{
410
- project(name: $projectName, entityName: $entityName) {{
411
- artifact({args_str}) {{
412
- ...ArtifactFragment
413
- }}
414
- }}
415
- }}
416
- {_gql_artifact_fragment()}
417
- """
418
- )
419
- response = client.execute(
420
- query,
421
- variable_values=query_variable_values,
422
- )
423
- project_attrs = response.get("project")
424
- if not project_attrs:
425
- raise ValueError(f"project '{project}' not found under entity '{entity}'")
426
- attrs = project_attrs.get("artifact")
427
- if not attrs:
428
- raise ValueError(f"artifact '{name}' not found in '{entity}/{project}'")
429
376
 
430
- return cls._from_attrs(entity, project, name, attrs, client)
377
+ data = client.execute(query, variable_values=gql_vars)
378
+ result = ArtifactByName.model_validate(data)
379
+
380
+ if not (proj_attrs := result.project):
381
+ raise ValueError(f"project {project!r} not found under entity {entity!r}")
382
+
383
+ if not (art_attrs := proj_attrs.artifact):
384
+ entity_project = f"{entity}/{project}"
385
+ raise ValueError(f"artifact {name!r} not found in {entity_project!r}")
386
+
387
+ return cls._from_attrs(entity, project, name, art_attrs, client)
388
+
389
+ @classmethod
390
+ def _from_membership(
391
+ cls,
392
+ membership: MembershipWithArtifact,
393
+ target: ArtifactPath,
394
+ client: RetryingClient,
395
+ ) -> Artifact:
396
+ if not (
397
+ (collection := membership.artifact_collection)
398
+ and (name := collection.name)
399
+ and (proj := collection.project)
400
+ ):
401
+ raise ValueError("Missing artifact collection project in GraphQL response")
402
+
403
+ if is_artifact_registry_project(proj.name) and (
404
+ target.project == "model-registry"
405
+ ):
406
+ wandb.termwarn(
407
+ "This model registry has been migrated and will be discontinued. "
408
+ f"Your request was redirected to the corresponding artifact {name!r} in the new registry. "
409
+ f"Please update your paths to point to the migrated registry directly, '{proj.name}/{name}'."
410
+ )
411
+ new_entity, new_project = proj.entity_name, proj.name
412
+ else:
413
+ new_entity = cast(str, target.prefix)
414
+ new_project = cast(str, target.project)
415
+
416
+ if not (artifact := membership.artifact):
417
+ raise ValueError(f"Artifact {target.to_str()!r} not found in response")
418
+
419
+ return cls._from_attrs(new_entity, new_project, target.name, artifact, client)
431
420
 
432
421
  @classmethod
433
422
  def _from_attrs(
@@ -435,7 +424,7 @@ class Artifact:
435
424
  entity: str,
436
425
  project: str,
437
426
  name: str,
438
- attrs: dict[str, Any],
427
+ attrs: dict[str, Any] | ArtifactFragment,
439
428
  client: RetryingClient,
440
429
  aliases: list[str] | None = None,
441
430
  ) -> Artifact:
@@ -445,7 +434,9 @@ class Artifact:
445
434
  artifact._entity = entity
446
435
  artifact._project = project
447
436
  artifact._name = name
448
- artifact._assign_attrs(attrs, aliases)
437
+
438
+ validated_attrs = ArtifactFragment.model_validate(attrs)
439
+ artifact._assign_attrs(validated_attrs, aliases)
449
440
 
450
441
  artifact.finalize()
451
442
 
@@ -458,29 +449,24 @@ class Artifact:
458
449
  # doesn't make it clear if the artifact is a link or not and have to manually set it.
459
450
  def _assign_attrs(
460
451
  self,
461
- attrs: dict[str, Any],
452
+ art: ArtifactFragment,
462
453
  aliases: list[str] | None = None,
463
454
  is_link: bool | None = None,
464
455
  ) -> None:
465
456
  """Update this Artifact's attributes using the server response."""
466
- self._id = attrs["id"]
467
-
468
- src_version = f"v{attrs['versionIndex']}"
469
- src_collection = attrs["artifactSequence"]
470
- src_project = src_collection["project"]
457
+ self._id = art.id
471
458
 
472
- self._source_entity = src_project["entityName"] if src_project else ""
473
- self._source_project = src_project["name"] if src_project else ""
474
- self._source_name = f"{src_collection['name']}:{src_version}"
475
- self._source_version = src_version
459
+ src_collection = art.artifact_sequence
460
+ src_project = src_collection.project
476
461
 
477
- if self._entity is None:
478
- self._entity = self._source_entity
479
- if self._project is None:
480
- self._project = self._source_project
462
+ self._source_entity = src_project.entity_name if src_project else ""
463
+ self._source_project = src_project.name if src_project else ""
464
+ self._source_name = f"{src_collection.name}:v{art.version_index}"
465
+ self._source_version = f"v{art.version_index}"
481
466
 
482
- if self._name is None:
483
- self._name = self._source_name
467
+ self._entity = self._entity or self._source_entity
468
+ self._project = self._project or self._source_project
469
+ self._name = self._name or self._source_name
484
470
 
485
471
  # TODO: Refactor artifact query to fetch artifact via membership instead
486
472
  # and get the collection type
@@ -488,33 +474,35 @@ class Artifact:
488
474
  self._is_link = (
489
475
  self._entity != self._source_entity
490
476
  or self._project != self._source_project
491
- or self._name != self._source_name
477
+ or self._name.split(":")[0] != self._source_name.split(":")[0]
492
478
  )
493
479
  else:
494
480
  self._is_link = is_link
495
481
 
496
- self._type = attrs["artifactType"]["name"]
497
- self._description = attrs["description"]
498
-
499
- entity = self._entity
500
- project = self._project
501
- collection, *_ = self._name.split(":")
482
+ self._type = art.artifact_type.name
483
+ self._description = art.description
502
484
 
503
- processed_aliases = []
504
485
  # The future of aliases is to move all alias fetches to the membership level
505
486
  # so we don't have to do the collection fetches below
506
487
  if aliases:
507
488
  processed_aliases = aliases
508
- else:
489
+ elif art.aliases:
490
+ entity = self._entity
491
+ project = self._project
492
+ collection = self._name.split(":")[0]
509
493
  processed_aliases = [
510
- obj["alias"]
511
- for obj in attrs["aliases"]
512
- if obj["artifactCollection"]
513
- and obj["artifactCollection"]["project"]
514
- and obj["artifactCollection"]["project"]["entityName"] == entity
515
- and obj["artifactCollection"]["project"]["name"] == project
516
- and obj["artifactCollection"]["name"] == collection
494
+ art_alias.alias
495
+ for art_alias in art.aliases
496
+ if (
497
+ (coll := art_alias.artifact_collection)
498
+ and (proj := coll.project)
499
+ and proj.entity_name == entity
500
+ and proj.name == project
501
+ and coll.name == collection
502
+ )
517
503
  ]
504
+ else:
505
+ processed_aliases = []
518
506
 
519
507
  version_aliases = list(filter(alias_is_version_index, processed_aliases))
520
508
  other_aliases = list(filterfalse(alias_is_version_index, processed_aliases))
@@ -524,49 +512,42 @@ class Artifact:
524
512
  version_aliases, too_short=TooFewItemsError, too_long=TooManyItemsError
525
513
  )
526
514
  except TooFewItemsError:
527
- version = src_version # default to the source version
515
+ version = f"v{art.version_index}" # default to the source version
528
516
  except TooManyItemsError:
529
517
  msg = f"Expected at most one version alias, got {len(version_aliases)}: {version_aliases!r}"
530
518
  raise ValueError(msg) from None
531
519
 
532
520
  self._version = version
533
-
534
- if ":" not in self._name:
535
- self._name = f"{self._name}:{version}"
521
+ self._name = self._name if (":" in self._name) else f"{self._name}:{version}"
536
522
 
537
523
  self._aliases = other_aliases
538
- self._saved_aliases = copy(other_aliases)
524
+ self._saved_aliases = copy(self._aliases)
539
525
 
540
- tags = [obj["name"] for obj in (attrs.get("tags") or [])]
541
- self._tags = tags
542
- self._saved_tags = copy(tags)
526
+ self._tags = [tag.name for tag in (art.tags or [])]
527
+ self._saved_tags = copy(self._tags)
543
528
 
544
- metadata_str = attrs["metadata"]
545
- self._metadata = validate_metadata(
546
- json.loads(metadata_str) if metadata_str else {}
547
- )
529
+ self._metadata = validate_metadata(art.metadata)
548
530
 
549
531
  self._ttl_duration_seconds = validate_ttl_duration_seconds(
550
- attrs.get("ttlDurationSeconds")
532
+ art.ttl_duration_seconds
551
533
  )
552
534
  self._ttl_is_inherited = (
553
- True if (attrs.get("ttlIsInherited") is None) else attrs["ttlIsInherited"]
535
+ True if (art.ttl_is_inherited is None) else art.ttl_is_inherited
554
536
  )
555
537
 
556
- self._state = ArtifactState(attrs["state"])
538
+ self._state = ArtifactState(art.state)
557
539
 
558
- try:
559
- manifest_url = attrs["currentManifest"]["file"]["directUrl"]
560
- except (LookupError, TypeError):
561
- self._manifest = None
562
- else:
563
- self._manifest = _DeferredArtifactManifest(manifest_url)
540
+ self._manifest = (
541
+ _DeferredArtifactManifest(manifest.file.direct_url)
542
+ if (manifest := art.current_manifest)
543
+ else None
544
+ )
564
545
 
565
- self._commit_hash = attrs["commitHash"]
566
- self._file_count = attrs["fileCount"]
567
- self._created_at = attrs["createdAt"]
568
- self._updated_at = attrs["updatedAt"]
569
- self._history_step = attrs.get("historyStep", None)
546
+ self._commit_hash = art.commit_hash
547
+ self._file_count = art.file_count
548
+ self._created_at = art.created_at
549
+ self._updated_at = art.updated_at
550
+ self._history_step = art.history_step
570
551
 
571
552
  @ensure_logged
572
553
  def new_draft(self) -> Artifact:
@@ -1063,37 +1044,24 @@ class Artifact:
1063
1044
  return self._manifest
1064
1045
 
1065
1046
  if self._manifest is None:
1066
- query = gql(
1067
- """
1068
- query ArtifactManifest(
1069
- $entityName: String!,
1070
- $projectName: String!,
1071
- $name: String!
1072
- ) {
1073
- project(entityName: $entityName, name: $projectName) {
1074
- artifact(name: $name) {
1075
- currentManifest {
1076
- file {
1077
- directUrl
1078
- }
1079
- }
1080
- }
1081
- }
1082
- }
1083
- """
1084
- )
1085
- assert self._client is not None
1086
- response = self._client.execute(
1087
- query,
1088
- variable_values={
1089
- "entityName": self._entity,
1090
- "projectName": self._project,
1091
- "name": self._name,
1092
- },
1093
- )
1094
- attrs = response["project"]["artifact"]
1095
- manifest_url = attrs["currentManifest"]["file"]["directUrl"]
1096
- self._manifest = self._load_manifest(manifest_url)
1047
+ if self._client is None:
1048
+ raise RuntimeError("Client not initialized for artifact queries")
1049
+
1050
+ query = gql(FETCH_ARTIFACT_MANIFEST_GQL)
1051
+ gql_vars = {
1052
+ "entityName": self.entity,
1053
+ "projectName": self.project,
1054
+ "name": self.name,
1055
+ }
1056
+ data = self._client.execute(query, variable_values=gql_vars)
1057
+ result = FetchArtifactManifest.model_validate(data)
1058
+ if not (
1059
+ (project := result.project)
1060
+ and (artifact := project.artifact)
1061
+ and (manifest := artifact.current_manifest)
1062
+ ):
1063
+ raise ValueError("Failed to fetch artifact manifest")
1064
+ self._manifest = self._load_manifest(manifest.file.direct_url)
1097
1065
 
1098
1066
  return self._manifest
1099
1067
 
@@ -1112,11 +1080,7 @@ class Artifact:
1112
1080
 
1113
1081
  Includes any references tracked by this artifact.
1114
1082
  """
1115
- total_size: int = 0
1116
- for entry in self.manifest.entries.values():
1117
- if entry.size is not None:
1118
- total_size += entry.size
1119
- return total_size
1083
+ return sum(entry.size for entry in self.manifest.entries.values() if entry.size)
1120
1084
 
1121
1085
  @property
1122
1086
  @ensure_logged
@@ -1184,7 +1148,7 @@ class Artifact:
1184
1148
  Returns:
1185
1149
  Boolean. `False` if artifact is saved. `True` if artifact is not saved.
1186
1150
  """
1187
- return self._state == ArtifactState.PENDING
1151
+ return self._state is ArtifactState.PENDING
1188
1152
 
1189
1153
  def _is_draft_save_started(self) -> bool:
1190
1154
  return self._save_handle is not None
@@ -1205,7 +1169,7 @@ class Artifact:
1205
1169
  settings: A settings object to use when initializing an automatic run. Most
1206
1170
  commonly used in testing harness.
1207
1171
  """
1208
- if self._state != ArtifactState.PENDING:
1172
+ if self._state is not ArtifactState.PENDING:
1209
1173
  return self._update()
1210
1174
 
1211
1175
  if self._incremental:
@@ -1266,31 +1230,20 @@ class Artifact:
1266
1230
  return self
1267
1231
 
1268
1232
  def _populate_after_save(self, artifact_id: str) -> None:
1269
- query_template = """
1270
- query ArtifactByIDShort($id: ID!) {
1271
- artifact(id: $id) {
1272
- ...ArtifactFragment
1273
- }
1274
- }
1275
- """ + _gql_artifact_fragment()
1233
+ assert self._client is not None
1276
1234
 
1277
- query = gql(query_template)
1235
+ query = gql_compat(ARTIFACT_BY_ID_GQL, omit_fields=omit_artifact_fields())
1278
1236
 
1279
- assert self._client is not None
1280
- response = self._client.execute(
1281
- query,
1282
- variable_values={"id": artifact_id},
1283
- )
1237
+ data = self._client.execute(query, variable_values={"id": artifact_id})
1238
+ result = ArtifactByID.model_validate(data)
1284
1239
 
1285
- try:
1286
- attrs = response["artifact"]
1287
- except LookupError:
1240
+ if not (artifact := result.artifact):
1288
1241
  raise ValueError(f"Unable to fetch artifact with id: {artifact_id!r}")
1289
- else:
1290
- # _populate_after_save is only called on source artifacts, not linked artifacts
1291
- # We have to manually set is_link because we aren't fetching the collection the artifact.
1292
- # That requires greater refactoring for commitArtifact to return the artifact collection type.
1293
- self._assign_attrs(attrs, is_link=False)
1242
+
1243
+ # _populate_after_save is only called on source artifacts, not linked artifacts
1244
+ # We have to manually set is_link because we aren't fetching the collection the artifact.
1245
+ # That requires greater refactoring for commitArtifact to return the artifact collection type.
1246
+ self._assign_attrs(artifact, is_link=False)
1294
1247
 
1295
1248
  @normalize_exceptions
1296
1249
  def _update(self) -> None:
@@ -1375,7 +1328,7 @@ class Artifact:
1375
1328
  for alias in self.aliases
1376
1329
  ]
1377
1330
 
1378
- omit_fields = omit_artifact_fields(api=InternalApi())
1331
+ omit_fields = omit_artifact_fields()
1379
1332
  omit_variables = set()
1380
1333
 
1381
1334
  if {"ttlIsInherited", "ttlDurationSeconds"} & omit_fields:
@@ -1399,7 +1352,9 @@ class Artifact:
1399
1352
  omit_variables |= {"tagsToAdd", "tagsToDelete"}
1400
1353
 
1401
1354
  mutation = gql_compat(
1402
- UPDATE_ARTIFACT_GQL, omit_variables=omit_variables, omit_fields=omit_fields
1355
+ UPDATE_ARTIFACT_GQL,
1356
+ omit_variables=omit_variables,
1357
+ omit_fields=omit_fields,
1403
1358
  )
1404
1359
 
1405
1360
  gql_vars = {
@@ -1417,7 +1372,7 @@ class Artifact:
1417
1372
  result = UpdateArtifact.model_validate(data).update_artifact
1418
1373
  if not (result and (artifact := result.artifact)):
1419
1374
  raise ValueError("Unable to parse updateArtifact response")
1420
- self._assign_attrs(artifact.model_dump())
1375
+ self._assign_attrs(artifact)
1421
1376
 
1422
1377
  self._ttl_changed = False # Reset after updating artifact
1423
1378
 
@@ -1481,7 +1436,7 @@ class Artifact:
1481
1436
  self._tmp_dir = tempfile.TemporaryDirectory()
1482
1437
  path = os.path.join(self._tmp_dir.name, name.lstrip("/"))
1483
1438
 
1484
- filesystem.mkdir_exists_ok(os.path.dirname(path))
1439
+ Path(path).parent.mkdir(parents=True, exist_ok=True)
1485
1440
  try:
1486
1441
  with fsync_open(path, mode, encoding) as f:
1487
1442
  yield f
@@ -1588,30 +1543,27 @@ class Artifact:
1588
1543
  ValueError: Policy must be "mutable" or "immutable"
1589
1544
  """
1590
1545
  if not os.path.isdir(local_path):
1591
- raise ValueError(f"Path is not a directory: {local_path}")
1546
+ raise ValueError(f"Path is not a directory: {local_path!r}")
1592
1547
 
1593
1548
  termlog(
1594
- "Adding directory to artifact ({})... ".format(
1595
- os.path.join(".", os.path.normpath(local_path))
1596
- ),
1549
+ f"Adding directory to artifact ({Path('.', local_path)})... ",
1597
1550
  newline=False,
1598
1551
  )
1599
- start_time = time.time()
1552
+ start_time = time.monotonic()
1600
1553
 
1601
- paths = []
1554
+ paths: deque[tuple[str, str]] = deque()
1555
+ logical_root = name or "" # shared prefix, if any, for logical paths
1602
1556
  for dirpath, _, filenames in os.walk(local_path, followlinks=True):
1603
1557
  for fname in filenames:
1604
1558
  physical_path = os.path.join(dirpath, fname)
1605
1559
  logical_path = os.path.relpath(physical_path, start=local_path)
1606
- if name is not None:
1607
- logical_path = os.path.join(name, logical_path)
1560
+ logical_path = os.path.join(logical_root, logical_path)
1608
1561
  paths.append((logical_path, physical_path))
1609
1562
 
1610
- def add_manifest_file(log_phy_path: tuple[str, str]) -> None:
1611
- logical_path, physical_path = log_phy_path
1563
+ def add_manifest_file(logical_pth: str, physical_pth: str) -> None:
1612
1564
  self._add_local_file(
1613
- name=logical_path,
1614
- path=physical_path,
1565
+ name=logical_pth,
1566
+ path=physical_pth,
1615
1567
  skip_cache=skip_cache,
1616
1568
  policy=policy,
1617
1569
  overwrite=merge,
@@ -1619,11 +1571,11 @@ class Artifact:
1619
1571
 
1620
1572
  num_threads = 8
1621
1573
  pool = multiprocessing.dummy.Pool(num_threads)
1622
- pool.map(add_manifest_file, paths)
1574
+ pool.starmap(add_manifest_file, paths)
1623
1575
  pool.close()
1624
1576
  pool.join()
1625
1577
 
1626
- termlog("Done. %.1fs" % (time.time() - start_time), prefix=False)
1578
+ termlog("Done. %.1fs" % (time.monotonic() - start_time), prefix=False)
1627
1579
 
1628
1580
  @ensure_not_finalized
1629
1581
  def add_reference(
@@ -1699,7 +1651,7 @@ class Artifact:
1699
1651
  "References must be URIs. To reference a local file, use file://"
1700
1652
  )
1701
1653
 
1702
- manifest_entries = self._storage_policy.store_reference(
1654
+ manifest_entries = self.manifest.storage_policy.store_reference(
1703
1655
  self,
1704
1656
  URIStr(uri_str),
1705
1657
  name=name,
@@ -1861,10 +1813,8 @@ class Artifact:
1861
1813
  return
1862
1814
 
1863
1815
  path = str(PurePosixPath(item))
1864
- entry = self.manifest.get_entry_by_path(path)
1865
- if entry:
1866
- self.manifest.remove_entry(entry)
1867
- return
1816
+ if entry := self.manifest.get_entry_by_path(path):
1817
+ return self.manifest.remove_entry(entry)
1868
1818
 
1869
1819
  entries = self.manifest.get_entries_in_directory(path)
1870
1820
  if not entries:
@@ -1922,8 +1872,7 @@ class Artifact:
1922
1872
 
1923
1873
  # If the entry is a reference from another artifact, then get it directly from
1924
1874
  # that artifact.
1925
- referenced_id = entry._referenced_artifact_id()
1926
- if referenced_id:
1875
+ if referenced_id := entry._referenced_artifact_id():
1927
1876
  assert self._client is not None
1928
1877
  artifact = self._from_id(referenced_id, client=self._client)
1929
1878
  assert artifact is not None
@@ -1942,10 +1891,9 @@ class Artifact:
1942
1891
  item_path = item.download()
1943
1892
 
1944
1893
  # Load the object from the JSON blob
1945
- result = None
1946
- json_obj = {}
1947
1894
  with open(item_path) as file:
1948
1895
  json_obj = json.load(file)
1896
+
1949
1897
  result = wb_class.from_json(json_obj, self)
1950
1898
  result._set_artifact_source(self, name)
1951
1899
  return result
@@ -1959,10 +1907,9 @@ class Artifact:
1959
1907
  Returns:
1960
1908
  The artifact relative name.
1961
1909
  """
1962
- entry = self._added_local_paths.get(local_path, None)
1963
- if entry is None:
1964
- return None
1965
- return entry.path
1910
+ if entry := self._added_local_paths.get(local_path):
1911
+ return entry.path
1912
+ return None
1966
1913
 
1967
1914
  def _get_obj_entry(
1968
1915
  self, name: str
@@ -1979,8 +1926,7 @@ class Artifact:
1979
1926
  """
1980
1927
  for wb_class in WBValue.type_mapping().values():
1981
1928
  wandb_file_name = wb_class.with_suffix(name)
1982
- entry = self.manifest.entries.get(wandb_file_name)
1983
- if entry is not None:
1929
+ if entry := self.manifest.entries.get(wandb_file_name):
1984
1930
  return entry, wb_class
1985
1931
  return None, None
1986
1932
 
@@ -2110,14 +2056,14 @@ class Artifact:
2110
2056
  multipart: bool | None = None,
2111
2057
  ) -> FilePathStr:
2112
2058
  nfiles = len(self.manifest.entries)
2113
- size = sum(e.size or 0 for e in self.manifest.entries.values())
2114
- log = False
2115
- if nfiles > 5000 or size > 50 * 1024 * 1024:
2116
- log = True
2059
+ size_mb = self.size / _MB
2060
+
2061
+ if log := (nfiles > 5000 or size_mb > 50):
2117
2062
  termlog(
2118
- f"Downloading large artifact {self.name}, {size / (1024 * 1024):.2f}MB. {nfiles} files... ",
2063
+ f"Downloading large artifact {self.name!r}, {size_mb:.2f}MB. {nfiles!r} files...",
2119
2064
  )
2120
- start_time = datetime.now()
2065
+ start_time = time.monotonic()
2066
+
2121
2067
  download_logger = ArtifactDownloadLogger(nfiles=nfiles)
2122
2068
 
2123
2069
  def _download_entry(
@@ -2156,44 +2102,62 @@ class Artifact:
2156
2102
  cookies=_thread_local_api_settings.cookies,
2157
2103
  headers=_thread_local_api_settings.headers,
2158
2104
  )
2105
+
2106
+ batch_size = env.get_artifact_fetch_file_url_batch_size()
2107
+
2159
2108
  active_futures = set()
2160
- has_next_page = True
2161
- cursor = None
2162
- while has_next_page:
2163
- fetch_url_batch_size = env.get_artifact_fetch_file_url_batch_size()
2164
- attrs = self._fetch_file_urls(cursor, fetch_url_batch_size)
2165
- has_next_page = attrs["pageInfo"]["hasNextPage"]
2166
- cursor = attrs["pageInfo"]["endCursor"]
2167
- for edge in attrs["edges"]:
2168
- entry = self.get_entry(edge["node"]["name"])
2109
+ cursor, has_more = None, True
2110
+ while has_more:
2111
+ files_page = self._fetch_file_urls(cursor=cursor, per_page=batch_size)
2112
+
2113
+ has_more = files_page.page_info.has_next_page
2114
+ cursor = files_page.page_info.end_cursor
2115
+
2116
+ # `File` nodes are formally nullable, so filter them out just in case.
2117
+ file_nodes = (e.node for e in files_page.edges if e.node)
2118
+ for node in file_nodes:
2119
+ entry = self.get_entry(node.name)
2169
2120
  # TODO: uncomment once artifact downloads are supported in core
2170
2121
  # if require_core and entry.ref is None:
2171
2122
  # # Handled by core
2172
2123
  # continue
2173
- entry._download_url = edge["node"]["directUrl"]
2124
+ entry._download_url = node.direct_url
2174
2125
  if (not path_prefix) or entry.path.startswith(str(path_prefix)):
2175
2126
  active_futures.add(executor.submit(download_entry, entry))
2127
+
2176
2128
  # Wait for download threads to catch up.
2177
- max_backlog = fetch_url_batch_size
2178
- if len(active_futures) > max_backlog:
2129
+ #
2130
+ # Extra context and observations (tonyyli):
2131
+ # - Even though the ThreadPoolExecutor limits the number of
2132
+ # concurrently-executed tasks, its internal task queue is unbounded.
2133
+ # The code below seems intended to ensure that at most `batch_size`
2134
+ # "backlogged" futures are held in memory at any given time. This seems like
2135
+ # a reasonable safeguard against unbounded memory consumption.
2136
+ #
2137
+ # - We should probably use a builtin (bounded) Queue or Semaphore here instead.
2138
+ # Consider this for a future change, or (depending on risk and risk tolerance)
2139
+ # managing this logic via asyncio instead, if viable.
2140
+ if len(active_futures) > batch_size:
2179
2141
  for future in concurrent.futures.as_completed(active_futures):
2180
2142
  future.result() # check for errors
2181
2143
  active_futures.remove(future)
2182
- if len(active_futures) <= max_backlog:
2144
+ if len(active_futures) <= batch_size:
2183
2145
  break
2146
+
2184
2147
  # Check for errors.
2185
2148
  for future in concurrent.futures.as_completed(active_futures):
2186
2149
  future.result()
2187
2150
 
2188
2151
  if log:
2189
- now = datetime.now()
2190
- delta = abs((now - start_time).total_seconds())
2191
- hours = int(delta // 3600)
2192
- minutes = int((delta - hours * 3600) // 60)
2193
- seconds = delta - hours * 3600 - minutes * 60
2194
- speed = size / 1024 / 1024 / delta
2152
+ # If you're wondering if we can display a `timedelta`, note that it
2153
+ # doesn't really support custom string format specifiers (compared to
2154
+ # e.g. `datetime` objs). To truncate the number of decimal places for
2155
+ # the seconds part, we manually convert/format each part below.
2156
+ dt_secs = abs(time.monotonic() - start_time)
2157
+ hrs, mins = divmod(dt_secs, 3600)
2158
+ mins, secs = divmod(mins, 60)
2195
2159
  termlog(
2196
- f"Done. {hours}:{minutes}:{seconds:.1f} ({speed:.1f}MB/s)",
2160
+ f"Done. {int(hrs):02d}:{int(mins):02d}:{secs:04.1f} ({size_mb / dt_secs:.1f}MB/s)",
2197
2161
  prefix=False,
2198
2162
  )
2199
2163
  return FilePathStr(root)
@@ -2202,79 +2166,44 @@ class Artifact:
2202
2166
  retry_timedelta=timedelta(minutes=3),
2203
2167
  retryable_exceptions=(requests.RequestException),
2204
2168
  )
2205
- def _fetch_file_urls(self, cursor: str | None, per_page: int | None = 5000) -> Any:
2169
+ def _fetch_file_urls(
2170
+ self, cursor: str | None, per_page: int = 5000
2171
+ ) -> FileUrlsFragment:
2172
+ if self._client is None:
2173
+ raise RuntimeError("Client not initialized")
2174
+
2206
2175
  if InternalApi()._server_supports(
2207
2176
  pb.ServerFeature.ARTIFACT_COLLECTION_MEMBERSHIP_FILES
2208
2177
  ):
2209
- query = gql(
2210
- """
2211
- query ArtifactCollectionMembershipFileURLs($entityName: String!, $projectName: String!, \
2212
- $artifactName: String!, $artifactVersionIndex: String!, $cursor: String, $perPage: Int) {
2213
- project(name: $projectName, entityName: $entityName) {
2214
- artifactCollection(name: $artifactName) {
2215
- artifactMembership(aliasName: $artifactVersionIndex) {
2216
- files(after: $cursor, first: $perPage) {
2217
- pageInfo {
2218
- hasNextPage
2219
- endCursor
2220
- }
2221
- edges {
2222
- node {
2223
- name
2224
- directUrl
2225
- }
2226
- }
2227
- }
2228
- }
2229
- }
2230
- }
2231
- }
2232
- """
2233
- )
2234
- assert self._client is not None
2235
- response = self._client.execute(
2236
- query,
2237
- variable_values={
2238
- "entityName": self.entity,
2239
- "projectName": self.project,
2240
- "artifactName": self.name.split(":")[0],
2241
- "artifactVersionIndex": self.version,
2242
- "cursor": cursor,
2243
- "perPage": per_page,
2244
- },
2245
- timeout=60,
2246
- )
2247
- return response["project"]["artifactCollection"]["artifactMembership"][
2248
- "files"
2249
- ]
2178
+ query = gql(ARTIFACT_COLLECTION_MEMBERSHIP_FILE_URLS_GQL)
2179
+ gql_vars = {
2180
+ "entityName": self.entity,
2181
+ "projectName": self.project,
2182
+ "artifactName": self.name.split(":")[0],
2183
+ "artifactVersionIndex": self.version,
2184
+ "cursor": cursor,
2185
+ "perPage": per_page,
2186
+ }
2187
+ data = self._client.execute(query, variable_values=gql_vars, timeout=60)
2188
+ result = ArtifactCollectionMembershipFileUrls.model_validate(data)
2189
+
2190
+ if not (
2191
+ (project := result.project)
2192
+ and (collection := project.artifact_collection)
2193
+ and (membership := collection.artifact_membership)
2194
+ and (files := membership.files)
2195
+ ):
2196
+ raise ValueError(f"Unable to fetch files for artifact: {self.name!r}")
2197
+ return files
2250
2198
  else:
2251
- query = gql(
2252
- """
2253
- query ArtifactFileURLs($id: ID!, $cursor: String, $perPage: Int) {
2254
- artifact(id: $id) {
2255
- files(after: $cursor, first: $perPage) {
2256
- pageInfo {
2257
- hasNextPage
2258
- endCursor
2259
- }
2260
- edges {
2261
- node {
2262
- name
2263
- directUrl
2264
- }
2265
- }
2266
- }
2267
- }
2268
- }
2269
- """
2270
- )
2271
- assert self._client is not None
2272
- response = self._client.execute(
2273
- query,
2274
- variable_values={"id": self.id, "cursor": cursor, "perPage": per_page},
2275
- timeout=60,
2276
- )
2277
- return response["artifact"]["files"]
2199
+ query = gql(ARTIFACT_FILE_URLS_GQL)
2200
+ gql_vars = {"id": self.id, "cursor": cursor, "perPage": per_page}
2201
+ data = self._client.execute(query, variable_values=gql_vars, timeout=60)
2202
+ result = ArtifactFileUrls.model_validate(data)
2203
+
2204
+ if not ((artifact := result.artifact) and (files := artifact.files)):
2205
+ raise ValueError(f"Unable to fetch files for artifact: {self.name!r}")
2206
+ return files
2278
2207
 
2279
2208
  @ensure_logged
2280
2209
  def checkout(self, root: str | None = None) -> str:
@@ -2440,33 +2369,16 @@ class Artifact:
2440
2369
 
2441
2370
  @normalize_exceptions
2442
2371
  def _delete(self, delete_aliases: bool = False) -> None:
2443
- mutation = gql(
2444
- """
2445
- mutation DeleteArtifact($artifactID: ID!, $deleteAliases: Boolean) {
2446
- deleteArtifact(input: {
2447
- artifactID: $artifactID
2448
- deleteAliases: $deleteAliases
2449
- }) {
2450
- artifact {
2451
- id
2452
- }
2453
- }
2454
- }
2455
- """
2456
- )
2457
- assert self._client is not None
2458
- self._client.execute(
2459
- mutation,
2460
- variable_values={
2461
- "artifactID": self.id,
2462
- "deleteAliases": delete_aliases,
2463
- },
2464
- )
2372
+ if self._client is None:
2373
+ raise RuntimeError("Client not initialized for artifact mutations")
2374
+
2375
+ mutation = gql(DELETE_ARTIFACT_GQL)
2376
+ gql_vars = {"artifactID": self.id, "deleteAliases": delete_aliases}
2377
+
2378
+ self._client.execute(mutation, variable_values=gql_vars)
2465
2379
 
2466
2380
  @normalize_exceptions
2467
- def link(
2468
- self, target_path: str, aliases: list[str] | None = None
2469
- ) -> Artifact | None:
2381
+ def link(self, target_path: str, aliases: list[str] | None = None) -> Artifact:
2470
2382
  """Link this artifact to a portfolio (a promoted collection of artifacts).
2471
2383
 
2472
2384
  Args:
@@ -2485,7 +2397,7 @@ class Artifact:
2485
2397
  ArtifactNotLoggedError: If the artifact is not logged.
2486
2398
 
2487
2399
  Returns:
2488
- The linked artifact if linking was successful, otherwise None.
2400
+ The linked artifact.
2489
2401
  """
2490
2402
  from wandb import Api
2491
2403
 
@@ -2505,10 +2417,11 @@ class Artifact:
2505
2417
  # Wait until the artifact is committed before trying to link it.
2506
2418
  self.wait()
2507
2419
 
2508
- api = Api(overrides={"entity": self.source_entity})
2420
+ api = InternalApi()
2421
+ settings = api.settings()
2509
2422
 
2510
2423
  target = ArtifactPath.from_str(target_path).with_defaults(
2511
- project=api.settings.get("project") or "uncategorized",
2424
+ project=settings.get("project") or "uncategorized",
2512
2425
  )
2513
2426
 
2514
2427
  # Parse the entity (first part of the path) appropriately,
@@ -2516,11 +2429,8 @@ class Artifact:
2516
2429
  if target.project and is_artifact_registry_project(target.project):
2517
2430
  # In a Registry linking, the entity is used to fetch the organization of the artifact
2518
2431
  # therefore the source artifact's entity is passed to the backend
2519
- organization = target.prefix or api.settings.get("organization") or ""
2520
-
2521
- target.prefix = InternalApi()._resolve_org_entity_name(
2522
- self.source_entity, organization
2523
- )
2432
+ org = target.prefix or settings.get("organization") or ""
2433
+ target.prefix = api._resolve_org_entity_name(self.source_entity, org)
2524
2434
  else:
2525
2435
  target = target.with_defaults(prefix=self.source_entity)
2526
2436
 
@@ -2537,21 +2447,35 @@ class Artifact:
2537
2447
  aliases=alias_inputs,
2538
2448
  )
2539
2449
  gql_vars = {"input": gql_input.model_dump(exclude_none=True)}
2540
- gql_op = gql(LINK_ARTIFACT_GQL)
2541
- data = self._client.execute(gql_op, variable_values=gql_vars)
2542
2450
 
2451
+ # Newer server versions can return `artifactMembership` directly in the response,
2452
+ # avoiding the need to re-fetch the linked artifact at the end.
2453
+ if api._server_supports(
2454
+ pb.ServerFeature.ARTIFACT_MEMBERSHIP_IN_LINK_ARTIFACT_RESPONSE
2455
+ ):
2456
+ omit_fragments = set()
2457
+ else:
2458
+ # FIXME: Make `gql_compat` omit nested fragment definitions recursively (but safely)
2459
+ omit_fragments = {
2460
+ "MembershipWithArtifact",
2461
+ "ArtifactFragment",
2462
+ "ArtifactFragmentWithoutAliases",
2463
+ }
2464
+
2465
+ gql_op = gql_compat(LINK_ARTIFACT_GQL, omit_fragments=omit_fragments)
2466
+ data = self._client.execute(gql_op, variable_values=gql_vars)
2543
2467
  result = LinkArtifact.model_validate(data).link_artifact
2468
+
2469
+ # Newer server versions can return artifactMembership directly in the response
2470
+ if result and (membership := result.artifact_membership):
2471
+ return self._from_membership(membership, target=target, client=self._client)
2472
+
2473
+ # Fallback to old behavior, which requires re-fetching the linked artifact to return it
2544
2474
  if not (result and (version_idx := result.version_index) is not None):
2545
2475
  raise ValueError("Unable to parse linked artifact version from response")
2546
2476
 
2547
- # Fetch the linked artifact to return it
2548
- linked_path = f"{target.to_str()}:v{version_idx}"
2549
-
2550
- try:
2551
- return api._artifact(linked_path)
2552
- except Exception as e:
2553
- wandb.termerror(f"Error fetching link artifact after linking: {e}")
2554
- return None
2477
+ link_name = f"{target.to_str()}:v{version_idx}"
2478
+ return Api(overrides={"entity": self.source_entity})._artifact(link_name)
2555
2479
 
2556
2480
  @ensure_logged
2557
2481
  def unlink(self) -> None:
@@ -2573,28 +2497,14 @@ class Artifact:
2573
2497
 
2574
2498
  @normalize_exceptions
2575
2499
  def _unlink(self) -> None:
2576
- mutation = gql(
2577
- """
2578
- mutation UnlinkArtifact($artifactID: ID!, $artifactPortfolioID: ID!) {
2579
- unlinkArtifact(
2580
- input: { artifactID: $artifactID, artifactPortfolioID: $artifactPortfolioID }
2581
- ) {
2582
- artifactID
2583
- success
2584
- clientMutationId
2585
- }
2586
- }
2587
- """
2588
- )
2589
- assert self._client is not None
2500
+ if self._client is None:
2501
+ raise RuntimeError("Client not initialized for artifact mutations")
2502
+
2503
+ mutation = gql(UNLINK_ARTIFACT_GQL)
2504
+ gql_vars = {"artifactID": self.id, "artifactPortfolioID": self.collection.id}
2505
+
2590
2506
  try:
2591
- self._client.execute(
2592
- mutation,
2593
- variable_values={
2594
- "artifactID": self.id,
2595
- "artifactPortfolioID": self.collection.id,
2596
- },
2597
- )
2507
+ self._client.execute(mutation, variable_values=gql_vars)
2598
2508
  except CommError as e:
2599
2509
  raise CommError(
2600
2510
  f"You do not have permission to unlink the artifact {self.qualified_name}"
@@ -2610,41 +2520,26 @@ class Artifact:
2610
2520
  Raises:
2611
2521
  ArtifactNotLoggedError: If the artifact is not logged.
2612
2522
  """
2613
- query = gql(
2614
- """
2615
- query ArtifactUsedBy(
2616
- $id: ID!,
2617
- ) {
2618
- artifact(id: $id) {
2619
- usedBy {
2620
- edges {
2621
- node {
2622
- name
2623
- project {
2624
- name
2625
- entityName
2626
- }
2627
- }
2628
- }
2629
- }
2630
- }
2631
- }
2632
- """
2633
- )
2634
- assert self._client is not None
2635
- response = self._client.execute(
2636
- query,
2637
- variable_values={"id": self.id},
2638
- )
2639
- return [
2640
- Run(
2641
- self._client,
2642
- edge["node"]["project"]["entityName"],
2643
- edge["node"]["project"]["name"],
2644
- edge["node"]["name"],
2645
- )
2646
- for edge in response.get("artifact", {}).get("usedBy", {}).get("edges", [])
2647
- ]
2523
+ if self._client is None:
2524
+ raise RuntimeError("Client not initialized for artifact queries")
2525
+
2526
+ query = gql(ARTIFACT_USED_BY_GQL)
2527
+ gql_vars = {"id": self.id}
2528
+ data = self._client.execute(query, variable_values=gql_vars)
2529
+ result = ArtifactUsedBy.model_validate(data)
2530
+
2531
+ if (
2532
+ (artifact := result.artifact)
2533
+ and (used_by := artifact.used_by)
2534
+ and (edges := used_by.edges)
2535
+ ):
2536
+ run_nodes = (e.node for e in edges)
2537
+ return [
2538
+ Run(self._client, proj.entity_name, proj.name, run.name)
2539
+ for run in run_nodes
2540
+ if (proj := run.project)
2541
+ ]
2542
+ return []
2648
2543
 
2649
2544
  @ensure_logged
2650
2545
  def logged_by(self) -> Run | None:
@@ -2656,39 +2551,22 @@ class Artifact:
2656
2551
  Raises:
2657
2552
  ArtifactNotLoggedError: If the artifact is not logged.
2658
2553
  """
2659
- query = gql(
2660
- """
2661
- query ArtifactCreatedBy(
2662
- $id: ID!
2663
- ) {
2664
- artifact(id: $id) {
2665
- createdBy {
2666
- ... on Run {
2667
- name
2668
- project {
2669
- name
2670
- entityName
2671
- }
2672
- }
2673
- }
2674
- }
2675
- }
2676
- """
2677
- )
2678
- assert self._client is not None
2679
- response = self._client.execute(
2680
- query,
2681
- variable_values={"id": self.id},
2682
- )
2683
- creator = response.get("artifact", {}).get("createdBy", {})
2684
- if creator.get("name") is None:
2685
- return None
2686
- return Run(
2687
- self._client,
2688
- creator["project"]["entityName"],
2689
- creator["project"]["name"],
2690
- creator["name"],
2691
- )
2554
+ if self._client is None:
2555
+ raise RuntimeError("Client not initialized for artifact queries")
2556
+
2557
+ query = gql(ARTIFACT_CREATED_BY_GQL)
2558
+ gql_vars = {"id": self.id}
2559
+ data = self._client.execute(query, variable_values=gql_vars)
2560
+ result = ArtifactCreatedBy.model_validate(data)
2561
+
2562
+ if (
2563
+ (artifact := result.artifact)
2564
+ and (creator := artifact.created_by)
2565
+ and (name := creator.name)
2566
+ and (project := creator.project)
2567
+ ):
2568
+ return Run(self._client, project.entity_name, project.name, name)
2569
+ return None
2692
2570
 
2693
2571
  @ensure_logged
2694
2572
  def json_encode(self) -> dict[str, Any]:
@@ -2704,37 +2582,22 @@ class Artifact:
2704
2582
  entity_name: str, project_name: str, name: str, client: RetryingClient
2705
2583
  ) -> str | None:
2706
2584
  """Returns the expected type for a given artifact name and project."""
2707
- query = gql(
2708
- """
2709
- query ArtifactType(
2710
- $entityName: String,
2711
- $projectName: String,
2712
- $name: String!
2713
- ) {
2714
- project(name: $projectName, entityName: $entityName) {
2715
- artifact(name: $name) {
2716
- artifactType {
2717
- name
2718
- }
2719
- }
2720
- }
2721
- }
2722
- """
2723
- )
2724
- if ":" not in name:
2725
- name += ":latest"
2726
- response = client.execute(
2727
- query,
2728
- variable_values={
2729
- "entityName": entity_name,
2730
- "projectName": project_name,
2731
- "name": name,
2732
- },
2733
- )
2734
- return (
2735
- ((response.get("project") or {}).get("artifact") or {}).get("artifactType")
2736
- or {}
2737
- ).get("name")
2585
+ query = gql(ARTIFACT_TYPE_GQL)
2586
+ gql_vars = {
2587
+ "entityName": entity_name,
2588
+ "projectName": project_name,
2589
+ "name": name if (":" in name) else f"{name}:latest",
2590
+ }
2591
+ data = client.execute(query, variable_values=gql_vars)
2592
+ result = ArtifactType.model_validate(data)
2593
+
2594
+ if (
2595
+ (project := result.project)
2596
+ and (artifact := project.artifact)
2597
+ and (artifact_type := artifact.artifact_type)
2598
+ ):
2599
+ return artifact_type.name
2600
+ return None
2738
2601
 
2739
2602
  def _load_manifest(self, url: str) -> ArtifactManifest:
2740
2603
  with requests.get(url) as response: