wandb 0.19.6__py3-none-macosx_11_0_arm64.whl → 0.19.7__py3-none-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +1 -1
- wandb/__init__.pyi +25 -5
- wandb/apis/public/_generated/__init__.py +21 -0
- wandb/apis/public/_generated/base.py +128 -0
- wandb/apis/public/_generated/enums.py +4 -0
- wandb/apis/public/_generated/input_types.py +4 -0
- wandb/apis/public/_generated/operations.py +15 -0
- wandb/apis/public/_generated/server_features_query.py +27 -0
- wandb/apis/public/_generated/typing_compat.py +14 -0
- wandb/apis/public/api.py +192 -6
- wandb/apis/public/artifacts.py +13 -45
- wandb/apis/public/registries.py +573 -0
- wandb/apis/public/utils.py +36 -0
- wandb/bin/gpu_stats +0 -0
- wandb/bin/wandb-core +0 -0
- wandb/cli/cli.py +11 -20
- wandb/env.py +10 -0
- wandb/proto/v3/wandb_internal_pb2.py +243 -222
- wandb/proto/v3/wandb_server_pb2.py +4 -4
- wandb/proto/v3/wandb_settings_pb2.py +1 -1
- wandb/proto/v4/wandb_internal_pb2.py +226 -222
- wandb/proto/v4/wandb_server_pb2.py +4 -4
- wandb/proto/v4/wandb_settings_pb2.py +1 -1
- wandb/proto/v5/wandb_internal_pb2.py +226 -222
- wandb/proto/v5/wandb_server_pb2.py +4 -4
- wandb/proto/v5/wandb_settings_pb2.py +1 -1
- wandb/sdk/artifacts/_graphql_fragments.py +126 -0
- wandb/sdk/artifacts/artifact.py +43 -88
- wandb/sdk/backend/backend.py +1 -1
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +14 -6
- wandb/sdk/data_types/helper_types/image_mask.py +12 -6
- wandb/sdk/data_types/saved_model.py +35 -46
- wandb/sdk/data_types/video.py +7 -16
- wandb/sdk/interface/interface.py +26 -10
- wandb/sdk/interface/interface_queue.py +5 -8
- wandb/sdk/interface/interface_relay.py +1 -6
- wandb/sdk/interface/interface_shared.py +21 -99
- wandb/sdk/interface/interface_sock.py +2 -13
- wandb/sdk/interface/router.py +21 -15
- wandb/sdk/interface/router_queue.py +2 -1
- wandb/sdk/interface/router_relay.py +2 -1
- wandb/sdk/interface/router_sock.py +5 -4
- wandb/sdk/internal/handler.py +4 -3
- wandb/sdk/internal/internal_api.py +12 -1
- wandb/sdk/internal/sender.py +0 -18
- wandb/sdk/lib/apikey.py +87 -26
- wandb/sdk/lib/asyncio_compat.py +210 -0
- wandb/sdk/lib/progress.py +78 -16
- wandb/sdk/lib/service_connection.py +1 -1
- wandb/sdk/lib/sock_client.py +7 -7
- wandb/sdk/mailbox/__init__.py +23 -0
- wandb/sdk/mailbox/handles.py +199 -0
- wandb/sdk/mailbox/mailbox.py +121 -0
- wandb/sdk/mailbox/wait_with_progress.py +134 -0
- wandb/sdk/service/server_sock.py +5 -1
- wandb/sdk/service/streams.py +66 -74
- wandb/sdk/verify/verify.py +54 -2
- wandb/sdk/wandb_init.py +61 -61
- wandb/sdk/wandb_login.py +7 -4
- wandb/sdk/wandb_metadata.py +65 -34
- wandb/sdk/wandb_require.py +14 -8
- wandb/sdk/wandb_run.py +82 -87
- wandb/sdk/wandb_settings.py +3 -3
- wandb/sdk/wandb_setup.py +19 -8
- wandb/sdk/wandb_sync.py +2 -4
- wandb/util.py +3 -1
- {wandb-0.19.6.dist-info → wandb-0.19.7.dist-info}/METADATA +2 -2
- {wandb-0.19.6.dist-info → wandb-0.19.7.dist-info}/RECORD +71 -58
- wandb/sdk/lib/mailbox.py +0 -442
- {wandb-0.19.6.dist-info → wandb-0.19.7.dist-info}/WHEEL +0 -0
- {wandb-0.19.6.dist-info → wandb-0.19.7.dist-info}/entry_points.txt +0 -0
- {wandb-0.19.6.dist-info → wandb-0.19.7.dist-info}/licenses/LICENSE +0 -0
    
        wandb/sdk/service/server_sock.py
    CHANGED
    
    | @@ -70,6 +70,7 @@ class SockServerInterfaceReaderThread(threading.Thread): | |
| 70 70 | 
             
                        sock_client = self._clients.get_client(sockid)
         | 
| 71 71 | 
             
                        assert sock_client
         | 
| 72 72 | 
             
                        sresp = spb.ServerResponse()
         | 
| 73 | 
            +
                        sresp.request_id = result.control.mailbox_slot
         | 
| 73 74 | 
             
                        sresp.result_communicate.CopyFrom(result)
         | 
| 74 75 | 
             
                        sock_client.send_server_response(sresp)
         | 
| 75 76 |  | 
| @@ -148,7 +149,10 @@ class SockServerReadThread(threading.Thread): | |
| 148 149 | 
             
                    inform_attach_response.settings.CopyFrom(
         | 
| 149 150 | 
             
                        self._mux._streams[stream_id]._settings._proto,
         | 
| 150 151 | 
             
                    )
         | 
| 151 | 
            -
                    response = spb.ServerResponse( | 
| 152 | 
            +
                    response = spb.ServerResponse(
         | 
| 153 | 
            +
                        request_id=sreq.request_id,
         | 
| 154 | 
            +
                        inform_attach_response=inform_attach_response,
         | 
| 155 | 
            +
                    )
         | 
| 152 156 | 
             
                    self._sock_client.send_server_response(response)
         | 
| 153 157 | 
             
                    iface = self._mux.get_stream(stream_id).interface
         | 
| 154 158 |  | 
    
        wandb/sdk/service/streams.py
    CHANGED
    
    | @@ -8,13 +8,13 @@ StreamMux: Container for dictionary of stream threads per runid | |
| 8 8 |  | 
| 9 9 | 
             
            from __future__ import annotations
         | 
| 10 10 |  | 
| 11 | 
            +
            import asyncio
         | 
| 11 12 | 
             
            import functools
         | 
| 12 | 
            -
            import multiprocessing
         | 
| 13 13 | 
             
            import queue
         | 
| 14 14 | 
             
            import threading
         | 
| 15 15 | 
             
            import time
         | 
| 16 16 | 
             
            from threading import Event
         | 
| 17 | 
            -
            from typing import Any, Callable
         | 
| 17 | 
            +
            from typing import Any, Callable, NoReturn
         | 
| 18 18 |  | 
| 19 19 | 
             
            import psutil
         | 
| 20 20 |  | 
| @@ -22,14 +22,9 @@ import wandb | |
| 22 22 | 
             
            import wandb.util
         | 
| 23 23 | 
             
            from wandb.proto import wandb_internal_pb2 as pb
         | 
| 24 24 | 
             
            from wandb.sdk.internal.settings_static import SettingsStatic
         | 
| 25 | 
            +
            from wandb.sdk.lib import asyncio_compat, progress
         | 
| 25 26 | 
             
            from wandb.sdk.lib import printer as printerlib
         | 
| 26 | 
            -
            from wandb.sdk. | 
| 27 | 
            -
            from wandb.sdk.lib.mailbox import (
         | 
| 28 | 
            -
                Mailbox,
         | 
| 29 | 
            -
                MailboxProbe,
         | 
| 30 | 
            -
                MailboxProgress,
         | 
| 31 | 
            -
                MailboxProgressAll,
         | 
| 32 | 
            -
            )
         | 
| 27 | 
            +
            from wandb.sdk.mailbox import Mailbox, MailboxHandle, wait_all_with_progress
         | 
| 33 28 | 
             
            from wandb.sdk.wandb_run import Run
         | 
| 34 29 |  | 
| 35 30 | 
             
            from ..interface.interface_relay import InterfaceRelay
         | 
| @@ -61,19 +56,16 @@ class StreamRecord: | |
| 61 56 | 
             
                _settings: SettingsStatic
         | 
| 62 57 | 
             
                _started: bool
         | 
| 63 58 |  | 
| 64 | 
            -
                def __init__(self, settings: SettingsStatic | 
| 59 | 
            +
                def __init__(self, settings: SettingsStatic) -> None:
         | 
| 65 60 | 
             
                    self._started = False
         | 
| 66 | 
            -
                    self._mailbox =  | 
| 61 | 
            +
                    self._mailbox = Mailbox()
         | 
| 67 62 | 
             
                    self._record_q = queue.Queue()
         | 
| 68 63 | 
             
                    self._result_q = queue.Queue()
         | 
| 69 64 | 
             
                    self._relay_q = queue.Queue()
         | 
| 70 | 
            -
                    process = multiprocessing.current_process()
         | 
| 71 65 | 
             
                    self._iface = InterfaceRelay(
         | 
| 72 66 | 
             
                        record_q=self._record_q,
         | 
| 73 67 | 
             
                        result_q=self._result_q,
         | 
| 74 68 | 
             
                        relay_q=self._relay_q,
         | 
| 75 | 
            -
                        process=process,
         | 
| 76 | 
            -
                        process_check=False,
         | 
| 77 69 | 
             
                        mailbox=self._mailbox,
         | 
| 78 70 | 
             
                    )
         | 
| 79 71 | 
             
                    self._settings = settings
         | 
| @@ -84,7 +76,7 @@ class StreamRecord: | |
| 84 76 | 
             
                    self._wait_thread_active()
         | 
| 85 77 |  | 
| 86 78 | 
             
                def _wait_thread_active(self) -> None:
         | 
| 87 | 
            -
                    self._iface.deliver_status(). | 
| 79 | 
            +
                    self._iface.deliver_status().wait_or(timeout=None)
         | 
| 88 80 |  | 
| 89 81 | 
             
                def join(self) -> None:
         | 
| 90 82 | 
             
                    self._iface.join()
         | 
| @@ -141,7 +133,6 @@ class StreamMux: | |
| 141 133 | 
             
                _action_q: queue.Queue[StreamAction]
         | 
| 142 134 | 
             
                _stopped: Event
         | 
| 143 135 | 
             
                _pid_checked_ts: float | None
         | 
| 144 | 
            -
                _mailbox: Mailbox
         | 
| 145 136 |  | 
| 146 137 | 
             
                def __init__(self) -> None:
         | 
| 147 138 | 
             
                    self._streams_lock = threading.Lock()
         | 
| @@ -151,8 +142,6 @@ class StreamMux: | |
| 151 142 | 
             
                    self._stopped = Event()
         | 
| 152 143 | 
             
                    self._action_q = queue.Queue()
         | 
| 153 144 | 
             
                    self._pid_checked_ts = None
         | 
| 154 | 
            -
                    self._mailbox = Mailbox()
         | 
| 155 | 
            -
                    self._mailbox.enable_keepalive()
         | 
| 156 145 |  | 
| 157 146 | 
             
                def _get_stopped_event(self) -> Event:
         | 
| 158 147 | 
             
                    # TODO: clean this up, there should be a better way to abstract this
         | 
| @@ -209,7 +198,7 @@ class StreamMux: | |
| 209 198 | 
             
                        return stream
         | 
| 210 199 |  | 
| 211 200 | 
             
                def _process_add(self, action: StreamAction) -> None:
         | 
| 212 | 
            -
                    stream = StreamRecord(action._data | 
| 201 | 
            +
                    stream = StreamRecord(action._data)
         | 
| 213 202 | 
             
                    # run_id = action.stream_id  # will want to fix if a streamid != runid
         | 
| 214 203 | 
             
                    settings = action._data
         | 
| 215 204 | 
             
                    thread = StreamThread(
         | 
| @@ -247,41 +236,51 @@ class StreamMux: | |
| 247 236 | 
             
                            stream.drop()
         | 
| 248 237 | 
             
                            stream.join()
         | 
| 249 238 |  | 
| 250 | 
            -
                def  | 
| 251 | 
            -
                    handle = probe_handle.get_mailbox_handle()
         | 
| 252 | 
            -
                    if handle:
         | 
| 253 | 
            -
                        result = handle.wait(timeout=0, release=False)
         | 
| 254 | 
            -
                        if not result:
         | 
| 255 | 
            -
                            return
         | 
| 256 | 
            -
                        probe_handle.set_probe_result(result)
         | 
| 257 | 
            -
                    handle = stream.interface.deliver_poll_exit()
         | 
| 258 | 
            -
                    probe_handle.set_mailbox_handle(handle)
         | 
| 259 | 
            -
             | 
| 260 | 
            -
                def _on_progress_exit(self, progress_handle: MailboxProgress) -> None:
         | 
| 261 | 
            -
                    pass
         | 
| 262 | 
            -
             | 
| 263 | 
            -
                def _on_progress_exit_all(
         | 
| 239 | 
            +
                async def _finish_all_progress(
         | 
| 264 240 | 
             
                    self,
         | 
| 265 241 | 
             
                    progress_printer: progress.ProgressPrinter,
         | 
| 266 | 
            -
                     | 
| 242 | 
            +
                    streams_to_watch: dict[str, StreamRecord],
         | 
| 267 243 | 
             
                ) -> None:
         | 
| 268 | 
            -
                     | 
| 269 | 
            -
             | 
| 270 | 
            -
                     | 
| 271 | 
            -
             | 
| 272 | 
            -
             | 
| 273 | 
            -
             | 
| 274 | 
            -
             | 
| 275 | 
            -
                     | 
| 276 | 
            -
             | 
| 277 | 
            -
             | 
| 278 | 
            -
                     | 
| 279 | 
            -
             | 
| 280 | 
            -
                         | 
| 281 | 
            -
             | 
| 282 | 
            -
             | 
| 283 | 
            -
             | 
| 284 | 
            -
             | 
| 244 | 
            +
                    """Poll the streams and display statistics about them.
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                    This never returns and must be cancelled.
         | 
| 247 | 
            +
             | 
| 248 | 
            +
                    Args:
         | 
| 249 | 
            +
                        progress_printer: Printer to use for displaying finish progress.
         | 
| 250 | 
            +
                        streams_to_watch: Streams to poll for finish progress.
         | 
| 251 | 
            +
                    """
         | 
| 252 | 
            +
                    results: dict[str, pb.Result | None] = {}
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                    async def loop_poll_stream(
         | 
| 255 | 
            +
                        stream_id: str,
         | 
| 256 | 
            +
                        stream: StreamRecord,
         | 
| 257 | 
            +
                    ) -> NoReturn:
         | 
| 258 | 
            +
                        while True:
         | 
| 259 | 
            +
                            start_time = time.monotonic()
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                            handle = stream.interface.deliver_poll_exit()
         | 
| 262 | 
            +
                            results[stream_id] = await handle.wait_async(timeout=None)
         | 
| 263 | 
            +
             | 
| 264 | 
            +
                            elapsed_time = time.monotonic() - start_time
         | 
| 265 | 
            +
                            if elapsed_time < 1:
         | 
| 266 | 
            +
                                await asyncio.sleep(1 - elapsed_time)
         | 
| 267 | 
            +
             | 
| 268 | 
            +
                    async def loop_update_printer() -> NoReturn:
         | 
| 269 | 
            +
                        while True:
         | 
| 270 | 
            +
                            poll_exit_responses: list[pb.PollExitResponse] = []
         | 
| 271 | 
            +
                            for result in results.values():
         | 
| 272 | 
            +
                                if not result or not result.response:
         | 
| 273 | 
            +
                                    continue
         | 
| 274 | 
            +
                                if poll_exit_response := result.response.poll_exit_response:
         | 
| 275 | 
            +
                                    poll_exit_responses.append(poll_exit_response)
         | 
| 276 | 
            +
             | 
| 277 | 
            +
                            progress_printer.update(poll_exit_responses)
         | 
| 278 | 
            +
                            await asyncio.sleep(1)
         | 
| 279 | 
            +
             | 
| 280 | 
            +
                    async with asyncio_compat.open_task_group() as task_group:
         | 
| 281 | 
            +
                        for stream_id, stream in streams_to_watch.items():
         | 
| 282 | 
            +
                            task_group.start_soon(loop_poll_stream(stream_id, stream))
         | 
| 283 | 
            +
                        task_group.start_soon(loop_update_printer())
         | 
| 285 284 |  | 
| 286 285 | 
             
                def _finish_all(self, streams: dict[str, StreamRecord], exit_code: int) -> None:
         | 
| 287 286 | 
             
                    if not streams:
         | 
| @@ -291,7 +290,7 @@ class StreamMux: | |
| 291 290 |  | 
| 292 291 | 
             
                    # fixme: for now we have a single printer for all streams,
         | 
| 293 292 | 
             
                    # and jupyter is disabled if at least single stream's setting set `_jupyter` to false
         | 
| 294 | 
            -
                    exit_handles = []
         | 
| 293 | 
            +
                    exit_handles: list[MailboxHandle] = []
         | 
| 295 294 |  | 
| 296 295 | 
             
                    # only finish started streams, non started streams failed early
         | 
| 297 296 | 
             
                    started_streams: dict[str, StreamRecord] = {}
         | 
| @@ -302,27 +301,24 @@ class StreamMux: | |
| 302 301 |  | 
| 303 302 | 
             
                    for stream in started_streams.values():
         | 
| 304 303 | 
             
                        handle = stream.interface.deliver_exit(exit_code)
         | 
| 305 | 
            -
                        handle.add_progress(self._on_progress_exit)
         | 
| 306 | 
            -
                        handle.add_probe(functools.partial(self._on_probe_exit, stream=stream))
         | 
| 307 304 | 
             
                        exit_handles.append(handle)
         | 
| 308 305 |  | 
| 309 | 
            -
             | 
| 310 | 
            -
                         | 
| 311 | 
            -
                         | 
| 312 | 
            -
             | 
| 313 | 
            -
             | 
| 314 | 
            -
                    with progress.progress_printer(printer) as progress_printer:
         | 
| 306 | 
            +
                    with progress.progress_printer(
         | 
| 307 | 
            +
                        printer,
         | 
| 308 | 
            +
                        default_text="Finishing up...",
         | 
| 309 | 
            +
                    ) as progress_printer:
         | 
| 315 310 | 
             
                        # todo: should we wait for the max timeout (?) of all exit handles or just wait forever?
         | 
| 316 311 | 
             
                        # timeout = max(stream._settings._exit_timeout for stream in streams.values())
         | 
| 317 | 
            -
                         | 
| 318 | 
            -
                             | 
| 319 | 
            -
                            timeout | 
| 320 | 
            -
                             | 
| 321 | 
            -
             | 
| 312 | 
            +
                        wait_all_with_progress(
         | 
| 313 | 
            +
                            exit_handles,
         | 
| 314 | 
            +
                            timeout=None,
         | 
| 315 | 
            +
                            progress_after=1,
         | 
| 316 | 
            +
                            display_progress=functools.partial(
         | 
| 317 | 
            +
                                self._finish_all_progress,
         | 
| 322 318 | 
             
                                progress_printer,
         | 
| 319 | 
            +
                                started_streams,
         | 
| 323 320 | 
             
                            ),
         | 
| 324 321 | 
             
                        )
         | 
| 325 | 
            -
                        assert got_result
         | 
| 326 322 |  | 
| 327 323 | 
             
                    # These could be done in parallel in the future
         | 
| 328 324 | 
             
                    for _sid, stream in started_streams.items():
         | 
| @@ -332,20 +328,16 @@ class StreamMux: | |
| 332 328 | 
             
                        sampled_history_handle = stream.interface.deliver_request_sampled_history()
         | 
| 333 329 | 
             
                        internal_messages_handle = stream.interface.deliver_internal_messages()
         | 
| 334 330 |  | 
| 335 | 
            -
                        result = internal_messages_handle. | 
| 336 | 
            -
                        assert result
         | 
| 331 | 
            +
                        result = internal_messages_handle.wait_or(timeout=None)
         | 
| 337 332 | 
             
                        internal_messages_response = result.response.internal_messages_response
         | 
| 338 333 |  | 
| 339 | 
            -
                        result = poll_exit_handle. | 
| 340 | 
            -
                        assert result
         | 
| 334 | 
            +
                        result = poll_exit_handle.wait_or(timeout=None)
         | 
| 341 335 | 
             
                        poll_exit_response = result.response.poll_exit_response
         | 
| 342 336 |  | 
| 343 | 
            -
                        result = sampled_history_handle. | 
| 344 | 
            -
                        assert result
         | 
| 337 | 
            +
                        result = sampled_history_handle.wait_or(timeout=None)
         | 
| 345 338 | 
             
                        sampled_history = result.response.sampled_history_response
         | 
| 346 339 |  | 
| 347 | 
            -
                        result = final_summary_handle. | 
| 348 | 
            -
                        assert result
         | 
| 340 | 
            +
                        result = final_summary_handle.wait_or(timeout=None)
         | 
| 349 341 | 
             
                        final_summary = result.response.get_summary_response
         | 
| 350 342 |  | 
| 351 343 | 
             
                        Run._footer(
         | 
    
        wandb/sdk/verify/verify.py
    CHANGED
    
    | @@ -1,6 +1,8 @@ | |
| 1 1 | 
             
            """Utilities for wandb verify."""
         | 
| 2 2 |  | 
| 3 | 
            +
            import contextlib
         | 
| 3 4 | 
             
            import getpass
         | 
| 5 | 
            +
            import io
         | 
| 4 6 | 
             
            import os
         | 
| 5 7 | 
             
            import time
         | 
| 6 8 | 
             
            from functools import partial
         | 
| @@ -163,8 +165,8 @@ def check_run(api: Api) -> bool: | |
| 163 165 | 
             
                    )
         | 
| 164 166 | 
             
                    print_results(failed_test_strings, False)
         | 
| 165 167 | 
             
                    return False
         | 
| 166 | 
            -
                for key, value in  | 
| 167 | 
            -
                    if config | 
| 168 | 
            +
                for key, value in config.items():
         | 
| 169 | 
            +
                    if prev_run.config.get(key) != value:
         | 
| 168 170 | 
             
                        failed_test_strings.append(
         | 
| 169 171 | 
             
                            "Read config values don't match run config. Contact W&B for support."
         | 
| 170 172 | 
             
                        )
         | 
| @@ -486,6 +488,56 @@ def check_wandb_version(api: Api) -> None: | |
| 486 488 | 
             
                print_results(fail_string, warning)
         | 
| 487 489 |  | 
| 488 490 |  | 
| 491 | 
            +
            def check_sweeps(api: Api) -> bool:
         | 
| 492 | 
            +
                print("Checking sweep creation and agent execution".ljust(72, "."), end="")  # noqa: T201
         | 
| 493 | 
            +
                failed_test_strings: List[str] = []
         | 
| 494 | 
            +
             | 
| 495 | 
            +
                sweep_config = {
         | 
| 496 | 
            +
                    "method": "random",
         | 
| 497 | 
            +
                    "metric": {"goal": "minimize", "name": "score"},
         | 
| 498 | 
            +
                    "parameters": {
         | 
| 499 | 
            +
                        "x": {"values": [0.01, 0.05, 0.1]},
         | 
| 500 | 
            +
                        "y": {"values": [1, 2, 3]},
         | 
| 501 | 
            +
                    },
         | 
| 502 | 
            +
                    "name": "verify_sweep",
         | 
| 503 | 
            +
                }
         | 
| 504 | 
            +
             | 
| 505 | 
            +
                try:
         | 
| 506 | 
            +
                    with contextlib.redirect_stdout(io.StringIO()):
         | 
| 507 | 
            +
                        sweep_id = wandb.sweep(
         | 
| 508 | 
            +
                            sweep=sweep_config, project=PROJECT_NAME, entity=api.default_entity
         | 
| 509 | 
            +
                        )
         | 
| 510 | 
            +
                except Exception as e:
         | 
| 511 | 
            +
                    failed_test_strings.append(f"Failed to create sweep: {e}")
         | 
| 512 | 
            +
                    print_results(failed_test_strings, False)
         | 
| 513 | 
            +
                    return False
         | 
| 514 | 
            +
             | 
| 515 | 
            +
                if not sweep_id:
         | 
| 516 | 
            +
                    failed_test_strings.append("Sweep creation returned an invalid ID.")
         | 
| 517 | 
            +
                    print_results(failed_test_strings, False)
         | 
| 518 | 
            +
                    return False
         | 
| 519 | 
            +
             | 
| 520 | 
            +
                try:
         | 
| 521 | 
            +
             | 
| 522 | 
            +
                    def objective(config):
         | 
| 523 | 
            +
                        score = config.x**3 + config.y
         | 
| 524 | 
            +
                        return score
         | 
| 525 | 
            +
             | 
| 526 | 
            +
                    def main():
         | 
| 527 | 
            +
                        with wandb.init(project=PROJECT_NAME) as run:
         | 
| 528 | 
            +
                            score = objective(run.config)
         | 
| 529 | 
            +
                            run.log({"score": score})
         | 
| 530 | 
            +
             | 
| 531 | 
            +
                    wandb.agent(sweep_id, function=main, count=10)
         | 
| 532 | 
            +
                except Exception as e:
         | 
| 533 | 
            +
                    failed_test_strings.append(f"Failed to run sweep agent: {e}")
         | 
| 534 | 
            +
                    print_results(failed_test_strings, False)
         | 
| 535 | 
            +
                    return False
         | 
| 536 | 
            +
             | 
| 537 | 
            +
                print_results(failed_test_strings, False)
         | 
| 538 | 
            +
                return len(failed_test_strings) == 0
         | 
| 539 | 
            +
             | 
| 540 | 
            +
             | 
| 489 541 | 
             
            def retry_fn(fn: Callable) -> Any:
         | 
| 490 542 | 
             
                ini_time = time.time()
         | 
| 491 543 | 
             
                res = None
         | 
    
        wandb/sdk/wandb_init.py
    CHANGED
    
    | @@ -34,7 +34,7 @@ from wandb.errors import CommError, Error, UsageError | |
| 34 34 | 
             
            from wandb.errors.links import url_registry
         | 
| 35 35 | 
             
            from wandb.errors.util import ProtobufErrorHandler
         | 
| 36 36 | 
             
            from wandb.integration import sagemaker
         | 
| 37 | 
            -
            from wandb.sdk.lib import runid
         | 
| 37 | 
            +
            from wandb.sdk.lib import progress, runid
         | 
| 38 38 | 
             
            from wandb.sdk.lib.paths import StrPath
         | 
| 39 39 | 
             
            from wandb.util import _is_artifact_representation
         | 
| 40 40 |  | 
| @@ -42,7 +42,7 @@ from . import wandb_login, wandb_setup | |
| 42 42 | 
             
            from .backend.backend import Backend
         | 
| 43 43 | 
             
            from .lib import SummaryDisabled, filesystem, module, paths, printer, telemetry
         | 
| 44 44 | 
             
            from .lib.deprecate import Deprecated, deprecate
         | 
| 45 | 
            -
            from . | 
| 45 | 
            +
            from .mailbox import Mailbox, wait_with_progress
         | 
| 46 46 | 
             
            from .wandb_helper import parse_config
         | 
| 47 47 | 
             
            from .wandb_run import Run, TeardownHook, TeardownStage
         | 
| 48 48 | 
             
            from .wandb_settings import Settings
         | 
| @@ -293,6 +293,16 @@ class _WandbInit: | |
| 293 293 |  | 
| 294 294 | 
             
                    settings.x_start_time = time.time()
         | 
| 295 295 |  | 
| 296 | 
            +
                    # In shared mode, generate a unique label if not provided.
         | 
| 297 | 
            +
                    # The label is used to distinguish between system metrics and console logs
         | 
| 298 | 
            +
                    # from different writers to the same run.
         | 
| 299 | 
            +
                    if settings._shared and not settings.x_label:
         | 
| 300 | 
            +
                        # TODO: If executed in a known distributed environment (e.g. Ray or SLURM),
         | 
| 301 | 
            +
                        #   use the env vars to generate a label (e.g. SLURM_JOB_ID or RANK)
         | 
| 302 | 
            +
                        prefix = settings.host or ""
         | 
| 303 | 
            +
                        label = runid.generate_id()
         | 
| 304 | 
            +
                        settings.x_label = f"{prefix}-{label}" if prefix else label
         | 
| 305 | 
            +
             | 
| 296 306 | 
             
                    return settings
         | 
| 297 307 |  | 
| 298 308 | 
             
                def _load_autoresume_run_id(self, resume_file: pathlib.Path) -> str | None:
         | 
| @@ -747,11 +757,6 @@ class _WandbInit: | |
| 747 757 | 
             
                    )
         | 
| 748 758 | 
             
                    return drun
         | 
| 749 759 |  | 
| 750 | 
            -
                def _on_progress_init(self, handle: MailboxProgress) -> None:
         | 
| 751 | 
            -
                    line = "Waiting for wandb.init()...\r"
         | 
| 752 | 
            -
                    percent_done = handle.percent_done
         | 
| 753 | 
            -
                    self.printer.progress_update(line, percent_done=percent_done)
         | 
| 754 | 
            -
             | 
| 755 760 | 
             
                def init(self, settings: Settings, config: _ConfigParts) -> Run:  # noqa: C901
         | 
| 756 761 | 
             
                    self._logger.info("calling init triggers")
         | 
| 757 762 | 
             
                    trigger.call("on_init")
         | 
| @@ -763,28 +768,18 @@ class _WandbInit: | |
| 763 768 | 
             
                        f"\nconfig: {config.base_no_artifacts}"
         | 
| 764 769 | 
             
                    )
         | 
| 765 770 |  | 
| 766 | 
            -
                    if (
         | 
| 767 | 
            -
                         | 
| 768 | 
            -
             | 
| 769 | 
            -
             | 
| 770 | 
            -
             | 
| 771 | 
            -
             | 
| 772 | 
            -
                                " module is not well-supported."
         | 
| 773 | 
            -
                                " Please use multiprocessing instead."
         | 
| 774 | 
            -
                                " Finishing previous run before initializing another."
         | 
| 775 | 
            -
                            )
         | 
| 776 | 
            -
             | 
| 777 | 
            -
                        latest_run = self._wl._global_run_stack[-1]
         | 
| 778 | 
            -
                        self._logger.info(f"found existing run on stack: {latest_run.id}")
         | 
| 779 | 
            -
                        latest_run.finish()
         | 
| 780 | 
            -
                    elif wandb.run is not None and os.getpid() == wandb.run._init_pid:
         | 
| 781 | 
            -
                        self._logger.info("wandb.init() called when a run is still active")
         | 
| 771 | 
            +
                    if wandb.run is not None and os.getpid() == wandb.run._init_pid:
         | 
| 772 | 
            +
                        if settings.reinit:
         | 
| 773 | 
            +
                            self._logger.info(f"finishing previous run: {wandb.run.id}")
         | 
| 774 | 
            +
                            wandb.run.finish()
         | 
| 775 | 
            +
                        else:
         | 
| 776 | 
            +
                            self._logger.info("wandb.init() called while a run is active")
         | 
| 782 777 |  | 
| 783 | 
            -
             | 
| 784 | 
            -
             | 
| 785 | 
            -
             | 
| 778 | 
            +
                            # NOTE: Updates telemetry on the pre-existing run.
         | 
| 779 | 
            +
                            with telemetry.context() as tel:
         | 
| 780 | 
            +
                                tel.feature.init_return_run = True
         | 
| 786 781 |  | 
| 787 | 
            -
             | 
| 782 | 
            +
                            return wandb.run
         | 
| 788 783 |  | 
| 789 784 | 
             
                    self._logger.info("starting backend")
         | 
| 790 785 |  | 
| @@ -905,7 +900,6 @@ class _WandbInit: | |
| 905 900 | 
             
                    run._set_teardown_hooks(self._teardown_hooks)
         | 
| 906 901 |  | 
| 907 902 | 
             
                    assert backend.interface
         | 
| 908 | 
            -
                    mailbox.enable_keepalive()
         | 
| 909 903 | 
             
                    backend.interface.publish_header()
         | 
| 910 904 |  | 
| 911 905 | 
             
                    # Using GitRepo() blocks & can be slow, depending on user's current git setup.
         | 
| @@ -921,16 +915,6 @@ class _WandbInit: | |
| 921 915 | 
             
                        )
         | 
| 922 916 | 
             
                    error: wandb.Error | None = None
         | 
| 923 917 |  | 
| 924 | 
            -
                    # In shared mode, generate a unique label if not provided.
         | 
| 925 | 
            -
                    # The label is used to distinguish between system metrics and console logs
         | 
| 926 | 
            -
                    # from different writers to the same run.
         | 
| 927 | 
            -
                    if settings._shared and not settings.x_label:
         | 
| 928 | 
            -
                        # TODO: If executed in a known distributed environment (e.g. Ray or SLURM),
         | 
| 929 | 
            -
                        #   use the env vars to generate a label (e.g. SLURM_JOB_ID or RANK)
         | 
| 930 | 
            -
                        prefix = settings.host or ""
         | 
| 931 | 
            -
                        label = runid.generate_id()
         | 
| 932 | 
            -
                        settings.x_label = f"{prefix}-{label}" if prefix else label
         | 
| 933 | 
            -
             | 
| 934 918 | 
             
                    timeout = settings.init_timeout
         | 
| 935 919 |  | 
| 936 920 | 
             
                    self._logger.info(
         | 
| @@ -938,11 +922,18 @@ class _WandbInit: | |
| 938 922 | 
             
                    )
         | 
| 939 923 |  | 
| 940 924 | 
             
                    run_init_handle = backend.interface.deliver_run(run)
         | 
| 941 | 
            -
             | 
| 942 | 
            -
             | 
| 943 | 
            -
                         | 
| 944 | 
            -
             | 
| 945 | 
            -
             | 
| 925 | 
            +
             | 
| 926 | 
            +
                    async def display_init_message() -> None:
         | 
| 927 | 
            +
                        assert backend.interface
         | 
| 928 | 
            +
             | 
| 929 | 
            +
                        with progress.progress_printer(
         | 
| 930 | 
            +
                            self.printer,
         | 
| 931 | 
            +
                            default_text="Waiting for wandb.init()...",
         | 
| 932 | 
            +
                        ) as progress_printer:
         | 
| 933 | 
            +
                            await progress.loop_printing_operation_stats(
         | 
| 934 | 
            +
                                progress_printer,
         | 
| 935 | 
            +
                                backend.interface,
         | 
| 936 | 
            +
                            )
         | 
| 946 937 |  | 
| 947 938 | 
             
                    # Raise an error if deliver_run failed.
         | 
| 948 939 | 
             
                    #
         | 
| @@ -951,8 +942,16 @@ class _WandbInit: | |
| 951 942 | 
             
                    #
         | 
| 952 943 | 
             
                    # TODO: Remove try-except once x_disable_service is removed.
         | 
| 953 944 | 
             
                    try:
         | 
| 954 | 
            -
                         | 
| 955 | 
            -
                             | 
| 945 | 
            +
                        try:
         | 
| 946 | 
            +
                            result = wait_with_progress(
         | 
| 947 | 
            +
                                run_init_handle,
         | 
| 948 | 
            +
                                timeout=timeout,
         | 
| 949 | 
            +
                                progress_after=1,
         | 
| 950 | 
            +
                                display_progress=display_init_message,
         | 
| 951 | 
            +
                            )
         | 
| 952 | 
            +
             | 
| 953 | 
            +
                        except TimeoutError:
         | 
| 954 | 
            +
                            run_init_handle.cancel(backend.interface)
         | 
| 956 955 |  | 
| 957 956 | 
             
                            # This may either be an issue with the W&B server (a CommError)
         | 
| 958 957 | 
             
                            # or a bug in the SDK (an Error). We cannot distinguish between
         | 
| @@ -963,6 +962,8 @@ class _WandbInit: | |
| 963 962 | 
             
                                " setting: `wandb.init(settings=wandb.Settings(init_timeout=120))`."
         | 
| 964 963 | 
             
                            )
         | 
| 965 964 |  | 
| 965 | 
            +
                        assert result.run_result
         | 
| 966 | 
            +
             | 
| 966 967 | 
             
                        if error := ProtobufErrorHandler.to_exception(result.run_result.error):
         | 
| 967 968 | 
             
                            raise error
         | 
| 968 969 |  | 
| @@ -1004,13 +1005,13 @@ class _WandbInit: | |
| 1004 1005 | 
             
                    assert backend.interface
         | 
| 1005 1006 |  | 
| 1006 1007 | 
             
                    run_start_handle = backend.interface.deliver_run_start(run)
         | 
| 1007 | 
            -
                     | 
| 1008 | 
            -
             | 
| 1009 | 
            -
             | 
| 1010 | 
            -
             | 
| 1008 | 
            +
                    try:
         | 
| 1009 | 
            +
                        # TODO: add progress to let user know we are doing something
         | 
| 1010 | 
            +
                        run_start_handle.wait_or(timeout=30)
         | 
| 1011 | 
            +
                    except TimeoutError:
         | 
| 1012 | 
            +
                        pass
         | 
| 1011 1013 |  | 
| 1012 1014 | 
             
                    assert self._wl is not None
         | 
| 1013 | 
            -
                    self._wl._global_run_stack.append(run)
         | 
| 1014 1015 | 
             
                    self.run = run
         | 
| 1015 1016 |  | 
| 1016 1017 | 
             
                    run._handle_launch_artifact_overrides()
         | 
| @@ -1097,14 +1098,13 @@ def _attach( | |
| 1097 1098 | 
             
                run._set_backend(backend)
         | 
| 1098 1099 | 
             
                assert backend.interface
         | 
| 1099 1100 |  | 
| 1100 | 
            -
                mailbox.enable_keepalive()
         | 
| 1101 | 
            -
             | 
| 1102 1101 | 
             
                attach_handle = backend.interface.deliver_attach(attach_id)
         | 
| 1103 | 
            -
                 | 
| 1104 | 
            -
             | 
| 1105 | 
            -
             | 
| 1106 | 
            -
             | 
| 1102 | 
            +
                try:
         | 
| 1103 | 
            +
                    # TODO: add progress to let user know we are doing something
         | 
| 1104 | 
            +
                    attach_result = attach_handle.wait_or(timeout=30)
         | 
| 1105 | 
            +
                except TimeoutError:
         | 
| 1107 1106 | 
             
                    raise UsageError("Timeout attaching to run")
         | 
| 1107 | 
            +
             | 
| 1108 1108 | 
             
                attach_response = attach_result.response.attach_response
         | 
| 1109 1109 | 
             
                if attach_response.error and attach_response.error.message:
         | 
| 1110 1110 | 
             
                    raise UsageError(f"Failed to attach to run: {attach_response.error.message}")
         | 
| @@ -1222,10 +1222,10 @@ def init(  # noqa: C901 | |
| 1222 1222 | 
             
                        on the system, such as checking the git root or the current program
         | 
| 1223 1223 | 
             
                        file. If we can't infer the project name, the project will default to
         | 
| 1224 1224 | 
             
                        `"uncategorized"`.
         | 
| 1225 | 
            -
                    dir:  | 
| 1226 | 
            -
                        files  | 
| 1227 | 
            -
                         | 
| 1228 | 
            -
                         | 
| 1225 | 
            +
                    dir: The absolute path to the directory where experiment logs and
         | 
| 1226 | 
            +
                        metadata files are stored. If not specified, this defaults
         | 
| 1227 | 
            +
                        to the `./wandb` directory. Note that this does not affect the
         | 
| 1228 | 
            +
                        location where artifacts are stored when calling `download()`.
         | 
| 1229 1229 | 
             
                    id: A unique identifier for this run, used for resuming. It must be unique
         | 
| 1230 1230 | 
             
                        within the project and cannot be reused once a run is deleted. The
         | 
| 1231 1231 | 
             
                        identifier must not contain any of the following special characters:
         | 
| @@ -1426,7 +1426,7 @@ def init(  # noqa: C901 | |
| 1426 1426 | 
             
                wl: wandb_setup._WandbSetup | None = None
         | 
| 1427 1427 |  | 
| 1428 1428 | 
             
                try:
         | 
| 1429 | 
            -
                    wl =  | 
| 1429 | 
            +
                    wl = wandb_setup._setup(start_service=False)
         | 
| 1430 1430 |  | 
| 1431 1431 | 
             
                    wi = _WandbInit(wl, init_telemetry)
         | 
| 1432 1432 |  | 
    
        wandb/sdk/wandb_login.py
    CHANGED
    
    | @@ -162,8 +162,8 @@ class _WandbLogin: | |
| 162 162 | 
             
                        repeat=False,
         | 
| 163 163 | 
             
                    )
         | 
| 164 164 |  | 
| 165 | 
            -
                def  | 
| 166 | 
            -
                    """Saves the API key  | 
| 165 | 
            +
                def try_save_api_key(self, key: str) -> None:
         | 
| 166 | 
            +
                    """Saves the API key to disk for future use."""
         | 
| 167 167 | 
             
                    if self._settings._notebook and not self._settings.silent:
         | 
| 168 168 | 
             
                        wandb.termwarn(
         | 
| 169 169 | 
             
                            "If you're specifying your api key in code, ensure this "
         | 
| @@ -172,7 +172,10 @@ class _WandbLogin: | |
| 172 172 | 
             
                            "`wandb login` from the command line."
         | 
| 173 173 | 
             
                        )
         | 
| 174 174 | 
             
                    if key:
         | 
| 175 | 
            -
                         | 
| 175 | 
            +
                        try:
         | 
| 176 | 
            +
                            apikey.write_key(self._settings, key)
         | 
| 177 | 
            +
                        except apikey.WriteNetrcError as e:
         | 
| 178 | 
            +
                            wandb.termwarn(str(e))
         | 
| 176 179 |  | 
| 177 180 | 
             
                def update_session(
         | 
| 178 181 | 
             
                    self,
         | 
| @@ -305,7 +308,7 @@ def _login( | |
| 305 308 | 
             
                    wlogin._verify_login(key)
         | 
| 306 309 |  | 
| 307 310 | 
             
                if not key_is_pre_configured:
         | 
| 308 | 
            -
                    wlogin. | 
| 311 | 
            +
                    wlogin.try_save_api_key(key)
         | 
| 309 312 | 
             
                    wlogin.update_session(key, status=key_status)
         | 
| 310 313 | 
             
                    wlogin._update_global_anonymous_setting()
         | 
| 311 314 |  |