runnable 0.36.0__tar.gz → 0.36.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. {runnable-0.36.0 → runnable-0.36.1}/PKG-INFO +1 -4
  2. {runnable-0.36.0 → runnable-0.36.1}/extensions/pipeline_executor/argo.py +2 -5
  3. {runnable-0.36.0 → runnable-0.36.1}/extensions/pipeline_executor/mocked.py +1 -5
  4. {runnable-0.36.0 → runnable-0.36.1}/pyproject.toml +2 -5
  5. {runnable-0.36.0 → runnable-0.36.1}/runnable/__init__.py +0 -2
  6. {runnable-0.36.0 → runnable-0.36.1}/runnable/sdk.py +0 -42
  7. {runnable-0.36.0 → runnable-0.36.1}/runnable/tasks.py +0 -60
  8. runnable-0.36.0/extensions/tasks/torch.py +0 -286
  9. runnable-0.36.0/extensions/tasks/torch_config.py +0 -76
  10. {runnable-0.36.0 → runnable-0.36.1}/.gitignore +0 -0
  11. {runnable-0.36.0 → runnable-0.36.1}/LICENSE +0 -0
  12. {runnable-0.36.0 → runnable-0.36.1}/README.md +0 -0
  13. {runnable-0.36.0 → runnable-0.36.1}/extensions/README.md +0 -0
  14. {runnable-0.36.0 → runnable-0.36.1}/extensions/__init__.py +0 -0
  15. {runnable-0.36.0 → runnable-0.36.1}/extensions/catalog/README.md +0 -0
  16. {runnable-0.36.0 → runnable-0.36.1}/extensions/catalog/any_path.py +0 -0
  17. {runnable-0.36.0 → runnable-0.36.1}/extensions/catalog/file_system.py +0 -0
  18. {runnable-0.36.0 → runnable-0.36.1}/extensions/catalog/minio.py +0 -0
  19. {runnable-0.36.0 → runnable-0.36.1}/extensions/catalog/pyproject.toml +0 -0
  20. {runnable-0.36.0 → runnable-0.36.1}/extensions/catalog/s3.py +0 -0
  21. {runnable-0.36.0 → runnable-0.36.1}/extensions/job_executor/README.md +0 -0
  22. {runnable-0.36.0 → runnable-0.36.1}/extensions/job_executor/__init__.py +0 -0
  23. {runnable-0.36.0 → runnable-0.36.1}/extensions/job_executor/emulate.py +0 -0
  24. {runnable-0.36.0 → runnable-0.36.1}/extensions/job_executor/k8s.py +0 -0
  25. {runnable-0.36.0 → runnable-0.36.1}/extensions/job_executor/k8s_job_spec.yaml +0 -0
  26. {runnable-0.36.0 → runnable-0.36.1}/extensions/job_executor/local.py +0 -0
  27. {runnable-0.36.0 → runnable-0.36.1}/extensions/job_executor/local_container.py +0 -0
  28. {runnable-0.36.0 → runnable-0.36.1}/extensions/job_executor/pyproject.toml +0 -0
  29. {runnable-0.36.0 → runnable-0.36.1}/extensions/nodes/README.md +0 -0
  30. {runnable-0.36.0 → runnable-0.36.1}/extensions/nodes/__init__.py +0 -0
  31. {runnable-0.36.0 → runnable-0.36.1}/extensions/nodes/conditional.py +0 -0
  32. {runnable-0.36.0 → runnable-0.36.1}/extensions/nodes/fail.py +0 -0
  33. {runnable-0.36.0 → runnable-0.36.1}/extensions/nodes/map.py +0 -0
  34. {runnable-0.36.0 → runnable-0.36.1}/extensions/nodes/parallel.py +0 -0
  35. {runnable-0.36.0 → runnable-0.36.1}/extensions/nodes/pyproject.toml +0 -0
  36. {runnable-0.36.0 → runnable-0.36.1}/extensions/nodes/stub.py +0 -0
  37. {runnable-0.36.0 → runnable-0.36.1}/extensions/nodes/success.py +0 -0
  38. {runnable-0.36.0 → runnable-0.36.1}/extensions/nodes/task.py +0 -0
  39. {runnable-0.36.0 → runnable-0.36.1}/extensions/pipeline_executor/README.md +0 -0
  40. {runnable-0.36.0 → runnable-0.36.1}/extensions/pipeline_executor/__init__.py +0 -0
  41. {runnable-0.36.0 → runnable-0.36.1}/extensions/pipeline_executor/emulate.py +0 -0
  42. {runnable-0.36.0 → runnable-0.36.1}/extensions/pipeline_executor/local.py +0 -0
  43. {runnable-0.36.0 → runnable-0.36.1}/extensions/pipeline_executor/local_container.py +0 -0
  44. {runnable-0.36.0 → runnable-0.36.1}/extensions/pipeline_executor/pyproject.toml +0 -0
  45. {runnable-0.36.0 → runnable-0.36.1}/extensions/pipeline_executor/retry.py +0 -0
  46. {runnable-0.36.0 → runnable-0.36.1}/extensions/run_log_store/README.md +0 -0
  47. {runnable-0.36.0 → runnable-0.36.1}/extensions/run_log_store/__init__.py +0 -0
  48. {runnable-0.36.0 → runnable-0.36.1}/extensions/run_log_store/any_path.py +0 -0
  49. {runnable-0.36.0 → runnable-0.36.1}/extensions/run_log_store/chunked_fs.py +0 -0
  50. {runnable-0.36.0 → runnable-0.36.1}/extensions/run_log_store/chunked_minio.py +0 -0
  51. {runnable-0.36.0 → runnable-0.36.1}/extensions/run_log_store/db/implementation_FF.py +0 -0
  52. {runnable-0.36.0 → runnable-0.36.1}/extensions/run_log_store/db/integration_FF.py +0 -0
  53. {runnable-0.36.0 → runnable-0.36.1}/extensions/run_log_store/file_system.py +0 -0
  54. {runnable-0.36.0 → runnable-0.36.1}/extensions/run_log_store/generic_chunked.py +0 -0
  55. {runnable-0.36.0 → runnable-0.36.1}/extensions/run_log_store/minio.py +0 -0
  56. {runnable-0.36.0 → runnable-0.36.1}/extensions/run_log_store/pyproject.toml +0 -0
  57. {runnable-0.36.0 → runnable-0.36.1}/extensions/secrets/README.md +0 -0
  58. {runnable-0.36.0 → runnable-0.36.1}/extensions/secrets/dotenv.py +0 -0
  59. {runnable-0.36.0 → runnable-0.36.1}/extensions/secrets/pyproject.toml +0 -0
  60. {runnable-0.36.0 → runnable-0.36.1}/runnable/catalog.py +0 -0
  61. {runnable-0.36.0 → runnable-0.36.1}/runnable/cli.py +0 -0
  62. {runnable-0.36.0 → runnable-0.36.1}/runnable/context.py +0 -0
  63. {runnable-0.36.0 → runnable-0.36.1}/runnable/datastore.py +0 -0
  64. {runnable-0.36.0 → runnable-0.36.1}/runnable/defaults.py +0 -0
  65. {runnable-0.36.0 → runnable-0.36.1}/runnable/entrypoints.py +0 -0
  66. {runnable-0.36.0 → runnable-0.36.1}/runnable/exceptions.py +0 -0
  67. {runnable-0.36.0 → runnable-0.36.1}/runnable/executor.py +0 -0
  68. {runnable-0.36.0 → runnable-0.36.1}/runnable/graph.py +0 -0
  69. {runnable-0.36.0 → runnable-0.36.1}/runnable/names.py +0 -0
  70. {runnable-0.36.0 → runnable-0.36.1}/runnable/nodes.py +0 -0
  71. {runnable-0.36.0 → runnable-0.36.1}/runnable/parameters.py +0 -0
  72. {runnable-0.36.0 → runnable-0.36.1}/runnable/pickler.py +0 -0
  73. {runnable-0.36.0 → runnable-0.36.1}/runnable/secrets.py +0 -0
  74. {runnable-0.36.0 → runnable-0.36.1}/runnable/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: runnable
3
- Version: 0.36.0
3
+ Version: 0.36.1
4
4
  Summary: Add your description here
5
5
  Author-email: "Vammi, Vijay" <vijay.vammi@astrazeneca.com>
6
6
  License-File: LICENSE
@@ -26,9 +26,6 @@ Provides-Extra: notebook
26
26
  Requires-Dist: ploomber-engine>=0.0.33; extra == 'notebook'
27
27
  Provides-Extra: s3
28
28
  Requires-Dist: cloudpathlib[s3]; extra == 's3'
29
- Provides-Extra: torch
30
- Requires-Dist: accelerate>=1.5.2; extra == 'torch'
31
- Requires-Dist: torch>=2.6.0; extra == 'torch'
32
29
  Description-Content-Type: text/markdown
33
30
 
34
31
 
@@ -24,9 +24,6 @@ from extensions.nodes.conditional import ConditionalNode
24
24
  from extensions.nodes.map import MapNode
25
25
  from extensions.nodes.parallel import ParallelNode
26
26
  from extensions.nodes.task import TaskNode
27
-
28
- # TODO: Should be part of a wider refactor
29
- # from extensions.nodes.torch import TorchNode
30
27
  from extensions.pipeline_executor import GenericPipelineExecutor
31
28
  from runnable import defaults
32
29
  from runnable.defaults import MapVariableType
@@ -592,7 +589,7 @@ class ArgoExecutor(GenericPipelineExecutor):
592
589
  task_name: str,
593
590
  inputs: Optional[Inputs] = None,
594
591
  ) -> ContainerTemplate:
595
- assert node.node_type in ["task", "torch", "success", "stub", "fail"]
592
+ assert node.node_type in ["task", "success", "stub", "fail"]
596
593
 
597
594
  node_override = None
598
595
  if hasattr(node, "overrides"):
@@ -655,7 +652,7 @@ class ArgoExecutor(GenericPipelineExecutor):
655
652
  def _set_env_vars_to_task(
656
653
  self, working_on: BaseNode, container_template: CoreContainerTemplate
657
654
  ):
658
- if working_on.node_type not in ["task", "torch"]:
655
+ if working_on.node_type not in ["task"]:
659
656
  return
660
657
 
661
658
  global_envs: dict[str, str] = {}
@@ -6,7 +6,7 @@ from pydantic import ConfigDict, Field
6
6
 
7
7
  from extensions.nodes.task import TaskNode
8
8
  from extensions.pipeline_executor import GenericPipelineExecutor
9
- from runnable import context, defaults
9
+ from runnable import defaults
10
10
  from runnable.defaults import MapVariableType
11
11
  from runnable.nodes import BaseNode
12
12
  from runnable.tasks import BaseTaskType
@@ -32,10 +32,6 @@ class MockedExecutor(GenericPipelineExecutor):
32
32
 
33
33
  patches: Dict[str, Any] = Field(default_factory=dict)
34
34
 
35
- @property
36
- def _context(self):
37
- return context.run_context
38
-
39
35
  def execute_from_graph(self, node: BaseNode, map_variable: MapVariableType = None):
40
36
  """
41
37
  This is the entry point to from the graph execution.
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "runnable"
3
- version = "0.36.0"
3
+ version = "0.36.1"
4
4
  description = "Add your description here"
5
5
  readme = "README.md"
6
6
  authors = [
@@ -37,10 +37,7 @@ k8s = [
37
37
  s3 = [
38
38
  "cloudpathlib[s3]"
39
39
  ]
40
- torch = [
41
- "torch>=2.6.0",
42
- "accelerate>=1.5.2",
43
- ]
40
+
44
41
 
45
42
  [dependency-groups]
46
43
  dev = [
@@ -24,8 +24,6 @@ from runnable.sdk import ( # noqa;
24
24
  ShellTask,
25
25
  Stub,
26
26
  Success,
27
- TorchJob,
28
- TorchTask,
29
27
  metric,
30
28
  pickled,
31
29
  )
@@ -40,7 +40,6 @@ StepType = Union[
40
40
  "ShellTask",
41
41
  "Parallel",
42
42
  "Map",
43
- "TorchTask",
44
43
  "Conditional",
45
44
  ]
46
45
 
@@ -277,27 +276,6 @@ class PythonTask(BaseTask):
277
276
  return node.executable
278
277
 
279
278
 
280
- class TorchTask(BaseTask):
281
- # entrypoint: str = Field(
282
- # alias="entrypoint", default="torch.distributed.run", frozen=True
283
- # )
284
- # args_to_torchrun: Dict[str, Any] = Field(
285
- # default_factory=dict, alias="args_to_torchrun"
286
- # )
287
-
288
- script_to_call: str
289
- accelerate_config_file: str
290
-
291
- @computed_field
292
- def command_type(self) -> str:
293
- return "torch"
294
-
295
- def create_job(self) -> RunnableTask:
296
- self.terminate_with_success = True
297
- node = self.create_node()
298
- return node.executable
299
-
300
-
301
279
  class NotebookTask(BaseTask):
302
280
  """
303
281
  An execution node of the pipeline of notebook.
@@ -937,26 +915,6 @@ class PythonJob(BaseJob):
937
915
  return task.create_node().executable
938
916
 
939
917
 
940
- class TorchJob(BaseJob):
941
- # entrypoint: str = Field(default="torch.distributed.run", frozen=True)
942
- # args_to_torchrun: dict[str, str | bool | int | float] = Field(
943
- # default_factory=dict
944
- # ) # For example
945
- # {"nproc_per_node": 2, "nnodes": 1,}
946
-
947
- script_to_call: str # For example train/script.py
948
- accelerate_config_file: str
949
-
950
- def get_task(self) -> RunnableTask:
951
- # Piggy bank on existing tasks as a hack
952
- task = TorchTask(
953
- name="dummy",
954
- terminate_with_success=True,
955
- **self.model_dump(exclude_defaults=True, exclude_none=True),
956
- )
957
- return task.create_node().executable
958
-
959
-
960
918
  class NotebookJob(BaseJob):
961
919
  notebook: str = Field(serialization_alias="command")
962
920
  optional_ploomber_args: Optional[Dict[str, Any]] = Field(
@@ -384,66 +384,6 @@ class PythonTaskType(BaseTaskType): # pylint: disable=too-few-public-methods
384
384
  return attempt_log
385
385
 
386
386
 
387
- class TorchTaskType(BaseTaskType):
388
- task_type: str = Field(default="torch", serialization_alias="command_type")
389
- accelerate_config_file: str
390
-
391
- script_to_call: str # For example train/script.py
392
-
393
- def execute_command(
394
- self, map_variable: Dict[str, str | int | float] | None = None
395
- ) -> StepAttempt:
396
- from accelerate.commands import launch
397
-
398
- attempt_log = StepAttempt(status=defaults.FAIL, start_time=str(datetime.now()))
399
-
400
- with (
401
- self.execution_context(
402
- map_variable=map_variable, allow_complex=False
403
- ) as params,
404
- self.expose_secrets() as _,
405
- ):
406
- try:
407
- script_args = []
408
- for key, value in params.items():
409
- script_args.append(f"--{key}")
410
- if type(value.value) is not bool:
411
- script_args.append(str(value.value))
412
-
413
- # TODO: Check the typing here
414
-
415
- logger.info("Calling the user script with the following parameters:")
416
- logger.info(script_args)
417
- out_file = TeeIO()
418
- try:
419
- with contextlib.redirect_stdout(out_file):
420
- parser = launch.launch_command_parser()
421
- args = parser.parse_args(self.script_to_call)
422
- args.training_script = self.script_to_call
423
- args.config_file = self.accelerate_config_file
424
- args.training_script_args = script_args
425
-
426
- launch.launch_command(args)
427
- task_console.print(out_file.getvalue())
428
- except Exception as e:
429
- raise exceptions.CommandCallError(
430
- f"Call to script{self.script_to_call} did not succeed."
431
- ) from e
432
- finally:
433
- sys.argv = sys.argv[:1]
434
-
435
- attempt_log.status = defaults.SUCCESS
436
- except Exception as _e:
437
- msg = f"Call to script: {self.script_to_call} did not succeed."
438
- attempt_log.message = msg
439
- task_console.print_exception(show_locals=False)
440
- task_console.log(_e, style=defaults.error_style)
441
-
442
- attempt_log.end_time = str(datetime.now())
443
-
444
- return attempt_log
445
-
446
-
447
387
  class NotebookTaskType(BaseTaskType):
448
388
  """
449
389
  --8<-- [start:notebook_reference]
@@ -1,286 +0,0 @@
1
- import importlib
2
- import logging
3
- import os
4
- import random
5
- import string
6
- from datetime import datetime
7
- from pathlib import Path
8
- from typing import Any, Optional
9
-
10
- from pydantic import BaseModel, ConfigDict, Field, field_serializer, model_validator
11
- from ruamel.yaml import YAML
12
-
13
- import runnable.context as context
14
- from extensions.tasks.torch_config import EasyTorchConfig, TorchConfig
15
- from runnable import Catalog, defaults
16
- from runnable.datastore import StepAttempt
17
- from runnable.tasks import BaseTaskType
18
- from runnable.utils import get_module_and_attr_names
19
-
20
- logger = logging.getLogger(defaults.LOGGER_NAME)
21
-
22
- logger = logging.getLogger(defaults.LOGGER_NAME)
23
-
24
- try:
25
- from torch.distributed.elastic.multiprocessing.api import DefaultLogsSpecs, Std
26
- from torch.distributed.launcher.api import LaunchConfig, elastic_launch
27
-
28
- except ImportError as e:
29
- logger.exception("torch is not installed")
30
- raise Exception("torch is not installed") from e
31
-
32
-
33
- def get_min_max_nodes(nnodes: str) -> tuple[int, int]:
34
- min_nodes, max_nodes = (int(x) for x in nnodes.split(":"))
35
- return min_nodes, max_nodes
36
-
37
-
38
- class TorchTaskType(BaseTaskType, TorchConfig):
39
- task_type: str = Field(default="torch", serialization_alias="command_type")
40
- catalog: Optional[Catalog] = Field(default=None, alias="catalog")
41
- command: str
42
-
43
- @model_validator(mode="before")
44
- @classmethod
45
- def check_secrets_and_returns(cls, data: Any) -> Any:
46
- if isinstance(data, dict):
47
- if "secrets" in data and data["secrets"]:
48
- raise ValueError("'secrets' is not supported for torch")
49
- if "returns" in data and data["returns"]:
50
- raise ValueError("'secrets' is not supported for torch")
51
- return data
52
-
53
- def get_summary(self) -> dict[str, Any]:
54
- return self.model_dump(by_alias=True, exclude_none=True)
55
-
56
- @property
57
- def _context(self):
58
- return context.run_context
59
-
60
- def _get_launch_config(self) -> LaunchConfig:
61
- internal_log_spec = InternalLogSpecs(**self.model_dump(exclude_none=True))
62
- log_spec: DefaultLogsSpecs = DefaultLogsSpecs(
63
- **internal_log_spec.model_dump(exclude_none=True)
64
- )
65
- easy_torch_config = EasyTorchConfig(
66
- **self.model_dump(
67
- exclude_none=True,
68
- )
69
- )
70
- print("###", easy_torch_config)
71
- print("###", easy_torch_config)
72
- launch_config = LaunchConfig(
73
- **easy_torch_config.model_dump(
74
- exclude_none=True,
75
- ),
76
- logs_specs=log_spec,
77
- run_id=self._context.run_id,
78
- )
79
- logger.info(f"launch_config: {launch_config}")
80
- return launch_config
81
-
82
- def execute_command(
83
- self,
84
- map_variable: defaults.MapVariableType = None,
85
- ):
86
- assert map_variable is None, "map_variable is not supported for torch"
87
-
88
- # The below should happen only if we are in the node that we want to execute
89
- # For a single node, multi worker setup, this should be the entry point
90
- # For a multi-node, we need to:
91
- # - create a service config
92
- # - Create a stateful set with number of nodes
93
- # - Create a job to run the torch.distributed.launcher.api.elastic_launch on every node
94
- # - the entry point to runnnable could be a way to trigger execution instead of scaling
95
- is_execute = os.environ.get("RUNNABLE_TORCH_EXECUTE", "true") == "true"
96
-
97
- _, max_nodes = get_min_max_nodes(self.nnodes)
98
-
99
- if max_nodes > 1 and not is_execute:
100
- executor = self._context.executor
101
- executor.scale_up(self)
102
- return StepAttempt(
103
- status=defaults.SUCCESS,
104
- start_time=str(datetime.now()),
105
- end_time=str(datetime.now()),
106
- attempt_number=1,
107
- message="Triggered a scale up",
108
- )
109
-
110
- # The below should happen only if we are in the node that we want to execute
111
- # For a single node, multi worker setup, this should be the entry point
112
- # For a multi-node, we need to:
113
- # - create a service config
114
- # - Create a stateful set with number of nodes
115
- # - Create a job to run the torch.distributed.launcher.api.elastic_launch on every node
116
- # - the entry point to runnnable could be a way to trigger execution instead of scaling
117
- is_execute = os.environ.get("RUNNABLE_TORCH_EXECUTE", "true") == "true"
118
-
119
- _, max_nodes = get_min_max_nodes(self.nnodes)
120
-
121
- if max_nodes > 1 and not is_execute:
122
- executor = self._context.executor
123
- executor.scale_up(self)
124
- return StepAttempt(
125
- status=defaults.SUCCESS,
126
- start_time=str(datetime.now()),
127
- end_time=str(datetime.now()),
128
- attempt_number=1,
129
- message="Triggered a scale up",
130
- )
131
-
132
- launch_config = self._get_launch_config()
133
- print("###****", launch_config)
134
- print("###****", launch_config)
135
- logger.info(f"launch_config: {launch_config}")
136
-
137
- # ENV variables are shared with the subprocess, use that as communication
138
- os.environ["RUNNABLE_TORCH_COMMAND"] = self.command
139
- os.environ["RUNNABLE_TORCH_PARAMETERS_FILES"] = (
140
- self._context.parameters_file or ""
141
- )
142
- os.environ["RUNNABLE_TORCH_RUN_ID"] = self._context.run_id
143
-
144
- launcher = elastic_launch(
145
- launch_config,
146
- training_subprocess,
147
- )
148
- try:
149
- launcher()
150
- attempt_log = StepAttempt(
151
- status=defaults.SUCCESS,
152
- start_time=str(datetime.now()),
153
- end_time=str(datetime.now()),
154
- attempt_number=1,
155
- )
156
- except Exception as e:
157
- attempt_log = StepAttempt(
158
- status=defaults.FAIL,
159
- start_time=str(datetime.now()),
160
- end_time=str(datetime.now()),
161
- attempt_number=1,
162
- )
163
- logger.error(f"Error executing TorchNode: {e}")
164
- finally:
165
- # This can only come from the subprocess
166
- if Path("proc_logs").exists():
167
- # Move .catalog and torch_logs to the parent node's catalog location
168
- self._context.catalog_handler.put(
169
- "proc_logs/**/*", allow_file_not_found_exc=True
170
- )
171
-
172
- # TODO: This is not working!!
173
- if self.log_dir:
174
- self._context.catalog_handler.put(
175
- self.log_dir + "/**/*", allow_file_not_found_exc=True
176
- )
177
-
178
- delete_env_vars_with_prefix("RUNNABLE_TORCH")
179
- logger.info(f"attempt_log: {attempt_log}")
180
-
181
- return attempt_log
182
-
183
-
184
- # This internal model makes it easier to extract the required fields
185
- # of log specs from user specification.
186
- # https://github.com/pytorch/pytorch/blob/main/torch/distributed/elastic/multiprocessing/api.py#L243
187
- class InternalLogSpecs(BaseModel):
188
- log_dir: Optional[str] = Field(default="torch_logs")
189
- redirects: str = Field(default="0") # Std.NONE
190
- tee: str = Field(default="0") # Std.NONE
191
- local_ranks_filter: Optional[set[int]] = Field(default=None)
192
-
193
- model_config = ConfigDict(extra="ignore")
194
-
195
- @field_serializer("redirects")
196
- def convert_redirects(self, redirects: str) -> Std | dict[int, Std]:
197
- return Std.from_str(redirects)
198
-
199
- @field_serializer("tee")
200
- def convert_tee(self, tee: str) -> Std | dict[int, Std]:
201
- return Std.from_str(tee)
202
-
203
-
204
- def delete_env_vars_with_prefix(prefix):
205
- to_delete = [] # List to keep track of variables to delete
206
-
207
- # Iterate over a list of all environment variable keys
208
- for var in os.environ:
209
- if var.startswith(prefix):
210
- to_delete.append(var)
211
-
212
- # Delete each of the variables collected
213
- for var in to_delete:
214
- del os.environ[var]
215
-
216
-
217
- def training_subprocess():
218
- """
219
- This function is called by the torch.distributed.launcher.api.elastic_launch
220
- It happens in a subprocess and is responsible for executing the user's function
221
-
222
- It is unrelated to the actual node execution, so any cataloging, run_log_store should be
223
- handled to match to main process.
224
-
225
- We have these variables to use:
226
-
227
- os.environ["RUNNABLE_TORCH_COMMAND"] = self.executable.command
228
- os.environ["RUNNABLE_TORCH_PARAMETERS_FILES"] = (
229
- self._context.parameters_file or ""
230
- )
231
- os.environ["RUNNABLE_TORCH_RUN_ID"] = self._context.run_id
232
- os.environ["RUNNABLE_TORCH_TORCH_LOGS"] = self.log_dir or ""
233
-
234
- """
235
- from runnable import PythonJob # noqa: F401
236
-
237
- command = os.environ.get("RUNNABLE_TORCH_COMMAND")
238
- assert command, "Command is not provided"
239
-
240
- run_id = os.environ.get("RUNNABLE_TORCH_RUN_ID", "")
241
- parameters_files = os.environ.get("RUNNABLE_TORCH_PARAMETERS_FILES", "")
242
-
243
- process_run_id = (
244
- run_id
245
- + "-"
246
- + os.environ.get("RANK", "")
247
- + "-"
248
- + "".join(random.choices(string.ascii_lowercase, k=3))
249
- )
250
- os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
251
-
252
- # In this subprocess there shoould not be any RUNNABLE environment variables
253
- delete_env_vars_with_prefix("RUNNABLE_")
254
-
255
- module_name, func_name = get_module_and_attr_names(command)
256
- module = importlib.import_module(module_name)
257
-
258
- callable_obj = getattr(module, func_name)
259
-
260
- # The job runs with the default configuration
261
- # ALl the execution logs are stored in .catalog
262
- job = PythonJob(function=callable_obj)
263
-
264
- config_content = {
265
- "catalog": {"type": "file-system", "config": {"catalog_location": "proc_logs"}}
266
- }
267
-
268
- temp_config_file = Path("runnable-config.yaml")
269
- with open(str(temp_config_file), "w", encoding="utf-8") as config_file:
270
- yaml = YAML(typ="safe", pure=True)
271
- yaml.dump(config_content, config_file)
272
-
273
- job.execute(
274
- parameters_file=parameters_files,
275
- job_id=process_run_id,
276
- )
277
-
278
- # delete the temp config file
279
- temp_config_file.unlink()
280
-
281
- from runnable.context import run_context
282
-
283
- job_log = run_context.run_log_store.get_run_log_by_id(run_id=run_context.run_id)
284
-
285
- if job_log.status == defaults.FAIL:
286
- raise Exception(f"Job {process_run_id} failed")
@@ -1,76 +0,0 @@
1
- from enum import Enum
2
- from typing import Any, Optional
3
-
4
- from pydantic import BaseModel, ConfigDict, Field, computed_field
5
-
6
-
7
- class StartMethod(str, Enum):
8
- spawn = "spawn"
9
- fork = "fork"
10
- forkserver = "forkserver"
11
-
12
-
13
- ## The idea is the following:
14
- # Users can configure any of the options present in TorchConfig class.
15
- # The LaunchConfig class will be created from TorchConfig.
16
- # The LogSpecs is sent as a parameter to the launch config.
17
-
18
- ## NO idea of standalone and how to send it
19
-
20
-
21
- # The user sees this as part of the config of the node.
22
- # It is kept as similar as possible to torchrun
23
- class TorchConfig(BaseModel):
24
- model_config = ConfigDict(extra="forbid")
25
-
26
- # excluded as LaunchConfig requires min and max nodes
27
- nnodes: str = Field(default="1:1", exclude=True, description="min:max")
28
- nproc_per_node: int = Field(default=1, description="Number of processes per node")
29
-
30
- # will be used to create the log specs
31
- # But they are excluded from dump as logs specs is a class for LaunchConfig
32
- # from_str("0") -> Std.NONE
33
- # from_str("1") -> Std.OUT
34
- # from_str("0:3,1:0,2:1,3:2") -> {0: Std.ALL, 1: Std.NONE, 2: Std.OUT, 3: Std.ERR}
35
- log_dir: Optional[str] = Field(default="torch_logs", exclude=True)
36
- redirects: str = Field(default="0", exclude=True) # Std.NONE
37
- tee: str = Field(default="0", exclude=True) # Std.NONE
38
- local_ranks_filter: Optional[set[int]] = Field(default=None, exclude=True)
39
-
40
- role: str | None = Field(default=None)
41
-
42
- # run_id would be the run_id of the context
43
- # and sent at the creation of the LaunchConfig
44
-
45
- # This section is about the communication between nodes/processes
46
- rdzv_backend: str | None = Field(default="")
47
- rdzv_endpoint: str | None = Field(default="")
48
- rdzv_configs: dict[str, Any] = Field(default_factory=dict)
49
- rdzv_timeout: int | None = Field(default=None)
50
-
51
- max_restarts: int | None = Field(default=None)
52
- monitor_interval: float | None = Field(default=None)
53
- start_method: str | None = Field(default=StartMethod.spawn)
54
- log_line_prefix_template: str | None = Field(default=None)
55
- local_addr: Optional[str] = None
56
-
57
- # https://github.com/pytorch/pytorch/blob/main/torch/distributed/run.py#L753
58
- # master_addr: str | None = Field(default="localhost")
59
- # master_port: str | None = Field(default="29500")
60
- # training_script: str = Field(default="dummy_training_script")
61
- # training_script_args: str = Field(default="")
62
-
63
-
64
- class EasyTorchConfig(TorchConfig):
65
- model_config = ConfigDict(extra="ignore")
66
-
67
- # TODO: Validate min < max
68
- @computed_field # type: ignore
69
- @property
70
- def min_nodes(self) -> int:
71
- return int(self.nnodes.split(":")[0])
72
-
73
- @computed_field # type: ignore
74
- @property
75
- def max_nodes(self) -> int:
76
- return int(self.nnodes.split(":")[1])
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes