runnable 0.31.0__py3-none-any.whl → 0.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- extensions/nodes/torch.py +71 -9
- extensions/nodes/torch_config.py +15 -35
- extensions/tasks/torch.py +235 -0
- extensions/tasks/torch_config.py +76 -0
- runnable/__init__.py +2 -1
- runnable/entrypoints.py +1 -0
- runnable/sdk.py +50 -50
- runnable/tasks.py +3 -3
- {runnable-0.31.0.dist-info → runnable-0.32.0.dist-info}/METADATA +2 -1
- {runnable-0.31.0.dist-info → runnable-0.32.0.dist-info}/RECORD +13 -11
- {runnable-0.31.0.dist-info → runnable-0.32.0.dist-info}/entry_points.txt +1 -0
- {runnable-0.31.0.dist-info → runnable-0.32.0.dist-info}/WHEEL +0 -0
- {runnable-0.31.0.dist-info → runnable-0.32.0.dist-info}/licenses/LICENSE +0 -0
extensions/nodes/torch.py
CHANGED
@@ -4,11 +4,12 @@ import os
|
|
4
4
|
import random
|
5
5
|
import string
|
6
6
|
from datetime import datetime
|
7
|
-
from
|
7
|
+
from pathlib import Path
|
8
|
+
from typing import Any, Callable, Optional
|
8
9
|
|
9
|
-
from pydantic import ConfigDict, Field
|
10
|
+
from pydantic import BaseModel, ConfigDict, Field, field_serializer
|
10
11
|
|
11
|
-
from extensions.nodes.torch_config import EasyTorchConfig,
|
12
|
+
from extensions.nodes.torch_config import EasyTorchConfig, TorchConfig
|
12
13
|
from runnable import PythonJob, datastore, defaults
|
13
14
|
from runnable.datastore import StepLog
|
14
15
|
from runnable.nodes import DistributedNode
|
@@ -18,7 +19,7 @@ from runnable.utils import TypeMapVariable
|
|
18
19
|
logger = logging.getLogger(defaults.LOGGER_NAME)
|
19
20
|
|
20
21
|
try:
|
21
|
-
from torch.distributed.elastic.multiprocessing.api import DefaultLogsSpecs
|
22
|
+
from torch.distributed.elastic.multiprocessing.api import DefaultLogsSpecs, Std
|
22
23
|
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
23
24
|
|
24
25
|
except ImportError:
|
@@ -28,9 +29,30 @@ print("torch is installed")
|
|
28
29
|
|
29
30
|
|
30
31
|
def training_subprocess():
|
32
|
+
"""
|
33
|
+
This function is called by the torch.distributed.launcher.api.elastic_launch
|
34
|
+
It happens in a subprocess and is responsible for executing the user's function
|
35
|
+
|
36
|
+
It is unrelated to the actual node execution, so any cataloging, run_log_store should be
|
37
|
+
handled to match to main process.
|
38
|
+
|
39
|
+
We have these variables to use:
|
40
|
+
|
41
|
+
os.environ["RUNNABLE_TORCH_COMMAND"] = self.executable.command
|
42
|
+
os.environ["RUNNABLE_TORCH_PARAMETERS_FILES"] = (
|
43
|
+
self._context.parameters_file or ""
|
44
|
+
)
|
45
|
+
os.environ["RUNNABLE_TORCH_RUN_ID"] = self._context.run_id
|
46
|
+
os.environ["RUNNABLE_TORCH_COPY_CONTENTS_TO"] = (
|
47
|
+
self._context.catalog_handler.compute_data_folder
|
48
|
+
)
|
49
|
+
os.environ["RUNNABLE_TORCH_TORCH_LOGS"] = self.log_dir or ""
|
50
|
+
|
51
|
+
"""
|
31
52
|
command = os.environ.get("RUNNABLE_TORCH_COMMAND")
|
32
53
|
run_id = os.environ.get("RUNNABLE_TORCH_RUN_ID", "")
|
33
54
|
parameters_files = os.environ.get("RUNNABLE_TORCH_PARAMETERS_FILES", "")
|
55
|
+
|
34
56
|
process_run_id = (
|
35
57
|
run_id
|
36
58
|
+ "-"
|
@@ -38,10 +60,14 @@ def training_subprocess():
|
|
38
60
|
+ "-"
|
39
61
|
+ "".join(random.choices(string.ascii_lowercase, k=3))
|
40
62
|
)
|
63
|
+
os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
|
41
64
|
|
42
65
|
delete_env_vars_with_prefix("RUNNABLE_")
|
43
66
|
|
44
67
|
func = get_callable_from_dotted_path(command)
|
68
|
+
|
69
|
+
# The job runs with the default configuration
|
70
|
+
# ALl the execution logs are stored in .catalog
|
45
71
|
job = PythonJob(function=func)
|
46
72
|
|
47
73
|
job.execute(
|
@@ -57,6 +83,7 @@ def training_subprocess():
|
|
57
83
|
raise Exception(f"Job {process_run_id} failed")
|
58
84
|
|
59
85
|
|
86
|
+
# TODO: Can this be utils.get_module_and_attr_names
|
60
87
|
def get_callable_from_dotted_path(dotted_path) -> Callable:
|
61
88
|
try:
|
62
89
|
# Split the path into module path and callable object
|
@@ -91,6 +118,7 @@ def delete_env_vars_with_prefix(prefix):
|
|
91
118
|
del os.environ[var]
|
92
119
|
|
93
120
|
|
121
|
+
# TODO: The design of this class is not final
|
94
122
|
class TorchNode(DistributedNode, TorchConfig):
|
95
123
|
node_type: str = Field(default="torch", serialization_alias="type")
|
96
124
|
executable: PythonTaskType = Field(exclude=True)
|
@@ -131,15 +159,15 @@ class TorchNode(DistributedNode, TorchConfig):
|
|
131
159
|
)
|
132
160
|
)
|
133
161
|
|
134
|
-
|
162
|
+
launch_config = LaunchConfig(
|
135
163
|
**easy_torch_config.model_dump(
|
136
164
|
exclude_none=True,
|
137
165
|
),
|
138
166
|
logs_specs=log_spec,
|
139
167
|
run_id=self._context.run_id,
|
140
168
|
)
|
141
|
-
|
142
|
-
return
|
169
|
+
logger.info(f"launch_config: {launch_config}")
|
170
|
+
return launch_config
|
143
171
|
|
144
172
|
def execute(
|
145
173
|
self,
|
@@ -159,13 +187,13 @@ class TorchNode(DistributedNode, TorchConfig):
|
|
159
187
|
launch_config = self.get_launch_config()
|
160
188
|
logger.info(f"launch_config: {launch_config}")
|
161
189
|
|
190
|
+
# ENV variables are shared with the subprocess, use that as communication
|
162
191
|
os.environ["RUNNABLE_TORCH_COMMAND"] = self.executable.command
|
163
192
|
os.environ["RUNNABLE_TORCH_PARAMETERS_FILES"] = (
|
164
193
|
self._context.parameters_file or ""
|
165
194
|
)
|
166
195
|
os.environ["RUNNABLE_TORCH_RUN_ID"] = self._context.run_id
|
167
|
-
|
168
|
-
# default to localhost and 29500
|
196
|
+
|
169
197
|
launcher = elastic_launch(
|
170
198
|
launch_config,
|
171
199
|
training_subprocess,
|
@@ -186,6 +214,20 @@ class TorchNode(DistributedNode, TorchConfig):
|
|
186
214
|
attempt_number=attempt_number,
|
187
215
|
)
|
188
216
|
logger.error(f"Error executing TorchNode: {e}")
|
217
|
+
finally:
|
218
|
+
# This can only come from the subprocess
|
219
|
+
if Path(".catalog").exists():
|
220
|
+
os.rename(".catalog", "proc_logs")
|
221
|
+
# Move .catalog and torch_logs to the parent node's catalog location
|
222
|
+
self._context.catalog_handler.put(
|
223
|
+
"proc_logs/**/*", allow_file_not_found_exc=True
|
224
|
+
)
|
225
|
+
|
226
|
+
# TODO: This is not working!!
|
227
|
+
if self.log_dir:
|
228
|
+
self._context.catalog_handler.put(
|
229
|
+
self.log_dir + "/**/*", allow_file_not_found_exc=True
|
230
|
+
)
|
189
231
|
|
190
232
|
delete_env_vars_with_prefix("RUNNABLE_TORCH")
|
191
233
|
|
@@ -211,3 +253,23 @@ class TorchNode(DistributedNode, TorchConfig):
|
|
211
253
|
assert (
|
212
254
|
map_variable is None or not map_variable
|
213
255
|
), "TorchNode does not support map_variable"
|
256
|
+
|
257
|
+
|
258
|
+
# This internal model makes it easier to extract the required fields
|
259
|
+
# of log specs from user specification.
|
260
|
+
# https://github.com/pytorch/pytorch/blob/main/torch/distributed/elastic/multiprocessing/api.py#L243
|
261
|
+
class InternalLogSpecs(BaseModel):
|
262
|
+
log_dir: Optional[str] = Field(default="torch_logs")
|
263
|
+
redirects: str = Field(default="0") # Std.NONE
|
264
|
+
tee: str = Field(default="0") # Std.NONE
|
265
|
+
local_ranks_filter: Optional[set[int]] = Field(default=None)
|
266
|
+
|
267
|
+
model_config = ConfigDict(extra="ignore")
|
268
|
+
|
269
|
+
@field_serializer("redirects")
|
270
|
+
def convert_redirects(self, redirects: str) -> Std | dict[int, Std]:
|
271
|
+
return Std.from_str(redirects)
|
272
|
+
|
273
|
+
@field_serializer("tee")
|
274
|
+
def convert_tee(self, tee: str) -> Std | dict[int, Std]:
|
275
|
+
return Std.from_str(tee)
|
extensions/nodes/torch_config.py
CHANGED
@@ -10,59 +10,39 @@ class StartMethod(str, Enum):
|
|
10
10
|
forkserver = "forkserver"
|
11
11
|
|
12
12
|
|
13
|
-
# min_nodes: int
|
14
|
-
# max_nodes: int
|
15
|
-
# nproc_per_node: int
|
16
|
-
|
17
|
-
# logs_specs: Optional[LogsSpecs] = None
|
18
|
-
# run_id: str = ""
|
19
|
-
# role: str = "default_role"
|
20
|
-
|
21
|
-
# rdzv_endpoint: str = ""
|
22
|
-
# rdzv_backend: str = "etcd"
|
23
|
-
# rdzv_configs: dict[str, Any] = field(default_factory=dict)
|
24
|
-
# rdzv_timeout: int = -1
|
25
|
-
|
26
|
-
# max_restarts: int = 3
|
27
|
-
# monitor_interval: float = 0.1
|
28
|
-
# start_method: str = "spawn"
|
29
|
-
# log_line_prefix_template: Optional[str] = None
|
30
|
-
# metrics_cfg: dict[str, str] = field(default_factory=dict)
|
31
|
-
# local_addr: Optional[str] = None
|
32
|
-
|
33
13
|
## The idea is the following:
|
34
14
|
# Users can configure any of the options present in TorchConfig class.
|
35
|
-
# The LaunchConfig class will be created from
|
15
|
+
# The LaunchConfig class will be created from TorchConfig.
|
36
16
|
# The LogSpecs is sent as a parameter to the launch config.
|
37
|
-
# None as much as possible to get
|
38
17
|
|
39
18
|
## NO idea of standalone and how to send it
|
40
19
|
|
41
20
|
|
42
|
-
|
43
|
-
|
44
|
-
redirects: int | None = Field(default=None)
|
45
|
-
tee: int | None = Field(default=None)
|
46
|
-
local_ranks_filter: Optional[set[int]] = Field(default=None)
|
47
|
-
|
48
|
-
model_config = ConfigDict(extra="ignore")
|
49
|
-
|
50
|
-
|
21
|
+
# The user sees this as part of the config of the node.
|
22
|
+
# It is kept as similar as possible to torchrun
|
51
23
|
class TorchConfig(BaseModel):
|
52
24
|
model_config = ConfigDict(extra="forbid")
|
53
|
-
|
54
|
-
|
25
|
+
|
26
|
+
# excluded as LaunchConfig requires min and max nodes
|
27
|
+
nnodes: str = Field(default="1:1", exclude=True, description="min:max")
|
28
|
+
nproc_per_node: int = Field(default=1, description="Number of processes per node")
|
55
29
|
|
56
30
|
# will be used to create the log specs
|
31
|
+
# But they are excluded from dump as logs specs is a class for LaunchConfig
|
32
|
+
# from_str("0") -> Std.NONE
|
33
|
+
# from_str("1") -> Std.OUT
|
34
|
+
# from_str("0:3,1:0,2:1,3:2") -> {0: Std.ALL, 1: Std.NONE, 2: Std.OUT, 3: Std.ERR}
|
57
35
|
log_dir: Optional[str] = Field(default="torch_logs", exclude=True)
|
58
|
-
redirects:
|
59
|
-
tee:
|
36
|
+
redirects: str = Field(default="0", exclude=True) # Std.NONE
|
37
|
+
tee: str = Field(default="0", exclude=True) # Std.NONE
|
60
38
|
local_ranks_filter: Optional[set[int]] = Field(default=None, exclude=True)
|
61
39
|
|
62
40
|
role: str | None = Field(default=None)
|
41
|
+
|
63
42
|
# run_id would be the run_id of the context
|
64
43
|
# and sent at the creation of the LaunchConfig
|
65
44
|
|
45
|
+
# This section is about the communication between nodes/processes
|
66
46
|
rdzv_backend: str | None = Field(default="static")
|
67
47
|
rdzv_endpoint: str | None = Field(default="")
|
68
48
|
rdzv_configs: dict[str, Any] = Field(default_factory=dict)
|
@@ -0,0 +1,235 @@
|
|
1
|
+
import importlib
|
2
|
+
import logging
|
3
|
+
import os
|
4
|
+
import random
|
5
|
+
import string
|
6
|
+
from datetime import datetime
|
7
|
+
from pathlib import Path
|
8
|
+
from typing import Any, Optional
|
9
|
+
|
10
|
+
from pydantic import BaseModel, ConfigDict, Field, field_serializer, model_validator
|
11
|
+
from ruamel.yaml import YAML
|
12
|
+
|
13
|
+
import runnable.context as context
|
14
|
+
from extensions.tasks.torch_config import EasyTorchConfig, TorchConfig
|
15
|
+
from runnable import Catalog, defaults
|
16
|
+
from runnable.datastore import StepAttempt
|
17
|
+
from runnable.tasks import BaseTaskType
|
18
|
+
from runnable.utils import get_module_and_attr_names
|
19
|
+
|
20
|
+
try:
|
21
|
+
from torch.distributed.elastic.multiprocessing.api import DefaultLogsSpecs, Std
|
22
|
+
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
23
|
+
|
24
|
+
except ImportError:
|
25
|
+
raise ImportError("torch is not installed. Please install torch first.")
|
26
|
+
|
27
|
+
|
28
|
+
logger = logging.getLogger(defaults.LOGGER_NAME)
|
29
|
+
|
30
|
+
|
31
|
+
class TorchTaskType(BaseTaskType, TorchConfig):
|
32
|
+
task_type: str = Field(default="torch", serialization_alias="command_type")
|
33
|
+
catalog: Optional[Catalog] = Field(default=None, alias="catalog")
|
34
|
+
command: str
|
35
|
+
|
36
|
+
@model_validator(mode="before")
|
37
|
+
@classmethod
|
38
|
+
def check_secrets_and_returns(cls, data: Any) -> Any:
|
39
|
+
if isinstance(data, dict):
|
40
|
+
if "secrets" in data and data["secrets"]:
|
41
|
+
raise ValueError("'secrets' is not supported for torch")
|
42
|
+
if "returns" in data and data["returns"]:
|
43
|
+
raise ValueError("'secrets' is not supported for torch")
|
44
|
+
return data
|
45
|
+
|
46
|
+
def get_summary(self) -> dict[str, Any]:
|
47
|
+
return self.model_dump(by_alias=True, exclude_none=True)
|
48
|
+
|
49
|
+
@property
|
50
|
+
def _context(self):
|
51
|
+
return context.run_context
|
52
|
+
|
53
|
+
def _get_launch_config(self) -> LaunchConfig:
|
54
|
+
internal_log_spec = InternalLogSpecs(**self.model_dump(exclude_none=True))
|
55
|
+
log_spec: DefaultLogsSpecs = DefaultLogsSpecs(
|
56
|
+
**internal_log_spec.model_dump(exclude_none=True)
|
57
|
+
)
|
58
|
+
easy_torch_config = EasyTorchConfig(
|
59
|
+
**self.model_dump(
|
60
|
+
exclude_none=True,
|
61
|
+
)
|
62
|
+
)
|
63
|
+
|
64
|
+
launch_config = LaunchConfig(
|
65
|
+
**easy_torch_config.model_dump(
|
66
|
+
exclude_none=True,
|
67
|
+
),
|
68
|
+
logs_specs=log_spec,
|
69
|
+
run_id=self._context.run_id,
|
70
|
+
)
|
71
|
+
logger.info(f"launch_config: {launch_config}")
|
72
|
+
return launch_config
|
73
|
+
|
74
|
+
def execute_command(
|
75
|
+
self,
|
76
|
+
map_variable: defaults.TypeMapVariable = None,
|
77
|
+
):
|
78
|
+
assert map_variable is None, "map_variable is not supported for torch"
|
79
|
+
|
80
|
+
launch_config = self._get_launch_config()
|
81
|
+
logger.info(f"launch_config: {launch_config}")
|
82
|
+
|
83
|
+
# ENV variables are shared with the subprocess, use that as communication
|
84
|
+
os.environ["RUNNABLE_TORCH_COMMAND"] = self.command
|
85
|
+
os.environ["RUNNABLE_TORCH_PARAMETERS_FILES"] = (
|
86
|
+
self._context.parameters_file or ""
|
87
|
+
)
|
88
|
+
os.environ["RUNNABLE_TORCH_RUN_ID"] = self._context.run_id
|
89
|
+
|
90
|
+
launcher = elastic_launch(
|
91
|
+
launch_config,
|
92
|
+
training_subprocess,
|
93
|
+
)
|
94
|
+
try:
|
95
|
+
launcher()
|
96
|
+
attempt_log = StepAttempt(
|
97
|
+
status=defaults.SUCCESS,
|
98
|
+
start_time=str(datetime.now()),
|
99
|
+
end_time=str(datetime.now()),
|
100
|
+
attempt_number=1,
|
101
|
+
)
|
102
|
+
except Exception as e:
|
103
|
+
attempt_log = StepAttempt(
|
104
|
+
status=defaults.FAIL,
|
105
|
+
start_time=str(datetime.now()),
|
106
|
+
end_time=str(datetime.now()),
|
107
|
+
attempt_number=1,
|
108
|
+
)
|
109
|
+
logger.error(f"Error executing TorchNode: {e}")
|
110
|
+
finally:
|
111
|
+
# This can only come from the subprocess
|
112
|
+
if Path("proc_logs").exists():
|
113
|
+
# Move .catalog and torch_logs to the parent node's catalog location
|
114
|
+
self._context.catalog_handler.put(
|
115
|
+
"proc_logs/**/*", allow_file_not_found_exc=True
|
116
|
+
)
|
117
|
+
|
118
|
+
# TODO: This is not working!!
|
119
|
+
if self.log_dir:
|
120
|
+
self._context.catalog_handler.put(
|
121
|
+
self.log_dir + "/**/*", allow_file_not_found_exc=True
|
122
|
+
)
|
123
|
+
|
124
|
+
delete_env_vars_with_prefix("RUNNABLE_TORCH")
|
125
|
+
logger.info(f"attempt_log: {attempt_log}")
|
126
|
+
|
127
|
+
return attempt_log
|
128
|
+
|
129
|
+
|
130
|
+
# This internal model makes it easier to extract the required fields
|
131
|
+
# of log specs from user specification.
|
132
|
+
# https://github.com/pytorch/pytorch/blob/main/torch/distributed/elastic/multiprocessing/api.py#L243
|
133
|
+
class InternalLogSpecs(BaseModel):
|
134
|
+
log_dir: Optional[str] = Field(default="torch_logs")
|
135
|
+
redirects: str = Field(default="0") # Std.NONE
|
136
|
+
tee: str = Field(default="0") # Std.NONE
|
137
|
+
local_ranks_filter: Optional[set[int]] = Field(default=None)
|
138
|
+
|
139
|
+
model_config = ConfigDict(extra="ignore")
|
140
|
+
|
141
|
+
@field_serializer("redirects")
|
142
|
+
def convert_redirects(self, redirects: str) -> Std | dict[int, Std]:
|
143
|
+
return Std.from_str(redirects)
|
144
|
+
|
145
|
+
@field_serializer("tee")
|
146
|
+
def convert_tee(self, tee: str) -> Std | dict[int, Std]:
|
147
|
+
return Std.from_str(tee)
|
148
|
+
|
149
|
+
|
150
|
+
def delete_env_vars_with_prefix(prefix):
|
151
|
+
to_delete = [] # List to keep track of variables to delete
|
152
|
+
|
153
|
+
# Iterate over a list of all environment variable keys
|
154
|
+
for var in os.environ:
|
155
|
+
if var.startswith(prefix):
|
156
|
+
to_delete.append(var)
|
157
|
+
|
158
|
+
# Delete each of the variables collected
|
159
|
+
for var in to_delete:
|
160
|
+
del os.environ[var]
|
161
|
+
|
162
|
+
|
163
|
+
def training_subprocess():
|
164
|
+
"""
|
165
|
+
This function is called by the torch.distributed.launcher.api.elastic_launch
|
166
|
+
It happens in a subprocess and is responsible for executing the user's function
|
167
|
+
|
168
|
+
It is unrelated to the actual node execution, so any cataloging, run_log_store should be
|
169
|
+
handled to match to main process.
|
170
|
+
|
171
|
+
We have these variables to use:
|
172
|
+
|
173
|
+
os.environ["RUNNABLE_TORCH_COMMAND"] = self.executable.command
|
174
|
+
os.environ["RUNNABLE_TORCH_PARAMETERS_FILES"] = (
|
175
|
+
self._context.parameters_file or ""
|
176
|
+
)
|
177
|
+
os.environ["RUNNABLE_TORCH_RUN_ID"] = self._context.run_id
|
178
|
+
os.environ["RUNNABLE_TORCH_COPY_CONTENTS_TO"] = (
|
179
|
+
self._context.catalog_handler.compute_data_folder
|
180
|
+
)
|
181
|
+
os.environ["RUNNABLE_TORCH_TORCH_LOGS"] = self.log_dir or ""
|
182
|
+
|
183
|
+
"""
|
184
|
+
from runnable import PythonJob # noqa: F401
|
185
|
+
|
186
|
+
command = os.environ.get("RUNNABLE_TORCH_COMMAND")
|
187
|
+
assert command, "Command is not provided"
|
188
|
+
|
189
|
+
run_id = os.environ.get("RUNNABLE_TORCH_RUN_ID", "")
|
190
|
+
parameters_files = os.environ.get("RUNNABLE_TORCH_PARAMETERS_FILES", "")
|
191
|
+
|
192
|
+
process_run_id = (
|
193
|
+
run_id
|
194
|
+
+ "-"
|
195
|
+
+ os.environ.get("RANK", "")
|
196
|
+
+ "-"
|
197
|
+
+ "".join(random.choices(string.ascii_lowercase, k=3))
|
198
|
+
)
|
199
|
+
os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
|
200
|
+
|
201
|
+
# In this subprocess there shoould not be any RUNNABLE environment variables
|
202
|
+
delete_env_vars_with_prefix("RUNNABLE_")
|
203
|
+
|
204
|
+
module_name, func_name = get_module_and_attr_names(command)
|
205
|
+
module = importlib.import_module(module_name)
|
206
|
+
|
207
|
+
callable_obj = getattr(module, func_name)
|
208
|
+
|
209
|
+
# The job runs with the default configuration
|
210
|
+
# ALl the execution logs are stored in .catalog
|
211
|
+
job = PythonJob(function=callable_obj)
|
212
|
+
|
213
|
+
config_content = {
|
214
|
+
"catalog": {"type": "file-system", "config": {"catalog_location": "proc_logs"}}
|
215
|
+
}
|
216
|
+
|
217
|
+
temp_config_file = Path("runnable-config.yaml")
|
218
|
+
with open(str(temp_config_file), "w", encoding="utf-8") as config_file:
|
219
|
+
yaml = YAML(typ="safe", pure=True)
|
220
|
+
yaml.dump(config_content, config_file)
|
221
|
+
|
222
|
+
job.execute(
|
223
|
+
parameters_file=parameters_files,
|
224
|
+
job_id=process_run_id,
|
225
|
+
)
|
226
|
+
|
227
|
+
# delete the temp config file
|
228
|
+
temp_config_file.unlink()
|
229
|
+
|
230
|
+
from runnable.context import run_context
|
231
|
+
|
232
|
+
job_log = run_context.run_log_store.get_run_log_by_id(run_id=run_context.run_id)
|
233
|
+
|
234
|
+
if job_log.status == defaults.FAIL:
|
235
|
+
raise Exception(f"Job {process_run_id} failed")
|
@@ -0,0 +1,76 @@
|
|
1
|
+
from enum import Enum
|
2
|
+
from typing import Any, Optional
|
3
|
+
|
4
|
+
from pydantic import BaseModel, ConfigDict, Field, computed_field
|
5
|
+
|
6
|
+
|
7
|
+
class StartMethod(str, Enum):
|
8
|
+
spawn = "spawn"
|
9
|
+
fork = "fork"
|
10
|
+
forkserver = "forkserver"
|
11
|
+
|
12
|
+
|
13
|
+
## The idea is the following:
|
14
|
+
# Users can configure any of the options present in TorchConfig class.
|
15
|
+
# The LaunchConfig class will be created from TorchConfig.
|
16
|
+
# The LogSpecs is sent as a parameter to the launch config.
|
17
|
+
|
18
|
+
## NO idea of standalone and how to send it
|
19
|
+
|
20
|
+
|
21
|
+
# The user sees this as part of the config of the node.
|
22
|
+
# It is kept as similar as possible to torchrun
|
23
|
+
class TorchConfig(BaseModel):
|
24
|
+
model_config = ConfigDict(extra="forbid")
|
25
|
+
|
26
|
+
# excluded as LaunchConfig requires min and max nodes
|
27
|
+
nnodes: str = Field(default="1:1", exclude=True, description="min:max")
|
28
|
+
nproc_per_node: int = Field(default=1, description="Number of processes per node")
|
29
|
+
|
30
|
+
# will be used to create the log specs
|
31
|
+
# But they are excluded from dump as logs specs is a class for LaunchConfig
|
32
|
+
# from_str("0") -> Std.NONE
|
33
|
+
# from_str("1") -> Std.OUT
|
34
|
+
# from_str("0:3,1:0,2:1,3:2") -> {0: Std.ALL, 1: Std.NONE, 2: Std.OUT, 3: Std.ERR}
|
35
|
+
log_dir: Optional[str] = Field(default="torch_logs", exclude=True)
|
36
|
+
redirects: str = Field(default="0", exclude=True) # Std.NONE
|
37
|
+
tee: str = Field(default="0", exclude=True) # Std.NONE
|
38
|
+
local_ranks_filter: Optional[set[int]] = Field(default=None, exclude=True)
|
39
|
+
|
40
|
+
role: str | None = Field(default=None)
|
41
|
+
|
42
|
+
# run_id would be the run_id of the context
|
43
|
+
# and sent at the creation of the LaunchConfig
|
44
|
+
|
45
|
+
# This section is about the communication between nodes/processes
|
46
|
+
rdzv_backend: str | None = Field(default="static")
|
47
|
+
rdzv_endpoint: str | None = Field(default="")
|
48
|
+
rdzv_configs: dict[str, Any] = Field(default_factory=dict)
|
49
|
+
rdzv_timeout: int | None = Field(default=None)
|
50
|
+
|
51
|
+
max_restarts: int | None = Field(default=None)
|
52
|
+
monitor_interval: float | None = Field(default=None)
|
53
|
+
start_method: str | None = Field(default=StartMethod.spawn)
|
54
|
+
log_line_prefix_template: str | None = Field(default=None)
|
55
|
+
local_addr: Optional[str] = None
|
56
|
+
|
57
|
+
# https://github.com/pytorch/pytorch/blob/main/torch/distributed/run.py#L753
|
58
|
+
# master_addr: str | None = Field(default="localhost")
|
59
|
+
# master_port: str | None = Field(default="29500")
|
60
|
+
# training_script: str = Field(default="dummy_training_script")
|
61
|
+
# training_script_args: str = Field(default="")
|
62
|
+
|
63
|
+
|
64
|
+
class EasyTorchConfig(TorchConfig):
|
65
|
+
model_config = ConfigDict(extra="ignore")
|
66
|
+
|
67
|
+
# TODO: Validate min < max
|
68
|
+
@computed_field # type: ignore
|
69
|
+
@property
|
70
|
+
def min_nodes(self) -> int:
|
71
|
+
return int(self.nnodes.split(":")[0])
|
72
|
+
|
73
|
+
@computed_field # type: ignore
|
74
|
+
@property
|
75
|
+
def max_nodes(self) -> int:
|
76
|
+
return int(self.nnodes.split(":")[1])
|
runnable/__init__.py
CHANGED
runnable/entrypoints.py
CHANGED
@@ -129,6 +129,7 @@ def prepare_configurations(
|
|
129
129
|
ServiceConfig,
|
130
130
|
runnable_defaults.get("job-executor", defaults.DEFAULT_JOB_EXECUTOR),
|
131
131
|
)
|
132
|
+
|
132
133
|
assert job_executor_config, "Job executor is not provided"
|
133
134
|
configured_executor = utils.get_provider_by_name_and_type(
|
134
135
|
"job_executor", job_executor_config
|
runnable/sdk.py
CHANGED
@@ -44,10 +44,10 @@ from runnable.tasks import TaskReturns
|
|
44
44
|
logger = logging.getLogger(defaults.LOGGER_NAME)
|
45
45
|
|
46
46
|
StepType = Union[
|
47
|
-
"Stub", "PythonTask", "NotebookTask", "ShellTask", "Parallel", "Map", "
|
47
|
+
"Stub", "PythonTask", "NotebookTask", "ShellTask", "Parallel", "Map", "TorchTask"
|
48
48
|
]
|
49
49
|
if TYPE_CHECKING:
|
50
|
-
|
50
|
+
pass
|
51
51
|
|
52
52
|
|
53
53
|
def pickled(name: str) -> TaskReturns:
|
@@ -191,6 +191,34 @@ class BaseTask(BaseTraversal):
|
|
191
191
|
)
|
192
192
|
|
193
193
|
|
194
|
+
class TorchTask(BaseTask, TorchConfig):
|
195
|
+
function: Callable = Field(exclude=True)
|
196
|
+
|
197
|
+
@field_validator("returns", mode="before")
|
198
|
+
@classmethod
|
199
|
+
def serialize_returns(
|
200
|
+
cls, returns: List[Union[str, TaskReturns]]
|
201
|
+
) -> List[TaskReturns]:
|
202
|
+
assert len(returns) == 0, "Torch tasks cannot return any variables"
|
203
|
+
return []
|
204
|
+
|
205
|
+
@computed_field
|
206
|
+
def command_type(self) -> str:
|
207
|
+
return "torch"
|
208
|
+
|
209
|
+
@computed_field
|
210
|
+
def command(self) -> str:
|
211
|
+
module = self.function.__module__
|
212
|
+
name = self.function.__name__
|
213
|
+
|
214
|
+
return f"{module}.{name}"
|
215
|
+
|
216
|
+
def create_job(self) -> RunnableTask:
|
217
|
+
self.terminate_with_success = True
|
218
|
+
node = self.create_node()
|
219
|
+
return node.executable
|
220
|
+
|
221
|
+
|
194
222
|
class PythonTask(BaseTask):
|
195
223
|
"""
|
196
224
|
An execution node of the pipeline of python functions.
|
@@ -459,43 +487,6 @@ class Stub(BaseTraversal):
|
|
459
487
|
return StubNode.parse_from_config(self.model_dump(exclude_none=True))
|
460
488
|
|
461
489
|
|
462
|
-
class Torch(BaseTraversal, TorchConfig):
|
463
|
-
function: Callable = Field(exclude=True)
|
464
|
-
catalog: Optional[Catalog] = Field(default=None, alias="catalog")
|
465
|
-
overrides: Dict[str, Any] = Field(default_factory=dict, alias="overrides")
|
466
|
-
returns: List[Union[str, TaskReturns]] = Field(
|
467
|
-
default_factory=list, alias="returns"
|
468
|
-
)
|
469
|
-
secrets: List[str] = Field(default_factory=list)
|
470
|
-
|
471
|
-
@computed_field
|
472
|
-
def command_type(self) -> str:
|
473
|
-
return "python"
|
474
|
-
|
475
|
-
@computed_field
|
476
|
-
def command(self) -> str:
|
477
|
-
module = self.function.__module__
|
478
|
-
name = self.function.__name__
|
479
|
-
|
480
|
-
return f"{module}.{name}"
|
481
|
-
|
482
|
-
def create_node(self) -> TorchNode:
|
483
|
-
if not self.next_node:
|
484
|
-
if not (self.terminate_with_failure or self.terminate_with_success):
|
485
|
-
raise AssertionError(
|
486
|
-
"A node not being terminated must have a user defined next node"
|
487
|
-
)
|
488
|
-
|
489
|
-
if self.on_failure:
|
490
|
-
self.on_failure = self.on_failure.steps[0].name # type: ignore
|
491
|
-
|
492
|
-
from extensions.nodes.torch import TorchNode
|
493
|
-
|
494
|
-
return TorchNode.parse_from_config(
|
495
|
-
self.model_dump(exclude_none=True, by_alias=True)
|
496
|
-
)
|
497
|
-
|
498
|
-
|
499
490
|
class Parallel(BaseTraversal):
|
500
491
|
"""
|
501
492
|
A node that executes multiple branches in parallel.
|
@@ -685,6 +676,7 @@ class Pipeline(BaseModel):
|
|
685
676
|
terminal_step: StepType = self.steps[-1]
|
686
677
|
if not terminal_step.terminate_with_failure:
|
687
678
|
terminal_step.terminate_with_success = True
|
679
|
+
terminal_step.next_node = "success"
|
688
680
|
|
689
681
|
# assert that there is only one termination node with success or failure
|
690
682
|
# Assert that there are no duplicate step names
|
@@ -965,7 +957,7 @@ class BaseJob(BaseModel):
|
|
965
957
|
|
966
958
|
|
967
959
|
class PythonJob(BaseJob):
|
968
|
-
function: Callable = Field(
|
960
|
+
function: Callable = Field()
|
969
961
|
|
970
962
|
@property
|
971
963
|
@computed_field
|
@@ -975,14 +967,27 @@ class PythonJob(BaseJob):
|
|
975
967
|
|
976
968
|
return f"{module}.{name}"
|
977
969
|
|
970
|
+
# TODO: can this be simplified to just self.model_dump(exclude_none=True)?
|
978
971
|
def get_task(self) -> RunnableTask:
|
979
972
|
# Piggy bank on existing tasks as a hack
|
980
973
|
task = PythonTask(
|
981
974
|
name="dummy",
|
982
975
|
terminate_with_success=True,
|
983
|
-
|
984
|
-
|
985
|
-
|
976
|
+
**self.model_dump(exclude_defaults=True, exclude_none=True),
|
977
|
+
)
|
978
|
+
return task.create_node().executable
|
979
|
+
|
980
|
+
|
981
|
+
class TorchJob(BaseJob, TorchConfig):
|
982
|
+
function: Callable = Field()
|
983
|
+
# min and max should always be 1
|
984
|
+
|
985
|
+
def get_task(self) -> RunnableTask:
|
986
|
+
# Piggy bank on existing tasks as a hack
|
987
|
+
task = TorchTask(
|
988
|
+
name="dummy",
|
989
|
+
terminate_with_success=True,
|
990
|
+
**self.model_dump(exclude_defaults=True, exclude_none=True),
|
986
991
|
)
|
987
992
|
return task.create_node().executable
|
988
993
|
|
@@ -998,10 +1003,7 @@ class NotebookJob(BaseJob):
|
|
998
1003
|
task = NotebookTask(
|
999
1004
|
name="dummy",
|
1000
1005
|
terminate_with_success=True,
|
1001
|
-
|
1002
|
-
secrets=self.secrets,
|
1003
|
-
notebook=self.notebook,
|
1004
|
-
optional_ploomber_args=self.optional_ploomber_args,
|
1006
|
+
**self.model_dump(exclude_defaults=True, exclude_none=True),
|
1005
1007
|
)
|
1006
1008
|
return task.create_node().executable
|
1007
1009
|
|
@@ -1014,8 +1016,6 @@ class ShellJob(BaseJob):
|
|
1014
1016
|
task = ShellTask(
|
1015
1017
|
name="dummy",
|
1016
1018
|
terminate_with_success=True,
|
1017
|
-
|
1018
|
-
secrets=self.secrets,
|
1019
|
-
command=self.command,
|
1019
|
+
**self.model_dump(exclude_defaults=True, exclude_none=True),
|
1020
1020
|
)
|
1021
1021
|
return task.create_node().executable
|
runnable/tasks.py
CHANGED
@@ -28,7 +28,6 @@ from runnable.datastore import (
|
|
28
28
|
from runnable.defaults import TypeMapVariable
|
29
29
|
|
30
30
|
logger = logging.getLogger(defaults.LOGGER_NAME)
|
31
|
-
logging.getLogger("stevedore").setLevel(logging.CRITICAL)
|
32
31
|
|
33
32
|
|
34
33
|
class TeeIO(io.StringIO):
|
@@ -49,8 +48,7 @@ class TeeIO(io.StringIO):
|
|
49
48
|
self.output_stream.flush()
|
50
49
|
|
51
50
|
|
52
|
-
|
53
|
-
sys.stdout = buffer
|
51
|
+
sys.stdout = TeeIO()
|
54
52
|
|
55
53
|
|
56
54
|
class TaskReturns(BaseModel):
|
@@ -761,6 +759,8 @@ def create_task(kwargs_for_init) -> BaseTaskType:
|
|
761
759
|
tasks.BaseTaskType: The command object
|
762
760
|
"""
|
763
761
|
# The dictionary cannot be modified
|
762
|
+
|
763
|
+
print(kwargs_for_init)
|
764
764
|
kwargs = kwargs_for_init.copy()
|
765
765
|
command_type = kwargs.pop("command_type", defaults.COMMAND_TYPE)
|
766
766
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: runnable
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.32.0
|
4
4
|
Summary: Add your description here
|
5
5
|
Author-email: "Vammi, Vijay" <vijay.vammi@astrazeneca.com>
|
6
6
|
License-File: LICENSE
|
@@ -28,6 +28,7 @@ Provides-Extra: s3
|
|
28
28
|
Requires-Dist: cloudpathlib[s3]; extra == 's3'
|
29
29
|
Provides-Extra: torch
|
30
30
|
Requires-Dist: torch>=2.6.0; extra == 'torch'
|
31
|
+
Requires-Dist: torchvision>=0.21.0; extra == 'torch'
|
31
32
|
Description-Content-Type: text/markdown
|
32
33
|
|
33
34
|
|
@@ -16,8 +16,8 @@ extensions/job_executor/pyproject.toml,sha256=UIEgiCYHTXcRWSByNMFuKJFKgxTBpQqTqy
|
|
16
16
|
extensions/nodes/README.md,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
17
17
|
extensions/nodes/nodes.py,sha256=s9ub1dqy4qHjRQG6YElCdL7rCOTYNs9RUIrStZ6tEB4,28256
|
18
18
|
extensions/nodes/pyproject.toml,sha256=YTu-ETN3JNFSkMzzWeOwn4m-O2nbRH-PmiPBALDCUw4,278
|
19
|
-
extensions/nodes/torch.py,sha256=
|
20
|
-
extensions/nodes/torch_config.py,sha256=
|
19
|
+
extensions/nodes/torch.py,sha256=h3x5931ePBNckeSXM3JFjSoUnxmIWvDyEpn1AI9TKaU,9347
|
20
|
+
extensions/nodes/torch_config.py,sha256=tO3sG2_fj8a6FmPZZllwKVx3WaRr4QmQYcACseg8YXM,2839
|
21
21
|
extensions/pipeline_executor/README.md,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
22
22
|
extensions/pipeline_executor/__init__.py,sha256=wfigTL2T9OHrmE8b2Ydmb8h6hr-oF--Yc2FectC7WaY,24623
|
23
23
|
extensions/pipeline_executor/argo.py,sha256=AEGSWVZulBL6EsvbVCaeBeTl2m_t5ymc6RFpMKhivis,37946
|
@@ -40,13 +40,15 @@ extensions/run_log_store/db/integration_FF.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeR
|
|
40
40
|
extensions/secrets/README.md,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
41
41
|
extensions/secrets/dotenv.py,sha256=nADHXI6KJ_LUYOIe5EbtYH-21OBebSNVr0Pjb1GlZ7w,1573
|
42
42
|
extensions/secrets/pyproject.toml,sha256=mLJNImNcBlbLKHh-0ugVWT9V83R4RibyyYDtBCSqVF4,282
|
43
|
-
|
43
|
+
extensions/tasks/torch.py,sha256=R0J_Q6SRAW2Ii0XQbXaaBWTah8TYs4P_48j2M1bIXeA,7983
|
44
|
+
extensions/tasks/torch_config.py,sha256=tO3sG2_fj8a6FmPZZllwKVx3WaRr4QmQYcACseg8YXM,2839
|
45
|
+
runnable/__init__.py,sha256=3ZKuvGEkY_zHVQlJtarXd4jkjICxjgnw-bbKN_5SiJI,691
|
44
46
|
runnable/catalog.py,sha256=4msQxLhLKlsDDrHFnGauPYe-Or-q9g8_RYCn_4dpxaU,4466
|
45
47
|
runnable/cli.py,sha256=3BiKSj95h2Drn__YlchMPZ5rBMafuRb2OGIsVpbsO5Y,8788
|
46
48
|
runnable/context.py,sha256=by5uepmuCP0dmM9BmsliXihSes5QEFejwAsmekcqylE,1388
|
47
49
|
runnable/datastore.py,sha256=ZobM1aVkgeUJ2fZYt63IFDsoNzObwc93hdByegS5YKQ,32396
|
48
50
|
runnable/defaults.py,sha256=3o9IVGryyCE6PoQTOoaIaHHTbJGEzmdXMcwzOhwAYoI,3518
|
49
|
-
runnable/entrypoints.py,sha256=
|
51
|
+
runnable/entrypoints.py,sha256=1xCbWVUQLGmg5gkWnAVWFLAUf6j4avP9azX_vuGQUMY,18985
|
50
52
|
runnable/exceptions.py,sha256=LFbp0-Qxg2PAMLEVt7w2whhBxSG-5pzUEv5qN-Rc4_c,3003
|
51
53
|
runnable/executor.py,sha256=UOsYJ3NkTGw4FTR0iePX7AOJzY7vODhZ62aqrwVMO1c,15143
|
52
54
|
runnable/graph.py,sha256=poQz5zcvq89ju_u5sYlunQLPbHnXTaUmjcvstPwvT4U,16536
|
@@ -54,12 +56,12 @@ runnable/names.py,sha256=vn92Kv9ANROYSZX6Z4z1v_WA3WiEdIYmG6KEStBFZug,8134
|
|
54
56
|
runnable/nodes.py,sha256=d1eLttMAcV7CTwTEqOuNwZqItANoLUkXJ73Xp-srlyI,17811
|
55
57
|
runnable/parameters.py,sha256=sT3DNGczivP9z7r4Cp_brbudg1z4J-zjmvrq3ppIrVs,5089
|
56
58
|
runnable/pickler.py,sha256=ydJ_eti_U1F4l-YacFp7BWm6g5vTn04UXye25S1HVok,2684
|
57
|
-
runnable/sdk.py,sha256=
|
59
|
+
runnable/sdk.py,sha256=J1PyiHQD2v_0JaqHjY7xSaXwCUMi_mCNr70TsC-SFZU,35012
|
58
60
|
runnable/secrets.py,sha256=4L_dBFxTgr8r_hHUD6RlZEtqaOHDRsFG5PXO5wlvMI0,2324
|
59
|
-
runnable/tasks.py,sha256=
|
61
|
+
runnable/tasks.py,sha256=_A0pcTyOGQL-72AicOxracsrwfs2Vg0r4mQyxz3k6Iw,29016
|
60
62
|
runnable/utils.py,sha256=hBr7oGwGL2VgfITlQCTz-a1iwvvf7Mfl-HY8UdENZac,19929
|
61
|
-
runnable-0.
|
62
|
-
runnable-0.
|
63
|
-
runnable-0.
|
64
|
-
runnable-0.
|
65
|
-
runnable-0.
|
63
|
+
runnable-0.32.0.dist-info/METADATA,sha256=t44gRxxaRugnqaRY9gGwweGT0OLvo_inlC3jxrhP3sg,10168
|
64
|
+
runnable-0.32.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
65
|
+
runnable-0.32.0.dist-info/entry_points.txt,sha256=uWHbbOSj0jlG54tFHw377xKkfVbjWvb_1Y9L_LgjJ0Q,1925
|
66
|
+
runnable-0.32.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
67
|
+
runnable-0.32.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|