ob-metaflow 2.17.1.0__py2.py3-none-any.whl → 2.18.0.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ob-metaflow might be problematic. Click here for more details.

Files changed (29) hide show
  1. metaflow/cli_components/run_cmds.py +15 -0
  2. metaflow/datastore/task_datastore.py +3 -0
  3. metaflow/flowspec.py +91 -1
  4. metaflow/graph.py +154 -13
  5. metaflow/lint.py +94 -3
  6. metaflow/plugins/argo/argo_workflows.py +367 -11
  7. metaflow/plugins/argo/argo_workflows_decorator.py +9 -0
  8. metaflow/plugins/argo/conditional_input_paths.py +21 -0
  9. metaflow/plugins/aws/step_functions/step_functions.py +6 -0
  10. metaflow/plugins/cards/card_modules/basic.py +14 -2
  11. metaflow/plugins/cards/card_modules/main.css +1 -0
  12. metaflow/plugins/cards/card_modules/main.js +31 -31
  13. metaflow/plugins/catch_decorator.py +9 -0
  14. metaflow/plugins/package_cli.py +1 -1
  15. metaflow/plugins/parallel_decorator.py +7 -0
  16. metaflow/runtime.py +217 -34
  17. metaflow/task.py +129 -34
  18. metaflow/user_configs/config_parameters.py +3 -1
  19. metaflow/user_decorators/user_step_decorator.py +31 -6
  20. metaflow/version.py +1 -1
  21. {ob_metaflow-2.17.1.0.dist-info → ob_metaflow-2.18.0.1.dist-info}/METADATA +2 -2
  22. {ob_metaflow-2.17.1.0.dist-info → ob_metaflow-2.18.0.1.dist-info}/RECORD +29 -27
  23. {ob_metaflow-2.17.1.0.data → ob_metaflow-2.18.0.1.data}/data/share/metaflow/devtools/Makefile +0 -0
  24. {ob_metaflow-2.17.1.0.data → ob_metaflow-2.18.0.1.data}/data/share/metaflow/devtools/Tiltfile +0 -0
  25. {ob_metaflow-2.17.1.0.data → ob_metaflow-2.18.0.1.data}/data/share/metaflow/devtools/pick_services.sh +0 -0
  26. {ob_metaflow-2.17.1.0.dist-info → ob_metaflow-2.18.0.1.dist-info}/WHEEL +0 -0
  27. {ob_metaflow-2.17.1.0.dist-info → ob_metaflow-2.18.0.1.dist-info}/entry_points.txt +0 -0
  28. {ob_metaflow-2.17.1.0.dist-info → ob_metaflow-2.18.0.1.dist-info}/licenses/LICENSE +0 -0
  29. {ob_metaflow-2.17.1.0.dist-info → ob_metaflow-2.18.0.1.dist-info}/top_level.txt +0 -0
@@ -152,6 +152,7 @@ class ArgoWorkflows(object):
152
152
 
153
153
  self.name = name
154
154
  self.graph = graph
155
+ self._parse_conditional_branches()
155
156
  self.flow = flow
156
157
  self.code_package_metadata = code_package_metadata
157
158
  self.code_package_sha = code_package_sha
@@ -929,6 +930,131 @@ class ArgoWorkflows(object):
929
930
  )
930
931
  )
931
932
 
933
+ # Visit every node and record information on conditional step structure
934
+ def _parse_conditional_branches(self):
935
+ self.conditional_nodes = set()
936
+ self.conditional_join_nodes = set()
937
+ self.matching_conditional_join_dict = {}
938
+ self.recursive_nodes = set()
939
+
940
+ node_conditional_parents = {}
941
+ node_conditional_branches = {}
942
+
943
+ def _visit(node, seen, conditional_branch, conditional_parents=None):
944
+ if not node.type == "split-switch" and not (
945
+ conditional_branch and conditional_parents
946
+ ):
947
+ # skip regular non-conditional nodes entirely
948
+ return
949
+
950
+ if node.type == "split-switch":
951
+ conditional_branch = conditional_branch + [node.name]
952
+ node_conditional_branches[node.name] = conditional_branch
953
+
954
+ conditional_parents = (
955
+ [node.name]
956
+ if not conditional_parents
957
+ else conditional_parents + [node.name]
958
+ )
959
+ node_conditional_parents[node.name] = conditional_parents
960
+
961
+ # check for recursion. this split is recursive if any of its out functions are itself.
962
+ if any(
963
+ out_func for out_func in node.out_funcs if out_func == node.name
964
+ ):
965
+ self.recursive_nodes.add(node.name)
966
+
967
+ if conditional_parents and not node.type == "split-switch":
968
+ node_conditional_parents[node.name] = conditional_parents
969
+ conditional_branch = conditional_branch + [node.name]
970
+ node_conditional_branches[node.name] = conditional_branch
971
+
972
+ self.conditional_nodes.add(node.name)
973
+
974
+ if conditional_branch and conditional_parents:
975
+ for n in node.out_funcs:
976
+ child = self.graph[n]
977
+ if n not in seen:
978
+ _visit(
979
+ child, seen + [n], conditional_branch, conditional_parents
980
+ )
981
+
982
+ # First we visit all nodes to determine conditional parents and branches
983
+ for n in self.graph:
984
+ _visit(n, [], [])
985
+
986
+ # Then we traverse again in order to determine conditional join nodes, and matching conditional join info
987
+ for node in self.graph:
988
+ if node_conditional_parents.get(node.name, False):
989
+ # do the required postprocessing for anything requiring node.in_funcs
990
+
991
+ # check that in previous parsing we have not closed all conditional in_funcs.
992
+ # If so, this step can not be conditional either
993
+ is_conditional = any(
994
+ in_func in self.conditional_nodes
995
+ or self.graph[in_func].type == "split-switch"
996
+ for in_func in node.in_funcs
997
+ )
998
+ if is_conditional:
999
+ self.conditional_nodes.add(node.name)
1000
+ else:
1001
+ if node.name in self.conditional_nodes:
1002
+ self.conditional_nodes.remove(node.name)
1003
+
1004
+ # does this node close the latest conditional parent branches?
1005
+ conditional_in_funcs = [
1006
+ in_func
1007
+ for in_func in node.in_funcs
1008
+ if node_conditional_branches.get(in_func, False)
1009
+ ]
1010
+ closed_conditional_parents = []
1011
+ for last_split_switch in node_conditional_parents.get(node.name, [])[
1012
+ ::-1
1013
+ ]:
1014
+ last_conditional_split_nodes = self.graph[
1015
+ last_split_switch
1016
+ ].out_funcs
1017
+ # p needs to be in at least one conditional_branch for it to be closed.
1018
+ if all(
1019
+ any(
1020
+ p in node_conditional_branches.get(in_func, [])
1021
+ for in_func in conditional_in_funcs
1022
+ )
1023
+ for p in last_conditional_split_nodes
1024
+ ):
1025
+ closed_conditional_parents.append(last_split_switch)
1026
+
1027
+ self.conditional_join_nodes.add(node.name)
1028
+ self.matching_conditional_join_dict[last_split_switch] = (
1029
+ node.name
1030
+ )
1031
+
1032
+ # Did we close all conditionals? Then this branch and all its children are not conditional anymore (unless a new conditional branch is encountered).
1033
+ if not [
1034
+ p
1035
+ for p in node_conditional_parents.get(node.name, [])
1036
+ if p not in closed_conditional_parents
1037
+ ]:
1038
+ if node.name in self.conditional_nodes:
1039
+ self.conditional_nodes.remove(node.name)
1040
+ node_conditional_parents[node.name] = []
1041
+ for p in node.out_funcs:
1042
+ if p in self.conditional_nodes:
1043
+ self.conditional_nodes.remove(p)
1044
+ node_conditional_parents[p] = []
1045
+
1046
+ def _is_conditional_node(self, node):
1047
+ return node.name in self.conditional_nodes
1048
+
1049
+ def _is_conditional_join_node(self, node):
1050
+ return node.name in self.conditional_join_nodes
1051
+
1052
+ def _is_recursive_node(self, node):
1053
+ return node.name in self.recursive_nodes
1054
+
1055
+ def _matching_conditional_join(self, node):
1056
+ return self.matching_conditional_join_dict.get(node.name, None)
1057
+
932
1058
  # Visit every node and yield the uber DAGTemplate(s).
933
1059
  def _dag_templates(self):
934
1060
  def _visit(
@@ -937,6 +1063,7 @@ class ArgoWorkflows(object):
937
1063
  templates=None,
938
1064
  dag_tasks=None,
939
1065
  parent_foreach=None,
1066
+ seen=None,
940
1067
  ): # Returns Tuple[List[Template], List[DAGTask]]
941
1068
  """ """
942
1069
  # Every for-each node results in a separate subDAG and an equivalent
@@ -946,18 +1073,28 @@ class ArgoWorkflows(object):
946
1073
  # of the for-each node.
947
1074
 
948
1075
  # Emit if we have reached the end of the sub workflow
1076
+ if seen is None:
1077
+ seen = []
949
1078
  if dag_tasks is None:
950
1079
  dag_tasks = []
951
1080
  if templates is None:
952
1081
  templates = []
1082
+
953
1083
  if exit_node is not None and exit_node is node.name:
954
1084
  return templates, dag_tasks
1085
+ if node.name in seen:
1086
+ return templates, dag_tasks
1087
+
1088
+ seen.append(node.name)
1089
+
1090
+ # helper variable for recursive conditional inputs
1091
+ has_foreach_inputs = False
955
1092
  if node.name == "start":
956
1093
  # Start node has no dependencies.
957
1094
  dag_task = DAGTask(self._sanitize(node.name)).template(
958
1095
  self._sanitize(node.name)
959
1096
  )
960
- elif (
1097
+ if (
961
1098
  node.is_inside_foreach
962
1099
  and self.graph[node.in_funcs[0]].type == "foreach"
963
1100
  and not self.graph[node.in_funcs[0]].parallel_foreach
@@ -965,9 +1102,10 @@ class ArgoWorkflows(object):
965
1102
  # vs what is a "num_parallel" based foreach (i.e. something that follows gang semantics.)
966
1103
  # A `regular` foreach is basically any arbitrary kind of foreach.
967
1104
  ):
1105
+ # helper variable for recursive conditional inputs
1106
+ has_foreach_inputs = True
968
1107
  # Child of a foreach node needs input-paths as well as split-index
969
1108
  # This child is the first node of the sub workflow and has no dependency
970
-
971
1109
  parameters = [
972
1110
  Parameter("input-paths").value("{{inputs.parameters.input-paths}}"),
973
1111
  Parameter("split-index").value("{{inputs.parameters.split-index}}"),
@@ -1091,15 +1229,43 @@ class ArgoWorkflows(object):
1091
1229
  ]
1092
1230
  )
1093
1231
 
1232
+ conditional_deps = [
1233
+ "%s.Succeeded" % self._sanitize(in_func)
1234
+ for in_func in node.in_funcs
1235
+ if self._is_conditional_node(self.graph[in_func])
1236
+ ]
1237
+ required_deps = [
1238
+ "%s.Succeeded" % self._sanitize(in_func)
1239
+ for in_func in node.in_funcs
1240
+ if not self._is_conditional_node(self.graph[in_func])
1241
+ ]
1242
+ both_conditions = required_deps and conditional_deps
1243
+
1244
+ depends_str = "{required}{_and}{conditional}".format(
1245
+ required=("(%s)" if both_conditions else "%s")
1246
+ % " && ".join(required_deps),
1247
+ _and=" && " if both_conditions else "",
1248
+ conditional=("(%s)" if both_conditions else "%s")
1249
+ % " || ".join(conditional_deps),
1250
+ )
1094
1251
  dag_task = (
1095
1252
  DAGTask(self._sanitize(node.name))
1096
- .dependencies(
1097
- [self._sanitize(in_func) for in_func in node.in_funcs]
1098
- )
1253
+ .depends(depends_str)
1099
1254
  .template(self._sanitize(node.name))
1100
1255
  .arguments(Arguments().parameters(parameters))
1101
1256
  )
1102
1257
 
1258
+ # Add conditional if this is the first step in a conditional branch
1259
+ if (
1260
+ self._is_conditional_node(node)
1261
+ and self.graph[node.in_funcs[0]].type == "split-switch"
1262
+ ):
1263
+ in_func = node.in_funcs[0]
1264
+ dag_task.when(
1265
+ "{{tasks.%s.outputs.parameters.switch-step}}==%s"
1266
+ % (self._sanitize(in_func), node.name)
1267
+ )
1268
+
1103
1269
  dag_tasks.append(dag_task)
1104
1270
  # End the workflow if we have reached the end of the flow
1105
1271
  if node.type == "end":
@@ -1117,6 +1283,7 @@ class ArgoWorkflows(object):
1117
1283
  templates,
1118
1284
  dag_tasks,
1119
1285
  parent_foreach,
1286
+ seen,
1120
1287
  )
1121
1288
  return _visit(
1122
1289
  self.graph[node.matching_join],
@@ -1124,6 +1291,119 @@ class ArgoWorkflows(object):
1124
1291
  templates,
1125
1292
  dag_tasks,
1126
1293
  parent_foreach,
1294
+ seen,
1295
+ )
1296
+ elif node.type == "split-switch":
1297
+ if self._is_recursive_node(node):
1298
+ # we need an additional recursive template if the step is recursive
1299
+ # NOTE: in the recursive case, the original step is renamed in the container templates to 'recursive-<step_name>'
1300
+ # so that we do not have to touch the step references in the DAG.
1301
+ #
1302
+ # NOTE: The way that recursion in Argo Workflows is achieved is with the following structure:
1303
+ # - the usual 'example-step' template which would match example_step in flow code is renamed to 'recursive-example-step'
1304
+ # - templates has another template with the original task name: 'example-step'
1305
+ # - the template 'example-step' in turn has steps
1306
+ # - 'example-step-internal' which uses the metaflow step executing template 'recursive-example-step'
1307
+ # - 'example-step-recursion' which calls the parent template 'example-step' if switch-step output from 'example-step-internal' matches the condition.
1308
+ sanitized_name = self._sanitize(node.name)
1309
+ templates.append(
1310
+ Template(sanitized_name)
1311
+ .steps(
1312
+ [
1313
+ WorkflowStep()
1314
+ .name("%s-internal" % sanitized_name)
1315
+ .template("recursive-%s" % sanitized_name)
1316
+ .arguments(
1317
+ Arguments().parameters(
1318
+ [
1319
+ Parameter("input-paths").value(
1320
+ "{{inputs.parameters.input-paths}}"
1321
+ )
1322
+ ]
1323
+ # Add the additional inputs required by specific node types.
1324
+ # We do not need to cover joins or @parallel, as a split-switch step can not be either one of these.
1325
+ + (
1326
+ [
1327
+ Parameter("split-index").value(
1328
+ "{{inputs.parameters.split-index}}"
1329
+ )
1330
+ ]
1331
+ if has_foreach_inputs
1332
+ else []
1333
+ )
1334
+ )
1335
+ )
1336
+ ]
1337
+ )
1338
+ .steps(
1339
+ [
1340
+ WorkflowStep()
1341
+ .name("%s-recursion" % sanitized_name)
1342
+ .template(sanitized_name)
1343
+ .when(
1344
+ "{{steps.%s-internal.outputs.parameters.switch-step}}==%s"
1345
+ % (sanitized_name, node.name)
1346
+ )
1347
+ .arguments(
1348
+ Arguments().parameters(
1349
+ [
1350
+ Parameter("input-paths").value(
1351
+ "argo-{{workflow.name}}/%s/{{steps.%s-internal.outputs.parameters.task-id}}"
1352
+ % (node.name, sanitized_name)
1353
+ )
1354
+ ]
1355
+ + (
1356
+ [
1357
+ Parameter("split-index").value(
1358
+ "{{inputs.parameters.split-index}}"
1359
+ )
1360
+ ]
1361
+ if has_foreach_inputs
1362
+ else []
1363
+ )
1364
+ )
1365
+ ),
1366
+ ]
1367
+ )
1368
+ .inputs(Inputs().parameters(parameters))
1369
+ .outputs(
1370
+ # NOTE: We try to read the output parameters from the recursive template call first (<step>-recursion), and the internal step second (<step>-internal).
1371
+ # This guarantees that we always get the output parameters of the last recursive step that executed.
1372
+ Outputs().parameters(
1373
+ [
1374
+ Parameter("task-id").valueFrom(
1375
+ {
1376
+ "expression": "(steps['%s-recursion']?.outputs ?? steps['%s-internal']?.outputs).parameters['task-id']"
1377
+ % (sanitized_name, sanitized_name)
1378
+ }
1379
+ ),
1380
+ Parameter("switch-step").valueFrom(
1381
+ {
1382
+ "expression": "(steps['%s-recursion']?.outputs ?? steps['%s-internal']?.outputs).parameters['switch-step']"
1383
+ % (sanitized_name, sanitized_name)
1384
+ }
1385
+ ),
1386
+ ]
1387
+ )
1388
+ )
1389
+ )
1390
+ for n in node.out_funcs:
1391
+ _visit(
1392
+ self.graph[n],
1393
+ self._matching_conditional_join(node),
1394
+ templates,
1395
+ dag_tasks,
1396
+ parent_foreach,
1397
+ seen,
1398
+ )
1399
+
1400
+ return _visit(
1401
+ self.graph[self._matching_conditional_join(node)],
1402
+ exit_node,
1403
+ templates,
1404
+ dag_tasks,
1405
+ parent_foreach,
1406
+ seen,
1127
1407
  )
1128
1408
  # For foreach nodes generate a new sub DAGTemplate
1129
1409
  # We do this for "regular" foreaches (ie. `self.next(self.a, foreach=)`)
@@ -1152,7 +1432,7 @@ class ArgoWorkflows(object):
1152
1432
  #
1153
1433
  foreach_task = (
1154
1434
  DAGTask(foreach_template_name)
1155
- .dependencies([self._sanitize(node.name)])
1435
+ .depends(f"{self._sanitize(node.name)}.Succeeded")
1156
1436
  .template(foreach_template_name)
1157
1437
  .arguments(
1158
1438
  Arguments().parameters(
@@ -1197,6 +1477,16 @@ class ArgoWorkflows(object):
1197
1477
  % self._sanitize(node.name)
1198
1478
  )
1199
1479
  )
1480
+ # Add conditional if this is the first step in a conditional branch
1481
+ if self._is_conditional_node(node) and not any(
1482
+ self._is_conditional_node(self.graph[in_func])
1483
+ for in_func in node.in_funcs
1484
+ ):
1485
+ in_func = node.in_funcs[0]
1486
+ foreach_task.when(
1487
+ "{{tasks.%s.outputs.parameters.switch-step}}==%s"
1488
+ % (self._sanitize(in_func), node.name)
1489
+ )
1200
1490
  dag_tasks.append(foreach_task)
1201
1491
  templates, dag_tasks_1 = _visit(
1202
1492
  self.graph[node.out_funcs[0]],
@@ -1204,6 +1494,7 @@ class ArgoWorkflows(object):
1204
1494
  templates,
1205
1495
  [],
1206
1496
  node.name,
1497
+ seen,
1207
1498
  )
1208
1499
 
1209
1500
  # How do foreach's work on Argo:
@@ -1240,7 +1531,22 @@ class ArgoWorkflows(object):
1240
1531
  self.graph[node.matching_join].in_funcs[0]
1241
1532
  )
1242
1533
  }
1243
- )
1534
+ if not self._is_conditional_join_node(
1535
+ self.graph[node.matching_join]
1536
+ )
1537
+ else
1538
+ # Note: If the nodes leading to the join are conditional, then we need to use an expression to pick the outputs from the task that executed.
1539
+ # ref for operators: https://github.com/expr-lang/expr/blob/master/docs/language-definition.md
1540
+ {
1541
+ "expression": "get((%s)?.parameters, 'task-id')"
1542
+ % " ?? ".join(
1543
+ f"tasks['{self._sanitize(func)}']?.outputs"
1544
+ for func in self.graph[
1545
+ node.matching_join
1546
+ ].in_funcs
1547
+ )
1548
+ }
1549
+ ),
1244
1550
  ]
1245
1551
  if not node.parallel_foreach
1246
1552
  else [
@@ -1273,7 +1579,7 @@ class ArgoWorkflows(object):
1273
1579
  join_foreach_task = (
1274
1580
  DAGTask(self._sanitize(self.graph[node.matching_join].name))
1275
1581
  .template(self._sanitize(self.graph[node.matching_join].name))
1276
- .dependencies([foreach_template_name])
1582
+ .depends(f"{foreach_template_name}.Succeeded")
1277
1583
  .arguments(
1278
1584
  Arguments().parameters(
1279
1585
  (
@@ -1322,6 +1628,7 @@ class ArgoWorkflows(object):
1322
1628
  templates,
1323
1629
  dag_tasks,
1324
1630
  parent_foreach,
1631
+ seen,
1325
1632
  )
1326
1633
  # For linear nodes continue traversing to the next node
1327
1634
  if node.type in ("linear", "join", "start"):
@@ -1331,6 +1638,7 @@ class ArgoWorkflows(object):
1331
1638
  templates,
1332
1639
  dag_tasks,
1333
1640
  parent_foreach,
1641
+ seen,
1334
1642
  )
1335
1643
  else:
1336
1644
  raise ArgoWorkflowsException(
@@ -1400,6 +1708,14 @@ class ArgoWorkflows(object):
1400
1708
  input_paths_expr = (
1401
1709
  "export INPUT_PATHS={{inputs.parameters.input-paths}}"
1402
1710
  )
1711
+ if self._is_conditional_join_node(node):
1712
+ # NOTE: Argo template expressions that fail to resolve, output the expression itself as a value.
1713
+ # With conditional steps, some of the input-paths are therefore 'broken' due to containing a nil expression
1714
+ # e.g. "{{ tasks['A'].outputs.parameters.task-id }}" when task A never executed.
1715
+ # We base64 encode the input-paths in order to not pollute the execution environment with templating expressions.
1716
+ # NOTE: Adding conditionals that check if a key exists or not does not work either, due to an issue with how Argo
1717
+ # handles tasks in a nested foreach (withParam template) leading to all such expressions getting evaluated as false.
1718
+ input_paths_expr = "export INPUT_PATHS={{=toBase64(inputs.parameters['input-paths'])}}"
1403
1719
  input_paths = "$(echo $INPUT_PATHS)"
1404
1720
  if any(self.graph[n].type == "foreach" for n in node.in_funcs):
1405
1721
  task_idx = "{{inputs.parameters.split-index}}"
@@ -1415,7 +1731,6 @@ class ArgoWorkflows(object):
1415
1731
  # foreaches
1416
1732
  task_idx = "{{inputs.parameters.split-index}}"
1417
1733
  root_input = "{{inputs.parameters.root-input-path}}"
1418
-
1419
1734
  # Task string to be hashed into an ID
1420
1735
  task_str = "-".join(
1421
1736
  [
@@ -1572,10 +1887,27 @@ class ArgoWorkflows(object):
1572
1887
  ]
1573
1888
  )
1574
1889
  input_paths = "%s/_parameters/%s" % (run_id, task_id_params)
1890
+ # Only for static joins and conditional_joins
1891
+ elif self._is_conditional_join_node(node) and not (
1892
+ node.type == "join"
1893
+ and self.graph[node.split_parents[-1]].type == "foreach"
1894
+ ):
1895
+ input_paths = (
1896
+ "$(python -m metaflow.plugins.argo.conditional_input_paths %s)"
1897
+ % input_paths
1898
+ )
1575
1899
  elif (
1576
1900
  node.type == "join"
1577
1901
  and self.graph[node.split_parents[-1]].type == "foreach"
1578
1902
  ):
1903
+ # foreach-joins straight out of conditional branches are not yet supported
1904
+ if self._is_conditional_join_node(node):
1905
+ raise ArgoWorkflowsException(
1906
+ "Conditional steps inside a foreach that transition directly into a join step are not currently supported.\n"
1907
+ "As a workaround, add a common step after the conditional steps %s "
1908
+ "that will transition to a join."
1909
+ % ", ".join("*%s*" % f for f in node.in_funcs)
1910
+ )
1579
1911
  # Set aggregated input-paths for a for-each join
1580
1912
  foreach_step = next(
1581
1913
  n for n in node.in_funcs if self.graph[n].is_inside_foreach
@@ -1818,7 +2150,7 @@ class ArgoWorkflows(object):
1818
2150
  [Parameter("num-parallel"), Parameter("task-id-entropy")]
1819
2151
  )
1820
2152
  else:
1821
- # append this only for joins of foreaches, not static splits
2153
+ # append these only for joins of foreaches, not static splits
1822
2154
  inputs.append(Parameter("split-cardinality"))
1823
2155
  # check if the node is a @parallel node.
1824
2156
  elif node.parallel_step:
@@ -1853,6 +2185,13 @@ class ArgoWorkflows(object):
1853
2185
  # are derived at runtime.
1854
2186
  if not (node.name == "end" or node.parallel_step):
1855
2187
  outputs = [Parameter("task-id").valueFrom({"path": "/mnt/out/task_id"})]
2188
+
2189
+ # If this step is a split-switch one, we need to output the switch step name
2190
+ if node.type == "split-switch":
2191
+ outputs.append(
2192
+ Parameter("switch-step").valueFrom({"path": "/mnt/out/switch_step"})
2193
+ )
2194
+
1856
2195
  if node.type == "foreach":
1857
2196
  # Emit split cardinality from foreach task
1858
2197
  outputs.append(
@@ -2091,8 +2430,13 @@ class ArgoWorkflows(object):
2091
2430
  )
2092
2431
  )
2093
2432
  else:
2433
+ template_name = self._sanitize(node.name)
2434
+ if self._is_recursive_node(node):
2435
+ # The recursive template has the original step name,
2436
+ # this becomes a template within the recursive ones 'steps'
2437
+ template_name = self._sanitize("recursive-%s" % node.name)
2094
2438
  yield (
2095
- Template(self._sanitize(node.name))
2439
+ Template(template_name)
2096
2440
  # Set @timeout values
2097
2441
  .active_deadline_seconds(run_time_limit)
2098
2442
  # Set service account
@@ -3585,6 +3929,10 @@ class WorkflowStep(object):
3585
3929
  self.payload["template"] = str(template)
3586
3930
  return self
3587
3931
 
3932
+ def arguments(self, arguments):
3933
+ self.payload["arguments"] = arguments.to_json()
3934
+ return self
3935
+
3588
3936
  def when(self, condition):
3589
3937
  self.payload["when"] = str(condition)
3590
3938
  return self
@@ -4027,6 +4375,10 @@ class DAGTask(object):
4027
4375
  self.payload["dependencies"] = dependencies
4028
4376
  return self
4029
4377
 
4378
+ def depends(self, depends: str):
4379
+ self.payload["depends"] = depends
4380
+ return self
4381
+
4030
4382
  def template(self, template):
4031
4383
  # Template reference
4032
4384
  self.payload["template"] = template
@@ -4038,6 +4390,10 @@ class DAGTask(object):
4038
4390
  self.payload["inline"] = template.to_json()
4039
4391
  return self
4040
4392
 
4393
+ def when(self, when: str):
4394
+ self.payload["when"] = when
4395
+ return self
4396
+
4041
4397
  def with_param(self, with_param):
4042
4398
  self.payload["withParam"] = with_param
4043
4399
  return self
@@ -123,6 +123,15 @@ class ArgoWorkflowsInternalDecorator(StepDecorator):
123
123
  with open("/mnt/out/split_cardinality", "w") as file:
124
124
  json.dump(flow._foreach_num_splits, file)
125
125
 
126
+ # For conditional branches we need to record the value of the switch to disk, in order to pass it as an
127
+ # output from the switching step to be used further down the DAG
128
+ if graph[step_name].type == "split-switch":
129
+ # TODO: A nicer way to access the chosen step?
130
+ _out_funcs, _ = flow._transition
131
+ chosen_step = _out_funcs[0]
132
+ with open("/mnt/out/switch_step", "w") as file:
133
+ file.write(chosen_step)
134
+
126
135
  # For steps that have a `@parallel` decorator set to them, we will be relying on Jobsets
127
136
  # to run the task. In this case, we cannot set anything in the
128
137
  # `/mnt/out` directory, since such form of output mounts are not available to Jobset executions.
@@ -0,0 +1,21 @@
1
+ from math import inf
2
+ import sys
3
+ from metaflow.util import decompress_list, compress_list
4
+ import base64
5
+
6
+
7
+ def generate_input_paths(input_paths):
8
+ # => run_id/step/:foo,bar
9
+ # input_paths are base64 encoded due to Argo shenanigans
10
+ decoded = base64.b64decode(input_paths).decode("utf-8")
11
+ paths = decompress_list(decoded)
12
+
13
+ # some of the paths are going to be malformed due to never having executed per conditional.
14
+ # strip these out of the list.
15
+
16
+ trimmed = [path for path in paths if not "{{" in path]
17
+ return compress_list(trimmed, zlibmin=inf)
18
+
19
+
20
+ if __name__ == "__main__":
21
+ print(generate_input_paths(sys.argv[1]))
@@ -317,6 +317,12 @@ class StepFunctions(object):
317
317
  "to AWS Step Functions is not supported currently."
318
318
  )
319
319
 
320
+ if node.type == "split-switch":
321
+ raise StepFunctionsException(
322
+ "Deploying flows with switch statement "
323
+ "to AWS Step Functions is not supported currently."
324
+ )
325
+
320
326
  # Assign an AWS Batch job to the AWS Step Functions state
321
327
  # and pass the intermediate state by exposing `JobId` and
322
328
  # `Parameters` to the child job(s) as outputs. `Index` and
@@ -20,12 +20,15 @@ def transform_flow_graph(step_info):
20
20
  return "split"
21
21
  elif node_type == "split-parallel" or node_type == "split-foreach":
22
22
  return "foreach"
23
+ elif node_type == "split-switch":
24
+ return "switch"
23
25
  return "unknown" # Should never happen
24
26
 
25
27
  graph_dict = {}
26
28
  for stepname in step_info:
27
- graph_dict[stepname] = {
28
- "type": node_to_type(step_info[stepname]["type"]),
29
+ node_type = node_to_type(step_info[stepname]["type"])
30
+ node_info = {
31
+ "type": node_type,
29
32
  "box_next": step_info[stepname]["type"] not in ("linear", "join"),
30
33
  "box_ends": (
31
34
  None
@@ -35,6 +38,15 @@ def transform_flow_graph(step_info):
35
38
  "next": step_info[stepname]["next"],
36
39
  "doc": step_info[stepname]["doc"],
37
40
  }
41
+
42
+ if node_type == "switch":
43
+ if "condition" in step_info[stepname]:
44
+ node_info["condition"] = step_info[stepname]["condition"]
45
+ if "switch_cases" in step_info[stepname]:
46
+ node_info["switch_cases"] = step_info[stepname]["switch_cases"]
47
+
48
+ graph_dict[stepname] = node_info
49
+
38
50
  return graph_dict
39
51
 
40
52