runnable 0.28.7__py3-none-any.whl → 0.29.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- extensions/job_executor/k8s.py +8 -9
- extensions/job_executor/local.py +7 -5
- extensions/job_executor/local_container.py +7 -5
- extensions/nodes/nodes.py +15 -195
- extensions/nodes/torch.py +169 -0
- extensions/nodes/torch_config.py +33 -0
- extensions/pipeline_executor/__init__.py +10 -14
- extensions/pipeline_executor/argo.py +1 -3
- extensions/pipeline_executor/local.py +6 -10
- extensions/pipeline_executor/local_container.py +10 -12
- extensions/pipeline_executor/mocked.py +6 -12
- extensions/pipeline_executor/retry.py +6 -10
- extensions/run_log_store/generic_chunked.py +1 -2
- extensions/secrets/dotenv.py +1 -1
- extensions/tasks/torch.py +52 -0
- runnable/__init__.py +1 -0
- runnable/entrypoints.py +2 -2
- runnable/executor.py +6 -11
- runnable/nodes.py +44 -25
- runnable/sdk.py +46 -4
- runnable/secrets.py +3 -3
- runnable/tasks.py +0 -4
- {runnable-0.28.7.dist-info → runnable-0.29.0.dist-info}/METADATA +3 -1
- {runnable-0.28.7.dist-info → runnable-0.29.0.dist-info}/RECORD +27 -24
- {runnable-0.28.7.dist-info → runnable-0.29.0.dist-info}/entry_points.txt +1 -0
- {runnable-0.28.7.dist-info → runnable-0.29.0.dist-info}/WHEEL +0 -0
- {runnable-0.28.7.dist-info → runnable-0.29.0.dist-info}/licenses/LICENSE +0 -0
extensions/job_executor/k8s.py
CHANGED
@@ -11,7 +11,7 @@ from rich import print
|
|
11
11
|
|
12
12
|
from extensions.job_executor import GenericJobExecutor
|
13
13
|
from runnable import console, defaults, utils
|
14
|
-
from runnable.datastore import DataCatalog
|
14
|
+
from runnable.datastore import DataCatalog, StepAttempt
|
15
15
|
from runnable.tasks import BaseTaskType
|
16
16
|
|
17
17
|
logger = logging.getLogger(defaults.NAME)
|
@@ -213,10 +213,12 @@ class GenericK8sJobExecutor(GenericJobExecutor):
|
|
213
213
|
job_log = self._context.run_log_store.get_job_log(run_id=self._context.run_id)
|
214
214
|
self.add_code_identities(job_log)
|
215
215
|
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
216
|
+
if not self.mock:
|
217
|
+
attempt_log = job.execute_command()
|
218
|
+
else:
|
219
|
+
attempt_log = StepAttempt(
|
220
|
+
status=defaults.SUCCESS,
|
221
|
+
)
|
220
222
|
|
221
223
|
job_log.status = attempt_log.status
|
222
224
|
job_log.attempts.append(attempt_log)
|
@@ -455,10 +457,7 @@ class K8sJobExecutor(GenericK8sJobExecutor):
|
|
455
457
|
job_log = self._context.run_log_store.get_job_log(run_id=self._context.run_id)
|
456
458
|
self.add_code_identities(job_log)
|
457
459
|
|
458
|
-
attempt_log = job.execute_command(
|
459
|
-
attempt_number=self.step_attempt_number,
|
460
|
-
mock=self.mock,
|
461
|
-
)
|
460
|
+
attempt_log = job.execute_command()
|
462
461
|
|
463
462
|
job_log.status = attempt_log.status
|
464
463
|
job_log.attempts.append(attempt_log)
|
extensions/job_executor/local.py
CHANGED
@@ -3,7 +3,7 @@ from typing import List, Optional
|
|
3
3
|
|
4
4
|
from extensions.job_executor import GenericJobExecutor
|
5
5
|
from runnable import console, defaults
|
6
|
-
from runnable.datastore import DataCatalog
|
6
|
+
from runnable.datastore import DataCatalog, StepAttempt
|
7
7
|
from runnable.tasks import BaseTaskType
|
8
8
|
|
9
9
|
logger = logging.getLogger(defaults.LOGGER_NAME)
|
@@ -39,10 +39,12 @@ class LocalJobExecutor(GenericJobExecutor):
|
|
39
39
|
job_log = self._context.run_log_store.get_job_log(run_id=self._context.run_id)
|
40
40
|
self.add_code_identities(job_log)
|
41
41
|
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
42
|
+
if not self.mock:
|
43
|
+
attempt_log = job.execute_command()
|
44
|
+
else:
|
45
|
+
attempt_log = StepAttempt(
|
46
|
+
status=defaults.SUCCESS,
|
47
|
+
)
|
46
48
|
|
47
49
|
job_log.status = attempt_log.status
|
48
50
|
job_log.attempts.append(attempt_log)
|
@@ -6,7 +6,7 @@ from pydantic import Field
|
|
6
6
|
|
7
7
|
from extensions.job_executor import GenericJobExecutor
|
8
8
|
from runnable import console, defaults, utils
|
9
|
-
from runnable.datastore import DataCatalog
|
9
|
+
from runnable.datastore import DataCatalog, StepAttempt
|
10
10
|
from runnable.tasks import BaseTaskType
|
11
11
|
|
12
12
|
logger = logging.getLogger(defaults.LOGGER_NAME)
|
@@ -54,10 +54,12 @@ class LocalContainerJobExecutor(GenericJobExecutor):
|
|
54
54
|
job_log = self._context.run_log_store.get_job_log(run_id=self._context.run_id)
|
55
55
|
self.add_code_identities(job_log)
|
56
56
|
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
57
|
+
if not self.mock:
|
58
|
+
attempt_log = job.execute_command()
|
59
|
+
else:
|
60
|
+
attempt_log = StepAttempt(
|
61
|
+
status=defaults.SUCCESS,
|
62
|
+
)
|
61
63
|
|
62
64
|
job_log.status = attempt_log.status
|
63
65
|
job_log.attempts.append(attempt_log)
|
extensions/nodes/nodes.py
CHANGED
@@ -5,15 +5,9 @@ import sys
|
|
5
5
|
from collections import OrderedDict
|
6
6
|
from copy import deepcopy
|
7
7
|
from datetime import datetime
|
8
|
-
from typing import
|
9
|
-
|
10
|
-
from pydantic import
|
11
|
-
ConfigDict,
|
12
|
-
Field,
|
13
|
-
ValidationInfo,
|
14
|
-
field_serializer,
|
15
|
-
field_validator,
|
16
|
-
)
|
8
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
|
9
|
+
|
10
|
+
from pydantic import ConfigDict, Field, field_serializer
|
17
11
|
|
18
12
|
from runnable import console, datastore, defaults, utils
|
19
13
|
from runnable.datastore import (
|
@@ -73,7 +67,6 @@ class TaskNode(ExecutableNode):
|
|
73
67
|
mock=False,
|
74
68
|
map_variable: TypeMapVariable = None,
|
75
69
|
attempt_number: int = 1,
|
76
|
-
**kwargs,
|
77
70
|
) -> StepLog:
|
78
71
|
"""
|
79
72
|
All that we do in runnable is to come to this point where we actually execute the command.
|
@@ -135,7 +128,6 @@ class FailNode(TerminalNode):
|
|
135
128
|
mock=False,
|
136
129
|
map_variable: TypeMapVariable = None,
|
137
130
|
attempt_number: int = 1,
|
138
|
-
**kwargs,
|
139
131
|
) -> StepLog:
|
140
132
|
"""
|
141
133
|
Execute the failure node.
|
@@ -199,7 +191,6 @@ class SuccessNode(TerminalNode):
|
|
199
191
|
mock=False,
|
200
192
|
map_variable: TypeMapVariable = None,
|
201
193
|
attempt_number: int = 1,
|
202
|
-
**kwargs,
|
203
194
|
) -> StepLog:
|
204
195
|
"""
|
205
196
|
Execute the success node.
|
@@ -255,7 +246,6 @@ class ParallelNode(CompositeNode):
|
|
255
246
|
|
256
247
|
node_type: str = Field(default="parallel", serialization_alias="type")
|
257
248
|
branches: Dict[str, Graph]
|
258
|
-
is_composite: bool = Field(default=True, exclude=True)
|
259
249
|
|
260
250
|
def get_summary(self) -> Dict[str, Any]:
|
261
251
|
summary = {
|
@@ -298,7 +288,7 @@ class ParallelNode(CompositeNode):
|
|
298
288
|
|
299
289
|
raise Exception(f"Branch {branch_name} does not exist")
|
300
290
|
|
301
|
-
def fan_out(self, map_variable: TypeMapVariable = None
|
291
|
+
def fan_out(self, map_variable: TypeMapVariable = None):
|
302
292
|
"""
|
303
293
|
The general fan out method for a node of type Parallel.
|
304
294
|
This method assumes that the step log has already been created.
|
@@ -321,7 +311,7 @@ class ParallelNode(CompositeNode):
|
|
321
311
|
branch_log.status = defaults.PROCESSING
|
322
312
|
self._context.run_log_store.add_branch_log(branch_log, self._context.run_id)
|
323
313
|
|
324
|
-
def execute_as_graph(self, map_variable: TypeMapVariable = None
|
314
|
+
def execute_as_graph(self, map_variable: TypeMapVariable = None):
|
325
315
|
"""
|
326
316
|
This function does the actual execution of the sub-branches of the parallel node.
|
327
317
|
|
@@ -342,16 +332,14 @@ class ParallelNode(CompositeNode):
|
|
342
332
|
executor (Executor): The Executor as per the use config
|
343
333
|
**kwargs: Optional kwargs passed around
|
344
334
|
"""
|
345
|
-
self.fan_out(map_variable=map_variable
|
335
|
+
self.fan_out(map_variable=map_variable)
|
346
336
|
|
347
337
|
for _, branch in self.branches.items():
|
348
|
-
self._context.executor.execute_graph(
|
349
|
-
branch, map_variable=map_variable, **kwargs
|
350
|
-
)
|
338
|
+
self._context.executor.execute_graph(branch, map_variable=map_variable)
|
351
339
|
|
352
|
-
self.fan_in(map_variable=map_variable
|
340
|
+
self.fan_in(map_variable=map_variable)
|
353
341
|
|
354
|
-
def fan_in(self, map_variable: TypeMapVariable = None
|
342
|
+
def fan_in(self, map_variable: TypeMapVariable = None):
|
355
343
|
"""
|
356
344
|
The general fan in method for a node of type Parallel.
|
357
345
|
|
@@ -412,7 +400,6 @@ class MapNode(CompositeNode):
|
|
412
400
|
iterate_as: str
|
413
401
|
reducer: Optional[str] = Field(default=None)
|
414
402
|
branch: Graph
|
415
|
-
is_composite: bool = True
|
416
403
|
|
417
404
|
def get_summary(self) -> Dict[str, Any]:
|
418
405
|
summary = {
|
@@ -515,7 +502,7 @@ class MapNode(CompositeNode):
|
|
515
502
|
"""
|
516
503
|
return self.branch
|
517
504
|
|
518
|
-
def fan_out(self, map_variable: TypeMapVariable = None
|
505
|
+
def fan_out(self, map_variable: TypeMapVariable = None):
|
519
506
|
"""
|
520
507
|
The general method to fan out for a node of type map.
|
521
508
|
This method assumes that the step log has already been created.
|
@@ -563,7 +550,7 @@ class MapNode(CompositeNode):
|
|
563
550
|
parameters=raw_parameters, run_id=self._context.run_id
|
564
551
|
)
|
565
552
|
|
566
|
-
def execute_as_graph(self, map_variable: TypeMapVariable = None
|
553
|
+
def execute_as_graph(self, map_variable: TypeMapVariable = None):
|
567
554
|
"""
|
568
555
|
This function does the actual execution of the branch of the map node.
|
569
556
|
|
@@ -607,19 +594,19 @@ class MapNode(CompositeNode):
|
|
607
594
|
if not isinstance(iterate_on, list):
|
608
595
|
raise Exception("Only list is allowed as a valid iterator type")
|
609
596
|
|
610
|
-
self.fan_out(map_variable=map_variable
|
597
|
+
self.fan_out(map_variable=map_variable)
|
611
598
|
|
612
599
|
for iter_variable in iterate_on:
|
613
600
|
effective_map_variable = map_variable or OrderedDict()
|
614
601
|
effective_map_variable[self.iterate_as] = iter_variable
|
615
602
|
|
616
603
|
self._context.executor.execute_graph(
|
617
|
-
self.branch, map_variable=effective_map_variable
|
604
|
+
self.branch, map_variable=effective_map_variable
|
618
605
|
)
|
619
606
|
|
620
|
-
self.fan_in(map_variable=map_variable
|
607
|
+
self.fan_in(map_variable=map_variable)
|
621
608
|
|
622
|
-
def fan_in(self, map_variable: TypeMapVariable = None
|
609
|
+
def fan_in(self, map_variable: TypeMapVariable = None):
|
623
610
|
"""
|
624
611
|
The general method to fan in for a node of type map.
|
625
612
|
|
@@ -714,172 +701,6 @@ class MapNode(CompositeNode):
|
|
714
701
|
)
|
715
702
|
|
716
703
|
|
717
|
-
class DagNode(CompositeNode):
|
718
|
-
"""
|
719
|
-
A composite node that internally holds a dag.
|
720
|
-
|
721
|
-
The structure is generally:
|
722
|
-
DagNode:
|
723
|
-
dag_definition: A YAML file that holds the dag in 'dag' block
|
724
|
-
|
725
|
-
The config is expected to have a variable 'dag_definition'.
|
726
|
-
"""
|
727
|
-
|
728
|
-
node_type: str = Field(default="dag", serialization_alias="type")
|
729
|
-
dag_definition: str
|
730
|
-
branch: Graph
|
731
|
-
is_composite: bool = True
|
732
|
-
internal_branch_name: Annotated[str, Field(validate_default=True)] = ""
|
733
|
-
|
734
|
-
def get_summary(self) -> Dict[str, Any]:
|
735
|
-
summary = {
|
736
|
-
"name": self.name,
|
737
|
-
"type": self.node_type,
|
738
|
-
}
|
739
|
-
return summary
|
740
|
-
|
741
|
-
@field_validator("internal_branch_name")
|
742
|
-
@classmethod
|
743
|
-
def validate_internal_branch_name(
|
744
|
-
cls, internal_branch_name: str, info: ValidationInfo
|
745
|
-
):
|
746
|
-
internal_name = info.data["internal_name"]
|
747
|
-
return internal_name + "." + defaults.DAG_BRANCH_NAME
|
748
|
-
|
749
|
-
@field_validator("dag_definition")
|
750
|
-
@classmethod
|
751
|
-
def validate_dag_definition(cls, value):
|
752
|
-
if not value.endswith(".yaml"): # TODO: Might have a problem with the SDK
|
753
|
-
raise ValueError("dag_definition must be a YAML file")
|
754
|
-
return value
|
755
|
-
|
756
|
-
@classmethod
|
757
|
-
def parse_from_config(cls, config: Dict[str, Any]) -> "DagNode":
|
758
|
-
internal_name = cast(str, config.get("internal_name"))
|
759
|
-
|
760
|
-
if "dag_definition" not in config:
|
761
|
-
raise Exception(f"No dag definition found in {config}")
|
762
|
-
|
763
|
-
dag_config = utils.load_yaml(config["dag_definition"])
|
764
|
-
if "dag" not in dag_config:
|
765
|
-
raise Exception(
|
766
|
-
"No DAG found in dag_definition, please provide it in dag block"
|
767
|
-
)
|
768
|
-
|
769
|
-
branch = create_graph(
|
770
|
-
dag_config["dag"],
|
771
|
-
internal_branch_name=internal_name + "." + defaults.DAG_BRANCH_NAME,
|
772
|
-
)
|
773
|
-
|
774
|
-
return cls(branch=branch, **config)
|
775
|
-
|
776
|
-
def _get_branch_by_name(self, branch_name: str):
|
777
|
-
"""
|
778
|
-
Retrieve a branch by name.
|
779
|
-
The name is expected to follow a dot path convention.
|
780
|
-
|
781
|
-
Returns a Graph Object
|
782
|
-
|
783
|
-
Args:
|
784
|
-
branch_name (str): The name of the branch to retrieve
|
785
|
-
|
786
|
-
Raises:
|
787
|
-
Exception: If the branch_name is not 'dag'
|
788
|
-
"""
|
789
|
-
if branch_name != self.internal_branch_name:
|
790
|
-
raise Exception(
|
791
|
-
f"Node of type {self.node_type} only allows a branch of name {defaults.DAG_BRANCH_NAME}"
|
792
|
-
)
|
793
|
-
|
794
|
-
return self.branch
|
795
|
-
|
796
|
-
def fan_out(self, map_variable: TypeMapVariable = None, **kwargs):
|
797
|
-
"""
|
798
|
-
The general method to fan out for a node of type dag.
|
799
|
-
The method assumes that the step log has already been created.
|
800
|
-
|
801
|
-
Args:
|
802
|
-
executor (BaseExecutor): The executor class as defined by the config
|
803
|
-
map_variable (dict, optional): _description_. Defaults to None.
|
804
|
-
"""
|
805
|
-
effective_branch_name = self._resolve_map_placeholders(
|
806
|
-
self.internal_branch_name, map_variable=map_variable
|
807
|
-
)
|
808
|
-
|
809
|
-
branch_log = self._context.run_log_store.create_branch_log(
|
810
|
-
effective_branch_name
|
811
|
-
)
|
812
|
-
branch_log.status = defaults.PROCESSING
|
813
|
-
self._context.run_log_store.add_branch_log(branch_log, self._context.run_id)
|
814
|
-
|
815
|
-
def execute_as_graph(self, map_variable: TypeMapVariable = None, **kwargs):
|
816
|
-
"""
|
817
|
-
This function does the actual execution of the branch of the dag node.
|
818
|
-
|
819
|
-
From a design perspective, this function should not be called if the execution is 3rd party orchestrated.
|
820
|
-
|
821
|
-
The modes that render the job specifications, do not need to interact with this node at all
|
822
|
-
as they have their own internal mechanisms of handling sub dags.
|
823
|
-
If they do not, you can find a way using as-is nodes as hack nodes.
|
824
|
-
|
825
|
-
The actual logic is :
|
826
|
-
* We just execute the branch as with any other composite nodes
|
827
|
-
* The branch name is called 'dag'
|
828
|
-
|
829
|
-
The execution of a dag, could result in
|
830
|
-
* The dag being completely executed with a definite (fail, success) state in case of
|
831
|
-
local or local-container execution
|
832
|
-
* The dag being in a processing state with PROCESSING status in case of local-aws-batch
|
833
|
-
|
834
|
-
Only fail state is considered failure during this phase of execution.
|
835
|
-
|
836
|
-
Args:
|
837
|
-
executor (Executor): The Executor as per the use config
|
838
|
-
**kwargs: Optional kwargs passed around
|
839
|
-
"""
|
840
|
-
self.fan_out(map_variable=map_variable, **kwargs)
|
841
|
-
self._context.executor.execute_graph(
|
842
|
-
self.branch, map_variable=map_variable, **kwargs
|
843
|
-
)
|
844
|
-
self.fan_in(map_variable=map_variable, **kwargs)
|
845
|
-
|
846
|
-
def fan_in(self, map_variable: TypeMapVariable = None, **kwargs):
|
847
|
-
"""
|
848
|
-
The general method to fan in for a node of type dag.
|
849
|
-
|
850
|
-
3rd party orchestrators should call this method to find the status of the step log.
|
851
|
-
|
852
|
-
Args:
|
853
|
-
executor (BaseExecutor): The executor class as defined by the config
|
854
|
-
map_variable (dict, optional): If the node is part of type dag. Defaults to None.
|
855
|
-
"""
|
856
|
-
step_success_bool = True
|
857
|
-
effective_branch_name = self._resolve_map_placeholders(
|
858
|
-
self.internal_branch_name, map_variable=map_variable
|
859
|
-
)
|
860
|
-
effective_internal_name = self._resolve_map_placeholders(
|
861
|
-
self.internal_name, map_variable=map_variable
|
862
|
-
)
|
863
|
-
|
864
|
-
branch_log = self._context.run_log_store.get_branch_log(
|
865
|
-
effective_branch_name, self._context.run_id
|
866
|
-
)
|
867
|
-
if branch_log.status != defaults.SUCCESS:
|
868
|
-
step_success_bool = False
|
869
|
-
|
870
|
-
step_log = self._context.run_log_store.get_step_log(
|
871
|
-
effective_internal_name, self._context.run_id
|
872
|
-
)
|
873
|
-
step_log.status = defaults.PROCESSING
|
874
|
-
|
875
|
-
if step_success_bool: # If none failed and nothing is waiting
|
876
|
-
step_log.status = defaults.SUCCESS
|
877
|
-
else:
|
878
|
-
step_log.status = defaults.FAIL
|
879
|
-
|
880
|
-
self._context.run_log_store.add_step_log(step_log, self._context.run_id)
|
881
|
-
|
882
|
-
|
883
704
|
class StubNode(ExecutableNode):
|
884
705
|
"""
|
885
706
|
Stub is a convenience design node.
|
@@ -926,7 +747,6 @@ class StubNode(ExecutableNode):
|
|
926
747
|
mock=False,
|
927
748
|
map_variable: TypeMapVariable = None,
|
928
749
|
attempt_number: int = 1,
|
929
|
-
**kwargs,
|
930
750
|
) -> StepLog:
|
931
751
|
"""
|
932
752
|
Do Nothing node.
|
@@ -0,0 +1,169 @@
|
|
1
|
+
import importlib
|
2
|
+
import logging
|
3
|
+
import os
|
4
|
+
from datetime import datetime
|
5
|
+
from typing import Any, Callable
|
6
|
+
|
7
|
+
from pydantic import ConfigDict, Field
|
8
|
+
|
9
|
+
from extensions.nodes.torch_config import TorchConfig
|
10
|
+
from runnable import PythonJob, datastore, defaults
|
11
|
+
from runnable.datastore import StepLog
|
12
|
+
from runnable.nodes import DistributedNode
|
13
|
+
from runnable.tasks import PythonTaskType, create_task
|
14
|
+
from runnable.utils import TypeMapVariable
|
15
|
+
|
16
|
+
logger = logging.getLogger(defaults.LOGGER_NAME)
|
17
|
+
|
18
|
+
try:
|
19
|
+
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
20
|
+
from torch.distributed.run import config_from_args
|
21
|
+
except ImportError:
|
22
|
+
raise ImportError("torch is not installed. Please install torch first.")
|
23
|
+
|
24
|
+
print("torch is installed")
|
25
|
+
|
26
|
+
|
27
|
+
def training_subprocess():
|
28
|
+
command = os.environ.get("RUNNABLE_TORCH_COMMAND")
|
29
|
+
run_id = os.environ.get("RUNNABLE_TORCH_RUN_ID", "")
|
30
|
+
parameters_files = os.environ.get("RUNNABLE_TORCH_PARAMETERS_FILES", "")
|
31
|
+
process_run_id = run_id + "-" + os.environ.get("RANK", "")
|
32
|
+
|
33
|
+
delete_env_vars_with_prefix("RUNNABLE_")
|
34
|
+
|
35
|
+
func = get_callable_from_dotted_path(command)
|
36
|
+
job = PythonJob(function=func)
|
37
|
+
|
38
|
+
job.execute(
|
39
|
+
parameters_file=parameters_files,
|
40
|
+
job_id=process_run_id,
|
41
|
+
)
|
42
|
+
|
43
|
+
|
44
|
+
def get_callable_from_dotted_path(dotted_path) -> Callable:
|
45
|
+
try:
|
46
|
+
# Split the path into module path and callable object
|
47
|
+
module_path, callable_name = dotted_path.rsplit(".", 1)
|
48
|
+
|
49
|
+
# Import the module
|
50
|
+
module = importlib.import_module(module_path)
|
51
|
+
|
52
|
+
# Get the callable from the module
|
53
|
+
callable_obj = getattr(module, callable_name)
|
54
|
+
|
55
|
+
# Check if the object is callable
|
56
|
+
if not callable(callable_obj):
|
57
|
+
raise TypeError(f"The object {callable_name} is not callable.")
|
58
|
+
|
59
|
+
return callable_obj
|
60
|
+
|
61
|
+
except (ImportError, AttributeError, ValueError) as e:
|
62
|
+
raise ImportError(f"Could not import '{dotted_path}'.") from e
|
63
|
+
|
64
|
+
|
65
|
+
def delete_env_vars_with_prefix(prefix):
|
66
|
+
to_delete = [] # List to keep track of variables to delete
|
67
|
+
|
68
|
+
# Iterate over a list of all environment variable keys
|
69
|
+
for var in os.environ:
|
70
|
+
if var.startswith(prefix):
|
71
|
+
to_delete.append(var)
|
72
|
+
|
73
|
+
# Delete each of the variables collected
|
74
|
+
for var in to_delete:
|
75
|
+
del os.environ[var]
|
76
|
+
|
77
|
+
|
78
|
+
class TorchNode(DistributedNode, TorchConfig):
|
79
|
+
node_type: str = Field(default="torch", serialization_alias="type")
|
80
|
+
executable: PythonTaskType = Field(exclude=True)
|
81
|
+
|
82
|
+
# Similar to TaskNode
|
83
|
+
model_config = ConfigDict(extra="allow")
|
84
|
+
|
85
|
+
def get_summary(self) -> dict[str, Any]:
|
86
|
+
summary = {
|
87
|
+
"name": self.name,
|
88
|
+
"type": self.node_type,
|
89
|
+
}
|
90
|
+
|
91
|
+
return summary
|
92
|
+
|
93
|
+
@classmethod
|
94
|
+
def parse_from_config(cls, config: dict[str, Any]) -> "TorchNode":
|
95
|
+
task_config = {
|
96
|
+
k: v for k, v in config.items() if k not in TorchNode.model_fields.keys()
|
97
|
+
}
|
98
|
+
node_config = {
|
99
|
+
k: v for k, v in config.items() if k in TorchNode.model_fields.keys()
|
100
|
+
}
|
101
|
+
|
102
|
+
executable = create_task(task_config)
|
103
|
+
|
104
|
+
assert isinstance(executable, PythonTaskType)
|
105
|
+
return cls(executable=executable, **node_config, **task_config)
|
106
|
+
|
107
|
+
def get_launch_config(self) -> LaunchConfig:
|
108
|
+
config, _, _ = config_from_args(self)
|
109
|
+
config.run_id = self._context.run_id
|
110
|
+
return config
|
111
|
+
|
112
|
+
def execute(
|
113
|
+
self,
|
114
|
+
mock=False,
|
115
|
+
map_variable: TypeMapVariable = None,
|
116
|
+
attempt_number: int = 1,
|
117
|
+
) -> StepLog:
|
118
|
+
assert map_variable is None, "TorchNode does not support map_variable"
|
119
|
+
|
120
|
+
step_log = self._context.run_log_store.get_step_log(
|
121
|
+
self._get_step_log_name(map_variable), self._context.run_id
|
122
|
+
)
|
123
|
+
|
124
|
+
# Attempt to call the function or elastic launch
|
125
|
+
launch_config = self.get_launch_config()
|
126
|
+
logger.info(f"launch_config: {launch_config}")
|
127
|
+
|
128
|
+
os.environ["RUNNABLE_TORCH_COMMAND"] = self.executable.command
|
129
|
+
os.environ["RUNNABLE_TORCH_PARAMETERS_FILES"] = (
|
130
|
+
self._context.parameters_file or ""
|
131
|
+
)
|
132
|
+
os.environ["RUNNABLE_TORCH_RUN_ID"] = self._context.run_id
|
133
|
+
launcher = elastic_launch(
|
134
|
+
launch_config,
|
135
|
+
training_subprocess,
|
136
|
+
)
|
137
|
+
try:
|
138
|
+
launcher()
|
139
|
+
attempt_log = datastore.StepAttempt(
|
140
|
+
status=defaults.SUCCESS,
|
141
|
+
start_time=str(datetime.now()),
|
142
|
+
end_time=str(datetime.now()),
|
143
|
+
attempt_number=attempt_number,
|
144
|
+
)
|
145
|
+
except Exception as e:
|
146
|
+
attempt_log = datastore.StepAttempt(
|
147
|
+
status=defaults.FAIL,
|
148
|
+
start_time=str(datetime.now()),
|
149
|
+
end_time=str(datetime.now()),
|
150
|
+
attempt_number=attempt_number,
|
151
|
+
)
|
152
|
+
logger.error(f"Error executing TorchNode: {e}")
|
153
|
+
|
154
|
+
delete_env_vars_with_prefix("RUNNABLE_TORCH")
|
155
|
+
|
156
|
+
logger.info(f"attempt_log: {attempt_log}")
|
157
|
+
logger.info(f"Step {self.name} completed with status: {attempt_log.status}")
|
158
|
+
|
159
|
+
step_log.status = attempt_log.status
|
160
|
+
step_log.attempts.append(attempt_log)
|
161
|
+
|
162
|
+
return step_log
|
163
|
+
|
164
|
+
# TODO: Not sure we need these methods
|
165
|
+
def fan_in(self, map_variable: dict[str, str | int | float] | None = None):
|
166
|
+
assert map_variable is None, "TorchNode does not support map_variable"
|
167
|
+
|
168
|
+
def fan_out(self, map_variable: dict[str, str | int | float] | None = None):
|
169
|
+
assert map_variable is None, "TorchNode does not support map_variable"
|
@@ -0,0 +1,33 @@
|
|
1
|
+
from pydantic import BaseModel, Field
|
2
|
+
|
3
|
+
|
4
|
+
class TorchConfig(BaseModel):
|
5
|
+
nnodes: str = Field(default="1:1")
|
6
|
+
nproc_per_node: int = Field(default=4)
|
7
|
+
|
8
|
+
rdzv_backend: str = Field(default="static")
|
9
|
+
rdzv_endpoint: str = Field(default="")
|
10
|
+
rdzv_id: str | None = Field(default=None)
|
11
|
+
rdzv_conf: str = Field(default="")
|
12
|
+
|
13
|
+
max_restarts: int = Field(default=3)
|
14
|
+
monitor_interval: float = Field(default=0.1)
|
15
|
+
start_method: str = Field(default="spawn")
|
16
|
+
role: str = Field(default="default_role")
|
17
|
+
log_dir: str = Field(default="torch_logs")
|
18
|
+
redirects: str = Field(default="1")
|
19
|
+
tee: str = Field(default="1")
|
20
|
+
master_addr: str = Field(default="localhost")
|
21
|
+
master_port: str = Field(default="29500")
|
22
|
+
training_script: str = Field(default="dummy_training_script")
|
23
|
+
training_script_args: str = Field(default="")
|
24
|
+
|
25
|
+
# Optional fields
|
26
|
+
local_ranks_filter: str = Field(default="")
|
27
|
+
node_rank: int = Field(default=0)
|
28
|
+
local_addr: str | None = Field(default=None)
|
29
|
+
logs_specs: str | None = Field(default=None)
|
30
|
+
standalone: bool = Field(default=False)
|
31
|
+
module: bool = Field(default=False)
|
32
|
+
no_python: bool = Field(default=False)
|
33
|
+
run_path: bool = Field(default=False)
|