runnable 0.34.0a2__py3-none-any.whl → 0.35.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,241 @@
1
+ import logging
2
+ from copy import deepcopy
3
+ from typing import Any, cast
4
+
5
+ from pydantic import Field, field_serializer, field_validator
6
+
7
+ from runnable import console, defaults
8
+ from runnable.datastore import Parameter
9
+ from runnable.graph import Graph, create_graph
10
+ from runnable.nodes import CompositeNode, TypeMapVariable
11
+
12
+ logger = logging.getLogger(defaults.LOGGER_NAME)
13
+
14
+
15
+ class ConditionalNode(CompositeNode):
16
+ """
17
+ parameter: name -> the parameter which is used for evaluation
18
+ default: Optional[branch] = branch to execute if nothing is matched.
19
+ branches: {
20
+ "case1" : branch1,
21
+ "case2: branch2,
22
+ }
23
+
24
+ Conceptually this is equal to:
25
+ match parameter:
26
+ case "case1":
27
+ branch1
28
+ case "case2":
29
+ branch2
30
+ case _:
31
+ default
32
+
33
+ """
34
+
35
+ node_type: str = Field(default="conditional", serialization_alias="type")
36
+
37
+ parameter: str # the name of the parameter should be isalnum
38
+ default: Graph | None = Field(default=None) # TODO: Think about the design of this
39
+ branches: dict[str, Graph]
40
+ # The keys of the branches should be isalnum()
41
+
42
+ @field_validator("parameter", mode="after")
43
+ @classmethod
44
+ def check_parameter(cls, parameter: str) -> str:
45
+ """
46
+ Validate that the parameter name is alphanumeric.
47
+
48
+ Args:
49
+ parameter (str): The parameter name to validate.
50
+
51
+ Raises:
52
+ ValueError: If the parameter name is not alphanumeric.
53
+
54
+ Returns:
55
+ str: The validated parameter name.
56
+ """
57
+ if not parameter.isalnum():
58
+ raise ValueError(f"Parameter '{parameter}' must be alphanumeric.")
59
+ return parameter
60
+
61
+ def get_parameter_value(self) -> str | int | bool | float:
62
+ """
63
+ Get the parameter value from the context.
64
+
65
+ Returns:
66
+ Any: The value of the parameter.
67
+ """
68
+ parameters: dict[str, Parameter] = self._context.run_log_store.get_parameters(
69
+ run_id=self._context.run_id
70
+ )
71
+
72
+ if self.parameter not in parameters:
73
+ raise Exception(f"Parameter {self.parameter} not found in parameters")
74
+
75
+ chosen_parameter_value = parameters[self.parameter].get_value()
76
+
77
+ assert isinstance(chosen_parameter_value, (int, float, bool, str)), (
78
+ f"Parameter '{self.parameter}' must be of type int, float, bool, or str, "
79
+ f"but got {type(chosen_parameter_value).__name__}."
80
+ )
81
+
82
+ return chosen_parameter_value
83
+
84
+ def get_summary(self) -> dict[str, Any]:
85
+ summary = {
86
+ "name": self.name,
87
+ "type": self.node_type,
88
+ "branches": [branch.get_summary() for branch in self.branches.values()],
89
+ "parameter": self.parameter,
90
+ "default": self.default.get_summary() if self.default else None,
91
+ }
92
+
93
+ return summary
94
+
95
+ @field_serializer("branches")
96
+ def ser_branches(self, branches: dict[str, Graph]) -> dict[str, Graph]:
97
+ ret: dict[str, Graph] = {}
98
+
99
+ for branch_name, branch in branches.items():
100
+ ret[branch_name.split(".")[-1]] = branch
101
+
102
+ return ret
103
+
104
+ @classmethod
105
+ def parse_from_config(cls, config: dict[str, Any]) -> "ConditionalNode":
106
+ internal_name = cast(str, config.get("internal_name"))
107
+
108
+ config_branches = config.pop("branches", {})
109
+ branches = {}
110
+ for branch_name, branch_config in config_branches.items():
111
+ sub_graph = create_graph(
112
+ deepcopy(branch_config),
113
+ internal_branch_name=internal_name + "." + branch_name,
114
+ )
115
+ branches[internal_name + "." + branch_name] = sub_graph
116
+
117
+ if not branches:
118
+ raise Exception("A parallel node should have branches")
119
+ return cls(branches=branches, **config)
120
+
121
+ def _get_branch_by_name(self, branch_name: str) -> Graph:
122
+ if branch_name in self.branches:
123
+ return self.branches[branch_name]
124
+
125
+ raise Exception(f"Branch {branch_name} does not exist")
126
+
127
+ def fan_out(self, map_variable: TypeMapVariable = None):
128
+ """
129
+ This method is restricted to creating branch logs.
130
+ """
131
+ parameter_value = self.get_parameter_value()
132
+
133
+ hit_once = False
134
+
135
+ for internal_branch_name, _ in self.branches.items():
136
+ # the match is done on the last part of the branch name
137
+ result = str(parameter_value) == internal_branch_name.split(".")[-1]
138
+
139
+ if not result:
140
+ # Need not create a branch log for this branch
141
+ continue
142
+
143
+ effective_branch_name = self._resolve_map_placeholders(
144
+ internal_branch_name, map_variable=map_variable
145
+ )
146
+
147
+ hit_once = True
148
+ branch_log = self._context.run_log_store.create_branch_log(
149
+ effective_branch_name
150
+ )
151
+
152
+ console.print(
153
+ f"Branch log created for {effective_branch_name}: {branch_log}"
154
+ )
155
+ branch_log.status = defaults.PROCESSING
156
+ self._context.run_log_store.add_branch_log(branch_log, self._context.run_id)
157
+
158
+ if not hit_once:
159
+ raise Exception(
160
+ "None of the branches were true. Please check your evaluate statements"
161
+ )
162
+
163
+ def execute_as_graph(self, map_variable: TypeMapVariable = None):
164
+ """
165
+ This function does the actual execution of the sub-branches of the parallel node.
166
+
167
+ From a design perspective, this function should not be called if the execution is 3rd party orchestrated.
168
+
169
+ The modes that render the job specifications, do not need to interact with this node at all as they have their
170
+ own internal mechanisms of handing parallel states.
171
+ If they do not, you can find a way using as-is nodes as hack nodes.
172
+
173
+ The execution of a dag, could result in
174
+ * The dag being completely executed with a definite (fail, success) state in case of
175
+ local or local-container execution
176
+ * The dag being in a processing state with PROCESSING status in case of local-aws-batch
177
+
178
+ Only fail state is considered failure during this phase of execution.
179
+
180
+ Args:
181
+ executor (Executor): The Executor as per the use config
182
+ **kwargs: Optional kwargs passed around
183
+ """
184
+ self.fan_out(map_variable=map_variable)
185
+ parameter_value = self.get_parameter_value()
186
+
187
+ for internal_branch_name, branch in self.branches.items():
188
+ result = str(parameter_value) == internal_branch_name.split(".")[-1]
189
+
190
+ if result:
191
+ # if the condition is met, execute the graph
192
+ logger.debug(f"Executing graph for {branch}")
193
+ self._context.executor.execute_graph(branch, map_variable=map_variable)
194
+
195
+ self.fan_in(map_variable=map_variable)
196
+
197
+ def fan_in(self, map_variable: TypeMapVariable = None):
198
+ """
199
+ The general fan in method for a node of type Parallel.
200
+
201
+ 3rd party orchestrators should use this method to find the status of the composite step.
202
+
203
+ Args:
204
+ executor (BaseExecutor): The executor class as defined by the config
205
+ map_variable (dict, optional): If the node is part of a map. Defaults to None.
206
+ """
207
+ effective_internal_name = self._resolve_map_placeholders(
208
+ self.internal_name, map_variable=map_variable
209
+ )
210
+
211
+ step_success_bool: bool = True
212
+ parameter_value = self.get_parameter_value()
213
+
214
+ for internal_branch_name, _ in self.branches.items():
215
+ result = str(parameter_value) == internal_branch_name.split(".")[-1]
216
+
217
+ if not result:
218
+ # The branch would not have been executed
219
+ continue
220
+
221
+ effective_branch_name = self._resolve_map_placeholders(
222
+ internal_branch_name, map_variable=map_variable
223
+ )
224
+
225
+ branch_log = self._context.run_log_store.get_branch_log(
226
+ effective_branch_name, self._context.run_id
227
+ )
228
+
229
+ if branch_log.status != defaults.SUCCESS:
230
+ step_success_bool = False
231
+
232
+ step_log = self._context.run_log_store.get_step_log(
233
+ effective_internal_name, self._context.run_id
234
+ )
235
+
236
+ if step_success_bool: #  If none failed
237
+ step_log.status = defaults.SUCCESS
238
+ else:
239
+ step_log.status = defaults.FAIL
240
+
241
+ self._context.run_log_store.add_step_log(step_log, self._context.run_id)
@@ -20,6 +20,7 @@ from pydantic import (
20
20
  from pydantic.alias_generators import to_camel
21
21
  from ruamel.yaml import YAML
22
22
 
23
+ from extensions.nodes.conditional import ConditionalNode
23
24
  from extensions.nodes.nodes import MapNode, ParallelNode, TaskNode
24
25
 
25
26
  # TODO: Should be part of a wider refactor
@@ -307,6 +308,7 @@ class DagTask(BaseModelWIthConfig):
307
308
  template: str # Should be name of a container template or dag template
308
309
  arguments: Optional[Arguments] = Field(default=None)
309
310
  with_param: Optional[str] = Field(default=None)
311
+ when_param: Optional[str] = Field(default=None, serialization_alias="when")
310
312
  depends: Optional[str] = Field(default=None)
311
313
 
312
314
 
@@ -563,6 +565,8 @@ class ArgoExecutor(GenericPipelineExecutor):
563
565
  outputs: Optional[Outputs] = None
564
566
  if mode == "out" and node.node_type == "map":
565
567
  outputs = Outputs(parameters=[OutputParameter(name="iterate-on")])
568
+ if mode == "out" and node.node_type == "conditional":
569
+ outputs = Outputs(parameters=[OutputParameter(name="case")])
566
570
 
567
571
  container_template = ContainerTemplate(
568
572
  name=task_name,
@@ -722,6 +726,7 @@ class ArgoExecutor(GenericPipelineExecutor):
722
726
  # - We are using withParam and arguments of the map template to send that value in
723
727
  # - The map template should receive that value as a parameter into the template.
724
728
  # - The task then start to use it as inputs.parameters.iterate-on
729
+ # the when param should be an evaluation
725
730
 
726
731
  def _gather_tasks_for_dag_template(
727
732
  self,
@@ -767,9 +772,11 @@ class ArgoExecutor(GenericPipelineExecutor):
767
772
 
768
773
  self._templates.append(template_of_container)
769
774
 
770
- case "map" | "parallel":
771
- assert isinstance(working_on, MapNode) or isinstance(
772
- working_on, ParallelNode
775
+ case "map" | "parallel" | "conditional":
776
+ assert (
777
+ isinstance(working_on, MapNode)
778
+ or isinstance(working_on, ParallelNode)
779
+ or isinstance(working_on, ConditionalNode)
773
780
  )
774
781
  node_type = working_on.node_type
775
782
 
@@ -792,7 +799,8 @@ class ArgoExecutor(GenericPipelineExecutor):
792
799
  )
793
800
 
794
801
  # Add the composite task
795
- with_param = None
802
+ with_param: Optional[str] = None
803
+ when_param: Optional[str] = None
796
804
  added_parameters = parameters or []
797
805
  branches = {}
798
806
  if node_type == "map":
@@ -807,22 +815,34 @@ class ArgoExecutor(GenericPipelineExecutor):
807
815
  elif node_type == "parallel":
808
816
  assert isinstance(working_on, ParallelNode)
809
817
  branches = working_on.branches
818
+ elif node_type == "conditional":
819
+ assert isinstance(working_on, ConditionalNode)
820
+ branches = working_on.branches
821
+ when_param = (
822
+ f"{{{{tasks.{task_name}-fan-out.outputs.parameters.case}}}}"
823
+ )
810
824
  else:
811
825
  raise ValueError("Invalid node type")
812
826
 
813
827
  fan_in_depends = ""
814
828
 
815
829
  for name, branch in branches.items():
830
+ match_when = branch.internal_branch_name.split(".")[-1]
816
831
  name = (
817
832
  name.replace(" ", "-").replace(".", "-").replace("_", "-")
818
833
  )
819
834
 
835
+ if node_type == "conditional":
836
+ assert isinstance(working_on, ConditionalNode)
837
+ when_param = f"'{match_when}' == {{{{tasks.{task_name}-fan-out.outputs.parameters.case}}}}"
838
+
820
839
  branch_task = DagTask(
821
840
  name=f"{task_name}-{name}",
822
841
  template=f"{task_name}-{name}",
823
842
  depends=f"{task_name}-fan-out.Succeeded",
824
843
  arguments=Arguments(parameters=added_parameters),
825
844
  with_param=with_param,
845
+ when_param=when_param,
826
846
  )
827
847
  composite_template.dag.tasks.append(branch_task)
828
848
 
@@ -836,6 +856,8 @@ class ArgoExecutor(GenericPipelineExecutor):
836
856
  ),
837
857
  )
838
858
 
859
+ assert isinstance(branch, Graph)
860
+
839
861
  self._gather_tasks_for_dag_template(
840
862
  dag_template=branch_template,
841
863
  dag=branch,
@@ -862,28 +884,6 @@ class ArgoExecutor(GenericPipelineExecutor):
862
884
 
863
885
  self._templates.append(composite_template)
864
886
 
865
- case "torch":
866
- from extensions.nodes.torch import TorchNode
867
-
868
- assert isinstance(working_on, TorchNode)
869
- # TODO: Need to add multi-node functionality
870
- # Check notes on the torch node
871
-
872
- template_of_container = self._create_container_template(
873
- working_on,
874
- task_name=task_name,
875
- inputs=Inputs(parameters=parameters),
876
- )
877
- assert template_of_container.container is not None
878
-
879
- if working_on.node_type == "task":
880
- self._expose_secrets_to_task(
881
- working_on,
882
- container_template=template_of_container.container,
883
- )
884
-
885
- self._templates.append(template_of_container)
886
-
887
887
  self._handle_failures(
888
888
  working_on,
889
889
  dag,
@@ -1025,6 +1025,12 @@ class ArgoExecutor(GenericPipelineExecutor):
1025
1025
  with open("/tmp/output.txt", mode="w", encoding="utf-8") as myfile:
1026
1026
  json.dump(iterate_on.get_value(), myfile, indent=4)
1027
1027
 
1028
+ if node.node_type == "conditional":
1029
+ assert isinstance(node, ConditionalNode)
1030
+
1031
+ with open("/tmp/output.txt", mode="w", encoding="utf-8") as myfile:
1032
+ json.dump(node.get_parameter_value(), myfile, indent=4)
1033
+
1028
1034
  def fan_in(self, node: BaseNode, map_variable: TypeMapVariable = None):
1029
1035
  self._use_volumes()
1030
1036
  super().fan_in(node, map_variable)
runnable/__init__.py CHANGED
@@ -17,8 +17,9 @@ console.print(":runner: Lets go!!")
17
17
 
18
18
  task_console = Console(record=True)
19
19
 
20
- from runnable.sdk import ( # noqa
20
+ from runnable.sdk import ( # noqa;
21
21
  Catalog,
22
+ Conditional,
22
23
  Fail,
23
24
  Map,
24
25
  NotebookJob,
runnable/nodes.py CHANGED
@@ -8,6 +8,7 @@ import runnable.context as context
8
8
  from runnable import defaults, exceptions
9
9
  from runnable.datastore import StepLog
10
10
  from runnable.defaults import TypeMapVariable
11
+ from runnable.graph import Graph
11
12
 
12
13
  logger = logging.getLogger(defaults.LOGGER_NAME)
13
14
 
@@ -218,7 +219,7 @@ class BaseNode(ABC, BaseModel):
218
219
  """
219
220
 
220
221
  @abstractmethod
221
- def _get_branch_by_name(self, branch_name: str):
222
+ def _get_branch_by_name(self, branch_name: str) -> Graph:
222
223
  """
223
224
  Retrieve a branch by name.
224
225
 
runnable/sdk.py CHANGED
@@ -26,6 +26,7 @@ from rich.progress import (
26
26
  from rich.table import Column
27
27
  from typing_extensions import Self
28
28
 
29
+ from extensions.nodes.conditional import ConditionalNode
29
30
  from extensions.nodes.nodes import (
30
31
  FailNode,
31
32
  MapNode,
@@ -50,6 +51,7 @@ StepType = Union[
50
51
  "Parallel",
51
52
  "Map",
52
53
  "TorchTask",
54
+ "Conditional",
53
55
  ]
54
56
 
55
57
 
@@ -193,6 +195,9 @@ class BaseTask(BaseTraversal):
193
195
  "This method should be implemented in the child class"
194
196
  )
195
197
 
198
+ def as_pipeline(self) -> "Pipeline":
199
+ return Pipeline(steps=[self]) # type: ignore
200
+
196
201
 
197
202
  class PythonTask(BaseTask):
198
203
  """
@@ -283,14 +288,15 @@ class PythonTask(BaseTask):
283
288
 
284
289
 
285
290
  class TorchTask(BaseTask):
286
- entrypoint: str = Field(
287
- alias="entrypoint", default="torch.distributed.run", frozen=True
288
- )
289
- args_to_torchrun: Dict[str, Any] = Field(
290
- default_factory=dict, alias="args_to_torchrun"
291
- )
291
+ # entrypoint: str = Field(
292
+ # alias="entrypoint", default="torch.distributed.run", frozen=True
293
+ # )
294
+ # args_to_torchrun: Dict[str, Any] = Field(
295
+ # default_factory=dict, alias="args_to_torchrun"
296
+ # )
292
297
 
293
298
  script_to_call: str
299
+ accelerate_config_file: str
294
300
 
295
301
  @computed_field
296
302
  def command_type(self) -> str:
@@ -520,6 +526,53 @@ class Parallel(BaseTraversal):
520
526
  return node
521
527
 
522
528
 
529
+ class Conditional(BaseTraversal):
530
+ branches: Dict[str, "Pipeline"]
531
+ parameter: str # the name of the parameter should be isalnum
532
+
533
+ @field_validator("parameter")
534
+ @classmethod
535
+ def validate_parameter(cls, parameter: str) -> str:
536
+ if not parameter.isalnum():
537
+ raise AssertionError(
538
+ "The parameter name should be alphanumeric and not empty"
539
+ )
540
+ return parameter
541
+
542
+ @field_validator("branches")
543
+ @classmethod
544
+ def validate_branches(
545
+ cls, branches: Dict[str, "Pipeline"]
546
+ ) -> Dict[str, "Pipeline"]:
547
+ for branch_name in branches.keys():
548
+ if not branch_name.isalnum():
549
+ raise ValueError(f"Branch '{branch_name}' must be alphanumeric.")
550
+ return branches
551
+
552
+ @computed_field # type: ignore
553
+ @property
554
+ def graph_branches(self) -> Dict[str, graph.Graph]:
555
+ return {
556
+ name: pipeline._dag.model_copy() for name, pipeline in self.branches.items()
557
+ }
558
+
559
+ def create_node(self) -> ConditionalNode:
560
+ if not self.next_node:
561
+ if not (self.terminate_with_failure or self.terminate_with_success):
562
+ raise AssertionError(
563
+ "A node not being terminated must have a user defined next node"
564
+ )
565
+
566
+ node = ConditionalNode(
567
+ name=self.name,
568
+ branches=self.graph_branches,
569
+ internal_name="",
570
+ next_node=self.next_node,
571
+ parameter=self.parameter,
572
+ )
573
+ return node
574
+
575
+
523
576
  class Map(BaseTraversal):
524
577
  """
525
578
  A node that iterates over a list of items and executes a pipeline for each item.
@@ -543,7 +596,6 @@ class Map(BaseTraversal):
543
596
  iterate_on: str
544
597
  iterate_as: str
545
598
  reducer: Optional[str] = Field(default=None, alias="reducer")
546
- overrides: Dict[str, Any] = Field(default_factory=dict)
547
599
 
548
600
  @computed_field # type: ignore
549
601
  @property
@@ -564,7 +616,6 @@ class Map(BaseTraversal):
564
616
  next_node=self.next_node,
565
617
  iterate_on=self.iterate_on,
566
618
  iterate_as=self.iterate_as,
567
- overrides=self.overrides,
568
619
  reducer=self.reducer,
569
620
  )
570
621
 
@@ -984,13 +1035,14 @@ class PythonJob(BaseJob):
984
1035
 
985
1036
 
986
1037
  class TorchJob(BaseJob):
987
- entrypoint: str = Field(default="torch.distributed.run", frozen=True)
988
- args_to_torchrun: dict[str, str | bool | int | float] = Field(
989
- default_factory=dict
990
- ) # For example
1038
+ # entrypoint: str = Field(default="torch.distributed.run", frozen=True)
1039
+ # args_to_torchrun: dict[str, str | bool | int | float] = Field(
1040
+ # default_factory=dict
1041
+ # ) # For example
991
1042
  # {"nproc_per_node": 2, "nnodes": 1,}
992
1043
 
993
1044
  script_to_call: str # For example train/script.py
1045
+ accelerate_config_file: str
994
1046
 
995
1047
  def get_task(self) -> RunnableTask:
996
1048
  # Piggy bank on existing tasks as a hack
runnable/tasks.py CHANGED
@@ -5,7 +5,6 @@ import io
5
5
  import json
6
6
  import logging
7
7
  import os
8
- import runpy
9
8
  import subprocess
10
9
  import sys
11
10
  from datetime import datetime
@@ -357,16 +356,15 @@ class PythonTaskType(BaseTaskType): # pylint: disable=too-few-public-methods
357
356
 
358
357
  class TorchTaskType(BaseTaskType):
359
358
  task_type: str = Field(default="torch", serialization_alias="command_type")
360
-
361
- entrypoint: str = Field(default="torch.distributed.run", frozen=True)
362
- args_to_torchrun: dict[str, str | bool] = Field(default_factory=dict) # For example
363
- # {"nproc_per_node": 2, "nnodes": 1,}
359
+ accelerate_config_file: str
364
360
 
365
361
  script_to_call: str # For example train/script.py
366
362
 
367
363
  def execute_command(
368
364
  self, map_variable: Dict[str, str | int | float] | None = None
369
365
  ) -> StepAttempt:
366
+ from accelerate.commands import launch
367
+
370
368
  attempt_log = StepAttempt(status=defaults.FAIL, start_time=str(datetime.now()))
371
369
 
372
370
  with (
@@ -376,39 +374,37 @@ class TorchTaskType(BaseTaskType):
376
374
  self.expose_secrets() as _,
377
375
  ):
378
376
  try:
379
- entry_point_args = [self.entrypoint]
380
-
381
- for key, value in self.args_to_torchrun.items():
382
- entry_point_args.append(f"--{key}")
383
- if type(value) is not bool:
384
- entry_point_args.append(str(value))
385
-
386
- entry_point_args.append(self.script_to_call)
377
+ script_args = []
387
378
  for key, value in params.items():
388
- entry_point_args.append(f"--{key}")
389
- if type(value.value) is not bool: # type: ignore
390
- entry_point_args.append(str(value.value)) # type: ignore
379
+ script_args.append(f"--{key}")
380
+ if type(value.value) is not bool:
381
+ script_args.append(str(value.value))
391
382
 
392
383
  # TODO: Check the typing here
393
384
 
394
385
  logger.info("Calling the user script with the following parameters:")
395
- logger.info(entry_point_args)
386
+ logger.info(script_args)
396
387
  out_file = TeeIO()
397
388
  try:
398
389
  with contextlib.redirect_stdout(out_file):
399
- sys.argv = entry_point_args
400
- runpy.run_module(self.entrypoint, run_name="__main__")
390
+ parser = launch.launch_command_parser()
391
+ args = parser.parse_args(self.script_to_call)
392
+ args.training_script = self.script_to_call
393
+ args.config_file = self.accelerate_config_file
394
+ args.training_script_args = script_args
395
+
396
+ launch.launch_command(args)
401
397
  task_console.print(out_file.getvalue())
402
398
  except Exception as e:
403
399
  raise exceptions.CommandCallError(
404
- f"Call to entrypoint {self.entrypoint} with {self.script_to_call} did not succeed."
400
+ f"Call to script{self.script_to_call} did not succeed."
405
401
  ) from e
406
402
  finally:
407
403
  sys.argv = sys.argv[:1]
408
404
 
409
405
  attempt_log.status = defaults.SUCCESS
410
406
  except Exception as _e:
411
- msg = f"Call to entrypoint {self.entrypoint} with {self.script_to_call} did not succeed."
407
+ msg = f"Call to script: {self.script_to_call} did not succeed."
412
408
  attempt_log.message = msg
413
409
  task_console.print_exception(show_locals=False)
414
410
  task_console.log(_e, style=defaults.error_style)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: runnable
3
- Version: 0.34.0a2
3
+ Version: 0.35.0
4
4
  Summary: Add your description here
5
5
  Author-email: "Vammi, Vijay" <vijay.vammi@astrazeneca.com>
6
6
  License-File: LICENSE
@@ -27,8 +27,8 @@ Requires-Dist: ploomber-engine>=0.0.33; extra == 'notebook'
27
27
  Provides-Extra: s3
28
28
  Requires-Dist: cloudpathlib[s3]; extra == 's3'
29
29
  Provides-Extra: torch
30
+ Requires-Dist: accelerate>=1.5.2; extra == 'torch'
30
31
  Requires-Dist: torch>=2.6.0; extra == 'torch'
31
- Requires-Dist: torchvision>=0.21.0; extra == 'torch'
32
32
  Description-Content-Type: text/markdown
33
33
 
34
34
 
@@ -14,13 +14,12 @@ extensions/job_executor/local.py,sha256=3ZbCFXBvbLlMp10JTmQJJrjBKG2keHI6SH8hEvmH
14
14
  extensions/job_executor/local_container.py,sha256=1JcLJ0zrNSNHdubrSO9miN54iwvPLHqKMZ08aOC8WWo,6886
15
15
  extensions/job_executor/pyproject.toml,sha256=UIEgiCYHTXcRWSByNMFuKJFKgxTBpQqTqyUecIsb_Vc,286
16
16
  extensions/nodes/README.md,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
+ extensions/nodes/conditional.py,sha256=m4DGxjqWpjNd2KQPAdVSJ6ridt1BDx2Lt6kmEQa9ghY,8594
17
18
  extensions/nodes/nodes.py,sha256=s9ub1dqy4qHjRQG6YElCdL7rCOTYNs9RUIrStZ6tEB4,28256
18
19
  extensions/nodes/pyproject.toml,sha256=YTu-ETN3JNFSkMzzWeOwn4m-O2nbRH-PmiPBALDCUw4,278
19
- extensions/nodes/torch.py,sha256=64DTjdPNSJ8vfMwUN9h9Ly5g9qj-Bga7LSGrfCAO0BY,9389
20
- extensions/nodes/torch_config.py,sha256=tO3sG2_fj8a6FmPZZllwKVx3WaRr4QmQYcACseg8YXM,2839
21
20
  extensions/pipeline_executor/README.md,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
22
21
  extensions/pipeline_executor/__init__.py,sha256=wfigTL2T9OHrmE8b2Ydmb8h6hr-oF--Yc2FectC7WaY,24623
23
- extensions/pipeline_executor/argo.py,sha256=Xj3rasvJfgdEze_s3ILB77VY92NNk7iO8yT46A-_Y4c,37627
22
+ extensions/pipeline_executor/argo.py,sha256=17hHj3L5oIkoOpCSSbZlliLnOUoN5_JpK_DY0ELWXac,38233
24
23
  extensions/pipeline_executor/local.py,sha256=6oWUJ6b6NvIkpeQJBoCT1hbfX4_6WCB4HzMgHZ4ik1A,1887
25
24
  extensions/pipeline_executor/local_container.py,sha256=3kZ2QCsrq_YjH9dcAz8v05knKShQ_JtbIU-IA_-G538,12724
26
25
  extensions/pipeline_executor/mocked.py,sha256=0sMmypuvstBIv9uQg-WAcPrF3oOFpeEXNi6N8Nzdnl0,5680
@@ -42,7 +41,7 @@ extensions/secrets/dotenv.py,sha256=nADHXI6KJ_LUYOIe5EbtYH-21OBebSNVr0Pjb1GlZ7w,
42
41
  extensions/secrets/pyproject.toml,sha256=mLJNImNcBlbLKHh-0ugVWT9V83R4RibyyYDtBCSqVF4,282
43
42
  extensions/tasks/torch.py,sha256=oeXRkmuttFIAuBwH7-h4SOVXMDOZXX5mvqI2aFrR3Vo,10283
44
43
  extensions/tasks/torch_config.py,sha256=UjfMitT-TXASRDGR30I2vDRnyk7JQnR-5CsOVidjpSY,2833
45
- runnable/__init__.py,sha256=3ZKuvGEkY_zHVQlJtarXd4jkjICxjgnw-bbKN_5SiJI,691
44
+ runnable/__init__.py,sha256=eRXLgO-iiSUmNkjjzBjWdBP7Fp--I_vnImyhoGxZUek,709
46
45
  runnable/catalog.py,sha256=4msQxLhLKlsDDrHFnGauPYe-Or-q9g8_RYCn_4dpxaU,4466
47
46
  runnable/cli.py,sha256=3BiKSj95h2Drn__YlchMPZ5rBMafuRb2OGIsVpbsO5Y,8788
48
47
  runnable/context.py,sha256=by5uepmuCP0dmM9BmsliXihSes5QEFejwAsmekcqylE,1388
@@ -53,15 +52,15 @@ runnable/exceptions.py,sha256=LFbp0-Qxg2PAMLEVt7w2whhBxSG-5pzUEv5qN-Rc4_c,3003
53
52
  runnable/executor.py,sha256=Jr9yJtSH7CzjXJLWx3VWIUAQblstuGqzpFtajv7d39M,15348
54
53
  runnable/graph.py,sha256=poQz5zcvq89ju_u5sYlunQLPbHnXTaUmjcvstPwvT4U,16536
55
54
  runnable/names.py,sha256=vn92Kv9ANROYSZX6Z4z1v_WA3WiEdIYmG6KEStBFZug,8134
56
- runnable/nodes.py,sha256=QGHMznriEz4AcmntHICBZKrDT6zbc7WD1sV0MgwK10c,16691
55
+ runnable/nodes.py,sha256=CWfKVuGNaKSQpvFYYE1gEiTNouG0xPaA8KKaOxFr8EI,16733
57
56
  runnable/parameters.py,sha256=u77CdqqDAbVdzNeBFPNUfGnWPy9-SpBVmwEJ56xmDm8,5289
58
57
  runnable/pickler.py,sha256=ydJ_eti_U1F4l-YacFp7BWm6g5vTn04UXye25S1HVok,2684
59
- runnable/sdk.py,sha256=Cl6wVJj_pBnHmcszf-kh4nVqbiQaIruGJn06cm9epm4,35097
58
+ runnable/sdk.py,sha256=1gerGsq6EMSbDh2-Ey1vk6e0Sls55t9R29KlblNahi0,36793
60
59
  runnable/secrets.py,sha256=4L_dBFxTgr8r_hHUD6RlZEtqaOHDRsFG5PXO5wlvMI0,2324
61
- runnable/tasks.py,sha256=OW9pzjEKMRFpB256KJm__jWwsF37gs-tkIUcfnOTJwA,32382
60
+ runnable/tasks.py,sha256=lOtCninvosGI2bNIzblrzNa-lN7TMwel1KQ1g23M85A,32088
62
61
  runnable/utils.py,sha256=hBr7oGwGL2VgfITlQCTz-a1iwvvf7Mfl-HY8UdENZac,19929
63
- runnable-0.34.0a2.dist-info/METADATA,sha256=DzGQTVqxRAN95MoyRc5TQXG_OC85uf6PH5NGtru3qSg,10170
64
- runnable-0.34.0a2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
65
- runnable-0.34.0a2.dist-info/entry_points.txt,sha256=wKfW6aIWMQFlwrwpPBVWlMQDcxQmOupDKNkKyXoPFV4,1917
66
- runnable-0.34.0a2.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
67
- runnable-0.34.0a2.dist-info/RECORD,,
62
+ runnable-0.35.0.dist-info/METADATA,sha256=CgZbaiNCY_mUrcdyOGYV_6zkVwSrGMzqbUdrKQ-LL0U,10166
63
+ runnable-0.35.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
64
+ runnable-0.35.0.dist-info/entry_points.txt,sha256=bLH1QXcc-G8xgJTi4wf6SYQnsG_BxRRvobwa9dYm-js,1935
65
+ runnable-0.35.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
66
+ runnable-0.35.0.dist-info/RECORD,,
@@ -14,6 +14,7 @@ local-container = extensions.job_executor.local_container:LocalContainerJobExecu
14
14
  mini-k8s-job = extensions.job_executor.k8s:MiniK8sJobExecutor
15
15
 
16
16
  [nodes]
17
+ conditional = extensions.nodes.conditional:ConditionalNode
17
18
  dag = extensions.nodes.nodes:DagNode
18
19
  fail = extensions.nodes.nodes:FailNode
19
20
  map = extensions.nodes.nodes:MapNode
@@ -21,7 +22,6 @@ parallel = extensions.nodes.nodes:ParallelNode
21
22
  stub = extensions.nodes.nodes:StubNode
22
23
  success = extensions.nodes.nodes:SuccessNode
23
24
  task = extensions.nodes.nodes:TaskNode
24
- torch = extensions.nodes.torch:TorchNode
25
25
 
26
26
  [pickler]
27
27
  pickle = runnable.pickler:NativePickler
extensions/nodes/torch.py DELETED
@@ -1,273 +0,0 @@
1
- import importlib
2
- import logging
3
- import os
4
- import random
5
- import string
6
- from datetime import datetime
7
- from pathlib import Path
8
- from typing import Any, Callable, Optional
9
-
10
- from pydantic import BaseModel, ConfigDict, Field, field_serializer
11
-
12
- from extensions.nodes.torch_config import EasyTorchConfig, TorchConfig
13
- from runnable import PythonJob, datastore, defaults
14
- from runnable.datastore import StepLog
15
- from runnable.nodes import ExecutableNode
16
- from runnable.tasks import PythonTaskType, create_task
17
- from runnable.utils import TypeMapVariable
18
-
19
- logger = logging.getLogger(defaults.LOGGER_NAME)
20
-
21
- try:
22
- from torch.distributed.elastic.multiprocessing.api import DefaultLogsSpecs, Std
23
- from torch.distributed.launcher.api import LaunchConfig, elastic_launch
24
- except ImportError:
25
- logger.exception("Torch is not installed. Please install torch first.")
26
- raise Exception("Torch is not installed. Please install torch first.")
27
-
28
-
29
- def training_subprocess():
30
- """
31
- This function is called by the torch.distributed.launcher.api.elastic_launch
32
- It happens in a subprocess and is responsible for executing the user's function
33
-
34
- It is unrelated to the actual node execution, so any cataloging, run_log_store should be
35
- handled to match to main process.
36
-
37
- We have these variables to use:
38
-
39
- os.environ["RUNNABLE_TORCH_COMMAND"] = self.executable.command
40
- os.environ["RUNNABLE_TORCH_PARAMETERS_FILES"] = (
41
- self._context.parameters_file or ""
42
- )
43
- os.environ["RUNNABLE_TORCH_RUN_ID"] = self._context.run_id
44
- os.environ["RUNNABLE_TORCH_COPY_CONTENTS_TO"] = (
45
- self._context.catalog_handler.compute_data_folder
46
- )
47
- os.environ["RUNNABLE_TORCH_TORCH_LOGS"] = self.log_dir or ""
48
-
49
- """
50
- command = os.environ.get("RUNNABLE_TORCH_COMMAND")
51
- run_id = os.environ.get("RUNNABLE_TORCH_RUN_ID", "")
52
- parameters_files = os.environ.get("RUNNABLE_TORCH_PARAMETERS_FILES", "")
53
-
54
- process_run_id = (
55
- run_id
56
- + "-"
57
- + os.environ.get("RANK", "")
58
- + "-"
59
- + "".join(random.choices(string.ascii_lowercase, k=3))
60
- )
61
- os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
62
-
63
- delete_env_vars_with_prefix("RUNNABLE_")
64
-
65
- func = get_callable_from_dotted_path(command)
66
-
67
- # The job runs with the default configuration
68
- # ALl the execution logs are stored in .catalog
69
- job = PythonJob(function=func)
70
-
71
- job.execute(
72
- parameters_file=parameters_files,
73
- job_id=process_run_id,
74
- )
75
-
76
- from runnable.context import run_context
77
-
78
- job_log = run_context.run_log_store.get_run_log_by_id(run_id=run_context.run_id)
79
-
80
- if job_log.status == defaults.FAIL:
81
- raise Exception(f"Job {process_run_id} failed")
82
-
83
-
84
- # TODO: Can this be utils.get_module_and_attr_names
85
- def get_callable_from_dotted_path(dotted_path) -> Callable:
86
- try:
87
- # Split the path into module path and callable object
88
- module_path, callable_name = dotted_path.rsplit(".", 1)
89
-
90
- # Import the module
91
- module = importlib.import_module(module_path)
92
-
93
- # Get the callable from the module
94
- callable_obj = getattr(module, callable_name)
95
-
96
- # Check if the object is callable
97
- if not callable(callable_obj):
98
- raise TypeError(f"The object {callable_name} is not callable.")
99
-
100
- return callable_obj
101
-
102
- except (ImportError, AttributeError, ValueError) as e:
103
- raise ImportError(f"Could not import '{dotted_path}'.") from e
104
-
105
-
106
- def delete_env_vars_with_prefix(prefix):
107
- to_delete = [] # List to keep track of variables to delete
108
-
109
- # Iterate over a list of all environment variable keys
110
- for var in os.environ:
111
- if var.startswith(prefix):
112
- to_delete.append(var)
113
-
114
- # Delete each of the variables collected
115
- for var in to_delete:
116
- del os.environ[var]
117
-
118
-
119
- # TODO: The design of this class is not final
120
- class TorchNode(ExecutableNode, TorchConfig):
121
- node_type: str = Field(default="torch", serialization_alias="type")
122
- executable: PythonTaskType = Field(exclude=True)
123
-
124
- # Similar to TaskNode
125
- model_config = ConfigDict(extra="allow")
126
-
127
- def get_summary(self) -> dict[str, Any]:
128
- summary = {
129
- "name": self.name,
130
- "type": self.node_type,
131
- }
132
-
133
- return summary
134
-
135
- @classmethod
136
- def parse_from_config(cls, config: dict[str, Any]) -> "TorchNode":
137
- task_config = {
138
- k: v for k, v in config.items() if k not in TorchNode.model_fields.keys()
139
- }
140
- node_config = {
141
- k: v for k, v in config.items() if k in TorchNode.model_fields.keys()
142
- }
143
-
144
- executable = create_task(task_config)
145
-
146
- assert isinstance(executable, PythonTaskType)
147
- return cls(executable=executable, **node_config, **task_config)
148
-
149
- def get_launch_config(self) -> LaunchConfig:
150
- internal_log_spec = InternalLogSpecs(**self.model_dump(exclude_none=True))
151
- log_spec: DefaultLogsSpecs = DefaultLogsSpecs(
152
- **internal_log_spec.model_dump(exclude_none=True)
153
- )
154
- easy_torch_config = EasyTorchConfig(
155
- **self.model_dump(
156
- exclude_none=True,
157
- )
158
- )
159
-
160
- launch_config = LaunchConfig(
161
- **easy_torch_config.model_dump(
162
- exclude_none=True,
163
- ),
164
- logs_specs=log_spec,
165
- run_id=self._context.run_id,
166
- )
167
- logger.info(f"launch_config: {launch_config}")
168
- return launch_config
169
-
170
- def execute(
171
- self,
172
- mock=False,
173
- map_variable: TypeMapVariable = None,
174
- attempt_number: int = 1,
175
- ) -> StepLog:
176
- assert (
177
- map_variable is None or not map_variable
178
- ), "TorchNode does not support map_variable"
179
-
180
- step_log = self._context.run_log_store.get_step_log(
181
- self._get_step_log_name(map_variable), self._context.run_id
182
- )
183
-
184
- # Attempt to call the function or elastic launch
185
- launch_config = self.get_launch_config()
186
- logger.info(f"launch_config: {launch_config}")
187
-
188
- # ENV variables are shared with the subprocess, use that as communication
189
- os.environ["RUNNABLE_TORCH_COMMAND"] = self.executable.command
190
- os.environ["RUNNABLE_TORCH_PARAMETERS_FILES"] = (
191
- self._context.parameters_file or ""
192
- )
193
- os.environ["RUNNABLE_TORCH_RUN_ID"] = self._context.run_id
194
-
195
- launcher = elastic_launch(
196
- launch_config,
197
- training_subprocess,
198
- )
199
- try:
200
- launcher()
201
- attempt_log = datastore.StepAttempt(
202
- status=defaults.SUCCESS,
203
- start_time=str(datetime.now()),
204
- end_time=str(datetime.now()),
205
- attempt_number=attempt_number,
206
- )
207
- except Exception as e:
208
- attempt_log = datastore.StepAttempt(
209
- status=defaults.FAIL,
210
- start_time=str(datetime.now()),
211
- end_time=str(datetime.now()),
212
- attempt_number=attempt_number,
213
- )
214
- logger.error(f"Error executing TorchNode: {e}")
215
- finally:
216
- # This can only come from the subprocess
217
- if Path(".catalog").exists():
218
- os.rename(".catalog", "proc_logs")
219
- # Move .catalog and torch_logs to the parent node's catalog location
220
- self._context.catalog_handler.put(
221
- "proc_logs/**/*", allow_file_not_found_exc=True
222
- )
223
-
224
- # TODO: This is not working!!
225
- if self.log_dir:
226
- self._context.catalog_handler.put(
227
- self.log_dir + "/**/*", allow_file_not_found_exc=True
228
- )
229
-
230
- delete_env_vars_with_prefix("RUNNABLE_TORCH")
231
-
232
- logger.info(f"attempt_log: {attempt_log}")
233
- logger.info(f"Step {self.name} completed with status: {attempt_log.status}")
234
-
235
- step_log.status = attempt_log.status
236
- step_log.attempts.append(attempt_log)
237
-
238
- return step_log
239
-
240
- def fan_in(self, map_variable: dict[str, str | int | float] | None = None):
241
- # Destroy the service
242
- # Destroy the statefulset
243
- assert (
244
- map_variable is None or not map_variable
245
- ), "TorchNode does not support map_variable"
246
-
247
- def fan_out(self, map_variable: dict[str, str | int | float] | None = None):
248
- # Create a service
249
- # Create a statefulset
250
- # Gather the IPs and set them as parameters downstream
251
- assert (
252
- map_variable is None or not map_variable
253
- ), "TorchNode does not support map_variable"
254
-
255
-
256
- # This internal model makes it easier to extract the required fields
257
- # of log specs from user specification.
258
- # https://github.com/pytorch/pytorch/blob/main/torch/distributed/elastic/multiprocessing/api.py#L243
259
- class InternalLogSpecs(BaseModel):
260
- log_dir: Optional[str] = Field(default="torch_logs")
261
- redirects: str = Field(default="0") # Std.NONE
262
- tee: str = Field(default="0") # Std.NONE
263
- local_ranks_filter: Optional[set[int]] = Field(default=None)
264
-
265
- model_config = ConfigDict(extra="ignore")
266
-
267
- @field_serializer("redirects")
268
- def convert_redirects(self, redirects: str) -> Std | dict[int, Std]:
269
- return Std.from_str(redirects)
270
-
271
- @field_serializer("tee")
272
- def convert_tee(self, tee: str) -> Std | dict[int, Std]:
273
- return Std.from_str(tee)
@@ -1,76 +0,0 @@
1
- from enum import Enum
2
- from typing import Any, Optional
3
-
4
- from pydantic import BaseModel, ConfigDict, Field, computed_field
5
-
6
-
7
- class StartMethod(str, Enum):
8
- spawn = "spawn"
9
- fork = "fork"
10
- forkserver = "forkserver"
11
-
12
-
13
- ## The idea is the following:
14
- # Users can configure any of the options present in TorchConfig class.
15
- # The LaunchConfig class will be created from TorchConfig.
16
- # The LogSpecs is sent as a parameter to the launch config.
17
-
18
- ## NO idea of standalone and how to send it
19
-
20
-
21
- # The user sees this as part of the config of the node.
22
- # It is kept as similar as possible to torchrun
23
- class TorchConfig(BaseModel):
24
- model_config = ConfigDict(extra="forbid")
25
-
26
- # excluded as LaunchConfig requires min and max nodes
27
- nnodes: str = Field(default="1:1", exclude=True, description="min:max")
28
- nproc_per_node: int = Field(default=1, description="Number of processes per node")
29
-
30
- # will be used to create the log specs
31
- # But they are excluded from dump as logs specs is a class for LaunchConfig
32
- # from_str("0") -> Std.NONE
33
- # from_str("1") -> Std.OUT
34
- # from_str("0:3,1:0,2:1,3:2") -> {0: Std.ALL, 1: Std.NONE, 2: Std.OUT, 3: Std.ERR}
35
- log_dir: Optional[str] = Field(default="torch_logs", exclude=True)
36
- redirects: str = Field(default="0", exclude=True) # Std.NONE
37
- tee: str = Field(default="0", exclude=True) # Std.NONE
38
- local_ranks_filter: Optional[set[int]] = Field(default=None, exclude=True)
39
-
40
- role: str | None = Field(default=None)
41
-
42
- # run_id would be the run_id of the context
43
- # and sent at the creation of the LaunchConfig
44
-
45
- # This section is about the communication between nodes/processes
46
- rdzv_backend: str | None = Field(default="static")
47
- rdzv_endpoint: str | None = Field(default="")
48
- rdzv_configs: dict[str, Any] = Field(default_factory=dict)
49
- rdzv_timeout: int | None = Field(default=None)
50
-
51
- max_restarts: int | None = Field(default=None)
52
- monitor_interval: float | None = Field(default=None)
53
- start_method: str | None = Field(default=StartMethod.spawn)
54
- log_line_prefix_template: str | None = Field(default=None)
55
- local_addr: Optional[str] = None
56
-
57
- # https://github.com/pytorch/pytorch/blob/main/torch/distributed/run.py#L753
58
- # master_addr: str | None = Field(default="localhost")
59
- # master_port: str | None = Field(default="29500")
60
- # training_script: str = Field(default="dummy_training_script")
61
- # training_script_args: str = Field(default="")
62
-
63
-
64
- class EasyTorchConfig(TorchConfig):
65
- model_config = ConfigDict(extra="ignore")
66
-
67
- # TODO: Validate min < max
68
- @computed_field # type: ignore
69
- @property
70
- def min_nodes(self) -> int:
71
- return int(self.nnodes.split(":")[0])
72
-
73
- @computed_field # type: ignore
74
- @property
75
- def max_nodes(self) -> int:
76
- return int(self.nnodes.split(":")[1])