runnable 0.32.0__py3-none-any.whl → 0.32.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- extensions/job_executor/k8s.py +1 -1
- extensions/nodes/torch.py +5 -4
- extensions/tasks/torch.py +35 -7
- extensions/tasks/torch_config.py +1 -1
- runnable/executor.py +8 -0
- runnable/parameters.py +5 -2
- runnable/sdk.py +4 -4
- runnable/tasks.py +0 -1
- {runnable-0.32.0.dist-info → runnable-0.32.2.dist-info}/METADATA +1 -1
- {runnable-0.32.0.dist-info → runnable-0.32.2.dist-info}/RECORD +13 -13
- {runnable-0.32.0.dist-info → runnable-0.32.2.dist-info}/WHEEL +0 -0
- {runnable-0.32.0.dist-info → runnable-0.32.2.dist-info}/entry_points.txt +0 -0
- {runnable-0.32.0.dist-info → runnable-0.32.2.dist-info}/licenses/LICENSE +0 -0
extensions/job_executor/k8s.py
CHANGED
extensions/nodes/torch.py
CHANGED
@@ -5,7 +5,7 @@ import random
|
|
5
5
|
import string
|
6
6
|
from datetime import datetime
|
7
7
|
from pathlib import Path
|
8
|
-
from typing import Any, Callable, Optional
|
8
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional
|
9
9
|
|
10
10
|
from pydantic import BaseModel, ConfigDict, Field, field_serializer
|
11
11
|
|
@@ -21,11 +21,12 @@ logger = logging.getLogger(defaults.LOGGER_NAME)
|
|
21
21
|
try:
|
22
22
|
from torch.distributed.elastic.multiprocessing.api import DefaultLogsSpecs, Std
|
23
23
|
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
24
|
-
|
25
24
|
except ImportError:
|
26
|
-
|
25
|
+
logger.exception("Torch is not installed. Please install torch first.")
|
27
26
|
|
28
|
-
|
27
|
+
if TYPE_CHECKING:
|
28
|
+
from torch.distributed.elastic.multiprocessing.api import DefaultLogsSpecs, Std
|
29
|
+
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
29
30
|
|
30
31
|
|
31
32
|
def training_subprocess():
|
extensions/tasks/torch.py
CHANGED
@@ -5,7 +5,7 @@ import random
|
|
5
5
|
import string
|
6
6
|
from datetime import datetime
|
7
7
|
from pathlib import Path
|
8
|
-
from typing import Any, Optional
|
8
|
+
from typing import TYPE_CHECKING, Any, Optional
|
9
9
|
|
10
10
|
from pydantic import BaseModel, ConfigDict, Field, field_serializer, model_validator
|
11
11
|
from ruamel.yaml import YAML
|
@@ -17,15 +17,23 @@ from runnable.datastore import StepAttempt
|
|
17
17
|
from runnable.tasks import BaseTaskType
|
18
18
|
from runnable.utils import get_module_and_attr_names
|
19
19
|
|
20
|
+
logger = logging.getLogger(defaults.LOGGER_NAME)
|
21
|
+
|
20
22
|
try:
|
21
23
|
from torch.distributed.elastic.multiprocessing.api import DefaultLogsSpecs, Std
|
22
24
|
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
23
25
|
|
24
26
|
except ImportError:
|
25
|
-
|
27
|
+
logger.exception("torch is not installed")
|
26
28
|
|
29
|
+
if TYPE_CHECKING:
|
30
|
+
from torch.distributed.elastic.multiprocessing.api import DefaultLogsSpecs, Std
|
31
|
+
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
27
32
|
|
28
|
-
|
33
|
+
|
34
|
+
def get_min_max_nodes(nnodes: str) -> tuple[int, int]:
|
35
|
+
min_nodes, max_nodes = (int(x) for x in nnodes.split(":"))
|
36
|
+
return min_nodes, max_nodes
|
29
37
|
|
30
38
|
|
31
39
|
class TorchTaskType(BaseTaskType, TorchConfig):
|
@@ -60,7 +68,7 @@ class TorchTaskType(BaseTaskType, TorchConfig):
|
|
60
68
|
exclude_none=True,
|
61
69
|
)
|
62
70
|
)
|
63
|
-
|
71
|
+
print("###", easy_torch_config)
|
64
72
|
launch_config = LaunchConfig(
|
65
73
|
**easy_torch_config.model_dump(
|
66
74
|
exclude_none=True,
|
@@ -77,7 +85,30 @@ class TorchTaskType(BaseTaskType, TorchConfig):
|
|
77
85
|
):
|
78
86
|
assert map_variable is None, "map_variable is not supported for torch"
|
79
87
|
|
88
|
+
# The below should happen only if we are in the node that we want to execute
|
89
|
+
# For a single node, multi worker setup, this should be the entry point
|
90
|
+
# For a multi-node, we need to:
|
91
|
+
# - create a service config
|
92
|
+
# - Create a stateful set with number of nodes
|
93
|
+
# - Create a job to run the torch.distributed.launcher.api.elastic_launch on every node
|
94
|
+
# - the entry point to runnnable could be a way to trigger execution instead of scaling
|
95
|
+
is_execute = os.environ.get("RUNNABLE_TORCH_EXECUTE", "true") == "true"
|
96
|
+
|
97
|
+
_, max_nodes = get_min_max_nodes(self.nnodes)
|
98
|
+
|
99
|
+
if max_nodes > 1 and not is_execute:
|
100
|
+
executor = self._context.executor
|
101
|
+
executor.scale_up(self)
|
102
|
+
return StepAttempt(
|
103
|
+
status=defaults.SUCCESS,
|
104
|
+
start_time=str(datetime.now()),
|
105
|
+
end_time=str(datetime.now()),
|
106
|
+
attempt_number=1,
|
107
|
+
message="Triggered a scale up",
|
108
|
+
)
|
109
|
+
|
80
110
|
launch_config = self._get_launch_config()
|
111
|
+
print("###****", launch_config)
|
81
112
|
logger.info(f"launch_config: {launch_config}")
|
82
113
|
|
83
114
|
# ENV variables are shared with the subprocess, use that as communication
|
@@ -175,9 +206,6 @@ def training_subprocess():
|
|
175
206
|
self._context.parameters_file or ""
|
176
207
|
)
|
177
208
|
os.environ["RUNNABLE_TORCH_RUN_ID"] = self._context.run_id
|
178
|
-
os.environ["RUNNABLE_TORCH_COPY_CONTENTS_TO"] = (
|
179
|
-
self._context.catalog_handler.compute_data_folder
|
180
|
-
)
|
181
209
|
os.environ["RUNNABLE_TORCH_TORCH_LOGS"] = self.log_dir or ""
|
182
210
|
|
183
211
|
"""
|
extensions/tasks/torch_config.py
CHANGED
@@ -43,7 +43,7 @@ class TorchConfig(BaseModel):
|
|
43
43
|
# and sent at the creation of the LaunchConfig
|
44
44
|
|
45
45
|
# This section is about the communication between nodes/processes
|
46
|
-
rdzv_backend: str | None = Field(default="
|
46
|
+
rdzv_backend: str | None = Field(default="")
|
47
47
|
rdzv_endpoint: str | None = Field(default="")
|
48
48
|
rdzv_configs: dict[str, Any] = Field(default_factory=dict)
|
49
49
|
rdzv_timeout: int | None = Field(default=None)
|
runnable/executor.py
CHANGED
@@ -153,6 +153,14 @@ class BaseJobExecutor(BaseExecutor):
|
|
153
153
|
"""
|
154
154
|
...
|
155
155
|
|
156
|
+
# @abstractmethod
|
157
|
+
# def scale_up(self, job: BaseTaskType):
|
158
|
+
# """
|
159
|
+
# Scale up the job to run on max_nodes
|
160
|
+
# This has to also call the entry point
|
161
|
+
# """
|
162
|
+
# ...
|
163
|
+
|
156
164
|
|
157
165
|
# TODO: Consolidate execute_node, trigger_node_execution, _execute_node
|
158
166
|
class BasePipelineExecutor(BaseExecutor):
|
runnable/parameters.py
CHANGED
@@ -9,7 +9,7 @@ from pydantic import BaseModel, ConfigDict
|
|
9
9
|
from typing_extensions import Callable
|
10
10
|
|
11
11
|
from runnable import defaults
|
12
|
-
from runnable.datastore import JsonParameter
|
12
|
+
from runnable.datastore import JsonParameter, ObjectParameter
|
13
13
|
from runnable.defaults import TypeMapVariable
|
14
14
|
from runnable.utils import remove_prefix
|
15
15
|
|
@@ -101,10 +101,13 @@ def filter_arguments_for_func(
|
|
101
101
|
# default value is given in the function signature, nothing further to do.
|
102
102
|
continue
|
103
103
|
|
104
|
+
param_value = params[name]
|
105
|
+
|
104
106
|
if type(value.annotation) in [
|
105
107
|
BaseModel,
|
106
108
|
pydantic._internal._model_construction.ModelMetaclass,
|
107
|
-
]:
|
109
|
+
] and not isinstance(param_value, ObjectParameter):
|
110
|
+
# Even if the annotation is a pydantic model, it can be passed as an object parameter
|
108
111
|
# We try to cast it as a pydantic model if asked
|
109
112
|
named_param = params[name].get_value()
|
110
113
|
|
runnable/sdk.py
CHANGED
@@ -5,7 +5,7 @@ import os
|
|
5
5
|
import re
|
6
6
|
from abc import ABC, abstractmethod
|
7
7
|
from pathlib import Path
|
8
|
-
from typing import
|
8
|
+
from typing import Any, Callable, Dict, List, Optional, Union
|
9
9
|
|
10
10
|
from pydantic import (
|
11
11
|
BaseModel,
|
@@ -34,7 +34,7 @@ from extensions.nodes.nodes import (
|
|
34
34
|
SuccessNode,
|
35
35
|
TaskNode,
|
36
36
|
)
|
37
|
-
from extensions.
|
37
|
+
from extensions.tasks.torch_config import TorchConfig
|
38
38
|
from runnable import console, defaults, entrypoints, exceptions, graph, utils
|
39
39
|
from runnable.executor import BaseJobExecutor, BasePipelineExecutor
|
40
40
|
from runnable.nodes import TraversalNode
|
@@ -46,8 +46,6 @@ logger = logging.getLogger(defaults.LOGGER_NAME)
|
|
46
46
|
StepType = Union[
|
47
47
|
"Stub", "PythonTask", "NotebookTask", "ShellTask", "Parallel", "Map", "TorchTask"
|
48
48
|
]
|
49
|
-
if TYPE_CHECKING:
|
50
|
-
pass
|
51
49
|
|
52
50
|
|
53
51
|
def pickled(name: str) -> TaskReturns:
|
@@ -192,6 +190,8 @@ class BaseTask(BaseTraversal):
|
|
192
190
|
|
193
191
|
|
194
192
|
class TorchTask(BaseTask, TorchConfig):
|
193
|
+
# The user will not know the rnnz variables for multi node
|
194
|
+
# They should be overridden in the environment
|
195
195
|
function: Callable = Field(exclude=True)
|
196
196
|
|
197
197
|
@field_validator("returns", mode="before")
|
runnable/tasks.py
CHANGED
@@ -8,7 +8,7 @@ extensions/catalog/pyproject.toml,sha256=lLNxY6v04c8I5QK_zKw_E6sJTArSJRA_V-79kta
|
|
8
8
|
extensions/catalog/s3.py,sha256=Sw5t8_kVRprn3uGGJCiHn7M9zw1CLaCOFj6YErtfG0o,287
|
9
9
|
extensions/job_executor/README.md,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
10
10
|
extensions/job_executor/__init__.py,sha256=VeLuYCcShCIYT0TNtAXfUF9tOk4ZHoLzdTEvbsz0spM,5870
|
11
|
-
extensions/job_executor/k8s.py,sha256=
|
11
|
+
extensions/job_executor/k8s.py,sha256=Jl0s3YryISx-SJIhDhyNskzlUlhy4ynBHEc9DfAXjAY,16394
|
12
12
|
extensions/job_executor/k8s_job_spec.yaml,sha256=7aFpxHdO_p6Hkc3YxusUOuAQTD1Myu0yTPX9DrhxbOg,1158
|
13
13
|
extensions/job_executor/local.py,sha256=3ZbCFXBvbLlMp10JTmQJJrjBKG2keHI6SH8hEvmHDkA,2230
|
14
14
|
extensions/job_executor/local_container.py,sha256=1JcLJ0zrNSNHdubrSO9miN54iwvPLHqKMZ08aOC8WWo,6886
|
@@ -16,7 +16,7 @@ extensions/job_executor/pyproject.toml,sha256=UIEgiCYHTXcRWSByNMFuKJFKgxTBpQqTqy
|
|
16
16
|
extensions/nodes/README.md,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
17
17
|
extensions/nodes/nodes.py,sha256=s9ub1dqy4qHjRQG6YElCdL7rCOTYNs9RUIrStZ6tEB4,28256
|
18
18
|
extensions/nodes/pyproject.toml,sha256=YTu-ETN3JNFSkMzzWeOwn4m-O2nbRH-PmiPBALDCUw4,278
|
19
|
-
extensions/nodes/torch.py,sha256=
|
19
|
+
extensions/nodes/torch.py,sha256=gydcRX5C7jEdPnxLsAQkpRD_by_0Lp4dFg96xDkRVW0,9510
|
20
20
|
extensions/nodes/torch_config.py,sha256=tO3sG2_fj8a6FmPZZllwKVx3WaRr4QmQYcACseg8YXM,2839
|
21
21
|
extensions/pipeline_executor/README.md,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
22
22
|
extensions/pipeline_executor/__init__.py,sha256=wfigTL2T9OHrmE8b2Ydmb8h6hr-oF--Yc2FectC7WaY,24623
|
@@ -40,8 +40,8 @@ extensions/run_log_store/db/integration_FF.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeR
|
|
40
40
|
extensions/secrets/README.md,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
41
41
|
extensions/secrets/dotenv.py,sha256=nADHXI6KJ_LUYOIe5EbtYH-21OBebSNVr0Pjb1GlZ7w,1573
|
42
42
|
extensions/secrets/pyproject.toml,sha256=mLJNImNcBlbLKHh-0ugVWT9V83R4RibyyYDtBCSqVF4,282
|
43
|
-
extensions/tasks/torch.py,sha256=
|
44
|
-
extensions/tasks/torch_config.py,sha256=
|
43
|
+
extensions/tasks/torch.py,sha256=At2eMpJas4sUUjzJfPrEBGamG-k3MsxXU6Bou0h9BEs,9274
|
44
|
+
extensions/tasks/torch_config.py,sha256=UjfMitT-TXASRDGR30I2vDRnyk7JQnR-5CsOVidjpSY,2833
|
45
45
|
runnable/__init__.py,sha256=3ZKuvGEkY_zHVQlJtarXd4jkjICxjgnw-bbKN_5SiJI,691
|
46
46
|
runnable/catalog.py,sha256=4msQxLhLKlsDDrHFnGauPYe-Or-q9g8_RYCn_4dpxaU,4466
|
47
47
|
runnable/cli.py,sha256=3BiKSj95h2Drn__YlchMPZ5rBMafuRb2OGIsVpbsO5Y,8788
|
@@ -50,18 +50,18 @@ runnable/datastore.py,sha256=ZobM1aVkgeUJ2fZYt63IFDsoNzObwc93hdByegS5YKQ,32396
|
|
50
50
|
runnable/defaults.py,sha256=3o9IVGryyCE6PoQTOoaIaHHTbJGEzmdXMcwzOhwAYoI,3518
|
51
51
|
runnable/entrypoints.py,sha256=1xCbWVUQLGmg5gkWnAVWFLAUf6j4avP9azX_vuGQUMY,18985
|
52
52
|
runnable/exceptions.py,sha256=LFbp0-Qxg2PAMLEVt7w2whhBxSG-5pzUEv5qN-Rc4_c,3003
|
53
|
-
runnable/executor.py,sha256=
|
53
|
+
runnable/executor.py,sha256=Jr9yJtSH7CzjXJLWx3VWIUAQblstuGqzpFtajv7d39M,15348
|
54
54
|
runnable/graph.py,sha256=poQz5zcvq89ju_u5sYlunQLPbHnXTaUmjcvstPwvT4U,16536
|
55
55
|
runnable/names.py,sha256=vn92Kv9ANROYSZX6Z4z1v_WA3WiEdIYmG6KEStBFZug,8134
|
56
56
|
runnable/nodes.py,sha256=d1eLttMAcV7CTwTEqOuNwZqItANoLUkXJ73Xp-srlyI,17811
|
57
|
-
runnable/parameters.py,sha256=
|
57
|
+
runnable/parameters.py,sha256=u77CdqqDAbVdzNeBFPNUfGnWPy9-SpBVmwEJ56xmDm8,5289
|
58
58
|
runnable/pickler.py,sha256=ydJ_eti_U1F4l-YacFp7BWm6g5vTn04UXye25S1HVok,2684
|
59
|
-
runnable/sdk.py,sha256=
|
59
|
+
runnable/sdk.py,sha256=hwsEGCCFSijm0DZwDJGHmV8jdMuSU_3Pf-vYoomWYHw,35084
|
60
60
|
runnable/secrets.py,sha256=4L_dBFxTgr8r_hHUD6RlZEtqaOHDRsFG5PXO5wlvMI0,2324
|
61
|
-
runnable/tasks.py,sha256=
|
61
|
+
runnable/tasks.py,sha256=ABRhgiTY8F62pNlqJmVTDjwJwuzp8DqciUEOq1fpt1U,28989
|
62
62
|
runnable/utils.py,sha256=hBr7oGwGL2VgfITlQCTz-a1iwvvf7Mfl-HY8UdENZac,19929
|
63
|
-
runnable-0.32.
|
64
|
-
runnable-0.32.
|
65
|
-
runnable-0.32.
|
66
|
-
runnable-0.32.
|
67
|
-
runnable-0.32.
|
63
|
+
runnable-0.32.2.dist-info/METADATA,sha256=fcKKBj2v2AhRQFZ7ALqSdJrKF5r0Wg-QV6HVKqkBpRY,10168
|
64
|
+
runnable-0.32.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
65
|
+
runnable-0.32.2.dist-info/entry_points.txt,sha256=uWHbbOSj0jlG54tFHw377xKkfVbjWvb_1Y9L_LgjJ0Q,1925
|
66
|
+
runnable-0.32.2.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
67
|
+
runnable-0.32.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|