runnable 0.28.7__py3-none-any.whl → 0.29.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -11,7 +11,7 @@ from rich import print
11
11
 
12
12
  from extensions.job_executor import GenericJobExecutor
13
13
  from runnable import console, defaults, utils
14
- from runnable.datastore import DataCatalog
14
+ from runnable.datastore import DataCatalog, StepAttempt
15
15
  from runnable.tasks import BaseTaskType
16
16
 
17
17
  logger = logging.getLogger(defaults.NAME)
@@ -213,10 +213,12 @@ class GenericK8sJobExecutor(GenericJobExecutor):
213
213
  job_log = self._context.run_log_store.get_job_log(run_id=self._context.run_id)
214
214
  self.add_code_identities(job_log)
215
215
 
216
- attempt_log = job.execute_command(
217
- attempt_number=self.step_attempt_number,
218
- mock=self.mock,
219
- )
216
+ if not self.mock:
217
+ attempt_log = job.execute_command()
218
+ else:
219
+ attempt_log = StepAttempt(
220
+ status=defaults.SUCCESS,
221
+ )
220
222
 
221
223
  job_log.status = attempt_log.status
222
224
  job_log.attempts.append(attempt_log)
@@ -455,10 +457,7 @@ class K8sJobExecutor(GenericK8sJobExecutor):
455
457
  job_log = self._context.run_log_store.get_job_log(run_id=self._context.run_id)
456
458
  self.add_code_identities(job_log)
457
459
 
458
- attempt_log = job.execute_command(
459
- attempt_number=self.step_attempt_number,
460
- mock=self.mock,
461
- )
460
+ attempt_log = job.execute_command()
462
461
 
463
462
  job_log.status = attempt_log.status
464
463
  job_log.attempts.append(attempt_log)
@@ -3,7 +3,7 @@ from typing import List, Optional
3
3
 
4
4
  from extensions.job_executor import GenericJobExecutor
5
5
  from runnable import console, defaults
6
- from runnable.datastore import DataCatalog
6
+ from runnable.datastore import DataCatalog, StepAttempt
7
7
  from runnable.tasks import BaseTaskType
8
8
 
9
9
  logger = logging.getLogger(defaults.LOGGER_NAME)
@@ -39,10 +39,12 @@ class LocalJobExecutor(GenericJobExecutor):
39
39
  job_log = self._context.run_log_store.get_job_log(run_id=self._context.run_id)
40
40
  self.add_code_identities(job_log)
41
41
 
42
- attempt_log = job.execute_command(
43
- attempt_number=self.step_attempt_number,
44
- mock=self.mock,
45
- )
42
+ if not self.mock:
43
+ attempt_log = job.execute_command()
44
+ else:
45
+ attempt_log = StepAttempt(
46
+ status=defaults.SUCCESS,
47
+ )
46
48
 
47
49
  job_log.status = attempt_log.status
48
50
  job_log.attempts.append(attempt_log)
@@ -6,7 +6,7 @@ from pydantic import Field
6
6
 
7
7
  from extensions.job_executor import GenericJobExecutor
8
8
  from runnable import console, defaults, utils
9
- from runnable.datastore import DataCatalog
9
+ from runnable.datastore import DataCatalog, StepAttempt
10
10
  from runnable.tasks import BaseTaskType
11
11
 
12
12
  logger = logging.getLogger(defaults.LOGGER_NAME)
@@ -54,10 +54,12 @@ class LocalContainerJobExecutor(GenericJobExecutor):
54
54
  job_log = self._context.run_log_store.get_job_log(run_id=self._context.run_id)
55
55
  self.add_code_identities(job_log)
56
56
 
57
- attempt_log = job.execute_command(
58
- attempt_number=self.step_attempt_number,
59
- mock=self.mock,
60
- )
57
+ if not self.mock:
58
+ attempt_log = job.execute_command()
59
+ else:
60
+ attempt_log = StepAttempt(
61
+ status=defaults.SUCCESS,
62
+ )
61
63
 
62
64
  job_log.status = attempt_log.status
63
65
  job_log.attempts.append(attempt_log)
extensions/nodes/nodes.py CHANGED
@@ -5,15 +5,9 @@ import sys
5
5
  from collections import OrderedDict
6
6
  from copy import deepcopy
7
7
  from datetime import datetime
8
- from typing import Annotated, Any, Callable, Dict, List, Optional, Tuple, Union, cast
9
-
10
- from pydantic import (
11
- ConfigDict,
12
- Field,
13
- ValidationInfo,
14
- field_serializer,
15
- field_validator,
16
- )
8
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
9
+
10
+ from pydantic import ConfigDict, Field, field_serializer
17
11
 
18
12
  from runnable import console, datastore, defaults, utils
19
13
  from runnable.datastore import (
@@ -73,7 +67,6 @@ class TaskNode(ExecutableNode):
73
67
  mock=False,
74
68
  map_variable: TypeMapVariable = None,
75
69
  attempt_number: int = 1,
76
- **kwargs,
77
70
  ) -> StepLog:
78
71
  """
79
72
  All that we do in runnable is to come to this point where we actually execute the command.
@@ -135,7 +128,6 @@ class FailNode(TerminalNode):
135
128
  mock=False,
136
129
  map_variable: TypeMapVariable = None,
137
130
  attempt_number: int = 1,
138
- **kwargs,
139
131
  ) -> StepLog:
140
132
  """
141
133
  Execute the failure node.
@@ -199,7 +191,6 @@ class SuccessNode(TerminalNode):
199
191
  mock=False,
200
192
  map_variable: TypeMapVariable = None,
201
193
  attempt_number: int = 1,
202
- **kwargs,
203
194
  ) -> StepLog:
204
195
  """
205
196
  Execute the success node.
@@ -255,7 +246,6 @@ class ParallelNode(CompositeNode):
255
246
 
256
247
  node_type: str = Field(default="parallel", serialization_alias="type")
257
248
  branches: Dict[str, Graph]
258
- is_composite: bool = Field(default=True, exclude=True)
259
249
 
260
250
  def get_summary(self) -> Dict[str, Any]:
261
251
  summary = {
@@ -298,7 +288,7 @@ class ParallelNode(CompositeNode):
298
288
 
299
289
  raise Exception(f"Branch {branch_name} does not exist")
300
290
 
301
- def fan_out(self, map_variable: TypeMapVariable = None, **kwargs):
291
+ def fan_out(self, map_variable: TypeMapVariable = None):
302
292
  """
303
293
  The general fan out method for a node of type Parallel.
304
294
  This method assumes that the step log has already been created.
@@ -321,7 +311,7 @@ class ParallelNode(CompositeNode):
321
311
  branch_log.status = defaults.PROCESSING
322
312
  self._context.run_log_store.add_branch_log(branch_log, self._context.run_id)
323
313
 
324
- def execute_as_graph(self, map_variable: TypeMapVariable = None, **kwargs):
314
+ def execute_as_graph(self, map_variable: TypeMapVariable = None):
325
315
  """
326
316
  This function does the actual execution of the sub-branches of the parallel node.
327
317
 
@@ -342,16 +332,14 @@ class ParallelNode(CompositeNode):
342
332
  executor (Executor): The Executor as per the use config
343
333
  **kwargs: Optional kwargs passed around
344
334
  """
345
- self.fan_out(map_variable=map_variable, **kwargs)
335
+ self.fan_out(map_variable=map_variable)
346
336
 
347
337
  for _, branch in self.branches.items():
348
- self._context.executor.execute_graph(
349
- branch, map_variable=map_variable, **kwargs
350
- )
338
+ self._context.executor.execute_graph(branch, map_variable=map_variable)
351
339
 
352
- self.fan_in(map_variable=map_variable, **kwargs)
340
+ self.fan_in(map_variable=map_variable)
353
341
 
354
- def fan_in(self, map_variable: TypeMapVariable = None, **kwargs):
342
+ def fan_in(self, map_variable: TypeMapVariable = None):
355
343
  """
356
344
  The general fan in method for a node of type Parallel.
357
345
 
@@ -412,7 +400,6 @@ class MapNode(CompositeNode):
412
400
  iterate_as: str
413
401
  reducer: Optional[str] = Field(default=None)
414
402
  branch: Graph
415
- is_composite: bool = True
416
403
 
417
404
  def get_summary(self) -> Dict[str, Any]:
418
405
  summary = {
@@ -515,7 +502,7 @@ class MapNode(CompositeNode):
515
502
  """
516
503
  return self.branch
517
504
 
518
- def fan_out(self, map_variable: TypeMapVariable = None, **kwargs):
505
+ def fan_out(self, map_variable: TypeMapVariable = None):
519
506
  """
520
507
  The general method to fan out for a node of type map.
521
508
  This method assumes that the step log has already been created.
@@ -563,7 +550,7 @@ class MapNode(CompositeNode):
563
550
  parameters=raw_parameters, run_id=self._context.run_id
564
551
  )
565
552
 
566
- def execute_as_graph(self, map_variable: TypeMapVariable = None, **kwargs):
553
+ def execute_as_graph(self, map_variable: TypeMapVariable = None):
567
554
  """
568
555
  This function does the actual execution of the branch of the map node.
569
556
 
@@ -607,19 +594,19 @@ class MapNode(CompositeNode):
607
594
  if not isinstance(iterate_on, list):
608
595
  raise Exception("Only list is allowed as a valid iterator type")
609
596
 
610
- self.fan_out(map_variable=map_variable, **kwargs)
597
+ self.fan_out(map_variable=map_variable)
611
598
 
612
599
  for iter_variable in iterate_on:
613
600
  effective_map_variable = map_variable or OrderedDict()
614
601
  effective_map_variable[self.iterate_as] = iter_variable
615
602
 
616
603
  self._context.executor.execute_graph(
617
- self.branch, map_variable=effective_map_variable, **kwargs
604
+ self.branch, map_variable=effective_map_variable
618
605
  )
619
606
 
620
- self.fan_in(map_variable=map_variable, **kwargs)
607
+ self.fan_in(map_variable=map_variable)
621
608
 
622
- def fan_in(self, map_variable: TypeMapVariable = None, **kwargs):
609
+ def fan_in(self, map_variable: TypeMapVariable = None):
623
610
  """
624
611
  The general method to fan in for a node of type map.
625
612
 
@@ -714,172 +701,6 @@ class MapNode(CompositeNode):
714
701
  )
715
702
 
716
703
 
717
- class DagNode(CompositeNode):
718
- """
719
- A composite node that internally holds a dag.
720
-
721
- The structure is generally:
722
- DagNode:
723
- dag_definition: A YAML file that holds the dag in 'dag' block
724
-
725
- The config is expected to have a variable 'dag_definition'.
726
- """
727
-
728
- node_type: str = Field(default="dag", serialization_alias="type")
729
- dag_definition: str
730
- branch: Graph
731
- is_composite: bool = True
732
- internal_branch_name: Annotated[str, Field(validate_default=True)] = ""
733
-
734
- def get_summary(self) -> Dict[str, Any]:
735
- summary = {
736
- "name": self.name,
737
- "type": self.node_type,
738
- }
739
- return summary
740
-
741
- @field_validator("internal_branch_name")
742
- @classmethod
743
- def validate_internal_branch_name(
744
- cls, internal_branch_name: str, info: ValidationInfo
745
- ):
746
- internal_name = info.data["internal_name"]
747
- return internal_name + "." + defaults.DAG_BRANCH_NAME
748
-
749
- @field_validator("dag_definition")
750
- @classmethod
751
- def validate_dag_definition(cls, value):
752
- if not value.endswith(".yaml"): # TODO: Might have a problem with the SDK
753
- raise ValueError("dag_definition must be a YAML file")
754
- return value
755
-
756
- @classmethod
757
- def parse_from_config(cls, config: Dict[str, Any]) -> "DagNode":
758
- internal_name = cast(str, config.get("internal_name"))
759
-
760
- if "dag_definition" not in config:
761
- raise Exception(f"No dag definition found in {config}")
762
-
763
- dag_config = utils.load_yaml(config["dag_definition"])
764
- if "dag" not in dag_config:
765
- raise Exception(
766
- "No DAG found in dag_definition, please provide it in dag block"
767
- )
768
-
769
- branch = create_graph(
770
- dag_config["dag"],
771
- internal_branch_name=internal_name + "." + defaults.DAG_BRANCH_NAME,
772
- )
773
-
774
- return cls(branch=branch, **config)
775
-
776
- def _get_branch_by_name(self, branch_name: str):
777
- """
778
- Retrieve a branch by name.
779
- The name is expected to follow a dot path convention.
780
-
781
- Returns a Graph Object
782
-
783
- Args:
784
- branch_name (str): The name of the branch to retrieve
785
-
786
- Raises:
787
- Exception: If the branch_name is not 'dag'
788
- """
789
- if branch_name != self.internal_branch_name:
790
- raise Exception(
791
- f"Node of type {self.node_type} only allows a branch of name {defaults.DAG_BRANCH_NAME}"
792
- )
793
-
794
- return self.branch
795
-
796
- def fan_out(self, map_variable: TypeMapVariable = None, **kwargs):
797
- """
798
- The general method to fan out for a node of type dag.
799
- The method assumes that the step log has already been created.
800
-
801
- Args:
802
- executor (BaseExecutor): The executor class as defined by the config
803
- map_variable (dict, optional): _description_. Defaults to None.
804
- """
805
- effective_branch_name = self._resolve_map_placeholders(
806
- self.internal_branch_name, map_variable=map_variable
807
- )
808
-
809
- branch_log = self._context.run_log_store.create_branch_log(
810
- effective_branch_name
811
- )
812
- branch_log.status = defaults.PROCESSING
813
- self._context.run_log_store.add_branch_log(branch_log, self._context.run_id)
814
-
815
- def execute_as_graph(self, map_variable: TypeMapVariable = None, **kwargs):
816
- """
817
- This function does the actual execution of the branch of the dag node.
818
-
819
- From a design perspective, this function should not be called if the execution is 3rd party orchestrated.
820
-
821
- The modes that render the job specifications, do not need to interact with this node at all
822
- as they have their own internal mechanisms of handling sub dags.
823
- If they do not, you can find a way using as-is nodes as hack nodes.
824
-
825
- The actual logic is :
826
- * We just execute the branch as with any other composite nodes
827
- * The branch name is called 'dag'
828
-
829
- The execution of a dag, could result in
830
- * The dag being completely executed with a definite (fail, success) state in case of
831
- local or local-container execution
832
- * The dag being in a processing state with PROCESSING status in case of local-aws-batch
833
-
834
- Only fail state is considered failure during this phase of execution.
835
-
836
- Args:
837
- executor (Executor): The Executor as per the use config
838
- **kwargs: Optional kwargs passed around
839
- """
840
- self.fan_out(map_variable=map_variable, **kwargs)
841
- self._context.executor.execute_graph(
842
- self.branch, map_variable=map_variable, **kwargs
843
- )
844
- self.fan_in(map_variable=map_variable, **kwargs)
845
-
846
- def fan_in(self, map_variable: TypeMapVariable = None, **kwargs):
847
- """
848
- The general method to fan in for a node of type dag.
849
-
850
- 3rd party orchestrators should call this method to find the status of the step log.
851
-
852
- Args:
853
- executor (BaseExecutor): The executor class as defined by the config
854
- map_variable (dict, optional): If the node is part of type dag. Defaults to None.
855
- """
856
- step_success_bool = True
857
- effective_branch_name = self._resolve_map_placeholders(
858
- self.internal_branch_name, map_variable=map_variable
859
- )
860
- effective_internal_name = self._resolve_map_placeholders(
861
- self.internal_name, map_variable=map_variable
862
- )
863
-
864
- branch_log = self._context.run_log_store.get_branch_log(
865
- effective_branch_name, self._context.run_id
866
- )
867
- if branch_log.status != defaults.SUCCESS:
868
- step_success_bool = False
869
-
870
- step_log = self._context.run_log_store.get_step_log(
871
- effective_internal_name, self._context.run_id
872
- )
873
- step_log.status = defaults.PROCESSING
874
-
875
- if step_success_bool: #  If none failed and nothing is waiting
876
- step_log.status = defaults.SUCCESS
877
- else:
878
- step_log.status = defaults.FAIL
879
-
880
- self._context.run_log_store.add_step_log(step_log, self._context.run_id)
881
-
882
-
883
704
  class StubNode(ExecutableNode):
884
705
  """
885
706
  Stub is a convenience design node.
@@ -926,7 +747,6 @@ class StubNode(ExecutableNode):
926
747
  mock=False,
927
748
  map_variable: TypeMapVariable = None,
928
749
  attempt_number: int = 1,
929
- **kwargs,
930
750
  ) -> StepLog:
931
751
  """
932
752
  Do Nothing node.
@@ -0,0 +1,169 @@
1
+ import importlib
2
+ import logging
3
+ import os
4
+ from datetime import datetime
5
+ from typing import Any, Callable
6
+
7
+ from pydantic import ConfigDict, Field
8
+
9
+ from extensions.nodes.torch_config import TorchConfig
10
+ from runnable import PythonJob, datastore, defaults
11
+ from runnable.datastore import StepLog
12
+ from runnable.nodes import DistributedNode
13
+ from runnable.tasks import PythonTaskType, create_task
14
+ from runnable.utils import TypeMapVariable
15
+
16
+ logger = logging.getLogger(defaults.LOGGER_NAME)
17
+
18
+ try:
19
+ from torch.distributed.launcher.api import LaunchConfig, elastic_launch
20
+ from torch.distributed.run import config_from_args
21
+ except ImportError:
22
+ raise ImportError("torch is not installed. Please install torch first.")
23
+
24
+ print("torch is installed")
25
+
26
+
27
+ def training_subprocess():
28
+ command = os.environ.get("RUNNABLE_TORCH_COMMAND")
29
+ run_id = os.environ.get("RUNNABLE_TORCH_RUN_ID", "")
30
+ parameters_files = os.environ.get("RUNNABLE_TORCH_PARAMETERS_FILES", "")
31
+ process_run_id = run_id + "-" + os.environ.get("RANK", "")
32
+
33
+ delete_env_vars_with_prefix("RUNNABLE_")
34
+
35
+ func = get_callable_from_dotted_path(command)
36
+ job = PythonJob(function=func)
37
+
38
+ job.execute(
39
+ parameters_file=parameters_files,
40
+ job_id=process_run_id,
41
+ )
42
+
43
+
44
+ def get_callable_from_dotted_path(dotted_path) -> Callable:
45
+ try:
46
+ # Split the path into module path and callable object
47
+ module_path, callable_name = dotted_path.rsplit(".", 1)
48
+
49
+ # Import the module
50
+ module = importlib.import_module(module_path)
51
+
52
+ # Get the callable from the module
53
+ callable_obj = getattr(module, callable_name)
54
+
55
+ # Check if the object is callable
56
+ if not callable(callable_obj):
57
+ raise TypeError(f"The object {callable_name} is not callable.")
58
+
59
+ return callable_obj
60
+
61
+ except (ImportError, AttributeError, ValueError) as e:
62
+ raise ImportError(f"Could not import '{dotted_path}'.") from e
63
+
64
+
65
+ def delete_env_vars_with_prefix(prefix):
66
+ to_delete = [] # List to keep track of variables to delete
67
+
68
+ # Iterate over a list of all environment variable keys
69
+ for var in os.environ:
70
+ if var.startswith(prefix):
71
+ to_delete.append(var)
72
+
73
+ # Delete each of the variables collected
74
+ for var in to_delete:
75
+ del os.environ[var]
76
+
77
+
78
+ class TorchNode(DistributedNode, TorchConfig):
79
+ node_type: str = Field(default="torch", serialization_alias="type")
80
+ executable: PythonTaskType = Field(exclude=True)
81
+
82
+ # Similar to TaskNode
83
+ model_config = ConfigDict(extra="allow")
84
+
85
+ def get_summary(self) -> dict[str, Any]:
86
+ summary = {
87
+ "name": self.name,
88
+ "type": self.node_type,
89
+ }
90
+
91
+ return summary
92
+
93
+ @classmethod
94
+ def parse_from_config(cls, config: dict[str, Any]) -> "TorchNode":
95
+ task_config = {
96
+ k: v for k, v in config.items() if k not in TorchNode.model_fields.keys()
97
+ }
98
+ node_config = {
99
+ k: v for k, v in config.items() if k in TorchNode.model_fields.keys()
100
+ }
101
+
102
+ executable = create_task(task_config)
103
+
104
+ assert isinstance(executable, PythonTaskType)
105
+ return cls(executable=executable, **node_config, **task_config)
106
+
107
+ def get_launch_config(self) -> LaunchConfig:
108
+ config, _, _ = config_from_args(self)
109
+ config.run_id = self._context.run_id
110
+ return config
111
+
112
+ def execute(
113
+ self,
114
+ mock=False,
115
+ map_variable: TypeMapVariable = None,
116
+ attempt_number: int = 1,
117
+ ) -> StepLog:
118
+ assert map_variable is None, "TorchNode does not support map_variable"
119
+
120
+ step_log = self._context.run_log_store.get_step_log(
121
+ self._get_step_log_name(map_variable), self._context.run_id
122
+ )
123
+
124
+ # Attempt to call the function or elastic launch
125
+ launch_config = self.get_launch_config()
126
+ logger.info(f"launch_config: {launch_config}")
127
+
128
+ os.environ["RUNNABLE_TORCH_COMMAND"] = self.executable.command
129
+ os.environ["RUNNABLE_TORCH_PARAMETERS_FILES"] = (
130
+ self._context.parameters_file or ""
131
+ )
132
+ os.environ["RUNNABLE_TORCH_RUN_ID"] = self._context.run_id
133
+ launcher = elastic_launch(
134
+ launch_config,
135
+ training_subprocess,
136
+ )
137
+ try:
138
+ launcher()
139
+ attempt_log = datastore.StepAttempt(
140
+ status=defaults.SUCCESS,
141
+ start_time=str(datetime.now()),
142
+ end_time=str(datetime.now()),
143
+ attempt_number=attempt_number,
144
+ )
145
+ except Exception as e:
146
+ attempt_log = datastore.StepAttempt(
147
+ status=defaults.FAIL,
148
+ start_time=str(datetime.now()),
149
+ end_time=str(datetime.now()),
150
+ attempt_number=attempt_number,
151
+ )
152
+ logger.error(f"Error executing TorchNode: {e}")
153
+
154
+ delete_env_vars_with_prefix("RUNNABLE_TORCH")
155
+
156
+ logger.info(f"attempt_log: {attempt_log}")
157
+ logger.info(f"Step {self.name} completed with status: {attempt_log.status}")
158
+
159
+ step_log.status = attempt_log.status
160
+ step_log.attempts.append(attempt_log)
161
+
162
+ return step_log
163
+
164
+ # TODO: Not sure we need these methods
165
+ def fan_in(self, map_variable: dict[str, str | int | float] | None = None):
166
+ assert map_variable is None, "TorchNode does not support map_variable"
167
+
168
+ def fan_out(self, map_variable: dict[str, str | int | float] | None = None):
169
+ assert map_variable is None, "TorchNode does not support map_variable"
@@ -0,0 +1,33 @@
1
+ from pydantic import BaseModel, Field
2
+
3
+
4
+ class TorchConfig(BaseModel):
5
+ nnodes: str = Field(default="1:1")
6
+ nproc_per_node: int = Field(default=4)
7
+
8
+ rdzv_backend: str = Field(default="static")
9
+ rdzv_endpoint: str = Field(default="")
10
+ rdzv_id: str | None = Field(default=None)
11
+ rdzv_conf: str = Field(default="")
12
+
13
+ max_restarts: int = Field(default=3)
14
+ monitor_interval: float = Field(default=0.1)
15
+ start_method: str = Field(default="spawn")
16
+ role: str = Field(default="default_role")
17
+ log_dir: str = Field(default="torch_logs")
18
+ redirects: str = Field(default="1")
19
+ tee: str = Field(default="1")
20
+ master_addr: str = Field(default="localhost")
21
+ master_port: str = Field(default="29500")
22
+ training_script: str = Field(default="dummy_training_script")
23
+ training_script_args: str = Field(default="")
24
+
25
+ # Optional fields
26
+ local_ranks_filter: str = Field(default="")
27
+ node_rank: int = Field(default=0)
28
+ local_addr: str | None = Field(default=None)
29
+ logs_specs: str | None = Field(default=None)
30
+ standalone: bool = Field(default=False)
31
+ module: bool = Field(default=False)
32
+ no_python: bool = Field(default=False)
33
+ run_path: bool = Field(default=False)