metaflow 2.12.8__py2.py3-none-any.whl → 2.12.9__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. metaflow/__init__.py +2 -0
  2. metaflow/cli.py +12 -4
  3. metaflow/extension_support/plugins.py +1 -0
  4. metaflow/flowspec.py +8 -1
  5. metaflow/lint.py +13 -0
  6. metaflow/metaflow_current.py +0 -8
  7. metaflow/plugins/__init__.py +12 -0
  8. metaflow/plugins/argo/argo_workflows.py +462 -42
  9. metaflow/plugins/argo/argo_workflows_cli.py +60 -3
  10. metaflow/plugins/argo/argo_workflows_decorator.py +38 -7
  11. metaflow/plugins/argo/argo_workflows_deployer.py +290 -0
  12. metaflow/plugins/argo/jobset_input_paths.py +16 -0
  13. metaflow/plugins/aws/batch/batch_decorator.py +16 -13
  14. metaflow/plugins/aws/step_functions/step_functions_cli.py +45 -3
  15. metaflow/plugins/aws/step_functions/step_functions_deployer.py +251 -0
  16. metaflow/plugins/cards/card_cli.py +1 -1
  17. metaflow/plugins/kubernetes/kubernetes.py +279 -52
  18. metaflow/plugins/kubernetes/kubernetes_cli.py +26 -8
  19. metaflow/plugins/kubernetes/kubernetes_client.py +0 -1
  20. metaflow/plugins/kubernetes/kubernetes_decorator.py +56 -44
  21. metaflow/plugins/kubernetes/kubernetes_job.py +6 -6
  22. metaflow/plugins/kubernetes/kubernetes_jobsets.py +510 -272
  23. metaflow/plugins/parallel_decorator.py +108 -8
  24. metaflow/plugins/secrets/secrets_decorator.py +12 -3
  25. metaflow/plugins/test_unbounded_foreach_decorator.py +39 -4
  26. metaflow/runner/deployer.py +386 -0
  27. metaflow/runner/metaflow_runner.py +1 -20
  28. metaflow/runner/nbdeploy.py +130 -0
  29. metaflow/runner/nbrun.py +4 -28
  30. metaflow/runner/utils.py +49 -0
  31. metaflow/runtime.py +246 -134
  32. metaflow/version.py +1 -1
  33. {metaflow-2.12.8.dist-info → metaflow-2.12.9.dist-info}/METADATA +2 -2
  34. {metaflow-2.12.8.dist-info → metaflow-2.12.9.dist-info}/RECORD +38 -32
  35. {metaflow-2.12.8.dist-info → metaflow-2.12.9.dist-info}/WHEEL +1 -1
  36. {metaflow-2.12.8.dist-info → metaflow-2.12.9.dist-info}/LICENSE +0 -0
  37. {metaflow-2.12.8.dist-info → metaflow-2.12.9.dist-info}/entry_points.txt +0 -0
  38. {metaflow-2.12.8.dist-info → metaflow-2.12.9.dist-info}/top_level.txt +0 -0
@@ -4,11 +4,13 @@ import os
4
4
  import re
5
5
  import shlex
6
6
  import sys
7
+ from typing import Tuple, List
7
8
  from collections import defaultdict
8
9
  from hashlib import sha1
9
10
  from math import inf
10
11
 
11
12
  from metaflow import JSONType, current
13
+ from metaflow.graph import DAGNode
12
14
  from metaflow.decorators import flow_decorators
13
15
  from metaflow.exception import MetaflowException
14
16
  from metaflow.includefile import FilePathClass
@@ -47,6 +49,7 @@ from metaflow.metaflow_config import (
47
49
  SERVICE_INTERNAL_URL,
48
50
  UI_URL,
49
51
  )
52
+ from metaflow.unbounded_foreach import UBF_CONTROL, UBF_TASK
50
53
  from metaflow.metaflow_config_funcs import config_values
51
54
  from metaflow.mflog import BASH_SAVE_LOGS, bash_capture_logs, export_mflog_env_vars
52
55
  from metaflow.parameters import deploy_time_eval
@@ -54,6 +57,7 @@ from metaflow.plugins.kubernetes.kubernetes import (
54
57
  parse_kube_keyvalue_list,
55
58
  validate_kube_labels,
56
59
  )
60
+ from metaflow.graph import FlowGraph
57
61
  from metaflow.util import (
58
62
  compress_list,
59
63
  dict_to_cli_options,
@@ -61,6 +65,9 @@ from metaflow.util import (
61
65
  to_camelcase,
62
66
  to_unicode,
63
67
  )
68
+ from metaflow.plugins.kubernetes.kubernetes_jobsets import (
69
+ KubernetesArgoJobSet,
70
+ )
64
71
 
65
72
  from .argo_client import ArgoClient
66
73
 
@@ -82,14 +89,14 @@ class ArgoWorkflowsSchedulingException(MetaflowException):
82
89
  # 5. Add Metaflow tags to labels/annotations.
83
90
  # 6. Support Multi-cluster scheduling - https://github.com/argoproj/argo-workflows/issues/3523#issuecomment-792307297
84
91
  # 7. Support R lang.
85
- # 8. Ping @savin at slack.outerbounds.co for any feature request.
92
+ # 8. Ping @savin at slack.outerbounds.co for any feature request
86
93
 
87
94
 
88
95
  class ArgoWorkflows(object):
89
96
  def __init__(
90
97
  self,
91
98
  name,
92
- graph,
99
+ graph: FlowGraph,
93
100
  flow,
94
101
  code_package_sha,
95
102
  code_package_url,
@@ -852,13 +859,13 @@ class ArgoWorkflows(object):
852
859
  # Visit every node and yield the uber DAGTemplate(s).
853
860
  def _dag_templates(self):
854
861
  def _visit(
855
- node, exit_node=None, templates=None, dag_tasks=None, parent_foreach=None
856
- ):
857
- if node.parallel_foreach:
858
- raise ArgoWorkflowsException(
859
- "Deploying flows with @parallel decorator(s) "
860
- "as Argo Workflows is not supported currently."
861
- )
862
+ node,
863
+ exit_node=None,
864
+ templates=None,
865
+ dag_tasks=None,
866
+ parent_foreach=None,
867
+ ): # Returns Tuple[List[Template], List[DAGTask]]
868
+ """ """
862
869
  # Every for-each node results in a separate subDAG and an equivalent
863
870
  # DAGTemplate rooted at the child of the for-each node. Each DAGTemplate
864
871
  # has a unique name - the top-level DAGTemplate is named as the name of
@@ -872,7 +879,6 @@ class ArgoWorkflows(object):
872
879
  templates = []
873
880
  if exit_node is not None and exit_node is node.name:
874
881
  return templates, dag_tasks
875
-
876
882
  if node.name == "start":
877
883
  # Start node has no dependencies.
878
884
  dag_task = DAGTask(self._sanitize(node.name)).template(
@@ -881,13 +887,86 @@ class ArgoWorkflows(object):
881
887
  elif (
882
888
  node.is_inside_foreach
883
889
  and self.graph[node.in_funcs[0]].type == "foreach"
890
+ and not self.graph[node.in_funcs[0]].parallel_foreach
891
+ # We need to distinguish what is a "regular" foreach (i.e something that doesn't care about to gang semantics)
892
+ # vs what is a "num_parallel" based foreach (i.e. something that follows gang semantics.)
893
+ # A `regular` foreach is basically any arbitrary kind of foreach.
884
894
  ):
885
895
  # Child of a foreach node needs input-paths as well as split-index
886
896
  # This child is the first node of the sub workflow and has no dependency
897
+
887
898
  parameters = [
888
899
  Parameter("input-paths").value("{{inputs.parameters.input-paths}}"),
889
900
  Parameter("split-index").value("{{inputs.parameters.split-index}}"),
890
901
  ]
902
+ dag_task = (
903
+ DAGTask(self._sanitize(node.name))
904
+ .template(self._sanitize(node.name))
905
+ .arguments(Arguments().parameters(parameters))
906
+ )
907
+ elif node.parallel_step:
908
+ # This is the step where the @parallel decorator is defined.
909
+ # Since this DAGTask will call the for the `resource` [based templates]
910
+ # (https://argo-workflows.readthedocs.io/en/stable/walk-through/kubernetes-resources/)
911
+ # we have certain constraints on the way we can pass information inside the Jobset manifest
912
+ # [All templates will have access](https://argo-workflows.readthedocs.io/en/stable/variables/#all-templates)
913
+ # to the `inputs.parameters` so we will pass down ANY/ALL information using the
914
+ # input parameters.
915
+ # We define the usual parameters like input-paths/split-index etc. but we will also
916
+ # define the following:
917
+ # - `workerCount`: parameter which will be used to determine the number of
918
+ # parallel worker jobs
919
+ # - `jobset-name`: parameter which will be used to determine the name of the jobset.
920
+ # This parameter needs to be dynamic so that when we have retries we don't
921
+ # end up using the name of the jobset again (if we do, it will crash since k8s wont allow duplicated job names)
922
+ # - `retryCount`: parameter which will be used to determine the number of retries
923
+ # This parameter will *only* be available within the container templates like we
924
+ # have it for all other DAGTasks and NOT for custom kubernetes resource templates.
925
+ # So as a work-around, we will set it as the `retryCount` parameter instead of
926
+ # setting it as a {{ retries }} in the CLI code. Once set as a input parameter,
927
+ # we can use it in the Jobset Manifest templates as `{{inputs.parameters.retryCount}}`
928
+ # - `task-id-entropy`: This is a parameter which will help derive task-ids and jobset names. This parameter
929
+ # contains the relevant amount of entropy to ensure that task-ids and jobset names
930
+ # are uniquish. We will also use this in the join task to construct the task-ids of
931
+ # all parallel tasks since the task-ids for parallel task are minted formulaically.
932
+ parameters = [
933
+ Parameter("input-paths").value("{{inputs.parameters.input-paths}}"),
934
+ Parameter("num-parallel").value(
935
+ "{{inputs.parameters.num-parallel}}"
936
+ ),
937
+ Parameter("split-index").value("{{inputs.parameters.split-index}}"),
938
+ Parameter("task-id-entropy").value(
939
+ "{{inputs.parameters.task-id-entropy}}"
940
+ ),
941
+ # we cant just use hyphens with sprig.
942
+ # https://github.com/argoproj/argo-workflows/issues/10567#issuecomment-1452410948
943
+ Parameter("workerCount").value(
944
+ "{{=sprig.int(sprig.sub(sprig.int(inputs.parameters['num-parallel']),1))}}"
945
+ ),
946
+ ]
947
+ if any(d.name == "retry" for d in node.decorators):
948
+ parameters.extend(
949
+ [
950
+ Parameter("retryCount").value("{{retries}}"),
951
+ # The job-setname needs to be unique for each retry
952
+ # and we cannot use the `generateName` field in the
953
+ # Jobset Manifest since we need to construct the subdomain
954
+ # and control pod domain name pre-hand. So we will use
955
+ # the retry count to ensure that the jobset name is unique
956
+ Parameter("jobset-name").value(
957
+ "js-{{inputs.parameters.task-id-entropy}}{{retries}}",
958
+ ),
959
+ ]
960
+ )
961
+ else:
962
+ parameters.extend(
963
+ [
964
+ Parameter("jobset-name").value(
965
+ "js-{{inputs.parameters.task-id-entropy}}",
966
+ )
967
+ ]
968
+ )
969
+
891
970
  dag_task = (
892
971
  DAGTask(self._sanitize(node.name))
893
972
  .template(self._sanitize(node.name))
@@ -947,8 +1026,8 @@ class ArgoWorkflows(object):
947
1026
  .template(self._sanitize(node.name))
948
1027
  .arguments(Arguments().parameters(parameters))
949
1028
  )
950
- dag_tasks.append(dag_task)
951
1029
 
1030
+ dag_tasks.append(dag_task)
952
1031
  # End the workflow if we have reached the end of the flow
953
1032
  if node.type == "end":
954
1033
  return [
@@ -974,14 +1053,30 @@ class ArgoWorkflows(object):
974
1053
  parent_foreach,
975
1054
  )
976
1055
  # For foreach nodes generate a new sub DAGTemplate
1056
+ # We do this for "regular" foreaches (ie. `self.next(self.a, foreach=)`)
977
1057
  elif node.type == "foreach":
978
1058
  foreach_template_name = self._sanitize(
979
1059
  "%s-foreach-%s"
980
1060
  % (
981
1061
  node.name,
982
- node.foreach_param,
1062
+ "parallel" if node.parallel_foreach else node.foreach_param
1063
+ # Since foreach's are derived based on `self.next(self.a, foreach="<varname>")`
1064
+ # vs @parallel foreach are done based on `self.next(self.a, num_parallel="<some-number>")`,
1065
+ # we need to ensure that `foreach_template_name` suffix is appropriately set based on the kind
1066
+ # of foreach.
983
1067
  )
984
1068
  )
1069
+
1070
+ # There are two separate "DAGTask"s created for the foreach node.
1071
+ # - The first one is a "jump-off" DAGTask where we propagate the
1072
+ # input-paths and split-index. This thing doesn't create
1073
+ # any actual containers and it responsible for only propagating
1074
+ # the parameters.
1075
+ # - The DAGTask that follows first DAGTask is the one
1076
+ # that uses the ContainerTemplate. This DAGTask is named the same
1077
+ # thing as the foreach node. We will leverage a similar pattern for the
1078
+ # @parallel tasks.
1079
+ #
985
1080
  foreach_task = (
986
1081
  DAGTask(foreach_template_name)
987
1082
  .dependencies([self._sanitize(node.name)])
@@ -1005,9 +1100,26 @@ class ArgoWorkflows(object):
1005
1100
  if parent_foreach
1006
1101
  else []
1007
1102
  )
1103
+ + (
1104
+ # Disabiguate parameters for a regular `foreach` vs a `@parallel` foreach
1105
+ [
1106
+ Parameter("num-parallel").value(
1107
+ "{{tasks.%s.outputs.parameters.num-parallel}}"
1108
+ % self._sanitize(node.name)
1109
+ ),
1110
+ Parameter("task-id-entropy").value(
1111
+ "{{tasks.%s.outputs.parameters.task-id-entropy}}"
1112
+ % self._sanitize(node.name)
1113
+ ),
1114
+ ]
1115
+ if node.parallel_foreach
1116
+ else []
1117
+ )
1008
1118
  )
1009
1119
  )
1010
1120
  .with_param(
1121
+ # For @parallel workloads `num-splits` will be explicitly set to one so that
1122
+ # we can piggyback on the current mechanism with which we leverage argo.
1011
1123
  "{{tasks.%s.outputs.parameters.num-splits}}"
1012
1124
  % self._sanitize(node.name)
1013
1125
  )
@@ -1020,17 +1132,34 @@ class ArgoWorkflows(object):
1020
1132
  [],
1021
1133
  node.name,
1022
1134
  )
1135
+
1136
+ # How do foreach's work on Argo:
1137
+ # Lets say you have the following dag: (start[sets `foreach="x"`]) --> (task-a [actual foreach]) --> (join) --> (end)
1138
+ # With argo we will :
1139
+ # (start [sets num-splits]) --> (task-a-foreach-(0,0) [dummy task]) --> (task-a) --> (join) --> (end)
1140
+ # The (task-a-foreach-(0,0) [dummy task]) propagates the values of the `split-index` and the input paths.
1141
+ # to the actual foreach task.
1023
1142
  templates.append(
1024
1143
  Template(foreach_template_name)
1025
1144
  .inputs(
1026
1145
  Inputs().parameters(
1027
1146
  [Parameter("input-paths"), Parameter("split-index")]
1028
1147
  + ([Parameter("root-input-path")] if parent_foreach else [])
1148
+ + (
1149
+ [
1150
+ Parameter("num-parallel"),
1151
+ Parameter("task-id-entropy"),
1152
+ # Parameter("workerCount")
1153
+ ]
1154
+ if node.parallel_foreach
1155
+ else []
1156
+ )
1029
1157
  )
1030
1158
  )
1031
1159
  .outputs(
1032
1160
  Outputs().parameters(
1033
1161
  [
1162
+ # non @parallel tasks set task-ids as outputs
1034
1163
  Parameter("task-id").valueFrom(
1035
1164
  {
1036
1165
  "parameter": "{{tasks.%s.outputs.parameters.task-id}}"
@@ -1040,29 +1169,67 @@ class ArgoWorkflows(object):
1040
1169
  }
1041
1170
  )
1042
1171
  ]
1172
+ if not node.parallel_foreach
1173
+ else [
1174
+ # @parallel tasks set `task-id-entropy` and `num-parallel`
1175
+ # as outputs so task-ids can be derived in the join step.
1176
+ # Both of these values should be propagated from the
1177
+ # jobset labels.
1178
+ Parameter("num-parallel").valueFrom(
1179
+ {
1180
+ "parameter": "{{tasks.%s.outputs.parameters.num-parallel}}"
1181
+ % self._sanitize(
1182
+ self.graph[node.matching_join].in_funcs[0]
1183
+ )
1184
+ }
1185
+ ),
1186
+ Parameter("task-id-entropy").valueFrom(
1187
+ {
1188
+ "parameter": "{{tasks.%s.outputs.parameters.task-id-entropy}}"
1189
+ % self._sanitize(
1190
+ self.graph[node.matching_join].in_funcs[0]
1191
+ )
1192
+ }
1193
+ ),
1194
+ ]
1043
1195
  )
1044
1196
  )
1045
1197
  .dag(DAGTemplate().fail_fast().tasks(dag_tasks_1))
1046
1198
  )
1199
+
1047
1200
  join_foreach_task = (
1048
1201
  DAGTask(self._sanitize(self.graph[node.matching_join].name))
1049
1202
  .template(self._sanitize(self.graph[node.matching_join].name))
1050
1203
  .dependencies([foreach_template_name])
1051
1204
  .arguments(
1052
1205
  Arguments().parameters(
1053
- [
1054
- Parameter("input-paths").value(
1055
- "argo-{{workflow.name}}/%s/{{tasks.%s.outputs.parameters.task-id}}"
1056
- % (node.name, self._sanitize(node.name))
1057
- ),
1058
- Parameter("split-cardinality").value(
1059
- "{{tasks.%s.outputs.parameters.split-cardinality}}"
1060
- % self._sanitize(node.name)
1061
- ),
1062
- ]
1206
+ (
1207
+ [
1208
+ Parameter("input-paths").value(
1209
+ "argo-{{workflow.name}}/%s/{{tasks.%s.outputs.parameters.task-id}}"
1210
+ % (node.name, self._sanitize(node.name))
1211
+ ),
1212
+ Parameter("split-cardinality").value(
1213
+ "{{tasks.%s.outputs.parameters.split-cardinality}}"
1214
+ % self._sanitize(node.name)
1215
+ ),
1216
+ ]
1217
+ if not node.parallel_foreach
1218
+ else [
1219
+ Parameter("num-parallel").value(
1220
+ "{{tasks.%s.outputs.parameters.num-parallel}}"
1221
+ % self._sanitize(node.name)
1222
+ ),
1223
+ Parameter("task-id-entropy").value(
1224
+ "{{tasks.%s.outputs.parameters.task-id-entropy}}"
1225
+ % self._sanitize(node.name)
1226
+ ),
1227
+ ]
1228
+ )
1063
1229
  + (
1064
1230
  [
1065
1231
  Parameter("split-index").value(
1232
+ # TODO : Pass down these parameters to the jobset stuff.
1066
1233
  "{{inputs.parameters.split-index}}"
1067
1234
  ),
1068
1235
  Parameter("root-input-path").value(
@@ -1140,7 +1307,17 @@ class ArgoWorkflows(object):
1140
1307
  # export input_paths as it is used multiple times in the container script
1141
1308
  # and we do not want to repeat the values.
1142
1309
  input_paths_expr = "export INPUT_PATHS=''"
1143
- if node.name != "start":
1310
+ # If node is not a start step or a @parallel join then we will set the input paths.
1311
+ # To set the input-paths as a parameter, we need to ensure that the node
1312
+ # is not (a start node or a parallel join node). Start nodes will have no
1313
+ # input paths and parallel join will derive input paths based on a
1314
+ # formulaic approach using `num-parallel` and `task-id-entropy`.
1315
+ if not (
1316
+ node.name == "start"
1317
+ or (node.type == "join" and self.graph[node.in_funcs[0]].parallel_step)
1318
+ ):
1319
+ # For parallel joins we don't pass the INPUT_PATHS but are dynamically constructed.
1320
+ # So we don't need to set the input paths.
1144
1321
  input_paths_expr = (
1145
1322
  "export INPUT_PATHS={{inputs.parameters.input-paths}}"
1146
1323
  )
@@ -1169,13 +1346,23 @@ class ArgoWorkflows(object):
1169
1346
  task_idx,
1170
1347
  ]
1171
1348
  )
1349
+ if node.parallel_step:
1350
+ task_str = "-".join(
1351
+ [
1352
+ "$TASK_ID_PREFIX",
1353
+ "{{inputs.parameters.task-id-entropy}}", # id_base is addition entropy to based on node-name of the workflow
1354
+ "$TASK_ID_SUFFIX",
1355
+ ]
1356
+ )
1357
+ else:
1358
+ # Generated task_ids need to be non-numeric - see register_task_id in
1359
+ # service.py. We do so by prefixing `t-`
1360
+ _task_id_base = (
1361
+ "$(echo %s | md5sum | cut -d ' ' -f 1 | tail -c 9)" % task_str
1362
+ )
1363
+ task_str = "(t-%s)" % _task_id_base
1172
1364
 
1173
- # Generated task_ids need to be non-numeric - see register_task_id in
1174
- # service.py. We do so by prefixing `t-`
1175
- task_id_expr = (
1176
- "export METAFLOW_TASK_ID="
1177
- "(t-$(echo %s | md5sum | cut -d ' ' -f 1 | tail -c 9))" % task_str
1178
- )
1365
+ task_id_expr = "export METAFLOW_TASK_ID=" "%s" % task_str
1179
1366
  task_id = "$METAFLOW_TASK_ID"
1180
1367
 
1181
1368
  # Resolve retry strategy.
@@ -1194,9 +1381,20 @@ class ArgoWorkflows(object):
1194
1381
  user_code_retries = max_user_code_retries
1195
1382
  total_retries = max_user_code_retries + max_error_retries
1196
1383
  # {{retries}} is only available if retryStrategy is specified
1384
+ # and they are only available in the container templates NOT for custom
1385
+ # Kubernetes manifests like Jobsets.
1386
+ # For custom kubernetes manifests, we will pass the retryCount as a parameter
1387
+ # and use that in the manifest.
1197
1388
  retry_count = (
1198
- "{{retries}}" if max_user_code_retries + max_error_retries else 0
1389
+ (
1390
+ "{{retries}}"
1391
+ if not node.parallel_step
1392
+ else "{{inputs.parameters.retryCount}}"
1393
+ )
1394
+ if total_retries
1395
+ else 0
1199
1396
  )
1397
+
1200
1398
  minutes_between_retries = int(minutes_between_retries)
1201
1399
 
1202
1400
  # Configure log capture.
@@ -1302,13 +1500,24 @@ class ArgoWorkflows(object):
1302
1500
  foreach_step = next(
1303
1501
  n for n in node.in_funcs if self.graph[n].is_inside_foreach
1304
1502
  )
1305
- input_paths = (
1306
- "$(python -m metaflow.plugins.argo.generate_input_paths %s {{workflow.creationTimestamp}} %s {{inputs.parameters.split-cardinality}})"
1307
- % (
1308
- foreach_step,
1309
- input_paths,
1503
+ if not self.graph[node.split_parents[-1]].parallel_foreach:
1504
+ input_paths = (
1505
+ "$(python -m metaflow.plugins.argo.generate_input_paths %s {{workflow.creationTimestamp}} %s {{inputs.parameters.split-cardinality}})"
1506
+ % (
1507
+ foreach_step,
1508
+ input_paths,
1509
+ )
1510
+ )
1511
+ else:
1512
+ # When we run Jobsets with Argo Workflows we need to ensure that `input_paths` are generated using the a formulaic approach
1513
+ # because our current strategy of using volume mounts for outputs won't work with Jobsets
1514
+ input_paths = (
1515
+ "$(python -m metaflow.plugins.argo.jobset_input_paths %s %s {{inputs.parameters.task-id-entropy}} {{inputs.parameters.num-parallel}})"
1516
+ % (
1517
+ run_id,
1518
+ foreach_step,
1519
+ )
1310
1520
  )
1311
- )
1312
1521
  step = [
1313
1522
  "step",
1314
1523
  node.name,
@@ -1318,7 +1527,14 @@ class ArgoWorkflows(object):
1318
1527
  "--max-user-code-retries %d" % user_code_retries,
1319
1528
  "--input-paths %s" % input_paths,
1320
1529
  ]
1321
- if any(self.graph[n].type == "foreach" for n in node.in_funcs):
1530
+ if node.parallel_step:
1531
+ step.append(
1532
+ "--split-index ${MF_CONTROL_INDEX:-$((MF_WORKER_REPLICA_INDEX + 1))}"
1533
+ )
1534
+ # This is needed for setting the value of the UBF context in the CLI.
1535
+ step.append("--ubf-context $UBF_CONTEXT")
1536
+
1537
+ elif any(self.graph[n].type == "foreach" for n in node.in_funcs):
1322
1538
  # Pass split-index to a foreach task
1323
1539
  step.append("--split-index {{inputs.parameters.split-index}}")
1324
1540
  if self.tags:
@@ -1481,17 +1697,47 @@ class ArgoWorkflows(object):
1481
1697
  # join task deterministically inside the join task without resorting to
1482
1698
  # passing a rather long list of (albiet compressed)
1483
1699
  inputs = []
1484
- if node.name != "start":
1700
+ # To set the input-paths as a parameter, we need to ensure that the node
1701
+ # is not (a start node or a parallel join node). Start nodes will have no
1702
+ # input paths and parallel join will derive input paths based on a
1703
+ # formulaic approach.
1704
+ if not (
1705
+ node.name == "start"
1706
+ or (node.type == "join" and self.graph[node.in_funcs[0]].parallel_step)
1707
+ ):
1485
1708
  inputs.append(Parameter("input-paths"))
1486
1709
  if any(self.graph[n].type == "foreach" for n in node.in_funcs):
1487
1710
  # Fetch split-index from parent
1488
1711
  inputs.append(Parameter("split-index"))
1712
+
1489
1713
  if (
1490
1714
  node.type == "join"
1491
1715
  and self.graph[node.split_parents[-1]].type == "foreach"
1492
1716
  ):
1493
- # append this only for joins of foreaches, not static splits
1494
- inputs.append(Parameter("split-cardinality"))
1717
+ # @parallel join tasks require `num-parallel` and `task-id-entropy`
1718
+ # to construct the input paths, so we pass them down as input parameters.
1719
+ if self.graph[node.split_parents[-1]].parallel_foreach:
1720
+ inputs.extend(
1721
+ [Parameter("num-parallel"), Parameter("task-id-entropy")]
1722
+ )
1723
+ else:
1724
+ # append this only for joins of foreaches, not static splits
1725
+ inputs.append(Parameter("split-cardinality"))
1726
+ # We can use an `elif` condition because the first `if` condition validates if its
1727
+ # a foreach join node, hence we can safely assume that if that condition fails then
1728
+ # we can check if the node is a @parallel node.
1729
+ elif node.parallel_step:
1730
+ inputs.extend(
1731
+ [
1732
+ Parameter("num-parallel"),
1733
+ Parameter("task-id-entropy"),
1734
+ Parameter("jobset-name"),
1735
+ Parameter("workerCount"),
1736
+ ]
1737
+ )
1738
+ if any(d.name == "retry" for d in node.decorators):
1739
+ inputs.append(Parameter("retryCount"))
1740
+
1495
1741
  if node.is_inside_foreach and self.graph[node.out_funcs[0]].type == "join":
1496
1742
  if any(
1497
1743
  self.graph[parent].matching_join
@@ -1508,7 +1754,9 @@ class ArgoWorkflows(object):
1508
1754
  inputs.append(Parameter("root-input-path"))
1509
1755
 
1510
1756
  outputs = []
1511
- if node.name != "end":
1757
+ # @parallel steps will not have a task-id as an output parameter since task-ids
1758
+ # are derived at runtime.
1759
+ if not (node.name == "end" or node.parallel_step):
1512
1760
  outputs = [Parameter("task-id").valueFrom({"path": "/mnt/out/task_id"})]
1513
1761
  if node.type == "foreach":
1514
1762
  # Emit split cardinality from foreach task
@@ -1521,6 +1769,19 @@ class ArgoWorkflows(object):
1521
1769
  )
1522
1770
  )
1523
1771
 
1772
+ if node.parallel_foreach:
1773
+ outputs.extend(
1774
+ [
1775
+ Parameter("num-parallel").valueFrom(
1776
+ {"path": "/mnt/out/num_parallel"}
1777
+ ),
1778
+ Parameter("task-id-entropy").valueFrom(
1779
+ {"path": "/mnt/out/task_id_entropy"}
1780
+ ),
1781
+ ]
1782
+ )
1783
+ # Outputs should be defined over here, Not in the _dag_template for the `num_parallel` stuff.
1784
+
1524
1785
  # It makes no sense to set env vars to None (shows up as "None" string)
1525
1786
  # Also we skip some env vars (e.g. in case we want to pull them from KUBERNETES_SECRETS)
1526
1787
  env = {
@@ -1550,6 +1811,156 @@ class ArgoWorkflows(object):
1550
1811
  # liked to inline this ContainerTemplate and avoid scanning the workflow
1551
1812
  # twice, but due to issues with variable substitution, we will have to
1552
1813
  # live with this routine.
1814
+ if node.parallel_step:
1815
+
1816
+ # Explicitly add the task-id-hint label. This is important because this label
1817
+ # is returned as an Output parameter of this step and is used subsequently an
1818
+ # an input in the join step. Even the num_parallel is used as an output parameter
1819
+ kubernetes_labels = self.kubernetes_labels.copy()
1820
+ jobset_name = "{{inputs.parameters.jobset-name}}"
1821
+ kubernetes_labels[
1822
+ "task_id_entropy"
1823
+ ] = "{{inputs.parameters.task-id-entropy}}"
1824
+ kubernetes_labels["num_parallel"] = "{{inputs.parameters.num-parallel}}"
1825
+ jobset = KubernetesArgoJobSet(
1826
+ kubernetes_sdk=kubernetes_sdk,
1827
+ name=jobset_name,
1828
+ flow_name=self.flow.name,
1829
+ run_id=run_id,
1830
+ step_name=self._sanitize(node.name),
1831
+ task_id=task_id,
1832
+ attempt=retry_count,
1833
+ user=self.username,
1834
+ subdomain=jobset_name,
1835
+ command=cmds,
1836
+ namespace=resources["namespace"],
1837
+ image=resources["image"],
1838
+ image_pull_policy=resources["image_pull_policy"],
1839
+ service_account=resources["service_account"],
1840
+ secrets=(
1841
+ [
1842
+ k
1843
+ for k in (
1844
+ list(
1845
+ []
1846
+ if not resources.get("secrets")
1847
+ else [resources.get("secrets")]
1848
+ if isinstance(resources.get("secrets"), str)
1849
+ else resources.get("secrets")
1850
+ )
1851
+ + KUBERNETES_SECRETS.split(",")
1852
+ + ARGO_WORKFLOWS_KUBERNETES_SECRETS.split(",")
1853
+ )
1854
+ if k
1855
+ ]
1856
+ ),
1857
+ node_selector=resources.get("node_selector"),
1858
+ cpu=str(resources["cpu"]),
1859
+ memory=str(resources["memory"]),
1860
+ disk=str(resources["disk"]),
1861
+ gpu=resources["gpu"],
1862
+ gpu_vendor=str(resources["gpu_vendor"]),
1863
+ tolerations=resources["tolerations"],
1864
+ use_tmpfs=use_tmpfs,
1865
+ tmpfs_tempdir=tmpfs_tempdir,
1866
+ tmpfs_size=tmpfs_size,
1867
+ tmpfs_path=tmpfs_path,
1868
+ timeout_in_seconds=run_time_limit,
1869
+ persistent_volume_claims=resources["persistent_volume_claims"],
1870
+ shared_memory=shared_memory,
1871
+ port=port,
1872
+ )
1873
+
1874
+ for k, v in env.items():
1875
+ jobset.environment_variable(k, v)
1876
+
1877
+ for k, v in kubernetes_labels.items():
1878
+ jobset.label(k, v)
1879
+
1880
+ ## -----Jobset specific env vars START here-----
1881
+ jobset.environment_variable(
1882
+ "MF_MASTER_ADDR", jobset.jobset_control_addr
1883
+ )
1884
+ jobset.environment_variable("MF_MASTER_PORT", str(port))
1885
+ jobset.environment_variable(
1886
+ "MF_WORLD_SIZE", "{{inputs.parameters.num-parallel}}"
1887
+ )
1888
+ # for k, v in .items():
1889
+ jobset.environment_variables_from_selectors(
1890
+ {
1891
+ "MF_WORKER_REPLICA_INDEX": "metadata.annotations['jobset.sigs.k8s.io/job-index']",
1892
+ "JOBSET_RESTART_ATTEMPT": "metadata.annotations['jobset.sigs.k8s.io/restart-attempt']",
1893
+ "METAFLOW_KUBERNETES_JOBSET_NAME": "metadata.annotations['jobset.sigs.k8s.io/jobset-name']",
1894
+ "METAFLOW_KUBERNETES_POD_NAMESPACE": "metadata.namespace",
1895
+ "METAFLOW_KUBERNETES_POD_NAME": "metadata.name",
1896
+ "METAFLOW_KUBERNETES_POD_ID": "metadata.uid",
1897
+ "METAFLOW_KUBERNETES_SERVICE_ACCOUNT_NAME": "spec.serviceAccountName",
1898
+ "METAFLOW_KUBERNETES_NODE_IP": "status.hostIP",
1899
+ # `TASK_ID_SUFFIX` is needed for the construction of the task-ids
1900
+ "TASK_ID_SUFFIX": "metadata.annotations['jobset.sigs.k8s.io/job-index']",
1901
+ }
1902
+ )
1903
+ annotations = {
1904
+ # setting annotations explicitly as they wont be
1905
+ # passed down from WorkflowTemplate level
1906
+ "metaflow/step_name": node.name,
1907
+ "metaflow/attempt": str(retry_count),
1908
+ "metaflow/run_id": run_id,
1909
+ "metaflow/production_token": self.production_token,
1910
+ "metaflow/owner": self.username,
1911
+ "metaflow/user": "argo-workflows",
1912
+ "metaflow/flow_name": self.flow.name,
1913
+ }
1914
+ if current.get("project_name"):
1915
+ annotations.update(
1916
+ {
1917
+ "metaflow/project_name": current.project_name,
1918
+ "metaflow/branch_name": current.branch_name,
1919
+ "metaflow/project_flow_name": current.project_flow_name,
1920
+ }
1921
+ )
1922
+ for k, v in annotations.items():
1923
+ jobset.annotation(k, v)
1924
+ ## -----Jobset specific env vars END here-----
1925
+ ## ---- Jobset control/workers specific vars START here ----
1926
+ jobset.control.replicas(1)
1927
+ jobset.worker.replicas("{{=asInt(inputs.parameters.workerCount)}}")
1928
+ jobset.control.environment_variable("UBF_CONTEXT", UBF_CONTROL)
1929
+ jobset.worker.environment_variable("UBF_CONTEXT", UBF_TASK)
1930
+ jobset.control.environment_variable("MF_CONTROL_INDEX", "0")
1931
+ # `TASK_ID_PREFIX` needs to explicitly be `control` or `worker`
1932
+ # because the join task uses a formulaic approach to infer the task-ids
1933
+ jobset.control.environment_variable("TASK_ID_PREFIX", "control")
1934
+ jobset.worker.environment_variable("TASK_ID_PREFIX", "worker")
1935
+
1936
+ ## ---- Jobset control/workers specific vars END here ----
1937
+ yield (
1938
+ Template(ArgoWorkflows._sanitize(node.name))
1939
+ .resource(
1940
+ "create",
1941
+ jobset.dump(),
1942
+ "status.terminalState == Completed",
1943
+ "status.terminalState == Failed",
1944
+ )
1945
+ .inputs(Inputs().parameters(inputs))
1946
+ .outputs(
1947
+ Outputs().parameters(
1948
+ [
1949
+ Parameter("task-id-entropy").valueFrom(
1950
+ {"jsonPath": "{.metadata.labels.task_id_entropy}"}
1951
+ ),
1952
+ Parameter("num-parallel").valueFrom(
1953
+ {"jsonPath": "{.metadata.labels.num_parallel}"}
1954
+ ),
1955
+ ]
1956
+ )
1957
+ )
1958
+ .retry_strategy(
1959
+ times=total_retries,
1960
+ minutes_between_retries=minutes_between_retries,
1961
+ )
1962
+ )
1963
+ continue
1553
1964
  yield (
1554
1965
  Template(self._sanitize(node.name))
1555
1966
  # Set @timeout values
@@ -1838,7 +2249,7 @@ class ArgoWorkflows(object):
1838
2249
  "fields": [
1839
2250
  {
1840
2251
  "type": "mrkdwn",
1841
- "text": "*Project:* %s" % current.project_name
2252
+ "text": "*Project:* %s" % current.project_name
1842
2253
  },
1843
2254
  {
1844
2255
  "type": "mrkdwn",
@@ -2612,6 +3023,15 @@ class Template(object):
2612
3023
  def to_json(self):
2613
3024
  return self.payload
2614
3025
 
3026
+ def resource(self, action, manifest, success_criteria, failure_criteria):
3027
+ self.payload["resource"] = {}
3028
+ self.payload["resource"]["action"] = action
3029
+ self.payload["setOwnerReference"] = True
3030
+ self.payload["resource"]["successCondition"] = success_criteria
3031
+ self.payload["resource"]["failureCondition"] = failure_criteria
3032
+ self.payload["resource"]["manifest"] = manifest
3033
+ return self
3034
+
2615
3035
  def __str__(self):
2616
3036
  return json.dumps(self.payload, indent=4)
2617
3037