metaflow 2.12.8__py2.py3-none-any.whl → 2.12.10__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. metaflow/__init__.py +2 -0
  2. metaflow/cli.py +12 -4
  3. metaflow/extension_support/plugins.py +1 -0
  4. metaflow/flowspec.py +8 -1
  5. metaflow/lint.py +13 -0
  6. metaflow/metaflow_current.py +0 -8
  7. metaflow/plugins/__init__.py +12 -0
  8. metaflow/plugins/argo/argo_workflows.py +616 -46
  9. metaflow/plugins/argo/argo_workflows_cli.py +70 -3
  10. metaflow/plugins/argo/argo_workflows_decorator.py +38 -7
  11. metaflow/plugins/argo/argo_workflows_deployer.py +290 -0
  12. metaflow/plugins/argo/daemon.py +59 -0
  13. metaflow/plugins/argo/jobset_input_paths.py +16 -0
  14. metaflow/plugins/aws/batch/batch_decorator.py +16 -13
  15. metaflow/plugins/aws/step_functions/step_functions_cli.py +45 -3
  16. metaflow/plugins/aws/step_functions/step_functions_deployer.py +251 -0
  17. metaflow/plugins/cards/card_cli.py +1 -1
  18. metaflow/plugins/kubernetes/kubernetes.py +279 -52
  19. metaflow/plugins/kubernetes/kubernetes_cli.py +26 -8
  20. metaflow/plugins/kubernetes/kubernetes_client.py +0 -1
  21. metaflow/plugins/kubernetes/kubernetes_decorator.py +56 -44
  22. metaflow/plugins/kubernetes/kubernetes_job.py +7 -6
  23. metaflow/plugins/kubernetes/kubernetes_jobsets.py +511 -272
  24. metaflow/plugins/parallel_decorator.py +108 -8
  25. metaflow/plugins/secrets/secrets_decorator.py +12 -3
  26. metaflow/plugins/test_unbounded_foreach_decorator.py +39 -4
  27. metaflow/runner/deployer.py +386 -0
  28. metaflow/runner/metaflow_runner.py +1 -20
  29. metaflow/runner/nbdeploy.py +130 -0
  30. metaflow/runner/nbrun.py +4 -28
  31. metaflow/runner/utils.py +49 -0
  32. metaflow/runtime.py +246 -134
  33. metaflow/version.py +1 -1
  34. {metaflow-2.12.8.dist-info → metaflow-2.12.10.dist-info}/METADATA +2 -2
  35. {metaflow-2.12.8.dist-info → metaflow-2.12.10.dist-info}/RECORD +39 -32
  36. {metaflow-2.12.8.dist-info → metaflow-2.12.10.dist-info}/WHEEL +1 -1
  37. {metaflow-2.12.8.dist-info → metaflow-2.12.10.dist-info}/LICENSE +0 -0
  38. {metaflow-2.12.8.dist-info → metaflow-2.12.10.dist-info}/entry_points.txt +0 -0
  39. {metaflow-2.12.8.dist-info → metaflow-2.12.10.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,14 @@
1
1
  import copy
2
+ import json
2
3
  import math
3
4
  import random
4
5
  import time
5
- from metaflow.metaflow_current import current
6
+ from collections import namedtuple
7
+
6
8
  from metaflow.exception import MetaflowException
7
- from metaflow.unbounded_foreach import UBF_CONTROL, UBF_TASK
8
- import json
9
9
  from metaflow.metaflow_config import KUBERNETES_JOBSET_GROUP, KUBERNETES_JOBSET_VERSION
10
- from collections import namedtuple
10
+ from metaflow.tracing import inject_tracing_vars
11
+ from metaflow.metaflow_config import KUBERNETES_SECRETS
11
12
 
12
13
 
13
14
  class KubernetesJobsetException(MetaflowException):
@@ -51,6 +52,8 @@ def k8s_retry(deadline_seconds=60, max_backoff=32):
51
52
  return decorator
52
53
 
53
54
 
55
+ CONTROL_JOB_NAME = "control"
56
+
54
57
  JobsetStatus = namedtuple(
55
58
  "JobsetStatus",
56
59
  [
@@ -323,18 +326,18 @@ class RunningJobSet(object):
323
326
  with client.ApiClient() as api_client:
324
327
  api_instance = client.CustomObjectsApi(api_client)
325
328
  try:
326
- jobset = api_instance.get_namespaced_custom_object(
329
+ obj = api_instance.get_namespaced_custom_object(
327
330
  group=self._group,
328
331
  version=self._version,
329
332
  namespace=self._namespace,
330
- plural="jobsets",
333
+ plural=plural,
331
334
  name=self._name,
332
335
  )
333
336
 
334
337
  # Suspend the jobset and set the replica's to Zero.
335
338
  #
336
- jobset["spec"]["suspend"] = True
337
- for replicated_job in jobset["spec"]["replicatedJobs"]:
339
+ obj["spec"]["suspend"] = True
340
+ for replicated_job in obj["spec"]["replicatedJobs"]:
338
341
  replicated_job["replicas"] = 0
339
342
 
340
343
  api_instance.replace_namespaced_custom_object(
@@ -342,8 +345,8 @@ class RunningJobSet(object):
342
345
  version=self._version,
343
346
  namespace=self._namespace,
344
347
  plural=plural,
345
- name=jobset["metadata"]["name"],
346
- body=jobset,
348
+ name=obj["metadata"]["name"],
349
+ body=obj,
347
350
  )
348
351
  except Exception as e:
349
352
  raise KubernetesJobsetException(
@@ -448,203 +451,6 @@ class RunningJobSet(object):
448
451
  ).jobset_failed
449
452
 
450
453
 
451
- class TaskIdConstructor:
452
- @classmethod
453
- def jobset_worker_id(cls, control_task_id: str):
454
- return "".join(
455
- [control_task_id.replace("control", "worker"), "-", "$WORKER_REPLICA_INDEX"]
456
- )
457
-
458
- @classmethod
459
- def join_step_task_ids(cls, num_parallel):
460
- """
461
- Called within the step decorator to set the `flow._control_mapper_tasks`.
462
- Setting these allows the flow to know which tasks are needed in the join step.
463
- We set this in the `task_pre_step` method of the decorator.
464
- """
465
- control_task_id = current.task_id
466
- worker_task_id_base = control_task_id.replace("control", "worker")
467
- mapper = lambda idx: worker_task_id_base + "-%s" % (str(idx))
468
- return control_task_id, [mapper(idx) for idx in range(0, num_parallel - 1)]
469
-
470
- @classmethod
471
- def argo(cls):
472
- pass
473
-
474
-
475
- def _jobset_specific_env_vars(client, jobset_main_addr, master_port, num_parallel):
476
- return [
477
- client.V1EnvVar(
478
- name="MASTER_ADDR",
479
- value=jobset_main_addr,
480
- ),
481
- client.V1EnvVar(
482
- name="MASTER_PORT",
483
- value=str(master_port),
484
- ),
485
- client.V1EnvVar(
486
- name="WORLD_SIZE",
487
- value=str(num_parallel),
488
- ),
489
- ] + [
490
- client.V1EnvVar(
491
- name="JOBSET_RESTART_ATTEMPT",
492
- value_from=client.V1EnvVarSource(
493
- field_ref=client.V1ObjectFieldSelector(
494
- field_path="metadata.annotations['jobset.sigs.k8s.io/restart-attempt']"
495
- )
496
- ),
497
- ),
498
- client.V1EnvVar(
499
- name="METAFLOW_KUBERNETES_JOBSET_NAME",
500
- value_from=client.V1EnvVarSource(
501
- field_ref=client.V1ObjectFieldSelector(
502
- field_path="metadata.annotations['jobset.sigs.k8s.io/jobset-name']"
503
- )
504
- ),
505
- ),
506
- client.V1EnvVar(
507
- name="WORKER_REPLICA_INDEX",
508
- value_from=client.V1EnvVarSource(
509
- field_ref=client.V1ObjectFieldSelector(
510
- field_path="metadata.annotations['jobset.sigs.k8s.io/job-index']"
511
- )
512
- ),
513
- ),
514
- ]
515
-
516
-
517
- def get_control_job(
518
- client,
519
- job_spec,
520
- jobset_main_addr,
521
- subdomain,
522
- port=None,
523
- num_parallel=None,
524
- namespace=None,
525
- annotations=None,
526
- ) -> dict:
527
- master_port = port
528
-
529
- job_spec = copy.deepcopy(job_spec)
530
- job_spec.parallelism = 1
531
- job_spec.completions = 1
532
- job_spec.template.spec.set_hostname_as_fqdn = True
533
- job_spec.template.spec.subdomain = subdomain
534
- job_spec.template.metadata.annotations = copy.copy(annotations)
535
-
536
- for idx in range(len(job_spec.template.spec.containers[0].command)):
537
- # CHECK FOR THE ubf_context in the command.
538
- # Replace the UBF context to the one appropriately matching control/worker.
539
- # Since we are passing the `step_cli` one time from the top level to one
540
- # KuberentesJobSet, we need to ensure that UBF context is replaced properly
541
- # in all the worker jobs.
542
- if UBF_CONTROL in job_spec.template.spec.containers[0].command[idx]:
543
- job_spec.template.spec.containers[0].command[idx] = (
544
- job_spec.template.spec.containers[0]
545
- .command[idx]
546
- .replace(UBF_CONTROL, UBF_CONTROL + " " + "--split-index 0")
547
- )
548
-
549
- job_spec.template.spec.containers[0].env = (
550
- job_spec.template.spec.containers[0].env
551
- + _jobset_specific_env_vars(client, jobset_main_addr, master_port, num_parallel)
552
- + [
553
- client.V1EnvVar(
554
- name="CONTROL_INDEX",
555
- value=str(0),
556
- )
557
- ]
558
- )
559
-
560
- # Based on https://github.com/kubernetes-sigs/jobset/blob/v0.5.0/api/jobset/v1alpha2/jobset_types.go#L178
561
- return dict(
562
- name="control",
563
- template=client.api_client.ApiClient().sanitize_for_serialization(
564
- client.V1JobTemplateSpec(
565
- metadata=client.V1ObjectMeta(
566
- namespace=namespace,
567
- # We don't set any annotations here
568
- # since they have been either set in the JobSpec
569
- # or on the JobSet level
570
- ),
571
- spec=job_spec,
572
- )
573
- ),
574
- replicas=1, # The control job will always have 1 replica.
575
- )
576
-
577
-
578
- def get_worker_job(
579
- client,
580
- job_spec,
581
- job_name,
582
- jobset_main_addr,
583
- subdomain,
584
- control_task_id=None,
585
- worker_task_id=None,
586
- replicas=1,
587
- port=None,
588
- num_parallel=None,
589
- namespace=None,
590
- annotations=None,
591
- ) -> dict:
592
- master_port = port
593
-
594
- job_spec = copy.deepcopy(job_spec)
595
- job_spec.parallelism = 1
596
- job_spec.completions = 1
597
- job_spec.template.spec.set_hostname_as_fqdn = True
598
- job_spec.template.spec.subdomain = subdomain
599
- job_spec.template.metadata.annotations = copy.copy(annotations)
600
-
601
- for idx in range(len(job_spec.template.spec.containers[0].command)):
602
- if control_task_id in job_spec.template.spec.containers[0].command[idx]:
603
- job_spec.template.spec.containers[0].command[idx] = (
604
- job_spec.template.spec.containers[0]
605
- .command[idx]
606
- .replace(control_task_id, worker_task_id)
607
- )
608
- # CHECK FOR THE ubf_context in the command.
609
- # Replace the UBF context to the one appropriately matching control/worker.
610
- # Since we are passing the `step_cli` one time from the top level to one
611
- # KuberentesJobSet, we need to ensure that UBF context is replaced properly
612
- # in all the worker jobs.
613
- if UBF_CONTROL in job_spec.template.spec.containers[0].command[idx]:
614
- # Since all command will have a UBF_CONTROL, we need to replace the UBF_CONTROL
615
- # with the actual UBF Context and also ensure that we are setting the correct
616
- # split-index for the worker jobs.
617
- split_index_str = "--split-index `expr $[WORKER_REPLICA_INDEX] + 1`" # This set in the environment variables below
618
- job_spec.template.spec.containers[0].command[idx] = (
619
- job_spec.template.spec.containers[0]
620
- .command[idx]
621
- .replace(UBF_CONTROL, UBF_TASK + " " + split_index_str)
622
- )
623
-
624
- job_spec.template.spec.containers[0].env = job_spec.template.spec.containers[
625
- 0
626
- ].env + _jobset_specific_env_vars(
627
- client, jobset_main_addr, master_port, num_parallel
628
- )
629
-
630
- # Based on https://github.com/kubernetes-sigs/jobset/blob/v0.5.0/api/jobset/v1alpha2/jobset_types.go#L178
631
- return dict(
632
- name=job_name,
633
- template=client.api_client.ApiClient().sanitize_for_serialization(
634
- client.V1JobTemplateSpec(
635
- metadata=client.V1ObjectMeta(
636
- namespace=namespace,
637
- # We don't set any annotations here
638
- # since they have been either set in the JobSpec
639
- # or on the JobSet level
640
- ),
641
- spec=job_spec,
642
- )
643
- ),
644
- replicas=replicas,
645
- )
646
-
647
-
648
454
  def _make_domain_name(
649
455
  jobset_name, main_job_name, main_job_index, main_pod_index, namespace
650
456
  ):
@@ -658,97 +464,414 @@ def _make_domain_name(
658
464
  )
659
465
 
660
466
 
467
+ class JobSetSpec(object):
468
+ def __init__(self, kubernetes_sdk, name, **kwargs):
469
+ self._kubernetes_sdk = kubernetes_sdk
470
+ self._kwargs = kwargs
471
+ self.name = name
472
+
473
+ def replicas(self, replicas):
474
+ self._kwargs["replicas"] = replicas
475
+ return self
476
+
477
+ def step_name(self, step_name):
478
+ self._kwargs["step_name"] = step_name
479
+ return self
480
+
481
+ def namespace(self, namespace):
482
+ self._kwargs["namespace"] = namespace
483
+ return self
484
+
485
+ def command(self, command):
486
+ self._kwargs["command"] = command
487
+ return self
488
+
489
+ def image(self, image):
490
+ self._kwargs["image"] = image
491
+ return self
492
+
493
+ def cpu(self, cpu):
494
+ self._kwargs["cpu"] = cpu
495
+ return self
496
+
497
+ def memory(self, mem):
498
+ self._kwargs["memory"] = mem
499
+ return self
500
+
501
+ def environment_variable(self, name, value):
502
+ # Never set to None
503
+ if value is None:
504
+ return self
505
+ self._kwargs["environment_variables"] = dict(
506
+ self._kwargs.get("environment_variables", {}), **{name: value}
507
+ )
508
+ return self
509
+
510
+ def secret(self, name):
511
+ if name is None:
512
+ return self
513
+ if len(self._kwargs.get("secrets", [])) == 0:
514
+ self._kwargs["secrets"] = []
515
+ self._kwargs["secrets"] = list(set(self._kwargs["secrets"] + [name]))
516
+
517
+ def environment_variable_from_selector(self, name, label_value):
518
+ # Never set to None
519
+ if label_value is None:
520
+ return self
521
+ self._kwargs["environment_variables_from_selectors"] = dict(
522
+ self._kwargs.get("environment_variables_from_selectors", {}),
523
+ **{name: label_value}
524
+ )
525
+ return self
526
+
527
+ def label(self, name, value):
528
+ self._kwargs["labels"] = dict(self._kwargs.get("labels", {}), **{name: value})
529
+ return self
530
+
531
+ def annotation(self, name, value):
532
+ self._kwargs["annotations"] = dict(
533
+ self._kwargs.get("annotations", {}), **{name: value}
534
+ )
535
+ return self
536
+
537
+ def dump(self):
538
+ client = self._kubernetes_sdk
539
+ use_tmpfs = self._kwargs["use_tmpfs"]
540
+ tmpfs_size = self._kwargs["tmpfs_size"]
541
+ tmpfs_enabled = use_tmpfs or (tmpfs_size and not use_tmpfs)
542
+ shared_memory = (
543
+ int(self._kwargs["shared_memory"])
544
+ if self._kwargs["shared_memory"]
545
+ else None
546
+ )
547
+
548
+ return dict(
549
+ name=self.name,
550
+ template=client.api_client.ApiClient().sanitize_for_serialization(
551
+ client.V1JobTemplateSpec(
552
+ metadata=client.V1ObjectMeta(
553
+ namespace=self._kwargs["namespace"],
554
+ # We don't set any annotations here
555
+ # since they have been either set in the JobSpec
556
+ # or on the JobSet level
557
+ ),
558
+ spec=client.V1JobSpec(
559
+ # Retries are handled by Metaflow when it is responsible for
560
+ # executing the flow. The responsibility is moved to Kubernetes
561
+ # when Argo Workflows is responsible for the execution.
562
+ backoff_limit=self._kwargs.get("retries", 0),
563
+ completions=1,
564
+ parallelism=1,
565
+ ttl_seconds_after_finished=7
566
+ * 60
567
+ * 60 # Remove job after a week. TODO: Make this configurable
568
+ * 24,
569
+ template=client.V1PodTemplateSpec(
570
+ metadata=client.V1ObjectMeta(
571
+ annotations=self._kwargs.get("annotations", {}),
572
+ labels=self._kwargs.get("labels", {}),
573
+ namespace=self._kwargs["namespace"],
574
+ ),
575
+ spec=client.V1PodSpec(
576
+ ## --- jobset require podspec deets start----
577
+ subdomain=self._kwargs["subdomain"],
578
+ set_hostname_as_fqdn=True,
579
+ ## --- jobset require podspec deets end ----
580
+ # Timeout is set on the pod and not the job (important!)
581
+ active_deadline_seconds=self._kwargs[
582
+ "timeout_in_seconds"
583
+ ],
584
+ # TODO (savin): Enable affinities for GPU scheduling.
585
+ # affinity=?,
586
+ containers=[
587
+ client.V1Container(
588
+ command=self._kwargs["command"],
589
+ termination_message_policy="FallbackToLogsOnError",
590
+ ports=[]
591
+ if self._kwargs["port"] is None
592
+ else [
593
+ client.V1ContainerPort(
594
+ container_port=int(self._kwargs["port"])
595
+ )
596
+ ],
597
+ env=[
598
+ client.V1EnvVar(name=k, value=str(v))
599
+ for k, v in self._kwargs.get(
600
+ "environment_variables", {}
601
+ ).items()
602
+ ]
603
+ # And some downward API magic. Add (key, value)
604
+ # pairs below to make pod metadata available
605
+ # within Kubernetes container.
606
+ + [
607
+ client.V1EnvVar(
608
+ name=k,
609
+ value_from=client.V1EnvVarSource(
610
+ field_ref=client.V1ObjectFieldSelector(
611
+ field_path=str(v)
612
+ )
613
+ ),
614
+ )
615
+ for k, v in self._kwargs.get(
616
+ "environment_variables_from_selectors",
617
+ {},
618
+ ).items()
619
+ ]
620
+ + [
621
+ client.V1EnvVar(name=k, value=str(v))
622
+ for k, v in inject_tracing_vars({}).items()
623
+ ],
624
+ env_from=[
625
+ client.V1EnvFromSource(
626
+ secret_ref=client.V1SecretEnvSource(
627
+ name=str(k),
628
+ # optional=True
629
+ )
630
+ )
631
+ for k in list(
632
+ self._kwargs.get("secrets", [])
633
+ )
634
+ if k
635
+ ],
636
+ image=self._kwargs["image"],
637
+ image_pull_policy=self._kwargs[
638
+ "image_pull_policy"
639
+ ],
640
+ name=self._kwargs["step_name"].replace(
641
+ "_", "-"
642
+ ),
643
+ resources=client.V1ResourceRequirements(
644
+ requests={
645
+ "cpu": str(self._kwargs["cpu"]),
646
+ "memory": "%sM"
647
+ % str(self._kwargs["memory"]),
648
+ "ephemeral-storage": "%sM"
649
+ % str(self._kwargs["disk"]),
650
+ },
651
+ limits={
652
+ "%s.com/gpu".lower()
653
+ % self._kwargs["gpu_vendor"]: str(
654
+ self._kwargs["gpu"]
655
+ )
656
+ for k in [0]
657
+ # Don't set GPU limits if gpu isn't specified.
658
+ if self._kwargs["gpu"] is not None
659
+ },
660
+ ),
661
+ volume_mounts=(
662
+ [
663
+ client.V1VolumeMount(
664
+ mount_path=self._kwargs.get(
665
+ "tmpfs_path"
666
+ ),
667
+ name="tmpfs-ephemeral-volume",
668
+ )
669
+ ]
670
+ if tmpfs_enabled
671
+ else []
672
+ )
673
+ + (
674
+ [
675
+ client.V1VolumeMount(
676
+ mount_path="/dev/shm", name="dhsm"
677
+ )
678
+ ]
679
+ if shared_memory
680
+ else []
681
+ )
682
+ + (
683
+ [
684
+ client.V1VolumeMount(
685
+ mount_path=path, name=claim
686
+ )
687
+ for claim, path in self._kwargs[
688
+ "persistent_volume_claims"
689
+ ].items()
690
+ ]
691
+ if self._kwargs["persistent_volume_claims"]
692
+ is not None
693
+ else []
694
+ ),
695
+ )
696
+ ],
697
+ node_selector=self._kwargs.get("node_selector"),
698
+ # TODO (savin): Support image_pull_secrets
699
+ # image_pull_secrets=?,
700
+ # TODO (savin): Support preemption policies
701
+ # preemption_policy=?,
702
+ #
703
+ # A Container in a Pod may fail for a number of
704
+ # reasons, such as because the process in it exited
705
+ # with a non-zero exit code, or the Container was
706
+ # killed due to OOM etc. If this happens, fail the pod
707
+ # and let Metaflow handle the retries.
708
+ restart_policy="Never",
709
+ service_account_name=self._kwargs["service_account"],
710
+ # Terminate the container immediately on SIGTERM
711
+ termination_grace_period_seconds=0,
712
+ tolerations=[
713
+ client.V1Toleration(**toleration)
714
+ for toleration in self._kwargs.get("tolerations")
715
+ or []
716
+ ],
717
+ volumes=(
718
+ [
719
+ client.V1Volume(
720
+ name="tmpfs-ephemeral-volume",
721
+ empty_dir=client.V1EmptyDirVolumeSource(
722
+ medium="Memory",
723
+ # Add default unit as ours differs from Kubernetes default.
724
+ size_limit="{}Mi".format(tmpfs_size),
725
+ ),
726
+ )
727
+ ]
728
+ if tmpfs_enabled
729
+ else []
730
+ )
731
+ + (
732
+ [
733
+ client.V1Volume(
734
+ name="dhsm",
735
+ empty_dir=client.V1EmptyDirVolumeSource(
736
+ medium="Memory",
737
+ size_limit="{}Mi".format(shared_memory),
738
+ ),
739
+ )
740
+ ]
741
+ if shared_memory
742
+ else []
743
+ )
744
+ + (
745
+ [
746
+ client.V1Volume(
747
+ name=claim,
748
+ persistent_volume_claim=client.V1PersistentVolumeClaimVolumeSource(
749
+ claim_name=claim
750
+ ),
751
+ )
752
+ for claim in self._kwargs[
753
+ "persistent_volume_claims"
754
+ ].keys()
755
+ ]
756
+ if self._kwargs["persistent_volume_claims"]
757
+ is not None
758
+ else []
759
+ ),
760
+ # TODO (savin): Set termination_message_policy
761
+ ),
762
+ ),
763
+ ),
764
+ )
765
+ ),
766
+ replicas=self._kwargs["replicas"],
767
+ )
768
+
769
+
661
770
  class KubernetesJobSet(object):
662
771
  def __init__(
663
772
  self,
664
773
  client,
665
774
  name=None,
666
- job_spec=None,
667
775
  namespace=None,
668
776
  num_parallel=None,
669
- annotations=None,
670
- labels=None,
671
- port=None,
672
- task_id=None,
777
+ # explcitly declaring num_parallel because we need to ensure that
778
+ # num_parallel is an INTEGER and this abstraction is called by the
779
+ # local runtime abstraction of kubernetes.
780
+ # Argo will call another abstraction that will allow setting a lot of these
781
+ # values from the top level argo code.
673
782
  **kwargs
674
783
  ):
675
784
  self._client = client
676
- self._kwargs = kwargs
785
+ self._annotations = {}
786
+ self._labels = {}
677
787
  self._group = KUBERNETES_JOBSET_GROUP
678
788
  self._version = KUBERNETES_JOBSET_VERSION
789
+ self._namespace = namespace
679
790
  self.name = name
680
791
 
681
- main_job_name = "control"
682
- main_job_index = 0
683
- main_pod_index = 0
684
- subdomain = self.name
685
- num_parallel = int(1 if not num_parallel else num_parallel)
686
- self._namespace = namespace
687
- jobset_main_addr = _make_domain_name(
688
- self.name,
689
- main_job_name,
690
- main_job_index,
691
- main_pod_index,
692
- self._namespace,
792
+ self._jobset_control_addr = _make_domain_name(
793
+ name,
794
+ CONTROL_JOB_NAME,
795
+ 0,
796
+ 0,
797
+ namespace,
693
798
  )
694
799
 
695
- annotations = {} if not annotations else annotations
696
- labels = {} if not labels else labels
697
-
698
- if "metaflow/task_id" in annotations:
699
- del annotations["metaflow/task_id"]
700
-
701
- control_job = get_control_job(
702
- client=self._client.get(),
703
- job_spec=job_spec,
704
- jobset_main_addr=jobset_main_addr,
705
- subdomain=subdomain,
706
- port=port,
707
- num_parallel=num_parallel,
708
- namespace=namespace,
709
- annotations=annotations,
800
+ self._control_spec = JobSetSpec(
801
+ client.get(), name=CONTROL_JOB_NAME, namespace=namespace, **kwargs
710
802
  )
711
- worker_task_id = TaskIdConstructor.jobset_worker_id(task_id)
712
- worker_job = get_worker_job(
713
- client=self._client.get(),
714
- job_spec=job_spec,
715
- job_name="worker",
716
- jobset_main_addr=jobset_main_addr,
717
- subdomain=subdomain,
718
- control_task_id=task_id,
719
- worker_task_id=worker_task_id,
720
- replicas=num_parallel - 1,
721
- port=port,
722
- num_parallel=num_parallel,
723
- namespace=namespace,
724
- annotations=annotations,
803
+ self._worker_spec = JobSetSpec(
804
+ client.get(), name="worker", namespace=namespace, **kwargs
725
805
  )
726
- worker_jobs = [worker_job]
727
- # Based on https://github.com/kubernetes-sigs/jobset/blob/v0.5.0/api/jobset/v1alpha2/jobset_types.go#L163
728
- _kclient = client.get()
729
- self._jobset = dict(
806
+ assert (
807
+ type(num_parallel) == int
808
+ ), "num_parallel must be an integer" # todo: [final-refactor] : fix-me
809
+
810
+ @property
811
+ def jobset_control_addr(self):
812
+ return self._jobset_control_addr
813
+
814
+ @property
815
+ def worker(self):
816
+ return self._worker_spec
817
+
818
+ @property
819
+ def control(self):
820
+ return self._control_spec
821
+
822
+ def environment_variable_from_selector(self, name, label_value):
823
+ self.worker.environment_variable_from_selector(name, label_value)
824
+ self.control.environment_variable_from_selector(name, label_value)
825
+ return self
826
+
827
+ def environment_variables_from_selectors(self, env_dict):
828
+ for name, label_value in env_dict.items():
829
+ self.worker.environment_variable_from_selector(name, label_value)
830
+ self.control.environment_variable_from_selector(name, label_value)
831
+ return self
832
+
833
+ def environment_variable(self, name, value):
834
+ self.worker.environment_variable(name, value)
835
+ self.control.environment_variable(name, value)
836
+ return self
837
+
838
+ def label(self, name, value):
839
+ self.worker.label(name, value)
840
+ self.control.label(name, value)
841
+ self._labels = dict(self._labels, **{name: value})
842
+ return self
843
+
844
+ def annotation(self, name, value):
845
+ self.worker.annotation(name, value)
846
+ self.control.annotation(name, value)
847
+ self._annotations = dict(self._annotations, **{name: value})
848
+ return self
849
+
850
+ def secret(self, name):
851
+ self.worker.secret(name)
852
+ self.control.secret(name)
853
+ return self
854
+
855
+ def dump(self):
856
+ client = self._client.get()
857
+ return dict(
730
858
  apiVersion=self._group + "/" + self._version,
731
859
  kind="JobSet",
732
- metadata=_kclient.api_client.ApiClient().sanitize_for_serialization(
733
- _kclient.V1ObjectMeta(
734
- name=self.name, labels=labels, annotations=annotations
860
+ metadata=client.api_client.ApiClient().sanitize_for_serialization(
861
+ client.V1ObjectMeta(
862
+ name=self.name,
863
+ labels=self._labels,
864
+ annotations=self._annotations,
735
865
  )
736
866
  ),
737
867
  spec=dict(
738
- replicatedJobs=[control_job] + worker_jobs,
868
+ replicatedJobs=[self.control.dump(), self.worker.dump()],
739
869
  suspend=False,
740
870
  startupPolicy=None,
741
871
  successPolicy=None,
742
872
  # The Failure Policy helps setting the number of retries for the jobset.
743
- # It cannot accept a value of 0 for maxRestarts.
744
- # So the attempt needs to be smartly set.
745
- # If there is no retry decorator then we not set maxRestarts and instead we will
746
- # set the attempt statically to 0. Otherwise we will make the job pickup the attempt
747
- # from the `V1EnvVarSource.value_from.V1ObjectFieldSelector.field_path` = "metadata.annotations['jobset.sigs.k8s.io/restart-attempt']"
748
- # failurePolicy={
749
- # "maxRestarts" : 1
750
- # },
751
- # The can be set for ArgoWorkflows
873
+ # but we don't rely on it and instead rely on either the local scheduler
874
+ # or the Argo Workflows to handle retries.
752
875
  failurePolicy=None,
753
876
  network=None,
754
877
  ),
@@ -767,7 +890,7 @@ class KubernetesJobSet(object):
767
890
  version=self._version,
768
891
  namespace=self._namespace,
769
892
  plural="jobsets",
770
- body=self._jobset,
893
+ body=self.dump(),
771
894
  )
772
895
  except Exception as e:
773
896
  raise KubernetesJobsetException(
@@ -782,3 +905,119 @@ class KubernetesJobSet(object):
782
905
  group=self._group,
783
906
  version=self._version,
784
907
  )
908
+
909
+
910
+ class KubernetesArgoJobSet(object):
911
+ def __init__(self, kubernetes_sdk, name=None, namespace=None, **kwargs):
912
+ self._kubernetes_sdk = kubernetes_sdk
913
+ self._annotations = {}
914
+ self._labels = {}
915
+ self._group = KUBERNETES_JOBSET_GROUP
916
+ self._version = KUBERNETES_JOBSET_VERSION
917
+ self._namespace = namespace
918
+ self.name = name
919
+
920
+ self._jobset_control_addr = _make_domain_name(
921
+ name,
922
+ CONTROL_JOB_NAME,
923
+ 0,
924
+ 0,
925
+ namespace,
926
+ )
927
+
928
+ self._control_spec = JobSetSpec(
929
+ kubernetes_sdk, name=CONTROL_JOB_NAME, namespace=namespace, **kwargs
930
+ )
931
+ self._worker_spec = JobSetSpec(
932
+ kubernetes_sdk, name="worker", namespace=namespace, **kwargs
933
+ )
934
+
935
+ @property
936
+ def jobset_control_addr(self):
937
+ return self._jobset_control_addr
938
+
939
+ @property
940
+ def worker(self):
941
+ return self._worker_spec
942
+
943
+ @property
944
+ def control(self):
945
+ return self._control_spec
946
+
947
+ def environment_variable_from_selector(self, name, label_value):
948
+ self.worker.environment_variable_from_selector(name, label_value)
949
+ self.control.environment_variable_from_selector(name, label_value)
950
+ return self
951
+
952
+ def environment_variables_from_selectors(self, env_dict):
953
+ for name, label_value in env_dict.items():
954
+ self.worker.environment_variable_from_selector(name, label_value)
955
+ self.control.environment_variable_from_selector(name, label_value)
956
+ return self
957
+
958
+ def environment_variable(self, name, value):
959
+ self.worker.environment_variable(name, value)
960
+ self.control.environment_variable(name, value)
961
+ return self
962
+
963
+ def label(self, name, value):
964
+ self.worker.label(name, value)
965
+ self.control.label(name, value)
966
+ self._labels = dict(self._labels, **{name: value})
967
+ return self
968
+
969
+ def annotation(self, name, value):
970
+ self.worker.annotation(name, value)
971
+ self.control.annotation(name, value)
972
+ self._annotations = dict(self._annotations, **{name: value})
973
+ return self
974
+
975
+ def dump(self):
976
+ client = self._kubernetes_sdk
977
+ import json
978
+
979
+ data = json.dumps(
980
+ client.ApiClient().sanitize_for_serialization(
981
+ dict(
982
+ apiVersion=self._group + "/" + self._version,
983
+ kind="JobSet",
984
+ metadata=client.api_client.ApiClient().sanitize_for_serialization(
985
+ client.V1ObjectMeta(
986
+ name=self.name,
987
+ labels=self._labels,
988
+ annotations=self._annotations,
989
+ )
990
+ ),
991
+ spec=dict(
992
+ replicatedJobs=[self.control.dump(), self.worker.dump()],
993
+ suspend=False,
994
+ startupPolicy=None,
995
+ successPolicy=None,
996
+ # The Failure Policy helps setting the number of retries for the jobset.
997
+ # but we don't rely on it and instead rely on either the local scheduler
998
+ # or the Argo Workflows to handle retries.
999
+ failurePolicy=None,
1000
+ network=None,
1001
+ ),
1002
+ status=None,
1003
+ )
1004
+ )
1005
+ )
1006
+ # The values we populate in the Jobset manifest (for Argo Workflows) piggybacks on the Argo Workflow's templating engine.
1007
+ # Even though Argo Workflows's templating helps us constructing all the necessary IDs and populating the fields
1008
+ # required by Metaflow, we run into one glitch. When we construct JSON/YAML serializable objects,
1009
+ # anything between two braces such as `{{=asInt(inputs.parameters.workerCount)}}` gets quoted. This is a problem
1010
+ # since we need to pass the value of `inputs.parameters.workerCount` as an integer and not as a string.
1011
+ # If we pass it as a string, the jobset controller will not accept the Jobset CRD we submitted to kubernetes.
1012
+ # To get around this, we need to replace the quoted substring with the unquoted substring because YAML /JSON parsers
1013
+ # won't allow deserialization with the quoting trivially.
1014
+
1015
+ # This is super important because the `inputs.parameters.workerCount` is used to set the number of replicas;
1016
+ # The value for number of replicas is derived from the value of `num_parallel` (which is set in the user-code).
1017
+ # Since the value of `num_parallel` can be dynamic and can change from run to run, we need to ensure that the
1018
+ # value can be passed-down dynamically and is **explicitly set as a integer** in the Jobset Manifest submitted as a
1019
+ # part of the Argo Workflow
1020
+
1021
+ quoted_substring = '"{{=asInt(inputs.parameters.workerCount)}}"'
1022
+ unquoted_substring = "{{=asInt(inputs.parameters.workerCount)}}"
1023
+ return data.replace(quoted_substring, unquoted_substring)