wandb 0.19.9__py3-none-win_amd64.whl → 0.19.11__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (156) hide show
  1. wandb/__init__.py +1 -1
  2. wandb/__init__.pyi +6 -3
  3. wandb/_pydantic/__init__.py +14 -8
  4. wandb/_pydantic/base.py +51 -36
  5. wandb/_pydantic/utils.py +73 -0
  6. wandb/_pydantic/v1_compat.py +79 -57
  7. wandb/apis/public/__init__.py +2 -2
  8. wandb/apis/public/api.py +684 -4
  9. wandb/apis/public/artifacts.py +377 -677
  10. wandb/apis/public/automations.py +69 -0
  11. wandb/apis/public/integrations.py +180 -0
  12. wandb/apis/public/projects.py +29 -0
  13. wandb/apis/public/registries/__init__.py +0 -0
  14. wandb/apis/public/registries/_freezable_list.py +179 -0
  15. wandb/apis/public/{registries.py → registries/registries_search.py} +22 -129
  16. wandb/apis/public/registries/registry.py +357 -0
  17. wandb/apis/public/registries/utils.py +140 -0
  18. wandb/apis/public/runs.py +58 -56
  19. wandb/apis/public/utils.py +107 -1
  20. wandb/automations/__init__.py +73 -0
  21. wandb/automations/_filters/__init__.py +40 -0
  22. wandb/automations/_filters/expressions.py +181 -0
  23. wandb/automations/_filters/operators.py +258 -0
  24. wandb/automations/_filters/run_metrics.py +332 -0
  25. wandb/automations/_generated/__init__.py +177 -0
  26. wandb/automations/_generated/create_automation.py +17 -0
  27. wandb/automations/_generated/create_generic_webhook_integration.py +43 -0
  28. wandb/automations/_generated/delete_automation.py +17 -0
  29. wandb/automations/_generated/enums.py +33 -0
  30. wandb/automations/_generated/fragments.py +358 -0
  31. wandb/automations/_generated/generic_webhook_integrations_by_entity.py +22 -0
  32. wandb/automations/_generated/get_automations.py +24 -0
  33. wandb/automations/_generated/get_automations_by_entity.py +26 -0
  34. wandb/automations/_generated/input_types.py +104 -0
  35. wandb/automations/_generated/integrations_by_entity.py +22 -0
  36. wandb/automations/_generated/operations.py +647 -0
  37. wandb/automations/_generated/slack_integrations_by_entity.py +22 -0
  38. wandb/automations/_generated/update_automation.py +17 -0
  39. wandb/automations/_utils.py +237 -0
  40. wandb/automations/_validators.py +165 -0
  41. wandb/automations/actions.py +220 -0
  42. wandb/automations/automations.py +87 -0
  43. wandb/automations/events.py +287 -0
  44. wandb/automations/integrations.py +45 -0
  45. wandb/automations/scopes.py +78 -0
  46. wandb/beta/workflows.py +9 -10
  47. wandb/bin/gpu_stats.exe +0 -0
  48. wandb/bin/wandb-core +0 -0
  49. wandb/cli/cli.py +3 -3
  50. wandb/env.py +11 -0
  51. wandb/integration/keras/keras.py +2 -1
  52. wandb/integration/langchain/wandb_tracer.py +2 -1
  53. wandb/jupyter.py +137 -118
  54. wandb/old/settings.py +4 -1
  55. wandb/old/summary.py +0 -2
  56. wandb/proto/v3/wandb_internal_pb2.py +297 -292
  57. wandb/proto/v3/wandb_settings_pb2.py +2 -2
  58. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  59. wandb/proto/v4/wandb_internal_pb2.py +292 -292
  60. wandb/proto/v4/wandb_settings_pb2.py +2 -2
  61. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  62. wandb/proto/v5/wandb_internal_pb2.py +292 -292
  63. wandb/proto/v5/wandb_settings_pb2.py +2 -2
  64. wandb/proto/v5/wandb_telemetry_pb2.py +10 -10
  65. wandb/proto/v6/wandb_base_pb2.py +41 -0
  66. wandb/proto/v6/wandb_internal_pb2.py +393 -0
  67. wandb/proto/v6/wandb_server_pb2.py +78 -0
  68. wandb/proto/v6/wandb_settings_pb2.py +58 -0
  69. wandb/proto/v6/wandb_telemetry_pb2.py +52 -0
  70. wandb/proto/wandb_base_pb2.py +2 -0
  71. wandb/proto/wandb_deprecated.py +8 -0
  72. wandb/proto/wandb_internal_pb2.py +3 -1
  73. wandb/proto/wandb_server_pb2.py +2 -0
  74. wandb/proto/wandb_settings_pb2.py +2 -0
  75. wandb/proto/wandb_telemetry_pb2.py +2 -0
  76. wandb/sdk/artifacts/_generated/__init__.py +289 -0
  77. wandb/sdk/artifacts/_generated/add_aliases.py +21 -0
  78. wandb/sdk/artifacts/_generated/artifact_collection_membership_files.py +43 -0
  79. wandb/sdk/artifacts/_generated/artifact_version_files.py +36 -0
  80. wandb/sdk/artifacts/_generated/create_artifact_collection_tag_assignments.py +36 -0
  81. wandb/sdk/artifacts/_generated/delete_aliases.py +21 -0
  82. wandb/sdk/artifacts/_generated/delete_artifact_collection_tag_assignments.py +25 -0
  83. wandb/sdk/artifacts/_generated/delete_artifact_portfolio.py +35 -0
  84. wandb/sdk/artifacts/_generated/delete_artifact_sequence.py +35 -0
  85. wandb/sdk/artifacts/_generated/enums.py +17 -0
  86. wandb/sdk/artifacts/_generated/fetch_linked_artifacts.py +67 -0
  87. wandb/sdk/artifacts/_generated/fragments.py +221 -0
  88. wandb/sdk/artifacts/_generated/input_types.py +28 -0
  89. wandb/sdk/artifacts/_generated/move_artifact_collection.py +35 -0
  90. wandb/sdk/artifacts/_generated/operations.py +611 -0
  91. wandb/sdk/artifacts/_generated/project_artifact_collection.py +101 -0
  92. wandb/sdk/artifacts/_generated/project_artifact_collections.py +33 -0
  93. wandb/sdk/artifacts/_generated/project_artifact_type.py +24 -0
  94. wandb/sdk/artifacts/_generated/project_artifact_types.py +24 -0
  95. wandb/sdk/artifacts/_generated/project_artifacts.py +42 -0
  96. wandb/sdk/artifacts/_generated/run_input_artifacts.py +51 -0
  97. wandb/sdk/artifacts/_generated/run_output_artifacts.py +51 -0
  98. wandb/sdk/artifacts/_generated/update_artifact.py +26 -0
  99. wandb/sdk/artifacts/_generated/update_artifact_portfolio.py +35 -0
  100. wandb/sdk/artifacts/_generated/update_artifact_sequence.py +35 -0
  101. wandb/sdk/artifacts/_graphql_fragments.py +57 -79
  102. wandb/sdk/artifacts/_validators.py +120 -1
  103. wandb/sdk/artifacts/artifact.py +419 -215
  104. wandb/sdk/artifacts/artifact_file_cache.py +4 -6
  105. wandb/sdk/artifacts/artifact_manifest_entry.py +13 -3
  106. wandb/sdk/artifacts/storage_handlers/azure_handler.py +1 -0
  107. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +182 -1
  108. wandb/sdk/artifacts/storage_policy.py +3 -0
  109. wandb/sdk/data_types/base_types/media.py +2 -3
  110. wandb/sdk/data_types/base_types/wb_value.py +34 -11
  111. wandb/sdk/data_types/html.py +36 -9
  112. wandb/sdk/data_types/image.py +12 -12
  113. wandb/sdk/data_types/table.py +5 -0
  114. wandb/sdk/data_types/trace_tree.py +2 -0
  115. wandb/sdk/data_types/utils.py +1 -1
  116. wandb/sdk/data_types/video.py +59 -57
  117. wandb/sdk/interface/interface.py +4 -3
  118. wandb/sdk/internal/internal_api.py +21 -31
  119. wandb/sdk/internal/profiler.py +6 -5
  120. wandb/sdk/internal/run.py +13 -6
  121. wandb/sdk/internal/sender.py +5 -2
  122. wandb/sdk/launch/sweeps/utils.py +8 -0
  123. wandb/sdk/lib/apikey.py +25 -4
  124. wandb/sdk/lib/asyncio_compat.py +1 -1
  125. wandb/sdk/lib/deprecate.py +13 -22
  126. wandb/sdk/lib/disabled.py +2 -1
  127. wandb/sdk/lib/printer.py +37 -8
  128. wandb/sdk/lib/printer_asyncio.py +46 -0
  129. wandb/sdk/lib/redirect.py +10 -5
  130. wandb/sdk/projects/_generated/__init__.py +47 -0
  131. wandb/sdk/projects/_generated/delete_project.py +22 -0
  132. wandb/sdk/projects/_generated/enums.py +4 -0
  133. wandb/sdk/projects/_generated/fetch_registry.py +22 -0
  134. wandb/sdk/projects/_generated/fragments.py +41 -0
  135. wandb/sdk/projects/_generated/input_types.py +13 -0
  136. wandb/sdk/projects/_generated/operations.py +88 -0
  137. wandb/sdk/projects/_generated/rename_project.py +27 -0
  138. wandb/sdk/projects/_generated/upsert_registry_project.py +27 -0
  139. wandb/sdk/service/server_sock.py +19 -14
  140. wandb/sdk/service/service.py +18 -8
  141. wandb/sdk/service/streams.py +5 -0
  142. wandb/sdk/verify/verify.py +6 -3
  143. wandb/sdk/wandb_init.py +217 -70
  144. wandb/sdk/wandb_login.py +13 -4
  145. wandb/sdk/wandb_run.py +419 -295
  146. wandb/sdk/wandb_settings.py +27 -10
  147. wandb/sdk/wandb_setup.py +61 -0
  148. wandb/util.py +33 -29
  149. {wandb-0.19.9.dist-info → wandb-0.19.11.dist-info}/METADATA +5 -5
  150. {wandb-0.19.9.dist-info → wandb-0.19.11.dist-info}/RECORD +153 -83
  151. wandb/_globals.py +0 -19
  152. wandb/sdk/internal/_generated/base.py +0 -226
  153. wandb/sdk/internal/_generated/typing_compat.py +0 -14
  154. {wandb-0.19.9.dist-info → wandb-0.19.11.dist-info}/WHEEL +0 -0
  155. {wandb-0.19.9.dist-info → wandb-0.19.11.dist-info}/entry_points.txt +0 -0
  156. {wandb-0.19.9.dist-info → wandb-0.19.11.dist-info}/licenses/LICENSE +0 -0
@@ -3,8 +3,20 @@
3
3
  import json
4
4
  import re
5
5
  from copy import copy
6
- from typing import TYPE_CHECKING, Any, List, Mapping, Optional, Sequence, Union
6
+ from typing import (
7
+ TYPE_CHECKING,
8
+ Any,
9
+ Iterable,
10
+ List,
11
+ Literal,
12
+ Mapping,
13
+ Optional,
14
+ Sequence,
15
+ Type,
16
+ Union,
17
+ )
7
18
 
19
+ from typing_extensions import override
8
20
  from wandb_gql import Client, gql
9
21
 
10
22
  import wandb
@@ -12,35 +24,58 @@ from wandb.apis import public
12
24
  from wandb.apis.normalize import normalize_exceptions
13
25
  from wandb.apis.paginator import Paginator, SizedPaginator
14
26
  from wandb.errors.term import termlog
27
+ from wandb.proto.wandb_deprecated import Deprecated
15
28
  from wandb.proto.wandb_internal_pb2 import ServerFeature
16
- from wandb.sdk.artifacts._graphql_fragments import (
17
- ARTIFACT_FILES_FRAGMENT,
18
- ARTIFACTS_TYPES_FRAGMENT,
29
+ from wandb.sdk.artifacts._generated import (
30
+ ARTIFACT_COLLECTION_MEMBERSHIP_FILES_GQL,
31
+ ARTIFACT_VERSION_FILES_GQL,
32
+ CREATE_ARTIFACT_COLLECTION_TAG_ASSIGNMENTS_GQL,
33
+ DELETE_ARTIFACT_COLLECTION_TAG_ASSIGNMENTS_GQL,
34
+ DELETE_ARTIFACT_PORTFOLIO_GQL,
35
+ DELETE_ARTIFACT_SEQUENCE_GQL,
36
+ MOVE_ARTIFACT_COLLECTION_GQL,
37
+ PROJECT_ARTIFACT_COLLECTION_GQL,
38
+ PROJECT_ARTIFACT_COLLECTIONS_GQL,
39
+ PROJECT_ARTIFACT_TYPE_GQL,
40
+ PROJECT_ARTIFACT_TYPES_GQL,
41
+ PROJECT_ARTIFACTS_GQL,
42
+ RUN_INPUT_ARTIFACTS_GQL,
43
+ RUN_OUTPUT_ARTIFACTS_GQL,
44
+ UPDATE_ARTIFACT_PORTFOLIO_GQL,
45
+ UPDATE_ARTIFACT_SEQUENCE_GQL,
46
+ ArtifactCollectionMembershipFiles,
47
+ ArtifactCollectionsFragment,
48
+ ArtifactsFragment,
49
+ ArtifactTypeFragment,
50
+ ArtifactTypesFragment,
51
+ ArtifactVersionFiles,
52
+ FilesFragment,
53
+ ProjectArtifactCollection,
54
+ ProjectArtifactCollections,
55
+ ProjectArtifacts,
56
+ ProjectArtifactType,
57
+ ProjectArtifactTypes,
58
+ RunInputArtifactsProjectRunInputArtifacts,
59
+ RunOutputArtifactsProjectRunOutputArtifacts,
60
+ )
61
+ from wandb.sdk.artifacts._graphql_fragments import omit_artifact_fields
62
+ from wandb.sdk.artifacts._validators import (
63
+ SOURCE_ARTIFACT_COLLECTION_TYPE,
64
+ validate_artifact_name,
19
65
  )
20
66
  from wandb.sdk.internal.internal_api import Api as InternalApi
21
67
  from wandb.sdk.lib import deprecate
22
68
 
69
+ from .utils import gql_compat
70
+
23
71
  if TYPE_CHECKING:
24
72
  from wandb.apis.public import RetryingClient, Run
25
73
 
26
74
 
27
75
  class ArtifactTypes(Paginator["ArtifactType"]):
28
- QUERY = gql(
29
- """
30
- query ProjectArtifacts(
31
- $entityName: String!,
32
- $projectName: String!,
33
- $cursor: String,
34
- ) {{
35
- project(name: $projectName, entityName: $entityName) {{
36
- artifactTypes(after: $cursor) {{
37
- ...ArtifactTypesFragment
38
- }}
39
- }}
40
- }}
41
- {}
42
- """.format(ARTIFACTS_TYPES_FRAGMENT)
43
- )
76
+ QUERY = gql(PROJECT_ARTIFACT_TYPES_GQL)
77
+
78
+ last_response: Optional[ArtifactTypesFragment]
44
79
 
45
80
  def __init__(
46
81
  self,
@@ -56,41 +91,54 @@ class ArtifactTypes(Paginator["ArtifactType"]):
56
91
  "entityName": entity,
57
92
  "projectName": project,
58
93
  }
59
-
60
94
  super().__init__(client, variable_values, per_page)
61
95
 
96
+ @override
97
+ def _update_response(self) -> None:
98
+ """Fetch and validate the response data for the current page."""
99
+ data = self.client.execute(self.QUERY, variable_values=self.variables)
100
+ result = ProjectArtifactTypes.model_validate(data)
101
+
102
+ # Extract the inner `*Connection` result for faster/easier access.
103
+ if not ((proj := result.project) and (conn := proj.artifact_types)):
104
+ raise ValueError(f"Unable to parse {type(self).__name__!r} response data")
105
+
106
+ self.last_response = ArtifactTypesFragment.model_validate(conn)
107
+
62
108
  @property
63
109
  def length(self) -> None:
64
110
  # TODO
65
111
  return None
66
112
 
67
113
  @property
68
- def more(self):
69
- if self.last_response:
70
- return self.last_response["project"]["artifactTypes"]["pageInfo"][
71
- "hasNextPage"
72
- ]
73
- else:
114
+ def more(self) -> bool:
115
+ if self.last_response is None:
74
116
  return True
117
+ return self.last_response.page_info.has_next_page
75
118
 
76
119
  @property
77
- def cursor(self):
78
- if self.last_response:
79
- return self.last_response["project"]["artifactTypes"]["edges"][-1]["cursor"]
80
- else:
120
+ def cursor(self) -> Optional[str]:
121
+ if self.last_response is None:
81
122
  return None
123
+ return self.last_response.edges[-1].cursor
82
124
 
83
- def update_variables(self):
125
+ def update_variables(self) -> None:
84
126
  self.variables.update({"cursor": self.cursor})
85
127
 
86
- def convert_objects(self):
87
- if self.last_response["project"] is None:
128
+ def convert_objects(self) -> List["ArtifactType"]:
129
+ if self.last_response is None:
88
130
  return []
131
+
89
132
  return [
90
133
  ArtifactType(
91
- self.client, self.entity, self.project, r["node"]["name"], r["node"]
134
+ client=self.client,
135
+ entity=self.entity,
136
+ project=self.project,
137
+ type_name=node.name,
138
+ attrs=node.model_dump(exclude_unset=True),
92
139
  )
93
- for r in self.last_response["project"]["artifactTypes"]["edges"]
140
+ for edge in self.last_response.edges
141
+ if edge.node and (node := ArtifactTypeFragment.model_validate(edge.node))
94
142
  ]
95
143
 
96
144
 
@@ -112,39 +160,19 @@ class ArtifactType:
112
160
  self.load()
113
161
 
114
162
  def load(self):
115
- query = gql(
116
- """
117
- query ProjectArtifactType(
118
- $entityName: String!,
119
- $projectName: String!,
120
- $artifactTypeName: String!
121
- ) {
122
- project(name: $projectName, entityName: $entityName) {
123
- artifactType(name: $artifactTypeName) {
124
- id
125
- name
126
- description
127
- createdAt
128
- }
129
- }
130
- }
131
- """
132
- )
133
- response: Optional[Mapping[str, Any]] = self.client.execute(
134
- query,
163
+ data: Optional[Mapping[str, Any]] = self.client.execute(
164
+ gql(PROJECT_ARTIFACT_TYPE_GQL),
135
165
  variable_values={
136
166
  "entityName": self.entity,
137
167
  "projectName": self.project,
138
168
  "artifactTypeName": self.type,
139
169
  },
140
170
  )
141
- if (
142
- response is None
143
- or response.get("project") is None
144
- or response["project"].get("artifactType") is None
145
- ):
146
- raise ValueError("Could not find artifact type {}".format(self.type))
147
- self._attrs = response["project"]["artifactType"]
171
+ result = ProjectArtifactType.model_validate(data)
172
+ if not ((proj := result.project) and (artifact_type := proj.artifact_type)):
173
+ raise ValueError(f"Could not find artifact type {self.type}")
174
+
175
+ self._attrs = artifact_type.model_dump(exclude_unset=True)
148
176
  return self._attrs
149
177
 
150
178
  @property
@@ -170,6 +198,8 @@ class ArtifactType:
170
198
 
171
199
 
172
200
  class ArtifactCollections(SizedPaginator["ArtifactCollection"]):
201
+ last_response: Optional[ArtifactCollectionsFragment]
202
+
173
203
  def __init__(
174
204
  self,
175
205
  client: Client,
@@ -188,86 +218,65 @@ class ArtifactCollections(SizedPaginator["ArtifactCollection"]):
188
218
  "artifactTypeName": type_name,
189
219
  }
190
220
 
191
- self.QUERY = gql(
192
- """
193
- query ProjectArtifactCollections(
194
- $entityName: String!,
195
- $projectName: String!,
196
- $artifactTypeName: String!
197
- $cursor: String,
198
- ) {{
199
- project(name: $projectName, entityName: $entityName) {{
200
- artifactType(name: $artifactTypeName) {{
201
- artifactCollections: {}(after: $cursor) {{
202
- pageInfo {{
203
- endCursor
204
- hasNextPage
205
- }}
206
- totalCount
207
- edges {{
208
- node {{
209
- id
210
- name
211
- description
212
- createdAt
213
- }}
214
- cursor
215
- }}
216
- }}
217
- }}
218
- }}
219
- }}
220
- """.format(
221
- artifact_collection_plural_edge_name(
222
- server_supports_artifact_collections_gql_edges(client)
223
- )
224
- )
221
+ if server_supports_artifact_collections_gql_edges(client):
222
+ rename_fields = None
223
+ else:
224
+ rename_fields = {"artifactCollections": "artifactSequences"}
225
+
226
+ self.QUERY = gql_compat(
227
+ PROJECT_ARTIFACT_COLLECTIONS_GQL, rename_fields=rename_fields
225
228
  )
226
229
 
227
230
  super().__init__(client, variable_values, per_page)
228
231
 
232
+ @override
233
+ def _update_response(self) -> None:
234
+ """Fetch and validate the response data for the current page."""
235
+ data = self.client.execute(self.QUERY, variable_values=self.variables)
236
+ result = ProjectArtifactCollections.model_validate(data)
237
+
238
+ # Extract the inner `*Connection` result for faster/easier access.
239
+ if not (
240
+ (proj := result.project)
241
+ and (type_ := proj.artifact_type)
242
+ and (conn := type_.artifact_collections)
243
+ ):
244
+ raise ValueError(f"Unable to parse {type(self).__name__!r} response data")
245
+
246
+ self.last_response = ArtifactCollectionsFragment.model_validate(conn)
247
+
229
248
  @property
230
249
  def length(self):
231
- if self.last_response:
232
- return self.last_response["project"]["artifactType"]["artifactCollections"][
233
- "totalCount"
234
- ]
235
- else:
250
+ if self.last_response is None:
236
251
  return None
252
+ return self.last_response.total_count
237
253
 
238
254
  @property
239
255
  def more(self):
240
- if self.last_response:
241
- return self.last_response["project"]["artifactType"]["artifactCollections"][
242
- "pageInfo"
243
- ]["hasNextPage"]
244
- else:
256
+ if self.last_response is None:
245
257
  return True
258
+ return self.last_response.page_info.has_next_page
246
259
 
247
260
  @property
248
261
  def cursor(self):
249
- if self.last_response:
250
- return self.last_response["project"]["artifactType"]["artifactCollections"][
251
- "edges"
252
- ][-1]["cursor"]
253
- else:
262
+ if self.last_response is None:
254
263
  return None
264
+ return self.last_response.edges[-1].cursor
255
265
 
256
- def update_variables(self):
266
+ def update_variables(self) -> None:
257
267
  self.variables.update({"cursor": self.cursor})
258
268
 
259
- def convert_objects(self):
269
+ def convert_objects(self) -> List["ArtifactCollection"]:
260
270
  return [
261
271
  ArtifactCollection(
262
- self.client,
263
- self.entity,
264
- self.project,
265
- r["node"]["name"],
266
- self.type_name,
272
+ client=self.client,
273
+ entity=self.entity,
274
+ project=self.project,
275
+ name=node.name,
276
+ type=self.type_name,
267
277
  )
268
- for r in self.last_response["project"]["artifactType"][
269
- "artifactCollections"
270
- ]["edges"]
278
+ for edge in self.last_response.edges
279
+ if (node := edge.node)
271
280
  ]
272
281
 
273
282
 
@@ -281,16 +290,20 @@ class ArtifactCollection:
281
290
  type: str,
282
291
  organization: Optional[str] = None,
283
292
  attrs: Optional[Mapping[str, Any]] = None,
293
+ is_sequence: Optional[bool] = None,
284
294
  ):
285
295
  self.client = client
286
296
  self.entity = entity
287
297
  self.project = project
288
- self._name = name
298
+ self._name = validate_artifact_name(name)
289
299
  self._saved_name = name
290
300
  self._type = type
291
301
  self._saved_type = type
292
302
  self._attrs = attrs
293
- if self._attrs is None:
303
+ if is_sequence is not None:
304
+ self._is_sequence = is_sequence
305
+ is_loaded = attrs is not None and is_sequence is not None
306
+ if not is_loaded:
294
307
  self.load()
295
308
  self._aliases = [a["node"]["alias"] for a in self._attrs["aliases"]["edges"]]
296
309
  self._description = self._attrs["description"]
@@ -300,83 +313,38 @@ class ArtifactCollection:
300
313
  self.organization = organization
301
314
 
302
315
  @property
303
- def id(self):
316
+ def id(self) -> str:
304
317
  return self._attrs["id"]
305
318
 
306
319
  @normalize_exceptions
307
- def artifacts(self, per_page=50):
320
+ def artifacts(self, per_page: int = 50) -> "Artifacts":
308
321
  """Artifacts."""
309
322
  return Artifacts(
310
- self.client,
311
- self.entity,
312
- self.project,
313
- self._saved_name,
314
- self._saved_type,
323
+ client=self.client,
324
+ entity=self.entity,
325
+ project=self.project,
326
+ collection_name=self._saved_name,
327
+ type=self._saved_type,
315
328
  per_page=per_page,
316
329
  )
317
330
 
318
331
  @property
319
- def aliases(self):
332
+ def aliases(self) -> List[str]:
320
333
  """Artifact Collection Aliases."""
321
334
  return self._aliases
322
335
 
323
336
  @property
324
- def created_at(self):
337
+ def created_at(self) -> str:
325
338
  return self._created_at
326
339
 
327
340
  def load(self):
328
- query = gql(
329
- """
330
- query ArtifactCollection(
331
- $entityName: String!,
332
- $projectName: String!,
333
- $artifactTypeName: String!,
334
- $artifactCollectionName: String!,
335
- $cursor: String,
336
- $perPage: Int = 1000
337
- ) {{
338
- project(name: $projectName, entityName: $entityName) {{
339
- artifactType(name: $artifactTypeName) {{
340
- artifactCollection: {}(name: $artifactCollectionName) {{
341
- id
342
- name
343
- description
344
- createdAt
345
- tags {{
346
- edges {{
347
- node {{
348
- id
349
- name
350
- }}
351
- }}
352
- }}
353
- aliases(after: $cursor, first: $perPage){{
354
- edges {{
355
- node {{
356
- alias
357
- }}
358
- cursor
359
- }}
360
- pageInfo {{
361
- endCursor
362
- hasNextPage
363
- }}
364
- }}
365
- }}
366
- artifactSequence(name: $artifactCollectionName) {{
367
- __typename
368
- }}
369
- }}
370
- }}
371
- }}
372
- """.format(
373
- artifact_collection_edge_name(
374
- server_supports_artifact_collections_gql_edges(self.client)
375
- )
376
- )
377
- )
341
+ if server_supports_artifact_collections_gql_edges(self.client):
342
+ rename_fields = None
343
+ else:
344
+ rename_fields = {"artifactCollection": "artifactSequence"}
345
+
378
346
  response = self.client.execute(
379
- query,
347
+ gql_compat(PROJECT_ARTIFACT_COLLECTION_GQL, rename_fields=rename_fields),
380
348
  variable_values={
381
349
  "entityName": self.entity,
382
350
  "projectName": self.project,
@@ -384,26 +352,30 @@ class ArtifactCollection:
384
352
  "artifactCollectionName": self._saved_name,
385
353
  },
386
354
  )
387
- if (
388
- response is None
389
- or response.get("project") is None
390
- or response["project"].get("artifactType") is None
391
- or response["project"]["artifactType"].get("artifactCollection") is None
355
+
356
+ result = ProjectArtifactCollection.model_validate(response)
357
+
358
+ if not (
359
+ result.project
360
+ and (proj := result.project)
361
+ and (type_ := proj.artifact_type)
362
+ and (collection := type_.artifact_collection)
392
363
  ):
393
- raise ValueError("Could not find artifact type {}".format(self._saved_type))
394
- sequence = response["project"]["artifactType"]["artifactSequence"]
364
+ raise ValueError(f"Could not find artifact type {self._saved_type}")
365
+
366
+ sequence = type_.artifact_sequence
395
367
  self._is_sequence = (
396
- sequence is not None and sequence["__typename"] == "ArtifactSequence"
397
- )
368
+ sequence is not None
369
+ ) and sequence.typename__ == SOURCE_ARTIFACT_COLLECTION_TYPE
398
370
 
399
371
  if self._attrs is None:
400
- self._attrs = response["project"]["artifactType"]["artifactCollection"]
372
+ self._attrs = collection.model_dump(exclude_unset=True)
401
373
  return self._attrs
402
374
 
403
375
  def change_type(self, new_type: str) -> None:
404
376
  """Deprecated, change type directly with `save` instead."""
405
377
  deprecate.deprecate(
406
- field_name=deprecate.Deprecated.artifact_collection__change_type,
378
+ field_name=Deprecated.artifact_collection__change_type,
407
379
  warning_message="ArtifactCollection.change_type(type) is deprecated, use ArtifactCollection.save() instead.",
408
380
  )
409
381
 
@@ -412,32 +384,13 @@ class ArtifactCollection:
412
384
  termlog(
413
385
  f"Changing artifact collection type of {self._saved_type} to {new_type}"
414
386
  )
415
- template = """
416
- mutation MoveArtifactCollection(
417
- $artifactSequenceID: ID!
418
- $destinationArtifactTypeName: String!
419
- ) {
420
- moveArtifactSequence(
421
- input: {
422
- artifactSequenceID: $artifactSequenceID
423
- destinationArtifactTypeName: $destinationArtifactTypeName
424
- }
425
- ) {
426
- artifactCollection {
427
- id
428
- name
429
- description
430
- __typename
431
- }
432
- }
433
- }
434
- """
435
- variable_values = {
436
- "artifactSequenceID": self.id,
437
- "destinationArtifactTypeName": new_type,
438
- }
439
- mutation = gql(template)
440
- self.client.execute(mutation, variable_values=variable_values)
387
+ self.client.execute(
388
+ gql(MOVE_ARTIFACT_COLLECTION_GQL),
389
+ variable_values={
390
+ "artifactSequenceID": self.id,
391
+ "destinationArtifactTypeName": new_type,
392
+ },
393
+ )
441
394
  self._saved_type = new_type
442
395
  self._type = new_type
443
396
 
@@ -446,40 +399,19 @@ class ArtifactCollection:
446
399
  return self._is_sequence
447
400
 
448
401
  @normalize_exceptions
449
- def delete(self):
402
+ def delete(self) -> None:
450
403
  """Delete the entire artifact collection."""
451
- if self.is_sequence():
452
- mutation = gql(
453
- """
454
- mutation deleteArtifactSequence($id: ID!) {
455
- deleteArtifactSequence(input: {
456
- artifactSequenceID: $id
457
- }) {
458
- artifactCollection {
459
- state
460
- }
461
- }
462
- }
463
- """
464
- )
465
- else:
466
- mutation = gql(
467
- """
468
- mutation deleteArtifactPortfolio($id: ID!) {
469
- deleteArtifactPortfolio(input: {
470
- artifactPortfolioID: $id
471
- }) {
472
- artifactCollection {
473
- state
474
- }
475
- }
476
- }
477
- """
478
- )
479
- self.client.execute(mutation, variable_values={"id": self.id})
404
+ self.client.execute(
405
+ gql(
406
+ DELETE_ARTIFACT_SEQUENCE_GQL
407
+ if self.is_sequence()
408
+ else DELETE_ARTIFACT_PORTFOLIO_GQL
409
+ ),
410
+ variable_values={"id": self.id},
411
+ )
480
412
 
481
413
  @property
482
- def description(self):
414
+ def description(self) -> str:
483
415
  """A description of the artifact collection."""
484
416
  return self._description
485
417
 
@@ -488,7 +420,7 @@ class ArtifactCollection:
488
420
  self._description = description
489
421
 
490
422
  @property
491
- def tags(self):
423
+ def tags(self) -> List[str]:
492
424
  """The tags associated with the artifact collection."""
493
425
  return self._tags
494
426
 
@@ -501,13 +433,13 @@ class ArtifactCollection:
501
433
  self._tags = tags
502
434
 
503
435
  @property
504
- def name(self):
436
+ def name(self) -> str:
505
437
  """The name of the artifact collection."""
506
438
  return self._name
507
439
 
508
440
  @name.setter
509
- def name(self, name: List[str]) -> None:
510
- self._name = name
441
+ def name(self, name: str) -> None:
442
+ self._name = validate_artifact_name(name)
511
443
 
512
444
  @property
513
445
  def type(self):
@@ -522,193 +454,69 @@ class ArtifactCollection:
522
454
  )
523
455
  self._type = type
524
456
 
525
- def _update_collection(self):
526
- mutation = gql("""
527
- mutation UpdateArtifactCollection(
528
- $artifactSequenceID: ID!
529
- $name: String
530
- $description: String
531
- ) {
532
- updateArtifactSequence(
533
- input: {
534
- artifactSequenceID: $artifactSequenceID
535
- name: $name
536
- description: $description
537
- }
538
- ) {
539
- artifactCollection {
540
- id
541
- name
542
- description
543
- }
544
- }
545
- }
546
- """)
547
-
548
- variable_values = {
549
- "artifactSequenceID": self.id,
550
- "name": self._name,
551
- "description": self.description,
552
- }
553
- self.client.execute(mutation, variable_values=variable_values)
457
+ def _update_collection(self) -> None:
458
+ self.client.execute(
459
+ gql(
460
+ UPDATE_ARTIFACT_SEQUENCE_GQL
461
+ if self.is_sequence()
462
+ else UPDATE_ARTIFACT_PORTFOLIO_GQL
463
+ ),
464
+ variable_values={
465
+ "id": self.id,
466
+ "name": self.name,
467
+ "description": self.description,
468
+ },
469
+ )
554
470
  self._saved_name = self._name
555
471
 
556
- def _update_collection_type(self):
557
- type_mutation = gql("""
558
- mutation MoveArtifactCollection(
559
- $artifactSequenceID: ID!
560
- $destinationArtifactTypeName: String!
561
- ) {
562
- moveArtifactSequence(
563
- input: {
564
- artifactSequenceID: $artifactSequenceID
565
- destinationArtifactTypeName: $destinationArtifactTypeName
566
- }
567
- ) {
568
- artifactCollection {
569
- id
570
- name
571
- description
572
- __typename
573
- }
574
- }
575
- }
576
- """)
577
-
578
- variable_values = {
579
- "artifactSequenceID": self.id,
580
- "destinationArtifactTypeName": self._type,
581
- }
582
- self.client.execute(type_mutation, variable_values=variable_values)
472
+ def _update_collection_type(self) -> None:
473
+ self.client.execute(
474
+ gql(MOVE_ARTIFACT_COLLECTION_GQL),
475
+ variable_values={
476
+ "artifactSequenceID": self.id,
477
+ "destinationArtifactTypeName": self.type,
478
+ },
479
+ )
583
480
  self._saved_type = self._type
584
481
 
585
- def _update_portfolio(self):
586
- mutation = gql("""
587
- mutation UpdateArtifactPortfolio(
588
- $artifactPortfolioID: ID!
589
- $name: String
590
- $description: String
591
- ) {
592
- updateArtifactPortfolio(
593
- input: {
594
- artifactPortfolioID: $artifactPortfolioID
595
- name: $name
596
- description: $description
597
- }
598
- ) {
599
- artifactCollection {
600
- id
601
- name
602
- description
603
- }
604
- }
605
- }
606
- """)
607
- variable_values = {
608
- "artifactPortfolioID": self.id,
609
- "name": self._name,
610
- "description": self.description,
611
- }
612
- self.client.execute(mutation, variable_values=variable_values)
613
- self._saved_name = self._name
614
-
615
- def _add_tags(self, tags_to_add):
616
- add_mutation = gql(
617
- """
618
- mutation CreateArtifactCollectionTagAssignments(
619
- $entityName: String!
620
- $projectName: String!
621
- $artifactCollectionName: String!
622
- $tags: [TagInput!]!
623
- ) {
624
- createArtifactCollectionTagAssignments(
625
- input: {
626
- entityName: $entityName
627
- projectName: $projectName
628
- artifactCollectionName: $artifactCollectionName
629
- tags: $tags
630
- }
631
- ) {
632
- tags {
633
- id
634
- name
635
- tagCategoryName
636
- }
637
- }
638
- }
639
- """
640
- )
482
+ def _add_tags(self, tags_to_add: Iterable[str]) -> None:
641
483
  self.client.execute(
642
- add_mutation,
484
+ gql(CREATE_ARTIFACT_COLLECTION_TAG_ASSIGNMENTS_GQL),
643
485
  variable_values={
644
486
  "entityName": self.entity,
645
487
  "projectName": self.project,
646
488
  "artifactCollectionName": self._saved_name,
647
- "tags": [
648
- {
649
- "tagName": tag,
650
- }
651
- for tag in tags_to_add
652
- ],
489
+ "tags": [{"tagName": tag} for tag in tags_to_add],
653
490
  },
654
491
  )
655
492
 
656
- def _delete_tags(self, tags_to_delete):
657
- delete_mutation = gql(
658
- """
659
- mutation DeleteArtifactCollectionTagAssignments(
660
- $entityName: String!
661
- $projectName: String!
662
- $artifactCollectionName: String!
663
- $tags: [TagInput!]!
664
- ) {
665
- deleteArtifactCollectionTagAssignments(
666
- input: {
667
- entityName: $entityName
668
- projectName: $projectName
669
- artifactCollectionName: $artifactCollectionName
670
- tags: $tags
671
- }
672
- ) {
673
- success
674
- }
675
- }
676
- """
677
- )
493
+ def _delete_tags(self, tags_to_delete: Iterable[str]) -> None:
678
494
  self.client.execute(
679
- delete_mutation,
495
+ gql(DELETE_ARTIFACT_COLLECTION_TAG_ASSIGNMENTS_GQL),
680
496
  variable_values={
681
497
  "entityName": self.entity,
682
498
  "projectName": self.project,
683
499
  "artifactCollectionName": self._saved_name,
684
- "tags": [
685
- {
686
- "tagName": tag,
687
- }
688
- for tag in tags_to_delete
689
- ],
500
+ "tags": [{"tagName": tag} for tag in tags_to_delete],
690
501
  },
691
502
  )
692
503
 
693
504
  def save(self) -> None:
694
505
  """Persist any changes made to the artifact collection."""
695
- if self.is_sequence():
696
- self._update_collection()
506
+ self._update_collection()
697
507
 
698
- if self._saved_type != self._type:
699
- self._update_collection_type()
700
- else:
701
- self._update_portfolio()
508
+ if self.is_sequence() and (self._saved_type != self._type):
509
+ self._update_collection_type()
702
510
 
703
- tags_to_add = set(self._tags) - set(self._saved_tags)
704
- tags_to_delete = set(self._saved_tags) - set(self._tags)
705
- if len(tags_to_add) > 0:
511
+ current_tags = set(self._tags)
512
+ saved_tags = set(self._saved_tags)
513
+ if tags_to_add := (current_tags - saved_tags):
706
514
  self._add_tags(tags_to_add)
707
- if len(tags_to_delete) > 0:
515
+ if tags_to_delete := (saved_tags - current_tags):
708
516
  self._delete_tags(tags_to_delete)
709
517
  self._saved_tags = copy(self._tags)
710
518
 
711
- def __repr__(self):
519
+ def __repr__(self) -> str:
712
520
  return f"<ArtifactCollection {self._name} ({self._type})>"
713
521
 
714
522
 
@@ -718,6 +526,8 @@ class Artifacts(SizedPaginator["wandb.Artifact"]):
718
526
  This is generally used indirectly via the `Api`.artifact_versions method.
719
527
  """
720
528
 
529
+ last_response: Optional[ArtifactsFragment]
530
+
721
531
  def __init__(
722
532
  self,
723
533
  client: Client,
@@ -730,8 +540,6 @@ class Artifacts(SizedPaginator["wandb.Artifact"]):
730
540
  per_page: int = 50,
731
541
  tags: Optional[Union[str, List[str]]] = None,
732
542
  ):
733
- from wandb.sdk.artifacts.artifact import _gql_artifact_fragment
734
-
735
543
  self.entity = entity
736
544
  self.collection_name = collection_name
737
545
  self.type = type
@@ -747,152 +555,107 @@ class Artifacts(SizedPaginator["wandb.Artifact"]):
747
555
  "collection": self.collection_name,
748
556
  "filters": json.dumps(self.filters),
749
557
  }
750
- self.QUERY = gql(
751
- """
752
- query Artifacts($project: String!, $entity: String!, $type: String!, $collection: String!, $cursor: String, $perPage: Int = 50, $order: String, $filters: JSONString) {{
753
- project(name: $project, entityName: $entity) {{
754
- artifactType(name: $type) {{
755
- artifactCollection: {}(name: $collection) {{
756
- name
757
- artifacts(filters: $filters, after: $cursor, first: $perPage, order: $order) {{
758
- totalCount
759
- edges {{
760
- node {{
761
- ...ArtifactFragment
762
- }}
763
- version
764
- cursor
765
- }}
766
- pageInfo {{
767
- endCursor
768
- hasNextPage
769
- }}
770
- }}
771
- }}
772
- }}
773
- }}
774
- }}
775
- {}
776
- """.format(
777
- artifact_collection_edge_name(
778
- server_supports_artifact_collections_gql_edges(client)
779
- ),
780
- _gql_artifact_fragment(),
781
- )
558
+
559
+ if server_supports_artifact_collections_gql_edges(client):
560
+ rename_fields = None
561
+ else:
562
+ rename_fields = {"artifactCollection": "artifactSequence"}
563
+
564
+ self.QUERY = gql_compat(
565
+ PROJECT_ARTIFACTS_GQL,
566
+ omit_fields=omit_artifact_fields(api=InternalApi()),
567
+ rename_fields=rename_fields,
782
568
  )
569
+
783
570
  super().__init__(client, variables, per_page)
784
571
 
572
+ @override
573
+ def _update_response(self) -> None:
574
+ data = self.client.execute(self.QUERY, variable_values=self.variables)
575
+ result = ProjectArtifacts.model_validate(data)
576
+
577
+ # Extract the inner `*Connection` result for faster/easier access.
578
+ if not (
579
+ (proj := result.project)
580
+ and (type_ := proj.artifact_type)
581
+ and (collection := type_.artifact_collection)
582
+ and (conn := collection.artifacts)
583
+ ):
584
+ raise ValueError(f"Unable to parse {type(self).__name__!r} response data")
585
+
586
+ self.last_response = ArtifactsFragment.model_validate(conn)
587
+
785
588
  @property
786
- def length(self):
787
- if self.last_response:
788
- return self.last_response["project"]["artifactType"]["artifactCollection"][
789
- "artifacts"
790
- ]["totalCount"]
791
- else:
589
+ def length(self) -> Optional[int]:
590
+ if self.last_response is None:
792
591
  return None
592
+ return self.last_response.total_count
793
593
 
794
594
  @property
795
- def more(self):
796
- if self.last_response:
797
- return self.last_response["project"]["artifactType"]["artifactCollection"][
798
- "artifacts"
799
- ]["pageInfo"]["hasNextPage"]
800
- else:
595
+ def more(self) -> bool:
596
+ if self.last_response is None:
801
597
  return True
598
+ return self.last_response.page_info.has_next_page
802
599
 
803
600
  @property
804
- def cursor(self):
805
- if self.last_response:
806
- return self.last_response["project"]["artifactType"]["artifactCollection"][
807
- "artifacts"
808
- ]["edges"][-1]["cursor"]
809
- else:
601
+ def cursor(self) -> Optional[str]:
602
+ if self.last_response is None:
810
603
  return None
604
+ return self.last_response.edges[-1].cursor
811
605
 
812
- def convert_objects(self):
813
- collection = self.last_response["project"]["artifactType"]["artifactCollection"]
814
- artifact_edges = collection.get("artifacts", {}).get("edges", [])
606
+ def convert_objects(self) -> List["wandb.Artifact"]:
607
+ artifact_edges = (edge for edge in self.last_response.edges if edge.node)
815
608
  artifacts = (
816
609
  wandb.Artifact._from_attrs(
817
- self.entity,
818
- self.project,
819
- self.collection_name + ":" + a["version"],
820
- a["node"],
821
- self.client,
610
+ entity=self.entity,
611
+ project=self.project,
612
+ name=f"{self.collection_name}:{edge.version}",
613
+ attrs=edge.node.model_dump(exclude_unset=True),
614
+ client=self.client,
822
615
  )
823
- for a in artifact_edges
616
+ for edge in artifact_edges
824
617
  )
825
618
  required_tags = set(self.tags or [])
826
- return [
827
- artifact for artifact in artifacts if required_tags.issubset(artifact.tags)
828
- ]
619
+ return [art for art in artifacts if required_tags.issubset(art.tags)]
829
620
 
830
621
 
831
622
  class RunArtifacts(SizedPaginator["wandb.Artifact"]):
832
- def __init__(self, client: Client, run: "Run", mode="logged", per_page: int = 50):
833
- from wandb.sdk.artifacts.artifact import _gql_artifact_fragment
834
-
835
- output_query = gql(
836
- """
837
- query RunOutputArtifacts(
838
- $entity: String!, $project: String!, $runName: String!, $cursor: String, $perPage: Int,
839
- ) {
840
- project(name: $project, entityName: $entity) {
841
- run(name: $runName) {
842
- outputArtifacts(after: $cursor, first: $perPage) {
843
- totalCount
844
- edges {
845
- node {
846
- ...ArtifactFragment
847
- }
848
- cursor
849
- }
850
- pageInfo {
851
- endCursor
852
- hasNextPage
853
- }
854
- }
855
- }
856
- }
857
- }
858
- """
859
- + _gql_artifact_fragment()
860
- )
861
-
862
- input_query = gql(
863
- """
864
- query RunInputArtifacts(
865
- $entity: String!, $project: String!, $runName: String!, $cursor: String, $perPage: Int,
866
- ) {
867
- project(name: $project, entityName: $entity) {
868
- run(name: $runName) {
869
- inputArtifacts(after: $cursor, first: $perPage) {
870
- totalCount
871
- edges {
872
- node {
873
- ...ArtifactFragment
874
- }
875
- cursor
876
- }
877
- pageInfo {
878
- endCursor
879
- hasNextPage
880
- }
881
- }
882
- }
883
- }
884
- }
885
- """
886
- + _gql_artifact_fragment()
887
- )
623
+ last_response: Union[
624
+ RunOutputArtifactsProjectRunOutputArtifacts,
625
+ RunInputArtifactsProjectRunInputArtifacts,
626
+ ]
627
+
628
+ #: The pydantic model used to parse the (inner part of the) raw response.
629
+ _response_cls: Type[
630
+ Union[
631
+ RunOutputArtifactsProjectRunOutputArtifacts,
632
+ RunInputArtifactsProjectRunInputArtifacts,
633
+ ]
634
+ ]
888
635
 
636
+ def __init__(
637
+ self,
638
+ client: Client,
639
+ run: "Run",
640
+ mode: Literal["logged", "used"] = "logged",
641
+ per_page: int = 50,
642
+ ):
889
643
  self.run = run
644
+
890
645
  if mode == "logged":
891
646
  self.run_key = "outputArtifacts"
892
- self.QUERY = output_query
647
+ self.QUERY = gql_compat(
648
+ RUN_OUTPUT_ARTIFACTS_GQL,
649
+ omit_fields=omit_artifact_fields(api=InternalApi()),
650
+ )
651
+ self._response_cls = RunOutputArtifactsProjectRunOutputArtifacts
893
652
  elif mode == "used":
894
653
  self.run_key = "inputArtifacts"
895
- self.QUERY = input_query
654
+ self.QUERY = gql_compat(
655
+ RUN_INPUT_ARTIFACTS_GQL,
656
+ omit_fields=omit_artifact_fields(api=InternalApi()),
657
+ )
658
+ self._response_cls = RunInputArtifactsProjectRunInputArtifacts
896
659
  else:
897
660
  raise ValueError("mode must be logged or used")
898
661
 
@@ -901,99 +664,50 @@ class RunArtifacts(SizedPaginator["wandb.Artifact"]):
901
664
  "project": run.project,
902
665
  "runName": run.id,
903
666
  }
904
-
905
667
  super().__init__(client, variable_values, per_page)
906
668
 
669
+ @override
670
+ def _update_response(self) -> None:
671
+ data = self.client.execute(self.QUERY, variable_values=self.variables)
672
+
673
+ # Extract the inner `*Connection` result for faster/easier access.
674
+ inner_data = data["project"]["run"][self.run_key]
675
+ self.last_response = self._response_cls.model_validate(inner_data)
676
+
907
677
  @property
908
- def length(self):
909
- if self.last_response:
910
- return self.last_response["project"]["run"][self.run_key]["totalCount"]
911
- else:
678
+ def length(self) -> Optional[int]:
679
+ if self.last_response is None:
912
680
  return None
681
+ return self.last_response.total_count
913
682
 
914
683
  @property
915
- def more(self):
916
- if self.last_response:
917
- return self.last_response["project"]["run"][self.run_key]["pageInfo"][
918
- "hasNextPage"
919
- ]
920
- else:
684
+ def more(self) -> bool:
685
+ if self.last_response is None:
921
686
  return True
687
+ return self.last_response.page_info.has_next_page
922
688
 
923
689
  @property
924
- def cursor(self):
925
- if self.last_response:
926
- return self.last_response["project"]["run"][self.run_key]["edges"][-1][
927
- "cursor"
928
- ]
929
- else:
690
+ def cursor(self) -> Optional[str]:
691
+ if self.last_response is None:
930
692
  return None
693
+ return self.last_response.edges[-1].cursor
931
694
 
932
- def convert_objects(self):
695
+ def convert_objects(self) -> List["wandb.Artifact"]:
933
696
  return [
934
697
  wandb.Artifact._from_attrs(
935
- r["node"]["artifactSequence"]["project"]["entityName"],
936
- r["node"]["artifactSequence"]["project"]["name"],
937
- "{}:v{}".format(
938
- r["node"]["artifactSequence"]["name"], r["node"]["versionIndex"]
939
- ),
940
- r["node"],
941
- self.client,
698
+ entity=node.artifact_sequence.project.entity_name,
699
+ project=node.artifact_sequence.project.name,
700
+ name=f"{node.artifact_sequence.name}:v{node.version_index}",
701
+ attrs=node.model_dump(exclude_unset=True),
702
+ client=self.client,
942
703
  )
943
- for r in self.last_response["project"]["run"][self.run_key]["edges"]
704
+ for r in self.last_response.edges
705
+ if (node := r.node)
944
706
  ]
945
707
 
946
708
 
947
709
  class ArtifactFiles(SizedPaginator["public.File"]):
948
- ARTIFACT_VERSION_FILES_QUERY = gql(
949
- f"""
950
- query ArtifactFiles(
951
- $entityName: String!,
952
- $projectName: String!,
953
- $artifactTypeName: String!,
954
- $artifactName: String!
955
- $fileNames: [String!],
956
- $fileCursor: String,
957
- $fileLimit: Int = 50
958
- ) {{
959
- project(name: $projectName, entityName: $entityName) {{
960
- artifactType(name: $artifactTypeName) {{
961
- artifact(name: $artifactName) {{
962
- files(names: $fileNames, after: $fileCursor, first: $fileLimit) {{
963
- ...FilesFragment
964
- }}
965
- }}
966
- }}
967
- }}
968
- }}
969
- {ARTIFACT_FILES_FRAGMENT}
970
- """
971
- )
972
-
973
- ARTIFACT_COLLECTION_MEMBERSHIP_FILES_QUERY = gql(
974
- f"""
975
- query ArtifactCollectionMembershipFiles(
976
- $entityName: String!,
977
- $projectName: String!,
978
- $artifactName: String!,
979
- $artifactVersionIndex: String!,
980
- $fileNames: [String!],
981
- $fileCursor: String,
982
- $fileLimit: Int = 50
983
- ) {{
984
- project(name: $projectName, entityName: $entityName) {{
985
- artifactCollection(name: $artifactName) {{
986
- artifactMembership (aliasName: $artifactVersionIndex) {{
987
- files(names: $fileNames, after: $fileCursor, first: $fileLimit) {{
988
- ...FilesFragment
989
- }}
990
- }}
991
- }}
992
- }}
993
- }}
994
- {ARTIFACT_FILES_FRAGMENT}
995
- """
996
- )
710
+ last_response: Optional[FilesFragment]
997
711
 
998
712
  def __init__(
999
713
  self,
@@ -1006,15 +720,9 @@ class ArtifactFiles(SizedPaginator["public.File"]):
1006
720
  ServerFeature.ARTIFACT_COLLECTION_MEMBERSHIP_FILES
1007
721
  )
1008
722
  self.artifact = artifact
1009
- variables = {
1010
- "entityName": artifact.source_entity,
1011
- "projectName": artifact.source_project,
1012
- "artifactTypeName": artifact.type,
1013
- "artifactName": artifact.source_name,
1014
- "fileNames": names,
1015
- }
723
+
1016
724
  if self.query_via_membership:
1017
- self.QUERY = self.ARTIFACT_COLLECTION_MEMBERSHIP_FILES_QUERY
725
+ query_str = ARTIFACT_COLLECTION_MEMBERSHIP_FILES_GQL
1018
726
  variables = {
1019
727
  "entityName": artifact.entity,
1020
728
  "projectName": artifact.project,
@@ -1023,67 +731,77 @@ class ArtifactFiles(SizedPaginator["public.File"]):
1023
731
  "fileNames": names,
1024
732
  }
1025
733
  else:
1026
- self.QUERY = self.ARTIFACT_VERSION_FILES_QUERY
734
+ query_str = ARTIFACT_VERSION_FILES_GQL
735
+ variables = {
736
+ "entityName": artifact.source_entity,
737
+ "projectName": artifact.source_project,
738
+ "artifactName": artifact.source_name,
739
+ "artifactTypeName": artifact.type,
740
+ "fileNames": names,
741
+ }
742
+
1027
743
  # The server must advertise at least SDK 0.12.21
1028
744
  # to get storagePath
1029
745
  if not client.version_supported("0.12.21"):
1030
- self.QUERY = gql(self.QUERY.loc.source.body.replace("storagePath\n", ""))
746
+ self.QUERY = gql_compat(query_str, omit_fields={"storagePath"})
747
+ else:
748
+ self.QUERY = gql(query_str)
749
+
1031
750
  super().__init__(client, variables, per_page)
1032
751
 
752
+ @override
753
+ def _update_response(self) -> None:
754
+ data = self.client.execute(self.QUERY, variable_values=self.variables)
755
+
756
+ # Extract the inner `*Connection` result for faster/easier access.
757
+ if self.query_via_membership:
758
+ result = ArtifactCollectionMembershipFiles.model_validate(data)
759
+ conn = result.project.artifact_collection.artifact_membership.files
760
+ else:
761
+ result = ArtifactVersionFiles.model_validate(data)
762
+ conn = result.project.artifact_type.artifact.files
763
+
764
+ if conn is None:
765
+ raise ValueError(f"Unable to parse {type(self).__name__!r} response data")
766
+
767
+ self.last_response = FilesFragment.model_validate(conn)
768
+
1033
769
  @property
1034
- def path(self):
770
+ def path(self) -> List[str]:
1035
771
  return [self.artifact.entity, self.artifact.project, self.artifact.name]
1036
772
 
1037
773
  @property
1038
- def length(self):
774
+ def length(self) -> int:
1039
775
  return self.artifact.file_count
1040
776
 
1041
777
  @property
1042
- def more(self):
1043
- if self.last_response:
1044
- if self.query_via_membership:
1045
- return self.last_response["project"]["artifactCollection"][
1046
- "artifactMembership"
1047
- ]["files"]["pageInfo"]["hasNextPage"]
1048
- return self.last_response["project"]["artifactType"]["artifact"]["files"][
1049
- "pageInfo"
1050
- ]["hasNextPage"]
1051
- else:
778
+ def more(self) -> bool:
779
+ if self.last_response is None:
1052
780
  return True
781
+ return self.last_response.page_info.has_next_page
1053
782
 
1054
783
  @property
1055
- def cursor(self):
1056
- if self.last_response:
1057
- if self.query_via_membership:
1058
- return self.last_response["project"]["artifactCollection"][
1059
- "artifactMembership"
1060
- ]["files"]["edges"][-1]["cursor"]
1061
- return self.last_response["project"]["artifactType"]["artifact"]["files"][
1062
- "edges"
1063
- ][-1]["cursor"]
1064
- else:
784
+ def cursor(self) -> Optional[str]:
785
+ if self.last_response is None:
1065
786
  return None
787
+ return self.last_response.edges[-1].cursor
1066
788
 
1067
- def update_variables(self):
789
+ def update_variables(self) -> None:
1068
790
  self.variables.update({"fileLimit": self.per_page, "fileCursor": self.cursor})
1069
791
 
1070
- def convert_objects(self):
1071
- if self.query_via_membership:
1072
- return [
1073
- public.File(self.client, r["node"])
1074
- for r in self.last_response["project"]["artifactCollection"][
1075
- "artifactMembership"
1076
- ]["files"]["edges"]
1077
- ]
792
+ def convert_objects(self) -> List["public.File"]:
1078
793
  return [
1079
- public.File(self.client, r["node"])
1080
- for r in self.last_response["project"]["artifactType"]["artifact"]["files"][
1081
- "edges"
1082
- ]
794
+ public.File(
795
+ client=self.client,
796
+ attrs=node.model_dump(exclude_unset=True),
797
+ )
798
+ for edge in self.last_response.edges
799
+ if (node := edge.node)
1083
800
  ]
1084
801
 
1085
- def __repr__(self):
1086
- return "<ArtifactFiles {} ({})>".format("/".join(self.path), len(self))
802
+ def __repr__(self) -> str:
803
+ path_str = "/".join(self.path)
804
+ return f"<ArtifactFiles {path_str} ({len(self)})>"
1087
805
 
1088
806
 
1089
807
  def server_supports_artifact_collections_gql_edges(
@@ -1099,21 +817,3 @@ def server_supports_artifact_collections_gql_edges(
1099
817
  "W&B Local Server version does not support ArtifactCollection gql edges; falling back to using legacy ArtifactSequence. Please update server to at least version 0.9.50."
1100
818
  )
1101
819
  return supported
1102
-
1103
-
1104
- def artifact_collection_edge_name(server_supports_artifact_collections: bool) -> str:
1105
- return (
1106
- "artifactCollection"
1107
- if server_supports_artifact_collections
1108
- else "artifactSequence"
1109
- )
1110
-
1111
-
1112
- def artifact_collection_plural_edge_name(
1113
- server_supports_artifact_collections: bool,
1114
- ) -> str:
1115
- return (
1116
- "artifactCollections"
1117
- if server_supports_artifact_collections
1118
- else "artifactSequences"
1119
- )