runnable 0.34.0a3__py3-none-any.whl → 0.36.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- extensions/job_executor/__init__.py +3 -4
- extensions/job_executor/emulate.py +106 -0
- extensions/job_executor/k8s.py +8 -8
- extensions/job_executor/local_container.py +13 -14
- extensions/nodes/__init__.py +0 -0
- extensions/nodes/conditional.py +243 -0
- extensions/nodes/fail.py +72 -0
- extensions/nodes/map.py +350 -0
- extensions/nodes/parallel.py +159 -0
- extensions/nodes/stub.py +89 -0
- extensions/nodes/success.py +72 -0
- extensions/nodes/task.py +92 -0
- extensions/pipeline_executor/__init__.py +24 -26
- extensions/pipeline_executor/argo.py +50 -41
- extensions/pipeline_executor/emulate.py +112 -0
- extensions/pipeline_executor/local.py +4 -4
- extensions/pipeline_executor/local_container.py +19 -79
- extensions/pipeline_executor/mocked.py +4 -4
- extensions/pipeline_executor/retry.py +6 -10
- extensions/tasks/torch.py +1 -1
- runnable/__init__.py +2 -9
- runnable/catalog.py +1 -21
- runnable/cli.py +0 -59
- runnable/context.py +519 -28
- runnable/datastore.py +51 -54
- runnable/defaults.py +12 -34
- runnable/entrypoints.py +82 -440
- runnable/exceptions.py +35 -34
- runnable/executor.py +13 -20
- runnable/names.py +1 -1
- runnable/nodes.py +18 -16
- runnable/parameters.py +2 -2
- runnable/sdk.py +117 -164
- runnable/tasks.py +62 -21
- runnable/utils.py +6 -268
- {runnable-0.34.0a3.dist-info → runnable-0.36.0.dist-info}/METADATA +1 -2
- runnable-0.36.0.dist-info/RECORD +74 -0
- {runnable-0.34.0a3.dist-info → runnable-0.36.0.dist-info}/entry_points.txt +9 -8
- extensions/nodes/nodes.py +0 -778
- extensions/nodes/torch.py +0 -273
- extensions/nodes/torch_config.py +0 -76
- runnable-0.34.0a3.dist-info/RECORD +0 -67
- {runnable-0.34.0a3.dist-info → runnable-0.36.0.dist-info}/WHEEL +0 -0
- {runnable-0.34.0a3.dist-info → runnable-0.36.0.dist-info}/licenses/LICENSE +0 -0
runnable/sdk.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import inspect
|
3
4
|
import logging
|
4
|
-
import os
|
5
5
|
import re
|
6
6
|
from abc import ABC, abstractmethod
|
7
7
|
from pathlib import Path
|
@@ -16,26 +16,17 @@ from pydantic import (
|
|
16
16
|
field_validator,
|
17
17
|
model_validator,
|
18
18
|
)
|
19
|
-
from rich.progress import (
|
20
|
-
BarColumn,
|
21
|
-
Progress,
|
22
|
-
SpinnerColumn,
|
23
|
-
TextColumn,
|
24
|
-
TimeElapsedColumn,
|
25
|
-
)
|
26
|
-
from rich.table import Column
|
27
19
|
from typing_extensions import Self
|
28
20
|
|
29
|
-
from extensions.nodes.
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
from runnable import
|
38
|
-
from runnable.executor import BaseJobExecutor, BasePipelineExecutor
|
21
|
+
from extensions.nodes.conditional import ConditionalNode
|
22
|
+
from extensions.nodes.fail import FailNode
|
23
|
+
from extensions.nodes.map import MapNode
|
24
|
+
from extensions.nodes.parallel import ParallelNode
|
25
|
+
from extensions.nodes.stub import StubNode
|
26
|
+
from extensions.nodes.success import SuccessNode
|
27
|
+
from extensions.nodes.task import TaskNode
|
28
|
+
from runnable import defaults, graph
|
29
|
+
from runnable.executor import BaseJobExecutor
|
39
30
|
from runnable.nodes import TraversalNode
|
40
31
|
from runnable.tasks import BaseTaskType as RunnableTask
|
41
32
|
from runnable.tasks import TaskReturns
|
@@ -50,6 +41,7 @@ StepType = Union[
|
|
50
41
|
"Parallel",
|
51
42
|
"Map",
|
52
43
|
"TorchTask",
|
44
|
+
"Conditional",
|
53
45
|
]
|
54
46
|
|
55
47
|
|
@@ -193,6 +185,9 @@ class BaseTask(BaseTraversal):
|
|
193
185
|
"This method should be implemented in the child class"
|
194
186
|
)
|
195
187
|
|
188
|
+
def as_pipeline(self) -> "Pipeline":
|
189
|
+
return Pipeline(steps=[self], name=self.internal_name) # type: ignore
|
190
|
+
|
196
191
|
|
197
192
|
class PythonTask(BaseTask):
|
198
193
|
"""
|
@@ -482,6 +477,9 @@ class Stub(BaseTraversal):
|
|
482
477
|
|
483
478
|
return StubNode.parse_from_config(self.model_dump(exclude_none=True))
|
484
479
|
|
480
|
+
def as_pipeline(self) -> "Pipeline":
|
481
|
+
return Pipeline(steps=[self])
|
482
|
+
|
485
483
|
|
486
484
|
class Parallel(BaseTraversal):
|
487
485
|
"""
|
@@ -521,6 +519,53 @@ class Parallel(BaseTraversal):
|
|
521
519
|
return node
|
522
520
|
|
523
521
|
|
522
|
+
class Conditional(BaseTraversal):
|
523
|
+
branches: Dict[str, "Pipeline"]
|
524
|
+
parameter: str # the name of the parameter should be isalnum
|
525
|
+
|
526
|
+
@field_validator("parameter")
|
527
|
+
@classmethod
|
528
|
+
def validate_parameter(cls, parameter: str) -> str:
|
529
|
+
if not parameter.isalnum():
|
530
|
+
raise AssertionError(
|
531
|
+
"The parameter name should be alphanumeric and not empty"
|
532
|
+
)
|
533
|
+
return parameter
|
534
|
+
|
535
|
+
@field_validator("branches")
|
536
|
+
@classmethod
|
537
|
+
def validate_branches(
|
538
|
+
cls, branches: Dict[str, "Pipeline"]
|
539
|
+
) -> Dict[str, "Pipeline"]:
|
540
|
+
for branch_name in branches.keys():
|
541
|
+
if not branch_name.isalnum():
|
542
|
+
raise ValueError(f"Branch '{branch_name}' must be alphanumeric.")
|
543
|
+
return branches
|
544
|
+
|
545
|
+
@computed_field # type: ignore
|
546
|
+
@property
|
547
|
+
def graph_branches(self) -> Dict[str, graph.Graph]:
|
548
|
+
return {
|
549
|
+
name: pipeline._dag.model_copy() for name, pipeline in self.branches.items()
|
550
|
+
}
|
551
|
+
|
552
|
+
def create_node(self) -> ConditionalNode:
|
553
|
+
if not self.next_node:
|
554
|
+
if not (self.terminate_with_failure or self.terminate_with_success):
|
555
|
+
raise AssertionError(
|
556
|
+
"A node not being terminated must have a user defined next node"
|
557
|
+
)
|
558
|
+
|
559
|
+
node = ConditionalNode(
|
560
|
+
name=self.name,
|
561
|
+
branches=self.graph_branches,
|
562
|
+
internal_name="",
|
563
|
+
next_node=self.next_node,
|
564
|
+
parameter=self.parameter,
|
565
|
+
)
|
566
|
+
return node
|
567
|
+
|
568
|
+
|
524
569
|
class Map(BaseTraversal):
|
525
570
|
"""
|
526
571
|
A node that iterates over a list of items and executes a pipeline for each item.
|
@@ -544,7 +589,6 @@ class Map(BaseTraversal):
|
|
544
589
|
iterate_on: str
|
545
590
|
iterate_as: str
|
546
591
|
reducer: Optional[str] = Field(default=None, alias="reducer")
|
547
|
-
overrides: Dict[str, Any] = Field(default_factory=dict)
|
548
592
|
|
549
593
|
@computed_field # type: ignore
|
550
594
|
@property
|
@@ -565,7 +609,6 @@ class Map(BaseTraversal):
|
|
565
609
|
next_node=self.next_node,
|
566
610
|
iterate_on=self.iterate_on,
|
567
611
|
iterate_as=self.iterate_as,
|
568
|
-
overrides=self.overrides,
|
569
612
|
reducer=self.reducer,
|
570
613
|
)
|
571
614
|
|
@@ -736,6 +779,15 @@ class Pipeline(BaseModel):
|
|
736
779
|
return False
|
737
780
|
return True
|
738
781
|
|
782
|
+
def get_caller(self) -> str:
|
783
|
+
caller_stack = inspect.stack()[2]
|
784
|
+
relative_to_root = str(Path(caller_stack.filename).relative_to(Path.cwd()))
|
785
|
+
|
786
|
+
module_name = re.sub(r"\b.py\b", "", relative_to_root.replace("/", "."))
|
787
|
+
module_to_call = f"{module_name}.{caller_stack.function}"
|
788
|
+
|
789
|
+
return module_to_call
|
790
|
+
|
739
791
|
def execute(
|
740
792
|
self,
|
741
793
|
configuration_file: str = "",
|
@@ -753,106 +805,31 @@ class Pipeline(BaseModel):
|
|
753
805
|
# Immediately return as this call is only for getting the pipeline definition
|
754
806
|
return {}
|
755
807
|
|
756
|
-
|
757
|
-
|
758
|
-
run_id = utils.generate_run_id(run_id=run_id)
|
759
|
-
|
760
|
-
parameters_file = os.environ.get("RUNNABLE_PARAMETERS_FILE", parameters_file)
|
808
|
+
from runnable import context
|
761
809
|
|
762
|
-
|
810
|
+
logger.setLevel(log_level)
|
763
811
|
|
764
|
-
|
765
|
-
"RUNNABLE_CONFIGURATION_FILE", configuration_file
|
766
|
-
)
|
767
|
-
run_context = entrypoints.prepare_configurations(
|
812
|
+
service_configurations = context.ServiceConfigurations(
|
768
813
|
configuration_file=configuration_file,
|
769
|
-
|
770
|
-
tag=tag,
|
771
|
-
parameters_file=parameters_file,
|
814
|
+
execution_context=context.ExecutionContext.PIPELINE,
|
772
815
|
)
|
773
816
|
|
774
|
-
|
775
|
-
|
776
|
-
|
777
|
-
|
778
|
-
|
779
|
-
|
780
|
-
|
781
|
-
|
782
|
-
|
783
|
-
|
784
|
-
console.print("Working with context:")
|
785
|
-
console.print(run_context)
|
786
|
-
console.rule(style="[dark orange]")
|
787
|
-
|
788
|
-
if not run_context.executor._is_local:
|
789
|
-
# We are not working with executor that does not work in local environment
|
790
|
-
import inspect
|
791
|
-
|
792
|
-
caller_stack = inspect.stack()[1]
|
793
|
-
relative_to_root = str(Path(caller_stack.filename).relative_to(Path.cwd()))
|
794
|
-
|
795
|
-
module_name = re.sub(r"\b.py\b", "", relative_to_root.replace("/", "."))
|
796
|
-
module_to_call = f"{module_name}.{caller_stack.function}"
|
797
|
-
|
798
|
-
run_context.pipeline_file = f"{module_to_call}.py"
|
799
|
-
run_context.from_sdk = True
|
800
|
-
|
801
|
-
# Prepare for graph execution
|
802
|
-
run_context.executor._set_up_run_log(exists_ok=False)
|
803
|
-
|
804
|
-
with Progress(
|
805
|
-
SpinnerColumn(spinner_name="runner"),
|
806
|
-
TextColumn(
|
807
|
-
"[progress.description]{task.description}", table_column=Column(ratio=2)
|
808
|
-
),
|
809
|
-
BarColumn(table_column=Column(ratio=1), style="dark_orange"),
|
810
|
-
TimeElapsedColumn(table_column=Column(ratio=1)),
|
811
|
-
console=console,
|
812
|
-
expand=True,
|
813
|
-
) as progress:
|
814
|
-
pipeline_execution_task = progress.add_task(
|
815
|
-
"[dark_orange] Starting execution .. ", total=1
|
816
|
-
)
|
817
|
-
try:
|
818
|
-
run_context.progress = progress
|
819
|
-
|
820
|
-
run_context.executor.execute_graph(dag=run_context.dag)
|
817
|
+
configurations = {
|
818
|
+
"pipeline_definition_file": self.get_caller(),
|
819
|
+
"parameters_file": parameters_file,
|
820
|
+
"tag": tag,
|
821
|
+
"run_id": run_id,
|
822
|
+
"execution_mode": context.ExecutionMode.PYTHON,
|
823
|
+
"configuration_file": configuration_file,
|
824
|
+
**service_configurations.services,
|
825
|
+
}
|
821
826
|
|
822
|
-
|
823
|
-
|
824
|
-
return {}
|
827
|
+
run_context = context.PipelineContext.model_validate(configurations)
|
828
|
+
context.run_context = run_context
|
825
829
|
|
826
|
-
|
827
|
-
run_id=run_context.run_id, full=False
|
828
|
-
)
|
830
|
+
assert isinstance(run_context, context.PipelineContext)
|
829
831
|
|
830
|
-
|
831
|
-
progress.update(
|
832
|
-
pipeline_execution_task,
|
833
|
-
description="[green] Success",
|
834
|
-
completed=True,
|
835
|
-
)
|
836
|
-
else:
|
837
|
-
progress.update(
|
838
|
-
pipeline_execution_task,
|
839
|
-
description="[red] Failed",
|
840
|
-
completed=True,
|
841
|
-
)
|
842
|
-
raise exceptions.ExecutionFailedError(run_context.run_id)
|
843
|
-
except Exception as e: # noqa: E722
|
844
|
-
console.print(e, style=defaults.error_style)
|
845
|
-
progress.update(
|
846
|
-
pipeline_execution_task,
|
847
|
-
description="[red] Errored execution",
|
848
|
-
completed=True,
|
849
|
-
)
|
850
|
-
raise
|
851
|
-
|
852
|
-
if run_context.executor._is_local:
|
853
|
-
return run_context.run_log_store.get_run_log_by_id(
|
854
|
-
run_id=run_context.run_id
|
855
|
-
)
|
832
|
+
run_context.execute()
|
856
833
|
|
857
834
|
|
858
835
|
class BaseJob(BaseModel):
|
@@ -876,6 +853,15 @@ class BaseJob(BaseModel):
|
|
876
853
|
def get_task(self) -> RunnableTask:
|
877
854
|
raise NotImplementedError
|
878
855
|
|
856
|
+
def get_caller(self) -> str:
|
857
|
+
caller_stack = inspect.stack()[2]
|
858
|
+
relative_to_root = str(Path(caller_stack.filename).relative_to(Path.cwd()))
|
859
|
+
|
860
|
+
module_name = re.sub(r"\b.py\b", "", relative_to_root.replace("/", "."))
|
861
|
+
module_to_call = f"{module_name}.{caller_stack.function}"
|
862
|
+
|
863
|
+
return module_to_call
|
864
|
+
|
879
865
|
def return_catalog_settings(self) -> Optional[List[str]]:
|
880
866
|
if self.catalog is None:
|
881
867
|
return []
|
@@ -902,65 +888,32 @@ class BaseJob(BaseModel):
|
|
902
888
|
if self._is_called_for_definition():
|
903
889
|
# Immediately return as this call is only for getting the job definition
|
904
890
|
return {}
|
905
|
-
|
906
|
-
|
907
|
-
run_id = utils.generate_run_id(run_id=job_id)
|
908
|
-
|
909
|
-
parameters_file = os.environ.get("RUNNABLE_PARAMETERS_FILE", parameters_file)
|
910
|
-
|
911
|
-
tag = os.environ.get("RUNNABLE_tag", tag)
|
891
|
+
from runnable import context
|
912
892
|
|
913
|
-
|
914
|
-
"RUNNABLE_CONFIGURATION_FILE", configuration_file
|
915
|
-
)
|
893
|
+
logger.setLevel(log_level)
|
916
894
|
|
917
|
-
|
895
|
+
service_configurations = context.ServiceConfigurations(
|
918
896
|
configuration_file=configuration_file,
|
919
|
-
|
920
|
-
tag=tag,
|
921
|
-
parameters_file=parameters_file,
|
922
|
-
is_job=True,
|
923
|
-
)
|
924
|
-
|
925
|
-
assert isinstance(run_context.executor, BaseJobExecutor)
|
926
|
-
run_context.from_sdk = True
|
927
|
-
|
928
|
-
utils.set_runnable_environment_variables(
|
929
|
-
run_id=run_id, configuration_file=configuration_file, tag=tag
|
897
|
+
execution_context=context.ExecutionContext.JOB,
|
930
898
|
)
|
931
899
|
|
932
|
-
|
933
|
-
|
934
|
-
|
935
|
-
|
936
|
-
|
937
|
-
|
938
|
-
|
939
|
-
|
940
|
-
|
941
|
-
|
942
|
-
|
943
|
-
module_name = re.sub(r"\b.py\b", "", relative_to_root.replace("/", "."))
|
944
|
-
module_to_call = f"{module_name}.{caller_stack.function}"
|
945
|
-
|
946
|
-
run_context.job_definition_file = f"{module_to_call}.py"
|
947
|
-
|
948
|
-
job = self.get_task()
|
949
|
-
catalog_settings = self.return_catalog_settings()
|
900
|
+
configurations = {
|
901
|
+
"job_definition_file": self.get_caller(),
|
902
|
+
"parameters_file": parameters_file,
|
903
|
+
"tag": tag,
|
904
|
+
"run_id": job_id,
|
905
|
+
"execution_mode": context.ExecutionMode.PYTHON,
|
906
|
+
"configuration_file": configuration_file,
|
907
|
+
"job": self.get_task(),
|
908
|
+
"catalog_settings": self.return_catalog_settings(),
|
909
|
+
**service_configurations.services,
|
910
|
+
}
|
950
911
|
|
951
|
-
|
952
|
-
run_context.executor.submit_job(job, catalog_settings=catalog_settings)
|
953
|
-
finally:
|
954
|
-
run_context.executor.add_task_log_to_catalog("job")
|
912
|
+
run_context = context.JobContext.model_validate(configurations)
|
955
913
|
|
956
|
-
|
957
|
-
"Executing the job from the user. We are still in the caller's compute environment"
|
958
|
-
)
|
914
|
+
assert isinstance(run_context.job_executor, BaseJobExecutor)
|
959
915
|
|
960
|
-
|
961
|
-
return run_context.run_log_store.get_run_log_by_id(
|
962
|
-
run_id=run_context.run_id
|
963
|
-
)
|
916
|
+
run_context.execute()
|
964
917
|
|
965
918
|
|
966
919
|
class PythonJob(BaseJob):
|
runnable/tasks.py
CHANGED
@@ -25,7 +25,7 @@ from runnable.datastore import (
|
|
25
25
|
Parameter,
|
26
26
|
StepAttempt,
|
27
27
|
)
|
28
|
-
from runnable.defaults import
|
28
|
+
from runnable.defaults import MapVariableType
|
29
29
|
|
30
30
|
logger = logging.getLogger(defaults.LOGGER_NAME)
|
31
31
|
|
@@ -48,7 +48,29 @@ class TeeIO(io.StringIO):
|
|
48
48
|
self.output_stream.flush()
|
49
49
|
|
50
50
|
|
51
|
-
|
51
|
+
@contextlib.contextmanager
|
52
|
+
def redirect_output():
|
53
|
+
# Set the stream handlers to use the custom TeeIO class
|
54
|
+
|
55
|
+
# Backup the original stdout and stderr
|
56
|
+
original_stdout = sys.stdout
|
57
|
+
original_stderr = sys.stderr
|
58
|
+
|
59
|
+
# Redirect stdout and stderr to custom TeeStream objects
|
60
|
+
sys.stdout = TeeIO(sys.stdout)
|
61
|
+
sys.stderr = TeeIO(sys.stderr)
|
62
|
+
|
63
|
+
# Replace stream for all StreamHandlers to use the new sys.stdout
|
64
|
+
for handler in logging.getLogger().handlers:
|
65
|
+
if isinstance(handler, logging.StreamHandler):
|
66
|
+
handler.stream = sys.stdout
|
67
|
+
|
68
|
+
try:
|
69
|
+
yield sys.stdout, sys.stderr
|
70
|
+
finally:
|
71
|
+
# Restore the original stdout and stderr
|
72
|
+
sys.stdout = original_stdout
|
73
|
+
sys.stderr = original_stderr
|
52
74
|
|
53
75
|
|
54
76
|
class TaskReturns(BaseModel):
|
@@ -79,7 +101,7 @@ class BaseTaskType(BaseModel):
|
|
79
101
|
def set_secrets_as_env_variables(self):
|
80
102
|
# Preparing the environment for the task execution
|
81
103
|
for key in self.secrets:
|
82
|
-
secret_value = context.run_context.
|
104
|
+
secret_value = context.run_context.secrets.get(key)
|
83
105
|
os.environ[key] = secret_value
|
84
106
|
|
85
107
|
def delete_secrets_from_env_variables(self):
|
@@ -90,7 +112,7 @@ class BaseTaskType(BaseModel):
|
|
90
112
|
|
91
113
|
def execute_command(
|
92
114
|
self,
|
93
|
-
map_variable:
|
115
|
+
map_variable: MapVariableType = None,
|
94
116
|
) -> StepAttempt:
|
95
117
|
"""The function to execute the command.
|
96
118
|
|
@@ -130,7 +152,7 @@ class BaseTaskType(BaseModel):
|
|
130
152
|
finally:
|
131
153
|
self.delete_secrets_from_env_variables()
|
132
154
|
|
133
|
-
def resolve_unreduced_parameters(self, map_variable:
|
155
|
+
def resolve_unreduced_parameters(self, map_variable: MapVariableType = None):
|
134
156
|
"""Resolve the unreduced parameters."""
|
135
157
|
params = self._context.run_log_store.get_parameters(
|
136
158
|
run_id=self._context.run_id
|
@@ -153,7 +175,7 @@ class BaseTaskType(BaseModel):
|
|
153
175
|
|
154
176
|
@contextlib.contextmanager
|
155
177
|
def execution_context(
|
156
|
-
self, map_variable:
|
178
|
+
self, map_variable: MapVariableType = None, allow_complex: bool = True
|
157
179
|
):
|
158
180
|
params = self.resolve_unreduced_parameters(map_variable=map_variable)
|
159
181
|
logger.info(f"Parameters available for the execution: {params}")
|
@@ -267,7 +289,7 @@ class PythonTaskType(BaseTaskType): # pylint: disable=too-few-public-methods
|
|
267
289
|
|
268
290
|
def execute_command(
|
269
291
|
self,
|
270
|
-
map_variable:
|
292
|
+
map_variable: MapVariableType = None,
|
271
293
|
) -> StepAttempt:
|
272
294
|
"""Execute the notebook as defined by the command."""
|
273
295
|
attempt_log = StepAttempt(status=defaults.FAIL, start_time=str(datetime.now()))
|
@@ -289,13 +311,21 @@ class PythonTaskType(BaseTaskType): # pylint: disable=too-few-public-methods
|
|
289
311
|
logger.info(
|
290
312
|
f"Calling {func} from {module} with {filtered_parameters}"
|
291
313
|
)
|
292
|
-
|
293
|
-
|
294
|
-
with contextlib.redirect_stdout(out_file):
|
314
|
+
context.progress.stop() # redirecting stdout clashes with rich progress
|
315
|
+
with redirect_output() as (buffer, stderr_buffer):
|
295
316
|
user_set_parameters = f(
|
296
317
|
**filtered_parameters
|
297
318
|
) # This is a tuple or single value
|
298
|
-
|
319
|
+
|
320
|
+
print(
|
321
|
+
stderr_buffer.getvalue()
|
322
|
+
) # To print the logging statements
|
323
|
+
|
324
|
+
# TODO: Avoid double print!!
|
325
|
+
with task_console.capture():
|
326
|
+
task_console.log(buffer.getvalue())
|
327
|
+
task_console.log(stderr_buffer.getvalue())
|
328
|
+
context.progress.start()
|
299
329
|
except Exception as e:
|
300
330
|
raise exceptions.CommandCallError(
|
301
331
|
f"Function call: {self.command} did not succeed.\n"
|
@@ -478,14 +508,15 @@ class NotebookTaskType(BaseTaskType):
|
|
478
508
|
|
479
509
|
return command
|
480
510
|
|
481
|
-
def get_notebook_output_path(self, map_variable:
|
511
|
+
def get_notebook_output_path(self, map_variable: MapVariableType = None) -> str:
|
482
512
|
tag = ""
|
483
513
|
map_variable = map_variable or {}
|
484
514
|
for key, value in map_variable.items():
|
485
515
|
tag += f"{key}_{value}_"
|
486
516
|
|
487
|
-
if
|
488
|
-
|
517
|
+
if isinstance(self._context, context.PipelineContext):
|
518
|
+
assert self._context.pipeline_executor._context_node
|
519
|
+
tag += self._context.pipeline_executor._context_node.name
|
489
520
|
|
490
521
|
tag = "".join(x for x in tag if x.isalnum()).strip("-")
|
491
522
|
|
@@ -496,7 +527,7 @@ class NotebookTaskType(BaseTaskType):
|
|
496
527
|
|
497
528
|
def execute_command(
|
498
529
|
self,
|
499
|
-
map_variable:
|
530
|
+
map_variable: MapVariableType = None,
|
500
531
|
) -> StepAttempt:
|
501
532
|
"""Execute the python notebook as defined by the command.
|
502
533
|
|
@@ -551,12 +582,20 @@ class NotebookTaskType(BaseTaskType):
|
|
551
582
|
}
|
552
583
|
kwds.update(ploomber_optional_args)
|
553
584
|
|
554
|
-
|
555
|
-
|
585
|
+
context.progress.stop() # redirecting stdout clashes with rich progress
|
586
|
+
|
587
|
+
with redirect_output() as (buffer, stderr_buffer):
|
556
588
|
pm.execute_notebook(**kwds)
|
557
|
-
task_console.print(out_file.getvalue())
|
558
589
|
|
559
|
-
|
590
|
+
print(stderr_buffer.getvalue()) # To print the logging statements
|
591
|
+
|
592
|
+
with task_console.capture():
|
593
|
+
task_console.log(buffer.getvalue())
|
594
|
+
task_console.log(stderr_buffer.getvalue())
|
595
|
+
|
596
|
+
context.progress.start()
|
597
|
+
|
598
|
+
context.run_context.catalog.put(name=notebook_output_path)
|
560
599
|
|
561
600
|
client = PloomberClient.from_path(path=notebook_output_path)
|
562
601
|
namespace = client.get_namespace()
|
@@ -674,7 +713,7 @@ class ShellTaskType(BaseTaskType):
|
|
674
713
|
|
675
714
|
def execute_command(
|
676
715
|
self,
|
677
|
-
map_variable:
|
716
|
+
map_variable: MapVariableType = None,
|
678
717
|
) -> StepAttempt:
|
679
718
|
# Using shell=True as we want to have chained commands to be executed in the same shell.
|
680
719
|
"""Execute the shell command as defined by the command.
|
@@ -698,7 +737,7 @@ class ShellTaskType(BaseTaskType):
|
|
698
737
|
# Expose secrets as environment variables
|
699
738
|
if self.secrets:
|
700
739
|
for key in self.secrets:
|
701
|
-
secret_value = context.run_context.
|
740
|
+
secret_value = context.run_context.secrets.get(key)
|
702
741
|
subprocess_env[key] = secret_value
|
703
742
|
|
704
743
|
try:
|
@@ -724,6 +763,7 @@ class ShellTaskType(BaseTaskType):
|
|
724
763
|
capture = False
|
725
764
|
return_keys = {x.name: x for x in self.returns}
|
726
765
|
|
766
|
+
context.progress.stop() # redirecting stdout clashes with rich progress
|
727
767
|
proc = subprocess.Popen(
|
728
768
|
command,
|
729
769
|
shell=True,
|
@@ -747,6 +787,7 @@ class ShellTaskType(BaseTaskType):
|
|
747
787
|
continue
|
748
788
|
task_console.print(line, style=defaults.warning_style)
|
749
789
|
|
790
|
+
context.progress.start()
|
750
791
|
output_parameters: Dict[str, Parameter] = {}
|
751
792
|
metrics: Dict[str, Parameter] = {}
|
752
793
|
|