runnable 0.35.0__py3-none-any.whl → 0.36.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- extensions/job_executor/__init__.py +3 -4
- extensions/job_executor/emulate.py +106 -0
- extensions/job_executor/k8s.py +8 -8
- extensions/job_executor/local_container.py +13 -14
- extensions/nodes/__init__.py +0 -0
- extensions/nodes/conditional.py +7 -5
- extensions/nodes/fail.py +72 -0
- extensions/nodes/map.py +350 -0
- extensions/nodes/parallel.py +159 -0
- extensions/nodes/stub.py +89 -0
- extensions/nodes/success.py +72 -0
- extensions/nodes/task.py +92 -0
- extensions/pipeline_executor/__init__.py +24 -26
- extensions/pipeline_executor/argo.py +20 -20
- extensions/pipeline_executor/emulate.py +112 -0
- extensions/pipeline_executor/local.py +4 -4
- extensions/pipeline_executor/local_container.py +19 -79
- extensions/pipeline_executor/mocked.py +5 -9
- extensions/pipeline_executor/retry.py +6 -10
- runnable/__init__.py +0 -10
- runnable/catalog.py +1 -21
- runnable/cli.py +0 -59
- runnable/context.py +519 -28
- runnable/datastore.py +51 -54
- runnable/defaults.py +12 -34
- runnable/entrypoints.py +82 -440
- runnable/exceptions.py +35 -34
- runnable/executor.py +13 -20
- runnable/names.py +1 -1
- runnable/nodes.py +16 -15
- runnable/parameters.py +2 -2
- runnable/sdk.py +66 -205
- runnable/tasks.py +62 -81
- runnable/utils.py +6 -268
- {runnable-0.35.0.dist-info → runnable-0.36.1.dist-info}/METADATA +1 -4
- runnable-0.36.1.dist-info/RECORD +72 -0
- {runnable-0.35.0.dist-info → runnable-0.36.1.dist-info}/entry_points.txt +8 -7
- extensions/nodes/nodes.py +0 -778
- extensions/tasks/torch.py +0 -286
- extensions/tasks/torch_config.py +0 -76
- runnable-0.35.0.dist-info/RECORD +0 -66
- {runnable-0.35.0.dist-info → runnable-0.36.1.dist-info}/WHEEL +0 -0
- {runnable-0.35.0.dist-info → runnable-0.36.1.dist-info}/licenses/LICENSE +0 -0
runnable/sdk.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import inspect
|
3
4
|
import logging
|
4
|
-
import os
|
5
5
|
import re
|
6
6
|
from abc import ABC, abstractmethod
|
7
7
|
from pathlib import Path
|
@@ -16,27 +16,17 @@ from pydantic import (
|
|
16
16
|
field_validator,
|
17
17
|
model_validator,
|
18
18
|
)
|
19
|
-
from rich.progress import (
|
20
|
-
BarColumn,
|
21
|
-
Progress,
|
22
|
-
SpinnerColumn,
|
23
|
-
TextColumn,
|
24
|
-
TimeElapsedColumn,
|
25
|
-
)
|
26
|
-
from rich.table import Column
|
27
19
|
from typing_extensions import Self
|
28
20
|
|
29
21
|
from extensions.nodes.conditional import ConditionalNode
|
30
|
-
from extensions.nodes.
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
from runnable import console, defaults, entrypoints, exceptions, graph, utils
|
39
|
-
from runnable.executor import BaseJobExecutor, BasePipelineExecutor
|
22
|
+
from extensions.nodes.fail import FailNode
|
23
|
+
from extensions.nodes.map import MapNode
|
24
|
+
from extensions.nodes.parallel import ParallelNode
|
25
|
+
from extensions.nodes.stub import StubNode
|
26
|
+
from extensions.nodes.success import SuccessNode
|
27
|
+
from extensions.nodes.task import TaskNode
|
28
|
+
from runnable import defaults, graph
|
29
|
+
from runnable.executor import BaseJobExecutor
|
40
30
|
from runnable.nodes import TraversalNode
|
41
31
|
from runnable.tasks import BaseTaskType as RunnableTask
|
42
32
|
from runnable.tasks import TaskReturns
|
@@ -50,7 +40,6 @@ StepType = Union[
|
|
50
40
|
"ShellTask",
|
51
41
|
"Parallel",
|
52
42
|
"Map",
|
53
|
-
"TorchTask",
|
54
43
|
"Conditional",
|
55
44
|
]
|
56
45
|
|
@@ -196,7 +185,7 @@ class BaseTask(BaseTraversal):
|
|
196
185
|
)
|
197
186
|
|
198
187
|
def as_pipeline(self) -> "Pipeline":
|
199
|
-
return Pipeline(steps=[self]) # type: ignore
|
188
|
+
return Pipeline(steps=[self], name=self.internal_name) # type: ignore
|
200
189
|
|
201
190
|
|
202
191
|
class PythonTask(BaseTask):
|
@@ -287,27 +276,6 @@ class PythonTask(BaseTask):
|
|
287
276
|
return node.executable
|
288
277
|
|
289
278
|
|
290
|
-
class TorchTask(BaseTask):
|
291
|
-
# entrypoint: str = Field(
|
292
|
-
# alias="entrypoint", default="torch.distributed.run", frozen=True
|
293
|
-
# )
|
294
|
-
# args_to_torchrun: Dict[str, Any] = Field(
|
295
|
-
# default_factory=dict, alias="args_to_torchrun"
|
296
|
-
# )
|
297
|
-
|
298
|
-
script_to_call: str
|
299
|
-
accelerate_config_file: str
|
300
|
-
|
301
|
-
@computed_field
|
302
|
-
def command_type(self) -> str:
|
303
|
-
return "torch"
|
304
|
-
|
305
|
-
def create_job(self) -> RunnableTask:
|
306
|
-
self.terminate_with_success = True
|
307
|
-
node = self.create_node()
|
308
|
-
return node.executable
|
309
|
-
|
310
|
-
|
311
279
|
class NotebookTask(BaseTask):
|
312
280
|
"""
|
313
281
|
An execution node of the pipeline of notebook.
|
@@ -487,6 +455,9 @@ class Stub(BaseTraversal):
|
|
487
455
|
|
488
456
|
return StubNode.parse_from_config(self.model_dump(exclude_none=True))
|
489
457
|
|
458
|
+
def as_pipeline(self) -> "Pipeline":
|
459
|
+
return Pipeline(steps=[self])
|
460
|
+
|
490
461
|
|
491
462
|
class Parallel(BaseTraversal):
|
492
463
|
"""
|
@@ -786,6 +757,15 @@ class Pipeline(BaseModel):
|
|
786
757
|
return False
|
787
758
|
return True
|
788
759
|
|
760
|
+
def get_caller(self) -> str:
|
761
|
+
caller_stack = inspect.stack()[2]
|
762
|
+
relative_to_root = str(Path(caller_stack.filename).relative_to(Path.cwd()))
|
763
|
+
|
764
|
+
module_name = re.sub(r"\b.py\b", "", relative_to_root.replace("/", "."))
|
765
|
+
module_to_call = f"{module_name}.{caller_stack.function}"
|
766
|
+
|
767
|
+
return module_to_call
|
768
|
+
|
789
769
|
def execute(
|
790
770
|
self,
|
791
771
|
configuration_file: str = "",
|
@@ -803,106 +783,31 @@ class Pipeline(BaseModel):
|
|
803
783
|
# Immediately return as this call is only for getting the pipeline definition
|
804
784
|
return {}
|
805
785
|
|
806
|
-
|
807
|
-
|
808
|
-
run_id = utils.generate_run_id(run_id=run_id)
|
809
|
-
|
810
|
-
parameters_file = os.environ.get("RUNNABLE_PARAMETERS_FILE", parameters_file)
|
786
|
+
from runnable import context
|
811
787
|
|
812
|
-
|
788
|
+
logger.setLevel(log_level)
|
813
789
|
|
814
|
-
|
815
|
-
"RUNNABLE_CONFIGURATION_FILE", configuration_file
|
816
|
-
)
|
817
|
-
run_context = entrypoints.prepare_configurations(
|
790
|
+
service_configurations = context.ServiceConfigurations(
|
818
791
|
configuration_file=configuration_file,
|
819
|
-
|
820
|
-
tag=tag,
|
821
|
-
parameters_file=parameters_file,
|
822
|
-
)
|
823
|
-
|
824
|
-
assert isinstance(run_context.executor, BasePipelineExecutor)
|
825
|
-
|
826
|
-
utils.set_runnable_environment_variables(
|
827
|
-
run_id=run_id, configuration_file=configuration_file, tag=tag
|
792
|
+
execution_context=context.ExecutionContext.PIPELINE,
|
828
793
|
)
|
829
794
|
|
830
|
-
|
831
|
-
|
832
|
-
|
833
|
-
|
834
|
-
|
835
|
-
|
836
|
-
|
837
|
-
|
838
|
-
|
839
|
-
# We are not working with executor that does not work in local environment
|
840
|
-
import inspect
|
841
|
-
|
842
|
-
caller_stack = inspect.stack()[1]
|
843
|
-
relative_to_root = str(Path(caller_stack.filename).relative_to(Path.cwd()))
|
844
|
-
|
845
|
-
module_name = re.sub(r"\b.py\b", "", relative_to_root.replace("/", "."))
|
846
|
-
module_to_call = f"{module_name}.{caller_stack.function}"
|
847
|
-
|
848
|
-
run_context.pipeline_file = f"{module_to_call}.py"
|
849
|
-
run_context.from_sdk = True
|
850
|
-
|
851
|
-
# Prepare for graph execution
|
852
|
-
run_context.executor._set_up_run_log(exists_ok=False)
|
853
|
-
|
854
|
-
with Progress(
|
855
|
-
SpinnerColumn(spinner_name="runner"),
|
856
|
-
TextColumn(
|
857
|
-
"[progress.description]{task.description}", table_column=Column(ratio=2)
|
858
|
-
),
|
859
|
-
BarColumn(table_column=Column(ratio=1), style="dark_orange"),
|
860
|
-
TimeElapsedColumn(table_column=Column(ratio=1)),
|
861
|
-
console=console,
|
862
|
-
expand=True,
|
863
|
-
) as progress:
|
864
|
-
pipeline_execution_task = progress.add_task(
|
865
|
-
"[dark_orange] Starting execution .. ", total=1
|
866
|
-
)
|
867
|
-
try:
|
868
|
-
run_context.progress = progress
|
869
|
-
|
870
|
-
run_context.executor.execute_graph(dag=run_context.dag)
|
795
|
+
configurations = {
|
796
|
+
"pipeline_definition_file": self.get_caller(),
|
797
|
+
"parameters_file": parameters_file,
|
798
|
+
"tag": tag,
|
799
|
+
"run_id": run_id,
|
800
|
+
"execution_mode": context.ExecutionMode.PYTHON,
|
801
|
+
"configuration_file": configuration_file,
|
802
|
+
**service_configurations.services,
|
803
|
+
}
|
871
804
|
|
872
|
-
|
873
|
-
|
874
|
-
return {}
|
805
|
+
run_context = context.PipelineContext.model_validate(configurations)
|
806
|
+
context.run_context = run_context
|
875
807
|
|
876
|
-
|
877
|
-
run_id=run_context.run_id, full=False
|
878
|
-
)
|
808
|
+
assert isinstance(run_context, context.PipelineContext)
|
879
809
|
|
880
|
-
|
881
|
-
progress.update(
|
882
|
-
pipeline_execution_task,
|
883
|
-
description="[green] Success",
|
884
|
-
completed=True,
|
885
|
-
)
|
886
|
-
else:
|
887
|
-
progress.update(
|
888
|
-
pipeline_execution_task,
|
889
|
-
description="[red] Failed",
|
890
|
-
completed=True,
|
891
|
-
)
|
892
|
-
raise exceptions.ExecutionFailedError(run_context.run_id)
|
893
|
-
except Exception as e: # noqa: E722
|
894
|
-
console.print(e, style=defaults.error_style)
|
895
|
-
progress.update(
|
896
|
-
pipeline_execution_task,
|
897
|
-
description="[red] Errored execution",
|
898
|
-
completed=True,
|
899
|
-
)
|
900
|
-
raise
|
901
|
-
|
902
|
-
if run_context.executor._is_local:
|
903
|
-
return run_context.run_log_store.get_run_log_by_id(
|
904
|
-
run_id=run_context.run_id
|
905
|
-
)
|
810
|
+
run_context.execute()
|
906
811
|
|
907
812
|
|
908
813
|
class BaseJob(BaseModel):
|
@@ -926,6 +831,15 @@ class BaseJob(BaseModel):
|
|
926
831
|
def get_task(self) -> RunnableTask:
|
927
832
|
raise NotImplementedError
|
928
833
|
|
834
|
+
def get_caller(self) -> str:
|
835
|
+
caller_stack = inspect.stack()[2]
|
836
|
+
relative_to_root = str(Path(caller_stack.filename).relative_to(Path.cwd()))
|
837
|
+
|
838
|
+
module_name = re.sub(r"\b.py\b", "", relative_to_root.replace("/", "."))
|
839
|
+
module_to_call = f"{module_name}.{caller_stack.function}"
|
840
|
+
|
841
|
+
return module_to_call
|
842
|
+
|
929
843
|
def return_catalog_settings(self) -> Optional[List[str]]:
|
930
844
|
if self.catalog is None:
|
931
845
|
return []
|
@@ -952,65 +866,32 @@ class BaseJob(BaseModel):
|
|
952
866
|
if self._is_called_for_definition():
|
953
867
|
# Immediately return as this call is only for getting the job definition
|
954
868
|
return {}
|
955
|
-
|
956
|
-
|
957
|
-
run_id = utils.generate_run_id(run_id=job_id)
|
958
|
-
|
959
|
-
parameters_file = os.environ.get("RUNNABLE_PARAMETERS_FILE", parameters_file)
|
960
|
-
|
961
|
-
tag = os.environ.get("RUNNABLE_tag", tag)
|
869
|
+
from runnable import context
|
962
870
|
|
963
|
-
|
964
|
-
"RUNNABLE_CONFIGURATION_FILE", configuration_file
|
965
|
-
)
|
871
|
+
logger.setLevel(log_level)
|
966
872
|
|
967
|
-
|
873
|
+
service_configurations = context.ServiceConfigurations(
|
968
874
|
configuration_file=configuration_file,
|
969
|
-
|
970
|
-
tag=tag,
|
971
|
-
parameters_file=parameters_file,
|
972
|
-
is_job=True,
|
973
|
-
)
|
974
|
-
|
975
|
-
assert isinstance(run_context.executor, BaseJobExecutor)
|
976
|
-
run_context.from_sdk = True
|
977
|
-
|
978
|
-
utils.set_runnable_environment_variables(
|
979
|
-
run_id=run_id, configuration_file=configuration_file, tag=tag
|
875
|
+
execution_context=context.ExecutionContext.JOB,
|
980
876
|
)
|
981
877
|
|
982
|
-
|
983
|
-
|
984
|
-
|
985
|
-
|
986
|
-
|
987
|
-
|
988
|
-
|
989
|
-
|
990
|
-
|
991
|
-
|
992
|
-
|
993
|
-
module_name = re.sub(r"\b.py\b", "", relative_to_root.replace("/", "."))
|
994
|
-
module_to_call = f"{module_name}.{caller_stack.function}"
|
995
|
-
|
996
|
-
run_context.job_definition_file = f"{module_to_call}.py"
|
997
|
-
|
998
|
-
job = self.get_task()
|
999
|
-
catalog_settings = self.return_catalog_settings()
|
878
|
+
configurations = {
|
879
|
+
"job_definition_file": self.get_caller(),
|
880
|
+
"parameters_file": parameters_file,
|
881
|
+
"tag": tag,
|
882
|
+
"run_id": job_id,
|
883
|
+
"execution_mode": context.ExecutionMode.PYTHON,
|
884
|
+
"configuration_file": configuration_file,
|
885
|
+
"job": self.get_task(),
|
886
|
+
"catalog_settings": self.return_catalog_settings(),
|
887
|
+
**service_configurations.services,
|
888
|
+
}
|
1000
889
|
|
1001
|
-
|
1002
|
-
run_context.executor.submit_job(job, catalog_settings=catalog_settings)
|
1003
|
-
finally:
|
1004
|
-
run_context.executor.add_task_log_to_catalog("job")
|
890
|
+
run_context = context.JobContext.model_validate(configurations)
|
1005
891
|
|
1006
|
-
|
1007
|
-
"Executing the job from the user. We are still in the caller's compute environment"
|
1008
|
-
)
|
892
|
+
assert isinstance(run_context.job_executor, BaseJobExecutor)
|
1009
893
|
|
1010
|
-
|
1011
|
-
return run_context.run_log_store.get_run_log_by_id(
|
1012
|
-
run_id=run_context.run_id
|
1013
|
-
)
|
894
|
+
run_context.execute()
|
1014
895
|
|
1015
896
|
|
1016
897
|
class PythonJob(BaseJob):
|
@@ -1034,26 +915,6 @@ class PythonJob(BaseJob):
|
|
1034
915
|
return task.create_node().executable
|
1035
916
|
|
1036
917
|
|
1037
|
-
class TorchJob(BaseJob):
|
1038
|
-
# entrypoint: str = Field(default="torch.distributed.run", frozen=True)
|
1039
|
-
# args_to_torchrun: dict[str, str | bool | int | float] = Field(
|
1040
|
-
# default_factory=dict
|
1041
|
-
# ) # For example
|
1042
|
-
# {"nproc_per_node": 2, "nnodes": 1,}
|
1043
|
-
|
1044
|
-
script_to_call: str # For example train/script.py
|
1045
|
-
accelerate_config_file: str
|
1046
|
-
|
1047
|
-
def get_task(self) -> RunnableTask:
|
1048
|
-
# Piggy bank on existing tasks as a hack
|
1049
|
-
task = TorchTask(
|
1050
|
-
name="dummy",
|
1051
|
-
terminate_with_success=True,
|
1052
|
-
**self.model_dump(exclude_defaults=True, exclude_none=True),
|
1053
|
-
)
|
1054
|
-
return task.create_node().executable
|
1055
|
-
|
1056
|
-
|
1057
918
|
class NotebookJob(BaseJob):
|
1058
919
|
notebook: str = Field(serialization_alias="command")
|
1059
920
|
optional_ploomber_args: Optional[Dict[str, Any]] = Field(
|
runnable/tasks.py
CHANGED
@@ -25,7 +25,7 @@ from runnable.datastore import (
|
|
25
25
|
Parameter,
|
26
26
|
StepAttempt,
|
27
27
|
)
|
28
|
-
from runnable.defaults import
|
28
|
+
from runnable.defaults import MapVariableType
|
29
29
|
|
30
30
|
logger = logging.getLogger(defaults.LOGGER_NAME)
|
31
31
|
|
@@ -48,7 +48,29 @@ class TeeIO(io.StringIO):
|
|
48
48
|
self.output_stream.flush()
|
49
49
|
|
50
50
|
|
51
|
-
|
51
|
+
@contextlib.contextmanager
|
52
|
+
def redirect_output():
|
53
|
+
# Set the stream handlers to use the custom TeeIO class
|
54
|
+
|
55
|
+
# Backup the original stdout and stderr
|
56
|
+
original_stdout = sys.stdout
|
57
|
+
original_stderr = sys.stderr
|
58
|
+
|
59
|
+
# Redirect stdout and stderr to custom TeeStream objects
|
60
|
+
sys.stdout = TeeIO(sys.stdout)
|
61
|
+
sys.stderr = TeeIO(sys.stderr)
|
62
|
+
|
63
|
+
# Replace stream for all StreamHandlers to use the new sys.stdout
|
64
|
+
for handler in logging.getLogger().handlers:
|
65
|
+
if isinstance(handler, logging.StreamHandler):
|
66
|
+
handler.stream = sys.stdout
|
67
|
+
|
68
|
+
try:
|
69
|
+
yield sys.stdout, sys.stderr
|
70
|
+
finally:
|
71
|
+
# Restore the original stdout and stderr
|
72
|
+
sys.stdout = original_stdout
|
73
|
+
sys.stderr = original_stderr
|
52
74
|
|
53
75
|
|
54
76
|
class TaskReturns(BaseModel):
|
@@ -79,7 +101,7 @@ class BaseTaskType(BaseModel):
|
|
79
101
|
def set_secrets_as_env_variables(self):
|
80
102
|
# Preparing the environment for the task execution
|
81
103
|
for key in self.secrets:
|
82
|
-
secret_value = context.run_context.
|
104
|
+
secret_value = context.run_context.secrets.get(key)
|
83
105
|
os.environ[key] = secret_value
|
84
106
|
|
85
107
|
def delete_secrets_from_env_variables(self):
|
@@ -90,7 +112,7 @@ class BaseTaskType(BaseModel):
|
|
90
112
|
|
91
113
|
def execute_command(
|
92
114
|
self,
|
93
|
-
map_variable:
|
115
|
+
map_variable: MapVariableType = None,
|
94
116
|
) -> StepAttempt:
|
95
117
|
"""The function to execute the command.
|
96
118
|
|
@@ -130,7 +152,7 @@ class BaseTaskType(BaseModel):
|
|
130
152
|
finally:
|
131
153
|
self.delete_secrets_from_env_variables()
|
132
154
|
|
133
|
-
def resolve_unreduced_parameters(self, map_variable:
|
155
|
+
def resolve_unreduced_parameters(self, map_variable: MapVariableType = None):
|
134
156
|
"""Resolve the unreduced parameters."""
|
135
157
|
params = self._context.run_log_store.get_parameters(
|
136
158
|
run_id=self._context.run_id
|
@@ -153,7 +175,7 @@ class BaseTaskType(BaseModel):
|
|
153
175
|
|
154
176
|
@contextlib.contextmanager
|
155
177
|
def execution_context(
|
156
|
-
self, map_variable:
|
178
|
+
self, map_variable: MapVariableType = None, allow_complex: bool = True
|
157
179
|
):
|
158
180
|
params = self.resolve_unreduced_parameters(map_variable=map_variable)
|
159
181
|
logger.info(f"Parameters available for the execution: {params}")
|
@@ -267,7 +289,7 @@ class PythonTaskType(BaseTaskType): # pylint: disable=too-few-public-methods
|
|
267
289
|
|
268
290
|
def execute_command(
|
269
291
|
self,
|
270
|
-
map_variable:
|
292
|
+
map_variable: MapVariableType = None,
|
271
293
|
) -> StepAttempt:
|
272
294
|
"""Execute the notebook as defined by the command."""
|
273
295
|
attempt_log = StepAttempt(status=defaults.FAIL, start_time=str(datetime.now()))
|
@@ -289,13 +311,21 @@ class PythonTaskType(BaseTaskType): # pylint: disable=too-few-public-methods
|
|
289
311
|
logger.info(
|
290
312
|
f"Calling {func} from {module} with {filtered_parameters}"
|
291
313
|
)
|
292
|
-
|
293
|
-
|
294
|
-
with contextlib.redirect_stdout(out_file):
|
314
|
+
context.progress.stop() # redirecting stdout clashes with rich progress
|
315
|
+
with redirect_output() as (buffer, stderr_buffer):
|
295
316
|
user_set_parameters = f(
|
296
317
|
**filtered_parameters
|
297
318
|
) # This is a tuple or single value
|
298
|
-
|
319
|
+
|
320
|
+
print(
|
321
|
+
stderr_buffer.getvalue()
|
322
|
+
) # To print the logging statements
|
323
|
+
|
324
|
+
# TODO: Avoid double print!!
|
325
|
+
with task_console.capture():
|
326
|
+
task_console.log(buffer.getvalue())
|
327
|
+
task_console.log(stderr_buffer.getvalue())
|
328
|
+
context.progress.start()
|
299
329
|
except Exception as e:
|
300
330
|
raise exceptions.CommandCallError(
|
301
331
|
f"Function call: {self.command} did not succeed.\n"
|
@@ -354,66 +384,6 @@ class PythonTaskType(BaseTaskType): # pylint: disable=too-few-public-methods
|
|
354
384
|
return attempt_log
|
355
385
|
|
356
386
|
|
357
|
-
class TorchTaskType(BaseTaskType):
|
358
|
-
task_type: str = Field(default="torch", serialization_alias="command_type")
|
359
|
-
accelerate_config_file: str
|
360
|
-
|
361
|
-
script_to_call: str # For example train/script.py
|
362
|
-
|
363
|
-
def execute_command(
|
364
|
-
self, map_variable: Dict[str, str | int | float] | None = None
|
365
|
-
) -> StepAttempt:
|
366
|
-
from accelerate.commands import launch
|
367
|
-
|
368
|
-
attempt_log = StepAttempt(status=defaults.FAIL, start_time=str(datetime.now()))
|
369
|
-
|
370
|
-
with (
|
371
|
-
self.execution_context(
|
372
|
-
map_variable=map_variable, allow_complex=False
|
373
|
-
) as params,
|
374
|
-
self.expose_secrets() as _,
|
375
|
-
):
|
376
|
-
try:
|
377
|
-
script_args = []
|
378
|
-
for key, value in params.items():
|
379
|
-
script_args.append(f"--{key}")
|
380
|
-
if type(value.value) is not bool:
|
381
|
-
script_args.append(str(value.value))
|
382
|
-
|
383
|
-
# TODO: Check the typing here
|
384
|
-
|
385
|
-
logger.info("Calling the user script with the following parameters:")
|
386
|
-
logger.info(script_args)
|
387
|
-
out_file = TeeIO()
|
388
|
-
try:
|
389
|
-
with contextlib.redirect_stdout(out_file):
|
390
|
-
parser = launch.launch_command_parser()
|
391
|
-
args = parser.parse_args(self.script_to_call)
|
392
|
-
args.training_script = self.script_to_call
|
393
|
-
args.config_file = self.accelerate_config_file
|
394
|
-
args.training_script_args = script_args
|
395
|
-
|
396
|
-
launch.launch_command(args)
|
397
|
-
task_console.print(out_file.getvalue())
|
398
|
-
except Exception as e:
|
399
|
-
raise exceptions.CommandCallError(
|
400
|
-
f"Call to script{self.script_to_call} did not succeed."
|
401
|
-
) from e
|
402
|
-
finally:
|
403
|
-
sys.argv = sys.argv[:1]
|
404
|
-
|
405
|
-
attempt_log.status = defaults.SUCCESS
|
406
|
-
except Exception as _e:
|
407
|
-
msg = f"Call to script: {self.script_to_call} did not succeed."
|
408
|
-
attempt_log.message = msg
|
409
|
-
task_console.print_exception(show_locals=False)
|
410
|
-
task_console.log(_e, style=defaults.error_style)
|
411
|
-
|
412
|
-
attempt_log.end_time = str(datetime.now())
|
413
|
-
|
414
|
-
return attempt_log
|
415
|
-
|
416
|
-
|
417
387
|
class NotebookTaskType(BaseTaskType):
|
418
388
|
"""
|
419
389
|
--8<-- [start:notebook_reference]
|
@@ -478,14 +448,15 @@ class NotebookTaskType(BaseTaskType):
|
|
478
448
|
|
479
449
|
return command
|
480
450
|
|
481
|
-
def get_notebook_output_path(self, map_variable:
|
451
|
+
def get_notebook_output_path(self, map_variable: MapVariableType = None) -> str:
|
482
452
|
tag = ""
|
483
453
|
map_variable = map_variable or {}
|
484
454
|
for key, value in map_variable.items():
|
485
455
|
tag += f"{key}_{value}_"
|
486
456
|
|
487
|
-
if
|
488
|
-
|
457
|
+
if isinstance(self._context, context.PipelineContext):
|
458
|
+
assert self._context.pipeline_executor._context_node
|
459
|
+
tag += self._context.pipeline_executor._context_node.name
|
489
460
|
|
490
461
|
tag = "".join(x for x in tag if x.isalnum()).strip("-")
|
491
462
|
|
@@ -496,7 +467,7 @@ class NotebookTaskType(BaseTaskType):
|
|
496
467
|
|
497
468
|
def execute_command(
|
498
469
|
self,
|
499
|
-
map_variable:
|
470
|
+
map_variable: MapVariableType = None,
|
500
471
|
) -> StepAttempt:
|
501
472
|
"""Execute the python notebook as defined by the command.
|
502
473
|
|
@@ -551,12 +522,20 @@ class NotebookTaskType(BaseTaskType):
|
|
551
522
|
}
|
552
523
|
kwds.update(ploomber_optional_args)
|
553
524
|
|
554
|
-
|
555
|
-
|
525
|
+
context.progress.stop() # redirecting stdout clashes with rich progress
|
526
|
+
|
527
|
+
with redirect_output() as (buffer, stderr_buffer):
|
556
528
|
pm.execute_notebook(**kwds)
|
557
|
-
task_console.print(out_file.getvalue())
|
558
529
|
|
559
|
-
|
530
|
+
print(stderr_buffer.getvalue()) # To print the logging statements
|
531
|
+
|
532
|
+
with task_console.capture():
|
533
|
+
task_console.log(buffer.getvalue())
|
534
|
+
task_console.log(stderr_buffer.getvalue())
|
535
|
+
|
536
|
+
context.progress.start()
|
537
|
+
|
538
|
+
context.run_context.catalog.put(name=notebook_output_path)
|
560
539
|
|
561
540
|
client = PloomberClient.from_path(path=notebook_output_path)
|
562
541
|
namespace = client.get_namespace()
|
@@ -674,7 +653,7 @@ class ShellTaskType(BaseTaskType):
|
|
674
653
|
|
675
654
|
def execute_command(
|
676
655
|
self,
|
677
|
-
map_variable:
|
656
|
+
map_variable: MapVariableType = None,
|
678
657
|
) -> StepAttempt:
|
679
658
|
# Using shell=True as we want to have chained commands to be executed in the same shell.
|
680
659
|
"""Execute the shell command as defined by the command.
|
@@ -698,7 +677,7 @@ class ShellTaskType(BaseTaskType):
|
|
698
677
|
# Expose secrets as environment variables
|
699
678
|
if self.secrets:
|
700
679
|
for key in self.secrets:
|
701
|
-
secret_value = context.run_context.
|
680
|
+
secret_value = context.run_context.secrets.get(key)
|
702
681
|
subprocess_env[key] = secret_value
|
703
682
|
|
704
683
|
try:
|
@@ -724,6 +703,7 @@ class ShellTaskType(BaseTaskType):
|
|
724
703
|
capture = False
|
725
704
|
return_keys = {x.name: x for x in self.returns}
|
726
705
|
|
706
|
+
context.progress.stop() # redirecting stdout clashes with rich progress
|
727
707
|
proc = subprocess.Popen(
|
728
708
|
command,
|
729
709
|
shell=True,
|
@@ -747,6 +727,7 @@ class ShellTaskType(BaseTaskType):
|
|
747
727
|
continue
|
748
728
|
task_console.print(line, style=defaults.warning_style)
|
749
729
|
|
730
|
+
context.progress.start()
|
750
731
|
output_parameters: Dict[str, Parameter] = {}
|
751
732
|
metrics: Dict[str, Parameter] = {}
|
752
733
|
|