runnable 0.32.0__py3-none-any.whl → 0.32.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -329,7 +329,7 @@ class GenericK8sJobExecutor(GenericJobExecutor):
329
329
 
330
330
  logger.info(f"Submitting job: {job.__dict__}")
331
331
  if self.mock:
332
- print(job.__dict__)
332
+ logger.info(job.__dict__)
333
333
  return
334
334
 
335
335
  try:
extensions/nodes/torch.py CHANGED
@@ -5,7 +5,7 @@ import random
5
5
  import string
6
6
  from datetime import datetime
7
7
  from pathlib import Path
8
- from typing import Any, Callable, Optional
8
+ from typing import TYPE_CHECKING, Any, Callable, Optional
9
9
 
10
10
  from pydantic import BaseModel, ConfigDict, Field, field_serializer
11
11
 
@@ -21,11 +21,12 @@ logger = logging.getLogger(defaults.LOGGER_NAME)
21
21
  try:
22
22
  from torch.distributed.elastic.multiprocessing.api import DefaultLogsSpecs, Std
23
23
  from torch.distributed.launcher.api import LaunchConfig, elastic_launch
24
-
25
24
  except ImportError:
26
- raise ImportError("torch is not installed. Please install torch first.")
25
+ logger.exception("Torch is not installed. Please install torch first.")
27
26
 
28
- print("torch is installed")
27
+ if TYPE_CHECKING:
28
+ from torch.distributed.elastic.multiprocessing.api import DefaultLogsSpecs, Std
29
+ from torch.distributed.launcher.api import LaunchConfig, elastic_launch
29
30
 
30
31
 
31
32
  def training_subprocess():
extensions/tasks/torch.py CHANGED
@@ -5,7 +5,7 @@ import random
5
5
  import string
6
6
  from datetime import datetime
7
7
  from pathlib import Path
8
- from typing import Any, Optional
8
+ from typing import TYPE_CHECKING, Any, Optional
9
9
 
10
10
  from pydantic import BaseModel, ConfigDict, Field, field_serializer, model_validator
11
11
  from ruamel.yaml import YAML
@@ -17,15 +17,23 @@ from runnable.datastore import StepAttempt
17
17
  from runnable.tasks import BaseTaskType
18
18
  from runnable.utils import get_module_and_attr_names
19
19
 
20
+ logger = logging.getLogger(defaults.LOGGER_NAME)
21
+
20
22
  try:
21
23
  from torch.distributed.elastic.multiprocessing.api import DefaultLogsSpecs, Std
22
24
  from torch.distributed.launcher.api import LaunchConfig, elastic_launch
23
25
 
24
26
  except ImportError:
25
- raise ImportError("torch is not installed. Please install torch first.")
27
+ logger.exception("torch is not installed")
26
28
 
29
+ if TYPE_CHECKING:
30
+ from torch.distributed.elastic.multiprocessing.api import DefaultLogsSpecs, Std
31
+ from torch.distributed.launcher.api import LaunchConfig, elastic_launch
27
32
 
28
- logger = logging.getLogger(defaults.LOGGER_NAME)
33
+
34
+ def get_min_max_nodes(nnodes: str) -> tuple[int, int]:
35
+ min_nodes, max_nodes = (int(x) for x in nnodes.split(":"))
36
+ return min_nodes, max_nodes
29
37
 
30
38
 
31
39
  class TorchTaskType(BaseTaskType, TorchConfig):
@@ -60,7 +68,7 @@ class TorchTaskType(BaseTaskType, TorchConfig):
60
68
  exclude_none=True,
61
69
  )
62
70
  )
63
-
71
+ print("###", easy_torch_config)
64
72
  launch_config = LaunchConfig(
65
73
  **easy_torch_config.model_dump(
66
74
  exclude_none=True,
@@ -77,7 +85,30 @@ class TorchTaskType(BaseTaskType, TorchConfig):
77
85
  ):
78
86
  assert map_variable is None, "map_variable is not supported for torch"
79
87
 
88
+ # The below should happen only if we are in the node that we want to execute
89
+ # For a single node, multi worker setup, this should be the entry point
90
+ # For a multi-node, we need to:
91
+ # - create a service config
92
+ # - Create a stateful set with number of nodes
93
+ # - Create a job to run the torch.distributed.launcher.api.elastic_launch on every node
94
+ # - the entry point to runnnable could be a way to trigger execution instead of scaling
95
+ is_execute = os.environ.get("RUNNABLE_TORCH_EXECUTE", "true") == "true"
96
+
97
+ _, max_nodes = get_min_max_nodes(self.nnodes)
98
+
99
+ if max_nodes > 1 and not is_execute:
100
+ executor = self._context.executor
101
+ executor.scale_up(self)
102
+ return StepAttempt(
103
+ status=defaults.SUCCESS,
104
+ start_time=str(datetime.now()),
105
+ end_time=str(datetime.now()),
106
+ attempt_number=1,
107
+ message="Triggered a scale up",
108
+ )
109
+
80
110
  launch_config = self._get_launch_config()
111
+ print("###****", launch_config)
81
112
  logger.info(f"launch_config: {launch_config}")
82
113
 
83
114
  # ENV variables are shared with the subprocess, use that as communication
@@ -175,9 +206,6 @@ def training_subprocess():
175
206
  self._context.parameters_file or ""
176
207
  )
177
208
  os.environ["RUNNABLE_TORCH_RUN_ID"] = self._context.run_id
178
- os.environ["RUNNABLE_TORCH_COPY_CONTENTS_TO"] = (
179
- self._context.catalog_handler.compute_data_folder
180
- )
181
209
  os.environ["RUNNABLE_TORCH_TORCH_LOGS"] = self.log_dir or ""
182
210
 
183
211
  """
@@ -43,7 +43,7 @@ class TorchConfig(BaseModel):
43
43
  # and sent at the creation of the LaunchConfig
44
44
 
45
45
  # This section is about the communication between nodes/processes
46
- rdzv_backend: str | None = Field(default="static")
46
+ rdzv_backend: str | None = Field(default="")
47
47
  rdzv_endpoint: str | None = Field(default="")
48
48
  rdzv_configs: dict[str, Any] = Field(default_factory=dict)
49
49
  rdzv_timeout: int | None = Field(default=None)
runnable/executor.py CHANGED
@@ -153,6 +153,14 @@ class BaseJobExecutor(BaseExecutor):
153
153
  """
154
154
  ...
155
155
 
156
+ # @abstractmethod
157
+ # def scale_up(self, job: BaseTaskType):
158
+ # """
159
+ # Scale up the job to run on max_nodes
160
+ # This has to also call the entry point
161
+ # """
162
+ # ...
163
+
156
164
 
157
165
  # TODO: Consolidate execute_node, trigger_node_execution, _execute_node
158
166
  class BasePipelineExecutor(BaseExecutor):
runnable/parameters.py CHANGED
@@ -9,7 +9,7 @@ from pydantic import BaseModel, ConfigDict
9
9
  from typing_extensions import Callable
10
10
 
11
11
  from runnable import defaults
12
- from runnable.datastore import JsonParameter
12
+ from runnable.datastore import JsonParameter, ObjectParameter
13
13
  from runnable.defaults import TypeMapVariable
14
14
  from runnable.utils import remove_prefix
15
15
 
@@ -101,10 +101,13 @@ def filter_arguments_for_func(
101
101
  # default value is given in the function signature, nothing further to do.
102
102
  continue
103
103
 
104
+ param_value = params[name]
105
+
104
106
  if type(value.annotation) in [
105
107
  BaseModel,
106
108
  pydantic._internal._model_construction.ModelMetaclass,
107
- ]:
109
+ ] and not isinstance(param_value, ObjectParameter):
110
+ # Even if the annotation is a pydantic model, it can be passed as an object parameter
108
111
  # We try to cast it as a pydantic model if asked
109
112
  named_param = params[name].get_value()
110
113
 
runnable/sdk.py CHANGED
@@ -5,7 +5,7 @@ import os
5
5
  import re
6
6
  from abc import ABC, abstractmethod
7
7
  from pathlib import Path
8
- from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
8
+ from typing import Any, Callable, Dict, List, Optional, Union
9
9
 
10
10
  from pydantic import (
11
11
  BaseModel,
@@ -34,7 +34,7 @@ from extensions.nodes.nodes import (
34
34
  SuccessNode,
35
35
  TaskNode,
36
36
  )
37
- from extensions.nodes.torch_config import TorchConfig
37
+ from extensions.tasks.torch_config import TorchConfig
38
38
  from runnable import console, defaults, entrypoints, exceptions, graph, utils
39
39
  from runnable.executor import BaseJobExecutor, BasePipelineExecutor
40
40
  from runnable.nodes import TraversalNode
@@ -46,8 +46,6 @@ logger = logging.getLogger(defaults.LOGGER_NAME)
46
46
  StepType = Union[
47
47
  "Stub", "PythonTask", "NotebookTask", "ShellTask", "Parallel", "Map", "TorchTask"
48
48
  ]
49
- if TYPE_CHECKING:
50
- pass
51
49
 
52
50
 
53
51
  def pickled(name: str) -> TaskReturns:
@@ -192,6 +190,8 @@ class BaseTask(BaseTraversal):
192
190
 
193
191
 
194
192
  class TorchTask(BaseTask, TorchConfig):
193
+ # The user will not know the rnnz variables for multi node
194
+ # They should be overridden in the environment
195
195
  function: Callable = Field(exclude=True)
196
196
 
197
197
  @field_validator("returns", mode="before")
runnable/tasks.py CHANGED
@@ -760,7 +760,6 @@ def create_task(kwargs_for_init) -> BaseTaskType:
760
760
  """
761
761
  # The dictionary cannot be modified
762
762
 
763
- print(kwargs_for_init)
764
763
  kwargs = kwargs_for_init.copy()
765
764
  command_type = kwargs.pop("command_type", defaults.COMMAND_TYPE)
766
765
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: runnable
3
- Version: 0.32.0
3
+ Version: 0.32.2
4
4
  Summary: Add your description here
5
5
  Author-email: "Vammi, Vijay" <vijay.vammi@astrazeneca.com>
6
6
  License-File: LICENSE
@@ -8,7 +8,7 @@ extensions/catalog/pyproject.toml,sha256=lLNxY6v04c8I5QK_zKw_E6sJTArSJRA_V-79kta
8
8
  extensions/catalog/s3.py,sha256=Sw5t8_kVRprn3uGGJCiHn7M9zw1CLaCOFj6YErtfG0o,287
9
9
  extensions/job_executor/README.md,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
10
  extensions/job_executor/__init__.py,sha256=VeLuYCcShCIYT0TNtAXfUF9tOk4ZHoLzdTEvbsz0spM,5870
11
- extensions/job_executor/k8s.py,sha256=0V7BL7ERmonVMgCsO-J57cxH__v8KomwukMwepH3qgs,16388
11
+ extensions/job_executor/k8s.py,sha256=Jl0s3YryISx-SJIhDhyNskzlUlhy4ynBHEc9DfAXjAY,16394
12
12
  extensions/job_executor/k8s_job_spec.yaml,sha256=7aFpxHdO_p6Hkc3YxusUOuAQTD1Myu0yTPX9DrhxbOg,1158
13
13
  extensions/job_executor/local.py,sha256=3ZbCFXBvbLlMp10JTmQJJrjBKG2keHI6SH8hEvmHDkA,2230
14
14
  extensions/job_executor/local_container.py,sha256=1JcLJ0zrNSNHdubrSO9miN54iwvPLHqKMZ08aOC8WWo,6886
@@ -16,7 +16,7 @@ extensions/job_executor/pyproject.toml,sha256=UIEgiCYHTXcRWSByNMFuKJFKgxTBpQqTqy
16
16
  extensions/nodes/README.md,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
17
  extensions/nodes/nodes.py,sha256=s9ub1dqy4qHjRQG6YElCdL7rCOTYNs9RUIrStZ6tEB4,28256
18
18
  extensions/nodes/pyproject.toml,sha256=YTu-ETN3JNFSkMzzWeOwn4m-O2nbRH-PmiPBALDCUw4,278
19
- extensions/nodes/torch.py,sha256=h3x5931ePBNckeSXM3JFjSoUnxmIWvDyEpn1AI9TKaU,9347
19
+ extensions/nodes/torch.py,sha256=gydcRX5C7jEdPnxLsAQkpRD_by_0Lp4dFg96xDkRVW0,9510
20
20
  extensions/nodes/torch_config.py,sha256=tO3sG2_fj8a6FmPZZllwKVx3WaRr4QmQYcACseg8YXM,2839
21
21
  extensions/pipeline_executor/README.md,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
22
22
  extensions/pipeline_executor/__init__.py,sha256=wfigTL2T9OHrmE8b2Ydmb8h6hr-oF--Yc2FectC7WaY,24623
@@ -40,8 +40,8 @@ extensions/run_log_store/db/integration_FF.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeR
40
40
  extensions/secrets/README.md,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
41
41
  extensions/secrets/dotenv.py,sha256=nADHXI6KJ_LUYOIe5EbtYH-21OBebSNVr0Pjb1GlZ7w,1573
42
42
  extensions/secrets/pyproject.toml,sha256=mLJNImNcBlbLKHh-0ugVWT9V83R4RibyyYDtBCSqVF4,282
43
- extensions/tasks/torch.py,sha256=R0J_Q6SRAW2Ii0XQbXaaBWTah8TYs4P_48j2M1bIXeA,7983
44
- extensions/tasks/torch_config.py,sha256=tO3sG2_fj8a6FmPZZllwKVx3WaRr4QmQYcACseg8YXM,2839
43
+ extensions/tasks/torch.py,sha256=At2eMpJas4sUUjzJfPrEBGamG-k3MsxXU6Bou0h9BEs,9274
44
+ extensions/tasks/torch_config.py,sha256=UjfMitT-TXASRDGR30I2vDRnyk7JQnR-5CsOVidjpSY,2833
45
45
  runnable/__init__.py,sha256=3ZKuvGEkY_zHVQlJtarXd4jkjICxjgnw-bbKN_5SiJI,691
46
46
  runnable/catalog.py,sha256=4msQxLhLKlsDDrHFnGauPYe-Or-q9g8_RYCn_4dpxaU,4466
47
47
  runnable/cli.py,sha256=3BiKSj95h2Drn__YlchMPZ5rBMafuRb2OGIsVpbsO5Y,8788
@@ -50,18 +50,18 @@ runnable/datastore.py,sha256=ZobM1aVkgeUJ2fZYt63IFDsoNzObwc93hdByegS5YKQ,32396
50
50
  runnable/defaults.py,sha256=3o9IVGryyCE6PoQTOoaIaHHTbJGEzmdXMcwzOhwAYoI,3518
51
51
  runnable/entrypoints.py,sha256=1xCbWVUQLGmg5gkWnAVWFLAUf6j4avP9azX_vuGQUMY,18985
52
52
  runnable/exceptions.py,sha256=LFbp0-Qxg2PAMLEVt7w2whhBxSG-5pzUEv5qN-Rc4_c,3003
53
- runnable/executor.py,sha256=UOsYJ3NkTGw4FTR0iePX7AOJzY7vODhZ62aqrwVMO1c,15143
53
+ runnable/executor.py,sha256=Jr9yJtSH7CzjXJLWx3VWIUAQblstuGqzpFtajv7d39M,15348
54
54
  runnable/graph.py,sha256=poQz5zcvq89ju_u5sYlunQLPbHnXTaUmjcvstPwvT4U,16536
55
55
  runnable/names.py,sha256=vn92Kv9ANROYSZX6Z4z1v_WA3WiEdIYmG6KEStBFZug,8134
56
56
  runnable/nodes.py,sha256=d1eLttMAcV7CTwTEqOuNwZqItANoLUkXJ73Xp-srlyI,17811
57
- runnable/parameters.py,sha256=sT3DNGczivP9z7r4Cp_brbudg1z4J-zjmvrq3ppIrVs,5089
57
+ runnable/parameters.py,sha256=u77CdqqDAbVdzNeBFPNUfGnWPy9-SpBVmwEJ56xmDm8,5289
58
58
  runnable/pickler.py,sha256=ydJ_eti_U1F4l-YacFp7BWm6g5vTn04UXye25S1HVok,2684
59
- runnable/sdk.py,sha256=J1PyiHQD2v_0JaqHjY7xSaXwCUMi_mCNr70TsC-SFZU,35012
59
+ runnable/sdk.py,sha256=hwsEGCCFSijm0DZwDJGHmV8jdMuSU_3Pf-vYoomWYHw,35084
60
60
  runnable/secrets.py,sha256=4L_dBFxTgr8r_hHUD6RlZEtqaOHDRsFG5PXO5wlvMI0,2324
61
- runnable/tasks.py,sha256=_A0pcTyOGQL-72AicOxracsrwfs2Vg0r4mQyxz3k6Iw,29016
61
+ runnable/tasks.py,sha256=ABRhgiTY8F62pNlqJmVTDjwJwuzp8DqciUEOq1fpt1U,28989
62
62
  runnable/utils.py,sha256=hBr7oGwGL2VgfITlQCTz-a1iwvvf7Mfl-HY8UdENZac,19929
63
- runnable-0.32.0.dist-info/METADATA,sha256=t44gRxxaRugnqaRY9gGwweGT0OLvo_inlC3jxrhP3sg,10168
64
- runnable-0.32.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
65
- runnable-0.32.0.dist-info/entry_points.txt,sha256=uWHbbOSj0jlG54tFHw377xKkfVbjWvb_1Y9L_LgjJ0Q,1925
66
- runnable-0.32.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
67
- runnable-0.32.0.dist-info/RECORD,,
63
+ runnable-0.32.2.dist-info/METADATA,sha256=fcKKBj2v2AhRQFZ7ALqSdJrKF5r0Wg-QV6HVKqkBpRY,10168
64
+ runnable-0.32.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
65
+ runnable-0.32.2.dist-info/entry_points.txt,sha256=uWHbbOSj0jlG54tFHw377xKkfVbjWvb_1Y9L_LgjJ0Q,1925
66
+ runnable-0.32.2.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
67
+ runnable-0.32.2.dist-info/RECORD,,