runnable 0.34.0a2__py3-none-any.whl → 0.35.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- extensions/nodes/conditional.py +241 -0
- extensions/pipeline_executor/argo.py +32 -26
- runnable/__init__.py +2 -1
- runnable/nodes.py +2 -1
- runnable/sdk.py +64 -12
- runnable/tasks.py +17 -21
- {runnable-0.34.0a2.dist-info → runnable-0.35.0.dist-info}/METADATA +2 -2
- {runnable-0.34.0a2.dist-info → runnable-0.35.0.dist-info}/RECORD +11 -12
- {runnable-0.34.0a2.dist-info → runnable-0.35.0.dist-info}/entry_points.txt +1 -1
- extensions/nodes/torch.py +0 -273
- extensions/nodes/torch_config.py +0 -76
- {runnable-0.34.0a2.dist-info → runnable-0.35.0.dist-info}/WHEEL +0 -0
- {runnable-0.34.0a2.dist-info → runnable-0.35.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,241 @@
|
|
1
|
+
import logging
|
2
|
+
from copy import deepcopy
|
3
|
+
from typing import Any, cast
|
4
|
+
|
5
|
+
from pydantic import Field, field_serializer, field_validator
|
6
|
+
|
7
|
+
from runnable import console, defaults
|
8
|
+
from runnable.datastore import Parameter
|
9
|
+
from runnable.graph import Graph, create_graph
|
10
|
+
from runnable.nodes import CompositeNode, TypeMapVariable
|
11
|
+
|
12
|
+
logger = logging.getLogger(defaults.LOGGER_NAME)
|
13
|
+
|
14
|
+
|
15
|
+
class ConditionalNode(CompositeNode):
|
16
|
+
"""
|
17
|
+
parameter: name -> the parameter which is used for evaluation
|
18
|
+
default: Optional[branch] = branch to execute if nothing is matched.
|
19
|
+
branches: {
|
20
|
+
"case1" : branch1,
|
21
|
+
"case2: branch2,
|
22
|
+
}
|
23
|
+
|
24
|
+
Conceptually this is equal to:
|
25
|
+
match parameter:
|
26
|
+
case "case1":
|
27
|
+
branch1
|
28
|
+
case "case2":
|
29
|
+
branch2
|
30
|
+
case _:
|
31
|
+
default
|
32
|
+
|
33
|
+
"""
|
34
|
+
|
35
|
+
node_type: str = Field(default="conditional", serialization_alias="type")
|
36
|
+
|
37
|
+
parameter: str # the name of the parameter should be isalnum
|
38
|
+
default: Graph | None = Field(default=None) # TODO: Think about the design of this
|
39
|
+
branches: dict[str, Graph]
|
40
|
+
# The keys of the branches should be isalnum()
|
41
|
+
|
42
|
+
@field_validator("parameter", mode="after")
|
43
|
+
@classmethod
|
44
|
+
def check_parameter(cls, parameter: str) -> str:
|
45
|
+
"""
|
46
|
+
Validate that the parameter name is alphanumeric.
|
47
|
+
|
48
|
+
Args:
|
49
|
+
parameter (str): The parameter name to validate.
|
50
|
+
|
51
|
+
Raises:
|
52
|
+
ValueError: If the parameter name is not alphanumeric.
|
53
|
+
|
54
|
+
Returns:
|
55
|
+
str: The validated parameter name.
|
56
|
+
"""
|
57
|
+
if not parameter.isalnum():
|
58
|
+
raise ValueError(f"Parameter '{parameter}' must be alphanumeric.")
|
59
|
+
return parameter
|
60
|
+
|
61
|
+
def get_parameter_value(self) -> str | int | bool | float:
|
62
|
+
"""
|
63
|
+
Get the parameter value from the context.
|
64
|
+
|
65
|
+
Returns:
|
66
|
+
Any: The value of the parameter.
|
67
|
+
"""
|
68
|
+
parameters: dict[str, Parameter] = self._context.run_log_store.get_parameters(
|
69
|
+
run_id=self._context.run_id
|
70
|
+
)
|
71
|
+
|
72
|
+
if self.parameter not in parameters:
|
73
|
+
raise Exception(f"Parameter {self.parameter} not found in parameters")
|
74
|
+
|
75
|
+
chosen_parameter_value = parameters[self.parameter].get_value()
|
76
|
+
|
77
|
+
assert isinstance(chosen_parameter_value, (int, float, bool, str)), (
|
78
|
+
f"Parameter '{self.parameter}' must be of type int, float, bool, or str, "
|
79
|
+
f"but got {type(chosen_parameter_value).__name__}."
|
80
|
+
)
|
81
|
+
|
82
|
+
return chosen_parameter_value
|
83
|
+
|
84
|
+
def get_summary(self) -> dict[str, Any]:
|
85
|
+
summary = {
|
86
|
+
"name": self.name,
|
87
|
+
"type": self.node_type,
|
88
|
+
"branches": [branch.get_summary() for branch in self.branches.values()],
|
89
|
+
"parameter": self.parameter,
|
90
|
+
"default": self.default.get_summary() if self.default else None,
|
91
|
+
}
|
92
|
+
|
93
|
+
return summary
|
94
|
+
|
95
|
+
@field_serializer("branches")
|
96
|
+
def ser_branches(self, branches: dict[str, Graph]) -> dict[str, Graph]:
|
97
|
+
ret: dict[str, Graph] = {}
|
98
|
+
|
99
|
+
for branch_name, branch in branches.items():
|
100
|
+
ret[branch_name.split(".")[-1]] = branch
|
101
|
+
|
102
|
+
return ret
|
103
|
+
|
104
|
+
@classmethod
|
105
|
+
def parse_from_config(cls, config: dict[str, Any]) -> "ConditionalNode":
|
106
|
+
internal_name = cast(str, config.get("internal_name"))
|
107
|
+
|
108
|
+
config_branches = config.pop("branches", {})
|
109
|
+
branches = {}
|
110
|
+
for branch_name, branch_config in config_branches.items():
|
111
|
+
sub_graph = create_graph(
|
112
|
+
deepcopy(branch_config),
|
113
|
+
internal_branch_name=internal_name + "." + branch_name,
|
114
|
+
)
|
115
|
+
branches[internal_name + "." + branch_name] = sub_graph
|
116
|
+
|
117
|
+
if not branches:
|
118
|
+
raise Exception("A parallel node should have branches")
|
119
|
+
return cls(branches=branches, **config)
|
120
|
+
|
121
|
+
def _get_branch_by_name(self, branch_name: str) -> Graph:
|
122
|
+
if branch_name in self.branches:
|
123
|
+
return self.branches[branch_name]
|
124
|
+
|
125
|
+
raise Exception(f"Branch {branch_name} does not exist")
|
126
|
+
|
127
|
+
def fan_out(self, map_variable: TypeMapVariable = None):
|
128
|
+
"""
|
129
|
+
This method is restricted to creating branch logs.
|
130
|
+
"""
|
131
|
+
parameter_value = self.get_parameter_value()
|
132
|
+
|
133
|
+
hit_once = False
|
134
|
+
|
135
|
+
for internal_branch_name, _ in self.branches.items():
|
136
|
+
# the match is done on the last part of the branch name
|
137
|
+
result = str(parameter_value) == internal_branch_name.split(".")[-1]
|
138
|
+
|
139
|
+
if not result:
|
140
|
+
# Need not create a branch log for this branch
|
141
|
+
continue
|
142
|
+
|
143
|
+
effective_branch_name = self._resolve_map_placeholders(
|
144
|
+
internal_branch_name, map_variable=map_variable
|
145
|
+
)
|
146
|
+
|
147
|
+
hit_once = True
|
148
|
+
branch_log = self._context.run_log_store.create_branch_log(
|
149
|
+
effective_branch_name
|
150
|
+
)
|
151
|
+
|
152
|
+
console.print(
|
153
|
+
f"Branch log created for {effective_branch_name}: {branch_log}"
|
154
|
+
)
|
155
|
+
branch_log.status = defaults.PROCESSING
|
156
|
+
self._context.run_log_store.add_branch_log(branch_log, self._context.run_id)
|
157
|
+
|
158
|
+
if not hit_once:
|
159
|
+
raise Exception(
|
160
|
+
"None of the branches were true. Please check your evaluate statements"
|
161
|
+
)
|
162
|
+
|
163
|
+
def execute_as_graph(self, map_variable: TypeMapVariable = None):
|
164
|
+
"""
|
165
|
+
This function does the actual execution of the sub-branches of the parallel node.
|
166
|
+
|
167
|
+
From a design perspective, this function should not be called if the execution is 3rd party orchestrated.
|
168
|
+
|
169
|
+
The modes that render the job specifications, do not need to interact with this node at all as they have their
|
170
|
+
own internal mechanisms of handing parallel states.
|
171
|
+
If they do not, you can find a way using as-is nodes as hack nodes.
|
172
|
+
|
173
|
+
The execution of a dag, could result in
|
174
|
+
* The dag being completely executed with a definite (fail, success) state in case of
|
175
|
+
local or local-container execution
|
176
|
+
* The dag being in a processing state with PROCESSING status in case of local-aws-batch
|
177
|
+
|
178
|
+
Only fail state is considered failure during this phase of execution.
|
179
|
+
|
180
|
+
Args:
|
181
|
+
executor (Executor): The Executor as per the use config
|
182
|
+
**kwargs: Optional kwargs passed around
|
183
|
+
"""
|
184
|
+
self.fan_out(map_variable=map_variable)
|
185
|
+
parameter_value = self.get_parameter_value()
|
186
|
+
|
187
|
+
for internal_branch_name, branch in self.branches.items():
|
188
|
+
result = str(parameter_value) == internal_branch_name.split(".")[-1]
|
189
|
+
|
190
|
+
if result:
|
191
|
+
# if the condition is met, execute the graph
|
192
|
+
logger.debug(f"Executing graph for {branch}")
|
193
|
+
self._context.executor.execute_graph(branch, map_variable=map_variable)
|
194
|
+
|
195
|
+
self.fan_in(map_variable=map_variable)
|
196
|
+
|
197
|
+
def fan_in(self, map_variable: TypeMapVariable = None):
|
198
|
+
"""
|
199
|
+
The general fan in method for a node of type Parallel.
|
200
|
+
|
201
|
+
3rd party orchestrators should use this method to find the status of the composite step.
|
202
|
+
|
203
|
+
Args:
|
204
|
+
executor (BaseExecutor): The executor class as defined by the config
|
205
|
+
map_variable (dict, optional): If the node is part of a map. Defaults to None.
|
206
|
+
"""
|
207
|
+
effective_internal_name = self._resolve_map_placeholders(
|
208
|
+
self.internal_name, map_variable=map_variable
|
209
|
+
)
|
210
|
+
|
211
|
+
step_success_bool: bool = True
|
212
|
+
parameter_value = self.get_parameter_value()
|
213
|
+
|
214
|
+
for internal_branch_name, _ in self.branches.items():
|
215
|
+
result = str(parameter_value) == internal_branch_name.split(".")[-1]
|
216
|
+
|
217
|
+
if not result:
|
218
|
+
# The branch would not have been executed
|
219
|
+
continue
|
220
|
+
|
221
|
+
effective_branch_name = self._resolve_map_placeholders(
|
222
|
+
internal_branch_name, map_variable=map_variable
|
223
|
+
)
|
224
|
+
|
225
|
+
branch_log = self._context.run_log_store.get_branch_log(
|
226
|
+
effective_branch_name, self._context.run_id
|
227
|
+
)
|
228
|
+
|
229
|
+
if branch_log.status != defaults.SUCCESS:
|
230
|
+
step_success_bool = False
|
231
|
+
|
232
|
+
step_log = self._context.run_log_store.get_step_log(
|
233
|
+
effective_internal_name, self._context.run_id
|
234
|
+
)
|
235
|
+
|
236
|
+
if step_success_bool: # If none failed
|
237
|
+
step_log.status = defaults.SUCCESS
|
238
|
+
else:
|
239
|
+
step_log.status = defaults.FAIL
|
240
|
+
|
241
|
+
self._context.run_log_store.add_step_log(step_log, self._context.run_id)
|
@@ -20,6 +20,7 @@ from pydantic import (
|
|
20
20
|
from pydantic.alias_generators import to_camel
|
21
21
|
from ruamel.yaml import YAML
|
22
22
|
|
23
|
+
from extensions.nodes.conditional import ConditionalNode
|
23
24
|
from extensions.nodes.nodes import MapNode, ParallelNode, TaskNode
|
24
25
|
|
25
26
|
# TODO: Should be part of a wider refactor
|
@@ -307,6 +308,7 @@ class DagTask(BaseModelWIthConfig):
|
|
307
308
|
template: str # Should be name of a container template or dag template
|
308
309
|
arguments: Optional[Arguments] = Field(default=None)
|
309
310
|
with_param: Optional[str] = Field(default=None)
|
311
|
+
when_param: Optional[str] = Field(default=None, serialization_alias="when")
|
310
312
|
depends: Optional[str] = Field(default=None)
|
311
313
|
|
312
314
|
|
@@ -563,6 +565,8 @@ class ArgoExecutor(GenericPipelineExecutor):
|
|
563
565
|
outputs: Optional[Outputs] = None
|
564
566
|
if mode == "out" and node.node_type == "map":
|
565
567
|
outputs = Outputs(parameters=[OutputParameter(name="iterate-on")])
|
568
|
+
if mode == "out" and node.node_type == "conditional":
|
569
|
+
outputs = Outputs(parameters=[OutputParameter(name="case")])
|
566
570
|
|
567
571
|
container_template = ContainerTemplate(
|
568
572
|
name=task_name,
|
@@ -722,6 +726,7 @@ class ArgoExecutor(GenericPipelineExecutor):
|
|
722
726
|
# - We are using withParam and arguments of the map template to send that value in
|
723
727
|
# - The map template should receive that value as a parameter into the template.
|
724
728
|
# - The task then start to use it as inputs.parameters.iterate-on
|
729
|
+
# the when param should be an evaluation
|
725
730
|
|
726
731
|
def _gather_tasks_for_dag_template(
|
727
732
|
self,
|
@@ -767,9 +772,11 @@ class ArgoExecutor(GenericPipelineExecutor):
|
|
767
772
|
|
768
773
|
self._templates.append(template_of_container)
|
769
774
|
|
770
|
-
case "map" | "parallel":
|
771
|
-
assert
|
772
|
-
working_on,
|
775
|
+
case "map" | "parallel" | "conditional":
|
776
|
+
assert (
|
777
|
+
isinstance(working_on, MapNode)
|
778
|
+
or isinstance(working_on, ParallelNode)
|
779
|
+
or isinstance(working_on, ConditionalNode)
|
773
780
|
)
|
774
781
|
node_type = working_on.node_type
|
775
782
|
|
@@ -792,7 +799,8 @@ class ArgoExecutor(GenericPipelineExecutor):
|
|
792
799
|
)
|
793
800
|
|
794
801
|
# Add the composite task
|
795
|
-
with_param = None
|
802
|
+
with_param: Optional[str] = None
|
803
|
+
when_param: Optional[str] = None
|
796
804
|
added_parameters = parameters or []
|
797
805
|
branches = {}
|
798
806
|
if node_type == "map":
|
@@ -807,22 +815,34 @@ class ArgoExecutor(GenericPipelineExecutor):
|
|
807
815
|
elif node_type == "parallel":
|
808
816
|
assert isinstance(working_on, ParallelNode)
|
809
817
|
branches = working_on.branches
|
818
|
+
elif node_type == "conditional":
|
819
|
+
assert isinstance(working_on, ConditionalNode)
|
820
|
+
branches = working_on.branches
|
821
|
+
when_param = (
|
822
|
+
f"{{{{tasks.{task_name}-fan-out.outputs.parameters.case}}}}"
|
823
|
+
)
|
810
824
|
else:
|
811
825
|
raise ValueError("Invalid node type")
|
812
826
|
|
813
827
|
fan_in_depends = ""
|
814
828
|
|
815
829
|
for name, branch in branches.items():
|
830
|
+
match_when = branch.internal_branch_name.split(".")[-1]
|
816
831
|
name = (
|
817
832
|
name.replace(" ", "-").replace(".", "-").replace("_", "-")
|
818
833
|
)
|
819
834
|
|
835
|
+
if node_type == "conditional":
|
836
|
+
assert isinstance(working_on, ConditionalNode)
|
837
|
+
when_param = f"'{match_when}' == {{{{tasks.{task_name}-fan-out.outputs.parameters.case}}}}"
|
838
|
+
|
820
839
|
branch_task = DagTask(
|
821
840
|
name=f"{task_name}-{name}",
|
822
841
|
template=f"{task_name}-{name}",
|
823
842
|
depends=f"{task_name}-fan-out.Succeeded",
|
824
843
|
arguments=Arguments(parameters=added_parameters),
|
825
844
|
with_param=with_param,
|
845
|
+
when_param=when_param,
|
826
846
|
)
|
827
847
|
composite_template.dag.tasks.append(branch_task)
|
828
848
|
|
@@ -836,6 +856,8 @@ class ArgoExecutor(GenericPipelineExecutor):
|
|
836
856
|
),
|
837
857
|
)
|
838
858
|
|
859
|
+
assert isinstance(branch, Graph)
|
860
|
+
|
839
861
|
self._gather_tasks_for_dag_template(
|
840
862
|
dag_template=branch_template,
|
841
863
|
dag=branch,
|
@@ -862,28 +884,6 @@ class ArgoExecutor(GenericPipelineExecutor):
|
|
862
884
|
|
863
885
|
self._templates.append(composite_template)
|
864
886
|
|
865
|
-
case "torch":
|
866
|
-
from extensions.nodes.torch import TorchNode
|
867
|
-
|
868
|
-
assert isinstance(working_on, TorchNode)
|
869
|
-
# TODO: Need to add multi-node functionality
|
870
|
-
# Check notes on the torch node
|
871
|
-
|
872
|
-
template_of_container = self._create_container_template(
|
873
|
-
working_on,
|
874
|
-
task_name=task_name,
|
875
|
-
inputs=Inputs(parameters=parameters),
|
876
|
-
)
|
877
|
-
assert template_of_container.container is not None
|
878
|
-
|
879
|
-
if working_on.node_type == "task":
|
880
|
-
self._expose_secrets_to_task(
|
881
|
-
working_on,
|
882
|
-
container_template=template_of_container.container,
|
883
|
-
)
|
884
|
-
|
885
|
-
self._templates.append(template_of_container)
|
886
|
-
|
887
887
|
self._handle_failures(
|
888
888
|
working_on,
|
889
889
|
dag,
|
@@ -1025,6 +1025,12 @@ class ArgoExecutor(GenericPipelineExecutor):
|
|
1025
1025
|
with open("/tmp/output.txt", mode="w", encoding="utf-8") as myfile:
|
1026
1026
|
json.dump(iterate_on.get_value(), myfile, indent=4)
|
1027
1027
|
|
1028
|
+
if node.node_type == "conditional":
|
1029
|
+
assert isinstance(node, ConditionalNode)
|
1030
|
+
|
1031
|
+
with open("/tmp/output.txt", mode="w", encoding="utf-8") as myfile:
|
1032
|
+
json.dump(node.get_parameter_value(), myfile, indent=4)
|
1033
|
+
|
1028
1034
|
def fan_in(self, node: BaseNode, map_variable: TypeMapVariable = None):
|
1029
1035
|
self._use_volumes()
|
1030
1036
|
super().fan_in(node, map_variable)
|
runnable/__init__.py
CHANGED
runnable/nodes.py
CHANGED
@@ -8,6 +8,7 @@ import runnable.context as context
|
|
8
8
|
from runnable import defaults, exceptions
|
9
9
|
from runnable.datastore import StepLog
|
10
10
|
from runnable.defaults import TypeMapVariable
|
11
|
+
from runnable.graph import Graph
|
11
12
|
|
12
13
|
logger = logging.getLogger(defaults.LOGGER_NAME)
|
13
14
|
|
@@ -218,7 +219,7 @@ class BaseNode(ABC, BaseModel):
|
|
218
219
|
"""
|
219
220
|
|
220
221
|
@abstractmethod
|
221
|
-
def _get_branch_by_name(self, branch_name: str):
|
222
|
+
def _get_branch_by_name(self, branch_name: str) -> Graph:
|
222
223
|
"""
|
223
224
|
Retrieve a branch by name.
|
224
225
|
|
runnable/sdk.py
CHANGED
@@ -26,6 +26,7 @@ from rich.progress import (
|
|
26
26
|
from rich.table import Column
|
27
27
|
from typing_extensions import Self
|
28
28
|
|
29
|
+
from extensions.nodes.conditional import ConditionalNode
|
29
30
|
from extensions.nodes.nodes import (
|
30
31
|
FailNode,
|
31
32
|
MapNode,
|
@@ -50,6 +51,7 @@ StepType = Union[
|
|
50
51
|
"Parallel",
|
51
52
|
"Map",
|
52
53
|
"TorchTask",
|
54
|
+
"Conditional",
|
53
55
|
]
|
54
56
|
|
55
57
|
|
@@ -193,6 +195,9 @@ class BaseTask(BaseTraversal):
|
|
193
195
|
"This method should be implemented in the child class"
|
194
196
|
)
|
195
197
|
|
198
|
+
def as_pipeline(self) -> "Pipeline":
|
199
|
+
return Pipeline(steps=[self]) # type: ignore
|
200
|
+
|
196
201
|
|
197
202
|
class PythonTask(BaseTask):
|
198
203
|
"""
|
@@ -283,14 +288,15 @@ class PythonTask(BaseTask):
|
|
283
288
|
|
284
289
|
|
285
290
|
class TorchTask(BaseTask):
|
286
|
-
entrypoint: str = Field(
|
287
|
-
|
288
|
-
)
|
289
|
-
args_to_torchrun: Dict[str, Any] = Field(
|
290
|
-
|
291
|
-
)
|
291
|
+
# entrypoint: str = Field(
|
292
|
+
# alias="entrypoint", default="torch.distributed.run", frozen=True
|
293
|
+
# )
|
294
|
+
# args_to_torchrun: Dict[str, Any] = Field(
|
295
|
+
# default_factory=dict, alias="args_to_torchrun"
|
296
|
+
# )
|
292
297
|
|
293
298
|
script_to_call: str
|
299
|
+
accelerate_config_file: str
|
294
300
|
|
295
301
|
@computed_field
|
296
302
|
def command_type(self) -> str:
|
@@ -520,6 +526,53 @@ class Parallel(BaseTraversal):
|
|
520
526
|
return node
|
521
527
|
|
522
528
|
|
529
|
+
class Conditional(BaseTraversal):
|
530
|
+
branches: Dict[str, "Pipeline"]
|
531
|
+
parameter: str # the name of the parameter should be isalnum
|
532
|
+
|
533
|
+
@field_validator("parameter")
|
534
|
+
@classmethod
|
535
|
+
def validate_parameter(cls, parameter: str) -> str:
|
536
|
+
if not parameter.isalnum():
|
537
|
+
raise AssertionError(
|
538
|
+
"The parameter name should be alphanumeric and not empty"
|
539
|
+
)
|
540
|
+
return parameter
|
541
|
+
|
542
|
+
@field_validator("branches")
|
543
|
+
@classmethod
|
544
|
+
def validate_branches(
|
545
|
+
cls, branches: Dict[str, "Pipeline"]
|
546
|
+
) -> Dict[str, "Pipeline"]:
|
547
|
+
for branch_name in branches.keys():
|
548
|
+
if not branch_name.isalnum():
|
549
|
+
raise ValueError(f"Branch '{branch_name}' must be alphanumeric.")
|
550
|
+
return branches
|
551
|
+
|
552
|
+
@computed_field # type: ignore
|
553
|
+
@property
|
554
|
+
def graph_branches(self) -> Dict[str, graph.Graph]:
|
555
|
+
return {
|
556
|
+
name: pipeline._dag.model_copy() for name, pipeline in self.branches.items()
|
557
|
+
}
|
558
|
+
|
559
|
+
def create_node(self) -> ConditionalNode:
|
560
|
+
if not self.next_node:
|
561
|
+
if not (self.terminate_with_failure or self.terminate_with_success):
|
562
|
+
raise AssertionError(
|
563
|
+
"A node not being terminated must have a user defined next node"
|
564
|
+
)
|
565
|
+
|
566
|
+
node = ConditionalNode(
|
567
|
+
name=self.name,
|
568
|
+
branches=self.graph_branches,
|
569
|
+
internal_name="",
|
570
|
+
next_node=self.next_node,
|
571
|
+
parameter=self.parameter,
|
572
|
+
)
|
573
|
+
return node
|
574
|
+
|
575
|
+
|
523
576
|
class Map(BaseTraversal):
|
524
577
|
"""
|
525
578
|
A node that iterates over a list of items and executes a pipeline for each item.
|
@@ -543,7 +596,6 @@ class Map(BaseTraversal):
|
|
543
596
|
iterate_on: str
|
544
597
|
iterate_as: str
|
545
598
|
reducer: Optional[str] = Field(default=None, alias="reducer")
|
546
|
-
overrides: Dict[str, Any] = Field(default_factory=dict)
|
547
599
|
|
548
600
|
@computed_field # type: ignore
|
549
601
|
@property
|
@@ -564,7 +616,6 @@ class Map(BaseTraversal):
|
|
564
616
|
next_node=self.next_node,
|
565
617
|
iterate_on=self.iterate_on,
|
566
618
|
iterate_as=self.iterate_as,
|
567
|
-
overrides=self.overrides,
|
568
619
|
reducer=self.reducer,
|
569
620
|
)
|
570
621
|
|
@@ -984,13 +1035,14 @@ class PythonJob(BaseJob):
|
|
984
1035
|
|
985
1036
|
|
986
1037
|
class TorchJob(BaseJob):
|
987
|
-
entrypoint: str = Field(default="torch.distributed.run", frozen=True)
|
988
|
-
args_to_torchrun: dict[str, str | bool | int | float] = Field(
|
989
|
-
|
990
|
-
) # For example
|
1038
|
+
# entrypoint: str = Field(default="torch.distributed.run", frozen=True)
|
1039
|
+
# args_to_torchrun: dict[str, str | bool | int | float] = Field(
|
1040
|
+
# default_factory=dict
|
1041
|
+
# ) # For example
|
991
1042
|
# {"nproc_per_node": 2, "nnodes": 1,}
|
992
1043
|
|
993
1044
|
script_to_call: str # For example train/script.py
|
1045
|
+
accelerate_config_file: str
|
994
1046
|
|
995
1047
|
def get_task(self) -> RunnableTask:
|
996
1048
|
# Piggy bank on existing tasks as a hack
|
runnable/tasks.py
CHANGED
@@ -5,7 +5,6 @@ import io
|
|
5
5
|
import json
|
6
6
|
import logging
|
7
7
|
import os
|
8
|
-
import runpy
|
9
8
|
import subprocess
|
10
9
|
import sys
|
11
10
|
from datetime import datetime
|
@@ -357,16 +356,15 @@ class PythonTaskType(BaseTaskType): # pylint: disable=too-few-public-methods
|
|
357
356
|
|
358
357
|
class TorchTaskType(BaseTaskType):
|
359
358
|
task_type: str = Field(default="torch", serialization_alias="command_type")
|
360
|
-
|
361
|
-
entrypoint: str = Field(default="torch.distributed.run", frozen=True)
|
362
|
-
args_to_torchrun: dict[str, str | bool] = Field(default_factory=dict) # For example
|
363
|
-
# {"nproc_per_node": 2, "nnodes": 1,}
|
359
|
+
accelerate_config_file: str
|
364
360
|
|
365
361
|
script_to_call: str # For example train/script.py
|
366
362
|
|
367
363
|
def execute_command(
|
368
364
|
self, map_variable: Dict[str, str | int | float] | None = None
|
369
365
|
) -> StepAttempt:
|
366
|
+
from accelerate.commands import launch
|
367
|
+
|
370
368
|
attempt_log = StepAttempt(status=defaults.FAIL, start_time=str(datetime.now()))
|
371
369
|
|
372
370
|
with (
|
@@ -376,39 +374,37 @@ class TorchTaskType(BaseTaskType):
|
|
376
374
|
self.expose_secrets() as _,
|
377
375
|
):
|
378
376
|
try:
|
379
|
-
|
380
|
-
|
381
|
-
for key, value in self.args_to_torchrun.items():
|
382
|
-
entry_point_args.append(f"--{key}")
|
383
|
-
if type(value) is not bool:
|
384
|
-
entry_point_args.append(str(value))
|
385
|
-
|
386
|
-
entry_point_args.append(self.script_to_call)
|
377
|
+
script_args = []
|
387
378
|
for key, value in params.items():
|
388
|
-
|
389
|
-
if type(value.value) is not bool:
|
390
|
-
|
379
|
+
script_args.append(f"--{key}")
|
380
|
+
if type(value.value) is not bool:
|
381
|
+
script_args.append(str(value.value))
|
391
382
|
|
392
383
|
# TODO: Check the typing here
|
393
384
|
|
394
385
|
logger.info("Calling the user script with the following parameters:")
|
395
|
-
logger.info(
|
386
|
+
logger.info(script_args)
|
396
387
|
out_file = TeeIO()
|
397
388
|
try:
|
398
389
|
with contextlib.redirect_stdout(out_file):
|
399
|
-
|
400
|
-
|
390
|
+
parser = launch.launch_command_parser()
|
391
|
+
args = parser.parse_args(self.script_to_call)
|
392
|
+
args.training_script = self.script_to_call
|
393
|
+
args.config_file = self.accelerate_config_file
|
394
|
+
args.training_script_args = script_args
|
395
|
+
|
396
|
+
launch.launch_command(args)
|
401
397
|
task_console.print(out_file.getvalue())
|
402
398
|
except Exception as e:
|
403
399
|
raise exceptions.CommandCallError(
|
404
|
-
f"Call to
|
400
|
+
f"Call to script{self.script_to_call} did not succeed."
|
405
401
|
) from e
|
406
402
|
finally:
|
407
403
|
sys.argv = sys.argv[:1]
|
408
404
|
|
409
405
|
attempt_log.status = defaults.SUCCESS
|
410
406
|
except Exception as _e:
|
411
|
-
msg = f"Call to
|
407
|
+
msg = f"Call to script: {self.script_to_call} did not succeed."
|
412
408
|
attempt_log.message = msg
|
413
409
|
task_console.print_exception(show_locals=False)
|
414
410
|
task_console.log(_e, style=defaults.error_style)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: runnable
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.35.0
|
4
4
|
Summary: Add your description here
|
5
5
|
Author-email: "Vammi, Vijay" <vijay.vammi@astrazeneca.com>
|
6
6
|
License-File: LICENSE
|
@@ -27,8 +27,8 @@ Requires-Dist: ploomber-engine>=0.0.33; extra == 'notebook'
|
|
27
27
|
Provides-Extra: s3
|
28
28
|
Requires-Dist: cloudpathlib[s3]; extra == 's3'
|
29
29
|
Provides-Extra: torch
|
30
|
+
Requires-Dist: accelerate>=1.5.2; extra == 'torch'
|
30
31
|
Requires-Dist: torch>=2.6.0; extra == 'torch'
|
31
|
-
Requires-Dist: torchvision>=0.21.0; extra == 'torch'
|
32
32
|
Description-Content-Type: text/markdown
|
33
33
|
|
34
34
|
|
@@ -14,13 +14,12 @@ extensions/job_executor/local.py,sha256=3ZbCFXBvbLlMp10JTmQJJrjBKG2keHI6SH8hEvmH
|
|
14
14
|
extensions/job_executor/local_container.py,sha256=1JcLJ0zrNSNHdubrSO9miN54iwvPLHqKMZ08aOC8WWo,6886
|
15
15
|
extensions/job_executor/pyproject.toml,sha256=UIEgiCYHTXcRWSByNMFuKJFKgxTBpQqTqyUecIsb_Vc,286
|
16
16
|
extensions/nodes/README.md,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
17
|
+
extensions/nodes/conditional.py,sha256=m4DGxjqWpjNd2KQPAdVSJ6ridt1BDx2Lt6kmEQa9ghY,8594
|
17
18
|
extensions/nodes/nodes.py,sha256=s9ub1dqy4qHjRQG6YElCdL7rCOTYNs9RUIrStZ6tEB4,28256
|
18
19
|
extensions/nodes/pyproject.toml,sha256=YTu-ETN3JNFSkMzzWeOwn4m-O2nbRH-PmiPBALDCUw4,278
|
19
|
-
extensions/nodes/torch.py,sha256=64DTjdPNSJ8vfMwUN9h9Ly5g9qj-Bga7LSGrfCAO0BY,9389
|
20
|
-
extensions/nodes/torch_config.py,sha256=tO3sG2_fj8a6FmPZZllwKVx3WaRr4QmQYcACseg8YXM,2839
|
21
20
|
extensions/pipeline_executor/README.md,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
22
21
|
extensions/pipeline_executor/__init__.py,sha256=wfigTL2T9OHrmE8b2Ydmb8h6hr-oF--Yc2FectC7WaY,24623
|
23
|
-
extensions/pipeline_executor/argo.py,sha256=
|
22
|
+
extensions/pipeline_executor/argo.py,sha256=17hHj3L5oIkoOpCSSbZlliLnOUoN5_JpK_DY0ELWXac,38233
|
24
23
|
extensions/pipeline_executor/local.py,sha256=6oWUJ6b6NvIkpeQJBoCT1hbfX4_6WCB4HzMgHZ4ik1A,1887
|
25
24
|
extensions/pipeline_executor/local_container.py,sha256=3kZ2QCsrq_YjH9dcAz8v05knKShQ_JtbIU-IA_-G538,12724
|
26
25
|
extensions/pipeline_executor/mocked.py,sha256=0sMmypuvstBIv9uQg-WAcPrF3oOFpeEXNi6N8Nzdnl0,5680
|
@@ -42,7 +41,7 @@ extensions/secrets/dotenv.py,sha256=nADHXI6KJ_LUYOIe5EbtYH-21OBebSNVr0Pjb1GlZ7w,
|
|
42
41
|
extensions/secrets/pyproject.toml,sha256=mLJNImNcBlbLKHh-0ugVWT9V83R4RibyyYDtBCSqVF4,282
|
43
42
|
extensions/tasks/torch.py,sha256=oeXRkmuttFIAuBwH7-h4SOVXMDOZXX5mvqI2aFrR3Vo,10283
|
44
43
|
extensions/tasks/torch_config.py,sha256=UjfMitT-TXASRDGR30I2vDRnyk7JQnR-5CsOVidjpSY,2833
|
45
|
-
runnable/__init__.py,sha256=
|
44
|
+
runnable/__init__.py,sha256=eRXLgO-iiSUmNkjjzBjWdBP7Fp--I_vnImyhoGxZUek,709
|
46
45
|
runnable/catalog.py,sha256=4msQxLhLKlsDDrHFnGauPYe-Or-q9g8_RYCn_4dpxaU,4466
|
47
46
|
runnable/cli.py,sha256=3BiKSj95h2Drn__YlchMPZ5rBMafuRb2OGIsVpbsO5Y,8788
|
48
47
|
runnable/context.py,sha256=by5uepmuCP0dmM9BmsliXihSes5QEFejwAsmekcqylE,1388
|
@@ -53,15 +52,15 @@ runnable/exceptions.py,sha256=LFbp0-Qxg2PAMLEVt7w2whhBxSG-5pzUEv5qN-Rc4_c,3003
|
|
53
52
|
runnable/executor.py,sha256=Jr9yJtSH7CzjXJLWx3VWIUAQblstuGqzpFtajv7d39M,15348
|
54
53
|
runnable/graph.py,sha256=poQz5zcvq89ju_u5sYlunQLPbHnXTaUmjcvstPwvT4U,16536
|
55
54
|
runnable/names.py,sha256=vn92Kv9ANROYSZX6Z4z1v_WA3WiEdIYmG6KEStBFZug,8134
|
56
|
-
runnable/nodes.py,sha256=
|
55
|
+
runnable/nodes.py,sha256=CWfKVuGNaKSQpvFYYE1gEiTNouG0xPaA8KKaOxFr8EI,16733
|
57
56
|
runnable/parameters.py,sha256=u77CdqqDAbVdzNeBFPNUfGnWPy9-SpBVmwEJ56xmDm8,5289
|
58
57
|
runnable/pickler.py,sha256=ydJ_eti_U1F4l-YacFp7BWm6g5vTn04UXye25S1HVok,2684
|
59
|
-
runnable/sdk.py,sha256=
|
58
|
+
runnable/sdk.py,sha256=1gerGsq6EMSbDh2-Ey1vk6e0Sls55t9R29KlblNahi0,36793
|
60
59
|
runnable/secrets.py,sha256=4L_dBFxTgr8r_hHUD6RlZEtqaOHDRsFG5PXO5wlvMI0,2324
|
61
|
-
runnable/tasks.py,sha256=
|
60
|
+
runnable/tasks.py,sha256=lOtCninvosGI2bNIzblrzNa-lN7TMwel1KQ1g23M85A,32088
|
62
61
|
runnable/utils.py,sha256=hBr7oGwGL2VgfITlQCTz-a1iwvvf7Mfl-HY8UdENZac,19929
|
63
|
-
runnable-0.
|
64
|
-
runnable-0.
|
65
|
-
runnable-0.
|
66
|
-
runnable-0.
|
67
|
-
runnable-0.
|
62
|
+
runnable-0.35.0.dist-info/METADATA,sha256=CgZbaiNCY_mUrcdyOGYV_6zkVwSrGMzqbUdrKQ-LL0U,10166
|
63
|
+
runnable-0.35.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
64
|
+
runnable-0.35.0.dist-info/entry_points.txt,sha256=bLH1QXcc-G8xgJTi4wf6SYQnsG_BxRRvobwa9dYm-js,1935
|
65
|
+
runnable-0.35.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
66
|
+
runnable-0.35.0.dist-info/RECORD,,
|
@@ -14,6 +14,7 @@ local-container = extensions.job_executor.local_container:LocalContainerJobExecu
|
|
14
14
|
mini-k8s-job = extensions.job_executor.k8s:MiniK8sJobExecutor
|
15
15
|
|
16
16
|
[nodes]
|
17
|
+
conditional = extensions.nodes.conditional:ConditionalNode
|
17
18
|
dag = extensions.nodes.nodes:DagNode
|
18
19
|
fail = extensions.nodes.nodes:FailNode
|
19
20
|
map = extensions.nodes.nodes:MapNode
|
@@ -21,7 +22,6 @@ parallel = extensions.nodes.nodes:ParallelNode
|
|
21
22
|
stub = extensions.nodes.nodes:StubNode
|
22
23
|
success = extensions.nodes.nodes:SuccessNode
|
23
24
|
task = extensions.nodes.nodes:TaskNode
|
24
|
-
torch = extensions.nodes.torch:TorchNode
|
25
25
|
|
26
26
|
[pickler]
|
27
27
|
pickle = runnable.pickler:NativePickler
|
extensions/nodes/torch.py
DELETED
@@ -1,273 +0,0 @@
|
|
1
|
-
import importlib
|
2
|
-
import logging
|
3
|
-
import os
|
4
|
-
import random
|
5
|
-
import string
|
6
|
-
from datetime import datetime
|
7
|
-
from pathlib import Path
|
8
|
-
from typing import Any, Callable, Optional
|
9
|
-
|
10
|
-
from pydantic import BaseModel, ConfigDict, Field, field_serializer
|
11
|
-
|
12
|
-
from extensions.nodes.torch_config import EasyTorchConfig, TorchConfig
|
13
|
-
from runnable import PythonJob, datastore, defaults
|
14
|
-
from runnable.datastore import StepLog
|
15
|
-
from runnable.nodes import ExecutableNode
|
16
|
-
from runnable.tasks import PythonTaskType, create_task
|
17
|
-
from runnable.utils import TypeMapVariable
|
18
|
-
|
19
|
-
logger = logging.getLogger(defaults.LOGGER_NAME)
|
20
|
-
|
21
|
-
try:
|
22
|
-
from torch.distributed.elastic.multiprocessing.api import DefaultLogsSpecs, Std
|
23
|
-
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
24
|
-
except ImportError:
|
25
|
-
logger.exception("Torch is not installed. Please install torch first.")
|
26
|
-
raise Exception("Torch is not installed. Please install torch first.")
|
27
|
-
|
28
|
-
|
29
|
-
def training_subprocess():
|
30
|
-
"""
|
31
|
-
This function is called by the torch.distributed.launcher.api.elastic_launch
|
32
|
-
It happens in a subprocess and is responsible for executing the user's function
|
33
|
-
|
34
|
-
It is unrelated to the actual node execution, so any cataloging, run_log_store should be
|
35
|
-
handled to match to main process.
|
36
|
-
|
37
|
-
We have these variables to use:
|
38
|
-
|
39
|
-
os.environ["RUNNABLE_TORCH_COMMAND"] = self.executable.command
|
40
|
-
os.environ["RUNNABLE_TORCH_PARAMETERS_FILES"] = (
|
41
|
-
self._context.parameters_file or ""
|
42
|
-
)
|
43
|
-
os.environ["RUNNABLE_TORCH_RUN_ID"] = self._context.run_id
|
44
|
-
os.environ["RUNNABLE_TORCH_COPY_CONTENTS_TO"] = (
|
45
|
-
self._context.catalog_handler.compute_data_folder
|
46
|
-
)
|
47
|
-
os.environ["RUNNABLE_TORCH_TORCH_LOGS"] = self.log_dir or ""
|
48
|
-
|
49
|
-
"""
|
50
|
-
command = os.environ.get("RUNNABLE_TORCH_COMMAND")
|
51
|
-
run_id = os.environ.get("RUNNABLE_TORCH_RUN_ID", "")
|
52
|
-
parameters_files = os.environ.get("RUNNABLE_TORCH_PARAMETERS_FILES", "")
|
53
|
-
|
54
|
-
process_run_id = (
|
55
|
-
run_id
|
56
|
-
+ "-"
|
57
|
-
+ os.environ.get("RANK", "")
|
58
|
-
+ "-"
|
59
|
-
+ "".join(random.choices(string.ascii_lowercase, k=3))
|
60
|
-
)
|
61
|
-
os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
|
62
|
-
|
63
|
-
delete_env_vars_with_prefix("RUNNABLE_")
|
64
|
-
|
65
|
-
func = get_callable_from_dotted_path(command)
|
66
|
-
|
67
|
-
# The job runs with the default configuration
|
68
|
-
# ALl the execution logs are stored in .catalog
|
69
|
-
job = PythonJob(function=func)
|
70
|
-
|
71
|
-
job.execute(
|
72
|
-
parameters_file=parameters_files,
|
73
|
-
job_id=process_run_id,
|
74
|
-
)
|
75
|
-
|
76
|
-
from runnable.context import run_context
|
77
|
-
|
78
|
-
job_log = run_context.run_log_store.get_run_log_by_id(run_id=run_context.run_id)
|
79
|
-
|
80
|
-
if job_log.status == defaults.FAIL:
|
81
|
-
raise Exception(f"Job {process_run_id} failed")
|
82
|
-
|
83
|
-
|
84
|
-
# TODO: Can this be utils.get_module_and_attr_names
|
85
|
-
def get_callable_from_dotted_path(dotted_path) -> Callable:
|
86
|
-
try:
|
87
|
-
# Split the path into module path and callable object
|
88
|
-
module_path, callable_name = dotted_path.rsplit(".", 1)
|
89
|
-
|
90
|
-
# Import the module
|
91
|
-
module = importlib.import_module(module_path)
|
92
|
-
|
93
|
-
# Get the callable from the module
|
94
|
-
callable_obj = getattr(module, callable_name)
|
95
|
-
|
96
|
-
# Check if the object is callable
|
97
|
-
if not callable(callable_obj):
|
98
|
-
raise TypeError(f"The object {callable_name} is not callable.")
|
99
|
-
|
100
|
-
return callable_obj
|
101
|
-
|
102
|
-
except (ImportError, AttributeError, ValueError) as e:
|
103
|
-
raise ImportError(f"Could not import '{dotted_path}'.") from e
|
104
|
-
|
105
|
-
|
106
|
-
def delete_env_vars_with_prefix(prefix):
|
107
|
-
to_delete = [] # List to keep track of variables to delete
|
108
|
-
|
109
|
-
# Iterate over a list of all environment variable keys
|
110
|
-
for var in os.environ:
|
111
|
-
if var.startswith(prefix):
|
112
|
-
to_delete.append(var)
|
113
|
-
|
114
|
-
# Delete each of the variables collected
|
115
|
-
for var in to_delete:
|
116
|
-
del os.environ[var]
|
117
|
-
|
118
|
-
|
119
|
-
# TODO: The design of this class is not final
|
120
|
-
class TorchNode(ExecutableNode, TorchConfig):
|
121
|
-
node_type: str = Field(default="torch", serialization_alias="type")
|
122
|
-
executable: PythonTaskType = Field(exclude=True)
|
123
|
-
|
124
|
-
# Similar to TaskNode
|
125
|
-
model_config = ConfigDict(extra="allow")
|
126
|
-
|
127
|
-
def get_summary(self) -> dict[str, Any]:
|
128
|
-
summary = {
|
129
|
-
"name": self.name,
|
130
|
-
"type": self.node_type,
|
131
|
-
}
|
132
|
-
|
133
|
-
return summary
|
134
|
-
|
135
|
-
@classmethod
|
136
|
-
def parse_from_config(cls, config: dict[str, Any]) -> "TorchNode":
|
137
|
-
task_config = {
|
138
|
-
k: v for k, v in config.items() if k not in TorchNode.model_fields.keys()
|
139
|
-
}
|
140
|
-
node_config = {
|
141
|
-
k: v for k, v in config.items() if k in TorchNode.model_fields.keys()
|
142
|
-
}
|
143
|
-
|
144
|
-
executable = create_task(task_config)
|
145
|
-
|
146
|
-
assert isinstance(executable, PythonTaskType)
|
147
|
-
return cls(executable=executable, **node_config, **task_config)
|
148
|
-
|
149
|
-
def get_launch_config(self) -> LaunchConfig:
|
150
|
-
internal_log_spec = InternalLogSpecs(**self.model_dump(exclude_none=True))
|
151
|
-
log_spec: DefaultLogsSpecs = DefaultLogsSpecs(
|
152
|
-
**internal_log_spec.model_dump(exclude_none=True)
|
153
|
-
)
|
154
|
-
easy_torch_config = EasyTorchConfig(
|
155
|
-
**self.model_dump(
|
156
|
-
exclude_none=True,
|
157
|
-
)
|
158
|
-
)
|
159
|
-
|
160
|
-
launch_config = LaunchConfig(
|
161
|
-
**easy_torch_config.model_dump(
|
162
|
-
exclude_none=True,
|
163
|
-
),
|
164
|
-
logs_specs=log_spec,
|
165
|
-
run_id=self._context.run_id,
|
166
|
-
)
|
167
|
-
logger.info(f"launch_config: {launch_config}")
|
168
|
-
return launch_config
|
169
|
-
|
170
|
-
def execute(
|
171
|
-
self,
|
172
|
-
mock=False,
|
173
|
-
map_variable: TypeMapVariable = None,
|
174
|
-
attempt_number: int = 1,
|
175
|
-
) -> StepLog:
|
176
|
-
assert (
|
177
|
-
map_variable is None or not map_variable
|
178
|
-
), "TorchNode does not support map_variable"
|
179
|
-
|
180
|
-
step_log = self._context.run_log_store.get_step_log(
|
181
|
-
self._get_step_log_name(map_variable), self._context.run_id
|
182
|
-
)
|
183
|
-
|
184
|
-
# Attempt to call the function or elastic launch
|
185
|
-
launch_config = self.get_launch_config()
|
186
|
-
logger.info(f"launch_config: {launch_config}")
|
187
|
-
|
188
|
-
# ENV variables are shared with the subprocess, use that as communication
|
189
|
-
os.environ["RUNNABLE_TORCH_COMMAND"] = self.executable.command
|
190
|
-
os.environ["RUNNABLE_TORCH_PARAMETERS_FILES"] = (
|
191
|
-
self._context.parameters_file or ""
|
192
|
-
)
|
193
|
-
os.environ["RUNNABLE_TORCH_RUN_ID"] = self._context.run_id
|
194
|
-
|
195
|
-
launcher = elastic_launch(
|
196
|
-
launch_config,
|
197
|
-
training_subprocess,
|
198
|
-
)
|
199
|
-
try:
|
200
|
-
launcher()
|
201
|
-
attempt_log = datastore.StepAttempt(
|
202
|
-
status=defaults.SUCCESS,
|
203
|
-
start_time=str(datetime.now()),
|
204
|
-
end_time=str(datetime.now()),
|
205
|
-
attempt_number=attempt_number,
|
206
|
-
)
|
207
|
-
except Exception as e:
|
208
|
-
attempt_log = datastore.StepAttempt(
|
209
|
-
status=defaults.FAIL,
|
210
|
-
start_time=str(datetime.now()),
|
211
|
-
end_time=str(datetime.now()),
|
212
|
-
attempt_number=attempt_number,
|
213
|
-
)
|
214
|
-
logger.error(f"Error executing TorchNode: {e}")
|
215
|
-
finally:
|
216
|
-
# This can only come from the subprocess
|
217
|
-
if Path(".catalog").exists():
|
218
|
-
os.rename(".catalog", "proc_logs")
|
219
|
-
# Move .catalog and torch_logs to the parent node's catalog location
|
220
|
-
self._context.catalog_handler.put(
|
221
|
-
"proc_logs/**/*", allow_file_not_found_exc=True
|
222
|
-
)
|
223
|
-
|
224
|
-
# TODO: This is not working!!
|
225
|
-
if self.log_dir:
|
226
|
-
self._context.catalog_handler.put(
|
227
|
-
self.log_dir + "/**/*", allow_file_not_found_exc=True
|
228
|
-
)
|
229
|
-
|
230
|
-
delete_env_vars_with_prefix("RUNNABLE_TORCH")
|
231
|
-
|
232
|
-
logger.info(f"attempt_log: {attempt_log}")
|
233
|
-
logger.info(f"Step {self.name} completed with status: {attempt_log.status}")
|
234
|
-
|
235
|
-
step_log.status = attempt_log.status
|
236
|
-
step_log.attempts.append(attempt_log)
|
237
|
-
|
238
|
-
return step_log
|
239
|
-
|
240
|
-
def fan_in(self, map_variable: dict[str, str | int | float] | None = None):
|
241
|
-
# Destroy the service
|
242
|
-
# Destroy the statefulset
|
243
|
-
assert (
|
244
|
-
map_variable is None or not map_variable
|
245
|
-
), "TorchNode does not support map_variable"
|
246
|
-
|
247
|
-
def fan_out(self, map_variable: dict[str, str | int | float] | None = None):
|
248
|
-
# Create a service
|
249
|
-
# Create a statefulset
|
250
|
-
# Gather the IPs and set them as parameters downstream
|
251
|
-
assert (
|
252
|
-
map_variable is None or not map_variable
|
253
|
-
), "TorchNode does not support map_variable"
|
254
|
-
|
255
|
-
|
256
|
-
# This internal model makes it easier to extract the required fields
|
257
|
-
# of log specs from user specification.
|
258
|
-
# https://github.com/pytorch/pytorch/blob/main/torch/distributed/elastic/multiprocessing/api.py#L243
|
259
|
-
class InternalLogSpecs(BaseModel):
|
260
|
-
log_dir: Optional[str] = Field(default="torch_logs")
|
261
|
-
redirects: str = Field(default="0") # Std.NONE
|
262
|
-
tee: str = Field(default="0") # Std.NONE
|
263
|
-
local_ranks_filter: Optional[set[int]] = Field(default=None)
|
264
|
-
|
265
|
-
model_config = ConfigDict(extra="ignore")
|
266
|
-
|
267
|
-
@field_serializer("redirects")
|
268
|
-
def convert_redirects(self, redirects: str) -> Std | dict[int, Std]:
|
269
|
-
return Std.from_str(redirects)
|
270
|
-
|
271
|
-
@field_serializer("tee")
|
272
|
-
def convert_tee(self, tee: str) -> Std | dict[int, Std]:
|
273
|
-
return Std.from_str(tee)
|
extensions/nodes/torch_config.py
DELETED
@@ -1,76 +0,0 @@
|
|
1
|
-
from enum import Enum
|
2
|
-
from typing import Any, Optional
|
3
|
-
|
4
|
-
from pydantic import BaseModel, ConfigDict, Field, computed_field
|
5
|
-
|
6
|
-
|
7
|
-
class StartMethod(str, Enum):
|
8
|
-
spawn = "spawn"
|
9
|
-
fork = "fork"
|
10
|
-
forkserver = "forkserver"
|
11
|
-
|
12
|
-
|
13
|
-
## The idea is the following:
|
14
|
-
# Users can configure any of the options present in TorchConfig class.
|
15
|
-
# The LaunchConfig class will be created from TorchConfig.
|
16
|
-
# The LogSpecs is sent as a parameter to the launch config.
|
17
|
-
|
18
|
-
## NO idea of standalone and how to send it
|
19
|
-
|
20
|
-
|
21
|
-
# The user sees this as part of the config of the node.
|
22
|
-
# It is kept as similar as possible to torchrun
|
23
|
-
class TorchConfig(BaseModel):
|
24
|
-
model_config = ConfigDict(extra="forbid")
|
25
|
-
|
26
|
-
# excluded as LaunchConfig requires min and max nodes
|
27
|
-
nnodes: str = Field(default="1:1", exclude=True, description="min:max")
|
28
|
-
nproc_per_node: int = Field(default=1, description="Number of processes per node")
|
29
|
-
|
30
|
-
# will be used to create the log specs
|
31
|
-
# But they are excluded from dump as logs specs is a class for LaunchConfig
|
32
|
-
# from_str("0") -> Std.NONE
|
33
|
-
# from_str("1") -> Std.OUT
|
34
|
-
# from_str("0:3,1:0,2:1,3:2") -> {0: Std.ALL, 1: Std.NONE, 2: Std.OUT, 3: Std.ERR}
|
35
|
-
log_dir: Optional[str] = Field(default="torch_logs", exclude=True)
|
36
|
-
redirects: str = Field(default="0", exclude=True) # Std.NONE
|
37
|
-
tee: str = Field(default="0", exclude=True) # Std.NONE
|
38
|
-
local_ranks_filter: Optional[set[int]] = Field(default=None, exclude=True)
|
39
|
-
|
40
|
-
role: str | None = Field(default=None)
|
41
|
-
|
42
|
-
# run_id would be the run_id of the context
|
43
|
-
# and sent at the creation of the LaunchConfig
|
44
|
-
|
45
|
-
# This section is about the communication between nodes/processes
|
46
|
-
rdzv_backend: str | None = Field(default="static")
|
47
|
-
rdzv_endpoint: str | None = Field(default="")
|
48
|
-
rdzv_configs: dict[str, Any] = Field(default_factory=dict)
|
49
|
-
rdzv_timeout: int | None = Field(default=None)
|
50
|
-
|
51
|
-
max_restarts: int | None = Field(default=None)
|
52
|
-
monitor_interval: float | None = Field(default=None)
|
53
|
-
start_method: str | None = Field(default=StartMethod.spawn)
|
54
|
-
log_line_prefix_template: str | None = Field(default=None)
|
55
|
-
local_addr: Optional[str] = None
|
56
|
-
|
57
|
-
# https://github.com/pytorch/pytorch/blob/main/torch/distributed/run.py#L753
|
58
|
-
# master_addr: str | None = Field(default="localhost")
|
59
|
-
# master_port: str | None = Field(default="29500")
|
60
|
-
# training_script: str = Field(default="dummy_training_script")
|
61
|
-
# training_script_args: str = Field(default="")
|
62
|
-
|
63
|
-
|
64
|
-
class EasyTorchConfig(TorchConfig):
|
65
|
-
model_config = ConfigDict(extra="ignore")
|
66
|
-
|
67
|
-
# TODO: Validate min < max
|
68
|
-
@computed_field # type: ignore
|
69
|
-
@property
|
70
|
-
def min_nodes(self) -> int:
|
71
|
-
return int(self.nnodes.split(":")[0])
|
72
|
-
|
73
|
-
@computed_field # type: ignore
|
74
|
-
@property
|
75
|
-
def max_nodes(self) -> int:
|
76
|
-
return int(self.nnodes.split(":")[1])
|
File without changes
|
File without changes
|