wandb 0.15.4__py3-none-any.whl → 0.15.5__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- wandb/__init__.py +1 -1
- wandb/analytics/sentry.py +1 -0
- wandb/apis/internal.py +3 -0
- wandb/apis/public.py +18 -20
- wandb/beta/workflows.py +5 -6
- wandb/cli/cli.py +27 -27
- wandb/data_types.py +2 -0
- wandb/integration/langchain/wandb_tracer.py +16 -179
- wandb/integration/sagemaker/config.py +2 -2
- wandb/integration/tensorboard/log.py +4 -4
- wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
- wandb/proto/wandb_deprecated.py +3 -1
- wandb/sdk/__init__.py +1 -4
- wandb/sdk/artifacts/__init__.py +0 -14
- wandb/sdk/artifacts/artifact.py +1757 -277
- wandb/sdk/artifacts/artifact_manifest_entry.py +26 -6
- wandb/sdk/artifacts/artifact_state.py +10 -0
- wandb/sdk/artifacts/artifacts_cache.py +7 -8
- wandb/sdk/artifacts/exceptions.py +4 -4
- wandb/sdk/artifacts/storage_handler.py +2 -2
- wandb/sdk/artifacts/storage_handlers/azure_handler.py +16 -6
- wandb/sdk/artifacts/storage_handlers/gcs_handler.py +2 -2
- wandb/sdk/artifacts/storage_handlers/http_handler.py +2 -2
- wandb/sdk/artifacts/storage_handlers/local_file_handler.py +2 -2
- wandb/sdk/artifacts/storage_handlers/multi_handler.py +2 -2
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +35 -32
- wandb/sdk/artifacts/storage_handlers/tracking_handler.py +2 -2
- wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +5 -9
- wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +2 -2
- wandb/sdk/artifacts/storage_policies/s3_bucket_policy.py +2 -2
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +24 -16
- wandb/sdk/artifacts/storage_policy.py +3 -3
- wandb/sdk/data_types/_dtypes.py +7 -12
- wandb/sdk/data_types/base_types/json_metadata.py +2 -2
- wandb/sdk/data_types/base_types/media.py +5 -6
- wandb/sdk/data_types/base_types/wb_value.py +12 -13
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +4 -5
- wandb/sdk/data_types/helper_types/classes.py +5 -8
- wandb/sdk/data_types/helper_types/image_mask.py +4 -5
- wandb/sdk/data_types/histogram.py +3 -3
- wandb/sdk/data_types/html.py +3 -4
- wandb/sdk/data_types/image.py +4 -5
- wandb/sdk/data_types/molecule.py +2 -2
- wandb/sdk/data_types/object_3d.py +3 -3
- wandb/sdk/data_types/plotly.py +2 -2
- wandb/sdk/data_types/saved_model.py +7 -8
- wandb/sdk/data_types/trace_tree.py +4 -4
- wandb/sdk/data_types/video.py +4 -4
- wandb/sdk/interface/interface.py +8 -10
- wandb/sdk/internal/file_stream.py +2 -3
- wandb/sdk/internal/internal_api.py +99 -4
- wandb/sdk/internal/job_builder.py +15 -7
- wandb/sdk/internal/sender.py +4 -0
- wandb/sdk/internal/settings_static.py +1 -0
- wandb/sdk/launch/_project_spec.py +9 -7
- wandb/sdk/launch/agent/agent.py +115 -58
- wandb/sdk/launch/agent/job_status_tracker.py +34 -0
- wandb/sdk/launch/agent/run_queue_item_file_saver.py +45 -0
- wandb/sdk/launch/builder/abstract.py +5 -1
- wandb/sdk/launch/builder/build.py +16 -10
- wandb/sdk/launch/builder/docker_builder.py +9 -2
- wandb/sdk/launch/builder/kaniko_builder.py +108 -22
- wandb/sdk/launch/builder/noop.py +3 -1
- wandb/sdk/launch/environment/aws_environment.py +2 -1
- wandb/sdk/launch/environment/azure_environment.py +124 -0
- wandb/sdk/launch/github_reference.py +30 -18
- wandb/sdk/launch/launch.py +1 -1
- wandb/sdk/launch/loader.py +15 -0
- wandb/sdk/launch/registry/azure_container_registry.py +132 -0
- wandb/sdk/launch/registry/elastic_container_registry.py +38 -4
- wandb/sdk/launch/registry/google_artifact_registry.py +46 -7
- wandb/sdk/launch/runner/abstract.py +19 -3
- wandb/sdk/launch/runner/kubernetes_runner.py +111 -47
- wandb/sdk/launch/runner/local_container.py +101 -48
- wandb/sdk/launch/runner/sagemaker_runner.py +59 -9
- wandb/sdk/launch/runner/vertex_runner.py +8 -4
- wandb/sdk/launch/sweeps/scheduler.py +102 -27
- wandb/sdk/launch/sweeps/utils.py +21 -0
- wandb/sdk/launch/utils.py +19 -7
- wandb/sdk/lib/_settings_toposort_generated.py +3 -0
- wandb/sdk/service/server.py +22 -9
- wandb/sdk/service/service.py +27 -8
- wandb/sdk/verify/verify.py +6 -9
- wandb/sdk/wandb_config.py +2 -4
- wandb/sdk/wandb_init.py +2 -0
- wandb/sdk/wandb_require.py +7 -0
- wandb/sdk/wandb_run.py +32 -35
- wandb/sdk/wandb_settings.py +10 -3
- wandb/testing/relay.py +15 -2
- wandb/util.py +55 -23
- {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/METADATA +11 -8
- {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/RECORD +97 -97
- wandb/integration/langchain/util.py +0 -191
- wandb/sdk/artifacts/invalid_artifact.py +0 -23
- wandb/sdk/artifacts/lazy_artifact.py +0 -162
- wandb/sdk/artifacts/local_artifact.py +0 -719
- wandb/sdk/artifacts/public_artifact.py +0 -1188
- {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/LICENSE +0 -0
- {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/WHEEL +0 -0
- {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/entry_points.txt +0 -0
- {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,14 @@
|
|
1
1
|
import logging
|
2
2
|
import os
|
3
3
|
import shlex
|
4
|
-
import signal
|
5
4
|
import subprocess
|
6
5
|
import sys
|
6
|
+
import threading
|
7
|
+
import time
|
7
8
|
from typing import Any, Dict, List, Optional
|
8
9
|
|
9
10
|
import wandb
|
11
|
+
from wandb.sdk.launch.agent.job_status_tracker import JobAndRunStatusTracker
|
10
12
|
from wandb.sdk.launch.builder.abstract import AbstractBuilder
|
11
13
|
from wandb.sdk.launch.environment.abstract import AbstractEnvironment
|
12
14
|
|
@@ -28,36 +30,57 @@ _logger = logging.getLogger(__name__)
|
|
28
30
|
class LocalSubmittedRun(AbstractRun):
|
29
31
|
"""Instance of ``AbstractRun`` corresponding to a subprocess launched to run an entry point command locally."""
|
30
32
|
|
31
|
-
def __init__(self
|
33
|
+
def __init__(self) -> None:
|
32
34
|
super().__init__()
|
33
|
-
self.
|
35
|
+
self._command_proc: Optional[subprocess.Popen] = None
|
36
|
+
self._stdout: Optional[str] = None
|
37
|
+
self._terminate_flag: bool = False
|
38
|
+
self._thread: Optional[threading.Thread] = None
|
39
|
+
|
40
|
+
def set_command_proc(self, command_proc: subprocess.Popen) -> None:
|
41
|
+
self._command_proc = command_proc
|
42
|
+
|
43
|
+
def set_thread(self, thread: threading.Thread) -> None:
|
44
|
+
self._thread = thread
|
34
45
|
|
35
46
|
@property
|
36
|
-
def id(self) -> str:
|
37
|
-
|
47
|
+
def id(self) -> Optional[str]:
|
48
|
+
if self._command_proc is None:
|
49
|
+
return None
|
50
|
+
return str(self._command_proc.pid)
|
38
51
|
|
39
52
|
def wait(self) -> bool:
|
40
|
-
|
53
|
+
assert self._thread is not None
|
54
|
+
# if command proc is not set
|
55
|
+
# wait for thread to set it
|
56
|
+
if self._command_proc is None:
|
57
|
+
while self._thread.is_alive():
|
58
|
+
time.sleep(5)
|
59
|
+
# command proc can be updated by another thread
|
60
|
+
if self._command_proc is not None:
|
61
|
+
return self._command_proc.wait() == 0 # type: ignore
|
62
|
+
return False
|
63
|
+
|
64
|
+
return self._command_proc.wait() == 0
|
65
|
+
|
66
|
+
def get_logs(self) -> Optional[str]:
|
67
|
+
return self._stdout
|
41
68
|
|
42
69
|
def cancel(self) -> None:
|
43
|
-
#
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
os.killpg(self.command_proc.pid, signal.SIGTERM)
|
50
|
-
else:
|
51
|
-
self.command_proc.terminate()
|
52
|
-
except OSError:
|
53
|
-
# The child process may have exited before we attempted to terminate it, so we
|
54
|
-
# ignore OSErrors raised during child process termination
|
55
|
-
_msg = f"{LOG_PREFIX}Failed to terminate child process PID {self.command_proc.pid}"
|
56
|
-
_logger.debug(_msg)
|
57
|
-
self.command_proc.wait()
|
70
|
+
# thread is set immediately after starting, should always exist
|
71
|
+
assert self._thread is not None
|
72
|
+
|
73
|
+
# cancel called before the thread subprocess has started
|
74
|
+
# indicates to thread to not start command proc if not already started
|
75
|
+
self._terminate_flag = True
|
58
76
|
|
59
77
|
def get_status(self) -> Status:
|
60
|
-
|
78
|
+
assert self._thread is not None, "Failed to get status, self._thread = None"
|
79
|
+
if self._command_proc is None:
|
80
|
+
if self._thread.is_alive():
|
81
|
+
return Status("running")
|
82
|
+
return Status("stopped")
|
83
|
+
exit_code = self._command_proc.poll()
|
61
84
|
if exit_code is None:
|
62
85
|
return Status("running")
|
63
86
|
if exit_code == 0:
|
@@ -77,12 +100,7 @@ class LocalContainerRunner(AbstractRunner):
|
|
77
100
|
super().__init__(api, backend_config)
|
78
101
|
self.environment = environment
|
79
102
|
|
80
|
-
def
|
81
|
-
self,
|
82
|
-
launch_project: LaunchProject,
|
83
|
-
builder: Optional[AbstractBuilder],
|
84
|
-
) -> Optional[AbstractRun]:
|
85
|
-
synchronous: bool = self.backend_config[PROJECT_SYNCHRONOUS]
|
103
|
+
def _populate_docker_args(self, launch_project: LaunchProject) -> Dict[str, Any]:
|
86
104
|
docker_args: Dict[str, Any] = launch_project.resource_args.get(
|
87
105
|
"local-container", {}
|
88
106
|
)
|
@@ -95,6 +113,16 @@ class LocalContainerRunner(AbstractRunner):
|
|
95
113
|
if sys.platform == "linux" or sys.platform == "linux2":
|
96
114
|
docker_args["add-host"] = "host.docker.internal:host-gateway"
|
97
115
|
|
116
|
+
return docker_args
|
117
|
+
|
118
|
+
def run(
|
119
|
+
self,
|
120
|
+
launch_project: LaunchProject,
|
121
|
+
builder: Optional[AbstractBuilder],
|
122
|
+
job_tracker: Optional[JobAndRunStatusTracker] = None,
|
123
|
+
) -> Optional[AbstractRun]:
|
124
|
+
docker_args = self._populate_docker_args(launch_project)
|
125
|
+
synchronous: bool = self.backend_config[PROJECT_SYNCHRONOUS]
|
98
126
|
entry_point = launch_project.get_single_entry_point()
|
99
127
|
env_vars = get_env_vars_dict(launch_project, self._api)
|
100
128
|
|
@@ -106,7 +134,7 @@ class LocalContainerRunner(AbstractRunner):
|
|
106
134
|
_, _, port = self._api.settings("base_url").split(":")
|
107
135
|
env_vars["WANDB_BASE_URL"] = f"http://host.docker.internal:{port}"
|
108
136
|
elif _is_wandb_dev_uri(self._api.settings("base_url")):
|
109
|
-
env_vars["WANDB_BASE_URL"] = "http://host.docker.internal:
|
137
|
+
env_vars["WANDB_BASE_URL"] = "http://host.docker.internal:9001"
|
110
138
|
|
111
139
|
if launch_project.docker_image:
|
112
140
|
# user has provided their own docker image
|
@@ -128,11 +156,7 @@ class LocalContainerRunner(AbstractRunner):
|
|
128
156
|
assert entry_point is not None
|
129
157
|
_logger.info("Building docker image...")
|
130
158
|
assert builder is not None
|
131
|
-
image_uri = builder.build_image(
|
132
|
-
launch_project,
|
133
|
-
entry_point,
|
134
|
-
)
|
135
|
-
|
159
|
+
image_uri = builder.build_image(launch_project, entry_point, job_tracker)
|
136
160
|
_logger.info(f"Docker image built with uri {image_uri}")
|
137
161
|
# entry_cmd and additional_args are empty here because
|
138
162
|
# if launch built the container they've been accounted
|
@@ -167,20 +191,49 @@ def _run_entry_point(command: str, work_dir: Optional[str]) -> AbstractRun:
|
|
167
191
|
if work_dir is None:
|
168
192
|
work_dir = os.getcwd()
|
169
193
|
env = os.environ.copy()
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
194
|
+
run = LocalSubmittedRun()
|
195
|
+
thread = threading.Thread(
|
196
|
+
target=_thread_process_runner,
|
197
|
+
args=(run, ["bash", "-c", command], work_dir, env),
|
198
|
+
)
|
199
|
+
run.set_thread(thread)
|
200
|
+
thread.start()
|
201
|
+
return run
|
202
|
+
|
203
|
+
|
204
|
+
def _thread_process_runner(
|
205
|
+
run: LocalSubmittedRun, args: List[str], work_dir: str, env: Dict[str, str]
|
206
|
+
) -> None:
|
207
|
+
# cancel was called before we started the subprocess
|
208
|
+
if run._terminate_flag:
|
209
|
+
return
|
210
|
+
process = subprocess.Popen(
|
211
|
+
args,
|
212
|
+
close_fds=True,
|
213
|
+
stdout=subprocess.PIPE,
|
214
|
+
stderr=subprocess.STDOUT,
|
215
|
+
universal_newlines=True,
|
216
|
+
bufsize=1,
|
217
|
+
cwd=work_dir,
|
218
|
+
env=env,
|
219
|
+
)
|
220
|
+
run.set_command_proc(process)
|
221
|
+
run._stdout = ""
|
222
|
+
while True:
|
223
|
+
# the agent thread could set the terminate flag
|
224
|
+
if run._terminate_flag:
|
225
|
+
process.terminate() # type: ignore
|
226
|
+
chunk = os.read(process.stdout.fileno(), 4096) # type: ignore
|
227
|
+
if not chunk:
|
228
|
+
break
|
229
|
+
index = chunk.find(b"\r")
|
230
|
+
decoded_chunk = chunk.decode()
|
231
|
+
if index != -1:
|
232
|
+
run._stdout += decoded_chunk
|
233
|
+
print(chunk.decode(), end="")
|
234
|
+
else:
|
235
|
+
run._stdout += decoded_chunk + "\r"
|
236
|
+
print(chunk.decode(), end="\r")
|
184
237
|
|
185
238
|
|
186
239
|
def get_docker_command(
|
@@ -8,6 +8,7 @@ if False:
|
|
8
8
|
|
9
9
|
import wandb
|
10
10
|
from wandb.apis.internal import Api
|
11
|
+
from wandb.sdk.launch.agent.job_status_tracker import JobAndRunStatusTracker
|
11
12
|
from wandb.sdk.launch.builder.abstract import AbstractBuilder
|
12
13
|
from wandb.sdk.launch.environment.aws_environment import AwsEnvironment
|
13
14
|
from wandb.sdk.launch.errors import LaunchError
|
@@ -23,9 +24,15 @@ _logger = logging.getLogger(__name__)
|
|
23
24
|
class SagemakerSubmittedRun(AbstractRun):
|
24
25
|
"""Instance of ``AbstractRun`` corresponding to a subprocess launched to run an entry point command on aws sagemaker."""
|
25
26
|
|
26
|
-
def __init__(
|
27
|
+
def __init__(
|
28
|
+
self,
|
29
|
+
training_job_name: str,
|
30
|
+
client: "boto3.Client",
|
31
|
+
log_client: Optional["boto3.Client"] = None,
|
32
|
+
) -> None:
|
27
33
|
super().__init__()
|
28
34
|
self.client = client
|
35
|
+
self.log_client = log_client
|
29
36
|
self.training_job_name = training_job_name
|
30
37
|
self._status = Status("running")
|
31
38
|
|
@@ -33,6 +40,38 @@ class SagemakerSubmittedRun(AbstractRun):
|
|
33
40
|
def id(self) -> str:
|
34
41
|
return f"sagemaker-{self.training_job_name}"
|
35
42
|
|
43
|
+
def get_logs(self) -> Optional[str]:
|
44
|
+
if self.log_client is None:
|
45
|
+
return None
|
46
|
+
try:
|
47
|
+
describe_res = self.log_client.describe_log_streams(
|
48
|
+
logGroupName="/aws/sagemaker/TrainingJobs",
|
49
|
+
logStreamNamePrefix=self.training_job_name,
|
50
|
+
)
|
51
|
+
if len(describe_res["logStreams"]) == 0:
|
52
|
+
wandb.termwarn(
|
53
|
+
f"Failed to get logs for training job: {self.training_job_name}"
|
54
|
+
)
|
55
|
+
return None
|
56
|
+
log_name = describe_res["logStreams"][0]["logStreamName"]
|
57
|
+
res = self.log_client.get_log_events(
|
58
|
+
logGroupName="/aws/sagemaker/TrainingJobs",
|
59
|
+
logStreamName=log_name,
|
60
|
+
)
|
61
|
+
return "\n".join(
|
62
|
+
[f'{event["timestamp"]}:{event["message"]}' for event in res["events"]]
|
63
|
+
)
|
64
|
+
except self.log_client.exceptions.ResourceNotFoundException:
|
65
|
+
wandb.termwarn(
|
66
|
+
f"Failed to get logs for training job: {self.training_job_name}"
|
67
|
+
)
|
68
|
+
return None
|
69
|
+
except Exception as e:
|
70
|
+
wandb.termwarn(
|
71
|
+
f"Failed to handle logs for training job: {self.training_job_name} with error {str(e)}"
|
72
|
+
)
|
73
|
+
return None
|
74
|
+
|
36
75
|
def wait(self) -> bool:
|
37
76
|
while True:
|
38
77
|
status_state = self.get_status().state
|
@@ -89,6 +128,7 @@ class SageMakerRunner(AbstractRunner):
|
|
89
128
|
self,
|
90
129
|
launch_project: LaunchProject,
|
91
130
|
builder: Optional[AbstractBuilder],
|
131
|
+
job_tracker: Optional[JobAndRunStatusTracker] = None,
|
92
132
|
) -> Optional[AbstractRun]:
|
93
133
|
"""Run a project on Amazon Sagemaker.
|
94
134
|
|
@@ -128,6 +168,13 @@ class SageMakerRunner(AbstractRunner):
|
|
128
168
|
|
129
169
|
# Create a sagemaker client to launch the job.
|
130
170
|
sagemaker_client = session.client("sagemaker")
|
171
|
+
log_client = None
|
172
|
+
try:
|
173
|
+
log_client = session.client("logs")
|
174
|
+
except Exception as e:
|
175
|
+
wandb.termwarn(
|
176
|
+
f"Failed to connect to cloudwatch logs with error {str(e)}, logs will not be available"
|
177
|
+
)
|
131
178
|
|
132
179
|
# if the user provided the image they want to use, use that, but warn it won't have swappable artifacts
|
133
180
|
if (
|
@@ -146,7 +193,9 @@ class SageMakerRunner(AbstractRunner):
|
|
146
193
|
_logger.info(
|
147
194
|
f"Launching sagemaker job on user supplied image with args: {sagemaker_args}"
|
148
195
|
)
|
149
|
-
run = launch_sagemaker_job(
|
196
|
+
run = launch_sagemaker_job(
|
197
|
+
launch_project, sagemaker_args, sagemaker_client, log_client
|
198
|
+
)
|
150
199
|
if self.backend_config[PROJECT_SYNCHRONOUS]:
|
151
200
|
run.wait()
|
152
201
|
return run
|
@@ -158,11 +207,9 @@ class SageMakerRunner(AbstractRunner):
|
|
158
207
|
assert builder is not None
|
159
208
|
# build our own image
|
160
209
|
_logger.info("Building docker image...")
|
161
|
-
image = builder.build_image(
|
162
|
-
launch_project,
|
163
|
-
entry_point,
|
164
|
-
)
|
210
|
+
image = builder.build_image(launch_project, entry_point, job_tracker)
|
165
211
|
_logger.info(f"Docker image built with uri {image}")
|
212
|
+
|
166
213
|
launch_project.fill_macros(image)
|
167
214
|
_logger.info("Connecting to sagemaker client")
|
168
215
|
command_args = get_entry_point_command(
|
@@ -181,7 +228,9 @@ class SageMakerRunner(AbstractRunner):
|
|
181
228
|
launch_project, self._api, role_arn, image, default_output_path
|
182
229
|
)
|
183
230
|
_logger.info(f"Launching sagemaker job with args: {sagemaker_args}")
|
184
|
-
run = launch_sagemaker_job(
|
231
|
+
run = launch_sagemaker_job(
|
232
|
+
launch_project, sagemaker_args, sagemaker_client, log_client
|
233
|
+
)
|
185
234
|
if self.backend_config[PROJECT_SYNCHRONOUS]:
|
186
235
|
run.wait()
|
187
236
|
return run
|
@@ -296,14 +345,15 @@ def launch_sagemaker_job(
|
|
296
345
|
launch_project: LaunchProject,
|
297
346
|
sagemaker_args: Dict[str, Any],
|
298
347
|
sagemaker_client: "boto3.Client",
|
348
|
+
log_client: Optional["boto3.Client"] = None,
|
299
349
|
) -> SagemakerSubmittedRun:
|
300
350
|
training_job_name = sagemaker_args.get("TrainingJobName") or launch_project.run_id
|
301
351
|
resp = sagemaker_client.create_training_job(**sagemaker_args)
|
302
352
|
|
303
353
|
if resp.get("TrainingJobArn") is None:
|
304
|
-
raise LaunchError("
|
354
|
+
raise LaunchError("Failed to create training job when submitting to SageMaker")
|
305
355
|
|
306
|
-
run = SagemakerSubmittedRun(training_job_name, sagemaker_client)
|
356
|
+
run = SagemakerSubmittedRun(training_job_name, sagemaker_client, log_client)
|
307
357
|
wandb.termlog(
|
308
358
|
f"{LOG_PREFIX}Run job submitted with arn: {resp.get('TrainingJobArn')}"
|
309
359
|
)
|
@@ -14,6 +14,7 @@ from wandb.apis.internal import Api
|
|
14
14
|
from wandb.util import get_module
|
15
15
|
|
16
16
|
from .._project_spec import LaunchProject, get_entry_point_command
|
17
|
+
from ..agent.job_status_tracker import JobAndRunStatusTracker
|
17
18
|
from ..builder.abstract import AbstractBuilder
|
18
19
|
from ..builder.build import get_env_vars_dict
|
19
20
|
from ..environment.gcp_environment import GcpEnvironment
|
@@ -35,6 +36,10 @@ class VertexSubmittedRun(AbstractRun):
|
|
35
36
|
# numeric ID of the custom training job
|
36
37
|
return self._job.name # type: ignore
|
37
38
|
|
39
|
+
def get_logs(self) -> Optional[str]:
|
40
|
+
# TODO: implement
|
41
|
+
return None
|
42
|
+
|
38
43
|
@property
|
39
44
|
def name(self) -> str:
|
40
45
|
return self._job.display_name # type: ignore
|
@@ -89,6 +94,7 @@ class VertexRunner(AbstractRunner):
|
|
89
94
|
self,
|
90
95
|
launch_project: LaunchProject,
|
91
96
|
builder: Optional[AbstractBuilder],
|
97
|
+
job_tracker: Optional[JobAndRunStatusTracker] = None,
|
92
98
|
) -> Optional[AbstractRun]:
|
93
99
|
"""Run a Vertex job."""
|
94
100
|
aiplatform = get_module( # noqa: F811
|
@@ -134,10 +140,8 @@ class VertexRunner(AbstractRunner):
|
|
134
140
|
else:
|
135
141
|
assert entry_point is not None
|
136
142
|
assert builder is not None
|
137
|
-
image_uri = builder.build_image(
|
138
|
-
|
139
|
-
entry_point,
|
140
|
-
)
|
143
|
+
image_uri = builder.build_image(launch_project, entry_point, job_tracker)
|
144
|
+
|
141
145
|
launch_project.fill_macros(image_uri)
|
142
146
|
# TODO: how to handle this?
|
143
147
|
entry_cmd = get_entry_point_command(entry_point, launch_project.override_args)
|
@@ -130,10 +130,10 @@ class Scheduler(ABC):
|
|
130
130
|
if resp.get("state") == SchedulerState.CANCELLED.name:
|
131
131
|
self._state = SchedulerState.CANCELLED
|
132
132
|
self._sweep_config = yaml.safe_load(resp["config"])
|
133
|
-
self._num_runs_launched: int =
|
133
|
+
self._num_runs_launched: int = self._get_num_runs_launched(resp["runs"])
|
134
134
|
if self._num_runs_launched > 0:
|
135
135
|
wandb.termlog(
|
136
|
-
f"{LOG_PREFIX}Found {self._num_runs_launched} previous runs for sweep {self._sweep_id}"
|
136
|
+
f"{LOG_PREFIX}Found {self._num_runs_launched} previous valid runs for sweep {self._sweep_id}"
|
137
137
|
)
|
138
138
|
except Exception as e:
|
139
139
|
raise SchedulerError(
|
@@ -295,10 +295,12 @@ class Scheduler(ABC):
|
|
295
295
|
self.state = SchedulerState.RUNNING
|
296
296
|
try:
|
297
297
|
while True:
|
298
|
-
|
298
|
+
self._update_scheduler_run_state()
|
299
299
|
if not self.is_alive:
|
300
300
|
break
|
301
301
|
|
302
|
+
wandb.termlog(f"{LOG_PREFIX}Polling for new runs to launch")
|
303
|
+
|
302
304
|
self._update_run_states()
|
303
305
|
self._poll()
|
304
306
|
if self.state == SchedulerState.FLUSH_RUNS:
|
@@ -316,8 +318,17 @@ class Scheduler(ABC):
|
|
316
318
|
self.state = SchedulerState.FLUSH_RUNS
|
317
319
|
break
|
318
320
|
|
319
|
-
|
320
|
-
|
321
|
+
try:
|
322
|
+
run: Optional[SweepRun] = self._get_next_sweep_run(worker_id)
|
323
|
+
if not run:
|
324
|
+
break
|
325
|
+
except SchedulerError as e:
|
326
|
+
raise SchedulerError(e)
|
327
|
+
except Exception as e:
|
328
|
+
wandb.termerror(
|
329
|
+
f"{LOG_PREFIX}Failed to get next sweep run: {e}"
|
330
|
+
)
|
331
|
+
self.state = SchedulerState.FAILED
|
321
332
|
break
|
322
333
|
|
323
334
|
if self._add_to_launch_queue(run):
|
@@ -356,10 +367,29 @@ class Scheduler(ABC):
|
|
356
367
|
SchedulerState.STOPPED,
|
357
368
|
]:
|
358
369
|
self.state = SchedulerState.FAILED
|
370
|
+
self._set_sweep_state("CRASHED")
|
371
|
+
else:
|
372
|
+
self._set_sweep_state("FINISHED")
|
373
|
+
|
359
374
|
self._stop_runs()
|
360
|
-
self._set_sweep_state("FINISHED")
|
361
375
|
self._wandb_run.finish()
|
362
376
|
|
377
|
+
def _get_num_runs_launched(self, runs: List[Dict[str, Any]]) -> int:
|
378
|
+
"""Returns the number of valid runs in the sweep."""
|
379
|
+
count = 0
|
380
|
+
for run in runs:
|
381
|
+
# if bad run, shouldn't be counted against run cap
|
382
|
+
if run.get("state", "") in ["killed", "crashed"] and not run.get(
|
383
|
+
"summaryMetrics"
|
384
|
+
):
|
385
|
+
_logger.debug(
|
386
|
+
f"excluding run: {run['name']} with state: {run['state']} from run cap \n{run}"
|
387
|
+
)
|
388
|
+
continue
|
389
|
+
count += 1
|
390
|
+
|
391
|
+
return count
|
392
|
+
|
363
393
|
def _try_load_executable(self) -> bool:
|
364
394
|
"""Check existance of valid executable for a run.
|
365
395
|
|
@@ -384,12 +414,17 @@ class Scheduler(ABC):
|
|
384
414
|
def _register_agents(self) -> None:
|
385
415
|
for worker_id in range(self._num_workers):
|
386
416
|
_logger.debug(f"{LOG_PREFIX}Starting AgentHeartbeat worker ({worker_id})")
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
417
|
+
try:
|
418
|
+
agent_config = self._api.register_agent(
|
419
|
+
f"{socket.gethostname()}-{worker_id}", # host
|
420
|
+
sweep_id=self._sweep_id,
|
421
|
+
project_name=self._project,
|
422
|
+
entity=self._entity,
|
423
|
+
)
|
424
|
+
except Exception as e:
|
425
|
+
_logger.debug(f"failed to register agent: {e}")
|
426
|
+
self.fail_sweep(f"failed to register agent: {e}")
|
427
|
+
|
393
428
|
self._workers[worker_id] = _Worker(
|
394
429
|
agent_config=agent_config,
|
395
430
|
agent_id=agent_config["id"],
|
@@ -455,6 +490,30 @@ class Scheduler(ABC):
|
|
455
490
|
|
456
491
|
return False
|
457
492
|
|
493
|
+
def _update_scheduler_run_state(self) -> None:
|
494
|
+
"""Update the scheduler state from state of scheduler run and sweep state."""
|
495
|
+
state: RunState = self._get_run_state(self._wandb_run.id)
|
496
|
+
|
497
|
+
if state == RunState.KILLED:
|
498
|
+
self.state = SchedulerState.STOPPED
|
499
|
+
elif state in [RunState.FAILED, RunState.CRASHED]:
|
500
|
+
self.state = SchedulerState.FAILED
|
501
|
+
elif state == RunState.FINISHED:
|
502
|
+
self.state = SchedulerState.COMPLETED
|
503
|
+
|
504
|
+
try:
|
505
|
+
sweep_state = self._api.get_sweep_state(
|
506
|
+
self._sweep_id, self._entity, self._project
|
507
|
+
)
|
508
|
+
except Exception as e:
|
509
|
+
_logger.debug(f"sweep state error: {sweep_state} e: {e}")
|
510
|
+
return
|
511
|
+
|
512
|
+
if sweep_state in ["FINISHED", "CANCELLED"]:
|
513
|
+
self.state = SchedulerState.COMPLETED
|
514
|
+
elif sweep_state in ["PAUSED", "STOPPED"]:
|
515
|
+
self.state = SchedulerState.FLUSH_RUNS
|
516
|
+
|
458
517
|
def _update_run_states(self) -> None:
|
459
518
|
"""Iterate through runs.
|
460
519
|
|
@@ -530,7 +589,7 @@ class Scheduler(ABC):
|
|
530
589
|
run_state = RunState.UNKNOWN
|
531
590
|
except (AttributeError, ValueError):
|
532
591
|
wandb.termwarn(
|
533
|
-
f"Bad state ({
|
592
|
+
f"Bad state ({run_state}) for run ({run_id}). Error: {traceback.format_exc()}"
|
534
593
|
)
|
535
594
|
run_state = RunState.UNKNOWN
|
536
595
|
return run_state
|
@@ -564,6 +623,35 @@ class Scheduler(ABC):
|
|
564
623
|
base64.b64decode(bytes(_id.encode("utf-8"))).decode("utf-8").split(":")[2]
|
565
624
|
)
|
566
625
|
|
626
|
+
def _make_entry_and_launch_config(
|
627
|
+
self, run: SweepRun
|
628
|
+
) -> Tuple[Optional[List[str]], Dict[str, Dict[str, Any]]]:
|
629
|
+
args = create_sweep_command_args({"args": run.args})
|
630
|
+
entry_point, macro_args = make_launch_sweep_entrypoint(
|
631
|
+
args, self._sweep_config.get("command")
|
632
|
+
)
|
633
|
+
# handle program macro
|
634
|
+
if entry_point and "${program}" in entry_point:
|
635
|
+
if not self._sweep_config.get("program"):
|
636
|
+
raise SchedulerError(
|
637
|
+
f"{LOG_PREFIX}Program macro in command has no corresponding 'program' in sweep config."
|
638
|
+
)
|
639
|
+
pidx = entry_point.index("${program}")
|
640
|
+
entry_point[pidx] = self._sweep_config["program"]
|
641
|
+
|
642
|
+
launch_config = {"overrides": {"run_config": args["args_dict"]}}
|
643
|
+
if macro_args: # pipe in hyperparam args as params to launch
|
644
|
+
launch_config["overrides"]["args"] = macro_args
|
645
|
+
|
646
|
+
if entry_point:
|
647
|
+
unresolved = [x for x in entry_point if str(x).startswith("${")]
|
648
|
+
if unresolved:
|
649
|
+
wandb.termwarn(
|
650
|
+
f"{LOG_PREFIX}Sweep command contains unresolved macros: "
|
651
|
+
f"{unresolved}, see launch docs for supported macros."
|
652
|
+
)
|
653
|
+
return entry_point, launch_config
|
654
|
+
|
567
655
|
def _add_to_launch_queue(self, run: SweepRun) -> bool:
|
568
656
|
"""Convert a sweeprun into a launch job then push to runqueue."""
|
569
657
|
# job and image first from CLI args, then from sweep config
|
@@ -575,25 +663,12 @@ class Scheduler(ABC):
|
|
575
663
|
elif _job is not None and _image_uri is not None:
|
576
664
|
raise SchedulerError(f"{LOG_PREFIX}Sweep has both 'job' and 'image_uri'")
|
577
665
|
|
578
|
-
|
579
|
-
entry_point, macro_args = make_launch_sweep_entrypoint(
|
580
|
-
args, self._sweep_config.get("command")
|
581
|
-
)
|
582
|
-
launch_config = {"overrides": {"run_config": args["args_dict"]}}
|
583
|
-
if macro_args: # pipe in hyperparam args as params to launch
|
584
|
-
launch_config["overrides"]["args"] = macro_args
|
585
|
-
|
666
|
+
entry_point, launch_config = self._make_entry_and_launch_config(run)
|
586
667
|
if entry_point:
|
587
668
|
wandb.termwarn(
|
588
669
|
f"{LOG_PREFIX}Sweep command {entry_point} will override"
|
589
670
|
f' {"job" if _job else "image_uri"} entrypoint'
|
590
671
|
)
|
591
|
-
unresolved = [x for x in entry_point if str(x).startswith("${")]
|
592
|
-
if unresolved:
|
593
|
-
wandb.termwarn(
|
594
|
-
f"{LOG_PREFIX}Sweep command contains unresolved macros: "
|
595
|
-
f"{unresolved}, see launch docs for supported macros."
|
596
|
-
)
|
597
672
|
|
598
673
|
run_id = run.id or generate_id()
|
599
674
|
queued_run = launch_add(
|
wandb/sdk/launch/sweeps/utils.py
CHANGED
@@ -291,3 +291,24 @@ def check_job_exists(public_api: PublicApi, job: Optional[str]) -> bool:
|
|
291
291
|
wandb.termerror(f"Failed to load job. {e}")
|
292
292
|
return False
|
293
293
|
return True
|
294
|
+
|
295
|
+
|
296
|
+
def get_previous_args(
|
297
|
+
run_spec: Dict[str, Any]
|
298
|
+
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
299
|
+
"""Parse through previous scheduler run_spec.
|
300
|
+
|
301
|
+
returns scheduler_args and settings.
|
302
|
+
"""
|
303
|
+
scheduler_args = (
|
304
|
+
run_spec.get("overrides", {}).get("run_config", {}).get("scheduler", {})
|
305
|
+
)
|
306
|
+
# also pipe through top level resource setup
|
307
|
+
if run_spec.get("resource"):
|
308
|
+
scheduler_args["resource"] = run_spec["resource"]
|
309
|
+
if run_spec.get("resource_args"):
|
310
|
+
scheduler_args["resource_args"] = run_spec["resource_args"]
|
311
|
+
|
312
|
+
settings = run_spec.get("overrides", {}).get("run_config", {}).get("settings", {})
|
313
|
+
|
314
|
+
return scheduler_args, settings
|
wandb/sdk/launch/utils.py
CHANGED
@@ -28,7 +28,8 @@ FAILED_PACKAGES_REGEX = re.compile(
|
|
28
28
|
)
|
29
29
|
|
30
30
|
if TYPE_CHECKING: # pragma: no cover
|
31
|
-
from wandb.sdk.artifacts.
|
31
|
+
from wandb.sdk.artifacts.artifact import Artifact
|
32
|
+
from wandb.sdk.launch.agent.job_status_tracker import JobAndRunStatusTracker
|
32
33
|
|
33
34
|
|
34
35
|
# TODO: this should be restricted to just Git repos and not S3 and stuff like that
|
@@ -47,7 +48,7 @@ _WANDB_LOCAL_DEV_URI_REGEX = re.compile(
|
|
47
48
|
r"^https?://localhost"
|
48
49
|
) # for testing, not sure if we wanna keep this
|
49
50
|
|
50
|
-
API_KEY_REGEX = r"WANDB_API_KEY=\w+"
|
51
|
+
API_KEY_REGEX = r"WANDB_API_KEY=\w+(-\w+)?"
|
51
52
|
|
52
53
|
MACRO_REGEX = re.compile(r"\$\{(\w+)\}")
|
53
54
|
|
@@ -493,7 +494,7 @@ def convert_jupyter_notebook_to_script(fname: str, project_dir: str) -> str:
|
|
493
494
|
|
494
495
|
def check_and_download_code_artifacts(
|
495
496
|
entity: str, project: str, run_name: str, internal_api: Api, project_dir: str
|
496
|
-
) -> Optional["
|
497
|
+
) -> Optional["Artifact"]:
|
497
498
|
_logger.info("Checking for code artifacts")
|
498
499
|
public_api = wandb.PublicApi(
|
499
500
|
overrides={"base_url": internal_api.settings("base_url")}
|
@@ -620,12 +621,23 @@ def make_name_dns_safe(name: str) -> str:
|
|
620
621
|
return resp
|
621
622
|
|
622
623
|
|
623
|
-
def warn_failed_packages_from_build_logs(
|
624
|
+
def warn_failed_packages_from_build_logs(
|
625
|
+
log: str, image_uri: str, api: Api, job_tracker: Optional["JobAndRunStatusTracker"]
|
626
|
+
) -> None:
|
624
627
|
match = FAILED_PACKAGES_REGEX.search(log)
|
625
628
|
if match:
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
+
_msg = f"Failed to install the following packages: {match.group(1)} for image: {image_uri}. Will attempt to launch image without them."
|
630
|
+
wandb.termwarn(_msg)
|
631
|
+
if job_tracker is not None:
|
632
|
+
res = job_tracker.saver.save_contents(
|
633
|
+
_msg, "failed-packages.log", "warning"
|
634
|
+
)
|
635
|
+
api.update_run_queue_item_warning(
|
636
|
+
job_tracker.run_queue_item_id,
|
637
|
+
"Some packages were not successfully installed during the build",
|
638
|
+
"build",
|
639
|
+
res,
|
640
|
+
)
|
629
641
|
|
630
642
|
|
631
643
|
def docker_image_exists(docker_image: str, should_raise: bool = False) -> bool:
|