metaflow 2.12.7__py2.py3-none-any.whl → 2.12.9__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- metaflow/__init__.py +2 -0
- metaflow/cli.py +12 -4
- metaflow/extension_support/plugins.py +1 -0
- metaflow/flowspec.py +8 -1
- metaflow/lint.py +13 -0
- metaflow/metaflow_current.py +0 -8
- metaflow/plugins/__init__.py +12 -0
- metaflow/plugins/argo/argo_workflows.py +462 -42
- metaflow/plugins/argo/argo_workflows_cli.py +60 -3
- metaflow/plugins/argo/argo_workflows_decorator.py +38 -7
- metaflow/plugins/argo/argo_workflows_deployer.py +290 -0
- metaflow/plugins/argo/jobset_input_paths.py +16 -0
- metaflow/plugins/aws/batch/batch_decorator.py +16 -13
- metaflow/plugins/aws/step_functions/step_functions_cli.py +45 -3
- metaflow/plugins/aws/step_functions/step_functions_deployer.py +251 -0
- metaflow/plugins/cards/card_cli.py +1 -1
- metaflow/plugins/kubernetes/kubernetes.py +279 -52
- metaflow/plugins/kubernetes/kubernetes_cli.py +26 -8
- metaflow/plugins/kubernetes/kubernetes_client.py +0 -1
- metaflow/plugins/kubernetes/kubernetes_decorator.py +56 -44
- metaflow/plugins/kubernetes/kubernetes_job.py +6 -6
- metaflow/plugins/kubernetes/kubernetes_jobsets.py +510 -272
- metaflow/plugins/parallel_decorator.py +108 -8
- metaflow/plugins/pypi/bootstrap.py +1 -1
- metaflow/plugins/pypi/micromamba.py +1 -1
- metaflow/plugins/secrets/secrets_decorator.py +12 -3
- metaflow/plugins/test_unbounded_foreach_decorator.py +39 -4
- metaflow/runner/deployer.py +386 -0
- metaflow/runner/metaflow_runner.py +1 -20
- metaflow/runner/nbdeploy.py +130 -0
- metaflow/runner/nbrun.py +4 -28
- metaflow/runner/utils.py +49 -0
- metaflow/runtime.py +246 -134
- metaflow/version.py +1 -1
- {metaflow-2.12.7.dist-info → metaflow-2.12.9.dist-info}/METADATA +2 -2
- {metaflow-2.12.7.dist-info → metaflow-2.12.9.dist-info}/RECORD +40 -34
- {metaflow-2.12.7.dist-info → metaflow-2.12.9.dist-info}/WHEEL +1 -1
- {metaflow-2.12.7.dist-info → metaflow-2.12.9.dist-info}/LICENSE +0 -0
- {metaflow-2.12.7.dist-info → metaflow-2.12.9.dist-info}/entry_points.txt +0 -0
- {metaflow-2.12.7.dist-info → metaflow-2.12.9.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,14 @@
|
|
1
1
|
import copy
|
2
|
+
import json
|
2
3
|
import math
|
3
4
|
import random
|
4
5
|
import time
|
5
|
-
from
|
6
|
+
from collections import namedtuple
|
7
|
+
|
6
8
|
from metaflow.exception import MetaflowException
|
7
|
-
from metaflow.unbounded_foreach import UBF_CONTROL, UBF_TASK
|
8
|
-
import json
|
9
9
|
from metaflow.metaflow_config import KUBERNETES_JOBSET_GROUP, KUBERNETES_JOBSET_VERSION
|
10
|
-
from
|
10
|
+
from metaflow.tracing import inject_tracing_vars
|
11
|
+
from metaflow.metaflow_config import KUBERNETES_SECRETS
|
11
12
|
|
12
13
|
|
13
14
|
class KubernetesJobsetException(MetaflowException):
|
@@ -51,6 +52,8 @@ def k8s_retry(deadline_seconds=60, max_backoff=32):
|
|
51
52
|
return decorator
|
52
53
|
|
53
54
|
|
55
|
+
CONTROL_JOB_NAME = "control"
|
56
|
+
|
54
57
|
JobsetStatus = namedtuple(
|
55
58
|
"JobsetStatus",
|
56
59
|
[
|
@@ -323,18 +326,18 @@ class RunningJobSet(object):
|
|
323
326
|
with client.ApiClient() as api_client:
|
324
327
|
api_instance = client.CustomObjectsApi(api_client)
|
325
328
|
try:
|
326
|
-
|
329
|
+
obj = api_instance.get_namespaced_custom_object(
|
327
330
|
group=self._group,
|
328
331
|
version=self._version,
|
329
332
|
namespace=self._namespace,
|
330
|
-
plural=
|
333
|
+
plural=plural,
|
331
334
|
name=self._name,
|
332
335
|
)
|
333
336
|
|
334
337
|
# Suspend the jobset and set the replica's to Zero.
|
335
338
|
#
|
336
|
-
|
337
|
-
for replicated_job in
|
339
|
+
obj["spec"]["suspend"] = True
|
340
|
+
for replicated_job in obj["spec"]["replicatedJobs"]:
|
338
341
|
replicated_job["replicas"] = 0
|
339
342
|
|
340
343
|
api_instance.replace_namespaced_custom_object(
|
@@ -342,8 +345,8 @@ class RunningJobSet(object):
|
|
342
345
|
version=self._version,
|
343
346
|
namespace=self._namespace,
|
344
347
|
plural=plural,
|
345
|
-
name=
|
346
|
-
body=
|
348
|
+
name=obj["metadata"]["name"],
|
349
|
+
body=obj,
|
347
350
|
)
|
348
351
|
except Exception as e:
|
349
352
|
raise KubernetesJobsetException(
|
@@ -448,203 +451,6 @@ class RunningJobSet(object):
|
|
448
451
|
).jobset_failed
|
449
452
|
|
450
453
|
|
451
|
-
class TaskIdConstructor:
|
452
|
-
@classmethod
|
453
|
-
def jobset_worker_id(cls, control_task_id: str):
|
454
|
-
return "".join(
|
455
|
-
[control_task_id.replace("control", "worker"), "-", "$WORKER_REPLICA_INDEX"]
|
456
|
-
)
|
457
|
-
|
458
|
-
@classmethod
|
459
|
-
def join_step_task_ids(cls, num_parallel):
|
460
|
-
"""
|
461
|
-
Called within the step decorator to set the `flow._control_mapper_tasks`.
|
462
|
-
Setting these allows the flow to know which tasks are needed in the join step.
|
463
|
-
We set this in the `task_pre_step` method of the decorator.
|
464
|
-
"""
|
465
|
-
control_task_id = current.task_id
|
466
|
-
worker_task_id_base = control_task_id.replace("control", "worker")
|
467
|
-
mapper = lambda idx: worker_task_id_base + "-%s" % (str(idx))
|
468
|
-
return control_task_id, [mapper(idx) for idx in range(0, num_parallel - 1)]
|
469
|
-
|
470
|
-
@classmethod
|
471
|
-
def argo(cls):
|
472
|
-
pass
|
473
|
-
|
474
|
-
|
475
|
-
def _jobset_specific_env_vars(client, jobset_main_addr, master_port, num_parallel):
|
476
|
-
return [
|
477
|
-
client.V1EnvVar(
|
478
|
-
name="MASTER_ADDR",
|
479
|
-
value=jobset_main_addr,
|
480
|
-
),
|
481
|
-
client.V1EnvVar(
|
482
|
-
name="MASTER_PORT",
|
483
|
-
value=str(master_port),
|
484
|
-
),
|
485
|
-
client.V1EnvVar(
|
486
|
-
name="WORLD_SIZE",
|
487
|
-
value=str(num_parallel),
|
488
|
-
),
|
489
|
-
] + [
|
490
|
-
client.V1EnvVar(
|
491
|
-
name="JOBSET_RESTART_ATTEMPT",
|
492
|
-
value_from=client.V1EnvVarSource(
|
493
|
-
field_ref=client.V1ObjectFieldSelector(
|
494
|
-
field_path="metadata.annotations['jobset.sigs.k8s.io/restart-attempt']"
|
495
|
-
)
|
496
|
-
),
|
497
|
-
),
|
498
|
-
client.V1EnvVar(
|
499
|
-
name="METAFLOW_KUBERNETES_JOBSET_NAME",
|
500
|
-
value_from=client.V1EnvVarSource(
|
501
|
-
field_ref=client.V1ObjectFieldSelector(
|
502
|
-
field_path="metadata.annotations['jobset.sigs.k8s.io/jobset-name']"
|
503
|
-
)
|
504
|
-
),
|
505
|
-
),
|
506
|
-
client.V1EnvVar(
|
507
|
-
name="WORKER_REPLICA_INDEX",
|
508
|
-
value_from=client.V1EnvVarSource(
|
509
|
-
field_ref=client.V1ObjectFieldSelector(
|
510
|
-
field_path="metadata.annotations['jobset.sigs.k8s.io/job-index']"
|
511
|
-
)
|
512
|
-
),
|
513
|
-
),
|
514
|
-
]
|
515
|
-
|
516
|
-
|
517
|
-
def get_control_job(
|
518
|
-
client,
|
519
|
-
job_spec,
|
520
|
-
jobset_main_addr,
|
521
|
-
subdomain,
|
522
|
-
port=None,
|
523
|
-
num_parallel=None,
|
524
|
-
namespace=None,
|
525
|
-
annotations=None,
|
526
|
-
) -> dict:
|
527
|
-
master_port = port
|
528
|
-
|
529
|
-
job_spec = copy.deepcopy(job_spec)
|
530
|
-
job_spec.parallelism = 1
|
531
|
-
job_spec.completions = 1
|
532
|
-
job_spec.template.spec.set_hostname_as_fqdn = True
|
533
|
-
job_spec.template.spec.subdomain = subdomain
|
534
|
-
job_spec.template.metadata.annotations = copy.copy(annotations)
|
535
|
-
|
536
|
-
for idx in range(len(job_spec.template.spec.containers[0].command)):
|
537
|
-
# CHECK FOR THE ubf_context in the command.
|
538
|
-
# Replace the UBF context to the one appropriately matching control/worker.
|
539
|
-
# Since we are passing the `step_cli` one time from the top level to one
|
540
|
-
# KuberentesJobSet, we need to ensure that UBF context is replaced properly
|
541
|
-
# in all the worker jobs.
|
542
|
-
if UBF_CONTROL in job_spec.template.spec.containers[0].command[idx]:
|
543
|
-
job_spec.template.spec.containers[0].command[idx] = (
|
544
|
-
job_spec.template.spec.containers[0]
|
545
|
-
.command[idx]
|
546
|
-
.replace(UBF_CONTROL, UBF_CONTROL + " " + "--split-index 0")
|
547
|
-
)
|
548
|
-
|
549
|
-
job_spec.template.spec.containers[0].env = (
|
550
|
-
job_spec.template.spec.containers[0].env
|
551
|
-
+ _jobset_specific_env_vars(client, jobset_main_addr, master_port, num_parallel)
|
552
|
-
+ [
|
553
|
-
client.V1EnvVar(
|
554
|
-
name="CONTROL_INDEX",
|
555
|
-
value=str(0),
|
556
|
-
)
|
557
|
-
]
|
558
|
-
)
|
559
|
-
|
560
|
-
# Based on https://github.com/kubernetes-sigs/jobset/blob/v0.5.0/api/jobset/v1alpha2/jobset_types.go#L178
|
561
|
-
return dict(
|
562
|
-
name="control",
|
563
|
-
template=client.api_client.ApiClient().sanitize_for_serialization(
|
564
|
-
client.V1JobTemplateSpec(
|
565
|
-
metadata=client.V1ObjectMeta(
|
566
|
-
namespace=namespace,
|
567
|
-
# We don't set any annotations here
|
568
|
-
# since they have been either set in the JobSpec
|
569
|
-
# or on the JobSet level
|
570
|
-
),
|
571
|
-
spec=job_spec,
|
572
|
-
)
|
573
|
-
),
|
574
|
-
replicas=1, # The control job will always have 1 replica.
|
575
|
-
)
|
576
|
-
|
577
|
-
|
578
|
-
def get_worker_job(
|
579
|
-
client,
|
580
|
-
job_spec,
|
581
|
-
job_name,
|
582
|
-
jobset_main_addr,
|
583
|
-
subdomain,
|
584
|
-
control_task_id=None,
|
585
|
-
worker_task_id=None,
|
586
|
-
replicas=1,
|
587
|
-
port=None,
|
588
|
-
num_parallel=None,
|
589
|
-
namespace=None,
|
590
|
-
annotations=None,
|
591
|
-
) -> dict:
|
592
|
-
master_port = port
|
593
|
-
|
594
|
-
job_spec = copy.deepcopy(job_spec)
|
595
|
-
job_spec.parallelism = 1
|
596
|
-
job_spec.completions = 1
|
597
|
-
job_spec.template.spec.set_hostname_as_fqdn = True
|
598
|
-
job_spec.template.spec.subdomain = subdomain
|
599
|
-
job_spec.template.metadata.annotations = copy.copy(annotations)
|
600
|
-
|
601
|
-
for idx in range(len(job_spec.template.spec.containers[0].command)):
|
602
|
-
if control_task_id in job_spec.template.spec.containers[0].command[idx]:
|
603
|
-
job_spec.template.spec.containers[0].command[idx] = (
|
604
|
-
job_spec.template.spec.containers[0]
|
605
|
-
.command[idx]
|
606
|
-
.replace(control_task_id, worker_task_id)
|
607
|
-
)
|
608
|
-
# CHECK FOR THE ubf_context in the command.
|
609
|
-
# Replace the UBF context to the one appropriately matching control/worker.
|
610
|
-
# Since we are passing the `step_cli` one time from the top level to one
|
611
|
-
# KuberentesJobSet, we need to ensure that UBF context is replaced properly
|
612
|
-
# in all the worker jobs.
|
613
|
-
if UBF_CONTROL in job_spec.template.spec.containers[0].command[idx]:
|
614
|
-
# Since all command will have a UBF_CONTROL, we need to replace the UBF_CONTROL
|
615
|
-
# with the actual UBF Context and also ensure that we are setting the correct
|
616
|
-
# split-index for the worker jobs.
|
617
|
-
split_index_str = "--split-index `expr $[WORKER_REPLICA_INDEX] + 1`" # This set in the environment variables below
|
618
|
-
job_spec.template.spec.containers[0].command[idx] = (
|
619
|
-
job_spec.template.spec.containers[0]
|
620
|
-
.command[idx]
|
621
|
-
.replace(UBF_CONTROL, UBF_TASK + " " + split_index_str)
|
622
|
-
)
|
623
|
-
|
624
|
-
job_spec.template.spec.containers[0].env = job_spec.template.spec.containers[
|
625
|
-
0
|
626
|
-
].env + _jobset_specific_env_vars(
|
627
|
-
client, jobset_main_addr, master_port, num_parallel
|
628
|
-
)
|
629
|
-
|
630
|
-
# Based on https://github.com/kubernetes-sigs/jobset/blob/v0.5.0/api/jobset/v1alpha2/jobset_types.go#L178
|
631
|
-
return dict(
|
632
|
-
name=job_name,
|
633
|
-
template=client.api_client.ApiClient().sanitize_for_serialization(
|
634
|
-
client.V1JobTemplateSpec(
|
635
|
-
metadata=client.V1ObjectMeta(
|
636
|
-
namespace=namespace,
|
637
|
-
# We don't set any annotations here
|
638
|
-
# since they have been either set in the JobSpec
|
639
|
-
# or on the JobSet level
|
640
|
-
),
|
641
|
-
spec=job_spec,
|
642
|
-
)
|
643
|
-
),
|
644
|
-
replicas=replicas,
|
645
|
-
)
|
646
|
-
|
647
|
-
|
648
454
|
def _make_domain_name(
|
649
455
|
jobset_name, main_job_name, main_job_index, main_pod_index, namespace
|
650
456
|
):
|
@@ -658,97 +464,413 @@ def _make_domain_name(
|
|
658
464
|
)
|
659
465
|
|
660
466
|
|
467
|
+
class JobSetSpec(object):
|
468
|
+
def __init__(self, kubernetes_sdk, name, **kwargs):
|
469
|
+
self._kubernetes_sdk = kubernetes_sdk
|
470
|
+
self._kwargs = kwargs
|
471
|
+
self.name = name
|
472
|
+
|
473
|
+
def replicas(self, replicas):
|
474
|
+
self._kwargs["replicas"] = replicas
|
475
|
+
return self
|
476
|
+
|
477
|
+
def step_name(self, step_name):
|
478
|
+
self._kwargs["step_name"] = step_name
|
479
|
+
return self
|
480
|
+
|
481
|
+
def namespace(self, namespace):
|
482
|
+
self._kwargs["namespace"] = namespace
|
483
|
+
return self
|
484
|
+
|
485
|
+
def command(self, command):
|
486
|
+
self._kwargs["command"] = command
|
487
|
+
return self
|
488
|
+
|
489
|
+
def image(self, image):
|
490
|
+
self._kwargs["image"] = image
|
491
|
+
return self
|
492
|
+
|
493
|
+
def cpu(self, cpu):
|
494
|
+
self._kwargs["cpu"] = cpu
|
495
|
+
return self
|
496
|
+
|
497
|
+
def memory(self, mem):
|
498
|
+
self._kwargs["memory"] = mem
|
499
|
+
return self
|
500
|
+
|
501
|
+
def environment_variable(self, name, value):
|
502
|
+
# Never set to None
|
503
|
+
if value is None:
|
504
|
+
return self
|
505
|
+
self._kwargs["environment_variables"] = dict(
|
506
|
+
self._kwargs.get("environment_variables", {}), **{name: value}
|
507
|
+
)
|
508
|
+
return self
|
509
|
+
|
510
|
+
def secret(self, name):
|
511
|
+
if name is None:
|
512
|
+
return self
|
513
|
+
if len(self._kwargs.get("secrets", [])) == 0:
|
514
|
+
self._kwargs["secrets"] = []
|
515
|
+
self._kwargs["secrets"] = list(set(self._kwargs["secrets"] + [name]))
|
516
|
+
|
517
|
+
def environment_variable_from_selector(self, name, label_value):
|
518
|
+
# Never set to None
|
519
|
+
if label_value is None:
|
520
|
+
return self
|
521
|
+
self._kwargs["environment_variables_from_selectors"] = dict(
|
522
|
+
self._kwargs.get("environment_variables_from_selectors", {}),
|
523
|
+
**{name: label_value}
|
524
|
+
)
|
525
|
+
return self
|
526
|
+
|
527
|
+
def label(self, name, value):
|
528
|
+
self._kwargs["labels"] = dict(self._kwargs.get("labels", {}), **{name: value})
|
529
|
+
return self
|
530
|
+
|
531
|
+
def annotation(self, name, value):
|
532
|
+
self._kwargs["annotations"] = dict(
|
533
|
+
self._kwargs.get("annotations", {}), **{name: value}
|
534
|
+
)
|
535
|
+
return self
|
536
|
+
|
537
|
+
def dump(self):
|
538
|
+
client = self._kubernetes_sdk
|
539
|
+
use_tmpfs = self._kwargs["use_tmpfs"]
|
540
|
+
tmpfs_size = self._kwargs["tmpfs_size"]
|
541
|
+
tmpfs_enabled = use_tmpfs or (tmpfs_size and not use_tmpfs)
|
542
|
+
shared_memory = (
|
543
|
+
int(self._kwargs["shared_memory"])
|
544
|
+
if self._kwargs["shared_memory"]
|
545
|
+
else None
|
546
|
+
)
|
547
|
+
|
548
|
+
return dict(
|
549
|
+
name=self.name,
|
550
|
+
template=client.api_client.ApiClient().sanitize_for_serialization(
|
551
|
+
client.V1JobTemplateSpec(
|
552
|
+
metadata=client.V1ObjectMeta(
|
553
|
+
namespace=self._kwargs["namespace"],
|
554
|
+
# We don't set any annotations here
|
555
|
+
# since they have been either set in the JobSpec
|
556
|
+
# or on the JobSet level
|
557
|
+
),
|
558
|
+
spec=client.V1JobSpec(
|
559
|
+
# Retries are handled by Metaflow when it is responsible for
|
560
|
+
# executing the flow. The responsibility is moved to Kubernetes
|
561
|
+
# when Argo Workflows is responsible for the execution.
|
562
|
+
backoff_limit=self._kwargs.get("retries", 0),
|
563
|
+
completions=1,
|
564
|
+
parallelism=1,
|
565
|
+
ttl_seconds_after_finished=7
|
566
|
+
* 60
|
567
|
+
* 60 # Remove job after a week. TODO: Make this configurable
|
568
|
+
* 24,
|
569
|
+
template=client.V1PodTemplateSpec(
|
570
|
+
metadata=client.V1ObjectMeta(
|
571
|
+
annotations=self._kwargs.get("annotations", {}),
|
572
|
+
labels=self._kwargs.get("labels", {}),
|
573
|
+
namespace=self._kwargs["namespace"],
|
574
|
+
),
|
575
|
+
spec=client.V1PodSpec(
|
576
|
+
## --- jobset require podspec deets start----
|
577
|
+
subdomain=self._kwargs["subdomain"],
|
578
|
+
set_hostname_as_fqdn=True,
|
579
|
+
## --- jobset require podspec deets end ----
|
580
|
+
# Timeout is set on the pod and not the job (important!)
|
581
|
+
active_deadline_seconds=self._kwargs[
|
582
|
+
"timeout_in_seconds"
|
583
|
+
],
|
584
|
+
# TODO (savin): Enable affinities for GPU scheduling.
|
585
|
+
# affinity=?,
|
586
|
+
containers=[
|
587
|
+
client.V1Container(
|
588
|
+
command=self._kwargs["command"],
|
589
|
+
ports=[]
|
590
|
+
if self._kwargs["port"] is None
|
591
|
+
else [
|
592
|
+
client.V1ContainerPort(
|
593
|
+
container_port=int(self._kwargs["port"])
|
594
|
+
)
|
595
|
+
],
|
596
|
+
env=[
|
597
|
+
client.V1EnvVar(name=k, value=str(v))
|
598
|
+
for k, v in self._kwargs.get(
|
599
|
+
"environment_variables", {}
|
600
|
+
).items()
|
601
|
+
]
|
602
|
+
# And some downward API magic. Add (key, value)
|
603
|
+
# pairs below to make pod metadata available
|
604
|
+
# within Kubernetes container.
|
605
|
+
+ [
|
606
|
+
client.V1EnvVar(
|
607
|
+
name=k,
|
608
|
+
value_from=client.V1EnvVarSource(
|
609
|
+
field_ref=client.V1ObjectFieldSelector(
|
610
|
+
field_path=str(v)
|
611
|
+
)
|
612
|
+
),
|
613
|
+
)
|
614
|
+
for k, v in self._kwargs.get(
|
615
|
+
"environment_variables_from_selectors",
|
616
|
+
{},
|
617
|
+
).items()
|
618
|
+
]
|
619
|
+
+ [
|
620
|
+
client.V1EnvVar(name=k, value=str(v))
|
621
|
+
for k, v in inject_tracing_vars({}).items()
|
622
|
+
],
|
623
|
+
env_from=[
|
624
|
+
client.V1EnvFromSource(
|
625
|
+
secret_ref=client.V1SecretEnvSource(
|
626
|
+
name=str(k),
|
627
|
+
# optional=True
|
628
|
+
)
|
629
|
+
)
|
630
|
+
for k in list(
|
631
|
+
self._kwargs.get("secrets", [])
|
632
|
+
)
|
633
|
+
if k
|
634
|
+
],
|
635
|
+
image=self._kwargs["image"],
|
636
|
+
image_pull_policy=self._kwargs[
|
637
|
+
"image_pull_policy"
|
638
|
+
],
|
639
|
+
name=self._kwargs["step_name"].replace(
|
640
|
+
"_", "-"
|
641
|
+
),
|
642
|
+
resources=client.V1ResourceRequirements(
|
643
|
+
requests={
|
644
|
+
"cpu": str(self._kwargs["cpu"]),
|
645
|
+
"memory": "%sM"
|
646
|
+
% str(self._kwargs["memory"]),
|
647
|
+
"ephemeral-storage": "%sM"
|
648
|
+
% str(self._kwargs["disk"]),
|
649
|
+
},
|
650
|
+
limits={
|
651
|
+
"%s.com/gpu".lower()
|
652
|
+
% self._kwargs["gpu_vendor"]: str(
|
653
|
+
self._kwargs["gpu"]
|
654
|
+
)
|
655
|
+
for k in [0]
|
656
|
+
# Don't set GPU limits if gpu isn't specified.
|
657
|
+
if self._kwargs["gpu"] is not None
|
658
|
+
},
|
659
|
+
),
|
660
|
+
volume_mounts=(
|
661
|
+
[
|
662
|
+
client.V1VolumeMount(
|
663
|
+
mount_path=self._kwargs.get(
|
664
|
+
"tmpfs_path"
|
665
|
+
),
|
666
|
+
name="tmpfs-ephemeral-volume",
|
667
|
+
)
|
668
|
+
]
|
669
|
+
if tmpfs_enabled
|
670
|
+
else []
|
671
|
+
)
|
672
|
+
+ (
|
673
|
+
[
|
674
|
+
client.V1VolumeMount(
|
675
|
+
mount_path="/dev/shm", name="dhsm"
|
676
|
+
)
|
677
|
+
]
|
678
|
+
if shared_memory
|
679
|
+
else []
|
680
|
+
)
|
681
|
+
+ (
|
682
|
+
[
|
683
|
+
client.V1VolumeMount(
|
684
|
+
mount_path=path, name=claim
|
685
|
+
)
|
686
|
+
for claim, path in self._kwargs[
|
687
|
+
"persistent_volume_claims"
|
688
|
+
].items()
|
689
|
+
]
|
690
|
+
if self._kwargs["persistent_volume_claims"]
|
691
|
+
is not None
|
692
|
+
else []
|
693
|
+
),
|
694
|
+
)
|
695
|
+
],
|
696
|
+
node_selector=self._kwargs.get("node_selector"),
|
697
|
+
# TODO (savin): Support image_pull_secrets
|
698
|
+
# image_pull_secrets=?,
|
699
|
+
# TODO (savin): Support preemption policies
|
700
|
+
# preemption_policy=?,
|
701
|
+
#
|
702
|
+
# A Container in a Pod may fail for a number of
|
703
|
+
# reasons, such as because the process in it exited
|
704
|
+
# with a non-zero exit code, or the Container was
|
705
|
+
# killed due to OOM etc. If this happens, fail the pod
|
706
|
+
# and let Metaflow handle the retries.
|
707
|
+
restart_policy="Never",
|
708
|
+
service_account_name=self._kwargs["service_account"],
|
709
|
+
# Terminate the container immediately on SIGTERM
|
710
|
+
termination_grace_period_seconds=0,
|
711
|
+
tolerations=[
|
712
|
+
client.V1Toleration(**toleration)
|
713
|
+
for toleration in self._kwargs.get("tolerations")
|
714
|
+
or []
|
715
|
+
],
|
716
|
+
volumes=(
|
717
|
+
[
|
718
|
+
client.V1Volume(
|
719
|
+
name="tmpfs-ephemeral-volume",
|
720
|
+
empty_dir=client.V1EmptyDirVolumeSource(
|
721
|
+
medium="Memory",
|
722
|
+
# Add default unit as ours differs from Kubernetes default.
|
723
|
+
size_limit="{}Mi".format(tmpfs_size),
|
724
|
+
),
|
725
|
+
)
|
726
|
+
]
|
727
|
+
if tmpfs_enabled
|
728
|
+
else []
|
729
|
+
)
|
730
|
+
+ (
|
731
|
+
[
|
732
|
+
client.V1Volume(
|
733
|
+
name="dhsm",
|
734
|
+
empty_dir=client.V1EmptyDirVolumeSource(
|
735
|
+
medium="Memory",
|
736
|
+
size_limit="{}Mi".format(shared_memory),
|
737
|
+
),
|
738
|
+
)
|
739
|
+
]
|
740
|
+
if shared_memory
|
741
|
+
else []
|
742
|
+
)
|
743
|
+
+ (
|
744
|
+
[
|
745
|
+
client.V1Volume(
|
746
|
+
name=claim,
|
747
|
+
persistent_volume_claim=client.V1PersistentVolumeClaimVolumeSource(
|
748
|
+
claim_name=claim
|
749
|
+
),
|
750
|
+
)
|
751
|
+
for claim in self._kwargs[
|
752
|
+
"persistent_volume_claims"
|
753
|
+
].keys()
|
754
|
+
]
|
755
|
+
if self._kwargs["persistent_volume_claims"]
|
756
|
+
is not None
|
757
|
+
else []
|
758
|
+
),
|
759
|
+
# TODO (savin): Set termination_message_policy
|
760
|
+
),
|
761
|
+
),
|
762
|
+
),
|
763
|
+
)
|
764
|
+
),
|
765
|
+
replicas=self._kwargs["replicas"],
|
766
|
+
)
|
767
|
+
|
768
|
+
|
661
769
|
class KubernetesJobSet(object):
|
662
770
|
def __init__(
|
663
771
|
self,
|
664
772
|
client,
|
665
773
|
name=None,
|
666
|
-
job_spec=None,
|
667
774
|
namespace=None,
|
668
775
|
num_parallel=None,
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
776
|
+
# explcitly declaring num_parallel because we need to ensure that
|
777
|
+
# num_parallel is an INTEGER and this abstraction is called by the
|
778
|
+
# local runtime abstraction of kubernetes.
|
779
|
+
# Argo will call another abstraction that will allow setting a lot of these
|
780
|
+
# values from the top level argo code.
|
673
781
|
**kwargs
|
674
782
|
):
|
675
783
|
self._client = client
|
676
|
-
self.
|
784
|
+
self._annotations = {}
|
785
|
+
self._labels = {}
|
677
786
|
self._group = KUBERNETES_JOBSET_GROUP
|
678
787
|
self._version = KUBERNETES_JOBSET_VERSION
|
788
|
+
self._namespace = namespace
|
679
789
|
self.name = name
|
680
790
|
|
681
|
-
|
682
|
-
|
683
|
-
|
684
|
-
|
685
|
-
|
686
|
-
|
687
|
-
jobset_main_addr = _make_domain_name(
|
688
|
-
self.name,
|
689
|
-
main_job_name,
|
690
|
-
main_job_index,
|
691
|
-
main_pod_index,
|
692
|
-
self._namespace,
|
791
|
+
self._jobset_control_addr = _make_domain_name(
|
792
|
+
name,
|
793
|
+
CONTROL_JOB_NAME,
|
794
|
+
0,
|
795
|
+
0,
|
796
|
+
namespace,
|
693
797
|
)
|
694
798
|
|
695
|
-
|
696
|
-
|
697
|
-
|
698
|
-
if "metaflow/task_id" in annotations:
|
699
|
-
del annotations["metaflow/task_id"]
|
700
|
-
|
701
|
-
control_job = get_control_job(
|
702
|
-
client=self._client.get(),
|
703
|
-
job_spec=job_spec,
|
704
|
-
jobset_main_addr=jobset_main_addr,
|
705
|
-
subdomain=subdomain,
|
706
|
-
port=port,
|
707
|
-
num_parallel=num_parallel,
|
708
|
-
namespace=namespace,
|
709
|
-
annotations=annotations,
|
799
|
+
self._control_spec = JobSetSpec(
|
800
|
+
client.get(), name=CONTROL_JOB_NAME, namespace=namespace, **kwargs
|
710
801
|
)
|
711
|
-
|
712
|
-
|
713
|
-
client=self._client.get(),
|
714
|
-
job_spec=job_spec,
|
715
|
-
job_name="worker",
|
716
|
-
jobset_main_addr=jobset_main_addr,
|
717
|
-
subdomain=subdomain,
|
718
|
-
control_task_id=task_id,
|
719
|
-
worker_task_id=worker_task_id,
|
720
|
-
replicas=num_parallel - 1,
|
721
|
-
port=port,
|
722
|
-
num_parallel=num_parallel,
|
723
|
-
namespace=namespace,
|
724
|
-
annotations=annotations,
|
802
|
+
self._worker_spec = JobSetSpec(
|
803
|
+
client.get(), name="worker", namespace=namespace, **kwargs
|
725
804
|
)
|
726
|
-
|
727
|
-
|
728
|
-
|
729
|
-
|
805
|
+
assert (
|
806
|
+
type(num_parallel) == int
|
807
|
+
), "num_parallel must be an integer" # todo: [final-refactor] : fix-me
|
808
|
+
|
809
|
+
@property
|
810
|
+
def jobset_control_addr(self):
|
811
|
+
return self._jobset_control_addr
|
812
|
+
|
813
|
+
@property
|
814
|
+
def worker(self):
|
815
|
+
return self._worker_spec
|
816
|
+
|
817
|
+
@property
|
818
|
+
def control(self):
|
819
|
+
return self._control_spec
|
820
|
+
|
821
|
+
def environment_variable_from_selector(self, name, label_value):
|
822
|
+
self.worker.environment_variable_from_selector(name, label_value)
|
823
|
+
self.control.environment_variable_from_selector(name, label_value)
|
824
|
+
return self
|
825
|
+
|
826
|
+
def environment_variables_from_selectors(self, env_dict):
|
827
|
+
for name, label_value in env_dict.items():
|
828
|
+
self.worker.environment_variable_from_selector(name, label_value)
|
829
|
+
self.control.environment_variable_from_selector(name, label_value)
|
830
|
+
return self
|
831
|
+
|
832
|
+
def environment_variable(self, name, value):
|
833
|
+
self.worker.environment_variable(name, value)
|
834
|
+
self.control.environment_variable(name, value)
|
835
|
+
return self
|
836
|
+
|
837
|
+
def label(self, name, value):
|
838
|
+
self.worker.label(name, value)
|
839
|
+
self.control.label(name, value)
|
840
|
+
self._labels = dict(self._labels, **{name: value})
|
841
|
+
return self
|
842
|
+
|
843
|
+
def annotation(self, name, value):
|
844
|
+
self.worker.annotation(name, value)
|
845
|
+
self.control.annotation(name, value)
|
846
|
+
self._annotations = dict(self._annotations, **{name: value})
|
847
|
+
return self
|
848
|
+
|
849
|
+
def secret(self, name):
|
850
|
+
self.worker.secret(name)
|
851
|
+
self.control.secret(name)
|
852
|
+
return self
|
853
|
+
|
854
|
+
def dump(self):
|
855
|
+
client = self._client.get()
|
856
|
+
return dict(
|
730
857
|
apiVersion=self._group + "/" + self._version,
|
731
858
|
kind="JobSet",
|
732
|
-
metadata=
|
733
|
-
|
734
|
-
name=self.name,
|
859
|
+
metadata=client.api_client.ApiClient().sanitize_for_serialization(
|
860
|
+
client.V1ObjectMeta(
|
861
|
+
name=self.name,
|
862
|
+
labels=self._labels,
|
863
|
+
annotations=self._annotations,
|
735
864
|
)
|
736
865
|
),
|
737
866
|
spec=dict(
|
738
|
-
replicatedJobs=[
|
867
|
+
replicatedJobs=[self.control.dump(), self.worker.dump()],
|
739
868
|
suspend=False,
|
740
869
|
startupPolicy=None,
|
741
870
|
successPolicy=None,
|
742
871
|
# The Failure Policy helps setting the number of retries for the jobset.
|
743
|
-
#
|
744
|
-
#
|
745
|
-
# If there is no retry decorator then we not set maxRestarts and instead we will
|
746
|
-
# set the attempt statically to 0. Otherwise we will make the job pickup the attempt
|
747
|
-
# from the `V1EnvVarSource.value_from.V1ObjectFieldSelector.field_path` = "metadata.annotations['jobset.sigs.k8s.io/restart-attempt']"
|
748
|
-
# failurePolicy={
|
749
|
-
# "maxRestarts" : 1
|
750
|
-
# },
|
751
|
-
# The can be set for ArgoWorkflows
|
872
|
+
# but we don't rely on it and instead rely on either the local scheduler
|
873
|
+
# or the Argo Workflows to handle retries.
|
752
874
|
failurePolicy=None,
|
753
875
|
network=None,
|
754
876
|
),
|
@@ -767,7 +889,7 @@ class KubernetesJobSet(object):
|
|
767
889
|
version=self._version,
|
768
890
|
namespace=self._namespace,
|
769
891
|
plural="jobsets",
|
770
|
-
body=self.
|
892
|
+
body=self.dump(),
|
771
893
|
)
|
772
894
|
except Exception as e:
|
773
895
|
raise KubernetesJobsetException(
|
@@ -782,3 +904,119 @@ class KubernetesJobSet(object):
|
|
782
904
|
group=self._group,
|
783
905
|
version=self._version,
|
784
906
|
)
|
907
|
+
|
908
|
+
|
909
|
+
class KubernetesArgoJobSet(object):
|
910
|
+
def __init__(self, kubernetes_sdk, name=None, namespace=None, **kwargs):
|
911
|
+
self._kubernetes_sdk = kubernetes_sdk
|
912
|
+
self._annotations = {}
|
913
|
+
self._labels = {}
|
914
|
+
self._group = KUBERNETES_JOBSET_GROUP
|
915
|
+
self._version = KUBERNETES_JOBSET_VERSION
|
916
|
+
self._namespace = namespace
|
917
|
+
self.name = name
|
918
|
+
|
919
|
+
self._jobset_control_addr = _make_domain_name(
|
920
|
+
name,
|
921
|
+
CONTROL_JOB_NAME,
|
922
|
+
0,
|
923
|
+
0,
|
924
|
+
namespace,
|
925
|
+
)
|
926
|
+
|
927
|
+
self._control_spec = JobSetSpec(
|
928
|
+
kubernetes_sdk, name=CONTROL_JOB_NAME, namespace=namespace, **kwargs
|
929
|
+
)
|
930
|
+
self._worker_spec = JobSetSpec(
|
931
|
+
kubernetes_sdk, name="worker", namespace=namespace, **kwargs
|
932
|
+
)
|
933
|
+
|
934
|
+
@property
|
935
|
+
def jobset_control_addr(self):
|
936
|
+
return self._jobset_control_addr
|
937
|
+
|
938
|
+
@property
|
939
|
+
def worker(self):
|
940
|
+
return self._worker_spec
|
941
|
+
|
942
|
+
@property
|
943
|
+
def control(self):
|
944
|
+
return self._control_spec
|
945
|
+
|
946
|
+
def environment_variable_from_selector(self, name, label_value):
|
947
|
+
self.worker.environment_variable_from_selector(name, label_value)
|
948
|
+
self.control.environment_variable_from_selector(name, label_value)
|
949
|
+
return self
|
950
|
+
|
951
|
+
def environment_variables_from_selectors(self, env_dict):
|
952
|
+
for name, label_value in env_dict.items():
|
953
|
+
self.worker.environment_variable_from_selector(name, label_value)
|
954
|
+
self.control.environment_variable_from_selector(name, label_value)
|
955
|
+
return self
|
956
|
+
|
957
|
+
def environment_variable(self, name, value):
|
958
|
+
self.worker.environment_variable(name, value)
|
959
|
+
self.control.environment_variable(name, value)
|
960
|
+
return self
|
961
|
+
|
962
|
+
def label(self, name, value):
|
963
|
+
self.worker.label(name, value)
|
964
|
+
self.control.label(name, value)
|
965
|
+
self._labels = dict(self._labels, **{name: value})
|
966
|
+
return self
|
967
|
+
|
968
|
+
def annotation(self, name, value):
|
969
|
+
self.worker.annotation(name, value)
|
970
|
+
self.control.annotation(name, value)
|
971
|
+
self._annotations = dict(self._annotations, **{name: value})
|
972
|
+
return self
|
973
|
+
|
974
|
+
def dump(self):
|
975
|
+
client = self._kubernetes_sdk
|
976
|
+
import json
|
977
|
+
|
978
|
+
data = json.dumps(
|
979
|
+
client.ApiClient().sanitize_for_serialization(
|
980
|
+
dict(
|
981
|
+
apiVersion=self._group + "/" + self._version,
|
982
|
+
kind="JobSet",
|
983
|
+
metadata=client.api_client.ApiClient().sanitize_for_serialization(
|
984
|
+
client.V1ObjectMeta(
|
985
|
+
name=self.name,
|
986
|
+
labels=self._labels,
|
987
|
+
annotations=self._annotations,
|
988
|
+
)
|
989
|
+
),
|
990
|
+
spec=dict(
|
991
|
+
replicatedJobs=[self.control.dump(), self.worker.dump()],
|
992
|
+
suspend=False,
|
993
|
+
startupPolicy=None,
|
994
|
+
successPolicy=None,
|
995
|
+
# The Failure Policy helps setting the number of retries for the jobset.
|
996
|
+
# but we don't rely on it and instead rely on either the local scheduler
|
997
|
+
# or the Argo Workflows to handle retries.
|
998
|
+
failurePolicy=None,
|
999
|
+
network=None,
|
1000
|
+
),
|
1001
|
+
status=None,
|
1002
|
+
)
|
1003
|
+
)
|
1004
|
+
)
|
1005
|
+
# The values we populate in the Jobset manifest (for Argo Workflows) piggybacks on the Argo Workflow's templating engine.
|
1006
|
+
# Even though Argo Workflows's templating helps us constructing all the necessary IDs and populating the fields
|
1007
|
+
# required by Metaflow, we run into one glitch. When we construct JSON/YAML serializable objects,
|
1008
|
+
# anything between two braces such as `{{=asInt(inputs.parameters.workerCount)}}` gets quoted. This is a problem
|
1009
|
+
# since we need to pass the value of `inputs.parameters.workerCount` as an integer and not as a string.
|
1010
|
+
# If we pass it as a string, the jobset controller will not accept the Jobset CRD we submitted to kubernetes.
|
1011
|
+
# To get around this, we need to replace the quoted substring with the unquoted substring because YAML /JSON parsers
|
1012
|
+
# won't allow deserialization with the quoting trivially.
|
1013
|
+
|
1014
|
+
# This is super important because the `inputs.parameters.workerCount` is used to set the number of replicas;
|
1015
|
+
# The value for number of replicas is derived from the value of `num_parallel` (which is set in the user-code).
|
1016
|
+
# Since the value of `num_parallel` can be dynamic and can change from run to run, we need to ensure that the
|
1017
|
+
# value can be passed-down dynamically and is **explicitly set as a integer** in the Jobset Manifest submitted as a
|
1018
|
+
# part of the Argo Workflow
|
1019
|
+
|
1020
|
+
quoted_substring = '"{{=asInt(inputs.parameters.workerCount)}}"'
|
1021
|
+
unquoted_substring = "{{=asInt(inputs.parameters.workerCount)}}"
|
1022
|
+
return data.replace(quoted_substring, unquoted_substring)
|