wandb 0.22.1__py3-none-musllinux_1_2_aarch64.whl → 0.22.3__py3-none-musllinux_1_2_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. wandb/__init__.py +1 -1
  2. wandb/__init__.pyi +7 -4
  3. wandb/_pydantic/__init__.py +8 -1
  4. wandb/_pydantic/base.py +54 -18
  5. wandb/_pydantic/field_types.py +8 -3
  6. wandb/_pydantic/pagination.py +46 -0
  7. wandb/_pydantic/utils.py +2 -2
  8. wandb/apis/public/api.py +24 -19
  9. wandb/apis/public/artifacts.py +259 -270
  10. wandb/apis/public/registries/_utils.py +40 -54
  11. wandb/apis/public/registries/registries_search.py +70 -85
  12. wandb/apis/public/registries/registry.py +173 -156
  13. wandb/apis/public/runs.py +27 -6
  14. wandb/apis/public/utils.py +43 -20
  15. wandb/automations/_generated/create_automation.py +2 -2
  16. wandb/automations/_generated/create_generic_webhook_integration.py +4 -4
  17. wandb/automations/_generated/delete_automation.py +2 -2
  18. wandb/automations/_generated/fragments.py +31 -52
  19. wandb/automations/_generated/generic_webhook_integrations_by_entity.py +3 -3
  20. wandb/automations/_generated/get_automations.py +3 -3
  21. wandb/automations/_generated/get_automations_by_entity.py +3 -3
  22. wandb/automations/_generated/input_types.py +9 -9
  23. wandb/automations/_generated/integrations_by_entity.py +3 -3
  24. wandb/automations/_generated/operations.py +6 -6
  25. wandb/automations/_generated/slack_integrations_by_entity.py +3 -3
  26. wandb/automations/_generated/update_automation.py +2 -2
  27. wandb/automations/_utils.py +3 -3
  28. wandb/automations/actions.py +3 -3
  29. wandb/automations/automations.py +6 -5
  30. wandb/bin/gpu_stats +0 -0
  31. wandb/bin/wandb-core +0 -0
  32. wandb/cli/beta.py +23 -3
  33. wandb/cli/beta_leet.py +75 -0
  34. wandb/cli/beta_sync.py +1 -1
  35. wandb/cli/cli.py +34 -7
  36. wandb/errors/term.py +8 -8
  37. wandb/jupyter.py +0 -51
  38. wandb/old/settings.py +6 -6
  39. wandb/proto/v3/wandb_api_pb2.py +86 -0
  40. wandb/proto/v3/wandb_server_pb2.py +38 -37
  41. wandb/proto/v3/wandb_settings_pb2.py +2 -2
  42. wandb/proto/v3/wandb_sync_pb2.py +19 -6
  43. wandb/proto/v4/wandb_api_pb2.py +37 -0
  44. wandb/proto/v4/wandb_server_pb2.py +38 -37
  45. wandb/proto/v4/wandb_settings_pb2.py +2 -2
  46. wandb/proto/v4/wandb_sync_pb2.py +10 -6
  47. wandb/proto/v5/wandb_api_pb2.py +38 -0
  48. wandb/proto/v5/wandb_server_pb2.py +38 -37
  49. wandb/proto/v5/wandb_settings_pb2.py +2 -2
  50. wandb/proto/v5/wandb_sync_pb2.py +10 -6
  51. wandb/proto/v6/wandb_api_pb2.py +48 -0
  52. wandb/proto/v6/wandb_server_pb2.py +38 -37
  53. wandb/proto/v6/wandb_settings_pb2.py +2 -2
  54. wandb/proto/v6/wandb_sync_pb2.py +10 -6
  55. wandb/proto/wandb_api_pb2.py +18 -0
  56. wandb/proto/wandb_generate_proto.py +1 -0
  57. wandb/sdk/artifacts/_generated/__init__.py +96 -40
  58. wandb/sdk/artifacts/_generated/add_aliases.py +3 -3
  59. wandb/sdk/artifacts/_generated/add_artifact_collection_tags.py +26 -0
  60. wandb/sdk/artifacts/_generated/artifact_by_id.py +2 -2
  61. wandb/sdk/artifacts/_generated/artifact_by_name.py +3 -3
  62. wandb/sdk/artifacts/_generated/artifact_collection_membership_file_urls.py +27 -8
  63. wandb/sdk/artifacts/_generated/artifact_collection_membership_files.py +27 -8
  64. wandb/sdk/artifacts/_generated/artifact_created_by.py +7 -20
  65. wandb/sdk/artifacts/_generated/artifact_file_urls.py +19 -6
  66. wandb/sdk/artifacts/_generated/artifact_membership_by_name.py +26 -0
  67. wandb/sdk/artifacts/_generated/artifact_type.py +5 -5
  68. wandb/sdk/artifacts/_generated/artifact_used_by.py +8 -17
  69. wandb/sdk/artifacts/_generated/artifact_version_files.py +19 -8
  70. wandb/sdk/artifacts/_generated/delete_aliases.py +3 -3
  71. wandb/sdk/artifacts/_generated/delete_artifact.py +4 -4
  72. wandb/sdk/artifacts/_generated/delete_artifact_collection_tags.py +23 -0
  73. wandb/sdk/artifacts/_generated/delete_artifact_portfolio.py +4 -4
  74. wandb/sdk/artifacts/_generated/delete_artifact_sequence.py +4 -4
  75. wandb/sdk/artifacts/_generated/delete_registry.py +21 -0
  76. wandb/sdk/artifacts/_generated/fetch_artifact_manifest.py +8 -20
  77. wandb/sdk/artifacts/_generated/fetch_linked_artifacts.py +13 -35
  78. wandb/sdk/artifacts/_generated/fetch_org_info_from_entity.py +28 -0
  79. wandb/sdk/artifacts/_generated/fetch_registries.py +18 -8
  80. wandb/sdk/{projects → artifacts}/_generated/fetch_registry.py +4 -4
  81. wandb/sdk/artifacts/_generated/fragments.py +183 -333
  82. wandb/sdk/artifacts/_generated/input_types.py +133 -7
  83. wandb/sdk/artifacts/_generated/link_artifact.py +5 -5
  84. wandb/sdk/artifacts/_generated/operations.py +1053 -548
  85. wandb/sdk/artifacts/_generated/project_artifact_collection.py +9 -77
  86. wandb/sdk/artifacts/_generated/project_artifact_collections.py +21 -9
  87. wandb/sdk/artifacts/_generated/project_artifact_type.py +3 -3
  88. wandb/sdk/artifacts/_generated/project_artifact_types.py +19 -6
  89. wandb/sdk/artifacts/_generated/project_artifacts.py +7 -8
  90. wandb/sdk/artifacts/_generated/registry_collections.py +21 -9
  91. wandb/sdk/artifacts/_generated/registry_versions.py +20 -9
  92. wandb/sdk/artifacts/_generated/rename_registry.py +25 -0
  93. wandb/sdk/artifacts/_generated/run_input_artifacts.py +5 -9
  94. wandb/sdk/artifacts/_generated/run_output_artifacts.py +5 -9
  95. wandb/sdk/artifacts/_generated/type_info.py +2 -2
  96. wandb/sdk/artifacts/_generated/unlink_artifact.py +3 -5
  97. wandb/sdk/artifacts/_generated/update_artifact.py +3 -3
  98. wandb/sdk/artifacts/_generated/update_artifact_collection_type.py +28 -0
  99. wandb/sdk/artifacts/_generated/update_artifact_portfolio.py +7 -16
  100. wandb/sdk/artifacts/_generated/update_artifact_sequence.py +7 -16
  101. wandb/sdk/artifacts/_generated/upsert_registry.py +25 -0
  102. wandb/sdk/artifacts/_gqlutils.py +170 -6
  103. wandb/sdk/artifacts/_models/__init__.py +9 -0
  104. wandb/sdk/artifacts/_models/artifact_collection.py +109 -0
  105. wandb/sdk/artifacts/_models/manifest.py +26 -0
  106. wandb/sdk/artifacts/_models/pagination.py +26 -0
  107. wandb/sdk/artifacts/_models/registry.py +100 -0
  108. wandb/sdk/artifacts/_validators.py +45 -27
  109. wandb/sdk/artifacts/artifact.py +249 -244
  110. wandb/sdk/artifacts/artifact_file_cache.py +1 -1
  111. wandb/sdk/artifacts/artifact_manifest.py +37 -32
  112. wandb/sdk/artifacts/artifact_manifest_entry.py +82 -133
  113. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +43 -61
  114. wandb/sdk/artifacts/storage_handler.py +18 -12
  115. wandb/sdk/artifacts/storage_handlers/azure_handler.py +11 -6
  116. wandb/sdk/artifacts/storage_handlers/gcs_handler.py +17 -12
  117. wandb/sdk/artifacts/storage_handlers/http_handler.py +9 -4
  118. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +10 -6
  119. wandb/sdk/artifacts/storage_handlers/multi_handler.py +5 -4
  120. wandb/sdk/artifacts/storage_handlers/s3_handler.py +10 -8
  121. wandb/sdk/artifacts/storage_handlers/tracking_handler.py +6 -4
  122. wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +24 -21
  123. wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +4 -2
  124. wandb/sdk/artifacts/storage_policies/_multipart.py +187 -0
  125. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +61 -242
  126. wandb/sdk/artifacts/storage_policy.py +25 -12
  127. wandb/sdk/data_types/image.py +2 -2
  128. wandb/sdk/data_types/object_3d.py +67 -2
  129. wandb/sdk/interface/interface.py +72 -64
  130. wandb/sdk/interface/interface_queue.py +27 -18
  131. wandb/sdk/interface/interface_shared.py +61 -23
  132. wandb/sdk/interface/interface_sock.py +9 -5
  133. wandb/sdk/internal/_generated/server_features_query.py +4 -4
  134. wandb/sdk/internal/job_builder.py +27 -10
  135. wandb/sdk/internal/sender.py +4 -1
  136. wandb/sdk/launch/create_job.py +2 -1
  137. wandb/sdk/launch/inputs/schema.py +13 -10
  138. wandb/sdk/lib/apikey.py +8 -12
  139. wandb/sdk/lib/asyncio_compat.py +1 -1
  140. wandb/sdk/lib/asyncio_manager.py +5 -5
  141. wandb/sdk/lib/console_capture.py +38 -30
  142. wandb/sdk/lib/progress.py +151 -125
  143. wandb/sdk/lib/retry.py +3 -2
  144. wandb/sdk/lib/service/service_connection.py +2 -2
  145. wandb/sdk/lib/wb_logging.py +2 -1
  146. wandb/sdk/mailbox/mailbox.py +1 -1
  147. wandb/sdk/wandb_init.py +11 -14
  148. wandb/sdk/wandb_run.py +14 -48
  149. wandb/sdk/wandb_settings.py +114 -30
  150. {wandb-0.22.1.dist-info → wandb-0.22.3.dist-info}/METADATA +2 -1
  151. {wandb-0.22.1.dist-info → wandb-0.22.3.dist-info}/RECORD +154 -146
  152. wandb/sdk/artifacts/_generated/artifact_via_membership_by_name.py +0 -26
  153. wandb/sdk/artifacts/_generated/create_artifact_collection_tag_assignments.py +0 -36
  154. wandb/sdk/artifacts/_generated/delete_artifact_collection_tag_assignments.py +0 -25
  155. wandb/sdk/artifacts/_generated/move_artifact_collection.py +0 -35
  156. wandb/sdk/projects/_generated/__init__.py +0 -26
  157. wandb/sdk/projects/_generated/delete_project.py +0 -22
  158. wandb/sdk/projects/_generated/enums.py +0 -4
  159. wandb/sdk/projects/_generated/fragments.py +0 -41
  160. wandb/sdk/projects/_generated/input_types.py +0 -13
  161. wandb/sdk/projects/_generated/operations.py +0 -88
  162. wandb/sdk/projects/_generated/rename_project.py +0 -27
  163. wandb/sdk/projects/_generated/upsert_registry_project.py +0 -27
  164. {wandb-0.22.1.dist-info → wandb-0.22.3.dist-info}/WHEEL +0 -0
  165. {wandb-0.22.1.dist-info → wandb-0.22.3.dist-info}/entry_points.txt +0 -0
  166. {wandb-0.22.1.dist-info → wandb-0.22.3.dist-info}/licenses/LICENSE +0 -0
@@ -1,10 +1,30 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from contextlib import suppress
4
+ from dataclasses import dataclass
3
5
  from functools import lru_cache
6
+ from typing import TYPE_CHECKING
4
7
 
5
- from wandb_gql import Client, gql
8
+ from wandb_gql import gql
6
9
 
7
- from ._generated import TYPE_INFO_GQL, TypeInfo, TypeInfoFragment
10
+ from wandb._iterutils import one
11
+ from wandb.errors import UnsupportedError
12
+ from wandb.proto.wandb_internal_pb2 import ServerFeature
13
+ from wandb.sdk.artifacts._generated.fetch_org_info_from_entity import (
14
+ FetchOrgInfoFromEntityEntity,
15
+ )
16
+ from wandb.sdk.internal._generated import SERVER_FEATURES_QUERY_GQL, ServerFeaturesQuery
17
+
18
+ from ._generated import (
19
+ FETCH_ORG_INFO_FROM_ENTITY_GQL,
20
+ TYPE_INFO_GQL,
21
+ FetchOrgInfoFromEntity,
22
+ TypeInfo,
23
+ TypeInfoFragment,
24
+ )
25
+
26
+ if TYPE_CHECKING:
27
+ from wandb.apis.public import RetryingClient
8
28
 
9
29
  OMITTABLE_ARTIFACT_FIELDS = frozenset(
10
30
  {
@@ -18,13 +38,58 @@ OMITTABLE_ARTIFACT_FIELDS = frozenset(
18
38
 
19
39
 
20
40
  @lru_cache(maxsize=16)
21
- def type_info(client: Client, typename: str) -> TypeInfoFragment | None:
41
+ def type_info(client: RetryingClient, typename: str) -> TypeInfoFragment | None:
22
42
  """Returns the type info for a given GraphQL type."""
23
43
  data = client.execute(gql(TYPE_INFO_GQL), variable_values={"name": typename})
24
44
  return TypeInfo.model_validate(data).type
25
45
 
26
46
 
27
- def supports_enable_tracking_var(client: Client) -> bool:
47
+ @lru_cache(maxsize=16)
48
+ def org_info_from_entity(
49
+ client: RetryingClient, entity: str
50
+ ) -> FetchOrgInfoFromEntityEntity | None:
51
+ """Returns the organization info for a given entity."""
52
+ gql_op = gql(FETCH_ORG_INFO_FROM_ENTITY_GQL)
53
+ data = client.execute(gql_op, variable_values={"entity": entity})
54
+ return FetchOrgInfoFromEntity.model_validate(data).entity
55
+
56
+
57
+ @lru_cache(maxsize=16)
58
+ def server_features(client: RetryingClient) -> dict[str, bool]:
59
+ """Returns a mapping of `{server_feature_name (str) -> is_enabled (bool)}`.
60
+
61
+ Results are cached per client instance.
62
+ """
63
+ try:
64
+ response = client.execute(gql(SERVER_FEATURES_QUERY_GQL))
65
+ except Exception as e:
66
+ # Unfortunately we currently have to match on the text of the error message,
67
+ # as the `gql` client raises `Exception` rather than a more specific error.
68
+ if 'Cannot query field "features" on type "ServerInfo".' in str(e):
69
+ return {}
70
+ raise
71
+
72
+ result = ServerFeaturesQuery.model_validate(response)
73
+ if (server_info := result.server_info) and (features := server_info.features):
74
+ return {feat.name: feat.is_enabled for feat in features if feat}
75
+ return {}
76
+
77
+
78
+ def server_supports(client: RetryingClient, feature: str | int) -> bool:
79
+ """Return whether the current server supports the given feature.
80
+
81
+ Good to use for features that have a fallback mechanism for older servers.
82
+ """
83
+ # If we're given the protobuf enum value, convert to a string name.
84
+ # NOTE: We deliberately use names (str) instead of enum values (int)
85
+ # as the keys here, since:
86
+ # - the server identifies features by their name, rather than (client-side) enum value
87
+ # - the defined list of client-side flags may be behind the server-side list of flags
88
+ feature_name = ServerFeature.Name(feature) if isinstance(feature, int) else feature
89
+ return server_features(client).get(feature_name) or False
90
+
91
+
92
+ def supports_enable_tracking_var(client: RetryingClient) -> bool:
28
93
  """Returns True if the server supports the `enableTracking` variable for the `Project.artifact(...)` field."""
29
94
  typ = type_info(client, "Project")
30
95
  if (
@@ -36,12 +101,111 @@ def supports_enable_tracking_var(client: Client) -> bool:
36
101
  return False
37
102
 
38
103
 
39
- def allowed_fields(client: Client, typename: str) -> set[str]:
104
+ def allowed_fields(client: RetryingClient, typename: str) -> set[str]:
40
105
  """Returns the allowed field names for a given GraphQL type."""
41
106
  typ = type_info(client, typename)
42
107
  return {f.name for f in typ.fields} if (typ and typ.fields) else set()
43
108
 
44
109
 
45
- def omit_artifact_fields(client: Client) -> set[str]:
110
+ def omit_artifact_fields(client: RetryingClient) -> set[str]:
46
111
  """Return names of Artifact fields to remove from GraphQL requests (for server compatibility)."""
47
112
  return set(OMITTABLE_ARTIFACT_FIELDS) - allowed_fields(client, "Artifact")
113
+
114
+
115
+ @dataclass(frozen=True)
116
+ class OrgInfo:
117
+ org_name: str
118
+ entity_name: str
119
+
120
+ def __contains__(self, other: str) -> bool:
121
+ return other in {self.org_name, self.entity_name}
122
+
123
+
124
+ def resolve_org_entity_name(
125
+ client: RetryingClient,
126
+ non_org_entity: str,
127
+ org_or_entity: str | None = None,
128
+ ) -> str:
129
+ # resolveOrgEntityName fetches the portfolio's org entity's name.
130
+ #
131
+ # The org_or_org_entity parameter may be empty, an org's display name, or an org entity name.
132
+ #
133
+ # If the server doesn't support fetching the org name of a portfolio, then this returns
134
+ # the org_or_org_entity parameter, or an error if it is empty. Otherwise, this returns the
135
+ # fetched value after validating that the given organization, if not empty, matches
136
+ # either the org's display or entity name.
137
+ if not non_org_entity:
138
+ raise ValueError("Entity name is required to resolve org entity name.")
139
+
140
+ if "orgEntity" not in allowed_fields(client, "Organization"):
141
+ if org_or_entity:
142
+ # Server doesn't support fetching orgEntity to match against,
143
+ # so assume orgEntity as provided is already correct.
144
+ return org_or_entity
145
+
146
+ raise UnsupportedError(
147
+ "Fetching Registry artifacts without inputting an organization "
148
+ "is unavailable for your server version. "
149
+ "Please upgrade your server to 0.50.0 or later."
150
+ )
151
+
152
+ # Otherwise, fetch candidate orgs to verify/identify the correct orgEntity name, if possible.
153
+ entity = org_info_from_entity(client, non_org_entity)
154
+
155
+ # Parse possible organization(s) from the response
156
+
157
+ # ----------------------------------------------------------------------------
158
+ # If a team entity was provided, there should be a single organization under team/org entity type.
159
+ if entity and (org := entity.organization) and (org_entity := org.org_entity):
160
+ # Make sure the org_or_org_entity name, if given, matches the org or org entity name before returning the org entity.
161
+ org_info = OrgInfo(org_name=org.name, entity_name=org_entity.name)
162
+ if (not org_or_entity) or (org_or_entity in org_info):
163
+ return org_entity.name
164
+
165
+ # ----------------------------------------------------------------------------
166
+ # If a personal entity was provided, there may be multiple organizations that the user belongs to
167
+ if entity and (user := entity.user) and (orgs := user.organizations):
168
+ org_infos = [
169
+ OrgInfo(org_name=org.name, entity_name=org_entity.name)
170
+ for org in orgs
171
+ if (org_entity := org.org_entity)
172
+ ]
173
+ if org_or_entity:
174
+ with suppress(StopIteration):
175
+ return next(
176
+ info.entity_name for info in org_infos if (org_or_entity in info)
177
+ )
178
+
179
+ if len(org_infos) == 1:
180
+ raise ValueError(
181
+ f"Expecting the organization name or entity name to match {org_infos[0].org_name!r} "
182
+ f"and cannot be linked/fetched with {org_or_entity!r}. "
183
+ "Please update the target path with the correct organization name."
184
+ )
185
+ else:
186
+ raise ValueError(
187
+ "Personal entity belongs to multiple organizations "
188
+ f"and cannot be linked/fetched with {org_or_entity!r}. "
189
+ "Please update the target path with the correct organization name "
190
+ "or use a team entity in the entity settings."
191
+ )
192
+
193
+ else:
194
+ # If no input organization provided, error if entity belongs to:
195
+ # - multiple orgs, because we cannot determine which one to use.
196
+ # - no orgs, because there's nothing to use.
197
+ return one(
198
+ (org.entity_name for org in org_infos),
199
+ too_short=ValueError(
200
+ f"Unable to resolve an organization associated with personal entity: {non_org_entity!r}. "
201
+ "This could be because its a personal entity that doesn't belong to any organizations. "
202
+ "Please specify the organization in the Registry path or use a team entity in the entity settings."
203
+ ),
204
+ too_long=ValueError(
205
+ f"Personal entity {non_org_entity!r} belongs to multiple organizations "
206
+ "and cannot be used without specifying the organization name. "
207
+ "Please specify the organization in the Registry path or use a team entity in the entity settings."
208
+ ),
209
+ )
210
+
211
+ raise ValueError(f"Unable to find organization for entity {non_org_entity!r}.")
@@ -2,3 +2,12 @@
2
2
 
3
3
  Excludes GraphQL-generated classes.
4
4
  """
5
+
6
+ __all__ = [
7
+ "ArtifactsBase",
8
+ "RegistryData",
9
+ "ArtifactCollectionData",
10
+ ]
11
+ from .artifact_collection import ArtifactCollectionData
12
+ from .base_model import ArtifactsBase
13
+ from .registry import RegistryData
@@ -0,0 +1,109 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, List, Optional, Tuple
4
+
5
+ from pydantic import ConfigDict, Field
6
+ from typing_extensions import Self
7
+
8
+ from wandb._pydantic import field_validator
9
+ from wandb.sdk.artifacts._generated import ArtifactCollectionFragment
10
+ from wandb.sdk.artifacts._validators import (
11
+ SOURCE_COLLECTION_TYPENAME,
12
+ validate_artifact_name,
13
+ validate_tags,
14
+ )
15
+
16
+ from .base_model import ArtifactsBase
17
+
18
+
19
+ class ArtifactCollectionData(ArtifactsBase):
20
+ """Transport-free model for local `ArtifactCollection` data.
21
+
22
+ For now, this is separated from the public `ArtifactCollection` model
23
+ to more easily ensure continuity in the public `ArtifactCollection` API.
24
+
25
+ Note:
26
+ In a future change, consider making _this_ the public `ArtifactCollection` instead, i.e.:
27
+ - Replace the _existing_ `ArtifactCollection` class with this one
28
+ - Rename _this_ class to `ArtifactCollection`
29
+ Note that this would be a breaking change.
30
+ """
31
+
32
+ model_config = ConfigDict(
33
+ str_min_length=1, # Strings cannot be empty
34
+ )
35
+
36
+ typename__: str = Field(alias="__typename", frozen=True, repr=False)
37
+ """The GraphQL `__typename` for this object."""
38
+
39
+ id: str = Field(frozen=True, repr=False)
40
+ """The encoded GraphQL ID for this object."""
41
+
42
+ name: str
43
+ """The name of this collection."""
44
+
45
+ type: str
46
+ """The artifact type of this collection."""
47
+
48
+ description: Optional[str] = None
49
+ """The description, if any, for this collection."""
50
+
51
+ created_at: str = Field(frozen=True)
52
+ """When this collection was created."""
53
+
54
+ project: str = Field(frozen=True)
55
+ """The name of this collection's project."""
56
+
57
+ entity: str = Field(frozen=True)
58
+ """The name of the entity that owns this collection's project."""
59
+
60
+ aliases: Optional[Tuple[str, ...]] = Field(default=None, frozen=True)
61
+ """All aliases assigned to artifact versions within this collection.
62
+
63
+ Note:
64
+ `None` should signal that aliases haven't been (or couldn't be) fetched and parsed yet,
65
+ for any reason, NOT the actual absence of aliases in this collection.
66
+ """
67
+
68
+ tags: List[str] = Field(default_factory=list)
69
+ """The tags assigned to this collection.
70
+
71
+ Note that this is distinct from any tags assigned to individual artifact versions within this collection.
72
+ """
73
+
74
+ @field_validator("name", mode="plain")
75
+ def _validate_name(cls, v: str) -> str:
76
+ return validate_artifact_name(v)
77
+
78
+ @field_validator("tags", mode="plain")
79
+ def _validate_tags(cls, v: Any) -> list[str]:
80
+ """Ensure tags is a validated, deduped list of (str) tag names."""
81
+ return validate_tags(v)
82
+
83
+ @property
84
+ def is_sequence(self) -> bool:
85
+ """Returns True if the artifact collection is an ArtifactSequence (source collection)."""
86
+ return self.typename__ == SOURCE_COLLECTION_TYPENAME
87
+
88
+ @classmethod
89
+ def from_fragment(cls, obj: ArtifactCollectionFragment) -> Self:
90
+ """Instantiate this type from a GraphQL fragment."""
91
+ if obj.project is None:
92
+ raise ValueError(f"Missing project info in {type(obj)!r} data")
93
+
94
+ return cls(
95
+ typename__=obj.typename__,
96
+ id=obj.id,
97
+ name=obj.name,
98
+ type=obj.type.name,
99
+ description=obj.description,
100
+ created_at=obj.created_at,
101
+ project=obj.project.name,
102
+ entity=obj.project.entity.name,
103
+ tags=[e.node.name for e in obj.tags.edges if e.node],
104
+ aliases=(
105
+ [e.node.alias for e in obj.aliases.edges if e.node]
106
+ if obj.aliases
107
+ else []
108
+ ),
109
+ )
@@ -0,0 +1,26 @@
1
+ from typing import Any, Dict, Literal, final
2
+
3
+ from wandb._pydantic import field_validator, to_camel
4
+ from wandb.sdk.artifacts.artifact_manifest_entry import ArtifactManifestEntry
5
+
6
+ from .base_model import ArtifactsBase
7
+
8
+
9
+ @final
10
+ class ArtifactManifestV1Data(ArtifactsBase, alias_generator=to_camel):
11
+ """Data model for the v1 artifact manifest."""
12
+
13
+ version: Literal[1]
14
+
15
+ contents: Dict[str, ArtifactManifestEntry]
16
+
17
+ storage_policy: str
18
+ storage_policy_config: Dict[str, Any]
19
+
20
+ @field_validator("contents", mode="before")
21
+ def _validate_entries(cls, v: Any) -> Any:
22
+ # The dict keys should be the `entry.path` values, but they've
23
+ # historically been dropped from the JSON objects. This restores
24
+ # them on instantiation.
25
+ # Pydantic will handle converting dicts -> ArtifactManifestEntries.
26
+ return {path: {**dict(entry), "path": path} for path, entry in v.items()}
@@ -0,0 +1,26 @@
1
+ """Artifacts-specific data models for handling paginated results from GraphQL queries."""
2
+
3
+ from wandb._pydantic import Connection, ConnectionWithTotal
4
+
5
+ from .._generated.fragments import (
6
+ ArtifactCollectionFragment,
7
+ ArtifactFragment,
8
+ ArtifactMembershipFragment,
9
+ ArtifactTypeFragment,
10
+ FileFragment,
11
+ FileWithUrlFragment,
12
+ RegistryCollectionFragment,
13
+ RegistryFragment,
14
+ )
15
+
16
+ ArtifactTypeConnection = Connection[ArtifactTypeFragment]
17
+ ArtifactCollectionConnection = ConnectionWithTotal[ArtifactCollectionFragment]
18
+ ArtifactMembershipConnection = Connection[ArtifactMembershipFragment]
19
+
20
+ FileWithUrlConnection = Connection[FileWithUrlFragment]
21
+ ArtifactFileConnection = Connection[FileFragment]
22
+
23
+ RunArtifactConnection = ConnectionWithTotal[ArtifactFragment]
24
+
25
+ RegistryConnection = Connection[RegistryFragment]
26
+ RegistryCollectionConnection = ConnectionWithTotal[RegistryCollectionFragment]
@@ -0,0 +1,100 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Optional
4
+
5
+ from pydantic import ConfigDict, Field
6
+ from typing_extensions import Self
7
+
8
+ from wandb._pydantic import GQLId, field_validator
9
+ from wandb._strutils import nameof
10
+ from wandb.apis.public.registries._freezable_list import AddOnlyArtifactTypesList
11
+ from wandb.apis.public.registries._utils import Visibility
12
+ from wandb.sdk.artifacts._generated import RegistryFragment
13
+ from wandb.sdk.artifacts._generated.fragments import RegistryFragmentArtifactTypes
14
+ from wandb.sdk.artifacts._validators import REGISTRY_PREFIX, remove_registry_prefix
15
+
16
+ from .base_model import ArtifactsBase
17
+
18
+
19
+ class RegistryData(ArtifactsBase):
20
+ """Transport-free model for local `Registry` data.
21
+
22
+ For now, this is separated from the public `Registry` class
23
+ to more easily ensure continuity in the public `Registry` API.
24
+ """
25
+
26
+ model_config = ConfigDict(
27
+ str_min_length=1, # Strings cannot be empty
28
+ )
29
+
30
+ id: GQLId = Field(frozen=True)
31
+ """The unique, encoded ID for this registry."""
32
+
33
+ created_at: str = Field(frozen=True)
34
+ """When this registry was created."""
35
+
36
+ updated_at: Optional[str] = Field(frozen=True)
37
+ """When this registry was last updated."""
38
+
39
+ organization: str = Field(frozen=True)
40
+ """The organization of the registry."""
41
+
42
+ entity: str = Field(frozen=True)
43
+ """The organization entity of the registry."""
44
+
45
+ name: str
46
+ """The name of the registry without the `wandb-registry-` project prefix."""
47
+
48
+ description: Optional[str] = None
49
+ """The description, if any, of the registry."""
50
+
51
+ allow_all_artifact_types: bool
52
+ """Whether all artifact types are allowed in the registry."""
53
+
54
+ artifact_types: AddOnlyArtifactTypesList = Field(
55
+ default_factory=AddOnlyArtifactTypesList
56
+ )
57
+ """The artifact types allowed in the registry.
58
+
59
+ The meaning of this list depends on the value of `allow_all_artifact_types`:
60
+ - If True: `artifact_types` are the previously-saved or currently-used types in the registry.
61
+ - If False: `artifact_types` are the only allowed artifact types in the registry.
62
+ """
63
+
64
+ visibility: Visibility = Field(alias="access")
65
+ """The visibility of the registry."""
66
+
67
+ @property
68
+ def full_name(self) -> str:
69
+ """The full project name of the registry, including the expected `wandb-registry-` prefix."""
70
+ return f"{REGISTRY_PREFIX}{self.name}"
71
+
72
+ @field_validator("artifact_types", mode="plain")
73
+ def _validate_artifact_types(cls, v: Any) -> AddOnlyArtifactTypesList:
74
+ """Coerce `artifact_types` to an AddOnlyArtifactTypesList."""
75
+ if isinstance(v, RegistryFragmentArtifactTypes):
76
+ # This is a GQL connection object, so we need to extract the node names
77
+ return AddOnlyArtifactTypesList(e.node.name for e in v.edges if e.node)
78
+
79
+ # By default, assume we were passed an iterable of strings
80
+ return AddOnlyArtifactTypesList(v)
81
+
82
+ @classmethod
83
+ def from_fragment(cls, obj: RegistryFragment) -> Self:
84
+ if not obj.entity.organization:
85
+ raise ValueError(
86
+ f"Unable to parse registry organization from {nameof(type(obj))!r} object"
87
+ )
88
+
89
+ return cls(
90
+ id=obj.id,
91
+ created_at=obj.created_at,
92
+ updated_at=obj.updated_at,
93
+ organization=obj.entity.organization.name,
94
+ entity=obj.entity.name,
95
+ name=remove_registry_prefix(obj.name),
96
+ description=obj.description,
97
+ allow_all_artifact_types=obj.allow_all_artifact_types,
98
+ artifact_types=obj.artifact_types,
99
+ visibility=obj.access,
100
+ )
@@ -6,7 +6,17 @@ import json
6
6
  import re
7
7
  from dataclasses import dataclass, field, replace
8
8
  from functools import wraps
9
- from typing import TYPE_CHECKING, Any, Callable, Dict, Literal, Optional, TypeVar, cast
9
+ from typing import (
10
+ TYPE_CHECKING,
11
+ Any,
12
+ Callable,
13
+ Dict,
14
+ Iterable,
15
+ Literal,
16
+ Optional,
17
+ TypeVar,
18
+ cast,
19
+ )
10
20
 
11
21
  from pydantic.dataclasses import dataclass as pydantic_dataclass
12
22
  from typing_extensions import Concatenate, ParamSpec, Self
@@ -35,8 +45,8 @@ MAX_ARTIFACT_METADATA_KEYS: Final[int] = 100
35
45
  ARTIFACT_NAME_MAXLEN: Final[int] = 128
36
46
  ARTIFACT_NAME_INVALID_CHARS: Final[frozenset[str]] = frozenset({"/"})
37
47
 
38
- LINKED_ARTIFACT_COLLECTION_TYPE: Final[str] = gql_typename(ArtifactPortfolioTypeFields)
39
- SOURCE_ARTIFACT_COLLECTION_TYPE: Final[str] = gql_typename(ArtifactSequenceTypeFields)
48
+ LINKED_COLLECTION_TYPENAME: Final[str] = gql_typename(ArtifactPortfolioTypeFields)
49
+ SOURCE_COLLECTION_TYPENAME: Final[str] = gql_typename(ArtifactSequenceTypeFields)
40
50
 
41
51
 
42
52
  @dataclass
@@ -85,11 +95,14 @@ def validate_artifact_name(name: str) -> str:
85
95
  return name
86
96
 
87
97
 
88
- INVALID_URL_CHARACTERS = ("/", "\\", "#", "?", "%", ":", "\r", "\n")
98
+ INVALID_URL_CHARACTERS: Final[frozenset[str]] = frozenset("/\\#?%:\r\n")
89
99
 
100
+ PROJECT_NAME_MAXLEN: Final[int] = 128
101
+ REGISTRY_NAME_MAXLEN: Final[int] = PROJECT_NAME_MAXLEN - len(REGISTRY_PREFIX)
90
102
 
91
- def validate_project_name(name: str) -> None:
92
- """Validates a project name according to W&B rules.
103
+
104
+ def validate_project_name(name: str) -> str:
105
+ """Validates a project name according to W&B rules, returning the original name if successful.
93
106
 
94
107
  Args:
95
108
  name: The project name string.
@@ -97,18 +110,16 @@ def validate_project_name(name: str) -> None:
97
110
  Raises:
98
111
  ValueError: If the name is invalid (too long or contains invalid characters).
99
112
  """
100
- max_len = 128
101
-
102
113
  if not name:
103
114
  raise ValueError("Project name cannot be empty")
104
115
  if not (registry_name := removeprefix(name, REGISTRY_PREFIX)):
105
116
  raise ValueError("Registry name cannot be empty")
106
117
 
107
- if len(name) > max_len:
118
+ if len(name) > PROJECT_NAME_MAXLEN:
108
119
  if registry_name != name:
109
- msg = f"Invalid registry name {registry_name!r}, must be {max_len - len(REGISTRY_PREFIX)} characters or less"
120
+ msg = f"Invalid registry name {registry_name!r}, must be {REGISTRY_NAME_MAXLEN} characters or less"
110
121
  else:
111
- msg = f"Invalid project name {name!r}, must be {max_len} characters or less"
122
+ msg = f"Invalid project name {name!r}, must be {PROJECT_NAME_MAXLEN} characters or less"
112
123
  raise ValueError(msg)
113
124
 
114
125
  # Find the first occurrence of any invalid character
@@ -118,6 +129,10 @@ def validate_project_name(name: str) -> None:
118
129
  raise ValueError(
119
130
  f"Invalid project/registry name {error_name!r}, cannot contain characters: {invalid_chars_repr!s}"
120
131
  )
132
+ return name
133
+
134
+
135
+ ALIAS_NAME_INVALID_CHARS: Final[frozenset[str]] = frozenset("/:")
121
136
 
122
137
 
123
138
  def validate_aliases(aliases: Collection[str] | str) -> list[str]:
@@ -127,36 +142,39 @@ def validate_aliases(aliases: Collection[str] | str) -> list[str]:
127
142
  ValueError: If any of the aliases contain invalid characters.
128
143
  """
129
144
  aliases_list = always_list(aliases)
130
-
131
- invalid_chars = ("/", ":")
132
- if any(char in alias for alias in aliases_list for char in invalid_chars):
145
+ if any(ALIAS_NAME_INVALID_CHARS.intersection(alias) for alias in aliases_list):
146
+ invalid_repr = ", ".join(sorted(map(repr, ALIAS_NAME_INVALID_CHARS)))
133
147
  raise ValueError(
134
- f"Aliases must not contain any of the following characters: {', '.join(invalid_chars)}"
148
+ f"Aliases must not contain any of the following characters: {invalid_repr}"
135
149
  )
136
150
  return aliases_list
137
151
 
138
152
 
139
- def validate_artifact_types_list(artifact_types: list[str]) -> list[str]:
140
- """Return True if the artifact types list is valid, False otherwise."""
141
- artifact_types = always_list(artifact_types)
142
- invalid_chars = ("/", ":")
153
+ ARTIFACT_TYPE_NAME_MAXLEN: Final[int] = 128
154
+ ARTIFACT_TYPE_NAME_INVALID_CHARS: Final[frozenset[str]] = frozenset("/:")
155
+
156
+
157
+ def validate_artifact_types(artifact_types: Collection[str] | str) -> list[str]:
158
+ """Validate the artifact type names and return them as a list."""
159
+ artifact_types_list = always_list(artifact_types)
143
160
  if any(
144
- char in type or len(type) > 128
145
- for type in artifact_types
146
- for char in invalid_chars
161
+ len(name) > ARTIFACT_TYPE_NAME_MAXLEN
162
+ or ARTIFACT_TYPE_NAME_INVALID_CHARS.intersection(name)
163
+ for name in artifact_types_list
147
164
  ):
165
+ invalid_repr = ", ".join(sorted(map(repr, ARTIFACT_TYPE_NAME_INVALID_CHARS)))
148
166
  raise ValueError(
149
- f"""Artifact types must not contain any of the following characters: {", ".join(invalid_chars)}
150
- and must be less than equal to 128 characters"""
167
+ f"Artifact types must not contain any of the following characters: {invalid_repr}"
168
+ f"and must be less than or equal to {ARTIFACT_TYPE_NAME_MAXLEN!r} characters"
151
169
  )
152
- return artifact_types
170
+ return artifact_types_list
153
171
 
154
172
 
155
173
  TAG_REGEX: re.Pattern[str] = re.compile(r"^[-\w]+( +[-\w]+)*$")
156
174
  """Regex pattern for valid tag names."""
157
175
 
158
176
 
159
- def validate_tags(tags: Collection[str] | str) -> list[str]:
177
+ def validate_tags(tags: Iterable[str] | str) -> list[str]:
160
178
  """Validate the artifact tag names and return them as a deduped list.
161
179
 
162
180
  In the case of duplicates, only keep the first tag, and otherwise maintain the order of appearance.
@@ -204,7 +222,7 @@ def validate_artifact_type(typ: str, name: str) -> str:
204
222
  return typ
205
223
 
206
224
 
207
- def validate_metadata(metadata: dict[str, Any] | None) -> dict[str, Any]:
225
+ def validate_metadata(metadata: dict[str, Any] | str | None) -> dict[str, Any]:
208
226
  """Validate the artifact metadata and return it as a dict."""
209
227
  if metadata is None:
210
228
  return {}