wandb 0.19.6rc4__py3-none-macosx_11_0_arm64.whl → 0.19.8__py3-none-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +1 -1
- wandb/__init__.pyi +56 -6
- wandb/apis/public/_generated/__init__.py +21 -0
- wandb/apis/public/_generated/base.py +128 -0
- wandb/apis/public/_generated/enums.py +4 -0
- wandb/apis/public/_generated/input_types.py +4 -0
- wandb/apis/public/_generated/operations.py +15 -0
- wandb/apis/public/_generated/server_features_query.py +27 -0
- wandb/apis/public/_generated/typing_compat.py +14 -0
- wandb/apis/public/api.py +192 -6
- wandb/apis/public/artifacts.py +13 -45
- wandb/apis/public/registries.py +573 -0
- wandb/apis/public/utils.py +36 -0
- wandb/bin/gpu_stats +0 -0
- wandb/bin/wandb-core +0 -0
- wandb/cli/cli.py +11 -20
- wandb/data_types.py +1 -1
- wandb/env.py +10 -0
- wandb/filesync/dir_watcher.py +2 -1
- wandb/proto/v3/wandb_internal_pb2.py +243 -222
- wandb/proto/v3/wandb_server_pb2.py +4 -4
- wandb/proto/v3/wandb_settings_pb2.py +2 -2
- wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v4/wandb_internal_pb2.py +226 -222
- wandb/proto/v4/wandb_server_pb2.py +4 -4
- wandb/proto/v4/wandb_settings_pb2.py +2 -2
- wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v5/wandb_internal_pb2.py +226 -222
- wandb/proto/v5/wandb_server_pb2.py +4 -4
- wandb/proto/v5/wandb_settings_pb2.py +2 -2
- wandb/proto/v5/wandb_telemetry_pb2.py +10 -10
- wandb/sdk/artifacts/_graphql_fragments.py +126 -0
- wandb/sdk/artifacts/artifact.py +51 -95
- wandb/sdk/backend/backend.py +17 -6
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +14 -6
- wandb/sdk/data_types/helper_types/image_mask.py +12 -6
- wandb/sdk/data_types/saved_model.py +35 -46
- wandb/sdk/data_types/video.py +7 -16
- wandb/sdk/interface/interface.py +87 -49
- wandb/sdk/interface/interface_queue.py +5 -15
- wandb/sdk/interface/interface_relay.py +7 -22
- wandb/sdk/interface/interface_shared.py +65 -136
- wandb/sdk/interface/interface_sock.py +3 -21
- wandb/sdk/interface/router.py +42 -68
- wandb/sdk/interface/router_queue.py +13 -11
- wandb/sdk/interface/router_relay.py +26 -13
- wandb/sdk/interface/router_sock.py +12 -16
- wandb/sdk/internal/handler.py +4 -3
- wandb/sdk/internal/internal_api.py +12 -1
- wandb/sdk/internal/sender.py +3 -19
- wandb/sdk/lib/apikey.py +87 -26
- wandb/sdk/lib/asyncio_compat.py +210 -0
- wandb/sdk/lib/console_capture.py +172 -0
- wandb/sdk/lib/progress.py +78 -16
- wandb/sdk/lib/redirect.py +102 -76
- wandb/sdk/lib/service_connection.py +37 -17
- wandb/sdk/lib/sock_client.py +6 -56
- wandb/sdk/mailbox/__init__.py +23 -0
- wandb/sdk/mailbox/mailbox.py +135 -0
- wandb/sdk/mailbox/mailbox_handle.py +127 -0
- wandb/sdk/mailbox/response_handle.py +167 -0
- wandb/sdk/mailbox/wait_with_progress.py +135 -0
- wandb/sdk/service/server_sock.py +9 -3
- wandb/sdk/service/streams.py +75 -78
- wandb/sdk/verify/verify.py +54 -2
- wandb/sdk/wandb_init.py +72 -75
- wandb/sdk/wandb_login.py +7 -4
- wandb/sdk/wandb_metadata.py +65 -34
- wandb/sdk/wandb_require.py +14 -8
- wandb/sdk/wandb_run.py +90 -97
- wandb/sdk/wandb_settings.py +10 -4
- wandb/sdk/wandb_setup.py +19 -8
- wandb/sdk/wandb_sync.py +2 -10
- wandb/util.py +3 -1
- {wandb-0.19.6rc4.dist-info → wandb-0.19.8.dist-info}/METADATA +2 -2
- {wandb-0.19.6rc4.dist-info → wandb-0.19.8.dist-info}/RECORD +79 -66
- wandb/sdk/interface/message_future.py +0 -27
- wandb/sdk/interface/message_future_poll.py +0 -50
- wandb/sdk/lib/mailbox.py +0 -442
- {wandb-0.19.6rc4.dist-info → wandb-0.19.8.dist-info}/WHEEL +0 -0
- {wandb-0.19.6rc4.dist-info → wandb-0.19.8.dist-info}/entry_points.txt +0 -0
- {wandb-0.19.6rc4.dist-info → wandb-0.19.8.dist-info}/licenses/LICENSE +0 -0
wandb/apis/public/artifacts.py
CHANGED
@@ -12,56 +12,16 @@ from wandb.apis import public
|
|
12
12
|
from wandb.apis.normalize import normalize_exceptions
|
13
13
|
from wandb.apis.paginator import Paginator
|
14
14
|
from wandb.errors.term import termlog
|
15
|
+
from wandb.sdk.artifacts._graphql_fragments import (
|
16
|
+
ARTIFACT_FILES_FRAGMENT,
|
17
|
+
ARTIFACTS_TYPES_FRAGMENT,
|
18
|
+
)
|
15
19
|
from wandb.sdk.lib import deprecate
|
16
20
|
|
17
21
|
if TYPE_CHECKING:
|
18
22
|
from wandb.apis.public import RetryingClient, Run
|
19
23
|
|
20
24
|
|
21
|
-
ARTIFACTS_TYPES_FRAGMENT = """
|
22
|
-
fragment ArtifactTypesFragment on ArtifactTypeConnection {
|
23
|
-
edges {
|
24
|
-
node {
|
25
|
-
id
|
26
|
-
name
|
27
|
-
description
|
28
|
-
createdAt
|
29
|
-
}
|
30
|
-
cursor
|
31
|
-
}
|
32
|
-
pageInfo {
|
33
|
-
endCursor
|
34
|
-
hasNextPage
|
35
|
-
}
|
36
|
-
}
|
37
|
-
"""
|
38
|
-
|
39
|
-
# TODO, factor out common file fragment
|
40
|
-
ARTIFACT_FILES_FRAGMENT = """fragment ArtifactFilesFragment on Artifact {
|
41
|
-
files(names: $fileNames, after: $fileCursor, first: $fileLimit) {
|
42
|
-
edges {
|
43
|
-
node {
|
44
|
-
id
|
45
|
-
name: displayName
|
46
|
-
url
|
47
|
-
sizeBytes
|
48
|
-
storagePath
|
49
|
-
mimetype
|
50
|
-
updatedAt
|
51
|
-
digest
|
52
|
-
md5
|
53
|
-
directUrl
|
54
|
-
}
|
55
|
-
cursor
|
56
|
-
}
|
57
|
-
pageInfo {
|
58
|
-
endCursor
|
59
|
-
hasNextPage
|
60
|
-
}
|
61
|
-
}
|
62
|
-
}"""
|
63
|
-
|
64
|
-
|
65
25
|
class ArtifactTypes(Paginator):
|
66
26
|
QUERY = gql(
|
67
27
|
"""
|
@@ -317,6 +277,7 @@ class ArtifactCollection:
|
|
317
277
|
project: str,
|
318
278
|
name: str,
|
319
279
|
type: str,
|
280
|
+
organization: Optional[str] = None,
|
320
281
|
attrs: Optional[Mapping[str, Any]] = None,
|
321
282
|
):
|
322
283
|
self.client = client
|
@@ -327,11 +288,14 @@ class ArtifactCollection:
|
|
327
288
|
self._type = type
|
328
289
|
self._saved_type = type
|
329
290
|
self._attrs = attrs
|
330
|
-
self.
|
291
|
+
if self._attrs is None:
|
292
|
+
self.load()
|
331
293
|
self._aliases = [a["node"]["alias"] for a in self._attrs["aliases"]["edges"]]
|
332
294
|
self._description = self._attrs["description"]
|
295
|
+
self._created_at = self._attrs["createdAt"]
|
333
296
|
self._tags = [a["node"]["name"] for a in self._attrs["tags"]["edges"]]
|
334
297
|
self._saved_tags = copy(self._tags)
|
298
|
+
self.organization = organization
|
335
299
|
|
336
300
|
@property
|
337
301
|
def id(self):
|
@@ -354,6 +318,10 @@ class ArtifactCollection:
|
|
354
318
|
"""Artifact Collection Aliases."""
|
355
319
|
return self._aliases
|
356
320
|
|
321
|
+
@property
|
322
|
+
def created_at(self):
|
323
|
+
return self._created_at
|
324
|
+
|
357
325
|
def load(self):
|
358
326
|
query = gql(
|
359
327
|
"""
|
@@ -0,0 +1,573 @@
|
|
1
|
+
"""Public API: registries."""
|
2
|
+
|
3
|
+
import json
|
4
|
+
from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Sequence
|
5
|
+
|
6
|
+
if TYPE_CHECKING:
|
7
|
+
from wandb_gql import Client
|
8
|
+
|
9
|
+
from wandb_gql import gql
|
10
|
+
|
11
|
+
import wandb
|
12
|
+
from wandb.apis.paginator import Paginator
|
13
|
+
from wandb.apis.public.artifacts import ArtifactCollection
|
14
|
+
from wandb.sdk.artifacts._graphql_fragments import (
|
15
|
+
_gql_artifact_fragment,
|
16
|
+
_gql_registry_fragment,
|
17
|
+
)
|
18
|
+
from wandb.sdk.artifacts._validators import REGISTRY_PREFIX
|
19
|
+
|
20
|
+
|
21
|
+
class Registries(Paginator):
|
22
|
+
"""Iterator that returns Registries."""
|
23
|
+
|
24
|
+
QUERY = gql(
|
25
|
+
"""
|
26
|
+
query Registries($organization: String!, $filters: JSONString, $cursor: String, $perPage: Int) {
|
27
|
+
organization(name: $organization) {
|
28
|
+
orgEntity {
|
29
|
+
name
|
30
|
+
projects(filters: $filters, after: $cursor, first: $perPage) {
|
31
|
+
pageInfo {
|
32
|
+
endCursor
|
33
|
+
hasNextPage
|
34
|
+
}
|
35
|
+
edges {
|
36
|
+
node {
|
37
|
+
...RegistryFragment
|
38
|
+
}
|
39
|
+
}
|
40
|
+
}
|
41
|
+
}
|
42
|
+
}
|
43
|
+
}
|
44
|
+
"""
|
45
|
+
+ _gql_registry_fragment()
|
46
|
+
)
|
47
|
+
|
48
|
+
def __init__(
|
49
|
+
self,
|
50
|
+
client: "Client",
|
51
|
+
organization: str,
|
52
|
+
filter: Optional[Dict[str, Any]] = None,
|
53
|
+
per_page: Optional[int] = 100,
|
54
|
+
):
|
55
|
+
self.client = client
|
56
|
+
self.organization = organization
|
57
|
+
self.filter = _ensure_registry_prefix_on_names(filter or {})
|
58
|
+
variables = {
|
59
|
+
"organization": organization,
|
60
|
+
"filters": json.dumps(self.filter),
|
61
|
+
}
|
62
|
+
|
63
|
+
super().__init__(client, variables, per_page)
|
64
|
+
|
65
|
+
def __bool__(self):
|
66
|
+
return len(self) > 0 or len(self.objects) > 0
|
67
|
+
|
68
|
+
def __next__(self):
|
69
|
+
# Implement custom next since its possible to load empty pages because of auth
|
70
|
+
self.index += 1
|
71
|
+
while len(self.objects) <= self.index:
|
72
|
+
if not self._load_page():
|
73
|
+
raise StopIteration
|
74
|
+
return self.objects[self.index]
|
75
|
+
|
76
|
+
def collections(self, filter: Optional[Dict[str, Any]] = None) -> "Collections":
|
77
|
+
return Collections(
|
78
|
+
self.client,
|
79
|
+
self.organization,
|
80
|
+
registry_filter=self.filter,
|
81
|
+
collection_filter=filter,
|
82
|
+
)
|
83
|
+
|
84
|
+
def versions(self, filter: Optional[Dict[str, Any]] = None) -> "Versions":
|
85
|
+
return Versions(
|
86
|
+
self.client,
|
87
|
+
self.organization,
|
88
|
+
registry_filter=self.filter,
|
89
|
+
collection_filter=None,
|
90
|
+
artifact_filter=filter,
|
91
|
+
)
|
92
|
+
|
93
|
+
@property
|
94
|
+
def length(self):
|
95
|
+
if self.last_response:
|
96
|
+
return len(
|
97
|
+
self.last_response["organization"]["orgEntity"]["projects"]["edges"]
|
98
|
+
)
|
99
|
+
else:
|
100
|
+
return None
|
101
|
+
|
102
|
+
@property
|
103
|
+
def more(self):
|
104
|
+
if self.last_response:
|
105
|
+
return self.last_response["organization"]["orgEntity"]["projects"][
|
106
|
+
"pageInfo"
|
107
|
+
]["hasNextPage"]
|
108
|
+
else:
|
109
|
+
return True
|
110
|
+
|
111
|
+
@property
|
112
|
+
def cursor(self):
|
113
|
+
if self.last_response:
|
114
|
+
return self.last_response["organization"]["orgEntity"]["projects"][
|
115
|
+
"pageInfo"
|
116
|
+
]["endCursor"]
|
117
|
+
else:
|
118
|
+
return None
|
119
|
+
|
120
|
+
def convert_objects(self):
|
121
|
+
if not self.last_response:
|
122
|
+
return []
|
123
|
+
if (
|
124
|
+
not self.last_response["organization"]
|
125
|
+
or not self.last_response["organization"]["orgEntity"]
|
126
|
+
):
|
127
|
+
raise ValueError(
|
128
|
+
f"Organization '{self.organization}' not found. Please verify the organization name is correct"
|
129
|
+
)
|
130
|
+
|
131
|
+
return [
|
132
|
+
Registry(
|
133
|
+
self.client,
|
134
|
+
self.organization,
|
135
|
+
self.last_response["organization"]["orgEntity"]["name"],
|
136
|
+
r["node"]["name"],
|
137
|
+
r["node"],
|
138
|
+
)
|
139
|
+
for r in self.last_response["organization"]["orgEntity"]["projects"][
|
140
|
+
"edges"
|
141
|
+
]
|
142
|
+
]
|
143
|
+
|
144
|
+
|
145
|
+
class Registry:
|
146
|
+
"""A single registry in the Registry."""
|
147
|
+
|
148
|
+
def __init__(
|
149
|
+
self,
|
150
|
+
client: "Client",
|
151
|
+
organization: str,
|
152
|
+
entity: str,
|
153
|
+
full_name: str,
|
154
|
+
attrs: Dict[str, Any],
|
155
|
+
):
|
156
|
+
self.client = client
|
157
|
+
self._full_name = full_name
|
158
|
+
self._name = full_name.replace(REGISTRY_PREFIX, "")
|
159
|
+
self._entity = entity
|
160
|
+
self._organization = organization
|
161
|
+
self._description = attrs.get("description", "")
|
162
|
+
self._allow_all_artifact_types = attrs.get(
|
163
|
+
"allowAllArtifactTypesInRegistry", False
|
164
|
+
)
|
165
|
+
self._artifact_types = [
|
166
|
+
t["node"]["name"] for t in attrs.get("artifactTypes", {}).get("edges", [])
|
167
|
+
]
|
168
|
+
self._id = attrs.get("id", "")
|
169
|
+
self._created_at = attrs.get("createdAt", "")
|
170
|
+
self._updated_at = attrs.get("updatedAt", "")
|
171
|
+
|
172
|
+
@property
|
173
|
+
def full_name(self):
|
174
|
+
return self._full_name
|
175
|
+
|
176
|
+
@property
|
177
|
+
def name(self):
|
178
|
+
return self._name
|
179
|
+
|
180
|
+
@property
|
181
|
+
def entity(self):
|
182
|
+
return self._entity
|
183
|
+
|
184
|
+
@property
|
185
|
+
def organization(self):
|
186
|
+
return self._organization
|
187
|
+
|
188
|
+
@property
|
189
|
+
def description(self):
|
190
|
+
return self._description
|
191
|
+
|
192
|
+
@property
|
193
|
+
def allow_all_artifact_types(self):
|
194
|
+
return self._allow_all_artifact_types
|
195
|
+
|
196
|
+
@property
|
197
|
+
def artifact_types(self):
|
198
|
+
return self._artifact_types
|
199
|
+
|
200
|
+
@property
|
201
|
+
def created_at(self):
|
202
|
+
return self._created_at
|
203
|
+
|
204
|
+
@property
|
205
|
+
def updated_at(self):
|
206
|
+
return self._updated_at
|
207
|
+
|
208
|
+
@property
|
209
|
+
def path(self):
|
210
|
+
return [self.entity, self.name]
|
211
|
+
|
212
|
+
def collections(self, filter: Optional[Dict[str, Any]] = None):
|
213
|
+
registry_filter = {
|
214
|
+
"name": self.full_name,
|
215
|
+
}
|
216
|
+
return Collections(self.client, self.organization, registry_filter, filter)
|
217
|
+
|
218
|
+
def versions(self, filter: Optional[Dict[str, Any]] = None):
|
219
|
+
registry_filter = {
|
220
|
+
"name": self.full_name,
|
221
|
+
}
|
222
|
+
return Versions(self.client, self.organization, registry_filter, None, filter)
|
223
|
+
|
224
|
+
|
225
|
+
class Collections(Paginator):
|
226
|
+
"""Iterator that returns Artifact collections in the Registry."""
|
227
|
+
|
228
|
+
QUERY = gql(
|
229
|
+
"""
|
230
|
+
query Collections(
|
231
|
+
$organization: String!,
|
232
|
+
$registryFilter: JSONString,
|
233
|
+
$collectionFilter: JSONString,
|
234
|
+
$collectionTypes: [ArtifactCollectionType!],
|
235
|
+
$cursor: String,
|
236
|
+
$perPage: Int
|
237
|
+
) {
|
238
|
+
organization(name: $organization) {
|
239
|
+
orgEntity {
|
240
|
+
name
|
241
|
+
artifactCollections(
|
242
|
+
projectFilters: $registryFilter,
|
243
|
+
filters: $collectionFilter,
|
244
|
+
collectionTypes: $collectionTypes,
|
245
|
+
after: $cursor,
|
246
|
+
first: $perPage
|
247
|
+
) {
|
248
|
+
totalCount
|
249
|
+
pageInfo {
|
250
|
+
endCursor
|
251
|
+
hasNextPage
|
252
|
+
}
|
253
|
+
edges {
|
254
|
+
cursor
|
255
|
+
node {
|
256
|
+
id
|
257
|
+
name
|
258
|
+
description
|
259
|
+
createdAt
|
260
|
+
tags {
|
261
|
+
edges {
|
262
|
+
node {
|
263
|
+
name
|
264
|
+
}
|
265
|
+
}
|
266
|
+
}
|
267
|
+
project {
|
268
|
+
name
|
269
|
+
entity {
|
270
|
+
name
|
271
|
+
}
|
272
|
+
}
|
273
|
+
defaultArtifactType {
|
274
|
+
name
|
275
|
+
}
|
276
|
+
aliases {
|
277
|
+
edges {
|
278
|
+
node {
|
279
|
+
alias
|
280
|
+
}
|
281
|
+
}
|
282
|
+
}
|
283
|
+
}
|
284
|
+
}
|
285
|
+
}
|
286
|
+
}
|
287
|
+
}
|
288
|
+
}
|
289
|
+
"""
|
290
|
+
)
|
291
|
+
|
292
|
+
def __init__(
|
293
|
+
self,
|
294
|
+
client: "Client",
|
295
|
+
organization: str,
|
296
|
+
registry_filter: Optional[Dict[str, Any]] = None,
|
297
|
+
collection_filter: Optional[Dict[str, Any]] = None,
|
298
|
+
per_page: Optional[int] = 100,
|
299
|
+
):
|
300
|
+
self.client = client
|
301
|
+
self.organization = organization
|
302
|
+
self.registry_filter = registry_filter
|
303
|
+
self.collection_filter = collection_filter or {}
|
304
|
+
|
305
|
+
variables = {
|
306
|
+
"registryFilter": json.dumps(self.registry_filter)
|
307
|
+
if self.registry_filter
|
308
|
+
else None,
|
309
|
+
"collectionFilter": json.dumps(self.collection_filter)
|
310
|
+
if self.collection_filter
|
311
|
+
else None,
|
312
|
+
"organization": self.organization,
|
313
|
+
"collectionTypes": ["PORTFOLIO"],
|
314
|
+
"perPage": per_page,
|
315
|
+
}
|
316
|
+
|
317
|
+
super().__init__(client, variables, per_page)
|
318
|
+
|
319
|
+
def __bool__(self):
|
320
|
+
return len(self) > 0 or len(self.objects) > 0
|
321
|
+
|
322
|
+
def __next__(self):
|
323
|
+
# Implement custom next since its possible to load empty pages because of auth
|
324
|
+
self.index += 1
|
325
|
+
while len(self.objects) <= self.index:
|
326
|
+
if not self._load_page():
|
327
|
+
raise StopIteration
|
328
|
+
return self.objects[self.index]
|
329
|
+
|
330
|
+
def versions(self, filter: Optional[Dict[str, Any]] = None) -> "Versions":
|
331
|
+
return Versions(
|
332
|
+
self.client,
|
333
|
+
self.organization,
|
334
|
+
registry_filter=self.registry_filter,
|
335
|
+
collection_filter=self.collection_filter,
|
336
|
+
artifact_filter=filter,
|
337
|
+
)
|
338
|
+
|
339
|
+
@property
|
340
|
+
def length(self):
|
341
|
+
if self.last_response:
|
342
|
+
return self.last_response["organization"]["orgEntity"][
|
343
|
+
"artifactCollections"
|
344
|
+
]["totalCount"]
|
345
|
+
else:
|
346
|
+
return None
|
347
|
+
|
348
|
+
@property
|
349
|
+
def more(self):
|
350
|
+
if self.last_response:
|
351
|
+
return self.last_response["organization"]["orgEntity"][
|
352
|
+
"artifactCollections"
|
353
|
+
]["pageInfo"]["hasNextPage"]
|
354
|
+
else:
|
355
|
+
return True
|
356
|
+
|
357
|
+
@property
|
358
|
+
def cursor(self):
|
359
|
+
if self.last_response:
|
360
|
+
return self.last_response["organization"]["orgEntity"][
|
361
|
+
"artifactCollections"
|
362
|
+
]["pageInfo"]["endCursor"]
|
363
|
+
else:
|
364
|
+
return None
|
365
|
+
|
366
|
+
def convert_objects(self):
|
367
|
+
if not self.last_response:
|
368
|
+
return []
|
369
|
+
if (
|
370
|
+
not self.last_response["organization"]
|
371
|
+
or not self.last_response["organization"]["orgEntity"]
|
372
|
+
):
|
373
|
+
raise ValueError(
|
374
|
+
f"Organization '{self.organization}' not found. Please verify the organization name is correct"
|
375
|
+
)
|
376
|
+
|
377
|
+
return [
|
378
|
+
ArtifactCollection(
|
379
|
+
self.client,
|
380
|
+
r["node"]["project"]["entity"]["name"],
|
381
|
+
r["node"]["project"]["name"],
|
382
|
+
r["node"]["name"],
|
383
|
+
r["node"]["defaultArtifactType"]["name"],
|
384
|
+
self.organization,
|
385
|
+
r["node"],
|
386
|
+
)
|
387
|
+
for r in self.last_response["organization"]["orgEntity"][
|
388
|
+
"artifactCollections"
|
389
|
+
]["edges"]
|
390
|
+
]
|
391
|
+
|
392
|
+
|
393
|
+
class Versions(Paginator):
|
394
|
+
"""Iterator that returns Artifact versions in the Registry."""
|
395
|
+
|
396
|
+
def __init__(
|
397
|
+
self,
|
398
|
+
client: "Client",
|
399
|
+
organization: str,
|
400
|
+
registry_filter: Optional[Dict[str, Any]] = None,
|
401
|
+
collection_filter: Optional[Dict[str, Any]] = None,
|
402
|
+
artifact_filter: Optional[Dict[str, Any]] = None,
|
403
|
+
per_page: int = 100,
|
404
|
+
):
|
405
|
+
self.client = client
|
406
|
+
self.organization = organization
|
407
|
+
self.registry_filter = registry_filter
|
408
|
+
self.collection_filter = collection_filter
|
409
|
+
self.artifact_filter = artifact_filter or {}
|
410
|
+
self.QUERY = gql(
|
411
|
+
"""
|
412
|
+
query Versions(
|
413
|
+
$organization: String!,
|
414
|
+
$registryFilter: JSONString,
|
415
|
+
$collectionFilter: JSONString,
|
416
|
+
$artifactFilter: JSONString,
|
417
|
+
$cursor: String,
|
418
|
+
$perPage: Int
|
419
|
+
) {
|
420
|
+
organization(name: $organization) {
|
421
|
+
orgEntity {
|
422
|
+
name
|
423
|
+
artifactMemberships(
|
424
|
+
projectFilters: $registryFilter,
|
425
|
+
collectionFilters: $collectionFilter,
|
426
|
+
filters: $artifactFilter,
|
427
|
+
after: $cursor,
|
428
|
+
first: $perPage
|
429
|
+
) {
|
430
|
+
pageInfo {
|
431
|
+
endCursor
|
432
|
+
hasNextPage
|
433
|
+
}
|
434
|
+
edges {
|
435
|
+
node {
|
436
|
+
artifactCollection {
|
437
|
+
project {
|
438
|
+
name
|
439
|
+
entity {
|
440
|
+
name
|
441
|
+
}
|
442
|
+
}
|
443
|
+
name
|
444
|
+
}
|
445
|
+
versionIndex
|
446
|
+
artifact {
|
447
|
+
...ArtifactFragment
|
448
|
+
}
|
449
|
+
aliases {
|
450
|
+
alias
|
451
|
+
}
|
452
|
+
}
|
453
|
+
}
|
454
|
+
}
|
455
|
+
}
|
456
|
+
}
|
457
|
+
}
|
458
|
+
"""
|
459
|
+
+ _gql_artifact_fragment(include_aliases=False)
|
460
|
+
)
|
461
|
+
|
462
|
+
variables = {
|
463
|
+
"registryFilter": json.dumps(self.registry_filter)
|
464
|
+
if self.registry_filter
|
465
|
+
else None,
|
466
|
+
"collectionFilter": json.dumps(self.collection_filter)
|
467
|
+
if self.collection_filter
|
468
|
+
else None,
|
469
|
+
"artifactFilter": json.dumps(self.artifact_filter)
|
470
|
+
if self.artifact_filter
|
471
|
+
else None,
|
472
|
+
"organization": self.organization,
|
473
|
+
}
|
474
|
+
|
475
|
+
super().__init__(client, variables, per_page)
|
476
|
+
|
477
|
+
def __next__(self):
|
478
|
+
# Implement custom next since its possible to load empty pages because of auth
|
479
|
+
self.index += 1
|
480
|
+
while len(self.objects) <= self.index:
|
481
|
+
if not self._load_page():
|
482
|
+
raise StopIteration
|
483
|
+
return self.objects[self.index]
|
484
|
+
|
485
|
+
def __bool__(self):
|
486
|
+
return len(self) > 0 or len(self.objects) > 0
|
487
|
+
|
488
|
+
@property
|
489
|
+
def length(self):
|
490
|
+
if self.last_response:
|
491
|
+
return len(
|
492
|
+
self.last_response["organization"]["orgEntity"]["artifactMemberships"][
|
493
|
+
"edges"
|
494
|
+
]
|
495
|
+
)
|
496
|
+
else:
|
497
|
+
return None
|
498
|
+
|
499
|
+
@property
|
500
|
+
def more(self):
|
501
|
+
if self.last_response:
|
502
|
+
return self.last_response["organization"]["orgEntity"][
|
503
|
+
"artifactMemberships"
|
504
|
+
]["pageInfo"]["hasNextPage"]
|
505
|
+
else:
|
506
|
+
return True
|
507
|
+
|
508
|
+
@property
|
509
|
+
def cursor(self):
|
510
|
+
if self.last_response:
|
511
|
+
return self.last_response["organization"]["orgEntity"][
|
512
|
+
"artifactMemberships"
|
513
|
+
]["pageInfo"]["endCursor"]
|
514
|
+
else:
|
515
|
+
return None
|
516
|
+
|
517
|
+
def convert_objects(self):
|
518
|
+
if not self.last_response:
|
519
|
+
return []
|
520
|
+
if (
|
521
|
+
not self.last_response["organization"]
|
522
|
+
or not self.last_response["organization"]["orgEntity"]
|
523
|
+
):
|
524
|
+
raise ValueError(
|
525
|
+
f"Organization '{self.organization}' not found. Please verify the organization name is correct"
|
526
|
+
)
|
527
|
+
|
528
|
+
artifacts = (
|
529
|
+
wandb.Artifact._from_attrs(
|
530
|
+
a["node"]["artifactCollection"]["project"]["entity"]["name"],
|
531
|
+
a["node"]["artifactCollection"]["project"]["name"],
|
532
|
+
a["node"]["artifactCollection"]["name"]
|
533
|
+
+ ":v"
|
534
|
+
+ str(a["node"]["versionIndex"]),
|
535
|
+
a["node"]["artifact"],
|
536
|
+
self.client,
|
537
|
+
[alias["alias"] for alias in a["node"]["aliases"]],
|
538
|
+
)
|
539
|
+
for a in self.last_response["organization"]["orgEntity"][
|
540
|
+
"artifactMemberships"
|
541
|
+
]["edges"]
|
542
|
+
)
|
543
|
+
return artifacts
|
544
|
+
|
545
|
+
|
546
|
+
def _ensure_registry_prefix_on_names(query, in_name=False):
|
547
|
+
"""Traverse the filter to prepend the `name` key value with the registry prefix unless the value is a regex.
|
548
|
+
|
549
|
+
- in_name: True if we are under a "name" key (or propagating from one).
|
550
|
+
|
551
|
+
EX: {"name": "model"} -> {"name": "wandb-registry-model"}
|
552
|
+
"""
|
553
|
+
if isinstance((txt := query), str):
|
554
|
+
if in_name:
|
555
|
+
return txt if txt.startswith(REGISTRY_PREFIX) else f"{REGISTRY_PREFIX}{txt}"
|
556
|
+
return txt
|
557
|
+
if isinstance((dct := query), Mapping):
|
558
|
+
new_dict = {}
|
559
|
+
for key, obj in dct.items():
|
560
|
+
if key == "name":
|
561
|
+
new_dict[key] = _ensure_registry_prefix_on_names(obj, in_name=True)
|
562
|
+
elif key == "$regex":
|
563
|
+
# For regex operator, we skip transformation of its value.
|
564
|
+
new_dict[key] = obj
|
565
|
+
else:
|
566
|
+
# For any other key, propagate the in_name and skip_transform flags as-is.
|
567
|
+
new_dict[key] = _ensure_registry_prefix_on_names(obj, in_name=in_name)
|
568
|
+
return new_dict
|
569
|
+
if isinstance((objs := query), Sequence):
|
570
|
+
return list(
|
571
|
+
map(lambda x: _ensure_registry_prefix_on_names(x, in_name=in_name), objs)
|
572
|
+
)
|
573
|
+
return query
|