wandb 0.20.1rc20250604__py3-none-musllinux_1_2_aarch64.whl → 0.21.0__py3-none-musllinux_1_2_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +3 -6
- wandb/__init__.pyi +24 -23
- wandb/analytics/sentry.py +2 -2
- wandb/apis/importers/internals/internal.py +0 -3
- wandb/apis/internal.py +3 -0
- wandb/apis/paginator.py +17 -4
- wandb/apis/public/api.py +85 -4
- wandb/apis/public/artifacts.py +10 -8
- wandb/apis/public/files.py +5 -5
- wandb/apis/public/projects.py +44 -3
- wandb/apis/public/registries/{utils.py → _utils.py} +12 -12
- wandb/apis/public/registries/registries_search.py +2 -2
- wandb/apis/public/registries/registry.py +19 -18
- wandb/apis/public/reports.py +64 -8
- wandb/apis/public/runs.py +16 -23
- wandb/automations/__init__.py +10 -10
- wandb/automations/_filters/run_metrics.py +0 -2
- wandb/automations/_utils.py +0 -2
- wandb/automations/actions.py +0 -2
- wandb/automations/automations.py +0 -2
- wandb/automations/events.py +0 -2
- wandb/bin/gpu_stats +0 -0
- wandb/bin/wandb-core +0 -0
- wandb/cli/beta.py +1 -7
- wandb/cli/cli.py +0 -30
- wandb/env.py +0 -6
- wandb/integration/catboost/catboost.py +6 -2
- wandb/integration/kfp/kfp_patch.py +3 -1
- wandb/integration/sb3/sb3.py +3 -3
- wandb/integration/ultralytics/callback.py +6 -2
- wandb/plot/__init__.py +2 -0
- wandb/plot/bar.py +30 -29
- wandb/plot/confusion_matrix.py +75 -71
- wandb/plot/histogram.py +26 -25
- wandb/plot/line.py +33 -32
- wandb/plot/line_series.py +100 -103
- wandb/plot/pr_curve.py +33 -32
- wandb/plot/roc_curve.py +38 -38
- wandb/plot/scatter.py +27 -27
- wandb/proto/v3/wandb_internal_pb2.py +366 -385
- wandb/proto/v3/wandb_settings_pb2.py +2 -2
- wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v4/wandb_internal_pb2.py +352 -356
- wandb/proto/v4/wandb_settings_pb2.py +2 -2
- wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v5/wandb_internal_pb2.py +352 -356
- wandb/proto/v5/wandb_settings_pb2.py +2 -2
- wandb/proto/v5/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v6/wandb_internal_pb2.py +352 -356
- wandb/proto/v6/wandb_settings_pb2.py +2 -2
- wandb/proto/v6/wandb_telemetry_pb2.py +10 -10
- wandb/sdk/artifacts/_generated/__init__.py +12 -1
- wandb/sdk/artifacts/_generated/input_types.py +20 -2
- wandb/sdk/artifacts/_generated/link_artifact.py +21 -0
- wandb/sdk/artifacts/_generated/operations.py +9 -0
- wandb/sdk/artifacts/_validators.py +40 -2
- wandb/sdk/artifacts/artifact.py +163 -21
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +42 -1
- wandb/sdk/backend/backend.py +1 -1
- wandb/sdk/data_types/base_types/media.py +9 -7
- wandb/sdk/data_types/base_types/wb_value.py +6 -6
- wandb/sdk/data_types/saved_model.py +3 -3
- wandb/sdk/data_types/table.py +41 -41
- wandb/sdk/data_types/trace_tree.py +12 -12
- wandb/sdk/interface/interface.py +8 -19
- wandb/sdk/interface/interface_shared.py +7 -16
- wandb/sdk/internal/datastore.py +18 -18
- wandb/sdk/internal/handler.py +4 -74
- wandb/sdk/internal/internal_api.py +54 -0
- wandb/sdk/internal/sender.py +23 -3
- wandb/sdk/internal/sender_config.py +9 -0
- wandb/sdk/launch/_project_spec.py +3 -3
- wandb/sdk/launch/agent/agent.py +3 -3
- wandb/sdk/launch/agent/job_status_tracker.py +3 -1
- wandb/sdk/launch/utils.py +3 -3
- wandb/sdk/lib/console_capture.py +66 -19
- wandb/sdk/lib/printer.py +6 -7
- wandb/sdk/lib/progress.py +1 -3
- wandb/sdk/lib/service/ipc_support.py +13 -0
- wandb/sdk/lib/{service_connection.py → service/service_connection.py} +20 -56
- wandb/sdk/lib/service/service_port_file.py +105 -0
- wandb/sdk/lib/service/service_process.py +111 -0
- wandb/sdk/lib/service/service_token.py +164 -0
- wandb/sdk/lib/sock_client.py +8 -12
- wandb/sdk/wandb_init.py +1 -5
- wandb/sdk/wandb_require.py +9 -21
- wandb/sdk/wandb_run.py +23 -137
- wandb/sdk/wandb_settings.py +233 -80
- wandb/sdk/wandb_setup.py +2 -13
- {wandb-0.20.1rc20250604.dist-info → wandb-0.21.0.dist-info}/METADATA +1 -3
- {wandb-0.20.1rc20250604.dist-info → wandb-0.21.0.dist-info}/RECORD +94 -120
- wandb/sdk/internal/flow_control.py +0 -263
- wandb/sdk/internal/internal.py +0 -401
- wandb/sdk/internal/internal_util.py +0 -97
- wandb/sdk/internal/system/__init__.py +0 -0
- wandb/sdk/internal/system/assets/__init__.py +0 -25
- wandb/sdk/internal/system/assets/aggregators.py +0 -31
- wandb/sdk/internal/system/assets/asset_registry.py +0 -20
- wandb/sdk/internal/system/assets/cpu.py +0 -163
- wandb/sdk/internal/system/assets/disk.py +0 -210
- wandb/sdk/internal/system/assets/gpu.py +0 -416
- wandb/sdk/internal/system/assets/gpu_amd.py +0 -233
- wandb/sdk/internal/system/assets/interfaces.py +0 -205
- wandb/sdk/internal/system/assets/ipu.py +0 -177
- wandb/sdk/internal/system/assets/memory.py +0 -166
- wandb/sdk/internal/system/assets/network.py +0 -125
- wandb/sdk/internal/system/assets/open_metrics.py +0 -293
- wandb/sdk/internal/system/assets/tpu.py +0 -154
- wandb/sdk/internal/system/assets/trainium.py +0 -393
- wandb/sdk/internal/system/env_probe_helpers.py +0 -13
- wandb/sdk/internal/system/system_info.py +0 -248
- wandb/sdk/internal/system/system_monitor.py +0 -224
- wandb/sdk/internal/writer.py +0 -204
- wandb/sdk/lib/service_token.py +0 -93
- wandb/sdk/service/__init__.py +0 -0
- wandb/sdk/service/_startup_debug.py +0 -22
- wandb/sdk/service/port_file.py +0 -53
- wandb/sdk/service/server.py +0 -107
- wandb/sdk/service/server_sock.py +0 -286
- wandb/sdk/service/service.py +0 -252
- wandb/sdk/service/streams.py +0 -425
- wandb/sdk/wandb_metadata.py +0 -623
- {wandb-0.20.1rc20250604.dist-info → wandb-0.21.0.dist-info}/WHEEL +0 -0
- {wandb-0.20.1rc20250604.dist-info → wandb-0.21.0.dist-info}/entry_points.txt +0 -0
- {wandb-0.20.1rc20250604.dist-info → wandb-0.21.0.dist-info}/licenses/LICENSE +0 -0
wandb/__init__.py
CHANGED
@@ -10,7 +10,7 @@ For reference documentation, see https://docs.wandb.com/ref/python.
|
|
10
10
|
"""
|
11
11
|
from __future__ import annotations
|
12
12
|
|
13
|
-
__version__ = "0.
|
13
|
+
__version__ = "0.21.0"
|
14
14
|
|
15
15
|
|
16
16
|
from wandb.errors import Error
|
@@ -30,9 +30,9 @@ wandb.wandb_lib = wandb_sdk.lib # type: ignore
|
|
30
30
|
|
31
31
|
init = wandb_sdk.init
|
32
32
|
setup = wandb_sdk.setup
|
33
|
-
_attach = wandb_sdk._attach
|
33
|
+
attach = _attach = wandb_sdk._attach
|
34
34
|
_sync = wandb_sdk._sync
|
35
|
-
_teardown = wandb_sdk.teardown
|
35
|
+
teardown = _teardown = wandb_sdk.teardown
|
36
36
|
finish = wandb_sdk.finish
|
37
37
|
join = finish
|
38
38
|
login = wandb_sdk.login
|
@@ -51,9 +51,6 @@ from wandb.errors import CommError, UsageError
|
|
51
51
|
_preinit = wandb.wandb_lib.preinit # type: ignore
|
52
52
|
_lazyloader = wandb.wandb_lib.lazyloader # type: ignore
|
53
53
|
|
54
|
-
# Call import module hook to set up any needed require hooks
|
55
|
-
wandb.sdk.wandb_require._import_module_hook()
|
56
|
-
|
57
54
|
from wandb.integration.torch import wandb_torch
|
58
55
|
|
59
56
|
# Move this (keras.__init__ expects it at top level)
|
wandb/__init__.pyi
CHANGED
@@ -12,20 +12,20 @@ For reference documentation, see https://docs.wandb.com/ref/python.
|
|
12
12
|
from __future__ import annotations
|
13
13
|
|
14
14
|
__all__ = (
|
15
|
-
"__version__",
|
15
|
+
"__version__", # doc:exclude
|
16
16
|
"init",
|
17
17
|
"finish",
|
18
18
|
"setup",
|
19
19
|
"login",
|
20
|
-
"save",
|
20
|
+
"save", # doc:exclude
|
21
21
|
"sweep",
|
22
22
|
"controller",
|
23
23
|
"agent",
|
24
|
-
"config",
|
25
|
-
"log",
|
26
|
-
"summary",
|
24
|
+
"config", # doc:exclude
|
25
|
+
"log", # doc:exclude
|
26
|
+
"summary", # doc:exclude
|
27
27
|
"Api",
|
28
|
-
"Graph",
|
28
|
+
"Graph", # doc:exclude
|
29
29
|
"Image",
|
30
30
|
"Plotly",
|
31
31
|
"Video",
|
@@ -36,26 +36,27 @@ __all__ = (
|
|
36
36
|
"Object3D",
|
37
37
|
"Molecule",
|
38
38
|
"Histogram",
|
39
|
-
"ArtifactTTL",
|
40
|
-
"log_artifact",
|
41
|
-
"use_artifact",
|
42
|
-
"log_model",
|
43
|
-
"use_model",
|
44
|
-
"link_model",
|
45
|
-
"define_metric",
|
46
|
-
"Error",
|
47
|
-
"termsetup",
|
48
|
-
"termlog",
|
49
|
-
"termerror",
|
50
|
-
"termwarn",
|
39
|
+
"ArtifactTTL", # doc:exclude
|
40
|
+
"log_artifact", # doc:exclude
|
41
|
+
"use_artifact", # doc:exclude
|
42
|
+
"log_model", # doc:exclude
|
43
|
+
"use_model", # doc:exclude
|
44
|
+
"link_model", # doc:exclude
|
45
|
+
"define_metric", # doc:exclude
|
46
|
+
"Error", # doc:exclude
|
47
|
+
"termsetup", # doc:exclude
|
48
|
+
"termlog", # doc:exclude
|
49
|
+
"termerror", # doc:exclude
|
50
|
+
"termwarn", # doc:exclude
|
51
51
|
"Artifact",
|
52
52
|
"Settings",
|
53
53
|
"teardown",
|
54
|
-
"watch",
|
55
|
-
"unwatch",
|
56
|
-
"plot",
|
54
|
+
"watch", # doc:exclude
|
55
|
+
"unwatch", # doc:exclude
|
56
|
+
"plot", # doc:exclude
|
57
57
|
"plot_table",
|
58
58
|
"restore",
|
59
|
+
"Run",
|
59
60
|
)
|
60
61
|
|
61
62
|
import os
|
@@ -106,7 +107,7 @@ if TYPE_CHECKING:
|
|
106
107
|
import wandb
|
107
108
|
from wandb.plot import CustomChart
|
108
109
|
|
109
|
-
__version__: str = "0.
|
110
|
+
__version__: str = "0.21.0"
|
110
111
|
|
111
112
|
run: Run | None
|
112
113
|
config: wandb_config.Config
|
@@ -325,7 +326,7 @@ def init(
|
|
325
326
|
the UI.
|
326
327
|
If resuming a run, the tags provided here will replace any existing
|
327
328
|
tags. To add tags to a resumed run without overwriting the current
|
328
|
-
tags, use `run.tags +=
|
329
|
+
tags, use `run.tags += ("new_tag",)` after calling `run = wandb.init()`.
|
329
330
|
config: Sets `wandb.config`, a dictionary-like object for storing input
|
330
331
|
parameters to your run, such as model hyperparameters or data
|
331
332
|
preprocessing settings.
|
wandb/analytics/sentry.py
CHANGED
@@ -15,6 +15,7 @@ from urllib.parse import quote
|
|
15
15
|
import sentry_sdk # type: ignore
|
16
16
|
import sentry_sdk.scope # type: ignore
|
17
17
|
import sentry_sdk.utils # type: ignore
|
18
|
+
from typing_extensions import Never
|
18
19
|
|
19
20
|
import wandb
|
20
21
|
import wandb.env
|
@@ -143,7 +144,7 @@ class Sentry:
|
|
143
144
|
|
144
145
|
return event_id
|
145
146
|
|
146
|
-
def reraise(self, exc: Any) ->
|
147
|
+
def reraise(self, exc: Any) -> Never:
|
147
148
|
"""Re-raise an exception after logging it to Sentry.
|
148
149
|
|
149
150
|
Use this for top-level exceptions when you want the user to see the traceback.
|
@@ -209,7 +210,6 @@ class Sentry:
|
|
209
210
|
"sweep_url",
|
210
211
|
"sweep_id",
|
211
212
|
"deployment",
|
212
|
-
"x_require_legacy_service",
|
213
213
|
"launch",
|
214
214
|
"_platform",
|
215
215
|
)
|
@@ -364,9 +364,6 @@ def send_run(
|
|
364
364
|
sm = AlternateSendManager(
|
365
365
|
settings, sm_record_q, result_q, interface, context_keeper
|
366
366
|
)
|
367
|
-
# wm = WriteManager(
|
368
|
-
# settings, wm_record_q, result_q, sm_record_q, interface, context_keeper
|
369
|
-
# )
|
370
367
|
|
371
368
|
if extra_arts or extra_used_arts:
|
372
369
|
records = rm.make_artifacts_only_records(extra_arts, extra_used_arts)
|
wandb/apis/internal.py
CHANGED
@@ -211,6 +211,9 @@ class Api:
|
|
211
211
|
def upsert_run_queue(self, *args, **kwargs):
|
212
212
|
return self.api.upsert_run_queue(*args, **kwargs)
|
213
213
|
|
214
|
+
def create_custom_chart(self, *args, **kwargs):
|
215
|
+
return self.api.create_custom_chart(*args, **kwargs)
|
216
|
+
|
214
217
|
def update_launch_agent_status(self, *args, **kwargs):
|
215
218
|
return self.api.update_launch_agent_status(*args, **kwargs)
|
216
219
|
|
wandb/apis/paginator.py
CHANGED
@@ -13,6 +13,8 @@ from typing import (
|
|
13
13
|
overload,
|
14
14
|
)
|
15
15
|
|
16
|
+
import wandb
|
17
|
+
|
16
18
|
if TYPE_CHECKING:
|
17
19
|
from wandb_graphql.language.ast import Document
|
18
20
|
|
@@ -112,14 +114,25 @@ class Paginator(Iterator[T]):
|
|
112
114
|
class SizedPaginator(Paginator[T], Sized):
|
113
115
|
"""A Paginator for objects with a known total count."""
|
114
116
|
|
117
|
+
@property
|
118
|
+
def length(self) -> int | None:
|
119
|
+
wandb.termwarn(
|
120
|
+
(
|
121
|
+
"`.length` is deprecated and will be removed in a future version. "
|
122
|
+
"Use `len(...)` instead."
|
123
|
+
),
|
124
|
+
repeat=False,
|
125
|
+
)
|
126
|
+
return len(self)
|
127
|
+
|
115
128
|
def __len__(self) -> int:
|
116
|
-
if self.
|
129
|
+
if self._length is None:
|
117
130
|
self._load_page()
|
118
|
-
if self.
|
131
|
+
if self._length is None:
|
119
132
|
raise ValueError("Object doesn't provide length")
|
120
|
-
return self.
|
133
|
+
return self._length
|
121
134
|
|
122
135
|
@property
|
123
136
|
@abstractmethod
|
124
|
-
def
|
137
|
+
def _length(self) -> int | None:
|
125
138
|
raise NotImplementedError
|
wandb/apis/public/api.py
CHANGED
@@ -40,9 +40,9 @@ from wandb._iterutils import one
|
|
40
40
|
from wandb.apis import public
|
41
41
|
from wandb.apis.normalize import normalize_exceptions
|
42
42
|
from wandb.apis.public.const import RETRY_TIMEDELTA
|
43
|
+
from wandb.apis.public.registries._utils import fetch_org_entity_from_organization
|
43
44
|
from wandb.apis.public.registries.registries_search import Registries
|
44
45
|
from wandb.apis.public.registries.registry import Registry
|
45
|
-
from wandb.apis.public.registries.utils import _fetch_org_entity_from_organization
|
46
46
|
from wandb.apis.public.utils import (
|
47
47
|
PathType,
|
48
48
|
fetch_org_from_settings_or_entity,
|
@@ -467,6 +467,85 @@ class Api:
|
|
467
467
|
_default_resource_config=config,
|
468
468
|
)
|
469
469
|
|
470
|
+
def create_custom_chart(
|
471
|
+
self,
|
472
|
+
entity: str,
|
473
|
+
name: str,
|
474
|
+
display_name: str,
|
475
|
+
spec_type: Literal["vega2"],
|
476
|
+
access: Literal["private", "public"],
|
477
|
+
spec: Union[str, dict],
|
478
|
+
) -> str:
|
479
|
+
"""Create a custom chart preset and return its id.
|
480
|
+
|
481
|
+
Args:
|
482
|
+
entity: The entity (user or team) that owns the chart
|
483
|
+
name: Unique identifier for the chart preset
|
484
|
+
display_name: Human-readable name shown in the UI
|
485
|
+
spec_type: Type of specification. Must be "vega2" for Vega-Lite v2 specifications.
|
486
|
+
access: Access level for the chart:
|
487
|
+
- "private": Chart is only accessible to the entity that created it
|
488
|
+
- "public": Chart is publicly accessible
|
489
|
+
spec: The Vega/Vega-Lite specification as a dictionary or JSON string
|
490
|
+
|
491
|
+
Returns:
|
492
|
+
The ID of the created chart preset in the format "entity/name"
|
493
|
+
|
494
|
+
Raises:
|
495
|
+
wandb.Error: If chart creation fails
|
496
|
+
UnsupportedError: If the server doesn't support custom charts
|
497
|
+
|
498
|
+
Example:
|
499
|
+
```python
|
500
|
+
import wandb
|
501
|
+
|
502
|
+
api = wandb.Api()
|
503
|
+
|
504
|
+
# Define a simple bar chart specification
|
505
|
+
vega_spec = {
|
506
|
+
"$schema": "https://vega.github.io/schema/vega-lite/v6.json",
|
507
|
+
"mark": "bar",
|
508
|
+
"data": {"name": "wandb"},
|
509
|
+
"encoding": {
|
510
|
+
"x": {"field": "${field:x}", "type": "ordinal"},
|
511
|
+
"y": {"field": "${field:y}", "type": "quantitative"},
|
512
|
+
},
|
513
|
+
}
|
514
|
+
|
515
|
+
# Create the custom chart
|
516
|
+
chart_id = api.create_custom_chart(
|
517
|
+
entity="my-team",
|
518
|
+
name="my-bar-chart",
|
519
|
+
display_name="My Custom Bar Chart",
|
520
|
+
spec_type="vega2",
|
521
|
+
access="private",
|
522
|
+
spec=vega_spec,
|
523
|
+
)
|
524
|
+
|
525
|
+
# Use with wandb.plot_table()
|
526
|
+
chart = wandb.plot_table(
|
527
|
+
vega_spec_name=chart_id,
|
528
|
+
data_table=my_table,
|
529
|
+
fields={"x": "category", "y": "value"},
|
530
|
+
)
|
531
|
+
```
|
532
|
+
"""
|
533
|
+
# Convert user-facing lowercase access to backend uppercase
|
534
|
+
backend_access = access.upper()
|
535
|
+
|
536
|
+
api = InternalApi(retry_timedelta=RETRY_TIMEDELTA)
|
537
|
+
result = api.create_custom_chart(
|
538
|
+
entity=entity,
|
539
|
+
name=name,
|
540
|
+
display_name=display_name,
|
541
|
+
spec_type=spec_type,
|
542
|
+
access=backend_access,
|
543
|
+
spec=spec,
|
544
|
+
)
|
545
|
+
if result is None or result.get("chart") is None:
|
546
|
+
raise wandb.Error("failed to create custom chart")
|
547
|
+
return result["chart"]["id"]
|
548
|
+
|
470
549
|
def upsert_run_queue(
|
471
550
|
self,
|
472
551
|
name: str,
|
@@ -713,7 +792,7 @@ class Api:
|
|
713
792
|
return public.BetaReport(
|
714
793
|
self.client,
|
715
794
|
{
|
716
|
-
"
|
795
|
+
"displayName": urllib.parse.unquote(name.replace("-", " ")),
|
717
796
|
"id": id,
|
718
797
|
"spec": "{}",
|
719
798
|
},
|
@@ -1514,7 +1593,9 @@ class Api:
|
|
1514
1593
|
|
1515
1594
|
Find all collections in the registries with the name "my_collection" and the tag "my_tag"
|
1516
1595
|
```python
|
1517
|
-
api.registries().collections(
|
1596
|
+
api.registries().collections(
|
1597
|
+
filter={"name": "my_collection", "tag": "my_tag"}
|
1598
|
+
)
|
1518
1599
|
```
|
1519
1600
|
|
1520
1601
|
Find all artifact versions in the registries with a collection name that contains "my_collection" and a version that has the alias "best"
|
@@ -1589,7 +1670,7 @@ class Api:
|
|
1589
1670
|
organization = organization or fetch_org_from_settings_or_entity(
|
1590
1671
|
self.settings, self.default_entity
|
1591
1672
|
)
|
1592
|
-
org_entity =
|
1673
|
+
org_entity = fetch_org_entity_from_organization(self.client, organization)
|
1593
1674
|
registry = Registry(self.client, organization, org_entity, name)
|
1594
1675
|
registry.load()
|
1595
1676
|
return registry
|
wandb/apis/public/artifacts.py
CHANGED
@@ -100,7 +100,7 @@ class ArtifactTypes(Paginator["ArtifactType"]):
|
|
100
100
|
self.last_response = ArtifactTypesFragment.model_validate(conn)
|
101
101
|
|
102
102
|
@property
|
103
|
-
def
|
103
|
+
def _length(self) -> None:
|
104
104
|
# TODO
|
105
105
|
return None
|
106
106
|
|
@@ -240,9 +240,9 @@ class ArtifactCollections(SizedPaginator["ArtifactCollection"]):
|
|
240
240
|
self.last_response = ArtifactCollectionsFragment.model_validate(conn)
|
241
241
|
|
242
242
|
@property
|
243
|
-
def
|
243
|
+
def _length(self) -> int:
|
244
244
|
if self.last_response is None:
|
245
|
-
|
245
|
+
self._load_page()
|
246
246
|
return self.last_response.total_count
|
247
247
|
|
248
248
|
@property
|
@@ -608,9 +608,9 @@ class Artifacts(SizedPaginator["Artifact"]):
|
|
608
608
|
self.last_response = ArtifactsFragment.model_validate(conn)
|
609
609
|
|
610
610
|
@property
|
611
|
-
def
|
611
|
+
def _length(self) -> int:
|
612
612
|
if self.last_response is None:
|
613
|
-
|
613
|
+
self._load_page()
|
614
614
|
return self.last_response.total_count
|
615
615
|
|
616
616
|
@property
|
@@ -698,9 +698,9 @@ class RunArtifacts(SizedPaginator["Artifact"]):
|
|
698
698
|
self.last_response = self._response_cls.model_validate(inner_data)
|
699
699
|
|
700
700
|
@property
|
701
|
-
def
|
701
|
+
def _length(self) -> int:
|
702
702
|
if self.last_response is None:
|
703
|
-
|
703
|
+
self._load_page()
|
704
704
|
return self.last_response.total_count
|
705
705
|
|
706
706
|
@property
|
@@ -799,7 +799,9 @@ class ArtifactFiles(SizedPaginator["public.File"]):
|
|
799
799
|
return [self.artifact.entity, self.artifact.project, self.artifact.name]
|
800
800
|
|
801
801
|
@property
|
802
|
-
def
|
802
|
+
def _length(self) -> int:
|
803
|
+
if self.last_response is None:
|
804
|
+
self._load_page()
|
803
805
|
return self.artifact.file_count
|
804
806
|
|
805
807
|
@property
|
wandb/apis/public/files.py
CHANGED
@@ -72,11 +72,11 @@ class Files(SizedPaginator["File"]):
|
|
72
72
|
super().__init__(client, variables, per_page)
|
73
73
|
|
74
74
|
@property
|
75
|
-
def
|
76
|
-
if self.last_response:
|
77
|
-
|
78
|
-
|
79
|
-
|
75
|
+
def _length(self):
|
76
|
+
if not self.last_response:
|
77
|
+
self._load_page()
|
78
|
+
|
79
|
+
return self.last_response["project"]["run"]["fileCount"]
|
80
80
|
|
81
81
|
@property
|
82
82
|
def more(self):
|
wandb/apis/public/projects.py
CHANGED
@@ -9,6 +9,7 @@ from wandb.apis import public
|
|
9
9
|
from wandb.apis.attrs import Attrs
|
10
10
|
from wandb.apis.normalize import normalize_exceptions
|
11
11
|
from wandb.apis.paginator import Paginator
|
12
|
+
from wandb.apis.public.api import RetryingClient
|
12
13
|
from wandb.sdk.lib import ipython
|
13
14
|
|
14
15
|
PROJECT_FRAGMENT = """fragment ProjectFragment on Project {
|
@@ -43,7 +44,19 @@ class Projects(Paginator["Project"]):
|
|
43
44
|
""".format(PROJECT_FRAGMENT)
|
44
45
|
)
|
45
46
|
|
46
|
-
def __init__(
|
47
|
+
def __init__(
|
48
|
+
self,
|
49
|
+
client: RetryingClient,
|
50
|
+
entity: str,
|
51
|
+
per_page: int = 50,
|
52
|
+
) -> "Projects":
|
53
|
+
"""An iterable collection of `Project` objects.
|
54
|
+
|
55
|
+
Args:
|
56
|
+
client: The API client used to query W&B.
|
57
|
+
entity: The entity which owns the projects.
|
58
|
+
per_page: The number of projects to fetch per request to the API.
|
59
|
+
"""
|
47
60
|
self.client = client
|
48
61
|
self.entity = entity
|
49
62
|
variables = {
|
@@ -83,7 +96,31 @@ class Projects(Paginator["Project"]):
|
|
83
96
|
class Project(Attrs):
|
84
97
|
"""A project is a namespace for runs."""
|
85
98
|
|
86
|
-
|
99
|
+
QUERY = gql(
|
100
|
+
"""
|
101
|
+
query Project($project: String!, $entity: String!) {
|
102
|
+
project(name: $project, entityName: $entity) {
|
103
|
+
id
|
104
|
+
}
|
105
|
+
}
|
106
|
+
"""
|
107
|
+
)
|
108
|
+
|
109
|
+
def __init__(
|
110
|
+
self,
|
111
|
+
client: RetryingClient,
|
112
|
+
entity: str,
|
113
|
+
project: str,
|
114
|
+
attrs: dict,
|
115
|
+
) -> "Project":
|
116
|
+
"""A single project associated with an entity.
|
117
|
+
|
118
|
+
Args:
|
119
|
+
client: The API client used to query W&B.
|
120
|
+
entity: The entity which owns the project.
|
121
|
+
project: The name of the project to query.
|
122
|
+
attrs: The attributes of the project.
|
123
|
+
"""
|
87
124
|
super().__init__(dict(attrs))
|
88
125
|
self.client = client
|
89
126
|
self.name = project
|
@@ -143,7 +180,7 @@ class Project(Attrs):
|
|
143
180
|
)
|
144
181
|
variable_values = {"project": self.name, "entity": self.entity}
|
145
182
|
ret = self.client.execute(query, variable_values)
|
146
|
-
if ret["project"]["totalSweeps"] < 1:
|
183
|
+
if not ret.get("project") or ret["project"]["totalSweeps"] < 1:
|
147
184
|
return []
|
148
185
|
|
149
186
|
return [
|
@@ -178,6 +215,10 @@ class Project(Attrs):
|
|
178
215
|
variable_values = {"projectName": self.name, "entityName": self.entity}
|
179
216
|
try:
|
180
217
|
data = self.client.execute(self._PROJECT_ID, variable_values)
|
218
|
+
|
219
|
+
if not data.get("project") or not data["project"].get("id"):
|
220
|
+
raise ValueError(f"Project {self.name} not found")
|
221
|
+
|
181
222
|
self._attrs["id"] = data["project"]["id"]
|
182
223
|
return self._attrs["id"]
|
183
224
|
except (HTTPError, LookupError, TypeError) as e:
|
@@ -13,7 +13,7 @@ if TYPE_CHECKING:
|
|
13
13
|
from wandb_gql import gql
|
14
14
|
|
15
15
|
|
16
|
-
class
|
16
|
+
class Visibility(str, Enum):
|
17
17
|
# names are what users see/pass into Python methods
|
18
18
|
# values are what's expected by backend API
|
19
19
|
organization = "PRIVATE"
|
@@ -27,7 +27,7 @@ class _Visibility(str, Enum):
|
|
27
27
|
)
|
28
28
|
|
29
29
|
|
30
|
-
def
|
30
|
+
def format_gql_artifact_types_input(
|
31
31
|
artifact_types: Optional[List[str]] = None,
|
32
32
|
):
|
33
33
|
"""Format the artifact types for the GQL input.
|
@@ -44,7 +44,7 @@ def _format_gql_artifact_types_input(
|
|
44
44
|
return [{"name": type} for type in new_types]
|
45
45
|
|
46
46
|
|
47
|
-
def
|
47
|
+
def gql_to_registry_visibility(
|
48
48
|
visibility: str,
|
49
49
|
) -> Literal["organization", "restricted"]:
|
50
50
|
"""Convert the GQL visibility to the registry visibility.
|
@@ -56,25 +56,25 @@ def _gql_to_registry_visibility(
|
|
56
56
|
The registry visibility.
|
57
57
|
"""
|
58
58
|
try:
|
59
|
-
return
|
59
|
+
return Visibility(visibility).name
|
60
60
|
except ValueError:
|
61
61
|
raise ValueError(f"Invalid visibility: {visibility!r} from backend")
|
62
62
|
|
63
63
|
|
64
|
-
def
|
64
|
+
def registry_visibility_to_gql(
|
65
65
|
visibility: Literal["organization", "restricted"],
|
66
66
|
) -> str:
|
67
67
|
"""Convert the registry visibility to the GQL visibility."""
|
68
68
|
try:
|
69
|
-
return
|
69
|
+
return Visibility[visibility].value
|
70
70
|
except KeyError:
|
71
71
|
raise ValueError(
|
72
72
|
f"Invalid visibility: {visibility!r}. "
|
73
|
-
f"Must be one of: {', '.join(map(repr, (e.name for e in
|
73
|
+
f"Must be one of: {', '.join(map(repr, (e.name for e in Visibility)))}"
|
74
74
|
)
|
75
75
|
|
76
76
|
|
77
|
-
def
|
77
|
+
def ensure_registry_prefix_on_names(query, in_name=False):
|
78
78
|
"""Traverse the filter to prepend the `name` key value with the registry prefix unless the value is a regex.
|
79
79
|
|
80
80
|
- in_name: True if we are under a "name" key (or propagating from one).
|
@@ -89,23 +89,23 @@ def _ensure_registry_prefix_on_names(query, in_name=False):
|
|
89
89
|
new_dict = {}
|
90
90
|
for key, obj in dct.items():
|
91
91
|
if key == "name":
|
92
|
-
new_dict[key] =
|
92
|
+
new_dict[key] = ensure_registry_prefix_on_names(obj, in_name=True)
|
93
93
|
elif key == "$regex":
|
94
94
|
# For regex operator, we skip transformation of its value.
|
95
95
|
new_dict[key] = obj
|
96
96
|
else:
|
97
97
|
# For any other key, propagate the in_name and skip_transform flags as-is.
|
98
|
-
new_dict[key] =
|
98
|
+
new_dict[key] = ensure_registry_prefix_on_names(obj, in_name=in_name)
|
99
99
|
return new_dict
|
100
100
|
if isinstance((objs := query), Sequence):
|
101
101
|
return list(
|
102
|
-
map(lambda x:
|
102
|
+
map(lambda x: ensure_registry_prefix_on_names(x, in_name=in_name), objs)
|
103
103
|
)
|
104
104
|
return query
|
105
105
|
|
106
106
|
|
107
107
|
@lru_cache(maxsize=10)
|
108
|
-
def
|
108
|
+
def fetch_org_entity_from_organization(client: "Client", organization: str) -> str:
|
109
109
|
"""Fetch the org entity from the organization.
|
110
110
|
|
111
111
|
Args:
|
@@ -15,7 +15,7 @@ from wandb.sdk.artifacts._graphql_fragments import (
|
|
15
15
|
_gql_registry_fragment,
|
16
16
|
)
|
17
17
|
|
18
|
-
from .
|
18
|
+
from ._utils import ensure_registry_prefix_on_names
|
19
19
|
|
20
20
|
|
21
21
|
class Registries(Paginator):
|
@@ -54,7 +54,7 @@ class Registries(Paginator):
|
|
54
54
|
):
|
55
55
|
self.client = client
|
56
56
|
self.organization = organization
|
57
|
-
self.filter =
|
57
|
+
self.filter = ensure_registry_prefix_on_names(filter or {})
|
58
58
|
variables = {
|
59
59
|
"organization": organization,
|
60
60
|
"filters": json.dumps(self.filter),
|
@@ -3,26 +3,27 @@ from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
|
|
3
3
|
from wandb_gql import gql
|
4
4
|
|
5
5
|
import wandb
|
6
|
-
from wandb.apis.public.registries._freezable_list import AddOnlyArtifactTypesList
|
7
|
-
from wandb.apis.public.registries.registries_search import Collections, Versions
|
8
|
-
from wandb.apis.public.registries.utils import (
|
9
|
-
_fetch_org_entity_from_organization,
|
10
|
-
_format_gql_artifact_types_input,
|
11
|
-
_gql_to_registry_visibility,
|
12
|
-
_registry_visibility_to_gql,
|
13
|
-
)
|
14
6
|
from wandb.proto.wandb_internal_pb2 import ServerFeature
|
15
7
|
from wandb.sdk.artifacts._validators import REGISTRY_PREFIX, validate_project_name
|
16
8
|
from wandb.sdk.internal.internal_api import Api as InternalApi
|
17
|
-
from wandb.sdk.projects._generated
|
18
|
-
from wandb.sdk.projects._generated.operations import (
|
9
|
+
from wandb.sdk.projects._generated import (
|
19
10
|
DELETE_PROJECT_GQL,
|
20
11
|
FETCH_REGISTRY_GQL,
|
21
12
|
RENAME_PROJECT_GQL,
|
22
13
|
UPSERT_REGISTRY_PROJECT_GQL,
|
14
|
+
DeleteProject,
|
15
|
+
RenameProject,
|
16
|
+
UpsertRegistryProject,
|
17
|
+
)
|
18
|
+
|
19
|
+
from ._freezable_list import AddOnlyArtifactTypesList
|
20
|
+
from ._utils import (
|
21
|
+
fetch_org_entity_from_organization,
|
22
|
+
format_gql_artifact_types_input,
|
23
|
+
gql_to_registry_visibility,
|
24
|
+
registry_visibility_to_gql,
|
23
25
|
)
|
24
|
-
from
|
25
|
-
from wandb.sdk.projects._generated.upsert_registry_project import UpsertRegistryProject
|
26
|
+
from .registries_search import Collections, Versions
|
26
27
|
|
27
28
|
if TYPE_CHECKING:
|
28
29
|
from wandb_gql import Client
|
@@ -62,7 +63,7 @@ class Registry:
|
|
62
63
|
)
|
63
64
|
self._created_at = attrs.get("createdAt", "")
|
64
65
|
self._updated_at = attrs.get("updatedAt", "")
|
65
|
-
self._visibility =
|
66
|
+
self._visibility = gql_to_registry_visibility(attrs.get("access", ""))
|
66
67
|
|
67
68
|
@property
|
68
69
|
def full_name(self) -> str:
|
@@ -220,13 +221,13 @@ class Registry:
|
|
220
221
|
ValueError: If a registry with the same name already exists in the
|
221
222
|
organization or if the creation fails.
|
222
223
|
"""
|
223
|
-
org_entity =
|
224
|
+
org_entity = fetch_org_entity_from_organization(client, organization)
|
224
225
|
full_name = REGISTRY_PREFIX + name
|
225
226
|
validate_project_name(full_name)
|
226
227
|
accepted_artifact_types = []
|
227
228
|
if artifact_types:
|
228
|
-
accepted_artifact_types =
|
229
|
-
visibility_value =
|
229
|
+
accepted_artifact_types = format_gql_artifact_types_input(artifact_types)
|
230
|
+
visibility_value = registry_visibility_to_gql(visibility)
|
230
231
|
registry_creation_error = (
|
231
232
|
f"Failed to create registry {name!r} in organization {organization!r}."
|
232
233
|
)
|
@@ -310,8 +311,8 @@ class Registry:
|
|
310
311
|
)
|
311
312
|
|
312
313
|
validate_project_name(self.full_name)
|
313
|
-
visibility_value =
|
314
|
-
newly_added_types =
|
314
|
+
visibility_value = registry_visibility_to_gql(self.visibility)
|
315
|
+
newly_added_types = format_gql_artifact_types_input(self.artifact_types.draft)
|
315
316
|
registry_save_error = f"Failed to save and update registry: {self.name} in organization: {self.organization}"
|
316
317
|
full_saved_name = f"{REGISTRY_PREFIX}{self._saved_name}"
|
317
318
|
try:
|