runnable 0.35.0__py3-none-any.whl → 0.36.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. extensions/job_executor/__init__.py +3 -4
  2. extensions/job_executor/emulate.py +106 -0
  3. extensions/job_executor/k8s.py +8 -8
  4. extensions/job_executor/local_container.py +13 -14
  5. extensions/nodes/__init__.py +0 -0
  6. extensions/nodes/conditional.py +7 -5
  7. extensions/nodes/fail.py +72 -0
  8. extensions/nodes/map.py +350 -0
  9. extensions/nodes/parallel.py +159 -0
  10. extensions/nodes/stub.py +89 -0
  11. extensions/nodes/success.py +72 -0
  12. extensions/nodes/task.py +92 -0
  13. extensions/pipeline_executor/__init__.py +24 -26
  14. extensions/pipeline_executor/argo.py +18 -15
  15. extensions/pipeline_executor/emulate.py +112 -0
  16. extensions/pipeline_executor/local.py +4 -4
  17. extensions/pipeline_executor/local_container.py +19 -79
  18. extensions/pipeline_executor/mocked.py +4 -4
  19. extensions/pipeline_executor/retry.py +6 -10
  20. extensions/tasks/torch.py +1 -1
  21. runnable/__init__.py +0 -8
  22. runnable/catalog.py +1 -21
  23. runnable/cli.py +0 -59
  24. runnable/context.py +519 -28
  25. runnable/datastore.py +51 -54
  26. runnable/defaults.py +12 -34
  27. runnable/entrypoints.py +82 -440
  28. runnable/exceptions.py +35 -34
  29. runnable/executor.py +13 -20
  30. runnable/names.py +1 -1
  31. runnable/nodes.py +16 -15
  32. runnable/parameters.py +2 -2
  33. runnable/sdk.py +66 -163
  34. runnable/tasks.py +62 -21
  35. runnable/utils.py +6 -268
  36. {runnable-0.35.0.dist-info → runnable-0.36.0.dist-info}/METADATA +1 -1
  37. runnable-0.36.0.dist-info/RECORD +74 -0
  38. {runnable-0.35.0.dist-info → runnable-0.36.0.dist-info}/entry_points.txt +8 -7
  39. extensions/nodes/nodes.py +0 -778
  40. runnable-0.35.0.dist-info/RECORD +0 -66
  41. {runnable-0.35.0.dist-info → runnable-0.36.0.dist-info}/WHEEL +0 -0
  42. {runnable-0.35.0.dist-info → runnable-0.36.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,92 @@
1
+ import logging
2
+ from datetime import datetime
3
+ from typing import Any, Dict
4
+
5
+ from pydantic import ConfigDict, Field
6
+
7
+ from runnable import datastore, defaults
8
+ from runnable.datastore import StepLog
9
+ from runnable.defaults import MapVariableType
10
+ from runnable.nodes import ExecutableNode
11
+ from runnable.tasks import BaseTaskType, create_task
12
+
13
+ logger = logging.getLogger(defaults.LOGGER_NAME)
14
+
15
+
16
+ class TaskNode(ExecutableNode):
17
+ """
18
+ A node of type Task.
19
+
20
+ This node does the actual function execution of the graph in all cases.
21
+ """
22
+
23
+ executable: BaseTaskType = Field(exclude=True)
24
+ node_type: str = Field(default="task", serialization_alias="type")
25
+
26
+ # It is technically not allowed as parse_from_config filters them.
27
+ # This is just to get the task level configuration to be present during serialization.
28
+ model_config = ConfigDict(extra="allow")
29
+
30
+ @classmethod
31
+ def parse_from_config(cls, config: Dict[str, Any]) -> "TaskNode":
32
+ # separate task config from node config
33
+ task_config = {
34
+ k: v for k, v in config.items() if k not in TaskNode.model_fields.keys()
35
+ }
36
+ node_config = {
37
+ k: v for k, v in config.items() if k in TaskNode.model_fields.keys()
38
+ }
39
+
40
+ executable = create_task(task_config)
41
+ return cls(executable=executable, **node_config, **task_config)
42
+
43
+ def get_summary(self) -> Dict[str, Any]:
44
+ summary = {
45
+ "name": self.name,
46
+ "type": self.node_type,
47
+ "executable": self.executable.get_summary(),
48
+ "catalog": self._get_catalog_settings(),
49
+ }
50
+
51
+ return summary
52
+
53
+ def execute(
54
+ self,
55
+ mock=False,
56
+ map_variable: MapVariableType = None,
57
+ attempt_number: int = 1,
58
+ ) -> StepLog:
59
+ """
60
+ All that we do in runnable is to come to this point where we actually execute the command.
61
+
62
+ Args:
63
+ executor (_type_): The executor class
64
+ mock (bool, optional): If we should just mock and not execute. Defaults to False.
65
+ map_variable (dict, optional): If the node is part of internal branch. Defaults to None.
66
+
67
+ Returns:
68
+ StepAttempt: The attempt object
69
+ """
70
+ step_log = self._context.run_log_store.get_step_log(
71
+ self._get_step_log_name(map_variable), self._context.run_id
72
+ )
73
+
74
+ if not mock:
75
+ # Do not run if we are mocking the execution, could be useful for caching and dry runs
76
+ attempt_log = self.executable.execute_command(map_variable=map_variable)
77
+ attempt_log.attempt_number = attempt_number
78
+ else:
79
+ attempt_log = datastore.StepAttempt(
80
+ status=defaults.SUCCESS,
81
+ start_time=str(datetime.now()),
82
+ end_time=str(datetime.now()),
83
+ attempt_number=attempt_number,
84
+ )
85
+
86
+ logger.info(f"attempt_log: {attempt_log}")
87
+ logger.info(f"Step {self.name} completed with status: {attempt_log.status}")
88
+
89
+ step_log.status = attempt_log.status
90
+ step_log.attempts.append(attempt_log)
91
+
92
+ return step_log
@@ -13,7 +13,7 @@ from runnable import (
13
13
  utils,
14
14
  )
15
15
  from runnable.datastore import DataCatalog, JsonParameter, RunLog, StepLog
16
- from runnable.defaults import TypeMapVariable
16
+ from runnable.defaults import MapVariableType
17
17
  from runnable.executor import BasePipelineExecutor
18
18
  from runnable.graph import Graph
19
19
  from runnable.nodes import BaseNode
@@ -40,7 +40,7 @@ class GenericPipelineExecutor(BasePipelineExecutor):
40
40
 
41
41
  @property
42
42
  def _context(self):
43
- assert context.run_context
43
+ assert isinstance(context.run_context, context.PipelineContext)
44
44
  return context.run_context
45
45
 
46
46
  def _get_parameters(self) -> Dict[str, JsonParameter]:
@@ -104,7 +104,7 @@ class GenericPipelineExecutor(BasePipelineExecutor):
104
104
  )
105
105
 
106
106
  # Update run_config
107
- run_config = utils.get_run_config()
107
+ run_config = self._context.model_dump()
108
108
  logger.debug(f"run_config as seen by executor: {run_config}")
109
109
  self._context.run_log_store.set_run_config(
110
110
  run_id=self._context.run_id, run_config=run_config
@@ -154,12 +154,12 @@ class GenericPipelineExecutor(BasePipelineExecutor):
154
154
  data_catalogs = []
155
155
  for name_pattern in node_catalog_settings.get(stage) or []:
156
156
  if stage == "get":
157
- data_catalog = self._context.catalog_handler.get(
157
+ data_catalog = self._context.catalog.get(
158
158
  name=name_pattern,
159
159
  )
160
160
 
161
161
  elif stage == "put":
162
- data_catalog = self._context.catalog_handler.put(
162
+ data_catalog = self._context.catalog.put(
163
163
  name=name_pattern, allow_file_not_found_exc=allow_file_no_found_exc
164
164
  )
165
165
  else:
@@ -189,14 +189,15 @@ class GenericPipelineExecutor(BasePipelineExecutor):
189
189
  map_variable=map_variable,
190
190
  )
191
191
  task_console.save_text(log_file_name)
192
+ task_console.export_text(clear=True)
192
193
  # Put the log file in the catalog
193
- self._context.catalog_handler.put(name=log_file_name)
194
+ self._context.catalog.put(name=log_file_name)
194
195
  os.remove(log_file_name)
195
196
 
196
197
  def _execute_node(
197
198
  self,
198
199
  node: BaseNode,
199
- map_variable: TypeMapVariable = None,
200
+ map_variable: MapVariableType = None,
200
201
  mock: bool = False,
201
202
  ):
202
203
  """
@@ -250,6 +251,10 @@ class GenericPipelineExecutor(BasePipelineExecutor):
250
251
  console.print(f"Summary of the step: {step_log.internal_name}")
251
252
  console.print(step_log.get_summary(), style=defaults.info_style)
252
253
 
254
+ self.add_task_log_to_catalog(
255
+ name=self._context_node.internal_name, map_variable=map_variable
256
+ )
257
+
253
258
  self._context_node = None
254
259
 
255
260
  self._context.run_log_store.add_step_log(step_log, self._context.run_id)
@@ -266,7 +271,7 @@ class GenericPipelineExecutor(BasePipelineExecutor):
266
271
  """
267
272
  step_log.code_identities.append(utils.get_git_code_identity())
268
273
 
269
- def execute_from_graph(self, node: BaseNode, map_variable: TypeMapVariable = None):
274
+ def execute_from_graph(self, node: BaseNode, map_variable: MapVariableType = None):
270
275
  """
271
276
  This is the entry point to from the graph execution.
272
277
 
@@ -315,8 +320,6 @@ class GenericPipelineExecutor(BasePipelineExecutor):
315
320
  node.execute_as_graph(map_variable=map_variable)
316
321
  return
317
322
 
318
- task_console.export_text(clear=True)
319
-
320
323
  task_name = node._resolve_map_placeholders(node.internal_name, map_variable)
321
324
  console.print(
322
325
  f":runner: Executing the node {task_name} ... ", style="bold color(208)"
@@ -324,7 +327,7 @@ class GenericPipelineExecutor(BasePipelineExecutor):
324
327
  self.trigger_node_execution(node=node, map_variable=map_variable)
325
328
 
326
329
  def trigger_node_execution(
327
- self, node: BaseNode, map_variable: TypeMapVariable = None
330
+ self, node: BaseNode, map_variable: MapVariableType = None
328
331
  ):
329
332
  """
330
333
  Call this method only if we are responsible for traversing the graph via
@@ -342,7 +345,7 @@ class GenericPipelineExecutor(BasePipelineExecutor):
342
345
  pass
343
346
 
344
347
  def _get_status_and_next_node_name(
345
- self, current_node: BaseNode, dag: Graph, map_variable: TypeMapVariable = None
348
+ self, current_node: BaseNode, dag: Graph, map_variable: MapVariableType = None
346
349
  ) -> tuple[str, str]:
347
350
  """
348
351
  Given the current node and the graph, returns the name of the next node to execute.
@@ -380,7 +383,7 @@ class GenericPipelineExecutor(BasePipelineExecutor):
380
383
 
381
384
  return step_log.status, next_node_name
382
385
 
383
- def execute_graph(self, dag: Graph, map_variable: TypeMapVariable = None):
386
+ def execute_graph(self, dag: Graph, map_variable: MapVariableType = None):
384
387
  """
385
388
  The parallelization is controlled by the nodes and not by this function.
386
389
 
@@ -409,7 +412,7 @@ class GenericPipelineExecutor(BasePipelineExecutor):
409
412
  dag.internal_branch_name or "Graph",
410
413
  map_variable,
411
414
  )
412
- branch_execution_task = self._context.progress.add_task(
415
+ branch_execution_task = context.progress.add_task(
413
416
  f"[dark_orange]Executing {branch_task_name}",
414
417
  total=1,
415
418
  )
@@ -429,7 +432,7 @@ class GenericPipelineExecutor(BasePipelineExecutor):
429
432
 
430
433
  depth = " " * ((task_name.count(".")) or 1 - 1)
431
434
 
432
- task_execution = self._context.progress.add_task(
435
+ task_execution = context.progress.add_task(
433
436
  f"{depth}Executing {task_name}", total=1
434
437
  )
435
438
 
@@ -440,20 +443,20 @@ class GenericPipelineExecutor(BasePipelineExecutor):
440
443
  )
441
444
 
442
445
  if status == defaults.SUCCESS:
443
- self._context.progress.update(
446
+ context.progress.update(
444
447
  task_execution,
445
448
  description=f"{depth}[green] {task_name} Completed",
446
449
  completed=True,
447
450
  overflow="fold",
448
451
  )
449
452
  else:
450
- self._context.progress.update(
453
+ context.progress.update(
451
454
  task_execution,
452
455
  description=f"{depth}[red] {task_name} Failed",
453
456
  completed=True,
454
457
  ) # type ignore
455
458
  except Exception as e: # noqa: E722
456
- self._context.progress.update(
459
+ context.progress.update(
457
460
  task_execution,
458
461
  description=f"{depth}[red] {task_name} Errored",
459
462
  completed=True,
@@ -461,11 +464,6 @@ class GenericPipelineExecutor(BasePipelineExecutor):
461
464
  console.print(e, style=defaults.error_style)
462
465
  logger.exception(e)
463
466
  raise
464
- finally:
465
- # Add task log to the catalog
466
- self.add_task_log_to_catalog(
467
- name=working_on.internal_name, map_variable=map_variable
468
- )
469
467
 
470
468
  console.rule(style="[dark orange]")
471
469
 
@@ -475,7 +473,7 @@ class GenericPipelineExecutor(BasePipelineExecutor):
475
473
  current_node = next_node_name
476
474
 
477
475
  if branch_execution_task:
478
- self._context.progress.update(
476
+ context.progress.update(
479
477
  branch_execution_task,
480
478
  description=f"[green3] {branch_task_name} completed",
481
479
  completed=True,
@@ -567,7 +565,7 @@ class GenericPipelineExecutor(BasePipelineExecutor):
567
565
 
568
566
  return effective_node_config
569
567
 
570
- def fan_out(self, node: BaseNode, map_variable: TypeMapVariable = None):
568
+ def fan_out(self, node: BaseNode, map_variable: MapVariableType = None):
571
569
  """
572
570
  This method is used to appropriately fan-out the execution of a composite node.
573
571
  This is only useful when we want to execute a composite node during 3rd party orchestrators.
@@ -599,7 +597,7 @@ class GenericPipelineExecutor(BasePipelineExecutor):
599
597
 
600
598
  node.fan_out(map_variable=map_variable)
601
599
 
602
- def fan_in(self, node: BaseNode, map_variable: TypeMapVariable = None):
600
+ def fan_in(self, node: BaseNode, map_variable: MapVariableType = None):
603
601
  """
604
602
  This method is used to appropriately fan-in after the execution of a composite node.
605
603
  This is only useful when we want to execute a composite node during 3rd party orchestrators.
@@ -21,13 +21,15 @@ from pydantic.alias_generators import to_camel
21
21
  from ruamel.yaml import YAML
22
22
 
23
23
  from extensions.nodes.conditional import ConditionalNode
24
- from extensions.nodes.nodes import MapNode, ParallelNode, TaskNode
24
+ from extensions.nodes.map import MapNode
25
+ from extensions.nodes.parallel import ParallelNode
26
+ from extensions.nodes.task import TaskNode
25
27
 
26
28
  # TODO: Should be part of a wider refactor
27
29
  # from extensions.nodes.torch import TorchNode
28
30
  from extensions.pipeline_executor import GenericPipelineExecutor
29
- from runnable import defaults, utils
30
- from runnable.defaults import TypeMapVariable
31
+ from runnable import defaults
32
+ from runnable.defaults import MapVariableType
31
33
  from runnable.graph import Graph, search_node_by_internal_name
32
34
  from runnable.nodes import BaseNode
33
35
 
@@ -453,7 +455,7 @@ class ArgoExecutor(GenericPipelineExecutor):
453
455
  """
454
456
 
455
457
  service_name: str = "argo"
456
- _is_local: bool = False
458
+ _should_setup_run_log_at_traversal: bool = PrivateAttr(default=False)
457
459
  mock: bool = False
458
460
 
459
461
  model_config = ConfigDict(
@@ -535,13 +537,13 @@ class ArgoExecutor(GenericPipelineExecutor):
535
537
  parameters: Optional[list[Parameter]],
536
538
  task_name: str,
537
539
  ):
538
- map_variable: TypeMapVariable = {}
540
+ map_variable: MapVariableType = {}
539
541
  for parameter in parameters or []:
540
542
  map_variable[parameter.name] = ( # type: ignore
541
543
  "{{inputs.parameters." + str(parameter.name) + "}}"
542
544
  )
543
545
 
544
- fan_command = utils.get_fan_command(
546
+ fan_command = self._context.get_fan_command(
545
547
  mode=mode,
546
548
  node=node,
547
549
  run_id=self._run_id_as_parameter,
@@ -606,17 +608,17 @@ class ArgoExecutor(GenericPipelineExecutor):
606
608
 
607
609
  inputs = inputs or Inputs(parameters=[])
608
610
 
609
- map_variable: TypeMapVariable = {}
611
+ map_variable: MapVariableType = {}
610
612
  for parameter in inputs.parameters or []:
611
613
  map_variable[parameter.name] = ( # type: ignore
612
614
  "{{inputs.parameters." + str(parameter.name) + "}}"
613
615
  )
614
616
 
615
617
  # command = "runnable execute-single-node"
616
- command = utils.get_node_execution_command(
618
+ command = self._context.get_node_callable_command(
617
619
  node=node,
618
- over_write_run_id=self._run_id_as_parameter,
619
620
  map_variable=map_variable,
621
+ over_write_run_id=self._run_id_as_parameter,
620
622
  log_level=self._log_level_as_parameter,
621
623
  )
622
624
 
@@ -715,6 +717,7 @@ class ArgoExecutor(GenericPipelineExecutor):
715
717
  assert parent_dag_template.dag
716
718
 
717
719
  parent_dag_template.dag.tasks.append(on_failure_task)
720
+
718
721
  self._gather_tasks_for_dag_template(
719
722
  on_failure_dag,
720
723
  dag=dag,
@@ -762,7 +765,7 @@ class ArgoExecutor(GenericPipelineExecutor):
762
765
  depends = task_name
763
766
 
764
767
  match working_on.node_type:
765
- case "task" | "success" | "stub":
768
+ case "task" | "success" | "stub" | "fail":
766
769
  template_of_container = self._create_container_template(
767
770
  working_on,
768
771
  task_name=task_name,
@@ -958,7 +961,7 @@ class ArgoExecutor(GenericPipelineExecutor):
958
961
  f,
959
962
  )
960
963
 
961
- def _implicitly_fail(self, node: BaseNode, map_variable: TypeMapVariable):
964
+ def _implicitly_fail(self, node: BaseNode, map_variable: MapVariableType):
962
965
  assert self._context.dag
963
966
  _, current_branch = search_node_by_internal_name(
964
967
  dag=self._context.dag, internal_name=node.internal_name
@@ -1005,7 +1008,7 @@ class ArgoExecutor(GenericPipelineExecutor):
1005
1008
 
1006
1009
  self._implicitly_fail(node, map_variable)
1007
1010
 
1008
- def fan_out(self, node: BaseNode, map_variable: TypeMapVariable = None):
1011
+ def fan_out(self, node: BaseNode, map_variable: MapVariableType = None):
1009
1012
  # This could be the first step of the graph
1010
1013
  self._use_volumes()
1011
1014
 
@@ -1031,7 +1034,7 @@ class ArgoExecutor(GenericPipelineExecutor):
1031
1034
  with open("/tmp/output.txt", mode="w", encoding="utf-8") as myfile:
1032
1035
  json.dump(node.get_parameter_value(), myfile, indent=4)
1033
1036
 
1034
- def fan_in(self, node: BaseNode, map_variable: TypeMapVariable = None):
1037
+ def fan_in(self, node: BaseNode, map_variable: MapVariableType = None):
1035
1038
  self._use_volumes()
1036
1039
  super().fan_in(node, map_variable)
1037
1040
 
@@ -1042,9 +1045,9 @@ class ArgoExecutor(GenericPipelineExecutor):
1042
1045
  case "chunked-fs":
1043
1046
  self._context.run_log_store.log_folder = self._container_log_location
1044
1047
 
1045
- match self._context.catalog_handler.service_name:
1048
+ match self._context.catalog.service_name:
1046
1049
  case "file-system":
1047
- self._context.catalog_handler.catalog_location = (
1050
+ self._context.catalog.catalog_location = (
1048
1051
  self._container_catalog_location
1049
1052
  )
1050
1053
 
@@ -0,0 +1,112 @@
1
+ import logging
2
+ import shlex
3
+ import subprocess
4
+ import sys
5
+
6
+ from pydantic import PrivateAttr
7
+
8
+ from extensions.pipeline_executor import GenericPipelineExecutor
9
+ from runnable import defaults
10
+ from runnable.defaults import MapVariableType
11
+ from runnable.nodes import BaseNode
12
+
13
+ logger = logging.getLogger(defaults.LOGGER_NAME)
14
+
15
+
16
+ class Emulator(GenericPipelineExecutor):
17
+ """
18
+ In the mode of local execution, we run everything on the local computer.
19
+
20
+ This has some serious implications on the amount of time it would take to complete the run.
21
+ Also ensure that the local compute is good enough for the compute to happen of all the steps.
22
+
23
+ Example config:
24
+
25
+ ```yaml
26
+ pipeline-executor:
27
+ type: local
28
+ ```
29
+
30
+ """
31
+
32
+ service_name: str = "emulator"
33
+
34
+ _should_setup_run_log_at_traversal: bool = PrivateAttr(default=True)
35
+
36
+ def trigger_node_execution(
37
+ self, node: BaseNode, map_variable: MapVariableType = None
38
+ ):
39
+ """
40
+ In this mode of execution, we prepare for the node execution and execute the node
41
+
42
+ Args:
43
+ node (BaseNode): [description]
44
+ map_variable (str, optional): [description]. Defaults to ''.
45
+ """
46
+ command = self._context.get_node_callable_command(
47
+ node, map_variable=map_variable
48
+ )
49
+
50
+ self.run_click_command(command)
51
+ # execute the command in a forked process
52
+
53
+ step_log = self._context.run_log_store.get_step_log(
54
+ node._get_step_log_name(map_variable), self._context.run_id
55
+ )
56
+ if step_log.status != defaults.SUCCESS:
57
+ msg = "Node execution inside the emulate failed. Please check the logs.\n"
58
+ logger.error(msg)
59
+ step_log.status = defaults.FAIL
60
+ self._context.run_log_store.add_step_log(step_log, self._context.run_id)
61
+
62
+ def execute_node(self, node: BaseNode, map_variable: MapVariableType = None):
63
+ """
64
+ For local execution, we just execute the node.
65
+
66
+ Args:
67
+ node (BaseNode): _description_
68
+ map_variable (dict[str, str], optional): _description_. Defaults to None.
69
+ """
70
+ self._execute_node(node=node, map_variable=map_variable)
71
+
72
+ def run_click_command(self, command: str) -> str:
73
+ """
74
+ Execute a Click-based CLI command in the current virtual environment.
75
+
76
+ Args:
77
+ args: List of Click command arguments (including subcommands and options)
78
+
79
+ Returns:
80
+ Combined stdout/stderr output as string
81
+ """
82
+ # For Click commands installed via setup.py entry_points
83
+ # command = [sys.executable, '-m', 'your_package.cli'] + args
84
+
85
+ # For direct module execution
86
+ sub_command = [sys.executable, "-m", "runnable.cli"] + shlex.split(command)[1:]
87
+
88
+ process = subprocess.Popen(
89
+ sub_command,
90
+ stdout=subprocess.PIPE,
91
+ stderr=subprocess.STDOUT,
92
+ universal_newlines=True,
93
+ bufsize=1,
94
+ )
95
+
96
+ output = []
97
+ try:
98
+ while True:
99
+ line = process.stdout.readline() # type: ignore
100
+ if not line and process.poll() is not None:
101
+ break
102
+ print(line, end="")
103
+ output.append(line)
104
+ finally:
105
+ process.stdout.close() # type: ignore
106
+
107
+ if process.returncode != 0:
108
+ raise subprocess.CalledProcessError(
109
+ process.returncode, command, "".join(output)
110
+ )
111
+
112
+ return "".join(output)
@@ -4,7 +4,7 @@ from pydantic import Field, PrivateAttr
4
4
 
5
5
  from extensions.pipeline_executor import GenericPipelineExecutor
6
6
  from runnable import defaults
7
- from runnable.defaults import TypeMapVariable
7
+ from runnable.defaults import MapVariableType
8
8
  from runnable.nodes import BaseNode
9
9
 
10
10
  logger = logging.getLogger(defaults.LOGGER_NAME)
@@ -32,14 +32,14 @@ class LocalExecutor(GenericPipelineExecutor):
32
32
 
33
33
  _is_local: bool = PrivateAttr(default=True)
34
34
 
35
- def execute_from_graph(self, node: BaseNode, map_variable: TypeMapVariable = None):
35
+ def execute_from_graph(self, node: BaseNode, map_variable: MapVariableType = None):
36
36
  if not self.object_serialisation:
37
37
  self._context.object_serialisation = False
38
38
 
39
39
  super().execute_from_graph(node=node, map_variable=map_variable)
40
40
 
41
41
  def trigger_node_execution(
42
- self, node: BaseNode, map_variable: TypeMapVariable = None
42
+ self, node: BaseNode, map_variable: MapVariableType = None
43
43
  ):
44
44
  """
45
45
  In this mode of execution, we prepare for the node execution and execute the node
@@ -50,7 +50,7 @@ class LocalExecutor(GenericPipelineExecutor):
50
50
  """
51
51
  self.execute_node(node=node, map_variable=map_variable)
52
52
 
53
- def execute_node(self, node: BaseNode, map_variable: TypeMapVariable = None):
53
+ def execute_node(self, node: BaseNode, map_variable: MapVariableType = None):
54
54
  """
55
55
  For local execution, we just execute the node.
56
56