wandb 0.13.10__py3-none-any.whl → 0.14.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (228) hide show
  1. wandb/__init__.py +2 -3
  2. wandb/apis/__init__.py +1 -3
  3. wandb/apis/importers/__init__.py +4 -0
  4. wandb/apis/importers/base.py +312 -0
  5. wandb/apis/importers/mlflow.py +113 -0
  6. wandb/apis/internal.py +29 -2
  7. wandb/apis/normalize.py +6 -5
  8. wandb/apis/public.py +163 -180
  9. wandb/apis/reports/_templates.py +6 -12
  10. wandb/apis/reports/report.py +1 -1
  11. wandb/apis/reports/runset.py +1 -3
  12. wandb/apis/reports/util.py +12 -10
  13. wandb/beta/workflows.py +57 -34
  14. wandb/catboost/__init__.py +1 -2
  15. wandb/cli/cli.py +215 -133
  16. wandb/data_types.py +63 -56
  17. wandb/docker/__init__.py +78 -16
  18. wandb/docker/auth.py +21 -22
  19. wandb/env.py +0 -1
  20. wandb/errors/__init__.py +8 -116
  21. wandb/errors/term.py +1 -1
  22. wandb/fastai/__init__.py +1 -2
  23. wandb/filesync/dir_watcher.py +8 -5
  24. wandb/filesync/step_prepare.py +76 -75
  25. wandb/filesync/step_upload.py +1 -2
  26. wandb/integration/catboost/__init__.py +1 -3
  27. wandb/integration/catboost/catboost.py +8 -14
  28. wandb/integration/fastai/__init__.py +7 -13
  29. wandb/integration/gym/__init__.py +35 -4
  30. wandb/integration/keras/__init__.py +3 -3
  31. wandb/integration/keras/callbacks/metrics_logger.py +9 -8
  32. wandb/integration/keras/callbacks/model_checkpoint.py +9 -9
  33. wandb/integration/keras/callbacks/tables_builder.py +31 -19
  34. wandb/integration/kfp/kfp_patch.py +20 -17
  35. wandb/integration/kfp/wandb_logging.py +1 -2
  36. wandb/integration/lightgbm/__init__.py +21 -19
  37. wandb/integration/prodigy/prodigy.py +6 -7
  38. wandb/integration/sacred/__init__.py +9 -12
  39. wandb/integration/sagemaker/__init__.py +1 -3
  40. wandb/integration/sagemaker/auth.py +0 -1
  41. wandb/integration/sagemaker/config.py +1 -1
  42. wandb/integration/sagemaker/resources.py +1 -1
  43. wandb/integration/sb3/sb3.py +8 -4
  44. wandb/integration/tensorboard/__init__.py +1 -3
  45. wandb/integration/tensorboard/log.py +8 -8
  46. wandb/integration/tensorboard/monkeypatch.py +11 -9
  47. wandb/integration/tensorflow/__init__.py +1 -3
  48. wandb/integration/xgboost/__init__.py +4 -6
  49. wandb/integration/yolov8/__init__.py +7 -0
  50. wandb/integration/yolov8/yolov8.py +250 -0
  51. wandb/jupyter.py +31 -35
  52. wandb/lightgbm/__init__.py +1 -2
  53. wandb/old/settings.py +2 -2
  54. wandb/plot/bar.py +1 -2
  55. wandb/plot/confusion_matrix.py +1 -3
  56. wandb/plot/histogram.py +1 -2
  57. wandb/plot/line.py +1 -2
  58. wandb/plot/line_series.py +4 -4
  59. wandb/plot/pr_curve.py +17 -20
  60. wandb/plot/roc_curve.py +1 -3
  61. wandb/plot/scatter.py +1 -2
  62. wandb/proto/v3/wandb_server_pb2.py +85 -39
  63. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  64. wandb/proto/v4/wandb_server_pb2.py +51 -39
  65. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  66. wandb/sdk/__init__.py +1 -3
  67. wandb/sdk/backend/backend.py +1 -1
  68. wandb/sdk/data_types/_dtypes.py +38 -30
  69. wandb/sdk/data_types/base_types/json_metadata.py +1 -3
  70. wandb/sdk/data_types/base_types/media.py +17 -17
  71. wandb/sdk/data_types/base_types/wb_value.py +33 -26
  72. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +91 -125
  73. wandb/sdk/data_types/helper_types/classes.py +1 -1
  74. wandb/sdk/data_types/helper_types/image_mask.py +12 -12
  75. wandb/sdk/data_types/histogram.py +5 -4
  76. wandb/sdk/data_types/html.py +1 -2
  77. wandb/sdk/data_types/image.py +11 -11
  78. wandb/sdk/data_types/molecule.py +3 -6
  79. wandb/sdk/data_types/object_3d.py +1 -2
  80. wandb/sdk/data_types/plotly.py +1 -2
  81. wandb/sdk/data_types/saved_model.py +10 -8
  82. wandb/sdk/data_types/video.py +1 -1
  83. wandb/sdk/integration_utils/data_logging.py +5 -5
  84. wandb/sdk/interface/artifacts.py +288 -266
  85. wandb/sdk/interface/interface.py +2 -3
  86. wandb/sdk/interface/interface_grpc.py +1 -1
  87. wandb/sdk/interface/interface_queue.py +1 -1
  88. wandb/sdk/interface/interface_relay.py +1 -1
  89. wandb/sdk/interface/interface_shared.py +1 -2
  90. wandb/sdk/interface/interface_sock.py +1 -1
  91. wandb/sdk/interface/message_future.py +1 -1
  92. wandb/sdk/interface/message_future_poll.py +1 -1
  93. wandb/sdk/interface/router.py +1 -1
  94. wandb/sdk/interface/router_queue.py +1 -1
  95. wandb/sdk/interface/router_relay.py +1 -1
  96. wandb/sdk/interface/router_sock.py +1 -1
  97. wandb/sdk/interface/summary_record.py +1 -1
  98. wandb/sdk/internal/artifacts.py +1 -1
  99. wandb/sdk/internal/datastore.py +2 -3
  100. wandb/sdk/internal/file_pusher.py +5 -3
  101. wandb/sdk/internal/file_stream.py +22 -19
  102. wandb/sdk/internal/handler.py +5 -4
  103. wandb/sdk/internal/internal.py +1 -1
  104. wandb/sdk/internal/internal_api.py +115 -55
  105. wandb/sdk/internal/job_builder.py +1 -3
  106. wandb/sdk/internal/profiler.py +1 -1
  107. wandb/sdk/internal/progress.py +4 -6
  108. wandb/sdk/internal/sample.py +1 -3
  109. wandb/sdk/internal/sender.py +28 -16
  110. wandb/sdk/internal/settings_static.py +5 -5
  111. wandb/sdk/internal/system/assets/__init__.py +1 -0
  112. wandb/sdk/internal/system/assets/cpu.py +3 -9
  113. wandb/sdk/internal/system/assets/disk.py +2 -4
  114. wandb/sdk/internal/system/assets/gpu.py +6 -18
  115. wandb/sdk/internal/system/assets/gpu_apple.py +2 -4
  116. wandb/sdk/internal/system/assets/interfaces.py +50 -22
  117. wandb/sdk/internal/system/assets/ipu.py +1 -3
  118. wandb/sdk/internal/system/assets/memory.py +7 -13
  119. wandb/sdk/internal/system/assets/network.py +4 -8
  120. wandb/sdk/internal/system/assets/open_metrics.py +283 -0
  121. wandb/sdk/internal/system/assets/tpu.py +1 -4
  122. wandb/sdk/internal/system/assets/trainium.py +26 -14
  123. wandb/sdk/internal/system/system_info.py +2 -3
  124. wandb/sdk/internal/system/system_monitor.py +52 -20
  125. wandb/sdk/internal/tb_watcher.py +12 -13
  126. wandb/sdk/launch/_project_spec.py +54 -65
  127. wandb/sdk/launch/agent/agent.py +374 -90
  128. wandb/sdk/launch/builder/abstract.py +61 -7
  129. wandb/sdk/launch/builder/build.py +81 -110
  130. wandb/sdk/launch/builder/docker_builder.py +181 -0
  131. wandb/sdk/launch/builder/kaniko_builder.py +419 -0
  132. wandb/sdk/launch/builder/noop.py +31 -12
  133. wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +70 -20
  134. wandb/sdk/launch/environment/abstract.py +28 -0
  135. wandb/sdk/launch/environment/aws_environment.py +276 -0
  136. wandb/sdk/launch/environment/gcp_environment.py +271 -0
  137. wandb/sdk/launch/environment/local_environment.py +65 -0
  138. wandb/sdk/launch/github_reference.py +3 -8
  139. wandb/sdk/launch/launch.py +38 -29
  140. wandb/sdk/launch/launch_add.py +6 -8
  141. wandb/sdk/launch/loader.py +230 -0
  142. wandb/sdk/launch/registry/abstract.py +54 -0
  143. wandb/sdk/launch/registry/elastic_container_registry.py +163 -0
  144. wandb/sdk/launch/registry/google_artifact_registry.py +203 -0
  145. wandb/sdk/launch/registry/local_registry.py +62 -0
  146. wandb/sdk/launch/runner/abstract.py +1 -16
  147. wandb/sdk/launch/runner/{kubernetes.py → kubernetes_runner.py} +83 -95
  148. wandb/sdk/launch/runner/local_container.py +46 -22
  149. wandb/sdk/launch/runner/local_process.py +1 -4
  150. wandb/sdk/launch/runner/{aws.py → sagemaker_runner.py} +53 -212
  151. wandb/sdk/launch/runner/{gcp_vertex.py → vertex_runner.py} +38 -55
  152. wandb/sdk/launch/sweeps/__init__.py +3 -2
  153. wandb/sdk/launch/sweeps/scheduler.py +132 -39
  154. wandb/sdk/launch/sweeps/scheduler_sweep.py +80 -89
  155. wandb/sdk/launch/utils.py +101 -30
  156. wandb/sdk/launch/wandb_reference.py +2 -7
  157. wandb/sdk/lib/_settings_toposort_generate.py +166 -0
  158. wandb/sdk/lib/_settings_toposort_generated.py +201 -0
  159. wandb/sdk/lib/apikey.py +2 -4
  160. wandb/sdk/lib/config_util.py +4 -1
  161. wandb/sdk/lib/console.py +1 -3
  162. wandb/sdk/lib/deprecate.py +3 -3
  163. wandb/sdk/lib/file_stream_utils.py +7 -5
  164. wandb/sdk/lib/filenames.py +1 -1
  165. wandb/sdk/lib/filesystem.py +61 -5
  166. wandb/sdk/lib/git.py +1 -3
  167. wandb/sdk/lib/import_hooks.py +4 -7
  168. wandb/sdk/lib/ipython.py +8 -5
  169. wandb/sdk/lib/lazyloader.py +1 -3
  170. wandb/sdk/lib/mailbox.py +14 -4
  171. wandb/sdk/lib/proto_util.py +10 -5
  172. wandb/sdk/lib/redirect.py +15 -22
  173. wandb/sdk/lib/reporting.py +1 -3
  174. wandb/sdk/lib/retry.py +4 -5
  175. wandb/sdk/lib/runid.py +1 -3
  176. wandb/sdk/lib/server.py +15 -9
  177. wandb/sdk/lib/sock_client.py +1 -1
  178. wandb/sdk/lib/sparkline.py +1 -1
  179. wandb/sdk/lib/wburls.py +1 -1
  180. wandb/sdk/service/port_file.py +1 -2
  181. wandb/sdk/service/service.py +36 -13
  182. wandb/sdk/service/service_base.py +12 -1
  183. wandb/sdk/verify/verify.py +5 -7
  184. wandb/sdk/wandb_artifacts.py +142 -177
  185. wandb/sdk/wandb_config.py +5 -8
  186. wandb/sdk/wandb_helper.py +1 -1
  187. wandb/sdk/wandb_init.py +24 -13
  188. wandb/sdk/wandb_login.py +9 -9
  189. wandb/sdk/wandb_manager.py +39 -4
  190. wandb/sdk/wandb_metric.py +2 -6
  191. wandb/sdk/wandb_require.py +4 -15
  192. wandb/sdk/wandb_require_helpers.py +1 -9
  193. wandb/sdk/wandb_run.py +95 -141
  194. wandb/sdk/wandb_save.py +1 -3
  195. wandb/sdk/wandb_settings.py +149 -54
  196. wandb/sdk/wandb_setup.py +66 -46
  197. wandb/sdk/wandb_summary.py +13 -10
  198. wandb/sdk/wandb_sweep.py +6 -7
  199. wandb/sdk/wandb_watch.py +1 -1
  200. wandb/sklearn/calculate/confusion_matrix.py +1 -1
  201. wandb/sklearn/calculate/learning_curve.py +1 -1
  202. wandb/sklearn/calculate/summary_metrics.py +1 -3
  203. wandb/sklearn/plot/__init__.py +1 -1
  204. wandb/sklearn/plot/classifier.py +27 -18
  205. wandb/sklearn/plot/clusterer.py +4 -5
  206. wandb/sklearn/plot/regressor.py +4 -4
  207. wandb/sklearn/plot/shared.py +2 -2
  208. wandb/sync/__init__.py +1 -3
  209. wandb/sync/sync.py +4 -5
  210. wandb/testing/relay.py +11 -10
  211. wandb/trigger.py +1 -1
  212. wandb/util.py +106 -81
  213. wandb/viz.py +4 -4
  214. wandb/wandb_agent.py +50 -50
  215. wandb/wandb_controller.py +2 -3
  216. wandb/wandb_run.py +1 -2
  217. wandb/wandb_torch.py +1 -1
  218. wandb/xgboost/__init__.py +1 -2
  219. {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/METADATA +6 -2
  220. {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/RECORD +224 -209
  221. {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/WHEEL +1 -1
  222. wandb/sdk/launch/builder/docker.py +0 -80
  223. wandb/sdk/launch/builder/kaniko.py +0 -393
  224. wandb/sdk/launch/builder/loader.py +0 -32
  225. wandb/sdk/launch/runner/loader.py +0 -50
  226. {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/LICENSE +0 -0
  227. {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/entry_points.txt +0 -0
  228. {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/top_level.txt +0 -0
@@ -8,14 +8,17 @@ from typing import Any, Dict, List, Optional
8
8
 
9
9
  import wandb
10
10
  from wandb.sdk.launch.builder.abstract import AbstractBuilder
11
+ from wandb.sdk.launch.environment.abstract import AbstractEnvironment
11
12
 
12
- from .._project_spec import LaunchProject, get_entry_point_command
13
- from ..builder.build import docker_image_exists, get_env_vars_dict, pull_docker_image
13
+ from .._project_spec import LaunchProject, compute_command_args
14
+ from ..builder.build import get_env_vars_dict
14
15
  from ..utils import (
15
16
  LOG_PREFIX,
16
17
  PROJECT_SYNCHRONOUS,
17
18
  _is_wandb_dev_uri,
18
19
  _is_wandb_local_uri,
20
+ docker_image_exists,
21
+ pull_docker_image,
19
22
  sanitize_wandb_api_key,
20
23
  )
21
24
  from .abstract import AbstractRun, AbstractRunner, Status
@@ -66,20 +69,24 @@ class LocalSubmittedRun(AbstractRun):
66
69
  class LocalContainerRunner(AbstractRunner):
67
70
  """Runner class, uses a project to create a LocallySubmittedRun."""
68
71
 
72
+ def __init__(
73
+ self,
74
+ api: wandb.apis.internal.Api,
75
+ backend_config: Dict[str, Any],
76
+ environment: AbstractEnvironment,
77
+ ) -> None:
78
+ super().__init__(api, backend_config)
79
+ self.environment = environment
80
+
69
81
  def run(
70
82
  self,
71
83
  launch_project: LaunchProject,
72
- builder: AbstractBuilder,
73
- registry_config: Dict[str, Any],
84
+ builder: Optional[AbstractBuilder],
74
85
  ) -> Optional[AbstractRun]:
75
86
  synchronous: bool = self.backend_config[PROJECT_SYNCHRONOUS]
76
87
  docker_args: Dict[str, Any] = launch_project.resource_args.get(
77
88
  "local-container", {}
78
89
  )
79
- # TODO: leaving this here because of existing CLI command
80
- # we should likely just tell users to specify the gpus arg directly
81
- if launch_project.cuda:
82
- docker_args["gpus"] = "all"
83
90
 
84
91
  if _is_wandb_local_uri(self._api.settings("base_url")):
85
92
  if sys.platform == "win32":
@@ -107,28 +114,39 @@ class LocalContainerRunner(AbstractRunner):
107
114
  image_uri = launch_project.image_name
108
115
  if not docker_image_exists(image_uri):
109
116
  pull_docker_image(image_uri)
110
- env_vars.pop("WANDB_RUN_ID")
111
- # if they've given an override to the entrypoint
112
- entry_cmd = get_entry_point_command(
113
- entry_point, launch_project.override_args
114
- )
117
+ entry_cmd = []
118
+ if entry_point is not None:
119
+ entry_cmd = entry_point.command
120
+ override_args = compute_command_args(launch_project.override_args)
115
121
  command_str = " ".join(
116
- get_docker_command(image_uri, env_vars, entry_cmd, docker_args)
122
+ get_docker_command(
123
+ image_uri,
124
+ env_vars,
125
+ entry_cmd=entry_cmd,
126
+ docker_args=docker_args,
127
+ additional_args=override_args,
128
+ )
117
129
  ).strip()
118
130
  else:
119
131
  assert entry_point is not None
120
- repository: Optional[str] = registry_config.get("url")
132
+ _logger.info("Building docker image...")
133
+ assert builder is not None
121
134
  image_uri = builder.build_image(
122
135
  launch_project,
123
- repository,
124
136
  entry_point,
125
137
  )
138
+ _logger.info(f"Docker image built with uri {image_uri}")
139
+ # entry_cmd and additional_args are empty here because
140
+ # if launch built the container they've been accounted
141
+ # in the dockerfile and env vars respectively
126
142
  command_str = " ".join(
127
- get_docker_command(image_uri, env_vars, [""], docker_args)
143
+ get_docker_command(
144
+ image_uri,
145
+ env_vars,
146
+ docker_args=docker_args,
147
+ )
128
148
  ).strip()
129
149
 
130
- if not self.ack_run_queue_item(launch_project):
131
- return None
132
150
  sanitized_cmd_str = sanitize_wandb_api_key(command_str)
133
151
  _msg = f"{LOG_PREFIX}Launching run in docker with command: {sanitized_cmd_str}"
134
152
  wandb.termlog(_msg)
@@ -170,10 +188,11 @@ def _run_entry_point(command: str, work_dir: Optional[str]) -> AbstractRun:
170
188
  def get_docker_command(
171
189
  image: str,
172
190
  env_vars: Dict[str, str],
173
- entry_cmd: List[str],
191
+ entry_cmd: Optional[List[str]] = None,
174
192
  docker_args: Optional[Dict[str, Any]] = None,
193
+ additional_args: Optional[List[str]] = None,
175
194
  ) -> List[str]:
176
- """Constructs the docker command using the image and docker args.
195
+ """Construct the docker command using the image and docker args.
177
196
 
178
197
  Arguments:
179
198
  image: a Docker image to be run
@@ -202,8 +221,13 @@ def get_docker_command(
202
221
  else:
203
222
  cmd += [prefix, shlex.quote(str(value))]
204
223
 
224
+ if entry_cmd:
225
+ cmd += ["--entrypoint", entry_cmd[0]]
205
226
  cmd += [shlex.quote(image)]
206
- cmd += entry_cmd
227
+ if entry_cmd and len(entry_cmd) > 1:
228
+ cmd += entry_cmd[1:]
229
+ if additional_args:
230
+ cmd += additional_args
207
231
  return cmd
208
232
 
209
233
 
@@ -3,13 +3,13 @@ import shlex
3
3
  from typing import Any, List, Optional
4
4
 
5
5
  import wandb
6
- from wandb.errors import LaunchError
7
6
 
8
7
  from .._project_spec import LaunchProject, get_entry_point_command
9
8
  from ..builder.build import get_env_vars_dict
10
9
  from ..utils import (
11
10
  LOG_PREFIX,
12
11
  PROJECT_SYNCHRONOUS,
12
+ LaunchError,
13
13
  _is_wandb_uri,
14
14
  download_wandb_python_deps,
15
15
  parse_wandb_uri,
@@ -81,9 +81,6 @@ class LocalProcessRunner(AbstractRunner):
81
81
  for env_key, env_value in env_vars.items():
82
82
  cmd += [f"{shlex.quote(env_key)}={shlex.quote(env_value)}"]
83
83
 
84
- if not self.ack_run_queue_item(launch_project):
85
- return None
86
-
87
84
  entry_cmd = get_entry_point_command(entry_point, launch_project.override_args)
88
85
  cmd += entry_cmd
89
86
 
@@ -1,23 +1,20 @@
1
- import configparser
1
+ """Implementation of the SageMakerRunner class."""
2
2
  import logging
3
- import os
4
- import subprocess
5
3
  import time
6
- from typing import Any, Dict, Optional, Tuple, cast
4
+ from typing import Any, Dict, Optional, cast
7
5
 
8
6
  if False:
9
7
  import boto3 # type: ignore
10
8
 
11
9
  import wandb
12
- import wandb.docker as docker
13
10
  from wandb.apis.internal import Api
14
- from wandb.errors import LaunchError
15
11
  from wandb.sdk.launch.builder.abstract import AbstractBuilder
16
- from wandb.util import get_module
12
+ from wandb.sdk.launch.environment.aws_environment import AwsEnvironment
13
+ from wandb.sdk.launch.utils import LaunchError
17
14
 
18
15
  from .._project_spec import LaunchProject, get_entry_point_command
19
16
  from ..builder.build import get_env_vars_dict
20
- from ..utils import LOG_PREFIX, PROJECT_SYNCHRONOUS, run_shell, to_camel_case
17
+ from ..utils import LOG_PREFIX, PROJECT_SYNCHRONOUS, to_camel_case
21
18
  from .abstract import AbstractRun, AbstractRunner, Status
22
19
 
23
20
  _logger = logging.getLogger(__name__)
@@ -69,32 +66,49 @@ class SagemakerSubmittedRun(AbstractRun):
69
66
  return self._status
70
67
 
71
68
 
72
- class AWSSagemakerRunner(AbstractRunner):
69
+ class SageMakerRunner(AbstractRunner):
73
70
  """Runner class, uses a project to create a SagemakerSubmittedRun."""
74
71
 
72
+ def __init__(
73
+ self, api: Api, backend_config: Dict[str, Any], environment: AwsEnvironment
74
+ ) -> None:
75
+ """Initialize the SagemakerRunner.
76
+
77
+ Arguments:
78
+ api (Api): The API instance.
79
+ backend_config (Dict[str, Any]): The backend configuration.
80
+ environment (AwsEnvironment): The AWS environment.
81
+
82
+ Raises:
83
+ LaunchError: If the runner cannot be initialized.
84
+ """
85
+ super().__init__(api, backend_config)
86
+ self.environment = environment
87
+
75
88
  def run(
76
89
  self,
77
90
  launch_project: LaunchProject,
78
- builder: AbstractBuilder,
79
- registry_config: Dict[str, Any],
91
+ builder: Optional[AbstractBuilder],
80
92
  ) -> Optional[AbstractRun]:
81
- _logger.info("using AWSSagemakerRunner")
93
+ """Run a project on Amazon Sagemaker.
82
94
 
83
- boto3 = get_module(
84
- "boto3",
85
- "AWSSagemakerRunner requires boto3 to be installed, install with pip install wandb[launch]",
86
- )
87
- botocore = get_module(
88
- "botocore",
89
- "AWSSagemakerRunner requires botocore to be installed, install with pip install wandb[launch]",
90
- )
95
+ Arguments:
96
+ launch_project (LaunchProject): The project to run.
97
+ builder (AbstractBuilder): The builder to use.
98
+
99
+ Returns:
100
+ Optional[AbstractRun]: The run instance.
101
+
102
+ Raises:
103
+ LaunchError: If the launch is unsuccessful.
104
+ """
105
+ _logger.info("using AWSSagemakerRunner")
91
106
 
92
107
  given_sagemaker_args = launch_project.resource_args.get("sagemaker")
93
108
  if given_sagemaker_args is None:
94
109
  raise LaunchError(
95
110
  "No sagemaker args specified. Specify sagemaker args in resource_args"
96
111
  )
97
- validate_sagemaker_requirements(given_sagemaker_args, registry_config)
98
112
 
99
113
  default_output_path = self.backend_config.get("runner", {}).get(
100
114
  "s3_output_path"
@@ -104,37 +118,22 @@ class AWSSagemakerRunner(AbstractRunner):
104
118
  ):
105
119
  default_output_path = f"s3://{default_output_path}"
106
120
 
107
- region = get_region(given_sagemaker_args, registry_config.get("region"))
108
- instance_role: bool = False
109
- try:
110
- client = boto3.client("sts")
111
- instance_role = True
112
- caller_id = client.get_caller_identity()
113
-
114
- except botocore.exceptions.NoCredentialsError:
115
- access_key, secret_key = get_aws_credentials(given_sagemaker_args)
116
- client = boto3.client(
117
- "sts", aws_access_key_id=access_key, aws_secret_access_key=secret_key
118
- )
119
- caller_id = client.get_caller_identity()
120
-
121
+ session = self.environment.get_session()
122
+ client = session.client("sts")
123
+ caller_id = client.get_caller_identity()
121
124
  account_id = caller_id["Account"]
125
+ _logger.info(f"Using account ID {account_id}")
122
126
  role_arn = get_role_arn(given_sagemaker_args, self.backend_config, account_id)
123
127
  entry_point = launch_project.get_single_entry_point()
128
+
129
+ # Create a sagemaker client to launch the job.
130
+ sagemaker_client = session.client("sagemaker")
131
+
124
132
  # if the user provided the image they want to use, use that, but warn it won't have swappable artifacts
125
133
  if (
126
134
  given_sagemaker_args.get("AlgorithmSpecification", {}).get("TrainingImage")
127
135
  is not None
128
136
  ):
129
- if instance_role:
130
- sagemaker_client = boto3.client("sagemaker", region_name=region)
131
- else:
132
- sagemaker_client = boto3.client(
133
- "sagemaker",
134
- region_name=region,
135
- aws_access_key_id=access_key,
136
- aws_secret_access_key=secret_key,
137
- )
138
137
  sagemaker_args = build_sagemaker_args(
139
138
  launch_project,
140
139
  self._api,
@@ -152,57 +151,20 @@ class AWSSagemakerRunner(AbstractRunner):
152
151
  run.wait()
153
152
  return run
154
153
 
155
- _logger.info("Connecting to AWS ECR Client")
156
- if instance_role:
157
- ecr_client = boto3.client("ecr", region_name=region)
158
- else:
159
- ecr_client = boto3.client(
160
- "ecr",
161
- region_name=region,
162
- aws_access_key_id=access_key,
163
- aws_secret_access_key=secret_key,
164
- )
165
- repository = get_ecr_repository_url(
166
- ecr_client, given_sagemaker_args, registry_config
167
- )
168
- # TODO: handle login credentials gracefully
169
- login_credentials = registry_config.get("credentials")
170
- if login_credentials is not None:
171
- wandb.termwarn(
172
- "Ignoring registry credentials for ECR, using those found on the system"
173
- )
174
-
175
- if builder.type != "kaniko":
176
- _logger.info("Logging in to AWS ECR")
177
- login_resp = aws_ecr_login(region, repository)
178
- if login_resp is None or "Login Succeeded" not in login_resp:
179
- raise LaunchError(f"Unable to login to ECR, response: {login_resp}")
180
-
181
154
  if launch_project.docker_image:
182
155
  image = launch_project.docker_image
183
156
  else:
184
157
  assert entry_point is not None
158
+ assert builder is not None
185
159
  # build our own image
160
+ _logger.info("Building docker image...")
186
161
  image = builder.build_image(
187
162
  launch_project,
188
- repository,
189
163
  entry_point,
190
164
  )
191
-
192
- if not self.ack_run_queue_item(launch_project):
193
- return None
165
+ _logger.info(f"Docker image built with uri {image}")
194
166
 
195
167
  _logger.info("Connecting to sagemaker client")
196
- if instance_role:
197
- sagemaker_client = boto3.client("sagemaker", region_name=region)
198
- else:
199
- sagemaker_client = boto3.client(
200
- "sagemaker",
201
- region_name=region,
202
- aws_access_key_id=access_key,
203
- aws_secret_access_key=secret_key,
204
- )
205
-
206
168
  command_args = get_entry_point_command(
207
169
  entry_point, launch_project.override_args
208
170
  )
@@ -225,29 +187,15 @@ class AWSSagemakerRunner(AbstractRunner):
225
187
  return run
226
188
 
227
189
 
228
- def aws_ecr_login(region: str, registry: str) -> Optional[str]:
229
- pw_command = ["aws", "ecr", "get-login-password", "--region", region]
230
- try:
231
- pw = run_shell(pw_command)[0]
232
- except subprocess.CalledProcessError:
233
- raise LaunchError(
234
- "Unable to get login password. Please ensure you have AWS credentials configured"
235
- )
236
- try:
237
- docker_login_process = docker.login("AWS", pw, registry)
238
- except Exception:
239
- raise LaunchError(f"Failed to login to ECR {registry}")
240
- return docker_login_process
241
-
242
-
243
190
  def merge_aws_tag_with_algorithm_specification(
244
191
  algorithm_specification: Optional[Dict[str, Any]], aws_tag: Optional[str]
245
192
  ) -> Dict[str, Any]:
246
- """
247
- AWS Sagemaker algorithms require a training image and an input mode.
248
- If the user does not specify the specification themselves, define the spec
249
- minimally using these two fields. Otherwise, if they specify the AlgorithmSpecification
250
- set the training image if it is not set.
193
+ """Create an AWS AlgorithmSpecification.
194
+
195
+ AWS Sagemaker algorithms require a training image and an input mode. If the user
196
+ does not specify the specification themselves, define the spec minimally using these
197
+ two fields. Otherwise, if they specify the AlgorithmSpecification set the training
198
+ image if it is not set.
251
199
  """
252
200
  if algorithm_specification is None:
253
201
  return {
@@ -366,65 +314,10 @@ def launch_sagemaker_job(
366
314
  return run
367
315
 
368
316
 
369
- def get_region(
370
- sagemaker_args: Dict[str, Any], registry_config_region: Optional[str] = None
371
- ) -> str:
372
- region = sagemaker_args.get("region")
373
- if region is None:
374
- region = registry_config_region
375
- if region is None:
376
- region = os.environ.get("AWS_DEFAULT_REGION")
377
- if region is None and os.path.exists(os.path.expanduser("~/.aws/config")):
378
- config = configparser.ConfigParser()
379
- config.read(os.path.expanduser("~/.aws/config"))
380
- section = sagemaker_args.get("profile") or "default"
381
- try:
382
- region = config.get(section, "region")
383
- except (configparser.NoOptionError, configparser.NoSectionError):
384
- raise LaunchError(
385
- "Unable to detemine default region from ~/.aws/config. "
386
- "Please specify region in resource args or specify config "
387
- "section as 'profile'"
388
- )
389
-
390
- if region is None:
391
- raise LaunchError(
392
- "AWS region not specified and ~/.aws/config not found. Configure AWS"
393
- )
394
- assert isinstance(region, str)
395
- return region
396
-
397
-
398
- def get_aws_credentials(sagemaker_args: Dict[str, Any]) -> Tuple[str, str]:
399
- access_key = os.environ.get("AWS_ACCESS_KEY_ID")
400
- secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY")
401
- if (
402
- access_key is None
403
- or secret_key is None
404
- and os.path.exists(os.path.expanduser("~/.aws/credentials"))
405
- ):
406
- profile = sagemaker_args.get("profile") or "default"
407
- config = configparser.ConfigParser()
408
- config.read(os.path.expanduser("~/.aws/credentials"))
409
- try:
410
- access_key = config.get(profile, "aws_access_key_id")
411
- secret_key = config.get(profile, "aws_secret_access_key")
412
- except (configparser.NoOptionError, configparser.NoSectionError):
413
- raise LaunchError(
414
- "Unable to get aws credentials from ~/.aws/credentials. "
415
- "Please set aws credentials in environments variables, or "
416
- "check your credentials in ~/.aws/credentials. Use resource "
417
- "args to specify the profile using 'profile'"
418
- )
419
-
420
- if access_key is None or secret_key is None:
421
- raise LaunchError("AWS credentials not found")
422
- return access_key, secret_key
423
-
424
-
425
317
  def get_role_arn(
426
318
  sagemaker_args: Dict[str, Any], backend_config: Dict[str, Any], account_id: str
427
319
  ) -> str:
320
+ """Get the role arn from the sagemaker args or the backend config."""
428
321
  role_arn = sagemaker_args.get("RoleArn") or sagemaker_args.get("role_arn")
429
322
  if role_arn is None:
430
323
  role_arn = backend_config.get("runner", {}).get("role_arn")
@@ -437,55 +330,3 @@ def get_role_arn(
437
330
  return role_arn
438
331
 
439
332
  return f"arn:aws:iam::{account_id}:role/{role_arn}"
440
-
441
-
442
- def validate_sagemaker_requirements(
443
- given_sagemaker_args: Dict[str, Any], registry_config: Dict[str, Any]
444
- ) -> None:
445
- if (
446
- given_sagemaker_args.get(
447
- "EcrRepoName", given_sagemaker_args.get("ecr_repo_name")
448
- )
449
- is None
450
- and registry_config.get("url") is None
451
- ):
452
- raise LaunchError(
453
- "AWS sagemaker requires an ECR Repository to push the container to "
454
- "set this by adding a `EcrRepoName` key to the sagemaker"
455
- "field of resource_args or through the url key in the registry section "
456
- "of the launch agent config."
457
- )
458
-
459
- if registry_config.get("ecr-repo-provider", "aws").lower() != "aws":
460
- raise LaunchError(
461
- "Sagemaker jobs requires an AWS ECR Repo to push the container to"
462
- )
463
-
464
-
465
- def get_ecr_repository_url(
466
- ecr_client: "boto3.Client",
467
- given_sagemaker_args: Dict[str, Any],
468
- registry_config: Dict[str, Any],
469
- ) -> str:
470
- token = ecr_client.get_authorization_token()
471
- ecr_repo_name = given_sagemaker_args.get(
472
- "EcrRepoName", given_sagemaker_args.get("ecr_repo_name")
473
- )
474
- if ecr_repo_name:
475
- if not isinstance(ecr_repo_name, str):
476
- raise LaunchError("EcrRepoName must be a string")
477
- if not ecr_repo_name.startswith("arn:aws:ecr:"):
478
- repository = cast(
479
- str,
480
- token["authorizationData"][0]["proxyEndpoint"].replace("https://", "")
481
- + f"/{ecr_repo_name}",
482
- )
483
- else:
484
- repository = ecr_repo_name
485
- else:
486
- repository = cast(str, registry_config.get("url", ""))
487
- if not repository:
488
- raise LaunchError(
489
- "Must provide a repository url either through resource args or launch config file"
490
- )
491
- return repository