wandb 0.19.6rc4__py3-none-any.whl → 0.19.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. wandb/__init__.py +1 -1
  2. wandb/__init__.pyi +25 -5
  3. wandb/apis/public/_generated/__init__.py +21 -0
  4. wandb/apis/public/_generated/base.py +128 -0
  5. wandb/apis/public/_generated/enums.py +4 -0
  6. wandb/apis/public/_generated/input_types.py +4 -0
  7. wandb/apis/public/_generated/operations.py +15 -0
  8. wandb/apis/public/_generated/server_features_query.py +27 -0
  9. wandb/apis/public/_generated/typing_compat.py +14 -0
  10. wandb/apis/public/api.py +192 -6
  11. wandb/apis/public/artifacts.py +13 -45
  12. wandb/apis/public/registries.py +573 -0
  13. wandb/apis/public/utils.py +36 -0
  14. wandb/bin/gpu_stats +0 -0
  15. wandb/cli/cli.py +11 -20
  16. wandb/env.py +10 -0
  17. wandb/proto/v3/wandb_internal_pb2.py +243 -222
  18. wandb/proto/v3/wandb_server_pb2.py +4 -4
  19. wandb/proto/v3/wandb_settings_pb2.py +1 -1
  20. wandb/proto/v4/wandb_internal_pb2.py +226 -222
  21. wandb/proto/v4/wandb_server_pb2.py +4 -4
  22. wandb/proto/v4/wandb_settings_pb2.py +1 -1
  23. wandb/proto/v5/wandb_internal_pb2.py +226 -222
  24. wandb/proto/v5/wandb_server_pb2.py +4 -4
  25. wandb/proto/v5/wandb_settings_pb2.py +1 -1
  26. wandb/sdk/artifacts/_graphql_fragments.py +126 -0
  27. wandb/sdk/artifacts/artifact.py +43 -88
  28. wandb/sdk/backend/backend.py +1 -1
  29. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +14 -6
  30. wandb/sdk/data_types/helper_types/image_mask.py +12 -6
  31. wandb/sdk/data_types/saved_model.py +35 -46
  32. wandb/sdk/data_types/video.py +7 -16
  33. wandb/sdk/interface/interface.py +26 -10
  34. wandb/sdk/interface/interface_queue.py +5 -8
  35. wandb/sdk/interface/interface_relay.py +1 -6
  36. wandb/sdk/interface/interface_shared.py +21 -99
  37. wandb/sdk/interface/interface_sock.py +2 -13
  38. wandb/sdk/interface/router.py +21 -15
  39. wandb/sdk/interface/router_queue.py +2 -1
  40. wandb/sdk/interface/router_relay.py +2 -1
  41. wandb/sdk/interface/router_sock.py +5 -4
  42. wandb/sdk/internal/handler.py +4 -3
  43. wandb/sdk/internal/internal_api.py +12 -1
  44. wandb/sdk/internal/sender.py +0 -18
  45. wandb/sdk/lib/apikey.py +87 -26
  46. wandb/sdk/lib/asyncio_compat.py +210 -0
  47. wandb/sdk/lib/progress.py +78 -16
  48. wandb/sdk/lib/service_connection.py +1 -1
  49. wandb/sdk/lib/sock_client.py +7 -7
  50. wandb/sdk/mailbox/__init__.py +23 -0
  51. wandb/sdk/mailbox/handles.py +199 -0
  52. wandb/sdk/mailbox/mailbox.py +121 -0
  53. wandb/sdk/mailbox/wait_with_progress.py +134 -0
  54. wandb/sdk/service/server_sock.py +5 -1
  55. wandb/sdk/service/streams.py +66 -74
  56. wandb/sdk/verify/verify.py +54 -2
  57. wandb/sdk/wandb_init.py +61 -61
  58. wandb/sdk/wandb_login.py +7 -4
  59. wandb/sdk/wandb_metadata.py +65 -34
  60. wandb/sdk/wandb_require.py +14 -8
  61. wandb/sdk/wandb_run.py +82 -87
  62. wandb/sdk/wandb_settings.py +3 -3
  63. wandb/sdk/wandb_setup.py +19 -8
  64. wandb/sdk/wandb_sync.py +2 -4
  65. wandb/util.py +3 -1
  66. {wandb-0.19.6rc4.dist-info → wandb-0.19.7.dist-info}/METADATA +2 -2
  67. {wandb-0.19.6rc4.dist-info → wandb-0.19.7.dist-info}/RECORD +70 -57
  68. wandb/sdk/lib/mailbox.py +0 -442
  69. {wandb-0.19.6rc4.dist-info → wandb-0.19.7.dist-info}/WHEEL +0 -0
  70. {wandb-0.19.6rc4.dist-info → wandb-0.19.7.dist-info}/entry_points.txt +0 -0
  71. {wandb-0.19.6rc4.dist-info → wandb-0.19.7.dist-info}/licenses/LICENSE +0 -0
@@ -7,10 +7,10 @@ See interface.py for how interface classes relate to each other.
7
7
  import logging
8
8
  from typing import TYPE_CHECKING, Any, Optional
9
9
 
10
- from ..lib.mailbox import Mailbox
10
+ from wandb.sdk.mailbox import Mailbox
11
+
11
12
  from ..lib.sock_client import SockClient
12
13
  from .interface_shared import InterfaceShared
13
- from .message_future import MessageFuture
14
14
  from .router_sock import MessageSockRouter
15
15
 
16
16
  if TYPE_CHECKING:
@@ -32,7 +32,6 @@ class InterfaceSock(InterfaceShared):
32
32
  # _sock_client is used when abstract method _init_router() is called by constructor
33
33
  self._sock_client = sock_client
34
34
  super().__init__(mailbox=mailbox)
35
- self._process_check = False
36
35
  self._stream_id = stream_id
37
36
 
38
37
  def _init_router(self) -> None:
@@ -45,13 +44,3 @@ class InterfaceSock(InterfaceShared):
45
44
  def _publish(self, record: "pb.Record", local: Optional[bool] = None) -> None:
46
45
  self._assign(record)
47
46
  self._sock_client.send_record_publish(record)
48
-
49
- def _communicate_async(
50
- self, rec: "pb.Record", local: Optional[bool] = None
51
- ) -> MessageFuture:
52
- self._assign(rec)
53
- assert self._router
54
- if self._process_check and self._process and not self._process.is_alive():
55
- raise Exception("The wandb backend process has shutdown")
56
- future = self._router.send_and_receive(rec, local=local)
57
- return future
@@ -10,7 +10,8 @@ import uuid
10
10
  from abc import abstractmethod
11
11
  from typing import TYPE_CHECKING, Dict, Optional
12
12
 
13
- from ..lib import mailbox
13
+ from wandb.sdk import mailbox
14
+
14
15
  from .message_future import MessageFuture
15
16
 
16
17
  if TYPE_CHECKING:
@@ -63,20 +64,25 @@ class MessageRouter:
63
64
  raise NotImplementedError
64
65
 
65
66
  def message_loop(self) -> None:
66
- while not self._join_event.is_set():
67
- try:
68
- msg = self._read_message()
69
- except EOFError:
70
- # On abnormal shutdown the queue will be destroyed underneath
71
- # resulting in EOFError. message_loop needs to exit..
72
- logger.warning("EOFError seen in message_loop")
73
- break
74
- except MessageRouterClosedError:
75
- logger.warning("message_loop has been closed")
76
- break
77
- if not msg:
78
- continue
79
- self._handle_msg_rcv(msg)
67
+ try:
68
+ while not self._join_event.is_set():
69
+ try:
70
+ msg = self._read_message()
71
+ except EOFError:
72
+ # On abnormal shutdown the queue will be destroyed underneath
73
+ # resulting in EOFError. message_loop needs to exit..
74
+ logger.warning("EOFError seen in message_loop")
75
+ break
76
+ except MessageRouterClosedError as e:
77
+ logger.warning("message_loop has been closed", exc_info=e)
78
+ break
79
+ if not msg:
80
+ continue
81
+ self._handle_msg_rcv(msg)
82
+
83
+ finally:
84
+ if self._mailbox:
85
+ self._mailbox.close()
80
86
 
81
87
  def send_and_receive(
82
88
  self, rec: "pb.Record", local: Optional[bool] = None
@@ -7,7 +7,8 @@ Router to manage responses from a queue.
7
7
  import queue
8
8
  from typing import TYPE_CHECKING, Optional
9
9
 
10
- from ..lib.mailbox import Mailbox
10
+ from wandb.sdk.mailbox import Mailbox
11
+
11
12
  from .router import MessageRouter
12
13
 
13
14
  if TYPE_CHECKING:
@@ -6,7 +6,8 @@ Router to manage responses from a queue with relay.
6
6
 
7
7
  from typing import TYPE_CHECKING
8
8
 
9
- from ..lib.mailbox import Mailbox
9
+ from wandb.sdk.mailbox import Mailbox
10
+
10
11
  from .router_queue import MessageQueueRouter
11
12
 
12
13
  if TYPE_CHECKING:
@@ -6,8 +6,9 @@ Router to manage responses from a socket client.
6
6
 
7
7
  from typing import TYPE_CHECKING, Optional
8
8
 
9
- from ..lib.mailbox import Mailbox
10
- from ..lib.sock_client import SockClient, SockClientClosedError
9
+ from wandb.sdk.lib.sock_client import SockClient, SockClientClosedError
10
+ from wandb.sdk.mailbox import Mailbox
11
+
11
12
  from .router import MessageRouter, MessageRouterClosedError
12
13
 
13
14
  if TYPE_CHECKING:
@@ -25,8 +26,8 @@ class MessageSockRouter(MessageRouter):
25
26
  def _read_message(self) -> Optional["pb.Result"]:
26
27
  try:
27
28
  resp = self._sock_client.read_server_response(timeout=1)
28
- except SockClientClosedError:
29
- raise MessageRouterClosedError
29
+ except SockClientClosedError as e:
30
+ raise MessageRouterClosedError from e
30
31
  if not resp:
31
32
  return None
32
33
  msg = resp.result_communicate
@@ -205,9 +205,6 @@ class HandleManager:
205
205
  # defer is used to drive the sender finish state machine
206
206
  self._dispatch_record(record, always_send=True)
207
207
 
208
- def handle_request_login(self, record: Record) -> None:
209
- self._dispatch_record(record)
210
-
211
208
  def handle_request_python_packages(self, record: Record) -> None:
212
209
  self._dispatch_record(record)
213
210
 
@@ -892,6 +889,10 @@ class HandleManager:
892
889
  self._respond_result(result)
893
890
  self._stopped.set()
894
891
 
892
+ def handle_request_operations(self, record: Record) -> None:
893
+ """No-op. Not implemented for the legacy-service."""
894
+ self._respond_result(proto_util._result_from_record(record))
895
+
895
896
  def finish(self) -> None:
896
897
  logger.info("shutting down handler")
897
898
  if self._system_monitor is not None:
@@ -115,6 +115,7 @@ if TYPE_CHECKING:
115
115
  root_dir: Optional[str]
116
116
  api_key: Optional[str]
117
117
  entity: Optional[str]
118
+ organization: Optional[str]
118
119
  project: Optional[str]
119
120
  _extra_http_headers: Optional[Mapping[str, str]]
120
121
  _proxies: Optional[Mapping[str, str]]
@@ -256,6 +257,7 @@ class Api:
256
257
  "root_dir": None,
257
258
  "api_key": None,
258
259
  "entity": None,
260
+ "organization": None,
259
261
  "project": None,
260
262
  "_extra_http_headers": None,
261
263
  "_proxies": None,
@@ -489,7 +491,8 @@ class Api:
489
491
  {
490
492
  "entity": "models",
491
493
  "base_url": "https://api.wandb.ai",
492
- "project": None
494
+ "project": None,
495
+ "organization": "my-org",
493
496
  }
494
497
  """
495
498
  result = self.default_settings.copy()
@@ -504,6 +507,14 @@ class Api:
504
507
  ),
505
508
  env=self._environ,
506
509
  ),
510
+ "organization": env.get_organization(
511
+ self._settings.get(
512
+ Settings.DEFAULT_SECTION,
513
+ "organization",
514
+ fallback=result.get("organization"),
515
+ ),
516
+ env=self._environ,
517
+ ),
507
518
  "project": env.get_project(
508
519
  self._settings.get(
509
520
  Settings.DEFAULT_SECTION,
@@ -544,23 +544,6 @@ class SendManager:
544
544
  logger.warning(f"Error emptying retry queue: {e}")
545
545
  self._respond_result(result)
546
546
 
547
- def send_request_login(self, record: "Record") -> None:
548
- # TODO: do something with api_key or anonymous?
549
- # TODO: return an error if we aren't logged in?
550
- self._api.reauth()
551
- viewer = self.get_viewer_info()
552
- server_info = self.get_server_info()
553
- # self._login_flags = json.loads(viewer.get("flags", "{}"))
554
- # self._login_entity = viewer.get("entity")
555
- if server_info:
556
- logger.info(f"Login server info: {server_info}")
557
- self._entity = viewer.get("entity")
558
- if record.control.req_resp:
559
- result = proto_util._result_from_record(record)
560
- if self._entity:
561
- result.response.login_response.active_entity = self._entity
562
- self._respond_result(result)
563
-
564
547
  def send_exit(self, record: "Record") -> None:
565
548
  # track where the exit came from
566
549
  self._record_exit = record
@@ -1491,7 +1474,6 @@ class SendManager:
1491
1474
  self._job_builder.set_partial_source_id(use.id)
1492
1475
 
1493
1476
  def send_request_log_artifact(self, record: "Record") -> None:
1494
- assert record.control.req_resp
1495
1477
  result = proto_util._result_from_record(record)
1496
1478
  artifact = record.request.log_artifact.artifact
1497
1479
  history_step = record.request.log_artifact.history_step
wandb/sdk/lib/apikey.py CHANGED
@@ -1,5 +1,8 @@
1
1
  """apikey util."""
2
2
 
3
+ from __future__ import annotations
4
+
5
+ import dataclasses
3
6
  import os
4
7
  import platform
5
8
  import stat
@@ -32,8 +35,21 @@ LOGIN_CHOICES = [
32
35
  LOGIN_CHOICE_DRYRUN,
33
36
  ]
34
37
 
38
+
39
+ @dataclasses.dataclass(frozen=True)
40
+ class _NetrcPermissions:
41
+ exists: bool
42
+ read_access: bool
43
+ write_access: bool
44
+
45
+
46
+ class WriteNetrcError(Exception):
47
+ """Raised when we cannot write to the netrc file."""
48
+
49
+
35
50
  Mode = Literal["allow", "must", "never", "false", "true"]
36
51
 
52
+
37
53
  if TYPE_CHECKING:
38
54
  from wandb.sdk.wandb_settings import Settings
39
55
 
@@ -170,29 +186,64 @@ def prompt_api_key( # noqa: C901
170
186
  return key
171
187
 
172
188
 
173
- def write_netrc(host: str, entity: str, key: str) -> Optional[bool]:
189
+ def check_netrc_access(
190
+ netrc_path: str,
191
+ ) -> _NetrcPermissions:
192
+ """Check if we can read and write to the netrc file."""
193
+ file_exists = False
194
+ write_access = False
195
+ read_access = False
196
+ try:
197
+ st = os.stat(netrc_path)
198
+ file_exists = True
199
+ write_access = bool(st.st_mode & stat.S_IWUSR)
200
+ read_access = bool(st.st_mode & stat.S_IRUSR)
201
+ except FileNotFoundError:
202
+ # If the netrc file doesn't exist, we will create it.
203
+ write_access = True
204
+ read_access = True
205
+ except OSError as e:
206
+ wandb.termerror(f"Unable to read permissions for {netrc_path}, {e}")
207
+
208
+ return _NetrcPermissions(
209
+ exists=file_exists,
210
+ write_access=write_access,
211
+ read_access=read_access,
212
+ )
213
+
214
+
215
+ def write_netrc(host: str, entity: str, key: str):
174
216
  """Add our host and key to .netrc."""
175
217
  _, key_suffix = key.split("-", 1) if "-" in key else ("", key)
176
218
  if len(key_suffix) != 40:
177
- wandb.termerror(
219
+ raise ValueError(
178
220
  "API-key must be exactly 40 characters long: {} ({} chars)".format(
179
221
  key_suffix, len(key_suffix)
180
222
  )
181
223
  )
182
- return None
183
- try:
184
- normalized_host = urlparse(host).netloc.split(":")[0]
185
- netrc_path = get_netrc_file_path()
186
- wandb.termlog(
187
- f"Appending key for {normalized_host} to your netrc file: {netrc_path}"
224
+
225
+ normalized_host = urlparse(host).netloc
226
+ netrc_path = get_netrc_file_path()
227
+ netrc_access = check_netrc_access(netrc_path)
228
+
229
+ if not netrc_access.write_access or not netrc_access.read_access:
230
+ raise WriteNetrcError(
231
+ f"Cannot access {netrc_path}. In order to persist your API key, "
232
+ "grant read and write permissions for your user to the file "
233
+ 'or specify a different file with the environment variable "NETRC=<new_netrc_path>".'
188
234
  )
189
- machine_line = f"machine {normalized_host}"
190
- orig_lines = None
191
- try:
192
- with open(netrc_path) as f:
193
- orig_lines = f.read().strip().split("\n")
194
- except OSError:
195
- pass
235
+
236
+ machine_line = f"machine {normalized_host}"
237
+ orig_lines = None
238
+ try:
239
+ with open(netrc_path) as f:
240
+ orig_lines = f.read().strip().split("\n")
241
+ except FileNotFoundError:
242
+ wandb.termlog("No netrc file found, creating one.")
243
+ except OSError as e:
244
+ raise WriteNetrcError(f"Unable to read {netrc_path}") from e
245
+
246
+ try:
196
247
  with open(netrc_path, "w") as f:
197
248
  if orig_lines:
198
249
  # delete this machine from the file if it's already there.
@@ -206,20 +257,22 @@ def write_netrc(host: str, entity: str, key: str) -> Optional[bool]:
206
257
  skip -= 1
207
258
  else:
208
259
  f.write("{}\n".format(line))
260
+
261
+ wandb.termlog(
262
+ f"Appending key for {normalized_host} to your netrc file: {netrc_path}"
263
+ )
209
264
  f.write(
210
265
  textwrap.dedent(
211
266
  """\
212
- machine {host}
213
- login {entity}
214
- password {key}
215
- """
267
+ machine {host}
268
+ login {entity}
269
+ password {key}
270
+ """
216
271
  ).format(host=normalized_host, entity=entity, key=key)
217
272
  )
218
273
  os.chmod(netrc_path, stat.S_IRUSR | stat.S_IWUSR)
219
- return True
220
- except OSError:
221
- wandb.termerror(f"Unable to read {netrc_path}")
222
- return None
274
+ except OSError as e:
275
+ raise WriteNetrcError(f"Unable to write {netrc_path}") from e
223
276
 
224
277
 
225
278
  def write_key(
@@ -250,7 +303,15 @@ def api_key(settings: Optional["Settings"] = None) -> Optional[str]:
250
303
  settings = wandb.setup().settings
251
304
  if settings.api_key:
252
305
  return settings.api_key
253
- auth = get_netrc_auth(settings.base_url)
254
- if auth:
255
- return auth[-1]
306
+
307
+ netrc_access = check_netrc_access(get_netrc_file_path())
308
+ if netrc_access.exists and not netrc_access.read_access:
309
+ wandb.termwarn(f"Cannot access {get_netrc_file_path()}.")
310
+ return None
311
+
312
+ if netrc_access.exists:
313
+ auth = get_netrc_auth(settings.base_url)
314
+ if auth:
315
+ return auth[-1]
316
+
256
317
  return None
@@ -0,0 +1,210 @@
1
+ """Functions for compatibility with asyncio."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import concurrent
7
+ import concurrent.futures
8
+ import contextlib
9
+ import threading
10
+ from typing import Any, AsyncIterator, Callable, Coroutine, Iterator, TypeVar
11
+
12
+ _T = TypeVar("_T")
13
+
14
+
15
+ def run(fn: Callable[[], Coroutine[Any, Any, _T]]) -> _T:
16
+ """Run `fn` in an asyncio loop in a new thread.
17
+
18
+ This must always be used instead of `asyncio.run` which fails if there is
19
+ an active `asyncio` event loop in the current thread. Since `wandb` was not
20
+ originally designed with `asyncio` in mind, using `asyncio.run` would break
21
+ users who were calling `wandb` methods from an `asyncio` loop.
22
+
23
+ Note that due to starting a new thread, this is slightly slow.
24
+ """
25
+ with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
26
+ runner = _Runner()
27
+ future = executor.submit(runner.run, fn)
28
+
29
+ try:
30
+ return future.result()
31
+
32
+ finally:
33
+ runner.cancel()
34
+
35
+
36
+ class _RunnerCancelledError(Exception):
37
+ """The `_Runner.run()` invocation was cancelled."""
38
+
39
+
40
+ class _Runner:
41
+ """Runs an asyncio event loop allowing cancellation.
42
+
43
+ This is like `asyncio.run()`, except it provides a `cancel()` method
44
+ meant to be called in a `finally` block.
45
+
46
+ Without this, it is impossible to make `asyncio.run()` stop if it runs
47
+ in a non-main thread. In particular, a KeyboardInterrupt causes the
48
+ ThreadPoolExecutor above to block until the asyncio thread completes,
49
+ but there is no way to tell the asyncio thread to cancel its work.
50
+ A second KeyboardInterrupt makes ThreadPoolExecutor give up while the
51
+ asyncio thread still runs in the background, with terrible effects if it
52
+ prints to the user's terminal.
53
+ """
54
+
55
+ def __init__(self) -> None:
56
+ self._lock = threading.Condition()
57
+
58
+ self._is_cancelled = False
59
+ self._started = False
60
+ self._done = False
61
+
62
+ self._loop: asyncio.AbstractEventLoop | None = None
63
+ self._cancel_event: asyncio.Event | None = None
64
+
65
+ def run(self, fn: Callable[[], Coroutine[Any, Any, _T]]) -> _T:
66
+ """Run a coroutine in asyncio, cancelling it on `cancel()`.
67
+
68
+ Returns:
69
+ The result of the coroutine returned by `fn`.
70
+
71
+ Raises:
72
+ _RunnerCancelledError: If `cancel()` is called.
73
+ """
74
+ return asyncio.run(self._run_or_cancel(fn))
75
+
76
+ async def _run_or_cancel(
77
+ self,
78
+ fn: Callable[[], Coroutine[Any, Any, _T]],
79
+ ) -> _T:
80
+ with self._lock:
81
+ if self._is_cancelled:
82
+ raise _RunnerCancelledError()
83
+
84
+ self._loop = asyncio.get_running_loop()
85
+ self._cancel_event = asyncio.Event()
86
+ self._started = True
87
+
88
+ cancellation_task = asyncio.create_task(self._cancel_event.wait())
89
+ fn_task = asyncio.create_task(fn())
90
+
91
+ try:
92
+ await asyncio.wait(
93
+ [cancellation_task, fn_task],
94
+ return_when=asyncio.FIRST_COMPLETED,
95
+ )
96
+
97
+ if fn_task.done():
98
+ return fn_task.result()
99
+ else:
100
+ raise _RunnerCancelledError()
101
+
102
+ finally:
103
+ cancellation_task.cancel()
104
+ fn_task.cancel()
105
+
106
+ with self._lock:
107
+ self._done = True
108
+
109
+ def cancel(self) -> None:
110
+ """Cancel all asyncio work started by `run()`."""
111
+ with self._lock:
112
+ if self._is_cancelled:
113
+ return
114
+ self._is_cancelled = True
115
+
116
+ if self._done or not self._started:
117
+ # If the runner already finished, no need to cancel it.
118
+ #
119
+ # If the runner hasn't started the loop yet, then it will not
120
+ # as we already set _is_cancelled.
121
+ return
122
+
123
+ assert self._loop
124
+ assert self._cancel_event
125
+ self._loop.call_soon_threadsafe(self._cancel_event.set)
126
+
127
+
128
+ class TaskGroup:
129
+ """Object that `open_task_group()` yields."""
130
+
131
+ def __init__(self) -> None:
132
+ self._tasks: list[asyncio.Task] = []
133
+
134
+ def start_soon(self, coro: Coroutine[Any, Any, Any]) -> None:
135
+ """Schedule a task in the group.
136
+
137
+ Args:
138
+ coro: The return value of the `async` function defining the task.
139
+ """
140
+ self._tasks.append(asyncio.create_task(coro))
141
+
142
+ async def _wait_all(self) -> None:
143
+ """Block until all tasks complete.
144
+
145
+ Raises:
146
+ Exception: If one or more tasks raises an exception, one of these
147
+ is raised arbitrarily.
148
+ """
149
+ done, _ = await asyncio.wait(
150
+ self._tasks,
151
+ # NOTE: Cancelling a task counts as a normal exit,
152
+ # not an exception.
153
+ return_when=concurrent.futures.FIRST_EXCEPTION,
154
+ )
155
+
156
+ for task in done:
157
+ try:
158
+ if exc := task.exception():
159
+ raise exc
160
+ except asyncio.CancelledError:
161
+ pass
162
+
163
+ def _cancel_all(self) -> None:
164
+ """Cancel all tasks."""
165
+ for task in self._tasks:
166
+ # NOTE: It is safe to cancel tasks that have already completed.
167
+ task.cancel()
168
+
169
+
170
+ @contextlib.asynccontextmanager
171
+ async def open_task_group() -> AsyncIterator[TaskGroup]:
172
+ """Create a task group.
173
+
174
+ `asyncio` gained task groups in Python 3.11.
175
+
176
+ This is an async context manager, meant to be used with `async with`.
177
+ On exit, it blocks until all subtasks complete. If any subtask fails, or if
178
+ the current task is cancelled, it cancels all subtasks in the group and
179
+ raises the subtask's exception. If multiple subtasks fail simultaneously,
180
+ one of their exceptions is chosen arbitrarily.
181
+
182
+ NOTE: Subtask exceptions do not propagate until the context manager exits.
183
+ This means that the task group cannot cancel code running inside the
184
+ `async with` block .
185
+ """
186
+ task_group = TaskGroup()
187
+
188
+ try:
189
+ yield task_group
190
+ await task_group._wait_all()
191
+ finally:
192
+ task_group._cancel_all()
193
+
194
+
195
+ @contextlib.contextmanager
196
+ def cancel_on_exit(coro: Coroutine[Any, Any, Any]) -> Iterator[None]:
197
+ """Schedule a task, cancelling it when exiting the context manager.
198
+
199
+ If the given coroutine raises an exception, that exception is raised
200
+ when exiting the context manager.
201
+ """
202
+ task = asyncio.create_task(coro)
203
+
204
+ try:
205
+ yield
206
+ finally:
207
+ if task.done() and (exception := task.exception()):
208
+ raise exception
209
+
210
+ task.cancel()