wandb 0.21.1__py3-none-musllinux_1_2_aarch64.whl → 0.21.3__py3-none-musllinux_1_2_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +1 -1
- wandb/__init__.pyi +1 -1
- wandb/apis/public/api.py +1 -2
- wandb/apis/public/artifacts.py +3 -5
- wandb/apis/public/registries/_utils.py +14 -16
- wandb/apis/public/registries/registries_search.py +176 -289
- wandb/apis/public/reports.py +13 -10
- wandb/automations/_generated/delete_automation.py +1 -3
- wandb/automations/_generated/enums.py +13 -11
- wandb/bin/gpu_stats +0 -0
- wandb/bin/wandb-core +0 -0
- wandb/cli/cli.py +47 -2
- wandb/integration/metaflow/data_pandas.py +2 -2
- wandb/integration/metaflow/data_pytorch.py +75 -0
- wandb/integration/metaflow/data_sklearn.py +76 -0
- wandb/integration/metaflow/metaflow.py +16 -87
- wandb/integration/weave/__init__.py +6 -0
- wandb/integration/weave/interface.py +49 -0
- wandb/integration/weave/weave.py +63 -0
- wandb/proto/v3/wandb_internal_pb2.py +3 -2
- wandb/proto/v4/wandb_internal_pb2.py +2 -2
- wandb/proto/v5/wandb_internal_pb2.py +2 -2
- wandb/proto/v6/wandb_internal_pb2.py +2 -2
- wandb/sdk/artifacts/_factories.py +17 -0
- wandb/sdk/artifacts/_generated/__init__.py +221 -13
- wandb/sdk/artifacts/_generated/artifact_by_id.py +17 -0
- wandb/sdk/artifacts/_generated/artifact_by_name.py +22 -0
- wandb/sdk/artifacts/_generated/artifact_collection_membership_file_urls.py +43 -0
- wandb/sdk/artifacts/_generated/artifact_created_by.py +47 -0
- wandb/sdk/artifacts/_generated/artifact_file_urls.py +22 -0
- wandb/sdk/artifacts/_generated/artifact_type.py +31 -0
- wandb/sdk/artifacts/_generated/artifact_used_by.py +43 -0
- wandb/sdk/artifacts/_generated/artifact_via_membership_by_name.py +26 -0
- wandb/sdk/artifacts/_generated/delete_artifact.py +28 -0
- wandb/sdk/artifacts/_generated/enums.py +5 -0
- wandb/sdk/artifacts/_generated/fetch_artifact_manifest.py +38 -0
- wandb/sdk/artifacts/_generated/fetch_registries.py +32 -0
- wandb/sdk/artifacts/_generated/fragments.py +279 -41
- wandb/sdk/artifacts/_generated/link_artifact.py +6 -0
- wandb/sdk/artifacts/_generated/operations.py +654 -51
- wandb/sdk/artifacts/_generated/registry_collections.py +34 -0
- wandb/sdk/artifacts/_generated/registry_versions.py +34 -0
- wandb/sdk/artifacts/_generated/unlink_artifact.py +25 -0
- wandb/sdk/artifacts/_graphql_fragments.py +3 -86
- wandb/sdk/artifacts/_validators.py +6 -4
- wandb/sdk/artifacts/artifact.py +410 -547
- wandb/sdk/artifacts/artifact_file_cache.py +11 -7
- wandb/sdk/artifacts/artifact_manifest.py +10 -9
- wandb/sdk/artifacts/artifact_manifest_entry.py +15 -18
- wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +5 -3
- wandb/sdk/artifacts/storage_handlers/gcs_handler.py +1 -1
- wandb/sdk/artifacts/storage_handlers/http_handler.py +1 -1
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +1 -1
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +1 -1
- wandb/sdk/data_types/video.py +2 -2
- wandb/sdk/interface/interface_queue.py +1 -4
- wandb/sdk/interface/interface_shared.py +26 -37
- wandb/sdk/interface/interface_sock.py +24 -14
- wandb/sdk/internal/settings_static.py +2 -3
- wandb/sdk/launch/create_job.py +12 -1
- wandb/sdk/launch/inputs/internal.py +25 -24
- wandb/sdk/launch/inputs/schema.py +31 -1
- wandb/sdk/launch/runner/kubernetes_runner.py +24 -29
- wandb/sdk/lib/asyncio_compat.py +16 -16
- wandb/sdk/lib/asyncio_manager.py +252 -0
- wandb/sdk/lib/hashutil.py +13 -4
- wandb/sdk/lib/paths.py +23 -21
- wandb/sdk/lib/printer.py +2 -2
- wandb/sdk/lib/printer_asyncio.py +3 -1
- wandb/sdk/lib/retry.py +185 -78
- wandb/sdk/lib/service/service_client.py +106 -0
- wandb/sdk/lib/service/service_connection.py +20 -26
- wandb/sdk/lib/service/service_token.py +30 -13
- wandb/sdk/mailbox/mailbox.py +13 -5
- wandb/sdk/mailbox/mailbox_handle.py +22 -13
- wandb/sdk/mailbox/response_handle.py +42 -106
- wandb/sdk/mailbox/wait_with_progress.py +7 -42
- wandb/sdk/wandb_init.py +11 -25
- wandb/sdk/wandb_login.py +1 -1
- wandb/sdk/wandb_run.py +92 -56
- wandb/sdk/wandb_settings.py +45 -32
- wandb/sdk/wandb_setup.py +176 -96
- wandb/util.py +1 -1
- {wandb-0.21.1.dist-info → wandb-0.21.3.dist-info}/METADATA +2 -2
- {wandb-0.21.1.dist-info → wandb-0.21.3.dist-info}/RECORD +88 -72
- wandb/sdk/interface/interface_relay.py +0 -38
- wandb/sdk/interface/router.py +0 -89
- wandb/sdk/interface/router_queue.py +0 -43
- wandb/sdk/interface/router_relay.py +0 -50
- wandb/sdk/interface/router_sock.py +0 -32
- wandb/sdk/lib/sock_client.py +0 -232
- {wandb-0.21.1.dist-info → wandb-0.21.3.dist-info}/WHEEL +0 -0
- {wandb-0.21.1.dist-info → wandb-0.21.3.dist-info}/entry_points.txt +0 -0
- {wandb-0.21.1.dist-info → wandb-0.21.3.dist-info}/licenses/LICENSE +0 -0
wandb/sdk/lib/printer_asyncio.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
import asyncio
|
2
2
|
from typing import Callable, TypeVar
|
3
3
|
|
4
|
+
from wandb.sdk import wandb_setup
|
4
5
|
from wandb.sdk.lib import asyncio_compat, printer
|
5
6
|
|
6
7
|
_T = TypeVar("_T")
|
@@ -43,4 +44,5 @@ def run_async_with_spinner(
|
|
43
44
|
func_running.set()
|
44
45
|
return res
|
45
46
|
|
46
|
-
|
47
|
+
asyncer = wandb_setup.singleton().asyncer
|
48
|
+
return asyncer.run(_loop_run_with_spinner)
|
wandb/sdk/lib/retry.py
CHANGED
@@ -94,109 +94,216 @@ class Retry(Generic[_R]):
|
|
94
94
|
"""The number of iterations the previous __call__ retried."""
|
95
95
|
return self._num_iter
|
96
96
|
|
97
|
-
def __call__(
|
97
|
+
def __call__(
|
98
|
+
self,
|
99
|
+
*args: Any,
|
100
|
+
num_retries: Optional[int] = None,
|
101
|
+
retry_timedelta: Optional[datetime.timedelta] = None,
|
102
|
+
retry_sleep_base: Optional[float] = None,
|
103
|
+
retry_cancel_event: Optional[threading.Event] = None,
|
104
|
+
check_retry_fn: Optional[CheckRetryFnType] = None,
|
105
|
+
**kwargs: Any,
|
106
|
+
) -> _R:
|
98
107
|
"""Call the wrapped function, with retries.
|
99
108
|
|
100
109
|
Args:
|
101
|
-
|
102
|
-
|
103
|
-
|
110
|
+
num_retries: The number of retries after which to give up.
|
111
|
+
retry_timedelta: An amount of time after which to give up.
|
112
|
+
retry_sleep_base: Number of seconds to sleep for the first retry.
|
113
|
+
This is used as the base for exponential backoff.
|
114
|
+
retry_cancel_event: An event that causes this to raise
|
115
|
+
a RetryCancelledException on the next attempted retry.
|
116
|
+
check_retry_fn: A custom check for deciding whether an exception
|
117
|
+
should be retried. Retrying is prevented if this returns a falsy
|
118
|
+
value, even if more retries are left. This may also return a
|
119
|
+
timedelta that represents a shorter timeout: retrying is
|
120
|
+
prevented if the value is less than the amount of time that has
|
121
|
+
passed since the last timedelta was returned.
|
104
122
|
"""
|
105
|
-
retry_timedelta = kwargs.pop("retry_timedelta", self._retry_timedelta)
|
106
|
-
if retry_timedelta is None:
|
107
|
-
retry_timedelta = datetime.timedelta(days=365)
|
108
|
-
|
109
|
-
retry_cancel_event = kwargs.pop("retry_cancel_event", self._retry_cancel_event)
|
110
|
-
|
111
|
-
num_retries = kwargs.pop("num_retries", self._num_retries)
|
112
|
-
if num_retries is None:
|
113
|
-
num_retries = 1000000
|
114
|
-
|
115
123
|
if os.environ.get("WANDB_TEST"):
|
116
|
-
|
124
|
+
max_retries = 0
|
125
|
+
elif num_retries is not None:
|
126
|
+
max_retries = num_retries
|
127
|
+
elif self._num_retries is not None:
|
128
|
+
max_retries = self._num_retries
|
129
|
+
else:
|
130
|
+
max_retries = 1000000
|
117
131
|
|
118
|
-
|
132
|
+
if retry_timedelta is not None:
|
133
|
+
timeout = retry_timedelta
|
134
|
+
elif self._retry_timedelta is not None:
|
135
|
+
timeout = self._retry_timedelta
|
136
|
+
else:
|
137
|
+
timeout = datetime.timedelta(days=365)
|
119
138
|
|
120
|
-
|
121
|
-
|
122
|
-
|
139
|
+
if retry_sleep_base is not None:
|
140
|
+
initial_sleep = retry_sleep_base
|
141
|
+
else:
|
142
|
+
initial_sleep = 1
|
143
|
+
|
144
|
+
retry_loop = _RetryLoop(
|
145
|
+
max_retries=max_retries,
|
146
|
+
timeout=timeout,
|
147
|
+
initial_sleep=initial_sleep,
|
148
|
+
max_sleep=self.MAX_SLEEP_SECONDS,
|
149
|
+
cancel_event=retry_cancel_event or self._retry_cancel_event,
|
150
|
+
retry_check=check_retry_fn or self._check_retry_fn,
|
123
151
|
)
|
124
152
|
|
125
|
-
|
126
|
-
now = NOW_FN()
|
127
|
-
start_time = now
|
128
|
-
start_time_triggered = None
|
129
|
-
|
153
|
+
start_time = NOW_FN()
|
130
154
|
self._num_iter = 0
|
131
155
|
|
132
156
|
while True:
|
133
157
|
try:
|
134
158
|
result = self._call_fn(*args, **kwargs)
|
135
|
-
|
136
|
-
if self._num_iter > 2 and now - self._last_print > datetime.timedelta(
|
137
|
-
minutes=1
|
138
|
-
):
|
139
|
-
self._last_print = NOW_FN()
|
140
|
-
if self.retry_callback:
|
141
|
-
self.retry_callback(
|
142
|
-
200,
|
143
|
-
f"{self._error_prefix} resolved after {NOW_FN() - start_time}, resuming normal operation.",
|
144
|
-
)
|
145
|
-
return result
|
159
|
+
|
146
160
|
except self._retryable_exceptions as e:
|
147
|
-
|
148
|
-
retry_timedelta_triggered = check_retry_fn(e)
|
149
|
-
if not retry_timedelta_triggered:
|
161
|
+
if not retry_loop.should_retry(e):
|
150
162
|
raise
|
151
163
|
|
152
|
-
|
153
|
-
|
154
|
-
|
164
|
+
if self._num_iter == 2:
|
165
|
+
logger.info("Retry attempt failed:", exc_info=e)
|
166
|
+
self._print_entered_retry_loop(e)
|
155
167
|
|
156
|
-
|
168
|
+
retry_loop.wait_before_retry()
|
169
|
+
self._num_iter += 1
|
157
170
|
|
158
|
-
|
159
|
-
if
|
160
|
-
|
161
|
-
if not start_time_triggered:
|
162
|
-
start_time_triggered = now
|
171
|
+
else:
|
172
|
+
if self._num_iter > 2:
|
173
|
+
self._print_recovered(start_time)
|
163
174
|
|
164
|
-
|
165
|
-
if now - start_time_triggered >= retry_timedelta_triggered:
|
166
|
-
raise
|
175
|
+
return result
|
167
176
|
|
168
|
-
|
169
|
-
|
170
|
-
raise
|
177
|
+
def _print_entered_retry_loop(self, exception: Exception) -> None:
|
178
|
+
"""Emit a message saying we've begun retrying.
|
171
179
|
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
f"{self._error_prefix} ({e.__class__.__name__}), entering retry loop."
|
186
|
-
)
|
187
|
-
# if wandb.env.is_debug():
|
188
|
-
# traceback.print_exc()
|
189
|
-
cancelled = self._sleep_check_cancelled(
|
190
|
-
sleep + random.random() * 0.25 * sleep, cancel_event=retry_cancel_event
|
180
|
+
Either calls the retry callback or prints a warning to console.
|
181
|
+
|
182
|
+
Args:
|
183
|
+
exception: The most recent exception we will retry.
|
184
|
+
"""
|
185
|
+
if (
|
186
|
+
isinstance(exception, HTTPError)
|
187
|
+
and exception.response is not None
|
188
|
+
and self.retry_callback is not None
|
189
|
+
):
|
190
|
+
self.retry_callback(
|
191
|
+
exception.response.status_code,
|
192
|
+
exception.response.text,
|
191
193
|
)
|
194
|
+
else:
|
195
|
+
wandb.termlog(
|
196
|
+
f"{self._error_prefix}"
|
197
|
+
f" ({exception.__class__.__name__}), entering retry loop."
|
198
|
+
)
|
199
|
+
|
200
|
+
def _print_recovered(self, start_time: datetime.datetime) -> None:
|
201
|
+
"""Emit a message saying we've recovered after retrying.
|
202
|
+
|
203
|
+
Args:
|
204
|
+
start_time: When we started retrying.
|
205
|
+
"""
|
206
|
+
if not self.retry_callback:
|
207
|
+
return
|
208
|
+
|
209
|
+
now = NOW_FN()
|
210
|
+
if now - self._last_print < datetime.timedelta(minutes=1):
|
211
|
+
return
|
212
|
+
self._last_print = now
|
213
|
+
|
214
|
+
time_to_recover = now - start_time
|
215
|
+
self.retry_callback(
|
216
|
+
200,
|
217
|
+
(
|
218
|
+
f"{self._error_prefix} resolved after"
|
219
|
+
f" {time_to_recover}, resuming normal operation."
|
220
|
+
),
|
221
|
+
)
|
222
|
+
|
223
|
+
|
224
|
+
class _RetryLoop:
|
225
|
+
"""An invocation of a Retry instance."""
|
226
|
+
|
227
|
+
def __init__(
|
228
|
+
self,
|
229
|
+
*,
|
230
|
+
max_retries: int,
|
231
|
+
timeout: datetime.timedelta,
|
232
|
+
initial_sleep: float,
|
233
|
+
max_sleep: float,
|
234
|
+
cancel_event: Optional[threading.Event],
|
235
|
+
retry_check: CheckRetryFnType,
|
236
|
+
) -> None:
|
237
|
+
"""Start a new call of a Retry instance.
|
238
|
+
|
239
|
+
Args:
|
240
|
+
max_retries: The number of retries after which to give up.
|
241
|
+
timeout: An amount of time after which to give up.
|
242
|
+
initial_sleep: Number of seconds to sleep for the first retry.
|
243
|
+
This is used as the base for exponential backoff.
|
244
|
+
max_sleep: Maximum number of seconds to sleep between retries.
|
245
|
+
cancel_event: An event that's set when the function is cancelled.
|
246
|
+
retry_check: A custom check for deciding whether an exception should
|
247
|
+
be retried. Retrying is prevented if this returns a falsy value,
|
248
|
+
even if more retries are left. This may also return a timedelta
|
249
|
+
that represents a shorter timeout: retrying is prevented if the
|
250
|
+
value is less than the amount of time that has passed since the
|
251
|
+
last timedelta was returned.
|
252
|
+
"""
|
253
|
+
self._max_retries = max_retries
|
254
|
+
self._total_retries = 0
|
255
|
+
|
256
|
+
self._timeout = timeout
|
257
|
+
self._start_time = NOW_FN()
|
258
|
+
|
259
|
+
self._next_sleep_time = initial_sleep
|
260
|
+
self._max_sleep = max_sleep
|
261
|
+
self._cancel_event = cancel_event
|
262
|
+
|
263
|
+
self._retry_check = retry_check
|
264
|
+
self._last_custom_timeout: Optional[datetime.datetime] = None
|
265
|
+
|
266
|
+
def should_retry(self, exception: Exception) -> bool:
|
267
|
+
"""Returns whether an exception should be retried."""
|
268
|
+
if self._total_retries >= self._max_retries:
|
269
|
+
return False
|
270
|
+
self._total_retries += 1
|
271
|
+
|
272
|
+
now = NOW_FN()
|
273
|
+
if now - self._start_time >= self._timeout:
|
274
|
+
return False
|
275
|
+
|
276
|
+
retry_check_result = self._retry_check(exception)
|
277
|
+
if not retry_check_result:
|
278
|
+
return False
|
279
|
+
|
280
|
+
if isinstance(retry_check_result, datetime.timedelta):
|
281
|
+
if not self._last_custom_timeout:
|
282
|
+
self._last_custom_timeout = now
|
283
|
+
|
284
|
+
if now - self._last_custom_timeout >= retry_check_result:
|
285
|
+
return False
|
286
|
+
|
287
|
+
return True
|
288
|
+
|
289
|
+
def wait_before_retry(self) -> None:
|
290
|
+
"""Block until the next retry should happen.
|
291
|
+
|
292
|
+
Raises:
|
293
|
+
RetryCancelledError: If the operation is cancelled.
|
294
|
+
"""
|
295
|
+
sleep_amount = self._next_sleep_time * (1 + random.random() * 0.25)
|
296
|
+
|
297
|
+
if self._cancel_event:
|
298
|
+
cancelled = self._cancel_event.wait(sleep_amount)
|
192
299
|
if cancelled:
|
193
|
-
raise RetryCancelledError("
|
194
|
-
|
195
|
-
|
196
|
-
sleep = self.MAX_SLEEP_SECONDS
|
197
|
-
now = NOW_FN()
|
300
|
+
raise RetryCancelledError("Cancelled while retrying.")
|
301
|
+
else:
|
302
|
+
SLEEP_FN(sleep_amount)
|
198
303
|
|
199
|
-
|
304
|
+
self._next_sleep_time *= 2
|
305
|
+
if self._next_sleep_time > self._max_sleep:
|
306
|
+
self._next_sleep_time = self._max_sleep
|
200
307
|
|
201
308
|
|
202
309
|
_F = TypeVar("_F", bound=Callable)
|
@@ -0,0 +1,106 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import asyncio
|
4
|
+
import logging
|
5
|
+
import struct
|
6
|
+
|
7
|
+
from wandb.proto import wandb_server_pb2 as spb
|
8
|
+
from wandb.sdk.lib import asyncio_manager
|
9
|
+
from wandb.sdk.mailbox.mailbox import Mailbox
|
10
|
+
from wandb.sdk.mailbox.mailbox_handle import MailboxHandle
|
11
|
+
|
12
|
+
_logger = logging.getLogger(__name__)
|
13
|
+
|
14
|
+
_HEADER_BYTE_INT_LEN = 5
|
15
|
+
_HEADER_BYTE_INT_FMT = "<BI"
|
16
|
+
|
17
|
+
|
18
|
+
class ServiceClient:
|
19
|
+
"""Implements socket communication with the internal service."""
|
20
|
+
|
21
|
+
def __init__(
|
22
|
+
self,
|
23
|
+
asyncer: asyncio_manager.AsyncioManager,
|
24
|
+
reader: asyncio.StreamReader,
|
25
|
+
writer: asyncio.StreamWriter,
|
26
|
+
) -> None:
|
27
|
+
self._asyncer = asyncer
|
28
|
+
self._reader = reader
|
29
|
+
self._writer = writer
|
30
|
+
self._mailbox = Mailbox(asyncer)
|
31
|
+
asyncer.run_soon(
|
32
|
+
self._forward_responses,
|
33
|
+
daemon=True,
|
34
|
+
name="ServiceClient._forward_responses",
|
35
|
+
)
|
36
|
+
|
37
|
+
def publish(self, request: spb.ServerRequest) -> None:
|
38
|
+
"""Send a request without waiting for a response."""
|
39
|
+
self._asyncer.run_soon(lambda: self._send_server_request(request))
|
40
|
+
|
41
|
+
def deliver(
|
42
|
+
self,
|
43
|
+
request: spb.ServerRequest,
|
44
|
+
) -> MailboxHandle[spb.ServerResponse]:
|
45
|
+
"""Send a request and return a handle to wait for a response.
|
46
|
+
|
47
|
+
NOTE: This may mutate the request. The request should not be used
|
48
|
+
after.
|
49
|
+
|
50
|
+
Raises:
|
51
|
+
MailboxClosedError: If used after the client is closed or has
|
52
|
+
stopped due to an error.
|
53
|
+
"""
|
54
|
+
handle = self._mailbox.require_response(request)
|
55
|
+
self._asyncer.run_soon(lambda: self._send_server_request(request))
|
56
|
+
return handle
|
57
|
+
|
58
|
+
async def _send_server_request(self, request: spb.ServerRequest) -> None:
|
59
|
+
header = struct.pack(_HEADER_BYTE_INT_FMT, ord("W"), request.ByteSize())
|
60
|
+
self._writer.write(header)
|
61
|
+
|
62
|
+
data = request.SerializeToString()
|
63
|
+
self._writer.write(data)
|
64
|
+
|
65
|
+
await self._writer.drain()
|
66
|
+
|
67
|
+
def close(self) -> None:
|
68
|
+
"""Flush and close the socket."""
|
69
|
+
self._asyncer.run_soon(self._close)
|
70
|
+
|
71
|
+
async def _close(self) -> None:
|
72
|
+
self._writer.close()
|
73
|
+
await self._writer.wait_closed()
|
74
|
+
|
75
|
+
async def _forward_responses(self) -> None:
|
76
|
+
try:
|
77
|
+
while response := await self._read_server_response():
|
78
|
+
await self._mailbox.deliver(response)
|
79
|
+
|
80
|
+
except Exception:
|
81
|
+
_logger.exception("Error reading server response.")
|
82
|
+
|
83
|
+
else:
|
84
|
+
_logger.info("Reached EOF.")
|
85
|
+
|
86
|
+
finally:
|
87
|
+
self._mailbox.close()
|
88
|
+
|
89
|
+
async def _read_server_response(self) -> spb.ServerResponse | None:
|
90
|
+
try:
|
91
|
+
header = await self._reader.readexactly(_HEADER_BYTE_INT_LEN)
|
92
|
+
except asyncio.IncompleteReadError as e:
|
93
|
+
if e.partial:
|
94
|
+
raise
|
95
|
+
else:
|
96
|
+
return None
|
97
|
+
|
98
|
+
magic, length = struct.unpack(_HEADER_BYTE_INT_FMT, header)
|
99
|
+
|
100
|
+
if magic != ord("W"):
|
101
|
+
raise ValueError(f"Bad header: {header.hex()}")
|
102
|
+
|
103
|
+
data = await self._reader.readexactly(length)
|
104
|
+
response = spb.ServerResponse()
|
105
|
+
response.ParseFromString(data)
|
106
|
+
return response
|
@@ -3,16 +3,15 @@ from __future__ import annotations
|
|
3
3
|
import atexit
|
4
4
|
from typing import Callable
|
5
5
|
|
6
|
-
from wandb.proto import wandb_internal_pb2 as pb
|
7
6
|
from wandb.proto import wandb_server_pb2 as spb
|
8
7
|
from wandb.proto import wandb_settings_pb2
|
9
8
|
from wandb.sdk import wandb_settings
|
10
9
|
from wandb.sdk.interface.interface import InterfaceBase
|
11
10
|
from wandb.sdk.interface.interface_sock import InterfaceSock
|
12
|
-
from wandb.sdk.
|
11
|
+
from wandb.sdk.lib import asyncio_manager
|
13
12
|
from wandb.sdk.lib.exit_hooks import ExitHooks
|
14
|
-
from wandb.sdk.lib.
|
15
|
-
from wandb.sdk.mailbox import HandleAbandonedError,
|
13
|
+
from wandb.sdk.lib.service.service_client import ServiceClient
|
14
|
+
from wandb.sdk.mailbox import HandleAbandonedError, MailboxClosedError
|
16
15
|
|
17
16
|
from . import service_process, service_token
|
18
17
|
|
@@ -22,18 +21,23 @@ class WandbAttachFailedError(Exception):
|
|
22
21
|
|
23
22
|
|
24
23
|
def connect_to_service(
|
24
|
+
asyncer: asyncio_manager.AsyncioManager,
|
25
25
|
settings: wandb_settings.Settings,
|
26
26
|
) -> ServiceConnection:
|
27
27
|
"""Connect to the service process, starting one up if necessary."""
|
28
28
|
token = service_token.from_env()
|
29
29
|
|
30
30
|
if token:
|
31
|
-
return ServiceConnection(
|
31
|
+
return ServiceConnection(
|
32
|
+
client=token.connect(asyncer=asyncer),
|
33
|
+
proc=None,
|
34
|
+
)
|
32
35
|
else:
|
33
|
-
return _start_and_connect_service(settings)
|
36
|
+
return _start_and_connect_service(asyncer, settings)
|
34
37
|
|
35
38
|
|
36
39
|
def _start_and_connect_service(
|
40
|
+
asyncer: asyncio_manager.AsyncioManager,
|
37
41
|
settings: wandb_settings.Settings,
|
38
42
|
) -> ServiceConnection:
|
39
43
|
"""Start a service process and returns a connection to it.
|
@@ -44,7 +48,7 @@ def _start_and_connect_service(
|
|
44
48
|
"""
|
45
49
|
proc = service_process.start(settings)
|
46
50
|
|
47
|
-
client = proc.token.connect()
|
51
|
+
client = proc.token.connect(asyncer=asyncer)
|
48
52
|
proc.token.save_to_env()
|
49
53
|
|
50
54
|
hooks = ExitHooks()
|
@@ -69,7 +73,7 @@ class ServiceConnection:
|
|
69
73
|
|
70
74
|
def __init__(
|
71
75
|
self,
|
72
|
-
client:
|
76
|
+
client: ServiceClient,
|
73
77
|
proc: service_process.ServiceProcess | None,
|
74
78
|
cleanup: Callable[[], None] | None = None,
|
75
79
|
):
|
@@ -88,16 +92,9 @@ class ServiceConnection:
|
|
88
92
|
self._torn_down = False
|
89
93
|
self._cleanup = cleanup
|
90
94
|
|
91
|
-
self._mailbox = Mailbox()
|
92
|
-
self._router = MessageSockRouter(self._client, self._mailbox)
|
93
|
-
|
94
95
|
def make_interface(self, stream_id: str) -> InterfaceBase:
|
95
96
|
"""Returns an interface for communicating with the service."""
|
96
|
-
return InterfaceSock(self._client,
|
97
|
-
|
98
|
-
def send_record(self, record: pb.Record) -> None:
|
99
|
-
"""Send data to the service."""
|
100
|
-
self._client.send_record_publish(record)
|
97
|
+
return InterfaceSock(self._client, stream_id=stream_id)
|
101
98
|
|
102
99
|
def inform_init(
|
103
100
|
self,
|
@@ -108,13 +105,13 @@ class ServiceConnection:
|
|
108
105
|
request = spb.ServerInformInitRequest()
|
109
106
|
request.settings.CopyFrom(settings)
|
110
107
|
request._info.stream_id = run_id
|
111
|
-
self._client.
|
108
|
+
self._client.publish(spb.ServerRequest(inform_init=request))
|
112
109
|
|
113
110
|
def inform_finish(self, run_id: str) -> None:
|
114
111
|
"""Send an finish request to the service."""
|
115
112
|
request = spb.ServerInformFinishRequest()
|
116
113
|
request._info.stream_id = run_id
|
117
|
-
self._client.
|
114
|
+
self._client.publish(spb.ServerRequest(inform_finish=request))
|
118
115
|
|
119
116
|
def inform_attach(
|
120
117
|
self,
|
@@ -128,11 +125,10 @@ class ServiceConnection:
|
|
128
125
|
request.inform_attach._info.stream_id = attach_id
|
129
126
|
|
130
127
|
try:
|
131
|
-
handle = self.
|
132
|
-
self._client.send_server_request(request)
|
128
|
+
handle = self._client.deliver(request)
|
133
129
|
response = handle.wait_or(timeout=10)
|
134
130
|
|
135
|
-
except (MailboxClosedError, HandleAbandonedError
|
131
|
+
except (MailboxClosedError, HandleAbandonedError):
|
136
132
|
raise WandbAttachFailedError(
|
137
133
|
"Failed to attach: the service process is not running.",
|
138
134
|
) from None
|
@@ -156,7 +152,7 @@ class ServiceConnection:
|
|
156
152
|
request = spb.ServerInformStartRequest()
|
157
153
|
request.settings.CopyFrom(settings)
|
158
154
|
request._info.stream_id = run_id
|
159
|
-
self._client.
|
155
|
+
self._client.publish(spb.ServerRequest(inform_start=request))
|
160
156
|
|
161
157
|
def teardown(self, exit_code: int) -> int | None:
|
162
158
|
"""Close the connection.
|
@@ -178,21 +174,19 @@ class ServiceConnection:
|
|
178
174
|
if self._cleanup:
|
179
175
|
self._cleanup()
|
180
176
|
|
181
|
-
# Stop reading responses on the socket.
|
182
|
-
self._router.join()
|
183
|
-
|
184
177
|
if not self._proc:
|
185
178
|
return None
|
186
179
|
|
187
180
|
# Clear the service token to prevent new connections to the process.
|
188
181
|
service_token.clear_service_in_env()
|
189
182
|
|
190
|
-
self._client.
|
183
|
+
self._client.publish(
|
191
184
|
spb.ServerRequest(
|
192
185
|
inform_teardown=spb.ServerInformTeardownRequest(
|
193
186
|
exit_code=exit_code,
|
194
187
|
)
|
195
188
|
),
|
196
189
|
)
|
190
|
+
self._client.close()
|
197
191
|
|
198
192
|
return self._proc.join()
|
@@ -1,15 +1,17 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import abc
|
4
|
+
import asyncio
|
4
5
|
import os
|
5
6
|
import re
|
6
|
-
import socket
|
7
7
|
|
8
8
|
from typing_extensions import final, override
|
9
9
|
|
10
10
|
from wandb import env
|
11
|
+
from wandb.sdk.lib import asyncio_manager
|
11
12
|
from wandb.sdk.lib.service import ipc_support
|
12
|
-
|
13
|
+
|
14
|
+
from .service_client import ServiceClient
|
13
15
|
|
14
16
|
_CURRENT_VERSION = "3"
|
15
17
|
|
@@ -53,9 +55,16 @@ class ServiceToken(abc.ABC):
|
|
53
55
|
"""A way of connecting to a running service process."""
|
54
56
|
|
55
57
|
@abc.abstractmethod
|
56
|
-
def connect(
|
58
|
+
def connect(
|
59
|
+
self,
|
60
|
+
*,
|
61
|
+
asyncer: asyncio_manager.AsyncioManager,
|
62
|
+
) -> ServiceClient:
|
57
63
|
"""Connect to the service process.
|
58
64
|
|
65
|
+
Args:
|
66
|
+
asyncer: A started AsyncioManager for asyncio operations.
|
67
|
+
|
59
68
|
Returns:
|
60
69
|
A socket object for communicating with the service.
|
61
70
|
|
@@ -81,21 +90,25 @@ class UnixServiceToken(ServiceToken):
|
|
81
90
|
self._path = path
|
82
91
|
|
83
92
|
@override
|
84
|
-
def connect(
|
93
|
+
def connect(
|
94
|
+
self,
|
95
|
+
*,
|
96
|
+
asyncer: asyncio_manager.AsyncioManager,
|
97
|
+
) -> ServiceClient:
|
85
98
|
if not ipc_support.SUPPORTS_UNIX:
|
86
99
|
raise WandbServiceConnectionError("AF_UNIX socket not supported")
|
87
100
|
|
88
|
-
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
89
|
-
|
90
101
|
try:
|
91
102
|
# TODO: This may block indefinitely if the service is unhealthy.
|
92
|
-
|
103
|
+
reader, writer = asyncer.run(
|
104
|
+
lambda: asyncio.open_unix_connection(self._path),
|
105
|
+
)
|
93
106
|
except Exception as e:
|
94
107
|
raise WandbServiceConnectionError(
|
95
108
|
f"Failed to connect to service on socket {self._path}",
|
96
109
|
) from e
|
97
110
|
|
98
|
-
return
|
111
|
+
return ServiceClient(asyncer, reader, writer)
|
99
112
|
|
100
113
|
@override
|
101
114
|
def _as_env_string(self):
|
@@ -128,18 +141,22 @@ class TCPServiceToken(ServiceToken):
|
|
128
141
|
self._port = port
|
129
142
|
|
130
143
|
@override
|
131
|
-
def connect(
|
132
|
-
|
133
|
-
|
144
|
+
def connect(
|
145
|
+
self,
|
146
|
+
*,
|
147
|
+
asyncer: asyncio_manager.AsyncioManager,
|
148
|
+
) -> ServiceClient:
|
134
149
|
try:
|
135
150
|
# TODO: This may block indefinitely if the service is unhealthy.
|
136
|
-
|
151
|
+
reader, writer = asyncer.run(
|
152
|
+
lambda: asyncio.open_connection("localhost", self._port),
|
153
|
+
)
|
137
154
|
except Exception as e:
|
138
155
|
raise WandbServiceConnectionError(
|
139
156
|
f"Failed to connect to service on port {self._port}",
|
140
157
|
) from e
|
141
158
|
|
142
|
-
return
|
159
|
+
return ServiceClient(asyncer, reader, writer)
|
143
160
|
|
144
161
|
@override
|
145
162
|
def _as_env_string(self):
|