metaflow 2.17.2__py2.py3-none-any.whl → 2.17.4__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -222,6 +222,9 @@ class TaskDataStore(object):
222
222
  @property
223
223
  def pathspec_index(self):
224
224
  idxstr = ",".join(map(str, (f.index for f in self["_foreach_stack"])))
225
+ if "_iteration_stack" in self:
226
+ itrstr = ",".join(map(str, (f for f in self["_iteration_stack"])))
227
+ return "%s/%s[%s][%s]" % (self._run_id, self._step_name, idxstr, itrstr)
225
228
  return "%s/%s[%s]" % (self._run_id, self._step_name, idxstr)
226
229
 
227
230
  @property
metaflow/graph.py CHANGED
@@ -478,7 +478,9 @@ class FlowGraph(object):
478
478
  cur_name = cur_node.matching_join
479
479
  elif node_type == "split-switch":
480
480
  all_paths = [
481
- populate_block(s, end_name) for s in cur_node.out_funcs
481
+ populate_block(s, end_name)
482
+ for s in cur_node.out_funcs
483
+ if s != cur_name
482
484
  ]
483
485
  resulting_list.append(all_paths)
484
486
  cur_name = end_name
metaflow/lint.py CHANGED
@@ -175,6 +175,8 @@ def check_for_acyclicity(graph):
175
175
 
176
176
  def check_path(node, seen):
177
177
  for n in node.out_funcs:
178
+ if node.type == "split-switch" and n == node.name:
179
+ continue
178
180
  if n in seen:
179
181
  path = "->".join(seen + [n])
180
182
  raise LintWarn(
@@ -241,6 +243,8 @@ def check_split_join_balance(graph):
241
243
  elif node.type == "split-switch":
242
244
  # For a switch, continue traversal down each path with the same stack
243
245
  for n in node.out_funcs:
246
+ if node.type == "split-switch" and n == node.name:
247
+ continue
244
248
  traverse(graph[n], split_stack)
245
249
  return
246
250
  elif node.type == "end":
@@ -293,6 +297,8 @@ def check_split_join_balance(graph):
293
297
  new_stack = split_stack
294
298
 
295
299
  for n in node.out_funcs:
300
+ if node.type == "split-switch" and n == node.name:
301
+ continue
296
302
  traverse(graph[n], new_stack)
297
303
 
298
304
  traverse(graph["start"], [])
@@ -143,6 +143,7 @@ class ArgoWorkflows(object):
143
143
 
144
144
  self.name = name
145
145
  self.graph = graph
146
+ self._parse_conditional_branches()
146
147
  self.flow = flow
147
148
  self.code_package_metadata = code_package_metadata
148
149
  self.code_package_sha = code_package_sha
@@ -920,6 +921,121 @@ class ArgoWorkflows(object):
920
921
  )
921
922
  )
922
923
 
924
+ # Visit every node and record information on conditional step structure
925
+ def _parse_conditional_branches(self):
926
+ self.conditional_nodes = set()
927
+ self.conditional_join_nodes = set()
928
+ self.matching_conditional_join_dict = {}
929
+
930
+ node_conditional_parents = {}
931
+ node_conditional_branches = {}
932
+
933
+ def _visit(node, seen, conditional_branch, conditional_parents=None):
934
+ if not node.type == "split-switch" and not (
935
+ conditional_branch and conditional_parents
936
+ ):
937
+ # skip regular non-conditional nodes entirely
938
+ return
939
+
940
+ if node.type == "split-switch":
941
+ conditional_branch = conditional_branch + [node.name]
942
+ node_conditional_branches[node.name] = conditional_branch
943
+
944
+ conditional_parents = (
945
+ [node.name]
946
+ if not conditional_parents
947
+ else conditional_parents + [node.name]
948
+ )
949
+ node_conditional_parents[node.name] = conditional_parents
950
+
951
+ if conditional_parents and not node.type == "split-switch":
952
+ node_conditional_parents[node.name] = conditional_parents
953
+ conditional_branch = conditional_branch + [node.name]
954
+ node_conditional_branches[node.name] = conditional_branch
955
+
956
+ self.conditional_nodes.add(node.name)
957
+
958
+ if conditional_branch and conditional_parents:
959
+ for n in node.out_funcs:
960
+ child = self.graph[n]
961
+ if n not in seen:
962
+ _visit(
963
+ child, seen + [n], conditional_branch, conditional_parents
964
+ )
965
+
966
+ # First we visit all nodes to determine conditional parents and branches
967
+ for n in self.graph:
968
+ _visit(n, [], [])
969
+
970
+ # Then we traverse again in order to determine conditional join nodes, and matching conditional join info
971
+ for node in self.graph:
972
+ if node_conditional_parents.get(node.name, False):
973
+ # do the required postprocessing for anything requiring node.in_funcs
974
+
975
+ # check that in previous parsing we have not closed all conditional in_funcs.
976
+ # If so, this step can not be conditional either
977
+ is_conditional = any(
978
+ in_func in self.conditional_nodes
979
+ or self.graph[in_func].type == "split-switch"
980
+ for in_func in node.in_funcs
981
+ )
982
+ if is_conditional:
983
+ self.conditional_nodes.add(node.name)
984
+ else:
985
+ if node.name in self.conditional_nodes:
986
+ self.conditional_nodes.remove(node.name)
987
+
988
+ # does this node close the latest conditional parent branches?
989
+ conditional_in_funcs = [
990
+ in_func
991
+ for in_func in node.in_funcs
992
+ if node_conditional_branches.get(in_func, False)
993
+ ]
994
+ closed_conditional_parents = []
995
+ for last_split_switch in node_conditional_parents.get(node.name, [])[
996
+ ::-1
997
+ ]:
998
+ last_conditional_split_nodes = self.graph[
999
+ last_split_switch
1000
+ ].out_funcs
1001
+ # p needs to be in at least one conditional_branch for it to be closed.
1002
+ if all(
1003
+ any(
1004
+ p in node_conditional_branches.get(in_func, [])
1005
+ for in_func in conditional_in_funcs
1006
+ )
1007
+ for p in last_conditional_split_nodes
1008
+ ):
1009
+ closed_conditional_parents.append(last_split_switch)
1010
+
1011
+ self.conditional_join_nodes.add(node.name)
1012
+ self.matching_conditional_join_dict[last_split_switch] = (
1013
+ node.name
1014
+ )
1015
+
1016
+ # Did we close all conditionals? Then this branch and all its children are not conditional anymore (unless a new conditional branch is encountered).
1017
+ if not [
1018
+ p
1019
+ for p in node_conditional_parents.get(node.name, [])
1020
+ if p not in closed_conditional_parents
1021
+ ]:
1022
+ if node.name in self.conditional_nodes:
1023
+ self.conditional_nodes.remove(node.name)
1024
+ node_conditional_parents[node.name] = []
1025
+ for p in node.out_funcs:
1026
+ if p in self.conditional_nodes:
1027
+ self.conditional_nodes.remove(p)
1028
+ node_conditional_parents[p] = []
1029
+
1030
+ def _is_conditional_node(self, node):
1031
+ return node.name in self.conditional_nodes
1032
+
1033
+ def _is_conditional_join_node(self, node):
1034
+ return node.name in self.conditional_join_nodes
1035
+
1036
+ def _matching_conditional_join(self, node):
1037
+ return self.matching_conditional_join_dict.get(node.name, None)
1038
+
923
1039
  # Visit every node and yield the uber DAGTemplate(s).
924
1040
  def _dag_templates(self):
925
1041
  def _visit(
@@ -941,6 +1057,7 @@ class ArgoWorkflows(object):
941
1057
  dag_tasks = []
942
1058
  if templates is None:
943
1059
  templates = []
1060
+
944
1061
  if exit_node is not None and exit_node is node.name:
945
1062
  return templates, dag_tasks
946
1063
  if node.name == "start":
@@ -948,12 +1065,7 @@ class ArgoWorkflows(object):
948
1065
  dag_task = DAGTask(self._sanitize(node.name)).template(
949
1066
  self._sanitize(node.name)
950
1067
  )
951
- if node.type == "split-switch":
952
- raise ArgoWorkflowsException(
953
- "Deploying flows with switch statement "
954
- "to Argo Workflows is not supported currently."
955
- )
956
- elif (
1068
+ if (
957
1069
  node.is_inside_foreach
958
1070
  and self.graph[node.in_funcs[0]].type == "foreach"
959
1071
  and not self.graph[node.in_funcs[0]].parallel_foreach
@@ -1087,15 +1199,43 @@ class ArgoWorkflows(object):
1087
1199
  ]
1088
1200
  )
1089
1201
 
1202
+ conditional_deps = [
1203
+ "%s.Succeeded" % self._sanitize(in_func)
1204
+ for in_func in node.in_funcs
1205
+ if self._is_conditional_node(self.graph[in_func])
1206
+ ]
1207
+ required_deps = [
1208
+ "%s.Succeeded" % self._sanitize(in_func)
1209
+ for in_func in node.in_funcs
1210
+ if not self._is_conditional_node(self.graph[in_func])
1211
+ ]
1212
+ both_conditions = required_deps and conditional_deps
1213
+
1214
+ depends_str = "{required}{_and}{conditional}".format(
1215
+ required=("(%s)" if both_conditions else "%s")
1216
+ % " && ".join(required_deps),
1217
+ _and=" && " if both_conditions else "",
1218
+ conditional=("(%s)" if both_conditions else "%s")
1219
+ % " || ".join(conditional_deps),
1220
+ )
1090
1221
  dag_task = (
1091
1222
  DAGTask(self._sanitize(node.name))
1092
- .dependencies(
1093
- [self._sanitize(in_func) for in_func in node.in_funcs]
1094
- )
1223
+ .depends(depends_str)
1095
1224
  .template(self._sanitize(node.name))
1096
1225
  .arguments(Arguments().parameters(parameters))
1097
1226
  )
1098
1227
 
1228
+ # Add conditional if this is the first step in a conditional branch
1229
+ if (
1230
+ self._is_conditional_node(node)
1231
+ and self.graph[node.in_funcs[0]].type == "split-switch"
1232
+ ):
1233
+ in_func = node.in_funcs[0]
1234
+ dag_task.when(
1235
+ "{{tasks.%s.outputs.parameters.switch-step}}==%s"
1236
+ % (self._sanitize(in_func), node.name)
1237
+ )
1238
+
1099
1239
  dag_tasks.append(dag_task)
1100
1240
  # End the workflow if we have reached the end of the flow
1101
1241
  if node.type == "end":
@@ -1121,6 +1261,23 @@ class ArgoWorkflows(object):
1121
1261
  dag_tasks,
1122
1262
  parent_foreach,
1123
1263
  )
1264
+ elif node.type == "split-switch":
1265
+ for n in node.out_funcs:
1266
+ _visit(
1267
+ self.graph[n],
1268
+ self._matching_conditional_join(node),
1269
+ templates,
1270
+ dag_tasks,
1271
+ parent_foreach,
1272
+ )
1273
+
1274
+ return _visit(
1275
+ self.graph[self._matching_conditional_join(node)],
1276
+ exit_node,
1277
+ templates,
1278
+ dag_tasks,
1279
+ parent_foreach,
1280
+ )
1124
1281
  # For foreach nodes generate a new sub DAGTemplate
1125
1282
  # We do this for "regular" foreaches (ie. `self.next(self.a, foreach=)`)
1126
1283
  elif node.type == "foreach":
@@ -1148,7 +1305,7 @@ class ArgoWorkflows(object):
1148
1305
  #
1149
1306
  foreach_task = (
1150
1307
  DAGTask(foreach_template_name)
1151
- .dependencies([self._sanitize(node.name)])
1308
+ .depends(f"{self._sanitize(node.name)}.Succeeded")
1152
1309
  .template(foreach_template_name)
1153
1310
  .arguments(
1154
1311
  Arguments().parameters(
@@ -1193,6 +1350,16 @@ class ArgoWorkflows(object):
1193
1350
  % self._sanitize(node.name)
1194
1351
  )
1195
1352
  )
1353
+ # Add conditional if this is the first step in a conditional branch
1354
+ if self._is_conditional_node(node) and not any(
1355
+ self._is_conditional_node(self.graph[in_func])
1356
+ for in_func in node.in_funcs
1357
+ ):
1358
+ in_func = node.in_funcs[0]
1359
+ foreach_task.when(
1360
+ "{{tasks.%s.outputs.parameters.switch-step}}==%s"
1361
+ % (self._sanitize(in_func), node.name)
1362
+ )
1196
1363
  dag_tasks.append(foreach_task)
1197
1364
  templates, dag_tasks_1 = _visit(
1198
1365
  self.graph[node.out_funcs[0]],
@@ -1236,7 +1403,22 @@ class ArgoWorkflows(object):
1236
1403
  self.graph[node.matching_join].in_funcs[0]
1237
1404
  )
1238
1405
  }
1239
- )
1406
+ if not self._is_conditional_join_node(
1407
+ self.graph[node.matching_join]
1408
+ )
1409
+ else
1410
+ # Note: If the nodes leading to the join are conditional, then we need to use an expression to pick the outputs from the task that executed.
1411
+ # ref for operators: https://github.com/expr-lang/expr/blob/master/docs/language-definition.md
1412
+ {
1413
+ "expression": "get((%s)?.parameters, 'task-id')"
1414
+ % " ?? ".join(
1415
+ f"tasks['{self._sanitize(func)}']?.outputs"
1416
+ for func in self.graph[
1417
+ node.matching_join
1418
+ ].in_funcs
1419
+ )
1420
+ }
1421
+ ),
1240
1422
  ]
1241
1423
  if not node.parallel_foreach
1242
1424
  else [
@@ -1269,7 +1451,7 @@ class ArgoWorkflows(object):
1269
1451
  join_foreach_task = (
1270
1452
  DAGTask(self._sanitize(self.graph[node.matching_join].name))
1271
1453
  .template(self._sanitize(self.graph[node.matching_join].name))
1272
- .dependencies([foreach_template_name])
1454
+ .depends(f"{foreach_template_name}.Succeeded")
1273
1455
  .arguments(
1274
1456
  Arguments().parameters(
1275
1457
  (
@@ -1396,6 +1578,14 @@ class ArgoWorkflows(object):
1396
1578
  input_paths_expr = (
1397
1579
  "export INPUT_PATHS={{inputs.parameters.input-paths}}"
1398
1580
  )
1581
+ if self._is_conditional_join_node(node):
1582
+ # NOTE: Argo template expressions that fail to resolve, output the expression itself as a value.
1583
+ # With conditional steps, some of the input-paths are therefore 'broken' due to containing a nil expression
1584
+ # e.g. "{{ tasks['A'].outputs.parameters.task-id }}" when task A never executed.
1585
+ # We base64 encode the input-paths in order to not pollute the execution environment with templating expressions.
1586
+ # NOTE: Adding conditionals that check if a key exists or not does not work either, due to an issue with how Argo
1587
+ # handles tasks in a nested foreach (withParam template) leading to all such expressions getting evaluated as false.
1588
+ input_paths_expr = "export INPUT_PATHS={{=toBase64(inputs.parameters['input-paths'])}}"
1399
1589
  input_paths = "$(echo $INPUT_PATHS)"
1400
1590
  if any(self.graph[n].type == "foreach" for n in node.in_funcs):
1401
1591
  task_idx = "{{inputs.parameters.split-index}}"
@@ -1411,7 +1601,6 @@ class ArgoWorkflows(object):
1411
1601
  # foreaches
1412
1602
  task_idx = "{{inputs.parameters.split-index}}"
1413
1603
  root_input = "{{inputs.parameters.root-input-path}}"
1414
-
1415
1604
  # Task string to be hashed into an ID
1416
1605
  task_str = "-".join(
1417
1606
  [
@@ -1568,10 +1757,25 @@ class ArgoWorkflows(object):
1568
1757
  ]
1569
1758
  )
1570
1759
  input_paths = "%s/_parameters/%s" % (run_id, task_id_params)
1760
+ # Only for static joins and conditional_joins
1761
+ elif self._is_conditional_join_node(node) and not (
1762
+ node.type == "join"
1763
+ and self.graph[node.split_parents[-1]].type == "foreach"
1764
+ ):
1765
+ input_paths = (
1766
+ "$(python -m metaflow.plugins.argo.conditional_input_paths %s)"
1767
+ % input_paths
1768
+ )
1571
1769
  elif (
1572
1770
  node.type == "join"
1573
1771
  and self.graph[node.split_parents[-1]].type == "foreach"
1574
1772
  ):
1773
+ # foreach-joins straight out of conditional branches are not yet supported
1774
+ if self._is_conditional_join_node(node):
1775
+ raise ArgoWorkflowsException(
1776
+ "Conditionals steps that transition directly into a join step are not currently supported. "
1777
+ "As a workaround, you can add a normal step after the conditional steps that transitions to a join step."
1778
+ )
1575
1779
  # Set aggregated input-paths for a for-each join
1576
1780
  foreach_step = next(
1577
1781
  n for n in node.in_funcs if self.graph[n].is_inside_foreach
@@ -1814,7 +2018,7 @@ class ArgoWorkflows(object):
1814
2018
  [Parameter("num-parallel"), Parameter("task-id-entropy")]
1815
2019
  )
1816
2020
  else:
1817
- # append this only for joins of foreaches, not static splits
2021
+ # append these only for joins of foreaches, not static splits
1818
2022
  inputs.append(Parameter("split-cardinality"))
1819
2023
  # check if the node is a @parallel node.
1820
2024
  elif node.parallel_step:
@@ -1849,6 +2053,13 @@ class ArgoWorkflows(object):
1849
2053
  # are derived at runtime.
1850
2054
  if not (node.name == "end" or node.parallel_step):
1851
2055
  outputs = [Parameter("task-id").valueFrom({"path": "/mnt/out/task_id"})]
2056
+
2057
+ # If this step is a split-switch one, we need to output the switch step name
2058
+ if node.type == "split-switch":
2059
+ outputs.append(
2060
+ Parameter("switch-step").valueFrom({"path": "/mnt/out/switch_step"})
2061
+ )
2062
+
1852
2063
  if node.type == "foreach":
1853
2064
  # Emit split cardinality from foreach task
1854
2065
  outputs.append(
@@ -3981,6 +4192,10 @@ class DAGTask(object):
3981
4192
  self.payload["dependencies"] = dependencies
3982
4193
  return self
3983
4194
 
4195
+ def depends(self, depends: str):
4196
+ self.payload["depends"] = depends
4197
+ return self
4198
+
3984
4199
  def template(self, template):
3985
4200
  # Template reference
3986
4201
  self.payload["template"] = template
@@ -3992,6 +4207,10 @@ class DAGTask(object):
3992
4207
  self.payload["inline"] = template.to_json()
3993
4208
  return self
3994
4209
 
4210
+ def when(self, when: str):
4211
+ self.payload["when"] = when
4212
+ return self
4213
+
3995
4214
  def with_param(self, with_param):
3996
4215
  self.payload["withParam"] = with_param
3997
4216
  return self
@@ -123,6 +123,15 @@ class ArgoWorkflowsInternalDecorator(StepDecorator):
123
123
  with open("/mnt/out/split_cardinality", "w") as file:
124
124
  json.dump(flow._foreach_num_splits, file)
125
125
 
126
+ # For conditional branches we need to record the value of the switch to disk, in order to pass it as an
127
+ # output from the switching step to be used further down the DAG
128
+ if graph[step_name].type == "split-switch":
129
+ # TODO: A nicer way to access the chosen step?
130
+ _out_funcs, _ = flow._transition
131
+ chosen_step = _out_funcs[0]
132
+ with open("/mnt/out/switch_step", "w") as file:
133
+ file.write(chosen_step)
134
+
126
135
  # For steps that have a `@parallel` decorator set to them, we will be relying on Jobsets
127
136
  # to run the task. In this case, we cannot set anything in the
128
137
  # `/mnt/out` directory, since such form of output mounts are not available to Jobset executions.
@@ -0,0 +1,21 @@
1
+ from math import inf
2
+ import sys
3
+ from metaflow.util import decompress_list, compress_list
4
+ import base64
5
+
6
+
7
+ def generate_input_paths(input_paths):
8
+ # => run_id/step/:foo,bar
9
+ # input_paths are base64 encoded due to Argo shenanigans
10
+ decoded = base64.b64decode(input_paths).decode("utf-8")
11
+ paths = decompress_list(decoded)
12
+
13
+ # some of the paths are going to be malformed due to never having executed per conditional.
14
+ # strip these out of the list.
15
+
16
+ trimmed = [path for path in paths if not "{{" in path]
17
+ return compress_list(trimmed, zlibmin=inf)
18
+
19
+
20
+ if __name__ == "__main__":
21
+ print(generate_input_paths(sys.argv[1]))