wandb 0.19.8__py3-none-macosx_11_0_arm64.whl → 0.19.9__py3-none-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +5 -1
- wandb/__init__.pyi +12 -8
- wandb/_pydantic/__init__.py +23 -0
- wandb/_pydantic/base.py +113 -0
- wandb/_pydantic/v1_compat.py +262 -0
- wandb/apis/paginator.py +82 -38
- wandb/apis/public/api.py +10 -64
- wandb/apis/public/artifacts.py +73 -17
- wandb/apis/public/files.py +2 -2
- wandb/apis/public/projects.py +3 -2
- wandb/apis/public/reports.py +2 -2
- wandb/apis/public/runs.py +19 -11
- wandb/bin/gpu_stats +0 -0
- wandb/bin/wandb-core +0 -0
- wandb/integration/metaflow/metaflow.py +19 -17
- wandb/integration/sacred/__init__.py +1 -1
- wandb/jupyter.py +18 -15
- wandb/proto/v3/wandb_internal_pb2.py +7 -3
- wandb/proto/v3/wandb_settings_pb2.py +2 -2
- wandb/proto/v3/wandb_telemetry_pb2.py +4 -4
- wandb/proto/v4/wandb_internal_pb2.py +3 -3
- wandb/proto/v4/wandb_settings_pb2.py +2 -2
- wandb/proto/v4/wandb_telemetry_pb2.py +4 -4
- wandb/proto/v5/wandb_internal_pb2.py +3 -3
- wandb/proto/v5/wandb_settings_pb2.py +2 -2
- wandb/proto/v5/wandb_telemetry_pb2.py +4 -4
- wandb/proto/wandb_deprecated.py +2 -0
- wandb/sdk/artifacts/_graphql_fragments.py +18 -20
- wandb/sdk/artifacts/_validators.py +1 -0
- wandb/sdk/artifacts/artifact.py +70 -36
- wandb/sdk/artifacts/artifact_saver.py +16 -2
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +23 -2
- wandb/sdk/data_types/audio.py +1 -3
- wandb/sdk/data_types/base_types/media.py +11 -4
- wandb/sdk/data_types/image.py +44 -25
- wandb/sdk/data_types/molecule.py +1 -5
- wandb/sdk/data_types/object_3d.py +2 -1
- wandb/sdk/data_types/saved_model.py +7 -9
- wandb/sdk/data_types/video.py +1 -4
- wandb/{apis/public → sdk/internal}/_generated/__init__.py +0 -6
- wandb/sdk/internal/_generated/base.py +226 -0
- wandb/{apis/public → sdk/internal}/_generated/server_features_query.py +3 -3
- wandb/{apis/public → sdk/internal}/_generated/typing_compat.py +1 -1
- wandb/sdk/internal/internal_api.py +138 -47
- wandb/sdk/internal/sender.py +2 -0
- wandb/sdk/internal/sender_config.py +8 -11
- wandb/sdk/internal/settings_static.py +24 -2
- wandb/sdk/lib/apikey.py +15 -16
- wandb/sdk/lib/run_moment.py +4 -6
- wandb/sdk/lib/wb_logging.py +161 -0
- wandb/sdk/wandb_config.py +44 -43
- wandb/sdk/wandb_init.py +141 -79
- wandb/sdk/wandb_metadata.py +107 -91
- wandb/sdk/wandb_run.py +152 -44
- wandb/sdk/wandb_settings.py +403 -201
- wandb/sdk/wandb_setup.py +3 -1
- {wandb-0.19.8.dist-info → wandb-0.19.9.dist-info}/METADATA +3 -3
- {wandb-0.19.8.dist-info → wandb-0.19.9.dist-info}/RECORD +64 -60
- wandb/apis/public/_generated/base.py +0 -128
- /wandb/{apis/public → sdk/internal}/_generated/enums.py +0 -0
- /wandb/{apis/public → sdk/internal}/_generated/input_types.py +0 -0
- /wandb/{apis/public → sdk/internal}/_generated/operations.py +0 -0
- {wandb-0.19.8.dist-info → wandb-0.19.9.dist-info}/WHEEL +0 -0
- {wandb-0.19.8.dist-info → wandb-0.19.9.dist-info}/entry_points.txt +0 -0
- {wandb-0.19.8.dist-info → wandb-0.19.9.dist-info}/licenses/LICENSE +0 -0
    
        wandb/apis/public/api.py
    CHANGED
    
    | @@ -25,7 +25,6 @@ import wandb | |
| 25 25 | 
             
            from wandb import env, util
         | 
| 26 26 | 
             
            from wandb.apis import public
         | 
| 27 27 | 
             
            from wandb.apis.normalize import normalize_exceptions
         | 
| 28 | 
            -
            from wandb.apis.public._generated import SERVER_FEATURES_QUERY_GQL, ServerFeaturesQuery
         | 
| 29 28 | 
             
            from wandb.apis.public.const import RETRY_TIMEDELTA
         | 
| 30 29 | 
             
            from wandb.apis.public.registries import Registries
         | 
| 31 30 | 
             
            from wandb.apis.public.utils import (
         | 
| @@ -289,7 +288,7 @@ class Api: | |
| 289 288 | 
             
                        )
         | 
| 290 289 | 
             
                    )
         | 
| 291 290 | 
             
                    self._client = RetryingClient(self._base_client)
         | 
| 292 | 
            -
                    self._server_features_cache = None
         | 
| 291 | 
            +
                    self._server_features_cache: Optional[dict[str, bool]] = None
         | 
| 293 292 |  | 
| 294 293 | 
             
                def create_project(self, name: str, entity: str) -> None:
         | 
| 295 294 | 
             
                    """Create a new project.
         | 
| @@ -757,15 +756,14 @@ class Api: | |
| 757 756 | 
             
                    return parts
         | 
| 758 757 |  | 
| 759 758 | 
             
                def projects(
         | 
| 760 | 
            -
                    self, entity: Optional[str] = None, per_page:  | 
| 759 | 
            +
                    self, entity: Optional[str] = None, per_page: int = 200
         | 
| 761 760 | 
             
                ) -> "public.Projects":
         | 
| 762 761 | 
             
                    """Get projects for a given entity.
         | 
| 763 762 |  | 
| 764 763 | 
             
                    Args:
         | 
| 765 764 | 
             
                        entity: (str) Name of the entity requested.  If None, will fall back to the
         | 
| 766 765 | 
             
                            default entity passed to `Api`.  If no default entity, will raise a `ValueError`.
         | 
| 767 | 
            -
                        per_page: (int) Sets the page size for query pagination.   | 
| 768 | 
            -
                            Usually there is no reason to change this.
         | 
| 766 | 
            +
                        per_page: (int) Sets the page size for query pagination.  Usually there is no reason to change this.
         | 
| 769 767 |  | 
| 770 768 | 
             
                    Returns:
         | 
| 771 769 | 
             
                        A `Projects` object which is an iterable collection of `Project` objects.
         | 
| @@ -808,7 +806,7 @@ class Api: | |
| 808 806 | 
             
                    return public.Project(self.client, entity, name, {})
         | 
| 809 807 |  | 
| 810 808 | 
             
                def reports(
         | 
| 811 | 
            -
                    self, path: str = "", name: Optional[str] = None, per_page:  | 
| 809 | 
            +
                    self, path: str = "", name: Optional[str] = None, per_page: int = 50
         | 
| 812 810 | 
             
                ) -> "public.Reports":
         | 
| 813 811 | 
             
                    """Get reports for a given project path.
         | 
| 814 812 |  | 
| @@ -817,8 +815,7 @@ class Api: | |
| 817 815 | 
             
                    Args:
         | 
| 818 816 | 
             
                        path: (str) path to project the report resides in, should be in the form: "entity/project"
         | 
| 819 817 | 
             
                        name: (str, optional) optional name of the report requested.
         | 
| 820 | 
            -
                        per_page: (int) Sets the page size for query pagination.   | 
| 821 | 
            -
                            Usually there is no reason to change this.
         | 
| 818 | 
            +
                        per_page: (int) Sets the page size for query pagination.  Usually there is no reason to change this.
         | 
| 822 819 |  | 
| 823 820 | 
             
                    Returns:
         | 
| 824 821 | 
             
                        A `Reports` object which is an iterable collection of `BetaReport` objects.
         | 
| @@ -1153,15 +1150,14 @@ class Api: | |
| 1153 1150 |  | 
| 1154 1151 | 
             
                @normalize_exceptions
         | 
| 1155 1152 | 
             
                def artifact_collections(
         | 
| 1156 | 
            -
                    self, project_name: str, type_name: str, per_page:  | 
| 1153 | 
            +
                    self, project_name: str, type_name: str, per_page: int = 50
         | 
| 1157 1154 | 
             
                ) -> "public.ArtifactCollections":
         | 
| 1158 1155 | 
             
                    """Return a collection of matching artifact collections.
         | 
| 1159 1156 |  | 
| 1160 1157 | 
             
                    Args:
         | 
| 1161 1158 | 
             
                        project_name: (str) The name of the project to filter on.
         | 
| 1162 1159 | 
             
                        type_name: (str) The name of the artifact type to filter on.
         | 
| 1163 | 
            -
                        per_page: (int | 
| 1164 | 
            -
                            Usually there is no reason to change this.
         | 
| 1160 | 
            +
                        per_page: (int) Sets the page size for query pagination.  Usually there is no reason to change this.
         | 
| 1165 1161 |  | 
| 1166 1162 | 
             
                    Returns:
         | 
| 1167 1163 | 
             
                        An iterable `ArtifactCollections` object.
         | 
| @@ -1226,7 +1222,7 @@ class Api: | |
| 1226 1222 | 
             
                    self,
         | 
| 1227 1223 | 
             
                    type_name: str,
         | 
| 1228 1224 | 
             
                    name: str,
         | 
| 1229 | 
            -
                    per_page:  | 
| 1225 | 
            +
                    per_page: int = 50,
         | 
| 1230 1226 | 
             
                    tags: Optional[List[str]] = None,
         | 
| 1231 1227 | 
             
                ) -> "public.Artifacts":
         | 
| 1232 1228 | 
             
                    """Return an `Artifacts` collection from the given parameters.
         | 
| @@ -1234,8 +1230,7 @@ class Api: | |
| 1234 1230 | 
             
                    Args:
         | 
| 1235 1231 | 
             
                        type_name: (str) The type of artifacts to fetch.
         | 
| 1236 1232 | 
             
                        name: (str) An artifact collection name. May be prefixed with entity/project.
         | 
| 1237 | 
            -
                        per_page: (int | 
| 1238 | 
            -
                            Usually there is no reason to change this.
         | 
| 1233 | 
            +
                        per_page: (int) Sets the page size for query pagination.  Usually there is no reason to change this.
         | 
| 1239 1234 | 
             
                        tags: (list[str], optional) Only return artifacts with all of these tags.
         | 
| 1240 1235 |  | 
| 1241 1236 | 
             
                    Returns:
         | 
| @@ -1510,7 +1505,7 @@ class Api: | |
| 1510 1505 | 
             
                    Returns:
         | 
| 1511 1506 | 
             
                        A registry iterator.
         | 
| 1512 1507 | 
             
                    """
         | 
| 1513 | 
            -
                    if not  | 
| 1508 | 
            +
                    if not InternalApi()._check_server_feature_with_fallback(
         | 
| 1514 1509 | 
             
                        ServerFeature.ARTIFACT_REGISTRY_SEARCH
         | 
| 1515 1510 | 
             
                    ):
         | 
| 1516 1511 | 
             
                        raise RuntimeError(
         | 
| @@ -1522,52 +1517,3 @@ class Api: | |
| 1522 1517 | 
             
                        self.settings, self.default_entity
         | 
| 1523 1518 | 
             
                    )
         | 
| 1524 1519 | 
             
                    return Registries(self.client, organization, filter)
         | 
| 1525 | 
            -
             | 
| 1526 | 
            -
                def _check_server_feature(self, feature: ServerFeature) -> bool:
         | 
| 1527 | 
            -
                    """Check if a server feature is enabled.
         | 
| 1528 | 
            -
             | 
| 1529 | 
            -
                    Args:
         | 
| 1530 | 
            -
                        feature (ServerFeature): The feature to check.
         | 
| 1531 | 
            -
             | 
| 1532 | 
            -
                    Returns:
         | 
| 1533 | 
            -
                        bool: True if the feature is enabled, False otherwise.
         | 
| 1534 | 
            -
             | 
| 1535 | 
            -
                    Raises:
         | 
| 1536 | 
            -
                        Exception: If server doesn't support feature queries or other errors occur
         | 
| 1537 | 
            -
                    """
         | 
| 1538 | 
            -
                    if self._server_features_cache is None:
         | 
| 1539 | 
            -
                        response = self.client.execute(gql(SERVER_FEATURES_QUERY_GQL))
         | 
| 1540 | 
            -
                        self._server_features_cache = ServerFeaturesQuery.model_validate(response)
         | 
| 1541 | 
            -
             | 
| 1542 | 
            -
                    feature_name = ServerFeature.Name(feature)
         | 
| 1543 | 
            -
                    if (
         | 
| 1544 | 
            -
                        self._server_features_cache
         | 
| 1545 | 
            -
                        and self._server_features_cache.server_info
         | 
| 1546 | 
            -
                        and self._server_features_cache.server_info.features
         | 
| 1547 | 
            -
                    ):
         | 
| 1548 | 
            -
                        for feature_info in self._server_features_cache.server_info.features:
         | 
| 1549 | 
            -
                            if feature_info and feature_info.name == feature_name:
         | 
| 1550 | 
            -
                                return feature_info.is_enabled
         | 
| 1551 | 
            -
             | 
| 1552 | 
            -
                    return False
         | 
| 1553 | 
            -
             | 
| 1554 | 
            -
                def _check_server_feature_with_fallback(self, feature: ServerFeature) -> bool:
         | 
| 1555 | 
            -
                    """Wrapper around check_server_feature that warns and returns False for older unsupported servers.
         | 
| 1556 | 
            -
             | 
| 1557 | 
            -
                    Good to use for features that have a fallback mechanism for older servers.
         | 
| 1558 | 
            -
             | 
| 1559 | 
            -
                    Args:
         | 
| 1560 | 
            -
                        feature (ServerFeature): The feature to check.
         | 
| 1561 | 
            -
             | 
| 1562 | 
            -
                    Returns:
         | 
| 1563 | 
            -
                        bool: True if the feature is enabled, False otherwise.
         | 
| 1564 | 
            -
             | 
| 1565 | 
            -
                    Exceptions:
         | 
| 1566 | 
            -
                        Exception: If an error other than the server not supporting feature queries occurs.
         | 
| 1567 | 
            -
                    """
         | 
| 1568 | 
            -
                    try:
         | 
| 1569 | 
            -
                        return self._check_server_feature(feature)
         | 
| 1570 | 
            -
                    except Exception as e:
         | 
| 1571 | 
            -
                        if 'Cannot query field "features" on type "ServerInfo".' in str(e):
         | 
| 1572 | 
            -
                            return False
         | 
| 1573 | 
            -
                        raise e
         | 
    
        wandb/apis/public/artifacts.py
    CHANGED
    
    | @@ -10,19 +10,21 @@ from wandb_gql import Client, gql | |
| 10 10 | 
             
            import wandb
         | 
| 11 11 | 
             
            from wandb.apis import public
         | 
| 12 12 | 
             
            from wandb.apis.normalize import normalize_exceptions
         | 
| 13 | 
            -
            from wandb.apis.paginator import Paginator
         | 
| 13 | 
            +
            from wandb.apis.paginator import Paginator, SizedPaginator
         | 
| 14 14 | 
             
            from wandb.errors.term import termlog
         | 
| 15 | 
            +
            from wandb.proto.wandb_internal_pb2 import ServerFeature
         | 
| 15 16 | 
             
            from wandb.sdk.artifacts._graphql_fragments import (
         | 
| 16 17 | 
             
                ARTIFACT_FILES_FRAGMENT,
         | 
| 17 18 | 
             
                ARTIFACTS_TYPES_FRAGMENT,
         | 
| 18 19 | 
             
            )
         | 
| 20 | 
            +
            from wandb.sdk.internal.internal_api import Api as InternalApi
         | 
| 19 21 | 
             
            from wandb.sdk.lib import deprecate
         | 
| 20 22 |  | 
| 21 23 | 
             
            if TYPE_CHECKING:
         | 
| 22 24 | 
             
                from wandb.apis.public import RetryingClient, Run
         | 
| 23 25 |  | 
| 24 26 |  | 
| 25 | 
            -
            class ArtifactTypes(Paginator):
         | 
| 27 | 
            +
            class ArtifactTypes(Paginator["ArtifactType"]):
         | 
| 26 28 | 
             
                QUERY = gql(
         | 
| 27 29 | 
             
                    """
         | 
| 28 30 | 
             
                    query ProjectArtifacts(
         | 
| @@ -45,7 +47,7 @@ class ArtifactTypes(Paginator): | |
| 45 47 | 
             
                    client: Client,
         | 
| 46 48 | 
             
                    entity: str,
         | 
| 47 49 | 
             
                    project: str,
         | 
| 48 | 
            -
                    per_page:  | 
| 50 | 
            +
                    per_page: int = 50,
         | 
| 49 51 | 
             
                ):
         | 
| 50 52 | 
             
                    self.entity = entity
         | 
| 51 53 | 
             
                    self.project = project
         | 
| @@ -58,7 +60,7 @@ class ArtifactTypes(Paginator): | |
| 58 60 | 
             
                    super().__init__(client, variable_values, per_page)
         | 
| 59 61 |  | 
| 60 62 | 
             
                @property
         | 
| 61 | 
            -
                def length(self):
         | 
| 63 | 
            +
                def length(self) -> None:
         | 
| 62 64 | 
             
                    # TODO
         | 
| 63 65 | 
             
                    return None
         | 
| 64 66 |  | 
| @@ -167,14 +169,14 @@ class ArtifactType: | |
| 167 169 | 
             
                    return f"<ArtifactType {self.type}>"
         | 
| 168 170 |  | 
| 169 171 |  | 
| 170 | 
            -
            class ArtifactCollections( | 
| 172 | 
            +
            class ArtifactCollections(SizedPaginator["ArtifactCollection"]):
         | 
| 171 173 | 
             
                def __init__(
         | 
| 172 174 | 
             
                    self,
         | 
| 173 175 | 
             
                    client: Client,
         | 
| 174 176 | 
             
                    entity: str,
         | 
| 175 177 | 
             
                    project: str,
         | 
| 176 178 | 
             
                    type_name: str,
         | 
| 177 | 
            -
                    per_page:  | 
| 179 | 
            +
                    per_page: int = 50,
         | 
| 178 180 | 
             
                ):
         | 
| 179 181 | 
             
                    self.entity = entity
         | 
| 180 182 | 
             
                    self.project = project
         | 
| @@ -710,7 +712,7 @@ class ArtifactCollection: | |
| 710 712 | 
             
                    return f"<ArtifactCollection {self._name} ({self._type})>"
         | 
| 711 713 |  | 
| 712 714 |  | 
| 713 | 
            -
            class Artifacts( | 
| 715 | 
            +
            class Artifacts(SizedPaginator["wandb.Artifact"]):
         | 
| 714 716 | 
             
                """An iterable collection of artifact versions associated with a project and optional filter.
         | 
| 715 717 |  | 
| 716 718 | 
             
                This is generally used indirectly via the `Api`.artifact_versions method.
         | 
| @@ -826,10 +828,8 @@ class Artifacts(Paginator): | |
| 826 828 | 
             
                    ]
         | 
| 827 829 |  | 
| 828 830 |  | 
| 829 | 
            -
            class RunArtifacts( | 
| 830 | 
            -
                def __init__(
         | 
| 831 | 
            -
                    self, client: Client, run: "Run", mode="logged", per_page: Optional[int] = 50
         | 
| 832 | 
            -
                ):
         | 
| 831 | 
            +
            class RunArtifacts(SizedPaginator["wandb.Artifact"]):
         | 
| 832 | 
            +
                def __init__(self, client: Client, run: "Run", mode="logged", per_page: int = 50):
         | 
| 833 833 | 
             
                    from wandb.sdk.artifacts.artifact import _gql_artifact_fragment
         | 
| 834 834 |  | 
| 835 835 | 
             
                    output_query = gql(
         | 
| @@ -944,9 +944,9 @@ class RunArtifacts(Paginator): | |
| 944 944 | 
             
                    ]
         | 
| 945 945 |  | 
| 946 946 |  | 
| 947 | 
            -
            class ArtifactFiles( | 
| 948 | 
            -
                 | 
| 949 | 
            -
                    """
         | 
| 947 | 
            +
            class ArtifactFiles(SizedPaginator["public.File"]):
         | 
| 948 | 
            +
                ARTIFACT_VERSION_FILES_QUERY = gql(
         | 
| 949 | 
            +
                    f"""
         | 
| 950 950 | 
             
                    query ArtifactFiles(
         | 
| 951 951 | 
             
                        $entityName: String!,
         | 
| 952 952 | 
             
                        $projectName: String!,
         | 
| @@ -959,13 +959,40 @@ class ArtifactFiles(Paginator): | |
| 959 959 | 
             
                        project(name: $projectName, entityName: $entityName) {{
         | 
| 960 960 | 
             
                            artifactType(name: $artifactTypeName) {{
         | 
| 961 961 | 
             
                                artifact(name: $artifactName) {{
         | 
| 962 | 
            -
                                     | 
| 962 | 
            +
                                    files(names: $fileNames, after: $fileCursor, first: $fileLimit) {{
         | 
| 963 | 
            +
                                        ...FilesFragment
         | 
| 964 | 
            +
                                    }}
         | 
| 963 965 | 
             
                                }}
         | 
| 964 966 | 
             
                            }}
         | 
| 965 967 | 
             
                        }}
         | 
| 966 968 | 
             
                    }}
         | 
| 967 | 
            -
                    {}
         | 
| 968 | 
            -
                """ | 
| 969 | 
            +
                    {ARTIFACT_FILES_FRAGMENT}
         | 
| 970 | 
            +
                """
         | 
| 971 | 
            +
                )
         | 
| 972 | 
            +
             | 
| 973 | 
            +
                ARTIFACT_COLLECTION_MEMBERSHIP_FILES_QUERY = gql(
         | 
| 974 | 
            +
                    f"""
         | 
| 975 | 
            +
                    query ArtifactCollectionMembershipFiles(
         | 
| 976 | 
            +
                        $entityName: String!,
         | 
| 977 | 
            +
                        $projectName: String!,
         | 
| 978 | 
            +
                        $artifactName: String!,
         | 
| 979 | 
            +
                        $artifactVersionIndex: String!,
         | 
| 980 | 
            +
                        $fileNames: [String!],
         | 
| 981 | 
            +
                        $fileCursor: String,
         | 
| 982 | 
            +
                        $fileLimit: Int = 50
         | 
| 983 | 
            +
                    ) {{
         | 
| 984 | 
            +
                        project(name: $projectName, entityName: $entityName) {{
         | 
| 985 | 
            +
                            artifactCollection(name: $artifactName) {{
         | 
| 986 | 
            +
                                artifactMembership (aliasName: $artifactVersionIndex) {{
         | 
| 987 | 
            +
                                    files(names: $fileNames, after: $fileCursor, first: $fileLimit) {{
         | 
| 988 | 
            +
                                        ...FilesFragment
         | 
| 989 | 
            +
                                    }}
         | 
| 990 | 
            +
                                }}
         | 
| 991 | 
            +
                            }}
         | 
| 992 | 
            +
                        }}
         | 
| 993 | 
            +
                    }}
         | 
| 994 | 
            +
                    {ARTIFACT_FILES_FRAGMENT}
         | 
| 995 | 
            +
                    """
         | 
| 969 996 | 
             
                )
         | 
| 970 997 |  | 
| 971 998 | 
             
                def __init__(
         | 
| @@ -975,6 +1002,9 @@ class ArtifactFiles(Paginator): | |
| 975 1002 | 
             
                    names: Optional[Sequence[str]] = None,
         | 
| 976 1003 | 
             
                    per_page: int = 50,
         | 
| 977 1004 | 
             
                ):
         | 
| 1005 | 
            +
                    self.query_via_membership = InternalApi()._check_server_feature_with_fallback(
         | 
| 1006 | 
            +
                        ServerFeature.ARTIFACT_COLLECTION_MEMBERSHIP_FILES
         | 
| 1007 | 
            +
                    )
         | 
| 978 1008 | 
             
                    self.artifact = artifact
         | 
| 979 1009 | 
             
                    variables = {
         | 
| 980 1010 | 
             
                        "entityName": artifact.source_entity,
         | 
| @@ -983,6 +1013,17 @@ class ArtifactFiles(Paginator): | |
| 983 1013 | 
             
                        "artifactName": artifact.source_name,
         | 
| 984 1014 | 
             
                        "fileNames": names,
         | 
| 985 1015 | 
             
                    }
         | 
| 1016 | 
            +
                    if self.query_via_membership:
         | 
| 1017 | 
            +
                        self.QUERY = self.ARTIFACT_COLLECTION_MEMBERSHIP_FILES_QUERY
         | 
| 1018 | 
            +
                        variables = {
         | 
| 1019 | 
            +
                            "entityName": artifact.entity,
         | 
| 1020 | 
            +
                            "projectName": artifact.project,
         | 
| 1021 | 
            +
                            "artifactName": artifact.name.split(":")[0],
         | 
| 1022 | 
            +
                            "artifactVersionIndex": artifact.version,
         | 
| 1023 | 
            +
                            "fileNames": names,
         | 
| 1024 | 
            +
                        }
         | 
| 1025 | 
            +
                    else:
         | 
| 1026 | 
            +
                        self.QUERY = self.ARTIFACT_VERSION_FILES_QUERY
         | 
| 986 1027 | 
             
                    # The server must advertise at least SDK 0.12.21
         | 
| 987 1028 | 
             
                    # to get storagePath
         | 
| 988 1029 | 
             
                    if not client.version_supported("0.12.21"):
         | 
| @@ -1000,6 +1041,10 @@ class ArtifactFiles(Paginator): | |
| 1000 1041 | 
             
                @property
         | 
| 1001 1042 | 
             
                def more(self):
         | 
| 1002 1043 | 
             
                    if self.last_response:
         | 
| 1044 | 
            +
                        if self.query_via_membership:
         | 
| 1045 | 
            +
                            return self.last_response["project"]["artifactCollection"][
         | 
| 1046 | 
            +
                                "artifactMembership"
         | 
| 1047 | 
            +
                            ]["files"]["pageInfo"]["hasNextPage"]
         | 
| 1003 1048 | 
             
                        return self.last_response["project"]["artifactType"]["artifact"]["files"][
         | 
| 1004 1049 | 
             
                            "pageInfo"
         | 
| 1005 1050 | 
             
                        ]["hasNextPage"]
         | 
| @@ -1009,6 +1054,10 @@ class ArtifactFiles(Paginator): | |
| 1009 1054 | 
             
                @property
         | 
| 1010 1055 | 
             
                def cursor(self):
         | 
| 1011 1056 | 
             
                    if self.last_response:
         | 
| 1057 | 
            +
                        if self.query_via_membership:
         | 
| 1058 | 
            +
                            return self.last_response["project"]["artifactCollection"][
         | 
| 1059 | 
            +
                                "artifactMembership"
         | 
| 1060 | 
            +
                            ]["files"]["edges"][-1]["cursor"]
         | 
| 1012 1061 | 
             
                        return self.last_response["project"]["artifactType"]["artifact"]["files"][
         | 
| 1013 1062 | 
             
                            "edges"
         | 
| 1014 1063 | 
             
                        ][-1]["cursor"]
         | 
| @@ -1019,6 +1068,13 @@ class ArtifactFiles(Paginator): | |
| 1019 1068 | 
             
                    self.variables.update({"fileLimit": self.per_page, "fileCursor": self.cursor})
         | 
| 1020 1069 |  | 
| 1021 1070 | 
             
                def convert_objects(self):
         | 
| 1071 | 
            +
                    if self.query_via_membership:
         | 
| 1072 | 
            +
                        return [
         | 
| 1073 | 
            +
                            public.File(self.client, r["node"])
         | 
| 1074 | 
            +
                            for r in self.last_response["project"]["artifactCollection"][
         | 
| 1075 | 
            +
                                "artifactMembership"
         | 
| 1076 | 
            +
                            ]["files"]["edges"]
         | 
| 1077 | 
            +
                        ]
         | 
| 1022 1078 | 
             
                    return [
         | 
| 1023 1079 | 
             
                        public.File(self.client, r["node"])
         | 
| 1024 1080 | 
             
                        for r in self.last_response["project"]["artifactType"]["artifact"]["files"][
         | 
    
        wandb/apis/public/files.py
    CHANGED
    
    | @@ -12,7 +12,7 @@ import wandb | |
| 12 12 | 
             
            from wandb import util
         | 
| 13 13 | 
             
            from wandb.apis.attrs import Attrs
         | 
| 14 14 | 
             
            from wandb.apis.normalize import normalize_exceptions
         | 
| 15 | 
            -
            from wandb.apis.paginator import  | 
| 15 | 
            +
            from wandb.apis.paginator import SizedPaginator
         | 
| 16 16 | 
             
            from wandb.apis.public import utils
         | 
| 17 17 | 
             
            from wandb.apis.public.api import Api
         | 
| 18 18 | 
             
            from wandb.apis.public.const import RETRY_TIMEDELTA
         | 
| @@ -41,7 +41,7 @@ FILE_FRAGMENT = """fragment RunFilesFragment on Run { | |
| 41 41 | 
             
            }"""
         | 
| 42 42 |  | 
| 43 43 |  | 
| 44 | 
            -
            class Files( | 
| 44 | 
            +
            class Files(SizedPaginator["File"]):
         | 
| 45 45 | 
             
                """An iterable collection of `File` objects."""
         | 
| 46 46 |  | 
| 47 47 | 
             
                QUERY = gql(
         | 
    
        wandb/apis/public/projects.py
    CHANGED
    
    | @@ -17,7 +17,7 @@ PROJECT_FRAGMENT = """fragment ProjectFragment on Project { | |
| 17 17 | 
             
            }"""
         | 
| 18 18 |  | 
| 19 19 |  | 
| 20 | 
            -
            class Projects(Paginator):
         | 
| 20 | 
            +
            class Projects(Paginator["Project"]):
         | 
| 21 21 | 
             
                """An iterable collection of `Project` objects."""
         | 
| 22 22 |  | 
| 23 23 | 
             
                QUERY = gql(
         | 
| @@ -49,7 +49,8 @@ class Projects(Paginator): | |
| 49 49 | 
             
                    super().__init__(client, variables, per_page)
         | 
| 50 50 |  | 
| 51 51 | 
             
                @property
         | 
| 52 | 
            -
                def length(self):
         | 
| 52 | 
            +
                def length(self) -> None:
         | 
| 53 | 
            +
                    # For backwards compatibility, even though this isn't a SizedPaginator
         | 
| 53 54 | 
             
                    return None
         | 
| 54 55 |  | 
| 55 56 | 
             
                @property
         | 
    
        wandb/apis/public/reports.py
    CHANGED
    
    | @@ -9,11 +9,11 @@ from wandb_gql import gql | |
| 9 9 | 
             
            import wandb
         | 
| 10 10 | 
             
            from wandb.apis import public
         | 
| 11 11 | 
             
            from wandb.apis.attrs import Attrs
         | 
| 12 | 
            -
            from wandb.apis.paginator import  | 
| 12 | 
            +
            from wandb.apis.paginator import SizedPaginator
         | 
| 13 13 | 
             
            from wandb.sdk.lib import ipython
         | 
| 14 14 |  | 
| 15 15 |  | 
| 16 | 
            -
            class Reports( | 
| 16 | 
            +
            class Reports(SizedPaginator["BetaReport"]):
         | 
| 17 17 | 
             
                """Reports is an iterable collection of `BetaReport` objects."""
         | 
| 18 18 |  | 
| 19 19 | 
             
                QUERY = gql(
         | 
    
        wandb/apis/public/runs.py
    CHANGED
    
    | @@ -24,7 +24,7 @@ from wandb.apis import public | |
| 24 24 | 
             
            from wandb.apis.attrs import Attrs
         | 
| 25 25 | 
             
            from wandb.apis.internal import Api as InternalApi
         | 
| 26 26 | 
             
            from wandb.apis.normalize import normalize_exceptions
         | 
| 27 | 
            -
            from wandb.apis.paginator import  | 
| 27 | 
            +
            from wandb.apis.paginator import SizedPaginator
         | 
| 28 28 | 
             
            from wandb.apis.public.const import RETRY_TIMEDELTA
         | 
| 29 29 | 
             
            from wandb.sdk.lib import ipython, json_util, runid
         | 
| 30 30 | 
             
            from wandb.sdk.lib.paths import LogicalPath
         | 
| @@ -61,7 +61,7 @@ RUN_FRAGMENT = """fragment RunFragment on Run { | |
| 61 61 | 
             
            }"""
         | 
| 62 62 |  | 
| 63 63 |  | 
| 64 | 
            -
            class Runs( | 
| 64 | 
            +
            class Runs(SizedPaginator["Run"]):
         | 
| 65 65 | 
             
                """An iterable collection of runs associated with a project and optional filter.
         | 
| 66 66 |  | 
| 67 67 | 
             
                This is generally used indirectly via the `Api`.runs method.
         | 
| @@ -421,16 +421,15 @@ class Run(Attrs): | |
| 421 421 | 
             
                        """
         | 
| 422 422 | 
             
                    query Run($project: String!, $entity: String!, $name: String!) {{
         | 
| 423 423 | 
             
                        project(name: $project, entityName: $entity) {{
         | 
| 424 | 
            -
                            {}
         | 
| 425 424 | 
             
                            run(name: $name) {{
         | 
| 425 | 
            +
                                {}
         | 
| 426 426 | 
             
                                ...RunFragment
         | 
| 427 427 | 
             
                            }}
         | 
| 428 428 | 
             
                        }}
         | 
| 429 429 | 
             
                    }}
         | 
| 430 430 | 
             
                    {}
         | 
| 431 431 | 
             
                    """.format(
         | 
| 432 | 
            -
                             | 
| 433 | 
            -
                            "internalId" if self._server_provides_internal_id_for_project() else "",
         | 
| 432 | 
            +
                            "projectId" if self._server_provides_internal_id_for_project() else "",
         | 
| 434 433 | 
             
                            RUN_FRAGMENT,
         | 
| 435 434 | 
             
                        )
         | 
| 436 435 | 
             
                    )
         | 
| @@ -444,7 +443,11 @@ class Run(Attrs): | |
| 444 443 | 
             
                            raise ValueError("Could not find run {}".format(self))
         | 
| 445 444 | 
             
                        self._attrs = response["project"]["run"]
         | 
| 446 445 | 
             
                        self._state = self._attrs["state"]
         | 
| 447 | 
            -
             | 
| 446 | 
            +
             | 
| 447 | 
            +
                        self._project_internal_id = (
         | 
| 448 | 
            +
                            int(self._attrs["projectId"]) if "projectId" in self._attrs else None
         | 
| 449 | 
            +
                        )
         | 
| 450 | 
            +
             | 
| 448 451 | 
             
                        if self._include_sweeps and self.sweep_name and not self.sweep:
         | 
| 449 452 | 
             
                            # There may be a lot of runs. Don't bother pulling them all
         | 
| 450 453 | 
             
                            # just for the sake of this one.
         | 
| @@ -847,7 +850,12 @@ class Run(Attrs): | |
| 847 850 | 
             
                    api.set_current_run_id(self.id)
         | 
| 848 851 |  | 
| 849 852 | 
             
                    if isinstance(artifact, wandb.Artifact) and not artifact.is_draft():
         | 
| 850 | 
            -
                        api.use_artifact( | 
| 853 | 
            +
                        api.use_artifact(
         | 
| 854 | 
            +
                            artifact.id,
         | 
| 855 | 
            +
                            use_as=use_as or artifact.name,
         | 
| 856 | 
            +
                            artifact_entity_name=artifact.entity,
         | 
| 857 | 
            +
                            artifact_project_name=artifact.project,
         | 
| 858 | 
            +
                        )
         | 
| 851 859 | 
             
                        return artifact
         | 
| 852 860 | 
             
                    elif isinstance(artifact, wandb.Artifact) and artifact.is_draft():
         | 
| 853 861 | 
             
                        raise ValueError(
         | 
| @@ -911,8 +919,8 @@ class Run(Attrs): | |
| 911 919 | 
             
                    This check is done by utilizing GraphQL introspection in the available fields on the Project type.
         | 
| 912 920 | 
             
                    """
         | 
| 913 921 | 
             
                    query_string = """
         | 
| 914 | 
            -
                       query  | 
| 915 | 
            -
                             | 
| 922 | 
            +
                       query ProbeRunInput {
         | 
| 923 | 
            +
                            RunType: __type(name:"Run") {
         | 
| 916 924 | 
             
                                fields {
         | 
| 917 925 | 
             
                                    name
         | 
| 918 926 | 
             
                                }
         | 
| @@ -924,8 +932,8 @@ class Run(Attrs): | |
| 924 932 | 
             
                    if self.server_provides_internal_id_field is None:
         | 
| 925 933 | 
             
                        query = gql(query_string)
         | 
| 926 934 | 
             
                        res = self.client.execute(query)
         | 
| 927 | 
            -
                        self.server_provides_internal_id_field = " | 
| 928 | 
            -
                            x["name"] for x in (res.get(" | 
| 935 | 
            +
                        self.server_provides_internal_id_field = "projectId" in [
         | 
| 936 | 
            +
                            x["name"] for x in (res.get("RunType", {}).get("fields", [{}]))
         | 
| 929 937 | 
             
                        ]
         | 
| 930 938 |  | 
| 931 939 | 
             
                    return self.server_provides_internal_id_field
         | 
    
        wandb/bin/gpu_stats
    CHANGED
    
    | Binary file | 
    
        wandb/bin/wandb-core
    CHANGED
    
    | Binary file | 
| @@ -13,6 +13,7 @@ import inspect | |
| 13 13 | 
             
            import pickle
         | 
| 14 14 | 
             
            from functools import wraps
         | 
| 15 15 | 
             
            from pathlib import Path
         | 
| 16 | 
            +
            from typing import Union
         | 
| 16 17 |  | 
| 17 18 | 
             
            import wandb
         | 
| 18 19 | 
             
            from wandb.sdk.lib import telemetry as wb_telemetry
         | 
| @@ -25,17 +26,18 @@ except ImportError as e: | |
| 25 26 | 
             
                ) from e
         | 
| 26 27 |  | 
| 27 28 | 
             
            try:
         | 
| 28 | 
            -
                from  | 
| 29 | 
            +
                from plum import dispatch
         | 
| 29 30 | 
             
            except ImportError as e:
         | 
| 30 31 | 
             
                raise Exception(
         | 
| 31 | 
            -
                    "Error: ` | 
| 32 | 
            +
                    "Error: `plum-dispatch` not installed >> "
         | 
| 33 | 
            +
                    "This integration requires plum-dispatch! To fix, please `pip install -Uqq plum-dispatch`"
         | 
| 32 34 | 
             
                ) from e
         | 
| 33 35 |  | 
| 34 36 |  | 
| 35 37 | 
             
            try:
         | 
| 36 38 | 
             
                import pandas as pd
         | 
| 37 39 |  | 
| 38 | 
            -
                @ | 
| 40 | 
            +
                @dispatch  # noqa: F811
         | 
| 39 41 | 
             
                def _wandb_use(
         | 
| 40 42 | 
             
                    name: str,
         | 
| 41 43 | 
             
                    data: pd.DataFrame,
         | 
| @@ -52,7 +54,7 @@ try: | |
| 52 54 | 
             
                        run.use_artifact(f"{name}:latest")
         | 
| 53 55 | 
             
                        wandb.termlog(f"Using artifact: {name} ({type(data)})")
         | 
| 54 56 |  | 
| 55 | 
            -
                @ | 
| 57 | 
            +
                @dispatch  # noqa: F811
         | 
| 56 58 | 
             
                def wandb_track(
         | 
| 57 59 | 
             
                    name: str,
         | 
| 58 60 | 
             
                    data: pd.DataFrame,
         | 
| @@ -81,7 +83,7 @@ try: | |
| 81 83 | 
             
                import torch
         | 
| 82 84 | 
             
                import torch.nn as nn
         | 
| 83 85 |  | 
| 84 | 
            -
                @ | 
| 86 | 
            +
                @dispatch  # noqa: F811
         | 
| 85 87 | 
             
                def _wandb_use(
         | 
| 86 88 | 
             
                    name: str,
         | 
| 87 89 | 
             
                    data: nn.Module,
         | 
| @@ -98,7 +100,7 @@ try: | |
| 98 100 | 
             
                        run.use_artifact(f"{name}:latest")
         | 
| 99 101 | 
             
                        wandb.termlog(f"Using artifact: {name} ({type(data)})")
         | 
| 100 102 |  | 
| 101 | 
            -
                @ | 
| 103 | 
            +
                @dispatch  # noqa: F811
         | 
| 102 104 | 
             
                def wandb_track(
         | 
| 103 105 | 
             
                    name: str,
         | 
| 104 106 | 
             
                    data: nn.Module,
         | 
| @@ -126,7 +128,7 @@ except ImportError: | |
| 126 128 | 
             
            try:
         | 
| 127 129 | 
             
                from sklearn.base import BaseEstimator
         | 
| 128 130 |  | 
| 129 | 
            -
                @ | 
| 131 | 
            +
                @dispatch  # noqa: F811
         | 
| 130 132 | 
             
                def _wandb_use(
         | 
| 131 133 | 
             
                    name: str,
         | 
| 132 134 | 
             
                    data: BaseEstimator,
         | 
| @@ -143,7 +145,7 @@ try: | |
| 143 145 | 
             
                        run.use_artifact(f"{name}:latest")
         | 
| 144 146 | 
             
                        wandb.termlog(f"Using artifact: {name} ({type(data)})")
         | 
| 145 147 |  | 
| 146 | 
            -
                @ | 
| 148 | 
            +
                @dispatch  # noqa: F811
         | 
| 147 149 | 
             
                def wandb_track(
         | 
| 148 150 | 
             
                    name: str,
         | 
| 149 151 | 
             
                    data: BaseEstimator,
         | 
| @@ -192,10 +194,10 @@ class ArtifactProxy: | |
| 192 194 | 
             
                    return getattr(self.flow, key)
         | 
| 193 195 |  | 
| 194 196 |  | 
| 195 | 
            -
            @ | 
| 197 | 
            +
            @dispatch  # noqa: F811
         | 
| 196 198 | 
             
            def wandb_track(
         | 
| 197 199 | 
             
                name: str,
         | 
| 198 | 
            -
                data:  | 
| 200 | 
            +
                data: Union[dict, list, set, str, int, float, bool],
         | 
| 199 201 | 
             
                run=None,
         | 
| 200 202 | 
             
                testing=False,
         | 
| 201 203 | 
             
                *args,
         | 
| @@ -207,7 +209,7 @@ def wandb_track( | |
| 207 209 | 
             
                run.log({name: data})
         | 
| 208 210 |  | 
| 209 211 |  | 
| 210 | 
            -
            @ | 
| 212 | 
            +
            @dispatch  # noqa: F811
         | 
| 211 213 | 
             
            def wandb_track(
         | 
| 212 214 | 
             
                name: str, data: Path, datasets=False, run=None, testing=False, *args, **kwargs
         | 
| 213 215 | 
             
            ):
         | 
| @@ -225,7 +227,7 @@ def wandb_track( | |
| 225 227 |  | 
| 226 228 |  | 
| 227 229 | 
             
            # this is the base case
         | 
| 228 | 
            -
            @ | 
| 230 | 
            +
            @dispatch  # noqa: F811
         | 
| 229 231 | 
             
            def wandb_track(
         | 
| 230 232 | 
             
                name: str, data, others=False, run=None, testing=False, *args, **kwargs
         | 
| 231 233 | 
             
            ):
         | 
| @@ -240,7 +242,7 @@ def wandb_track( | |
| 240 242 | 
             
                    wandb.termlog(f"Logging artifact: {name} ({type(data)})")
         | 
| 241 243 |  | 
| 242 244 |  | 
| 243 | 
            -
            @ | 
| 245 | 
            +
            @dispatch  # noqa: F811
         | 
| 244 246 | 
             
            def wandb_use(name: str, data, *args, **kwargs):
         | 
| 245 247 | 
             
                try:
         | 
| 246 248 | 
             
                    return _wandb_use(name, data, *args, **kwargs)
         | 
| @@ -252,14 +254,14 @@ def wandb_use(name: str, data, *args, **kwargs): | |
| 252 254 | 
             
                    )
         | 
| 253 255 |  | 
| 254 256 |  | 
| 255 | 
            -
            @ | 
| 257 | 
            +
            @dispatch  # noqa: F811
         | 
| 256 258 | 
             
            def wandb_use(
         | 
| 257 | 
            -
                name: str, data:  | 
| 259 | 
            +
                name: str, data: Union[dict, list, set, str, int, float, bool], *args, **kwargs
         | 
| 258 260 | 
             
            ):  # type: ignore
         | 
| 259 261 | 
             
                pass  # do nothing for these types
         | 
| 260 262 |  | 
| 261 263 |  | 
| 262 | 
            -
            @ | 
| 264 | 
            +
            @dispatch  # noqa: F811
         | 
| 263 265 | 
             
            def _wandb_use(
         | 
| 264 266 | 
             
                name: str, data: Path, datasets=False, run=None, testing=False, *args, **kwargs
         | 
| 265 267 | 
             
            ):  # type: ignore
         | 
| @@ -271,7 +273,7 @@ def _wandb_use( | |
| 271 273 | 
             
                    wandb.termlog(f"Using artifact: {name} ({type(data)})")
         | 
| 272 274 |  | 
| 273 275 |  | 
| 274 | 
            -
            @ | 
| 276 | 
            +
            @dispatch  # noqa: F811
         | 
| 275 277 | 
             
            def _wandb_use(name: str, data, others=False, run=None, testing=False, *args, **kwargs):  # type: ignore
         | 
| 276 278 | 
             
                if testing:
         | 
| 277 279 | 
             
                    return "others" if others else None
         | 
| @@ -23,7 +23,7 @@ class WandbObserver(RunObserver): | |
| 23 23 | 
             
                    job_type — the type of job you are logging, e.g. eval, worker, ps (default: training)
         | 
| 24 24 | 
             
                    save_code — save the main python or notebook file to wandb to enable diffing (default: editable from your settings page)
         | 
| 25 25 | 
             
                    group — a string by which to group other runs; see Grouping
         | 
| 26 | 
            -
                    reinit —  | 
| 26 | 
            +
                    reinit — Shorthand for the reinit setting that defines what to do when `wandb.init()` is called while a run is active. See the setting's documentation.
         | 
| 27 27 | 
             
                    id — A unique ID for this run primarily used for Resuming. It must be globally unique, and if you delete a run you can't reuse the ID. Use the name field for a descriptive, useful name for the run. The ID cannot contain special characters.
         | 
| 28 28 | 
             
                    resume — if set to True, the run auto resumes; can also be a unique string for manual resuming; see Resuming (default: False)
         | 
| 29 29 | 
             
                    anonymous — can be "allow", "never", or "must". This enables or explicitly disables anonymous logging. (default: never)
         |