runnable 0.35.0__py3-none-any.whl → 0.36.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. extensions/job_executor/__init__.py +3 -4
  2. extensions/job_executor/emulate.py +106 -0
  3. extensions/job_executor/k8s.py +8 -8
  4. extensions/job_executor/local_container.py +13 -14
  5. extensions/nodes/__init__.py +0 -0
  6. extensions/nodes/conditional.py +7 -5
  7. extensions/nodes/fail.py +72 -0
  8. extensions/nodes/map.py +350 -0
  9. extensions/nodes/parallel.py +159 -0
  10. extensions/nodes/stub.py +89 -0
  11. extensions/nodes/success.py +72 -0
  12. extensions/nodes/task.py +92 -0
  13. extensions/pipeline_executor/__init__.py +24 -26
  14. extensions/pipeline_executor/argo.py +18 -15
  15. extensions/pipeline_executor/emulate.py +112 -0
  16. extensions/pipeline_executor/local.py +4 -4
  17. extensions/pipeline_executor/local_container.py +19 -79
  18. extensions/pipeline_executor/mocked.py +4 -4
  19. extensions/pipeline_executor/retry.py +6 -10
  20. extensions/tasks/torch.py +1 -1
  21. runnable/__init__.py +0 -8
  22. runnable/catalog.py +1 -21
  23. runnable/cli.py +0 -59
  24. runnable/context.py +519 -28
  25. runnable/datastore.py +51 -54
  26. runnable/defaults.py +12 -34
  27. runnable/entrypoints.py +82 -440
  28. runnable/exceptions.py +35 -34
  29. runnable/executor.py +13 -20
  30. runnable/names.py +1 -1
  31. runnable/nodes.py +16 -15
  32. runnable/parameters.py +2 -2
  33. runnable/sdk.py +66 -163
  34. runnable/tasks.py +62 -21
  35. runnable/utils.py +6 -268
  36. {runnable-0.35.0.dist-info → runnable-0.36.0.dist-info}/METADATA +1 -1
  37. runnable-0.36.0.dist-info/RECORD +74 -0
  38. {runnable-0.35.0.dist-info → runnable-0.36.0.dist-info}/entry_points.txt +8 -7
  39. extensions/nodes/nodes.py +0 -778
  40. runnable-0.35.0.dist-info/RECORD +0 -66
  41. {runnable-0.35.0.dist-info → runnable-0.36.0.dist-info}/WHEEL +0 -0
  42. {runnable-0.35.0.dist-info → runnable-0.36.0.dist-info}/licenses/LICENSE +0 -0
runnable/sdk.py CHANGED
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import inspect
3
4
  import logging
4
- import os
5
5
  import re
6
6
  from abc import ABC, abstractmethod
7
7
  from pathlib import Path
@@ -16,27 +16,17 @@ from pydantic import (
16
16
  field_validator,
17
17
  model_validator,
18
18
  )
19
- from rich.progress import (
20
- BarColumn,
21
- Progress,
22
- SpinnerColumn,
23
- TextColumn,
24
- TimeElapsedColumn,
25
- )
26
- from rich.table import Column
27
19
  from typing_extensions import Self
28
20
 
29
21
  from extensions.nodes.conditional import ConditionalNode
30
- from extensions.nodes.nodes import (
31
- FailNode,
32
- MapNode,
33
- ParallelNode,
34
- StubNode,
35
- SuccessNode,
36
- TaskNode,
37
- )
38
- from runnable import console, defaults, entrypoints, exceptions, graph, utils
39
- from runnable.executor import BaseJobExecutor, BasePipelineExecutor
22
+ from extensions.nodes.fail import FailNode
23
+ from extensions.nodes.map import MapNode
24
+ from extensions.nodes.parallel import ParallelNode
25
+ from extensions.nodes.stub import StubNode
26
+ from extensions.nodes.success import SuccessNode
27
+ from extensions.nodes.task import TaskNode
28
+ from runnable import defaults, graph
29
+ from runnable.executor import BaseJobExecutor
40
30
  from runnable.nodes import TraversalNode
41
31
  from runnable.tasks import BaseTaskType as RunnableTask
42
32
  from runnable.tasks import TaskReturns
@@ -196,7 +186,7 @@ class BaseTask(BaseTraversal):
196
186
  )
197
187
 
198
188
  def as_pipeline(self) -> "Pipeline":
199
- return Pipeline(steps=[self]) # type: ignore
189
+ return Pipeline(steps=[self], name=self.internal_name) # type: ignore
200
190
 
201
191
 
202
192
  class PythonTask(BaseTask):
@@ -487,6 +477,9 @@ class Stub(BaseTraversal):
487
477
 
488
478
  return StubNode.parse_from_config(self.model_dump(exclude_none=True))
489
479
 
480
+ def as_pipeline(self) -> "Pipeline":
481
+ return Pipeline(steps=[self])
482
+
490
483
 
491
484
  class Parallel(BaseTraversal):
492
485
  """
@@ -786,6 +779,15 @@ class Pipeline(BaseModel):
786
779
  return False
787
780
  return True
788
781
 
782
+ def get_caller(self) -> str:
783
+ caller_stack = inspect.stack()[2]
784
+ relative_to_root = str(Path(caller_stack.filename).relative_to(Path.cwd()))
785
+
786
+ module_name = re.sub(r"\b.py\b", "", relative_to_root.replace("/", "."))
787
+ module_to_call = f"{module_name}.{caller_stack.function}"
788
+
789
+ return module_to_call
790
+
789
791
  def execute(
790
792
  self,
791
793
  configuration_file: str = "",
@@ -803,106 +805,31 @@ class Pipeline(BaseModel):
803
805
  # Immediately return as this call is only for getting the pipeline definition
804
806
  return {}
805
807
 
806
- logger.setLevel(log_level)
807
-
808
- run_id = utils.generate_run_id(run_id=run_id)
808
+ from runnable import context
809
809
 
810
- parameters_file = os.environ.get("RUNNABLE_PARAMETERS_FILE", parameters_file)
811
-
812
- tag = os.environ.get("RUNNABLE_tag", tag)
810
+ logger.setLevel(log_level)
813
811
 
814
- configuration_file = os.environ.get(
815
- "RUNNABLE_CONFIGURATION_FILE", configuration_file
816
- )
817
- run_context = entrypoints.prepare_configurations(
812
+ service_configurations = context.ServiceConfigurations(
818
813
  configuration_file=configuration_file,
819
- run_id=run_id,
820
- tag=tag,
821
- parameters_file=parameters_file,
822
- )
823
-
824
- assert isinstance(run_context.executor, BasePipelineExecutor)
825
-
826
- utils.set_runnable_environment_variables(
827
- run_id=run_id, configuration_file=configuration_file, tag=tag
814
+ execution_context=context.ExecutionContext.PIPELINE,
828
815
  )
829
816
 
830
- dag_definition = self._dag.model_dump(by_alias=True, exclude_none=True)
831
- run_context.from_sdk = True
832
- run_context.dag = graph.create_graph(dag_definition)
833
-
834
- console.print("Working with context:")
835
- console.print(run_context)
836
- console.rule(style="[dark orange]")
837
-
838
- if not run_context.executor._is_local:
839
- # We are not working with executor that does not work in local environment
840
- import inspect
841
-
842
- caller_stack = inspect.stack()[1]
843
- relative_to_root = str(Path(caller_stack.filename).relative_to(Path.cwd()))
844
-
845
- module_name = re.sub(r"\b.py\b", "", relative_to_root.replace("/", "."))
846
- module_to_call = f"{module_name}.{caller_stack.function}"
847
-
848
- run_context.pipeline_file = f"{module_to_call}.py"
849
- run_context.from_sdk = True
850
-
851
- # Prepare for graph execution
852
- run_context.executor._set_up_run_log(exists_ok=False)
853
-
854
- with Progress(
855
- SpinnerColumn(spinner_name="runner"),
856
- TextColumn(
857
- "[progress.description]{task.description}", table_column=Column(ratio=2)
858
- ),
859
- BarColumn(table_column=Column(ratio=1), style="dark_orange"),
860
- TimeElapsedColumn(table_column=Column(ratio=1)),
861
- console=console,
862
- expand=True,
863
- ) as progress:
864
- pipeline_execution_task = progress.add_task(
865
- "[dark_orange] Starting execution .. ", total=1
866
- )
867
- try:
868
- run_context.progress = progress
817
+ configurations = {
818
+ "pipeline_definition_file": self.get_caller(),
819
+ "parameters_file": parameters_file,
820
+ "tag": tag,
821
+ "run_id": run_id,
822
+ "execution_mode": context.ExecutionMode.PYTHON,
823
+ "configuration_file": configuration_file,
824
+ **service_configurations.services,
825
+ }
869
826
 
870
- run_context.executor.execute_graph(dag=run_context.dag)
827
+ run_context = context.PipelineContext.model_validate(configurations)
828
+ context.run_context = run_context
871
829
 
872
- if not run_context.executor._is_local:
873
- # non local executors just traverse the graph and do nothing
874
- return {}
830
+ assert isinstance(run_context, context.PipelineContext)
875
831
 
876
- run_log = run_context.run_log_store.get_run_log_by_id(
877
- run_id=run_context.run_id, full=False
878
- )
879
-
880
- if run_log.status == defaults.SUCCESS:
881
- progress.update(
882
- pipeline_execution_task,
883
- description="[green] Success",
884
- completed=True,
885
- )
886
- else:
887
- progress.update(
888
- pipeline_execution_task,
889
- description="[red] Failed",
890
- completed=True,
891
- )
892
- raise exceptions.ExecutionFailedError(run_context.run_id)
893
- except Exception as e: # noqa: E722
894
- console.print(e, style=defaults.error_style)
895
- progress.update(
896
- pipeline_execution_task,
897
- description="[red] Errored execution",
898
- completed=True,
899
- )
900
- raise
901
-
902
- if run_context.executor._is_local:
903
- return run_context.run_log_store.get_run_log_by_id(
904
- run_id=run_context.run_id
905
- )
832
+ run_context.execute()
906
833
 
907
834
 
908
835
  class BaseJob(BaseModel):
@@ -926,6 +853,15 @@ class BaseJob(BaseModel):
926
853
  def get_task(self) -> RunnableTask:
927
854
  raise NotImplementedError
928
855
 
856
+ def get_caller(self) -> str:
857
+ caller_stack = inspect.stack()[2]
858
+ relative_to_root = str(Path(caller_stack.filename).relative_to(Path.cwd()))
859
+
860
+ module_name = re.sub(r"\b.py\b", "", relative_to_root.replace("/", "."))
861
+ module_to_call = f"{module_name}.{caller_stack.function}"
862
+
863
+ return module_to_call
864
+
929
865
  def return_catalog_settings(self) -> Optional[List[str]]:
930
866
  if self.catalog is None:
931
867
  return []
@@ -952,65 +888,32 @@ class BaseJob(BaseModel):
952
888
  if self._is_called_for_definition():
953
889
  # Immediately return as this call is only for getting the job definition
954
890
  return {}
955
- logger.setLevel(log_level)
956
-
957
- run_id = utils.generate_run_id(run_id=job_id)
891
+ from runnable import context
958
892
 
959
- parameters_file = os.environ.get("RUNNABLE_PARAMETERS_FILE", parameters_file)
960
-
961
- tag = os.environ.get("RUNNABLE_tag", tag)
962
-
963
- configuration_file = os.environ.get(
964
- "RUNNABLE_CONFIGURATION_FILE", configuration_file
965
- )
893
+ logger.setLevel(log_level)
966
894
 
967
- run_context = entrypoints.prepare_configurations(
895
+ service_configurations = context.ServiceConfigurations(
968
896
  configuration_file=configuration_file,
969
- run_id=run_id,
970
- tag=tag,
971
- parameters_file=parameters_file,
972
- is_job=True,
897
+ execution_context=context.ExecutionContext.JOB,
973
898
  )
974
899
 
975
- assert isinstance(run_context.executor, BaseJobExecutor)
976
- run_context.from_sdk = True
977
-
978
- utils.set_runnable_environment_variables(
979
- run_id=run_id, configuration_file=configuration_file, tag=tag
980
- )
981
-
982
- console.print("Working with context:")
983
- console.print(run_context)
984
- console.rule(style="[dark orange]")
985
-
986
- if not run_context.executor._is_local:
987
- # We are not working with executor that does not work in local environment
988
- import inspect
989
-
990
- caller_stack = inspect.stack()[1]
991
- relative_to_root = str(Path(caller_stack.filename).relative_to(Path.cwd()))
992
-
993
- module_name = re.sub(r"\b.py\b", "", relative_to_root.replace("/", "."))
994
- module_to_call = f"{module_name}.{caller_stack.function}"
995
-
996
- run_context.job_definition_file = f"{module_to_call}.py"
900
+ configurations = {
901
+ "job_definition_file": self.get_caller(),
902
+ "parameters_file": parameters_file,
903
+ "tag": tag,
904
+ "run_id": job_id,
905
+ "execution_mode": context.ExecutionMode.PYTHON,
906
+ "configuration_file": configuration_file,
907
+ "job": self.get_task(),
908
+ "catalog_settings": self.return_catalog_settings(),
909
+ **service_configurations.services,
910
+ }
997
911
 
998
- job = self.get_task()
999
- catalog_settings = self.return_catalog_settings()
912
+ run_context = context.JobContext.model_validate(configurations)
1000
913
 
1001
- try:
1002
- run_context.executor.submit_job(job, catalog_settings=catalog_settings)
1003
- finally:
1004
- run_context.executor.add_task_log_to_catalog("job")
914
+ assert isinstance(run_context.job_executor, BaseJobExecutor)
1005
915
 
1006
- logger.info(
1007
- "Executing the job from the user. We are still in the caller's compute environment"
1008
- )
1009
-
1010
- if run_context.executor._is_local:
1011
- return run_context.run_log_store.get_run_log_by_id(
1012
- run_id=run_context.run_id
1013
- )
916
+ run_context.execute()
1014
917
 
1015
918
 
1016
919
  class PythonJob(BaseJob):
runnable/tasks.py CHANGED
@@ -25,7 +25,7 @@ from runnable.datastore import (
25
25
  Parameter,
26
26
  StepAttempt,
27
27
  )
28
- from runnable.defaults import TypeMapVariable
28
+ from runnable.defaults import MapVariableType
29
29
 
30
30
  logger = logging.getLogger(defaults.LOGGER_NAME)
31
31
 
@@ -48,7 +48,29 @@ class TeeIO(io.StringIO):
48
48
  self.output_stream.flush()
49
49
 
50
50
 
51
- sys.stdout = TeeIO()
51
+ @contextlib.contextmanager
52
+ def redirect_output():
53
+ # Set the stream handlers to use the custom TeeIO class
54
+
55
+ # Backup the original stdout and stderr
56
+ original_stdout = sys.stdout
57
+ original_stderr = sys.stderr
58
+
59
+ # Redirect stdout and stderr to custom TeeStream objects
60
+ sys.stdout = TeeIO(sys.stdout)
61
+ sys.stderr = TeeIO(sys.stderr)
62
+
63
+ # Replace stream for all StreamHandlers to use the new sys.stdout
64
+ for handler in logging.getLogger().handlers:
65
+ if isinstance(handler, logging.StreamHandler):
66
+ handler.stream = sys.stdout
67
+
68
+ try:
69
+ yield sys.stdout, sys.stderr
70
+ finally:
71
+ # Restore the original stdout and stderr
72
+ sys.stdout = original_stdout
73
+ sys.stderr = original_stderr
52
74
 
53
75
 
54
76
  class TaskReturns(BaseModel):
@@ -79,7 +101,7 @@ class BaseTaskType(BaseModel):
79
101
  def set_secrets_as_env_variables(self):
80
102
  # Preparing the environment for the task execution
81
103
  for key in self.secrets:
82
- secret_value = context.run_context.secrets_handler.get(key)
104
+ secret_value = context.run_context.secrets.get(key)
83
105
  os.environ[key] = secret_value
84
106
 
85
107
  def delete_secrets_from_env_variables(self):
@@ -90,7 +112,7 @@ class BaseTaskType(BaseModel):
90
112
 
91
113
  def execute_command(
92
114
  self,
93
- map_variable: TypeMapVariable = None,
115
+ map_variable: MapVariableType = None,
94
116
  ) -> StepAttempt:
95
117
  """The function to execute the command.
96
118
 
@@ -130,7 +152,7 @@ class BaseTaskType(BaseModel):
130
152
  finally:
131
153
  self.delete_secrets_from_env_variables()
132
154
 
133
- def resolve_unreduced_parameters(self, map_variable: TypeMapVariable = None):
155
+ def resolve_unreduced_parameters(self, map_variable: MapVariableType = None):
134
156
  """Resolve the unreduced parameters."""
135
157
  params = self._context.run_log_store.get_parameters(
136
158
  run_id=self._context.run_id
@@ -153,7 +175,7 @@ class BaseTaskType(BaseModel):
153
175
 
154
176
  @contextlib.contextmanager
155
177
  def execution_context(
156
- self, map_variable: TypeMapVariable = None, allow_complex: bool = True
178
+ self, map_variable: MapVariableType = None, allow_complex: bool = True
157
179
  ):
158
180
  params = self.resolve_unreduced_parameters(map_variable=map_variable)
159
181
  logger.info(f"Parameters available for the execution: {params}")
@@ -267,7 +289,7 @@ class PythonTaskType(BaseTaskType): # pylint: disable=too-few-public-methods
267
289
 
268
290
  def execute_command(
269
291
  self,
270
- map_variable: TypeMapVariable = None,
292
+ map_variable: MapVariableType = None,
271
293
  ) -> StepAttempt:
272
294
  """Execute the notebook as defined by the command."""
273
295
  attempt_log = StepAttempt(status=defaults.FAIL, start_time=str(datetime.now()))
@@ -289,13 +311,21 @@ class PythonTaskType(BaseTaskType): # pylint: disable=too-few-public-methods
289
311
  logger.info(
290
312
  f"Calling {func} from {module} with {filtered_parameters}"
291
313
  )
292
-
293
- out_file = TeeIO()
294
- with contextlib.redirect_stdout(out_file):
314
+ context.progress.stop() # redirecting stdout clashes with rich progress
315
+ with redirect_output() as (buffer, stderr_buffer):
295
316
  user_set_parameters = f(
296
317
  **filtered_parameters
297
318
  ) # This is a tuple or single value
298
- task_console.print(out_file.getvalue())
319
+
320
+ print(
321
+ stderr_buffer.getvalue()
322
+ ) # To print the logging statements
323
+
324
+ # TODO: Avoid double print!!
325
+ with task_console.capture():
326
+ task_console.log(buffer.getvalue())
327
+ task_console.log(stderr_buffer.getvalue())
328
+ context.progress.start()
299
329
  except Exception as e:
300
330
  raise exceptions.CommandCallError(
301
331
  f"Function call: {self.command} did not succeed.\n"
@@ -478,14 +508,15 @@ class NotebookTaskType(BaseTaskType):
478
508
 
479
509
  return command
480
510
 
481
- def get_notebook_output_path(self, map_variable: TypeMapVariable = None) -> str:
511
+ def get_notebook_output_path(self, map_variable: MapVariableType = None) -> str:
482
512
  tag = ""
483
513
  map_variable = map_variable or {}
484
514
  for key, value in map_variable.items():
485
515
  tag += f"{key}_{value}_"
486
516
 
487
- if hasattr(self._context.executor, "_context_node"):
488
- tag += self._context.executor._context_node.name
517
+ if isinstance(self._context, context.PipelineContext):
518
+ assert self._context.pipeline_executor._context_node
519
+ tag += self._context.pipeline_executor._context_node.name
489
520
 
490
521
  tag = "".join(x for x in tag if x.isalnum()).strip("-")
491
522
 
@@ -496,7 +527,7 @@ class NotebookTaskType(BaseTaskType):
496
527
 
497
528
  def execute_command(
498
529
  self,
499
- map_variable: TypeMapVariable = None,
530
+ map_variable: MapVariableType = None,
500
531
  ) -> StepAttempt:
501
532
  """Execute the python notebook as defined by the command.
502
533
 
@@ -551,12 +582,20 @@ class NotebookTaskType(BaseTaskType):
551
582
  }
552
583
  kwds.update(ploomber_optional_args)
553
584
 
554
- out_file = TeeIO()
555
- with contextlib.redirect_stdout(out_file):
585
+ context.progress.stop() # redirecting stdout clashes with rich progress
586
+
587
+ with redirect_output() as (buffer, stderr_buffer):
556
588
  pm.execute_notebook(**kwds)
557
- task_console.print(out_file.getvalue())
558
589
 
559
- context.run_context.catalog_handler.put(name=notebook_output_path)
590
+ print(stderr_buffer.getvalue()) # To print the logging statements
591
+
592
+ with task_console.capture():
593
+ task_console.log(buffer.getvalue())
594
+ task_console.log(stderr_buffer.getvalue())
595
+
596
+ context.progress.start()
597
+
598
+ context.run_context.catalog.put(name=notebook_output_path)
560
599
 
561
600
  client = PloomberClient.from_path(path=notebook_output_path)
562
601
  namespace = client.get_namespace()
@@ -674,7 +713,7 @@ class ShellTaskType(BaseTaskType):
674
713
 
675
714
  def execute_command(
676
715
  self,
677
- map_variable: TypeMapVariable = None,
716
+ map_variable: MapVariableType = None,
678
717
  ) -> StepAttempt:
679
718
  # Using shell=True as we want to have chained commands to be executed in the same shell.
680
719
  """Execute the shell command as defined by the command.
@@ -698,7 +737,7 @@ class ShellTaskType(BaseTaskType):
698
737
  # Expose secrets as environment variables
699
738
  if self.secrets:
700
739
  for key in self.secrets:
701
- secret_value = context.run_context.secrets_handler.get(key)
740
+ secret_value = context.run_context.secrets.get(key)
702
741
  subprocess_env[key] = secret_value
703
742
 
704
743
  try:
@@ -724,6 +763,7 @@ class ShellTaskType(BaseTaskType):
724
763
  capture = False
725
764
  return_keys = {x.name: x for x in self.returns}
726
765
 
766
+ context.progress.stop() # redirecting stdout clashes with rich progress
727
767
  proc = subprocess.Popen(
728
768
  command,
729
769
  shell=True,
@@ -747,6 +787,7 @@ class ShellTaskType(BaseTaskType):
747
787
  continue
748
788
  task_console.print(line, style=defaults.warning_style)
749
789
 
790
+ context.progress.start()
750
791
  output_parameters: Dict[str, Parameter] = {}
751
792
  metrics: Dict[str, Parameter] = {}
752
793