wandb 0.19.6rc4__py3-none-win_amd64.whl → 0.19.8__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +1 -1
- wandb/__init__.pyi +56 -6
- wandb/apis/public/_generated/__init__.py +21 -0
- wandb/apis/public/_generated/base.py +128 -0
- wandb/apis/public/_generated/enums.py +4 -0
- wandb/apis/public/_generated/input_types.py +4 -0
- wandb/apis/public/_generated/operations.py +15 -0
- wandb/apis/public/_generated/server_features_query.py +27 -0
- wandb/apis/public/_generated/typing_compat.py +14 -0
- wandb/apis/public/api.py +192 -6
- wandb/apis/public/artifacts.py +13 -45
- wandb/apis/public/registries.py +573 -0
- wandb/apis/public/utils.py +36 -0
- wandb/bin/gpu_stats.exe +0 -0
- wandb/bin/wandb-core +0 -0
- wandb/cli/cli.py +11 -20
- wandb/data_types.py +1 -1
- wandb/env.py +10 -0
- wandb/filesync/dir_watcher.py +2 -1
- wandb/proto/v3/wandb_internal_pb2.py +243 -222
- wandb/proto/v3/wandb_server_pb2.py +4 -4
- wandb/proto/v3/wandb_settings_pb2.py +2 -2
- wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v4/wandb_internal_pb2.py +226 -222
- wandb/proto/v4/wandb_server_pb2.py +4 -4
- wandb/proto/v4/wandb_settings_pb2.py +2 -2
- wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v5/wandb_internal_pb2.py +226 -222
- wandb/proto/v5/wandb_server_pb2.py +4 -4
- wandb/proto/v5/wandb_settings_pb2.py +2 -2
- wandb/proto/v5/wandb_telemetry_pb2.py +10 -10
- wandb/sdk/artifacts/_graphql_fragments.py +126 -0
- wandb/sdk/artifacts/artifact.py +51 -95
- wandb/sdk/backend/backend.py +17 -6
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +14 -6
- wandb/sdk/data_types/helper_types/image_mask.py +12 -6
- wandb/sdk/data_types/saved_model.py +35 -46
- wandb/sdk/data_types/video.py +7 -16
- wandb/sdk/interface/interface.py +87 -49
- wandb/sdk/interface/interface_queue.py +5 -15
- wandb/sdk/interface/interface_relay.py +7 -22
- wandb/sdk/interface/interface_shared.py +65 -136
- wandb/sdk/interface/interface_sock.py +3 -21
- wandb/sdk/interface/router.py +42 -68
- wandb/sdk/interface/router_queue.py +13 -11
- wandb/sdk/interface/router_relay.py +26 -13
- wandb/sdk/interface/router_sock.py +12 -16
- wandb/sdk/internal/handler.py +4 -3
- wandb/sdk/internal/internal_api.py +12 -1
- wandb/sdk/internal/sender.py +3 -19
- wandb/sdk/lib/apikey.py +87 -26
- wandb/sdk/lib/asyncio_compat.py +210 -0
- wandb/sdk/lib/console_capture.py +172 -0
- wandb/sdk/lib/progress.py +78 -16
- wandb/sdk/lib/redirect.py +102 -76
- wandb/sdk/lib/service_connection.py +37 -17
- wandb/sdk/lib/sock_client.py +6 -56
- wandb/sdk/mailbox/__init__.py +23 -0
- wandb/sdk/mailbox/mailbox.py +135 -0
- wandb/sdk/mailbox/mailbox_handle.py +127 -0
- wandb/sdk/mailbox/response_handle.py +167 -0
- wandb/sdk/mailbox/wait_with_progress.py +135 -0
- wandb/sdk/service/server_sock.py +9 -3
- wandb/sdk/service/streams.py +75 -78
- wandb/sdk/verify/verify.py +54 -2
- wandb/sdk/wandb_init.py +72 -75
- wandb/sdk/wandb_login.py +7 -4
- wandb/sdk/wandb_metadata.py +65 -34
- wandb/sdk/wandb_require.py +14 -8
- wandb/sdk/wandb_run.py +90 -97
- wandb/sdk/wandb_settings.py +10 -4
- wandb/sdk/wandb_setup.py +19 -8
- wandb/sdk/wandb_sync.py +2 -10
- wandb/util.py +3 -1
- {wandb-0.19.6rc4.dist-info → wandb-0.19.8.dist-info}/METADATA +2 -2
- {wandb-0.19.6rc4.dist-info → wandb-0.19.8.dist-info}/RECORD +79 -66
- wandb/sdk/interface/message_future.py +0 -27
- wandb/sdk/interface/message_future_poll.py +0 -50
- wandb/sdk/lib/mailbox.py +0 -442
- {wandb-0.19.6rc4.dist-info → wandb-0.19.8.dist-info}/WHEEL +0 -0
- {wandb-0.19.6rc4.dist-info → wandb-0.19.8.dist-info}/entry_points.txt +0 -0
- {wandb-0.19.6rc4.dist-info → wandb-0.19.8.dist-info}/licenses/LICENSE +0 -0
wandb/sdk/service/streams.py
CHANGED
@@ -8,34 +8,27 @@ StreamMux: Container for dictionary of stream threads per runid
|
|
8
8
|
|
9
9
|
from __future__ import annotations
|
10
10
|
|
11
|
+
import asyncio
|
11
12
|
import functools
|
12
|
-
import multiprocessing
|
13
13
|
import queue
|
14
14
|
import threading
|
15
15
|
import time
|
16
16
|
from threading import Event
|
17
|
-
from typing import Any, Callable
|
17
|
+
from typing import Any, Callable, NoReturn
|
18
18
|
|
19
19
|
import psutil
|
20
20
|
|
21
21
|
import wandb
|
22
22
|
import wandb.util
|
23
23
|
from wandb.proto import wandb_internal_pb2 as pb
|
24
|
+
from wandb.sdk.interface.interface_relay import InterfaceRelay
|
25
|
+
from wandb.sdk.interface.router_relay import MessageRelayRouter
|
24
26
|
from wandb.sdk.internal.settings_static import SettingsStatic
|
27
|
+
from wandb.sdk.lib import asyncio_compat, progress
|
25
28
|
from wandb.sdk.lib import printer as printerlib
|
26
|
-
from wandb.sdk.
|
27
|
-
from wandb.sdk.lib.mailbox import (
|
28
|
-
Mailbox,
|
29
|
-
MailboxProbe,
|
30
|
-
MailboxProgress,
|
31
|
-
MailboxProgressAll,
|
32
|
-
)
|
29
|
+
from wandb.sdk.mailbox import Mailbox, MailboxHandle, wait_all_with_progress
|
33
30
|
from wandb.sdk.wandb_run import Run
|
34
31
|
|
35
|
-
from ..interface.interface_relay import InterfaceRelay
|
36
|
-
|
37
|
-
# from wandb.sdk.wandb_settings import Settings
|
38
|
-
|
39
32
|
|
40
33
|
class StreamThread(threading.Thread):
|
41
34
|
"""Class to running internal process as a thread."""
|
@@ -61,19 +54,22 @@ class StreamRecord:
|
|
61
54
|
_settings: SettingsStatic
|
62
55
|
_started: bool
|
63
56
|
|
64
|
-
def __init__(self, settings: SettingsStatic
|
57
|
+
def __init__(self, settings: SettingsStatic) -> None:
|
65
58
|
self._started = False
|
66
|
-
self._mailbox =
|
59
|
+
self._mailbox = Mailbox()
|
67
60
|
self._record_q = queue.Queue()
|
68
61
|
self._result_q = queue.Queue()
|
69
62
|
self._relay_q = queue.Queue()
|
70
|
-
|
63
|
+
self._router = MessageRelayRouter(
|
64
|
+
request_queue=self._record_q,
|
65
|
+
response_queue=self._result_q,
|
66
|
+
relay_queue=self._relay_q,
|
67
|
+
mailbox=self._mailbox,
|
68
|
+
)
|
71
69
|
self._iface = InterfaceRelay(
|
72
70
|
record_q=self._record_q,
|
73
71
|
result_q=self._result_q,
|
74
72
|
relay_q=self._relay_q,
|
75
|
-
process=process,
|
76
|
-
process_check=False,
|
77
73
|
mailbox=self._mailbox,
|
78
74
|
)
|
79
75
|
self._settings = settings
|
@@ -84,10 +80,11 @@ class StreamRecord:
|
|
84
80
|
self._wait_thread_active()
|
85
81
|
|
86
82
|
def _wait_thread_active(self) -> None:
|
87
|
-
self._iface.deliver_status().
|
83
|
+
self._iface.deliver_status().wait_or(timeout=None)
|
88
84
|
|
89
85
|
def join(self) -> None:
|
90
86
|
self._iface.join()
|
87
|
+
self._router.join()
|
91
88
|
if self._thread:
|
92
89
|
self._thread.join()
|
93
90
|
|
@@ -141,7 +138,6 @@ class StreamMux:
|
|
141
138
|
_action_q: queue.Queue[StreamAction]
|
142
139
|
_stopped: Event
|
143
140
|
_pid_checked_ts: float | None
|
144
|
-
_mailbox: Mailbox
|
145
141
|
|
146
142
|
def __init__(self) -> None:
|
147
143
|
self._streams_lock = threading.Lock()
|
@@ -151,8 +147,6 @@ class StreamMux:
|
|
151
147
|
self._stopped = Event()
|
152
148
|
self._action_q = queue.Queue()
|
153
149
|
self._pid_checked_ts = None
|
154
|
-
self._mailbox = Mailbox()
|
155
|
-
self._mailbox.enable_keepalive()
|
156
150
|
|
157
151
|
def _get_stopped_event(self) -> Event:
|
158
152
|
# TODO: clean this up, there should be a better way to abstract this
|
@@ -209,7 +203,7 @@ class StreamMux:
|
|
209
203
|
return stream
|
210
204
|
|
211
205
|
def _process_add(self, action: StreamAction) -> None:
|
212
|
-
stream = StreamRecord(action._data
|
206
|
+
stream = StreamRecord(action._data)
|
213
207
|
# run_id = action.stream_id # will want to fix if a streamid != runid
|
214
208
|
settings = action._data
|
215
209
|
thread = StreamThread(
|
@@ -247,41 +241,51 @@ class StreamMux:
|
|
247
241
|
stream.drop()
|
248
242
|
stream.join()
|
249
243
|
|
250
|
-
def
|
251
|
-
handle = probe_handle.get_mailbox_handle()
|
252
|
-
if handle:
|
253
|
-
result = handle.wait(timeout=0, release=False)
|
254
|
-
if not result:
|
255
|
-
return
|
256
|
-
probe_handle.set_probe_result(result)
|
257
|
-
handle = stream.interface.deliver_poll_exit()
|
258
|
-
probe_handle.set_mailbox_handle(handle)
|
259
|
-
|
260
|
-
def _on_progress_exit(self, progress_handle: MailboxProgress) -> None:
|
261
|
-
pass
|
262
|
-
|
263
|
-
def _on_progress_exit_all(
|
244
|
+
async def _finish_all_progress(
|
264
245
|
self,
|
265
246
|
progress_printer: progress.ProgressPrinter,
|
266
|
-
|
247
|
+
streams_to_watch: dict[str, StreamRecord],
|
267
248
|
) -> None:
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
249
|
+
"""Poll the streams and display statistics about them.
|
250
|
+
|
251
|
+
This never returns and must be cancelled.
|
252
|
+
|
253
|
+
Args:
|
254
|
+
progress_printer: Printer to use for displaying finish progress.
|
255
|
+
streams_to_watch: Streams to poll for finish progress.
|
256
|
+
"""
|
257
|
+
results: dict[str, pb.Result | None] = {}
|
258
|
+
|
259
|
+
async def loop_poll_stream(
|
260
|
+
stream_id: str,
|
261
|
+
stream: StreamRecord,
|
262
|
+
) -> NoReturn:
|
263
|
+
while True:
|
264
|
+
start_time = time.monotonic()
|
265
|
+
|
266
|
+
handle = stream.interface.deliver_poll_exit()
|
267
|
+
results[stream_id] = await handle.wait_async(timeout=None)
|
268
|
+
|
269
|
+
elapsed_time = time.monotonic() - start_time
|
270
|
+
if elapsed_time < 1:
|
271
|
+
await asyncio.sleep(1 - elapsed_time)
|
272
|
+
|
273
|
+
async def loop_update_printer() -> NoReturn:
|
274
|
+
while True:
|
275
|
+
poll_exit_responses: list[pb.PollExitResponse] = []
|
276
|
+
for result in results.values():
|
277
|
+
if not result or not result.response:
|
278
|
+
continue
|
279
|
+
if poll_exit_response := result.response.poll_exit_response:
|
280
|
+
poll_exit_responses.append(poll_exit_response)
|
281
|
+
|
282
|
+
progress_printer.update(poll_exit_responses)
|
283
|
+
await asyncio.sleep(1)
|
284
|
+
|
285
|
+
async with asyncio_compat.open_task_group() as task_group:
|
286
|
+
for stream_id, stream in streams_to_watch.items():
|
287
|
+
task_group.start_soon(loop_poll_stream(stream_id, stream))
|
288
|
+
task_group.start_soon(loop_update_printer())
|
285
289
|
|
286
290
|
def _finish_all(self, streams: dict[str, StreamRecord], exit_code: int) -> None:
|
287
291
|
if not streams:
|
@@ -291,7 +295,7 @@ class StreamMux:
|
|
291
295
|
|
292
296
|
# fixme: for now we have a single printer for all streams,
|
293
297
|
# and jupyter is disabled if at least single stream's setting set `_jupyter` to false
|
294
|
-
exit_handles = []
|
298
|
+
exit_handles: list[MailboxHandle[pb.Result]] = []
|
295
299
|
|
296
300
|
# only finish started streams, non started streams failed early
|
297
301
|
started_streams: dict[str, StreamRecord] = {}
|
@@ -302,27 +306,24 @@ class StreamMux:
|
|
302
306
|
|
303
307
|
for stream in started_streams.values():
|
304
308
|
handle = stream.interface.deliver_exit(exit_code)
|
305
|
-
handle.add_progress(self._on_progress_exit)
|
306
|
-
handle.add_probe(functools.partial(self._on_probe_exit, stream=stream))
|
307
309
|
exit_handles.append(handle)
|
308
310
|
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
with progress.progress_printer(printer) as progress_printer:
|
311
|
+
with progress.progress_printer(
|
312
|
+
printer,
|
313
|
+
default_text="Finishing up...",
|
314
|
+
) as progress_printer:
|
315
315
|
# todo: should we wait for the max timeout (?) of all exit handles or just wait forever?
|
316
316
|
# timeout = max(stream._settings._exit_timeout for stream in streams.values())
|
317
|
-
|
318
|
-
|
319
|
-
timeout
|
320
|
-
|
321
|
-
|
317
|
+
wait_all_with_progress(
|
318
|
+
exit_handles,
|
319
|
+
timeout=None,
|
320
|
+
progress_after=1,
|
321
|
+
display_progress=functools.partial(
|
322
|
+
self._finish_all_progress,
|
322
323
|
progress_printer,
|
324
|
+
started_streams,
|
323
325
|
),
|
324
326
|
)
|
325
|
-
assert got_result
|
326
327
|
|
327
328
|
# These could be done in parallel in the future
|
328
329
|
for _sid, stream in started_streams.items():
|
@@ -332,20 +333,16 @@ class StreamMux:
|
|
332
333
|
sampled_history_handle = stream.interface.deliver_request_sampled_history()
|
333
334
|
internal_messages_handle = stream.interface.deliver_internal_messages()
|
334
335
|
|
335
|
-
result = internal_messages_handle.
|
336
|
-
assert result
|
336
|
+
result = internal_messages_handle.wait_or(timeout=None)
|
337
337
|
internal_messages_response = result.response.internal_messages_response
|
338
338
|
|
339
|
-
result = poll_exit_handle.
|
340
|
-
assert result
|
339
|
+
result = poll_exit_handle.wait_or(timeout=None)
|
341
340
|
poll_exit_response = result.response.poll_exit_response
|
342
341
|
|
343
|
-
result = sampled_history_handle.
|
344
|
-
assert result
|
342
|
+
result = sampled_history_handle.wait_or(timeout=None)
|
345
343
|
sampled_history = result.response.sampled_history_response
|
346
344
|
|
347
|
-
result = final_summary_handle.
|
348
|
-
assert result
|
345
|
+
result = final_summary_handle.wait_or(timeout=None)
|
349
346
|
final_summary = result.response.get_summary_response
|
350
347
|
|
351
348
|
Run._footer(
|
wandb/sdk/verify/verify.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1
1
|
"""Utilities for wandb verify."""
|
2
2
|
|
3
|
+
import contextlib
|
3
4
|
import getpass
|
5
|
+
import io
|
4
6
|
import os
|
5
7
|
import time
|
6
8
|
from functools import partial
|
@@ -163,8 +165,8 @@ def check_run(api: Api) -> bool:
|
|
163
165
|
)
|
164
166
|
print_results(failed_test_strings, False)
|
165
167
|
return False
|
166
|
-
for key, value in
|
167
|
-
if config
|
168
|
+
for key, value in config.items():
|
169
|
+
if prev_run.config.get(key) != value:
|
168
170
|
failed_test_strings.append(
|
169
171
|
"Read config values don't match run config. Contact W&B for support."
|
170
172
|
)
|
@@ -486,6 +488,56 @@ def check_wandb_version(api: Api) -> None:
|
|
486
488
|
print_results(fail_string, warning)
|
487
489
|
|
488
490
|
|
491
|
+
def check_sweeps(api: Api) -> bool:
|
492
|
+
print("Checking sweep creation and agent execution".ljust(72, "."), end="") # noqa: T201
|
493
|
+
failed_test_strings: List[str] = []
|
494
|
+
|
495
|
+
sweep_config = {
|
496
|
+
"method": "random",
|
497
|
+
"metric": {"goal": "minimize", "name": "score"},
|
498
|
+
"parameters": {
|
499
|
+
"x": {"values": [0.01, 0.05, 0.1]},
|
500
|
+
"y": {"values": [1, 2, 3]},
|
501
|
+
},
|
502
|
+
"name": "verify_sweep",
|
503
|
+
}
|
504
|
+
|
505
|
+
try:
|
506
|
+
with contextlib.redirect_stdout(io.StringIO()):
|
507
|
+
sweep_id = wandb.sweep(
|
508
|
+
sweep=sweep_config, project=PROJECT_NAME, entity=api.default_entity
|
509
|
+
)
|
510
|
+
except Exception as e:
|
511
|
+
failed_test_strings.append(f"Failed to create sweep: {e}")
|
512
|
+
print_results(failed_test_strings, False)
|
513
|
+
return False
|
514
|
+
|
515
|
+
if not sweep_id:
|
516
|
+
failed_test_strings.append("Sweep creation returned an invalid ID.")
|
517
|
+
print_results(failed_test_strings, False)
|
518
|
+
return False
|
519
|
+
|
520
|
+
try:
|
521
|
+
|
522
|
+
def objective(config):
|
523
|
+
score = config.x**3 + config.y
|
524
|
+
return score
|
525
|
+
|
526
|
+
def main():
|
527
|
+
with wandb.init(project=PROJECT_NAME) as run:
|
528
|
+
score = objective(run.config)
|
529
|
+
run.log({"score": score})
|
530
|
+
|
531
|
+
wandb.agent(sweep_id, function=main, count=10)
|
532
|
+
except Exception as e:
|
533
|
+
failed_test_strings.append(f"Failed to run sweep agent: {e}")
|
534
|
+
print_results(failed_test_strings, False)
|
535
|
+
return False
|
536
|
+
|
537
|
+
print_results(failed_test_strings, False)
|
538
|
+
return len(failed_test_strings) == 0
|
539
|
+
|
540
|
+
|
489
541
|
def retry_fn(fn: Callable) -> Any:
|
490
542
|
ini_time = time.time()
|
491
543
|
res = None
|
wandb/sdk/wandb_init.py
CHANGED
@@ -34,7 +34,7 @@ from wandb.errors import CommError, Error, UsageError
|
|
34
34
|
from wandb.errors.links import url_registry
|
35
35
|
from wandb.errors.util import ProtobufErrorHandler
|
36
36
|
from wandb.integration import sagemaker
|
37
|
-
from wandb.sdk.lib import runid
|
37
|
+
from wandb.sdk.lib import progress, runid
|
38
38
|
from wandb.sdk.lib.paths import StrPath
|
39
39
|
from wandb.util import _is_artifact_representation
|
40
40
|
|
@@ -42,7 +42,7 @@ from . import wandb_login, wandb_setup
|
|
42
42
|
from .backend.backend import Backend
|
43
43
|
from .lib import SummaryDisabled, filesystem, module, paths, printer, telemetry
|
44
44
|
from .lib.deprecate import Deprecated, deprecate
|
45
|
-
from .
|
45
|
+
from .mailbox import wait_with_progress
|
46
46
|
from .wandb_helper import parse_config
|
47
47
|
from .wandb_run import Run, TeardownHook, TeardownStage
|
48
48
|
from .wandb_settings import Settings
|
@@ -293,6 +293,16 @@ class _WandbInit:
|
|
293
293
|
|
294
294
|
settings.x_start_time = time.time()
|
295
295
|
|
296
|
+
# In shared mode, generate a unique label if not provided.
|
297
|
+
# The label is used to distinguish between system metrics and console logs
|
298
|
+
# from different writers to the same run.
|
299
|
+
if settings._shared and not settings.x_label:
|
300
|
+
# TODO: If executed in a known distributed environment (e.g. Ray or SLURM),
|
301
|
+
# use the env vars to generate a label (e.g. SLURM_JOB_ID or RANK)
|
302
|
+
prefix = settings.host or ""
|
303
|
+
label = runid.generate_id()
|
304
|
+
settings.x_label = f"{prefix}-{label}" if prefix else label
|
305
|
+
|
296
306
|
return settings
|
297
307
|
|
298
308
|
def _load_autoresume_run_id(self, resume_file: pathlib.Path) -> str | None:
|
@@ -672,11 +682,11 @@ class _WandbInit:
|
|
672
682
|
drun._Run__metadata = wandb.sdk.wandb_metadata.Metadata()
|
673
683
|
|
674
684
|
# methods
|
675
|
-
drun.log = lambda data, *_, **__: drun.summary.update(data) # type: ignore
|
676
|
-
drun.finish = lambda *_, **__: module.unset_globals() # type: ignore
|
677
|
-
drun.join = drun.finish # type: ignore
|
678
|
-
drun.define_metric = lambda *_, **__: wandb.sdk.wandb_metric.Metric("dummy") # type: ignore
|
679
|
-
drun.save = lambda *_, **__: False # type: ignore
|
685
|
+
drun.log = lambda data, *_, **__: drun.summary.update(data) # type: ignore[method-assign]
|
686
|
+
drun.finish = lambda *_, **__: module.unset_globals() # type: ignore[method-assign]
|
687
|
+
drun.join = drun.finish # type: ignore[method-assign]
|
688
|
+
drun.define_metric = lambda *_, **__: wandb.sdk.wandb_metric.Metric("dummy") # type: ignore[method-assign]
|
689
|
+
drun.save = lambda *_, **__: False # type: ignore[method-assign]
|
680
690
|
for symbol in (
|
681
691
|
"alert",
|
682
692
|
"finish_artifact",
|
@@ -723,7 +733,7 @@ class _WandbInit:
|
|
723
733
|
def __call__(self, *args: Any, **kwargs: Any) -> _ChainableNoOp:
|
724
734
|
return _ChainableNoOp()
|
725
735
|
|
726
|
-
drun.log_artifact = _ChainableNoOpField()
|
736
|
+
drun.log_artifact = _ChainableNoOpField() # type: ignore[method-assign]
|
727
737
|
# attributes
|
728
738
|
drun._start_time = time.time()
|
729
739
|
drun._starting_step = 0
|
@@ -747,11 +757,6 @@ class _WandbInit:
|
|
747
757
|
)
|
748
758
|
return drun
|
749
759
|
|
750
|
-
def _on_progress_init(self, handle: MailboxProgress) -> None:
|
751
|
-
line = "Waiting for wandb.init()...\r"
|
752
|
-
percent_done = handle.percent_done
|
753
|
-
self.printer.progress_update(line, percent_done=percent_done)
|
754
|
-
|
755
760
|
def init(self, settings: Settings, config: _ConfigParts) -> Run: # noqa: C901
|
756
761
|
self._logger.info("calling init triggers")
|
757
762
|
trigger.call("on_init")
|
@@ -763,28 +768,18 @@ class _WandbInit:
|
|
763
768
|
f"\nconfig: {config.base_no_artifacts}"
|
764
769
|
)
|
765
770
|
|
766
|
-
if (
|
767
|
-
|
768
|
-
|
769
|
-
|
770
|
-
|
771
|
-
|
772
|
-
" module is not well-supported."
|
773
|
-
" Please use multiprocessing instead."
|
774
|
-
" Finishing previous run before initializing another."
|
775
|
-
)
|
776
|
-
|
777
|
-
latest_run = self._wl._global_run_stack[-1]
|
778
|
-
self._logger.info(f"found existing run on stack: {latest_run.id}")
|
779
|
-
latest_run.finish()
|
780
|
-
elif wandb.run is not None and os.getpid() == wandb.run._init_pid:
|
781
|
-
self._logger.info("wandb.init() called when a run is still active")
|
771
|
+
if wandb.run is not None and os.getpid() == wandb.run._init_pid:
|
772
|
+
if settings.reinit:
|
773
|
+
self._logger.info(f"finishing previous run: {wandb.run.id}")
|
774
|
+
wandb.run.finish()
|
775
|
+
else:
|
776
|
+
self._logger.info("wandb.init() called while a run is active")
|
782
777
|
|
783
|
-
|
784
|
-
|
785
|
-
|
778
|
+
# NOTE: Updates telemetry on the pre-existing run.
|
779
|
+
with telemetry.context() as tel:
|
780
|
+
tel.feature.init_return_run = True
|
786
781
|
|
787
|
-
|
782
|
+
return wandb.run
|
788
783
|
|
789
784
|
self._logger.info("starting backend")
|
790
785
|
|
@@ -798,12 +793,7 @@ class _WandbInit:
|
|
798
793
|
else:
|
799
794
|
service = None
|
800
795
|
|
801
|
-
|
802
|
-
backend = Backend(
|
803
|
-
settings=settings,
|
804
|
-
service=service,
|
805
|
-
mailbox=mailbox,
|
806
|
-
)
|
796
|
+
backend = Backend(settings=settings, service=service)
|
807
797
|
backend.ensure_launched()
|
808
798
|
self._logger.info("backend started and connected")
|
809
799
|
|
@@ -905,7 +895,6 @@ class _WandbInit:
|
|
905
895
|
run._set_teardown_hooks(self._teardown_hooks)
|
906
896
|
|
907
897
|
assert backend.interface
|
908
|
-
mailbox.enable_keepalive()
|
909
898
|
backend.interface.publish_header()
|
910
899
|
|
911
900
|
# Using GitRepo() blocks & can be slow, depending on user's current git setup.
|
@@ -921,16 +910,6 @@ class _WandbInit:
|
|
921
910
|
)
|
922
911
|
error: wandb.Error | None = None
|
923
912
|
|
924
|
-
# In shared mode, generate a unique label if not provided.
|
925
|
-
# The label is used to distinguish between system metrics and console logs
|
926
|
-
# from different writers to the same run.
|
927
|
-
if settings._shared and not settings.x_label:
|
928
|
-
# TODO: If executed in a known distributed environment (e.g. Ray or SLURM),
|
929
|
-
# use the env vars to generate a label (e.g. SLURM_JOB_ID or RANK)
|
930
|
-
prefix = settings.host or ""
|
931
|
-
label = runid.generate_id()
|
932
|
-
settings.x_label = f"{prefix}-{label}" if prefix else label
|
933
|
-
|
934
913
|
timeout = settings.init_timeout
|
935
914
|
|
936
915
|
self._logger.info(
|
@@ -938,11 +917,18 @@ class _WandbInit:
|
|
938
917
|
)
|
939
918
|
|
940
919
|
run_init_handle = backend.interface.deliver_run(run)
|
941
|
-
|
942
|
-
|
943
|
-
|
944
|
-
|
945
|
-
|
920
|
+
|
921
|
+
async def display_init_message() -> None:
|
922
|
+
assert backend.interface
|
923
|
+
|
924
|
+
with progress.progress_printer(
|
925
|
+
self.printer,
|
926
|
+
default_text="Waiting for wandb.init()...",
|
927
|
+
) as progress_printer:
|
928
|
+
await progress.loop_printing_operation_stats(
|
929
|
+
progress_printer,
|
930
|
+
backend.interface,
|
931
|
+
)
|
946
932
|
|
947
933
|
# Raise an error if deliver_run failed.
|
948
934
|
#
|
@@ -951,8 +937,16 @@ class _WandbInit:
|
|
951
937
|
#
|
952
938
|
# TODO: Remove try-except once x_disable_service is removed.
|
953
939
|
try:
|
954
|
-
|
955
|
-
|
940
|
+
try:
|
941
|
+
result = wait_with_progress(
|
942
|
+
run_init_handle,
|
943
|
+
timeout=timeout,
|
944
|
+
progress_after=1,
|
945
|
+
display_progress=display_init_message,
|
946
|
+
)
|
947
|
+
|
948
|
+
except TimeoutError:
|
949
|
+
run_init_handle.cancel(backend.interface)
|
956
950
|
|
957
951
|
# This may either be an issue with the W&B server (a CommError)
|
958
952
|
# or a bug in the SDK (an Error). We cannot distinguish between
|
@@ -963,6 +957,8 @@ class _WandbInit:
|
|
963
957
|
" setting: `wandb.init(settings=wandb.Settings(init_timeout=120))`."
|
964
958
|
)
|
965
959
|
|
960
|
+
assert result.run_result
|
961
|
+
|
966
962
|
if error := ProtobufErrorHandler.to_exception(result.run_result.error):
|
967
963
|
raise error
|
968
964
|
|
@@ -1004,13 +1000,13 @@ class _WandbInit:
|
|
1004
1000
|
assert backend.interface
|
1005
1001
|
|
1006
1002
|
run_start_handle = backend.interface.deliver_run_start(run)
|
1007
|
-
|
1008
|
-
|
1009
|
-
|
1010
|
-
|
1003
|
+
try:
|
1004
|
+
# TODO: add progress to let user know we are doing something
|
1005
|
+
run_start_handle.wait_or(timeout=30)
|
1006
|
+
except TimeoutError:
|
1007
|
+
pass
|
1011
1008
|
|
1012
1009
|
assert self._wl is not None
|
1013
|
-
self._wl._global_run_stack.append(run)
|
1014
1010
|
self.run = run
|
1015
1011
|
|
1016
1012
|
run._handle_launch_artifact_overrides()
|
@@ -1084,8 +1080,7 @@ def _attach(
|
|
1084
1080
|
)
|
1085
1081
|
|
1086
1082
|
# TODO: consolidate this codepath with wandb.init()
|
1087
|
-
|
1088
|
-
backend = Backend(settings=settings, service=service, mailbox=mailbox)
|
1083
|
+
backend = Backend(settings=settings, service=service)
|
1089
1084
|
backend.ensure_launched()
|
1090
1085
|
logger.info("attach backend started and connected")
|
1091
1086
|
|
@@ -1097,14 +1092,13 @@ def _attach(
|
|
1097
1092
|
run._set_backend(backend)
|
1098
1093
|
assert backend.interface
|
1099
1094
|
|
1100
|
-
mailbox.enable_keepalive()
|
1101
|
-
|
1102
1095
|
attach_handle = backend.interface.deliver_attach(attach_id)
|
1103
|
-
|
1104
|
-
|
1105
|
-
|
1106
|
-
|
1096
|
+
try:
|
1097
|
+
# TODO: add progress to let user know we are doing something
|
1098
|
+
attach_result = attach_handle.wait_or(timeout=30)
|
1099
|
+
except TimeoutError:
|
1107
1100
|
raise UsageError("Timeout attaching to run")
|
1101
|
+
|
1108
1102
|
attach_response = attach_result.response.attach_response
|
1109
1103
|
if attach_response.error and attach_response.error.message:
|
1110
1104
|
raise UsageError(f"Failed to attach to run: {attach_response.error.message}")
|
@@ -1222,10 +1216,10 @@ def init( # noqa: C901
|
|
1222
1216
|
on the system, such as checking the git root or the current program
|
1223
1217
|
file. If we can't infer the project name, the project will default to
|
1224
1218
|
`"uncategorized"`.
|
1225
|
-
dir:
|
1226
|
-
files
|
1227
|
-
|
1228
|
-
|
1219
|
+
dir: The absolute path to the directory where experiment logs and
|
1220
|
+
metadata files are stored. If not specified, this defaults
|
1221
|
+
to the `./wandb` directory. Note that this does not affect the
|
1222
|
+
location where artifacts are stored when calling `download()`.
|
1229
1223
|
id: A unique identifier for this run, used for resuming. It must be unique
|
1230
1224
|
within the project and cannot be reused once a run is deleted. The
|
1231
1225
|
identifier must not contain any of the following special characters:
|
@@ -1426,7 +1420,7 @@ def init( # noqa: C901
|
|
1426
1420
|
wl: wandb_setup._WandbSetup | None = None
|
1427
1421
|
|
1428
1422
|
try:
|
1429
|
-
wl =
|
1423
|
+
wl = wandb_setup._setup(start_service=False)
|
1430
1424
|
|
1431
1425
|
wi = _WandbInit(wl, init_telemetry)
|
1432
1426
|
|
@@ -1468,6 +1462,9 @@ def init( # noqa: C901
|
|
1468
1462
|
_monkeypatch_tensorboard()
|
1469
1463
|
init_telemetry.feature.tensorboard_sync = True
|
1470
1464
|
|
1465
|
+
if run_settings.x_server_side_derived_summary:
|
1466
|
+
init_telemetry.feature.server_side_derived_summary = True
|
1467
|
+
|
1471
1468
|
return wi.init(run_settings, run_config)
|
1472
1469
|
|
1473
1470
|
except KeyboardInterrupt as e:
|
wandb/sdk/wandb_login.py
CHANGED
@@ -162,8 +162,8 @@ class _WandbLogin:
|
|
162
162
|
repeat=False,
|
163
163
|
)
|
164
164
|
|
165
|
-
def
|
166
|
-
"""Saves the API key
|
165
|
+
def try_save_api_key(self, key: str) -> None:
|
166
|
+
"""Saves the API key to disk for future use."""
|
167
167
|
if self._settings._notebook and not self._settings.silent:
|
168
168
|
wandb.termwarn(
|
169
169
|
"If you're specifying your api key in code, ensure this "
|
@@ -172,7 +172,10 @@ class _WandbLogin:
|
|
172
172
|
"`wandb login` from the command line."
|
173
173
|
)
|
174
174
|
if key:
|
175
|
-
|
175
|
+
try:
|
176
|
+
apikey.write_key(self._settings, key)
|
177
|
+
except apikey.WriteNetrcError as e:
|
178
|
+
wandb.termwarn(str(e))
|
176
179
|
|
177
180
|
def update_session(
|
178
181
|
self,
|
@@ -305,7 +308,7 @@ def _login(
|
|
305
308
|
wlogin._verify_login(key)
|
306
309
|
|
307
310
|
if not key_is_pre_configured:
|
308
|
-
wlogin.
|
311
|
+
wlogin.try_save_api_key(key)
|
309
312
|
wlogin.update_session(key, status=key_status)
|
310
313
|
wlogin._update_global_anonymous_setting()
|
311
314
|
|