wandb 0.19.9__py3-none-any.whl → 0.19.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +1 -1
- wandb/__init__.pyi +6 -3
- wandb/_pydantic/__init__.py +14 -8
- wandb/_pydantic/base.py +51 -36
- wandb/_pydantic/utils.py +73 -0
- wandb/_pydantic/v1_compat.py +79 -57
- wandb/apis/public/__init__.py +2 -2
- wandb/apis/public/api.py +684 -4
- wandb/apis/public/artifacts.py +377 -677
- wandb/apis/public/automations.py +69 -0
- wandb/apis/public/integrations.py +180 -0
- wandb/apis/public/projects.py +29 -0
- wandb/apis/public/registries/__init__.py +0 -0
- wandb/apis/public/registries/_freezable_list.py +179 -0
- wandb/apis/public/{registries.py → registries/registries_search.py} +22 -129
- wandb/apis/public/registries/registry.py +357 -0
- wandb/apis/public/registries/utils.py +140 -0
- wandb/apis/public/runs.py +58 -56
- wandb/apis/public/utils.py +107 -1
- wandb/automations/__init__.py +73 -0
- wandb/automations/_filters/__init__.py +40 -0
- wandb/automations/_filters/expressions.py +181 -0
- wandb/automations/_filters/operators.py +258 -0
- wandb/automations/_filters/run_metrics.py +332 -0
- wandb/automations/_generated/__init__.py +177 -0
- wandb/automations/_generated/create_automation.py +17 -0
- wandb/automations/_generated/create_generic_webhook_integration.py +43 -0
- wandb/automations/_generated/delete_automation.py +17 -0
- wandb/automations/_generated/enums.py +33 -0
- wandb/automations/_generated/fragments.py +358 -0
- wandb/automations/_generated/generic_webhook_integrations_by_entity.py +22 -0
- wandb/automations/_generated/get_automations.py +24 -0
- wandb/automations/_generated/get_automations_by_entity.py +26 -0
- wandb/automations/_generated/input_types.py +104 -0
- wandb/automations/_generated/integrations_by_entity.py +22 -0
- wandb/automations/_generated/operations.py +647 -0
- wandb/automations/_generated/slack_integrations_by_entity.py +22 -0
- wandb/automations/_generated/update_automation.py +17 -0
- wandb/automations/_utils.py +237 -0
- wandb/automations/_validators.py +165 -0
- wandb/automations/actions.py +220 -0
- wandb/automations/automations.py +87 -0
- wandb/automations/events.py +287 -0
- wandb/automations/integrations.py +45 -0
- wandb/automations/scopes.py +78 -0
- wandb/beta/workflows.py +9 -10
- wandb/bin/gpu_stats +0 -0
- wandb/cli/cli.py +3 -3
- wandb/env.py +11 -0
- wandb/integration/keras/keras.py +2 -1
- wandb/integration/langchain/wandb_tracer.py +2 -1
- wandb/jupyter.py +137 -118
- wandb/old/settings.py +4 -1
- wandb/old/summary.py +0 -2
- wandb/proto/v3/wandb_internal_pb2.py +297 -292
- wandb/proto/v3/wandb_settings_pb2.py +2 -2
- wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v4/wandb_internal_pb2.py +292 -292
- wandb/proto/v4/wandb_settings_pb2.py +2 -2
- wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v5/wandb_internal_pb2.py +292 -292
- wandb/proto/v5/wandb_settings_pb2.py +2 -2
- wandb/proto/v5/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v6/wandb_base_pb2.py +41 -0
- wandb/proto/v6/wandb_internal_pb2.py +393 -0
- wandb/proto/v6/wandb_server_pb2.py +78 -0
- wandb/proto/v6/wandb_settings_pb2.py +58 -0
- wandb/proto/v6/wandb_telemetry_pb2.py +52 -0
- wandb/proto/wandb_base_pb2.py +2 -0
- wandb/proto/wandb_deprecated.py +8 -0
- wandb/proto/wandb_internal_pb2.py +3 -1
- wandb/proto/wandb_server_pb2.py +2 -0
- wandb/proto/wandb_settings_pb2.py +2 -0
- wandb/proto/wandb_telemetry_pb2.py +2 -0
- wandb/sdk/artifacts/_generated/__init__.py +289 -0
- wandb/sdk/artifacts/_generated/add_aliases.py +21 -0
- wandb/sdk/artifacts/_generated/artifact_collection_membership_files.py +43 -0
- wandb/sdk/artifacts/_generated/artifact_version_files.py +36 -0
- wandb/sdk/artifacts/_generated/create_artifact_collection_tag_assignments.py +36 -0
- wandb/sdk/artifacts/_generated/delete_aliases.py +21 -0
- wandb/sdk/artifacts/_generated/delete_artifact_collection_tag_assignments.py +25 -0
- wandb/sdk/artifacts/_generated/delete_artifact_portfolio.py +35 -0
- wandb/sdk/artifacts/_generated/delete_artifact_sequence.py +35 -0
- wandb/sdk/artifacts/_generated/enums.py +17 -0
- wandb/sdk/artifacts/_generated/fetch_linked_artifacts.py +67 -0
- wandb/sdk/artifacts/_generated/fragments.py +221 -0
- wandb/sdk/artifacts/_generated/input_types.py +28 -0
- wandb/sdk/artifacts/_generated/move_artifact_collection.py +35 -0
- wandb/sdk/artifacts/_generated/operations.py +611 -0
- wandb/sdk/artifacts/_generated/project_artifact_collection.py +101 -0
- wandb/sdk/artifacts/_generated/project_artifact_collections.py +33 -0
- wandb/sdk/artifacts/_generated/project_artifact_type.py +24 -0
- wandb/sdk/artifacts/_generated/project_artifact_types.py +24 -0
- wandb/sdk/artifacts/_generated/project_artifacts.py +42 -0
- wandb/sdk/artifacts/_generated/run_input_artifacts.py +51 -0
- wandb/sdk/artifacts/_generated/run_output_artifacts.py +51 -0
- wandb/sdk/artifacts/_generated/update_artifact.py +26 -0
- wandb/sdk/artifacts/_generated/update_artifact_portfolio.py +35 -0
- wandb/sdk/artifacts/_generated/update_artifact_sequence.py +35 -0
- wandb/sdk/artifacts/_graphql_fragments.py +57 -79
- wandb/sdk/artifacts/_validators.py +120 -1
- wandb/sdk/artifacts/artifact.py +419 -215
- wandb/sdk/artifacts/artifact_file_cache.py +4 -6
- wandb/sdk/artifacts/artifact_manifest_entry.py +13 -3
- wandb/sdk/artifacts/storage_handlers/azure_handler.py +1 -0
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +182 -1
- wandb/sdk/artifacts/storage_policy.py +3 -0
- wandb/sdk/data_types/base_types/media.py +2 -3
- wandb/sdk/data_types/base_types/wb_value.py +34 -11
- wandb/sdk/data_types/html.py +36 -9
- wandb/sdk/data_types/image.py +12 -12
- wandb/sdk/data_types/table.py +5 -0
- wandb/sdk/data_types/trace_tree.py +2 -0
- wandb/sdk/data_types/utils.py +1 -1
- wandb/sdk/data_types/video.py +59 -57
- wandb/sdk/interface/interface.py +4 -3
- wandb/sdk/internal/internal_api.py +21 -31
- wandb/sdk/internal/profiler.py +6 -5
- wandb/sdk/internal/run.py +13 -6
- wandb/sdk/internal/sender.py +5 -2
- wandb/sdk/launch/sweeps/utils.py +8 -0
- wandb/sdk/lib/apikey.py +25 -4
- wandb/sdk/lib/asyncio_compat.py +1 -1
- wandb/sdk/lib/deprecate.py +13 -22
- wandb/sdk/lib/disabled.py +2 -1
- wandb/sdk/lib/printer.py +37 -8
- wandb/sdk/lib/printer_asyncio.py +46 -0
- wandb/sdk/lib/redirect.py +10 -5
- wandb/sdk/projects/_generated/__init__.py +47 -0
- wandb/sdk/projects/_generated/delete_project.py +22 -0
- wandb/sdk/projects/_generated/enums.py +4 -0
- wandb/sdk/projects/_generated/fetch_registry.py +22 -0
- wandb/sdk/projects/_generated/fragments.py +41 -0
- wandb/sdk/projects/_generated/input_types.py +13 -0
- wandb/sdk/projects/_generated/operations.py +88 -0
- wandb/sdk/projects/_generated/rename_project.py +27 -0
- wandb/sdk/projects/_generated/upsert_registry_project.py +27 -0
- wandb/sdk/service/server_sock.py +19 -14
- wandb/sdk/service/service.py +18 -8
- wandb/sdk/service/streams.py +5 -0
- wandb/sdk/verify/verify.py +6 -3
- wandb/sdk/wandb_init.py +217 -70
- wandb/sdk/wandb_login.py +13 -4
- wandb/sdk/wandb_run.py +419 -295
- wandb/sdk/wandb_settings.py +27 -10
- wandb/sdk/wandb_setup.py +61 -0
- wandb/util.py +33 -29
- {wandb-0.19.9.dist-info → wandb-0.19.11.dist-info}/METADATA +5 -5
- {wandb-0.19.9.dist-info → wandb-0.19.11.dist-info}/RECORD +152 -82
- wandb/_globals.py +0 -19
- wandb/sdk/internal/_generated/base.py +0 -226
- wandb/sdk/internal/_generated/typing_compat.py +0 -14
- {wandb-0.19.9.dist-info → wandb-0.19.11.dist-info}/WHEEL +0 -0
- {wandb-0.19.9.dist-info → wandb-0.19.11.dist-info}/entry_points.txt +0 -0
- {wandb-0.19.9.dist-info → wandb-0.19.11.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,69 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from itertools import chain
|
4
|
+
from typing import TYPE_CHECKING, Any, Iterable, Mapping
|
5
|
+
|
6
|
+
from pydantic import ValidationError
|
7
|
+
from typing_extensions import override
|
8
|
+
from wandb_graphql.language.ast import Document
|
9
|
+
|
10
|
+
from wandb.apis.paginator import Paginator, _Client
|
11
|
+
|
12
|
+
if TYPE_CHECKING:
|
13
|
+
from wandb.automations import Automation
|
14
|
+
from wandb.automations._generated import ProjectConnectionFields
|
15
|
+
|
16
|
+
|
17
|
+
class Automations(Paginator["Automation"]):
|
18
|
+
last_response: ProjectConnectionFields | None
|
19
|
+
_query: Document
|
20
|
+
|
21
|
+
def __init__(
|
22
|
+
self,
|
23
|
+
client: _Client,
|
24
|
+
variables: Mapping[str, Any],
|
25
|
+
per_page: int = 50,
|
26
|
+
_query: Document | None = None,
|
27
|
+
):
|
28
|
+
super().__init__(client, variables, per_page=per_page)
|
29
|
+
if _query is None:
|
30
|
+
raise RuntimeError(f"Query required for {type(self).__qualname__}")
|
31
|
+
self._query = _query
|
32
|
+
|
33
|
+
@property
|
34
|
+
def more(self) -> bool:
|
35
|
+
"""Whether there are more items to fetch."""
|
36
|
+
if self.last_response is None:
|
37
|
+
return True
|
38
|
+
return self.last_response.page_info.has_next_page
|
39
|
+
|
40
|
+
@property
|
41
|
+
def cursor(self) -> str | None:
|
42
|
+
"""The start cursor to use for the next page."""
|
43
|
+
if self.last_response is None:
|
44
|
+
return None
|
45
|
+
return self.last_response.page_info.end_cursor
|
46
|
+
|
47
|
+
@override
|
48
|
+
def _update_response(self) -> None:
|
49
|
+
"""Fetch the raw response data for the current page."""
|
50
|
+
from wandb.automations._generated import ProjectConnectionFields
|
51
|
+
|
52
|
+
data: dict[str, Any] = self.client.execute(
|
53
|
+
self._query, variable_values=self.variables
|
54
|
+
)
|
55
|
+
try:
|
56
|
+
page_data = data["searchScope"]["projects"]
|
57
|
+
self.last_response = ProjectConnectionFields.model_validate(page_data)
|
58
|
+
except (LookupError, AttributeError, ValidationError) as e:
|
59
|
+
raise ValueError("Unexpected response data") from e
|
60
|
+
|
61
|
+
def convert_objects(self) -> Iterable[Automation]:
|
62
|
+
"""Parse the page data into a list of objects."""
|
63
|
+
from wandb.automations import Automation
|
64
|
+
|
65
|
+
page = self.last_response
|
66
|
+
return [
|
67
|
+
Automation.model_validate(obj)
|
68
|
+
for obj in chain.from_iterable(edge.node.triggers for edge in page.edges)
|
69
|
+
]
|
@@ -0,0 +1,180 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import TYPE_CHECKING, Any, Iterable
|
4
|
+
|
5
|
+
from pydantic import ValidationError
|
6
|
+
from typing_extensions import override
|
7
|
+
from wandb_gql import gql
|
8
|
+
from wandb_graphql.language.ast import Document
|
9
|
+
|
10
|
+
from wandb.apis.paginator import Paginator
|
11
|
+
|
12
|
+
if TYPE_CHECKING:
|
13
|
+
from wandb.apis.paginator import _Client
|
14
|
+
from wandb.automations import Integration, SlackIntegration, WebhookIntegration
|
15
|
+
from wandb.automations._generated import (
|
16
|
+
GenericWebhookIntegrationConnectionFields,
|
17
|
+
IntegrationConnectionFields,
|
18
|
+
SlackIntegrationConnectionFields,
|
19
|
+
)
|
20
|
+
|
21
|
+
|
22
|
+
class Integrations(Paginator["Integration"]):
|
23
|
+
last_response: IntegrationConnectionFields | None
|
24
|
+
_query: Document
|
25
|
+
|
26
|
+
def __init__(self, client: _Client, variables: dict[str, Any], per_page: int = 50):
|
27
|
+
from wandb.automations._generated import INTEGRATIONS_BY_ENTITY_GQL
|
28
|
+
|
29
|
+
super().__init__(client, variables, per_page=per_page)
|
30
|
+
# All integrations for entity
|
31
|
+
self._query = gql(INTEGRATIONS_BY_ENTITY_GQL)
|
32
|
+
|
33
|
+
@property
|
34
|
+
def more(self) -> bool:
|
35
|
+
"""Whether there are more Integrations to fetch."""
|
36
|
+
if self.last_response is None:
|
37
|
+
return True
|
38
|
+
return self.last_response.page_info.has_next_page
|
39
|
+
|
40
|
+
@property
|
41
|
+
def cursor(self) -> str | None:
|
42
|
+
"""The start cursor to use for the next page."""
|
43
|
+
if self.last_response is None:
|
44
|
+
return None
|
45
|
+
return self.last_response.page_info.end_cursor
|
46
|
+
|
47
|
+
@override
|
48
|
+
def _update_response(self) -> None:
|
49
|
+
"""Fetch and parse the response data for the current page."""
|
50
|
+
from wandb.automations._generated import IntegrationConnectionFields
|
51
|
+
|
52
|
+
data: dict[str, Any] = self.client.execute(
|
53
|
+
self._query, variable_values=self.variables
|
54
|
+
)
|
55
|
+
try:
|
56
|
+
page_data = data["entity"]["integrations"]
|
57
|
+
self.last_response = IntegrationConnectionFields.model_validate(page_data)
|
58
|
+
except (LookupError, AttributeError, ValidationError) as e:
|
59
|
+
raise ValueError("Unexpected response data") from e
|
60
|
+
|
61
|
+
def convert_objects(self) -> Iterable[Integration]:
|
62
|
+
"""Parse the page data into a list of integrations."""
|
63
|
+
from wandb.automations.integrations import _IntegrationEdge
|
64
|
+
|
65
|
+
page = self.last_response
|
66
|
+
return [_IntegrationEdge.model_validate(edge).node for edge in page.edges]
|
67
|
+
|
68
|
+
|
69
|
+
class WebhookIntegrations(Paginator["WebhookIntegration"]):
|
70
|
+
last_response: GenericWebhookIntegrationConnectionFields | None
|
71
|
+
_query: Document
|
72
|
+
|
73
|
+
def __init__(self, client: _Client, variables: dict[str, Any], per_page: int = 50):
|
74
|
+
from wandb.automations._generated import (
|
75
|
+
GENERIC_WEBHOOK_INTEGRATIONS_BY_ENTITY_GQL,
|
76
|
+
)
|
77
|
+
|
78
|
+
super().__init__(client, variables, per_page=per_page)
|
79
|
+
# Webhook integrations for entity
|
80
|
+
self._query = gql(GENERIC_WEBHOOK_INTEGRATIONS_BY_ENTITY_GQL)
|
81
|
+
|
82
|
+
@property
|
83
|
+
def more(self) -> bool:
|
84
|
+
"""Whether there are more webhook integrations to fetch."""
|
85
|
+
if self.last_response is None:
|
86
|
+
return True
|
87
|
+
return self.last_response.page_info.has_next_page
|
88
|
+
|
89
|
+
@property
|
90
|
+
def cursor(self) -> str | None:
|
91
|
+
"""The start cursor to use for the next page."""
|
92
|
+
if self.last_response is None:
|
93
|
+
return None
|
94
|
+
return self.last_response.page_info.end_cursor
|
95
|
+
|
96
|
+
@override
|
97
|
+
def _update_response(self) -> None:
|
98
|
+
"""Fetch and parse the response data for the current page."""
|
99
|
+
from wandb.automations._generated import (
|
100
|
+
GenericWebhookIntegrationConnectionFields,
|
101
|
+
)
|
102
|
+
|
103
|
+
data: dict[str, Any] = self.client.execute(
|
104
|
+
self._query, variable_values=self.variables
|
105
|
+
)
|
106
|
+
try:
|
107
|
+
page_data = data["entity"]["integrations"]
|
108
|
+
self.last_response = (
|
109
|
+
GenericWebhookIntegrationConnectionFields.model_validate(page_data)
|
110
|
+
)
|
111
|
+
except (LookupError, AttributeError, ValidationError) as e:
|
112
|
+
raise ValueError("Unexpected response data") from e
|
113
|
+
|
114
|
+
def convert_objects(self) -> Iterable[WebhookIntegration]:
|
115
|
+
"""Parse the page data into a list of webhook integrations."""
|
116
|
+
from wandb.automations import WebhookIntegration
|
117
|
+
|
118
|
+
typename = "GenericWebhookIntegration"
|
119
|
+
return [
|
120
|
+
# Filter on typename__ needed because the GQL response still
|
121
|
+
# includes all integration types
|
122
|
+
WebhookIntegration.model_validate(node)
|
123
|
+
for edge in self.last_response.edges
|
124
|
+
if (node := edge.node) and (node.typename__ == typename)
|
125
|
+
]
|
126
|
+
|
127
|
+
|
128
|
+
class SlackIntegrations(Paginator["SlackIntegration"]):
|
129
|
+
last_response: SlackIntegrationConnectionFields | None
|
130
|
+
_query: Document
|
131
|
+
|
132
|
+
def __init__(self, client: _Client, variables: dict[str, Any], per_page: int = 50):
|
133
|
+
from wandb.automations._generated import SLACK_INTEGRATIONS_BY_ENTITY_GQL
|
134
|
+
|
135
|
+
super().__init__(client, variables, per_page=per_page)
|
136
|
+
# Slack integrations for entity
|
137
|
+
self._query = gql(SLACK_INTEGRATIONS_BY_ENTITY_GQL)
|
138
|
+
|
139
|
+
@property
|
140
|
+
def more(self) -> bool:
|
141
|
+
"""Whether there are more Slack integrations to fetch."""
|
142
|
+
if self.last_response is None:
|
143
|
+
return True
|
144
|
+
return self.last_response.page_info.has_next_page
|
145
|
+
|
146
|
+
@property
|
147
|
+
def cursor(self) -> str | None:
|
148
|
+
"""The start cursor to use for the next page."""
|
149
|
+
if self.last_response is None:
|
150
|
+
return None
|
151
|
+
return self.last_response.page_info.end_cursor
|
152
|
+
|
153
|
+
@override
|
154
|
+
def _update_response(self) -> None:
|
155
|
+
"""Fetch and parse the response data for the current page."""
|
156
|
+
from wandb.automations._generated import SlackIntegrationConnectionFields
|
157
|
+
|
158
|
+
data: dict[str, Any] = self.client.execute(
|
159
|
+
self._query, variable_values=self.variables
|
160
|
+
)
|
161
|
+
try:
|
162
|
+
page_data = data["entity"]["integrations"]
|
163
|
+
self.last_response = SlackIntegrationConnectionFields.model_validate(
|
164
|
+
page_data
|
165
|
+
)
|
166
|
+
except (LookupError, AttributeError, ValidationError) as e:
|
167
|
+
raise ValueError("Unexpected response data") from e
|
168
|
+
|
169
|
+
def convert_objects(self) -> Iterable[SlackIntegration]:
|
170
|
+
"""Parse the page data into a list of Slack integrations."""
|
171
|
+
from wandb.automations import SlackIntegration
|
172
|
+
|
173
|
+
typename = "SlackIntegration"
|
174
|
+
return [
|
175
|
+
# Filter on typename__ needed because the GQL response still
|
176
|
+
# includes all integration types
|
177
|
+
SlackIntegration.model_validate(node)
|
178
|
+
for edge in self.last_response.edges
|
179
|
+
if (node := edge.node) and (node.typename__ == typename)
|
180
|
+
]
|
wandb/apis/public/projects.py
CHANGED
@@ -1,5 +1,8 @@
|
|
1
1
|
"""Public API: projects."""
|
2
2
|
|
3
|
+
from contextlib import suppress
|
4
|
+
|
5
|
+
from requests import HTTPError
|
3
6
|
from wandb_gql import gql
|
4
7
|
|
5
8
|
from wandb.apis import public
|
@@ -153,3 +156,29 @@ class Project(Attrs):
|
|
153
156
|
)
|
154
157
|
for e in ret["project"]["sweeps"]["edges"]
|
155
158
|
]
|
159
|
+
|
160
|
+
_PROJECT_ID = gql(
|
161
|
+
"""
|
162
|
+
query ProjectID($projectName: String!, $entityName: String!) {
|
163
|
+
project(name: $projectName, entityName: $entityName) {
|
164
|
+
id
|
165
|
+
}
|
166
|
+
}
|
167
|
+
"""
|
168
|
+
)
|
169
|
+
|
170
|
+
@property
|
171
|
+
def id(self) -> str:
|
172
|
+
# This is a workaround to ensure that the project ID can be retrieved
|
173
|
+
# on demand, as it generally is not set or fetched on instantiation.
|
174
|
+
# This is necessary if using this project as the scope of a new Automation.
|
175
|
+
with suppress(LookupError):
|
176
|
+
return self._attrs["id"]
|
177
|
+
|
178
|
+
variable_values = {"projectName": self.name, "entityName": self.entity}
|
179
|
+
try:
|
180
|
+
data = self.client.execute(self._PROJECT_ID, variable_values)
|
181
|
+
self._attrs["id"] = data["project"]["id"]
|
182
|
+
return self._attrs["id"]
|
183
|
+
except (HTTPError, LookupError, TypeError) as e:
|
184
|
+
raise ValueError(f"Unable to fetch project ID: {variable_values!r}") from e
|
File without changes
|
@@ -0,0 +1,179 @@
|
|
1
|
+
from itertools import chain
|
2
|
+
from typing import (
|
3
|
+
Any,
|
4
|
+
Iterable,
|
5
|
+
Iterator,
|
6
|
+
List,
|
7
|
+
MutableSequence,
|
8
|
+
Sequence,
|
9
|
+
Tuple,
|
10
|
+
TypeVar,
|
11
|
+
Union,
|
12
|
+
final,
|
13
|
+
overload,
|
14
|
+
)
|
15
|
+
|
16
|
+
T = TypeVar("T")
|
17
|
+
|
18
|
+
|
19
|
+
@final
|
20
|
+
class FreezableList(MutableSequence[T]):
|
21
|
+
"""A list-like container type that only allows adding new items.
|
22
|
+
|
23
|
+
It tracks "saved" (immutable) and "draft" (mutable) items.
|
24
|
+
Items can be added, inserted, and removed while in draft state, but once frozen,
|
25
|
+
they become immutable. Unlike a set, duplicate items are allowed in the draft
|
26
|
+
state but duplicates already present in the saved state cannot be added.
|
27
|
+
Any initial items passed to the constructor are saved.
|
28
|
+
"""
|
29
|
+
|
30
|
+
def __init__(self, iterable: Union[Iterable[T], None] = None, /) -> None:
|
31
|
+
self._frozen: Tuple[T, ...] = tuple(iterable or ())
|
32
|
+
self._draft: List[T] = []
|
33
|
+
|
34
|
+
def append(self, value: T) -> None:
|
35
|
+
"""Append an item to the draft list. No duplicates are allowed."""
|
36
|
+
if (value in self._frozen) or (value in self._draft):
|
37
|
+
return
|
38
|
+
self._draft.append(value)
|
39
|
+
|
40
|
+
def remove(self, value: T) -> None:
|
41
|
+
"""Remove the first occurrence of value from the draft list."""
|
42
|
+
if value in self._frozen:
|
43
|
+
raise ValueError(f"Cannot remove item from frozen list: {value!r}")
|
44
|
+
self._draft.remove(value)
|
45
|
+
|
46
|
+
def freeze(self) -> None:
|
47
|
+
"""Freeze any draft items by adding them to the saved tuple."""
|
48
|
+
# Filter out duplicates already in saved before extending
|
49
|
+
new_items = tuple(item for item in self._draft if item not in self._frozen)
|
50
|
+
self._frozen = self._frozen + new_items
|
51
|
+
self._draft.clear()
|
52
|
+
|
53
|
+
def __eq__(self, value: object) -> bool:
|
54
|
+
if not isinstance(value, Sequence):
|
55
|
+
return NotImplemented
|
56
|
+
return list(self) == list(value)
|
57
|
+
|
58
|
+
def __contains__(self, value: Any) -> bool:
|
59
|
+
return value in self._frozen or value in self._draft
|
60
|
+
|
61
|
+
def __len__(self) -> int:
|
62
|
+
return len(self._frozen) + len(self._draft)
|
63
|
+
|
64
|
+
def __iter__(self) -> Iterator[T]:
|
65
|
+
return iter(chain(self._frozen, self._draft))
|
66
|
+
|
67
|
+
@overload
|
68
|
+
def __getitem__(self, index: int) -> T: ...
|
69
|
+
|
70
|
+
@overload
|
71
|
+
def __getitem__(self, index: slice) -> Sequence[T]: ...
|
72
|
+
|
73
|
+
def __getitem__(self, index: Union[int, slice]) -> Union[T, Sequence[T]]:
|
74
|
+
return [*self._frozen, *self._draft][index]
|
75
|
+
|
76
|
+
@overload
|
77
|
+
def __setitem__(self, index: int, value: T) -> None: ...
|
78
|
+
|
79
|
+
@overload
|
80
|
+
def __setitem__(self, index: slice, value: Iterable[T]) -> None: ...
|
81
|
+
|
82
|
+
def __setitem__(
|
83
|
+
self, index: Union[int, slice], value: Union[T, Iterable[T]]
|
84
|
+
) -> None:
|
85
|
+
if isinstance(index, slice):
|
86
|
+
# Setting slices might affect saved items, disallow for simplicity
|
87
|
+
raise TypeError(
|
88
|
+
f"{type(self).__name__!r} does not support slice assignment"
|
89
|
+
)
|
90
|
+
else:
|
91
|
+
if value in self._frozen or value in self._draft:
|
92
|
+
return
|
93
|
+
|
94
|
+
# The frozen items are sequentially first and protected from changes
|
95
|
+
len_frozen = len(self._frozen)
|
96
|
+
size = len(self)
|
97
|
+
|
98
|
+
if (index >= size) or (index < -size):
|
99
|
+
raise IndexError("Index out of range")
|
100
|
+
|
101
|
+
draft_index = (index % size) - len_frozen
|
102
|
+
if draft_index < 0:
|
103
|
+
raise ValueError(f"Cannot assign to saved item at index {index!r}")
|
104
|
+
self._draft[draft_index] = value
|
105
|
+
|
106
|
+
@overload
|
107
|
+
def __delitem__(self, index: int) -> None: ...
|
108
|
+
|
109
|
+
@overload
|
110
|
+
def __delitem__(self, index: slice) -> None: ...
|
111
|
+
|
112
|
+
def __delitem__(self, index: Union[int, slice]) -> None:
|
113
|
+
if isinstance(index, slice):
|
114
|
+
raise TypeError(f"{type(self).__name__!r} does not support slice deletion")
|
115
|
+
else:
|
116
|
+
# The frozen items are sequentially first and protected from changes
|
117
|
+
len_frozen = len(self._frozen)
|
118
|
+
size = len(self)
|
119
|
+
|
120
|
+
if (index >= size) or (index < -size):
|
121
|
+
raise IndexError("Index out of range")
|
122
|
+
|
123
|
+
draft_index = (index % size) - len_frozen
|
124
|
+
if draft_index < 0:
|
125
|
+
raise ValueError(f"Cannot delete saved item at index {index!r}")
|
126
|
+
del self._draft[draft_index]
|
127
|
+
|
128
|
+
def insert(self, index: int, value: T) -> None:
|
129
|
+
"""Insert item before index.
|
130
|
+
|
131
|
+
Insertion is only allowed at indices corresponding to the draft portion
|
132
|
+
of the list (i.e., index >= len(frozen_items)). Negative indices are
|
133
|
+
interpreted relative to the combined length of frozen and draft items.
|
134
|
+
"""
|
135
|
+
if value in self._frozen or value in self._draft:
|
136
|
+
# Silently ignore duplicates, similar to append
|
137
|
+
return
|
138
|
+
|
139
|
+
# The frozen items are sequentially first and protected from changes
|
140
|
+
len_frozen = len(self._frozen)
|
141
|
+
size = len(self)
|
142
|
+
|
143
|
+
# Follow the behavior of `list.insert()` when the index is out of bounds.
|
144
|
+
# - negative out-of-bounds index: prepend. Will only work if the frozen items are empty.
|
145
|
+
if index < -size and not self._frozen:
|
146
|
+
return self._draft.insert(0, value)
|
147
|
+
|
148
|
+
# - positive out-of-bounds index: append.
|
149
|
+
if index >= size:
|
150
|
+
return self._draft.append(value)
|
151
|
+
|
152
|
+
# - in-bounds index: insert only if into the draft portion.
|
153
|
+
draft_index = (index % size) - len_frozen
|
154
|
+
if draft_index < 0:
|
155
|
+
raise IndexError(
|
156
|
+
f"Cannot insert into the frozen list (index < {len_frozen})"
|
157
|
+
)
|
158
|
+
return self._draft.insert(draft_index, value)
|
159
|
+
|
160
|
+
def __repr__(self) -> str:
|
161
|
+
return f"{type(self).__name__}(frozen={list(self._frozen)!r}, draft={list(self._draft)!r})"
|
162
|
+
|
163
|
+
@property
|
164
|
+
def draft(self) -> Tuple[T, ...]:
|
165
|
+
"""A read-only, tuple copy of the current draft items."""
|
166
|
+
return tuple(self._draft)
|
167
|
+
|
168
|
+
|
169
|
+
class AddOnlyArtifactTypesList(FreezableList[str]):
|
170
|
+
def remove(self, value: str) -> None:
|
171
|
+
try:
|
172
|
+
super().remove(value)
|
173
|
+
except ValueError:
|
174
|
+
raise ValueError(
|
175
|
+
f"Cannot remove artifact type: {value!r} that has been saved to the registry"
|
176
|
+
)
|
177
|
+
|
178
|
+
def __repr__(self) -> str:
|
179
|
+
return f"{type(self).__name__}(saved={list(self._frozen)!r}, draft={list(self._draft)!r})"
|
@@ -1,7 +1,7 @@
|
|
1
|
-
"""Public API: registries."""
|
1
|
+
"""Public API: registries search."""
|
2
2
|
|
3
3
|
import json
|
4
|
-
from typing import TYPE_CHECKING, Any, Dict,
|
4
|
+
from typing import TYPE_CHECKING, Any, Dict, Optional
|
5
5
|
|
6
6
|
if TYPE_CHECKING:
|
7
7
|
from wandb_gql import Client
|
@@ -11,11 +11,11 @@ from wandb_gql import gql
|
|
11
11
|
import wandb
|
12
12
|
from wandb.apis.paginator import Paginator
|
13
13
|
from wandb.apis.public.artifacts import ArtifactCollection
|
14
|
+
from wandb.apis.public.registries.utils import _ensure_registry_prefix_on_names
|
14
15
|
from wandb.sdk.artifacts._graphql_fragments import (
|
15
16
|
_gql_artifact_fragment,
|
16
17
|
_gql_registry_fragment,
|
17
18
|
)
|
18
|
-
from wandb.sdk.artifacts._validators import REGISTRY_PREFIX
|
19
19
|
|
20
20
|
|
21
21
|
class Registries(Paginator):
|
@@ -63,7 +63,7 @@ class Registries(Paginator):
|
|
63
63
|
super().__init__(client, variables, per_page)
|
64
64
|
|
65
65
|
def __bool__(self):
|
66
|
-
return
|
66
|
+
return bool(self.objects)
|
67
67
|
|
68
68
|
def __next__(self):
|
69
69
|
# Implement custom next since its possible to load empty pages because of auth
|
@@ -128,6 +128,8 @@ class Registries(Paginator):
|
|
128
128
|
f"Organization '{self.organization}' not found. Please verify the organization name is correct"
|
129
129
|
)
|
130
130
|
|
131
|
+
from wandb.apis.public.registries.registry import Registry
|
132
|
+
|
131
133
|
return [
|
132
134
|
Registry(
|
133
135
|
self.client,
|
@@ -142,86 +144,6 @@ class Registries(Paginator):
|
|
142
144
|
]
|
143
145
|
|
144
146
|
|
145
|
-
class Registry:
|
146
|
-
"""A single registry in the Registry."""
|
147
|
-
|
148
|
-
def __init__(
|
149
|
-
self,
|
150
|
-
client: "Client",
|
151
|
-
organization: str,
|
152
|
-
entity: str,
|
153
|
-
full_name: str,
|
154
|
-
attrs: Dict[str, Any],
|
155
|
-
):
|
156
|
-
self.client = client
|
157
|
-
self._full_name = full_name
|
158
|
-
self._name = full_name.replace(REGISTRY_PREFIX, "")
|
159
|
-
self._entity = entity
|
160
|
-
self._organization = organization
|
161
|
-
self._description = attrs.get("description", "")
|
162
|
-
self._allow_all_artifact_types = attrs.get(
|
163
|
-
"allowAllArtifactTypesInRegistry", False
|
164
|
-
)
|
165
|
-
self._artifact_types = [
|
166
|
-
t["node"]["name"] for t in attrs.get("artifactTypes", {}).get("edges", [])
|
167
|
-
]
|
168
|
-
self._id = attrs.get("id", "")
|
169
|
-
self._created_at = attrs.get("createdAt", "")
|
170
|
-
self._updated_at = attrs.get("updatedAt", "")
|
171
|
-
|
172
|
-
@property
|
173
|
-
def full_name(self):
|
174
|
-
return self._full_name
|
175
|
-
|
176
|
-
@property
|
177
|
-
def name(self):
|
178
|
-
return self._name
|
179
|
-
|
180
|
-
@property
|
181
|
-
def entity(self):
|
182
|
-
return self._entity
|
183
|
-
|
184
|
-
@property
|
185
|
-
def organization(self):
|
186
|
-
return self._organization
|
187
|
-
|
188
|
-
@property
|
189
|
-
def description(self):
|
190
|
-
return self._description
|
191
|
-
|
192
|
-
@property
|
193
|
-
def allow_all_artifact_types(self):
|
194
|
-
return self._allow_all_artifact_types
|
195
|
-
|
196
|
-
@property
|
197
|
-
def artifact_types(self):
|
198
|
-
return self._artifact_types
|
199
|
-
|
200
|
-
@property
|
201
|
-
def created_at(self):
|
202
|
-
return self._created_at
|
203
|
-
|
204
|
-
@property
|
205
|
-
def updated_at(self):
|
206
|
-
return self._updated_at
|
207
|
-
|
208
|
-
@property
|
209
|
-
def path(self):
|
210
|
-
return [self.entity, self.name]
|
211
|
-
|
212
|
-
def collections(self, filter: Optional[Dict[str, Any]] = None):
|
213
|
-
registry_filter = {
|
214
|
-
"name": self.full_name,
|
215
|
-
}
|
216
|
-
return Collections(self.client, self.organization, registry_filter, filter)
|
217
|
-
|
218
|
-
def versions(self, filter: Optional[Dict[str, Any]] = None):
|
219
|
-
registry_filter = {
|
220
|
-
"name": self.full_name,
|
221
|
-
}
|
222
|
-
return Versions(self.client, self.organization, registry_filter, None, filter)
|
223
|
-
|
224
|
-
|
225
147
|
class Collections(Paginator):
|
226
148
|
"""Iterator that returns Artifact collections in the Registry."""
|
227
149
|
|
@@ -303,12 +225,12 @@ class Collections(Paginator):
|
|
303
225
|
self.collection_filter = collection_filter or {}
|
304
226
|
|
305
227
|
variables = {
|
306
|
-
"registryFilter":
|
307
|
-
|
308
|
-
|
309
|
-
"collectionFilter":
|
310
|
-
|
311
|
-
|
228
|
+
"registryFilter": (
|
229
|
+
json.dumps(self.registry_filter) if self.registry_filter else None
|
230
|
+
),
|
231
|
+
"collectionFilter": (
|
232
|
+
json.dumps(self.collection_filter) if self.collection_filter else None
|
233
|
+
),
|
312
234
|
"organization": self.organization,
|
313
235
|
"collectionTypes": ["PORTFOLIO"],
|
314
236
|
"perPage": per_page,
|
@@ -383,6 +305,7 @@ class Collections(Paginator):
|
|
383
305
|
r["node"]["defaultArtifactType"]["name"],
|
384
306
|
self.organization,
|
385
307
|
r["node"],
|
308
|
+
is_sequence=False,
|
386
309
|
)
|
387
310
|
for r in self.last_response["organization"]["orgEntity"][
|
388
311
|
"artifactCollections"
|
@@ -460,15 +383,15 @@ class Versions(Paginator):
|
|
460
383
|
)
|
461
384
|
|
462
385
|
variables = {
|
463
|
-
"registryFilter":
|
464
|
-
|
465
|
-
|
466
|
-
"collectionFilter":
|
467
|
-
|
468
|
-
|
469
|
-
"artifactFilter":
|
470
|
-
|
471
|
-
|
386
|
+
"registryFilter": (
|
387
|
+
json.dumps(self.registry_filter) if self.registry_filter else None
|
388
|
+
),
|
389
|
+
"collectionFilter": (
|
390
|
+
json.dumps(self.collection_filter) if self.collection_filter else None
|
391
|
+
),
|
392
|
+
"artifactFilter": (
|
393
|
+
json.dumps(self.artifact_filter) if self.artifact_filter else None
|
394
|
+
),
|
472
395
|
"organization": self.organization,
|
473
396
|
}
|
474
397
|
|
@@ -541,33 +464,3 @@ class Versions(Paginator):
|
|
541
464
|
]["edges"]
|
542
465
|
)
|
543
466
|
return artifacts
|
544
|
-
|
545
|
-
|
546
|
-
def _ensure_registry_prefix_on_names(query, in_name=False):
|
547
|
-
"""Traverse the filter to prepend the `name` key value with the registry prefix unless the value is a regex.
|
548
|
-
|
549
|
-
- in_name: True if we are under a "name" key (or propagating from one).
|
550
|
-
|
551
|
-
EX: {"name": "model"} -> {"name": "wandb-registry-model"}
|
552
|
-
"""
|
553
|
-
if isinstance((txt := query), str):
|
554
|
-
if in_name:
|
555
|
-
return txt if txt.startswith(REGISTRY_PREFIX) else f"{REGISTRY_PREFIX}{txt}"
|
556
|
-
return txt
|
557
|
-
if isinstance((dct := query), Mapping):
|
558
|
-
new_dict = {}
|
559
|
-
for key, obj in dct.items():
|
560
|
-
if key == "name":
|
561
|
-
new_dict[key] = _ensure_registry_prefix_on_names(obj, in_name=True)
|
562
|
-
elif key == "$regex":
|
563
|
-
# For regex operator, we skip transformation of its value.
|
564
|
-
new_dict[key] = obj
|
565
|
-
else:
|
566
|
-
# For any other key, propagate the in_name and skip_transform flags as-is.
|
567
|
-
new_dict[key] = _ensure_registry_prefix_on_names(obj, in_name=in_name)
|
568
|
-
return new_dict
|
569
|
-
if isinstance((objs := query), Sequence):
|
570
|
-
return list(
|
571
|
-
map(lambda x: _ensure_registry_prefix_on_names(x, in_name=in_name), objs)
|
572
|
-
)
|
573
|
-
return query
|