wandb 0.19.6__py3-none-macosx_11_0_arm64.whl → 0.19.7__py3-none-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +1 -1
- wandb/__init__.pyi +25 -5
- wandb/apis/public/_generated/__init__.py +21 -0
- wandb/apis/public/_generated/base.py +128 -0
- wandb/apis/public/_generated/enums.py +4 -0
- wandb/apis/public/_generated/input_types.py +4 -0
- wandb/apis/public/_generated/operations.py +15 -0
- wandb/apis/public/_generated/server_features_query.py +27 -0
- wandb/apis/public/_generated/typing_compat.py +14 -0
- wandb/apis/public/api.py +192 -6
- wandb/apis/public/artifacts.py +13 -45
- wandb/apis/public/registries.py +573 -0
- wandb/apis/public/utils.py +36 -0
- wandb/bin/gpu_stats +0 -0
- wandb/bin/wandb-core +0 -0
- wandb/cli/cli.py +11 -20
- wandb/env.py +10 -0
- wandb/proto/v3/wandb_internal_pb2.py +243 -222
- wandb/proto/v3/wandb_server_pb2.py +4 -4
- wandb/proto/v3/wandb_settings_pb2.py +1 -1
- wandb/proto/v4/wandb_internal_pb2.py +226 -222
- wandb/proto/v4/wandb_server_pb2.py +4 -4
- wandb/proto/v4/wandb_settings_pb2.py +1 -1
- wandb/proto/v5/wandb_internal_pb2.py +226 -222
- wandb/proto/v5/wandb_server_pb2.py +4 -4
- wandb/proto/v5/wandb_settings_pb2.py +1 -1
- wandb/sdk/artifacts/_graphql_fragments.py +126 -0
- wandb/sdk/artifacts/artifact.py +43 -88
- wandb/sdk/backend/backend.py +1 -1
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +14 -6
- wandb/sdk/data_types/helper_types/image_mask.py +12 -6
- wandb/sdk/data_types/saved_model.py +35 -46
- wandb/sdk/data_types/video.py +7 -16
- wandb/sdk/interface/interface.py +26 -10
- wandb/sdk/interface/interface_queue.py +5 -8
- wandb/sdk/interface/interface_relay.py +1 -6
- wandb/sdk/interface/interface_shared.py +21 -99
- wandb/sdk/interface/interface_sock.py +2 -13
- wandb/sdk/interface/router.py +21 -15
- wandb/sdk/interface/router_queue.py +2 -1
- wandb/sdk/interface/router_relay.py +2 -1
- wandb/sdk/interface/router_sock.py +5 -4
- wandb/sdk/internal/handler.py +4 -3
- wandb/sdk/internal/internal_api.py +12 -1
- wandb/sdk/internal/sender.py +0 -18
- wandb/sdk/lib/apikey.py +87 -26
- wandb/sdk/lib/asyncio_compat.py +210 -0
- wandb/sdk/lib/progress.py +78 -16
- wandb/sdk/lib/service_connection.py +1 -1
- wandb/sdk/lib/sock_client.py +7 -7
- wandb/sdk/mailbox/__init__.py +23 -0
- wandb/sdk/mailbox/handles.py +199 -0
- wandb/sdk/mailbox/mailbox.py +121 -0
- wandb/sdk/mailbox/wait_with_progress.py +134 -0
- wandb/sdk/service/server_sock.py +5 -1
- wandb/sdk/service/streams.py +66 -74
- wandb/sdk/verify/verify.py +54 -2
- wandb/sdk/wandb_init.py +61 -61
- wandb/sdk/wandb_login.py +7 -4
- wandb/sdk/wandb_metadata.py +65 -34
- wandb/sdk/wandb_require.py +14 -8
- wandb/sdk/wandb_run.py +82 -87
- wandb/sdk/wandb_settings.py +3 -3
- wandb/sdk/wandb_setup.py +19 -8
- wandb/sdk/wandb_sync.py +2 -4
- wandb/util.py +3 -1
- {wandb-0.19.6.dist-info → wandb-0.19.7.dist-info}/METADATA +2 -2
- {wandb-0.19.6.dist-info → wandb-0.19.7.dist-info}/RECORD +71 -58
- wandb/sdk/lib/mailbox.py +0 -442
- {wandb-0.19.6.dist-info → wandb-0.19.7.dist-info}/WHEEL +0 -0
- {wandb-0.19.6.dist-info → wandb-0.19.7.dist-info}/entry_points.txt +0 -0
- {wandb-0.19.6.dist-info → wandb-0.19.7.dist-info}/licenses/LICENSE +0 -0
| @@ -7,10 +7,10 @@ See interface.py for how interface classes relate to each other. | |
| 7 7 | 
             
            import logging
         | 
| 8 8 | 
             
            from typing import TYPE_CHECKING, Any, Optional
         | 
| 9 9 |  | 
| 10 | 
            -
            from  | 
| 10 | 
            +
            from wandb.sdk.mailbox import Mailbox
         | 
| 11 | 
            +
             | 
| 11 12 | 
             
            from ..lib.sock_client import SockClient
         | 
| 12 13 | 
             
            from .interface_shared import InterfaceShared
         | 
| 13 | 
            -
            from .message_future import MessageFuture
         | 
| 14 14 | 
             
            from .router_sock import MessageSockRouter
         | 
| 15 15 |  | 
| 16 16 | 
             
            if TYPE_CHECKING:
         | 
| @@ -32,7 +32,6 @@ class InterfaceSock(InterfaceShared): | |
| 32 32 | 
             
                    # _sock_client is used when abstract method _init_router() is called by constructor
         | 
| 33 33 | 
             
                    self._sock_client = sock_client
         | 
| 34 34 | 
             
                    super().__init__(mailbox=mailbox)
         | 
| 35 | 
            -
                    self._process_check = False
         | 
| 36 35 | 
             
                    self._stream_id = stream_id
         | 
| 37 36 |  | 
| 38 37 | 
             
                def _init_router(self) -> None:
         | 
| @@ -45,13 +44,3 @@ class InterfaceSock(InterfaceShared): | |
| 45 44 | 
             
                def _publish(self, record: "pb.Record", local: Optional[bool] = None) -> None:
         | 
| 46 45 | 
             
                    self._assign(record)
         | 
| 47 46 | 
             
                    self._sock_client.send_record_publish(record)
         | 
| 48 | 
            -
             | 
| 49 | 
            -
                def _communicate_async(
         | 
| 50 | 
            -
                    self, rec: "pb.Record", local: Optional[bool] = None
         | 
| 51 | 
            -
                ) -> MessageFuture:
         | 
| 52 | 
            -
                    self._assign(rec)
         | 
| 53 | 
            -
                    assert self._router
         | 
| 54 | 
            -
                    if self._process_check and self._process and not self._process.is_alive():
         | 
| 55 | 
            -
                        raise Exception("The wandb backend process has shutdown")
         | 
| 56 | 
            -
                    future = self._router.send_and_receive(rec, local=local)
         | 
| 57 | 
            -
                    return future
         | 
    
        wandb/sdk/interface/router.py
    CHANGED
    
    | @@ -10,7 +10,8 @@ import uuid | |
| 10 10 | 
             
            from abc import abstractmethod
         | 
| 11 11 | 
             
            from typing import TYPE_CHECKING, Dict, Optional
         | 
| 12 12 |  | 
| 13 | 
            -
            from  | 
| 13 | 
            +
            from wandb.sdk import mailbox
         | 
| 14 | 
            +
             | 
| 14 15 | 
             
            from .message_future import MessageFuture
         | 
| 15 16 |  | 
| 16 17 | 
             
            if TYPE_CHECKING:
         | 
| @@ -63,20 +64,25 @@ class MessageRouter: | |
| 63 64 | 
             
                    raise NotImplementedError
         | 
| 64 65 |  | 
| 65 66 | 
             
                def message_loop(self) -> None:
         | 
| 66 | 
            -
                     | 
| 67 | 
            -
                         | 
| 68 | 
            -
                             | 
| 69 | 
            -
             | 
| 70 | 
            -
                             | 
| 71 | 
            -
             | 
| 72 | 
            -
             | 
| 73 | 
            -
             | 
| 74 | 
            -
             | 
| 75 | 
            -
                             | 
| 76 | 
            -
             | 
| 77 | 
            -
             | 
| 78 | 
            -
                             | 
| 79 | 
            -
             | 
| 67 | 
            +
                    try:
         | 
| 68 | 
            +
                        while not self._join_event.is_set():
         | 
| 69 | 
            +
                            try:
         | 
| 70 | 
            +
                                msg = self._read_message()
         | 
| 71 | 
            +
                            except EOFError:
         | 
| 72 | 
            +
                                # On abnormal shutdown the queue will be destroyed underneath
         | 
| 73 | 
            +
                                # resulting in EOFError.  message_loop needs to exit..
         | 
| 74 | 
            +
                                logger.warning("EOFError seen in message_loop")
         | 
| 75 | 
            +
                                break
         | 
| 76 | 
            +
                            except MessageRouterClosedError as e:
         | 
| 77 | 
            +
                                logger.warning("message_loop has been closed", exc_info=e)
         | 
| 78 | 
            +
                                break
         | 
| 79 | 
            +
                            if not msg:
         | 
| 80 | 
            +
                                continue
         | 
| 81 | 
            +
                            self._handle_msg_rcv(msg)
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                    finally:
         | 
| 84 | 
            +
                        if self._mailbox:
         | 
| 85 | 
            +
                            self._mailbox.close()
         | 
| 80 86 |  | 
| 81 87 | 
             
                def send_and_receive(
         | 
| 82 88 | 
             
                    self, rec: "pb.Record", local: Optional[bool] = None
         | 
| @@ -6,8 +6,9 @@ Router to manage responses from a socket client. | |
| 6 6 |  | 
| 7 7 | 
             
            from typing import TYPE_CHECKING, Optional
         | 
| 8 8 |  | 
| 9 | 
            -
            from  | 
| 10 | 
            -
            from  | 
| 9 | 
            +
            from wandb.sdk.lib.sock_client import SockClient, SockClientClosedError
         | 
| 10 | 
            +
            from wandb.sdk.mailbox import Mailbox
         | 
| 11 | 
            +
             | 
| 11 12 | 
             
            from .router import MessageRouter, MessageRouterClosedError
         | 
| 12 13 |  | 
| 13 14 | 
             
            if TYPE_CHECKING:
         | 
| @@ -25,8 +26,8 @@ class MessageSockRouter(MessageRouter): | |
| 25 26 | 
             
                def _read_message(self) -> Optional["pb.Result"]:
         | 
| 26 27 | 
             
                    try:
         | 
| 27 28 | 
             
                        resp = self._sock_client.read_server_response(timeout=1)
         | 
| 28 | 
            -
                    except SockClientClosedError:
         | 
| 29 | 
            -
                        raise MessageRouterClosedError
         | 
| 29 | 
            +
                    except SockClientClosedError as e:
         | 
| 30 | 
            +
                        raise MessageRouterClosedError from e
         | 
| 30 31 | 
             
                    if not resp:
         | 
| 31 32 | 
             
                        return None
         | 
| 32 33 | 
             
                    msg = resp.result_communicate
         | 
    
        wandb/sdk/internal/handler.py
    CHANGED
    
    | @@ -205,9 +205,6 @@ class HandleManager: | |
| 205 205 | 
             
                    # defer is used to drive the sender finish state machine
         | 
| 206 206 | 
             
                    self._dispatch_record(record, always_send=True)
         | 
| 207 207 |  | 
| 208 | 
            -
                def handle_request_login(self, record: Record) -> None:
         | 
| 209 | 
            -
                    self._dispatch_record(record)
         | 
| 210 | 
            -
             | 
| 211 208 | 
             
                def handle_request_python_packages(self, record: Record) -> None:
         | 
| 212 209 | 
             
                    self._dispatch_record(record)
         | 
| 213 210 |  | 
| @@ -892,6 +889,10 @@ class HandleManager: | |
| 892 889 | 
             
                    self._respond_result(result)
         | 
| 893 890 | 
             
                    self._stopped.set()
         | 
| 894 891 |  | 
| 892 | 
            +
                def handle_request_operations(self, record: Record) -> None:
         | 
| 893 | 
            +
                    """No-op. Not implemented for the legacy-service."""
         | 
| 894 | 
            +
                    self._respond_result(proto_util._result_from_record(record))
         | 
| 895 | 
            +
             | 
| 895 896 | 
             
                def finish(self) -> None:
         | 
| 896 897 | 
             
                    logger.info("shutting down handler")
         | 
| 897 898 | 
             
                    if self._system_monitor is not None:
         | 
| @@ -115,6 +115,7 @@ if TYPE_CHECKING: | |
| 115 115 | 
             
                    root_dir: Optional[str]
         | 
| 116 116 | 
             
                    api_key: Optional[str]
         | 
| 117 117 | 
             
                    entity: Optional[str]
         | 
| 118 | 
            +
                    organization: Optional[str]
         | 
| 118 119 | 
             
                    project: Optional[str]
         | 
| 119 120 | 
             
                    _extra_http_headers: Optional[Mapping[str, str]]
         | 
| 120 121 | 
             
                    _proxies: Optional[Mapping[str, str]]
         | 
| @@ -256,6 +257,7 @@ class Api: | |
| 256 257 | 
             
                        "root_dir": None,
         | 
| 257 258 | 
             
                        "api_key": None,
         | 
| 258 259 | 
             
                        "entity": None,
         | 
| 260 | 
            +
                        "organization": None,
         | 
| 259 261 | 
             
                        "project": None,
         | 
| 260 262 | 
             
                        "_extra_http_headers": None,
         | 
| 261 263 | 
             
                        "_proxies": None,
         | 
| @@ -489,7 +491,8 @@ class Api: | |
| 489 491 | 
             
                            {
         | 
| 490 492 | 
             
                                "entity": "models",
         | 
| 491 493 | 
             
                                "base_url": "https://api.wandb.ai",
         | 
| 492 | 
            -
                                "project": None
         | 
| 494 | 
            +
                                "project": None,
         | 
| 495 | 
            +
                                "organization": "my-org",
         | 
| 493 496 | 
             
                            }
         | 
| 494 497 | 
             
                    """
         | 
| 495 498 | 
             
                    result = self.default_settings.copy()
         | 
| @@ -504,6 +507,14 @@ class Api: | |
| 504 507 | 
             
                                ),
         | 
| 505 508 | 
             
                                env=self._environ,
         | 
| 506 509 | 
             
                            ),
         | 
| 510 | 
            +
                            "organization": env.get_organization(
         | 
| 511 | 
            +
                                self._settings.get(
         | 
| 512 | 
            +
                                    Settings.DEFAULT_SECTION,
         | 
| 513 | 
            +
                                    "organization",
         | 
| 514 | 
            +
                                    fallback=result.get("organization"),
         | 
| 515 | 
            +
                                ),
         | 
| 516 | 
            +
                                env=self._environ,
         | 
| 517 | 
            +
                            ),
         | 
| 507 518 | 
             
                            "project": env.get_project(
         | 
| 508 519 | 
             
                                self._settings.get(
         | 
| 509 520 | 
             
                                    Settings.DEFAULT_SECTION,
         | 
    
        wandb/sdk/internal/sender.py
    CHANGED
    
    | @@ -544,23 +544,6 @@ class SendManager: | |
| 544 544 | 
             
                            logger.warning(f"Error emptying retry queue: {e}")
         | 
| 545 545 | 
             
                    self._respond_result(result)
         | 
| 546 546 |  | 
| 547 | 
            -
                def send_request_login(self, record: "Record") -> None:
         | 
| 548 | 
            -
                    # TODO: do something with api_key or anonymous?
         | 
| 549 | 
            -
                    # TODO: return an error if we aren't logged in?
         | 
| 550 | 
            -
                    self._api.reauth()
         | 
| 551 | 
            -
                    viewer = self.get_viewer_info()
         | 
| 552 | 
            -
                    server_info = self.get_server_info()
         | 
| 553 | 
            -
                    # self._login_flags = json.loads(viewer.get("flags", "{}"))
         | 
| 554 | 
            -
                    # self._login_entity = viewer.get("entity")
         | 
| 555 | 
            -
                    if server_info:
         | 
| 556 | 
            -
                        logger.info(f"Login server info: {server_info}")
         | 
| 557 | 
            -
                    self._entity = viewer.get("entity")
         | 
| 558 | 
            -
                    if record.control.req_resp:
         | 
| 559 | 
            -
                        result = proto_util._result_from_record(record)
         | 
| 560 | 
            -
                        if self._entity:
         | 
| 561 | 
            -
                            result.response.login_response.active_entity = self._entity
         | 
| 562 | 
            -
                        self._respond_result(result)
         | 
| 563 | 
            -
             | 
| 564 547 | 
             
                def send_exit(self, record: "Record") -> None:
         | 
| 565 548 | 
             
                    # track where the exit came from
         | 
| 566 549 | 
             
                    self._record_exit = record
         | 
| @@ -1491,7 +1474,6 @@ class SendManager: | |
| 1491 1474 | 
             
                        self._job_builder.set_partial_source_id(use.id)
         | 
| 1492 1475 |  | 
| 1493 1476 | 
             
                def send_request_log_artifact(self, record: "Record") -> None:
         | 
| 1494 | 
            -
                    assert record.control.req_resp
         | 
| 1495 1477 | 
             
                    result = proto_util._result_from_record(record)
         | 
| 1496 1478 | 
             
                    artifact = record.request.log_artifact.artifact
         | 
| 1497 1479 | 
             
                    history_step = record.request.log_artifact.history_step
         | 
    
        wandb/sdk/lib/apikey.py
    CHANGED
    
    | @@ -1,5 +1,8 @@ | |
| 1 1 | 
             
            """apikey util."""
         | 
| 2 2 |  | 
| 3 | 
            +
            from __future__ import annotations
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import dataclasses
         | 
| 3 6 | 
             
            import os
         | 
| 4 7 | 
             
            import platform
         | 
| 5 8 | 
             
            import stat
         | 
| @@ -32,8 +35,21 @@ LOGIN_CHOICES = [ | |
| 32 35 | 
             
                LOGIN_CHOICE_DRYRUN,
         | 
| 33 36 | 
             
            ]
         | 
| 34 37 |  | 
| 38 | 
            +
             | 
| 39 | 
            +
            @dataclasses.dataclass(frozen=True)
         | 
| 40 | 
            +
            class _NetrcPermissions:
         | 
| 41 | 
            +
                exists: bool
         | 
| 42 | 
            +
                read_access: bool
         | 
| 43 | 
            +
                write_access: bool
         | 
| 44 | 
            +
             | 
| 45 | 
            +
             | 
| 46 | 
            +
            class WriteNetrcError(Exception):
         | 
| 47 | 
            +
                """Raised when we cannot write to the netrc file."""
         | 
| 48 | 
            +
             | 
| 49 | 
            +
             | 
| 35 50 | 
             
            Mode = Literal["allow", "must", "never", "false", "true"]
         | 
| 36 51 |  | 
| 52 | 
            +
             | 
| 37 53 | 
             
            if TYPE_CHECKING:
         | 
| 38 54 | 
             
                from wandb.sdk.wandb_settings import Settings
         | 
| 39 55 |  | 
| @@ -170,29 +186,64 @@ def prompt_api_key(  # noqa: C901 | |
| 170 186 | 
             
                return key
         | 
| 171 187 |  | 
| 172 188 |  | 
| 173 | 
            -
            def  | 
| 189 | 
            +
            def check_netrc_access(
         | 
| 190 | 
            +
                netrc_path: str,
         | 
| 191 | 
            +
            ) -> _NetrcPermissions:
         | 
| 192 | 
            +
                """Check if we can read and write to the netrc file."""
         | 
| 193 | 
            +
                file_exists = False
         | 
| 194 | 
            +
                write_access = False
         | 
| 195 | 
            +
                read_access = False
         | 
| 196 | 
            +
                try:
         | 
| 197 | 
            +
                    st = os.stat(netrc_path)
         | 
| 198 | 
            +
                    file_exists = True
         | 
| 199 | 
            +
                    write_access = bool(st.st_mode & stat.S_IWUSR)
         | 
| 200 | 
            +
                    read_access = bool(st.st_mode & stat.S_IRUSR)
         | 
| 201 | 
            +
                except FileNotFoundError:
         | 
| 202 | 
            +
                    # If the netrc file doesn't exist, we will create it.
         | 
| 203 | 
            +
                    write_access = True
         | 
| 204 | 
            +
                    read_access = True
         | 
| 205 | 
            +
                except OSError as e:
         | 
| 206 | 
            +
                    wandb.termerror(f"Unable to read permissions for {netrc_path}, {e}")
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                return _NetrcPermissions(
         | 
| 209 | 
            +
                    exists=file_exists,
         | 
| 210 | 
            +
                    write_access=write_access,
         | 
| 211 | 
            +
                    read_access=read_access,
         | 
| 212 | 
            +
                )
         | 
| 213 | 
            +
             | 
| 214 | 
            +
             | 
| 215 | 
            +
            def write_netrc(host: str, entity: str, key: str):
         | 
| 174 216 | 
             
                """Add our host and key to .netrc."""
         | 
| 175 217 | 
             
                _, key_suffix = key.split("-", 1) if "-" in key else ("", key)
         | 
| 176 218 | 
             
                if len(key_suffix) != 40:
         | 
| 177 | 
            -
                     | 
| 219 | 
            +
                    raise ValueError(
         | 
| 178 220 | 
             
                        "API-key must be exactly 40 characters long: {} ({} chars)".format(
         | 
| 179 221 | 
             
                            key_suffix, len(key_suffix)
         | 
| 180 222 | 
             
                        )
         | 
| 181 223 | 
             
                    )
         | 
| 182 | 
            -
             | 
| 183 | 
            -
                 | 
| 184 | 
            -
             | 
| 185 | 
            -
             | 
| 186 | 
            -
             | 
| 187 | 
            -
             | 
| 224 | 
            +
             | 
| 225 | 
            +
                normalized_host = urlparse(host).netloc
         | 
| 226 | 
            +
                netrc_path = get_netrc_file_path()
         | 
| 227 | 
            +
                netrc_access = check_netrc_access(netrc_path)
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                if not netrc_access.write_access or not netrc_access.read_access:
         | 
| 230 | 
            +
                    raise WriteNetrcError(
         | 
| 231 | 
            +
                        f"Cannot access {netrc_path}. In order to persist your API key, "
         | 
| 232 | 
            +
                        "grant read and write permissions for your user to the file "
         | 
| 233 | 
            +
                        'or specify a different file with the environment variable "NETRC=<new_netrc_path>".'
         | 
| 188 234 | 
             
                    )
         | 
| 189 | 
            -
             | 
| 190 | 
            -
             | 
| 191 | 
            -
             | 
| 192 | 
            -
             | 
| 193 | 
            -
             | 
| 194 | 
            -
             | 
| 195 | 
            -
             | 
| 235 | 
            +
             | 
| 236 | 
            +
                machine_line = f"machine {normalized_host}"
         | 
| 237 | 
            +
                orig_lines = None
         | 
| 238 | 
            +
                try:
         | 
| 239 | 
            +
                    with open(netrc_path) as f:
         | 
| 240 | 
            +
                        orig_lines = f.read().strip().split("\n")
         | 
| 241 | 
            +
                except FileNotFoundError:
         | 
| 242 | 
            +
                    wandb.termlog("No netrc file found, creating one.")
         | 
| 243 | 
            +
                except OSError as e:
         | 
| 244 | 
            +
                    raise WriteNetrcError(f"Unable to read {netrc_path}") from e
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                try:
         | 
| 196 247 | 
             
                    with open(netrc_path, "w") as f:
         | 
| 197 248 | 
             
                        if orig_lines:
         | 
| 198 249 | 
             
                            # delete this machine from the file if it's already there.
         | 
| @@ -206,20 +257,22 @@ def write_netrc(host: str, entity: str, key: str) -> Optional[bool]: | |
| 206 257 | 
             
                                    skip -= 1
         | 
| 207 258 | 
             
                                else:
         | 
| 208 259 | 
             
                                    f.write("{}\n".format(line))
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                        wandb.termlog(
         | 
| 262 | 
            +
                            f"Appending key for {normalized_host} to your netrc file: {netrc_path}"
         | 
| 263 | 
            +
                        )
         | 
| 209 264 | 
             
                        f.write(
         | 
| 210 265 | 
             
                            textwrap.dedent(
         | 
| 211 266 | 
             
                                """\
         | 
| 212 | 
            -
             | 
| 213 | 
            -
             | 
| 214 | 
            -
             | 
| 215 | 
            -
             | 
| 267 | 
            +
                                machine {host}
         | 
| 268 | 
            +
                                  login {entity}
         | 
| 269 | 
            +
                                  password {key}
         | 
| 270 | 
            +
                                """
         | 
| 216 271 | 
             
                            ).format(host=normalized_host, entity=entity, key=key)
         | 
| 217 272 | 
             
                        )
         | 
| 218 273 | 
             
                    os.chmod(netrc_path, stat.S_IRUSR | stat.S_IWUSR)
         | 
| 219 | 
            -
             | 
| 220 | 
            -
             | 
| 221 | 
            -
                    wandb.termerror(f"Unable to read {netrc_path}")
         | 
| 222 | 
            -
                    return None
         | 
| 274 | 
            +
                except OSError as e:
         | 
| 275 | 
            +
                    raise WriteNetrcError(f"Unable to write {netrc_path}") from e
         | 
| 223 276 |  | 
| 224 277 |  | 
| 225 278 | 
             
            def write_key(
         | 
| @@ -250,7 +303,15 @@ def api_key(settings: Optional["Settings"] = None) -> Optional[str]: | |
| 250 303 | 
             
                    settings = wandb.setup().settings
         | 
| 251 304 | 
             
                if settings.api_key:
         | 
| 252 305 | 
             
                    return settings.api_key
         | 
| 253 | 
            -
             | 
| 254 | 
            -
                 | 
| 255 | 
            -
             | 
| 306 | 
            +
             | 
| 307 | 
            +
                netrc_access = check_netrc_access(get_netrc_file_path())
         | 
| 308 | 
            +
                if netrc_access.exists and not netrc_access.read_access:
         | 
| 309 | 
            +
                    wandb.termwarn(f"Cannot access {get_netrc_file_path()}.")
         | 
| 310 | 
            +
                    return None
         | 
| 311 | 
            +
             | 
| 312 | 
            +
                if netrc_access.exists:
         | 
| 313 | 
            +
                    auth = get_netrc_auth(settings.base_url)
         | 
| 314 | 
            +
                    if auth:
         | 
| 315 | 
            +
                        return auth[-1]
         | 
| 316 | 
            +
             | 
| 256 317 | 
             
                return None
         | 
| @@ -0,0 +1,210 @@ | |
| 1 | 
            +
            """Functions for compatibility with asyncio."""
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            from __future__ import annotations
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import asyncio
         | 
| 6 | 
            +
            import concurrent
         | 
| 7 | 
            +
            import concurrent.futures
         | 
| 8 | 
            +
            import contextlib
         | 
| 9 | 
            +
            import threading
         | 
| 10 | 
            +
            from typing import Any, AsyncIterator, Callable, Coroutine, Iterator, TypeVar
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            _T = TypeVar("_T")
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            def run(fn: Callable[[], Coroutine[Any, Any, _T]]) -> _T:
         | 
| 16 | 
            +
                """Run `fn` in an asyncio loop in a new thread.
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                This must always be used instead of `asyncio.run` which fails if there is
         | 
| 19 | 
            +
                an active `asyncio` event loop in the current thread. Since `wandb` was not
         | 
| 20 | 
            +
                originally designed with `asyncio` in mind, using `asyncio.run` would break
         | 
| 21 | 
            +
                users who were calling `wandb` methods from an `asyncio` loop.
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                Note that due to starting a new thread, this is slightly slow.
         | 
| 24 | 
            +
                """
         | 
| 25 | 
            +
                with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
         | 
| 26 | 
            +
                    runner = _Runner()
         | 
| 27 | 
            +
                    future = executor.submit(runner.run, fn)
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                    try:
         | 
| 30 | 
            +
                        return future.result()
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                    finally:
         | 
| 33 | 
            +
                        runner.cancel()
         | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
            +
            class _RunnerCancelledError(Exception):
         | 
| 37 | 
            +
                """The `_Runner.run()` invocation was cancelled."""
         | 
| 38 | 
            +
             | 
| 39 | 
            +
             | 
| 40 | 
            +
            class _Runner:
         | 
| 41 | 
            +
                """Runs an asyncio event loop allowing cancellation.
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                This is like `asyncio.run()`, except it provides a `cancel()` method
         | 
| 44 | 
            +
                meant to be called in a `finally` block.
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                Without this, it is impossible to make `asyncio.run()` stop if it runs
         | 
| 47 | 
            +
                in a non-main thread. In particular, a KeyboardInterrupt causes the
         | 
| 48 | 
            +
                ThreadPoolExecutor above to block until the asyncio thread completes,
         | 
| 49 | 
            +
                but there is no way to tell the asyncio thread to cancel its work.
         | 
| 50 | 
            +
                A second KeyboardInterrupt makes ThreadPoolExecutor give up while the
         | 
| 51 | 
            +
                asyncio thread still runs in the background, with terrible effects if it
         | 
| 52 | 
            +
                prints to the user's terminal.
         | 
| 53 | 
            +
                """
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                def __init__(self) -> None:
         | 
| 56 | 
            +
                    self._lock = threading.Condition()
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                    self._is_cancelled = False
         | 
| 59 | 
            +
                    self._started = False
         | 
| 60 | 
            +
                    self._done = False
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                    self._loop: asyncio.AbstractEventLoop | None = None
         | 
| 63 | 
            +
                    self._cancel_event: asyncio.Event | None = None
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                def run(self, fn: Callable[[], Coroutine[Any, Any, _T]]) -> _T:
         | 
| 66 | 
            +
                    """Run a coroutine in asyncio, cancelling it on `cancel()`.
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    Returns:
         | 
| 69 | 
            +
                        The result of the coroutine returned by `fn`.
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                    Raises:
         | 
| 72 | 
            +
                        _RunnerCancelledError: If `cancel()` is called.
         | 
| 73 | 
            +
                    """
         | 
| 74 | 
            +
                    return asyncio.run(self._run_or_cancel(fn))
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                async def _run_or_cancel(
         | 
| 77 | 
            +
                    self,
         | 
| 78 | 
            +
                    fn: Callable[[], Coroutine[Any, Any, _T]],
         | 
| 79 | 
            +
                ) -> _T:
         | 
| 80 | 
            +
                    with self._lock:
         | 
| 81 | 
            +
                        if self._is_cancelled:
         | 
| 82 | 
            +
                            raise _RunnerCancelledError()
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                        self._loop = asyncio.get_running_loop()
         | 
| 85 | 
            +
                        self._cancel_event = asyncio.Event()
         | 
| 86 | 
            +
                        self._started = True
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                    cancellation_task = asyncio.create_task(self._cancel_event.wait())
         | 
| 89 | 
            +
                    fn_task = asyncio.create_task(fn())
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                    try:
         | 
| 92 | 
            +
                        await asyncio.wait(
         | 
| 93 | 
            +
                            [cancellation_task, fn_task],
         | 
| 94 | 
            +
                            return_when=asyncio.FIRST_COMPLETED,
         | 
| 95 | 
            +
                        )
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                        if fn_task.done():
         | 
| 98 | 
            +
                            return fn_task.result()
         | 
| 99 | 
            +
                        else:
         | 
| 100 | 
            +
                            raise _RunnerCancelledError()
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    finally:
         | 
| 103 | 
            +
                        cancellation_task.cancel()
         | 
| 104 | 
            +
                        fn_task.cancel()
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                        with self._lock:
         | 
| 107 | 
            +
                            self._done = True
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                def cancel(self) -> None:
         | 
| 110 | 
            +
                    """Cancel all asyncio work started by `run()`."""
         | 
| 111 | 
            +
                    with self._lock:
         | 
| 112 | 
            +
                        if self._is_cancelled:
         | 
| 113 | 
            +
                            return
         | 
| 114 | 
            +
                        self._is_cancelled = True
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                        if self._done or not self._started:
         | 
| 117 | 
            +
                            # If the runner already finished, no need to cancel it.
         | 
| 118 | 
            +
                            #
         | 
| 119 | 
            +
                            # If the runner hasn't started the loop yet, then it will not
         | 
| 120 | 
            +
                            # as we already set _is_cancelled.
         | 
| 121 | 
            +
                            return
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                        assert self._loop
         | 
| 124 | 
            +
                        assert self._cancel_event
         | 
| 125 | 
            +
                        self._loop.call_soon_threadsafe(self._cancel_event.set)
         | 
| 126 | 
            +
             | 
| 127 | 
            +
             | 
| 128 | 
            +
            class TaskGroup:
         | 
| 129 | 
            +
                """Object that `open_task_group()` yields."""
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                def __init__(self) -> None:
         | 
| 132 | 
            +
                    self._tasks: list[asyncio.Task] = []
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                def start_soon(self, coro: Coroutine[Any, Any, Any]) -> None:
         | 
| 135 | 
            +
                    """Schedule a task in the group.
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                    Args:
         | 
| 138 | 
            +
                        coro: The return value of the `async` function defining the task.
         | 
| 139 | 
            +
                    """
         | 
| 140 | 
            +
                    self._tasks.append(asyncio.create_task(coro))
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                async def _wait_all(self) -> None:
         | 
| 143 | 
            +
                    """Block until all tasks complete.
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                    Raises:
         | 
| 146 | 
            +
                        Exception: If one or more tasks raises an exception, one of these
         | 
| 147 | 
            +
                            is raised arbitrarily.
         | 
| 148 | 
            +
                    """
         | 
| 149 | 
            +
                    done, _ = await asyncio.wait(
         | 
| 150 | 
            +
                        self._tasks,
         | 
| 151 | 
            +
                        # NOTE: Cancelling a task counts as a normal exit,
         | 
| 152 | 
            +
                        #   not an exception.
         | 
| 153 | 
            +
                        return_when=concurrent.futures.FIRST_EXCEPTION,
         | 
| 154 | 
            +
                    )
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                    for task in done:
         | 
| 157 | 
            +
                        try:
         | 
| 158 | 
            +
                            if exc := task.exception():
         | 
| 159 | 
            +
                                raise exc
         | 
| 160 | 
            +
                        except asyncio.CancelledError:
         | 
| 161 | 
            +
                            pass
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                def _cancel_all(self) -> None:
         | 
| 164 | 
            +
                    """Cancel all tasks."""
         | 
| 165 | 
            +
                    for task in self._tasks:
         | 
| 166 | 
            +
                        # NOTE: It is safe to cancel tasks that have already completed.
         | 
| 167 | 
            +
                        task.cancel()
         | 
| 168 | 
            +
             | 
| 169 | 
            +
             | 
| 170 | 
            +
            @contextlib.asynccontextmanager
         | 
| 171 | 
            +
            async def open_task_group() -> AsyncIterator[TaskGroup]:
         | 
| 172 | 
            +
                """Create a task group.
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                `asyncio` gained task groups in Python 3.11.
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                This is an async context manager, meant to be used with `async with`.
         | 
| 177 | 
            +
                On exit, it blocks until all subtasks complete. If any subtask fails, or if
         | 
| 178 | 
            +
                the current task is cancelled, it cancels all subtasks in the group and
         | 
| 179 | 
            +
                raises the subtask's exception. If multiple subtasks fail simultaneously,
         | 
| 180 | 
            +
                one of their exceptions is chosen arbitrarily.
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                NOTE: Subtask exceptions do not propagate until the context manager exits.
         | 
| 183 | 
            +
                This means that the task group cannot cancel code running inside the
         | 
| 184 | 
            +
                `async with` block .
         | 
| 185 | 
            +
                """
         | 
| 186 | 
            +
                task_group = TaskGroup()
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                try:
         | 
| 189 | 
            +
                    yield task_group
         | 
| 190 | 
            +
                    await task_group._wait_all()
         | 
| 191 | 
            +
                finally:
         | 
| 192 | 
            +
                    task_group._cancel_all()
         | 
| 193 | 
            +
             | 
| 194 | 
            +
             | 
| 195 | 
            +
            @contextlib.contextmanager
         | 
| 196 | 
            +
            def cancel_on_exit(coro: Coroutine[Any, Any, Any]) -> Iterator[None]:
         | 
| 197 | 
            +
                """Schedule a task, cancelling it when exiting the context manager.
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                If the given coroutine raises an exception, that exception is raised
         | 
| 200 | 
            +
                when exiting the context manager.
         | 
| 201 | 
            +
                """
         | 
| 202 | 
            +
                task = asyncio.create_task(coro)
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                try:
         | 
| 205 | 
            +
                    yield
         | 
| 206 | 
            +
                finally:
         | 
| 207 | 
            +
                    if task.done() and (exception := task.exception()):
         | 
| 208 | 
            +
                        raise exception
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                    task.cancel()
         |