wandb 0.19.6rc4__py3-none-win32.whl → 0.19.8__py3-none-win32.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. wandb/__init__.py +1 -1
  2. wandb/__init__.pyi +56 -6
  3. wandb/apis/public/_generated/__init__.py +21 -0
  4. wandb/apis/public/_generated/base.py +128 -0
  5. wandb/apis/public/_generated/enums.py +4 -0
  6. wandb/apis/public/_generated/input_types.py +4 -0
  7. wandb/apis/public/_generated/operations.py +15 -0
  8. wandb/apis/public/_generated/server_features_query.py +27 -0
  9. wandb/apis/public/_generated/typing_compat.py +14 -0
  10. wandb/apis/public/api.py +192 -6
  11. wandb/apis/public/artifacts.py +13 -45
  12. wandb/apis/public/registries.py +573 -0
  13. wandb/apis/public/utils.py +36 -0
  14. wandb/bin/gpu_stats.exe +0 -0
  15. wandb/bin/wandb-core +0 -0
  16. wandb/cli/cli.py +11 -20
  17. wandb/data_types.py +1 -1
  18. wandb/env.py +10 -0
  19. wandb/filesync/dir_watcher.py +2 -1
  20. wandb/proto/v3/wandb_internal_pb2.py +243 -222
  21. wandb/proto/v3/wandb_server_pb2.py +4 -4
  22. wandb/proto/v3/wandb_settings_pb2.py +2 -2
  23. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  24. wandb/proto/v4/wandb_internal_pb2.py +226 -222
  25. wandb/proto/v4/wandb_server_pb2.py +4 -4
  26. wandb/proto/v4/wandb_settings_pb2.py +2 -2
  27. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  28. wandb/proto/v5/wandb_internal_pb2.py +226 -222
  29. wandb/proto/v5/wandb_server_pb2.py +4 -4
  30. wandb/proto/v5/wandb_settings_pb2.py +2 -2
  31. wandb/proto/v5/wandb_telemetry_pb2.py +10 -10
  32. wandb/sdk/artifacts/_graphql_fragments.py +126 -0
  33. wandb/sdk/artifacts/artifact.py +51 -95
  34. wandb/sdk/backend/backend.py +17 -6
  35. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +14 -6
  36. wandb/sdk/data_types/helper_types/image_mask.py +12 -6
  37. wandb/sdk/data_types/saved_model.py +35 -46
  38. wandb/sdk/data_types/video.py +7 -16
  39. wandb/sdk/interface/interface.py +87 -49
  40. wandb/sdk/interface/interface_queue.py +5 -15
  41. wandb/sdk/interface/interface_relay.py +7 -22
  42. wandb/sdk/interface/interface_shared.py +65 -136
  43. wandb/sdk/interface/interface_sock.py +3 -21
  44. wandb/sdk/interface/router.py +42 -68
  45. wandb/sdk/interface/router_queue.py +13 -11
  46. wandb/sdk/interface/router_relay.py +26 -13
  47. wandb/sdk/interface/router_sock.py +12 -16
  48. wandb/sdk/internal/handler.py +4 -3
  49. wandb/sdk/internal/internal_api.py +12 -1
  50. wandb/sdk/internal/sender.py +3 -19
  51. wandb/sdk/lib/apikey.py +87 -26
  52. wandb/sdk/lib/asyncio_compat.py +210 -0
  53. wandb/sdk/lib/console_capture.py +172 -0
  54. wandb/sdk/lib/progress.py +78 -16
  55. wandb/sdk/lib/redirect.py +102 -76
  56. wandb/sdk/lib/service_connection.py +37 -17
  57. wandb/sdk/lib/sock_client.py +6 -56
  58. wandb/sdk/mailbox/__init__.py +23 -0
  59. wandb/sdk/mailbox/mailbox.py +135 -0
  60. wandb/sdk/mailbox/mailbox_handle.py +127 -0
  61. wandb/sdk/mailbox/response_handle.py +167 -0
  62. wandb/sdk/mailbox/wait_with_progress.py +135 -0
  63. wandb/sdk/service/server_sock.py +9 -3
  64. wandb/sdk/service/streams.py +75 -78
  65. wandb/sdk/verify/verify.py +54 -2
  66. wandb/sdk/wandb_init.py +72 -75
  67. wandb/sdk/wandb_login.py +7 -4
  68. wandb/sdk/wandb_metadata.py +65 -34
  69. wandb/sdk/wandb_require.py +14 -8
  70. wandb/sdk/wandb_run.py +90 -97
  71. wandb/sdk/wandb_settings.py +10 -4
  72. wandb/sdk/wandb_setup.py +19 -8
  73. wandb/sdk/wandb_sync.py +2 -10
  74. wandb/util.py +3 -1
  75. {wandb-0.19.6rc4.dist-info → wandb-0.19.8.dist-info}/METADATA +2 -2
  76. {wandb-0.19.6rc4.dist-info → wandb-0.19.8.dist-info}/RECORD +79 -66
  77. wandb/sdk/interface/message_future.py +0 -27
  78. wandb/sdk/interface/message_future_poll.py +0 -50
  79. wandb/sdk/lib/mailbox.py +0 -442
  80. {wandb-0.19.6rc4.dist-info → wandb-0.19.8.dist-info}/WHEEL +0 -0
  81. {wandb-0.19.6rc4.dist-info → wandb-0.19.8.dist-info}/entry_points.txt +0 -0
  82. {wandb-0.19.6rc4.dist-info → wandb-0.19.8.dist-info}/licenses/LICENSE +0 -0
wandb/__init__.py CHANGED
@@ -10,7 +10,7 @@ For reference documentation, see https://docs.wandb.com/ref/python.
10
10
  """
11
11
  from __future__ import annotations
12
12
 
13
- __version__ = "0.19.6rc4"
13
+ __version__ = "0.19.8"
14
14
 
15
15
 
16
16
  from wandb.errors import Error
wandb/__init__.pyi CHANGED
@@ -55,6 +55,7 @@ __all__ = (
55
55
  "unwatch",
56
56
  "plot",
57
57
  "plot_table",
58
+ "restore",
58
59
  )
59
60
 
60
61
  import os
@@ -63,10 +64,12 @@ from typing import (
63
64
  Any,
64
65
  Callable,
65
66
  Dict,
67
+ Iterable,
66
68
  List,
67
69
  Literal,
68
70
  Optional,
69
71
  Sequence,
72
+ TextIO,
70
73
  Union,
71
74
  )
72
75
 
@@ -103,7 +106,7 @@ if TYPE_CHECKING:
103
106
  import wandb
104
107
  from wandb.plot import CustomChart
105
108
 
106
- __version__: str = "0.19.6rc4"
109
+ __version__: str = "0.19.8"
107
110
 
108
111
  run: Run | None
109
112
  config: wandb_config.Config
@@ -114,6 +117,25 @@ _sentry: Sentry
114
117
  api: InternalApi
115
118
  patched: Dict[str, List[Callable]]
116
119
 
120
+ def require(
121
+ requirement: str | Iterable[str] | None = None,
122
+ experiment: str | Iterable[str] | None = None,
123
+ ) -> None:
124
+ """Indicate which experimental features are used by the script.
125
+
126
+ This should be called before any other `wandb` functions, ideally right
127
+ after importing `wandb`.
128
+
129
+ Args:
130
+ requirement: The name of a feature to require or an iterable of
131
+ feature names.
132
+ experiment: An alias for `requirement`.
133
+
134
+ Raises:
135
+ wandb.errors.UnsupportedError: If a feature name is unknown.
136
+ """
137
+ ...
138
+
117
139
  def setup(settings: Settings | None = None) -> _WandbSetup:
118
140
  """Prepares W&B for use in the current process and its children.
119
141
 
@@ -271,10 +293,10 @@ def init(
271
293
  on the system, such as checking the git root or the current program
272
294
  file. If we can't infer the project name, the project will default to
273
295
  `"uncategorized"`.
274
- dir: An absolute path to the directory where metadata and downloaded
275
- files will be stored. When calling `download()` on an artifact, files
276
- will be saved to this directory. If not specified, this defaults to
277
- the `./wandb` directory.
296
+ dir: The absolute path to the directory where experiment logs and
297
+ metadata files are stored. If not specified, this defaults
298
+ to the `./wandb` directory. Note that this does not affect the
299
+ location where artifacts are stored when calling `download()`.
278
300
  id: A unique identifier for this run, used for resuming. It must be unique
279
301
  within the project and cannot be reused once a run is deleted. The
280
302
  identifier must not contain any of the following special characters:
@@ -510,7 +532,7 @@ def log(
510
532
  [guides to logging](https://docs.wandb.ai/guides/track/log) for examples,
511
533
  from 3D molecular structures and segmentation masks to PR curves and histograms.
512
534
  You can use `wandb.Table` to log structured data. See our
513
- [guide to logging tables](https://docs.wandb.ai/guides/tables/tables-walkthrough)
535
+ [guide to logging tables](https://docs.wandb.ai/guides/models/tables/tables-walkthrough)
514
536
  for details.
515
537
 
516
538
  The W&B UI organizes metrics with a forward slash (`/`) in their name
@@ -1196,3 +1218,31 @@ def unwatch(
1196
1218
  Optional list of pytorch models that have had watch called on them
1197
1219
  """
1198
1220
  ...
1221
+
1222
+ def restore(
1223
+ name: str,
1224
+ run_path: str | None = None,
1225
+ replace: bool = False,
1226
+ root: str | None = None,
1227
+ ) -> None | TextIO:
1228
+ """Download the specified file from cloud storage.
1229
+
1230
+ File is placed into the current directory or run directory.
1231
+ By default, will only download the file if it doesn't already exist.
1232
+
1233
+ Args:
1234
+ name: the name of the file
1235
+ run_path: optional path to a run to pull files from, i.e. `username/project_name/run_id`
1236
+ if wandb.init has not been called, this is required.
1237
+ replace: whether to download the file even if it already exists locally
1238
+ root: the directory to download the file to. Defaults to the current
1239
+ directory or the run directory if wandb.init was called.
1240
+
1241
+ Returns:
1242
+ None if it can't find the file, otherwise a file object open for reading
1243
+
1244
+ Raises:
1245
+ wandb.CommError: if we can't connect to the wandb backend
1246
+ ValueError: if the file is not found or can't find run_path
1247
+ """
1248
+ ...
@@ -0,0 +1,21 @@
1
+ # Generated by ariadne-codegen
2
+
3
+ from .base import Base, GQLBase, GQLId, SerializedToJson, Typename
4
+ from .operations import SERVER_FEATURES_QUERY_GQL
5
+ from .server_features_query import (
6
+ ServerFeaturesQuery,
7
+ ServerFeaturesQueryServerInfo,
8
+ ServerFeaturesQueryServerInfoFeatures,
9
+ )
10
+
11
+ __all__ = [
12
+ "Base",
13
+ "GQLBase",
14
+ "GQLId",
15
+ "SerializedToJson",
16
+ "Typename",
17
+ "SERVER_FEATURES_QUERY_GQL",
18
+ "ServerFeaturesQuery",
19
+ "ServerFeaturesQueryServerInfo",
20
+ "ServerFeaturesQueryServerInfoFeatures",
21
+ ]
@@ -0,0 +1,128 @@
1
+ # Generated by ariadne-codegen
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Literal, TypeVar
6
+
7
+ from pydantic import BaseModel, ConfigDict, Field, Json, ValidationError, WrapValidator
8
+ from pydantic.alias_generators import to_camel
9
+ from pydantic.main import IncEx
10
+ from pydantic_core import to_json
11
+ from pydantic_core.core_schema import ValidatorFunctionWrapHandler
12
+
13
+ from .typing_compat import Annotated, override
14
+
15
+
16
+ # Base class for all automation classes/types.
17
+ # Omitted from docstring to avoid inclusion in generated docs.
18
+ class Base(BaseModel):
19
+ model_config = ConfigDict(
20
+ populate_by_name=True,
21
+ validate_assignment=True,
22
+ validate_default=True,
23
+ extra="forbid",
24
+ alias_generator=to_camel,
25
+ use_attribute_docstrings=True,
26
+ from_attributes=True,
27
+ revalidate_instances="always",
28
+ )
29
+
30
+ @override
31
+ def model_dump(
32
+ self,
33
+ *,
34
+ mode: Literal["json", "python"] | str = "json", # NOTE: changed default
35
+ include: IncEx | None = None,
36
+ exclude: IncEx | None = None,
37
+ context: dict[str, Any] | None = None,
38
+ by_alias: bool = True, # NOTE: changed default
39
+ exclude_unset: bool = False,
40
+ exclude_defaults: bool = False,
41
+ exclude_none: bool = False,
42
+ round_trip: bool = True, # NOTE: changed default
43
+ warnings: bool | Literal["none", "warn", "error"] = True,
44
+ serialize_as_any: bool = False,
45
+ ) -> dict[str, Any]:
46
+ return super().model_dump(
47
+ mode=mode,
48
+ include=include,
49
+ exclude=exclude,
50
+ context=context,
51
+ by_alias=by_alias,
52
+ exclude_unset=exclude_unset,
53
+ exclude_defaults=exclude_defaults,
54
+ exclude_none=exclude_none,
55
+ round_trip=round_trip,
56
+ warnings=warnings,
57
+ serialize_as_any=serialize_as_any,
58
+ )
59
+
60
+ @override
61
+ def model_dump_json(
62
+ self,
63
+ *,
64
+ indent: int | None = None,
65
+ include: IncEx | None = None,
66
+ exclude: IncEx | None = None,
67
+ context: dict[str, Any] | None = None,
68
+ by_alias: bool = True, # NOTE: changed default
69
+ exclude_unset: bool = False,
70
+ exclude_defaults: bool = False,
71
+ exclude_none: bool = False,
72
+ round_trip: bool = True, # NOTE: changed default
73
+ warnings: bool | Literal["none", "warn", "error"] = True,
74
+ serialize_as_any: bool = False,
75
+ ) -> str:
76
+ return super().model_dump_json(
77
+ indent=indent,
78
+ include=include,
79
+ exclude=exclude,
80
+ context=context,
81
+ by_alias=by_alias,
82
+ exclude_unset=exclude_unset,
83
+ exclude_defaults=exclude_defaults,
84
+ exclude_none=exclude_none,
85
+ round_trip=round_trip,
86
+ warnings=warnings,
87
+ serialize_as_any=serialize_as_any,
88
+ )
89
+
90
+
91
+ # Base class with extra customization for GQL generated types.
92
+ # Omitted from docstring to avoid inclusion in generated docs.
93
+ class GQLBase(Base):
94
+ model_config = ConfigDict(
95
+ extra="ignore",
96
+ protected_namespaces=(),
97
+ )
98
+
99
+
100
+ # ------------------------------------------------------------------------------
101
+ # Reusable annotations for field types
102
+ T = TypeVar("T")
103
+
104
+ GQLId = Annotated[
105
+ str,
106
+ Field(repr=False, strict=True, frozen=True),
107
+ ]
108
+
109
+ Typename = Annotated[
110
+ T,
111
+ Field(repr=False, alias="__typename", frozen=True),
112
+ ]
113
+
114
+
115
+ def validate_maybe_json(v: Any, handler: ValidatorFunctionWrapHandler) -> Any:
116
+ """Wraps default Json[...] field validator to allow instantiation with an already-decoded value."""
117
+ try:
118
+ return handler(v)
119
+ except ValidationError:
120
+ # Try revalidating after properly jsonifying the value
121
+ return handler(to_json(v, by_alias=True, round_trip=True))
122
+
123
+
124
+ SerializedToJson = Annotated[
125
+ Json[T],
126
+ # Allow lenient instantiation/validation: incoming data may already be deserialized.
127
+ WrapValidator(validate_maybe_json),
128
+ ]
@@ -0,0 +1,4 @@
1
+ # Generated by ariadne-codegen
2
+ # Source: core/api/graphql/schemas/schema-latest.graphql
3
+
4
+ from __future__ import annotations
@@ -0,0 +1,4 @@
1
+ # Generated by ariadne-codegen
2
+ # Source: core/api/graphql/schemas/schema-latest.graphql
3
+
4
+ from __future__ import annotations
@@ -0,0 +1,15 @@
1
+ # Generated by ariadne-codegen
2
+ # Source: tools/graphql_codegen/utils/
3
+
4
+ __all__ = ["SERVER_FEATURES_QUERY_GQL"]
5
+
6
+ SERVER_FEATURES_QUERY_GQL = """
7
+ query ServerFeaturesQuery {
8
+ serverInfo {
9
+ features {
10
+ name
11
+ isEnabled
12
+ }
13
+ }
14
+ }
15
+ """
@@ -0,0 +1,27 @@
1
+ # Generated by ariadne-codegen
2
+ # Source: tools/graphql_codegen/utils/
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import List, Optional
7
+
8
+ from pydantic import Field
9
+
10
+ from .base import GQLBase
11
+
12
+
13
+ class ServerFeaturesQuery(GQLBase):
14
+ server_info: Optional["ServerFeaturesQueryServerInfo"] = Field(alias="serverInfo")
15
+
16
+
17
+ class ServerFeaturesQueryServerInfo(GQLBase):
18
+ features: List[Optional["ServerFeaturesQueryServerInfoFeatures"]]
19
+
20
+
21
+ class ServerFeaturesQueryServerInfoFeatures(GQLBase):
22
+ name: str
23
+ is_enabled: bool = Field(alias="isEnabled")
24
+
25
+
26
+ ServerFeaturesQuery.model_rebuild()
27
+ ServerFeaturesQueryServerInfo.model_rebuild()
@@ -0,0 +1,14 @@
1
+ # Generated by ariadne-codegen
2
+
3
+ """Definitions to ensure compatibility with all supported python versions."""
4
+
5
+ import sys
6
+
7
+ if sys.version_info >= (3, 12):
8
+ from typing import Annotated, override
9
+ else:
10
+ from typing_extensions import Annotated, override
11
+
12
+
13
+ Annnotated = Annotated
14
+ override = override
wandb/apis/public/api.py CHANGED
@@ -25,8 +25,15 @@ import wandb
25
25
  from wandb import env, util
26
26
  from wandb.apis import public
27
27
  from wandb.apis.normalize import normalize_exceptions
28
+ from wandb.apis.public._generated import SERVER_FEATURES_QUERY_GQL, ServerFeaturesQuery
28
29
  from wandb.apis.public.const import RETRY_TIMEDELTA
29
- from wandb.apis.public.utils import PathType, parse_org_from_registry_path
30
+ from wandb.apis.public.registries import Registries
31
+ from wandb.apis.public.utils import (
32
+ PathType,
33
+ fetch_org_from_settings_or_entity,
34
+ parse_org_from_registry_path,
35
+ )
36
+ from wandb.proto.wandb_internal_pb2 import ServerFeature
30
37
  from wandb.sdk.artifacts._validators import is_artifact_registry_project
31
38
  from wandb.sdk.internal.internal_api import Api as InternalApi
32
39
  from wandb.sdk.internal.thread_local_settings import _thread_local_api_settings
@@ -248,6 +255,9 @@ class Api:
248
255
  self.settings["entity"] = _overrides["username"]
249
256
  self.settings["base_url"] = self.settings["base_url"].rstrip("/")
250
257
 
258
+ if "organization" in _overrides:
259
+ self.settings["organization"] = _overrides["organization"]
260
+
251
261
  self._viewer = None
252
262
  self._projects = {}
253
263
  self._runs = {}
@@ -279,6 +289,7 @@ class Api:
279
289
  )
280
290
  )
281
291
  self._client = RetryingClient(self._base_client)
292
+ self._server_features_cache = None
282
293
 
283
294
  def create_project(self, name: str, entity: str) -> None:
284
295
  """Create a new project.
@@ -900,12 +911,46 @@ class Api:
900
911
  ):
901
912
  """Return a set of runs from a project that match the filters provided.
902
913
 
903
- You can filter by `config.*`, `summary_metrics.*`, `tags`, `state`, `entity`, `createdAt`, etc.
914
+ Fields you can filter by include:
915
+ - `createdAt`: The timestamp when the run was created. (in ISO 8601 format, e.g. "2023-01-01T12:00:00Z")
916
+ - `displayName`: The human-readable display name of the run. (e.g. "eager-fox-1")
917
+ - `duration`: The total runtime of the run in seconds.
918
+ - `group`: The group name used to organize related runs together.
919
+ - `host`: The hostname where the run was executed.
920
+ - `jobType`: The type of job or purpose of the run.
921
+ - `name`: The unique identifier of the run. (e.g. "a1b2cdef")
922
+ - `state`: The current state of the run.
923
+ - `tags`: The tags associated with the run.
924
+ - `username`: The username of the user who initiated the run
925
+
926
+ Additionally, you can filter by items in the run config or summary metrics.
927
+ Such as `config.experiment_name`, `summary_metrics.loss`, etc.
928
+
929
+ For more complex filtering, you can use MongoDB query operators.
930
+ For details, see: https://docs.mongodb.com/manual/reference/operator/query
931
+ The following operations are supported:
932
+ - `$and`
933
+ - `$or`
934
+ - `$nor`
935
+ - `$eq`
936
+ - `$ne`
937
+ - `$gt`
938
+ - `$gte`
939
+ - `$lt`
940
+ - `$lte`
941
+ - `$in`
942
+ - `$nin`
943
+ - `$exists`
944
+ - `$regex`
945
+
904
946
 
905
947
  Examples:
906
948
  Find runs in my_project where config.experiment_name has been set to "foo"
907
949
  ```
908
- api.runs(path="my_entity/my_project", filters={"config.experiment_name": "foo"})
950
+ api.runs(
951
+ path="my_entity/my_project",
952
+ filters={"config.experiment_name": "foo"},
953
+ )
909
954
  ```
910
955
 
911
956
  Find runs in my_project where config.experiment_name has been set to "foo" or "bar"
@@ -932,7 +977,24 @@ class Api:
932
977
  Find runs in my_project where the run name matches a regex (anchors are not supported)
933
978
  ```
934
979
  api.runs(
935
- path="my_entity/my_project", filters={"display_name": {"$regex": "^foo.*"}}
980
+ path="my_entity/my_project",
981
+ filters={"display_name": {"$regex": "^foo.*"}},
982
+ )
983
+ ```
984
+
985
+ Find runs in my_project where config.experiment contains a nested field "category" with value "testing"
986
+ ```
987
+ api.runs(
988
+ path="my_entity/my_project",
989
+ filters={"config.experiment.category": "testing"},
990
+ )
991
+ ```
992
+
993
+ Find runs in my_project with a loss value of 0.5 nested in a dictionary under model1 in the summary metrics
994
+ ```
995
+ api.runs(
996
+ path="my_entity/my_project",
997
+ filters={"summary_metrics.model1.loss": 0.5},
936
998
  )
937
999
  ```
938
1000
 
@@ -947,8 +1009,6 @@ class Api:
947
1009
  You can filter by run properties such as config.key, summary_metrics.key, state, entity, createdAt, etc.
948
1010
  For example: `{"config.experiment_name": "foo"}` would find runs with a config entry
949
1011
  of experiment name set to "foo"
950
- You can compose operations to make more complicated queries,
951
- see Reference for the language is at https://docs.mongodb.com/manual/reference/operator/query
952
1012
  order: (str) Order can be `created_at`, `heartbeat_at`, `config.*.value`, or `summary_metrics.*`.
953
1013
  If you prepend order with a + order is ascending.
954
1014
  If you prepend order with a - order is descending (default).
@@ -1139,6 +1199,12 @@ class Api:
1139
1199
  entity = InternalApi()._resolve_org_entity_name(
1140
1200
  entity=settings_entity, organization=org
1141
1201
  )
1202
+
1203
+ if entity is None:
1204
+ raise ValueError(
1205
+ "Could not determine entity. Please include the entity as part of the collection name path."
1206
+ )
1207
+
1142
1208
  return public.ArtifactCollection(
1143
1209
  self.client, entity, project, collection_name, type_name
1144
1210
  )
@@ -1211,6 +1277,12 @@ class Api:
1211
1277
  entity = InternalApi()._resolve_org_entity_name(
1212
1278
  entity=settings_entity, organization=organization
1213
1279
  )
1280
+
1281
+ if entity is None:
1282
+ raise ValueError(
1283
+ "Could not determine entity. Please include the entity as part of the artifact name path."
1284
+ )
1285
+
1214
1286
  artifact = wandb.Artifact._from_name(
1215
1287
  entity=entity,
1216
1288
  project=project,
@@ -1385,3 +1457,117 @@ class Api:
1385
1457
  return True
1386
1458
  except wandb.errors.CommError:
1387
1459
  return False
1460
+
1461
+ def registries(
1462
+ self,
1463
+ organization: Optional[str] = None,
1464
+ filter: Optional[Dict[str, Any]] = None,
1465
+ ) -> Registries:
1466
+ """Returns a Registry iterator.
1467
+
1468
+ Use the iterator to search and filter registries, collections,
1469
+ or artifact versions across your organization's registry.
1470
+
1471
+ Examples:
1472
+ Find all registries with the names that contain "model"
1473
+ ```python
1474
+ import wandb
1475
+
1476
+ api = wandb.Api() # specify an org if your entity belongs to multiple orgs
1477
+ api.registries(filter={"name": {"$regex": "model"}})
1478
+ ```
1479
+
1480
+ Find all collections in the registries with the name "my_collection" and the tag "my_tag"
1481
+ ```python
1482
+ api.registries().collections(filter={"name": "my_collection", "tag": "my_tag"})
1483
+ ```
1484
+
1485
+ Find all artifact versions in the registries with a collection name that contains "my_collection" and a version that has the alias "best"
1486
+ ```python
1487
+ api.registries().collections(
1488
+ filter={"name": {"$regex": "my_collection"}}
1489
+ ).versions(filter={"alias": "best"})
1490
+ ```
1491
+
1492
+ Find all artifact versions in the registries that contain "model" and have the tag "prod" or alias "best"
1493
+ ```python
1494
+ api.registries(filter={"name": {"$regex": "model"}}).versions(
1495
+ filter={"$or": [{"tag": "prod"}, {"alias": "best"}]}
1496
+ )
1497
+ ```
1498
+
1499
+ Args:
1500
+ organization: (str, optional) The organization of the registry to fetch.
1501
+ If not specified, use the organization specified in the user's settings.
1502
+ filter: (dict, optional) MongoDB-style filter to apply to each object in the registry iterator.
1503
+ Fields available to filter for collections are
1504
+ `name`, `description`, `created_at`, `updated_at`.
1505
+ Fields available to filter for collections are
1506
+ `name`, `tag`, `description`, `created_at`, `updated_at`
1507
+ Fields available to filter for versions are
1508
+ `tag`, `alias`, `created_at`, `updated_at`, `metadata`
1509
+
1510
+ Returns:
1511
+ A registry iterator.
1512
+ """
1513
+ if not self._check_server_feature_with_fallback(
1514
+ ServerFeature.ARTIFACT_REGISTRY_SEARCH
1515
+ ):
1516
+ raise RuntimeError(
1517
+ "Registry search API is not enabled on this wandb server version. "
1518
+ "Please upgrade your server version or contact support at support@wandb.com."
1519
+ )
1520
+
1521
+ organization = organization or fetch_org_from_settings_or_entity(
1522
+ self.settings, self.default_entity
1523
+ )
1524
+ return Registries(self.client, organization, filter)
1525
+
1526
+ def _check_server_feature(self, feature: ServerFeature) -> bool:
1527
+ """Check if a server feature is enabled.
1528
+
1529
+ Args:
1530
+ feature (ServerFeature): The feature to check.
1531
+
1532
+ Returns:
1533
+ bool: True if the feature is enabled, False otherwise.
1534
+
1535
+ Raises:
1536
+ Exception: If server doesn't support feature queries or other errors occur
1537
+ """
1538
+ if self._server_features_cache is None:
1539
+ response = self.client.execute(gql(SERVER_FEATURES_QUERY_GQL))
1540
+ self._server_features_cache = ServerFeaturesQuery.model_validate(response)
1541
+
1542
+ feature_name = ServerFeature.Name(feature)
1543
+ if (
1544
+ self._server_features_cache
1545
+ and self._server_features_cache.server_info
1546
+ and self._server_features_cache.server_info.features
1547
+ ):
1548
+ for feature_info in self._server_features_cache.server_info.features:
1549
+ if feature_info and feature_info.name == feature_name:
1550
+ return feature_info.is_enabled
1551
+
1552
+ return False
1553
+
1554
+ def _check_server_feature_with_fallback(self, feature: ServerFeature) -> bool:
1555
+ """Wrapper around check_server_feature that warns and returns False for older unsupported servers.
1556
+
1557
+ Good to use for features that have a fallback mechanism for older servers.
1558
+
1559
+ Args:
1560
+ feature (ServerFeature): The feature to check.
1561
+
1562
+ Returns:
1563
+ bool: True if the feature is enabled, False otherwise.
1564
+
1565
+ Exceptions:
1566
+ Exception: If an error other than the server not supporting feature queries occurs.
1567
+ """
1568
+ try:
1569
+ return self._check_server_feature(feature)
1570
+ except Exception as e:
1571
+ if 'Cannot query field "features" on type "ServerInfo".' in str(e):
1572
+ return False
1573
+ raise e