wandb 0.19.6rc4__py3-none-any.whl → 0.19.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. wandb/__init__.py +1 -1
  2. wandb/__init__.pyi +25 -5
  3. wandb/apis/public/_generated/__init__.py +21 -0
  4. wandb/apis/public/_generated/base.py +128 -0
  5. wandb/apis/public/_generated/enums.py +4 -0
  6. wandb/apis/public/_generated/input_types.py +4 -0
  7. wandb/apis/public/_generated/operations.py +15 -0
  8. wandb/apis/public/_generated/server_features_query.py +27 -0
  9. wandb/apis/public/_generated/typing_compat.py +14 -0
  10. wandb/apis/public/api.py +192 -6
  11. wandb/apis/public/artifacts.py +13 -45
  12. wandb/apis/public/registries.py +573 -0
  13. wandb/apis/public/utils.py +36 -0
  14. wandb/bin/gpu_stats +0 -0
  15. wandb/cli/cli.py +11 -20
  16. wandb/env.py +10 -0
  17. wandb/proto/v3/wandb_internal_pb2.py +243 -222
  18. wandb/proto/v3/wandb_server_pb2.py +4 -4
  19. wandb/proto/v3/wandb_settings_pb2.py +1 -1
  20. wandb/proto/v4/wandb_internal_pb2.py +226 -222
  21. wandb/proto/v4/wandb_server_pb2.py +4 -4
  22. wandb/proto/v4/wandb_settings_pb2.py +1 -1
  23. wandb/proto/v5/wandb_internal_pb2.py +226 -222
  24. wandb/proto/v5/wandb_server_pb2.py +4 -4
  25. wandb/proto/v5/wandb_settings_pb2.py +1 -1
  26. wandb/sdk/artifacts/_graphql_fragments.py +126 -0
  27. wandb/sdk/artifacts/artifact.py +43 -88
  28. wandb/sdk/backend/backend.py +1 -1
  29. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +14 -6
  30. wandb/sdk/data_types/helper_types/image_mask.py +12 -6
  31. wandb/sdk/data_types/saved_model.py +35 -46
  32. wandb/sdk/data_types/video.py +7 -16
  33. wandb/sdk/interface/interface.py +26 -10
  34. wandb/sdk/interface/interface_queue.py +5 -8
  35. wandb/sdk/interface/interface_relay.py +1 -6
  36. wandb/sdk/interface/interface_shared.py +21 -99
  37. wandb/sdk/interface/interface_sock.py +2 -13
  38. wandb/sdk/interface/router.py +21 -15
  39. wandb/sdk/interface/router_queue.py +2 -1
  40. wandb/sdk/interface/router_relay.py +2 -1
  41. wandb/sdk/interface/router_sock.py +5 -4
  42. wandb/sdk/internal/handler.py +4 -3
  43. wandb/sdk/internal/internal_api.py +12 -1
  44. wandb/sdk/internal/sender.py +0 -18
  45. wandb/sdk/lib/apikey.py +87 -26
  46. wandb/sdk/lib/asyncio_compat.py +210 -0
  47. wandb/sdk/lib/progress.py +78 -16
  48. wandb/sdk/lib/service_connection.py +1 -1
  49. wandb/sdk/lib/sock_client.py +7 -7
  50. wandb/sdk/mailbox/__init__.py +23 -0
  51. wandb/sdk/mailbox/handles.py +199 -0
  52. wandb/sdk/mailbox/mailbox.py +121 -0
  53. wandb/sdk/mailbox/wait_with_progress.py +134 -0
  54. wandb/sdk/service/server_sock.py +5 -1
  55. wandb/sdk/service/streams.py +66 -74
  56. wandb/sdk/verify/verify.py +54 -2
  57. wandb/sdk/wandb_init.py +61 -61
  58. wandb/sdk/wandb_login.py +7 -4
  59. wandb/sdk/wandb_metadata.py +65 -34
  60. wandb/sdk/wandb_require.py +14 -8
  61. wandb/sdk/wandb_run.py +82 -87
  62. wandb/sdk/wandb_settings.py +3 -3
  63. wandb/sdk/wandb_setup.py +19 -8
  64. wandb/sdk/wandb_sync.py +2 -4
  65. wandb/util.py +3 -1
  66. {wandb-0.19.6rc4.dist-info → wandb-0.19.7.dist-info}/METADATA +2 -2
  67. {wandb-0.19.6rc4.dist-info → wandb-0.19.7.dist-info}/RECORD +70 -57
  68. wandb/sdk/lib/mailbox.py +0 -442
  69. {wandb-0.19.6rc4.dist-info → wandb-0.19.7.dist-info}/WHEEL +0 -0
  70. {wandb-0.19.6rc4.dist-info → wandb-0.19.7.dist-info}/entry_points.txt +0 -0
  71. {wandb-0.19.6rc4.dist-info → wandb-0.19.7.dist-info}/licenses/LICENSE +0 -0
wandb/sdk/lib/progress.py CHANGED
@@ -2,12 +2,15 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ import asyncio
5
6
  import contextlib
6
- from typing import Iterable, Iterator
7
+ import time
8
+ from typing import Iterable, Iterator, NoReturn
7
9
 
8
- import wandb
9
10
  from wandb import env
10
11
  from wandb.proto import wandb_internal_pb2 as pb
12
+ from wandb.sdk.interface import interface
13
+ from wandb.sdk.lib import asyncio_compat
11
14
 
12
15
  from . import printer as p
13
16
 
@@ -31,14 +34,67 @@ def print_sync_dedupe_stats(
31
34
  printer.display(f"W&B sync reduced upload amount by {frac:.1%}")
32
35
 
33
36
 
37
+ async def loop_printing_operation_stats(
38
+ progress: ProgressPrinter,
39
+ interface: interface.InterfaceBase,
40
+ ) -> None:
41
+ """Poll and display ongoing tasks in the internal service process.
42
+
43
+ This never returns and must be cancelled. This is meant to be used with
44
+ `mailbox.wait_with_progress()`.
45
+
46
+ Args:
47
+ progress: The printer to update with operation stats.
48
+ interface: The interface to use to poll for updates.
49
+
50
+ Raises:
51
+ HandleAbandonedError: If the mailbox associated with the interface
52
+ becomes closed.
53
+ Exception: Any other problem communicating with the service process.
54
+ """
55
+ stats: pb.OperationStats | None = None
56
+
57
+ async def loop_update_screen() -> NoReturn:
58
+ while True:
59
+ if stats:
60
+ progress.update(stats)
61
+ await asyncio.sleep(0.1)
62
+
63
+ async def loop_poll_stats() -> NoReturn:
64
+ nonlocal stats
65
+ while True:
66
+ start_time = time.monotonic()
67
+
68
+ handle = interface.deliver_operation_stats()
69
+ result = await handle.wait_async(timeout=None)
70
+ stats = result.response.operations_response.operation_stats
71
+
72
+ elapsed_time = time.monotonic() - start_time
73
+ if elapsed_time < 0.5:
74
+ await asyncio.sleep(0.5 - elapsed_time)
75
+
76
+ async with asyncio_compat.open_task_group() as task_group:
77
+ task_group.start_soon(loop_update_screen())
78
+ task_group.start_soon(loop_poll_stats())
79
+
80
+
34
81
  @contextlib.contextmanager
35
82
  def progress_printer(
36
83
  printer: p.Printer,
37
- settings: wandb.Settings | None = None,
84
+ default_text: str,
38
85
  ) -> Iterator[ProgressPrinter]:
39
- """Context manager providing an object for printing run progress."""
86
+ """Context manager providing an object for printing run progress.
87
+
88
+ Args:
89
+ printer: The printer to use.
90
+ default_text: The text to show if no information is available.
91
+ """
40
92
  with printer.dynamic_text() as text_area:
41
- yield ProgressPrinter(printer, text_area, settings)
93
+ yield ProgressPrinter(
94
+ printer,
95
+ text_area,
96
+ default_text=default_text,
97
+ )
42
98
  printer.progress_close()
43
99
 
44
100
 
@@ -49,28 +105,27 @@ class ProgressPrinter:
49
105
  self,
50
106
  printer: p.Printer,
51
107
  progress_text_area: p.DynamicText | None,
52
- settings: wandb.Settings | None,
108
+ default_text: str,
53
109
  ) -> None:
54
- self._show_operation_stats = (
55
- settings
56
- and settings.x_show_operation_stats
57
- # Not implemented by the legacy service.
58
- and not env.is_require_legacy_service()
59
- )
110
+ # Not implemented by the legacy service.
111
+ self._show_operation_stats = not env.is_require_legacy_service()
60
112
  self._printer = printer
61
113
  self._progress_text_area = progress_text_area
114
+ self._default_text = default_text
62
115
  self._tick = 0
63
116
  self._last_printed_line = ""
64
117
 
65
118
  def update(
66
119
  self,
67
- progress: list[pb.PollExitResponse],
120
+ progress: list[pb.PollExitResponse] | pb.OperationStats,
68
121
  ) -> None:
69
122
  """Update the displayed information."""
70
123
  if not progress:
71
124
  return
72
125
 
73
- if self._show_operation_stats:
126
+ if isinstance(progress, pb.OperationStats):
127
+ self._update_operation_stats([progress])
128
+ elif self._show_operation_stats:
74
129
  self._update_operation_stats(
75
130
  list(response.operation_stats for response in progress)
76
131
  )
@@ -88,6 +143,7 @@ class ProgressPrinter:
88
143
  self._progress_text_area,
89
144
  max_lines=6,
90
145
  loading_symbol=self._printer.loading_symbol(self._tick),
146
+ default_text=self._default_text,
91
147
  ).display(stats_list)
92
148
 
93
149
  else:
@@ -159,6 +215,10 @@ class ProgressPrinter:
159
215
  self._update_progress_text(line, 1.0)
160
216
 
161
217
  def _update_progress_text(self, text: str, progress: float) -> None:
218
+ if text == self._last_printed_line:
219
+ return
220
+ self._last_printed_line = text
221
+
162
222
  if self._progress_text_area:
163
223
  self._progress_text_area.set_text(text)
164
224
  else:
@@ -174,11 +234,13 @@ class _DynamicOperationStatsPrinter:
174
234
  text_area: p.DynamicText,
175
235
  max_lines: int,
176
236
  loading_symbol: str,
237
+ default_text: str,
177
238
  ) -> None:
178
239
  self._printer = printer
179
240
  self._text_area = text_area
180
241
  self._max_lines = max_lines
181
242
  self._loading_symbol = loading_symbol
243
+ self._default_text = default_text
182
244
 
183
245
  self._lines: list[str] = []
184
246
  self._ops_shown = 0
@@ -204,9 +266,9 @@ class _DynamicOperationStatsPrinter:
204
266
 
205
267
  if len(self._lines) == 0:
206
268
  if self._loading_symbol:
207
- self._text_area.set_text(f"{self._loading_symbol} Finishing up...")
269
+ self._text_area.set_text(f"{self._loading_symbol} {self._default_text}")
208
270
  else:
209
- self._text_area.set_text("Finishing up...")
271
+ self._text_area.set_text(self._default_text)
210
272
  else:
211
273
  self._text_area.set_text("\n".join(self._lines))
212
274
 
@@ -12,8 +12,8 @@ from wandb.sdk.interface.interface import InterfaceBase
12
12
  from wandb.sdk.interface.interface_sock import InterfaceSock
13
13
  from wandb.sdk.lib import service_token
14
14
  from wandb.sdk.lib.exit_hooks import ExitHooks
15
- from wandb.sdk.lib.mailbox import Mailbox
16
15
  from wandb.sdk.lib.sock_client import SockClient, SockClientTimeoutError
16
+ from wandb.sdk.mailbox import Mailbox
17
17
  from wandb.sdk.service import service
18
18
 
19
19
 
@@ -177,9 +177,9 @@ class SockClient:
177
177
  inform_finish=inform_finish,
178
178
  inform_teardown=inform_teardown,
179
179
  )
180
- # TODO: this solution is fragile, but for checking attach
181
- # it should be relatively stable.
182
- # This pass would be solved as part of the fix in https://wandb.atlassian.net/browse/WB-8709
180
+
181
+ # HACK: This assumes nothing else is reading on the socket, and that
182
+ # the next response is for this request.
183
183
  response = self.read_server_response(timeout=1)
184
184
 
185
185
  if response is None:
@@ -213,11 +213,13 @@ class SockClient:
213
213
 
214
214
  def send_record_communicate(self, record: "pb.Record") -> None:
215
215
  server_req = spb.ServerRequest()
216
+ server_req.request_id = record.control.mailbox_slot
216
217
  server_req.record_communicate.CopyFrom(record)
217
218
  self.send_server_request(server_req)
218
219
 
219
220
  def send_record_publish(self, record: "pb.Record") -> None:
220
221
  server_req = spb.ServerRequest()
222
+ server_req.request_id = record.control.mailbox_slot
221
223
  server_req.record_publish.CopyFrom(record)
222
224
  self.send_server_request(server_req)
223
225
 
@@ -256,10 +258,8 @@ class SockClient:
256
258
  data = self._sock.recv(self._bufsize)
257
259
  except socket.timeout:
258
260
  break
259
- except ConnectionResetError:
260
- raise SockClientClosedError
261
- except OSError:
262
- raise SockClientClosedError
261
+ except OSError as e:
262
+ raise SockClientClosedError from e
263
263
  finally:
264
264
  if timeout:
265
265
  self._sock.settimeout(None)
@@ -0,0 +1,23 @@
1
+ """A message protocol for the internal service process.
2
+
3
+ The core of W&B is implemented by a side process that asynchronously uploads
4
+ data. The client process (such as this Python code) sends requests to the
5
+ service, and for some requests, the service eventually sends a response.
6
+
7
+ The client can send multiple requests before the service provides a response.
8
+ The Mailbox handles matching responses to requests. An internal thread
9
+ continuously reads data from the service and passes it to the mailbox.
10
+ """
11
+
12
+ from .handles import HandleAbandonedError, MailboxHandle
13
+ from .mailbox import Mailbox, MailboxClosedError
14
+ from .wait_with_progress import wait_all_with_progress, wait_with_progress
15
+
16
+ __all__ = [
17
+ "HandleAbandonedError",
18
+ "MailboxHandle",
19
+ "Mailbox",
20
+ "MailboxClosedError",
21
+ "wait_all_with_progress",
22
+ "wait_with_progress",
23
+ ]
@@ -0,0 +1,199 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import math
5
+ import threading
6
+ from typing import TYPE_CHECKING
7
+
8
+ from wandb.proto import wandb_internal_pb2 as pb
9
+
10
+ # Necessary to break an import loop.
11
+ if TYPE_CHECKING:
12
+ from wandb.sdk.interface import interface
13
+
14
+
15
+ class HandleAbandonedError(Exception):
16
+ """The handle has no result and has been abandoned."""
17
+
18
+
19
+ class MailboxHandle:
20
+ """A thread-safe handle that allows waiting for a response to a request."""
21
+
22
+ def __init__(self, address: str) -> None:
23
+ self._address = address
24
+ self._lock = threading.Lock()
25
+ self._event = threading.Event()
26
+
27
+ self._abandoned = False
28
+ self._result: pb.Result | None = None
29
+
30
+ self._asyncio_events: dict[asyncio.Event, _AsyncioEvent] = dict()
31
+
32
+ def deliver(self, result: pb.Result) -> None:
33
+ """Deliver the response.
34
+
35
+ This may only be called once. It is an error to respond to the same
36
+ request more than once. It is a no-op if the handle has been abandoned.
37
+ """
38
+ with self._lock:
39
+ if self._abandoned:
40
+ return
41
+
42
+ if self._result:
43
+ raise ValueError(
44
+ f"A response has already been delivered to {self._address}."
45
+ )
46
+
47
+ self._result = result
48
+ self._signal_done()
49
+
50
+ def cancel(self, iface: interface.InterfaceBase) -> None:
51
+ """Cancel the handle, requesting any associated work to not complete.
52
+
53
+ This automatically abandons the handle, as a response is no longer
54
+ guaranteed.
55
+
56
+ Args:
57
+ interface: The interface on which to publish the cancel request.
58
+ """
59
+ iface.publish_cancel(self._address)
60
+ self.abandon()
61
+
62
+ def abandon(self) -> None:
63
+ """Abandon the handle, indicating it will not receive a response."""
64
+ with self._lock:
65
+ self._abandoned = True
66
+ self._signal_done()
67
+
68
+ def _signal_done(self) -> None:
69
+ """Indicate that the handle either got a result or became abandoned.
70
+
71
+ The lock must be held.
72
+ """
73
+ # Unblock threads blocked on `wait_or`.
74
+ self._event.set()
75
+
76
+ # Unblock asyncio loops blocked on `wait_async`.
77
+ for asyncio_event in self._asyncio_events.values():
78
+ asyncio_event.set_threadsafe()
79
+ self._asyncio_events.clear()
80
+
81
+ def check(self) -> pb.Result | None:
82
+ """Returns the result if it's ready."""
83
+ with self._lock:
84
+ return self._result
85
+
86
+ def wait_or(self, *, timeout: float | None) -> pb.Result:
87
+ """Wait for a response or a timeout.
88
+
89
+ This is called `wait_or` because it replaces a method called `wait`
90
+ with different semantics.
91
+
92
+ Args:
93
+ timeout: A finite number of seconds or None to never time out.
94
+ If less than or equal to zero, times out immediately unless
95
+ the result is available.
96
+
97
+ Returns:
98
+ The result if it arrives before the timeout or has already arrived.
99
+
100
+ Raises:
101
+ TimeoutError: If the timeout is reached.
102
+ HandleAbandonedError: If the handle becomes abandoned.
103
+ """
104
+ if timeout is not None and not math.isfinite(timeout):
105
+ raise ValueError("Timeout must be finite or None.")
106
+
107
+ if not self._event.wait(timeout=timeout):
108
+ raise TimeoutError(
109
+ f"Timed out waiting for response on {self._address}",
110
+ )
111
+
112
+ with self._lock:
113
+ if self._result:
114
+ return self._result
115
+
116
+ assert self._abandoned
117
+ raise HandleAbandonedError()
118
+
119
+ async def wait_async(self, *, timeout: float | None) -> pb.Result:
120
+ """Wait for a response or timeout.
121
+
122
+ This must run in an `asyncio` event loop.
123
+
124
+ Args:
125
+ timeout: A finite number of seconds or None to never time out.
126
+
127
+ Returns:
128
+ The result if it arrives before the timeout or has already arrived.
129
+
130
+ Raises:
131
+ TimeoutError: If the timeout is reached.
132
+ HandleAbandonedError: If the handle becomes abandoned.
133
+ """
134
+ if timeout is not None and not math.isfinite(timeout):
135
+ raise ValueError("Timeout must be finite or None.")
136
+
137
+ evt = asyncio.Event()
138
+ self._add_asyncio_event(asyncio.get_event_loop(), evt)
139
+
140
+ try:
141
+ await asyncio.wait_for(evt.wait(), timeout=timeout)
142
+
143
+ except (asyncio.TimeoutError, TimeoutError) as e:
144
+ with self._lock:
145
+ if self._result:
146
+ return self._result
147
+ elif self._abandoned:
148
+ raise HandleAbandonedError()
149
+ else:
150
+ raise TimeoutError(
151
+ f"Timed out waiting for response on {self._address}"
152
+ ) from e
153
+
154
+ else:
155
+ with self._lock:
156
+ if self._result:
157
+ return self._result
158
+
159
+ assert self._abandoned
160
+ raise HandleAbandonedError()
161
+
162
+ finally:
163
+ self._forget_asyncio_event(evt)
164
+
165
+ def _add_asyncio_event(
166
+ self,
167
+ loop: asyncio.AbstractEventLoop,
168
+ event: asyncio.Event,
169
+ ) -> None:
170
+ """Add an event to signal when a result is received.
171
+
172
+ If a result already exists, this notifies the event loop immediately.
173
+ """
174
+ asyncio_event = _AsyncioEvent(loop, event)
175
+
176
+ with self._lock:
177
+ if self._result or self._abandoned:
178
+ asyncio_event.set_threadsafe()
179
+ else:
180
+ self._asyncio_events[event] = asyncio_event
181
+
182
+ def _forget_asyncio_event(self, event: asyncio.Event) -> None:
183
+ """Cancel signalling an event when a result is received."""
184
+ with self._lock:
185
+ self._asyncio_events.pop(event, None)
186
+
187
+
188
+ class _AsyncioEvent:
189
+ def __init__(
190
+ self,
191
+ loop: asyncio.AbstractEventLoop,
192
+ event: asyncio.Event,
193
+ ):
194
+ self._loop = loop
195
+ self._event = event
196
+
197
+ def set_threadsafe(self) -> None:
198
+ """Set the asyncio event in its own loop."""
199
+ self._loop.call_soon_threadsafe(self._event.set)
@@ -0,0 +1,121 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import secrets
5
+ import string
6
+ import threading
7
+
8
+ from wandb.proto import wandb_internal_pb2 as pb
9
+
10
+ from . import handles
11
+
12
+ _logger = logging.getLogger(__name__)
13
+
14
+
15
+ class MailboxClosedError(Exception):
16
+ """The mailbox has been closed and cannot be used."""
17
+
18
+
19
+ class Mailbox:
20
+ """Matches service responses to requests.
21
+
22
+ The mailbox can set an address on a Record and create a handle for
23
+ waiting for a response to that record. Responses are delivered by calling
24
+ `deliver()`. The `close()` method abandons all handles in case the
25
+ service process becomes unreachable.
26
+ """
27
+
28
+ def __init__(self) -> None:
29
+ self._handles: dict[str, handles.MailboxHandle] = {}
30
+ self._handles_lock = threading.Lock()
31
+ self._closed = False
32
+
33
+ def require_response(self, request: pb.Record) -> handles.MailboxHandle:
34
+ """Set a response address on a request.
35
+
36
+ Args:
37
+ request: The request on which to set a mailbox slot.
38
+ This is mutated. An address must not already be set.
39
+
40
+ Returns:
41
+ A handle for waiting for the response to the request.
42
+
43
+ Raises:
44
+ MailboxClosedError: If the mailbox has been closed, in which case
45
+ no new responses are expected to be delivered and new handles
46
+ cannot be created.
47
+ """
48
+ if address := request.control.mailbox_slot:
49
+ raise ValueError(f"Request already has an address ({address})")
50
+
51
+ address = self._new_address()
52
+ request.control.mailbox_slot = address
53
+
54
+ with self._handles_lock:
55
+ if self._closed:
56
+ raise MailboxClosedError()
57
+
58
+ handle = handles.MailboxHandle(address)
59
+ self._handles[address] = handle
60
+
61
+ return handle
62
+
63
+ def _new_address(self) -> str:
64
+ """Returns an unused address for a request.
65
+
66
+ Assumes `_handles_lock` is held.
67
+ """
68
+
69
+ def generate():
70
+ return "".join(
71
+ secrets.choice(string.ascii_lowercase + string.digits)
72
+ for i in range(12)
73
+ )
74
+
75
+ address = generate()
76
+
77
+ # Being extra cautious. This loop will almost never be entered.
78
+ while address in self._handles:
79
+ address = generate()
80
+
81
+ return address
82
+
83
+ def deliver(self, result: pb.Result) -> None:
84
+ """Deliver a response from the service.
85
+
86
+ If the response address is invalid, this does nothing.
87
+ It is a no-op if the mailbox has been closed.
88
+ """
89
+ address = result.control.mailbox_slot
90
+ if not address:
91
+ _logger.error(
92
+ "Received response with no mailbox slot."
93
+ f" Kind: {result.WhichOneof('result_type')}"
94
+ )
95
+ return
96
+
97
+ with self._handles_lock:
98
+ # NOTE: If the mailbox is closed, this returns None because
99
+ # we clear the dict.
100
+ handle = self._handles.pop(address, None)
101
+
102
+ # It is not an error if there is no handle for the address:
103
+ # handles can be abandoned if the result is no longer needed.
104
+ if handle:
105
+ handle.deliver(result)
106
+
107
+ def close(self) -> None:
108
+ """Indicate no further responses will be delivered.
109
+
110
+ Abandons all handles.
111
+ """
112
+ with self._handles_lock:
113
+ self._closed = True
114
+
115
+ _logger.info(
116
+ f"Closing mailbox, abandoning {len(self._handles)} handles.",
117
+ )
118
+
119
+ for handle in self._handles.values():
120
+ handle.abandon()
121
+ self._handles.clear()
@@ -0,0 +1,134 @@
1
+ from __future__ import annotations
2
+
3
+ import time
4
+ from typing import Any, Callable, Coroutine, List, cast
5
+
6
+ from wandb.proto import wandb_internal_pb2 as pb
7
+ from wandb.sdk.lib import asyncio_compat
8
+
9
+ from . import handles
10
+
11
+
12
+ def wait_with_progress(
13
+ handle: handles.MailboxHandle,
14
+ *,
15
+ timeout: float | None,
16
+ progress_after: float,
17
+ display_progress: Callable[[], Coroutine[Any, Any, None]],
18
+ ) -> pb.Result:
19
+ """Wait for a handle, possibly displaying progress to the user.
20
+
21
+ Equivalent to passing a single handle to `wait_all_with_progress`.
22
+ """
23
+ return wait_all_with_progress(
24
+ [handle],
25
+ timeout=timeout,
26
+ progress_after=progress_after,
27
+ display_progress=display_progress,
28
+ )[0]
29
+
30
+
31
+ def wait_all_with_progress(
32
+ handle_list: list[handles.MailboxHandle],
33
+ *,
34
+ timeout: float | None,
35
+ progress_after: float,
36
+ display_progress: Callable[[], Coroutine[Any, Any, None]],
37
+ ) -> list[pb.Result]:
38
+ """Wait for multiple handles, possibly displaying progress to the user.
39
+
40
+ Args:
41
+ handle_list: The handles to wait for.
42
+ timeout: A number of seconds after which to raise a TimeoutError,
43
+ or None if this should never timeout.
44
+ progress_after: A number of seconds after which to start the
45
+ display_progress callback. Starting the callback creates a thread
46
+ and starts an asyncio loop, so we want to avoid doing it if
47
+ the handle is resolved quickly.
48
+ display_progress: An asyncio function that displays progress to
49
+ the user. This function is executed on a new thread and cancelled
50
+ if the timeout is exceeded.
51
+
52
+ Returns:
53
+ A list where the Nth item is the Nth handle's result.
54
+
55
+ Raises:
56
+ TimeoutError: If the overall timeout expires.
57
+ HandleAbandonedError: If any handle becomes abandoned.
58
+ Exception: Any exception from the display function is propagated.
59
+ """
60
+ if not handle_list:
61
+ return []
62
+
63
+ if timeout is not None and timeout <= progress_after:
64
+ return _wait_handles(handle_list, timeout=timeout)
65
+
66
+ start_time = time.monotonic()
67
+
68
+ try:
69
+ return _wait_handles(handle_list, timeout=progress_after)
70
+ except TimeoutError:
71
+ pass
72
+
73
+ async def progress_loop_with_timeout() -> list[pb.Result]:
74
+ with asyncio_compat.cancel_on_exit(display_progress()):
75
+ if timeout is not None:
76
+ elapsed_time = time.monotonic() - start_time
77
+ remaining_timeout = timeout - elapsed_time
78
+ else:
79
+ remaining_timeout = None
80
+
81
+ return await _wait_handles_async(
82
+ handle_list,
83
+ timeout=remaining_timeout,
84
+ )
85
+
86
+ return asyncio_compat.run(progress_loop_with_timeout)
87
+
88
+
89
+ def _wait_handles(
90
+ handle_list: list[handles.MailboxHandle],
91
+ *,
92
+ timeout: float,
93
+ ) -> list[pb.Result]:
94
+ """Wait for multiple mailbox handles.
95
+
96
+ Returns:
97
+ Each handle's result, in the same order as the given handles.
98
+
99
+ Raises:
100
+ TimeoutError: If the overall timeout expires.
101
+ HandleAbandonedError: If any handle becomes abandoned.
102
+ """
103
+ results: list[pb.Result] = []
104
+
105
+ start_time = time.monotonic()
106
+ for handle in handle_list:
107
+ elapsed_time = time.monotonic() - start_time
108
+ remaining_timeout = timeout - elapsed_time
109
+ results.append(handle.wait_or(timeout=remaining_timeout))
110
+
111
+ return results
112
+
113
+
114
+ async def _wait_handles_async(
115
+ handle_list: list[handles.MailboxHandle],
116
+ *,
117
+ timeout: float | None,
118
+ ) -> list[pb.Result]:
119
+ """Asynchronously wait for multiple mailbox handles.
120
+
121
+ Just like _wait_handles.
122
+ """
123
+ results: list[pb.Result | None] = [None for _ in handle_list]
124
+
125
+ async def wait_single(index: int) -> None:
126
+ handle = handle_list[index]
127
+ results[index] = await handle.wait_async(timeout=timeout)
128
+
129
+ async with asyncio_compat.open_task_group() as task_group:
130
+ for index in range(len(handle_list)):
131
+ task_group.start_soon(wait_single(index))
132
+
133
+ # NOTE: `list` is not subscriptable until Python 3.10, so we use List.
134
+ return cast(List[pb.Result], results)