runnable 0.34.0a1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of runnable might be problematic. Click here for more details.
- extensions/catalog/any_path.py +13 -2
- extensions/job_executor/__init__.py +7 -5
- extensions/job_executor/emulate.py +106 -0
- extensions/job_executor/k8s.py +8 -8
- extensions/job_executor/local_container.py +13 -14
- extensions/nodes/__init__.py +0 -0
- extensions/nodes/conditional.py +243 -0
- extensions/nodes/fail.py +72 -0
- extensions/nodes/map.py +350 -0
- extensions/nodes/parallel.py +159 -0
- extensions/nodes/stub.py +89 -0
- extensions/nodes/success.py +72 -0
- extensions/nodes/task.py +92 -0
- extensions/pipeline_executor/__init__.py +27 -27
- extensions/pipeline_executor/argo.py +52 -46
- extensions/pipeline_executor/emulate.py +112 -0
- extensions/pipeline_executor/local.py +4 -4
- extensions/pipeline_executor/local_container.py +19 -79
- extensions/pipeline_executor/mocked.py +5 -9
- extensions/pipeline_executor/retry.py +6 -10
- runnable/__init__.py +2 -11
- runnable/catalog.py +6 -23
- runnable/cli.py +145 -48
- runnable/context.py +520 -28
- runnable/datastore.py +51 -54
- runnable/defaults.py +12 -34
- runnable/entrypoints.py +82 -440
- runnable/exceptions.py +35 -34
- runnable/executor.py +13 -20
- runnable/gantt.py +1141 -0
- runnable/graph.py +1 -1
- runnable/names.py +1 -1
- runnable/nodes.py +20 -16
- runnable/parameters.py +108 -51
- runnable/sdk.py +125 -204
- runnable/tasks.py +62 -85
- runnable/utils.py +6 -268
- runnable-1.0.0.dist-info/METADATA +122 -0
- runnable-1.0.0.dist-info/RECORD +73 -0
- {runnable-0.34.0a1.dist-info → runnable-1.0.0.dist-info}/entry_points.txt +9 -8
- extensions/nodes/nodes.py +0 -778
- extensions/nodes/torch.py +0 -273
- extensions/nodes/torch_config.py +0 -76
- extensions/tasks/torch.py +0 -286
- extensions/tasks/torch_config.py +0 -76
- runnable-0.34.0a1.dist-info/METADATA +0 -267
- runnable-0.34.0a1.dist-info/RECORD +0 -67
- {runnable-0.34.0a1.dist-info → runnable-1.0.0.dist-info}/WHEEL +0 -0
- {runnable-0.34.0a1.dist-info → runnable-1.0.0.dist-info}/licenses/LICENSE +0 -0
extensions/nodes/torch.py
DELETED
|
@@ -1,273 +0,0 @@
|
|
|
1
|
-
import importlib
|
|
2
|
-
import logging
|
|
3
|
-
import os
|
|
4
|
-
import random
|
|
5
|
-
import string
|
|
6
|
-
from datetime import datetime
|
|
7
|
-
from pathlib import Path
|
|
8
|
-
from typing import Any, Callable, Optional
|
|
9
|
-
|
|
10
|
-
from pydantic import BaseModel, ConfigDict, Field, field_serializer
|
|
11
|
-
|
|
12
|
-
from extensions.nodes.torch_config import EasyTorchConfig, TorchConfig
|
|
13
|
-
from runnable import PythonJob, datastore, defaults
|
|
14
|
-
from runnable.datastore import StepLog
|
|
15
|
-
from runnable.nodes import ExecutableNode
|
|
16
|
-
from runnable.tasks import PythonTaskType, create_task
|
|
17
|
-
from runnable.utils import TypeMapVariable
|
|
18
|
-
|
|
19
|
-
logger = logging.getLogger(defaults.LOGGER_NAME)
|
|
20
|
-
|
|
21
|
-
try:
|
|
22
|
-
from torch.distributed.elastic.multiprocessing.api import DefaultLogsSpecs, Std
|
|
23
|
-
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
|
24
|
-
except ImportError:
|
|
25
|
-
logger.exception("Torch is not installed. Please install torch first.")
|
|
26
|
-
raise Exception("Torch is not installed. Please install torch first.")
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
def training_subprocess():
|
|
30
|
-
"""
|
|
31
|
-
This function is called by the torch.distributed.launcher.api.elastic_launch
|
|
32
|
-
It happens in a subprocess and is responsible for executing the user's function
|
|
33
|
-
|
|
34
|
-
It is unrelated to the actual node execution, so any cataloging, run_log_store should be
|
|
35
|
-
handled to match to main process.
|
|
36
|
-
|
|
37
|
-
We have these variables to use:
|
|
38
|
-
|
|
39
|
-
os.environ["RUNNABLE_TORCH_COMMAND"] = self.executable.command
|
|
40
|
-
os.environ["RUNNABLE_TORCH_PARAMETERS_FILES"] = (
|
|
41
|
-
self._context.parameters_file or ""
|
|
42
|
-
)
|
|
43
|
-
os.environ["RUNNABLE_TORCH_RUN_ID"] = self._context.run_id
|
|
44
|
-
os.environ["RUNNABLE_TORCH_COPY_CONTENTS_TO"] = (
|
|
45
|
-
self._context.catalog_handler.compute_data_folder
|
|
46
|
-
)
|
|
47
|
-
os.environ["RUNNABLE_TORCH_TORCH_LOGS"] = self.log_dir or ""
|
|
48
|
-
|
|
49
|
-
"""
|
|
50
|
-
command = os.environ.get("RUNNABLE_TORCH_COMMAND")
|
|
51
|
-
run_id = os.environ.get("RUNNABLE_TORCH_RUN_ID", "")
|
|
52
|
-
parameters_files = os.environ.get("RUNNABLE_TORCH_PARAMETERS_FILES", "")
|
|
53
|
-
|
|
54
|
-
process_run_id = (
|
|
55
|
-
run_id
|
|
56
|
-
+ "-"
|
|
57
|
-
+ os.environ.get("RANK", "")
|
|
58
|
-
+ "-"
|
|
59
|
-
+ "".join(random.choices(string.ascii_lowercase, k=3))
|
|
60
|
-
)
|
|
61
|
-
os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
|
|
62
|
-
|
|
63
|
-
delete_env_vars_with_prefix("RUNNABLE_")
|
|
64
|
-
|
|
65
|
-
func = get_callable_from_dotted_path(command)
|
|
66
|
-
|
|
67
|
-
# The job runs with the default configuration
|
|
68
|
-
# ALl the execution logs are stored in .catalog
|
|
69
|
-
job = PythonJob(function=func)
|
|
70
|
-
|
|
71
|
-
job.execute(
|
|
72
|
-
parameters_file=parameters_files,
|
|
73
|
-
job_id=process_run_id,
|
|
74
|
-
)
|
|
75
|
-
|
|
76
|
-
from runnable.context import run_context
|
|
77
|
-
|
|
78
|
-
job_log = run_context.run_log_store.get_run_log_by_id(run_id=run_context.run_id)
|
|
79
|
-
|
|
80
|
-
if job_log.status == defaults.FAIL:
|
|
81
|
-
raise Exception(f"Job {process_run_id} failed")
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
# TODO: Can this be utils.get_module_and_attr_names
|
|
85
|
-
def get_callable_from_dotted_path(dotted_path) -> Callable:
|
|
86
|
-
try:
|
|
87
|
-
# Split the path into module path and callable object
|
|
88
|
-
module_path, callable_name = dotted_path.rsplit(".", 1)
|
|
89
|
-
|
|
90
|
-
# Import the module
|
|
91
|
-
module = importlib.import_module(module_path)
|
|
92
|
-
|
|
93
|
-
# Get the callable from the module
|
|
94
|
-
callable_obj = getattr(module, callable_name)
|
|
95
|
-
|
|
96
|
-
# Check if the object is callable
|
|
97
|
-
if not callable(callable_obj):
|
|
98
|
-
raise TypeError(f"The object {callable_name} is not callable.")
|
|
99
|
-
|
|
100
|
-
return callable_obj
|
|
101
|
-
|
|
102
|
-
except (ImportError, AttributeError, ValueError) as e:
|
|
103
|
-
raise ImportError(f"Could not import '{dotted_path}'.") from e
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
def delete_env_vars_with_prefix(prefix):
|
|
107
|
-
to_delete = [] # List to keep track of variables to delete
|
|
108
|
-
|
|
109
|
-
# Iterate over a list of all environment variable keys
|
|
110
|
-
for var in os.environ:
|
|
111
|
-
if var.startswith(prefix):
|
|
112
|
-
to_delete.append(var)
|
|
113
|
-
|
|
114
|
-
# Delete each of the variables collected
|
|
115
|
-
for var in to_delete:
|
|
116
|
-
del os.environ[var]
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
# TODO: The design of this class is not final
|
|
120
|
-
class TorchNode(ExecutableNode, TorchConfig):
|
|
121
|
-
node_type: str = Field(default="torch", serialization_alias="type")
|
|
122
|
-
executable: PythonTaskType = Field(exclude=True)
|
|
123
|
-
|
|
124
|
-
# Similar to TaskNode
|
|
125
|
-
model_config = ConfigDict(extra="allow")
|
|
126
|
-
|
|
127
|
-
def get_summary(self) -> dict[str, Any]:
|
|
128
|
-
summary = {
|
|
129
|
-
"name": self.name,
|
|
130
|
-
"type": self.node_type,
|
|
131
|
-
}
|
|
132
|
-
|
|
133
|
-
return summary
|
|
134
|
-
|
|
135
|
-
@classmethod
|
|
136
|
-
def parse_from_config(cls, config: dict[str, Any]) -> "TorchNode":
|
|
137
|
-
task_config = {
|
|
138
|
-
k: v for k, v in config.items() if k not in TorchNode.model_fields.keys()
|
|
139
|
-
}
|
|
140
|
-
node_config = {
|
|
141
|
-
k: v for k, v in config.items() if k in TorchNode.model_fields.keys()
|
|
142
|
-
}
|
|
143
|
-
|
|
144
|
-
executable = create_task(task_config)
|
|
145
|
-
|
|
146
|
-
assert isinstance(executable, PythonTaskType)
|
|
147
|
-
return cls(executable=executable, **node_config, **task_config)
|
|
148
|
-
|
|
149
|
-
def get_launch_config(self) -> LaunchConfig:
|
|
150
|
-
internal_log_spec = InternalLogSpecs(**self.model_dump(exclude_none=True))
|
|
151
|
-
log_spec: DefaultLogsSpecs = DefaultLogsSpecs(
|
|
152
|
-
**internal_log_spec.model_dump(exclude_none=True)
|
|
153
|
-
)
|
|
154
|
-
easy_torch_config = EasyTorchConfig(
|
|
155
|
-
**self.model_dump(
|
|
156
|
-
exclude_none=True,
|
|
157
|
-
)
|
|
158
|
-
)
|
|
159
|
-
|
|
160
|
-
launch_config = LaunchConfig(
|
|
161
|
-
**easy_torch_config.model_dump(
|
|
162
|
-
exclude_none=True,
|
|
163
|
-
),
|
|
164
|
-
logs_specs=log_spec,
|
|
165
|
-
run_id=self._context.run_id,
|
|
166
|
-
)
|
|
167
|
-
logger.info(f"launch_config: {launch_config}")
|
|
168
|
-
return launch_config
|
|
169
|
-
|
|
170
|
-
def execute(
|
|
171
|
-
self,
|
|
172
|
-
mock=False,
|
|
173
|
-
map_variable: TypeMapVariable = None,
|
|
174
|
-
attempt_number: int = 1,
|
|
175
|
-
) -> StepLog:
|
|
176
|
-
assert (
|
|
177
|
-
map_variable is None or not map_variable
|
|
178
|
-
), "TorchNode does not support map_variable"
|
|
179
|
-
|
|
180
|
-
step_log = self._context.run_log_store.get_step_log(
|
|
181
|
-
self._get_step_log_name(map_variable), self._context.run_id
|
|
182
|
-
)
|
|
183
|
-
|
|
184
|
-
# Attempt to call the function or elastic launch
|
|
185
|
-
launch_config = self.get_launch_config()
|
|
186
|
-
logger.info(f"launch_config: {launch_config}")
|
|
187
|
-
|
|
188
|
-
# ENV variables are shared with the subprocess, use that as communication
|
|
189
|
-
os.environ["RUNNABLE_TORCH_COMMAND"] = self.executable.command
|
|
190
|
-
os.environ["RUNNABLE_TORCH_PARAMETERS_FILES"] = (
|
|
191
|
-
self._context.parameters_file or ""
|
|
192
|
-
)
|
|
193
|
-
os.environ["RUNNABLE_TORCH_RUN_ID"] = self._context.run_id
|
|
194
|
-
|
|
195
|
-
launcher = elastic_launch(
|
|
196
|
-
launch_config,
|
|
197
|
-
training_subprocess,
|
|
198
|
-
)
|
|
199
|
-
try:
|
|
200
|
-
launcher()
|
|
201
|
-
attempt_log = datastore.StepAttempt(
|
|
202
|
-
status=defaults.SUCCESS,
|
|
203
|
-
start_time=str(datetime.now()),
|
|
204
|
-
end_time=str(datetime.now()),
|
|
205
|
-
attempt_number=attempt_number,
|
|
206
|
-
)
|
|
207
|
-
except Exception as e:
|
|
208
|
-
attempt_log = datastore.StepAttempt(
|
|
209
|
-
status=defaults.FAIL,
|
|
210
|
-
start_time=str(datetime.now()),
|
|
211
|
-
end_time=str(datetime.now()),
|
|
212
|
-
attempt_number=attempt_number,
|
|
213
|
-
)
|
|
214
|
-
logger.error(f"Error executing TorchNode: {e}")
|
|
215
|
-
finally:
|
|
216
|
-
# This can only come from the subprocess
|
|
217
|
-
if Path(".catalog").exists():
|
|
218
|
-
os.rename(".catalog", "proc_logs")
|
|
219
|
-
# Move .catalog and torch_logs to the parent node's catalog location
|
|
220
|
-
self._context.catalog_handler.put(
|
|
221
|
-
"proc_logs/**/*", allow_file_not_found_exc=True
|
|
222
|
-
)
|
|
223
|
-
|
|
224
|
-
# TODO: This is not working!!
|
|
225
|
-
if self.log_dir:
|
|
226
|
-
self._context.catalog_handler.put(
|
|
227
|
-
self.log_dir + "/**/*", allow_file_not_found_exc=True
|
|
228
|
-
)
|
|
229
|
-
|
|
230
|
-
delete_env_vars_with_prefix("RUNNABLE_TORCH")
|
|
231
|
-
|
|
232
|
-
logger.info(f"attempt_log: {attempt_log}")
|
|
233
|
-
logger.info(f"Step {self.name} completed with status: {attempt_log.status}")
|
|
234
|
-
|
|
235
|
-
step_log.status = attempt_log.status
|
|
236
|
-
step_log.attempts.append(attempt_log)
|
|
237
|
-
|
|
238
|
-
return step_log
|
|
239
|
-
|
|
240
|
-
def fan_in(self, map_variable: dict[str, str | int | float] | None = None):
|
|
241
|
-
# Destroy the service
|
|
242
|
-
# Destroy the statefulset
|
|
243
|
-
assert (
|
|
244
|
-
map_variable is None or not map_variable
|
|
245
|
-
), "TorchNode does not support map_variable"
|
|
246
|
-
|
|
247
|
-
def fan_out(self, map_variable: dict[str, str | int | float] | None = None):
|
|
248
|
-
# Create a service
|
|
249
|
-
# Create a statefulset
|
|
250
|
-
# Gather the IPs and set them as parameters downstream
|
|
251
|
-
assert (
|
|
252
|
-
map_variable is None or not map_variable
|
|
253
|
-
), "TorchNode does not support map_variable"
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
# This internal model makes it easier to extract the required fields
|
|
257
|
-
# of log specs from user specification.
|
|
258
|
-
# https://github.com/pytorch/pytorch/blob/main/torch/distributed/elastic/multiprocessing/api.py#L243
|
|
259
|
-
class InternalLogSpecs(BaseModel):
|
|
260
|
-
log_dir: Optional[str] = Field(default="torch_logs")
|
|
261
|
-
redirects: str = Field(default="0") # Std.NONE
|
|
262
|
-
tee: str = Field(default="0") # Std.NONE
|
|
263
|
-
local_ranks_filter: Optional[set[int]] = Field(default=None)
|
|
264
|
-
|
|
265
|
-
model_config = ConfigDict(extra="ignore")
|
|
266
|
-
|
|
267
|
-
@field_serializer("redirects")
|
|
268
|
-
def convert_redirects(self, redirects: str) -> Std | dict[int, Std]:
|
|
269
|
-
return Std.from_str(redirects)
|
|
270
|
-
|
|
271
|
-
@field_serializer("tee")
|
|
272
|
-
def convert_tee(self, tee: str) -> Std | dict[int, Std]:
|
|
273
|
-
return Std.from_str(tee)
|
extensions/nodes/torch_config.py
DELETED
|
@@ -1,76 +0,0 @@
|
|
|
1
|
-
from enum import Enum
|
|
2
|
-
from typing import Any, Optional
|
|
3
|
-
|
|
4
|
-
from pydantic import BaseModel, ConfigDict, Field, computed_field
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
class StartMethod(str, Enum):
|
|
8
|
-
spawn = "spawn"
|
|
9
|
-
fork = "fork"
|
|
10
|
-
forkserver = "forkserver"
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
## The idea is the following:
|
|
14
|
-
# Users can configure any of the options present in TorchConfig class.
|
|
15
|
-
# The LaunchConfig class will be created from TorchConfig.
|
|
16
|
-
# The LogSpecs is sent as a parameter to the launch config.
|
|
17
|
-
|
|
18
|
-
## NO idea of standalone and how to send it
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
# The user sees this as part of the config of the node.
|
|
22
|
-
# It is kept as similar as possible to torchrun
|
|
23
|
-
class TorchConfig(BaseModel):
|
|
24
|
-
model_config = ConfigDict(extra="forbid")
|
|
25
|
-
|
|
26
|
-
# excluded as LaunchConfig requires min and max nodes
|
|
27
|
-
nnodes: str = Field(default="1:1", exclude=True, description="min:max")
|
|
28
|
-
nproc_per_node: int = Field(default=1, description="Number of processes per node")
|
|
29
|
-
|
|
30
|
-
# will be used to create the log specs
|
|
31
|
-
# But they are excluded from dump as logs specs is a class for LaunchConfig
|
|
32
|
-
# from_str("0") -> Std.NONE
|
|
33
|
-
# from_str("1") -> Std.OUT
|
|
34
|
-
# from_str("0:3,1:0,2:1,3:2") -> {0: Std.ALL, 1: Std.NONE, 2: Std.OUT, 3: Std.ERR}
|
|
35
|
-
log_dir: Optional[str] = Field(default="torch_logs", exclude=True)
|
|
36
|
-
redirects: str = Field(default="0", exclude=True) # Std.NONE
|
|
37
|
-
tee: str = Field(default="0", exclude=True) # Std.NONE
|
|
38
|
-
local_ranks_filter: Optional[set[int]] = Field(default=None, exclude=True)
|
|
39
|
-
|
|
40
|
-
role: str | None = Field(default=None)
|
|
41
|
-
|
|
42
|
-
# run_id would be the run_id of the context
|
|
43
|
-
# and sent at the creation of the LaunchConfig
|
|
44
|
-
|
|
45
|
-
# This section is about the communication between nodes/processes
|
|
46
|
-
rdzv_backend: str | None = Field(default="static")
|
|
47
|
-
rdzv_endpoint: str | None = Field(default="")
|
|
48
|
-
rdzv_configs: dict[str, Any] = Field(default_factory=dict)
|
|
49
|
-
rdzv_timeout: int | None = Field(default=None)
|
|
50
|
-
|
|
51
|
-
max_restarts: int | None = Field(default=None)
|
|
52
|
-
monitor_interval: float | None = Field(default=None)
|
|
53
|
-
start_method: str | None = Field(default=StartMethod.spawn)
|
|
54
|
-
log_line_prefix_template: str | None = Field(default=None)
|
|
55
|
-
local_addr: Optional[str] = None
|
|
56
|
-
|
|
57
|
-
# https://github.com/pytorch/pytorch/blob/main/torch/distributed/run.py#L753
|
|
58
|
-
# master_addr: str | None = Field(default="localhost")
|
|
59
|
-
# master_port: str | None = Field(default="29500")
|
|
60
|
-
# training_script: str = Field(default="dummy_training_script")
|
|
61
|
-
# training_script_args: str = Field(default="")
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
class EasyTorchConfig(TorchConfig):
|
|
65
|
-
model_config = ConfigDict(extra="ignore")
|
|
66
|
-
|
|
67
|
-
# TODO: Validate min < max
|
|
68
|
-
@computed_field # type: ignore
|
|
69
|
-
@property
|
|
70
|
-
def min_nodes(self) -> int:
|
|
71
|
-
return int(self.nnodes.split(":")[0])
|
|
72
|
-
|
|
73
|
-
@computed_field # type: ignore
|
|
74
|
-
@property
|
|
75
|
-
def max_nodes(self) -> int:
|
|
76
|
-
return int(self.nnodes.split(":")[1])
|
extensions/tasks/torch.py
DELETED
|
@@ -1,286 +0,0 @@
|
|
|
1
|
-
import importlib
|
|
2
|
-
import logging
|
|
3
|
-
import os
|
|
4
|
-
import random
|
|
5
|
-
import string
|
|
6
|
-
from datetime import datetime
|
|
7
|
-
from pathlib import Path
|
|
8
|
-
from typing import Any, Optional
|
|
9
|
-
|
|
10
|
-
from pydantic import BaseModel, ConfigDict, Field, field_serializer, model_validator
|
|
11
|
-
from ruamel.yaml import YAML
|
|
12
|
-
|
|
13
|
-
import runnable.context as context
|
|
14
|
-
from extensions.tasks.torch_config import EasyTorchConfig, TorchConfig
|
|
15
|
-
from runnable import Catalog, defaults
|
|
16
|
-
from runnable.datastore import StepAttempt
|
|
17
|
-
from runnable.tasks import BaseTaskType
|
|
18
|
-
from runnable.utils import get_module_and_attr_names
|
|
19
|
-
|
|
20
|
-
logger = logging.getLogger(defaults.LOGGER_NAME)
|
|
21
|
-
|
|
22
|
-
logger = logging.getLogger(defaults.LOGGER_NAME)
|
|
23
|
-
|
|
24
|
-
try:
|
|
25
|
-
from torch.distributed.elastic.multiprocessing.api import DefaultLogsSpecs, Std
|
|
26
|
-
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
|
27
|
-
|
|
28
|
-
except ImportError as e:
|
|
29
|
-
logger.exception("torch is not installed")
|
|
30
|
-
raise Exception("torch is not installed") from e
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
def get_min_max_nodes(nnodes: str) -> tuple[int, int]:
|
|
34
|
-
min_nodes, max_nodes = (int(x) for x in nnodes.split(":"))
|
|
35
|
-
return min_nodes, max_nodes
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
class TorchTaskType(BaseTaskType, TorchConfig):
|
|
39
|
-
task_type: str = Field(default="torch", serialization_alias="command_type")
|
|
40
|
-
catalog: Optional[Catalog] = Field(default=None, alias="catalog")
|
|
41
|
-
command: str
|
|
42
|
-
|
|
43
|
-
@model_validator(mode="before")
|
|
44
|
-
@classmethod
|
|
45
|
-
def check_secrets_and_returns(cls, data: Any) -> Any:
|
|
46
|
-
if isinstance(data, dict):
|
|
47
|
-
if "secrets" in data and data["secrets"]:
|
|
48
|
-
raise ValueError("'secrets' is not supported for torch")
|
|
49
|
-
if "returns" in data and data["returns"]:
|
|
50
|
-
raise ValueError("'secrets' is not supported for torch")
|
|
51
|
-
return data
|
|
52
|
-
|
|
53
|
-
def get_summary(self) -> dict[str, Any]:
|
|
54
|
-
return self.model_dump(by_alias=True, exclude_none=True)
|
|
55
|
-
|
|
56
|
-
@property
|
|
57
|
-
def _context(self):
|
|
58
|
-
return context.run_context
|
|
59
|
-
|
|
60
|
-
def _get_launch_config(self) -> LaunchConfig:
|
|
61
|
-
internal_log_spec = InternalLogSpecs(**self.model_dump(exclude_none=True))
|
|
62
|
-
log_spec: DefaultLogsSpecs = DefaultLogsSpecs(
|
|
63
|
-
**internal_log_spec.model_dump(exclude_none=True)
|
|
64
|
-
)
|
|
65
|
-
easy_torch_config = EasyTorchConfig(
|
|
66
|
-
**self.model_dump(
|
|
67
|
-
exclude_none=True,
|
|
68
|
-
)
|
|
69
|
-
)
|
|
70
|
-
print("###", easy_torch_config)
|
|
71
|
-
print("###", easy_torch_config)
|
|
72
|
-
launch_config = LaunchConfig(
|
|
73
|
-
**easy_torch_config.model_dump(
|
|
74
|
-
exclude_none=True,
|
|
75
|
-
),
|
|
76
|
-
logs_specs=log_spec,
|
|
77
|
-
run_id=self._context.run_id,
|
|
78
|
-
)
|
|
79
|
-
logger.info(f"launch_config: {launch_config}")
|
|
80
|
-
return launch_config
|
|
81
|
-
|
|
82
|
-
def execute_command(
|
|
83
|
-
self,
|
|
84
|
-
map_variable: defaults.TypeMapVariable = None,
|
|
85
|
-
):
|
|
86
|
-
assert map_variable is None, "map_variable is not supported for torch"
|
|
87
|
-
|
|
88
|
-
# The below should happen only if we are in the node that we want to execute
|
|
89
|
-
# For a single node, multi worker setup, this should be the entry point
|
|
90
|
-
# For a multi-node, we need to:
|
|
91
|
-
# - create a service config
|
|
92
|
-
# - Create a stateful set with number of nodes
|
|
93
|
-
# - Create a job to run the torch.distributed.launcher.api.elastic_launch on every node
|
|
94
|
-
# - the entry point to runnnable could be a way to trigger execution instead of scaling
|
|
95
|
-
is_execute = os.environ.get("RUNNABLE_TORCH_EXECUTE", "true") == "true"
|
|
96
|
-
|
|
97
|
-
_, max_nodes = get_min_max_nodes(self.nnodes)
|
|
98
|
-
|
|
99
|
-
if max_nodes > 1 and not is_execute:
|
|
100
|
-
executor = self._context.executor
|
|
101
|
-
executor.scale_up(self)
|
|
102
|
-
return StepAttempt(
|
|
103
|
-
status=defaults.SUCCESS,
|
|
104
|
-
start_time=str(datetime.now()),
|
|
105
|
-
end_time=str(datetime.now()),
|
|
106
|
-
attempt_number=1,
|
|
107
|
-
message="Triggered a scale up",
|
|
108
|
-
)
|
|
109
|
-
|
|
110
|
-
# The below should happen only if we are in the node that we want to execute
|
|
111
|
-
# For a single node, multi worker setup, this should be the entry point
|
|
112
|
-
# For a multi-node, we need to:
|
|
113
|
-
# - create a service config
|
|
114
|
-
# - Create a stateful set with number of nodes
|
|
115
|
-
# - Create a job to run the torch.distributed.launcher.api.elastic_launch on every node
|
|
116
|
-
# - the entry point to runnnable could be a way to trigger execution instead of scaling
|
|
117
|
-
is_execute = os.environ.get("RUNNABLE_TORCH_EXECUTE", "true") == "true"
|
|
118
|
-
|
|
119
|
-
_, max_nodes = get_min_max_nodes(self.nnodes)
|
|
120
|
-
|
|
121
|
-
if max_nodes > 1 and not is_execute:
|
|
122
|
-
executor = self._context.executor
|
|
123
|
-
executor.scale_up(self)
|
|
124
|
-
return StepAttempt(
|
|
125
|
-
status=defaults.SUCCESS,
|
|
126
|
-
start_time=str(datetime.now()),
|
|
127
|
-
end_time=str(datetime.now()),
|
|
128
|
-
attempt_number=1,
|
|
129
|
-
message="Triggered a scale up",
|
|
130
|
-
)
|
|
131
|
-
|
|
132
|
-
launch_config = self._get_launch_config()
|
|
133
|
-
print("###****", launch_config)
|
|
134
|
-
print("###****", launch_config)
|
|
135
|
-
logger.info(f"launch_config: {launch_config}")
|
|
136
|
-
|
|
137
|
-
# ENV variables are shared with the subprocess, use that as communication
|
|
138
|
-
os.environ["RUNNABLE_TORCH_COMMAND"] = self.command
|
|
139
|
-
os.environ["RUNNABLE_TORCH_PARAMETERS_FILES"] = (
|
|
140
|
-
self._context.parameters_file or ""
|
|
141
|
-
)
|
|
142
|
-
os.environ["RUNNABLE_TORCH_RUN_ID"] = self._context.run_id
|
|
143
|
-
|
|
144
|
-
launcher = elastic_launch(
|
|
145
|
-
launch_config,
|
|
146
|
-
training_subprocess,
|
|
147
|
-
)
|
|
148
|
-
try:
|
|
149
|
-
launcher()
|
|
150
|
-
attempt_log = StepAttempt(
|
|
151
|
-
status=defaults.SUCCESS,
|
|
152
|
-
start_time=str(datetime.now()),
|
|
153
|
-
end_time=str(datetime.now()),
|
|
154
|
-
attempt_number=1,
|
|
155
|
-
)
|
|
156
|
-
except Exception as e:
|
|
157
|
-
attempt_log = StepAttempt(
|
|
158
|
-
status=defaults.FAIL,
|
|
159
|
-
start_time=str(datetime.now()),
|
|
160
|
-
end_time=str(datetime.now()),
|
|
161
|
-
attempt_number=1,
|
|
162
|
-
)
|
|
163
|
-
logger.error(f"Error executing TorchNode: {e}")
|
|
164
|
-
finally:
|
|
165
|
-
# This can only come from the subprocess
|
|
166
|
-
if Path("proc_logs").exists():
|
|
167
|
-
# Move .catalog and torch_logs to the parent node's catalog location
|
|
168
|
-
self._context.catalog_handler.put(
|
|
169
|
-
"proc_logs/**/*", allow_file_not_found_exc=True
|
|
170
|
-
)
|
|
171
|
-
|
|
172
|
-
# TODO: This is not working!!
|
|
173
|
-
if self.log_dir:
|
|
174
|
-
self._context.catalog_handler.put(
|
|
175
|
-
self.log_dir + "/**/*", allow_file_not_found_exc=True
|
|
176
|
-
)
|
|
177
|
-
|
|
178
|
-
delete_env_vars_with_prefix("RUNNABLE_TORCH")
|
|
179
|
-
logger.info(f"attempt_log: {attempt_log}")
|
|
180
|
-
|
|
181
|
-
return attempt_log
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
# This internal model makes it easier to extract the required fields
|
|
185
|
-
# of log specs from user specification.
|
|
186
|
-
# https://github.com/pytorch/pytorch/blob/main/torch/distributed/elastic/multiprocessing/api.py#L243
|
|
187
|
-
class InternalLogSpecs(BaseModel):
|
|
188
|
-
log_dir: Optional[str] = Field(default="torch_logs")
|
|
189
|
-
redirects: str = Field(default="0") # Std.NONE
|
|
190
|
-
tee: str = Field(default="0") # Std.NONE
|
|
191
|
-
local_ranks_filter: Optional[set[int]] = Field(default=None)
|
|
192
|
-
|
|
193
|
-
model_config = ConfigDict(extra="ignore")
|
|
194
|
-
|
|
195
|
-
@field_serializer("redirects")
|
|
196
|
-
def convert_redirects(self, redirects: str) -> Std | dict[int, Std]:
|
|
197
|
-
return Std.from_str(redirects)
|
|
198
|
-
|
|
199
|
-
@field_serializer("tee")
|
|
200
|
-
def convert_tee(self, tee: str) -> Std | dict[int, Std]:
|
|
201
|
-
return Std.from_str(tee)
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
def delete_env_vars_with_prefix(prefix):
|
|
205
|
-
to_delete = [] # List to keep track of variables to delete
|
|
206
|
-
|
|
207
|
-
# Iterate over a list of all environment variable keys
|
|
208
|
-
for var in os.environ:
|
|
209
|
-
if var.startswith(prefix):
|
|
210
|
-
to_delete.append(var)
|
|
211
|
-
|
|
212
|
-
# Delete each of the variables collected
|
|
213
|
-
for var in to_delete:
|
|
214
|
-
del os.environ[var]
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
def training_subprocess():
|
|
218
|
-
"""
|
|
219
|
-
This function is called by the torch.distributed.launcher.api.elastic_launch
|
|
220
|
-
It happens in a subprocess and is responsible for executing the user's function
|
|
221
|
-
|
|
222
|
-
It is unrelated to the actual node execution, so any cataloging, run_log_store should be
|
|
223
|
-
handled to match to main process.
|
|
224
|
-
|
|
225
|
-
We have these variables to use:
|
|
226
|
-
|
|
227
|
-
os.environ["RUNNABLE_TORCH_COMMAND"] = self.executable.command
|
|
228
|
-
os.environ["RUNNABLE_TORCH_PARAMETERS_FILES"] = (
|
|
229
|
-
self._context.parameters_file or ""
|
|
230
|
-
)
|
|
231
|
-
os.environ["RUNNABLE_TORCH_RUN_ID"] = self._context.run_id
|
|
232
|
-
os.environ["RUNNABLE_TORCH_TORCH_LOGS"] = self.log_dir or ""
|
|
233
|
-
|
|
234
|
-
"""
|
|
235
|
-
from runnable import PythonJob # noqa: F401
|
|
236
|
-
|
|
237
|
-
command = os.environ.get("RUNNABLE_TORCH_COMMAND")
|
|
238
|
-
assert command, "Command is not provided"
|
|
239
|
-
|
|
240
|
-
run_id = os.environ.get("RUNNABLE_TORCH_RUN_ID", "")
|
|
241
|
-
parameters_files = os.environ.get("RUNNABLE_TORCH_PARAMETERS_FILES", "")
|
|
242
|
-
|
|
243
|
-
process_run_id = (
|
|
244
|
-
run_id
|
|
245
|
-
+ "-"
|
|
246
|
-
+ os.environ.get("RANK", "")
|
|
247
|
-
+ "-"
|
|
248
|
-
+ "".join(random.choices(string.ascii_lowercase, k=3))
|
|
249
|
-
)
|
|
250
|
-
os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
|
|
251
|
-
|
|
252
|
-
# In this subprocess there shoould not be any RUNNABLE environment variables
|
|
253
|
-
delete_env_vars_with_prefix("RUNNABLE_")
|
|
254
|
-
|
|
255
|
-
module_name, func_name = get_module_and_attr_names(command)
|
|
256
|
-
module = importlib.import_module(module_name)
|
|
257
|
-
|
|
258
|
-
callable_obj = getattr(module, func_name)
|
|
259
|
-
|
|
260
|
-
# The job runs with the default configuration
|
|
261
|
-
# ALl the execution logs are stored in .catalog
|
|
262
|
-
job = PythonJob(function=callable_obj)
|
|
263
|
-
|
|
264
|
-
config_content = {
|
|
265
|
-
"catalog": {"type": "file-system", "config": {"catalog_location": "proc_logs"}}
|
|
266
|
-
}
|
|
267
|
-
|
|
268
|
-
temp_config_file = Path("runnable-config.yaml")
|
|
269
|
-
with open(str(temp_config_file), "w", encoding="utf-8") as config_file:
|
|
270
|
-
yaml = YAML(typ="safe", pure=True)
|
|
271
|
-
yaml.dump(config_content, config_file)
|
|
272
|
-
|
|
273
|
-
job.execute(
|
|
274
|
-
parameters_file=parameters_files,
|
|
275
|
-
job_id=process_run_id,
|
|
276
|
-
)
|
|
277
|
-
|
|
278
|
-
# delete the temp config file
|
|
279
|
-
temp_config_file.unlink()
|
|
280
|
-
|
|
281
|
-
from runnable.context import run_context
|
|
282
|
-
|
|
283
|
-
job_log = run_context.run_log_store.get_run_log_by_id(run_id=run_context.run_id)
|
|
284
|
-
|
|
285
|
-
if job_log.status == defaults.FAIL:
|
|
286
|
-
raise Exception(f"Job {process_run_id} failed")
|