metaflow 2.12.8__py2.py3-none-any.whl → 2.12.10__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. metaflow/__init__.py +2 -0
  2. metaflow/cli.py +12 -4
  3. metaflow/extension_support/plugins.py +1 -0
  4. metaflow/flowspec.py +8 -1
  5. metaflow/lint.py +13 -0
  6. metaflow/metaflow_current.py +0 -8
  7. metaflow/plugins/__init__.py +12 -0
  8. metaflow/plugins/argo/argo_workflows.py +616 -46
  9. metaflow/plugins/argo/argo_workflows_cli.py +70 -3
  10. metaflow/plugins/argo/argo_workflows_decorator.py +38 -7
  11. metaflow/plugins/argo/argo_workflows_deployer.py +290 -0
  12. metaflow/plugins/argo/daemon.py +59 -0
  13. metaflow/plugins/argo/jobset_input_paths.py +16 -0
  14. metaflow/plugins/aws/batch/batch_decorator.py +16 -13
  15. metaflow/plugins/aws/step_functions/step_functions_cli.py +45 -3
  16. metaflow/plugins/aws/step_functions/step_functions_deployer.py +251 -0
  17. metaflow/plugins/cards/card_cli.py +1 -1
  18. metaflow/plugins/kubernetes/kubernetes.py +279 -52
  19. metaflow/plugins/kubernetes/kubernetes_cli.py +26 -8
  20. metaflow/plugins/kubernetes/kubernetes_client.py +0 -1
  21. metaflow/plugins/kubernetes/kubernetes_decorator.py +56 -44
  22. metaflow/plugins/kubernetes/kubernetes_job.py +7 -6
  23. metaflow/plugins/kubernetes/kubernetes_jobsets.py +511 -272
  24. metaflow/plugins/parallel_decorator.py +108 -8
  25. metaflow/plugins/secrets/secrets_decorator.py +12 -3
  26. metaflow/plugins/test_unbounded_foreach_decorator.py +39 -4
  27. metaflow/runner/deployer.py +386 -0
  28. metaflow/runner/metaflow_runner.py +1 -20
  29. metaflow/runner/nbdeploy.py +130 -0
  30. metaflow/runner/nbrun.py +4 -28
  31. metaflow/runner/utils.py +49 -0
  32. metaflow/runtime.py +246 -134
  33. metaflow/version.py +1 -1
  34. {metaflow-2.12.8.dist-info → metaflow-2.12.10.dist-info}/METADATA +2 -2
  35. {metaflow-2.12.8.dist-info → metaflow-2.12.10.dist-info}/RECORD +39 -32
  36. {metaflow-2.12.8.dist-info → metaflow-2.12.10.dist-info}/WHEEL +1 -1
  37. {metaflow-2.12.8.dist-info → metaflow-2.12.10.dist-info}/LICENSE +0 -0
  38. {metaflow-2.12.8.dist-info → metaflow-2.12.10.dist-info}/entry_points.txt +0 -0
  39. {metaflow-2.12.8.dist-info → metaflow-2.12.10.dist-info}/top_level.txt +0 -0
@@ -4,11 +4,13 @@ import os
4
4
  import re
5
5
  import shlex
6
6
  import sys
7
+ from typing import Tuple, List
7
8
  from collections import defaultdict
8
9
  from hashlib import sha1
9
10
  from math import inf
10
11
 
11
12
  from metaflow import JSONType, current
13
+ from metaflow.graph import DAGNode
12
14
  from metaflow.decorators import flow_decorators
13
15
  from metaflow.exception import MetaflowException
14
16
  from metaflow.includefile import FilePathClass
@@ -47,6 +49,7 @@ from metaflow.metaflow_config import (
47
49
  SERVICE_INTERNAL_URL,
48
50
  UI_URL,
49
51
  )
52
+ from metaflow.unbounded_foreach import UBF_CONTROL, UBF_TASK
50
53
  from metaflow.metaflow_config_funcs import config_values
51
54
  from metaflow.mflog import BASH_SAVE_LOGS, bash_capture_logs, export_mflog_env_vars
52
55
  from metaflow.parameters import deploy_time_eval
@@ -54,6 +57,7 @@ from metaflow.plugins.kubernetes.kubernetes import (
54
57
  parse_kube_keyvalue_list,
55
58
  validate_kube_labels,
56
59
  )
60
+ from metaflow.graph import FlowGraph
57
61
  from metaflow.util import (
58
62
  compress_list,
59
63
  dict_to_cli_options,
@@ -61,6 +65,9 @@ from metaflow.util import (
61
65
  to_camelcase,
62
66
  to_unicode,
63
67
  )
68
+ from metaflow.plugins.kubernetes.kubernetes_jobsets import (
69
+ KubernetesArgoJobSet,
70
+ )
64
71
 
65
72
  from .argo_client import ArgoClient
66
73
 
@@ -82,14 +89,14 @@ class ArgoWorkflowsSchedulingException(MetaflowException):
82
89
  # 5. Add Metaflow tags to labels/annotations.
83
90
  # 6. Support Multi-cluster scheduling - https://github.com/argoproj/argo-workflows/issues/3523#issuecomment-792307297
84
91
  # 7. Support R lang.
85
- # 8. Ping @savin at slack.outerbounds.co for any feature request.
92
+ # 8. Ping @savin at slack.outerbounds.co for any feature request
86
93
 
87
94
 
88
95
  class ArgoWorkflows(object):
89
96
  def __init__(
90
97
  self,
91
98
  name,
92
- graph,
99
+ graph: FlowGraph,
93
100
  flow,
94
101
  code_package_sha,
95
102
  code_package_url,
@@ -110,6 +117,7 @@ class ArgoWorkflows(object):
110
117
  notify_on_success=False,
111
118
  notify_slack_webhook_url=None,
112
119
  notify_pager_duty_integration_key=None,
120
+ enable_heartbeat_daemon=True,
113
121
  ):
114
122
  # Some high-level notes -
115
123
  #
@@ -157,6 +165,7 @@ class ArgoWorkflows(object):
157
165
  self.notify_on_success = notify_on_success
158
166
  self.notify_slack_webhook_url = notify_slack_webhook_url
159
167
  self.notify_pager_duty_integration_key = notify_pager_duty_integration_key
168
+ self.enable_heartbeat_daemon = enable_heartbeat_daemon
160
169
 
161
170
  self.parameters = self._process_parameters()
162
171
  self.triggers, self.trigger_options = self._process_triggers()
@@ -846,19 +855,21 @@ class ArgoWorkflows(object):
846
855
  .templates(self._container_templates())
847
856
  # Exit hook template(s)
848
857
  .templates(self._exit_hook_templates())
858
+ # Sidecar templates (Daemon Containers)
859
+ .templates(self._daemon_templates())
849
860
  )
850
861
  )
851
862
 
852
863
  # Visit every node and yield the uber DAGTemplate(s).
853
864
  def _dag_templates(self):
854
865
  def _visit(
855
- node, exit_node=None, templates=None, dag_tasks=None, parent_foreach=None
856
- ):
857
- if node.parallel_foreach:
858
- raise ArgoWorkflowsException(
859
- "Deploying flows with @parallel decorator(s) "
860
- "as Argo Workflows is not supported currently."
861
- )
866
+ node,
867
+ exit_node=None,
868
+ templates=None,
869
+ dag_tasks=None,
870
+ parent_foreach=None,
871
+ ): # Returns Tuple[List[Template], List[DAGTask]]
872
+ """ """
862
873
  # Every for-each node results in a separate subDAG and an equivalent
863
874
  # DAGTemplate rooted at the child of the for-each node. Each DAGTemplate
864
875
  # has a unique name - the top-level DAGTemplate is named as the name of
@@ -872,7 +883,6 @@ class ArgoWorkflows(object):
872
883
  templates = []
873
884
  if exit_node is not None and exit_node is node.name:
874
885
  return templates, dag_tasks
875
-
876
886
  if node.name == "start":
877
887
  # Start node has no dependencies.
878
888
  dag_task = DAGTask(self._sanitize(node.name)).template(
@@ -881,13 +891,86 @@ class ArgoWorkflows(object):
881
891
  elif (
882
892
  node.is_inside_foreach
883
893
  and self.graph[node.in_funcs[0]].type == "foreach"
894
+ and not self.graph[node.in_funcs[0]].parallel_foreach
895
+ # We need to distinguish what is a "regular" foreach (i.e something that doesn't care about to gang semantics)
896
+ # vs what is a "num_parallel" based foreach (i.e. something that follows gang semantics.)
897
+ # A `regular` foreach is basically any arbitrary kind of foreach.
884
898
  ):
885
899
  # Child of a foreach node needs input-paths as well as split-index
886
900
  # This child is the first node of the sub workflow and has no dependency
901
+
902
+ parameters = [
903
+ Parameter("input-paths").value("{{inputs.parameters.input-paths}}"),
904
+ Parameter("split-index").value("{{inputs.parameters.split-index}}"),
905
+ ]
906
+ dag_task = (
907
+ DAGTask(self._sanitize(node.name))
908
+ .template(self._sanitize(node.name))
909
+ .arguments(Arguments().parameters(parameters))
910
+ )
911
+ elif node.parallel_step:
912
+ # This is the step where the @parallel decorator is defined.
913
+ # Since this DAGTask will call the for the `resource` [based templates]
914
+ # (https://argo-workflows.readthedocs.io/en/stable/walk-through/kubernetes-resources/)
915
+ # we have certain constraints on the way we can pass information inside the Jobset manifest
916
+ # [All templates will have access](https://argo-workflows.readthedocs.io/en/stable/variables/#all-templates)
917
+ # to the `inputs.parameters` so we will pass down ANY/ALL information using the
918
+ # input parameters.
919
+ # We define the usual parameters like input-paths/split-index etc. but we will also
920
+ # define the following:
921
+ # - `workerCount`: parameter which will be used to determine the number of
922
+ # parallel worker jobs
923
+ # - `jobset-name`: parameter which will be used to determine the name of the jobset.
924
+ # This parameter needs to be dynamic so that when we have retries we don't
925
+ # end up using the name of the jobset again (if we do, it will crash since k8s wont allow duplicated job names)
926
+ # - `retryCount`: parameter which will be used to determine the number of retries
927
+ # This parameter will *only* be available within the container templates like we
928
+ # have it for all other DAGTasks and NOT for custom kubernetes resource templates.
929
+ # So as a work-around, we will set it as the `retryCount` parameter instead of
930
+ # setting it as a {{ retries }} in the CLI code. Once set as a input parameter,
931
+ # we can use it in the Jobset Manifest templates as `{{inputs.parameters.retryCount}}`
932
+ # - `task-id-entropy`: This is a parameter which will help derive task-ids and jobset names. This parameter
933
+ # contains the relevant amount of entropy to ensure that task-ids and jobset names
934
+ # are uniquish. We will also use this in the join task to construct the task-ids of
935
+ # all parallel tasks since the task-ids for parallel task are minted formulaically.
887
936
  parameters = [
888
937
  Parameter("input-paths").value("{{inputs.parameters.input-paths}}"),
938
+ Parameter("num-parallel").value(
939
+ "{{inputs.parameters.num-parallel}}"
940
+ ),
889
941
  Parameter("split-index").value("{{inputs.parameters.split-index}}"),
942
+ Parameter("task-id-entropy").value(
943
+ "{{inputs.parameters.task-id-entropy}}"
944
+ ),
945
+ # we cant just use hyphens with sprig.
946
+ # https://github.com/argoproj/argo-workflows/issues/10567#issuecomment-1452410948
947
+ Parameter("workerCount").value(
948
+ "{{=sprig.int(sprig.sub(sprig.int(inputs.parameters['num-parallel']),1))}}"
949
+ ),
890
950
  ]
951
+ if any(d.name == "retry" for d in node.decorators):
952
+ parameters.extend(
953
+ [
954
+ Parameter("retryCount").value("{{retries}}"),
955
+ # The job-setname needs to be unique for each retry
956
+ # and we cannot use the `generateName` field in the
957
+ # Jobset Manifest since we need to construct the subdomain
958
+ # and control pod domain name pre-hand. So we will use
959
+ # the retry count to ensure that the jobset name is unique
960
+ Parameter("jobset-name").value(
961
+ "js-{{inputs.parameters.task-id-entropy}}{{retries}}",
962
+ ),
963
+ ]
964
+ )
965
+ else:
966
+ parameters.extend(
967
+ [
968
+ Parameter("jobset-name").value(
969
+ "js-{{inputs.parameters.task-id-entropy}}",
970
+ )
971
+ ]
972
+ )
973
+
891
974
  dag_task = (
892
975
  DAGTask(self._sanitize(node.name))
893
976
  .template(self._sanitize(node.name))
@@ -947,8 +1030,8 @@ class ArgoWorkflows(object):
947
1030
  .template(self._sanitize(node.name))
948
1031
  .arguments(Arguments().parameters(parameters))
949
1032
  )
950
- dag_tasks.append(dag_task)
951
1033
 
1034
+ dag_tasks.append(dag_task)
952
1035
  # End the workflow if we have reached the end of the flow
953
1036
  if node.type == "end":
954
1037
  return [
@@ -974,14 +1057,30 @@ class ArgoWorkflows(object):
974
1057
  parent_foreach,
975
1058
  )
976
1059
  # For foreach nodes generate a new sub DAGTemplate
1060
+ # We do this for "regular" foreaches (ie. `self.next(self.a, foreach=)`)
977
1061
  elif node.type == "foreach":
978
1062
  foreach_template_name = self._sanitize(
979
1063
  "%s-foreach-%s"
980
1064
  % (
981
1065
  node.name,
982
- node.foreach_param,
1066
+ "parallel" if node.parallel_foreach else node.foreach_param
1067
+ # Since foreach's are derived based on `self.next(self.a, foreach="<varname>")`
1068
+ # vs @parallel foreach are done based on `self.next(self.a, num_parallel="<some-number>")`,
1069
+ # we need to ensure that `foreach_template_name` suffix is appropriately set based on the kind
1070
+ # of foreach.
983
1071
  )
984
1072
  )
1073
+
1074
+ # There are two separate "DAGTask"s created for the foreach node.
1075
+ # - The first one is a "jump-off" DAGTask where we propagate the
1076
+ # input-paths and split-index. This thing doesn't create
1077
+ # any actual containers and it responsible for only propagating
1078
+ # the parameters.
1079
+ # - The DAGTask that follows first DAGTask is the one
1080
+ # that uses the ContainerTemplate. This DAGTask is named the same
1081
+ # thing as the foreach node. We will leverage a similar pattern for the
1082
+ # @parallel tasks.
1083
+ #
985
1084
  foreach_task = (
986
1085
  DAGTask(foreach_template_name)
987
1086
  .dependencies([self._sanitize(node.name)])
@@ -1005,9 +1104,26 @@ class ArgoWorkflows(object):
1005
1104
  if parent_foreach
1006
1105
  else []
1007
1106
  )
1107
+ + (
1108
+ # Disabiguate parameters for a regular `foreach` vs a `@parallel` foreach
1109
+ [
1110
+ Parameter("num-parallel").value(
1111
+ "{{tasks.%s.outputs.parameters.num-parallel}}"
1112
+ % self._sanitize(node.name)
1113
+ ),
1114
+ Parameter("task-id-entropy").value(
1115
+ "{{tasks.%s.outputs.parameters.task-id-entropy}}"
1116
+ % self._sanitize(node.name)
1117
+ ),
1118
+ ]
1119
+ if node.parallel_foreach
1120
+ else []
1121
+ )
1008
1122
  )
1009
1123
  )
1010
1124
  .with_param(
1125
+ # For @parallel workloads `num-splits` will be explicitly set to one so that
1126
+ # we can piggyback on the current mechanism with which we leverage argo.
1011
1127
  "{{tasks.%s.outputs.parameters.num-splits}}"
1012
1128
  % self._sanitize(node.name)
1013
1129
  )
@@ -1020,17 +1136,34 @@ class ArgoWorkflows(object):
1020
1136
  [],
1021
1137
  node.name,
1022
1138
  )
1139
+
1140
+ # How do foreach's work on Argo:
1141
+ # Lets say you have the following dag: (start[sets `foreach="x"`]) --> (task-a [actual foreach]) --> (join) --> (end)
1142
+ # With argo we will :
1143
+ # (start [sets num-splits]) --> (task-a-foreach-(0,0) [dummy task]) --> (task-a) --> (join) --> (end)
1144
+ # The (task-a-foreach-(0,0) [dummy task]) propagates the values of the `split-index` and the input paths.
1145
+ # to the actual foreach task.
1023
1146
  templates.append(
1024
1147
  Template(foreach_template_name)
1025
1148
  .inputs(
1026
1149
  Inputs().parameters(
1027
1150
  [Parameter("input-paths"), Parameter("split-index")]
1028
1151
  + ([Parameter("root-input-path")] if parent_foreach else [])
1152
+ + (
1153
+ [
1154
+ Parameter("num-parallel"),
1155
+ Parameter("task-id-entropy"),
1156
+ # Parameter("workerCount")
1157
+ ]
1158
+ if node.parallel_foreach
1159
+ else []
1160
+ )
1029
1161
  )
1030
1162
  )
1031
1163
  .outputs(
1032
1164
  Outputs().parameters(
1033
1165
  [
1166
+ # non @parallel tasks set task-ids as outputs
1034
1167
  Parameter("task-id").valueFrom(
1035
1168
  {
1036
1169
  "parameter": "{{tasks.%s.outputs.parameters.task-id}}"
@@ -1040,29 +1173,67 @@ class ArgoWorkflows(object):
1040
1173
  }
1041
1174
  )
1042
1175
  ]
1176
+ if not node.parallel_foreach
1177
+ else [
1178
+ # @parallel tasks set `task-id-entropy` and `num-parallel`
1179
+ # as outputs so task-ids can be derived in the join step.
1180
+ # Both of these values should be propagated from the
1181
+ # jobset labels.
1182
+ Parameter("num-parallel").valueFrom(
1183
+ {
1184
+ "parameter": "{{tasks.%s.outputs.parameters.num-parallel}}"
1185
+ % self._sanitize(
1186
+ self.graph[node.matching_join].in_funcs[0]
1187
+ )
1188
+ }
1189
+ ),
1190
+ Parameter("task-id-entropy").valueFrom(
1191
+ {
1192
+ "parameter": "{{tasks.%s.outputs.parameters.task-id-entropy}}"
1193
+ % self._sanitize(
1194
+ self.graph[node.matching_join].in_funcs[0]
1195
+ )
1196
+ }
1197
+ ),
1198
+ ]
1043
1199
  )
1044
1200
  )
1045
1201
  .dag(DAGTemplate().fail_fast().tasks(dag_tasks_1))
1046
1202
  )
1203
+
1047
1204
  join_foreach_task = (
1048
1205
  DAGTask(self._sanitize(self.graph[node.matching_join].name))
1049
1206
  .template(self._sanitize(self.graph[node.matching_join].name))
1050
1207
  .dependencies([foreach_template_name])
1051
1208
  .arguments(
1052
1209
  Arguments().parameters(
1053
- [
1054
- Parameter("input-paths").value(
1055
- "argo-{{workflow.name}}/%s/{{tasks.%s.outputs.parameters.task-id}}"
1056
- % (node.name, self._sanitize(node.name))
1057
- ),
1058
- Parameter("split-cardinality").value(
1059
- "{{tasks.%s.outputs.parameters.split-cardinality}}"
1060
- % self._sanitize(node.name)
1061
- ),
1062
- ]
1210
+ (
1211
+ [
1212
+ Parameter("input-paths").value(
1213
+ "argo-{{workflow.name}}/%s/{{tasks.%s.outputs.parameters.task-id}}"
1214
+ % (node.name, self._sanitize(node.name))
1215
+ ),
1216
+ Parameter("split-cardinality").value(
1217
+ "{{tasks.%s.outputs.parameters.split-cardinality}}"
1218
+ % self._sanitize(node.name)
1219
+ ),
1220
+ ]
1221
+ if not node.parallel_foreach
1222
+ else [
1223
+ Parameter("num-parallel").value(
1224
+ "{{tasks.%s.outputs.parameters.num-parallel}}"
1225
+ % self._sanitize(node.name)
1226
+ ),
1227
+ Parameter("task-id-entropy").value(
1228
+ "{{tasks.%s.outputs.parameters.task-id-entropy}}"
1229
+ % self._sanitize(node.name)
1230
+ ),
1231
+ ]
1232
+ )
1063
1233
  + (
1064
1234
  [
1065
1235
  Parameter("split-index").value(
1236
+ # TODO : Pass down these parameters to the jobset stuff.
1066
1237
  "{{inputs.parameters.split-index}}"
1067
1238
  ),
1068
1239
  Parameter("root-input-path").value(
@@ -1098,7 +1269,13 @@ class ArgoWorkflows(object):
1098
1269
  "Argo Workflows." % (node.type, node.name)
1099
1270
  )
1100
1271
 
1101
- templates, _ = _visit(node=self.graph["start"])
1272
+ # Generate daemon tasks
1273
+ daemon_tasks = [
1274
+ DAGTask("%s-task" % daemon_template.name).template(daemon_template.name)
1275
+ for daemon_template in self._daemon_templates()
1276
+ ]
1277
+
1278
+ templates, _ = _visit(node=self.graph["start"], dag_tasks=daemon_tasks)
1102
1279
  return templates
1103
1280
 
1104
1281
  # Visit every node and yield ContainerTemplates.
@@ -1140,7 +1317,17 @@ class ArgoWorkflows(object):
1140
1317
  # export input_paths as it is used multiple times in the container script
1141
1318
  # and we do not want to repeat the values.
1142
1319
  input_paths_expr = "export INPUT_PATHS=''"
1143
- if node.name != "start":
1320
+ # If node is not a start step or a @parallel join then we will set the input paths.
1321
+ # To set the input-paths as a parameter, we need to ensure that the node
1322
+ # is not (a start node or a parallel join node). Start nodes will have no
1323
+ # input paths and parallel join will derive input paths based on a
1324
+ # formulaic approach using `num-parallel` and `task-id-entropy`.
1325
+ if not (
1326
+ node.name == "start"
1327
+ or (node.type == "join" and self.graph[node.in_funcs[0]].parallel_step)
1328
+ ):
1329
+ # For parallel joins we don't pass the INPUT_PATHS but are dynamically constructed.
1330
+ # So we don't need to set the input paths.
1144
1331
  input_paths_expr = (
1145
1332
  "export INPUT_PATHS={{inputs.parameters.input-paths}}"
1146
1333
  )
@@ -1169,13 +1356,23 @@ class ArgoWorkflows(object):
1169
1356
  task_idx,
1170
1357
  ]
1171
1358
  )
1359
+ if node.parallel_step:
1360
+ task_str = "-".join(
1361
+ [
1362
+ "$TASK_ID_PREFIX",
1363
+ "{{inputs.parameters.task-id-entropy}}", # id_base is addition entropy to based on node-name of the workflow
1364
+ "$TASK_ID_SUFFIX",
1365
+ ]
1366
+ )
1367
+ else:
1368
+ # Generated task_ids need to be non-numeric - see register_task_id in
1369
+ # service.py. We do so by prefixing `t-`
1370
+ _task_id_base = (
1371
+ "$(echo %s | md5sum | cut -d ' ' -f 1 | tail -c 9)" % task_str
1372
+ )
1373
+ task_str = "(t-%s)" % _task_id_base
1172
1374
 
1173
- # Generated task_ids need to be non-numeric - see register_task_id in
1174
- # service.py. We do so by prefixing `t-`
1175
- task_id_expr = (
1176
- "export METAFLOW_TASK_ID="
1177
- "(t-$(echo %s | md5sum | cut -d ' ' -f 1 | tail -c 9))" % task_str
1178
- )
1375
+ task_id_expr = "export METAFLOW_TASK_ID=" "%s" % task_str
1179
1376
  task_id = "$METAFLOW_TASK_ID"
1180
1377
 
1181
1378
  # Resolve retry strategy.
@@ -1194,9 +1391,20 @@ class ArgoWorkflows(object):
1194
1391
  user_code_retries = max_user_code_retries
1195
1392
  total_retries = max_user_code_retries + max_error_retries
1196
1393
  # {{retries}} is only available if retryStrategy is specified
1394
+ # and they are only available in the container templates NOT for custom
1395
+ # Kubernetes manifests like Jobsets.
1396
+ # For custom kubernetes manifests, we will pass the retryCount as a parameter
1397
+ # and use that in the manifest.
1197
1398
  retry_count = (
1198
- "{{retries}}" if max_user_code_retries + max_error_retries else 0
1399
+ (
1400
+ "{{retries}}"
1401
+ if not node.parallel_step
1402
+ else "{{inputs.parameters.retryCount}}"
1403
+ )
1404
+ if total_retries
1405
+ else 0
1199
1406
  )
1407
+
1200
1408
  minutes_between_retries = int(minutes_between_retries)
1201
1409
 
1202
1410
  # Configure log capture.
@@ -1302,13 +1510,24 @@ class ArgoWorkflows(object):
1302
1510
  foreach_step = next(
1303
1511
  n for n in node.in_funcs if self.graph[n].is_inside_foreach
1304
1512
  )
1305
- input_paths = (
1306
- "$(python -m metaflow.plugins.argo.generate_input_paths %s {{workflow.creationTimestamp}} %s {{inputs.parameters.split-cardinality}})"
1307
- % (
1308
- foreach_step,
1309
- input_paths,
1513
+ if not self.graph[node.split_parents[-1]].parallel_foreach:
1514
+ input_paths = (
1515
+ "$(python -m metaflow.plugins.argo.generate_input_paths %s {{workflow.creationTimestamp}} %s {{inputs.parameters.split-cardinality}})"
1516
+ % (
1517
+ foreach_step,
1518
+ input_paths,
1519
+ )
1520
+ )
1521
+ else:
1522
+ # When we run Jobsets with Argo Workflows we need to ensure that `input_paths` are generated using the a formulaic approach
1523
+ # because our current strategy of using volume mounts for outputs won't work with Jobsets
1524
+ input_paths = (
1525
+ "$(python -m metaflow.plugins.argo.jobset_input_paths %s %s {{inputs.parameters.task-id-entropy}} {{inputs.parameters.num-parallel}})"
1526
+ % (
1527
+ run_id,
1528
+ foreach_step,
1529
+ )
1310
1530
  )
1311
- )
1312
1531
  step = [
1313
1532
  "step",
1314
1533
  node.name,
@@ -1318,7 +1537,14 @@ class ArgoWorkflows(object):
1318
1537
  "--max-user-code-retries %d" % user_code_retries,
1319
1538
  "--input-paths %s" % input_paths,
1320
1539
  ]
1321
- if any(self.graph[n].type == "foreach" for n in node.in_funcs):
1540
+ if node.parallel_step:
1541
+ step.append(
1542
+ "--split-index ${MF_CONTROL_INDEX:-$((MF_WORKER_REPLICA_INDEX + 1))}"
1543
+ )
1544
+ # This is needed for setting the value of the UBF context in the CLI.
1545
+ step.append("--ubf-context $UBF_CONTEXT")
1546
+
1547
+ elif any(self.graph[n].type == "foreach" for n in node.in_funcs):
1322
1548
  # Pass split-index to a foreach task
1323
1549
  step.append("--split-index {{inputs.parameters.split-index}}")
1324
1550
  if self.tags:
@@ -1481,17 +1707,47 @@ class ArgoWorkflows(object):
1481
1707
  # join task deterministically inside the join task without resorting to
1482
1708
  # passing a rather long list of (albiet compressed)
1483
1709
  inputs = []
1484
- if node.name != "start":
1710
+ # To set the input-paths as a parameter, we need to ensure that the node
1711
+ # is not (a start node or a parallel join node). Start nodes will have no
1712
+ # input paths and parallel join will derive input paths based on a
1713
+ # formulaic approach.
1714
+ if not (
1715
+ node.name == "start"
1716
+ or (node.type == "join" and self.graph[node.in_funcs[0]].parallel_step)
1717
+ ):
1485
1718
  inputs.append(Parameter("input-paths"))
1486
1719
  if any(self.graph[n].type == "foreach" for n in node.in_funcs):
1487
1720
  # Fetch split-index from parent
1488
1721
  inputs.append(Parameter("split-index"))
1722
+
1489
1723
  if (
1490
1724
  node.type == "join"
1491
1725
  and self.graph[node.split_parents[-1]].type == "foreach"
1492
1726
  ):
1493
- # append this only for joins of foreaches, not static splits
1494
- inputs.append(Parameter("split-cardinality"))
1727
+ # @parallel join tasks require `num-parallel` and `task-id-entropy`
1728
+ # to construct the input paths, so we pass them down as input parameters.
1729
+ if self.graph[node.split_parents[-1]].parallel_foreach:
1730
+ inputs.extend(
1731
+ [Parameter("num-parallel"), Parameter("task-id-entropy")]
1732
+ )
1733
+ else:
1734
+ # append this only for joins of foreaches, not static splits
1735
+ inputs.append(Parameter("split-cardinality"))
1736
+ # We can use an `elif` condition because the first `if` condition validates if its
1737
+ # a foreach join node, hence we can safely assume that if that condition fails then
1738
+ # we can check if the node is a @parallel node.
1739
+ elif node.parallel_step:
1740
+ inputs.extend(
1741
+ [
1742
+ Parameter("num-parallel"),
1743
+ Parameter("task-id-entropy"),
1744
+ Parameter("jobset-name"),
1745
+ Parameter("workerCount"),
1746
+ ]
1747
+ )
1748
+ if any(d.name == "retry" for d in node.decorators):
1749
+ inputs.append(Parameter("retryCount"))
1750
+
1495
1751
  if node.is_inside_foreach and self.graph[node.out_funcs[0]].type == "join":
1496
1752
  if any(
1497
1753
  self.graph[parent].matching_join
@@ -1508,7 +1764,9 @@ class ArgoWorkflows(object):
1508
1764
  inputs.append(Parameter("root-input-path"))
1509
1765
 
1510
1766
  outputs = []
1511
- if node.name != "end":
1767
+ # @parallel steps will not have a task-id as an output parameter since task-ids
1768
+ # are derived at runtime.
1769
+ if not (node.name == "end" or node.parallel_step):
1512
1770
  outputs = [Parameter("task-id").valueFrom({"path": "/mnt/out/task_id"})]
1513
1771
  if node.type == "foreach":
1514
1772
  # Emit split cardinality from foreach task
@@ -1521,6 +1779,19 @@ class ArgoWorkflows(object):
1521
1779
  )
1522
1780
  )
1523
1781
 
1782
+ if node.parallel_foreach:
1783
+ outputs.extend(
1784
+ [
1785
+ Parameter("num-parallel").valueFrom(
1786
+ {"path": "/mnt/out/num_parallel"}
1787
+ ),
1788
+ Parameter("task-id-entropy").valueFrom(
1789
+ {"path": "/mnt/out/task_id_entropy"}
1790
+ ),
1791
+ ]
1792
+ )
1793
+ # Outputs should be defined over here, Not in the _dag_template for the `num_parallel` stuff.
1794
+
1524
1795
  # It makes no sense to set env vars to None (shows up as "None" string)
1525
1796
  # Also we skip some env vars (e.g. in case we want to pull them from KUBERNETES_SECRETS)
1526
1797
  env = {
@@ -1550,6 +1821,156 @@ class ArgoWorkflows(object):
1550
1821
  # liked to inline this ContainerTemplate and avoid scanning the workflow
1551
1822
  # twice, but due to issues with variable substitution, we will have to
1552
1823
  # live with this routine.
1824
+ if node.parallel_step:
1825
+
1826
+ # Explicitly add the task-id-hint label. This is important because this label
1827
+ # is returned as an Output parameter of this step and is used subsequently an
1828
+ # an input in the join step. Even the num_parallel is used as an output parameter
1829
+ kubernetes_labels = self.kubernetes_labels.copy()
1830
+ jobset_name = "{{inputs.parameters.jobset-name}}"
1831
+ kubernetes_labels[
1832
+ "task_id_entropy"
1833
+ ] = "{{inputs.parameters.task-id-entropy}}"
1834
+ kubernetes_labels["num_parallel"] = "{{inputs.parameters.num-parallel}}"
1835
+ jobset = KubernetesArgoJobSet(
1836
+ kubernetes_sdk=kubernetes_sdk,
1837
+ name=jobset_name,
1838
+ flow_name=self.flow.name,
1839
+ run_id=run_id,
1840
+ step_name=self._sanitize(node.name),
1841
+ task_id=task_id,
1842
+ attempt=retry_count,
1843
+ user=self.username,
1844
+ subdomain=jobset_name,
1845
+ command=cmds,
1846
+ namespace=resources["namespace"],
1847
+ image=resources["image"],
1848
+ image_pull_policy=resources["image_pull_policy"],
1849
+ service_account=resources["service_account"],
1850
+ secrets=(
1851
+ [
1852
+ k
1853
+ for k in (
1854
+ list(
1855
+ []
1856
+ if not resources.get("secrets")
1857
+ else [resources.get("secrets")]
1858
+ if isinstance(resources.get("secrets"), str)
1859
+ else resources.get("secrets")
1860
+ )
1861
+ + KUBERNETES_SECRETS.split(",")
1862
+ + ARGO_WORKFLOWS_KUBERNETES_SECRETS.split(",")
1863
+ )
1864
+ if k
1865
+ ]
1866
+ ),
1867
+ node_selector=resources.get("node_selector"),
1868
+ cpu=str(resources["cpu"]),
1869
+ memory=str(resources["memory"]),
1870
+ disk=str(resources["disk"]),
1871
+ gpu=resources["gpu"],
1872
+ gpu_vendor=str(resources["gpu_vendor"]),
1873
+ tolerations=resources["tolerations"],
1874
+ use_tmpfs=use_tmpfs,
1875
+ tmpfs_tempdir=tmpfs_tempdir,
1876
+ tmpfs_size=tmpfs_size,
1877
+ tmpfs_path=tmpfs_path,
1878
+ timeout_in_seconds=run_time_limit,
1879
+ persistent_volume_claims=resources["persistent_volume_claims"],
1880
+ shared_memory=shared_memory,
1881
+ port=port,
1882
+ )
1883
+
1884
+ for k, v in env.items():
1885
+ jobset.environment_variable(k, v)
1886
+
1887
+ for k, v in kubernetes_labels.items():
1888
+ jobset.label(k, v)
1889
+
1890
+ ## -----Jobset specific env vars START here-----
1891
+ jobset.environment_variable(
1892
+ "MF_MASTER_ADDR", jobset.jobset_control_addr
1893
+ )
1894
+ jobset.environment_variable("MF_MASTER_PORT", str(port))
1895
+ jobset.environment_variable(
1896
+ "MF_WORLD_SIZE", "{{inputs.parameters.num-parallel}}"
1897
+ )
1898
+ # for k, v in .items():
1899
+ jobset.environment_variables_from_selectors(
1900
+ {
1901
+ "MF_WORKER_REPLICA_INDEX": "metadata.annotations['jobset.sigs.k8s.io/job-index']",
1902
+ "JOBSET_RESTART_ATTEMPT": "metadata.annotations['jobset.sigs.k8s.io/restart-attempt']",
1903
+ "METAFLOW_KUBERNETES_JOBSET_NAME": "metadata.annotations['jobset.sigs.k8s.io/jobset-name']",
1904
+ "METAFLOW_KUBERNETES_POD_NAMESPACE": "metadata.namespace",
1905
+ "METAFLOW_KUBERNETES_POD_NAME": "metadata.name",
1906
+ "METAFLOW_KUBERNETES_POD_ID": "metadata.uid",
1907
+ "METAFLOW_KUBERNETES_SERVICE_ACCOUNT_NAME": "spec.serviceAccountName",
1908
+ "METAFLOW_KUBERNETES_NODE_IP": "status.hostIP",
1909
+ # `TASK_ID_SUFFIX` is needed for the construction of the task-ids
1910
+ "TASK_ID_SUFFIX": "metadata.annotations['jobset.sigs.k8s.io/job-index']",
1911
+ }
1912
+ )
1913
+ annotations = {
1914
+ # setting annotations explicitly as they wont be
1915
+ # passed down from WorkflowTemplate level
1916
+ "metaflow/step_name": node.name,
1917
+ "metaflow/attempt": str(retry_count),
1918
+ "metaflow/run_id": run_id,
1919
+ "metaflow/production_token": self.production_token,
1920
+ "metaflow/owner": self.username,
1921
+ "metaflow/user": "argo-workflows",
1922
+ "metaflow/flow_name": self.flow.name,
1923
+ }
1924
+ if current.get("project_name"):
1925
+ annotations.update(
1926
+ {
1927
+ "metaflow/project_name": current.project_name,
1928
+ "metaflow/branch_name": current.branch_name,
1929
+ "metaflow/project_flow_name": current.project_flow_name,
1930
+ }
1931
+ )
1932
+ for k, v in annotations.items():
1933
+ jobset.annotation(k, v)
1934
+ ## -----Jobset specific env vars END here-----
1935
+ ## ---- Jobset control/workers specific vars START here ----
1936
+ jobset.control.replicas(1)
1937
+ jobset.worker.replicas("{{=asInt(inputs.parameters.workerCount)}}")
1938
+ jobset.control.environment_variable("UBF_CONTEXT", UBF_CONTROL)
1939
+ jobset.worker.environment_variable("UBF_CONTEXT", UBF_TASK)
1940
+ jobset.control.environment_variable("MF_CONTROL_INDEX", "0")
1941
+ # `TASK_ID_PREFIX` needs to explicitly be `control` or `worker`
1942
+ # because the join task uses a formulaic approach to infer the task-ids
1943
+ jobset.control.environment_variable("TASK_ID_PREFIX", "control")
1944
+ jobset.worker.environment_variable("TASK_ID_PREFIX", "worker")
1945
+
1946
+ ## ---- Jobset control/workers specific vars END here ----
1947
+ yield (
1948
+ Template(ArgoWorkflows._sanitize(node.name))
1949
+ .resource(
1950
+ "create",
1951
+ jobset.dump(),
1952
+ "status.terminalState == Completed",
1953
+ "status.terminalState == Failed",
1954
+ )
1955
+ .inputs(Inputs().parameters(inputs))
1956
+ .outputs(
1957
+ Outputs().parameters(
1958
+ [
1959
+ Parameter("task-id-entropy").valueFrom(
1960
+ {"jsonPath": "{.metadata.labels.task_id_entropy}"}
1961
+ ),
1962
+ Parameter("num-parallel").valueFrom(
1963
+ {"jsonPath": "{.metadata.labels.num_parallel}"}
1964
+ ),
1965
+ ]
1966
+ )
1967
+ )
1968
+ .retry_strategy(
1969
+ times=total_retries,
1970
+ minutes_between_retries=minutes_between_retries,
1971
+ )
1972
+ )
1973
+ continue
1553
1974
  yield (
1554
1975
  Template(self._sanitize(node.name))
1555
1976
  # Set @timeout values
@@ -1604,6 +2025,7 @@ class ArgoWorkflows(object):
1604
2025
  kubernetes_sdk.V1Container(
1605
2026
  name=self._sanitize(node.name),
1606
2027
  command=cmds,
2028
+ termination_message_policy="FallbackToLogsOnError",
1607
2029
  ports=[kubernetes_sdk.V1ContainerPort(container_port=port)]
1608
2030
  if port
1609
2031
  else None,
@@ -1656,9 +2078,11 @@ class ArgoWorkflows(object):
1656
2078
  for k in list(
1657
2079
  []
1658
2080
  if not resources.get("secrets")
1659
- else [resources.get("secrets")]
1660
- if isinstance(resources.get("secrets"), str)
1661
- else resources.get("secrets")
2081
+ else (
2082
+ [resources.get("secrets")]
2083
+ if isinstance(resources.get("secrets"), str)
2084
+ else resources.get("secrets")
2085
+ )
1662
2086
  )
1663
2087
  + KUBERNETES_SECRETS.split(",")
1664
2088
  + ARGO_WORKFLOWS_KUBERNETES_SECRETS.split(",")
@@ -1710,6 +2134,13 @@ class ArgoWorkflows(object):
1710
2134
  )
1711
2135
  )
1712
2136
 
2137
+ # Return daemon container templates for workflow execution notifications.
2138
+ def _daemon_templates(self):
2139
+ templates = []
2140
+ if self.enable_heartbeat_daemon:
2141
+ templates.append(self._heartbeat_daemon_template())
2142
+ return templates
2143
+
1713
2144
  # Return exit hook templates for workflow execution notifications.
1714
2145
  def _exit_hook_templates(self):
1715
2146
  templates = []
@@ -1838,7 +2269,7 @@ class ArgoWorkflows(object):
1838
2269
  "fields": [
1839
2270
  {
1840
2271
  "type": "mrkdwn",
1841
- "text": "*Project:* %s" % current.project_name
2272
+ "text": "*Project:* %s" % current.project_name
1842
2273
  },
1843
2274
  {
1844
2275
  "type": "mrkdwn",
@@ -1916,6 +2347,117 @@ class ArgoWorkflows(object):
1916
2347
  Http("POST").url(self.notify_slack_webhook_url).body(json.dumps(payload))
1917
2348
  )
1918
2349
 
2350
+ def _heartbeat_daemon_template(self):
2351
+ # Use all the affordances available to _parameters task
2352
+ executable = self.environment.executable("_parameters")
2353
+ run_id = "argo-{{workflow.name}}"
2354
+ entrypoint = [executable, "-m metaflow.plugins.argo.daemon"]
2355
+ heartbeat_cmds = "{entrypoint} --flow_name {flow_name} --run_id {run_id} {tags} heartbeat".format(
2356
+ entrypoint=" ".join(entrypoint),
2357
+ flow_name=self.flow.name,
2358
+ run_id=run_id,
2359
+ tags=" ".join(["--tag %s" % t for t in self.tags]) if self.tags else "",
2360
+ )
2361
+
2362
+ # TODO: we do not really need MFLOG logging for the daemon at the moment, but might be good for the future.
2363
+ # Consider if we can do without this setup.
2364
+ # Configure log capture.
2365
+ mflog_expr = export_mflog_env_vars(
2366
+ datastore_type=self.flow_datastore.TYPE,
2367
+ stdout_path="$PWD/.logs/mflog_stdout",
2368
+ stderr_path="$PWD/.logs/mflog_stderr",
2369
+ flow_name=self.flow.name,
2370
+ run_id=run_id,
2371
+ step_name="_run_heartbeat_daemon",
2372
+ task_id="1",
2373
+ retry_count="0",
2374
+ )
2375
+ # TODO: Can the init be trimmed down?
2376
+ # Can we do without get_package_commands fetching the whole code package?
2377
+ init_cmds = " && ".join(
2378
+ [
2379
+ # For supporting sandboxes, ensure that a custom script is executed
2380
+ # before anything else is executed. The script is passed in as an
2381
+ # env var.
2382
+ '${METAFLOW_INIT_SCRIPT:+eval \\"${METAFLOW_INIT_SCRIPT}\\"}',
2383
+ "mkdir -p $PWD/.logs",
2384
+ mflog_expr,
2385
+ ]
2386
+ + self.environment.get_package_commands(
2387
+ self.code_package_url, self.flow_datastore.TYPE
2388
+ )[:-1]
2389
+ # Replace the line 'Task in starting'
2390
+ # FIXME: this can be brittle.
2391
+ + ["mflog 'Heartbeat daemon is starting.'"]
2392
+ )
2393
+
2394
+ cmd_str = " && ".join([init_cmds, heartbeat_cmds])
2395
+ cmds = shlex.split('bash -c "%s"' % cmd_str)
2396
+
2397
+ # TODO: Check that this is the minimal env.
2398
+ # Env required for sending heartbeats to the metadata service, nothing extra.
2399
+ env = {
2400
+ # These values are needed by Metaflow to set it's internal
2401
+ # state appropriately.
2402
+ "METAFLOW_CODE_URL": self.code_package_url,
2403
+ "METAFLOW_CODE_SHA": self.code_package_sha,
2404
+ "METAFLOW_CODE_DS": self.flow_datastore.TYPE,
2405
+ "METAFLOW_SERVICE_URL": SERVICE_INTERNAL_URL,
2406
+ "METAFLOW_SERVICE_HEADERS": json.dumps(SERVICE_HEADERS),
2407
+ "METAFLOW_USER": "argo-workflows",
2408
+ "METAFLOW_DEFAULT_DATASTORE": self.flow_datastore.TYPE,
2409
+ "METAFLOW_DEFAULT_METADATA": DEFAULT_METADATA,
2410
+ "METAFLOW_OWNER": self.username,
2411
+ }
2412
+ # support Metaflow sandboxes
2413
+ env["METAFLOW_INIT_SCRIPT"] = KUBERNETES_SANDBOX_INIT_SCRIPT
2414
+
2415
+ # cleanup env values
2416
+ env = {
2417
+ k: v
2418
+ for k, v in env.items()
2419
+ if v is not None
2420
+ and k not in set(ARGO_WORKFLOWS_ENV_VARS_TO_SKIP.split(","))
2421
+ }
2422
+
2423
+ # We want to grab the base image used by the start step, as this is known to be pullable from within the cluster,
2424
+ # and it might contain the required libraries, allowing us to start up faster.
2425
+ start_step = next(step for step in self.flow if step.name == "start")
2426
+ resources = dict(
2427
+ [deco for deco in start_step.decorators if deco.name == "kubernetes"][
2428
+ 0
2429
+ ].attributes
2430
+ )
2431
+ from kubernetes import client as kubernetes_sdk
2432
+
2433
+ return DaemonTemplate("heartbeat-daemon").container(
2434
+ to_camelcase(
2435
+ kubernetes_sdk.V1Container(
2436
+ name="main",
2437
+ # TODO: Make the image configurable
2438
+ image=resources["image"],
2439
+ command=cmds,
2440
+ env=[
2441
+ kubernetes_sdk.V1EnvVar(name=k, value=str(v))
2442
+ for k, v in env.items()
2443
+ ],
2444
+ resources=kubernetes_sdk.V1ResourceRequirements(
2445
+ # NOTE: base resources for this are kept to a minimum to save on running costs.
2446
+ # This has an adverse effect on startup time for the daemon, which can be completely
2447
+ # alleviated by using a base image that has the required dependencies pre-installed
2448
+ requests={
2449
+ "cpu": "200m",
2450
+ "memory": "100Mi",
2451
+ },
2452
+ limits={
2453
+ "cpu": "200m",
2454
+ "memory": "100Mi",
2455
+ },
2456
+ ),
2457
+ )
2458
+ )
2459
+ )
2460
+
1919
2461
  def _compile_sensor(self):
1920
2462
  # This method compiles a Metaflow @trigger decorator into Argo Events Sensor.
1921
2463
  #
@@ -2488,6 +3030,25 @@ class Metadata(object):
2488
3030
  return json.dumps(self.to_json(), indent=4)
2489
3031
 
2490
3032
 
3033
+ class DaemonTemplate(object):
3034
+ def __init__(self, name):
3035
+ tree = lambda: defaultdict(tree)
3036
+ self.name = name
3037
+ self.payload = tree()
3038
+ self.payload["daemon"] = True
3039
+ self.payload["name"] = name
3040
+
3041
+ def container(self, container):
3042
+ self.payload["container"] = container
3043
+ return self
3044
+
3045
+ def to_json(self):
3046
+ return self.payload
3047
+
3048
+ def __str__(self):
3049
+ return json.dumps(self.payload, indent=4)
3050
+
3051
+
2491
3052
  class Template(object):
2492
3053
  # https://argoproj.github.io/argo-workflows/fields/#template
2493
3054
 
@@ -2612,6 +3173,15 @@ class Template(object):
2612
3173
  def to_json(self):
2613
3174
  return self.payload
2614
3175
 
3176
+ def resource(self, action, manifest, success_criteria, failure_criteria):
3177
+ self.payload["resource"] = {}
3178
+ self.payload["resource"]["action"] = action
3179
+ self.payload["setOwnerReference"] = True
3180
+ self.payload["resource"]["successCondition"] = success_criteria
3181
+ self.payload["resource"]["failureCondition"] = failure_criteria
3182
+ self.payload["resource"]["manifest"] = manifest
3183
+ return self
3184
+
2615
3185
  def __str__(self):
2616
3186
  return json.dumps(self.payload, indent=4)
2617
3187