metaflow 2.11.15__py2.py3-none-any.whl → 2.11.16__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. metaflow/__init__.py +3 -0
  2. metaflow/clone_util.py +6 -0
  3. metaflow/extension_support/plugins.py +2 -0
  4. metaflow/metaflow_config.py +24 -0
  5. metaflow/metaflow_environment.py +2 -2
  6. metaflow/plugins/__init__.py +19 -0
  7. metaflow/plugins/airflow/airflow.py +7 -0
  8. metaflow/plugins/argo/argo_workflows.py +17 -0
  9. metaflow/plugins/azure/__init__.py +3 -0
  10. metaflow/plugins/azure/azure_credential.py +53 -0
  11. metaflow/plugins/azure/azure_exceptions.py +1 -1
  12. metaflow/plugins/azure/azure_secret_manager_secrets_provider.py +240 -0
  13. metaflow/plugins/azure/azure_utils.py +2 -35
  14. metaflow/plugins/azure/blob_service_client_factory.py +4 -2
  15. metaflow/plugins/datastores/azure_storage.py +6 -6
  16. metaflow/plugins/datatools/s3/s3.py +1 -1
  17. metaflow/plugins/gcp/__init__.py +1 -0
  18. metaflow/plugins/gcp/gcp_secret_manager_secrets_provider.py +169 -0
  19. metaflow/plugins/gcp/gs_storage_client_factory.py +52 -1
  20. metaflow/plugins/kubernetes/kubernetes.py +85 -8
  21. metaflow/plugins/kubernetes/kubernetes_cli.py +24 -1
  22. metaflow/plugins/kubernetes/kubernetes_client.py +4 -1
  23. metaflow/plugins/kubernetes/kubernetes_decorator.py +49 -4
  24. metaflow/plugins/kubernetes/kubernetes_job.py +208 -201
  25. metaflow/plugins/kubernetes/kubernetes_jobsets.py +784 -0
  26. metaflow/plugins/timeout_decorator.py +2 -1
  27. metaflow/task.py +1 -12
  28. metaflow/tuple_util.py +27 -0
  29. metaflow/util.py +0 -15
  30. metaflow/version.py +1 -1
  31. {metaflow-2.11.15.dist-info → metaflow-2.11.16.dist-info}/METADATA +2 -2
  32. {metaflow-2.11.15.dist-info → metaflow-2.11.16.dist-info}/RECORD +36 -31
  33. {metaflow-2.11.15.dist-info → metaflow-2.11.16.dist-info}/LICENSE +0 -0
  34. {metaflow-2.11.15.dist-info → metaflow-2.11.16.dist-info}/WHEEL +0 -0
  35. {metaflow-2.11.15.dist-info → metaflow-2.11.16.dist-info}/entry_points.txt +0 -0
  36. {metaflow-2.11.15.dist-info → metaflow-2.11.16.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,784 @@
1
+ import copy
2
+ import math
3
+ import random
4
+ import time
5
+ from metaflow.metaflow_current import current
6
+ from metaflow.exception import MetaflowException
7
+ from metaflow.unbounded_foreach import UBF_CONTROL, UBF_TASK
8
+ import json
9
+ from metaflow.metaflow_config import KUBERNETES_JOBSET_GROUP, KUBERNETES_JOBSET_VERSION
10
+ from collections import namedtuple
11
+
12
+
13
+ class KubernetesJobsetException(MetaflowException):
14
+ headline = "Kubernetes jobset error"
15
+
16
+
17
+ # TODO [DUPLICATE CODE]: Refactor this method to a separate file so that
18
+ # It can be used by both KubernetesJob and KubernetesJobset
19
+ def k8s_retry(deadline_seconds=60, max_backoff=32):
20
+ def decorator(function):
21
+ from functools import wraps
22
+
23
+ @wraps(function)
24
+ def wrapper(*args, **kwargs):
25
+ from kubernetes import client
26
+
27
+ deadline = time.time() + deadline_seconds
28
+ retry_number = 0
29
+
30
+ while True:
31
+ try:
32
+ result = function(*args, **kwargs)
33
+ return result
34
+ except client.rest.ApiException as e:
35
+ if e.status == 500:
36
+ current_t = time.time()
37
+ backoff_delay = min(
38
+ math.pow(2, retry_number) + random.random(), max_backoff
39
+ )
40
+ if current_t + backoff_delay < deadline:
41
+ time.sleep(backoff_delay)
42
+ retry_number += 1
43
+ continue # retry again
44
+ else:
45
+ raise
46
+ else:
47
+ raise
48
+
49
+ return wrapper
50
+
51
+ return decorator
52
+
53
+
54
+ JobsetStatus = namedtuple(
55
+ "JobsetStatus",
56
+ [
57
+ "control_pod_failed", # boolean
58
+ "control_exit_code",
59
+ "control_pod_status", # string like (<pod-status>):(<container-status>) [used for user-messaging]
60
+ "control_started",
61
+ "control_completed",
62
+ "worker_pods_failed",
63
+ "workers_are_suspended",
64
+ "workers_have_started",
65
+ "all_jobs_are_suspended",
66
+ "jobset_finished",
67
+ "jobset_failed",
68
+ "status_unknown",
69
+ "jobset_was_terminated",
70
+ "some_jobs_are_running",
71
+ ],
72
+ )
73
+
74
+
75
+ def _basic_validation_for_js(jobset):
76
+ if not jobset.get("status") or not _retrieve_replicated_job_statuses(jobset):
77
+ return False
78
+ worker_jobs = [
79
+ w for w in jobset.get("spec").get("replicatedJobs") if w["name"] == "worker"
80
+ ]
81
+ if len(worker_jobs) == 0:
82
+ raise KubernetesJobsetException("No worker jobs found in the jobset manifest")
83
+ control_job = [
84
+ w for w in jobset.get("spec").get("replicatedJobs") if w["name"] == "control"
85
+ ]
86
+ if len(control_job) == 0:
87
+ raise KubernetesJobsetException("No control job found in the jobset manifest")
88
+ return True
89
+
90
+
91
+ def _derive_pod_status_and_status_code(control_pod):
92
+ overall_status = None
93
+ control_exit_code = None
94
+ control_pod_failed = False
95
+ if control_pod:
96
+ container_status = None
97
+ pod_status = control_pod.get("status", {}).get("phase")
98
+ container_statuses = control_pod.get("status", {}).get("containerStatuses")
99
+ if container_statuses is None:
100
+ container_status = ": ".join(
101
+ filter(
102
+ None,
103
+ [
104
+ control_pod.get("status", {}).get("reason"),
105
+ control_pod.get("status", {}).get("message"),
106
+ ],
107
+ )
108
+ )
109
+ else:
110
+ for k, v in container_statuses[0].get("state", {}).items():
111
+ if v is not None:
112
+ control_exit_code = v.get("exit_code")
113
+ container_status = ": ".join(
114
+ filter(
115
+ None,
116
+ [v.get("reason"), v.get("message")],
117
+ )
118
+ )
119
+ if container_status is None:
120
+ overall_status = "pod status: %s | container status: %s" % (
121
+ pod_status,
122
+ container_status,
123
+ )
124
+ else:
125
+ overall_status = "pod status: %s" % pod_status
126
+ if pod_status == "Failed":
127
+ control_pod_failed = True
128
+ return overall_status, control_exit_code, control_pod_failed
129
+
130
+
131
+ def _retrieve_replicated_job_statuses(jobset):
132
+ # We needed this abstraction because Jobsets changed thier schema
133
+ # in version v0.3.0 where `ReplicatedJobsStatus` became `replicatedJobsStatus`
134
+ # So to handle users having an older version of jobsets, we need to account
135
+ # for both the schemas.
136
+ if jobset.get("status", {}).get("replicatedJobsStatus", None):
137
+ return jobset.get("status").get("replicatedJobsStatus")
138
+ elif jobset.get("status", {}).get("ReplicatedJobsStatus", None):
139
+ return jobset.get("status").get("ReplicatedJobsStatus")
140
+ return None
141
+
142
+
143
+ def _construct_jobset_logical_status(jobset, control_pod=None):
144
+ if not _basic_validation_for_js(jobset):
145
+ return JobsetStatus(
146
+ control_started=False,
147
+ control_completed=False,
148
+ workers_are_suspended=False,
149
+ workers_have_started=False,
150
+ all_jobs_are_suspended=False,
151
+ jobset_finished=False,
152
+ jobset_failed=False,
153
+ status_unknown=True,
154
+ jobset_was_terminated=False,
155
+ control_exit_code=None,
156
+ control_pod_status=None,
157
+ worker_pods_failed=False,
158
+ control_pod_failed=False,
159
+ some_jobs_are_running=False,
160
+ )
161
+
162
+ js_status = jobset.get("status")
163
+
164
+ control_started = False
165
+ control_completed = False
166
+ workers_are_suspended = False
167
+ workers_have_started = False
168
+ all_jobs_are_suspended = jobset.get("spec", {}).get("suspend", False)
169
+ jobset_finished = False
170
+ jobset_failed = False
171
+ status_unknown = False
172
+ jobset_was_terminated = False
173
+ worker_pods_failed = False
174
+ some_jobs_are_running = False
175
+
176
+ total_worker_jobs = [
177
+ w["replicas"]
178
+ for w in jobset.get("spec").get("replicatedJobs", [])
179
+ if w["name"] == "worker"
180
+ ][0]
181
+ total_control_jobs = [
182
+ w["replicas"]
183
+ for w in jobset.get("spec").get("replicatedJobs", [])
184
+ if w["name"] == "control"
185
+ ][0]
186
+
187
+ if total_worker_jobs == 0 and total_control_jobs == 0:
188
+ jobset_was_terminated = True
189
+
190
+ replicated_job_statuses = _retrieve_replicated_job_statuses(jobset)
191
+ for job_status in replicated_job_statuses:
192
+ if job_status["active"] > 0:
193
+ some_jobs_are_running = True
194
+
195
+ if job_status["name"] == "control":
196
+ control_started = job_status["active"] > 0 or job_status["succeeded"] > 0
197
+ control_completed = job_status["succeeded"] > 0
198
+ if job_status["failed"] > 0:
199
+ jobset_failed = True
200
+
201
+ if job_status["name"] == "worker":
202
+ workers_have_started = job_status["active"] == total_worker_jobs
203
+ if "suspended" in job_status:
204
+ # `replicatedJobStatus` didn't have `suspend` field
205
+ # until v0.3.0. So we need to account for that.
206
+ workers_are_suspended = job_status["suspended"] > 0
207
+ if job_status["failed"] > 0:
208
+ worker_pods_failed = True
209
+ jobset_failed = True
210
+
211
+ if js_status.get("conditions"):
212
+ for condition in js_status["conditions"]:
213
+ if condition["type"] == "Completed":
214
+ jobset_finished = True
215
+ if condition["type"] == "Failed":
216
+ jobset_failed = True
217
+
218
+ (
219
+ overall_status,
220
+ control_exit_code,
221
+ control_pod_failed,
222
+ ) = _derive_pod_status_and_status_code(control_pod)
223
+
224
+ return JobsetStatus(
225
+ control_started=control_started,
226
+ control_completed=control_completed,
227
+ workers_are_suspended=workers_are_suspended,
228
+ workers_have_started=workers_have_started,
229
+ all_jobs_are_suspended=all_jobs_are_suspended,
230
+ jobset_finished=jobset_finished,
231
+ jobset_failed=jobset_failed,
232
+ status_unknown=status_unknown,
233
+ jobset_was_terminated=jobset_was_terminated,
234
+ control_exit_code=control_exit_code,
235
+ control_pod_status=overall_status,
236
+ worker_pods_failed=worker_pods_failed,
237
+ control_pod_failed=control_pod_failed,
238
+ some_jobs_are_running=some_jobs_are_running,
239
+ )
240
+
241
+
242
+ class RunningJobSet(object):
243
+ def __init__(self, client, name, namespace, group, version):
244
+ self._client = client
245
+ self._name = name
246
+ self._pod_name = None
247
+ self._namespace = namespace
248
+ self._group = group
249
+ self._version = version
250
+ self._pod = self._fetch_pod()
251
+ self._jobset = self._fetch_jobset()
252
+
253
+ import atexit
254
+
255
+ def best_effort_kill():
256
+ try:
257
+ self.kill()
258
+ except Exception as ex:
259
+ pass
260
+
261
+ atexit.register(best_effort_kill)
262
+
263
+ def __repr__(self):
264
+ return "{}('{}/{}')".format(
265
+ self.__class__.__name__, self._namespace, self._name
266
+ )
267
+
268
+ @k8s_retry()
269
+ def _fetch_jobset(
270
+ self,
271
+ ):
272
+ # name : name of jobset.
273
+ # namespace : namespace of the jobset
274
+ # Query the jobset and return the object's status field as a JSON object
275
+ client = self._client.get()
276
+ with client.ApiClient() as api_client:
277
+ api_instance = client.CustomObjectsApi(api_client)
278
+ try:
279
+ jobset = api_instance.get_namespaced_custom_object(
280
+ group=self._group,
281
+ version=self._version,
282
+ namespace=self._namespace,
283
+ plural="jobsets",
284
+ name=self._name,
285
+ )
286
+ return jobset
287
+ except client.rest.ApiException as e:
288
+ if e.status == 404:
289
+ raise KubernetesJobsetException(
290
+ "Unable to locate Kubernetes jobset %s" % self._name
291
+ )
292
+ raise
293
+
294
+ @k8s_retry()
295
+ def _fetch_pod(self):
296
+ # Fetch pod metadata.
297
+ client = self._client.get()
298
+ pods = (
299
+ client.CoreV1Api()
300
+ .list_namespaced_pod(
301
+ namespace=self._namespace,
302
+ label_selector="jobset.sigs.k8s.io/jobset-name={}".format(self._name),
303
+ )
304
+ .to_dict()["items"]
305
+ )
306
+ if pods:
307
+ for pod in pods:
308
+ # check the labels of the pod to see if
309
+ # the `jobset.sigs.k8s.io/replicatedjob-name` is set to `control`
310
+ if (
311
+ pod["metadata"]["labels"].get(
312
+ "jobset.sigs.k8s.io/replicatedjob-name"
313
+ )
314
+ == "control"
315
+ ):
316
+ return pod
317
+ return {}
318
+
319
+ def kill(self):
320
+ plural = "jobsets"
321
+ client = self._client.get()
322
+ # Get the jobset
323
+ with client.ApiClient() as api_client:
324
+ api_instance = client.CustomObjectsApi(api_client)
325
+ try:
326
+ jobset = api_instance.get_namespaced_custom_object(
327
+ group=self._group,
328
+ version=self._version,
329
+ namespace=self._namespace,
330
+ plural="jobsets",
331
+ name=self._name,
332
+ )
333
+
334
+ # Suspend the jobset and set the replica's to Zero.
335
+ #
336
+ jobset["spec"]["suspend"] = True
337
+ for replicated_job in jobset["spec"]["replicatedJobs"]:
338
+ replicated_job["replicas"] = 0
339
+
340
+ api_instance.replace_namespaced_custom_object(
341
+ group=self._group,
342
+ version=self._version,
343
+ namespace=self._namespace,
344
+ plural=plural,
345
+ name=jobset["metadata"]["name"],
346
+ body=jobset,
347
+ )
348
+ except Exception as e:
349
+ raise KubernetesJobsetException(
350
+ "Exception when suspending existing jobset: %s\n" % e
351
+ )
352
+
353
+ @property
354
+ def id(self):
355
+ if self._pod_name:
356
+ return "pod %s" % self._pod_name
357
+ if self._pod:
358
+ self._pod_name = self._pod["metadata"]["name"]
359
+ return self.id
360
+ return "jobset %s" % self._name
361
+
362
+ @property
363
+ def is_done(self):
364
+ def done():
365
+ return (
366
+ self._jobset_is_completed
367
+ or self._jobset_has_failed
368
+ or self._jobset_was_terminated
369
+ )
370
+
371
+ if not done():
372
+ # If not done, fetch newer status
373
+ self._jobset = self._fetch_jobset()
374
+ self._pod = self._fetch_pod()
375
+ return done()
376
+
377
+ @property
378
+ def status(self):
379
+ if self.is_done:
380
+ return "Jobset is done"
381
+
382
+ status = _construct_jobset_logical_status(self._jobset, control_pod=self._pod)
383
+ if status.status_unknown:
384
+ return "Jobset status is unknown"
385
+ if status.control_started:
386
+ if status.control_pod_status:
387
+ return "Jobset is running: %s" % status.control_pod_status
388
+ return "Jobset is running"
389
+ if status.all_jobs_are_suspended:
390
+ return "Jobset is waiting to be unsuspended"
391
+
392
+ return "Jobset waiting for jobs to start"
393
+
394
+ @property
395
+ def has_succeeded(self):
396
+ return self.is_done and self._jobset_is_completed
397
+
398
+ @property
399
+ def has_failed(self):
400
+ return self.is_done and self._jobset_has_failed
401
+
402
+ @property
403
+ def is_running(self):
404
+ if self.is_done:
405
+ return False
406
+ status = _construct_jobset_logical_status(self._jobset, control_pod=self._pod)
407
+ if status.some_jobs_are_running:
408
+ return True
409
+ return False
410
+
411
+ @property
412
+ def _jobset_was_terminated(self):
413
+ return _construct_jobset_logical_status(
414
+ self._jobset, control_pod=self._pod
415
+ ).jobset_was_terminated
416
+
417
+ @property
418
+ def is_waiting(self):
419
+ return not self.is_done and not self.is_running
420
+
421
+ @property
422
+ def reason(self):
423
+ # return exit code and reason
424
+ if self.is_done and not self.has_succeeded:
425
+ self._pod = self._fetch_pod()
426
+ elif self.has_succeeded:
427
+ return 0, None
428
+ status = _construct_jobset_logical_status(self._jobset, control_pod=self._pod)
429
+ if status.control_pod_failed:
430
+ return (
431
+ status.control_exit_code,
432
+ "control-pod failed [%s]" % status.control_pod_status,
433
+ )
434
+ elif status.worker_pods_failed:
435
+ return None, "Worker pods failed"
436
+ return None, None
437
+
438
+ @property
439
+ def _jobset_is_completed(self):
440
+ return _construct_jobset_logical_status(
441
+ self._jobset, control_pod=self._pod
442
+ ).jobset_finished
443
+
444
+ @property
445
+ def _jobset_has_failed(self):
446
+ return _construct_jobset_logical_status(
447
+ self._jobset, control_pod=self._pod
448
+ ).jobset_failed
449
+
450
+
451
+ class TaskIdConstructor:
452
+ @classmethod
453
+ def jobset_worker_id(cls, control_task_id: str):
454
+ return "".join(
455
+ [control_task_id.replace("control", "worker"), "-", "$WORKER_REPLICA_INDEX"]
456
+ )
457
+
458
+ @classmethod
459
+ def join_step_task_ids(cls, num_parallel):
460
+ """
461
+ Called within the step decorator to set the `flow._control_mapper_tasks`.
462
+ Setting these allows the flow to know which tasks are needed in the join step.
463
+ We set this in the `task_pre_step` method of the decorator.
464
+ """
465
+ control_task_id = current.task_id
466
+ worker_task_id_base = control_task_id.replace("control", "worker")
467
+ mapper = lambda idx: worker_task_id_base + "-%s" % (str(idx))
468
+ return control_task_id, [mapper(idx) for idx in range(0, num_parallel - 1)]
469
+
470
+ @classmethod
471
+ def argo(cls):
472
+ pass
473
+
474
+
475
+ def _jobset_specific_env_vars(client, jobset_main_addr, master_port, num_parallel):
476
+ return [
477
+ client.V1EnvVar(
478
+ name="MASTER_ADDR",
479
+ value=jobset_main_addr,
480
+ ),
481
+ client.V1EnvVar(
482
+ name="MASTER_PORT",
483
+ value=str(master_port),
484
+ ),
485
+ client.V1EnvVar(
486
+ name="WORLD_SIZE",
487
+ value=str(num_parallel),
488
+ ),
489
+ ] + [
490
+ client.V1EnvVar(
491
+ name="JOBSET_RESTART_ATTEMPT",
492
+ value_from=client.V1EnvVarSource(
493
+ field_ref=client.V1ObjectFieldSelector(
494
+ field_path="metadata.annotations['jobset.sigs.k8s.io/restart-attempt']"
495
+ )
496
+ ),
497
+ ),
498
+ client.V1EnvVar(
499
+ name="METAFLOW_KUBERNETES_JOBSET_NAME",
500
+ value_from=client.V1EnvVarSource(
501
+ field_ref=client.V1ObjectFieldSelector(
502
+ field_path="metadata.annotations['jobset.sigs.k8s.io/jobset-name']"
503
+ )
504
+ ),
505
+ ),
506
+ client.V1EnvVar(
507
+ name="WORKER_REPLICA_INDEX",
508
+ value_from=client.V1EnvVarSource(
509
+ field_ref=client.V1ObjectFieldSelector(
510
+ field_path="metadata.annotations['jobset.sigs.k8s.io/job-index']"
511
+ )
512
+ ),
513
+ ),
514
+ ]
515
+
516
+
517
+ def get_control_job(
518
+ client,
519
+ job_spec,
520
+ jobset_main_addr,
521
+ subdomain,
522
+ port=None,
523
+ num_parallel=None,
524
+ namespace=None,
525
+ annotations=None,
526
+ ) -> dict:
527
+ master_port = port
528
+
529
+ job_spec = copy.deepcopy(job_spec)
530
+ job_spec.parallelism = 1
531
+ job_spec.completions = 1
532
+ job_spec.template.spec.set_hostname_as_fqdn = True
533
+ job_spec.template.spec.subdomain = subdomain
534
+ job_spec.template.metadata.annotations = copy.copy(annotations)
535
+
536
+ for idx in range(len(job_spec.template.spec.containers[0].command)):
537
+ # CHECK FOR THE ubf_context in the command.
538
+ # Replace the UBF context to the one appropriately matching control/worker.
539
+ # Since we are passing the `step_cli` one time from the top level to one
540
+ # KuberentesJobSet, we need to ensure that UBF context is replaced properly
541
+ # in all the worker jobs.
542
+ if UBF_CONTROL in job_spec.template.spec.containers[0].command[idx]:
543
+ job_spec.template.spec.containers[0].command[idx] = (
544
+ job_spec.template.spec.containers[0]
545
+ .command[idx]
546
+ .replace(UBF_CONTROL, UBF_CONTROL + " " + "--split-index 0")
547
+ )
548
+
549
+ job_spec.template.spec.containers[0].env = (
550
+ job_spec.template.spec.containers[0].env
551
+ + _jobset_specific_env_vars(client, jobset_main_addr, master_port, num_parallel)
552
+ + [
553
+ client.V1EnvVar(
554
+ name="CONTROL_INDEX",
555
+ value=str(0),
556
+ )
557
+ ]
558
+ )
559
+
560
+ # Based on https://github.com/kubernetes-sigs/jobset/blob/v0.5.0/api/jobset/v1alpha2/jobset_types.go#L178
561
+ return dict(
562
+ name="control",
563
+ template=client.api_client.ApiClient().sanitize_for_serialization(
564
+ client.V1JobTemplateSpec(
565
+ metadata=client.V1ObjectMeta(
566
+ namespace=namespace,
567
+ # We don't set any annotations here
568
+ # since they have been either set in the JobSpec
569
+ # or on the JobSet level
570
+ ),
571
+ spec=job_spec,
572
+ )
573
+ ),
574
+ replicas=1, # The control job will always have 1 replica.
575
+ )
576
+
577
+
578
+ def get_worker_job(
579
+ client,
580
+ job_spec,
581
+ job_name,
582
+ jobset_main_addr,
583
+ subdomain,
584
+ control_task_id=None,
585
+ worker_task_id=None,
586
+ replicas=1,
587
+ port=None,
588
+ num_parallel=None,
589
+ namespace=None,
590
+ annotations=None,
591
+ ) -> dict:
592
+ master_port = port
593
+
594
+ job_spec = copy.deepcopy(job_spec)
595
+ job_spec.parallelism = 1
596
+ job_spec.completions = 1
597
+ job_spec.template.spec.set_hostname_as_fqdn = True
598
+ job_spec.template.spec.subdomain = subdomain
599
+ job_spec.template.metadata.annotations = copy.copy(annotations)
600
+
601
+ for idx in range(len(job_spec.template.spec.containers[0].command)):
602
+ if control_task_id in job_spec.template.spec.containers[0].command[idx]:
603
+ job_spec.template.spec.containers[0].command[idx] = (
604
+ job_spec.template.spec.containers[0]
605
+ .command[idx]
606
+ .replace(control_task_id, worker_task_id)
607
+ )
608
+ # CHECK FOR THE ubf_context in the command.
609
+ # Replace the UBF context to the one appropriately matching control/worker.
610
+ # Since we are passing the `step_cli` one time from the top level to one
611
+ # KuberentesJobSet, we need to ensure that UBF context is replaced properly
612
+ # in all the worker jobs.
613
+ if UBF_CONTROL in job_spec.template.spec.containers[0].command[idx]:
614
+ # Since all command will have a UBF_CONTROL, we need to replace the UBF_CONTROL
615
+ # with the actual UBF Context and also ensure that we are setting the correct
616
+ # split-index for the worker jobs.
617
+ split_index_str = "--split-index `expr $[WORKER_REPLICA_INDEX] + 1`" # This set in the environment variables below
618
+ job_spec.template.spec.containers[0].command[idx] = (
619
+ job_spec.template.spec.containers[0]
620
+ .command[idx]
621
+ .replace(UBF_CONTROL, UBF_TASK + " " + split_index_str)
622
+ )
623
+
624
+ job_spec.template.spec.containers[0].env = job_spec.template.spec.containers[
625
+ 0
626
+ ].env + _jobset_specific_env_vars(
627
+ client, jobset_main_addr, master_port, num_parallel
628
+ )
629
+
630
+ # Based on https://github.com/kubernetes-sigs/jobset/blob/v0.5.0/api/jobset/v1alpha2/jobset_types.go#L178
631
+ return dict(
632
+ name=job_name,
633
+ template=client.api_client.ApiClient().sanitize_for_serialization(
634
+ client.V1JobTemplateSpec(
635
+ metadata=client.V1ObjectMeta(
636
+ namespace=namespace,
637
+ # We don't set any annotations here
638
+ # since they have been either set in the JobSpec
639
+ # or on the JobSet level
640
+ ),
641
+ spec=job_spec,
642
+ )
643
+ ),
644
+ replicas=replicas,
645
+ )
646
+
647
+
648
+ def _make_domain_name(
649
+ jobset_name, main_job_name, main_job_index, main_pod_index, namespace
650
+ ):
651
+ return "%s-%s-%s-%s.%s.%s.svc.cluster.local" % (
652
+ jobset_name,
653
+ main_job_name,
654
+ main_job_index,
655
+ main_pod_index,
656
+ jobset_name,
657
+ namespace,
658
+ )
659
+
660
+
661
+ class KubernetesJobSet(object):
662
+ def __init__(
663
+ self,
664
+ client,
665
+ name=None,
666
+ job_spec=None,
667
+ namespace=None,
668
+ num_parallel=None,
669
+ annotations=None,
670
+ labels=None,
671
+ port=None,
672
+ task_id=None,
673
+ **kwargs
674
+ ):
675
+ self._client = client
676
+ self._kwargs = kwargs
677
+ self._group = KUBERNETES_JOBSET_GROUP
678
+ self._version = KUBERNETES_JOBSET_VERSION
679
+ self.name = name
680
+
681
+ main_job_name = "control"
682
+ main_job_index = 0
683
+ main_pod_index = 0
684
+ subdomain = self.name
685
+ num_parallel = int(1 if not num_parallel else num_parallel)
686
+ self._namespace = namespace
687
+ jobset_main_addr = _make_domain_name(
688
+ self.name,
689
+ main_job_name,
690
+ main_job_index,
691
+ main_pod_index,
692
+ self._namespace,
693
+ )
694
+
695
+ annotations = {} if not annotations else annotations
696
+ labels = {} if not labels else labels
697
+
698
+ if "metaflow/task_id" in annotations:
699
+ del annotations["metaflow/task_id"]
700
+
701
+ control_job = get_control_job(
702
+ client=self._client.get(),
703
+ job_spec=job_spec,
704
+ jobset_main_addr=jobset_main_addr,
705
+ subdomain=subdomain,
706
+ port=port,
707
+ num_parallel=num_parallel,
708
+ namespace=namespace,
709
+ annotations=annotations,
710
+ )
711
+ worker_task_id = TaskIdConstructor.jobset_worker_id(task_id)
712
+ worker_job = get_worker_job(
713
+ client=self._client.get(),
714
+ job_spec=job_spec,
715
+ job_name="worker",
716
+ jobset_main_addr=jobset_main_addr,
717
+ subdomain=subdomain,
718
+ control_task_id=task_id,
719
+ worker_task_id=worker_task_id,
720
+ replicas=num_parallel - 1,
721
+ port=port,
722
+ num_parallel=num_parallel,
723
+ namespace=namespace,
724
+ annotations=annotations,
725
+ )
726
+ worker_jobs = [worker_job]
727
+ # Based on https://github.com/kubernetes-sigs/jobset/blob/v0.5.0/api/jobset/v1alpha2/jobset_types.go#L163
728
+ _kclient = client.get()
729
+ self._jobset = dict(
730
+ apiVersion=self._group + "/" + self._version,
731
+ kind="JobSet",
732
+ metadata=_kclient.api_client.ApiClient().sanitize_for_serialization(
733
+ _kclient.V1ObjectMeta(
734
+ name=self.name, labels=labels, annotations=annotations
735
+ )
736
+ ),
737
+ spec=dict(
738
+ replicatedJobs=[control_job] + worker_jobs,
739
+ suspend=False,
740
+ startupPolicy=None,
741
+ successPolicy=None,
742
+ # The Failure Policy helps setting the number of retries for the jobset.
743
+ # It cannot accept a value of 0 for maxRestarts.
744
+ # So the attempt needs to be smartly set.
745
+ # If there is no retry decorator then we not set maxRestarts and instead we will
746
+ # set the attempt statically to 0. Otherwise we will make the job pickup the attempt
747
+ # from the `V1EnvVarSource.value_from.V1ObjectFieldSelector.field_path` = "metadata.annotations['jobset.sigs.k8s.io/restart-attempt']"
748
+ # failurePolicy={
749
+ # "maxRestarts" : 1
750
+ # },
751
+ # The can be set for ArgoWorkflows
752
+ failurePolicy=None,
753
+ network=None,
754
+ ),
755
+ status=None,
756
+ )
757
+
758
+ def execute(self):
759
+ client = self._client.get()
760
+ api_instance = client.CoreV1Api()
761
+
762
+ with client.ApiClient() as api_client:
763
+ api_instance = client.CustomObjectsApi(api_client)
764
+ try:
765
+ jobset_obj = api_instance.create_namespaced_custom_object(
766
+ group=self._group,
767
+ version=self._version,
768
+ namespace=self._namespace,
769
+ plural="jobsets",
770
+ body=self._jobset,
771
+ )
772
+ except Exception as e:
773
+ raise KubernetesJobsetException(
774
+ "Exception when calling CustomObjectsApi->create_namespaced_custom_object: %s\n"
775
+ % e
776
+ )
777
+
778
+ return RunningJobSet(
779
+ client=self._client,
780
+ name=jobset_obj["metadata"]["name"],
781
+ namespace=jobset_obj["metadata"]["namespace"],
782
+ group=self._group,
783
+ version=self._version,
784
+ )