wandb 0.13.10__py3-none-any.whl → 0.14.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- wandb/__init__.py +2 -3
- wandb/apis/__init__.py +1 -3
- wandb/apis/importers/__init__.py +4 -0
- wandb/apis/importers/base.py +312 -0
- wandb/apis/importers/mlflow.py +113 -0
- wandb/apis/internal.py +29 -2
- wandb/apis/normalize.py +6 -5
- wandb/apis/public.py +163 -180
- wandb/apis/reports/_templates.py +6 -12
- wandb/apis/reports/report.py +1 -1
- wandb/apis/reports/runset.py +1 -3
- wandb/apis/reports/util.py +12 -10
- wandb/beta/workflows.py +57 -34
- wandb/catboost/__init__.py +1 -2
- wandb/cli/cli.py +215 -133
- wandb/data_types.py +63 -56
- wandb/docker/__init__.py +78 -16
- wandb/docker/auth.py +21 -22
- wandb/env.py +0 -1
- wandb/errors/__init__.py +8 -116
- wandb/errors/term.py +1 -1
- wandb/fastai/__init__.py +1 -2
- wandb/filesync/dir_watcher.py +8 -5
- wandb/filesync/step_prepare.py +76 -75
- wandb/filesync/step_upload.py +1 -2
- wandb/integration/catboost/__init__.py +1 -3
- wandb/integration/catboost/catboost.py +8 -14
- wandb/integration/fastai/__init__.py +7 -13
- wandb/integration/gym/__init__.py +35 -4
- wandb/integration/keras/__init__.py +3 -3
- wandb/integration/keras/callbacks/metrics_logger.py +9 -8
- wandb/integration/keras/callbacks/model_checkpoint.py +9 -9
- wandb/integration/keras/callbacks/tables_builder.py +31 -19
- wandb/integration/kfp/kfp_patch.py +20 -17
- wandb/integration/kfp/wandb_logging.py +1 -2
- wandb/integration/lightgbm/__init__.py +21 -19
- wandb/integration/prodigy/prodigy.py +6 -7
- wandb/integration/sacred/__init__.py +9 -12
- wandb/integration/sagemaker/__init__.py +1 -3
- wandb/integration/sagemaker/auth.py +0 -1
- wandb/integration/sagemaker/config.py +1 -1
- wandb/integration/sagemaker/resources.py +1 -1
- wandb/integration/sb3/sb3.py +8 -4
- wandb/integration/tensorboard/__init__.py +1 -3
- wandb/integration/tensorboard/log.py +8 -8
- wandb/integration/tensorboard/monkeypatch.py +11 -9
- wandb/integration/tensorflow/__init__.py +1 -3
- wandb/integration/xgboost/__init__.py +4 -6
- wandb/integration/yolov8/__init__.py +7 -0
- wandb/integration/yolov8/yolov8.py +250 -0
- wandb/jupyter.py +31 -35
- wandb/lightgbm/__init__.py +1 -2
- wandb/old/settings.py +2 -2
- wandb/plot/bar.py +1 -2
- wandb/plot/confusion_matrix.py +1 -3
- wandb/plot/histogram.py +1 -2
- wandb/plot/line.py +1 -2
- wandb/plot/line_series.py +4 -4
- wandb/plot/pr_curve.py +17 -20
- wandb/plot/roc_curve.py +1 -3
- wandb/plot/scatter.py +1 -2
- wandb/proto/v3/wandb_server_pb2.py +85 -39
- wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v4/wandb_server_pb2.py +51 -39
- wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
- wandb/sdk/__init__.py +1 -3
- wandb/sdk/backend/backend.py +1 -1
- wandb/sdk/data_types/_dtypes.py +38 -30
- wandb/sdk/data_types/base_types/json_metadata.py +1 -3
- wandb/sdk/data_types/base_types/media.py +17 -17
- wandb/sdk/data_types/base_types/wb_value.py +33 -26
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +91 -125
- wandb/sdk/data_types/helper_types/classes.py +1 -1
- wandb/sdk/data_types/helper_types/image_mask.py +12 -12
- wandb/sdk/data_types/histogram.py +5 -4
- wandb/sdk/data_types/html.py +1 -2
- wandb/sdk/data_types/image.py +11 -11
- wandb/sdk/data_types/molecule.py +3 -6
- wandb/sdk/data_types/object_3d.py +1 -2
- wandb/sdk/data_types/plotly.py +1 -2
- wandb/sdk/data_types/saved_model.py +10 -8
- wandb/sdk/data_types/video.py +1 -1
- wandb/sdk/integration_utils/data_logging.py +5 -5
- wandb/sdk/interface/artifacts.py +288 -266
- wandb/sdk/interface/interface.py +2 -3
- wandb/sdk/interface/interface_grpc.py +1 -1
- wandb/sdk/interface/interface_queue.py +1 -1
- wandb/sdk/interface/interface_relay.py +1 -1
- wandb/sdk/interface/interface_shared.py +1 -2
- wandb/sdk/interface/interface_sock.py +1 -1
- wandb/sdk/interface/message_future.py +1 -1
- wandb/sdk/interface/message_future_poll.py +1 -1
- wandb/sdk/interface/router.py +1 -1
- wandb/sdk/interface/router_queue.py +1 -1
- wandb/sdk/interface/router_relay.py +1 -1
- wandb/sdk/interface/router_sock.py +1 -1
- wandb/sdk/interface/summary_record.py +1 -1
- wandb/sdk/internal/artifacts.py +1 -1
- wandb/sdk/internal/datastore.py +2 -3
- wandb/sdk/internal/file_pusher.py +5 -3
- wandb/sdk/internal/file_stream.py +22 -19
- wandb/sdk/internal/handler.py +5 -4
- wandb/sdk/internal/internal.py +1 -1
- wandb/sdk/internal/internal_api.py +115 -55
- wandb/sdk/internal/job_builder.py +1 -3
- wandb/sdk/internal/profiler.py +1 -1
- wandb/sdk/internal/progress.py +4 -6
- wandb/sdk/internal/sample.py +1 -3
- wandb/sdk/internal/sender.py +28 -16
- wandb/sdk/internal/settings_static.py +5 -5
- wandb/sdk/internal/system/assets/__init__.py +1 -0
- wandb/sdk/internal/system/assets/cpu.py +3 -9
- wandb/sdk/internal/system/assets/disk.py +2 -4
- wandb/sdk/internal/system/assets/gpu.py +6 -18
- wandb/sdk/internal/system/assets/gpu_apple.py +2 -4
- wandb/sdk/internal/system/assets/interfaces.py +50 -22
- wandb/sdk/internal/system/assets/ipu.py +1 -3
- wandb/sdk/internal/system/assets/memory.py +7 -13
- wandb/sdk/internal/system/assets/network.py +4 -8
- wandb/sdk/internal/system/assets/open_metrics.py +283 -0
- wandb/sdk/internal/system/assets/tpu.py +1 -4
- wandb/sdk/internal/system/assets/trainium.py +26 -14
- wandb/sdk/internal/system/system_info.py +2 -3
- wandb/sdk/internal/system/system_monitor.py +52 -20
- wandb/sdk/internal/tb_watcher.py +12 -13
- wandb/sdk/launch/_project_spec.py +54 -65
- wandb/sdk/launch/agent/agent.py +374 -90
- wandb/sdk/launch/builder/abstract.py +61 -7
- wandb/sdk/launch/builder/build.py +81 -110
- wandb/sdk/launch/builder/docker_builder.py +181 -0
- wandb/sdk/launch/builder/kaniko_builder.py +419 -0
- wandb/sdk/launch/builder/noop.py +31 -12
- wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +70 -20
- wandb/sdk/launch/environment/abstract.py +28 -0
- wandb/sdk/launch/environment/aws_environment.py +276 -0
- wandb/sdk/launch/environment/gcp_environment.py +271 -0
- wandb/sdk/launch/environment/local_environment.py +65 -0
- wandb/sdk/launch/github_reference.py +3 -8
- wandb/sdk/launch/launch.py +38 -29
- wandb/sdk/launch/launch_add.py +6 -8
- wandb/sdk/launch/loader.py +230 -0
- wandb/sdk/launch/registry/abstract.py +54 -0
- wandb/sdk/launch/registry/elastic_container_registry.py +163 -0
- wandb/sdk/launch/registry/google_artifact_registry.py +203 -0
- wandb/sdk/launch/registry/local_registry.py +62 -0
- wandb/sdk/launch/runner/abstract.py +1 -16
- wandb/sdk/launch/runner/{kubernetes.py → kubernetes_runner.py} +83 -95
- wandb/sdk/launch/runner/local_container.py +46 -22
- wandb/sdk/launch/runner/local_process.py +1 -4
- wandb/sdk/launch/runner/{aws.py → sagemaker_runner.py} +53 -212
- wandb/sdk/launch/runner/{gcp_vertex.py → vertex_runner.py} +38 -55
- wandb/sdk/launch/sweeps/__init__.py +3 -2
- wandb/sdk/launch/sweeps/scheduler.py +132 -39
- wandb/sdk/launch/sweeps/scheduler_sweep.py +80 -89
- wandb/sdk/launch/utils.py +101 -30
- wandb/sdk/launch/wandb_reference.py +2 -7
- wandb/sdk/lib/_settings_toposort_generate.py +166 -0
- wandb/sdk/lib/_settings_toposort_generated.py +201 -0
- wandb/sdk/lib/apikey.py +2 -4
- wandb/sdk/lib/config_util.py +4 -1
- wandb/sdk/lib/console.py +1 -3
- wandb/sdk/lib/deprecate.py +3 -3
- wandb/sdk/lib/file_stream_utils.py +7 -5
- wandb/sdk/lib/filenames.py +1 -1
- wandb/sdk/lib/filesystem.py +61 -5
- wandb/sdk/lib/git.py +1 -3
- wandb/sdk/lib/import_hooks.py +4 -7
- wandb/sdk/lib/ipython.py +8 -5
- wandb/sdk/lib/lazyloader.py +1 -3
- wandb/sdk/lib/mailbox.py +14 -4
- wandb/sdk/lib/proto_util.py +10 -5
- wandb/sdk/lib/redirect.py +15 -22
- wandb/sdk/lib/reporting.py +1 -3
- wandb/sdk/lib/retry.py +4 -5
- wandb/sdk/lib/runid.py +1 -3
- wandb/sdk/lib/server.py +15 -9
- wandb/sdk/lib/sock_client.py +1 -1
- wandb/sdk/lib/sparkline.py +1 -1
- wandb/sdk/lib/wburls.py +1 -1
- wandb/sdk/service/port_file.py +1 -2
- wandb/sdk/service/service.py +36 -13
- wandb/sdk/service/service_base.py +12 -1
- wandb/sdk/verify/verify.py +5 -7
- wandb/sdk/wandb_artifacts.py +142 -177
- wandb/sdk/wandb_config.py +5 -8
- wandb/sdk/wandb_helper.py +1 -1
- wandb/sdk/wandb_init.py +24 -13
- wandb/sdk/wandb_login.py +9 -9
- wandb/sdk/wandb_manager.py +39 -4
- wandb/sdk/wandb_metric.py +2 -6
- wandb/sdk/wandb_require.py +4 -15
- wandb/sdk/wandb_require_helpers.py +1 -9
- wandb/sdk/wandb_run.py +95 -141
- wandb/sdk/wandb_save.py +1 -3
- wandb/sdk/wandb_settings.py +149 -54
- wandb/sdk/wandb_setup.py +66 -46
- wandb/sdk/wandb_summary.py +13 -10
- wandb/sdk/wandb_sweep.py +6 -7
- wandb/sdk/wandb_watch.py +1 -1
- wandb/sklearn/calculate/confusion_matrix.py +1 -1
- wandb/sklearn/calculate/learning_curve.py +1 -1
- wandb/sklearn/calculate/summary_metrics.py +1 -3
- wandb/sklearn/plot/__init__.py +1 -1
- wandb/sklearn/plot/classifier.py +27 -18
- wandb/sklearn/plot/clusterer.py +4 -5
- wandb/sklearn/plot/regressor.py +4 -4
- wandb/sklearn/plot/shared.py +2 -2
- wandb/sync/__init__.py +1 -3
- wandb/sync/sync.py +4 -5
- wandb/testing/relay.py +11 -10
- wandb/trigger.py +1 -1
- wandb/util.py +106 -81
- wandb/viz.py +4 -4
- wandb/wandb_agent.py +50 -50
- wandb/wandb_controller.py +2 -3
- wandb/wandb_run.py +1 -2
- wandb/wandb_torch.py +1 -1
- wandb/xgboost/__init__.py +1 -2
- {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/METADATA +6 -2
- {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/RECORD +224 -209
- {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/WHEEL +1 -1
- wandb/sdk/launch/builder/docker.py +0 -80
- wandb/sdk/launch/builder/kaniko.py +0 -393
- wandb/sdk/launch/builder/loader.py +0 -32
- wandb/sdk/launch/runner/loader.py +0 -50
- {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/LICENSE +0 -0
- {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/entry_points.txt +0 -0
- {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/top_level.txt +0 -0
@@ -8,14 +8,17 @@ from typing import Any, Dict, List, Optional
|
|
8
8
|
|
9
9
|
import wandb
|
10
10
|
from wandb.sdk.launch.builder.abstract import AbstractBuilder
|
11
|
+
from wandb.sdk.launch.environment.abstract import AbstractEnvironment
|
11
12
|
|
12
|
-
from .._project_spec import LaunchProject,
|
13
|
-
from ..builder.build import
|
13
|
+
from .._project_spec import LaunchProject, compute_command_args
|
14
|
+
from ..builder.build import get_env_vars_dict
|
14
15
|
from ..utils import (
|
15
16
|
LOG_PREFIX,
|
16
17
|
PROJECT_SYNCHRONOUS,
|
17
18
|
_is_wandb_dev_uri,
|
18
19
|
_is_wandb_local_uri,
|
20
|
+
docker_image_exists,
|
21
|
+
pull_docker_image,
|
19
22
|
sanitize_wandb_api_key,
|
20
23
|
)
|
21
24
|
from .abstract import AbstractRun, AbstractRunner, Status
|
@@ -66,20 +69,24 @@ class LocalSubmittedRun(AbstractRun):
|
|
66
69
|
class LocalContainerRunner(AbstractRunner):
|
67
70
|
"""Runner class, uses a project to create a LocallySubmittedRun."""
|
68
71
|
|
72
|
+
def __init__(
|
73
|
+
self,
|
74
|
+
api: wandb.apis.internal.Api,
|
75
|
+
backend_config: Dict[str, Any],
|
76
|
+
environment: AbstractEnvironment,
|
77
|
+
) -> None:
|
78
|
+
super().__init__(api, backend_config)
|
79
|
+
self.environment = environment
|
80
|
+
|
69
81
|
def run(
|
70
82
|
self,
|
71
83
|
launch_project: LaunchProject,
|
72
|
-
builder: AbstractBuilder,
|
73
|
-
registry_config: Dict[str, Any],
|
84
|
+
builder: Optional[AbstractBuilder],
|
74
85
|
) -> Optional[AbstractRun]:
|
75
86
|
synchronous: bool = self.backend_config[PROJECT_SYNCHRONOUS]
|
76
87
|
docker_args: Dict[str, Any] = launch_project.resource_args.get(
|
77
88
|
"local-container", {}
|
78
89
|
)
|
79
|
-
# TODO: leaving this here because of existing CLI command
|
80
|
-
# we should likely just tell users to specify the gpus arg directly
|
81
|
-
if launch_project.cuda:
|
82
|
-
docker_args["gpus"] = "all"
|
83
90
|
|
84
91
|
if _is_wandb_local_uri(self._api.settings("base_url")):
|
85
92
|
if sys.platform == "win32":
|
@@ -107,28 +114,39 @@ class LocalContainerRunner(AbstractRunner):
|
|
107
114
|
image_uri = launch_project.image_name
|
108
115
|
if not docker_image_exists(image_uri):
|
109
116
|
pull_docker_image(image_uri)
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
)
|
117
|
+
entry_cmd = []
|
118
|
+
if entry_point is not None:
|
119
|
+
entry_cmd = entry_point.command
|
120
|
+
override_args = compute_command_args(launch_project.override_args)
|
115
121
|
command_str = " ".join(
|
116
|
-
get_docker_command(
|
122
|
+
get_docker_command(
|
123
|
+
image_uri,
|
124
|
+
env_vars,
|
125
|
+
entry_cmd=entry_cmd,
|
126
|
+
docker_args=docker_args,
|
127
|
+
additional_args=override_args,
|
128
|
+
)
|
117
129
|
).strip()
|
118
130
|
else:
|
119
131
|
assert entry_point is not None
|
120
|
-
|
132
|
+
_logger.info("Building docker image...")
|
133
|
+
assert builder is not None
|
121
134
|
image_uri = builder.build_image(
|
122
135
|
launch_project,
|
123
|
-
repository,
|
124
136
|
entry_point,
|
125
137
|
)
|
138
|
+
_logger.info(f"Docker image built with uri {image_uri}")
|
139
|
+
# entry_cmd and additional_args are empty here because
|
140
|
+
# if launch built the container they've been accounted
|
141
|
+
# in the dockerfile and env vars respectively
|
126
142
|
command_str = " ".join(
|
127
|
-
get_docker_command(
|
143
|
+
get_docker_command(
|
144
|
+
image_uri,
|
145
|
+
env_vars,
|
146
|
+
docker_args=docker_args,
|
147
|
+
)
|
128
148
|
).strip()
|
129
149
|
|
130
|
-
if not self.ack_run_queue_item(launch_project):
|
131
|
-
return None
|
132
150
|
sanitized_cmd_str = sanitize_wandb_api_key(command_str)
|
133
151
|
_msg = f"{LOG_PREFIX}Launching run in docker with command: {sanitized_cmd_str}"
|
134
152
|
wandb.termlog(_msg)
|
@@ -170,10 +188,11 @@ def _run_entry_point(command: str, work_dir: Optional[str]) -> AbstractRun:
|
|
170
188
|
def get_docker_command(
|
171
189
|
image: str,
|
172
190
|
env_vars: Dict[str, str],
|
173
|
-
entry_cmd: List[str],
|
191
|
+
entry_cmd: Optional[List[str]] = None,
|
174
192
|
docker_args: Optional[Dict[str, Any]] = None,
|
193
|
+
additional_args: Optional[List[str]] = None,
|
175
194
|
) -> List[str]:
|
176
|
-
"""
|
195
|
+
"""Construct the docker command using the image and docker args.
|
177
196
|
|
178
197
|
Arguments:
|
179
198
|
image: a Docker image to be run
|
@@ -202,8 +221,13 @@ def get_docker_command(
|
|
202
221
|
else:
|
203
222
|
cmd += [prefix, shlex.quote(str(value))]
|
204
223
|
|
224
|
+
if entry_cmd:
|
225
|
+
cmd += ["--entrypoint", entry_cmd[0]]
|
205
226
|
cmd += [shlex.quote(image)]
|
206
|
-
|
227
|
+
if entry_cmd and len(entry_cmd) > 1:
|
228
|
+
cmd += entry_cmd[1:]
|
229
|
+
if additional_args:
|
230
|
+
cmd += additional_args
|
207
231
|
return cmd
|
208
232
|
|
209
233
|
|
@@ -3,13 +3,13 @@ import shlex
|
|
3
3
|
from typing import Any, List, Optional
|
4
4
|
|
5
5
|
import wandb
|
6
|
-
from wandb.errors import LaunchError
|
7
6
|
|
8
7
|
from .._project_spec import LaunchProject, get_entry_point_command
|
9
8
|
from ..builder.build import get_env_vars_dict
|
10
9
|
from ..utils import (
|
11
10
|
LOG_PREFIX,
|
12
11
|
PROJECT_SYNCHRONOUS,
|
12
|
+
LaunchError,
|
13
13
|
_is_wandb_uri,
|
14
14
|
download_wandb_python_deps,
|
15
15
|
parse_wandb_uri,
|
@@ -81,9 +81,6 @@ class LocalProcessRunner(AbstractRunner):
|
|
81
81
|
for env_key, env_value in env_vars.items():
|
82
82
|
cmd += [f"{shlex.quote(env_key)}={shlex.quote(env_value)}"]
|
83
83
|
|
84
|
-
if not self.ack_run_queue_item(launch_project):
|
85
|
-
return None
|
86
|
-
|
87
84
|
entry_cmd = get_entry_point_command(entry_point, launch_project.override_args)
|
88
85
|
cmd += entry_cmd
|
89
86
|
|
@@ -1,23 +1,20 @@
|
|
1
|
-
|
1
|
+
"""Implementation of the SageMakerRunner class."""
|
2
2
|
import logging
|
3
|
-
import os
|
4
|
-
import subprocess
|
5
3
|
import time
|
6
|
-
from typing import Any, Dict, Optional,
|
4
|
+
from typing import Any, Dict, Optional, cast
|
7
5
|
|
8
6
|
if False:
|
9
7
|
import boto3 # type: ignore
|
10
8
|
|
11
9
|
import wandb
|
12
|
-
import wandb.docker as docker
|
13
10
|
from wandb.apis.internal import Api
|
14
|
-
from wandb.errors import LaunchError
|
15
11
|
from wandb.sdk.launch.builder.abstract import AbstractBuilder
|
16
|
-
from wandb.
|
12
|
+
from wandb.sdk.launch.environment.aws_environment import AwsEnvironment
|
13
|
+
from wandb.sdk.launch.utils import LaunchError
|
17
14
|
|
18
15
|
from .._project_spec import LaunchProject, get_entry_point_command
|
19
16
|
from ..builder.build import get_env_vars_dict
|
20
|
-
from ..utils import LOG_PREFIX, PROJECT_SYNCHRONOUS,
|
17
|
+
from ..utils import LOG_PREFIX, PROJECT_SYNCHRONOUS, to_camel_case
|
21
18
|
from .abstract import AbstractRun, AbstractRunner, Status
|
22
19
|
|
23
20
|
_logger = logging.getLogger(__name__)
|
@@ -69,32 +66,49 @@ class SagemakerSubmittedRun(AbstractRun):
|
|
69
66
|
return self._status
|
70
67
|
|
71
68
|
|
72
|
-
class
|
69
|
+
class SageMakerRunner(AbstractRunner):
|
73
70
|
"""Runner class, uses a project to create a SagemakerSubmittedRun."""
|
74
71
|
|
72
|
+
def __init__(
|
73
|
+
self, api: Api, backend_config: Dict[str, Any], environment: AwsEnvironment
|
74
|
+
) -> None:
|
75
|
+
"""Initialize the SagemakerRunner.
|
76
|
+
|
77
|
+
Arguments:
|
78
|
+
api (Api): The API instance.
|
79
|
+
backend_config (Dict[str, Any]): The backend configuration.
|
80
|
+
environment (AwsEnvironment): The AWS environment.
|
81
|
+
|
82
|
+
Raises:
|
83
|
+
LaunchError: If the runner cannot be initialized.
|
84
|
+
"""
|
85
|
+
super().__init__(api, backend_config)
|
86
|
+
self.environment = environment
|
87
|
+
|
75
88
|
def run(
|
76
89
|
self,
|
77
90
|
launch_project: LaunchProject,
|
78
|
-
builder: AbstractBuilder,
|
79
|
-
registry_config: Dict[str, Any],
|
91
|
+
builder: Optional[AbstractBuilder],
|
80
92
|
) -> Optional[AbstractRun]:
|
81
|
-
|
93
|
+
"""Run a project on Amazon Sagemaker.
|
82
94
|
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
95
|
+
Arguments:
|
96
|
+
launch_project (LaunchProject): The project to run.
|
97
|
+
builder (AbstractBuilder): The builder to use.
|
98
|
+
|
99
|
+
Returns:
|
100
|
+
Optional[AbstractRun]: The run instance.
|
101
|
+
|
102
|
+
Raises:
|
103
|
+
LaunchError: If the launch is unsuccessful.
|
104
|
+
"""
|
105
|
+
_logger.info("using AWSSagemakerRunner")
|
91
106
|
|
92
107
|
given_sagemaker_args = launch_project.resource_args.get("sagemaker")
|
93
108
|
if given_sagemaker_args is None:
|
94
109
|
raise LaunchError(
|
95
110
|
"No sagemaker args specified. Specify sagemaker args in resource_args"
|
96
111
|
)
|
97
|
-
validate_sagemaker_requirements(given_sagemaker_args, registry_config)
|
98
112
|
|
99
113
|
default_output_path = self.backend_config.get("runner", {}).get(
|
100
114
|
"s3_output_path"
|
@@ -104,37 +118,22 @@ class AWSSagemakerRunner(AbstractRunner):
|
|
104
118
|
):
|
105
119
|
default_output_path = f"s3://{default_output_path}"
|
106
120
|
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
client = boto3.client("sts")
|
111
|
-
instance_role = True
|
112
|
-
caller_id = client.get_caller_identity()
|
113
|
-
|
114
|
-
except botocore.exceptions.NoCredentialsError:
|
115
|
-
access_key, secret_key = get_aws_credentials(given_sagemaker_args)
|
116
|
-
client = boto3.client(
|
117
|
-
"sts", aws_access_key_id=access_key, aws_secret_access_key=secret_key
|
118
|
-
)
|
119
|
-
caller_id = client.get_caller_identity()
|
120
|
-
|
121
|
+
session = self.environment.get_session()
|
122
|
+
client = session.client("sts")
|
123
|
+
caller_id = client.get_caller_identity()
|
121
124
|
account_id = caller_id["Account"]
|
125
|
+
_logger.info(f"Using account ID {account_id}")
|
122
126
|
role_arn = get_role_arn(given_sagemaker_args, self.backend_config, account_id)
|
123
127
|
entry_point = launch_project.get_single_entry_point()
|
128
|
+
|
129
|
+
# Create a sagemaker client to launch the job.
|
130
|
+
sagemaker_client = session.client("sagemaker")
|
131
|
+
|
124
132
|
# if the user provided the image they want to use, use that, but warn it won't have swappable artifacts
|
125
133
|
if (
|
126
134
|
given_sagemaker_args.get("AlgorithmSpecification", {}).get("TrainingImage")
|
127
135
|
is not None
|
128
136
|
):
|
129
|
-
if instance_role:
|
130
|
-
sagemaker_client = boto3.client("sagemaker", region_name=region)
|
131
|
-
else:
|
132
|
-
sagemaker_client = boto3.client(
|
133
|
-
"sagemaker",
|
134
|
-
region_name=region,
|
135
|
-
aws_access_key_id=access_key,
|
136
|
-
aws_secret_access_key=secret_key,
|
137
|
-
)
|
138
137
|
sagemaker_args = build_sagemaker_args(
|
139
138
|
launch_project,
|
140
139
|
self._api,
|
@@ -152,57 +151,20 @@ class AWSSagemakerRunner(AbstractRunner):
|
|
152
151
|
run.wait()
|
153
152
|
return run
|
154
153
|
|
155
|
-
_logger.info("Connecting to AWS ECR Client")
|
156
|
-
if instance_role:
|
157
|
-
ecr_client = boto3.client("ecr", region_name=region)
|
158
|
-
else:
|
159
|
-
ecr_client = boto3.client(
|
160
|
-
"ecr",
|
161
|
-
region_name=region,
|
162
|
-
aws_access_key_id=access_key,
|
163
|
-
aws_secret_access_key=secret_key,
|
164
|
-
)
|
165
|
-
repository = get_ecr_repository_url(
|
166
|
-
ecr_client, given_sagemaker_args, registry_config
|
167
|
-
)
|
168
|
-
# TODO: handle login credentials gracefully
|
169
|
-
login_credentials = registry_config.get("credentials")
|
170
|
-
if login_credentials is not None:
|
171
|
-
wandb.termwarn(
|
172
|
-
"Ignoring registry credentials for ECR, using those found on the system"
|
173
|
-
)
|
174
|
-
|
175
|
-
if builder.type != "kaniko":
|
176
|
-
_logger.info("Logging in to AWS ECR")
|
177
|
-
login_resp = aws_ecr_login(region, repository)
|
178
|
-
if login_resp is None or "Login Succeeded" not in login_resp:
|
179
|
-
raise LaunchError(f"Unable to login to ECR, response: {login_resp}")
|
180
|
-
|
181
154
|
if launch_project.docker_image:
|
182
155
|
image = launch_project.docker_image
|
183
156
|
else:
|
184
157
|
assert entry_point is not None
|
158
|
+
assert builder is not None
|
185
159
|
# build our own image
|
160
|
+
_logger.info("Building docker image...")
|
186
161
|
image = builder.build_image(
|
187
162
|
launch_project,
|
188
|
-
repository,
|
189
163
|
entry_point,
|
190
164
|
)
|
191
|
-
|
192
|
-
if not self.ack_run_queue_item(launch_project):
|
193
|
-
return None
|
165
|
+
_logger.info(f"Docker image built with uri {image}")
|
194
166
|
|
195
167
|
_logger.info("Connecting to sagemaker client")
|
196
|
-
if instance_role:
|
197
|
-
sagemaker_client = boto3.client("sagemaker", region_name=region)
|
198
|
-
else:
|
199
|
-
sagemaker_client = boto3.client(
|
200
|
-
"sagemaker",
|
201
|
-
region_name=region,
|
202
|
-
aws_access_key_id=access_key,
|
203
|
-
aws_secret_access_key=secret_key,
|
204
|
-
)
|
205
|
-
|
206
168
|
command_args = get_entry_point_command(
|
207
169
|
entry_point, launch_project.override_args
|
208
170
|
)
|
@@ -225,29 +187,15 @@ class AWSSagemakerRunner(AbstractRunner):
|
|
225
187
|
return run
|
226
188
|
|
227
189
|
|
228
|
-
def aws_ecr_login(region: str, registry: str) -> Optional[str]:
|
229
|
-
pw_command = ["aws", "ecr", "get-login-password", "--region", region]
|
230
|
-
try:
|
231
|
-
pw = run_shell(pw_command)[0]
|
232
|
-
except subprocess.CalledProcessError:
|
233
|
-
raise LaunchError(
|
234
|
-
"Unable to get login password. Please ensure you have AWS credentials configured"
|
235
|
-
)
|
236
|
-
try:
|
237
|
-
docker_login_process = docker.login("AWS", pw, registry)
|
238
|
-
except Exception:
|
239
|
-
raise LaunchError(f"Failed to login to ECR {registry}")
|
240
|
-
return docker_login_process
|
241
|
-
|
242
|
-
|
243
190
|
def merge_aws_tag_with_algorithm_specification(
|
244
191
|
algorithm_specification: Optional[Dict[str, Any]], aws_tag: Optional[str]
|
245
192
|
) -> Dict[str, Any]:
|
246
|
-
"""
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
193
|
+
"""Create an AWS AlgorithmSpecification.
|
194
|
+
|
195
|
+
AWS Sagemaker algorithms require a training image and an input mode. If the user
|
196
|
+
does not specify the specification themselves, define the spec minimally using these
|
197
|
+
two fields. Otherwise, if they specify the AlgorithmSpecification set the training
|
198
|
+
image if it is not set.
|
251
199
|
"""
|
252
200
|
if algorithm_specification is None:
|
253
201
|
return {
|
@@ -366,65 +314,10 @@ def launch_sagemaker_job(
|
|
366
314
|
return run
|
367
315
|
|
368
316
|
|
369
|
-
def get_region(
|
370
|
-
sagemaker_args: Dict[str, Any], registry_config_region: Optional[str] = None
|
371
|
-
) -> str:
|
372
|
-
region = sagemaker_args.get("region")
|
373
|
-
if region is None:
|
374
|
-
region = registry_config_region
|
375
|
-
if region is None:
|
376
|
-
region = os.environ.get("AWS_DEFAULT_REGION")
|
377
|
-
if region is None and os.path.exists(os.path.expanduser("~/.aws/config")):
|
378
|
-
config = configparser.ConfigParser()
|
379
|
-
config.read(os.path.expanduser("~/.aws/config"))
|
380
|
-
section = sagemaker_args.get("profile") or "default"
|
381
|
-
try:
|
382
|
-
region = config.get(section, "region")
|
383
|
-
except (configparser.NoOptionError, configparser.NoSectionError):
|
384
|
-
raise LaunchError(
|
385
|
-
"Unable to detemine default region from ~/.aws/config. "
|
386
|
-
"Please specify region in resource args or specify config "
|
387
|
-
"section as 'profile'"
|
388
|
-
)
|
389
|
-
|
390
|
-
if region is None:
|
391
|
-
raise LaunchError(
|
392
|
-
"AWS region not specified and ~/.aws/config not found. Configure AWS"
|
393
|
-
)
|
394
|
-
assert isinstance(region, str)
|
395
|
-
return region
|
396
|
-
|
397
|
-
|
398
|
-
def get_aws_credentials(sagemaker_args: Dict[str, Any]) -> Tuple[str, str]:
|
399
|
-
access_key = os.environ.get("AWS_ACCESS_KEY_ID")
|
400
|
-
secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY")
|
401
|
-
if (
|
402
|
-
access_key is None
|
403
|
-
or secret_key is None
|
404
|
-
and os.path.exists(os.path.expanduser("~/.aws/credentials"))
|
405
|
-
):
|
406
|
-
profile = sagemaker_args.get("profile") or "default"
|
407
|
-
config = configparser.ConfigParser()
|
408
|
-
config.read(os.path.expanduser("~/.aws/credentials"))
|
409
|
-
try:
|
410
|
-
access_key = config.get(profile, "aws_access_key_id")
|
411
|
-
secret_key = config.get(profile, "aws_secret_access_key")
|
412
|
-
except (configparser.NoOptionError, configparser.NoSectionError):
|
413
|
-
raise LaunchError(
|
414
|
-
"Unable to get aws credentials from ~/.aws/credentials. "
|
415
|
-
"Please set aws credentials in environments variables, or "
|
416
|
-
"check your credentials in ~/.aws/credentials. Use resource "
|
417
|
-
"args to specify the profile using 'profile'"
|
418
|
-
)
|
419
|
-
|
420
|
-
if access_key is None or secret_key is None:
|
421
|
-
raise LaunchError("AWS credentials not found")
|
422
|
-
return access_key, secret_key
|
423
|
-
|
424
|
-
|
425
317
|
def get_role_arn(
|
426
318
|
sagemaker_args: Dict[str, Any], backend_config: Dict[str, Any], account_id: str
|
427
319
|
) -> str:
|
320
|
+
"""Get the role arn from the sagemaker args or the backend config."""
|
428
321
|
role_arn = sagemaker_args.get("RoleArn") or sagemaker_args.get("role_arn")
|
429
322
|
if role_arn is None:
|
430
323
|
role_arn = backend_config.get("runner", {}).get("role_arn")
|
@@ -437,55 +330,3 @@ def get_role_arn(
|
|
437
330
|
return role_arn
|
438
331
|
|
439
332
|
return f"arn:aws:iam::{account_id}:role/{role_arn}"
|
440
|
-
|
441
|
-
|
442
|
-
def validate_sagemaker_requirements(
|
443
|
-
given_sagemaker_args: Dict[str, Any], registry_config: Dict[str, Any]
|
444
|
-
) -> None:
|
445
|
-
if (
|
446
|
-
given_sagemaker_args.get(
|
447
|
-
"EcrRepoName", given_sagemaker_args.get("ecr_repo_name")
|
448
|
-
)
|
449
|
-
is None
|
450
|
-
and registry_config.get("url") is None
|
451
|
-
):
|
452
|
-
raise LaunchError(
|
453
|
-
"AWS sagemaker requires an ECR Repository to push the container to "
|
454
|
-
"set this by adding a `EcrRepoName` key to the sagemaker"
|
455
|
-
"field of resource_args or through the url key in the registry section "
|
456
|
-
"of the launch agent config."
|
457
|
-
)
|
458
|
-
|
459
|
-
if registry_config.get("ecr-repo-provider", "aws").lower() != "aws":
|
460
|
-
raise LaunchError(
|
461
|
-
"Sagemaker jobs requires an AWS ECR Repo to push the container to"
|
462
|
-
)
|
463
|
-
|
464
|
-
|
465
|
-
def get_ecr_repository_url(
|
466
|
-
ecr_client: "boto3.Client",
|
467
|
-
given_sagemaker_args: Dict[str, Any],
|
468
|
-
registry_config: Dict[str, Any],
|
469
|
-
) -> str:
|
470
|
-
token = ecr_client.get_authorization_token()
|
471
|
-
ecr_repo_name = given_sagemaker_args.get(
|
472
|
-
"EcrRepoName", given_sagemaker_args.get("ecr_repo_name")
|
473
|
-
)
|
474
|
-
if ecr_repo_name:
|
475
|
-
if not isinstance(ecr_repo_name, str):
|
476
|
-
raise LaunchError("EcrRepoName must be a string")
|
477
|
-
if not ecr_repo_name.startswith("arn:aws:ecr:"):
|
478
|
-
repository = cast(
|
479
|
-
str,
|
480
|
-
token["authorizationData"][0]["proxyEndpoint"].replace("https://", "")
|
481
|
-
+ f"/{ecr_repo_name}",
|
482
|
-
)
|
483
|
-
else:
|
484
|
-
repository = ecr_repo_name
|
485
|
-
else:
|
486
|
-
repository = cast(str, registry_config.get("url", ""))
|
487
|
-
if not repository:
|
488
|
-
raise LaunchError(
|
489
|
-
"Must provide a repository url either through resource args or launch config file"
|
490
|
-
)
|
491
|
-
return repository
|