wandb 0.19.6rc4__py3-none-any.whl → 0.19.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +1 -1
- wandb/__init__.pyi +25 -5
- wandb/apis/public/_generated/__init__.py +21 -0
- wandb/apis/public/_generated/base.py +128 -0
- wandb/apis/public/_generated/enums.py +4 -0
- wandb/apis/public/_generated/input_types.py +4 -0
- wandb/apis/public/_generated/operations.py +15 -0
- wandb/apis/public/_generated/server_features_query.py +27 -0
- wandb/apis/public/_generated/typing_compat.py +14 -0
- wandb/apis/public/api.py +192 -6
- wandb/apis/public/artifacts.py +13 -45
- wandb/apis/public/registries.py +573 -0
- wandb/apis/public/utils.py +36 -0
- wandb/bin/gpu_stats +0 -0
- wandb/cli/cli.py +11 -20
- wandb/env.py +10 -0
- wandb/proto/v3/wandb_internal_pb2.py +243 -222
- wandb/proto/v3/wandb_server_pb2.py +4 -4
- wandb/proto/v3/wandb_settings_pb2.py +1 -1
- wandb/proto/v4/wandb_internal_pb2.py +226 -222
- wandb/proto/v4/wandb_server_pb2.py +4 -4
- wandb/proto/v4/wandb_settings_pb2.py +1 -1
- wandb/proto/v5/wandb_internal_pb2.py +226 -222
- wandb/proto/v5/wandb_server_pb2.py +4 -4
- wandb/proto/v5/wandb_settings_pb2.py +1 -1
- wandb/sdk/artifacts/_graphql_fragments.py +126 -0
- wandb/sdk/artifacts/artifact.py +43 -88
- wandb/sdk/backend/backend.py +1 -1
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +14 -6
- wandb/sdk/data_types/helper_types/image_mask.py +12 -6
- wandb/sdk/data_types/saved_model.py +35 -46
- wandb/sdk/data_types/video.py +7 -16
- wandb/sdk/interface/interface.py +26 -10
- wandb/sdk/interface/interface_queue.py +5 -8
- wandb/sdk/interface/interface_relay.py +1 -6
- wandb/sdk/interface/interface_shared.py +21 -99
- wandb/sdk/interface/interface_sock.py +2 -13
- wandb/sdk/interface/router.py +21 -15
- wandb/sdk/interface/router_queue.py +2 -1
- wandb/sdk/interface/router_relay.py +2 -1
- wandb/sdk/interface/router_sock.py +5 -4
- wandb/sdk/internal/handler.py +4 -3
- wandb/sdk/internal/internal_api.py +12 -1
- wandb/sdk/internal/sender.py +0 -18
- wandb/sdk/lib/apikey.py +87 -26
- wandb/sdk/lib/asyncio_compat.py +210 -0
- wandb/sdk/lib/progress.py +78 -16
- wandb/sdk/lib/service_connection.py +1 -1
- wandb/sdk/lib/sock_client.py +7 -7
- wandb/sdk/mailbox/__init__.py +23 -0
- wandb/sdk/mailbox/handles.py +199 -0
- wandb/sdk/mailbox/mailbox.py +121 -0
- wandb/sdk/mailbox/wait_with_progress.py +134 -0
- wandb/sdk/service/server_sock.py +5 -1
- wandb/sdk/service/streams.py +66 -74
- wandb/sdk/verify/verify.py +54 -2
- wandb/sdk/wandb_init.py +61 -61
- wandb/sdk/wandb_login.py +7 -4
- wandb/sdk/wandb_metadata.py +65 -34
- wandb/sdk/wandb_require.py +14 -8
- wandb/sdk/wandb_run.py +82 -87
- wandb/sdk/wandb_settings.py +3 -3
- wandb/sdk/wandb_setup.py +19 -8
- wandb/sdk/wandb_sync.py +2 -4
- wandb/util.py +3 -1
- {wandb-0.19.6rc4.dist-info → wandb-0.19.7.dist-info}/METADATA +2 -2
- {wandb-0.19.6rc4.dist-info → wandb-0.19.7.dist-info}/RECORD +70 -57
- wandb/sdk/lib/mailbox.py +0 -442
- {wandb-0.19.6rc4.dist-info → wandb-0.19.7.dist-info}/WHEEL +0 -0
- {wandb-0.19.6rc4.dist-info → wandb-0.19.7.dist-info}/entry_points.txt +0 -0
- {wandb-0.19.6rc4.dist-info → wandb-0.19.7.dist-info}/licenses/LICENSE +0 -0
wandb/sdk/service/server_sock.py
CHANGED
@@ -70,6 +70,7 @@ class SockServerInterfaceReaderThread(threading.Thread):
|
|
70
70
|
sock_client = self._clients.get_client(sockid)
|
71
71
|
assert sock_client
|
72
72
|
sresp = spb.ServerResponse()
|
73
|
+
sresp.request_id = result.control.mailbox_slot
|
73
74
|
sresp.result_communicate.CopyFrom(result)
|
74
75
|
sock_client.send_server_response(sresp)
|
75
76
|
|
@@ -148,7 +149,10 @@ class SockServerReadThread(threading.Thread):
|
|
148
149
|
inform_attach_response.settings.CopyFrom(
|
149
150
|
self._mux._streams[stream_id]._settings._proto,
|
150
151
|
)
|
151
|
-
response = spb.ServerResponse(
|
152
|
+
response = spb.ServerResponse(
|
153
|
+
request_id=sreq.request_id,
|
154
|
+
inform_attach_response=inform_attach_response,
|
155
|
+
)
|
152
156
|
self._sock_client.send_server_response(response)
|
153
157
|
iface = self._mux.get_stream(stream_id).interface
|
154
158
|
|
wandb/sdk/service/streams.py
CHANGED
@@ -8,13 +8,13 @@ StreamMux: Container for dictionary of stream threads per runid
|
|
8
8
|
|
9
9
|
from __future__ import annotations
|
10
10
|
|
11
|
+
import asyncio
|
11
12
|
import functools
|
12
|
-
import multiprocessing
|
13
13
|
import queue
|
14
14
|
import threading
|
15
15
|
import time
|
16
16
|
from threading import Event
|
17
|
-
from typing import Any, Callable
|
17
|
+
from typing import Any, Callable, NoReturn
|
18
18
|
|
19
19
|
import psutil
|
20
20
|
|
@@ -22,14 +22,9 @@ import wandb
|
|
22
22
|
import wandb.util
|
23
23
|
from wandb.proto import wandb_internal_pb2 as pb
|
24
24
|
from wandb.sdk.internal.settings_static import SettingsStatic
|
25
|
+
from wandb.sdk.lib import asyncio_compat, progress
|
25
26
|
from wandb.sdk.lib import printer as printerlib
|
26
|
-
from wandb.sdk.
|
27
|
-
from wandb.sdk.lib.mailbox import (
|
28
|
-
Mailbox,
|
29
|
-
MailboxProbe,
|
30
|
-
MailboxProgress,
|
31
|
-
MailboxProgressAll,
|
32
|
-
)
|
27
|
+
from wandb.sdk.mailbox import Mailbox, MailboxHandle, wait_all_with_progress
|
33
28
|
from wandb.sdk.wandb_run import Run
|
34
29
|
|
35
30
|
from ..interface.interface_relay import InterfaceRelay
|
@@ -61,19 +56,16 @@ class StreamRecord:
|
|
61
56
|
_settings: SettingsStatic
|
62
57
|
_started: bool
|
63
58
|
|
64
|
-
def __init__(self, settings: SettingsStatic
|
59
|
+
def __init__(self, settings: SettingsStatic) -> None:
|
65
60
|
self._started = False
|
66
|
-
self._mailbox =
|
61
|
+
self._mailbox = Mailbox()
|
67
62
|
self._record_q = queue.Queue()
|
68
63
|
self._result_q = queue.Queue()
|
69
64
|
self._relay_q = queue.Queue()
|
70
|
-
process = multiprocessing.current_process()
|
71
65
|
self._iface = InterfaceRelay(
|
72
66
|
record_q=self._record_q,
|
73
67
|
result_q=self._result_q,
|
74
68
|
relay_q=self._relay_q,
|
75
|
-
process=process,
|
76
|
-
process_check=False,
|
77
69
|
mailbox=self._mailbox,
|
78
70
|
)
|
79
71
|
self._settings = settings
|
@@ -84,7 +76,7 @@ class StreamRecord:
|
|
84
76
|
self._wait_thread_active()
|
85
77
|
|
86
78
|
def _wait_thread_active(self) -> None:
|
87
|
-
self._iface.deliver_status().
|
79
|
+
self._iface.deliver_status().wait_or(timeout=None)
|
88
80
|
|
89
81
|
def join(self) -> None:
|
90
82
|
self._iface.join()
|
@@ -141,7 +133,6 @@ class StreamMux:
|
|
141
133
|
_action_q: queue.Queue[StreamAction]
|
142
134
|
_stopped: Event
|
143
135
|
_pid_checked_ts: float | None
|
144
|
-
_mailbox: Mailbox
|
145
136
|
|
146
137
|
def __init__(self) -> None:
|
147
138
|
self._streams_lock = threading.Lock()
|
@@ -151,8 +142,6 @@ class StreamMux:
|
|
151
142
|
self._stopped = Event()
|
152
143
|
self._action_q = queue.Queue()
|
153
144
|
self._pid_checked_ts = None
|
154
|
-
self._mailbox = Mailbox()
|
155
|
-
self._mailbox.enable_keepalive()
|
156
145
|
|
157
146
|
def _get_stopped_event(self) -> Event:
|
158
147
|
# TODO: clean this up, there should be a better way to abstract this
|
@@ -209,7 +198,7 @@ class StreamMux:
|
|
209
198
|
return stream
|
210
199
|
|
211
200
|
def _process_add(self, action: StreamAction) -> None:
|
212
|
-
stream = StreamRecord(action._data
|
201
|
+
stream = StreamRecord(action._data)
|
213
202
|
# run_id = action.stream_id # will want to fix if a streamid != runid
|
214
203
|
settings = action._data
|
215
204
|
thread = StreamThread(
|
@@ -247,41 +236,51 @@ class StreamMux:
|
|
247
236
|
stream.drop()
|
248
237
|
stream.join()
|
249
238
|
|
250
|
-
def
|
251
|
-
handle = probe_handle.get_mailbox_handle()
|
252
|
-
if handle:
|
253
|
-
result = handle.wait(timeout=0, release=False)
|
254
|
-
if not result:
|
255
|
-
return
|
256
|
-
probe_handle.set_probe_result(result)
|
257
|
-
handle = stream.interface.deliver_poll_exit()
|
258
|
-
probe_handle.set_mailbox_handle(handle)
|
259
|
-
|
260
|
-
def _on_progress_exit(self, progress_handle: MailboxProgress) -> None:
|
261
|
-
pass
|
262
|
-
|
263
|
-
def _on_progress_exit_all(
|
239
|
+
async def _finish_all_progress(
|
264
240
|
self,
|
265
241
|
progress_printer: progress.ProgressPrinter,
|
266
|
-
|
242
|
+
streams_to_watch: dict[str, StreamRecord],
|
267
243
|
) -> None:
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
244
|
+
"""Poll the streams and display statistics about them.
|
245
|
+
|
246
|
+
This never returns and must be cancelled.
|
247
|
+
|
248
|
+
Args:
|
249
|
+
progress_printer: Printer to use for displaying finish progress.
|
250
|
+
streams_to_watch: Streams to poll for finish progress.
|
251
|
+
"""
|
252
|
+
results: dict[str, pb.Result | None] = {}
|
253
|
+
|
254
|
+
async def loop_poll_stream(
|
255
|
+
stream_id: str,
|
256
|
+
stream: StreamRecord,
|
257
|
+
) -> NoReturn:
|
258
|
+
while True:
|
259
|
+
start_time = time.monotonic()
|
260
|
+
|
261
|
+
handle = stream.interface.deliver_poll_exit()
|
262
|
+
results[stream_id] = await handle.wait_async(timeout=None)
|
263
|
+
|
264
|
+
elapsed_time = time.monotonic() - start_time
|
265
|
+
if elapsed_time < 1:
|
266
|
+
await asyncio.sleep(1 - elapsed_time)
|
267
|
+
|
268
|
+
async def loop_update_printer() -> NoReturn:
|
269
|
+
while True:
|
270
|
+
poll_exit_responses: list[pb.PollExitResponse] = []
|
271
|
+
for result in results.values():
|
272
|
+
if not result or not result.response:
|
273
|
+
continue
|
274
|
+
if poll_exit_response := result.response.poll_exit_response:
|
275
|
+
poll_exit_responses.append(poll_exit_response)
|
276
|
+
|
277
|
+
progress_printer.update(poll_exit_responses)
|
278
|
+
await asyncio.sleep(1)
|
279
|
+
|
280
|
+
async with asyncio_compat.open_task_group() as task_group:
|
281
|
+
for stream_id, stream in streams_to_watch.items():
|
282
|
+
task_group.start_soon(loop_poll_stream(stream_id, stream))
|
283
|
+
task_group.start_soon(loop_update_printer())
|
285
284
|
|
286
285
|
def _finish_all(self, streams: dict[str, StreamRecord], exit_code: int) -> None:
|
287
286
|
if not streams:
|
@@ -291,7 +290,7 @@ class StreamMux:
|
|
291
290
|
|
292
291
|
# fixme: for now we have a single printer for all streams,
|
293
292
|
# and jupyter is disabled if at least single stream's setting set `_jupyter` to false
|
294
|
-
exit_handles = []
|
293
|
+
exit_handles: list[MailboxHandle] = []
|
295
294
|
|
296
295
|
# only finish started streams, non started streams failed early
|
297
296
|
started_streams: dict[str, StreamRecord] = {}
|
@@ -302,27 +301,24 @@ class StreamMux:
|
|
302
301
|
|
303
302
|
for stream in started_streams.values():
|
304
303
|
handle = stream.interface.deliver_exit(exit_code)
|
305
|
-
handle.add_progress(self._on_progress_exit)
|
306
|
-
handle.add_probe(functools.partial(self._on_probe_exit, stream=stream))
|
307
304
|
exit_handles.append(handle)
|
308
305
|
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
with progress.progress_printer(printer) as progress_printer:
|
306
|
+
with progress.progress_printer(
|
307
|
+
printer,
|
308
|
+
default_text="Finishing up...",
|
309
|
+
) as progress_printer:
|
315
310
|
# todo: should we wait for the max timeout (?) of all exit handles or just wait forever?
|
316
311
|
# timeout = max(stream._settings._exit_timeout for stream in streams.values())
|
317
|
-
|
318
|
-
|
319
|
-
timeout
|
320
|
-
|
321
|
-
|
312
|
+
wait_all_with_progress(
|
313
|
+
exit_handles,
|
314
|
+
timeout=None,
|
315
|
+
progress_after=1,
|
316
|
+
display_progress=functools.partial(
|
317
|
+
self._finish_all_progress,
|
322
318
|
progress_printer,
|
319
|
+
started_streams,
|
323
320
|
),
|
324
321
|
)
|
325
|
-
assert got_result
|
326
322
|
|
327
323
|
# These could be done in parallel in the future
|
328
324
|
for _sid, stream in started_streams.items():
|
@@ -332,20 +328,16 @@ class StreamMux:
|
|
332
328
|
sampled_history_handle = stream.interface.deliver_request_sampled_history()
|
333
329
|
internal_messages_handle = stream.interface.deliver_internal_messages()
|
334
330
|
|
335
|
-
result = internal_messages_handle.
|
336
|
-
assert result
|
331
|
+
result = internal_messages_handle.wait_or(timeout=None)
|
337
332
|
internal_messages_response = result.response.internal_messages_response
|
338
333
|
|
339
|
-
result = poll_exit_handle.
|
340
|
-
assert result
|
334
|
+
result = poll_exit_handle.wait_or(timeout=None)
|
341
335
|
poll_exit_response = result.response.poll_exit_response
|
342
336
|
|
343
|
-
result = sampled_history_handle.
|
344
|
-
assert result
|
337
|
+
result = sampled_history_handle.wait_or(timeout=None)
|
345
338
|
sampled_history = result.response.sampled_history_response
|
346
339
|
|
347
|
-
result = final_summary_handle.
|
348
|
-
assert result
|
340
|
+
result = final_summary_handle.wait_or(timeout=None)
|
349
341
|
final_summary = result.response.get_summary_response
|
350
342
|
|
351
343
|
Run._footer(
|
wandb/sdk/verify/verify.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1
1
|
"""Utilities for wandb verify."""
|
2
2
|
|
3
|
+
import contextlib
|
3
4
|
import getpass
|
5
|
+
import io
|
4
6
|
import os
|
5
7
|
import time
|
6
8
|
from functools import partial
|
@@ -163,8 +165,8 @@ def check_run(api: Api) -> bool:
|
|
163
165
|
)
|
164
166
|
print_results(failed_test_strings, False)
|
165
167
|
return False
|
166
|
-
for key, value in
|
167
|
-
if config
|
168
|
+
for key, value in config.items():
|
169
|
+
if prev_run.config.get(key) != value:
|
168
170
|
failed_test_strings.append(
|
169
171
|
"Read config values don't match run config. Contact W&B for support."
|
170
172
|
)
|
@@ -486,6 +488,56 @@ def check_wandb_version(api: Api) -> None:
|
|
486
488
|
print_results(fail_string, warning)
|
487
489
|
|
488
490
|
|
491
|
+
def check_sweeps(api: Api) -> bool:
|
492
|
+
print("Checking sweep creation and agent execution".ljust(72, "."), end="") # noqa: T201
|
493
|
+
failed_test_strings: List[str] = []
|
494
|
+
|
495
|
+
sweep_config = {
|
496
|
+
"method": "random",
|
497
|
+
"metric": {"goal": "minimize", "name": "score"},
|
498
|
+
"parameters": {
|
499
|
+
"x": {"values": [0.01, 0.05, 0.1]},
|
500
|
+
"y": {"values": [1, 2, 3]},
|
501
|
+
},
|
502
|
+
"name": "verify_sweep",
|
503
|
+
}
|
504
|
+
|
505
|
+
try:
|
506
|
+
with contextlib.redirect_stdout(io.StringIO()):
|
507
|
+
sweep_id = wandb.sweep(
|
508
|
+
sweep=sweep_config, project=PROJECT_NAME, entity=api.default_entity
|
509
|
+
)
|
510
|
+
except Exception as e:
|
511
|
+
failed_test_strings.append(f"Failed to create sweep: {e}")
|
512
|
+
print_results(failed_test_strings, False)
|
513
|
+
return False
|
514
|
+
|
515
|
+
if not sweep_id:
|
516
|
+
failed_test_strings.append("Sweep creation returned an invalid ID.")
|
517
|
+
print_results(failed_test_strings, False)
|
518
|
+
return False
|
519
|
+
|
520
|
+
try:
|
521
|
+
|
522
|
+
def objective(config):
|
523
|
+
score = config.x**3 + config.y
|
524
|
+
return score
|
525
|
+
|
526
|
+
def main():
|
527
|
+
with wandb.init(project=PROJECT_NAME) as run:
|
528
|
+
score = objective(run.config)
|
529
|
+
run.log({"score": score})
|
530
|
+
|
531
|
+
wandb.agent(sweep_id, function=main, count=10)
|
532
|
+
except Exception as e:
|
533
|
+
failed_test_strings.append(f"Failed to run sweep agent: {e}")
|
534
|
+
print_results(failed_test_strings, False)
|
535
|
+
return False
|
536
|
+
|
537
|
+
print_results(failed_test_strings, False)
|
538
|
+
return len(failed_test_strings) == 0
|
539
|
+
|
540
|
+
|
489
541
|
def retry_fn(fn: Callable) -> Any:
|
490
542
|
ini_time = time.time()
|
491
543
|
res = None
|
wandb/sdk/wandb_init.py
CHANGED
@@ -34,7 +34,7 @@ from wandb.errors import CommError, Error, UsageError
|
|
34
34
|
from wandb.errors.links import url_registry
|
35
35
|
from wandb.errors.util import ProtobufErrorHandler
|
36
36
|
from wandb.integration import sagemaker
|
37
|
-
from wandb.sdk.lib import runid
|
37
|
+
from wandb.sdk.lib import progress, runid
|
38
38
|
from wandb.sdk.lib.paths import StrPath
|
39
39
|
from wandb.util import _is_artifact_representation
|
40
40
|
|
@@ -42,7 +42,7 @@ from . import wandb_login, wandb_setup
|
|
42
42
|
from .backend.backend import Backend
|
43
43
|
from .lib import SummaryDisabled, filesystem, module, paths, printer, telemetry
|
44
44
|
from .lib.deprecate import Deprecated, deprecate
|
45
|
-
from .
|
45
|
+
from .mailbox import Mailbox, wait_with_progress
|
46
46
|
from .wandb_helper import parse_config
|
47
47
|
from .wandb_run import Run, TeardownHook, TeardownStage
|
48
48
|
from .wandb_settings import Settings
|
@@ -293,6 +293,16 @@ class _WandbInit:
|
|
293
293
|
|
294
294
|
settings.x_start_time = time.time()
|
295
295
|
|
296
|
+
# In shared mode, generate a unique label if not provided.
|
297
|
+
# The label is used to distinguish between system metrics and console logs
|
298
|
+
# from different writers to the same run.
|
299
|
+
if settings._shared and not settings.x_label:
|
300
|
+
# TODO: If executed in a known distributed environment (e.g. Ray or SLURM),
|
301
|
+
# use the env vars to generate a label (e.g. SLURM_JOB_ID or RANK)
|
302
|
+
prefix = settings.host or ""
|
303
|
+
label = runid.generate_id()
|
304
|
+
settings.x_label = f"{prefix}-{label}" if prefix else label
|
305
|
+
|
296
306
|
return settings
|
297
307
|
|
298
308
|
def _load_autoresume_run_id(self, resume_file: pathlib.Path) -> str | None:
|
@@ -747,11 +757,6 @@ class _WandbInit:
|
|
747
757
|
)
|
748
758
|
return drun
|
749
759
|
|
750
|
-
def _on_progress_init(self, handle: MailboxProgress) -> None:
|
751
|
-
line = "Waiting for wandb.init()...\r"
|
752
|
-
percent_done = handle.percent_done
|
753
|
-
self.printer.progress_update(line, percent_done=percent_done)
|
754
|
-
|
755
760
|
def init(self, settings: Settings, config: _ConfigParts) -> Run: # noqa: C901
|
756
761
|
self._logger.info("calling init triggers")
|
757
762
|
trigger.call("on_init")
|
@@ -763,28 +768,18 @@ class _WandbInit:
|
|
763
768
|
f"\nconfig: {config.base_no_artifacts}"
|
764
769
|
)
|
765
770
|
|
766
|
-
if (
|
767
|
-
|
768
|
-
|
769
|
-
|
770
|
-
|
771
|
-
|
772
|
-
" module is not well-supported."
|
773
|
-
" Please use multiprocessing instead."
|
774
|
-
" Finishing previous run before initializing another."
|
775
|
-
)
|
776
|
-
|
777
|
-
latest_run = self._wl._global_run_stack[-1]
|
778
|
-
self._logger.info(f"found existing run on stack: {latest_run.id}")
|
779
|
-
latest_run.finish()
|
780
|
-
elif wandb.run is not None and os.getpid() == wandb.run._init_pid:
|
781
|
-
self._logger.info("wandb.init() called when a run is still active")
|
771
|
+
if wandb.run is not None and os.getpid() == wandb.run._init_pid:
|
772
|
+
if settings.reinit:
|
773
|
+
self._logger.info(f"finishing previous run: {wandb.run.id}")
|
774
|
+
wandb.run.finish()
|
775
|
+
else:
|
776
|
+
self._logger.info("wandb.init() called while a run is active")
|
782
777
|
|
783
|
-
|
784
|
-
|
785
|
-
|
778
|
+
# NOTE: Updates telemetry on the pre-existing run.
|
779
|
+
with telemetry.context() as tel:
|
780
|
+
tel.feature.init_return_run = True
|
786
781
|
|
787
|
-
|
782
|
+
return wandb.run
|
788
783
|
|
789
784
|
self._logger.info("starting backend")
|
790
785
|
|
@@ -905,7 +900,6 @@ class _WandbInit:
|
|
905
900
|
run._set_teardown_hooks(self._teardown_hooks)
|
906
901
|
|
907
902
|
assert backend.interface
|
908
|
-
mailbox.enable_keepalive()
|
909
903
|
backend.interface.publish_header()
|
910
904
|
|
911
905
|
# Using GitRepo() blocks & can be slow, depending on user's current git setup.
|
@@ -921,16 +915,6 @@ class _WandbInit:
|
|
921
915
|
)
|
922
916
|
error: wandb.Error | None = None
|
923
917
|
|
924
|
-
# In shared mode, generate a unique label if not provided.
|
925
|
-
# The label is used to distinguish between system metrics and console logs
|
926
|
-
# from different writers to the same run.
|
927
|
-
if settings._shared and not settings.x_label:
|
928
|
-
# TODO: If executed in a known distributed environment (e.g. Ray or SLURM),
|
929
|
-
# use the env vars to generate a label (e.g. SLURM_JOB_ID or RANK)
|
930
|
-
prefix = settings.host or ""
|
931
|
-
label = runid.generate_id()
|
932
|
-
settings.x_label = f"{prefix}-{label}" if prefix else label
|
933
|
-
|
934
918
|
timeout = settings.init_timeout
|
935
919
|
|
936
920
|
self._logger.info(
|
@@ -938,11 +922,18 @@ class _WandbInit:
|
|
938
922
|
)
|
939
923
|
|
940
924
|
run_init_handle = backend.interface.deliver_run(run)
|
941
|
-
|
942
|
-
|
943
|
-
|
944
|
-
|
945
|
-
|
925
|
+
|
926
|
+
async def display_init_message() -> None:
|
927
|
+
assert backend.interface
|
928
|
+
|
929
|
+
with progress.progress_printer(
|
930
|
+
self.printer,
|
931
|
+
default_text="Waiting for wandb.init()...",
|
932
|
+
) as progress_printer:
|
933
|
+
await progress.loop_printing_operation_stats(
|
934
|
+
progress_printer,
|
935
|
+
backend.interface,
|
936
|
+
)
|
946
937
|
|
947
938
|
# Raise an error if deliver_run failed.
|
948
939
|
#
|
@@ -951,8 +942,16 @@ class _WandbInit:
|
|
951
942
|
#
|
952
943
|
# TODO: Remove try-except once x_disable_service is removed.
|
953
944
|
try:
|
954
|
-
|
955
|
-
|
945
|
+
try:
|
946
|
+
result = wait_with_progress(
|
947
|
+
run_init_handle,
|
948
|
+
timeout=timeout,
|
949
|
+
progress_after=1,
|
950
|
+
display_progress=display_init_message,
|
951
|
+
)
|
952
|
+
|
953
|
+
except TimeoutError:
|
954
|
+
run_init_handle.cancel(backend.interface)
|
956
955
|
|
957
956
|
# This may either be an issue with the W&B server (a CommError)
|
958
957
|
# or a bug in the SDK (an Error). We cannot distinguish between
|
@@ -963,6 +962,8 @@ class _WandbInit:
|
|
963
962
|
" setting: `wandb.init(settings=wandb.Settings(init_timeout=120))`."
|
964
963
|
)
|
965
964
|
|
965
|
+
assert result.run_result
|
966
|
+
|
966
967
|
if error := ProtobufErrorHandler.to_exception(result.run_result.error):
|
967
968
|
raise error
|
968
969
|
|
@@ -1004,13 +1005,13 @@ class _WandbInit:
|
|
1004
1005
|
assert backend.interface
|
1005
1006
|
|
1006
1007
|
run_start_handle = backend.interface.deliver_run_start(run)
|
1007
|
-
|
1008
|
-
|
1009
|
-
|
1010
|
-
|
1008
|
+
try:
|
1009
|
+
# TODO: add progress to let user know we are doing something
|
1010
|
+
run_start_handle.wait_or(timeout=30)
|
1011
|
+
except TimeoutError:
|
1012
|
+
pass
|
1011
1013
|
|
1012
1014
|
assert self._wl is not None
|
1013
|
-
self._wl._global_run_stack.append(run)
|
1014
1015
|
self.run = run
|
1015
1016
|
|
1016
1017
|
run._handle_launch_artifact_overrides()
|
@@ -1097,14 +1098,13 @@ def _attach(
|
|
1097
1098
|
run._set_backend(backend)
|
1098
1099
|
assert backend.interface
|
1099
1100
|
|
1100
|
-
mailbox.enable_keepalive()
|
1101
|
-
|
1102
1101
|
attach_handle = backend.interface.deliver_attach(attach_id)
|
1103
|
-
|
1104
|
-
|
1105
|
-
|
1106
|
-
|
1102
|
+
try:
|
1103
|
+
# TODO: add progress to let user know we are doing something
|
1104
|
+
attach_result = attach_handle.wait_or(timeout=30)
|
1105
|
+
except TimeoutError:
|
1107
1106
|
raise UsageError("Timeout attaching to run")
|
1107
|
+
|
1108
1108
|
attach_response = attach_result.response.attach_response
|
1109
1109
|
if attach_response.error and attach_response.error.message:
|
1110
1110
|
raise UsageError(f"Failed to attach to run: {attach_response.error.message}")
|
@@ -1222,10 +1222,10 @@ def init( # noqa: C901
|
|
1222
1222
|
on the system, such as checking the git root or the current program
|
1223
1223
|
file. If we can't infer the project name, the project will default to
|
1224
1224
|
`"uncategorized"`.
|
1225
|
-
dir:
|
1226
|
-
files
|
1227
|
-
|
1228
|
-
|
1225
|
+
dir: The absolute path to the directory where experiment logs and
|
1226
|
+
metadata files are stored. If not specified, this defaults
|
1227
|
+
to the `./wandb` directory. Note that this does not affect the
|
1228
|
+
location where artifacts are stored when calling `download()`.
|
1229
1229
|
id: A unique identifier for this run, used for resuming. It must be unique
|
1230
1230
|
within the project and cannot be reused once a run is deleted. The
|
1231
1231
|
identifier must not contain any of the following special characters:
|
@@ -1426,7 +1426,7 @@ def init( # noqa: C901
|
|
1426
1426
|
wl: wandb_setup._WandbSetup | None = None
|
1427
1427
|
|
1428
1428
|
try:
|
1429
|
-
wl =
|
1429
|
+
wl = wandb_setup._setup(start_service=False)
|
1430
1430
|
|
1431
1431
|
wi = _WandbInit(wl, init_telemetry)
|
1432
1432
|
|
wandb/sdk/wandb_login.py
CHANGED
@@ -162,8 +162,8 @@ class _WandbLogin:
|
|
162
162
|
repeat=False,
|
163
163
|
)
|
164
164
|
|
165
|
-
def
|
166
|
-
"""Saves the API key
|
165
|
+
def try_save_api_key(self, key: str) -> None:
|
166
|
+
"""Saves the API key to disk for future use."""
|
167
167
|
if self._settings._notebook and not self._settings.silent:
|
168
168
|
wandb.termwarn(
|
169
169
|
"If you're specifying your api key in code, ensure this "
|
@@ -172,7 +172,10 @@ class _WandbLogin:
|
|
172
172
|
"`wandb login` from the command line."
|
173
173
|
)
|
174
174
|
if key:
|
175
|
-
|
175
|
+
try:
|
176
|
+
apikey.write_key(self._settings, key)
|
177
|
+
except apikey.WriteNetrcError as e:
|
178
|
+
wandb.termwarn(str(e))
|
176
179
|
|
177
180
|
def update_session(
|
178
181
|
self,
|
@@ -305,7 +308,7 @@ def _login(
|
|
305
308
|
wlogin._verify_login(key)
|
306
309
|
|
307
310
|
if not key_is_pre_configured:
|
308
|
-
wlogin.
|
311
|
+
wlogin.try_save_api_key(key)
|
309
312
|
wlogin.update_session(key, status=key_status)
|
310
313
|
wlogin._update_global_anonymous_setting()
|
311
314
|
|