wandb 0.18.3__py3-none-win32.whl → 0.18.4__py3-none-win32.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +16 -7
- wandb/__init__.pyi +96 -63
- wandb/analytics/sentry.py +91 -88
- wandb/apis/public/api.py +18 -4
- wandb/apis/public/runs.py +53 -2
- wandb/bin/gpu_stats.exe +0 -0
- wandb/bin/wandb-core +0 -0
- wandb/cli/beta.py +178 -0
- wandb/cli/cli.py +5 -171
- wandb/data_types.py +3 -0
- wandb/env.py +74 -73
- wandb/errors/term.py +300 -43
- wandb/proto/v3/wandb_internal_pb2.py +263 -223
- wandb/proto/v3/wandb_server_pb2.py +57 -37
- wandb/proto/v3/wandb_settings_pb2.py +2 -2
- wandb/proto/v4/wandb_internal_pb2.py +226 -218
- wandb/proto/v4/wandb_server_pb2.py +41 -37
- wandb/proto/v4/wandb_settings_pb2.py +2 -2
- wandb/proto/v5/wandb_internal_pb2.py +226 -218
- wandb/proto/v5/wandb_server_pb2.py +41 -37
- wandb/proto/v5/wandb_settings_pb2.py +2 -2
- wandb/sdk/__init__.py +3 -3
- wandb/sdk/artifacts/_validators.py +41 -8
- wandb/sdk/artifacts/artifact.py +32 -1
- wandb/sdk/artifacts/artifact_file_cache.py +1 -2
- wandb/sdk/data_types/_dtypes.py +7 -3
- wandb/sdk/data_types/video.py +15 -6
- wandb/sdk/interface/interface.py +2 -0
- wandb/sdk/internal/internal_api.py +122 -5
- wandb/sdk/internal/sender.py +16 -3
- wandb/sdk/launch/inputs/internal.py +1 -1
- wandb/sdk/lib/module.py +12 -0
- wandb/sdk/lib/printer.py +291 -105
- wandb/sdk/lib/progress.py +274 -0
- wandb/sdk/service/streams.py +21 -11
- wandb/sdk/wandb_init.py +58 -54
- wandb/sdk/wandb_run.py +380 -454
- wandb/sdk/wandb_settings.py +2 -0
- wandb/sdk/wandb_watch.py +17 -11
- wandb/util.py +6 -2
- {wandb-0.18.3.dist-info → wandb-0.18.4.dist-info}/METADATA +4 -3
- {wandb-0.18.3.dist-info → wandb-0.18.4.dist-info}/RECORD +45 -43
- wandb/bin/nvidia_gpu_stats.exe +0 -0
- {wandb-0.18.3.dist-info → wandb-0.18.4.dist-info}/WHEEL +0 -0
- {wandb-0.18.3.dist-info → wandb-0.18.4.dist-info}/entry_points.txt +0 -0
- {wandb-0.18.3.dist-info → wandb-0.18.4.dist-info}/licenses/LICENSE +0 -0
    
        wandb/__init__.py
    CHANGED
    
    | @@ -8,9 +8,10 @@ For scripts and interactive notebooks, see https://github.com/wandb/examples. | |
| 8 8 |  | 
| 9 9 | 
             
            For reference documentation, see https://docs.wandb.com/ref/python.
         | 
| 10 10 | 
             
            """
         | 
| 11 | 
            -
             | 
| 11 | 
            +
            from __future__ import annotations
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            __version__ = "0.18.4"
         | 
| 12 14 |  | 
| 13 | 
            -
            from typing import Optional
         | 
| 14 15 |  | 
| 15 16 | 
             
            from wandb.errors import Error
         | 
| 16 17 |  | 
| @@ -28,8 +29,6 @@ setup = wandb_sdk.setup | |
| 28 29 | 
             
            _attach = wandb_sdk._attach
         | 
| 29 30 | 
             
            _sync = wandb_sdk._sync
         | 
| 30 31 | 
             
            _teardown = wandb_sdk.teardown
         | 
| 31 | 
            -
            watch = wandb_sdk.watch
         | 
| 32 | 
            -
            unwatch = wandb_sdk.unwatch
         | 
| 33 32 | 
             
            finish = wandb_sdk.finish
         | 
| 34 33 | 
             
            join = finish
         | 
| 35 34 | 
             
            login = wandb_sdk.login
         | 
| @@ -112,10 +111,12 @@ def _assert_is_user_process(): | |
| 112 111 | 
             
            # globals
         | 
| 113 112 | 
             
            Api = PublicApi
         | 
| 114 113 | 
             
            api = InternalApi()
         | 
| 115 | 
            -
            run:  | 
| 114 | 
            +
            run: wandb_sdk.wandb_run.Run | None = None
         | 
| 116 115 | 
             
            config = _preinit.PreInitObject("wandb.config", wandb_sdk.wandb_config.Config)
         | 
| 117 116 | 
             
            summary = _preinit.PreInitObject("wandb.summary", wandb_sdk.wandb_summary.Summary)
         | 
| 118 117 | 
             
            log = _preinit.PreInitCallable("wandb.log", wandb_sdk.wandb_run.Run.log)  # type: ignore
         | 
| 118 | 
            +
            watch = _preinit.PreInitCallable("wandb.watch", wandb_sdk.wandb_run.Run.watch)  # type: ignore
         | 
| 119 | 
            +
            unwatch = _preinit.PreInitCallable("wandb.unwatch", wandb_sdk.wandb_run.Run.unwatch)  # type: ignore
         | 
| 119 120 | 
             
            save = _preinit.PreInitCallable("wandb.save", wandb_sdk.wandb_run.Run.save)  # type: ignore
         | 
| 120 121 | 
             
            restore = wandb_sdk.wandb_run.restore
         | 
| 121 122 | 
             
            use_artifact = _preinit.PreInitCallable(
         | 
| @@ -200,9 +201,16 @@ if "dev" in __version__: | |
| 200 201 | 
             
                import wandb.env
         | 
| 201 202 | 
             
                import os
         | 
| 202 203 |  | 
| 203 | 
            -
                #  | 
| 204 | 
            +
                # Disable error reporting in dev versions.
         | 
| 204 205 | 
             
                os.environ[wandb.env.ERROR_REPORTING] = os.environ.get(
         | 
| 205 | 
            -
                    wandb.env.ERROR_REPORTING, | 
| 206 | 
            +
                    wandb.env.ERROR_REPORTING,
         | 
| 207 | 
            +
                    "false",
         | 
| 208 | 
            +
                )
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                # Enable new features in dev versions.
         | 
| 211 | 
            +
                os.environ["WANDB__SHOW_OPERATION_STATS"] = os.environ.get(
         | 
| 212 | 
            +
                    "WANDB__SHOW_OPERATION_STATS",
         | 
| 213 | 
            +
                    "true",
         | 
| 206 214 | 
             
                )
         | 
| 207 215 |  | 
| 208 216 | 
             
            _sentry = _Sentry()
         | 
| @@ -242,4 +250,5 @@ __all__ = ( | |
| 242 250 | 
             
                "link_model",
         | 
| 243 251 | 
             
                "define_metric",
         | 
| 244 252 | 
             
                "watch",
         | 
| 253 | 
            +
                "unwatch",
         | 
| 245 254 | 
             
            )
         | 
    
        wandb/__init__.pyi
    CHANGED
    
    | @@ -52,6 +52,7 @@ __all__ = ( | |
| 52 52 | 
             
                "Settings",
         | 
| 53 53 | 
             
                "teardown",
         | 
| 54 54 | 
             
                "watch",
         | 
| 55 | 
            +
                "unwatch",
         | 
| 55 56 | 
             
            )
         | 
| 56 57 |  | 
| 57 58 | 
             
            import os
         | 
| @@ -95,12 +96,14 @@ from wandb.wandb_controller import _WandbController | |
| 95 96 | 
             
            if TYPE_CHECKING:
         | 
| 96 97 | 
             
                import torch  # type: ignore [import-not-found]
         | 
| 97 98 |  | 
| 98 | 
            -
             | 
| 99 | 
            +
                from wandb.plot.viz import CustomChart
         | 
| 100 | 
            +
             | 
| 101 | 
            +
            __version__: str = "0.18.4"
         | 
| 99 102 |  | 
| 100 103 | 
             
            run: Run | None
         | 
| 101 104 | 
             
            config: wandb_config.Config
         | 
| 102 105 | 
             
            summary: wandb_summary.Summary
         | 
| 103 | 
            -
            Api: PublicApi
         | 
| 106 | 
            +
            Api: type[PublicApi]
         | 
| 104 107 |  | 
| 105 108 | 
             
            # private attributes
         | 
| 106 109 | 
             
            _sentry: Sentry
         | 
| @@ -181,32 +184,32 @@ def teardown(exit_code: Optional[int] = None) -> None: | |
| 181 184 | 
             
                ...
         | 
| 182 185 |  | 
| 183 186 | 
             
            def init(
         | 
| 184 | 
            -
                job_type:  | 
| 185 | 
            -
                dir:  | 
| 186 | 
            -
                config:  | 
| 187 | 
            -
                project:  | 
| 188 | 
            -
                entity:  | 
| 189 | 
            -
                reinit:  | 
| 190 | 
            -
                tags:  | 
| 191 | 
            -
                group:  | 
| 192 | 
            -
                name:  | 
| 193 | 
            -
                notes:  | 
| 194 | 
            -
                magic:  | 
| 195 | 
            -
                config_exclude_keys:  | 
| 196 | 
            -
                config_include_keys:  | 
| 197 | 
            -
                anonymous:  | 
| 198 | 
            -
                mode:  | 
| 199 | 
            -
                allow_val_change:  | 
| 200 | 
            -
                resume:  | 
| 201 | 
            -
                force:  | 
| 202 | 
            -
                tensorboard:  | 
| 203 | 
            -
                sync_tensorboard:  | 
| 204 | 
            -
                monitor_gym:  | 
| 205 | 
            -
                save_code:  | 
| 206 | 
            -
                id:  | 
| 207 | 
            -
                fork_from:  | 
| 208 | 
            -
                resume_from:  | 
| 209 | 
            -
                settings:  | 
| 187 | 
            +
                job_type: str | None = None,
         | 
| 188 | 
            +
                dir: StrPath | None = None,
         | 
| 189 | 
            +
                config: dict | str | None = None,
         | 
| 190 | 
            +
                project: str | None = None,
         | 
| 191 | 
            +
                entity: str | None = None,
         | 
| 192 | 
            +
                reinit: bool | None = None,
         | 
| 193 | 
            +
                tags: Sequence | None = None,
         | 
| 194 | 
            +
                group: str | None = None,
         | 
| 195 | 
            +
                name: str | None = None,
         | 
| 196 | 
            +
                notes: str | None = None,
         | 
| 197 | 
            +
                magic: dict | str | bool | None = None,
         | 
| 198 | 
            +
                config_exclude_keys: list[str] | None = None,
         | 
| 199 | 
            +
                config_include_keys: list[str] | None = None,
         | 
| 200 | 
            +
                anonymous: str | None = None,
         | 
| 201 | 
            +
                mode: str | None = None,
         | 
| 202 | 
            +
                allow_val_change: bool | None = None,
         | 
| 203 | 
            +
                resume: bool | str | None = None,
         | 
| 204 | 
            +
                force: bool | None = None,
         | 
| 205 | 
            +
                tensorboard: bool | None = None,  # alias for sync_tensorboard
         | 
| 206 | 
            +
                sync_tensorboard: bool | None = None,
         | 
| 207 | 
            +
                monitor_gym: bool | None = None,
         | 
| 208 | 
            +
                save_code: bool | None = None,
         | 
| 209 | 
            +
                id: str | None = None,
         | 
| 210 | 
            +
                fork_from: str | None = None,
         | 
| 211 | 
            +
                resume_from: str | None = None,
         | 
| 212 | 
            +
                settings: Settings | dict[str, Any] | None = None,
         | 
| 210 213 | 
             
            ) -> Run:
         | 
| 211 214 | 
             
                r"""Start a new run to track and log to W&B.
         | 
| 212 215 |  | 
| @@ -420,7 +423,7 @@ def init( | |
| 420 423 | 
             
                """
         | 
| 421 424 | 
             
                ...
         | 
| 422 425 |  | 
| 423 | 
            -
            def finish(exit_code:  | 
| 426 | 
            +
            def finish(exit_code: int | None = None, quiet: bool | None = None) -> None:
         | 
| 424 427 | 
             
                """Mark a run as finished, and finish uploading all data.
         | 
| 425 428 |  | 
| 426 429 | 
             
                This is used when creating multiple runs in the same process.
         | 
| @@ -470,10 +473,10 @@ def login( | |
| 470 473 | 
             
                ...
         | 
| 471 474 |  | 
| 472 475 | 
             
            def log(
         | 
| 473 | 
            -
                data:  | 
| 474 | 
            -
                step:  | 
| 475 | 
            -
                commit:  | 
| 476 | 
            -
                sync:  | 
| 476 | 
            +
                data: dict[str, Any],
         | 
| 477 | 
            +
                step: int | None = None,
         | 
| 478 | 
            +
                commit: bool | None = None,
         | 
| 479 | 
            +
                sync: bool | None = None,
         | 
| 477 480 | 
             
            ) -> None:
         | 
| 478 481 | 
             
                """Upload run data.
         | 
| 479 482 |  | 
| @@ -704,10 +707,10 @@ def log( | |
| 704 707 | 
             
                ...
         | 
| 705 708 |  | 
| 706 709 | 
             
            def save(
         | 
| 707 | 
            -
                glob_str:  | 
| 708 | 
            -
                base_path:  | 
| 710 | 
            +
                glob_str: str | os.PathLike | None = None,
         | 
| 711 | 
            +
                base_path: str | os.PathLike | None = None,
         | 
| 709 712 | 
             
                policy: PolicyName = "live",
         | 
| 710 | 
            -
            ) ->  | 
| 713 | 
            +
            ) -> bool | list[str]:
         | 
| 711 714 | 
             
                """Sync one or more files to W&B.
         | 
| 712 715 |  | 
| 713 716 | 
             
                Relative paths are relative to the current working directory.
         | 
| @@ -846,12 +849,12 @@ def agent( | |
| 846 849 |  | 
| 847 850 | 
             
            def define_metric(
         | 
| 848 851 | 
             
                name: str,
         | 
| 849 | 
            -
                step_metric:  | 
| 850 | 
            -
                step_sync:  | 
| 851 | 
            -
                hidden:  | 
| 852 | 
            -
                summary:  | 
| 853 | 
            -
                goal:  | 
| 854 | 
            -
                overwrite:  | 
| 852 | 
            +
                step_metric: str | wandb_metric.Metric | None = None,
         | 
| 853 | 
            +
                step_sync: bool | None = None,
         | 
| 854 | 
            +
                hidden: bool | None = None,
         | 
| 855 | 
            +
                summary: str | None = None,
         | 
| 856 | 
            +
                goal: str | None = None,
         | 
| 857 | 
            +
                overwrite: bool | None = None,
         | 
| 855 858 | 
             
            ) -> wandb_metric.Metric:
         | 
| 856 859 | 
             
                """Customize metrics logged with `wandb.log()`.
         | 
| 857 860 |  | 
| @@ -882,11 +885,11 @@ def define_metric( | |
| 882 885 | 
             
                ...
         | 
| 883 886 |  | 
| 884 887 | 
             
            def log_artifact(
         | 
| 885 | 
            -
                artifact_or_path:  | 
| 886 | 
            -
                name:  | 
| 887 | 
            -
                type:  | 
| 888 | 
            -
                aliases:  | 
| 889 | 
            -
                tags:  | 
| 888 | 
            +
                artifact_or_path: Artifact | StrPath,
         | 
| 889 | 
            +
                name: str | None = None,
         | 
| 890 | 
            +
                type: str | None = None,
         | 
| 891 | 
            +
                aliases: list[str] | None = None,
         | 
| 892 | 
            +
                tags: list[str] | None = None,
         | 
| 890 893 | 
             
            ) -> Artifact:
         | 
| 891 894 | 
             
                """Declare an artifact as an output of a run.
         | 
| 892 895 |  | 
| @@ -915,10 +918,10 @@ def log_artifact( | |
| 915 918 | 
             
                ...
         | 
| 916 919 |  | 
| 917 920 | 
             
            def use_artifact(
         | 
| 918 | 
            -
                artifact_or_name:  | 
| 919 | 
            -
                type:  | 
| 920 | 
            -
                aliases:  | 
| 921 | 
            -
                use_as:  | 
| 921 | 
            +
                artifact_or_name: str | Artifact,
         | 
| 922 | 
            +
                type: str | None = None,
         | 
| 923 | 
            +
                aliases: list[str] | None = None,
         | 
| 924 | 
            +
                use_as: str | None = None,
         | 
| 922 925 | 
             
            ) -> Artifact:
         | 
| 923 926 | 
             
                """Declare an artifact as an input to a run.
         | 
| 924 927 |  | 
| @@ -926,8 +929,9 @@ def use_artifact( | |
| 926 929 |  | 
| 927 930 | 
             
                Arguments:
         | 
| 928 931 | 
             
                    artifact_or_name: (str or Artifact) An artifact name.
         | 
| 929 | 
            -
                        May be prefixed with entity/project/. | 
| 930 | 
            -
                         | 
| 932 | 
            +
                        May be prefixed with project/ or entity/project/.
         | 
| 933 | 
            +
                        If no entity is specified in the name, the Run or API setting's entity is used.
         | 
| 934 | 
            +
                        Valid names can be in the following forms:
         | 
| 931 935 | 
             
                            - name:version
         | 
| 932 936 | 
             
                            - name:alias
         | 
| 933 937 | 
             
                        You can also pass an Artifact object created by calling `wandb.Artifact`
         | 
| @@ -943,8 +947,8 @@ def use_artifact( | |
| 943 947 |  | 
| 944 948 | 
             
            def log_model(
         | 
| 945 949 | 
             
                path: StrPath,
         | 
| 946 | 
            -
                name:  | 
| 947 | 
            -
                aliases:  | 
| 950 | 
            +
                name: str | None = None,
         | 
| 951 | 
            +
                aliases: list[str] | None = None,
         | 
| 948 952 | 
             
            ) -> None:
         | 
| 949 953 | 
             
                """Logs a model artifact containing the contents inside the 'path' to a run and marks it as an output to this run.
         | 
| 950 954 |  | 
| @@ -1031,8 +1035,8 @@ def use_model(name: str) -> FilePathStr: | |
| 1031 1035 | 
             
            def link_model(
         | 
| 1032 1036 | 
             
                path: StrPath,
         | 
| 1033 1037 | 
             
                registered_model_name: str,
         | 
| 1034 | 
            -
                name:  | 
| 1035 | 
            -
                aliases:  | 
| 1038 | 
            +
                name: str | None = None,
         | 
| 1039 | 
            +
                aliases: list[str] | None = None,
         | 
| 1036 1040 | 
             
            ) -> None:
         | 
| 1037 1041 | 
             
                """Log a model artifact version and link it to a registered model in the model registry.
         | 
| 1038 1042 |  | 
| @@ -1099,6 +1103,28 @@ def link_model( | |
| 1099 1103 | 
             
                """
         | 
| 1100 1104 | 
             
                ...
         | 
| 1101 1105 |  | 
| 1106 | 
            +
            def plot_table(
         | 
| 1107 | 
            +
                vega_spec_name: str,
         | 
| 1108 | 
            +
                data_table: Table,
         | 
| 1109 | 
            +
                fields: dict[str, Any],
         | 
| 1110 | 
            +
                string_fields: dict[str, Any] | None = None,
         | 
| 1111 | 
            +
                split_table: bool | None = False,
         | 
| 1112 | 
            +
            ) -> CustomChart:
         | 
| 1113 | 
            +
                """Create a custom plot on a table.
         | 
| 1114 | 
            +
             | 
| 1115 | 
            +
                Arguments:
         | 
| 1116 | 
            +
                    vega_spec_name: the name of the spec for the plot
         | 
| 1117 | 
            +
                    data_table: a wandb.Table object containing the data to
         | 
| 1118 | 
            +
                        be used on the visualization
         | 
| 1119 | 
            +
                    fields: a dict mapping from table keys to fields that the custom
         | 
| 1120 | 
            +
                        visualization needs
         | 
| 1121 | 
            +
                    string_fields: a dict that provides values for any string constants
         | 
| 1122 | 
            +
                        the custom visualization needs
         | 
| 1123 | 
            +
                    split_table: a boolean that indicates whether the table should be in
         | 
| 1124 | 
            +
                        a separate section in the UI
         | 
| 1125 | 
            +
                """
         | 
| 1126 | 
            +
                ...
         | 
| 1127 | 
            +
             | 
| 1102 1128 | 
             
            def watch(
         | 
| 1103 1129 | 
             
                models: torch.nn.Module | Sequence[torch.nn.Module],
         | 
| 1104 1130 | 
             
                criterion: torch.F | None = None,
         | 
| @@ -1106,7 +1132,7 @@ def watch( | |
| 1106 1132 | 
             
                log_freq: int = 1000,
         | 
| 1107 1133 | 
             
                idx: int | None = None,
         | 
| 1108 1134 | 
             
                log_graph: bool = False,
         | 
| 1109 | 
            -
            ) ->  | 
| 1135 | 
            +
            ) -> None:
         | 
| 1110 1136 | 
             
                """Hooks into the given PyTorch model(s) to monitor gradients and the model's computational graph.
         | 
| 1111 1137 |  | 
| 1112 1138 | 
             
                This function can track parameters, gradients, or both during training. It should be
         | 
| @@ -1124,16 +1150,23 @@ def watch( | |
| 1124 1150 | 
             
                        Frequency (in batches) to log gradients and parameters. (default=1000)
         | 
| 1125 1151 | 
             
                    idx (Optional[int]):
         | 
| 1126 1152 | 
             
                        Index used when tracking multiple models with `wandb.watch`. (default=None)
         | 
| 1127 | 
            -
             | 
| 1153 | 
            +
                    log_graph (bool):
         | 
| 1128 1154 | 
             
                        Whether to log the model's computational graph. (default=False)
         | 
| 1129 1155 |  | 
| 1130 | 
            -
                Returns:
         | 
| 1131 | 
            -
                    wandb.Graph:
         | 
| 1132 | 
            -
                        The graph object, which will be populated after the first backward pass.
         | 
| 1133 | 
            -
             | 
| 1134 1156 | 
             
                Raises:
         | 
| 1135 1157 | 
             
                    ValueError:
         | 
| 1136 1158 | 
             
                        If `wandb.init` has not been called or if any of the models are not instances
         | 
| 1137 1159 | 
             
                        of `torch.nn.Module`.
         | 
| 1138 1160 | 
             
                """
         | 
| 1139 1161 | 
             
                ...
         | 
| 1162 | 
            +
             | 
| 1163 | 
            +
            def unwatch(
         | 
| 1164 | 
            +
                models: torch.nn.Module | Sequence[torch.nn.Module] | None = None,
         | 
| 1165 | 
            +
            ) -> None:
         | 
| 1166 | 
            +
                """Remove pytorch model topology, gradient and parameter hooks.
         | 
| 1167 | 
            +
             | 
| 1168 | 
            +
                Args:
         | 
| 1169 | 
            +
                    models (torch.nn.Module | Sequence[torch.nn.Module]):
         | 
| 1170 | 
            +
                        Optional list of pytorch models that have had watch called on them
         | 
| 1171 | 
            +
                """
         | 
| 1172 | 
            +
                ...
         | 
    
        wandb/analytics/sentry.py
    CHANGED
    
    | @@ -1,3 +1,5 @@ | |
| 1 | 
            +
            from __future__ import annotations
         | 
| 2 | 
            +
             | 
| 1 3 | 
             
            __all__ = ("Sentry",)
         | 
| 2 4 |  | 
| 3 5 |  | 
| @@ -7,7 +9,7 @@ import os | |
| 7 9 | 
             
            import pathlib
         | 
| 8 10 | 
             
            import sys
         | 
| 9 11 | 
             
            from types import TracebackType
         | 
| 10 | 
            -
            from typing import TYPE_CHECKING, Any, Callable | 
| 12 | 
            +
            from typing import TYPE_CHECKING, Any, Callable
         | 
| 11 13 | 
             
            from urllib.parse import quote
         | 
| 12 14 |  | 
| 13 15 | 
             
            if sys.version_info >= (3, 8):
         | 
| @@ -16,6 +18,7 @@ else: | |
| 16 18 | 
             
                from typing_extensions import Literal
         | 
| 17 19 |  | 
| 18 20 | 
             
            import sentry_sdk  # type: ignore
         | 
| 21 | 
            +
            import sentry_sdk.scope  # type: ignore
         | 
| 19 22 | 
             
            import sentry_sdk.utils  # type: ignore
         | 
| 20 23 |  | 
| 21 24 | 
             
            import wandb
         | 
| @@ -36,7 +39,7 @@ def _safe_noop(func: Callable) -> Callable: | |
| 36 39 | 
             
                """Decorator to ensure that Sentry methods do nothing if disabled and don't raise."""
         | 
| 37 40 |  | 
| 38 41 | 
             
                @functools.wraps(func)
         | 
| 39 | 
            -
                def wrapper(self:  | 
| 42 | 
            +
                def wrapper(self: type[Sentry], *args: Any, **kwargs: Any) -> Any:
         | 
| 40 43 | 
             
                    if self._disabled:
         | 
| 41 44 | 
             
                        return None
         | 
| 42 45 | 
             
                    try:
         | 
| @@ -59,7 +62,7 @@ class Sentry: | |
| 59 62 |  | 
| 60 63 | 
             
                    self.dsn = os.environ.get(wandb.env.SENTRY_DSN, SENTRY_DEFAULT_DSN)
         | 
| 61 64 |  | 
| 62 | 
            -
                    self. | 
| 65 | 
            +
                    self.scope: sentry_sdk.scope.Scope | None = None
         | 
| 63 66 |  | 
| 64 67 | 
             
                    # ensure we always end the Sentry session
         | 
| 65 68 | 
             
                    atexit.register(self.end_session)
         | 
| @@ -87,47 +90,50 @@ class Sentry: | |
| 87 90 | 
             
                        environment=self.environment,
         | 
| 88 91 | 
             
                        release=wandb.__version__,
         | 
| 89 92 | 
             
                    )
         | 
| 90 | 
            -
                    self. | 
| 93 | 
            +
                    self.scope = sentry_sdk.get_global_scope().fork()
         | 
| 94 | 
            +
                    self.scope.clear()
         | 
| 95 | 
            +
                    self.scope.set_client(client)
         | 
| 91 96 |  | 
| 92 97 | 
             
                @_safe_noop
         | 
| 93 | 
            -
                def message(self, message: str, repeat: bool = True) -> None:
         | 
| 98 | 
            +
                def message(self, message: str, repeat: bool = True) -> str | None:
         | 
| 94 99 | 
             
                    """Send a message to Sentry."""
         | 
| 95 100 | 
             
                    if not repeat and message in self._sent_messages:
         | 
| 96 | 
            -
                        return
         | 
| 101 | 
            +
                        return None
         | 
| 97 102 | 
             
                    self._sent_messages.add(message)
         | 
| 98 | 
            -
                     | 
| 103 | 
            +
                    with sentry_sdk.scope.use_isolation_scope(self.scope):  # type: ignore
         | 
| 104 | 
            +
                        return sentry_sdk.capture_message(message)  # type: ignore
         | 
| 99 105 |  | 
| 100 106 | 
             
                @_safe_noop
         | 
| 101 107 | 
             
                def exception(
         | 
| 102 108 | 
             
                    self,
         | 
| 103 | 
            -
                    exc:  | 
| 104 | 
            -
             | 
| 105 | 
            -
             | 
| 106 | 
            -
                         | 
| 107 | 
            -
             | 
| 108 | 
            -
             | 
| 109 | 
            -
             | 
| 110 | 
            -
             | 
| 111 | 
            -
                        None,
         | 
| 112 | 
            -
                    ],
         | 
| 109 | 
            +
                    exc: str
         | 
| 110 | 
            +
                    | BaseException
         | 
| 111 | 
            +
                    | tuple[
         | 
| 112 | 
            +
                        type[BaseException] | None,
         | 
| 113 | 
            +
                        BaseException | None,
         | 
| 114 | 
            +
                        TracebackType | None,
         | 
| 115 | 
            +
                    ]
         | 
| 116 | 
            +
                    | None,
         | 
| 113 117 | 
             
                    handled: bool = False,
         | 
| 114 | 
            -
                    status:  | 
| 115 | 
            -
                ) -> None:
         | 
| 118 | 
            +
                    status: SessionStatus | None = None,
         | 
| 119 | 
            +
                ) -> str | None:
         | 
| 116 120 | 
             
                    """Log an exception to Sentry."""
         | 
| 117 | 
            -
                     | 
| 118 | 
            -
             | 
| 119 | 
            -
                     | 
| 120 | 
            -
                        exc_info = sentry_sdk.utils.exc_info_from_error( | 
| 121 | 
            +
                    if isinstance(exc, str):
         | 
| 122 | 
            +
                        exc_info = sentry_sdk.utils.exc_info_from_error(Exception(exc))
         | 
| 123 | 
            +
                    elif isinstance(exc, BaseException):
         | 
| 124 | 
            +
                        exc_info = sentry_sdk.utils.exc_info_from_error(exc)
         | 
| 121 125 | 
             
                    else:
         | 
| 122 126 | 
             
                        exc_info = sys.exc_info()
         | 
| 123 127 |  | 
| 124 | 
            -
                    event,  | 
| 128 | 
            +
                    event, _ = sentry_sdk.utils.event_from_exception(
         | 
| 125 129 | 
             
                        exc_info,
         | 
| 126 | 
            -
                        client_options=self. | 
| 130 | 
            +
                        client_options=self.scope.get_client().options,  # type: ignore
         | 
| 127 131 | 
             
                        mechanism={"type": "generic", "handled": handled},
         | 
| 128 132 | 
             
                    )
         | 
| 133 | 
            +
                    event_id = None
         | 
| 129 134 | 
             
                    try:
         | 
| 130 | 
            -
                         | 
| 135 | 
            +
                        with sentry_sdk.scope.use_isolation_scope(self.scope):  # type: ignore
         | 
| 136 | 
            +
                            event_id = sentry_sdk.capture_event(event)  # type: ignore
         | 
| 131 137 | 
             
                    except Exception:
         | 
| 132 138 | 
             
                        pass
         | 
| 133 139 |  | 
| @@ -136,11 +142,11 @@ class Sentry: | |
| 136 142 | 
             
                    status = status or ("crashed" if not handled else "errored")  # type: ignore
         | 
| 137 143 | 
             
                    self.mark_session(status=status)
         | 
| 138 144 |  | 
| 139 | 
            -
                    client | 
| 145 | 
            +
                    client = self.scope.get_client()  # type: ignore
         | 
| 140 146 | 
             
                    if client is not None:
         | 
| 141 147 | 
             
                        client.flush()
         | 
| 142 148 |  | 
| 143 | 
            -
                    return  | 
| 149 | 
            +
                    return event_id
         | 
| 144 150 |  | 
| 145 151 | 
             
                def reraise(self, exc: Any) -> None:
         | 
| 146 152 | 
             
                    """Re-raise an exception after logging it to Sentry.
         | 
| @@ -157,33 +163,31 @@ class Sentry: | |
| 157 163 | 
             
                @_safe_noop
         | 
| 158 164 | 
             
                def start_session(self) -> None:
         | 
| 159 165 | 
             
                    """Start a new session."""
         | 
| 160 | 
            -
                    assert self. | 
| 166 | 
            +
                    assert self.scope is not None
         | 
| 161 167 | 
             
                    # get the current client and scope
         | 
| 162 | 
            -
                     | 
| 163 | 
            -
                    session = scope._session
         | 
| 168 | 
            +
                    session = self.scope._session
         | 
| 164 169 |  | 
| 165 170 | 
             
                    # if there's no session, start one
         | 
| 166 171 | 
             
                    if session is None:
         | 
| 167 | 
            -
                        self. | 
| 172 | 
            +
                        self.scope.start_session()
         | 
| 168 173 |  | 
| 169 174 | 
             
                @_safe_noop
         | 
| 170 175 | 
             
                def end_session(self) -> None:
         | 
| 171 176 | 
             
                    """End the current session."""
         | 
| 172 | 
            -
                    assert self. | 
| 177 | 
            +
                    assert self.scope is not None
         | 
| 173 178 | 
             
                    # get the current client and scope
         | 
| 174 | 
            -
                    client | 
| 175 | 
            -
                    session = scope._session
         | 
| 179 | 
            +
                    client = self.scope.get_client()
         | 
| 180 | 
            +
                    session = self.scope._session
         | 
| 176 181 |  | 
| 177 182 | 
             
                    if session is not None and client is not None:
         | 
| 178 | 
            -
                        self. | 
| 183 | 
            +
                        self.scope.end_session()
         | 
| 179 184 | 
             
                        client.flush()
         | 
| 180 185 |  | 
| 181 186 | 
             
                @_safe_noop
         | 
| 182 | 
            -
                def mark_session(self, status:  | 
| 187 | 
            +
                def mark_session(self, status: SessionStatus | None = None) -> None:
         | 
| 183 188 | 
             
                    """Mark the current session with a status."""
         | 
| 184 | 
            -
                    assert self. | 
| 185 | 
            -
                     | 
| 186 | 
            -
                    session = scope._session
         | 
| 189 | 
            +
                    assert self.scope is not None
         | 
| 190 | 
            +
                    session = self.scope._session
         | 
| 187 191 |  | 
| 188 192 | 
             
                    if session is not None:
         | 
| 189 193 | 
             
                        session.update(status=status)
         | 
| @@ -191,8 +195,8 @@ class Sentry: | |
| 191 195 | 
             
                @_safe_noop
         | 
| 192 196 | 
             
                def configure_scope(
         | 
| 193 197 | 
             
                    self,
         | 
| 194 | 
            -
                    tags:  | 
| 195 | 
            -
                    process_context:  | 
| 198 | 
            +
                    tags: dict[str, Any] | None = None,
         | 
| 199 | 
            +
                    process_context: str | None = None,
         | 
| 196 200 | 
             
                ) -> None:
         | 
| 197 201 | 
             
                    """Configure the Sentry scope for the current thread.
         | 
| 198 202 |  | 
| @@ -201,7 +205,7 @@ class Sentry: | |
| 201 205 | 
             
                    all events sent from this thread. It also tries to start a session
         | 
| 202 206 | 
             
                    if one doesn't already exist for this thread.
         | 
| 203 207 | 
             
                    """
         | 
| 204 | 
            -
                    assert self. | 
| 208 | 
            +
                    assert self.scope is not None
         | 
| 205 209 | 
             
                    settings_tags = (
         | 
| 206 210 | 
             
                        "entity",
         | 
| 207 211 | 
             
                        "project",
         | 
| @@ -215,51 +219,50 @@ class Sentry: | |
| 215 219 | 
             
                        "launch",
         | 
| 216 220 | 
             
                    )
         | 
| 217 221 |  | 
| 218 | 
            -
                     | 
| 219 | 
            -
             | 
| 220 | 
            -
             | 
| 221 | 
            -
             | 
| 222 | 
            -
                         | 
| 223 | 
            -
             | 
| 224 | 
            -
             | 
| 225 | 
            -
             | 
| 226 | 
            -
                         | 
| 227 | 
            -
             | 
| 228 | 
            -
             | 
| 229 | 
            -
                         | 
| 230 | 
            -
             | 
| 231 | 
            -
                             | 
| 232 | 
            -
             | 
| 233 | 
            -
             | 
| 234 | 
            -
                         | 
| 235 | 
            -
             | 
| 236 | 
            -
                         | 
| 237 | 
            -
             | 
| 238 | 
            -
                         | 
| 239 | 
            -
             | 
| 240 | 
            -
                         | 
| 241 | 
            -
             | 
| 242 | 
            -
             | 
| 243 | 
            -
             | 
| 244 | 
            -
             | 
| 245 | 
            -
                         | 
| 246 | 
            -
             | 
| 247 | 
            -
                             | 
| 248 | 
            -
             | 
| 249 | 
            -
             | 
| 250 | 
            -
                             | 
| 251 | 
            -
             | 
| 252 | 
            -
             | 
| 253 | 
            -
                                 | 
| 254 | 
            -
             | 
| 255 | 
            -
             | 
| 256 | 
            -
             | 
| 257 | 
            -
                             | 
| 258 | 
            -
             | 
| 259 | 
            -
             | 
| 260 | 
            -
             | 
| 261 | 
            -
                         | 
| 262 | 
            -
                            scope.user = {"email": email}  # noqa
         | 
| 222 | 
            +
                    self.scope.set_tag("platform", wandb.util.get_platform_name())
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                    # set context
         | 
| 225 | 
            +
                    if process_context:
         | 
| 226 | 
            +
                        self.scope.set_tag("process_context", process_context)
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                    # apply settings tags
         | 
| 229 | 
            +
                    if tags is None:
         | 
| 230 | 
            +
                        return None
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                    for tag in settings_tags:
         | 
| 233 | 
            +
                        val = tags.get(tag, None)
         | 
| 234 | 
            +
                        if val not in (None, ""):
         | 
| 235 | 
            +
                            self.scope.set_tag(tag, val)
         | 
| 236 | 
            +
             | 
| 237 | 
            +
                    if tags.get("_colab", None):
         | 
| 238 | 
            +
                        python_runtime = "colab"
         | 
| 239 | 
            +
                    elif tags.get("_jupyter", None):
         | 
| 240 | 
            +
                        python_runtime = "jupyter"
         | 
| 241 | 
            +
                    elif tags.get("_ipython", None):
         | 
| 242 | 
            +
                        python_runtime = "ipython"
         | 
| 243 | 
            +
                    else:
         | 
| 244 | 
            +
                        python_runtime = "python"
         | 
| 245 | 
            +
                    self.scope.set_tag("python_runtime", python_runtime)
         | 
| 246 | 
            +
             | 
| 247 | 
            +
                    # Construct run_url and sweep_url given run_id and sweep_id
         | 
| 248 | 
            +
                    for obj in ("run", "sweep"):
         | 
| 249 | 
            +
                        obj_id, obj_url = f"{obj}_id", f"{obj}_url"
         | 
| 250 | 
            +
                        if tags.get(obj_url, None):
         | 
| 251 | 
            +
                            continue
         | 
| 252 | 
            +
             | 
| 253 | 
            +
                        try:
         | 
| 254 | 
            +
                            app_url = wandb.util.app_url(tags["base_url"])  # type: ignore
         | 
| 255 | 
            +
                            entity, project = (quote(tags[k]) for k in ("entity", "project"))  # type: ignore
         | 
| 256 | 
            +
                            self.scope.set_tag(
         | 
| 257 | 
            +
                                obj_url,
         | 
| 258 | 
            +
                                f"{app_url}/{entity}/{project}/{obj}s/{tags[obj_id]}",
         | 
| 259 | 
            +
                            )
         | 
| 260 | 
            +
                        except Exception:
         | 
| 261 | 
            +
                            pass
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                    email = tags.get("email")
         | 
| 264 | 
            +
                    if email:
         | 
| 265 | 
            +
                        self.scope.user = {"email": email}  # noqa
         | 
| 263 266 |  | 
| 264 267 | 
             
                    # todo: add back the option to pass general tags see: c645f625d1c1a3db4a6b0e2aa8e924fee101904c (wandb/util.py)
         | 
| 265 268 |  | 
    
        wandb/apis/public/api.py
    CHANGED
    
    | @@ -27,6 +27,7 @@ from wandb.apis import public | |
| 27 27 | 
             
            from wandb.apis.internal import Api as InternalApi
         | 
| 28 28 | 
             
            from wandb.apis.normalize import normalize_exceptions
         | 
| 29 29 | 
             
            from wandb.apis.public.const import RETRY_TIMEDELTA
         | 
| 30 | 
            +
            from wandb.sdk.artifacts._validators import is_artifact_registry_project
         | 
| 30 31 | 
             
            from wandb.sdk.internal.thread_local_settings import _thread_local_api_settings
         | 
| 31 32 | 
             
            from wandb.sdk.launch.utils import LAUNCH_DEFAULT_PROJECT
         | 
| 32 33 | 
             
            from wandb.sdk.lib import retry, runid
         | 
| @@ -1142,11 +1143,12 @@ class Api: | |
| 1142 1143 |  | 
| 1143 1144 | 
             
                @normalize_exceptions
         | 
| 1144 1145 | 
             
                def artifact(self, name, type=None):
         | 
| 1145 | 
            -
                    """Return a single artifact by parsing path in the form `entity/project/name`.
         | 
| 1146 | 
            +
                    """Return a single artifact by parsing path in the form `project/name` or `entity/project/name`.
         | 
| 1146 1147 |  | 
| 1147 1148 | 
             
                    Arguments:
         | 
| 1148 | 
            -
                        name: (str) An artifact name. May be prefixed with entity/project | 
| 1149 | 
            -
             | 
| 1149 | 
            +
                        name: (str) An artifact name. May be prefixed with project/ or entity/project/.
         | 
| 1150 | 
            +
                                If no entity is specified in the name, the Run or API setting's entity is used.
         | 
| 1151 | 
            +
                            Valid names can be in the following forms:
         | 
| 1150 1152 | 
             
                                name:version
         | 
| 1151 1153 | 
             
                                name:alias
         | 
| 1152 1154 | 
             
                        type: (str, optional) The type of artifact to fetch.
         | 
| @@ -1157,8 +1159,20 @@ class Api: | |
| 1157 1159 | 
             
                    if name is None:
         | 
| 1158 1160 | 
             
                        raise ValueError("You must specify name= to fetch an artifact.")
         | 
| 1159 1161 | 
             
                    entity, project, artifact_name = self._parse_artifact_path(name)
         | 
| 1162 | 
            +
             | 
| 1163 | 
            +
                    organization = ""
         | 
| 1164 | 
            +
                    # If its an Registry artifact, the entity is an org instead
         | 
| 1165 | 
            +
                    if is_artifact_registry_project(project):
         | 
| 1166 | 
            +
                        # Update `organization` only if an organization name was provided,
         | 
| 1167 | 
            +
                        # otherwise use the default that you already set above.
         | 
| 1168 | 
            +
                        try:
         | 
| 1169 | 
            +
                            organization, _, _ = name.split("/")
         | 
| 1170 | 
            +
                        except ValueError:
         | 
| 1171 | 
            +
                            organization = ""
         | 
| 1172 | 
            +
                        # set entity to match the settings since in above code it was potentially set to an org
         | 
| 1173 | 
            +
                        entity = self.settings["entity"] or self.default_entity
         | 
| 1160 1174 | 
             
                    artifact = wandb.Artifact._from_name(
         | 
| 1161 | 
            -
                        entity, project, artifact_name, self.client
         | 
| 1175 | 
            +
                        entity, project, artifact_name, self.client, organization
         | 
| 1162 1176 | 
             
                    )
         | 
| 1163 1177 | 
             
                    if type is not None and artifact.type != type:
         | 
| 1164 1178 | 
             
                        raise ValueError(
         |