wandb 0.22.0__py3-none-win32.whl → 0.22.1__py3-none-win32.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +1 -1
- wandb/__init__.pyi +3 -3
- wandb/_pydantic/__init__.py +12 -11
- wandb/_pydantic/base.py +49 -19
- wandb/apis/__init__.py +2 -0
- wandb/apis/attrs.py +2 -0
- wandb/apis/importers/internals/internal.py +16 -23
- wandb/apis/internal.py +2 -0
- wandb/apis/normalize.py +2 -0
- wandb/apis/public/__init__.py +3 -2
- wandb/apis/public/api.py +215 -164
- wandb/apis/public/artifacts.py +23 -20
- wandb/apis/public/const.py +2 -0
- wandb/apis/public/files.py +33 -24
- wandb/apis/public/history.py +2 -0
- wandb/apis/public/jobs.py +20 -18
- wandb/apis/public/projects.py +4 -2
- wandb/apis/public/query_generator.py +3 -0
- wandb/apis/public/registries/__init__.py +7 -0
- wandb/apis/public/registries/_freezable_list.py +9 -12
- wandb/apis/public/registries/registries_search.py +8 -6
- wandb/apis/public/registries/registry.py +22 -17
- wandb/apis/public/reports.py +2 -0
- wandb/apis/public/runs.py +261 -57
- wandb/apis/public/sweeps.py +10 -9
- wandb/apis/public/teams.py +2 -0
- wandb/apis/public/users.py +2 -0
- wandb/apis/public/utils.py +16 -15
- wandb/automations/_generated/__init__.py +54 -127
- wandb/automations/_generated/create_generic_webhook_integration.py +1 -7
- wandb/automations/_generated/fragments.py +26 -91
- wandb/bin/gpu_stats.exe +0 -0
- wandb/bin/wandb-core +0 -0
- wandb/cli/beta_sync.py +9 -11
- wandb/errors/errors.py +3 -3
- wandb/proto/v3/wandb_sync_pb2.py +19 -6
- wandb/proto/v4/wandb_sync_pb2.py +10 -6
- wandb/proto/v5/wandb_sync_pb2.py +10 -6
- wandb/proto/v6/wandb_sync_pb2.py +10 -6
- wandb/sdk/artifacts/_factories.py +7 -2
- wandb/sdk/artifacts/_generated/__init__.py +112 -412
- wandb/sdk/artifacts/_generated/fragments.py +65 -0
- wandb/sdk/artifacts/_generated/operations.py +52 -22
- wandb/sdk/artifacts/_generated/run_input_artifacts.py +3 -23
- wandb/sdk/artifacts/_generated/run_output_artifacts.py +3 -23
- wandb/sdk/artifacts/_generated/type_info.py +19 -0
- wandb/sdk/artifacts/_gqlutils.py +47 -0
- wandb/sdk/artifacts/_models/__init__.py +4 -0
- wandb/sdk/artifacts/_models/base_model.py +20 -0
- wandb/sdk/artifacts/_validators.py +40 -12
- wandb/sdk/artifacts/artifact.py +69 -88
- wandb/sdk/artifacts/artifact_file_cache.py +6 -1
- wandb/sdk/artifacts/artifact_manifest_entry.py +61 -2
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +10 -0
- wandb/sdk/data_types/bokeh.py +5 -1
- wandb/sdk/data_types/image.py +17 -6
- wandb/sdk/interface/interface.py +31 -4
- wandb/sdk/interface/interface_queue.py +10 -0
- wandb/sdk/interface/interface_shared.py +0 -7
- wandb/sdk/interface/interface_sock.py +9 -3
- wandb/sdk/internal/_generated/__init__.py +2 -12
- wandb/sdk/internal/sender.py +1 -1
- wandb/sdk/internal/settings_static.py +2 -82
- wandb/sdk/launch/runner/kubernetes_runner.py +25 -20
- wandb/sdk/launch/utils.py +82 -1
- wandb/sdk/lib/progress.py +7 -4
- wandb/sdk/lib/service/service_client.py +5 -9
- wandb/sdk/lib/service/service_connection.py +39 -23
- wandb/sdk/mailbox/mailbox_handle.py +2 -0
- wandb/sdk/projects/_generated/__init__.py +12 -33
- wandb/sdk/wandb_init.py +22 -2
- wandb/sdk/wandb_login.py +53 -27
- wandb/sdk/wandb_run.py +5 -3
- wandb/sdk/wandb_settings.py +50 -13
- wandb/sync/sync.py +7 -2
- wandb/util.py +1 -1
- {wandb-0.22.0.dist-info → wandb-0.22.1.dist-info}/METADATA +1 -1
- {wandb-0.22.0.dist-info → wandb-0.22.1.dist-info}/RECORD +81 -78
- wandb/sdk/artifacts/_graphql_fragments.py +0 -19
- {wandb-0.22.0.dist-info → wandb-0.22.1.dist-info}/WHEEL +0 -0
- {wandb-0.22.0.dist-info → wandb-0.22.1.dist-info}/entry_points.txt +0 -0
- {wandb-0.22.0.dist-info → wandb-0.22.1.dist-info}/licenses/LICENSE +0 -0
@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any
|
|
6
6
|
from typing_extensions import override
|
7
7
|
|
8
8
|
from wandb.proto import wandb_server_pb2 as spb
|
9
|
+
from wandb.sdk.lib import asyncio_manager
|
9
10
|
|
10
11
|
from .interface_shared import InterfaceShared
|
11
12
|
|
@@ -21,10 +22,12 @@ logger = logging.getLogger("wandb")
|
|
21
22
|
class InterfaceSock(InterfaceShared):
|
22
23
|
def __init__(
|
23
24
|
self,
|
25
|
+
asyncer: asyncio_manager.AsyncioManager,
|
24
26
|
client: ServiceClient,
|
25
27
|
stream_id: str,
|
26
28
|
) -> None:
|
27
29
|
super().__init__()
|
30
|
+
self._asyncer = asyncer
|
28
31
|
self._client = client
|
29
32
|
self._stream_id = stream_id
|
30
33
|
|
@@ -37,13 +40,16 @@ class InterfaceSock(InterfaceShared):
|
|
37
40
|
self._assign(record)
|
38
41
|
request = spb.ServerRequest()
|
39
42
|
request.record_publish.CopyFrom(record)
|
40
|
-
self._client.publish(request)
|
43
|
+
self._asyncer.run(lambda: self._client.publish(request))
|
41
44
|
|
42
|
-
@override
|
43
45
|
def _deliver(self, record: pb.Record) -> MailboxHandle[pb.Result]:
|
46
|
+
return self._asyncer.run(lambda: self.deliver_async(record))
|
47
|
+
|
48
|
+
@override
|
49
|
+
async def deliver_async(self, record: pb.Record) -> MailboxHandle[pb.Result]:
|
44
50
|
self._assign(record)
|
45
51
|
request = spb.ServerRequest()
|
46
52
|
request.record_publish.CopyFrom(record)
|
47
53
|
|
48
|
-
handle = self._client.deliver(request)
|
54
|
+
handle = await self._client.deliver(request)
|
49
55
|
return handle.map(lambda response: response.result_communicate)
|
@@ -1,15 +1,5 @@
|
|
1
1
|
# Generated by ariadne-codegen
|
2
2
|
|
3
|
+
__all__ = ["SERVER_FEATURES_QUERY_GQL", "ServerFeaturesQuery"]
|
3
4
|
from .operations import SERVER_FEATURES_QUERY_GQL
|
4
|
-
from .server_features_query import
|
5
|
-
ServerFeaturesQuery,
|
6
|
-
ServerFeaturesQueryServerInfo,
|
7
|
-
ServerFeaturesQueryServerInfoFeatures,
|
8
|
-
)
|
9
|
-
|
10
|
-
__all__ = [
|
11
|
-
"SERVER_FEATURES_QUERY_GQL",
|
12
|
-
"ServerFeaturesQuery",
|
13
|
-
"ServerFeaturesQueryServerInfo",
|
14
|
-
"ServerFeaturesQueryServerInfoFeatures",
|
15
|
-
]
|
5
|
+
from .server_features_query import ServerFeaturesQuery
|
wandb/sdk/internal/sender.py
CHANGED
@@ -343,7 +343,7 @@ class SendManager:
|
|
343
343
|
publish_interface = InterfaceQueue(record_q=record_q)
|
344
344
|
context_keeper = context.ContextKeeper()
|
345
345
|
return SendManager(
|
346
|
-
settings=SettingsStatic(settings
|
346
|
+
settings=SettingsStatic(dict(settings)),
|
347
347
|
record_q=record_q,
|
348
348
|
result_q=result_q,
|
349
349
|
interface=publish_interface,
|
@@ -2,9 +2,7 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
from typing import Any, Iterable
|
4
4
|
|
5
|
-
from wandb.
|
6
|
-
from wandb.sdk.lib import RunMoment
|
7
|
-
from wandb.sdk.wandb_settings import CLIENT_ONLY_SETTINGS, Settings
|
5
|
+
from wandb.sdk.wandb_settings import Settings
|
8
6
|
|
9
7
|
|
10
8
|
class SettingsStatic(Settings):
|
@@ -14,87 +12,9 @@ class SettingsStatic(Settings):
|
|
14
12
|
attributes or items.
|
15
13
|
"""
|
16
14
|
|
17
|
-
def __init__(self,
|
18
|
-
data = self._proto_to_dict(proto)
|
15
|
+
def __init__(self, data: dict[str, Any]) -> None:
|
19
16
|
super().__init__(**data)
|
20
17
|
|
21
|
-
def _proto_to_dict(self, proto: wandb_settings_pb2.Settings) -> dict:
|
22
|
-
data = {}
|
23
|
-
|
24
|
-
exclude_fields = {
|
25
|
-
"model_config",
|
26
|
-
"model_fields",
|
27
|
-
"model_fields_set",
|
28
|
-
"__fields__",
|
29
|
-
"__model_fields_set",
|
30
|
-
"__pydantic_self__",
|
31
|
-
"__pydantic_initialised__",
|
32
|
-
}
|
33
|
-
|
34
|
-
fields = (
|
35
|
-
Settings.model_fields
|
36
|
-
if hasattr(Settings, "model_fields")
|
37
|
-
else Settings.__fields__
|
38
|
-
) # type: ignore [attr-defined]
|
39
|
-
|
40
|
-
fields = {k: v for k, v in fields.items() if k not in exclude_fields} # type: ignore [union-attr]
|
41
|
-
|
42
|
-
forks_specified: list[str] = []
|
43
|
-
for key in fields:
|
44
|
-
if key in CLIENT_ONLY_SETTINGS:
|
45
|
-
continue
|
46
|
-
|
47
|
-
value: Any = None
|
48
|
-
|
49
|
-
field_info = fields[key]
|
50
|
-
annotation = str(field_info.annotation)
|
51
|
-
|
52
|
-
if key == "_stats_open_metrics_filters":
|
53
|
-
# todo: it's an underscored field, refactor into
|
54
|
-
# something more elegant?
|
55
|
-
# I'm really about this. It's ugly, but it works.
|
56
|
-
# Do not try to repeat this at home.
|
57
|
-
value_type = getattr(proto, key).WhichOneof("value")
|
58
|
-
if value_type == "sequence":
|
59
|
-
value = list(getattr(proto, key).sequence.value)
|
60
|
-
elif value_type == "mapping":
|
61
|
-
unpacked_mapping = {}
|
62
|
-
for outer_key, outer_value in getattr(
|
63
|
-
proto, key
|
64
|
-
).mapping.value.items():
|
65
|
-
unpacked_inner = {}
|
66
|
-
for inner_key, inner_value in outer_value.value.items():
|
67
|
-
unpacked_inner[inner_key] = inner_value
|
68
|
-
unpacked_mapping[outer_key] = unpacked_inner
|
69
|
-
value = unpacked_mapping
|
70
|
-
elif key == "fork_from" or key == "resume_from":
|
71
|
-
value = getattr(proto, key)
|
72
|
-
if value.run:
|
73
|
-
value = RunMoment(
|
74
|
-
run=value.run, value=value.value, metric=value.metric
|
75
|
-
)
|
76
|
-
forks_specified.append(key)
|
77
|
-
else:
|
78
|
-
value = None
|
79
|
-
else:
|
80
|
-
if proto.HasField(key): # type: ignore [arg-type]
|
81
|
-
value = getattr(proto, key).value
|
82
|
-
# Convert to list if the field is a sequence
|
83
|
-
if any(t in annotation for t in ("tuple", "Sequence", "list")):
|
84
|
-
value = list(value)
|
85
|
-
else:
|
86
|
-
value = None
|
87
|
-
|
88
|
-
if value is not None:
|
89
|
-
data[key] = value
|
90
|
-
|
91
|
-
if len(forks_specified) > 1:
|
92
|
-
raise ValueError(
|
93
|
-
"Only one of fork_from or resume_from can be specified, not both"
|
94
|
-
)
|
95
|
-
|
96
|
-
return data
|
97
|
-
|
98
18
|
def __setattr__(self, name: str, value: object) -> None:
|
99
19
|
raise AttributeError("Error: SettingsStatic is a readonly object")
|
100
20
|
|
@@ -27,7 +27,11 @@ from wandb.sdk.launch.runner.kubernetes_monitor import (
|
|
27
27
|
CustomResource,
|
28
28
|
LaunchKubernetesMonitor,
|
29
29
|
)
|
30
|
-
from wandb.sdk.launch.utils import
|
30
|
+
from wandb.sdk.launch.utils import (
|
31
|
+
recursive_macro_sub,
|
32
|
+
sanitize_identifiers_for_k8s,
|
33
|
+
yield_containers,
|
34
|
+
)
|
31
35
|
from wandb.sdk.lib.retry import ExponentialBackoff, retry_async
|
32
36
|
from wandb.util import get_module
|
33
37
|
|
@@ -39,6 +43,7 @@ from ..utils import (
|
|
39
43
|
MAX_ENV_LENGTHS,
|
40
44
|
PROJECT_SYNCHRONOUS,
|
41
45
|
get_kube_context_and_api_client,
|
46
|
+
make_k8s_label_safe,
|
42
47
|
make_name_dns_safe,
|
43
48
|
)
|
44
49
|
from .abstract import AbstractRun, AbstractRunner
|
@@ -708,6 +713,7 @@ class KubernetesRunner(AbstractRunner):
|
|
708
713
|
api_key_secret: Optional["V1Secret"] = None,
|
709
714
|
wait_for_ready: bool = True,
|
710
715
|
wait_timeout: int = 300,
|
716
|
+
auxiliary_resource_label_value: Optional[str] = None,
|
711
717
|
) -> None:
|
712
718
|
"""Prepare a service for launch.
|
713
719
|
|
@@ -725,6 +731,10 @@ class KubernetesRunner(AbstractRunner):
|
|
725
731
|
config["metadata"].setdefault("labels", {})
|
726
732
|
config["metadata"]["labels"][WANDB_K8S_RUN_ID] = run_id
|
727
733
|
config["metadata"]["labels"]["wandb.ai/created-by"] = "launch-agent"
|
734
|
+
if auxiliary_resource_label_value:
|
735
|
+
config["metadata"]["labels"][WANDB_K8S_LABEL_AUXILIARY_RESOURCE] = (
|
736
|
+
auxiliary_resource_label_value
|
737
|
+
)
|
728
738
|
|
729
739
|
env_vars = launch_project.get_env_vars_dict(
|
730
740
|
self._api, MAX_ENV_LENGTHS[self.__class__.__name__]
|
@@ -734,6 +744,13 @@ class KubernetesRunner(AbstractRunner):
|
|
734
744
|
}
|
735
745
|
add_wandb_env(config, wandb_config_env)
|
736
746
|
|
747
|
+
if auxiliary_resource_label_value:
|
748
|
+
add_label_to_pods(
|
749
|
+
config,
|
750
|
+
WANDB_K8S_LABEL_AUXILIARY_RESOURCE,
|
751
|
+
auxiliary_resource_label_value,
|
752
|
+
)
|
753
|
+
|
737
754
|
if api_key_secret:
|
738
755
|
for cont in yield_containers(config):
|
739
756
|
env = cont.setdefault("env", [])
|
@@ -751,6 +768,8 @@ class KubernetesRunner(AbstractRunner):
|
|
751
768
|
cont["env"] = env
|
752
769
|
|
753
770
|
try:
|
771
|
+
sanitize_identifiers_for_k8s(config)
|
772
|
+
|
754
773
|
await kubernetes_asyncio.utils.create_from_dict(
|
755
774
|
api_client, config, namespace=namespace
|
756
775
|
)
|
@@ -924,6 +943,9 @@ class KubernetesRunner(AbstractRunner):
|
|
924
943
|
additional_services: List[Dict[str, Any]] = recursive_macro_sub(
|
925
944
|
launch_project.launch_spec.get("additional_services", []), update_dict
|
926
945
|
)
|
946
|
+
auxiliary_resource_label_value = make_k8s_label_safe(
|
947
|
+
f"aux-{launch_project.target_entity}-{launch_project.target_project}-{launch_project.run_id}"
|
948
|
+
)
|
927
949
|
if additional_services:
|
928
950
|
wandb.termlog(
|
929
951
|
f"{LOG_PREFIX}Creating additional services: {additional_services}"
|
@@ -943,6 +965,7 @@ class KubernetesRunner(AbstractRunner):
|
|
943
965
|
secret,
|
944
966
|
wait_for_ready,
|
945
967
|
wait_timeout,
|
968
|
+
auxiliary_resource_label_value,
|
946
969
|
)
|
947
970
|
for resource in additional_services
|
948
971
|
if resource.get("config", {})
|
@@ -980,7 +1003,7 @@ class KubernetesRunner(AbstractRunner):
|
|
980
1003
|
job_name,
|
981
1004
|
namespace,
|
982
1005
|
secret,
|
983
|
-
|
1006
|
+
auxiliary_resource_label_value,
|
984
1007
|
)
|
985
1008
|
if self.backend_config[PROJECT_SYNCHRONOUS]:
|
986
1009
|
await submitted_job.wait()
|
@@ -1131,24 +1154,6 @@ async def maybe_create_imagepull_secret(
|
|
1131
1154
|
raise LaunchError(f"Exception when creating Kubernetes secret: {str(e)}\n")
|
1132
1155
|
|
1133
1156
|
|
1134
|
-
def yield_containers(root: Any) -> Iterator[dict]:
|
1135
|
-
"""Yield all container specs in a manifest.
|
1136
|
-
|
1137
|
-
Recursively traverses the manifest and yields all container specs. Container
|
1138
|
-
specs are identified by the presence of a "containers" key in the value.
|
1139
|
-
"""
|
1140
|
-
if isinstance(root, dict):
|
1141
|
-
for k, v in root.items():
|
1142
|
-
if k == "containers":
|
1143
|
-
if isinstance(v, list):
|
1144
|
-
yield from v
|
1145
|
-
elif isinstance(v, (dict, list)):
|
1146
|
-
yield from yield_containers(v)
|
1147
|
-
elif isinstance(root, list):
|
1148
|
-
for item in root:
|
1149
|
-
yield from yield_containers(item)
|
1150
|
-
|
1151
|
-
|
1152
1157
|
def add_wandb_env(root: Union[dict, list], env_vars: Dict[str, str]) -> None:
|
1153
1158
|
"""Injects wandb environment variables into specs.
|
1154
1159
|
|
wandb/sdk/launch/utils.py
CHANGED
@@ -7,7 +7,17 @@ import re
|
|
7
7
|
import subprocess
|
8
8
|
import sys
|
9
9
|
from collections import defaultdict
|
10
|
-
from typing import
|
10
|
+
from typing import (
|
11
|
+
TYPE_CHECKING,
|
12
|
+
Any,
|
13
|
+
Dict,
|
14
|
+
Iterator,
|
15
|
+
List,
|
16
|
+
Optional,
|
17
|
+
Tuple,
|
18
|
+
Union,
|
19
|
+
cast,
|
20
|
+
)
|
11
21
|
|
12
22
|
import click
|
13
23
|
|
@@ -599,6 +609,32 @@ def make_name_dns_safe(name: str) -> str:
|
|
599
609
|
return resp
|
600
610
|
|
601
611
|
|
612
|
+
def make_k8s_label_safe(value: str) -> str:
|
613
|
+
"""Return a Kubernetes label/identifier safe string (DNS-1123 label).
|
614
|
+
|
615
|
+
See:
|
616
|
+
https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#dns-label-names
|
617
|
+
|
618
|
+
Rules:
|
619
|
+
- lowercase alphanumeric and '-'
|
620
|
+
- must start and end with an alphanumeric
|
621
|
+
- max length 63
|
622
|
+
"""
|
623
|
+
# Normalize common separators first
|
624
|
+
safe = value.replace("_", "-").lower()
|
625
|
+
# Remove any invalid characters
|
626
|
+
safe = re.sub(r"[^a-z0-9\-]", "", safe)
|
627
|
+
# Collapse consecutive '-'
|
628
|
+
safe = re.sub(r"-+", "-", safe)
|
629
|
+
# Trim to 63 and strip leading/trailing '-'
|
630
|
+
safe = safe[:63].strip("-")
|
631
|
+
|
632
|
+
if not safe:
|
633
|
+
raise LaunchError(f"Invalid value for Kubernetes label: {value}")
|
634
|
+
|
635
|
+
return safe
|
636
|
+
|
637
|
+
|
602
638
|
def warn_failed_packages_from_build_logs(
|
603
639
|
log: str, image_uri: str, api: Api, job_tracker: Optional["JobAndRunStatusTracker"]
|
604
640
|
) -> None:
|
@@ -744,3 +780,48 @@ def get_current_python_version() -> Tuple[str, str]:
|
|
744
780
|
major = full_version[0]
|
745
781
|
version = ".".join(full_version[:2]) if len(full_version) >= 2 else major + ".0"
|
746
782
|
return version, major
|
783
|
+
|
784
|
+
|
785
|
+
def yield_containers(root: Union[dict, list]) -> Iterator[dict]:
|
786
|
+
"""Yield all container specs in a manifest.
|
787
|
+
|
788
|
+
Recursively traverses the manifest and yields all container specs. Container
|
789
|
+
specs are identified by the presence of a "containers" key in the value.
|
790
|
+
"""
|
791
|
+
if isinstance(root, dict):
|
792
|
+
for k, v in root.items():
|
793
|
+
if k == "containers":
|
794
|
+
if isinstance(v, list):
|
795
|
+
yield from v
|
796
|
+
elif isinstance(v, (dict, list)):
|
797
|
+
yield from yield_containers(v)
|
798
|
+
elif isinstance(root, list):
|
799
|
+
for item in root:
|
800
|
+
yield from yield_containers(item)
|
801
|
+
|
802
|
+
|
803
|
+
def sanitize_identifiers_for_k8s(root: Any) -> None:
|
804
|
+
if isinstance(root, list):
|
805
|
+
for item in root:
|
806
|
+
sanitize_identifiers_for_k8s(item)
|
807
|
+
return
|
808
|
+
|
809
|
+
# Only dicts have metadata and nested structures we need to sanitize.
|
810
|
+
if not isinstance(root, dict):
|
811
|
+
return
|
812
|
+
|
813
|
+
metadata = root.get("metadata")
|
814
|
+
if isinstance(metadata, dict):
|
815
|
+
if name := metadata.get("name"):
|
816
|
+
metadata["name"] = make_k8s_label_safe(str(name))
|
817
|
+
|
818
|
+
for container in yield_containers(root):
|
819
|
+
if name := container.get("name"):
|
820
|
+
container["name"] = make_k8s_label_safe(str(name))
|
821
|
+
|
822
|
+
# nested names
|
823
|
+
for key, value in root.items():
|
824
|
+
if isinstance(value, (dict, list)):
|
825
|
+
sanitize_identifiers_for_k8s(value)
|
826
|
+
elif key == "name" and isinstance(value, str):
|
827
|
+
root[key] = make_k8s_label_safe(value)
|
wandb/sdk/lib/progress.py
CHANGED
@@ -45,7 +45,11 @@ async def loop_printing_operation_stats(
|
|
45
45
|
while True:
|
46
46
|
start_time = time.monotonic()
|
47
47
|
|
48
|
-
handle = interface.
|
48
|
+
handle = await interface.deliver_async(
|
49
|
+
pb.Record(
|
50
|
+
request=pb.Request(operations=pb.OperationStatsRequest()),
|
51
|
+
)
|
52
|
+
)
|
49
53
|
result = await handle.wait_async(timeout=None)
|
50
54
|
stats = result.response.operations_response.operation_stats
|
51
55
|
|
@@ -141,10 +145,9 @@ class ProgressPrinter:
|
|
141
145
|
if extra_operations > 0:
|
142
146
|
line += f" (+ {extra_operations} more)"
|
143
147
|
|
144
|
-
if line != self._last_printed_line:
|
148
|
+
if line and line != self._last_printed_line:
|
145
149
|
self._printer.display(line)
|
146
|
-
|
147
|
-
self._last_printed_line = line
|
150
|
+
self._last_printed_line = line
|
148
151
|
|
149
152
|
def _update_single_run(
|
150
153
|
self,
|
@@ -24,7 +24,6 @@ class ServiceClient:
|
|
24
24
|
reader: asyncio.StreamReader,
|
25
25
|
writer: asyncio.StreamWriter,
|
26
26
|
) -> None:
|
27
|
-
self._asyncer = asyncer
|
28
27
|
self._reader = reader
|
29
28
|
self._writer = writer
|
30
29
|
self._mailbox = Mailbox(asyncer)
|
@@ -34,11 +33,11 @@ class ServiceClient:
|
|
34
33
|
name="ServiceClient._forward_responses",
|
35
34
|
)
|
36
35
|
|
37
|
-
def publish(self, request: spb.ServerRequest) -> None:
|
36
|
+
async def publish(self, request: spb.ServerRequest) -> None:
|
38
37
|
"""Send a request without waiting for a response."""
|
39
|
-
|
38
|
+
await self._send_server_request(request)
|
40
39
|
|
41
|
-
def deliver(
|
40
|
+
async def deliver(
|
42
41
|
self,
|
43
42
|
request: spb.ServerRequest,
|
44
43
|
) -> MailboxHandle[spb.ServerResponse]:
|
@@ -52,7 +51,7 @@ class ServiceClient:
|
|
52
51
|
stopped due to an error.
|
53
52
|
"""
|
54
53
|
handle = self._mailbox.require_response(request)
|
55
|
-
|
54
|
+
await self._send_server_request(request)
|
56
55
|
return handle
|
57
56
|
|
58
57
|
async def _send_server_request(self, request: spb.ServerRequest) -> None:
|
@@ -64,11 +63,8 @@ class ServiceClient:
|
|
64
63
|
|
65
64
|
await self._writer.drain()
|
66
65
|
|
67
|
-
def close(self) -> None:
|
66
|
+
async def close(self) -> None:
|
68
67
|
"""Flush and close the socket."""
|
69
|
-
self._asyncer.run_soon(self._close)
|
70
|
-
|
71
|
-
async def _close(self) -> None:
|
72
68
|
self._writer.close()
|
73
69
|
await self._writer.wait_closed()
|
74
70
|
|
@@ -31,6 +31,7 @@ def connect_to_service(
|
|
31
31
|
|
32
32
|
if token:
|
33
33
|
return ServiceConnection(
|
34
|
+
asyncer=asyncer,
|
34
35
|
client=token.connect(asyncer=asyncer),
|
35
36
|
proc=None,
|
36
37
|
)
|
@@ -60,6 +61,7 @@ def _start_and_connect_service(
|
|
60
61
|
conn.teardown(hooks.exit_code)
|
61
62
|
|
62
63
|
conn = ServiceConnection(
|
64
|
+
asyncer=asyncer,
|
63
65
|
client=client,
|
64
66
|
proc=proc,
|
65
67
|
cleanup=lambda: atexit.unregister(teardown_atexit),
|
@@ -71,10 +73,14 @@ def _start_and_connect_service(
|
|
71
73
|
|
72
74
|
|
73
75
|
class ServiceConnection:
|
74
|
-
"""A connection to the W&B internal service process.
|
76
|
+
"""A connection to the W&B internal service process.
|
77
|
+
|
78
|
+
None of the synchronous methods may be called in an asyncio context.
|
79
|
+
"""
|
75
80
|
|
76
81
|
def __init__(
|
77
82
|
self,
|
83
|
+
asyncer: asyncio_manager.AsyncioManager,
|
78
84
|
client: ServiceClient,
|
79
85
|
proc: service_process.ServiceProcess | None,
|
80
86
|
cleanup: Callable[[], None] | None = None,
|
@@ -82,13 +88,12 @@ class ServiceConnection:
|
|
82
88
|
"""Returns a new ServiceConnection.
|
83
89
|
|
84
90
|
Args:
|
85
|
-
|
86
|
-
|
87
|
-
updates the mailbox.
|
88
|
-
client: A socket that's connected to the service.
|
91
|
+
asyncer: An asyncio runner.
|
92
|
+
client: A client for communicating with the service over a socket.
|
89
93
|
proc: The service process if we own it, or None otherwise.
|
90
94
|
cleanup: A callback to run on teardown before doing anything.
|
91
95
|
"""
|
96
|
+
self._asyncer = asyncer
|
92
97
|
self._client = client
|
93
98
|
self._proc = proc
|
94
99
|
self._torn_down = False
|
@@ -96,9 +101,13 @@ class ServiceConnection:
|
|
96
101
|
|
97
102
|
def make_interface(self, stream_id: str) -> InterfaceBase:
|
98
103
|
"""Returns an interface for communicating with the service."""
|
99
|
-
return InterfaceSock(
|
104
|
+
return InterfaceSock(
|
105
|
+
self._asyncer,
|
106
|
+
self._client,
|
107
|
+
stream_id=stream_id,
|
108
|
+
)
|
100
109
|
|
101
|
-
def init_sync(
|
110
|
+
async def init_sync(
|
102
111
|
self,
|
103
112
|
paths: set[pathlib.Path],
|
104
113
|
settings: wandb_settings.Settings,
|
@@ -110,10 +119,10 @@ class ServiceConnection:
|
|
110
119
|
)
|
111
120
|
request = spb.ServerRequest(init_sync=init_sync)
|
112
121
|
|
113
|
-
handle = self._client.deliver(request)
|
122
|
+
handle = await self._client.deliver(request)
|
114
123
|
return handle.map(lambda r: r.init_sync_response)
|
115
124
|
|
116
|
-
def sync(
|
125
|
+
async def sync(
|
117
126
|
self,
|
118
127
|
id: str,
|
119
128
|
*,
|
@@ -123,10 +132,10 @@ class ServiceConnection:
|
|
123
132
|
sync = wandb_sync_pb2.ServerSyncRequest(id=id, parallelism=parallelism)
|
124
133
|
request = spb.ServerRequest(sync=sync)
|
125
134
|
|
126
|
-
handle = self._client.deliver(request)
|
135
|
+
handle = await self._client.deliver(request)
|
127
136
|
return handle.map(lambda r: r.sync_response)
|
128
137
|
|
129
|
-
def sync_status(
|
138
|
+
async def sync_status(
|
130
139
|
self,
|
131
140
|
id: str,
|
132
141
|
) -> MailboxHandle[wandb_sync_pb2.ServerSyncStatusResponse]:
|
@@ -134,7 +143,7 @@ class ServiceConnection:
|
|
134
143
|
sync_status = wandb_sync_pb2.ServerSyncStatusRequest(id=id)
|
135
144
|
request = spb.ServerRequest(sync_status=sync_status)
|
136
145
|
|
137
|
-
handle = self._client.deliver(request)
|
146
|
+
handle = await self._client.deliver(request)
|
138
147
|
return handle.map(lambda r: r.sync_status_response)
|
139
148
|
|
140
149
|
def inform_init(
|
@@ -146,13 +155,17 @@ class ServiceConnection:
|
|
146
155
|
request = spb.ServerInformInitRequest()
|
147
156
|
request.settings.CopyFrom(settings)
|
148
157
|
request._info.stream_id = run_id
|
149
|
-
self.
|
158
|
+
self._asyncer.run(
|
159
|
+
lambda: self._client.publish(spb.ServerRequest(inform_init=request))
|
160
|
+
)
|
150
161
|
|
151
162
|
def inform_finish(self, run_id: str) -> None:
|
152
163
|
"""Send an finish request to the service."""
|
153
164
|
request = spb.ServerInformFinishRequest()
|
154
165
|
request._info.stream_id = run_id
|
155
|
-
self.
|
166
|
+
self._asyncer.run(
|
167
|
+
lambda: self._client.publish(spb.ServerRequest(inform_finish=request))
|
168
|
+
)
|
156
169
|
|
157
170
|
def inform_attach(
|
158
171
|
self,
|
@@ -166,7 +179,7 @@ class ServiceConnection:
|
|
166
179
|
request.inform_attach._info.stream_id = attach_id
|
167
180
|
|
168
181
|
try:
|
169
|
-
handle = self._client.deliver(request)
|
182
|
+
handle = self._asyncer.run(lambda: self._client.deliver(request))
|
170
183
|
response = handle.wait_or(timeout=10)
|
171
184
|
|
172
185
|
except (MailboxClosedError, HandleAbandonedError):
|
@@ -210,13 +223,16 @@ class ServiceConnection:
|
|
210
223
|
# Clear the service token to prevent new connections to the process.
|
211
224
|
service_token.clear_service_in_env()
|
212
225
|
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
226
|
+
async def publish_teardown_and_close() -> None:
|
227
|
+
await self._client.publish(
|
228
|
+
spb.ServerRequest(
|
229
|
+
inform_teardown=spb.ServerInformTeardownRequest(
|
230
|
+
exit_code=exit_code,
|
231
|
+
)
|
232
|
+
),
|
233
|
+
)
|
234
|
+
await self._client.close()
|
235
|
+
|
236
|
+
self._asyncer.run(publish_teardown_and_close)
|
221
237
|
|
222
238
|
return self._proc.join()
|
@@ -58,6 +58,8 @@ class MailboxHandle(abc.ABC, Generic[_T]):
|
|
58
58
|
def cancel(self, iface: interface.InterfaceBase) -> None:
|
59
59
|
"""Cancel the handle, requesting any associated work to not complete.
|
60
60
|
|
61
|
+
It is an error to call this from an async function.
|
62
|
+
|
61
63
|
This automatically abandons the handle, as a response is no longer
|
62
64
|
guaranteed.
|
63
65
|
|
@@ -1,47 +1,26 @@
|
|
1
1
|
# Generated by ariadne-codegen
|
2
2
|
|
3
|
-
from .delete_project import DeleteProject, DeleteProjectDeleteModel
|
4
|
-
from .fetch_registry import FetchRegistry, FetchRegistryEntity
|
5
|
-
from .fragments import (
|
6
|
-
RegistryFragment,
|
7
|
-
RegistryFragmentArtifactTypes,
|
8
|
-
RegistryFragmentArtifactTypesEdges,
|
9
|
-
RegistryFragmentArtifactTypesEdgesNode,
|
10
|
-
)
|
11
|
-
from .input_types import ArtifactTypeInput
|
12
|
-
from .operations import (
|
13
|
-
DELETE_PROJECT_GQL,
|
14
|
-
FETCH_REGISTRY_GQL,
|
15
|
-
RENAME_PROJECT_GQL,
|
16
|
-
UPSERT_REGISTRY_PROJECT_GQL,
|
17
|
-
)
|
18
|
-
from .rename_project import (
|
19
|
-
RenameProject,
|
20
|
-
RenameProjectRenameProject,
|
21
|
-
RenameProjectRenameProjectProject,
|
22
|
-
)
|
23
|
-
from .upsert_registry_project import (
|
24
|
-
UpsertRegistryProject,
|
25
|
-
UpsertRegistryProjectUpsertModel,
|
26
|
-
)
|
27
|
-
|
28
3
|
__all__ = [
|
29
4
|
"DELETE_PROJECT_GQL",
|
30
5
|
"FETCH_REGISTRY_GQL",
|
31
6
|
"RENAME_PROJECT_GQL",
|
32
7
|
"UPSERT_REGISTRY_PROJECT_GQL",
|
33
8
|
"FetchRegistry",
|
34
|
-
"FetchRegistryEntity",
|
35
9
|
"RenameProject",
|
36
|
-
"RenameProjectRenameProject",
|
37
|
-
"RenameProjectRenameProjectProject",
|
38
10
|
"UpsertRegistryProject",
|
39
|
-
"UpsertRegistryProjectUpsertModel",
|
40
11
|
"DeleteProject",
|
41
|
-
"DeleteProjectDeleteModel",
|
42
12
|
"ArtifactTypeInput",
|
43
13
|
"RegistryFragment",
|
44
|
-
"RegistryFragmentArtifactTypes",
|
45
|
-
"RegistryFragmentArtifactTypesEdges",
|
46
|
-
"RegistryFragmentArtifactTypesEdgesNode",
|
47
14
|
]
|
15
|
+
from .delete_project import DeleteProject
|
16
|
+
from .fetch_registry import FetchRegistry
|
17
|
+
from .fragments import RegistryFragment
|
18
|
+
from .input_types import ArtifactTypeInput
|
19
|
+
from .operations import (
|
20
|
+
DELETE_PROJECT_GQL,
|
21
|
+
FETCH_REGISTRY_GQL,
|
22
|
+
RENAME_PROJECT_GQL,
|
23
|
+
UPSERT_REGISTRY_PROJECT_GQL,
|
24
|
+
)
|
25
|
+
from .rename_project import RenameProject
|
26
|
+
from .upsert_registry_project import UpsertRegistryProject
|