wandb 0.22.2__py3-none-macosx_12_0_arm64.whl → 0.22.3__py3-none-macosx_12_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +1 -1
- wandb/__init__.pyi +2 -2
- wandb/_pydantic/__init__.py +8 -1
- wandb/_pydantic/base.py +54 -18
- wandb/_pydantic/field_types.py +8 -3
- wandb/_pydantic/pagination.py +46 -0
- wandb/_pydantic/utils.py +2 -2
- wandb/apis/public/api.py +24 -19
- wandb/apis/public/artifacts.py +259 -270
- wandb/apis/public/registries/_utils.py +40 -54
- wandb/apis/public/registries/registries_search.py +70 -85
- wandb/apis/public/registries/registry.py +173 -156
- wandb/apis/public/runs.py +27 -6
- wandb/apis/public/utils.py +43 -20
- wandb/automations/_generated/create_automation.py +2 -2
- wandb/automations/_generated/create_generic_webhook_integration.py +4 -4
- wandb/automations/_generated/delete_automation.py +2 -2
- wandb/automations/_generated/fragments.py +31 -52
- wandb/automations/_generated/generic_webhook_integrations_by_entity.py +3 -3
- wandb/automations/_generated/get_automations.py +3 -3
- wandb/automations/_generated/get_automations_by_entity.py +3 -3
- wandb/automations/_generated/input_types.py +9 -9
- wandb/automations/_generated/integrations_by_entity.py +3 -3
- wandb/automations/_generated/operations.py +6 -6
- wandb/automations/_generated/slack_integrations_by_entity.py +3 -3
- wandb/automations/_generated/update_automation.py +2 -2
- wandb/automations/_utils.py +3 -3
- wandb/automations/actions.py +3 -3
- wandb/automations/automations.py +6 -5
- wandb/bin/gpu_stats +0 -0
- wandb/bin/wandb-core +0 -0
- wandb/cli/beta.py +8 -2
- wandb/cli/beta_leet.py +2 -1
- wandb/cli/beta_sync.py +1 -1
- wandb/errors/term.py +8 -8
- wandb/jupyter.py +0 -51
- wandb/old/settings.py +6 -6
- wandb/proto/v3/wandb_internal_pb2.py +351 -352
- wandb/proto/v3/wandb_server_pb2.py +38 -37
- wandb/proto/v3/wandb_settings_pb2.py +2 -2
- wandb/proto/v3/wandb_sync_pb2.py +19 -6
- wandb/proto/v4/wandb_internal_pb2.py +351 -352
- wandb/proto/v4/wandb_server_pb2.py +38 -37
- wandb/proto/v4/wandb_settings_pb2.py +2 -2
- wandb/proto/v4/wandb_sync_pb2.py +10 -6
- wandb/proto/v5/wandb_internal_pb2.py +351 -352
- wandb/proto/v5/wandb_server_pb2.py +38 -37
- wandb/proto/v5/wandb_settings_pb2.py +2 -2
- wandb/proto/v5/wandb_sync_pb2.py +10 -6
- wandb/proto/v6/wandb_internal_pb2.py +351 -352
- wandb/proto/v6/wandb_server_pb2.py +38 -37
- wandb/proto/v6/wandb_settings_pb2.py +2 -2
- wandb/proto/v6/wandb_sync_pb2.py +10 -6
- wandb/sdk/artifacts/_generated/__init__.py +96 -40
- wandb/sdk/artifacts/_generated/add_aliases.py +3 -3
- wandb/sdk/artifacts/_generated/add_artifact_collection_tags.py +26 -0
- wandb/sdk/artifacts/_generated/artifact_by_id.py +2 -2
- wandb/sdk/artifacts/_generated/artifact_by_name.py +3 -3
- wandb/sdk/artifacts/_generated/artifact_collection_membership_file_urls.py +27 -8
- wandb/sdk/artifacts/_generated/artifact_collection_membership_files.py +27 -8
- wandb/sdk/artifacts/_generated/artifact_created_by.py +7 -20
- wandb/sdk/artifacts/_generated/artifact_file_urls.py +19 -6
- wandb/sdk/artifacts/_generated/artifact_membership_by_name.py +26 -0
- wandb/sdk/artifacts/_generated/artifact_type.py +5 -5
- wandb/sdk/artifacts/_generated/artifact_used_by.py +8 -17
- wandb/sdk/artifacts/_generated/artifact_version_files.py +19 -8
- wandb/sdk/artifacts/_generated/delete_aliases.py +3 -3
- wandb/sdk/artifacts/_generated/delete_artifact.py +4 -4
- wandb/sdk/artifacts/_generated/delete_artifact_collection_tags.py +23 -0
- wandb/sdk/artifacts/_generated/delete_artifact_portfolio.py +4 -4
- wandb/sdk/artifacts/_generated/delete_artifact_sequence.py +4 -4
- wandb/sdk/artifacts/_generated/delete_registry.py +21 -0
- wandb/sdk/artifacts/_generated/fetch_artifact_manifest.py +8 -20
- wandb/sdk/artifacts/_generated/fetch_linked_artifacts.py +13 -35
- wandb/sdk/artifacts/_generated/fetch_org_info_from_entity.py +28 -0
- wandb/sdk/artifacts/_generated/fetch_registries.py +18 -8
- wandb/sdk/{projects → artifacts}/_generated/fetch_registry.py +4 -4
- wandb/sdk/artifacts/_generated/fragments.py +183 -333
- wandb/sdk/artifacts/_generated/input_types.py +133 -7
- wandb/sdk/artifacts/_generated/link_artifact.py +5 -5
- wandb/sdk/artifacts/_generated/operations.py +1053 -548
- wandb/sdk/artifacts/_generated/project_artifact_collection.py +9 -77
- wandb/sdk/artifacts/_generated/project_artifact_collections.py +21 -9
- wandb/sdk/artifacts/_generated/project_artifact_type.py +3 -3
- wandb/sdk/artifacts/_generated/project_artifact_types.py +19 -6
- wandb/sdk/artifacts/_generated/project_artifacts.py +7 -8
- wandb/sdk/artifacts/_generated/registry_collections.py +21 -9
- wandb/sdk/artifacts/_generated/registry_versions.py +20 -9
- wandb/sdk/artifacts/_generated/rename_registry.py +25 -0
- wandb/sdk/artifacts/_generated/run_input_artifacts.py +5 -9
- wandb/sdk/artifacts/_generated/run_output_artifacts.py +5 -9
- wandb/sdk/artifacts/_generated/type_info.py +2 -2
- wandb/sdk/artifacts/_generated/unlink_artifact.py +3 -5
- wandb/sdk/artifacts/_generated/update_artifact.py +3 -3
- wandb/sdk/artifacts/_generated/update_artifact_collection_type.py +28 -0
- wandb/sdk/artifacts/_generated/update_artifact_portfolio.py +7 -16
- wandb/sdk/artifacts/_generated/update_artifact_sequence.py +7 -16
- wandb/sdk/artifacts/_generated/upsert_registry.py +25 -0
- wandb/sdk/artifacts/_gqlutils.py +170 -6
- wandb/sdk/artifacts/_models/__init__.py +9 -0
- wandb/sdk/artifacts/_models/artifact_collection.py +109 -0
- wandb/sdk/artifacts/_models/manifest.py +26 -0
- wandb/sdk/artifacts/_models/pagination.py +26 -0
- wandb/sdk/artifacts/_models/registry.py +100 -0
- wandb/sdk/artifacts/_validators.py +45 -27
- wandb/sdk/artifacts/artifact.py +220 -215
- wandb/sdk/artifacts/artifact_file_cache.py +1 -1
- wandb/sdk/artifacts/artifact_manifest.py +37 -32
- wandb/sdk/artifacts/artifact_manifest_entry.py +80 -125
- wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +43 -61
- wandb/sdk/artifacts/storage_handlers/gcs_handler.py +8 -6
- wandb/sdk/data_types/image.py +2 -2
- wandb/sdk/interface/interface.py +72 -64
- wandb/sdk/interface/interface_queue.py +27 -18
- wandb/sdk/interface/interface_shared.py +61 -23
- wandb/sdk/interface/interface_sock.py +9 -5
- wandb/sdk/internal/_generated/server_features_query.py +4 -4
- wandb/sdk/launch/inputs/schema.py +13 -10
- wandb/sdk/lib/apikey.py +8 -12
- wandb/sdk/lib/asyncio_compat.py +1 -1
- wandb/sdk/lib/asyncio_manager.py +5 -5
- wandb/sdk/lib/console_capture.py +38 -30
- wandb/sdk/lib/progress.py +159 -64
- wandb/sdk/lib/retry.py +3 -2
- wandb/sdk/lib/service/service_connection.py +2 -2
- wandb/sdk/lib/wb_logging.py +2 -1
- wandb/sdk/mailbox/mailbox.py +1 -1
- wandb/sdk/wandb_init.py +10 -13
- wandb/sdk/wandb_run.py +9 -46
- wandb/sdk/wandb_settings.py +102 -19
- {wandb-0.22.2.dist-info → wandb-0.22.3.dist-info}/METADATA +2 -1
- {wandb-0.22.2.dist-info → wandb-0.22.3.dist-info}/RECORD +135 -134
- wandb/sdk/artifacts/_generated/artifact_via_membership_by_name.py +0 -26
- wandb/sdk/artifacts/_generated/create_artifact_collection_tag_assignments.py +0 -36
- wandb/sdk/artifacts/_generated/delete_artifact_collection_tag_assignments.py +0 -25
- wandb/sdk/artifacts/_generated/move_artifact_collection.py +0 -35
- wandb/sdk/projects/_generated/__init__.py +0 -26
- wandb/sdk/projects/_generated/delete_project.py +0 -22
- wandb/sdk/projects/_generated/enums.py +0 -4
- wandb/sdk/projects/_generated/fragments.py +0 -41
- wandb/sdk/projects/_generated/input_types.py +0 -13
- wandb/sdk/projects/_generated/operations.py +0 -88
- wandb/sdk/projects/_generated/rename_project.py +0 -27
- wandb/sdk/projects/_generated/upsert_registry_project.py +0 -27
- {wandb-0.22.2.dist-info → wandb-0.22.3.dist-info}/WHEEL +0 -0
- {wandb-0.22.2.dist-info → wandb-0.22.3.dist-info}/entry_points.txt +0 -0
- {wandb-0.22.2.dist-info → wandb-0.22.3.dist-info}/licenses/LICENSE +0 -0
wandb/sdk/artifacts/_gqlutils.py
CHANGED
|
@@ -1,10 +1,30 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
from contextlib import suppress
|
|
4
|
+
from dataclasses import dataclass
|
|
3
5
|
from functools import lru_cache
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
4
7
|
|
|
5
|
-
from wandb_gql import
|
|
8
|
+
from wandb_gql import gql
|
|
6
9
|
|
|
7
|
-
from .
|
|
10
|
+
from wandb._iterutils import one
|
|
11
|
+
from wandb.errors import UnsupportedError
|
|
12
|
+
from wandb.proto.wandb_internal_pb2 import ServerFeature
|
|
13
|
+
from wandb.sdk.artifacts._generated.fetch_org_info_from_entity import (
|
|
14
|
+
FetchOrgInfoFromEntityEntity,
|
|
15
|
+
)
|
|
16
|
+
from wandb.sdk.internal._generated import SERVER_FEATURES_QUERY_GQL, ServerFeaturesQuery
|
|
17
|
+
|
|
18
|
+
from ._generated import (
|
|
19
|
+
FETCH_ORG_INFO_FROM_ENTITY_GQL,
|
|
20
|
+
TYPE_INFO_GQL,
|
|
21
|
+
FetchOrgInfoFromEntity,
|
|
22
|
+
TypeInfo,
|
|
23
|
+
TypeInfoFragment,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
if TYPE_CHECKING:
|
|
27
|
+
from wandb.apis.public import RetryingClient
|
|
8
28
|
|
|
9
29
|
OMITTABLE_ARTIFACT_FIELDS = frozenset(
|
|
10
30
|
{
|
|
@@ -18,13 +38,58 @@ OMITTABLE_ARTIFACT_FIELDS = frozenset(
|
|
|
18
38
|
|
|
19
39
|
|
|
20
40
|
@lru_cache(maxsize=16)
|
|
21
|
-
def type_info(client:
|
|
41
|
+
def type_info(client: RetryingClient, typename: str) -> TypeInfoFragment | None:
|
|
22
42
|
"""Returns the type info for a given GraphQL type."""
|
|
23
43
|
data = client.execute(gql(TYPE_INFO_GQL), variable_values={"name": typename})
|
|
24
44
|
return TypeInfo.model_validate(data).type
|
|
25
45
|
|
|
26
46
|
|
|
27
|
-
|
|
47
|
+
@lru_cache(maxsize=16)
|
|
48
|
+
def org_info_from_entity(
|
|
49
|
+
client: RetryingClient, entity: str
|
|
50
|
+
) -> FetchOrgInfoFromEntityEntity | None:
|
|
51
|
+
"""Returns the organization info for a given entity."""
|
|
52
|
+
gql_op = gql(FETCH_ORG_INFO_FROM_ENTITY_GQL)
|
|
53
|
+
data = client.execute(gql_op, variable_values={"entity": entity})
|
|
54
|
+
return FetchOrgInfoFromEntity.model_validate(data).entity
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@lru_cache(maxsize=16)
|
|
58
|
+
def server_features(client: RetryingClient) -> dict[str, bool]:
|
|
59
|
+
"""Returns a mapping of `{server_feature_name (str) -> is_enabled (bool)}`.
|
|
60
|
+
|
|
61
|
+
Results are cached per client instance.
|
|
62
|
+
"""
|
|
63
|
+
try:
|
|
64
|
+
response = client.execute(gql(SERVER_FEATURES_QUERY_GQL))
|
|
65
|
+
except Exception as e:
|
|
66
|
+
# Unfortunately we currently have to match on the text of the error message,
|
|
67
|
+
# as the `gql` client raises `Exception` rather than a more specific error.
|
|
68
|
+
if 'Cannot query field "features" on type "ServerInfo".' in str(e):
|
|
69
|
+
return {}
|
|
70
|
+
raise
|
|
71
|
+
|
|
72
|
+
result = ServerFeaturesQuery.model_validate(response)
|
|
73
|
+
if (server_info := result.server_info) and (features := server_info.features):
|
|
74
|
+
return {feat.name: feat.is_enabled for feat in features if feat}
|
|
75
|
+
return {}
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def server_supports(client: RetryingClient, feature: str | int) -> bool:
|
|
79
|
+
"""Return whether the current server supports the given feature.
|
|
80
|
+
|
|
81
|
+
Good to use for features that have a fallback mechanism for older servers.
|
|
82
|
+
"""
|
|
83
|
+
# If we're given the protobuf enum value, convert to a string name.
|
|
84
|
+
# NOTE: We deliberately use names (str) instead of enum values (int)
|
|
85
|
+
# as the keys here, since:
|
|
86
|
+
# - the server identifies features by their name, rather than (client-side) enum value
|
|
87
|
+
# - the defined list of client-side flags may be behind the server-side list of flags
|
|
88
|
+
feature_name = ServerFeature.Name(feature) if isinstance(feature, int) else feature
|
|
89
|
+
return server_features(client).get(feature_name) or False
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def supports_enable_tracking_var(client: RetryingClient) -> bool:
|
|
28
93
|
"""Returns True if the server supports the `enableTracking` variable for the `Project.artifact(...)` field."""
|
|
29
94
|
typ = type_info(client, "Project")
|
|
30
95
|
if (
|
|
@@ -36,12 +101,111 @@ def supports_enable_tracking_var(client: Client) -> bool:
|
|
|
36
101
|
return False
|
|
37
102
|
|
|
38
103
|
|
|
39
|
-
def allowed_fields(client:
|
|
104
|
+
def allowed_fields(client: RetryingClient, typename: str) -> set[str]:
|
|
40
105
|
"""Returns the allowed field names for a given GraphQL type."""
|
|
41
106
|
typ = type_info(client, typename)
|
|
42
107
|
return {f.name for f in typ.fields} if (typ and typ.fields) else set()
|
|
43
108
|
|
|
44
109
|
|
|
45
|
-
def omit_artifact_fields(client:
|
|
110
|
+
def omit_artifact_fields(client: RetryingClient) -> set[str]:
|
|
46
111
|
"""Return names of Artifact fields to remove from GraphQL requests (for server compatibility)."""
|
|
47
112
|
return set(OMITTABLE_ARTIFACT_FIELDS) - allowed_fields(client, "Artifact")
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
@dataclass(frozen=True)
|
|
116
|
+
class OrgInfo:
|
|
117
|
+
org_name: str
|
|
118
|
+
entity_name: str
|
|
119
|
+
|
|
120
|
+
def __contains__(self, other: str) -> bool:
|
|
121
|
+
return other in {self.org_name, self.entity_name}
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def resolve_org_entity_name(
|
|
125
|
+
client: RetryingClient,
|
|
126
|
+
non_org_entity: str,
|
|
127
|
+
org_or_entity: str | None = None,
|
|
128
|
+
) -> str:
|
|
129
|
+
# resolveOrgEntityName fetches the portfolio's org entity's name.
|
|
130
|
+
#
|
|
131
|
+
# The org_or_org_entity parameter may be empty, an org's display name, or an org entity name.
|
|
132
|
+
#
|
|
133
|
+
# If the server doesn't support fetching the org name of a portfolio, then this returns
|
|
134
|
+
# the org_or_org_entity parameter, or an error if it is empty. Otherwise, this returns the
|
|
135
|
+
# fetched value after validating that the given organization, if not empty, matches
|
|
136
|
+
# either the org's display or entity name.
|
|
137
|
+
if not non_org_entity:
|
|
138
|
+
raise ValueError("Entity name is required to resolve org entity name.")
|
|
139
|
+
|
|
140
|
+
if "orgEntity" not in allowed_fields(client, "Organization"):
|
|
141
|
+
if org_or_entity:
|
|
142
|
+
# Server doesn't support fetching orgEntity to match against,
|
|
143
|
+
# so assume orgEntity as provided is already correct.
|
|
144
|
+
return org_or_entity
|
|
145
|
+
|
|
146
|
+
raise UnsupportedError(
|
|
147
|
+
"Fetching Registry artifacts without inputting an organization "
|
|
148
|
+
"is unavailable for your server version. "
|
|
149
|
+
"Please upgrade your server to 0.50.0 or later."
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
# Otherwise, fetch candidate orgs to verify/identify the correct orgEntity name, if possible.
|
|
153
|
+
entity = org_info_from_entity(client, non_org_entity)
|
|
154
|
+
|
|
155
|
+
# Parse possible organization(s) from the response
|
|
156
|
+
|
|
157
|
+
# ----------------------------------------------------------------------------
|
|
158
|
+
# If a team entity was provided, there should be a single organization under team/org entity type.
|
|
159
|
+
if entity and (org := entity.organization) and (org_entity := org.org_entity):
|
|
160
|
+
# Make sure the org_or_org_entity name, if given, matches the org or org entity name before returning the org entity.
|
|
161
|
+
org_info = OrgInfo(org_name=org.name, entity_name=org_entity.name)
|
|
162
|
+
if (not org_or_entity) or (org_or_entity in org_info):
|
|
163
|
+
return org_entity.name
|
|
164
|
+
|
|
165
|
+
# ----------------------------------------------------------------------------
|
|
166
|
+
# If a personal entity was provided, there may be multiple organizations that the user belongs to
|
|
167
|
+
if entity and (user := entity.user) and (orgs := user.organizations):
|
|
168
|
+
org_infos = [
|
|
169
|
+
OrgInfo(org_name=org.name, entity_name=org_entity.name)
|
|
170
|
+
for org in orgs
|
|
171
|
+
if (org_entity := org.org_entity)
|
|
172
|
+
]
|
|
173
|
+
if org_or_entity:
|
|
174
|
+
with suppress(StopIteration):
|
|
175
|
+
return next(
|
|
176
|
+
info.entity_name for info in org_infos if (org_or_entity in info)
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
if len(org_infos) == 1:
|
|
180
|
+
raise ValueError(
|
|
181
|
+
f"Expecting the organization name or entity name to match {org_infos[0].org_name!r} "
|
|
182
|
+
f"and cannot be linked/fetched with {org_or_entity!r}. "
|
|
183
|
+
"Please update the target path with the correct organization name."
|
|
184
|
+
)
|
|
185
|
+
else:
|
|
186
|
+
raise ValueError(
|
|
187
|
+
"Personal entity belongs to multiple organizations "
|
|
188
|
+
f"and cannot be linked/fetched with {org_or_entity!r}. "
|
|
189
|
+
"Please update the target path with the correct organization name "
|
|
190
|
+
"or use a team entity in the entity settings."
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
else:
|
|
194
|
+
# If no input organization provided, error if entity belongs to:
|
|
195
|
+
# - multiple orgs, because we cannot determine which one to use.
|
|
196
|
+
# - no orgs, because there's nothing to use.
|
|
197
|
+
return one(
|
|
198
|
+
(org.entity_name for org in org_infos),
|
|
199
|
+
too_short=ValueError(
|
|
200
|
+
f"Unable to resolve an organization associated with personal entity: {non_org_entity!r}. "
|
|
201
|
+
"This could be because its a personal entity that doesn't belong to any organizations. "
|
|
202
|
+
"Please specify the organization in the Registry path or use a team entity in the entity settings."
|
|
203
|
+
),
|
|
204
|
+
too_long=ValueError(
|
|
205
|
+
f"Personal entity {non_org_entity!r} belongs to multiple organizations "
|
|
206
|
+
"and cannot be used without specifying the organization name. "
|
|
207
|
+
"Please specify the organization in the Registry path or use a team entity in the entity settings."
|
|
208
|
+
),
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
raise ValueError(f"Unable to find organization for entity {non_org_entity!r}.")
|
|
@@ -2,3 +2,12 @@
|
|
|
2
2
|
|
|
3
3
|
Excludes GraphQL-generated classes.
|
|
4
4
|
"""
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
"ArtifactsBase",
|
|
8
|
+
"RegistryData",
|
|
9
|
+
"ArtifactCollectionData",
|
|
10
|
+
]
|
|
11
|
+
from .artifact_collection import ArtifactCollectionData
|
|
12
|
+
from .base_model import ArtifactsBase
|
|
13
|
+
from .registry import RegistryData
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, List, Optional, Tuple
|
|
4
|
+
|
|
5
|
+
from pydantic import ConfigDict, Field
|
|
6
|
+
from typing_extensions import Self
|
|
7
|
+
|
|
8
|
+
from wandb._pydantic import field_validator
|
|
9
|
+
from wandb.sdk.artifacts._generated import ArtifactCollectionFragment
|
|
10
|
+
from wandb.sdk.artifacts._validators import (
|
|
11
|
+
SOURCE_COLLECTION_TYPENAME,
|
|
12
|
+
validate_artifact_name,
|
|
13
|
+
validate_tags,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
from .base_model import ArtifactsBase
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ArtifactCollectionData(ArtifactsBase):
|
|
20
|
+
"""Transport-free model for local `ArtifactCollection` data.
|
|
21
|
+
|
|
22
|
+
For now, this is separated from the public `ArtifactCollection` model
|
|
23
|
+
to more easily ensure continuity in the public `ArtifactCollection` API.
|
|
24
|
+
|
|
25
|
+
Note:
|
|
26
|
+
In a future change, consider making _this_ the public `ArtifactCollection` instead, i.e.:
|
|
27
|
+
- Replace the _existing_ `ArtifactCollection` class with this one
|
|
28
|
+
- Rename _this_ class to `ArtifactCollection`
|
|
29
|
+
Note that this would be a breaking change.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
model_config = ConfigDict(
|
|
33
|
+
str_min_length=1, # Strings cannot be empty
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
typename__: str = Field(alias="__typename", frozen=True, repr=False)
|
|
37
|
+
"""The GraphQL `__typename` for this object."""
|
|
38
|
+
|
|
39
|
+
id: str = Field(frozen=True, repr=False)
|
|
40
|
+
"""The encoded GraphQL ID for this object."""
|
|
41
|
+
|
|
42
|
+
name: str
|
|
43
|
+
"""The name of this collection."""
|
|
44
|
+
|
|
45
|
+
type: str
|
|
46
|
+
"""The artifact type of this collection."""
|
|
47
|
+
|
|
48
|
+
description: Optional[str] = None
|
|
49
|
+
"""The description, if any, for this collection."""
|
|
50
|
+
|
|
51
|
+
created_at: str = Field(frozen=True)
|
|
52
|
+
"""When this collection was created."""
|
|
53
|
+
|
|
54
|
+
project: str = Field(frozen=True)
|
|
55
|
+
"""The name of this collection's project."""
|
|
56
|
+
|
|
57
|
+
entity: str = Field(frozen=True)
|
|
58
|
+
"""The name of the entity that owns this collection's project."""
|
|
59
|
+
|
|
60
|
+
aliases: Optional[Tuple[str, ...]] = Field(default=None, frozen=True)
|
|
61
|
+
"""All aliases assigned to artifact versions within this collection.
|
|
62
|
+
|
|
63
|
+
Note:
|
|
64
|
+
`None` should signal that aliases haven't been (or couldn't be) fetched and parsed yet,
|
|
65
|
+
for any reason, NOT the actual absence of aliases in this collection.
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
tags: List[str] = Field(default_factory=list)
|
|
69
|
+
"""The tags assigned to this collection.
|
|
70
|
+
|
|
71
|
+
Note that this is distinct from any tags assigned to individual artifact versions within this collection.
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
@field_validator("name", mode="plain")
|
|
75
|
+
def _validate_name(cls, v: str) -> str:
|
|
76
|
+
return validate_artifact_name(v)
|
|
77
|
+
|
|
78
|
+
@field_validator("tags", mode="plain")
|
|
79
|
+
def _validate_tags(cls, v: Any) -> list[str]:
|
|
80
|
+
"""Ensure tags is a validated, deduped list of (str) tag names."""
|
|
81
|
+
return validate_tags(v)
|
|
82
|
+
|
|
83
|
+
@property
|
|
84
|
+
def is_sequence(self) -> bool:
|
|
85
|
+
"""Returns True if the artifact collection is an ArtifactSequence (source collection)."""
|
|
86
|
+
return self.typename__ == SOURCE_COLLECTION_TYPENAME
|
|
87
|
+
|
|
88
|
+
@classmethod
|
|
89
|
+
def from_fragment(cls, obj: ArtifactCollectionFragment) -> Self:
|
|
90
|
+
"""Instantiate this type from a GraphQL fragment."""
|
|
91
|
+
if obj.project is None:
|
|
92
|
+
raise ValueError(f"Missing project info in {type(obj)!r} data")
|
|
93
|
+
|
|
94
|
+
return cls(
|
|
95
|
+
typename__=obj.typename__,
|
|
96
|
+
id=obj.id,
|
|
97
|
+
name=obj.name,
|
|
98
|
+
type=obj.type.name,
|
|
99
|
+
description=obj.description,
|
|
100
|
+
created_at=obj.created_at,
|
|
101
|
+
project=obj.project.name,
|
|
102
|
+
entity=obj.project.entity.name,
|
|
103
|
+
tags=[e.node.name for e in obj.tags.edges if e.node],
|
|
104
|
+
aliases=(
|
|
105
|
+
[e.node.alias for e in obj.aliases.edges if e.node]
|
|
106
|
+
if obj.aliases
|
|
107
|
+
else []
|
|
108
|
+
),
|
|
109
|
+
)
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from typing import Any, Dict, Literal, final
|
|
2
|
+
|
|
3
|
+
from wandb._pydantic import field_validator, to_camel
|
|
4
|
+
from wandb.sdk.artifacts.artifact_manifest_entry import ArtifactManifestEntry
|
|
5
|
+
|
|
6
|
+
from .base_model import ArtifactsBase
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@final
|
|
10
|
+
class ArtifactManifestV1Data(ArtifactsBase, alias_generator=to_camel):
|
|
11
|
+
"""Data model for the v1 artifact manifest."""
|
|
12
|
+
|
|
13
|
+
version: Literal[1]
|
|
14
|
+
|
|
15
|
+
contents: Dict[str, ArtifactManifestEntry]
|
|
16
|
+
|
|
17
|
+
storage_policy: str
|
|
18
|
+
storage_policy_config: Dict[str, Any]
|
|
19
|
+
|
|
20
|
+
@field_validator("contents", mode="before")
|
|
21
|
+
def _validate_entries(cls, v: Any) -> Any:
|
|
22
|
+
# The dict keys should be the `entry.path` values, but they've
|
|
23
|
+
# historically been dropped from the JSON objects. This restores
|
|
24
|
+
# them on instantiation.
|
|
25
|
+
# Pydantic will handle converting dicts -> ArtifactManifestEntries.
|
|
26
|
+
return {path: {**dict(entry), "path": path} for path, entry in v.items()}
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
"""Artifacts-specific data models for handling paginated results from GraphQL queries."""
|
|
2
|
+
|
|
3
|
+
from wandb._pydantic import Connection, ConnectionWithTotal
|
|
4
|
+
|
|
5
|
+
from .._generated.fragments import (
|
|
6
|
+
ArtifactCollectionFragment,
|
|
7
|
+
ArtifactFragment,
|
|
8
|
+
ArtifactMembershipFragment,
|
|
9
|
+
ArtifactTypeFragment,
|
|
10
|
+
FileFragment,
|
|
11
|
+
FileWithUrlFragment,
|
|
12
|
+
RegistryCollectionFragment,
|
|
13
|
+
RegistryFragment,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
ArtifactTypeConnection = Connection[ArtifactTypeFragment]
|
|
17
|
+
ArtifactCollectionConnection = ConnectionWithTotal[ArtifactCollectionFragment]
|
|
18
|
+
ArtifactMembershipConnection = Connection[ArtifactMembershipFragment]
|
|
19
|
+
|
|
20
|
+
FileWithUrlConnection = Connection[FileWithUrlFragment]
|
|
21
|
+
ArtifactFileConnection = Connection[FileFragment]
|
|
22
|
+
|
|
23
|
+
RunArtifactConnection = ConnectionWithTotal[ArtifactFragment]
|
|
24
|
+
|
|
25
|
+
RegistryConnection = Connection[RegistryFragment]
|
|
26
|
+
RegistryCollectionConnection = ConnectionWithTotal[RegistryCollectionFragment]
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, Optional
|
|
4
|
+
|
|
5
|
+
from pydantic import ConfigDict, Field
|
|
6
|
+
from typing_extensions import Self
|
|
7
|
+
|
|
8
|
+
from wandb._pydantic import GQLId, field_validator
|
|
9
|
+
from wandb._strutils import nameof
|
|
10
|
+
from wandb.apis.public.registries._freezable_list import AddOnlyArtifactTypesList
|
|
11
|
+
from wandb.apis.public.registries._utils import Visibility
|
|
12
|
+
from wandb.sdk.artifacts._generated import RegistryFragment
|
|
13
|
+
from wandb.sdk.artifacts._generated.fragments import RegistryFragmentArtifactTypes
|
|
14
|
+
from wandb.sdk.artifacts._validators import REGISTRY_PREFIX, remove_registry_prefix
|
|
15
|
+
|
|
16
|
+
from .base_model import ArtifactsBase
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class RegistryData(ArtifactsBase):
|
|
20
|
+
"""Transport-free model for local `Registry` data.
|
|
21
|
+
|
|
22
|
+
For now, this is separated from the public `Registry` class
|
|
23
|
+
to more easily ensure continuity in the public `Registry` API.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
model_config = ConfigDict(
|
|
27
|
+
str_min_length=1, # Strings cannot be empty
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
id: GQLId = Field(frozen=True)
|
|
31
|
+
"""The unique, encoded ID for this registry."""
|
|
32
|
+
|
|
33
|
+
created_at: str = Field(frozen=True)
|
|
34
|
+
"""When this registry was created."""
|
|
35
|
+
|
|
36
|
+
updated_at: Optional[str] = Field(frozen=True)
|
|
37
|
+
"""When this registry was last updated."""
|
|
38
|
+
|
|
39
|
+
organization: str = Field(frozen=True)
|
|
40
|
+
"""The organization of the registry."""
|
|
41
|
+
|
|
42
|
+
entity: str = Field(frozen=True)
|
|
43
|
+
"""The organization entity of the registry."""
|
|
44
|
+
|
|
45
|
+
name: str
|
|
46
|
+
"""The name of the registry without the `wandb-registry-` project prefix."""
|
|
47
|
+
|
|
48
|
+
description: Optional[str] = None
|
|
49
|
+
"""The description, if any, of the registry."""
|
|
50
|
+
|
|
51
|
+
allow_all_artifact_types: bool
|
|
52
|
+
"""Whether all artifact types are allowed in the registry."""
|
|
53
|
+
|
|
54
|
+
artifact_types: AddOnlyArtifactTypesList = Field(
|
|
55
|
+
default_factory=AddOnlyArtifactTypesList
|
|
56
|
+
)
|
|
57
|
+
"""The artifact types allowed in the registry.
|
|
58
|
+
|
|
59
|
+
The meaning of this list depends on the value of `allow_all_artifact_types`:
|
|
60
|
+
- If True: `artifact_types` are the previously-saved or currently-used types in the registry.
|
|
61
|
+
- If False: `artifact_types` are the only allowed artifact types in the registry.
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
visibility: Visibility = Field(alias="access")
|
|
65
|
+
"""The visibility of the registry."""
|
|
66
|
+
|
|
67
|
+
@property
|
|
68
|
+
def full_name(self) -> str:
|
|
69
|
+
"""The full project name of the registry, including the expected `wandb-registry-` prefix."""
|
|
70
|
+
return f"{REGISTRY_PREFIX}{self.name}"
|
|
71
|
+
|
|
72
|
+
@field_validator("artifact_types", mode="plain")
|
|
73
|
+
def _validate_artifact_types(cls, v: Any) -> AddOnlyArtifactTypesList:
|
|
74
|
+
"""Coerce `artifact_types` to an AddOnlyArtifactTypesList."""
|
|
75
|
+
if isinstance(v, RegistryFragmentArtifactTypes):
|
|
76
|
+
# This is a GQL connection object, so we need to extract the node names
|
|
77
|
+
return AddOnlyArtifactTypesList(e.node.name for e in v.edges if e.node)
|
|
78
|
+
|
|
79
|
+
# By default, assume we were passed an iterable of strings
|
|
80
|
+
return AddOnlyArtifactTypesList(v)
|
|
81
|
+
|
|
82
|
+
@classmethod
|
|
83
|
+
def from_fragment(cls, obj: RegistryFragment) -> Self:
|
|
84
|
+
if not obj.entity.organization:
|
|
85
|
+
raise ValueError(
|
|
86
|
+
f"Unable to parse registry organization from {nameof(type(obj))!r} object"
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
return cls(
|
|
90
|
+
id=obj.id,
|
|
91
|
+
created_at=obj.created_at,
|
|
92
|
+
updated_at=obj.updated_at,
|
|
93
|
+
organization=obj.entity.organization.name,
|
|
94
|
+
entity=obj.entity.name,
|
|
95
|
+
name=remove_registry_prefix(obj.name),
|
|
96
|
+
description=obj.description,
|
|
97
|
+
allow_all_artifact_types=obj.allow_all_artifact_types,
|
|
98
|
+
artifact_types=obj.artifact_types,
|
|
99
|
+
visibility=obj.access,
|
|
100
|
+
)
|
|
@@ -6,7 +6,17 @@ import json
|
|
|
6
6
|
import re
|
|
7
7
|
from dataclasses import dataclass, field, replace
|
|
8
8
|
from functools import wraps
|
|
9
|
-
from typing import
|
|
9
|
+
from typing import (
|
|
10
|
+
TYPE_CHECKING,
|
|
11
|
+
Any,
|
|
12
|
+
Callable,
|
|
13
|
+
Dict,
|
|
14
|
+
Iterable,
|
|
15
|
+
Literal,
|
|
16
|
+
Optional,
|
|
17
|
+
TypeVar,
|
|
18
|
+
cast,
|
|
19
|
+
)
|
|
10
20
|
|
|
11
21
|
from pydantic.dataclasses import dataclass as pydantic_dataclass
|
|
12
22
|
from typing_extensions import Concatenate, ParamSpec, Self
|
|
@@ -35,8 +45,8 @@ MAX_ARTIFACT_METADATA_KEYS: Final[int] = 100
|
|
|
35
45
|
ARTIFACT_NAME_MAXLEN: Final[int] = 128
|
|
36
46
|
ARTIFACT_NAME_INVALID_CHARS: Final[frozenset[str]] = frozenset({"/"})
|
|
37
47
|
|
|
38
|
-
|
|
39
|
-
|
|
48
|
+
LINKED_COLLECTION_TYPENAME: Final[str] = gql_typename(ArtifactPortfolioTypeFields)
|
|
49
|
+
SOURCE_COLLECTION_TYPENAME: Final[str] = gql_typename(ArtifactSequenceTypeFields)
|
|
40
50
|
|
|
41
51
|
|
|
42
52
|
@dataclass
|
|
@@ -85,11 +95,14 @@ def validate_artifact_name(name: str) -> str:
|
|
|
85
95
|
return name
|
|
86
96
|
|
|
87
97
|
|
|
88
|
-
INVALID_URL_CHARACTERS = ("
|
|
98
|
+
INVALID_URL_CHARACTERS: Final[frozenset[str]] = frozenset("/\\#?%:\r\n")
|
|
89
99
|
|
|
100
|
+
PROJECT_NAME_MAXLEN: Final[int] = 128
|
|
101
|
+
REGISTRY_NAME_MAXLEN: Final[int] = PROJECT_NAME_MAXLEN - len(REGISTRY_PREFIX)
|
|
90
102
|
|
|
91
|
-
|
|
92
|
-
|
|
103
|
+
|
|
104
|
+
def validate_project_name(name: str) -> str:
|
|
105
|
+
"""Validates a project name according to W&B rules, returning the original name if successful.
|
|
93
106
|
|
|
94
107
|
Args:
|
|
95
108
|
name: The project name string.
|
|
@@ -97,18 +110,16 @@ def validate_project_name(name: str) -> None:
|
|
|
97
110
|
Raises:
|
|
98
111
|
ValueError: If the name is invalid (too long or contains invalid characters).
|
|
99
112
|
"""
|
|
100
|
-
max_len = 128
|
|
101
|
-
|
|
102
113
|
if not name:
|
|
103
114
|
raise ValueError("Project name cannot be empty")
|
|
104
115
|
if not (registry_name := removeprefix(name, REGISTRY_PREFIX)):
|
|
105
116
|
raise ValueError("Registry name cannot be empty")
|
|
106
117
|
|
|
107
|
-
if len(name) >
|
|
118
|
+
if len(name) > PROJECT_NAME_MAXLEN:
|
|
108
119
|
if registry_name != name:
|
|
109
|
-
msg = f"Invalid registry name {registry_name!r}, must be {
|
|
120
|
+
msg = f"Invalid registry name {registry_name!r}, must be {REGISTRY_NAME_MAXLEN} characters or less"
|
|
110
121
|
else:
|
|
111
|
-
msg = f"Invalid project name {name!r}, must be {
|
|
122
|
+
msg = f"Invalid project name {name!r}, must be {PROJECT_NAME_MAXLEN} characters or less"
|
|
112
123
|
raise ValueError(msg)
|
|
113
124
|
|
|
114
125
|
# Find the first occurrence of any invalid character
|
|
@@ -118,6 +129,10 @@ def validate_project_name(name: str) -> None:
|
|
|
118
129
|
raise ValueError(
|
|
119
130
|
f"Invalid project/registry name {error_name!r}, cannot contain characters: {invalid_chars_repr!s}"
|
|
120
131
|
)
|
|
132
|
+
return name
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
ALIAS_NAME_INVALID_CHARS: Final[frozenset[str]] = frozenset("/:")
|
|
121
136
|
|
|
122
137
|
|
|
123
138
|
def validate_aliases(aliases: Collection[str] | str) -> list[str]:
|
|
@@ -127,36 +142,39 @@ def validate_aliases(aliases: Collection[str] | str) -> list[str]:
|
|
|
127
142
|
ValueError: If any of the aliases contain invalid characters.
|
|
128
143
|
"""
|
|
129
144
|
aliases_list = always_list(aliases)
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
if any(char in alias for alias in aliases_list for char in invalid_chars):
|
|
145
|
+
if any(ALIAS_NAME_INVALID_CHARS.intersection(alias) for alias in aliases_list):
|
|
146
|
+
invalid_repr = ", ".join(sorted(map(repr, ALIAS_NAME_INVALID_CHARS)))
|
|
133
147
|
raise ValueError(
|
|
134
|
-
f"Aliases must not contain any of the following characters: {
|
|
148
|
+
f"Aliases must not contain any of the following characters: {invalid_repr}"
|
|
135
149
|
)
|
|
136
150
|
return aliases_list
|
|
137
151
|
|
|
138
152
|
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
153
|
+
ARTIFACT_TYPE_NAME_MAXLEN: Final[int] = 128
|
|
154
|
+
ARTIFACT_TYPE_NAME_INVALID_CHARS: Final[frozenset[str]] = frozenset("/:")
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def validate_artifact_types(artifact_types: Collection[str] | str) -> list[str]:
|
|
158
|
+
"""Validate the artifact type names and return them as a list."""
|
|
159
|
+
artifact_types_list = always_list(artifact_types)
|
|
143
160
|
if any(
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
for
|
|
161
|
+
len(name) > ARTIFACT_TYPE_NAME_MAXLEN
|
|
162
|
+
or ARTIFACT_TYPE_NAME_INVALID_CHARS.intersection(name)
|
|
163
|
+
for name in artifact_types_list
|
|
147
164
|
):
|
|
165
|
+
invalid_repr = ", ".join(sorted(map(repr, ARTIFACT_TYPE_NAME_INVALID_CHARS)))
|
|
148
166
|
raise ValueError(
|
|
149
|
-
f"
|
|
150
|
-
|
|
167
|
+
f"Artifact types must not contain any of the following characters: {invalid_repr}"
|
|
168
|
+
f"and must be less than or equal to {ARTIFACT_TYPE_NAME_MAXLEN!r} characters"
|
|
151
169
|
)
|
|
152
|
-
return
|
|
170
|
+
return artifact_types_list
|
|
153
171
|
|
|
154
172
|
|
|
155
173
|
TAG_REGEX: re.Pattern[str] = re.compile(r"^[-\w]+( +[-\w]+)*$")
|
|
156
174
|
"""Regex pattern for valid tag names."""
|
|
157
175
|
|
|
158
176
|
|
|
159
|
-
def validate_tags(tags:
|
|
177
|
+
def validate_tags(tags: Iterable[str] | str) -> list[str]:
|
|
160
178
|
"""Validate the artifact tag names and return them as a deduped list.
|
|
161
179
|
|
|
162
180
|
In the case of duplicates, only keep the first tag, and otherwise maintain the order of appearance.
|
|
@@ -204,7 +222,7 @@ def validate_artifact_type(typ: str, name: str) -> str:
|
|
|
204
222
|
return typ
|
|
205
223
|
|
|
206
224
|
|
|
207
|
-
def validate_metadata(metadata: dict[str, Any] | None) -> dict[str, Any]:
|
|
225
|
+
def validate_metadata(metadata: dict[str, Any] | str | None) -> dict[str, Any]:
|
|
208
226
|
"""Validate the artifact metadata and return it as a dict."""
|
|
209
227
|
if metadata is None:
|
|
210
228
|
return {}
|