runnable 0.32.2__py3-none-any.whl → 0.32.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- extensions/nodes/torch.py +4 -7
- extensions/pipeline_executor/argo.py +9 -16
- extensions/tasks/torch.py +29 -6
- runnable/nodes.py +1 -36
- {runnable-0.32.2.dist-info → runnable-0.32.3.dist-info}/METADATA +1 -1
- {runnable-0.32.2.dist-info → runnable-0.32.3.dist-info}/RECORD +9 -9
- {runnable-0.32.2.dist-info → runnable-0.32.3.dist-info}/WHEEL +0 -0
- {runnable-0.32.2.dist-info → runnable-0.32.3.dist-info}/entry_points.txt +0 -0
- {runnable-0.32.2.dist-info → runnable-0.32.3.dist-info}/licenses/LICENSE +0 -0
extensions/nodes/torch.py
CHANGED
@@ -5,14 +5,14 @@ import random
|
|
5
5
|
import string
|
6
6
|
from datetime import datetime
|
7
7
|
from pathlib import Path
|
8
|
-
from typing import
|
8
|
+
from typing import Any, Callable, Optional
|
9
9
|
|
10
10
|
from pydantic import BaseModel, ConfigDict, Field, field_serializer
|
11
11
|
|
12
12
|
from extensions.nodes.torch_config import EasyTorchConfig, TorchConfig
|
13
13
|
from runnable import PythonJob, datastore, defaults
|
14
14
|
from runnable.datastore import StepLog
|
15
|
-
from runnable.nodes import
|
15
|
+
from runnable.nodes import ExecutableNode
|
16
16
|
from runnable.tasks import PythonTaskType, create_task
|
17
17
|
from runnable.utils import TypeMapVariable
|
18
18
|
|
@@ -23,10 +23,7 @@ try:
|
|
23
23
|
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
24
24
|
except ImportError:
|
25
25
|
logger.exception("Torch is not installed. Please install torch first.")
|
26
|
-
|
27
|
-
if TYPE_CHECKING:
|
28
|
-
from torch.distributed.elastic.multiprocessing.api import DefaultLogsSpecs, Std
|
29
|
-
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
26
|
+
raise Exception("Torch is not installed. Please install torch first.")
|
30
27
|
|
31
28
|
|
32
29
|
def training_subprocess():
|
@@ -120,7 +117,7 @@ def delete_env_vars_with_prefix(prefix):
|
|
120
117
|
|
121
118
|
|
122
119
|
# TODO: The design of this class is not final
|
123
|
-
class TorchNode(
|
120
|
+
class TorchNode(ExecutableNode, TorchConfig):
|
124
121
|
node_type: str = Field(default="torch", serialization_alias="type")
|
125
122
|
executable: PythonTaskType = Field(exclude=True)
|
126
123
|
|
@@ -20,14 +20,10 @@ from pydantic import (
|
|
20
20
|
from pydantic.alias_generators import to_camel
|
21
21
|
from ruamel.yaml import YAML
|
22
22
|
|
23
|
-
from extensions.nodes.nodes import
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
SuccessNode,
|
28
|
-
TaskNode,
|
29
|
-
)
|
30
|
-
from extensions.nodes.torch import TorchNode
|
23
|
+
from extensions.nodes.nodes import MapNode, ParallelNode, TaskNode
|
24
|
+
|
25
|
+
# TODO: Should be part of a wider refactor
|
26
|
+
# from extensions.nodes.torch import TorchNode
|
31
27
|
from extensions.pipeline_executor import GenericPipelineExecutor
|
32
28
|
from runnable import defaults, utils
|
33
29
|
from runnable.defaults import TypeMapVariable
|
@@ -590,12 +586,7 @@ class ArgoExecutor(GenericPipelineExecutor):
|
|
590
586
|
task_name: str,
|
591
587
|
inputs: Optional[Inputs] = None,
|
592
588
|
) -> ContainerTemplate:
|
593
|
-
assert
|
594
|
-
isinstance(node, TaskNode)
|
595
|
-
or isinstance(node, StubNode)
|
596
|
-
or isinstance(node, SuccessNode)
|
597
|
-
or isinstance(node, TorchNode)
|
598
|
-
)
|
589
|
+
assert node.node_type in ["task", "torch", "success", "stub", "fail"]
|
599
590
|
|
600
591
|
node_override = None
|
601
592
|
if hasattr(node, "overrides"):
|
@@ -658,7 +649,7 @@ class ArgoExecutor(GenericPipelineExecutor):
|
|
658
649
|
def _set_env_vars_to_task(
|
659
650
|
self, working_on: BaseNode, container_template: CoreContainerTemplate
|
660
651
|
):
|
661
|
-
if
|
652
|
+
if working_on.node_type not in ["task", "torch"]:
|
662
653
|
return
|
663
654
|
|
664
655
|
global_envs: dict[str, str] = {}
|
@@ -667,7 +658,7 @@ class ArgoExecutor(GenericPipelineExecutor):
|
|
667
658
|
env_var = cast(EnvVar, env_var)
|
668
659
|
global_envs[env_var.name] = env_var.value
|
669
660
|
|
670
|
-
override_key = working_on.overrides.get(self.service_name, "")
|
661
|
+
override_key = working_on.overrides.get(self.service_name, "") # type: ignore
|
671
662
|
node_override = self.overrides.get(override_key, None)
|
672
663
|
|
673
664
|
# Update the global envs with the node overrides
|
@@ -878,6 +869,8 @@ class ArgoExecutor(GenericPipelineExecutor):
|
|
878
869
|
self._templates.append(composite_template)
|
879
870
|
|
880
871
|
case "torch":
|
872
|
+
from extensions.nodes.torch import TorchNode
|
873
|
+
|
881
874
|
assert isinstance(working_on, TorchNode)
|
882
875
|
# TODO: Need to add multi-node functionality
|
883
876
|
# Check notes on the torch node
|
extensions/tasks/torch.py
CHANGED
@@ -5,7 +5,7 @@ import random
|
|
5
5
|
import string
|
6
6
|
from datetime import datetime
|
7
7
|
from pathlib import Path
|
8
|
-
from typing import
|
8
|
+
from typing import Any, Optional
|
9
9
|
|
10
10
|
from pydantic import BaseModel, ConfigDict, Field, field_serializer, model_validator
|
11
11
|
from ruamel.yaml import YAML
|
@@ -19,16 +19,15 @@ from runnable.utils import get_module_and_attr_names
|
|
19
19
|
|
20
20
|
logger = logging.getLogger(defaults.LOGGER_NAME)
|
21
21
|
|
22
|
+
logger = logging.getLogger(defaults.LOGGER_NAME)
|
23
|
+
|
22
24
|
try:
|
23
25
|
from torch.distributed.elastic.multiprocessing.api import DefaultLogsSpecs, Std
|
24
26
|
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
25
27
|
|
26
|
-
except ImportError:
|
28
|
+
except ImportError as e:
|
27
29
|
logger.exception("torch is not installed")
|
28
|
-
|
29
|
-
if TYPE_CHECKING:
|
30
|
-
from torch.distributed.elastic.multiprocessing.api import DefaultLogsSpecs, Std
|
31
|
-
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
30
|
+
raise Exception("torch is not installed") from e
|
32
31
|
|
33
32
|
|
34
33
|
def get_min_max_nodes(nnodes: str) -> tuple[int, int]:
|
@@ -69,6 +68,7 @@ class TorchTaskType(BaseTaskType, TorchConfig):
|
|
69
68
|
)
|
70
69
|
)
|
71
70
|
print("###", easy_torch_config)
|
71
|
+
print("###", easy_torch_config)
|
72
72
|
launch_config = LaunchConfig(
|
73
73
|
**easy_torch_config.model_dump(
|
74
74
|
exclude_none=True,
|
@@ -96,6 +96,28 @@ class TorchTaskType(BaseTaskType, TorchConfig):
|
|
96
96
|
|
97
97
|
_, max_nodes = get_min_max_nodes(self.nnodes)
|
98
98
|
|
99
|
+
if max_nodes > 1 and not is_execute:
|
100
|
+
executor = self._context.executor
|
101
|
+
executor.scale_up(self)
|
102
|
+
return StepAttempt(
|
103
|
+
status=defaults.SUCCESS,
|
104
|
+
start_time=str(datetime.now()),
|
105
|
+
end_time=str(datetime.now()),
|
106
|
+
attempt_number=1,
|
107
|
+
message="Triggered a scale up",
|
108
|
+
)
|
109
|
+
|
110
|
+
# The below should happen only if we are in the node that we want to execute
|
111
|
+
# For a single node, multi worker setup, this should be the entry point
|
112
|
+
# For a multi-node, we need to:
|
113
|
+
# - create a service config
|
114
|
+
# - Create a stateful set with number of nodes
|
115
|
+
# - Create a job to run the torch.distributed.launcher.api.elastic_launch on every node
|
116
|
+
# - the entry point to runnnable could be a way to trigger execution instead of scaling
|
117
|
+
is_execute = os.environ.get("RUNNABLE_TORCH_EXECUTE", "true") == "true"
|
118
|
+
|
119
|
+
_, max_nodes = get_min_max_nodes(self.nnodes)
|
120
|
+
|
99
121
|
if max_nodes > 1 and not is_execute:
|
100
122
|
executor = self._context.executor
|
101
123
|
executor.scale_up(self)
|
@@ -109,6 +131,7 @@ class TorchTaskType(BaseTaskType, TorchConfig):
|
|
109
131
|
|
110
132
|
launch_config = self._get_launch_config()
|
111
133
|
print("###****", launch_config)
|
134
|
+
print("###****", launch_config)
|
112
135
|
logger.info(f"launch_config: {launch_config}")
|
113
136
|
|
114
137
|
# ENV variables are shared with the subprocess, use that as communication
|
runnable/nodes.py
CHANGED
@@ -36,8 +36,8 @@ class BaseNode(ABC, BaseModel):
|
|
36
36
|
name: str
|
37
37
|
internal_name: str = Field(exclude=True)
|
38
38
|
internal_branch_name: str = Field(default="", exclude=True)
|
39
|
+
|
39
40
|
is_composite: bool = Field(default=False, exclude=True)
|
40
|
-
is_distributed: bool = Field(default=False, exclude=True)
|
41
41
|
|
42
42
|
@property
|
43
43
|
def _context(self):
|
@@ -483,41 +483,6 @@ class CompositeNode(TraversalNode):
|
|
483
483
|
)
|
484
484
|
|
485
485
|
|
486
|
-
class DistributedNode(TraversalNode):
|
487
|
-
"""
|
488
|
-
Use this node for distributed execution of tasks.
|
489
|
-
eg: torch distributed, horovod, etc.
|
490
|
-
"""
|
491
|
-
|
492
|
-
is_distributed: bool = True
|
493
|
-
catalog: Optional[CatalogStructure] = Field(default=None)
|
494
|
-
max_attempts: int = Field(default=1, ge=1)
|
495
|
-
|
496
|
-
def _get_catalog_settings(self) -> Dict[str, Any]:
|
497
|
-
"""
|
498
|
-
If the node defines a catalog settings, return it or None
|
499
|
-
|
500
|
-
Returns:
|
501
|
-
dict: catalog settings defined as per the node or None
|
502
|
-
"""
|
503
|
-
if self.catalog:
|
504
|
-
return self.catalog.model_dump()
|
505
|
-
return {}
|
506
|
-
|
507
|
-
def _get_max_attempts(self) -> int:
|
508
|
-
return self.max_attempts
|
509
|
-
|
510
|
-
def _get_branch_by_name(self, branch_name: str):
|
511
|
-
raise exceptions.NodeMethodCallError(
|
512
|
-
"This is an distributed node and does not have branches"
|
513
|
-
)
|
514
|
-
|
515
|
-
def execute_as_graph(self, map_variable: TypeMapVariable = None):
|
516
|
-
raise exceptions.NodeMethodCallError(
|
517
|
-
"This is an executable node and does not have a graph"
|
518
|
-
)
|
519
|
-
|
520
|
-
|
521
486
|
class TerminalNode(BaseNode):
|
522
487
|
def _get_on_failure_node(self) -> str:
|
523
488
|
return ""
|
@@ -16,11 +16,11 @@ extensions/job_executor/pyproject.toml,sha256=UIEgiCYHTXcRWSByNMFuKJFKgxTBpQqTqy
|
|
16
16
|
extensions/nodes/README.md,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
17
17
|
extensions/nodes/nodes.py,sha256=s9ub1dqy4qHjRQG6YElCdL7rCOTYNs9RUIrStZ6tEB4,28256
|
18
18
|
extensions/nodes/pyproject.toml,sha256=YTu-ETN3JNFSkMzzWeOwn4m-O2nbRH-PmiPBALDCUw4,278
|
19
|
-
extensions/nodes/torch.py,sha256=
|
19
|
+
extensions/nodes/torch.py,sha256=64DTjdPNSJ8vfMwUN9h9Ly5g9qj-Bga7LSGrfCAO0BY,9389
|
20
20
|
extensions/nodes/torch_config.py,sha256=tO3sG2_fj8a6FmPZZllwKVx3WaRr4QmQYcACseg8YXM,2839
|
21
21
|
extensions/pipeline_executor/README.md,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
22
22
|
extensions/pipeline_executor/__init__.py,sha256=wfigTL2T9OHrmE8b2Ydmb8h6hr-oF--Yc2FectC7WaY,24623
|
23
|
-
extensions/pipeline_executor/argo.py,sha256=
|
23
|
+
extensions/pipeline_executor/argo.py,sha256=lHM3TM_UnQc4I1ghkuYdeBLpyr4pBLg-Ubnaf55Zw54,37878
|
24
24
|
extensions/pipeline_executor/local.py,sha256=6oWUJ6b6NvIkpeQJBoCT1hbfX4_6WCB4HzMgHZ4ik1A,1887
|
25
25
|
extensions/pipeline_executor/local_container.py,sha256=3kZ2QCsrq_YjH9dcAz8v05knKShQ_JtbIU-IA_-G538,12724
|
26
26
|
extensions/pipeline_executor/mocked.py,sha256=0sMmypuvstBIv9uQg-WAcPrF3oOFpeEXNi6N8Nzdnl0,5680
|
@@ -40,7 +40,7 @@ extensions/run_log_store/db/integration_FF.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeR
|
|
40
40
|
extensions/secrets/README.md,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
41
41
|
extensions/secrets/dotenv.py,sha256=nADHXI6KJ_LUYOIe5EbtYH-21OBebSNVr0Pjb1GlZ7w,1573
|
42
42
|
extensions/secrets/pyproject.toml,sha256=mLJNImNcBlbLKHh-0ugVWT9V83R4RibyyYDtBCSqVF4,282
|
43
|
-
extensions/tasks/torch.py,sha256=
|
43
|
+
extensions/tasks/torch.py,sha256=oeXRkmuttFIAuBwH7-h4SOVXMDOZXX5mvqI2aFrR3Vo,10283
|
44
44
|
extensions/tasks/torch_config.py,sha256=UjfMitT-TXASRDGR30I2vDRnyk7JQnR-5CsOVidjpSY,2833
|
45
45
|
runnable/__init__.py,sha256=3ZKuvGEkY_zHVQlJtarXd4jkjICxjgnw-bbKN_5SiJI,691
|
46
46
|
runnable/catalog.py,sha256=4msQxLhLKlsDDrHFnGauPYe-Or-q9g8_RYCn_4dpxaU,4466
|
@@ -53,15 +53,15 @@ runnable/exceptions.py,sha256=LFbp0-Qxg2PAMLEVt7w2whhBxSG-5pzUEv5qN-Rc4_c,3003
|
|
53
53
|
runnable/executor.py,sha256=Jr9yJtSH7CzjXJLWx3VWIUAQblstuGqzpFtajv7d39M,15348
|
54
54
|
runnable/graph.py,sha256=poQz5zcvq89ju_u5sYlunQLPbHnXTaUmjcvstPwvT4U,16536
|
55
55
|
runnable/names.py,sha256=vn92Kv9ANROYSZX6Z4z1v_WA3WiEdIYmG6KEStBFZug,8134
|
56
|
-
runnable/nodes.py,sha256=
|
56
|
+
runnable/nodes.py,sha256=QGHMznriEz4AcmntHICBZKrDT6zbc7WD1sV0MgwK10c,16691
|
57
57
|
runnable/parameters.py,sha256=u77CdqqDAbVdzNeBFPNUfGnWPy9-SpBVmwEJ56xmDm8,5289
|
58
58
|
runnable/pickler.py,sha256=ydJ_eti_U1F4l-YacFp7BWm6g5vTn04UXye25S1HVok,2684
|
59
59
|
runnable/sdk.py,sha256=hwsEGCCFSijm0DZwDJGHmV8jdMuSU_3Pf-vYoomWYHw,35084
|
60
60
|
runnable/secrets.py,sha256=4L_dBFxTgr8r_hHUD6RlZEtqaOHDRsFG5PXO5wlvMI0,2324
|
61
61
|
runnable/tasks.py,sha256=ABRhgiTY8F62pNlqJmVTDjwJwuzp8DqciUEOq1fpt1U,28989
|
62
62
|
runnable/utils.py,sha256=hBr7oGwGL2VgfITlQCTz-a1iwvvf7Mfl-HY8UdENZac,19929
|
63
|
-
runnable-0.32.
|
64
|
-
runnable-0.32.
|
65
|
-
runnable-0.32.
|
66
|
-
runnable-0.32.
|
67
|
-
runnable-0.32.
|
63
|
+
runnable-0.32.3.dist-info/METADATA,sha256=l0jxi_VKPXblTa_Kd-fnqKophmB2e8x1Dj1HDbJV570,10168
|
64
|
+
runnable-0.32.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
65
|
+
runnable-0.32.3.dist-info/entry_points.txt,sha256=uWHbbOSj0jlG54tFHw377xKkfVbjWvb_1Y9L_LgjJ0Q,1925
|
66
|
+
runnable-0.32.3.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
67
|
+
runnable-0.32.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|