ob-metaflow 2.16.8.2rc0__py2.py3-none-any.whl → 2.16.8.2rc2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ob-metaflow might be problematic. Click here for more details.

Files changed (61) hide show
  1. metaflow/_vendor/click/core.py +4 -3
  2. metaflow/_vendor/imghdr/__init__.py +1 -7
  3. metaflow/cli.py +2 -11
  4. metaflow/cli_components/run_cmds.py +15 -0
  5. metaflow/client/core.py +1 -6
  6. metaflow/extension_support/__init__.py +3 -4
  7. metaflow/flowspec.py +113 -1
  8. metaflow/graph.py +134 -10
  9. metaflow/lint.py +70 -3
  10. metaflow/metaflow_environment.py +6 -14
  11. metaflow/package/__init__.py +9 -18
  12. metaflow/packaging_sys/__init__.py +43 -53
  13. metaflow/packaging_sys/backend.py +6 -21
  14. metaflow/packaging_sys/tar_backend.py +3 -16
  15. metaflow/packaging_sys/v1.py +21 -21
  16. metaflow/plugins/argo/argo_client.py +14 -31
  17. metaflow/plugins/argo/argo_workflows.py +22 -66
  18. metaflow/plugins/argo/argo_workflows_cli.py +2 -1
  19. metaflow/plugins/argo/argo_workflows_deployer_objects.py +0 -69
  20. metaflow/plugins/aws/step_functions/step_functions.py +6 -0
  21. metaflow/plugins/aws/step_functions/step_functions_deployer_objects.py +0 -30
  22. metaflow/plugins/cards/card_modules/basic.py +14 -2
  23. metaflow/plugins/cards/card_modules/convert_to_native_type.py +1 -7
  24. metaflow/plugins/kubernetes/kubernetes_decorator.py +1 -1
  25. metaflow/plugins/kubernetes/kubernetes_jobsets.py +28 -26
  26. metaflow/plugins/pypi/conda_decorator.py +2 -4
  27. metaflow/runner/click_api.py +7 -14
  28. metaflow/runner/deployer.py +7 -160
  29. metaflow/runner/subprocess_manager.py +12 -20
  30. metaflow/runtime.py +102 -27
  31. metaflow/task.py +46 -25
  32. metaflow/user_decorators/mutable_flow.py +1 -3
  33. metaflow/util.py +29 -0
  34. metaflow/vendor.py +6 -23
  35. metaflow/version.py +1 -1
  36. {ob_metaflow-2.16.8.2rc0.dist-info → ob_metaflow-2.16.8.2rc2.dist-info}/METADATA +2 -2
  37. {ob_metaflow-2.16.8.2rc0.dist-info → ob_metaflow-2.16.8.2rc2.dist-info}/RECORD +44 -61
  38. metaflow/_vendor/yaml/__init__.py +0 -427
  39. metaflow/_vendor/yaml/composer.py +0 -139
  40. metaflow/_vendor/yaml/constructor.py +0 -748
  41. metaflow/_vendor/yaml/cyaml.py +0 -101
  42. metaflow/_vendor/yaml/dumper.py +0 -62
  43. metaflow/_vendor/yaml/emitter.py +0 -1137
  44. metaflow/_vendor/yaml/error.py +0 -75
  45. metaflow/_vendor/yaml/events.py +0 -86
  46. metaflow/_vendor/yaml/loader.py +0 -63
  47. metaflow/_vendor/yaml/nodes.py +0 -49
  48. metaflow/_vendor/yaml/parser.py +0 -589
  49. metaflow/_vendor/yaml/reader.py +0 -185
  50. metaflow/_vendor/yaml/representer.py +0 -389
  51. metaflow/_vendor/yaml/resolver.py +0 -227
  52. metaflow/_vendor/yaml/scanner.py +0 -1435
  53. metaflow/_vendor/yaml/serializer.py +0 -111
  54. metaflow/_vendor/yaml/tokens.py +0 -104
  55. {ob_metaflow-2.16.8.2rc0.data → ob_metaflow-2.16.8.2rc2.data}/data/share/metaflow/devtools/Makefile +0 -0
  56. {ob_metaflow-2.16.8.2rc0.data → ob_metaflow-2.16.8.2rc2.data}/data/share/metaflow/devtools/Tiltfile +0 -0
  57. {ob_metaflow-2.16.8.2rc0.data → ob_metaflow-2.16.8.2rc2.data}/data/share/metaflow/devtools/pick_services.sh +0 -0
  58. {ob_metaflow-2.16.8.2rc0.dist-info → ob_metaflow-2.16.8.2rc2.dist-info}/WHEEL +0 -0
  59. {ob_metaflow-2.16.8.2rc0.dist-info → ob_metaflow-2.16.8.2rc2.dist-info}/entry_points.txt +0 -0
  60. {ob_metaflow-2.16.8.2rc0.dist-info → ob_metaflow-2.16.8.2rc2.dist-info}/licenses/LICENSE +0 -0
  61. {ob_metaflow-2.16.8.2rc0.dist-info → ob_metaflow-2.16.8.2rc2.dist-info}/top_level.txt +0 -0
@@ -719,7 +719,7 @@ class BaseCommand(object):
719
719
  prog_name=None,
720
720
  complete_var=None,
721
721
  standalone_mode=True,
722
- **extra
722
+ **extra,
723
723
  ):
724
724
  """This is the way to invoke a script with all the bells and
725
725
  whistles as a command line application. This will always terminate
@@ -1101,7 +1101,7 @@ class MultiCommand(Command):
1101
1101
  subcommand_metavar=None,
1102
1102
  chain=False,
1103
1103
  result_callback=None,
1104
- **attrs
1104
+ **attrs,
1105
1105
  ):
1106
1106
  Command.__init__(self, name, **attrs)
1107
1107
  if no_args_is_help is None:
@@ -1463,6 +1463,7 @@ class Parameter(object):
1463
1463
  parameter. The old callback format will still work, but it will
1464
1464
  raise a warning to give you a chance to migrate the code easier.
1465
1465
  """
1466
+
1466
1467
  param_type_name = "parameter"
1467
1468
 
1468
1469
  def __init__(
@@ -1708,7 +1709,7 @@ class Option(Parameter):
1708
1709
  hidden=False,
1709
1710
  show_choices=True,
1710
1711
  show_envvar=False,
1711
- **attrs
1712
+ **attrs,
1712
1713
  ):
1713
1714
  default_is_missing = attrs.get("default", _missing) is _missing
1714
1715
  Parameter.__init__(self, param_decls, type=type, **attrs)
@@ -6,13 +6,7 @@ import warnings
6
6
  __all__ = ["what"]
7
7
 
8
8
 
9
- # python-deadlib: Replace deprecation warning not to raise exception
10
- warnings.warn(
11
- f"{__name__} was removed in Python 3.13. "
12
- f"Please be aware that you are currently NOT using standard '{__name__}', "
13
- f"but instead a separately installed 'standard-{__name__}'.",
14
- DeprecationWarning, stacklevel=2
15
- )
9
+ warnings._deprecated(__name__, remove=(3, 13))
16
10
 
17
11
 
18
12
  #-------------------------#
metaflow/cli.py CHANGED
@@ -7,7 +7,6 @@ from datetime import datetime
7
7
 
8
8
  import metaflow.tracing as tracing
9
9
  from metaflow._vendor import click
10
- from metaflow.system import _system_logger, _system_monitor
11
10
 
12
11
  from . import decorators, lint, metaflow_version, parameters, plugins
13
12
  from .cli_args import cli_args
@@ -27,6 +26,7 @@ from .metaflow_config import (
27
26
  DEFAULT_PACKAGE_SUFFIXES,
28
27
  )
29
28
  from .metaflow_current import current
29
+ from metaflow.system import _system_monitor, _system_logger
30
30
  from .metaflow_environment import MetaflowEnvironment
31
31
  from .packaging_sys import MetaflowCodeContent
32
32
  from .plugins import (
@@ -38,9 +38,9 @@ from .plugins import (
38
38
  )
39
39
  from .pylint_wrapper import PyLint
40
40
  from .R import metaflow_r_version, use_r
41
+ from .util import get_latest_run_id, resolve_identity
41
42
  from .user_configs.config_options import LocalFileInput, config_options
42
43
  from .user_configs.config_parameters import ConfigValue
43
- from .util import get_latest_run_id, resolve_identity
44
44
 
45
45
  ERASE_TO_EOL = "\033[K"
46
46
  HIGHLIGHT = "red"
@@ -56,15 +56,6 @@ def echo_dev_null(*args, **kwargs):
56
56
 
57
57
 
58
58
  def echo_always(line, **kwargs):
59
- if kwargs.pop("wrap", False):
60
- import textwrap
61
-
62
- indent_str = INDENT if kwargs.get("indent", None) else ""
63
- effective_width = 80 - len(indent_str)
64
- wrapped = textwrap.wrap(line, width=effective_width, break_long_words=False)
65
- line = "\n".join(indent_str + l for l in wrapped)
66
- kwargs["indent"] = False
67
-
68
59
  kwargs["err"] = kwargs.get("err", True)
69
60
  if kwargs.pop("indent", None):
70
61
  line = "\n".join(INDENT + x for x in line.splitlines())
@@ -13,6 +13,8 @@ from ..package import MetaflowPackage
13
13
  from ..runtime import NativeRuntime
14
14
  from ..system import _system_logger
15
15
 
16
+ # from ..client.core import Run
17
+
16
18
  from ..tagging_util import validate_tags
17
19
  from ..util import get_latest_run_id, write_latest_run_id
18
20
 
@@ -230,6 +232,19 @@ def resume(
230
232
  step_to_rerun, ",".join(list(obj.graph.nodes.keys()))
231
233
  )
232
234
  )
235
+
236
+ ## TODO: instead of checking execution path here, can add a warning later
237
+ ## instead of throwing an error. This is for resuming a step which was not
238
+ ## taken inside a branch i.e. not present in the execution path.
239
+
240
+ # origin_run = Run(f"{obj.flow.name}/{origin_run_id}", _namespace_check=False)
241
+ # executed_steps = {step.path_components[-1] for step in origin_run}
242
+ # if step_to_rerun not in executed_steps:
243
+ # raise CommandException(
244
+ # f"Cannot resume from step '{step_to_rerun}'. This step was not "
245
+ # f"part of the original execution path for run '{origin_run_id}'."
246
+ # )
247
+
233
248
  steps_to_rerun = {step_to_rerun}
234
249
 
235
250
  if run_id:
metaflow/client/core.py CHANGED
@@ -831,12 +831,10 @@ class MetaflowCode(object):
831
831
  )
832
832
  self._code_obj = BytesIO(blobdata)
833
833
  self._info = MetaflowPackage.cls_get_info(self._code_metadata, self._code_obj)
834
- self._code_obj.seek(0)
835
834
  if self._info:
836
835
  self._flowspec = MetaflowPackage.cls_get_content(
837
836
  self._code_metadata, self._code_obj, self._info["script"]
838
837
  )
839
- self._code_obj.seek(0)
840
838
  else:
841
839
  raise MetaflowInternalError("Code package metadata is invalid.")
842
840
 
@@ -887,9 +885,7 @@ class MetaflowCode(object):
887
885
  TarFile for everything in this code package
888
886
  """
889
887
  if self._backend.type == "tgz":
890
- to_return = self._backend.cls_open(self._code_obj)
891
- self._code_obj.seek(0)
892
- return to_return
888
+ return self._backend.cls_open(self._code_obj)
893
889
  raise RuntimeError("Archive is not a tarball")
894
890
 
895
891
  def extract(self) -> TemporaryDirectory:
@@ -925,7 +921,6 @@ class MetaflowCode(object):
925
921
  MetaflowPackage.cls_extract_into(
926
922
  self._code_metadata, self._code_obj, tmp.name, ContentType.USER_CONTENT
927
923
  )
928
- self._code_obj.seek(0)
929
924
  return tmp
930
925
 
931
926
  @property
@@ -205,10 +205,9 @@ def package_mfext_all():
205
205
  # the packaged metaflow_extensions directory "self-contained" so that
206
206
  # python doesn't go and search other parts of the system for more
207
207
  # metaflow_extensions.
208
- if _all_packages:
209
- yield os.path.join(
210
- os.path.dirname(os.path.abspath(__file__)), "_empty_file.py"
211
- ), os.path.join(EXT_PKG, "__init__.py")
208
+ yield os.path.join(
209
+ os.path.dirname(os.path.abspath(__file__)), "_empty_file.py"
210
+ ), os.path.join(EXT_PKG, "__init__.py")
212
211
 
213
212
  for p in _all_packages:
214
213
  for path_tuple in package_mfext_package(p):
metaflow/flowspec.py CHANGED
@@ -788,6 +788,35 @@ class FlowSpec(metaclass=FlowSpecMeta):
788
788
  value = item if _is_primitive_type(item) else reprlib.Repr().repr(item)
789
789
  return basestring(value)[:MAXIMUM_FOREACH_VALUE_CHARS]
790
790
 
791
+ def _validate_switch_cases(self, switch_cases, step):
792
+ resolved_cases = {}
793
+ for case_key, step_method in switch_cases.items():
794
+ if isinstance(case_key, str) and case_key.startswith("config:"):
795
+ full_path = case_key[len("config:") :]
796
+ parts = full_path.split(".", 1)
797
+ if len(parts) == 2:
798
+ config_var_name, config_key_name = parts
799
+ try:
800
+ config_obj = getattr(self, config_var_name)
801
+ resolved_key = str(getattr(config_obj, config_key_name))
802
+ except AttributeError:
803
+ msg = (
804
+ "Step *{step}* references unknown config '{path}' "
805
+ "in switch case.".format(step=step, path=full_path)
806
+ )
807
+ raise InvalidNextException(msg)
808
+ else:
809
+ raise MetaflowInternalError(
810
+ "Invalid config path format in switch case."
811
+ )
812
+ else:
813
+ resolved_key = case_key
814
+
815
+ func_name = step_method.__func__.__name__
816
+ resolved_cases[resolved_key] = func_name
817
+
818
+ return resolved_cases
819
+
791
820
  def next(self, *dsts: Callable[..., None], **kwargs) -> None:
792
821
  """
793
822
  Indicates the next step to execute after this step has completed.
@@ -812,6 +841,15 @@ class FlowSpec(metaclass=FlowSpecMeta):
812
841
  evaluates to an iterator. A task will be launched for each value in the iterator and
813
842
  each task will execute the code specified by the step `foreach_step`.
814
843
 
844
+ - Switch statement:
845
+ ```
846
+ self.next({"case1": self.step_a, "case2": self.step_b}, condition='condition_variable')
847
+ ```
848
+ In this situation, `step_a` and `step_b` are methods in the current class decorated
849
+ with the `@step` decorator and `condition_variable` is a variable name in the current
850
+ class. The value of the condition variable determines which step to execute. If the
851
+ value doesn't match any of the dictionary keys, a RuntimeError is raised.
852
+
815
853
  Parameters
816
854
  ----------
817
855
  dsts : Callable[..., None]
@@ -827,6 +865,7 @@ class FlowSpec(metaclass=FlowSpecMeta):
827
865
 
828
866
  foreach = kwargs.pop("foreach", None)
829
867
  num_parallel = kwargs.pop("num_parallel", None)
868
+ condition = kwargs.pop("condition", None)
830
869
  if kwargs:
831
870
  kw = next(iter(kwargs))
832
871
  msg = (
@@ -843,6 +882,79 @@ class FlowSpec(metaclass=FlowSpecMeta):
843
882
  )
844
883
  raise InvalidNextException(msg)
845
884
 
885
+ # check: switch case using condition
886
+ if condition is not None:
887
+ if len(dsts) != 1 or not isinstance(dsts[0], dict) or not dsts[0]:
888
+ msg = (
889
+ "Step *{step}* has an invalid self.next() transition. "
890
+ "When using 'condition', the transition must be to a single, "
891
+ "non-empty dictionary mapping condition values to step methods.".format(
892
+ step=step
893
+ )
894
+ )
895
+ raise InvalidNextException(msg)
896
+
897
+ if not isinstance(condition, basestring):
898
+ msg = (
899
+ "Step *{step}* has an invalid self.next() transition. "
900
+ "The argument to 'condition' must be a string.".format(step=step)
901
+ )
902
+ raise InvalidNextException(msg)
903
+
904
+ if foreach is not None or num_parallel is not None:
905
+ msg = (
906
+ "Step *{step}* has an invalid self.next() transition. "
907
+ "Switch statements cannot be combined with foreach or num_parallel.".format(
908
+ step=step
909
+ )
910
+ )
911
+ raise InvalidNextException(msg)
912
+
913
+ switch_cases = dsts[0]
914
+
915
+ # Validate that condition variable exists
916
+ try:
917
+ condition_value = getattr(self, condition)
918
+ except AttributeError:
919
+ msg = (
920
+ "Condition variable *self.{var}* in step *{step}* "
921
+ "does not exist. Make sure you set self.{var} in this step.".format(
922
+ step=step, var=condition
923
+ )
924
+ )
925
+ raise InvalidNextException(msg)
926
+
927
+ resolved_switch_cases = self._validate_switch_cases(switch_cases, step)
928
+
929
+ if str(condition_value) not in resolved_switch_cases:
930
+ available_cases = list(resolved_switch_cases.keys())
931
+ raise RuntimeError(
932
+ f"Switch condition variable '{condition}' has value '{condition_value}' "
933
+ f"which is not in the available cases: {available_cases}"
934
+ )
935
+
936
+ # Get the chosen step and set transition directly
937
+ chosen_step = resolved_switch_cases[str(condition_value)]
938
+
939
+ # Validate that the chosen step exists
940
+ if not hasattr(self, chosen_step):
941
+ msg = (
942
+ "Step *{step}* specifies a switch transition to an "
943
+ "unknown step, *{name}*.".format(step=step, name=chosen_step)
944
+ )
945
+ raise InvalidNextException(msg)
946
+
947
+ self._transition = ([chosen_step], None)
948
+ return
949
+
950
+ # Check for an invalid transition: a dictionary used without a 'condition' parameter.
951
+ if len(dsts) == 1 and isinstance(dsts[0], dict):
952
+ msg = (
953
+ "Step *{step}* has an invalid self.next() transition. "
954
+ "Dictionary argument requires 'condition' parameter.".format(step=step)
955
+ )
956
+ raise InvalidNextException(msg)
957
+
846
958
  # check: all destinations are methods of this object
847
959
  funcs = []
848
960
  for i, dst in enumerate(dsts):
@@ -933,7 +1045,7 @@ class FlowSpec(metaclass=FlowSpecMeta):
933
1045
  self._foreach_var = foreach
934
1046
 
935
1047
  # check: non-keyword transitions are valid
936
- if foreach is None:
1048
+ if foreach is None and condition is None:
937
1049
  if len(dsts) < 1:
938
1050
  msg = (
939
1051
  "Step *{step}* has an invalid self.next() transition. "
metaflow/graph.py CHANGED
@@ -68,6 +68,8 @@ class DAGNode(object):
68
68
  self.has_tail_next = False
69
69
  self.invalid_tail_next = False
70
70
  self.num_args = 0
71
+ self.switch_cases = {}
72
+ self.condition = None
71
73
  self.foreach_param = None
72
74
  self.num_parallel = 0
73
75
  self.parallel_foreach = False
@@ -83,6 +85,56 @@ class DAGNode(object):
83
85
  def _expr_str(self, expr):
84
86
  return "%s.%s" % (expr.value.id, expr.attr)
85
87
 
88
+ def _parse_switch_dict(self, dict_node):
89
+ switch_cases = {}
90
+
91
+ if isinstance(dict_node, ast.Dict):
92
+ for key, value in zip(dict_node.keys, dict_node.values):
93
+ case_key = None
94
+
95
+ # handle string literals
96
+ if isinstance(key, ast.Str):
97
+ case_key = key.s
98
+ elif isinstance(key, ast.Constant) and isinstance(key.value, str):
99
+ case_key = key.value
100
+ elif isinstance(key, ast.Attribute):
101
+ if isinstance(key.value, ast.Attribute) and isinstance(
102
+ key.value.value, ast.Name
103
+ ):
104
+ # This handles self.config.some_key
105
+ if key.value.value.id == "self":
106
+ config_var = key.value.attr
107
+ config_key = key.attr
108
+ case_key = f"config:{config_var}.{config_key}"
109
+ else:
110
+ return None
111
+ else:
112
+ return None
113
+
114
+ # handle variables or other dynamic expressions - not allowed
115
+ elif isinstance(key, ast.Name):
116
+ return None
117
+ else:
118
+ # can't statically analyze this key
119
+ return None
120
+
121
+ if case_key is None:
122
+ return None
123
+
124
+ # extract the step name from the value
125
+ if isinstance(value, ast.Attribute) and isinstance(
126
+ value.value, ast.Name
127
+ ):
128
+ if value.value.id == "self":
129
+ step_name = value.attr
130
+ switch_cases[case_key] = step_name
131
+ else:
132
+ return None
133
+ else:
134
+ return None
135
+
136
+ return switch_cases if switch_cases else None
137
+
86
138
  def _parse(self, func_ast, lineno):
87
139
  self.num_args = len(func_ast.args.args)
88
140
  tail = func_ast.body[-1]
@@ -104,7 +156,38 @@ class DAGNode(object):
104
156
  self.has_tail_next = True
105
157
  self.invalid_tail_next = True
106
158
  self.tail_next_lineno = lineno + tail.lineno - 1
107
- self.out_funcs = [e.attr for e in tail.value.args]
159
+
160
+ # Check if first argument is a dictionary (switch case)
161
+ if (
162
+ len(tail.value.args) == 1
163
+ and isinstance(tail.value.args[0], ast.Dict)
164
+ and any(k.arg == "condition" for k in tail.value.keywords)
165
+ ):
166
+ # This is a switch statement
167
+ switch_cases = self._parse_switch_dict(tail.value.args[0])
168
+ condition_name = None
169
+
170
+ # Get condition parameter
171
+ for keyword in tail.value.keywords:
172
+ if keyword.arg == "condition":
173
+ if isinstance(keyword.value, ast.Str):
174
+ condition_name = keyword.value.s
175
+ elif isinstance(keyword.value, ast.Constant) and isinstance(
176
+ keyword.value.value, str
177
+ ):
178
+ condition_name = keyword.value.value
179
+ break
180
+
181
+ if switch_cases and condition_name:
182
+ self.type = "split-switch"
183
+ self.condition = condition_name
184
+ self.switch_cases = switch_cases
185
+ self.out_funcs = list(switch_cases.values())
186
+ self.invalid_tail_next = False
187
+ return
188
+
189
+ else:
190
+ self.out_funcs = [e.attr for e in tail.value.args]
108
191
 
109
192
  keywords = dict(
110
193
  (k.arg, getattr(k.value, "s", None)) for k in tail.value.keywords
@@ -151,6 +234,7 @@ class DAGNode(object):
151
234
  has_tail_next={0.has_tail_next} (line {0.tail_next_lineno})
152
235
  invalid_tail_next={0.invalid_tail_next}
153
236
  foreach_param={0.foreach_param}
237
+ condition={0.condition}
154
238
  parallel_step={0.parallel_step}
155
239
  parallel_foreach={0.parallel_foreach}
156
240
  -> {out}""".format(
@@ -219,6 +303,8 @@ class FlowGraph(object):
219
303
  if node.type in ("split", "foreach"):
220
304
  node.split_parents = split_parents
221
305
  split_parents = split_parents + [node.name]
306
+ elif node.type == "split-switch":
307
+ node.split_parents = split_parents
222
308
  elif node.type == "join":
223
309
  # ignore joins without splits
224
310
  if split_parents:
@@ -259,15 +345,37 @@ class FlowGraph(object):
259
345
  def output_dot(self):
260
346
  def edge_specs():
261
347
  for node in self.nodes.values():
262
- for edge in node.out_funcs:
263
- yield "%s -> %s;" % (node.name, edge)
348
+ if node.type == "split-switch":
349
+ # Label edges for switch cases
350
+ for case_value, step_name in node.switch_cases.items():
351
+ yield (
352
+ '{0} -> {1} [label="{2}" color="blue" fontcolor="blue"];'.format(
353
+ node.name, step_name, case_value
354
+ )
355
+ )
356
+ else:
357
+ for edge in node.out_funcs:
358
+ yield "%s -> %s;" % (node.name, edge)
264
359
 
265
360
  def node_specs():
266
361
  for node in self.nodes.values():
267
- nodetype = "join" if node.num_args > 1 else node.type
268
- yield '"{0.name}"' '[ label = <<b>{0.name}</b> | <font point-size="10">{type}</font>> ' ' fontname = "Helvetica" ' ' shape = "record" ];'.format(
269
- node, type=nodetype
270
- )
362
+ if node.type == "split-switch":
363
+ # Hexagon shape for switch nodes
364
+ condition_label = (
365
+ f"switch: {node.condition}" if node.condition else "switch"
366
+ )
367
+ yield (
368
+ '"{0.name}" '
369
+ '[ label = <<b>{0.name}</b><br/><font point-size="9">{condition}</font>> '
370
+ ' fontname = "Helvetica" '
371
+ ' shape = "hexagon" '
372
+ ' style = "filled" fillcolor = "lightgreen" ];'
373
+ ).format(node, condition=condition_label)
374
+ else:
375
+ nodetype = "join" if node.num_args > 1 else node.type
376
+ yield '"{0.name}"' '[ label = <<b>{0.name}</b> | <font point-size="10">{type}</font>> ' ' fontname = "Helvetica" ' ' shape = "record" ];'.format(
377
+ node, type=nodetype
378
+ )
271
379
 
272
380
  return (
273
381
  "digraph {0.name} {{\n"
@@ -291,6 +399,8 @@ class FlowGraph(object):
291
399
  if node.parallel_foreach:
292
400
  return "split-parallel"
293
401
  return "split-foreach"
402
+ elif node.type == "split-switch":
403
+ return "split-switch"
294
404
  return "unknown" # Should never happen
295
405
 
296
406
  def node_to_dict(name, node):
@@ -325,6 +435,9 @@ class FlowGraph(object):
325
435
  d["foreach_artifact"] = node.foreach_param
326
436
  elif d["type"] == "split-parallel":
327
437
  d["num_parallel"] = node.num_parallel
438
+ elif d["type"] == "split-switch":
439
+ d["condition"] = node.condition
440
+ d["switch_cases"] = node.switch_cases
328
441
  if node.matching_join:
329
442
  d["matching_join"] = node.matching_join
330
443
  return d
@@ -339,8 +452,8 @@ class FlowGraph(object):
339
452
  steps_info[cur_name] = node_dict
340
453
  resulting_list.append(cur_name)
341
454
 
342
- if cur_node.type not in ("start", "linear", "join"):
343
- # We need to look at the different branches for this
455
+ node_type = node_to_type(cur_node)
456
+ if node_type in ("split-static", "split-foreach"):
344
457
  resulting_list.append(
345
458
  [
346
459
  populate_block(s, cur_node.matching_join)
@@ -348,8 +461,19 @@ class FlowGraph(object):
348
461
  ]
349
462
  )
350
463
  cur_name = cur_node.matching_join
464
+ elif node_type == "split-switch":
465
+ all_paths = [
466
+ populate_block(s, end_name) for s in cur_node.out_funcs
467
+ ]
468
+ resulting_list.append(all_paths)
469
+ cur_name = end_name
351
470
  else:
352
- cur_name = cur_node.out_funcs[0]
471
+ # handles only linear, start, and join steps.
472
+ if cur_node.out_funcs:
473
+ cur_name = cur_node.out_funcs[0]
474
+ else:
475
+ # handles terminal nodes or when we jump to 'end_name'.
476
+ break
353
477
  return resulting_list
354
478
 
355
479
  graph_structure = populate_block("start", "end")
metaflow/lint.py CHANGED
@@ -1,6 +1,6 @@
1
1
  import re
2
2
  from .exception import MetaflowException
3
- from .util import all_equal
3
+ from .util import all_equal, get_split_branch_for_node
4
4
 
5
5
 
6
6
  class LintWarn(MetaflowException):
@@ -134,7 +134,13 @@ def check_valid_transitions(graph):
134
134
  msg = (
135
135
  "Step *{0.name}* specifies an invalid self.next() transition. "
136
136
  "Make sure the self.next() expression matches with one of the "
137
- "supported transition types."
137
+ "supported transition types:\n"
138
+ " • Linear: self.next(self.step_name)\n"
139
+ " • Fan-out: self.next(self.step1, self.step2, ...)\n"
140
+ " • Foreach: self.next(self.step, foreach='variable')\n"
141
+ " • Switch: self.next({{\"key\": self.step, ...}}, condition='variable')\n\n"
142
+ "For switch statements, keys must be string literals or config expressions "
143
+ "(self.config.key_name), not variables or numbers."
138
144
  )
139
145
  for node in graph:
140
146
  if node.type != "end" and node.has_tail_next and node.invalid_tail_next:
@@ -232,7 +238,13 @@ def check_split_join_balance(graph):
232
238
  new_stack = split_stack
233
239
  elif node.type in ("split", "foreach"):
234
240
  new_stack = split_stack + [("split", node.out_funcs)]
241
+ elif node.type == "split-switch":
242
+ # For a switch, continue traversal down each path with the same stack
243
+ for n in node.out_funcs:
244
+ traverse(graph[n], split_stack)
245
+ return
235
246
  elif node.type == "end":
247
+ new_stack = split_stack
236
248
  if split_stack:
237
249
  _, split_roots = split_stack.pop()
238
250
  roots = ", ".join(split_roots)
@@ -240,10 +252,25 @@ def check_split_join_balance(graph):
240
252
  msg0.format(roots=roots), node.func_lineno, node.source_file
241
253
  )
242
254
  elif node.type == "join":
255
+ new_stack = split_stack
243
256
  if split_stack:
244
257
  _, split_roots = split_stack[-1]
245
258
  new_stack = split_stack[:-1]
246
- if len(node.in_funcs) != len(split_roots):
259
+
260
+ # Identify the split this join corresponds to from its parentage.
261
+ split_node_name = node.split_parents[-1]
262
+
263
+ # Resolve each incoming function to its root branch from the split.
264
+ resolved_branches = set()
265
+ for in_node_name in node.in_funcs:
266
+ branch = get_split_branch_for_node(
267
+ graph, in_node_name, split_node_name
268
+ )
269
+ if branch:
270
+ resolved_branches.add(branch)
271
+
272
+ # compares the set of resolved branches against the expected branches from the split.
273
+ if len(resolved_branches) != len(split_roots):
247
274
  paths = ", ".join(node.in_funcs)
248
275
  roots = ", ".join(split_roots)
249
276
  raise LintWarn(
@@ -266,6 +293,8 @@ def check_split_join_balance(graph):
266
293
 
267
294
  if not all_equal(map(parents, node.in_funcs)):
268
295
  raise LintWarn(msg3.format(node), node.func_lineno, node.source_file)
296
+ else:
297
+ new_stack = split_stack
269
298
 
270
299
  for n in node.out_funcs:
271
300
  traverse(graph[n], new_stack)
@@ -273,6 +302,44 @@ def check_split_join_balance(graph):
273
302
  traverse(graph["start"], [])
274
303
 
275
304
 
305
+ @linter.ensure_static_graph
306
+ @linter.check
307
+ def check_switch_splits(graph):
308
+ """Check conditional split constraints"""
309
+ msg0 = (
310
+ "Step *{0.name}* is a switch split but defines {num} transitions. "
311
+ "Switch splits must define at least 2 transitions."
312
+ )
313
+ msg1 = "Step *{0.name}* is a switch split but has no condition variable."
314
+ msg2 = "Step *{0.name}* is a switch split but has no switch cases defined."
315
+
316
+ for node in graph:
317
+ if node.type == "split-switch":
318
+ # Check at least 2 outputs
319
+ if len(node.out_funcs) < 2:
320
+ raise LintWarn(
321
+ msg0.format(node, num=len(node.out_funcs)),
322
+ node.func_lineno,
323
+ node.source_file,
324
+ )
325
+
326
+ # Check condition exists
327
+ if not node.condition:
328
+ raise LintWarn(
329
+ msg1.format(node),
330
+ node.func_lineno,
331
+ node.source_file,
332
+ )
333
+
334
+ # Check switch cases exist
335
+ if not node.switch_cases:
336
+ raise LintWarn(
337
+ msg2.format(node),
338
+ node.func_lineno,
339
+ node.source_file,
340
+ )
341
+
342
+
276
343
  @linter.ensure_static_graph
277
344
  @linter.check
278
345
  def check_empty_foreaches(graph):