runnable 0.32.2__py3-none-any.whl → 0.32.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
extensions/nodes/torch.py CHANGED
@@ -5,14 +5,14 @@ import random
5
5
  import string
6
6
  from datetime import datetime
7
7
  from pathlib import Path
8
- from typing import TYPE_CHECKING, Any, Callable, Optional
8
+ from typing import Any, Callable, Optional
9
9
 
10
10
  from pydantic import BaseModel, ConfigDict, Field, field_serializer
11
11
 
12
12
  from extensions.nodes.torch_config import EasyTorchConfig, TorchConfig
13
13
  from runnable import PythonJob, datastore, defaults
14
14
  from runnable.datastore import StepLog
15
- from runnable.nodes import DistributedNode
15
+ from runnable.nodes import ExecutableNode
16
16
  from runnable.tasks import PythonTaskType, create_task
17
17
  from runnable.utils import TypeMapVariable
18
18
 
@@ -23,10 +23,7 @@ try:
23
23
  from torch.distributed.launcher.api import LaunchConfig, elastic_launch
24
24
  except ImportError:
25
25
  logger.exception("Torch is not installed. Please install torch first.")
26
-
27
- if TYPE_CHECKING:
28
- from torch.distributed.elastic.multiprocessing.api import DefaultLogsSpecs, Std
29
- from torch.distributed.launcher.api import LaunchConfig, elastic_launch
26
+ raise Exception("Torch is not installed. Please install torch first.")
30
27
 
31
28
 
32
29
  def training_subprocess():
@@ -120,7 +117,7 @@ def delete_env_vars_with_prefix(prefix):
120
117
 
121
118
 
122
119
  # TODO: The design of this class is not final
123
- class TorchNode(DistributedNode, TorchConfig):
120
+ class TorchNode(ExecutableNode, TorchConfig):
124
121
  node_type: str = Field(default="torch", serialization_alias="type")
125
122
  executable: PythonTaskType = Field(exclude=True)
126
123
 
@@ -20,14 +20,10 @@ from pydantic import (
20
20
  from pydantic.alias_generators import to_camel
21
21
  from ruamel.yaml import YAML
22
22
 
23
- from extensions.nodes.nodes import (
24
- MapNode,
25
- ParallelNode,
26
- StubNode,
27
- SuccessNode,
28
- TaskNode,
29
- )
30
- from extensions.nodes.torch import TorchNode
23
+ from extensions.nodes.nodes import MapNode, ParallelNode, TaskNode
24
+
25
+ # TODO: Should be part of a wider refactor
26
+ # from extensions.nodes.torch import TorchNode
31
27
  from extensions.pipeline_executor import GenericPipelineExecutor
32
28
  from runnable import defaults, utils
33
29
  from runnable.defaults import TypeMapVariable
@@ -590,12 +586,7 @@ class ArgoExecutor(GenericPipelineExecutor):
590
586
  task_name: str,
591
587
  inputs: Optional[Inputs] = None,
592
588
  ) -> ContainerTemplate:
593
- assert (
594
- isinstance(node, TaskNode)
595
- or isinstance(node, StubNode)
596
- or isinstance(node, SuccessNode)
597
- or isinstance(node, TorchNode)
598
- )
589
+ assert node.node_type in ["task", "torch", "success", "stub", "fail"]
599
590
 
600
591
  node_override = None
601
592
  if hasattr(node, "overrides"):
@@ -658,7 +649,7 @@ class ArgoExecutor(GenericPipelineExecutor):
658
649
  def _set_env_vars_to_task(
659
650
  self, working_on: BaseNode, container_template: CoreContainerTemplate
660
651
  ):
661
- if not (isinstance(working_on, TaskNode) or isinstance(working_on, TorchNode)):
652
+ if working_on.node_type not in ["task", "torch"]:
662
653
  return
663
654
 
664
655
  global_envs: dict[str, str] = {}
@@ -667,7 +658,7 @@ class ArgoExecutor(GenericPipelineExecutor):
667
658
  env_var = cast(EnvVar, env_var)
668
659
  global_envs[env_var.name] = env_var.value
669
660
 
670
- override_key = working_on.overrides.get(self.service_name, "")
661
+ override_key = working_on.overrides.get(self.service_name, "") # type: ignore
671
662
  node_override = self.overrides.get(override_key, None)
672
663
 
673
664
  # Update the global envs with the node overrides
@@ -878,6 +869,8 @@ class ArgoExecutor(GenericPipelineExecutor):
878
869
  self._templates.append(composite_template)
879
870
 
880
871
  case "torch":
872
+ from extensions.nodes.torch import TorchNode
873
+
881
874
  assert isinstance(working_on, TorchNode)
882
875
  # TODO: Need to add multi-node functionality
883
876
  # Check notes on the torch node
extensions/tasks/torch.py CHANGED
@@ -5,7 +5,7 @@ import random
5
5
  import string
6
6
  from datetime import datetime
7
7
  from pathlib import Path
8
- from typing import TYPE_CHECKING, Any, Optional
8
+ from typing import Any, Optional
9
9
 
10
10
  from pydantic import BaseModel, ConfigDict, Field, field_serializer, model_validator
11
11
  from ruamel.yaml import YAML
@@ -19,16 +19,15 @@ from runnable.utils import get_module_and_attr_names
19
19
 
20
20
  logger = logging.getLogger(defaults.LOGGER_NAME)
21
21
 
22
+ logger = logging.getLogger(defaults.LOGGER_NAME)
23
+
22
24
  try:
23
25
  from torch.distributed.elastic.multiprocessing.api import DefaultLogsSpecs, Std
24
26
  from torch.distributed.launcher.api import LaunchConfig, elastic_launch
25
27
 
26
- except ImportError:
28
+ except ImportError as e:
27
29
  logger.exception("torch is not installed")
28
-
29
- if TYPE_CHECKING:
30
- from torch.distributed.elastic.multiprocessing.api import DefaultLogsSpecs, Std
31
- from torch.distributed.launcher.api import LaunchConfig, elastic_launch
30
+ raise Exception("torch is not installed") from e
32
31
 
33
32
 
34
33
  def get_min_max_nodes(nnodes: str) -> tuple[int, int]:
@@ -69,6 +68,7 @@ class TorchTaskType(BaseTaskType, TorchConfig):
69
68
  )
70
69
  )
71
70
  print("###", easy_torch_config)
71
+ print("###", easy_torch_config)
72
72
  launch_config = LaunchConfig(
73
73
  **easy_torch_config.model_dump(
74
74
  exclude_none=True,
@@ -96,6 +96,28 @@ class TorchTaskType(BaseTaskType, TorchConfig):
96
96
 
97
97
  _, max_nodes = get_min_max_nodes(self.nnodes)
98
98
 
99
+ if max_nodes > 1 and not is_execute:
100
+ executor = self._context.executor
101
+ executor.scale_up(self)
102
+ return StepAttempt(
103
+ status=defaults.SUCCESS,
104
+ start_time=str(datetime.now()),
105
+ end_time=str(datetime.now()),
106
+ attempt_number=1,
107
+ message="Triggered a scale up",
108
+ )
109
+
110
+ # The below should happen only if we are in the node that we want to execute
111
+ # For a single node, multi worker setup, this should be the entry point
112
+ # For a multi-node, we need to:
113
+ # - create a service config
114
+ # - Create a stateful set with number of nodes
115
+ # - Create a job to run the torch.distributed.launcher.api.elastic_launch on every node
116
+ # - the entry point to runnnable could be a way to trigger execution instead of scaling
117
+ is_execute = os.environ.get("RUNNABLE_TORCH_EXECUTE", "true") == "true"
118
+
119
+ _, max_nodes = get_min_max_nodes(self.nnodes)
120
+
99
121
  if max_nodes > 1 and not is_execute:
100
122
  executor = self._context.executor
101
123
  executor.scale_up(self)
@@ -109,6 +131,7 @@ class TorchTaskType(BaseTaskType, TorchConfig):
109
131
 
110
132
  launch_config = self._get_launch_config()
111
133
  print("###****", launch_config)
134
+ print("###****", launch_config)
112
135
  logger.info(f"launch_config: {launch_config}")
113
136
 
114
137
  # ENV variables are shared with the subprocess, use that as communication
runnable/nodes.py CHANGED
@@ -36,8 +36,8 @@ class BaseNode(ABC, BaseModel):
36
36
  name: str
37
37
  internal_name: str = Field(exclude=True)
38
38
  internal_branch_name: str = Field(default="", exclude=True)
39
+
39
40
  is_composite: bool = Field(default=False, exclude=True)
40
- is_distributed: bool = Field(default=False, exclude=True)
41
41
 
42
42
  @property
43
43
  def _context(self):
@@ -483,41 +483,6 @@ class CompositeNode(TraversalNode):
483
483
  )
484
484
 
485
485
 
486
- class DistributedNode(TraversalNode):
487
- """
488
- Use this node for distributed execution of tasks.
489
- eg: torch distributed, horovod, etc.
490
- """
491
-
492
- is_distributed: bool = True
493
- catalog: Optional[CatalogStructure] = Field(default=None)
494
- max_attempts: int = Field(default=1, ge=1)
495
-
496
- def _get_catalog_settings(self) -> Dict[str, Any]:
497
- """
498
- If the node defines a catalog settings, return it or None
499
-
500
- Returns:
501
- dict: catalog settings defined as per the node or None
502
- """
503
- if self.catalog:
504
- return self.catalog.model_dump()
505
- return {}
506
-
507
- def _get_max_attempts(self) -> int:
508
- return self.max_attempts
509
-
510
- def _get_branch_by_name(self, branch_name: str):
511
- raise exceptions.NodeMethodCallError(
512
- "This is an distributed node and does not have branches"
513
- )
514
-
515
- def execute_as_graph(self, map_variable: TypeMapVariable = None):
516
- raise exceptions.NodeMethodCallError(
517
- "This is an executable node and does not have a graph"
518
- )
519
-
520
-
521
486
  class TerminalNode(BaseNode):
522
487
  def _get_on_failure_node(self) -> str:
523
488
  return ""
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: runnable
3
- Version: 0.32.2
3
+ Version: 0.32.3
4
4
  Summary: Add your description here
5
5
  Author-email: "Vammi, Vijay" <vijay.vammi@astrazeneca.com>
6
6
  License-File: LICENSE
@@ -16,11 +16,11 @@ extensions/job_executor/pyproject.toml,sha256=UIEgiCYHTXcRWSByNMFuKJFKgxTBpQqTqy
16
16
  extensions/nodes/README.md,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
17
  extensions/nodes/nodes.py,sha256=s9ub1dqy4qHjRQG6YElCdL7rCOTYNs9RUIrStZ6tEB4,28256
18
18
  extensions/nodes/pyproject.toml,sha256=YTu-ETN3JNFSkMzzWeOwn4m-O2nbRH-PmiPBALDCUw4,278
19
- extensions/nodes/torch.py,sha256=gydcRX5C7jEdPnxLsAQkpRD_by_0Lp4dFg96xDkRVW0,9510
19
+ extensions/nodes/torch.py,sha256=64DTjdPNSJ8vfMwUN9h9Ly5g9qj-Bga7LSGrfCAO0BY,9389
20
20
  extensions/nodes/torch_config.py,sha256=tO3sG2_fj8a6FmPZZllwKVx3WaRr4QmQYcACseg8YXM,2839
21
21
  extensions/pipeline_executor/README.md,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
22
22
  extensions/pipeline_executor/__init__.py,sha256=wfigTL2T9OHrmE8b2Ydmb8h6hr-oF--Yc2FectC7WaY,24623
23
- extensions/pipeline_executor/argo.py,sha256=AEGSWVZulBL6EsvbVCaeBeTl2m_t5ymc6RFpMKhivis,37946
23
+ extensions/pipeline_executor/argo.py,sha256=lHM3TM_UnQc4I1ghkuYdeBLpyr4pBLg-Ubnaf55Zw54,37878
24
24
  extensions/pipeline_executor/local.py,sha256=6oWUJ6b6NvIkpeQJBoCT1hbfX4_6WCB4HzMgHZ4ik1A,1887
25
25
  extensions/pipeline_executor/local_container.py,sha256=3kZ2QCsrq_YjH9dcAz8v05knKShQ_JtbIU-IA_-G538,12724
26
26
  extensions/pipeline_executor/mocked.py,sha256=0sMmypuvstBIv9uQg-WAcPrF3oOFpeEXNi6N8Nzdnl0,5680
@@ -40,7 +40,7 @@ extensions/run_log_store/db/integration_FF.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeR
40
40
  extensions/secrets/README.md,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
41
41
  extensions/secrets/dotenv.py,sha256=nADHXI6KJ_LUYOIe5EbtYH-21OBebSNVr0Pjb1GlZ7w,1573
42
42
  extensions/secrets/pyproject.toml,sha256=mLJNImNcBlbLKHh-0ugVWT9V83R4RibyyYDtBCSqVF4,282
43
- extensions/tasks/torch.py,sha256=At2eMpJas4sUUjzJfPrEBGamG-k3MsxXU6Bou0h9BEs,9274
43
+ extensions/tasks/torch.py,sha256=oeXRkmuttFIAuBwH7-h4SOVXMDOZXX5mvqI2aFrR3Vo,10283
44
44
  extensions/tasks/torch_config.py,sha256=UjfMitT-TXASRDGR30I2vDRnyk7JQnR-5CsOVidjpSY,2833
45
45
  runnable/__init__.py,sha256=3ZKuvGEkY_zHVQlJtarXd4jkjICxjgnw-bbKN_5SiJI,691
46
46
  runnable/catalog.py,sha256=4msQxLhLKlsDDrHFnGauPYe-Or-q9g8_RYCn_4dpxaU,4466
@@ -53,15 +53,15 @@ runnable/exceptions.py,sha256=LFbp0-Qxg2PAMLEVt7w2whhBxSG-5pzUEv5qN-Rc4_c,3003
53
53
  runnable/executor.py,sha256=Jr9yJtSH7CzjXJLWx3VWIUAQblstuGqzpFtajv7d39M,15348
54
54
  runnable/graph.py,sha256=poQz5zcvq89ju_u5sYlunQLPbHnXTaUmjcvstPwvT4U,16536
55
55
  runnable/names.py,sha256=vn92Kv9ANROYSZX6Z4z1v_WA3WiEdIYmG6KEStBFZug,8134
56
- runnable/nodes.py,sha256=d1eLttMAcV7CTwTEqOuNwZqItANoLUkXJ73Xp-srlyI,17811
56
+ runnable/nodes.py,sha256=QGHMznriEz4AcmntHICBZKrDT6zbc7WD1sV0MgwK10c,16691
57
57
  runnable/parameters.py,sha256=u77CdqqDAbVdzNeBFPNUfGnWPy9-SpBVmwEJ56xmDm8,5289
58
58
  runnable/pickler.py,sha256=ydJ_eti_U1F4l-YacFp7BWm6g5vTn04UXye25S1HVok,2684
59
59
  runnable/sdk.py,sha256=hwsEGCCFSijm0DZwDJGHmV8jdMuSU_3Pf-vYoomWYHw,35084
60
60
  runnable/secrets.py,sha256=4L_dBFxTgr8r_hHUD6RlZEtqaOHDRsFG5PXO5wlvMI0,2324
61
61
  runnable/tasks.py,sha256=ABRhgiTY8F62pNlqJmVTDjwJwuzp8DqciUEOq1fpt1U,28989
62
62
  runnable/utils.py,sha256=hBr7oGwGL2VgfITlQCTz-a1iwvvf7Mfl-HY8UdENZac,19929
63
- runnable-0.32.2.dist-info/METADATA,sha256=fcKKBj2v2AhRQFZ7ALqSdJrKF5r0Wg-QV6HVKqkBpRY,10168
64
- runnable-0.32.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
65
- runnable-0.32.2.dist-info/entry_points.txt,sha256=uWHbbOSj0jlG54tFHw377xKkfVbjWvb_1Y9L_LgjJ0Q,1925
66
- runnable-0.32.2.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
67
- runnable-0.32.2.dist-info/RECORD,,
63
+ runnable-0.32.3.dist-info/METADATA,sha256=l0jxi_VKPXblTa_Kd-fnqKophmB2e8x1Dj1HDbJV570,10168
64
+ runnable-0.32.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
65
+ runnable-0.32.3.dist-info/entry_points.txt,sha256=uWHbbOSj0jlG54tFHw377xKkfVbjWvb_1Y9L_LgjJ0Q,1925
66
+ runnable-0.32.3.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
67
+ runnable-0.32.3.dist-info/RECORD,,