runnable 0.30.4__py3-none-any.whl → 0.31.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
extensions/nodes/torch.py CHANGED
@@ -8,7 +8,7 @@ from typing import Any, Callable
8
8
 
9
9
  from pydantic import ConfigDict, Field
10
10
 
11
- from extensions.nodes.torch_config import TorchConfig
11
+ from extensions.nodes.torch_config import EasyTorchConfig, InternalLogSpecs, TorchConfig
12
12
  from runnable import PythonJob, datastore, defaults
13
13
  from runnable.datastore import StepLog
14
14
  from runnable.nodes import DistributedNode
@@ -18,8 +18,9 @@ from runnable.utils import TypeMapVariable
18
18
  logger = logging.getLogger(defaults.LOGGER_NAME)
19
19
 
20
20
  try:
21
+ from torch.distributed.elastic.multiprocessing.api import DefaultLogsSpecs
21
22
  from torch.distributed.launcher.api import LaunchConfig, elastic_launch
22
- from torch.distributed.run import config_from_args
23
+
23
24
  except ImportError:
24
25
  raise ImportError("torch is not installed. Please install torch first.")
25
26
 
@@ -120,9 +121,25 @@ class TorchNode(DistributedNode, TorchConfig):
120
121
  return cls(executable=executable, **node_config, **task_config)
121
122
 
122
123
  def get_launch_config(self) -> LaunchConfig:
123
- config, _, _ = config_from_args(self)
124
- config.run_id = self._context.run_id
125
- return config
124
+ internal_log_spec = InternalLogSpecs(**self.model_dump(exclude_none=True))
125
+ log_spec: DefaultLogsSpecs = DefaultLogsSpecs(
126
+ **internal_log_spec.model_dump(exclude_none=True)
127
+ )
128
+ easy_torch_config = EasyTorchConfig(
129
+ **self.model_dump(
130
+ exclude_none=True,
131
+ )
132
+ )
133
+
134
+ laugch_config = LaunchConfig(
135
+ **easy_torch_config.model_dump(
136
+ exclude_none=True,
137
+ ),
138
+ logs_specs=log_spec,
139
+ run_id=self._context.run_id,
140
+ )
141
+ print(laugch_config)
142
+ return laugch_config
126
143
 
127
144
  def execute(
128
145
  self,
@@ -1,33 +1,96 @@
1
- from pydantic import BaseModel, Field
1
+ from enum import Enum
2
+ from typing import Any, Optional
3
+
4
+ from pydantic import BaseModel, ConfigDict, Field, computed_field
5
+
6
+
7
+ class StartMethod(str, Enum):
8
+ spawn = "spawn"
9
+ fork = "fork"
10
+ forkserver = "forkserver"
11
+
12
+
13
+ # min_nodes: int
14
+ # max_nodes: int
15
+ # nproc_per_node: int
16
+
17
+ # logs_specs: Optional[LogsSpecs] = None
18
+ # run_id: str = ""
19
+ # role: str = "default_role"
20
+
21
+ # rdzv_endpoint: str = ""
22
+ # rdzv_backend: str = "etcd"
23
+ # rdzv_configs: dict[str, Any] = field(default_factory=dict)
24
+ # rdzv_timeout: int = -1
25
+
26
+ # max_restarts: int = 3
27
+ # monitor_interval: float = 0.1
28
+ # start_method: str = "spawn"
29
+ # log_line_prefix_template: Optional[str] = None
30
+ # metrics_cfg: dict[str, str] = field(default_factory=dict)
31
+ # local_addr: Optional[str] = None
32
+
33
+ ## The idea is the following:
34
+ # Users can configure any of the options present in TorchConfig class.
35
+ # The LaunchConfig class will be created from torch config.
36
+ # The LogSpecs is sent as a parameter to the launch config.
37
+ # None as much as possible to get
38
+
39
+ ## NO idea of standalone and how to send it
40
+
41
+
42
+ class InternalLogSpecs(BaseModel):
43
+ log_dir: Optional[str] = Field(default="torch_logs")
44
+ redirects: int | None = Field(default=None)
45
+ tee: int | None = Field(default=None)
46
+ local_ranks_filter: Optional[set[int]] = Field(default=None)
47
+
48
+ model_config = ConfigDict(extra="ignore")
2
49
 
3
50
 
4
51
  class TorchConfig(BaseModel):
5
- nnodes: str = Field(default="1:1")
6
- nproc_per_node: int = Field(default=4)
7
-
8
- rdzv_backend: str = Field(default="static")
9
- rdzv_endpoint: str = Field(default="")
10
- rdzv_id: str | None = Field(default=None)
11
- rdzv_conf: str = Field(default="")
12
-
13
- max_restarts: int = Field(default=3)
14
- monitor_interval: float = Field(default=0.1)
15
- start_method: str = Field(default="spawn")
16
- role: str = Field(default="default_role")
17
- log_dir: str = Field(default="torch_logs")
18
- redirects: str = Field(default="1")
19
- tee: str = Field(default="1")
20
- master_addr: str = Field(default="localhost")
21
- master_port: str = Field(default="29500")
22
- training_script: str = Field(default="dummy_training_script")
23
- training_script_args: str = Field(default="")
24
-
25
- # Optional fields
26
- local_ranks_filter: str = Field(default="")
27
- node_rank: int = Field(default=0)
28
- local_addr: str | None = Field(default=None)
29
- logs_specs: str | None = Field(default=None)
30
- standalone: bool = Field(default=False)
31
- module: bool = Field(default=False)
32
- no_python: bool = Field(default=False)
33
- run_path: bool = Field(default=False)
52
+ model_config = ConfigDict(extra="forbid")
53
+ nnodes: str = Field(default="1:1", exclude=True)
54
+ nproc_per_node: int = Field(default=1)
55
+
56
+ # will be used to create the log specs
57
+ log_dir: Optional[str] = Field(default="torch_logs", exclude=True)
58
+ redirects: int | None = Field(default=None, exclude=True)
59
+ tee: int | None = Field(default=None, exclude=True)
60
+ local_ranks_filter: Optional[set[int]] = Field(default=None, exclude=True)
61
+
62
+ role: str | None = Field(default=None)
63
+ # run_id would be the run_id of the context
64
+ # and sent at the creation of the LaunchConfig
65
+
66
+ rdzv_backend: str | None = Field(default="static")
67
+ rdzv_endpoint: str | None = Field(default="")
68
+ rdzv_configs: dict[str, Any] = Field(default_factory=dict)
69
+ rdzv_timeout: int | None = Field(default=None)
70
+
71
+ max_restarts: int | None = Field(default=None)
72
+ monitor_interval: float | None = Field(default=None)
73
+ start_method: str | None = Field(default=StartMethod.spawn)
74
+ log_line_prefix_template: str | None = Field(default=None)
75
+ local_addr: Optional[str] = None
76
+
77
+ # https://github.com/pytorch/pytorch/blob/main/torch/distributed/run.py#L753
78
+ # master_addr: str | None = Field(default="localhost")
79
+ # master_port: str | None = Field(default="29500")
80
+ # training_script: str = Field(default="dummy_training_script")
81
+ # training_script_args: str = Field(default="")
82
+
83
+
84
+ class EasyTorchConfig(TorchConfig):
85
+ model_config = ConfigDict(extra="ignore")
86
+
87
+ # TODO: Validate min < max
88
+ @computed_field # type: ignore
89
+ @property
90
+ def min_nodes(self) -> int:
91
+ return int(self.nnodes.split(":")[0])
92
+
93
+ @computed_field # type: ignore
94
+ @property
95
+ def max_nodes(self) -> int:
96
+ return int(self.nnodes.split(":")[1])
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: runnable
3
- Version: 0.30.4
3
+ Version: 0.31.0
4
4
  Summary: Add your description here
5
5
  Author-email: "Vammi, Vijay" <vijay.vammi@astrazeneca.com>
6
6
  License-File: LICENSE
@@ -16,8 +16,8 @@ extensions/job_executor/pyproject.toml,sha256=UIEgiCYHTXcRWSByNMFuKJFKgxTBpQqTqy
16
16
  extensions/nodes/README.md,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
17
  extensions/nodes/nodes.py,sha256=s9ub1dqy4qHjRQG6YElCdL7rCOTYNs9RUIrStZ6tEB4,28256
18
18
  extensions/nodes/pyproject.toml,sha256=YTu-ETN3JNFSkMzzWeOwn4m-O2nbRH-PmiPBALDCUw4,278
19
- extensions/nodes/torch.py,sha256=id0_HVkRcqL9_DPOI-b53vaDwRgVfGB-zZS3yrRej9g,6318
20
- extensions/nodes/torch_config.py,sha256=yDvDADpnLhQsNtfH8qIztLHQ2LhYiOJEWljxpH9GZzs,1222
19
+ extensions/nodes/torch.py,sha256=RUelXV7Pa4U5F7Ww3cfRG0Oaz9SkYF3b_CmpFHlpbyI,6885
20
+ extensions/nodes/torch_config.py,sha256=jfUtkwCYolyKVcFxiMjjwm63yv-HjTKvSQR8JLA7sZg,3151
21
21
  extensions/pipeline_executor/README.md,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
22
22
  extensions/pipeline_executor/__init__.py,sha256=wfigTL2T9OHrmE8b2Ydmb8h6hr-oF--Yc2FectC7WaY,24623
23
23
  extensions/pipeline_executor/argo.py,sha256=AEGSWVZulBL6EsvbVCaeBeTl2m_t5ymc6RFpMKhivis,37946
@@ -58,8 +58,8 @@ runnable/sdk.py,sha256=NZVQGaL4Zm2hwloRmqEgp8UPbBg9hY1abQGYnOgniPI,35128
58
58
  runnable/secrets.py,sha256=4L_dBFxTgr8r_hHUD6RlZEtqaOHDRsFG5PXO5wlvMI0,2324
59
59
  runnable/tasks.py,sha256=Qb1IhVxHv68E7vf3M3YCf7MGRHyjmsEEYBpEpiZ4mRI,29062
60
60
  runnable/utils.py,sha256=hBr7oGwGL2VgfITlQCTz-a1iwvvf7Mfl-HY8UdENZac,19929
61
- runnable-0.30.4.dist-info/METADATA,sha256=S-5zecrqE4tU5MW4Fe1-2F-Q_hLU7fAXZ2oo9xVRRUw,10115
62
- runnable-0.30.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
63
- runnable-0.30.4.dist-info/entry_points.txt,sha256=PrjKrlfXPZaV_7hz8orGu4FDnatLqnhPOXljyllszdw,1880
64
- runnable-0.30.4.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
65
- runnable-0.30.4.dist-info/RECORD,,
61
+ runnable-0.31.0.dist-info/METADATA,sha256=9c3Ixkq-Kl0_hiQfDX-KwtSAdSWzMRLJMfEze2oVQhE,10115
62
+ runnable-0.31.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
63
+ runnable-0.31.0.dist-info/entry_points.txt,sha256=PrjKrlfXPZaV_7hz8orGu4FDnatLqnhPOXljyllszdw,1880
64
+ runnable-0.31.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
65
+ runnable-0.31.0.dist-info/RECORD,,