runnable 0.35.0__py3-none-any.whl → 0.36.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. extensions/job_executor/__init__.py +3 -4
  2. extensions/job_executor/emulate.py +106 -0
  3. extensions/job_executor/k8s.py +8 -8
  4. extensions/job_executor/local_container.py +13 -14
  5. extensions/nodes/__init__.py +0 -0
  6. extensions/nodes/conditional.py +7 -5
  7. extensions/nodes/fail.py +72 -0
  8. extensions/nodes/map.py +350 -0
  9. extensions/nodes/parallel.py +159 -0
  10. extensions/nodes/stub.py +89 -0
  11. extensions/nodes/success.py +72 -0
  12. extensions/nodes/task.py +92 -0
  13. extensions/pipeline_executor/__init__.py +24 -26
  14. extensions/pipeline_executor/argo.py +20 -20
  15. extensions/pipeline_executor/emulate.py +112 -0
  16. extensions/pipeline_executor/local.py +4 -4
  17. extensions/pipeline_executor/local_container.py +19 -79
  18. extensions/pipeline_executor/mocked.py +5 -9
  19. extensions/pipeline_executor/retry.py +6 -10
  20. runnable/__init__.py +0 -10
  21. runnable/catalog.py +1 -21
  22. runnable/cli.py +0 -59
  23. runnable/context.py +519 -28
  24. runnable/datastore.py +51 -54
  25. runnable/defaults.py +12 -34
  26. runnable/entrypoints.py +82 -440
  27. runnable/exceptions.py +35 -34
  28. runnable/executor.py +13 -20
  29. runnable/names.py +1 -1
  30. runnable/nodes.py +16 -15
  31. runnable/parameters.py +2 -2
  32. runnable/sdk.py +66 -205
  33. runnable/tasks.py +62 -81
  34. runnable/utils.py +6 -268
  35. {runnable-0.35.0.dist-info → runnable-0.36.1.dist-info}/METADATA +1 -4
  36. runnable-0.36.1.dist-info/RECORD +72 -0
  37. {runnable-0.35.0.dist-info → runnable-0.36.1.dist-info}/entry_points.txt +8 -7
  38. extensions/nodes/nodes.py +0 -778
  39. extensions/tasks/torch.py +0 -286
  40. extensions/tasks/torch_config.py +0 -76
  41. runnable-0.35.0.dist-info/RECORD +0 -66
  42. {runnable-0.35.0.dist-info → runnable-0.36.1.dist-info}/WHEEL +0 -0
  43. {runnable-0.35.0.dist-info → runnable-0.36.1.dist-info}/licenses/LICENSE +0 -0
runnable/sdk.py CHANGED
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import inspect
3
4
  import logging
4
- import os
5
5
  import re
6
6
  from abc import ABC, abstractmethod
7
7
  from pathlib import Path
@@ -16,27 +16,17 @@ from pydantic import (
16
16
  field_validator,
17
17
  model_validator,
18
18
  )
19
- from rich.progress import (
20
- BarColumn,
21
- Progress,
22
- SpinnerColumn,
23
- TextColumn,
24
- TimeElapsedColumn,
25
- )
26
- from rich.table import Column
27
19
  from typing_extensions import Self
28
20
 
29
21
  from extensions.nodes.conditional import ConditionalNode
30
- from extensions.nodes.nodes import (
31
- FailNode,
32
- MapNode,
33
- ParallelNode,
34
- StubNode,
35
- SuccessNode,
36
- TaskNode,
37
- )
38
- from runnable import console, defaults, entrypoints, exceptions, graph, utils
39
- from runnable.executor import BaseJobExecutor, BasePipelineExecutor
22
+ from extensions.nodes.fail import FailNode
23
+ from extensions.nodes.map import MapNode
24
+ from extensions.nodes.parallel import ParallelNode
25
+ from extensions.nodes.stub import StubNode
26
+ from extensions.nodes.success import SuccessNode
27
+ from extensions.nodes.task import TaskNode
28
+ from runnable import defaults, graph
29
+ from runnable.executor import BaseJobExecutor
40
30
  from runnable.nodes import TraversalNode
41
31
  from runnable.tasks import BaseTaskType as RunnableTask
42
32
  from runnable.tasks import TaskReturns
@@ -50,7 +40,6 @@ StepType = Union[
50
40
  "ShellTask",
51
41
  "Parallel",
52
42
  "Map",
53
- "TorchTask",
54
43
  "Conditional",
55
44
  ]
56
45
 
@@ -196,7 +185,7 @@ class BaseTask(BaseTraversal):
196
185
  )
197
186
 
198
187
  def as_pipeline(self) -> "Pipeline":
199
- return Pipeline(steps=[self]) # type: ignore
188
+ return Pipeline(steps=[self], name=self.internal_name) # type: ignore
200
189
 
201
190
 
202
191
  class PythonTask(BaseTask):
@@ -287,27 +276,6 @@ class PythonTask(BaseTask):
287
276
  return node.executable
288
277
 
289
278
 
290
- class TorchTask(BaseTask):
291
- # entrypoint: str = Field(
292
- # alias="entrypoint", default="torch.distributed.run", frozen=True
293
- # )
294
- # args_to_torchrun: Dict[str, Any] = Field(
295
- # default_factory=dict, alias="args_to_torchrun"
296
- # )
297
-
298
- script_to_call: str
299
- accelerate_config_file: str
300
-
301
- @computed_field
302
- def command_type(self) -> str:
303
- return "torch"
304
-
305
- def create_job(self) -> RunnableTask:
306
- self.terminate_with_success = True
307
- node = self.create_node()
308
- return node.executable
309
-
310
-
311
279
  class NotebookTask(BaseTask):
312
280
  """
313
281
  An execution node of the pipeline of notebook.
@@ -487,6 +455,9 @@ class Stub(BaseTraversal):
487
455
 
488
456
  return StubNode.parse_from_config(self.model_dump(exclude_none=True))
489
457
 
458
+ def as_pipeline(self) -> "Pipeline":
459
+ return Pipeline(steps=[self])
460
+
490
461
 
491
462
  class Parallel(BaseTraversal):
492
463
  """
@@ -786,6 +757,15 @@ class Pipeline(BaseModel):
786
757
  return False
787
758
  return True
788
759
 
760
+ def get_caller(self) -> str:
761
+ caller_stack = inspect.stack()[2]
762
+ relative_to_root = str(Path(caller_stack.filename).relative_to(Path.cwd()))
763
+
764
+ module_name = re.sub(r"\b.py\b", "", relative_to_root.replace("/", "."))
765
+ module_to_call = f"{module_name}.{caller_stack.function}"
766
+
767
+ return module_to_call
768
+
789
769
  def execute(
790
770
  self,
791
771
  configuration_file: str = "",
@@ -803,106 +783,31 @@ class Pipeline(BaseModel):
803
783
  # Immediately return as this call is only for getting the pipeline definition
804
784
  return {}
805
785
 
806
- logger.setLevel(log_level)
807
-
808
- run_id = utils.generate_run_id(run_id=run_id)
809
-
810
- parameters_file = os.environ.get("RUNNABLE_PARAMETERS_FILE", parameters_file)
786
+ from runnable import context
811
787
 
812
- tag = os.environ.get("RUNNABLE_tag", tag)
788
+ logger.setLevel(log_level)
813
789
 
814
- configuration_file = os.environ.get(
815
- "RUNNABLE_CONFIGURATION_FILE", configuration_file
816
- )
817
- run_context = entrypoints.prepare_configurations(
790
+ service_configurations = context.ServiceConfigurations(
818
791
  configuration_file=configuration_file,
819
- run_id=run_id,
820
- tag=tag,
821
- parameters_file=parameters_file,
822
- )
823
-
824
- assert isinstance(run_context.executor, BasePipelineExecutor)
825
-
826
- utils.set_runnable_environment_variables(
827
- run_id=run_id, configuration_file=configuration_file, tag=tag
792
+ execution_context=context.ExecutionContext.PIPELINE,
828
793
  )
829
794
 
830
- dag_definition = self._dag.model_dump(by_alias=True, exclude_none=True)
831
- run_context.from_sdk = True
832
- run_context.dag = graph.create_graph(dag_definition)
833
-
834
- console.print("Working with context:")
835
- console.print(run_context)
836
- console.rule(style="[dark orange]")
837
-
838
- if not run_context.executor._is_local:
839
- # We are not working with executor that does not work in local environment
840
- import inspect
841
-
842
- caller_stack = inspect.stack()[1]
843
- relative_to_root = str(Path(caller_stack.filename).relative_to(Path.cwd()))
844
-
845
- module_name = re.sub(r"\b.py\b", "", relative_to_root.replace("/", "."))
846
- module_to_call = f"{module_name}.{caller_stack.function}"
847
-
848
- run_context.pipeline_file = f"{module_to_call}.py"
849
- run_context.from_sdk = True
850
-
851
- # Prepare for graph execution
852
- run_context.executor._set_up_run_log(exists_ok=False)
853
-
854
- with Progress(
855
- SpinnerColumn(spinner_name="runner"),
856
- TextColumn(
857
- "[progress.description]{task.description}", table_column=Column(ratio=2)
858
- ),
859
- BarColumn(table_column=Column(ratio=1), style="dark_orange"),
860
- TimeElapsedColumn(table_column=Column(ratio=1)),
861
- console=console,
862
- expand=True,
863
- ) as progress:
864
- pipeline_execution_task = progress.add_task(
865
- "[dark_orange] Starting execution .. ", total=1
866
- )
867
- try:
868
- run_context.progress = progress
869
-
870
- run_context.executor.execute_graph(dag=run_context.dag)
795
+ configurations = {
796
+ "pipeline_definition_file": self.get_caller(),
797
+ "parameters_file": parameters_file,
798
+ "tag": tag,
799
+ "run_id": run_id,
800
+ "execution_mode": context.ExecutionMode.PYTHON,
801
+ "configuration_file": configuration_file,
802
+ **service_configurations.services,
803
+ }
871
804
 
872
- if not run_context.executor._is_local:
873
- # non local executors just traverse the graph and do nothing
874
- return {}
805
+ run_context = context.PipelineContext.model_validate(configurations)
806
+ context.run_context = run_context
875
807
 
876
- run_log = run_context.run_log_store.get_run_log_by_id(
877
- run_id=run_context.run_id, full=False
878
- )
808
+ assert isinstance(run_context, context.PipelineContext)
879
809
 
880
- if run_log.status == defaults.SUCCESS:
881
- progress.update(
882
- pipeline_execution_task,
883
- description="[green] Success",
884
- completed=True,
885
- )
886
- else:
887
- progress.update(
888
- pipeline_execution_task,
889
- description="[red] Failed",
890
- completed=True,
891
- )
892
- raise exceptions.ExecutionFailedError(run_context.run_id)
893
- except Exception as e: # noqa: E722
894
- console.print(e, style=defaults.error_style)
895
- progress.update(
896
- pipeline_execution_task,
897
- description="[red] Errored execution",
898
- completed=True,
899
- )
900
- raise
901
-
902
- if run_context.executor._is_local:
903
- return run_context.run_log_store.get_run_log_by_id(
904
- run_id=run_context.run_id
905
- )
810
+ run_context.execute()
906
811
 
907
812
 
908
813
  class BaseJob(BaseModel):
@@ -926,6 +831,15 @@ class BaseJob(BaseModel):
926
831
  def get_task(self) -> RunnableTask:
927
832
  raise NotImplementedError
928
833
 
834
+ def get_caller(self) -> str:
835
+ caller_stack = inspect.stack()[2]
836
+ relative_to_root = str(Path(caller_stack.filename).relative_to(Path.cwd()))
837
+
838
+ module_name = re.sub(r"\b.py\b", "", relative_to_root.replace("/", "."))
839
+ module_to_call = f"{module_name}.{caller_stack.function}"
840
+
841
+ return module_to_call
842
+
929
843
  def return_catalog_settings(self) -> Optional[List[str]]:
930
844
  if self.catalog is None:
931
845
  return []
@@ -952,65 +866,32 @@ class BaseJob(BaseModel):
952
866
  if self._is_called_for_definition():
953
867
  # Immediately return as this call is only for getting the job definition
954
868
  return {}
955
- logger.setLevel(log_level)
956
-
957
- run_id = utils.generate_run_id(run_id=job_id)
958
-
959
- parameters_file = os.environ.get("RUNNABLE_PARAMETERS_FILE", parameters_file)
960
-
961
- tag = os.environ.get("RUNNABLE_tag", tag)
869
+ from runnable import context
962
870
 
963
- configuration_file = os.environ.get(
964
- "RUNNABLE_CONFIGURATION_FILE", configuration_file
965
- )
871
+ logger.setLevel(log_level)
966
872
 
967
- run_context = entrypoints.prepare_configurations(
873
+ service_configurations = context.ServiceConfigurations(
968
874
  configuration_file=configuration_file,
969
- run_id=run_id,
970
- tag=tag,
971
- parameters_file=parameters_file,
972
- is_job=True,
973
- )
974
-
975
- assert isinstance(run_context.executor, BaseJobExecutor)
976
- run_context.from_sdk = True
977
-
978
- utils.set_runnable_environment_variables(
979
- run_id=run_id, configuration_file=configuration_file, tag=tag
875
+ execution_context=context.ExecutionContext.JOB,
980
876
  )
981
877
 
982
- console.print("Working with context:")
983
- console.print(run_context)
984
- console.rule(style="[dark orange]")
985
-
986
- if not run_context.executor._is_local:
987
- # We are not working with executor that does not work in local environment
988
- import inspect
989
-
990
- caller_stack = inspect.stack()[1]
991
- relative_to_root = str(Path(caller_stack.filename).relative_to(Path.cwd()))
992
-
993
- module_name = re.sub(r"\b.py\b", "", relative_to_root.replace("/", "."))
994
- module_to_call = f"{module_name}.{caller_stack.function}"
995
-
996
- run_context.job_definition_file = f"{module_to_call}.py"
997
-
998
- job = self.get_task()
999
- catalog_settings = self.return_catalog_settings()
878
+ configurations = {
879
+ "job_definition_file": self.get_caller(),
880
+ "parameters_file": parameters_file,
881
+ "tag": tag,
882
+ "run_id": job_id,
883
+ "execution_mode": context.ExecutionMode.PYTHON,
884
+ "configuration_file": configuration_file,
885
+ "job": self.get_task(),
886
+ "catalog_settings": self.return_catalog_settings(),
887
+ **service_configurations.services,
888
+ }
1000
889
 
1001
- try:
1002
- run_context.executor.submit_job(job, catalog_settings=catalog_settings)
1003
- finally:
1004
- run_context.executor.add_task_log_to_catalog("job")
890
+ run_context = context.JobContext.model_validate(configurations)
1005
891
 
1006
- logger.info(
1007
- "Executing the job from the user. We are still in the caller's compute environment"
1008
- )
892
+ assert isinstance(run_context.job_executor, BaseJobExecutor)
1009
893
 
1010
- if run_context.executor._is_local:
1011
- return run_context.run_log_store.get_run_log_by_id(
1012
- run_id=run_context.run_id
1013
- )
894
+ run_context.execute()
1014
895
 
1015
896
 
1016
897
  class PythonJob(BaseJob):
@@ -1034,26 +915,6 @@ class PythonJob(BaseJob):
1034
915
  return task.create_node().executable
1035
916
 
1036
917
 
1037
- class TorchJob(BaseJob):
1038
- # entrypoint: str = Field(default="torch.distributed.run", frozen=True)
1039
- # args_to_torchrun: dict[str, str | bool | int | float] = Field(
1040
- # default_factory=dict
1041
- # ) # For example
1042
- # {"nproc_per_node": 2, "nnodes": 1,}
1043
-
1044
- script_to_call: str # For example train/script.py
1045
- accelerate_config_file: str
1046
-
1047
- def get_task(self) -> RunnableTask:
1048
- # Piggy bank on existing tasks as a hack
1049
- task = TorchTask(
1050
- name="dummy",
1051
- terminate_with_success=True,
1052
- **self.model_dump(exclude_defaults=True, exclude_none=True),
1053
- )
1054
- return task.create_node().executable
1055
-
1056
-
1057
918
  class NotebookJob(BaseJob):
1058
919
  notebook: str = Field(serialization_alias="command")
1059
920
  optional_ploomber_args: Optional[Dict[str, Any]] = Field(
runnable/tasks.py CHANGED
@@ -25,7 +25,7 @@ from runnable.datastore import (
25
25
  Parameter,
26
26
  StepAttempt,
27
27
  )
28
- from runnable.defaults import TypeMapVariable
28
+ from runnable.defaults import MapVariableType
29
29
 
30
30
  logger = logging.getLogger(defaults.LOGGER_NAME)
31
31
 
@@ -48,7 +48,29 @@ class TeeIO(io.StringIO):
48
48
  self.output_stream.flush()
49
49
 
50
50
 
51
- sys.stdout = TeeIO()
51
+ @contextlib.contextmanager
52
+ def redirect_output():
53
+ # Set the stream handlers to use the custom TeeIO class
54
+
55
+ # Backup the original stdout and stderr
56
+ original_stdout = sys.stdout
57
+ original_stderr = sys.stderr
58
+
59
+ # Redirect stdout and stderr to custom TeeStream objects
60
+ sys.stdout = TeeIO(sys.stdout)
61
+ sys.stderr = TeeIO(sys.stderr)
62
+
63
+ # Replace stream for all StreamHandlers to use the new sys.stdout
64
+ for handler in logging.getLogger().handlers:
65
+ if isinstance(handler, logging.StreamHandler):
66
+ handler.stream = sys.stdout
67
+
68
+ try:
69
+ yield sys.stdout, sys.stderr
70
+ finally:
71
+ # Restore the original stdout and stderr
72
+ sys.stdout = original_stdout
73
+ sys.stderr = original_stderr
52
74
 
53
75
 
54
76
  class TaskReturns(BaseModel):
@@ -79,7 +101,7 @@ class BaseTaskType(BaseModel):
79
101
  def set_secrets_as_env_variables(self):
80
102
  # Preparing the environment for the task execution
81
103
  for key in self.secrets:
82
- secret_value = context.run_context.secrets_handler.get(key)
104
+ secret_value = context.run_context.secrets.get(key)
83
105
  os.environ[key] = secret_value
84
106
 
85
107
  def delete_secrets_from_env_variables(self):
@@ -90,7 +112,7 @@ class BaseTaskType(BaseModel):
90
112
 
91
113
  def execute_command(
92
114
  self,
93
- map_variable: TypeMapVariable = None,
115
+ map_variable: MapVariableType = None,
94
116
  ) -> StepAttempt:
95
117
  """The function to execute the command.
96
118
 
@@ -130,7 +152,7 @@ class BaseTaskType(BaseModel):
130
152
  finally:
131
153
  self.delete_secrets_from_env_variables()
132
154
 
133
- def resolve_unreduced_parameters(self, map_variable: TypeMapVariable = None):
155
+ def resolve_unreduced_parameters(self, map_variable: MapVariableType = None):
134
156
  """Resolve the unreduced parameters."""
135
157
  params = self._context.run_log_store.get_parameters(
136
158
  run_id=self._context.run_id
@@ -153,7 +175,7 @@ class BaseTaskType(BaseModel):
153
175
 
154
176
  @contextlib.contextmanager
155
177
  def execution_context(
156
- self, map_variable: TypeMapVariable = None, allow_complex: bool = True
178
+ self, map_variable: MapVariableType = None, allow_complex: bool = True
157
179
  ):
158
180
  params = self.resolve_unreduced_parameters(map_variable=map_variable)
159
181
  logger.info(f"Parameters available for the execution: {params}")
@@ -267,7 +289,7 @@ class PythonTaskType(BaseTaskType): # pylint: disable=too-few-public-methods
267
289
 
268
290
  def execute_command(
269
291
  self,
270
- map_variable: TypeMapVariable = None,
292
+ map_variable: MapVariableType = None,
271
293
  ) -> StepAttempt:
272
294
  """Execute the notebook as defined by the command."""
273
295
  attempt_log = StepAttempt(status=defaults.FAIL, start_time=str(datetime.now()))
@@ -289,13 +311,21 @@ class PythonTaskType(BaseTaskType): # pylint: disable=too-few-public-methods
289
311
  logger.info(
290
312
  f"Calling {func} from {module} with {filtered_parameters}"
291
313
  )
292
-
293
- out_file = TeeIO()
294
- with contextlib.redirect_stdout(out_file):
314
+ context.progress.stop() # redirecting stdout clashes with rich progress
315
+ with redirect_output() as (buffer, stderr_buffer):
295
316
  user_set_parameters = f(
296
317
  **filtered_parameters
297
318
  ) # This is a tuple or single value
298
- task_console.print(out_file.getvalue())
319
+
320
+ print(
321
+ stderr_buffer.getvalue()
322
+ ) # To print the logging statements
323
+
324
+ # TODO: Avoid double print!!
325
+ with task_console.capture():
326
+ task_console.log(buffer.getvalue())
327
+ task_console.log(stderr_buffer.getvalue())
328
+ context.progress.start()
299
329
  except Exception as e:
300
330
  raise exceptions.CommandCallError(
301
331
  f"Function call: {self.command} did not succeed.\n"
@@ -354,66 +384,6 @@ class PythonTaskType(BaseTaskType): # pylint: disable=too-few-public-methods
354
384
  return attempt_log
355
385
 
356
386
 
357
- class TorchTaskType(BaseTaskType):
358
- task_type: str = Field(default="torch", serialization_alias="command_type")
359
- accelerate_config_file: str
360
-
361
- script_to_call: str # For example train/script.py
362
-
363
- def execute_command(
364
- self, map_variable: Dict[str, str | int | float] | None = None
365
- ) -> StepAttempt:
366
- from accelerate.commands import launch
367
-
368
- attempt_log = StepAttempt(status=defaults.FAIL, start_time=str(datetime.now()))
369
-
370
- with (
371
- self.execution_context(
372
- map_variable=map_variable, allow_complex=False
373
- ) as params,
374
- self.expose_secrets() as _,
375
- ):
376
- try:
377
- script_args = []
378
- for key, value in params.items():
379
- script_args.append(f"--{key}")
380
- if type(value.value) is not bool:
381
- script_args.append(str(value.value))
382
-
383
- # TODO: Check the typing here
384
-
385
- logger.info("Calling the user script with the following parameters:")
386
- logger.info(script_args)
387
- out_file = TeeIO()
388
- try:
389
- with contextlib.redirect_stdout(out_file):
390
- parser = launch.launch_command_parser()
391
- args = parser.parse_args(self.script_to_call)
392
- args.training_script = self.script_to_call
393
- args.config_file = self.accelerate_config_file
394
- args.training_script_args = script_args
395
-
396
- launch.launch_command(args)
397
- task_console.print(out_file.getvalue())
398
- except Exception as e:
399
- raise exceptions.CommandCallError(
400
- f"Call to script{self.script_to_call} did not succeed."
401
- ) from e
402
- finally:
403
- sys.argv = sys.argv[:1]
404
-
405
- attempt_log.status = defaults.SUCCESS
406
- except Exception as _e:
407
- msg = f"Call to script: {self.script_to_call} did not succeed."
408
- attempt_log.message = msg
409
- task_console.print_exception(show_locals=False)
410
- task_console.log(_e, style=defaults.error_style)
411
-
412
- attempt_log.end_time = str(datetime.now())
413
-
414
- return attempt_log
415
-
416
-
417
387
  class NotebookTaskType(BaseTaskType):
418
388
  """
419
389
  --8<-- [start:notebook_reference]
@@ -478,14 +448,15 @@ class NotebookTaskType(BaseTaskType):
478
448
 
479
449
  return command
480
450
 
481
- def get_notebook_output_path(self, map_variable: TypeMapVariable = None) -> str:
451
+ def get_notebook_output_path(self, map_variable: MapVariableType = None) -> str:
482
452
  tag = ""
483
453
  map_variable = map_variable or {}
484
454
  for key, value in map_variable.items():
485
455
  tag += f"{key}_{value}_"
486
456
 
487
- if hasattr(self._context.executor, "_context_node"):
488
- tag += self._context.executor._context_node.name
457
+ if isinstance(self._context, context.PipelineContext):
458
+ assert self._context.pipeline_executor._context_node
459
+ tag += self._context.pipeline_executor._context_node.name
489
460
 
490
461
  tag = "".join(x for x in tag if x.isalnum()).strip("-")
491
462
 
@@ -496,7 +467,7 @@ class NotebookTaskType(BaseTaskType):
496
467
 
497
468
  def execute_command(
498
469
  self,
499
- map_variable: TypeMapVariable = None,
470
+ map_variable: MapVariableType = None,
500
471
  ) -> StepAttempt:
501
472
  """Execute the python notebook as defined by the command.
502
473
 
@@ -551,12 +522,20 @@ class NotebookTaskType(BaseTaskType):
551
522
  }
552
523
  kwds.update(ploomber_optional_args)
553
524
 
554
- out_file = TeeIO()
555
- with contextlib.redirect_stdout(out_file):
525
+ context.progress.stop() # redirecting stdout clashes with rich progress
526
+
527
+ with redirect_output() as (buffer, stderr_buffer):
556
528
  pm.execute_notebook(**kwds)
557
- task_console.print(out_file.getvalue())
558
529
 
559
- context.run_context.catalog_handler.put(name=notebook_output_path)
530
+ print(stderr_buffer.getvalue()) # To print the logging statements
531
+
532
+ with task_console.capture():
533
+ task_console.log(buffer.getvalue())
534
+ task_console.log(stderr_buffer.getvalue())
535
+
536
+ context.progress.start()
537
+
538
+ context.run_context.catalog.put(name=notebook_output_path)
560
539
 
561
540
  client = PloomberClient.from_path(path=notebook_output_path)
562
541
  namespace = client.get_namespace()
@@ -674,7 +653,7 @@ class ShellTaskType(BaseTaskType):
674
653
 
675
654
  def execute_command(
676
655
  self,
677
- map_variable: TypeMapVariable = None,
656
+ map_variable: MapVariableType = None,
678
657
  ) -> StepAttempt:
679
658
  # Using shell=True as we want to have chained commands to be executed in the same shell.
680
659
  """Execute the shell command as defined by the command.
@@ -698,7 +677,7 @@ class ShellTaskType(BaseTaskType):
698
677
  # Expose secrets as environment variables
699
678
  if self.secrets:
700
679
  for key in self.secrets:
701
- secret_value = context.run_context.secrets_handler.get(key)
680
+ secret_value = context.run_context.secrets.get(key)
702
681
  subprocess_env[key] = secret_value
703
682
 
704
683
  try:
@@ -724,6 +703,7 @@ class ShellTaskType(BaseTaskType):
724
703
  capture = False
725
704
  return_keys = {x.name: x for x in self.returns}
726
705
 
706
+ context.progress.stop() # redirecting stdout clashes with rich progress
727
707
  proc = subprocess.Popen(
728
708
  command,
729
709
  shell=True,
@@ -747,6 +727,7 @@ class ShellTaskType(BaseTaskType):
747
727
  continue
748
728
  task_console.print(line, style=defaults.warning_style)
749
729
 
730
+ context.progress.start()
750
731
  output_parameters: Dict[str, Parameter] = {}
751
732
  metrics: Dict[str, Parameter] = {}
752
733