zrb 1.0.0a5__py3-none-any.whl → 1.0.0a12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zrb/__main__.py +6 -0
- zrb/builtin/__init__.py +24 -6
- zrb/builtin/git.py +16 -13
- zrb/builtin/git_subtree.py +19 -4
- zrb/builtin/group.py +5 -0
- zrb/builtin/setup/asdf/asdf.py +86 -0
- zrb/builtin/setup/asdf/asdf_helper.py +44 -0
- zrb/builtin/setup/common_input.py +35 -0
- zrb/builtin/setup/latex/ubuntu.py +18 -0
- zrb/builtin/setup/tmux/tmux.py +50 -0
- zrb/builtin/setup/tmux/tmux_config.sh +12 -0
- zrb/builtin/setup/tmux/tmux_helper.py +13 -0
- zrb/builtin/setup/ubuntu.py +28 -0
- zrb/builtin/todo.py +89 -21
- zrb/config.py +5 -1
- zrb/input/base_input.py +1 -1
- zrb/input/bool_input.py +1 -1
- zrb/input/float_input.py +1 -1
- zrb/input/int_input.py +1 -1
- zrb/input/option_input.py +1 -1
- zrb/input/password_input.py +1 -1
- zrb/input/text_input.py +1 -1
- zrb/runner/web_app.py +27 -21
- zrb/task/any_task.py +38 -2
- zrb/task/base_task.py +71 -4
- zrb/task/cmd_task.py +27 -3
- zrb/task/llm_task.py +24 -18
- zrb/task/rsync_task.py +8 -8
- zrb/util/cmd/command.py +33 -0
- zrb/util/codemod/add_parent_to_class.py +38 -0
- zrb/util/git.py +35 -17
- zrb/util/git_subtree.py +11 -10
- zrb/util/load.py +4 -1
- zrb/util/string/format.py +12 -2
- zrb/util/todo.py +152 -34
- {zrb-1.0.0a5.dist-info → zrb-1.0.0a12.dist-info}/METADATA +1 -1
- {zrb-1.0.0a5.dist-info → zrb-1.0.0a12.dist-info}/RECORD +39 -29
- {zrb-1.0.0a5.dist-info → zrb-1.0.0a12.dist-info}/WHEEL +0 -0
- {zrb-1.0.0a5.dist-info → zrb-1.0.0a12.dist-info}/entry_points.txt +0 -0
zrb/builtin/todo.py
CHANGED
@@ -4,7 +4,7 @@ import os
|
|
4
4
|
from typing import Any
|
5
5
|
|
6
6
|
from zrb.builtin.group import todo_group
|
7
|
-
from zrb.config import TODO_DIR
|
7
|
+
from zrb.config import TODO_DIR, TODO_VISUAL_FILTER
|
8
8
|
from zrb.context.any_context import AnyContext
|
9
9
|
from zrb.input.str_input import StrInput
|
10
10
|
from zrb.input.text_input import TextInput
|
@@ -13,6 +13,7 @@ from zrb.util.todo import (
|
|
13
13
|
TodoTaskModel,
|
14
14
|
add_durations,
|
15
15
|
cascade_todo_task,
|
16
|
+
get_visual_todo_card,
|
16
17
|
get_visual_todo_list,
|
17
18
|
line_to_todo_task,
|
18
19
|
load_todo_list,
|
@@ -23,7 +24,7 @@ from zrb.util.todo import (
|
|
23
24
|
|
24
25
|
|
25
26
|
@make_task(
|
26
|
-
name="todo
|
27
|
+
name="add-todo",
|
27
28
|
input=[
|
28
29
|
StrInput(
|
29
30
|
name="description",
|
@@ -51,7 +52,7 @@ from zrb.util.todo import (
|
|
51
52
|
group=todo_group,
|
52
53
|
alias="add",
|
53
54
|
)
|
54
|
-
def
|
55
|
+
def add_todo(ctx: AnyContext):
|
55
56
|
todo_file_path = os.path.join(TODO_DIR, "todo.txt")
|
56
57
|
todo_list: list[TodoTaskModel] = []
|
57
58
|
if os.path.isfile(todo_file_path):
|
@@ -77,26 +78,58 @@ def todo_add(ctx: AnyContext):
|
|
77
78
|
)
|
78
79
|
)
|
79
80
|
save_todo_list(todo_file_path, todo_list)
|
80
|
-
return get_visual_todo_list(todo_list)
|
81
|
+
return get_visual_todo_list(todo_list, TODO_VISUAL_FILTER)
|
81
82
|
|
82
83
|
|
83
|
-
@make_task(name="todo
|
84
|
-
def
|
84
|
+
@make_task(name="list-todo", description="📋 List todo", group=todo_group, alias="list")
|
85
|
+
def list_todo(ctx: AnyContext):
|
85
86
|
todo_file_path = os.path.join(TODO_DIR, "todo.txt")
|
86
|
-
|
87
|
+
todo_list: list[TodoTaskModel] = []
|
87
88
|
if os.path.isfile(todo_file_path):
|
88
|
-
|
89
|
-
return get_visual_todo_list(
|
89
|
+
todo_list = load_todo_list(todo_file_path)
|
90
|
+
return get_visual_todo_list(todo_list, TODO_VISUAL_FILTER)
|
90
91
|
|
91
92
|
|
92
93
|
@make_task(
|
93
|
-
name="todo
|
94
|
+
name="show-todo",
|
95
|
+
input=StrInput(name="keyword", prompt="Task keyword", description="Task Keyword"),
|
96
|
+
description="🔍 Show todo",
|
97
|
+
group=todo_group,
|
98
|
+
alias="show",
|
99
|
+
)
|
100
|
+
def show_todo(ctx: AnyContext):
|
101
|
+
todo_file_path = os.path.join(TODO_DIR, "todo.txt")
|
102
|
+
todo_list: list[TodoTaskModel] = []
|
103
|
+
todo_list: list[TodoTaskModel] = []
|
104
|
+
if os.path.isfile(todo_file_path):
|
105
|
+
todo_list = load_todo_list(todo_file_path)
|
106
|
+
# Get todo task
|
107
|
+
todo_task = select_todo_task(todo_list, ctx.input.keyword)
|
108
|
+
if todo_task is None:
|
109
|
+
ctx.log_error("Task not found")
|
110
|
+
return
|
111
|
+
if todo_task.completed:
|
112
|
+
ctx.log_error("Task already completed")
|
113
|
+
return
|
114
|
+
# Update todo task
|
115
|
+
todo_task = cascade_todo_task(todo_task)
|
116
|
+
task_id = todo_task.keyval.get("id", "")
|
117
|
+
log_work_path = os.path.join(TODO_DIR, "log-work", f"{task_id}.json")
|
118
|
+
log_work_list = []
|
119
|
+
if os.path.isfile(log_work_path):
|
120
|
+
with open(log_work_path, "r") as f:
|
121
|
+
log_work_list = json.loads(f.read())
|
122
|
+
return get_visual_todo_card(todo_task, log_work_list)
|
123
|
+
|
124
|
+
|
125
|
+
@make_task(
|
126
|
+
name="complete-todo",
|
94
127
|
input=StrInput(name="keyword", prompt="Task keyword", description="Task Keyword"),
|
95
128
|
description="✅ Complete todo",
|
96
129
|
group=todo_group,
|
97
130
|
alias="complete",
|
98
131
|
)
|
99
|
-
def
|
132
|
+
def complete_todo(ctx: AnyContext):
|
100
133
|
todo_file_path = os.path.join(TODO_DIR, "todo.txt")
|
101
134
|
todo_list: list[TodoTaskModel] = []
|
102
135
|
if os.path.isfile(todo_file_path):
|
@@ -105,7 +138,10 @@ def todo_complete(ctx: AnyContext):
|
|
105
138
|
todo_task = select_todo_task(todo_list, ctx.input.keyword)
|
106
139
|
if todo_task is None:
|
107
140
|
ctx.log_error("Task not found")
|
108
|
-
return get_visual_todo_list(todo_list)
|
141
|
+
return get_visual_todo_list(todo_list, TODO_VISUAL_FILTER)
|
142
|
+
if todo_task.completed:
|
143
|
+
ctx.log_error("Task already completed")
|
144
|
+
return get_visual_todo_list(todo_list, TODO_VISUAL_FILTER)
|
109
145
|
# Update todo task
|
110
146
|
todo_task = cascade_todo_task(todo_task)
|
111
147
|
if todo_task.creation_date is not None:
|
@@ -113,11 +149,43 @@ def todo_complete(ctx: AnyContext):
|
|
113
149
|
todo_task.completed = True
|
114
150
|
# Save todo list
|
115
151
|
save_todo_list(todo_file_path, todo_list)
|
116
|
-
return get_visual_todo_list(todo_list)
|
152
|
+
return get_visual_todo_list(todo_list, TODO_VISUAL_FILTER)
|
117
153
|
|
118
154
|
|
119
155
|
@make_task(
|
120
|
-
name="todo
|
156
|
+
name="archive-todo",
|
157
|
+
description="📚 Archive todo",
|
158
|
+
group=todo_group,
|
159
|
+
alias="archive",
|
160
|
+
)
|
161
|
+
def archive_todo(ctx: AnyContext):
|
162
|
+
todo_file_path = os.path.join(TODO_DIR, "todo.txt")
|
163
|
+
todo_list: list[TodoTaskModel] = []
|
164
|
+
if os.path.isfile(todo_file_path):
|
165
|
+
todo_list = load_todo_list(todo_file_path)
|
166
|
+
working_todo_list = [
|
167
|
+
todo_task for todo_task in todo_list if not todo_task.completed
|
168
|
+
]
|
169
|
+
new_archived_todo_list = [
|
170
|
+
todo_task for todo_task in todo_list if todo_task.completed
|
171
|
+
]
|
172
|
+
if len(new_archived_todo_list) == 0:
|
173
|
+
ctx.print("No completed task to archive")
|
174
|
+
return get_visual_todo_list(todo_list, TODO_VISUAL_FILTER)
|
175
|
+
archive_file_path = os.path.join(TODO_DIR, "archive.txt")
|
176
|
+
if not os.path.isdir(TODO_DIR):
|
177
|
+
os.make_dirs(TODO_DIR, exist_ok=True)
|
178
|
+
archived_todo_list = []
|
179
|
+
if os.path.isfile(archive_file_path):
|
180
|
+
archived_todo_list = load_todo_list(archive_file_path)
|
181
|
+
archived_todo_list += new_archived_todo_list
|
182
|
+
save_todo_list(archive_file_path, archived_todo_list)
|
183
|
+
save_todo_list(todo_file_path, working_todo_list)
|
184
|
+
return get_visual_todo_list(todo_list, TODO_VISUAL_FILTER)
|
185
|
+
|
186
|
+
|
187
|
+
@make_task(
|
188
|
+
name="log-todo",
|
121
189
|
input=[
|
122
190
|
StrInput(name="keyword", prompt="Task keyword", description="Task Keyword"),
|
123
191
|
StrInput(
|
@@ -142,7 +210,7 @@ def todo_complete(ctx: AnyContext):
|
|
142
210
|
group=todo_group,
|
143
211
|
alias="log",
|
144
212
|
)
|
145
|
-
def
|
213
|
+
def log_todo(ctx: AnyContext):
|
146
214
|
todo_file_path = os.path.join(TODO_DIR, "todo.txt")
|
147
215
|
todo_list: list[TodoTaskModel] = []
|
148
216
|
if os.path.isfile(todo_file_path):
|
@@ -151,12 +219,11 @@ def todo_log(ctx: AnyContext):
|
|
151
219
|
todo_task = select_todo_task(todo_list, ctx.input.keyword)
|
152
220
|
if todo_task is None:
|
153
221
|
ctx.log_error("Task not found")
|
154
|
-
return get_visual_todo_list(todo_list)
|
222
|
+
return get_visual_todo_list(todo_list, TODO_VISUAL_FILTER)
|
155
223
|
# Update todo task
|
156
224
|
todo_task = cascade_todo_task(todo_task)
|
157
225
|
current_duration = todo_task.keyval.get("duration", "0")
|
158
226
|
todo_task.keyval["duration"] = add_durations(current_duration, ctx.input.duration)
|
159
|
-
print(current_duration, todo_task.keyval)
|
160
227
|
# Save todo list
|
161
228
|
save_todo_list(todo_file_path, todo_list)
|
162
229
|
# Add log work
|
@@ -176,7 +243,7 @@ def todo_log(ctx: AnyContext):
|
|
176
243
|
)
|
177
244
|
with open(log_work_file_path, "w") as f:
|
178
245
|
f.write(json.dumps(log_work, indent=2))
|
179
|
-
return get_visual_todo_list(todo_list)
|
246
|
+
return get_visual_todo_list(todo_list, TODO_VISUAL_FILTER)
|
180
247
|
|
181
248
|
|
182
249
|
def _get_default_start() -> str:
|
@@ -184,7 +251,7 @@ def _get_default_start() -> str:
|
|
184
251
|
|
185
252
|
|
186
253
|
@make_task(
|
187
|
-
name="todo
|
254
|
+
name="edit-todo",
|
188
255
|
input=[
|
189
256
|
TextInput(
|
190
257
|
name="text",
|
@@ -197,7 +264,7 @@ def _get_default_start() -> str:
|
|
197
264
|
group=todo_group,
|
198
265
|
alias="edit",
|
199
266
|
)
|
200
|
-
def
|
267
|
+
def edit_todo(ctx: AnyContext):
|
201
268
|
todo_list = [
|
202
269
|
cascade_todo_task(line_to_todo_task(line))
|
203
270
|
for line in ctx.input.text.split("\n")
|
@@ -208,7 +275,7 @@ def todo_edit(ctx: AnyContext):
|
|
208
275
|
with open(todo_file_path, "w") as f:
|
209
276
|
f.write(new_content)
|
210
277
|
todo_list = load_todo_list(todo_file_path)
|
211
|
-
return get_visual_todo_list(todo_list)
|
278
|
+
return get_visual_todo_list(todo_list, TODO_VISUAL_FILTER)
|
212
279
|
|
213
280
|
|
214
281
|
def _get_todo_txt_content() -> str:
|
@@ -217,3 +284,4 @@ def _get_todo_txt_content() -> str:
|
|
217
284
|
return ""
|
218
285
|
with open(todo_file_path, "r") as f:
|
219
286
|
return f.read()
|
287
|
+
return f.read()
|
zrb/config.py
CHANGED
@@ -49,10 +49,14 @@ LOGGING_LEVEL = _get_log_level(os.getenv("ZRB_LOGGING_LEVEL", "WARNING"))
|
|
49
49
|
LOAD_BUILTIN = to_boolean(os.getenv("ZRB_LOAD_BUILTIN", "1"))
|
50
50
|
ENV_PREFIX = os.getenv("ZRB_ENV", "")
|
51
51
|
SHOW_PROMPT = to_boolean(os.getenv("ZRB_SHOW_PROMPT", "1"))
|
52
|
+
WARN_UNRECOMMENDED_COMMAND = to_boolean(
|
53
|
+
os.getenv("ZRB_WARN_UNRECOMMENDED_COMMAND", "1")
|
54
|
+
)
|
52
55
|
SESSION_LOG_DIR = os.getenv(
|
53
56
|
"ZRB_SESSION_LOG_DIR", os.path.expanduser(os.path.join("~", ".zrb-session"))
|
54
57
|
)
|
55
58
|
TODO_DIR = os.getenv("ZRB_TODO_DIR", os.path.expanduser(os.path.join("~", "todo")))
|
59
|
+
TODO_VISUAL_FILTER = os.getenv("ZRB_TODO_FILTER", "")
|
56
60
|
VERSION = metadata.version("zrb")
|
57
61
|
WEB_HTTP_PORT = int(os.getenv("ZRB_WEB_HTTP_PORT", "21213"))
|
58
62
|
LLM_MODEL = os.getenv("ZRB_LLM_MODEL", "ollama_chat/llama3.1")
|
@@ -77,7 +81,7 @@ BANNER = f"""
|
|
77
81
|
zzzzz rr bbbbbb {VERSION} Janggala
|
78
82
|
_ _ . . . _ . _ . . .
|
79
83
|
|
80
|
-
|
84
|
+
Your Automation Powerhouse
|
81
85
|
|
82
86
|
☕ Donate at: https://stalchmst.com/donation
|
83
87
|
🐙 Submit issues/PR at: https://github.com/state-alchemists/zrb
|
zrb/input/base_input.py
CHANGED
zrb/input/bool_input.py
CHANGED
zrb/input/float_input.py
CHANGED
zrb/input/int_input.py
CHANGED
zrb/input/option_input.py
CHANGED
zrb/input/password_input.py
CHANGED
zrb/input/text_input.py
CHANGED
@@ -16,7 +16,7 @@ class TextInput(BaseInput):
|
|
16
16
|
prompt: str | None = None,
|
17
17
|
default_str: str | Callable[[AnySharedContext], str] = "",
|
18
18
|
auto_render: bool = True,
|
19
|
-
allow_empty: bool =
|
19
|
+
allow_empty: bool = False,
|
20
20
|
editor: str = DEFAULT_EDITOR,
|
21
21
|
extension: str = ".txt",
|
22
22
|
comment_start: str | None = None,
|
zrb/runner/web_app.py
CHANGED
@@ -2,7 +2,7 @@ import asyncio
|
|
2
2
|
import os
|
3
3
|
import sys
|
4
4
|
from datetime import datetime, timedelta
|
5
|
-
from typing import Any
|
5
|
+
from typing import Any
|
6
6
|
|
7
7
|
from zrb.config import BANNER, WEB_HTTP_PORT
|
8
8
|
from zrb.context.shared_context import SharedContext
|
@@ -23,7 +23,7 @@ from zrb.util.group import extract_node_from_args, get_node_path
|
|
23
23
|
def create_app(root_group: AnyGroup, port: int = WEB_HTTP_PORT):
|
24
24
|
from contextlib import asynccontextmanager
|
25
25
|
|
26
|
-
from fastapi import FastAPI, HTTPException, Request
|
26
|
+
from fastapi import FastAPI, HTTPException, Query, Request
|
27
27
|
from fastapi.responses import FileResponse, HTMLResponse
|
28
28
|
from fastapi.staticfiles import StaticFiles
|
29
29
|
|
@@ -97,7 +97,13 @@ def create_app(root_group: AnyGroup, port: int = WEB_HTTP_PORT):
|
|
97
97
|
raise HTTPException(status_code=404, detail="Not Found")
|
98
98
|
|
99
99
|
@app.get("/api/{path:path}", response_model=SessionStateLog | SessionStateLogList)
|
100
|
-
async def get_session(
|
100
|
+
async def get_session(
|
101
|
+
path: str,
|
102
|
+
min_start_query: str = Query(default=None, alias="from"),
|
103
|
+
max_start_query: str = Query(default=None, alias="to"),
|
104
|
+
page: int = Query(default=0, alias="page"),
|
105
|
+
limit: int = Query(default=10, alias="limit"),
|
106
|
+
):
|
101
107
|
"""
|
102
108
|
Getting existing session or sessions
|
103
109
|
"""
|
@@ -106,24 +112,30 @@ def create_app(root_group: AnyGroup, port: int = WEB_HTTP_PORT):
|
|
106
112
|
if isinstance(node, AnyTask) and residual_args:
|
107
113
|
if residual_args[0] == "list":
|
108
114
|
task_path = get_node_path(root_group, node)
|
109
|
-
|
115
|
+
max_start_time = (
|
116
|
+
datetime.now()
|
117
|
+
if max_start_query is None
|
118
|
+
else datetime.strptime(max_start_query, "%Y-%m-%d %H:%M:%S")
|
119
|
+
)
|
120
|
+
min_start_time = (
|
121
|
+
max_start_time - timedelta(hours=1)
|
122
|
+
if min_start_query is None
|
123
|
+
else datetime.strptime(min_start_query, "%Y-%m-%d %H:%M:%S")
|
124
|
+
)
|
125
|
+
return list_sessions(
|
126
|
+
task_path, min_start_time, max_start_time, page, limit
|
127
|
+
)
|
110
128
|
else:
|
111
129
|
return read_session(residual_args[0])
|
112
130
|
raise HTTPException(status_code=404, detail="Not Found")
|
113
131
|
|
114
132
|
def list_sessions(
|
115
|
-
task_path:
|
133
|
+
task_path: list[str],
|
134
|
+
min_start_time: datetime,
|
135
|
+
max_start_time: datetime,
|
136
|
+
page: int,
|
137
|
+
limit: int,
|
116
138
|
) -> SessionStateLogList:
|
117
|
-
max_start_time = datetime.now()
|
118
|
-
if "to" in query_params:
|
119
|
-
max_start_time = datetime.strptime(query_params["to"], "%Y-%m-%d %H:%M:%S")
|
120
|
-
min_start_time = max_start_time - timedelta(hours=1)
|
121
|
-
if "from" in query_params:
|
122
|
-
min_start_time = datetime.strptime(
|
123
|
-
query_params["from"], "%Y-%m-%d %H:%M:%S"
|
124
|
-
)
|
125
|
-
page = int(query_params.get("page", 0))
|
126
|
-
limit = int(query_params.get("limit", 10))
|
127
139
|
try:
|
128
140
|
return default_session_state_logger.list(
|
129
141
|
task_path,
|
@@ -142,9 +154,3 @@ def create_app(root_group: AnyGroup, port: int = WEB_HTTP_PORT):
|
|
142
154
|
raise HTTPException(status_code=500, detail=str(e))
|
143
155
|
|
144
156
|
return app
|
145
|
-
|
146
|
-
|
147
|
-
# async def run_web_server(app: FastAPI, port: int = WEB_HTTP_PORT):
|
148
|
-
# config = Config(app=app, host="0.0.0.0", port=port, loop="asyncio")
|
149
|
-
# server = Server(config)
|
150
|
-
# await server.serve()
|
zrb/task/any_task.py
CHANGED
@@ -74,6 +74,12 @@ class AnyTask(ABC):
|
|
74
74
|
"""Task fallbacks"""
|
75
75
|
pass
|
76
76
|
|
77
|
+
@property
|
78
|
+
@abstractmethod
|
79
|
+
def successors(self) -> list["AnyTask"]:
|
80
|
+
"""Task successors"""
|
81
|
+
pass
|
82
|
+
|
77
83
|
@property
|
78
84
|
@abstractmethod
|
79
85
|
def readiness_checks(self) -> list["AnyTask"]:
|
@@ -81,8 +87,38 @@ class AnyTask(ABC):
|
|
81
87
|
pass
|
82
88
|
|
83
89
|
@abstractmethod
|
84
|
-
def
|
85
|
-
"""
|
90
|
+
def append_fallback(self, fallbacks: "AnyTask" | list["AnyTask"]):
|
91
|
+
"""Add the fallback tasks.
|
92
|
+
|
93
|
+
Args:
|
94
|
+
fallbacks (AnyTask | list[AnyTask]): A single fallback task or
|
95
|
+
a list of fallback tasks.
|
96
|
+
"""
|
97
|
+
pass
|
98
|
+
|
99
|
+
@abstractmethod
|
100
|
+
def append_successor(self, successors: "AnyTask" | list["AnyTask"]):
|
101
|
+
"""Add the successor tasks.
|
102
|
+
|
103
|
+
Args:
|
104
|
+
successors (AnyTask | list[AnyTask]): A single successor task or
|
105
|
+
a list of successor tasks.
|
106
|
+
"""
|
107
|
+
pass
|
108
|
+
|
109
|
+
@abstractmethod
|
110
|
+
def append_readiness_check(self, readiness_checks: "AnyTask" | list["AnyTask"]):
|
111
|
+
"""Add the readiness_check tasks.
|
112
|
+
|
113
|
+
Args:
|
114
|
+
readiness_checks (AnyTask | list[AnyTask]): A single readiness_check task or
|
115
|
+
a list of readiness_check tasks.
|
116
|
+
"""
|
117
|
+
pass
|
118
|
+
|
119
|
+
@abstractmethod
|
120
|
+
def append_upstream(self, upstreams: "AnyTask" | list["AnyTask"]):
|
121
|
+
"""Add the upstream tasks that this task depends on.
|
86
122
|
|
87
123
|
Args:
|
88
124
|
upstreams (AnyTask | list[AnyTask]): A single upstream task or
|
zrb/task/base_task.py
CHANGED
@@ -38,6 +38,7 @@ class BaseTask(AnyTask):
|
|
38
38
|
monitor_readiness: bool = False,
|
39
39
|
upstream: list[AnyTask] | AnyTask | None = None,
|
40
40
|
fallback: list[AnyTask] | AnyTask | None = None,
|
41
|
+
successor: list[AnyTask] | AnyTask | None = None,
|
41
42
|
):
|
42
43
|
self._name = name
|
43
44
|
self._color = color
|
@@ -50,6 +51,7 @@ class BaseTask(AnyTask):
|
|
50
51
|
self._retry_period = retry_period
|
51
52
|
self._upstreams = upstream
|
52
53
|
self._fallbacks = fallback
|
54
|
+
self._successors = successor
|
53
55
|
self._readiness_checks = readiness_check
|
54
56
|
self._readiness_check_delay = readiness_check_delay
|
55
57
|
self._readiness_check_period = readiness_check_period
|
@@ -65,17 +67,17 @@ class BaseTask(AnyTask):
|
|
65
67
|
def __rshift__(self, other: AnyTask | list[AnyTask]) -> AnyTask:
|
66
68
|
try:
|
67
69
|
if isinstance(other, AnyTask):
|
68
|
-
other.
|
70
|
+
other.append_upstream(self)
|
69
71
|
elif isinstance(other, list):
|
70
72
|
for task in other:
|
71
|
-
task.
|
73
|
+
task.append_upstream(self)
|
72
74
|
return other
|
73
75
|
except Exception as e:
|
74
76
|
raise ValueError(f"Invalid operation {self} >> {other}: {e}")
|
75
77
|
|
76
78
|
def __lshift__(self, other: AnyTask | list[AnyTask]) -> AnyTask:
|
77
79
|
try:
|
78
|
-
self.
|
80
|
+
self.append_upstream(other)
|
79
81
|
return self
|
80
82
|
except Exception as e:
|
81
83
|
raise ValueError(f"Invalid operation {self} << {other}: {e}")
|
@@ -142,6 +144,44 @@ class BaseTask(AnyTask):
|
|
142
144
|
return [self._fallbacks]
|
143
145
|
return self._fallbacks
|
144
146
|
|
147
|
+
def append_fallback(self, fallbacks: AnyTask | list[AnyTask]):
|
148
|
+
fallback_list = [fallbacks] if isinstance(fallbacks, AnyTask) else fallbacks
|
149
|
+
for fallback in fallback_list:
|
150
|
+
self.__append_fallback(fallback)
|
151
|
+
|
152
|
+
def __append_fallback(self, fallback: AnyTask):
|
153
|
+
# Make sure self._fallbacks is a list
|
154
|
+
if self._fallbacks is None:
|
155
|
+
self._fallbacks = []
|
156
|
+
elif isinstance(self._fallbacks, AnyTask):
|
157
|
+
self._fallbacks = [self._fallbacks]
|
158
|
+
# Add fallback if it was not on self._fallbacks
|
159
|
+
if fallback not in self._fallbacks:
|
160
|
+
self._fallbacks.append(fallback)
|
161
|
+
|
162
|
+
@property
|
163
|
+
def successors(self) -> list[AnyTask]:
|
164
|
+
if self._successors is None:
|
165
|
+
return []
|
166
|
+
elif isinstance(self._successors, AnyTask):
|
167
|
+
return [self._successors]
|
168
|
+
return self._successors
|
169
|
+
|
170
|
+
def append_successor(self, successors: AnyTask | list[AnyTask]):
|
171
|
+
successor_list = [successors] if isinstance(successors, AnyTask) else successors
|
172
|
+
for successor in successor_list:
|
173
|
+
self.__append_successor(successor)
|
174
|
+
|
175
|
+
def __append_successor(self, successor: AnyTask):
|
176
|
+
# Make sure self._successors is a list
|
177
|
+
if self._successors is None:
|
178
|
+
self._successors = []
|
179
|
+
elif isinstance(self._successors, AnyTask):
|
180
|
+
self._successors = [self._successors]
|
181
|
+
# Add successor if it was not on self._successors
|
182
|
+
if successor not in self._successors:
|
183
|
+
self._successors.append(successor)
|
184
|
+
|
145
185
|
@property
|
146
186
|
def readiness_checks(self) -> list[AnyTask]:
|
147
187
|
if self._readiness_checks is None:
|
@@ -150,6 +190,25 @@ class BaseTask(AnyTask):
|
|
150
190
|
return [self._readiness_checks]
|
151
191
|
return self._readiness_checks
|
152
192
|
|
193
|
+
def append_readiness_check(self, readiness_checks: AnyTask | list[AnyTask]):
|
194
|
+
readiness_check_list = (
|
195
|
+
[readiness_checks]
|
196
|
+
if isinstance(readiness_checks, AnyTask)
|
197
|
+
else readiness_checks
|
198
|
+
)
|
199
|
+
for readiness_check in readiness_check_list:
|
200
|
+
self.__append_readiness_check(readiness_check)
|
201
|
+
|
202
|
+
def __append_readiness_check(self, readiness_check: AnyTask):
|
203
|
+
# Make sure self._readiness_checks is a list
|
204
|
+
if self._readiness_checks is None:
|
205
|
+
self._readiness_checks = []
|
206
|
+
elif isinstance(self._readiness_checks, AnyTask):
|
207
|
+
self._readiness_checks = [self._readiness_checks]
|
208
|
+
# Add readiness_check if it was not on self._readiness_checks
|
209
|
+
if readiness_check not in self._readiness_checks:
|
210
|
+
self._readiness_checks.append(readiness_check)
|
211
|
+
|
153
212
|
@property
|
154
213
|
def upstreams(self) -> list[AnyTask]:
|
155
214
|
if self._upstreams is None:
|
@@ -158,7 +217,7 @@ class BaseTask(AnyTask):
|
|
158
217
|
return [self._upstreams]
|
159
218
|
return self._upstreams
|
160
219
|
|
161
|
-
def
|
220
|
+
def append_upstream(self, upstreams: AnyTask | list[AnyTask]):
|
162
221
|
upstream_list = [upstreams] if isinstance(upstreams, AnyTask) else upstreams
|
163
222
|
for upstream in upstream_list:
|
164
223
|
self.__append_upstream(upstream)
|
@@ -374,6 +433,7 @@ class BaseTask(AnyTask):
|
|
374
433
|
# Put result on xcom
|
375
434
|
task_xcom: Xcom = ctx.xcom.get(self.name)
|
376
435
|
task_xcom.push(result)
|
436
|
+
await run_async(self.__exec_successors(session))
|
377
437
|
return result
|
378
438
|
except (asyncio.CancelledError, KeyboardInterrupt):
|
379
439
|
ctx.log_info("Marked as failed")
|
@@ -390,6 +450,13 @@ class BaseTask(AnyTask):
|
|
390
450
|
await run_async(self.__exec_fallbacks(session))
|
391
451
|
raise e
|
392
452
|
|
453
|
+
async def __exec_successors(self, session: AnySession) -> Any:
|
454
|
+
successors: list[AnyTask] = self.successors
|
455
|
+
successor_coros = [
|
456
|
+
run_async(successor.exec_chain(session)) for successor in successors
|
457
|
+
]
|
458
|
+
await asyncio.gather(*successor_coros)
|
459
|
+
|
393
460
|
async def __exec_fallbacks(self, session: AnySession) -> Any:
|
394
461
|
fallbacks: list[AnyTask] = self.fallbacks
|
395
462
|
fallback_coros = [
|