runnable 0.31.0__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
extensions/nodes/torch.py CHANGED
@@ -4,11 +4,12 @@ import os
4
4
  import random
5
5
  import string
6
6
  from datetime import datetime
7
- from typing import Any, Callable
7
+ from pathlib import Path
8
+ from typing import Any, Callable, Optional
8
9
 
9
- from pydantic import ConfigDict, Field
10
+ from pydantic import BaseModel, ConfigDict, Field, field_serializer
10
11
 
11
- from extensions.nodes.torch_config import EasyTorchConfig, InternalLogSpecs, TorchConfig
12
+ from extensions.nodes.torch_config import EasyTorchConfig, TorchConfig
12
13
  from runnable import PythonJob, datastore, defaults
13
14
  from runnable.datastore import StepLog
14
15
  from runnable.nodes import DistributedNode
@@ -18,7 +19,7 @@ from runnable.utils import TypeMapVariable
18
19
  logger = logging.getLogger(defaults.LOGGER_NAME)
19
20
 
20
21
  try:
21
- from torch.distributed.elastic.multiprocessing.api import DefaultLogsSpecs
22
+ from torch.distributed.elastic.multiprocessing.api import DefaultLogsSpecs, Std
22
23
  from torch.distributed.launcher.api import LaunchConfig, elastic_launch
23
24
 
24
25
  except ImportError:
@@ -28,9 +29,30 @@ print("torch is installed")
28
29
 
29
30
 
30
31
  def training_subprocess():
32
+ """
33
+ This function is called by the torch.distributed.launcher.api.elastic_launch
34
+ It happens in a subprocess and is responsible for executing the user's function
35
+
36
+ It is unrelated to the actual node execution, so any cataloging, run_log_store should be
37
+ handled to match to main process.
38
+
39
+ We have these variables to use:
40
+
41
+ os.environ["RUNNABLE_TORCH_COMMAND"] = self.executable.command
42
+ os.environ["RUNNABLE_TORCH_PARAMETERS_FILES"] = (
43
+ self._context.parameters_file or ""
44
+ )
45
+ os.environ["RUNNABLE_TORCH_RUN_ID"] = self._context.run_id
46
+ os.environ["RUNNABLE_TORCH_COPY_CONTENTS_TO"] = (
47
+ self._context.catalog_handler.compute_data_folder
48
+ )
49
+ os.environ["RUNNABLE_TORCH_TORCH_LOGS"] = self.log_dir or ""
50
+
51
+ """
31
52
  command = os.environ.get("RUNNABLE_TORCH_COMMAND")
32
53
  run_id = os.environ.get("RUNNABLE_TORCH_RUN_ID", "")
33
54
  parameters_files = os.environ.get("RUNNABLE_TORCH_PARAMETERS_FILES", "")
55
+
34
56
  process_run_id = (
35
57
  run_id
36
58
  + "-"
@@ -38,10 +60,14 @@ def training_subprocess():
38
60
  + "-"
39
61
  + "".join(random.choices(string.ascii_lowercase, k=3))
40
62
  )
63
+ os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
41
64
 
42
65
  delete_env_vars_with_prefix("RUNNABLE_")
43
66
 
44
67
  func = get_callable_from_dotted_path(command)
68
+
69
+ # The job runs with the default configuration
70
+ # ALl the execution logs are stored in .catalog
45
71
  job = PythonJob(function=func)
46
72
 
47
73
  job.execute(
@@ -57,6 +83,7 @@ def training_subprocess():
57
83
  raise Exception(f"Job {process_run_id} failed")
58
84
 
59
85
 
86
+ # TODO: Can this be utils.get_module_and_attr_names
60
87
  def get_callable_from_dotted_path(dotted_path) -> Callable:
61
88
  try:
62
89
  # Split the path into module path and callable object
@@ -91,6 +118,7 @@ def delete_env_vars_with_prefix(prefix):
91
118
  del os.environ[var]
92
119
 
93
120
 
121
+ # TODO: The design of this class is not final
94
122
  class TorchNode(DistributedNode, TorchConfig):
95
123
  node_type: str = Field(default="torch", serialization_alias="type")
96
124
  executable: PythonTaskType = Field(exclude=True)
@@ -131,15 +159,15 @@ class TorchNode(DistributedNode, TorchConfig):
131
159
  )
132
160
  )
133
161
 
134
- laugch_config = LaunchConfig(
162
+ launch_config = LaunchConfig(
135
163
  **easy_torch_config.model_dump(
136
164
  exclude_none=True,
137
165
  ),
138
166
  logs_specs=log_spec,
139
167
  run_id=self._context.run_id,
140
168
  )
141
- print(laugch_config)
142
- return laugch_config
169
+ logger.info(f"launch_config: {launch_config}")
170
+ return launch_config
143
171
 
144
172
  def execute(
145
173
  self,
@@ -159,13 +187,13 @@ class TorchNode(DistributedNode, TorchConfig):
159
187
  launch_config = self.get_launch_config()
160
188
  logger.info(f"launch_config: {launch_config}")
161
189
 
190
+ # ENV variables are shared with the subprocess, use that as communication
162
191
  os.environ["RUNNABLE_TORCH_COMMAND"] = self.executable.command
163
192
  os.environ["RUNNABLE_TORCH_PARAMETERS_FILES"] = (
164
193
  self._context.parameters_file or ""
165
194
  )
166
195
  os.environ["RUNNABLE_TORCH_RUN_ID"] = self._context.run_id
167
- # retrieve the master address and port from the parameters
168
- # default to localhost and 29500
196
+
169
197
  launcher = elastic_launch(
170
198
  launch_config,
171
199
  training_subprocess,
@@ -186,6 +214,20 @@ class TorchNode(DistributedNode, TorchConfig):
186
214
  attempt_number=attempt_number,
187
215
  )
188
216
  logger.error(f"Error executing TorchNode: {e}")
217
+ finally:
218
+ # This can only come from the subprocess
219
+ if Path(".catalog").exists():
220
+ os.rename(".catalog", "proc_logs")
221
+ # Move .catalog and torch_logs to the parent node's catalog location
222
+ self._context.catalog_handler.put(
223
+ "proc_logs/**/*", allow_file_not_found_exc=True
224
+ )
225
+
226
+ # TODO: This is not working!!
227
+ if self.log_dir:
228
+ self._context.catalog_handler.put(
229
+ self.log_dir + "/**/*", allow_file_not_found_exc=True
230
+ )
189
231
 
190
232
  delete_env_vars_with_prefix("RUNNABLE_TORCH")
191
233
 
@@ -211,3 +253,23 @@ class TorchNode(DistributedNode, TorchConfig):
211
253
  assert (
212
254
  map_variable is None or not map_variable
213
255
  ), "TorchNode does not support map_variable"
256
+
257
+
258
+ # This internal model makes it easier to extract the required fields
259
+ # of log specs from user specification.
260
+ # https://github.com/pytorch/pytorch/blob/main/torch/distributed/elastic/multiprocessing/api.py#L243
261
+ class InternalLogSpecs(BaseModel):
262
+ log_dir: Optional[str] = Field(default="torch_logs")
263
+ redirects: str = Field(default="0") # Std.NONE
264
+ tee: str = Field(default="0") # Std.NONE
265
+ local_ranks_filter: Optional[set[int]] = Field(default=None)
266
+
267
+ model_config = ConfigDict(extra="ignore")
268
+
269
+ @field_serializer("redirects")
270
+ def convert_redirects(self, redirects: str) -> Std | dict[int, Std]:
271
+ return Std.from_str(redirects)
272
+
273
+ @field_serializer("tee")
274
+ def convert_tee(self, tee: str) -> Std | dict[int, Std]:
275
+ return Std.from_str(tee)
@@ -10,59 +10,39 @@ class StartMethod(str, Enum):
10
10
  forkserver = "forkserver"
11
11
 
12
12
 
13
- # min_nodes: int
14
- # max_nodes: int
15
- # nproc_per_node: int
16
-
17
- # logs_specs: Optional[LogsSpecs] = None
18
- # run_id: str = ""
19
- # role: str = "default_role"
20
-
21
- # rdzv_endpoint: str = ""
22
- # rdzv_backend: str = "etcd"
23
- # rdzv_configs: dict[str, Any] = field(default_factory=dict)
24
- # rdzv_timeout: int = -1
25
-
26
- # max_restarts: int = 3
27
- # monitor_interval: float = 0.1
28
- # start_method: str = "spawn"
29
- # log_line_prefix_template: Optional[str] = None
30
- # metrics_cfg: dict[str, str] = field(default_factory=dict)
31
- # local_addr: Optional[str] = None
32
-
33
13
  ## The idea is the following:
34
14
  # Users can configure any of the options present in TorchConfig class.
35
- # The LaunchConfig class will be created from torch config.
15
+ # The LaunchConfig class will be created from TorchConfig.
36
16
  # The LogSpecs is sent as a parameter to the launch config.
37
- # None as much as possible to get
38
17
 
39
18
  ## NO idea of standalone and how to send it
40
19
 
41
20
 
42
- class InternalLogSpecs(BaseModel):
43
- log_dir: Optional[str] = Field(default="torch_logs")
44
- redirects: int | None = Field(default=None)
45
- tee: int | None = Field(default=None)
46
- local_ranks_filter: Optional[set[int]] = Field(default=None)
47
-
48
- model_config = ConfigDict(extra="ignore")
49
-
50
-
21
+ # The user sees this as part of the config of the node.
22
+ # It is kept as similar as possible to torchrun
51
23
  class TorchConfig(BaseModel):
52
24
  model_config = ConfigDict(extra="forbid")
53
- nnodes: str = Field(default="1:1", exclude=True)
54
- nproc_per_node: int = Field(default=1)
25
+
26
+ # excluded as LaunchConfig requires min and max nodes
27
+ nnodes: str = Field(default="1:1", exclude=True, description="min:max")
28
+ nproc_per_node: int = Field(default=1, description="Number of processes per node")
55
29
 
56
30
  # will be used to create the log specs
31
+ # But they are excluded from dump as logs specs is a class for LaunchConfig
32
+ # from_str("0") -> Std.NONE
33
+ # from_str("1") -> Std.OUT
34
+ # from_str("0:3,1:0,2:1,3:2") -> {0: Std.ALL, 1: Std.NONE, 2: Std.OUT, 3: Std.ERR}
57
35
  log_dir: Optional[str] = Field(default="torch_logs", exclude=True)
58
- redirects: int | None = Field(default=None, exclude=True)
59
- tee: int | None = Field(default=None, exclude=True)
36
+ redirects: str = Field(default="0", exclude=True) # Std.NONE
37
+ tee: str = Field(default="0", exclude=True) # Std.NONE
60
38
  local_ranks_filter: Optional[set[int]] = Field(default=None, exclude=True)
61
39
 
62
40
  role: str | None = Field(default=None)
41
+
63
42
  # run_id would be the run_id of the context
64
43
  # and sent at the creation of the LaunchConfig
65
44
 
45
+ # This section is about the communication between nodes/processes
66
46
  rdzv_backend: str | None = Field(default="static")
67
47
  rdzv_endpoint: str | None = Field(default="")
68
48
  rdzv_configs: dict[str, Any] = Field(default_factory=dict)
@@ -0,0 +1,235 @@
1
+ import importlib
2
+ import logging
3
+ import os
4
+ import random
5
+ import string
6
+ from datetime import datetime
7
+ from pathlib import Path
8
+ from typing import Any, Optional
9
+
10
+ from pydantic import BaseModel, ConfigDict, Field, field_serializer, model_validator
11
+ from ruamel.yaml import YAML
12
+
13
+ import runnable.context as context
14
+ from extensions.tasks.torch_config import EasyTorchConfig, TorchConfig
15
+ from runnable import Catalog, defaults
16
+ from runnable.datastore import StepAttempt
17
+ from runnable.tasks import BaseTaskType
18
+ from runnable.utils import get_module_and_attr_names
19
+
20
+ try:
21
+ from torch.distributed.elastic.multiprocessing.api import DefaultLogsSpecs, Std
22
+ from torch.distributed.launcher.api import LaunchConfig, elastic_launch
23
+
24
+ except ImportError:
25
+ raise ImportError("torch is not installed. Please install torch first.")
26
+
27
+
28
+ logger = logging.getLogger(defaults.LOGGER_NAME)
29
+
30
+
31
+ class TorchTaskType(BaseTaskType, TorchConfig):
32
+ task_type: str = Field(default="torch", serialization_alias="command_type")
33
+ catalog: Optional[Catalog] = Field(default=None, alias="catalog")
34
+ command: str
35
+
36
+ @model_validator(mode="before")
37
+ @classmethod
38
+ def check_secrets_and_returns(cls, data: Any) -> Any:
39
+ if isinstance(data, dict):
40
+ if "secrets" in data and data["secrets"]:
41
+ raise ValueError("'secrets' is not supported for torch")
42
+ if "returns" in data and data["returns"]:
43
+ raise ValueError("'secrets' is not supported for torch")
44
+ return data
45
+
46
+ def get_summary(self) -> dict[str, Any]:
47
+ return self.model_dump(by_alias=True, exclude_none=True)
48
+
49
+ @property
50
+ def _context(self):
51
+ return context.run_context
52
+
53
+ def _get_launch_config(self) -> LaunchConfig:
54
+ internal_log_spec = InternalLogSpecs(**self.model_dump(exclude_none=True))
55
+ log_spec: DefaultLogsSpecs = DefaultLogsSpecs(
56
+ **internal_log_spec.model_dump(exclude_none=True)
57
+ )
58
+ easy_torch_config = EasyTorchConfig(
59
+ **self.model_dump(
60
+ exclude_none=True,
61
+ )
62
+ )
63
+
64
+ launch_config = LaunchConfig(
65
+ **easy_torch_config.model_dump(
66
+ exclude_none=True,
67
+ ),
68
+ logs_specs=log_spec,
69
+ run_id=self._context.run_id,
70
+ )
71
+ logger.info(f"launch_config: {launch_config}")
72
+ return launch_config
73
+
74
+ def execute_command(
75
+ self,
76
+ map_variable: defaults.TypeMapVariable = None,
77
+ ):
78
+ assert map_variable is None, "map_variable is not supported for torch"
79
+
80
+ launch_config = self._get_launch_config()
81
+ logger.info(f"launch_config: {launch_config}")
82
+
83
+ # ENV variables are shared with the subprocess, use that as communication
84
+ os.environ["RUNNABLE_TORCH_COMMAND"] = self.command
85
+ os.environ["RUNNABLE_TORCH_PARAMETERS_FILES"] = (
86
+ self._context.parameters_file or ""
87
+ )
88
+ os.environ["RUNNABLE_TORCH_RUN_ID"] = self._context.run_id
89
+
90
+ launcher = elastic_launch(
91
+ launch_config,
92
+ training_subprocess,
93
+ )
94
+ try:
95
+ launcher()
96
+ attempt_log = StepAttempt(
97
+ status=defaults.SUCCESS,
98
+ start_time=str(datetime.now()),
99
+ end_time=str(datetime.now()),
100
+ attempt_number=1,
101
+ )
102
+ except Exception as e:
103
+ attempt_log = StepAttempt(
104
+ status=defaults.FAIL,
105
+ start_time=str(datetime.now()),
106
+ end_time=str(datetime.now()),
107
+ attempt_number=1,
108
+ )
109
+ logger.error(f"Error executing TorchNode: {e}")
110
+ finally:
111
+ # This can only come from the subprocess
112
+ if Path("proc_logs").exists():
113
+ # Move .catalog and torch_logs to the parent node's catalog location
114
+ self._context.catalog_handler.put(
115
+ "proc_logs/**/*", allow_file_not_found_exc=True
116
+ )
117
+
118
+ # TODO: This is not working!!
119
+ if self.log_dir:
120
+ self._context.catalog_handler.put(
121
+ self.log_dir + "/**/*", allow_file_not_found_exc=True
122
+ )
123
+
124
+ delete_env_vars_with_prefix("RUNNABLE_TORCH")
125
+ logger.info(f"attempt_log: {attempt_log}")
126
+
127
+ return attempt_log
128
+
129
+
130
+ # This internal model makes it easier to extract the required fields
131
+ # of log specs from user specification.
132
+ # https://github.com/pytorch/pytorch/blob/main/torch/distributed/elastic/multiprocessing/api.py#L243
133
+ class InternalLogSpecs(BaseModel):
134
+ log_dir: Optional[str] = Field(default="torch_logs")
135
+ redirects: str = Field(default="0") # Std.NONE
136
+ tee: str = Field(default="0") # Std.NONE
137
+ local_ranks_filter: Optional[set[int]] = Field(default=None)
138
+
139
+ model_config = ConfigDict(extra="ignore")
140
+
141
+ @field_serializer("redirects")
142
+ def convert_redirects(self, redirects: str) -> Std | dict[int, Std]:
143
+ return Std.from_str(redirects)
144
+
145
+ @field_serializer("tee")
146
+ def convert_tee(self, tee: str) -> Std | dict[int, Std]:
147
+ return Std.from_str(tee)
148
+
149
+
150
+ def delete_env_vars_with_prefix(prefix):
151
+ to_delete = [] # List to keep track of variables to delete
152
+
153
+ # Iterate over a list of all environment variable keys
154
+ for var in os.environ:
155
+ if var.startswith(prefix):
156
+ to_delete.append(var)
157
+
158
+ # Delete each of the variables collected
159
+ for var in to_delete:
160
+ del os.environ[var]
161
+
162
+
163
+ def training_subprocess():
164
+ """
165
+ This function is called by the torch.distributed.launcher.api.elastic_launch
166
+ It happens in a subprocess and is responsible for executing the user's function
167
+
168
+ It is unrelated to the actual node execution, so any cataloging, run_log_store should be
169
+ handled to match to main process.
170
+
171
+ We have these variables to use:
172
+
173
+ os.environ["RUNNABLE_TORCH_COMMAND"] = self.executable.command
174
+ os.environ["RUNNABLE_TORCH_PARAMETERS_FILES"] = (
175
+ self._context.parameters_file or ""
176
+ )
177
+ os.environ["RUNNABLE_TORCH_RUN_ID"] = self._context.run_id
178
+ os.environ["RUNNABLE_TORCH_COPY_CONTENTS_TO"] = (
179
+ self._context.catalog_handler.compute_data_folder
180
+ )
181
+ os.environ["RUNNABLE_TORCH_TORCH_LOGS"] = self.log_dir or ""
182
+
183
+ """
184
+ from runnable import PythonJob # noqa: F401
185
+
186
+ command = os.environ.get("RUNNABLE_TORCH_COMMAND")
187
+ assert command, "Command is not provided"
188
+
189
+ run_id = os.environ.get("RUNNABLE_TORCH_RUN_ID", "")
190
+ parameters_files = os.environ.get("RUNNABLE_TORCH_PARAMETERS_FILES", "")
191
+
192
+ process_run_id = (
193
+ run_id
194
+ + "-"
195
+ + os.environ.get("RANK", "")
196
+ + "-"
197
+ + "".join(random.choices(string.ascii_lowercase, k=3))
198
+ )
199
+ os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
200
+
201
+ # In this subprocess there shoould not be any RUNNABLE environment variables
202
+ delete_env_vars_with_prefix("RUNNABLE_")
203
+
204
+ module_name, func_name = get_module_and_attr_names(command)
205
+ module = importlib.import_module(module_name)
206
+
207
+ callable_obj = getattr(module, func_name)
208
+
209
+ # The job runs with the default configuration
210
+ # ALl the execution logs are stored in .catalog
211
+ job = PythonJob(function=callable_obj)
212
+
213
+ config_content = {
214
+ "catalog": {"type": "file-system", "config": {"catalog_location": "proc_logs"}}
215
+ }
216
+
217
+ temp_config_file = Path("runnable-config.yaml")
218
+ with open(str(temp_config_file), "w", encoding="utf-8") as config_file:
219
+ yaml = YAML(typ="safe", pure=True)
220
+ yaml.dump(config_content, config_file)
221
+
222
+ job.execute(
223
+ parameters_file=parameters_files,
224
+ job_id=process_run_id,
225
+ )
226
+
227
+ # delete the temp config file
228
+ temp_config_file.unlink()
229
+
230
+ from runnable.context import run_context
231
+
232
+ job_log = run_context.run_log_store.get_run_log_by_id(run_id=run_context.run_id)
233
+
234
+ if job_log.status == defaults.FAIL:
235
+ raise Exception(f"Job {process_run_id} failed")
@@ -0,0 +1,76 @@
1
+ from enum import Enum
2
+ from typing import Any, Optional
3
+
4
+ from pydantic import BaseModel, ConfigDict, Field, computed_field
5
+
6
+
7
+ class StartMethod(str, Enum):
8
+ spawn = "spawn"
9
+ fork = "fork"
10
+ forkserver = "forkserver"
11
+
12
+
13
+ ## The idea is the following:
14
+ # Users can configure any of the options present in TorchConfig class.
15
+ # The LaunchConfig class will be created from TorchConfig.
16
+ # The LogSpecs is sent as a parameter to the launch config.
17
+
18
+ ## NO idea of standalone and how to send it
19
+
20
+
21
+ # The user sees this as part of the config of the node.
22
+ # It is kept as similar as possible to torchrun
23
+ class TorchConfig(BaseModel):
24
+ model_config = ConfigDict(extra="forbid")
25
+
26
+ # excluded as LaunchConfig requires min and max nodes
27
+ nnodes: str = Field(default="1:1", exclude=True, description="min:max")
28
+ nproc_per_node: int = Field(default=1, description="Number of processes per node")
29
+
30
+ # will be used to create the log specs
31
+ # But they are excluded from dump as logs specs is a class for LaunchConfig
32
+ # from_str("0") -> Std.NONE
33
+ # from_str("1") -> Std.OUT
34
+ # from_str("0:3,1:0,2:1,3:2") -> {0: Std.ALL, 1: Std.NONE, 2: Std.OUT, 3: Std.ERR}
35
+ log_dir: Optional[str] = Field(default="torch_logs", exclude=True)
36
+ redirects: str = Field(default="0", exclude=True) # Std.NONE
37
+ tee: str = Field(default="0", exclude=True) # Std.NONE
38
+ local_ranks_filter: Optional[set[int]] = Field(default=None, exclude=True)
39
+
40
+ role: str | None = Field(default=None)
41
+
42
+ # run_id would be the run_id of the context
43
+ # and sent at the creation of the LaunchConfig
44
+
45
+ # This section is about the communication between nodes/processes
46
+ rdzv_backend: str | None = Field(default="static")
47
+ rdzv_endpoint: str | None = Field(default="")
48
+ rdzv_configs: dict[str, Any] = Field(default_factory=dict)
49
+ rdzv_timeout: int | None = Field(default=None)
50
+
51
+ max_restarts: int | None = Field(default=None)
52
+ monitor_interval: float | None = Field(default=None)
53
+ start_method: str | None = Field(default=StartMethod.spawn)
54
+ log_line_prefix_template: str | None = Field(default=None)
55
+ local_addr: Optional[str] = None
56
+
57
+ # https://github.com/pytorch/pytorch/blob/main/torch/distributed/run.py#L753
58
+ # master_addr: str | None = Field(default="localhost")
59
+ # master_port: str | None = Field(default="29500")
60
+ # training_script: str = Field(default="dummy_training_script")
61
+ # training_script_args: str = Field(default="")
62
+
63
+
64
+ class EasyTorchConfig(TorchConfig):
65
+ model_config = ConfigDict(extra="ignore")
66
+
67
+ # TODO: Validate min < max
68
+ @computed_field # type: ignore
69
+ @property
70
+ def min_nodes(self) -> int:
71
+ return int(self.nnodes.split(":")[0])
72
+
73
+ @computed_field # type: ignore
74
+ @property
75
+ def max_nodes(self) -> int:
76
+ return int(self.nnodes.split(":")[1])
runnable/__init__.py CHANGED
@@ -31,7 +31,8 @@ from runnable.sdk import ( # noqa
31
31
  ShellTask,
32
32
  Stub,
33
33
  Success,
34
- Torch,
34
+ TorchJob,
35
+ TorchTask,
35
36
  metric,
36
37
  pickled,
37
38
  )
runnable/entrypoints.py CHANGED
@@ -129,6 +129,7 @@ def prepare_configurations(
129
129
  ServiceConfig,
130
130
  runnable_defaults.get("job-executor", defaults.DEFAULT_JOB_EXECUTOR),
131
131
  )
132
+
132
133
  assert job_executor_config, "Job executor is not provided"
133
134
  configured_executor = utils.get_provider_by_name_and_type(
134
135
  "job_executor", job_executor_config
runnable/sdk.py CHANGED
@@ -44,10 +44,10 @@ from runnable.tasks import TaskReturns
44
44
  logger = logging.getLogger(defaults.LOGGER_NAME)
45
45
 
46
46
  StepType = Union[
47
- "Stub", "PythonTask", "NotebookTask", "ShellTask", "Parallel", "Map", "Torch"
47
+ "Stub", "PythonTask", "NotebookTask", "ShellTask", "Parallel", "Map", "TorchTask"
48
48
  ]
49
49
  if TYPE_CHECKING:
50
- from extensions.nodes.torch import TorchNode
50
+ pass
51
51
 
52
52
 
53
53
  def pickled(name: str) -> TaskReturns:
@@ -191,6 +191,34 @@ class BaseTask(BaseTraversal):
191
191
  )
192
192
 
193
193
 
194
+ class TorchTask(BaseTask, TorchConfig):
195
+ function: Callable = Field(exclude=True)
196
+
197
+ @field_validator("returns", mode="before")
198
+ @classmethod
199
+ def serialize_returns(
200
+ cls, returns: List[Union[str, TaskReturns]]
201
+ ) -> List[TaskReturns]:
202
+ assert len(returns) == 0, "Torch tasks cannot return any variables"
203
+ return []
204
+
205
+ @computed_field
206
+ def command_type(self) -> str:
207
+ return "torch"
208
+
209
+ @computed_field
210
+ def command(self) -> str:
211
+ module = self.function.__module__
212
+ name = self.function.__name__
213
+
214
+ return f"{module}.{name}"
215
+
216
+ def create_job(self) -> RunnableTask:
217
+ self.terminate_with_success = True
218
+ node = self.create_node()
219
+ return node.executable
220
+
221
+
194
222
  class PythonTask(BaseTask):
195
223
  """
196
224
  An execution node of the pipeline of python functions.
@@ -459,43 +487,6 @@ class Stub(BaseTraversal):
459
487
  return StubNode.parse_from_config(self.model_dump(exclude_none=True))
460
488
 
461
489
 
462
- class Torch(BaseTraversal, TorchConfig):
463
- function: Callable = Field(exclude=True)
464
- catalog: Optional[Catalog] = Field(default=None, alias="catalog")
465
- overrides: Dict[str, Any] = Field(default_factory=dict, alias="overrides")
466
- returns: List[Union[str, TaskReturns]] = Field(
467
- default_factory=list, alias="returns"
468
- )
469
- secrets: List[str] = Field(default_factory=list)
470
-
471
- @computed_field
472
- def command_type(self) -> str:
473
- return "python"
474
-
475
- @computed_field
476
- def command(self) -> str:
477
- module = self.function.__module__
478
- name = self.function.__name__
479
-
480
- return f"{module}.{name}"
481
-
482
- def create_node(self) -> TorchNode:
483
- if not self.next_node:
484
- if not (self.terminate_with_failure or self.terminate_with_success):
485
- raise AssertionError(
486
- "A node not being terminated must have a user defined next node"
487
- )
488
-
489
- if self.on_failure:
490
- self.on_failure = self.on_failure.steps[0].name # type: ignore
491
-
492
- from extensions.nodes.torch import TorchNode
493
-
494
- return TorchNode.parse_from_config(
495
- self.model_dump(exclude_none=True, by_alias=True)
496
- )
497
-
498
-
499
490
  class Parallel(BaseTraversal):
500
491
  """
501
492
  A node that executes multiple branches in parallel.
@@ -685,6 +676,7 @@ class Pipeline(BaseModel):
685
676
  terminal_step: StepType = self.steps[-1]
686
677
  if not terminal_step.terminate_with_failure:
687
678
  terminal_step.terminate_with_success = True
679
+ terminal_step.next_node = "success"
688
680
 
689
681
  # assert that there is only one termination node with success or failure
690
682
  # Assert that there are no duplicate step names
@@ -965,7 +957,7 @@ class BaseJob(BaseModel):
965
957
 
966
958
 
967
959
  class PythonJob(BaseJob):
968
- function: Callable = Field(exclude=True)
960
+ function: Callable = Field()
969
961
 
970
962
  @property
971
963
  @computed_field
@@ -975,14 +967,27 @@ class PythonJob(BaseJob):
975
967
 
976
968
  return f"{module}.{name}"
977
969
 
970
+ # TODO: can this be simplified to just self.model_dump(exclude_none=True)?
978
971
  def get_task(self) -> RunnableTask:
979
972
  # Piggy bank on existing tasks as a hack
980
973
  task = PythonTask(
981
974
  name="dummy",
982
975
  terminate_with_success=True,
983
- returns=self.returns,
984
- secrets=self.secrets,
985
- function=self.function,
976
+ **self.model_dump(exclude_defaults=True, exclude_none=True),
977
+ )
978
+ return task.create_node().executable
979
+
980
+
981
+ class TorchJob(BaseJob, TorchConfig):
982
+ function: Callable = Field()
983
+ # min and max should always be 1
984
+
985
+ def get_task(self) -> RunnableTask:
986
+ # Piggy bank on existing tasks as a hack
987
+ task = TorchTask(
988
+ name="dummy",
989
+ terminate_with_success=True,
990
+ **self.model_dump(exclude_defaults=True, exclude_none=True),
986
991
  )
987
992
  return task.create_node().executable
988
993
 
@@ -998,10 +1003,7 @@ class NotebookJob(BaseJob):
998
1003
  task = NotebookTask(
999
1004
  name="dummy",
1000
1005
  terminate_with_success=True,
1001
- returns=self.returns,
1002
- secrets=self.secrets,
1003
- notebook=self.notebook,
1004
- optional_ploomber_args=self.optional_ploomber_args,
1006
+ **self.model_dump(exclude_defaults=True, exclude_none=True),
1005
1007
  )
1006
1008
  return task.create_node().executable
1007
1009
 
@@ -1014,8 +1016,6 @@ class ShellJob(BaseJob):
1014
1016
  task = ShellTask(
1015
1017
  name="dummy",
1016
1018
  terminate_with_success=True,
1017
- returns=self.returns,
1018
- secrets=self.secrets,
1019
- command=self.command,
1019
+ **self.model_dump(exclude_defaults=True, exclude_none=True),
1020
1020
  )
1021
1021
  return task.create_node().executable
runnable/tasks.py CHANGED
@@ -28,7 +28,6 @@ from runnable.datastore import (
28
28
  from runnable.defaults import TypeMapVariable
29
29
 
30
30
  logger = logging.getLogger(defaults.LOGGER_NAME)
31
- logging.getLogger("stevedore").setLevel(logging.CRITICAL)
32
31
 
33
32
 
34
33
  class TeeIO(io.StringIO):
@@ -49,8 +48,7 @@ class TeeIO(io.StringIO):
49
48
  self.output_stream.flush()
50
49
 
51
50
 
52
- buffer = TeeIO()
53
- sys.stdout = buffer
51
+ sys.stdout = TeeIO()
54
52
 
55
53
 
56
54
  class TaskReturns(BaseModel):
@@ -761,6 +759,8 @@ def create_task(kwargs_for_init) -> BaseTaskType:
761
759
  tasks.BaseTaskType: The command object
762
760
  """
763
761
  # The dictionary cannot be modified
762
+
763
+ print(kwargs_for_init)
764
764
  kwargs = kwargs_for_init.copy()
765
765
  command_type = kwargs.pop("command_type", defaults.COMMAND_TYPE)
766
766
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: runnable
3
- Version: 0.31.0
3
+ Version: 0.32.0
4
4
  Summary: Add your description here
5
5
  Author-email: "Vammi, Vijay" <vijay.vammi@astrazeneca.com>
6
6
  License-File: LICENSE
@@ -28,6 +28,7 @@ Provides-Extra: s3
28
28
  Requires-Dist: cloudpathlib[s3]; extra == 's3'
29
29
  Provides-Extra: torch
30
30
  Requires-Dist: torch>=2.6.0; extra == 'torch'
31
+ Requires-Dist: torchvision>=0.21.0; extra == 'torch'
31
32
  Description-Content-Type: text/markdown
32
33
 
33
34
 
@@ -16,8 +16,8 @@ extensions/job_executor/pyproject.toml,sha256=UIEgiCYHTXcRWSByNMFuKJFKgxTBpQqTqy
16
16
  extensions/nodes/README.md,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
17
  extensions/nodes/nodes.py,sha256=s9ub1dqy4qHjRQG6YElCdL7rCOTYNs9RUIrStZ6tEB4,28256
18
18
  extensions/nodes/pyproject.toml,sha256=YTu-ETN3JNFSkMzzWeOwn4m-O2nbRH-PmiPBALDCUw4,278
19
- extensions/nodes/torch.py,sha256=RUelXV7Pa4U5F7Ww3cfRG0Oaz9SkYF3b_CmpFHlpbyI,6885
20
- extensions/nodes/torch_config.py,sha256=jfUtkwCYolyKVcFxiMjjwm63yv-HjTKvSQR8JLA7sZg,3151
19
+ extensions/nodes/torch.py,sha256=h3x5931ePBNckeSXM3JFjSoUnxmIWvDyEpn1AI9TKaU,9347
20
+ extensions/nodes/torch_config.py,sha256=tO3sG2_fj8a6FmPZZllwKVx3WaRr4QmQYcACseg8YXM,2839
21
21
  extensions/pipeline_executor/README.md,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
22
22
  extensions/pipeline_executor/__init__.py,sha256=wfigTL2T9OHrmE8b2Ydmb8h6hr-oF--Yc2FectC7WaY,24623
23
23
  extensions/pipeline_executor/argo.py,sha256=AEGSWVZulBL6EsvbVCaeBeTl2m_t5ymc6RFpMKhivis,37946
@@ -40,13 +40,15 @@ extensions/run_log_store/db/integration_FF.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeR
40
40
  extensions/secrets/README.md,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
41
41
  extensions/secrets/dotenv.py,sha256=nADHXI6KJ_LUYOIe5EbtYH-21OBebSNVr0Pjb1GlZ7w,1573
42
42
  extensions/secrets/pyproject.toml,sha256=mLJNImNcBlbLKHh-0ugVWT9V83R4RibyyYDtBCSqVF4,282
43
- runnable/__init__.py,sha256=swvqdCjeddn40o4zjsluyahdVcU0r1arSRrxmRsvFEQ,673
43
+ extensions/tasks/torch.py,sha256=R0J_Q6SRAW2Ii0XQbXaaBWTah8TYs4P_48j2M1bIXeA,7983
44
+ extensions/tasks/torch_config.py,sha256=tO3sG2_fj8a6FmPZZllwKVx3WaRr4QmQYcACseg8YXM,2839
45
+ runnable/__init__.py,sha256=3ZKuvGEkY_zHVQlJtarXd4jkjICxjgnw-bbKN_5SiJI,691
44
46
  runnable/catalog.py,sha256=4msQxLhLKlsDDrHFnGauPYe-Or-q9g8_RYCn_4dpxaU,4466
45
47
  runnable/cli.py,sha256=3BiKSj95h2Drn__YlchMPZ5rBMafuRb2OGIsVpbsO5Y,8788
46
48
  runnable/context.py,sha256=by5uepmuCP0dmM9BmsliXihSes5QEFejwAsmekcqylE,1388
47
49
  runnable/datastore.py,sha256=ZobM1aVkgeUJ2fZYt63IFDsoNzObwc93hdByegS5YKQ,32396
48
50
  runnable/defaults.py,sha256=3o9IVGryyCE6PoQTOoaIaHHTbJGEzmdXMcwzOhwAYoI,3518
49
- runnable/entrypoints.py,sha256=cDbhtmLUWdBh9K6hNusfQpSd5NadcX8V1K2JEDf_YAg,18984
51
+ runnable/entrypoints.py,sha256=1xCbWVUQLGmg5gkWnAVWFLAUf6j4avP9azX_vuGQUMY,18985
50
52
  runnable/exceptions.py,sha256=LFbp0-Qxg2PAMLEVt7w2whhBxSG-5pzUEv5qN-Rc4_c,3003
51
53
  runnable/executor.py,sha256=UOsYJ3NkTGw4FTR0iePX7AOJzY7vODhZ62aqrwVMO1c,15143
52
54
  runnable/graph.py,sha256=poQz5zcvq89ju_u5sYlunQLPbHnXTaUmjcvstPwvT4U,16536
@@ -54,12 +56,12 @@ runnable/names.py,sha256=vn92Kv9ANROYSZX6Z4z1v_WA3WiEdIYmG6KEStBFZug,8134
54
56
  runnable/nodes.py,sha256=d1eLttMAcV7CTwTEqOuNwZqItANoLUkXJ73Xp-srlyI,17811
55
57
  runnable/parameters.py,sha256=sT3DNGczivP9z7r4Cp_brbudg1z4J-zjmvrq3ppIrVs,5089
56
58
  runnable/pickler.py,sha256=ydJ_eti_U1F4l-YacFp7BWm6g5vTn04UXye25S1HVok,2684
57
- runnable/sdk.py,sha256=NZVQGaL4Zm2hwloRmqEgp8UPbBg9hY1abQGYnOgniPI,35128
59
+ runnable/sdk.py,sha256=J1PyiHQD2v_0JaqHjY7xSaXwCUMi_mCNr70TsC-SFZU,35012
58
60
  runnable/secrets.py,sha256=4L_dBFxTgr8r_hHUD6RlZEtqaOHDRsFG5PXO5wlvMI0,2324
59
- runnable/tasks.py,sha256=Qb1IhVxHv68E7vf3M3YCf7MGRHyjmsEEYBpEpiZ4mRI,29062
61
+ runnable/tasks.py,sha256=_A0pcTyOGQL-72AicOxracsrwfs2Vg0r4mQyxz3k6Iw,29016
60
62
  runnable/utils.py,sha256=hBr7oGwGL2VgfITlQCTz-a1iwvvf7Mfl-HY8UdENZac,19929
61
- runnable-0.31.0.dist-info/METADATA,sha256=9c3Ixkq-Kl0_hiQfDX-KwtSAdSWzMRLJMfEze2oVQhE,10115
62
- runnable-0.31.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
63
- runnable-0.31.0.dist-info/entry_points.txt,sha256=PrjKrlfXPZaV_7hz8orGu4FDnatLqnhPOXljyllszdw,1880
64
- runnable-0.31.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
65
- runnable-0.31.0.dist-info/RECORD,,
63
+ runnable-0.32.0.dist-info/METADATA,sha256=t44gRxxaRugnqaRY9gGwweGT0OLvo_inlC3jxrhP3sg,10168
64
+ runnable-0.32.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
65
+ runnable-0.32.0.dist-info/entry_points.txt,sha256=uWHbbOSj0jlG54tFHw377xKkfVbjWvb_1Y9L_LgjJ0Q,1925
66
+ runnable-0.32.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
67
+ runnable-0.32.0.dist-info/RECORD,,
@@ -49,3 +49,4 @@ env-secrets = runnable.secrets:EnvSecretsManager
49
49
  notebook = runnable.tasks:NotebookTaskType
50
50
  python = runnable.tasks:PythonTaskType
51
51
  shell = runnable.tasks:ShellTaskType
52
+ torch = extensions.tasks.torch:TorchTaskType