wandb 0.19.6rc4__py3-none-any.whl → 0.19.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. wandb/__init__.py +1 -1
  2. wandb/__init__.pyi +56 -6
  3. wandb/apis/public/_generated/__init__.py +21 -0
  4. wandb/apis/public/_generated/base.py +128 -0
  5. wandb/apis/public/_generated/enums.py +4 -0
  6. wandb/apis/public/_generated/input_types.py +4 -0
  7. wandb/apis/public/_generated/operations.py +15 -0
  8. wandb/apis/public/_generated/server_features_query.py +27 -0
  9. wandb/apis/public/_generated/typing_compat.py +14 -0
  10. wandb/apis/public/api.py +192 -6
  11. wandb/apis/public/artifacts.py +13 -45
  12. wandb/apis/public/registries.py +573 -0
  13. wandb/apis/public/utils.py +36 -0
  14. wandb/bin/gpu_stats +0 -0
  15. wandb/cli/cli.py +11 -20
  16. wandb/data_types.py +1 -1
  17. wandb/env.py +10 -0
  18. wandb/filesync/dir_watcher.py +2 -1
  19. wandb/proto/v3/wandb_internal_pb2.py +243 -222
  20. wandb/proto/v3/wandb_server_pb2.py +4 -4
  21. wandb/proto/v3/wandb_settings_pb2.py +2 -2
  22. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  23. wandb/proto/v4/wandb_internal_pb2.py +226 -222
  24. wandb/proto/v4/wandb_server_pb2.py +4 -4
  25. wandb/proto/v4/wandb_settings_pb2.py +2 -2
  26. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  27. wandb/proto/v5/wandb_internal_pb2.py +226 -222
  28. wandb/proto/v5/wandb_server_pb2.py +4 -4
  29. wandb/proto/v5/wandb_settings_pb2.py +2 -2
  30. wandb/proto/v5/wandb_telemetry_pb2.py +10 -10
  31. wandb/sdk/artifacts/_graphql_fragments.py +126 -0
  32. wandb/sdk/artifacts/artifact.py +51 -95
  33. wandb/sdk/backend/backend.py +17 -6
  34. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +14 -6
  35. wandb/sdk/data_types/helper_types/image_mask.py +12 -6
  36. wandb/sdk/data_types/saved_model.py +35 -46
  37. wandb/sdk/data_types/video.py +7 -16
  38. wandb/sdk/interface/interface.py +87 -49
  39. wandb/sdk/interface/interface_queue.py +5 -15
  40. wandb/sdk/interface/interface_relay.py +7 -22
  41. wandb/sdk/interface/interface_shared.py +65 -136
  42. wandb/sdk/interface/interface_sock.py +3 -21
  43. wandb/sdk/interface/router.py +42 -68
  44. wandb/sdk/interface/router_queue.py +13 -11
  45. wandb/sdk/interface/router_relay.py +26 -13
  46. wandb/sdk/interface/router_sock.py +12 -16
  47. wandb/sdk/internal/handler.py +4 -3
  48. wandb/sdk/internal/internal_api.py +12 -1
  49. wandb/sdk/internal/sender.py +3 -19
  50. wandb/sdk/lib/apikey.py +87 -26
  51. wandb/sdk/lib/asyncio_compat.py +210 -0
  52. wandb/sdk/lib/console_capture.py +172 -0
  53. wandb/sdk/lib/progress.py +78 -16
  54. wandb/sdk/lib/redirect.py +102 -76
  55. wandb/sdk/lib/service_connection.py +37 -17
  56. wandb/sdk/lib/sock_client.py +6 -56
  57. wandb/sdk/mailbox/__init__.py +23 -0
  58. wandb/sdk/mailbox/mailbox.py +135 -0
  59. wandb/sdk/mailbox/mailbox_handle.py +127 -0
  60. wandb/sdk/mailbox/response_handle.py +167 -0
  61. wandb/sdk/mailbox/wait_with_progress.py +135 -0
  62. wandb/sdk/service/server_sock.py +9 -3
  63. wandb/sdk/service/streams.py +75 -78
  64. wandb/sdk/verify/verify.py +54 -2
  65. wandb/sdk/wandb_init.py +72 -75
  66. wandb/sdk/wandb_login.py +7 -4
  67. wandb/sdk/wandb_metadata.py +65 -34
  68. wandb/sdk/wandb_require.py +14 -8
  69. wandb/sdk/wandb_run.py +90 -97
  70. wandb/sdk/wandb_settings.py +10 -4
  71. wandb/sdk/wandb_setup.py +19 -8
  72. wandb/sdk/wandb_sync.py +2 -10
  73. wandb/util.py +3 -1
  74. {wandb-0.19.6rc4.dist-info → wandb-0.19.8.dist-info}/METADATA +2 -2
  75. {wandb-0.19.6rc4.dist-info → wandb-0.19.8.dist-info}/RECORD +78 -65
  76. wandb/sdk/interface/message_future.py +0 -27
  77. wandb/sdk/interface/message_future_poll.py +0 -50
  78. wandb/sdk/lib/mailbox.py +0 -442
  79. {wandb-0.19.6rc4.dist-info → wandb-0.19.8.dist-info}/WHEEL +0 -0
  80. {wandb-0.19.6rc4.dist-info → wandb-0.19.8.dist-info}/entry_points.txt +0 -0
  81. {wandb-0.19.6rc4.dist-info → wandb-0.19.8.dist-info}/licenses/LICENSE +0 -0
@@ -4,14 +4,14 @@ Router to manage responses from a socket client.
4
4
 
5
5
  """
6
6
 
7
- from typing import TYPE_CHECKING, Optional
7
+ from __future__ import annotations
8
8
 
9
- from ..lib.mailbox import Mailbox
10
- from ..lib.sock_client import SockClient, SockClientClosedError
11
- from .router import MessageRouter, MessageRouterClosedError
9
+ from wandb.proto import wandb_internal_pb2 as pb
10
+ from wandb.proto import wandb_server_pb2 as spb
11
+ from wandb.sdk.lib.sock_client import SockClient, SockClientClosedError
12
+ from wandb.sdk.mailbox import Mailbox
12
13
 
13
- if TYPE_CHECKING:
14
- from wandb.proto import wandb_internal_pb2 as pb
14
+ from .router import MessageRouter, MessageRouterClosedError
15
15
 
16
16
 
17
17
  class MessageSockRouter(MessageRouter):
@@ -22,15 +22,11 @@ class MessageSockRouter(MessageRouter):
22
22
  self._sock_client = sock_client
23
23
  super().__init__(mailbox=mailbox)
24
24
 
25
- def _read_message(self) -> Optional["pb.Result"]:
25
+ def _read_message(self) -> spb.ServerResponse | None:
26
26
  try:
27
- resp = self._sock_client.read_server_response(timeout=1)
28
- except SockClientClosedError:
29
- raise MessageRouterClosedError
30
- if not resp:
31
- return None
32
- msg = resp.result_communicate
33
- return msg
34
-
35
- def _send_message(self, record: "pb.Record") -> None:
27
+ return self._sock_client.read_server_response(timeout=1)
28
+ except SockClientClosedError as e:
29
+ raise MessageRouterClosedError from e
30
+
31
+ def _send_message(self, record: pb.Record) -> None:
36
32
  self._sock_client.send_record_communicate(record)
@@ -205,9 +205,6 @@ class HandleManager:
205
205
  # defer is used to drive the sender finish state machine
206
206
  self._dispatch_record(record, always_send=True)
207
207
 
208
- def handle_request_login(self, record: Record) -> None:
209
- self._dispatch_record(record)
210
-
211
208
  def handle_request_python_packages(self, record: Record) -> None:
212
209
  self._dispatch_record(record)
213
210
 
@@ -892,6 +889,10 @@ class HandleManager:
892
889
  self._respond_result(result)
893
890
  self._stopped.set()
894
891
 
892
+ def handle_request_operations(self, record: Record) -> None:
893
+ """No-op. Not implemented for the legacy-service."""
894
+ self._respond_result(proto_util._result_from_record(record))
895
+
895
896
  def finish(self) -> None:
896
897
  logger.info("shutting down handler")
897
898
  if self._system_monitor is not None:
@@ -115,6 +115,7 @@ if TYPE_CHECKING:
115
115
  root_dir: Optional[str]
116
116
  api_key: Optional[str]
117
117
  entity: Optional[str]
118
+ organization: Optional[str]
118
119
  project: Optional[str]
119
120
  _extra_http_headers: Optional[Mapping[str, str]]
120
121
  _proxies: Optional[Mapping[str, str]]
@@ -256,6 +257,7 @@ class Api:
256
257
  "root_dir": None,
257
258
  "api_key": None,
258
259
  "entity": None,
260
+ "organization": None,
259
261
  "project": None,
260
262
  "_extra_http_headers": None,
261
263
  "_proxies": None,
@@ -489,7 +491,8 @@ class Api:
489
491
  {
490
492
  "entity": "models",
491
493
  "base_url": "https://api.wandb.ai",
492
- "project": None
494
+ "project": None,
495
+ "organization": "my-org",
493
496
  }
494
497
  """
495
498
  result = self.default_settings.copy()
@@ -504,6 +507,14 @@ class Api:
504
507
  ),
505
508
  env=self._environ,
506
509
  ),
510
+ "organization": env.get_organization(
511
+ self._settings.get(
512
+ Settings.DEFAULT_SECTION,
513
+ "organization",
514
+ fallback=result.get("organization"),
515
+ ),
516
+ env=self._environ,
517
+ ),
507
518
  "project": env.get_project(
508
519
  self._settings.get(
509
520
  Settings.DEFAULT_SECTION,
@@ -1,6 +1,7 @@
1
1
  """sender."""
2
2
 
3
3
  import contextlib
4
+ import glob
4
5
  import gzip
5
6
  import json
6
7
  import logging
@@ -544,23 +545,6 @@ class SendManager:
544
545
  logger.warning(f"Error emptying retry queue: {e}")
545
546
  self._respond_result(result)
546
547
 
547
- def send_request_login(self, record: "Record") -> None:
548
- # TODO: do something with api_key or anonymous?
549
- # TODO: return an error if we aren't logged in?
550
- self._api.reauth()
551
- viewer = self.get_viewer_info()
552
- server_info = self.get_server_info()
553
- # self._login_flags = json.loads(viewer.get("flags", "{}"))
554
- # self._login_entity = viewer.get("entity")
555
- if server_info:
556
- logger.info(f"Login server info: {server_info}")
557
- self._entity = viewer.get("entity")
558
- if record.control.req_resp:
559
- result = proto_util._result_from_record(record)
560
- if self._entity:
561
- result.response.login_response.active_entity = self._entity
562
- self._respond_result(result)
563
-
564
548
  def send_exit(self, record: "Record") -> None:
565
549
  # track where the exit came from
566
550
  self._record_exit = record
@@ -1425,7 +1409,8 @@ class SendManager:
1425
1409
  for k in files.files:
1426
1410
  # TODO(jhr): fix paths with directories
1427
1411
  self._save_file(
1428
- interface.GlobStr(k.path), interface.file_enum_to_policy(k.policy)
1412
+ interface.GlobStr(glob.escape(k.path)),
1413
+ interface.file_enum_to_policy(k.policy),
1429
1414
  )
1430
1415
 
1431
1416
  def send_header(self, record: "Record") -> None:
@@ -1491,7 +1476,6 @@ class SendManager:
1491
1476
  self._job_builder.set_partial_source_id(use.id)
1492
1477
 
1493
1478
  def send_request_log_artifact(self, record: "Record") -> None:
1494
- assert record.control.req_resp
1495
1479
  result = proto_util._result_from_record(record)
1496
1480
  artifact = record.request.log_artifact.artifact
1497
1481
  history_step = record.request.log_artifact.history_step
wandb/sdk/lib/apikey.py CHANGED
@@ -1,5 +1,8 @@
1
1
  """apikey util."""
2
2
 
3
+ from __future__ import annotations
4
+
5
+ import dataclasses
3
6
  import os
4
7
  import platform
5
8
  import stat
@@ -32,8 +35,21 @@ LOGIN_CHOICES = [
32
35
  LOGIN_CHOICE_DRYRUN,
33
36
  ]
34
37
 
38
+
39
+ @dataclasses.dataclass(frozen=True)
40
+ class _NetrcPermissions:
41
+ exists: bool
42
+ read_access: bool
43
+ write_access: bool
44
+
45
+
46
+ class WriteNetrcError(Exception):
47
+ """Raised when we cannot write to the netrc file."""
48
+
49
+
35
50
  Mode = Literal["allow", "must", "never", "false", "true"]
36
51
 
52
+
37
53
  if TYPE_CHECKING:
38
54
  from wandb.sdk.wandb_settings import Settings
39
55
 
@@ -170,29 +186,64 @@ def prompt_api_key( # noqa: C901
170
186
  return key
171
187
 
172
188
 
173
- def write_netrc(host: str, entity: str, key: str) -> Optional[bool]:
189
+ def check_netrc_access(
190
+ netrc_path: str,
191
+ ) -> _NetrcPermissions:
192
+ """Check if we can read and write to the netrc file."""
193
+ file_exists = False
194
+ write_access = False
195
+ read_access = False
196
+ try:
197
+ st = os.stat(netrc_path)
198
+ file_exists = True
199
+ write_access = bool(st.st_mode & stat.S_IWUSR)
200
+ read_access = bool(st.st_mode & stat.S_IRUSR)
201
+ except FileNotFoundError:
202
+ # If the netrc file doesn't exist, we will create it.
203
+ write_access = True
204
+ read_access = True
205
+ except OSError as e:
206
+ wandb.termerror(f"Unable to read permissions for {netrc_path}, {e}")
207
+
208
+ return _NetrcPermissions(
209
+ exists=file_exists,
210
+ write_access=write_access,
211
+ read_access=read_access,
212
+ )
213
+
214
+
215
+ def write_netrc(host: str, entity: str, key: str):
174
216
  """Add our host and key to .netrc."""
175
217
  _, key_suffix = key.split("-", 1) if "-" in key else ("", key)
176
218
  if len(key_suffix) != 40:
177
- wandb.termerror(
219
+ raise ValueError(
178
220
  "API-key must be exactly 40 characters long: {} ({} chars)".format(
179
221
  key_suffix, len(key_suffix)
180
222
  )
181
223
  )
182
- return None
183
- try:
184
- normalized_host = urlparse(host).netloc.split(":")[0]
185
- netrc_path = get_netrc_file_path()
186
- wandb.termlog(
187
- f"Appending key for {normalized_host} to your netrc file: {netrc_path}"
224
+
225
+ normalized_host = urlparse(host).netloc
226
+ netrc_path = get_netrc_file_path()
227
+ netrc_access = check_netrc_access(netrc_path)
228
+
229
+ if not netrc_access.write_access or not netrc_access.read_access:
230
+ raise WriteNetrcError(
231
+ f"Cannot access {netrc_path}. In order to persist your API key, "
232
+ "grant read and write permissions for your user to the file "
233
+ 'or specify a different file with the environment variable "NETRC=<new_netrc_path>".'
188
234
  )
189
- machine_line = f"machine {normalized_host}"
190
- orig_lines = None
191
- try:
192
- with open(netrc_path) as f:
193
- orig_lines = f.read().strip().split("\n")
194
- except OSError:
195
- pass
235
+
236
+ machine_line = f"machine {normalized_host}"
237
+ orig_lines = None
238
+ try:
239
+ with open(netrc_path) as f:
240
+ orig_lines = f.read().strip().split("\n")
241
+ except FileNotFoundError:
242
+ wandb.termlog("No netrc file found, creating one.")
243
+ except OSError as e:
244
+ raise WriteNetrcError(f"Unable to read {netrc_path}") from e
245
+
246
+ try:
196
247
  with open(netrc_path, "w") as f:
197
248
  if orig_lines:
198
249
  # delete this machine from the file if it's already there.
@@ -206,20 +257,22 @@ def write_netrc(host: str, entity: str, key: str) -> Optional[bool]:
206
257
  skip -= 1
207
258
  else:
208
259
  f.write("{}\n".format(line))
260
+
261
+ wandb.termlog(
262
+ f"Appending key for {normalized_host} to your netrc file: {netrc_path}"
263
+ )
209
264
  f.write(
210
265
  textwrap.dedent(
211
266
  """\
212
- machine {host}
213
- login {entity}
214
- password {key}
215
- """
267
+ machine {host}
268
+ login {entity}
269
+ password {key}
270
+ """
216
271
  ).format(host=normalized_host, entity=entity, key=key)
217
272
  )
218
273
  os.chmod(netrc_path, stat.S_IRUSR | stat.S_IWUSR)
219
- return True
220
- except OSError:
221
- wandb.termerror(f"Unable to read {netrc_path}")
222
- return None
274
+ except OSError as e:
275
+ raise WriteNetrcError(f"Unable to write {netrc_path}") from e
223
276
 
224
277
 
225
278
  def write_key(
@@ -250,7 +303,15 @@ def api_key(settings: Optional["Settings"] = None) -> Optional[str]:
250
303
  settings = wandb.setup().settings
251
304
  if settings.api_key:
252
305
  return settings.api_key
253
- auth = get_netrc_auth(settings.base_url)
254
- if auth:
255
- return auth[-1]
306
+
307
+ netrc_access = check_netrc_access(get_netrc_file_path())
308
+ if netrc_access.exists and not netrc_access.read_access:
309
+ wandb.termwarn(f"Cannot access {get_netrc_file_path()}.")
310
+ return None
311
+
312
+ if netrc_access.exists:
313
+ auth = get_netrc_auth(settings.base_url)
314
+ if auth:
315
+ return auth[-1]
316
+
256
317
  return None
@@ -0,0 +1,210 @@
1
+ """Functions for compatibility with asyncio."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import concurrent
7
+ import concurrent.futures
8
+ import contextlib
9
+ import threading
10
+ from typing import Any, AsyncIterator, Callable, Coroutine, Iterator, TypeVar
11
+
12
+ _T = TypeVar("_T")
13
+
14
+
15
+ def run(fn: Callable[[], Coroutine[Any, Any, _T]]) -> _T:
16
+ """Run `fn` in an asyncio loop in a new thread.
17
+
18
+ This must always be used instead of `asyncio.run` which fails if there is
19
+ an active `asyncio` event loop in the current thread. Since `wandb` was not
20
+ originally designed with `asyncio` in mind, using `asyncio.run` would break
21
+ users who were calling `wandb` methods from an `asyncio` loop.
22
+
23
+ Note that due to starting a new thread, this is slightly slow.
24
+ """
25
+ with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
26
+ runner = _Runner()
27
+ future = executor.submit(runner.run, fn)
28
+
29
+ try:
30
+ return future.result()
31
+
32
+ finally:
33
+ runner.cancel()
34
+
35
+
36
+ class _RunnerCancelledError(Exception):
37
+ """The `_Runner.run()` invocation was cancelled."""
38
+
39
+
40
+ class _Runner:
41
+ """Runs an asyncio event loop allowing cancellation.
42
+
43
+ This is like `asyncio.run()`, except it provides a `cancel()` method
44
+ meant to be called in a `finally` block.
45
+
46
+ Without this, it is impossible to make `asyncio.run()` stop if it runs
47
+ in a non-main thread. In particular, a KeyboardInterrupt causes the
48
+ ThreadPoolExecutor above to block until the asyncio thread completes,
49
+ but there is no way to tell the asyncio thread to cancel its work.
50
+ A second KeyboardInterrupt makes ThreadPoolExecutor give up while the
51
+ asyncio thread still runs in the background, with terrible effects if it
52
+ prints to the user's terminal.
53
+ """
54
+
55
+ def __init__(self) -> None:
56
+ self._lock = threading.Condition()
57
+
58
+ self._is_cancelled = False
59
+ self._started = False
60
+ self._done = False
61
+
62
+ self._loop: asyncio.AbstractEventLoop | None = None
63
+ self._cancel_event: asyncio.Event | None = None
64
+
65
+ def run(self, fn: Callable[[], Coroutine[Any, Any, _T]]) -> _T:
66
+ """Run a coroutine in asyncio, cancelling it on `cancel()`.
67
+
68
+ Returns:
69
+ The result of the coroutine returned by `fn`.
70
+
71
+ Raises:
72
+ _RunnerCancelledError: If `cancel()` is called.
73
+ """
74
+ return asyncio.run(self._run_or_cancel(fn))
75
+
76
+ async def _run_or_cancel(
77
+ self,
78
+ fn: Callable[[], Coroutine[Any, Any, _T]],
79
+ ) -> _T:
80
+ with self._lock:
81
+ if self._is_cancelled:
82
+ raise _RunnerCancelledError()
83
+
84
+ self._loop = asyncio.get_running_loop()
85
+ self._cancel_event = asyncio.Event()
86
+ self._started = True
87
+
88
+ cancellation_task = asyncio.create_task(self._cancel_event.wait())
89
+ fn_task = asyncio.create_task(fn())
90
+
91
+ try:
92
+ await asyncio.wait(
93
+ [cancellation_task, fn_task],
94
+ return_when=asyncio.FIRST_COMPLETED,
95
+ )
96
+
97
+ if fn_task.done():
98
+ return fn_task.result()
99
+ else:
100
+ raise _RunnerCancelledError()
101
+
102
+ finally:
103
+ cancellation_task.cancel()
104
+ fn_task.cancel()
105
+
106
+ with self._lock:
107
+ self._done = True
108
+
109
+ def cancel(self) -> None:
110
+ """Cancel all asyncio work started by `run()`."""
111
+ with self._lock:
112
+ if self._is_cancelled:
113
+ return
114
+ self._is_cancelled = True
115
+
116
+ if self._done or not self._started:
117
+ # If the runner already finished, no need to cancel it.
118
+ #
119
+ # If the runner hasn't started the loop yet, then it will not
120
+ # as we already set _is_cancelled.
121
+ return
122
+
123
+ assert self._loop
124
+ assert self._cancel_event
125
+ self._loop.call_soon_threadsafe(self._cancel_event.set)
126
+
127
+
128
+ class TaskGroup:
129
+ """Object that `open_task_group()` yields."""
130
+
131
+ def __init__(self) -> None:
132
+ self._tasks: list[asyncio.Task] = []
133
+
134
+ def start_soon(self, coro: Coroutine[Any, Any, Any]) -> None:
135
+ """Schedule a task in the group.
136
+
137
+ Args:
138
+ coro: The return value of the `async` function defining the task.
139
+ """
140
+ self._tasks.append(asyncio.create_task(coro))
141
+
142
+ async def _wait_all(self) -> None:
143
+ """Block until all tasks complete.
144
+
145
+ Raises:
146
+ Exception: If one or more tasks raises an exception, one of these
147
+ is raised arbitrarily.
148
+ """
149
+ done, _ = await asyncio.wait(
150
+ self._tasks,
151
+ # NOTE: Cancelling a task counts as a normal exit,
152
+ # not an exception.
153
+ return_when=concurrent.futures.FIRST_EXCEPTION,
154
+ )
155
+
156
+ for task in done:
157
+ try:
158
+ if exc := task.exception():
159
+ raise exc
160
+ except asyncio.CancelledError:
161
+ pass
162
+
163
+ def _cancel_all(self) -> None:
164
+ """Cancel all tasks."""
165
+ for task in self._tasks:
166
+ # NOTE: It is safe to cancel tasks that have already completed.
167
+ task.cancel()
168
+
169
+
170
+ @contextlib.asynccontextmanager
171
+ async def open_task_group() -> AsyncIterator[TaskGroup]:
172
+ """Create a task group.
173
+
174
+ `asyncio` gained task groups in Python 3.11.
175
+
176
+ This is an async context manager, meant to be used with `async with`.
177
+ On exit, it blocks until all subtasks complete. If any subtask fails, or if
178
+ the current task is cancelled, it cancels all subtasks in the group and
179
+ raises the subtask's exception. If multiple subtasks fail simultaneously,
180
+ one of their exceptions is chosen arbitrarily.
181
+
182
+ NOTE: Subtask exceptions do not propagate until the context manager exits.
183
+ This means that the task group cannot cancel code running inside the
184
+ `async with` block .
185
+ """
186
+ task_group = TaskGroup()
187
+
188
+ try:
189
+ yield task_group
190
+ await task_group._wait_all()
191
+ finally:
192
+ task_group._cancel_all()
193
+
194
+
195
+ @contextlib.contextmanager
196
+ def cancel_on_exit(coro: Coroutine[Any, Any, Any]) -> Iterator[None]:
197
+ """Schedule a task, cancelling it when exiting the context manager.
198
+
199
+ If the given coroutine raises an exception, that exception is raised
200
+ when exiting the context manager.
201
+ """
202
+ task = asyncio.create_task(coro)
203
+
204
+ try:
205
+ yield
206
+ finally:
207
+ if task.done() and (exception := task.exception()):
208
+ raise exception
209
+
210
+ task.cancel()
@@ -0,0 +1,172 @@
1
+ """Module for intercepting stdout/stderr.
2
+
3
+ This patches the `write()` method of `stdout` and `stderr` on import.
4
+ Once patched, it is not possible to unpatch or repatch, though individual
5
+ callbacks can be removed.
6
+
7
+ We assume that all other writing methods on the object delegate to `write()`,
8
+ like `writelines()`. This is not guaranteed to be true, but it is true for
9
+ common implementations. In particular, CPython's implementation of IOBase's
10
+ `writelines()` delegates to `write()`.
11
+
12
+ It is important to note that this technique interacts poorly with other
13
+ code that performs similar patching if it also allows unpatching as this
14
+ discards our modification. This is why we patch on import and do not support
15
+ unpatching:
16
+
17
+ with contextlib.redirect_stderr(...):
18
+ from ... import console_capture
19
+ # Here, everything works fine.
20
+ # Here, callbacks are never called again.
21
+
22
+ In particular, it does not work with some combinations of pytest's
23
+ `capfd` / `capsys` fixtures and pytest's `--capture` option.
24
+ """
25
+
26
+ from __future__ import annotations
27
+
28
+ import sys
29
+ import threading
30
+ from typing import IO, AnyStr, Callable, Protocol
31
+
32
+
33
+ class CannotCaptureConsoleError(Exception):
34
+ """The module failed to patch stdout or stderr."""
35
+
36
+
37
+ class _WriteCallback(Protocol):
38
+ """A callback that receives intercepted bytes or string data."""
39
+
40
+ def __call__(
41
+ self,
42
+ data: bytes | str,
43
+ written: int,
44
+ /,
45
+ ) -> None:
46
+ """Intercept data passed to `write()`.
47
+
48
+ Args:
49
+ data: The object passed to stderr's or stdout's `write()`.
50
+ written: The number of bytes or characters written.
51
+ This is the return value of `write()`.
52
+ """
53
+
54
+
55
+ _module_lock = threading.Lock()
56
+
57
+ _patch_exception: CannotCaptureConsoleError | None = None
58
+
59
+ _next_callback_id: int = 1
60
+
61
+ _stdout_callbacks: dict[int, _WriteCallback] = {}
62
+ _stderr_callbacks: dict[int, _WriteCallback] = {}
63
+
64
+
65
+ def capture_stdout(callback: _WriteCallback) -> Callable[[], None]:
66
+ """Install a callback that runs after every write to sys.stdout.
67
+
68
+ Args:
69
+ callback: A callback to invoke after running `sys.stdout.write`.
70
+ This may be called from any thread, so it must be thread-safe.
71
+ Exceptions are propagated to the caller of `write`.
72
+ See `_WriteCallback` for the exact protocol.
73
+
74
+ Returns:
75
+ A function to uninstall the callback.
76
+
77
+ Raises:
78
+ CannotCaptureConsoleError: If patching failed on import.
79
+ """
80
+ with _module_lock:
81
+ if _patch_exception:
82
+ raise _patch_exception
83
+
84
+ return _insert_disposably(
85
+ _stdout_callbacks,
86
+ callback,
87
+ )
88
+
89
+
90
+ def capture_stderr(callback: _WriteCallback) -> Callable[[], None]:
91
+ """Install a callback that runs after every write to sys.sdterr.
92
+
93
+ Args:
94
+ callback: A callback to invoke after running `sys.stderr.write`.
95
+ This may be called from any thread, so it must be thread-safe.
96
+ Exceptions are propagated to the caller of `write`.
97
+ See `_WriteCallback` for the exact protocol.
98
+
99
+ Returns:
100
+ A function to uninstall the callback.
101
+
102
+ Raises:
103
+ CannotCaptureConsoleError: If patching failed on import.
104
+ """
105
+ with _module_lock:
106
+ if _patch_exception:
107
+ raise _patch_exception
108
+
109
+ return _insert_disposably(
110
+ _stderr_callbacks,
111
+ callback,
112
+ )
113
+
114
+
115
+ def _insert_disposably(
116
+ callback_dict: dict[int, _WriteCallback],
117
+ callback: _WriteCallback,
118
+ ) -> Callable[[], None]:
119
+ global _next_callback_id
120
+ id = _next_callback_id
121
+ _next_callback_id += 1
122
+
123
+ disposed = False
124
+
125
+ def dispose() -> None:
126
+ nonlocal disposed
127
+
128
+ with _module_lock:
129
+ if disposed:
130
+ return
131
+
132
+ del callback_dict[id]
133
+
134
+ disposed = True
135
+
136
+ callback_dict[id] = callback
137
+ return dispose
138
+
139
+
140
+ def _patch(
141
+ stdout_or_stderr: IO[AnyStr],
142
+ callbacks: dict[int, _WriteCallback],
143
+ ) -> None:
144
+ orig_write: Callable[[AnyStr], int]
145
+
146
+ def write_with_callbacks(s: AnyStr, /) -> int:
147
+ n = orig_write(s)
148
+
149
+ # We make a copy here because callbacks could, in theory, modify
150
+ # the list of callbacks.
151
+ with _module_lock:
152
+ callbacks_copy = list(callbacks.values())
153
+
154
+ for cb in callbacks_copy:
155
+ cb(s, n)
156
+
157
+ return n
158
+
159
+ orig_write = stdout_or_stderr.write
160
+
161
+ # mypy==1.14.1 fails to type-check this:
162
+ # Incompatible types in assignment (expression has type
163
+ # "Callable[[bytes], int]", variable has type overloaded function)
164
+ stdout_or_stderr.write = write_with_callbacks # type: ignore
165
+
166
+
167
+ try:
168
+ _patch(sys.stdout, _stdout_callbacks)
169
+ _patch(sys.stderr, _stderr_callbacks)
170
+ except Exception as _patch_exception_cause:
171
+ _patch_exception = CannotCaptureConsoleError()
172
+ _patch_exception.__cause__ = _patch_exception_cause