runnable 0.36.0__tar.gz → 0.37.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. {runnable-0.36.0 → runnable-0.37.0}/.gitignore +1 -0
  2. {runnable-0.36.0 → runnable-0.37.0}/PKG-INFO +1 -4
  3. {runnable-0.36.0 → runnable-0.37.0}/extensions/catalog/any_path.py +13 -2
  4. {runnable-0.36.0 → runnable-0.37.0}/extensions/job_executor/__init__.py +4 -1
  5. {runnable-0.36.0 → runnable-0.37.0}/extensions/pipeline_executor/__init__.py +3 -1
  6. {runnable-0.36.0 → runnable-0.37.0}/extensions/pipeline_executor/argo.py +2 -5
  7. {runnable-0.36.0 → runnable-0.37.0}/extensions/pipeline_executor/mocked.py +1 -5
  8. {runnable-0.36.0 → runnable-0.37.0}/pyproject.toml +5 -5
  9. {runnable-0.36.0 → runnable-0.37.0}/runnable/__init__.py +0 -2
  10. {runnable-0.36.0 → runnable-0.37.0}/runnable/catalog.py +5 -2
  11. {runnable-0.36.0 → runnable-0.37.0}/runnable/context.py +1 -0
  12. {runnable-0.36.0 → runnable-0.37.0}/runnable/nodes.py +2 -0
  13. {runnable-0.36.0 → runnable-0.37.0}/runnable/sdk.py +8 -42
  14. {runnable-0.36.0 → runnable-0.37.0}/runnable/tasks.py +0 -60
  15. runnable-0.36.0/extensions/tasks/torch.py +0 -286
  16. runnable-0.36.0/extensions/tasks/torch_config.py +0 -76
  17. {runnable-0.36.0 → runnable-0.37.0}/LICENSE +0 -0
  18. {runnable-0.36.0 → runnable-0.37.0}/README.md +0 -0
  19. {runnable-0.36.0 → runnable-0.37.0}/extensions/README.md +0 -0
  20. {runnable-0.36.0 → runnable-0.37.0}/extensions/__init__.py +0 -0
  21. {runnable-0.36.0 → runnable-0.37.0}/extensions/catalog/README.md +0 -0
  22. {runnable-0.36.0 → runnable-0.37.0}/extensions/catalog/file_system.py +0 -0
  23. {runnable-0.36.0 → runnable-0.37.0}/extensions/catalog/minio.py +0 -0
  24. {runnable-0.36.0 → runnable-0.37.0}/extensions/catalog/pyproject.toml +0 -0
  25. {runnable-0.36.0 → runnable-0.37.0}/extensions/catalog/s3.py +0 -0
  26. {runnable-0.36.0 → runnable-0.37.0}/extensions/job_executor/README.md +0 -0
  27. {runnable-0.36.0 → runnable-0.37.0}/extensions/job_executor/emulate.py +0 -0
  28. {runnable-0.36.0 → runnable-0.37.0}/extensions/job_executor/k8s.py +0 -0
  29. {runnable-0.36.0 → runnable-0.37.0}/extensions/job_executor/k8s_job_spec.yaml +0 -0
  30. {runnable-0.36.0 → runnable-0.37.0}/extensions/job_executor/local.py +0 -0
  31. {runnable-0.36.0 → runnable-0.37.0}/extensions/job_executor/local_container.py +0 -0
  32. {runnable-0.36.0 → runnable-0.37.0}/extensions/job_executor/pyproject.toml +0 -0
  33. {runnable-0.36.0 → runnable-0.37.0}/extensions/nodes/README.md +0 -0
  34. {runnable-0.36.0 → runnable-0.37.0}/extensions/nodes/__init__.py +0 -0
  35. {runnable-0.36.0 → runnable-0.37.0}/extensions/nodes/conditional.py +0 -0
  36. {runnable-0.36.0 → runnable-0.37.0}/extensions/nodes/fail.py +0 -0
  37. {runnable-0.36.0 → runnable-0.37.0}/extensions/nodes/map.py +0 -0
  38. {runnable-0.36.0 → runnable-0.37.0}/extensions/nodes/parallel.py +0 -0
  39. {runnable-0.36.0 → runnable-0.37.0}/extensions/nodes/pyproject.toml +0 -0
  40. {runnable-0.36.0 → runnable-0.37.0}/extensions/nodes/stub.py +0 -0
  41. {runnable-0.36.0 → runnable-0.37.0}/extensions/nodes/success.py +0 -0
  42. {runnable-0.36.0 → runnable-0.37.0}/extensions/nodes/task.py +0 -0
  43. {runnable-0.36.0 → runnable-0.37.0}/extensions/pipeline_executor/README.md +0 -0
  44. {runnable-0.36.0 → runnable-0.37.0}/extensions/pipeline_executor/emulate.py +0 -0
  45. {runnable-0.36.0 → runnable-0.37.0}/extensions/pipeline_executor/local.py +0 -0
  46. {runnable-0.36.0 → runnable-0.37.0}/extensions/pipeline_executor/local_container.py +0 -0
  47. {runnable-0.36.0 → runnable-0.37.0}/extensions/pipeline_executor/pyproject.toml +0 -0
  48. {runnable-0.36.0 → runnable-0.37.0}/extensions/pipeline_executor/retry.py +0 -0
  49. {runnable-0.36.0 → runnable-0.37.0}/extensions/run_log_store/README.md +0 -0
  50. {runnable-0.36.0 → runnable-0.37.0}/extensions/run_log_store/__init__.py +0 -0
  51. {runnable-0.36.0 → runnable-0.37.0}/extensions/run_log_store/any_path.py +0 -0
  52. {runnable-0.36.0 → runnable-0.37.0}/extensions/run_log_store/chunked_fs.py +0 -0
  53. {runnable-0.36.0 → runnable-0.37.0}/extensions/run_log_store/chunked_minio.py +0 -0
  54. {runnable-0.36.0 → runnable-0.37.0}/extensions/run_log_store/db/implementation_FF.py +0 -0
  55. {runnable-0.36.0 → runnable-0.37.0}/extensions/run_log_store/db/integration_FF.py +0 -0
  56. {runnable-0.36.0 → runnable-0.37.0}/extensions/run_log_store/file_system.py +0 -0
  57. {runnable-0.36.0 → runnable-0.37.0}/extensions/run_log_store/generic_chunked.py +0 -0
  58. {runnable-0.36.0 → runnable-0.37.0}/extensions/run_log_store/minio.py +0 -0
  59. {runnable-0.36.0 → runnable-0.37.0}/extensions/run_log_store/pyproject.toml +0 -0
  60. {runnable-0.36.0 → runnable-0.37.0}/extensions/secrets/README.md +0 -0
  61. {runnable-0.36.0 → runnable-0.37.0}/extensions/secrets/dotenv.py +0 -0
  62. {runnable-0.36.0 → runnable-0.37.0}/extensions/secrets/pyproject.toml +0 -0
  63. {runnable-0.36.0 → runnable-0.37.0}/runnable/cli.py +0 -0
  64. {runnable-0.36.0 → runnable-0.37.0}/runnable/datastore.py +0 -0
  65. {runnable-0.36.0 → runnable-0.37.0}/runnable/defaults.py +0 -0
  66. {runnable-0.36.0 → runnable-0.37.0}/runnable/entrypoints.py +0 -0
  67. {runnable-0.36.0 → runnable-0.37.0}/runnable/exceptions.py +0 -0
  68. {runnable-0.36.0 → runnable-0.37.0}/runnable/executor.py +0 -0
  69. {runnable-0.36.0 → runnable-0.37.0}/runnable/graph.py +0 -0
  70. {runnable-0.36.0 → runnable-0.37.0}/runnable/names.py +0 -0
  71. {runnable-0.36.0 → runnable-0.37.0}/runnable/parameters.py +0 -0
  72. {runnable-0.36.0 → runnable-0.37.0}/runnable/pickler.py +0 -0
  73. {runnable-0.36.0 → runnable-0.37.0}/runnable/secrets.py +0 -0
  74. {runnable-0.36.0 → runnable-0.37.0}/runnable/utils.py +0 -0
@@ -157,3 +157,4 @@ cov.xml
157
157
  data/
158
158
 
159
159
  minikube/
160
+ .pth # For model saving and loading
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: runnable
3
- Version: 0.36.0
3
+ Version: 0.37.0
4
4
  Summary: Add your description here
5
5
  Author-email: "Vammi, Vijay" <vijay.vammi@astrazeneca.com>
6
6
  License-File: LICENSE
@@ -26,9 +26,6 @@ Provides-Extra: notebook
26
26
  Requires-Dist: ploomber-engine>=0.0.33; extra == 'notebook'
27
27
  Provides-Extra: s3
28
28
  Requires-Dist: cloudpathlib[s3]; extra == 's3'
29
- Provides-Extra: torch
30
- Requires-Dist: accelerate>=1.5.2; extra == 'torch'
31
- Requires-Dist: torch>=2.6.0; extra == 'torch'
32
29
  Description-Content-Type: text/markdown
33
30
 
34
31
 
@@ -95,7 +95,10 @@ class AnyPathCatalog(BaseCatalog):
95
95
  return data_catalogs
96
96
 
97
97
  def put(
98
- self, name: str, allow_file_not_found_exc: bool = False
98
+ self,
99
+ name: str,
100
+ allow_file_not_found_exc: bool = False,
101
+ store_copy: bool = True,
99
102
  ) -> List[DataCatalog]:
100
103
  """
101
104
  Put the files matching the glob pattern into the catalog.
@@ -154,7 +157,15 @@ class AnyPathCatalog(BaseCatalog):
154
157
  data_catalogs.append(data_catalog)
155
158
 
156
159
  # TODO: Think about syncing only if the file is changed
157
- self.upload_to_catalog(file)
160
+ if store_copy:
161
+ logger.debug(
162
+ f"Copying file {file} to the catalog location for run_id: {run_id}"
163
+ )
164
+ self.upload_to_catalog(file)
165
+ else:
166
+ logger.debug(
167
+ f"Not copying file {file} to the catalog location for run_id: {run_id}"
168
+ )
158
169
 
159
170
  if not data_catalogs and not allow_file_not_found_exc:
160
171
  raise Exception(f"Did not find any files matching {name} in {copy_from}")
@@ -29,6 +29,7 @@ class GenericJobExecutor(BaseJobExecutor):
29
29
  @property
30
30
  def _context(self):
31
31
  assert context.run_context
32
+ assert isinstance(context.run_context, context.JobContext)
32
33
  return context.run_context
33
34
 
34
35
  def _get_parameters(self) -> Dict[str, JsonParameter]:
@@ -147,7 +148,9 @@ class GenericJobExecutor(BaseJobExecutor):
147
148
  data_catalogs = []
148
149
  for name_pattern in catalog_settings:
149
150
  data_catalog = self._context.catalog.put(
150
- name=name_pattern, allow_file_not_found_exc=allow_file_not_found_exc
151
+ name=name_pattern,
152
+ allow_file_not_found_exc=allow_file_not_found_exc,
153
+ store_copy=self._context.catalog_store_copy,
151
154
  )
152
155
 
153
156
  logger.debug(f"Added data catalog: {data_catalog} to job log")
@@ -160,7 +160,9 @@ class GenericPipelineExecutor(BasePipelineExecutor):
160
160
 
161
161
  elif stage == "put":
162
162
  data_catalog = self._context.catalog.put(
163
- name=name_pattern, allow_file_not_found_exc=allow_file_no_found_exc
163
+ name=name_pattern,
164
+ allow_file_not_found_exc=allow_file_no_found_exc,
165
+ store_copy=node_catalog_settings.get("store_copy", True),
164
166
  )
165
167
  else:
166
168
  raise Exception(f"Stage {stage} not supported")
@@ -24,9 +24,6 @@ from extensions.nodes.conditional import ConditionalNode
24
24
  from extensions.nodes.map import MapNode
25
25
  from extensions.nodes.parallel import ParallelNode
26
26
  from extensions.nodes.task import TaskNode
27
-
28
- # TODO: Should be part of a wider refactor
29
- # from extensions.nodes.torch import TorchNode
30
27
  from extensions.pipeline_executor import GenericPipelineExecutor
31
28
  from runnable import defaults
32
29
  from runnable.defaults import MapVariableType
@@ -592,7 +589,7 @@ class ArgoExecutor(GenericPipelineExecutor):
592
589
  task_name: str,
593
590
  inputs: Optional[Inputs] = None,
594
591
  ) -> ContainerTemplate:
595
- assert node.node_type in ["task", "torch", "success", "stub", "fail"]
592
+ assert node.node_type in ["task", "success", "stub", "fail"]
596
593
 
597
594
  node_override = None
598
595
  if hasattr(node, "overrides"):
@@ -655,7 +652,7 @@ class ArgoExecutor(GenericPipelineExecutor):
655
652
  def _set_env_vars_to_task(
656
653
  self, working_on: BaseNode, container_template: CoreContainerTemplate
657
654
  ):
658
- if working_on.node_type not in ["task", "torch"]:
655
+ if working_on.node_type not in ["task"]:
659
656
  return
660
657
 
661
658
  global_envs: dict[str, str] = {}
@@ -6,7 +6,7 @@ from pydantic import ConfigDict, Field
6
6
 
7
7
  from extensions.nodes.task import TaskNode
8
8
  from extensions.pipeline_executor import GenericPipelineExecutor
9
- from runnable import context, defaults
9
+ from runnable import defaults
10
10
  from runnable.defaults import MapVariableType
11
11
  from runnable.nodes import BaseNode
12
12
  from runnable.tasks import BaseTaskType
@@ -32,10 +32,6 @@ class MockedExecutor(GenericPipelineExecutor):
32
32
 
33
33
  patches: Dict[str, Any] = Field(default_factory=dict)
34
34
 
35
- @property
36
- def _context(self):
37
- return context.run_context
38
-
39
35
  def execute_from_graph(self, node: BaseNode, map_variable: MapVariableType = None):
40
36
  """
41
37
  This is the entry point to from the graph execution.
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "runnable"
3
- version = "0.36.0"
3
+ version = "0.37.0"
4
4
  description = "Add your description here"
5
5
  readme = "README.md"
6
6
  authors = [
@@ -37,10 +37,7 @@ k8s = [
37
37
  s3 = [
38
38
  "cloudpathlib[s3]"
39
39
  ]
40
- torch = [
41
- "torch>=2.6.0",
42
- "accelerate>=1.5.2",
43
- ]
40
+
44
41
 
45
42
  [dependency-groups]
46
43
  dev = [
@@ -61,6 +58,9 @@ docs = [
61
58
  release = [
62
59
  "python-semantic-release>=9.15.2",
63
60
  ]
61
+ examples-torch = [
62
+ "torch>=2.7.1",
63
+ ]
64
64
 
65
65
  [tool.uv.workspace]
66
66
  members = ["extensions/catalog",
@@ -24,8 +24,6 @@ from runnable.sdk import ( # noqa;
24
24
  ShellTask,
25
25
  Stub,
26
26
  Success,
27
- TorchJob,
28
- TorchTask,
29
27
  metric,
30
28
  pickled,
31
29
  )
@@ -57,7 +57,7 @@ class BaseCatalog(ABC, BaseModel):
57
57
 
58
58
  @abstractmethod
59
59
  def put(
60
- self, name: str, allow_file_not_found_exc: bool = False
60
+ self, name: str, allow_file_not_found_exc: bool = False, store_copy: bool = True
61
61
  ) -> List[DataCatalog]:
62
62
  """
63
63
  Put the file by 'name' from the 'compute_data_folder' in the catalog for the run_id.
@@ -120,7 +120,10 @@ class DoNothingCatalog(BaseCatalog):
120
120
  return []
121
121
 
122
122
  def put(
123
- self, name: str, allow_file_not_found_exc: bool = False
123
+ self,
124
+ name: str,
125
+ allow_file_not_found_exc: bool = False,
126
+ store_copy: bool = True,
124
127
  ) -> List[DataCatalog]:
125
128
  """
126
129
  Does nothing
@@ -475,6 +475,7 @@ class JobContext(RunnableContext):
475
475
  default=None,
476
476
  description="Catalog settings to be used for the job.",
477
477
  )
478
+ catalog_store_copy: bool = Field(default=True, alias="catalog_store_copy")
478
479
 
479
480
  @computed_field # type: ignore
480
481
  @cached_property
@@ -411,11 +411,13 @@ class TraversalNode(BaseNode):
411
411
  return self.overrides.get(executor_type) or ""
412
412
 
413
413
 
414
+ # Unfortunately, this is defined in 2 places. Look in SDK
414
415
  class CatalogStructure(BaseModel):
415
416
  model_config = ConfigDict(extra="forbid") # Need to forbid
416
417
 
417
418
  get: List[str] = Field(default_factory=list)
418
419
  put: List[str] = Field(default_factory=list)
420
+ store_copy: bool = Field(default=True, alias="store_copy")
419
421
 
420
422
 
421
423
  class ExecutableNode(TraversalNode):
@@ -40,7 +40,6 @@ StepType = Union[
40
40
  "ShellTask",
41
41
  "Parallel",
42
42
  "Map",
43
- "TorchTask",
44
43
  "Conditional",
45
44
  ]
46
45
 
@@ -61,6 +60,7 @@ class Catalog(BaseModel):
61
60
  Attributes:
62
61
  get (List[str]): List of glob patterns to get from central catalog to the compute data folder.
63
62
  put (List[str]): List of glob patterns to put into central catalog from the compute data folder.
63
+ store_copy (bool): Whether to store a copy of the data in the central catalog.
64
64
 
65
65
  Examples:
66
66
  >>> from runnable import Catalog
@@ -75,6 +75,7 @@ class Catalog(BaseModel):
75
75
  # compute_data_folder: str = Field(default="", alias="compute_data_folder")
76
76
  get: List[str] = Field(default_factory=list, alias="get")
77
77
  put: List[str] = Field(default_factory=list, alias="put")
78
+ store_copy: bool = Field(default=True, alias="store_copy")
78
79
 
79
80
 
80
81
  class BaseTraversal(ABC, BaseModel):
@@ -277,27 +278,6 @@ class PythonTask(BaseTask):
277
278
  return node.executable
278
279
 
279
280
 
280
- class TorchTask(BaseTask):
281
- # entrypoint: str = Field(
282
- # alias="entrypoint", default="torch.distributed.run", frozen=True
283
- # )
284
- # args_to_torchrun: Dict[str, Any] = Field(
285
- # default_factory=dict, alias="args_to_torchrun"
286
- # )
287
-
288
- script_to_call: str
289
- accelerate_config_file: str
290
-
291
- @computed_field
292
- def command_type(self) -> str:
293
- return "torch"
294
-
295
- def create_job(self) -> RunnableTask:
296
- self.terminate_with_success = True
297
- node = self.create_node()
298
- return node.executable
299
-
300
-
301
281
  class NotebookTask(BaseTask):
302
282
  """
303
283
  An execution node of the pipeline of notebook.
@@ -867,6 +847,11 @@ class BaseJob(BaseModel):
867
847
  return []
868
848
  return self.catalog.put
869
849
 
850
+ def return_bool_catalog_store_copy(self) -> bool:
851
+ if self.catalog is None:
852
+ return True
853
+ return self.catalog.store_copy
854
+
870
855
  def _is_called_for_definition(self) -> bool:
871
856
  """
872
857
  If the run context is set, we are coming in only to get the pipeline definition.
@@ -910,6 +895,7 @@ class BaseJob(BaseModel):
910
895
  }
911
896
 
912
897
  run_context = context.JobContext.model_validate(configurations)
898
+ run_context.catalog_store_copy = self.return_bool_catalog_store_copy()
913
899
 
914
900
  assert isinstance(run_context.job_executor, BaseJobExecutor)
915
901
 
@@ -937,26 +923,6 @@ class PythonJob(BaseJob):
937
923
  return task.create_node().executable
938
924
 
939
925
 
940
- class TorchJob(BaseJob):
941
- # entrypoint: str = Field(default="torch.distributed.run", frozen=True)
942
- # args_to_torchrun: dict[str, str | bool | int | float] = Field(
943
- # default_factory=dict
944
- # ) # For example
945
- # {"nproc_per_node": 2, "nnodes": 1,}
946
-
947
- script_to_call: str # For example train/script.py
948
- accelerate_config_file: str
949
-
950
- def get_task(self) -> RunnableTask:
951
- # Piggy bank on existing tasks as a hack
952
- task = TorchTask(
953
- name="dummy",
954
- terminate_with_success=True,
955
- **self.model_dump(exclude_defaults=True, exclude_none=True),
956
- )
957
- return task.create_node().executable
958
-
959
-
960
926
  class NotebookJob(BaseJob):
961
927
  notebook: str = Field(serialization_alias="command")
962
928
  optional_ploomber_args: Optional[Dict[str, Any]] = Field(
@@ -384,66 +384,6 @@ class PythonTaskType(BaseTaskType): # pylint: disable=too-few-public-methods
384
384
  return attempt_log
385
385
 
386
386
 
387
- class TorchTaskType(BaseTaskType):
388
- task_type: str = Field(default="torch", serialization_alias="command_type")
389
- accelerate_config_file: str
390
-
391
- script_to_call: str # For example train/script.py
392
-
393
- def execute_command(
394
- self, map_variable: Dict[str, str | int | float] | None = None
395
- ) -> StepAttempt:
396
- from accelerate.commands import launch
397
-
398
- attempt_log = StepAttempt(status=defaults.FAIL, start_time=str(datetime.now()))
399
-
400
- with (
401
- self.execution_context(
402
- map_variable=map_variable, allow_complex=False
403
- ) as params,
404
- self.expose_secrets() as _,
405
- ):
406
- try:
407
- script_args = []
408
- for key, value in params.items():
409
- script_args.append(f"--{key}")
410
- if type(value.value) is not bool:
411
- script_args.append(str(value.value))
412
-
413
- # TODO: Check the typing here
414
-
415
- logger.info("Calling the user script with the following parameters:")
416
- logger.info(script_args)
417
- out_file = TeeIO()
418
- try:
419
- with contextlib.redirect_stdout(out_file):
420
- parser = launch.launch_command_parser()
421
- args = parser.parse_args(self.script_to_call)
422
- args.training_script = self.script_to_call
423
- args.config_file = self.accelerate_config_file
424
- args.training_script_args = script_args
425
-
426
- launch.launch_command(args)
427
- task_console.print(out_file.getvalue())
428
- except Exception as e:
429
- raise exceptions.CommandCallError(
430
- f"Call to script{self.script_to_call} did not succeed."
431
- ) from e
432
- finally:
433
- sys.argv = sys.argv[:1]
434
-
435
- attempt_log.status = defaults.SUCCESS
436
- except Exception as _e:
437
- msg = f"Call to script: {self.script_to_call} did not succeed."
438
- attempt_log.message = msg
439
- task_console.print_exception(show_locals=False)
440
- task_console.log(_e, style=defaults.error_style)
441
-
442
- attempt_log.end_time = str(datetime.now())
443
-
444
- return attempt_log
445
-
446
-
447
387
  class NotebookTaskType(BaseTaskType):
448
388
  """
449
389
  --8<-- [start:notebook_reference]
@@ -1,286 +0,0 @@
1
- import importlib
2
- import logging
3
- import os
4
- import random
5
- import string
6
- from datetime import datetime
7
- from pathlib import Path
8
- from typing import Any, Optional
9
-
10
- from pydantic import BaseModel, ConfigDict, Field, field_serializer, model_validator
11
- from ruamel.yaml import YAML
12
-
13
- import runnable.context as context
14
- from extensions.tasks.torch_config import EasyTorchConfig, TorchConfig
15
- from runnable import Catalog, defaults
16
- from runnable.datastore import StepAttempt
17
- from runnable.tasks import BaseTaskType
18
- from runnable.utils import get_module_and_attr_names
19
-
20
- logger = logging.getLogger(defaults.LOGGER_NAME)
21
-
22
- logger = logging.getLogger(defaults.LOGGER_NAME)
23
-
24
- try:
25
- from torch.distributed.elastic.multiprocessing.api import DefaultLogsSpecs, Std
26
- from torch.distributed.launcher.api import LaunchConfig, elastic_launch
27
-
28
- except ImportError as e:
29
- logger.exception("torch is not installed")
30
- raise Exception("torch is not installed") from e
31
-
32
-
33
- def get_min_max_nodes(nnodes: str) -> tuple[int, int]:
34
- min_nodes, max_nodes = (int(x) for x in nnodes.split(":"))
35
- return min_nodes, max_nodes
36
-
37
-
38
- class TorchTaskType(BaseTaskType, TorchConfig):
39
- task_type: str = Field(default="torch", serialization_alias="command_type")
40
- catalog: Optional[Catalog] = Field(default=None, alias="catalog")
41
- command: str
42
-
43
- @model_validator(mode="before")
44
- @classmethod
45
- def check_secrets_and_returns(cls, data: Any) -> Any:
46
- if isinstance(data, dict):
47
- if "secrets" in data and data["secrets"]:
48
- raise ValueError("'secrets' is not supported for torch")
49
- if "returns" in data and data["returns"]:
50
- raise ValueError("'secrets' is not supported for torch")
51
- return data
52
-
53
- def get_summary(self) -> dict[str, Any]:
54
- return self.model_dump(by_alias=True, exclude_none=True)
55
-
56
- @property
57
- def _context(self):
58
- return context.run_context
59
-
60
- def _get_launch_config(self) -> LaunchConfig:
61
- internal_log_spec = InternalLogSpecs(**self.model_dump(exclude_none=True))
62
- log_spec: DefaultLogsSpecs = DefaultLogsSpecs(
63
- **internal_log_spec.model_dump(exclude_none=True)
64
- )
65
- easy_torch_config = EasyTorchConfig(
66
- **self.model_dump(
67
- exclude_none=True,
68
- )
69
- )
70
- print("###", easy_torch_config)
71
- print("###", easy_torch_config)
72
- launch_config = LaunchConfig(
73
- **easy_torch_config.model_dump(
74
- exclude_none=True,
75
- ),
76
- logs_specs=log_spec,
77
- run_id=self._context.run_id,
78
- )
79
- logger.info(f"launch_config: {launch_config}")
80
- return launch_config
81
-
82
- def execute_command(
83
- self,
84
- map_variable: defaults.MapVariableType = None,
85
- ):
86
- assert map_variable is None, "map_variable is not supported for torch"
87
-
88
- # The below should happen only if we are in the node that we want to execute
89
- # For a single node, multi worker setup, this should be the entry point
90
- # For a multi-node, we need to:
91
- # - create a service config
92
- # - Create a stateful set with number of nodes
93
- # - Create a job to run the torch.distributed.launcher.api.elastic_launch on every node
94
- # - the entry point to runnnable could be a way to trigger execution instead of scaling
95
- is_execute = os.environ.get("RUNNABLE_TORCH_EXECUTE", "true") == "true"
96
-
97
- _, max_nodes = get_min_max_nodes(self.nnodes)
98
-
99
- if max_nodes > 1 and not is_execute:
100
- executor = self._context.executor
101
- executor.scale_up(self)
102
- return StepAttempt(
103
- status=defaults.SUCCESS,
104
- start_time=str(datetime.now()),
105
- end_time=str(datetime.now()),
106
- attempt_number=1,
107
- message="Triggered a scale up",
108
- )
109
-
110
- # The below should happen only if we are in the node that we want to execute
111
- # For a single node, multi worker setup, this should be the entry point
112
- # For a multi-node, we need to:
113
- # - create a service config
114
- # - Create a stateful set with number of nodes
115
- # - Create a job to run the torch.distributed.launcher.api.elastic_launch on every node
116
- # - the entry point to runnnable could be a way to trigger execution instead of scaling
117
- is_execute = os.environ.get("RUNNABLE_TORCH_EXECUTE", "true") == "true"
118
-
119
- _, max_nodes = get_min_max_nodes(self.nnodes)
120
-
121
- if max_nodes > 1 and not is_execute:
122
- executor = self._context.executor
123
- executor.scale_up(self)
124
- return StepAttempt(
125
- status=defaults.SUCCESS,
126
- start_time=str(datetime.now()),
127
- end_time=str(datetime.now()),
128
- attempt_number=1,
129
- message="Triggered a scale up",
130
- )
131
-
132
- launch_config = self._get_launch_config()
133
- print("###****", launch_config)
134
- print("###****", launch_config)
135
- logger.info(f"launch_config: {launch_config}")
136
-
137
- # ENV variables are shared with the subprocess, use that as communication
138
- os.environ["RUNNABLE_TORCH_COMMAND"] = self.command
139
- os.environ["RUNNABLE_TORCH_PARAMETERS_FILES"] = (
140
- self._context.parameters_file or ""
141
- )
142
- os.environ["RUNNABLE_TORCH_RUN_ID"] = self._context.run_id
143
-
144
- launcher = elastic_launch(
145
- launch_config,
146
- training_subprocess,
147
- )
148
- try:
149
- launcher()
150
- attempt_log = StepAttempt(
151
- status=defaults.SUCCESS,
152
- start_time=str(datetime.now()),
153
- end_time=str(datetime.now()),
154
- attempt_number=1,
155
- )
156
- except Exception as e:
157
- attempt_log = StepAttempt(
158
- status=defaults.FAIL,
159
- start_time=str(datetime.now()),
160
- end_time=str(datetime.now()),
161
- attempt_number=1,
162
- )
163
- logger.error(f"Error executing TorchNode: {e}")
164
- finally:
165
- # This can only come from the subprocess
166
- if Path("proc_logs").exists():
167
- # Move .catalog and torch_logs to the parent node's catalog location
168
- self._context.catalog_handler.put(
169
- "proc_logs/**/*", allow_file_not_found_exc=True
170
- )
171
-
172
- # TODO: This is not working!!
173
- if self.log_dir:
174
- self._context.catalog_handler.put(
175
- self.log_dir + "/**/*", allow_file_not_found_exc=True
176
- )
177
-
178
- delete_env_vars_with_prefix("RUNNABLE_TORCH")
179
- logger.info(f"attempt_log: {attempt_log}")
180
-
181
- return attempt_log
182
-
183
-
184
- # This internal model makes it easier to extract the required fields
185
- # of log specs from user specification.
186
- # https://github.com/pytorch/pytorch/blob/main/torch/distributed/elastic/multiprocessing/api.py#L243
187
- class InternalLogSpecs(BaseModel):
188
- log_dir: Optional[str] = Field(default="torch_logs")
189
- redirects: str = Field(default="0") # Std.NONE
190
- tee: str = Field(default="0") # Std.NONE
191
- local_ranks_filter: Optional[set[int]] = Field(default=None)
192
-
193
- model_config = ConfigDict(extra="ignore")
194
-
195
- @field_serializer("redirects")
196
- def convert_redirects(self, redirects: str) -> Std | dict[int, Std]:
197
- return Std.from_str(redirects)
198
-
199
- @field_serializer("tee")
200
- def convert_tee(self, tee: str) -> Std | dict[int, Std]:
201
- return Std.from_str(tee)
202
-
203
-
204
- def delete_env_vars_with_prefix(prefix):
205
- to_delete = [] # List to keep track of variables to delete
206
-
207
- # Iterate over a list of all environment variable keys
208
- for var in os.environ:
209
- if var.startswith(prefix):
210
- to_delete.append(var)
211
-
212
- # Delete each of the variables collected
213
- for var in to_delete:
214
- del os.environ[var]
215
-
216
-
217
- def training_subprocess():
218
- """
219
- This function is called by the torch.distributed.launcher.api.elastic_launch
220
- It happens in a subprocess and is responsible for executing the user's function
221
-
222
- It is unrelated to the actual node execution, so any cataloging, run_log_store should be
223
- handled to match to main process.
224
-
225
- We have these variables to use:
226
-
227
- os.environ["RUNNABLE_TORCH_COMMAND"] = self.executable.command
228
- os.environ["RUNNABLE_TORCH_PARAMETERS_FILES"] = (
229
- self._context.parameters_file or ""
230
- )
231
- os.environ["RUNNABLE_TORCH_RUN_ID"] = self._context.run_id
232
- os.environ["RUNNABLE_TORCH_TORCH_LOGS"] = self.log_dir or ""
233
-
234
- """
235
- from runnable import PythonJob # noqa: F401
236
-
237
- command = os.environ.get("RUNNABLE_TORCH_COMMAND")
238
- assert command, "Command is not provided"
239
-
240
- run_id = os.environ.get("RUNNABLE_TORCH_RUN_ID", "")
241
- parameters_files = os.environ.get("RUNNABLE_TORCH_PARAMETERS_FILES", "")
242
-
243
- process_run_id = (
244
- run_id
245
- + "-"
246
- + os.environ.get("RANK", "")
247
- + "-"
248
- + "".join(random.choices(string.ascii_lowercase, k=3))
249
- )
250
- os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
251
-
252
- # In this subprocess there shoould not be any RUNNABLE environment variables
253
- delete_env_vars_with_prefix("RUNNABLE_")
254
-
255
- module_name, func_name = get_module_and_attr_names(command)
256
- module = importlib.import_module(module_name)
257
-
258
- callable_obj = getattr(module, func_name)
259
-
260
- # The job runs with the default configuration
261
- # ALl the execution logs are stored in .catalog
262
- job = PythonJob(function=callable_obj)
263
-
264
- config_content = {
265
- "catalog": {"type": "file-system", "config": {"catalog_location": "proc_logs"}}
266
- }
267
-
268
- temp_config_file = Path("runnable-config.yaml")
269
- with open(str(temp_config_file), "w", encoding="utf-8") as config_file:
270
- yaml = YAML(typ="safe", pure=True)
271
- yaml.dump(config_content, config_file)
272
-
273
- job.execute(
274
- parameters_file=parameters_files,
275
- job_id=process_run_id,
276
- )
277
-
278
- # delete the temp config file
279
- temp_config_file.unlink()
280
-
281
- from runnable.context import run_context
282
-
283
- job_log = run_context.run_log_store.get_run_log_by_id(run_id=run_context.run_id)
284
-
285
- if job_log.status == defaults.FAIL:
286
- raise Exception(f"Job {process_run_id} failed")
@@ -1,76 +0,0 @@
1
- from enum import Enum
2
- from typing import Any, Optional
3
-
4
- from pydantic import BaseModel, ConfigDict, Field, computed_field
5
-
6
-
7
- class StartMethod(str, Enum):
8
- spawn = "spawn"
9
- fork = "fork"
10
- forkserver = "forkserver"
11
-
12
-
13
- ## The idea is the following:
14
- # Users can configure any of the options present in TorchConfig class.
15
- # The LaunchConfig class will be created from TorchConfig.
16
- # The LogSpecs is sent as a parameter to the launch config.
17
-
18
- ## NO idea of standalone and how to send it
19
-
20
-
21
- # The user sees this as part of the config of the node.
22
- # It is kept as similar as possible to torchrun
23
- class TorchConfig(BaseModel):
24
- model_config = ConfigDict(extra="forbid")
25
-
26
- # excluded as LaunchConfig requires min and max nodes
27
- nnodes: str = Field(default="1:1", exclude=True, description="min:max")
28
- nproc_per_node: int = Field(default=1, description="Number of processes per node")
29
-
30
- # will be used to create the log specs
31
- # But they are excluded from dump as logs specs is a class for LaunchConfig
32
- # from_str("0") -> Std.NONE
33
- # from_str("1") -> Std.OUT
34
- # from_str("0:3,1:0,2:1,3:2") -> {0: Std.ALL, 1: Std.NONE, 2: Std.OUT, 3: Std.ERR}
35
- log_dir: Optional[str] = Field(default="torch_logs", exclude=True)
36
- redirects: str = Field(default="0", exclude=True) # Std.NONE
37
- tee: str = Field(default="0", exclude=True) # Std.NONE
38
- local_ranks_filter: Optional[set[int]] = Field(default=None, exclude=True)
39
-
40
- role: str | None = Field(default=None)
41
-
42
- # run_id would be the run_id of the context
43
- # and sent at the creation of the LaunchConfig
44
-
45
- # This section is about the communication between nodes/processes
46
- rdzv_backend: str | None = Field(default="")
47
- rdzv_endpoint: str | None = Field(default="")
48
- rdzv_configs: dict[str, Any] = Field(default_factory=dict)
49
- rdzv_timeout: int | None = Field(default=None)
50
-
51
- max_restarts: int | None = Field(default=None)
52
- monitor_interval: float | None = Field(default=None)
53
- start_method: str | None = Field(default=StartMethod.spawn)
54
- log_line_prefix_template: str | None = Field(default=None)
55
- local_addr: Optional[str] = None
56
-
57
- # https://github.com/pytorch/pytorch/blob/main/torch/distributed/run.py#L753
58
- # master_addr: str | None = Field(default="localhost")
59
- # master_port: str | None = Field(default="29500")
60
- # training_script: str = Field(default="dummy_training_script")
61
- # training_script_args: str = Field(default="")
62
-
63
-
64
- class EasyTorchConfig(TorchConfig):
65
- model_config = ConfigDict(extra="ignore")
66
-
67
- # TODO: Validate min < max
68
- @computed_field # type: ignore
69
- @property
70
- def min_nodes(self) -> int:
71
- return int(self.nnodes.split(":")[0])
72
-
73
- @computed_field # type: ignore
74
- @property
75
- def max_nodes(self) -> int:
76
- return int(self.nnodes.split(":")[1])
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes