runnable 0.30.4__py3-none-any.whl → 0.31.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- extensions/nodes/torch.py +22 -5
- extensions/nodes/torch_config.py +93 -30
- {runnable-0.30.4.dist-info → runnable-0.31.0.dist-info}/METADATA +1 -1
- {runnable-0.30.4.dist-info → runnable-0.31.0.dist-info}/RECORD +7 -7
- {runnable-0.30.4.dist-info → runnable-0.31.0.dist-info}/WHEEL +0 -0
- {runnable-0.30.4.dist-info → runnable-0.31.0.dist-info}/entry_points.txt +0 -0
- {runnable-0.30.4.dist-info → runnable-0.31.0.dist-info}/licenses/LICENSE +0 -0
extensions/nodes/torch.py
CHANGED
@@ -8,7 +8,7 @@ from typing import Any, Callable
|
|
8
8
|
|
9
9
|
from pydantic import ConfigDict, Field
|
10
10
|
|
11
|
-
from extensions.nodes.torch_config import TorchConfig
|
11
|
+
from extensions.nodes.torch_config import EasyTorchConfig, InternalLogSpecs, TorchConfig
|
12
12
|
from runnable import PythonJob, datastore, defaults
|
13
13
|
from runnable.datastore import StepLog
|
14
14
|
from runnable.nodes import DistributedNode
|
@@ -18,8 +18,9 @@ from runnable.utils import TypeMapVariable
|
|
18
18
|
logger = logging.getLogger(defaults.LOGGER_NAME)
|
19
19
|
|
20
20
|
try:
|
21
|
+
from torch.distributed.elastic.multiprocessing.api import DefaultLogsSpecs
|
21
22
|
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
22
|
-
|
23
|
+
|
23
24
|
except ImportError:
|
24
25
|
raise ImportError("torch is not installed. Please install torch first.")
|
25
26
|
|
@@ -120,9 +121,25 @@ class TorchNode(DistributedNode, TorchConfig):
|
|
120
121
|
return cls(executable=executable, **node_config, **task_config)
|
121
122
|
|
122
123
|
def get_launch_config(self) -> LaunchConfig:
|
123
|
-
|
124
|
-
|
125
|
-
|
124
|
+
internal_log_spec = InternalLogSpecs(**self.model_dump(exclude_none=True))
|
125
|
+
log_spec: DefaultLogsSpecs = DefaultLogsSpecs(
|
126
|
+
**internal_log_spec.model_dump(exclude_none=True)
|
127
|
+
)
|
128
|
+
easy_torch_config = EasyTorchConfig(
|
129
|
+
**self.model_dump(
|
130
|
+
exclude_none=True,
|
131
|
+
)
|
132
|
+
)
|
133
|
+
|
134
|
+
laugch_config = LaunchConfig(
|
135
|
+
**easy_torch_config.model_dump(
|
136
|
+
exclude_none=True,
|
137
|
+
),
|
138
|
+
logs_specs=log_spec,
|
139
|
+
run_id=self._context.run_id,
|
140
|
+
)
|
141
|
+
print(laugch_config)
|
142
|
+
return laugch_config
|
126
143
|
|
127
144
|
def execute(
|
128
145
|
self,
|
extensions/nodes/torch_config.py
CHANGED
@@ -1,33 +1,96 @@
|
|
1
|
-
from
|
1
|
+
from enum import Enum
|
2
|
+
from typing import Any, Optional
|
3
|
+
|
4
|
+
from pydantic import BaseModel, ConfigDict, Field, computed_field
|
5
|
+
|
6
|
+
|
7
|
+
class StartMethod(str, Enum):
|
8
|
+
spawn = "spawn"
|
9
|
+
fork = "fork"
|
10
|
+
forkserver = "forkserver"
|
11
|
+
|
12
|
+
|
13
|
+
# min_nodes: int
|
14
|
+
# max_nodes: int
|
15
|
+
# nproc_per_node: int
|
16
|
+
|
17
|
+
# logs_specs: Optional[LogsSpecs] = None
|
18
|
+
# run_id: str = ""
|
19
|
+
# role: str = "default_role"
|
20
|
+
|
21
|
+
# rdzv_endpoint: str = ""
|
22
|
+
# rdzv_backend: str = "etcd"
|
23
|
+
# rdzv_configs: dict[str, Any] = field(default_factory=dict)
|
24
|
+
# rdzv_timeout: int = -1
|
25
|
+
|
26
|
+
# max_restarts: int = 3
|
27
|
+
# monitor_interval: float = 0.1
|
28
|
+
# start_method: str = "spawn"
|
29
|
+
# log_line_prefix_template: Optional[str] = None
|
30
|
+
# metrics_cfg: dict[str, str] = field(default_factory=dict)
|
31
|
+
# local_addr: Optional[str] = None
|
32
|
+
|
33
|
+
## The idea is the following:
|
34
|
+
# Users can configure any of the options present in TorchConfig class.
|
35
|
+
# The LaunchConfig class will be created from torch config.
|
36
|
+
# The LogSpecs is sent as a parameter to the launch config.
|
37
|
+
# None as much as possible to get
|
38
|
+
|
39
|
+
## NO idea of standalone and how to send it
|
40
|
+
|
41
|
+
|
42
|
+
class InternalLogSpecs(BaseModel):
|
43
|
+
log_dir: Optional[str] = Field(default="torch_logs")
|
44
|
+
redirects: int | None = Field(default=None)
|
45
|
+
tee: int | None = Field(default=None)
|
46
|
+
local_ranks_filter: Optional[set[int]] = Field(default=None)
|
47
|
+
|
48
|
+
model_config = ConfigDict(extra="ignore")
|
2
49
|
|
3
50
|
|
4
51
|
class TorchConfig(BaseModel):
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
local_addr: str
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
52
|
+
model_config = ConfigDict(extra="forbid")
|
53
|
+
nnodes: str = Field(default="1:1", exclude=True)
|
54
|
+
nproc_per_node: int = Field(default=1)
|
55
|
+
|
56
|
+
# will be used to create the log specs
|
57
|
+
log_dir: Optional[str] = Field(default="torch_logs", exclude=True)
|
58
|
+
redirects: int | None = Field(default=None, exclude=True)
|
59
|
+
tee: int | None = Field(default=None, exclude=True)
|
60
|
+
local_ranks_filter: Optional[set[int]] = Field(default=None, exclude=True)
|
61
|
+
|
62
|
+
role: str | None = Field(default=None)
|
63
|
+
# run_id would be the run_id of the context
|
64
|
+
# and sent at the creation of the LaunchConfig
|
65
|
+
|
66
|
+
rdzv_backend: str | None = Field(default="static")
|
67
|
+
rdzv_endpoint: str | None = Field(default="")
|
68
|
+
rdzv_configs: dict[str, Any] = Field(default_factory=dict)
|
69
|
+
rdzv_timeout: int | None = Field(default=None)
|
70
|
+
|
71
|
+
max_restarts: int | None = Field(default=None)
|
72
|
+
monitor_interval: float | None = Field(default=None)
|
73
|
+
start_method: str | None = Field(default=StartMethod.spawn)
|
74
|
+
log_line_prefix_template: str | None = Field(default=None)
|
75
|
+
local_addr: Optional[str] = None
|
76
|
+
|
77
|
+
# https://github.com/pytorch/pytorch/blob/main/torch/distributed/run.py#L753
|
78
|
+
# master_addr: str | None = Field(default="localhost")
|
79
|
+
# master_port: str | None = Field(default="29500")
|
80
|
+
# training_script: str = Field(default="dummy_training_script")
|
81
|
+
# training_script_args: str = Field(default="")
|
82
|
+
|
83
|
+
|
84
|
+
class EasyTorchConfig(TorchConfig):
|
85
|
+
model_config = ConfigDict(extra="ignore")
|
86
|
+
|
87
|
+
# TODO: Validate min < max
|
88
|
+
@computed_field # type: ignore
|
89
|
+
@property
|
90
|
+
def min_nodes(self) -> int:
|
91
|
+
return int(self.nnodes.split(":")[0])
|
92
|
+
|
93
|
+
@computed_field # type: ignore
|
94
|
+
@property
|
95
|
+
def max_nodes(self) -> int:
|
96
|
+
return int(self.nnodes.split(":")[1])
|
@@ -16,8 +16,8 @@ extensions/job_executor/pyproject.toml,sha256=UIEgiCYHTXcRWSByNMFuKJFKgxTBpQqTqy
|
|
16
16
|
extensions/nodes/README.md,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
17
17
|
extensions/nodes/nodes.py,sha256=s9ub1dqy4qHjRQG6YElCdL7rCOTYNs9RUIrStZ6tEB4,28256
|
18
18
|
extensions/nodes/pyproject.toml,sha256=YTu-ETN3JNFSkMzzWeOwn4m-O2nbRH-PmiPBALDCUw4,278
|
19
|
-
extensions/nodes/torch.py,sha256=
|
20
|
-
extensions/nodes/torch_config.py,sha256=
|
19
|
+
extensions/nodes/torch.py,sha256=RUelXV7Pa4U5F7Ww3cfRG0Oaz9SkYF3b_CmpFHlpbyI,6885
|
20
|
+
extensions/nodes/torch_config.py,sha256=jfUtkwCYolyKVcFxiMjjwm63yv-HjTKvSQR8JLA7sZg,3151
|
21
21
|
extensions/pipeline_executor/README.md,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
22
22
|
extensions/pipeline_executor/__init__.py,sha256=wfigTL2T9OHrmE8b2Ydmb8h6hr-oF--Yc2FectC7WaY,24623
|
23
23
|
extensions/pipeline_executor/argo.py,sha256=AEGSWVZulBL6EsvbVCaeBeTl2m_t5ymc6RFpMKhivis,37946
|
@@ -58,8 +58,8 @@ runnable/sdk.py,sha256=NZVQGaL4Zm2hwloRmqEgp8UPbBg9hY1abQGYnOgniPI,35128
|
|
58
58
|
runnable/secrets.py,sha256=4L_dBFxTgr8r_hHUD6RlZEtqaOHDRsFG5PXO5wlvMI0,2324
|
59
59
|
runnable/tasks.py,sha256=Qb1IhVxHv68E7vf3M3YCf7MGRHyjmsEEYBpEpiZ4mRI,29062
|
60
60
|
runnable/utils.py,sha256=hBr7oGwGL2VgfITlQCTz-a1iwvvf7Mfl-HY8UdENZac,19929
|
61
|
-
runnable-0.
|
62
|
-
runnable-0.
|
63
|
-
runnable-0.
|
64
|
-
runnable-0.
|
65
|
-
runnable-0.
|
61
|
+
runnable-0.31.0.dist-info/METADATA,sha256=9c3Ixkq-Kl0_hiQfDX-KwtSAdSWzMRLJMfEze2oVQhE,10115
|
62
|
+
runnable-0.31.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
63
|
+
runnable-0.31.0.dist-info/entry_points.txt,sha256=PrjKrlfXPZaV_7hz8orGu4FDnatLqnhPOXljyllszdw,1880
|
64
|
+
runnable-0.31.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
65
|
+
runnable-0.31.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|