ob-metaflow 2.15.13.1__py2.py3-none-any.whl → 2.19.7.1rc0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. metaflow/__init__.py +10 -3
  2. metaflow/_vendor/imghdr/__init__.py +186 -0
  3. metaflow/_vendor/yaml/__init__.py +427 -0
  4. metaflow/_vendor/yaml/composer.py +139 -0
  5. metaflow/_vendor/yaml/constructor.py +748 -0
  6. metaflow/_vendor/yaml/cyaml.py +101 -0
  7. metaflow/_vendor/yaml/dumper.py +62 -0
  8. metaflow/_vendor/yaml/emitter.py +1137 -0
  9. metaflow/_vendor/yaml/error.py +75 -0
  10. metaflow/_vendor/yaml/events.py +86 -0
  11. metaflow/_vendor/yaml/loader.py +63 -0
  12. metaflow/_vendor/yaml/nodes.py +49 -0
  13. metaflow/_vendor/yaml/parser.py +589 -0
  14. metaflow/_vendor/yaml/reader.py +185 -0
  15. metaflow/_vendor/yaml/representer.py +389 -0
  16. metaflow/_vendor/yaml/resolver.py +227 -0
  17. metaflow/_vendor/yaml/scanner.py +1435 -0
  18. metaflow/_vendor/yaml/serializer.py +111 -0
  19. metaflow/_vendor/yaml/tokens.py +104 -0
  20. metaflow/cards.py +4 -0
  21. metaflow/cli.py +125 -21
  22. metaflow/cli_components/init_cmd.py +1 -0
  23. metaflow/cli_components/run_cmds.py +204 -40
  24. metaflow/cli_components/step_cmd.py +160 -4
  25. metaflow/client/__init__.py +1 -0
  26. metaflow/client/core.py +198 -130
  27. metaflow/client/filecache.py +59 -32
  28. metaflow/cmd/code/__init__.py +2 -1
  29. metaflow/cmd/develop/stub_generator.py +49 -18
  30. metaflow/cmd/develop/stubs.py +9 -27
  31. metaflow/cmd/make_wrapper.py +30 -0
  32. metaflow/datastore/__init__.py +1 -0
  33. metaflow/datastore/content_addressed_store.py +40 -9
  34. metaflow/datastore/datastore_set.py +10 -1
  35. metaflow/datastore/flow_datastore.py +124 -4
  36. metaflow/datastore/spin_datastore.py +91 -0
  37. metaflow/datastore/task_datastore.py +92 -6
  38. metaflow/debug.py +5 -0
  39. metaflow/decorators.py +331 -82
  40. metaflow/extension_support/__init__.py +414 -356
  41. metaflow/extension_support/_empty_file.py +2 -2
  42. metaflow/flowspec.py +322 -82
  43. metaflow/graph.py +178 -15
  44. metaflow/includefile.py +25 -3
  45. metaflow/lint.py +94 -3
  46. metaflow/meta_files.py +13 -0
  47. metaflow/metadata_provider/metadata.py +13 -2
  48. metaflow/metaflow_config.py +66 -4
  49. metaflow/metaflow_environment.py +91 -25
  50. metaflow/metaflow_profile.py +18 -0
  51. metaflow/metaflow_version.py +16 -1
  52. metaflow/package/__init__.py +673 -0
  53. metaflow/packaging_sys/__init__.py +880 -0
  54. metaflow/packaging_sys/backend.py +128 -0
  55. metaflow/packaging_sys/distribution_support.py +153 -0
  56. metaflow/packaging_sys/tar_backend.py +99 -0
  57. metaflow/packaging_sys/utils.py +54 -0
  58. metaflow/packaging_sys/v1.py +527 -0
  59. metaflow/parameters.py +6 -2
  60. metaflow/plugins/__init__.py +6 -0
  61. metaflow/plugins/airflow/airflow.py +11 -1
  62. metaflow/plugins/airflow/airflow_cli.py +16 -5
  63. metaflow/plugins/argo/argo_client.py +42 -20
  64. metaflow/plugins/argo/argo_events.py +6 -6
  65. metaflow/plugins/argo/argo_workflows.py +1023 -344
  66. metaflow/plugins/argo/argo_workflows_cli.py +396 -94
  67. metaflow/plugins/argo/argo_workflows_decorator.py +9 -0
  68. metaflow/plugins/argo/argo_workflows_deployer_objects.py +75 -49
  69. metaflow/plugins/argo/capture_error.py +5 -2
  70. metaflow/plugins/argo/conditional_input_paths.py +35 -0
  71. metaflow/plugins/argo/exit_hooks.py +209 -0
  72. metaflow/plugins/argo/param_val.py +19 -0
  73. metaflow/plugins/aws/aws_client.py +6 -0
  74. metaflow/plugins/aws/aws_utils.py +33 -1
  75. metaflow/plugins/aws/batch/batch.py +72 -5
  76. metaflow/plugins/aws/batch/batch_cli.py +24 -3
  77. metaflow/plugins/aws/batch/batch_decorator.py +57 -6
  78. metaflow/plugins/aws/step_functions/step_functions.py +28 -3
  79. metaflow/plugins/aws/step_functions/step_functions_cli.py +49 -4
  80. metaflow/plugins/aws/step_functions/step_functions_deployer.py +3 -0
  81. metaflow/plugins/aws/step_functions/step_functions_deployer_objects.py +30 -0
  82. metaflow/plugins/cards/card_cli.py +20 -1
  83. metaflow/plugins/cards/card_creator.py +24 -1
  84. metaflow/plugins/cards/card_datastore.py +21 -49
  85. metaflow/plugins/cards/card_decorator.py +58 -6
  86. metaflow/plugins/cards/card_modules/basic.py +38 -9
  87. metaflow/plugins/cards/card_modules/bundle.css +1 -1
  88. metaflow/plugins/cards/card_modules/chevron/renderer.py +1 -1
  89. metaflow/plugins/cards/card_modules/components.py +592 -3
  90. metaflow/plugins/cards/card_modules/convert_to_native_type.py +34 -5
  91. metaflow/plugins/cards/card_modules/json_viewer.py +232 -0
  92. metaflow/plugins/cards/card_modules/main.css +1 -0
  93. metaflow/plugins/cards/card_modules/main.js +56 -41
  94. metaflow/plugins/cards/card_modules/test_cards.py +22 -6
  95. metaflow/plugins/cards/component_serializer.py +1 -8
  96. metaflow/plugins/cards/metadata.py +22 -0
  97. metaflow/plugins/catch_decorator.py +9 -0
  98. metaflow/plugins/datastores/local_storage.py +12 -6
  99. metaflow/plugins/datastores/spin_storage.py +12 -0
  100. metaflow/plugins/datatools/s3/s3.py +49 -17
  101. metaflow/plugins/datatools/s3/s3op.py +113 -66
  102. metaflow/plugins/env_escape/client_modules.py +102 -72
  103. metaflow/plugins/events_decorator.py +127 -121
  104. metaflow/plugins/exit_hook/__init__.py +0 -0
  105. metaflow/plugins/exit_hook/exit_hook_decorator.py +46 -0
  106. metaflow/plugins/exit_hook/exit_hook_script.py +52 -0
  107. metaflow/plugins/kubernetes/kubernetes.py +12 -1
  108. metaflow/plugins/kubernetes/kubernetes_cli.py +11 -0
  109. metaflow/plugins/kubernetes/kubernetes_decorator.py +25 -6
  110. metaflow/plugins/kubernetes/kubernetes_job.py +12 -4
  111. metaflow/plugins/kubernetes/kubernetes_jobsets.py +31 -30
  112. metaflow/plugins/metadata_providers/local.py +76 -82
  113. metaflow/plugins/metadata_providers/service.py +13 -9
  114. metaflow/plugins/metadata_providers/spin.py +16 -0
  115. metaflow/plugins/package_cli.py +36 -24
  116. metaflow/plugins/parallel_decorator.py +11 -2
  117. metaflow/plugins/parsers.py +16 -0
  118. metaflow/plugins/pypi/bootstrap.py +7 -1
  119. metaflow/plugins/pypi/conda_decorator.py +41 -82
  120. metaflow/plugins/pypi/conda_environment.py +14 -6
  121. metaflow/plugins/pypi/micromamba.py +9 -1
  122. metaflow/plugins/pypi/pip.py +41 -5
  123. metaflow/plugins/pypi/pypi_decorator.py +4 -4
  124. metaflow/plugins/pypi/utils.py +22 -0
  125. metaflow/plugins/secrets/__init__.py +3 -0
  126. metaflow/plugins/secrets/secrets_decorator.py +14 -178
  127. metaflow/plugins/secrets/secrets_func.py +49 -0
  128. metaflow/plugins/secrets/secrets_spec.py +101 -0
  129. metaflow/plugins/secrets/utils.py +74 -0
  130. metaflow/plugins/test_unbounded_foreach_decorator.py +2 -2
  131. metaflow/plugins/timeout_decorator.py +0 -1
  132. metaflow/plugins/uv/bootstrap.py +29 -1
  133. metaflow/plugins/uv/uv_environment.py +5 -3
  134. metaflow/pylint_wrapper.py +5 -1
  135. metaflow/runner/click_api.py +79 -26
  136. metaflow/runner/deployer.py +208 -6
  137. metaflow/runner/deployer_impl.py +32 -12
  138. metaflow/runner/metaflow_runner.py +266 -33
  139. metaflow/runner/subprocess_manager.py +21 -1
  140. metaflow/runner/utils.py +27 -16
  141. metaflow/runtime.py +660 -66
  142. metaflow/task.py +255 -26
  143. metaflow/user_configs/config_options.py +33 -21
  144. metaflow/user_configs/config_parameters.py +220 -58
  145. metaflow/user_decorators/__init__.py +0 -0
  146. metaflow/user_decorators/common.py +144 -0
  147. metaflow/user_decorators/mutable_flow.py +512 -0
  148. metaflow/user_decorators/mutable_step.py +424 -0
  149. metaflow/user_decorators/user_flow_decorator.py +264 -0
  150. metaflow/user_decorators/user_step_decorator.py +749 -0
  151. metaflow/util.py +197 -7
  152. metaflow/vendor.py +23 -7
  153. metaflow/version.py +1 -1
  154. {ob_metaflow-2.15.13.1.data → ob_metaflow-2.19.7.1rc0.data}/data/share/metaflow/devtools/Makefile +13 -2
  155. {ob_metaflow-2.15.13.1.data → ob_metaflow-2.19.7.1rc0.data}/data/share/metaflow/devtools/Tiltfile +107 -7
  156. {ob_metaflow-2.15.13.1.data → ob_metaflow-2.19.7.1rc0.data}/data/share/metaflow/devtools/pick_services.sh +1 -0
  157. {ob_metaflow-2.15.13.1.dist-info → ob_metaflow-2.19.7.1rc0.dist-info}/METADATA +2 -3
  158. {ob_metaflow-2.15.13.1.dist-info → ob_metaflow-2.19.7.1rc0.dist-info}/RECORD +162 -121
  159. {ob_metaflow-2.15.13.1.dist-info → ob_metaflow-2.19.7.1rc0.dist-info}/WHEEL +1 -1
  160. metaflow/_vendor/v3_5/__init__.py +0 -1
  161. metaflow/_vendor/v3_5/importlib_metadata/__init__.py +0 -644
  162. metaflow/_vendor/v3_5/importlib_metadata/_compat.py +0 -152
  163. metaflow/_vendor/v3_5/zipp.py +0 -329
  164. metaflow/info_file.py +0 -25
  165. metaflow/package.py +0 -203
  166. metaflow/user_configs/config_decorators.py +0 -568
  167. {ob_metaflow-2.15.13.1.dist-info → ob_metaflow-2.19.7.1rc0.dist-info}/entry_points.txt +0 -0
  168. {ob_metaflow-2.15.13.1.dist-info → ob_metaflow-2.19.7.1rc0.dist-info}/licenses/LICENSE +0 -0
  169. {ob_metaflow-2.15.13.1.dist-info → ob_metaflow-2.19.7.1rc0.dist-info}/top_level.txt +0 -0
@@ -123,6 +123,15 @@ class ArgoWorkflowsInternalDecorator(StepDecorator):
123
123
  with open("/mnt/out/split_cardinality", "w") as file:
124
124
  json.dump(flow._foreach_num_splits, file)
125
125
 
126
+ # For conditional branches we need to record the value of the switch to disk, in order to pass it as an
127
+ # output from the switching step to be used further down the DAG
128
+ if graph[step_name].type == "split-switch":
129
+ # TODO: A nicer way to access the chosen step?
130
+ _out_funcs, _ = flow._transition
131
+ chosen_step = _out_funcs[0]
132
+ with open("/mnt/out/switch_step", "w") as file:
133
+ file.write(chosen_step)
134
+
126
135
  # For steps that have a `@parallel` decorator set to them, we will be relying on Jobsets
127
136
  # to run the task. In this case, we cannot set anything in the
128
137
  # `/mnt/out` directory, since such form of output mounts are not available to Jobset executions.
@@ -9,59 +9,16 @@ from metaflow.exception import MetaflowException
9
9
  from metaflow.plugins.argo.argo_client import ArgoClient
10
10
  from metaflow.metaflow_config import KUBERNETES_NAMESPACE
11
11
  from metaflow.plugins.argo.argo_workflows import ArgoWorkflows
12
- from metaflow.runner.deployer import Deployer, DeployedFlow, TriggeredRun
12
+ from metaflow.runner.deployer import (
13
+ Deployer,
14
+ DeployedFlow,
15
+ TriggeredRun,
16
+ generate_fake_flow_file_contents,
17
+ )
13
18
 
14
19
  from metaflow.runner.utils import get_lower_level_group, handle_timeout, temporary_fifo
15
20
 
16
21
 
17
- def generate_fake_flow_file_contents(
18
- flow_name: str, param_info: dict, project_name: Optional[str] = None
19
- ):
20
- params_code = ""
21
- for _, param_details in param_info.items():
22
- param_python_var_name = param_details["python_var_name"]
23
- param_name = param_details["name"]
24
- param_type = param_details["type"]
25
- param_help = param_details["description"]
26
- param_required = param_details["is_required"]
27
-
28
- if param_type == "JSON":
29
- params_code += (
30
- f" {param_python_var_name} = Parameter('{param_name}', "
31
- f"type=JSONType, help='''{param_help}''', required={param_required})\n"
32
- )
33
- elif param_type == "FilePath":
34
- is_text = param_details.get("is_text", True)
35
- encoding = param_details.get("encoding", "utf-8")
36
- params_code += (
37
- f" {param_python_var_name} = IncludeFile('{param_name}', "
38
- f"is_text={is_text}, encoding='{encoding}', help='''{param_help}''', "
39
- f"required={param_required})\n"
40
- )
41
- else:
42
- params_code += (
43
- f" {param_python_var_name} = Parameter('{param_name}', "
44
- f"type={param_type}, help='''{param_help}''', required={param_required})\n"
45
- )
46
-
47
- project_decorator = f"@project(name='{project_name}')\n" if project_name else ""
48
-
49
- contents = f"""\
50
- from metaflow import FlowSpec, Parameter, IncludeFile, JSONType, step, project
51
- {project_decorator}class {flow_name}(FlowSpec):
52
- {params_code}
53
- @step
54
- def start(self):
55
- self.next(self.end)
56
- @step
57
- def end(self):
58
- pass
59
- if __name__ == '__main__':
60
- {flow_name}()
61
- """
62
- return contents
63
-
64
-
65
22
  class ArgoWorkflowsTriggeredRun(TriggeredRun):
66
23
  """
67
24
  A class representing a triggered Argo Workflow execution.
@@ -246,6 +203,38 @@ class ArgoWorkflowsDeployedFlow(DeployedFlow):
246
203
 
247
204
  TYPE: ClassVar[Optional[str]] = "argo-workflows"
248
205
 
206
+ @classmethod
207
+ def list_deployed_flows(cls, flow_name: Optional[str] = None):
208
+ """
209
+ List all deployed Argo Workflow templates.
210
+
211
+ Parameters
212
+ ----------
213
+ flow_name : str, optional, default None
214
+ If specified, only list deployed flows for this specific flow name.
215
+ If None, list all deployed flows.
216
+
217
+ Yields
218
+ ------
219
+ ArgoWorkflowsDeployedFlow
220
+ `ArgoWorkflowsDeployedFlow` objects representing deployed
221
+ workflow templates on Argo Workflows.
222
+ """
223
+ from metaflow.plugins.argo.argo_workflows import ArgoWorkflows
224
+
225
+ # When flow_name is None, use all=True to get all templates
226
+ # When flow_name is specified, use all=False to filter by flow_name
227
+ all_templates = flow_name is None
228
+ for template_name in ArgoWorkflows.list_templates(
229
+ flow_name=flow_name, all=all_templates
230
+ ):
231
+ try:
232
+ deployed_flow = cls.from_deployment(template_name)
233
+ yield deployed_flow
234
+ except Exception:
235
+ # Skip templates that can't be converted to DeployedFlow objects
236
+ continue
237
+
249
238
  @classmethod
250
239
  def from_deployment(cls, identifier: str, metadata: Optional[str] = None):
251
240
  """
@@ -321,6 +310,43 @@ class ArgoWorkflowsDeployedFlow(DeployedFlow):
321
310
 
322
311
  return cls(deployer=d)
323
312
 
313
+ @classmethod
314
+ def get_triggered_run(
315
+ cls, identifier: str, run_id: str, metadata: Optional[str] = None
316
+ ):
317
+ """
318
+ Retrieves a `ArgoWorkflowsTriggeredRun` object from an identifier, a run id and
319
+ optional metadata.
320
+
321
+ Parameters
322
+ ----------
323
+ identifier : str
324
+ Deployer specific identifier for the workflow to retrieve
325
+ run_id : str
326
+ Run ID for the which to fetch the triggered run object
327
+ metadata : str, optional, default None
328
+ Optional deployer specific metadata.
329
+
330
+ Returns
331
+ -------
332
+ ArgoWorkflowsTriggeredRun
333
+ A `ArgoWorkflowsTriggeredRun` object representing the
334
+ triggered run on argo workflows.
335
+ """
336
+ deployed_flow_obj = cls.from_deployment(identifier, metadata)
337
+ return ArgoWorkflowsTriggeredRun(
338
+ deployer=deployed_flow_obj.deployer,
339
+ content=json.dumps(
340
+ {
341
+ "metadata": deployed_flow_obj.deployer.metadata,
342
+ "pathspec": "/".join(
343
+ (deployed_flow_obj.deployer.flow_name, run_id)
344
+ ),
345
+ "name": run_id,
346
+ }
347
+ ),
348
+ )
349
+
324
350
  @property
325
351
  def production_token(self) -> Optional[str]:
326
352
  """
@@ -1,6 +1,6 @@
1
1
  import json
2
2
  import os
3
- from datetime import datetime
3
+ from datetime import datetime, timezone
4
4
 
5
5
  ###
6
6
  # Algorithm to determine 1st error:
@@ -26,6 +26,9 @@ def parse_workflow_failures():
26
26
  def group_failures_by_template(failures):
27
27
  groups = {}
28
28
  for failure in failures:
29
+ if failure.get("finishedAt", None) is None:
30
+ timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
31
+ failure["finishedAt"] = timestamp
29
32
  groups.setdefault(failure["templateName"], []).append(failure)
30
33
  return groups
31
34
 
@@ -53,7 +56,7 @@ def determine_first_error():
53
56
  grouped_failures = group_failures_by_template(failures)
54
57
  for group in grouped_failures.values():
55
58
  group.sort(
56
- key=lambda x: datetime.strptime(x["finishedAt"], "%Y-%m-%dT%H:%M:%SZ")
59
+ key=lambda g: datetime.strptime(g["finishedAt"], "%Y-%m-%dT%H:%M:%SZ")
57
60
  )
58
61
 
59
62
  earliest_group = grouped_failures[
@@ -0,0 +1,35 @@
1
+ from math import inf
2
+ import sys
3
+ from metaflow.util import decompress_list, compress_list
4
+ import base64
5
+
6
+
7
+ def generate_input_paths(input_paths, skippable_steps):
8
+ # => run_id/step/:foo,bar
9
+ # input_paths are base64 encoded due to Argo shenanigans
10
+ decoded = base64.b64decode(input_paths).decode("utf-8")
11
+ paths = decompress_list(decoded)
12
+
13
+ # some of the paths are going to be malformed due to never having executed per conditional.
14
+ # strip these out of the list.
15
+
16
+ # all pathspecs of leading steps that executed.
17
+ trimmed = [path for path in paths if not "{{" in path]
18
+
19
+ # pathspecs of leading steps that are conditional, and should be used instead of non-conditional ones
20
+ # e.g. the case of skipping switches: start -> case_step -> conditional_a or end
21
+ conditionals = [
22
+ path for path in trimmed if not any(step in path for step in skippable_steps)
23
+ ]
24
+ pathspecs_to_use = conditionals if conditionals else trimmed
25
+ return compress_list(pathspecs_to_use, zlibmin=inf)
26
+
27
+
28
+ if __name__ == "__main__":
29
+ input_paths = sys.argv[1]
30
+ try:
31
+ skippable_steps = sys.argv[2].split(",")
32
+ except IndexError:
33
+ skippable_steps = []
34
+
35
+ print(generate_input_paths(input_paths, skippable_steps))
@@ -0,0 +1,209 @@
1
+ from collections import defaultdict
2
+ import json
3
+ from typing import Dict, List, Optional
4
+
5
+
6
+ class JsonSerializable(object):
7
+ def to_json(self):
8
+ return self.payload
9
+
10
+ def __str__(self):
11
+ return json.dumps(self.payload, indent=4)
12
+
13
+
14
+ class _LifecycleHook(JsonSerializable):
15
+ # https://argoproj.github.io/argo-workflows/fields/#lifecyclehook
16
+
17
+ def __init__(self, name):
18
+ tree = lambda: defaultdict(tree)
19
+ self.name = name
20
+ self.payload = tree()
21
+
22
+ def expression(self, expression):
23
+ self.payload["expression"] = str(expression)
24
+ return self
25
+
26
+ def template(self, template):
27
+ self.payload["template"] = template
28
+ return self
29
+
30
+
31
+ class _Template(JsonSerializable):
32
+ # https://argoproj.github.io/argo-workflows/fields/#template
33
+
34
+ def __init__(self, name):
35
+ tree = lambda: defaultdict(tree)
36
+ self.name = name
37
+ self.payload = tree()
38
+ self.payload["name"] = name
39
+
40
+ def http(self, http):
41
+ self.payload["http"] = http.to_json()
42
+ return self
43
+
44
+ def script(self, script):
45
+ self.payload["script"] = script.to_json()
46
+ return self
47
+
48
+ def container(self, container):
49
+ self.payload["container"] = container
50
+ return self
51
+
52
+ def service_account_name(self, service_account_name):
53
+ self.payload["serviceAccountName"] = service_account_name
54
+ return self
55
+
56
+
57
+ class Hook(object):
58
+ """
59
+ Abstraction for Argo Workflows exit hooks.
60
+ A hook consists of a Template, and one or more LifecycleHooks that trigger the template
61
+ """
62
+
63
+ template: "_Template"
64
+ lifecycle_hooks: List["_LifecycleHook"]
65
+
66
+
67
+ class _HttpSpec(JsonSerializable):
68
+ # https://argoproj.github.io/argo-workflows/fields/#http
69
+
70
+ def __init__(self, method):
71
+ tree = lambda: defaultdict(tree)
72
+ self.payload = tree()
73
+ self.payload["method"] = method
74
+ self.payload["headers"] = []
75
+
76
+ def header(self, header, value):
77
+ self.payload["headers"].append({"name": header, "value": value})
78
+ return self
79
+
80
+ def body(self, body):
81
+ self.payload["body"] = str(body)
82
+ return self
83
+
84
+ def url(self, url):
85
+ self.payload["url"] = url
86
+ return self
87
+
88
+ def success_condition(self, success_condition):
89
+ self.payload["successCondition"] = success_condition
90
+ return self
91
+
92
+
93
+ # HTTP hook
94
+ class HttpExitHook(Hook):
95
+ def __init__(
96
+ self,
97
+ name: str,
98
+ url: str,
99
+ method: str = "GET",
100
+ headers: Optional[Dict] = None,
101
+ body: Optional[str] = None,
102
+ on_success: bool = False,
103
+ on_error: bool = False,
104
+ ):
105
+ self.template = _Template(name)
106
+ http = _HttpSpec(method).url(url)
107
+ if headers is not None:
108
+ for header, value in headers.items():
109
+ http.header(header, value)
110
+
111
+ if body is not None:
112
+ http.body(body)
113
+
114
+ self.template.http(http)
115
+
116
+ self.lifecycle_hooks = []
117
+
118
+ if on_success and on_error:
119
+ raise Exception("Set only one of the on_success/on_error at a time.")
120
+
121
+ if on_success:
122
+ self.lifecycle_hooks.append(
123
+ _LifecycleHook(name)
124
+ .expression("workflow.status == 'Succeeded'")
125
+ .template(self.template.name)
126
+ )
127
+
128
+ if on_error:
129
+ self.lifecycle_hooks.append(
130
+ _LifecycleHook(name)
131
+ .expression("workflow.status == 'Error' || workflow.status == 'Failed'")
132
+ .template(self.template.name)
133
+ )
134
+
135
+ if not on_success and not on_error:
136
+ # add an expressionless lifecycle hook
137
+ self.lifecycle_hooks.append(_LifecycleHook(name).template(name))
138
+
139
+
140
+ class ExitHookHack(Hook):
141
+ # Warning: terrible hack to workaround a bug in Argo Workflow where the
142
+ # templates listed above do not execute unless there is an
143
+ # explicit exit hook. as and when this bug is patched, we should
144
+ # remove this effectively no-op template.
145
+ # Note: We use the Http template because changing this to an actual no-op container had the side-effect of
146
+ # leaving LifecycleHooks in a pending state even when they have finished execution.
147
+ def __init__(
148
+ self,
149
+ url,
150
+ headers=None,
151
+ body=None,
152
+ ):
153
+ self.template = _Template("exit-hook-hack")
154
+ http = _HttpSpec("GET").url(url)
155
+ if headers is not None:
156
+ for header, value in headers.items():
157
+ http.header(header, value)
158
+
159
+ if body is not None:
160
+ http.body(json.dumps(body))
161
+
162
+ http.success_condition("true == true")
163
+
164
+ self.template.http(http)
165
+
166
+ self.lifecycle_hooks = []
167
+
168
+ # add an expressionless lifecycle hook
169
+ self.lifecycle_hooks.append(_LifecycleHook("exit").template("exit-hook-hack"))
170
+
171
+
172
+ class ContainerHook(Hook):
173
+ def __init__(
174
+ self,
175
+ name: str,
176
+ container: Dict,
177
+ service_account_name: str = None,
178
+ on_success: bool = False,
179
+ on_error: bool = False,
180
+ ):
181
+ self.template = _Template(name)
182
+
183
+ if service_account_name is not None:
184
+ self.template.service_account_name(service_account_name)
185
+
186
+ self.template.container(container)
187
+
188
+ self.lifecycle_hooks = []
189
+
190
+ if on_success and on_error:
191
+ raise Exception("Set only one of the on_success/on_error at a time.")
192
+
193
+ if on_success:
194
+ self.lifecycle_hooks.append(
195
+ _LifecycleHook(name)
196
+ .expression("workflow.status == 'Succeeded'")
197
+ .template(self.template.name)
198
+ )
199
+
200
+ if on_error:
201
+ self.lifecycle_hooks.append(
202
+ _LifecycleHook(name)
203
+ .expression("workflow.status == 'Error' || workflow.status == 'Failed'")
204
+ .template(self.template.name)
205
+ )
206
+
207
+ if not on_success and not on_error:
208
+ # add an expressionless lifecycle hook
209
+ self.lifecycle_hooks.append(_LifecycleHook(name).template(name))
@@ -0,0 +1,19 @@
1
+ import sys
2
+ import base64
3
+ import json
4
+
5
+
6
+ def parse_parameter_value(base64_value):
7
+ val = base64.b64decode(base64_value).decode("utf-8")
8
+
9
+ try:
10
+ return json.loads(val)
11
+ except json.decoder.JSONDecodeError:
12
+ # fallback to using the original value.
13
+ return val
14
+
15
+
16
+ if __name__ == "__main__":
17
+ base64_val = sys.argv[1]
18
+
19
+ print(parse_parameter_value(base64_val))
@@ -35,6 +35,12 @@ class Boto3ClientProvider(object):
35
35
  "Could not import module 'boto3'. Install boto3 first."
36
36
  )
37
37
 
38
+ # Convert dictionary config to Config object if needed
39
+ if "config" in client_params and not isinstance(
40
+ client_params["config"], Config
41
+ ):
42
+ client_params["config"] = Config(**client_params["config"])
43
+
38
44
  if module == "s3" and (
39
45
  "config" not in client_params or client_params["config"].retries is None
40
46
  ):
@@ -49,7 +49,7 @@ def get_ec2_instance_metadata():
49
49
  # Try to get an IMDSv2 token.
50
50
  token = requests.put(
51
51
  url="http://169.254.169.254/latest/api/token",
52
- headers={"X-aws-ec2-metadata-token-ttl-seconds": 100},
52
+ headers={"X-aws-ec2-metadata-token-ttl-seconds": "100"},
53
53
  timeout=timeout,
54
54
  ).text
55
55
  except:
@@ -208,3 +208,35 @@ def sanitize_batch_tag(key, value):
208
208
  _value = re.sub(RE_NOT_PERMITTED, "", value)[:256]
209
209
 
210
210
  return _key, _value
211
+
212
+
213
+ def validate_aws_tag(key: str, value: str):
214
+ PERMITTED = r"[A-Za-z0-9\s\+\-\=\.\_\:\/\@]"
215
+
216
+ AWS_PREFIX = r"^aws\:" # case-insensitive.
217
+ if re.match(AWS_PREFIX, key, re.IGNORECASE) or re.match(
218
+ AWS_PREFIX, value, re.IGNORECASE
219
+ ):
220
+ raise MetaflowException(
221
+ "'aws:' is not an allowed prefix for either tag keys or values"
222
+ )
223
+
224
+ if len(key) > 128:
225
+ raise MetaflowException(
226
+ "Tag key *%s* is too long. Maximum allowed tag key length is 128." % key
227
+ )
228
+ if len(value) > 256:
229
+ raise MetaflowException(
230
+ "Tag value *%s* is too long. Maximum allowed tag value length is 256."
231
+ % value
232
+ )
233
+
234
+ if not re.match(PERMITTED, key):
235
+ raise MetaflowException(
236
+ "Key *s* is not permitted. Tags must match pattern: %s" % (key, PERMITTED)
237
+ )
238
+ if not re.match(PERMITTED, value):
239
+ raise MetaflowException(
240
+ "Value *%s* is not permitted. Tags must match pattern: %s"
241
+ % (value, PERMITTED)
242
+ )
@@ -53,20 +53,32 @@ class BatchKilledException(MetaflowException):
53
53
 
54
54
 
55
55
  class Batch(object):
56
- def __init__(self, metadata, environment):
56
+ def __init__(self, metadata, environment, flow_datastore=None):
57
57
  self.metadata = metadata
58
58
  self.environment = environment
59
+ self.flow_datastore = flow_datastore
59
60
  self._client = BatchClient()
60
61
  atexit.register(lambda: self.job.kill() if hasattr(self, "job") else None)
61
62
 
62
- def _command(self, environment, code_package_url, step_name, step_cmds, task_spec):
63
+ def _command(
64
+ self,
65
+ environment,
66
+ code_package_metadata,
67
+ code_package_url,
68
+ step_name,
69
+ step_cmds,
70
+ task_spec,
71
+ offload_command_to_s3,
72
+ ):
63
73
  mflog_expr = export_mflog_env_vars(
64
74
  datastore_type="s3",
65
75
  stdout_path=STDOUT_PATH,
66
76
  stderr_path=STDERR_PATH,
67
77
  **task_spec
68
78
  )
69
- init_cmds = environment.get_package_commands(code_package_url, "s3")
79
+ init_cmds = environment.get_package_commands(
80
+ code_package_url, "s3", code_package_metadata
81
+ )
70
82
  init_expr = " && ".join(init_cmds)
71
83
  step_expr = bash_capture_logs(
72
84
  " && ".join(environment.bootstrap_commands(step_name, "s3") + step_cmds)
@@ -94,7 +106,43 @@ class Batch(object):
94
106
  # We lose the last logs in this scenario (although they are visible
95
107
  # still through AWS CloudWatch console).
96
108
  cmd_str += "c=$?; %s; exit $c" % BASH_SAVE_LOGS
97
- return shlex.split('bash -c "%s"' % cmd_str)
109
+ command = shlex.split('bash -c "%s"' % cmd_str)
110
+
111
+ if not offload_command_to_s3:
112
+ return command
113
+
114
+ # If S3 upload is enabled, we need to modify the command after it's created
115
+ if self.flow_datastore is None:
116
+ raise MetaflowException(
117
+ "Can not offload Batch command to S3 without a datastore configured."
118
+ )
119
+
120
+ from metaflow.plugins.aws.aws_utils import parse_s3_full_path
121
+
122
+ # Get the command that was created
123
+ # Upload the command to S3 during deployment
124
+ try:
125
+ command_bytes = cmd_str.encode("utf-8")
126
+ result_paths = self.flow_datastore.save_data([command_bytes], len_hint=1)
127
+ s3_path, _key = result_paths[0]
128
+
129
+ bucket, s3_object = parse_s3_full_path(s3_path)
130
+ download_script = "{python} -c '{script}'".format(
131
+ python=self.environment._python(),
132
+ script='import boto3, os; ep=os.getenv(\\"METAFLOW_S3_ENDPOINT_URL\\"); boto3.client(\\"s3\\", **({\\"endpoint_url\\":ep} if ep else {})).download_file(\\"%s\\", \\"%s\\", \\"/tmp/step_command.sh\\")'
133
+ % (bucket, s3_object),
134
+ )
135
+ download_cmd = (
136
+ f"{self.environment._get_install_dependencies_cmd('s3')} && " # required for boto3 due to the original dependencies cmd getting packaged, and not being downloaded in time.
137
+ f"{download_script} && "
138
+ f"chmod +x /tmp/step_command.sh && "
139
+ f"bash /tmp/step_command.sh"
140
+ )
141
+ new_cmd = shlex.split('bash -c "%s"' % download_cmd)
142
+ return new_cmd
143
+ except Exception as e:
144
+ print(f"Warning: Failed to upload command to S3: {e}")
145
+ print("Falling back to inline command")
98
146
 
99
147
  def _search_jobs(self, flow_name, run_id, user):
100
148
  if user is None:
@@ -167,6 +215,7 @@ class Batch(object):
167
215
  step_name,
168
216
  step_cli,
169
217
  task_spec,
218
+ code_package_metadata,
170
219
  code_package_sha,
171
220
  code_package_url,
172
221
  code_package_ds,
@@ -188,6 +237,7 @@ class Batch(object):
188
237
  host_volumes=None,
189
238
  efs_volumes=None,
190
239
  use_tmpfs=None,
240
+ aws_batch_tags=None,
191
241
  tmpfs_tempdir=None,
192
242
  tmpfs_size=None,
193
243
  tmpfs_path=None,
@@ -195,6 +245,7 @@ class Batch(object):
195
245
  ephemeral_storage=None,
196
246
  log_driver=None,
197
247
  log_options=None,
248
+ offload_command_to_s3=False,
198
249
  ):
199
250
  job_name = self._job_name(
200
251
  attrs.get("metaflow.user"),
@@ -210,7 +261,13 @@ class Batch(object):
210
261
  .job_queue(queue)
211
262
  .command(
212
263
  self._command(
213
- self.environment, code_package_url, step_name, [step_cli], task_spec
264
+ self.environment,
265
+ code_package_metadata,
266
+ code_package_url,
267
+ step_name,
268
+ [step_cli],
269
+ task_spec,
270
+ offload_command_to_s3,
214
271
  )
215
272
  )
216
273
  .image(image)
@@ -249,6 +306,7 @@ class Batch(object):
249
306
  )
250
307
  .task_id(attrs.get("metaflow.task_id"))
251
308
  .environment_variable("AWS_DEFAULT_REGION", self._client.region())
309
+ .environment_variable("METAFLOW_CODE_METADATA", code_package_metadata)
252
310
  .environment_variable("METAFLOW_CODE_SHA", code_package_sha)
253
311
  .environment_variable("METAFLOW_CODE_URL", code_package_url)
254
312
  .environment_variable("METAFLOW_CODE_DS", code_package_ds)
@@ -327,6 +385,11 @@ class Batch(object):
327
385
  if key in attrs:
328
386
  k, v = sanitize_batch_tag(key, attrs.get(key))
329
387
  job.tag(k, v)
388
+
389
+ if aws_batch_tags is not None:
390
+ for key, value in aws_batch_tags.items():
391
+ job.tag(key, value)
392
+
330
393
  return job
331
394
 
332
395
  def launch_job(
@@ -334,6 +397,7 @@ class Batch(object):
334
397
  step_name,
335
398
  step_cli,
336
399
  task_spec,
400
+ code_package_metadata,
337
401
  code_package_sha,
338
402
  code_package_url,
339
403
  code_package_ds,
@@ -353,6 +417,7 @@ class Batch(object):
353
417
  host_volumes=None,
354
418
  efs_volumes=None,
355
419
  use_tmpfs=None,
420
+ aws_batch_tags=None,
356
421
  tmpfs_tempdir=None,
357
422
  tmpfs_size=None,
358
423
  tmpfs_path=None,
@@ -374,6 +439,7 @@ class Batch(object):
374
439
  step_name,
375
440
  step_cli,
376
441
  task_spec,
442
+ code_package_metadata,
377
443
  code_package_sha,
378
444
  code_package_url,
379
445
  code_package_ds,
@@ -395,6 +461,7 @@ class Batch(object):
395
461
  host_volumes=host_volumes,
396
462
  efs_volumes=efs_volumes,
397
463
  use_tmpfs=use_tmpfs,
464
+ aws_batch_tags=aws_batch_tags,
398
465
  tmpfs_tempdir=tmpfs_tempdir,
399
466
  tmpfs_size=tmpfs_size,
400
467
  tmpfs_path=tmpfs_path,