metaflow 2.11.14__py2.py3-none-any.whl → 2.11.16__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- metaflow/__init__.py +3 -0
- metaflow/cli.py +0 -120
- metaflow/clone_util.py +6 -0
- metaflow/datastore/datastore_set.py +1 -1
- metaflow/datastore/flow_datastore.py +32 -6
- metaflow/datastore/task_datastore.py +50 -0
- metaflow/extension_support/plugins.py +2 -0
- metaflow/metaflow_config.py +24 -0
- metaflow/metaflow_environment.py +2 -2
- metaflow/plugins/__init__.py +20 -0
- metaflow/plugins/airflow/airflow.py +7 -0
- metaflow/plugins/argo/argo_workflows.py +17 -0
- metaflow/plugins/aws/batch/batch_cli.py +6 -4
- metaflow/plugins/azure/__init__.py +3 -0
- metaflow/plugins/azure/azure_credential.py +53 -0
- metaflow/plugins/azure/azure_exceptions.py +1 -1
- metaflow/plugins/azure/azure_secret_manager_secrets_provider.py +240 -0
- metaflow/plugins/azure/azure_utils.py +2 -35
- metaflow/plugins/azure/blob_service_client_factory.py +4 -2
- metaflow/plugins/datastores/azure_storage.py +6 -6
- metaflow/plugins/datatools/s3/s3.py +9 -9
- metaflow/plugins/gcp/__init__.py +1 -0
- metaflow/plugins/gcp/gcp_secret_manager_secrets_provider.py +169 -0
- metaflow/plugins/gcp/gs_storage_client_factory.py +52 -1
- metaflow/plugins/kubernetes/kubernetes.py +85 -8
- metaflow/plugins/kubernetes/kubernetes_cli.py +24 -1
- metaflow/plugins/kubernetes/kubernetes_client.py +4 -1
- metaflow/plugins/kubernetes/kubernetes_decorator.py +49 -4
- metaflow/plugins/kubernetes/kubernetes_job.py +208 -201
- metaflow/plugins/kubernetes/kubernetes_jobsets.py +784 -0
- metaflow/plugins/logs_cli.py +358 -0
- metaflow/plugins/timeout_decorator.py +2 -1
- metaflow/task.py +1 -12
- metaflow/tuple_util.py +27 -0
- metaflow/util.py +0 -15
- metaflow/version.py +1 -1
- {metaflow-2.11.14.dist-info → metaflow-2.11.16.dist-info}/METADATA +2 -2
- {metaflow-2.11.14.dist-info → metaflow-2.11.16.dist-info}/RECORD +42 -36
- {metaflow-2.11.14.dist-info → metaflow-2.11.16.dist-info}/LICENSE +0 -0
- {metaflow-2.11.14.dist-info → metaflow-2.11.16.dist-info}/WHEEL +0 -0
- {metaflow-2.11.14.dist-info → metaflow-2.11.16.dist-info}/entry_points.txt +0 -0
- {metaflow-2.11.14.dist-info → metaflow-2.11.16.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,784 @@
|
|
1
|
+
import copy
|
2
|
+
import math
|
3
|
+
import random
|
4
|
+
import time
|
5
|
+
from metaflow.metaflow_current import current
|
6
|
+
from metaflow.exception import MetaflowException
|
7
|
+
from metaflow.unbounded_foreach import UBF_CONTROL, UBF_TASK
|
8
|
+
import json
|
9
|
+
from metaflow.metaflow_config import KUBERNETES_JOBSET_GROUP, KUBERNETES_JOBSET_VERSION
|
10
|
+
from collections import namedtuple
|
11
|
+
|
12
|
+
|
13
|
+
class KubernetesJobsetException(MetaflowException):
|
14
|
+
headline = "Kubernetes jobset error"
|
15
|
+
|
16
|
+
|
17
|
+
# TODO [DUPLICATE CODE]: Refactor this method to a separate file so that
|
18
|
+
# It can be used by both KubernetesJob and KubernetesJobset
|
19
|
+
def k8s_retry(deadline_seconds=60, max_backoff=32):
|
20
|
+
def decorator(function):
|
21
|
+
from functools import wraps
|
22
|
+
|
23
|
+
@wraps(function)
|
24
|
+
def wrapper(*args, **kwargs):
|
25
|
+
from kubernetes import client
|
26
|
+
|
27
|
+
deadline = time.time() + deadline_seconds
|
28
|
+
retry_number = 0
|
29
|
+
|
30
|
+
while True:
|
31
|
+
try:
|
32
|
+
result = function(*args, **kwargs)
|
33
|
+
return result
|
34
|
+
except client.rest.ApiException as e:
|
35
|
+
if e.status == 500:
|
36
|
+
current_t = time.time()
|
37
|
+
backoff_delay = min(
|
38
|
+
math.pow(2, retry_number) + random.random(), max_backoff
|
39
|
+
)
|
40
|
+
if current_t + backoff_delay < deadline:
|
41
|
+
time.sleep(backoff_delay)
|
42
|
+
retry_number += 1
|
43
|
+
continue # retry again
|
44
|
+
else:
|
45
|
+
raise
|
46
|
+
else:
|
47
|
+
raise
|
48
|
+
|
49
|
+
return wrapper
|
50
|
+
|
51
|
+
return decorator
|
52
|
+
|
53
|
+
|
54
|
+
JobsetStatus = namedtuple(
|
55
|
+
"JobsetStatus",
|
56
|
+
[
|
57
|
+
"control_pod_failed", # boolean
|
58
|
+
"control_exit_code",
|
59
|
+
"control_pod_status", # string like (<pod-status>):(<container-status>) [used for user-messaging]
|
60
|
+
"control_started",
|
61
|
+
"control_completed",
|
62
|
+
"worker_pods_failed",
|
63
|
+
"workers_are_suspended",
|
64
|
+
"workers_have_started",
|
65
|
+
"all_jobs_are_suspended",
|
66
|
+
"jobset_finished",
|
67
|
+
"jobset_failed",
|
68
|
+
"status_unknown",
|
69
|
+
"jobset_was_terminated",
|
70
|
+
"some_jobs_are_running",
|
71
|
+
],
|
72
|
+
)
|
73
|
+
|
74
|
+
|
75
|
+
def _basic_validation_for_js(jobset):
|
76
|
+
if not jobset.get("status") or not _retrieve_replicated_job_statuses(jobset):
|
77
|
+
return False
|
78
|
+
worker_jobs = [
|
79
|
+
w for w in jobset.get("spec").get("replicatedJobs") if w["name"] == "worker"
|
80
|
+
]
|
81
|
+
if len(worker_jobs) == 0:
|
82
|
+
raise KubernetesJobsetException("No worker jobs found in the jobset manifest")
|
83
|
+
control_job = [
|
84
|
+
w for w in jobset.get("spec").get("replicatedJobs") if w["name"] == "control"
|
85
|
+
]
|
86
|
+
if len(control_job) == 0:
|
87
|
+
raise KubernetesJobsetException("No control job found in the jobset manifest")
|
88
|
+
return True
|
89
|
+
|
90
|
+
|
91
|
+
def _derive_pod_status_and_status_code(control_pod):
|
92
|
+
overall_status = None
|
93
|
+
control_exit_code = None
|
94
|
+
control_pod_failed = False
|
95
|
+
if control_pod:
|
96
|
+
container_status = None
|
97
|
+
pod_status = control_pod.get("status", {}).get("phase")
|
98
|
+
container_statuses = control_pod.get("status", {}).get("containerStatuses")
|
99
|
+
if container_statuses is None:
|
100
|
+
container_status = ": ".join(
|
101
|
+
filter(
|
102
|
+
None,
|
103
|
+
[
|
104
|
+
control_pod.get("status", {}).get("reason"),
|
105
|
+
control_pod.get("status", {}).get("message"),
|
106
|
+
],
|
107
|
+
)
|
108
|
+
)
|
109
|
+
else:
|
110
|
+
for k, v in container_statuses[0].get("state", {}).items():
|
111
|
+
if v is not None:
|
112
|
+
control_exit_code = v.get("exit_code")
|
113
|
+
container_status = ": ".join(
|
114
|
+
filter(
|
115
|
+
None,
|
116
|
+
[v.get("reason"), v.get("message")],
|
117
|
+
)
|
118
|
+
)
|
119
|
+
if container_status is None:
|
120
|
+
overall_status = "pod status: %s | container status: %s" % (
|
121
|
+
pod_status,
|
122
|
+
container_status,
|
123
|
+
)
|
124
|
+
else:
|
125
|
+
overall_status = "pod status: %s" % pod_status
|
126
|
+
if pod_status == "Failed":
|
127
|
+
control_pod_failed = True
|
128
|
+
return overall_status, control_exit_code, control_pod_failed
|
129
|
+
|
130
|
+
|
131
|
+
def _retrieve_replicated_job_statuses(jobset):
|
132
|
+
# We needed this abstraction because Jobsets changed thier schema
|
133
|
+
# in version v0.3.0 where `ReplicatedJobsStatus` became `replicatedJobsStatus`
|
134
|
+
# So to handle users having an older version of jobsets, we need to account
|
135
|
+
# for both the schemas.
|
136
|
+
if jobset.get("status", {}).get("replicatedJobsStatus", None):
|
137
|
+
return jobset.get("status").get("replicatedJobsStatus")
|
138
|
+
elif jobset.get("status", {}).get("ReplicatedJobsStatus", None):
|
139
|
+
return jobset.get("status").get("ReplicatedJobsStatus")
|
140
|
+
return None
|
141
|
+
|
142
|
+
|
143
|
+
def _construct_jobset_logical_status(jobset, control_pod=None):
|
144
|
+
if not _basic_validation_for_js(jobset):
|
145
|
+
return JobsetStatus(
|
146
|
+
control_started=False,
|
147
|
+
control_completed=False,
|
148
|
+
workers_are_suspended=False,
|
149
|
+
workers_have_started=False,
|
150
|
+
all_jobs_are_suspended=False,
|
151
|
+
jobset_finished=False,
|
152
|
+
jobset_failed=False,
|
153
|
+
status_unknown=True,
|
154
|
+
jobset_was_terminated=False,
|
155
|
+
control_exit_code=None,
|
156
|
+
control_pod_status=None,
|
157
|
+
worker_pods_failed=False,
|
158
|
+
control_pod_failed=False,
|
159
|
+
some_jobs_are_running=False,
|
160
|
+
)
|
161
|
+
|
162
|
+
js_status = jobset.get("status")
|
163
|
+
|
164
|
+
control_started = False
|
165
|
+
control_completed = False
|
166
|
+
workers_are_suspended = False
|
167
|
+
workers_have_started = False
|
168
|
+
all_jobs_are_suspended = jobset.get("spec", {}).get("suspend", False)
|
169
|
+
jobset_finished = False
|
170
|
+
jobset_failed = False
|
171
|
+
status_unknown = False
|
172
|
+
jobset_was_terminated = False
|
173
|
+
worker_pods_failed = False
|
174
|
+
some_jobs_are_running = False
|
175
|
+
|
176
|
+
total_worker_jobs = [
|
177
|
+
w["replicas"]
|
178
|
+
for w in jobset.get("spec").get("replicatedJobs", [])
|
179
|
+
if w["name"] == "worker"
|
180
|
+
][0]
|
181
|
+
total_control_jobs = [
|
182
|
+
w["replicas"]
|
183
|
+
for w in jobset.get("spec").get("replicatedJobs", [])
|
184
|
+
if w["name"] == "control"
|
185
|
+
][0]
|
186
|
+
|
187
|
+
if total_worker_jobs == 0 and total_control_jobs == 0:
|
188
|
+
jobset_was_terminated = True
|
189
|
+
|
190
|
+
replicated_job_statuses = _retrieve_replicated_job_statuses(jobset)
|
191
|
+
for job_status in replicated_job_statuses:
|
192
|
+
if job_status["active"] > 0:
|
193
|
+
some_jobs_are_running = True
|
194
|
+
|
195
|
+
if job_status["name"] == "control":
|
196
|
+
control_started = job_status["active"] > 0 or job_status["succeeded"] > 0
|
197
|
+
control_completed = job_status["succeeded"] > 0
|
198
|
+
if job_status["failed"] > 0:
|
199
|
+
jobset_failed = True
|
200
|
+
|
201
|
+
if job_status["name"] == "worker":
|
202
|
+
workers_have_started = job_status["active"] == total_worker_jobs
|
203
|
+
if "suspended" in job_status:
|
204
|
+
# `replicatedJobStatus` didn't have `suspend` field
|
205
|
+
# until v0.3.0. So we need to account for that.
|
206
|
+
workers_are_suspended = job_status["suspended"] > 0
|
207
|
+
if job_status["failed"] > 0:
|
208
|
+
worker_pods_failed = True
|
209
|
+
jobset_failed = True
|
210
|
+
|
211
|
+
if js_status.get("conditions"):
|
212
|
+
for condition in js_status["conditions"]:
|
213
|
+
if condition["type"] == "Completed":
|
214
|
+
jobset_finished = True
|
215
|
+
if condition["type"] == "Failed":
|
216
|
+
jobset_failed = True
|
217
|
+
|
218
|
+
(
|
219
|
+
overall_status,
|
220
|
+
control_exit_code,
|
221
|
+
control_pod_failed,
|
222
|
+
) = _derive_pod_status_and_status_code(control_pod)
|
223
|
+
|
224
|
+
return JobsetStatus(
|
225
|
+
control_started=control_started,
|
226
|
+
control_completed=control_completed,
|
227
|
+
workers_are_suspended=workers_are_suspended,
|
228
|
+
workers_have_started=workers_have_started,
|
229
|
+
all_jobs_are_suspended=all_jobs_are_suspended,
|
230
|
+
jobset_finished=jobset_finished,
|
231
|
+
jobset_failed=jobset_failed,
|
232
|
+
status_unknown=status_unknown,
|
233
|
+
jobset_was_terminated=jobset_was_terminated,
|
234
|
+
control_exit_code=control_exit_code,
|
235
|
+
control_pod_status=overall_status,
|
236
|
+
worker_pods_failed=worker_pods_failed,
|
237
|
+
control_pod_failed=control_pod_failed,
|
238
|
+
some_jobs_are_running=some_jobs_are_running,
|
239
|
+
)
|
240
|
+
|
241
|
+
|
242
|
+
class RunningJobSet(object):
|
243
|
+
def __init__(self, client, name, namespace, group, version):
|
244
|
+
self._client = client
|
245
|
+
self._name = name
|
246
|
+
self._pod_name = None
|
247
|
+
self._namespace = namespace
|
248
|
+
self._group = group
|
249
|
+
self._version = version
|
250
|
+
self._pod = self._fetch_pod()
|
251
|
+
self._jobset = self._fetch_jobset()
|
252
|
+
|
253
|
+
import atexit
|
254
|
+
|
255
|
+
def best_effort_kill():
|
256
|
+
try:
|
257
|
+
self.kill()
|
258
|
+
except Exception as ex:
|
259
|
+
pass
|
260
|
+
|
261
|
+
atexit.register(best_effort_kill)
|
262
|
+
|
263
|
+
def __repr__(self):
|
264
|
+
return "{}('{}/{}')".format(
|
265
|
+
self.__class__.__name__, self._namespace, self._name
|
266
|
+
)
|
267
|
+
|
268
|
+
@k8s_retry()
|
269
|
+
def _fetch_jobset(
|
270
|
+
self,
|
271
|
+
):
|
272
|
+
# name : name of jobset.
|
273
|
+
# namespace : namespace of the jobset
|
274
|
+
# Query the jobset and return the object's status field as a JSON object
|
275
|
+
client = self._client.get()
|
276
|
+
with client.ApiClient() as api_client:
|
277
|
+
api_instance = client.CustomObjectsApi(api_client)
|
278
|
+
try:
|
279
|
+
jobset = api_instance.get_namespaced_custom_object(
|
280
|
+
group=self._group,
|
281
|
+
version=self._version,
|
282
|
+
namespace=self._namespace,
|
283
|
+
plural="jobsets",
|
284
|
+
name=self._name,
|
285
|
+
)
|
286
|
+
return jobset
|
287
|
+
except client.rest.ApiException as e:
|
288
|
+
if e.status == 404:
|
289
|
+
raise KubernetesJobsetException(
|
290
|
+
"Unable to locate Kubernetes jobset %s" % self._name
|
291
|
+
)
|
292
|
+
raise
|
293
|
+
|
294
|
+
@k8s_retry()
|
295
|
+
def _fetch_pod(self):
|
296
|
+
# Fetch pod metadata.
|
297
|
+
client = self._client.get()
|
298
|
+
pods = (
|
299
|
+
client.CoreV1Api()
|
300
|
+
.list_namespaced_pod(
|
301
|
+
namespace=self._namespace,
|
302
|
+
label_selector="jobset.sigs.k8s.io/jobset-name={}".format(self._name),
|
303
|
+
)
|
304
|
+
.to_dict()["items"]
|
305
|
+
)
|
306
|
+
if pods:
|
307
|
+
for pod in pods:
|
308
|
+
# check the labels of the pod to see if
|
309
|
+
# the `jobset.sigs.k8s.io/replicatedjob-name` is set to `control`
|
310
|
+
if (
|
311
|
+
pod["metadata"]["labels"].get(
|
312
|
+
"jobset.sigs.k8s.io/replicatedjob-name"
|
313
|
+
)
|
314
|
+
== "control"
|
315
|
+
):
|
316
|
+
return pod
|
317
|
+
return {}
|
318
|
+
|
319
|
+
def kill(self):
|
320
|
+
plural = "jobsets"
|
321
|
+
client = self._client.get()
|
322
|
+
# Get the jobset
|
323
|
+
with client.ApiClient() as api_client:
|
324
|
+
api_instance = client.CustomObjectsApi(api_client)
|
325
|
+
try:
|
326
|
+
jobset = api_instance.get_namespaced_custom_object(
|
327
|
+
group=self._group,
|
328
|
+
version=self._version,
|
329
|
+
namespace=self._namespace,
|
330
|
+
plural="jobsets",
|
331
|
+
name=self._name,
|
332
|
+
)
|
333
|
+
|
334
|
+
# Suspend the jobset and set the replica's to Zero.
|
335
|
+
#
|
336
|
+
jobset["spec"]["suspend"] = True
|
337
|
+
for replicated_job in jobset["spec"]["replicatedJobs"]:
|
338
|
+
replicated_job["replicas"] = 0
|
339
|
+
|
340
|
+
api_instance.replace_namespaced_custom_object(
|
341
|
+
group=self._group,
|
342
|
+
version=self._version,
|
343
|
+
namespace=self._namespace,
|
344
|
+
plural=plural,
|
345
|
+
name=jobset["metadata"]["name"],
|
346
|
+
body=jobset,
|
347
|
+
)
|
348
|
+
except Exception as e:
|
349
|
+
raise KubernetesJobsetException(
|
350
|
+
"Exception when suspending existing jobset: %s\n" % e
|
351
|
+
)
|
352
|
+
|
353
|
+
@property
|
354
|
+
def id(self):
|
355
|
+
if self._pod_name:
|
356
|
+
return "pod %s" % self._pod_name
|
357
|
+
if self._pod:
|
358
|
+
self._pod_name = self._pod["metadata"]["name"]
|
359
|
+
return self.id
|
360
|
+
return "jobset %s" % self._name
|
361
|
+
|
362
|
+
@property
|
363
|
+
def is_done(self):
|
364
|
+
def done():
|
365
|
+
return (
|
366
|
+
self._jobset_is_completed
|
367
|
+
or self._jobset_has_failed
|
368
|
+
or self._jobset_was_terminated
|
369
|
+
)
|
370
|
+
|
371
|
+
if not done():
|
372
|
+
# If not done, fetch newer status
|
373
|
+
self._jobset = self._fetch_jobset()
|
374
|
+
self._pod = self._fetch_pod()
|
375
|
+
return done()
|
376
|
+
|
377
|
+
@property
|
378
|
+
def status(self):
|
379
|
+
if self.is_done:
|
380
|
+
return "Jobset is done"
|
381
|
+
|
382
|
+
status = _construct_jobset_logical_status(self._jobset, control_pod=self._pod)
|
383
|
+
if status.status_unknown:
|
384
|
+
return "Jobset status is unknown"
|
385
|
+
if status.control_started:
|
386
|
+
if status.control_pod_status:
|
387
|
+
return "Jobset is running: %s" % status.control_pod_status
|
388
|
+
return "Jobset is running"
|
389
|
+
if status.all_jobs_are_suspended:
|
390
|
+
return "Jobset is waiting to be unsuspended"
|
391
|
+
|
392
|
+
return "Jobset waiting for jobs to start"
|
393
|
+
|
394
|
+
@property
|
395
|
+
def has_succeeded(self):
|
396
|
+
return self.is_done and self._jobset_is_completed
|
397
|
+
|
398
|
+
@property
|
399
|
+
def has_failed(self):
|
400
|
+
return self.is_done and self._jobset_has_failed
|
401
|
+
|
402
|
+
@property
|
403
|
+
def is_running(self):
|
404
|
+
if self.is_done:
|
405
|
+
return False
|
406
|
+
status = _construct_jobset_logical_status(self._jobset, control_pod=self._pod)
|
407
|
+
if status.some_jobs_are_running:
|
408
|
+
return True
|
409
|
+
return False
|
410
|
+
|
411
|
+
@property
|
412
|
+
def _jobset_was_terminated(self):
|
413
|
+
return _construct_jobset_logical_status(
|
414
|
+
self._jobset, control_pod=self._pod
|
415
|
+
).jobset_was_terminated
|
416
|
+
|
417
|
+
@property
|
418
|
+
def is_waiting(self):
|
419
|
+
return not self.is_done and not self.is_running
|
420
|
+
|
421
|
+
@property
|
422
|
+
def reason(self):
|
423
|
+
# return exit code and reason
|
424
|
+
if self.is_done and not self.has_succeeded:
|
425
|
+
self._pod = self._fetch_pod()
|
426
|
+
elif self.has_succeeded:
|
427
|
+
return 0, None
|
428
|
+
status = _construct_jobset_logical_status(self._jobset, control_pod=self._pod)
|
429
|
+
if status.control_pod_failed:
|
430
|
+
return (
|
431
|
+
status.control_exit_code,
|
432
|
+
"control-pod failed [%s]" % status.control_pod_status,
|
433
|
+
)
|
434
|
+
elif status.worker_pods_failed:
|
435
|
+
return None, "Worker pods failed"
|
436
|
+
return None, None
|
437
|
+
|
438
|
+
@property
|
439
|
+
def _jobset_is_completed(self):
|
440
|
+
return _construct_jobset_logical_status(
|
441
|
+
self._jobset, control_pod=self._pod
|
442
|
+
).jobset_finished
|
443
|
+
|
444
|
+
@property
|
445
|
+
def _jobset_has_failed(self):
|
446
|
+
return _construct_jobset_logical_status(
|
447
|
+
self._jobset, control_pod=self._pod
|
448
|
+
).jobset_failed
|
449
|
+
|
450
|
+
|
451
|
+
class TaskIdConstructor:
|
452
|
+
@classmethod
|
453
|
+
def jobset_worker_id(cls, control_task_id: str):
|
454
|
+
return "".join(
|
455
|
+
[control_task_id.replace("control", "worker"), "-", "$WORKER_REPLICA_INDEX"]
|
456
|
+
)
|
457
|
+
|
458
|
+
@classmethod
|
459
|
+
def join_step_task_ids(cls, num_parallel):
|
460
|
+
"""
|
461
|
+
Called within the step decorator to set the `flow._control_mapper_tasks`.
|
462
|
+
Setting these allows the flow to know which tasks are needed in the join step.
|
463
|
+
We set this in the `task_pre_step` method of the decorator.
|
464
|
+
"""
|
465
|
+
control_task_id = current.task_id
|
466
|
+
worker_task_id_base = control_task_id.replace("control", "worker")
|
467
|
+
mapper = lambda idx: worker_task_id_base + "-%s" % (str(idx))
|
468
|
+
return control_task_id, [mapper(idx) for idx in range(0, num_parallel - 1)]
|
469
|
+
|
470
|
+
@classmethod
|
471
|
+
def argo(cls):
|
472
|
+
pass
|
473
|
+
|
474
|
+
|
475
|
+
def _jobset_specific_env_vars(client, jobset_main_addr, master_port, num_parallel):
|
476
|
+
return [
|
477
|
+
client.V1EnvVar(
|
478
|
+
name="MASTER_ADDR",
|
479
|
+
value=jobset_main_addr,
|
480
|
+
),
|
481
|
+
client.V1EnvVar(
|
482
|
+
name="MASTER_PORT",
|
483
|
+
value=str(master_port),
|
484
|
+
),
|
485
|
+
client.V1EnvVar(
|
486
|
+
name="WORLD_SIZE",
|
487
|
+
value=str(num_parallel),
|
488
|
+
),
|
489
|
+
] + [
|
490
|
+
client.V1EnvVar(
|
491
|
+
name="JOBSET_RESTART_ATTEMPT",
|
492
|
+
value_from=client.V1EnvVarSource(
|
493
|
+
field_ref=client.V1ObjectFieldSelector(
|
494
|
+
field_path="metadata.annotations['jobset.sigs.k8s.io/restart-attempt']"
|
495
|
+
)
|
496
|
+
),
|
497
|
+
),
|
498
|
+
client.V1EnvVar(
|
499
|
+
name="METAFLOW_KUBERNETES_JOBSET_NAME",
|
500
|
+
value_from=client.V1EnvVarSource(
|
501
|
+
field_ref=client.V1ObjectFieldSelector(
|
502
|
+
field_path="metadata.annotations['jobset.sigs.k8s.io/jobset-name']"
|
503
|
+
)
|
504
|
+
),
|
505
|
+
),
|
506
|
+
client.V1EnvVar(
|
507
|
+
name="WORKER_REPLICA_INDEX",
|
508
|
+
value_from=client.V1EnvVarSource(
|
509
|
+
field_ref=client.V1ObjectFieldSelector(
|
510
|
+
field_path="metadata.annotations['jobset.sigs.k8s.io/job-index']"
|
511
|
+
)
|
512
|
+
),
|
513
|
+
),
|
514
|
+
]
|
515
|
+
|
516
|
+
|
517
|
+
def get_control_job(
|
518
|
+
client,
|
519
|
+
job_spec,
|
520
|
+
jobset_main_addr,
|
521
|
+
subdomain,
|
522
|
+
port=None,
|
523
|
+
num_parallel=None,
|
524
|
+
namespace=None,
|
525
|
+
annotations=None,
|
526
|
+
) -> dict:
|
527
|
+
master_port = port
|
528
|
+
|
529
|
+
job_spec = copy.deepcopy(job_spec)
|
530
|
+
job_spec.parallelism = 1
|
531
|
+
job_spec.completions = 1
|
532
|
+
job_spec.template.spec.set_hostname_as_fqdn = True
|
533
|
+
job_spec.template.spec.subdomain = subdomain
|
534
|
+
job_spec.template.metadata.annotations = copy.copy(annotations)
|
535
|
+
|
536
|
+
for idx in range(len(job_spec.template.spec.containers[0].command)):
|
537
|
+
# CHECK FOR THE ubf_context in the command.
|
538
|
+
# Replace the UBF context to the one appropriately matching control/worker.
|
539
|
+
# Since we are passing the `step_cli` one time from the top level to one
|
540
|
+
# KuberentesJobSet, we need to ensure that UBF context is replaced properly
|
541
|
+
# in all the worker jobs.
|
542
|
+
if UBF_CONTROL in job_spec.template.spec.containers[0].command[idx]:
|
543
|
+
job_spec.template.spec.containers[0].command[idx] = (
|
544
|
+
job_spec.template.spec.containers[0]
|
545
|
+
.command[idx]
|
546
|
+
.replace(UBF_CONTROL, UBF_CONTROL + " " + "--split-index 0")
|
547
|
+
)
|
548
|
+
|
549
|
+
job_spec.template.spec.containers[0].env = (
|
550
|
+
job_spec.template.spec.containers[0].env
|
551
|
+
+ _jobset_specific_env_vars(client, jobset_main_addr, master_port, num_parallel)
|
552
|
+
+ [
|
553
|
+
client.V1EnvVar(
|
554
|
+
name="CONTROL_INDEX",
|
555
|
+
value=str(0),
|
556
|
+
)
|
557
|
+
]
|
558
|
+
)
|
559
|
+
|
560
|
+
# Based on https://github.com/kubernetes-sigs/jobset/blob/v0.5.0/api/jobset/v1alpha2/jobset_types.go#L178
|
561
|
+
return dict(
|
562
|
+
name="control",
|
563
|
+
template=client.api_client.ApiClient().sanitize_for_serialization(
|
564
|
+
client.V1JobTemplateSpec(
|
565
|
+
metadata=client.V1ObjectMeta(
|
566
|
+
namespace=namespace,
|
567
|
+
# We don't set any annotations here
|
568
|
+
# since they have been either set in the JobSpec
|
569
|
+
# or on the JobSet level
|
570
|
+
),
|
571
|
+
spec=job_spec,
|
572
|
+
)
|
573
|
+
),
|
574
|
+
replicas=1, # The control job will always have 1 replica.
|
575
|
+
)
|
576
|
+
|
577
|
+
|
578
|
+
def get_worker_job(
|
579
|
+
client,
|
580
|
+
job_spec,
|
581
|
+
job_name,
|
582
|
+
jobset_main_addr,
|
583
|
+
subdomain,
|
584
|
+
control_task_id=None,
|
585
|
+
worker_task_id=None,
|
586
|
+
replicas=1,
|
587
|
+
port=None,
|
588
|
+
num_parallel=None,
|
589
|
+
namespace=None,
|
590
|
+
annotations=None,
|
591
|
+
) -> dict:
|
592
|
+
master_port = port
|
593
|
+
|
594
|
+
job_spec = copy.deepcopy(job_spec)
|
595
|
+
job_spec.parallelism = 1
|
596
|
+
job_spec.completions = 1
|
597
|
+
job_spec.template.spec.set_hostname_as_fqdn = True
|
598
|
+
job_spec.template.spec.subdomain = subdomain
|
599
|
+
job_spec.template.metadata.annotations = copy.copy(annotations)
|
600
|
+
|
601
|
+
for idx in range(len(job_spec.template.spec.containers[0].command)):
|
602
|
+
if control_task_id in job_spec.template.spec.containers[0].command[idx]:
|
603
|
+
job_spec.template.spec.containers[0].command[idx] = (
|
604
|
+
job_spec.template.spec.containers[0]
|
605
|
+
.command[idx]
|
606
|
+
.replace(control_task_id, worker_task_id)
|
607
|
+
)
|
608
|
+
# CHECK FOR THE ubf_context in the command.
|
609
|
+
# Replace the UBF context to the one appropriately matching control/worker.
|
610
|
+
# Since we are passing the `step_cli` one time from the top level to one
|
611
|
+
# KuberentesJobSet, we need to ensure that UBF context is replaced properly
|
612
|
+
# in all the worker jobs.
|
613
|
+
if UBF_CONTROL in job_spec.template.spec.containers[0].command[idx]:
|
614
|
+
# Since all command will have a UBF_CONTROL, we need to replace the UBF_CONTROL
|
615
|
+
# with the actual UBF Context and also ensure that we are setting the correct
|
616
|
+
# split-index for the worker jobs.
|
617
|
+
split_index_str = "--split-index `expr $[WORKER_REPLICA_INDEX] + 1`" # This set in the environment variables below
|
618
|
+
job_spec.template.spec.containers[0].command[idx] = (
|
619
|
+
job_spec.template.spec.containers[0]
|
620
|
+
.command[idx]
|
621
|
+
.replace(UBF_CONTROL, UBF_TASK + " " + split_index_str)
|
622
|
+
)
|
623
|
+
|
624
|
+
job_spec.template.spec.containers[0].env = job_spec.template.spec.containers[
|
625
|
+
0
|
626
|
+
].env + _jobset_specific_env_vars(
|
627
|
+
client, jobset_main_addr, master_port, num_parallel
|
628
|
+
)
|
629
|
+
|
630
|
+
# Based on https://github.com/kubernetes-sigs/jobset/blob/v0.5.0/api/jobset/v1alpha2/jobset_types.go#L178
|
631
|
+
return dict(
|
632
|
+
name=job_name,
|
633
|
+
template=client.api_client.ApiClient().sanitize_for_serialization(
|
634
|
+
client.V1JobTemplateSpec(
|
635
|
+
metadata=client.V1ObjectMeta(
|
636
|
+
namespace=namespace,
|
637
|
+
# We don't set any annotations here
|
638
|
+
# since they have been either set in the JobSpec
|
639
|
+
# or on the JobSet level
|
640
|
+
),
|
641
|
+
spec=job_spec,
|
642
|
+
)
|
643
|
+
),
|
644
|
+
replicas=replicas,
|
645
|
+
)
|
646
|
+
|
647
|
+
|
648
|
+
def _make_domain_name(
|
649
|
+
jobset_name, main_job_name, main_job_index, main_pod_index, namespace
|
650
|
+
):
|
651
|
+
return "%s-%s-%s-%s.%s.%s.svc.cluster.local" % (
|
652
|
+
jobset_name,
|
653
|
+
main_job_name,
|
654
|
+
main_job_index,
|
655
|
+
main_pod_index,
|
656
|
+
jobset_name,
|
657
|
+
namespace,
|
658
|
+
)
|
659
|
+
|
660
|
+
|
661
|
+
class KubernetesJobSet(object):
|
662
|
+
def __init__(
|
663
|
+
self,
|
664
|
+
client,
|
665
|
+
name=None,
|
666
|
+
job_spec=None,
|
667
|
+
namespace=None,
|
668
|
+
num_parallel=None,
|
669
|
+
annotations=None,
|
670
|
+
labels=None,
|
671
|
+
port=None,
|
672
|
+
task_id=None,
|
673
|
+
**kwargs
|
674
|
+
):
|
675
|
+
self._client = client
|
676
|
+
self._kwargs = kwargs
|
677
|
+
self._group = KUBERNETES_JOBSET_GROUP
|
678
|
+
self._version = KUBERNETES_JOBSET_VERSION
|
679
|
+
self.name = name
|
680
|
+
|
681
|
+
main_job_name = "control"
|
682
|
+
main_job_index = 0
|
683
|
+
main_pod_index = 0
|
684
|
+
subdomain = self.name
|
685
|
+
num_parallel = int(1 if not num_parallel else num_parallel)
|
686
|
+
self._namespace = namespace
|
687
|
+
jobset_main_addr = _make_domain_name(
|
688
|
+
self.name,
|
689
|
+
main_job_name,
|
690
|
+
main_job_index,
|
691
|
+
main_pod_index,
|
692
|
+
self._namespace,
|
693
|
+
)
|
694
|
+
|
695
|
+
annotations = {} if not annotations else annotations
|
696
|
+
labels = {} if not labels else labels
|
697
|
+
|
698
|
+
if "metaflow/task_id" in annotations:
|
699
|
+
del annotations["metaflow/task_id"]
|
700
|
+
|
701
|
+
control_job = get_control_job(
|
702
|
+
client=self._client.get(),
|
703
|
+
job_spec=job_spec,
|
704
|
+
jobset_main_addr=jobset_main_addr,
|
705
|
+
subdomain=subdomain,
|
706
|
+
port=port,
|
707
|
+
num_parallel=num_parallel,
|
708
|
+
namespace=namespace,
|
709
|
+
annotations=annotations,
|
710
|
+
)
|
711
|
+
worker_task_id = TaskIdConstructor.jobset_worker_id(task_id)
|
712
|
+
worker_job = get_worker_job(
|
713
|
+
client=self._client.get(),
|
714
|
+
job_spec=job_spec,
|
715
|
+
job_name="worker",
|
716
|
+
jobset_main_addr=jobset_main_addr,
|
717
|
+
subdomain=subdomain,
|
718
|
+
control_task_id=task_id,
|
719
|
+
worker_task_id=worker_task_id,
|
720
|
+
replicas=num_parallel - 1,
|
721
|
+
port=port,
|
722
|
+
num_parallel=num_parallel,
|
723
|
+
namespace=namespace,
|
724
|
+
annotations=annotations,
|
725
|
+
)
|
726
|
+
worker_jobs = [worker_job]
|
727
|
+
# Based on https://github.com/kubernetes-sigs/jobset/blob/v0.5.0/api/jobset/v1alpha2/jobset_types.go#L163
|
728
|
+
_kclient = client.get()
|
729
|
+
self._jobset = dict(
|
730
|
+
apiVersion=self._group + "/" + self._version,
|
731
|
+
kind="JobSet",
|
732
|
+
metadata=_kclient.api_client.ApiClient().sanitize_for_serialization(
|
733
|
+
_kclient.V1ObjectMeta(
|
734
|
+
name=self.name, labels=labels, annotations=annotations
|
735
|
+
)
|
736
|
+
),
|
737
|
+
spec=dict(
|
738
|
+
replicatedJobs=[control_job] + worker_jobs,
|
739
|
+
suspend=False,
|
740
|
+
startupPolicy=None,
|
741
|
+
successPolicy=None,
|
742
|
+
# The Failure Policy helps setting the number of retries for the jobset.
|
743
|
+
# It cannot accept a value of 0 for maxRestarts.
|
744
|
+
# So the attempt needs to be smartly set.
|
745
|
+
# If there is no retry decorator then we not set maxRestarts and instead we will
|
746
|
+
# set the attempt statically to 0. Otherwise we will make the job pickup the attempt
|
747
|
+
# from the `V1EnvVarSource.value_from.V1ObjectFieldSelector.field_path` = "metadata.annotations['jobset.sigs.k8s.io/restart-attempt']"
|
748
|
+
# failurePolicy={
|
749
|
+
# "maxRestarts" : 1
|
750
|
+
# },
|
751
|
+
# The can be set for ArgoWorkflows
|
752
|
+
failurePolicy=None,
|
753
|
+
network=None,
|
754
|
+
),
|
755
|
+
status=None,
|
756
|
+
)
|
757
|
+
|
758
|
+
def execute(self):
|
759
|
+
client = self._client.get()
|
760
|
+
api_instance = client.CoreV1Api()
|
761
|
+
|
762
|
+
with client.ApiClient() as api_client:
|
763
|
+
api_instance = client.CustomObjectsApi(api_client)
|
764
|
+
try:
|
765
|
+
jobset_obj = api_instance.create_namespaced_custom_object(
|
766
|
+
group=self._group,
|
767
|
+
version=self._version,
|
768
|
+
namespace=self._namespace,
|
769
|
+
plural="jobsets",
|
770
|
+
body=self._jobset,
|
771
|
+
)
|
772
|
+
except Exception as e:
|
773
|
+
raise KubernetesJobsetException(
|
774
|
+
"Exception when calling CustomObjectsApi->create_namespaced_custom_object: %s\n"
|
775
|
+
% e
|
776
|
+
)
|
777
|
+
|
778
|
+
return RunningJobSet(
|
779
|
+
client=self._client,
|
780
|
+
name=jobset_obj["metadata"]["name"],
|
781
|
+
namespace=jobset_obj["metadata"]["namespace"],
|
782
|
+
group=self._group,
|
783
|
+
version=self._version,
|
784
|
+
)
|