ob-metaflow 2.10.7.4__py2.py3-none-any.whl → 2.10.9.2__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ob-metaflow might be problematic. Click here for more details.
- metaflow/cards.py +2 -0
- metaflow/decorators.py +1 -1
- metaflow/metaflow_config.py +4 -0
- metaflow/plugins/__init__.py +4 -0
- metaflow/plugins/airflow/airflow_cli.py +1 -1
- metaflow/plugins/argo/argo_workflows.py +5 -0
- metaflow/plugins/argo/argo_workflows_cli.py +1 -1
- metaflow/plugins/aws/aws_utils.py +1 -1
- metaflow/plugins/aws/batch/batch.py +4 -0
- metaflow/plugins/aws/batch/batch_cli.py +3 -0
- metaflow/plugins/aws/batch/batch_client.py +40 -11
- metaflow/plugins/aws/batch/batch_decorator.py +1 -0
- metaflow/plugins/aws/step_functions/step_functions.py +1 -0
- metaflow/plugins/aws/step_functions/step_functions_cli.py +1 -1
- metaflow/plugins/azure/azure_exceptions.py +1 -1
- metaflow/plugins/cards/card_cli.py +413 -28
- metaflow/plugins/cards/card_client.py +16 -7
- metaflow/plugins/cards/card_creator.py +228 -0
- metaflow/plugins/cards/card_datastore.py +124 -26
- metaflow/plugins/cards/card_decorator.py +40 -86
- metaflow/plugins/cards/card_modules/base.html +12 -0
- metaflow/plugins/cards/card_modules/basic.py +74 -8
- metaflow/plugins/cards/card_modules/bundle.css +1 -170
- metaflow/plugins/cards/card_modules/card.py +65 -0
- metaflow/plugins/cards/card_modules/components.py +446 -81
- metaflow/plugins/cards/card_modules/convert_to_native_type.py +9 -3
- metaflow/plugins/cards/card_modules/main.js +250 -21
- metaflow/plugins/cards/card_modules/test_cards.py +117 -0
- metaflow/plugins/cards/card_resolver.py +0 -2
- metaflow/plugins/cards/card_server.py +361 -0
- metaflow/plugins/cards/component_serializer.py +506 -42
- metaflow/plugins/cards/exception.py +20 -1
- metaflow/plugins/datastores/azure_storage.py +1 -2
- metaflow/plugins/datastores/gs_storage.py +1 -2
- metaflow/plugins/datastores/s3_storage.py +2 -1
- metaflow/plugins/datatools/s3/s3.py +24 -11
- metaflow/plugins/env_escape/client.py +2 -12
- metaflow/plugins/env_escape/client_modules.py +18 -14
- metaflow/plugins/env_escape/server.py +18 -11
- metaflow/plugins/env_escape/utils.py +12 -0
- metaflow/plugins/gcp/gs_exceptions.py +1 -1
- metaflow/plugins/gcp/gs_utils.py +1 -1
- metaflow/plugins/kubernetes/kubernetes.py +43 -6
- metaflow/plugins/kubernetes/kubernetes_cli.py +40 -1
- metaflow/plugins/kubernetes/kubernetes_decorator.py +73 -6
- metaflow/plugins/kubernetes/kubernetes_job.py +536 -161
- metaflow/plugins/pypi/conda_environment.py +5 -6
- metaflow/plugins/pypi/pip.py +2 -2
- metaflow/plugins/pypi/utils.py +15 -0
- metaflow/task.py +1 -0
- metaflow/version.py +1 -1
- {ob_metaflow-2.10.7.4.dist-info → ob_metaflow-2.10.9.2.dist-info}/METADATA +1 -1
- {ob_metaflow-2.10.7.4.dist-info → ob_metaflow-2.10.9.2.dist-info}/RECORD +57 -55
- {ob_metaflow-2.10.7.4.dist-info → ob_metaflow-2.10.9.2.dist-info}/LICENSE +0 -0
- {ob_metaflow-2.10.7.4.dist-info → ob_metaflow-2.10.9.2.dist-info}/WHEEL +0 -0
- {ob_metaflow-2.10.7.4.dist-info → ob_metaflow-2.10.9.2.dist-info}/entry_points.txt +0 -0
- {ob_metaflow-2.10.7.4.dist-info → ob_metaflow-2.10.9.2.dist-info}/top_level.txt +0 -0
|
@@ -2,20 +2,18 @@ import json
|
|
|
2
2
|
import math
|
|
3
3
|
import random
|
|
4
4
|
import time
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
5
|
+
import os
|
|
6
|
+
import socket
|
|
7
|
+
import copy
|
|
8
8
|
|
|
9
9
|
from metaflow.exception import MetaflowException
|
|
10
10
|
from metaflow.metaflow_config import KUBERNETES_SECRETS
|
|
11
11
|
|
|
12
12
|
CLIENT_REFRESH_INTERVAL_SECONDS = 300
|
|
13
13
|
|
|
14
|
-
|
|
15
14
|
class KubernetesJobException(MetaflowException):
|
|
16
15
|
headline = "Kubernetes job error"
|
|
17
16
|
|
|
18
|
-
|
|
19
17
|
# Implements truncated exponential backoff from
|
|
20
18
|
# https://cloud.google.com/storage/docs/retry-strategy#exponential-backoff
|
|
21
19
|
def k8s_retry(deadline_seconds=60, max_backoff=32):
|
|
@@ -78,107 +76,260 @@ class KubernetesJob(object):
|
|
|
78
76
|
tmpfs_size = self._kwargs["tmpfs_size"]
|
|
79
77
|
tmpfs_enabled = use_tmpfs or (tmpfs_size and not use_tmpfs)
|
|
80
78
|
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
79
|
+
jobset_name = "js-%s" % self._kwargs["attrs"]["metaflow.task_id"].split('-')[-1]
|
|
80
|
+
main_job_name = "control"
|
|
81
|
+
main_job_index = 0
|
|
82
|
+
main_pod_index = 0
|
|
83
|
+
subdomain = jobset_name
|
|
84
|
+
master_port = int(self._kwargs['port']) if self._kwargs['port'] else None
|
|
85
|
+
|
|
86
|
+
passwordless_ssh = self._kwargs["attrs"]["requires_passwordless_ssh"]
|
|
87
|
+
if passwordless_ssh:
|
|
88
|
+
passwordless_ssh_service_name = subdomain
|
|
89
|
+
passwordless_ssh_service_selector = {
|
|
90
|
+
"passwordless-ssh-jobset": "true"
|
|
91
|
+
}
|
|
92
|
+
else:
|
|
93
|
+
passwordless_ssh_service_name = None
|
|
94
|
+
passwordless_ssh_service_selector = {}
|
|
95
|
+
|
|
96
|
+
fqdn_suffix = "%s.svc.cluster.local" % self._kwargs["namespace"]
|
|
97
|
+
jobset_main_addr = "%s-%s-%s-%s.%s.%s" % (
|
|
98
|
+
jobset_name,
|
|
99
|
+
main_job_name,
|
|
100
|
+
main_job_index,
|
|
101
|
+
main_pod_index,
|
|
102
|
+
subdomain,
|
|
103
|
+
fqdn_suffix,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
def _install_jobset(
|
|
107
|
+
repo_url="https://github.com/kubernetes-sigs/jobset",
|
|
108
|
+
python_sdk_path="jobset/sdk/python",
|
|
109
|
+
):
|
|
110
|
+
|
|
111
|
+
# TODO (Eddie): Remove this and suggest to user.
|
|
112
|
+
|
|
113
|
+
import subprocess
|
|
114
|
+
import tempfile
|
|
115
|
+
import shutil
|
|
116
|
+
import os
|
|
117
|
+
|
|
118
|
+
with open(os.devnull, "wb") as devnull:
|
|
119
|
+
cwd = os.getcwd()
|
|
120
|
+
tmp_dir = tempfile.mkdtemp()
|
|
121
|
+
os.chdir(tmp_dir)
|
|
122
|
+
subprocess.check_call(
|
|
123
|
+
["git", "clone", repo_url], stdout=devnull, stderr=subprocess.STDOUT
|
|
124
|
+
)
|
|
125
|
+
tmp_python_sdk_path = os.path.join(tmp_dir, python_sdk_path)
|
|
126
|
+
os.chdir(tmp_python_sdk_path)
|
|
127
|
+
subprocess.check_call(
|
|
128
|
+
["pip", "install", "."], stdout=devnull, stderr=subprocess.STDOUT
|
|
129
|
+
)
|
|
130
|
+
os.chdir(cwd)
|
|
131
|
+
shutil.rmtree(tmp_dir)
|
|
132
|
+
|
|
133
|
+
def _get_passwordless_ssh_service():
|
|
134
|
+
|
|
135
|
+
return client.V1Service(
|
|
136
|
+
api_version="v1",
|
|
137
|
+
kind="Service",
|
|
138
|
+
metadata=client.V1ObjectMeta(
|
|
139
|
+
name=passwordless_ssh_service_name,
|
|
140
|
+
namespace=self._kwargs["namespace"]
|
|
141
|
+
),
|
|
142
|
+
spec=client.V1ServiceSpec(
|
|
143
|
+
cluster_ip="None",
|
|
144
|
+
internal_traffic_policy="Cluster",
|
|
145
|
+
ip_families=["IPv4"],
|
|
146
|
+
ip_family_policy="SingleStack",
|
|
147
|
+
selector=passwordless_ssh_service_selector,
|
|
148
|
+
session_affinity="None",
|
|
149
|
+
type="ClusterIP",
|
|
150
|
+
ports=[
|
|
151
|
+
client.V1ServicePort(
|
|
152
|
+
name="control",
|
|
153
|
+
port=22,
|
|
154
|
+
protocol="TCP",
|
|
155
|
+
target_port=22
|
|
156
|
+
)
|
|
157
|
+
]
|
|
158
|
+
)
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
def _get_replicated_job(job_name, parallelism, command):
|
|
162
|
+
return jobset.models.jobset_v1alpha2_replicated_job.JobsetV1alpha2ReplicatedJob(
|
|
163
|
+
name=job_name,
|
|
164
|
+
template=client.V1JobTemplateSpec(
|
|
103
165
|
metadata=client.V1ObjectMeta(
|
|
104
166
|
annotations=self._kwargs.get("annotations", {}),
|
|
105
167
|
labels=self._kwargs.get("labels", {}),
|
|
106
168
|
namespace=self._kwargs["namespace"],
|
|
107
169
|
),
|
|
108
|
-
spec=client.
|
|
109
|
-
#
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
170
|
+
spec=client.V1JobSpec(
|
|
171
|
+
parallelism=parallelism, # how many jobs can run at once
|
|
172
|
+
completions=parallelism, # how many Pods the JobSet creates in total
|
|
173
|
+
backoff_limit=0,
|
|
174
|
+
ttl_seconds_after_finished=7
|
|
175
|
+
* 60
|
|
176
|
+
* 60
|
|
177
|
+
* 24,
|
|
178
|
+
template=client.V1PodTemplateSpec(
|
|
179
|
+
metadata=client.V1ObjectMeta(
|
|
180
|
+
annotations=self._kwargs.get("annotations", {}),
|
|
181
|
+
labels={
|
|
182
|
+
**self._kwargs.get("labels", {}),
|
|
183
|
+
**passwordless_ssh_service_selector, # TODO: necessary?
|
|
184
|
+
# TODO: cluster-name, app.kubernetes.io/name necessary?
|
|
185
|
+
},
|
|
186
|
+
namespace=self._kwargs["namespace"],
|
|
187
|
+
),
|
|
188
|
+
spec=client.V1PodSpec(
|
|
189
|
+
active_deadline_seconds=self._kwargs[
|
|
190
|
+
"timeout_in_seconds"
|
|
191
|
+
],
|
|
192
|
+
containers=[
|
|
193
|
+
client.V1Container(
|
|
194
|
+
command=command,
|
|
195
|
+
ports=[client.V1ContainerPort(container_port=master_port)] if master_port and job_name=="control" else [],
|
|
196
|
+
env=[
|
|
197
|
+
client.V1EnvVar(name=k, value=str(v))
|
|
198
|
+
for k, v in self._kwargs.get(
|
|
199
|
+
"environment_variables", {}
|
|
200
|
+
).items()
|
|
201
|
+
]
|
|
202
|
+
+ [
|
|
203
|
+
client.V1EnvVar(
|
|
204
|
+
name=k,
|
|
205
|
+
value_from=client.V1EnvVarSource(
|
|
206
|
+
field_ref=client.V1ObjectFieldSelector(
|
|
207
|
+
field_path=str(v)
|
|
208
|
+
)
|
|
209
|
+
),
|
|
210
|
+
)
|
|
211
|
+
for k, v in {
|
|
212
|
+
"METAFLOW_KUBERNETES_POD_NAMESPACE": "metadata.namespace",
|
|
213
|
+
"METAFLOW_KUBERNETES_POD_NAME": "metadata.name",
|
|
214
|
+
"METAFLOW_KUBERNETES_POD_ID": "metadata.uid",
|
|
215
|
+
"METAFLOW_KUBERNETES_SERVICE_ACCOUNT_NAME": "spec.serviceAccountName",
|
|
216
|
+
"METAFLOW_KUBERNETES_NODE_IP": "status.hostIP",
|
|
217
|
+
}.items()
|
|
218
|
+
]
|
|
219
|
+
# Mimicking the AWS Batch Multinode env vars.
|
|
220
|
+
+ [
|
|
221
|
+
client.V1EnvVar(
|
|
222
|
+
name="MASTER_ADDR",
|
|
223
|
+
value=jobset_main_addr,
|
|
224
|
+
),
|
|
225
|
+
client.V1EnvVar(
|
|
226
|
+
name="MASTER_PORT",
|
|
227
|
+
value=str(master_port),
|
|
228
|
+
),
|
|
229
|
+
client.V1EnvVar(
|
|
230
|
+
name="RANK",
|
|
231
|
+
value_from=client.V1EnvVarSource(
|
|
232
|
+
field_ref=client.V1ObjectFieldSelector(
|
|
233
|
+
field_path="metadata.annotations['batch.kubernetes.io/job-completion-index']"
|
|
234
|
+
)
|
|
235
|
+
),
|
|
236
|
+
),
|
|
237
|
+
client.V1EnvVar(
|
|
238
|
+
name="WORLD_SIZE",
|
|
239
|
+
value=str(self._kwargs["num_parallel"]),
|
|
240
|
+
),
|
|
241
|
+
client.V1EnvVar(
|
|
242
|
+
name="PYTHONUNBUFFERED",
|
|
243
|
+
value="0",
|
|
244
|
+
),
|
|
245
|
+
],
|
|
246
|
+
env_from=[
|
|
247
|
+
client.V1EnvFromSource(
|
|
248
|
+
secret_ref=client.V1SecretEnvSource(
|
|
249
|
+
name=str(k),
|
|
250
|
+
# optional=True
|
|
251
|
+
)
|
|
131
252
|
)
|
|
253
|
+
for k in list(
|
|
254
|
+
self._kwargs.get("secrets", [])
|
|
255
|
+
)
|
|
256
|
+
+ KUBERNETES_SECRETS.split(",")
|
|
257
|
+
if k
|
|
258
|
+
],
|
|
259
|
+
image=self._kwargs["image"],
|
|
260
|
+
image_pull_policy=self._kwargs[
|
|
261
|
+
"image_pull_policy"
|
|
262
|
+
],
|
|
263
|
+
name=self._kwargs["step_name"].replace(
|
|
264
|
+
"_", "-"
|
|
132
265
|
),
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
266
|
+
resources=client.V1ResourceRequirements(
|
|
267
|
+
requests={
|
|
268
|
+
"cpu": str(self._kwargs["cpu"]),
|
|
269
|
+
"memory": "%sM"
|
|
270
|
+
% str(self._kwargs["memory"]),
|
|
271
|
+
"ephemeral-storage": "%sM"
|
|
272
|
+
% str(self._kwargs["disk"]),
|
|
273
|
+
},
|
|
274
|
+
limits={
|
|
275
|
+
"%s.com/gpu".lower()
|
|
276
|
+
% self._kwargs["gpu_vendor"]: str(
|
|
277
|
+
self._kwargs["gpu"]
|
|
278
|
+
)
|
|
279
|
+
for k in [0]
|
|
280
|
+
# Don't set GPU limits if gpu isn't specified.
|
|
281
|
+
if self._kwargs["gpu"] is not None
|
|
282
|
+
},
|
|
283
|
+
),
|
|
284
|
+
volume_mounts=(
|
|
285
|
+
[
|
|
286
|
+
client.V1VolumeMount(
|
|
287
|
+
mount_path=self._kwargs.get(
|
|
288
|
+
"tmpfs_path"
|
|
289
|
+
),
|
|
290
|
+
name="tmpfs-ephemeral-volume",
|
|
291
|
+
)
|
|
292
|
+
]
|
|
293
|
+
if tmpfs_enabled
|
|
294
|
+
else []
|
|
151
295
|
)
|
|
296
|
+
+ (
|
|
297
|
+
[
|
|
298
|
+
client.V1VolumeMount(
|
|
299
|
+
mount_path=path, name=claim
|
|
300
|
+
)
|
|
301
|
+
for claim, path in self._kwargs[
|
|
302
|
+
"persistent_volume_claims"
|
|
303
|
+
].items()
|
|
304
|
+
]
|
|
305
|
+
if self._kwargs["persistent_volume_claims"]
|
|
306
|
+
is not None
|
|
307
|
+
else []
|
|
308
|
+
),
|
|
152
309
|
)
|
|
153
|
-
for k in list(self._kwargs.get("secrets", []))
|
|
154
|
-
+ KUBERNETES_SECRETS.split(",")
|
|
155
|
-
if k
|
|
156
310
|
],
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
for k in [0]
|
|
173
|
-
# Don't set GPU limits if gpu isn't specified.
|
|
174
|
-
if self._kwargs["gpu"] is not None
|
|
175
|
-
},
|
|
176
|
-
),
|
|
177
|
-
volume_mounts=(
|
|
311
|
+
node_selector=self._kwargs.get("node_selector"),
|
|
312
|
+
restart_policy="Never",
|
|
313
|
+
|
|
314
|
+
set_hostname_as_fqdn=True, # configure pod hostname as pod's FQDN
|
|
315
|
+
share_process_namespace=False, # default
|
|
316
|
+
subdomain=subdomain, # FQDN = <hostname>.<subdomain>.<pod namespace>.svc.<cluster domain>
|
|
317
|
+
|
|
318
|
+
service_account_name=self._kwargs["service_account"],
|
|
319
|
+
termination_grace_period_seconds=0,
|
|
320
|
+
tolerations=[
|
|
321
|
+
client.V1Toleration(**toleration)
|
|
322
|
+
for toleration in self._kwargs.get("tolerations")
|
|
323
|
+
or []
|
|
324
|
+
],
|
|
325
|
+
volumes=(
|
|
178
326
|
[
|
|
179
|
-
client.
|
|
180
|
-
mount_path=self._kwargs.get("tmpfs_path"),
|
|
327
|
+
client.V1Volume(
|
|
181
328
|
name="tmpfs-ephemeral-volume",
|
|
329
|
+
empty_dir=client.V1EmptyDirVolumeSource(
|
|
330
|
+
medium="Memory",
|
|
331
|
+
size_limit="{}Mi".format(tmpfs_size),
|
|
332
|
+
),
|
|
182
333
|
)
|
|
183
334
|
]
|
|
184
335
|
if tmpfs_enabled
|
|
@@ -186,72 +337,264 @@ class KubernetesJob(object):
|
|
|
186
337
|
)
|
|
187
338
|
+ (
|
|
188
339
|
[
|
|
189
|
-
client.
|
|
190
|
-
|
|
340
|
+
client.V1Volume(
|
|
341
|
+
name=claim,
|
|
342
|
+
persistent_volume_claim=client.V1PersistentVolumeClaimVolumeSource(
|
|
343
|
+
claim_name=claim
|
|
344
|
+
),
|
|
191
345
|
)
|
|
192
|
-
for claim
|
|
346
|
+
for claim in self._kwargs[
|
|
193
347
|
"persistent_volume_claims"
|
|
194
|
-
].
|
|
348
|
+
].keys()
|
|
195
349
|
]
|
|
196
350
|
if self._kwargs["persistent_volume_claims"]
|
|
197
351
|
is not None
|
|
198
352
|
else []
|
|
199
353
|
),
|
|
200
|
-
)
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
354
|
+
),
|
|
355
|
+
),
|
|
356
|
+
),
|
|
357
|
+
),
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
if "num_parallel" in self._kwargs and self._kwargs["num_parallel"] >= 1:
|
|
361
|
+
|
|
362
|
+
try:
|
|
363
|
+
import jobset
|
|
364
|
+
except ImportError:
|
|
365
|
+
_install_jobset()
|
|
366
|
+
import jobset
|
|
367
|
+
|
|
368
|
+
main_commands = copy.copy(self._kwargs["command"])
|
|
369
|
+
main_commands[-1] = main_commands[-1].replace(
|
|
370
|
+
"[multinode-args]", "--split-index 0"
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
task_id = self._kwargs["attrs"]["metaflow.task_id"]
|
|
374
|
+
secondary_commands = copy.copy(self._kwargs["command"])
|
|
375
|
+
# RANK needs +1 because control node is not in the worker index group, yet we want global nodes.
|
|
376
|
+
# Technically, control and worker could be same replicated job type, but cleaner to separate for future use cases.
|
|
377
|
+
secondary_commands[-1] = secondary_commands[-1].replace(
|
|
378
|
+
"[multinode-args]", "--split-index `expr $RANK + 1`"
|
|
379
|
+
)
|
|
380
|
+
secondary_commands[-1] = secondary_commands[-1].replace(
|
|
381
|
+
"ubf_control", "ubf_task"
|
|
382
|
+
)
|
|
383
|
+
secondary_commands[-1] = secondary_commands[-1].replace(
|
|
384
|
+
task_id,
|
|
385
|
+
task_id.replace("control-", "") + "-node-`expr $RANK + 1`",
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
if passwordless_ssh:
|
|
389
|
+
if not os.path.exists("/usr/sbin/sshd"):
|
|
390
|
+
raise KubernetesJobException(
|
|
391
|
+
"This @parallel decorator requires sshd to be installed in the container image."
|
|
392
|
+
"Please install OpenSSH."
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
# run sshd in background
|
|
396
|
+
main_commands[-1] = "/usr/sbin/sshd -D & %s" % main_commands[-1]
|
|
397
|
+
secondary_commands[-1] = "/usr/sbin/sshd -D & %s" % secondary_commands[-1]
|
|
398
|
+
|
|
399
|
+
self._jobset = jobset.models.jobset_v1alpha2_job_set.JobsetV1alpha2JobSet(
|
|
400
|
+
api_version="jobset.x-k8s.io/v1alpha2",
|
|
401
|
+
kind="JobSet",
|
|
402
|
+
metadata=client.V1ObjectMeta(
|
|
403
|
+
annotations=self._kwargs.get("annotations", {}),
|
|
404
|
+
labels=self._kwargs.get("labels", {}),
|
|
405
|
+
name=jobset_name,
|
|
406
|
+
namespace=self._kwargs["namespace"],
|
|
407
|
+
),
|
|
408
|
+
spec=jobset.models.jobset_v1alpha2_job_set_spec.JobsetV1alpha2JobSetSpec(
|
|
409
|
+
network=jobset.models.jobset_v1alpha2_network.JobsetV1alpha2Network(
|
|
410
|
+
enable_dns_hostnames=True if not self._kwargs['attrs']['requires_passwordless_ssh'] else False,
|
|
411
|
+
subdomain=subdomain
|
|
412
|
+
),
|
|
413
|
+
replicated_jobs=[
|
|
414
|
+
_get_replicated_job("control", 1, main_commands),
|
|
415
|
+
_get_replicated_job(
|
|
416
|
+
"worker",
|
|
417
|
+
self._kwargs["num_parallel"] - 1,
|
|
418
|
+
secondary_commands,
|
|
419
|
+
),
|
|
420
|
+
],
|
|
421
|
+
),
|
|
422
|
+
)
|
|
423
|
+
self._passwordless_ssh_service = _get_passwordless_ssh_service()
|
|
424
|
+
else:
|
|
425
|
+
self._job = client.V1Job(
|
|
426
|
+
api_version="batch/v1",
|
|
427
|
+
kind="Job",
|
|
428
|
+
metadata=client.V1ObjectMeta(
|
|
429
|
+
# Annotations are for humans
|
|
430
|
+
annotations=self._kwargs.get("annotations", {}),
|
|
431
|
+
# While labels are for Kubernetes
|
|
432
|
+
labels=self._kwargs.get("labels", {}),
|
|
433
|
+
generate_name=self._kwargs["generate_name"],
|
|
434
|
+
namespace=self._kwargs["namespace"], # Defaults to `default`
|
|
435
|
+
),
|
|
436
|
+
spec=client.V1JobSpec(
|
|
437
|
+
# Retries are handled by Metaflow when it is responsible for
|
|
438
|
+
# executing the flow. The responsibility is moved to Kubernetes
|
|
439
|
+
# when Argo Workflows is responsible for the execution.
|
|
440
|
+
backoff_limit=self._kwargs.get("retries", 0),
|
|
441
|
+
completions=1, # A single non-indexed pod job
|
|
442
|
+
ttl_seconds_after_finished=7
|
|
443
|
+
* 60
|
|
444
|
+
* 60 # Remove job after a week. TODO: Make this configurable
|
|
445
|
+
* 24,
|
|
446
|
+
template=client.V1PodTemplateSpec(
|
|
447
|
+
metadata=client.V1ObjectMeta(
|
|
448
|
+
annotations=self._kwargs.get("annotations", {}),
|
|
449
|
+
labels=self._kwargs.get("labels", {}),
|
|
450
|
+
namespace=self._kwargs["namespace"],
|
|
451
|
+
),
|
|
452
|
+
spec=client.V1PodSpec(
|
|
453
|
+
# Timeout is set on the pod and not the job (important!)
|
|
454
|
+
active_deadline_seconds=self._kwargs["timeout_in_seconds"],
|
|
455
|
+
# TODO (savin): Enable affinities for GPU scheduling.
|
|
456
|
+
# affinity=?,
|
|
457
|
+
containers=[
|
|
458
|
+
client.V1Container(
|
|
459
|
+
command=self._kwargs["command"],
|
|
460
|
+
env=[
|
|
461
|
+
client.V1EnvVar(name=k, value=str(v))
|
|
462
|
+
for k, v in self._kwargs.get(
|
|
463
|
+
"environment_variables", {}
|
|
464
|
+
).items()
|
|
465
|
+
]
|
|
466
|
+
# And some downward API magic. Add (key, value)
|
|
467
|
+
# pairs below to make pod metadata available
|
|
468
|
+
# within Kubernetes container.
|
|
469
|
+
+ [
|
|
470
|
+
client.V1EnvVar(
|
|
471
|
+
name=k,
|
|
472
|
+
value_from=client.V1EnvVarSource(
|
|
473
|
+
field_ref=client.V1ObjectFieldSelector(
|
|
474
|
+
field_path=str(v)
|
|
475
|
+
)
|
|
476
|
+
),
|
|
477
|
+
)
|
|
478
|
+
for k, v in {
|
|
479
|
+
"METAFLOW_KUBERNETES_POD_NAMESPACE": "metadata.namespace",
|
|
480
|
+
"METAFLOW_KUBERNETES_POD_NAME": "metadata.name",
|
|
481
|
+
"METAFLOW_KUBERNETES_POD_ID": "metadata.uid",
|
|
482
|
+
"METAFLOW_KUBERNETES_SERVICE_ACCOUNT_NAME": "spec.serviceAccountName",
|
|
483
|
+
"METAFLOW_KUBERNETES_NODE_IP": "status.hostIP",
|
|
484
|
+
}.items()
|
|
485
|
+
],
|
|
486
|
+
env_from=[
|
|
487
|
+
client.V1EnvFromSource(
|
|
488
|
+
secret_ref=client.V1SecretEnvSource(
|
|
489
|
+
name=str(k),
|
|
490
|
+
# optional=True
|
|
491
|
+
)
|
|
492
|
+
)
|
|
493
|
+
for k in list(self._kwargs.get("secrets", []))
|
|
494
|
+
+ KUBERNETES_SECRETS.split(",")
|
|
495
|
+
if k
|
|
496
|
+
],
|
|
497
|
+
image=self._kwargs["image"],
|
|
498
|
+
image_pull_policy=self._kwargs["image_pull_policy"],
|
|
499
|
+
name=self._kwargs["step_name"].replace("_", "-"),
|
|
500
|
+
resources=client.V1ResourceRequirements(
|
|
501
|
+
requests={
|
|
502
|
+
"cpu": str(self._kwargs["cpu"]),
|
|
503
|
+
"memory": "%sM"
|
|
504
|
+
% str(self._kwargs["memory"]),
|
|
505
|
+
"ephemeral-storage": "%sM"
|
|
506
|
+
% str(self._kwargs["disk"]),
|
|
507
|
+
},
|
|
508
|
+
limits={
|
|
509
|
+
"%s.com/gpu".lower()
|
|
510
|
+
% self._kwargs["gpu_vendor"]: str(
|
|
511
|
+
self._kwargs["gpu"]
|
|
512
|
+
)
|
|
513
|
+
for k in [0]
|
|
514
|
+
# Don't set GPU limits if gpu isn't specified.
|
|
515
|
+
if self._kwargs["gpu"] is not None
|
|
516
|
+
},
|
|
229
517
|
),
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
518
|
+
volume_mounts=(
|
|
519
|
+
[
|
|
520
|
+
client.V1VolumeMount(
|
|
521
|
+
mount_path=self._kwargs.get(
|
|
522
|
+
"tmpfs_path"
|
|
523
|
+
),
|
|
524
|
+
name="tmpfs-ephemeral-volume",
|
|
525
|
+
)
|
|
526
|
+
]
|
|
527
|
+
if tmpfs_enabled
|
|
528
|
+
else []
|
|
529
|
+
)
|
|
530
|
+
+ (
|
|
531
|
+
[
|
|
532
|
+
client.V1VolumeMount(
|
|
533
|
+
mount_path=path, name=claim
|
|
534
|
+
)
|
|
535
|
+
for claim, path in self._kwargs[
|
|
536
|
+
"persistent_volume_claims"
|
|
537
|
+
].items()
|
|
538
|
+
]
|
|
539
|
+
if self._kwargs["persistent_volume_claims"]
|
|
540
|
+
is not None
|
|
541
|
+
else []
|
|
241
542
|
),
|
|
242
543
|
)
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
544
|
+
],
|
|
545
|
+
node_selector=self._kwargs.get("node_selector"),
|
|
546
|
+
# TODO (savin): Support image_pull_secrets
|
|
547
|
+
# image_pull_secrets=?,
|
|
548
|
+
# TODO (savin): Support preemption policies
|
|
549
|
+
# preemption_policy=?,
|
|
550
|
+
#
|
|
551
|
+
# A Container in a Pod may fail for a number of
|
|
552
|
+
# reasons, such as because the process in it exited
|
|
553
|
+
# with a non-zero exit code, or the Container was
|
|
554
|
+
# killed due to OOM etc. If this happens, fail the pod
|
|
555
|
+
# and let Metaflow handle the retries.
|
|
556
|
+
restart_policy="Never",
|
|
557
|
+
service_account_name=self._kwargs["service_account"],
|
|
558
|
+
# Terminate the container immediately on SIGTERM
|
|
559
|
+
termination_grace_period_seconds=0,
|
|
560
|
+
tolerations=[
|
|
561
|
+
client.V1Toleration(**toleration)
|
|
562
|
+
for toleration in self._kwargs.get("tolerations") or []
|
|
563
|
+
],
|
|
564
|
+
volumes=(
|
|
565
|
+
[
|
|
566
|
+
client.V1Volume(
|
|
567
|
+
name="tmpfs-ephemeral-volume",
|
|
568
|
+
empty_dir=client.V1EmptyDirVolumeSource(
|
|
569
|
+
medium="Memory",
|
|
570
|
+
# Add default unit as ours differs from Kubernetes default.
|
|
571
|
+
size_limit="{}Mi".format(tmpfs_size),
|
|
572
|
+
),
|
|
573
|
+
)
|
|
574
|
+
]
|
|
575
|
+
if tmpfs_enabled
|
|
576
|
+
else []
|
|
577
|
+
)
|
|
578
|
+
+ (
|
|
579
|
+
[
|
|
580
|
+
client.V1Volume(
|
|
581
|
+
name=claim,
|
|
582
|
+
persistent_volume_claim=client.V1PersistentVolumeClaimVolumeSource(
|
|
583
|
+
claim_name=claim
|
|
584
|
+
),
|
|
585
|
+
)
|
|
586
|
+
for claim in self._kwargs[
|
|
587
|
+
"persistent_volume_claims"
|
|
588
|
+
].keys()
|
|
589
|
+
]
|
|
590
|
+
if self._kwargs["persistent_volume_claims"] is not None
|
|
591
|
+
else []
|
|
592
|
+
),
|
|
593
|
+
# TODO (savin): Set termination_message_policy
|
|
249
594
|
),
|
|
250
|
-
# TODO (savin): Set termination_message_policy
|
|
251
595
|
),
|
|
252
596
|
),
|
|
253
|
-
)
|
|
254
|
-
)
|
|
597
|
+
)
|
|
255
598
|
return self
|
|
256
599
|
|
|
257
600
|
def execute(self):
|
|
@@ -262,19 +605,53 @@ class KubernetesJob(object):
|
|
|
262
605
|
# achieve the guarantees that we are seeking.
|
|
263
606
|
# https://github.com/kubernetes/enhancements/issues/1040
|
|
264
607
|
# Hopefully, we will be able to get creative with kube-batch
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
.
|
|
268
|
-
|
|
608
|
+
|
|
609
|
+
if "num_parallel" in self._kwargs and self._kwargs["num_parallel"] >= 1:
|
|
610
|
+
# TODO (Eddie): this is kinda gross. fix it.
|
|
611
|
+
if self._kwargs["attrs"]["requires_passwordless_ssh"]:
|
|
612
|
+
api_instance = client.CoreV1Api()
|
|
613
|
+
api_response = api_instance.create_namespaced_service(namespace=self._kwargs['namespace'], body=self._passwordless_ssh_service)
|
|
614
|
+
|
|
615
|
+
with client.ApiClient() as api_client:
|
|
616
|
+
api_instance = client.CustomObjectsApi(api_client)
|
|
617
|
+
|
|
618
|
+
response = api_instance.create_namespaced_custom_object(
|
|
619
|
+
body=self._jobset,
|
|
620
|
+
group="jobset.x-k8s.io",
|
|
621
|
+
version="v1alpha2",
|
|
622
|
+
namespace=self._kwargs["namespace"],
|
|
623
|
+
plural="jobsets",
|
|
269
624
|
)
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
625
|
+
|
|
626
|
+
# HACK: Give K8s some time to actually create the job
|
|
627
|
+
time.sleep(10)
|
|
628
|
+
|
|
629
|
+
# TODO (Eddie): Remove hack and make RunningJobSet.
|
|
630
|
+
# There are many jobs running that should be monitored.
|
|
631
|
+
job_name = "%s-control-0" % response["metadata"]["name"]
|
|
632
|
+
fake_id = 123
|
|
633
|
+
return RunningJob(
|
|
634
|
+
client=self._client,
|
|
635
|
+
name=job_name,
|
|
636
|
+
uid=fake_id,
|
|
637
|
+
namespace=response["metadata"]["namespace"],
|
|
638
|
+
)
|
|
639
|
+
|
|
640
|
+
else:
|
|
641
|
+
response = (
|
|
642
|
+
client.BatchV1Api()
|
|
643
|
+
.create_namespaced_job(
|
|
644
|
+
body=self._job, namespace=self._kwargs["namespace"]
|
|
645
|
+
)
|
|
646
|
+
.to_dict()
|
|
647
|
+
)
|
|
648
|
+
return RunningJob(
|
|
649
|
+
client=self._client,
|
|
650
|
+
name=response["metadata"]["name"],
|
|
651
|
+
uid=response["metadata"]["uid"],
|
|
652
|
+
namespace=response["metadata"]["namespace"],
|
|
653
|
+
)
|
|
654
|
+
|
|
278
655
|
except client.rest.ApiException as e:
|
|
279
656
|
raise KubernetesJobException(
|
|
280
657
|
"Unable to launch Kubernetes job.\n %s"
|
|
@@ -330,7 +707,6 @@ class KubernetesJob(object):
|
|
|
330
707
|
|
|
331
708
|
|
|
332
709
|
class RunningJob(object):
|
|
333
|
-
|
|
334
710
|
# State Machine implementation for the lifecycle behavior documented in
|
|
335
711
|
# https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle/
|
|
336
712
|
#
|
|
@@ -450,7 +826,6 @@ class RunningJob(object):
|
|
|
450
826
|
client = self._client.get()
|
|
451
827
|
if not self.is_done:
|
|
452
828
|
if self.is_running:
|
|
453
|
-
|
|
454
829
|
# Case 1.
|
|
455
830
|
from kubernetes.stream import stream
|
|
456
831
|
|