wandb 0.22.2__py3-none-win32.whl → 0.22.3__py3-none-win32.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +1 -1
 - wandb/__init__.pyi +2 -2
 - wandb/_pydantic/__init__.py +8 -1
 - wandb/_pydantic/base.py +54 -18
 - wandb/_pydantic/field_types.py +8 -3
 - wandb/_pydantic/pagination.py +46 -0
 - wandb/_pydantic/utils.py +2 -2
 - wandb/apis/public/api.py +24 -19
 - wandb/apis/public/artifacts.py +259 -270
 - wandb/apis/public/registries/_utils.py +40 -54
 - wandb/apis/public/registries/registries_search.py +70 -85
 - wandb/apis/public/registries/registry.py +173 -156
 - wandb/apis/public/runs.py +27 -6
 - wandb/apis/public/utils.py +43 -20
 - wandb/automations/_generated/create_automation.py +2 -2
 - wandb/automations/_generated/create_generic_webhook_integration.py +4 -4
 - wandb/automations/_generated/delete_automation.py +2 -2
 - wandb/automations/_generated/fragments.py +31 -52
 - wandb/automations/_generated/generic_webhook_integrations_by_entity.py +3 -3
 - wandb/automations/_generated/get_automations.py +3 -3
 - wandb/automations/_generated/get_automations_by_entity.py +3 -3
 - wandb/automations/_generated/input_types.py +9 -9
 - wandb/automations/_generated/integrations_by_entity.py +3 -3
 - wandb/automations/_generated/operations.py +6 -6
 - wandb/automations/_generated/slack_integrations_by_entity.py +3 -3
 - wandb/automations/_generated/update_automation.py +2 -2
 - wandb/automations/_utils.py +3 -3
 - wandb/automations/actions.py +3 -3
 - wandb/automations/automations.py +6 -5
 - wandb/bin/gpu_stats.exe +0 -0
 - wandb/bin/wandb-core +0 -0
 - wandb/cli/beta.py +8 -2
 - wandb/cli/beta_leet.py +2 -1
 - wandb/cli/beta_sync.py +1 -1
 - wandb/errors/term.py +8 -8
 - wandb/jupyter.py +0 -51
 - wandb/old/settings.py +6 -6
 - wandb/proto/v3/wandb_internal_pb2.py +351 -352
 - wandb/proto/v3/wandb_server_pb2.py +38 -37
 - wandb/proto/v3/wandb_settings_pb2.py +2 -2
 - wandb/proto/v3/wandb_sync_pb2.py +19 -6
 - wandb/proto/v4/wandb_internal_pb2.py +351 -352
 - wandb/proto/v4/wandb_server_pb2.py +38 -37
 - wandb/proto/v4/wandb_settings_pb2.py +2 -2
 - wandb/proto/v4/wandb_sync_pb2.py +10 -6
 - wandb/proto/v5/wandb_internal_pb2.py +351 -352
 - wandb/proto/v5/wandb_server_pb2.py +38 -37
 - wandb/proto/v5/wandb_settings_pb2.py +2 -2
 - wandb/proto/v5/wandb_sync_pb2.py +10 -6
 - wandb/proto/v6/wandb_internal_pb2.py +351 -352
 - wandb/proto/v6/wandb_server_pb2.py +38 -37
 - wandb/proto/v6/wandb_settings_pb2.py +2 -2
 - wandb/proto/v6/wandb_sync_pb2.py +10 -6
 - wandb/sdk/artifacts/_generated/__init__.py +96 -40
 - wandb/sdk/artifacts/_generated/add_aliases.py +3 -3
 - wandb/sdk/artifacts/_generated/add_artifact_collection_tags.py +26 -0
 - wandb/sdk/artifacts/_generated/artifact_by_id.py +2 -2
 - wandb/sdk/artifacts/_generated/artifact_by_name.py +3 -3
 - wandb/sdk/artifacts/_generated/artifact_collection_membership_file_urls.py +27 -8
 - wandb/sdk/artifacts/_generated/artifact_collection_membership_files.py +27 -8
 - wandb/sdk/artifacts/_generated/artifact_created_by.py +7 -20
 - wandb/sdk/artifacts/_generated/artifact_file_urls.py +19 -6
 - wandb/sdk/artifacts/_generated/artifact_membership_by_name.py +26 -0
 - wandb/sdk/artifacts/_generated/artifact_type.py +5 -5
 - wandb/sdk/artifacts/_generated/artifact_used_by.py +8 -17
 - wandb/sdk/artifacts/_generated/artifact_version_files.py +19 -8
 - wandb/sdk/artifacts/_generated/delete_aliases.py +3 -3
 - wandb/sdk/artifacts/_generated/delete_artifact.py +4 -4
 - wandb/sdk/artifacts/_generated/delete_artifact_collection_tags.py +23 -0
 - wandb/sdk/artifacts/_generated/delete_artifact_portfolio.py +4 -4
 - wandb/sdk/artifacts/_generated/delete_artifact_sequence.py +4 -4
 - wandb/sdk/artifacts/_generated/delete_registry.py +21 -0
 - wandb/sdk/artifacts/_generated/fetch_artifact_manifest.py +8 -20
 - wandb/sdk/artifacts/_generated/fetch_linked_artifacts.py +13 -35
 - wandb/sdk/artifacts/_generated/fetch_org_info_from_entity.py +28 -0
 - wandb/sdk/artifacts/_generated/fetch_registries.py +18 -8
 - wandb/sdk/{projects → artifacts}/_generated/fetch_registry.py +4 -4
 - wandb/sdk/artifacts/_generated/fragments.py +183 -333
 - wandb/sdk/artifacts/_generated/input_types.py +133 -7
 - wandb/sdk/artifacts/_generated/link_artifact.py +5 -5
 - wandb/sdk/artifacts/_generated/operations.py +1053 -548
 - wandb/sdk/artifacts/_generated/project_artifact_collection.py +9 -77
 - wandb/sdk/artifacts/_generated/project_artifact_collections.py +21 -9
 - wandb/sdk/artifacts/_generated/project_artifact_type.py +3 -3
 - wandb/sdk/artifacts/_generated/project_artifact_types.py +19 -6
 - wandb/sdk/artifacts/_generated/project_artifacts.py +7 -8
 - wandb/sdk/artifacts/_generated/registry_collections.py +21 -9
 - wandb/sdk/artifacts/_generated/registry_versions.py +20 -9
 - wandb/sdk/artifacts/_generated/rename_registry.py +25 -0
 - wandb/sdk/artifacts/_generated/run_input_artifacts.py +5 -9
 - wandb/sdk/artifacts/_generated/run_output_artifacts.py +5 -9
 - wandb/sdk/artifacts/_generated/type_info.py +2 -2
 - wandb/sdk/artifacts/_generated/unlink_artifact.py +3 -5
 - wandb/sdk/artifacts/_generated/update_artifact.py +3 -3
 - wandb/sdk/artifacts/_generated/update_artifact_collection_type.py +28 -0
 - wandb/sdk/artifacts/_generated/update_artifact_portfolio.py +7 -16
 - wandb/sdk/artifacts/_generated/update_artifact_sequence.py +7 -16
 - wandb/sdk/artifacts/_generated/upsert_registry.py +25 -0
 - wandb/sdk/artifacts/_gqlutils.py +170 -6
 - wandb/sdk/artifacts/_models/__init__.py +9 -0
 - wandb/sdk/artifacts/_models/artifact_collection.py +109 -0
 - wandb/sdk/artifacts/_models/manifest.py +26 -0
 - wandb/sdk/artifacts/_models/pagination.py +26 -0
 - wandb/sdk/artifacts/_models/registry.py +100 -0
 - wandb/sdk/artifacts/_validators.py +45 -27
 - wandb/sdk/artifacts/artifact.py +220 -215
 - wandb/sdk/artifacts/artifact_file_cache.py +1 -1
 - wandb/sdk/artifacts/artifact_manifest.py +37 -32
 - wandb/sdk/artifacts/artifact_manifest_entry.py +80 -125
 - wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +43 -61
 - wandb/sdk/artifacts/storage_handlers/gcs_handler.py +8 -6
 - wandb/sdk/data_types/image.py +2 -2
 - wandb/sdk/interface/interface.py +72 -64
 - wandb/sdk/interface/interface_queue.py +27 -18
 - wandb/sdk/interface/interface_shared.py +61 -23
 - wandb/sdk/interface/interface_sock.py +9 -5
 - wandb/sdk/internal/_generated/server_features_query.py +4 -4
 - wandb/sdk/launch/inputs/schema.py +13 -10
 - wandb/sdk/lib/apikey.py +8 -12
 - wandb/sdk/lib/asyncio_compat.py +1 -1
 - wandb/sdk/lib/asyncio_manager.py +5 -5
 - wandb/sdk/lib/console_capture.py +38 -30
 - wandb/sdk/lib/progress.py +159 -64
 - wandb/sdk/lib/retry.py +3 -2
 - wandb/sdk/lib/service/service_connection.py +2 -2
 - wandb/sdk/lib/wb_logging.py +2 -1
 - wandb/sdk/mailbox/mailbox.py +1 -1
 - wandb/sdk/wandb_init.py +10 -13
 - wandb/sdk/wandb_run.py +9 -46
 - wandb/sdk/wandb_settings.py +102 -19
 - {wandb-0.22.2.dist-info → wandb-0.22.3.dist-info}/METADATA +2 -1
 - {wandb-0.22.2.dist-info → wandb-0.22.3.dist-info}/RECORD +135 -134
 - wandb/sdk/artifacts/_generated/artifact_via_membership_by_name.py +0 -26
 - wandb/sdk/artifacts/_generated/create_artifact_collection_tag_assignments.py +0 -36
 - wandb/sdk/artifacts/_generated/delete_artifact_collection_tag_assignments.py +0 -25
 - wandb/sdk/artifacts/_generated/move_artifact_collection.py +0 -35
 - wandb/sdk/projects/_generated/__init__.py +0 -26
 - wandb/sdk/projects/_generated/delete_project.py +0 -22
 - wandb/sdk/projects/_generated/enums.py +0 -4
 - wandb/sdk/projects/_generated/fragments.py +0 -41
 - wandb/sdk/projects/_generated/input_types.py +0 -13
 - wandb/sdk/projects/_generated/operations.py +0 -88
 - wandb/sdk/projects/_generated/rename_project.py +0 -27
 - wandb/sdk/projects/_generated/upsert_registry_project.py +0 -27
 - {wandb-0.22.2.dist-info → wandb-0.22.3.dist-info}/WHEEL +0 -0
 - {wandb-0.22.2.dist-info → wandb-0.22.3.dist-info}/entry_points.txt +0 -0
 - {wandb-0.22.2.dist-info → wandb-0.22.3.dist-info}/licenses/LICENSE +0 -0
 
| 
         @@ -7,6 +7,9 @@ META_SCHEMA = { 
     | 
|
| 
       7 
7 
     | 
    
         
             
                    },
         
     | 
| 
       8 
8 
     | 
    
         
             
                    "title": {"type": "string"},
         
     | 
| 
       9 
9 
     | 
    
         
             
                    "description": {"type": "string"},
         
     | 
| 
      
 10 
     | 
    
         
            +
                    "label": {"type": "string"},
         
     | 
| 
      
 11 
     | 
    
         
            +
                    "placeholder": {"type": "string"},
         
     | 
| 
      
 12 
     | 
    
         
            +
                    "required": {"type": "boolean"},
         
     | 
| 
       10 
13 
     | 
    
         
             
                    "format": {"type": "string"},
         
     | 
| 
       11 
14 
     | 
    
         
             
                    "enum": {"type": "array", "items": {"type": ["integer", "number", "string"]}},
         
     | 
| 
       12 
15 
     | 
    
         
             
                    "properties": {"type": "object", "patternProperties": {".*": {"$ref": "#"}}},
         
     | 
| 
         @@ -19,24 +22,24 @@ META_SCHEMA = { 
     | 
|
| 
       19 
22 
     | 
    
         
             
                },
         
     | 
| 
       20 
23 
     | 
    
         
             
                "allOf": [
         
     | 
| 
       21 
24 
     | 
    
         
             
                    {
         
     | 
| 
       22 
     | 
    
         
            -
                        "if": {"properties": {"type": {"const": " 
     | 
| 
      
 25 
     | 
    
         
            +
                        "if": {"properties": {"type": {"const": "integer"}}},
         
     | 
| 
       23 
26 
     | 
    
         
             
                        "then": {
         
     | 
| 
       24 
27 
     | 
    
         
             
                            "properties": {
         
     | 
| 
       25 
     | 
    
         
            -
                                "minimum": {"type":  
     | 
| 
       26 
     | 
    
         
            -
                                "maximum": {"type":  
     | 
| 
       27 
     | 
    
         
            -
                                "exclusiveMinimum": {"type":  
     | 
| 
       28 
     | 
    
         
            -
                                "exclusiveMaximum": {"type":  
     | 
| 
      
 28 
     | 
    
         
            +
                                "minimum": {"type": "integer"},
         
     | 
| 
      
 29 
     | 
    
         
            +
                                "maximum": {"type": "integer"},
         
     | 
| 
      
 30 
     | 
    
         
            +
                                "exclusiveMinimum": {"type": "integer"},
         
     | 
| 
      
 31 
     | 
    
         
            +
                                "exclusiveMaximum": {"type": "integer"},
         
     | 
| 
       29 
32 
     | 
    
         
             
                            }
         
     | 
| 
       30 
33 
     | 
    
         
             
                        },
         
     | 
| 
       31 
34 
     | 
    
         
             
                    },
         
     | 
| 
       32 
35 
     | 
    
         
             
                    {
         
     | 
| 
       33 
     | 
    
         
            -
                        "if": {"properties": {"type": {"const": " 
     | 
| 
      
 36 
     | 
    
         
            +
                        "if": {"properties": {"type": {"const": "number"}}},
         
     | 
| 
       34 
37 
     | 
    
         
             
                        "then": {
         
     | 
| 
       35 
38 
     | 
    
         
             
                            "properties": {
         
     | 
| 
       36 
     | 
    
         
            -
                                "minimum": {"type": "integer"},
         
     | 
| 
       37 
     | 
    
         
            -
                                "maximum": {"type": "integer"},
         
     | 
| 
       38 
     | 
    
         
            -
                                "exclusiveMinimum": {"type": "integer"},
         
     | 
| 
       39 
     | 
    
         
            -
                                "exclusiveMaximum": {"type": "integer"},
         
     | 
| 
      
 39 
     | 
    
         
            +
                                "minimum": {"type": ["integer", "number"]},
         
     | 
| 
      
 40 
     | 
    
         
            +
                                "maximum": {"type": ["integer", "number"]},
         
     | 
| 
      
 41 
     | 
    
         
            +
                                "exclusiveMinimum": {"type": ["integer", "number"]},
         
     | 
| 
      
 42 
     | 
    
         
            +
                                "exclusiveMaximum": {"type": ["integer", "number"]},
         
     | 
| 
       40 
43 
     | 
    
         
             
                            }
         
     | 
| 
       41 
44 
     | 
    
         
             
                        },
         
     | 
| 
       42 
45 
     | 
    
         
             
                    },
         
     | 
    
        wandb/sdk/lib/apikey.py
    CHANGED
    
    | 
         @@ -136,12 +136,6 @@ def prompt_api_key(  # noqa: C901 
     | 
|
| 
       136 
136 
     | 
    
         
             
                if (jupyter and not settings.login_timeout) or no_create:
         
     | 
| 
       137 
137 
     | 
    
         
             
                    choices.remove(LOGIN_CHOICE_NEW)
         
     | 
| 
       138 
138 
     | 
    
         | 
| 
       139 
     | 
    
         
            -
                if jupyter and "google.colab" in sys.modules:
         
     | 
| 
       140 
     | 
    
         
            -
                    log_string = term.LOG_STRING_NOCOLOR
         
     | 
| 
       141 
     | 
    
         
            -
                    key = wandb.jupyter.attempt_colab_login(app_url)  # type: ignore
         
     | 
| 
       142 
     | 
    
         
            -
                    if key is not None:
         
     | 
| 
       143 
     | 
    
         
            -
                        return key  # type: ignore
         
     | 
| 
       144 
     | 
    
         
            -
             
     | 
| 
       145 
139 
     | 
    
         
             
                if anon_mode == "must":
         
     | 
| 
       146 
140 
     | 
    
         
             
                    result = LOGIN_CHOICE_ANON
         
     | 
| 
       147 
141 
     | 
    
         
             
                # If we're not in an interactive environment, default to dry-run.
         
     | 
| 
         @@ -236,9 +230,9 @@ def check_netrc_access( 
     | 
|
| 
       236 
230 
     | 
    
         
             
            def write_netrc(host: str, entity: str, key: str):
         
     | 
| 
       237 
231 
     | 
    
         
             
                """Add our host and key to .netrc."""
         
     | 
| 
       238 
232 
     | 
    
         
             
                _, key_suffix = key.split("-", 1) if "-" in key else ("", key)
         
     | 
| 
       239 
     | 
    
         
            -
                if len(key_suffix)  
     | 
| 
      
 233 
     | 
    
         
            +
                if len(key_suffix) < 40:
         
     | 
| 
       240 
234 
     | 
    
         
             
                    raise ValueError(
         
     | 
| 
       241 
     | 
    
         
            -
                        f"API-key must be  
     | 
| 
      
 235 
     | 
    
         
            +
                        f"API-key must be at least 40 characters long: {key_suffix} ({len(key_suffix)} chars)"
         
     | 
| 
       242 
236 
     | 
    
         
             
                    )
         
     | 
| 
       243 
237 
     | 
    
         | 
| 
       244 
238 
     | 
    
         
             
                normalized_host = urlparse(host).netloc
         
     | 
| 
         @@ -305,12 +299,14 @@ def write_key( 
     | 
|
| 
       305 
299 
     | 
    
         
             
                # TODO(jhr): api shouldn't be optional or it shouldn't be passed, clean up callers
         
     | 
| 
       306 
300 
     | 
    
         
             
                api = api or InternalApi()
         
     | 
| 
       307 
301 
     | 
    
         | 
| 
       308 
     | 
    
         
            -
                #  
     | 
| 
       309 
     | 
    
         
            -
                # variable-length prefix, a dash, then the 40 
     | 
| 
      
 302 
     | 
    
         
            +
                # API keys are strings of at least 40 characters. On-prem API keys have a
         
     | 
| 
      
 303 
     | 
    
         
            +
                # variable-length prefix, a dash, then the string of at least 40 chars.
         
     | 
| 
       310 
304 
     | 
    
         
             
                _, suffix = key.split("-", 1) if "-" in key else ("", key)
         
     | 
| 
       311 
305 
     | 
    
         | 
| 
       312 
     | 
    
         
            -
                if len(suffix)  
     | 
| 
       313 
     | 
    
         
            -
                    raise ValueError( 
     | 
| 
      
 306 
     | 
    
         
            +
                if len(suffix) < 40:
         
     | 
| 
      
 307 
     | 
    
         
            +
                    raise ValueError(
         
     | 
| 
      
 308 
     | 
    
         
            +
                        f"API key must be at least 40 characters long, yours was {len(key)}"
         
     | 
| 
      
 309 
     | 
    
         
            +
                    )
         
     | 
| 
       314 
310 
     | 
    
         | 
| 
       315 
311 
     | 
    
         
             
                write_netrc(settings.base_url, "user", key)
         
     | 
| 
       316 
312 
     | 
    
         | 
    
        wandb/sdk/lib/asyncio_compat.py
    CHANGED
    
    | 
         @@ -133,7 +133,7 @@ class TaskGroup: 
     | 
|
| 
       133 
133 
     | 
    
         
             
                """Object that `open_task_group()` yields."""
         
     | 
| 
       134 
134 
     | 
    
         | 
| 
       135 
135 
     | 
    
         
             
                def __init__(self) -> None:
         
     | 
| 
       136 
     | 
    
         
            -
                    self._tasks: list[asyncio.Task] = []
         
     | 
| 
      
 136 
     | 
    
         
            +
                    self._tasks: list[asyncio.Task[None]] = []
         
     | 
| 
       137 
137 
     | 
    
         | 
| 
       138 
138 
     | 
    
         
             
                def start_soon(self, coro: Coroutine[Any, Any, Any]) -> None:
         
     | 
| 
       139 
139 
     | 
    
         
             
                    """Schedule a task in the group.
         
     | 
    
        wandb/sdk/lib/asyncio_manager.py
    CHANGED
    
    | 
         @@ -7,7 +7,7 @@ import concurrent.futures 
     | 
|
| 
       7 
7 
     | 
    
         
             
            import contextlib
         
     | 
| 
       8 
8 
     | 
    
         
             
            import logging
         
     | 
| 
       9 
9 
     | 
    
         
             
            import threading
         
     | 
| 
       10 
     | 
    
         
            -
            from typing import  
     | 
| 
      
 10 
     | 
    
         
            +
            from typing import Awaitable, Callable, TypeVar
         
     | 
| 
       11 
11 
     | 
    
         | 
| 
       12 
12 
     | 
    
         
             
            from . import asyncio_compat
         
     | 
| 
       13 
13 
     | 
    
         | 
| 
         @@ -104,7 +104,7 @@ class AsyncioManager: 
     | 
|
| 
       104 
104 
     | 
    
         
             
                        # This only matters if the KeyboardInterrupt is suppressed.
         
     | 
| 
       105 
105 
     | 
    
         
             
                        self._runner.cancel()
         
     | 
| 
       106 
106 
     | 
    
         | 
| 
       107 
     | 
    
         
            -
                def run(self, fn: Callable[[],  
     | 
| 
      
 107 
     | 
    
         
            +
                def run(self, fn: Callable[[], Awaitable[_T]]) -> _T:
         
     | 
| 
       108 
108 
     | 
    
         
             
                    """Run an async function to completion.
         
     | 
| 
       109 
109 
     | 
    
         | 
| 
       110 
110 
     | 
    
         
             
                    The function is called in the asyncio thread. Blocks until start()
         
     | 
| 
         @@ -148,7 +148,7 @@ class AsyncioManager: 
     | 
|
| 
       148 
148 
     | 
    
         | 
| 
       149 
149 
     | 
    
         
             
                def run_soon(
         
     | 
| 
       150 
150 
     | 
    
         
             
                    self,
         
     | 
| 
       151 
     | 
    
         
            -
                    fn: Callable[[],  
     | 
| 
      
 151 
     | 
    
         
            +
                    fn: Callable[[], Awaitable[None]],
         
     | 
| 
       152 
152 
     | 
    
         
             
                    *,
         
     | 
| 
       153 
153 
     | 
    
         
             
                    daemon: bool = False,
         
     | 
| 
       154 
154 
     | 
    
         
             
                    name: str | None = None,
         
     | 
| 
         @@ -186,7 +186,7 @@ class AsyncioManager: 
     | 
|
| 
       186 
186 
     | 
    
         | 
| 
       187 
187 
     | 
    
         
             
                def _schedule(
         
     | 
| 
       188 
188 
     | 
    
         
             
                    self,
         
     | 
| 
       189 
     | 
    
         
            -
                    fn: Callable[[],  
     | 
| 
      
 189 
     | 
    
         
            +
                    fn: Callable[[], Awaitable[_T]],
         
     | 
| 
       190 
190 
     | 
    
         
             
                    daemon: bool,
         
     | 
| 
       191 
191 
     | 
    
         
             
                    name: str | None = None,
         
     | 
| 
       192 
192 
     | 
    
         
             
                ) -> concurrent.futures.Future[_T]:
         
     | 
| 
         @@ -207,7 +207,7 @@ class AsyncioManager: 
     | 
|
| 
       207 
207 
     | 
    
         | 
| 
       208 
208 
     | 
    
         
             
                async def _wrap(
         
     | 
| 
       209 
209 
     | 
    
         
             
                    self,
         
     | 
| 
       210 
     | 
    
         
            -
                    fn: Callable[[],  
     | 
| 
      
 210 
     | 
    
         
            +
                    fn: Callable[[], Awaitable[_T]],
         
     | 
| 
       211 
211 
     | 
    
         
             
                    daemon: bool,
         
     | 
| 
       212 
212 
     | 
    
         
             
                    name: str | None,
         
     | 
| 
       213 
213 
     | 
    
         
             
                ) -> _T:
         
     | 
    
        wandb/sdk/lib/console_capture.py
    CHANGED
    
    | 
         @@ -75,9 +75,12 @@ class _WriteCallback(Protocol): 
     | 
|
| 
       75 
75 
     | 
    
         
             
                    """
         
     | 
| 
       76 
76 
     | 
    
         | 
| 
       77 
77 
     | 
    
         | 
| 
       78 
     | 
    
         
            -
             
     | 
| 
       79 
     | 
    
         
            -
            _module_rlock = threading.RLock()
         
     | 
| 
      
 78 
     | 
    
         
            +
            _module_lock = threading.Lock()
         
     | 
| 
       80 
79 
     | 
    
         
             
            _is_writing = False
         
     | 
| 
      
 80 
     | 
    
         
            +
            """Prevents infinite print-capture loops.
         
     | 
| 
      
 81 
     | 
    
         
            +
             
     | 
| 
      
 82 
     | 
    
         
            +
            If a capture callback prints, that output is not captured.
         
     | 
| 
      
 83 
     | 
    
         
            +
            """
         
     | 
| 
       81 
84 
     | 
    
         | 
| 
       82 
85 
     | 
    
         
             
            _patch_exception: CannotCaptureConsoleError | None = None
         
     | 
| 
       83 
86 
     | 
    
         | 
| 
         @@ -99,7 +102,7 @@ def capture_stdout(callback: _WriteCallback) -> Callable[[], None]: 
     | 
|
| 
       99 
102 
     | 
    
         
             
                Raises:
         
     | 
| 
       100 
103 
     | 
    
         
             
                    CannotCaptureConsoleError: If patching failed on import.
         
     | 
| 
       101 
104 
     | 
    
         
             
                """
         
     | 
| 
       102 
     | 
    
         
            -
                with  
     | 
| 
      
 105 
     | 
    
         
            +
                with _module_lock:
         
     | 
| 
       103 
106 
     | 
    
         
             
                    if _patch_exception:
         
     | 
| 
       104 
107 
     | 
    
         
             
                        raise _patch_exception
         
     | 
| 
       105 
108 
     | 
    
         | 
| 
         @@ -121,7 +124,7 @@ def capture_stderr(callback: _WriteCallback) -> Callable[[], None]: 
     | 
|
| 
       121 
124 
     | 
    
         
             
                Raises:
         
     | 
| 
       122 
125 
     | 
    
         
             
                    CannotCaptureConsoleError: If patching failed on import.
         
     | 
| 
       123 
126 
     | 
    
         
             
                """
         
     | 
| 
       124 
     | 
    
         
            -
                with  
     | 
| 
      
 127 
     | 
    
         
            +
                with _module_lock:
         
     | 
| 
       125 
128 
     | 
    
         
             
                    if _patch_exception:
         
     | 
| 
       126 
129 
     | 
    
         
             
                        raise _patch_exception
         
     | 
| 
       127 
130 
     | 
    
         | 
| 
         @@ -144,7 +147,7 @@ def _insert_disposably( 
     | 
|
| 
       144 
147 
     | 
    
         
             
                def dispose() -> None:
         
     | 
| 
       145 
148 
     | 
    
         
             
                    nonlocal disposed
         
     | 
| 
       146 
149 
     | 
    
         | 
| 
       147 
     | 
    
         
            -
                    with  
     | 
| 
      
 150 
     | 
    
         
            +
                    with _module_lock:
         
     | 
| 
       148 
151 
     | 
    
         
             
                        if disposed:
         
     | 
| 
       149 
152 
     | 
    
         
             
                            return
         
     | 
| 
       150 
153 
     | 
    
         | 
| 
         @@ -167,38 +170,43 @@ def _patch( 
     | 
|
| 
       167 
170 
     | 
    
         
             
                    global _is_writing
         
     | 
| 
       168 
171 
     | 
    
         
             
                    n = orig_write(s)
         
     | 
| 
       169 
172 
     | 
    
         | 
| 
       170 
     | 
    
         
            -
                     
     | 
| 
       171 
     | 
    
         
            -
                    # deadlock if a callback invokes write() again.
         
     | 
| 
       172 
     | 
    
         
            -
                    with _module_rlock:
         
     | 
| 
      
 173 
     | 
    
         
            +
                    with _module_lock:
         
     | 
| 
       173 
174 
     | 
    
         
             
                        if _is_writing:
         
     | 
| 
       174 
175 
     | 
    
         
             
                            return n
         
     | 
| 
       175 
     | 
    
         
            -
             
     | 
| 
       176 
176 
     | 
    
         
             
                        _is_writing = True
         
     | 
| 
       177 
     | 
    
         
            -
             
     | 
| 
       178 
     | 
    
         
            -
             
     | 
| 
       179 
     | 
    
         
            -
             
     | 
| 
       180 
     | 
    
         
            -
             
     | 
| 
       181 
     | 
    
         
            -
                         
     | 
| 
       182 
     | 
    
         
            -
             
     | 
| 
       183 
     | 
    
         
            -
             
     | 
| 
       184 
     | 
    
         
            -
             
     | 
| 
       185 
     | 
    
         
            -
                             
     | 
| 
       186 
     | 
    
         
            -
             
     | 
| 
       187 
     | 
    
         
            -
             
     | 
| 
       188 
     | 
    
         
            -
             
     | 
| 
       189 
     | 
    
         
            -
             
     | 
| 
      
 177 
     | 
    
         
            +
             
     | 
| 
      
 178 
     | 
    
         
            +
                        # Invoke callbacks outside of the lock to avoid deadlocks.
         
     | 
| 
      
 179 
     | 
    
         
            +
                        # 1. A callback may print, invoking this again.
         
     | 
| 
      
 180 
     | 
    
         
            +
                        # 2. A callback may block on a different thread which then prints.
         
     | 
| 
      
 181 
     | 
    
         
            +
                        callback_list = list(callbacks.values())
         
     | 
| 
      
 182 
     | 
    
         
            +
             
     | 
| 
      
 183 
     | 
    
         
            +
                    try:
         
     | 
| 
      
 184 
     | 
    
         
            +
                        for cb in callback_list:
         
     | 
| 
      
 185 
     | 
    
         
            +
                            cb(s, n)
         
     | 
| 
      
 186 
     | 
    
         
            +
             
     | 
| 
      
 187 
     | 
    
         
            +
                    except BaseException as e:
         
     | 
| 
      
 188 
     | 
    
         
            +
                        # Clear all callbacks on any exception to avoid infinite loops:
         
     | 
| 
      
 189 
     | 
    
         
            +
                        #
         
     | 
| 
      
 190 
     | 
    
         
            +
                        # * If we re-raise, an exception handler is likely to print
         
     | 
| 
      
 191 
     | 
    
         
            +
                        #   the exception to the console and trigger callbacks again
         
     | 
| 
      
 192 
     | 
    
         
            +
                        # * If we log, we can't guarantee that this doesn't print
         
     | 
| 
      
 193 
     | 
    
         
            +
                        #   to console.
         
     | 
| 
      
 194 
     | 
    
         
            +
                        #
         
     | 
| 
      
 195 
     | 
    
         
            +
                        # This is especially important for KeyboardInterrupt.
         
     | 
| 
      
 196 
     | 
    
         
            +
                        with _module_lock:
         
     | 
| 
       190 
197 
     | 
    
         
             
                            _stderr_callbacks.clear()
         
     | 
| 
       191 
198 
     | 
    
         
             
                            _stdout_callbacks.clear()
         
     | 
| 
       192 
199 
     | 
    
         | 
| 
       193 
     | 
    
         
            -
             
     | 
| 
       194 
     | 
    
         
            -
             
     | 
| 
       195 
     | 
    
         
            -
             
     | 
| 
       196 
     | 
    
         
            -
             
     | 
| 
       197 
     | 
    
         
            -
             
     | 
| 
       198 
     | 
    
         
            -
             
     | 
| 
       199 
     | 
    
         
            -
             
     | 
| 
      
 200 
     | 
    
         
            +
                        if isinstance(e, Exception):
         
     | 
| 
      
 201 
     | 
    
         
            +
                            # We suppress Exceptions so that bugs in W&B code don't
         
     | 
| 
      
 202 
     | 
    
         
            +
                            # cause the user's print() statements to raise errors.
         
     | 
| 
      
 203 
     | 
    
         
            +
                            _logger.exception("Error in console callback, clearing all!")
         
     | 
| 
      
 204 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 205 
     | 
    
         
            +
                            # Re-raise errors like KeyboardInterrupt.
         
     | 
| 
      
 206 
     | 
    
         
            +
                            raise
         
     | 
| 
       200 
207 
     | 
    
         | 
| 
       201 
     | 
    
         
            -
             
     | 
| 
      
 208 
     | 
    
         
            +
                    finally:
         
     | 
| 
      
 209 
     | 
    
         
            +
                        with _module_lock:
         
     | 
| 
       202 
210 
     | 
    
         
             
                            _is_writing = False
         
     | 
| 
       203 
211 
     | 
    
         | 
| 
       204 
212 
     | 
    
         
             
                    return n
         
     | 
    
        wandb/sdk/lib/progress.py
    CHANGED
    
    | 
         @@ -5,7 +5,7 @@ from __future__ import annotations 
     | 
|
| 
       5 
5 
     | 
    
         
             
            import asyncio
         
     | 
| 
       6 
6 
     | 
    
         
             
            import contextlib
         
     | 
| 
       7 
7 
     | 
    
         
             
            import time
         
     | 
| 
       8 
     | 
    
         
            -
            from typing import  
     | 
| 
      
 8 
     | 
    
         
            +
            from typing import Iterator, NoReturn
         
     | 
| 
       9 
9 
     | 
    
         | 
| 
       10 
10 
     | 
    
         
             
            from wandb.proto import wandb_internal_pb2 as pb
         
     | 
| 
       11 
11 
     | 
    
         
             
            from wandb.sdk.interface import interface
         
     | 
| 
         @@ -13,6 +13,10 @@ from wandb.sdk.lib import asyncio_compat 
     | 
|
| 
       13 
13 
     | 
    
         | 
| 
       14 
14 
     | 
    
         
             
            from . import printer as p
         
     | 
| 
       15 
15 
     | 
    
         | 
| 
      
 16 
     | 
    
         
            +
            _INDENT = "  "
         
     | 
| 
      
 17 
     | 
    
         
            +
            _MAX_LINES_TO_PRINT = 6
         
     | 
| 
      
 18 
     | 
    
         
            +
            _MAX_OPS_TO_PRINT = 5
         
     | 
| 
      
 19 
     | 
    
         
            +
             
     | 
| 
       16 
20 
     | 
    
         | 
| 
       17 
21 
     | 
    
         
             
            async def loop_printing_operation_stats(
         
     | 
| 
       18 
22 
     | 
    
         
             
                progress: ProgressPrinter,
         
     | 
| 
         @@ -96,101 +100,189 @@ class ProgressPrinter: 
     | 
|
| 
       96 
100 
     | 
    
         
             
                    self._printer = printer
         
     | 
| 
       97 
101 
     | 
    
         
             
                    self._progress_text_area = progress_text_area
         
     | 
| 
       98 
102 
     | 
    
         
             
                    self._default_text = default_text
         
     | 
| 
       99 
     | 
    
         
            -
                    self._tick =  
     | 
| 
      
 103 
     | 
    
         
            +
                    self._tick = -1
         
     | 
| 
       100 
104 
     | 
    
         
             
                    self._last_printed_line = ""
         
     | 
| 
       101 
105 
     | 
    
         | 
| 
       102 
106 
     | 
    
         
             
                def update(
         
     | 
| 
       103 
107 
     | 
    
         
             
                    self,
         
     | 
| 
       104 
     | 
    
         
            -
                     
     | 
| 
      
 108 
     | 
    
         
            +
                    stats_or_groups: pb.OperationStats | dict[str, pb.OperationStats],
         
     | 
| 
       105 
109 
     | 
    
         
             
                ) -> None:
         
     | 
| 
       106 
     | 
    
         
            -
                    """Update the displayed information. 
     | 
| 
       107 
     | 
    
         
            -
             
     | 
| 
      
 110 
     | 
    
         
            +
                    """Update the displayed information.
         
     | 
| 
      
 111 
     | 
    
         
            +
             
     | 
| 
      
 112 
     | 
    
         
            +
                    Args:
         
     | 
| 
      
 113 
     | 
    
         
            +
                        stats_or_groups: A single group of operations, or zero or more
         
     | 
| 
      
 114 
     | 
    
         
            +
                            labeled operation groups.
         
     | 
| 
      
 115 
     | 
    
         
            +
                    """
         
     | 
| 
      
 116 
     | 
    
         
            +
                    self._tick += 1
         
     | 
| 
      
 117 
     | 
    
         
            +
             
     | 
| 
      
 118 
     | 
    
         
            +
                    if not self._progress_text_area:
         
     | 
| 
      
 119 
     | 
    
         
            +
                        line = self._to_static_text(stats_or_groups)
         
     | 
| 
      
 120 
     | 
    
         
            +
                        if line and line != self._last_printed_line:
         
     | 
| 
      
 121 
     | 
    
         
            +
                            self._printer.display(line)
         
     | 
| 
      
 122 
     | 
    
         
            +
                            self._last_printed_line = line
         
     | 
| 
       108 
123 
     | 
    
         
             
                        return
         
     | 
| 
       109 
124 
     | 
    
         | 
| 
       110 
     | 
    
         
            -
                     
     | 
| 
       111 
     | 
    
         
            -
             
     | 
| 
       112 
     | 
    
         
            -
             
     | 
| 
       113 
     | 
    
         
            -
                         
     | 
| 
       114 
     | 
    
         
            -
                             
     | 
| 
       115 
     | 
    
         
            -
                         
     | 
| 
      
 125 
     | 
    
         
            +
                    lines = self._to_dynamic_text(stats_or_groups)
         
     | 
| 
      
 126 
     | 
    
         
            +
                    if not lines:
         
     | 
| 
      
 127 
     | 
    
         
            +
                        loading_symbol = self._printer.loading_symbol(self._tick)
         
     | 
| 
      
 128 
     | 
    
         
            +
                        if loading_symbol:
         
     | 
| 
      
 129 
     | 
    
         
            +
                            lines = [f"{loading_symbol} {self._default_text}"]
         
     | 
| 
      
 130 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 131 
     | 
    
         
            +
                            lines = [self._default_text]
         
     | 
| 
       116 
132 
     | 
    
         | 
| 
       117 
     | 
    
         
            -
                    self. 
     | 
| 
      
 133 
     | 
    
         
            +
                    self._progress_text_area.set_text("\n".join(lines))
         
     | 
| 
       118 
134 
     | 
    
         | 
| 
       119 
     | 
    
         
            -
                def  
     | 
| 
       120 
     | 
    
         
            -
                     
     | 
| 
       121 
     | 
    
         
            -
             
     | 
| 
      
 135 
     | 
    
         
            +
                def _to_dynamic_text(
         
     | 
| 
      
 136 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 137 
     | 
    
         
            +
                    stats_or_groups: pb.OperationStats | dict[str, pb.OperationStats],
         
     | 
| 
      
 138 
     | 
    
         
            +
                ) -> list[str]:
         
     | 
| 
      
 139 
     | 
    
         
            +
                    """Returns text to show in a dynamic text area."""
         
     | 
| 
      
 140 
     | 
    
         
            +
                    loading_symbol = self._printer.loading_symbol(self._tick)
         
     | 
| 
      
 141 
     | 
    
         
            +
             
     | 
| 
      
 142 
     | 
    
         
            +
                    if isinstance(stats_or_groups, dict):
         
     | 
| 
      
 143 
     | 
    
         
            +
                        return _GroupedOperationStatsPrinter(
         
     | 
| 
       122 
144 
     | 
    
         
             
                            self._printer,
         
     | 
| 
       123 
     | 
    
         
            -
                             
     | 
| 
       124 
     | 
    
         
            -
                             
     | 
| 
       125 
     | 
    
         
            -
             
     | 
| 
       126 
     | 
    
         
            -
                            default_text=self._default_text,
         
     | 
| 
       127 
     | 
    
         
            -
                        ).display(stats_list)
         
     | 
| 
      
 145 
     | 
    
         
            +
                            _MAX_LINES_TO_PRINT,
         
     | 
| 
      
 146 
     | 
    
         
            +
                            loading_symbol,
         
     | 
| 
      
 147 
     | 
    
         
            +
                        ).render(stats_or_groups)
         
     | 
| 
       128 
148 
     | 
    
         | 
| 
       129 
149 
     | 
    
         
             
                    else:
         
     | 
| 
       130 
     | 
    
         
            -
                         
     | 
| 
       131 
     | 
    
         
            -
             
     | 
| 
       132 
     | 
    
         
            -
             
     | 
| 
       133 
     | 
    
         
            -
                             
     | 
| 
       134 
     | 
    
         
            -
             
     | 
| 
       135 
     | 
    
         
            -
                                    top_level_operations.append(op.desc)
         
     | 
| 
       136 
     | 
    
         
            -
                                else:
         
     | 
| 
       137 
     | 
    
         
            -
                                    extra_operations += 1
         
     | 
| 
       138 
     | 
    
         
            -
             
     | 
| 
       139 
     | 
    
         
            -
                        line = "; ".join(top_level_operations)
         
     | 
| 
       140 
     | 
    
         
            -
                        if extra_operations > 0:
         
     | 
| 
       141 
     | 
    
         
            -
                            line += f" (+ {extra_operations} more)"
         
     | 
| 
      
 150 
     | 
    
         
            +
                        return _OperationStatsPrinter(
         
     | 
| 
      
 151 
     | 
    
         
            +
                            self._printer,
         
     | 
| 
      
 152 
     | 
    
         
            +
                            _MAX_LINES_TO_PRINT,
         
     | 
| 
      
 153 
     | 
    
         
            +
                            loading_symbol,
         
     | 
| 
      
 154 
     | 
    
         
            +
                        ).render(stats_or_groups)
         
     | 
| 
       142 
155 
     | 
    
         | 
| 
       143 
     | 
    
         
            -
             
     | 
| 
       144 
     | 
    
         
            -
             
     | 
| 
       145 
     | 
    
         
            -
             
     | 
| 
      
 156 
     | 
    
         
            +
                def _to_static_text(
         
     | 
| 
      
 157 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 158 
     | 
    
         
            +
                    stats_or_groups: pb.OperationStats | dict[str, pb.OperationStats],
         
     | 
| 
      
 159 
     | 
    
         
            +
                ) -> str:
         
     | 
| 
      
 160 
     | 
    
         
            +
                    """Returns a single line of text to print out."""
         
     | 
| 
      
 161 
     | 
    
         
            +
                    if isinstance(stats_or_groups, dict):
         
     | 
| 
      
 162 
     | 
    
         
            +
                        sorted_prefixed_stats = list(
         
     | 
| 
      
 163 
     | 
    
         
            +
                            (f"[{group}] ", stats)  #
         
     | 
| 
      
 164 
     | 
    
         
            +
                            for group, stats in sorted(stats_or_groups.items())
         
     | 
| 
      
 165 
     | 
    
         
            +
                        )
         
     | 
| 
      
 166 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 167 
     | 
    
         
            +
                        sorted_prefixed_stats = [("", stats_or_groups)]
         
     | 
| 
      
 168 
     | 
    
         
            +
             
     | 
| 
      
 169 
     | 
    
         
            +
                    group_strs: list[str] = []
         
     | 
| 
      
 170 
     | 
    
         
            +
                    total_operations = 0
         
     | 
| 
      
 171 
     | 
    
         
            +
                    total_printed = 0
         
     | 
| 
       146 
172 
     | 
    
         | 
| 
      
 173 
     | 
    
         
            +
                    for prefix, stats in sorted_prefixed_stats:
         
     | 
| 
      
 174 
     | 
    
         
            +
                        total_operations += stats.total_operations
         
     | 
| 
      
 175 
     | 
    
         
            +
                        if not stats.operations:
         
     | 
| 
      
 176 
     | 
    
         
            +
                            continue
         
     | 
| 
      
 177 
     | 
    
         
            +
             
     | 
| 
      
 178 
     | 
    
         
            +
                        group_ops: list[str] = []
         
     | 
| 
      
 179 
     | 
    
         
            +
                        i = 0
         
     | 
| 
      
 180 
     | 
    
         
            +
                        while total_printed < _MAX_OPS_TO_PRINT and i < len(stats.operations):
         
     | 
| 
      
 181 
     | 
    
         
            +
                            group_ops.append(stats.operations[i].desc)
         
     | 
| 
      
 182 
     | 
    
         
            +
                            total_printed += 1
         
     | 
| 
      
 183 
     | 
    
         
            +
                            i += 1
         
     | 
| 
      
 184 
     | 
    
         
            +
             
     | 
| 
      
 185 
     | 
    
         
            +
                        if group_ops:
         
     | 
| 
      
 186 
     | 
    
         
            +
                            group_strs.append(prefix + "; ".join(group_ops))
         
     | 
| 
      
 187 
     | 
    
         
            +
             
     | 
| 
      
 188 
     | 
    
         
            +
                    line = "; ".join(group_strs)
         
     | 
| 
      
 189 
     | 
    
         
            +
                    remaining = total_operations - total_printed
         
     | 
| 
      
 190 
     | 
    
         
            +
                    if total_printed > 0 and remaining > 0:
         
     | 
| 
      
 191 
     | 
    
         
            +
                        line += f" (+ {remaining} more)"
         
     | 
| 
       147 
192 
     | 
    
         | 
| 
       148 
     | 
    
         
            -
             
     | 
| 
       149 
     | 
    
         
            -
             
     | 
| 
      
 193 
     | 
    
         
            +
                    return line
         
     | 
| 
      
 194 
     | 
    
         
            +
             
     | 
| 
      
 195 
     | 
    
         
            +
             
     | 
| 
      
 196 
     | 
    
         
            +
            class _GroupedOperationStatsPrinter:
         
     | 
| 
      
 197 
     | 
    
         
            +
                """Renders a list of labeled operation stats groups into lines of text."""
         
     | 
| 
       150 
198 
     | 
    
         | 
| 
       151 
199 
     | 
    
         
             
                def __init__(
         
     | 
| 
       152 
200 
     | 
    
         
             
                    self,
         
     | 
| 
       153 
201 
     | 
    
         
             
                    printer: p.Printer,
         
     | 
| 
       154 
     | 
    
         
            -
                    text_area: p.DynamicText,
         
     | 
| 
       155 
202 
     | 
    
         
             
                    max_lines: int,
         
     | 
| 
       156 
203 
     | 
    
         
             
                    loading_symbol: str,
         
     | 
| 
       157 
     | 
    
         
            -
                    default_text: str,
         
     | 
| 
       158 
204 
     | 
    
         
             
                ) -> None:
         
     | 
| 
       159 
205 
     | 
    
         
             
                    self._printer = printer
         
     | 
| 
       160 
     | 
    
         
            -
                    self._text_area = text_area
         
     | 
| 
       161 
206 
     | 
    
         
             
                    self._max_lines = max_lines
         
     | 
| 
       162 
207 
     | 
    
         
             
                    self._loading_symbol = loading_symbol
         
     | 
| 
       163 
     | 
    
         
            -
                    self._default_text = default_text
         
     | 
| 
       164 
208 
     | 
    
         | 
| 
       165 
     | 
    
         
            -
             
     | 
| 
       166 
     | 
    
         
            -
                     
     | 
| 
      
 209 
     | 
    
         
            +
                def render(self, groups: dict[str, pb.OperationStats]) -> list[str]:
         
     | 
| 
      
 210 
     | 
    
         
            +
                    """Convert labeled operation stats groups into text to display.
         
     | 
| 
      
 211 
     | 
    
         
            +
             
     | 
| 
      
 212 
     | 
    
         
            +
                    Args:
         
     | 
| 
      
 213 
     | 
    
         
            +
                        groups: A mapping from group labels to stats for that group.
         
     | 
| 
      
 214 
     | 
    
         
            +
             
     | 
| 
      
 215 
     | 
    
         
            +
                    Returns:
         
     | 
| 
      
 216 
     | 
    
         
            +
                        The lines of text to print. The lines do not end with the newline
         
     | 
| 
      
 217 
     | 
    
         
            +
                        character. Returns an empty list if there are no operations.
         
     | 
| 
      
 218 
     | 
    
         
            +
                    """
         
     | 
| 
      
 219 
     | 
    
         
            +
                    lines: list[str] = []
         
     | 
| 
      
 220 
     | 
    
         
            +
             
     | 
| 
      
 221 
     | 
    
         
            +
                    for key, stats in sorted(groups.items()):
         
     | 
| 
      
 222 
     | 
    
         
            +
                        # Don't display empty groups.
         
     | 
| 
      
 223 
     | 
    
         
            +
                        if not stats.operations:
         
     | 
| 
      
 224 
     | 
    
         
            +
                            continue
         
     | 
| 
      
 225 
     | 
    
         
            +
             
     | 
| 
      
 226 
     | 
    
         
            +
                        # Ensure enough space left for the group header and at least
         
     | 
| 
      
 227 
     | 
    
         
            +
                        # one line of content.
         
     | 
| 
      
 228 
     | 
    
         
            +
                        remaining_lines = self._max_lines - len(lines)
         
     | 
| 
      
 229 
     | 
    
         
            +
                        if remaining_lines < 2:
         
     | 
| 
      
 230 
     | 
    
         
            +
                            break
         
     | 
| 
      
 231 
     | 
    
         
            +
             
     | 
| 
      
 232 
     | 
    
         
            +
                        # Group header.
         
     | 
| 
      
 233 
     | 
    
         
            +
                        lines.append(key)
         
     | 
| 
       167 
234 
     | 
    
         | 
| 
       168 
     | 
    
         
            -
             
     | 
| 
      
 235 
     | 
    
         
            +
                        # Group content.
         
     | 
| 
      
 236 
     | 
    
         
            +
                        stats_lines = _OperationStatsPrinter(
         
     | 
| 
      
 237 
     | 
    
         
            +
                            printer=self._printer,
         
     | 
| 
      
 238 
     | 
    
         
            +
                            max_lines=remaining_lines - 1,  # minus one for the header
         
     | 
| 
      
 239 
     | 
    
         
            +
                            loading_symbol=self._loading_symbol,
         
     | 
| 
      
 240 
     | 
    
         
            +
                        ).render(stats)
         
     | 
| 
      
 241 
     | 
    
         
            +
                        for line in stats_lines:
         
     | 
| 
      
 242 
     | 
    
         
            +
                            lines.append(f"{_INDENT}{line}")
         
     | 
| 
      
 243 
     | 
    
         
            +
             
     | 
| 
      
 244 
     | 
    
         
            +
                    return lines
         
     | 
| 
      
 245 
     | 
    
         
            +
             
     | 
| 
      
 246 
     | 
    
         
            +
             
     | 
| 
      
 247 
     | 
    
         
            +
            class _OperationStatsPrinter:
         
     | 
| 
      
 248 
     | 
    
         
            +
                """Renders operation stats into lines of text."""
         
     | 
| 
      
 249 
     | 
    
         
            +
             
     | 
| 
      
 250 
     | 
    
         
            +
                def __init__(
         
     | 
| 
       169 
251 
     | 
    
         
             
                    self,
         
     | 
| 
       170 
     | 
    
         
            -
                     
     | 
| 
      
 252 
     | 
    
         
            +
                    printer: p.Printer,
         
     | 
| 
      
 253 
     | 
    
         
            +
                    max_lines: int,
         
     | 
| 
      
 254 
     | 
    
         
            +
                    loading_symbol: str,
         
     | 
| 
       171 
255 
     | 
    
         
             
                ) -> None:
         
     | 
| 
       172 
     | 
    
         
            -
                     
     | 
| 
       173 
     | 
    
         
            -
                     
     | 
| 
       174 
     | 
    
         
            -
                     
     | 
| 
       175 
     | 
    
         
            -
             
     | 
| 
       176 
     | 
    
         
            -
             
     | 
| 
       177 
     | 
    
         
            -
             
     | 
| 
      
 256 
     | 
    
         
            +
                    self._printer = printer
         
     | 
| 
      
 257 
     | 
    
         
            +
                    self._max_lines = max_lines
         
     | 
| 
      
 258 
     | 
    
         
            +
                    self._loading_symbol = loading_symbol
         
     | 
| 
      
 259 
     | 
    
         
            +
             
     | 
| 
      
 260 
     | 
    
         
            +
                    self._lines: list[str] = []
         
     | 
| 
      
 261 
     | 
    
         
            +
                    self._ops_shown = 0
         
     | 
| 
      
 262 
     | 
    
         
            +
             
     | 
| 
      
 263 
     | 
    
         
            +
                def render(self, stats: pb.OperationStats) -> list[str]:
         
     | 
| 
      
 264 
     | 
    
         
            +
                    """Convert the stats into a list of lines to display.
         
     | 
| 
      
 265 
     | 
    
         
            +
             
     | 
| 
      
 266 
     | 
    
         
            +
                    Args:
         
     | 
| 
      
 267 
     | 
    
         
            +
                        stats: Collection of operations to display.
         
     | 
| 
       178 
268 
     | 
    
         | 
| 
       179 
     | 
    
         
            -
                     
     | 
| 
      
 269 
     | 
    
         
            +
                    Returns:
         
     | 
| 
      
 270 
     | 
    
         
            +
                        The lines of text to print. The lines do not end with the newline
         
     | 
| 
      
 271 
     | 
    
         
            +
                        character. Returns an empty list if there are no operations.
         
     | 
| 
      
 272 
     | 
    
         
            +
                    """
         
     | 
| 
      
 273 
     | 
    
         
            +
                    for op in stats.operations:
         
     | 
| 
      
 274 
     | 
    
         
            +
                        self._add_operation(op, is_subtask=False, indent="")
         
     | 
| 
      
 275 
     | 
    
         
            +
             
     | 
| 
      
 276 
     | 
    
         
            +
                    if self._ops_shown < stats.total_operations:
         
     | 
| 
       180 
277 
     | 
    
         
             
                        if 1 <= self._max_lines <= len(self._lines):
         
     | 
| 
      
 278 
     | 
    
         
            +
                            self._ops_shown -= 1
         
     | 
| 
       181 
279 
     | 
    
         
             
                            self._lines.pop()
         
     | 
| 
       182 
280 
     | 
    
         | 
| 
       183 
     | 
    
         
            -
                        remaining = total_operations - self._ops_shown
         
     | 
| 
      
 281 
     | 
    
         
            +
                        remaining = stats.total_operations - self._ops_shown
         
     | 
| 
       184 
282 
     | 
    
         | 
| 
       185 
283 
     | 
    
         
             
                        self._lines.append(f"+ {remaining} more task(s)")
         
     | 
| 
       186 
284 
     | 
    
         | 
| 
       187 
     | 
    
         
            -
                     
     | 
| 
       188 
     | 
    
         
            -
                        if self._loading_symbol:
         
     | 
| 
       189 
     | 
    
         
            -
                            self._text_area.set_text(f"{self._loading_symbol} {self._default_text}")
         
     | 
| 
       190 
     | 
    
         
            -
                        else:
         
     | 
| 
       191 
     | 
    
         
            -
                            self._text_area.set_text(self._default_text)
         
     | 
| 
       192 
     | 
    
         
            -
                    else:
         
     | 
| 
       193 
     | 
    
         
            -
                        self._text_area.set_text("\n".join(self._lines))
         
     | 
| 
      
 285 
     | 
    
         
            +
                    return self._lines
         
     | 
| 
       194 
286 
     | 
    
         | 
| 
       195 
287 
     | 
    
         
             
                def _add_operation(self, op: pb.Operation, is_subtask: bool, indent: str) -> None:
         
     | 
| 
       196 
288 
     | 
    
         
             
                    """Add the operation to `self._lines`."""
         
     | 
| 
         @@ -200,14 +292,17 @@ class _DynamicOperationStatsPrinter: 
     | 
|
| 
       200 
292 
     | 
    
         
             
                    if not is_subtask:
         
     | 
| 
       201 
293 
     | 
    
         
             
                        self._ops_shown += 1
         
     | 
| 
       202 
294 
     | 
    
         | 
| 
       203 
     | 
    
         
            -
                     
     | 
| 
      
 295 
     | 
    
         
            +
                    status_indent_level = 0  # alignment for the status message, if any
         
     | 
| 
      
 296 
     | 
    
         
            +
                    parts: list[str] = []
         
     | 
| 
       204 
297 
     | 
    
         | 
| 
       205 
298 
     | 
    
         
             
                    # Subtask indicator.
         
     | 
| 
       206 
299 
     | 
    
         
             
                    if is_subtask and self._printer.supports_unicode:
         
     | 
| 
      
 300 
     | 
    
         
            +
                        status_indent_level += 2  # +1 for space
         
     | 
| 
       207 
301 
     | 
    
         
             
                        parts.append("↳")
         
     | 
| 
       208 
302 
     | 
    
         | 
| 
       209 
303 
     | 
    
         
             
                    # Loading symbol.
         
     | 
| 
       210 
304 
     | 
    
         
             
                    if self._loading_symbol:
         
     | 
| 
      
 305 
     | 
    
         
            +
                        status_indent_level += 2  # +1 for space
         
     | 
| 
       211 
306 
     | 
    
         
             
                        parts.append(self._loading_symbol)
         
     | 
| 
       212 
307 
     | 
    
         | 
| 
       213 
308 
     | 
    
         
             
                    # Task name.
         
     | 
| 
         @@ -225,14 +320,14 @@ class _DynamicOperationStatsPrinter: 
     | 
|
| 
       225 
320 
     | 
    
         
             
                    if op.error_status:
         
     | 
| 
       226 
321 
     | 
    
         
             
                        error_word = self._printer.error("ERROR")
         
     | 
| 
       227 
322 
     | 
    
         
             
                        error_desc = self._printer.secondary_text(op.error_status)
         
     | 
| 
       228 
     | 
    
         
            -
                         
     | 
| 
      
 323 
     | 
    
         
            +
                        status_indent = " " * status_indent_level
         
     | 
| 
       229 
324 
     | 
    
         
             
                        self._lines.append(
         
     | 
| 
       230 
     | 
    
         
            -
                            f"{indent}{ 
     | 
| 
      
 325 
     | 
    
         
            +
                            f"{indent}{status_indent}{error_word} {error_desc}",
         
     | 
| 
       231 
326 
     | 
    
         
             
                        )
         
     | 
| 
       232 
327 
     | 
    
         | 
| 
       233 
328 
     | 
    
         
             
                    # Subtasks.
         
     | 
| 
       234 
329 
     | 
    
         
             
                    if op.subtasks:
         
     | 
| 
       235 
     | 
    
         
            -
                        subtask_indent = indent +  
     | 
| 
      
 330 
     | 
    
         
            +
                        subtask_indent = indent + _INDENT
         
     | 
| 
       236 
331 
     | 
    
         
             
                        for task in op.subtasks:
         
     | 
| 
       237 
332 
     | 
    
         
             
                            self._add_operation(
         
     | 
| 
       238 
333 
     | 
    
         
             
                                task,
         
     | 
    
        wandb/sdk/lib/retry.py
    CHANGED
    
    | 
         @@ -77,9 +77,10 @@ class Retry(Generic[_R]): 
     | 
|
| 
       77 
77 
     | 
    
         
             
                        self._retryable_exceptions = retryable_exceptions
         
     | 
| 
       78 
78 
     | 
    
         
             
                    else:
         
     | 
| 
       79 
79 
     | 
    
         
             
                        self._retryable_exceptions = (TransientError,)
         
     | 
| 
       80 
     | 
    
         
            -
                    self._index = 0
         
     | 
| 
       81 
80 
     | 
    
         
             
                    self.retry_callback = retry_callback
         
     | 
| 
       82 
81 
     | 
    
         | 
| 
      
 82 
     | 
    
         
            +
                    self._num_iter = 0
         
     | 
| 
      
 83 
     | 
    
         
            +
             
     | 
| 
       83 
84 
     | 
    
         
             
                def _sleep_check_cancelled(
         
     | 
| 
       84 
85 
     | 
    
         
             
                    self, wait_seconds: float, cancel_event: Optional[threading.Event]
         
     | 
| 
       85 
86 
     | 
    
         
             
                ) -> bool:
         
     | 
| 
         @@ -194,7 +195,7 @@ class Retry(Generic[_R]): 
     | 
|
| 
       194 
195 
     | 
    
         
             
                    else:
         
     | 
| 
       195 
196 
     | 
    
         
             
                        wandb.termlog(
         
     | 
| 
       196 
197 
     | 
    
         
             
                            f"{self._error_prefix}"
         
     | 
| 
       197 
     | 
    
         
            -
                            f" ({exception.__class__.__name__}), entering retry loop."
         
     | 
| 
      
 198 
     | 
    
         
            +
                            + f" ({exception.__class__.__name__}), entering retry loop."
         
     | 
| 
       198 
199 
     | 
    
         
             
                        )
         
     | 
| 
       199 
200 
     | 
    
         | 
| 
       200 
201 
     | 
    
         
             
                def _print_recovered(self, start_time: datetime.datetime) -> None:
         
     | 
| 
         @@ -190,8 +190,8 @@ class ServiceConnection: 
     | 
|
| 
       190 
190 
     | 
    
         
             
                    except TimeoutError:
         
     | 
| 
       191 
191 
     | 
    
         
             
                        raise WandbAttachFailedError(
         
     | 
| 
       192 
192 
     | 
    
         
             
                            "Failed to attach because the run does not belong to"
         
     | 
| 
       193 
     | 
    
         
            -
                            " the current service process, or because the service"
         
     | 
| 
       194 
     | 
    
         
            -
                            " process is busy (unlikely)."
         
     | 
| 
      
 193 
     | 
    
         
            +
                            + " the current service process, or because the service"
         
     | 
| 
      
 194 
     | 
    
         
            +
                            + " process is busy (unlikely)."
         
     | 
| 
       195 
195 
     | 
    
         
             
                        ) from None
         
     | 
| 
       196 
196 
     | 
    
         | 
| 
       197 
197 
     | 
    
         
             
                    else:
         
     | 
    
        wandb/sdk/lib/wb_logging.py
    CHANGED
    
    | 
         @@ -136,7 +136,7 @@ def add_file_handler(run_id: str, filepath: pathlib.Path) -> logging.Handler: 
     | 
|
| 
       136 
136 
     | 
    
         
             
                return handler
         
     | 
| 
       137 
137 
     | 
    
         | 
| 
       138 
138 
     | 
    
         | 
| 
       139 
     | 
    
         
            -
            class _RunIDFilter 
     | 
| 
      
 139 
     | 
    
         
            +
            class _RunIDFilter:
         
     | 
| 
       140 
140 
     | 
    
         
             
                """Filters out messages logged for a different run."""
         
     | 
| 
       141 
141 
     | 
    
         | 
| 
       142 
142 
     | 
    
         
             
                def __init__(self, run_id: str) -> None:
         
     | 
| 
         @@ -148,6 +148,7 @@ class _RunIDFilter(logging.Filter): 
     | 
|
| 
       148 
148 
     | 
    
         
             
                    self._run_id = run_id
         
     | 
| 
       149 
149 
     | 
    
         | 
| 
       150 
150 
     | 
    
         
             
                def filter(self, record: logging.LogRecord) -> bool:
         
     | 
| 
      
 151 
     | 
    
         
            +
                    """Modify a log record and return whether it matches the run."""
         
     | 
| 
       151 
152 
     | 
    
         
             
                    run_id = _run_id.get()
         
     | 
| 
       152 
153 
     | 
    
         | 
| 
       153 
154 
     | 
    
         
             
                    if run_id is None:
         
     | 
    
        wandb/sdk/mailbox/mailbox.py
    CHANGED
    
    
    
        wandb/sdk/wandb_init.py
    CHANGED
    
    | 
         @@ -12,6 +12,7 @@ from __future__ import annotations 
     | 
|
| 
       12 
12 
     | 
    
         | 
| 
       13 
13 
     | 
    
         
             
            import contextlib
         
     | 
| 
       14 
14 
     | 
    
         
             
            import dataclasses
         
     | 
| 
      
 15 
     | 
    
         
            +
            import functools
         
     | 
| 
       15 
16 
     | 
    
         
             
            import json
         
     | 
| 
       16 
17 
     | 
    
         
             
            import logging
         
     | 
| 
       17 
18 
     | 
    
         
             
            import os
         
     | 
| 
         @@ -988,25 +989,21 @@ class _WandbInit: 
     | 
|
| 
       988 
989 
     | 
    
         | 
| 
       989 
990 
     | 
    
         
             
                    run_init_handle = backend.interface.deliver_run(run)
         
     | 
| 
       990 
991 
     | 
    
         | 
| 
       991 
     | 
    
         
            -
                     
     | 
| 
       992 
     | 
    
         
            -
                        assert backend.interface
         
     | 
| 
       993 
     | 
    
         
            -
             
     | 
| 
      
 992 
     | 
    
         
            +
                    try:
         
     | 
| 
       994 
993 
     | 
    
         
             
                        with progress.progress_printer(
         
     | 
| 
       995 
994 
     | 
    
         
             
                            run_printer,
         
     | 
| 
       996 
995 
     | 
    
         
             
                            default_text="Waiting for wandb.init()...",
         
     | 
| 
       997 
996 
     | 
    
         
             
                        ) as progress_printer:
         
     | 
| 
       998 
     | 
    
         
            -
                             
     | 
| 
       999 
     | 
    
         
            -
                                 
     | 
| 
       1000 
     | 
    
         
            -
                                 
     | 
| 
      
 997 
     | 
    
         
            +
                            result = wait_with_progress(
         
     | 
| 
      
 998 
     | 
    
         
            +
                                run_init_handle,
         
     | 
| 
      
 999 
     | 
    
         
            +
                                timeout=timeout,
         
     | 
| 
      
 1000 
     | 
    
         
            +
                                display_progress=functools.partial(
         
     | 
| 
      
 1001 
     | 
    
         
            +
                                    progress.loop_printing_operation_stats,
         
     | 
| 
      
 1002 
     | 
    
         
            +
                                    progress_printer,
         
     | 
| 
      
 1003 
     | 
    
         
            +
                                    backend.interface,
         
     | 
| 
      
 1004 
     | 
    
         
            +
                                ),
         
     | 
| 
       1001 
1005 
     | 
    
         
             
                            )
         
     | 
| 
       1002 
1006 
     | 
    
         | 
| 
       1003 
     | 
    
         
            -
                    try:
         
     | 
| 
       1004 
     | 
    
         
            -
                        result = wait_with_progress(
         
     | 
| 
       1005 
     | 
    
         
            -
                            run_init_handle,
         
     | 
| 
       1006 
     | 
    
         
            -
                            timeout=timeout,
         
     | 
| 
       1007 
     | 
    
         
            -
                            display_progress=display_init_message,
         
     | 
| 
       1008 
     | 
    
         
            -
                        )
         
     | 
| 
       1009 
     | 
    
         
            -
             
     | 
| 
       1010 
1007 
     | 
    
         
             
                    except TimeoutError:
         
     | 
| 
       1011 
1008 
     | 
    
         
             
                        run_init_handle.cancel(backend.interface)
         
     | 
| 
       1012 
1009 
     | 
    
         |