runnable 0.28.7__py3-none-any.whl → 0.29.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -198,7 +198,6 @@ class GenericPipelineExecutor(BasePipelineExecutor):
198
198
  node: BaseNode,
199
199
  map_variable: TypeMapVariable = None,
200
200
  mock: bool = False,
201
- **kwargs,
202
201
  ):
203
202
  """
204
203
  This is the entry point when we do the actual execution of the function.
@@ -232,7 +231,6 @@ class GenericPipelineExecutor(BasePipelineExecutor):
232
231
  map_variable=map_variable,
233
232
  attempt_number=self.step_attempt_number,
234
233
  mock=mock,
235
- **kwargs,
236
234
  )
237
235
 
238
236
  data_catalogs_put: Optional[List[DataCatalog]] = self._sync_catalog(stage="put")
@@ -248,7 +246,7 @@ class GenericPipelineExecutor(BasePipelineExecutor):
248
246
 
249
247
  self._context.run_log_store.add_step_log(step_log, self._context.run_id)
250
248
 
251
- def add_code_identities(self, node: BaseNode, step_log: StepLog, **kwargs):
249
+ def add_code_identities(self, node: BaseNode, step_log: StepLog):
252
250
  """
253
251
  Add code identities specific to the implementation.
254
252
 
@@ -260,9 +258,7 @@ class GenericPipelineExecutor(BasePipelineExecutor):
260
258
  """
261
259
  step_log.code_identities.append(utils.get_git_code_identity())
262
260
 
263
- def execute_from_graph(
264
- self, node: BaseNode, map_variable: TypeMapVariable = None, **kwargs
265
- ):
261
+ def execute_from_graph(self, node: BaseNode, map_variable: TypeMapVariable = None):
266
262
  """
267
263
  This is the entry point to from the graph execution.
268
264
 
@@ -303,12 +299,12 @@ class GenericPipelineExecutor(BasePipelineExecutor):
303
299
  # Add the step log to the database as per the situation.
304
300
  # If its a terminal node, complete it now
305
301
  if node.node_type in ["success", "fail"]:
306
- self._execute_node(node, map_variable=map_variable, **kwargs)
302
+ self._execute_node(node, map_variable=map_variable)
307
303
  return
308
304
 
309
305
  # We call an internal function to iterate the sub graphs and execute them
310
306
  if node.is_composite:
311
- node.execute_as_graph(map_variable=map_variable, **kwargs)
307
+ node.execute_as_graph(map_variable=map_variable)
312
308
  return
313
309
 
314
310
  task_console.export_text(clear=True)
@@ -317,10 +313,10 @@ class GenericPipelineExecutor(BasePipelineExecutor):
317
313
  console.print(
318
314
  f":runner: Executing the node {task_name} ... ", style="bold color(208)"
319
315
  )
320
- self.trigger_node_execution(node=node, map_variable=map_variable, **kwargs)
316
+ self.trigger_node_execution(node=node, map_variable=map_variable)
321
317
 
322
318
  def trigger_node_execution(
323
- self, node: BaseNode, map_variable: TypeMapVariable = None, **kwargs
319
+ self, node: BaseNode, map_variable: TypeMapVariable = None
324
320
  ):
325
321
  """
326
322
  Call this method only if we are responsible for traversing the graph via
@@ -376,7 +372,7 @@ class GenericPipelineExecutor(BasePipelineExecutor):
376
372
 
377
373
  return step_log.status, next_node_name
378
374
 
379
- def execute_graph(self, dag: Graph, map_variable: TypeMapVariable = None, **kwargs):
375
+ def execute_graph(self, dag: Graph, map_variable: TypeMapVariable = None):
380
376
  """
381
377
  The parallelization is controlled by the nodes and not by this function.
382
378
 
@@ -430,7 +426,7 @@ class GenericPipelineExecutor(BasePipelineExecutor):
430
426
  )
431
427
 
432
428
  try:
433
- self.execute_from_graph(working_on, map_variable=map_variable, **kwargs)
429
+ self.execute_from_graph(working_on, map_variable=map_variable)
434
430
  status, next_node_name = self._get_status_and_next_node_name(
435
431
  current_node=working_on, dag=dag, map_variable=map_variable
436
432
  )
@@ -593,7 +589,7 @@ class GenericPipelineExecutor(BasePipelineExecutor):
593
589
  step_log.status = defaults.PROCESSING
594
590
  self._context.run_log_store.add_step_log(step_log, self._context.run_id)
595
591
 
596
- node.fan_out(executor=self, map_variable=map_variable)
592
+ node.fan_out(map_variable=map_variable)
597
593
 
598
594
  def fan_in(self, node: BaseNode, map_variable: TypeMapVariable = None):
599
595
  """
@@ -614,7 +610,7 @@ class GenericPipelineExecutor(BasePipelineExecutor):
614
610
  map_variable (dict, optional): If the node if of a map state,.Defaults to None.
615
611
 
616
612
  """
617
- node.fan_in(executor=self, map_variable=map_variable)
613
+ node.fan_in(map_variable=map_variable)
618
614
 
619
615
  step_log = self._context.run_log_store.get_step_log(
620
616
  node._get_step_log_name(map_variable=map_variable), self._context.run_id
@@ -810,7 +810,6 @@ class ArgoExecutor(GenericPipelineExecutor):
810
810
  self,
811
811
  dag: Graph,
812
812
  map_variable: dict[str, str | int | float] | None = None,
813
- **kwargs,
814
813
  ):
815
814
  # All the arguments set at the spec level can be referred as "{{workflow.parameters.*}}"
816
815
  # We want to use that functionality to override the parameters at the task level
@@ -886,7 +885,6 @@ class ArgoExecutor(GenericPipelineExecutor):
886
885
  self,
887
886
  node: BaseNode,
888
887
  map_variable: dict[str, str | int | float] | None = None,
889
- **kwargs,
890
888
  ):
891
889
  error_on_existing_run_id = os.environ.get("error_on_existing_run_id", "false")
892
890
  exists_ok = error_on_existing_run_id == "false"
@@ -904,7 +902,7 @@ class ArgoExecutor(GenericPipelineExecutor):
904
902
  step_log.status = defaults.PROCESSING
905
903
  self._context.run_log_store.add_step_log(step_log, self._context.run_id)
906
904
 
907
- self._execute_node(node=node, map_variable=map_variable, **kwargs)
905
+ self._execute_node(node=node, map_variable=map_variable)
908
906
 
909
907
  # Raise exception if the step failed
910
908
  step_log = self._context.run_log_store.get_step_log(
@@ -29,16 +29,14 @@ class LocalExecutor(GenericPipelineExecutor):
29
29
 
30
30
  _is_local: bool = PrivateAttr(default=True)
31
31
 
32
- def execute_from_graph(
33
- self, node: BaseNode, map_variable: TypeMapVariable = None, **kwargs
34
- ):
32
+ def execute_from_graph(self, node: BaseNode, map_variable: TypeMapVariable = None):
35
33
  if not self.object_serialisation:
36
34
  self._context.object_serialisation = False
37
35
 
38
- super().execute_from_graph(node=node, map_variable=map_variable, **kwargs)
36
+ super().execute_from_graph(node=node, map_variable=map_variable)
39
37
 
40
38
  def trigger_node_execution(
41
- self, node: BaseNode, map_variable: TypeMapVariable = None, **kwargs
39
+ self, node: BaseNode, map_variable: TypeMapVariable = None
42
40
  ):
43
41
  """
44
42
  In this mode of execution, we prepare for the node execution and execute the node
@@ -47,11 +45,9 @@ class LocalExecutor(GenericPipelineExecutor):
47
45
  node (BaseNode): [description]
48
46
  map_variable (str, optional): [description]. Defaults to ''.
49
47
  """
50
- self.execute_node(node=node, map_variable=map_variable, **kwargs)
48
+ self.execute_node(node=node, map_variable=map_variable)
51
49
 
52
- def execute_node(
53
- self, node: BaseNode, map_variable: TypeMapVariable = None, **kwargs
54
- ):
50
+ def execute_node(self, node: BaseNode, map_variable: TypeMapVariable = None):
55
51
  """
56
52
  For local execution, we just execute the node.
57
53
 
@@ -59,4 +55,4 @@ class LocalExecutor(GenericPipelineExecutor):
59
55
  node (BaseNode): _description_
60
56
  map_variable (dict[str, str], optional): _description_. Defaults to None.
61
57
  """
62
- self._execute_node(node=node, map_variable=map_variable, **kwargs)
58
+ self._execute_node(node=node, map_variable=map_variable)
@@ -59,7 +59,7 @@ class LocalContainerExecutor(GenericPipelineExecutor):
59
59
  _container_secrets_location = "/tmp/dotenv"
60
60
  _volumes: Dict[str, Dict[str, str]] = {}
61
61
 
62
- def add_code_identities(self, node: BaseNode, step_log: StepLog, **kwargs):
62
+ def add_code_identities(self, node: BaseNode, step_log: StepLog):
63
63
  """
64
64
  Call the Base class to add the git code identity and add docker identity
65
65
 
@@ -86,18 +86,18 @@ class LocalContainerExecutor(GenericPipelineExecutor):
86
86
  code_id.code_identifier_url = "local docker host"
87
87
  step_log.code_identities.append(code_id)
88
88
 
89
- def execute_node(
90
- self, node: BaseNode, map_variable: TypeMapVariable = None, **kwargs
91
- ):
89
+ def execute_node(self, node: BaseNode, map_variable: TypeMapVariable = None):
92
90
  """
93
91
  We are already in the container, we just execute the node.
94
92
  The node is already prepared for execution.
95
93
  """
96
94
  self._use_volumes()
97
- return self._execute_node(node, map_variable, **kwargs)
95
+ return self._execute_node(node, map_variable)
98
96
 
99
97
  def execute_from_graph(
100
- self, node: BaseNode, map_variable: TypeMapVariable = None, **kwargs
98
+ self,
99
+ node: BaseNode,
100
+ map_variable: TypeMapVariable = None,
101
101
  ):
102
102
  """
103
103
  This is the entry point to from the graph execution.
@@ -139,12 +139,12 @@ class LocalContainerExecutor(GenericPipelineExecutor):
139
139
  # Add the step log to the database as per the situation.
140
140
  # If its a terminal node, complete it now
141
141
  if node.node_type in ["success", "fail"]:
142
- self._execute_node(node, map_variable=map_variable, **kwargs)
142
+ self._execute_node(node, map_variable=map_variable)
143
143
  return
144
144
 
145
145
  # We call an internal function to iterate the sub graphs and execute them
146
146
  if node.is_composite:
147
- node.execute_as_graph(map_variable=map_variable, **kwargs)
147
+ node.execute_as_graph(map_variable=map_variable)
148
148
  return
149
149
 
150
150
  task_console.export_text(clear=True)
@@ -153,10 +153,10 @@ class LocalContainerExecutor(GenericPipelineExecutor):
153
153
  console.print(
154
154
  f":runner: Executing the node {task_name} ... ", style="bold color(208)"
155
155
  )
156
- self.trigger_node_execution(node=node, map_variable=map_variable, **kwargs)
156
+ self.trigger_node_execution(node=node, map_variable=map_variable)
157
157
 
158
158
  def trigger_node_execution(
159
- self, node: BaseNode, map_variable: TypeMapVariable = None, **kwargs
159
+ self, node: BaseNode, map_variable: TypeMapVariable = None
160
160
  ):
161
161
  """
162
162
  We come into this step via execute from graph, use trigger job to spin up the container.
@@ -181,7 +181,6 @@ class LocalContainerExecutor(GenericPipelineExecutor):
181
181
  command=command,
182
182
  map_variable=map_variable,
183
183
  auto_remove_container=auto_remove_container,
184
- **kwargs,
185
184
  )
186
185
 
187
186
  step_log = self._context.run_log_store.get_step_log(
@@ -203,7 +202,6 @@ class LocalContainerExecutor(GenericPipelineExecutor):
203
202
  command: str,
204
203
  map_variable: TypeMapVariable = None,
205
204
  auto_remove_container: bool = True,
206
- **kwargs,
207
205
  ):
208
206
  """
209
207
  During the flow run, we have to spin up a container with the docker image mentioned
@@ -36,9 +36,7 @@ class MockedExecutor(GenericPipelineExecutor):
36
36
  def _context(self):
37
37
  return context.run_context
38
38
 
39
- def execute_from_graph(
40
- self, node: BaseNode, map_variable: TypeMapVariable = None, **kwargs
41
- ):
39
+ def execute_from_graph(self, node: BaseNode, map_variable: TypeMapVariable = None):
42
40
  """
43
41
  This is the entry point to from the graph execution.
44
42
 
@@ -80,18 +78,18 @@ class MockedExecutor(GenericPipelineExecutor):
80
78
  # If its a terminal node, complete it now
81
79
  if node.node_type in ["success", "fail"]:
82
80
  self._context.run_log_store.add_step_log(step_log, self._context.run_id)
83
- self._execute_node(node, map_variable=map_variable, **kwargs)
81
+ self._execute_node(node, map_variable=map_variable)
84
82
  return
85
83
 
86
84
  # We call an internal function to iterate the sub graphs and execute them
87
85
  if node.is_composite:
88
86
  self._context.run_log_store.add_step_log(step_log, self._context.run_id)
89
- node.execute_as_graph(map_variable=map_variable, **kwargs)
87
+ node.execute_as_graph(map_variable=map_variable)
90
88
  return
91
89
 
92
90
  if node.name not in self.patches:
93
91
  # node is not patched, so mock it
94
- self._execute_node(node, map_variable=map_variable, mock=True, **kwargs)
92
+ self._execute_node(node, map_variable=map_variable, mock=True)
95
93
  else:
96
94
  # node is patched
97
95
  # command as the patch value
@@ -103,9 +101,7 @@ class MockedExecutor(GenericPipelineExecutor):
103
101
  node_name=node.name,
104
102
  )
105
103
  node_to_send.executable = executable
106
- self._execute_node(
107
- node_to_send, map_variable=map_variable, mock=False, **kwargs
108
- )
104
+ self._execute_node(node_to_send, map_variable=map_variable, mock=False)
109
105
 
110
106
  def _resolve_executor_config(self, node: BaseNode):
111
107
  """
@@ -144,9 +140,7 @@ class MockedExecutor(GenericPipelineExecutor):
144
140
 
145
141
  return effective_node_config
146
142
 
147
- def execute_node(
148
- self, node: BaseNode, map_variable: TypeMapVariable = None, **kwargs
149
- ):
143
+ def execute_node(self, node: BaseNode, map_variable: TypeMapVariable = None):
150
144
  """
151
145
  The entry point for all executors apart from local.
152
146
  We have already prepared for node execution.
@@ -63,9 +63,7 @@ class RetryExecutor(GenericPipelineExecutor):
63
63
  # Should the parameters be copied from previous execution
64
64
  # self._set_up_for_re_run(params=params)
65
65
 
66
- def execute_from_graph(
67
- self, node: BaseNode, map_variable: TypeMapVariable = None, **kwargs
68
- ):
66
+ def execute_from_graph(self, node: BaseNode, map_variable: TypeMapVariable = None):
69
67
  """
70
68
  This is the entry point to from the graph execution.
71
69
 
@@ -103,7 +101,7 @@ class RetryExecutor(GenericPipelineExecutor):
103
101
  # If its a terminal node, complete it now
104
102
  if node.node_type in ["success", "fail"]:
105
103
  self._context.run_log_store.add_step_log(step_log, self._context.run_id)
106
- self._execute_node(node, map_variable=map_variable, **kwargs)
104
+ self._execute_node(node, map_variable=map_variable)
107
105
  return
108
106
 
109
107
  # In retry step
@@ -118,12 +116,12 @@ class RetryExecutor(GenericPipelineExecutor):
118
116
  # We call an internal function to iterate the sub graphs and execute them
119
117
  if node.is_composite:
120
118
  self._context.run_log_store.add_step_log(step_log, self._context.run_id)
121
- node.execute_as_graph(map_variable=map_variable, **kwargs)
119
+ node.execute_as_graph(map_variable=map_variable)
122
120
  return
123
121
 
124
122
  # Executor specific way to trigger a job
125
123
  self._context.run_log_store.add_step_log(step_log, self._context.run_id)
126
- self.execute_node(node=node, map_variable=map_variable, **kwargs)
124
+ self.execute_node(node=node, map_variable=map_variable)
127
125
 
128
126
  def _is_step_eligible_for_rerun(
129
127
  self, node: BaseNode, map_variable: TypeMapVariable = None
@@ -174,7 +172,5 @@ class RetryExecutor(GenericPipelineExecutor):
174
172
  self._restart_initiated = True
175
173
  return True
176
174
 
177
- def execute_node(
178
- self, node: BaseNode, map_variable: TypeMapVariable = None, **kwargs
179
- ):
180
- self._execute_node(node, map_variable=map_variable, **kwargs)
175
+ def execute_node(self, node: BaseNode, map_variable: TypeMapVariable = None):
176
+ self._execute_node(node, map_variable=map_variable)
@@ -318,7 +318,6 @@ class ChunkedRunLogStore(BaseRunLogStore):
318
318
  tag: str = "",
319
319
  original_run_id: str = "",
320
320
  status: str = defaults.CREATED,
321
- **kwargs,
322
321
  ):
323
322
  """
324
323
  Creates a Run Log object by using the config
@@ -549,7 +548,7 @@ class ChunkedRunLogStore(BaseRunLogStore):
549
548
  )
550
549
 
551
550
  def get_branch_log(
552
- self, internal_branch_name: str, run_id: str, **kwargs
551
+ self, internal_branch_name: str, run_id: str
553
552
  ) -> Union[BranchLog, RunLog]:
554
553
  """
555
554
  Returns the branch log by the internal branch name for the run id
@@ -37,7 +37,7 @@ class DotEnvSecrets(BaseSecrets):
37
37
  """
38
38
  self.secrets = dotenv_values(self.secrets_location)
39
39
 
40
- def get(self, name: str = "", **kwargs) -> str:
40
+ def get(self, name: str = "") -> str:
41
41
  """
42
42
  Get a secret of name from the secrets file.
43
43
 
@@ -0,0 +1,52 @@
1
+ from typing import List, Optional
2
+
3
+ from pydantic import Field, field_validator
4
+
5
+ from runnable import defaults
6
+ from runnable.datastore import StepAttempt
7
+ from runnable.defaults import TypeMapVariable
8
+ from runnable.tasks import BaseTaskType
9
+
10
+
11
+ def run_torch_task(
12
+ rank: int = 1,
13
+ world_size: int = 1,
14
+ entrypoint: str = "some function",
15
+ catalog: Optional[dict[str, List[str]]] = None,
16
+ task_returns: Optional[List[str]] = None,
17
+ secrets: Optional[list[str]] = None,
18
+ ):
19
+ # Entry point that creates a python job using simpler python types
20
+ # and and executes them. The run_id for the job is set to be run_id_rank
21
+ # Since the configuration file is passes as environmental variable,
22
+ # The job will use the configuration file to get the required information.
23
+
24
+ # In pseudocode, the following is done:
25
+ # Create the catalog object
26
+ # Create the secrets and other objects required for the PythonJob
27
+ # Init the process group using:
28
+ # https://github.com/pytorch/examples/blob/main/imagenet/main.py#L140
29
+ # Execute the job, the job is expected to use the environmental variables
30
+ # to identify the rank or can have them as variable in the signature.
31
+ # Once the job is executed, we destroy the process group
32
+ pass
33
+
34
+
35
+ class TorchTaskType(BaseTaskType):
36
+ task_type: str = Field(default="torch", serialization_alias="command_type")
37
+ command: str
38
+ num_gpus: int = Field(default=1, description="Number of GPUs to use")
39
+
40
+ @field_validator("num_gpus")
41
+ @classmethod
42
+ def check_if_cuda_is_available(cls, num_gpus: int) -> int:
43
+ # Import torch and check if cuda is available
44
+ # validate if the number of gpus is less than or equal to available gpus
45
+ return num_gpus
46
+
47
+ def execute_command(
48
+ self,
49
+ map_variable: TypeMapVariable = None,
50
+ ) -> StepAttempt:
51
+ # We have to spawn here
52
+ return StepAttempt(attempt_number=1, status=defaults.SUCCESS)
runnable/__init__.py CHANGED
@@ -31,6 +31,7 @@ from runnable.sdk import ( # noqa
31
31
  ShellTask,
32
32
  Stub,
33
33
  Success,
34
+ Torch,
34
35
  metric,
35
36
  pickled,
36
37
  )
runnable/entrypoints.py CHANGED
@@ -174,7 +174,7 @@ def set_pipeline_spec_from_yaml(run_context: context.Context, pipeline_file: str
174
174
  def set_pipeline_spec_from_python(run_context: context.Context, python_module: str):
175
175
  # Call the SDK to get the dag
176
176
  # Import the module and call the function to get the dag
177
- module_file = python_module.strip(".py")
177
+ module_file = python_module.rstrip(".py")
178
178
  module, func = utils.get_module_and_attr_names(module_file)
179
179
  sys.path.insert(0, os.getcwd()) # Need to add the current directory to path
180
180
  imported_module = importlib.import_module(module)
@@ -429,7 +429,7 @@ def set_job_spec_from_yaml(run_context: context.Context, job_definition_file: st
429
429
 
430
430
  def set_job_spec_from_python(run_context: context.Context, python_module: str):
431
431
  # Import the module and call the function to get the task
432
- module_file = python_module.strip(".py")
432
+ module_file = python_module.rstrip(".py")
433
433
  module, func = utils.get_module_and_attr_names(module_file)
434
434
  sys.path.insert(0, os.getcwd()) # Need to add the current directory to path
435
435
  imported_module = importlib.import_module(module)
runnable/executor.py CHANGED
@@ -107,7 +107,7 @@ class BaseJobExecutor(BaseExecutor):
107
107
  ...
108
108
 
109
109
  @abstractmethod
110
- def add_code_identities(self, job_log: JobLog, **kwargs):
110
+ def add_code_identities(self, job_log: JobLog):
111
111
  """
112
112
  Add code identities specific to the implementation.
113
113
 
@@ -161,7 +161,7 @@ class BasePipelineExecutor(BaseExecutor):
161
161
  _context_node: Optional[BaseNode] = PrivateAttr(default=None)
162
162
 
163
163
  @abstractmethod
164
- def add_code_identities(self, node: BaseNode, step_log: StepLog, **kwargs):
164
+ def add_code_identities(self, node: BaseNode, step_log: StepLog):
165
165
  """
166
166
  Add code identities specific to the implementation.
167
167
 
@@ -204,7 +204,6 @@ class BasePipelineExecutor(BaseExecutor):
204
204
  node: BaseNode,
205
205
  map_variable: TypeMapVariable = None,
206
206
  mock: bool = False,
207
- **kwargs,
208
207
  ):
209
208
  """
210
209
  This is the entry point when we do the actual execution of the function.
@@ -227,9 +226,7 @@ class BasePipelineExecutor(BaseExecutor):
227
226
  ...
228
227
 
229
228
  @abstractmethod
230
- def execute_node(
231
- self, node: BaseNode, map_variable: TypeMapVariable = None, **kwargs
232
- ):
229
+ def execute_node(self, node: BaseNode, map_variable: TypeMapVariable = None):
233
230
  """
234
231
  The entry point for all executors apart from local.
235
232
  We have already prepared for node execution.
@@ -244,9 +241,7 @@ class BasePipelineExecutor(BaseExecutor):
244
241
  ...
245
242
 
246
243
  @abstractmethod
247
- def execute_from_graph(
248
- self, node: BaseNode, map_variable: TypeMapVariable = None, **kwargs
249
- ):
244
+ def execute_from_graph(self, node: BaseNode, map_variable: TypeMapVariable = None):
250
245
  """
251
246
  This is the entry point to from the graph execution.
252
247
 
@@ -294,7 +289,7 @@ class BasePipelineExecutor(BaseExecutor):
294
289
  ...
295
290
 
296
291
  @abstractmethod
297
- def execute_graph(self, dag: Graph, map_variable: TypeMapVariable = None, **kwargs):
292
+ def execute_graph(self, dag: Graph, map_variable: TypeMapVariable = None):
298
293
  """
299
294
  The parallelization is controlled by the nodes and not by this function.
300
295
 
@@ -395,7 +390,7 @@ class BasePipelineExecutor(BaseExecutor):
395
390
 
396
391
  @abstractmethod
397
392
  def trigger_node_execution(
398
- self, node: BaseNode, map_variable: TypeMapVariable = None, **kwargs
393
+ self, node: BaseNode, map_variable: TypeMapVariable = None
399
394
  ):
400
395
  """
401
396
  Executor specific way of triggering jobs when runnable does both traversal and execution
runnable/nodes.py CHANGED
@@ -37,6 +37,7 @@ class BaseNode(ABC, BaseModel):
37
37
  internal_name: str = Field(exclude=True)
38
38
  internal_branch_name: str = Field(default="", exclude=True)
39
39
  is_composite: bool = Field(default=False, exclude=True)
40
+ is_distributed: bool = Field(default=False, exclude=True)
40
41
 
41
42
  @property
42
43
  def _context(self):
@@ -280,7 +281,6 @@ class BaseNode(ABC, BaseModel):
280
281
  mock=False,
281
282
  map_variable: TypeMapVariable = None,
282
283
  attempt_number: int = 1,
283
- **kwargs,
284
284
  ) -> StepLog:
285
285
  """
286
286
  The actual function that does the execution of the command in the config.
@@ -299,7 +299,7 @@ class BaseNode(ABC, BaseModel):
299
299
  """
300
300
 
301
301
  @abstractmethod
302
- def execute_as_graph(self, map_variable: TypeMapVariable = None, **kwargs):
302
+ def execute_as_graph(self, map_variable: TypeMapVariable = None):
303
303
  """
304
304
  This function would be called to set up the execution of the individual
305
305
  branches of a composite node.
@@ -314,7 +314,7 @@ class BaseNode(ABC, BaseModel):
314
314
  """
315
315
 
316
316
  @abstractmethod
317
- def fan_out(self, map_variable: TypeMapVariable = None, **kwargs):
317
+ def fan_out(self, map_variable: TypeMapVariable = None):
318
318
  """
319
319
  This function would be called to set up the execution of the individual
320
320
  branches of a composite node.
@@ -330,7 +330,7 @@ class BaseNode(ABC, BaseModel):
330
330
  """
331
331
 
332
332
  @abstractmethod
333
- def fan_in(self, map_variable: TypeMapVariable = None, **kwargs):
333
+ def fan_in(self, map_variable: TypeMapVariable = None):
334
334
  """
335
335
  This function would be called to tear down the execution of the individual
336
336
  branches of a composite node.
@@ -439,33 +439,25 @@ class ExecutableNode(TraversalNode):
439
439
  "This is an executable node and does not have branches"
440
440
  )
441
441
 
442
- def execute_as_graph(self, map_variable: TypeMapVariable = None, **kwargs):
442
+ def execute_as_graph(self, map_variable: TypeMapVariable = None):
443
443
  raise exceptions.NodeMethodCallError(
444
444
  "This is an executable node and does not have a graph"
445
445
  )
446
446
 
447
- def fan_in(self, map_variable: TypeMapVariable = None, **kwargs):
447
+ def fan_in(self, map_variable: TypeMapVariable = None):
448
448
  raise exceptions.NodeMethodCallError(
449
449
  "This is an executable node and does not have a fan in"
450
450
  )
451
451
 
452
- def fan_out(self, map_variable: TypeMapVariable = None, **kwargs):
452
+ def fan_out(self, map_variable: TypeMapVariable = None):
453
453
  raise exceptions.NodeMethodCallError(
454
454
  "This is an executable node and does not have a fan out"
455
455
  )
456
456
 
457
- def prepare_for_job_execution(self):
458
- raise exceptions.NodeMethodCallError(
459
- "This is an executable node and does not have a prepare_for_job_execution"
460
- )
461
-
462
- def tear_down_after_job_execution(self):
463
- raise exceptions.NodeMethodCallError(
464
- "This is an executable node and does not have a tear_down_after_job_execution",
465
- )
466
-
467
457
 
468
458
  class CompositeNode(TraversalNode):
459
+ is_composite: bool = True
460
+
469
461
  def _get_catalog_settings(self) -> Dict[str, Any]:
470
462
  """
471
463
  If the node defines a catalog settings, return it or None
@@ -485,20 +477,44 @@ class CompositeNode(TraversalNode):
485
477
  mock=False,
486
478
  map_variable: TypeMapVariable = None,
487
479
  attempt_number: int = 1,
488
- **kwargs,
489
480
  ) -> StepLog:
490
481
  raise exceptions.NodeMethodCallError(
491
482
  "This is a composite node and does not have an execute function"
492
483
  )
493
484
 
494
- def prepare_for_job_execution(self):
485
+
486
+ class DistributedNode(TraversalNode):
487
+ """
488
+ Use this node for distributed execution of tasks.
489
+ eg: torch distributed, horovod, etc.
490
+ """
491
+
492
+ is_distributed: bool = True
493
+ catalog: Optional[CatalogStructure] = Field(default=None)
494
+ max_attempts: int = Field(default=1, ge=1)
495
+
496
+ def _get_catalog_settings(self) -> Dict[str, Any]:
497
+ """
498
+ If the node defines a catalog settings, return it or None
499
+
500
+ Returns:
501
+ dict: catalog settings defined as per the node or None
502
+ """
503
+ if self.catalog:
504
+ return self.catalog.model_dump()
505
+ return {}
506
+
507
+ def _get_max_attempts(self) -> int:
508
+ return self.max_attempts
509
+
510
+ def _get_branch_by_name(self, branch_name: str):
495
511
  raise exceptions.NodeMethodCallError(
496
- "This is an executable node and does not have a prepare_for_job_execution"
512
+ "This is an distributed node and does not have branches"
497
513
  )
498
514
 
499
- def tear_down_after_job_execution(self):
515
+ def execute_as_graph(self, map_variable: TypeMapVariable = None):
500
516
  raise exceptions.NodeMethodCallError(
501
- "This is an executable node and does not have a tear_down_after_job_execution"
517
+ "This is an executable node and does not have a graph"
502
518
  )
503
519
 
504
520
 
@@ -524,13 +540,16 @@ class TerminalNode(BaseNode):
524
540
  def _get_max_attempts(self) -> int:
525
541
  return 1
526
542
 
527
- def execute_as_graph(self, map_variable: TypeMapVariable = None, **kwargs):
543
+ def execute_as_graph(self, map_variable: TypeMapVariable = None):
528
544
  raise exceptions.TerminalNodeError()
529
545
 
530
- def fan_in(self, map_variable: TypeMapVariable = None, **kwargs):
546
+ def fan_in(self, map_variable: TypeMapVariable = None):
531
547
  raise exceptions.TerminalNodeError()
532
548
 
533
- def fan_out(self, map_variable: TypeMapVariable = None, **kwargs):
549
+ def fan_out(
550
+ self,
551
+ map_variable: TypeMapVariable = None,
552
+ ):
534
553
  raise exceptions.TerminalNodeError()
535
554
 
536
555
  @classmethod