ob-metaflow 2.11.13.1__py2.py3-none-any.whl → 2.19.7.1rc0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (289) hide show
  1. metaflow/R.py +10 -7
  2. metaflow/__init__.py +40 -25
  3. metaflow/_vendor/imghdr/__init__.py +186 -0
  4. metaflow/_vendor/importlib_metadata/__init__.py +1063 -0
  5. metaflow/_vendor/importlib_metadata/_adapters.py +68 -0
  6. metaflow/_vendor/importlib_metadata/_collections.py +30 -0
  7. metaflow/_vendor/importlib_metadata/_compat.py +71 -0
  8. metaflow/_vendor/importlib_metadata/_functools.py +104 -0
  9. metaflow/_vendor/importlib_metadata/_itertools.py +73 -0
  10. metaflow/_vendor/importlib_metadata/_meta.py +48 -0
  11. metaflow/_vendor/importlib_metadata/_text.py +99 -0
  12. metaflow/_vendor/importlib_metadata/py.typed +0 -0
  13. metaflow/_vendor/typeguard/__init__.py +48 -0
  14. metaflow/_vendor/typeguard/_checkers.py +1070 -0
  15. metaflow/_vendor/typeguard/_config.py +108 -0
  16. metaflow/_vendor/typeguard/_decorators.py +233 -0
  17. metaflow/_vendor/typeguard/_exceptions.py +42 -0
  18. metaflow/_vendor/typeguard/_functions.py +308 -0
  19. metaflow/_vendor/typeguard/_importhook.py +213 -0
  20. metaflow/_vendor/typeguard/_memo.py +48 -0
  21. metaflow/_vendor/typeguard/_pytest_plugin.py +127 -0
  22. metaflow/_vendor/typeguard/_suppression.py +86 -0
  23. metaflow/_vendor/typeguard/_transformer.py +1229 -0
  24. metaflow/_vendor/typeguard/_union_transformer.py +55 -0
  25. metaflow/_vendor/typeguard/_utils.py +173 -0
  26. metaflow/_vendor/typeguard/py.typed +0 -0
  27. metaflow/_vendor/typing_extensions.py +3641 -0
  28. metaflow/_vendor/v3_7/importlib_metadata/__init__.py +1063 -0
  29. metaflow/_vendor/v3_7/importlib_metadata/_adapters.py +68 -0
  30. metaflow/_vendor/v3_7/importlib_metadata/_collections.py +30 -0
  31. metaflow/_vendor/v3_7/importlib_metadata/_compat.py +71 -0
  32. metaflow/_vendor/v3_7/importlib_metadata/_functools.py +104 -0
  33. metaflow/_vendor/v3_7/importlib_metadata/_itertools.py +73 -0
  34. metaflow/_vendor/v3_7/importlib_metadata/_meta.py +48 -0
  35. metaflow/_vendor/v3_7/importlib_metadata/_text.py +99 -0
  36. metaflow/_vendor/v3_7/importlib_metadata/py.typed +0 -0
  37. metaflow/_vendor/v3_7/typeguard/__init__.py +48 -0
  38. metaflow/_vendor/v3_7/typeguard/_checkers.py +906 -0
  39. metaflow/_vendor/v3_7/typeguard/_config.py +108 -0
  40. metaflow/_vendor/v3_7/typeguard/_decorators.py +237 -0
  41. metaflow/_vendor/v3_7/typeguard/_exceptions.py +42 -0
  42. metaflow/_vendor/v3_7/typeguard/_functions.py +310 -0
  43. metaflow/_vendor/v3_7/typeguard/_importhook.py +213 -0
  44. metaflow/_vendor/v3_7/typeguard/_memo.py +48 -0
  45. metaflow/_vendor/v3_7/typeguard/_pytest_plugin.py +100 -0
  46. metaflow/_vendor/v3_7/typeguard/_suppression.py +88 -0
  47. metaflow/_vendor/v3_7/typeguard/_transformer.py +1207 -0
  48. metaflow/_vendor/v3_7/typeguard/_union_transformer.py +54 -0
  49. metaflow/_vendor/v3_7/typeguard/_utils.py +169 -0
  50. metaflow/_vendor/v3_7/typeguard/py.typed +0 -0
  51. metaflow/_vendor/v3_7/typing_extensions.py +3072 -0
  52. metaflow/_vendor/yaml/__init__.py +427 -0
  53. metaflow/_vendor/yaml/composer.py +139 -0
  54. metaflow/_vendor/yaml/constructor.py +748 -0
  55. metaflow/_vendor/yaml/cyaml.py +101 -0
  56. metaflow/_vendor/yaml/dumper.py +62 -0
  57. metaflow/_vendor/yaml/emitter.py +1137 -0
  58. metaflow/_vendor/yaml/error.py +75 -0
  59. metaflow/_vendor/yaml/events.py +86 -0
  60. metaflow/_vendor/yaml/loader.py +63 -0
  61. metaflow/_vendor/yaml/nodes.py +49 -0
  62. metaflow/_vendor/yaml/parser.py +589 -0
  63. metaflow/_vendor/yaml/reader.py +185 -0
  64. metaflow/_vendor/yaml/representer.py +389 -0
  65. metaflow/_vendor/yaml/resolver.py +227 -0
  66. metaflow/_vendor/yaml/scanner.py +1435 -0
  67. metaflow/_vendor/yaml/serializer.py +111 -0
  68. metaflow/_vendor/yaml/tokens.py +104 -0
  69. metaflow/cards.py +5 -0
  70. metaflow/cli.py +331 -785
  71. metaflow/cli_args.py +17 -0
  72. metaflow/cli_components/__init__.py +0 -0
  73. metaflow/cli_components/dump_cmd.py +96 -0
  74. metaflow/cli_components/init_cmd.py +52 -0
  75. metaflow/cli_components/run_cmds.py +546 -0
  76. metaflow/cli_components/step_cmd.py +334 -0
  77. metaflow/cli_components/utils.py +140 -0
  78. metaflow/client/__init__.py +1 -0
  79. metaflow/client/core.py +467 -73
  80. metaflow/client/filecache.py +75 -35
  81. metaflow/clone_util.py +7 -1
  82. metaflow/cmd/code/__init__.py +231 -0
  83. metaflow/cmd/develop/stub_generator.py +756 -288
  84. metaflow/cmd/develop/stubs.py +12 -28
  85. metaflow/cmd/main_cli.py +6 -4
  86. metaflow/cmd/make_wrapper.py +78 -0
  87. metaflow/datastore/__init__.py +1 -0
  88. metaflow/datastore/content_addressed_store.py +41 -10
  89. metaflow/datastore/datastore_set.py +11 -2
  90. metaflow/datastore/flow_datastore.py +156 -10
  91. metaflow/datastore/spin_datastore.py +91 -0
  92. metaflow/datastore/task_datastore.py +154 -39
  93. metaflow/debug.py +5 -0
  94. metaflow/decorators.py +404 -78
  95. metaflow/exception.py +8 -2
  96. metaflow/extension_support/__init__.py +527 -376
  97. metaflow/extension_support/_empty_file.py +2 -2
  98. metaflow/extension_support/plugins.py +49 -31
  99. metaflow/flowspec.py +482 -33
  100. metaflow/graph.py +210 -42
  101. metaflow/includefile.py +84 -40
  102. metaflow/lint.py +141 -22
  103. metaflow/meta_files.py +13 -0
  104. metaflow/{metadata → metadata_provider}/heartbeat.py +24 -8
  105. metaflow/{metadata → metadata_provider}/metadata.py +86 -1
  106. metaflow/metaflow_config.py +175 -28
  107. metaflow/metaflow_config_funcs.py +51 -3
  108. metaflow/metaflow_current.py +4 -10
  109. metaflow/metaflow_environment.py +139 -53
  110. metaflow/metaflow_git.py +115 -0
  111. metaflow/metaflow_profile.py +18 -0
  112. metaflow/metaflow_version.py +150 -66
  113. metaflow/mflog/__init__.py +4 -3
  114. metaflow/mflog/save_logs.py +2 -2
  115. metaflow/multicore_utils.py +31 -14
  116. metaflow/package/__init__.py +673 -0
  117. metaflow/packaging_sys/__init__.py +880 -0
  118. metaflow/packaging_sys/backend.py +128 -0
  119. metaflow/packaging_sys/distribution_support.py +153 -0
  120. metaflow/packaging_sys/tar_backend.py +99 -0
  121. metaflow/packaging_sys/utils.py +54 -0
  122. metaflow/packaging_sys/v1.py +527 -0
  123. metaflow/parameters.py +149 -28
  124. metaflow/plugins/__init__.py +74 -5
  125. metaflow/plugins/airflow/airflow.py +40 -25
  126. metaflow/plugins/airflow/airflow_cli.py +22 -5
  127. metaflow/plugins/airflow/airflow_decorator.py +1 -1
  128. metaflow/plugins/airflow/airflow_utils.py +5 -3
  129. metaflow/plugins/airflow/sensors/base_sensor.py +4 -4
  130. metaflow/plugins/airflow/sensors/external_task_sensor.py +2 -2
  131. metaflow/plugins/airflow/sensors/s3_sensor.py +2 -2
  132. metaflow/plugins/argo/argo_client.py +78 -33
  133. metaflow/plugins/argo/argo_events.py +6 -6
  134. metaflow/plugins/argo/argo_workflows.py +2410 -527
  135. metaflow/plugins/argo/argo_workflows_cli.py +571 -121
  136. metaflow/plugins/argo/argo_workflows_decorator.py +43 -12
  137. metaflow/plugins/argo/argo_workflows_deployer.py +106 -0
  138. metaflow/plugins/argo/argo_workflows_deployer_objects.py +453 -0
  139. metaflow/plugins/argo/capture_error.py +73 -0
  140. metaflow/plugins/argo/conditional_input_paths.py +35 -0
  141. metaflow/plugins/argo/exit_hooks.py +209 -0
  142. metaflow/plugins/argo/jobset_input_paths.py +15 -0
  143. metaflow/plugins/argo/param_val.py +19 -0
  144. metaflow/plugins/aws/aws_client.py +10 -3
  145. metaflow/plugins/aws/aws_utils.py +55 -2
  146. metaflow/plugins/aws/batch/batch.py +72 -5
  147. metaflow/plugins/aws/batch/batch_cli.py +33 -10
  148. metaflow/plugins/aws/batch/batch_client.py +4 -3
  149. metaflow/plugins/aws/batch/batch_decorator.py +102 -35
  150. metaflow/plugins/aws/secrets_manager/aws_secrets_manager_secrets_provider.py +13 -10
  151. metaflow/plugins/aws/step_functions/dynamo_db_client.py +0 -3
  152. metaflow/plugins/aws/step_functions/production_token.py +1 -1
  153. metaflow/plugins/aws/step_functions/step_functions.py +65 -8
  154. metaflow/plugins/aws/step_functions/step_functions_cli.py +101 -7
  155. metaflow/plugins/aws/step_functions/step_functions_decorator.py +1 -2
  156. metaflow/plugins/aws/step_functions/step_functions_deployer.py +97 -0
  157. metaflow/plugins/aws/step_functions/step_functions_deployer_objects.py +264 -0
  158. metaflow/plugins/azure/azure_exceptions.py +1 -1
  159. metaflow/plugins/azure/azure_secret_manager_secrets_provider.py +240 -0
  160. metaflow/plugins/azure/azure_tail.py +1 -1
  161. metaflow/plugins/azure/includefile_support.py +2 -0
  162. metaflow/plugins/cards/card_cli.py +66 -30
  163. metaflow/plugins/cards/card_creator.py +25 -1
  164. metaflow/plugins/cards/card_datastore.py +21 -49
  165. metaflow/plugins/cards/card_decorator.py +132 -8
  166. metaflow/plugins/cards/card_modules/basic.py +112 -17
  167. metaflow/plugins/cards/card_modules/bundle.css +1 -1
  168. metaflow/plugins/cards/card_modules/card.py +16 -1
  169. metaflow/plugins/cards/card_modules/chevron/renderer.py +1 -1
  170. metaflow/plugins/cards/card_modules/components.py +665 -28
  171. metaflow/plugins/cards/card_modules/convert_to_native_type.py +36 -7
  172. metaflow/plugins/cards/card_modules/json_viewer.py +232 -0
  173. metaflow/plugins/cards/card_modules/main.css +1 -0
  174. metaflow/plugins/cards/card_modules/main.js +68 -49
  175. metaflow/plugins/cards/card_modules/renderer_tools.py +1 -0
  176. metaflow/plugins/cards/card_modules/test_cards.py +26 -12
  177. metaflow/plugins/cards/card_server.py +39 -14
  178. metaflow/plugins/cards/component_serializer.py +2 -9
  179. metaflow/plugins/cards/metadata.py +22 -0
  180. metaflow/plugins/catch_decorator.py +9 -0
  181. metaflow/plugins/datastores/azure_storage.py +10 -1
  182. metaflow/plugins/datastores/gs_storage.py +6 -2
  183. metaflow/plugins/datastores/local_storage.py +12 -6
  184. metaflow/plugins/datastores/spin_storage.py +12 -0
  185. metaflow/plugins/datatools/local.py +2 -0
  186. metaflow/plugins/datatools/s3/s3.py +126 -75
  187. metaflow/plugins/datatools/s3/s3op.py +254 -121
  188. metaflow/plugins/env_escape/__init__.py +3 -3
  189. metaflow/plugins/env_escape/client_modules.py +102 -72
  190. metaflow/plugins/env_escape/server.py +7 -0
  191. metaflow/plugins/env_escape/stub.py +24 -5
  192. metaflow/plugins/events_decorator.py +343 -185
  193. metaflow/plugins/exit_hook/__init__.py +0 -0
  194. metaflow/plugins/exit_hook/exit_hook_decorator.py +46 -0
  195. metaflow/plugins/exit_hook/exit_hook_script.py +52 -0
  196. metaflow/plugins/gcp/__init__.py +1 -1
  197. metaflow/plugins/gcp/gcp_secret_manager_secrets_provider.py +11 -6
  198. metaflow/plugins/gcp/gs_tail.py +10 -6
  199. metaflow/plugins/gcp/includefile_support.py +3 -0
  200. metaflow/plugins/kubernetes/kube_utils.py +108 -0
  201. metaflow/plugins/kubernetes/kubernetes.py +411 -130
  202. metaflow/plugins/kubernetes/kubernetes_cli.py +168 -36
  203. metaflow/plugins/kubernetes/kubernetes_client.py +104 -2
  204. metaflow/plugins/kubernetes/kubernetes_decorator.py +246 -88
  205. metaflow/plugins/kubernetes/kubernetes_job.py +253 -581
  206. metaflow/plugins/kubernetes/kubernetes_jobsets.py +1071 -0
  207. metaflow/plugins/kubernetes/spot_metadata_cli.py +69 -0
  208. metaflow/plugins/kubernetes/spot_monitor_sidecar.py +109 -0
  209. metaflow/plugins/logs_cli.py +359 -0
  210. metaflow/plugins/{metadata → metadata_providers}/local.py +144 -84
  211. metaflow/plugins/{metadata → metadata_providers}/service.py +103 -26
  212. metaflow/plugins/metadata_providers/spin.py +16 -0
  213. metaflow/plugins/package_cli.py +36 -24
  214. metaflow/plugins/parallel_decorator.py +128 -11
  215. metaflow/plugins/parsers.py +16 -0
  216. metaflow/plugins/project_decorator.py +51 -5
  217. metaflow/plugins/pypi/bootstrap.py +357 -105
  218. metaflow/plugins/pypi/conda_decorator.py +82 -81
  219. metaflow/plugins/pypi/conda_environment.py +187 -52
  220. metaflow/plugins/pypi/micromamba.py +157 -47
  221. metaflow/plugins/pypi/parsers.py +268 -0
  222. metaflow/plugins/pypi/pip.py +88 -13
  223. metaflow/plugins/pypi/pypi_decorator.py +37 -1
  224. metaflow/plugins/pypi/utils.py +48 -2
  225. metaflow/plugins/resources_decorator.py +2 -2
  226. metaflow/plugins/secrets/__init__.py +3 -0
  227. metaflow/plugins/secrets/secrets_decorator.py +26 -181
  228. metaflow/plugins/secrets/secrets_func.py +49 -0
  229. metaflow/plugins/secrets/secrets_spec.py +101 -0
  230. metaflow/plugins/secrets/utils.py +74 -0
  231. metaflow/plugins/tag_cli.py +4 -7
  232. metaflow/plugins/test_unbounded_foreach_decorator.py +41 -6
  233. metaflow/plugins/timeout_decorator.py +3 -3
  234. metaflow/plugins/uv/__init__.py +0 -0
  235. metaflow/plugins/uv/bootstrap.py +128 -0
  236. metaflow/plugins/uv/uv_environment.py +72 -0
  237. metaflow/procpoll.py +1 -1
  238. metaflow/pylint_wrapper.py +5 -1
  239. metaflow/runner/__init__.py +0 -0
  240. metaflow/runner/click_api.py +717 -0
  241. metaflow/runner/deployer.py +470 -0
  242. metaflow/runner/deployer_impl.py +201 -0
  243. metaflow/runner/metaflow_runner.py +714 -0
  244. metaflow/runner/nbdeploy.py +132 -0
  245. metaflow/runner/nbrun.py +225 -0
  246. metaflow/runner/subprocess_manager.py +650 -0
  247. metaflow/runner/utils.py +335 -0
  248. metaflow/runtime.py +1078 -260
  249. metaflow/sidecar/sidecar_worker.py +1 -1
  250. metaflow/system/__init__.py +5 -0
  251. metaflow/system/system_logger.py +85 -0
  252. metaflow/system/system_monitor.py +108 -0
  253. metaflow/system/system_utils.py +19 -0
  254. metaflow/task.py +521 -225
  255. metaflow/tracing/__init__.py +7 -7
  256. metaflow/tracing/span_exporter.py +31 -38
  257. metaflow/tracing/tracing_modules.py +38 -43
  258. metaflow/tuple_util.py +27 -0
  259. metaflow/user_configs/__init__.py +0 -0
  260. metaflow/user_configs/config_options.py +563 -0
  261. metaflow/user_configs/config_parameters.py +598 -0
  262. metaflow/user_decorators/__init__.py +0 -0
  263. metaflow/user_decorators/common.py +144 -0
  264. metaflow/user_decorators/mutable_flow.py +512 -0
  265. metaflow/user_decorators/mutable_step.py +424 -0
  266. metaflow/user_decorators/user_flow_decorator.py +264 -0
  267. metaflow/user_decorators/user_step_decorator.py +749 -0
  268. metaflow/util.py +243 -27
  269. metaflow/vendor.py +23 -7
  270. metaflow/version.py +1 -1
  271. ob_metaflow-2.19.7.1rc0.data/data/share/metaflow/devtools/Makefile +355 -0
  272. ob_metaflow-2.19.7.1rc0.data/data/share/metaflow/devtools/Tiltfile +726 -0
  273. ob_metaflow-2.19.7.1rc0.data/data/share/metaflow/devtools/pick_services.sh +105 -0
  274. ob_metaflow-2.19.7.1rc0.dist-info/METADATA +87 -0
  275. ob_metaflow-2.19.7.1rc0.dist-info/RECORD +445 -0
  276. {ob_metaflow-2.11.13.1.dist-info → ob_metaflow-2.19.7.1rc0.dist-info}/WHEEL +1 -1
  277. {ob_metaflow-2.11.13.1.dist-info → ob_metaflow-2.19.7.1rc0.dist-info}/entry_points.txt +1 -0
  278. metaflow/_vendor/v3_5/__init__.py +0 -1
  279. metaflow/_vendor/v3_5/importlib_metadata/__init__.py +0 -644
  280. metaflow/_vendor/v3_5/importlib_metadata/_compat.py +0 -152
  281. metaflow/package.py +0 -188
  282. ob_metaflow-2.11.13.1.dist-info/METADATA +0 -85
  283. ob_metaflow-2.11.13.1.dist-info/RECORD +0 -308
  284. /metaflow/_vendor/{v3_5/zipp.py → zipp.py} +0 -0
  285. /metaflow/{metadata → metadata_provider}/__init__.py +0 -0
  286. /metaflow/{metadata → metadata_provider}/util.py +0 -0
  287. /metaflow/plugins/{metadata → metadata_providers}/__init__.py +0 -0
  288. {ob_metaflow-2.11.13.1.dist-info → ob_metaflow-2.19.7.1rc0.dist-info/licenses}/LICENSE +0 -0
  289. {ob_metaflow-2.11.13.1.dist-info → ob_metaflow-2.19.7.1rc0.dist-info}/top_level.txt +0 -0
@@ -6,22 +6,28 @@ import shlex
6
6
  import sys
7
7
  from collections import defaultdict
8
8
  from hashlib import sha1
9
+ from math import inf
10
+ from typing import List
9
11
 
10
12
  from metaflow import JSONType, current
11
13
  from metaflow.decorators import flow_decorators
12
14
  from metaflow.exception import MetaflowException
15
+ from metaflow.graph import FlowGraph
13
16
  from metaflow.includefile import FilePathClass
14
17
  from metaflow.metaflow_config import (
15
18
  ARGO_EVENTS_EVENT,
16
19
  ARGO_EVENTS_EVENT_BUS,
17
20
  ARGO_EVENTS_EVENT_SOURCE,
18
21
  ARGO_EVENTS_INTERNAL_WEBHOOK_URL,
22
+ ARGO_EVENTS_SENSOR_NAMESPACE,
19
23
  ARGO_EVENTS_SERVICE_ACCOUNT,
20
24
  ARGO_EVENTS_WEBHOOK_AUTH,
25
+ ARGO_WORKFLOWS_CAPTURE_ERROR_SCRIPT,
21
26
  ARGO_WORKFLOWS_ENV_VARS_TO_SKIP,
22
27
  ARGO_WORKFLOWS_KUBERNETES_SECRETS,
23
28
  ARGO_WORKFLOWS_UI_URL,
24
29
  AWS_SECRETS_MANAGER_DEFAULT_REGION,
30
+ AZURE_KEY_VAULT_PREFIX,
25
31
  AZURE_STORAGE_BLOB_SERVICE_ENDPOINT,
26
32
  CARD_AZUREROOT,
27
33
  CARD_GSROOT,
@@ -34,9 +40,7 @@ from metaflow.metaflow_config import (
34
40
  DEFAULT_SECRETS_BACKEND_TYPE,
35
41
  GCP_SECRET_MANAGER_PREFIX,
36
42
  KUBERNETES_FETCH_EC2_METADATA,
37
- KUBERNETES_LABELS,
38
43
  KUBERNETES_NAMESPACE,
39
- KUBERNETES_NODE_SELECTOR,
40
44
  KUBERNETES_SANDBOX_INIT_SCRIPT,
41
45
  KUBERNETES_SECRETS,
42
46
  S3_ENDPOINT_URL,
@@ -44,14 +48,16 @@ from metaflow.metaflow_config import (
44
48
  SERVICE_HEADERS,
45
49
  SERVICE_INTERNAL_URL,
46
50
  UI_URL,
51
+ PAGERDUTY_TEMPLATE_URL,
47
52
  )
48
- from metaflow.metaflow_config_funcs import config_values
53
+ from metaflow.metaflow_config_funcs import config_values, init_config
49
54
  from metaflow.mflog import BASH_SAVE_LOGS, bash_capture_logs, export_mflog_env_vars
50
55
  from metaflow.parameters import deploy_time_eval
51
- from metaflow.plugins.kubernetes.kubernetes import (
52
- parse_kube_keyvalue_list,
53
- validate_kube_labels,
54
- )
56
+ from metaflow.plugins.kubernetes.kube_utils import qos_requests_and_limits
57
+
58
+ from metaflow.plugins.kubernetes.kubernetes_jobsets import KubernetesArgoJobSet
59
+ from metaflow.unbounded_foreach import UBF_CONTROL, UBF_TASK
60
+ from metaflow.user_configs.config_options import ConfigInput
55
61
  from metaflow.util import (
56
62
  compress_list,
57
63
  dict_to_cli_options,
@@ -61,12 +67,18 @@ from metaflow.util import (
61
67
  )
62
68
 
63
69
  from .argo_client import ArgoClient
70
+ from .exit_hooks import ExitHookHack, HttpExitHook, ContainerHook
71
+ from metaflow.util import resolve_identity
64
72
 
65
73
 
66
74
  class ArgoWorkflowsException(MetaflowException):
67
75
  headline = "Argo Workflows error"
68
76
 
69
77
 
78
+ class ArgoWorkflowsSensorCleanupException(MetaflowException):
79
+ headline = "Argo Workflows sensor clean up error"
80
+
81
+
70
82
  class ArgoWorkflowsSchedulingException(MetaflowException):
71
83
  headline = "Argo Workflows scheduling error"
72
84
 
@@ -74,21 +86,18 @@ class ArgoWorkflowsSchedulingException(MetaflowException):
74
86
  # List of future enhancements -
75
87
  # 1. Configure Argo metrics.
76
88
  # 2. Support resuming failed workflows within Argo Workflows.
77
- # 3. Support gang-scheduled clusters for distributed PyTorch/TF - One option is to
78
- # use volcano - https://github.com/volcano-sh/volcano/tree/master/example/integrations/argo
79
- # 4. Support GitOps workflows.
80
- # 5. Add Metaflow tags to labels/annotations.
81
- # 6. Support Multi-cluster scheduling - https://github.com/argoproj/argo-workflows/issues/3523#issuecomment-792307297
82
- # 7. Support R lang.
83
- # 8. Ping @savin at slack.outerbounds.co for any feature request.
89
+ # 3. Add Metaflow tags to labels/annotations.
90
+ # 4. Support R lang.
91
+ # 5. Ping @savin at slack.outerbounds.co for any feature request
84
92
 
85
93
 
86
94
  class ArgoWorkflows(object):
87
95
  def __init__(
88
96
  self,
89
97
  name,
90
- graph,
98
+ graph: FlowGraph,
91
99
  flow,
100
+ code_package_metadata,
92
101
  code_package_sha,
93
102
  code_package_url,
94
103
  production_token,
@@ -108,6 +117,13 @@ class ArgoWorkflows(object):
108
117
  notify_on_success=False,
109
118
  notify_slack_webhook_url=None,
110
119
  notify_pager_duty_integration_key=None,
120
+ notify_incident_io_api_key=None,
121
+ incident_io_alert_source_config_id=None,
122
+ incident_io_metadata: List[str] = None,
123
+ enable_heartbeat_daemon=True,
124
+ enable_error_msg_capture=False,
125
+ workflow_title=None,
126
+ workflow_description=None,
111
127
  ):
112
128
  # Some high-level notes -
113
129
  #
@@ -133,9 +149,19 @@ class ArgoWorkflows(object):
133
149
  # ensure that your Argo Workflows controller doesn't restrict
134
150
  # templateReferencing.
135
151
 
152
+ # get initial configs
153
+ self.initial_configs = init_config()
154
+ for entry in ["OBP_PERIMETER", "OBP_INTEGRATIONS_URL"]:
155
+ if entry not in self.initial_configs:
156
+ raise ArgoWorkflowsException(
157
+ f"{entry} was not found in metaflow config. Please make sure to run `outerbounds configure <...>` command which can be found on the Outerbounds UI or reach out to your Outerbounds support team."
158
+ )
159
+
136
160
  self.name = name
137
161
  self.graph = graph
162
+ self._parse_conditional_branches()
138
163
  self.flow = flow
164
+ self.code_package_metadata = code_package_metadata
139
165
  self.code_package_sha = code_package_sha
140
166
  self.code_package_url = code_package_url
141
167
  self.production_token = production_token
@@ -155,12 +181,22 @@ class ArgoWorkflows(object):
155
181
  self.notify_on_success = notify_on_success
156
182
  self.notify_slack_webhook_url = notify_slack_webhook_url
157
183
  self.notify_pager_duty_integration_key = notify_pager_duty_integration_key
158
-
184
+ self.notify_incident_io_api_key = notify_incident_io_api_key
185
+ self.incident_io_alert_source_config_id = incident_io_alert_source_config_id
186
+ self.incident_io_metadata = self.parse_incident_io_metadata(
187
+ incident_io_metadata
188
+ )
189
+ self.enable_heartbeat_daemon = enable_heartbeat_daemon
190
+ self.enable_error_msg_capture = enable_error_msg_capture
191
+ self.workflow_title = workflow_title
192
+ self.workflow_description = workflow_description
159
193
  self.parameters = self._process_parameters()
194
+ self.config_parameters = self._process_config_parameters()
160
195
  self.triggers, self.trigger_options = self._process_triggers()
161
196
  self._schedule, self._timezone = self._get_schedule()
162
197
 
163
- self.kubernetes_labels = self._get_kubernetes_labels()
198
+ self._base_labels = self._base_kubernetes_labels()
199
+ self._base_annotations = self._base_kubernetes_annotations()
164
200
  self._workflow_template = self._compile_workflow_template()
165
201
  self._sensor = self._compile_sensor()
166
202
 
@@ -168,6 +204,7 @@ class ArgoWorkflows(object):
168
204
  return str(self._workflow_template)
169
205
 
170
206
  def deploy(self):
207
+ self.cleanup_previous_sensors()
171
208
  try:
172
209
  # Register workflow template.
173
210
  ArgoClient(namespace=KUBERNETES_NAMESPACE).register_workflow_template(
@@ -176,6 +213,37 @@ class ArgoWorkflows(object):
176
213
  except Exception as e:
177
214
  raise ArgoWorkflowsException(str(e))
178
215
 
216
+ def cleanup_previous_sensors(self):
217
+ try:
218
+ client = ArgoClient(namespace=KUBERNETES_NAMESPACE)
219
+ # Check for existing deployment and do cleanup
220
+ old_template = client.get_workflow_template(self.name)
221
+ if not old_template:
222
+ return None
223
+ # Clean up old sensors
224
+ old_sensor_namespace = old_template["metadata"]["annotations"].get(
225
+ "metaflow/sensor_namespace"
226
+ )
227
+
228
+ if old_sensor_namespace is None:
229
+ # This workflow was created before sensor annotations
230
+ # and may have a sensor in the default namespace
231
+ # we will delete it and it'll get recreated if need be
232
+ old_sensor_name = ArgoWorkflows._sensor_name(self.name)
233
+ client.delete_sensor(old_sensor_name, client._namespace)
234
+ else:
235
+ # delete old sensor only if it was somewhere else, otherwise it'll get replaced
236
+ old_sensor_name = old_template["metadata"]["annotations"][
237
+ "metaflow/sensor_name"
238
+ ]
239
+ if (
240
+ not self._sensor
241
+ or old_sensor_namespace != ARGO_EVENTS_SENSOR_NAMESPACE
242
+ ):
243
+ client.delete_sensor(old_sensor_name, old_sensor_namespace)
244
+ except Exception as e:
245
+ raise ArgoWorkflowsSensorCleanupException(str(e))
246
+
179
247
  @staticmethod
180
248
  def _sanitize(name):
181
249
  # Metaflow allows underscores in node names, which are disallowed in Argo
@@ -184,28 +252,39 @@ class ArgoWorkflows(object):
184
252
  return name.replace("_", "-")
185
253
 
186
254
  @staticmethod
187
- def list_templates(flow_name, all=False):
255
+ def _sensor_name(name):
256
+ # Unfortunately, Argo Events Sensor names don't allow for
257
+ # dots (sensors run into an error) which rules out self.name :(
258
+ return name.replace(".", "-")
259
+
260
+ @staticmethod
261
+ def list_templates(flow_name, all=False, page_size=100):
188
262
  client = ArgoClient(namespace=KUBERNETES_NAMESPACE)
189
263
 
190
- templates = client.get_workflow_templates()
191
- if templates is None:
192
- return []
193
-
194
- template_names = [
195
- template["metadata"]["name"]
196
- for template in templates
197
- if all
198
- or flow_name
199
- == template["metadata"]
200
- .get("annotations", {})
201
- .get("metaflow/flow_name", None)
202
- ]
203
- return template_names
264
+ for template in client.get_workflow_templates(page_size=page_size):
265
+ if all or flow_name == template["metadata"].get("annotations", {}).get(
266
+ "metaflow/flow_name", None
267
+ ):
268
+ yield template["metadata"]["name"]
204
269
 
205
270
  @staticmethod
206
271
  def delete(name):
207
272
  client = ArgoClient(namespace=KUBERNETES_NAMESPACE)
208
273
 
274
+ # the workflow template might not exist, but we still want to try clean up associated sensors and schedules.
275
+ workflow_template = client.get_workflow_template(name) or {}
276
+ workflow_annotations = workflow_template.get("metadata", {}).get(
277
+ "annotations", {}
278
+ )
279
+
280
+ sensor_name = ArgoWorkflows._sensor_name(
281
+ workflow_annotations.get("metaflow/sensor_name", name)
282
+ )
283
+ # if below is missing then it was deployed before custom sensor namespaces
284
+ sensor_namespace = workflow_annotations.get(
285
+ "metaflow/sensor_namespace", KUBERNETES_NAMESPACE
286
+ )
287
+
209
288
  # Always try to delete the schedule. Failure in deleting the schedule should not
210
289
  # be treated as an error, due to any of the following reasons
211
290
  # - there might not have been a schedule, or it was deleted by some other means
@@ -215,7 +294,7 @@ class ArgoWorkflows(object):
215
294
 
216
295
  # The workflow might have sensors attached to it, which consume actual resources.
217
296
  # Try to delete these as well.
218
- sensor_deleted = client.delete_sensor(name)
297
+ sensor_deleted = client.delete_sensor(sensor_name, sensor_namespace)
219
298
 
220
299
  # After cleaning up related resources, delete the workflow in question.
221
300
  # Failure in deleting is treated as critical and will be made visible to the user
@@ -239,6 +318,7 @@ class ArgoWorkflows(object):
239
318
  flow_name=flow_name, run_id=name
240
319
  )
241
320
  )
321
+ return True
242
322
 
243
323
  @staticmethod
244
324
  def get_workflow_status(flow_name, name):
@@ -272,6 +352,21 @@ class ArgoWorkflows(object):
272
352
 
273
353
  return True
274
354
 
355
+ @staticmethod
356
+ def parse_incident_io_metadata(metadata: List[str] = None):
357
+ "parse key value pairs into a dict for incident.io metadata if given"
358
+ parsed_metadata = None
359
+ if metadata is not None:
360
+ parsed_metadata = {}
361
+ for kv in metadata:
362
+ key, value = kv.split("=", 1)
363
+ if key in parsed_metadata:
364
+ raise MetaflowException(
365
+ "Incident.io Metadata *%s* provided multiple times" % key
366
+ )
367
+ parsed_metadata[key] = value
368
+ return parsed_metadata
369
+
275
370
  @classmethod
276
371
  def trigger(cls, name, parameters=None):
277
372
  if parameters is None:
@@ -291,7 +386,7 @@ class ArgoWorkflows(object):
291
386
  try:
292
387
  # Check that the workflow was deployed through Metaflow
293
388
  workflow_template["metadata"]["annotations"]["metaflow/owner"]
294
- except KeyError as e:
389
+ except KeyError:
295
390
  raise ArgoWorkflowsException(
296
391
  "An existing non-metaflow workflow with the same name as "
297
392
  "*%s* already exists in Argo Workflows. \nPlease modify the "
@@ -299,24 +394,75 @@ class ArgoWorkflows(object):
299
394
  "Workflows before proceeding." % name
300
395
  )
301
396
  try:
397
+ id_parts = resolve_identity().split(":")
398
+ parts_size = len(id_parts)
399
+ usertype = id_parts[0] if parts_size > 0 else "unknown"
400
+ username = id_parts[1] if parts_size > 1 else "unknown"
401
+
302
402
  return ArgoClient(namespace=KUBERNETES_NAMESPACE).trigger_workflow_template(
303
- name, parameters
403
+ name,
404
+ usertype,
405
+ username,
406
+ parameters,
304
407
  )
305
408
  except Exception as e:
306
409
  raise ArgoWorkflowsException(str(e))
307
410
 
308
- @staticmethod
309
- def _get_kubernetes_labels():
411
+ def _base_kubernetes_labels(self):
310
412
  """
311
- Get Kubernetes labels from environment variable.
312
- Parses the string into a dict and validates that values adhere to Kubernetes restrictions.
413
+ Get shared Kubernetes labels for Argo resources.
313
414
  """
314
- if not KUBERNETES_LABELS:
315
- return {}
316
- env_labels = KUBERNETES_LABELS.split(",")
317
- env_labels = parse_kube_keyvalue_list(env_labels, False)
318
- validate_kube_labels(env_labels)
319
- return env_labels
415
+ # TODO: Add configuration through an environment variable or Metaflow config in the future if required.
416
+ labels = {"app.kubernetes.io/part-of": "metaflow"}
417
+
418
+ return labels
419
+
420
+ def _base_kubernetes_annotations(self):
421
+ """
422
+ Get shared Kubernetes annotations for Argo resources.
423
+ """
424
+ from datetime import datetime, timezone
425
+
426
+ # TODO: Add configuration through an environment variable or Metaflow config in the future if required.
427
+ # base annotations
428
+ annotations = {
429
+ "metaflow/production_token": self.production_token,
430
+ "metaflow/owner": self.username,
431
+ "metaflow/user": "argo-workflows",
432
+ "metaflow/flow_name": self.flow.name,
433
+ "metaflow/deployment_timestamp": str(
434
+ datetime.now(timezone.utc).isoformat()
435
+ ),
436
+ }
437
+
438
+ if current.get("project_name"):
439
+ annotations.update(
440
+ {
441
+ "metaflow/project_name": current.project_name,
442
+ "metaflow/branch_name": current.branch_name,
443
+ "metaflow/project_flow_name": current.project_flow_name,
444
+ }
445
+ )
446
+
447
+ # Add Argo Workflows title and description annotations
448
+ # https://argo-workflows.readthedocs.io/en/latest/title-and-description/
449
+ # Use CLI-provided values or auto-populate from metadata
450
+ title = (
451
+ (self.workflow_title.strip() if self.workflow_title else None)
452
+ or current.get("project_flow_name")
453
+ or self.flow.name
454
+ )
455
+
456
+ description = (
457
+ self.workflow_description.strip() if self.workflow_description else None
458
+ ) or (self.flow.__doc__.strip() if self.flow.__doc__ else None)
459
+
460
+ if title:
461
+ annotations["workflows.argoproj.io/title"] = title
462
+ if description:
463
+ annotations["workflows.argoproj.io/description"] = description
464
+
465
+ return annotations
320
466
 
321
467
  def _get_schedule(self):
322
468
  schedule = self.flow._flow_decorators.get("schedule")
@@ -332,16 +478,14 @@ class ArgoWorkflows(object):
332
478
  argo_client.schedule_workflow_template(
333
479
  self.name, self._schedule, self._timezone
334
480
  )
335
- # Register sensor. Unfortunately, Argo Events Sensor names don't allow for
336
- # dots (sensors run into an error) which rules out self.name :(
481
+ # Register sensor.
337
482
  # Metaflow will overwrite any existing sensor.
338
- sensor_name = self.name.replace(".", "-")
483
+ sensor_name = ArgoWorkflows._sensor_name(self.name)
339
484
  if self._sensor:
340
- argo_client.register_sensor(sensor_name, self._sensor.to_json())
341
- else:
342
- # Since sensors occupy real resources, delete existing sensor if needed
343
- # Deregister sensors that might have existed before this deployment
344
- argo_client.delete_sensor(sensor_name)
485
+ # The new sensor will go into the sensor namespace specified
486
+ ArgoClient(namespace=ARGO_EVENTS_SENSOR_NAMESPACE).register_sensor(
487
+ sensor_name, self._sensor.to_json(), ARGO_EVENTS_SENSOR_NAMESPACE
488
+ )
345
489
  except Exception as e:
346
490
  raise ArgoWorkflowsSchedulingException(str(e))
347
491
 
@@ -393,7 +537,7 @@ class ArgoWorkflows(object):
393
537
  "metaflow/production_token"
394
538
  ],
395
539
  )
396
- except KeyError as e:
540
+ except KeyError:
397
541
  raise ArgoWorkflowsException(
398
542
  "An existing non-metaflow workflow with the same name as "
399
543
  "*%s* already exists in Argo Workflows. \nPlease modify the "
@@ -439,12 +583,22 @@ class ArgoWorkflows(object):
439
583
  "case-insensitive." % param.name
440
584
  )
441
585
  seen.add(norm)
586
+ # NOTE: We skip config parameters as these do not have dynamic values,
587
+ # and need to be treated differently.
588
+ if param.IS_CONFIG_PARAMETER:
589
+ continue
442
590
 
443
- if param.kwargs.get("type") == JSONType or isinstance(
444
- param.kwargs.get("type"), FilePathClass
445
- ):
446
- # Special-case this to avoid touching core
591
+ extra_attrs = {}
592
+ if param.kwargs.get("type") == JSONType:
593
+ param_type = str(param.kwargs.get("type").name)
594
+ elif isinstance(param.kwargs.get("type"), FilePathClass):
447
595
  param_type = str(param.kwargs.get("type").name)
596
+ extra_attrs["is_text"] = getattr(
597
+ param.kwargs.get("type"), "_is_text", True
598
+ )
599
+ extra_attrs["encoding"] = getattr(
600
+ param.kwargs.get("type"), "_encoding", "utf-8"
601
+ )
448
602
  else:
449
603
  param_type = str(param.kwargs.get("type").__name__)
450
604
 
@@ -464,14 +618,47 @@ class ArgoWorkflows(object):
464
618
  # the JSON equivalent of None to please argo-workflows. Unfortunately it
465
619
  # has the side effect of casting the parameter value to string null during
466
620
  # execution - which needs to be fixed imminently.
467
- if not is_required or default_value is not None:
621
+ if default_value is None:
622
+ default_value = json.dumps(None)
623
+ elif param_type == "JSON":
624
+ if not isinstance(default_value, str):
625
+ # once to serialize the default value if needed.
626
+ default_value = json.dumps(default_value)
627
+ # adds outer quotes to param
468
628
  default_value = json.dumps(default_value)
629
+ else:
630
+ # Make argo sensors happy
631
+ default_value = json.dumps(default_value)
632
+
469
633
  parameters[param.name] = dict(
634
+ python_var_name=var,
470
635
  name=param.name,
471
636
  value=default_value,
472
637
  type=param_type,
473
638
  description=param.kwargs.get("help"),
474
639
  is_required=is_required,
640
+ **extra_attrs,
641
+ )
642
+ return parameters
643
+
644
+ def _process_config_parameters(self):
645
+ parameters = []
646
+ seen = set()
647
+ for var, param in self.flow._get_parameters():
648
+ if not param.IS_CONFIG_PARAMETER:
649
+ continue
650
+ # Throw an exception if the parameter is specified twice.
651
+ norm = param.name.lower()
652
+ if norm in seen:
653
+ raise MetaflowException(
654
+ "Parameter *%s* is specified twice. "
655
+ "Note that parameter names are "
656
+ "case-insensitive." % param.name
657
+ )
658
+ seen.add(norm)
659
+
660
+ parameters.append(
661
+ dict(name=param.name, kv_name=ConfigInput.make_key_name(param.name))
475
662
  )
476
663
  return parameters
477
664
 
@@ -497,10 +684,17 @@ class ArgoWorkflows(object):
497
684
  # convert them to lower case since Metaflow parameters are case
498
685
  # insensitive.
499
686
  seen = set()
687
+ # NOTE: We skip config parameters as their values can not be set through event payloads
500
688
  params = set(
501
- [param.name.lower() for var, param in self.flow._get_parameters()]
689
+ [
690
+ param.name.lower()
691
+ for var, param in self.flow._get_parameters()
692
+ if not param.IS_CONFIG_PARAMETER
693
+ ]
502
694
  )
503
- for event in self.flow._flow_decorators.get("trigger")[0].triggers:
695
+ trigger_deco = self.flow._flow_decorators.get("trigger")[0]
696
+ trigger_deco.format_deploytime_value()
697
+ for event in trigger_deco.triggers:
504
698
  parameters = {}
505
699
  # TODO: Add a check to guard against names starting with numerals(?)
506
700
  if not re.match(r"^[A-Za-z0-9_.-]+$", event["name"]):
@@ -540,11 +734,23 @@ class ArgoWorkflows(object):
540
734
 
541
735
  # @trigger_on_finish decorator
542
736
  if self.flow._flow_decorators.get("trigger_on_finish"):
543
- for event in self.flow._flow_decorators.get("trigger_on_finish")[
544
- 0
545
- ].triggers:
737
+ trigger_on_finish_deco = self.flow._flow_decorators.get(
738
+ "trigger_on_finish"
739
+ )[0]
740
+ trigger_on_finish_deco.format_deploytime_value()
741
+ for event in trigger_on_finish_deco.triggers:
546
742
  # Actual filters are deduced here since we don't have access to
547
743
  # the current object in the @trigger_on_finish decorator.
744
+ project_name = event.get("project") or current.get("project_name")
745
+ branch_name = event.get("branch") or current.get("branch_name")
746
+ # validate that we have complete project info for an event name
747
+ if project_name or branch_name:
748
+ if not (project_name and branch_name):
749
+ # if one of the two is missing, we would end up listening to an event that will never be broadcast.
750
+ raise ArgoWorkflowsException(
751
+ "Incomplete project info. Please specify both 'project' and 'project_branch' or use the @project decorator"
752
+ )
753
+
548
754
  triggers.append(
549
755
  {
550
756
  # Make sure this remains consistent with the event name format
@@ -553,18 +759,16 @@ class ArgoWorkflows(object):
553
759
  % ".".join(
554
760
  v
555
761
  for v in [
556
- event.get("project") or current.get("project_name"),
557
- event.get("branch") or current.get("branch_name"),
762
+ project_name,
763
+ branch_name,
558
764
  event["flow"],
559
765
  ]
560
766
  if v
561
767
  ),
562
768
  "filters": {
563
769
  "auto-generated-by-metaflow": True,
564
- "project_name": event.get("project")
565
- or current.get("project_name"),
566
- "branch_name": event.get("branch")
567
- or current.get("branch_name"),
770
+ "project_name": project_name,
771
+ "branch_name": branch_name,
568
772
  # TODO: Add a time filters to guard against cached events
569
773
  },
570
774
  "type": "run",
@@ -616,30 +820,19 @@ class ArgoWorkflows(object):
616
820
  # generate container templates at the top level (in WorkflowSpec) and maintain
617
821
  # references to them within the DAGTask.
618
822
 
619
- from datetime import datetime, timezone
823
+ annotations = {}
620
824
 
621
- annotations = {
622
- "metaflow/production_token": self.production_token,
623
- "metaflow/owner": self.username,
624
- "metaflow/user": "argo-workflows",
625
- "metaflow/flow_name": self.flow.name,
626
- "metaflow/deployment_timestamp": str(
627
- datetime.now(timezone.utc).isoformat()
628
- ),
629
- }
825
+ if self._schedule is not None:
826
+ # timezone is an optional field and json dumps on None will result in null
827
+ # hence configuring it to an empty string
828
+ if self._timezone is None:
829
+ self._timezone = ""
830
+ cron_info = {"schedule": self._schedule, "tz": self._timezone}
831
+ annotations.update({"metaflow/cron": json.dumps(cron_info)})
630
832
 
631
833
  if self.parameters:
632
834
  annotations.update({"metaflow/parameters": json.dumps(self.parameters)})
633
835
 
634
- if current.get("project_name"):
635
- annotations.update(
636
- {
637
- "metaflow/project_name": current.project_name,
638
- "metaflow/branch_name": current.branch_name,
639
- "metaflow/project_flow_name": current.project_flow_name,
640
- }
641
- )
642
-
643
836
  # Some more annotations to populate the Argo UI nicely
644
837
  if self.tags:
645
838
  annotations.update({"metaflow/tags": json.dumps(self.tags)})
@@ -651,7 +844,9 @@ class ArgoWorkflows(object):
651
844
  {key: trigger.get(key) for key in ["name", "type"]}
652
845
  for trigger in self.triggers
653
846
  ]
654
- )
847
+ ),
848
+ "metaflow/sensor_name": ArgoWorkflows._sensor_name(self.name),
849
+ "metaflow/sensor_namespace": ARGO_EVENTS_SENSOR_NAMESPACE,
655
850
  }
656
851
  )
657
852
  if self.notify_on_error:
@@ -661,6 +856,7 @@ class ArgoWorkflows(object):
661
856
  {
662
857
  "slack": bool(self.notify_slack_webhook_url),
663
858
  "pager_duty": bool(self.notify_pager_duty_integration_key),
859
+ "incident_io": bool(self.notify_incident_io_api_key),
664
860
  }
665
861
  )
666
862
  }
@@ -672,11 +868,24 @@ class ArgoWorkflows(object):
672
868
  {
673
869
  "slack": bool(self.notify_slack_webhook_url),
674
870
  "pager_duty": bool(self.notify_pager_duty_integration_key),
871
+ "incident_io": bool(self.notify_incident_io_api_key),
675
872
  }
676
873
  )
677
874
  }
678
875
  )
876
+ try:
877
+ # Build the DAG based on the DAGNodes given by the FlowGraph for the found FlowSpec class.
878
+ _steps_info, graph_structure = self.graph.output_steps()
879
+ graph_info = {
880
+ # for the time being, we only need the graph_structure. Being mindful of annotation size limits we do not include anything extra.
881
+ "graph_structure": graph_structure
882
+ }
883
+ except Exception:
884
+ graph_info = None
885
+
886
+ dag_annotation = {"metaflow/dag": json.dumps(graph_info)}
679
887
 
888
+ lifecycle_hooks = self._lifecycle_hooks()
680
889
  return (
681
890
  WorkflowTemplate()
682
891
  .metadata(
@@ -687,9 +896,11 @@ class ArgoWorkflows(object):
687
896
  # is released, we should be able to support multi-namespace /
688
897
  # multi-cluster scheduling.
689
898
  .namespace(KUBERNETES_NAMESPACE)
690
- .label("app.kubernetes.io/name", "metaflow-flow")
691
- .label("app.kubernetes.io/part-of", "metaflow")
692
899
  .annotations(annotations)
900
+ .annotations(self._base_annotations)
901
+ .labels(self._base_labels)
902
+ .label("app.kubernetes.io/name", "metaflow-flow")
903
+ .annotations(dag_annotation)
693
904
  )
694
905
  .spec(
695
906
  WorkflowSpec()
@@ -719,10 +930,23 @@ class ArgoWorkflows(object):
719
930
  # Set workflow metadata
720
931
  .workflow_metadata(
721
932
  Metadata()
933
+ .labels(self._base_labels)
722
934
  .label("app.kubernetes.io/name", "metaflow-run")
723
- .label("app.kubernetes.io/part-of", "metaflow")
724
935
  .annotations(
725
- {**annotations, **{"metaflow/run_id": "argo-{{workflow.name}}"}}
936
+ {
937
+ **annotations,
938
+ **{
939
+ k: v
940
+ for k, v in self._base_annotations.items()
941
+ if k
942
+ # Skip custom title/description for workflows as this makes it harder to find specific runs.
943
+ not in [
944
+ "workflows.argoproj.io/title",
945
+ "workflows.argoproj.io/description",
946
+ ]
947
+ },
948
+ **{"metaflow/run_id": "argo-{{workflow.name}}"},
949
+ }
726
950
  )
727
951
  # TODO: Set dynamic labels using labels_from. Ideally, we would
728
952
  # want to expose run_id as a label. It's easy to add labels,
@@ -755,95 +979,251 @@ class ArgoWorkflows(object):
755
979
  # Set common pod metadata.
756
980
  .pod_metadata(
757
981
  Metadata()
982
+ .labels(self._base_labels)
758
983
  .label("app.kubernetes.io/name", "metaflow-task")
759
- .label("app.kubernetes.io/part-of", "metaflow")
760
- .annotations(annotations)
761
- .labels(self.kubernetes_labels)
984
+ .annotations(
985
+ {
986
+ **annotations,
987
+ **self._base_annotations,
988
+ **{
989
+ "metaflow/run_id": "argo-{{workflow.name}}"
990
+ }, # we want pods of the workflow to have the run_id as an annotation as well
991
+ }
992
+ )
762
993
  )
763
994
  # Set the entrypoint to flow name
764
995
  .entrypoint(self.flow.name)
765
- # Set exit hook handlers if notifications are enabled
996
+ # OnExit hooks
997
+ .onExit(
998
+ "capture-error-hook-fn-preflight"
999
+ if self.enable_error_msg_capture
1000
+ else None
1001
+ )
1002
+ # Set lifecycle hooks if notifications are enabled
766
1003
  .hooks(
767
1004
  {
768
- **(
769
- {
770
- # workflow status maps to Completed
771
- "notify-slack-on-success": LifecycleHook()
772
- .expression("workflow.status == 'Succeeded'")
773
- .template("notify-slack-on-success"),
774
- }
775
- if self.notify_on_success and self.notify_slack_webhook_url
776
- else {}
777
- ),
778
- **(
779
- {
780
- # workflow status maps to Completed
781
- "notify-pager-duty-on-success": LifecycleHook()
782
- .expression("workflow.status == 'Succeeded'")
783
- .template("notify-pager-duty-on-success"),
784
- }
785
- if self.notify_on_success
786
- and self.notify_pager_duty_integration_key
787
- else {}
788
- ),
789
- **(
790
- {
791
- # workflow status maps to Failed or Error
792
- "notify-slack-on-failure": LifecycleHook()
793
- .expression("workflow.status == 'Failed'")
794
- .template("notify-slack-on-error"),
795
- "notify-slack-on-error": LifecycleHook()
796
- .expression("workflow.status == 'Error'")
797
- .template("notify-slack-on-error"),
798
- }
799
- if self.notify_on_error and self.notify_slack_webhook_url
800
- else {}
801
- ),
802
- **(
803
- {
804
- # workflow status maps to Failed or Error
805
- "notify-pager-duty-on-failure": LifecycleHook()
806
- .expression("workflow.status == 'Failed'")
807
- .template("notify-pager-duty-on-error"),
808
- "notify-pager-duty-on-error": LifecycleHook()
809
- .expression("workflow.status == 'Error'")
810
- .template("notify-pager-duty-on-error"),
811
- }
812
- if self.notify_on_error
813
- and self.notify_pager_duty_integration_key
814
- else {}
815
- ),
816
- # Warning: terrible hack to workaround a bug in Argo Workflow
817
- # where the hooks listed above do not execute unless
818
- # there is an explicit exit hook. as and when this
819
- # bug is patched, we should remove this effectively
820
- # no-op hook.
821
- **(
822
- {"exit": LifecycleHook().template("exit-hook-hack")}
823
- if self.notify_on_error or self.notify_on_success
824
- else {}
825
- ),
1005
+ lifecycle.name: lifecycle
1006
+ for hook in lifecycle_hooks
1007
+ for lifecycle in hook.lifecycle_hooks
826
1008
  }
827
1009
  )
828
1010
  # Top-level DAG template(s)
829
1011
  .templates(self._dag_templates())
830
1012
  # Container templates
831
1013
  .templates(self._container_templates())
1014
+ # Lifecycle hook template(s)
1015
+ .templates([hook.template for hook in lifecycle_hooks])
832
1016
  # Exit hook template(s)
833
1017
  .templates(self._exit_hook_templates())
1018
+ # Sidecar templates (Daemon Containers)
1019
+ .templates(self._daemon_templates())
1020
+ )
1021
+ )
1022
+
1023
+ # Visit every node and record information on conditional step structure
1024
+ def _parse_conditional_branches(self):
1025
+ self.conditional_nodes = set()
1026
+ self.conditional_join_nodes = set()
1027
+ self.matching_conditional_join_dict = {}
1028
+ self.recursive_nodes = set()
1029
+
1030
+ node_conditional_parents = {}
1031
+ node_conditional_branches = {}
1032
+
1033
+ def _visit(node, conditional_branch, conditional_parents=None):
1034
+ if not node.type == "split-switch" and not (
1035
+ conditional_branch and conditional_parents
1036
+ ):
1037
+ # skip regular non-conditional nodes entirely
1038
+ return
1039
+
1040
+ if node.type == "split-switch":
1041
+ conditional_branch = conditional_branch + [node.name]
1042
+ c_br = node_conditional_branches.get(node.name, [])
1043
+ node_conditional_branches[node.name] = c_br + [
1044
+ b for b in conditional_branch if b not in c_br
1045
+ ]
1046
+
1047
+ conditional_parents = (
1048
+ [node.name]
1049
+ if not conditional_parents
1050
+ else conditional_parents + [node.name]
1051
+ )
1052
+ node_conditional_parents[node.name] = conditional_parents
1053
+
1054
+ # check for recursion. this split is recursive if any of its out functions are itself.
1055
+ if any(
1056
+ out_func for out_func in node.out_funcs if out_func == node.name
1057
+ ):
1058
+ self.recursive_nodes.add(node.name)
1059
+
1060
+ if conditional_parents and not node.type == "split-switch":
1061
+ node_conditional_parents[node.name] = conditional_parents
1062
+ conditional_branch = conditional_branch + [node.name]
1063
+ c_br = node_conditional_branches.get(node.name, [])
1064
+ node_conditional_branches[node.name] = c_br + [
1065
+ b for b in conditional_branch if b not in c_br
1066
+ ]
1067
+
1068
+ self.conditional_nodes.add(node.name)
1069
+
1070
+ if conditional_branch and conditional_parents:
1071
+ for n in node.out_funcs:
1072
+ child = self.graph[n]
1073
+ if child.name == node.name:
1074
+ continue
1075
+ _visit(child, conditional_branch, conditional_parents)
1076
+
1077
+ # First we visit all nodes to determine conditional parents and branches
1078
+ for n in self.graph:
1079
+ _visit(n, [])
1080
+
1081
+ # helper to clean up conditional info for all children of a node, until a new split-switch is encountered.
1082
+ def _cleanup_conditional_status(node_name, seen):
1083
+ if self.graph[node_name].type == "split-switch":
1084
+ # stop recursive cleanup if we hit a new split-switch
1085
+ return
1086
+ if node_name in self.conditional_nodes:
1087
+ self.conditional_nodes.remove(node_name)
1088
+ node_conditional_parents[node_name] = []
1089
+ node_conditional_branches[node_name] = []
1090
+ for p in self.graph[node_name].out_funcs:
1091
+ if p not in seen:
1092
+ _cleanup_conditional_status(p, seen + [p])
1093
+
1094
+ # Then we traverse again in order to determine conditional join nodes, and matching conditional join info
1095
+ for node in self.graph:
1096
+ if node_conditional_parents.get(node.name, False):
1097
+ # do the required postprocessing for anything requiring node.in_funcs
1098
+
1099
+ # check that in previous parsing we have not closed all conditional in_funcs.
1100
+ # If so, this step can not be conditional either
1101
+ is_conditional = any(
1102
+ in_func in self.conditional_nodes
1103
+ or self.graph[in_func].type == "split-switch"
1104
+ for in_func in node.in_funcs
1105
+ )
1106
+ if is_conditional:
1107
+ self.conditional_nodes.add(node.name)
1108
+ else:
1109
+ if node.name in self.conditional_nodes:
1110
+ self.conditional_nodes.remove(node.name)
1111
+
1112
+ # does this node close the latest conditional parent branches?
1113
+ conditional_in_funcs = [
1114
+ in_func
1115
+ for in_func in node.in_funcs
1116
+ if node_conditional_branches.get(in_func, False)
1117
+ ]
1118
+ closed_conditional_parents = []
1119
+ for last_split_switch in node_conditional_parents.get(node.name, [])[
1120
+ ::-1
1121
+ ]:
1122
+ last_conditional_split_nodes = self.graph[
1123
+ last_split_switch
1124
+ ].out_funcs
1125
+ # NOTE: How do we define a conditional join step?
1126
+ # The idea here is that we check if the conditional branches(e.g. chains of conditional steps leading to) of all the in_funcs
1127
+ # manage to tick off every step name that follows a split-switch
1128
+ # For example, consider the following structure
1129
+ # switch_step -> A, B, C
1130
+ # A -> A2 -> A3 -> A4 -> B2
1131
+ # B -> B2 -> B3 -> C3
1132
+ # C -> C2 -> C3 -> end
1133
+ #
1134
+ # if we look at the in_funcs for C3, they are (C2, B3)
1135
+ # B3 closes off branches started by A and B
1136
+ # C3 closes off branches started by C
1137
+ # therefore C3 is a conditional join step for the 'switch_step'
1138
+ # NOTE: Then what about a skip step?
1139
+ # some switch cases might not introduce any distinct steps of their own, opting to instead skip ahead to a later common step.
1140
+ # Example:
1141
+ # switch_step -> A, B, C
1142
+ # A -> A1 -> B2 -> C
1143
+ # B -> B1 -> B2 -> C
1144
+ #
1145
+ # In this case, C is a skip step as it does not add any conditional branching of its own.
1146
+ # C is also a conditional join, as it closes all branches started by 'switch_step'
1147
+
1148
+ closes_branches = all(
1149
+ (
1150
+ # branch_root_node_name needs to be in at least one conditional_branch for it to be closed.
1151
+ any(
1152
+ branch_root_node_name
1153
+ in node_conditional_branches.get(in_func, [])
1154
+ for in_func in conditional_in_funcs
1155
+ )
1156
+ # need to account for a switch case skipping completely, not having a conditional-branch of its own.
1157
+ if branch_root_node_name != node.name
1158
+ else True
1159
+ )
1160
+ for branch_root_node_name in last_conditional_split_nodes
1161
+ )
1162
+ if closes_branches:
1163
+ closed_conditional_parents.append(last_split_switch)
1164
+
1165
+ self.conditional_join_nodes.add(node.name)
1166
+ self.matching_conditional_join_dict[last_split_switch] = (
1167
+ node.name
1168
+ )
1169
+
1170
+ # Did we close all conditionals? Then this branch and all its children are not conditional anymore (unless a new conditional branch is encountered).
1171
+ if not [
1172
+ p
1173
+ for p in node_conditional_parents.get(node.name, [])
1174
+ if p not in closed_conditional_parents
1175
+ ]:
1176
+ _cleanup_conditional_status(node.name, [])
1177
+
1178
+ def _is_conditional_node(self, node):
1179
+ return node.name in self.conditional_nodes
1180
+
1181
+ def _is_conditional_skip_node(self, node):
1182
+ return (
1183
+ self._is_conditional_node(node)
1184
+ and any(
1185
+ self.graph[in_func].type == "split-switch" for in_func in node.in_funcs
1186
+ )
1187
+ and len(
1188
+ [
1189
+ in_func
1190
+ for in_func in node.in_funcs
1191
+ if self._is_conditional_node(self.graph[in_func])
1192
+ or self.graph[in_func].type == "split-switch"
1193
+ ]
834
1194
  )
1195
+ > 1
835
1196
  )
836
1197
 
1198
+ def _is_conditional_join_node(self, node):
1199
+ return node.name in self.conditional_join_nodes
1200
+
1201
+ def _many_in_funcs_all_conditional(self, node):
1202
+ cond_in_funcs = [
1203
+ in_func
1204
+ for in_func in node.in_funcs
1205
+ if self._is_conditional_node(self.graph[in_func])
1206
+ ]
1207
+ return len(cond_in_funcs) > 1 and len(cond_in_funcs) == len(node.in_funcs)
1208
+
1209
+ def _is_recursive_node(self, node):
1210
+ return node.name in self.recursive_nodes
1211
+
1212
+ def _matching_conditional_join(self, node):
1213
+ # If no earlier conditional join step is found during parsing, then 'end' is always one.
1214
+ return self.matching_conditional_join_dict.get(node.name, "end")
1215
+
837
1216
  # Visit every node and yield the uber DAGTemplate(s).
838
1217
  def _dag_templates(self):
839
1218
  def _visit(
840
- node, exit_node=None, templates=None, dag_tasks=None, parent_foreach=None
841
- ):
842
- if node.parallel_foreach:
843
- raise ArgoWorkflowsException(
844
- "Deploying flows with @parallel decorator(s) "
845
- "as Argo Workflows is not supported currently."
846
- )
1219
+ node,
1220
+ exit_node=None,
1221
+ templates=None,
1222
+ dag_tasks=None,
1223
+ parent_foreach=None,
1224
+ seen=None,
1225
+ ): # Returns Tuple[List[Template], List[DAGTask]]
1226
+ """ """
847
1227
  # Every for-each node results in a separate subDAG and an equivalent
848
1228
  # DAGTemplate rooted at the child of the for-each node. Each DAGTemplate
849
1229
  # has a unique name - the top-level DAGTemplate is named as the name of
@@ -851,28 +1231,111 @@ class ArgoWorkflows(object):
851
1231
  # of the for-each node.
852
1232
 
853
1233
  # Emit if we have reached the end of the sub workflow
1234
+ if seen is None:
1235
+ seen = []
854
1236
  if dag_tasks is None:
855
1237
  dag_tasks = []
856
1238
  if templates is None:
857
1239
  templates = []
1240
+
858
1241
  if exit_node is not None and exit_node is node.name:
859
1242
  return templates, dag_tasks
1243
+ if node.name in seen:
1244
+ return templates, dag_tasks
1245
+
1246
+ seen.append(node.name)
860
1247
 
1248
+ # helper variable for recursive conditional inputs
1249
+ has_foreach_inputs = False
861
1250
  if node.name == "start":
862
1251
  # Start node has no dependencies.
863
1252
  dag_task = DAGTask(self._sanitize(node.name)).template(
864
1253
  self._sanitize(node.name)
865
1254
  )
866
- elif (
1255
+ if (
867
1256
  node.is_inside_foreach
868
1257
  and self.graph[node.in_funcs[0]].type == "foreach"
1258
+ and not self.graph[node.in_funcs[0]].parallel_foreach
1259
+ # We need to distinguish what is a "regular" foreach (i.e something that doesn't care about to gang semantics)
1260
+ # vs what is a "num_parallel" based foreach (i.e. something that follows gang semantics.)
1261
+ # A `regular` foreach is basically any arbitrary kind of foreach.
869
1262
  ):
1263
+ # helper variable for recursive conditional inputs
1264
+ has_foreach_inputs = True
870
1265
  # Child of a foreach node needs input-paths as well as split-index
871
1266
  # This child is the first node of the sub workflow and has no dependency
872
1267
  parameters = [
873
1268
  Parameter("input-paths").value("{{inputs.parameters.input-paths}}"),
874
1269
  Parameter("split-index").value("{{inputs.parameters.split-index}}"),
875
1270
  ]
1271
+ dag_task = (
1272
+ DAGTask(self._sanitize(node.name))
1273
+ .template(self._sanitize(node.name))
1274
+ .arguments(Arguments().parameters(parameters))
1275
+ )
1276
+ elif node.parallel_step:
1277
+ # This is the step where the @parallel decorator is defined.
1278
+ # Since this DAGTask will call the for the `resource` [based templates]
1279
+ # (https://argo-workflows.readthedocs.io/en/stable/walk-through/kubernetes-resources/)
1280
+ # we have certain constraints on the way we can pass information inside the Jobset manifest
1281
+ # [All templates will have access](https://argo-workflows.readthedocs.io/en/stable/variables/#all-templates)
1282
+ # to the `inputs.parameters` so we will pass down ANY/ALL information using the
1283
+ # input parameters.
1284
+ # We define the usual parameters like input-paths/split-index etc. but we will also
1285
+ # define the following:
1286
+ # - `workerCount`: parameter which will be used to determine the number of
1287
+ # parallel worker jobs
1288
+ # - `jobset-name`: parameter which will be used to determine the name of the jobset.
1289
+ # This parameter needs to be dynamic so that when we have retries we don't
1290
+ # end up using the name of the jobset again (if we do, it will crash since k8s wont allow duplicated job names)
1291
+ # - `retryCount`: parameter which will be used to determine the number of retries
1292
+ # This parameter will *only* be available within the container templates like we
1293
+ # have it for all other DAGTasks and NOT for custom kubernetes resource templates.
1294
+ # So as a work-around, we will set it as the `retryCount` parameter instead of
1295
+ # setting it as a {{ retries }} in the CLI code. Once set as a input parameter,
1296
+ # we can use it in the Jobset Manifest templates as `{{inputs.parameters.retryCount}}`
1297
+ # - `task-id-entropy`: This is a parameter which will help derive task-ids and jobset names. This parameter
1298
+ # contains the relevant amount of entropy to ensure that task-ids and jobset names
1299
+ # are uniquish. We will also use this in the join task to construct the task-ids of
1300
+ # all parallel tasks since the task-ids for parallel task are minted formulaically.
1301
+ parameters = [
1302
+ Parameter("input-paths").value("{{inputs.parameters.input-paths}}"),
1303
+ Parameter("num-parallel").value(
1304
+ "{{inputs.parameters.num-parallel}}"
1305
+ ),
1306
+ Parameter("split-index").value("{{inputs.parameters.split-index}}"),
1307
+ Parameter("task-id-entropy").value(
1308
+ "{{inputs.parameters.task-id-entropy}}"
1309
+ ),
1310
+ # we cant just use hyphens with sprig.
1311
+ # https://github.com/argoproj/argo-workflows/issues/10567#issuecomment-1452410948
1312
+ Parameter("workerCount").value(
1313
+ "{{=sprig.int(sprig.sub(sprig.int(inputs.parameters['num-parallel']),1))}}"
1314
+ ),
1315
+ ]
1316
+ if any(d.name == "retry" for d in node.decorators):
1317
+ parameters.extend(
1318
+ [
1319
+ Parameter("retryCount").value("{{retries}}"),
1320
+ # The job-setname needs to be unique for each retry
1321
+ # and we cannot use the `generateName` field in the
1322
+ # Jobset Manifest since we need to construct the subdomain
1323
+ # and control pod domain name pre-hand. So we will use
1324
+ # the retry count to ensure that the jobset name is unique
1325
+ Parameter("jobset-name").value(
1326
+ "js-{{inputs.parameters.task-id-entropy}}{{retries}}",
1327
+ ),
1328
+ ]
1329
+ )
1330
+ else:
1331
+ parameters.extend(
1332
+ [
1333
+ Parameter("jobset-name").value(
1334
+ "js-{{inputs.parameters.task-id-entropy}}",
1335
+ )
1336
+ ]
1337
+ )
1338
+
876
1339
  dag_task = (
877
1340
  DAGTask(self._sanitize(node.name))
878
1341
  .template(self._sanitize(node.name))
@@ -887,7 +1350,9 @@ class ArgoWorkflows(object):
887
1350
  "argo-{{workflow.name}}/%s/{{tasks.%s.outputs.parameters.task-id}}"
888
1351
  % (n, self._sanitize(n))
889
1352
  for n in node.in_funcs
890
- ]
1353
+ ],
1354
+ # NOTE: We set zlibmin to infinite because zlib compression for the Argo input-paths breaks template value substitution.
1355
+ zlibmin=inf,
891
1356
  )
892
1357
  )
893
1358
  ]
@@ -922,23 +1387,89 @@ class ArgoWorkflows(object):
922
1387
  ]
923
1388
  )
924
1389
 
1390
+ conditional_deps = [
1391
+ "%s.Succeeded" % self._sanitize(in_func)
1392
+ for in_func in node.in_funcs
1393
+ if self._is_conditional_node(self.graph[in_func])
1394
+ or self.graph[in_func].type == "split-switch"
1395
+ ]
1396
+ required_deps = [
1397
+ "%s.Succeeded" % self._sanitize(in_func)
1398
+ for in_func in node.in_funcs
1399
+ if not self._is_conditional_node(self.graph[in_func])
1400
+ and self.graph[in_func].type != "split-switch"
1401
+ ]
1402
+ if self._is_conditional_skip_node(
1403
+ node
1404
+ ) or self._many_in_funcs_all_conditional(node):
1405
+ # skip nodes need unique condition handling
1406
+ conditional_deps = [
1407
+ "%s.Succeeded" % self._sanitize(in_func)
1408
+ for in_func in node.in_funcs
1409
+ ]
1410
+ required_deps = []
1411
+
1412
+ both_conditions = required_deps and conditional_deps
1413
+
1414
+ depends_str = "{required}{_and}{conditional}".format(
1415
+ required=("(%s)" if both_conditions else "%s")
1416
+ % " && ".join(required_deps),
1417
+ _and=" && " if both_conditions else "",
1418
+ conditional=("(%s)" if both_conditions else "%s")
1419
+ % " || ".join(conditional_deps),
1420
+ )
925
1421
  dag_task = (
926
1422
  DAGTask(self._sanitize(node.name))
927
- .dependencies(
928
- [self._sanitize(in_func) for in_func in node.in_funcs]
929
- )
1423
+ .depends(depends_str)
930
1424
  .template(self._sanitize(node.name))
931
1425
  .arguments(Arguments().parameters(parameters))
932
1426
  )
933
- dag_tasks.append(dag_task)
934
1427
 
1428
+ # Add conditional if this is the first step in a conditional branch
1429
+ switch_in_funcs = [
1430
+ in_func
1431
+ for in_func in node.in_funcs
1432
+ if self.graph[in_func].type == "split-switch"
1433
+ ]
1434
+ if (
1435
+ self._is_conditional_node(node)
1436
+ or self._is_conditional_skip_node(node)
1437
+ or self._is_conditional_join_node(node)
1438
+ ) and switch_in_funcs:
1439
+ conditional_when = "||".join(
1440
+ [
1441
+ "{{tasks.%s.outputs.parameters.switch-step}}==%s"
1442
+ % (self._sanitize(switch_in_func), node.name)
1443
+ for switch_in_func in switch_in_funcs
1444
+ ]
1445
+ )
1446
+
1447
+ non_switch_in_funcs = [
1448
+ in_func
1449
+ for in_func in node.in_funcs
1450
+ if in_func not in switch_in_funcs
1451
+ ]
1452
+ status_when = ""
1453
+ if non_switch_in_funcs:
1454
+ status_when = "||".join(
1455
+ [
1456
+ "{{tasks.%s.status}}==Succeeded"
1457
+ % self._sanitize(in_func)
1458
+ for in_func in non_switch_in_funcs
1459
+ ]
1460
+ )
1461
+
1462
+ total_when = (
1463
+ f"({status_when}) || ({conditional_when})"
1464
+ if status_when
1465
+ else conditional_when
1466
+ )
1467
+ dag_task.when(total_when)
1468
+
1469
+ dag_tasks.append(dag_task)
935
1470
  # End the workflow if we have reached the end of the flow
936
1471
  if node.type == "end":
937
- return [
938
- Template(self.flow.name).dag(
939
- DAGTemplate().fail_fast().tasks(dag_tasks)
940
- )
941
- ] + templates, dag_tasks
1472
+ return templates, dag_tasks
942
1473
  # For split nodes traverse all the children
943
1474
  if node.type == "split":
944
1475
  for n in node.out_funcs:
@@ -948,6 +1479,7 @@ class ArgoWorkflows(object):
948
1479
  templates,
949
1480
  dag_tasks,
950
1481
  parent_foreach,
1482
+ seen,
951
1483
  )
952
1484
  return _visit(
953
1485
  self.graph[node.matching_join],
@@ -955,46 +1487,201 @@ class ArgoWorkflows(object):
955
1487
  templates,
956
1488
  dag_tasks,
957
1489
  parent_foreach,
1490
+ seen,
958
1491
  )
959
- # For foreach nodes generate a new sub DAGTemplate
960
- elif node.type == "foreach":
961
- foreach_template_name = self._sanitize(
962
- "%s-foreach-%s"
963
- % (
964
- node.name,
965
- node.foreach_param,
966
- )
967
- )
968
- foreach_task = (
969
- DAGTask(foreach_template_name)
970
- .dependencies([self._sanitize(node.name)])
971
- .template(foreach_template_name)
972
- .arguments(
973
- Arguments().parameters(
1492
+ elif node.type == "split-switch":
1493
+ if self._is_recursive_node(node):
1494
+ # we need an additional recursive template if the step is recursive
1495
+ # NOTE: in the recursive case, the original step is renamed in the container templates to 'recursive-<step_name>'
1496
+ # so that we do not have to touch the step references in the DAG.
1497
+ #
1498
+ # NOTE: The way that recursion in Argo Workflows is achieved is with the following structure:
1499
+ # - the usual 'example-step' template which would match example_step in flow code is renamed to 'recursive-example-step'
1500
+ # - templates has another template with the original task name: 'example-step'
1501
+ # - the template 'example-step' in turn has steps
1502
+ # - 'example-step-internal' which uses the metaflow step executing template 'recursive-example-step'
1503
+ # - 'example-step-recursion' which calls the parent template 'example-step' if switch-step output from 'example-step-internal' matches the condition.
1504
+ sanitized_name = self._sanitize(node.name)
1505
+ templates.append(
1506
+ Template(sanitized_name)
1507
+ .steps(
974
1508
  [
975
- Parameter("input-paths").value(
976
- "argo-{{workflow.name}}/%s/{{tasks.%s.outputs.parameters.task-id}}"
977
- % (node.name, self._sanitize(node.name))
978
- ),
979
- Parameter("split-index").value("{{item}}"),
1509
+ WorkflowStep()
1510
+ .name("%s-internal" % sanitized_name)
1511
+ .template("recursive-%s" % sanitized_name)
1512
+ .arguments(
1513
+ Arguments().parameters(
1514
+ [
1515
+ Parameter("input-paths").value(
1516
+ "{{inputs.parameters.input-paths}}"
1517
+ )
1518
+ ]
1519
+ # Add the additional inputs required by specific node types.
1520
+ # We do not need to cover joins or @parallel, as a split-switch step can not be either one of these.
1521
+ + (
1522
+ [
1523
+ Parameter("split-index").value(
1524
+ "{{inputs.parameters.split-index}}"
1525
+ )
1526
+ ]
1527
+ if has_foreach_inputs
1528
+ else []
1529
+ )
1530
+ )
1531
+ )
980
1532
  ]
981
- + (
982
- [
983
- Parameter("root-input-path").value(
984
- "argo-{{workflow.name}}/%s/{{tasks.%s.outputs.parameters.task-id}}"
1533
+ )
1534
+ .steps(
1535
+ [
1536
+ WorkflowStep()
1537
+ .name("%s-recursion" % sanitized_name)
1538
+ .template(sanitized_name)
1539
+ .when(
1540
+ "{{steps.%s-internal.outputs.parameters.switch-step}}==%s"
1541
+ % (sanitized_name, node.name)
1542
+ )
1543
+ .arguments(
1544
+ Arguments().parameters(
1545
+ [
1546
+ Parameter("input-paths").value(
1547
+ "argo-{{workflow.name}}/%s/{{steps.%s-internal.outputs.parameters.task-id}}"
1548
+ % (node.name, sanitized_name)
1549
+ )
1550
+ ]
1551
+ + (
1552
+ [
1553
+ Parameter("split-index").value(
1554
+ "{{inputs.parameters.split-index}}"
1555
+ )
1556
+ ]
1557
+ if has_foreach_inputs
1558
+ else []
1559
+ )
1560
+ )
1561
+ ),
1562
+ ]
1563
+ )
1564
+ .inputs(Inputs().parameters(parameters))
1565
+ .outputs(
1566
+ # NOTE: We try to read the output parameters from the recursive template call first (<step>-recursion), and the internal step second (<step>-internal).
1567
+ # This guarantees that we always get the output parameters of the last recursive step that executed.
1568
+ Outputs().parameters(
1569
+ [
1570
+ Parameter("task-id").valueFrom(
1571
+ {
1572
+ "expression": "(steps['%s-recursion']?.outputs ?? steps['%s-internal']?.outputs).parameters['task-id']"
1573
+ % (sanitized_name, sanitized_name)
1574
+ }
1575
+ ),
1576
+ Parameter("switch-step").valueFrom(
1577
+ {
1578
+ "expression": "(steps['%s-recursion']?.outputs ?? steps['%s-internal']?.outputs).parameters['switch-step']"
1579
+ % (sanitized_name, sanitized_name)
1580
+ }
1581
+ ),
1582
+ ]
1583
+ )
1584
+ )
1585
+ )
1586
+ for n in node.out_funcs:
1587
+ _visit(
1588
+ self.graph[n],
1589
+ self._matching_conditional_join(node),
1590
+ templates,
1591
+ dag_tasks,
1592
+ parent_foreach,
1593
+ seen,
1594
+ )
1595
+ return _visit(
1596
+ self.graph[self._matching_conditional_join(node)],
1597
+ exit_node,
1598
+ templates,
1599
+ dag_tasks,
1600
+ parent_foreach,
1601
+ seen,
1602
+ )
1603
+ # For foreach nodes generate a new sub DAGTemplate
1604
+ # We do this for "regular" foreaches (ie. `self.next(self.a, foreach=)`)
1605
+ elif node.type == "foreach":
1606
+ foreach_template_name = self._sanitize(
1607
+ "%s-foreach-%s"
1608
+ % (
1609
+ node.name,
1610
+ "parallel" if node.parallel_foreach else node.foreach_param,
1611
+ # Since foreach's are derived based on `self.next(self.a, foreach="<varname>")`
1612
+ # vs @parallel foreach are done based on `self.next(self.a, num_parallel="<some-number>")`,
1613
+ # we need to ensure that `foreach_template_name` suffix is appropriately set based on the kind
1614
+ # of foreach.
1615
+ )
1616
+ )
1617
+
1618
+ # There are two separate "DAGTask"s created for the foreach node.
1619
+ # - The first one is a "jump-off" DAGTask where we propagate the
1620
+ # input-paths and split-index. This thing doesn't create
1621
+ # any actual containers and it responsible for only propagating
1622
+ # the parameters.
1623
+ # - The DAGTask that follows first DAGTask is the one
1624
+ # that uses the ContainerTemplate. This DAGTask is named the same
1625
+ # thing as the foreach node. We will leverage a similar pattern for the
1626
+ # @parallel tasks.
1627
+ #
1628
+ foreach_task = (
1629
+ DAGTask(foreach_template_name)
1630
+ .depends(f"{self._sanitize(node.name)}.Succeeded")
1631
+ .template(foreach_template_name)
1632
+ .arguments(
1633
+ Arguments().parameters(
1634
+ [
1635
+ Parameter("input-paths").value(
1636
+ "argo-{{workflow.name}}/%s/{{tasks.%s.outputs.parameters.task-id}}"
1637
+ % (node.name, self._sanitize(node.name))
1638
+ ),
1639
+ Parameter("split-index").value("{{item}}"),
1640
+ ]
1641
+ + (
1642
+ [
1643
+ Parameter("root-input-path").value(
1644
+ "argo-{{workflow.name}}/%s/{{tasks.%s.outputs.parameters.task-id}}"
985
1645
  % (node.name, self._sanitize(node.name))
986
1646
  ),
987
1647
  ]
988
1648
  if parent_foreach
989
1649
  else []
990
1650
  )
1651
+ + (
1652
+ # Disabiguate parameters for a regular `foreach` vs a `@parallel` foreach
1653
+ [
1654
+ Parameter("num-parallel").value(
1655
+ "{{tasks.%s.outputs.parameters.num-parallel}}"
1656
+ % self._sanitize(node.name)
1657
+ ),
1658
+ Parameter("task-id-entropy").value(
1659
+ "{{tasks.%s.outputs.parameters.task-id-entropy}}"
1660
+ % self._sanitize(node.name)
1661
+ ),
1662
+ ]
1663
+ if node.parallel_foreach
1664
+ else []
1665
+ )
991
1666
  )
992
1667
  )
993
1668
  .with_param(
1669
+ # For @parallel workloads `num-splits` will be explicitly set to one so that
1670
+ # we can piggyback on the current mechanism with which we leverage argo.
994
1671
  "{{tasks.%s.outputs.parameters.num-splits}}"
995
1672
  % self._sanitize(node.name)
996
1673
  )
997
1674
  )
1675
+ # Add conditional if this is the first step in a conditional branch
1676
+ if self._is_conditional_node(node) and not any(
1677
+ self._is_conditional_node(self.graph[in_func])
1678
+ for in_func in node.in_funcs
1679
+ ):
1680
+ in_func = node.in_funcs[0]
1681
+ foreach_task.when(
1682
+ "{{tasks.%s.outputs.parameters.switch-step}}==%s"
1683
+ % (self._sanitize(in_func), node.name)
1684
+ )
998
1685
  dag_tasks.append(foreach_task)
999
1686
  templates, dag_tasks_1 = _visit(
1000
1687
  self.graph[node.out_funcs[0]],
@@ -1002,18 +1689,36 @@ class ArgoWorkflows(object):
1002
1689
  templates,
1003
1690
  [],
1004
1691
  node.name,
1692
+ seen,
1005
1693
  )
1694
+
1695
+ # How do foreach's work on Argo:
1696
+ # Lets say you have the following dag: (start[sets `foreach="x"`]) --> (task-a [actual foreach]) --> (join) --> (end)
1697
+ # With argo we will :
1698
+ # (start [sets num-splits]) --> (task-a-foreach-(0,0) [dummy task]) --> (task-a) --> (join) --> (end)
1699
+ # The (task-a-foreach-(0,0) [dummy task]) propagates the values of the `split-index` and the input paths.
1700
+ # to the actual foreach task.
1006
1701
  templates.append(
1007
1702
  Template(foreach_template_name)
1008
1703
  .inputs(
1009
1704
  Inputs().parameters(
1010
1705
  [Parameter("input-paths"), Parameter("split-index")]
1011
1706
  + ([Parameter("root-input-path")] if parent_foreach else [])
1707
+ + (
1708
+ [
1709
+ Parameter("num-parallel"),
1710
+ Parameter("task-id-entropy"),
1711
+ # Parameter("workerCount")
1712
+ ]
1713
+ if node.parallel_foreach
1714
+ else []
1715
+ )
1012
1716
  )
1013
1717
  )
1014
1718
  .outputs(
1015
1719
  Outputs().parameters(
1016
1720
  [
1721
+ # non @parallel tasks set task-ids as outputs
1017
1722
  Parameter("task-id").valueFrom(
1018
1723
  {
1019
1724
  "parameter": "{{tasks.%s.outputs.parameters.task-id}}"
@@ -1021,31 +1726,84 @@ class ArgoWorkflows(object):
1021
1726
  self.graph[node.matching_join].in_funcs[0]
1022
1727
  )
1023
1728
  }
1024
- )
1729
+ if not self._is_conditional_join_node(
1730
+ self.graph[node.matching_join]
1731
+ )
1732
+ else
1733
+ # Note: If the nodes leading to the join are conditional, then we need to use an expression to pick the outputs from the task that executed.
1734
+ # ref for operators: https://github.com/expr-lang/expr/blob/master/docs/language-definition.md
1735
+ {
1736
+ "expression": "get((%s)?.parameters, 'task-id')"
1737
+ % " ?? ".join(
1738
+ f"tasks['{self._sanitize(func)}']?.outputs"
1739
+ for func in self.graph[
1740
+ node.matching_join
1741
+ ].in_funcs
1742
+ )
1743
+ }
1744
+ ),
1745
+ ]
1746
+ if not node.parallel_foreach
1747
+ else [
1748
+ # @parallel tasks set `task-id-entropy` and `num-parallel`
1749
+ # as outputs so task-ids can be derived in the join step.
1750
+ # Both of these values should be propagated from the
1751
+ # jobset labels.
1752
+ Parameter("num-parallel").valueFrom(
1753
+ {
1754
+ "parameter": "{{tasks.%s.outputs.parameters.num-parallel}}"
1755
+ % self._sanitize(
1756
+ self.graph[node.matching_join].in_funcs[0]
1757
+ )
1758
+ }
1759
+ ),
1760
+ Parameter("task-id-entropy").valueFrom(
1761
+ {
1762
+ "parameter": "{{tasks.%s.outputs.parameters.task-id-entropy}}"
1763
+ % self._sanitize(
1764
+ self.graph[node.matching_join].in_funcs[0]
1765
+ )
1766
+ }
1767
+ ),
1025
1768
  ]
1026
1769
  )
1027
1770
  )
1028
1771
  .dag(DAGTemplate().fail_fast().tasks(dag_tasks_1))
1029
1772
  )
1773
+
1030
1774
  join_foreach_task = (
1031
1775
  DAGTask(self._sanitize(self.graph[node.matching_join].name))
1032
1776
  .template(self._sanitize(self.graph[node.matching_join].name))
1033
- .dependencies([foreach_template_name])
1777
+ .depends(f"{foreach_template_name}.Succeeded")
1034
1778
  .arguments(
1035
1779
  Arguments().parameters(
1036
- [
1037
- Parameter("input-paths").value(
1038
- "argo-{{workflow.name}}/%s/{{tasks.%s.outputs.parameters.task-id}}"
1039
- % (node.name, self._sanitize(node.name))
1040
- ),
1041
- Parameter("split-cardinality").value(
1042
- "{{tasks.%s.outputs.parameters.split-cardinality}}"
1043
- % self._sanitize(node.name)
1044
- ),
1045
- ]
1780
+ (
1781
+ [
1782
+ Parameter("input-paths").value(
1783
+ "argo-{{workflow.name}}/%s/{{tasks.%s.outputs.parameters.task-id}}"
1784
+ % (node.name, self._sanitize(node.name))
1785
+ ),
1786
+ Parameter("split-cardinality").value(
1787
+ "{{tasks.%s.outputs.parameters.split-cardinality}}"
1788
+ % self._sanitize(node.name)
1789
+ ),
1790
+ ]
1791
+ if not node.parallel_foreach
1792
+ else [
1793
+ Parameter("num-parallel").value(
1794
+ "{{tasks.%s.outputs.parameters.num-parallel}}"
1795
+ % self._sanitize(node.name)
1796
+ ),
1797
+ Parameter("task-id-entropy").value(
1798
+ "{{tasks.%s.outputs.parameters.task-id-entropy}}"
1799
+ % self._sanitize(node.name)
1800
+ ),
1801
+ ]
1802
+ )
1046
1803
  + (
1047
1804
  [
1048
1805
  Parameter("split-index").value(
1806
+ # TODO : Pass down these parameters to the jobset stuff.
1049
1807
  "{{inputs.parameters.split-index}}"
1050
1808
  ),
1051
1809
  Parameter("root-input-path").value(
@@ -1065,6 +1823,7 @@ class ArgoWorkflows(object):
1065
1823
  templates,
1066
1824
  dag_tasks,
1067
1825
  parent_foreach,
1826
+ seen,
1068
1827
  )
1069
1828
  # For linear nodes continue traversing to the next node
1070
1829
  if node.type in ("linear", "join", "start"):
@@ -1074,6 +1833,7 @@ class ArgoWorkflows(object):
1074
1833
  templates,
1075
1834
  dag_tasks,
1076
1835
  parent_foreach,
1836
+ seen,
1077
1837
  )
1078
1838
  else:
1079
1839
  raise ArgoWorkflowsException(
@@ -1081,7 +1841,17 @@ class ArgoWorkflows(object):
1081
1841
  "Argo Workflows." % (node.type, node.name)
1082
1842
  )
1083
1843
 
1084
- templates, _ = _visit(node=self.graph["start"])
1844
+ # Generate daemon tasks
1845
+ daemon_tasks = [
1846
+ DAGTask("%s-task" % daemon_template.name).template(daemon_template.name)
1847
+ for daemon_template in self._daemon_templates()
1848
+ ]
1849
+
1850
+ templates, dag_tasks = _visit(node=self.graph["start"], dag_tasks=daemon_tasks)
1851
+ # Add the DAG template only after fully traversing the graph so we are guaranteed to have all the dag_tasks collected.
1852
+ templates.append(
1853
+ Template(self.flow.name).dag(DAGTemplate().fail_fast().tasks(dag_tasks))
1854
+ )
1085
1855
  return templates
1086
1856
 
1087
1857
  # Visit every node and yield ContainerTemplates.
@@ -1123,10 +1893,32 @@ class ArgoWorkflows(object):
1123
1893
  # export input_paths as it is used multiple times in the container script
1124
1894
  # and we do not want to repeat the values.
1125
1895
  input_paths_expr = "export INPUT_PATHS=''"
1126
- if node.name != "start":
1896
+ # If node is not a start step or a @parallel join then we will set the input paths.
1897
+ # To set the input-paths as a parameter, we need to ensure that the node
1898
+ # is not (a start node or a parallel join node). Start nodes will have no
1899
+ # input paths and parallel join will derive input paths based on a
1900
+ # formulaic approach using `num-parallel` and `task-id-entropy`.
1901
+ if not (
1902
+ node.name == "start"
1903
+ or (node.type == "join" and self.graph[node.in_funcs[0]].parallel_step)
1904
+ ):
1905
+ # For parallel joins we don't pass the INPUT_PATHS but are dynamically constructed.
1906
+ # So we don't need to set the input paths.
1127
1907
  input_paths_expr = (
1128
1908
  "export INPUT_PATHS={{inputs.parameters.input-paths}}"
1129
1909
  )
1910
+ if (
1911
+ self._is_conditional_join_node(node)
1912
+ or self._many_in_funcs_all_conditional(node)
1913
+ or self._is_conditional_skip_node(node)
1914
+ ):
1915
+ # NOTE: Argo template expressions that fail to resolve, output the expression itself as a value.
1916
+ # With conditional steps, some of the input-paths are therefore 'broken' due to containing a nil expression
1917
+ # e.g. "{{ tasks['A'].outputs.parameters.task-id }}" when task A never executed.
1918
+ # We base64 encode the input-paths in order to not pollute the execution environment with templating expressions.
1919
+ # NOTE: Adding conditionals that check if a key exists or not does not work either, due to an issue with how Argo
1920
+ # handles tasks in a nested foreach (withParam template) leading to all such expressions getting evaluated as false.
1921
+ input_paths_expr = "export INPUT_PATHS={{=toBase64(inputs.parameters['input-paths'])}}"
1130
1922
  input_paths = "$(echo $INPUT_PATHS)"
1131
1923
  if any(self.graph[n].type == "foreach" for n in node.in_funcs):
1132
1924
  task_idx = "{{inputs.parameters.split-index}}"
@@ -1142,7 +1934,6 @@ class ArgoWorkflows(object):
1142
1934
  # foreaches
1143
1935
  task_idx = "{{inputs.parameters.split-index}}"
1144
1936
  root_input = "{{inputs.parameters.root-input-path}}"
1145
-
1146
1937
  # Task string to be hashed into an ID
1147
1938
  task_str = "-".join(
1148
1939
  [
@@ -1152,13 +1943,23 @@ class ArgoWorkflows(object):
1152
1943
  task_idx,
1153
1944
  ]
1154
1945
  )
1946
+ if node.parallel_step:
1947
+ task_str = "-".join(
1948
+ [
1949
+ "$TASK_ID_PREFIX",
1950
+ "{{inputs.parameters.task-id-entropy}}",
1951
+ "$TASK_ID_SUFFIX",
1952
+ ]
1953
+ )
1954
+ else:
1955
+ # Generated task_ids need to be non-numeric - see register_task_id in
1956
+ # service.py. We do so by prefixing `t-`
1957
+ _task_id_base = (
1958
+ "$(echo %s | md5sum | cut -d ' ' -f 1 | tail -c 9)" % task_str
1959
+ )
1960
+ task_str = "(t-%s)" % _task_id_base
1155
1961
 
1156
- # Generated task_ids need to be non-numeric - see register_task_id in
1157
- # service.py. We do so by prefixing `t-`
1158
- task_id_expr = (
1159
- "export METAFLOW_TASK_ID="
1160
- "(t-$(echo %s | md5sum | cut -d ' ' -f 1 | tail -c 9))" % task_str
1161
- )
1962
+ task_id_expr = "export METAFLOW_TASK_ID=" "%s" % task_str
1162
1963
  task_id = "$METAFLOW_TASK_ID"
1163
1964
 
1164
1965
  # Resolve retry strategy.
@@ -1177,9 +1978,18 @@ class ArgoWorkflows(object):
1177
1978
  user_code_retries = max_user_code_retries
1178
1979
  total_retries = max_user_code_retries + max_error_retries
1179
1980
  # {{retries}} is only available if retryStrategy is specified
1981
+ # For custom kubernetes manifests, we will pass the retryCount as a parameter
1982
+ # and use that in the manifest.
1180
1983
  retry_count = (
1181
- "{{retries}}" if max_user_code_retries + max_error_retries else 0
1984
+ (
1985
+ "{{retries}}"
1986
+ if not node.parallel_step
1987
+ else "{{inputs.parameters.retryCount}}"
1988
+ )
1989
+ if total_retries
1990
+ else 0
1182
1991
  )
1992
+
1183
1993
  minutes_between_retries = int(minutes_between_retries)
1184
1994
 
1185
1995
  # Configure log capture.
@@ -1206,7 +2016,9 @@ class ArgoWorkflows(object):
1206
2016
  mflog_expr,
1207
2017
  ]
1208
2018
  + self.environment.get_package_commands(
1209
- self.code_package_url, self.flow_datastore.TYPE
2019
+ self.code_package_url,
2020
+ self.flow_datastore.TYPE,
2021
+ self.code_package_metadata,
1210
2022
  )
1211
2023
  )
1212
2024
  step_cmds = self.environment.bootstrap_commands(
@@ -1218,12 +2030,13 @@ class ArgoWorkflows(object):
1218
2030
  decorator.make_decorator_spec()
1219
2031
  for decorator in node.decorators
1220
2032
  if not decorator.statically_defined
2033
+ and decorator.inserted_by is None
1221
2034
  ]
1222
2035
  }
1223
2036
  # FlowDecorators can define their own top-level options. They are
1224
2037
  # responsible for adding their own top-level options and values through
1225
2038
  # the get_top_level_options() hook. See similar logic in runtime.py.
1226
- for deco in flow_decorators():
2039
+ for deco in flow_decorators(self.flow):
1227
2040
  top_opts_dict.update(deco.get_top_level_options())
1228
2041
 
1229
2042
  top_level = list(dict_to_cli_options(top_opts_dict)) + [
@@ -1255,7 +2068,7 @@ class ArgoWorkflows(object):
1255
2068
  # {{foo.bar['param_name']}}.
1256
2069
  # https://argoproj.github.io/argo-events/tutorials/02-parameterization/
1257
2070
  # http://masterminds.github.io/sprig/strings.html
1258
- "--%s={{workflow.parameters.%s}}"
2071
+ "--%s=\\\"$(python -m metaflow.plugins.argo.param_val {{=toBase64(workflow.parameters['%s'])}})\\\""
1259
2072
  % (parameter["name"], parameter["name"])
1260
2073
  for parameter in self.parameters.values()
1261
2074
  ]
@@ -1277,21 +2090,63 @@ class ArgoWorkflows(object):
1277
2090
  ]
1278
2091
  )
1279
2092
  input_paths = "%s/_parameters/%s" % (run_id, task_id_params)
2093
+ # Only for static joins and conditional_joins
2094
+ elif (
2095
+ self._is_conditional_join_node(node)
2096
+ or self._many_in_funcs_all_conditional(node)
2097
+ or self._is_conditional_skip_node(node)
2098
+ ) and not (
2099
+ node.type == "join"
2100
+ and self.graph[node.split_parents[-1]].type == "foreach"
2101
+ ):
2102
+ # we need to pass in the set of conditional in_funcs to the pathspec generating script as in the case of split-switch skipping cases,
2103
+ # non-conditional input-paths need to be ignored in favour of conditional ones when they have executed.
2104
+ skippable_input_steps = ",".join(
2105
+ [
2106
+ in_func
2107
+ for in_func in node.in_funcs
2108
+ if self.graph[in_func].type == "split-switch"
2109
+ ]
2110
+ )
2111
+ input_paths = (
2112
+ "$(python -m metaflow.plugins.argo.conditional_input_paths %s %s)"
2113
+ % (input_paths, skippable_input_steps)
2114
+ )
1280
2115
  elif (
1281
2116
  node.type == "join"
1282
2117
  and self.graph[node.split_parents[-1]].type == "foreach"
1283
2118
  ):
2119
+ # foreach-joins straight out of conditional branches are not yet supported
2120
+ if self._is_conditional_join_node(node) and len(node.in_funcs) > 1:
2121
+ raise ArgoWorkflowsException(
2122
+ "Conditional steps inside a foreach that transition directly into a join step are not currently supported.\n"
2123
+ "As a workaround, add a common step after the conditional steps %s "
2124
+ "that will transition to a join."
2125
+ % ", ".join("*%s*" % f for f in node.in_funcs)
2126
+ )
1284
2127
  # Set aggregated input-paths for a for-each join
1285
2128
  foreach_step = next(
1286
2129
  n for n in node.in_funcs if self.graph[n].is_inside_foreach
1287
2130
  )
1288
- input_paths = (
1289
- "$(python -m metaflow.plugins.argo.generate_input_paths %s {{workflow.creationTimestamp}} %s {{inputs.parameters.split-cardinality}})"
1290
- % (
1291
- foreach_step,
1292
- input_paths,
2131
+ if not self.graph[node.split_parents[-1]].parallel_foreach:
2132
+ input_paths = (
2133
+ "$(python -m metaflow.plugins.argo.generate_input_paths %s {{workflow.creationTimestamp}} %s {{inputs.parameters.split-cardinality}})"
2134
+ % (
2135
+ foreach_step,
2136
+ input_paths,
2137
+ )
1293
2138
  )
1294
- )
2139
+ else:
2140
+ # Handle @parallel where output from volume mount isn't accessible
2141
+ input_paths = (
2142
+ "$(python -m metaflow.plugins.argo.jobset_input_paths %s %s {{inputs.parameters.task-id-entropy}} {{inputs.parameters.num-parallel}})"
2143
+ % (
2144
+ run_id,
2145
+ foreach_step,
2146
+ )
2147
+ )
2148
+ # NOTE: input-paths might be extremely lengthy so we dump these to disk instead of passing them directly to the cmd
2149
+ step_cmds.append("echo %s >> /tmp/mf-input-paths" % input_paths)
1295
2150
  step = [
1296
2151
  "step",
1297
2152
  node.name,
@@ -1299,9 +2154,16 @@ class ArgoWorkflows(object):
1299
2154
  "--task-id %s" % task_id,
1300
2155
  "--retry-count %s" % retry_count,
1301
2156
  "--max-user-code-retries %d" % user_code_retries,
1302
- "--input-paths %s" % input_paths,
2157
+ "--input-paths-filename /tmp/mf-input-paths",
1303
2158
  ]
1304
- if any(self.graph[n].type == "foreach" for n in node.in_funcs):
2159
+ if node.parallel_step:
2160
+ step.append(
2161
+ "--split-index ${MF_CONTROL_INDEX:-$((MF_WORKER_REPLICA_INDEX + 1))}"
2162
+ )
2163
+ # This is needed for setting the value of the UBF context in the CLI.
2164
+ step.append("--ubf-context $UBF_CONTEXT")
2165
+
2166
+ elif any(self.graph[n].type == "foreach" for n in node.in_funcs):
1305
2167
  # Pass split-index to a foreach task
1306
2168
  step.append("--split-index {{inputs.parameters.split-index}}")
1307
2169
  if self.tags:
@@ -1367,6 +2229,7 @@ class ArgoWorkflows(object):
1367
2229
  **{
1368
2230
  # These values are needed by Metaflow to set it's internal
1369
2231
  # state appropriately.
2232
+ "METAFLOW_CODE_METADATA": self.code_package_metadata,
1370
2233
  "METAFLOW_CODE_URL": self.code_package_url,
1371
2234
  "METAFLOW_CODE_SHA": self.code_package_sha,
1372
2235
  "METAFLOW_CODE_DS": self.flow_datastore.TYPE,
@@ -1395,6 +2258,7 @@ class ArgoWorkflows(object):
1395
2258
  },
1396
2259
  **{
1397
2260
  # Some optional values for bookkeeping
2261
+ "METAFLOW_FLOW_FILENAME": os.path.basename(sys.argv[0]),
1398
2262
  "METAFLOW_FLOW_NAME": self.flow.name,
1399
2263
  "METAFLOW_STEP_NAME": node.name,
1400
2264
  "METAFLOW_RUN_ID": run_id,
@@ -1413,20 +2277,30 @@ class ArgoWorkflows(object):
1413
2277
 
1414
2278
  # support Metaflow sandboxes
1415
2279
  env["METAFLOW_INIT_SCRIPT"] = KUBERNETES_SANDBOX_INIT_SCRIPT
2280
+ env["METAFLOW_KUBERNETES_SANDBOX_INIT_SCRIPT"] = (
2281
+ KUBERNETES_SANDBOX_INIT_SCRIPT
2282
+ )
1416
2283
 
1417
2284
  # support for @secret
1418
2285
  env["METAFLOW_DEFAULT_SECRETS_BACKEND_TYPE"] = DEFAULT_SECRETS_BACKEND_TYPE
1419
- env[
1420
- "METAFLOW_AWS_SECRETS_MANAGER_DEFAULT_REGION"
1421
- ] = AWS_SECRETS_MANAGER_DEFAULT_REGION
2286
+ env["METAFLOW_AWS_SECRETS_MANAGER_DEFAULT_REGION"] = (
2287
+ AWS_SECRETS_MANAGER_DEFAULT_REGION
2288
+ )
1422
2289
  env["METAFLOW_GCP_SECRET_MANAGER_PREFIX"] = GCP_SECRET_MANAGER_PREFIX
2290
+ env["METAFLOW_AZURE_KEY_VAULT_PREFIX"] = AZURE_KEY_VAULT_PREFIX
1423
2291
 
1424
2292
  # support for Azure
1425
- env[
1426
- "METAFLOW_AZURE_STORAGE_BLOB_SERVICE_ENDPOINT"
1427
- ] = AZURE_STORAGE_BLOB_SERVICE_ENDPOINT
2293
+ env["METAFLOW_AZURE_STORAGE_BLOB_SERVICE_ENDPOINT"] = (
2294
+ AZURE_STORAGE_BLOB_SERVICE_ENDPOINT
2295
+ )
1428
2296
  env["METAFLOW_DATASTORE_SYSROOT_AZURE"] = DATASTORE_SYSROOT_AZURE
1429
2297
  env["METAFLOW_CARD_AZUREROOT"] = CARD_AZUREROOT
2298
+ env["METAFLOW_ARGO_WORKFLOWS_KUBERNETES_SECRETS"] = (
2299
+ ARGO_WORKFLOWS_KUBERNETES_SECRETS
2300
+ )
2301
+ env["METAFLOW_ARGO_WORKFLOWS_ENV_VARS_TO_SKIP"] = (
2302
+ ARGO_WORKFLOWS_ENV_VARS_TO_SKIP
2303
+ )
1430
2304
 
1431
2305
  # support for GCP
1432
2306
  env["METAFLOW_DATASTORE_SYSROOT_GS"] = DATASTORE_SYSROOT_GS
@@ -1449,6 +2323,13 @@ class ArgoWorkflows(object):
1449
2323
  metaflow_version["production_token"] = self.production_token
1450
2324
  env["METAFLOW_VERSION"] = json.dumps(metaflow_version)
1451
2325
 
2326
+ # map config values
2327
+ cfg_env = {
2328
+ param["name"]: param["kv_name"] for param in self.config_parameters
2329
+ }
2330
+ if cfg_env:
2331
+ env["METAFLOW_FLOW_CONFIG_VALUE"] = json.dumps(cfg_env)
2332
+
1452
2333
  # Set the template inputs and outputs for passing state. Very simply,
1453
2334
  # the container template takes in input-paths as input and outputs
1454
2335
  # the task-id (which feeds in as input-paths to the subsequent task).
@@ -1463,17 +2344,45 @@ class ArgoWorkflows(object):
1463
2344
  # join task deterministically inside the join task without resorting to
1464
2345
  # passing a rather long list of (albiet compressed)
1465
2346
  inputs = []
1466
- if node.name != "start":
2347
+ # To set the input-paths as a parameter, we need to ensure that the node
2348
+ # is not (a start node or a parallel join node). Start nodes will have no
2349
+ # input paths and parallel join will derive input paths based on a
2350
+ # formulaic approach.
2351
+ if not (
2352
+ node.name == "start"
2353
+ or (node.type == "join" and self.graph[node.in_funcs[0]].parallel_step)
2354
+ ):
1467
2355
  inputs.append(Parameter("input-paths"))
1468
2356
  if any(self.graph[n].type == "foreach" for n in node.in_funcs):
1469
2357
  # Fetch split-index from parent
1470
2358
  inputs.append(Parameter("split-index"))
2359
+
1471
2360
  if (
1472
2361
  node.type == "join"
1473
2362
  and self.graph[node.split_parents[-1]].type == "foreach"
1474
2363
  ):
1475
- # append this only for joins of foreaches, not static splits
1476
- inputs.append(Parameter("split-cardinality"))
2364
+ # @parallel join tasks require `num-parallel` and `task-id-entropy`
2365
+ # to construct the input paths, so we pass them down as input parameters.
2366
+ if self.graph[node.split_parents[-1]].parallel_foreach:
2367
+ inputs.extend(
2368
+ [Parameter("num-parallel"), Parameter("task-id-entropy")]
2369
+ )
2370
+ else:
2371
+ # append these only for joins of foreaches, not static splits
2372
+ inputs.append(Parameter("split-cardinality"))
2373
+ # check if the node is a @parallel node.
2374
+ elif node.parallel_step:
2375
+ inputs.extend(
2376
+ [
2377
+ Parameter("num-parallel"),
2378
+ Parameter("task-id-entropy"),
2379
+ Parameter("jobset-name"),
2380
+ Parameter("workerCount"),
2381
+ ]
2382
+ )
2383
+ if any(d.name == "retry" for d in node.decorators):
2384
+ inputs.append(Parameter("retryCount"))
2385
+
1477
2386
  if node.is_inside_foreach and self.graph[node.out_funcs[0]].type == "join":
1478
2387
  if any(
1479
2388
  self.graph[parent].matching_join
@@ -1490,8 +2399,17 @@ class ArgoWorkflows(object):
1490
2399
  inputs.append(Parameter("root-input-path"))
1491
2400
 
1492
2401
  outputs = []
1493
- if node.name != "end":
2402
+ # @parallel steps will not have a task-id as an output parameter since task-ids
2403
+ # are derived at runtime.
2404
+ if not (node.name == "end" or node.parallel_step):
1494
2405
  outputs = [Parameter("task-id").valueFrom({"path": "/mnt/out/task_id"})]
2406
+
2407
+ # If this step is a split-switch one, we need to output the switch step name
2408
+ if node.type == "split-switch":
2409
+ outputs.append(
2410
+ Parameter("switch-step").valueFrom({"path": "/mnt/out/switch_step"})
2411
+ )
2412
+
1495
2413
  if node.type == "foreach":
1496
2414
  # Emit split cardinality from foreach task
1497
2415
  outputs.append(
@@ -1503,6 +2421,19 @@ class ArgoWorkflows(object):
1503
2421
  )
1504
2422
  )
1505
2423
 
2424
+ if node.parallel_foreach:
2425
+ outputs.extend(
2426
+ [
2427
+ Parameter("num-parallel").valueFrom(
2428
+ {"path": "/mnt/out/num_parallel"}
2429
+ ),
2430
+ Parameter("task-id-entropy").valueFrom(
2431
+ {"path": "/mnt/out/task_id_entropy"}
2432
+ ),
2433
+ ]
2434
+ )
2435
+ # Outputs should be defined over here and not in the _dag_template for @parallel.
2436
+
1506
2437
  # It makes no sense to set env vars to None (shows up as "None" string)
1507
2438
  # Also we skip some env vars (e.g. in case we want to pull them from KUBERNETES_SECRETS)
1508
2439
  env = {
@@ -1512,6 +2443,12 @@ class ArgoWorkflows(object):
1512
2443
  and k not in set(ARGO_WORKFLOWS_ENV_VARS_TO_SKIP.split(","))
1513
2444
  }
1514
2445
 
2446
+ # OBP configs
2447
+ additional_obp_configs = {
2448
+ "OBP_PERIMETER": self.initial_configs["OBP_PERIMETER"],
2449
+ "OBP_INTEGRATIONS_URL": self.initial_configs["OBP_INTEGRATIONS_URL"],
2450
+ }
2451
+
1515
2452
  # Tmpfs variables
1516
2453
  use_tmpfs = resources["use_tmpfs"]
1517
2454
  tmpfs_size = resources["tmpfs_size"]
@@ -1528,262 +2465,938 @@ class ArgoWorkflows(object):
1528
2465
 
1529
2466
  if tmpfs_enabled and tmpfs_tempdir:
1530
2467
  env["METAFLOW_TEMPDIR"] = tmpfs_path
2468
+
2469
+ qos_requests, qos_limits = qos_requests_and_limits(
2470
+ resources["qos"],
2471
+ resources["cpu"],
2472
+ resources["memory"],
2473
+ resources["disk"],
2474
+ )
2475
+
2476
+ security_context = resources.get("security_context", None)
2477
+ _security_context = {}
2478
+ if security_context is not None and len(security_context) > 0:
2479
+ _security_context = {
2480
+ "security_context": kubernetes_sdk.V1SecurityContext(
2481
+ **security_context
2482
+ )
2483
+ }
2484
+
1531
2485
  # Create a ContainerTemplate for this node. Ideally, we would have
1532
2486
  # liked to inline this ContainerTemplate and avoid scanning the workflow
1533
2487
  # twice, but due to issues with variable substitution, we will have to
1534
2488
  # live with this routine.
1535
- yield (
1536
- Template(self._sanitize(node.name))
1537
- # Set @timeout values
1538
- .active_deadline_seconds(run_time_limit)
1539
- # Set service account
1540
- .service_account_name(resources["service_account"])
1541
- # Configure template input
1542
- .inputs(Inputs().parameters(inputs))
1543
- # Configure template output
1544
- .outputs(Outputs().parameters(outputs))
1545
- # Fail fast!
1546
- .fail_fast()
1547
- # Set @retry/@catch values
1548
- .retry_strategy(
1549
- times=total_retries,
1550
- minutes_between_retries=minutes_between_retries,
1551
- )
1552
- .metadata(
1553
- ObjectMeta().annotation("metaflow/step_name", node.name)
1554
- # Unfortunately, we can't set the task_id since it is generated
1555
- # inside the pod. However, it can be inferred from the annotation
1556
- # set by argo-workflows - `workflows.argoproj.io/outputs` - refer
1557
- # the field 'task-id' in 'parameters'
1558
- # .annotation("metaflow/task_id", ...)
1559
- .annotation("metaflow/attempt", retry_count)
1560
- )
1561
- # Set emptyDir volume for state management
1562
- .empty_dir_volume("out")
1563
- # Set tmpfs emptyDir volume if enabled
1564
- .empty_dir_volume(
1565
- "tmpfs-ephemeral-volume",
1566
- medium="Memory",
1567
- size_limit=tmpfs_size if tmpfs_enabled else 0,
1568
- )
1569
- .empty_dir_volume("dhsm", medium="Memory", size_limit=shared_memory)
1570
- .pvc_volumes(resources.get("persistent_volume_claims"))
1571
- # Set node selectors
1572
- .node_selectors(resources.get("node_selector"))
1573
- # Set tolerations
1574
- .tolerations(resources.get("tolerations"))
1575
- # Set container
1576
- .container(
1577
- # TODO: Unify the logic with kubernetes.py
1578
- # Important note - Unfortunately, V1Container uses snakecase while
1579
- # Argo Workflows uses camel. For most of the attributes, both cases
1580
- # are indistinguishable, but unfortunately, not for all - (
1581
- # env_from, value_from, etc.) - so we need to handle the conversion
1582
- # ourselves using to_camelcase. We need to be vigilant about
1583
- # resources attributes in particular where the keys maybe user
1584
- # defined.
1585
- to_camelcase(
1586
- kubernetes_sdk.V1Container(
1587
- name=self._sanitize(node.name),
1588
- command=cmds,
1589
- ports=[kubernetes_sdk.V1ContainerPort(container_port=port)]
1590
- if port
1591
- else None,
1592
- env=[
1593
- kubernetes_sdk.V1EnvVar(name=k, value=str(v))
1594
- for k, v in env.items()
1595
- ]
1596
- # Add environment variables for book-keeping.
1597
- # https://argoproj.github.io/argo-workflows/fields/#fields_155
1598
- + [
1599
- kubernetes_sdk.V1EnvVar(
1600
- name=k,
1601
- value_from=kubernetes_sdk.V1EnvVarSource(
1602
- field_ref=kubernetes_sdk.V1ObjectFieldSelector(
1603
- field_path=str(v)
1604
- )
1605
- ),
1606
- )
1607
- for k, v in {
1608
- "METAFLOW_KUBERNETES_POD_NAMESPACE": "metadata.namespace",
1609
- "METAFLOW_KUBERNETES_POD_NAME": "metadata.name",
1610
- "METAFLOW_KUBERNETES_POD_ID": "metadata.uid",
1611
- "METAFLOW_KUBERNETES_SERVICE_ACCOUNT_NAME": "spec.serviceAccountName",
1612
- "METAFLOW_KUBERNETES_NODE_IP": "status.hostIP",
1613
- }.items()
1614
- ],
1615
- image=resources["image"],
1616
- image_pull_policy=resources["image_pull_policy"],
1617
- resources=kubernetes_sdk.V1ResourceRequirements(
1618
- requests={
1619
- "cpu": str(resources["cpu"]),
1620
- "memory": "%sM" % str(resources["memory"]),
1621
- "ephemeral-storage": "%sM" % str(resources["disk"]),
1622
- },
1623
- limits={
1624
- "%s.com/gpu".lower()
1625
- % resources["gpu_vendor"]: str(resources["gpu"])
1626
- for k in [0]
1627
- if resources["gpu"] is not None
1628
- },
1629
- ),
1630
- # Configure secrets
1631
- env_from=[
1632
- kubernetes_sdk.V1EnvFromSource(
1633
- secret_ref=kubernetes_sdk.V1SecretEnvSource(
1634
- name=str(k),
1635
- # optional=True
1636
- )
1637
- )
1638
- for k in list(
2489
+ if node.parallel_step:
2490
+ jobset_name = "{{inputs.parameters.jobset-name}}"
2491
+ jobset = KubernetesArgoJobSet(
2492
+ kubernetes_sdk=kubernetes_sdk,
2493
+ name=jobset_name,
2494
+ flow_name=self.flow.name,
2495
+ run_id=run_id,
2496
+ step_name=self._sanitize(node.name),
2497
+ task_id=task_id,
2498
+ attempt=retry_count,
2499
+ user=self.username,
2500
+ subdomain=jobset_name,
2501
+ command=cmds,
2502
+ namespace=resources["namespace"],
2503
+ image=resources["image"],
2504
+ image_pull_policy=resources["image_pull_policy"],
2505
+ image_pull_secrets=resources["image_pull_secrets"],
2506
+ service_account=resources["service_account"],
2507
+ secrets=(
2508
+ [
2509
+ k
2510
+ for k in (
2511
+ list(
1639
2512
  []
1640
2513
  if not resources.get("secrets")
1641
- else [resources.get("secrets")]
1642
- if isinstance(resources.get("secrets"), str)
1643
- else resources.get("secrets")
2514
+ else (
2515
+ [resources.get("secrets")]
2516
+ if isinstance(resources.get("secrets"), str)
2517
+ else resources.get("secrets")
2518
+ )
1644
2519
  )
1645
2520
  + KUBERNETES_SECRETS.split(",")
1646
2521
  + ARGO_WORKFLOWS_KUBERNETES_SECRETS.split(",")
1647
- if k
1648
- ],
1649
- volume_mounts=[
1650
- # Assign a volume mount to pass state to the next task.
1651
- kubernetes_sdk.V1VolumeMount(
1652
- name="out", mount_path="/mnt/out"
1653
- )
2522
+ )
2523
+ if k
2524
+ ]
2525
+ ),
2526
+ node_selector=resources.get("node_selector"),
2527
+ cpu=str(resources["cpu"]),
2528
+ memory=str(resources["memory"]),
2529
+ disk=str(resources["disk"]),
2530
+ gpu=resources["gpu"],
2531
+ gpu_vendor=str(resources["gpu_vendor"]),
2532
+ tolerations=resources["tolerations"],
2533
+ use_tmpfs=use_tmpfs,
2534
+ tmpfs_tempdir=tmpfs_tempdir,
2535
+ tmpfs_size=tmpfs_size,
2536
+ tmpfs_path=tmpfs_path,
2537
+ timeout_in_seconds=run_time_limit,
2538
+ persistent_volume_claims=resources["persistent_volume_claims"],
2539
+ shared_memory=shared_memory,
2540
+ port=port,
2541
+ qos=resources["qos"],
2542
+ security_context=security_context,
2543
+ )
2544
+
2545
+ for k, v in env.items():
2546
+ jobset.environment_variable(k, v)
2547
+
2548
+ for k, v in additional_obp_configs.items():
2549
+ jobset.environment_variable(k, v)
2550
+ # Set labels. Do not allow user-specified task labels to override internal ones.
2551
+ #
2552
+ # Explicitly add the task-id-hint label. This is important because this label
2553
+ # is returned as an Output parameter of this step and is used subsequently as an
2554
+ # an input in the join step.
2555
+ kubernetes_labels = {
2556
+ "task_id_entropy": "{{inputs.parameters.task-id-entropy}}",
2557
+ "num_parallel": "{{inputs.parameters.num-parallel}}",
2558
+ "metaflow/argo-workflows-name": "{{workflow.name}}",
2559
+ "workflows.argoproj.io/workflow": "{{workflow.name}}",
2560
+ }
2561
+ jobset.labels(
2562
+ {
2563
+ **resources["labels"],
2564
+ **self._base_labels,
2565
+ **kubernetes_labels,
2566
+ }
2567
+ )
2568
+
2569
+ jobset.environment_variable(
2570
+ "MF_MASTER_ADDR", jobset.jobset_control_addr
2571
+ )
2572
+ jobset.environment_variable("MF_MASTER_PORT", str(port))
2573
+ jobset.environment_variable(
2574
+ "MF_WORLD_SIZE", "{{inputs.parameters.num-parallel}}"
2575
+ )
2576
+ # We need this task-id set so that all the nodes are aware of the control
2577
+ # task's task-id. These "MF_" variables populate the `current.parallel` namedtuple
2578
+ jobset.environment_variable(
2579
+ "MF_PARALLEL_CONTROL_TASK_ID",
2580
+ "control-{{inputs.parameters.task-id-entropy}}-0",
2581
+ )
2582
+ # for k, v in .items():
2583
+ jobset.environment_variables_from_selectors(
2584
+ {
2585
+ "MF_WORKER_REPLICA_INDEX": "metadata.annotations['jobset.sigs.k8s.io/job-index']",
2586
+ "JOBSET_RESTART_ATTEMPT": "metadata.annotations['jobset.sigs.k8s.io/restart-attempt']",
2587
+ "METAFLOW_KUBERNETES_JOBSET_NAME": "metadata.annotations['jobset.sigs.k8s.io/jobset-name']",
2588
+ "METAFLOW_KUBERNETES_POD_NAMESPACE": "metadata.namespace",
2589
+ "METAFLOW_KUBERNETES_POD_NAME": "metadata.name",
2590
+ "METAFLOW_KUBERNETES_POD_ID": "metadata.uid",
2591
+ "METAFLOW_KUBERNETES_SERVICE_ACCOUNT_NAME": "spec.serviceAccountName",
2592
+ "METAFLOW_KUBERNETES_NODE_IP": "status.hostIP",
2593
+ "TASK_ID_SUFFIX": "metadata.annotations['jobset.sigs.k8s.io/job-index']",
2594
+ }
2595
+ )
2596
+
2597
+ # Set annotations. Do not allow user-specified task-specific annotations to override internal ones.
2598
+ annotations = {
2599
+ # setting annotations explicitly as they wont be
2600
+ # passed down from WorkflowTemplate level
2601
+ "metaflow/step_name": node.name,
2602
+ "metaflow/attempt": str(retry_count),
2603
+ "metaflow/run_id": run_id,
2604
+ }
2605
+
2606
+ jobset.annotations(
2607
+ {
2608
+ **resources["annotations"],
2609
+ **self._base_annotations,
2610
+ **annotations,
2611
+ }
2612
+ )
2613
+
2614
+ jobset.control.replicas(1)
2615
+ jobset.worker.replicas("{{=asInt(inputs.parameters.workerCount)}}")
2616
+ jobset.control.environment_variable("UBF_CONTEXT", UBF_CONTROL)
2617
+ jobset.worker.environment_variable("UBF_CONTEXT", UBF_TASK)
2618
+ jobset.control.environment_variable("MF_CONTROL_INDEX", "0")
2619
+ # `TASK_ID_PREFIX` needs to explicitly be `control` or `worker`
2620
+ # because the join task uses a formulaic approach to infer the task-ids
2621
+ jobset.control.environment_variable("TASK_ID_PREFIX", "control")
2622
+ jobset.worker.environment_variable("TASK_ID_PREFIX", "worker")
2623
+
2624
+ yield (
2625
+ Template(ArgoWorkflows._sanitize(node.name))
2626
+ .resource(
2627
+ "create",
2628
+ jobset.dump(),
2629
+ "status.terminalState == Completed",
2630
+ "status.terminalState == Failed",
2631
+ )
2632
+ .inputs(Inputs().parameters(inputs))
2633
+ .outputs(
2634
+ Outputs().parameters(
2635
+ [
2636
+ Parameter("task-id-entropy").valueFrom(
2637
+ {"jsonPath": "{.metadata.labels.task_id_entropy}"}
2638
+ ),
2639
+ Parameter("num-parallel").valueFrom(
2640
+ {"jsonPath": "{.metadata.labels.num_parallel}"}
2641
+ ),
1654
2642
  ]
1655
- # Support tmpfs.
1656
- + (
1657
- [
1658
- kubernetes_sdk.V1VolumeMount(
1659
- name="tmpfs-ephemeral-volume",
1660
- mount_path=tmpfs_path,
1661
- )
2643
+ )
2644
+ )
2645
+ .retry_strategy(
2646
+ times=total_retries,
2647
+ minutes_between_retries=minutes_between_retries,
2648
+ )
2649
+ )
2650
+ else:
2651
+ template_name = self._sanitize(node.name)
2652
+ if self._is_recursive_node(node):
2653
+ # The recursive template has the original step name,
2654
+ # this becomes a template within the recursive ones 'steps'
2655
+ template_name = self._sanitize("recursive-%s" % node.name)
2656
+ yield (
2657
+ Template(template_name)
2658
+ # Set @timeout values
2659
+ .active_deadline_seconds(run_time_limit)
2660
+ # Set service account
2661
+ .service_account_name(resources["service_account"])
2662
+ # Configure template input
2663
+ .inputs(Inputs().parameters(inputs))
2664
+ # Configure template output
2665
+ .outputs(Outputs().parameters(outputs))
2666
+ # Fail fast!
2667
+ .fail_fast()
2668
+ # Set @retry/@catch values
2669
+ .retry_strategy(
2670
+ times=total_retries,
2671
+ minutes_between_retries=minutes_between_retries,
2672
+ )
2673
+ .metadata(
2674
+ ObjectMeta()
2675
+ .annotation("metaflow/step_name", node.name)
2676
+ # Unfortunately, we can't set the task_id since it is generated
2677
+ # inside the pod. However, it can be inferred from the annotation
2678
+ # set by argo-workflows - `workflows.argoproj.io/outputs` - refer
2679
+ # the field 'task-id' in 'parameters'
2680
+ # .annotation("metaflow/task_id", ...)
2681
+ .annotation("metaflow/attempt", retry_count)
2682
+ .annotations(resources["annotations"])
2683
+ .labels(resources["labels"])
2684
+ )
2685
+ # Set emptyDir volume for state management
2686
+ .empty_dir_volume("out")
2687
+ # Set tmpfs emptyDir volume if enabled
2688
+ .empty_dir_volume(
2689
+ "tmpfs-ephemeral-volume",
2690
+ medium="Memory",
2691
+ size_limit=tmpfs_size if tmpfs_enabled else 0,
2692
+ )
2693
+ .empty_dir_volume("dhsm", medium="Memory", size_limit=shared_memory)
2694
+ .pvc_volumes(resources.get("persistent_volume_claims"))
2695
+ # Set node selectors
2696
+ .node_selectors(resources.get("node_selector"))
2697
+ # Set tolerations
2698
+ .tolerations(resources.get("tolerations"))
2699
+ # Set image pull secrets if present. We need to use pod_spec_patch due to Argo not supporting this on a template level.
2700
+ .pod_spec_patch(
2701
+ {
2702
+ "imagePullSecrets": [
2703
+ {"name": secret}
2704
+ for secret in resources["image_pull_secrets"]
2705
+ ]
2706
+ }
2707
+ if resources["image_pull_secrets"]
2708
+ else None
2709
+ )
2710
+ # Set container
2711
+ .container(
2712
+ # TODO: Unify the logic with kubernetes.py
2713
+ # Important note - Unfortunately, V1Container uses snakecase while
2714
+ # Argo Workflows uses camel. For most of the attributes, both cases
2715
+ # are indistinguishable, but unfortunately, not for all - (
2716
+ # env_from, value_from, etc.) - so we need to handle the conversion
2717
+ # ourselves using to_camelcase. We need to be vigilant about
2718
+ # resources attributes in particular where the keys maybe user
2719
+ # defined.
2720
+ to_camelcase(
2721
+ kubernetes_sdk.V1Container(
2722
+ name=self._sanitize(node.name),
2723
+ command=cmds,
2724
+ termination_message_policy="FallbackToLogsOnError",
2725
+ ports=(
2726
+ [
2727
+ kubernetes_sdk.V1ContainerPort(
2728
+ container_port=port
2729
+ )
2730
+ ]
2731
+ if port
2732
+ else None
2733
+ ),
2734
+ env=[
2735
+ kubernetes_sdk.V1EnvVar(name=k, value=str(v))
2736
+ for k, v in env.items()
1662
2737
  ]
1663
- if tmpfs_enabled
1664
- else []
1665
- )
1666
- # Support shared_memory
1667
- + (
1668
- [
1669
- kubernetes_sdk.V1VolumeMount(
1670
- name="dhsm",
1671
- mount_path="/dev/shm",
2738
+ # Add environment variables for book-keeping.
2739
+ # https://argoproj.github.io/argo-workflows/fields/#fields_155
2740
+ + [
2741
+ kubernetes_sdk.V1EnvVar(
2742
+ name=k,
2743
+ value_from=kubernetes_sdk.V1EnvVarSource(
2744
+ field_ref=kubernetes_sdk.V1ObjectFieldSelector(
2745
+ field_path=str(v)
2746
+ )
2747
+ ),
1672
2748
  )
2749
+ for k, v in {
2750
+ "METAFLOW_KUBERNETES_NAMESPACE": "metadata.namespace",
2751
+ "METAFLOW_KUBERNETES_POD_NAMESPACE": "metadata.namespace",
2752
+ "METAFLOW_KUBERNETES_POD_NAME": "metadata.name",
2753
+ "METAFLOW_KUBERNETES_POD_ID": "metadata.uid",
2754
+ "METAFLOW_KUBERNETES_SERVICE_ACCOUNT_NAME": "spec.serviceAccountName",
2755
+ "METAFLOW_KUBERNETES_NODE_IP": "status.hostIP",
2756
+ }.items()
1673
2757
  ]
1674
- if shared_memory
1675
- else []
1676
- )
1677
- # Support persistent volume claims.
1678
- + (
1679
- [
2758
+ + [
2759
+ kubernetes_sdk.V1EnvVar(
2760
+ name=k,
2761
+ value=v,
2762
+ )
2763
+ for k, v in additional_obp_configs.items()
2764
+ ],
2765
+ image=resources["image"],
2766
+ image_pull_policy=resources["image_pull_policy"],
2767
+ resources=kubernetes_sdk.V1ResourceRequirements(
2768
+ requests=qos_requests,
2769
+ limits={
2770
+ **qos_limits,
2771
+ **{
2772
+ "%s.com/gpu".lower()
2773
+ % resources["gpu_vendor"]: str(
2774
+ resources["gpu"]
2775
+ )
2776
+ for k in [0]
2777
+ if resources["gpu"] is not None
2778
+ },
2779
+ },
2780
+ ),
2781
+ # Configure secrets
2782
+ env_from=[
2783
+ kubernetes_sdk.V1EnvFromSource(
2784
+ secret_ref=kubernetes_sdk.V1SecretEnvSource(
2785
+ name=str(k),
2786
+ # optional=True
2787
+ )
2788
+ )
2789
+ for k in list(
2790
+ []
2791
+ if not resources.get("secrets")
2792
+ else (
2793
+ [resources.get("secrets")]
2794
+ if isinstance(resources.get("secrets"), str)
2795
+ else resources.get("secrets")
2796
+ )
2797
+ )
2798
+ + KUBERNETES_SECRETS.split(",")
2799
+ + ARGO_WORKFLOWS_KUBERNETES_SECRETS.split(",")
2800
+ if k
2801
+ ],
2802
+ volume_mounts=[
2803
+ # Assign a volume mount to pass state to the next task.
1680
2804
  kubernetes_sdk.V1VolumeMount(
1681
- name=claim, mount_path=path
2805
+ name="out", mount_path="/mnt/out"
1682
2806
  )
1683
- for claim, path in resources.get(
1684
- "persistent_volume_claims"
1685
- ).items()
1686
2807
  ]
1687
- if resources.get("persistent_volume_claims") is not None
1688
- else []
1689
- ),
1690
- ).to_dict()
2808
+ # Support tmpfs.
2809
+ + (
2810
+ [
2811
+ kubernetes_sdk.V1VolumeMount(
2812
+ name="tmpfs-ephemeral-volume",
2813
+ mount_path=tmpfs_path,
2814
+ )
2815
+ ]
2816
+ if tmpfs_enabled
2817
+ else []
2818
+ )
2819
+ # Support shared_memory
2820
+ + (
2821
+ [
2822
+ kubernetes_sdk.V1VolumeMount(
2823
+ name="dhsm",
2824
+ mount_path="/dev/shm",
2825
+ )
2826
+ ]
2827
+ if shared_memory
2828
+ else []
2829
+ )
2830
+ # Support persistent volume claims.
2831
+ + (
2832
+ [
2833
+ kubernetes_sdk.V1VolumeMount(
2834
+ name=claim, mount_path=path
2835
+ )
2836
+ for claim, path in resources.get(
2837
+ "persistent_volume_claims"
2838
+ ).items()
2839
+ ]
2840
+ if resources.get("persistent_volume_claims")
2841
+ is not None
2842
+ else []
2843
+ ),
2844
+ **_security_context,
2845
+ ).to_dict()
2846
+ )
1691
2847
  )
1692
2848
  )
2849
+
2850
+ # Return daemon container templates for workflow execution notifications.
2851
+ def _daemon_templates(self):
2852
+ templates = []
2853
+ if self.enable_heartbeat_daemon:
2854
+ templates.append(self._heartbeat_daemon_template())
2855
+ return templates
2856
+
2857
+ # Return lifecycle hooks for workflow execution notifications.
2858
+ def _lifecycle_hooks(self):
2859
+ hooks = []
2860
+ if self.notify_on_error:
2861
+ hooks.append(self._slack_error_template())
2862
+ hooks.append(self._pager_duty_alert_template())
2863
+ hooks.append(self._incident_io_alert_template())
2864
+ if self.notify_on_success:
2865
+ hooks.append(self._slack_success_template())
2866
+ hooks.append(self._pager_duty_change_template())
2867
+ hooks.append(self._incident_io_change_template())
2868
+
2869
+ exit_hook_decos = self.flow._flow_decorators.get("exit_hook", [])
2870
+
2871
+ for deco in exit_hook_decos:
2872
+ hooks.extend(self._lifecycle_hook_from_deco(deco))
2873
+
2874
+ # Clean up None values from templates.
2875
+ hooks = list(filter(None, hooks))
2876
+
2877
+ if hooks:
2878
+ hooks.append(
2879
+ ExitHookHack(
2880
+ url=(
2881
+ self.notify_slack_webhook_url
2882
+ or "https://events.pagerduty.com/v2/enqueue"
2883
+ )
2884
+ )
2885
+ )
2886
+ return hooks
2887
+
2888
+ def _lifecycle_hook_from_deco(self, deco):
2889
+ from kubernetes import client as kubernetes_sdk
2890
+
2891
+ start_step = [step for step in self.graph if step.name == "start"][0]
2892
+ # We want to grab the base image used by the start step, as this is known to be pullable from within the cluster,
2893
+ # and it might contain the required libraries, allowing us to start up faster.
2894
+ start_kube_deco = [
2895
+ deco for deco in start_step.decorators if deco.name == "kubernetes"
2896
+ ][0]
2897
+ resources = dict(start_kube_deco.attributes)
2898
+ kube_defaults = dict(start_kube_deco.defaults)
2899
+
2900
+ # OBP Configs
2901
+ additional_obp_configs = {
2902
+ "OBP_PERIMETER": self.initial_configs["OBP_PERIMETER"],
2903
+ "OBP_INTEGRATIONS_URL": self.initial_configs["OBP_INTEGRATIONS_URL"],
2904
+ }
2905
+
2906
+ run_id_template = "argo-{{workflow.name}}"
2907
+ metaflow_version = self.environment.get_environment_info()
2908
+ metaflow_version["flow_name"] = self.graph.name
2909
+ metaflow_version["production_token"] = self.production_token
2910
+ env = {
2911
+ # These values are needed by Metaflow to set it's internal
2912
+ # state appropriately.
2913
+ "METAFLOW_CODE_URL": self.code_package_url,
2914
+ "METAFLOW_CODE_SHA": self.code_package_sha,
2915
+ "METAFLOW_CODE_DS": self.flow_datastore.TYPE,
2916
+ "METAFLOW_SERVICE_URL": SERVICE_INTERNAL_URL,
2917
+ "METAFLOW_SERVICE_HEADERS": json.dumps(SERVICE_HEADERS),
2918
+ "METAFLOW_USER": "argo-workflows",
2919
+ "METAFLOW_DEFAULT_DATASTORE": self.flow_datastore.TYPE,
2920
+ "METAFLOW_DEFAULT_METADATA": DEFAULT_METADATA,
2921
+ "METAFLOW_OWNER": self.username,
2922
+ }
2923
+ # pass on the Run pathspec for script
2924
+ env["RUN_PATHSPEC"] = f"{self.graph.name}/{run_id_template}"
2925
+
2926
+ # support Metaflow sandboxes
2927
+ env["METAFLOW_INIT_SCRIPT"] = KUBERNETES_SANDBOX_INIT_SCRIPT
2928
+
2929
+ # support fetching secrets
2930
+ env.update(additional_obp_configs)
2931
+
2932
+ env["METAFLOW_WORKFLOW_NAME"] = "{{workflow.name}}"
2933
+ env["METAFLOW_WORKFLOW_NAMESPACE"] = "{{workflow.namespace}}"
2934
+ env = {
2935
+ k: v
2936
+ for k, v in env.items()
2937
+ if v is not None
2938
+ and k not in set(ARGO_WORKFLOWS_ENV_VARS_TO_SKIP.split(","))
2939
+ }
2940
+
2941
+ def _cmd(fn_name):
2942
+ mflog_expr = export_mflog_env_vars(
2943
+ datastore_type=self.flow_datastore.TYPE,
2944
+ stdout_path="$PWD/.logs/mflog_stdout",
2945
+ stderr_path="$PWD/.logs/mflog_stderr",
2946
+ flow_name=self.flow.name,
2947
+ run_id=run_id_template,
2948
+ step_name=f"_hook_{fn_name}",
2949
+ task_id="1",
2950
+ retry_count="0",
2951
+ )
2952
+ cmds = " && ".join(
2953
+ [
2954
+ # For supporting sandboxes, ensure that a custom script is executed
2955
+ # before anything else is executed. The script is passed in as an
2956
+ # env var.
2957
+ '${METAFLOW_INIT_SCRIPT:+eval \\"${METAFLOW_INIT_SCRIPT}\\"}',
2958
+ "mkdir -p $PWD/.logs",
2959
+ mflog_expr,
2960
+ ]
2961
+ + self.environment.get_package_commands(
2962
+ self.code_package_url, self.flow_datastore.TYPE
2963
+ )[:-1]
2964
+ # Replace the line 'Task in starting'
2965
+ + [f"mflog 'Lifecycle hook {fn_name} is starting.'"]
2966
+ + [
2967
+ f"python -m metaflow.plugins.exit_hook.exit_hook_script {metaflow_version['script']} {fn_name} $RUN_PATHSPEC"
2968
+ ]
2969
+ )
2970
+
2971
+ cmds = shlex.split('bash -c "%s"' % cmds)
2972
+ return cmds
2973
+
2974
+ def _container(cmds):
2975
+ return to_camelcase(
2976
+ kubernetes_sdk.V1Container(
2977
+ name="main",
2978
+ command=cmds,
2979
+ image=deco.attributes["options"].get("image", None)
2980
+ or resources["image"],
2981
+ env=[
2982
+ kubernetes_sdk.V1EnvVar(name=k, value=str(v))
2983
+ for k, v in env.items()
2984
+ ],
2985
+ env_from=[
2986
+ kubernetes_sdk.V1EnvFromSource(
2987
+ secret_ref=kubernetes_sdk.V1SecretEnvSource(
2988
+ name=str(k),
2989
+ # optional=True
2990
+ )
2991
+ )
2992
+ for k in list(
2993
+ []
2994
+ if not resources.get("secrets")
2995
+ else (
2996
+ [resources.get("secrets")]
2997
+ if isinstance(resources.get("secrets"), str)
2998
+ else resources.get("secrets")
2999
+ )
3000
+ )
3001
+ + KUBERNETES_SECRETS.split(",")
3002
+ + ARGO_WORKFLOWS_KUBERNETES_SECRETS.split(",")
3003
+ if k
3004
+ ],
3005
+ resources=kubernetes_sdk.V1ResourceRequirements(
3006
+ requests={
3007
+ "cpu": str(kube_defaults["cpu"]),
3008
+ "memory": "%sM" % str(kube_defaults["memory"]),
3009
+ }
3010
+ ),
3011
+ ).to_dict()
3012
+ )
3013
+
3014
+ # create lifecycle hooks from deco
3015
+ hooks = []
3016
+ for success_fn_name in deco.success_hooks:
3017
+ hook = ContainerHook(
3018
+ name=f"success-{success_fn_name.replace('_', '-')}",
3019
+ container=_container(cmds=_cmd(success_fn_name)),
3020
+ service_account_name=resources["service_account"],
3021
+ on_success=True,
3022
+ )
3023
+ hooks.append(hook)
3024
+
3025
+ for error_fn_name in deco.error_hooks:
3026
+ hook = ContainerHook(
3027
+ name=f"error-{error_fn_name.replace('_', '-')}",
3028
+ service_account_name=resources["service_account"],
3029
+ container=_container(cmds=_cmd(error_fn_name)),
3030
+ on_error=True,
1693
3031
  )
3032
+ hooks.append(hook)
3033
+
3034
+ return hooks
1694
3035
 
1695
- # Return exit hook templates for workflow execution notifications.
1696
3036
  def _exit_hook_templates(self):
1697
3037
  templates = []
1698
- if self.notify_on_error:
1699
- templates.append(self._slack_error_template())
1700
- templates.append(self._pager_duty_alert_template())
1701
- if self.notify_on_success:
1702
- templates.append(self._slack_success_template())
1703
- templates.append(self._pager_duty_change_template())
1704
- if self.notify_on_error or self.notify_on_success:
1705
- # Warning: terrible hack to workaround a bug in Argo Workflow where the
1706
- # templates listed above do not execute unless there is an
1707
- # explicit exit hook. as and when this bug is patched, we should
1708
- # remove this effectively no-op template.
1709
- # Note: We use the Http template because changing this to an actual no-op container had the side-effect of
1710
- # leaving LifecycleHooks in a pending state even when they have finished execution.
1711
- templates.append(
1712
- Template("exit-hook-hack").http(
1713
- Http("GET")
1714
- .url(
1715
- self.notify_slack_webhook_url
1716
- or "https://events.pagerduty.com/v2/enqueue"
1717
- )
1718
- .success_condition("true == true")
1719
- )
1720
- )
3038
+ if self.enable_error_msg_capture:
3039
+ templates.extend(self._error_msg_capture_hook_templates())
3040
+
1721
3041
  return templates
1722
3042
 
3043
+ def _error_msg_capture_hook_templates(self):
3044
+ from kubernetes import client as kubernetes_sdk
3045
+
3046
+ start_step = [step for step in self.graph if step.name == "start"][0]
3047
+ # We want to grab the base image used by the start step, as this is known to be pullable from within the cluster,
3048
+ # and it might contain the required libraries, allowing us to start up faster.
3049
+ resources = dict(
3050
+ [deco for deco in start_step.decorators if deco.name == "kubernetes"][
3051
+ 0
3052
+ ].attributes
3053
+ )
3054
+
3055
+ run_id_template = "argo-{{workflow.name}}"
3056
+ metaflow_version = self.environment.get_environment_info()
3057
+ metaflow_version["flow_name"] = self.graph.name
3058
+ metaflow_version["production_token"] = self.production_token
3059
+
3060
+ mflog_expr = export_mflog_env_vars(
3061
+ datastore_type=self.flow_datastore.TYPE,
3062
+ stdout_path="$PWD/.logs/mflog_stdout",
3063
+ stderr_path="$PWD/.logs/mflog_stderr",
3064
+ flow_name=self.flow.name,
3065
+ run_id=run_id_template,
3066
+ step_name="_run_capture_error",
3067
+ task_id="1",
3068
+ retry_count="0",
3069
+ )
3070
+
3071
+ cmds = " && ".join(
3072
+ [
3073
+ # For supporting sandboxes, ensure that a custom script is executed
3074
+ # before anything else is executed. The script is passed in as an
3075
+ # env var.
3076
+ '${METAFLOW_INIT_SCRIPT:+eval \\"${METAFLOW_INIT_SCRIPT}\\"}',
3077
+ "mkdir -p $PWD/.logs",
3078
+ mflog_expr,
3079
+ ]
3080
+ + self.environment.get_package_commands(
3081
+ self.code_package_url,
3082
+ self.flow_datastore.TYPE,
3083
+ self.code_package_metadata,
3084
+ )[:-1]
3085
+ # Replace the line 'Task in starting'
3086
+ # FIXME: this can be brittle.
3087
+ + ["mflog 'Error capture hook is starting.'"]
3088
+ + ["argo_error=$(python -m 'metaflow.plugins.argo.capture_error')"]
3089
+ + ["export METAFLOW_ARGO_ERROR=$argo_error"]
3090
+ + [
3091
+ """python -c 'import json, os; error_obj=os.getenv(\\"METAFLOW_ARGO_ERROR\\");data=json.loads(error_obj); print(data[\\"message\\"])'"""
3092
+ ]
3093
+ + [
3094
+ 'if [ -n \\"${METAFLOW_ARGO_WORKFLOWS_CAPTURE_ERROR_SCRIPT}\\" ]; then eval \\"${METAFLOW_ARGO_WORKFLOWS_CAPTURE_ERROR_SCRIPT}\\"; fi'
3095
+ ]
3096
+ )
3097
+
3098
+ # TODO: Also capture the first failed task id
3099
+ cmds = shlex.split('bash -c "%s"' % cmds)
3100
+ env = {
3101
+ # These values are needed by Metaflow to set it's internal
3102
+ # state appropriately.
3103
+ "METAFLOW_CODE_METADATA": self.code_package_metadata,
3104
+ "METAFLOW_CODE_URL": self.code_package_url,
3105
+ "METAFLOW_CODE_SHA": self.code_package_sha,
3106
+ "METAFLOW_CODE_DS": self.flow_datastore.TYPE,
3107
+ "METAFLOW_SERVICE_URL": SERVICE_INTERNAL_URL,
3108
+ "METAFLOW_SERVICE_HEADERS": json.dumps(SERVICE_HEADERS),
3109
+ "METAFLOW_USER": "argo-workflows",
3110
+ "METAFLOW_DEFAULT_DATASTORE": self.flow_datastore.TYPE,
3111
+ "METAFLOW_DEFAULT_METADATA": DEFAULT_METADATA,
3112
+ "METAFLOW_OWNER": self.username,
3113
+ }
3114
+ # support Metaflow sandboxes
3115
+ env["METAFLOW_INIT_SCRIPT"] = KUBERNETES_SANDBOX_INIT_SCRIPT
3116
+ env["METAFLOW_ARGO_WORKFLOWS_CAPTURE_ERROR_SCRIPT"] = (
3117
+ ARGO_WORKFLOWS_CAPTURE_ERROR_SCRIPT
3118
+ )
3119
+
3120
+ env["METAFLOW_WORKFLOW_NAME"] = "{{workflow.name}}"
3121
+ env["METAFLOW_WORKFLOW_NAMESPACE"] = "{{workflow.namespace}}"
3122
+ env["METAFLOW_ARGO_WORKFLOW_FAILURES"] = "{{workflow.failures}}"
3123
+ env = {
3124
+ k: v
3125
+ for k, v in env.items()
3126
+ if v is not None
3127
+ and k not in set(ARGO_WORKFLOWS_ENV_VARS_TO_SKIP.split(","))
3128
+ }
3129
+ return [
3130
+ Template("error-msg-capture-hook")
3131
+ .service_account_name(resources["service_account"])
3132
+ .container(
3133
+ to_camelcase(
3134
+ kubernetes_sdk.V1Container(
3135
+ name="main",
3136
+ command=cmds,
3137
+ image=resources["image"],
3138
+ env=[
3139
+ kubernetes_sdk.V1EnvVar(name=k, value=str(v))
3140
+ for k, v in env.items()
3141
+ ],
3142
+ env_from=[
3143
+ kubernetes_sdk.V1EnvFromSource(
3144
+ secret_ref=kubernetes_sdk.V1SecretEnvSource(
3145
+ name=str(k),
3146
+ # optional=True
3147
+ )
3148
+ )
3149
+ for k in list(
3150
+ []
3151
+ if not resources.get("secrets")
3152
+ else (
3153
+ [resources.get("secrets")]
3154
+ if isinstance(resources.get("secrets"), str)
3155
+ else resources.get("secrets")
3156
+ )
3157
+ )
3158
+ + KUBERNETES_SECRETS.split(",")
3159
+ + ARGO_WORKFLOWS_KUBERNETES_SECRETS.split(",")
3160
+ if k
3161
+ ],
3162
+ resources=kubernetes_sdk.V1ResourceRequirements(
3163
+ # NOTE: base resources for this are kept to a minimum to save on running costs.
3164
+ # This has an adverse effect on startup time for the daemon, which can be completely
3165
+ # alleviated by using a base image that has the required dependencies pre-installed
3166
+ requests={
3167
+ "cpu": "200m",
3168
+ "memory": "100Mi",
3169
+ },
3170
+ limits={
3171
+ "cpu": "200m",
3172
+ "memory": "500Mi",
3173
+ },
3174
+ ),
3175
+ ).to_dict()
3176
+ )
3177
+ ),
3178
+ Template("capture-error-hook-fn-preflight").steps(
3179
+ [
3180
+ WorkflowStep()
3181
+ .name("capture-error-hook-fn-preflight")
3182
+ .template("error-msg-capture-hook")
3183
+ .when("{{workflow.status}} != Succeeded")
3184
+ ]
3185
+ ),
3186
+ ]
3187
+
1723
3188
  def _pager_duty_alert_template(self):
1724
3189
  # https://developer.pagerduty.com/docs/ZG9jOjExMDI5NTgx-send-an-alert-event
1725
3190
  if self.notify_pager_duty_integration_key is None:
1726
3191
  return None
1727
- return Template("notify-pager-duty-on-error").http(
1728
- Http("POST")
1729
- .url("https://events.pagerduty.com/v2/enqueue")
1730
- .header("Content-Type", "application/json")
1731
- .body(
1732
- json.dumps(
1733
- {
1734
- "event_action": "trigger",
1735
- "routing_key": self.notify_pager_duty_integration_key,
1736
- # "dedup_key": self.flow.name, # TODO: Do we need deduplication?
1737
- "payload": {
1738
- "source": "{{workflow.name}}",
1739
- "severity": "info",
1740
- "summary": "Metaflow run %s/argo-{{workflow.name}} failed!"
1741
- % self.flow.name,
1742
- "custom_details": {
1743
- "Flow": self.flow.name,
1744
- "Run ID": "argo-{{workflow.name}}",
1745
- },
3192
+ return HttpExitHook(
3193
+ name="notify-pager-duty-on-error",
3194
+ method="POST",
3195
+ url="https://events.pagerduty.com/v2/enqueue",
3196
+ headers={"Content-Type": "application/json"},
3197
+ body=json.dumps(
3198
+ {
3199
+ "event_action": "trigger",
3200
+ "routing_key": self.notify_pager_duty_integration_key,
3201
+ # "dedup_key": self.flow.name, # TODO: Do we need deduplication?
3202
+ "payload": {
3203
+ "source": "{{workflow.name}}",
3204
+ "severity": "info",
3205
+ "summary": "Metaflow run %s/argo-{{workflow.name}} failed!"
3206
+ % self.flow.name,
3207
+ "custom_details": {
3208
+ "Flow": self.flow.name,
3209
+ "Run ID": "argo-{{workflow.name}}",
1746
3210
  },
1747
- "links": self._pager_duty_notification_links(),
1748
- }
1749
- )
3211
+ },
3212
+ "links": self._pager_duty_notification_links(),
3213
+ }
3214
+ ),
3215
+ on_error=True,
3216
+ )
3217
+
3218
+ def _incident_io_alert_template(self):
3219
+ if self.notify_incident_io_api_key is None:
3220
+ return None
3221
+ if self.incident_io_alert_source_config_id is None:
3222
+ raise MetaflowException(
3223
+ "Creating alerts for errors requires a alert source config ID."
3224
+ )
3225
+ ui_links = self._incident_io_ui_urls_for_run()
3226
+ return HttpExitHook(
3227
+ name="notify-incident-io-on-error",
3228
+ method="POST",
3229
+ url=(
3230
+ "https://api.incident.io/v2/alert_events/http/%s"
3231
+ % self.incident_io_alert_source_config_id
3232
+ ),
3233
+ headers={
3234
+ "Content-Type": "application/json",
3235
+ "Authorization": "Bearer %s" % self.notify_incident_io_api_key,
3236
+ },
3237
+ body=json.dumps(
3238
+ {
3239
+ "idempotency_key": "argo-{{workflow.name}}", # use run id to deduplicate alerts.
3240
+ "status": "firing",
3241
+ "title": "Flow %s has failed." % self.flow.name,
3242
+ "description": "Metaflow run {run_pathspec} failed!{urls}".format(
3243
+ run_pathspec="%s/argo-{{workflow.name}}" % self.flow.name,
3244
+ urls=(
3245
+ "\n\nSee details for the run at:\n\n"
3246
+ + "\n\n".join(ui_links)
3247
+ if ui_links
3248
+ else ""
3249
+ ),
3250
+ ),
3251
+ "source_url": (
3252
+ "%s/%s/%s"
3253
+ % (
3254
+ UI_URL.rstrip("/"),
3255
+ self.flow.name,
3256
+ "argo-{{workflow.name}}",
3257
+ )
3258
+ if UI_URL
3259
+ else None
3260
+ ),
3261
+ "metadata": {
3262
+ **(self.incident_io_metadata or {}),
3263
+ **{
3264
+ "run_status": "failed",
3265
+ "flow_name": self.flow.name,
3266
+ "run_id": "argo-{{workflow.name}}",
3267
+ },
3268
+ },
3269
+ }
3270
+ ),
3271
+ on_error=True,
3272
+ )
3273
+
3274
+ def _incident_io_change_template(self):
3275
+ if self.notify_incident_io_api_key is None:
3276
+ return None
3277
+ if self.incident_io_alert_source_config_id is None:
3278
+ raise MetaflowException(
3279
+ "Creating alerts for successes requires an alert source config ID."
1750
3280
  )
3281
+ ui_links = self._incident_io_ui_urls_for_run()
3282
+ return HttpExitHook(
3283
+ name="notify-incident-io-on-success",
3284
+ method="POST",
3285
+ url=(
3286
+ "https://api.incident.io/v2/alert_events/http/%s"
3287
+ % self.incident_io_alert_source_config_id
3288
+ ),
3289
+ headers={
3290
+ "Content-Type": "application/json",
3291
+ "Authorization": "Bearer %s" % self.notify_incident_io_api_key,
3292
+ },
3293
+ body=json.dumps(
3294
+ {
3295
+ "idempotency_key": "argo-{{workflow.name}}", # use run id to deduplicate alerts.
3296
+ "status": "firing",
3297
+ "title": "Flow %s has succeeded." % self.flow.name,
3298
+ "description": "Metaflow run {run_pathspec} succeeded!{urls}".format(
3299
+ run_pathspec="%s/argo-{{workflow.name}}" % self.flow.name,
3300
+ urls=(
3301
+ "\n\nSee details for the run at:\n\n"
3302
+ + "\n\n".join(ui_links)
3303
+ if ui_links
3304
+ else ""
3305
+ ),
3306
+ ),
3307
+ "source_url": (
3308
+ "%s/%s/%s"
3309
+ % (
3310
+ UI_URL.rstrip("/"),
3311
+ self.flow.name,
3312
+ "argo-{{workflow.name}}",
3313
+ )
3314
+ if UI_URL
3315
+ else None
3316
+ ),
3317
+ "metadata": {
3318
+ **(self.incident_io_metadata or {}),
3319
+ **{
3320
+ "run_status": "succeeded",
3321
+ "flow_name": self.flow.name,
3322
+ "run_id": "argo-{{workflow.name}}",
3323
+ },
3324
+ },
3325
+ }
3326
+ ),
3327
+ on_success=True,
1751
3328
  )
1752
3329
 
3330
+ def _incident_io_ui_urls_for_run(self):
3331
+ links = []
3332
+ if UI_URL:
3333
+ url = "[Metaflow UI](%s/%s/%s)" % (
3334
+ UI_URL.rstrip("/"),
3335
+ self.flow.name,
3336
+ "argo-{{workflow.name}}",
3337
+ )
3338
+ links.append(url)
3339
+ if ARGO_WORKFLOWS_UI_URL:
3340
+ url = "[Argo UI](%s/workflows/%s/%s)" % (
3341
+ ARGO_WORKFLOWS_UI_URL.rstrip("/"),
3342
+ "{{workflow.namespace}}",
3343
+ "{{workflow.name}}",
3344
+ )
3345
+ links.append(url)
3346
+ return links
3347
+
1753
3348
  def _pager_duty_change_template(self):
1754
3349
  # https://developer.pagerduty.com/docs/ZG9jOjExMDI5NTgy-send-a-change-event
1755
3350
  if self.notify_pager_duty_integration_key is None:
1756
3351
  return None
1757
- return Template("notify-pager-duty-on-success").http(
1758
- Http("POST")
1759
- .url("https://events.pagerduty.com/v2/change/enqueue")
1760
- .header("Content-Type", "application/json")
1761
- .body(
1762
- json.dumps(
1763
- {
1764
- "routing_key": self.notify_pager_duty_integration_key,
1765
- "payload": {
1766
- "summary": "Metaflow run %s/argo-{{workflow.name}} Succeeded"
1767
- % self.flow.name,
1768
- "source": "{{workflow.name}}",
1769
- "custom_details": {
1770
- "Flow": self.flow.name,
1771
- "Run ID": "argo-{{workflow.name}}",
1772
- },
3352
+ return HttpExitHook(
3353
+ name="notify-pager-duty-on-success",
3354
+ method="POST",
3355
+ url="https://events.pagerduty.com/v2/change/enqueue",
3356
+ headers={"Content-Type": "application/json"},
3357
+ body=json.dumps(
3358
+ {
3359
+ "routing_key": self.notify_pager_duty_integration_key,
3360
+ "payload": {
3361
+ "summary": "Metaflow run %s/argo-{{workflow.name}} Succeeded"
3362
+ % self.flow.name,
3363
+ "source": "{{workflow.name}}",
3364
+ "custom_details": {
3365
+ "Flow": self.flow.name,
3366
+ "Run ID": "argo-{{workflow.name}}",
1773
3367
  },
1774
- "links": self._pager_duty_notification_links(),
1775
- }
1776
- )
1777
- )
3368
+ },
3369
+ "links": self._pager_duty_notification_links(),
3370
+ }
3371
+ ),
3372
+ on_success=True,
1778
3373
  )
1779
3374
 
1780
3375
  def _pager_duty_notification_links(self):
1781
3376
  links = []
1782
3377
  if UI_URL:
3378
+ if PAGERDUTY_TEMPLATE_URL:
3379
+ pdproject = ""
3380
+ pdbranch = ""
3381
+ if getattr(current, "project_name", None):
3382
+ pdproject = current.project_name
3383
+ pdbranch = current.branch_name
3384
+ href_val = PAGERDUTY_TEMPLATE_URL.format(
3385
+ pd_flow=self.flow.name,
3386
+ pd_namespace=KUBERNETES_NAMESPACE,
3387
+ pd_template=self.name,
3388
+ pd_project=pdproject,
3389
+ pd_branch=pdbranch,
3390
+ )
3391
+ else:
3392
+ href_val = "%s/%s/%s" % (
3393
+ UI_URL.rstrip("/"),
3394
+ self.flow.name,
3395
+ "argo-{{workflow.name}}",
3396
+ )
1783
3397
  links.append(
1784
3398
  {
1785
- "href": "%s/%s/%s"
1786
- % (UI_URL.rstrip("/"), self.flow.name, "argo-{{workflow.name}}"),
3399
+ "href": href_val,
1787
3400
  "text": "Metaflow UI",
1788
3401
  }
1789
3402
  )
@@ -1807,7 +3420,7 @@ class ArgoWorkflows(object):
1807
3420
  Use Slack's Block Kit to add general information about the environment and
1808
3421
  execution metadata, including a link to the UI and an optional message.
1809
3422
  """
1810
- ui_link = "%s%s/argo-{{workflow.name}}" % (UI_URL, self.flow.name)
3423
+ ui_link = "%s/%s/argo-{{workflow.name}}" % (UI_URL.rstrip("/"), self.flow.name)
1811
3424
  # fmt: off
1812
3425
  if getattr(current, "project_name", None):
1813
3426
  # Add @project metadata when available.
@@ -1815,12 +3428,12 @@ class ArgoWorkflows(object):
1815
3428
  "type": "section",
1816
3429
  "text": {
1817
3430
  "type": "mrkdwn",
1818
- "text": ":metaflow: Environment details"
3431
+ "text": "Environment details"
1819
3432
  },
1820
3433
  "fields": [
1821
3434
  {
1822
3435
  "type": "mrkdwn",
1823
- "text": "*Project:* %s" % current.project_name
3436
+ "text": "*Project:* %s" % current.project_name
1824
3437
  },
1825
3438
  {
1826
3439
  "type": "mrkdwn",
@@ -1833,7 +3446,7 @@ class ArgoWorkflows(object):
1833
3446
  "type": "section",
1834
3447
  "text": {
1835
3448
  "type": "mrkdwn",
1836
- "text": ":metaflow: Environment details"
3449
+ "text": "Environment details"
1837
3450
  }
1838
3451
  }
1839
3452
 
@@ -1878,8 +3491,12 @@ class ArgoWorkflows(object):
1878
3491
  blocks = self._get_slack_blocks(message)
1879
3492
  payload = {"text": message, "blocks": blocks}
1880
3493
 
1881
- return Template("notify-slack-on-error").http(
1882
- Http("POST").url(self.notify_slack_webhook_url).body(json.dumps(payload))
3494
+ return HttpExitHook(
3495
+ name="notify-slack-on-error",
3496
+ method="POST",
3497
+ url=self.notify_slack_webhook_url,
3498
+ body=json.dumps(payload),
3499
+ on_error=True,
1883
3500
  )
1884
3501
 
1885
3502
  def _slack_success_template(self):
@@ -1894,8 +3511,178 @@ class ArgoWorkflows(object):
1894
3511
  blocks = self._get_slack_blocks(message)
1895
3512
  payload = {"text": message, "blocks": blocks}
1896
3513
 
1897
- return Template("notify-slack-on-success").http(
1898
- Http("POST").url(self.notify_slack_webhook_url).body(json.dumps(payload))
3514
+ return HttpExitHook(
3515
+ name="notify-slack-on-success",
3516
+ method="POST",
3517
+ url=self.notify_slack_webhook_url,
3518
+ body=json.dumps(payload),
3519
+ on_success=True,
3520
+ )
3521
+
3522
+ def _heartbeat_daemon_template(self):
3523
+ # Use all the affordances available to _parameters task
3524
+ executable = self.environment.executable("_parameters")
3525
+ run_id = "argo-{{workflow.name}}"
3526
+ script_name = os.path.basename(sys.argv[0])
3527
+ entrypoint = [executable, script_name]
3528
+ # FlowDecorators can define their own top-level options. These might affect run level information
3529
+ # so it is important to pass these to the heartbeat process as well, as it might be the first task to register a run.
3530
+ top_opts_dict = {}
3531
+ for deco in flow_decorators(self.flow):
3532
+ top_opts_dict.update(deco.get_top_level_options())
3533
+
3534
+ top_level = list(dict_to_cli_options(top_opts_dict)) + [
3535
+ "--quiet",
3536
+ "--metadata=%s" % self.metadata.TYPE,
3537
+ "--environment=%s" % self.environment.TYPE,
3538
+ "--datastore=%s" % self.flow_datastore.TYPE,
3539
+ "--datastore-root=%s" % self.flow_datastore.datastore_root,
3540
+ "--event-logger=%s" % self.event_logger.TYPE,
3541
+ "--monitor=%s" % self.monitor.TYPE,
3542
+ "--no-pylint",
3543
+ "--with=argo_workflows_internal:auto-emit-argo-events=%i"
3544
+ % self.auto_emit_argo_events,
3545
+ ]
3546
+ heartbeat_cmds = "{entrypoint} {top_level} argo-workflows heartbeat --run_id {run_id} {tags}".format(
3547
+ entrypoint=" ".join(entrypoint),
3548
+ top_level=" ".join(top_level) if top_level else "",
3549
+ run_id=run_id,
3550
+ tags=" ".join(["--tag %s" % t for t in self.tags]) if self.tags else "",
3551
+ )
3552
+
3553
+ # TODO: we do not really need MFLOG logging for the daemon at the moment, but might be good for the future.
3554
+ # Consider if we can do without this setup.
3555
+ # Configure log capture.
3556
+ mflog_expr = export_mflog_env_vars(
3557
+ datastore_type=self.flow_datastore.TYPE,
3558
+ stdout_path="$PWD/.logs/mflog_stdout",
3559
+ stderr_path="$PWD/.logs/mflog_stderr",
3560
+ flow_name=self.flow.name,
3561
+ run_id=run_id,
3562
+ step_name="_run_heartbeat_daemon",
3563
+ task_id="1",
3564
+ retry_count="0",
3565
+ )
3566
+ # TODO: Can the init be trimmed down?
3567
+ # Can we do without get_package_commands fetching the whole code package?
3568
+ init_cmds = " && ".join(
3569
+ [
3570
+ # For supporting sandboxes, ensure that a custom script is executed
3571
+ # before anything else is executed. The script is passed in as an
3572
+ # env var.
3573
+ '${METAFLOW_INIT_SCRIPT:+eval \\"${METAFLOW_INIT_SCRIPT}\\"}',
3574
+ "mkdir -p $PWD/.logs",
3575
+ mflog_expr,
3576
+ ]
3577
+ + self.environment.get_package_commands(
3578
+ self.code_package_url,
3579
+ self.flow_datastore.TYPE,
3580
+ )[:-1]
3581
+ # Replace the line 'Task in starting'
3582
+ # FIXME: this can be brittle.
3583
+ + ["mflog 'Heartbeat daemon is starting.'"]
3584
+ )
3585
+
3586
+ cmd_str = " && ".join([init_cmds, heartbeat_cmds])
3587
+ cmds = shlex.split('bash -c "%s"' % cmd_str)
3588
+
3589
+ # Env required for sending heartbeats to the metadata service, nothing extra.
3590
+ # prod token / runtime info is required to correctly register flow branches
3591
+ env = {
3592
+ # These values are needed by Metaflow to set it's internal
3593
+ # state appropriately.
3594
+ "METAFLOW_CODE_METADATA": self.code_package_metadata,
3595
+ "METAFLOW_CODE_URL": self.code_package_url,
3596
+ "METAFLOW_CODE_SHA": self.code_package_sha,
3597
+ "METAFLOW_CODE_DS": self.flow_datastore.TYPE,
3598
+ "METAFLOW_SERVICE_URL": SERVICE_INTERNAL_URL,
3599
+ "METAFLOW_SERVICE_HEADERS": json.dumps(SERVICE_HEADERS),
3600
+ "METAFLOW_USER": "argo-workflows",
3601
+ "METAFLOW_DATASTORE_SYSROOT_S3": DATASTORE_SYSROOT_S3,
3602
+ "METAFLOW_DATATOOLS_S3ROOT": DATATOOLS_S3ROOT,
3603
+ "METAFLOW_DEFAULT_DATASTORE": self.flow_datastore.TYPE,
3604
+ "METAFLOW_DEFAULT_METADATA": DEFAULT_METADATA,
3605
+ "METAFLOW_CARD_S3ROOT": CARD_S3ROOT,
3606
+ "METAFLOW_KUBERNETES_WORKLOAD": 1,
3607
+ "METAFLOW_KUBERNETES_FETCH_EC2_METADATA": KUBERNETES_FETCH_EC2_METADATA,
3608
+ "METAFLOW_RUNTIME_ENVIRONMENT": "kubernetes",
3609
+ "METAFLOW_OWNER": self.username,
3610
+ "METAFLOW_PRODUCTION_TOKEN": self.production_token, # Used in identity resolving. This affects system tags.
3611
+ }
3612
+ # support Metaflow sandboxes
3613
+ env["METAFLOW_INIT_SCRIPT"] = KUBERNETES_SANDBOX_INIT_SCRIPT
3614
+
3615
+ # cleanup env values
3616
+ env = {
3617
+ k: v
3618
+ for k, v in env.items()
3619
+ if v is not None
3620
+ and k not in set(ARGO_WORKFLOWS_ENV_VARS_TO_SKIP.split(","))
3621
+ }
3622
+
3623
+ # We want to grab the base image used by the start step, as this is known to be pullable from within the cluster,
3624
+ # and it might contain the required libraries, allowing us to start up faster.
3625
+ start_step = next(step for step in self.flow if step.name == "start")
3626
+ resources = dict(
3627
+ [deco for deco in start_step.decorators if deco.name == "kubernetes"][
3628
+ 0
3629
+ ].attributes
3630
+ )
3631
+ from kubernetes import client as kubernetes_sdk
3632
+
3633
+ return (
3634
+ DaemonTemplate("heartbeat-daemon")
3635
+ # NOTE: Even though a retry strategy does not work for Argo daemon containers,
3636
+ # this has the side-effect of protecting the exit hooks of the workflow from failing in case the daemon container errors out.
3637
+ .retry_strategy(10, 1)
3638
+ .service_account_name(resources["service_account"])
3639
+ .container(
3640
+ to_camelcase(
3641
+ kubernetes_sdk.V1Container(
3642
+ name="main",
3643
+ # TODO: Make the image configurable
3644
+ image=resources["image"],
3645
+ command=cmds,
3646
+ env=[
3647
+ kubernetes_sdk.V1EnvVar(name=k, value=str(v))
3648
+ for k, v in env.items()
3649
+ ],
3650
+ env_from=[
3651
+ kubernetes_sdk.V1EnvFromSource(
3652
+ secret_ref=kubernetes_sdk.V1SecretEnvSource(
3653
+ name=str(k),
3654
+ # optional=True
3655
+ )
3656
+ )
3657
+ for k in list(
3658
+ []
3659
+ if not resources.get("secrets")
3660
+ else (
3661
+ [resources.get("secrets")]
3662
+ if isinstance(resources.get("secrets"), str)
3663
+ else resources.get("secrets")
3664
+ )
3665
+ )
3666
+ + KUBERNETES_SECRETS.split(",")
3667
+ + ARGO_WORKFLOWS_KUBERNETES_SECRETS.split(",")
3668
+ if k
3669
+ ],
3670
+ resources=kubernetes_sdk.V1ResourceRequirements(
3671
+ # NOTE: base resources for this are kept to a minimum to save on running costs.
3672
+ # This has an adverse effect on startup time for the daemon, which can be completely
3673
+ # alleviated by using a base image that has the required dependencies pre-installed
3674
+ requests={
3675
+ "cpu": "200m",
3676
+ "memory": "100Mi",
3677
+ },
3678
+ limits={
3679
+ "cpu": "200m",
3680
+ "memory": "100Mi",
3681
+ },
3682
+ ),
3683
+ )
3684
+ ).to_dict()
3685
+ )
1899
3686
  )
1900
3687
 
1901
3688
  def _compile_sensor(self):
@@ -1997,44 +3784,16 @@ class ArgoWorkflows(object):
1997
3784
  "sdk (https://pypi.org/project/kubernetes/) first."
1998
3785
  )
1999
3786
 
2000
- labels = {"app.kubernetes.io/part-of": "metaflow"}
2001
-
2002
- annotations = {
2003
- "metaflow/production_token": self.production_token,
2004
- "metaflow/owner": self.username,
2005
- "metaflow/user": "argo-workflows",
2006
- "metaflow/flow_name": self.flow.name,
2007
- }
2008
- if current.get("project_name"):
2009
- annotations.update(
2010
- {
2011
- "metaflow/project_name": current.project_name,
2012
- "metaflow/branch_name": current.branch_name,
2013
- "metaflow/project_flow_name": current.project_flow_name,
2014
- }
2015
- )
2016
-
2017
- # Useful to paint the UI
2018
- trigger_annotations = {
2019
- "metaflow/triggered_by": json.dumps(
2020
- [
2021
- {key: trigger.get(key) for key in ["name", "type"]}
2022
- for trigger in self.triggers
2023
- ]
2024
- )
2025
- }
2026
-
2027
3787
  return (
2028
3788
  Sensor()
2029
3789
  .metadata(
2030
3790
  # Sensor metadata.
2031
3791
  ObjectMeta()
2032
- .name(self.name.replace(".", "-"))
2033
- .namespace(KUBERNETES_NAMESPACE)
3792
+ .name(ArgoWorkflows._sensor_name(self.name))
3793
+ .namespace(ARGO_EVENTS_SENSOR_NAMESPACE)
3794
+ .labels(self._base_labels)
2034
3795
  .label("app.kubernetes.io/name", "metaflow-sensor")
2035
- .label("app.kubernetes.io/part-of", "metaflow")
2036
- .labels(self.kubernetes_labels)
2037
- .annotations(annotations)
3796
+ .annotations(self._base_annotations)
2038
3797
  )
2039
3798
  .spec(
2040
3799
  SensorSpec().template(
@@ -2044,7 +3803,7 @@ class ArgoWorkflows(object):
2044
3803
  ObjectMeta()
2045
3804
  .label("app.kubernetes.io/name", "metaflow-sensor")
2046
3805
  .label("app.kubernetes.io/part-of", "metaflow")
2047
- .annotations(annotations)
3806
+ .annotations(self._base_annotations)
2048
3807
  )
2049
3808
  .container(
2050
3809
  # Run sensor in guaranteed QoS. The sensor isn't doing a lot
@@ -2064,7 +3823,7 @@ class ArgoWorkflows(object):
2064
3823
  "memory": "250Mi",
2065
3824
  },
2066
3825
  ),
2067
- )
3826
+ ).to_dict()
2068
3827
  )
2069
3828
  )
2070
3829
  .service_account_name(ARGO_EVENTS_SERVICE_ACCOUNT)
@@ -2081,8 +3840,8 @@ class ArgoWorkflows(object):
2081
3840
  Trigger().template(
2082
3841
  TriggerTemplate(self.name)
2083
3842
  # Trigger a deployed workflow template
2084
- .argo_workflow_trigger(
2085
- ArgoWorkflowTrigger()
3843
+ .k8s_trigger(
3844
+ StandardK8STrigger()
2086
3845
  .source(
2087
3846
  {
2088
3847
  "resource": {
@@ -2091,6 +3850,7 @@ class ArgoWorkflows(object):
2091
3850
  "metadata": {
2092
3851
  "generateName": "%s-" % self.name,
2093
3852
  "namespace": KUBERNETES_NAMESPACE,
3853
+ # Useful to paint the UI
2094
3854
  "annotations": {
2095
3855
  "metaflow/triggered_by": json.dumps(
2096
3856
  [
@@ -2139,8 +3899,21 @@ class ArgoWorkflows(object):
2139
3899
  # everything within the body.
2140
3900
  # NOTE: We need the conditional logic in order to successfully fall back to the default value
2141
3901
  # when the event payload does not contain a key for a parameter.
2142
- data_template='{{ if (hasKey $.Input.body.payload "%s") }}{{- (.Input.body.payload.%s | toJson) -}}{{- else -}}{{ (fail "use-default-instead") }}{{- end -}}'
2143
- % (v, v),
3902
+ # NOTE: Keys might contain dashes, so use the safer 'get' for fetching the value
3903
+ data_template='{{ if (hasKey $.Input.body.payload "%s") }}%s{{- else -}}{{ (fail "use-default-instead") }}{{- end -}}'
3904
+ % (
3905
+ v,
3906
+ (
3907
+ '{{- $pv:=(get $.Input.body.payload "%s") -}}{{ if kindIs "string" $pv }}{{- $pv | toRawJson -}}{{- else -}}{{ $pv | toRawJson | toRawJson }}{{- end -}}'
3908
+ % v
3909
+ if self.parameters[
3910
+ parameter_name
3911
+ ]["type"]
3912
+ == "JSON"
3913
+ else '{{- (get $.Input.body.payload "%s" | toRawJson) -}}'
3914
+ % v
3915
+ ),
3916
+ ),
2144
3917
  # Unfortunately the sensor needs to
2145
3918
  # record the default values for
2146
3919
  # the parameters - there doesn't seem
@@ -2351,6 +4124,38 @@ class ObjectMeta(object):
2351
4124
  return json.dumps(self.to_json(), indent=4)
2352
4125
 
2353
4126
 
4127
+ class WorkflowStep(object):
4128
+ def __init__(self):
4129
+ tree = lambda: defaultdict(tree)
4130
+ self.payload = tree()
4131
+
4132
+ def name(self, name):
4133
+ self.payload["name"] = str(name)
4134
+ return self
4135
+
4136
+ def template(self, template):
4137
+ self.payload["template"] = str(template)
4138
+ return self
4139
+
4140
+ def arguments(self, arguments):
4141
+ self.payload["arguments"] = arguments.to_json()
4142
+ return self
4143
+
4144
+ def when(self, condition):
4145
+ self.payload["when"] = str(condition)
4146
+ return self
4147
+
4148
+ def step(self, expression):
4149
+ self.payload["expression"] = str(expression)
4150
+ return self
4151
+
4152
+ def to_json(self):
4153
+ return self.payload
4154
+
4155
+ def __str__(self):
4156
+ return json.dumps(self.to_json(), indent=4)
4157
+
4158
+
2354
4159
  class WorkflowSpec(object):
2355
4160
  # https://argoproj.github.io/argo-workflows/fields/#workflowspec
2356
4161
  # This object sets all Workflow level properties.
@@ -2381,6 +4186,11 @@ class WorkflowSpec(object):
2381
4186
  self.payload["entrypoint"] = entrypoint
2382
4187
  return self
2383
4188
 
4189
+ def onExit(self, on_exit_template):
4190
+ if on_exit_template:
4191
+ self.payload["onExit"] = on_exit_template
4192
+ return self
4193
+
2384
4194
  def parallelism(self, parallelism):
2385
4195
  # Set parallelism at Workflow level
2386
4196
  self.payload["parallelism"] = int(parallelism)
@@ -2469,6 +4279,38 @@ class Metadata(object):
2469
4279
  return json.dumps(self.to_json(), indent=4)
2470
4280
 
2471
4281
 
4282
+ class DaemonTemplate(object):
4283
+ def __init__(self, name):
4284
+ tree = lambda: defaultdict(tree)
4285
+ self.name = name
4286
+ self.payload = tree()
4287
+ self.payload["daemon"] = True
4288
+ self.payload["name"] = name
4289
+
4290
+ def container(self, container):
4291
+ self.payload["container"] = container
4292
+ return self
4293
+
4294
+ def service_account_name(self, service_account_name):
4295
+ self.payload["serviceAccountName"] = service_account_name
4296
+ return self
4297
+
4298
+ def retry_strategy(self, times, minutes_between_retries):
4299
+ if times > 0:
4300
+ self.payload["retryStrategy"] = {
4301
+ "retryPolicy": "Always",
4302
+ "limit": times,
4303
+ "backoff": {"duration": "%sm" % minutes_between_retries},
4304
+ }
4305
+ return self
4306
+
4307
+ def to_json(self):
4308
+ return self.payload
4309
+
4310
+ def __str__(self):
4311
+ return json.dumps(self.payload, indent=4)
4312
+
4313
+
2472
4314
  class Template(object):
2473
4315
  # https://argoproj.github.io/argo-workflows/fields/#template
2474
4316
 
@@ -2487,6 +4329,18 @@ class Template(object):
2487
4329
  self.payload["dag"] = dag_template.to_json()
2488
4330
  return self
2489
4331
 
4332
+ def steps(self, steps):
4333
+ if "steps" not in self.payload:
4334
+ self.payload["steps"] = []
4335
+ # steps is a list of lists.
4336
+ # hence we go over every item in the incoming list
4337
+ # serialize it and then append the list to the payload
4338
+ step_list = []
4339
+ for step in steps:
4340
+ step_list.append(step.to_json())
4341
+ self.payload["steps"].append(step_list)
4342
+ return self
4343
+
2490
4344
  def container(self, container):
2491
4345
  # Luckily this can simply be V1Container and we are spared from writing more
2492
4346
  # boilerplate - https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1Container.md.
@@ -2579,6 +4433,14 @@ class Template(object):
2579
4433
  )
2580
4434
  return self
2581
4435
 
4436
+ def pod_spec_patch(self, pod_spec_patch=None):
4437
+ if pod_spec_patch is None:
4438
+ return self
4439
+
4440
+ self.payload["podSpecPatch"] = json.dumps(pod_spec_patch)
4441
+
4442
+ return self
4443
+
2582
4444
  def node_selectors(self, node_selectors):
2583
4445
  if "nodeSelector" not in self.payload:
2584
4446
  self.payload["nodeSelector"] = {}
@@ -2593,6 +4455,15 @@ class Template(object):
2593
4455
  def to_json(self):
2594
4456
  return self.payload
2595
4457
 
4458
+ def resource(self, action, manifest, success_criteria, failure_criteria):
4459
+ self.payload["resource"] = {}
4460
+ self.payload["resource"]["action"] = action
4461
+ self.payload["resource"]["setOwnerReference"] = True
4462
+ self.payload["resource"]["successCondition"] = success_criteria
4463
+ self.payload["resource"]["failureCondition"] = failure_criteria
4464
+ self.payload["resource"]["manifest"] = manifest
4465
+ return self
4466
+
2596
4467
  def __str__(self):
2597
4468
  return json.dumps(self.payload, indent=4)
2598
4469
 
@@ -2712,6 +4583,10 @@ class DAGTask(object):
2712
4583
  self.payload["dependencies"] = dependencies
2713
4584
  return self
2714
4585
 
4586
+ def depends(self, depends: str):
4587
+ self.payload["depends"] = depends
4588
+ return self
4589
+
2715
4590
  def template(self, template):
2716
4591
  # Template reference
2717
4592
  self.payload["template"] = template
@@ -2723,6 +4598,10 @@ class DAGTask(object):
2723
4598
  self.payload["inline"] = template.to_json()
2724
4599
  return self
2725
4600
 
4601
+ def when(self, when: str):
4602
+ self.payload["when"] = when
4603
+ return self
4604
+
2726
4605
  def with_param(self, with_param):
2727
4606
  self.payload["withParam"] = with_param
2728
4607
  return self
@@ -2942,6 +4821,10 @@ class TriggerTemplate(object):
2942
4821
  self.payload = tree()
2943
4822
  self.payload["name"] = name
2944
4823
 
4824
+ def k8s_trigger(self, k8s_trigger):
4825
+ self.payload["k8s"] = k8s_trigger.to_json()
4826
+ return self
4827
+
2945
4828
  def argo_workflow_trigger(self, argo_workflow_trigger):
2946
4829
  self.payload["argoWorkflow"] = argo_workflow_trigger.to_json()
2947
4830
  return self
@@ -3018,51 +4901,51 @@ class TriggerParameter(object):
3018
4901
  return json.dumps(self.payload, indent=4)
3019
4902
 
3020
4903
 
3021
- class Http(object):
3022
- # https://argoproj.github.io/argo-workflows/fields/#http
4904
+ class StandardK8STrigger(object):
4905
+ # https://pkg.go.dev/github.com/argoproj/argo-events/pkg/apis/sensor/v1alpha1#StandardK8STrigger
3023
4906
 
3024
- def __init__(self, method):
4907
+ def __init__(self):
3025
4908
  tree = lambda: defaultdict(tree)
3026
4909
  self.payload = tree()
3027
- self.payload["method"] = method
3028
- self.payload["headers"] = []
4910
+ self.payload["operation"] = "create"
3029
4911
 
3030
- def header(self, header, value):
3031
- self.payload["headers"].append({"name": header, "value": value})
4912
+ def operation(self, operation):
4913
+ self.payload["operation"] = operation
3032
4914
  return self
3033
4915
 
3034
- def body(self, body):
3035
- self.payload["body"] = str(body)
4916
+ def group(self, group):
4917
+ self.payload["group"] = group
3036
4918
  return self
3037
4919
 
3038
- def url(self, url):
3039
- self.payload["url"] = url
4920
+ def version(self, version):
4921
+ self.payload["version"] = version
3040
4922
  return self
3041
4923
 
3042
- def success_condition(self, success_condition):
3043
- self.payload["successCondition"] = success_condition
4924
+ def resource(self, resource):
4925
+ self.payload["resource"] = resource
3044
4926
  return self
3045
4927
 
3046
- def to_json(self):
3047
- return self.payload
3048
-
3049
- def __str__(self):
3050
- return json.dumps(self.payload, indent=4)
3051
-
4928
+ def namespace(self, namespace):
4929
+ self.payload["namespace"] = namespace
4930
+ return self
3052
4931
 
3053
- class LifecycleHook(object):
3054
- # https://argoproj.github.io/argo-workflows/fields/#lifecyclehook
4932
+ def source(self, source):
4933
+ self.payload["source"] = source
4934
+ return self
3055
4935
 
3056
- def __init__(self):
3057
- tree = lambda: defaultdict(tree)
3058
- self.payload = tree()
4936
+ def parameters(self, trigger_parameters):
4937
+ if "parameters" not in self.payload:
4938
+ self.payload["parameters"] = []
4939
+ for trigger_parameter in trigger_parameters:
4940
+ self.payload["parameters"].append(trigger_parameter.to_json())
4941
+ return self
3059
4942
 
3060
- def expression(self, expression):
3061
- self.payload["expression"] = str(expression)
4943
+ def live_object(self, live_object=True):
4944
+ self.payload["liveObject"] = live_object
3062
4945
  return self
3063
4946
 
3064
- def template(self, template):
3065
- self.payload["template"] = template
4947
+ def patch_strategy(self, patch_strategy):
4948
+ self.payload["patchStrategy"] = patch_strategy
3066
4949
  return self
3067
4950
 
3068
4951
  def to_json(self):