metaflow 2.12.7__py2.py3-none-any.whl → 2.12.9__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. metaflow/__init__.py +2 -0
  2. metaflow/cli.py +12 -4
  3. metaflow/extension_support/plugins.py +1 -0
  4. metaflow/flowspec.py +8 -1
  5. metaflow/lint.py +13 -0
  6. metaflow/metaflow_current.py +0 -8
  7. metaflow/plugins/__init__.py +12 -0
  8. metaflow/plugins/argo/argo_workflows.py +462 -42
  9. metaflow/plugins/argo/argo_workflows_cli.py +60 -3
  10. metaflow/plugins/argo/argo_workflows_decorator.py +38 -7
  11. metaflow/plugins/argo/argo_workflows_deployer.py +290 -0
  12. metaflow/plugins/argo/jobset_input_paths.py +16 -0
  13. metaflow/plugins/aws/batch/batch_decorator.py +16 -13
  14. metaflow/plugins/aws/step_functions/step_functions_cli.py +45 -3
  15. metaflow/plugins/aws/step_functions/step_functions_deployer.py +251 -0
  16. metaflow/plugins/cards/card_cli.py +1 -1
  17. metaflow/plugins/kubernetes/kubernetes.py +279 -52
  18. metaflow/plugins/kubernetes/kubernetes_cli.py +26 -8
  19. metaflow/plugins/kubernetes/kubernetes_client.py +0 -1
  20. metaflow/plugins/kubernetes/kubernetes_decorator.py +56 -44
  21. metaflow/plugins/kubernetes/kubernetes_job.py +6 -6
  22. metaflow/plugins/kubernetes/kubernetes_jobsets.py +510 -272
  23. metaflow/plugins/parallel_decorator.py +108 -8
  24. metaflow/plugins/pypi/bootstrap.py +1 -1
  25. metaflow/plugins/pypi/micromamba.py +1 -1
  26. metaflow/plugins/secrets/secrets_decorator.py +12 -3
  27. metaflow/plugins/test_unbounded_foreach_decorator.py +39 -4
  28. metaflow/runner/deployer.py +386 -0
  29. metaflow/runner/metaflow_runner.py +1 -20
  30. metaflow/runner/nbdeploy.py +130 -0
  31. metaflow/runner/nbrun.py +4 -28
  32. metaflow/runner/utils.py +49 -0
  33. metaflow/runtime.py +246 -134
  34. metaflow/version.py +1 -1
  35. {metaflow-2.12.7.dist-info → metaflow-2.12.9.dist-info}/METADATA +2 -2
  36. {metaflow-2.12.7.dist-info → metaflow-2.12.9.dist-info}/RECORD +40 -34
  37. {metaflow-2.12.7.dist-info → metaflow-2.12.9.dist-info}/WHEEL +1 -1
  38. {metaflow-2.12.7.dist-info → metaflow-2.12.9.dist-info}/LICENSE +0 -0
  39. {metaflow-2.12.7.dist-info → metaflow-2.12.9.dist-info}/entry_points.txt +0 -0
  40. {metaflow-2.12.7.dist-info → metaflow-2.12.9.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,14 @@
1
1
  import copy
2
+ import json
2
3
  import math
3
4
  import random
4
5
  import time
5
- from metaflow.metaflow_current import current
6
+ from collections import namedtuple
7
+
6
8
  from metaflow.exception import MetaflowException
7
- from metaflow.unbounded_foreach import UBF_CONTROL, UBF_TASK
8
- import json
9
9
  from metaflow.metaflow_config import KUBERNETES_JOBSET_GROUP, KUBERNETES_JOBSET_VERSION
10
- from collections import namedtuple
10
+ from metaflow.tracing import inject_tracing_vars
11
+ from metaflow.metaflow_config import KUBERNETES_SECRETS
11
12
 
12
13
 
13
14
  class KubernetesJobsetException(MetaflowException):
@@ -51,6 +52,8 @@ def k8s_retry(deadline_seconds=60, max_backoff=32):
51
52
  return decorator
52
53
 
53
54
 
55
+ CONTROL_JOB_NAME = "control"
56
+
54
57
  JobsetStatus = namedtuple(
55
58
  "JobsetStatus",
56
59
  [
@@ -323,18 +326,18 @@ class RunningJobSet(object):
323
326
  with client.ApiClient() as api_client:
324
327
  api_instance = client.CustomObjectsApi(api_client)
325
328
  try:
326
- jobset = api_instance.get_namespaced_custom_object(
329
+ obj = api_instance.get_namespaced_custom_object(
327
330
  group=self._group,
328
331
  version=self._version,
329
332
  namespace=self._namespace,
330
- plural="jobsets",
333
+ plural=plural,
331
334
  name=self._name,
332
335
  )
333
336
 
334
337
  # Suspend the jobset and set the replica's to Zero.
335
338
  #
336
- jobset["spec"]["suspend"] = True
337
- for replicated_job in jobset["spec"]["replicatedJobs"]:
339
+ obj["spec"]["suspend"] = True
340
+ for replicated_job in obj["spec"]["replicatedJobs"]:
338
341
  replicated_job["replicas"] = 0
339
342
 
340
343
  api_instance.replace_namespaced_custom_object(
@@ -342,8 +345,8 @@ class RunningJobSet(object):
342
345
  version=self._version,
343
346
  namespace=self._namespace,
344
347
  plural=plural,
345
- name=jobset["metadata"]["name"],
346
- body=jobset,
348
+ name=obj["metadata"]["name"],
349
+ body=obj,
347
350
  )
348
351
  except Exception as e:
349
352
  raise KubernetesJobsetException(
@@ -448,203 +451,6 @@ class RunningJobSet(object):
448
451
  ).jobset_failed
449
452
 
450
453
 
451
- class TaskIdConstructor:
452
- @classmethod
453
- def jobset_worker_id(cls, control_task_id: str):
454
- return "".join(
455
- [control_task_id.replace("control", "worker"), "-", "$WORKER_REPLICA_INDEX"]
456
- )
457
-
458
- @classmethod
459
- def join_step_task_ids(cls, num_parallel):
460
- """
461
- Called within the step decorator to set the `flow._control_mapper_tasks`.
462
- Setting these allows the flow to know which tasks are needed in the join step.
463
- We set this in the `task_pre_step` method of the decorator.
464
- """
465
- control_task_id = current.task_id
466
- worker_task_id_base = control_task_id.replace("control", "worker")
467
- mapper = lambda idx: worker_task_id_base + "-%s" % (str(idx))
468
- return control_task_id, [mapper(idx) for idx in range(0, num_parallel - 1)]
469
-
470
- @classmethod
471
- def argo(cls):
472
- pass
473
-
474
-
475
- def _jobset_specific_env_vars(client, jobset_main_addr, master_port, num_parallel):
476
- return [
477
- client.V1EnvVar(
478
- name="MASTER_ADDR",
479
- value=jobset_main_addr,
480
- ),
481
- client.V1EnvVar(
482
- name="MASTER_PORT",
483
- value=str(master_port),
484
- ),
485
- client.V1EnvVar(
486
- name="WORLD_SIZE",
487
- value=str(num_parallel),
488
- ),
489
- ] + [
490
- client.V1EnvVar(
491
- name="JOBSET_RESTART_ATTEMPT",
492
- value_from=client.V1EnvVarSource(
493
- field_ref=client.V1ObjectFieldSelector(
494
- field_path="metadata.annotations['jobset.sigs.k8s.io/restart-attempt']"
495
- )
496
- ),
497
- ),
498
- client.V1EnvVar(
499
- name="METAFLOW_KUBERNETES_JOBSET_NAME",
500
- value_from=client.V1EnvVarSource(
501
- field_ref=client.V1ObjectFieldSelector(
502
- field_path="metadata.annotations['jobset.sigs.k8s.io/jobset-name']"
503
- )
504
- ),
505
- ),
506
- client.V1EnvVar(
507
- name="WORKER_REPLICA_INDEX",
508
- value_from=client.V1EnvVarSource(
509
- field_ref=client.V1ObjectFieldSelector(
510
- field_path="metadata.annotations['jobset.sigs.k8s.io/job-index']"
511
- )
512
- ),
513
- ),
514
- ]
515
-
516
-
517
- def get_control_job(
518
- client,
519
- job_spec,
520
- jobset_main_addr,
521
- subdomain,
522
- port=None,
523
- num_parallel=None,
524
- namespace=None,
525
- annotations=None,
526
- ) -> dict:
527
- master_port = port
528
-
529
- job_spec = copy.deepcopy(job_spec)
530
- job_spec.parallelism = 1
531
- job_spec.completions = 1
532
- job_spec.template.spec.set_hostname_as_fqdn = True
533
- job_spec.template.spec.subdomain = subdomain
534
- job_spec.template.metadata.annotations = copy.copy(annotations)
535
-
536
- for idx in range(len(job_spec.template.spec.containers[0].command)):
537
- # CHECK FOR THE ubf_context in the command.
538
- # Replace the UBF context to the one appropriately matching control/worker.
539
- # Since we are passing the `step_cli` one time from the top level to one
540
- # KuberentesJobSet, we need to ensure that UBF context is replaced properly
541
- # in all the worker jobs.
542
- if UBF_CONTROL in job_spec.template.spec.containers[0].command[idx]:
543
- job_spec.template.spec.containers[0].command[idx] = (
544
- job_spec.template.spec.containers[0]
545
- .command[idx]
546
- .replace(UBF_CONTROL, UBF_CONTROL + " " + "--split-index 0")
547
- )
548
-
549
- job_spec.template.spec.containers[0].env = (
550
- job_spec.template.spec.containers[0].env
551
- + _jobset_specific_env_vars(client, jobset_main_addr, master_port, num_parallel)
552
- + [
553
- client.V1EnvVar(
554
- name="CONTROL_INDEX",
555
- value=str(0),
556
- )
557
- ]
558
- )
559
-
560
- # Based on https://github.com/kubernetes-sigs/jobset/blob/v0.5.0/api/jobset/v1alpha2/jobset_types.go#L178
561
- return dict(
562
- name="control",
563
- template=client.api_client.ApiClient().sanitize_for_serialization(
564
- client.V1JobTemplateSpec(
565
- metadata=client.V1ObjectMeta(
566
- namespace=namespace,
567
- # We don't set any annotations here
568
- # since they have been either set in the JobSpec
569
- # or on the JobSet level
570
- ),
571
- spec=job_spec,
572
- )
573
- ),
574
- replicas=1, # The control job will always have 1 replica.
575
- )
576
-
577
-
578
- def get_worker_job(
579
- client,
580
- job_spec,
581
- job_name,
582
- jobset_main_addr,
583
- subdomain,
584
- control_task_id=None,
585
- worker_task_id=None,
586
- replicas=1,
587
- port=None,
588
- num_parallel=None,
589
- namespace=None,
590
- annotations=None,
591
- ) -> dict:
592
- master_port = port
593
-
594
- job_spec = copy.deepcopy(job_spec)
595
- job_spec.parallelism = 1
596
- job_spec.completions = 1
597
- job_spec.template.spec.set_hostname_as_fqdn = True
598
- job_spec.template.spec.subdomain = subdomain
599
- job_spec.template.metadata.annotations = copy.copy(annotations)
600
-
601
- for idx in range(len(job_spec.template.spec.containers[0].command)):
602
- if control_task_id in job_spec.template.spec.containers[0].command[idx]:
603
- job_spec.template.spec.containers[0].command[idx] = (
604
- job_spec.template.spec.containers[0]
605
- .command[idx]
606
- .replace(control_task_id, worker_task_id)
607
- )
608
- # CHECK FOR THE ubf_context in the command.
609
- # Replace the UBF context to the one appropriately matching control/worker.
610
- # Since we are passing the `step_cli` one time from the top level to one
611
- # KuberentesJobSet, we need to ensure that UBF context is replaced properly
612
- # in all the worker jobs.
613
- if UBF_CONTROL in job_spec.template.spec.containers[0].command[idx]:
614
- # Since all command will have a UBF_CONTROL, we need to replace the UBF_CONTROL
615
- # with the actual UBF Context and also ensure that we are setting the correct
616
- # split-index for the worker jobs.
617
- split_index_str = "--split-index `expr $[WORKER_REPLICA_INDEX] + 1`" # This set in the environment variables below
618
- job_spec.template.spec.containers[0].command[idx] = (
619
- job_spec.template.spec.containers[0]
620
- .command[idx]
621
- .replace(UBF_CONTROL, UBF_TASK + " " + split_index_str)
622
- )
623
-
624
- job_spec.template.spec.containers[0].env = job_spec.template.spec.containers[
625
- 0
626
- ].env + _jobset_specific_env_vars(
627
- client, jobset_main_addr, master_port, num_parallel
628
- )
629
-
630
- # Based on https://github.com/kubernetes-sigs/jobset/blob/v0.5.0/api/jobset/v1alpha2/jobset_types.go#L178
631
- return dict(
632
- name=job_name,
633
- template=client.api_client.ApiClient().sanitize_for_serialization(
634
- client.V1JobTemplateSpec(
635
- metadata=client.V1ObjectMeta(
636
- namespace=namespace,
637
- # We don't set any annotations here
638
- # since they have been either set in the JobSpec
639
- # or on the JobSet level
640
- ),
641
- spec=job_spec,
642
- )
643
- ),
644
- replicas=replicas,
645
- )
646
-
647
-
648
454
  def _make_domain_name(
649
455
  jobset_name, main_job_name, main_job_index, main_pod_index, namespace
650
456
  ):
@@ -658,97 +464,413 @@ def _make_domain_name(
658
464
  )
659
465
 
660
466
 
467
+ class JobSetSpec(object):
468
+ def __init__(self, kubernetes_sdk, name, **kwargs):
469
+ self._kubernetes_sdk = kubernetes_sdk
470
+ self._kwargs = kwargs
471
+ self.name = name
472
+
473
+ def replicas(self, replicas):
474
+ self._kwargs["replicas"] = replicas
475
+ return self
476
+
477
+ def step_name(self, step_name):
478
+ self._kwargs["step_name"] = step_name
479
+ return self
480
+
481
+ def namespace(self, namespace):
482
+ self._kwargs["namespace"] = namespace
483
+ return self
484
+
485
+ def command(self, command):
486
+ self._kwargs["command"] = command
487
+ return self
488
+
489
+ def image(self, image):
490
+ self._kwargs["image"] = image
491
+ return self
492
+
493
+ def cpu(self, cpu):
494
+ self._kwargs["cpu"] = cpu
495
+ return self
496
+
497
+ def memory(self, mem):
498
+ self._kwargs["memory"] = mem
499
+ return self
500
+
501
+ def environment_variable(self, name, value):
502
+ # Never set to None
503
+ if value is None:
504
+ return self
505
+ self._kwargs["environment_variables"] = dict(
506
+ self._kwargs.get("environment_variables", {}), **{name: value}
507
+ )
508
+ return self
509
+
510
+ def secret(self, name):
511
+ if name is None:
512
+ return self
513
+ if len(self._kwargs.get("secrets", [])) == 0:
514
+ self._kwargs["secrets"] = []
515
+ self._kwargs["secrets"] = list(set(self._kwargs["secrets"] + [name]))
516
+
517
+ def environment_variable_from_selector(self, name, label_value):
518
+ # Never set to None
519
+ if label_value is None:
520
+ return self
521
+ self._kwargs["environment_variables_from_selectors"] = dict(
522
+ self._kwargs.get("environment_variables_from_selectors", {}),
523
+ **{name: label_value}
524
+ )
525
+ return self
526
+
527
+ def label(self, name, value):
528
+ self._kwargs["labels"] = dict(self._kwargs.get("labels", {}), **{name: value})
529
+ return self
530
+
531
+ def annotation(self, name, value):
532
+ self._kwargs["annotations"] = dict(
533
+ self._kwargs.get("annotations", {}), **{name: value}
534
+ )
535
+ return self
536
+
537
+ def dump(self):
538
+ client = self._kubernetes_sdk
539
+ use_tmpfs = self._kwargs["use_tmpfs"]
540
+ tmpfs_size = self._kwargs["tmpfs_size"]
541
+ tmpfs_enabled = use_tmpfs or (tmpfs_size and not use_tmpfs)
542
+ shared_memory = (
543
+ int(self._kwargs["shared_memory"])
544
+ if self._kwargs["shared_memory"]
545
+ else None
546
+ )
547
+
548
+ return dict(
549
+ name=self.name,
550
+ template=client.api_client.ApiClient().sanitize_for_serialization(
551
+ client.V1JobTemplateSpec(
552
+ metadata=client.V1ObjectMeta(
553
+ namespace=self._kwargs["namespace"],
554
+ # We don't set any annotations here
555
+ # since they have been either set in the JobSpec
556
+ # or on the JobSet level
557
+ ),
558
+ spec=client.V1JobSpec(
559
+ # Retries are handled by Metaflow when it is responsible for
560
+ # executing the flow. The responsibility is moved to Kubernetes
561
+ # when Argo Workflows is responsible for the execution.
562
+ backoff_limit=self._kwargs.get("retries", 0),
563
+ completions=1,
564
+ parallelism=1,
565
+ ttl_seconds_after_finished=7
566
+ * 60
567
+ * 60 # Remove job after a week. TODO: Make this configurable
568
+ * 24,
569
+ template=client.V1PodTemplateSpec(
570
+ metadata=client.V1ObjectMeta(
571
+ annotations=self._kwargs.get("annotations", {}),
572
+ labels=self._kwargs.get("labels", {}),
573
+ namespace=self._kwargs["namespace"],
574
+ ),
575
+ spec=client.V1PodSpec(
576
+ ## --- jobset require podspec deets start----
577
+ subdomain=self._kwargs["subdomain"],
578
+ set_hostname_as_fqdn=True,
579
+ ## --- jobset require podspec deets end ----
580
+ # Timeout is set on the pod and not the job (important!)
581
+ active_deadline_seconds=self._kwargs[
582
+ "timeout_in_seconds"
583
+ ],
584
+ # TODO (savin): Enable affinities for GPU scheduling.
585
+ # affinity=?,
586
+ containers=[
587
+ client.V1Container(
588
+ command=self._kwargs["command"],
589
+ ports=[]
590
+ if self._kwargs["port"] is None
591
+ else [
592
+ client.V1ContainerPort(
593
+ container_port=int(self._kwargs["port"])
594
+ )
595
+ ],
596
+ env=[
597
+ client.V1EnvVar(name=k, value=str(v))
598
+ for k, v in self._kwargs.get(
599
+ "environment_variables", {}
600
+ ).items()
601
+ ]
602
+ # And some downward API magic. Add (key, value)
603
+ # pairs below to make pod metadata available
604
+ # within Kubernetes container.
605
+ + [
606
+ client.V1EnvVar(
607
+ name=k,
608
+ value_from=client.V1EnvVarSource(
609
+ field_ref=client.V1ObjectFieldSelector(
610
+ field_path=str(v)
611
+ )
612
+ ),
613
+ )
614
+ for k, v in self._kwargs.get(
615
+ "environment_variables_from_selectors",
616
+ {},
617
+ ).items()
618
+ ]
619
+ + [
620
+ client.V1EnvVar(name=k, value=str(v))
621
+ for k, v in inject_tracing_vars({}).items()
622
+ ],
623
+ env_from=[
624
+ client.V1EnvFromSource(
625
+ secret_ref=client.V1SecretEnvSource(
626
+ name=str(k),
627
+ # optional=True
628
+ )
629
+ )
630
+ for k in list(
631
+ self._kwargs.get("secrets", [])
632
+ )
633
+ if k
634
+ ],
635
+ image=self._kwargs["image"],
636
+ image_pull_policy=self._kwargs[
637
+ "image_pull_policy"
638
+ ],
639
+ name=self._kwargs["step_name"].replace(
640
+ "_", "-"
641
+ ),
642
+ resources=client.V1ResourceRequirements(
643
+ requests={
644
+ "cpu": str(self._kwargs["cpu"]),
645
+ "memory": "%sM"
646
+ % str(self._kwargs["memory"]),
647
+ "ephemeral-storage": "%sM"
648
+ % str(self._kwargs["disk"]),
649
+ },
650
+ limits={
651
+ "%s.com/gpu".lower()
652
+ % self._kwargs["gpu_vendor"]: str(
653
+ self._kwargs["gpu"]
654
+ )
655
+ for k in [0]
656
+ # Don't set GPU limits if gpu isn't specified.
657
+ if self._kwargs["gpu"] is not None
658
+ },
659
+ ),
660
+ volume_mounts=(
661
+ [
662
+ client.V1VolumeMount(
663
+ mount_path=self._kwargs.get(
664
+ "tmpfs_path"
665
+ ),
666
+ name="tmpfs-ephemeral-volume",
667
+ )
668
+ ]
669
+ if tmpfs_enabled
670
+ else []
671
+ )
672
+ + (
673
+ [
674
+ client.V1VolumeMount(
675
+ mount_path="/dev/shm", name="dhsm"
676
+ )
677
+ ]
678
+ if shared_memory
679
+ else []
680
+ )
681
+ + (
682
+ [
683
+ client.V1VolumeMount(
684
+ mount_path=path, name=claim
685
+ )
686
+ for claim, path in self._kwargs[
687
+ "persistent_volume_claims"
688
+ ].items()
689
+ ]
690
+ if self._kwargs["persistent_volume_claims"]
691
+ is not None
692
+ else []
693
+ ),
694
+ )
695
+ ],
696
+ node_selector=self._kwargs.get("node_selector"),
697
+ # TODO (savin): Support image_pull_secrets
698
+ # image_pull_secrets=?,
699
+ # TODO (savin): Support preemption policies
700
+ # preemption_policy=?,
701
+ #
702
+ # A Container in a Pod may fail for a number of
703
+ # reasons, such as because the process in it exited
704
+ # with a non-zero exit code, or the Container was
705
+ # killed due to OOM etc. If this happens, fail the pod
706
+ # and let Metaflow handle the retries.
707
+ restart_policy="Never",
708
+ service_account_name=self._kwargs["service_account"],
709
+ # Terminate the container immediately on SIGTERM
710
+ termination_grace_period_seconds=0,
711
+ tolerations=[
712
+ client.V1Toleration(**toleration)
713
+ for toleration in self._kwargs.get("tolerations")
714
+ or []
715
+ ],
716
+ volumes=(
717
+ [
718
+ client.V1Volume(
719
+ name="tmpfs-ephemeral-volume",
720
+ empty_dir=client.V1EmptyDirVolumeSource(
721
+ medium="Memory",
722
+ # Add default unit as ours differs from Kubernetes default.
723
+ size_limit="{}Mi".format(tmpfs_size),
724
+ ),
725
+ )
726
+ ]
727
+ if tmpfs_enabled
728
+ else []
729
+ )
730
+ + (
731
+ [
732
+ client.V1Volume(
733
+ name="dhsm",
734
+ empty_dir=client.V1EmptyDirVolumeSource(
735
+ medium="Memory",
736
+ size_limit="{}Mi".format(shared_memory),
737
+ ),
738
+ )
739
+ ]
740
+ if shared_memory
741
+ else []
742
+ )
743
+ + (
744
+ [
745
+ client.V1Volume(
746
+ name=claim,
747
+ persistent_volume_claim=client.V1PersistentVolumeClaimVolumeSource(
748
+ claim_name=claim
749
+ ),
750
+ )
751
+ for claim in self._kwargs[
752
+ "persistent_volume_claims"
753
+ ].keys()
754
+ ]
755
+ if self._kwargs["persistent_volume_claims"]
756
+ is not None
757
+ else []
758
+ ),
759
+ # TODO (savin): Set termination_message_policy
760
+ ),
761
+ ),
762
+ ),
763
+ )
764
+ ),
765
+ replicas=self._kwargs["replicas"],
766
+ )
767
+
768
+
661
769
  class KubernetesJobSet(object):
662
770
  def __init__(
663
771
  self,
664
772
  client,
665
773
  name=None,
666
- job_spec=None,
667
774
  namespace=None,
668
775
  num_parallel=None,
669
- annotations=None,
670
- labels=None,
671
- port=None,
672
- task_id=None,
776
+ # explcitly declaring num_parallel because we need to ensure that
777
+ # num_parallel is an INTEGER and this abstraction is called by the
778
+ # local runtime abstraction of kubernetes.
779
+ # Argo will call another abstraction that will allow setting a lot of these
780
+ # values from the top level argo code.
673
781
  **kwargs
674
782
  ):
675
783
  self._client = client
676
- self._kwargs = kwargs
784
+ self._annotations = {}
785
+ self._labels = {}
677
786
  self._group = KUBERNETES_JOBSET_GROUP
678
787
  self._version = KUBERNETES_JOBSET_VERSION
788
+ self._namespace = namespace
679
789
  self.name = name
680
790
 
681
- main_job_name = "control"
682
- main_job_index = 0
683
- main_pod_index = 0
684
- subdomain = self.name
685
- num_parallel = int(1 if not num_parallel else num_parallel)
686
- self._namespace = namespace
687
- jobset_main_addr = _make_domain_name(
688
- self.name,
689
- main_job_name,
690
- main_job_index,
691
- main_pod_index,
692
- self._namespace,
791
+ self._jobset_control_addr = _make_domain_name(
792
+ name,
793
+ CONTROL_JOB_NAME,
794
+ 0,
795
+ 0,
796
+ namespace,
693
797
  )
694
798
 
695
- annotations = {} if not annotations else annotations
696
- labels = {} if not labels else labels
697
-
698
- if "metaflow/task_id" in annotations:
699
- del annotations["metaflow/task_id"]
700
-
701
- control_job = get_control_job(
702
- client=self._client.get(),
703
- job_spec=job_spec,
704
- jobset_main_addr=jobset_main_addr,
705
- subdomain=subdomain,
706
- port=port,
707
- num_parallel=num_parallel,
708
- namespace=namespace,
709
- annotations=annotations,
799
+ self._control_spec = JobSetSpec(
800
+ client.get(), name=CONTROL_JOB_NAME, namespace=namespace, **kwargs
710
801
  )
711
- worker_task_id = TaskIdConstructor.jobset_worker_id(task_id)
712
- worker_job = get_worker_job(
713
- client=self._client.get(),
714
- job_spec=job_spec,
715
- job_name="worker",
716
- jobset_main_addr=jobset_main_addr,
717
- subdomain=subdomain,
718
- control_task_id=task_id,
719
- worker_task_id=worker_task_id,
720
- replicas=num_parallel - 1,
721
- port=port,
722
- num_parallel=num_parallel,
723
- namespace=namespace,
724
- annotations=annotations,
802
+ self._worker_spec = JobSetSpec(
803
+ client.get(), name="worker", namespace=namespace, **kwargs
725
804
  )
726
- worker_jobs = [worker_job]
727
- # Based on https://github.com/kubernetes-sigs/jobset/blob/v0.5.0/api/jobset/v1alpha2/jobset_types.go#L163
728
- _kclient = client.get()
729
- self._jobset = dict(
805
+ assert (
806
+ type(num_parallel) == int
807
+ ), "num_parallel must be an integer" # todo: [final-refactor] : fix-me
808
+
809
+ @property
810
+ def jobset_control_addr(self):
811
+ return self._jobset_control_addr
812
+
813
+ @property
814
+ def worker(self):
815
+ return self._worker_spec
816
+
817
+ @property
818
+ def control(self):
819
+ return self._control_spec
820
+
821
+ def environment_variable_from_selector(self, name, label_value):
822
+ self.worker.environment_variable_from_selector(name, label_value)
823
+ self.control.environment_variable_from_selector(name, label_value)
824
+ return self
825
+
826
+ def environment_variables_from_selectors(self, env_dict):
827
+ for name, label_value in env_dict.items():
828
+ self.worker.environment_variable_from_selector(name, label_value)
829
+ self.control.environment_variable_from_selector(name, label_value)
830
+ return self
831
+
832
+ def environment_variable(self, name, value):
833
+ self.worker.environment_variable(name, value)
834
+ self.control.environment_variable(name, value)
835
+ return self
836
+
837
+ def label(self, name, value):
838
+ self.worker.label(name, value)
839
+ self.control.label(name, value)
840
+ self._labels = dict(self._labels, **{name: value})
841
+ return self
842
+
843
+ def annotation(self, name, value):
844
+ self.worker.annotation(name, value)
845
+ self.control.annotation(name, value)
846
+ self._annotations = dict(self._annotations, **{name: value})
847
+ return self
848
+
849
+ def secret(self, name):
850
+ self.worker.secret(name)
851
+ self.control.secret(name)
852
+ return self
853
+
854
+ def dump(self):
855
+ client = self._client.get()
856
+ return dict(
730
857
  apiVersion=self._group + "/" + self._version,
731
858
  kind="JobSet",
732
- metadata=_kclient.api_client.ApiClient().sanitize_for_serialization(
733
- _kclient.V1ObjectMeta(
734
- name=self.name, labels=labels, annotations=annotations
859
+ metadata=client.api_client.ApiClient().sanitize_for_serialization(
860
+ client.V1ObjectMeta(
861
+ name=self.name,
862
+ labels=self._labels,
863
+ annotations=self._annotations,
735
864
  )
736
865
  ),
737
866
  spec=dict(
738
- replicatedJobs=[control_job] + worker_jobs,
867
+ replicatedJobs=[self.control.dump(), self.worker.dump()],
739
868
  suspend=False,
740
869
  startupPolicy=None,
741
870
  successPolicy=None,
742
871
  # The Failure Policy helps setting the number of retries for the jobset.
743
- # It cannot accept a value of 0 for maxRestarts.
744
- # So the attempt needs to be smartly set.
745
- # If there is no retry decorator then we not set maxRestarts and instead we will
746
- # set the attempt statically to 0. Otherwise we will make the job pickup the attempt
747
- # from the `V1EnvVarSource.value_from.V1ObjectFieldSelector.field_path` = "metadata.annotations['jobset.sigs.k8s.io/restart-attempt']"
748
- # failurePolicy={
749
- # "maxRestarts" : 1
750
- # },
751
- # The can be set for ArgoWorkflows
872
+ # but we don't rely on it and instead rely on either the local scheduler
873
+ # or the Argo Workflows to handle retries.
752
874
  failurePolicy=None,
753
875
  network=None,
754
876
  ),
@@ -767,7 +889,7 @@ class KubernetesJobSet(object):
767
889
  version=self._version,
768
890
  namespace=self._namespace,
769
891
  plural="jobsets",
770
- body=self._jobset,
892
+ body=self.dump(),
771
893
  )
772
894
  except Exception as e:
773
895
  raise KubernetesJobsetException(
@@ -782,3 +904,119 @@ class KubernetesJobSet(object):
782
904
  group=self._group,
783
905
  version=self._version,
784
906
  )
907
+
908
+
909
+ class KubernetesArgoJobSet(object):
910
+ def __init__(self, kubernetes_sdk, name=None, namespace=None, **kwargs):
911
+ self._kubernetes_sdk = kubernetes_sdk
912
+ self._annotations = {}
913
+ self._labels = {}
914
+ self._group = KUBERNETES_JOBSET_GROUP
915
+ self._version = KUBERNETES_JOBSET_VERSION
916
+ self._namespace = namespace
917
+ self.name = name
918
+
919
+ self._jobset_control_addr = _make_domain_name(
920
+ name,
921
+ CONTROL_JOB_NAME,
922
+ 0,
923
+ 0,
924
+ namespace,
925
+ )
926
+
927
+ self._control_spec = JobSetSpec(
928
+ kubernetes_sdk, name=CONTROL_JOB_NAME, namespace=namespace, **kwargs
929
+ )
930
+ self._worker_spec = JobSetSpec(
931
+ kubernetes_sdk, name="worker", namespace=namespace, **kwargs
932
+ )
933
+
934
+ @property
935
+ def jobset_control_addr(self):
936
+ return self._jobset_control_addr
937
+
938
+ @property
939
+ def worker(self):
940
+ return self._worker_spec
941
+
942
+ @property
943
+ def control(self):
944
+ return self._control_spec
945
+
946
+ def environment_variable_from_selector(self, name, label_value):
947
+ self.worker.environment_variable_from_selector(name, label_value)
948
+ self.control.environment_variable_from_selector(name, label_value)
949
+ return self
950
+
951
+ def environment_variables_from_selectors(self, env_dict):
952
+ for name, label_value in env_dict.items():
953
+ self.worker.environment_variable_from_selector(name, label_value)
954
+ self.control.environment_variable_from_selector(name, label_value)
955
+ return self
956
+
957
+ def environment_variable(self, name, value):
958
+ self.worker.environment_variable(name, value)
959
+ self.control.environment_variable(name, value)
960
+ return self
961
+
962
+ def label(self, name, value):
963
+ self.worker.label(name, value)
964
+ self.control.label(name, value)
965
+ self._labels = dict(self._labels, **{name: value})
966
+ return self
967
+
968
+ def annotation(self, name, value):
969
+ self.worker.annotation(name, value)
970
+ self.control.annotation(name, value)
971
+ self._annotations = dict(self._annotations, **{name: value})
972
+ return self
973
+
974
+ def dump(self):
975
+ client = self._kubernetes_sdk
976
+ import json
977
+
978
+ data = json.dumps(
979
+ client.ApiClient().sanitize_for_serialization(
980
+ dict(
981
+ apiVersion=self._group + "/" + self._version,
982
+ kind="JobSet",
983
+ metadata=client.api_client.ApiClient().sanitize_for_serialization(
984
+ client.V1ObjectMeta(
985
+ name=self.name,
986
+ labels=self._labels,
987
+ annotations=self._annotations,
988
+ )
989
+ ),
990
+ spec=dict(
991
+ replicatedJobs=[self.control.dump(), self.worker.dump()],
992
+ suspend=False,
993
+ startupPolicy=None,
994
+ successPolicy=None,
995
+ # The Failure Policy helps setting the number of retries for the jobset.
996
+ # but we don't rely on it and instead rely on either the local scheduler
997
+ # or the Argo Workflows to handle retries.
998
+ failurePolicy=None,
999
+ network=None,
1000
+ ),
1001
+ status=None,
1002
+ )
1003
+ )
1004
+ )
1005
+ # The values we populate in the Jobset manifest (for Argo Workflows) piggybacks on the Argo Workflow's templating engine.
1006
+ # Even though Argo Workflows's templating helps us constructing all the necessary IDs and populating the fields
1007
+ # required by Metaflow, we run into one glitch. When we construct JSON/YAML serializable objects,
1008
+ # anything between two braces such as `{{=asInt(inputs.parameters.workerCount)}}` gets quoted. This is a problem
1009
+ # since we need to pass the value of `inputs.parameters.workerCount` as an integer and not as a string.
1010
+ # If we pass it as a string, the jobset controller will not accept the Jobset CRD we submitted to kubernetes.
1011
+ # To get around this, we need to replace the quoted substring with the unquoted substring because YAML /JSON parsers
1012
+ # won't allow deserialization with the quoting trivially.
1013
+
1014
+ # This is super important because the `inputs.parameters.workerCount` is used to set the number of replicas;
1015
+ # The value for number of replicas is derived from the value of `num_parallel` (which is set in the user-code).
1016
+ # Since the value of `num_parallel` can be dynamic and can change from run to run, we need to ensure that the
1017
+ # value can be passed-down dynamically and is **explicitly set as a integer** in the Jobset Manifest submitted as a
1018
+ # part of the Argo Workflow
1019
+
1020
+ quoted_substring = '"{{=asInt(inputs.parameters.workerCount)}}"'
1021
+ unquoted_substring = "{{=asInt(inputs.parameters.workerCount)}}"
1022
+ return data.replace(quoted_substring, unquoted_substring)