wandb 0.13.11__py3-none-any.whl → 0.14.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +1 -1
- wandb/apis/importers/__init__.py +4 -0
- wandb/apis/importers/base.py +312 -0
- wandb/apis/importers/mlflow.py +113 -0
- wandb/apis/internal.py +9 -0
- wandb/apis/public.py +0 -2
- wandb/cli/cli.py +100 -72
- wandb/docker/__init__.py +33 -5
- wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
- wandb/sdk/internal/internal_api.py +85 -9
- wandb/sdk/launch/_project_spec.py +45 -55
- wandb/sdk/launch/agent/agent.py +80 -18
- wandb/sdk/launch/builder/build.py +16 -74
- wandb/sdk/launch/builder/docker_builder.py +36 -8
- wandb/sdk/launch/builder/kaniko_builder.py +78 -37
- wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +68 -18
- wandb/sdk/launch/environment/aws_environment.py +4 -0
- wandb/sdk/launch/launch.py +1 -6
- wandb/sdk/launch/launch_add.py +0 -5
- wandb/sdk/launch/registry/abstract.py +12 -0
- wandb/sdk/launch/registry/elastic_container_registry.py +31 -1
- wandb/sdk/launch/registry/google_artifact_registry.py +32 -0
- wandb/sdk/launch/registry/local_registry.py +15 -1
- wandb/sdk/launch/runner/abstract.py +0 -14
- wandb/sdk/launch/runner/kubernetes_runner.py +25 -19
- wandb/sdk/launch/runner/local_container.py +7 -8
- wandb/sdk/launch/runner/local_process.py +0 -3
- wandb/sdk/launch/runner/sagemaker_runner.py +0 -3
- wandb/sdk/launch/runner/vertex_runner.py +0 -2
- wandb/sdk/launch/sweeps/scheduler.py +39 -10
- wandb/sdk/launch/utils.py +52 -4
- wandb/sdk/wandb_run.py +3 -10
- wandb/sync/sync.py +1 -0
- wandb/util.py +1 -0
- {wandb-0.13.11.dist-info → wandb-0.14.0.dist-info}/METADATA +1 -1
- {wandb-0.13.11.dist-info → wandb-0.14.0.dist-info}/RECORD +41 -38
- {wandb-0.13.11.dist-info → wandb-0.14.0.dist-info}/WHEEL +1 -1
- {wandb-0.13.11.dist-info → wandb-0.14.0.dist-info}/LICENSE +0 -0
- {wandb-0.13.11.dist-info → wandb-0.14.0.dist-info}/entry_points.txt +0 -0
- {wandb-0.13.11.dist-info → wandb-0.14.0.dist-info}/top_level.txt +0 -0
| @@ -1,13 +1,11 @@ | |
| 1 1 | 
             
            import base64
         | 
| 2 2 | 
             
            import json
         | 
| 3 | 
            +
            import logging
         | 
| 3 4 | 
             
            import tarfile
         | 
| 4 5 | 
             
            import tempfile
         | 
| 5 6 | 
             
            import time
         | 
| 6 7 | 
             
            from typing import Optional
         | 
| 7 8 |  | 
| 8 | 
            -
            import kubernetes  # type: ignore
         | 
| 9 | 
            -
            from kubernetes import client
         | 
| 10 | 
            -
             | 
| 11 9 | 
             
            import wandb
         | 
| 12 10 | 
             
            from wandb.sdk.launch.builder.abstract import AbstractBuilder
         | 
| 13 11 | 
             
            from wandb.sdk.launch.environment.abstract import AbstractEnvironment
         | 
| @@ -16,7 +14,7 @@ from wandb.sdk.launch.registry.elastic_container_registry import ( | |
| 16 14 | 
             
                ElasticContainerRegistry,
         | 
| 17 15 | 
             
            )
         | 
| 18 16 | 
             
            from wandb.sdk.launch.registry.google_artifact_registry import GoogleArtifactRegistry
         | 
| 19 | 
            -
            from wandb. | 
| 17 | 
            +
            from wandb.util import get_module
         | 
| 20 18 |  | 
| 21 19 | 
             
            from .._project_spec import (
         | 
| 22 20 | 
             
                EntryPoint,
         | 
| @@ -24,8 +22,28 @@ from .._project_spec import ( | |
| 24 22 | 
             
                create_metadata_file,
         | 
| 25 23 | 
             
                get_entry_point_command,
         | 
| 26 24 | 
             
            )
         | 
| 27 | 
            -
            from ..utils import  | 
| 28 | 
            -
             | 
| 25 | 
            +
            from ..utils import (
         | 
| 26 | 
            +
                LOG_PREFIX,
         | 
| 27 | 
            +
                LaunchError,
         | 
| 28 | 
            +
                get_kube_context_and_api_client,
         | 
| 29 | 
            +
                sanitize_wandb_api_key,
         | 
| 30 | 
            +
                warn_failed_packages_from_build_logs,
         | 
| 31 | 
            +
            )
         | 
| 32 | 
            +
            from .build import (
         | 
| 33 | 
            +
                _create_docker_build_ctx,
         | 
| 34 | 
            +
                generate_dockerfile,
         | 
| 35 | 
            +
                image_tag_from_dockerfile_and_source,
         | 
| 36 | 
            +
            )
         | 
| 37 | 
            +
             | 
| 38 | 
            +
            get_module(
         | 
| 39 | 
            +
                "kubernetes",
         | 
| 40 | 
            +
                required="Kaniko builder requires the kubernetes package. Please install it with `pip install wandb[launch]`.",
         | 
| 41 | 
            +
            )
         | 
| 42 | 
            +
             | 
| 43 | 
            +
            import kubernetes  # type: ignore # noqa: E402
         | 
| 44 | 
            +
            from kubernetes import client  # noqa: E402
         | 
| 45 | 
            +
             | 
| 46 | 
            +
            _logger = logging.getLogger(__name__)
         | 
| 29 47 |  | 
| 30 48 | 
             
            _DEFAULT_BUILD_TIMEOUT_SECS = 1800  # 30 minute build timeout
         | 
| 31 49 |  | 
| @@ -155,7 +173,7 @@ class KanikoBuilder(AbstractBuilder): | |
| 155 173 | 
             
                    pass
         | 
| 156 174 |  | 
| 157 175 | 
             
                def _create_docker_ecr_config_map(
         | 
| 158 | 
            -
                    self, corev1_client: client.CoreV1Api, repository: str
         | 
| 176 | 
            +
                    self, job_name: str, corev1_client: client.CoreV1Api, repository: str
         | 
| 159 177 | 
             
                ) -> None:
         | 
| 160 178 | 
             
                    if self.registry is None:
         | 
| 161 179 | 
             
                        raise LaunchError("No registry specified for Kaniko build.")
         | 
| @@ -165,7 +183,7 @@ class KanikoBuilder(AbstractBuilder): | |
| 165 183 | 
             
                        api_version="v1",
         | 
| 166 184 | 
             
                        kind="ConfigMap",
         | 
| 167 185 | 
             
                        metadata=client.V1ObjectMeta(
         | 
| 168 | 
            -
                            name="docker-config",
         | 
| 186 | 
            +
                            name=f"docker-config-{job_name}",
         | 
| 169 187 | 
             
                            namespace="wandb",
         | 
| 170 188 | 
             
                        ),
         | 
| 171 189 | 
             
                        data={
         | 
| @@ -177,8 +195,11 @@ class KanikoBuilder(AbstractBuilder): | |
| 177 195 | 
             
                    )
         | 
| 178 196 | 
             
                    corev1_client.create_namespaced_config_map("wandb", ecr_config_map)
         | 
| 179 197 |  | 
| 180 | 
            -
                def _delete_docker_ecr_config_map( | 
| 181 | 
            -
                    client. | 
| 198 | 
            +
                def _delete_docker_ecr_config_map(
         | 
| 199 | 
            +
                    self, job_name: str, client: client.CoreV1Api
         | 
| 200 | 
            +
                ) -> None:
         | 
| 201 | 
            +
                    if self.secret_name:
         | 
| 202 | 
            +
                        client.delete_namespaced_config_map(f"docker-config-{job_name}", "wandb")
         | 
| 182 203 |  | 
| 183 204 | 
             
                def _upload_build_context(self, run_id: str, context_path: str) -> str:
         | 
| 184 205 | 
             
                    # creat a tar archive of the build context and upload it to s3
         | 
| @@ -197,20 +218,28 @@ class KanikoBuilder(AbstractBuilder): | |
| 197 218 | 
             
                    launch_project: LaunchProject,
         | 
| 198 219 | 
             
                    entrypoint: EntryPoint,
         | 
| 199 220 | 
             
                ) -> str:
         | 
| 221 | 
            +
                    # TODO: this should probably throw an error if the registry is a local registry
         | 
| 200 222 | 
             
                    if not self.registry:
         | 
| 201 223 | 
             
                        raise LaunchError("No registry specified for Kaniko build.")
         | 
| 224 | 
            +
                    # kaniko builder doesn't seem to work with a custom user id, need more investigation
         | 
| 225 | 
            +
                    dockerfile_str = generate_dockerfile(
         | 
| 226 | 
            +
                        launch_project, entrypoint, launch_project.resource, "kaniko"
         | 
| 227 | 
            +
                    )
         | 
| 228 | 
            +
                    image_tag = image_tag_from_dockerfile_and_source(launch_project, dockerfile_str)
         | 
| 202 229 | 
             
                    repo_uri = self.registry.get_repo_uri()
         | 
| 203 | 
            -
                    image_uri = repo_uri + ":" +  | 
| 230 | 
            +
                    image_uri = repo_uri + ":" + image_tag
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                    if not launch_project.build_required() and self.registry.check_image_exists(
         | 
| 233 | 
            +
                        image_uri
         | 
| 234 | 
            +
                    ):
         | 
| 235 | 
            +
                        return image_uri
         | 
| 236 | 
            +
             | 
| 237 | 
            +
                    _logger.info(f"Building image {image_uri}...")
         | 
| 204 238 |  | 
| 205 239 | 
             
                    entry_cmd = " ".join(
         | 
| 206 240 | 
             
                        get_entry_point_command(entrypoint, launch_project.override_args)
         | 
| 207 241 | 
             
                    )
         | 
| 208 242 |  | 
| 209 | 
            -
                    # kaniko builder doesn't seem to work with a custom user id, need more investigation
         | 
| 210 | 
            -
                    dockerfile_str = generate_dockerfile(
         | 
| 211 | 
            -
                        launch_project, entrypoint, launch_project.resource, self.type
         | 
| 212 | 
            -
                    )
         | 
| 213 | 
            -
             | 
| 214 243 | 
             
                    create_metadata_file(
         | 
| 215 244 | 
             
                        launch_project,
         | 
| 216 245 | 
             
                        image_uri,
         | 
| @@ -240,7 +269,8 @@ class KanikoBuilder(AbstractBuilder): | |
| 240 269 |  | 
| 241 270 | 
             
                    try:
         | 
| 242 271 | 
             
                        # core_v1.create_namespaced_config_map("wandb", dockerfile_config_map)
         | 
| 243 | 
            -
                        self. | 
| 272 | 
            +
                        if self.secret_name:
         | 
| 273 | 
            +
                            self._create_docker_ecr_config_map(build_job_name, core_v1, repo_uri)
         | 
| 244 274 | 
             
                        batch_v1.create_namespaced_job("wandb", build_job)
         | 
| 245 275 |  | 
| 246 276 | 
             
                        # wait for double the job deadline since it might take time to schedule
         | 
| @@ -248,6 +278,13 @@ class KanikoBuilder(AbstractBuilder): | |
| 248 278 | 
             
                            batch_v1, build_job_name, 3 * _DEFAULT_BUILD_TIMEOUT_SECS
         | 
| 249 279 | 
             
                        ):
         | 
| 250 280 | 
             
                            raise Exception(f"Failed to build image in kaniko for job {run_id}")
         | 
| 281 | 
            +
                        try:
         | 
| 282 | 
            +
                            logs = batch_v1.read_namespaced_job_log(build_job_name, "wandb")
         | 
| 283 | 
            +
                            warn_failed_packages_from_build_logs(logs, image_uri)
         | 
| 284 | 
            +
                        except Exception as e:
         | 
| 285 | 
            +
                            wandb.termwarn(
         | 
| 286 | 
            +
                                f"{LOG_PREFIX}Failed to get logs for kaniko job {build_job_name}: {e}"
         | 
| 287 | 
            +
                            )
         | 
| 251 288 | 
             
                    except Exception as e:
         | 
| 252 289 | 
             
                        wandb.termerror(
         | 
| 253 290 | 
             
                            f"{LOG_PREFIX}Exception when creating Kubernetes resources: {e}\n"
         | 
| @@ -258,7 +295,8 @@ class KanikoBuilder(AbstractBuilder): | |
| 258 295 | 
             
                        try:
         | 
| 259 296 | 
             
                            # should we clean up the s3 build contexts? can set bucket level policy to auto deletion
         | 
| 260 297 | 
             
                            # core_v1.delete_namespaced_config_map(config_map_name, "wandb")
         | 
| 261 | 
            -
                            self. | 
| 298 | 
            +
                            if self.secret_name:
         | 
| 299 | 
            +
                                self._delete_docker_ecr_config_map(build_job_name, core_v1)
         | 
| 262 300 | 
             
                            batch_v1.delete_namespaced_job(build_job_name, "wandb")
         | 
| 263 301 | 
             
                        except Exception as e:
         | 
| 264 302 | 
             
                            raise LaunchError(f"Exception during Kubernetes resource clean up {e}")
         | 
| @@ -273,25 +311,34 @@ class KanikoBuilder(AbstractBuilder): | |
| 273 311 | 
             
                    build_context_path: str,
         | 
| 274 312 | 
             
                ) -> "client.V1Job":
         | 
| 275 313 | 
             
                    env = []
         | 
| 276 | 
            -
                    volume_mounts = [
         | 
| 277 | 
            -
             | 
| 278 | 
            -
                    ]
         | 
| 279 | 
            -
                    volumes = [
         | 
| 280 | 
            -
                        client.V1Volume(
         | 
| 281 | 
            -
                            name="docker-config",
         | 
| 282 | 
            -
                            config_map=client.V1ConfigMapVolumeSource(
         | 
| 283 | 
            -
                                name="docker-config",
         | 
| 284 | 
            -
                            ),
         | 
| 285 | 
            -
                        ),
         | 
| 286 | 
            -
                    ]
         | 
| 314 | 
            +
                    volume_mounts = []
         | 
| 315 | 
            +
                    volumes = []
         | 
| 287 316 | 
             
                    if bool(self.secret_name) != bool(self.secret_key):
         | 
| 288 317 | 
             
                        raise LaunchError(
         | 
| 289 318 | 
             
                            "Both secret_name and secret_key or neither must be specified "
         | 
| 290 319 | 
             
                            "for kaniko build. You provided only one of them."
         | 
| 291 320 | 
             
                        )
         | 
| 321 | 
            +
                    if isinstance(self.registry, ElasticContainerRegistry):
         | 
| 322 | 
            +
                        env += [
         | 
| 323 | 
            +
                            client.V1EnvVar(
         | 
| 324 | 
            +
                                name="AWS_REGION",
         | 
| 325 | 
            +
                                value=self.registry.environment.region,
         | 
| 326 | 
            +
                            )
         | 
| 327 | 
            +
                        ]
         | 
| 292 328 | 
             
                    if self.secret_name and self.secret_key:
         | 
| 293 | 
            -
                         | 
| 294 | 
            -
             | 
| 329 | 
            +
                        volumes += [
         | 
| 330 | 
            +
                            client.V1Volume(
         | 
| 331 | 
            +
                                name="docker-config",
         | 
| 332 | 
            +
                                config_map=client.V1ConfigMapVolumeSource(
         | 
| 333 | 
            +
                                    name=f"docker-config-{job_name}",
         | 
| 334 | 
            +
                                ),
         | 
| 335 | 
            +
                            ),
         | 
| 336 | 
            +
                        ]
         | 
| 337 | 
            +
                        volume_mounts += [
         | 
| 338 | 
            +
                            client.V1VolumeMount(
         | 
| 339 | 
            +
                                name="docker-config", mount_path="/kaniko/.docker/"
         | 
| 340 | 
            +
                            ),
         | 
| 341 | 
            +
                        ]
         | 
| 295 342 | 
             
                        # TODO: I don't like conditioning on the registry type here. As a
         | 
| 296 343 | 
             
                        # future change I want the registry and environment classes to provide
         | 
| 297 344 | 
             
                        # a list of environment variables and volume mounts that need to be
         | 
| @@ -303,12 +350,6 @@ class KanikoBuilder(AbstractBuilder): | |
| 303 350 | 
             
                        if isinstance(self.registry, ElasticContainerRegistry):
         | 
| 304 351 | 
             
                            mount_path = "/root/.aws"
         | 
| 305 352 | 
             
                            key = "credentials"
         | 
| 306 | 
            -
                            env += [
         | 
| 307 | 
            -
                                client.V1EnvVar(
         | 
| 308 | 
            -
                                    name="AWS_REGION",
         | 
| 309 | 
            -
                                    value=self.registry.environment.region,
         | 
| 310 | 
            -
                                )
         | 
| 311 | 
            -
                            ]
         | 
| 312 353 | 
             
                        elif isinstance(self.registry, GoogleArtifactRegistry):
         | 
| 313 354 | 
             
                            mount_path = "/kaniko/.config/gcloud"
         | 
| 314 355 | 
             
                            key = "config.json"
         | 
| @@ -1,10 +1,13 @@ | |
| 1 1 | 
             
            import json
         | 
| 2 2 | 
             
            import multiprocessing
         | 
| 3 3 | 
             
            import os
         | 
| 4 | 
            +
            import re
         | 
| 4 5 | 
             
            import subprocess
         | 
| 5 6 | 
             
            import sys
         | 
| 6 7 | 
             
            from typing import List, Optional, Set
         | 
| 7 8 |  | 
| 9 | 
            +
            FAILED_PACKAGES_PREFIX = "ERROR: Failed to install: "
         | 
| 10 | 
            +
            FAILED_PACKAGES_POSTFIX = ". During automated build process."
         | 
| 8 11 | 
             
            CORES = multiprocessing.cpu_count()
         | 
| 9 12 | 
             
            ONLY_INCLUDE = {x for x in os.getenv("WANDB_ONLY_INCLUDE", "").split(",") if x != ""}
         | 
| 10 13 | 
             
            OPTS = []
         | 
| @@ -21,7 +24,10 @@ else: | |
| 21 24 |  | 
| 22 25 |  | 
| 23 26 | 
             
            def install_deps(
         | 
| 24 | 
            -
                deps: List[str], | 
| 27 | 
            +
                deps: List[str],
         | 
| 28 | 
            +
                failed: Optional[Set[str]] = None,
         | 
| 29 | 
            +
                extra_index: Optional[str] = None,
         | 
| 30 | 
            +
                opts: Optional[List[str]] = None,
         | 
| 25 31 | 
             
            ) -> Optional[Set[str]]:
         | 
| 26 32 | 
             
                """Install pip dependencies.
         | 
| 27 33 |  | 
| @@ -35,33 +41,45 @@ def install_deps( | |
| 35 41 | 
             
                try:
         | 
| 36 42 | 
             
                    # Include only uri if @ is present
         | 
| 37 43 | 
             
                    clean_deps = [d.split("@")[-1].strip() if "@" in d else d for d in deps]
         | 
| 38 | 
            -
             | 
| 44 | 
            +
                    index_args = ["--extra-index-url", extra_index] if extra_index else []
         | 
| 39 45 | 
             
                    print("installing {}...".format(", ".join(clean_deps)))
         | 
| 46 | 
            +
                    opts = opts or []
         | 
| 47 | 
            +
                    args = ["pip", "install"] + opts + clean_deps + index_args
         | 
| 40 48 | 
             
                    sys.stdout.flush()
         | 
| 41 | 
            -
                    subprocess.check_output(
         | 
| 42 | 
            -
                        ["pip", "install"] + OPTS + clean_deps, stderr=subprocess.STDOUT
         | 
| 43 | 
            -
                    )
         | 
| 44 | 
            -
                    if failed is not None and len(failed) > 0:
         | 
| 45 | 
            -
                        sys.stderr.write(
         | 
| 46 | 
            -
                            "ERROR: Unable to install: {}".format(", ".join(clean_deps))
         | 
| 47 | 
            -
                        )
         | 
| 48 | 
            -
                        sys.stderr.flush()
         | 
| 49 | 
            +
                    subprocess.check_output(args, stderr=subprocess.STDOUT)
         | 
| 49 50 | 
             
                    return failed
         | 
| 50 51 | 
             
                except subprocess.CalledProcessError as e:
         | 
| 51 52 | 
             
                    if failed is None:
         | 
| 52 53 | 
             
                        failed = set()
         | 
| 53 54 | 
             
                    num_failed = len(failed)
         | 
| 54 | 
            -
                    for line in e.output.decode("utf8"):
         | 
| 55 | 
            +
                    for line in e.output.decode("utf8").splitlines():
         | 
| 55 56 | 
             
                        if line.startswith("ERROR:"):
         | 
| 56 | 
            -
                             | 
| 57 | 
            -
             | 
| 58 | 
            -
             | 
| 57 | 
            +
                            clean_dep = find_package_in_error_string(clean_deps, line)
         | 
| 58 | 
            +
                            if clean_dep is not None:
         | 
| 59 | 
            +
                                if clean_dep in deps:
         | 
| 60 | 
            +
                                    failed.add(clean_dep)
         | 
| 61 | 
            +
                                else:
         | 
| 62 | 
            +
                                    for d in deps:
         | 
| 63 | 
            +
                                        if clean_dep in d:
         | 
| 64 | 
            +
                                            failed.add(d.replace(" ", ""))
         | 
| 65 | 
            +
                                            break
         | 
| 66 | 
            +
                    if len(set(clean_deps) - failed) == 0:
         | 
| 67 | 
            +
                        return failed
         | 
| 68 | 
            +
                    elif len(failed) > num_failed:
         | 
| 69 | 
            +
                        return install_deps(
         | 
| 70 | 
            +
                            list(set(clean_deps) - failed),
         | 
| 71 | 
            +
                            failed,
         | 
| 72 | 
            +
                            extra_index=extra_index,
         | 
| 73 | 
            +
                            opts=opts,
         | 
| 74 | 
            +
                        )
         | 
| 59 75 | 
             
                    else:
         | 
| 60 76 | 
             
                        return failed
         | 
| 61 77 |  | 
| 62 78 |  | 
| 63 79 | 
             
            def main() -> None:
         | 
| 64 80 | 
             
                """Install deps in requirements.frozen.txt."""
         | 
| 81 | 
            +
                extra_index = None
         | 
| 82 | 
            +
                torch_reqs = []
         | 
| 65 83 | 
             
                if os.path.exists("requirements.frozen.txt"):
         | 
| 66 84 | 
             
                    with open("requirements.frozen.txt") as f:
         | 
| 67 85 | 
             
                        print("Installing frozen dependencies...")
         | 
| @@ -72,28 +90,60 @@ def main() -> None: | |
| 72 90 | 
             
                                # can't pip install wandb==0.*.*.dev1 through pip. Lets just install wandb for now
         | 
| 73 91 | 
             
                                if req.startswith("wandb==") and "dev1" in req:
         | 
| 74 92 | 
             
                                    req = "wandb"
         | 
| 75 | 
            -
                                 | 
| 93 | 
            +
                                match = re.match(
         | 
| 94 | 
            +
                                    r"torch(vision|audio)?==\d+\.\d+\.\d+(\+(?:cu[\d]{2,3})|(?:cpu))?",
         | 
| 95 | 
            +
                                    req,
         | 
| 96 | 
            +
                                )
         | 
| 97 | 
            +
                                if match:
         | 
| 98 | 
            +
                                    variant = match.group(2)
         | 
| 99 | 
            +
                                    if variant:
         | 
| 100 | 
            +
                                        extra_index = (
         | 
| 101 | 
            +
                                            f"https://download.pytorch.org/whl/{variant[1:]}"
         | 
| 102 | 
            +
                                        )
         | 
| 103 | 
            +
                                    torch_reqs.append(req.strip().replace(" ", ""))
         | 
| 104 | 
            +
                                else:
         | 
| 105 | 
            +
                                    reqs.append(req.strip().replace(" ", ""))
         | 
| 76 106 | 
             
                            else:
         | 
| 77 107 | 
             
                                print(f"Ignoring requirement: {req} from frozen requirements")
         | 
| 78 108 | 
             
                            if len(reqs) >= CORES:
         | 
| 79 | 
            -
                                deps_failed = install_deps(reqs)
         | 
| 109 | 
            +
                                deps_failed = install_deps(reqs, opts=OPTS)
         | 
| 80 110 | 
             
                                reqs = []
         | 
| 81 111 | 
             
                                if deps_failed is not None:
         | 
| 82 112 | 
             
                                    failed = failed.union(deps_failed)
         | 
| 83 113 | 
             
                        if len(reqs) > 0:
         | 
| 84 | 
            -
                            deps_failed = install_deps(reqs)
         | 
| 114 | 
            +
                            deps_failed = install_deps(reqs, opts=OPTS)
         | 
| 85 115 | 
             
                            if deps_failed is not None:
         | 
| 86 116 | 
             
                                failed = failed.union(deps_failed)
         | 
| 87 117 | 
             
                        with open("_wandb_bootstrap_errors.json", "w") as f:
         | 
| 88 118 | 
             
                            f.write(json.dumps({"pip": list(failed)}))
         | 
| 89 119 | 
             
                        if len(failed) > 0:
         | 
| 90 120 | 
             
                            sys.stderr.write(
         | 
| 91 | 
            -
                                 | 
| 121 | 
            +
                                FAILED_PACKAGES_PREFIX + ",".join(failed) + FAILED_PACKAGES_POSTFIX
         | 
| 92 122 | 
             
                            )
         | 
| 93 123 | 
             
                            sys.stderr.flush()
         | 
| 124 | 
            +
                    install_deps(torch_reqs, extra_index=extra_index)
         | 
| 94 125 | 
             
                else:
         | 
| 95 126 | 
             
                    print("No frozen requirements found")
         | 
| 96 127 |  | 
| 97 128 |  | 
| 129 | 
            +
            # hacky way to get the name of the requirement that failed
         | 
| 130 | 
            +
            # attempt last word which is the name of the package often
         | 
| 131 | 
            +
            # fall back to checking all words in the line for the package name
         | 
| 132 | 
            +
            def find_package_in_error_string(deps: List[str], line: str) -> Optional[str]:
         | 
| 133 | 
            +
                # if the last word in the error string is in the list of deps, return it
         | 
| 134 | 
            +
                last_word = line.split(" ")[-1]
         | 
| 135 | 
            +
                if last_word in deps:
         | 
| 136 | 
            +
                    return last_word
         | 
| 137 | 
            +
                # if the last word is not in the list of deps, check all words
         | 
| 138 | 
            +
                # TODO: this could report the wrong package if the error string
         | 
| 139 | 
            +
                # contains a reference to another package in the deps
         | 
| 140 | 
            +
                # before the package that failed to install
         | 
| 141 | 
            +
                for word in line.split(" "):
         | 
| 142 | 
            +
                    if word in deps:
         | 
| 143 | 
            +
                        return word
         | 
| 144 | 
            +
                # if we can't find the package, return None
         | 
| 145 | 
            +
                return None
         | 
| 146 | 
            +
             | 
| 147 | 
            +
             | 
| 98 148 | 
             
            if __name__ == "__main__":
         | 
| 99 149 | 
             
                main()
         | 
| @@ -70,6 +70,10 @@ class AwsEnvironment(AbstractEnvironment): | |
| 70 70 | 
             
                        session = boto3.Session()
         | 
| 71 71 | 
             
                        region = region or session.region_name
         | 
| 72 72 | 
             
                        credentials = session.get_credentials()
         | 
| 73 | 
            +
                        if not credentials:
         | 
| 74 | 
            +
                            raise LaunchError(
         | 
| 75 | 
            +
                                "Could not create AWS environment from default environment. Please verify that your AWS credentials are configured correctly."
         | 
| 76 | 
            +
                            )
         | 
| 73 77 | 
             
                        access_key = credentials.access_key
         | 
| 74 78 | 
             
                        secret_key = credentials.secret_key
         | 
| 75 79 | 
             
                        session_token = credentials.token
         | 
    
        wandb/sdk/launch/launch.py
    CHANGED
    
    | @@ -55,7 +55,7 @@ def resolve_agent_config(  # noqa: C901 | |
| 55 55 | 
             
                    "api_key": api.api_key,
         | 
| 56 56 | 
             
                    "base_url": api.settings("base_url"),
         | 
| 57 57 | 
             
                    "registry": {},
         | 
| 58 | 
            -
                    " | 
| 58 | 
            +
                    "builder": {},
         | 
| 59 59 | 
             
                    "runner": {},
         | 
| 60 60 | 
             
                }
         | 
| 61 61 | 
             
                user_set_project = False
         | 
| @@ -151,7 +151,6 @@ def _run( | |
| 151 151 | 
             
                resource_args: Optional[Dict[str, Any]],
         | 
| 152 152 | 
             
                launch_config: Optional[Dict[str, Any]],
         | 
| 153 153 | 
             
                synchronous: Optional[bool],
         | 
| 154 | 
            -
                cuda: Optional[bool],
         | 
| 155 154 | 
             
                api: Api,
         | 
| 156 155 | 
             
                run_id: Optional[str],
         | 
| 157 156 | 
             
                repository: Optional[str],
         | 
| @@ -171,7 +170,6 @@ def _run( | |
| 171 170 | 
             
                    parameters,
         | 
| 172 171 | 
             
                    resource_args,
         | 
| 173 172 | 
             
                    launch_config,
         | 
| 174 | 
            -
                    cuda,
         | 
| 175 173 | 
             
                    run_id,
         | 
| 176 174 | 
             
                    repository,
         | 
| 177 175 | 
             
                )
         | 
| @@ -217,7 +215,6 @@ def run( | |
| 217 215 | 
             
                docker_image: Optional[str] = None,
         | 
| 218 216 | 
             
                config: Optional[Dict[str, Any]] = None,
         | 
| 219 217 | 
             
                synchronous: Optional[bool] = True,
         | 
| 220 | 
            -
                cuda: Optional[bool] = None,
         | 
| 221 218 | 
             
                run_id: Optional[str] = None,
         | 
| 222 219 | 
             
                repository: Optional[str] = None,
         | 
| 223 220 | 
             
            ) -> AbstractRun:
         | 
| @@ -247,7 +244,6 @@ def run( | |
| 247 244 | 
             
                    asynchronous runs launched via this method will be terminated. If
         | 
| 248 245 | 
             
                    ``synchronous`` is True and the run fails, the current process will
         | 
| 249 246 | 
             
                    error out as well.
         | 
| 250 | 
            -
                cuda: Whether to build a CUDA-enabled docker image or not
         | 
| 251 247 | 
             
                run_id: ID for the run (To ultimately replace the :name: field)
         | 
| 252 248 | 
             
                repository: string name of repository path for remote registry
         | 
| 253 249 |  | 
| @@ -290,7 +286,6 @@ def run( | |
| 290 286 | 
             
                    resource_args=resource_args,
         | 
| 291 287 | 
             
                    launch_config=config,
         | 
| 292 288 | 
             
                    synchronous=synchronous,
         | 
| 293 | 
            -
                    cuda=cuda,
         | 
| 294 289 | 
             
                    api=api,
         | 
| 295 290 | 
             
                    run_id=run_id,
         | 
| 296 291 | 
             
                    repository=repository,
         | 
    
        wandb/sdk/launch/launch_add.py
    CHANGED
    
    | @@ -44,7 +44,6 @@ def launch_add( | |
| 44 44 | 
             
                params: Optional[Dict[str, Any]] = None,
         | 
| 45 45 | 
             
                project_queue: Optional[str] = None,
         | 
| 46 46 | 
             
                resource_args: Optional[Dict[str, Any]] = None,
         | 
| 47 | 
            -
                cuda: Optional[bool] = None,
         | 
| 48 47 | 
             
                run_id: Optional[str] = None,
         | 
| 49 48 | 
             
                build: Optional[bool] = False,
         | 
| 50 49 | 
             
                repository: Optional[str] = None,
         | 
| @@ -69,7 +68,6 @@ def launch_add( | |
| 69 68 | 
             
                    the parameters used to run the original run.
         | 
| 70 69 | 
             
                resource_args: Resource related arguments for launching runs onto a remote backend.
         | 
| 71 70 | 
             
                    Will be stored on the constructed launch config under ``resource_args``.
         | 
| 72 | 
            -
                cuda: Whether to build a CUDA-enabled docker image or not
         | 
| 73 71 | 
             
                run_id: optional string indicating the id of the launched run
         | 
| 74 72 | 
             
                build: optional flag defaulting to false, requires queue to be set
         | 
| 75 73 | 
             
                    if build, an image is created, creates a job artifact, pushes a reference
         | 
| @@ -116,7 +114,6 @@ def launch_add( | |
| 116 114 | 
             
                    params,
         | 
| 117 115 | 
             
                    project_queue,
         | 
| 118 116 | 
             
                    resource_args,
         | 
| 119 | 
            -
                    cuda,
         | 
| 120 117 | 
             
                    run_id=run_id,
         | 
| 121 118 | 
             
                    build=build,
         | 
| 122 119 | 
             
                    repository=repository,
         | 
| @@ -139,7 +136,6 @@ def _launch_add( | |
| 139 136 | 
             
                params: Optional[Dict[str, Any]],
         | 
| 140 137 | 
             
                project_queue: Optional[str],
         | 
| 141 138 | 
             
                resource_args: Optional[Dict[str, Any]] = None,
         | 
| 142 | 
            -
                cuda: Optional[bool] = None,
         | 
| 143 139 | 
             
                run_id: Optional[str] = None,
         | 
| 144 140 | 
             
                build: Optional[bool] = False,
         | 
| 145 141 | 
             
                repository: Optional[str] = None,
         | 
| @@ -158,7 +154,6 @@ def _launch_add( | |
| 158 154 | 
             
                    params,
         | 
| 159 155 | 
             
                    resource_args,
         | 
| 160 156 | 
             
                    config,
         | 
| 161 | 
            -
                    cuda,
         | 
| 162 157 | 
             
                    run_id,
         | 
| 163 158 | 
             
                    repository,
         | 
| 164 159 | 
             
                )
         | 
| @@ -33,6 +33,18 @@ class AbstractRegistry(ABC): | |
| 33 33 | 
             
                    """
         | 
| 34 34 | 
             
                    raise NotImplementedError
         | 
| 35 35 |  | 
| 36 | 
            +
                @abstractmethod
         | 
| 37 | 
            +
                def check_image_exists(self, image_uri: str) -> bool:
         | 
| 38 | 
            +
                    """Check if an image exists in the registry.
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                    Arguments:
         | 
| 41 | 
            +
                        image_uri (str): The URI of the image.
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                    Returns:
         | 
| 44 | 
            +
                        bool: True if the image exists.
         | 
| 45 | 
            +
                    """
         | 
| 46 | 
            +
                    raise NotImplementedError
         | 
| 47 | 
            +
             | 
| 36 48 | 
             
                @classmethod
         | 
| 37 49 | 
             
                @abstractmethod
         | 
| 38 50 | 
             
                def from_config(
         | 
| @@ -87,7 +87,6 @@ class ElasticContainerRegistry(AbstractRegistry): | |
| 87 87 | 
             
                    try:
         | 
| 88 88 | 
             
                        session = self.environment.get_session()
         | 
| 89 89 | 
             
                        client = session.client("ecr")
         | 
| 90 | 
            -
                        response = client.describe_registry()
         | 
| 91 90 | 
             
                        response = client.describe_repositories(repositoryNames=[self.repo_name])
         | 
| 92 91 | 
             
                        self.uri = response["repositories"][0]["repositoryUri"].split("/")[0]
         | 
| 93 92 |  | 
| @@ -131,3 +130,34 @@ class ElasticContainerRegistry(AbstractRegistry): | |
| 131 130 | 
             
                        str: The uri of the repository.
         | 
| 132 131 | 
             
                    """
         | 
| 133 132 | 
             
                    return self.uri + "/" + self.repo_name
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                def check_image_exists(self, image_uri: str) -> bool:
         | 
| 135 | 
            +
                    """Check if the image tag exists.
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                    Arguments:
         | 
| 138 | 
            +
                        image_uri (str): The full image_uri.
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                    Returns:
         | 
| 141 | 
            +
                        bool: True if the image tag exists.
         | 
| 142 | 
            +
                    """
         | 
| 143 | 
            +
                    uri, tag = image_uri.split(":")
         | 
| 144 | 
            +
                    if uri != self.get_repo_uri():
         | 
| 145 | 
            +
                        raise LaunchError(
         | 
| 146 | 
            +
                            f"Image uri {image_uri} does not match Elastic Container Registry uri {self.get_repo_uri()}."
         | 
| 147 | 
            +
                        )
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                    _logger.debug("Checking if image tag exists.")
         | 
| 150 | 
            +
                    try:
         | 
| 151 | 
            +
                        session = self.environment.get_session()
         | 
| 152 | 
            +
                        client = session.client("ecr")
         | 
| 153 | 
            +
                        response = client.describe_images(
         | 
| 154 | 
            +
                            repositoryName=self.repo_name, imageIds=[{"imageTag": tag}]
         | 
| 155 | 
            +
                        )
         | 
| 156 | 
            +
                        return len(response["imageDetails"]) > 0
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                    except botocore.exceptions.ClientError as e:
         | 
| 159 | 
            +
                        code = e.response["Error"]["Code"]
         | 
| 160 | 
            +
                        if code == "ImageNotFoundException":
         | 
| 161 | 
            +
                            return False
         | 
| 162 | 
            +
                        msg = e.response["Error"]["Message"]
         | 
| 163 | 
            +
                        raise LaunchError(f"Error checking if image tag exists: {code} {msg}")
         | 
| @@ -169,3 +169,35 @@ class GoogleArtifactRegistry(AbstractRegistry): | |
| 169 169 | 
             
                        f"{self.environment.region}-docker.pkg.dev/"
         | 
| 170 170 | 
             
                        f"{self.environment.project}/{self.repository}/{self.image_name}"
         | 
| 171 171 | 
             
                    )
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                def check_image_exists(self, image_uri: str) -> bool:
         | 
| 174 | 
            +
                    """Check if the image exists.
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                    Arguments:
         | 
| 177 | 
            +
                        image_uri: The image URI.
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                    Returns:
         | 
| 180 | 
            +
                        True if the image exists, False otherwise.
         | 
| 181 | 
            +
                    """
         | 
| 182 | 
            +
                    _logger.info(
         | 
| 183 | 
            +
                        f"Checking if image {image_uri} exists. In Google Artifact Registry {self.uri}."
         | 
| 184 | 
            +
                    )
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                    return False
         | 
| 187 | 
            +
                    # TODO: Test GCP Artifact Registry image exists to get working
         | 
| 188 | 
            +
                    # repo_uri, _ = image_uri.split(":")
         | 
| 189 | 
            +
                    # if repo_uri != self.get_repo_uri():
         | 
| 190 | 
            +
                    #     raise LaunchError(
         | 
| 191 | 
            +
                    #         f"The image {image_uri} does not belong to the Google Artifact "
         | 
| 192 | 
            +
                    #         f"Repository {self.get_repo_uri()}."
         | 
| 193 | 
            +
                    #     )
         | 
| 194 | 
            +
                    # credentials = self.environment.get_credentials()
         | 
| 195 | 
            +
                    # request = google.cloud.artifactregistry.GetTagRequest(parent=image_uri)
         | 
| 196 | 
            +
                    # client = google.cloud.artifactregistry.ArtifactRegistryClient(
         | 
| 197 | 
            +
                    #     credentials=credentials
         | 
| 198 | 
            +
                    # )
         | 
| 199 | 
            +
                    # try:
         | 
| 200 | 
            +
                    #     client.get_tag(request=request)
         | 
| 201 | 
            +
                    #     return True
         | 
| 202 | 
            +
                    # except google.api_core.exceptions.NotFound:
         | 
| 203 | 
            +
                    #     return False
         | 
| @@ -1,11 +1,14 @@ | |
| 1 1 | 
             
            """Local registry implementation."""
         | 
| 2 | 
            +
            import logging
         | 
| 2 3 | 
             
            from typing import Tuple
         | 
| 3 4 |  | 
| 4 | 
            -
            from wandb.sdk.launch.utils import LaunchError
         | 
| 5 | 
            +
            from wandb.sdk.launch.utils import LaunchError, docker_image_exists
         | 
| 5 6 |  | 
| 6 7 | 
             
            from ..environment.abstract import AbstractEnvironment
         | 
| 7 8 | 
             
            from .abstract import AbstractRegistry
         | 
| 8 9 |  | 
| 10 | 
            +
            _logger = logging.getLogger(__name__)
         | 
| 11 | 
            +
             | 
| 9 12 |  | 
| 10 13 | 
             
            class LocalRegistry(AbstractRegistry):
         | 
| 11 14 | 
             
                """A local registry.
         | 
| @@ -46,3 +49,14 @@ class LocalRegistry(AbstractRegistry): | |
| 46 49 | 
             
                    Returns: An empty string.
         | 
| 47 50 | 
             
                    """
         | 
| 48 51 | 
             
                    return ""
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                def check_image_exists(self, image_uri: str) -> bool:
         | 
| 54 | 
            +
                    """Check if an image exists in the local registry.
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                    Arguments:
         | 
| 57 | 
            +
                        image_uri (str): The uri of the image.
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                    Returns:
         | 
| 60 | 
            +
                        bool: True.
         | 
| 61 | 
            +
                    """
         | 
| 62 | 
            +
                    return docker_image_exists(image_uri)
         | 
| @@ -10,7 +10,6 @@ from dockerpycreds.utils import find_executable  # type: ignore | |
| 10 10 | 
             
            import wandb
         | 
| 11 11 | 
             
            from wandb import Settings
         | 
| 12 12 | 
             
            from wandb.apis.internal import Api
         | 
| 13 | 
            -
            from wandb.errors import CommError
         | 
| 14 13 | 
             
            from wandb.sdk.launch.builder.abstract import AbstractBuilder
         | 
| 15 14 | 
             
            from wandb.sdk.lib import runid
         | 
| 16 15 |  | 
| @@ -143,19 +142,6 @@ class AbstractRunner(ABC): | |
| 143 142 | 
             
                        sys.exit(1)
         | 
| 144 143 | 
             
                    return True
         | 
| 145 144 |  | 
| 146 | 
            -
                def ack_run_queue_item(self, launch_project: LaunchProject) -> bool:
         | 
| 147 | 
            -
                    if self.backend_config.get("runQueueItemId"):
         | 
| 148 | 
            -
                        try:
         | 
| 149 | 
            -
                            self._api.ack_run_queue_item(
         | 
| 150 | 
            -
                                self.backend_config["runQueueItemId"], launch_project.run_id
         | 
| 151 | 
            -
                            )
         | 
| 152 | 
            -
                        except CommError:
         | 
| 153 | 
            -
                            wandb.termerror(
         | 
| 154 | 
            -
                                "Error acking run queue item. Item lease may have ended or another process may have acked it."
         | 
| 155 | 
            -
                            )
         | 
| 156 | 
            -
                            return False
         | 
| 157 | 
            -
                    return True
         | 
| 158 | 
            -
             | 
| 159 145 | 
             
                @abstractmethod
         | 
| 160 146 | 
             
                def run(
         | 
| 161 147 | 
             
                    self,
         |