wandb 0.19.6rc4__py3-none-macosx_11_0_arm64.whl → 0.19.8__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. wandb/__init__.py +1 -1
  2. wandb/__init__.pyi +56 -6
  3. wandb/apis/public/_generated/__init__.py +21 -0
  4. wandb/apis/public/_generated/base.py +128 -0
  5. wandb/apis/public/_generated/enums.py +4 -0
  6. wandb/apis/public/_generated/input_types.py +4 -0
  7. wandb/apis/public/_generated/operations.py +15 -0
  8. wandb/apis/public/_generated/server_features_query.py +27 -0
  9. wandb/apis/public/_generated/typing_compat.py +14 -0
  10. wandb/apis/public/api.py +192 -6
  11. wandb/apis/public/artifacts.py +13 -45
  12. wandb/apis/public/registries.py +573 -0
  13. wandb/apis/public/utils.py +36 -0
  14. wandb/bin/gpu_stats +0 -0
  15. wandb/bin/wandb-core +0 -0
  16. wandb/cli/cli.py +11 -20
  17. wandb/data_types.py +1 -1
  18. wandb/env.py +10 -0
  19. wandb/filesync/dir_watcher.py +2 -1
  20. wandb/proto/v3/wandb_internal_pb2.py +243 -222
  21. wandb/proto/v3/wandb_server_pb2.py +4 -4
  22. wandb/proto/v3/wandb_settings_pb2.py +2 -2
  23. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  24. wandb/proto/v4/wandb_internal_pb2.py +226 -222
  25. wandb/proto/v4/wandb_server_pb2.py +4 -4
  26. wandb/proto/v4/wandb_settings_pb2.py +2 -2
  27. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  28. wandb/proto/v5/wandb_internal_pb2.py +226 -222
  29. wandb/proto/v5/wandb_server_pb2.py +4 -4
  30. wandb/proto/v5/wandb_settings_pb2.py +2 -2
  31. wandb/proto/v5/wandb_telemetry_pb2.py +10 -10
  32. wandb/sdk/artifacts/_graphql_fragments.py +126 -0
  33. wandb/sdk/artifacts/artifact.py +51 -95
  34. wandb/sdk/backend/backend.py +17 -6
  35. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +14 -6
  36. wandb/sdk/data_types/helper_types/image_mask.py +12 -6
  37. wandb/sdk/data_types/saved_model.py +35 -46
  38. wandb/sdk/data_types/video.py +7 -16
  39. wandb/sdk/interface/interface.py +87 -49
  40. wandb/sdk/interface/interface_queue.py +5 -15
  41. wandb/sdk/interface/interface_relay.py +7 -22
  42. wandb/sdk/interface/interface_shared.py +65 -136
  43. wandb/sdk/interface/interface_sock.py +3 -21
  44. wandb/sdk/interface/router.py +42 -68
  45. wandb/sdk/interface/router_queue.py +13 -11
  46. wandb/sdk/interface/router_relay.py +26 -13
  47. wandb/sdk/interface/router_sock.py +12 -16
  48. wandb/sdk/internal/handler.py +4 -3
  49. wandb/sdk/internal/internal_api.py +12 -1
  50. wandb/sdk/internal/sender.py +3 -19
  51. wandb/sdk/lib/apikey.py +87 -26
  52. wandb/sdk/lib/asyncio_compat.py +210 -0
  53. wandb/sdk/lib/console_capture.py +172 -0
  54. wandb/sdk/lib/progress.py +78 -16
  55. wandb/sdk/lib/redirect.py +102 -76
  56. wandb/sdk/lib/service_connection.py +37 -17
  57. wandb/sdk/lib/sock_client.py +6 -56
  58. wandb/sdk/mailbox/__init__.py +23 -0
  59. wandb/sdk/mailbox/mailbox.py +135 -0
  60. wandb/sdk/mailbox/mailbox_handle.py +127 -0
  61. wandb/sdk/mailbox/response_handle.py +167 -0
  62. wandb/sdk/mailbox/wait_with_progress.py +135 -0
  63. wandb/sdk/service/server_sock.py +9 -3
  64. wandb/sdk/service/streams.py +75 -78
  65. wandb/sdk/verify/verify.py +54 -2
  66. wandb/sdk/wandb_init.py +72 -75
  67. wandb/sdk/wandb_login.py +7 -4
  68. wandb/sdk/wandb_metadata.py +65 -34
  69. wandb/sdk/wandb_require.py +14 -8
  70. wandb/sdk/wandb_run.py +90 -97
  71. wandb/sdk/wandb_settings.py +10 -4
  72. wandb/sdk/wandb_setup.py +19 -8
  73. wandb/sdk/wandb_sync.py +2 -10
  74. wandb/util.py +3 -1
  75. {wandb-0.19.6rc4.dist-info → wandb-0.19.8.dist-info}/METADATA +2 -2
  76. {wandb-0.19.6rc4.dist-info → wandb-0.19.8.dist-info}/RECORD +79 -66
  77. wandb/sdk/interface/message_future.py +0 -27
  78. wandb/sdk/interface/message_future_poll.py +0 -50
  79. wandb/sdk/lib/mailbox.py +0 -442
  80. {wandb-0.19.6rc4.dist-info → wandb-0.19.8.dist-info}/WHEEL +0 -0
  81. {wandb-0.19.6rc4.dist-info → wandb-0.19.8.dist-info}/entry_points.txt +0 -0
  82. {wandb-0.19.6rc4.dist-info → wandb-0.19.8.dist-info}/licenses/LICENSE +0 -0
wandb/sdk/lib/progress.py CHANGED
@@ -2,12 +2,15 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ import asyncio
5
6
  import contextlib
6
- from typing import Iterable, Iterator
7
+ import time
8
+ from typing import Iterable, Iterator, NoReturn
7
9
 
8
- import wandb
9
10
  from wandb import env
10
11
  from wandb.proto import wandb_internal_pb2 as pb
12
+ from wandb.sdk.interface import interface
13
+ from wandb.sdk.lib import asyncio_compat
11
14
 
12
15
  from . import printer as p
13
16
 
@@ -31,14 +34,67 @@ def print_sync_dedupe_stats(
31
34
  printer.display(f"W&B sync reduced upload amount by {frac:.1%}")
32
35
 
33
36
 
37
+ async def loop_printing_operation_stats(
38
+ progress: ProgressPrinter,
39
+ interface: interface.InterfaceBase,
40
+ ) -> None:
41
+ """Poll and display ongoing tasks in the internal service process.
42
+
43
+ This never returns and must be cancelled. This is meant to be used with
44
+ `mailbox.wait_with_progress()`.
45
+
46
+ Args:
47
+ progress: The printer to update with operation stats.
48
+ interface: The interface to use to poll for updates.
49
+
50
+ Raises:
51
+ HandleAbandonedError: If the mailbox associated with the interface
52
+ becomes closed.
53
+ Exception: Any other problem communicating with the service process.
54
+ """
55
+ stats: pb.OperationStats | None = None
56
+
57
+ async def loop_update_screen() -> NoReturn:
58
+ while True:
59
+ if stats:
60
+ progress.update(stats)
61
+ await asyncio.sleep(0.1)
62
+
63
+ async def loop_poll_stats() -> NoReturn:
64
+ nonlocal stats
65
+ while True:
66
+ start_time = time.monotonic()
67
+
68
+ handle = interface.deliver_operation_stats()
69
+ result = await handle.wait_async(timeout=None)
70
+ stats = result.response.operations_response.operation_stats
71
+
72
+ elapsed_time = time.monotonic() - start_time
73
+ if elapsed_time < 0.5:
74
+ await asyncio.sleep(0.5 - elapsed_time)
75
+
76
+ async with asyncio_compat.open_task_group() as task_group:
77
+ task_group.start_soon(loop_update_screen())
78
+ task_group.start_soon(loop_poll_stats())
79
+
80
+
34
81
  @contextlib.contextmanager
35
82
  def progress_printer(
36
83
  printer: p.Printer,
37
- settings: wandb.Settings | None = None,
84
+ default_text: str,
38
85
  ) -> Iterator[ProgressPrinter]:
39
- """Context manager providing an object for printing run progress."""
86
+ """Context manager providing an object for printing run progress.
87
+
88
+ Args:
89
+ printer: The printer to use.
90
+ default_text: The text to show if no information is available.
91
+ """
40
92
  with printer.dynamic_text() as text_area:
41
- yield ProgressPrinter(printer, text_area, settings)
93
+ yield ProgressPrinter(
94
+ printer,
95
+ text_area,
96
+ default_text=default_text,
97
+ )
42
98
  printer.progress_close()
43
99
 
44
100
 
@@ -49,28 +105,27 @@ class ProgressPrinter:
49
105
  self,
50
106
  printer: p.Printer,
51
107
  progress_text_area: p.DynamicText | None,
52
- settings: wandb.Settings | None,
108
+ default_text: str,
53
109
  ) -> None:
54
- self._show_operation_stats = (
55
- settings
56
- and settings.x_show_operation_stats
57
- # Not implemented by the legacy service.
58
- and not env.is_require_legacy_service()
59
- )
110
+ # Not implemented by the legacy service.
111
+ self._show_operation_stats = not env.is_require_legacy_service()
60
112
  self._printer = printer
61
113
  self._progress_text_area = progress_text_area
114
+ self._default_text = default_text
62
115
  self._tick = 0
63
116
  self._last_printed_line = ""
64
117
 
65
118
  def update(
66
119
  self,
67
- progress: list[pb.PollExitResponse],
120
+ progress: list[pb.PollExitResponse] | pb.OperationStats,
68
121
  ) -> None:
69
122
  """Update the displayed information."""
70
123
  if not progress:
71
124
  return
72
125
 
73
- if self._show_operation_stats:
126
+ if isinstance(progress, pb.OperationStats):
127
+ self._update_operation_stats([progress])
128
+ elif self._show_operation_stats:
74
129
  self._update_operation_stats(
75
130
  list(response.operation_stats for response in progress)
76
131
  )
@@ -88,6 +143,7 @@ class ProgressPrinter:
88
143
  self._progress_text_area,
89
144
  max_lines=6,
90
145
  loading_symbol=self._printer.loading_symbol(self._tick),
146
+ default_text=self._default_text,
91
147
  ).display(stats_list)
92
148
 
93
149
  else:
@@ -159,6 +215,10 @@ class ProgressPrinter:
159
215
  self._update_progress_text(line, 1.0)
160
216
 
161
217
  def _update_progress_text(self, text: str, progress: float) -> None:
218
+ if text == self._last_printed_line:
219
+ return
220
+ self._last_printed_line = text
221
+
162
222
  if self._progress_text_area:
163
223
  self._progress_text_area.set_text(text)
164
224
  else:
@@ -174,11 +234,13 @@ class _DynamicOperationStatsPrinter:
174
234
  text_area: p.DynamicText,
175
235
  max_lines: int,
176
236
  loading_symbol: str,
237
+ default_text: str,
177
238
  ) -> None:
178
239
  self._printer = printer
179
240
  self._text_area = text_area
180
241
  self._max_lines = max_lines
181
242
  self._loading_symbol = loading_symbol
243
+ self._default_text = default_text
182
244
 
183
245
  self._lines: list[str] = []
184
246
  self._ops_shown = 0
@@ -204,9 +266,9 @@ class _DynamicOperationStatsPrinter:
204
266
 
205
267
  if len(self._lines) == 0:
206
268
  if self._loading_symbol:
207
- self._text_area.set_text(f"{self._loading_symbol} Finishing up...")
269
+ self._text_area.set_text(f"{self._loading_symbol} {self._default_text}")
208
270
  else:
209
- self._text_area.set_text("Finishing up...")
271
+ self._text_area.set_text(self._default_text)
210
272
  else:
211
273
  self._text_area.set_text("\n".join(self._lines))
212
274
 
wandb/sdk/lib/redirect.py CHANGED
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  try:
2
4
  import fcntl
3
5
  import pty
@@ -17,8 +19,10 @@ import sys
17
19
  import threading
18
20
  import time
19
21
  from collections import defaultdict
22
+ from typing import Callable, Iterable, Literal
20
23
 
21
24
  import wandb
25
+ from wandb.sdk.lib import console_capture
22
26
 
23
27
 
24
28
  class _Numpy: # fallback in case numpy is not available
@@ -55,8 +59,6 @@ except ImportError:
55
59
 
56
60
  logger = logging.getLogger("wandb")
57
61
 
58
- _redirects = {"stdout": None, "stderr": None}
59
-
60
62
 
61
63
  ANSI_CSI_RE = re.compile("\001?\033\\[((?:\\d|;)*)([a-zA-Z])\002?")
62
64
  ANSI_OSC_RE = re.compile("\001?\033\\]([^\a]*)(\a)\002?")
@@ -491,7 +493,11 @@ _MIN_CALLBACK_INTERVAL = 2 # seconds
491
493
 
492
494
 
493
495
  class RedirectBase:
494
- def __init__(self, src, cbs=()):
496
+ def __init__(
497
+ self,
498
+ src: Literal["stdout", "stderr"],
499
+ cbs: Iterable[Callable[[str], None]] = (),
500
+ ) -> None:
495
501
  """# Arguments.
496
502
 
497
503
  `src`: Source stream to be redirected. "stdout" or "stderr".
@@ -499,7 +505,7 @@ class RedirectBase:
499
505
 
500
506
  """
501
507
  assert hasattr(sys, src)
502
- self.src = src
508
+ self.src: Literal["stdout", "stderr"] = src
503
509
  self.cbs = cbs
504
510
 
505
511
  @property
@@ -514,71 +520,82 @@ class RedirectBase:
514
520
  def src_wrapped_stream(self):
515
521
  return getattr(sys, self.src)
516
522
 
517
- def save(self):
523
+ def install(self) -> None:
518
524
  pass
519
525
 
520
- def install(self):
521
- curr_redirect = _redirects.get(self.src)
522
- if curr_redirect and curr_redirect != self:
523
- curr_redirect.uninstall()
524
- _redirects[self.src] = self
525
-
526
- def uninstall(self):
527
- if _redirects[self.src] != self:
528
- return
529
- _redirects[self.src] = None
526
+ def uninstall(self) -> None:
527
+ pass
530
528
 
531
529
 
532
530
  class StreamWrapper(RedirectBase):
533
531
  """Patches the write method of current sys.stdout/sys.stderr."""
534
532
 
535
- def __init__(self, src, cbs=()):
533
+ def __init__(
534
+ self,
535
+ src: Literal["stdout", "stderr"],
536
+ cbs: Iterable[Callable[[str], None]] = (),
537
+ ) -> None:
536
538
  super().__init__(src=src, cbs=cbs)
537
- self._installed = False
539
+ self._uninstall: Callable[[], None] | None = None
538
540
  self._emulator = TerminalEmulator()
541
+ self._queue: queue.Queue[str] = queue.Queue()
542
+ self._stopped = threading.Event()
539
543
 
540
- def _emulator_write(self):
544
+ def _emulator_write(self) -> None:
541
545
  while True:
542
546
  if self._queue.empty():
543
547
  if self._stopped.is_set():
544
548
  return
545
549
  time.sleep(0.5)
546
550
  continue
547
- data = []
551
+
552
+ data: list[str] = []
548
553
  while not self._queue.empty():
549
554
  data.append(self._queue.get())
555
+
550
556
  if self._stopped.is_set() and sum(map(len, data)) > 100000:
551
557
  wandb.termlog("Terminal output too large. Logging without processing.")
552
558
  self.flush()
553
- [self.flush(line.encode("utf-8")) for line in data]
559
+
560
+ for line in data:
561
+ self.flush(line)
562
+
554
563
  return
564
+
555
565
  try:
556
566
  self._emulator.write("".join(data))
557
567
  except Exception:
558
568
  pass
559
569
 
560
- def _callback(self):
570
+ def _callback(self) -> None:
561
571
  while not (self._stopped.is_set() and self._queue.empty()):
562
572
  self.flush()
563
573
  time.sleep(_MIN_CALLBACK_INTERVAL)
564
574
 
565
- def install(self):
566
- super().install()
567
- if self._installed:
568
- return
569
- stream = self.src_wrapped_stream
570
- old_write = stream.write
571
- self._prev_callback_timestamp = time.time()
572
- self._old_write = old_write
575
+ def _on_write(self, data: str | bytes, written: int, /) -> None:
576
+ if isinstance(data, bytes):
577
+ written_data = data[:written].decode("utf-8")
578
+ else:
579
+ written_data = data[:written]
573
580
 
574
- def write(data):
575
- self._old_write(data)
576
- self._queue.put(data)
581
+ self._queue.put(written_data)
577
582
 
578
- stream.write = write
583
+ def install(self) -> None:
584
+ if self._uninstall:
585
+ return
586
+
587
+ try:
588
+ if self.src == "stdout":
589
+ self._uninstall = console_capture.capture_stdout(self._on_write)
590
+ else:
591
+ self._uninstall = console_capture.capture_stderr(self._on_write)
592
+ except console_capture.CannotCaptureConsoleError:
593
+ logger.exception("failed to install %s hooks", self.src)
594
+ wandb.termwarn(
595
+ f"Failed to wrap {self.src}. Console logs will not be captured.",
596
+ )
597
+ return
579
598
 
580
- self._queue = queue.Queue()
581
- self._stopped = threading.Event()
582
599
  self._emulator_write_thread = threading.Thread(target=self._emulator_write)
583
600
  self._emulator_write_thread.daemon = True
584
601
  self._emulator_write_thread.start()
@@ -588,25 +605,25 @@ class StreamWrapper(RedirectBase):
588
605
  self._callback_thread.daemon = True
589
606
  self._callback_thread.start()
590
607
 
591
- self._installed = True
592
-
593
- def flush(self, data=None):
608
+ def flush(self, data: str | None = None) -> None:
594
609
  if data is None:
595
610
  try:
596
611
  data = self._emulator.read().encode("utf-8")
597
612
  except Exception:
598
- pass
613
+ logger.exception("exception reading TerminalEmulator")
614
+
599
615
  if data:
600
616
  for cb in self.cbs:
601
617
  try:
602
618
  cb(data)
603
619
  except Exception:
604
- pass # TODO(frz)
620
+ logger.exception("exception in StreamWrapper callback")
605
621
 
606
- def uninstall(self):
607
- if not self._installed:
622
+ def uninstall(self) -> None:
623
+ if not self._uninstall:
608
624
  return
609
- self.src_wrapped_stream.write = self._old_write
625
+
626
+ self._uninstall()
610
627
 
611
628
  self._stopped.set()
612
629
  self._emulator_write_thread.join(timeout=5)
@@ -616,9 +633,6 @@ class StreamWrapper(RedirectBase):
616
633
  wandb.termlog("Done.")
617
634
  self.flush()
618
635
 
619
- self._installed = False
620
- super().uninstall()
621
-
622
636
 
623
637
  class StreamRawWrapper(RedirectBase):
624
638
  """Patches the write method of current sys.stdout/sys.stderr.
@@ -626,40 +640,44 @@ class StreamRawWrapper(RedirectBase):
626
640
  Captures data in a raw form rather than using the emulator
627
641
  """
628
642
 
629
- def __init__(self, src, cbs=()):
643
+ def __init__(
644
+ self,
645
+ src: Literal["stdout", "stderr"],
646
+ cbs: Iterable[Callable[[str], None]] = (),
647
+ ) -> None:
630
648
  super().__init__(src=src, cbs=cbs)
631
- self._installed = False
649
+ self._uninstall: Callable[[], None] | None = None
632
650
 
633
- def save(self):
634
- stream = self.src_wrapped_stream
635
- self._old_write = stream.write
651
+ def _on_write(self, data: str | bytes, written: int, /) -> None:
652
+ if isinstance(data, bytes):
653
+ written_data = data[:written].decode("utf-8")
654
+ else:
655
+ written_data = data[:written]
636
656
 
637
- def install(self):
638
- super().install()
639
- if self._installed:
640
- return
641
- stream = self.src_wrapped_stream
642
- self._prev_callback_timestamp = time.time()
657
+ for cb in self.cbs:
658
+ try:
659
+ cb(written_data)
660
+ except Exception:
661
+ logger.exception("error in %s callback", self.src)
643
662
 
644
- def write(data):
645
- self._old_write(data)
646
- for cb in self.cbs:
647
- try:
648
- cb(data)
649
- except Exception:
650
- # TODO: Figure out why this was needed and log or error out appropriately
651
- # it might have been strange terminals? maybe shutdown cases?
652
- pass
663
+ def install(self) -> None:
664
+ if self._uninstall:
665
+ return
653
666
 
654
- stream.write = write
655
- self._installed = True
667
+ try:
668
+ if self.src == "stdout":
669
+ self._uninstall = console_capture.capture_stdout(self._on_write)
670
+ else:
671
+ self._uninstall = console_capture.capture_stderr(self._on_write)
672
+ except console_capture.CannotCaptureConsoleError:
673
+ logger.exception("failed to install %s hooks", self.src)
674
+ wandb.termwarn(
675
+ f"Failed to wrap {self.src}. Console logs will not be captured.",
676
+ )
656
677
 
657
- def uninstall(self):
658
- if not self._installed:
659
- return
660
- self.src_wrapped_stream.write = self._old_write
661
- self._installed = False
662
- super().uninstall()
678
+ def uninstall(self) -> None:
679
+ if self._uninstall:
680
+ self._uninstall()
663
681
 
664
682
 
665
683
  class _WindowSizeChangeHandler:
@@ -708,6 +726,8 @@ class _WindowSizeChangeHandler:
708
726
 
709
727
  _WSCH = _WindowSizeChangeHandler()
710
728
 
729
+ _redirects: dict[str, Redirect | None] = {"stdout": None, "stderr": None}
730
+
711
731
 
712
732
  class Redirect(RedirectBase):
713
733
  """Redirect low level file descriptors."""
@@ -725,7 +745,11 @@ class Redirect(RedirectBase):
725
745
  return r, w
726
746
 
727
747
  def install(self):
728
- super().install()
748
+ curr_redirect = _redirects.get(self.src)
749
+ if curr_redirect and curr_redirect != self:
750
+ curr_redirect.uninstall()
751
+ _redirects[self.src] = self
752
+
729
753
  if self._installed:
730
754
  return
731
755
  self._pipe_read_fd, self._pipe_write_fd = self._pipe()
@@ -776,7 +800,9 @@ class Redirect(RedirectBase):
776
800
  self.flush()
777
801
 
778
802
  _WSCH.remove_fd(self._pipe_read_fd)
779
- super().uninstall()
803
+
804
+ if _redirects[self.src] == self:
805
+ _redirects[self.src] = None
780
806
 
781
807
  def flush(self, data=None):
782
808
  if data is None:
@@ -10,10 +10,11 @@ from wandb.proto import wandb_settings_pb2
10
10
  from wandb.sdk import wandb_settings
11
11
  from wandb.sdk.interface.interface import InterfaceBase
12
12
  from wandb.sdk.interface.interface_sock import InterfaceSock
13
+ from wandb.sdk.interface.router_sock import MessageSockRouter
13
14
  from wandb.sdk.lib import service_token
14
15
  from wandb.sdk.lib.exit_hooks import ExitHooks
15
- from wandb.sdk.lib.mailbox import Mailbox
16
- from wandb.sdk.lib.sock_client import SockClient, SockClientTimeoutError
16
+ from wandb.sdk.lib.sock_client import SockClient, SockClientClosedError
17
+ from wandb.sdk.mailbox import HandleAbandonedError, Mailbox, MailboxClosedError
17
18
  from wandb.sdk.service import service
18
19
 
19
20
 
@@ -115,6 +116,9 @@ class ServiceConnection:
115
116
  """Returns a new ServiceConnection.
116
117
 
117
118
  Args:
119
+ mailbox: The mailbox to use for all communication over the socket.
120
+ router: A handle to the thread that reads from the socket and
121
+ updates the mailbox.
118
122
  client: A socket that's connected to the service.
119
123
  proc: The service process if we own it, or None otherwise.
120
124
  cleanup: A callback to run on teardown before doing anything.
@@ -124,9 +128,12 @@ class ServiceConnection:
124
128
  self._torn_down = False
125
129
  self._cleanup = cleanup
126
130
 
127
- def make_interface(self, mailbox: Mailbox, stream_id: str) -> InterfaceBase:
131
+ self._mailbox = Mailbox()
132
+ self._router = MessageSockRouter(self._client, self._mailbox)
133
+
134
+ def make_interface(self, stream_id: str) -> InterfaceBase:
128
135
  """Returns an interface for communicating with the service."""
129
- return InterfaceSock(self._client, mailbox, stream_id=stream_id)
136
+ return InterfaceSock(self._client, self._mailbox, stream_id=stream_id)
130
137
 
131
138
  def send_record(self, record: pb.Record) -> None:
132
139
  """Sends data to the service."""
@@ -141,13 +148,13 @@ class ServiceConnection:
141
148
  request = spb.ServerInformInitRequest()
142
149
  request.settings.CopyFrom(settings)
143
150
  request._info.stream_id = run_id
144
- self._client.send(inform_init=request)
151
+ self._client.send_server_request(spb.ServerRequest(inform_init=request))
145
152
 
146
153
  def inform_finish(self, run_id: str) -> None:
147
154
  """Sends an finish request to the service."""
148
155
  request = spb.ServerInformFinishRequest()
149
156
  request._info.stream_id = run_id
150
- self._client.send(inform_finish=request)
157
+ self._client.send_server_request(spb.ServerRequest(inform_finish=request))
151
158
 
152
159
  def inform_attach(
153
160
  self,
@@ -157,18 +164,26 @@ class ServiceConnection:
157
164
 
158
165
  Raises a WandbAttachFailedError if attaching is not possible.
159
166
  """
160
- request = spb.ServerInformAttachRequest()
161
- request._info.stream_id = attach_id
167
+ request = spb.ServerRequest()
168
+ request.inform_attach._info.stream_id = attach_id
162
169
 
163
170
  try:
164
- response = self._client.send_and_recv(inform_attach=request)
171
+ handle = self._mailbox.require_response(request)
172
+ self._client.send_server_request(request)
173
+ response = handle.wait_or(timeout=10)
165
174
  return response.inform_attach_response.settings
166
- except SockClientTimeoutError:
175
+
176
+ except (MailboxClosedError, HandleAbandonedError, SockClientClosedError):
177
+ raise WandbAttachFailedError(
178
+ "Failed to attach: the service process is not running.",
179
+ ) from None
180
+
181
+ except TimeoutError:
167
182
  raise WandbAttachFailedError(
168
- "Could not attach because the run does not belong to"
183
+ "Failed to attach because the run does not belong to"
169
184
  " the current service process, or because the service"
170
185
  " process is busy (unlikely)."
171
- )
186
+ ) from None
172
187
 
173
188
  def inform_start(
174
189
  self,
@@ -179,7 +194,7 @@ class ServiceConnection:
179
194
  request = spb.ServerInformStartRequest()
180
195
  request.settings.CopyFrom(settings)
181
196
  request._info.stream_id = run_id
182
- self._client.send(inform_start=request)
197
+ self._client.send_server_request(spb.ServerRequest(inform_start=request))
183
198
 
184
199
  def teardown(self, exit_code: int) -> int:
185
200
  """Shuts down the service process and returns its exit code.
@@ -207,10 +222,15 @@ class ServiceConnection:
207
222
  # Clear the service token to prevent new connections from being made.
208
223
  service_token.clear_service_token()
209
224
 
210
- self._client.send(
211
- inform_teardown=spb.ServerInformTeardownRequest(
212
- exit_code=exit_code,
213
- )
225
+ # Stop reading responses on the socket.
226
+ self._router.join()
227
+
228
+ self._client.send_server_request(
229
+ spb.ServerRequest(
230
+ inform_teardown=spb.ServerInformTeardownRequest(
231
+ exit_code=exit_code,
232
+ )
233
+ ),
214
234
  )
215
235
 
216
236
  return self._proc.join()
@@ -150,10 +150,10 @@ class SockClient:
150
150
  with self._lock:
151
151
  self._sendall_with_error_handle(header + data)
152
152
 
153
- def send_server_request(self, msg: Any) -> None:
153
+ def send_server_request(self, msg: spb.ServerRequest) -> None:
154
154
  self._send_message(msg)
155
155
 
156
- def send_server_response(self, msg: Any) -> None:
156
+ def send_server_response(self, msg: spb.ServerResponse) -> None:
157
157
  try:
158
158
  self._send_message(msg)
159
159
  except BrokenPipeError:
@@ -161,63 +161,15 @@ class SockClient:
161
161
  # things like network status poll loop, there might be a better way to quiesce
162
162
  pass
163
163
 
164
- def send_and_recv(
165
- self,
166
- *,
167
- inform_init: Optional[spb.ServerInformInitRequest] = None,
168
- inform_start: Optional[spb.ServerInformStartRequest] = None,
169
- inform_attach: Optional[spb.ServerInformAttachRequest] = None,
170
- inform_finish: Optional[spb.ServerInformFinishRequest] = None,
171
- inform_teardown: Optional[spb.ServerInformTeardownRequest] = None,
172
- ) -> spb.ServerResponse:
173
- self.send(
174
- inform_init=inform_init,
175
- inform_start=inform_start,
176
- inform_attach=inform_attach,
177
- inform_finish=inform_finish,
178
- inform_teardown=inform_teardown,
179
- )
180
- # TODO: this solution is fragile, but for checking attach
181
- # it should be relatively stable.
182
- # This pass would be solved as part of the fix in https://wandb.atlassian.net/browse/WB-8709
183
- response = self.read_server_response(timeout=1)
184
-
185
- if response is None:
186
- raise SockClientTimeoutError("No response after 1 second.")
187
-
188
- return response
189
-
190
- def send(
191
- self,
192
- *,
193
- inform_init: Optional[spb.ServerInformInitRequest] = None,
194
- inform_start: Optional[spb.ServerInformStartRequest] = None,
195
- inform_attach: Optional[spb.ServerInformAttachRequest] = None,
196
- inform_finish: Optional[spb.ServerInformFinishRequest] = None,
197
- inform_teardown: Optional[spb.ServerInformTeardownRequest] = None,
198
- ) -> None:
199
- server_req = spb.ServerRequest()
200
- if inform_init:
201
- server_req.inform_init.CopyFrom(inform_init)
202
- elif inform_start:
203
- server_req.inform_start.CopyFrom(inform_start)
204
- elif inform_attach:
205
- server_req.inform_attach.CopyFrom(inform_attach)
206
- elif inform_finish:
207
- server_req.inform_finish.CopyFrom(inform_finish)
208
- elif inform_teardown:
209
- server_req.inform_teardown.CopyFrom(inform_teardown)
210
- else:
211
- raise Exception("unmatched")
212
- self.send_server_request(server_req)
213
-
214
164
  def send_record_communicate(self, record: "pb.Record") -> None:
215
165
  server_req = spb.ServerRequest()
166
+ server_req.request_id = record.control.mailbox_slot
216
167
  server_req.record_communicate.CopyFrom(record)
217
168
  self.send_server_request(server_req)
218
169
 
219
170
  def send_record_publish(self, record: "pb.Record") -> None:
220
171
  server_req = spb.ServerRequest()
172
+ server_req.request_id = record.control.mailbox_slot
221
173
  server_req.record_publish.CopyFrom(record)
222
174
  self.send_server_request(server_req)
223
175
 
@@ -256,10 +208,8 @@ class SockClient:
256
208
  data = self._sock.recv(self._bufsize)
257
209
  except socket.timeout:
258
210
  break
259
- except ConnectionResetError:
260
- raise SockClientClosedError
261
- except OSError:
262
- raise SockClientClosedError
211
+ except OSError as e:
212
+ raise SockClientClosedError from e
263
213
  finally:
264
214
  if timeout:
265
215
  self._sock.settimeout(None)