wandb 0.19.8__py3-none-any.whl → 0.19.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. wandb/__init__.py +5 -1
  2. wandb/__init__.pyi +15 -8
  3. wandb/_pydantic/__init__.py +30 -0
  4. wandb/_pydantic/base.py +148 -0
  5. wandb/_pydantic/utils.py +66 -0
  6. wandb/_pydantic/v1_compat.py +284 -0
  7. wandb/apis/paginator.py +82 -38
  8. wandb/apis/public/__init__.py +2 -2
  9. wandb/apis/public/api.py +111 -53
  10. wandb/apis/public/artifacts.py +387 -639
  11. wandb/apis/public/automations.py +69 -0
  12. wandb/apis/public/files.py +2 -2
  13. wandb/apis/public/integrations.py +168 -0
  14. wandb/apis/public/projects.py +32 -2
  15. wandb/apis/public/reports.py +2 -2
  16. wandb/apis/public/runs.py +19 -11
  17. wandb/apis/public/utils.py +107 -1
  18. wandb/automations/__init__.py +81 -0
  19. wandb/automations/_filters/__init__.py +40 -0
  20. wandb/automations/_filters/expressions.py +179 -0
  21. wandb/automations/_filters/operators.py +267 -0
  22. wandb/automations/_filters/run_metrics.py +183 -0
  23. wandb/automations/_generated/__init__.py +184 -0
  24. wandb/automations/_generated/create_filter_trigger.py +21 -0
  25. wandb/automations/_generated/create_generic_webhook_integration.py +43 -0
  26. wandb/automations/_generated/delete_trigger.py +19 -0
  27. wandb/automations/_generated/enums.py +33 -0
  28. wandb/automations/_generated/fragments.py +343 -0
  29. wandb/automations/_generated/generic_webhook_integrations_by_entity.py +22 -0
  30. wandb/automations/_generated/get_triggers.py +24 -0
  31. wandb/automations/_generated/get_triggers_by_entity.py +24 -0
  32. wandb/automations/_generated/input_types.py +104 -0
  33. wandb/automations/_generated/integrations_by_entity.py +22 -0
  34. wandb/automations/_generated/operations.py +710 -0
  35. wandb/automations/_generated/slack_integrations_by_entity.py +22 -0
  36. wandb/automations/_generated/update_filter_trigger.py +21 -0
  37. wandb/automations/_utils.py +123 -0
  38. wandb/automations/_validators.py +73 -0
  39. wandb/automations/actions.py +205 -0
  40. wandb/automations/automations.py +109 -0
  41. wandb/automations/events.py +235 -0
  42. wandb/automations/integrations.py +26 -0
  43. wandb/automations/scopes.py +76 -0
  44. wandb/beta/workflows.py +9 -10
  45. wandb/bin/gpu_stats +0 -0
  46. wandb/cli/cli.py +3 -3
  47. wandb/integration/keras/keras.py +2 -1
  48. wandb/integration/langchain/wandb_tracer.py +2 -1
  49. wandb/integration/metaflow/metaflow.py +19 -17
  50. wandb/integration/sacred/__init__.py +1 -1
  51. wandb/jupyter.py +155 -133
  52. wandb/old/summary.py +0 -2
  53. wandb/proto/v3/wandb_internal_pb2.py +297 -292
  54. wandb/proto/v3/wandb_settings_pb2.py +2 -2
  55. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  56. wandb/proto/v4/wandb_internal_pb2.py +292 -292
  57. wandb/proto/v4/wandb_settings_pb2.py +2 -2
  58. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  59. wandb/proto/v5/wandb_internal_pb2.py +292 -292
  60. wandb/proto/v5/wandb_settings_pb2.py +2 -2
  61. wandb/proto/v5/wandb_telemetry_pb2.py +10 -10
  62. wandb/proto/v6/wandb_base_pb2.py +41 -0
  63. wandb/proto/v6/wandb_internal_pb2.py +393 -0
  64. wandb/proto/v6/wandb_server_pb2.py +78 -0
  65. wandb/proto/v6/wandb_settings_pb2.py +58 -0
  66. wandb/proto/v6/wandb_telemetry_pb2.py +52 -0
  67. wandb/proto/wandb_base_pb2.py +2 -0
  68. wandb/proto/wandb_deprecated.py +10 -0
  69. wandb/proto/wandb_internal_pb2.py +3 -1
  70. wandb/proto/wandb_server_pb2.py +2 -0
  71. wandb/proto/wandb_settings_pb2.py +2 -0
  72. wandb/proto/wandb_telemetry_pb2.py +2 -0
  73. wandb/sdk/artifacts/_generated/__init__.py +248 -0
  74. wandb/sdk/artifacts/_generated/artifact_collection_membership_files.py +43 -0
  75. wandb/sdk/artifacts/_generated/artifact_version_files.py +36 -0
  76. wandb/sdk/artifacts/_generated/create_artifact_collection_tag_assignments.py +36 -0
  77. wandb/sdk/artifacts/_generated/delete_artifact_collection_tag_assignments.py +25 -0
  78. wandb/sdk/artifacts/_generated/delete_artifact_portfolio.py +35 -0
  79. wandb/sdk/artifacts/_generated/delete_artifact_sequence.py +35 -0
  80. wandb/sdk/artifacts/_generated/enums.py +17 -0
  81. wandb/sdk/artifacts/_generated/fragments.py +186 -0
  82. wandb/sdk/artifacts/_generated/input_types.py +16 -0
  83. wandb/sdk/artifacts/_generated/move_artifact_collection.py +35 -0
  84. wandb/sdk/artifacts/_generated/operations.py +510 -0
  85. wandb/sdk/artifacts/_generated/project_artifact_collection.py +101 -0
  86. wandb/sdk/artifacts/_generated/project_artifact_collections.py +33 -0
  87. wandb/sdk/artifacts/_generated/project_artifact_type.py +24 -0
  88. wandb/sdk/artifacts/_generated/project_artifact_types.py +24 -0
  89. wandb/sdk/artifacts/_generated/project_artifacts.py +42 -0
  90. wandb/sdk/artifacts/_generated/run_input_artifacts.py +51 -0
  91. wandb/sdk/artifacts/_generated/run_output_artifacts.py +51 -0
  92. wandb/sdk/artifacts/_generated/update_artifact_portfolio.py +35 -0
  93. wandb/sdk/artifacts/_generated/update_artifact_sequence.py +35 -0
  94. wandb/sdk/artifacts/_graphql_fragments.py +56 -81
  95. wandb/sdk/artifacts/_validators.py +1 -0
  96. wandb/sdk/artifacts/artifact.py +110 -49
  97. wandb/sdk/artifacts/artifact_manifest_entry.py +2 -1
  98. wandb/sdk/artifacts/artifact_saver.py +16 -2
  99. wandb/sdk/artifacts/storage_handlers/azure_handler.py +1 -0
  100. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +23 -2
  101. wandb/sdk/data_types/audio.py +1 -3
  102. wandb/sdk/data_types/base_types/media.py +13 -7
  103. wandb/sdk/data_types/base_types/wb_value.py +34 -11
  104. wandb/sdk/data_types/html.py +36 -9
  105. wandb/sdk/data_types/image.py +56 -37
  106. wandb/sdk/data_types/molecule.py +1 -5
  107. wandb/sdk/data_types/object_3d.py +2 -1
  108. wandb/sdk/data_types/saved_model.py +7 -9
  109. wandb/sdk/data_types/table.py +5 -0
  110. wandb/sdk/data_types/trace_tree.py +2 -0
  111. wandb/sdk/data_types/utils.py +1 -1
  112. wandb/sdk/data_types/video.py +15 -30
  113. wandb/sdk/interface/interface.py +2 -0
  114. wandb/{apis/public → sdk/internal}/_generated/__init__.py +0 -6
  115. wandb/{apis/public → sdk/internal}/_generated/server_features_query.py +3 -3
  116. wandb/sdk/internal/internal_api.py +138 -47
  117. wandb/sdk/internal/profiler.py +6 -5
  118. wandb/sdk/internal/run.py +13 -6
  119. wandb/sdk/internal/sender.py +2 -0
  120. wandb/sdk/internal/sender_config.py +8 -11
  121. wandb/sdk/internal/settings_static.py +24 -2
  122. wandb/sdk/lib/apikey.py +40 -20
  123. wandb/sdk/lib/asyncio_compat.py +1 -1
  124. wandb/sdk/lib/deprecate.py +13 -22
  125. wandb/sdk/lib/disabled.py +2 -1
  126. wandb/sdk/lib/printer.py +37 -8
  127. wandb/sdk/lib/printer_asyncio.py +46 -0
  128. wandb/sdk/lib/redirect.py +10 -5
  129. wandb/sdk/lib/run_moment.py +4 -6
  130. wandb/sdk/lib/wb_logging.py +161 -0
  131. wandb/sdk/service/server_sock.py +19 -14
  132. wandb/sdk/service/service.py +9 -7
  133. wandb/sdk/service/streams.py +5 -0
  134. wandb/sdk/verify/verify.py +6 -3
  135. wandb/sdk/wandb_config.py +44 -43
  136. wandb/sdk/wandb_init.py +323 -141
  137. wandb/sdk/wandb_login.py +13 -4
  138. wandb/sdk/wandb_metadata.py +107 -91
  139. wandb/sdk/wandb_run.py +529 -325
  140. wandb/sdk/wandb_settings.py +422 -202
  141. wandb/sdk/wandb_setup.py +52 -1
  142. wandb/util.py +29 -29
  143. {wandb-0.19.8.dist-info → wandb-0.19.10.dist-info}/METADATA +7 -7
  144. {wandb-0.19.8.dist-info → wandb-0.19.10.dist-info}/RECORD +150 -93
  145. wandb/_globals.py +0 -19
  146. wandb/apis/public/_generated/base.py +0 -128
  147. wandb/apis/public/_generated/typing_compat.py +0 -14
  148. /wandb/{apis/public → sdk/internal}/_generated/enums.py +0 -0
  149. /wandb/{apis/public → sdk/internal}/_generated/input_types.py +0 -0
  150. /wandb/{apis/public → sdk/internal}/_generated/operations.py +0 -0
  151. {wandb-0.19.8.dist-info → wandb-0.19.10.dist-info}/WHEEL +0 -0
  152. {wandb-0.19.8.dist-info → wandb-0.19.10.dist-info}/entry_points.txt +0 -0
  153. {wandb-0.19.8.dist-info → wandb-0.19.10.dist-info}/licenses/LICENSE +0 -0
@@ -3,49 +3,82 @@
3
3
  import json
4
4
  import re
5
5
  from copy import copy
6
- from typing import TYPE_CHECKING, Any, List, Mapping, Optional, Sequence, Union
6
+ from typing import (
7
+ TYPE_CHECKING,
8
+ Any,
9
+ Iterable,
10
+ List,
11
+ Literal,
12
+ Mapping,
13
+ Optional,
14
+ Sequence,
15
+ Type,
16
+ Union,
17
+ )
7
18
 
19
+ from typing_extensions import override
8
20
  from wandb_gql import Client, gql
9
21
 
10
22
  import wandb
11
23
  from wandb.apis import public
12
24
  from wandb.apis.normalize import normalize_exceptions
13
- from wandb.apis.paginator import Paginator
25
+ from wandb.apis.paginator import Paginator, SizedPaginator
14
26
  from wandb.errors.term import termlog
15
- from wandb.sdk.artifacts._graphql_fragments import (
16
- ARTIFACT_FILES_FRAGMENT,
17
- ARTIFACTS_TYPES_FRAGMENT,
27
+ from wandb.proto.wandb_deprecated import Deprecated
28
+ from wandb.proto.wandb_internal_pb2 import ServerFeature
29
+ from wandb.sdk.artifacts._generated import (
30
+ ARTIFACT_COLLECTION_MEMBERSHIP_FILES_GQL,
31
+ ARTIFACT_VERSION_FILES_GQL,
32
+ CREATE_ARTIFACT_COLLECTION_TAG_ASSIGNMENTS_GQL,
33
+ DELETE_ARTIFACT_COLLECTION_TAG_ASSIGNMENTS_GQL,
34
+ DELETE_ARTIFACT_PORTFOLIO_GQL,
35
+ DELETE_ARTIFACT_SEQUENCE_GQL,
36
+ MOVE_ARTIFACT_COLLECTION_GQL,
37
+ PROJECT_ARTIFACT_COLLECTION_GQL,
38
+ PROJECT_ARTIFACT_COLLECTIONS_GQL,
39
+ PROJECT_ARTIFACT_TYPE_GQL,
40
+ PROJECT_ARTIFACT_TYPES_GQL,
41
+ PROJECT_ARTIFACTS_GQL,
42
+ RUN_INPUT_ARTIFACTS_GQL,
43
+ RUN_OUTPUT_ARTIFACTS_GQL,
44
+ UPDATE_ARTIFACT_PORTFOLIO_GQL,
45
+ UPDATE_ARTIFACT_SEQUENCE_GQL,
46
+ ArtifactCollectionMembershipFiles,
47
+ ArtifactCollectionsFragment,
48
+ ArtifactsFragment,
49
+ ArtifactTypeFragment,
50
+ ArtifactTypesFragment,
51
+ ArtifactVersionFiles,
52
+ FilesFragment,
53
+ ProjectArtifactCollection,
54
+ ProjectArtifactCollections,
55
+ ProjectArtifacts,
56
+ ProjectArtifactType,
57
+ ProjectArtifactTypes,
58
+ RunInputArtifactsProjectRunInputArtifacts,
59
+ RunOutputArtifactsProjectRunOutputArtifacts,
18
60
  )
61
+ from wandb.sdk.artifacts._graphql_fragments import omit_artifact_fields
62
+ from wandb.sdk.internal.internal_api import Api as InternalApi
19
63
  from wandb.sdk.lib import deprecate
20
64
 
65
+ from .utils import gql_compat
66
+
21
67
  if TYPE_CHECKING:
22
68
  from wandb.apis.public import RetryingClient, Run
23
69
 
24
70
 
25
- class ArtifactTypes(Paginator):
26
- QUERY = gql(
27
- """
28
- query ProjectArtifacts(
29
- $entityName: String!,
30
- $projectName: String!,
31
- $cursor: String,
32
- ) {{
33
- project(name: $projectName, entityName: $entityName) {{
34
- artifactTypes(after: $cursor) {{
35
- ...ArtifactTypesFragment
36
- }}
37
- }}
38
- }}
39
- {}
40
- """.format(ARTIFACTS_TYPES_FRAGMENT)
41
- )
71
+ class ArtifactTypes(Paginator["ArtifactType"]):
72
+ QUERY = gql(PROJECT_ARTIFACT_TYPES_GQL)
73
+
74
+ last_response: Optional[ArtifactTypesFragment]
42
75
 
43
76
  def __init__(
44
77
  self,
45
78
  client: Client,
46
79
  entity: str,
47
80
  project: str,
48
- per_page: Optional[int] = 50,
81
+ per_page: int = 50,
49
82
  ):
50
83
  self.entity = entity
51
84
  self.project = project
@@ -54,41 +87,54 @@ class ArtifactTypes(Paginator):
54
87
  "entityName": entity,
55
88
  "projectName": project,
56
89
  }
57
-
58
90
  super().__init__(client, variable_values, per_page)
59
91
 
92
+ @override
93
+ def _update_response(self) -> None:
94
+ """Fetch and validate the response data for the current page."""
95
+ data = self.client.execute(self.QUERY, variable_values=self.variables)
96
+ result = ProjectArtifactTypes.model_validate(data)
97
+
98
+ # Extract the inner `*Connection` result for faster/easier access.
99
+ if not ((proj := result.project) and (conn := proj.artifact_types)):
100
+ raise ValueError(f"Unable to parse {type(self).__name__!r} response data")
101
+
102
+ self.last_response = ArtifactTypesFragment.model_validate(conn)
103
+
60
104
  @property
61
- def length(self):
105
+ def length(self) -> None:
62
106
  # TODO
63
107
  return None
64
108
 
65
109
  @property
66
- def more(self):
67
- if self.last_response:
68
- return self.last_response["project"]["artifactTypes"]["pageInfo"][
69
- "hasNextPage"
70
- ]
71
- else:
110
+ def more(self) -> bool:
111
+ if self.last_response is None:
72
112
  return True
113
+ return self.last_response.page_info.has_next_page
73
114
 
74
115
  @property
75
- def cursor(self):
76
- if self.last_response:
77
- return self.last_response["project"]["artifactTypes"]["edges"][-1]["cursor"]
78
- else:
116
+ def cursor(self) -> Optional[str]:
117
+ if self.last_response is None:
79
118
  return None
119
+ return self.last_response.edges[-1].cursor
80
120
 
81
- def update_variables(self):
121
+ def update_variables(self) -> None:
82
122
  self.variables.update({"cursor": self.cursor})
83
123
 
84
- def convert_objects(self):
85
- if self.last_response["project"] is None:
124
+ def convert_objects(self) -> List["ArtifactType"]:
125
+ if self.last_response is None:
86
126
  return []
127
+
87
128
  return [
88
129
  ArtifactType(
89
- self.client, self.entity, self.project, r["node"]["name"], r["node"]
130
+ client=self.client,
131
+ entity=self.entity,
132
+ project=self.project,
133
+ type_name=node.name,
134
+ attrs=node.model_dump(exclude_unset=True),
90
135
  )
91
- for r in self.last_response["project"]["artifactTypes"]["edges"]
136
+ for edge in self.last_response.edges
137
+ if edge.node and (node := ArtifactTypeFragment.model_validate(edge.node))
92
138
  ]
93
139
 
94
140
 
@@ -110,39 +156,19 @@ class ArtifactType:
110
156
  self.load()
111
157
 
112
158
  def load(self):
113
- query = gql(
114
- """
115
- query ProjectArtifactType(
116
- $entityName: String!,
117
- $projectName: String!,
118
- $artifactTypeName: String!
119
- ) {
120
- project(name: $projectName, entityName: $entityName) {
121
- artifactType(name: $artifactTypeName) {
122
- id
123
- name
124
- description
125
- createdAt
126
- }
127
- }
128
- }
129
- """
130
- )
131
- response: Optional[Mapping[str, Any]] = self.client.execute(
132
- query,
159
+ data: Optional[Mapping[str, Any]] = self.client.execute(
160
+ gql(PROJECT_ARTIFACT_TYPE_GQL),
133
161
  variable_values={
134
162
  "entityName": self.entity,
135
163
  "projectName": self.project,
136
164
  "artifactTypeName": self.type,
137
165
  },
138
166
  )
139
- if (
140
- response is None
141
- or response.get("project") is None
142
- or response["project"].get("artifactType") is None
143
- ):
144
- raise ValueError("Could not find artifact type {}".format(self.type))
145
- self._attrs = response["project"]["artifactType"]
167
+ result = ProjectArtifactType.model_validate(data)
168
+ if not ((proj := result.project) and (artifact_type := proj.artifact_type)):
169
+ raise ValueError(f"Could not find artifact type {self.type}")
170
+
171
+ self._attrs = artifact_type.model_dump(exclude_unset=True)
146
172
  return self._attrs
147
173
 
148
174
  @property
@@ -167,14 +193,16 @@ class ArtifactType:
167
193
  return f"<ArtifactType {self.type}>"
168
194
 
169
195
 
170
- class ArtifactCollections(Paginator):
196
+ class ArtifactCollections(SizedPaginator["ArtifactCollection"]):
197
+ last_response: Optional[ArtifactCollectionsFragment]
198
+
171
199
  def __init__(
172
200
  self,
173
201
  client: Client,
174
202
  entity: str,
175
203
  project: str,
176
204
  type_name: str,
177
- per_page: Optional[int] = 50,
205
+ per_page: int = 50,
178
206
  ):
179
207
  self.entity = entity
180
208
  self.project = project
@@ -186,86 +214,65 @@ class ArtifactCollections(Paginator):
186
214
  "artifactTypeName": type_name,
187
215
  }
188
216
 
189
- self.QUERY = gql(
190
- """
191
- query ProjectArtifactCollections(
192
- $entityName: String!,
193
- $projectName: String!,
194
- $artifactTypeName: String!
195
- $cursor: String,
196
- ) {{
197
- project(name: $projectName, entityName: $entityName) {{
198
- artifactType(name: $artifactTypeName) {{
199
- artifactCollections: {}(after: $cursor) {{
200
- pageInfo {{
201
- endCursor
202
- hasNextPage
203
- }}
204
- totalCount
205
- edges {{
206
- node {{
207
- id
208
- name
209
- description
210
- createdAt
211
- }}
212
- cursor
213
- }}
214
- }}
215
- }}
216
- }}
217
- }}
218
- """.format(
219
- artifact_collection_plural_edge_name(
220
- server_supports_artifact_collections_gql_edges(client)
221
- )
222
- )
217
+ if server_supports_artifact_collections_gql_edges(client):
218
+ rename_fields = None
219
+ else:
220
+ rename_fields = {"artifactCollections": "artifactSequences"}
221
+
222
+ self.QUERY = gql_compat(
223
+ PROJECT_ARTIFACT_COLLECTIONS_GQL, rename_fields=rename_fields
223
224
  )
224
225
 
225
226
  super().__init__(client, variable_values, per_page)
226
227
 
228
+ @override
229
+ def _update_response(self) -> None:
230
+ """Fetch and validate the response data for the current page."""
231
+ data = self.client.execute(self.QUERY, variable_values=self.variables)
232
+ result = ProjectArtifactCollections.model_validate(data)
233
+
234
+ # Extract the inner `*Connection` result for faster/easier access.
235
+ if not (
236
+ (proj := result.project)
237
+ and (type_ := proj.artifact_type)
238
+ and (conn := type_.artifact_collections)
239
+ ):
240
+ raise ValueError(f"Unable to parse {type(self).__name__!r} response data")
241
+
242
+ self.last_response = ArtifactCollectionsFragment.model_validate(conn)
243
+
227
244
  @property
228
245
  def length(self):
229
- if self.last_response:
230
- return self.last_response["project"]["artifactType"]["artifactCollections"][
231
- "totalCount"
232
- ]
233
- else:
246
+ if self.last_response is None:
234
247
  return None
248
+ return self.last_response.total_count
235
249
 
236
250
  @property
237
251
  def more(self):
238
- if self.last_response:
239
- return self.last_response["project"]["artifactType"]["artifactCollections"][
240
- "pageInfo"
241
- ]["hasNextPage"]
242
- else:
252
+ if self.last_response is None:
243
253
  return True
254
+ return self.last_response.page_info.has_next_page
244
255
 
245
256
  @property
246
257
  def cursor(self):
247
- if self.last_response:
248
- return self.last_response["project"]["artifactType"]["artifactCollections"][
249
- "edges"
250
- ][-1]["cursor"]
251
- else:
258
+ if self.last_response is None:
252
259
  return None
260
+ return self.last_response.edges[-1].cursor
253
261
 
254
- def update_variables(self):
262
+ def update_variables(self) -> None:
255
263
  self.variables.update({"cursor": self.cursor})
256
264
 
257
- def convert_objects(self):
265
+ def convert_objects(self) -> List["ArtifactCollection"]:
258
266
  return [
259
267
  ArtifactCollection(
260
- self.client,
261
- self.entity,
262
- self.project,
263
- r["node"]["name"],
264
- self.type_name,
268
+ client=self.client,
269
+ entity=self.entity,
270
+ project=self.project,
271
+ name=node.name,
272
+ type=self.type_name,
265
273
  )
266
- for r in self.last_response["project"]["artifactType"][
267
- "artifactCollections"
268
- ]["edges"]
274
+ for edge in self.last_response.edges
275
+ if (node := edge.node)
269
276
  ]
270
277
 
271
278
 
@@ -298,83 +305,38 @@ class ArtifactCollection:
298
305
  self.organization = organization
299
306
 
300
307
  @property
301
- def id(self):
308
+ def id(self) -> str:
302
309
  return self._attrs["id"]
303
310
 
304
311
  @normalize_exceptions
305
- def artifacts(self, per_page=50):
312
+ def artifacts(self, per_page: int = 50) -> "Artifacts":
306
313
  """Artifacts."""
307
314
  return Artifacts(
308
- self.client,
309
- self.entity,
310
- self.project,
311
- self._saved_name,
312
- self._saved_type,
315
+ client=self.client,
316
+ entity=self.entity,
317
+ project=self.project,
318
+ collection_name=self._saved_name,
319
+ type=self._saved_type,
313
320
  per_page=per_page,
314
321
  )
315
322
 
316
323
  @property
317
- def aliases(self):
324
+ def aliases(self) -> List[str]:
318
325
  """Artifact Collection Aliases."""
319
326
  return self._aliases
320
327
 
321
328
  @property
322
- def created_at(self):
329
+ def created_at(self) -> str:
323
330
  return self._created_at
324
331
 
325
332
  def load(self):
326
- query = gql(
327
- """
328
- query ArtifactCollection(
329
- $entityName: String!,
330
- $projectName: String!,
331
- $artifactTypeName: String!,
332
- $artifactCollectionName: String!,
333
- $cursor: String,
334
- $perPage: Int = 1000
335
- ) {{
336
- project(name: $projectName, entityName: $entityName) {{
337
- artifactType(name: $artifactTypeName) {{
338
- artifactCollection: {}(name: $artifactCollectionName) {{
339
- id
340
- name
341
- description
342
- createdAt
343
- tags {{
344
- edges {{
345
- node {{
346
- id
347
- name
348
- }}
349
- }}
350
- }}
351
- aliases(after: $cursor, first: $perPage){{
352
- edges {{
353
- node {{
354
- alias
355
- }}
356
- cursor
357
- }}
358
- pageInfo {{
359
- endCursor
360
- hasNextPage
361
- }}
362
- }}
363
- }}
364
- artifactSequence(name: $artifactCollectionName) {{
365
- __typename
366
- }}
367
- }}
368
- }}
369
- }}
370
- """.format(
371
- artifact_collection_edge_name(
372
- server_supports_artifact_collections_gql_edges(self.client)
373
- )
374
- )
375
- )
333
+ if server_supports_artifact_collections_gql_edges(self.client):
334
+ rename_fields = None
335
+ else:
336
+ rename_fields = {"artifactCollection": "artifactSequence"}
337
+
376
338
  response = self.client.execute(
377
- query,
339
+ gql_compat(PROJECT_ARTIFACT_COLLECTION_GQL, rename_fields=rename_fields),
378
340
  variable_values={
379
341
  "entityName": self.entity,
380
342
  "projectName": self.project,
@@ -382,26 +344,30 @@ class ArtifactCollection:
382
344
  "artifactCollectionName": self._saved_name,
383
345
  },
384
346
  )
385
- if (
386
- response is None
387
- or response.get("project") is None
388
- or response["project"].get("artifactType") is None
389
- or response["project"]["artifactType"].get("artifactCollection") is None
347
+
348
+ result = ProjectArtifactCollection.model_validate(response)
349
+
350
+ if not (
351
+ result.project
352
+ and (proj := result.project)
353
+ and (type_ := proj.artifact_type)
354
+ and (collection := type_.artifact_collection)
390
355
  ):
391
- raise ValueError("Could not find artifact type {}".format(self._saved_type))
392
- sequence = response["project"]["artifactType"]["artifactSequence"]
356
+ raise ValueError(f"Could not find artifact type {self._saved_type}")
357
+
358
+ sequence = type_.artifact_sequence
393
359
  self._is_sequence = (
394
- sequence is not None and sequence["__typename"] == "ArtifactSequence"
395
- )
360
+ sequence is not None
361
+ ) and sequence.typename__ == "ArtifactSequence"
396
362
 
397
363
  if self._attrs is None:
398
- self._attrs = response["project"]["artifactType"]["artifactCollection"]
364
+ self._attrs = collection.model_dump(exclude_unset=True)
399
365
  return self._attrs
400
366
 
401
367
  def change_type(self, new_type: str) -> None:
402
368
  """Deprecated, change type directly with `save` instead."""
403
369
  deprecate.deprecate(
404
- field_name=deprecate.Deprecated.artifact_collection__change_type,
370
+ field_name=Deprecated.artifact_collection__change_type,
405
371
  warning_message="ArtifactCollection.change_type(type) is deprecated, use ArtifactCollection.save() instead.",
406
372
  )
407
373
 
@@ -410,32 +376,13 @@ class ArtifactCollection:
410
376
  termlog(
411
377
  f"Changing artifact collection type of {self._saved_type} to {new_type}"
412
378
  )
413
- template = """
414
- mutation MoveArtifactCollection(
415
- $artifactSequenceID: ID!
416
- $destinationArtifactTypeName: String!
417
- ) {
418
- moveArtifactSequence(
419
- input: {
420
- artifactSequenceID: $artifactSequenceID
421
- destinationArtifactTypeName: $destinationArtifactTypeName
422
- }
423
- ) {
424
- artifactCollection {
425
- id
426
- name
427
- description
428
- __typename
429
- }
430
- }
431
- }
432
- """
433
- variable_values = {
434
- "artifactSequenceID": self.id,
435
- "destinationArtifactTypeName": new_type,
436
- }
437
- mutation = gql(template)
438
- self.client.execute(mutation, variable_values=variable_values)
379
+ self.client.execute(
380
+ gql(MOVE_ARTIFACT_COLLECTION_GQL),
381
+ variable_values={
382
+ "artifactSequenceID": self.id,
383
+ "destinationArtifactTypeName": new_type,
384
+ },
385
+ )
439
386
  self._saved_type = new_type
440
387
  self._type = new_type
441
388
 
@@ -444,40 +391,19 @@ class ArtifactCollection:
444
391
  return self._is_sequence
445
392
 
446
393
  @normalize_exceptions
447
- def delete(self):
394
+ def delete(self) -> None:
448
395
  """Delete the entire artifact collection."""
449
- if self.is_sequence():
450
- mutation = gql(
451
- """
452
- mutation deleteArtifactSequence($id: ID!) {
453
- deleteArtifactSequence(input: {
454
- artifactSequenceID: $id
455
- }) {
456
- artifactCollection {
457
- state
458
- }
459
- }
460
- }
461
- """
462
- )
463
- else:
464
- mutation = gql(
465
- """
466
- mutation deleteArtifactPortfolio($id: ID!) {
467
- deleteArtifactPortfolio(input: {
468
- artifactPortfolioID: $id
469
- }) {
470
- artifactCollection {
471
- state
472
- }
473
- }
474
- }
475
- """
476
- )
477
- self.client.execute(mutation, variable_values={"id": self.id})
396
+ self.client.execute(
397
+ gql(
398
+ DELETE_ARTIFACT_SEQUENCE_GQL
399
+ if self.is_sequence()
400
+ else DELETE_ARTIFACT_PORTFOLIO_GQL
401
+ ),
402
+ variable_values={"id": self.id},
403
+ )
478
404
 
479
405
  @property
480
- def description(self):
406
+ def description(self) -> str:
481
407
  """A description of the artifact collection."""
482
408
  return self._description
483
409
 
@@ -486,7 +412,7 @@ class ArtifactCollection:
486
412
  self._description = description
487
413
 
488
414
  @property
489
- def tags(self):
415
+ def tags(self) -> List[str]:
490
416
  """The tags associated with the artifact collection."""
491
417
  return self._tags
492
418
 
@@ -499,7 +425,7 @@ class ArtifactCollection:
499
425
  self._tags = tags
500
426
 
501
427
  @property
502
- def name(self):
428
+ def name(self) -> str:
503
429
  """The name of the artifact collection."""
504
430
  return self._name
505
431
 
@@ -520,202 +446,80 @@ class ArtifactCollection:
520
446
  )
521
447
  self._type = type
522
448
 
523
- def _update_collection(self):
524
- mutation = gql("""
525
- mutation UpdateArtifactCollection(
526
- $artifactSequenceID: ID!
527
- $name: String
528
- $description: String
529
- ) {
530
- updateArtifactSequence(
531
- input: {
532
- artifactSequenceID: $artifactSequenceID
533
- name: $name
534
- description: $description
535
- }
536
- ) {
537
- artifactCollection {
538
- id
539
- name
540
- description
541
- }
542
- }
543
- }
544
- """)
545
-
546
- variable_values = {
547
- "artifactSequenceID": self.id,
548
- "name": self._name,
549
- "description": self.description,
550
- }
551
- self.client.execute(mutation, variable_values=variable_values)
449
+ def _update_collection(self) -> None:
450
+ self.client.execute(
451
+ gql(
452
+ UPDATE_ARTIFACT_SEQUENCE_GQL
453
+ if self.is_sequence()
454
+ else UPDATE_ARTIFACT_PORTFOLIO_GQL
455
+ ),
456
+ variable_values={
457
+ "id": self.id,
458
+ "name": self.name,
459
+ "description": self.description,
460
+ },
461
+ )
552
462
  self._saved_name = self._name
553
463
 
554
- def _update_collection_type(self):
555
- type_mutation = gql("""
556
- mutation MoveArtifactCollection(
557
- $artifactSequenceID: ID!
558
- $destinationArtifactTypeName: String!
559
- ) {
560
- moveArtifactSequence(
561
- input: {
562
- artifactSequenceID: $artifactSequenceID
563
- destinationArtifactTypeName: $destinationArtifactTypeName
564
- }
565
- ) {
566
- artifactCollection {
567
- id
568
- name
569
- description
570
- __typename
571
- }
572
- }
573
- }
574
- """)
575
-
576
- variable_values = {
577
- "artifactSequenceID": self.id,
578
- "destinationArtifactTypeName": self._type,
579
- }
580
- self.client.execute(type_mutation, variable_values=variable_values)
464
+ def _update_collection_type(self) -> None:
465
+ self.client.execute(
466
+ gql(MOVE_ARTIFACT_COLLECTION_GQL),
467
+ variable_values={
468
+ "artifactSequenceID": self.id,
469
+ "destinationArtifactTypeName": self.type,
470
+ },
471
+ )
581
472
  self._saved_type = self._type
582
473
 
583
- def _update_portfolio(self):
584
- mutation = gql("""
585
- mutation UpdateArtifactPortfolio(
586
- $artifactPortfolioID: ID!
587
- $name: String
588
- $description: String
589
- ) {
590
- updateArtifactPortfolio(
591
- input: {
592
- artifactPortfolioID: $artifactPortfolioID
593
- name: $name
594
- description: $description
595
- }
596
- ) {
597
- artifactCollection {
598
- id
599
- name
600
- description
601
- }
602
- }
603
- }
604
- """)
605
- variable_values = {
606
- "artifactPortfolioID": self.id,
607
- "name": self._name,
608
- "description": self.description,
609
- }
610
- self.client.execute(mutation, variable_values=variable_values)
611
- self._saved_name = self._name
612
-
613
- def _add_tags(self, tags_to_add):
614
- add_mutation = gql(
615
- """
616
- mutation CreateArtifactCollectionTagAssignments(
617
- $entityName: String!
618
- $projectName: String!
619
- $artifactCollectionName: String!
620
- $tags: [TagInput!]!
621
- ) {
622
- createArtifactCollectionTagAssignments(
623
- input: {
624
- entityName: $entityName
625
- projectName: $projectName
626
- artifactCollectionName: $artifactCollectionName
627
- tags: $tags
628
- }
629
- ) {
630
- tags {
631
- id
632
- name
633
- tagCategoryName
634
- }
635
- }
636
- }
637
- """
638
- )
474
+ def _add_tags(self, tags_to_add: Iterable[str]) -> None:
639
475
  self.client.execute(
640
- add_mutation,
476
+ gql(CREATE_ARTIFACT_COLLECTION_TAG_ASSIGNMENTS_GQL),
641
477
  variable_values={
642
478
  "entityName": self.entity,
643
479
  "projectName": self.project,
644
480
  "artifactCollectionName": self._saved_name,
645
- "tags": [
646
- {
647
- "tagName": tag,
648
- }
649
- for tag in tags_to_add
650
- ],
481
+ "tags": [{"tagName": tag} for tag in tags_to_add],
651
482
  },
652
483
  )
653
484
 
654
- def _delete_tags(self, tags_to_delete):
655
- delete_mutation = gql(
656
- """
657
- mutation DeleteArtifactCollectionTagAssignments(
658
- $entityName: String!
659
- $projectName: String!
660
- $artifactCollectionName: String!
661
- $tags: [TagInput!]!
662
- ) {
663
- deleteArtifactCollectionTagAssignments(
664
- input: {
665
- entityName: $entityName
666
- projectName: $projectName
667
- artifactCollectionName: $artifactCollectionName
668
- tags: $tags
669
- }
670
- ) {
671
- success
672
- }
673
- }
674
- """
675
- )
485
+ def _delete_tags(self, tags_to_delete: Iterable[str]) -> None:
676
486
  self.client.execute(
677
- delete_mutation,
487
+ gql(DELETE_ARTIFACT_COLLECTION_TAG_ASSIGNMENTS_GQL),
678
488
  variable_values={
679
489
  "entityName": self.entity,
680
490
  "projectName": self.project,
681
491
  "artifactCollectionName": self._saved_name,
682
- "tags": [
683
- {
684
- "tagName": tag,
685
- }
686
- for tag in tags_to_delete
687
- ],
492
+ "tags": [{"tagName": tag} for tag in tags_to_delete],
688
493
  },
689
494
  )
690
495
 
691
496
  def save(self) -> None:
692
497
  """Persist any changes made to the artifact collection."""
693
- if self.is_sequence():
694
- self._update_collection()
498
+ self._update_collection()
695
499
 
696
- if self._saved_type != self._type:
697
- self._update_collection_type()
698
- else:
699
- self._update_portfolio()
500
+ if self.is_sequence() and (self._saved_type != self._type):
501
+ self._update_collection_type()
700
502
 
701
- tags_to_add = set(self._tags) - set(self._saved_tags)
702
- tags_to_delete = set(self._saved_tags) - set(self._tags)
703
- if len(tags_to_add) > 0:
503
+ current_tags = set(self._tags)
504
+ saved_tags = set(self._saved_tags)
505
+ if tags_to_add := (current_tags - saved_tags):
704
506
  self._add_tags(tags_to_add)
705
- if len(tags_to_delete) > 0:
507
+ if tags_to_delete := (saved_tags - current_tags):
706
508
  self._delete_tags(tags_to_delete)
707
509
  self._saved_tags = copy(self._tags)
708
510
 
709
- def __repr__(self):
511
+ def __repr__(self) -> str:
710
512
  return f"<ArtifactCollection {self._name} ({self._type})>"
711
513
 
712
514
 
713
- class Artifacts(Paginator):
515
+ class Artifacts(SizedPaginator["wandb.Artifact"]):
714
516
  """An iterable collection of artifact versions associated with a project and optional filter.
715
517
 
716
518
  This is generally used indirectly via the `Api`.artifact_versions method.
717
519
  """
718
520
 
521
+ last_response: Optional[ArtifactsFragment]
522
+
719
523
  def __init__(
720
524
  self,
721
525
  client: Client,
@@ -728,8 +532,6 @@ class Artifacts(Paginator):
728
532
  per_page: int = 50,
729
533
  tags: Optional[Union[str, List[str]]] = None,
730
534
  ):
731
- from wandb.sdk.artifacts.artifact import _gql_artifact_fragment
732
-
733
535
  self.entity = entity
734
536
  self.collection_name = collection_name
735
537
  self.type = type
@@ -745,154 +547,107 @@ class Artifacts(Paginator):
745
547
  "collection": self.collection_name,
746
548
  "filters": json.dumps(self.filters),
747
549
  }
748
- self.QUERY = gql(
749
- """
750
- query Artifacts($project: String!, $entity: String!, $type: String!, $collection: String!, $cursor: String, $perPage: Int = 50, $order: String, $filters: JSONString) {{
751
- project(name: $project, entityName: $entity) {{
752
- artifactType(name: $type) {{
753
- artifactCollection: {}(name: $collection) {{
754
- name
755
- artifacts(filters: $filters, after: $cursor, first: $perPage, order: $order) {{
756
- totalCount
757
- edges {{
758
- node {{
759
- ...ArtifactFragment
760
- }}
761
- version
762
- cursor
763
- }}
764
- pageInfo {{
765
- endCursor
766
- hasNextPage
767
- }}
768
- }}
769
- }}
770
- }}
771
- }}
772
- }}
773
- {}
774
- """.format(
775
- artifact_collection_edge_name(
776
- server_supports_artifact_collections_gql_edges(client)
777
- ),
778
- _gql_artifact_fragment(),
779
- )
550
+
551
+ if server_supports_artifact_collections_gql_edges(client):
552
+ rename_fields = None
553
+ else:
554
+ rename_fields = {"artifactCollection": "artifactSequence"}
555
+
556
+ self.QUERY = gql_compat(
557
+ PROJECT_ARTIFACTS_GQL,
558
+ omit_fields=omit_artifact_fields(api=InternalApi()),
559
+ rename_fields=rename_fields,
780
560
  )
561
+
781
562
  super().__init__(client, variables, per_page)
782
563
 
564
+ @override
565
+ def _update_response(self) -> None:
566
+ data = self.client.execute(self.QUERY, variable_values=self.variables)
567
+ result = ProjectArtifacts.model_validate(data)
568
+
569
+ # Extract the inner `*Connection` result for faster/easier access.
570
+ if not (
571
+ (proj := result.project)
572
+ and (type_ := proj.artifact_type)
573
+ and (collection := type_.artifact_collection)
574
+ and (conn := collection.artifacts)
575
+ ):
576
+ raise ValueError(f"Unable to parse {type(self).__name__!r} response data")
577
+
578
+ self.last_response = ArtifactsFragment.model_validate(conn)
579
+
783
580
  @property
784
- def length(self):
785
- if self.last_response:
786
- return self.last_response["project"]["artifactType"]["artifactCollection"][
787
- "artifacts"
788
- ]["totalCount"]
789
- else:
581
+ def length(self) -> Optional[int]:
582
+ if self.last_response is None:
790
583
  return None
584
+ return self.last_response.total_count
791
585
 
792
586
  @property
793
- def more(self):
794
- if self.last_response:
795
- return self.last_response["project"]["artifactType"]["artifactCollection"][
796
- "artifacts"
797
- ]["pageInfo"]["hasNextPage"]
798
- else:
587
+ def more(self) -> bool:
588
+ if self.last_response is None:
799
589
  return True
590
+ return self.last_response.page_info.has_next_page
800
591
 
801
592
  @property
802
- def cursor(self):
803
- if self.last_response:
804
- return self.last_response["project"]["artifactType"]["artifactCollection"][
805
- "artifacts"
806
- ]["edges"][-1]["cursor"]
807
- else:
593
+ def cursor(self) -> Optional[str]:
594
+ if self.last_response is None:
808
595
  return None
596
+ return self.last_response.edges[-1].cursor
809
597
 
810
- def convert_objects(self):
811
- collection = self.last_response["project"]["artifactType"]["artifactCollection"]
812
- artifact_edges = collection.get("artifacts", {}).get("edges", [])
598
+ def convert_objects(self) -> List["wandb.Artifact"]:
599
+ artifact_edges = (edge for edge in self.last_response.edges if edge.node)
813
600
  artifacts = (
814
601
  wandb.Artifact._from_attrs(
815
- self.entity,
816
- self.project,
817
- self.collection_name + ":" + a["version"],
818
- a["node"],
819
- self.client,
602
+ entity=self.entity,
603
+ project=self.project,
604
+ name=f"{self.collection_name}:{edge.version}",
605
+ attrs=edge.node.model_dump(exclude_unset=True),
606
+ client=self.client,
820
607
  )
821
- for a in artifact_edges
608
+ for edge in artifact_edges
822
609
  )
823
610
  required_tags = set(self.tags or [])
824
- return [
825
- artifact for artifact in artifacts if required_tags.issubset(artifact.tags)
826
- ]
611
+ return [art for art in artifacts if required_tags.issubset(art.tags)]
827
612
 
828
613
 
829
- class RunArtifacts(Paginator):
830
- def __init__(
831
- self, client: Client, run: "Run", mode="logged", per_page: Optional[int] = 50
832
- ):
833
- from wandb.sdk.artifacts.artifact import _gql_artifact_fragment
834
-
835
- output_query = gql(
836
- """
837
- query RunOutputArtifacts(
838
- $entity: String!, $project: String!, $runName: String!, $cursor: String, $perPage: Int,
839
- ) {
840
- project(name: $project, entityName: $entity) {
841
- run(name: $runName) {
842
- outputArtifacts(after: $cursor, first: $perPage) {
843
- totalCount
844
- edges {
845
- node {
846
- ...ArtifactFragment
847
- }
848
- cursor
849
- }
850
- pageInfo {
851
- endCursor
852
- hasNextPage
853
- }
854
- }
855
- }
856
- }
857
- }
858
- """
859
- + _gql_artifact_fragment()
860
- )
614
+ class RunArtifacts(SizedPaginator["wandb.Artifact"]):
615
+ last_response: Union[
616
+ RunOutputArtifactsProjectRunOutputArtifacts,
617
+ RunInputArtifactsProjectRunInputArtifacts,
618
+ ]
861
619
 
862
- input_query = gql(
863
- """
864
- query RunInputArtifacts(
865
- $entity: String!, $project: String!, $runName: String!, $cursor: String, $perPage: Int,
866
- ) {
867
- project(name: $project, entityName: $entity) {
868
- run(name: $runName) {
869
- inputArtifacts(after: $cursor, first: $perPage) {
870
- totalCount
871
- edges {
872
- node {
873
- ...ArtifactFragment
874
- }
875
- cursor
876
- }
877
- pageInfo {
878
- endCursor
879
- hasNextPage
880
- }
881
- }
882
- }
883
- }
884
- }
885
- """
886
- + _gql_artifact_fragment()
887
- )
620
+ #: The pydantic model used to parse the (inner part of the) raw response.
621
+ _response_cls: Type[
622
+ Union[
623
+ RunOutputArtifactsProjectRunOutputArtifacts,
624
+ RunInputArtifactsProjectRunInputArtifacts,
625
+ ]
626
+ ]
888
627
 
628
+ def __init__(
629
+ self,
630
+ client: Client,
631
+ run: "Run",
632
+ mode: Literal["logged", "used"] = "logged",
633
+ per_page: int = 50,
634
+ ):
889
635
  self.run = run
636
+
890
637
  if mode == "logged":
891
638
  self.run_key = "outputArtifacts"
892
- self.QUERY = output_query
639
+ self.QUERY = gql_compat(
640
+ RUN_OUTPUT_ARTIFACTS_GQL,
641
+ omit_fields=omit_artifact_fields(api=InternalApi()),
642
+ )
643
+ self._response_cls = RunOutputArtifactsProjectRunOutputArtifacts
893
644
  elif mode == "used":
894
645
  self.run_key = "inputArtifacts"
895
- self.QUERY = input_query
646
+ self.QUERY = gql_compat(
647
+ RUN_INPUT_ARTIFACTS_GQL,
648
+ omit_fields=omit_artifact_fields(api=InternalApi()),
649
+ )
650
+ self._response_cls = RunInputArtifactsProjectRunInputArtifacts
896
651
  else:
897
652
  raise ValueError("mode must be logged or used")
898
653
 
@@ -901,72 +656,50 @@ class RunArtifacts(Paginator):
901
656
  "project": run.project,
902
657
  "runName": run.id,
903
658
  }
904
-
905
659
  super().__init__(client, variable_values, per_page)
906
660
 
661
+ @override
662
+ def _update_response(self) -> None:
663
+ data = self.client.execute(self.QUERY, variable_values=self.variables)
664
+
665
+ # Extract the inner `*Connection` result for faster/easier access.
666
+ inner_data = data["project"]["run"][self.run_key]
667
+ self.last_response = self._response_cls.model_validate(inner_data)
668
+
907
669
  @property
908
- def length(self):
909
- if self.last_response:
910
- return self.last_response["project"]["run"][self.run_key]["totalCount"]
911
- else:
670
+ def length(self) -> Optional[int]:
671
+ if self.last_response is None:
912
672
  return None
673
+ return self.last_response.total_count
913
674
 
914
675
  @property
915
- def more(self):
916
- if self.last_response:
917
- return self.last_response["project"]["run"][self.run_key]["pageInfo"][
918
- "hasNextPage"
919
- ]
920
- else:
676
+ def more(self) -> bool:
677
+ if self.last_response is None:
921
678
  return True
679
+ return self.last_response.page_info.has_next_page
922
680
 
923
681
  @property
924
- def cursor(self):
925
- if self.last_response:
926
- return self.last_response["project"]["run"][self.run_key]["edges"][-1][
927
- "cursor"
928
- ]
929
- else:
682
+ def cursor(self) -> Optional[str]:
683
+ if self.last_response is None:
930
684
  return None
685
+ return self.last_response.edges[-1].cursor
931
686
 
932
- def convert_objects(self):
687
+ def convert_objects(self) -> List["wandb.Artifact"]:
933
688
  return [
934
689
  wandb.Artifact._from_attrs(
935
- r["node"]["artifactSequence"]["project"]["entityName"],
936
- r["node"]["artifactSequence"]["project"]["name"],
937
- "{}:v{}".format(
938
- r["node"]["artifactSequence"]["name"], r["node"]["versionIndex"]
939
- ),
940
- r["node"],
941
- self.client,
690
+ entity=node.artifact_sequence.project.entity_name,
691
+ project=node.artifact_sequence.project.name,
692
+ name=f"{node.artifact_sequence.name}:v{node.version_index}",
693
+ attrs=node.model_dump(exclude_unset=True),
694
+ client=self.client,
942
695
  )
943
- for r in self.last_response["project"]["run"][self.run_key]["edges"]
696
+ for r in self.last_response.edges
697
+ if (node := r.node)
944
698
  ]
945
699
 
946
700
 
947
- class ArtifactFiles(Paginator):
948
- QUERY = gql(
949
- """
950
- query ArtifactFiles(
951
- $entityName: String!,
952
- $projectName: String!,
953
- $artifactTypeName: String!,
954
- $artifactName: String!
955
- $fileNames: [String!],
956
- $fileCursor: String,
957
- $fileLimit: Int = 50
958
- ) {{
959
- project(name: $projectName, entityName: $entityName) {{
960
- artifactType(name: $artifactTypeName) {{
961
- artifact(name: $artifactName) {{
962
- ...ArtifactFilesFragment
963
- }}
964
- }}
965
- }}
966
- }}
967
- {}
968
- """.format(ARTIFACT_FILES_FRAGMENT)
969
- )
701
+ class ArtifactFiles(SizedPaginator["public.File"]):
702
+ last_response: Optional[FilesFragment]
970
703
 
971
704
  def __init__(
972
705
  self,
@@ -975,59 +708,92 @@ class ArtifactFiles(Paginator):
975
708
  names: Optional[Sequence[str]] = None,
976
709
  per_page: int = 50,
977
710
  ):
711
+ self.query_via_membership = InternalApi()._check_server_feature_with_fallback(
712
+ ServerFeature.ARTIFACT_COLLECTION_MEMBERSHIP_FILES
713
+ )
978
714
  self.artifact = artifact
979
- variables = {
980
- "entityName": artifact.source_entity,
981
- "projectName": artifact.source_project,
982
- "artifactTypeName": artifact.type,
983
- "artifactName": artifact.source_name,
984
- "fileNames": names,
985
- }
715
+
716
+ if self.query_via_membership:
717
+ query_str = ARTIFACT_COLLECTION_MEMBERSHIP_FILES_GQL
718
+ variables = {
719
+ "entityName": artifact.entity,
720
+ "projectName": artifact.project,
721
+ "artifactName": artifact.name.split(":")[0],
722
+ "artifactVersionIndex": artifact.version,
723
+ "fileNames": names,
724
+ }
725
+ else:
726
+ query_str = ARTIFACT_VERSION_FILES_GQL
727
+ variables = {
728
+ "entityName": artifact.source_entity,
729
+ "projectName": artifact.source_project,
730
+ "artifactName": artifact.source_name,
731
+ "artifactTypeName": artifact.type,
732
+ "fileNames": names,
733
+ }
734
+
986
735
  # The server must advertise at least SDK 0.12.21
987
736
  # to get storagePath
988
737
  if not client.version_supported("0.12.21"):
989
- self.QUERY = gql(self.QUERY.loc.source.body.replace("storagePath\n", ""))
738
+ self.QUERY = gql_compat(query_str, omit_fields={"storagePath"})
739
+ else:
740
+ self.QUERY = gql(query_str)
741
+
990
742
  super().__init__(client, variables, per_page)
991
743
 
744
+ @override
745
+ def _update_response(self) -> None:
746
+ data = self.client.execute(self.QUERY, variable_values=self.variables)
747
+
748
+ # Extract the inner `*Connection` result for faster/easier access.
749
+ if self.query_via_membership:
750
+ result = ArtifactCollectionMembershipFiles.model_validate(data)
751
+ conn = result.project.artifact_collection.artifact_membership.files
752
+ else:
753
+ result = ArtifactVersionFiles.model_validate(data)
754
+ conn = result.project.artifact_type.artifact.files
755
+
756
+ if conn is None:
757
+ raise ValueError(f"Unable to parse {type(self).__name__!r} response data")
758
+
759
+ self.last_response = FilesFragment.model_validate(conn)
760
+
992
761
  @property
993
- def path(self):
762
+ def path(self) -> List[str]:
994
763
  return [self.artifact.entity, self.artifact.project, self.artifact.name]
995
764
 
996
765
  @property
997
- def length(self):
766
+ def length(self) -> int:
998
767
  return self.artifact.file_count
999
768
 
1000
769
  @property
1001
- def more(self):
1002
- if self.last_response:
1003
- return self.last_response["project"]["artifactType"]["artifact"]["files"][
1004
- "pageInfo"
1005
- ]["hasNextPage"]
1006
- else:
770
+ def more(self) -> bool:
771
+ if self.last_response is None:
1007
772
  return True
773
+ return self.last_response.page_info.has_next_page
1008
774
 
1009
775
  @property
1010
- def cursor(self):
1011
- if self.last_response:
1012
- return self.last_response["project"]["artifactType"]["artifact"]["files"][
1013
- "edges"
1014
- ][-1]["cursor"]
1015
- else:
776
+ def cursor(self) -> Optional[str]:
777
+ if self.last_response is None:
1016
778
  return None
779
+ return self.last_response.edges[-1].cursor
1017
780
 
1018
- def update_variables(self):
781
+ def update_variables(self) -> None:
1019
782
  self.variables.update({"fileLimit": self.per_page, "fileCursor": self.cursor})
1020
783
 
1021
- def convert_objects(self):
784
+ def convert_objects(self) -> List["public.File"]:
1022
785
  return [
1023
- public.File(self.client, r["node"])
1024
- for r in self.last_response["project"]["artifactType"]["artifact"]["files"][
1025
- "edges"
1026
- ]
786
+ public.File(
787
+ client=self.client,
788
+ attrs=node.model_dump(exclude_unset=True),
789
+ )
790
+ for edge in self.last_response.edges
791
+ if (node := edge.node)
1027
792
  ]
1028
793
 
1029
- def __repr__(self):
1030
- return "<ArtifactFiles {} ({})>".format("/".join(self.path), len(self))
794
+ def __repr__(self) -> str:
795
+ path_str = "/".join(self.path)
796
+ return f"<ArtifactFiles {path_str} ({len(self)})>"
1031
797
 
1032
798
 
1033
799
  def server_supports_artifact_collections_gql_edges(
@@ -1043,21 +809,3 @@ def server_supports_artifact_collections_gql_edges(
1043
809
  "W&B Local Server version does not support ArtifactCollection gql edges; falling back to using legacy ArtifactSequence. Please update server to at least version 0.9.50."
1044
810
  )
1045
811
  return supported
1046
-
1047
-
1048
- def artifact_collection_edge_name(server_supports_artifact_collections: bool) -> str:
1049
- return (
1050
- "artifactCollection"
1051
- if server_supports_artifact_collections
1052
- else "artifactSequence"
1053
- )
1054
-
1055
-
1056
- def artifact_collection_plural_edge_name(
1057
- server_supports_artifact_collections: bool,
1058
- ) -> str:
1059
- return (
1060
- "artifactCollections"
1061
- if server_supports_artifact_collections
1062
- else "artifactSequences"
1063
- )