wandb 0.21.1__py3-none-win32.whl → 0.21.2__py3-none-win32.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. wandb/__init__.py +1 -1
  2. wandb/__init__.pyi +1 -1
  3. wandb/apis/public/api.py +1 -2
  4. wandb/apis/public/artifacts.py +3 -5
  5. wandb/apis/public/registries/_utils.py +14 -16
  6. wandb/apis/public/registries/registries_search.py +176 -289
  7. wandb/apis/public/reports.py +13 -10
  8. wandb/automations/_generated/delete_automation.py +1 -3
  9. wandb/automations/_generated/enums.py +13 -11
  10. wandb/bin/gpu_stats.exe +0 -0
  11. wandb/bin/wandb-core +0 -0
  12. wandb/cli/cli.py +47 -2
  13. wandb/integration/metaflow/data_pandas.py +2 -2
  14. wandb/integration/metaflow/data_pytorch.py +75 -0
  15. wandb/integration/metaflow/data_sklearn.py +76 -0
  16. wandb/integration/metaflow/metaflow.py +16 -87
  17. wandb/integration/weave/__init__.py +6 -0
  18. wandb/integration/weave/interface.py +49 -0
  19. wandb/integration/weave/weave.py +63 -0
  20. wandb/proto/v3/wandb_internal_pb2.py +3 -2
  21. wandb/proto/v4/wandb_internal_pb2.py +2 -2
  22. wandb/proto/v5/wandb_internal_pb2.py +2 -2
  23. wandb/proto/v6/wandb_internal_pb2.py +2 -2
  24. wandb/sdk/artifacts/_factories.py +17 -0
  25. wandb/sdk/artifacts/_generated/__init__.py +221 -13
  26. wandb/sdk/artifacts/_generated/artifact_by_id.py +17 -0
  27. wandb/sdk/artifacts/_generated/artifact_by_name.py +22 -0
  28. wandb/sdk/artifacts/_generated/artifact_collection_membership_file_urls.py +43 -0
  29. wandb/sdk/artifacts/_generated/artifact_created_by.py +47 -0
  30. wandb/sdk/artifacts/_generated/artifact_file_urls.py +22 -0
  31. wandb/sdk/artifacts/_generated/artifact_type.py +31 -0
  32. wandb/sdk/artifacts/_generated/artifact_used_by.py +43 -0
  33. wandb/sdk/artifacts/_generated/artifact_via_membership_by_name.py +26 -0
  34. wandb/sdk/artifacts/_generated/delete_artifact.py +28 -0
  35. wandb/sdk/artifacts/_generated/enums.py +5 -0
  36. wandb/sdk/artifacts/_generated/fetch_artifact_manifest.py +38 -0
  37. wandb/sdk/artifacts/_generated/fetch_registries.py +32 -0
  38. wandb/sdk/artifacts/_generated/fragments.py +279 -41
  39. wandb/sdk/artifacts/_generated/link_artifact.py +6 -0
  40. wandb/sdk/artifacts/_generated/operations.py +654 -51
  41. wandb/sdk/artifacts/_generated/registry_collections.py +34 -0
  42. wandb/sdk/artifacts/_generated/registry_versions.py +34 -0
  43. wandb/sdk/artifacts/_generated/unlink_artifact.py +25 -0
  44. wandb/sdk/artifacts/_graphql_fragments.py +3 -86
  45. wandb/sdk/artifacts/_validators.py +6 -4
  46. wandb/sdk/artifacts/artifact.py +406 -543
  47. wandb/sdk/artifacts/artifact_file_cache.py +10 -6
  48. wandb/sdk/artifacts/artifact_manifest.py +10 -9
  49. wandb/sdk/artifacts/artifact_manifest_entry.py +9 -10
  50. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +5 -3
  51. wandb/sdk/artifacts/storage_handlers/http_handler.py +1 -1
  52. wandb/sdk/artifacts/storage_handlers/s3_handler.py +1 -1
  53. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +1 -1
  54. wandb/sdk/data_types/video.py +2 -2
  55. wandb/sdk/interface/interface_queue.py +1 -4
  56. wandb/sdk/interface/interface_shared.py +26 -37
  57. wandb/sdk/interface/interface_sock.py +24 -14
  58. wandb/sdk/internal/settings_static.py +2 -3
  59. wandb/sdk/launch/create_job.py +12 -1
  60. wandb/sdk/launch/runner/kubernetes_runner.py +24 -29
  61. wandb/sdk/lib/asyncio_compat.py +16 -16
  62. wandb/sdk/lib/asyncio_manager.py +252 -0
  63. wandb/sdk/lib/hashutil.py +13 -4
  64. wandb/sdk/lib/printer.py +2 -2
  65. wandb/sdk/lib/printer_asyncio.py +3 -1
  66. wandb/sdk/lib/retry.py +185 -78
  67. wandb/sdk/lib/service/service_client.py +106 -0
  68. wandb/sdk/lib/service/service_connection.py +20 -26
  69. wandb/sdk/lib/service/service_token.py +30 -13
  70. wandb/sdk/mailbox/mailbox.py +13 -5
  71. wandb/sdk/mailbox/mailbox_handle.py +22 -13
  72. wandb/sdk/mailbox/response_handle.py +42 -106
  73. wandb/sdk/mailbox/wait_with_progress.py +7 -42
  74. wandb/sdk/wandb_init.py +11 -25
  75. wandb/sdk/wandb_login.py +1 -1
  76. wandb/sdk/wandb_run.py +91 -55
  77. wandb/sdk/wandb_settings.py +45 -32
  78. wandb/sdk/wandb_setup.py +176 -96
  79. wandb/util.py +1 -1
  80. {wandb-0.21.1.dist-info → wandb-0.21.2.dist-info}/METADATA +1 -1
  81. {wandb-0.21.1.dist-info → wandb-0.21.2.dist-info}/RECORD +84 -68
  82. wandb/sdk/interface/interface_relay.py +0 -38
  83. wandb/sdk/interface/router.py +0 -89
  84. wandb/sdk/interface/router_queue.py +0 -43
  85. wandb/sdk/interface/router_relay.py +0 -50
  86. wandb/sdk/interface/router_sock.py +0 -32
  87. wandb/sdk/lib/sock_client.py +0 -232
  88. {wandb-0.21.1.dist-info → wandb-0.21.2.dist-info}/WHEEL +0 -0
  89. {wandb-0.21.1.dist-info → wandb-0.21.2.dist-info}/entry_points.txt +0 -0
  90. {wandb-0.21.1.dist-info → wandb-0.21.2.dist-info}/licenses/LICENSE +0 -0
wandb/sdk/lib/retry.py CHANGED
@@ -94,109 +94,216 @@ class Retry(Generic[_R]):
94
94
  """The number of iterations the previous __call__ retried."""
95
95
  return self._num_iter
96
96
 
97
- def __call__(self, *args: Any, **kwargs: Any) -> _R:
97
+ def __call__(
98
+ self,
99
+ *args: Any,
100
+ num_retries: Optional[int] = None,
101
+ retry_timedelta: Optional[datetime.timedelta] = None,
102
+ retry_sleep_base: Optional[float] = None,
103
+ retry_cancel_event: Optional[threading.Event] = None,
104
+ check_retry_fn: Optional[CheckRetryFnType] = None,
105
+ **kwargs: Any,
106
+ ) -> _R:
98
107
  """Call the wrapped function, with retries.
99
108
 
100
109
  Args:
101
- retry_timedelta (kwarg): amount of time to retry before giving up.
102
- sleep_base (kwarg): amount of time to sleep upon first failure, all other sleeps
103
- are derived from this one.
110
+ num_retries: The number of retries after which to give up.
111
+ retry_timedelta: An amount of time after which to give up.
112
+ retry_sleep_base: Number of seconds to sleep for the first retry.
113
+ This is used as the base for exponential backoff.
114
+ retry_cancel_event: An event that causes this to raise
115
+ a RetryCancelledException on the next attempted retry.
116
+ check_retry_fn: A custom check for deciding whether an exception
117
+ should be retried. Retrying is prevented if this returns a falsy
118
+ value, even if more retries are left. This may also return a
119
+ timedelta that represents a shorter timeout: retrying is
120
+ prevented if the value is less than the amount of time that has
121
+ passed since the last timedelta was returned.
104
122
  """
105
- retry_timedelta = kwargs.pop("retry_timedelta", self._retry_timedelta)
106
- if retry_timedelta is None:
107
- retry_timedelta = datetime.timedelta(days=365)
108
-
109
- retry_cancel_event = kwargs.pop("retry_cancel_event", self._retry_cancel_event)
110
-
111
- num_retries = kwargs.pop("num_retries", self._num_retries)
112
- if num_retries is None:
113
- num_retries = 1000000
114
-
115
123
  if os.environ.get("WANDB_TEST"):
116
- num_retries = 0
124
+ max_retries = 0
125
+ elif num_retries is not None:
126
+ max_retries = num_retries
127
+ elif self._num_retries is not None:
128
+ max_retries = self._num_retries
129
+ else:
130
+ max_retries = 1000000
117
131
 
118
- sleep_base: float = kwargs.pop("retry_sleep_base", 1)
132
+ if retry_timedelta is not None:
133
+ timeout = retry_timedelta
134
+ elif self._retry_timedelta is not None:
135
+ timeout = self._retry_timedelta
136
+ else:
137
+ timeout = datetime.timedelta(days=365)
119
138
 
120
- # an extra function to allow performing more logic on the filtered exception
121
- check_retry_fn: CheckRetryFnType = kwargs.pop(
122
- "check_retry_fn", self._check_retry_fn
139
+ if retry_sleep_base is not None:
140
+ initial_sleep = retry_sleep_base
141
+ else:
142
+ initial_sleep = 1
143
+
144
+ retry_loop = _RetryLoop(
145
+ max_retries=max_retries,
146
+ timeout=timeout,
147
+ initial_sleep=initial_sleep,
148
+ max_sleep=self.MAX_SLEEP_SECONDS,
149
+ cancel_event=retry_cancel_event or self._retry_cancel_event,
150
+ retry_check=check_retry_fn or self._check_retry_fn,
123
151
  )
124
152
 
125
- sleep = sleep_base
126
- now = NOW_FN()
127
- start_time = now
128
- start_time_triggered = None
129
-
153
+ start_time = NOW_FN()
130
154
  self._num_iter = 0
131
155
 
132
156
  while True:
133
157
  try:
134
158
  result = self._call_fn(*args, **kwargs)
135
- # Only print resolved attempts once every minute
136
- if self._num_iter > 2 and now - self._last_print > datetime.timedelta(
137
- minutes=1
138
- ):
139
- self._last_print = NOW_FN()
140
- if self.retry_callback:
141
- self.retry_callback(
142
- 200,
143
- f"{self._error_prefix} resolved after {NOW_FN() - start_time}, resuming normal operation.",
144
- )
145
- return result
159
+
146
160
  except self._retryable_exceptions as e:
147
- # if the secondary check fails, re-raise
148
- retry_timedelta_triggered = check_retry_fn(e)
149
- if not retry_timedelta_triggered:
161
+ if not retry_loop.should_retry(e):
150
162
  raise
151
163
 
152
- # always enforce num_retries no matter which type of exception was seen
153
- if self._num_iter >= num_retries:
154
- raise
164
+ if self._num_iter == 2:
165
+ logger.info("Retry attempt failed:", exc_info=e)
166
+ self._print_entered_retry_loop(e)
155
167
 
156
- now = NOW_FN()
168
+ retry_loop.wait_before_retry()
169
+ self._num_iter += 1
157
170
 
158
- # handle a triggered secondary check which could have a shortened timeout
159
- if isinstance(retry_timedelta_triggered, datetime.timedelta):
160
- # save the time of the first secondary trigger
161
- if not start_time_triggered:
162
- start_time_triggered = now
171
+ else:
172
+ if self._num_iter > 2:
173
+ self._print_recovered(start_time)
163
174
 
164
- # make sure that we haven't run out of time from secondary trigger
165
- if now - start_time_triggered >= retry_timedelta_triggered:
166
- raise
175
+ return result
167
176
 
168
- # always enforce the default timeout from start of retries
169
- if now - start_time >= retry_timedelta:
170
- raise
177
+ def _print_entered_retry_loop(self, exception: Exception) -> None:
178
+ """Emit a message saying we've begun retrying.
171
179
 
172
- if self._num_iter == 2:
173
- logger.info("Retry attempt failed:", exc_info=e)
174
- if (
175
- isinstance(e, HTTPError)
176
- and e.response is not None
177
- and self.retry_callback is not None
178
- ):
179
- self.retry_callback(e.response.status_code, e.response.text)
180
- else:
181
- # todo: would like to catch other errors, eg wandb.errors.Error, ConnectionError etc
182
- # but some of these can be raised before the retry handler thread (RunStatusChecker) is
183
- # spawned in wandb_init
184
- wandb.termlog(
185
- f"{self._error_prefix} ({e.__class__.__name__}), entering retry loop."
186
- )
187
- # if wandb.env.is_debug():
188
- # traceback.print_exc()
189
- cancelled = self._sleep_check_cancelled(
190
- sleep + random.random() * 0.25 * sleep, cancel_event=retry_cancel_event
180
+ Either calls the retry callback or prints a warning to console.
181
+
182
+ Args:
183
+ exception: The most recent exception we will retry.
184
+ """
185
+ if (
186
+ isinstance(exception, HTTPError)
187
+ and exception.response is not None
188
+ and self.retry_callback is not None
189
+ ):
190
+ self.retry_callback(
191
+ exception.response.status_code,
192
+ exception.response.text,
191
193
  )
194
+ else:
195
+ wandb.termlog(
196
+ f"{self._error_prefix}"
197
+ f" ({exception.__class__.__name__}), entering retry loop."
198
+ )
199
+
200
+ def _print_recovered(self, start_time: datetime.datetime) -> None:
201
+ """Emit a message saying we've recovered after retrying.
202
+
203
+ Args:
204
+ start_time: When we started retrying.
205
+ """
206
+ if not self.retry_callback:
207
+ return
208
+
209
+ now = NOW_FN()
210
+ if now - self._last_print < datetime.timedelta(minutes=1):
211
+ return
212
+ self._last_print = now
213
+
214
+ time_to_recover = now - start_time
215
+ self.retry_callback(
216
+ 200,
217
+ (
218
+ f"{self._error_prefix} resolved after"
219
+ f" {time_to_recover}, resuming normal operation."
220
+ ),
221
+ )
222
+
223
+
224
+ class _RetryLoop:
225
+ """An invocation of a Retry instance."""
226
+
227
+ def __init__(
228
+ self,
229
+ *,
230
+ max_retries: int,
231
+ timeout: datetime.timedelta,
232
+ initial_sleep: float,
233
+ max_sleep: float,
234
+ cancel_event: Optional[threading.Event],
235
+ retry_check: CheckRetryFnType,
236
+ ) -> None:
237
+ """Start a new call of a Retry instance.
238
+
239
+ Args:
240
+ max_retries: The number of retries after which to give up.
241
+ timeout: An amount of time after which to give up.
242
+ initial_sleep: Number of seconds to sleep for the first retry.
243
+ This is used as the base for exponential backoff.
244
+ max_sleep: Maximum number of seconds to sleep between retries.
245
+ cancel_event: An event that's set when the function is cancelled.
246
+ retry_check: A custom check for deciding whether an exception should
247
+ be retried. Retrying is prevented if this returns a falsy value,
248
+ even if more retries are left. This may also return a timedelta
249
+ that represents a shorter timeout: retrying is prevented if the
250
+ value is less than the amount of time that has passed since the
251
+ last timedelta was returned.
252
+ """
253
+ self._max_retries = max_retries
254
+ self._total_retries = 0
255
+
256
+ self._timeout = timeout
257
+ self._start_time = NOW_FN()
258
+
259
+ self._next_sleep_time = initial_sleep
260
+ self._max_sleep = max_sleep
261
+ self._cancel_event = cancel_event
262
+
263
+ self._retry_check = retry_check
264
+ self._last_custom_timeout: Optional[datetime.datetime] = None
265
+
266
+ def should_retry(self, exception: Exception) -> bool:
267
+ """Returns whether an exception should be retried."""
268
+ if self._total_retries >= self._max_retries:
269
+ return False
270
+ self._total_retries += 1
271
+
272
+ now = NOW_FN()
273
+ if now - self._start_time >= self._timeout:
274
+ return False
275
+
276
+ retry_check_result = self._retry_check(exception)
277
+ if not retry_check_result:
278
+ return False
279
+
280
+ if isinstance(retry_check_result, datetime.timedelta):
281
+ if not self._last_custom_timeout:
282
+ self._last_custom_timeout = now
283
+
284
+ if now - self._last_custom_timeout >= retry_check_result:
285
+ return False
286
+
287
+ return True
288
+
289
+ def wait_before_retry(self) -> None:
290
+ """Block until the next retry should happen.
291
+
292
+ Raises:
293
+ RetryCancelledError: If the operation is cancelled.
294
+ """
295
+ sleep_amount = self._next_sleep_time * (1 + random.random() * 0.25)
296
+
297
+ if self._cancel_event:
298
+ cancelled = self._cancel_event.wait(sleep_amount)
192
299
  if cancelled:
193
- raise RetryCancelledError("retry timeout")
194
- sleep *= 2
195
- if sleep > self.MAX_SLEEP_SECONDS:
196
- sleep = self.MAX_SLEEP_SECONDS
197
- now = NOW_FN()
300
+ raise RetryCancelledError("Cancelled while retrying.")
301
+ else:
302
+ SLEEP_FN(sleep_amount)
198
303
 
199
- self._num_iter += 1
304
+ self._next_sleep_time *= 2
305
+ if self._next_sleep_time > self._max_sleep:
306
+ self._next_sleep_time = self._max_sleep
200
307
 
201
308
 
202
309
  _F = TypeVar("_F", bound=Callable)
@@ -0,0 +1,106 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import logging
5
+ import struct
6
+
7
+ from wandb.proto import wandb_server_pb2 as spb
8
+ from wandb.sdk.lib import asyncio_manager
9
+ from wandb.sdk.mailbox.mailbox import Mailbox
10
+ from wandb.sdk.mailbox.mailbox_handle import MailboxHandle
11
+
12
+ _logger = logging.getLogger(__name__)
13
+
14
+ _HEADER_BYTE_INT_LEN = 5
15
+ _HEADER_BYTE_INT_FMT = "<BI"
16
+
17
+
18
+ class ServiceClient:
19
+ """Implements socket communication with the internal service."""
20
+
21
+ def __init__(
22
+ self,
23
+ asyncer: asyncio_manager.AsyncioManager,
24
+ reader: asyncio.StreamReader,
25
+ writer: asyncio.StreamWriter,
26
+ ) -> None:
27
+ self._asyncer = asyncer
28
+ self._reader = reader
29
+ self._writer = writer
30
+ self._mailbox = Mailbox(asyncer)
31
+ asyncer.run_soon(
32
+ self._forward_responses,
33
+ daemon=True,
34
+ name="ServiceClient._forward_responses",
35
+ )
36
+
37
+ def publish(self, request: spb.ServerRequest) -> None:
38
+ """Send a request without waiting for a response."""
39
+ self._asyncer.run_soon(lambda: self._send_server_request(request))
40
+
41
+ def deliver(
42
+ self,
43
+ request: spb.ServerRequest,
44
+ ) -> MailboxHandle[spb.ServerResponse]:
45
+ """Send a request and return a handle to wait for a response.
46
+
47
+ NOTE: This may mutate the request. The request should not be used
48
+ after.
49
+
50
+ Raises:
51
+ MailboxClosedError: If used after the client is closed or has
52
+ stopped due to an error.
53
+ """
54
+ handle = self._mailbox.require_response(request)
55
+ self._asyncer.run_soon(lambda: self._send_server_request(request))
56
+ return handle
57
+
58
+ async def _send_server_request(self, request: spb.ServerRequest) -> None:
59
+ header = struct.pack(_HEADER_BYTE_INT_FMT, ord("W"), request.ByteSize())
60
+ self._writer.write(header)
61
+
62
+ data = request.SerializeToString()
63
+ self._writer.write(data)
64
+
65
+ await self._writer.drain()
66
+
67
+ def close(self) -> None:
68
+ """Flush and close the socket."""
69
+ self._asyncer.run_soon(self._close)
70
+
71
+ async def _close(self) -> None:
72
+ self._writer.close()
73
+ await self._writer.wait_closed()
74
+
75
+ async def _forward_responses(self) -> None:
76
+ try:
77
+ while response := await self._read_server_response():
78
+ await self._mailbox.deliver(response)
79
+
80
+ except Exception:
81
+ _logger.exception("Error reading server response.")
82
+
83
+ else:
84
+ _logger.info("Reached EOF.")
85
+
86
+ finally:
87
+ self._mailbox.close()
88
+
89
+ async def _read_server_response(self) -> spb.ServerResponse | None:
90
+ try:
91
+ header = await self._reader.readexactly(_HEADER_BYTE_INT_LEN)
92
+ except asyncio.IncompleteReadError as e:
93
+ if e.partial:
94
+ raise
95
+ else:
96
+ return None
97
+
98
+ magic, length = struct.unpack(_HEADER_BYTE_INT_FMT, header)
99
+
100
+ if magic != ord("W"):
101
+ raise ValueError(f"Bad header: {header.hex()}")
102
+
103
+ data = await self._reader.readexactly(length)
104
+ response = spb.ServerResponse()
105
+ response.ParseFromString(data)
106
+ return response
@@ -3,16 +3,15 @@ from __future__ import annotations
3
3
  import atexit
4
4
  from typing import Callable
5
5
 
6
- from wandb.proto import wandb_internal_pb2 as pb
7
6
  from wandb.proto import wandb_server_pb2 as spb
8
7
  from wandb.proto import wandb_settings_pb2
9
8
  from wandb.sdk import wandb_settings
10
9
  from wandb.sdk.interface.interface import InterfaceBase
11
10
  from wandb.sdk.interface.interface_sock import InterfaceSock
12
- from wandb.sdk.interface.router_sock import MessageSockRouter
11
+ from wandb.sdk.lib import asyncio_manager
13
12
  from wandb.sdk.lib.exit_hooks import ExitHooks
14
- from wandb.sdk.lib.sock_client import SockClient, SockClientClosedError
15
- from wandb.sdk.mailbox import HandleAbandonedError, Mailbox, MailboxClosedError
13
+ from wandb.sdk.lib.service.service_client import ServiceClient
14
+ from wandb.sdk.mailbox import HandleAbandonedError, MailboxClosedError
16
15
 
17
16
  from . import service_process, service_token
18
17
 
@@ -22,18 +21,23 @@ class WandbAttachFailedError(Exception):
22
21
 
23
22
 
24
23
  def connect_to_service(
24
+ asyncer: asyncio_manager.AsyncioManager,
25
25
  settings: wandb_settings.Settings,
26
26
  ) -> ServiceConnection:
27
27
  """Connect to the service process, starting one up if necessary."""
28
28
  token = service_token.from_env()
29
29
 
30
30
  if token:
31
- return ServiceConnection(client=token.connect(), proc=None)
31
+ return ServiceConnection(
32
+ client=token.connect(asyncer=asyncer),
33
+ proc=None,
34
+ )
32
35
  else:
33
- return _start_and_connect_service(settings)
36
+ return _start_and_connect_service(asyncer, settings)
34
37
 
35
38
 
36
39
  def _start_and_connect_service(
40
+ asyncer: asyncio_manager.AsyncioManager,
37
41
  settings: wandb_settings.Settings,
38
42
  ) -> ServiceConnection:
39
43
  """Start a service process and returns a connection to it.
@@ -44,7 +48,7 @@ def _start_and_connect_service(
44
48
  """
45
49
  proc = service_process.start(settings)
46
50
 
47
- client = proc.token.connect()
51
+ client = proc.token.connect(asyncer=asyncer)
48
52
  proc.token.save_to_env()
49
53
 
50
54
  hooks = ExitHooks()
@@ -69,7 +73,7 @@ class ServiceConnection:
69
73
 
70
74
  def __init__(
71
75
  self,
72
- client: SockClient,
76
+ client: ServiceClient,
73
77
  proc: service_process.ServiceProcess | None,
74
78
  cleanup: Callable[[], None] | None = None,
75
79
  ):
@@ -88,16 +92,9 @@ class ServiceConnection:
88
92
  self._torn_down = False
89
93
  self._cleanup = cleanup
90
94
 
91
- self._mailbox = Mailbox()
92
- self._router = MessageSockRouter(self._client, self._mailbox)
93
-
94
95
  def make_interface(self, stream_id: str) -> InterfaceBase:
95
96
  """Returns an interface for communicating with the service."""
96
- return InterfaceSock(self._client, self._mailbox, stream_id=stream_id)
97
-
98
- def send_record(self, record: pb.Record) -> None:
99
- """Send data to the service."""
100
- self._client.send_record_publish(record)
97
+ return InterfaceSock(self._client, stream_id=stream_id)
101
98
 
102
99
  def inform_init(
103
100
  self,
@@ -108,13 +105,13 @@ class ServiceConnection:
108
105
  request = spb.ServerInformInitRequest()
109
106
  request.settings.CopyFrom(settings)
110
107
  request._info.stream_id = run_id
111
- self._client.send_server_request(spb.ServerRequest(inform_init=request))
108
+ self._client.publish(spb.ServerRequest(inform_init=request))
112
109
 
113
110
  def inform_finish(self, run_id: str) -> None:
114
111
  """Send an finish request to the service."""
115
112
  request = spb.ServerInformFinishRequest()
116
113
  request._info.stream_id = run_id
117
- self._client.send_server_request(spb.ServerRequest(inform_finish=request))
114
+ self._client.publish(spb.ServerRequest(inform_finish=request))
118
115
 
119
116
  def inform_attach(
120
117
  self,
@@ -128,11 +125,10 @@ class ServiceConnection:
128
125
  request.inform_attach._info.stream_id = attach_id
129
126
 
130
127
  try:
131
- handle = self._mailbox.require_response(request)
132
- self._client.send_server_request(request)
128
+ handle = self._client.deliver(request)
133
129
  response = handle.wait_or(timeout=10)
134
130
 
135
- except (MailboxClosedError, HandleAbandonedError, SockClientClosedError):
131
+ except (MailboxClosedError, HandleAbandonedError):
136
132
  raise WandbAttachFailedError(
137
133
  "Failed to attach: the service process is not running.",
138
134
  ) from None
@@ -156,7 +152,7 @@ class ServiceConnection:
156
152
  request = spb.ServerInformStartRequest()
157
153
  request.settings.CopyFrom(settings)
158
154
  request._info.stream_id = run_id
159
- self._client.send_server_request(spb.ServerRequest(inform_start=request))
155
+ self._client.publish(spb.ServerRequest(inform_start=request))
160
156
 
161
157
  def teardown(self, exit_code: int) -> int | None:
162
158
  """Close the connection.
@@ -178,21 +174,19 @@ class ServiceConnection:
178
174
  if self._cleanup:
179
175
  self._cleanup()
180
176
 
181
- # Stop reading responses on the socket.
182
- self._router.join()
183
-
184
177
  if not self._proc:
185
178
  return None
186
179
 
187
180
  # Clear the service token to prevent new connections to the process.
188
181
  service_token.clear_service_in_env()
189
182
 
190
- self._client.send_server_request(
183
+ self._client.publish(
191
184
  spb.ServerRequest(
192
185
  inform_teardown=spb.ServerInformTeardownRequest(
193
186
  exit_code=exit_code,
194
187
  )
195
188
  ),
196
189
  )
190
+ self._client.close()
197
191
 
198
192
  return self._proc.join()
@@ -1,15 +1,17 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import abc
4
+ import asyncio
4
5
  import os
5
6
  import re
6
- import socket
7
7
 
8
8
  from typing_extensions import final, override
9
9
 
10
10
  from wandb import env
11
+ from wandb.sdk.lib import asyncio_manager
11
12
  from wandb.sdk.lib.service import ipc_support
12
- from wandb.sdk.lib.sock_client import SockClient
13
+
14
+ from .service_client import ServiceClient
13
15
 
14
16
  _CURRENT_VERSION = "3"
15
17
 
@@ -53,9 +55,16 @@ class ServiceToken(abc.ABC):
53
55
  """A way of connecting to a running service process."""
54
56
 
55
57
  @abc.abstractmethod
56
- def connect(self) -> SockClient:
58
+ def connect(
59
+ self,
60
+ *,
61
+ asyncer: asyncio_manager.AsyncioManager,
62
+ ) -> ServiceClient:
57
63
  """Connect to the service process.
58
64
 
65
+ Args:
66
+ asyncer: A started AsyncioManager for asyncio operations.
67
+
59
68
  Returns:
60
69
  A socket object for communicating with the service.
61
70
 
@@ -81,21 +90,25 @@ class UnixServiceToken(ServiceToken):
81
90
  self._path = path
82
91
 
83
92
  @override
84
- def connect(self) -> SockClient:
93
+ def connect(
94
+ self,
95
+ *,
96
+ asyncer: asyncio_manager.AsyncioManager,
97
+ ) -> ServiceClient:
85
98
  if not ipc_support.SUPPORTS_UNIX:
86
99
  raise WandbServiceConnectionError("AF_UNIX socket not supported")
87
100
 
88
- sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
89
-
90
101
  try:
91
102
  # TODO: This may block indefinitely if the service is unhealthy.
92
- sock.connect(self._path)
103
+ reader, writer = asyncer.run(
104
+ lambda: asyncio.open_unix_connection(self._path),
105
+ )
93
106
  except Exception as e:
94
107
  raise WandbServiceConnectionError(
95
108
  f"Failed to connect to service on socket {self._path}",
96
109
  ) from e
97
110
 
98
- return SockClient(sock)
111
+ return ServiceClient(asyncer, reader, writer)
99
112
 
100
113
  @override
101
114
  def _as_env_string(self):
@@ -128,18 +141,22 @@ class TCPServiceToken(ServiceToken):
128
141
  self._port = port
129
142
 
130
143
  @override
131
- def connect(self) -> SockClient:
132
- sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
133
-
144
+ def connect(
145
+ self,
146
+ *,
147
+ asyncer: asyncio_manager.AsyncioManager,
148
+ ) -> ServiceClient:
134
149
  try:
135
150
  # TODO: This may block indefinitely if the service is unhealthy.
136
- sock.connect(("localhost", self._port))
151
+ reader, writer = asyncer.run(
152
+ lambda: asyncio.open_connection("localhost", self._port),
153
+ )
137
154
  except Exception as e:
138
155
  raise WandbServiceConnectionError(
139
156
  f"Failed to connect to service on port {self._port}",
140
157
  ) from e
141
158
 
142
- return SockClient(sock)
159
+ return ServiceClient(asyncer, reader, writer)
143
160
 
144
161
  @override
145
162
  def _as_env_string(self):