ob-metaflow 2.12.30.2__py2.py3-none-any.whl → 2.13.6.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ob-metaflow might be problematic. Click here for more details.

Files changed (96) hide show
  1. metaflow/__init__.py +3 -0
  2. metaflow/cards.py +1 -0
  3. metaflow/cli.py +185 -717
  4. metaflow/cli_args.py +17 -0
  5. metaflow/cli_components/__init__.py +0 -0
  6. metaflow/cli_components/dump_cmd.py +96 -0
  7. metaflow/cli_components/init_cmd.py +51 -0
  8. metaflow/cli_components/run_cmds.py +362 -0
  9. metaflow/cli_components/step_cmd.py +176 -0
  10. metaflow/cli_components/utils.py +140 -0
  11. metaflow/cmd/develop/stub_generator.py +9 -2
  12. metaflow/datastore/flow_datastore.py +2 -2
  13. metaflow/decorators.py +63 -2
  14. metaflow/exception.py +8 -2
  15. metaflow/extension_support/plugins.py +42 -27
  16. metaflow/flowspec.py +176 -23
  17. metaflow/graph.py +28 -27
  18. metaflow/includefile.py +50 -22
  19. metaflow/lint.py +35 -20
  20. metaflow/metadata_provider/heartbeat.py +23 -8
  21. metaflow/metaflow_config.py +10 -1
  22. metaflow/multicore_utils.py +31 -14
  23. metaflow/package.py +17 -3
  24. metaflow/parameters.py +97 -25
  25. metaflow/plugins/__init__.py +22 -0
  26. metaflow/plugins/airflow/airflow.py +18 -17
  27. metaflow/plugins/airflow/airflow_cli.py +1 -0
  28. metaflow/plugins/argo/argo_client.py +0 -2
  29. metaflow/plugins/argo/argo_workflows.py +195 -132
  30. metaflow/plugins/argo/argo_workflows_cli.py +1 -1
  31. metaflow/plugins/argo/argo_workflows_decorator.py +2 -4
  32. metaflow/plugins/argo/argo_workflows_deployer_objects.py +51 -9
  33. metaflow/plugins/argo/jobset_input_paths.py +0 -1
  34. metaflow/plugins/aws/aws_utils.py +6 -1
  35. metaflow/plugins/aws/batch/batch_client.py +1 -3
  36. metaflow/plugins/aws/batch/batch_decorator.py +13 -13
  37. metaflow/plugins/aws/secrets_manager/aws_secrets_manager_secrets_provider.py +13 -10
  38. metaflow/plugins/aws/step_functions/dynamo_db_client.py +0 -3
  39. metaflow/plugins/aws/step_functions/production_token.py +1 -1
  40. metaflow/plugins/aws/step_functions/step_functions.py +33 -1
  41. metaflow/plugins/aws/step_functions/step_functions_cli.py +1 -1
  42. metaflow/plugins/aws/step_functions/step_functions_decorator.py +0 -1
  43. metaflow/plugins/aws/step_functions/step_functions_deployer_objects.py +7 -9
  44. metaflow/plugins/cards/card_cli.py +7 -2
  45. metaflow/plugins/cards/card_creator.py +1 -0
  46. metaflow/plugins/cards/card_decorator.py +79 -8
  47. metaflow/plugins/cards/card_modules/basic.py +56 -5
  48. metaflow/plugins/cards/card_modules/card.py +16 -1
  49. metaflow/plugins/cards/card_modules/components.py +64 -16
  50. metaflow/plugins/cards/card_modules/main.js +27 -25
  51. metaflow/plugins/cards/card_modules/test_cards.py +4 -4
  52. metaflow/plugins/cards/component_serializer.py +1 -1
  53. metaflow/plugins/datatools/s3/s3.py +12 -4
  54. metaflow/plugins/datatools/s3/s3op.py +3 -3
  55. metaflow/plugins/events_decorator.py +338 -186
  56. metaflow/plugins/kubernetes/kube_utils.py +84 -1
  57. metaflow/plugins/kubernetes/kubernetes.py +40 -92
  58. metaflow/plugins/kubernetes/kubernetes_cli.py +32 -7
  59. metaflow/plugins/kubernetes/kubernetes_decorator.py +76 -4
  60. metaflow/plugins/kubernetes/kubernetes_job.py +23 -20
  61. metaflow/plugins/kubernetes/kubernetes_jobsets.py +41 -20
  62. metaflow/plugins/kubernetes/spot_metadata_cli.py +69 -0
  63. metaflow/plugins/kubernetes/spot_monitor_sidecar.py +109 -0
  64. metaflow/plugins/parallel_decorator.py +4 -1
  65. metaflow/plugins/project_decorator.py +33 -5
  66. metaflow/plugins/pypi/bootstrap.py +249 -81
  67. metaflow/plugins/pypi/conda_decorator.py +20 -10
  68. metaflow/plugins/pypi/conda_environment.py +83 -27
  69. metaflow/plugins/pypi/micromamba.py +82 -37
  70. metaflow/plugins/pypi/pip.py +9 -6
  71. metaflow/plugins/pypi/pypi_decorator.py +11 -9
  72. metaflow/plugins/pypi/utils.py +4 -2
  73. metaflow/plugins/timeout_decorator.py +2 -2
  74. metaflow/runner/click_api.py +240 -50
  75. metaflow/runner/deployer.py +1 -1
  76. metaflow/runner/deployer_impl.py +12 -11
  77. metaflow/runner/metaflow_runner.py +68 -34
  78. metaflow/runner/nbdeploy.py +2 -0
  79. metaflow/runner/nbrun.py +1 -1
  80. metaflow/runner/subprocess_manager.py +61 -10
  81. metaflow/runner/utils.py +208 -44
  82. metaflow/runtime.py +216 -112
  83. metaflow/sidecar/sidecar_worker.py +1 -1
  84. metaflow/tracing/tracing_modules.py +4 -1
  85. metaflow/user_configs/__init__.py +0 -0
  86. metaflow/user_configs/config_decorators.py +563 -0
  87. metaflow/user_configs/config_options.py +548 -0
  88. metaflow/user_configs/config_parameters.py +436 -0
  89. metaflow/util.py +22 -0
  90. metaflow/version.py +1 -1
  91. {ob_metaflow-2.12.30.2.dist-info → ob_metaflow-2.13.6.1.dist-info}/METADATA +12 -3
  92. {ob_metaflow-2.12.30.2.dist-info → ob_metaflow-2.13.6.1.dist-info}/RECORD +96 -84
  93. {ob_metaflow-2.12.30.2.dist-info → ob_metaflow-2.13.6.1.dist-info}/WHEEL +1 -1
  94. {ob_metaflow-2.12.30.2.dist-info → ob_metaflow-2.13.6.1.dist-info}/LICENSE +0 -0
  95. {ob_metaflow-2.12.30.2.dist-info → ob_metaflow-2.13.6.1.dist-info}/entry_points.txt +0 -0
  96. {ob_metaflow-2.12.30.2.dist-info → ob_metaflow-2.13.6.1.dist-info}/top_level.txt +0 -0
metaflow/lint.py CHANGED
@@ -52,7 +52,7 @@ def check_reserved_words(graph):
52
52
  msg = "Step name *%s* is a reserved word. Choose another name for the " "step."
53
53
  for node in graph:
54
54
  if node.name in RESERVED:
55
- raise LintWarn(msg % node.name)
55
+ raise LintWarn(msg % node.name, node.func_lineno, node.source_file)
56
56
 
57
57
 
58
58
  @linter.ensure_fundamentals
@@ -76,9 +76,9 @@ def check_that_end_is_end(graph):
76
76
  node = graph["end"]
77
77
 
78
78
  if node.has_tail_next or node.invalid_tail_next:
79
- raise LintWarn(msg0, node.tail_next_lineno)
79
+ raise LintWarn(msg0, node.tail_next_lineno, node.source_file)
80
80
  if node.num_args > 1:
81
- raise LintWarn(msg1, node.tail_next_lineno)
81
+ raise LintWarn(msg1, node.tail_next_lineno, node.source_file)
82
82
 
83
83
 
84
84
  @linter.ensure_fundamentals
@@ -90,7 +90,7 @@ def check_step_names(graph):
90
90
  )
91
91
  for node in graph:
92
92
  if re.search("[^a-z0-9_]", node.name) or node.name[0] == "_":
93
- raise LintWarn(msg.format(node), node.func_lineno)
93
+ raise LintWarn(msg.format(node), node.func_lineno, node.source_file)
94
94
 
95
95
 
96
96
  @linter.ensure_fundamentals
@@ -108,11 +108,11 @@ def check_num_args(graph):
108
108
  msg2 = "Step *{0.name}* is missing the 'self' argument."
109
109
  for node in graph:
110
110
  if node.num_args > 2:
111
- raise LintWarn(msg0.format(node), node.func_lineno)
111
+ raise LintWarn(msg0.format(node), node.func_lineno, node.source_file)
112
112
  elif node.num_args == 2 and node.type != "join":
113
- raise LintWarn(msg1.format(node), node.func_lineno)
113
+ raise LintWarn(msg1.format(node), node.func_lineno, node.source_file)
114
114
  elif node.num_args == 0:
115
- raise LintWarn(msg2.format(node), node.func_lineno)
115
+ raise LintWarn(msg2.format(node), node.func_lineno, node.source_file)
116
116
 
117
117
 
118
118
  @linter.ensure_static_graph
@@ -125,7 +125,7 @@ def check_static_transitions(graph):
125
125
  )
126
126
  for node in graph:
127
127
  if node.type != "end" and not node.has_tail_next:
128
- raise LintWarn(msg.format(node), node.func_lineno)
128
+ raise LintWarn(msg.format(node), node.func_lineno, node.source_file)
129
129
 
130
130
 
131
131
  @linter.ensure_static_graph
@@ -138,7 +138,7 @@ def check_valid_transitions(graph):
138
138
  )
139
139
  for node in graph:
140
140
  if node.type != "end" and node.has_tail_next and node.invalid_tail_next:
141
- raise LintWarn(msg.format(node), node.tail_next_lineno)
141
+ raise LintWarn(msg.format(node), node.tail_next_lineno, node.source_file)
142
142
 
143
143
 
144
144
  @linter.ensure_static_graph
@@ -151,7 +151,11 @@ def check_unknown_transitions(graph):
151
151
  for node in graph:
152
152
  unknown = [n for n in node.out_funcs if n not in graph]
153
153
  if unknown:
154
- raise LintWarn(msg.format(node, step=unknown[0]), node.tail_next_lineno)
154
+ raise LintWarn(
155
+ msg.format(node, step=unknown[0]),
156
+ node.tail_next_lineno,
157
+ node.source_file,
158
+ )
155
159
 
156
160
 
157
161
  @linter.ensure_acyclicity
@@ -167,7 +171,9 @@ def check_for_acyclicity(graph):
167
171
  for n in node.out_funcs:
168
172
  if n in seen:
169
173
  path = "->".join(seen + [n])
170
- raise LintWarn(msg.format(path), node.tail_next_lineno)
174
+ raise LintWarn(
175
+ msg.format(path), node.tail_next_lineno, node.source_file
176
+ )
171
177
  else:
172
178
  check_path(graph[n], seen + [n])
173
179
 
@@ -195,7 +201,7 @@ def check_for_orphans(graph):
195
201
  orphans = nodeset - seen
196
202
  if orphans:
197
203
  orphan = graph[list(orphans)[0]]
198
- raise LintWarn(msg.format(orphan), orphan.func_lineno)
204
+ raise LintWarn(msg.format(orphan), orphan.func_lineno, orphan.source_file)
199
205
 
200
206
 
201
207
  @linter.ensure_static_graph
@@ -230,7 +236,9 @@ def check_split_join_balance(graph):
230
236
  if split_stack:
231
237
  _, split_roots = split_stack.pop()
232
238
  roots = ", ".join(split_roots)
233
- raise LintWarn(msg0.format(roots=roots))
239
+ raise LintWarn(
240
+ msg0.format(roots=roots), node.func_lineno, node.source_file
241
+ )
234
242
  elif node.type == "join":
235
243
  if split_stack:
236
244
  _, split_roots = split_stack[-1]
@@ -243,9 +251,10 @@ def check_split_join_balance(graph):
243
251
  node, paths=paths, num_roots=len(split_roots), roots=roots
244
252
  ),
245
253
  node.func_lineno,
254
+ node.source_file,
246
255
  )
247
256
  else:
248
- raise LintWarn(msg2.format(node), node.func_lineno)
257
+ raise LintWarn(msg2.format(node), node.func_lineno, node.source_file)
249
258
 
250
259
  # check that incoming steps come from the same lineage
251
260
  # (no cross joins)
@@ -256,7 +265,7 @@ def check_split_join_balance(graph):
256
265
  return tuple(graph[n].split_parents)
257
266
 
258
267
  if not all_equal(map(parents, node.in_funcs)):
259
- raise LintWarn(msg3.format(node), node.func_lineno)
268
+ raise LintWarn(msg3.format(node), node.func_lineno, node.source_file)
260
269
 
261
270
  for n in node.out_funcs:
262
271
  traverse(graph[n], new_stack)
@@ -276,7 +285,9 @@ def check_empty_foreaches(graph):
276
285
  if node.type == "foreach":
277
286
  joins = [n for n in node.out_funcs if graph[n].type == "join"]
278
287
  if joins:
279
- raise LintWarn(msg.format(node, join=joins[0]))
288
+ raise LintWarn(
289
+ msg.format(node, join=joins[0]), node.func_lineno, node.source_file
290
+ )
280
291
 
281
292
 
282
293
  @linter.ensure_static_graph
@@ -290,7 +301,7 @@ def check_parallel_step_after_next(graph):
290
301
  if node.parallel_foreach and not all(
291
302
  graph[out_node].parallel_step for out_node in node.out_funcs
292
303
  ):
293
- raise LintWarn(msg.format(node))
304
+ raise LintWarn(msg.format(node), node.func_lineno, node.source_file)
294
305
 
295
306
 
296
307
  @linter.ensure_static_graph
@@ -303,7 +314,9 @@ def check_join_followed_by_parallel_step(graph):
303
314
  )
304
315
  for node in graph:
305
316
  if node.parallel_step and not graph[node.out_funcs[0]].type == "join":
306
- raise LintWarn(msg.format(node.out_funcs[0]))
317
+ raise LintWarn(
318
+ msg.format(node.out_funcs[0]), node.func_lineno, node.source_file
319
+ )
307
320
 
308
321
 
309
322
  @linter.ensure_static_graph
@@ -318,7 +331,9 @@ def check_parallel_foreach_calls_parallel_step(graph):
318
331
  for node2 in graph:
319
332
  if node2.out_funcs and node.name in node2.out_funcs:
320
333
  if not node2.parallel_foreach:
321
- raise LintWarn(msg.format(node, node2))
334
+ raise LintWarn(
335
+ msg.format(node, node2), node.func_lineno, node.source_file
336
+ )
322
337
 
323
338
 
324
339
  @linter.ensure_non_nested_foreach
@@ -331,4 +346,4 @@ def check_nested_foreach(graph):
331
346
  for node in graph:
332
347
  if node.type == "foreach":
333
348
  if any(graph[p].type == "foreach" for p in node.split_parents):
334
- raise LintWarn(msg.format(node))
349
+ raise LintWarn(msg.format(node), node.func_lineno, node.source_file)
@@ -1,11 +1,12 @@
1
+ import json
1
2
  import time
3
+ from threading import Thread
4
+
2
5
  import requests
3
- import json
4
6
 
5
- from threading import Thread
6
- from metaflow.sidecar import MessageTypes, Message
7
- from metaflow.metaflow_config import SERVICE_HEADERS
8
7
  from metaflow.exception import MetaflowException
8
+ from metaflow.metaflow_config import SERVICE_HEADERS
9
+ from metaflow.sidecar import Message, MessageTypes
9
10
 
10
11
  HB_URL_KEY = "hb_url"
11
12
 
@@ -52,13 +53,27 @@ class MetadataHeartBeat(object):
52
53
  retry_counter = 0
53
54
  except HeartBeatException as e:
54
55
  retry_counter = retry_counter + 1
55
- time.sleep(4**retry_counter)
56
+ time.sleep(1.5**retry_counter)
56
57
 
57
58
  def _heartbeat(self):
58
59
  if self.hb_url is not None:
59
- response = requests.post(
60
- url=self.hb_url, data="{}", headers=self.headers.copy()
61
- )
60
+ try:
61
+ response = requests.post(
62
+ url=self.hb_url, data="{}", headers=self.headers.copy()
63
+ )
64
+ except requests.exceptions.ConnectionError as e:
65
+ raise HeartBeatException(
66
+ "HeartBeat request (%s) failed" " (ConnectionError)" % (self.hb_url)
67
+ )
68
+ except requests.exceptions.Timeout as e:
69
+ raise HeartBeatException(
70
+ "HeartBeat request (%s) failed" " (Timeout)" % (self.hb_url)
71
+ )
72
+ except requests.exceptions.RequestException as e:
73
+ raise HeartBeatException(
74
+ "HeartBeat request (%s) failed"
75
+ " (RequestException) %s" % (self.hb_url, str(e))
76
+ )
62
77
  # Unfortunately, response.json() returns a string that we need
63
78
  # to cast to json; however when the request encounters an error
64
79
  # the return type is a json blob :/
@@ -356,6 +356,8 @@ KUBERNETES_PERSISTENT_VOLUME_CLAIMS = from_conf(
356
356
  KUBERNETES_SECRETS = from_conf("KUBERNETES_SECRETS", "")
357
357
  # Default labels for kubernetes pods
358
358
  KUBERNETES_LABELS = from_conf("KUBERNETES_LABELS", "")
359
+ # Default annotations for kubernetes pods
360
+ KUBERNETES_ANNOTATIONS = from_conf("KUBERNETES_ANNOTATIONS", "")
359
361
  # Default GPU vendor to use by K8S jobs created by Metaflow (supports nvidia, amd)
360
362
  KUBERNETES_GPU_VENDOR = from_conf("KUBERNETES_GPU_VENDOR", "nvidia")
361
363
  # Default container image for K8S
@@ -378,6 +380,8 @@ KUBERNETES_PORT = from_conf("KUBERNETES_PORT", None)
378
380
  KUBERNETES_CPU = from_conf("KUBERNETES_CPU", None)
379
381
  KUBERNETES_MEMORY = from_conf("KUBERNETES_MEMORY", None)
380
382
  KUBERNETES_DISK = from_conf("KUBERNETES_DISK", None)
383
+ # Default kubernetes QoS class
384
+ KUBERNETES_QOS = from_conf("KUBERNETES_QOS", "burstable")
381
385
 
382
386
  ARGO_WORKFLOWS_KUBERNETES_SECRETS = from_conf("ARGO_WORKFLOWS_KUBERNETES_SECRETS", "")
383
387
  ARGO_WORKFLOWS_ENV_VARS_TO_SKIP = from_conf("ARGO_WORKFLOWS_ENV_VARS_TO_SKIP", "")
@@ -439,7 +443,7 @@ ESCAPE_HATCH_WARNING = from_conf("ESCAPE_HATCH_WARNING", True)
439
443
  ###
440
444
  # Debug configuration
441
445
  ###
442
- DEBUG_OPTIONS = ["subcommand", "sidecar", "s3client", "tracing", "stubgen"]
446
+ DEBUG_OPTIONS = ["subcommand", "sidecar", "s3client", "tracing", "stubgen", "userconf"]
443
447
 
444
448
  for typ in DEBUG_OPTIONS:
445
449
  vars()["DEBUG_%s" % typ.upper()] = from_conf("DEBUG_%s" % typ.upper(), False)
@@ -509,6 +513,11 @@ MAX_CPU_PER_TASK = from_conf("MAX_CPU_PER_TASK")
509
513
  # lexicographic ordering of attempts. This won't work if MAX_ATTEMPTS > 99.
510
514
  MAX_ATTEMPTS = 6
511
515
 
516
+ # Feature flag (experimental features that are *explicitly* unsupported)
517
+
518
+ # Process configs even when using the click_api for Runner/Deployer
519
+ CLICK_API_PROCESS_CONFIG = from_conf("CLICK_API_PROCESS_CONFIG", False)
520
+
512
521
 
513
522
  # PINNED_CONDA_LIBS are the libraries that metaflow depends on for execution
514
523
  # and are needed within a conda environment
@@ -6,7 +6,18 @@ from tempfile import NamedTemporaryFile
6
6
  import time
7
7
  import metaflow.tracing as tracing
8
8
 
9
- from typing import Any, Callable, Iterable, Iterator, List, Optional
9
+ from typing import (
10
+ Any,
11
+ Callable,
12
+ Iterable,
13
+ Iterator,
14
+ List,
15
+ Optional,
16
+ NoReturn,
17
+ Tuple,
18
+ TypeVar,
19
+ Union,
20
+ )
10
21
 
11
22
  try:
12
23
  # Python 2
@@ -30,7 +41,13 @@ class MulticoreException(Exception):
30
41
  pass
31
42
 
32
43
 
33
- def _spawn(func, arg, dir):
44
+ _A = TypeVar("_A")
45
+ _R = TypeVar("_R")
46
+
47
+
48
+ def _spawn(
49
+ func: Callable[[_A], _R], arg: _A, dir: Optional[str]
50
+ ) -> Union[Tuple[int, str], NoReturn]:
34
51
  with NamedTemporaryFile(prefix="parallel_map_", dir=dir, delete=False) as tmpfile:
35
52
  output_file = tmpfile.name
36
53
 
@@ -63,11 +80,11 @@ def _spawn(func, arg, dir):
63
80
 
64
81
 
65
82
  def parallel_imap_unordered(
66
- func: Callable[[Any], Any],
67
- iterable: Iterable[Any],
83
+ func: Callable[[_A], _R],
84
+ iterable: Iterable[_A],
68
85
  max_parallel: Optional[int] = None,
69
86
  dir: Optional[str] = None,
70
- ) -> Iterator[Any]:
87
+ ) -> Iterator[_R]:
71
88
  """
72
89
  Parallelizes execution of a function using multiprocessing. The result
73
90
  order is not guaranteed.
@@ -79,9 +96,9 @@ def parallel_imap_unordered(
79
96
  iterable : Iterable[Any]
80
97
  Iterable over arguments to pass to fun
81
98
  max_parallel int, optional, default None
82
- Maximum parallelism. If not specified, uses the number of CPUs
99
+ Maximum parallelism. If not specified, it uses the number of CPUs
83
100
  dir : str, optional, default None
84
- If specified, directory where temporary files are created
101
+ If specified, it's the directory where temporary files are created
85
102
 
86
103
  Yields
87
104
  ------
@@ -121,14 +138,14 @@ def parallel_imap_unordered(
121
138
 
122
139
 
123
140
  def parallel_map(
124
- func: Callable[[Any], Any],
125
- iterable: Iterable[Any],
141
+ func: Callable[[_A], _R],
142
+ iterable: Iterable[_A],
126
143
  max_parallel: Optional[int] = None,
127
144
  dir: Optional[str] = None,
128
- ) -> List[Any]:
145
+ ) -> List[_R]:
129
146
  """
130
147
  Parallelizes execution of a function using multiprocessing. The result
131
- order is that of the arguments in `iterable`
148
+ order is that of the arguments in `iterable`.
132
149
 
133
150
  Parameters
134
151
  ----------
@@ -137,9 +154,9 @@ def parallel_map(
137
154
  iterable : Iterable[Any]
138
155
  Iterable over arguments to pass to fun
139
156
  max_parallel int, optional, default None
140
- Maximum parallelism. If not specified, uses the number of CPUs
157
+ Maximum parallelism. If not specified, it uses the number of CPUs
141
158
  dir : str, optional, default None
142
- If specified, directory where temporary files are created
159
+ If specified, it's the directory where temporary files are created
143
160
 
144
161
  Returns
145
162
  -------
@@ -155,4 +172,4 @@ def parallel_map(
155
172
  res = parallel_imap_unordered(
156
173
  wrapper, enumerate(iterable), max_parallel=max_parallel, dir=dir
157
174
  )
158
- return [r for idx, r in sorted(res)]
175
+ return [r for _, r in sorted(res)]
metaflow/package.py CHANGED
@@ -6,6 +6,7 @@ import time
6
6
  import json
7
7
  from io import BytesIO
8
8
 
9
+ from .user_configs.config_parameters import CONFIG_FILE, dump_config_values
9
10
  from .extension_support import EXT_PKG, package_mfext_all
10
11
  from .metaflow_config import DEFAULT_PACKAGE_SUFFIXES
11
12
  from .exception import MetaflowException
@@ -151,11 +152,23 @@ class MetaflowPackage(object):
151
152
  for path_tuple in self._walk(flowdir, suffixes=self.suffixes):
152
153
  yield path_tuple
153
154
 
155
+ def _add_configs(self, tar):
156
+ buf = BytesIO()
157
+ buf.write(json.dumps(dump_config_values(self._flow)).encode("utf-8"))
158
+ self._add_file(tar, os.path.basename(CONFIG_FILE), buf)
159
+
154
160
  def _add_info(self, tar):
155
- info = tarfile.TarInfo(os.path.basename(INFO_FILE))
156
- env = self.environment.get_environment_info(include_ext_info=True)
157
161
  buf = BytesIO()
158
- buf.write(json.dumps(env).encode("utf-8"))
162
+ buf.write(
163
+ json.dumps(
164
+ self.environment.get_environment_info(include_ext_info=True)
165
+ ).encode("utf-8")
166
+ )
167
+ self._add_file(tar, os.path.basename(INFO_FILE), buf)
168
+
169
+ @staticmethod
170
+ def _add_file(tar, filename, buf):
171
+ info = tarfile.TarInfo(filename)
159
172
  buf.seek(0)
160
173
  info.size = len(buf.getvalue())
161
174
  # Setting this default to Dec 3, 2019
@@ -175,6 +188,7 @@ class MetaflowPackage(object):
175
188
  fileobj=buf, mode="w:gz", compresslevel=3, dereference=True
176
189
  ) as tar:
177
190
  self._add_info(tar)
191
+ self._add_configs(tar)
178
192
  for path, arcname in self.path_tuples():
179
193
  tar.add(path, arcname=arcname, recursive=False, filter=no_mtime)
180
194
 
metaflow/parameters.py CHANGED
@@ -3,7 +3,7 @@ import json
3
3
  from contextlib import contextmanager
4
4
  from threading import local
5
5
 
6
- from typing import Any, Callable, Dict, NamedTuple, Optional, Type, Union
6
+ from typing import Any, Callable, Dict, NamedTuple, Optional, TYPE_CHECKING, Type, Union
7
7
 
8
8
  from metaflow._vendor import click
9
9
 
@@ -14,6 +14,9 @@ from .exception import (
14
14
  MetaflowException,
15
15
  )
16
16
 
17
+ if TYPE_CHECKING:
18
+ from .user_configs.config_parameters import ConfigValue
19
+
17
20
  try:
18
21
  # Python2
19
22
  strtype = basestring
@@ -32,6 +35,7 @@ ParameterContext = NamedTuple(
32
35
  ("parameter_name", str),
33
36
  ("logger", Callable[..., None]),
34
37
  ("ds_type", str),
38
+ ("configs", Optional["ConfigValue"]),
35
39
  ],
36
40
  )
37
41
 
@@ -72,6 +76,16 @@ def flow_context(flow_cls):
72
76
  context_proto = None
73
77
 
74
78
 
79
+ def replace_flow_context(flow_cls):
80
+ """
81
+ Replace the current flow context with a new flow class. This is used
82
+ when we change the current flow class after having run user configuration functions
83
+ """
84
+ current_flow.flow_cls_stack = current_flow.flow_cls_stack[1:]
85
+ current_flow.flow_cls_stack.insert(0, flow_cls)
86
+ current_flow.flow_cls = current_flow.flow_cls_stack[0]
87
+
88
+
75
89
  class JSONTypeClass(click.ParamType):
76
90
  name = "JSON"
77
91
 
@@ -151,6 +165,7 @@ class DeployTimeField(object):
151
165
  return self._check_type(val, deploy_time)
152
166
 
153
167
  def _check_type(self, val, deploy_time):
168
+
154
169
  # it is easy to introduce a deploy-time function that accidentally
155
170
  # returns a value whose type is not compatible with what is defined
156
171
  # in Parameter. Let's catch those mistakes early here, instead of
@@ -158,7 +173,7 @@ class DeployTimeField(object):
158
173
 
159
174
  # note: this doesn't work with long in Python2 or types defined as
160
175
  # click types, e.g. click.INT
161
- TYPES = {bool: "bool", int: "int", float: "float", list: "list"}
176
+ TYPES = {bool: "bool", int: "int", float: "float", list: "list", dict: "dict"}
162
177
 
163
178
  msg = (
164
179
  "The value returned by the deploy-time function for "
@@ -166,7 +181,12 @@ class DeployTimeField(object):
166
181
  % (self.parameter_name, self.field)
167
182
  )
168
183
 
169
- if self.parameter_type in TYPES:
184
+ if isinstance(self.parameter_type, list):
185
+ if not any(isinstance(val, x) for x in self.parameter_type):
186
+ msg += "Expected one of the following %s." % TYPES[self.parameter_type]
187
+ raise ParameterFieldTypeMismatch(msg)
188
+ return str(val) if self.return_str else val
189
+ elif self.parameter_type in TYPES:
170
190
  if type(val) != self.parameter_type:
171
191
  msg += "Expected a %s." % TYPES[self.parameter_type]
172
192
  raise ParameterFieldTypeMismatch(msg)
@@ -204,12 +224,18 @@ class DeployTimeField(object):
204
224
  def deploy_time_eval(value):
205
225
  if isinstance(value, DeployTimeField):
206
226
  return value(deploy_time=True)
227
+ elif isinstance(value, DelayedEvaluationParameter):
228
+ return value(return_str=True)
207
229
  else:
208
230
  return value
209
231
 
210
232
 
211
233
  # this is called by cli.main
212
- def set_parameter_context(flow_name, echo, datastore):
234
+ def set_parameter_context(flow_name, echo, datastore, configs):
235
+ from .user_configs.config_parameters import (
236
+ ConfigValue,
237
+ ) # Prevent circular dependency
238
+
213
239
  global context_proto
214
240
  context_proto = ParameterContext(
215
241
  flow_name=flow_name,
@@ -217,6 +243,7 @@ def set_parameter_context(flow_name, echo, datastore):
217
243
  parameter_name=None,
218
244
  logger=echo,
219
245
  ds_type=datastore.TYPE,
246
+ configs=ConfigValue(dict(configs)),
220
247
  )
221
248
 
222
249
 
@@ -273,7 +300,11 @@ class Parameter(object):
273
300
  ----------
274
301
  name : str
275
302
  User-visible parameter name.
276
- default : str or float or int or bool or `JSONType` or a function.
303
+ default : Union[str, float, int, bool, Dict[str, Any],
304
+ Callable[
305
+ [ParameterContext], Union[str, float, int, bool, Dict[str, Any]]
306
+ ],
307
+ ], optional, default None
277
308
  Default value for the parameter. Use a special `JSONType` class to
278
309
  indicate that the value must be a valid JSON object. A function
279
310
  implies that the parameter corresponds to a *deploy-time parameter*.
@@ -282,15 +313,19 @@ class Parameter(object):
282
313
  If `default` is not specified, define the parameter type. Specify
283
314
  one of `str`, `float`, `int`, `bool`, or `JSONType`. If None, defaults
284
315
  to the type of `default` or `str` if none specified.
285
- help : str, optional
316
+ help : str, optional, default None
286
317
  Help text to show in `run --help`.
287
- required : bool, default False
288
- Require that the user specified a value for the parameter.
289
- `required=True` implies that the `default` is not used.
290
- show_default : bool, default True
291
- If True, show the default value in the help text.
318
+ required : bool, optional, default None
319
+ Require that the user specified a value for the parameter. Note that if
320
+ a default is provide, the required flag is ignored.
321
+ A value of None is equivalent to False.
322
+ show_default : bool, optional, default None
323
+ If True, show the default value in the help text. A value of None is equivalent
324
+ to True.
292
325
  """
293
326
 
327
+ IS_CONFIG_PARAMETER = False
328
+
294
329
  def __init__(
295
330
  self,
296
331
  name: str,
@@ -301,31 +336,62 @@ class Parameter(object):
301
336
  int,
302
337
  bool,
303
338
  Dict[str, Any],
304
- Callable[[], Union[str, float, int, bool, Dict[str, Any]]],
339
+ Callable[
340
+ [ParameterContext], Union[str, float, int, bool, Dict[str, Any]]
341
+ ],
305
342
  ]
306
343
  ] = None,
307
344
  type: Optional[
308
345
  Union[Type[str], Type[float], Type[int], Type[bool], JSONTypeClass]
309
346
  ] = None,
310
347
  help: Optional[str] = None,
311
- required: bool = False,
312
- show_default: bool = True,
348
+ required: Optional[bool] = None,
349
+ show_default: Optional[bool] = None,
313
350
  **kwargs: Dict[str, Any]
314
351
  ):
315
352
  self.name = name
316
353
  self.kwargs = kwargs
317
- for k, v in {
354
+ self._override_kwargs = {
318
355
  "default": default,
319
356
  "type": type,
320
357
  "help": help,
321
358
  "required": required,
322
359
  "show_default": show_default,
323
- }.items():
324
- if v is not None:
325
- self.kwargs[k] = v
360
+ }
361
+
362
+ def init(self, ignore_errors=False):
363
+ # Prevent circular import
364
+ from .user_configs.config_parameters import (
365
+ resolve_delayed_evaluator,
366
+ unpack_delayed_evaluator,
367
+ )
368
+
369
+ # Resolve any value from configurations
370
+ self.kwargs, _ = unpack_delayed_evaluator(
371
+ self.kwargs, ignore_errors=ignore_errors
372
+ )
373
+ # Do it one item at a time so errors are ignored at that level (as opposed to
374
+ # at the entire kwargs level)
375
+ self.kwargs = {
376
+ k: resolve_delayed_evaluator(v, ignore_errors=ignore_errors)
377
+ for k, v in self.kwargs.items()
378
+ }
379
+
380
+ # This was the behavior before configs: values specified in args would override
381
+ # stuff in kwargs which is what we implement here as well
382
+ for key, value in self._override_kwargs.items():
383
+ if value is not None:
384
+ self.kwargs[key] = resolve_delayed_evaluator(
385
+ value, ignore_errors=ignore_errors
386
+ )
387
+ # Set two default values if no-one specified them
388
+ self.kwargs.setdefault("required", False)
389
+ self.kwargs.setdefault("show_default", True)
390
+
391
+ # Continue processing kwargs free of any configuration values :)
326
392
 
327
393
  # TODO: check that the type is one of the supported types
328
- param_type = self.kwargs["type"] = self._get_type(kwargs)
394
+ param_type = self.kwargs["type"] = self._get_type(self.kwargs)
329
395
 
330
396
  reserved_params = [
331
397
  "params",
@@ -350,23 +416,27 @@ class Parameter(object):
350
416
  raise MetaflowException(
351
417
  "Parameter name '%s' is a reserved "
352
418
  "word. Please use a different "
353
- "name for your parameter." % (name)
419
+ "name for your parameter." % (self.name)
354
420
  )
355
421
 
356
422
  # make sure the user is not trying to pass a function in one of the
357
423
  # fields that don't support function-values yet
358
424
  for field in ("show_default", "separator", "required"):
359
- if callable(kwargs.get(field)):
425
+ if callable(self.kwargs.get(field)):
360
426
  raise MetaflowException(
361
427
  "Parameter *%s*: Field '%s' cannot "
362
- "have a function as its value" % (name, field)
428
+ "have a function as its value" % (self.name, field)
363
429
  )
364
430
 
365
431
  # default can be defined as a function
366
432
  default_field = self.kwargs.get("default")
367
433
  if callable(default_field) and not isinstance(default_field, DeployTimeField):
368
434
  self.kwargs["default"] = DeployTimeField(
369
- name, param_type, "default", self.kwargs["default"], return_str=True
435
+ self.name,
436
+ param_type,
437
+ "default",
438
+ self.kwargs["default"],
439
+ return_str=True,
370
440
  )
371
441
 
372
442
  # note that separator doesn't work with DeployTimeFields unless you
@@ -375,7 +445,7 @@ class Parameter(object):
375
445
  if self.separator and not self.is_string_type:
376
446
  raise MetaflowException(
377
447
  "Parameter *%s*: Separator is only allowed "
378
- "for string parameters." % name
448
+ "for string parameters." % self.name
379
449
  )
380
450
 
381
451
  def __repr__(self):
@@ -431,7 +501,9 @@ def add_custom_parameters(deploy_mode=False):
431
501
  flow_cls = getattr(current_flow, "flow_cls", None)
432
502
  if flow_cls is None:
433
503
  return cmd
434
- parameters = [p for _, p in flow_cls._get_parameters()]
504
+ parameters = [
505
+ p for _, p in flow_cls._get_parameters() if not p.IS_CONFIG_PARAMETER
506
+ ]
435
507
  for arg in parameters[::-1]:
436
508
  kwargs = arg.option_kwargs(deploy_mode)
437
509
  cmd.params.insert(0, click.Option(("--" + arg.name,), **kwargs))