runnable 0.34.0a1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of runnable might be problematic. Click here for more details.

Files changed (49) hide show
  1. extensions/catalog/any_path.py +13 -2
  2. extensions/job_executor/__init__.py +7 -5
  3. extensions/job_executor/emulate.py +106 -0
  4. extensions/job_executor/k8s.py +8 -8
  5. extensions/job_executor/local_container.py +13 -14
  6. extensions/nodes/__init__.py +0 -0
  7. extensions/nodes/conditional.py +243 -0
  8. extensions/nodes/fail.py +72 -0
  9. extensions/nodes/map.py +350 -0
  10. extensions/nodes/parallel.py +159 -0
  11. extensions/nodes/stub.py +89 -0
  12. extensions/nodes/success.py +72 -0
  13. extensions/nodes/task.py +92 -0
  14. extensions/pipeline_executor/__init__.py +27 -27
  15. extensions/pipeline_executor/argo.py +52 -46
  16. extensions/pipeline_executor/emulate.py +112 -0
  17. extensions/pipeline_executor/local.py +4 -4
  18. extensions/pipeline_executor/local_container.py +19 -79
  19. extensions/pipeline_executor/mocked.py +5 -9
  20. extensions/pipeline_executor/retry.py +6 -10
  21. runnable/__init__.py +2 -11
  22. runnable/catalog.py +6 -23
  23. runnable/cli.py +145 -48
  24. runnable/context.py +520 -28
  25. runnable/datastore.py +51 -54
  26. runnable/defaults.py +12 -34
  27. runnable/entrypoints.py +82 -440
  28. runnable/exceptions.py +35 -34
  29. runnable/executor.py +13 -20
  30. runnable/gantt.py +1141 -0
  31. runnable/graph.py +1 -1
  32. runnable/names.py +1 -1
  33. runnable/nodes.py +20 -16
  34. runnable/parameters.py +108 -51
  35. runnable/sdk.py +125 -204
  36. runnable/tasks.py +62 -85
  37. runnable/utils.py +6 -268
  38. runnable-1.0.0.dist-info/METADATA +122 -0
  39. runnable-1.0.0.dist-info/RECORD +73 -0
  40. {runnable-0.34.0a1.dist-info → runnable-1.0.0.dist-info}/entry_points.txt +9 -8
  41. extensions/nodes/nodes.py +0 -778
  42. extensions/nodes/torch.py +0 -273
  43. extensions/nodes/torch_config.py +0 -76
  44. extensions/tasks/torch.py +0 -286
  45. extensions/tasks/torch_config.py +0 -76
  46. runnable-0.34.0a1.dist-info/METADATA +0 -267
  47. runnable-0.34.0a1.dist-info/RECORD +0 -67
  48. {runnable-0.34.0a1.dist-info → runnable-1.0.0.dist-info}/WHEEL +0 -0
  49. {runnable-0.34.0a1.dist-info → runnable-1.0.0.dist-info}/licenses/LICENSE +0 -0
extensions/nodes/torch.py DELETED
@@ -1,273 +0,0 @@
1
- import importlib
2
- import logging
3
- import os
4
- import random
5
- import string
6
- from datetime import datetime
7
- from pathlib import Path
8
- from typing import Any, Callable, Optional
9
-
10
- from pydantic import BaseModel, ConfigDict, Field, field_serializer
11
-
12
- from extensions.nodes.torch_config import EasyTorchConfig, TorchConfig
13
- from runnable import PythonJob, datastore, defaults
14
- from runnable.datastore import StepLog
15
- from runnable.nodes import ExecutableNode
16
- from runnable.tasks import PythonTaskType, create_task
17
- from runnable.utils import TypeMapVariable
18
-
19
- logger = logging.getLogger(defaults.LOGGER_NAME)
20
-
21
- try:
22
- from torch.distributed.elastic.multiprocessing.api import DefaultLogsSpecs, Std
23
- from torch.distributed.launcher.api import LaunchConfig, elastic_launch
24
- except ImportError:
25
- logger.exception("Torch is not installed. Please install torch first.")
26
- raise Exception("Torch is not installed. Please install torch first.")
27
-
28
-
29
- def training_subprocess():
30
- """
31
- This function is called by the torch.distributed.launcher.api.elastic_launch
32
- It happens in a subprocess and is responsible for executing the user's function
33
-
34
- It is unrelated to the actual node execution, so any cataloging, run_log_store should be
35
- handled to match to main process.
36
-
37
- We have these variables to use:
38
-
39
- os.environ["RUNNABLE_TORCH_COMMAND"] = self.executable.command
40
- os.environ["RUNNABLE_TORCH_PARAMETERS_FILES"] = (
41
- self._context.parameters_file or ""
42
- )
43
- os.environ["RUNNABLE_TORCH_RUN_ID"] = self._context.run_id
44
- os.environ["RUNNABLE_TORCH_COPY_CONTENTS_TO"] = (
45
- self._context.catalog_handler.compute_data_folder
46
- )
47
- os.environ["RUNNABLE_TORCH_TORCH_LOGS"] = self.log_dir or ""
48
-
49
- """
50
- command = os.environ.get("RUNNABLE_TORCH_COMMAND")
51
- run_id = os.environ.get("RUNNABLE_TORCH_RUN_ID", "")
52
- parameters_files = os.environ.get("RUNNABLE_TORCH_PARAMETERS_FILES", "")
53
-
54
- process_run_id = (
55
- run_id
56
- + "-"
57
- + os.environ.get("RANK", "")
58
- + "-"
59
- + "".join(random.choices(string.ascii_lowercase, k=3))
60
- )
61
- os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
62
-
63
- delete_env_vars_with_prefix("RUNNABLE_")
64
-
65
- func = get_callable_from_dotted_path(command)
66
-
67
- # The job runs with the default configuration
68
- # ALl the execution logs are stored in .catalog
69
- job = PythonJob(function=func)
70
-
71
- job.execute(
72
- parameters_file=parameters_files,
73
- job_id=process_run_id,
74
- )
75
-
76
- from runnable.context import run_context
77
-
78
- job_log = run_context.run_log_store.get_run_log_by_id(run_id=run_context.run_id)
79
-
80
- if job_log.status == defaults.FAIL:
81
- raise Exception(f"Job {process_run_id} failed")
82
-
83
-
84
- # TODO: Can this be utils.get_module_and_attr_names
85
- def get_callable_from_dotted_path(dotted_path) -> Callable:
86
- try:
87
- # Split the path into module path and callable object
88
- module_path, callable_name = dotted_path.rsplit(".", 1)
89
-
90
- # Import the module
91
- module = importlib.import_module(module_path)
92
-
93
- # Get the callable from the module
94
- callable_obj = getattr(module, callable_name)
95
-
96
- # Check if the object is callable
97
- if not callable(callable_obj):
98
- raise TypeError(f"The object {callable_name} is not callable.")
99
-
100
- return callable_obj
101
-
102
- except (ImportError, AttributeError, ValueError) as e:
103
- raise ImportError(f"Could not import '{dotted_path}'.") from e
104
-
105
-
106
- def delete_env_vars_with_prefix(prefix):
107
- to_delete = [] # List to keep track of variables to delete
108
-
109
- # Iterate over a list of all environment variable keys
110
- for var in os.environ:
111
- if var.startswith(prefix):
112
- to_delete.append(var)
113
-
114
- # Delete each of the variables collected
115
- for var in to_delete:
116
- del os.environ[var]
117
-
118
-
119
- # TODO: The design of this class is not final
120
- class TorchNode(ExecutableNode, TorchConfig):
121
- node_type: str = Field(default="torch", serialization_alias="type")
122
- executable: PythonTaskType = Field(exclude=True)
123
-
124
- # Similar to TaskNode
125
- model_config = ConfigDict(extra="allow")
126
-
127
- def get_summary(self) -> dict[str, Any]:
128
- summary = {
129
- "name": self.name,
130
- "type": self.node_type,
131
- }
132
-
133
- return summary
134
-
135
- @classmethod
136
- def parse_from_config(cls, config: dict[str, Any]) -> "TorchNode":
137
- task_config = {
138
- k: v for k, v in config.items() if k not in TorchNode.model_fields.keys()
139
- }
140
- node_config = {
141
- k: v for k, v in config.items() if k in TorchNode.model_fields.keys()
142
- }
143
-
144
- executable = create_task(task_config)
145
-
146
- assert isinstance(executable, PythonTaskType)
147
- return cls(executable=executable, **node_config, **task_config)
148
-
149
- def get_launch_config(self) -> LaunchConfig:
150
- internal_log_spec = InternalLogSpecs(**self.model_dump(exclude_none=True))
151
- log_spec: DefaultLogsSpecs = DefaultLogsSpecs(
152
- **internal_log_spec.model_dump(exclude_none=True)
153
- )
154
- easy_torch_config = EasyTorchConfig(
155
- **self.model_dump(
156
- exclude_none=True,
157
- )
158
- )
159
-
160
- launch_config = LaunchConfig(
161
- **easy_torch_config.model_dump(
162
- exclude_none=True,
163
- ),
164
- logs_specs=log_spec,
165
- run_id=self._context.run_id,
166
- )
167
- logger.info(f"launch_config: {launch_config}")
168
- return launch_config
169
-
170
- def execute(
171
- self,
172
- mock=False,
173
- map_variable: TypeMapVariable = None,
174
- attempt_number: int = 1,
175
- ) -> StepLog:
176
- assert (
177
- map_variable is None or not map_variable
178
- ), "TorchNode does not support map_variable"
179
-
180
- step_log = self._context.run_log_store.get_step_log(
181
- self._get_step_log_name(map_variable), self._context.run_id
182
- )
183
-
184
- # Attempt to call the function or elastic launch
185
- launch_config = self.get_launch_config()
186
- logger.info(f"launch_config: {launch_config}")
187
-
188
- # ENV variables are shared with the subprocess, use that as communication
189
- os.environ["RUNNABLE_TORCH_COMMAND"] = self.executable.command
190
- os.environ["RUNNABLE_TORCH_PARAMETERS_FILES"] = (
191
- self._context.parameters_file or ""
192
- )
193
- os.environ["RUNNABLE_TORCH_RUN_ID"] = self._context.run_id
194
-
195
- launcher = elastic_launch(
196
- launch_config,
197
- training_subprocess,
198
- )
199
- try:
200
- launcher()
201
- attempt_log = datastore.StepAttempt(
202
- status=defaults.SUCCESS,
203
- start_time=str(datetime.now()),
204
- end_time=str(datetime.now()),
205
- attempt_number=attempt_number,
206
- )
207
- except Exception as e:
208
- attempt_log = datastore.StepAttempt(
209
- status=defaults.FAIL,
210
- start_time=str(datetime.now()),
211
- end_time=str(datetime.now()),
212
- attempt_number=attempt_number,
213
- )
214
- logger.error(f"Error executing TorchNode: {e}")
215
- finally:
216
- # This can only come from the subprocess
217
- if Path(".catalog").exists():
218
- os.rename(".catalog", "proc_logs")
219
- # Move .catalog and torch_logs to the parent node's catalog location
220
- self._context.catalog_handler.put(
221
- "proc_logs/**/*", allow_file_not_found_exc=True
222
- )
223
-
224
- # TODO: This is not working!!
225
- if self.log_dir:
226
- self._context.catalog_handler.put(
227
- self.log_dir + "/**/*", allow_file_not_found_exc=True
228
- )
229
-
230
- delete_env_vars_with_prefix("RUNNABLE_TORCH")
231
-
232
- logger.info(f"attempt_log: {attempt_log}")
233
- logger.info(f"Step {self.name} completed with status: {attempt_log.status}")
234
-
235
- step_log.status = attempt_log.status
236
- step_log.attempts.append(attempt_log)
237
-
238
- return step_log
239
-
240
- def fan_in(self, map_variable: dict[str, str | int | float] | None = None):
241
- # Destroy the service
242
- # Destroy the statefulset
243
- assert (
244
- map_variable is None or not map_variable
245
- ), "TorchNode does not support map_variable"
246
-
247
- def fan_out(self, map_variable: dict[str, str | int | float] | None = None):
248
- # Create a service
249
- # Create a statefulset
250
- # Gather the IPs and set them as parameters downstream
251
- assert (
252
- map_variable is None or not map_variable
253
- ), "TorchNode does not support map_variable"
254
-
255
-
256
- # This internal model makes it easier to extract the required fields
257
- # of log specs from user specification.
258
- # https://github.com/pytorch/pytorch/blob/main/torch/distributed/elastic/multiprocessing/api.py#L243
259
- class InternalLogSpecs(BaseModel):
260
- log_dir: Optional[str] = Field(default="torch_logs")
261
- redirects: str = Field(default="0") # Std.NONE
262
- tee: str = Field(default="0") # Std.NONE
263
- local_ranks_filter: Optional[set[int]] = Field(default=None)
264
-
265
- model_config = ConfigDict(extra="ignore")
266
-
267
- @field_serializer("redirects")
268
- def convert_redirects(self, redirects: str) -> Std | dict[int, Std]:
269
- return Std.from_str(redirects)
270
-
271
- @field_serializer("tee")
272
- def convert_tee(self, tee: str) -> Std | dict[int, Std]:
273
- return Std.from_str(tee)
@@ -1,76 +0,0 @@
1
- from enum import Enum
2
- from typing import Any, Optional
3
-
4
- from pydantic import BaseModel, ConfigDict, Field, computed_field
5
-
6
-
7
- class StartMethod(str, Enum):
8
- spawn = "spawn"
9
- fork = "fork"
10
- forkserver = "forkserver"
11
-
12
-
13
- ## The idea is the following:
14
- # Users can configure any of the options present in TorchConfig class.
15
- # The LaunchConfig class will be created from TorchConfig.
16
- # The LogSpecs is sent as a parameter to the launch config.
17
-
18
- ## NO idea of standalone and how to send it
19
-
20
-
21
- # The user sees this as part of the config of the node.
22
- # It is kept as similar as possible to torchrun
23
- class TorchConfig(BaseModel):
24
- model_config = ConfigDict(extra="forbid")
25
-
26
- # excluded as LaunchConfig requires min and max nodes
27
- nnodes: str = Field(default="1:1", exclude=True, description="min:max")
28
- nproc_per_node: int = Field(default=1, description="Number of processes per node")
29
-
30
- # will be used to create the log specs
31
- # But they are excluded from dump as logs specs is a class for LaunchConfig
32
- # from_str("0") -> Std.NONE
33
- # from_str("1") -> Std.OUT
34
- # from_str("0:3,1:0,2:1,3:2") -> {0: Std.ALL, 1: Std.NONE, 2: Std.OUT, 3: Std.ERR}
35
- log_dir: Optional[str] = Field(default="torch_logs", exclude=True)
36
- redirects: str = Field(default="0", exclude=True) # Std.NONE
37
- tee: str = Field(default="0", exclude=True) # Std.NONE
38
- local_ranks_filter: Optional[set[int]] = Field(default=None, exclude=True)
39
-
40
- role: str | None = Field(default=None)
41
-
42
- # run_id would be the run_id of the context
43
- # and sent at the creation of the LaunchConfig
44
-
45
- # This section is about the communication between nodes/processes
46
- rdzv_backend: str | None = Field(default="static")
47
- rdzv_endpoint: str | None = Field(default="")
48
- rdzv_configs: dict[str, Any] = Field(default_factory=dict)
49
- rdzv_timeout: int | None = Field(default=None)
50
-
51
- max_restarts: int | None = Field(default=None)
52
- monitor_interval: float | None = Field(default=None)
53
- start_method: str | None = Field(default=StartMethod.spawn)
54
- log_line_prefix_template: str | None = Field(default=None)
55
- local_addr: Optional[str] = None
56
-
57
- # https://github.com/pytorch/pytorch/blob/main/torch/distributed/run.py#L753
58
- # master_addr: str | None = Field(default="localhost")
59
- # master_port: str | None = Field(default="29500")
60
- # training_script: str = Field(default="dummy_training_script")
61
- # training_script_args: str = Field(default="")
62
-
63
-
64
- class EasyTorchConfig(TorchConfig):
65
- model_config = ConfigDict(extra="ignore")
66
-
67
- # TODO: Validate min < max
68
- @computed_field # type: ignore
69
- @property
70
- def min_nodes(self) -> int:
71
- return int(self.nnodes.split(":")[0])
72
-
73
- @computed_field # type: ignore
74
- @property
75
- def max_nodes(self) -> int:
76
- return int(self.nnodes.split(":")[1])
extensions/tasks/torch.py DELETED
@@ -1,286 +0,0 @@
1
- import importlib
2
- import logging
3
- import os
4
- import random
5
- import string
6
- from datetime import datetime
7
- from pathlib import Path
8
- from typing import Any, Optional
9
-
10
- from pydantic import BaseModel, ConfigDict, Field, field_serializer, model_validator
11
- from ruamel.yaml import YAML
12
-
13
- import runnable.context as context
14
- from extensions.tasks.torch_config import EasyTorchConfig, TorchConfig
15
- from runnable import Catalog, defaults
16
- from runnable.datastore import StepAttempt
17
- from runnable.tasks import BaseTaskType
18
- from runnable.utils import get_module_and_attr_names
19
-
20
- logger = logging.getLogger(defaults.LOGGER_NAME)
21
-
22
- logger = logging.getLogger(defaults.LOGGER_NAME)
23
-
24
- try:
25
- from torch.distributed.elastic.multiprocessing.api import DefaultLogsSpecs, Std
26
- from torch.distributed.launcher.api import LaunchConfig, elastic_launch
27
-
28
- except ImportError as e:
29
- logger.exception("torch is not installed")
30
- raise Exception("torch is not installed") from e
31
-
32
-
33
- def get_min_max_nodes(nnodes: str) -> tuple[int, int]:
34
- min_nodes, max_nodes = (int(x) for x in nnodes.split(":"))
35
- return min_nodes, max_nodes
36
-
37
-
38
- class TorchTaskType(BaseTaskType, TorchConfig):
39
- task_type: str = Field(default="torch", serialization_alias="command_type")
40
- catalog: Optional[Catalog] = Field(default=None, alias="catalog")
41
- command: str
42
-
43
- @model_validator(mode="before")
44
- @classmethod
45
- def check_secrets_and_returns(cls, data: Any) -> Any:
46
- if isinstance(data, dict):
47
- if "secrets" in data and data["secrets"]:
48
- raise ValueError("'secrets' is not supported for torch")
49
- if "returns" in data and data["returns"]:
50
- raise ValueError("'secrets' is not supported for torch")
51
- return data
52
-
53
- def get_summary(self) -> dict[str, Any]:
54
- return self.model_dump(by_alias=True, exclude_none=True)
55
-
56
- @property
57
- def _context(self):
58
- return context.run_context
59
-
60
- def _get_launch_config(self) -> LaunchConfig:
61
- internal_log_spec = InternalLogSpecs(**self.model_dump(exclude_none=True))
62
- log_spec: DefaultLogsSpecs = DefaultLogsSpecs(
63
- **internal_log_spec.model_dump(exclude_none=True)
64
- )
65
- easy_torch_config = EasyTorchConfig(
66
- **self.model_dump(
67
- exclude_none=True,
68
- )
69
- )
70
- print("###", easy_torch_config)
71
- print("###", easy_torch_config)
72
- launch_config = LaunchConfig(
73
- **easy_torch_config.model_dump(
74
- exclude_none=True,
75
- ),
76
- logs_specs=log_spec,
77
- run_id=self._context.run_id,
78
- )
79
- logger.info(f"launch_config: {launch_config}")
80
- return launch_config
81
-
82
- def execute_command(
83
- self,
84
- map_variable: defaults.TypeMapVariable = None,
85
- ):
86
- assert map_variable is None, "map_variable is not supported for torch"
87
-
88
- # The below should happen only if we are in the node that we want to execute
89
- # For a single node, multi worker setup, this should be the entry point
90
- # For a multi-node, we need to:
91
- # - create a service config
92
- # - Create a stateful set with number of nodes
93
- # - Create a job to run the torch.distributed.launcher.api.elastic_launch on every node
94
- # - the entry point to runnnable could be a way to trigger execution instead of scaling
95
- is_execute = os.environ.get("RUNNABLE_TORCH_EXECUTE", "true") == "true"
96
-
97
- _, max_nodes = get_min_max_nodes(self.nnodes)
98
-
99
- if max_nodes > 1 and not is_execute:
100
- executor = self._context.executor
101
- executor.scale_up(self)
102
- return StepAttempt(
103
- status=defaults.SUCCESS,
104
- start_time=str(datetime.now()),
105
- end_time=str(datetime.now()),
106
- attempt_number=1,
107
- message="Triggered a scale up",
108
- )
109
-
110
- # The below should happen only if we are in the node that we want to execute
111
- # For a single node, multi worker setup, this should be the entry point
112
- # For a multi-node, we need to:
113
- # - create a service config
114
- # - Create a stateful set with number of nodes
115
- # - Create a job to run the torch.distributed.launcher.api.elastic_launch on every node
116
- # - the entry point to runnnable could be a way to trigger execution instead of scaling
117
- is_execute = os.environ.get("RUNNABLE_TORCH_EXECUTE", "true") == "true"
118
-
119
- _, max_nodes = get_min_max_nodes(self.nnodes)
120
-
121
- if max_nodes > 1 and not is_execute:
122
- executor = self._context.executor
123
- executor.scale_up(self)
124
- return StepAttempt(
125
- status=defaults.SUCCESS,
126
- start_time=str(datetime.now()),
127
- end_time=str(datetime.now()),
128
- attempt_number=1,
129
- message="Triggered a scale up",
130
- )
131
-
132
- launch_config = self._get_launch_config()
133
- print("###****", launch_config)
134
- print("###****", launch_config)
135
- logger.info(f"launch_config: {launch_config}")
136
-
137
- # ENV variables are shared with the subprocess, use that as communication
138
- os.environ["RUNNABLE_TORCH_COMMAND"] = self.command
139
- os.environ["RUNNABLE_TORCH_PARAMETERS_FILES"] = (
140
- self._context.parameters_file or ""
141
- )
142
- os.environ["RUNNABLE_TORCH_RUN_ID"] = self._context.run_id
143
-
144
- launcher = elastic_launch(
145
- launch_config,
146
- training_subprocess,
147
- )
148
- try:
149
- launcher()
150
- attempt_log = StepAttempt(
151
- status=defaults.SUCCESS,
152
- start_time=str(datetime.now()),
153
- end_time=str(datetime.now()),
154
- attempt_number=1,
155
- )
156
- except Exception as e:
157
- attempt_log = StepAttempt(
158
- status=defaults.FAIL,
159
- start_time=str(datetime.now()),
160
- end_time=str(datetime.now()),
161
- attempt_number=1,
162
- )
163
- logger.error(f"Error executing TorchNode: {e}")
164
- finally:
165
- # This can only come from the subprocess
166
- if Path("proc_logs").exists():
167
- # Move .catalog and torch_logs to the parent node's catalog location
168
- self._context.catalog_handler.put(
169
- "proc_logs/**/*", allow_file_not_found_exc=True
170
- )
171
-
172
- # TODO: This is not working!!
173
- if self.log_dir:
174
- self._context.catalog_handler.put(
175
- self.log_dir + "/**/*", allow_file_not_found_exc=True
176
- )
177
-
178
- delete_env_vars_with_prefix("RUNNABLE_TORCH")
179
- logger.info(f"attempt_log: {attempt_log}")
180
-
181
- return attempt_log
182
-
183
-
184
- # This internal model makes it easier to extract the required fields
185
- # of log specs from user specification.
186
- # https://github.com/pytorch/pytorch/blob/main/torch/distributed/elastic/multiprocessing/api.py#L243
187
- class InternalLogSpecs(BaseModel):
188
- log_dir: Optional[str] = Field(default="torch_logs")
189
- redirects: str = Field(default="0") # Std.NONE
190
- tee: str = Field(default="0") # Std.NONE
191
- local_ranks_filter: Optional[set[int]] = Field(default=None)
192
-
193
- model_config = ConfigDict(extra="ignore")
194
-
195
- @field_serializer("redirects")
196
- def convert_redirects(self, redirects: str) -> Std | dict[int, Std]:
197
- return Std.from_str(redirects)
198
-
199
- @field_serializer("tee")
200
- def convert_tee(self, tee: str) -> Std | dict[int, Std]:
201
- return Std.from_str(tee)
202
-
203
-
204
- def delete_env_vars_with_prefix(prefix):
205
- to_delete = [] # List to keep track of variables to delete
206
-
207
- # Iterate over a list of all environment variable keys
208
- for var in os.environ:
209
- if var.startswith(prefix):
210
- to_delete.append(var)
211
-
212
- # Delete each of the variables collected
213
- for var in to_delete:
214
- del os.environ[var]
215
-
216
-
217
- def training_subprocess():
218
- """
219
- This function is called by the torch.distributed.launcher.api.elastic_launch
220
- It happens in a subprocess and is responsible for executing the user's function
221
-
222
- It is unrelated to the actual node execution, so any cataloging, run_log_store should be
223
- handled to match to main process.
224
-
225
- We have these variables to use:
226
-
227
- os.environ["RUNNABLE_TORCH_COMMAND"] = self.executable.command
228
- os.environ["RUNNABLE_TORCH_PARAMETERS_FILES"] = (
229
- self._context.parameters_file or ""
230
- )
231
- os.environ["RUNNABLE_TORCH_RUN_ID"] = self._context.run_id
232
- os.environ["RUNNABLE_TORCH_TORCH_LOGS"] = self.log_dir or ""
233
-
234
- """
235
- from runnable import PythonJob # noqa: F401
236
-
237
- command = os.environ.get("RUNNABLE_TORCH_COMMAND")
238
- assert command, "Command is not provided"
239
-
240
- run_id = os.environ.get("RUNNABLE_TORCH_RUN_ID", "")
241
- parameters_files = os.environ.get("RUNNABLE_TORCH_PARAMETERS_FILES", "")
242
-
243
- process_run_id = (
244
- run_id
245
- + "-"
246
- + os.environ.get("RANK", "")
247
- + "-"
248
- + "".join(random.choices(string.ascii_lowercase, k=3))
249
- )
250
- os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
251
-
252
- # In this subprocess there shoould not be any RUNNABLE environment variables
253
- delete_env_vars_with_prefix("RUNNABLE_")
254
-
255
- module_name, func_name = get_module_and_attr_names(command)
256
- module = importlib.import_module(module_name)
257
-
258
- callable_obj = getattr(module, func_name)
259
-
260
- # The job runs with the default configuration
261
- # ALl the execution logs are stored in .catalog
262
- job = PythonJob(function=callable_obj)
263
-
264
- config_content = {
265
- "catalog": {"type": "file-system", "config": {"catalog_location": "proc_logs"}}
266
- }
267
-
268
- temp_config_file = Path("runnable-config.yaml")
269
- with open(str(temp_config_file), "w", encoding="utf-8") as config_file:
270
- yaml = YAML(typ="safe", pure=True)
271
- yaml.dump(config_content, config_file)
272
-
273
- job.execute(
274
- parameters_file=parameters_files,
275
- job_id=process_run_id,
276
- )
277
-
278
- # delete the temp config file
279
- temp_config_file.unlink()
280
-
281
- from runnable.context import run_context
282
-
283
- job_log = run_context.run_log_store.get_run_log_by_id(run_id=run_context.run_id)
284
-
285
- if job_log.status == defaults.FAIL:
286
- raise Exception(f"Job {process_run_id} failed")