wandb 0.19.10__py3-none-macosx_11_0_arm64.whl → 0.19.11__py3-none-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +1 -1
- wandb/__init__.pyi +3 -3
- wandb/_pydantic/__init__.py +2 -3
- wandb/_pydantic/base.py +11 -31
- wandb/_pydantic/utils.py +8 -1
- wandb/_pydantic/v1_compat.py +3 -3
- wandb/apis/public/api.py +590 -22
- wandb/apis/public/artifacts.py +13 -5
- wandb/apis/public/automations.py +1 -1
- wandb/apis/public/integrations.py +22 -10
- wandb/apis/public/registries/__init__.py +0 -0
- wandb/apis/public/registries/_freezable_list.py +179 -0
- wandb/apis/public/{registries.py → registries/registries_search.py} +22 -129
- wandb/apis/public/registries/registry.py +357 -0
- wandb/apis/public/registries/utils.py +140 -0
- wandb/apis/public/runs.py +58 -56
- wandb/automations/__init__.py +16 -24
- wandb/automations/_filters/expressions.py +12 -10
- wandb/automations/_filters/operators.py +10 -19
- wandb/automations/_filters/run_metrics.py +231 -82
- wandb/automations/_generated/__init__.py +27 -34
- wandb/automations/_generated/create_automation.py +17 -0
- wandb/automations/_generated/delete_automation.py +17 -0
- wandb/automations/_generated/fragments.py +40 -25
- wandb/automations/_generated/{get_triggers.py → get_automations.py} +5 -5
- wandb/automations/_generated/{get_triggers_by_entity.py → get_automations_by_entity.py} +7 -5
- wandb/automations/_generated/operations.py +35 -98
- wandb/automations/_generated/update_automation.py +17 -0
- wandb/automations/_utils.py +178 -64
- wandb/automations/_validators.py +94 -2
- wandb/automations/actions.py +113 -98
- wandb/automations/automations.py +47 -69
- wandb/automations/events.py +139 -87
- wandb/automations/integrations.py +23 -4
- wandb/automations/scopes.py +22 -20
- wandb/bin/gpu_stats +0 -0
- wandb/bin/wandb-core +0 -0
- wandb/env.py +11 -0
- wandb/old/settings.py +4 -1
- wandb/proto/v3/wandb_internal_pb2.py +240 -236
- wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v4/wandb_internal_pb2.py +236 -236
- wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v5/wandb_internal_pb2.py +236 -236
- wandb/proto/v5/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v6/wandb_internal_pb2.py +236 -236
- wandb/proto/v6/wandb_telemetry_pb2.py +10 -10
- wandb/sdk/artifacts/_generated/__init__.py +42 -1
- wandb/sdk/artifacts/_generated/add_aliases.py +21 -0
- wandb/sdk/artifacts/_generated/delete_aliases.py +21 -0
- wandb/sdk/artifacts/_generated/fetch_linked_artifacts.py +67 -0
- wandb/sdk/artifacts/_generated/fragments.py +35 -0
- wandb/sdk/artifacts/_generated/input_types.py +12 -0
- wandb/sdk/artifacts/_generated/operations.py +101 -0
- wandb/sdk/artifacts/_generated/update_artifact.py +26 -0
- wandb/sdk/artifacts/_graphql_fragments.py +1 -0
- wandb/sdk/artifacts/_validators.py +120 -1
- wandb/sdk/artifacts/artifact.py +380 -203
- wandb/sdk/artifacts/artifact_file_cache.py +4 -6
- wandb/sdk/artifacts/artifact_manifest_entry.py +11 -2
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +182 -1
- wandb/sdk/artifacts/storage_policy.py +3 -0
- wandb/sdk/data_types/video.py +46 -32
- wandb/sdk/interface/interface.py +2 -3
- wandb/sdk/internal/internal_api.py +21 -31
- wandb/sdk/internal/sender.py +5 -2
- wandb/sdk/launch/sweeps/utils.py +8 -0
- wandb/sdk/projects/_generated/__init__.py +47 -0
- wandb/sdk/projects/_generated/delete_project.py +22 -0
- wandb/sdk/projects/_generated/enums.py +4 -0
- wandb/sdk/projects/_generated/fetch_registry.py +22 -0
- wandb/sdk/projects/_generated/fragments.py +41 -0
- wandb/sdk/projects/_generated/input_types.py +13 -0
- wandb/sdk/projects/_generated/operations.py +88 -0
- wandb/sdk/projects/_generated/rename_project.py +27 -0
- wandb/sdk/projects/_generated/upsert_registry_project.py +27 -0
- wandb/sdk/service/service.py +9 -1
- wandb/sdk/wandb_init.py +32 -5
- wandb/sdk/wandb_run.py +37 -9
- wandb/sdk/wandb_settings.py +6 -7
- wandb/sdk/wandb_setup.py +12 -0
- wandb/util.py +7 -3
- {wandb-0.19.10.dist-info → wandb-0.19.11.dist-info}/METADATA +1 -1
- {wandb-0.19.10.dist-info → wandb-0.19.11.dist-info}/RECORD +87 -70
- wandb/automations/_generated/create_filter_trigger.py +0 -21
- wandb/automations/_generated/delete_trigger.py +0 -19
- wandb/automations/_generated/update_filter_trigger.py +0 -21
- {wandb-0.19.10.dist-info → wandb-0.19.11.dist-info}/WHEEL +0 -0
- {wandb-0.19.10.dist-info → wandb-0.19.11.dist-info}/entry_points.txt +0 -0
- {wandb-0.19.10.dist-info → wandb-0.19.11.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,357 @@
|
|
1
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
|
2
|
+
|
3
|
+
from wandb_gql import gql
|
4
|
+
|
5
|
+
import wandb
|
6
|
+
from wandb.apis.public.registries._freezable_list import AddOnlyArtifactTypesList
|
7
|
+
from wandb.apis.public.registries.registries_search import Collections, Versions
|
8
|
+
from wandb.apis.public.registries.utils import (
|
9
|
+
_fetch_org_entity_from_organization,
|
10
|
+
_format_gql_artifact_types_input,
|
11
|
+
_gql_to_registry_visibility,
|
12
|
+
_registry_visibility_to_gql,
|
13
|
+
)
|
14
|
+
from wandb.proto.wandb_internal_pb2 import ServerFeature
|
15
|
+
from wandb.sdk.artifacts._validators import REGISTRY_PREFIX, validate_project_name
|
16
|
+
from wandb.sdk.internal.internal_api import Api as InternalApi
|
17
|
+
from wandb.sdk.projects._generated.delete_project import DeleteProject
|
18
|
+
from wandb.sdk.projects._generated.operations import (
|
19
|
+
DELETE_PROJECT_GQL,
|
20
|
+
FETCH_REGISTRY_GQL,
|
21
|
+
RENAME_PROJECT_GQL,
|
22
|
+
UPSERT_REGISTRY_PROJECT_GQL,
|
23
|
+
)
|
24
|
+
from wandb.sdk.projects._generated.rename_project import RenameProject
|
25
|
+
from wandb.sdk.projects._generated.upsert_registry_project import UpsertRegistryProject
|
26
|
+
|
27
|
+
if TYPE_CHECKING:
|
28
|
+
from wandb_gql import Client
|
29
|
+
|
30
|
+
|
31
|
+
class Registry:
|
32
|
+
"""A single registry in the Registry."""
|
33
|
+
|
34
|
+
def __init__(
|
35
|
+
self,
|
36
|
+
client: "Client",
|
37
|
+
organization: str,
|
38
|
+
entity: str,
|
39
|
+
name: str,
|
40
|
+
attrs: Optional[Dict[str, Any]] = None,
|
41
|
+
):
|
42
|
+
self.client = client
|
43
|
+
self._name = name
|
44
|
+
self._saved_name = name
|
45
|
+
self._entity = entity
|
46
|
+
self._organization = organization
|
47
|
+
if attrs is not None:
|
48
|
+
self._update_attributes(attrs)
|
49
|
+
|
50
|
+
def _update_attributes(self, attrs: Dict[str, Any]) -> None:
|
51
|
+
"""Helper method to update instance attributes from a dictionary."""
|
52
|
+
self._id = attrs.get("id", "")
|
53
|
+
if self._id is None:
|
54
|
+
raise ValueError(f"Registry {self.name}'s id is not found")
|
55
|
+
|
56
|
+
self._description = attrs.get("description", "")
|
57
|
+
self._allow_all_artifact_types = attrs.get(
|
58
|
+
"allowAllArtifactTypesInRegistry", False
|
59
|
+
)
|
60
|
+
self._artifact_types = AddOnlyArtifactTypesList(
|
61
|
+
t["node"]["name"] for t in attrs.get("artifactTypes", {}).get("edges", [])
|
62
|
+
)
|
63
|
+
self._created_at = attrs.get("createdAt", "")
|
64
|
+
self._updated_at = attrs.get("updatedAt", "")
|
65
|
+
self._visibility = _gql_to_registry_visibility(attrs.get("access", ""))
|
66
|
+
|
67
|
+
@property
|
68
|
+
def full_name(self) -> str:
|
69
|
+
"""Full name of the registry including the `wandb-registry-` prefix."""
|
70
|
+
return f"wandb-registry-{self.name}"
|
71
|
+
|
72
|
+
@property
|
73
|
+
def name(self) -> str:
|
74
|
+
"""Name of the registry without the `wandb-registry-` prefix."""
|
75
|
+
return self._name
|
76
|
+
|
77
|
+
@name.setter
|
78
|
+
def name(self, value: str):
|
79
|
+
self._name = value
|
80
|
+
|
81
|
+
@property
|
82
|
+
def entity(self) -> str:
|
83
|
+
"""Organization entity of the registry."""
|
84
|
+
return self._entity
|
85
|
+
|
86
|
+
@property
|
87
|
+
def organization(self) -> str:
|
88
|
+
"""Organization name of the registry."""
|
89
|
+
return self._organization
|
90
|
+
|
91
|
+
@property
|
92
|
+
def description(self) -> str:
|
93
|
+
"""Description of the registry."""
|
94
|
+
return self._description
|
95
|
+
|
96
|
+
@description.setter
|
97
|
+
def description(self, value: str):
|
98
|
+
"""Set the description of the registry."""
|
99
|
+
self._description = value
|
100
|
+
|
101
|
+
@property
|
102
|
+
def allow_all_artifact_types(self):
|
103
|
+
"""Returns whether all artifact types are allowed in the registry.
|
104
|
+
|
105
|
+
If `True` then artifacts of any type can be added to this registry.
|
106
|
+
If `False` then artifacts are restricted to the types in `artifact_types` for this registry.
|
107
|
+
"""
|
108
|
+
return self._allow_all_artifact_types
|
109
|
+
|
110
|
+
@allow_all_artifact_types.setter
|
111
|
+
def allow_all_artifact_types(self, value: bool):
|
112
|
+
"""Set whether all artifact types are allowed in the registry."""
|
113
|
+
self._allow_all_artifact_types = value
|
114
|
+
|
115
|
+
@property
|
116
|
+
def artifact_types(self) -> AddOnlyArtifactTypesList:
|
117
|
+
"""Returns the artifact types allowed in the registry.
|
118
|
+
|
119
|
+
If `allow_all_artifact_types` is `True` then `artifact_types` reflects the
|
120
|
+
types previously saved or currently used in the registry.
|
121
|
+
If `allow_all_artifact_types` is `False` then artifacts are restricted to the
|
122
|
+
types in `artifact_types`.
|
123
|
+
|
124
|
+
Note:
|
125
|
+
Previously saved artifact types cannot be removed.
|
126
|
+
|
127
|
+
Example:
|
128
|
+
```python
|
129
|
+
registry.artifact_types.append("model")
|
130
|
+
registry.save() # once saved, the artifact type `model` cannot be removed
|
131
|
+
registry.artifact_types.append("accidentally_added")
|
132
|
+
registry.artifact_types.remove(
|
133
|
+
"accidentally_added"
|
134
|
+
) # Types can only be removed if it has not been saved yet
|
135
|
+
```
|
136
|
+
"""
|
137
|
+
return self._artifact_types
|
138
|
+
|
139
|
+
@property
|
140
|
+
def created_at(self) -> str:
|
141
|
+
"""Timestamp of when the registry was created."""
|
142
|
+
return self._created_at
|
143
|
+
|
144
|
+
@property
|
145
|
+
def updated_at(self) -> str:
|
146
|
+
"""Timestamp of when the registry was last updated."""
|
147
|
+
return self._updated_at
|
148
|
+
|
149
|
+
@property
|
150
|
+
def path(self):
|
151
|
+
return [self.entity, self.full_name]
|
152
|
+
|
153
|
+
@property
|
154
|
+
def visibility(self) -> Literal["organization", "restricted"]:
|
155
|
+
"""Visibility of the registry.
|
156
|
+
|
157
|
+
Returns:
|
158
|
+
Literal["organization", "restricted"]: The visibility level.
|
159
|
+
- "organization": Anyone in the organization can view this registry.
|
160
|
+
You can edit their roles later from the settings in the UI.
|
161
|
+
- "restricted": Only invited members via the UI can access this registry.
|
162
|
+
Public sharing is disabled.
|
163
|
+
"""
|
164
|
+
return self._visibility
|
165
|
+
|
166
|
+
@visibility.setter
|
167
|
+
def visibility(self, value: Literal["organization", "restricted"]):
|
168
|
+
"""Set the visibility of the registry.
|
169
|
+
|
170
|
+
Args:
|
171
|
+
value: The visibility level. Options are:
|
172
|
+
- "organization": Anyone in the organization can view this registry.
|
173
|
+
You can edit their roles later from the settings in the UI.
|
174
|
+
- "restricted": Only invited members via the UI can access this registry.
|
175
|
+
Public sharing is disabled.
|
176
|
+
"""
|
177
|
+
self._visibility = value
|
178
|
+
|
179
|
+
def collections(self, filter: Optional[Dict[str, Any]] = None) -> Collections:
|
180
|
+
"""Returns the collections belonging to the registry."""
|
181
|
+
registry_filter = {
|
182
|
+
"name": self.full_name,
|
183
|
+
}
|
184
|
+
return Collections(self.client, self.organization, registry_filter, filter)
|
185
|
+
|
186
|
+
def versions(self, filter: Optional[Dict[str, Any]] = None) -> Versions:
|
187
|
+
"""Returns the versions belonging to the registry."""
|
188
|
+
registry_filter = {
|
189
|
+
"name": self.full_name,
|
190
|
+
}
|
191
|
+
return Versions(self.client, self.organization, registry_filter, None, filter)
|
192
|
+
|
193
|
+
@classmethod
|
194
|
+
def create(
|
195
|
+
cls,
|
196
|
+
client: "Client",
|
197
|
+
organization: str,
|
198
|
+
name: str,
|
199
|
+
visibility: Literal["organization", "restricted"],
|
200
|
+
description: Optional[str] = None,
|
201
|
+
artifact_types: Optional[List[str]] = None,
|
202
|
+
):
|
203
|
+
"""Create a new registry.
|
204
|
+
|
205
|
+
The registry name must be unique within the organization.
|
206
|
+
This function should be called using `api.create_registry()`
|
207
|
+
|
208
|
+
Args:
|
209
|
+
client: The GraphQL client.
|
210
|
+
organization: The name of the organization.
|
211
|
+
name: The name of the registry (without the `wandb-registry-` prefix).
|
212
|
+
visibility: The visibility level ('organization' or 'restricted').
|
213
|
+
description: An optional description for the registry.
|
214
|
+
artifact_types: An optional list of allowed artifact types.
|
215
|
+
|
216
|
+
Returns:
|
217
|
+
Registry: The newly created Registry object.
|
218
|
+
|
219
|
+
Raises:
|
220
|
+
ValueError: If a registry with the same name already exists in the
|
221
|
+
organization or if the creation fails.
|
222
|
+
"""
|
223
|
+
org_entity = _fetch_org_entity_from_organization(client, organization)
|
224
|
+
full_name = REGISTRY_PREFIX + name
|
225
|
+
validate_project_name(full_name)
|
226
|
+
accepted_artifact_types = []
|
227
|
+
if artifact_types:
|
228
|
+
accepted_artifact_types = _format_gql_artifact_types_input(artifact_types)
|
229
|
+
visibility_value = _registry_visibility_to_gql(visibility)
|
230
|
+
registry_creation_error = (
|
231
|
+
f"Failed to create registry {name!r} in organization {organization!r}."
|
232
|
+
)
|
233
|
+
try:
|
234
|
+
response = client.execute(
|
235
|
+
gql(UPSERT_REGISTRY_PROJECT_GQL),
|
236
|
+
variable_values={
|
237
|
+
"description": description,
|
238
|
+
"entityName": org_entity,
|
239
|
+
"name": full_name,
|
240
|
+
"access": visibility_value,
|
241
|
+
"allowAllArtifactTypesInRegistry": not accepted_artifact_types,
|
242
|
+
"artifactTypes": accepted_artifact_types,
|
243
|
+
},
|
244
|
+
)
|
245
|
+
except Exception:
|
246
|
+
raise ValueError(registry_creation_error)
|
247
|
+
if not response["upsertModel"]["inserted"]:
|
248
|
+
raise ValueError(registry_creation_error)
|
249
|
+
|
250
|
+
return Registry(
|
251
|
+
client,
|
252
|
+
organization,
|
253
|
+
org_entity,
|
254
|
+
name,
|
255
|
+
response["upsertModel"]["project"],
|
256
|
+
)
|
257
|
+
|
258
|
+
def delete(self) -> None:
|
259
|
+
"""Delete the registry. This is irreversible."""
|
260
|
+
try:
|
261
|
+
response = self.client.execute(
|
262
|
+
gql(DELETE_PROJECT_GQL), variable_values={"id": self._id}
|
263
|
+
)
|
264
|
+
result = DeleteProject.model_validate(response)
|
265
|
+
except Exception:
|
266
|
+
raise ValueError(
|
267
|
+
f"Failed to delete registry: {self.name!r} in organization: {self.organization!r}"
|
268
|
+
)
|
269
|
+
if not result.delete_model.success:
|
270
|
+
raise ValueError(
|
271
|
+
f"Failed to delete registry: {self.name!r} in organization: {self.organization!r}"
|
272
|
+
)
|
273
|
+
|
274
|
+
def load(self) -> None:
|
275
|
+
"""Load the registry attributes from the backend to reflect the latest saved state."""
|
276
|
+
load_failure_message = (
|
277
|
+
f"Failed to load registry {self.name!r} "
|
278
|
+
f"in organization {self.organization!r}."
|
279
|
+
)
|
280
|
+
try:
|
281
|
+
response = self.client.execute(
|
282
|
+
gql(FETCH_REGISTRY_GQL),
|
283
|
+
variable_values={
|
284
|
+
"name": self.full_name,
|
285
|
+
"entityName": self.entity,
|
286
|
+
},
|
287
|
+
)
|
288
|
+
except Exception:
|
289
|
+
raise ValueError(load_failure_message)
|
290
|
+
if response["entity"] is None:
|
291
|
+
raise ValueError(load_failure_message)
|
292
|
+
self.attrs = response["entity"]["project"]
|
293
|
+
if self.attrs is None:
|
294
|
+
raise ValueError(load_failure_message)
|
295
|
+
self._update_attributes(self.attrs)
|
296
|
+
|
297
|
+
def save(self) -> None:
|
298
|
+
"""Save registry attributes to the backend."""
|
299
|
+
if not InternalApi()._check_server_feature_with_fallback(
|
300
|
+
ServerFeature.INCLUDE_ARTIFACT_TYPES_IN_REGISTRY_CREATION
|
301
|
+
):
|
302
|
+
raise RuntimeError(
|
303
|
+
"saving the registry is not enabled on this wandb server version. "
|
304
|
+
"Please upgrade your server version or contact support at support@wandb.com."
|
305
|
+
)
|
306
|
+
|
307
|
+
if self._no_updating_registry_types():
|
308
|
+
raise ValueError(
|
309
|
+
f"Cannot update artifact types when `allows_all_artifact_types` is {True!r}. Set it to {False!r} first."
|
310
|
+
)
|
311
|
+
|
312
|
+
validate_project_name(self.full_name)
|
313
|
+
visibility_value = _registry_visibility_to_gql(self.visibility)
|
314
|
+
newly_added_types = _format_gql_artifact_types_input(self.artifact_types.draft)
|
315
|
+
registry_save_error = f"Failed to save and update registry: {self.name} in organization: {self.organization}"
|
316
|
+
full_saved_name = f"{REGISTRY_PREFIX}{self._saved_name}"
|
317
|
+
try:
|
318
|
+
response = self.client.execute(
|
319
|
+
gql(UPSERT_REGISTRY_PROJECT_GQL),
|
320
|
+
variable_values={
|
321
|
+
"description": self.description,
|
322
|
+
"entityName": self.entity,
|
323
|
+
"name": full_saved_name, # this makes it so we are updating the original registry in case the name has changed
|
324
|
+
"access": visibility_value,
|
325
|
+
"allowAllArtifactTypesInRegistry": self.allow_all_artifact_types,
|
326
|
+
"artifactTypes": newly_added_types,
|
327
|
+
},
|
328
|
+
)
|
329
|
+
result = UpsertRegistryProject.model_validate(response)
|
330
|
+
except Exception:
|
331
|
+
raise ValueError(registry_save_error)
|
332
|
+
if result.upsert_model.inserted:
|
333
|
+
# This is not suppose trigger unless the user has messed with the `_saved_name` variable
|
334
|
+
wandb.termlog(
|
335
|
+
f"Created registry {self.name!r} in organization {self.organization!r} on save"
|
336
|
+
)
|
337
|
+
self._update_attributes(response["upsertModel"]["project"])
|
338
|
+
|
339
|
+
# Update the name of the registry if it has changed
|
340
|
+
if self._saved_name != self.name:
|
341
|
+
response = self.client.execute(
|
342
|
+
gql(RENAME_PROJECT_GQL),
|
343
|
+
variable_values={
|
344
|
+
"entityName": self.entity,
|
345
|
+
"oldProjectName": full_saved_name,
|
346
|
+
"newProjectName": self.full_name,
|
347
|
+
},
|
348
|
+
)
|
349
|
+
result = RenameProject.model_validate(response)
|
350
|
+
self._saved_name = self.name
|
351
|
+
if result.rename_project.inserted:
|
352
|
+
# This is not suppose trigger unless the user has messed with the `_saved_name` variable
|
353
|
+
wandb.termlog(f"Created new registry {self.name!r} on save")
|
354
|
+
|
355
|
+
def _no_updating_registry_types(self) -> bool:
|
356
|
+
# artifact types draft means user assigned types to add that are not yet saved
|
357
|
+
return len(self.artifact_types.draft) > 0 and self.allow_all_artifact_types
|
@@ -0,0 +1,140 @@
|
|
1
|
+
from enum import Enum
|
2
|
+
from functools import lru_cache
|
3
|
+
from typing import TYPE_CHECKING, Any, List, Literal, Mapping, Optional, Sequence
|
4
|
+
|
5
|
+
from wandb.sdk.artifacts._validators import (
|
6
|
+
REGISTRY_PREFIX,
|
7
|
+
validate_artifact_types_list,
|
8
|
+
)
|
9
|
+
|
10
|
+
if TYPE_CHECKING:
|
11
|
+
from wandb_gql import Client
|
12
|
+
|
13
|
+
from wandb_gql import gql
|
14
|
+
|
15
|
+
|
16
|
+
class _Visibility(str, Enum):
|
17
|
+
# names are what users see/pass into Python methods
|
18
|
+
# values are what's expected by backend API
|
19
|
+
organization = "PRIVATE"
|
20
|
+
restricted = "RESTRICTED"
|
21
|
+
|
22
|
+
@classmethod
|
23
|
+
def _missing_(cls, value: object) -> Any:
|
24
|
+
return next(
|
25
|
+
(e for e in cls if e.name == value),
|
26
|
+
None,
|
27
|
+
)
|
28
|
+
|
29
|
+
|
30
|
+
def _format_gql_artifact_types_input(
|
31
|
+
artifact_types: Optional[List[str]] = None,
|
32
|
+
):
|
33
|
+
"""Format the artifact types for the GQL input.
|
34
|
+
|
35
|
+
Args:
|
36
|
+
artifact_types: The artifact types to add to the registry.
|
37
|
+
|
38
|
+
Returns:
|
39
|
+
The artifact types for the GQL input.
|
40
|
+
"""
|
41
|
+
if artifact_types is None:
|
42
|
+
return []
|
43
|
+
new_types = validate_artifact_types_list(artifact_types)
|
44
|
+
return [{"name": type} for type in new_types]
|
45
|
+
|
46
|
+
|
47
|
+
def _gql_to_registry_visibility(
|
48
|
+
visibility: str,
|
49
|
+
) -> Literal["organization", "restricted"]:
|
50
|
+
"""Convert the GQL visibility to the registry visibility.
|
51
|
+
|
52
|
+
Args:
|
53
|
+
visibility: The GQL visibility.
|
54
|
+
|
55
|
+
Returns:
|
56
|
+
The registry visibility.
|
57
|
+
"""
|
58
|
+
try:
|
59
|
+
return _Visibility(visibility).name
|
60
|
+
except ValueError:
|
61
|
+
raise ValueError(f"Invalid visibility: {visibility!r} from backend")
|
62
|
+
|
63
|
+
|
64
|
+
def _registry_visibility_to_gql(
|
65
|
+
visibility: Literal["organization", "restricted"],
|
66
|
+
) -> str:
|
67
|
+
"""Convert the registry visibility to the GQL visibility."""
|
68
|
+
try:
|
69
|
+
return _Visibility[visibility].value
|
70
|
+
except KeyError:
|
71
|
+
raise ValueError(
|
72
|
+
f"Invalid visibility: {visibility!r}. "
|
73
|
+
f"Must be one of: {', '.join(map(repr, (e.name for e in _Visibility)))}"
|
74
|
+
)
|
75
|
+
|
76
|
+
|
77
|
+
def _ensure_registry_prefix_on_names(query, in_name=False):
|
78
|
+
"""Traverse the filter to prepend the `name` key value with the registry prefix unless the value is a regex.
|
79
|
+
|
80
|
+
- in_name: True if we are under a "name" key (or propagating from one).
|
81
|
+
|
82
|
+
EX: {"name": "model"} -> {"name": "wandb-registry-model"}
|
83
|
+
"""
|
84
|
+
if isinstance((txt := query), str):
|
85
|
+
if in_name:
|
86
|
+
return txt if txt.startswith(REGISTRY_PREFIX) else f"{REGISTRY_PREFIX}{txt}"
|
87
|
+
return txt
|
88
|
+
if isinstance((dct := query), Mapping):
|
89
|
+
new_dict = {}
|
90
|
+
for key, obj in dct.items():
|
91
|
+
if key == "name":
|
92
|
+
new_dict[key] = _ensure_registry_prefix_on_names(obj, in_name=True)
|
93
|
+
elif key == "$regex":
|
94
|
+
# For regex operator, we skip transformation of its value.
|
95
|
+
new_dict[key] = obj
|
96
|
+
else:
|
97
|
+
# For any other key, propagate the in_name and skip_transform flags as-is.
|
98
|
+
new_dict[key] = _ensure_registry_prefix_on_names(obj, in_name=in_name)
|
99
|
+
return new_dict
|
100
|
+
if isinstance((objs := query), Sequence):
|
101
|
+
return list(
|
102
|
+
map(lambda x: _ensure_registry_prefix_on_names(x, in_name=in_name), objs)
|
103
|
+
)
|
104
|
+
return query
|
105
|
+
|
106
|
+
|
107
|
+
@lru_cache(maxsize=10)
|
108
|
+
def _fetch_org_entity_from_organization(client: "Client", organization: str) -> str:
|
109
|
+
"""Fetch the org entity from the organization.
|
110
|
+
|
111
|
+
Args:
|
112
|
+
client (Client): Graphql client.
|
113
|
+
organization (str): The organization to fetch the org entity for.
|
114
|
+
"""
|
115
|
+
query = gql(
|
116
|
+
"""
|
117
|
+
query FetchOrgEntityFromOrganization($organization: String!) {
|
118
|
+
organization(name: $organization) {
|
119
|
+
orgEntity {
|
120
|
+
name
|
121
|
+
}
|
122
|
+
}
|
123
|
+
}
|
124
|
+
"""
|
125
|
+
)
|
126
|
+
try:
|
127
|
+
response = client.execute(query, variable_values={"organization": organization})
|
128
|
+
if response["organization"] and response["organization"]["orgEntity"]:
|
129
|
+
if response["organization"]["orgEntity"]["name"]:
|
130
|
+
return response["organization"]["orgEntity"]["name"]
|
131
|
+
return ValueError(
|
132
|
+
f"Organization entity for organization: {organization} is empty"
|
133
|
+
)
|
134
|
+
raise ValueError(
|
135
|
+
f"Organization entity for organization: {organization} not found"
|
136
|
+
)
|
137
|
+
except Exception as e:
|
138
|
+
raise ValueError(
|
139
|
+
f"Error fetching org entity for organization: {organization}"
|
140
|
+
) from e
|
wandb/apis/public/runs.py
CHANGED
@@ -61,37 +61,36 @@ RUN_FRAGMENT = """fragment RunFragment on Run {
|
|
61
61
|
}"""
|
62
62
|
|
63
63
|
|
64
|
+
@normalize_exceptions
|
65
|
+
def _server_provides_internal_id_for_project(client) -> bool:
|
66
|
+
"""Returns True if the server allows us to query the internalId field for a project.
|
67
|
+
|
68
|
+
This check is done by utilizing GraphQL introspection in the available fields on the Project type.
|
69
|
+
"""
|
70
|
+
query_string = """
|
71
|
+
query ProbeRunInput {
|
72
|
+
RunType: __type(name:"Run") {
|
73
|
+
fields {
|
74
|
+
name
|
75
|
+
}
|
76
|
+
}
|
77
|
+
}
|
78
|
+
"""
|
79
|
+
|
80
|
+
# Only perform the query once to avoid extra network calls
|
81
|
+
query = gql(query_string)
|
82
|
+
res = client.execute(query)
|
83
|
+
return "projectId" in [
|
84
|
+
x["name"] for x in (res.get("RunType", {}).get("fields", [{}]))
|
85
|
+
]
|
86
|
+
|
87
|
+
|
64
88
|
class Runs(SizedPaginator["Run"]):
|
65
89
|
"""An iterable collection of runs associated with a project and optional filter.
|
66
90
|
|
67
91
|
This is generally used indirectly via the `Api`.runs method.
|
68
92
|
"""
|
69
93
|
|
70
|
-
QUERY = gql(
|
71
|
-
"""
|
72
|
-
query Runs($project: String!, $entity: String!, $cursor: String, $perPage: Int = 50, $order: String, $filters: JSONString) {{
|
73
|
-
project(name: $project, entityName: $entity) {{
|
74
|
-
internalId
|
75
|
-
runCount(filters: $filters)
|
76
|
-
readOnly
|
77
|
-
runs(filters: $filters, after: $cursor, first: $perPage, order: $order) {{
|
78
|
-
edges {{
|
79
|
-
node {{
|
80
|
-
...RunFragment
|
81
|
-
}}
|
82
|
-
cursor
|
83
|
-
}}
|
84
|
-
pageInfo {{
|
85
|
-
endCursor
|
86
|
-
hasNextPage
|
87
|
-
}}
|
88
|
-
}}
|
89
|
-
}}
|
90
|
-
}}
|
91
|
-
{}
|
92
|
-
""".format(RUN_FRAGMENT)
|
93
|
-
)
|
94
|
-
|
95
94
|
def __init__(
|
96
95
|
self,
|
97
96
|
client: "RetryingClient",
|
@@ -102,6 +101,32 @@ class Runs(SizedPaginator["Run"]):
|
|
102
101
|
per_page: int = 50,
|
103
102
|
include_sweeps: bool = True,
|
104
103
|
):
|
104
|
+
self.QUERY = gql(
|
105
|
+
f"""#graphql
|
106
|
+
query Runs($project: String!, $entity: String!, $cursor: String, $perPage: Int = 50, $order: String, $filters: JSONString) {{
|
107
|
+
project(name: $project, entityName: $entity) {{
|
108
|
+
internalId
|
109
|
+
runCount(filters: $filters)
|
110
|
+
readOnly
|
111
|
+
runs(filters: $filters, after: $cursor, first: $perPage, order: $order) {{
|
112
|
+
edges {{
|
113
|
+
node {{
|
114
|
+
{"" if _server_provides_internal_id_for_project(client) else "internalId"}
|
115
|
+
...RunFragment
|
116
|
+
}}
|
117
|
+
cursor
|
118
|
+
}}
|
119
|
+
pageInfo {{
|
120
|
+
endCursor
|
121
|
+
hasNextPage
|
122
|
+
}}
|
123
|
+
}}
|
124
|
+
}}
|
125
|
+
}}
|
126
|
+
{RUN_FRAGMENT}
|
127
|
+
"""
|
128
|
+
)
|
129
|
+
|
105
130
|
self.entity = entity
|
106
131
|
self.project = project
|
107
132
|
self._project_internal_id = None
|
@@ -429,7 +454,9 @@ class Run(Attrs):
|
|
429
454
|
}}
|
430
455
|
{}
|
431
456
|
""".format(
|
432
|
-
"projectId"
|
457
|
+
"projectId"
|
458
|
+
if _server_provides_internal_id_for_project(self.client)
|
459
|
+
else "",
|
433
460
|
RUN_FRAGMENT,
|
434
461
|
)
|
435
462
|
)
|
@@ -444,10 +471,6 @@ class Run(Attrs):
|
|
444
471
|
self._attrs = response["project"]["run"]
|
445
472
|
self._state = self._attrs["state"]
|
446
473
|
|
447
|
-
self._project_internal_id = (
|
448
|
-
int(self._attrs["projectId"]) if "projectId" in self._attrs else None
|
449
|
-
)
|
450
|
-
|
451
474
|
if self._include_sweeps and self.sweep_name and not self.sweep:
|
452
475
|
# There may be a lot of runs. Don't bother pulling them all
|
453
476
|
# just for the sake of this one.
|
@@ -459,6 +482,11 @@ class Run(Attrs):
|
|
459
482
|
withRuns=False,
|
460
483
|
)
|
461
484
|
|
485
|
+
if "projectId" in self._attrs:
|
486
|
+
self._project_internal_id = int(self._attrs["projectId"])
|
487
|
+
else:
|
488
|
+
self._project_internal_id = None
|
489
|
+
|
462
490
|
try:
|
463
491
|
self._attrs["summaryMetrics"] = (
|
464
492
|
json.loads(self._attrs["summaryMetrics"])
|
@@ -912,32 +940,6 @@ class Run(Attrs):
|
|
912
940
|
)
|
913
941
|
return artifact
|
914
942
|
|
915
|
-
@normalize_exceptions
|
916
|
-
def _server_provides_internal_id_for_project(self) -> bool:
|
917
|
-
"""Returns True if the server allows us to query the internalId field for a project.
|
918
|
-
|
919
|
-
This check is done by utilizing GraphQL introspection in the available fields on the Project type.
|
920
|
-
"""
|
921
|
-
query_string = """
|
922
|
-
query ProbeRunInput {
|
923
|
-
RunType: __type(name:"Run") {
|
924
|
-
fields {
|
925
|
-
name
|
926
|
-
}
|
927
|
-
}
|
928
|
-
}
|
929
|
-
"""
|
930
|
-
|
931
|
-
# Only perform the query once to avoid extra network calls
|
932
|
-
if self.server_provides_internal_id_field is None:
|
933
|
-
query = gql(query_string)
|
934
|
-
res = self.client.execute(query)
|
935
|
-
self.server_provides_internal_id_field = "projectId" in [
|
936
|
-
x["name"] for x in (res.get("RunType", {}).get("fields", [{}]))
|
937
|
-
]
|
938
|
-
|
939
|
-
return self.server_provides_internal_id_field
|
940
|
-
|
941
943
|
@property
|
942
944
|
def summary(self):
|
943
945
|
if self._summary is None:
|