wandb 0.19.7__py3-none-macosx_11_0_arm64.whl → 0.19.8__py3-none-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +1 -1
- wandb/__init__.pyi +32 -2
- wandb/bin/gpu_stats +0 -0
- wandb/bin/wandb-core +0 -0
- wandb/data_types.py +1 -1
- wandb/filesync/dir_watcher.py +2 -1
- wandb/proto/v3/wandb_settings_pb2.py +2 -2
- wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v4/wandb_settings_pb2.py +2 -2
- wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v5/wandb_settings_pb2.py +2 -2
- wandb/proto/v5/wandb_telemetry_pb2.py +10 -10
- wandb/sdk/artifacts/artifact.py +11 -10
- wandb/sdk/backend/backend.py +16 -5
- wandb/sdk/interface/interface.py +65 -43
- wandb/sdk/interface/interface_queue.py +0 -7
- wandb/sdk/interface/interface_relay.py +6 -16
- wandb/sdk/interface/interface_shared.py +47 -40
- wandb/sdk/interface/interface_sock.py +1 -8
- wandb/sdk/interface/router.py +22 -54
- wandb/sdk/interface/router_queue.py +11 -10
- wandb/sdk/interface/router_relay.py +24 -12
- wandb/sdk/interface/router_sock.py +6 -11
- wandb/sdk/internal/sender.py +3 -1
- wandb/sdk/lib/console_capture.py +172 -0
- wandb/sdk/lib/redirect.py +102 -76
- wandb/sdk/lib/service_connection.py +37 -17
- wandb/sdk/lib/sock_client.py +2 -52
- wandb/sdk/mailbox/__init__.py +3 -3
- wandb/sdk/mailbox/mailbox.py +31 -17
- wandb/sdk/mailbox/mailbox_handle.py +127 -0
- wandb/sdk/mailbox/{handles.py → response_handle.py} +34 -66
- wandb/sdk/mailbox/wait_with_progress.py +16 -15
- wandb/sdk/service/server_sock.py +4 -2
- wandb/sdk/service/streams.py +10 -5
- wandb/sdk/wandb_init.py +12 -15
- wandb/sdk/wandb_run.py +8 -10
- wandb/sdk/wandb_settings.py +7 -1
- wandb/sdk/wandb_sync.py +1 -7
- {wandb-0.19.7.dist-info → wandb-0.19.8.dist-info}/METADATA +1 -1
- {wandb-0.19.7.dist-info → wandb-0.19.8.dist-info}/RECORD +44 -44
- wandb/sdk/interface/message_future.py +0 -27
- wandb/sdk/interface/message_future_poll.py +0 -50
- {wandb-0.19.7.dist-info → wandb-0.19.8.dist-info}/WHEEL +0 -0
- {wandb-0.19.7.dist-info → wandb-0.19.8.dist-info}/entry_points.txt +0 -0
- {wandb-0.19.7.dist-info → wandb-0.19.8.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,172 @@
|
|
1
|
+
"""Module for intercepting stdout/stderr.
|
2
|
+
|
3
|
+
This patches the `write()` method of `stdout` and `stderr` on import.
|
4
|
+
Once patched, it is not possible to unpatch or repatch, though individual
|
5
|
+
callbacks can be removed.
|
6
|
+
|
7
|
+
We assume that all other writing methods on the object delegate to `write()`,
|
8
|
+
like `writelines()`. This is not guaranteed to be true, but it is true for
|
9
|
+
common implementations. In particular, CPython's implementation of IOBase's
|
10
|
+
`writelines()` delegates to `write()`.
|
11
|
+
|
12
|
+
It is important to note that this technique interacts poorly with other
|
13
|
+
code that performs similar patching if it also allows unpatching as this
|
14
|
+
discards our modification. This is why we patch on import and do not support
|
15
|
+
unpatching:
|
16
|
+
|
17
|
+
with contextlib.redirect_stderr(...):
|
18
|
+
from ... import console_capture
|
19
|
+
# Here, everything works fine.
|
20
|
+
# Here, callbacks are never called again.
|
21
|
+
|
22
|
+
In particular, it does not work with some combinations of pytest's
|
23
|
+
`capfd` / `capsys` fixtures and pytest's `--capture` option.
|
24
|
+
"""
|
25
|
+
|
26
|
+
from __future__ import annotations
|
27
|
+
|
28
|
+
import sys
|
29
|
+
import threading
|
30
|
+
from typing import IO, AnyStr, Callable, Protocol
|
31
|
+
|
32
|
+
|
33
|
+
class CannotCaptureConsoleError(Exception):
|
34
|
+
"""The module failed to patch stdout or stderr."""
|
35
|
+
|
36
|
+
|
37
|
+
class _WriteCallback(Protocol):
|
38
|
+
"""A callback that receives intercepted bytes or string data."""
|
39
|
+
|
40
|
+
def __call__(
|
41
|
+
self,
|
42
|
+
data: bytes | str,
|
43
|
+
written: int,
|
44
|
+
/,
|
45
|
+
) -> None:
|
46
|
+
"""Intercept data passed to `write()`.
|
47
|
+
|
48
|
+
Args:
|
49
|
+
data: The object passed to stderr's or stdout's `write()`.
|
50
|
+
written: The number of bytes or characters written.
|
51
|
+
This is the return value of `write()`.
|
52
|
+
"""
|
53
|
+
|
54
|
+
|
55
|
+
_module_lock = threading.Lock()
|
56
|
+
|
57
|
+
_patch_exception: CannotCaptureConsoleError | None = None
|
58
|
+
|
59
|
+
_next_callback_id: int = 1
|
60
|
+
|
61
|
+
_stdout_callbacks: dict[int, _WriteCallback] = {}
|
62
|
+
_stderr_callbacks: dict[int, _WriteCallback] = {}
|
63
|
+
|
64
|
+
|
65
|
+
def capture_stdout(callback: _WriteCallback) -> Callable[[], None]:
|
66
|
+
"""Install a callback that runs after every write to sys.stdout.
|
67
|
+
|
68
|
+
Args:
|
69
|
+
callback: A callback to invoke after running `sys.stdout.write`.
|
70
|
+
This may be called from any thread, so it must be thread-safe.
|
71
|
+
Exceptions are propagated to the caller of `write`.
|
72
|
+
See `_WriteCallback` for the exact protocol.
|
73
|
+
|
74
|
+
Returns:
|
75
|
+
A function to uninstall the callback.
|
76
|
+
|
77
|
+
Raises:
|
78
|
+
CannotCaptureConsoleError: If patching failed on import.
|
79
|
+
"""
|
80
|
+
with _module_lock:
|
81
|
+
if _patch_exception:
|
82
|
+
raise _patch_exception
|
83
|
+
|
84
|
+
return _insert_disposably(
|
85
|
+
_stdout_callbacks,
|
86
|
+
callback,
|
87
|
+
)
|
88
|
+
|
89
|
+
|
90
|
+
def capture_stderr(callback: _WriteCallback) -> Callable[[], None]:
|
91
|
+
"""Install a callback that runs after every write to sys.sdterr.
|
92
|
+
|
93
|
+
Args:
|
94
|
+
callback: A callback to invoke after running `sys.stderr.write`.
|
95
|
+
This may be called from any thread, so it must be thread-safe.
|
96
|
+
Exceptions are propagated to the caller of `write`.
|
97
|
+
See `_WriteCallback` for the exact protocol.
|
98
|
+
|
99
|
+
Returns:
|
100
|
+
A function to uninstall the callback.
|
101
|
+
|
102
|
+
Raises:
|
103
|
+
CannotCaptureConsoleError: If patching failed on import.
|
104
|
+
"""
|
105
|
+
with _module_lock:
|
106
|
+
if _patch_exception:
|
107
|
+
raise _patch_exception
|
108
|
+
|
109
|
+
return _insert_disposably(
|
110
|
+
_stderr_callbacks,
|
111
|
+
callback,
|
112
|
+
)
|
113
|
+
|
114
|
+
|
115
|
+
def _insert_disposably(
|
116
|
+
callback_dict: dict[int, _WriteCallback],
|
117
|
+
callback: _WriteCallback,
|
118
|
+
) -> Callable[[], None]:
|
119
|
+
global _next_callback_id
|
120
|
+
id = _next_callback_id
|
121
|
+
_next_callback_id += 1
|
122
|
+
|
123
|
+
disposed = False
|
124
|
+
|
125
|
+
def dispose() -> None:
|
126
|
+
nonlocal disposed
|
127
|
+
|
128
|
+
with _module_lock:
|
129
|
+
if disposed:
|
130
|
+
return
|
131
|
+
|
132
|
+
del callback_dict[id]
|
133
|
+
|
134
|
+
disposed = True
|
135
|
+
|
136
|
+
callback_dict[id] = callback
|
137
|
+
return dispose
|
138
|
+
|
139
|
+
|
140
|
+
def _patch(
|
141
|
+
stdout_or_stderr: IO[AnyStr],
|
142
|
+
callbacks: dict[int, _WriteCallback],
|
143
|
+
) -> None:
|
144
|
+
orig_write: Callable[[AnyStr], int]
|
145
|
+
|
146
|
+
def write_with_callbacks(s: AnyStr, /) -> int:
|
147
|
+
n = orig_write(s)
|
148
|
+
|
149
|
+
# We make a copy here because callbacks could, in theory, modify
|
150
|
+
# the list of callbacks.
|
151
|
+
with _module_lock:
|
152
|
+
callbacks_copy = list(callbacks.values())
|
153
|
+
|
154
|
+
for cb in callbacks_copy:
|
155
|
+
cb(s, n)
|
156
|
+
|
157
|
+
return n
|
158
|
+
|
159
|
+
orig_write = stdout_or_stderr.write
|
160
|
+
|
161
|
+
# mypy==1.14.1 fails to type-check this:
|
162
|
+
# Incompatible types in assignment (expression has type
|
163
|
+
# "Callable[[bytes], int]", variable has type overloaded function)
|
164
|
+
stdout_or_stderr.write = write_with_callbacks # type: ignore
|
165
|
+
|
166
|
+
|
167
|
+
try:
|
168
|
+
_patch(sys.stdout, _stdout_callbacks)
|
169
|
+
_patch(sys.stderr, _stderr_callbacks)
|
170
|
+
except Exception as _patch_exception_cause:
|
171
|
+
_patch_exception = CannotCaptureConsoleError()
|
172
|
+
_patch_exception.__cause__ = _patch_exception_cause
|
wandb/sdk/lib/redirect.py
CHANGED
@@ -1,3 +1,5 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
try:
|
2
4
|
import fcntl
|
3
5
|
import pty
|
@@ -17,8 +19,10 @@ import sys
|
|
17
19
|
import threading
|
18
20
|
import time
|
19
21
|
from collections import defaultdict
|
22
|
+
from typing import Callable, Iterable, Literal
|
20
23
|
|
21
24
|
import wandb
|
25
|
+
from wandb.sdk.lib import console_capture
|
22
26
|
|
23
27
|
|
24
28
|
class _Numpy: # fallback in case numpy is not available
|
@@ -55,8 +59,6 @@ except ImportError:
|
|
55
59
|
|
56
60
|
logger = logging.getLogger("wandb")
|
57
61
|
|
58
|
-
_redirects = {"stdout": None, "stderr": None}
|
59
|
-
|
60
62
|
|
61
63
|
ANSI_CSI_RE = re.compile("\001?\033\\[((?:\\d|;)*)([a-zA-Z])\002?")
|
62
64
|
ANSI_OSC_RE = re.compile("\001?\033\\]([^\a]*)(\a)\002?")
|
@@ -491,7 +493,11 @@ _MIN_CALLBACK_INTERVAL = 2 # seconds
|
|
491
493
|
|
492
494
|
|
493
495
|
class RedirectBase:
|
494
|
-
def __init__(
|
496
|
+
def __init__(
|
497
|
+
self,
|
498
|
+
src: Literal["stdout", "stderr"],
|
499
|
+
cbs: Iterable[Callable[[str], None]] = (),
|
500
|
+
) -> None:
|
495
501
|
"""# Arguments.
|
496
502
|
|
497
503
|
`src`: Source stream to be redirected. "stdout" or "stderr".
|
@@ -499,7 +505,7 @@ class RedirectBase:
|
|
499
505
|
|
500
506
|
"""
|
501
507
|
assert hasattr(sys, src)
|
502
|
-
self.src = src
|
508
|
+
self.src: Literal["stdout", "stderr"] = src
|
503
509
|
self.cbs = cbs
|
504
510
|
|
505
511
|
@property
|
@@ -514,71 +520,82 @@ class RedirectBase:
|
|
514
520
|
def src_wrapped_stream(self):
|
515
521
|
return getattr(sys, self.src)
|
516
522
|
|
517
|
-
def
|
523
|
+
def install(self) -> None:
|
518
524
|
pass
|
519
525
|
|
520
|
-
def
|
521
|
-
|
522
|
-
if curr_redirect and curr_redirect != self:
|
523
|
-
curr_redirect.uninstall()
|
524
|
-
_redirects[self.src] = self
|
525
|
-
|
526
|
-
def uninstall(self):
|
527
|
-
if _redirects[self.src] != self:
|
528
|
-
return
|
529
|
-
_redirects[self.src] = None
|
526
|
+
def uninstall(self) -> None:
|
527
|
+
pass
|
530
528
|
|
531
529
|
|
532
530
|
class StreamWrapper(RedirectBase):
|
533
531
|
"""Patches the write method of current sys.stdout/sys.stderr."""
|
534
532
|
|
535
|
-
def __init__(
|
533
|
+
def __init__(
|
534
|
+
self,
|
535
|
+
src: Literal["stdout", "stderr"],
|
536
|
+
cbs: Iterable[Callable[[str], None]] = (),
|
537
|
+
) -> None:
|
536
538
|
super().__init__(src=src, cbs=cbs)
|
537
|
-
self.
|
539
|
+
self._uninstall: Callable[[], None] | None = None
|
538
540
|
self._emulator = TerminalEmulator()
|
541
|
+
self._queue: queue.Queue[str] = queue.Queue()
|
542
|
+
self._stopped = threading.Event()
|
539
543
|
|
540
|
-
def _emulator_write(self):
|
544
|
+
def _emulator_write(self) -> None:
|
541
545
|
while True:
|
542
546
|
if self._queue.empty():
|
543
547
|
if self._stopped.is_set():
|
544
548
|
return
|
545
549
|
time.sleep(0.5)
|
546
550
|
continue
|
547
|
-
|
551
|
+
|
552
|
+
data: list[str] = []
|
548
553
|
while not self._queue.empty():
|
549
554
|
data.append(self._queue.get())
|
555
|
+
|
550
556
|
if self._stopped.is_set() and sum(map(len, data)) > 100000:
|
551
557
|
wandb.termlog("Terminal output too large. Logging without processing.")
|
552
558
|
self.flush()
|
553
|
-
|
559
|
+
|
560
|
+
for line in data:
|
561
|
+
self.flush(line)
|
562
|
+
|
554
563
|
return
|
564
|
+
|
555
565
|
try:
|
556
566
|
self._emulator.write("".join(data))
|
557
567
|
except Exception:
|
558
568
|
pass
|
559
569
|
|
560
|
-
def _callback(self):
|
570
|
+
def _callback(self) -> None:
|
561
571
|
while not (self._stopped.is_set() and self._queue.empty()):
|
562
572
|
self.flush()
|
563
573
|
time.sleep(_MIN_CALLBACK_INTERVAL)
|
564
574
|
|
565
|
-
def
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
old_write = stream.write
|
571
|
-
self._prev_callback_timestamp = time.time()
|
572
|
-
self._old_write = old_write
|
575
|
+
def _on_write(self, data: str | bytes, written: int, /) -> None:
|
576
|
+
if isinstance(data, bytes):
|
577
|
+
written_data = data[:written].decode("utf-8")
|
578
|
+
else:
|
579
|
+
written_data = data[:written]
|
573
580
|
|
574
|
-
|
575
|
-
self._old_write(data)
|
576
|
-
self._queue.put(data)
|
581
|
+
self._queue.put(written_data)
|
577
582
|
|
578
|
-
|
583
|
+
def install(self) -> None:
|
584
|
+
if self._uninstall:
|
585
|
+
return
|
586
|
+
|
587
|
+
try:
|
588
|
+
if self.src == "stdout":
|
589
|
+
self._uninstall = console_capture.capture_stdout(self._on_write)
|
590
|
+
else:
|
591
|
+
self._uninstall = console_capture.capture_stderr(self._on_write)
|
592
|
+
except console_capture.CannotCaptureConsoleError:
|
593
|
+
logger.exception("failed to install %s hooks", self.src)
|
594
|
+
wandb.termwarn(
|
595
|
+
f"Failed to wrap {self.src}. Console logs will not be captured.",
|
596
|
+
)
|
597
|
+
return
|
579
598
|
|
580
|
-
self._queue = queue.Queue()
|
581
|
-
self._stopped = threading.Event()
|
582
599
|
self._emulator_write_thread = threading.Thread(target=self._emulator_write)
|
583
600
|
self._emulator_write_thread.daemon = True
|
584
601
|
self._emulator_write_thread.start()
|
@@ -588,25 +605,25 @@ class StreamWrapper(RedirectBase):
|
|
588
605
|
self._callback_thread.daemon = True
|
589
606
|
self._callback_thread.start()
|
590
607
|
|
591
|
-
|
592
|
-
|
593
|
-
def flush(self, data=None):
|
608
|
+
def flush(self, data: str | None = None) -> None:
|
594
609
|
if data is None:
|
595
610
|
try:
|
596
611
|
data = self._emulator.read().encode("utf-8")
|
597
612
|
except Exception:
|
598
|
-
|
613
|
+
logger.exception("exception reading TerminalEmulator")
|
614
|
+
|
599
615
|
if data:
|
600
616
|
for cb in self.cbs:
|
601
617
|
try:
|
602
618
|
cb(data)
|
603
619
|
except Exception:
|
604
|
-
|
620
|
+
logger.exception("exception in StreamWrapper callback")
|
605
621
|
|
606
|
-
def uninstall(self):
|
607
|
-
if not self.
|
622
|
+
def uninstall(self) -> None:
|
623
|
+
if not self._uninstall:
|
608
624
|
return
|
609
|
-
|
625
|
+
|
626
|
+
self._uninstall()
|
610
627
|
|
611
628
|
self._stopped.set()
|
612
629
|
self._emulator_write_thread.join(timeout=5)
|
@@ -616,9 +633,6 @@ class StreamWrapper(RedirectBase):
|
|
616
633
|
wandb.termlog("Done.")
|
617
634
|
self.flush()
|
618
635
|
|
619
|
-
self._installed = False
|
620
|
-
super().uninstall()
|
621
|
-
|
622
636
|
|
623
637
|
class StreamRawWrapper(RedirectBase):
|
624
638
|
"""Patches the write method of current sys.stdout/sys.stderr.
|
@@ -626,40 +640,44 @@ class StreamRawWrapper(RedirectBase):
|
|
626
640
|
Captures data in a raw form rather than using the emulator
|
627
641
|
"""
|
628
642
|
|
629
|
-
def __init__(
|
643
|
+
def __init__(
|
644
|
+
self,
|
645
|
+
src: Literal["stdout", "stderr"],
|
646
|
+
cbs: Iterable[Callable[[str], None]] = (),
|
647
|
+
) -> None:
|
630
648
|
super().__init__(src=src, cbs=cbs)
|
631
|
-
self.
|
649
|
+
self._uninstall: Callable[[], None] | None = None
|
632
650
|
|
633
|
-
def
|
634
|
-
|
635
|
-
|
651
|
+
def _on_write(self, data: str | bytes, written: int, /) -> None:
|
652
|
+
if isinstance(data, bytes):
|
653
|
+
written_data = data[:written].decode("utf-8")
|
654
|
+
else:
|
655
|
+
written_data = data[:written]
|
636
656
|
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
641
|
-
|
642
|
-
self._prev_callback_timestamp = time.time()
|
657
|
+
for cb in self.cbs:
|
658
|
+
try:
|
659
|
+
cb(written_data)
|
660
|
+
except Exception:
|
661
|
+
logger.exception("error in %s callback", self.src)
|
643
662
|
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
try:
|
648
|
-
cb(data)
|
649
|
-
except Exception:
|
650
|
-
# TODO: Figure out why this was needed and log or error out appropriately
|
651
|
-
# it might have been strange terminals? maybe shutdown cases?
|
652
|
-
pass
|
663
|
+
def install(self) -> None:
|
664
|
+
if self._uninstall:
|
665
|
+
return
|
653
666
|
|
654
|
-
|
655
|
-
|
667
|
+
try:
|
668
|
+
if self.src == "stdout":
|
669
|
+
self._uninstall = console_capture.capture_stdout(self._on_write)
|
670
|
+
else:
|
671
|
+
self._uninstall = console_capture.capture_stderr(self._on_write)
|
672
|
+
except console_capture.CannotCaptureConsoleError:
|
673
|
+
logger.exception("failed to install %s hooks", self.src)
|
674
|
+
wandb.termwarn(
|
675
|
+
f"Failed to wrap {self.src}. Console logs will not be captured.",
|
676
|
+
)
|
656
677
|
|
657
|
-
def uninstall(self):
|
658
|
-
if
|
659
|
-
|
660
|
-
self.src_wrapped_stream.write = self._old_write
|
661
|
-
self._installed = False
|
662
|
-
super().uninstall()
|
678
|
+
def uninstall(self) -> None:
|
679
|
+
if self._uninstall:
|
680
|
+
self._uninstall()
|
663
681
|
|
664
682
|
|
665
683
|
class _WindowSizeChangeHandler:
|
@@ -708,6 +726,8 @@ class _WindowSizeChangeHandler:
|
|
708
726
|
|
709
727
|
_WSCH = _WindowSizeChangeHandler()
|
710
728
|
|
729
|
+
_redirects: dict[str, Redirect | None] = {"stdout": None, "stderr": None}
|
730
|
+
|
711
731
|
|
712
732
|
class Redirect(RedirectBase):
|
713
733
|
"""Redirect low level file descriptors."""
|
@@ -725,7 +745,11 @@ class Redirect(RedirectBase):
|
|
725
745
|
return r, w
|
726
746
|
|
727
747
|
def install(self):
|
728
|
-
|
748
|
+
curr_redirect = _redirects.get(self.src)
|
749
|
+
if curr_redirect and curr_redirect != self:
|
750
|
+
curr_redirect.uninstall()
|
751
|
+
_redirects[self.src] = self
|
752
|
+
|
729
753
|
if self._installed:
|
730
754
|
return
|
731
755
|
self._pipe_read_fd, self._pipe_write_fd = self._pipe()
|
@@ -776,7 +800,9 @@ class Redirect(RedirectBase):
|
|
776
800
|
self.flush()
|
777
801
|
|
778
802
|
_WSCH.remove_fd(self._pipe_read_fd)
|
779
|
-
|
803
|
+
|
804
|
+
if _redirects[self.src] == self:
|
805
|
+
_redirects[self.src] = None
|
780
806
|
|
781
807
|
def flush(self, data=None):
|
782
808
|
if data is None:
|
@@ -10,10 +10,11 @@ from wandb.proto import wandb_settings_pb2
|
|
10
10
|
from wandb.sdk import wandb_settings
|
11
11
|
from wandb.sdk.interface.interface import InterfaceBase
|
12
12
|
from wandb.sdk.interface.interface_sock import InterfaceSock
|
13
|
+
from wandb.sdk.interface.router_sock import MessageSockRouter
|
13
14
|
from wandb.sdk.lib import service_token
|
14
15
|
from wandb.sdk.lib.exit_hooks import ExitHooks
|
15
|
-
from wandb.sdk.lib.sock_client import SockClient,
|
16
|
-
from wandb.sdk.mailbox import Mailbox
|
16
|
+
from wandb.sdk.lib.sock_client import SockClient, SockClientClosedError
|
17
|
+
from wandb.sdk.mailbox import HandleAbandonedError, Mailbox, MailboxClosedError
|
17
18
|
from wandb.sdk.service import service
|
18
19
|
|
19
20
|
|
@@ -115,6 +116,9 @@ class ServiceConnection:
|
|
115
116
|
"""Returns a new ServiceConnection.
|
116
117
|
|
117
118
|
Args:
|
119
|
+
mailbox: The mailbox to use for all communication over the socket.
|
120
|
+
router: A handle to the thread that reads from the socket and
|
121
|
+
updates the mailbox.
|
118
122
|
client: A socket that's connected to the service.
|
119
123
|
proc: The service process if we own it, or None otherwise.
|
120
124
|
cleanup: A callback to run on teardown before doing anything.
|
@@ -124,9 +128,12 @@ class ServiceConnection:
|
|
124
128
|
self._torn_down = False
|
125
129
|
self._cleanup = cleanup
|
126
130
|
|
127
|
-
|
131
|
+
self._mailbox = Mailbox()
|
132
|
+
self._router = MessageSockRouter(self._client, self._mailbox)
|
133
|
+
|
134
|
+
def make_interface(self, stream_id: str) -> InterfaceBase:
|
128
135
|
"""Returns an interface for communicating with the service."""
|
129
|
-
return InterfaceSock(self._client,
|
136
|
+
return InterfaceSock(self._client, self._mailbox, stream_id=stream_id)
|
130
137
|
|
131
138
|
def send_record(self, record: pb.Record) -> None:
|
132
139
|
"""Sends data to the service."""
|
@@ -141,13 +148,13 @@ class ServiceConnection:
|
|
141
148
|
request = spb.ServerInformInitRequest()
|
142
149
|
request.settings.CopyFrom(settings)
|
143
150
|
request._info.stream_id = run_id
|
144
|
-
self._client.
|
151
|
+
self._client.send_server_request(spb.ServerRequest(inform_init=request))
|
145
152
|
|
146
153
|
def inform_finish(self, run_id: str) -> None:
|
147
154
|
"""Sends an finish request to the service."""
|
148
155
|
request = spb.ServerInformFinishRequest()
|
149
156
|
request._info.stream_id = run_id
|
150
|
-
self._client.
|
157
|
+
self._client.send_server_request(spb.ServerRequest(inform_finish=request))
|
151
158
|
|
152
159
|
def inform_attach(
|
153
160
|
self,
|
@@ -157,18 +164,26 @@ class ServiceConnection:
|
|
157
164
|
|
158
165
|
Raises a WandbAttachFailedError if attaching is not possible.
|
159
166
|
"""
|
160
|
-
request = spb.
|
161
|
-
request._info.stream_id = attach_id
|
167
|
+
request = spb.ServerRequest()
|
168
|
+
request.inform_attach._info.stream_id = attach_id
|
162
169
|
|
163
170
|
try:
|
164
|
-
|
171
|
+
handle = self._mailbox.require_response(request)
|
172
|
+
self._client.send_server_request(request)
|
173
|
+
response = handle.wait_or(timeout=10)
|
165
174
|
return response.inform_attach_response.settings
|
166
|
-
|
175
|
+
|
176
|
+
except (MailboxClosedError, HandleAbandonedError, SockClientClosedError):
|
177
|
+
raise WandbAttachFailedError(
|
178
|
+
"Failed to attach: the service process is not running.",
|
179
|
+
) from None
|
180
|
+
|
181
|
+
except TimeoutError:
|
167
182
|
raise WandbAttachFailedError(
|
168
|
-
"
|
183
|
+
"Failed to attach because the run does not belong to"
|
169
184
|
" the current service process, or because the service"
|
170
185
|
" process is busy (unlikely)."
|
171
|
-
)
|
186
|
+
) from None
|
172
187
|
|
173
188
|
def inform_start(
|
174
189
|
self,
|
@@ -179,7 +194,7 @@ class ServiceConnection:
|
|
179
194
|
request = spb.ServerInformStartRequest()
|
180
195
|
request.settings.CopyFrom(settings)
|
181
196
|
request._info.stream_id = run_id
|
182
|
-
self._client.
|
197
|
+
self._client.send_server_request(spb.ServerRequest(inform_start=request))
|
183
198
|
|
184
199
|
def teardown(self, exit_code: int) -> int:
|
185
200
|
"""Shuts down the service process and returns its exit code.
|
@@ -207,10 +222,15 @@ class ServiceConnection:
|
|
207
222
|
# Clear the service token to prevent new connections from being made.
|
208
223
|
service_token.clear_service_token()
|
209
224
|
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
225
|
+
# Stop reading responses on the socket.
|
226
|
+
self._router.join()
|
227
|
+
|
228
|
+
self._client.send_server_request(
|
229
|
+
spb.ServerRequest(
|
230
|
+
inform_teardown=spb.ServerInformTeardownRequest(
|
231
|
+
exit_code=exit_code,
|
232
|
+
)
|
233
|
+
),
|
214
234
|
)
|
215
235
|
|
216
236
|
return self._proc.join()
|
wandb/sdk/lib/sock_client.py
CHANGED
@@ -150,10 +150,10 @@ class SockClient:
|
|
150
150
|
with self._lock:
|
151
151
|
self._sendall_with_error_handle(header + data)
|
152
152
|
|
153
|
-
def send_server_request(self, msg:
|
153
|
+
def send_server_request(self, msg: spb.ServerRequest) -> None:
|
154
154
|
self._send_message(msg)
|
155
155
|
|
156
|
-
def send_server_response(self, msg:
|
156
|
+
def send_server_response(self, msg: spb.ServerResponse) -> None:
|
157
157
|
try:
|
158
158
|
self._send_message(msg)
|
159
159
|
except BrokenPipeError:
|
@@ -161,56 +161,6 @@ class SockClient:
|
|
161
161
|
# things like network status poll loop, there might be a better way to quiesce
|
162
162
|
pass
|
163
163
|
|
164
|
-
def send_and_recv(
|
165
|
-
self,
|
166
|
-
*,
|
167
|
-
inform_init: Optional[spb.ServerInformInitRequest] = None,
|
168
|
-
inform_start: Optional[spb.ServerInformStartRequest] = None,
|
169
|
-
inform_attach: Optional[spb.ServerInformAttachRequest] = None,
|
170
|
-
inform_finish: Optional[spb.ServerInformFinishRequest] = None,
|
171
|
-
inform_teardown: Optional[spb.ServerInformTeardownRequest] = None,
|
172
|
-
) -> spb.ServerResponse:
|
173
|
-
self.send(
|
174
|
-
inform_init=inform_init,
|
175
|
-
inform_start=inform_start,
|
176
|
-
inform_attach=inform_attach,
|
177
|
-
inform_finish=inform_finish,
|
178
|
-
inform_teardown=inform_teardown,
|
179
|
-
)
|
180
|
-
|
181
|
-
# HACK: This assumes nothing else is reading on the socket, and that
|
182
|
-
# the next response is for this request.
|
183
|
-
response = self.read_server_response(timeout=1)
|
184
|
-
|
185
|
-
if response is None:
|
186
|
-
raise SockClientTimeoutError("No response after 1 second.")
|
187
|
-
|
188
|
-
return response
|
189
|
-
|
190
|
-
def send(
|
191
|
-
self,
|
192
|
-
*,
|
193
|
-
inform_init: Optional[spb.ServerInformInitRequest] = None,
|
194
|
-
inform_start: Optional[spb.ServerInformStartRequest] = None,
|
195
|
-
inform_attach: Optional[spb.ServerInformAttachRequest] = None,
|
196
|
-
inform_finish: Optional[spb.ServerInformFinishRequest] = None,
|
197
|
-
inform_teardown: Optional[spb.ServerInformTeardownRequest] = None,
|
198
|
-
) -> None:
|
199
|
-
server_req = spb.ServerRequest()
|
200
|
-
if inform_init:
|
201
|
-
server_req.inform_init.CopyFrom(inform_init)
|
202
|
-
elif inform_start:
|
203
|
-
server_req.inform_start.CopyFrom(inform_start)
|
204
|
-
elif inform_attach:
|
205
|
-
server_req.inform_attach.CopyFrom(inform_attach)
|
206
|
-
elif inform_finish:
|
207
|
-
server_req.inform_finish.CopyFrom(inform_finish)
|
208
|
-
elif inform_teardown:
|
209
|
-
server_req.inform_teardown.CopyFrom(inform_teardown)
|
210
|
-
else:
|
211
|
-
raise Exception("unmatched")
|
212
|
-
self.send_server_request(server_req)
|
213
|
-
|
214
164
|
def send_record_communicate(self, record: "pb.Record") -> None:
|
215
165
|
server_req = spb.ServerRequest()
|
216
166
|
server_req.request_id = record.control.mailbox_slot
|
wandb/sdk/mailbox/__init__.py
CHANGED
@@ -9,15 +9,15 @@ The Mailbox handles matching responses to requests. An internal thread
|
|
9
9
|
continuously reads data from the service and passes it to the mailbox.
|
10
10
|
"""
|
11
11
|
|
12
|
-
from .handles import HandleAbandonedError, MailboxHandle
|
13
12
|
from .mailbox import Mailbox, MailboxClosedError
|
13
|
+
from .mailbox_handle import HandleAbandonedError, MailboxHandle
|
14
14
|
from .wait_with_progress import wait_all_with_progress, wait_with_progress
|
15
15
|
|
16
16
|
__all__ = [
|
17
|
-
"HandleAbandonedError",
|
18
|
-
"MailboxHandle",
|
19
17
|
"Mailbox",
|
20
18
|
"MailboxClosedError",
|
19
|
+
"HandleAbandonedError",
|
20
|
+
"MailboxHandle",
|
21
21
|
"wait_all_with_progress",
|
22
22
|
"wait_with_progress",
|
23
23
|
]
|