wandb 0.15.9__py3-none-any.whl → 0.15.11__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- wandb/__init__.py +5 -1
- wandb/apis/public.py +137 -17
- wandb/apis/reports/_panels.py +1 -1
- wandb/apis/reports/blocks.py +1 -0
- wandb/apis/reports/report.py +27 -5
- wandb/cli/cli.py +52 -41
- wandb/docker/__init__.py +17 -0
- wandb/docker/auth.py +1 -1
- wandb/env.py +24 -4
- wandb/filesync/step_checksum.py +3 -3
- wandb/integration/openai/openai.py +3 -0
- wandb/integration/ultralytics/__init__.py +9 -0
- wandb/integration/ultralytics/bbox_utils.py +196 -0
- wandb/integration/ultralytics/callback.py +458 -0
- wandb/integration/ultralytics/classification_utils.py +66 -0
- wandb/integration/ultralytics/mask_utils.py +141 -0
- wandb/integration/ultralytics/pose_utils.py +92 -0
- wandb/integration/xgboost/xgboost.py +3 -3
- wandb/integration/yolov8/__init__.py +0 -7
- wandb/integration/yolov8/yolov8.py +22 -3
- wandb/old/settings.py +7 -0
- wandb/plot/line_series.py +0 -1
- wandb/proto/v3/wandb_internal_pb2.py +353 -300
- wandb/proto/v3/wandb_server_pb2.py +37 -41
- wandb/proto/v3/wandb_settings_pb2.py +2 -2
- wandb/proto/v3/wandb_telemetry_pb2.py +16 -16
- wandb/proto/v4/wandb_internal_pb2.py +272 -260
- wandb/proto/v4/wandb_server_pb2.py +37 -40
- wandb/proto/v4/wandb_settings_pb2.py +2 -2
- wandb/proto/v4/wandb_telemetry_pb2.py +16 -16
- wandb/proto/wandb_internal_codegen.py +7 -31
- wandb/sdk/artifacts/artifact.py +321 -189
- wandb/sdk/artifacts/artifact_cache.py +14 -0
- wandb/sdk/artifacts/artifact_manifest.py +5 -4
- wandb/sdk/artifacts/artifact_manifest_entry.py +37 -9
- wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +1 -9
- wandb/sdk/artifacts/artifact_saver.py +13 -50
- wandb/sdk/artifacts/artifact_ttl.py +6 -0
- wandb/sdk/artifacts/artifacts_cache.py +119 -93
- wandb/sdk/artifacts/staging.py +25 -0
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +12 -7
- wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +2 -3
- wandb/sdk/artifacts/storage_policies/__init__.py +4 -0
- wandb/sdk/artifacts/storage_policies/register.py +1 -0
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +4 -3
- wandb/sdk/artifacts/storage_policy.py +4 -2
- wandb/sdk/backend/backend.py +0 -16
- wandb/sdk/data_types/image.py +3 -1
- wandb/sdk/integration_utils/auto_logging.py +38 -13
- wandb/sdk/interface/interface.py +16 -135
- wandb/sdk/interface/interface_shared.py +9 -147
- wandb/sdk/interface/interface_sock.py +0 -26
- wandb/sdk/internal/file_pusher.py +20 -3
- wandb/sdk/internal/file_stream.py +3 -1
- wandb/sdk/internal/handler.py +53 -70
- wandb/sdk/internal/internal_api.py +220 -130
- wandb/sdk/internal/job_builder.py +41 -37
- wandb/sdk/internal/sender.py +7 -25
- wandb/sdk/internal/system/assets/disk.py +144 -11
- wandb/sdk/internal/system/system_info.py +6 -2
- wandb/sdk/launch/__init__.py +5 -0
- wandb/sdk/launch/{launch.py → _launch.py} +53 -54
- wandb/sdk/launch/{launch_add.py → _launch_add.py} +34 -31
- wandb/sdk/launch/_project_spec.py +13 -2
- wandb/sdk/launch/agent/agent.py +103 -59
- wandb/sdk/launch/agent/run_queue_item_file_saver.py +6 -4
- wandb/sdk/launch/builder/build.py +19 -1
- wandb/sdk/launch/builder/docker_builder.py +5 -1
- wandb/sdk/launch/builder/kaniko_builder.py +5 -1
- wandb/sdk/launch/create_job.py +20 -5
- wandb/sdk/launch/loader.py +14 -5
- wandb/sdk/launch/runner/abstract.py +0 -2
- wandb/sdk/launch/runner/kubernetes_monitor.py +329 -0
- wandb/sdk/launch/runner/kubernetes_runner.py +66 -209
- wandb/sdk/launch/runner/local_container.py +5 -2
- wandb/sdk/launch/runner/local_process.py +4 -1
- wandb/sdk/launch/sweeps/scheduler.py +43 -25
- wandb/sdk/launch/sweeps/utils.py +5 -3
- wandb/sdk/launch/utils.py +3 -1
- wandb/sdk/lib/_settings_toposort_generate.py +3 -9
- wandb/sdk/lib/_settings_toposort_generated.py +27 -3
- wandb/sdk/lib/_wburls_generated.py +1 -0
- wandb/sdk/lib/filenames.py +27 -6
- wandb/sdk/lib/filesystem.py +181 -7
- wandb/sdk/lib/fsm.py +5 -3
- wandb/sdk/lib/gql_request.py +3 -0
- wandb/sdk/lib/ipython.py +7 -0
- wandb/sdk/lib/wburls.py +1 -0
- wandb/sdk/service/port_file.py +2 -15
- wandb/sdk/service/server.py +7 -55
- wandb/sdk/service/service.py +56 -26
- wandb/sdk/service/service_base.py +1 -1
- wandb/sdk/service/streams.py +11 -5
- wandb/sdk/verify/verify.py +2 -2
- wandb/sdk/wandb_init.py +8 -2
- wandb/sdk/wandb_manager.py +4 -14
- wandb/sdk/wandb_run.py +143 -53
- wandb/sdk/wandb_settings.py +148 -35
- wandb/testing/relay.py +85 -38
- wandb/util.py +87 -4
- wandb/wandb_torch.py +24 -38
- {wandb-0.15.9.dist-info → wandb-0.15.11.dist-info}/METADATA +48 -23
- {wandb-0.15.9.dist-info → wandb-0.15.11.dist-info}/RECORD +107 -103
- {wandb-0.15.9.dist-info → wandb-0.15.11.dist-info}/WHEEL +1 -1
- wandb/proto/v3/wandb_server_pb2_grpc.py +0 -1422
- wandb/proto/v4/wandb_server_pb2_grpc.py +0 -1422
- wandb/proto/wandb_server_pb2_grpc.py +0 -8
- wandb/sdk/artifacts/storage_policies/s3_bucket_policy.py +0 -61
- wandb/sdk/interface/interface_grpc.py +0 -460
- wandb/sdk/service/server_grpc.py +0 -444
- wandb/sdk/service/service_grpc.py +0 -73
- {wandb-0.15.9.dist-info → wandb-0.15.11.dist-info}/LICENSE +0 -0
- {wandb-0.15.9.dist-info → wandb-0.15.11.dist-info}/entry_points.txt +0 -0
- {wandb-0.15.9.dist-info → wandb-0.15.11.dist-info}/top_level.txt +0 -0
@@ -6,13 +6,15 @@ import logging
|
|
6
6
|
import time
|
7
7
|
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
8
8
|
|
9
|
+
import yaml
|
10
|
+
|
9
11
|
import wandb
|
10
12
|
from wandb.apis.internal import Api
|
11
13
|
from wandb.sdk.launch.environment.abstract import AbstractEnvironment
|
12
14
|
from wandb.sdk.launch.registry.abstract import AbstractRegistry
|
13
15
|
from wandb.sdk.launch.registry.azure_container_registry import AzureContainerRegistry
|
14
16
|
from wandb.sdk.launch.registry.local_registry import LocalRegistry
|
15
|
-
from wandb.sdk.launch.runner.abstract import
|
17
|
+
from wandb.sdk.launch.runner.abstract import Status
|
16
18
|
from wandb.util import get_module
|
17
19
|
|
18
20
|
from .._project_spec import EntryPoint, LaunchProject
|
@@ -26,6 +28,7 @@ from ..utils import (
|
|
26
28
|
make_name_dns_safe,
|
27
29
|
)
|
28
30
|
from .abstract import AbstractRun, AbstractRunner
|
31
|
+
from .kubernetes_monitor import KubernetesRunMonitor
|
29
32
|
|
30
33
|
get_module(
|
31
34
|
"kubernetes",
|
@@ -43,32 +46,16 @@ from kubernetes.client.models.v1_secret import V1Secret # type: ignore # noqa:
|
|
43
46
|
from kubernetes.client.rest import ApiException # type: ignore # noqa: E402
|
44
47
|
|
45
48
|
TIMEOUT = 5
|
46
|
-
MAX_KUBERNETES_RETRIES = (
|
47
|
-
60 # default 10 second loop time on the agent, this is 10 minutes
|
48
|
-
)
|
49
|
-
FAIL_MESSAGE_INTERVAL = 60
|
50
49
|
|
51
50
|
_logger = logging.getLogger(__name__)
|
52
51
|
|
53
52
|
|
54
|
-
# Dict for mapping possible states of custom objects to the states we want to report
|
55
|
-
# to the agent.
|
56
|
-
CRD_STATE_DICT: Dict[str, State] = {
|
57
|
-
"pending": "starting",
|
58
|
-
"running": "running",
|
59
|
-
"completed": "finished",
|
60
|
-
"failed": "failed",
|
61
|
-
"aborted": "failed",
|
62
|
-
"terminating": "stopping",
|
63
|
-
"terminated": "stopped",
|
64
|
-
}
|
65
|
-
|
66
|
-
|
67
53
|
class KubernetesSubmittedRun(AbstractRun):
|
68
54
|
"""Wrapper for a launched run on Kubernetes."""
|
69
55
|
|
70
56
|
def __init__(
|
71
57
|
self,
|
58
|
+
monitor: KubernetesRunMonitor,
|
72
59
|
batch_api: "BatchV1Api",
|
73
60
|
core_api: "CoreV1Api",
|
74
61
|
name: str,
|
@@ -78,6 +65,14 @@ class KubernetesSubmittedRun(AbstractRun):
|
|
78
65
|
) -> None:
|
79
66
|
"""Initialize a KubernetesSubmittedRun.
|
80
67
|
|
68
|
+
Other implementations of the AbstractRun interface poll on the run
|
69
|
+
when `get_status` is called, but KubernetesSubmittedRun uses
|
70
|
+
Kubernetes watch streams to update the run status. One thread handles
|
71
|
+
events from the job object and another thread handles events from the
|
72
|
+
rank 0 pod. These threads updated the `_status` attributed of the
|
73
|
+
KubernetesSubmittedRun object. When `get_status` is called, the
|
74
|
+
`_status` attribute is returned.
|
75
|
+
|
81
76
|
Arguments:
|
82
77
|
batch_api: Kubernetes BatchV1Api object.
|
83
78
|
core_api: Kubernetes CoreV1Api object.
|
@@ -89,13 +84,11 @@ class KubernetesSubmittedRun(AbstractRun):
|
|
89
84
|
Returns:
|
90
85
|
None.
|
91
86
|
"""
|
87
|
+
self.monitor = monitor
|
92
88
|
self.batch_api = batch_api
|
93
89
|
self.core_api = core_api
|
94
90
|
self.name = name
|
95
91
|
self.namespace = namespace
|
96
|
-
self.job = self.batch_api.read_namespaced_job(
|
97
|
-
name=self.name, namespace=self.namespace
|
98
|
-
)
|
99
92
|
self._fail_count = 0
|
100
93
|
self.pod_names = pod_names
|
101
94
|
self.secret = secret
|
@@ -136,7 +129,7 @@ class KubernetesSubmittedRun(AbstractRun):
|
|
136
129
|
while True:
|
137
130
|
status = self.get_status()
|
138
131
|
wandb.termlog(f"{LOG_PREFIX}Job {self.name} status: {status}")
|
139
|
-
if status.state
|
132
|
+
if status.state in ["finished", "failed", "preempted"]:
|
140
133
|
break
|
141
134
|
time.sleep(5)
|
142
135
|
return (
|
@@ -156,98 +149,20 @@ class KubernetesSubmittedRun(AbstractRun):
|
|
156
149
|
)
|
157
150
|
|
158
151
|
def get_status(self) -> Status:
|
159
|
-
|
160
|
-
try:
|
161
|
-
job_response = self.batch_api.read_namespaced_job_status(
|
162
|
-
name=self.name, namespace=self.namespace
|
163
|
-
)
|
164
|
-
except ApiException as e:
|
165
|
-
if e.status == 404:
|
166
|
-
wandb.termerror(
|
167
|
-
f"Could not reach job {self.name} in namespace {self.namespace}"
|
168
|
-
)
|
169
|
-
self._delete_secret_if_completed("failed")
|
170
|
-
return Status("failed")
|
171
|
-
|
172
|
-
status = job_response.status
|
152
|
+
return self.monitor.get_status()
|
173
153
|
|
154
|
+
def cancel(self) -> None:
|
155
|
+
"""Cancel the run."""
|
156
|
+
self.monitor.stop()
|
174
157
|
try:
|
175
|
-
|
176
|
-
|
158
|
+
self.batch_api.delete_namespaced_job(
|
159
|
+
namespace=self.namespace,
|
160
|
+
name=self.name,
|
177
161
|
)
|
178
162
|
except ApiException as e:
|
179
|
-
if e.status == 404:
|
180
|
-
wandb.termerror(
|
181
|
-
f"Could not reach pod {self.pod_names[0]} in namespace {self.namespace}"
|
182
|
-
)
|
183
|
-
self._delete_secret_if_completed("failed")
|
184
|
-
return Status("failed")
|
185
|
-
|
186
|
-
if hasattr(pod.status, "conditions") and pod.status.conditions is not None:
|
187
|
-
for condition in pod.status.conditions:
|
188
|
-
if condition.type == "DisruptionTarget" and condition.reason in [
|
189
|
-
"EvictionByEvictionAPI",
|
190
|
-
"PreemptionByScheduler",
|
191
|
-
"TerminationByKubelet",
|
192
|
-
]:
|
193
|
-
return Status("preempted")
|
194
|
-
if pod.status.phase in ["Pending", "Unknown"]:
|
195
|
-
now = time.time()
|
196
|
-
if self._fail_count == 0:
|
197
|
-
self._fail_first_msg_time = now
|
198
|
-
self._fail_last_msg_time = 0.0
|
199
|
-
self._fail_count += 1
|
200
|
-
if now - self._fail_last_msg_time > FAIL_MESSAGE_INTERVAL:
|
201
|
-
wandb.termlog(
|
202
|
-
f"{LOG_PREFIX}Pod has not started yet for job: {self.name}. Will wait up to {round(10 - (now - self._fail_first_msg_time)/60)} minutes."
|
203
|
-
)
|
204
|
-
self._fail_last_msg_time = now
|
205
|
-
if self._fail_count > MAX_KUBERNETES_RETRIES:
|
206
|
-
raise LaunchError(f"Failed to start job {self.name}")
|
207
|
-
# todo: we only handle the 1 pod case. see https://kubernetes.io/docs/concepts/workloads/controllers/job/#parallel-jobs for multipod handling
|
208
|
-
return_status = None
|
209
|
-
if status.succeeded == 1:
|
210
|
-
return_status = Status("finished")
|
211
|
-
elif status.failed is not None and status.failed >= 1:
|
212
|
-
return_status = Status("failed")
|
213
|
-
elif status.active == 1:
|
214
|
-
return Status("running")
|
215
|
-
elif status.conditions is not None and status.conditions[0].type == "Suspended":
|
216
|
-
return_status = Status("stopped")
|
217
|
-
else:
|
218
|
-
return_status = Status("unknown")
|
219
|
-
|
220
|
-
self._delete_secret_if_completed(return_status.state)
|
221
|
-
return return_status
|
222
|
-
|
223
|
-
def suspend(self) -> None:
|
224
|
-
"""Suspend the run."""
|
225
|
-
self.job.spec.suspend = True
|
226
|
-
self.batch_api.patch_namespaced_job(
|
227
|
-
name=self.name, namespace=self.namespace, body=self.job
|
228
|
-
)
|
229
|
-
timeout = TIMEOUT
|
230
|
-
job_response = self.batch_api.read_namespaced_job_status(
|
231
|
-
name=self.name, namespace=self.namespace
|
232
|
-
)
|
233
|
-
while job_response.status.conditions is None and timeout > 0:
|
234
|
-
time.sleep(1)
|
235
|
-
timeout -= 1
|
236
|
-
job_response = self.batch_api.read_namespaced_job_status(
|
237
|
-
name=self.name, namespace=self.namespace
|
238
|
-
)
|
239
|
-
|
240
|
-
if timeout == 0 or job_response.status.conditions[0].type != "Suspended":
|
241
163
|
raise LaunchError(
|
242
|
-
"Failed to
|
243
|
-
|
244
|
-
)
|
245
|
-
)
|
246
|
-
|
247
|
-
def cancel(self) -> None:
|
248
|
-
"""Cancel the run."""
|
249
|
-
self.suspend()
|
250
|
-
self.batch_api.delete_namespaced_job(name=self.name, namespace=self.namespace)
|
164
|
+
f"Failed to delete Kubernetes Job {self.name} in namespace {self.namespace}: {str(e)}"
|
165
|
+
) from e
|
251
166
|
|
252
167
|
|
253
168
|
class CrdSubmittedRun(AbstractRun):
|
@@ -262,7 +177,7 @@ class CrdSubmittedRun(AbstractRun):
|
|
262
177
|
namespace: str,
|
263
178
|
core_api: CoreV1Api,
|
264
179
|
custom_api: CustomObjectsApi,
|
265
|
-
|
180
|
+
monitor: KubernetesRunMonitor,
|
266
181
|
) -> None:
|
267
182
|
"""Create a run object for tracking the progress of a CRD.
|
268
183
|
|
@@ -274,7 +189,7 @@ class CrdSubmittedRun(AbstractRun):
|
|
274
189
|
namespace: The namespace of the CRD instance.
|
275
190
|
core_api: The Kubernetes core API client.
|
276
191
|
custom_api: The Kubernetes custom object API client.
|
277
|
-
|
192
|
+
monitor: The run monitor.
|
278
193
|
|
279
194
|
Raises:
|
280
195
|
LaunchError: If the CRD instance does not exist.
|
@@ -286,20 +201,8 @@ class CrdSubmittedRun(AbstractRun):
|
|
286
201
|
self.namespace = namespace
|
287
202
|
self.core_api = core_api
|
288
203
|
self.custom_api = custom_api
|
289
|
-
self.pod_names = pod_names
|
290
204
|
self._fail_count = 0
|
291
|
-
|
292
|
-
self.job = self.custom_api.get_namespaced_custom_object(
|
293
|
-
group=self.group,
|
294
|
-
version=self.version,
|
295
|
-
namespace=self.namespace,
|
296
|
-
plural=self.plural,
|
297
|
-
name=self.name,
|
298
|
-
)
|
299
|
-
except ApiException as e:
|
300
|
-
raise LaunchError(
|
301
|
-
f"Failed to get CRD {self.name} in namespace {self.namespace}: {str(e)}"
|
302
|
-
) from e
|
205
|
+
self.monitor = monitor
|
303
206
|
|
304
207
|
@property
|
305
208
|
def id(self) -> str:
|
@@ -311,7 +214,11 @@ class CrdSubmittedRun(AbstractRun):
|
|
311
214
|
# TODO: test more carefully once we release multi-node support
|
312
215
|
logs: Dict[str, Optional[str]] = {}
|
313
216
|
try:
|
314
|
-
|
217
|
+
pods = self.core_api.list_namespaced_pod(
|
218
|
+
label_selector=f"wandb/run-id={self.name}", namespace=self.namespace
|
219
|
+
)
|
220
|
+
pod_names = [pi.metadata.name for pi in pods.items]
|
221
|
+
for pod_name in pod_names:
|
315
222
|
logs[pod_name] = self.core_api.read_namespaced_pod_log(
|
316
223
|
name=pod_name, namespace=self.namespace
|
317
224
|
)
|
@@ -325,30 +232,7 @@ class CrdSubmittedRun(AbstractRun):
|
|
325
232
|
|
326
233
|
def get_status(self) -> Status:
|
327
234
|
"""Get status of custom object."""
|
328
|
-
|
329
|
-
job_response = self.custom_api.get_namespaced_custom_object_status(
|
330
|
-
group=self.group,
|
331
|
-
version=self.version,
|
332
|
-
namespace=self.namespace,
|
333
|
-
plural=self.plural,
|
334
|
-
name=self.name,
|
335
|
-
)
|
336
|
-
except ApiException as e:
|
337
|
-
raise LaunchError(
|
338
|
-
f"Failed to get CRD {self.name} in namespace {self.namespace}: {str(e)}"
|
339
|
-
) from e
|
340
|
-
# Custom objects can technically define whater states and format the
|
341
|
-
# response to the status request however they want. This checks for
|
342
|
-
# the most common cases.
|
343
|
-
status = job_response["status"]
|
344
|
-
state = status.get("state")
|
345
|
-
if isinstance(state, dict):
|
346
|
-
state = state.get("phase")
|
347
|
-
if state is None:
|
348
|
-
raise LaunchError(
|
349
|
-
f"Failed to get CRD {self.name} in namespace {self.namespace}: no state found"
|
350
|
-
)
|
351
|
-
return Status(CRD_STATE_DICT.get(state.lower(), "unknown"))
|
235
|
+
return self.monitor.get_status()
|
352
236
|
|
353
237
|
def cancel(self) -> None:
|
354
238
|
"""Cancel the custom object."""
|
@@ -370,10 +254,9 @@ class CrdSubmittedRun(AbstractRun):
|
|
370
254
|
while True:
|
371
255
|
status = self.get_status()
|
372
256
|
wandb.termlog(f"{LOG_PREFIX}Job {self.name} status: {status}")
|
373
|
-
if status.state != "running":
|
374
|
-
break
|
375
257
|
time.sleep(5)
|
376
|
-
|
258
|
+
if status.state in ["finished", "failed", "preempted"]:
|
259
|
+
return status.state == "finished"
|
377
260
|
|
378
261
|
|
379
262
|
class KubernetesRunner(AbstractRunner):
|
@@ -400,48 +283,6 @@ class KubernetesRunner(AbstractRunner):
|
|
400
283
|
self.environment = environment
|
401
284
|
self.registry = registry
|
402
285
|
|
403
|
-
def wait_job_launch(
|
404
|
-
self,
|
405
|
-
job_name: str,
|
406
|
-
namespace: str,
|
407
|
-
core_api: "CoreV1Api",
|
408
|
-
label: str = "job-name",
|
409
|
-
) -> List[str]:
|
410
|
-
"""Wait for a job to be launched and return the pod names.
|
411
|
-
|
412
|
-
Arguments:
|
413
|
-
job_name: The name of the job.
|
414
|
-
namespace: The namespace of the job.
|
415
|
-
core_api: The Kubernetes core API client.
|
416
|
-
label: The label key to match against job_name.
|
417
|
-
|
418
|
-
Returns:
|
419
|
-
The names of the pods associated with the job.
|
420
|
-
"""
|
421
|
-
pods = core_api.list_namespaced_pod(
|
422
|
-
label_selector=f"{label}={job_name}", namespace=namespace
|
423
|
-
)
|
424
|
-
timeout = TIMEOUT
|
425
|
-
while len(pods.items) == 0 and timeout > 0:
|
426
|
-
time.sleep(1)
|
427
|
-
timeout -= 1
|
428
|
-
pods = core_api.list_namespaced_pod(
|
429
|
-
label_selector=f"{label}={job_name}", namespace=namespace
|
430
|
-
)
|
431
|
-
|
432
|
-
if timeout == 0:
|
433
|
-
raise LaunchError(
|
434
|
-
"No pods found for job {}. Check dashboard to see if job was launched successfully.".format(
|
435
|
-
job_name
|
436
|
-
)
|
437
|
-
)
|
438
|
-
|
439
|
-
pod_names = [pi.metadata.name for pi in pods.items]
|
440
|
-
wandb.termlog(
|
441
|
-
f"{LOG_PREFIX}Job {job_name} created on pod(s) {', '.join(pod_names)}. See logs with e.g. `kubectl logs {pod_names[0]} -n {namespace}`."
|
442
|
-
)
|
443
|
-
return pod_names
|
444
|
-
|
445
286
|
def get_namespace(
|
446
287
|
self, resource_args: Dict[str, Any], context: Dict[str, Any]
|
447
288
|
) -> str:
|
@@ -522,18 +363,10 @@ class KubernetesRunner(AbstractRunner):
|
|
522
363
|
or launch_project.get_single_entry_point()
|
523
364
|
)
|
524
365
|
if launch_project.docker_image:
|
525
|
-
if len(containers) > 1:
|
526
|
-
raise LaunchError(
|
527
|
-
"Invalid specification of multiple containers. See https://docs.wandb.ai/guides/launch for guidance on submitting jobs."
|
528
|
-
)
|
529
366
|
# dont specify run id if user provided image, could have multiple runs
|
530
367
|
containers[0]["image"] = image_uri
|
531
368
|
# TODO: handle secret pulling image from registry
|
532
369
|
elif not any(["image" in cont for cont in containers]):
|
533
|
-
if len(containers) > 1:
|
534
|
-
raise LaunchError(
|
535
|
-
"Launch only builds one container at a time. See https://docs.wandb.ai/guides/launch for guidance on submitting jobs."
|
536
|
-
)
|
537
370
|
assert entry_point is not None
|
538
371
|
# in the non instance case we need to make an imagePullSecret
|
539
372
|
# so the new job can pull the image
|
@@ -638,16 +471,27 @@ class KubernetesRunner(AbstractRunner):
|
|
638
471
|
body=resource_args,
|
639
472
|
)
|
640
473
|
except ApiException as e:
|
474
|
+
body = json.loads(e.body)
|
475
|
+
body_yaml = yaml.dump(body)
|
641
476
|
raise LaunchError(
|
642
|
-
f"Error creating CRD of kind {kind}: {e.status} {e.reason}"
|
477
|
+
f"Error creating CRD of kind {kind}: {e.status} {e.reason}\n{body_yaml}"
|
643
478
|
) from e
|
644
479
|
name = response.get("metadata", {}).get("name")
|
645
480
|
_logger.info(f"Created {kind} {response['metadata']['name']}")
|
646
481
|
core = client.CoreV1Api(api_client)
|
647
|
-
|
648
|
-
|
482
|
+
run_monitor = KubernetesRunMonitor(
|
483
|
+
job_field_selector=f"metadata.name={name}",
|
484
|
+
pod_label_selector=f"wandb/run-id={launch_project.run_id}",
|
485
|
+
namespace=namespace,
|
486
|
+
batch_api=None,
|
487
|
+
core_api=core,
|
488
|
+
custom_api=api,
|
489
|
+
group=group,
|
490
|
+
version=version,
|
491
|
+
plural=plural,
|
649
492
|
)
|
650
|
-
|
493
|
+
run_monitor.start()
|
494
|
+
submitted_run = CrdSubmittedRun(
|
651
495
|
name=name,
|
652
496
|
group=group,
|
653
497
|
version=version,
|
@@ -655,8 +499,11 @@ class KubernetesRunner(AbstractRunner):
|
|
655
499
|
plural=plural,
|
656
500
|
core_api=client.CoreV1Api(api_client),
|
657
501
|
custom_api=api,
|
658
|
-
|
502
|
+
monitor=run_monitor,
|
659
503
|
)
|
504
|
+
if self.backend_config[PROJECT_SYNCHRONOUS]:
|
505
|
+
submitted_run.wait()
|
506
|
+
return submitted_run
|
660
507
|
|
661
508
|
batch_api = kubernetes.client.BatchV1Api(api_client)
|
662
509
|
core_api = kubernetes.client.CoreV1Api(api_client)
|
@@ -674,12 +521,22 @@ class KubernetesRunner(AbstractRunner):
|
|
674
521
|
0
|
675
522
|
] # create_from_yaml returns a nested list of k8s objects
|
676
523
|
job_name = job_response.metadata.name
|
677
|
-
|
524
|
+
|
525
|
+
# Event stream monitor to ensure pod creation and job completion.
|
526
|
+
monitor = KubernetesRunMonitor(
|
527
|
+
job_field_selector=f"metadata.name={job_name}",
|
528
|
+
pod_label_selector=f"job-name={job_name}",
|
529
|
+
namespace=namespace,
|
530
|
+
batch_api=batch_api,
|
531
|
+
core_api=core_api,
|
532
|
+
)
|
533
|
+
monitor.start()
|
678
534
|
submitted_job = KubernetesSubmittedRun(
|
679
|
-
batch_api, core_api, job_name,
|
535
|
+
monitor, batch_api, core_api, job_name, [], namespace, secret
|
680
536
|
)
|
681
537
|
if self.backend_config[PROJECT_SYNCHRONOUS]:
|
682
538
|
submitted_job.wait()
|
539
|
+
|
683
540
|
return submitted_job
|
684
541
|
|
685
542
|
|
@@ -5,7 +5,7 @@ import subprocess
|
|
5
5
|
import sys
|
6
6
|
import threading
|
7
7
|
import time
|
8
|
-
from typing import Any, Dict, List, Optional
|
8
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
9
9
|
|
10
10
|
import wandb
|
11
11
|
from wandb.sdk.launch.environment.abstract import AbstractEnvironment
|
@@ -26,6 +26,9 @@ from ..utils import (
|
|
26
26
|
)
|
27
27
|
from .abstract import AbstractRun, AbstractRunner, Status
|
28
28
|
|
29
|
+
if TYPE_CHECKING:
|
30
|
+
from wandb.apis.internal import Api
|
31
|
+
|
29
32
|
_logger = logging.getLogger(__name__)
|
30
33
|
|
31
34
|
|
@@ -95,7 +98,7 @@ class LocalContainerRunner(AbstractRunner):
|
|
95
98
|
|
96
99
|
def __init__(
|
97
100
|
self,
|
98
|
-
api:
|
101
|
+
api: "Api",
|
99
102
|
backend_config: Dict[str, Any],
|
100
103
|
environment: AbstractEnvironment,
|
101
104
|
registry: AbstractRegistry,
|
@@ -46,7 +46,10 @@ class LocalProcessRunner(AbstractRunner):
|
|
46
46
|
_logger.warning(_msg)
|
47
47
|
|
48
48
|
synchronous: bool = self.backend_config[PROJECT_SYNCHRONOUS]
|
49
|
-
entry_point =
|
49
|
+
entry_point = (
|
50
|
+
launch_project.override_entrypoint
|
51
|
+
or launch_project.get_single_entry_point()
|
52
|
+
)
|
50
53
|
|
51
54
|
cmd: List[Any] = []
|
52
55
|
|
@@ -9,26 +9,28 @@ import traceback
|
|
9
9
|
from abc import ABC, abstractmethod
|
10
10
|
from dataclasses import dataclass
|
11
11
|
from enum import Enum
|
12
|
-
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
12
|
+
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Tuple, Union
|
13
13
|
|
14
14
|
import click
|
15
15
|
import yaml
|
16
16
|
|
17
17
|
import wandb
|
18
|
-
import wandb.apis.public as public
|
19
|
-
from wandb.apis.internal import Api
|
20
|
-
from wandb.apis.public import Api as PublicApi
|
21
|
-
from wandb.apis.public import QueuedRun, Run
|
22
18
|
from wandb.errors import CommError
|
19
|
+
from wandb.sdk.launch._launch_add import launch_add
|
23
20
|
from wandb.sdk.launch.errors import LaunchError
|
24
|
-
from wandb.sdk.launch.launch_add import launch_add
|
25
21
|
from wandb.sdk.launch.sweeps import SchedulerError
|
26
22
|
from wandb.sdk.launch.sweeps.utils import (
|
27
23
|
create_sweep_command_args,
|
28
24
|
make_launch_sweep_entrypoint,
|
29
25
|
)
|
30
26
|
from wandb.sdk.lib.runid import generate_id
|
31
|
-
|
27
|
+
|
28
|
+
if TYPE_CHECKING:
|
29
|
+
import wandb.apis.public as public
|
30
|
+
from wandb.apis.internal import Api
|
31
|
+
from wandb.apis.public import QueuedRun, Run
|
32
|
+
from wandb.sdk.wandb_run import Run as SdkRun
|
33
|
+
|
32
34
|
|
33
35
|
_logger = logging.getLogger(__name__)
|
34
36
|
LOG_PREFIX = f"{click.style('sched:', fg='cyan')} "
|
@@ -84,7 +86,7 @@ class SweepRun:
|
|
84
86
|
id: str
|
85
87
|
worker_id: int
|
86
88
|
state: RunState = RunState.RUNNING
|
87
|
-
queued_run: Optional[public.QueuedRun] = None
|
89
|
+
queued_run: Optional["public.QueuedRun"] = None
|
88
90
|
args: Optional[Dict[str, Any]] = None
|
89
91
|
logs: Optional[List[str]] = None
|
90
92
|
|
@@ -98,7 +100,7 @@ class Scheduler(ABC):
|
|
98
100
|
|
99
101
|
def __init__(
|
100
102
|
self,
|
101
|
-
api: Api,
|
103
|
+
api: "Api",
|
102
104
|
*args: Optional[Any],
|
103
105
|
polling_sleep: Optional[float] = None,
|
104
106
|
sweep_id: Optional[str] = None,
|
@@ -108,6 +110,8 @@ class Scheduler(ABC):
|
|
108
110
|
num_workers: Optional[Union[int, str]] = None,
|
109
111
|
**kwargs: Optional[Any],
|
110
112
|
):
|
113
|
+
from wandb.apis.public import Api as PublicApi
|
114
|
+
|
111
115
|
self._api = api
|
112
116
|
self._public_api = PublicApi()
|
113
117
|
self._entity = (
|
@@ -244,7 +248,7 @@ class Scheduler(ABC):
|
|
244
248
|
_id: w for _id, w in self._workers.items() if _id not in self.busy_workers
|
245
249
|
}
|
246
250
|
|
247
|
-
def _init_wandb_run(self) -> SdkRun:
|
251
|
+
def _init_wandb_run(self) -> "SdkRun":
|
248
252
|
"""Controls resume or init logic for a scheduler wandb run."""
|
249
253
|
_type = self._kwargs.get("sweep_type", "sweep")
|
250
254
|
run: SdkRun = wandb.init(
|
@@ -346,9 +350,8 @@ class Scheduler(ABC):
|
|
346
350
|
self.exit()
|
347
351
|
raise e
|
348
352
|
else:
|
349
|
-
|
350
|
-
|
351
|
-
if self.state in [SchedulerState.RUNNING, SchedulerState.FLUSH_RUNS]:
|
353
|
+
# scheduler succeeds if at runcap
|
354
|
+
if self.state == SchedulerState.FLUSH_RUNS and self.at_runcap:
|
352
355
|
self.state = SchedulerState.COMPLETED
|
353
356
|
self.exit()
|
354
357
|
|
@@ -362,16 +365,24 @@ class Scheduler(ABC):
|
|
362
365
|
f"{LOG_PREFIX}Failed to save state: {traceback.format_exc()}"
|
363
366
|
)
|
364
367
|
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
368
|
+
status = ""
|
369
|
+
if self.state == SchedulerState.FLUSH_RUNS:
|
370
|
+
self._set_sweep_state("PAUSED")
|
371
|
+
status = "paused"
|
372
|
+
elif self.state == SchedulerState.COMPLETED:
|
373
|
+
self._set_sweep_state("FINISHED")
|
374
|
+
status = "completed"
|
375
|
+
elif self.state in [SchedulerState.CANCELLED, SchedulerState.STOPPED]:
|
376
|
+
self._set_sweep_state("CANCELED") # one L
|
377
|
+
status = "cancelled"
|
378
|
+
self._stop_runs()
|
379
|
+
else:
|
369
380
|
self.state = SchedulerState.FAILED
|
370
381
|
self._set_sweep_state("CRASHED")
|
371
|
-
|
372
|
-
self.
|
382
|
+
status = "crashed"
|
383
|
+
self._stop_runs()
|
373
384
|
|
374
|
-
|
385
|
+
wandb.termlog(f"{LOG_PREFIX}Scheduler {status}")
|
375
386
|
self._wandb_run.finish()
|
376
387
|
|
377
388
|
def _get_num_runs_launched(self, runs: List[Dict[str, Any]]) -> int:
|
@@ -494,6 +505,7 @@ class Scheduler(ABC):
|
|
494
505
|
"""Update the scheduler state from state of scheduler run and sweep state."""
|
495
506
|
state: RunState = self._get_run_state(self._wandb_run.id)
|
496
507
|
|
508
|
+
# map scheduler run-state to scheduler-state
|
497
509
|
if state == RunState.KILLED:
|
498
510
|
self.state = SchedulerState.STOPPED
|
499
511
|
elif state in [RunState.FAILED, RunState.CRASHED]:
|
@@ -501,17 +513,20 @@ class Scheduler(ABC):
|
|
501
513
|
elif state == RunState.FINISHED:
|
502
514
|
self.state = SchedulerState.COMPLETED
|
503
515
|
|
516
|
+
# check sweep state for completed states, overwrite scheduler state
|
504
517
|
try:
|
505
518
|
sweep_state = self._api.get_sweep_state(
|
506
519
|
self._sweep_id, self._entity, self._project
|
507
520
|
)
|
508
521
|
except Exception as e:
|
509
|
-
_logger.debug(f"sweep state error: {
|
522
|
+
_logger.debug(f"sweep state error: {e}")
|
510
523
|
return
|
511
524
|
|
512
|
-
if sweep_state
|
525
|
+
if sweep_state == "FINISHED":
|
513
526
|
self.state = SchedulerState.COMPLETED
|
514
|
-
elif sweep_state in ["
|
527
|
+
elif sweep_state in ["CANCELLED", "STOPPED"]:
|
528
|
+
self.state = SchedulerState.CANCELLED
|
529
|
+
elif sweep_state == "PAUSED":
|
515
530
|
self.state = SchedulerState.FLUSH_RUNS
|
516
531
|
|
517
532
|
def _update_run_states(self) -> None:
|
@@ -674,6 +689,9 @@ class Scheduler(ABC):
|
|
674
689
|
f' {"job" if _job else "image_uri"} entrypoint'
|
675
690
|
)
|
676
691
|
|
692
|
+
# override resource and args of job
|
693
|
+
_job_launch_config = self._wandb_run.config.get("launch") or {}
|
694
|
+
|
677
695
|
run_id = run.id or generate_id()
|
678
696
|
queued_run = launch_add(
|
679
697
|
run_id=run_id,
|
@@ -685,8 +703,8 @@ class Scheduler(ABC):
|
|
685
703
|
entity=self._entity,
|
686
704
|
queue_name=self._kwargs.get("queue"),
|
687
705
|
project_queue=self._project_queue,
|
688
|
-
resource=
|
689
|
-
resource_args=
|
706
|
+
resource=_job_launch_config.get("resource"),
|
707
|
+
resource_args=_job_launch_config.get("resource_args"),
|
690
708
|
author=self._kwargs.get("author"),
|
691
709
|
sweep_id=self._sweep_id,
|
692
710
|
)
|
wandb/sdk/launch/sweeps/utils.py
CHANGED
@@ -1,15 +1,17 @@
|
|
1
1
|
import json
|
2
2
|
import os
|
3
3
|
import re
|
4
|
-
from typing import Any, Dict, List, Optional, Tuple, Union
|
4
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
5
5
|
|
6
6
|
import yaml
|
7
7
|
|
8
8
|
import wandb
|
9
9
|
from wandb import util
|
10
|
-
from wandb.apis.public import Api as PublicApi
|
11
10
|
from wandb.sdk.launch.errors import LaunchError
|
12
11
|
|
12
|
+
if TYPE_CHECKING:
|
13
|
+
from wandb.apis.public import Api as PublicApi
|
14
|
+
|
13
15
|
DEFAULT_SWEEP_COMMAND: List[str] = [
|
14
16
|
"${env}",
|
15
17
|
"${interpreter}",
|
@@ -276,7 +278,7 @@ def make_launch_sweep_entrypoint(
|
|
276
278
|
return entry_point, macro_args
|
277
279
|
|
278
280
|
|
279
|
-
def check_job_exists(public_api: PublicApi, job: Optional[str]) -> bool:
|
281
|
+
def check_job_exists(public_api: "PublicApi", job: Optional[str]) -> bool:
|
280
282
|
"""Check if the job exists using the public api.
|
281
283
|
|
282
284
|
Returns: True if no job is passed, or if the job exists.
|
wandb/sdk/launch/utils.py
CHANGED
@@ -127,7 +127,9 @@ def set_project_entity_defaults(
|
|
127
127
|
prefix = ""
|
128
128
|
if platform.system() != "Windows" and sys.stdout.encoding == "UTF-8":
|
129
129
|
prefix = "🚀 "
|
130
|
-
wandb.termlog(
|
130
|
+
wandb.termlog(
|
131
|
+
f"{LOG_PREFIX}{prefix}Launching run into {entity}{'/' + project if project else ''}"
|
132
|
+
)
|
131
133
|
return project, entity
|
132
134
|
|
133
135
|
|