wandb 0.19.6__py3-none-macosx_11_0_arm64.whl → 0.19.7__py3-none-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +1 -1
- wandb/__init__.pyi +25 -5
- wandb/apis/public/_generated/__init__.py +21 -0
- wandb/apis/public/_generated/base.py +128 -0
- wandb/apis/public/_generated/enums.py +4 -0
- wandb/apis/public/_generated/input_types.py +4 -0
- wandb/apis/public/_generated/operations.py +15 -0
- wandb/apis/public/_generated/server_features_query.py +27 -0
- wandb/apis/public/_generated/typing_compat.py +14 -0
- wandb/apis/public/api.py +192 -6
- wandb/apis/public/artifacts.py +13 -45
- wandb/apis/public/registries.py +573 -0
- wandb/apis/public/utils.py +36 -0
- wandb/bin/gpu_stats +0 -0
- wandb/bin/wandb-core +0 -0
- wandb/cli/cli.py +11 -20
- wandb/env.py +10 -0
- wandb/proto/v3/wandb_internal_pb2.py +243 -222
- wandb/proto/v3/wandb_server_pb2.py +4 -4
- wandb/proto/v3/wandb_settings_pb2.py +1 -1
- wandb/proto/v4/wandb_internal_pb2.py +226 -222
- wandb/proto/v4/wandb_server_pb2.py +4 -4
- wandb/proto/v4/wandb_settings_pb2.py +1 -1
- wandb/proto/v5/wandb_internal_pb2.py +226 -222
- wandb/proto/v5/wandb_server_pb2.py +4 -4
- wandb/proto/v5/wandb_settings_pb2.py +1 -1
- wandb/sdk/artifacts/_graphql_fragments.py +126 -0
- wandb/sdk/artifacts/artifact.py +43 -88
- wandb/sdk/backend/backend.py +1 -1
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +14 -6
- wandb/sdk/data_types/helper_types/image_mask.py +12 -6
- wandb/sdk/data_types/saved_model.py +35 -46
- wandb/sdk/data_types/video.py +7 -16
- wandb/sdk/interface/interface.py +26 -10
- wandb/sdk/interface/interface_queue.py +5 -8
- wandb/sdk/interface/interface_relay.py +1 -6
- wandb/sdk/interface/interface_shared.py +21 -99
- wandb/sdk/interface/interface_sock.py +2 -13
- wandb/sdk/interface/router.py +21 -15
- wandb/sdk/interface/router_queue.py +2 -1
- wandb/sdk/interface/router_relay.py +2 -1
- wandb/sdk/interface/router_sock.py +5 -4
- wandb/sdk/internal/handler.py +4 -3
- wandb/sdk/internal/internal_api.py +12 -1
- wandb/sdk/internal/sender.py +0 -18
- wandb/sdk/lib/apikey.py +87 -26
- wandb/sdk/lib/asyncio_compat.py +210 -0
- wandb/sdk/lib/progress.py +78 -16
- wandb/sdk/lib/service_connection.py +1 -1
- wandb/sdk/lib/sock_client.py +7 -7
- wandb/sdk/mailbox/__init__.py +23 -0
- wandb/sdk/mailbox/handles.py +199 -0
- wandb/sdk/mailbox/mailbox.py +121 -0
- wandb/sdk/mailbox/wait_with_progress.py +134 -0
- wandb/sdk/service/server_sock.py +5 -1
- wandb/sdk/service/streams.py +66 -74
- wandb/sdk/verify/verify.py +54 -2
- wandb/sdk/wandb_init.py +61 -61
- wandb/sdk/wandb_login.py +7 -4
- wandb/sdk/wandb_metadata.py +65 -34
- wandb/sdk/wandb_require.py +14 -8
- wandb/sdk/wandb_run.py +82 -87
- wandb/sdk/wandb_settings.py +3 -3
- wandb/sdk/wandb_setup.py +19 -8
- wandb/sdk/wandb_sync.py +2 -4
- wandb/util.py +3 -1
- {wandb-0.19.6.dist-info → wandb-0.19.7.dist-info}/METADATA +2 -2
- {wandb-0.19.6.dist-info → wandb-0.19.7.dist-info}/RECORD +71 -58
- wandb/sdk/lib/mailbox.py +0 -442
- {wandb-0.19.6.dist-info → wandb-0.19.7.dist-info}/WHEEL +0 -0
- {wandb-0.19.6.dist-info → wandb-0.19.7.dist-info}/entry_points.txt +0 -0
- {wandb-0.19.6.dist-info → wandb-0.19.7.dist-info}/licenses/LICENSE +0 -0
wandb/sdk/lib/progress.py
CHANGED
@@ -2,12 +2,15 @@
|
|
2
2
|
|
3
3
|
from __future__ import annotations
|
4
4
|
|
5
|
+
import asyncio
|
5
6
|
import contextlib
|
6
|
-
|
7
|
+
import time
|
8
|
+
from typing import Iterable, Iterator, NoReturn
|
7
9
|
|
8
|
-
import wandb
|
9
10
|
from wandb import env
|
10
11
|
from wandb.proto import wandb_internal_pb2 as pb
|
12
|
+
from wandb.sdk.interface import interface
|
13
|
+
from wandb.sdk.lib import asyncio_compat
|
11
14
|
|
12
15
|
from . import printer as p
|
13
16
|
|
@@ -31,14 +34,67 @@ def print_sync_dedupe_stats(
|
|
31
34
|
printer.display(f"W&B sync reduced upload amount by {frac:.1%}")
|
32
35
|
|
33
36
|
|
37
|
+
async def loop_printing_operation_stats(
|
38
|
+
progress: ProgressPrinter,
|
39
|
+
interface: interface.InterfaceBase,
|
40
|
+
) -> None:
|
41
|
+
"""Poll and display ongoing tasks in the internal service process.
|
42
|
+
|
43
|
+
This never returns and must be cancelled. This is meant to be used with
|
44
|
+
`mailbox.wait_with_progress()`.
|
45
|
+
|
46
|
+
Args:
|
47
|
+
progress: The printer to update with operation stats.
|
48
|
+
interface: The interface to use to poll for updates.
|
49
|
+
|
50
|
+
Raises:
|
51
|
+
HandleAbandonedError: If the mailbox associated with the interface
|
52
|
+
becomes closed.
|
53
|
+
Exception: Any other problem communicating with the service process.
|
54
|
+
"""
|
55
|
+
stats: pb.OperationStats | None = None
|
56
|
+
|
57
|
+
async def loop_update_screen() -> NoReturn:
|
58
|
+
while True:
|
59
|
+
if stats:
|
60
|
+
progress.update(stats)
|
61
|
+
await asyncio.sleep(0.1)
|
62
|
+
|
63
|
+
async def loop_poll_stats() -> NoReturn:
|
64
|
+
nonlocal stats
|
65
|
+
while True:
|
66
|
+
start_time = time.monotonic()
|
67
|
+
|
68
|
+
handle = interface.deliver_operation_stats()
|
69
|
+
result = await handle.wait_async(timeout=None)
|
70
|
+
stats = result.response.operations_response.operation_stats
|
71
|
+
|
72
|
+
elapsed_time = time.monotonic() - start_time
|
73
|
+
if elapsed_time < 0.5:
|
74
|
+
await asyncio.sleep(0.5 - elapsed_time)
|
75
|
+
|
76
|
+
async with asyncio_compat.open_task_group() as task_group:
|
77
|
+
task_group.start_soon(loop_update_screen())
|
78
|
+
task_group.start_soon(loop_poll_stats())
|
79
|
+
|
80
|
+
|
34
81
|
@contextlib.contextmanager
|
35
82
|
def progress_printer(
|
36
83
|
printer: p.Printer,
|
37
|
-
|
84
|
+
default_text: str,
|
38
85
|
) -> Iterator[ProgressPrinter]:
|
39
|
-
"""Context manager providing an object for printing run progress.
|
86
|
+
"""Context manager providing an object for printing run progress.
|
87
|
+
|
88
|
+
Args:
|
89
|
+
printer: The printer to use.
|
90
|
+
default_text: The text to show if no information is available.
|
91
|
+
"""
|
40
92
|
with printer.dynamic_text() as text_area:
|
41
|
-
yield ProgressPrinter(
|
93
|
+
yield ProgressPrinter(
|
94
|
+
printer,
|
95
|
+
text_area,
|
96
|
+
default_text=default_text,
|
97
|
+
)
|
42
98
|
printer.progress_close()
|
43
99
|
|
44
100
|
|
@@ -49,28 +105,27 @@ class ProgressPrinter:
|
|
49
105
|
self,
|
50
106
|
printer: p.Printer,
|
51
107
|
progress_text_area: p.DynamicText | None,
|
52
|
-
|
108
|
+
default_text: str,
|
53
109
|
) -> None:
|
54
|
-
|
55
|
-
|
56
|
-
and settings.x_show_operation_stats
|
57
|
-
# Not implemented by the legacy service.
|
58
|
-
and not env.is_require_legacy_service()
|
59
|
-
)
|
110
|
+
# Not implemented by the legacy service.
|
111
|
+
self._show_operation_stats = not env.is_require_legacy_service()
|
60
112
|
self._printer = printer
|
61
113
|
self._progress_text_area = progress_text_area
|
114
|
+
self._default_text = default_text
|
62
115
|
self._tick = 0
|
63
116
|
self._last_printed_line = ""
|
64
117
|
|
65
118
|
def update(
|
66
119
|
self,
|
67
|
-
progress: list[pb.PollExitResponse],
|
120
|
+
progress: list[pb.PollExitResponse] | pb.OperationStats,
|
68
121
|
) -> None:
|
69
122
|
"""Update the displayed information."""
|
70
123
|
if not progress:
|
71
124
|
return
|
72
125
|
|
73
|
-
if
|
126
|
+
if isinstance(progress, pb.OperationStats):
|
127
|
+
self._update_operation_stats([progress])
|
128
|
+
elif self._show_operation_stats:
|
74
129
|
self._update_operation_stats(
|
75
130
|
list(response.operation_stats for response in progress)
|
76
131
|
)
|
@@ -88,6 +143,7 @@ class ProgressPrinter:
|
|
88
143
|
self._progress_text_area,
|
89
144
|
max_lines=6,
|
90
145
|
loading_symbol=self._printer.loading_symbol(self._tick),
|
146
|
+
default_text=self._default_text,
|
91
147
|
).display(stats_list)
|
92
148
|
|
93
149
|
else:
|
@@ -159,6 +215,10 @@ class ProgressPrinter:
|
|
159
215
|
self._update_progress_text(line, 1.0)
|
160
216
|
|
161
217
|
def _update_progress_text(self, text: str, progress: float) -> None:
|
218
|
+
if text == self._last_printed_line:
|
219
|
+
return
|
220
|
+
self._last_printed_line = text
|
221
|
+
|
162
222
|
if self._progress_text_area:
|
163
223
|
self._progress_text_area.set_text(text)
|
164
224
|
else:
|
@@ -174,11 +234,13 @@ class _DynamicOperationStatsPrinter:
|
|
174
234
|
text_area: p.DynamicText,
|
175
235
|
max_lines: int,
|
176
236
|
loading_symbol: str,
|
237
|
+
default_text: str,
|
177
238
|
) -> None:
|
178
239
|
self._printer = printer
|
179
240
|
self._text_area = text_area
|
180
241
|
self._max_lines = max_lines
|
181
242
|
self._loading_symbol = loading_symbol
|
243
|
+
self._default_text = default_text
|
182
244
|
|
183
245
|
self._lines: list[str] = []
|
184
246
|
self._ops_shown = 0
|
@@ -204,9 +266,9 @@ class _DynamicOperationStatsPrinter:
|
|
204
266
|
|
205
267
|
if len(self._lines) == 0:
|
206
268
|
if self._loading_symbol:
|
207
|
-
self._text_area.set_text(f"{self._loading_symbol}
|
269
|
+
self._text_area.set_text(f"{self._loading_symbol} {self._default_text}")
|
208
270
|
else:
|
209
|
-
self._text_area.set_text(
|
271
|
+
self._text_area.set_text(self._default_text)
|
210
272
|
else:
|
211
273
|
self._text_area.set_text("\n".join(self._lines))
|
212
274
|
|
@@ -12,8 +12,8 @@ from wandb.sdk.interface.interface import InterfaceBase
|
|
12
12
|
from wandb.sdk.interface.interface_sock import InterfaceSock
|
13
13
|
from wandb.sdk.lib import service_token
|
14
14
|
from wandb.sdk.lib.exit_hooks import ExitHooks
|
15
|
-
from wandb.sdk.lib.mailbox import Mailbox
|
16
15
|
from wandb.sdk.lib.sock_client import SockClient, SockClientTimeoutError
|
16
|
+
from wandb.sdk.mailbox import Mailbox
|
17
17
|
from wandb.sdk.service import service
|
18
18
|
|
19
19
|
|
wandb/sdk/lib/sock_client.py
CHANGED
@@ -177,9 +177,9 @@ class SockClient:
|
|
177
177
|
inform_finish=inform_finish,
|
178
178
|
inform_teardown=inform_teardown,
|
179
179
|
)
|
180
|
-
|
181
|
-
#
|
182
|
-
#
|
180
|
+
|
181
|
+
# HACK: This assumes nothing else is reading on the socket, and that
|
182
|
+
# the next response is for this request.
|
183
183
|
response = self.read_server_response(timeout=1)
|
184
184
|
|
185
185
|
if response is None:
|
@@ -213,11 +213,13 @@ class SockClient:
|
|
213
213
|
|
214
214
|
def send_record_communicate(self, record: "pb.Record") -> None:
|
215
215
|
server_req = spb.ServerRequest()
|
216
|
+
server_req.request_id = record.control.mailbox_slot
|
216
217
|
server_req.record_communicate.CopyFrom(record)
|
217
218
|
self.send_server_request(server_req)
|
218
219
|
|
219
220
|
def send_record_publish(self, record: "pb.Record") -> None:
|
220
221
|
server_req = spb.ServerRequest()
|
222
|
+
server_req.request_id = record.control.mailbox_slot
|
221
223
|
server_req.record_publish.CopyFrom(record)
|
222
224
|
self.send_server_request(server_req)
|
223
225
|
|
@@ -256,10 +258,8 @@ class SockClient:
|
|
256
258
|
data = self._sock.recv(self._bufsize)
|
257
259
|
except socket.timeout:
|
258
260
|
break
|
259
|
-
except
|
260
|
-
raise SockClientClosedError
|
261
|
-
except OSError:
|
262
|
-
raise SockClientClosedError
|
261
|
+
except OSError as e:
|
262
|
+
raise SockClientClosedError from e
|
263
263
|
finally:
|
264
264
|
if timeout:
|
265
265
|
self._sock.settimeout(None)
|
@@ -0,0 +1,23 @@
|
|
1
|
+
"""A message protocol for the internal service process.
|
2
|
+
|
3
|
+
The core of W&B is implemented by a side process that asynchronously uploads
|
4
|
+
data. The client process (such as this Python code) sends requests to the
|
5
|
+
service, and for some requests, the service eventually sends a response.
|
6
|
+
|
7
|
+
The client can send multiple requests before the service provides a response.
|
8
|
+
The Mailbox handles matching responses to requests. An internal thread
|
9
|
+
continuously reads data from the service and passes it to the mailbox.
|
10
|
+
"""
|
11
|
+
|
12
|
+
from .handles import HandleAbandonedError, MailboxHandle
|
13
|
+
from .mailbox import Mailbox, MailboxClosedError
|
14
|
+
from .wait_with_progress import wait_all_with_progress, wait_with_progress
|
15
|
+
|
16
|
+
__all__ = [
|
17
|
+
"HandleAbandonedError",
|
18
|
+
"MailboxHandle",
|
19
|
+
"Mailbox",
|
20
|
+
"MailboxClosedError",
|
21
|
+
"wait_all_with_progress",
|
22
|
+
"wait_with_progress",
|
23
|
+
]
|
@@ -0,0 +1,199 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import asyncio
|
4
|
+
import math
|
5
|
+
import threading
|
6
|
+
from typing import TYPE_CHECKING
|
7
|
+
|
8
|
+
from wandb.proto import wandb_internal_pb2 as pb
|
9
|
+
|
10
|
+
# Necessary to break an import loop.
|
11
|
+
if TYPE_CHECKING:
|
12
|
+
from wandb.sdk.interface import interface
|
13
|
+
|
14
|
+
|
15
|
+
class HandleAbandonedError(Exception):
|
16
|
+
"""The handle has no result and has been abandoned."""
|
17
|
+
|
18
|
+
|
19
|
+
class MailboxHandle:
|
20
|
+
"""A thread-safe handle that allows waiting for a response to a request."""
|
21
|
+
|
22
|
+
def __init__(self, address: str) -> None:
|
23
|
+
self._address = address
|
24
|
+
self._lock = threading.Lock()
|
25
|
+
self._event = threading.Event()
|
26
|
+
|
27
|
+
self._abandoned = False
|
28
|
+
self._result: pb.Result | None = None
|
29
|
+
|
30
|
+
self._asyncio_events: dict[asyncio.Event, _AsyncioEvent] = dict()
|
31
|
+
|
32
|
+
def deliver(self, result: pb.Result) -> None:
|
33
|
+
"""Deliver the response.
|
34
|
+
|
35
|
+
This may only be called once. It is an error to respond to the same
|
36
|
+
request more than once. It is a no-op if the handle has been abandoned.
|
37
|
+
"""
|
38
|
+
with self._lock:
|
39
|
+
if self._abandoned:
|
40
|
+
return
|
41
|
+
|
42
|
+
if self._result:
|
43
|
+
raise ValueError(
|
44
|
+
f"A response has already been delivered to {self._address}."
|
45
|
+
)
|
46
|
+
|
47
|
+
self._result = result
|
48
|
+
self._signal_done()
|
49
|
+
|
50
|
+
def cancel(self, iface: interface.InterfaceBase) -> None:
|
51
|
+
"""Cancel the handle, requesting any associated work to not complete.
|
52
|
+
|
53
|
+
This automatically abandons the handle, as a response is no longer
|
54
|
+
guaranteed.
|
55
|
+
|
56
|
+
Args:
|
57
|
+
interface: The interface on which to publish the cancel request.
|
58
|
+
"""
|
59
|
+
iface.publish_cancel(self._address)
|
60
|
+
self.abandon()
|
61
|
+
|
62
|
+
def abandon(self) -> None:
|
63
|
+
"""Abandon the handle, indicating it will not receive a response."""
|
64
|
+
with self._lock:
|
65
|
+
self._abandoned = True
|
66
|
+
self._signal_done()
|
67
|
+
|
68
|
+
def _signal_done(self) -> None:
|
69
|
+
"""Indicate that the handle either got a result or became abandoned.
|
70
|
+
|
71
|
+
The lock must be held.
|
72
|
+
"""
|
73
|
+
# Unblock threads blocked on `wait_or`.
|
74
|
+
self._event.set()
|
75
|
+
|
76
|
+
# Unblock asyncio loops blocked on `wait_async`.
|
77
|
+
for asyncio_event in self._asyncio_events.values():
|
78
|
+
asyncio_event.set_threadsafe()
|
79
|
+
self._asyncio_events.clear()
|
80
|
+
|
81
|
+
def check(self) -> pb.Result | None:
|
82
|
+
"""Returns the result if it's ready."""
|
83
|
+
with self._lock:
|
84
|
+
return self._result
|
85
|
+
|
86
|
+
def wait_or(self, *, timeout: float | None) -> pb.Result:
|
87
|
+
"""Wait for a response or a timeout.
|
88
|
+
|
89
|
+
This is called `wait_or` because it replaces a method called `wait`
|
90
|
+
with different semantics.
|
91
|
+
|
92
|
+
Args:
|
93
|
+
timeout: A finite number of seconds or None to never time out.
|
94
|
+
If less than or equal to zero, times out immediately unless
|
95
|
+
the result is available.
|
96
|
+
|
97
|
+
Returns:
|
98
|
+
The result if it arrives before the timeout or has already arrived.
|
99
|
+
|
100
|
+
Raises:
|
101
|
+
TimeoutError: If the timeout is reached.
|
102
|
+
HandleAbandonedError: If the handle becomes abandoned.
|
103
|
+
"""
|
104
|
+
if timeout is not None and not math.isfinite(timeout):
|
105
|
+
raise ValueError("Timeout must be finite or None.")
|
106
|
+
|
107
|
+
if not self._event.wait(timeout=timeout):
|
108
|
+
raise TimeoutError(
|
109
|
+
f"Timed out waiting for response on {self._address}",
|
110
|
+
)
|
111
|
+
|
112
|
+
with self._lock:
|
113
|
+
if self._result:
|
114
|
+
return self._result
|
115
|
+
|
116
|
+
assert self._abandoned
|
117
|
+
raise HandleAbandonedError()
|
118
|
+
|
119
|
+
async def wait_async(self, *, timeout: float | None) -> pb.Result:
|
120
|
+
"""Wait for a response or timeout.
|
121
|
+
|
122
|
+
This must run in an `asyncio` event loop.
|
123
|
+
|
124
|
+
Args:
|
125
|
+
timeout: A finite number of seconds or None to never time out.
|
126
|
+
|
127
|
+
Returns:
|
128
|
+
The result if it arrives before the timeout or has already arrived.
|
129
|
+
|
130
|
+
Raises:
|
131
|
+
TimeoutError: If the timeout is reached.
|
132
|
+
HandleAbandonedError: If the handle becomes abandoned.
|
133
|
+
"""
|
134
|
+
if timeout is not None and not math.isfinite(timeout):
|
135
|
+
raise ValueError("Timeout must be finite or None.")
|
136
|
+
|
137
|
+
evt = asyncio.Event()
|
138
|
+
self._add_asyncio_event(asyncio.get_event_loop(), evt)
|
139
|
+
|
140
|
+
try:
|
141
|
+
await asyncio.wait_for(evt.wait(), timeout=timeout)
|
142
|
+
|
143
|
+
except (asyncio.TimeoutError, TimeoutError) as e:
|
144
|
+
with self._lock:
|
145
|
+
if self._result:
|
146
|
+
return self._result
|
147
|
+
elif self._abandoned:
|
148
|
+
raise HandleAbandonedError()
|
149
|
+
else:
|
150
|
+
raise TimeoutError(
|
151
|
+
f"Timed out waiting for response on {self._address}"
|
152
|
+
) from e
|
153
|
+
|
154
|
+
else:
|
155
|
+
with self._lock:
|
156
|
+
if self._result:
|
157
|
+
return self._result
|
158
|
+
|
159
|
+
assert self._abandoned
|
160
|
+
raise HandleAbandonedError()
|
161
|
+
|
162
|
+
finally:
|
163
|
+
self._forget_asyncio_event(evt)
|
164
|
+
|
165
|
+
def _add_asyncio_event(
|
166
|
+
self,
|
167
|
+
loop: asyncio.AbstractEventLoop,
|
168
|
+
event: asyncio.Event,
|
169
|
+
) -> None:
|
170
|
+
"""Add an event to signal when a result is received.
|
171
|
+
|
172
|
+
If a result already exists, this notifies the event loop immediately.
|
173
|
+
"""
|
174
|
+
asyncio_event = _AsyncioEvent(loop, event)
|
175
|
+
|
176
|
+
with self._lock:
|
177
|
+
if self._result or self._abandoned:
|
178
|
+
asyncio_event.set_threadsafe()
|
179
|
+
else:
|
180
|
+
self._asyncio_events[event] = asyncio_event
|
181
|
+
|
182
|
+
def _forget_asyncio_event(self, event: asyncio.Event) -> None:
|
183
|
+
"""Cancel signalling an event when a result is received."""
|
184
|
+
with self._lock:
|
185
|
+
self._asyncio_events.pop(event, None)
|
186
|
+
|
187
|
+
|
188
|
+
class _AsyncioEvent:
|
189
|
+
def __init__(
|
190
|
+
self,
|
191
|
+
loop: asyncio.AbstractEventLoop,
|
192
|
+
event: asyncio.Event,
|
193
|
+
):
|
194
|
+
self._loop = loop
|
195
|
+
self._event = event
|
196
|
+
|
197
|
+
def set_threadsafe(self) -> None:
|
198
|
+
"""Set the asyncio event in its own loop."""
|
199
|
+
self._loop.call_soon_threadsafe(self._event.set)
|
@@ -0,0 +1,121 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import logging
|
4
|
+
import secrets
|
5
|
+
import string
|
6
|
+
import threading
|
7
|
+
|
8
|
+
from wandb.proto import wandb_internal_pb2 as pb
|
9
|
+
|
10
|
+
from . import handles
|
11
|
+
|
12
|
+
_logger = logging.getLogger(__name__)
|
13
|
+
|
14
|
+
|
15
|
+
class MailboxClosedError(Exception):
|
16
|
+
"""The mailbox has been closed and cannot be used."""
|
17
|
+
|
18
|
+
|
19
|
+
class Mailbox:
|
20
|
+
"""Matches service responses to requests.
|
21
|
+
|
22
|
+
The mailbox can set an address on a Record and create a handle for
|
23
|
+
waiting for a response to that record. Responses are delivered by calling
|
24
|
+
`deliver()`. The `close()` method abandons all handles in case the
|
25
|
+
service process becomes unreachable.
|
26
|
+
"""
|
27
|
+
|
28
|
+
def __init__(self) -> None:
|
29
|
+
self._handles: dict[str, handles.MailboxHandle] = {}
|
30
|
+
self._handles_lock = threading.Lock()
|
31
|
+
self._closed = False
|
32
|
+
|
33
|
+
def require_response(self, request: pb.Record) -> handles.MailboxHandle:
|
34
|
+
"""Set a response address on a request.
|
35
|
+
|
36
|
+
Args:
|
37
|
+
request: The request on which to set a mailbox slot.
|
38
|
+
This is mutated. An address must not already be set.
|
39
|
+
|
40
|
+
Returns:
|
41
|
+
A handle for waiting for the response to the request.
|
42
|
+
|
43
|
+
Raises:
|
44
|
+
MailboxClosedError: If the mailbox has been closed, in which case
|
45
|
+
no new responses are expected to be delivered and new handles
|
46
|
+
cannot be created.
|
47
|
+
"""
|
48
|
+
if address := request.control.mailbox_slot:
|
49
|
+
raise ValueError(f"Request already has an address ({address})")
|
50
|
+
|
51
|
+
address = self._new_address()
|
52
|
+
request.control.mailbox_slot = address
|
53
|
+
|
54
|
+
with self._handles_lock:
|
55
|
+
if self._closed:
|
56
|
+
raise MailboxClosedError()
|
57
|
+
|
58
|
+
handle = handles.MailboxHandle(address)
|
59
|
+
self._handles[address] = handle
|
60
|
+
|
61
|
+
return handle
|
62
|
+
|
63
|
+
def _new_address(self) -> str:
|
64
|
+
"""Returns an unused address for a request.
|
65
|
+
|
66
|
+
Assumes `_handles_lock` is held.
|
67
|
+
"""
|
68
|
+
|
69
|
+
def generate():
|
70
|
+
return "".join(
|
71
|
+
secrets.choice(string.ascii_lowercase + string.digits)
|
72
|
+
for i in range(12)
|
73
|
+
)
|
74
|
+
|
75
|
+
address = generate()
|
76
|
+
|
77
|
+
# Being extra cautious. This loop will almost never be entered.
|
78
|
+
while address in self._handles:
|
79
|
+
address = generate()
|
80
|
+
|
81
|
+
return address
|
82
|
+
|
83
|
+
def deliver(self, result: pb.Result) -> None:
|
84
|
+
"""Deliver a response from the service.
|
85
|
+
|
86
|
+
If the response address is invalid, this does nothing.
|
87
|
+
It is a no-op if the mailbox has been closed.
|
88
|
+
"""
|
89
|
+
address = result.control.mailbox_slot
|
90
|
+
if not address:
|
91
|
+
_logger.error(
|
92
|
+
"Received response with no mailbox slot."
|
93
|
+
f" Kind: {result.WhichOneof('result_type')}"
|
94
|
+
)
|
95
|
+
return
|
96
|
+
|
97
|
+
with self._handles_lock:
|
98
|
+
# NOTE: If the mailbox is closed, this returns None because
|
99
|
+
# we clear the dict.
|
100
|
+
handle = self._handles.pop(address, None)
|
101
|
+
|
102
|
+
# It is not an error if there is no handle for the address:
|
103
|
+
# handles can be abandoned if the result is no longer needed.
|
104
|
+
if handle:
|
105
|
+
handle.deliver(result)
|
106
|
+
|
107
|
+
def close(self) -> None:
|
108
|
+
"""Indicate no further responses will be delivered.
|
109
|
+
|
110
|
+
Abandons all handles.
|
111
|
+
"""
|
112
|
+
with self._handles_lock:
|
113
|
+
self._closed = True
|
114
|
+
|
115
|
+
_logger.info(
|
116
|
+
f"Closing mailbox, abandoning {len(self._handles)} handles.",
|
117
|
+
)
|
118
|
+
|
119
|
+
for handle in self._handles.values():
|
120
|
+
handle.abandon()
|
121
|
+
self._handles.clear()
|
@@ -0,0 +1,134 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import time
|
4
|
+
from typing import Any, Callable, Coroutine, List, cast
|
5
|
+
|
6
|
+
from wandb.proto import wandb_internal_pb2 as pb
|
7
|
+
from wandb.sdk.lib import asyncio_compat
|
8
|
+
|
9
|
+
from . import handles
|
10
|
+
|
11
|
+
|
12
|
+
def wait_with_progress(
|
13
|
+
handle: handles.MailboxHandle,
|
14
|
+
*,
|
15
|
+
timeout: float | None,
|
16
|
+
progress_after: float,
|
17
|
+
display_progress: Callable[[], Coroutine[Any, Any, None]],
|
18
|
+
) -> pb.Result:
|
19
|
+
"""Wait for a handle, possibly displaying progress to the user.
|
20
|
+
|
21
|
+
Equivalent to passing a single handle to `wait_all_with_progress`.
|
22
|
+
"""
|
23
|
+
return wait_all_with_progress(
|
24
|
+
[handle],
|
25
|
+
timeout=timeout,
|
26
|
+
progress_after=progress_after,
|
27
|
+
display_progress=display_progress,
|
28
|
+
)[0]
|
29
|
+
|
30
|
+
|
31
|
+
def wait_all_with_progress(
|
32
|
+
handle_list: list[handles.MailboxHandle],
|
33
|
+
*,
|
34
|
+
timeout: float | None,
|
35
|
+
progress_after: float,
|
36
|
+
display_progress: Callable[[], Coroutine[Any, Any, None]],
|
37
|
+
) -> list[pb.Result]:
|
38
|
+
"""Wait for multiple handles, possibly displaying progress to the user.
|
39
|
+
|
40
|
+
Args:
|
41
|
+
handle_list: The handles to wait for.
|
42
|
+
timeout: A number of seconds after which to raise a TimeoutError,
|
43
|
+
or None if this should never timeout.
|
44
|
+
progress_after: A number of seconds after which to start the
|
45
|
+
display_progress callback. Starting the callback creates a thread
|
46
|
+
and starts an asyncio loop, so we want to avoid doing it if
|
47
|
+
the handle is resolved quickly.
|
48
|
+
display_progress: An asyncio function that displays progress to
|
49
|
+
the user. This function is executed on a new thread and cancelled
|
50
|
+
if the timeout is exceeded.
|
51
|
+
|
52
|
+
Returns:
|
53
|
+
A list where the Nth item is the Nth handle's result.
|
54
|
+
|
55
|
+
Raises:
|
56
|
+
TimeoutError: If the overall timeout expires.
|
57
|
+
HandleAbandonedError: If any handle becomes abandoned.
|
58
|
+
Exception: Any exception from the display function is propagated.
|
59
|
+
"""
|
60
|
+
if not handle_list:
|
61
|
+
return []
|
62
|
+
|
63
|
+
if timeout is not None and timeout <= progress_after:
|
64
|
+
return _wait_handles(handle_list, timeout=timeout)
|
65
|
+
|
66
|
+
start_time = time.monotonic()
|
67
|
+
|
68
|
+
try:
|
69
|
+
return _wait_handles(handle_list, timeout=progress_after)
|
70
|
+
except TimeoutError:
|
71
|
+
pass
|
72
|
+
|
73
|
+
async def progress_loop_with_timeout() -> list[pb.Result]:
|
74
|
+
with asyncio_compat.cancel_on_exit(display_progress()):
|
75
|
+
if timeout is not None:
|
76
|
+
elapsed_time = time.monotonic() - start_time
|
77
|
+
remaining_timeout = timeout - elapsed_time
|
78
|
+
else:
|
79
|
+
remaining_timeout = None
|
80
|
+
|
81
|
+
return await _wait_handles_async(
|
82
|
+
handle_list,
|
83
|
+
timeout=remaining_timeout,
|
84
|
+
)
|
85
|
+
|
86
|
+
return asyncio_compat.run(progress_loop_with_timeout)
|
87
|
+
|
88
|
+
|
89
|
+
def _wait_handles(
|
90
|
+
handle_list: list[handles.MailboxHandle],
|
91
|
+
*,
|
92
|
+
timeout: float,
|
93
|
+
) -> list[pb.Result]:
|
94
|
+
"""Wait for multiple mailbox handles.
|
95
|
+
|
96
|
+
Returns:
|
97
|
+
Each handle's result, in the same order as the given handles.
|
98
|
+
|
99
|
+
Raises:
|
100
|
+
TimeoutError: If the overall timeout expires.
|
101
|
+
HandleAbandonedError: If any handle becomes abandoned.
|
102
|
+
"""
|
103
|
+
results: list[pb.Result] = []
|
104
|
+
|
105
|
+
start_time = time.monotonic()
|
106
|
+
for handle in handle_list:
|
107
|
+
elapsed_time = time.monotonic() - start_time
|
108
|
+
remaining_timeout = timeout - elapsed_time
|
109
|
+
results.append(handle.wait_or(timeout=remaining_timeout))
|
110
|
+
|
111
|
+
return results
|
112
|
+
|
113
|
+
|
114
|
+
async def _wait_handles_async(
|
115
|
+
handle_list: list[handles.MailboxHandle],
|
116
|
+
*,
|
117
|
+
timeout: float | None,
|
118
|
+
) -> list[pb.Result]:
|
119
|
+
"""Asynchronously wait for multiple mailbox handles.
|
120
|
+
|
121
|
+
Just like _wait_handles.
|
122
|
+
"""
|
123
|
+
results: list[pb.Result | None] = [None for _ in handle_list]
|
124
|
+
|
125
|
+
async def wait_single(index: int) -> None:
|
126
|
+
handle = handle_list[index]
|
127
|
+
results[index] = await handle.wait_async(timeout=timeout)
|
128
|
+
|
129
|
+
async with asyncio_compat.open_task_group() as task_group:
|
130
|
+
for index in range(len(handle_list)):
|
131
|
+
task_group.start_soon(wait_single(index))
|
132
|
+
|
133
|
+
# NOTE: `list` is not subscriptable until Python 3.10, so we use List.
|
134
|
+
return cast(List[pb.Result], results)
|