metaflow 2.12.10__py2.py3-none-any.whl → 2.12.11__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. metaflow/client/core.py +6 -6
  2. metaflow/client/filecache.py +16 -3
  3. metaflow/cmd/develop/stub_generator.py +62 -47
  4. metaflow/datastore/content_addressed_store.py +1 -1
  5. metaflow/datastore/task_datastore.py +1 -1
  6. metaflow/decorators.py +2 -4
  7. metaflow/extension_support/__init__.py +3 -3
  8. metaflow/extension_support/plugins.py +3 -3
  9. metaflow/metaflow_config.py +35 -18
  10. metaflow/parameters.py +3 -3
  11. metaflow/plugins/airflow/airflow.py +6 -6
  12. metaflow/plugins/airflow/airflow_utils.py +5 -3
  13. metaflow/plugins/argo/argo_workflows.py +407 -193
  14. metaflow/plugins/argo/argo_workflows_cli.py +17 -4
  15. metaflow/plugins/argo/argo_workflows_decorator.py +6 -13
  16. metaflow/plugins/argo/capture_error.py +70 -0
  17. metaflow/plugins/aws/step_functions/step_functions.py +3 -3
  18. metaflow/plugins/cards/card_modules/basic.py +5 -3
  19. metaflow/plugins/cards/card_modules/convert_to_native_type.py +2 -2
  20. metaflow/plugins/cards/card_modules/renderer_tools.py +1 -0
  21. metaflow/plugins/cards/card_modules/test_cards.py +0 -2
  22. metaflow/plugins/datatools/s3/s3op.py +5 -3
  23. metaflow/plugins/kubernetes/kubernetes.py +1 -0
  24. metaflow/plugins/kubernetes/kubernetes_job.py +10 -8
  25. metaflow/plugins/kubernetes/kubernetes_jobsets.py +15 -14
  26. metaflow/plugins/logs_cli.py +1 -0
  27. metaflow/plugins/pypi/conda_environment.py +1 -3
  28. metaflow/plugins/pypi/pip.py +3 -3
  29. metaflow/plugins/tag_cli.py +3 -3
  30. metaflow/procpoll.py +1 -1
  31. metaflow/runtime.py +1 -0
  32. metaflow/util.py +6 -6
  33. metaflow/version.py +1 -1
  34. {metaflow-2.12.10.dist-info → metaflow-2.12.11.dist-info}/METADATA +2 -2
  35. {metaflow-2.12.10.dist-info → metaflow-2.12.11.dist-info}/RECORD +39 -38
  36. {metaflow-2.12.10.dist-info → metaflow-2.12.11.dist-info}/LICENSE +0 -0
  37. {metaflow-2.12.10.dist-info → metaflow-2.12.11.dist-info}/WHEEL +0 -0
  38. {metaflow-2.12.10.dist-info → metaflow-2.12.11.dist-info}/entry_points.txt +0 -0
  39. {metaflow-2.12.10.dist-info → metaflow-2.12.11.dist-info}/top_level.txt +0 -0
@@ -4,15 +4,15 @@ import os
4
4
  import re
5
5
  import shlex
6
6
  import sys
7
- from typing import Tuple, List
8
7
  from collections import defaultdict
9
8
  from hashlib import sha1
10
9
  from math import inf
10
+ from typing import List, Tuple
11
11
 
12
12
  from metaflow import JSONType, current
13
- from metaflow.graph import DAGNode
14
13
  from metaflow.decorators import flow_decorators
15
14
  from metaflow.exception import MetaflowException
15
+ from metaflow.graph import DAGNode, FlowGraph
16
16
  from metaflow.includefile import FilePathClass
17
17
  from metaflow.metaflow_config import (
18
18
  ARGO_EVENTS_EVENT,
@@ -21,10 +21,12 @@ from metaflow.metaflow_config import (
21
21
  ARGO_EVENTS_INTERNAL_WEBHOOK_URL,
22
22
  ARGO_EVENTS_SERVICE_ACCOUNT,
23
23
  ARGO_EVENTS_WEBHOOK_AUTH,
24
+ ARGO_WORKFLOWS_CAPTURE_ERROR_SCRIPT,
24
25
  ARGO_WORKFLOWS_ENV_VARS_TO_SKIP,
25
26
  ARGO_WORKFLOWS_KUBERNETES_SECRETS,
26
27
  ARGO_WORKFLOWS_UI_URL,
27
28
  AWS_SECRETS_MANAGER_DEFAULT_REGION,
29
+ AZURE_KEY_VAULT_PREFIX,
28
30
  AZURE_STORAGE_BLOB_SERVICE_ENDPOINT,
29
31
  CARD_AZUREROOT,
30
32
  CARD_GSROOT,
@@ -36,7 +38,6 @@ from metaflow.metaflow_config import (
36
38
  DEFAULT_METADATA,
37
39
  DEFAULT_SECRETS_BACKEND_TYPE,
38
40
  GCP_SECRET_MANAGER_PREFIX,
39
- AZURE_KEY_VAULT_PREFIX,
40
41
  KUBERNETES_FETCH_EC2_METADATA,
41
42
  KUBERNETES_LABELS,
42
43
  KUBERNETES_NAMESPACE,
@@ -49,7 +50,6 @@ from metaflow.metaflow_config import (
49
50
  SERVICE_INTERNAL_URL,
50
51
  UI_URL,
51
52
  )
52
- from metaflow.unbounded_foreach import UBF_CONTROL, UBF_TASK
53
53
  from metaflow.metaflow_config_funcs import config_values
54
54
  from metaflow.mflog import BASH_SAVE_LOGS, bash_capture_logs, export_mflog_env_vars
55
55
  from metaflow.parameters import deploy_time_eval
@@ -57,7 +57,8 @@ from metaflow.plugins.kubernetes.kubernetes import (
57
57
  parse_kube_keyvalue_list,
58
58
  validate_kube_labels,
59
59
  )
60
- from metaflow.graph import FlowGraph
60
+ from metaflow.plugins.kubernetes.kubernetes_jobsets import KubernetesArgoJobSet
61
+ from metaflow.unbounded_foreach import UBF_CONTROL, UBF_TASK
61
62
  from metaflow.util import (
62
63
  compress_list,
63
64
  dict_to_cli_options,
@@ -65,9 +66,6 @@ from metaflow.util import (
65
66
  to_camelcase,
66
67
  to_unicode,
67
68
  )
68
- from metaflow.plugins.kubernetes.kubernetes_jobsets import (
69
- KubernetesArgoJobSet,
70
- )
71
69
 
72
70
  from .argo_client import ArgoClient
73
71
 
@@ -118,6 +116,7 @@ class ArgoWorkflows(object):
118
116
  notify_slack_webhook_url=None,
119
117
  notify_pager_duty_integration_key=None,
120
118
  enable_heartbeat_daemon=True,
119
+ enable_error_msg_capture=False,
121
120
  ):
122
121
  # Some high-level notes -
123
122
  #
@@ -166,7 +165,7 @@ class ArgoWorkflows(object):
166
165
  self.notify_slack_webhook_url = notify_slack_webhook_url
167
166
  self.notify_pager_duty_integration_key = notify_pager_duty_integration_key
168
167
  self.enable_heartbeat_daemon = enable_heartbeat_daemon
169
-
168
+ self.enable_error_msg_capture = enable_error_msg_capture
170
169
  self.parameters = self._process_parameters()
171
170
  self.triggers, self.trigger_options = self._process_triggers()
172
171
  self._schedule, self._timezone = self._get_schedule()
@@ -786,6 +785,12 @@ class ArgoWorkflows(object):
786
785
  )
787
786
  # Set the entrypoint to flow name
788
787
  .entrypoint(self.flow.name)
788
+ # OnExit hooks
789
+ .onExit(
790
+ "capture-error-hook-fn-preflight"
791
+ if self.enable_error_msg_capture
792
+ else None
793
+ )
789
794
  # Set exit hook handlers if notifications are enabled
790
795
  .hooks(
791
796
  {
@@ -1063,7 +1068,7 @@ class ArgoWorkflows(object):
1063
1068
  "%s-foreach-%s"
1064
1069
  % (
1065
1070
  node.name,
1066
- "parallel" if node.parallel_foreach else node.foreach_param
1071
+ "parallel" if node.parallel_foreach else node.foreach_param,
1067
1072
  # Since foreach's are derived based on `self.next(self.a, foreach="<varname>")`
1068
1073
  # vs @parallel foreach are done based on `self.next(self.a, num_parallel="<some-number>")`,
1069
1074
  # we need to ensure that `foreach_template_name` suffix is appropriately set based on the kind
@@ -1360,7 +1365,7 @@ class ArgoWorkflows(object):
1360
1365
  task_str = "-".join(
1361
1366
  [
1362
1367
  "$TASK_ID_PREFIX",
1363
- "{{inputs.parameters.task-id-entropy}}", # id_base is addition entropy to based on node-name of the workflow
1368
+ "{{inputs.parameters.task-id-entropy}}",
1364
1369
  "$TASK_ID_SUFFIX",
1365
1370
  ]
1366
1371
  )
@@ -1391,8 +1396,6 @@ class ArgoWorkflows(object):
1391
1396
  user_code_retries = max_user_code_retries
1392
1397
  total_retries = max_user_code_retries + max_error_retries
1393
1398
  # {{retries}} is only available if retryStrategy is specified
1394
- # and they are only available in the container templates NOT for custom
1395
- # Kubernetes manifests like Jobsets.
1396
1399
  # For custom kubernetes manifests, we will pass the retryCount as a parameter
1397
1400
  # and use that in the manifest.
1398
1401
  retry_count = (
@@ -1519,8 +1522,7 @@ class ArgoWorkflows(object):
1519
1522
  )
1520
1523
  )
1521
1524
  else:
1522
- # When we run Jobsets with Argo Workflows we need to ensure that `input_paths` are generated using the a formulaic approach
1523
- # because our current strategy of using volume mounts for outputs won't work with Jobsets
1525
+ # Handle @parallel where output from volume mount isn't accessible
1524
1526
  input_paths = (
1525
1527
  "$(python -m metaflow.plugins.argo.jobset_input_paths %s %s {{inputs.parameters.task-id-entropy}} {{inputs.parameters.num-parallel}})"
1526
1528
  % (
@@ -1659,16 +1661,16 @@ class ArgoWorkflows(object):
1659
1661
 
1660
1662
  # support for @secret
1661
1663
  env["METAFLOW_DEFAULT_SECRETS_BACKEND_TYPE"] = DEFAULT_SECRETS_BACKEND_TYPE
1662
- env[
1663
- "METAFLOW_AWS_SECRETS_MANAGER_DEFAULT_REGION"
1664
- ] = AWS_SECRETS_MANAGER_DEFAULT_REGION
1664
+ env["METAFLOW_AWS_SECRETS_MANAGER_DEFAULT_REGION"] = (
1665
+ AWS_SECRETS_MANAGER_DEFAULT_REGION
1666
+ )
1665
1667
  env["METAFLOW_GCP_SECRET_MANAGER_PREFIX"] = GCP_SECRET_MANAGER_PREFIX
1666
1668
  env["METAFLOW_AZURE_KEY_VAULT_PREFIX"] = AZURE_KEY_VAULT_PREFIX
1667
1669
 
1668
1670
  # support for Azure
1669
- env[
1670
- "METAFLOW_AZURE_STORAGE_BLOB_SERVICE_ENDPOINT"
1671
- ] = AZURE_STORAGE_BLOB_SERVICE_ENDPOINT
1671
+ env["METAFLOW_AZURE_STORAGE_BLOB_SERVICE_ENDPOINT"] = (
1672
+ AZURE_STORAGE_BLOB_SERVICE_ENDPOINT
1673
+ )
1672
1674
  env["METAFLOW_DATASTORE_SYSROOT_AZURE"] = DATASTORE_SYSROOT_AZURE
1673
1675
  env["METAFLOW_CARD_AZUREROOT"] = CARD_AZUREROOT
1674
1676
 
@@ -1733,9 +1735,7 @@ class ArgoWorkflows(object):
1733
1735
  else:
1734
1736
  # append this only for joins of foreaches, not static splits
1735
1737
  inputs.append(Parameter("split-cardinality"))
1736
- # We can use an `elif` condition because the first `if` condition validates if its
1737
- # a foreach join node, hence we can safely assume that if that condition fails then
1738
- # we can check if the node is a @parallel node.
1738
+ # check if the node is a @parallel node.
1739
1739
  elif node.parallel_step:
1740
1740
  inputs.extend(
1741
1741
  [
@@ -1790,7 +1790,7 @@ class ArgoWorkflows(object):
1790
1790
  ),
1791
1791
  ]
1792
1792
  )
1793
- # Outputs should be defined over here, Not in the _dag_template for the `num_parallel` stuff.
1793
+ # Outputs should be defined over here and not in the _dag_template for @parallel.
1794
1794
 
1795
1795
  # It makes no sense to set env vars to None (shows up as "None" string)
1796
1796
  # Also we skip some env vars (e.g. in case we want to pull them from KUBERNETES_SECRETS)
@@ -1817,20 +1817,20 @@ class ArgoWorkflows(object):
1817
1817
 
1818
1818
  if tmpfs_enabled and tmpfs_tempdir:
1819
1819
  env["METAFLOW_TEMPDIR"] = tmpfs_path
1820
+
1820
1821
  # Create a ContainerTemplate for this node. Ideally, we would have
1821
1822
  # liked to inline this ContainerTemplate and avoid scanning the workflow
1822
1823
  # twice, but due to issues with variable substitution, we will have to
1823
1824
  # live with this routine.
1824
1825
  if node.parallel_step:
1825
-
1826
1826
  # Explicitly add the task-id-hint label. This is important because this label
1827
- # is returned as an Output parameter of this step and is used subsequently an
1828
- # an input in the join step. Even the num_parallel is used as an output parameter
1827
+ # is returned as an Output parameter of this step and is used subsequently as an
1828
+ # an input in the join step.
1829
1829
  kubernetes_labels = self.kubernetes_labels.copy()
1830
1830
  jobset_name = "{{inputs.parameters.jobset-name}}"
1831
- kubernetes_labels[
1832
- "task_id_entropy"
1833
- ] = "{{inputs.parameters.task-id-entropy}}"
1831
+ kubernetes_labels["task_id_entropy"] = (
1832
+ "{{inputs.parameters.task-id-entropy}}"
1833
+ )
1834
1834
  kubernetes_labels["num_parallel"] = "{{inputs.parameters.num-parallel}}"
1835
1835
  jobset = KubernetesArgoJobSet(
1836
1836
  kubernetes_sdk=kubernetes_sdk,
@@ -1854,9 +1854,11 @@ class ArgoWorkflows(object):
1854
1854
  list(
1855
1855
  []
1856
1856
  if not resources.get("secrets")
1857
- else [resources.get("secrets")]
1858
- if isinstance(resources.get("secrets"), str)
1859
- else resources.get("secrets")
1857
+ else (
1858
+ [resources.get("secrets")]
1859
+ if isinstance(resources.get("secrets"), str)
1860
+ else resources.get("secrets")
1861
+ )
1860
1862
  )
1861
1863
  + KUBERNETES_SECRETS.split(",")
1862
1864
  + ARGO_WORKFLOWS_KUBERNETES_SECRETS.split(",")
@@ -1887,7 +1889,6 @@ class ArgoWorkflows(object):
1887
1889
  for k, v in kubernetes_labels.items():
1888
1890
  jobset.label(k, v)
1889
1891
 
1890
- ## -----Jobset specific env vars START here-----
1891
1892
  jobset.environment_variable(
1892
1893
  "MF_MASTER_ADDR", jobset.jobset_control_addr
1893
1894
  )
@@ -1906,7 +1907,6 @@ class ArgoWorkflows(object):
1906
1907
  "METAFLOW_KUBERNETES_POD_ID": "metadata.uid",
1907
1908
  "METAFLOW_KUBERNETES_SERVICE_ACCOUNT_NAME": "spec.serviceAccountName",
1908
1909
  "METAFLOW_KUBERNETES_NODE_IP": "status.hostIP",
1909
- # `TASK_ID_SUFFIX` is needed for the construction of the task-ids
1910
1910
  "TASK_ID_SUFFIX": "metadata.annotations['jobset.sigs.k8s.io/job-index']",
1911
1911
  }
1912
1912
  )
@@ -1931,8 +1931,7 @@ class ArgoWorkflows(object):
1931
1931
  )
1932
1932
  for k, v in annotations.items():
1933
1933
  jobset.annotation(k, v)
1934
- ## -----Jobset specific env vars END here-----
1935
- ## ---- Jobset control/workers specific vars START here ----
1934
+
1936
1935
  jobset.control.replicas(1)
1937
1936
  jobset.worker.replicas("{{=asInt(inputs.parameters.workerCount)}}")
1938
1937
  jobset.control.environment_variable("UBF_CONTEXT", UBF_CONTROL)
@@ -1943,7 +1942,6 @@ class ArgoWorkflows(object):
1943
1942
  jobset.control.environment_variable("TASK_ID_PREFIX", "control")
1944
1943
  jobset.worker.environment_variable("TASK_ID_PREFIX", "worker")
1945
1944
 
1946
- ## ---- Jobset control/workers specific vars END here ----
1947
1945
  yield (
1948
1946
  Template(ArgoWorkflows._sanitize(node.name))
1949
1947
  .resource(
@@ -1970,169 +1968,178 @@ class ArgoWorkflows(object):
1970
1968
  minutes_between_retries=minutes_between_retries,
1971
1969
  )
1972
1970
  )
1973
- continue
1974
- yield (
1975
- Template(self._sanitize(node.name))
1976
- # Set @timeout values
1977
- .active_deadline_seconds(run_time_limit)
1978
- # Set service account
1979
- .service_account_name(resources["service_account"])
1980
- # Configure template input
1981
- .inputs(Inputs().parameters(inputs))
1982
- # Configure template output
1983
- .outputs(Outputs().parameters(outputs))
1984
- # Fail fast!
1985
- .fail_fast()
1986
- # Set @retry/@catch values
1987
- .retry_strategy(
1988
- times=total_retries,
1989
- minutes_between_retries=minutes_between_retries,
1990
- )
1991
- .metadata(
1992
- ObjectMeta().annotation("metaflow/step_name", node.name)
1993
- # Unfortunately, we can't set the task_id since it is generated
1994
- # inside the pod. However, it can be inferred from the annotation
1995
- # set by argo-workflows - `workflows.argoproj.io/outputs` - refer
1996
- # the field 'task-id' in 'parameters'
1997
- # .annotation("metaflow/task_id", ...)
1998
- .annotation("metaflow/attempt", retry_count)
1999
- )
2000
- # Set emptyDir volume for state management
2001
- .empty_dir_volume("out")
2002
- # Set tmpfs emptyDir volume if enabled
2003
- .empty_dir_volume(
2004
- "tmpfs-ephemeral-volume",
2005
- medium="Memory",
2006
- size_limit=tmpfs_size if tmpfs_enabled else 0,
2007
- )
2008
- .empty_dir_volume("dhsm", medium="Memory", size_limit=shared_memory)
2009
- .pvc_volumes(resources.get("persistent_volume_claims"))
2010
- # Set node selectors
2011
- .node_selectors(resources.get("node_selector"))
2012
- # Set tolerations
2013
- .tolerations(resources.get("tolerations"))
2014
- # Set container
2015
- .container(
2016
- # TODO: Unify the logic with kubernetes.py
2017
- # Important note - Unfortunately, V1Container uses snakecase while
2018
- # Argo Workflows uses camel. For most of the attributes, both cases
2019
- # are indistinguishable, but unfortunately, not for all - (
2020
- # env_from, value_from, etc.) - so we need to handle the conversion
2021
- # ourselves using to_camelcase. We need to be vigilant about
2022
- # resources attributes in particular where the keys maybe user
2023
- # defined.
2024
- to_camelcase(
2025
- kubernetes_sdk.V1Container(
2026
- name=self._sanitize(node.name),
2027
- command=cmds,
2028
- termination_message_policy="FallbackToLogsOnError",
2029
- ports=[kubernetes_sdk.V1ContainerPort(container_port=port)]
2030
- if port
2031
- else None,
2032
- env=[
2033
- kubernetes_sdk.V1EnvVar(name=k, value=str(v))
2034
- for k, v in env.items()
2035
- ]
2036
- # Add environment variables for book-keeping.
2037
- # https://argoproj.github.io/argo-workflows/fields/#fields_155
2038
- + [
2039
- kubernetes_sdk.V1EnvVar(
2040
- name=k,
2041
- value_from=kubernetes_sdk.V1EnvVarSource(
2042
- field_ref=kubernetes_sdk.V1ObjectFieldSelector(
2043
- field_path=str(v)
1971
+ else:
1972
+ yield (
1973
+ Template(self._sanitize(node.name))
1974
+ # Set @timeout values
1975
+ .active_deadline_seconds(run_time_limit)
1976
+ # Set service account
1977
+ .service_account_name(resources["service_account"])
1978
+ # Configure template input
1979
+ .inputs(Inputs().parameters(inputs))
1980
+ # Configure template output
1981
+ .outputs(Outputs().parameters(outputs))
1982
+ # Fail fast!
1983
+ .fail_fast()
1984
+ # Set @retry/@catch values
1985
+ .retry_strategy(
1986
+ times=total_retries,
1987
+ minutes_between_retries=minutes_between_retries,
1988
+ )
1989
+ .metadata(
1990
+ ObjectMeta().annotation("metaflow/step_name", node.name)
1991
+ # Unfortunately, we can't set the task_id since it is generated
1992
+ # inside the pod. However, it can be inferred from the annotation
1993
+ # set by argo-workflows - `workflows.argoproj.io/outputs` - refer
1994
+ # the field 'task-id' in 'parameters'
1995
+ # .annotation("metaflow/task_id", ...)
1996
+ .annotation("metaflow/attempt", retry_count)
1997
+ )
1998
+ # Set emptyDir volume for state management
1999
+ .empty_dir_volume("out")
2000
+ # Set tmpfs emptyDir volume if enabled
2001
+ .empty_dir_volume(
2002
+ "tmpfs-ephemeral-volume",
2003
+ medium="Memory",
2004
+ size_limit=tmpfs_size if tmpfs_enabled else 0,
2005
+ )
2006
+ .empty_dir_volume("dhsm", medium="Memory", size_limit=shared_memory)
2007
+ .pvc_volumes(resources.get("persistent_volume_claims"))
2008
+ # Set node selectors
2009
+ .node_selectors(resources.get("node_selector"))
2010
+ # Set tolerations
2011
+ .tolerations(resources.get("tolerations"))
2012
+ # Set container
2013
+ .container(
2014
+ # TODO: Unify the logic with kubernetes.py
2015
+ # Important note - Unfortunately, V1Container uses snakecase while
2016
+ # Argo Workflows uses camel. For most of the attributes, both cases
2017
+ # are indistinguishable, but unfortunately, not for all - (
2018
+ # env_from, value_from, etc.) - so we need to handle the conversion
2019
+ # ourselves using to_camelcase. We need to be vigilant about
2020
+ # resources attributes in particular where the keys maybe user
2021
+ # defined.
2022
+ to_camelcase(
2023
+ kubernetes_sdk.V1Container(
2024
+ name=self._sanitize(node.name),
2025
+ command=cmds,
2026
+ termination_message_policy="FallbackToLogsOnError",
2027
+ ports=(
2028
+ [
2029
+ kubernetes_sdk.V1ContainerPort(
2030
+ container_port=port
2044
2031
  )
2045
- ),
2046
- )
2047
- for k, v in {
2048
- "METAFLOW_KUBERNETES_POD_NAMESPACE": "metadata.namespace",
2049
- "METAFLOW_KUBERNETES_POD_NAME": "metadata.name",
2050
- "METAFLOW_KUBERNETES_POD_ID": "metadata.uid",
2051
- "METAFLOW_KUBERNETES_SERVICE_ACCOUNT_NAME": "spec.serviceAccountName",
2052
- "METAFLOW_KUBERNETES_NODE_IP": "status.hostIP",
2053
- }.items()
2054
- ],
2055
- image=resources["image"],
2056
- image_pull_policy=resources["image_pull_policy"],
2057
- resources=kubernetes_sdk.V1ResourceRequirements(
2058
- requests={
2059
- "cpu": str(resources["cpu"]),
2060
- "memory": "%sM" % str(resources["memory"]),
2061
- "ephemeral-storage": "%sM" % str(resources["disk"]),
2062
- },
2063
- limits={
2064
- "%s.com/gpu".lower()
2065
- % resources["gpu_vendor"]: str(resources["gpu"])
2066
- for k in [0]
2067
- if resources["gpu"] is not None
2068
- },
2069
- ),
2070
- # Configure secrets
2071
- env_from=[
2072
- kubernetes_sdk.V1EnvFromSource(
2073
- secret_ref=kubernetes_sdk.V1SecretEnvSource(
2074
- name=str(k),
2075
- # optional=True
2076
- )
2077
- )
2078
- for k in list(
2079
- []
2080
- if not resources.get("secrets")
2081
- else (
2082
- [resources.get("secrets")]
2083
- if isinstance(resources.get("secrets"), str)
2084
- else resources.get("secrets")
2032
+ ]
2033
+ if port
2034
+ else None
2035
+ ),
2036
+ env=[
2037
+ kubernetes_sdk.V1EnvVar(name=k, value=str(v))
2038
+ for k, v in env.items()
2039
+ ]
2040
+ # Add environment variables for book-keeping.
2041
+ # https://argoproj.github.io/argo-workflows/fields/#fields_155
2042
+ + [
2043
+ kubernetes_sdk.V1EnvVar(
2044
+ name=k,
2045
+ value_from=kubernetes_sdk.V1EnvVarSource(
2046
+ field_ref=kubernetes_sdk.V1ObjectFieldSelector(
2047
+ field_path=str(v)
2048
+ )
2049
+ ),
2085
2050
  )
2086
- )
2087
- + KUBERNETES_SECRETS.split(",")
2088
- + ARGO_WORKFLOWS_KUBERNETES_SECRETS.split(",")
2089
- if k
2090
- ],
2091
- volume_mounts=[
2092
- # Assign a volume mount to pass state to the next task.
2093
- kubernetes_sdk.V1VolumeMount(
2094
- name="out", mount_path="/mnt/out"
2095
- )
2096
- ]
2097
- # Support tmpfs.
2098
- + (
2099
- [
2100
- kubernetes_sdk.V1VolumeMount(
2101
- name="tmpfs-ephemeral-volume",
2102
- mount_path=tmpfs_path,
2051
+ for k, v in {
2052
+ "METAFLOW_KUBERNETES_NAMESPACE": "metadata.namespace",
2053
+ "METAFLOW_KUBERNETES_POD_NAMESPACE": "metadata.namespace",
2054
+ "METAFLOW_KUBERNETES_POD_NAME": "metadata.name",
2055
+ "METAFLOW_KUBERNETES_POD_ID": "metadata.uid",
2056
+ "METAFLOW_KUBERNETES_SERVICE_ACCOUNT_NAME": "spec.serviceAccountName",
2057
+ "METAFLOW_KUBERNETES_NODE_IP": "status.hostIP",
2058
+ }.items()
2059
+ ],
2060
+ image=resources["image"],
2061
+ image_pull_policy=resources["image_pull_policy"],
2062
+ resources=kubernetes_sdk.V1ResourceRequirements(
2063
+ requests={
2064
+ "cpu": str(resources["cpu"]),
2065
+ "memory": "%sM" % str(resources["memory"]),
2066
+ "ephemeral-storage": "%sM"
2067
+ % str(resources["disk"]),
2068
+ },
2069
+ limits={
2070
+ "%s.com/gpu".lower()
2071
+ % resources["gpu_vendor"]: str(resources["gpu"])
2072
+ for k in [0]
2073
+ if resources["gpu"] is not None
2074
+ },
2075
+ ),
2076
+ # Configure secrets
2077
+ env_from=[
2078
+ kubernetes_sdk.V1EnvFromSource(
2079
+ secret_ref=kubernetes_sdk.V1SecretEnvSource(
2080
+ name=str(k),
2081
+ # optional=True
2082
+ )
2103
2083
  )
2104
- ]
2105
- if tmpfs_enabled
2106
- else []
2107
- )
2108
- # Support shared_memory
2109
- + (
2110
- [
2111
- kubernetes_sdk.V1VolumeMount(
2112
- name="dhsm",
2113
- mount_path="/dev/shm",
2084
+ for k in list(
2085
+ []
2086
+ if not resources.get("secrets")
2087
+ else (
2088
+ [resources.get("secrets")]
2089
+ if isinstance(resources.get("secrets"), str)
2090
+ else resources.get("secrets")
2091
+ )
2114
2092
  )
2115
- ]
2116
- if shared_memory
2117
- else []
2118
- )
2119
- # Support persistent volume claims.
2120
- + (
2121
- [
2093
+ + KUBERNETES_SECRETS.split(",")
2094
+ + ARGO_WORKFLOWS_KUBERNETES_SECRETS.split(",")
2095
+ if k
2096
+ ],
2097
+ volume_mounts=[
2098
+ # Assign a volume mount to pass state to the next task.
2122
2099
  kubernetes_sdk.V1VolumeMount(
2123
- name=claim, mount_path=path
2100
+ name="out", mount_path="/mnt/out"
2124
2101
  )
2125
- for claim, path in resources.get(
2126
- "persistent_volume_claims"
2127
- ).items()
2128
2102
  ]
2129
- if resources.get("persistent_volume_claims") is not None
2130
- else []
2131
- ),
2132
- ).to_dict()
2103
+ # Support tmpfs.
2104
+ + (
2105
+ [
2106
+ kubernetes_sdk.V1VolumeMount(
2107
+ name="tmpfs-ephemeral-volume",
2108
+ mount_path=tmpfs_path,
2109
+ )
2110
+ ]
2111
+ if tmpfs_enabled
2112
+ else []
2113
+ )
2114
+ # Support shared_memory
2115
+ + (
2116
+ [
2117
+ kubernetes_sdk.V1VolumeMount(
2118
+ name="dhsm",
2119
+ mount_path="/dev/shm",
2120
+ )
2121
+ ]
2122
+ if shared_memory
2123
+ else []
2124
+ )
2125
+ # Support persistent volume claims.
2126
+ + (
2127
+ [
2128
+ kubernetes_sdk.V1VolumeMount(
2129
+ name=claim, mount_path=path
2130
+ )
2131
+ for claim, path in resources.get(
2132
+ "persistent_volume_claims"
2133
+ ).items()
2134
+ ]
2135
+ if resources.get("persistent_volume_claims")
2136
+ is not None
2137
+ else []
2138
+ ),
2139
+ ).to_dict()
2140
+ )
2133
2141
  )
2134
2142
  )
2135
- )
2136
2143
 
2137
2144
  # Return daemon container templates for workflow execution notifications.
2138
2145
  def _daemon_templates(self):
@@ -2167,8 +2174,150 @@ class ArgoWorkflows(object):
2167
2174
  .success_condition("true == true")
2168
2175
  )
2169
2176
  )
2177
+ if self.enable_error_msg_capture:
2178
+ templates.extend(self._error_msg_capture_hook_templates())
2170
2179
  return templates
2171
2180
 
2181
+ def _error_msg_capture_hook_templates(self):
2182
+ from kubernetes import client as kubernetes_sdk
2183
+
2184
+ start_step = [step for step in self.graph if step.name == "start"][0]
2185
+ # We want to grab the base image used by the start step, as this is known to be pullable from within the cluster,
2186
+ # and it might contain the required libraries, allowing us to start up faster.
2187
+ resources = dict(
2188
+ [deco for deco in start_step.decorators if deco.name == "kubernetes"][
2189
+ 0
2190
+ ].attributes
2191
+ )
2192
+
2193
+ run_id_template = "argo-{{workflow.name}}"
2194
+ metaflow_version = self.environment.get_environment_info()
2195
+ metaflow_version["flow_name"] = self.graph.name
2196
+ metaflow_version["production_token"] = self.production_token
2197
+
2198
+ mflog_expr = export_mflog_env_vars(
2199
+ datastore_type=self.flow_datastore.TYPE,
2200
+ stdout_path="$PWD/.logs/mflog_stdout",
2201
+ stderr_path="$PWD/.logs/mflog_stderr",
2202
+ flow_name=self.flow.name,
2203
+ run_id=run_id_template,
2204
+ step_name="_run_capture_error",
2205
+ task_id="1",
2206
+ retry_count="0",
2207
+ )
2208
+
2209
+ cmds = " && ".join(
2210
+ [
2211
+ # For supporting sandboxes, ensure that a custom script is executed
2212
+ # before anything else is executed. The script is passed in as an
2213
+ # env var.
2214
+ '${METAFLOW_INIT_SCRIPT:+eval \\"${METAFLOW_INIT_SCRIPT}\\"}',
2215
+ "mkdir -p $PWD/.logs",
2216
+ mflog_expr,
2217
+ ]
2218
+ + self.environment.get_package_commands(
2219
+ self.code_package_url, self.flow_datastore.TYPE
2220
+ )[:-1]
2221
+ # Replace the line 'Task in starting'
2222
+ # FIXME: this can be brittle.
2223
+ + ["mflog 'Error capture hook is starting.'"]
2224
+ + ["argo_error=$(python -m 'metaflow.plugins.argo.capture_error')"]
2225
+ + ["export METAFLOW_ARGO_ERROR=$argo_error"]
2226
+ + [
2227
+ """python -c 'import json, os; error_obj=os.getenv(\\"METAFLOW_ARGO_ERROR\\");data=json.loads(error_obj); print(data[\\"message\\"])'"""
2228
+ ]
2229
+ + [
2230
+ 'if [ -n \\"${ARGO_WORKFLOWS_CAPTURE_ERROR_SCRIPT}\\" ]; then eval \\"${ARGO_WORKFLOWS_CAPTURE_ERROR_SCRIPT}\\"; fi'
2231
+ ]
2232
+ )
2233
+
2234
+ # TODO: Also capture the first failed task id
2235
+ cmds = shlex.split('bash -c "%s"' % cmds)
2236
+ env = {
2237
+ # These values are needed by Metaflow to set it's internal
2238
+ # state appropriately.
2239
+ "METAFLOW_CODE_URL": self.code_package_url,
2240
+ "METAFLOW_CODE_SHA": self.code_package_sha,
2241
+ "METAFLOW_CODE_DS": self.flow_datastore.TYPE,
2242
+ "METAFLOW_SERVICE_URL": SERVICE_INTERNAL_URL,
2243
+ "METAFLOW_SERVICE_HEADERS": json.dumps(SERVICE_HEADERS),
2244
+ "METAFLOW_USER": "argo-workflows",
2245
+ "METAFLOW_DEFAULT_DATASTORE": self.flow_datastore.TYPE,
2246
+ "METAFLOW_DEFAULT_METADATA": DEFAULT_METADATA,
2247
+ "METAFLOW_OWNER": self.username,
2248
+ }
2249
+ # support Metaflow sandboxes
2250
+ env["METAFLOW_INIT_SCRIPT"] = KUBERNETES_SANDBOX_INIT_SCRIPT
2251
+ env["METAFLOW_ARGO_WORKFLOWS_CAPTURE_ERROR_SCRIPT"] = (
2252
+ ARGO_WORKFLOWS_CAPTURE_ERROR_SCRIPT
2253
+ )
2254
+
2255
+ env["METAFLOW_WORKFLOW_NAME"] = "{{workflow.name}}"
2256
+ env["METAFLOW_WORKFLOW_NAMESPACE"] = "{{workflow.namespace}}"
2257
+ env["METAFLOW_ARGO_WORKFLOW_FAILURES"] = "{{workflow.failures}}"
2258
+ env = {
2259
+ k: v
2260
+ for k, v in env.items()
2261
+ if v is not None
2262
+ and k not in set(ARGO_WORKFLOWS_ENV_VARS_TO_SKIP.split(","))
2263
+ }
2264
+ return [
2265
+ Template("error-msg-capture-hook").container(
2266
+ to_camelcase(
2267
+ kubernetes_sdk.V1Container(
2268
+ name="main",
2269
+ command=cmds,
2270
+ image=resources["image"],
2271
+ env=[
2272
+ kubernetes_sdk.V1EnvVar(name=k, value=str(v))
2273
+ for k, v in env.items()
2274
+ ],
2275
+ env_from=[
2276
+ kubernetes_sdk.V1EnvFromSource(
2277
+ secret_ref=kubernetes_sdk.V1SecretEnvSource(
2278
+ name=str(k),
2279
+ # optional=True
2280
+ )
2281
+ )
2282
+ for k in list(
2283
+ []
2284
+ if not resources.get("secrets")
2285
+ else (
2286
+ [resources.get("secrets")]
2287
+ if isinstance(resources.get("secrets"), str)
2288
+ else resources.get("secrets")
2289
+ )
2290
+ )
2291
+ + KUBERNETES_SECRETS.split(",")
2292
+ + ARGO_WORKFLOWS_KUBERNETES_SECRETS.split(",")
2293
+ if k
2294
+ ],
2295
+ resources=kubernetes_sdk.V1ResourceRequirements(
2296
+ # NOTE: base resources for this are kept to a minimum to save on running costs.
2297
+ # This has an adverse effect on startup time for the daemon, which can be completely
2298
+ # alleviated by using a base image that has the required dependencies pre-installed
2299
+ requests={
2300
+ "cpu": "200m",
2301
+ "memory": "100Mi",
2302
+ },
2303
+ limits={
2304
+ "cpu": "200m",
2305
+ "memory": "500Mi",
2306
+ },
2307
+ ),
2308
+ )
2309
+ )
2310
+ ),
2311
+ Template("capture-error-hook-fn-preflight").steps(
2312
+ [
2313
+ WorkflowStep()
2314
+ .name("capture-error-hook-fn-preflight")
2315
+ .template("error-msg-capture-hook")
2316
+ .when("{{workflow.status}} != Succeeded")
2317
+ ]
2318
+ ),
2319
+ ]
2320
+
2172
2321
  def _pager_duty_alert_template(self):
2173
2322
  # https://developer.pagerduty.com/docs/ZG9jOjExMDI5NTgx-send-an-alert-event
2174
2323
  if self.notify_pager_duty_integration_key is None:
@@ -2441,6 +2590,26 @@ class ArgoWorkflows(object):
2441
2590
  kubernetes_sdk.V1EnvVar(name=k, value=str(v))
2442
2591
  for k, v in env.items()
2443
2592
  ],
2593
+ env_from=[
2594
+ kubernetes_sdk.V1EnvFromSource(
2595
+ secret_ref=kubernetes_sdk.V1SecretEnvSource(
2596
+ name=str(k),
2597
+ # optional=True
2598
+ )
2599
+ )
2600
+ for k in list(
2601
+ []
2602
+ if not resources.get("secrets")
2603
+ else (
2604
+ [resources.get("secrets")]
2605
+ if isinstance(resources.get("secrets"), str)
2606
+ else resources.get("secrets")
2607
+ )
2608
+ )
2609
+ + KUBERNETES_SECRETS.split(",")
2610
+ + ARGO_WORKFLOWS_KUBERNETES_SECRETS.split(",")
2611
+ if k
2612
+ ],
2444
2613
  resources=kubernetes_sdk.V1ResourceRequirements(
2445
2614
  # NOTE: base resources for this are kept to a minimum to save on running costs.
2446
2615
  # This has an adverse effect on startup time for the daemon, which can be completely
@@ -2912,6 +3081,34 @@ class ObjectMeta(object):
2912
3081
  return json.dumps(self.to_json(), indent=4)
2913
3082
 
2914
3083
 
3084
+ class WorkflowStep(object):
3085
+ def __init__(self):
3086
+ tree = lambda: defaultdict(tree)
3087
+ self.payload = tree()
3088
+
3089
+ def name(self, name):
3090
+ self.payload["name"] = str(name)
3091
+ return self
3092
+
3093
+ def template(self, template):
3094
+ self.payload["template"] = str(template)
3095
+ return self
3096
+
3097
+ def when(self, condition):
3098
+ self.payload["when"] = str(condition)
3099
+ return self
3100
+
3101
+ def step(self, expression):
3102
+ self.payload["expression"] = str(expression)
3103
+ return self
3104
+
3105
+ def to_json(self):
3106
+ return self.payload
3107
+
3108
+ def __str__(self):
3109
+ return json.dumps(self.to_json(), indent=4)
3110
+
3111
+
2915
3112
  class WorkflowSpec(object):
2916
3113
  # https://argoproj.github.io/argo-workflows/fields/#workflowspec
2917
3114
  # This object sets all Workflow level properties.
@@ -2942,6 +3139,11 @@ class WorkflowSpec(object):
2942
3139
  self.payload["entrypoint"] = entrypoint
2943
3140
  return self
2944
3141
 
3142
+ def onExit(self, on_exit_template):
3143
+ if on_exit_template:
3144
+ self.payload["onExit"] = on_exit_template
3145
+ return self
3146
+
2945
3147
  def parallelism(self, parallelism):
2946
3148
  # Set parallelism at Workflow level
2947
3149
  self.payload["parallelism"] = int(parallelism)
@@ -3067,6 +3269,18 @@ class Template(object):
3067
3269
  self.payload["dag"] = dag_template.to_json()
3068
3270
  return self
3069
3271
 
3272
+ def steps(self, steps):
3273
+ if "steps" not in self.payload:
3274
+ self.payload["steps"] = []
3275
+ # steps is a list of lists.
3276
+ # hence we go over every item in the incoming list
3277
+ # serialize it and then append the list to the payload
3278
+ step_list = []
3279
+ for step in steps:
3280
+ step_list.append(step.to_json())
3281
+ self.payload["steps"].append(step_list)
3282
+ return self
3283
+
3070
3284
  def container(self, container):
3071
3285
  # Luckily this can simply be V1Container and we are spared from writing more
3072
3286
  # boilerplate - https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1Container.md.