wandb 0.21.0__py3-none-win_amd64.whl → 0.21.2__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. wandb/__init__.py +16 -14
  2. wandb/__init__.pyi +427 -450
  3. wandb/agents/pyagent.py +41 -12
  4. wandb/analytics/sentry.py +7 -2
  5. wandb/apis/importers/mlflow.py +1 -1
  6. wandb/apis/public/__init__.py +1 -1
  7. wandb/apis/public/api.py +525 -360
  8. wandb/apis/public/artifacts.py +207 -13
  9. wandb/apis/public/automations.py +19 -3
  10. wandb/apis/public/files.py +172 -33
  11. wandb/apis/public/history.py +67 -15
  12. wandb/apis/public/integrations.py +25 -2
  13. wandb/apis/public/jobs.py +90 -2
  14. wandb/apis/public/projects.py +130 -79
  15. wandb/apis/public/query_generator.py +11 -1
  16. wandb/apis/public/registries/_utils.py +14 -16
  17. wandb/apis/public/registries/registries_search.py +183 -304
  18. wandb/apis/public/reports.py +96 -15
  19. wandb/apis/public/runs.py +299 -105
  20. wandb/apis/public/sweeps.py +222 -22
  21. wandb/apis/public/teams.py +41 -4
  22. wandb/apis/public/users.py +45 -4
  23. wandb/automations/_generated/delete_automation.py +1 -3
  24. wandb/automations/_generated/enums.py +13 -11
  25. wandb/beta/workflows.py +66 -30
  26. wandb/bin/gpu_stats.exe +0 -0
  27. wandb/bin/wandb-core +0 -0
  28. wandb/cli/cli.py +127 -3
  29. wandb/env.py +8 -0
  30. wandb/errors/errors.py +4 -1
  31. wandb/integration/lightning/fabric/logger.py +3 -4
  32. wandb/integration/metaflow/__init__.py +6 -0
  33. wandb/integration/metaflow/data_pandas.py +74 -0
  34. wandb/integration/metaflow/data_pytorch.py +75 -0
  35. wandb/integration/metaflow/data_sklearn.py +76 -0
  36. wandb/integration/metaflow/errors.py +13 -0
  37. wandb/integration/metaflow/metaflow.py +167 -223
  38. wandb/integration/openai/fine_tuning.py +1 -2
  39. wandb/integration/weave/__init__.py +6 -0
  40. wandb/integration/weave/interface.py +49 -0
  41. wandb/integration/weave/weave.py +63 -0
  42. wandb/jupyter.py +5 -5
  43. wandb/plot/custom_chart.py +30 -7
  44. wandb/proto/v3/wandb_internal_pb2.py +281 -280
  45. wandb/proto/v3/wandb_telemetry_pb2.py +4 -4
  46. wandb/proto/v4/wandb_internal_pb2.py +280 -280
  47. wandb/proto/v4/wandb_telemetry_pb2.py +4 -4
  48. wandb/proto/v5/wandb_internal_pb2.py +280 -280
  49. wandb/proto/v5/wandb_telemetry_pb2.py +4 -4
  50. wandb/proto/v6/wandb_internal_pb2.py +280 -280
  51. wandb/proto/v6/wandb_telemetry_pb2.py +4 -4
  52. wandb/proto/wandb_deprecated.py +6 -0
  53. wandb/sdk/artifacts/_factories.py +17 -0
  54. wandb/sdk/artifacts/_generated/__init__.py +221 -13
  55. wandb/sdk/artifacts/_generated/artifact_by_id.py +17 -0
  56. wandb/sdk/artifacts/_generated/artifact_by_name.py +22 -0
  57. wandb/sdk/artifacts/_generated/artifact_collection_membership_file_urls.py +43 -0
  58. wandb/sdk/artifacts/_generated/artifact_created_by.py +47 -0
  59. wandb/sdk/artifacts/_generated/artifact_file_urls.py +22 -0
  60. wandb/sdk/artifacts/_generated/artifact_type.py +31 -0
  61. wandb/sdk/artifacts/_generated/artifact_used_by.py +43 -0
  62. wandb/sdk/artifacts/_generated/artifact_via_membership_by_name.py +26 -0
  63. wandb/sdk/artifacts/_generated/delete_artifact.py +28 -0
  64. wandb/sdk/artifacts/_generated/enums.py +5 -0
  65. wandb/sdk/artifacts/_generated/fetch_artifact_manifest.py +38 -0
  66. wandb/sdk/artifacts/_generated/fetch_registries.py +32 -0
  67. wandb/sdk/artifacts/_generated/fragments.py +279 -41
  68. wandb/sdk/artifacts/_generated/link_artifact.py +6 -0
  69. wandb/sdk/artifacts/_generated/operations.py +654 -51
  70. wandb/sdk/artifacts/_generated/registry_collections.py +34 -0
  71. wandb/sdk/artifacts/_generated/registry_versions.py +34 -0
  72. wandb/sdk/artifacts/_generated/unlink_artifact.py +25 -0
  73. wandb/sdk/artifacts/_graphql_fragments.py +3 -86
  74. wandb/sdk/artifacts/_internal_artifact.py +19 -8
  75. wandb/sdk/artifacts/_validators.py +14 -4
  76. wandb/sdk/artifacts/artifact.py +512 -618
  77. wandb/sdk/artifacts/artifact_file_cache.py +10 -6
  78. wandb/sdk/artifacts/artifact_manifest.py +10 -9
  79. wandb/sdk/artifacts/artifact_manifest_entry.py +9 -10
  80. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +5 -3
  81. wandb/sdk/artifacts/storage_handlers/http_handler.py +1 -1
  82. wandb/sdk/artifacts/storage_handlers/s3_handler.py +1 -1
  83. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +1 -1
  84. wandb/sdk/data_types/audio.py +38 -10
  85. wandb/sdk/data_types/base_types/media.py +6 -56
  86. wandb/sdk/data_types/graph.py +48 -14
  87. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +1 -3
  88. wandb/sdk/data_types/helper_types/image_mask.py +1 -3
  89. wandb/sdk/data_types/histogram.py +34 -21
  90. wandb/sdk/data_types/html.py +35 -12
  91. wandb/sdk/data_types/image.py +104 -68
  92. wandb/sdk/data_types/molecule.py +32 -19
  93. wandb/sdk/data_types/object_3d.py +36 -17
  94. wandb/sdk/data_types/plotly.py +18 -5
  95. wandb/sdk/data_types/saved_model.py +4 -6
  96. wandb/sdk/data_types/table.py +59 -30
  97. wandb/sdk/data_types/video.py +53 -26
  98. wandb/sdk/integration_utils/auto_logging.py +2 -2
  99. wandb/sdk/interface/interface_queue.py +1 -4
  100. wandb/sdk/interface/interface_shared.py +26 -37
  101. wandb/sdk/interface/interface_sock.py +24 -14
  102. wandb/sdk/internal/internal_api.py +6 -0
  103. wandb/sdk/internal/job_builder.py +6 -0
  104. wandb/sdk/internal/settings_static.py +2 -3
  105. wandb/sdk/launch/agent/agent.py +8 -1
  106. wandb/sdk/launch/agent/run_queue_item_file_saver.py +2 -2
  107. wandb/sdk/launch/create_job.py +15 -2
  108. wandb/sdk/launch/inputs/internal.py +3 -4
  109. wandb/sdk/launch/inputs/schema.py +1 -0
  110. wandb/sdk/launch/runner/kubernetes_monitor.py +1 -0
  111. wandb/sdk/launch/runner/kubernetes_runner.py +323 -1
  112. wandb/sdk/launch/sweeps/scheduler.py +2 -3
  113. wandb/sdk/lib/asyncio_compat.py +19 -16
  114. wandb/sdk/lib/asyncio_manager.py +252 -0
  115. wandb/sdk/lib/deprecate.py +1 -7
  116. wandb/sdk/lib/disabled.py +1 -1
  117. wandb/sdk/lib/hashutil.py +27 -5
  118. wandb/sdk/lib/module.py +7 -13
  119. wandb/sdk/lib/printer.py +2 -2
  120. wandb/sdk/lib/printer_asyncio.py +3 -1
  121. wandb/sdk/lib/progress.py +0 -19
  122. wandb/sdk/lib/retry.py +185 -78
  123. wandb/sdk/lib/service/service_client.py +106 -0
  124. wandb/sdk/lib/service/service_connection.py +20 -26
  125. wandb/sdk/lib/service/service_token.py +30 -13
  126. wandb/sdk/mailbox/mailbox.py +13 -5
  127. wandb/sdk/mailbox/mailbox_handle.py +22 -13
  128. wandb/sdk/mailbox/response_handle.py +42 -106
  129. wandb/sdk/mailbox/wait_with_progress.py +7 -42
  130. wandb/sdk/wandb_init.py +77 -116
  131. wandb/sdk/wandb_login.py +19 -15
  132. wandb/sdk/wandb_metric.py +2 -0
  133. wandb/sdk/wandb_run.py +497 -469
  134. wandb/sdk/wandb_settings.py +145 -4
  135. wandb/sdk/wandb_setup.py +204 -124
  136. wandb/sdk/wandb_sweep.py +14 -13
  137. wandb/sdk/wandb_watch.py +4 -6
  138. wandb/sync/sync.py +10 -0
  139. wandb/util.py +58 -1
  140. wandb/wandb_run.py +1 -2
  141. {wandb-0.21.0.dist-info → wandb-0.21.2.dist-info}/METADATA +1 -1
  142. {wandb-0.21.0.dist-info → wandb-0.21.2.dist-info}/RECORD +145 -129
  143. wandb/sdk/interface/interface_relay.py +0 -38
  144. wandb/sdk/interface/router.py +0 -89
  145. wandb/sdk/interface/router_queue.py +0 -43
  146. wandb/sdk/interface/router_relay.py +0 -50
  147. wandb/sdk/interface/router_sock.py +0 -32
  148. wandb/sdk/lib/sock_client.py +0 -236
  149. wandb/vendor/pynvml/__init__.py +0 -0
  150. wandb/vendor/pynvml/pynvml.py +0 -4779
  151. {wandb-0.21.0.dist-info → wandb-0.21.2.dist-info}/WHEEL +0 -0
  152. {wandb-0.21.0.dist-info → wandb-0.21.2.dist-info}/entry_points.txt +0 -0
  153. {wandb-0.21.0.dist-info → wandb-0.21.2.dist-info}/licenses/LICENSE +0 -0
@@ -17,11 +17,22 @@ import time
17
17
  from collections import deque
18
18
  from copy import copy
19
19
  from dataclasses import dataclass
20
- from datetime import datetime, timedelta
20
+ from datetime import timedelta
21
21
  from functools import partial
22
22
  from itertools import filterfalse
23
- from pathlib import PurePosixPath
24
- from typing import IO, TYPE_CHECKING, Any, Iterator, Literal, Sequence, Type, final
23
+ from pathlib import Path, PurePosixPath
24
+ from typing import (
25
+ IO,
26
+ TYPE_CHECKING,
27
+ Any,
28
+ Final,
29
+ Iterator,
30
+ Literal,
31
+ Sequence,
32
+ Type,
33
+ cast,
34
+ final,
35
+ )
25
36
  from urllib.parse import quote, urljoin, urlparse
26
37
 
27
38
  import requests
@@ -58,21 +69,45 @@ from wandb.util import (
58
69
  vendor_setup,
59
70
  )
60
71
 
72
+ from ._factories import make_storage_policy
61
73
  from ._generated import (
62
74
  ADD_ALIASES_GQL,
75
+ ARTIFACT_BY_ID_GQL,
76
+ ARTIFACT_BY_NAME_GQL,
77
+ ARTIFACT_COLLECTION_MEMBERSHIP_FILE_URLS_GQL,
78
+ ARTIFACT_CREATED_BY_GQL,
79
+ ARTIFACT_FILE_URLS_GQL,
80
+ ARTIFACT_TYPE_GQL,
81
+ ARTIFACT_USED_BY_GQL,
82
+ ARTIFACT_VIA_MEMBERSHIP_BY_NAME_GQL,
63
83
  DELETE_ALIASES_GQL,
84
+ DELETE_ARTIFACT_GQL,
85
+ FETCH_ARTIFACT_MANIFEST_GQL,
64
86
  FETCH_LINKED_ARTIFACTS_GQL,
65
87
  LINK_ARTIFACT_GQL,
88
+ UNLINK_ARTIFACT_GQL,
66
89
  UPDATE_ARTIFACT_GQL,
67
90
  ArtifactAliasInput,
91
+ ArtifactByID,
92
+ ArtifactByName,
68
93
  ArtifactCollectionAliasInput,
94
+ ArtifactCollectionMembershipFileUrls,
95
+ ArtifactCreatedBy,
96
+ ArtifactFileUrls,
97
+ ArtifactFragment,
98
+ ArtifactType,
99
+ ArtifactUsedBy,
100
+ ArtifactViaMembershipByName,
101
+ FetchArtifactManifest,
69
102
  FetchLinkedArtifacts,
103
+ FileUrlsFragment,
70
104
  LinkArtifact,
71
105
  LinkArtifactInput,
106
+ MembershipWithArtifact,
72
107
  TagInput,
73
108
  UpdateArtifact,
74
109
  )
75
- from ._graphql_fragments import _gql_artifact_fragment, omit_artifact_fields
110
+ from ._graphql_fragments import omit_artifact_fields
76
111
  from ._validators import (
77
112
  LINKED_ARTIFACT_COLLECTION_TYPE,
78
113
  ArtifactPath,
@@ -80,6 +115,7 @@ from ._validators import (
80
115
  ensure_logged,
81
116
  ensure_not_finalized,
82
117
  is_artifact_registry_project,
118
+ remove_registry_prefix,
83
119
  validate_aliases,
84
120
  validate_artifact_name,
85
121
  validate_artifact_type,
@@ -102,9 +138,6 @@ from .exceptions import (
102
138
  )
103
139
  from .staging import get_staging_dir
104
140
  from .storage_handlers.gcs_handler import _GCSIsADirectoryError
105
- from .storage_layout import StorageLayout
106
- from .storage_policies import WANDB_STORAGE_POLICY
107
- from .storage_policy import StoragePolicy
108
141
 
109
142
  reset_path = vendor_setup()
110
143
 
@@ -118,6 +151,9 @@ if TYPE_CHECKING:
118
151
  logger = logging.getLogger(__name__)
119
152
 
120
153
 
154
+ _MB: Final[int] = 1024 * 1024
155
+
156
+
121
157
  @final
122
158
  @dataclass
123
159
  class _DeferredArtifactManifest:
@@ -131,25 +167,26 @@ class Artifact:
131
167
 
132
168
  Construct an empty W&B Artifact. Populate an artifacts contents with methods that
133
169
  begin with `add`. Once the artifact has all the desired files, you can call
134
- `wandb.log_artifact()` to log it.
170
+ `run.log_artifact()` to log it.
135
171
 
136
172
  Args:
137
- name: A human-readable name for the artifact. Use the name to identify
173
+ name (str): A human-readable name for the artifact. Use the name to identify
138
174
  a specific artifact in the W&B App UI or programmatically. You can
139
175
  interactively reference an artifact with the `use_artifact` Public API.
140
176
  A name can contain letters, numbers, underscores, hyphens, and dots.
141
177
  The name must be unique across a project.
142
- type: The artifact's type. Use the type of an artifact to both organize
178
+ type (str): The artifact's type. Use the type of an artifact to both organize
143
179
  and differentiate artifacts. You can use any string that contains letters,
144
180
  numbers, underscores, hyphens, and dots. Common types include `dataset` or `model`.
145
- Note: Some types are reserved for internal use and cannot be set by users.
146
- Such types include `job` and types that start with `wandb-`.
147
- description: A description of the artifact. For Model or Dataset Artifacts,
181
+ Include `model` within your type string if you want to link the artifact
182
+ to the W&B Model Registry. Note that some types reserved for internal use
183
+ and cannot be set by users. Such types include `job` and types that start with `wandb-`.
184
+ description (str | None) = None: A description of the artifact. For Model or Dataset Artifacts,
148
185
  add documentation for your standardized team model or dataset card. View
149
186
  an artifact's description programmatically with the `Artifact.description`
150
187
  attribute or programmatically with the W&B App UI. W&B renders the
151
188
  description as markdown in the W&B App.
152
- metadata: Additional information about an artifact. Specify metadata as a
189
+ metadata (dict[str, Any] | None) = None: Additional information about an artifact. Specify metadata as a
153
190
  dictionary of key-value pairs. You can specify no more than 100 total keys.
154
191
  incremental: Use `Artifact.new_draft()` method instead to modify an
155
192
  existing artifact.
@@ -186,11 +223,6 @@ class Artifact:
186
223
  # Internal.
187
224
  self._client: RetryingClient | None = None
188
225
 
189
- storage_policy_cls = StoragePolicy.lookup_by_name(WANDB_STORAGE_POLICY)
190
- layout = StorageLayout.V1 if env.get_use_v1_artifacts() else StorageLayout.V2
191
- policy_config = {"storageLayout": layout}
192
- self._storage_policy = storage_policy_cls.from_config(config=policy_config)
193
-
194
226
  self._tmp_dir: tempfile.TemporaryDirectory | None = None
195
227
  self._added_objs: dict[int, tuple[WBValue, ArtifactManifestEntry]] = {}
196
228
  self._added_local_paths: dict[str, ArtifactManifestEntry] = {}
@@ -234,7 +266,7 @@ class Artifact:
234
266
  self._use_as: str | None = None
235
267
  self._state: ArtifactState = ArtifactState.PENDING
236
268
  self._manifest: ArtifactManifest | _DeferredArtifactManifest | None = (
237
- ArtifactManifestV1(self._storage_policy)
269
+ ArtifactManifestV1(storage_policy=make_storage_policy())
238
270
  )
239
271
  self._commit_hash: str | None = None
240
272
  self._file_count: int | None = None
@@ -255,32 +287,22 @@ class Artifact:
255
287
  if (artifact := artifact_instance_cache.get(artifact_id)) is not None:
256
288
  return artifact
257
289
 
258
- query = gql(
259
- """
260
- query ArtifactByID($id: ID!) {
261
- artifact(id: $id) {
262
- ...ArtifactFragment
263
- }
264
- }
265
- """
266
- + _gql_artifact_fragment()
267
- )
268
- response = client.execute(
269
- query,
270
- variable_values={"id": artifact_id},
271
- )
272
- attrs = response.get("artifact")
273
- if attrs is None:
290
+ query = gql_compat(ARTIFACT_BY_ID_GQL, omit_fields=omit_artifact_fields())
291
+
292
+ data = client.execute(query, variable_values={"id": artifact_id})
293
+ result = ArtifactByID.model_validate(data)
294
+
295
+ if (art := result.artifact) is None:
274
296
  return None
275
297
 
276
- src_collection = attrs["artifactSequence"]
277
- src_project = src_collection["project"]
298
+ src_collection = art.artifact_sequence
299
+ src_project = src_collection.project
278
300
 
279
- entity_name = src_project["entityName"] if src_project else ""
280
- project_name = src_project["name"] if src_project else ""
301
+ entity_name = src_project.entity_name if src_project else ""
302
+ project_name = src_project.name if src_project else ""
281
303
 
282
- name = "{}:v{}".format(src_collection["name"], attrs["versionIndex"])
283
- return cls._from_attrs(entity_name, project_name, name, attrs, client)
304
+ name = f"{src_collection.name}:v{art.version_index}"
305
+ return cls._from_attrs(entity_name, project_name, name, art, client)
284
306
 
285
307
  @classmethod
286
308
  def _membership_from_name(
@@ -291,7 +313,7 @@ class Artifact:
291
313
  name: str,
292
314
  client: RetryingClient,
293
315
  ) -> Artifact:
294
- if not InternalApi()._server_supports(
316
+ if not (api := InternalApi())._server_supports(
295
317
  pb.ServerFeature.PROJECT_ARTIFACT_COLLECTION_MEMBERSHIP
296
318
  ):
297
319
  raise UnsupportedError(
@@ -299,69 +321,26 @@ class Artifact:
299
321
  "by this version of wandb server. Consider updating to the latest version."
300
322
  )
301
323
 
302
- query = gql(
303
- f"""
304
- query ArtifactByName($entityName: String!, $projectName: String!, $name: String!) {{
305
- project(name: $projectName, entityName: $entityName) {{
306
- artifactCollectionMembership(name: $name) {{
307
- id
308
- artifactCollection {{
309
- id
310
- name
311
- project {{
312
- id
313
- entityName
314
- name
315
- }}
316
- }}
317
- artifact {{
318
- ...ArtifactFragment
319
- }}
320
- }}
321
- }}
322
- }}
323
- {_gql_artifact_fragment()}
324
- """
324
+ query = gql_compat(
325
+ ARTIFACT_VIA_MEMBERSHIP_BY_NAME_GQL,
326
+ omit_fields=omit_artifact_fields(api=api),
325
327
  )
326
328
 
327
- query_variable_values: dict[str, Any] = {
328
- "entityName": entity,
329
- "projectName": project,
330
- "name": name,
331
- }
332
- response = client.execute(
333
- query,
334
- variable_values=query_variable_values,
335
- )
336
- if not (project_attrs := response.get("project")):
329
+ gql_vars = {"entityName": entity, "projectName": project, "name": name}
330
+ data = client.execute(query, variable_values=gql_vars)
331
+ result = ArtifactViaMembershipByName.model_validate(data)
332
+
333
+ if not (project_attrs := result.project):
337
334
  raise ValueError(f"project {project!r} not found under entity {entity!r}")
338
- if not (acm_attrs := project_attrs.get("artifactCollectionMembership")):
335
+
336
+ if not (acm_attrs := project_attrs.artifact_collection_membership):
339
337
  entity_project = f"{entity}/{project}"
340
338
  raise ValueError(
341
339
  f"artifact membership {name!r} not found in {entity_project!r}"
342
340
  )
343
- if not (ac_attrs := acm_attrs.get("artifactCollection")):
344
- raise ValueError("artifact collection not found")
345
- if not (
346
- (ac_name := ac_attrs.get("name"))
347
- and (ac_project_attrs := ac_attrs.get("project"))
348
- ):
349
- raise ValueError("artifact collection project not found")
350
- ac_project = ac_project_attrs.get("name")
351
- ac_entity = ac_project_attrs.get("entityName")
352
- if is_artifact_registry_project(ac_project) and project == "model-registry":
353
- wandb.termwarn(
354
- "This model registry has been migrated and will be discontinued. "
355
- f"Your request was redirected to the corresponding artifact `{ac_name}` in the new registry. "
356
- f"Please update your paths to point to the migrated registry directly, `{ac_project}/{ac_name}`."
357
- )
358
- entity = ac_entity
359
- project = ac_project
360
- if not (attrs := acm_attrs.get("artifact")):
361
- entity_project = f"{entity}/{project}"
362
- raise ValueError(f"artifact {name!r} not found in {entity_project!r}")
363
341
 
364
- return cls._from_attrs(entity, project, name, attrs, client)
342
+ target_path = ArtifactPath(prefix=entity, project=project, name=name)
343
+ return cls._from_membership(acm_attrs, target=target_path, client=client)
365
344
 
366
345
  @classmethod
367
346
  def _from_name(
@@ -373,59 +352,71 @@ class Artifact:
373
352
  client: RetryingClient,
374
353
  enable_tracking: bool = False,
375
354
  ) -> Artifact:
376
- if InternalApi()._server_supports(
355
+ if (api := InternalApi())._server_supports(
377
356
  pb.ServerFeature.PROJECT_ARTIFACT_COLLECTION_MEMBERSHIP
378
357
  ):
379
358
  return cls._membership_from_name(
380
- entity=entity,
381
- project=project,
382
- name=name,
383
- client=client,
359
+ entity=entity, project=project, name=name, client=client
384
360
  )
385
361
 
386
- query_variable_values: dict[str, Any] = {
362
+ supports_enable_tracking_gql_var = api.server_project_type_introspection()
363
+ omit_vars = None if supports_enable_tracking_gql_var else {"enableTracking"}
364
+
365
+ gql_vars = {
387
366
  "entityName": entity,
388
367
  "projectName": project,
389
368
  "name": name,
369
+ "enableTracking": enable_tracking,
390
370
  }
391
- query_vars = ["$entityName: String!", "$projectName: String!", "$name: String!"]
392
- query_args = ["name: $name"]
393
-
394
- server_supports_enabling_artifact_usage_tracking = (
395
- InternalApi().server_project_type_introspection()
396
- )
397
- if server_supports_enabling_artifact_usage_tracking:
398
- query_vars.append("$enableTracking: Boolean")
399
- query_args.append("enableTracking: $enableTracking")
400
- query_variable_values["enableTracking"] = enable_tracking
401
-
402
- vars_str = ", ".join(query_vars)
403
- args_str = ", ".join(query_args)
404
-
405
- query = gql(
406
- f"""
407
- query ArtifactByName({vars_str}) {{
408
- project(name: $projectName, entityName: $entityName) {{
409
- artifact({args_str}) {{
410
- ...ArtifactFragment
411
- }}
412
- }}
413
- }}
414
- {_gql_artifact_fragment()}
415
- """
371
+ query = gql_compat(
372
+ ARTIFACT_BY_NAME_GQL,
373
+ omit_variables=omit_vars,
374
+ omit_fields=omit_artifact_fields(api=api),
416
375
  )
417
- response = client.execute(
418
- query,
419
- variable_values=query_variable_values,
420
- )
421
- project_attrs = response.get("project")
422
- if not project_attrs:
423
- raise ValueError(f"project '{project}' not found under entity '{entity}'")
424
- attrs = project_attrs.get("artifact")
425
- if not attrs:
426
- raise ValueError(f"artifact '{name}' not found in '{entity}/{project}'")
427
376
 
428
- return cls._from_attrs(entity, project, name, attrs, client)
377
+ data = client.execute(query, variable_values=gql_vars)
378
+ result = ArtifactByName.model_validate(data)
379
+
380
+ if not (proj_attrs := result.project):
381
+ raise ValueError(f"project {project!r} not found under entity {entity!r}")
382
+
383
+ if not (art_attrs := proj_attrs.artifact):
384
+ entity_project = f"{entity}/{project}"
385
+ raise ValueError(f"artifact {name!r} not found in {entity_project!r}")
386
+
387
+ return cls._from_attrs(entity, project, name, art_attrs, client)
388
+
389
+ @classmethod
390
+ def _from_membership(
391
+ cls,
392
+ membership: MembershipWithArtifact,
393
+ target: ArtifactPath,
394
+ client: RetryingClient,
395
+ ) -> Artifact:
396
+ if not (
397
+ (collection := membership.artifact_collection)
398
+ and (name := collection.name)
399
+ and (proj := collection.project)
400
+ ):
401
+ raise ValueError("Missing artifact collection project in GraphQL response")
402
+
403
+ if is_artifact_registry_project(proj.name) and (
404
+ target.project == "model-registry"
405
+ ):
406
+ wandb.termwarn(
407
+ "This model registry has been migrated and will be discontinued. "
408
+ f"Your request was redirected to the corresponding artifact {name!r} in the new registry. "
409
+ f"Please update your paths to point to the migrated registry directly, '{proj.name}/{name}'."
410
+ )
411
+ new_entity, new_project = proj.entity_name, proj.name
412
+ else:
413
+ new_entity = cast(str, target.prefix)
414
+ new_project = cast(str, target.project)
415
+
416
+ if not (artifact := membership.artifact):
417
+ raise ValueError(f"Artifact {target.to_str()!r} not found in response")
418
+
419
+ return cls._from_attrs(new_entity, new_project, target.name, artifact, client)
429
420
 
430
421
  @classmethod
431
422
  def _from_attrs(
@@ -433,7 +424,7 @@ class Artifact:
433
424
  entity: str,
434
425
  project: str,
435
426
  name: str,
436
- attrs: dict[str, Any],
427
+ attrs: dict[str, Any] | ArtifactFragment,
437
428
  client: RetryingClient,
438
429
  aliases: list[str] | None = None,
439
430
  ) -> Artifact:
@@ -443,7 +434,9 @@ class Artifact:
443
434
  artifact._entity = entity
444
435
  artifact._project = project
445
436
  artifact._name = name
446
- artifact._assign_attrs(attrs, aliases)
437
+
438
+ validated_attrs = ArtifactFragment.model_validate(attrs)
439
+ artifact._assign_attrs(validated_attrs, aliases)
447
440
 
448
441
  artifact.finalize()
449
442
 
@@ -456,29 +449,24 @@ class Artifact:
456
449
  # doesn't make it clear if the artifact is a link or not and have to manually set it.
457
450
  def _assign_attrs(
458
451
  self,
459
- attrs: dict[str, Any],
452
+ art: ArtifactFragment,
460
453
  aliases: list[str] | None = None,
461
454
  is_link: bool | None = None,
462
455
  ) -> None:
463
456
  """Update this Artifact's attributes using the server response."""
464
- self._id = attrs["id"]
465
-
466
- src_version = f"v{attrs['versionIndex']}"
467
- src_collection = attrs["artifactSequence"]
468
- src_project = src_collection["project"]
457
+ self._id = art.id
469
458
 
470
- self._source_entity = src_project["entityName"] if src_project else ""
471
- self._source_project = src_project["name"] if src_project else ""
472
- self._source_name = f"{src_collection['name']}:{src_version}"
473
- self._source_version = src_version
459
+ src_collection = art.artifact_sequence
460
+ src_project = src_collection.project
474
461
 
475
- if self._entity is None:
476
- self._entity = self._source_entity
477
- if self._project is None:
478
- self._project = self._source_project
462
+ self._source_entity = src_project.entity_name if src_project else ""
463
+ self._source_project = src_project.name if src_project else ""
464
+ self._source_name = f"{src_collection.name}:v{art.version_index}"
465
+ self._source_version = f"v{art.version_index}"
479
466
 
480
- if self._name is None:
481
- self._name = self._source_name
467
+ self._entity = self._entity or self._source_entity
468
+ self._project = self._project or self._source_project
469
+ self._name = self._name or self._source_name
482
470
 
483
471
  # TODO: Refactor artifact query to fetch artifact via membership instead
484
472
  # and get the collection type
@@ -486,33 +474,35 @@ class Artifact:
486
474
  self._is_link = (
487
475
  self._entity != self._source_entity
488
476
  or self._project != self._source_project
489
- or self._name != self._source_name
477
+ or self._name.split(":")[0] != self._source_name.split(":")[0]
490
478
  )
491
479
  else:
492
480
  self._is_link = is_link
493
481
 
494
- self._type = attrs["artifactType"]["name"]
495
- self._description = attrs["description"]
496
-
497
- entity = self._entity
498
- project = self._project
499
- collection, *_ = self._name.split(":")
482
+ self._type = art.artifact_type.name
483
+ self._description = art.description
500
484
 
501
- processed_aliases = []
502
485
  # The future of aliases is to move all alias fetches to the membership level
503
486
  # so we don't have to do the collection fetches below
504
487
  if aliases:
505
488
  processed_aliases = aliases
506
- else:
489
+ elif art.aliases:
490
+ entity = self._entity
491
+ project = self._project
492
+ collection = self._name.split(":")[0]
507
493
  processed_aliases = [
508
- obj["alias"]
509
- for obj in attrs["aliases"]
510
- if obj["artifactCollection"]
511
- and obj["artifactCollection"]["project"]
512
- and obj["artifactCollection"]["project"]["entityName"] == entity
513
- and obj["artifactCollection"]["project"]["name"] == project
514
- and obj["artifactCollection"]["name"] == collection
494
+ art_alias.alias
495
+ for art_alias in art.aliases
496
+ if (
497
+ (coll := art_alias.artifact_collection)
498
+ and (proj := coll.project)
499
+ and proj.entity_name == entity
500
+ and proj.name == project
501
+ and coll.name == collection
502
+ )
515
503
  ]
504
+ else:
505
+ processed_aliases = []
516
506
 
517
507
  version_aliases = list(filter(alias_is_version_index, processed_aliases))
518
508
  other_aliases = list(filterfalse(alias_is_version_index, processed_aliases))
@@ -522,49 +512,42 @@ class Artifact:
522
512
  version_aliases, too_short=TooFewItemsError, too_long=TooManyItemsError
523
513
  )
524
514
  except TooFewItemsError:
525
- version = src_version # default to the source version
515
+ version = f"v{art.version_index}" # default to the source version
526
516
  except TooManyItemsError:
527
517
  msg = f"Expected at most one version alias, got {len(version_aliases)}: {version_aliases!r}"
528
518
  raise ValueError(msg) from None
529
519
 
530
520
  self._version = version
531
-
532
- if ":" not in self._name:
533
- self._name = f"{self._name}:{version}"
521
+ self._name = self._name if (":" in self._name) else f"{self._name}:{version}"
534
522
 
535
523
  self._aliases = other_aliases
536
- self._saved_aliases = copy(other_aliases)
524
+ self._saved_aliases = copy(self._aliases)
537
525
 
538
- tags = [obj["name"] for obj in (attrs.get("tags") or [])]
539
- self._tags = tags
540
- self._saved_tags = copy(tags)
526
+ self._tags = [tag.name for tag in (art.tags or [])]
527
+ self._saved_tags = copy(self._tags)
541
528
 
542
- metadata_str = attrs["metadata"]
543
- self._metadata = validate_metadata(
544
- json.loads(metadata_str) if metadata_str else {}
545
- )
529
+ self._metadata = validate_metadata(art.metadata)
546
530
 
547
531
  self._ttl_duration_seconds = validate_ttl_duration_seconds(
548
- attrs.get("ttlDurationSeconds")
532
+ art.ttl_duration_seconds
549
533
  )
550
534
  self._ttl_is_inherited = (
551
- True if (attrs.get("ttlIsInherited") is None) else attrs["ttlIsInherited"]
535
+ True if (art.ttl_is_inherited is None) else art.ttl_is_inherited
552
536
  )
553
537
 
554
- self._state = ArtifactState(attrs["state"])
538
+ self._state = ArtifactState(art.state)
555
539
 
556
- try:
557
- manifest_url = attrs["currentManifest"]["file"]["directUrl"]
558
- except (LookupError, TypeError):
559
- self._manifest = None
560
- else:
561
- self._manifest = _DeferredArtifactManifest(manifest_url)
540
+ self._manifest = (
541
+ _DeferredArtifactManifest(manifest.file.direct_url)
542
+ if (manifest := art.current_manifest)
543
+ else None
544
+ )
562
545
 
563
- self._commit_hash = attrs["commitHash"]
564
- self._file_count = attrs["fileCount"]
565
- self._created_at = attrs["createdAt"]
566
- self._updated_at = attrs["updatedAt"]
567
- self._history_step = attrs.get("historyStep", None)
546
+ self._commit_hash = art.commit_hash
547
+ self._file_count = art.file_count
548
+ self._created_at = art.created_at
549
+ self._updated_at = art.updated_at
550
+ self._history_step = art.history_step
568
551
 
569
552
  @ensure_logged
570
553
  def new_draft(self) -> Artifact:
@@ -839,7 +822,7 @@ class Artifact:
839
822
  )
840
823
  return urljoin(
841
824
  base_url,
842
- f"orgs/{org.display_name}/registry/{self._type}?selectionPath={selection_path}&view=membership&version={self._version}",
825
+ f"orgs/{org.display_name}/registry/{remove_registry_prefix(self.project)}?selectionPath={selection_path}&view=membership&version={self.version}",
843
826
  )
844
827
 
845
828
  def _construct_model_registry_url(self, base_url: str) -> str:
@@ -924,7 +907,8 @@ class Artifact:
924
907
  TTL and there is no custom policy set on an artifact.
925
908
 
926
909
  Raises:
927
- ArtifactNotLoggedError: Unable to fetch inherited TTL if the artifact has not been logged or saved
910
+ ArtifactNotLoggedError: Unable to fetch inherited TTL if the
911
+ artifact has not been logged or saved.
928
912
  """
929
913
  if self._ttl_is_inherited and (self.is_draft() or self._ttl_changed):
930
914
  raise ArtifactNotLoggedError(f"{type(self).__name__}.ttl", self)
@@ -976,7 +960,9 @@ class Artifact:
976
960
  @property
977
961
  @ensure_logged
978
962
  def aliases(self) -> list[str]:
979
- """List of one or more semantically-friendly references or identifying "nicknames" assigned to an artifact version.
963
+ """List of one or more semantically-friendly references or
964
+
965
+ identifying "nicknames" assigned to an artifact version.
980
966
 
981
967
  Aliases are mutable references that you can programmatically reference.
982
968
  Change an artifact's alias with the W&B App UI or programmatically.
@@ -1012,6 +998,10 @@ class Artifact:
1012
998
 
1013
999
  @property
1014
1000
  def distributed_id(self) -> str | None:
1001
+ """The distributed ID of the artifact.
1002
+
1003
+ <!-- lazydoc-ignore: internal -->
1004
+ """
1015
1005
  return self._distributed_id
1016
1006
 
1017
1007
  @distributed_id.setter
@@ -1020,6 +1010,10 @@ class Artifact:
1020
1010
 
1021
1011
  @property
1022
1012
  def incremental(self) -> bool:
1013
+ """Boolean flag indicating if the artifact is an incremental artifact.
1014
+
1015
+ <!-- lazydoc-ignore: internal -->
1016
+ """
1023
1017
  return self._incremental
1024
1018
 
1025
1019
  @property
@@ -1050,37 +1044,24 @@ class Artifact:
1050
1044
  return self._manifest
1051
1045
 
1052
1046
  if self._manifest is None:
1053
- query = gql(
1054
- """
1055
- query ArtifactManifest(
1056
- $entityName: String!,
1057
- $projectName: String!,
1058
- $name: String!
1059
- ) {
1060
- project(entityName: $entityName, name: $projectName) {
1061
- artifact(name: $name) {
1062
- currentManifest {
1063
- file {
1064
- directUrl
1065
- }
1066
- }
1067
- }
1068
- }
1069
- }
1070
- """
1071
- )
1072
- assert self._client is not None
1073
- response = self._client.execute(
1074
- query,
1075
- variable_values={
1076
- "entityName": self._entity,
1077
- "projectName": self._project,
1078
- "name": self._name,
1079
- },
1080
- )
1081
- attrs = response["project"]["artifact"]
1082
- manifest_url = attrs["currentManifest"]["file"]["directUrl"]
1083
- self._manifest = self._load_manifest(manifest_url)
1047
+ if self._client is None:
1048
+ raise RuntimeError("Client not initialized for artifact queries")
1049
+
1050
+ query = gql(FETCH_ARTIFACT_MANIFEST_GQL)
1051
+ gql_vars = {
1052
+ "entityName": self.entity,
1053
+ "projectName": self.project,
1054
+ "name": self.name,
1055
+ }
1056
+ data = self._client.execute(query, variable_values=gql_vars)
1057
+ result = FetchArtifactManifest.model_validate(data)
1058
+ if not (
1059
+ (project := result.project)
1060
+ and (artifact := project.artifact)
1061
+ and (manifest := artifact.current_manifest)
1062
+ ):
1063
+ raise ValueError("Failed to fetch artifact manifest")
1064
+ self._manifest = self._load_manifest(manifest.file.direct_url)
1084
1065
 
1085
1066
  return self._manifest
1086
1067
 
@@ -1099,11 +1080,7 @@ class Artifact:
1099
1080
 
1100
1081
  Includes any references tracked by this artifact.
1101
1082
  """
1102
- total_size: int = 0
1103
- for entry in self.manifest.entries.values():
1104
- if entry.size is not None:
1105
- total_size += entry.size
1106
- return total_size
1083
+ return sum(entry.size for entry in self.manifest.entries.values() if entry.size)
1107
1084
 
1108
1085
  @property
1109
1086
  @ensure_logged
@@ -1139,15 +1116,15 @@ class Artifact:
1139
1116
  """The nearest step at which history metrics were logged for the source run of the artifact.
1140
1117
 
1141
1118
  Examples:
1142
- ```python
1143
- run = artifact.logged_by()
1144
- if run and (artifact.history_step is not None):
1145
- history = run.sample_history(
1146
- min_step=artifact.history_step,
1147
- max_step=artifact.history_step + 1,
1148
- keys=["my_metric"],
1149
- )
1150
- ```
1119
+ ```python
1120
+ run = artifact.logged_by()
1121
+ if run and (artifact.history_step is not None):
1122
+ history = run.sample_history(
1123
+ min_step=artifact.history_step,
1124
+ max_step=artifact.history_step + 1,
1125
+ keys=["my_metric"],
1126
+ )
1127
+ ```
1151
1128
  """
1152
1129
  if self._history_step is None:
1153
1130
  return None
@@ -1168,9 +1145,10 @@ class Artifact:
1168
1145
  def is_draft(self) -> bool:
1169
1146
  """Check if artifact is not saved.
1170
1147
 
1171
- Returns: Boolean. `False` if artifact is saved. `True` if artifact is not saved.
1148
+ Returns:
1149
+ Boolean. `False` if artifact is saved. `True` if artifact is not saved.
1172
1150
  """
1173
- return self._state == ArtifactState.PENDING
1151
+ return self._state is ArtifactState.PENDING
1174
1152
 
1175
1153
  def _is_draft_save_started(self) -> bool:
1176
1154
  return self._save_handle is not None
@@ -1191,7 +1169,7 @@ class Artifact:
1191
1169
  settings: A settings object to use when initializing an automatic run. Most
1192
1170
  commonly used in testing harness.
1193
1171
  """
1194
- if self._state != ArtifactState.PENDING:
1172
+ if self._state is not ArtifactState.PENDING:
1195
1173
  return self._update()
1196
1174
 
1197
1175
  if self._incremental:
@@ -1252,31 +1230,20 @@ class Artifact:
1252
1230
  return self
1253
1231
 
1254
1232
  def _populate_after_save(self, artifact_id: str) -> None:
1255
- query_template = """
1256
- query ArtifactByIDShort($id: ID!) {
1257
- artifact(id: $id) {
1258
- ...ArtifactFragment
1259
- }
1260
- }
1261
- """ + _gql_artifact_fragment()
1233
+ assert self._client is not None
1262
1234
 
1263
- query = gql(query_template)
1235
+ query = gql_compat(ARTIFACT_BY_ID_GQL, omit_fields=omit_artifact_fields())
1264
1236
 
1265
- assert self._client is not None
1266
- response = self._client.execute(
1267
- query,
1268
- variable_values={"id": artifact_id},
1269
- )
1237
+ data = self._client.execute(query, variable_values={"id": artifact_id})
1238
+ result = ArtifactByID.model_validate(data)
1270
1239
 
1271
- try:
1272
- attrs = response["artifact"]
1273
- except LookupError:
1240
+ if not (artifact := result.artifact):
1274
1241
  raise ValueError(f"Unable to fetch artifact with id: {artifact_id!r}")
1275
- else:
1276
- # _populate_after_save is only called on source artifacts, not linked artifacts
1277
- # We have to manually set is_link because we aren't fetching the collection the artifact.
1278
- # That requires greater refactoring for commitArtifact to return the artifact collection type.
1279
- self._assign_attrs(attrs, is_link=False)
1242
+
1243
+ # _populate_after_save is only called on source artifacts, not linked artifacts
1244
+ # We have to manually set is_link because we aren't fetching the collection the artifact.
1245
+ # That requires greater refactoring for commitArtifact to return the artifact collection type.
1246
+ self._assign_attrs(artifact, is_link=False)
1280
1247
 
1281
1248
  @normalize_exceptions
1282
1249
  def _update(self) -> None:
@@ -1361,7 +1328,7 @@ class Artifact:
1361
1328
  for alias in self.aliases
1362
1329
  ]
1363
1330
 
1364
- omit_fields = omit_artifact_fields(api=InternalApi())
1331
+ omit_fields = omit_artifact_fields()
1365
1332
  omit_variables = set()
1366
1333
 
1367
1334
  if {"ttlIsInherited", "ttlDurationSeconds"} & omit_fields:
@@ -1385,7 +1352,9 @@ class Artifact:
1385
1352
  omit_variables |= {"tagsToAdd", "tagsToDelete"}
1386
1353
 
1387
1354
  mutation = gql_compat(
1388
- UPDATE_ARTIFACT_GQL, omit_variables=omit_variables, omit_fields=omit_fields
1355
+ UPDATE_ARTIFACT_GQL,
1356
+ omit_variables=omit_variables,
1357
+ omit_fields=omit_fields,
1389
1358
  )
1390
1359
 
1391
1360
  gql_vars = {
@@ -1403,7 +1372,7 @@ class Artifact:
1403
1372
  result = UpdateArtifact.model_validate(data).update_artifact
1404
1373
  if not (result and (artifact := result.artifact)):
1405
1374
  raise ValueError("Unable to parse updateArtifact response")
1406
- self._assign_attrs(artifact.model_dump())
1375
+ self._assign_attrs(artifact)
1407
1376
 
1408
1377
  self._ttl_changed = False # Reset after updating artifact
1409
1378
 
@@ -1416,7 +1385,7 @@ class Artifact:
1416
1385
  name: The artifact relative name to get.
1417
1386
 
1418
1387
  Returns:
1419
- W&B object that can be logged with `wandb.log()` and visualized in the W&B UI.
1388
+ W&B object that can be logged with `run.log()` and visualized in the W&B UI.
1420
1389
 
1421
1390
  Raises:
1422
1391
  ArtifactNotLoggedError: If the artifact isn't logged or the run is offline.
@@ -1434,8 +1403,9 @@ class Artifact:
1434
1403
  The added manifest entry
1435
1404
 
1436
1405
  Raises:
1437
- ArtifactFinalizedError: You cannot make changes to the current artifact
1438
- version because it is finalized. Log a new artifact version instead.
1406
+ ArtifactFinalizedError: You cannot make changes to the current
1407
+ artifact version because it is finalized. Log a new artifact
1408
+ version instead.
1439
1409
  """
1440
1410
  return self.add(item, name)
1441
1411
 
@@ -1452,12 +1422,13 @@ class Artifact:
1452
1422
  encoding: The encoding used to open the new file.
1453
1423
 
1454
1424
  Returns:
1455
- A new file object that can be written to. Upon closing, the file will be
1456
- automatically added to the artifact.
1425
+ A new file object that can be written to. Upon closing, the file
1426
+ is automatically added to the artifact.
1457
1427
 
1458
1428
  Raises:
1459
- ArtifactFinalizedError: You cannot make changes to the current artifact
1460
- version because it is finalized. Log a new artifact version instead.
1429
+ ArtifactFinalizedError: You cannot make changes to the current
1430
+ artifact version because it is finalized. Log a new artifact
1431
+ version instead.
1461
1432
  """
1462
1433
  overwrite: bool = "x" not in mode
1463
1434
 
@@ -1465,7 +1436,7 @@ class Artifact:
1465
1436
  self._tmp_dir = tempfile.TemporaryDirectory()
1466
1437
  path = os.path.join(self._tmp_dir.name, name.lstrip("/"))
1467
1438
 
1468
- filesystem.mkdir_exists_ok(os.path.dirname(path))
1439
+ Path(path).parent.mkdir(parents=True, exist_ok=True)
1469
1440
  try:
1470
1441
  with fsync_open(path, mode, encoding) as f:
1471
1442
  yield f
@@ -1496,22 +1467,26 @@ class Artifact:
1496
1467
 
1497
1468
  Args:
1498
1469
  local_path: The path to the file being added.
1499
- name: The path within the artifact to use for the file being added. Defaults
1500
- to the basename of the file.
1470
+ name: The path within the artifact to use for the file being added.
1471
+ Defaults to the basename of the file.
1501
1472
  is_tmp: If true, then the file is renamed deterministically to avoid
1502
1473
  collisions.
1503
- skip_cache: If `True`, W&B will not copy files to the cache after uploading.
1504
- policy: By default, set to "mutable". If set to "mutable", create a temporary copy of the
1505
- file to prevent corruption during upload. If set to "immutable", disable
1506
- protection and rely on the user not to delete or change the file.
1474
+ skip_cache: If `True`, do not copy files to the cache
1475
+ after uploading.
1476
+ policy: By default, set to "mutable". If set to "mutable",
1477
+ create a temporary copy of the file to prevent corruption
1478
+ during upload. If set to "immutable", disable
1479
+ protection and rely on the user not to delete or change the
1480
+ file.
1507
1481
  overwrite: If `True`, overwrite the file if it already exists.
1508
1482
 
1509
1483
  Returns:
1510
1484
  The added manifest entry.
1511
1485
 
1512
1486
  Raises:
1513
- ArtifactFinalizedError: You cannot make changes to the current artifact
1514
- version because it is finalized. Log a new artifact version instead.
1487
+ ArtifactFinalizedError: You cannot make changes to the current
1488
+ artifact version because it is finalized. Log a new artifact
1489
+ version instead.
1515
1490
  ValueError: Policy must be "mutable" or "immutable"
1516
1491
  """
1517
1492
  if not os.path.isfile(local_path):
@@ -1548,48 +1523,47 @@ class Artifact:
1548
1523
 
1549
1524
  Args:
1550
1525
  local_path: The path of the local directory.
1551
- name: The subdirectory name within an artifact. The name you specify appears
1552
- in the W&B App UI nested by artifact's `type`.
1526
+ name: The subdirectory name within an artifact. The name you
1527
+ specify appears in the W&B App UI nested by artifact's `type`.
1553
1528
  Defaults to the root of the artifact.
1554
- skip_cache: If set to `True`, W&B will not copy/move files to the cache while uploading
1555
- policy: "mutable" | "immutable". By default, "mutable"
1556
- "mutable": Create a temporary copy of the file to prevent corruption during upload.
1557
- "immutable": Disable protection, rely on the user not to delete or change the file.
1529
+ skip_cache: If set to `True`, W&B will not copy/move files to
1530
+ the cache while uploading
1531
+ policy: By default, "mutable".
1532
+ - mutable: Create a temporary copy of the file to prevent corruption during upload.
1533
+ - immutable: Disable protection, rely on the user not to delete or change the file.
1558
1534
  merge: If `False` (default), throws ValueError if a file was already added in a previous add_dir call
1559
1535
  and its content has changed. If `True`, overwrites existing files with changed content.
1560
1536
  Always adds new files and never removes files. To replace an entire directory, pass a name when adding the directory
1561
1537
  using `add_dir(local_path, name=my_prefix)` and call `remove(my_prefix)` to remove the directory, then add it again.
1562
1538
 
1563
1539
  Raises:
1564
- ArtifactFinalizedError: You cannot make changes to the current artifact
1565
- version because it is finalized. Log a new artifact version instead.
1540
+ ArtifactFinalizedError: You cannot make changes to the current
1541
+ artifact version because it is finalized. Log a new artifact
1542
+ version instead.
1566
1543
  ValueError: Policy must be "mutable" or "immutable"
1567
1544
  """
1568
1545
  if not os.path.isdir(local_path):
1569
- raise ValueError(f"Path is not a directory: {local_path}")
1546
+ raise ValueError(f"Path is not a directory: {local_path!r}")
1570
1547
 
1571
1548
  termlog(
1572
- "Adding directory to artifact ({})... ".format(
1573
- os.path.join(".", os.path.normpath(local_path))
1574
- ),
1549
+ f"Adding directory to artifact ({Path('.', local_path)})... ",
1575
1550
  newline=False,
1576
1551
  )
1577
- start_time = time.time()
1552
+ start_time = time.monotonic()
1578
1553
 
1579
- paths = []
1554
+ paths: deque[tuple[str, str]] = deque()
1555
+ logical_root = name or "" # shared prefix, if any, for logical paths
1580
1556
  for dirpath, _, filenames in os.walk(local_path, followlinks=True):
1581
1557
  for fname in filenames:
1582
1558
  physical_path = os.path.join(dirpath, fname)
1583
1559
  logical_path = os.path.relpath(physical_path, start=local_path)
1584
- if name is not None:
1585
- logical_path = os.path.join(name, logical_path)
1560
+ logical_path = os.path.join(logical_root, logical_path)
1586
1561
  paths.append((logical_path, physical_path))
1587
1562
 
1588
- def add_manifest_file(log_phy_path: tuple[str, str]) -> None:
1589
- logical_path, physical_path = log_phy_path
1563
+ def add_manifest_file(logical_pth: str, physical_pth: str) -> None:
1590
1564
  self._add_local_file(
1591
- name=logical_path,
1592
- path=physical_path,
1565
+ name=logical_pth,
1566
+ path=physical_pth,
1593
1567
  skip_cache=skip_cache,
1594
1568
  policy=policy,
1595
1569
  overwrite=merge,
@@ -1597,11 +1571,11 @@ class Artifact:
1597
1571
 
1598
1572
  num_threads = 8
1599
1573
  pool = multiprocessing.dummy.Pool(num_threads)
1600
- pool.map(add_manifest_file, paths)
1574
+ pool.starmap(add_manifest_file, paths)
1601
1575
  pool.close()
1602
1576
  pool.join()
1603
1577
 
1604
- termlog("Done. %.1fs" % (time.time() - start_time), prefix=False)
1578
+ termlog("Done. %.1fs" % (time.monotonic() - start_time), prefix=False)
1605
1579
 
1606
1580
  @ensure_not_finalized
1607
1581
  def add_reference(
@@ -1621,13 +1595,14 @@ class Artifact:
1621
1595
 
1622
1596
  - http(s): The size and digest of the file will be inferred by the
1623
1597
  `Content-Length` and the `ETag` response headers returned by the server.
1624
- - s3: The checksum and size are pulled from the object metadata. If bucket
1625
- versioning is enabled, then the version ID is also tracked.
1598
+ - s3: The checksum and size are pulled from the object metadata.
1599
+ If bucket versioning is enabled, then the version ID is also tracked.
1626
1600
  - gs: The checksum and size are pulled from the object metadata. If bucket
1627
1601
  versioning is enabled, then the version ID is also tracked.
1628
- - https, domain matching `*.blob.core.windows.net` (Azure): The checksum and size
1629
- are be pulled from the blob metadata. If storage account versioning is
1630
- enabled, then the version ID is also tracked.
1602
+ - https, domain matching `*.blob.core.windows.net`
1603
+ - Azure: The checksum and size are be pulled from the blob metadata.
1604
+ If storage account versioning is enabled, then the version ID is
1605
+ also tracked.
1631
1606
  - file: The checksum and size are pulled from the file system. This scheme
1632
1607
  is useful if you have an NFS share or other externally mounted volume
1633
1608
  containing files you wish to track but not necessarily upload.
@@ -1648,16 +1623,18 @@ class Artifact:
1648
1623
  setting `checksum=False` when adding reference objects, in which case
1649
1624
  a new version will only be created if the reference URI changes.
1650
1625
  max_objects: The maximum number of objects to consider when adding a
1651
- reference that points to directory or bucket store prefix. By default,
1652
- the maximum number of objects allowed for Amazon S3,
1653
- GCS, Azure, and local files is 10,000,000. Other URI schemas do not have a maximum.
1626
+ reference that points to directory or bucket store prefix.
1627
+ By default, the maximum number of objects allowed for Amazon S3,
1628
+ GCS, Azure, and local files is 10,000,000. Other URI schemas
1629
+ do not have a maximum.
1654
1630
 
1655
1631
  Returns:
1656
1632
  The added manifest entries.
1657
1633
 
1658
1634
  Raises:
1659
- ArtifactFinalizedError: You cannot make changes to the current artifact
1660
- version because it is finalized. Log a new artifact version instead.
1635
+ ArtifactFinalizedError: You cannot make changes to the current
1636
+ artifact version because it is finalized. Log a new artifact
1637
+ version instead.
1661
1638
  """
1662
1639
  if name is not None:
1663
1640
  name = LogicalPath(name)
@@ -1674,7 +1651,7 @@ class Artifact:
1674
1651
  "References must be URIs. To reference a local file, use file://"
1675
1652
  )
1676
1653
 
1677
- manifest_entries = self._storage_policy.store_reference(
1654
+ manifest_entries = self.manifest.storage_policy.store_reference(
1678
1655
  self,
1679
1656
  URIStr(uri_str),
1680
1657
  name=name,
@@ -1694,17 +1671,19 @@ class Artifact:
1694
1671
 
1695
1672
  Args:
1696
1673
  obj: The object to add. Currently support one of Bokeh, JoinedTable,
1697
- PartitionedTable, Table, Classes, ImageMask, BoundingBoxes2D, Audio,
1698
- Image, Video, Html, Object3D
1674
+ PartitionedTable, Table, Classes, ImageMask, BoundingBoxes2D,
1675
+ Audio, Image, Video, Html, Object3D
1699
1676
  name: The path within the artifact to add the object.
1700
- overwrite: If True, overwrite existing objects with the same file path (if applicable).
1677
+ overwrite: If True, overwrite existing objects with the same file
1678
+ path if applicable.
1701
1679
 
1702
1680
  Returns:
1703
1681
  The added manifest entry
1704
1682
 
1705
1683
  Raises:
1706
- ArtifactFinalizedError: You cannot make changes to the current artifact
1707
- version because it is finalized. Log a new artifact version instead.
1684
+ ArtifactFinalizedError: You cannot make changes to the current
1685
+ artifact version because it is finalized. Log a new artifact
1686
+ version instead.
1708
1687
  """
1709
1688
  name = LogicalPath(name)
1710
1689
 
@@ -1819,13 +1798,14 @@ class Artifact:
1819
1798
  """Remove an item from the artifact.
1820
1799
 
1821
1800
  Args:
1822
- item: The item to remove. Can be a specific manifest entry or the name of an
1823
- artifact-relative path. If the item matches a directory all items in
1824
- that directory will be removed.
1801
+ item: The item to remove. Can be a specific manifest entry
1802
+ or the name of an artifact-relative path. If the item
1803
+ matches a directory all items in that directory will be removed.
1825
1804
 
1826
1805
  Raises:
1827
- ArtifactFinalizedError: You cannot make changes to the current artifact
1828
- version because it is finalized. Log a new artifact version instead.
1806
+ ArtifactFinalizedError: You cannot make changes to the current
1807
+ artifact version because it is finalized. Log a new artifact
1808
+ version instead.
1829
1809
  FileNotFoundError: If the item isn't found in the artifact.
1830
1810
  """
1831
1811
  if isinstance(item, ArtifactManifestEntry):
@@ -1833,10 +1813,8 @@ class Artifact:
1833
1813
  return
1834
1814
 
1835
1815
  path = str(PurePosixPath(item))
1836
- entry = self.manifest.get_entry_by_path(path)
1837
- if entry:
1838
- self.manifest.remove_entry(entry)
1839
- return
1816
+ if entry := self.manifest.get_entry_by_path(path):
1817
+ return self.manifest.remove_entry(entry)
1840
1818
 
1841
1819
  entries = self.manifest.get_entries_in_directory(path)
1842
1820
  if not entries:
@@ -1881,10 +1859,12 @@ class Artifact:
1881
1859
  name: The artifact relative name to retrieve.
1882
1860
 
1883
1861
  Returns:
1884
- W&B object that can be logged with `wandb.log()` and visualized in the W&B UI.
1862
+ W&B object that can be logged with `run.log()` and
1863
+ visualized in the W&B UI.
1885
1864
 
1886
1865
  Raises:
1887
- ArtifactNotLoggedError: if the artifact isn't logged or the run is offline
1866
+ ArtifactNotLoggedError: if the artifact isn't logged or the
1867
+ run is offline.
1888
1868
  """
1889
1869
  entry, wb_class = self._get_obj_entry(name)
1890
1870
  if entry is None or wb_class is None:
@@ -1892,8 +1872,7 @@ class Artifact:
1892
1872
 
1893
1873
  # If the entry is a reference from another artifact, then get it directly from
1894
1874
  # that artifact.
1895
- referenced_id = entry._referenced_artifact_id()
1896
- if referenced_id:
1875
+ if referenced_id := entry._referenced_artifact_id():
1897
1876
  assert self._client is not None
1898
1877
  artifact = self._from_id(referenced_id, client=self._client)
1899
1878
  assert artifact is not None
@@ -1912,10 +1891,9 @@ class Artifact:
1912
1891
  item_path = item.download()
1913
1892
 
1914
1893
  # Load the object from the JSON blob
1915
- result = None
1916
- json_obj = {}
1917
1894
  with open(item_path) as file:
1918
1895
  json_obj = json.load(file)
1896
+
1919
1897
  result = wb_class.from_json(json_obj, self)
1920
1898
  result._set_artifact_source(self, name)
1921
1899
  return result
@@ -1929,10 +1907,9 @@ class Artifact:
1929
1907
  Returns:
1930
1908
  The artifact relative name.
1931
1909
  """
1932
- entry = self._added_local_paths.get(local_path, None)
1933
- if entry is None:
1934
- return None
1935
- return entry.path
1910
+ if entry := self._added_local_paths.get(local_path):
1911
+ return entry.path
1912
+ return None
1936
1913
 
1937
1914
  def _get_obj_entry(
1938
1915
  self, name: str
@@ -1949,8 +1926,7 @@ class Artifact:
1949
1926
  """
1950
1927
  for wb_class in WBValue.type_mapping().values():
1951
1928
  wandb_file_name = wb_class.with_suffix(name)
1952
- entry = self.manifest.entries.get(wandb_file_name)
1953
- if entry is not None:
1929
+ if entry := self.manifest.entries.get(wandb_file_name):
1954
1930
  return entry, wb_class
1955
1931
  return None, None
1956
1932
 
@@ -2080,14 +2056,14 @@ class Artifact:
2080
2056
  multipart: bool | None = None,
2081
2057
  ) -> FilePathStr:
2082
2058
  nfiles = len(self.manifest.entries)
2083
- size = sum(e.size or 0 for e in self.manifest.entries.values())
2084
- log = False
2085
- if nfiles > 5000 or size > 50 * 1024 * 1024:
2086
- log = True
2059
+ size_mb = self.size / _MB
2060
+
2061
+ if log := (nfiles > 5000 or size_mb > 50):
2087
2062
  termlog(
2088
- f"Downloading large artifact {self.name}, {size / (1024 * 1024):.2f}MB. {nfiles} files... ",
2063
+ f"Downloading large artifact {self.name!r}, {size_mb:.2f}MB. {nfiles!r} files...",
2089
2064
  )
2090
- start_time = datetime.now()
2065
+ start_time = time.monotonic()
2066
+
2091
2067
  download_logger = ArtifactDownloadLogger(nfiles=nfiles)
2092
2068
 
2093
2069
  def _download_entry(
@@ -2126,44 +2102,62 @@ class Artifact:
2126
2102
  cookies=_thread_local_api_settings.cookies,
2127
2103
  headers=_thread_local_api_settings.headers,
2128
2104
  )
2105
+
2106
+ batch_size = env.get_artifact_fetch_file_url_batch_size()
2107
+
2129
2108
  active_futures = set()
2130
- has_next_page = True
2131
- cursor = None
2132
- while has_next_page:
2133
- fetch_url_batch_size = env.get_artifact_fetch_file_url_batch_size()
2134
- attrs = self._fetch_file_urls(cursor, fetch_url_batch_size)
2135
- has_next_page = attrs["pageInfo"]["hasNextPage"]
2136
- cursor = attrs["pageInfo"]["endCursor"]
2137
- for edge in attrs["edges"]:
2138
- entry = self.get_entry(edge["node"]["name"])
2109
+ cursor, has_more = None, True
2110
+ while has_more:
2111
+ files_page = self._fetch_file_urls(cursor=cursor, per_page=batch_size)
2112
+
2113
+ has_more = files_page.page_info.has_next_page
2114
+ cursor = files_page.page_info.end_cursor
2115
+
2116
+ # `File` nodes are formally nullable, so filter them out just in case.
2117
+ file_nodes = (e.node for e in files_page.edges if e.node)
2118
+ for node in file_nodes:
2119
+ entry = self.get_entry(node.name)
2139
2120
  # TODO: uncomment once artifact downloads are supported in core
2140
2121
  # if require_core and entry.ref is None:
2141
2122
  # # Handled by core
2142
2123
  # continue
2143
- entry._download_url = edge["node"]["directUrl"]
2124
+ entry._download_url = node.direct_url
2144
2125
  if (not path_prefix) or entry.path.startswith(str(path_prefix)):
2145
2126
  active_futures.add(executor.submit(download_entry, entry))
2127
+
2146
2128
  # Wait for download threads to catch up.
2147
- max_backlog = fetch_url_batch_size
2148
- if len(active_futures) > max_backlog:
2129
+ #
2130
+ # Extra context and observations (tonyyli):
2131
+ # - Even though the ThreadPoolExecutor limits the number of
2132
+ # concurrently-executed tasks, its internal task queue is unbounded.
2133
+ # The code below seems intended to ensure that at most `batch_size`
2134
+ # "backlogged" futures are held in memory at any given time. This seems like
2135
+ # a reasonable safeguard against unbounded memory consumption.
2136
+ #
2137
+ # - We should probably use a builtin (bounded) Queue or Semaphore here instead.
2138
+ # Consider this for a future change, or (depending on risk and risk tolerance)
2139
+ # managing this logic via asyncio instead, if viable.
2140
+ if len(active_futures) > batch_size:
2149
2141
  for future in concurrent.futures.as_completed(active_futures):
2150
2142
  future.result() # check for errors
2151
2143
  active_futures.remove(future)
2152
- if len(active_futures) <= max_backlog:
2144
+ if len(active_futures) <= batch_size:
2153
2145
  break
2146
+
2154
2147
  # Check for errors.
2155
2148
  for future in concurrent.futures.as_completed(active_futures):
2156
2149
  future.result()
2157
2150
 
2158
2151
  if log:
2159
- now = datetime.now()
2160
- delta = abs((now - start_time).total_seconds())
2161
- hours = int(delta // 3600)
2162
- minutes = int((delta - hours * 3600) // 60)
2163
- seconds = delta - hours * 3600 - minutes * 60
2164
- speed = size / 1024 / 1024 / delta
2152
+ # If you're wondering if we can display a `timedelta`, note that it
2153
+ # doesn't really support custom string format specifiers (compared to
2154
+ # e.g. `datetime` objs). To truncate the number of decimal places for
2155
+ # the seconds part, we manually convert/format each part below.
2156
+ dt_secs = abs(time.monotonic() - start_time)
2157
+ hrs, mins = divmod(dt_secs, 3600)
2158
+ mins, secs = divmod(mins, 60)
2165
2159
  termlog(
2166
- f"Done. {hours}:{minutes}:{seconds:.1f} ({speed:.1f}MB/s)",
2160
+ f"Done. {int(hrs):02d}:{int(mins):02d}:{secs:04.1f} ({size_mb / dt_secs:.1f}MB/s)",
2167
2161
  prefix=False,
2168
2162
  )
2169
2163
  return FilePathStr(root)
@@ -2172,79 +2166,44 @@ class Artifact:
2172
2166
  retry_timedelta=timedelta(minutes=3),
2173
2167
  retryable_exceptions=(requests.RequestException),
2174
2168
  )
2175
- def _fetch_file_urls(self, cursor: str | None, per_page: int | None = 5000) -> Any:
2169
+ def _fetch_file_urls(
2170
+ self, cursor: str | None, per_page: int = 5000
2171
+ ) -> FileUrlsFragment:
2172
+ if self._client is None:
2173
+ raise RuntimeError("Client not initialized")
2174
+
2176
2175
  if InternalApi()._server_supports(
2177
2176
  pb.ServerFeature.ARTIFACT_COLLECTION_MEMBERSHIP_FILES
2178
2177
  ):
2179
- query = gql(
2180
- """
2181
- query ArtifactCollectionMembershipFileURLs($entityName: String!, $projectName: String!, \
2182
- $artifactName: String!, $artifactVersionIndex: String!, $cursor: String, $perPage: Int) {
2183
- project(name: $projectName, entityName: $entityName) {
2184
- artifactCollection(name: $artifactName) {
2185
- artifactMembership(aliasName: $artifactVersionIndex) {
2186
- files(after: $cursor, first: $perPage) {
2187
- pageInfo {
2188
- hasNextPage
2189
- endCursor
2190
- }
2191
- edges {
2192
- node {
2193
- name
2194
- directUrl
2195
- }
2196
- }
2197
- }
2198
- }
2199
- }
2200
- }
2201
- }
2202
- """
2203
- )
2204
- assert self._client is not None
2205
- response = self._client.execute(
2206
- query,
2207
- variable_values={
2208
- "entityName": self.entity,
2209
- "projectName": self.project,
2210
- "artifactName": self.name.split(":")[0],
2211
- "artifactVersionIndex": self.version,
2212
- "cursor": cursor,
2213
- "perPage": per_page,
2214
- },
2215
- timeout=60,
2216
- )
2217
- return response["project"]["artifactCollection"]["artifactMembership"][
2218
- "files"
2219
- ]
2178
+ query = gql(ARTIFACT_COLLECTION_MEMBERSHIP_FILE_URLS_GQL)
2179
+ gql_vars = {
2180
+ "entityName": self.entity,
2181
+ "projectName": self.project,
2182
+ "artifactName": self.name.split(":")[0],
2183
+ "artifactVersionIndex": self.version,
2184
+ "cursor": cursor,
2185
+ "perPage": per_page,
2186
+ }
2187
+ data = self._client.execute(query, variable_values=gql_vars, timeout=60)
2188
+ result = ArtifactCollectionMembershipFileUrls.model_validate(data)
2189
+
2190
+ if not (
2191
+ (project := result.project)
2192
+ and (collection := project.artifact_collection)
2193
+ and (membership := collection.artifact_membership)
2194
+ and (files := membership.files)
2195
+ ):
2196
+ raise ValueError(f"Unable to fetch files for artifact: {self.name!r}")
2197
+ return files
2220
2198
  else:
2221
- query = gql(
2222
- """
2223
- query ArtifactFileURLs($id: ID!, $cursor: String, $perPage: Int) {
2224
- artifact(id: $id) {
2225
- files(after: $cursor, first: $perPage) {
2226
- pageInfo {
2227
- hasNextPage
2228
- endCursor
2229
- }
2230
- edges {
2231
- node {
2232
- name
2233
- directUrl
2234
- }
2235
- }
2236
- }
2237
- }
2238
- }
2239
- """
2240
- )
2241
- assert self._client is not None
2242
- response = self._client.execute(
2243
- query,
2244
- variable_values={"id": self.id, "cursor": cursor, "perPage": per_page},
2245
- timeout=60,
2246
- )
2247
- return response["artifact"]["files"]
2199
+ query = gql(ARTIFACT_FILE_URLS_GQL)
2200
+ gql_vars = {"id": self.id, "cursor": cursor, "perPage": per_page}
2201
+ data = self._client.execute(query, variable_values=gql_vars, timeout=60)
2202
+ result = ArtifactFileUrls.model_validate(data)
2203
+
2204
+ if not ((artifact := result.artifact) and (files := artifact.files)):
2205
+ raise ValueError(f"Unable to fetch files for artifact: {self.name!r}")
2206
+ return files
2248
2207
 
2249
2208
  @ensure_logged
2250
2209
  def checkout(self, root: str | None = None) -> str:
@@ -2320,7 +2279,7 @@ class Artifact:
2320
2279
 
2321
2280
  Args:
2322
2281
  root: The root directory to store the file. Defaults to
2323
- './artifacts/self.name/'.
2282
+ `./artifacts/self.name/`.
2324
2283
 
2325
2284
  Returns:
2326
2285
  The full path of the downloaded file.
@@ -2386,16 +2345,16 @@ class Artifact:
2386
2345
  def delete(self, delete_aliases: bool = False) -> None:
2387
2346
  """Delete an artifact and its files.
2388
2347
 
2389
- If called on a linked artifact (i.e. a member of a portfolio collection): only the link is deleted, and the
2348
+ If called on a linked artifact, only the link is deleted, and the
2390
2349
  source artifact is unaffected.
2391
2350
 
2392
2351
  Use `artifact.unlink()` instead of `artifact.delete()` to remove a link between a source artifact and a linked artifact.
2393
2352
 
2394
2353
  Args:
2395
- delete_aliases: If set to `True`, deletes all aliases associated with the artifact.
2396
- Otherwise, this raises an exception if the artifact has existing
2397
- aliases.
2398
- This parameter is ignored if the artifact is linked (i.e. a member of a portfolio collection).
2354
+ delete_aliases: If set to `True`, deletes all aliases associated
2355
+ with the artifact. Otherwise, this raises an exception if
2356
+ the artifact has existing aliases. This parameter is ignored
2357
+ if the artifact is linked (a member of a portfolio collection).
2399
2358
 
2400
2359
  Raises:
2401
2360
  ArtifactNotLoggedError: If the artifact is not logged.
@@ -2410,33 +2369,16 @@ class Artifact:
2410
2369
 
2411
2370
  @normalize_exceptions
2412
2371
  def _delete(self, delete_aliases: bool = False) -> None:
2413
- mutation = gql(
2414
- """
2415
- mutation DeleteArtifact($artifactID: ID!, $deleteAliases: Boolean) {
2416
- deleteArtifact(input: {
2417
- artifactID: $artifactID
2418
- deleteAliases: $deleteAliases
2419
- }) {
2420
- artifact {
2421
- id
2422
- }
2423
- }
2424
- }
2425
- """
2426
- )
2427
- assert self._client is not None
2428
- self._client.execute(
2429
- mutation,
2430
- variable_values={
2431
- "artifactID": self.id,
2432
- "deleteAliases": delete_aliases,
2433
- },
2434
- )
2372
+ if self._client is None:
2373
+ raise RuntimeError("Client not initialized for artifact mutations")
2374
+
2375
+ mutation = gql(DELETE_ARTIFACT_GQL)
2376
+ gql_vars = {"artifactID": self.id, "deleteAliases": delete_aliases}
2377
+
2378
+ self._client.execute(mutation, variable_values=gql_vars)
2435
2379
 
2436
2380
  @normalize_exceptions
2437
- def link(
2438
- self, target_path: str, aliases: list[str] | None = None
2439
- ) -> Artifact | None:
2381
+ def link(self, target_path: str, aliases: list[str] | None = None) -> Artifact:
2440
2382
  """Link this artifact to a portfolio (a promoted collection of artifacts).
2441
2383
 
2442
2384
  Args:
@@ -2448,14 +2390,14 @@ class Artifact:
2448
2390
  portfolio inside a project, set `target_path` to the following
2449
2391
  schema `{"model-registry"}/{Registered Model Name}` or
2450
2392
  `{entity}/{"model-registry"}/{Registered Model Name}`.
2451
- aliases: A list of strings that uniquely identifies the artifact inside the
2452
- specified portfolio.
2393
+ aliases: A list of strings that uniquely identifies the artifact
2394
+ inside the specified portfolio.
2453
2395
 
2454
2396
  Raises:
2455
2397
  ArtifactNotLoggedError: If the artifact is not logged.
2456
2398
 
2457
2399
  Returns:
2458
- The linked artifact if linking was successful, otherwise None.
2400
+ The linked artifact.
2459
2401
  """
2460
2402
  from wandb import Api
2461
2403
 
@@ -2475,10 +2417,11 @@ class Artifact:
2475
2417
  # Wait until the artifact is committed before trying to link it.
2476
2418
  self.wait()
2477
2419
 
2478
- api = Api(overrides={"entity": self.source_entity})
2420
+ api = InternalApi()
2421
+ settings = api.settings()
2479
2422
 
2480
2423
  target = ArtifactPath.from_str(target_path).with_defaults(
2481
- project=api.settings.get("project") or "uncategorized",
2424
+ project=settings.get("project") or "uncategorized",
2482
2425
  )
2483
2426
 
2484
2427
  # Parse the entity (first part of the path) appropriately,
@@ -2486,11 +2429,8 @@ class Artifact:
2486
2429
  if target.project and is_artifact_registry_project(target.project):
2487
2430
  # In a Registry linking, the entity is used to fetch the organization of the artifact
2488
2431
  # therefore the source artifact's entity is passed to the backend
2489
- organization = target.prefix or api.settings.get("organization") or ""
2490
-
2491
- target.prefix = InternalApi()._resolve_org_entity_name(
2492
- self.source_entity, organization
2493
- )
2432
+ org = target.prefix or settings.get("organization") or ""
2433
+ target.prefix = api._resolve_org_entity_name(self.source_entity, org)
2494
2434
  else:
2495
2435
  target = target.with_defaults(prefix=self.source_entity)
2496
2436
 
@@ -2507,29 +2447,44 @@ class Artifact:
2507
2447
  aliases=alias_inputs,
2508
2448
  )
2509
2449
  gql_vars = {"input": gql_input.model_dump(exclude_none=True)}
2510
- gql_op = gql(LINK_ARTIFACT_GQL)
2511
- data = self._client.execute(gql_op, variable_values=gql_vars)
2512
2450
 
2451
+ # Newer server versions can return `artifactMembership` directly in the response,
2452
+ # avoiding the need to re-fetch the linked artifact at the end.
2453
+ if api._server_supports(
2454
+ pb.ServerFeature.ARTIFACT_MEMBERSHIP_IN_LINK_ARTIFACT_RESPONSE
2455
+ ):
2456
+ omit_fragments = set()
2457
+ else:
2458
+ # FIXME: Make `gql_compat` omit nested fragment definitions recursively (but safely)
2459
+ omit_fragments = {
2460
+ "MembershipWithArtifact",
2461
+ "ArtifactFragment",
2462
+ "ArtifactFragmentWithoutAliases",
2463
+ }
2464
+
2465
+ gql_op = gql_compat(LINK_ARTIFACT_GQL, omit_fragments=omit_fragments)
2466
+ data = self._client.execute(gql_op, variable_values=gql_vars)
2513
2467
  result = LinkArtifact.model_validate(data).link_artifact
2468
+
2469
+ # Newer server versions can return artifactMembership directly in the response
2470
+ if result and (membership := result.artifact_membership):
2471
+ return self._from_membership(membership, target=target, client=self._client)
2472
+
2473
+ # Fallback to old behavior, which requires re-fetching the linked artifact to return it
2514
2474
  if not (result and (version_idx := result.version_index) is not None):
2515
2475
  raise ValueError("Unable to parse linked artifact version from response")
2516
2476
 
2517
- # Fetch the linked artifact to return it
2518
- linked_path = f"{target.to_str()}:v{version_idx}"
2519
-
2520
- try:
2521
- return api._artifact(linked_path)
2522
- except Exception as e:
2523
- wandb.termerror(f"Error fetching link artifact after linking: {e}")
2524
- return None
2477
+ link_name = f"{target.to_str()}:v{version_idx}"
2478
+ return Api(overrides={"entity": self.source_entity})._artifact(link_name)
2525
2479
 
2526
2480
  @ensure_logged
2527
2481
  def unlink(self) -> None:
2528
- """Unlink this artifact if it is currently a member of a portfolio (a promoted collection of artifacts).
2482
+ """Unlink this artifact if it is currently a member of a promoted collection of artifacts.
2529
2483
 
2530
2484
  Raises:
2531
2485
  ArtifactNotLoggedError: If the artifact is not logged.
2532
- ValueError: If the artifact is not linked, i.e. it is not a member of a portfolio collection.
2486
+ ValueError: If the artifact is not linked, in other words,
2487
+ it is not a member of a portfolio collection.
2533
2488
  """
2534
2489
  # Fail early if this isn't a linked artifact to begin with
2535
2490
  if not self.is_link:
@@ -2542,28 +2497,14 @@ class Artifact:
2542
2497
 
2543
2498
  @normalize_exceptions
2544
2499
  def _unlink(self) -> None:
2545
- mutation = gql(
2546
- """
2547
- mutation UnlinkArtifact($artifactID: ID!, $artifactPortfolioID: ID!) {
2548
- unlinkArtifact(
2549
- input: { artifactID: $artifactID, artifactPortfolioID: $artifactPortfolioID }
2550
- ) {
2551
- artifactID
2552
- success
2553
- clientMutationId
2554
- }
2555
- }
2556
- """
2557
- )
2558
- assert self._client is not None
2500
+ if self._client is None:
2501
+ raise RuntimeError("Client not initialized for artifact mutations")
2502
+
2503
+ mutation = gql(UNLINK_ARTIFACT_GQL)
2504
+ gql_vars = {"artifactID": self.id, "artifactPortfolioID": self.collection.id}
2505
+
2559
2506
  try:
2560
- self._client.execute(
2561
- mutation,
2562
- variable_values={
2563
- "artifactID": self.id,
2564
- "artifactPortfolioID": self.collection.id,
2565
- },
2566
- )
2507
+ self._client.execute(mutation, variable_values=gql_vars)
2567
2508
  except CommError as e:
2568
2509
  raise CommError(
2569
2510
  f"You do not have permission to unlink the artifact {self.qualified_name}"
@@ -2579,41 +2520,26 @@ class Artifact:
2579
2520
  Raises:
2580
2521
  ArtifactNotLoggedError: If the artifact is not logged.
2581
2522
  """
2582
- query = gql(
2583
- """
2584
- query ArtifactUsedBy(
2585
- $id: ID!,
2586
- ) {
2587
- artifact(id: $id) {
2588
- usedBy {
2589
- edges {
2590
- node {
2591
- name
2592
- project {
2593
- name
2594
- entityName
2595
- }
2596
- }
2597
- }
2598
- }
2599
- }
2600
- }
2601
- """
2602
- )
2603
- assert self._client is not None
2604
- response = self._client.execute(
2605
- query,
2606
- variable_values={"id": self.id},
2607
- )
2608
- return [
2609
- Run(
2610
- self._client,
2611
- edge["node"]["project"]["entityName"],
2612
- edge["node"]["project"]["name"],
2613
- edge["node"]["name"],
2614
- )
2615
- for edge in response.get("artifact", {}).get("usedBy", {}).get("edges", [])
2616
- ]
2523
+ if self._client is None:
2524
+ raise RuntimeError("Client not initialized for artifact queries")
2525
+
2526
+ query = gql(ARTIFACT_USED_BY_GQL)
2527
+ gql_vars = {"id": self.id}
2528
+ data = self._client.execute(query, variable_values=gql_vars)
2529
+ result = ArtifactUsedBy.model_validate(data)
2530
+
2531
+ if (
2532
+ (artifact := result.artifact)
2533
+ and (used_by := artifact.used_by)
2534
+ and (edges := used_by.edges)
2535
+ ):
2536
+ run_nodes = (e.node for e in edges)
2537
+ return [
2538
+ Run(self._client, proj.entity_name, proj.name, run.name)
2539
+ for run in run_nodes
2540
+ if (proj := run.project)
2541
+ ]
2542
+ return []
2617
2543
 
2618
2544
  @ensure_logged
2619
2545
  def logged_by(self) -> Run | None:
@@ -2625,39 +2551,22 @@ class Artifact:
2625
2551
  Raises:
2626
2552
  ArtifactNotLoggedError: If the artifact is not logged.
2627
2553
  """
2628
- query = gql(
2629
- """
2630
- query ArtifactCreatedBy(
2631
- $id: ID!
2632
- ) {
2633
- artifact(id: $id) {
2634
- createdBy {
2635
- ... on Run {
2636
- name
2637
- project {
2638
- name
2639
- entityName
2640
- }
2641
- }
2642
- }
2643
- }
2644
- }
2645
- """
2646
- )
2647
- assert self._client is not None
2648
- response = self._client.execute(
2649
- query,
2650
- variable_values={"id": self.id},
2651
- )
2652
- creator = response.get("artifact", {}).get("createdBy", {})
2653
- if creator.get("name") is None:
2654
- return None
2655
- return Run(
2656
- self._client,
2657
- creator["project"]["entityName"],
2658
- creator["project"]["name"],
2659
- creator["name"],
2660
- )
2554
+ if self._client is None:
2555
+ raise RuntimeError("Client not initialized for artifact queries")
2556
+
2557
+ query = gql(ARTIFACT_CREATED_BY_GQL)
2558
+ gql_vars = {"id": self.id}
2559
+ data = self._client.execute(query, variable_values=gql_vars)
2560
+ result = ArtifactCreatedBy.model_validate(data)
2561
+
2562
+ if (
2563
+ (artifact := result.artifact)
2564
+ and (creator := artifact.created_by)
2565
+ and (name := creator.name)
2566
+ and (project := creator.project)
2567
+ ):
2568
+ return Run(self._client, project.entity_name, project.name, name)
2569
+ return None
2661
2570
 
2662
2571
  @ensure_logged
2663
2572
  def json_encode(self) -> dict[str, Any]:
@@ -2673,37 +2582,22 @@ class Artifact:
2673
2582
  entity_name: str, project_name: str, name: str, client: RetryingClient
2674
2583
  ) -> str | None:
2675
2584
  """Returns the expected type for a given artifact name and project."""
2676
- query = gql(
2677
- """
2678
- query ArtifactType(
2679
- $entityName: String,
2680
- $projectName: String,
2681
- $name: String!
2682
- ) {
2683
- project(name: $projectName, entityName: $entityName) {
2684
- artifact(name: $name) {
2685
- artifactType {
2686
- name
2687
- }
2688
- }
2689
- }
2690
- }
2691
- """
2692
- )
2693
- if ":" not in name:
2694
- name += ":latest"
2695
- response = client.execute(
2696
- query,
2697
- variable_values={
2698
- "entityName": entity_name,
2699
- "projectName": project_name,
2700
- "name": name,
2701
- },
2702
- )
2703
- return (
2704
- ((response.get("project") or {}).get("artifact") or {}).get("artifactType")
2705
- or {}
2706
- ).get("name")
2585
+ query = gql(ARTIFACT_TYPE_GQL)
2586
+ gql_vars = {
2587
+ "entityName": entity_name,
2588
+ "projectName": project_name,
2589
+ "name": name if (":" in name) else f"{name}:latest",
2590
+ }
2591
+ data = client.execute(query, variable_values=gql_vars)
2592
+ result = ArtifactType.model_validate(data)
2593
+
2594
+ if (
2595
+ (project := result.project)
2596
+ and (artifact := project.artifact)
2597
+ and (artifact_type := artifact.artifact_type)
2598
+ ):
2599
+ return artifact_type.name
2600
+ return None
2707
2601
 
2708
2602
  def _load_manifest(self, url: str) -> ArtifactManifest:
2709
2603
  with requests.get(url) as response: