wandb 0.15.3__py3-none-any.whl → 0.15.5__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (156) hide show
  1. wandb/__init__.py +1 -1
  2. wandb/analytics/sentry.py +1 -0
  3. wandb/apis/importers/base.py +20 -5
  4. wandb/apis/importers/mlflow.py +7 -1
  5. wandb/apis/internal.py +12 -0
  6. wandb/apis/public.py +247 -1387
  7. wandb/apis/reports/_panels.py +58 -35
  8. wandb/beta/workflows.py +6 -7
  9. wandb/cli/cli.py +130 -60
  10. wandb/data_types.py +3 -1
  11. wandb/filesync/dir_watcher.py +21 -27
  12. wandb/filesync/step_checksum.py +8 -8
  13. wandb/filesync/step_prepare.py +23 -10
  14. wandb/filesync/step_upload.py +13 -13
  15. wandb/filesync/upload_job.py +4 -8
  16. wandb/integration/cohere/__init__.py +3 -0
  17. wandb/integration/cohere/cohere.py +21 -0
  18. wandb/integration/cohere/resolver.py +347 -0
  19. wandb/integration/gym/__init__.py +4 -6
  20. wandb/integration/huggingface/__init__.py +3 -0
  21. wandb/integration/huggingface/huggingface.py +18 -0
  22. wandb/integration/huggingface/resolver.py +213 -0
  23. wandb/integration/langchain/wandb_tracer.py +16 -179
  24. wandb/integration/openai/__init__.py +1 -3
  25. wandb/integration/openai/openai.py +11 -143
  26. wandb/integration/openai/resolver.py +111 -38
  27. wandb/integration/sagemaker/config.py +2 -2
  28. wandb/integration/tensorboard/log.py +4 -4
  29. wandb/old/settings.py +24 -7
  30. wandb/proto/v3/wandb_telemetry_pb2.py +12 -12
  31. wandb/proto/v4/wandb_telemetry_pb2.py +12 -12
  32. wandb/proto/wandb_deprecated.py +3 -1
  33. wandb/sdk/__init__.py +1 -1
  34. wandb/sdk/artifacts/__init__.py +0 -0
  35. wandb/sdk/artifacts/artifact.py +2101 -0
  36. wandb/sdk/artifacts/artifact_download_logger.py +42 -0
  37. wandb/sdk/artifacts/artifact_manifest.py +67 -0
  38. wandb/sdk/artifacts/artifact_manifest_entry.py +159 -0
  39. wandb/sdk/artifacts/artifact_manifests/__init__.py +0 -0
  40. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +91 -0
  41. wandb/sdk/{internal → artifacts}/artifact_saver.py +6 -5
  42. wandb/sdk/artifacts/artifact_state.py +10 -0
  43. wandb/sdk/{interface/artifacts/artifact_cache.py → artifacts/artifacts_cache.py} +22 -12
  44. wandb/sdk/artifacts/exceptions.py +55 -0
  45. wandb/sdk/artifacts/storage_handler.py +59 -0
  46. wandb/sdk/artifacts/storage_handlers/__init__.py +0 -0
  47. wandb/sdk/artifacts/storage_handlers/azure_handler.py +192 -0
  48. wandb/sdk/artifacts/storage_handlers/gcs_handler.py +224 -0
  49. wandb/sdk/artifacts/storage_handlers/http_handler.py +112 -0
  50. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +134 -0
  51. wandb/sdk/artifacts/storage_handlers/multi_handler.py +53 -0
  52. wandb/sdk/artifacts/storage_handlers/s3_handler.py +301 -0
  53. wandb/sdk/artifacts/storage_handlers/tracking_handler.py +67 -0
  54. wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +132 -0
  55. wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +72 -0
  56. wandb/sdk/artifacts/storage_layout.py +6 -0
  57. wandb/sdk/artifacts/storage_policies/__init__.py +0 -0
  58. wandb/sdk/artifacts/storage_policies/s3_bucket_policy.py +61 -0
  59. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +386 -0
  60. wandb/sdk/{interface/artifacts/artifact_storage.py → artifacts/storage_policy.py} +5 -57
  61. wandb/sdk/data_types/_dtypes.py +7 -12
  62. wandb/sdk/data_types/base_types/json_metadata.py +3 -2
  63. wandb/sdk/data_types/base_types/media.py +8 -8
  64. wandb/sdk/data_types/base_types/wb_value.py +12 -13
  65. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +5 -6
  66. wandb/sdk/data_types/helper_types/classes.py +6 -8
  67. wandb/sdk/data_types/helper_types/image_mask.py +5 -6
  68. wandb/sdk/data_types/histogram.py +4 -3
  69. wandb/sdk/data_types/html.py +3 -4
  70. wandb/sdk/data_types/image.py +11 -9
  71. wandb/sdk/data_types/molecule.py +5 -3
  72. wandb/sdk/data_types/object_3d.py +7 -5
  73. wandb/sdk/data_types/plotly.py +3 -2
  74. wandb/sdk/data_types/saved_model.py +11 -11
  75. wandb/sdk/data_types/trace_tree.py +5 -4
  76. wandb/sdk/data_types/utils.py +3 -5
  77. wandb/sdk/data_types/video.py +5 -4
  78. wandb/sdk/integration_utils/auto_logging.py +215 -0
  79. wandb/sdk/interface/interface.py +15 -15
  80. wandb/sdk/internal/file_pusher.py +8 -16
  81. wandb/sdk/internal/file_stream.py +5 -11
  82. wandb/sdk/internal/handler.py +13 -1
  83. wandb/sdk/internal/internal_api.py +287 -13
  84. wandb/sdk/internal/job_builder.py +119 -30
  85. wandb/sdk/internal/sender.py +6 -26
  86. wandb/sdk/internal/settings_static.py +2 -0
  87. wandb/sdk/internal/system/assets/__init__.py +2 -0
  88. wandb/sdk/internal/system/assets/gpu.py +42 -0
  89. wandb/sdk/internal/system/assets/gpu_amd.py +216 -0
  90. wandb/sdk/internal/system/env_probe_helpers.py +13 -0
  91. wandb/sdk/internal/system/system_info.py +3 -3
  92. wandb/sdk/internal/tb_watcher.py +32 -22
  93. wandb/sdk/internal/thread_local_settings.py +18 -0
  94. wandb/sdk/launch/_project_spec.py +57 -11
  95. wandb/sdk/launch/agent/agent.py +147 -65
  96. wandb/sdk/launch/agent/job_status_tracker.py +34 -0
  97. wandb/sdk/launch/agent/run_queue_item_file_saver.py +45 -0
  98. wandb/sdk/launch/builder/abstract.py +5 -1
  99. wandb/sdk/launch/builder/build.py +21 -18
  100. wandb/sdk/launch/builder/docker_builder.py +10 -4
  101. wandb/sdk/launch/builder/kaniko_builder.py +113 -23
  102. wandb/sdk/launch/builder/noop.py +6 -3
  103. wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +46 -14
  104. wandb/sdk/launch/environment/aws_environment.py +3 -2
  105. wandb/sdk/launch/environment/azure_environment.py +124 -0
  106. wandb/sdk/launch/environment/gcp_environment.py +2 -4
  107. wandb/sdk/launch/environment/local_environment.py +1 -1
  108. wandb/sdk/launch/errors.py +19 -0
  109. wandb/sdk/launch/github_reference.py +32 -19
  110. wandb/sdk/launch/launch.py +3 -8
  111. wandb/sdk/launch/launch_add.py +6 -2
  112. wandb/sdk/launch/loader.py +21 -2
  113. wandb/sdk/launch/registry/azure_container_registry.py +132 -0
  114. wandb/sdk/launch/registry/elastic_container_registry.py +39 -5
  115. wandb/sdk/launch/registry/google_artifact_registry.py +68 -26
  116. wandb/sdk/launch/registry/local_registry.py +2 -1
  117. wandb/sdk/launch/runner/abstract.py +24 -3
  118. wandb/sdk/launch/runner/kubernetes_runner.py +479 -26
  119. wandb/sdk/launch/runner/local_container.py +103 -51
  120. wandb/sdk/launch/runner/local_process.py +1 -1
  121. wandb/sdk/launch/runner/sagemaker_runner.py +60 -10
  122. wandb/sdk/launch/runner/vertex_runner.py +10 -5
  123. wandb/sdk/launch/sweeps/__init__.py +7 -9
  124. wandb/sdk/launch/sweeps/scheduler.py +307 -77
  125. wandb/sdk/launch/sweeps/scheduler_sweep.py +2 -1
  126. wandb/sdk/launch/sweeps/utils.py +82 -35
  127. wandb/sdk/launch/utils.py +89 -75
  128. wandb/sdk/lib/_settings_toposort_generated.py +7 -0
  129. wandb/sdk/lib/capped_dict.py +26 -0
  130. wandb/sdk/lib/{git.py → gitlib.py} +76 -59
  131. wandb/sdk/lib/hashutil.py +12 -4
  132. wandb/sdk/lib/paths.py +96 -8
  133. wandb/sdk/lib/sock_client.py +2 -2
  134. wandb/sdk/lib/timer.py +1 -0
  135. wandb/sdk/service/server.py +22 -9
  136. wandb/sdk/service/server_sock.py +1 -1
  137. wandb/sdk/service/service.py +27 -8
  138. wandb/sdk/verify/verify.py +4 -7
  139. wandb/sdk/wandb_config.py +2 -6
  140. wandb/sdk/wandb_init.py +57 -53
  141. wandb/sdk/wandb_require.py +7 -0
  142. wandb/sdk/wandb_run.py +61 -223
  143. wandb/sdk/wandb_settings.py +28 -4
  144. wandb/testing/relay.py +15 -2
  145. wandb/util.py +74 -36
  146. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/METADATA +15 -9
  147. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/RECORD +151 -116
  148. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/entry_points.txt +1 -0
  149. wandb/integration/langchain/util.py +0 -191
  150. wandb/sdk/interface/artifacts/__init__.py +0 -33
  151. wandb/sdk/interface/artifacts/artifact.py +0 -615
  152. wandb/sdk/interface/artifacts/artifact_manifest.py +0 -131
  153. wandb/sdk/wandb_artifacts.py +0 -2226
  154. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/LICENSE +0 -0
  155. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/WHEEL +0 -0
  156. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,14 @@
1
1
  import logging
2
2
  import os
3
3
  import shlex
4
- import signal
5
4
  import subprocess
6
5
  import sys
6
+ import threading
7
+ import time
7
8
  from typing import Any, Dict, List, Optional
8
9
 
9
10
  import wandb
11
+ from wandb.sdk.launch.agent.job_status_tracker import JobAndRunStatusTracker
10
12
  from wandb.sdk.launch.builder.abstract import AbstractBuilder
11
13
  from wandb.sdk.launch.environment.abstract import AbstractEnvironment
12
14
 
@@ -17,7 +19,6 @@ from ..utils import (
17
19
  PROJECT_SYNCHRONOUS,
18
20
  _is_wandb_dev_uri,
19
21
  _is_wandb_local_uri,
20
- docker_image_exists,
21
22
  pull_docker_image,
22
23
  sanitize_wandb_api_key,
23
24
  )
@@ -29,36 +30,57 @@ _logger = logging.getLogger(__name__)
29
30
  class LocalSubmittedRun(AbstractRun):
30
31
  """Instance of ``AbstractRun`` corresponding to a subprocess launched to run an entry point command locally."""
31
32
 
32
- def __init__(self, command_proc: "subprocess.Popen[bytes]") -> None:
33
+ def __init__(self) -> None:
33
34
  super().__init__()
34
- self.command_proc = command_proc
35
+ self._command_proc: Optional[subprocess.Popen] = None
36
+ self._stdout: Optional[str] = None
37
+ self._terminate_flag: bool = False
38
+ self._thread: Optional[threading.Thread] = None
39
+
40
+ def set_command_proc(self, command_proc: subprocess.Popen) -> None:
41
+ self._command_proc = command_proc
42
+
43
+ def set_thread(self, thread: threading.Thread) -> None:
44
+ self._thread = thread
35
45
 
36
46
  @property
37
- def id(self) -> str:
38
- return str(self.command_proc.pid)
47
+ def id(self) -> Optional[str]:
48
+ if self._command_proc is None:
49
+ return None
50
+ return str(self._command_proc.pid)
39
51
 
40
52
  def wait(self) -> bool:
41
- return self.command_proc.wait() == 0
53
+ assert self._thread is not None
54
+ # if command proc is not set
55
+ # wait for thread to set it
56
+ if self._command_proc is None:
57
+ while self._thread.is_alive():
58
+ time.sleep(5)
59
+ # command proc can be updated by another thread
60
+ if self._command_proc is not None:
61
+ return self._command_proc.wait() == 0 # type: ignore
62
+ return False
63
+
64
+ return self._command_proc.wait() == 0
65
+
66
+ def get_logs(self) -> Optional[str]:
67
+ return self._stdout
42
68
 
43
69
  def cancel(self) -> None:
44
- # Interrupt child process if it hasn't already exited
45
- if self.command_proc.poll() is None:
46
- # Kill the the process tree rooted at the child if it's the leader of its own process
47
- # group, otherwise just kill the child
48
- try:
49
- if self.command_proc.pid == os.getpgid(self.command_proc.pid):
50
- os.killpg(self.command_proc.pid, signal.SIGTERM)
51
- else:
52
- self.command_proc.terminate()
53
- except OSError:
54
- # The child process may have exited before we attempted to terminate it, so we
55
- # ignore OSErrors raised during child process termination
56
- _msg = f"{LOG_PREFIX}Failed to terminate child process PID {self.command_proc.pid}"
57
- _logger.debug(_msg)
58
- self.command_proc.wait()
70
+ # thread is set immediately after starting, should always exist
71
+ assert self._thread is not None
72
+
73
+ # cancel called before the thread subprocess has started
74
+ # indicates to thread to not start command proc if not already started
75
+ self._terminate_flag = True
59
76
 
60
77
  def get_status(self) -> Status:
61
- exit_code = self.command_proc.poll()
78
+ assert self._thread is not None, "Failed to get status, self._thread = None"
79
+ if self._command_proc is None:
80
+ if self._thread.is_alive():
81
+ return Status("running")
82
+ return Status("stopped")
83
+ exit_code = self._command_proc.poll()
62
84
  if exit_code is None:
63
85
  return Status("running")
64
86
  if exit_code == 0:
@@ -78,12 +100,7 @@ class LocalContainerRunner(AbstractRunner):
78
100
  super().__init__(api, backend_config)
79
101
  self.environment = environment
80
102
 
81
- def run(
82
- self,
83
- launch_project: LaunchProject,
84
- builder: Optional[AbstractBuilder],
85
- ) -> Optional[AbstractRun]:
86
- synchronous: bool = self.backend_config[PROJECT_SYNCHRONOUS]
103
+ def _populate_docker_args(self, launch_project: LaunchProject) -> Dict[str, Any]:
87
104
  docker_args: Dict[str, Any] = launch_project.resource_args.get(
88
105
  "local-container", {}
89
106
  )
@@ -96,6 +113,16 @@ class LocalContainerRunner(AbstractRunner):
96
113
  if sys.platform == "linux" or sys.platform == "linux2":
97
114
  docker_args["add-host"] = "host.docker.internal:host-gateway"
98
115
 
116
+ return docker_args
117
+
118
+ def run(
119
+ self,
120
+ launch_project: LaunchProject,
121
+ builder: Optional[AbstractBuilder],
122
+ job_tracker: Optional[JobAndRunStatusTracker] = None,
123
+ ) -> Optional[AbstractRun]:
124
+ docker_args = self._populate_docker_args(launch_project)
125
+ synchronous: bool = self.backend_config[PROJECT_SYNCHRONOUS]
99
126
  entry_point = launch_project.get_single_entry_point()
100
127
  env_vars = get_env_vars_dict(launch_project, self._api)
101
128
 
@@ -107,13 +134,12 @@ class LocalContainerRunner(AbstractRunner):
107
134
  _, _, port = self._api.settings("base_url").split(":")
108
135
  env_vars["WANDB_BASE_URL"] = f"http://host.docker.internal:{port}"
109
136
  elif _is_wandb_dev_uri(self._api.settings("base_url")):
110
- env_vars["WANDB_BASE_URL"] = "http://host.docker.internal:9002"
137
+ env_vars["WANDB_BASE_URL"] = "http://host.docker.internal:9001"
111
138
 
112
139
  if launch_project.docker_image:
113
140
  # user has provided their own docker image
114
141
  image_uri = launch_project.image_name
115
- if not docker_image_exists(image_uri):
116
- pull_docker_image(image_uri)
142
+ pull_docker_image(image_uri)
117
143
  entry_cmd = []
118
144
  if entry_point is not None:
119
145
  entry_cmd = entry_point.command
@@ -130,10 +156,7 @@ class LocalContainerRunner(AbstractRunner):
130
156
  assert entry_point is not None
131
157
  _logger.info("Building docker image...")
132
158
  assert builder is not None
133
- image_uri = builder.build_image(
134
- launch_project,
135
- entry_point,
136
- )
159
+ image_uri = builder.build_image(launch_project, entry_point, job_tracker)
137
160
  _logger.info(f"Docker image built with uri {image_uri}")
138
161
  # entry_cmd and additional_args are empty here because
139
162
  # if launch built the container they've been accounted
@@ -145,7 +168,7 @@ class LocalContainerRunner(AbstractRunner):
145
168
  docker_args=docker_args,
146
169
  )
147
170
  ).strip()
148
-
171
+ launch_project.fill_macros(image_uri)
149
172
  sanitized_cmd_str = sanitize_wandb_api_key(command_str)
150
173
  _msg = f"{LOG_PREFIX}Launching run in docker with command: {sanitized_cmd_str}"
151
174
  wandb.termlog(_msg)
@@ -168,20 +191,49 @@ def _run_entry_point(command: str, work_dir: Optional[str]) -> AbstractRun:
168
191
  if work_dir is None:
169
192
  work_dir = os.getcwd()
170
193
  env = os.environ.copy()
171
- if os.name == "nt":
172
- # we are running on windows
173
- process = subprocess.Popen(
174
- ["cmd", "/c", command], close_fds=True, cwd=work_dir, env=env
175
- )
176
- else:
177
- process = subprocess.Popen(
178
- ["bash", "-c", command],
179
- close_fds=True,
180
- cwd=work_dir,
181
- env=env,
182
- )
183
-
184
- return LocalSubmittedRun(process)
194
+ run = LocalSubmittedRun()
195
+ thread = threading.Thread(
196
+ target=_thread_process_runner,
197
+ args=(run, ["bash", "-c", command], work_dir, env),
198
+ )
199
+ run.set_thread(thread)
200
+ thread.start()
201
+ return run
202
+
203
+
204
+ def _thread_process_runner(
205
+ run: LocalSubmittedRun, args: List[str], work_dir: str, env: Dict[str, str]
206
+ ) -> None:
207
+ # cancel was called before we started the subprocess
208
+ if run._terminate_flag:
209
+ return
210
+ process = subprocess.Popen(
211
+ args,
212
+ close_fds=True,
213
+ stdout=subprocess.PIPE,
214
+ stderr=subprocess.STDOUT,
215
+ universal_newlines=True,
216
+ bufsize=1,
217
+ cwd=work_dir,
218
+ env=env,
219
+ )
220
+ run.set_command_proc(process)
221
+ run._stdout = ""
222
+ while True:
223
+ # the agent thread could set the terminate flag
224
+ if run._terminate_flag:
225
+ process.terminate() # type: ignore
226
+ chunk = os.read(process.stdout.fileno(), 4096) # type: ignore
227
+ if not chunk:
228
+ break
229
+ index = chunk.find(b"\r")
230
+ decoded_chunk = chunk.decode()
231
+ if index != -1:
232
+ run._stdout += decoded_chunk
233
+ print(chunk.decode(), end="")
234
+ else:
235
+ run._stdout += decoded_chunk + "\r"
236
+ print(chunk.decode(), end="\r")
185
237
 
186
238
 
187
239
  def get_docker_command(
@@ -6,10 +6,10 @@ import wandb
6
6
 
7
7
  from .._project_spec import LaunchProject, get_entry_point_command
8
8
  from ..builder.build import get_env_vars_dict
9
+ from ..errors import LaunchError
9
10
  from ..utils import (
10
11
  LOG_PREFIX,
11
12
  PROJECT_SYNCHRONOUS,
12
- LaunchError,
13
13
  _is_wandb_uri,
14
14
  download_wandb_python_deps,
15
15
  parse_wandb_uri,
@@ -8,9 +8,10 @@ if False:
8
8
 
9
9
  import wandb
10
10
  from wandb.apis.internal import Api
11
+ from wandb.sdk.launch.agent.job_status_tracker import JobAndRunStatusTracker
11
12
  from wandb.sdk.launch.builder.abstract import AbstractBuilder
12
13
  from wandb.sdk.launch.environment.aws_environment import AwsEnvironment
13
- from wandb.sdk.launch.utils import LaunchError
14
+ from wandb.sdk.launch.errors import LaunchError
14
15
 
15
16
  from .._project_spec import LaunchProject, get_entry_point_command
16
17
  from ..builder.build import get_env_vars_dict
@@ -23,9 +24,15 @@ _logger = logging.getLogger(__name__)
23
24
  class SagemakerSubmittedRun(AbstractRun):
24
25
  """Instance of ``AbstractRun`` corresponding to a subprocess launched to run an entry point command on aws sagemaker."""
25
26
 
26
- def __init__(self, training_job_name: str, client: "boto3.Client") -> None:
27
+ def __init__(
28
+ self,
29
+ training_job_name: str,
30
+ client: "boto3.Client",
31
+ log_client: Optional["boto3.Client"] = None,
32
+ ) -> None:
27
33
  super().__init__()
28
34
  self.client = client
35
+ self.log_client = log_client
29
36
  self.training_job_name = training_job_name
30
37
  self._status = Status("running")
31
38
 
@@ -33,6 +40,38 @@ class SagemakerSubmittedRun(AbstractRun):
33
40
  def id(self) -> str:
34
41
  return f"sagemaker-{self.training_job_name}"
35
42
 
43
+ def get_logs(self) -> Optional[str]:
44
+ if self.log_client is None:
45
+ return None
46
+ try:
47
+ describe_res = self.log_client.describe_log_streams(
48
+ logGroupName="/aws/sagemaker/TrainingJobs",
49
+ logStreamNamePrefix=self.training_job_name,
50
+ )
51
+ if len(describe_res["logStreams"]) == 0:
52
+ wandb.termwarn(
53
+ f"Failed to get logs for training job: {self.training_job_name}"
54
+ )
55
+ return None
56
+ log_name = describe_res["logStreams"][0]["logStreamName"]
57
+ res = self.log_client.get_log_events(
58
+ logGroupName="/aws/sagemaker/TrainingJobs",
59
+ logStreamName=log_name,
60
+ )
61
+ return "\n".join(
62
+ [f'{event["timestamp"]}:{event["message"]}' for event in res["events"]]
63
+ )
64
+ except self.log_client.exceptions.ResourceNotFoundException:
65
+ wandb.termwarn(
66
+ f"Failed to get logs for training job: {self.training_job_name}"
67
+ )
68
+ return None
69
+ except Exception as e:
70
+ wandb.termwarn(
71
+ f"Failed to handle logs for training job: {self.training_job_name} with error {str(e)}"
72
+ )
73
+ return None
74
+
36
75
  def wait(self) -> bool:
37
76
  while True:
38
77
  status_state = self.get_status().state
@@ -89,6 +128,7 @@ class SageMakerRunner(AbstractRunner):
89
128
  self,
90
129
  launch_project: LaunchProject,
91
130
  builder: Optional[AbstractBuilder],
131
+ job_tracker: Optional[JobAndRunStatusTracker] = None,
92
132
  ) -> Optional[AbstractRun]:
93
133
  """Run a project on Amazon Sagemaker.
94
134
 
@@ -128,6 +168,13 @@ class SageMakerRunner(AbstractRunner):
128
168
 
129
169
  # Create a sagemaker client to launch the job.
130
170
  sagemaker_client = session.client("sagemaker")
171
+ log_client = None
172
+ try:
173
+ log_client = session.client("logs")
174
+ except Exception as e:
175
+ wandb.termwarn(
176
+ f"Failed to connect to cloudwatch logs with error {str(e)}, logs will not be available"
177
+ )
131
178
 
132
179
  # if the user provided the image they want to use, use that, but warn it won't have swappable artifacts
133
180
  if (
@@ -146,7 +193,9 @@ class SageMakerRunner(AbstractRunner):
146
193
  _logger.info(
147
194
  f"Launching sagemaker job on user supplied image with args: {sagemaker_args}"
148
195
  )
149
- run = launch_sagemaker_job(launch_project, sagemaker_args, sagemaker_client)
196
+ run = launch_sagemaker_job(
197
+ launch_project, sagemaker_args, sagemaker_client, log_client
198
+ )
150
199
  if self.backend_config[PROJECT_SYNCHRONOUS]:
151
200
  run.wait()
152
201
  return run
@@ -158,12 +207,10 @@ class SageMakerRunner(AbstractRunner):
158
207
  assert builder is not None
159
208
  # build our own image
160
209
  _logger.info("Building docker image...")
161
- image = builder.build_image(
162
- launch_project,
163
- entry_point,
164
- )
210
+ image = builder.build_image(launch_project, entry_point, job_tracker)
165
211
  _logger.info(f"Docker image built with uri {image}")
166
212
 
213
+ launch_project.fill_macros(image)
167
214
  _logger.info("Connecting to sagemaker client")
168
215
  command_args = get_entry_point_command(
169
216
  entry_point, launch_project.override_args
@@ -181,7 +228,9 @@ class SageMakerRunner(AbstractRunner):
181
228
  launch_project, self._api, role_arn, image, default_output_path
182
229
  )
183
230
  _logger.info(f"Launching sagemaker job with args: {sagemaker_args}")
184
- run = launch_sagemaker_job(launch_project, sagemaker_args, sagemaker_client)
231
+ run = launch_sagemaker_job(
232
+ launch_project, sagemaker_args, sagemaker_client, log_client
233
+ )
185
234
  if self.backend_config[PROJECT_SYNCHRONOUS]:
186
235
  run.wait()
187
236
  return run
@@ -296,14 +345,15 @@ def launch_sagemaker_job(
296
345
  launch_project: LaunchProject,
297
346
  sagemaker_args: Dict[str, Any],
298
347
  sagemaker_client: "boto3.Client",
348
+ log_client: Optional["boto3.Client"] = None,
299
349
  ) -> SagemakerSubmittedRun:
300
350
  training_job_name = sagemaker_args.get("TrainingJobName") or launch_project.run_id
301
351
  resp = sagemaker_client.create_training_job(**sagemaker_args)
302
352
 
303
353
  if resp.get("TrainingJobArn") is None:
304
- raise LaunchError("Unable to create training job")
354
+ raise LaunchError("Failed to create training job when submitting to SageMaker")
305
355
 
306
- run = SagemakerSubmittedRun(training_job_name, sagemaker_client)
356
+ run = SagemakerSubmittedRun(training_job_name, sagemaker_client, log_client)
307
357
  wandb.termlog(
308
358
  f"{LOG_PREFIX}Run job submitted with arn: {resp.get('TrainingJobArn')}"
309
359
  )
@@ -14,10 +14,12 @@ from wandb.apis.internal import Api
14
14
  from wandb.util import get_module
15
15
 
16
16
  from .._project_spec import LaunchProject, get_entry_point_command
17
+ from ..agent.job_status_tracker import JobAndRunStatusTracker
17
18
  from ..builder.abstract import AbstractBuilder
18
19
  from ..builder.build import get_env_vars_dict
19
20
  from ..environment.gcp_environment import GcpEnvironment
20
- from ..utils import LOG_PREFIX, PROJECT_SYNCHRONOUS, LaunchError, run_shell
21
+ from ..errors import LaunchError
22
+ from ..utils import LOG_PREFIX, PROJECT_SYNCHRONOUS, run_shell
21
23
  from .abstract import AbstractRun, AbstractRunner, Status
22
24
 
23
25
  GCP_CONSOLE_URI = "https://console.cloud.google.com"
@@ -34,6 +36,10 @@ class VertexSubmittedRun(AbstractRun):
34
36
  # numeric ID of the custom training job
35
37
  return self._job.name # type: ignore
36
38
 
39
+ def get_logs(self) -> Optional[str]:
40
+ # TODO: implement
41
+ return None
42
+
37
43
  @property
38
44
  def name(self) -> str:
39
45
  return self._job.display_name # type: ignore
@@ -88,6 +94,7 @@ class VertexRunner(AbstractRunner):
88
94
  self,
89
95
  launch_project: LaunchProject,
90
96
  builder: Optional[AbstractBuilder],
97
+ job_tracker: Optional[JobAndRunStatusTracker] = None,
91
98
  ) -> Optional[AbstractRun]:
92
99
  """Run a Vertex job."""
93
100
  aiplatform = get_module( # noqa: F811
@@ -133,11 +140,9 @@ class VertexRunner(AbstractRunner):
133
140
  else:
134
141
  assert entry_point is not None
135
142
  assert builder is not None
136
- image_uri = builder.build_image(
137
- launch_project,
138
- entry_point,
139
- )
143
+ image_uri = builder.build_image(launch_project, entry_point, job_tracker)
140
144
 
145
+ launch_project.fill_macros(image_uri)
141
146
  # TODO: how to handle this?
142
147
  entry_cmd = get_entry_point_command(entry_point, launch_project.override_args)
143
148
 
@@ -3,8 +3,6 @@ from typing import Any, Callable, Dict
3
3
 
4
4
  log = logging.getLogger(__name__)
5
5
 
6
- SCHEDULER_URI = "placeholder-uri-scheduler"
7
-
8
6
 
9
7
  class SchedulerError(Exception):
10
8
  """Raised when a known error occurs with wandb sweep scheduler."""
@@ -19,20 +17,20 @@ def _import_sweep_scheduler() -> Any:
19
17
 
20
18
 
21
19
  _WANDB_SCHEDULERS: Dict[str, Callable] = {
22
- "sweep": _import_sweep_scheduler,
20
+ "wandb": _import_sweep_scheduler,
23
21
  }
24
22
 
25
23
 
26
- def load_scheduler(scheduler_name: str) -> Any:
27
- scheduler_name = scheduler_name.lower()
28
- if scheduler_name not in _WANDB_SCHEDULERS:
24
+ def load_scheduler(scheduler_type: str) -> Any:
25
+ scheduler_type = scheduler_type.lower()
26
+ if scheduler_type not in _WANDB_SCHEDULERS:
29
27
  raise SchedulerError(
30
28
  f"The `scheduler_name` argument must be one of "
31
- f"{list(_WANDB_SCHEDULERS.keys())}, got: {scheduler_name}"
29
+ f"{list(_WANDB_SCHEDULERS.keys())}, got: {scheduler_type}"
32
30
  )
33
31
 
34
- log.warn(f"Loading dependencies for Scheduler of type: {scheduler_name}")
35
- import_func = _WANDB_SCHEDULERS[scheduler_name]
32
+ log.warn(f"Loading dependencies for Scheduler of type: {scheduler_type}")
33
+ import_func = _WANDB_SCHEDULERS[scheduler_type]
36
34
  return import_func()
37
35
 
38
36