wandb 0.21.1__py3-none-win_amd64.whl → 0.21.2__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +1 -1
- wandb/__init__.pyi +1 -1
- wandb/apis/public/api.py +1 -2
- wandb/apis/public/artifacts.py +3 -5
- wandb/apis/public/registries/_utils.py +14 -16
- wandb/apis/public/registries/registries_search.py +176 -289
- wandb/apis/public/reports.py +13 -10
- wandb/automations/_generated/delete_automation.py +1 -3
- wandb/automations/_generated/enums.py +13 -11
- wandb/bin/gpu_stats.exe +0 -0
- wandb/bin/wandb-core +0 -0
- wandb/cli/cli.py +47 -2
- wandb/integration/metaflow/data_pandas.py +2 -2
- wandb/integration/metaflow/data_pytorch.py +75 -0
- wandb/integration/metaflow/data_sklearn.py +76 -0
- wandb/integration/metaflow/metaflow.py +16 -87
- wandb/integration/weave/__init__.py +6 -0
- wandb/integration/weave/interface.py +49 -0
- wandb/integration/weave/weave.py +63 -0
- wandb/proto/v3/wandb_internal_pb2.py +3 -2
- wandb/proto/v4/wandb_internal_pb2.py +2 -2
- wandb/proto/v5/wandb_internal_pb2.py +2 -2
- wandb/proto/v6/wandb_internal_pb2.py +2 -2
- wandb/sdk/artifacts/_factories.py +17 -0
- wandb/sdk/artifacts/_generated/__init__.py +221 -13
- wandb/sdk/artifacts/_generated/artifact_by_id.py +17 -0
- wandb/sdk/artifacts/_generated/artifact_by_name.py +22 -0
- wandb/sdk/artifacts/_generated/artifact_collection_membership_file_urls.py +43 -0
- wandb/sdk/artifacts/_generated/artifact_created_by.py +47 -0
- wandb/sdk/artifacts/_generated/artifact_file_urls.py +22 -0
- wandb/sdk/artifacts/_generated/artifact_type.py +31 -0
- wandb/sdk/artifacts/_generated/artifact_used_by.py +43 -0
- wandb/sdk/artifacts/_generated/artifact_via_membership_by_name.py +26 -0
- wandb/sdk/artifacts/_generated/delete_artifact.py +28 -0
- wandb/sdk/artifacts/_generated/enums.py +5 -0
- wandb/sdk/artifacts/_generated/fetch_artifact_manifest.py +38 -0
- wandb/sdk/artifacts/_generated/fetch_registries.py +32 -0
- wandb/sdk/artifacts/_generated/fragments.py +279 -41
- wandb/sdk/artifacts/_generated/link_artifact.py +6 -0
- wandb/sdk/artifacts/_generated/operations.py +654 -51
- wandb/sdk/artifacts/_generated/registry_collections.py +34 -0
- wandb/sdk/artifacts/_generated/registry_versions.py +34 -0
- wandb/sdk/artifacts/_generated/unlink_artifact.py +25 -0
- wandb/sdk/artifacts/_graphql_fragments.py +3 -86
- wandb/sdk/artifacts/_validators.py +6 -4
- wandb/sdk/artifacts/artifact.py +406 -543
- wandb/sdk/artifacts/artifact_file_cache.py +10 -6
- wandb/sdk/artifacts/artifact_manifest.py +10 -9
- wandb/sdk/artifacts/artifact_manifest_entry.py +9 -10
- wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +5 -3
- wandb/sdk/artifacts/storage_handlers/http_handler.py +1 -1
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +1 -1
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +1 -1
- wandb/sdk/data_types/video.py +2 -2
- wandb/sdk/interface/interface_queue.py +1 -4
- wandb/sdk/interface/interface_shared.py +26 -37
- wandb/sdk/interface/interface_sock.py +24 -14
- wandb/sdk/internal/settings_static.py +2 -3
- wandb/sdk/launch/create_job.py +12 -1
- wandb/sdk/launch/runner/kubernetes_runner.py +24 -29
- wandb/sdk/lib/asyncio_compat.py +16 -16
- wandb/sdk/lib/asyncio_manager.py +252 -0
- wandb/sdk/lib/hashutil.py +13 -4
- wandb/sdk/lib/printer.py +2 -2
- wandb/sdk/lib/printer_asyncio.py +3 -1
- wandb/sdk/lib/retry.py +185 -78
- wandb/sdk/lib/service/service_client.py +106 -0
- wandb/sdk/lib/service/service_connection.py +20 -26
- wandb/sdk/lib/service/service_token.py +30 -13
- wandb/sdk/mailbox/mailbox.py +13 -5
- wandb/sdk/mailbox/mailbox_handle.py +22 -13
- wandb/sdk/mailbox/response_handle.py +42 -106
- wandb/sdk/mailbox/wait_with_progress.py +7 -42
- wandb/sdk/wandb_init.py +11 -25
- wandb/sdk/wandb_login.py +1 -1
- wandb/sdk/wandb_run.py +91 -55
- wandb/sdk/wandb_settings.py +45 -32
- wandb/sdk/wandb_setup.py +176 -96
- wandb/util.py +1 -1
- {wandb-0.21.1.dist-info → wandb-0.21.2.dist-info}/METADATA +1 -1
- {wandb-0.21.1.dist-info → wandb-0.21.2.dist-info}/RECORD +84 -68
- wandb/sdk/interface/interface_relay.py +0 -38
- wandb/sdk/interface/router.py +0 -89
- wandb/sdk/interface/router_queue.py +0 -43
- wandb/sdk/interface/router_relay.py +0 -50
- wandb/sdk/interface/router_sock.py +0 -32
- wandb/sdk/lib/sock_client.py +0 -232
- {wandb-0.21.1.dist-info → wandb-0.21.2.dist-info}/WHEEL +0 -0
- {wandb-0.21.1.dist-info → wandb-0.21.2.dist-info}/entry_points.txt +0 -0
- {wandb-0.21.1.dist-info → wandb-0.21.2.dist-info}/licenses/LICENSE +0 -0
@@ -7,7 +7,6 @@ import json
|
|
7
7
|
import logging
|
8
8
|
import os
|
9
9
|
import time
|
10
|
-
import uuid
|
11
10
|
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
12
11
|
|
13
12
|
import yaml
|
@@ -28,6 +27,7 @@ from wandb.sdk.launch.runner.kubernetes_monitor import (
|
|
28
27
|
CustomResource,
|
29
28
|
LaunchKubernetesMonitor,
|
30
29
|
)
|
30
|
+
from wandb.sdk.launch.utils import recursive_macro_sub
|
31
31
|
from wandb.sdk.lib.retry import ExponentialBackoff, retry_async
|
32
32
|
from wandb.util import get_module
|
33
33
|
|
@@ -62,6 +62,9 @@ from kubernetes_asyncio.client.api.core_v1_api import ( # type: ignore # noqa:
|
|
62
62
|
from kubernetes_asyncio.client.api.custom_objects_api import ( # type: ignore # noqa: E402
|
63
63
|
CustomObjectsApi,
|
64
64
|
)
|
65
|
+
from kubernetes_asyncio.client.api.networking_v1_api import ( # type: ignore # noqa: E402
|
66
|
+
NetworkingV1Api,
|
67
|
+
)
|
65
68
|
from kubernetes_asyncio.client.models.v1_secret import ( # type: ignore # noqa: E402
|
66
69
|
V1Secret,
|
67
70
|
)
|
@@ -85,6 +88,7 @@ class KubernetesSubmittedRun(AbstractRun):
|
|
85
88
|
batch_api: "BatchV1Api",
|
86
89
|
core_api: "CoreV1Api",
|
87
90
|
apps_api: "AppsV1Api",
|
91
|
+
network_api: "NetworkingV1Api",
|
88
92
|
name: str,
|
89
93
|
namespace: Optional[str] = "default",
|
90
94
|
secret: Optional["V1Secret"] = None,
|
@@ -103,6 +107,7 @@ class KubernetesSubmittedRun(AbstractRun):
|
|
103
107
|
Arguments:
|
104
108
|
batch_api: Kubernetes BatchV1Api object.
|
105
109
|
core_api: Kubernetes CoreV1Api object.
|
110
|
+
network_api: Kubernetes NetworkV1Api object.
|
106
111
|
name: Name of the job.
|
107
112
|
namespace: Kubernetes namespace.
|
108
113
|
secret: Kubernetes secret.
|
@@ -113,6 +118,7 @@ class KubernetesSubmittedRun(AbstractRun):
|
|
113
118
|
self.batch_api = batch_api
|
114
119
|
self.core_api = core_api
|
115
120
|
self.apps_api = apps_api
|
121
|
+
self.network_api = network_api
|
116
122
|
self.name = name
|
117
123
|
self.namespace = namespace
|
118
124
|
self._fail_count = 0
|
@@ -207,11 +213,9 @@ class KubernetesSubmittedRun(AbstractRun):
|
|
207
213
|
(self.core_api, "service"),
|
208
214
|
(self.batch_api, "job"),
|
209
215
|
(self.core_api, "pod"),
|
210
|
-
(self.core_api, "config_map"),
|
211
216
|
(self.core_api, "secret"),
|
212
217
|
(self.apps_api, "deployment"),
|
213
|
-
(self.
|
214
|
-
(self.apps_api, "daemon_set"),
|
218
|
+
(self.network_api, "network_policy"),
|
215
219
|
]
|
216
220
|
|
217
221
|
for api_client, resource_type in resource_cleanups:
|
@@ -700,7 +704,6 @@ class KubernetesRunner(AbstractRunner):
|
|
700
704
|
config: Dict[str, Any],
|
701
705
|
namespace: str,
|
702
706
|
run_id: str,
|
703
|
-
auxiliary_resource_label_key: str,
|
704
707
|
launch_project: LaunchProject,
|
705
708
|
api_key_secret: Optional["V1Secret"] = None,
|
706
709
|
wait_for_ready: bool = True,
|
@@ -713,7 +716,6 @@ class KubernetesRunner(AbstractRunner):
|
|
713
716
|
config: The resource configuration to prepare.
|
714
717
|
namespace: The namespace to create the resource in.
|
715
718
|
run_id: The run ID to label the resource with.
|
716
|
-
auxiliary_resource_label_key: The key of the auxiliary resource label.
|
717
719
|
launch_project: The launch project to get environment variables from.
|
718
720
|
api_key_secret: The API key secret to inject.
|
719
721
|
wait_for_ready: Whether to wait for the resource to be ready after creation.
|
@@ -722,25 +724,8 @@ class KubernetesRunner(AbstractRunner):
|
|
722
724
|
config.setdefault("metadata", {})
|
723
725
|
config["metadata"].setdefault("labels", {})
|
724
726
|
config["metadata"]["labels"][WANDB_K8S_RUN_ID] = run_id
|
725
|
-
config["metadata"]["labels"][WANDB_K8S_LABEL_AUXILIARY_RESOURCE] = (
|
726
|
-
auxiliary_resource_label_key
|
727
|
-
)
|
728
727
|
config["metadata"]["labels"]["wandb.ai/created-by"] = "launch-agent"
|
729
728
|
|
730
|
-
if config.get("kind") == "Service" or config.get("kind") == "Deployment":
|
731
|
-
config.setdefault("metadata", {})
|
732
|
-
original_name = config["metadata"].get("name", config.get("kind"))
|
733
|
-
safe_name = make_name_dns_safe(original_name)
|
734
|
-
safe_entity = make_name_dns_safe(launch_project.target_entity or "")
|
735
|
-
safe_project = make_name_dns_safe(launch_project.target_project or "")
|
736
|
-
safe_run_id = make_name_dns_safe(run_id or "")
|
737
|
-
|
738
|
-
new_name = f"{safe_name}-{safe_entity}-{safe_project}-{safe_run_id}"
|
739
|
-
config["metadata"]["name"] = new_name
|
740
|
-
wandb.termlog(
|
741
|
-
f"{LOG_PREFIX}Modified {config.get('kind')} name from '{original_name}' to '{new_name}'"
|
742
|
-
)
|
743
|
-
|
744
729
|
env_vars = launch_project.get_env_vars_dict(
|
745
730
|
self._api, MAX_ENV_LENGTHS[self.__class__.__name__]
|
746
731
|
)
|
@@ -920,19 +905,29 @@ class KubernetesRunner(AbstractRunner):
|
|
920
905
|
batch_api = kubernetes_asyncio.client.BatchV1Api(api_client)
|
921
906
|
core_api = kubernetes_asyncio.client.CoreV1Api(api_client)
|
922
907
|
apps_api = kubernetes_asyncio.client.AppsV1Api(api_client)
|
908
|
+
network_api = kubernetes_asyncio.client.NetworkingV1Api(api_client)
|
923
909
|
|
924
910
|
namespace = self.get_namespace(resource_args, context)
|
925
911
|
job, secret = await self._inject_defaults(
|
926
912
|
resource_args, launch_project, image_uri, namespace, core_api
|
927
913
|
)
|
928
914
|
|
929
|
-
|
930
|
-
|
915
|
+
update_dict = {
|
916
|
+
"project_name": launch_project.target_project,
|
917
|
+
"entity_name": launch_project.target_entity,
|
918
|
+
"run_id": launch_project.run_id,
|
919
|
+
"run_name": launch_project.name,
|
920
|
+
"image_uri": image_uri,
|
921
|
+
"author": launch_project.author,
|
922
|
+
}
|
923
|
+
update_dict.update(os.environ)
|
924
|
+
additional_services: List[Dict[str, Any]] = recursive_macro_sub(
|
925
|
+
launch_project.launch_spec.get("additional_services", []), update_dict
|
926
|
+
)
|
931
927
|
if additional_services:
|
932
928
|
wandb.termlog(
|
933
929
|
f"{LOG_PREFIX}Creating additional services: {additional_services}"
|
934
930
|
)
|
935
|
-
auxiliary_resource_label_key = f"aux-{uuid.uuid4()}"
|
936
931
|
|
937
932
|
wait_for_ready = resource_args.get("wait_for_ready", True)
|
938
933
|
wait_timeout = resource_args.get("wait_timeout", 300)
|
@@ -941,10 +936,9 @@ class KubernetesRunner(AbstractRunner):
|
|
941
936
|
*[
|
942
937
|
self._prepare_resource(
|
943
938
|
api_client,
|
944
|
-
resource.get("config"),
|
939
|
+
resource.get("config", {}),
|
945
940
|
namespace,
|
946
941
|
launch_project.run_id,
|
947
|
-
auxiliary_resource_label_key,
|
948
942
|
launch_project,
|
949
943
|
secret,
|
950
944
|
wait_for_ready,
|
@@ -982,10 +976,11 @@ class KubernetesRunner(AbstractRunner):
|
|
982
976
|
batch_api,
|
983
977
|
core_api,
|
984
978
|
apps_api,
|
979
|
+
network_api,
|
985
980
|
job_name,
|
986
981
|
namespace,
|
987
982
|
secret,
|
988
|
-
|
983
|
+
f"aux-{launch_project.target_entity}-{launch_project.target_project}-{launch_project.run_id}",
|
989
984
|
)
|
990
985
|
if self.backend_config[PROJECT_SYNCHRONOUS]:
|
991
986
|
await submitted_job.wait()
|
wandb/sdk/lib/asyncio_compat.py
CHANGED
@@ -23,7 +23,7 @@ def run(fn: Callable[[], Coroutine[Any, Any, _T]]) -> _T:
|
|
23
23
|
Note that due to starting a new thread, this is slightly slow.
|
24
24
|
"""
|
25
25
|
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
26
|
-
runner =
|
26
|
+
runner = CancellableRunner()
|
27
27
|
future = executor.submit(runner.run, fn)
|
28
28
|
|
29
29
|
try:
|
@@ -33,15 +33,16 @@ def run(fn: Callable[[], Coroutine[Any, Any, _T]]) -> _T:
|
|
33
33
|
runner.cancel()
|
34
34
|
|
35
35
|
|
36
|
-
class
|
37
|
-
"""The `
|
36
|
+
class RunnerCancelledError(Exception):
|
37
|
+
"""The `CancellableRunner.run()` invocation was cancelled."""
|
38
38
|
|
39
39
|
|
40
|
-
class
|
40
|
+
class CancellableRunner:
|
41
41
|
"""Runs an asyncio event loop allowing cancellation.
|
42
42
|
|
43
|
-
|
44
|
-
|
43
|
+
The `run()` method is like `asyncio.run()`. The `cancel()` method may
|
44
|
+
be used in a different thread, for instance in a `finally` block, to cancel
|
45
|
+
all tasks, and it is a no-op if `run()` completed.
|
45
46
|
|
46
47
|
Without this, it is impossible to make `asyncio.run()` stop if it runs
|
47
48
|
in a non-main thread. In particular, a KeyboardInterrupt causes the
|
@@ -69,7 +70,7 @@ class _Runner:
|
|
69
70
|
The result of the coroutine returned by `fn`.
|
70
71
|
|
71
72
|
Raises:
|
72
|
-
|
73
|
+
RunnerCancelledError: If `cancel()` is called.
|
73
74
|
"""
|
74
75
|
return asyncio.run(self._run_or_cancel(fn))
|
75
76
|
|
@@ -79,7 +80,7 @@ class _Runner:
|
|
79
80
|
) -> _T:
|
80
81
|
with self._lock:
|
81
82
|
if self._is_cancelled:
|
82
|
-
raise
|
83
|
+
raise RunnerCancelledError()
|
83
84
|
|
84
85
|
self._loop = asyncio.get_running_loop()
|
85
86
|
self._cancel_event = asyncio.Event()
|
@@ -97,7 +98,7 @@ class _Runner:
|
|
97
98
|
if fn_task.done():
|
98
99
|
return fn_task.result()
|
99
100
|
else:
|
100
|
-
raise
|
101
|
+
raise RunnerCancelledError()
|
101
102
|
|
102
103
|
finally:
|
103
104
|
# NOTE: asyncio.run() cancels all tasks after the main task exits,
|
@@ -157,11 +158,9 @@ class TaskGroup:
|
|
157
158
|
)
|
158
159
|
|
159
160
|
for task in done:
|
160
|
-
|
161
|
+
with contextlib.suppress(asyncio.CancelledError):
|
161
162
|
if exc := task.exception():
|
162
163
|
raise exc
|
163
|
-
except asyncio.CancelledError:
|
164
|
-
pass
|
165
164
|
|
166
165
|
def _cancel_all(self) -> None:
|
167
166
|
"""Cancel all tasks."""
|
@@ -199,15 +198,16 @@ async def open_task_group() -> AsyncIterator[TaskGroup]:
|
|
199
198
|
def cancel_on_exit(coro: Coroutine[Any, Any, Any]) -> Iterator[None]:
|
200
199
|
"""Schedule a task, cancelling it when exiting the context manager.
|
201
200
|
|
202
|
-
If the
|
203
|
-
|
201
|
+
If the context manager exits successfully but the given coroutine raises
|
202
|
+
an exception, that exception is reraised. The exception is suppressed
|
203
|
+
if the context manager raises an exception.
|
204
204
|
"""
|
205
205
|
task = asyncio.create_task(coro)
|
206
206
|
|
207
207
|
try:
|
208
208
|
yield
|
209
|
-
|
209
|
+
|
210
210
|
if task.done() and (exception := task.exception()):
|
211
211
|
raise exception
|
212
|
-
|
212
|
+
finally:
|
213
213
|
task.cancel()
|
@@ -0,0 +1,252 @@
|
|
1
|
+
"""Implements an asyncio thread suitable for internal wandb use."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
import asyncio
|
6
|
+
import concurrent.futures
|
7
|
+
import contextlib
|
8
|
+
import logging
|
9
|
+
import threading
|
10
|
+
from typing import Any, Callable, Coroutine, TypeVar
|
11
|
+
|
12
|
+
from . import asyncio_compat
|
13
|
+
|
14
|
+
_T = TypeVar("_T")
|
15
|
+
|
16
|
+
_logger = logging.getLogger(__name__)
|
17
|
+
|
18
|
+
|
19
|
+
class RunCancelledError(Exception):
|
20
|
+
"""A function passed to AsyncioManager.run() was cancelled."""
|
21
|
+
|
22
|
+
|
23
|
+
class AlreadyJoinedError(Exception):
|
24
|
+
"""AsyncioManager.run() used after join()."""
|
25
|
+
|
26
|
+
|
27
|
+
class AsyncioManager:
|
28
|
+
"""Manages a thread running an asyncio loop.
|
29
|
+
|
30
|
+
The thread must be started using start() and should be joined using
|
31
|
+
join(). The thread is a daemon thread, so if join() is not invoked,
|
32
|
+
the asyncio work could end abruptly when all non-daemon threads exit.
|
33
|
+
|
34
|
+
The run() method allows invoking an async function in the asyncio thread
|
35
|
+
and waiting until it completes. The run_soon() method allows running
|
36
|
+
an async function without waiting for it.
|
37
|
+
|
38
|
+
Note that although tempting, it is **not** possible to write a safe
|
39
|
+
run_in_loop() method that chooses whether to use run() or execute a function
|
40
|
+
directly based on whether it's called from the asyncio thread: Suppose a
|
41
|
+
function bad() holds a threading.Lock while using run_in_loop() and an
|
42
|
+
asyncio task calling bad() is scheduled. If bad() is then invoked in a
|
43
|
+
different thread that reaches run_in_loop(), the aforementioned asyncio task
|
44
|
+
will deadlock. It is unreasonable to require that run_in_loop() never be
|
45
|
+
called while holding a lock (which would apply to the callers of its
|
46
|
+
callers, and so on), so it cannot safely exist.
|
47
|
+
"""
|
48
|
+
|
49
|
+
def __init__(self) -> None:
|
50
|
+
self._runner = asyncio_compat.CancellableRunner()
|
51
|
+
self._thread = threading.Thread(
|
52
|
+
target=self._main,
|
53
|
+
name="wandb-AsyncioManager-main",
|
54
|
+
daemon=True,
|
55
|
+
)
|
56
|
+
self._lock = threading.Lock()
|
57
|
+
|
58
|
+
self._ready_event = threading.Event()
|
59
|
+
"""Whether asyncio primitives have been initialized."""
|
60
|
+
|
61
|
+
self._joined = False
|
62
|
+
"""Whether join() has been called. Guarded by _lock."""
|
63
|
+
|
64
|
+
self._loop: asyncio.AbstractEventLoop
|
65
|
+
"""A handle for interacting with the asyncio event loop."""
|
66
|
+
|
67
|
+
self._done_event: asyncio.Event
|
68
|
+
"""Indicates to the asyncio loop that join() was called."""
|
69
|
+
|
70
|
+
self._remaining_tasks = 0
|
71
|
+
"""The number of tasks remaining. Guarded by _lock."""
|
72
|
+
|
73
|
+
self._task_finished_cond: asyncio.Condition
|
74
|
+
"""Signalled when _remaining_tasks is decremented."""
|
75
|
+
|
76
|
+
def start(self) -> None:
|
77
|
+
"""Start the asyncio thread."""
|
78
|
+
self._thread.start()
|
79
|
+
|
80
|
+
def join(self) -> None:
|
81
|
+
"""Stop accepting new asyncio tasks and wait for the remaining ones."""
|
82
|
+
try:
|
83
|
+
with self._lock:
|
84
|
+
# If join() was already called, block until the thread completes
|
85
|
+
# and then return.
|
86
|
+
if self._joined:
|
87
|
+
self._thread.join()
|
88
|
+
return
|
89
|
+
|
90
|
+
self._joined = True
|
91
|
+
|
92
|
+
# Wait until _loop and _done_event are initialized.
|
93
|
+
self._ready_event.wait()
|
94
|
+
|
95
|
+
# Set the done event. The main function will exit once all
|
96
|
+
# tasks complete.
|
97
|
+
self._loop.call_soon_threadsafe(self._done_event.set)
|
98
|
+
|
99
|
+
self._thread.join()
|
100
|
+
|
101
|
+
finally:
|
102
|
+
# Any of the above may get interrupted by Ctrl+C, in which case we
|
103
|
+
# should cancel all tasks, since join() can only be called once.
|
104
|
+
# This only matters if the KeyboardInterrupt is suppressed.
|
105
|
+
self._runner.cancel()
|
106
|
+
|
107
|
+
def run(self, fn: Callable[[], Coroutine[Any, Any, _T]]) -> _T:
|
108
|
+
"""Run an async function to completion.
|
109
|
+
|
110
|
+
The function is called in the asyncio thread. Blocks until start()
|
111
|
+
is called. This raises an error if called inside an async function,
|
112
|
+
and as a consequence, the caller may also not be called inside an
|
113
|
+
async function.
|
114
|
+
|
115
|
+
Args:
|
116
|
+
fn: The function to run.
|
117
|
+
|
118
|
+
Returns:
|
119
|
+
The return value of fn.
|
120
|
+
|
121
|
+
Raises:
|
122
|
+
Exception: Any exception raised by fn.
|
123
|
+
RunCancelledError: If fn is cancelled, particularly when join()
|
124
|
+
is interrupted by Ctrl+C or if it otherwise cancels itself.
|
125
|
+
AlreadyJoinedError: If join() was already called.
|
126
|
+
ValueError: If called inside an async function.
|
127
|
+
"""
|
128
|
+
self._ready_event.wait()
|
129
|
+
|
130
|
+
if threading.current_thread().ident == self._thread.ident:
|
131
|
+
raise ValueError("Cannot use run() inside async loop.")
|
132
|
+
|
133
|
+
future = self._schedule(fn, daemon=False)
|
134
|
+
|
135
|
+
try:
|
136
|
+
return future.result()
|
137
|
+
|
138
|
+
except concurrent.futures.CancelledError:
|
139
|
+
raise RunCancelledError from None
|
140
|
+
|
141
|
+
except KeyboardInterrupt:
|
142
|
+
# If we're interrupted here, we only cancel this task rather than
|
143
|
+
# cancelling all tasks like in join(). This only matters if the
|
144
|
+
# interrupt is then suppressed (or delayed) in which case we
|
145
|
+
# should let other tasks progress.
|
146
|
+
future.cancel()
|
147
|
+
raise
|
148
|
+
|
149
|
+
def run_soon(
|
150
|
+
self,
|
151
|
+
fn: Callable[[], Coroutine[Any, Any, None]],
|
152
|
+
*,
|
153
|
+
daemon: bool = False,
|
154
|
+
name: str | None = None,
|
155
|
+
) -> None:
|
156
|
+
"""Run an async function without waiting for it to complete.
|
157
|
+
|
158
|
+
The function is called in the asyncio thread. Note that since that's
|
159
|
+
a daemon thread, it will not get joined when the main thread exits,
|
160
|
+
so fn can stop abruptly.
|
161
|
+
|
162
|
+
Unlike run(), it is OK to call this inside an async function.
|
163
|
+
|
164
|
+
Blocks until start() is called.
|
165
|
+
|
166
|
+
Args:
|
167
|
+
fn: The function to run.
|
168
|
+
daemon: If true, join() will cancel fn after all non-daemon
|
169
|
+
tasks complete. By default, join() blocks until fn
|
170
|
+
completes.
|
171
|
+
name: An optional name to give to long-running tasks which can
|
172
|
+
appear in error traces and be useful to debugging.
|
173
|
+
|
174
|
+
Raises:
|
175
|
+
AlreadyJoinedError: If join() was already called.
|
176
|
+
"""
|
177
|
+
|
178
|
+
# Wrap exceptions so that they're not printed to console.
|
179
|
+
async def fn_wrap_exceptions() -> None:
|
180
|
+
try:
|
181
|
+
await fn()
|
182
|
+
except Exception:
|
183
|
+
_logger.exception("Uncaught exception in run_soon callback.")
|
184
|
+
|
185
|
+
_ = self._schedule(fn_wrap_exceptions, daemon=daemon, name=name)
|
186
|
+
|
187
|
+
def _schedule(
|
188
|
+
self,
|
189
|
+
fn: Callable[[], Coroutine[Any, Any, _T]],
|
190
|
+
daemon: bool,
|
191
|
+
name: str | None = None,
|
192
|
+
) -> concurrent.futures.Future[_T]:
|
193
|
+
# Wait for _loop to be initialized.
|
194
|
+
self._ready_event.wait()
|
195
|
+
|
196
|
+
with self._lock:
|
197
|
+
if self._joined:
|
198
|
+
raise AlreadyJoinedError
|
199
|
+
|
200
|
+
if not daemon:
|
201
|
+
self._remaining_tasks += 1
|
202
|
+
|
203
|
+
return asyncio.run_coroutine_threadsafe(
|
204
|
+
self._wrap(fn, daemon=daemon, name=name),
|
205
|
+
self._loop,
|
206
|
+
)
|
207
|
+
|
208
|
+
async def _wrap(
|
209
|
+
self,
|
210
|
+
fn: Callable[[], Coroutine[Any, Any, _T]],
|
211
|
+
daemon: bool,
|
212
|
+
name: str | None,
|
213
|
+
) -> _T:
|
214
|
+
"""Run fn to completion and possibly decrement _remaining tasks."""
|
215
|
+
try:
|
216
|
+
if name and (task := asyncio.current_task()):
|
217
|
+
task.set_name(name)
|
218
|
+
|
219
|
+
return await fn()
|
220
|
+
finally:
|
221
|
+
if not daemon:
|
222
|
+
async with self._task_finished_cond:
|
223
|
+
with self._lock:
|
224
|
+
self._remaining_tasks -= 1
|
225
|
+
self._task_finished_cond.notify_all()
|
226
|
+
|
227
|
+
def _main(self) -> None:
|
228
|
+
"""Run the asyncio loop until join() is called and all tasks finish."""
|
229
|
+
# A cancellation error is expected if join() is interrupted.
|
230
|
+
#
|
231
|
+
# Were it not suppressed, its stacktrace would get printed.
|
232
|
+
with contextlib.suppress(asyncio_compat.RunnerCancelledError):
|
233
|
+
self._runner.run(self._main_async)
|
234
|
+
|
235
|
+
async def _main_async(self) -> None:
|
236
|
+
"""Wait until join() is called and all tasks finish."""
|
237
|
+
self._loop = asyncio.get_running_loop()
|
238
|
+
self._done_event = asyncio.Event()
|
239
|
+
self._task_finished_cond = asyncio.Condition()
|
240
|
+
|
241
|
+
self._ready_event.set()
|
242
|
+
|
243
|
+
# Wait until done.
|
244
|
+
await self._done_event.wait()
|
245
|
+
|
246
|
+
# Wait for all tasks to complete.
|
247
|
+
#
|
248
|
+
# Once we exit, asyncio will cancel any leftover tasks.
|
249
|
+
async with self._task_finished_cond:
|
250
|
+
await self._task_finished_cond.wait_for(
|
251
|
+
lambda: self._remaining_tasks <= 0,
|
252
|
+
)
|
wandb/sdk/lib/hashutil.py
CHANGED
@@ -6,19 +6,28 @@ import logging
|
|
6
6
|
import mmap
|
7
7
|
import sys
|
8
8
|
import time
|
9
|
-
from typing import TYPE_CHECKING
|
9
|
+
from typing import TYPE_CHECKING
|
10
|
+
|
11
|
+
from typing_extensions import TypeAlias
|
10
12
|
|
11
13
|
from wandb.sdk.lib.paths import StrPath
|
12
14
|
|
13
15
|
if TYPE_CHECKING:
|
14
16
|
import _hashlib # type: ignore[import-not-found]
|
15
17
|
|
16
|
-
ETag = NewType("ETag", str)
|
17
|
-
HexMD5 = NewType("HexMD5", str)
|
18
|
-
B64MD5 = NewType("B64MD5", str)
|
19
18
|
|
20
19
|
logger = logging.getLogger(__name__)
|
21
20
|
|
21
|
+
# In the future, consider relying on pydantic to validate these types via e.g.
|
22
|
+
# - Base64Str: https://docs.pydantic.dev/latest/api/types/#pydantic.types.Base64Str
|
23
|
+
# - a custom EncodedStr + Encoder impl: https://docs.pydantic.dev/latest/api/types/#pydantic.types.EncodedStr
|
24
|
+
#
|
25
|
+
# Note that so long as we continue to support Pydantic v1, the options above will require a compatible shim/backport
|
26
|
+
# implementation, since those types are not in Pydantic v1.
|
27
|
+
ETag: TypeAlias = str
|
28
|
+
HexMD5: TypeAlias = str
|
29
|
+
B64MD5: TypeAlias = str
|
30
|
+
|
22
31
|
|
23
32
|
def _md5(data: bytes = b"") -> _hashlib.HASH:
|
24
33
|
"""Allow FIPS-compliant md5 hash when supported."""
|
wandb/sdk/lib/printer.py
CHANGED
@@ -102,8 +102,8 @@ def new_printer(settings: wandb.Settings | None = None) -> Printer:
|
|
102
102
|
has been called, then global settings are used. Otherwise,
|
103
103
|
settings (such as silent mode) are ignored.
|
104
104
|
"""
|
105
|
-
if not settings and (
|
106
|
-
settings =
|
105
|
+
if not settings and (s := wandb_setup.singleton().settings_if_loaded):
|
106
|
+
settings = s
|
107
107
|
|
108
108
|
if ipython.in_jupyter():
|
109
109
|
return _PrinterJupyter(settings=settings)
|
wandb/sdk/lib/printer_asyncio.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
import asyncio
|
2
2
|
from typing import Callable, TypeVar
|
3
3
|
|
4
|
+
from wandb.sdk import wandb_setup
|
4
5
|
from wandb.sdk.lib import asyncio_compat, printer
|
5
6
|
|
6
7
|
_T = TypeVar("_T")
|
@@ -43,4 +44,5 @@ def run_async_with_spinner(
|
|
43
44
|
func_running.set()
|
44
45
|
return res
|
45
46
|
|
46
|
-
|
47
|
+
asyncer = wandb_setup.singleton().asyncer
|
48
|
+
return asyncer.run(_loop_run_with_spinner)
|