wandb 0.19.9__py3-none-win32.whl → 0.19.11__py3-none-win32.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (156) hide show
  1. wandb/__init__.py +1 -1
  2. wandb/__init__.pyi +6 -3
  3. wandb/_pydantic/__init__.py +14 -8
  4. wandb/_pydantic/base.py +51 -36
  5. wandb/_pydantic/utils.py +73 -0
  6. wandb/_pydantic/v1_compat.py +79 -57
  7. wandb/apis/public/__init__.py +2 -2
  8. wandb/apis/public/api.py +684 -4
  9. wandb/apis/public/artifacts.py +377 -677
  10. wandb/apis/public/automations.py +69 -0
  11. wandb/apis/public/integrations.py +180 -0
  12. wandb/apis/public/projects.py +29 -0
  13. wandb/apis/public/registries/__init__.py +0 -0
  14. wandb/apis/public/registries/_freezable_list.py +179 -0
  15. wandb/apis/public/{registries.py → registries/registries_search.py} +22 -129
  16. wandb/apis/public/registries/registry.py +357 -0
  17. wandb/apis/public/registries/utils.py +140 -0
  18. wandb/apis/public/runs.py +58 -56
  19. wandb/apis/public/utils.py +107 -1
  20. wandb/automations/__init__.py +73 -0
  21. wandb/automations/_filters/__init__.py +40 -0
  22. wandb/automations/_filters/expressions.py +181 -0
  23. wandb/automations/_filters/operators.py +258 -0
  24. wandb/automations/_filters/run_metrics.py +332 -0
  25. wandb/automations/_generated/__init__.py +177 -0
  26. wandb/automations/_generated/create_automation.py +17 -0
  27. wandb/automations/_generated/create_generic_webhook_integration.py +43 -0
  28. wandb/automations/_generated/delete_automation.py +17 -0
  29. wandb/automations/_generated/enums.py +33 -0
  30. wandb/automations/_generated/fragments.py +358 -0
  31. wandb/automations/_generated/generic_webhook_integrations_by_entity.py +22 -0
  32. wandb/automations/_generated/get_automations.py +24 -0
  33. wandb/automations/_generated/get_automations_by_entity.py +26 -0
  34. wandb/automations/_generated/input_types.py +104 -0
  35. wandb/automations/_generated/integrations_by_entity.py +22 -0
  36. wandb/automations/_generated/operations.py +647 -0
  37. wandb/automations/_generated/slack_integrations_by_entity.py +22 -0
  38. wandb/automations/_generated/update_automation.py +17 -0
  39. wandb/automations/_utils.py +237 -0
  40. wandb/automations/_validators.py +165 -0
  41. wandb/automations/actions.py +220 -0
  42. wandb/automations/automations.py +87 -0
  43. wandb/automations/events.py +287 -0
  44. wandb/automations/integrations.py +45 -0
  45. wandb/automations/scopes.py +78 -0
  46. wandb/beta/workflows.py +9 -10
  47. wandb/bin/gpu_stats.exe +0 -0
  48. wandb/bin/wandb-core +0 -0
  49. wandb/cli/cli.py +3 -3
  50. wandb/env.py +11 -0
  51. wandb/integration/keras/keras.py +2 -1
  52. wandb/integration/langchain/wandb_tracer.py +2 -1
  53. wandb/jupyter.py +137 -118
  54. wandb/old/settings.py +4 -1
  55. wandb/old/summary.py +0 -2
  56. wandb/proto/v3/wandb_internal_pb2.py +297 -292
  57. wandb/proto/v3/wandb_settings_pb2.py +2 -2
  58. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  59. wandb/proto/v4/wandb_internal_pb2.py +292 -292
  60. wandb/proto/v4/wandb_settings_pb2.py +2 -2
  61. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  62. wandb/proto/v5/wandb_internal_pb2.py +292 -292
  63. wandb/proto/v5/wandb_settings_pb2.py +2 -2
  64. wandb/proto/v5/wandb_telemetry_pb2.py +10 -10
  65. wandb/proto/v6/wandb_base_pb2.py +41 -0
  66. wandb/proto/v6/wandb_internal_pb2.py +393 -0
  67. wandb/proto/v6/wandb_server_pb2.py +78 -0
  68. wandb/proto/v6/wandb_settings_pb2.py +58 -0
  69. wandb/proto/v6/wandb_telemetry_pb2.py +52 -0
  70. wandb/proto/wandb_base_pb2.py +2 -0
  71. wandb/proto/wandb_deprecated.py +8 -0
  72. wandb/proto/wandb_internal_pb2.py +3 -1
  73. wandb/proto/wandb_server_pb2.py +2 -0
  74. wandb/proto/wandb_settings_pb2.py +2 -0
  75. wandb/proto/wandb_telemetry_pb2.py +2 -0
  76. wandb/sdk/artifacts/_generated/__init__.py +289 -0
  77. wandb/sdk/artifacts/_generated/add_aliases.py +21 -0
  78. wandb/sdk/artifacts/_generated/artifact_collection_membership_files.py +43 -0
  79. wandb/sdk/artifacts/_generated/artifact_version_files.py +36 -0
  80. wandb/sdk/artifacts/_generated/create_artifact_collection_tag_assignments.py +36 -0
  81. wandb/sdk/artifacts/_generated/delete_aliases.py +21 -0
  82. wandb/sdk/artifacts/_generated/delete_artifact_collection_tag_assignments.py +25 -0
  83. wandb/sdk/artifacts/_generated/delete_artifact_portfolio.py +35 -0
  84. wandb/sdk/artifacts/_generated/delete_artifact_sequence.py +35 -0
  85. wandb/sdk/artifacts/_generated/enums.py +17 -0
  86. wandb/sdk/artifacts/_generated/fetch_linked_artifacts.py +67 -0
  87. wandb/sdk/artifacts/_generated/fragments.py +221 -0
  88. wandb/sdk/artifacts/_generated/input_types.py +28 -0
  89. wandb/sdk/artifacts/_generated/move_artifact_collection.py +35 -0
  90. wandb/sdk/artifacts/_generated/operations.py +611 -0
  91. wandb/sdk/artifacts/_generated/project_artifact_collection.py +101 -0
  92. wandb/sdk/artifacts/_generated/project_artifact_collections.py +33 -0
  93. wandb/sdk/artifacts/_generated/project_artifact_type.py +24 -0
  94. wandb/sdk/artifacts/_generated/project_artifact_types.py +24 -0
  95. wandb/sdk/artifacts/_generated/project_artifacts.py +42 -0
  96. wandb/sdk/artifacts/_generated/run_input_artifacts.py +51 -0
  97. wandb/sdk/artifacts/_generated/run_output_artifacts.py +51 -0
  98. wandb/sdk/artifacts/_generated/update_artifact.py +26 -0
  99. wandb/sdk/artifacts/_generated/update_artifact_portfolio.py +35 -0
  100. wandb/sdk/artifacts/_generated/update_artifact_sequence.py +35 -0
  101. wandb/sdk/artifacts/_graphql_fragments.py +57 -79
  102. wandb/sdk/artifacts/_validators.py +120 -1
  103. wandb/sdk/artifacts/artifact.py +419 -215
  104. wandb/sdk/artifacts/artifact_file_cache.py +4 -6
  105. wandb/sdk/artifacts/artifact_manifest_entry.py +13 -3
  106. wandb/sdk/artifacts/storage_handlers/azure_handler.py +1 -0
  107. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +182 -1
  108. wandb/sdk/artifacts/storage_policy.py +3 -0
  109. wandb/sdk/data_types/base_types/media.py +2 -3
  110. wandb/sdk/data_types/base_types/wb_value.py +34 -11
  111. wandb/sdk/data_types/html.py +36 -9
  112. wandb/sdk/data_types/image.py +12 -12
  113. wandb/sdk/data_types/table.py +5 -0
  114. wandb/sdk/data_types/trace_tree.py +2 -0
  115. wandb/sdk/data_types/utils.py +1 -1
  116. wandb/sdk/data_types/video.py +59 -57
  117. wandb/sdk/interface/interface.py +4 -3
  118. wandb/sdk/internal/internal_api.py +21 -31
  119. wandb/sdk/internal/profiler.py +6 -5
  120. wandb/sdk/internal/run.py +13 -6
  121. wandb/sdk/internal/sender.py +5 -2
  122. wandb/sdk/launch/sweeps/utils.py +8 -0
  123. wandb/sdk/lib/apikey.py +25 -4
  124. wandb/sdk/lib/asyncio_compat.py +1 -1
  125. wandb/sdk/lib/deprecate.py +13 -22
  126. wandb/sdk/lib/disabled.py +2 -1
  127. wandb/sdk/lib/printer.py +37 -8
  128. wandb/sdk/lib/printer_asyncio.py +46 -0
  129. wandb/sdk/lib/redirect.py +10 -5
  130. wandb/sdk/projects/_generated/__init__.py +47 -0
  131. wandb/sdk/projects/_generated/delete_project.py +22 -0
  132. wandb/sdk/projects/_generated/enums.py +4 -0
  133. wandb/sdk/projects/_generated/fetch_registry.py +22 -0
  134. wandb/sdk/projects/_generated/fragments.py +41 -0
  135. wandb/sdk/projects/_generated/input_types.py +13 -0
  136. wandb/sdk/projects/_generated/operations.py +88 -0
  137. wandb/sdk/projects/_generated/rename_project.py +27 -0
  138. wandb/sdk/projects/_generated/upsert_registry_project.py +27 -0
  139. wandb/sdk/service/server_sock.py +19 -14
  140. wandb/sdk/service/service.py +18 -8
  141. wandb/sdk/service/streams.py +5 -0
  142. wandb/sdk/verify/verify.py +6 -3
  143. wandb/sdk/wandb_init.py +217 -70
  144. wandb/sdk/wandb_login.py +13 -4
  145. wandb/sdk/wandb_run.py +419 -295
  146. wandb/sdk/wandb_settings.py +27 -10
  147. wandb/sdk/wandb_setup.py +61 -0
  148. wandb/util.py +33 -29
  149. {wandb-0.19.9.dist-info → wandb-0.19.11.dist-info}/METADATA +5 -5
  150. {wandb-0.19.9.dist-info → wandb-0.19.11.dist-info}/RECORD +153 -83
  151. wandb/_globals.py +0 -19
  152. wandb/sdk/internal/_generated/base.py +0 -226
  153. wandb/sdk/internal/_generated/typing_compat.py +0 -14
  154. {wandb-0.19.9.dist-info → wandb-0.19.11.dist-info}/WHEEL +0 -0
  155. {wandb-0.19.9.dist-info → wandb-0.19.11.dist-info}/entry_points.txt +0 -0
  156. {wandb-0.19.9.dist-info → wandb-0.19.11.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,357 @@
1
+ from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
2
+
3
+ from wandb_gql import gql
4
+
5
+ import wandb
6
+ from wandb.apis.public.registries._freezable_list import AddOnlyArtifactTypesList
7
+ from wandb.apis.public.registries.registries_search import Collections, Versions
8
+ from wandb.apis.public.registries.utils import (
9
+ _fetch_org_entity_from_organization,
10
+ _format_gql_artifact_types_input,
11
+ _gql_to_registry_visibility,
12
+ _registry_visibility_to_gql,
13
+ )
14
+ from wandb.proto.wandb_internal_pb2 import ServerFeature
15
+ from wandb.sdk.artifacts._validators import REGISTRY_PREFIX, validate_project_name
16
+ from wandb.sdk.internal.internal_api import Api as InternalApi
17
+ from wandb.sdk.projects._generated.delete_project import DeleteProject
18
+ from wandb.sdk.projects._generated.operations import (
19
+ DELETE_PROJECT_GQL,
20
+ FETCH_REGISTRY_GQL,
21
+ RENAME_PROJECT_GQL,
22
+ UPSERT_REGISTRY_PROJECT_GQL,
23
+ )
24
+ from wandb.sdk.projects._generated.rename_project import RenameProject
25
+ from wandb.sdk.projects._generated.upsert_registry_project import UpsertRegistryProject
26
+
27
+ if TYPE_CHECKING:
28
+ from wandb_gql import Client
29
+
30
+
31
+ class Registry:
32
+ """A single registry in the Registry."""
33
+
34
+ def __init__(
35
+ self,
36
+ client: "Client",
37
+ organization: str,
38
+ entity: str,
39
+ name: str,
40
+ attrs: Optional[Dict[str, Any]] = None,
41
+ ):
42
+ self.client = client
43
+ self._name = name
44
+ self._saved_name = name
45
+ self._entity = entity
46
+ self._organization = organization
47
+ if attrs is not None:
48
+ self._update_attributes(attrs)
49
+
50
+ def _update_attributes(self, attrs: Dict[str, Any]) -> None:
51
+ """Helper method to update instance attributes from a dictionary."""
52
+ self._id = attrs.get("id", "")
53
+ if self._id is None:
54
+ raise ValueError(f"Registry {self.name}'s id is not found")
55
+
56
+ self._description = attrs.get("description", "")
57
+ self._allow_all_artifact_types = attrs.get(
58
+ "allowAllArtifactTypesInRegistry", False
59
+ )
60
+ self._artifact_types = AddOnlyArtifactTypesList(
61
+ t["node"]["name"] for t in attrs.get("artifactTypes", {}).get("edges", [])
62
+ )
63
+ self._created_at = attrs.get("createdAt", "")
64
+ self._updated_at = attrs.get("updatedAt", "")
65
+ self._visibility = _gql_to_registry_visibility(attrs.get("access", ""))
66
+
67
+ @property
68
+ def full_name(self) -> str:
69
+ """Full name of the registry including the `wandb-registry-` prefix."""
70
+ return f"wandb-registry-{self.name}"
71
+
72
+ @property
73
+ def name(self) -> str:
74
+ """Name of the registry without the `wandb-registry-` prefix."""
75
+ return self._name
76
+
77
+ @name.setter
78
+ def name(self, value: str):
79
+ self._name = value
80
+
81
+ @property
82
+ def entity(self) -> str:
83
+ """Organization entity of the registry."""
84
+ return self._entity
85
+
86
+ @property
87
+ def organization(self) -> str:
88
+ """Organization name of the registry."""
89
+ return self._organization
90
+
91
+ @property
92
+ def description(self) -> str:
93
+ """Description of the registry."""
94
+ return self._description
95
+
96
+ @description.setter
97
+ def description(self, value: str):
98
+ """Set the description of the registry."""
99
+ self._description = value
100
+
101
+ @property
102
+ def allow_all_artifact_types(self):
103
+ """Returns whether all artifact types are allowed in the registry.
104
+
105
+ If `True` then artifacts of any type can be added to this registry.
106
+ If `False` then artifacts are restricted to the types in `artifact_types` for this registry.
107
+ """
108
+ return self._allow_all_artifact_types
109
+
110
+ @allow_all_artifact_types.setter
111
+ def allow_all_artifact_types(self, value: bool):
112
+ """Set whether all artifact types are allowed in the registry."""
113
+ self._allow_all_artifact_types = value
114
+
115
+ @property
116
+ def artifact_types(self) -> AddOnlyArtifactTypesList:
117
+ """Returns the artifact types allowed in the registry.
118
+
119
+ If `allow_all_artifact_types` is `True` then `artifact_types` reflects the
120
+ types previously saved or currently used in the registry.
121
+ If `allow_all_artifact_types` is `False` then artifacts are restricted to the
122
+ types in `artifact_types`.
123
+
124
+ Note:
125
+ Previously saved artifact types cannot be removed.
126
+
127
+ Example:
128
+ ```python
129
+ registry.artifact_types.append("model")
130
+ registry.save() # once saved, the artifact type `model` cannot be removed
131
+ registry.artifact_types.append("accidentally_added")
132
+ registry.artifact_types.remove(
133
+ "accidentally_added"
134
+ ) # Types can only be removed if it has not been saved yet
135
+ ```
136
+ """
137
+ return self._artifact_types
138
+
139
+ @property
140
+ def created_at(self) -> str:
141
+ """Timestamp of when the registry was created."""
142
+ return self._created_at
143
+
144
+ @property
145
+ def updated_at(self) -> str:
146
+ """Timestamp of when the registry was last updated."""
147
+ return self._updated_at
148
+
149
+ @property
150
+ def path(self):
151
+ return [self.entity, self.full_name]
152
+
153
+ @property
154
+ def visibility(self) -> Literal["organization", "restricted"]:
155
+ """Visibility of the registry.
156
+
157
+ Returns:
158
+ Literal["organization", "restricted"]: The visibility level.
159
+ - "organization": Anyone in the organization can view this registry.
160
+ You can edit their roles later from the settings in the UI.
161
+ - "restricted": Only invited members via the UI can access this registry.
162
+ Public sharing is disabled.
163
+ """
164
+ return self._visibility
165
+
166
+ @visibility.setter
167
+ def visibility(self, value: Literal["organization", "restricted"]):
168
+ """Set the visibility of the registry.
169
+
170
+ Args:
171
+ value: The visibility level. Options are:
172
+ - "organization": Anyone in the organization can view this registry.
173
+ You can edit their roles later from the settings in the UI.
174
+ - "restricted": Only invited members via the UI can access this registry.
175
+ Public sharing is disabled.
176
+ """
177
+ self._visibility = value
178
+
179
+ def collections(self, filter: Optional[Dict[str, Any]] = None) -> Collections:
180
+ """Returns the collections belonging to the registry."""
181
+ registry_filter = {
182
+ "name": self.full_name,
183
+ }
184
+ return Collections(self.client, self.organization, registry_filter, filter)
185
+
186
+ def versions(self, filter: Optional[Dict[str, Any]] = None) -> Versions:
187
+ """Returns the versions belonging to the registry."""
188
+ registry_filter = {
189
+ "name": self.full_name,
190
+ }
191
+ return Versions(self.client, self.organization, registry_filter, None, filter)
192
+
193
+ @classmethod
194
+ def create(
195
+ cls,
196
+ client: "Client",
197
+ organization: str,
198
+ name: str,
199
+ visibility: Literal["organization", "restricted"],
200
+ description: Optional[str] = None,
201
+ artifact_types: Optional[List[str]] = None,
202
+ ):
203
+ """Create a new registry.
204
+
205
+ The registry name must be unique within the organization.
206
+ This function should be called using `api.create_registry()`
207
+
208
+ Args:
209
+ client: The GraphQL client.
210
+ organization: The name of the organization.
211
+ name: The name of the registry (without the `wandb-registry-` prefix).
212
+ visibility: The visibility level ('organization' or 'restricted').
213
+ description: An optional description for the registry.
214
+ artifact_types: An optional list of allowed artifact types.
215
+
216
+ Returns:
217
+ Registry: The newly created Registry object.
218
+
219
+ Raises:
220
+ ValueError: If a registry with the same name already exists in the
221
+ organization or if the creation fails.
222
+ """
223
+ org_entity = _fetch_org_entity_from_organization(client, organization)
224
+ full_name = REGISTRY_PREFIX + name
225
+ validate_project_name(full_name)
226
+ accepted_artifact_types = []
227
+ if artifact_types:
228
+ accepted_artifact_types = _format_gql_artifact_types_input(artifact_types)
229
+ visibility_value = _registry_visibility_to_gql(visibility)
230
+ registry_creation_error = (
231
+ f"Failed to create registry {name!r} in organization {organization!r}."
232
+ )
233
+ try:
234
+ response = client.execute(
235
+ gql(UPSERT_REGISTRY_PROJECT_GQL),
236
+ variable_values={
237
+ "description": description,
238
+ "entityName": org_entity,
239
+ "name": full_name,
240
+ "access": visibility_value,
241
+ "allowAllArtifactTypesInRegistry": not accepted_artifact_types,
242
+ "artifactTypes": accepted_artifact_types,
243
+ },
244
+ )
245
+ except Exception:
246
+ raise ValueError(registry_creation_error)
247
+ if not response["upsertModel"]["inserted"]:
248
+ raise ValueError(registry_creation_error)
249
+
250
+ return Registry(
251
+ client,
252
+ organization,
253
+ org_entity,
254
+ name,
255
+ response["upsertModel"]["project"],
256
+ )
257
+
258
+ def delete(self) -> None:
259
+ """Delete the registry. This is irreversible."""
260
+ try:
261
+ response = self.client.execute(
262
+ gql(DELETE_PROJECT_GQL), variable_values={"id": self._id}
263
+ )
264
+ result = DeleteProject.model_validate(response)
265
+ except Exception:
266
+ raise ValueError(
267
+ f"Failed to delete registry: {self.name!r} in organization: {self.organization!r}"
268
+ )
269
+ if not result.delete_model.success:
270
+ raise ValueError(
271
+ f"Failed to delete registry: {self.name!r} in organization: {self.organization!r}"
272
+ )
273
+
274
+ def load(self) -> None:
275
+ """Load the registry attributes from the backend to reflect the latest saved state."""
276
+ load_failure_message = (
277
+ f"Failed to load registry {self.name!r} "
278
+ f"in organization {self.organization!r}."
279
+ )
280
+ try:
281
+ response = self.client.execute(
282
+ gql(FETCH_REGISTRY_GQL),
283
+ variable_values={
284
+ "name": self.full_name,
285
+ "entityName": self.entity,
286
+ },
287
+ )
288
+ except Exception:
289
+ raise ValueError(load_failure_message)
290
+ if response["entity"] is None:
291
+ raise ValueError(load_failure_message)
292
+ self.attrs = response["entity"]["project"]
293
+ if self.attrs is None:
294
+ raise ValueError(load_failure_message)
295
+ self._update_attributes(self.attrs)
296
+
297
+ def save(self) -> None:
298
+ """Save registry attributes to the backend."""
299
+ if not InternalApi()._check_server_feature_with_fallback(
300
+ ServerFeature.INCLUDE_ARTIFACT_TYPES_IN_REGISTRY_CREATION
301
+ ):
302
+ raise RuntimeError(
303
+ "saving the registry is not enabled on this wandb server version. "
304
+ "Please upgrade your server version or contact support at support@wandb.com."
305
+ )
306
+
307
+ if self._no_updating_registry_types():
308
+ raise ValueError(
309
+ f"Cannot update artifact types when `allows_all_artifact_types` is {True!r}. Set it to {False!r} first."
310
+ )
311
+
312
+ validate_project_name(self.full_name)
313
+ visibility_value = _registry_visibility_to_gql(self.visibility)
314
+ newly_added_types = _format_gql_artifact_types_input(self.artifact_types.draft)
315
+ registry_save_error = f"Failed to save and update registry: {self.name} in organization: {self.organization}"
316
+ full_saved_name = f"{REGISTRY_PREFIX}{self._saved_name}"
317
+ try:
318
+ response = self.client.execute(
319
+ gql(UPSERT_REGISTRY_PROJECT_GQL),
320
+ variable_values={
321
+ "description": self.description,
322
+ "entityName": self.entity,
323
+ "name": full_saved_name, # this makes it so we are updating the original registry in case the name has changed
324
+ "access": visibility_value,
325
+ "allowAllArtifactTypesInRegistry": self.allow_all_artifact_types,
326
+ "artifactTypes": newly_added_types,
327
+ },
328
+ )
329
+ result = UpsertRegistryProject.model_validate(response)
330
+ except Exception:
331
+ raise ValueError(registry_save_error)
332
+ if result.upsert_model.inserted:
333
+ # This is not suppose trigger unless the user has messed with the `_saved_name` variable
334
+ wandb.termlog(
335
+ f"Created registry {self.name!r} in organization {self.organization!r} on save"
336
+ )
337
+ self._update_attributes(response["upsertModel"]["project"])
338
+
339
+ # Update the name of the registry if it has changed
340
+ if self._saved_name != self.name:
341
+ response = self.client.execute(
342
+ gql(RENAME_PROJECT_GQL),
343
+ variable_values={
344
+ "entityName": self.entity,
345
+ "oldProjectName": full_saved_name,
346
+ "newProjectName": self.full_name,
347
+ },
348
+ )
349
+ result = RenameProject.model_validate(response)
350
+ self._saved_name = self.name
351
+ if result.rename_project.inserted:
352
+ # This is not suppose trigger unless the user has messed with the `_saved_name` variable
353
+ wandb.termlog(f"Created new registry {self.name!r} on save")
354
+
355
+ def _no_updating_registry_types(self) -> bool:
356
+ # artifact types draft means user assigned types to add that are not yet saved
357
+ return len(self.artifact_types.draft) > 0 and self.allow_all_artifact_types
@@ -0,0 +1,140 @@
1
+ from enum import Enum
2
+ from functools import lru_cache
3
+ from typing import TYPE_CHECKING, Any, List, Literal, Mapping, Optional, Sequence
4
+
5
+ from wandb.sdk.artifacts._validators import (
6
+ REGISTRY_PREFIX,
7
+ validate_artifact_types_list,
8
+ )
9
+
10
+ if TYPE_CHECKING:
11
+ from wandb_gql import Client
12
+
13
+ from wandb_gql import gql
14
+
15
+
16
+ class _Visibility(str, Enum):
17
+ # names are what users see/pass into Python methods
18
+ # values are what's expected by backend API
19
+ organization = "PRIVATE"
20
+ restricted = "RESTRICTED"
21
+
22
+ @classmethod
23
+ def _missing_(cls, value: object) -> Any:
24
+ return next(
25
+ (e for e in cls if e.name == value),
26
+ None,
27
+ )
28
+
29
+
30
+ def _format_gql_artifact_types_input(
31
+ artifact_types: Optional[List[str]] = None,
32
+ ):
33
+ """Format the artifact types for the GQL input.
34
+
35
+ Args:
36
+ artifact_types: The artifact types to add to the registry.
37
+
38
+ Returns:
39
+ The artifact types for the GQL input.
40
+ """
41
+ if artifact_types is None:
42
+ return []
43
+ new_types = validate_artifact_types_list(artifact_types)
44
+ return [{"name": type} for type in new_types]
45
+
46
+
47
+ def _gql_to_registry_visibility(
48
+ visibility: str,
49
+ ) -> Literal["organization", "restricted"]:
50
+ """Convert the GQL visibility to the registry visibility.
51
+
52
+ Args:
53
+ visibility: The GQL visibility.
54
+
55
+ Returns:
56
+ The registry visibility.
57
+ """
58
+ try:
59
+ return _Visibility(visibility).name
60
+ except ValueError:
61
+ raise ValueError(f"Invalid visibility: {visibility!r} from backend")
62
+
63
+
64
+ def _registry_visibility_to_gql(
65
+ visibility: Literal["organization", "restricted"],
66
+ ) -> str:
67
+ """Convert the registry visibility to the GQL visibility."""
68
+ try:
69
+ return _Visibility[visibility].value
70
+ except KeyError:
71
+ raise ValueError(
72
+ f"Invalid visibility: {visibility!r}. "
73
+ f"Must be one of: {', '.join(map(repr, (e.name for e in _Visibility)))}"
74
+ )
75
+
76
+
77
+ def _ensure_registry_prefix_on_names(query, in_name=False):
78
+ """Traverse the filter to prepend the `name` key value with the registry prefix unless the value is a regex.
79
+
80
+ - in_name: True if we are under a "name" key (or propagating from one).
81
+
82
+ EX: {"name": "model"} -> {"name": "wandb-registry-model"}
83
+ """
84
+ if isinstance((txt := query), str):
85
+ if in_name:
86
+ return txt if txt.startswith(REGISTRY_PREFIX) else f"{REGISTRY_PREFIX}{txt}"
87
+ return txt
88
+ if isinstance((dct := query), Mapping):
89
+ new_dict = {}
90
+ for key, obj in dct.items():
91
+ if key == "name":
92
+ new_dict[key] = _ensure_registry_prefix_on_names(obj, in_name=True)
93
+ elif key == "$regex":
94
+ # For regex operator, we skip transformation of its value.
95
+ new_dict[key] = obj
96
+ else:
97
+ # For any other key, propagate the in_name and skip_transform flags as-is.
98
+ new_dict[key] = _ensure_registry_prefix_on_names(obj, in_name=in_name)
99
+ return new_dict
100
+ if isinstance((objs := query), Sequence):
101
+ return list(
102
+ map(lambda x: _ensure_registry_prefix_on_names(x, in_name=in_name), objs)
103
+ )
104
+ return query
105
+
106
+
107
+ @lru_cache(maxsize=10)
108
+ def _fetch_org_entity_from_organization(client: "Client", organization: str) -> str:
109
+ """Fetch the org entity from the organization.
110
+
111
+ Args:
112
+ client (Client): Graphql client.
113
+ organization (str): The organization to fetch the org entity for.
114
+ """
115
+ query = gql(
116
+ """
117
+ query FetchOrgEntityFromOrganization($organization: String!) {
118
+ organization(name: $organization) {
119
+ orgEntity {
120
+ name
121
+ }
122
+ }
123
+ }
124
+ """
125
+ )
126
+ try:
127
+ response = client.execute(query, variable_values={"organization": organization})
128
+ if response["organization"] and response["organization"]["orgEntity"]:
129
+ if response["organization"]["orgEntity"]["name"]:
130
+ return response["organization"]["orgEntity"]["name"]
131
+ return ValueError(
132
+ f"Organization entity for organization: {organization} is empty"
133
+ )
134
+ raise ValueError(
135
+ f"Organization entity for organization: {organization} not found"
136
+ )
137
+ except Exception as e:
138
+ raise ValueError(
139
+ f"Error fetching org entity for organization: {organization}"
140
+ ) from e
wandb/apis/public/runs.py CHANGED
@@ -61,37 +61,36 @@ RUN_FRAGMENT = """fragment RunFragment on Run {
61
61
  }"""
62
62
 
63
63
 
64
+ @normalize_exceptions
65
+ def _server_provides_internal_id_for_project(client) -> bool:
66
+ """Returns True if the server allows us to query the internalId field for a project.
67
+
68
+ This check is done by utilizing GraphQL introspection in the available fields on the Project type.
69
+ """
70
+ query_string = """
71
+ query ProbeRunInput {
72
+ RunType: __type(name:"Run") {
73
+ fields {
74
+ name
75
+ }
76
+ }
77
+ }
78
+ """
79
+
80
+ # Only perform the query once to avoid extra network calls
81
+ query = gql(query_string)
82
+ res = client.execute(query)
83
+ return "projectId" in [
84
+ x["name"] for x in (res.get("RunType", {}).get("fields", [{}]))
85
+ ]
86
+
87
+
64
88
  class Runs(SizedPaginator["Run"]):
65
89
  """An iterable collection of runs associated with a project and optional filter.
66
90
 
67
91
  This is generally used indirectly via the `Api`.runs method.
68
92
  """
69
93
 
70
- QUERY = gql(
71
- """
72
- query Runs($project: String!, $entity: String!, $cursor: String, $perPage: Int = 50, $order: String, $filters: JSONString) {{
73
- project(name: $project, entityName: $entity) {{
74
- internalId
75
- runCount(filters: $filters)
76
- readOnly
77
- runs(filters: $filters, after: $cursor, first: $perPage, order: $order) {{
78
- edges {{
79
- node {{
80
- ...RunFragment
81
- }}
82
- cursor
83
- }}
84
- pageInfo {{
85
- endCursor
86
- hasNextPage
87
- }}
88
- }}
89
- }}
90
- }}
91
- {}
92
- """.format(RUN_FRAGMENT)
93
- )
94
-
95
94
  def __init__(
96
95
  self,
97
96
  client: "RetryingClient",
@@ -102,6 +101,32 @@ class Runs(SizedPaginator["Run"]):
102
101
  per_page: int = 50,
103
102
  include_sweeps: bool = True,
104
103
  ):
104
+ self.QUERY = gql(
105
+ f"""#graphql
106
+ query Runs($project: String!, $entity: String!, $cursor: String, $perPage: Int = 50, $order: String, $filters: JSONString) {{
107
+ project(name: $project, entityName: $entity) {{
108
+ internalId
109
+ runCount(filters: $filters)
110
+ readOnly
111
+ runs(filters: $filters, after: $cursor, first: $perPage, order: $order) {{
112
+ edges {{
113
+ node {{
114
+ {"" if _server_provides_internal_id_for_project(client) else "internalId"}
115
+ ...RunFragment
116
+ }}
117
+ cursor
118
+ }}
119
+ pageInfo {{
120
+ endCursor
121
+ hasNextPage
122
+ }}
123
+ }}
124
+ }}
125
+ }}
126
+ {RUN_FRAGMENT}
127
+ """
128
+ )
129
+
105
130
  self.entity = entity
106
131
  self.project = project
107
132
  self._project_internal_id = None
@@ -429,7 +454,9 @@ class Run(Attrs):
429
454
  }}
430
455
  {}
431
456
  """.format(
432
- "projectId" if self._server_provides_internal_id_for_project() else "",
457
+ "projectId"
458
+ if _server_provides_internal_id_for_project(self.client)
459
+ else "",
433
460
  RUN_FRAGMENT,
434
461
  )
435
462
  )
@@ -444,10 +471,6 @@ class Run(Attrs):
444
471
  self._attrs = response["project"]["run"]
445
472
  self._state = self._attrs["state"]
446
473
 
447
- self._project_internal_id = (
448
- int(self._attrs["projectId"]) if "projectId" in self._attrs else None
449
- )
450
-
451
474
  if self._include_sweeps and self.sweep_name and not self.sweep:
452
475
  # There may be a lot of runs. Don't bother pulling them all
453
476
  # just for the sake of this one.
@@ -459,6 +482,11 @@ class Run(Attrs):
459
482
  withRuns=False,
460
483
  )
461
484
 
485
+ if "projectId" in self._attrs:
486
+ self._project_internal_id = int(self._attrs["projectId"])
487
+ else:
488
+ self._project_internal_id = None
489
+
462
490
  try:
463
491
  self._attrs["summaryMetrics"] = (
464
492
  json.loads(self._attrs["summaryMetrics"])
@@ -912,32 +940,6 @@ class Run(Attrs):
912
940
  )
913
941
  return artifact
914
942
 
915
- @normalize_exceptions
916
- def _server_provides_internal_id_for_project(self) -> bool:
917
- """Returns True if the server allows us to query the internalId field for a project.
918
-
919
- This check is done by utilizing GraphQL introspection in the available fields on the Project type.
920
- """
921
- query_string = """
922
- query ProbeRunInput {
923
- RunType: __type(name:"Run") {
924
- fields {
925
- name
926
- }
927
- }
928
- }
929
- """
930
-
931
- # Only perform the query once to avoid extra network calls
932
- if self.server_provides_internal_id_field is None:
933
- query = gql(query_string)
934
- res = self.client.execute(query)
935
- self.server_provides_internal_id_field = "projectId" in [
936
- x["name"] for x in (res.get("RunType", {}).get("fields", [{}]))
937
- ]
938
-
939
- return self.server_provides_internal_id_field
940
-
941
943
  @property
942
944
  def summary(self):
943
945
  if self._summary is None: