wandb 0.19.10__py3-none-musllinux_1_2_aarch64.whl → 0.19.11__py3-none-musllinux_1_2_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. wandb/__init__.py +1 -1
  2. wandb/__init__.pyi +3 -3
  3. wandb/_pydantic/__init__.py +2 -3
  4. wandb/_pydantic/base.py +11 -31
  5. wandb/_pydantic/utils.py +8 -1
  6. wandb/_pydantic/v1_compat.py +3 -3
  7. wandb/apis/public/api.py +590 -22
  8. wandb/apis/public/artifacts.py +13 -5
  9. wandb/apis/public/automations.py +1 -1
  10. wandb/apis/public/integrations.py +22 -10
  11. wandb/apis/public/registries/__init__.py +0 -0
  12. wandb/apis/public/registries/_freezable_list.py +179 -0
  13. wandb/apis/public/{registries.py → registries/registries_search.py} +22 -129
  14. wandb/apis/public/registries/registry.py +357 -0
  15. wandb/apis/public/registries/utils.py +140 -0
  16. wandb/apis/public/runs.py +58 -56
  17. wandb/automations/__init__.py +16 -24
  18. wandb/automations/_filters/expressions.py +12 -10
  19. wandb/automations/_filters/operators.py +10 -19
  20. wandb/automations/_filters/run_metrics.py +231 -82
  21. wandb/automations/_generated/__init__.py +27 -34
  22. wandb/automations/_generated/create_automation.py +17 -0
  23. wandb/automations/_generated/delete_automation.py +17 -0
  24. wandb/automations/_generated/fragments.py +40 -25
  25. wandb/automations/_generated/{get_triggers.py → get_automations.py} +5 -5
  26. wandb/automations/_generated/{get_triggers_by_entity.py → get_automations_by_entity.py} +7 -5
  27. wandb/automations/_generated/operations.py +35 -98
  28. wandb/automations/_generated/update_automation.py +17 -0
  29. wandb/automations/_utils.py +178 -64
  30. wandb/automations/_validators.py +94 -2
  31. wandb/automations/actions.py +113 -98
  32. wandb/automations/automations.py +47 -69
  33. wandb/automations/events.py +139 -87
  34. wandb/automations/integrations.py +23 -4
  35. wandb/automations/scopes.py +22 -20
  36. wandb/bin/gpu_stats +0 -0
  37. wandb/bin/wandb-core +0 -0
  38. wandb/env.py +11 -0
  39. wandb/old/settings.py +4 -1
  40. wandb/proto/v3/wandb_internal_pb2.py +240 -236
  41. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  42. wandb/proto/v4/wandb_internal_pb2.py +236 -236
  43. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  44. wandb/proto/v5/wandb_internal_pb2.py +236 -236
  45. wandb/proto/v5/wandb_telemetry_pb2.py +10 -10
  46. wandb/proto/v6/wandb_internal_pb2.py +236 -236
  47. wandb/proto/v6/wandb_telemetry_pb2.py +10 -10
  48. wandb/sdk/artifacts/_generated/__init__.py +42 -1
  49. wandb/sdk/artifacts/_generated/add_aliases.py +21 -0
  50. wandb/sdk/artifacts/_generated/delete_aliases.py +21 -0
  51. wandb/sdk/artifacts/_generated/fetch_linked_artifacts.py +67 -0
  52. wandb/sdk/artifacts/_generated/fragments.py +35 -0
  53. wandb/sdk/artifacts/_generated/input_types.py +12 -0
  54. wandb/sdk/artifacts/_generated/operations.py +101 -0
  55. wandb/sdk/artifacts/_generated/update_artifact.py +26 -0
  56. wandb/sdk/artifacts/_graphql_fragments.py +1 -0
  57. wandb/sdk/artifacts/_validators.py +120 -1
  58. wandb/sdk/artifacts/artifact.py +380 -203
  59. wandb/sdk/artifacts/artifact_file_cache.py +4 -6
  60. wandb/sdk/artifacts/artifact_manifest_entry.py +11 -2
  61. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +182 -1
  62. wandb/sdk/artifacts/storage_policy.py +3 -0
  63. wandb/sdk/data_types/video.py +46 -32
  64. wandb/sdk/interface/interface.py +2 -3
  65. wandb/sdk/internal/internal_api.py +21 -31
  66. wandb/sdk/internal/sender.py +5 -2
  67. wandb/sdk/launch/sweeps/utils.py +8 -0
  68. wandb/sdk/projects/_generated/__init__.py +47 -0
  69. wandb/sdk/projects/_generated/delete_project.py +22 -0
  70. wandb/sdk/projects/_generated/enums.py +4 -0
  71. wandb/sdk/projects/_generated/fetch_registry.py +22 -0
  72. wandb/sdk/projects/_generated/fragments.py +41 -0
  73. wandb/sdk/projects/_generated/input_types.py +13 -0
  74. wandb/sdk/projects/_generated/operations.py +88 -0
  75. wandb/sdk/projects/_generated/rename_project.py +27 -0
  76. wandb/sdk/projects/_generated/upsert_registry_project.py +27 -0
  77. wandb/sdk/service/service.py +9 -1
  78. wandb/sdk/wandb_init.py +32 -5
  79. wandb/sdk/wandb_run.py +37 -9
  80. wandb/sdk/wandb_settings.py +6 -7
  81. wandb/sdk/wandb_setup.py +12 -0
  82. wandb/util.py +7 -3
  83. {wandb-0.19.10.dist-info → wandb-0.19.11.dist-info}/METADATA +1 -1
  84. {wandb-0.19.10.dist-info → wandb-0.19.11.dist-info}/RECORD +87 -70
  85. wandb/automations/_generated/create_filter_trigger.py +0 -21
  86. wandb/automations/_generated/delete_trigger.py +0 -19
  87. wandb/automations/_generated/update_filter_trigger.py +0 -21
  88. {wandb-0.19.10.dist-info → wandb-0.19.11.dist-info}/WHEEL +0 -0
  89. {wandb-0.19.10.dist-info → wandb-0.19.11.dist-info}/entry_points.txt +0 -0
  90. {wandb-0.19.10.dist-info → wandb-0.19.11.dist-info}/licenses/LICENSE +0 -0
@@ -14,6 +14,7 @@ import shutil
14
14
  import stat
15
15
  import tempfile
16
16
  import time
17
+ from collections import deque
17
18
  from copy import copy
18
19
  from dataclasses import dataclass
19
20
  from datetime import datetime, timedelta
@@ -28,17 +29,27 @@ import wandb
28
29
  from wandb import data_types, env, util
29
30
  from wandb.apis.normalize import normalize_exceptions
30
31
  from wandb.apis.public import ArtifactCollection, ArtifactFiles, RetryingClient, Run
32
+ from wandb.apis.public.utils import gql_compat
31
33
  from wandb.data_types import WBValue
34
+ from wandb.errors import CommError
32
35
  from wandb.errors.term import termerror, termlog, termwarn
33
36
  from wandb.proto import wandb_internal_pb2 as pb
34
37
  from wandb.proto.wandb_deprecated import Deprecated
35
38
  from wandb.sdk import wandb_setup
36
- from wandb.sdk.artifacts._graphql_fragments import _gql_artifact_fragment
39
+ from wandb.sdk.artifacts._generated.fetch_linked_artifacts import FetchLinkedArtifacts
40
+ from wandb.sdk.artifacts._generated.operations import FETCH_LINKED_ARTIFACTS_GQL
41
+ from wandb.sdk.artifacts._graphql_fragments import (
42
+ _gql_artifact_fragment,
43
+ omit_artifact_fields,
44
+ )
37
45
  from wandb.sdk.artifacts._validators import (
46
+ LINKED_ARTIFACT_COLLECTION_TYPE,
47
+ _LinkArtifactFields,
38
48
  ensure_logged,
39
49
  ensure_not_finalized,
40
50
  is_artifact_registry_project,
41
51
  validate_aliases,
52
+ validate_artifact_name,
42
53
  validate_tags,
43
54
  )
44
55
  from wandb.sdk.artifacts.artifact_download_logger import ArtifactDownloadLogger
@@ -67,6 +78,16 @@ from wandb.sdk.lib.paths import FilePathStr, LogicalPath, StrPath, URIStr
67
78
  from wandb.sdk.lib.runid import generate_id
68
79
  from wandb.sdk.mailbox import MailboxHandle
69
80
 
81
+ from ._generated import (
82
+ ADD_ALIASES_GQL,
83
+ DELETE_ALIASES_GQL,
84
+ UPDATE_ARTIFACT_GQL,
85
+ ArtifactAliasInput,
86
+ ArtifactCollectionAliasInput,
87
+ TagInput,
88
+ UpdateArtifact,
89
+ )
90
+
70
91
  reset_path = util.vendor_setup()
71
92
 
72
93
  from wandb_gql import gql # noqa: E402
@@ -112,6 +133,7 @@ class Artifact:
112
133
  incremental: Use `Artifact.new_draft()` method instead to modify an
113
134
  existing artifact.
114
135
  use_as: W&B Launch specific parameter. Not recommended for general use.
136
+ is_link: Boolean indication of if the artifact is a linked artifact(`True`) or source artifact(`False`).
115
137
 
116
138
  Returns:
117
139
  An `Artifact` object.
@@ -163,12 +185,14 @@ class Artifact:
163
185
  self._sequence_client_id: str = runid.generate_id(128)
164
186
  self._entity: str | None = None
165
187
  self._project: str | None = None
166
- self._name: str = name # includes version after saving
188
+ self._name: str = validate_artifact_name(name) # includes version after saving
167
189
  self._version: str | None = None
168
190
  self._source_entity: str | None = None
169
191
  self._source_project: str | None = None
170
192
  self._source_name: str = name # includes version after saving
171
193
  self._source_version: str | None = None
194
+ self._source_artifact: Artifact | None = None
195
+ self._is_link: bool = False
172
196
  self._type: str = type
173
197
  self._description: str | None = description
174
198
  self._metadata: dict = self._normalize_metadata(metadata)
@@ -192,6 +216,7 @@ class Artifact:
192
216
  self._updated_at: str | None = None
193
217
  self._final: bool = False
194
218
  self._history_step: int | None = None
219
+ self._linked_artifacts: list[Artifact] = []
195
220
 
196
221
  # Cache.
197
222
  artifact_instance_cache[self._client_id] = self
@@ -311,8 +336,13 @@ class Artifact:
311
336
  artifact_instance_cache[artifact.id] = artifact
312
337
  return artifact
313
338
 
339
+ # TODO: Eventually factor out is_link. Have to currently use it since some forms of fetching the artifact
340
+ # doesn't make it clear if the artifact is a link or not and have to manually set it.
314
341
  def _assign_attrs(
315
- self, attrs: dict[str, Any], aliases: list[str] | None = None
342
+ self,
343
+ attrs: dict[str, Any],
344
+ aliases: list[str] | None = None,
345
+ is_link: bool | None = None,
316
346
  ) -> None:
317
347
  """Update this Artifact's attributes using the server response."""
318
348
  self._id = attrs["id"]
@@ -334,6 +364,17 @@ class Artifact:
334
364
  if self._name is None:
335
365
  self._name = self._source_name
336
366
 
367
+ # TODO: Refactor artifact query to fetch artifact via membership instead
368
+ # and get the collection type
369
+ if is_link is None:
370
+ self._is_link = (
371
+ self._entity != self._source_entity
372
+ or self._project != self._source_project
373
+ or self._name != self._source_name
374
+ )
375
+ else:
376
+ self._is_link = is_link
377
+
337
378
  self._type = attrs["artifactType"]["name"]
338
379
  self._description = attrs["description"]
339
380
 
@@ -383,12 +424,12 @@ class Artifact:
383
424
  self._aliases = other_aliases
384
425
  self._saved_aliases = copy(other_aliases)
385
426
 
386
- tags = [obj["name"] for obj in attrs.get("tags", [])]
427
+ tags = [obj["name"] for obj in (attrs.get("tags") or [])]
387
428
  self._tags = tags
388
429
  self._saved_tags = copy(tags)
389
430
 
390
431
  metadata_str = attrs["metadata"]
391
- self.metadata = self._normalize_metadata(
432
+ self._metadata = self._normalize_metadata(
392
433
  json.loads(metadata_str) if metadata_str else {}
393
434
  )
394
435
 
@@ -462,35 +503,48 @@ class Artifact:
462
503
  @property
463
504
  @ensure_logged
464
505
  def entity(self) -> str:
465
- """The name of the entity of the secondary (portfolio) artifact collection."""
506
+ """The name of the entity that the artifact collection belongs to.
507
+
508
+ If the artifact is a link, the entity will be the entity of the linked artifact.
509
+ """
466
510
  assert self._entity is not None
467
511
  return self._entity
468
512
 
469
513
  @property
470
514
  @ensure_logged
471
515
  def project(self) -> str:
472
- """The name of the project of the secondary (portfolio) artifact collection."""
516
+ """The name of the project that the artifact collection belongs to.
517
+
518
+ If the artifact is a link, the project will be the project of the linked artifact.
519
+ """
473
520
  assert self._project is not None
474
521
  return self._project
475
522
 
476
523
  @property
477
524
  def name(self) -> str:
478
- """The artifact name and version in its secondary (portfolio) collection.
525
+ """The artifact name and version of the artifact.
479
526
 
480
- A string with the format `{collection}:{alias}`. Before the artifact is saved,
481
- contains only the name since the version is not yet known.
527
+ A string with the format `{collection}:{alias}`. If fetched before an artifact is logged/saved, the name won't contain the alias.
528
+ If the artifact is a link, the name will be the name of the linked artifact.
482
529
  """
483
530
  return self._name
484
531
 
485
532
  @property
486
533
  def qualified_name(self) -> str:
487
- """The entity/project/name of the secondary (portfolio) collection."""
534
+ """The entity/project/name of the artifact.
535
+
536
+ If the artifact is a link, the qualified name will be the qualified name of the linked artifact path.
537
+ """
488
538
  return f"{self.entity}/{self.project}/{self.name}"
489
539
 
490
540
  @property
491
541
  @ensure_logged
492
542
  def version(self) -> str:
493
- """The artifact's version in its secondary (portfolio) collection."""
543
+ """The artifact's version.
544
+
545
+ A string with the format `v{number}`.
546
+ If the artifact is a link artifact, the version will be from the linked collection.
547
+ """
494
548
  assert self._version is not None
495
549
  return self._version
496
550
 
@@ -513,35 +567,35 @@ class Artifact:
513
567
  @property
514
568
  @ensure_logged
515
569
  def source_entity(self) -> str:
516
- """The name of the entity of the primary (sequence) artifact collection."""
570
+ """The name of the entity of the source artifact."""
517
571
  assert self._source_entity is not None
518
572
  return self._source_entity
519
573
 
520
574
  @property
521
575
  @ensure_logged
522
576
  def source_project(self) -> str:
523
- """The name of the project of the primary (sequence) artifact collection."""
577
+ """The name of the project of the source artifact."""
524
578
  assert self._source_project is not None
525
579
  return self._source_project
526
580
 
527
581
  @property
528
582
  def source_name(self) -> str:
529
- """The artifact name and version in its primary (sequence) collection.
583
+ """The artifact name and version of the source artifact.
530
584
 
531
- A string with the format `{collection}:{alias}`. Before the artifact is saved,
585
+ A string with the format `{source_collection}:{alias}`. Before the artifact is saved,
532
586
  contains only the name since the version is not yet known.
533
587
  """
534
588
  return self._source_name
535
589
 
536
590
  @property
537
591
  def source_qualified_name(self) -> str:
538
- """The entity/project/name of the primary (sequence) collection."""
592
+ """The source_entity/source_project/source_name of the source artifact."""
539
593
  return f"{self.source_entity}/{self.source_project}/{self.source_name}"
540
594
 
541
595
  @property
542
596
  @ensure_logged
543
597
  def source_version(self) -> str:
544
- """The artifact's version in its primary (sequence) collection.
598
+ """The source artifact's version.
545
599
 
546
600
  A string with the format `v{number}`.
547
601
  """
@@ -551,12 +605,60 @@ class Artifact:
551
605
  @property
552
606
  @ensure_logged
553
607
  def source_collection(self) -> ArtifactCollection:
554
- """The artifact's primary (sequence) collection."""
608
+ """The artifact's source collection.
609
+
610
+ The source collection is the collection that the artifact was logged from.
611
+ """
555
612
  base_name = self.source_name.split(":")[0]
556
613
  return ArtifactCollection(
557
614
  self._client, self.source_entity, self.source_project, base_name, self.type
558
615
  )
559
616
 
617
+ @property
618
+ def is_link(self) -> bool:
619
+ """Boolean flag indicating if the artifact is a link artifact.
620
+
621
+ True: The artifact is a link artifact to a source artifact.
622
+ False: The artifact is a source artifact.
623
+ """
624
+ return self._is_link
625
+
626
+ @property
627
+ @ensure_logged
628
+ def linked_artifacts(self) -> list[Artifact]:
629
+ """Returns a list of all the linked artifacts of a source artifact.
630
+
631
+ If the artifact is a link artifact (`artifact.is_link == True`), it will return an empty list.
632
+ Limited to 500 results."""
633
+ if not self.is_link:
634
+ self._linked_artifacts = self._fetch_linked_artifacts()
635
+ return self._linked_artifacts
636
+
637
+ @property
638
+ @ensure_logged
639
+ def source_artifact(self) -> Artifact:
640
+ """Returns the source artifact. The source artifact is the original logged artifact.
641
+
642
+ If the artifact itself is a source artifact (`artifact.is_link == False`), it will return itself."""
643
+ if not self.is_link:
644
+ return self
645
+ if self._source_artifact is None:
646
+ try:
647
+ if self._client is None:
648
+ raise ValueError("Client is not initialized")
649
+ artifact = self._from_name(
650
+ entity=self.source_entity,
651
+ project=self.source_project,
652
+ name=self.source_name,
653
+ client=self._client,
654
+ )
655
+ self._source_artifact = artifact
656
+ except Exception as e:
657
+ raise ValueError(
658
+ f"Unable to fetch source artifact for linked artifact {self.name}"
659
+ ) from e
660
+ return self._source_artifact
661
+
560
662
  @property
561
663
  def type(self) -> str:
562
664
  """The artifact's type. Common types include `dataset` or `model`."""
@@ -576,7 +678,7 @@ class Artifact:
576
678
  except AttributeError:
577
679
  return ""
578
680
 
579
- if self.collection.is_sequence():
681
+ if not self.is_link:
580
682
  return self._construct_standard_url(base_url)
581
683
  if is_artifact_registry_project(self.project):
582
684
  return self._construct_registry_url(base_url)
@@ -660,9 +762,15 @@ class Artifact:
660
762
  standardized team model or dataset card. In the W&B UI the
661
763
  description is rendered as markdown.
662
764
 
765
+ Editing the description will apply the changes to the source artifact and all linked artifacts associated with it.
766
+
663
767
  Args:
664
768
  description: Free text that offers a description of the artifact.
665
769
  """
770
+ if self.is_link:
771
+ wandb.termwarn(
772
+ "Editing the description of this linked artifact will edit the description for the source artifact and it's linked artifacts as well."
773
+ )
666
774
  self._description = description
667
775
 
668
776
  @property
@@ -681,10 +789,15 @@ class Artifact:
681
789
  the class distribution of a dataset.
682
790
 
683
791
  Note: There is currently a limit of 100 total keys.
792
+ Editing the metadata will apply the changes to the source artifact and all linked artifacts associated with it.
684
793
 
685
794
  Args:
686
795
  metadata: Structured data associated with the artifact.
687
796
  """
797
+ if self.is_link:
798
+ wandb.termwarn(
799
+ "Editing the metadata of this linked artifact will edit the metadata for the source artifact and it's linked artifacts as well."
800
+ )
688
801
  self._metadata = self._normalize_metadata(metadata)
689
802
 
690
803
  @property
@@ -725,6 +838,12 @@ class Artifact:
725
838
  if self.type == "wandb-history":
726
839
  raise ValueError("Cannot set artifact TTL for type wandb-history")
727
840
 
841
+ if self.is_link:
842
+ raise ValueError(
843
+ "Cannot set TTL for link artifact. "
844
+ "Unlink the artifact first then set the TTL for the source artifact"
845
+ )
846
+
728
847
  self._ttl_changed = True
729
848
  if isinstance(ttl, ArtifactTTL):
730
849
  if ttl == ArtifactTTL.INHERIT:
@@ -769,7 +888,14 @@ class Artifact:
769
888
  @tags.setter
770
889
  @ensure_logged
771
890
  def tags(self, tags: list[str]) -> None:
772
- """Set the tags associated with this artifact."""
891
+ """Set the tags associated with this artifact.
892
+
893
+ Editing tags will apply the changes to the source artifact and all linked artifacts associated with it.
894
+ """
895
+ if self.is_link:
896
+ wandb.termwarn(
897
+ "Editing tags will apply the changes to the source artifact and all linked artifacts associated with it."
898
+ )
773
899
  self._tags = validate_tags(tags)
774
900
 
775
901
  @property
@@ -1032,11 +1158,21 @@ class Artifact:
1032
1158
  except LookupError:
1033
1159
  raise ValueError(f"Unable to fetch artifact with id: {artifact_id!r}")
1034
1160
  else:
1035
- self._assign_attrs(attrs)
1161
+ # _populate_after_save is only called on source artifacts, not linked artifacts
1162
+ # We have to manually set is_link because we aren't fetching the collection the artifact.
1163
+ # That requires greater refactoring for commitArtifact to return the artifact collection type.
1164
+ self._assign_attrs(attrs, is_link=False)
1036
1165
 
1037
1166
  @normalize_exceptions
1038
1167
  def _update(self) -> None:
1039
1168
  """Persists artifact changes to the wandb backend."""
1169
+ if self._client is None:
1170
+ raise RuntimeError("Client not initialized for artifact mutations")
1171
+
1172
+ entity = self.entity
1173
+ project = self.project
1174
+ collection = self.name.split(":")[0]
1175
+
1040
1176
  aliases = None
1041
1177
  introspect_query = gql(
1042
1178
  """
@@ -1050,180 +1186,109 @@ class Artifact:
1050
1186
  }
1051
1187
  """
1052
1188
  )
1053
- assert self._client is not None
1054
- response = self._client.execute(introspect_query)
1055
- if response.get("AddAliasesInputInfoType"): # wandb backend version >= 0.13.0
1056
- aliases_to_add = set(self._aliases) - set(self._saved_aliases)
1057
- aliases_to_delete = set(self._saved_aliases) - set(self._aliases)
1058
- if aliases_to_add:
1059
- add_mutation = gql(
1060
- """
1061
- mutation addAliases(
1062
- $artifactID: ID!,
1063
- $aliases: [ArtifactCollectionAliasInput!]!,
1064
- ) {
1065
- addAliases(
1066
- input: {artifactID: $artifactID, aliases: $aliases}
1067
- ) {
1068
- success
1069
- }
1070
- }
1071
- """
1072
- )
1073
- assert self._client is not None
1074
- self._client.execute(
1075
- add_mutation,
1076
- variable_values={
1077
- "artifactID": self.id,
1078
- "aliases": [
1079
- {
1080
- "entityName": self._entity,
1081
- "projectName": self._project,
1082
- "artifactCollectionName": self._name.split(":")[0],
1083
- "alias": alias,
1084
- }
1085
- for alias in aliases_to_add
1086
- ],
1087
- },
1088
- )
1089
- if aliases_to_delete:
1090
- delete_mutation = gql(
1091
- """
1092
- mutation deleteAliases(
1093
- $artifactID: ID!,
1094
- $aliases: [ArtifactCollectionAliasInput!]!,
1095
- ) {
1096
- deleteAliases(
1097
- input: {artifactID: $artifactID, aliases: $aliases}
1098
- ) {
1099
- success
1100
- }
1101
- }
1102
- """
1103
- )
1104
- assert self._client is not None
1105
- self._client.execute(
1106
- delete_mutation,
1107
- variable_values={
1108
- "artifactID": self.id,
1109
- "aliases": [
1110
- {
1111
- "entityName": self._entity,
1112
- "projectName": self._project,
1113
- "artifactCollectionName": self._name.split(":")[0],
1114
- "alias": alias,
1115
- }
1116
- for alias in aliases_to_delete
1117
- ],
1118
- },
1119
- )
1120
- self._saved_aliases = copy(self._aliases)
1189
+
1190
+ data = self._client.execute(introspect_query)
1191
+ if data.get("AddAliasesInputInfoType"): # wandb backend version >= 0.13.0
1192
+ alias_props = {
1193
+ "entity_name": entity,
1194
+ "project_name": project,
1195
+ "artifact_collection_name": collection,
1196
+ }
1197
+ if aliases_to_add := (set(self.aliases) - set(self._saved_aliases)):
1198
+ add_mutation = gql(ADD_ALIASES_GQL)
1199
+ add_alias_inputs = [
1200
+ ArtifactCollectionAliasInput(**alias_props, alias=alias)
1201
+ for alias in aliases_to_add
1202
+ ]
1203
+ try:
1204
+ self._client.execute(
1205
+ add_mutation,
1206
+ variable_values={
1207
+ "artifactID": self.id,
1208
+ "aliases": [a.model_dump() for a in add_alias_inputs],
1209
+ },
1210
+ )
1211
+ except CommError as e:
1212
+ raise CommError(
1213
+ "You do not have permission to add"
1214
+ f" {'at least one of the following aliases' if len(aliases_to_add) > 1 else 'the following alias'}"
1215
+ f" to this artifact: {aliases_to_add}"
1216
+ ) from e
1217
+
1218
+ if aliases_to_delete := (set(self._saved_aliases) - set(self.aliases)):
1219
+ delete_mutation = gql(DELETE_ALIASES_GQL)
1220
+ delete_alias_inputs = [
1221
+ ArtifactCollectionAliasInput(**alias_props, alias=alias)
1222
+ for alias in aliases_to_delete
1223
+ ]
1224
+ try:
1225
+ self._client.execute(
1226
+ delete_mutation,
1227
+ variable_values={
1228
+ "artifactID": self.id,
1229
+ "aliases": [a.model_dump() for a in delete_alias_inputs],
1230
+ },
1231
+ )
1232
+ except CommError as e:
1233
+ raise CommError(
1234
+ f"You do not have permission to delete"
1235
+ f" {'at least one of the following aliases' if len(aliases_to_delete) > 1 else 'the following alias'}"
1236
+ f" from this artifact: {aliases_to_delete}"
1237
+ ) from e
1238
+
1239
+ self._saved_aliases = copy(self.aliases)
1240
+
1121
1241
  else: # wandb backend version < 0.13.0
1122
1242
  aliases = [
1123
- {
1124
- "artifactCollectionName": self._name.split(":")[0],
1125
- "alias": alias,
1126
- }
1127
- for alias in self._aliases
1243
+ ArtifactAliasInput(
1244
+ artifact_collection_name=collection, alias=alias
1245
+ ).model_dump()
1246
+ for alias in self.aliases
1128
1247
  ]
1129
1248
 
1130
- mutation_template = """
1131
- mutation updateArtifact(
1132
- $artifactID: ID!
1133
- $description: String
1134
- $metadata: JSONString
1135
- _TTL_DURATION_SECONDS_TYPE_
1136
- _TAGS_TO_ADD_TYPE_
1137
- _TAGS_TO_DELETE_TYPE_
1138
- $aliases: [ArtifactAliasInput!]
1139
- ) {
1140
- updateArtifact(
1141
- input: {
1142
- artifactID: $artifactID,
1143
- description: $description,
1144
- metadata: $metadata,
1145
- _TTL_DURATION_SECONDS_VALUE_
1146
- _TAGS_TO_ADD_VALUE_
1147
- _TAGS_TO_DELETE_VALUE_
1148
- aliases: $aliases
1149
- }
1150
- ) {
1151
- artifact {
1152
- ...ArtifactFragment
1153
- }
1154
- }
1155
- }
1156
- """ + _gql_artifact_fragment()
1249
+ omit_fields = omit_artifact_fields(api=InternalApi())
1250
+ omit_variables = set()
1157
1251
 
1158
- fields = InternalApi().server_artifact_introspection()
1159
- if "ttlIsInherited" in fields:
1160
- mutation_template = (
1161
- mutation_template.replace(
1162
- "_TTL_DURATION_SECONDS_TYPE_",
1163
- "$ttlDurationSeconds: Int64",
1164
- )
1165
- .replace(
1166
- "_TTL_DURATION_SECONDS_VALUE_",
1167
- "ttlDurationSeconds: $ttlDurationSeconds",
1168
- )
1169
- .replace(
1170
- "_TTL_DURATION_SECONDS_FIELDS_",
1171
- "ttlDurationSeconds ttlIsInherited",
1172
- )
1173
- )
1174
- else:
1252
+ if {"ttlIsInherited", "ttlDurationSeconds"} & omit_fields:
1175
1253
  if self._ttl_changed:
1176
1254
  termwarn(
1177
1255
  "Server not compatible with setting Artifact TTLs, please upgrade the server to use Artifact TTL"
1178
1256
  )
1179
- mutation_template = (
1180
- mutation_template.replace("_TTL_DURATION_SECONDS_TYPE_", "")
1181
- .replace("_TTL_DURATION_SECONDS_VALUE_", "")
1182
- .replace("_TTL_DURATION_SECONDS_FIELDS_", "")
1183
- )
1184
1257
 
1185
- tags_to_add = validate_tags(set(self._tags) - set(self._saved_tags))
1186
- tags_to_delete = validate_tags(set(self._saved_tags) - set(self._tags))
1187
- if "tags" in fields:
1188
- mutation_template = (
1189
- mutation_template.replace(
1190
- "_TAGS_TO_ADD_TYPE_", "$tagsToAdd: [TagInput!]"
1191
- )
1192
- .replace("_TAGS_TO_DELETE_TYPE_", "$tagsToDelete: [TagInput!]")
1193
- .replace("_TAGS_TO_ADD_VALUE_", "tagsToAdd: $tagsToAdd")
1194
- .replace("_TAGS_TO_DELETE_VALUE_", "tagsToDelete: $tagsToDelete")
1195
- )
1196
- else:
1197
- if tags_to_add or tags_to_delete:
1258
+ omit_variables |= {"ttlDurationSeconds"}
1259
+
1260
+ tags_to_add = validate_tags(set(self.tags) - set(self._saved_tags))
1261
+ tags_to_del = validate_tags(set(self._saved_tags) - set(self.tags))
1262
+
1263
+ if {"tags"} & omit_fields:
1264
+ if tags_to_add or tags_to_del:
1198
1265
  termwarn(
1199
1266
  "Server not compatible with Artifact tags. "
1200
1267
  "To use Artifact tags, please upgrade the server to v0.85 or higher."
1201
1268
  )
1202
- mutation_template = (
1203
- mutation_template.replace("_TAGS_TO_ADD_TYPE_", "")
1204
- .replace("_TAGS_TO_DELETE_TYPE_", "")
1205
- .replace("_TAGS_TO_ADD_VALUE_", "")
1206
- .replace("_TAGS_TO_DELETE_VALUE_", "")
1207
- )
1208
1269
 
1209
- mutation = gql(mutation_template)
1210
- assert self._client is not None
1270
+ omit_variables |= {"tagsToAdd", "tagsToDelete"}
1211
1271
 
1212
- ttl_duration_input = self._ttl_duration_seconds_to_gql()
1213
- response = self._client.execute(
1214
- mutation,
1215
- variable_values={
1216
- "artifactID": self.id,
1217
- "description": self.description,
1218
- "metadata": util.json_dumps_safer(self.metadata),
1219
- "ttlDurationSeconds": ttl_duration_input,
1220
- "aliases": aliases,
1221
- "tagsToAdd": [{"tagName": tag_name} for tag_name in tags_to_add],
1222
- "tagsToDelete": [{"tagName": tag_name} for tag_name in tags_to_delete],
1223
- },
1272
+ mutation = gql_compat(
1273
+ UPDATE_ARTIFACT_GQL, omit_variables=omit_variables, omit_fields=omit_fields
1224
1274
  )
1225
- attrs = response["updateArtifact"]["artifact"]
1226
- self._assign_attrs(attrs)
1275
+
1276
+ gql_vars = {
1277
+ "artifactID": self.id,
1278
+ "description": self.description,
1279
+ "metadata": util.json_dumps_safer(self.metadata),
1280
+ "ttlDurationSeconds": self._ttl_duration_seconds_to_gql(),
1281
+ "aliases": aliases,
1282
+ "tagsToAdd": [TagInput(tag_name=t).model_dump() for t in tags_to_add],
1283
+ "tagsToDelete": [TagInput(tag_name=t).model_dump() for t in tags_to_del],
1284
+ }
1285
+
1286
+ data = self._client.execute(mutation, variable_values=gql_vars)
1287
+
1288
+ result = UpdateArtifact.model_validate(data).update_artifact
1289
+ if not (result and (artifact := result.artifact)):
1290
+ raise ValueError("Unable to parse updateArtifact response")
1291
+ self._assign_attrs(artifact.model_dump())
1227
1292
 
1228
1293
  self._ttl_changed = False # Reset after updating artifact
1229
1294
 
@@ -1776,6 +1841,7 @@ class Artifact:
1776
1841
  allow_missing_references: bool = False,
1777
1842
  skip_cache: bool | None = None,
1778
1843
  path_prefix: StrPath | None = None,
1844
+ multipart: bool | None = None,
1779
1845
  ) -> FilePathStr:
1780
1846
  """Download the contents of the artifact to the specified root directory.
1781
1847
 
@@ -1792,6 +1858,10 @@ class Artifact:
1792
1858
  specified download directory.
1793
1859
  path_prefix: If specified, only files with a path that starts with the given
1794
1860
  prefix will be downloaded. Uses unix format (forward slashes).
1861
+ multipart: If set to `None` (default), the artifact will be downloaded
1862
+ in parallel using multipart download if individual file size is greater than
1863
+ 2GB. If set to `True` or `False`, the artifact will be downloaded in
1864
+ parallel or serially regardless of the file size.
1795
1865
 
1796
1866
  Returns:
1797
1867
  The path to the downloaded contents.
@@ -1815,6 +1885,7 @@ class Artifact:
1815
1885
  allow_missing_references=allow_missing_references,
1816
1886
  skip_cache=skip_cache,
1817
1887
  path_prefix=path_prefix,
1888
+ multipart=multipart,
1818
1889
  )
1819
1890
 
1820
1891
  def _download_using_core(
@@ -1884,6 +1955,7 @@ class Artifact:
1884
1955
  allow_missing_references: bool = False,
1885
1956
  skip_cache: bool | None = None,
1886
1957
  path_prefix: StrPath | None = None,
1958
+ multipart: bool | None = None,
1887
1959
  ) -> FilePathStr:
1888
1960
  nfiles = len(self.manifest.entries)
1889
1961
  size = sum(e.size or 0 for e in self.manifest.entries.values())
@@ -1900,6 +1972,7 @@ class Artifact:
1900
1972
 
1901
1973
  def _download_entry(
1902
1974
  entry: ArtifactManifestEntry,
1975
+ executor: concurrent.futures.Executor,
1903
1976
  api_key: str | None,
1904
1977
  cookies: dict | None,
1905
1978
  headers: dict | None,
@@ -1909,7 +1982,12 @@ class Artifact:
1909
1982
  _thread_local_api_settings.headers = headers
1910
1983
 
1911
1984
  try:
1912
- entry.download(root, skip_cache=skip_cache)
1985
+ entry.download(
1986
+ root,
1987
+ skip_cache=skip_cache,
1988
+ executor=executor,
1989
+ multipart=multipart,
1990
+ )
1913
1991
  except FileNotFoundError as e:
1914
1992
  if allow_missing_references:
1915
1993
  wandb.termwarn(str(e))
@@ -1920,14 +1998,14 @@ class Artifact:
1920
1998
  return
1921
1999
  download_logger.notify_downloaded()
1922
2000
 
1923
- download_entry = partial(
1924
- _download_entry,
1925
- api_key=_thread_local_api_settings.api_key,
1926
- cookies=_thread_local_api_settings.cookies,
1927
- headers=_thread_local_api_settings.headers,
1928
- )
1929
-
1930
2001
  with concurrent.futures.ThreadPoolExecutor(64) as executor:
2002
+ download_entry = partial(
2003
+ _download_entry,
2004
+ executor=executor,
2005
+ api_key=_thread_local_api_settings.api_key,
2006
+ cookies=_thread_local_api_settings.cookies,
2007
+ headers=_thread_local_api_settings.headers,
2008
+ )
1931
2009
  active_futures = set()
1932
2010
  has_next_page = True
1933
2011
  cursor = None
@@ -1963,8 +2041,9 @@ class Artifact:
1963
2041
  hours = int(delta // 3600)
1964
2042
  minutes = int((delta - hours * 3600) // 60)
1965
2043
  seconds = delta - hours * 3600 - minutes * 60
2044
+ speed = size / 1024 / 1024 / delta
1966
2045
  termlog(
1967
- f"Done. {hours}:{minutes}:{seconds:.1f}",
2046
+ f"Done. {hours}:{minutes}:{seconds:.1f} ({speed:.1f}MB/s)",
1968
2047
  prefix=False,
1969
2048
  )
1970
2049
  return FilePathStr(root)
@@ -2192,6 +2271,8 @@ class Artifact:
2192
2271
  If called on a linked artifact (i.e. a member of a portfolio collection): only the link is deleted, and the
2193
2272
  source artifact is unaffected.
2194
2273
 
2274
+ Use `artifact.unlink()` instead of `artifact.delete()` to remove a link between a source artifact and a linked artifact.
2275
+
2195
2276
  Args:
2196
2277
  delete_aliases: If set to `True`, deletes all aliases associated with the artifact.
2197
2278
  Otherwise, this raises an exception if the artifact has existing
@@ -2201,10 +2282,13 @@ class Artifact:
2201
2282
  Raises:
2202
2283
  ArtifactNotLoggedError: If the artifact is not logged.
2203
2284
  """
2204
- if self.collection.is_sequence():
2205
- self._delete(delete_aliases)
2206
- else:
2285
+ if self.is_link:
2286
+ wandb.termwarn(
2287
+ "Deleting a link artifact will only unlink the artifact from the source artifact and not delete the source artifact and the data of the source artifact."
2288
+ )
2207
2289
  self._unlink()
2290
+ else:
2291
+ self._delete(delete_aliases)
2208
2292
 
2209
2293
  @normalize_exceptions
2210
2294
  def _delete(self, delete_aliases: bool = False) -> None:
@@ -2232,7 +2316,9 @@ class Artifact:
2232
2316
  )
2233
2317
 
2234
2318
  @normalize_exceptions
2235
- def link(self, target_path: str, aliases: list[str] | None = None) -> None:
2319
+ def link(
2320
+ self, target_path: str, aliases: list[str] | None = None
2321
+ ) -> Artifact | None:
2236
2322
  """Link this artifact to a portfolio (a promoted collection of artifacts).
2237
2323
 
2238
2324
  Args:
@@ -2249,12 +2335,20 @@ class Artifact:
2249
2335
 
2250
2336
  Raises:
2251
2337
  ArtifactNotLoggedError: If the artifact is not logged.
2338
+
2339
+ Returns:
2340
+ The linked artifact if linking was successful, otherwise None.
2252
2341
  """
2342
+ if self.is_link:
2343
+ wandb.termwarn(
2344
+ "Linking to a link artifact will result in directly linking to the source artifact of that link artifact."
2345
+ )
2346
+
2253
2347
  singleton = wandb_setup._setup(start_service=False)
2254
2348
 
2255
2349
  if run := singleton.most_recent_active_run:
2256
2350
  # TODO: Deprecate and encourage explicit link_artifact().
2257
- run.link_artifact(self, target_path, aliases)
2351
+ return run.link_artifact(self, target_path, aliases)
2258
2352
 
2259
2353
  else:
2260
2354
  with wandb.init(
@@ -2263,7 +2357,9 @@ class Artifact:
2263
2357
  job_type="auto",
2264
2358
  settings=wandb.Settings(silent="true"),
2265
2359
  ) as run:
2266
- run.link_artifact(self, target_path, aliases)
2360
+ return run.link_artifact(self, target_path, aliases)
2361
+
2362
+ return None
2267
2363
 
2268
2364
  @ensure_logged
2269
2365
  def unlink(self) -> None:
@@ -2274,7 +2370,7 @@ class Artifact:
2274
2370
  ValueError: If the artifact is not linked, i.e. it is not a member of a portfolio collection.
2275
2371
  """
2276
2372
  # Fail early if this isn't a linked artifact to begin with
2277
- if self.collection.is_sequence():
2373
+ if not self.is_link:
2278
2374
  raise ValueError(
2279
2375
  f"Artifact {self.qualified_name!r} is not a linked artifact and cannot be unlinked. "
2280
2376
  f"To delete it, use {self.delete.__qualname__!r} instead."
@@ -2298,17 +2394,22 @@ class Artifact:
2298
2394
  """
2299
2395
  )
2300
2396
  assert self._client is not None
2301
- self._client.execute(
2302
- mutation,
2303
- variable_values={
2304
- "artifactID": self.id,
2305
- "artifactPortfolioID": self.collection.id,
2306
- },
2307
- )
2397
+ try:
2398
+ self._client.execute(
2399
+ mutation,
2400
+ variable_values={
2401
+ "artifactID": self.id,
2402
+ "artifactPortfolioID": self.collection.id,
2403
+ },
2404
+ )
2405
+ except CommError as e:
2406
+ raise CommError(
2407
+ f"You do not have permission to unlink the artifact {self.qualified_name}"
2408
+ ) from e
2308
2409
 
2309
2410
  @ensure_logged
2310
2411
  def used_by(self) -> list[Run]:
2311
- """Get a list of the runs that have used this artifact.
2412
+ """Get a list of the runs that have used this artifact and its linked artifacts.
2312
2413
 
2313
2414
  Returns:
2314
2415
  A list of `Run` objects.
@@ -2470,6 +2571,82 @@ class Artifact:
2470
2571
  return INHERIT
2471
2572
  return self._ttl_duration_seconds or DISABLED
2472
2573
 
2574
+ def _fetch_linked_artifacts(self) -> list[Artifact]:
2575
+ """Fetches all linked artifacts from the server."""
2576
+ if self.id is None:
2577
+ raise ValueError(
2578
+ "Unable to find any artifact memberships for artifact without an ID"
2579
+ )
2580
+ if self._client is None:
2581
+ raise ValueError("Client is not initialized")
2582
+ response = self._client.execute(
2583
+ gql_compat(FETCH_LINKED_ARTIFACTS_GQL),
2584
+ variable_values={"artifactID": self.id},
2585
+ )
2586
+ result = FetchLinkedArtifacts.model_validate(response)
2587
+
2588
+ if not (
2589
+ (artifact := result.artifact)
2590
+ and (memberships := artifact.artifact_memberships)
2591
+ and (membership_edges := memberships.edges)
2592
+ ):
2593
+ raise ValueError("Unable to find any artifact memberships for artifact")
2594
+
2595
+ linked_artifacts: deque[Artifact] = deque()
2596
+ linked_nodes = (
2597
+ node
2598
+ for edge in membership_edges
2599
+ if (
2600
+ (node := edge.node)
2601
+ and (col := node.artifact_collection)
2602
+ and (col.typename__ == LINKED_ARTIFACT_COLLECTION_TYPE)
2603
+ )
2604
+ )
2605
+ for node in linked_nodes:
2606
+ # Trick for O(1) membership check that maintains order
2607
+ alias_names = dict.fromkeys(a.alias for a in node.aliases)
2608
+ version = f"v{node.version_index}"
2609
+ aliases = (
2610
+ [*alias_names, version]
2611
+ if version not in alias_names
2612
+ else [*alias_names]
2613
+ )
2614
+
2615
+ if not (
2616
+ node
2617
+ and (col := node.artifact_collection)
2618
+ and (proj := col.project)
2619
+ and (proj.entity_name and proj.name)
2620
+ ):
2621
+ raise ValueError("Unable to fetch fields for linked artifact")
2622
+
2623
+ link_fields = _LinkArtifactFields(
2624
+ entity_name=proj.entity_name,
2625
+ project_name=proj.name,
2626
+ name=f"{col.name}:{version}",
2627
+ version=version,
2628
+ aliases=aliases,
2629
+ )
2630
+ link = self._create_linked_artifact_using_source_artifact(link_fields)
2631
+ linked_artifacts.append(link)
2632
+ return list(linked_artifacts)
2633
+
2634
+ def _create_linked_artifact_using_source_artifact(
2635
+ self,
2636
+ link_fields: _LinkArtifactFields,
2637
+ ) -> Artifact:
2638
+ """Copies the source artifact to a linked artifact."""
2639
+ linked_artifact = copy(self)
2640
+ linked_artifact._version = link_fields.version
2641
+ linked_artifact._aliases = link_fields.aliases
2642
+ linked_artifact._saved_aliases = copy(link_fields.aliases)
2643
+ linked_artifact._name = link_fields.name
2644
+ linked_artifact._entity = link_fields.entity_name
2645
+ linked_artifact._project = link_fields.project_name
2646
+ linked_artifact._is_link = link_fields.is_link
2647
+ linked_artifact._linked_artifacts = link_fields.linked_artifacts
2648
+ return linked_artifact
2649
+
2473
2650
 
2474
2651
  def _ttl_duration_seconds_from_gql(gql_ttl_duration_seconds: int | None) -> int | None:
2475
2652
  # If gql_ttl_duration_seconds is not positive, its indicating that TTL is DISABLED(-2)