wandb 0.19.6__py3-none-macosx_11_0_arm64.whl → 0.19.7__py3-none-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +1 -1
- wandb/__init__.pyi +25 -5
- wandb/apis/public/_generated/__init__.py +21 -0
- wandb/apis/public/_generated/base.py +128 -0
- wandb/apis/public/_generated/enums.py +4 -0
- wandb/apis/public/_generated/input_types.py +4 -0
- wandb/apis/public/_generated/operations.py +15 -0
- wandb/apis/public/_generated/server_features_query.py +27 -0
- wandb/apis/public/_generated/typing_compat.py +14 -0
- wandb/apis/public/api.py +192 -6
- wandb/apis/public/artifacts.py +13 -45
- wandb/apis/public/registries.py +573 -0
- wandb/apis/public/utils.py +36 -0
- wandb/bin/gpu_stats +0 -0
- wandb/bin/wandb-core +0 -0
- wandb/cli/cli.py +11 -20
- wandb/env.py +10 -0
- wandb/proto/v3/wandb_internal_pb2.py +243 -222
- wandb/proto/v3/wandb_server_pb2.py +4 -4
- wandb/proto/v3/wandb_settings_pb2.py +1 -1
- wandb/proto/v4/wandb_internal_pb2.py +226 -222
- wandb/proto/v4/wandb_server_pb2.py +4 -4
- wandb/proto/v4/wandb_settings_pb2.py +1 -1
- wandb/proto/v5/wandb_internal_pb2.py +226 -222
- wandb/proto/v5/wandb_server_pb2.py +4 -4
- wandb/proto/v5/wandb_settings_pb2.py +1 -1
- wandb/sdk/artifacts/_graphql_fragments.py +126 -0
- wandb/sdk/artifacts/artifact.py +43 -88
- wandb/sdk/backend/backend.py +1 -1
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +14 -6
- wandb/sdk/data_types/helper_types/image_mask.py +12 -6
- wandb/sdk/data_types/saved_model.py +35 -46
- wandb/sdk/data_types/video.py +7 -16
- wandb/sdk/interface/interface.py +26 -10
- wandb/sdk/interface/interface_queue.py +5 -8
- wandb/sdk/interface/interface_relay.py +1 -6
- wandb/sdk/interface/interface_shared.py +21 -99
- wandb/sdk/interface/interface_sock.py +2 -13
- wandb/sdk/interface/router.py +21 -15
- wandb/sdk/interface/router_queue.py +2 -1
- wandb/sdk/interface/router_relay.py +2 -1
- wandb/sdk/interface/router_sock.py +5 -4
- wandb/sdk/internal/handler.py +4 -3
- wandb/sdk/internal/internal_api.py +12 -1
- wandb/sdk/internal/sender.py +0 -18
- wandb/sdk/lib/apikey.py +87 -26
- wandb/sdk/lib/asyncio_compat.py +210 -0
- wandb/sdk/lib/progress.py +78 -16
- wandb/sdk/lib/service_connection.py +1 -1
- wandb/sdk/lib/sock_client.py +7 -7
- wandb/sdk/mailbox/__init__.py +23 -0
- wandb/sdk/mailbox/handles.py +199 -0
- wandb/sdk/mailbox/mailbox.py +121 -0
- wandb/sdk/mailbox/wait_with_progress.py +134 -0
- wandb/sdk/service/server_sock.py +5 -1
- wandb/sdk/service/streams.py +66 -74
- wandb/sdk/verify/verify.py +54 -2
- wandb/sdk/wandb_init.py +61 -61
- wandb/sdk/wandb_login.py +7 -4
- wandb/sdk/wandb_metadata.py +65 -34
- wandb/sdk/wandb_require.py +14 -8
- wandb/sdk/wandb_run.py +82 -87
- wandb/sdk/wandb_settings.py +3 -3
- wandb/sdk/wandb_setup.py +19 -8
- wandb/sdk/wandb_sync.py +2 -4
- wandb/util.py +3 -1
- {wandb-0.19.6.dist-info → wandb-0.19.7.dist-info}/METADATA +2 -2
- {wandb-0.19.6.dist-info → wandb-0.19.7.dist-info}/RECORD +71 -58
- wandb/sdk/lib/mailbox.py +0 -442
- {wandb-0.19.6.dist-info → wandb-0.19.7.dist-info}/WHEEL +0 -0
- {wandb-0.19.6.dist-info → wandb-0.19.7.dist-info}/entry_points.txt +0 -0
- {wandb-0.19.6.dist-info → wandb-0.19.7.dist-info}/licenses/LICENSE +0 -0
| @@ -0,0 +1,573 @@ | |
| 1 | 
            +
            """Public API: registries."""
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import json
         | 
| 4 | 
            +
            from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Sequence
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            if TYPE_CHECKING:
         | 
| 7 | 
            +
                from wandb_gql import Client
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from wandb_gql import gql
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            import wandb
         | 
| 12 | 
            +
            from wandb.apis.paginator import Paginator
         | 
| 13 | 
            +
            from wandb.apis.public.artifacts import ArtifactCollection
         | 
| 14 | 
            +
            from wandb.sdk.artifacts._graphql_fragments import (
         | 
| 15 | 
            +
                _gql_artifact_fragment,
         | 
| 16 | 
            +
                _gql_registry_fragment,
         | 
| 17 | 
            +
            )
         | 
| 18 | 
            +
            from wandb.sdk.artifacts._validators import REGISTRY_PREFIX
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            class Registries(Paginator):
         | 
| 22 | 
            +
                """Iterator that returns Registries."""
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                QUERY = gql(
         | 
| 25 | 
            +
                    """
         | 
| 26 | 
            +
                    query Registries($organization: String!, $filters: JSONString, $cursor: String, $perPage: Int) {
         | 
| 27 | 
            +
                        organization(name: $organization) {
         | 
| 28 | 
            +
                            orgEntity {
         | 
| 29 | 
            +
                                name
         | 
| 30 | 
            +
                                projects(filters: $filters, after: $cursor, first: $perPage) {
         | 
| 31 | 
            +
                                    pageInfo {
         | 
| 32 | 
            +
                                        endCursor
         | 
| 33 | 
            +
                                        hasNextPage
         | 
| 34 | 
            +
                                    }
         | 
| 35 | 
            +
                                    edges {
         | 
| 36 | 
            +
                                        node {
         | 
| 37 | 
            +
                                            ...RegistryFragment
         | 
| 38 | 
            +
                                        }
         | 
| 39 | 
            +
                                    }
         | 
| 40 | 
            +
                                }
         | 
| 41 | 
            +
                            }
         | 
| 42 | 
            +
                        }
         | 
| 43 | 
            +
                    }
         | 
| 44 | 
            +
                    """
         | 
| 45 | 
            +
                    + _gql_registry_fragment()
         | 
| 46 | 
            +
                )
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                def __init__(
         | 
| 49 | 
            +
                    self,
         | 
| 50 | 
            +
                    client: "Client",
         | 
| 51 | 
            +
                    organization: str,
         | 
| 52 | 
            +
                    filter: Optional[Dict[str, Any]] = None,
         | 
| 53 | 
            +
                    per_page: Optional[int] = 100,
         | 
| 54 | 
            +
                ):
         | 
| 55 | 
            +
                    self.client = client
         | 
| 56 | 
            +
                    self.organization = organization
         | 
| 57 | 
            +
                    self.filter = _ensure_registry_prefix_on_names(filter or {})
         | 
| 58 | 
            +
                    variables = {
         | 
| 59 | 
            +
                        "organization": organization,
         | 
| 60 | 
            +
                        "filters": json.dumps(self.filter),
         | 
| 61 | 
            +
                    }
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                    super().__init__(client, variables, per_page)
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                def __bool__(self):
         | 
| 66 | 
            +
                    return len(self) > 0 or len(self.objects) > 0
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                def __next__(self):
         | 
| 69 | 
            +
                    # Implement custom next since its possible to load empty pages because of auth
         | 
| 70 | 
            +
                    self.index += 1
         | 
| 71 | 
            +
                    while len(self.objects) <= self.index:
         | 
| 72 | 
            +
                        if not self._load_page():
         | 
| 73 | 
            +
                            raise StopIteration
         | 
| 74 | 
            +
                    return self.objects[self.index]
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                def collections(self, filter: Optional[Dict[str, Any]] = None) -> "Collections":
         | 
| 77 | 
            +
                    return Collections(
         | 
| 78 | 
            +
                        self.client,
         | 
| 79 | 
            +
                        self.organization,
         | 
| 80 | 
            +
                        registry_filter=self.filter,
         | 
| 81 | 
            +
                        collection_filter=filter,
         | 
| 82 | 
            +
                    )
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                def versions(self, filter: Optional[Dict[str, Any]] = None) -> "Versions":
         | 
| 85 | 
            +
                    return Versions(
         | 
| 86 | 
            +
                        self.client,
         | 
| 87 | 
            +
                        self.organization,
         | 
| 88 | 
            +
                        registry_filter=self.filter,
         | 
| 89 | 
            +
                        collection_filter=None,
         | 
| 90 | 
            +
                        artifact_filter=filter,
         | 
| 91 | 
            +
                    )
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                @property
         | 
| 94 | 
            +
                def length(self):
         | 
| 95 | 
            +
                    if self.last_response:
         | 
| 96 | 
            +
                        return len(
         | 
| 97 | 
            +
                            self.last_response["organization"]["orgEntity"]["projects"]["edges"]
         | 
| 98 | 
            +
                        )
         | 
| 99 | 
            +
                    else:
         | 
| 100 | 
            +
                        return None
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                @property
         | 
| 103 | 
            +
                def more(self):
         | 
| 104 | 
            +
                    if self.last_response:
         | 
| 105 | 
            +
                        return self.last_response["organization"]["orgEntity"]["projects"][
         | 
| 106 | 
            +
                            "pageInfo"
         | 
| 107 | 
            +
                        ]["hasNextPage"]
         | 
| 108 | 
            +
                    else:
         | 
| 109 | 
            +
                        return True
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                @property
         | 
| 112 | 
            +
                def cursor(self):
         | 
| 113 | 
            +
                    if self.last_response:
         | 
| 114 | 
            +
                        return self.last_response["organization"]["orgEntity"]["projects"][
         | 
| 115 | 
            +
                            "pageInfo"
         | 
| 116 | 
            +
                        ]["endCursor"]
         | 
| 117 | 
            +
                    else:
         | 
| 118 | 
            +
                        return None
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                def convert_objects(self):
         | 
| 121 | 
            +
                    if not self.last_response:
         | 
| 122 | 
            +
                        return []
         | 
| 123 | 
            +
                    if (
         | 
| 124 | 
            +
                        not self.last_response["organization"]
         | 
| 125 | 
            +
                        or not self.last_response["organization"]["orgEntity"]
         | 
| 126 | 
            +
                    ):
         | 
| 127 | 
            +
                        raise ValueError(
         | 
| 128 | 
            +
                            f"Organization '{self.organization}' not found. Please verify the organization name is correct"
         | 
| 129 | 
            +
                        )
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                    return [
         | 
| 132 | 
            +
                        Registry(
         | 
| 133 | 
            +
                            self.client,
         | 
| 134 | 
            +
                            self.organization,
         | 
| 135 | 
            +
                            self.last_response["organization"]["orgEntity"]["name"],
         | 
| 136 | 
            +
                            r["node"]["name"],
         | 
| 137 | 
            +
                            r["node"],
         | 
| 138 | 
            +
                        )
         | 
| 139 | 
            +
                        for r in self.last_response["organization"]["orgEntity"]["projects"][
         | 
| 140 | 
            +
                            "edges"
         | 
| 141 | 
            +
                        ]
         | 
| 142 | 
            +
                    ]
         | 
| 143 | 
            +
             | 
| 144 | 
            +
             | 
| 145 | 
            +
            class Registry:
         | 
| 146 | 
            +
                """A single registry in the Registry."""
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                def __init__(
         | 
| 149 | 
            +
                    self,
         | 
| 150 | 
            +
                    client: "Client",
         | 
| 151 | 
            +
                    organization: str,
         | 
| 152 | 
            +
                    entity: str,
         | 
| 153 | 
            +
                    full_name: str,
         | 
| 154 | 
            +
                    attrs: Dict[str, Any],
         | 
| 155 | 
            +
                ):
         | 
| 156 | 
            +
                    self.client = client
         | 
| 157 | 
            +
                    self._full_name = full_name
         | 
| 158 | 
            +
                    self._name = full_name.replace(REGISTRY_PREFIX, "")
         | 
| 159 | 
            +
                    self._entity = entity
         | 
| 160 | 
            +
                    self._organization = organization
         | 
| 161 | 
            +
                    self._description = attrs.get("description", "")
         | 
| 162 | 
            +
                    self._allow_all_artifact_types = attrs.get(
         | 
| 163 | 
            +
                        "allowAllArtifactTypesInRegistry", False
         | 
| 164 | 
            +
                    )
         | 
| 165 | 
            +
                    self._artifact_types = [
         | 
| 166 | 
            +
                        t["node"]["name"] for t in attrs.get("artifactTypes", {}).get("edges", [])
         | 
| 167 | 
            +
                    ]
         | 
| 168 | 
            +
                    self._id = attrs.get("id", "")
         | 
| 169 | 
            +
                    self._created_at = attrs.get("createdAt", "")
         | 
| 170 | 
            +
                    self._updated_at = attrs.get("updatedAt", "")
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                @property
         | 
| 173 | 
            +
                def full_name(self):
         | 
| 174 | 
            +
                    return self._full_name
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                @property
         | 
| 177 | 
            +
                def name(self):
         | 
| 178 | 
            +
                    return self._name
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                @property
         | 
| 181 | 
            +
                def entity(self):
         | 
| 182 | 
            +
                    return self._entity
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                @property
         | 
| 185 | 
            +
                def organization(self):
         | 
| 186 | 
            +
                    return self._organization
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                @property
         | 
| 189 | 
            +
                def description(self):
         | 
| 190 | 
            +
                    return self._description
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                @property
         | 
| 193 | 
            +
                def allow_all_artifact_types(self):
         | 
| 194 | 
            +
                    return self._allow_all_artifact_types
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                @property
         | 
| 197 | 
            +
                def artifact_types(self):
         | 
| 198 | 
            +
                    return self._artifact_types
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                @property
         | 
| 201 | 
            +
                def created_at(self):
         | 
| 202 | 
            +
                    return self._created_at
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                @property
         | 
| 205 | 
            +
                def updated_at(self):
         | 
| 206 | 
            +
                    return self._updated_at
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                @property
         | 
| 209 | 
            +
                def path(self):
         | 
| 210 | 
            +
                    return [self.entity, self.name]
         | 
| 211 | 
            +
             | 
| 212 | 
            +
                def collections(self, filter: Optional[Dict[str, Any]] = None):
         | 
| 213 | 
            +
                    registry_filter = {
         | 
| 214 | 
            +
                        "name": self.full_name,
         | 
| 215 | 
            +
                    }
         | 
| 216 | 
            +
                    return Collections(self.client, self.organization, registry_filter, filter)
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                def versions(self, filter: Optional[Dict[str, Any]] = None):
         | 
| 219 | 
            +
                    registry_filter = {
         | 
| 220 | 
            +
                        "name": self.full_name,
         | 
| 221 | 
            +
                    }
         | 
| 222 | 
            +
                    return Versions(self.client, self.organization, registry_filter, None, filter)
         | 
| 223 | 
            +
             | 
| 224 | 
            +
             | 
| 225 | 
            +
            class Collections(Paginator):
         | 
| 226 | 
            +
                """Iterator that returns Artifact collections in the Registry."""
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                QUERY = gql(
         | 
| 229 | 
            +
                    """
         | 
| 230 | 
            +
                    query Collections(
         | 
| 231 | 
            +
                        $organization: String!,
         | 
| 232 | 
            +
                        $registryFilter: JSONString,
         | 
| 233 | 
            +
                        $collectionFilter: JSONString,
         | 
| 234 | 
            +
                        $collectionTypes: [ArtifactCollectionType!],
         | 
| 235 | 
            +
                        $cursor: String,
         | 
| 236 | 
            +
                        $perPage: Int
         | 
| 237 | 
            +
                    ) {
         | 
| 238 | 
            +
                        organization(name: $organization) {
         | 
| 239 | 
            +
                            orgEntity {
         | 
| 240 | 
            +
                                name
         | 
| 241 | 
            +
                                artifactCollections(
         | 
| 242 | 
            +
                                    projectFilters: $registryFilter,
         | 
| 243 | 
            +
                                    filters: $collectionFilter,
         | 
| 244 | 
            +
                                    collectionTypes: $collectionTypes,
         | 
| 245 | 
            +
                                    after: $cursor,
         | 
| 246 | 
            +
                                    first: $perPage
         | 
| 247 | 
            +
                                ) {
         | 
| 248 | 
            +
                                    totalCount
         | 
| 249 | 
            +
                                    pageInfo {
         | 
| 250 | 
            +
                                        endCursor
         | 
| 251 | 
            +
                                        hasNextPage
         | 
| 252 | 
            +
                                    }
         | 
| 253 | 
            +
                                    edges {
         | 
| 254 | 
            +
                                        cursor
         | 
| 255 | 
            +
                                        node {
         | 
| 256 | 
            +
                                            id
         | 
| 257 | 
            +
                                            name
         | 
| 258 | 
            +
                                            description
         | 
| 259 | 
            +
                                            createdAt
         | 
| 260 | 
            +
                                            tags {
         | 
| 261 | 
            +
                                                edges {
         | 
| 262 | 
            +
                                                    node {
         | 
| 263 | 
            +
                                                        name
         | 
| 264 | 
            +
                                                    }
         | 
| 265 | 
            +
                                                }
         | 
| 266 | 
            +
                                            }
         | 
| 267 | 
            +
                                            project {
         | 
| 268 | 
            +
                                                name
         | 
| 269 | 
            +
                                                entity {
         | 
| 270 | 
            +
                                                    name
         | 
| 271 | 
            +
                                                }
         | 
| 272 | 
            +
                                            }
         | 
| 273 | 
            +
                                            defaultArtifactType {
         | 
| 274 | 
            +
                                                name
         | 
| 275 | 
            +
                                            }
         | 
| 276 | 
            +
                                            aliases {
         | 
| 277 | 
            +
                                                edges {
         | 
| 278 | 
            +
                                                    node {
         | 
| 279 | 
            +
                                                        alias
         | 
| 280 | 
            +
                                                    }
         | 
| 281 | 
            +
                                                }
         | 
| 282 | 
            +
                                            }
         | 
| 283 | 
            +
                                        }
         | 
| 284 | 
            +
                                    }
         | 
| 285 | 
            +
                                }
         | 
| 286 | 
            +
                            }
         | 
| 287 | 
            +
                        }
         | 
| 288 | 
            +
                    }
         | 
| 289 | 
            +
                    """
         | 
| 290 | 
            +
                )
         | 
| 291 | 
            +
             | 
| 292 | 
            +
                def __init__(
         | 
| 293 | 
            +
                    self,
         | 
| 294 | 
            +
                    client: "Client",
         | 
| 295 | 
            +
                    organization: str,
         | 
| 296 | 
            +
                    registry_filter: Optional[Dict[str, Any]] = None,
         | 
| 297 | 
            +
                    collection_filter: Optional[Dict[str, Any]] = None,
         | 
| 298 | 
            +
                    per_page: Optional[int] = 100,
         | 
| 299 | 
            +
                ):
         | 
| 300 | 
            +
                    self.client = client
         | 
| 301 | 
            +
                    self.organization = organization
         | 
| 302 | 
            +
                    self.registry_filter = registry_filter
         | 
| 303 | 
            +
                    self.collection_filter = collection_filter or {}
         | 
| 304 | 
            +
             | 
| 305 | 
            +
                    variables = {
         | 
| 306 | 
            +
                        "registryFilter": json.dumps(self.registry_filter)
         | 
| 307 | 
            +
                        if self.registry_filter
         | 
| 308 | 
            +
                        else None,
         | 
| 309 | 
            +
                        "collectionFilter": json.dumps(self.collection_filter)
         | 
| 310 | 
            +
                        if self.collection_filter
         | 
| 311 | 
            +
                        else None,
         | 
| 312 | 
            +
                        "organization": self.organization,
         | 
| 313 | 
            +
                        "collectionTypes": ["PORTFOLIO"],
         | 
| 314 | 
            +
                        "perPage": per_page,
         | 
| 315 | 
            +
                    }
         | 
| 316 | 
            +
             | 
| 317 | 
            +
                    super().__init__(client, variables, per_page)
         | 
| 318 | 
            +
             | 
| 319 | 
            +
                def __bool__(self):
         | 
| 320 | 
            +
                    return len(self) > 0 or len(self.objects) > 0
         | 
| 321 | 
            +
             | 
| 322 | 
            +
                def __next__(self):
         | 
| 323 | 
            +
                    # Implement custom next since its possible to load empty pages because of auth
         | 
| 324 | 
            +
                    self.index += 1
         | 
| 325 | 
            +
                    while len(self.objects) <= self.index:
         | 
| 326 | 
            +
                        if not self._load_page():
         | 
| 327 | 
            +
                            raise StopIteration
         | 
| 328 | 
            +
                    return self.objects[self.index]
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                def versions(self, filter: Optional[Dict[str, Any]] = None) -> "Versions":
         | 
| 331 | 
            +
                    return Versions(
         | 
| 332 | 
            +
                        self.client,
         | 
| 333 | 
            +
                        self.organization,
         | 
| 334 | 
            +
                        registry_filter=self.registry_filter,
         | 
| 335 | 
            +
                        collection_filter=self.collection_filter,
         | 
| 336 | 
            +
                        artifact_filter=filter,
         | 
| 337 | 
            +
                    )
         | 
| 338 | 
            +
             | 
| 339 | 
            +
                @property
         | 
| 340 | 
            +
                def length(self):
         | 
| 341 | 
            +
                    if self.last_response:
         | 
| 342 | 
            +
                        return self.last_response["organization"]["orgEntity"][
         | 
| 343 | 
            +
                            "artifactCollections"
         | 
| 344 | 
            +
                        ]["totalCount"]
         | 
| 345 | 
            +
                    else:
         | 
| 346 | 
            +
                        return None
         | 
| 347 | 
            +
             | 
| 348 | 
            +
                @property
         | 
| 349 | 
            +
                def more(self):
         | 
| 350 | 
            +
                    if self.last_response:
         | 
| 351 | 
            +
                        return self.last_response["organization"]["orgEntity"][
         | 
| 352 | 
            +
                            "artifactCollections"
         | 
| 353 | 
            +
                        ]["pageInfo"]["hasNextPage"]
         | 
| 354 | 
            +
                    else:
         | 
| 355 | 
            +
                        return True
         | 
| 356 | 
            +
             | 
| 357 | 
            +
                @property
         | 
| 358 | 
            +
                def cursor(self):
         | 
| 359 | 
            +
                    if self.last_response:
         | 
| 360 | 
            +
                        return self.last_response["organization"]["orgEntity"][
         | 
| 361 | 
            +
                            "artifactCollections"
         | 
| 362 | 
            +
                        ]["pageInfo"]["endCursor"]
         | 
| 363 | 
            +
                    else:
         | 
| 364 | 
            +
                        return None
         | 
| 365 | 
            +
             | 
| 366 | 
            +
                def convert_objects(self):
         | 
| 367 | 
            +
                    if not self.last_response:
         | 
| 368 | 
            +
                        return []
         | 
| 369 | 
            +
                    if (
         | 
| 370 | 
            +
                        not self.last_response["organization"]
         | 
| 371 | 
            +
                        or not self.last_response["organization"]["orgEntity"]
         | 
| 372 | 
            +
                    ):
         | 
| 373 | 
            +
                        raise ValueError(
         | 
| 374 | 
            +
                            f"Organization '{self.organization}' not found. Please verify the organization name is correct"
         | 
| 375 | 
            +
                        )
         | 
| 376 | 
            +
             | 
| 377 | 
            +
                    return [
         | 
| 378 | 
            +
                        ArtifactCollection(
         | 
| 379 | 
            +
                            self.client,
         | 
| 380 | 
            +
                            r["node"]["project"]["entity"]["name"],
         | 
| 381 | 
            +
                            r["node"]["project"]["name"],
         | 
| 382 | 
            +
                            r["node"]["name"],
         | 
| 383 | 
            +
                            r["node"]["defaultArtifactType"]["name"],
         | 
| 384 | 
            +
                            self.organization,
         | 
| 385 | 
            +
                            r["node"],
         | 
| 386 | 
            +
                        )
         | 
| 387 | 
            +
                        for r in self.last_response["organization"]["orgEntity"][
         | 
| 388 | 
            +
                            "artifactCollections"
         | 
| 389 | 
            +
                        ]["edges"]
         | 
| 390 | 
            +
                    ]
         | 
| 391 | 
            +
             | 
| 392 | 
            +
             | 
| 393 | 
            +
            class Versions(Paginator):
         | 
| 394 | 
            +
                """Iterator that returns Artifact versions in the Registry."""
         | 
| 395 | 
            +
             | 
| 396 | 
            +
                def __init__(
         | 
| 397 | 
            +
                    self,
         | 
| 398 | 
            +
                    client: "Client",
         | 
| 399 | 
            +
                    organization: str,
         | 
| 400 | 
            +
                    registry_filter: Optional[Dict[str, Any]] = None,
         | 
| 401 | 
            +
                    collection_filter: Optional[Dict[str, Any]] = None,
         | 
| 402 | 
            +
                    artifact_filter: Optional[Dict[str, Any]] = None,
         | 
| 403 | 
            +
                    per_page: int = 100,
         | 
| 404 | 
            +
                ):
         | 
| 405 | 
            +
                    self.client = client
         | 
| 406 | 
            +
                    self.organization = organization
         | 
| 407 | 
            +
                    self.registry_filter = registry_filter
         | 
| 408 | 
            +
                    self.collection_filter = collection_filter
         | 
| 409 | 
            +
                    self.artifact_filter = artifact_filter or {}
         | 
| 410 | 
            +
                    self.QUERY = gql(
         | 
| 411 | 
            +
                        """
         | 
| 412 | 
            +
                        query Versions(
         | 
| 413 | 
            +
                            $organization: String!,
         | 
| 414 | 
            +
                            $registryFilter: JSONString,
         | 
| 415 | 
            +
                            $collectionFilter: JSONString,
         | 
| 416 | 
            +
                            $artifactFilter: JSONString,
         | 
| 417 | 
            +
                            $cursor: String,
         | 
| 418 | 
            +
                            $perPage: Int
         | 
| 419 | 
            +
                        ) {
         | 
| 420 | 
            +
                            organization(name: $organization) {
         | 
| 421 | 
            +
                                orgEntity {
         | 
| 422 | 
            +
                                    name
         | 
| 423 | 
            +
                                    artifactMemberships(
         | 
| 424 | 
            +
                                        projectFilters: $registryFilter,
         | 
| 425 | 
            +
                                        collectionFilters: $collectionFilter,
         | 
| 426 | 
            +
                                        filters: $artifactFilter,
         | 
| 427 | 
            +
                                        after: $cursor,
         | 
| 428 | 
            +
                                        first: $perPage
         | 
| 429 | 
            +
                                    ) {
         | 
| 430 | 
            +
                                        pageInfo {
         | 
| 431 | 
            +
                                            endCursor
         | 
| 432 | 
            +
                                            hasNextPage
         | 
| 433 | 
            +
                                        }
         | 
| 434 | 
            +
                                        edges {
         | 
| 435 | 
            +
                                            node {
         | 
| 436 | 
            +
                                                artifactCollection {
         | 
| 437 | 
            +
                                                    project {
         | 
| 438 | 
            +
                                                        name
         | 
| 439 | 
            +
                                                        entity {
         | 
| 440 | 
            +
                                                            name
         | 
| 441 | 
            +
                                                        }
         | 
| 442 | 
            +
                                                    }
         | 
| 443 | 
            +
                                                    name
         | 
| 444 | 
            +
                                                }
         | 
| 445 | 
            +
                                                versionIndex
         | 
| 446 | 
            +
                                                artifact {
         | 
| 447 | 
            +
                                                    ...ArtifactFragment
         | 
| 448 | 
            +
                                                }
         | 
| 449 | 
            +
                                                aliases {
         | 
| 450 | 
            +
                                                    alias
         | 
| 451 | 
            +
                                                }
         | 
| 452 | 
            +
                                            }
         | 
| 453 | 
            +
                                        }
         | 
| 454 | 
            +
                                    }
         | 
| 455 | 
            +
                                }
         | 
| 456 | 
            +
                            }
         | 
| 457 | 
            +
                        }
         | 
| 458 | 
            +
                        """
         | 
| 459 | 
            +
                        + _gql_artifact_fragment(include_aliases=False)
         | 
| 460 | 
            +
                    )
         | 
| 461 | 
            +
             | 
| 462 | 
            +
                    variables = {
         | 
| 463 | 
            +
                        "registryFilter": json.dumps(self.registry_filter)
         | 
| 464 | 
            +
                        if self.registry_filter
         | 
| 465 | 
            +
                        else None,
         | 
| 466 | 
            +
                        "collectionFilter": json.dumps(self.collection_filter)
         | 
| 467 | 
            +
                        if self.collection_filter
         | 
| 468 | 
            +
                        else None,
         | 
| 469 | 
            +
                        "artifactFilter": json.dumps(self.artifact_filter)
         | 
| 470 | 
            +
                        if self.artifact_filter
         | 
| 471 | 
            +
                        else None,
         | 
| 472 | 
            +
                        "organization": self.organization,
         | 
| 473 | 
            +
                    }
         | 
| 474 | 
            +
             | 
| 475 | 
            +
                    super().__init__(client, variables, per_page)
         | 
| 476 | 
            +
             | 
| 477 | 
            +
                def __next__(self):
         | 
| 478 | 
            +
                    # Implement custom next since its possible to load empty pages because of auth
         | 
| 479 | 
            +
                    self.index += 1
         | 
| 480 | 
            +
                    while len(self.objects) <= self.index:
         | 
| 481 | 
            +
                        if not self._load_page():
         | 
| 482 | 
            +
                            raise StopIteration
         | 
| 483 | 
            +
                    return self.objects[self.index]
         | 
| 484 | 
            +
             | 
| 485 | 
            +
                def __bool__(self):
         | 
| 486 | 
            +
                    return len(self) > 0 or len(self.objects) > 0
         | 
| 487 | 
            +
             | 
| 488 | 
            +
                @property
         | 
| 489 | 
            +
                def length(self):
         | 
| 490 | 
            +
                    if self.last_response:
         | 
| 491 | 
            +
                        return len(
         | 
| 492 | 
            +
                            self.last_response["organization"]["orgEntity"]["artifactMemberships"][
         | 
| 493 | 
            +
                                "edges"
         | 
| 494 | 
            +
                            ]
         | 
| 495 | 
            +
                        )
         | 
| 496 | 
            +
                    else:
         | 
| 497 | 
            +
                        return None
         | 
| 498 | 
            +
             | 
| 499 | 
            +
                @property
         | 
| 500 | 
            +
                def more(self):
         | 
| 501 | 
            +
                    if self.last_response:
         | 
| 502 | 
            +
                        return self.last_response["organization"]["orgEntity"][
         | 
| 503 | 
            +
                            "artifactMemberships"
         | 
| 504 | 
            +
                        ]["pageInfo"]["hasNextPage"]
         | 
| 505 | 
            +
                    else:
         | 
| 506 | 
            +
                        return True
         | 
| 507 | 
            +
             | 
| 508 | 
            +
                @property
         | 
| 509 | 
            +
                def cursor(self):
         | 
| 510 | 
            +
                    if self.last_response:
         | 
| 511 | 
            +
                        return self.last_response["organization"]["orgEntity"][
         | 
| 512 | 
            +
                            "artifactMemberships"
         | 
| 513 | 
            +
                        ]["pageInfo"]["endCursor"]
         | 
| 514 | 
            +
                    else:
         | 
| 515 | 
            +
                        return None
         | 
| 516 | 
            +
             | 
| 517 | 
            +
                def convert_objects(self):
         | 
| 518 | 
            +
                    if not self.last_response:
         | 
| 519 | 
            +
                        return []
         | 
| 520 | 
            +
                    if (
         | 
| 521 | 
            +
                        not self.last_response["organization"]
         | 
| 522 | 
            +
                        or not self.last_response["organization"]["orgEntity"]
         | 
| 523 | 
            +
                    ):
         | 
| 524 | 
            +
                        raise ValueError(
         | 
| 525 | 
            +
                            f"Organization '{self.organization}' not found. Please verify the organization name is correct"
         | 
| 526 | 
            +
                        )
         | 
| 527 | 
            +
             | 
| 528 | 
            +
                    artifacts = (
         | 
| 529 | 
            +
                        wandb.Artifact._from_attrs(
         | 
| 530 | 
            +
                            a["node"]["artifactCollection"]["project"]["entity"]["name"],
         | 
| 531 | 
            +
                            a["node"]["artifactCollection"]["project"]["name"],
         | 
| 532 | 
            +
                            a["node"]["artifactCollection"]["name"]
         | 
| 533 | 
            +
                            + ":v"
         | 
| 534 | 
            +
                            + str(a["node"]["versionIndex"]),
         | 
| 535 | 
            +
                            a["node"]["artifact"],
         | 
| 536 | 
            +
                            self.client,
         | 
| 537 | 
            +
                            [alias["alias"] for alias in a["node"]["aliases"]],
         | 
| 538 | 
            +
                        )
         | 
| 539 | 
            +
                        for a in self.last_response["organization"]["orgEntity"][
         | 
| 540 | 
            +
                            "artifactMemberships"
         | 
| 541 | 
            +
                        ]["edges"]
         | 
| 542 | 
            +
                    )
         | 
| 543 | 
            +
                    return artifacts
         | 
| 544 | 
            +
             | 
| 545 | 
            +
             | 
| 546 | 
            +
            def _ensure_registry_prefix_on_names(query, in_name=False):
         | 
| 547 | 
            +
                """Traverse the filter to prepend the `name` key value with the registry prefix unless the value is a regex.
         | 
| 548 | 
            +
             | 
| 549 | 
            +
                - in_name: True if we are under a "name" key (or propagating from one).
         | 
| 550 | 
            +
             | 
| 551 | 
            +
                EX: {"name": "model"} -> {"name": "wandb-registry-model"}
         | 
| 552 | 
            +
                """
         | 
| 553 | 
            +
                if isinstance((txt := query), str):
         | 
| 554 | 
            +
                    if in_name:
         | 
| 555 | 
            +
                        return txt if txt.startswith(REGISTRY_PREFIX) else f"{REGISTRY_PREFIX}{txt}"
         | 
| 556 | 
            +
                    return txt
         | 
| 557 | 
            +
                if isinstance((dct := query), Mapping):
         | 
| 558 | 
            +
                    new_dict = {}
         | 
| 559 | 
            +
                    for key, obj in dct.items():
         | 
| 560 | 
            +
                        if key == "name":
         | 
| 561 | 
            +
                            new_dict[key] = _ensure_registry_prefix_on_names(obj, in_name=True)
         | 
| 562 | 
            +
                        elif key == "$regex":
         | 
| 563 | 
            +
                            # For regex operator, we skip transformation of its value.
         | 
| 564 | 
            +
                            new_dict[key] = obj
         | 
| 565 | 
            +
                        else:
         | 
| 566 | 
            +
                            # For any other key, propagate the in_name and skip_transform flags as-is.
         | 
| 567 | 
            +
                            new_dict[key] = _ensure_registry_prefix_on_names(obj, in_name=in_name)
         | 
| 568 | 
            +
                    return new_dict
         | 
| 569 | 
            +
                if isinstance((objs := query), Sequence):
         | 
| 570 | 
            +
                    return list(
         | 
| 571 | 
            +
                        map(lambda x: _ensure_registry_prefix_on_names(x, in_name=in_name), objs)
         | 
| 572 | 
            +
                    )
         | 
| 573 | 
            +
                return query
         | 
    
        wandb/apis/public/utils.py
    CHANGED
    
    | @@ -1,8 +1,11 @@ | |
| 1 1 | 
             
            import re
         | 
| 2 2 | 
             
            from enum import Enum
         | 
| 3 | 
            +
            from typing import Optional
         | 
| 3 4 | 
             
            from urllib.parse import urlparse
         | 
| 4 5 |  | 
| 6 | 
            +
            from wandb._iterutils import one
         | 
| 5 7 | 
             
            from wandb.sdk.artifacts._validators import is_artifact_registry_project
         | 
| 8 | 
            +
            from wandb.sdk.internal.internal_api import Api as InternalApi
         | 
| 6 9 |  | 
| 7 10 |  | 
| 8 11 | 
             
            def parse_s3_url_to_s3_uri(url) -> str:
         | 
| @@ -66,3 +69,36 @@ def parse_org_from_registry_path(path: str, path_type: PathType) -> str: | |
| 66 69 | 
             
                    if is_artifact_registry_project(project):
         | 
| 67 70 | 
             
                        return org
         | 
| 68 71 | 
             
                return ""
         | 
| 72 | 
            +
             | 
| 73 | 
            +
             | 
| 74 | 
            +
            def fetch_org_from_settings_or_entity(
         | 
| 75 | 
            +
                settings: dict, default_entity: Optional[str] = None
         | 
| 76 | 
            +
            ) -> str:
         | 
| 77 | 
            +
                """Fetch the org from either the settings or deriving it from the entity.
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                Returns the org from the settings if available. If no org is passed in or set, the entity is used to fetch the org.
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                Args:
         | 
| 82 | 
            +
                    organization (str | None): The organization to fetch the org for.
         | 
| 83 | 
            +
                    settings (dict): The settings to fetch the org for.
         | 
| 84 | 
            +
                    default_entity (str | None): The default entity to fetch the org for.
         | 
| 85 | 
            +
                """
         | 
| 86 | 
            +
                if (organization := settings.get("organization")) is None:
         | 
| 87 | 
            +
                    # Fetch the org via the Entity. Won't work if default entity is a personal entity and belongs to multiple orgs
         | 
| 88 | 
            +
                    entity = settings.get("entity") or default_entity
         | 
| 89 | 
            +
                    if entity is None:
         | 
| 90 | 
            +
                        raise ValueError(
         | 
| 91 | 
            +
                            "No entity specified and can't fetch organization from the entity"
         | 
| 92 | 
            +
                        )
         | 
| 93 | 
            +
                    entity_orgs = InternalApi()._fetch_orgs_and_org_entities_from_entity(entity)
         | 
| 94 | 
            +
                    entity_org = one(
         | 
| 95 | 
            +
                        entity_orgs,
         | 
| 96 | 
            +
                        too_short=ValueError(
         | 
| 97 | 
            +
                            "No organizations found for entity. Please specify an organization in the settings."
         | 
| 98 | 
            +
                        ),
         | 
| 99 | 
            +
                        too_long=ValueError(
         | 
| 100 | 
            +
                            "Multiple organizations found for entity. Please specify an organization in the settings."
         | 
| 101 | 
            +
                        ),
         | 
| 102 | 
            +
                    )
         | 
| 103 | 
            +
                    organization = entity_org.display_name
         | 
| 104 | 
            +
                return organization
         | 
    
        wandb/bin/gpu_stats
    CHANGED
    
    | Binary file | 
    
        wandb/bin/wandb-core
    CHANGED
    
    | Binary file | 
    
        wandb/cli/cli.py
    CHANGED
    
    | @@ -241,32 +241,21 @@ def login(key, host, cloud, relogin, anonymously, verify, no_offline=False): | |
| 241 241 | 
             
                wandb_sdk.wandb_login._handle_host_wandb_setting(host, cloud)
         | 
| 242 242 | 
             
                # A change in click or the test harness means key can be none...
         | 
| 243 243 | 
             
                key = key[0] if key is not None and len(key) > 0 else None
         | 
| 244 | 
            -
                if key | 
| 245 | 
            -
                    relogin = True
         | 
| 244 | 
            +
                relogin = True if key or relogin else False
         | 
| 246 245 |  | 
| 247 | 
            -
                 | 
| 248 | 
            -
                     | 
| 249 | 
            -
             | 
| 250 | 
            -
             | 
| 251 | 
            -
                    base_url=host,
         | 
| 252 | 
            -
                )
         | 
| 253 | 
            -
             | 
| 254 | 
            -
                try:
         | 
| 255 | 
            -
                    wandb.setup(
         | 
| 256 | 
            -
                        settings=wandb.Settings(
         | 
| 257 | 
            -
                            **{k: v for k, v in login_settings.items() if v is not None}
         | 
| 258 | 
            -
                        )
         | 
| 246 | 
            +
                wandb.setup(
         | 
| 247 | 
            +
                    settings=wandb.Settings(
         | 
| 248 | 
            +
                        x_cli_only_mode=True,
         | 
| 249 | 
            +
                        x_disable_viewer=relogin and not verify,
         | 
| 259 250 | 
             
                    )
         | 
| 260 | 
            -
                 | 
| 261 | 
            -
                    wandb.termerror(str(e))
         | 
| 262 | 
            -
                    sys.exit(1)
         | 
| 251 | 
            +
                )
         | 
| 263 252 |  | 
| 264 253 | 
             
                wandb.login(
         | 
| 265 | 
            -
                    relogin=relogin,
         | 
| 266 | 
            -
                    key=key,
         | 
| 267 254 | 
             
                    anonymous=anon_mode,
         | 
| 268 | 
            -
                    host=host,
         | 
| 269 255 | 
             
                    force=True,
         | 
| 256 | 
            +
                    host=host,
         | 
| 257 | 
            +
                    key=key,
         | 
| 258 | 
            +
                    relogin=relogin,
         | 
| 270 259 | 
             
                    verify=verify,
         | 
| 271 260 | 
             
                )
         | 
| 272 261 |  | 
| @@ -2805,11 +2794,13 @@ def verify(host): | |
| 2805 2794 | 
             
                wandb_verify.check_wandb_version(api)
         | 
| 2806 2795 | 
             
                check_run_success = wandb_verify.check_run(api)
         | 
| 2807 2796 | 
             
                check_artifacts_success = wandb_verify.check_artifacts()
         | 
| 2797 | 
            +
                check_sweeps_success = wandb_verify.check_sweeps(api)
         | 
| 2808 2798 | 
             
                if not (
         | 
| 2809 2799 | 
             
                    check_artifacts_success
         | 
| 2810 2800 | 
             
                    and check_run_success
         | 
| 2811 2801 | 
             
                    and large_post_success
         | 
| 2812 2802 | 
             
                    and url_success
         | 
| 2803 | 
            +
                    and check_sweeps_success
         | 
| 2813 2804 | 
             
                ):
         | 
| 2814 2805 | 
             
                    sys.exit(1)
         | 
| 2815 2806 |  | 
    
        wandb/env.py
    CHANGED
    
    | @@ -34,6 +34,7 @@ USERNAME = "WANDB_USERNAME" | |
| 34 34 | 
             
            USER_EMAIL = "WANDB_USER_EMAIL"
         | 
| 35 35 | 
             
            PROJECT = "WANDB_PROJECT"
         | 
| 36 36 | 
             
            ENTITY = "WANDB_ENTITY"
         | 
| 37 | 
            +
            ORGANIZATION = "WANDB_ORGANIZATION"
         | 
| 37 38 | 
             
            BASE_URL = "WANDB_BASE_URL"
         | 
| 38 39 | 
             
            APP_URL = "WANDB_APP_URL"
         | 
| 39 40 | 
             
            PROGRAM = "WANDB_PROGRAM"
         | 
| @@ -284,6 +285,15 @@ def get_entity( | |
| 284 285 | 
             
                return env.get(ENTITY, default)
         | 
| 285 286 |  | 
| 286 287 |  | 
| 288 | 
            +
            def get_organization(
         | 
| 289 | 
            +
                default: str | None = None, env: MutableMapping | None = None
         | 
| 290 | 
            +
            ) -> str | None:
         | 
| 291 | 
            +
                if env is None:
         | 
| 292 | 
            +
                    env = os.environ
         | 
| 293 | 
            +
             | 
| 294 | 
            +
                return env.get(ORGANIZATION, default)
         | 
| 295 | 
            +
             | 
| 296 | 
            +
             | 
| 287 297 | 
             
            def get_base_url(
         | 
| 288 298 | 
             
                default: str | None = None, env: MutableMapping | None = None
         | 
| 289 299 | 
             
            ) -> str | None:
         |