wandb 0.15.3__py3-none-any.whl → 0.15.5__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- wandb/__init__.py +1 -1
- wandb/analytics/sentry.py +1 -0
- wandb/apis/importers/base.py +20 -5
- wandb/apis/importers/mlflow.py +7 -1
- wandb/apis/internal.py +12 -0
- wandb/apis/public.py +247 -1387
- wandb/apis/reports/_panels.py +58 -35
- wandb/beta/workflows.py +6 -7
- wandb/cli/cli.py +130 -60
- wandb/data_types.py +3 -1
- wandb/filesync/dir_watcher.py +21 -27
- wandb/filesync/step_checksum.py +8 -8
- wandb/filesync/step_prepare.py +23 -10
- wandb/filesync/step_upload.py +13 -13
- wandb/filesync/upload_job.py +4 -8
- wandb/integration/cohere/__init__.py +3 -0
- wandb/integration/cohere/cohere.py +21 -0
- wandb/integration/cohere/resolver.py +347 -0
- wandb/integration/gym/__init__.py +4 -6
- wandb/integration/huggingface/__init__.py +3 -0
- wandb/integration/huggingface/huggingface.py +18 -0
- wandb/integration/huggingface/resolver.py +213 -0
- wandb/integration/langchain/wandb_tracer.py +16 -179
- wandb/integration/openai/__init__.py +1 -3
- wandb/integration/openai/openai.py +11 -143
- wandb/integration/openai/resolver.py +111 -38
- wandb/integration/sagemaker/config.py +2 -2
- wandb/integration/tensorboard/log.py +4 -4
- wandb/old/settings.py +24 -7
- wandb/proto/v3/wandb_telemetry_pb2.py +12 -12
- wandb/proto/v4/wandb_telemetry_pb2.py +12 -12
- wandb/proto/wandb_deprecated.py +3 -1
- wandb/sdk/__init__.py +1 -1
- wandb/sdk/artifacts/__init__.py +0 -0
- wandb/sdk/artifacts/artifact.py +2101 -0
- wandb/sdk/artifacts/artifact_download_logger.py +42 -0
- wandb/sdk/artifacts/artifact_manifest.py +67 -0
- wandb/sdk/artifacts/artifact_manifest_entry.py +159 -0
- wandb/sdk/artifacts/artifact_manifests/__init__.py +0 -0
- wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +91 -0
- wandb/sdk/{internal → artifacts}/artifact_saver.py +6 -5
- wandb/sdk/artifacts/artifact_state.py +10 -0
- wandb/sdk/{interface/artifacts/artifact_cache.py → artifacts/artifacts_cache.py} +22 -12
- wandb/sdk/artifacts/exceptions.py +55 -0
- wandb/sdk/artifacts/storage_handler.py +59 -0
- wandb/sdk/artifacts/storage_handlers/__init__.py +0 -0
- wandb/sdk/artifacts/storage_handlers/azure_handler.py +192 -0
- wandb/sdk/artifacts/storage_handlers/gcs_handler.py +224 -0
- wandb/sdk/artifacts/storage_handlers/http_handler.py +112 -0
- wandb/sdk/artifacts/storage_handlers/local_file_handler.py +134 -0
- wandb/sdk/artifacts/storage_handlers/multi_handler.py +53 -0
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +301 -0
- wandb/sdk/artifacts/storage_handlers/tracking_handler.py +67 -0
- wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +132 -0
- wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +72 -0
- wandb/sdk/artifacts/storage_layout.py +6 -0
- wandb/sdk/artifacts/storage_policies/__init__.py +0 -0
- wandb/sdk/artifacts/storage_policies/s3_bucket_policy.py +61 -0
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +386 -0
- wandb/sdk/{interface/artifacts/artifact_storage.py → artifacts/storage_policy.py} +5 -57
- wandb/sdk/data_types/_dtypes.py +7 -12
- wandb/sdk/data_types/base_types/json_metadata.py +3 -2
- wandb/sdk/data_types/base_types/media.py +8 -8
- wandb/sdk/data_types/base_types/wb_value.py +12 -13
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +5 -6
- wandb/sdk/data_types/helper_types/classes.py +6 -8
- wandb/sdk/data_types/helper_types/image_mask.py +5 -6
- wandb/sdk/data_types/histogram.py +4 -3
- wandb/sdk/data_types/html.py +3 -4
- wandb/sdk/data_types/image.py +11 -9
- wandb/sdk/data_types/molecule.py +5 -3
- wandb/sdk/data_types/object_3d.py +7 -5
- wandb/sdk/data_types/plotly.py +3 -2
- wandb/sdk/data_types/saved_model.py +11 -11
- wandb/sdk/data_types/trace_tree.py +5 -4
- wandb/sdk/data_types/utils.py +3 -5
- wandb/sdk/data_types/video.py +5 -4
- wandb/sdk/integration_utils/auto_logging.py +215 -0
- wandb/sdk/interface/interface.py +15 -15
- wandb/sdk/internal/file_pusher.py +8 -16
- wandb/sdk/internal/file_stream.py +5 -11
- wandb/sdk/internal/handler.py +13 -1
- wandb/sdk/internal/internal_api.py +287 -13
- wandb/sdk/internal/job_builder.py +119 -30
- wandb/sdk/internal/sender.py +6 -26
- wandb/sdk/internal/settings_static.py +2 -0
- wandb/sdk/internal/system/assets/__init__.py +2 -0
- wandb/sdk/internal/system/assets/gpu.py +42 -0
- wandb/sdk/internal/system/assets/gpu_amd.py +216 -0
- wandb/sdk/internal/system/env_probe_helpers.py +13 -0
- wandb/sdk/internal/system/system_info.py +3 -3
- wandb/sdk/internal/tb_watcher.py +32 -22
- wandb/sdk/internal/thread_local_settings.py +18 -0
- wandb/sdk/launch/_project_spec.py +57 -11
- wandb/sdk/launch/agent/agent.py +147 -65
- wandb/sdk/launch/agent/job_status_tracker.py +34 -0
- wandb/sdk/launch/agent/run_queue_item_file_saver.py +45 -0
- wandb/sdk/launch/builder/abstract.py +5 -1
- wandb/sdk/launch/builder/build.py +21 -18
- wandb/sdk/launch/builder/docker_builder.py +10 -4
- wandb/sdk/launch/builder/kaniko_builder.py +113 -23
- wandb/sdk/launch/builder/noop.py +6 -3
- wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +46 -14
- wandb/sdk/launch/environment/aws_environment.py +3 -2
- wandb/sdk/launch/environment/azure_environment.py +124 -0
- wandb/sdk/launch/environment/gcp_environment.py +2 -4
- wandb/sdk/launch/environment/local_environment.py +1 -1
- wandb/sdk/launch/errors.py +19 -0
- wandb/sdk/launch/github_reference.py +32 -19
- wandb/sdk/launch/launch.py +3 -8
- wandb/sdk/launch/launch_add.py +6 -2
- wandb/sdk/launch/loader.py +21 -2
- wandb/sdk/launch/registry/azure_container_registry.py +132 -0
- wandb/sdk/launch/registry/elastic_container_registry.py +39 -5
- wandb/sdk/launch/registry/google_artifact_registry.py +68 -26
- wandb/sdk/launch/registry/local_registry.py +2 -1
- wandb/sdk/launch/runner/abstract.py +24 -3
- wandb/sdk/launch/runner/kubernetes_runner.py +479 -26
- wandb/sdk/launch/runner/local_container.py +103 -51
- wandb/sdk/launch/runner/local_process.py +1 -1
- wandb/sdk/launch/runner/sagemaker_runner.py +60 -10
- wandb/sdk/launch/runner/vertex_runner.py +10 -5
- wandb/sdk/launch/sweeps/__init__.py +7 -9
- wandb/sdk/launch/sweeps/scheduler.py +307 -77
- wandb/sdk/launch/sweeps/scheduler_sweep.py +2 -1
- wandb/sdk/launch/sweeps/utils.py +82 -35
- wandb/sdk/launch/utils.py +89 -75
- wandb/sdk/lib/_settings_toposort_generated.py +7 -0
- wandb/sdk/lib/capped_dict.py +26 -0
- wandb/sdk/lib/{git.py → gitlib.py} +76 -59
- wandb/sdk/lib/hashutil.py +12 -4
- wandb/sdk/lib/paths.py +96 -8
- wandb/sdk/lib/sock_client.py +2 -2
- wandb/sdk/lib/timer.py +1 -0
- wandb/sdk/service/server.py +22 -9
- wandb/sdk/service/server_sock.py +1 -1
- wandb/sdk/service/service.py +27 -8
- wandb/sdk/verify/verify.py +4 -7
- wandb/sdk/wandb_config.py +2 -6
- wandb/sdk/wandb_init.py +57 -53
- wandb/sdk/wandb_require.py +7 -0
- wandb/sdk/wandb_run.py +61 -223
- wandb/sdk/wandb_settings.py +28 -4
- wandb/testing/relay.py +15 -2
- wandb/util.py +74 -36
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/METADATA +15 -9
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/RECORD +151 -116
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/entry_points.txt +1 -0
- wandb/integration/langchain/util.py +0 -191
- wandb/sdk/interface/artifacts/__init__.py +0 -33
- wandb/sdk/interface/artifacts/artifact.py +0 -615
- wandb/sdk/interface/artifacts/artifact_manifest.py +0 -131
- wandb/sdk/wandb_artifacts.py +0 -2226
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/LICENSE +0 -0
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/WHEEL +0 -0
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/top_level.txt +0 -0
@@ -1,27 +1,32 @@
|
|
1
|
+
"""Implementation of KubernetesRunner class for wandb launch."""
|
2
|
+
|
1
3
|
import base64
|
2
4
|
import json
|
3
5
|
import logging
|
4
6
|
import time
|
5
|
-
from typing import Any, Dict, List, Optional, Tuple
|
7
|
+
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
6
8
|
|
7
9
|
import wandb
|
8
10
|
from wandb.apis.internal import Api
|
11
|
+
from wandb.sdk.launch.agent.job_status_tracker import JobAndRunStatusTracker
|
9
12
|
from wandb.sdk.launch.builder.abstract import AbstractBuilder
|
10
13
|
from wandb.sdk.launch.environment.abstract import AbstractEnvironment
|
11
14
|
from wandb.sdk.launch.registry.abstract import AbstractRegistry
|
15
|
+
from wandb.sdk.launch.registry.azure_container_registry import AzureContainerRegistry
|
12
16
|
from wandb.sdk.launch.registry.local_registry import LocalRegistry
|
17
|
+
from wandb.sdk.launch.runner.abstract import State, Status
|
13
18
|
from wandb.util import get_module
|
14
19
|
|
15
20
|
from .._project_spec import EntryPoint, LaunchProject
|
16
21
|
from ..builder.build import get_env_vars_dict
|
22
|
+
from ..errors import LaunchError
|
17
23
|
from ..utils import (
|
18
24
|
LOG_PREFIX,
|
19
25
|
PROJECT_SYNCHRONOUS,
|
20
|
-
LaunchError,
|
21
26
|
get_kube_context_and_api_client,
|
22
27
|
make_name_dns_safe,
|
23
28
|
)
|
24
|
-
from .abstract import AbstractRun, AbstractRunner
|
29
|
+
from .abstract import AbstractRun, AbstractRunner
|
25
30
|
|
26
31
|
get_module(
|
27
32
|
"kubernetes",
|
@@ -31,8 +36,12 @@ get_module(
|
|
31
36
|
from kubernetes import client # type: ignore # noqa: E402
|
32
37
|
from kubernetes.client.api.batch_v1_api import BatchV1Api # type: ignore # noqa: E402
|
33
38
|
from kubernetes.client.api.core_v1_api import CoreV1Api # type: ignore # noqa: E402
|
39
|
+
from kubernetes.client.api.custom_objects_api import ( # type: ignore # noqa: E402
|
40
|
+
CustomObjectsApi,
|
41
|
+
)
|
34
42
|
from kubernetes.client.models.v1_job import V1Job # type: ignore # noqa: E402
|
35
43
|
from kubernetes.client.models.v1_secret import V1Secret # type: ignore # noqa: E402
|
44
|
+
from kubernetes.client.rest import ApiException # type: ignore # noqa: E402
|
36
45
|
|
37
46
|
TIMEOUT = 5
|
38
47
|
MAX_KUBERNETES_RETRIES = (
|
@@ -43,7 +52,22 @@ FAIL_MESSAGE_INTERVAL = 60
|
|
43
52
|
_logger = logging.getLogger(__name__)
|
44
53
|
|
45
54
|
|
55
|
+
# Dict for mapping possible states of custom objects to the states we want to report
|
56
|
+
# to the agent.
|
57
|
+
CRD_STATE_DICT: Dict[str, State] = {
|
58
|
+
"pending": "starting",
|
59
|
+
"running": "running",
|
60
|
+
"completed": "finished",
|
61
|
+
"failed": "failed",
|
62
|
+
"aborted": "failed",
|
63
|
+
"terminating": "stopping",
|
64
|
+
"terminated": "stopped",
|
65
|
+
}
|
66
|
+
|
67
|
+
|
46
68
|
class KubernetesSubmittedRun(AbstractRun):
|
69
|
+
"""Wrapper for a launched run on Kubernetes."""
|
70
|
+
|
47
71
|
def __init__(
|
48
72
|
self,
|
49
73
|
batch_api: "BatchV1Api",
|
@@ -53,6 +77,19 @@ class KubernetesSubmittedRun(AbstractRun):
|
|
53
77
|
namespace: Optional[str] = "default",
|
54
78
|
secret: Optional["V1Secret"] = None,
|
55
79
|
) -> None:
|
80
|
+
"""Initialize a KubernetesSubmittedRun.
|
81
|
+
|
82
|
+
Arguments:
|
83
|
+
batch_api: Kubernetes BatchV1Api object.
|
84
|
+
core_api: Kubernetes CoreV1Api object.
|
85
|
+
name: Name of the job.
|
86
|
+
pod_names: List of pod names.
|
87
|
+
namespace: Kubernetes namespace.
|
88
|
+
secret: Kubernetes secret.
|
89
|
+
|
90
|
+
Returns:
|
91
|
+
None.
|
92
|
+
"""
|
56
93
|
self.batch_api = batch_api
|
57
94
|
self.core_api = core_api
|
58
95
|
self.name = name
|
@@ -66,14 +103,37 @@ class KubernetesSubmittedRun(AbstractRun):
|
|
66
103
|
|
67
104
|
@property
|
68
105
|
def id(self) -> str:
|
106
|
+
"""Return the run id."""
|
69
107
|
return self.name
|
70
108
|
|
109
|
+
def get_logs(self) -> Optional[str]:
|
110
|
+
try:
|
111
|
+
logs = self.core_api.read_namespaced_pod_log(
|
112
|
+
name=self.pod_names[0], namespace=self.namespace
|
113
|
+
)
|
114
|
+
if logs:
|
115
|
+
return str(logs)
|
116
|
+
else:
|
117
|
+
wandb.termwarn(
|
118
|
+
f"Retrieved no logs for kubernetes pod(s): {self.pod_names}"
|
119
|
+
)
|
120
|
+
return None
|
121
|
+
except Exception as e:
|
122
|
+
wandb.termerror(f"{LOG_PREFIX}Failed to get pod logs: {e}")
|
123
|
+
return None
|
124
|
+
|
71
125
|
def get_job(self) -> "V1Job":
|
126
|
+
"""Return the job object."""
|
72
127
|
return self.batch_api.read_namespaced_job(
|
73
128
|
name=self.name, namespace=self.namespace
|
74
129
|
)
|
75
130
|
|
76
131
|
def wait(self) -> bool:
|
132
|
+
"""Wait for the run to finish.
|
133
|
+
|
134
|
+
Returns:
|
135
|
+
True if the run finished successfully, False otherwise.
|
136
|
+
"""
|
77
137
|
while True:
|
78
138
|
status = self.get_status()
|
79
139
|
wandb.termlog(f"{LOG_PREFIX}Job {self.name} status: {status}")
|
@@ -85,14 +145,23 @@ class KubernetesSubmittedRun(AbstractRun):
|
|
85
145
|
) # todo: not sure if this (copied from aws runner) is the right approach? should we return false on failure
|
86
146
|
|
87
147
|
def get_status(self) -> Status:
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
148
|
+
"""Return the run status."""
|
149
|
+
try:
|
150
|
+
job_response = self.batch_api.read_namespaced_job_status(
|
151
|
+
name=self.name, namespace=self.namespace
|
152
|
+
)
|
153
|
+
status = job_response.status
|
154
|
+
|
155
|
+
pod = self.core_api.read_namespaced_pod(
|
156
|
+
name=self.pod_names[0], namespace=self.namespace
|
157
|
+
)
|
158
|
+
except ApiException as e:
|
159
|
+
if "(404)" not in str(e):
|
160
|
+
raise
|
161
|
+
# 404 = Pod/job not reachable
|
162
|
+
wandb.termlog(f"{LOG_PREFIX}Job or pod disconnected for job: {self.name}")
|
163
|
+
return Status("preempted")
|
92
164
|
|
93
|
-
pod = self.core_api.read_namespaced_pod(
|
94
|
-
name=self.pod_names[0], namespace=self.namespace
|
95
|
-
)
|
96
165
|
if pod.status.phase in ["Pending", "Unknown"]:
|
97
166
|
now = time.time()
|
98
167
|
if self._fail_count == 0:
|
@@ -111,7 +180,13 @@ class KubernetesSubmittedRun(AbstractRun):
|
|
111
180
|
if status.succeeded == 1:
|
112
181
|
return_status = Status("finished")
|
113
182
|
elif status.failed is not None and status.failed >= 1:
|
114
|
-
|
183
|
+
if status.conditions[0].reason == "BackoffLimitExceeded":
|
184
|
+
wandb.termlog(
|
185
|
+
f"{LOG_PREFIX}Job or pod disconnected for job: {self.name}"
|
186
|
+
)
|
187
|
+
return_status = Status("preempted")
|
188
|
+
else:
|
189
|
+
return_status = Status("failed")
|
115
190
|
elif status.active == 1:
|
116
191
|
return Status("running")
|
117
192
|
elif status.conditions is not None and status.conditions[0].type == "Suspended":
|
@@ -133,6 +208,7 @@ class KubernetesSubmittedRun(AbstractRun):
|
|
133
208
|
return return_status
|
134
209
|
|
135
210
|
def suspend(self) -> None:
|
211
|
+
"""Suspend the run."""
|
136
212
|
self.job.spec.suspend = True
|
137
213
|
self.batch_api.patch_namespaced_job(
|
138
214
|
name=self.name, namespace=self.namespace, body=self.job
|
@@ -156,29 +232,183 @@ class KubernetesSubmittedRun(AbstractRun):
|
|
156
232
|
)
|
157
233
|
|
158
234
|
def cancel(self) -> None:
|
235
|
+
"""Cancel the run."""
|
159
236
|
self.suspend()
|
160
237
|
self.batch_api.delete_namespaced_job(name=self.name, namespace=self.namespace)
|
161
238
|
|
162
239
|
|
240
|
+
class CrdSubmittedRun(AbstractRun):
|
241
|
+
"""Run submitted to a CRD backend, e.g. Volcano."""
|
242
|
+
|
243
|
+
def __init__(
|
244
|
+
self,
|
245
|
+
group: str,
|
246
|
+
version: str,
|
247
|
+
plural: str,
|
248
|
+
name: str,
|
249
|
+
namespace: str,
|
250
|
+
core_api: CoreV1Api,
|
251
|
+
custom_api: CustomObjectsApi,
|
252
|
+
pod_names: List[str],
|
253
|
+
) -> None:
|
254
|
+
"""Create a run object for tracking the progress of a CRD.
|
255
|
+
|
256
|
+
Arguments:
|
257
|
+
group: The API group of the CRD.
|
258
|
+
version: The API version of the CRD.
|
259
|
+
plural: The plural name of the CRD.
|
260
|
+
name: The name of the CRD instance.
|
261
|
+
namespace: The namespace of the CRD instance.
|
262
|
+
core_api: The Kubernetes core API client.
|
263
|
+
custom_api: The Kubernetes custom object API client.
|
264
|
+
pod_names: The names of the pods associated with the CRD instance.
|
265
|
+
|
266
|
+
Raises:
|
267
|
+
LaunchError: If the CRD instance does not exist.
|
268
|
+
"""
|
269
|
+
self.group = group
|
270
|
+
self.version = version
|
271
|
+
self.plural = plural
|
272
|
+
self.name = name
|
273
|
+
self.namespace = namespace
|
274
|
+
self.core_api = core_api
|
275
|
+
self.custom_api = custom_api
|
276
|
+
self.pod_names = pod_names
|
277
|
+
self._fail_count = 0
|
278
|
+
try:
|
279
|
+
self.job = self.custom_api.get_namespaced_custom_object(
|
280
|
+
group=self.group,
|
281
|
+
version=self.version,
|
282
|
+
namespace=self.namespace,
|
283
|
+
plural=self.plural,
|
284
|
+
name=self.name,
|
285
|
+
)
|
286
|
+
except ApiException as e:
|
287
|
+
raise LaunchError(
|
288
|
+
f"Failed to get CRD {self.name} in namespace {self.namespace}: {str(e)}"
|
289
|
+
) from e
|
290
|
+
|
291
|
+
@property
|
292
|
+
def id(self) -> str:
|
293
|
+
"""Get the name of the custom object."""
|
294
|
+
return self.name
|
295
|
+
|
296
|
+
def get_logs(self) -> Optional[str]:
|
297
|
+
"""Get logs for custom object."""
|
298
|
+
# TODO: test more carefully once we release multi-node support
|
299
|
+
logs: Dict[str, Optional[str]] = {}
|
300
|
+
try:
|
301
|
+
for pod_name in self.pod_names:
|
302
|
+
logs[pod_name] = self.core_api.read_namespaced_pod_log(
|
303
|
+
name=pod_name, namespace=self.namespace
|
304
|
+
)
|
305
|
+
except ApiException as e:
|
306
|
+
wandb.termwarn(f"Failed to get logs for {self.name}: {str(e)}")
|
307
|
+
return None
|
308
|
+
if not logs:
|
309
|
+
return None
|
310
|
+
logs_as_array = [f"Pod {pod_name}:\n{log}" for pod_name, log in logs.items()]
|
311
|
+
return "\n".join(logs_as_array)
|
312
|
+
|
313
|
+
def get_status(self) -> Status:
|
314
|
+
"""Get status of custom object."""
|
315
|
+
try:
|
316
|
+
job_response = self.custom_api.get_namespaced_custom_object_status(
|
317
|
+
group=self.group,
|
318
|
+
version=self.version,
|
319
|
+
namespace=self.namespace,
|
320
|
+
plural=self.plural,
|
321
|
+
name=self.name,
|
322
|
+
)
|
323
|
+
except ApiException as e:
|
324
|
+
raise LaunchError(
|
325
|
+
f"Failed to get CRD {self.name} in namespace {self.namespace}: {str(e)}"
|
326
|
+
) from e
|
327
|
+
# Custom objects can technically define whater states and format the
|
328
|
+
# response to the status request however they want. This checks for
|
329
|
+
# the most common cases.
|
330
|
+
status = job_response["status"]
|
331
|
+
state = status.get("state")
|
332
|
+
if isinstance(state, dict):
|
333
|
+
state = state.get("phase")
|
334
|
+
if state is None:
|
335
|
+
raise LaunchError(
|
336
|
+
f"Failed to get CRD {self.name} in namespace {self.namespace}: no state found"
|
337
|
+
)
|
338
|
+
return Status(CRD_STATE_DICT.get(state.lower(), "unknown"))
|
339
|
+
|
340
|
+
def cancel(self) -> None:
|
341
|
+
"""Cancel the custom object."""
|
342
|
+
try:
|
343
|
+
self.custom_api.delete_namespaced_custom_object(
|
344
|
+
group=self.group,
|
345
|
+
version=self.version,
|
346
|
+
namespace=self.namespace,
|
347
|
+
plural=self.plural,
|
348
|
+
name=self.name,
|
349
|
+
)
|
350
|
+
except ApiException as e:
|
351
|
+
raise LaunchError(
|
352
|
+
f"Failed to delete CRD {self.name} in namespace {self.namespace}: {str(e)}"
|
353
|
+
) from e
|
354
|
+
|
355
|
+
def wait(self) -> bool:
|
356
|
+
"""Wait for this custom object to finish running."""
|
357
|
+
while True:
|
358
|
+
status = self.get_status()
|
359
|
+
wandb.termlog(f"{LOG_PREFIX}Job {self.name} status: {status}")
|
360
|
+
if status.state != "running":
|
361
|
+
break
|
362
|
+
time.sleep(5)
|
363
|
+
return status.state == "finished"
|
364
|
+
|
365
|
+
|
163
366
|
class KubernetesRunner(AbstractRunner):
|
367
|
+
"""Launches runs onto kubernetes."""
|
368
|
+
|
164
369
|
def __init__(
|
165
370
|
self, api: Api, backend_config: Dict[str, Any], environment: AbstractEnvironment
|
166
371
|
) -> None:
|
372
|
+
"""Create a Kubernetes runner.
|
373
|
+
|
374
|
+
Arguments:
|
375
|
+
api: The API client object.
|
376
|
+
backend_config: The backend configuration.
|
377
|
+
environment: The environment to launch runs into.
|
378
|
+
|
379
|
+
Raises:
|
380
|
+
LaunchError: If the Kubernetes configuration is invalid.
|
381
|
+
"""
|
167
382
|
super().__init__(api, backend_config)
|
168
383
|
self.environment = environment
|
169
384
|
|
170
385
|
def wait_job_launch(
|
171
|
-
self,
|
386
|
+
self,
|
387
|
+
job_name: str,
|
388
|
+
namespace: str,
|
389
|
+
core_api: "CoreV1Api",
|
390
|
+
label: str = "job-name",
|
172
391
|
) -> List[str]:
|
392
|
+
"""Wait for a job to be launched and return the pod names.
|
393
|
+
|
394
|
+
Arguments:
|
395
|
+
job_name: The name of the job.
|
396
|
+
namespace: The namespace of the job.
|
397
|
+
core_api: The Kubernetes core API client.
|
398
|
+
label: The label key to match against job_name.
|
399
|
+
|
400
|
+
Returns:
|
401
|
+
The names of the pods associated with the job.
|
402
|
+
"""
|
173
403
|
pods = core_api.list_namespaced_pod(
|
174
|
-
label_selector=f"
|
404
|
+
label_selector=f"{label}={job_name}", namespace=namespace
|
175
405
|
)
|
176
406
|
timeout = TIMEOUT
|
177
407
|
while len(pods.items) == 0 and timeout > 0:
|
178
408
|
time.sleep(1)
|
179
409
|
timeout -= 1
|
180
410
|
pods = core_api.list_namespaced_pod(
|
181
|
-
label_selector=f"
|
411
|
+
label_selector=f"{label}={job_name}", namespace=namespace
|
182
412
|
)
|
183
413
|
|
184
414
|
if timeout == 0:
|
@@ -197,6 +427,15 @@ class KubernetesRunner(AbstractRunner):
|
|
197
427
|
def get_namespace(
|
198
428
|
self, resource_args: Dict[str, Any], context: Dict[str, Any]
|
199
429
|
) -> str:
|
430
|
+
"""Get the namespace to launch into.
|
431
|
+
|
432
|
+
Arguments:
|
433
|
+
resource_args: The resource args to launch.
|
434
|
+
context: The k8s config context.
|
435
|
+
|
436
|
+
Returns:
|
437
|
+
The namespace to launch into.
|
438
|
+
"""
|
200
439
|
default_namespace = (
|
201
440
|
context["context"].get("namespace", "default") if context else "default"
|
202
441
|
)
|
@@ -213,8 +452,20 @@ class KubernetesRunner(AbstractRunner):
|
|
213
452
|
builder: Optional[AbstractBuilder],
|
214
453
|
namespace: str,
|
215
454
|
core_api: "CoreV1Api",
|
455
|
+
job_tracker: Optional[JobAndRunStatusTracker],
|
216
456
|
) -> Tuple[Dict[str, Any], Optional["V1Secret"]]:
|
217
|
-
"""Apply our default values, return job dict and secret.
|
457
|
+
"""Apply our default values, return job dict and secret.
|
458
|
+
|
459
|
+
Arguments:
|
460
|
+
resource_args (Dict[str, Any]): The resource args to launch.
|
461
|
+
launch_project (LaunchProject): The launch project.
|
462
|
+
builder (Optional[AbstractBuilder]): The builder.
|
463
|
+
namespace (str): The namespace.
|
464
|
+
core_api (CoreV1Api): The core api.
|
465
|
+
|
466
|
+
Returns:
|
467
|
+
Tuple[Dict[str, Any], Optional["V1Secret"]]: The resource args and secret.
|
468
|
+
"""
|
218
469
|
job: Dict[str, Any] = {
|
219
470
|
"apiVersion": "batch/v1",
|
220
471
|
"kind": "Job",
|
@@ -253,7 +504,9 @@ class KubernetesRunner(AbstractRunner):
|
|
253
504
|
"Invalid specification of multiple containers. See https://docs.wandb.ai/guides/launch for guidance on submitting jobs."
|
254
505
|
)
|
255
506
|
# dont specify run id if user provided image, could have multiple runs
|
256
|
-
|
507
|
+
image_uri = launch_project.docker_image
|
508
|
+
containers[0]["image"] = image_uri
|
509
|
+
launch_project.fill_macros(image_uri)
|
257
510
|
# TODO: handle secret pulling image from registry
|
258
511
|
elif not any(["image" in cont for cont in containers]):
|
259
512
|
if len(containers) > 1:
|
@@ -262,7 +515,9 @@ class KubernetesRunner(AbstractRunner):
|
|
262
515
|
)
|
263
516
|
assert entry_point is not None
|
264
517
|
assert builder is not None
|
265
|
-
image_uri = builder.build_image(launch_project, entry_point)
|
518
|
+
image_uri = builder.build_image(launch_project, entry_point, job_tracker)
|
519
|
+
image_uri = image_uri.replace("https://", "")
|
520
|
+
launch_project.fill_macros(image_uri)
|
266
521
|
# in the non instance case we need to make an imagePullSecret
|
267
522
|
# so the new job can pull the image
|
268
523
|
if not builder.registry:
|
@@ -276,8 +531,8 @@ class KubernetesRunner(AbstractRunner):
|
|
276
531
|
pod_spec["imagePullSecrets"] = [
|
277
532
|
{"name": f"regcred-{launch_project.run_id}"}
|
278
533
|
]
|
279
|
-
|
280
534
|
containers[0]["image"] = image_uri
|
535
|
+
launch_project.fill_macros(image_uri)
|
281
536
|
|
282
537
|
inject_entrypoint_and_args(
|
283
538
|
containers,
|
@@ -306,8 +561,18 @@ class KubernetesRunner(AbstractRunner):
|
|
306
561
|
def run(
|
307
562
|
self,
|
308
563
|
launch_project: LaunchProject,
|
309
|
-
builder:
|
564
|
+
builder: AbstractBuilder,
|
565
|
+
job_tracker: Optional[JobAndRunStatusTracker] = None,
|
310
566
|
) -> Optional[AbstractRun]: # noqa: C901
|
567
|
+
"""Execute a launch project on Kubernetes.
|
568
|
+
|
569
|
+
Arguments:
|
570
|
+
launch_project: The launch project to execute.
|
571
|
+
builder: The builder to use to build the image.
|
572
|
+
|
573
|
+
Returns:
|
574
|
+
The run object if the run was successful, otherwise None.
|
575
|
+
"""
|
311
576
|
kubernetes = get_module( # noqa: F811
|
312
577
|
"kubernetes",
|
313
578
|
required="Kubernetes runner requires the kubernetes package. Please"
|
@@ -316,23 +581,86 @@ class KubernetesRunner(AbstractRunner):
|
|
316
581
|
resource_args = launch_project.resource_args.get("kubernetes", {})
|
317
582
|
if not resource_args:
|
318
583
|
wandb.termlog(
|
319
|
-
f"{LOG_PREFIX}Note: no resource args specified. Add a
|
584
|
+
f"{LOG_PREFIX}Note: no resource args specified. Add a "
|
585
|
+
"Kubernetes yaml spec or other options in a json file "
|
586
|
+
"with --resource-args <json>."
|
320
587
|
)
|
321
588
|
_logger.info(f"Running Kubernetes job with resource args: {resource_args}")
|
322
589
|
|
323
590
|
context, api_client = get_kube_context_and_api_client(kubernetes, resource_args)
|
324
591
|
|
592
|
+
# If the user specified an alternate api, we need will execute this
|
593
|
+
# run by creating a custom object.
|
594
|
+
api_version = resource_args.get("apiVersion", "batch/v1")
|
595
|
+
if api_version not in ["batch/v1", "batch/v1beta1"]:
|
596
|
+
entrypoint = launch_project.get_single_entry_point()
|
597
|
+
if launch_project.docker_image:
|
598
|
+
image_uri = launch_project.docker_image
|
599
|
+
else:
|
600
|
+
assert entrypoint is not None
|
601
|
+
image_uri = builder.build_image(launch_project, entrypoint, job_tracker)
|
602
|
+
launch_project.fill_macros(image_uri)
|
603
|
+
env_vars = get_env_vars_dict(launch_project, self._api)
|
604
|
+
# Crawl the resource args and add our env vars to the containers.
|
605
|
+
add_wandb_env(launch_project.resource_args, env_vars)
|
606
|
+
# Crawl the resource arsg and add our labels to the pods. This is
|
607
|
+
# necessary for the agent to find the pods later on.
|
608
|
+
add_label_to_pods(
|
609
|
+
launch_project.resource_args, "wandb/run-id", launch_project.run_id
|
610
|
+
)
|
611
|
+
overrides = {}
|
612
|
+
if launch_project.override_args:
|
613
|
+
overrides["args"] = launch_project.override_args
|
614
|
+
if launch_project.override_entrypoint:
|
615
|
+
overrides["command"] = launch_project.override_entrypoint.command
|
616
|
+
add_entrypoint_args_overrides(
|
617
|
+
launch_project.resource_args,
|
618
|
+
overrides,
|
619
|
+
)
|
620
|
+
api = client.CustomObjectsApi(api_client)
|
621
|
+
# Infer the attributes of a custom object from the apiVersion and/or
|
622
|
+
# a kind: attribute in the resource args.
|
623
|
+
namespace = self.get_namespace(resource_args, context)
|
624
|
+
group = resource_args.get("group", api_version.split("/")[0])
|
625
|
+
version = api_version.split("/")[1]
|
626
|
+
kind = resource_args.get("kind", version)
|
627
|
+
plural = f"{kind.lower()}s"
|
628
|
+
try:
|
629
|
+
response = api.create_namespaced_custom_object(
|
630
|
+
group=group,
|
631
|
+
version=version,
|
632
|
+
namespace=namespace,
|
633
|
+
plural=plural,
|
634
|
+
body=launch_project.resource_args.get("kubernetes"),
|
635
|
+
)
|
636
|
+
except ApiException as e:
|
637
|
+
raise LaunchError(
|
638
|
+
f"Error creating CRD of kind {kind}: {e.status} {e.reason}"
|
639
|
+
) from e
|
640
|
+
name = response.get("metadata", {}).get("name")
|
641
|
+
_logger.info(f"Created {kind} {response['metadata']['name']}")
|
642
|
+
core = client.CoreV1Api(api_client)
|
643
|
+
pod_names = self.wait_job_launch(
|
644
|
+
launch_project.run_id, namespace, core, label="wandb/run-id"
|
645
|
+
)
|
646
|
+
return CrdSubmittedRun(
|
647
|
+
name=name,
|
648
|
+
group=group,
|
649
|
+
version=version,
|
650
|
+
namespace=namespace,
|
651
|
+
plural=plural,
|
652
|
+
core_api=client.CoreV1Api(api_client),
|
653
|
+
custom_api=api,
|
654
|
+
pod_names=pod_names,
|
655
|
+
)
|
656
|
+
|
325
657
|
batch_api = kubernetes.client.BatchV1Api(api_client)
|
326
658
|
core_api = kubernetes.client.CoreV1Api(api_client)
|
327
659
|
|
328
660
|
namespace = self.get_namespace(resource_args, context)
|
329
661
|
|
330
662
|
job, secret = self._inject_defaults(
|
331
|
-
resource_args,
|
332
|
-
launch_project,
|
333
|
-
builder,
|
334
|
-
namespace,
|
335
|
-
core_api,
|
663
|
+
resource_args, launch_project, builder, namespace, core_api, job_tracker
|
336
664
|
)
|
337
665
|
|
338
666
|
msg = "Creating Kubernetes job"
|
@@ -364,6 +692,17 @@ def inject_entrypoint_and_args(
|
|
364
692
|
override_args: List[str],
|
365
693
|
should_override_entrypoint: bool,
|
366
694
|
) -> None:
|
695
|
+
"""Inject the entrypoint and args into the containers.
|
696
|
+
|
697
|
+
Arguments:
|
698
|
+
containers: The containers to inject the entrypoint and args into.
|
699
|
+
entry_point: The entrypoint to inject.
|
700
|
+
override_args: The args to inject.
|
701
|
+
should_override_entrypoint: Whether to override the entrypoint.
|
702
|
+
|
703
|
+
Returns:
|
704
|
+
None
|
705
|
+
"""
|
367
706
|
for i in range(len(containers)):
|
368
707
|
if override_args:
|
369
708
|
containers[i]["args"] = override_args
|
@@ -379,8 +718,21 @@ def maybe_create_imagepull_secret(
|
|
379
718
|
run_id: str,
|
380
719
|
namespace: str,
|
381
720
|
) -> Optional["V1Secret"]:
|
721
|
+
"""Create a secret for pulling images from a private registry.
|
722
|
+
|
723
|
+
Arguments:
|
724
|
+
core_api: The Kubernetes CoreV1Api object.
|
725
|
+
registry: The registry to pull from.
|
726
|
+
run_id: The run id.
|
727
|
+
namespace: The namespace to create the secret in.
|
728
|
+
|
729
|
+
Returns:
|
730
|
+
A secret if one was created, otherwise None.
|
731
|
+
"""
|
382
732
|
secret = None
|
383
|
-
if isinstance(registry, LocalRegistry)
|
733
|
+
if isinstance(registry, LocalRegistry) or isinstance(
|
734
|
+
registry, AzureContainerRegistry
|
735
|
+
):
|
384
736
|
# Secret not required
|
385
737
|
return None
|
386
738
|
uname, token = registry.get_username_password()
|
@@ -406,3 +758,104 @@ def maybe_create_imagepull_secret(
|
|
406
758
|
return core_api.create_namespaced_secret(namespace, secret)
|
407
759
|
except Exception as e:
|
408
760
|
raise LaunchError(f"Exception when creating Kubernetes secret: {str(e)}\n")
|
761
|
+
|
762
|
+
|
763
|
+
def add_wandb_env(root: Union[dict, list], env_vars: Dict[str, str]) -> None:
|
764
|
+
"""Injects wandb environment variables into specs.
|
765
|
+
|
766
|
+
Recursively walks the spec and injects the environment variables into
|
767
|
+
every container spec. Containers are identified by the "containers" key.
|
768
|
+
|
769
|
+
This function treats the WANDB_RUN_ID and WANDB_GROUP_ID environment variables
|
770
|
+
specially. If they are present in the spec, they will be overwritten. If a setting
|
771
|
+
for WANDB_RUN_ID is provided in env_vars, then that environment variable will only be
|
772
|
+
set in the first container modified by this function.
|
773
|
+
|
774
|
+
Arguments:
|
775
|
+
root: The spec to modify.
|
776
|
+
env_vars: The environment variables to inject.
|
777
|
+
|
778
|
+
Returns: None.
|
779
|
+
"""
|
780
|
+
|
781
|
+
def yield_containers(root: Any) -> Iterator[dict]:
|
782
|
+
if isinstance(root, dict):
|
783
|
+
for k, v in root.items():
|
784
|
+
if k == "containers":
|
785
|
+
if isinstance(v, list):
|
786
|
+
yield from v
|
787
|
+
elif isinstance(v, (dict, list)):
|
788
|
+
yield from yield_containers(v)
|
789
|
+
elif isinstance(root, list):
|
790
|
+
for item in root:
|
791
|
+
yield from yield_containers(item)
|
792
|
+
|
793
|
+
for cont in yield_containers(root):
|
794
|
+
env = cont.setdefault("env", [])
|
795
|
+
env.extend([{"name": key, "value": value} for key, value in env_vars.items()])
|
796
|
+
cont["env"] = env
|
797
|
+
# After we have set WANDB_RUN_ID once, we don't want to set it again
|
798
|
+
if "WANDB_RUN_ID" in env_vars:
|
799
|
+
env_vars.pop("WANDB_RUN_ID")
|
800
|
+
|
801
|
+
|
802
|
+
def add_label_to_pods(
|
803
|
+
manifest: Union[dict, list], label_key: str, label_value: str
|
804
|
+
) -> None:
|
805
|
+
"""Add a label to all pod specs in a manifest.
|
806
|
+
|
807
|
+
Recursively traverses the manifest and adds the label to all pod specs.
|
808
|
+
Pod specs are identified by the presence of a "spec" key with a "containers"
|
809
|
+
key in the value.
|
810
|
+
|
811
|
+
Arguments:
|
812
|
+
manifest: The manifest to modify.
|
813
|
+
label_key: The label key to add.
|
814
|
+
label_value: The label value to add.
|
815
|
+
|
816
|
+
Returns: None.
|
817
|
+
"""
|
818
|
+
|
819
|
+
def yield_pods(manifest: Any) -> Iterator[dict]:
|
820
|
+
if isinstance(manifest, list):
|
821
|
+
for item in manifest:
|
822
|
+
yield from yield_pods(item)
|
823
|
+
elif isinstance(manifest, dict):
|
824
|
+
if "spec" in manifest and "containers" in manifest["spec"]:
|
825
|
+
yield manifest
|
826
|
+
for value in manifest.values():
|
827
|
+
if isinstance(value, (dict, list)):
|
828
|
+
yield from yield_pods(value)
|
829
|
+
|
830
|
+
for pod in yield_pods(manifest):
|
831
|
+
metadata = pod.setdefault("metadata", {})
|
832
|
+
labels = metadata.setdefault("labels", {})
|
833
|
+
labels[label_key] = label_value
|
834
|
+
|
835
|
+
|
836
|
+
def add_entrypoint_args_overrides(manifest: Union[dict, list], overrides: dict) -> None:
|
837
|
+
"""Add entrypoint and args overrides to all containers in a manifest.
|
838
|
+
|
839
|
+
Recursively traverses the manifest and adds the entrypoint and args overrides
|
840
|
+
to all containers. Containers are identified by the presence of a "spec" key
|
841
|
+
with a "containers" key in the value.
|
842
|
+
|
843
|
+
Arguments:
|
844
|
+
manifest: The manifest to modify.
|
845
|
+
overrides: Dictionary with args and entrypoint keys.
|
846
|
+
|
847
|
+
Returns: None.
|
848
|
+
"""
|
849
|
+
if isinstance(manifest, list):
|
850
|
+
for item in manifest:
|
851
|
+
add_entrypoint_args_overrides(item, overrides)
|
852
|
+
elif isinstance(manifest, dict):
|
853
|
+
if "spec" in manifest and "containers" in manifest["spec"]:
|
854
|
+
containers = manifest["spec"]["containers"]
|
855
|
+
for container in containers:
|
856
|
+
if "command" in overrides:
|
857
|
+
container["command"] = overrides["command"]
|
858
|
+
if "args" in overrides:
|
859
|
+
container["args"] = overrides["args"]
|
860
|
+
for value in manifest.values():
|
861
|
+
add_entrypoint_args_overrides(value, overrides)
|