wandb 0.20.1rc20250604__py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl → 0.21.0__py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (125) hide show
  1. wandb/__init__.py +3 -6
  2. wandb/__init__.pyi +24 -23
  3. wandb/analytics/sentry.py +2 -2
  4. wandb/apis/importers/internals/internal.py +0 -3
  5. wandb/apis/internal.py +3 -0
  6. wandb/apis/paginator.py +17 -4
  7. wandb/apis/public/api.py +85 -4
  8. wandb/apis/public/artifacts.py +10 -8
  9. wandb/apis/public/files.py +5 -5
  10. wandb/apis/public/projects.py +44 -3
  11. wandb/apis/public/registries/{utils.py → _utils.py} +12 -12
  12. wandb/apis/public/registries/registries_search.py +2 -2
  13. wandb/apis/public/registries/registry.py +19 -18
  14. wandb/apis/public/reports.py +64 -8
  15. wandb/apis/public/runs.py +16 -23
  16. wandb/automations/__init__.py +10 -10
  17. wandb/automations/_filters/run_metrics.py +0 -2
  18. wandb/automations/_utils.py +0 -2
  19. wandb/automations/actions.py +0 -2
  20. wandb/automations/automations.py +0 -2
  21. wandb/automations/events.py +0 -2
  22. wandb/bin/gpu_stats +0 -0
  23. wandb/bin/wandb-core +0 -0
  24. wandb/cli/beta.py +1 -7
  25. wandb/cli/cli.py +0 -30
  26. wandb/env.py +0 -6
  27. wandb/integration/catboost/catboost.py +6 -2
  28. wandb/integration/kfp/kfp_patch.py +3 -1
  29. wandb/integration/sb3/sb3.py +3 -3
  30. wandb/integration/ultralytics/callback.py +6 -2
  31. wandb/plot/__init__.py +2 -0
  32. wandb/plot/bar.py +30 -29
  33. wandb/plot/confusion_matrix.py +75 -71
  34. wandb/plot/histogram.py +26 -25
  35. wandb/plot/line.py +33 -32
  36. wandb/plot/line_series.py +100 -103
  37. wandb/plot/pr_curve.py +33 -32
  38. wandb/plot/roc_curve.py +38 -38
  39. wandb/plot/scatter.py +27 -27
  40. wandb/proto/v3/wandb_internal_pb2.py +366 -385
  41. wandb/proto/v3/wandb_settings_pb2.py +2 -2
  42. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  43. wandb/proto/v4/wandb_internal_pb2.py +352 -356
  44. wandb/proto/v4/wandb_settings_pb2.py +2 -2
  45. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  46. wandb/proto/v5/wandb_internal_pb2.py +352 -356
  47. wandb/proto/v5/wandb_settings_pb2.py +2 -2
  48. wandb/proto/v5/wandb_telemetry_pb2.py +10 -10
  49. wandb/proto/v6/wandb_internal_pb2.py +352 -356
  50. wandb/proto/v6/wandb_settings_pb2.py +2 -2
  51. wandb/proto/v6/wandb_telemetry_pb2.py +10 -10
  52. wandb/sdk/artifacts/_generated/__init__.py +12 -1
  53. wandb/sdk/artifacts/_generated/input_types.py +20 -2
  54. wandb/sdk/artifacts/_generated/link_artifact.py +21 -0
  55. wandb/sdk/artifacts/_generated/operations.py +9 -0
  56. wandb/sdk/artifacts/_validators.py +40 -2
  57. wandb/sdk/artifacts/artifact.py +163 -21
  58. wandb/sdk/artifacts/storage_handlers/s3_handler.py +42 -1
  59. wandb/sdk/backend/backend.py +1 -1
  60. wandb/sdk/data_types/base_types/media.py +9 -7
  61. wandb/sdk/data_types/base_types/wb_value.py +6 -6
  62. wandb/sdk/data_types/saved_model.py +3 -3
  63. wandb/sdk/data_types/table.py +41 -41
  64. wandb/sdk/data_types/trace_tree.py +12 -12
  65. wandb/sdk/interface/interface.py +8 -19
  66. wandb/sdk/interface/interface_shared.py +7 -16
  67. wandb/sdk/internal/datastore.py +18 -18
  68. wandb/sdk/internal/handler.py +4 -74
  69. wandb/sdk/internal/internal_api.py +54 -0
  70. wandb/sdk/internal/sender.py +23 -3
  71. wandb/sdk/internal/sender_config.py +9 -0
  72. wandb/sdk/launch/_project_spec.py +3 -3
  73. wandb/sdk/launch/agent/agent.py +3 -3
  74. wandb/sdk/launch/agent/job_status_tracker.py +3 -1
  75. wandb/sdk/launch/utils.py +3 -3
  76. wandb/sdk/lib/console_capture.py +66 -19
  77. wandb/sdk/lib/printer.py +6 -7
  78. wandb/sdk/lib/progress.py +1 -3
  79. wandb/sdk/lib/service/ipc_support.py +13 -0
  80. wandb/sdk/lib/{service_connection.py → service/service_connection.py} +20 -56
  81. wandb/sdk/lib/service/service_port_file.py +105 -0
  82. wandb/sdk/lib/service/service_process.py +111 -0
  83. wandb/sdk/lib/service/service_token.py +164 -0
  84. wandb/sdk/lib/sock_client.py +8 -12
  85. wandb/sdk/wandb_init.py +1 -5
  86. wandb/sdk/wandb_require.py +9 -21
  87. wandb/sdk/wandb_run.py +23 -137
  88. wandb/sdk/wandb_settings.py +233 -80
  89. wandb/sdk/wandb_setup.py +2 -13
  90. {wandb-0.20.1rc20250604.dist-info → wandb-0.21.0.dist-info}/METADATA +1 -3
  91. {wandb-0.20.1rc20250604.dist-info → wandb-0.21.0.dist-info}/RECORD +94 -120
  92. wandb/sdk/internal/flow_control.py +0 -263
  93. wandb/sdk/internal/internal.py +0 -401
  94. wandb/sdk/internal/internal_util.py +0 -97
  95. wandb/sdk/internal/system/__init__.py +0 -0
  96. wandb/sdk/internal/system/assets/__init__.py +0 -25
  97. wandb/sdk/internal/system/assets/aggregators.py +0 -31
  98. wandb/sdk/internal/system/assets/asset_registry.py +0 -20
  99. wandb/sdk/internal/system/assets/cpu.py +0 -163
  100. wandb/sdk/internal/system/assets/disk.py +0 -210
  101. wandb/sdk/internal/system/assets/gpu.py +0 -416
  102. wandb/sdk/internal/system/assets/gpu_amd.py +0 -233
  103. wandb/sdk/internal/system/assets/interfaces.py +0 -205
  104. wandb/sdk/internal/system/assets/ipu.py +0 -177
  105. wandb/sdk/internal/system/assets/memory.py +0 -166
  106. wandb/sdk/internal/system/assets/network.py +0 -125
  107. wandb/sdk/internal/system/assets/open_metrics.py +0 -293
  108. wandb/sdk/internal/system/assets/tpu.py +0 -154
  109. wandb/sdk/internal/system/assets/trainium.py +0 -393
  110. wandb/sdk/internal/system/env_probe_helpers.py +0 -13
  111. wandb/sdk/internal/system/system_info.py +0 -248
  112. wandb/sdk/internal/system/system_monitor.py +0 -224
  113. wandb/sdk/internal/writer.py +0 -204
  114. wandb/sdk/lib/service_token.py +0 -93
  115. wandb/sdk/service/__init__.py +0 -0
  116. wandb/sdk/service/_startup_debug.py +0 -22
  117. wandb/sdk/service/port_file.py +0 -53
  118. wandb/sdk/service/server.py +0 -107
  119. wandb/sdk/service/server_sock.py +0 -286
  120. wandb/sdk/service/service.py +0 -252
  121. wandb/sdk/service/streams.py +0 -425
  122. wandb/sdk/wandb_metadata.py +0 -623
  123. {wandb-0.20.1rc20250604.dist-info → wandb-0.21.0.dist-info}/WHEEL +0 -0
  124. {wandb-0.20.1rc20250604.dist-info → wandb-0.21.0.dist-info}/entry_points.txt +0 -0
  125. {wandb-0.20.1rc20250604.dist-info → wandb-0.21.0.dist-info}/licenses/LICENSE +0 -0
wandb/__init__.py CHANGED
@@ -10,7 +10,7 @@ For reference documentation, see https://docs.wandb.com/ref/python.
10
10
  """
11
11
  from __future__ import annotations
12
12
 
13
- __version__ = "0.20.1rc20250604"
13
+ __version__ = "0.21.0"
14
14
 
15
15
 
16
16
  from wandb.errors import Error
@@ -30,9 +30,9 @@ wandb.wandb_lib = wandb_sdk.lib # type: ignore
30
30
 
31
31
  init = wandb_sdk.init
32
32
  setup = wandb_sdk.setup
33
- _attach = wandb_sdk._attach
33
+ attach = _attach = wandb_sdk._attach
34
34
  _sync = wandb_sdk._sync
35
- _teardown = wandb_sdk.teardown
35
+ teardown = _teardown = wandb_sdk.teardown
36
36
  finish = wandb_sdk.finish
37
37
  join = finish
38
38
  login = wandb_sdk.login
@@ -51,9 +51,6 @@ from wandb.errors import CommError, UsageError
51
51
  _preinit = wandb.wandb_lib.preinit # type: ignore
52
52
  _lazyloader = wandb.wandb_lib.lazyloader # type: ignore
53
53
 
54
- # Call import module hook to set up any needed require hooks
55
- wandb.sdk.wandb_require._import_module_hook()
56
-
57
54
  from wandb.integration.torch import wandb_torch
58
55
 
59
56
  # Move this (keras.__init__ expects it at top level)
wandb/__init__.pyi CHANGED
@@ -12,20 +12,20 @@ For reference documentation, see https://docs.wandb.com/ref/python.
12
12
  from __future__ import annotations
13
13
 
14
14
  __all__ = (
15
- "__version__",
15
+ "__version__", # doc:exclude
16
16
  "init",
17
17
  "finish",
18
18
  "setup",
19
19
  "login",
20
- "save",
20
+ "save", # doc:exclude
21
21
  "sweep",
22
22
  "controller",
23
23
  "agent",
24
- "config",
25
- "log",
26
- "summary",
24
+ "config", # doc:exclude
25
+ "log", # doc:exclude
26
+ "summary", # doc:exclude
27
27
  "Api",
28
- "Graph",
28
+ "Graph", # doc:exclude
29
29
  "Image",
30
30
  "Plotly",
31
31
  "Video",
@@ -36,26 +36,27 @@ __all__ = (
36
36
  "Object3D",
37
37
  "Molecule",
38
38
  "Histogram",
39
- "ArtifactTTL",
40
- "log_artifact",
41
- "use_artifact",
42
- "log_model",
43
- "use_model",
44
- "link_model",
45
- "define_metric",
46
- "Error",
47
- "termsetup",
48
- "termlog",
49
- "termerror",
50
- "termwarn",
39
+ "ArtifactTTL", # doc:exclude
40
+ "log_artifact", # doc:exclude
41
+ "use_artifact", # doc:exclude
42
+ "log_model", # doc:exclude
43
+ "use_model", # doc:exclude
44
+ "link_model", # doc:exclude
45
+ "define_metric", # doc:exclude
46
+ "Error", # doc:exclude
47
+ "termsetup", # doc:exclude
48
+ "termlog", # doc:exclude
49
+ "termerror", # doc:exclude
50
+ "termwarn", # doc:exclude
51
51
  "Artifact",
52
52
  "Settings",
53
53
  "teardown",
54
- "watch",
55
- "unwatch",
56
- "plot",
54
+ "watch", # doc:exclude
55
+ "unwatch", # doc:exclude
56
+ "plot", # doc:exclude
57
57
  "plot_table",
58
58
  "restore",
59
+ "Run",
59
60
  )
60
61
 
61
62
  import os
@@ -106,7 +107,7 @@ if TYPE_CHECKING:
106
107
  import wandb
107
108
  from wandb.plot import CustomChart
108
109
 
109
- __version__: str = "0.20.1rc20250604"
110
+ __version__: str = "0.21.0"
110
111
 
111
112
  run: Run | None
112
113
  config: wandb_config.Config
@@ -325,7 +326,7 @@ def init(
325
326
  the UI.
326
327
  If resuming a run, the tags provided here will replace any existing
327
328
  tags. To add tags to a resumed run without overwriting the current
328
- tags, use `run.tags += ["new_tag"]` after calling `run = wandb.init()`.
329
+ tags, use `run.tags += ("new_tag",)` after calling `run = wandb.init()`.
329
330
  config: Sets `wandb.config`, a dictionary-like object for storing input
330
331
  parameters to your run, such as model hyperparameters or data
331
332
  preprocessing settings.
wandb/analytics/sentry.py CHANGED
@@ -15,6 +15,7 @@ from urllib.parse import quote
15
15
  import sentry_sdk # type: ignore
16
16
  import sentry_sdk.scope # type: ignore
17
17
  import sentry_sdk.utils # type: ignore
18
+ from typing_extensions import Never
18
19
 
19
20
  import wandb
20
21
  import wandb.env
@@ -143,7 +144,7 @@ class Sentry:
143
144
 
144
145
  return event_id
145
146
 
146
- def reraise(self, exc: Any) -> None:
147
+ def reraise(self, exc: Any) -> Never:
147
148
  """Re-raise an exception after logging it to Sentry.
148
149
 
149
150
  Use this for top-level exceptions when you want the user to see the traceback.
@@ -209,7 +210,6 @@ class Sentry:
209
210
  "sweep_url",
210
211
  "sweep_id",
211
212
  "deployment",
212
- "x_require_legacy_service",
213
213
  "launch",
214
214
  "_platform",
215
215
  )
@@ -364,9 +364,6 @@ def send_run(
364
364
  sm = AlternateSendManager(
365
365
  settings, sm_record_q, result_q, interface, context_keeper
366
366
  )
367
- # wm = WriteManager(
368
- # settings, wm_record_q, result_q, sm_record_q, interface, context_keeper
369
- # )
370
367
 
371
368
  if extra_arts or extra_used_arts:
372
369
  records = rm.make_artifacts_only_records(extra_arts, extra_used_arts)
wandb/apis/internal.py CHANGED
@@ -211,6 +211,9 @@ class Api:
211
211
  def upsert_run_queue(self, *args, **kwargs):
212
212
  return self.api.upsert_run_queue(*args, **kwargs)
213
213
 
214
+ def create_custom_chart(self, *args, **kwargs):
215
+ return self.api.create_custom_chart(*args, **kwargs)
216
+
214
217
  def update_launch_agent_status(self, *args, **kwargs):
215
218
  return self.api.update_launch_agent_status(*args, **kwargs)
216
219
 
wandb/apis/paginator.py CHANGED
@@ -13,6 +13,8 @@ from typing import (
13
13
  overload,
14
14
  )
15
15
 
16
+ import wandb
17
+
16
18
  if TYPE_CHECKING:
17
19
  from wandb_graphql.language.ast import Document
18
20
 
@@ -112,14 +114,25 @@ class Paginator(Iterator[T]):
112
114
  class SizedPaginator(Paginator[T], Sized):
113
115
  """A Paginator for objects with a known total count."""
114
116
 
117
+ @property
118
+ def length(self) -> int | None:
119
+ wandb.termwarn(
120
+ (
121
+ "`.length` is deprecated and will be removed in a future version. "
122
+ "Use `len(...)` instead."
123
+ ),
124
+ repeat=False,
125
+ )
126
+ return len(self)
127
+
115
128
  def __len__(self) -> int:
116
- if self.length is None:
129
+ if self._length is None:
117
130
  self._load_page()
118
- if self.length is None:
131
+ if self._length is None:
119
132
  raise ValueError("Object doesn't provide length")
120
- return self.length
133
+ return self._length
121
134
 
122
135
  @property
123
136
  @abstractmethod
124
- def length(self) -> int | None:
137
+ def _length(self) -> int | None:
125
138
  raise NotImplementedError
wandb/apis/public/api.py CHANGED
@@ -40,9 +40,9 @@ from wandb._iterutils import one
40
40
  from wandb.apis import public
41
41
  from wandb.apis.normalize import normalize_exceptions
42
42
  from wandb.apis.public.const import RETRY_TIMEDELTA
43
+ from wandb.apis.public.registries._utils import fetch_org_entity_from_organization
43
44
  from wandb.apis.public.registries.registries_search import Registries
44
45
  from wandb.apis.public.registries.registry import Registry
45
- from wandb.apis.public.registries.utils import _fetch_org_entity_from_organization
46
46
  from wandb.apis.public.utils import (
47
47
  PathType,
48
48
  fetch_org_from_settings_or_entity,
@@ -467,6 +467,85 @@ class Api:
467
467
  _default_resource_config=config,
468
468
  )
469
469
 
470
+ def create_custom_chart(
471
+ self,
472
+ entity: str,
473
+ name: str,
474
+ display_name: str,
475
+ spec_type: Literal["vega2"],
476
+ access: Literal["private", "public"],
477
+ spec: Union[str, dict],
478
+ ) -> str:
479
+ """Create a custom chart preset and return its id.
480
+
481
+ Args:
482
+ entity: The entity (user or team) that owns the chart
483
+ name: Unique identifier for the chart preset
484
+ display_name: Human-readable name shown in the UI
485
+ spec_type: Type of specification. Must be "vega2" for Vega-Lite v2 specifications.
486
+ access: Access level for the chart:
487
+ - "private": Chart is only accessible to the entity that created it
488
+ - "public": Chart is publicly accessible
489
+ spec: The Vega/Vega-Lite specification as a dictionary or JSON string
490
+
491
+ Returns:
492
+ The ID of the created chart preset in the format "entity/name"
493
+
494
+ Raises:
495
+ wandb.Error: If chart creation fails
496
+ UnsupportedError: If the server doesn't support custom charts
497
+
498
+ Example:
499
+ ```python
500
+ import wandb
501
+
502
+ api = wandb.Api()
503
+
504
+ # Define a simple bar chart specification
505
+ vega_spec = {
506
+ "$schema": "https://vega.github.io/schema/vega-lite/v6.json",
507
+ "mark": "bar",
508
+ "data": {"name": "wandb"},
509
+ "encoding": {
510
+ "x": {"field": "${field:x}", "type": "ordinal"},
511
+ "y": {"field": "${field:y}", "type": "quantitative"},
512
+ },
513
+ }
514
+
515
+ # Create the custom chart
516
+ chart_id = api.create_custom_chart(
517
+ entity="my-team",
518
+ name="my-bar-chart",
519
+ display_name="My Custom Bar Chart",
520
+ spec_type="vega2",
521
+ access="private",
522
+ spec=vega_spec,
523
+ )
524
+
525
+ # Use with wandb.plot_table()
526
+ chart = wandb.plot_table(
527
+ vega_spec_name=chart_id,
528
+ data_table=my_table,
529
+ fields={"x": "category", "y": "value"},
530
+ )
531
+ ```
532
+ """
533
+ # Convert user-facing lowercase access to backend uppercase
534
+ backend_access = access.upper()
535
+
536
+ api = InternalApi(retry_timedelta=RETRY_TIMEDELTA)
537
+ result = api.create_custom_chart(
538
+ entity=entity,
539
+ name=name,
540
+ display_name=display_name,
541
+ spec_type=spec_type,
542
+ access=backend_access,
543
+ spec=spec,
544
+ )
545
+ if result is None or result.get("chart") is None:
546
+ raise wandb.Error("failed to create custom chart")
547
+ return result["chart"]["id"]
548
+
470
549
  def upsert_run_queue(
471
550
  self,
472
551
  name: str,
@@ -713,7 +792,7 @@ class Api:
713
792
  return public.BetaReport(
714
793
  self.client,
715
794
  {
716
- "display_name": urllib.parse.unquote(name.replace("-", " ")),
795
+ "displayName": urllib.parse.unquote(name.replace("-", " ")),
717
796
  "id": id,
718
797
  "spec": "{}",
719
798
  },
@@ -1514,7 +1593,9 @@ class Api:
1514
1593
 
1515
1594
  Find all collections in the registries with the name "my_collection" and the tag "my_tag"
1516
1595
  ```python
1517
- api.registries().collections(filter={"name": "my_collection", "tag": "my_tag"})
1596
+ api.registries().collections(
1597
+ filter={"name": "my_collection", "tag": "my_tag"}
1598
+ )
1518
1599
  ```
1519
1600
 
1520
1601
  Find all artifact versions in the registries with a collection name that contains "my_collection" and a version that has the alias "best"
@@ -1589,7 +1670,7 @@ class Api:
1589
1670
  organization = organization or fetch_org_from_settings_or_entity(
1590
1671
  self.settings, self.default_entity
1591
1672
  )
1592
- org_entity = _fetch_org_entity_from_organization(self.client, organization)
1673
+ org_entity = fetch_org_entity_from_organization(self.client, organization)
1593
1674
  registry = Registry(self.client, organization, org_entity, name)
1594
1675
  registry.load()
1595
1676
  return registry
@@ -100,7 +100,7 @@ class ArtifactTypes(Paginator["ArtifactType"]):
100
100
  self.last_response = ArtifactTypesFragment.model_validate(conn)
101
101
 
102
102
  @property
103
- def length(self) -> None:
103
+ def _length(self) -> None:
104
104
  # TODO
105
105
  return None
106
106
 
@@ -240,9 +240,9 @@ class ArtifactCollections(SizedPaginator["ArtifactCollection"]):
240
240
  self.last_response = ArtifactCollectionsFragment.model_validate(conn)
241
241
 
242
242
  @property
243
- def length(self):
243
+ def _length(self) -> int:
244
244
  if self.last_response is None:
245
- return None
245
+ self._load_page()
246
246
  return self.last_response.total_count
247
247
 
248
248
  @property
@@ -608,9 +608,9 @@ class Artifacts(SizedPaginator["Artifact"]):
608
608
  self.last_response = ArtifactsFragment.model_validate(conn)
609
609
 
610
610
  @property
611
- def length(self) -> int | None:
611
+ def _length(self) -> int:
612
612
  if self.last_response is None:
613
- return None
613
+ self._load_page()
614
614
  return self.last_response.total_count
615
615
 
616
616
  @property
@@ -698,9 +698,9 @@ class RunArtifacts(SizedPaginator["Artifact"]):
698
698
  self.last_response = self._response_cls.model_validate(inner_data)
699
699
 
700
700
  @property
701
- def length(self) -> int | None:
701
+ def _length(self) -> int:
702
702
  if self.last_response is None:
703
- return None
703
+ self._load_page()
704
704
  return self.last_response.total_count
705
705
 
706
706
  @property
@@ -799,7 +799,9 @@ class ArtifactFiles(SizedPaginator["public.File"]):
799
799
  return [self.artifact.entity, self.artifact.project, self.artifact.name]
800
800
 
801
801
  @property
802
- def length(self) -> int:
802
+ def _length(self) -> int:
803
+ if self.last_response is None:
804
+ self._load_page()
803
805
  return self.artifact.file_count
804
806
 
805
807
  @property
@@ -72,11 +72,11 @@ class Files(SizedPaginator["File"]):
72
72
  super().__init__(client, variables, per_page)
73
73
 
74
74
  @property
75
- def length(self):
76
- if self.last_response:
77
- return self.last_response["project"]["run"]["fileCount"]
78
- else:
79
- return None
75
+ def _length(self):
76
+ if not self.last_response:
77
+ self._load_page()
78
+
79
+ return self.last_response["project"]["run"]["fileCount"]
80
80
 
81
81
  @property
82
82
  def more(self):
@@ -9,6 +9,7 @@ from wandb.apis import public
9
9
  from wandb.apis.attrs import Attrs
10
10
  from wandb.apis.normalize import normalize_exceptions
11
11
  from wandb.apis.paginator import Paginator
12
+ from wandb.apis.public.api import RetryingClient
12
13
  from wandb.sdk.lib import ipython
13
14
 
14
15
  PROJECT_FRAGMENT = """fragment ProjectFragment on Project {
@@ -43,7 +44,19 @@ class Projects(Paginator["Project"]):
43
44
  """.format(PROJECT_FRAGMENT)
44
45
  )
45
46
 
46
- def __init__(self, client, entity, per_page=50):
47
+ def __init__(
48
+ self,
49
+ client: RetryingClient,
50
+ entity: str,
51
+ per_page: int = 50,
52
+ ) -> "Projects":
53
+ """An iterable collection of `Project` objects.
54
+
55
+ Args:
56
+ client: The API client used to query W&B.
57
+ entity: The entity which owns the projects.
58
+ per_page: The number of projects to fetch per request to the API.
59
+ """
47
60
  self.client = client
48
61
  self.entity = entity
49
62
  variables = {
@@ -83,7 +96,31 @@ class Projects(Paginator["Project"]):
83
96
  class Project(Attrs):
84
97
  """A project is a namespace for runs."""
85
98
 
86
- def __init__(self, client, entity, project, attrs):
99
+ QUERY = gql(
100
+ """
101
+ query Project($project: String!, $entity: String!) {
102
+ project(name: $project, entityName: $entity) {
103
+ id
104
+ }
105
+ }
106
+ """
107
+ )
108
+
109
+ def __init__(
110
+ self,
111
+ client: RetryingClient,
112
+ entity: str,
113
+ project: str,
114
+ attrs: dict,
115
+ ) -> "Project":
116
+ """A single project associated with an entity.
117
+
118
+ Args:
119
+ client: The API client used to query W&B.
120
+ entity: The entity which owns the project.
121
+ project: The name of the project to query.
122
+ attrs: The attributes of the project.
123
+ """
87
124
  super().__init__(dict(attrs))
88
125
  self.client = client
89
126
  self.name = project
@@ -143,7 +180,7 @@ class Project(Attrs):
143
180
  )
144
181
  variable_values = {"project": self.name, "entity": self.entity}
145
182
  ret = self.client.execute(query, variable_values)
146
- if ret["project"]["totalSweeps"] < 1:
183
+ if not ret.get("project") or ret["project"]["totalSweeps"] < 1:
147
184
  return []
148
185
 
149
186
  return [
@@ -178,6 +215,10 @@ class Project(Attrs):
178
215
  variable_values = {"projectName": self.name, "entityName": self.entity}
179
216
  try:
180
217
  data = self.client.execute(self._PROJECT_ID, variable_values)
218
+
219
+ if not data.get("project") or not data["project"].get("id"):
220
+ raise ValueError(f"Project {self.name} not found")
221
+
181
222
  self._attrs["id"] = data["project"]["id"]
182
223
  return self._attrs["id"]
183
224
  except (HTTPError, LookupError, TypeError) as e:
@@ -13,7 +13,7 @@ if TYPE_CHECKING:
13
13
  from wandb_gql import gql
14
14
 
15
15
 
16
- class _Visibility(str, Enum):
16
+ class Visibility(str, Enum):
17
17
  # names are what users see/pass into Python methods
18
18
  # values are what's expected by backend API
19
19
  organization = "PRIVATE"
@@ -27,7 +27,7 @@ class _Visibility(str, Enum):
27
27
  )
28
28
 
29
29
 
30
- def _format_gql_artifact_types_input(
30
+ def format_gql_artifact_types_input(
31
31
  artifact_types: Optional[List[str]] = None,
32
32
  ):
33
33
  """Format the artifact types for the GQL input.
@@ -44,7 +44,7 @@ def _format_gql_artifact_types_input(
44
44
  return [{"name": type} for type in new_types]
45
45
 
46
46
 
47
- def _gql_to_registry_visibility(
47
+ def gql_to_registry_visibility(
48
48
  visibility: str,
49
49
  ) -> Literal["organization", "restricted"]:
50
50
  """Convert the GQL visibility to the registry visibility.
@@ -56,25 +56,25 @@ def _gql_to_registry_visibility(
56
56
  The registry visibility.
57
57
  """
58
58
  try:
59
- return _Visibility(visibility).name
59
+ return Visibility(visibility).name
60
60
  except ValueError:
61
61
  raise ValueError(f"Invalid visibility: {visibility!r} from backend")
62
62
 
63
63
 
64
- def _registry_visibility_to_gql(
64
+ def registry_visibility_to_gql(
65
65
  visibility: Literal["organization", "restricted"],
66
66
  ) -> str:
67
67
  """Convert the registry visibility to the GQL visibility."""
68
68
  try:
69
- return _Visibility[visibility].value
69
+ return Visibility[visibility].value
70
70
  except KeyError:
71
71
  raise ValueError(
72
72
  f"Invalid visibility: {visibility!r}. "
73
- f"Must be one of: {', '.join(map(repr, (e.name for e in _Visibility)))}"
73
+ f"Must be one of: {', '.join(map(repr, (e.name for e in Visibility)))}"
74
74
  )
75
75
 
76
76
 
77
- def _ensure_registry_prefix_on_names(query, in_name=False):
77
+ def ensure_registry_prefix_on_names(query, in_name=False):
78
78
  """Traverse the filter to prepend the `name` key value with the registry prefix unless the value is a regex.
79
79
 
80
80
  - in_name: True if we are under a "name" key (or propagating from one).
@@ -89,23 +89,23 @@ def _ensure_registry_prefix_on_names(query, in_name=False):
89
89
  new_dict = {}
90
90
  for key, obj in dct.items():
91
91
  if key == "name":
92
- new_dict[key] = _ensure_registry_prefix_on_names(obj, in_name=True)
92
+ new_dict[key] = ensure_registry_prefix_on_names(obj, in_name=True)
93
93
  elif key == "$regex":
94
94
  # For regex operator, we skip transformation of its value.
95
95
  new_dict[key] = obj
96
96
  else:
97
97
  # For any other key, propagate the in_name and skip_transform flags as-is.
98
- new_dict[key] = _ensure_registry_prefix_on_names(obj, in_name=in_name)
98
+ new_dict[key] = ensure_registry_prefix_on_names(obj, in_name=in_name)
99
99
  return new_dict
100
100
  if isinstance((objs := query), Sequence):
101
101
  return list(
102
- map(lambda x: _ensure_registry_prefix_on_names(x, in_name=in_name), objs)
102
+ map(lambda x: ensure_registry_prefix_on_names(x, in_name=in_name), objs)
103
103
  )
104
104
  return query
105
105
 
106
106
 
107
107
  @lru_cache(maxsize=10)
108
- def _fetch_org_entity_from_organization(client: "Client", organization: str) -> str:
108
+ def fetch_org_entity_from_organization(client: "Client", organization: str) -> str:
109
109
  """Fetch the org entity from the organization.
110
110
 
111
111
  Args:
@@ -15,7 +15,7 @@ from wandb.sdk.artifacts._graphql_fragments import (
15
15
  _gql_registry_fragment,
16
16
  )
17
17
 
18
- from .utils import _ensure_registry_prefix_on_names
18
+ from ._utils import ensure_registry_prefix_on_names
19
19
 
20
20
 
21
21
  class Registries(Paginator):
@@ -54,7 +54,7 @@ class Registries(Paginator):
54
54
  ):
55
55
  self.client = client
56
56
  self.organization = organization
57
- self.filter = _ensure_registry_prefix_on_names(filter or {})
57
+ self.filter = ensure_registry_prefix_on_names(filter or {})
58
58
  variables = {
59
59
  "organization": organization,
60
60
  "filters": json.dumps(self.filter),
@@ -3,26 +3,27 @@ from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
3
3
  from wandb_gql import gql
4
4
 
5
5
  import wandb
6
- from wandb.apis.public.registries._freezable_list import AddOnlyArtifactTypesList
7
- from wandb.apis.public.registries.registries_search import Collections, Versions
8
- from wandb.apis.public.registries.utils import (
9
- _fetch_org_entity_from_organization,
10
- _format_gql_artifact_types_input,
11
- _gql_to_registry_visibility,
12
- _registry_visibility_to_gql,
13
- )
14
6
  from wandb.proto.wandb_internal_pb2 import ServerFeature
15
7
  from wandb.sdk.artifacts._validators import REGISTRY_PREFIX, validate_project_name
16
8
  from wandb.sdk.internal.internal_api import Api as InternalApi
17
- from wandb.sdk.projects._generated.delete_project import DeleteProject
18
- from wandb.sdk.projects._generated.operations import (
9
+ from wandb.sdk.projects._generated import (
19
10
  DELETE_PROJECT_GQL,
20
11
  FETCH_REGISTRY_GQL,
21
12
  RENAME_PROJECT_GQL,
22
13
  UPSERT_REGISTRY_PROJECT_GQL,
14
+ DeleteProject,
15
+ RenameProject,
16
+ UpsertRegistryProject,
17
+ )
18
+
19
+ from ._freezable_list import AddOnlyArtifactTypesList
20
+ from ._utils import (
21
+ fetch_org_entity_from_organization,
22
+ format_gql_artifact_types_input,
23
+ gql_to_registry_visibility,
24
+ registry_visibility_to_gql,
23
25
  )
24
- from wandb.sdk.projects._generated.rename_project import RenameProject
25
- from wandb.sdk.projects._generated.upsert_registry_project import UpsertRegistryProject
26
+ from .registries_search import Collections, Versions
26
27
 
27
28
  if TYPE_CHECKING:
28
29
  from wandb_gql import Client
@@ -62,7 +63,7 @@ class Registry:
62
63
  )
63
64
  self._created_at = attrs.get("createdAt", "")
64
65
  self._updated_at = attrs.get("updatedAt", "")
65
- self._visibility = _gql_to_registry_visibility(attrs.get("access", ""))
66
+ self._visibility = gql_to_registry_visibility(attrs.get("access", ""))
66
67
 
67
68
  @property
68
69
  def full_name(self) -> str:
@@ -220,13 +221,13 @@ class Registry:
220
221
  ValueError: If a registry with the same name already exists in the
221
222
  organization or if the creation fails.
222
223
  """
223
- org_entity = _fetch_org_entity_from_organization(client, organization)
224
+ org_entity = fetch_org_entity_from_organization(client, organization)
224
225
  full_name = REGISTRY_PREFIX + name
225
226
  validate_project_name(full_name)
226
227
  accepted_artifact_types = []
227
228
  if artifact_types:
228
- accepted_artifact_types = _format_gql_artifact_types_input(artifact_types)
229
- visibility_value = _registry_visibility_to_gql(visibility)
229
+ accepted_artifact_types = format_gql_artifact_types_input(artifact_types)
230
+ visibility_value = registry_visibility_to_gql(visibility)
230
231
  registry_creation_error = (
231
232
  f"Failed to create registry {name!r} in organization {organization!r}."
232
233
  )
@@ -310,8 +311,8 @@ class Registry:
310
311
  )
311
312
 
312
313
  validate_project_name(self.full_name)
313
- visibility_value = _registry_visibility_to_gql(self.visibility)
314
- newly_added_types = _format_gql_artifact_types_input(self.artifact_types.draft)
314
+ visibility_value = registry_visibility_to_gql(self.visibility)
315
+ newly_added_types = format_gql_artifact_types_input(self.artifact_types.draft)
315
316
  registry_save_error = f"Failed to save and update registry: {self.name} in organization: {self.organization}"
316
317
  full_saved_name = f"{REGISTRY_PREFIX}{self._saved_name}"
317
318
  try: