wandb 0.19.9__py3-none-any.whl → 0.19.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (127) hide show
  1. wandb/__init__.py +1 -1
  2. wandb/__init__.pyi +4 -1
  3. wandb/_pydantic/__init__.py +14 -7
  4. wandb/_pydantic/base.py +44 -9
  5. wandb/_pydantic/utils.py +66 -0
  6. wandb/_pydantic/v1_compat.py +78 -56
  7. wandb/apis/public/__init__.py +2 -2
  8. wandb/apis/public/api.py +114 -2
  9. wandb/apis/public/artifacts.py +365 -673
  10. wandb/apis/public/automations.py +69 -0
  11. wandb/apis/public/integrations.py +168 -0
  12. wandb/apis/public/projects.py +29 -0
  13. wandb/apis/public/utils.py +107 -1
  14. wandb/automations/__init__.py +81 -0
  15. wandb/automations/_filters/__init__.py +40 -0
  16. wandb/automations/_filters/expressions.py +179 -0
  17. wandb/automations/_filters/operators.py +267 -0
  18. wandb/automations/_filters/run_metrics.py +183 -0
  19. wandb/automations/_generated/__init__.py +184 -0
  20. wandb/automations/_generated/create_filter_trigger.py +21 -0
  21. wandb/automations/_generated/create_generic_webhook_integration.py +43 -0
  22. wandb/automations/_generated/delete_trigger.py +19 -0
  23. wandb/automations/_generated/enums.py +33 -0
  24. wandb/automations/_generated/fragments.py +343 -0
  25. wandb/automations/_generated/generic_webhook_integrations_by_entity.py +22 -0
  26. wandb/automations/_generated/get_triggers.py +24 -0
  27. wandb/automations/_generated/get_triggers_by_entity.py +24 -0
  28. wandb/automations/_generated/input_types.py +104 -0
  29. wandb/automations/_generated/integrations_by_entity.py +22 -0
  30. wandb/automations/_generated/operations.py +710 -0
  31. wandb/automations/_generated/slack_integrations_by_entity.py +22 -0
  32. wandb/automations/_generated/update_filter_trigger.py +21 -0
  33. wandb/automations/_utils.py +123 -0
  34. wandb/automations/_validators.py +73 -0
  35. wandb/automations/actions.py +205 -0
  36. wandb/automations/automations.py +109 -0
  37. wandb/automations/events.py +235 -0
  38. wandb/automations/integrations.py +26 -0
  39. wandb/automations/scopes.py +76 -0
  40. wandb/beta/workflows.py +9 -10
  41. wandb/bin/gpu_stats +0 -0
  42. wandb/cli/cli.py +3 -3
  43. wandb/integration/keras/keras.py +2 -1
  44. wandb/integration/langchain/wandb_tracer.py +2 -1
  45. wandb/jupyter.py +137 -118
  46. wandb/old/summary.py +0 -2
  47. wandb/proto/v3/wandb_internal_pb2.py +293 -292
  48. wandb/proto/v3/wandb_settings_pb2.py +2 -2
  49. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  50. wandb/proto/v4/wandb_internal_pb2.py +292 -292
  51. wandb/proto/v4/wandb_settings_pb2.py +2 -2
  52. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  53. wandb/proto/v5/wandb_internal_pb2.py +292 -292
  54. wandb/proto/v5/wandb_settings_pb2.py +2 -2
  55. wandb/proto/v5/wandb_telemetry_pb2.py +10 -10
  56. wandb/proto/v6/wandb_base_pb2.py +41 -0
  57. wandb/proto/v6/wandb_internal_pb2.py +393 -0
  58. wandb/proto/v6/wandb_server_pb2.py +78 -0
  59. wandb/proto/v6/wandb_settings_pb2.py +58 -0
  60. wandb/proto/v6/wandb_telemetry_pb2.py +52 -0
  61. wandb/proto/wandb_base_pb2.py +2 -0
  62. wandb/proto/wandb_deprecated.py +8 -0
  63. wandb/proto/wandb_internal_pb2.py +3 -1
  64. wandb/proto/wandb_server_pb2.py +2 -0
  65. wandb/proto/wandb_settings_pb2.py +2 -0
  66. wandb/proto/wandb_telemetry_pb2.py +2 -0
  67. wandb/sdk/artifacts/_generated/__init__.py +248 -0
  68. wandb/sdk/artifacts/_generated/artifact_collection_membership_files.py +43 -0
  69. wandb/sdk/artifacts/_generated/artifact_version_files.py +36 -0
  70. wandb/sdk/artifacts/_generated/create_artifact_collection_tag_assignments.py +36 -0
  71. wandb/sdk/artifacts/_generated/delete_artifact_collection_tag_assignments.py +25 -0
  72. wandb/sdk/artifacts/_generated/delete_artifact_portfolio.py +35 -0
  73. wandb/sdk/artifacts/_generated/delete_artifact_sequence.py +35 -0
  74. wandb/sdk/artifacts/_generated/enums.py +17 -0
  75. wandb/sdk/artifacts/_generated/fragments.py +186 -0
  76. wandb/sdk/artifacts/_generated/input_types.py +16 -0
  77. wandb/sdk/artifacts/_generated/move_artifact_collection.py +35 -0
  78. wandb/sdk/artifacts/_generated/operations.py +510 -0
  79. wandb/sdk/artifacts/_generated/project_artifact_collection.py +101 -0
  80. wandb/sdk/artifacts/_generated/project_artifact_collections.py +33 -0
  81. wandb/sdk/artifacts/_generated/project_artifact_type.py +24 -0
  82. wandb/sdk/artifacts/_generated/project_artifact_types.py +24 -0
  83. wandb/sdk/artifacts/_generated/project_artifacts.py +42 -0
  84. wandb/sdk/artifacts/_generated/run_input_artifacts.py +51 -0
  85. wandb/sdk/artifacts/_generated/run_output_artifacts.py +51 -0
  86. wandb/sdk/artifacts/_generated/update_artifact_portfolio.py +35 -0
  87. wandb/sdk/artifacts/_generated/update_artifact_sequence.py +35 -0
  88. wandb/sdk/artifacts/_graphql_fragments.py +56 -79
  89. wandb/sdk/artifacts/artifact.py +40 -13
  90. wandb/sdk/artifacts/artifact_manifest_entry.py +2 -1
  91. wandb/sdk/artifacts/storage_handlers/azure_handler.py +1 -0
  92. wandb/sdk/data_types/base_types/media.py +2 -3
  93. wandb/sdk/data_types/base_types/wb_value.py +34 -11
  94. wandb/sdk/data_types/html.py +36 -9
  95. wandb/sdk/data_types/image.py +12 -12
  96. wandb/sdk/data_types/table.py +5 -0
  97. wandb/sdk/data_types/trace_tree.py +2 -0
  98. wandb/sdk/data_types/utils.py +1 -1
  99. wandb/sdk/data_types/video.py +14 -26
  100. wandb/sdk/interface/interface.py +2 -0
  101. wandb/sdk/internal/profiler.py +6 -5
  102. wandb/sdk/internal/run.py +13 -6
  103. wandb/sdk/lib/apikey.py +25 -4
  104. wandb/sdk/lib/asyncio_compat.py +1 -1
  105. wandb/sdk/lib/deprecate.py +13 -22
  106. wandb/sdk/lib/disabled.py +2 -1
  107. wandb/sdk/lib/printer.py +37 -8
  108. wandb/sdk/lib/printer_asyncio.py +46 -0
  109. wandb/sdk/lib/redirect.py +10 -5
  110. wandb/sdk/service/server_sock.py +19 -14
  111. wandb/sdk/service/service.py +9 -7
  112. wandb/sdk/service/streams.py +5 -0
  113. wandb/sdk/verify/verify.py +6 -3
  114. wandb/sdk/wandb_init.py +185 -65
  115. wandb/sdk/wandb_login.py +13 -4
  116. wandb/sdk/wandb_run.py +382 -286
  117. wandb/sdk/wandb_settings.py +21 -3
  118. wandb/sdk/wandb_setup.py +49 -0
  119. wandb/util.py +29 -29
  120. {wandb-0.19.9.dist-info → wandb-0.19.10.dist-info}/METADATA +5 -5
  121. {wandb-0.19.9.dist-info → wandb-0.19.10.dist-info}/RECORD +124 -71
  122. wandb/_globals.py +0 -19
  123. wandb/sdk/internal/_generated/base.py +0 -226
  124. wandb/sdk/internal/_generated/typing_compat.py +0 -14
  125. {wandb-0.19.9.dist-info → wandb-0.19.10.dist-info}/WHEEL +0 -0
  126. {wandb-0.19.9.dist-info → wandb-0.19.10.dist-info}/entry_points.txt +0 -0
  127. {wandb-0.19.9.dist-info → wandb-0.19.10.dist-info}/licenses/LICENSE +0 -0
@@ -3,8 +3,20 @@
3
3
  import json
4
4
  import re
5
5
  from copy import copy
6
- from typing import TYPE_CHECKING, Any, List, Mapping, Optional, Sequence, Union
6
+ from typing import (
7
+ TYPE_CHECKING,
8
+ Any,
9
+ Iterable,
10
+ List,
11
+ Literal,
12
+ Mapping,
13
+ Optional,
14
+ Sequence,
15
+ Type,
16
+ Union,
17
+ )
7
18
 
19
+ from typing_extensions import override
8
20
  from wandb_gql import Client, gql
9
21
 
10
22
  import wandb
@@ -12,35 +24,54 @@ from wandb.apis import public
12
24
  from wandb.apis.normalize import normalize_exceptions
13
25
  from wandb.apis.paginator import Paginator, SizedPaginator
14
26
  from wandb.errors.term import termlog
27
+ from wandb.proto.wandb_deprecated import Deprecated
15
28
  from wandb.proto.wandb_internal_pb2 import ServerFeature
16
- from wandb.sdk.artifacts._graphql_fragments import (
17
- ARTIFACT_FILES_FRAGMENT,
18
- ARTIFACTS_TYPES_FRAGMENT,
29
+ from wandb.sdk.artifacts._generated import (
30
+ ARTIFACT_COLLECTION_MEMBERSHIP_FILES_GQL,
31
+ ARTIFACT_VERSION_FILES_GQL,
32
+ CREATE_ARTIFACT_COLLECTION_TAG_ASSIGNMENTS_GQL,
33
+ DELETE_ARTIFACT_COLLECTION_TAG_ASSIGNMENTS_GQL,
34
+ DELETE_ARTIFACT_PORTFOLIO_GQL,
35
+ DELETE_ARTIFACT_SEQUENCE_GQL,
36
+ MOVE_ARTIFACT_COLLECTION_GQL,
37
+ PROJECT_ARTIFACT_COLLECTION_GQL,
38
+ PROJECT_ARTIFACT_COLLECTIONS_GQL,
39
+ PROJECT_ARTIFACT_TYPE_GQL,
40
+ PROJECT_ARTIFACT_TYPES_GQL,
41
+ PROJECT_ARTIFACTS_GQL,
42
+ RUN_INPUT_ARTIFACTS_GQL,
43
+ RUN_OUTPUT_ARTIFACTS_GQL,
44
+ UPDATE_ARTIFACT_PORTFOLIO_GQL,
45
+ UPDATE_ARTIFACT_SEQUENCE_GQL,
46
+ ArtifactCollectionMembershipFiles,
47
+ ArtifactCollectionsFragment,
48
+ ArtifactsFragment,
49
+ ArtifactTypeFragment,
50
+ ArtifactTypesFragment,
51
+ ArtifactVersionFiles,
52
+ FilesFragment,
53
+ ProjectArtifactCollection,
54
+ ProjectArtifactCollections,
55
+ ProjectArtifacts,
56
+ ProjectArtifactType,
57
+ ProjectArtifactTypes,
58
+ RunInputArtifactsProjectRunInputArtifacts,
59
+ RunOutputArtifactsProjectRunOutputArtifacts,
19
60
  )
61
+ from wandb.sdk.artifacts._graphql_fragments import omit_artifact_fields
20
62
  from wandb.sdk.internal.internal_api import Api as InternalApi
21
63
  from wandb.sdk.lib import deprecate
22
64
 
65
+ from .utils import gql_compat
66
+
23
67
  if TYPE_CHECKING:
24
68
  from wandb.apis.public import RetryingClient, Run
25
69
 
26
70
 
27
71
  class ArtifactTypes(Paginator["ArtifactType"]):
28
- QUERY = gql(
29
- """
30
- query ProjectArtifacts(
31
- $entityName: String!,
32
- $projectName: String!,
33
- $cursor: String,
34
- ) {{
35
- project(name: $projectName, entityName: $entityName) {{
36
- artifactTypes(after: $cursor) {{
37
- ...ArtifactTypesFragment
38
- }}
39
- }}
40
- }}
41
- {}
42
- """.format(ARTIFACTS_TYPES_FRAGMENT)
43
- )
72
+ QUERY = gql(PROJECT_ARTIFACT_TYPES_GQL)
73
+
74
+ last_response: Optional[ArtifactTypesFragment]
44
75
 
45
76
  def __init__(
46
77
  self,
@@ -56,41 +87,54 @@ class ArtifactTypes(Paginator["ArtifactType"]):
56
87
  "entityName": entity,
57
88
  "projectName": project,
58
89
  }
59
-
60
90
  super().__init__(client, variable_values, per_page)
61
91
 
92
+ @override
93
+ def _update_response(self) -> None:
94
+ """Fetch and validate the response data for the current page."""
95
+ data = self.client.execute(self.QUERY, variable_values=self.variables)
96
+ result = ProjectArtifactTypes.model_validate(data)
97
+
98
+ # Extract the inner `*Connection` result for faster/easier access.
99
+ if not ((proj := result.project) and (conn := proj.artifact_types)):
100
+ raise ValueError(f"Unable to parse {type(self).__name__!r} response data")
101
+
102
+ self.last_response = ArtifactTypesFragment.model_validate(conn)
103
+
62
104
  @property
63
105
  def length(self) -> None:
64
106
  # TODO
65
107
  return None
66
108
 
67
109
  @property
68
- def more(self):
69
- if self.last_response:
70
- return self.last_response["project"]["artifactTypes"]["pageInfo"][
71
- "hasNextPage"
72
- ]
73
- else:
110
+ def more(self) -> bool:
111
+ if self.last_response is None:
74
112
  return True
113
+ return self.last_response.page_info.has_next_page
75
114
 
76
115
  @property
77
- def cursor(self):
78
- if self.last_response:
79
- return self.last_response["project"]["artifactTypes"]["edges"][-1]["cursor"]
80
- else:
116
+ def cursor(self) -> Optional[str]:
117
+ if self.last_response is None:
81
118
  return None
119
+ return self.last_response.edges[-1].cursor
82
120
 
83
- def update_variables(self):
121
+ def update_variables(self) -> None:
84
122
  self.variables.update({"cursor": self.cursor})
85
123
 
86
- def convert_objects(self):
87
- if self.last_response["project"] is None:
124
+ def convert_objects(self) -> List["ArtifactType"]:
125
+ if self.last_response is None:
88
126
  return []
127
+
89
128
  return [
90
129
  ArtifactType(
91
- self.client, self.entity, self.project, r["node"]["name"], r["node"]
130
+ client=self.client,
131
+ entity=self.entity,
132
+ project=self.project,
133
+ type_name=node.name,
134
+ attrs=node.model_dump(exclude_unset=True),
92
135
  )
93
- for r in self.last_response["project"]["artifactTypes"]["edges"]
136
+ for edge in self.last_response.edges
137
+ if edge.node and (node := ArtifactTypeFragment.model_validate(edge.node))
94
138
  ]
95
139
 
96
140
 
@@ -112,39 +156,19 @@ class ArtifactType:
112
156
  self.load()
113
157
 
114
158
  def load(self):
115
- query = gql(
116
- """
117
- query ProjectArtifactType(
118
- $entityName: String!,
119
- $projectName: String!,
120
- $artifactTypeName: String!
121
- ) {
122
- project(name: $projectName, entityName: $entityName) {
123
- artifactType(name: $artifactTypeName) {
124
- id
125
- name
126
- description
127
- createdAt
128
- }
129
- }
130
- }
131
- """
132
- )
133
- response: Optional[Mapping[str, Any]] = self.client.execute(
134
- query,
159
+ data: Optional[Mapping[str, Any]] = self.client.execute(
160
+ gql(PROJECT_ARTIFACT_TYPE_GQL),
135
161
  variable_values={
136
162
  "entityName": self.entity,
137
163
  "projectName": self.project,
138
164
  "artifactTypeName": self.type,
139
165
  },
140
166
  )
141
- if (
142
- response is None
143
- or response.get("project") is None
144
- or response["project"].get("artifactType") is None
145
- ):
146
- raise ValueError("Could not find artifact type {}".format(self.type))
147
- self._attrs = response["project"]["artifactType"]
167
+ result = ProjectArtifactType.model_validate(data)
168
+ if not ((proj := result.project) and (artifact_type := proj.artifact_type)):
169
+ raise ValueError(f"Could not find artifact type {self.type}")
170
+
171
+ self._attrs = artifact_type.model_dump(exclude_unset=True)
148
172
  return self._attrs
149
173
 
150
174
  @property
@@ -170,6 +194,8 @@ class ArtifactType:
170
194
 
171
195
 
172
196
  class ArtifactCollections(SizedPaginator["ArtifactCollection"]):
197
+ last_response: Optional[ArtifactCollectionsFragment]
198
+
173
199
  def __init__(
174
200
  self,
175
201
  client: Client,
@@ -188,86 +214,65 @@ class ArtifactCollections(SizedPaginator["ArtifactCollection"]):
188
214
  "artifactTypeName": type_name,
189
215
  }
190
216
 
191
- self.QUERY = gql(
192
- """
193
- query ProjectArtifactCollections(
194
- $entityName: String!,
195
- $projectName: String!,
196
- $artifactTypeName: String!
197
- $cursor: String,
198
- ) {{
199
- project(name: $projectName, entityName: $entityName) {{
200
- artifactType(name: $artifactTypeName) {{
201
- artifactCollections: {}(after: $cursor) {{
202
- pageInfo {{
203
- endCursor
204
- hasNextPage
205
- }}
206
- totalCount
207
- edges {{
208
- node {{
209
- id
210
- name
211
- description
212
- createdAt
213
- }}
214
- cursor
215
- }}
216
- }}
217
- }}
218
- }}
219
- }}
220
- """.format(
221
- artifact_collection_plural_edge_name(
222
- server_supports_artifact_collections_gql_edges(client)
223
- )
224
- )
217
+ if server_supports_artifact_collections_gql_edges(client):
218
+ rename_fields = None
219
+ else:
220
+ rename_fields = {"artifactCollections": "artifactSequences"}
221
+
222
+ self.QUERY = gql_compat(
223
+ PROJECT_ARTIFACT_COLLECTIONS_GQL, rename_fields=rename_fields
225
224
  )
226
225
 
227
226
  super().__init__(client, variable_values, per_page)
228
227
 
228
+ @override
229
+ def _update_response(self) -> None:
230
+ """Fetch and validate the response data for the current page."""
231
+ data = self.client.execute(self.QUERY, variable_values=self.variables)
232
+ result = ProjectArtifactCollections.model_validate(data)
233
+
234
+ # Extract the inner `*Connection` result for faster/easier access.
235
+ if not (
236
+ (proj := result.project)
237
+ and (type_ := proj.artifact_type)
238
+ and (conn := type_.artifact_collections)
239
+ ):
240
+ raise ValueError(f"Unable to parse {type(self).__name__!r} response data")
241
+
242
+ self.last_response = ArtifactCollectionsFragment.model_validate(conn)
243
+
229
244
  @property
230
245
  def length(self):
231
- if self.last_response:
232
- return self.last_response["project"]["artifactType"]["artifactCollections"][
233
- "totalCount"
234
- ]
235
- else:
246
+ if self.last_response is None:
236
247
  return None
248
+ return self.last_response.total_count
237
249
 
238
250
  @property
239
251
  def more(self):
240
- if self.last_response:
241
- return self.last_response["project"]["artifactType"]["artifactCollections"][
242
- "pageInfo"
243
- ]["hasNextPage"]
244
- else:
252
+ if self.last_response is None:
245
253
  return True
254
+ return self.last_response.page_info.has_next_page
246
255
 
247
256
  @property
248
257
  def cursor(self):
249
- if self.last_response:
250
- return self.last_response["project"]["artifactType"]["artifactCollections"][
251
- "edges"
252
- ][-1]["cursor"]
253
- else:
258
+ if self.last_response is None:
254
259
  return None
260
+ return self.last_response.edges[-1].cursor
255
261
 
256
- def update_variables(self):
262
+ def update_variables(self) -> None:
257
263
  self.variables.update({"cursor": self.cursor})
258
264
 
259
- def convert_objects(self):
265
+ def convert_objects(self) -> List["ArtifactCollection"]:
260
266
  return [
261
267
  ArtifactCollection(
262
- self.client,
263
- self.entity,
264
- self.project,
265
- r["node"]["name"],
266
- self.type_name,
268
+ client=self.client,
269
+ entity=self.entity,
270
+ project=self.project,
271
+ name=node.name,
272
+ type=self.type_name,
267
273
  )
268
- for r in self.last_response["project"]["artifactType"][
269
- "artifactCollections"
270
- ]["edges"]
274
+ for edge in self.last_response.edges
275
+ if (node := edge.node)
271
276
  ]
272
277
 
273
278
 
@@ -300,83 +305,38 @@ class ArtifactCollection:
300
305
  self.organization = organization
301
306
 
302
307
  @property
303
- def id(self):
308
+ def id(self) -> str:
304
309
  return self._attrs["id"]
305
310
 
306
311
  @normalize_exceptions
307
- def artifacts(self, per_page=50):
312
+ def artifacts(self, per_page: int = 50) -> "Artifacts":
308
313
  """Artifacts."""
309
314
  return Artifacts(
310
- self.client,
311
- self.entity,
312
- self.project,
313
- self._saved_name,
314
- self._saved_type,
315
+ client=self.client,
316
+ entity=self.entity,
317
+ project=self.project,
318
+ collection_name=self._saved_name,
319
+ type=self._saved_type,
315
320
  per_page=per_page,
316
321
  )
317
322
 
318
323
  @property
319
- def aliases(self):
324
+ def aliases(self) -> List[str]:
320
325
  """Artifact Collection Aliases."""
321
326
  return self._aliases
322
327
 
323
328
  @property
324
- def created_at(self):
329
+ def created_at(self) -> str:
325
330
  return self._created_at
326
331
 
327
332
  def load(self):
328
- query = gql(
329
- """
330
- query ArtifactCollection(
331
- $entityName: String!,
332
- $projectName: String!,
333
- $artifactTypeName: String!,
334
- $artifactCollectionName: String!,
335
- $cursor: String,
336
- $perPage: Int = 1000
337
- ) {{
338
- project(name: $projectName, entityName: $entityName) {{
339
- artifactType(name: $artifactTypeName) {{
340
- artifactCollection: {}(name: $artifactCollectionName) {{
341
- id
342
- name
343
- description
344
- createdAt
345
- tags {{
346
- edges {{
347
- node {{
348
- id
349
- name
350
- }}
351
- }}
352
- }}
353
- aliases(after: $cursor, first: $perPage){{
354
- edges {{
355
- node {{
356
- alias
357
- }}
358
- cursor
359
- }}
360
- pageInfo {{
361
- endCursor
362
- hasNextPage
363
- }}
364
- }}
365
- }}
366
- artifactSequence(name: $artifactCollectionName) {{
367
- __typename
368
- }}
369
- }}
370
- }}
371
- }}
372
- """.format(
373
- artifact_collection_edge_name(
374
- server_supports_artifact_collections_gql_edges(self.client)
375
- )
376
- )
377
- )
333
+ if server_supports_artifact_collections_gql_edges(self.client):
334
+ rename_fields = None
335
+ else:
336
+ rename_fields = {"artifactCollection": "artifactSequence"}
337
+
378
338
  response = self.client.execute(
379
- query,
339
+ gql_compat(PROJECT_ARTIFACT_COLLECTION_GQL, rename_fields=rename_fields),
380
340
  variable_values={
381
341
  "entityName": self.entity,
382
342
  "projectName": self.project,
@@ -384,26 +344,30 @@ class ArtifactCollection:
384
344
  "artifactCollectionName": self._saved_name,
385
345
  },
386
346
  )
387
- if (
388
- response is None
389
- or response.get("project") is None
390
- or response["project"].get("artifactType") is None
391
- or response["project"]["artifactType"].get("artifactCollection") is None
347
+
348
+ result = ProjectArtifactCollection.model_validate(response)
349
+
350
+ if not (
351
+ result.project
352
+ and (proj := result.project)
353
+ and (type_ := proj.artifact_type)
354
+ and (collection := type_.artifact_collection)
392
355
  ):
393
- raise ValueError("Could not find artifact type {}".format(self._saved_type))
394
- sequence = response["project"]["artifactType"]["artifactSequence"]
356
+ raise ValueError(f"Could not find artifact type {self._saved_type}")
357
+
358
+ sequence = type_.artifact_sequence
395
359
  self._is_sequence = (
396
- sequence is not None and sequence["__typename"] == "ArtifactSequence"
397
- )
360
+ sequence is not None
361
+ ) and sequence.typename__ == "ArtifactSequence"
398
362
 
399
363
  if self._attrs is None:
400
- self._attrs = response["project"]["artifactType"]["artifactCollection"]
364
+ self._attrs = collection.model_dump(exclude_unset=True)
401
365
  return self._attrs
402
366
 
403
367
  def change_type(self, new_type: str) -> None:
404
368
  """Deprecated, change type directly with `save` instead."""
405
369
  deprecate.deprecate(
406
- field_name=deprecate.Deprecated.artifact_collection__change_type,
370
+ field_name=Deprecated.artifact_collection__change_type,
407
371
  warning_message="ArtifactCollection.change_type(type) is deprecated, use ArtifactCollection.save() instead.",
408
372
  )
409
373
 
@@ -412,32 +376,13 @@ class ArtifactCollection:
412
376
  termlog(
413
377
  f"Changing artifact collection type of {self._saved_type} to {new_type}"
414
378
  )
415
- template = """
416
- mutation MoveArtifactCollection(
417
- $artifactSequenceID: ID!
418
- $destinationArtifactTypeName: String!
419
- ) {
420
- moveArtifactSequence(
421
- input: {
422
- artifactSequenceID: $artifactSequenceID
423
- destinationArtifactTypeName: $destinationArtifactTypeName
424
- }
425
- ) {
426
- artifactCollection {
427
- id
428
- name
429
- description
430
- __typename
431
- }
432
- }
433
- }
434
- """
435
- variable_values = {
436
- "artifactSequenceID": self.id,
437
- "destinationArtifactTypeName": new_type,
438
- }
439
- mutation = gql(template)
440
- self.client.execute(mutation, variable_values=variable_values)
379
+ self.client.execute(
380
+ gql(MOVE_ARTIFACT_COLLECTION_GQL),
381
+ variable_values={
382
+ "artifactSequenceID": self.id,
383
+ "destinationArtifactTypeName": new_type,
384
+ },
385
+ )
441
386
  self._saved_type = new_type
442
387
  self._type = new_type
443
388
 
@@ -446,40 +391,19 @@ class ArtifactCollection:
446
391
  return self._is_sequence
447
392
 
448
393
  @normalize_exceptions
449
- def delete(self):
394
+ def delete(self) -> None:
450
395
  """Delete the entire artifact collection."""
451
- if self.is_sequence():
452
- mutation = gql(
453
- """
454
- mutation deleteArtifactSequence($id: ID!) {
455
- deleteArtifactSequence(input: {
456
- artifactSequenceID: $id
457
- }) {
458
- artifactCollection {
459
- state
460
- }
461
- }
462
- }
463
- """
464
- )
465
- else:
466
- mutation = gql(
467
- """
468
- mutation deleteArtifactPortfolio($id: ID!) {
469
- deleteArtifactPortfolio(input: {
470
- artifactPortfolioID: $id
471
- }) {
472
- artifactCollection {
473
- state
474
- }
475
- }
476
- }
477
- """
478
- )
479
- self.client.execute(mutation, variable_values={"id": self.id})
396
+ self.client.execute(
397
+ gql(
398
+ DELETE_ARTIFACT_SEQUENCE_GQL
399
+ if self.is_sequence()
400
+ else DELETE_ARTIFACT_PORTFOLIO_GQL
401
+ ),
402
+ variable_values={"id": self.id},
403
+ )
480
404
 
481
405
  @property
482
- def description(self):
406
+ def description(self) -> str:
483
407
  """A description of the artifact collection."""
484
408
  return self._description
485
409
 
@@ -488,7 +412,7 @@ class ArtifactCollection:
488
412
  self._description = description
489
413
 
490
414
  @property
491
- def tags(self):
415
+ def tags(self) -> List[str]:
492
416
  """The tags associated with the artifact collection."""
493
417
  return self._tags
494
418
 
@@ -501,7 +425,7 @@ class ArtifactCollection:
501
425
  self._tags = tags
502
426
 
503
427
  @property
504
- def name(self):
428
+ def name(self) -> str:
505
429
  """The name of the artifact collection."""
506
430
  return self._name
507
431
 
@@ -522,193 +446,69 @@ class ArtifactCollection:
522
446
  )
523
447
  self._type = type
524
448
 
525
- def _update_collection(self):
526
- mutation = gql("""
527
- mutation UpdateArtifactCollection(
528
- $artifactSequenceID: ID!
529
- $name: String
530
- $description: String
531
- ) {
532
- updateArtifactSequence(
533
- input: {
534
- artifactSequenceID: $artifactSequenceID
535
- name: $name
536
- description: $description
537
- }
538
- ) {
539
- artifactCollection {
540
- id
541
- name
542
- description
543
- }
544
- }
545
- }
546
- """)
547
-
548
- variable_values = {
549
- "artifactSequenceID": self.id,
550
- "name": self._name,
551
- "description": self.description,
552
- }
553
- self.client.execute(mutation, variable_values=variable_values)
449
+ def _update_collection(self) -> None:
450
+ self.client.execute(
451
+ gql(
452
+ UPDATE_ARTIFACT_SEQUENCE_GQL
453
+ if self.is_sequence()
454
+ else UPDATE_ARTIFACT_PORTFOLIO_GQL
455
+ ),
456
+ variable_values={
457
+ "id": self.id,
458
+ "name": self.name,
459
+ "description": self.description,
460
+ },
461
+ )
554
462
  self._saved_name = self._name
555
463
 
556
- def _update_collection_type(self):
557
- type_mutation = gql("""
558
- mutation MoveArtifactCollection(
559
- $artifactSequenceID: ID!
560
- $destinationArtifactTypeName: String!
561
- ) {
562
- moveArtifactSequence(
563
- input: {
564
- artifactSequenceID: $artifactSequenceID
565
- destinationArtifactTypeName: $destinationArtifactTypeName
566
- }
567
- ) {
568
- artifactCollection {
569
- id
570
- name
571
- description
572
- __typename
573
- }
574
- }
575
- }
576
- """)
577
-
578
- variable_values = {
579
- "artifactSequenceID": self.id,
580
- "destinationArtifactTypeName": self._type,
581
- }
582
- self.client.execute(type_mutation, variable_values=variable_values)
464
+ def _update_collection_type(self) -> None:
465
+ self.client.execute(
466
+ gql(MOVE_ARTIFACT_COLLECTION_GQL),
467
+ variable_values={
468
+ "artifactSequenceID": self.id,
469
+ "destinationArtifactTypeName": self.type,
470
+ },
471
+ )
583
472
  self._saved_type = self._type
584
473
 
585
- def _update_portfolio(self):
586
- mutation = gql("""
587
- mutation UpdateArtifactPortfolio(
588
- $artifactPortfolioID: ID!
589
- $name: String
590
- $description: String
591
- ) {
592
- updateArtifactPortfolio(
593
- input: {
594
- artifactPortfolioID: $artifactPortfolioID
595
- name: $name
596
- description: $description
597
- }
598
- ) {
599
- artifactCollection {
600
- id
601
- name
602
- description
603
- }
604
- }
605
- }
606
- """)
607
- variable_values = {
608
- "artifactPortfolioID": self.id,
609
- "name": self._name,
610
- "description": self.description,
611
- }
612
- self.client.execute(mutation, variable_values=variable_values)
613
- self._saved_name = self._name
614
-
615
- def _add_tags(self, tags_to_add):
616
- add_mutation = gql(
617
- """
618
- mutation CreateArtifactCollectionTagAssignments(
619
- $entityName: String!
620
- $projectName: String!
621
- $artifactCollectionName: String!
622
- $tags: [TagInput!]!
623
- ) {
624
- createArtifactCollectionTagAssignments(
625
- input: {
626
- entityName: $entityName
627
- projectName: $projectName
628
- artifactCollectionName: $artifactCollectionName
629
- tags: $tags
630
- }
631
- ) {
632
- tags {
633
- id
634
- name
635
- tagCategoryName
636
- }
637
- }
638
- }
639
- """
640
- )
474
+ def _add_tags(self, tags_to_add: Iterable[str]) -> None:
641
475
  self.client.execute(
642
- add_mutation,
476
+ gql(CREATE_ARTIFACT_COLLECTION_TAG_ASSIGNMENTS_GQL),
643
477
  variable_values={
644
478
  "entityName": self.entity,
645
479
  "projectName": self.project,
646
480
  "artifactCollectionName": self._saved_name,
647
- "tags": [
648
- {
649
- "tagName": tag,
650
- }
651
- for tag in tags_to_add
652
- ],
481
+ "tags": [{"tagName": tag} for tag in tags_to_add],
653
482
  },
654
483
  )
655
484
 
656
- def _delete_tags(self, tags_to_delete):
657
- delete_mutation = gql(
658
- """
659
- mutation DeleteArtifactCollectionTagAssignments(
660
- $entityName: String!
661
- $projectName: String!
662
- $artifactCollectionName: String!
663
- $tags: [TagInput!]!
664
- ) {
665
- deleteArtifactCollectionTagAssignments(
666
- input: {
667
- entityName: $entityName
668
- projectName: $projectName
669
- artifactCollectionName: $artifactCollectionName
670
- tags: $tags
671
- }
672
- ) {
673
- success
674
- }
675
- }
676
- """
677
- )
485
+ def _delete_tags(self, tags_to_delete: Iterable[str]) -> None:
678
486
  self.client.execute(
679
- delete_mutation,
487
+ gql(DELETE_ARTIFACT_COLLECTION_TAG_ASSIGNMENTS_GQL),
680
488
  variable_values={
681
489
  "entityName": self.entity,
682
490
  "projectName": self.project,
683
491
  "artifactCollectionName": self._saved_name,
684
- "tags": [
685
- {
686
- "tagName": tag,
687
- }
688
- for tag in tags_to_delete
689
- ],
492
+ "tags": [{"tagName": tag} for tag in tags_to_delete],
690
493
  },
691
494
  )
692
495
 
693
496
  def save(self) -> None:
694
497
  """Persist any changes made to the artifact collection."""
695
- if self.is_sequence():
696
- self._update_collection()
498
+ self._update_collection()
697
499
 
698
- if self._saved_type != self._type:
699
- self._update_collection_type()
700
- else:
701
- self._update_portfolio()
500
+ if self.is_sequence() and (self._saved_type != self._type):
501
+ self._update_collection_type()
702
502
 
703
- tags_to_add = set(self._tags) - set(self._saved_tags)
704
- tags_to_delete = set(self._saved_tags) - set(self._tags)
705
- if len(tags_to_add) > 0:
503
+ current_tags = set(self._tags)
504
+ saved_tags = set(self._saved_tags)
505
+ if tags_to_add := (current_tags - saved_tags):
706
506
  self._add_tags(tags_to_add)
707
- if len(tags_to_delete) > 0:
507
+ if tags_to_delete := (saved_tags - current_tags):
708
508
  self._delete_tags(tags_to_delete)
709
509
  self._saved_tags = copy(self._tags)
710
510
 
711
- def __repr__(self):
511
+ def __repr__(self) -> str:
712
512
  return f"<ArtifactCollection {self._name} ({self._type})>"
713
513
 
714
514
 
@@ -718,6 +518,8 @@ class Artifacts(SizedPaginator["wandb.Artifact"]):
718
518
  This is generally used indirectly via the `Api`.artifact_versions method.
719
519
  """
720
520
 
521
+ last_response: Optional[ArtifactsFragment]
522
+
721
523
  def __init__(
722
524
  self,
723
525
  client: Client,
@@ -730,8 +532,6 @@ class Artifacts(SizedPaginator["wandb.Artifact"]):
730
532
  per_page: int = 50,
731
533
  tags: Optional[Union[str, List[str]]] = None,
732
534
  ):
733
- from wandb.sdk.artifacts.artifact import _gql_artifact_fragment
734
-
735
535
  self.entity = entity
736
536
  self.collection_name = collection_name
737
537
  self.type = type
@@ -747,152 +547,107 @@ class Artifacts(SizedPaginator["wandb.Artifact"]):
747
547
  "collection": self.collection_name,
748
548
  "filters": json.dumps(self.filters),
749
549
  }
750
- self.QUERY = gql(
751
- """
752
- query Artifacts($project: String!, $entity: String!, $type: String!, $collection: String!, $cursor: String, $perPage: Int = 50, $order: String, $filters: JSONString) {{
753
- project(name: $project, entityName: $entity) {{
754
- artifactType(name: $type) {{
755
- artifactCollection: {}(name: $collection) {{
756
- name
757
- artifacts(filters: $filters, after: $cursor, first: $perPage, order: $order) {{
758
- totalCount
759
- edges {{
760
- node {{
761
- ...ArtifactFragment
762
- }}
763
- version
764
- cursor
765
- }}
766
- pageInfo {{
767
- endCursor
768
- hasNextPage
769
- }}
770
- }}
771
- }}
772
- }}
773
- }}
774
- }}
775
- {}
776
- """.format(
777
- artifact_collection_edge_name(
778
- server_supports_artifact_collections_gql_edges(client)
779
- ),
780
- _gql_artifact_fragment(),
781
- )
550
+
551
+ if server_supports_artifact_collections_gql_edges(client):
552
+ rename_fields = None
553
+ else:
554
+ rename_fields = {"artifactCollection": "artifactSequence"}
555
+
556
+ self.QUERY = gql_compat(
557
+ PROJECT_ARTIFACTS_GQL,
558
+ omit_fields=omit_artifact_fields(api=InternalApi()),
559
+ rename_fields=rename_fields,
782
560
  )
561
+
783
562
  super().__init__(client, variables, per_page)
784
563
 
564
+ @override
565
+ def _update_response(self) -> None:
566
+ data = self.client.execute(self.QUERY, variable_values=self.variables)
567
+ result = ProjectArtifacts.model_validate(data)
568
+
569
+ # Extract the inner `*Connection` result for faster/easier access.
570
+ if not (
571
+ (proj := result.project)
572
+ and (type_ := proj.artifact_type)
573
+ and (collection := type_.artifact_collection)
574
+ and (conn := collection.artifacts)
575
+ ):
576
+ raise ValueError(f"Unable to parse {type(self).__name__!r} response data")
577
+
578
+ self.last_response = ArtifactsFragment.model_validate(conn)
579
+
785
580
  @property
786
- def length(self):
787
- if self.last_response:
788
- return self.last_response["project"]["artifactType"]["artifactCollection"][
789
- "artifacts"
790
- ]["totalCount"]
791
- else:
581
+ def length(self) -> Optional[int]:
582
+ if self.last_response is None:
792
583
  return None
584
+ return self.last_response.total_count
793
585
 
794
586
  @property
795
- def more(self):
796
- if self.last_response:
797
- return self.last_response["project"]["artifactType"]["artifactCollection"][
798
- "artifacts"
799
- ]["pageInfo"]["hasNextPage"]
800
- else:
587
+ def more(self) -> bool:
588
+ if self.last_response is None:
801
589
  return True
590
+ return self.last_response.page_info.has_next_page
802
591
 
803
592
  @property
804
- def cursor(self):
805
- if self.last_response:
806
- return self.last_response["project"]["artifactType"]["artifactCollection"][
807
- "artifacts"
808
- ]["edges"][-1]["cursor"]
809
- else:
593
+ def cursor(self) -> Optional[str]:
594
+ if self.last_response is None:
810
595
  return None
596
+ return self.last_response.edges[-1].cursor
811
597
 
812
- def convert_objects(self):
813
- collection = self.last_response["project"]["artifactType"]["artifactCollection"]
814
- artifact_edges = collection.get("artifacts", {}).get("edges", [])
598
+ def convert_objects(self) -> List["wandb.Artifact"]:
599
+ artifact_edges = (edge for edge in self.last_response.edges if edge.node)
815
600
  artifacts = (
816
601
  wandb.Artifact._from_attrs(
817
- self.entity,
818
- self.project,
819
- self.collection_name + ":" + a["version"],
820
- a["node"],
821
- self.client,
602
+ entity=self.entity,
603
+ project=self.project,
604
+ name=f"{self.collection_name}:{edge.version}",
605
+ attrs=edge.node.model_dump(exclude_unset=True),
606
+ client=self.client,
822
607
  )
823
- for a in artifact_edges
608
+ for edge in artifact_edges
824
609
  )
825
610
  required_tags = set(self.tags or [])
826
- return [
827
- artifact for artifact in artifacts if required_tags.issubset(artifact.tags)
828
- ]
611
+ return [art for art in artifacts if required_tags.issubset(art.tags)]
829
612
 
830
613
 
831
614
  class RunArtifacts(SizedPaginator["wandb.Artifact"]):
832
- def __init__(self, client: Client, run: "Run", mode="logged", per_page: int = 50):
833
- from wandb.sdk.artifacts.artifact import _gql_artifact_fragment
834
-
835
- output_query = gql(
836
- """
837
- query RunOutputArtifacts(
838
- $entity: String!, $project: String!, $runName: String!, $cursor: String, $perPage: Int,
839
- ) {
840
- project(name: $project, entityName: $entity) {
841
- run(name: $runName) {
842
- outputArtifacts(after: $cursor, first: $perPage) {
843
- totalCount
844
- edges {
845
- node {
846
- ...ArtifactFragment
847
- }
848
- cursor
849
- }
850
- pageInfo {
851
- endCursor
852
- hasNextPage
853
- }
854
- }
855
- }
856
- }
857
- }
858
- """
859
- + _gql_artifact_fragment()
860
- )
861
-
862
- input_query = gql(
863
- """
864
- query RunInputArtifacts(
865
- $entity: String!, $project: String!, $runName: String!, $cursor: String, $perPage: Int,
866
- ) {
867
- project(name: $project, entityName: $entity) {
868
- run(name: $runName) {
869
- inputArtifacts(after: $cursor, first: $perPage) {
870
- totalCount
871
- edges {
872
- node {
873
- ...ArtifactFragment
874
- }
875
- cursor
876
- }
877
- pageInfo {
878
- endCursor
879
- hasNextPage
880
- }
881
- }
882
- }
883
- }
884
- }
885
- """
886
- + _gql_artifact_fragment()
887
- )
615
+ last_response: Union[
616
+ RunOutputArtifactsProjectRunOutputArtifacts,
617
+ RunInputArtifactsProjectRunInputArtifacts,
618
+ ]
619
+
620
+ #: The pydantic model used to parse the (inner part of the) raw response.
621
+ _response_cls: Type[
622
+ Union[
623
+ RunOutputArtifactsProjectRunOutputArtifacts,
624
+ RunInputArtifactsProjectRunInputArtifacts,
625
+ ]
626
+ ]
888
627
 
628
+ def __init__(
629
+ self,
630
+ client: Client,
631
+ run: "Run",
632
+ mode: Literal["logged", "used"] = "logged",
633
+ per_page: int = 50,
634
+ ):
889
635
  self.run = run
636
+
890
637
  if mode == "logged":
891
638
  self.run_key = "outputArtifacts"
892
- self.QUERY = output_query
639
+ self.QUERY = gql_compat(
640
+ RUN_OUTPUT_ARTIFACTS_GQL,
641
+ omit_fields=omit_artifact_fields(api=InternalApi()),
642
+ )
643
+ self._response_cls = RunOutputArtifactsProjectRunOutputArtifacts
893
644
  elif mode == "used":
894
645
  self.run_key = "inputArtifacts"
895
- self.QUERY = input_query
646
+ self.QUERY = gql_compat(
647
+ RUN_INPUT_ARTIFACTS_GQL,
648
+ omit_fields=omit_artifact_fields(api=InternalApi()),
649
+ )
650
+ self._response_cls = RunInputArtifactsProjectRunInputArtifacts
896
651
  else:
897
652
  raise ValueError("mode must be logged or used")
898
653
 
@@ -901,99 +656,50 @@ class RunArtifacts(SizedPaginator["wandb.Artifact"]):
901
656
  "project": run.project,
902
657
  "runName": run.id,
903
658
  }
904
-
905
659
  super().__init__(client, variable_values, per_page)
906
660
 
661
+ @override
662
+ def _update_response(self) -> None:
663
+ data = self.client.execute(self.QUERY, variable_values=self.variables)
664
+
665
+ # Extract the inner `*Connection` result for faster/easier access.
666
+ inner_data = data["project"]["run"][self.run_key]
667
+ self.last_response = self._response_cls.model_validate(inner_data)
668
+
907
669
  @property
908
- def length(self):
909
- if self.last_response:
910
- return self.last_response["project"]["run"][self.run_key]["totalCount"]
911
- else:
670
+ def length(self) -> Optional[int]:
671
+ if self.last_response is None:
912
672
  return None
673
+ return self.last_response.total_count
913
674
 
914
675
  @property
915
- def more(self):
916
- if self.last_response:
917
- return self.last_response["project"]["run"][self.run_key]["pageInfo"][
918
- "hasNextPage"
919
- ]
920
- else:
676
+ def more(self) -> bool:
677
+ if self.last_response is None:
921
678
  return True
679
+ return self.last_response.page_info.has_next_page
922
680
 
923
681
  @property
924
- def cursor(self):
925
- if self.last_response:
926
- return self.last_response["project"]["run"][self.run_key]["edges"][-1][
927
- "cursor"
928
- ]
929
- else:
682
+ def cursor(self) -> Optional[str]:
683
+ if self.last_response is None:
930
684
  return None
685
+ return self.last_response.edges[-1].cursor
931
686
 
932
- def convert_objects(self):
687
+ def convert_objects(self) -> List["wandb.Artifact"]:
933
688
  return [
934
689
  wandb.Artifact._from_attrs(
935
- r["node"]["artifactSequence"]["project"]["entityName"],
936
- r["node"]["artifactSequence"]["project"]["name"],
937
- "{}:v{}".format(
938
- r["node"]["artifactSequence"]["name"], r["node"]["versionIndex"]
939
- ),
940
- r["node"],
941
- self.client,
690
+ entity=node.artifact_sequence.project.entity_name,
691
+ project=node.artifact_sequence.project.name,
692
+ name=f"{node.artifact_sequence.name}:v{node.version_index}",
693
+ attrs=node.model_dump(exclude_unset=True),
694
+ client=self.client,
942
695
  )
943
- for r in self.last_response["project"]["run"][self.run_key]["edges"]
696
+ for r in self.last_response.edges
697
+ if (node := r.node)
944
698
  ]
945
699
 
946
700
 
947
701
  class ArtifactFiles(SizedPaginator["public.File"]):
948
- ARTIFACT_VERSION_FILES_QUERY = gql(
949
- f"""
950
- query ArtifactFiles(
951
- $entityName: String!,
952
- $projectName: String!,
953
- $artifactTypeName: String!,
954
- $artifactName: String!
955
- $fileNames: [String!],
956
- $fileCursor: String,
957
- $fileLimit: Int = 50
958
- ) {{
959
- project(name: $projectName, entityName: $entityName) {{
960
- artifactType(name: $artifactTypeName) {{
961
- artifact(name: $artifactName) {{
962
- files(names: $fileNames, after: $fileCursor, first: $fileLimit) {{
963
- ...FilesFragment
964
- }}
965
- }}
966
- }}
967
- }}
968
- }}
969
- {ARTIFACT_FILES_FRAGMENT}
970
- """
971
- )
972
-
973
- ARTIFACT_COLLECTION_MEMBERSHIP_FILES_QUERY = gql(
974
- f"""
975
- query ArtifactCollectionMembershipFiles(
976
- $entityName: String!,
977
- $projectName: String!,
978
- $artifactName: String!,
979
- $artifactVersionIndex: String!,
980
- $fileNames: [String!],
981
- $fileCursor: String,
982
- $fileLimit: Int = 50
983
- ) {{
984
- project(name: $projectName, entityName: $entityName) {{
985
- artifactCollection(name: $artifactName) {{
986
- artifactMembership (aliasName: $artifactVersionIndex) {{
987
- files(names: $fileNames, after: $fileCursor, first: $fileLimit) {{
988
- ...FilesFragment
989
- }}
990
- }}
991
- }}
992
- }}
993
- }}
994
- {ARTIFACT_FILES_FRAGMENT}
995
- """
996
- )
702
+ last_response: Optional[FilesFragment]
997
703
 
998
704
  def __init__(
999
705
  self,
@@ -1006,15 +712,9 @@ class ArtifactFiles(SizedPaginator["public.File"]):
1006
712
  ServerFeature.ARTIFACT_COLLECTION_MEMBERSHIP_FILES
1007
713
  )
1008
714
  self.artifact = artifact
1009
- variables = {
1010
- "entityName": artifact.source_entity,
1011
- "projectName": artifact.source_project,
1012
- "artifactTypeName": artifact.type,
1013
- "artifactName": artifact.source_name,
1014
- "fileNames": names,
1015
- }
715
+
1016
716
  if self.query_via_membership:
1017
- self.QUERY = self.ARTIFACT_COLLECTION_MEMBERSHIP_FILES_QUERY
717
+ query_str = ARTIFACT_COLLECTION_MEMBERSHIP_FILES_GQL
1018
718
  variables = {
1019
719
  "entityName": artifact.entity,
1020
720
  "projectName": artifact.project,
@@ -1023,67 +723,77 @@ class ArtifactFiles(SizedPaginator["public.File"]):
1023
723
  "fileNames": names,
1024
724
  }
1025
725
  else:
1026
- self.QUERY = self.ARTIFACT_VERSION_FILES_QUERY
726
+ query_str = ARTIFACT_VERSION_FILES_GQL
727
+ variables = {
728
+ "entityName": artifact.source_entity,
729
+ "projectName": artifact.source_project,
730
+ "artifactName": artifact.source_name,
731
+ "artifactTypeName": artifact.type,
732
+ "fileNames": names,
733
+ }
734
+
1027
735
  # The server must advertise at least SDK 0.12.21
1028
736
  # to get storagePath
1029
737
  if not client.version_supported("0.12.21"):
1030
- self.QUERY = gql(self.QUERY.loc.source.body.replace("storagePath\n", ""))
738
+ self.QUERY = gql_compat(query_str, omit_fields={"storagePath"})
739
+ else:
740
+ self.QUERY = gql(query_str)
741
+
1031
742
  super().__init__(client, variables, per_page)
1032
743
 
744
+ @override
745
+ def _update_response(self) -> None:
746
+ data = self.client.execute(self.QUERY, variable_values=self.variables)
747
+
748
+ # Extract the inner `*Connection` result for faster/easier access.
749
+ if self.query_via_membership:
750
+ result = ArtifactCollectionMembershipFiles.model_validate(data)
751
+ conn = result.project.artifact_collection.artifact_membership.files
752
+ else:
753
+ result = ArtifactVersionFiles.model_validate(data)
754
+ conn = result.project.artifact_type.artifact.files
755
+
756
+ if conn is None:
757
+ raise ValueError(f"Unable to parse {type(self).__name__!r} response data")
758
+
759
+ self.last_response = FilesFragment.model_validate(conn)
760
+
1033
761
  @property
1034
- def path(self):
762
+ def path(self) -> List[str]:
1035
763
  return [self.artifact.entity, self.artifact.project, self.artifact.name]
1036
764
 
1037
765
  @property
1038
- def length(self):
766
+ def length(self) -> int:
1039
767
  return self.artifact.file_count
1040
768
 
1041
769
  @property
1042
- def more(self):
1043
- if self.last_response:
1044
- if self.query_via_membership:
1045
- return self.last_response["project"]["artifactCollection"][
1046
- "artifactMembership"
1047
- ]["files"]["pageInfo"]["hasNextPage"]
1048
- return self.last_response["project"]["artifactType"]["artifact"]["files"][
1049
- "pageInfo"
1050
- ]["hasNextPage"]
1051
- else:
770
+ def more(self) -> bool:
771
+ if self.last_response is None:
1052
772
  return True
773
+ return self.last_response.page_info.has_next_page
1053
774
 
1054
775
  @property
1055
- def cursor(self):
1056
- if self.last_response:
1057
- if self.query_via_membership:
1058
- return self.last_response["project"]["artifactCollection"][
1059
- "artifactMembership"
1060
- ]["files"]["edges"][-1]["cursor"]
1061
- return self.last_response["project"]["artifactType"]["artifact"]["files"][
1062
- "edges"
1063
- ][-1]["cursor"]
1064
- else:
776
+ def cursor(self) -> Optional[str]:
777
+ if self.last_response is None:
1065
778
  return None
779
+ return self.last_response.edges[-1].cursor
1066
780
 
1067
- def update_variables(self):
781
+ def update_variables(self) -> None:
1068
782
  self.variables.update({"fileLimit": self.per_page, "fileCursor": self.cursor})
1069
783
 
1070
- def convert_objects(self):
1071
- if self.query_via_membership:
1072
- return [
1073
- public.File(self.client, r["node"])
1074
- for r in self.last_response["project"]["artifactCollection"][
1075
- "artifactMembership"
1076
- ]["files"]["edges"]
1077
- ]
784
+ def convert_objects(self) -> List["public.File"]:
1078
785
  return [
1079
- public.File(self.client, r["node"])
1080
- for r in self.last_response["project"]["artifactType"]["artifact"]["files"][
1081
- "edges"
1082
- ]
786
+ public.File(
787
+ client=self.client,
788
+ attrs=node.model_dump(exclude_unset=True),
789
+ )
790
+ for edge in self.last_response.edges
791
+ if (node := edge.node)
1083
792
  ]
1084
793
 
1085
- def __repr__(self):
1086
- return "<ArtifactFiles {} ({})>".format("/".join(self.path), len(self))
794
+ def __repr__(self) -> str:
795
+ path_str = "/".join(self.path)
796
+ return f"<ArtifactFiles {path_str} ({len(self)})>"
1087
797
 
1088
798
 
1089
799
  def server_supports_artifact_collections_gql_edges(
@@ -1099,21 +809,3 @@ def server_supports_artifact_collections_gql_edges(
1099
809
  "W&B Local Server version does not support ArtifactCollection gql edges; falling back to using legacy ArtifactSequence. Please update server to at least version 0.9.50."
1100
810
  )
1101
811
  return supported
1102
-
1103
-
1104
- def artifact_collection_edge_name(server_supports_artifact_collections: bool) -> str:
1105
- return (
1106
- "artifactCollection"
1107
- if server_supports_artifact_collections
1108
- else "artifactSequence"
1109
- )
1110
-
1111
-
1112
- def artifact_collection_plural_edge_name(
1113
- server_supports_artifact_collections: bool,
1114
- ) -> str:
1115
- return (
1116
- "artifactCollections"
1117
- if server_supports_artifact_collections
1118
- else "artifactSequences"
1119
- )