runnable 0.50.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- extensions/README.md +0 -0
- extensions/__init__.py +0 -0
- extensions/catalog/README.md +0 -0
- extensions/catalog/any_path.py +214 -0
- extensions/catalog/file_system.py +52 -0
- extensions/catalog/minio.py +72 -0
- extensions/catalog/pyproject.toml +14 -0
- extensions/catalog/s3.py +11 -0
- extensions/job_executor/README.md +0 -0
- extensions/job_executor/__init__.py +236 -0
- extensions/job_executor/emulate.py +70 -0
- extensions/job_executor/k8s.py +553 -0
- extensions/job_executor/k8s_job_spec.yaml +37 -0
- extensions/job_executor/local.py +35 -0
- extensions/job_executor/local_container.py +161 -0
- extensions/job_executor/pyproject.toml +16 -0
- extensions/nodes/README.md +0 -0
- extensions/nodes/__init__.py +0 -0
- extensions/nodes/conditional.py +301 -0
- extensions/nodes/fail.py +78 -0
- extensions/nodes/loop.py +394 -0
- extensions/nodes/map.py +477 -0
- extensions/nodes/parallel.py +281 -0
- extensions/nodes/pyproject.toml +15 -0
- extensions/nodes/stub.py +93 -0
- extensions/nodes/success.py +78 -0
- extensions/nodes/task.py +156 -0
- extensions/pipeline_executor/README.md +0 -0
- extensions/pipeline_executor/__init__.py +871 -0
- extensions/pipeline_executor/argo.py +1266 -0
- extensions/pipeline_executor/emulate.py +119 -0
- extensions/pipeline_executor/local.py +226 -0
- extensions/pipeline_executor/local_container.py +369 -0
- extensions/pipeline_executor/mocked.py +159 -0
- extensions/pipeline_executor/pyproject.toml +16 -0
- extensions/run_log_store/README.md +0 -0
- extensions/run_log_store/__init__.py +0 -0
- extensions/run_log_store/any_path.py +100 -0
- extensions/run_log_store/chunked_fs.py +122 -0
- extensions/run_log_store/chunked_minio.py +141 -0
- extensions/run_log_store/file_system.py +91 -0
- extensions/run_log_store/generic_chunked.py +549 -0
- extensions/run_log_store/minio.py +114 -0
- extensions/run_log_store/pyproject.toml +15 -0
- extensions/secrets/README.md +0 -0
- extensions/secrets/dotenv.py +62 -0
- extensions/secrets/pyproject.toml +15 -0
- runnable/__init__.py +108 -0
- runnable/catalog.py +141 -0
- runnable/cli.py +484 -0
- runnable/context.py +730 -0
- runnable/datastore.py +1058 -0
- runnable/defaults.py +159 -0
- runnable/entrypoints.py +390 -0
- runnable/exceptions.py +137 -0
- runnable/executor.py +561 -0
- runnable/gantt.py +1646 -0
- runnable/graph.py +501 -0
- runnable/names.py +546 -0
- runnable/nodes.py +593 -0
- runnable/parameters.py +217 -0
- runnable/pickler.py +96 -0
- runnable/sdk.py +1277 -0
- runnable/secrets.py +92 -0
- runnable/tasks.py +1268 -0
- runnable/telemetry.py +142 -0
- runnable/utils.py +423 -0
- runnable-0.50.0.dist-info/METADATA +189 -0
- runnable-0.50.0.dist-info/RECORD +72 -0
- runnable-0.50.0.dist-info/WHEEL +4 -0
- runnable-0.50.0.dist-info/entry_points.txt +53 -0
- runnable-0.50.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Dict, List, Optional
|
|
4
|
+
|
|
5
|
+
from pydantic import Field, PrivateAttr
|
|
6
|
+
|
|
7
|
+
from extensions.job_executor import GenericJobExecutor
|
|
8
|
+
from runnable import context, defaults
|
|
9
|
+
from runnable.tasks import BaseTaskType
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(defaults.LOGGER_NAME)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class LocalContainerJobExecutor(GenericJobExecutor):
|
|
15
|
+
"""
|
|
16
|
+
The LocalJobExecutor is a job executor that runs the job locally.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
service_name: str = "local-container"
|
|
20
|
+
docker_image: str
|
|
21
|
+
auto_remove_container: bool = True
|
|
22
|
+
environment: Dict[str, str] = Field(default_factory=dict)
|
|
23
|
+
|
|
24
|
+
_should_setup_run_log_at_traversal: bool = PrivateAttr(default=True)
|
|
25
|
+
|
|
26
|
+
_container_log_location = "/tmp/run_logs/"
|
|
27
|
+
_container_catalog_location = "/tmp/catalog/"
|
|
28
|
+
_container_secrets_location = "/tmp/dotenv"
|
|
29
|
+
_volumes: Dict[str, Dict[str, str]] = {}
|
|
30
|
+
|
|
31
|
+
def submit_job(self, job: BaseTaskType, catalog_settings=Optional[List[str]]):
|
|
32
|
+
"""
|
|
33
|
+
This method gets invoked by the CLI.
|
|
34
|
+
"""
|
|
35
|
+
self._set_up_run_log()
|
|
36
|
+
self._mount_volumes()
|
|
37
|
+
|
|
38
|
+
# Call the container job
|
|
39
|
+
job_log = self._context.run_log_store.create_job_log()
|
|
40
|
+
self._context.run_log_store.add_job_log(
|
|
41
|
+
run_id=self._context.run_id, job_log=job_log
|
|
42
|
+
)
|
|
43
|
+
self.spin_container()
|
|
44
|
+
|
|
45
|
+
def execute_job(self, job: BaseTaskType, catalog_settings=Optional[List[str]]):
|
|
46
|
+
"""
|
|
47
|
+
This method gets invoked by the CLI.
|
|
48
|
+
"""
|
|
49
|
+
self._use_volumes()
|
|
50
|
+
super().execute_job(job, catalog_settings=catalog_settings)
|
|
51
|
+
|
|
52
|
+
def spin_container(self):
|
|
53
|
+
"""
|
|
54
|
+
This method spins up the container
|
|
55
|
+
"""
|
|
56
|
+
import docker # pylint: disable=C0415
|
|
57
|
+
|
|
58
|
+
try:
|
|
59
|
+
client = docker.from_env()
|
|
60
|
+
api_client = docker.APIClient()
|
|
61
|
+
except Exception as ex:
|
|
62
|
+
logger.exception("Could not get access to docker")
|
|
63
|
+
raise Exception(
|
|
64
|
+
"Could not get the docker socket file, do you have docker installed?"
|
|
65
|
+
) from ex
|
|
66
|
+
|
|
67
|
+
try:
|
|
68
|
+
assert isinstance(self._context, context.JobContext)
|
|
69
|
+
command = self._context.get_job_callable_command()
|
|
70
|
+
logger.info(f"Running the command {command}")
|
|
71
|
+
|
|
72
|
+
docker_image = self.docker_image
|
|
73
|
+
environment = self.environment
|
|
74
|
+
|
|
75
|
+
container = client.containers.create(
|
|
76
|
+
image=docker_image,
|
|
77
|
+
command=command,
|
|
78
|
+
auto_remove=False,
|
|
79
|
+
volumes=self._volumes,
|
|
80
|
+
environment=environment,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
container.start()
|
|
84
|
+
stream = api_client.logs(
|
|
85
|
+
container=container.id, timestamps=True, stream=True, follow=True
|
|
86
|
+
)
|
|
87
|
+
while True:
|
|
88
|
+
try:
|
|
89
|
+
output = next(stream).decode("utf-8")
|
|
90
|
+
output = output.strip("\r\n")
|
|
91
|
+
logger.info(output)
|
|
92
|
+
print(output)
|
|
93
|
+
except StopIteration:
|
|
94
|
+
logger.info("Docker Run completed")
|
|
95
|
+
break
|
|
96
|
+
|
|
97
|
+
exit_status = api_client.inspect_container(container.id)["State"][
|
|
98
|
+
"ExitCode"
|
|
99
|
+
]
|
|
100
|
+
|
|
101
|
+
if self.auto_remove_container:
|
|
102
|
+
container.remove(force=True)
|
|
103
|
+
|
|
104
|
+
if exit_status != 0:
|
|
105
|
+
msg = f"Docker command failed with exit code {exit_status}"
|
|
106
|
+
raise Exception(msg)
|
|
107
|
+
|
|
108
|
+
except Exception as _e:
|
|
109
|
+
logger.exception("Problems with spinning/running the container")
|
|
110
|
+
raise _e
|
|
111
|
+
|
|
112
|
+
def _mount_volumes(self):
|
|
113
|
+
"""
|
|
114
|
+
Mount the volumes for the container
|
|
115
|
+
"""
|
|
116
|
+
match self._context.run_log_store.service_name:
|
|
117
|
+
case "file-system":
|
|
118
|
+
write_to = self._context.run_log_store.log_folder
|
|
119
|
+
self._volumes[str(Path(write_to).resolve())] = {
|
|
120
|
+
"bind": f"{self._container_log_location}",
|
|
121
|
+
"mode": "rw",
|
|
122
|
+
}
|
|
123
|
+
case "chunked-fs":
|
|
124
|
+
write_to = self._context.run_log_store.log_folder
|
|
125
|
+
self._volumes[str(Path(write_to).resolve())] = {
|
|
126
|
+
"bind": f"{self._container_log_location}",
|
|
127
|
+
"mode": "rw",
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
match self._context.catalog.service_name:
|
|
131
|
+
case "file-system":
|
|
132
|
+
catalog_location = self._context.catalog.catalog_location
|
|
133
|
+
self._volumes[str(Path(catalog_location).resolve())] = {
|
|
134
|
+
"bind": f"{self._container_catalog_location}",
|
|
135
|
+
"mode": "rw",
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
match self._context.secrets.service_name:
|
|
139
|
+
case "dotenv":
|
|
140
|
+
secrets_location = self._context.secrets.location
|
|
141
|
+
self._volumes[str(Path(secrets_location).resolve())] = {
|
|
142
|
+
"bind": f"{self._container_secrets_location}",
|
|
143
|
+
"mode": "ro",
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
def _use_volumes(self):
|
|
147
|
+
match self._context.run_log_store.service_name:
|
|
148
|
+
case "file-system":
|
|
149
|
+
self._context.run_log_store.log_folder = self._container_log_location
|
|
150
|
+
case "chunked-fs":
|
|
151
|
+
self._context.run_log_store.log_folder = self._container_log_location
|
|
152
|
+
|
|
153
|
+
match self._context.catalog.service_name:
|
|
154
|
+
case "file-system":
|
|
155
|
+
self._context.catalog.catalog_location = (
|
|
156
|
+
self._container_catalog_location
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
match self._context.secrets.service_name:
|
|
160
|
+
case "dotenv":
|
|
161
|
+
self._context.secrets.location = self._container_secrets_location
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "job_executor"
|
|
3
|
+
version = "0.0.0"
|
|
4
|
+
description = "Add your description here"
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
requires-python = ">=3.10"
|
|
7
|
+
dependencies = []
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
[build-system]
|
|
11
|
+
requires = ["hatchling"]
|
|
12
|
+
build-backend = "hatchling.build"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
[tool.hatch.build.targets.wheel]
|
|
16
|
+
packages = ["."]
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,301 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from copy import deepcopy
|
|
3
|
+
from typing import Any, Optional, cast
|
|
4
|
+
|
|
5
|
+
from pydantic import Field, field_serializer, field_validator
|
|
6
|
+
|
|
7
|
+
from runnable import console, defaults, exceptions
|
|
8
|
+
from runnable.datastore import Parameter
|
|
9
|
+
from runnable.defaults import IterableParameterModel
|
|
10
|
+
from runnable.graph import Graph, create_graph
|
|
11
|
+
from runnable.nodes import CompositeNode
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(defaults.LOGGER_NAME)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ConditionalNode(CompositeNode):
|
|
17
|
+
"""
|
|
18
|
+
parameter: name -> the parameter which is used for evaluation
|
|
19
|
+
default: Optional[branch] = branch to execute if nothing is matched.
|
|
20
|
+
branches: {
|
|
21
|
+
"case1" : branch1,
|
|
22
|
+
"case2: branch2,
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
Conceptually this is equal to:
|
|
26
|
+
match parameter:
|
|
27
|
+
case "case1":
|
|
28
|
+
branch1
|
|
29
|
+
case "case2":
|
|
30
|
+
branch2
|
|
31
|
+
case _:
|
|
32
|
+
default
|
|
33
|
+
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
node_type: str = Field(default="conditional", serialization_alias="type")
|
|
37
|
+
|
|
38
|
+
parameter: str # the name of the parameter should be isalnum
|
|
39
|
+
default: Graph | None = Field(default=None) # TODO: Think about the design of this
|
|
40
|
+
branches: dict[str, Graph]
|
|
41
|
+
# The keys of the branches should be isalnum()
|
|
42
|
+
|
|
43
|
+
@field_validator("parameter", mode="after")
|
|
44
|
+
@classmethod
|
|
45
|
+
def check_parameter(cls, parameter: str) -> str:
|
|
46
|
+
"""
|
|
47
|
+
Validate that the parameter name is alphanumeric.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
parameter (str): The parameter name to validate.
|
|
51
|
+
|
|
52
|
+
Raises:
|
|
53
|
+
ValueError: If the parameter name is not alphanumeric.
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
str: The validated parameter name.
|
|
57
|
+
"""
|
|
58
|
+
if not parameter.isalnum():
|
|
59
|
+
raise ValueError(f"Parameter '{parameter}' must be alphanumeric.")
|
|
60
|
+
return parameter
|
|
61
|
+
|
|
62
|
+
def get_parameter_value(self) -> str | int | bool | float:
|
|
63
|
+
"""
|
|
64
|
+
Get the parameter value from the context.
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
Any: The value of the parameter.
|
|
68
|
+
"""
|
|
69
|
+
parameters: dict[str, Parameter] = self._context.run_log_store.get_parameters(
|
|
70
|
+
run_id=self._context.run_id
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
if self.parameter not in parameters:
|
|
74
|
+
raise Exception(f"Parameter {self.parameter} not found in parameters")
|
|
75
|
+
|
|
76
|
+
chosen_parameter_value = parameters[self.parameter].get_value()
|
|
77
|
+
|
|
78
|
+
assert isinstance(chosen_parameter_value, (int, float, bool, str)), (
|
|
79
|
+
f"Parameter '{self.parameter}' must be of type int, float, bool, or str, "
|
|
80
|
+
f"but got {type(chosen_parameter_value).__name__}."
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
return chosen_parameter_value
|
|
84
|
+
|
|
85
|
+
def get_summary(self) -> dict[str, Any]:
|
|
86
|
+
summary = {
|
|
87
|
+
"name": self.name,
|
|
88
|
+
"type": self.node_type,
|
|
89
|
+
"branches": [branch.get_summary() for branch in self.branches.values()],
|
|
90
|
+
"parameter": self.parameter,
|
|
91
|
+
"default": self.default.get_summary() if self.default else None,
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
return summary
|
|
95
|
+
|
|
96
|
+
@field_serializer("branches")
|
|
97
|
+
def ser_branches(self, branches: dict[str, Graph]) -> dict[str, Graph]:
|
|
98
|
+
ret: dict[str, Graph] = {}
|
|
99
|
+
|
|
100
|
+
for branch_name, branch in branches.items():
|
|
101
|
+
ret[branch_name.split(".")[-1]] = branch
|
|
102
|
+
|
|
103
|
+
return ret
|
|
104
|
+
|
|
105
|
+
@classmethod
|
|
106
|
+
def parse_from_config(cls, config: dict[str, Any]) -> "ConditionalNode":
|
|
107
|
+
internal_name = cast(str, config.get("internal_name"))
|
|
108
|
+
|
|
109
|
+
config_branches = config.pop("branches", {})
|
|
110
|
+
branches = {}
|
|
111
|
+
for branch_name, branch_config in config_branches.items():
|
|
112
|
+
sub_graph = create_graph(
|
|
113
|
+
deepcopy(branch_config),
|
|
114
|
+
internal_branch_name=internal_name + "." + branch_name,
|
|
115
|
+
)
|
|
116
|
+
branches[internal_name + "." + branch_name] = sub_graph
|
|
117
|
+
|
|
118
|
+
if not branches:
|
|
119
|
+
raise Exception("A parallel node should have branches")
|
|
120
|
+
return cls(branches=branches, **config)
|
|
121
|
+
|
|
122
|
+
def _get_branch_by_name(self, branch_name: str) -> Graph:
|
|
123
|
+
if branch_name in self.branches:
|
|
124
|
+
return self.branches[branch_name]
|
|
125
|
+
|
|
126
|
+
raise Exception(f"Branch {branch_name} does not exist")
|
|
127
|
+
|
|
128
|
+
def fan_out(
|
|
129
|
+
self,
|
|
130
|
+
iter_variable: Optional[IterableParameterModel] = None,
|
|
131
|
+
):
|
|
132
|
+
"""
|
|
133
|
+
This method is restricted to creating branch logs.
|
|
134
|
+
"""
|
|
135
|
+
parameter_value = self.get_parameter_value()
|
|
136
|
+
|
|
137
|
+
hit_once = False
|
|
138
|
+
|
|
139
|
+
for internal_branch_name, _ in self.branches.items():
|
|
140
|
+
# the match is done on the last part of the branch name
|
|
141
|
+
result = str(parameter_value) == internal_branch_name.split(".")[-1]
|
|
142
|
+
|
|
143
|
+
if not result:
|
|
144
|
+
# Need not create a branch log for this branch
|
|
145
|
+
continue
|
|
146
|
+
|
|
147
|
+
effective_branch_name = self._resolve_map_placeholders(
|
|
148
|
+
internal_branch_name, iter_variable=iter_variable
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
hit_once = True
|
|
152
|
+
try:
|
|
153
|
+
branch_log = self._context.run_log_store.get_branch_log(
|
|
154
|
+
effective_branch_name, self._context.run_id
|
|
155
|
+
)
|
|
156
|
+
console.print(f"Branch log already exists for {effective_branch_name}")
|
|
157
|
+
except exceptions.BranchLogNotFoundError:
|
|
158
|
+
branch_log = self._context.run_log_store.create_branch_log(
|
|
159
|
+
effective_branch_name
|
|
160
|
+
)
|
|
161
|
+
console.print(f"Branch log created for {effective_branch_name}")
|
|
162
|
+
|
|
163
|
+
branch_log.status = defaults.PROCESSING
|
|
164
|
+
self._context.run_log_store.add_branch_log(branch_log, self._context.run_id)
|
|
165
|
+
|
|
166
|
+
if not hit_once:
|
|
167
|
+
raise Exception(
|
|
168
|
+
"None of the branches were true. Please check your evaluate statements"
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
def execute_as_graph(
|
|
172
|
+
self,
|
|
173
|
+
iter_variable: Optional[IterableParameterModel] = None,
|
|
174
|
+
):
|
|
175
|
+
"""
|
|
176
|
+
This function does the actual execution of the sub-branches of the parallel node.
|
|
177
|
+
|
|
178
|
+
From a design perspective, this function should not be called if the execution is 3rd party orchestrated.
|
|
179
|
+
|
|
180
|
+
The modes that render the job specifications, do not need to interact with this node at all as they have their
|
|
181
|
+
own internal mechanisms of handing parallel states.
|
|
182
|
+
If they do not, you can find a way using as-is nodes as hack nodes.
|
|
183
|
+
|
|
184
|
+
The execution of a dag, could result in
|
|
185
|
+
* The dag being completely executed with a definite (fail, success) state in case of
|
|
186
|
+
local or local-container execution
|
|
187
|
+
* The dag being in a processing state with PROCESSING status in case of local-aws-batch
|
|
188
|
+
|
|
189
|
+
Only fail state is considered failure during this phase of execution.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
executor (Executor): The Executor as per the use config
|
|
193
|
+
**kwargs: Optional kwargs passed around
|
|
194
|
+
"""
|
|
195
|
+
self.fan_out(iter_variable=iter_variable)
|
|
196
|
+
parameter_value = self.get_parameter_value()
|
|
197
|
+
|
|
198
|
+
for internal_branch_name, branch in self.branches.items():
|
|
199
|
+
result = str(parameter_value) == internal_branch_name.split(".")[-1]
|
|
200
|
+
|
|
201
|
+
if result:
|
|
202
|
+
# if the condition is met, execute the graph
|
|
203
|
+
logger.debug(f"Executing graph for {branch}")
|
|
204
|
+
self._context.pipeline_executor.execute_graph(
|
|
205
|
+
branch, iter_variable=iter_variable
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
self.fan_in(iter_variable=iter_variable)
|
|
209
|
+
|
|
210
|
+
def fan_in(
|
|
211
|
+
self,
|
|
212
|
+
iter_variable: Optional[IterableParameterModel] = None,
|
|
213
|
+
):
|
|
214
|
+
"""
|
|
215
|
+
The general fan in method for a node of type Parallel.
|
|
216
|
+
|
|
217
|
+
3rd party orchestrators should use this method to find the status of the composite step.
|
|
218
|
+
|
|
219
|
+
Args:
|
|
220
|
+
executor (BaseExecutor): The executor class as defined by the config
|
|
221
|
+
map_variable (dict, optional): If the node is part of a map. Defaults to None.
|
|
222
|
+
"""
|
|
223
|
+
effective_internal_name = self._resolve_map_placeholders(
|
|
224
|
+
self.internal_name, iter_variable=iter_variable
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
step_success_bool: bool = True
|
|
228
|
+
parameter_value = self.get_parameter_value()
|
|
229
|
+
executed_branch_name = None
|
|
230
|
+
|
|
231
|
+
for internal_branch_name, _ in self.branches.items():
|
|
232
|
+
result = str(parameter_value) == internal_branch_name.split(".")[-1]
|
|
233
|
+
|
|
234
|
+
if not result:
|
|
235
|
+
# The branch would not have been executed
|
|
236
|
+
continue
|
|
237
|
+
|
|
238
|
+
effective_branch_name = self._resolve_map_placeholders(
|
|
239
|
+
internal_branch_name, iter_variable=iter_variable
|
|
240
|
+
)
|
|
241
|
+
executed_branch_name = effective_branch_name
|
|
242
|
+
|
|
243
|
+
branch_log = self._context.run_log_store.get_branch_log(
|
|
244
|
+
effective_branch_name, self._context.run_id
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
if branch_log.status != defaults.SUCCESS:
|
|
248
|
+
step_success_bool = False
|
|
249
|
+
|
|
250
|
+
step_log = self._context.run_log_store.get_step_log(
|
|
251
|
+
effective_internal_name, self._context.run_id
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
if step_success_bool: # If none failed
|
|
255
|
+
step_log.status = defaults.SUCCESS
|
|
256
|
+
else:
|
|
257
|
+
step_log.status = defaults.FAIL
|
|
258
|
+
|
|
259
|
+
self._context.run_log_store.add_step_log(step_log, self._context.run_id)
|
|
260
|
+
|
|
261
|
+
# If we failed, return without parameter rollback
|
|
262
|
+
if not step_log.status == defaults.SUCCESS:
|
|
263
|
+
return
|
|
264
|
+
|
|
265
|
+
# Roll back parameters from executed branch to parent scope
|
|
266
|
+
if executed_branch_name:
|
|
267
|
+
parent_params = self._context.run_log_store.get_parameters(
|
|
268
|
+
self._context.run_id, internal_branch_name=self.internal_branch_name
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
branch_params = self._context.run_log_store.get_parameters(
|
|
272
|
+
self._context.run_id, internal_branch_name=executed_branch_name
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
# Merge branch parameters into parent (overwrite with branch values)
|
|
276
|
+
parent_params.update(branch_params)
|
|
277
|
+
|
|
278
|
+
self._context.run_log_store.set_parameters(
|
|
279
|
+
parameters=parent_params,
|
|
280
|
+
run_id=self._context.run_id,
|
|
281
|
+
internal_branch_name=self.internal_branch_name,
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
async def execute_as_graph_async(
|
|
285
|
+
self,
|
|
286
|
+
iter_variable: Optional[IterableParameterModel] = None,
|
|
287
|
+
):
|
|
288
|
+
"""Async conditional execution."""
|
|
289
|
+
self.fan_out(iter_variable=iter_variable) # sync
|
|
290
|
+
parameter_value = self.get_parameter_value()
|
|
291
|
+
|
|
292
|
+
for internal_branch_name, branch in self.branches.items():
|
|
293
|
+
result = str(parameter_value) == internal_branch_name.split(".")[-1]
|
|
294
|
+
|
|
295
|
+
if result:
|
|
296
|
+
logger.debug(f"Executing graph for {branch}")
|
|
297
|
+
await self._context.pipeline_executor.execute_graph_async(
|
|
298
|
+
branch, iter_variable=iter_variable
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
self.fan_in(iter_variable=iter_variable) # sync
|
extensions/nodes/fail.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from typing import Any, Dict, Optional, cast
|
|
3
|
+
|
|
4
|
+
from pydantic import Field
|
|
5
|
+
|
|
6
|
+
from runnable import datastore, defaults
|
|
7
|
+
from runnable.datastore import StepLog
|
|
8
|
+
from runnable.defaults import IterableParameterModel
|
|
9
|
+
from runnable.nodes import TerminalNode
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class FailNode(TerminalNode):
|
|
13
|
+
"""
|
|
14
|
+
A leaf node of the graph that represents a failure node
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
node_type: str = Field(default="fail", serialization_alias="type")
|
|
18
|
+
|
|
19
|
+
@classmethod
|
|
20
|
+
def parse_from_config(cls, config: Dict[str, Any]) -> "FailNode":
|
|
21
|
+
return cast("FailNode", super().parse_from_config(config))
|
|
22
|
+
|
|
23
|
+
def get_summary(self) -> Dict[str, Any]:
|
|
24
|
+
summary = {
|
|
25
|
+
"name": self.name,
|
|
26
|
+
"type": self.node_type,
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
return summary
|
|
30
|
+
|
|
31
|
+
def execute(
|
|
32
|
+
self,
|
|
33
|
+
mock=False,
|
|
34
|
+
iter_variable: Optional[IterableParameterModel] = None,
|
|
35
|
+
attempt_number: int = 1,
|
|
36
|
+
) -> StepLog:
|
|
37
|
+
"""
|
|
38
|
+
Execute the failure node.
|
|
39
|
+
Set the run or branch log status to failure.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
executor (_type_): the executor class
|
|
43
|
+
mock (bool, optional): If we should just mock and not do the actual execution. Defaults to False.
|
|
44
|
+
iter_variable (dict, optional): If the node belongs to internal branches. Defaults to None.
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
StepAttempt: The step attempt object
|
|
48
|
+
"""
|
|
49
|
+
step_log = self._context.run_log_store.get_step_log(
|
|
50
|
+
self._get_step_log_name(iter_variable), self._context.run_id
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
attempt_log = datastore.StepAttempt(
|
|
54
|
+
status=defaults.SUCCESS,
|
|
55
|
+
start_time=str(datetime.now()),
|
|
56
|
+
end_time=str(datetime.now()),
|
|
57
|
+
attempt_number=attempt_number,
|
|
58
|
+
retry_indicator=self._context.retry_indicator,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
# Add code identities to the attempt
|
|
62
|
+
self._context.pipeline_executor.add_code_identities(
|
|
63
|
+
node=self, attempt_log=attempt_log
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
run_or_branch_log = self._context.run_log_store.get_branch_log(
|
|
67
|
+
self._get_branch_log_name(iter_variable), self._context.run_id
|
|
68
|
+
)
|
|
69
|
+
run_or_branch_log.status = defaults.FAIL
|
|
70
|
+
self._context.run_log_store.add_branch_log(
|
|
71
|
+
run_or_branch_log, self._context.run_id
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
step_log.status = attempt_log.status
|
|
75
|
+
|
|
76
|
+
step_log.attempts.append(attempt_log)
|
|
77
|
+
|
|
78
|
+
return step_log
|