zrb 1.5.7__py3-none-any.whl → 1.5.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zrb/builtin/llm/tool/rag.py +4 -3
- zrb/llm_config.py +38 -0
- zrb/task/any_task.py +22 -6
- zrb/task/base/__init__.py +0 -0
- zrb/task/base/context.py +108 -0
- zrb/task/base/dependencies.py +57 -0
- zrb/task/base/execution.py +274 -0
- zrb/task/base/lifecycle.py +182 -0
- zrb/task/base/monitoring.py +134 -0
- zrb/task/base/operators.py +41 -0
- zrb/task/base_task.py +76 -382
- zrb/task/cmd_task.py +2 -1
- zrb/task/llm/agent.py +141 -0
- zrb/task/llm/config.py +83 -0
- zrb/task/llm/context.py +95 -0
- zrb/task/llm/{context_enricher.py → context_enrichment.py} +55 -6
- zrb/task/llm/history.py +153 -3
- zrb/task/llm/history_summarization.py +173 -0
- zrb/task/llm/prompt.py +87 -0
- zrb/task/llm/typing.py +3 -0
- zrb/task/llm_task.py +140 -323
- {zrb-1.5.7.dist-info → zrb-1.5.9.dist-info}/METADATA +2 -2
- {zrb-1.5.7.dist-info → zrb-1.5.9.dist-info}/RECORD +25 -15
- zrb/task/llm/agent_runner.py +0 -53
- zrb/task/llm/default_context.py +0 -45
- zrb/task/llm/history_summarizer.py +0 -71
- {zrb-1.5.7.dist-info → zrb-1.5.9.dist-info}/WHEEL +0 -0
- {zrb-1.5.7.dist-info → zrb-1.5.9.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,182 @@
|
|
1
|
+
import asyncio
|
2
|
+
from typing import Any
|
3
|
+
|
4
|
+
from zrb.context.shared_context import SharedContext
|
5
|
+
from zrb.session.any_session import AnySession
|
6
|
+
from zrb.session.session import Session
|
7
|
+
from zrb.task.any_task import AnyTask
|
8
|
+
from zrb.task.base.context import fill_shared_context_envs, fill_shared_context_inputs
|
9
|
+
from zrb.util.run import run_async
|
10
|
+
|
11
|
+
|
12
|
+
async def run_and_cleanup(
|
13
|
+
task: AnyTask,
|
14
|
+
session: AnySession | None = None,
|
15
|
+
str_kwargs: dict[str, str] = {},
|
16
|
+
) -> Any:
|
17
|
+
"""
|
18
|
+
Wrapper for async_run that ensures session termination and cleanup of
|
19
|
+
other concurrent asyncio tasks. This is the main entry point for `task.run()`.
|
20
|
+
"""
|
21
|
+
# Ensure a session exists
|
22
|
+
if session is None:
|
23
|
+
session = Session(shared_ctx=SharedContext())
|
24
|
+
|
25
|
+
# Create the main task execution coroutine
|
26
|
+
main_task_coro = asyncio.create_task(run_task_async(task, session, str_kwargs))
|
27
|
+
|
28
|
+
try:
|
29
|
+
result = await main_task_coro
|
30
|
+
return result
|
31
|
+
except (asyncio.CancelledError, KeyboardInterrupt) as e:
|
32
|
+
ctx = task.get_ctx(session) # Get context for logging
|
33
|
+
ctx.log_warning(f"Run cancelled/interrupted: {e}")
|
34
|
+
raise # Re-raise to propagate
|
35
|
+
finally:
|
36
|
+
# Ensure session termination if it exists and wasn't terminated by the run
|
37
|
+
if session and not session.is_terminated:
|
38
|
+
ctx = task.get_ctx(session) # Get context for logging
|
39
|
+
ctx.log_info("Terminating session after run completion/error.")
|
40
|
+
session.terminate()
|
41
|
+
|
42
|
+
# Clean up other potentially running asyncio tasks (excluding the main one)
|
43
|
+
# Be cautious with blanket cancellation if other background tasks are expected
|
44
|
+
pending = [
|
45
|
+
t for t in asyncio.all_tasks() if t is not main_task_coro and not t.done()
|
46
|
+
]
|
47
|
+
if pending:
|
48
|
+
ctx = task.get_ctx(session) # Get context for logging
|
49
|
+
ctx.log_debug(f"Cleaning up {len(pending)} pending asyncio tasks...")
|
50
|
+
for t in pending:
|
51
|
+
t.cancel()
|
52
|
+
try:
|
53
|
+
# Give cancelled tasks a moment to process cancellation
|
54
|
+
await asyncio.wait(pending, timeout=1.0)
|
55
|
+
except asyncio.CancelledError:
|
56
|
+
# Expected if tasks handle cancellation promptly
|
57
|
+
pass
|
58
|
+
except Exception as cleanup_exc:
|
59
|
+
# Log errors during cleanup if necessary
|
60
|
+
ctx.log_warning(f"Error during task cleanup: {cleanup_exc}")
|
61
|
+
|
62
|
+
|
63
|
+
async def run_task_async(
|
64
|
+
task: AnyTask,
|
65
|
+
session: AnySession | None = None,
|
66
|
+
str_kwargs: dict[str, str] = {},
|
67
|
+
) -> Any:
|
68
|
+
"""
|
69
|
+
Asynchronous entry point for running a task (`task.async_run()`).
|
70
|
+
Sets up the session and initiates the root task execution chain.
|
71
|
+
"""
|
72
|
+
if session is None:
|
73
|
+
session = Session(shared_ctx=SharedContext())
|
74
|
+
|
75
|
+
# Populate shared context with inputs and environment variables
|
76
|
+
fill_shared_context_inputs(task, session.shared_ctx, str_kwargs)
|
77
|
+
fill_shared_context_envs(session.shared_ctx) # Inject OS env vars
|
78
|
+
|
79
|
+
# Start the execution chain from the root tasks
|
80
|
+
result = await run_async(execute_root_tasks(task, session))
|
81
|
+
return result
|
82
|
+
|
83
|
+
|
84
|
+
async def execute_root_tasks(task: AnyTask, session: AnySession):
|
85
|
+
"""
|
86
|
+
Identifies and executes the root tasks required for the main task,
|
87
|
+
manages session state logging, and handles overall execution flow.
|
88
|
+
"""
|
89
|
+
session.set_main_task(task)
|
90
|
+
session.state_logger.write(session.as_state_log()) # Initial state log
|
91
|
+
ctx = task.get_ctx(session) # Get context early for logging
|
92
|
+
|
93
|
+
log_state_task = None
|
94
|
+
try:
|
95
|
+
# Start background state logging
|
96
|
+
log_state_task = asyncio.create_task(log_session_state(task, session))
|
97
|
+
|
98
|
+
# Identify root tasks allowed to run
|
99
|
+
root_tasks = [
|
100
|
+
t for t in session.get_root_tasks(task) if session.is_allowed_to_run(t)
|
101
|
+
]
|
102
|
+
|
103
|
+
if not root_tasks:
|
104
|
+
ctx.log_info("No root tasks to execute for this task.")
|
105
|
+
# If the main task itself should run even with no explicit roots?
|
106
|
+
# Current logic seems to imply if no roots, nothing runs.
|
107
|
+
session.terminate() # Terminate if nothing to run
|
108
|
+
return None
|
109
|
+
|
110
|
+
ctx.log_info(f"Executing {len(root_tasks)} root task(s)")
|
111
|
+
root_task_coros = [
|
112
|
+
# Assuming exec_chain exists on AnyTask (it's abstract)
|
113
|
+
run_async(root_task.exec_chain(session))
|
114
|
+
for root_task in root_tasks
|
115
|
+
]
|
116
|
+
|
117
|
+
# Wait for all root chains to complete
|
118
|
+
await asyncio.gather(*root_task_coros)
|
119
|
+
|
120
|
+
# Wait for any deferred actions (like long-running task bodies)
|
121
|
+
ctx.log_info("Waiting for deferred actions...")
|
122
|
+
await session.wait_deferred()
|
123
|
+
ctx.log_info("Deferred actions complete.")
|
124
|
+
|
125
|
+
# Final termination and logging
|
126
|
+
session.terminate()
|
127
|
+
if log_state_task and not log_state_task.done():
|
128
|
+
await log_state_task # Ensure final state is logged
|
129
|
+
ctx.log_info("Session finished.")
|
130
|
+
return session.final_result
|
131
|
+
|
132
|
+
except IndexError:
|
133
|
+
# This might occur if get_root_tasks fails unexpectedly
|
134
|
+
ctx.log_error(
|
135
|
+
"IndexError during root task execution, potentially session issue."
|
136
|
+
)
|
137
|
+
session.terminate() # Ensure termination on error
|
138
|
+
return None
|
139
|
+
except (asyncio.CancelledError, KeyboardInterrupt):
|
140
|
+
ctx.log_warning("Session execution cancelled or interrupted.")
|
141
|
+
# Session termination happens in finally block
|
142
|
+
return None # Indicate abnormal termination
|
143
|
+
finally:
|
144
|
+
# Ensure termination and final state logging regardless of outcome
|
145
|
+
if not session.is_terminated:
|
146
|
+
session.terminate()
|
147
|
+
# Ensure the state logger task is awaited/cancelled properly
|
148
|
+
if log_state_task:
|
149
|
+
if not log_state_task.done():
|
150
|
+
log_state_task.cancel()
|
151
|
+
try:
|
152
|
+
await log_state_task
|
153
|
+
except asyncio.CancelledError:
|
154
|
+
pass # Expected cancellation
|
155
|
+
# Log final state after ensuring logger task is finished
|
156
|
+
session.state_logger.write(session.as_state_log())
|
157
|
+
else:
|
158
|
+
# Log final state even if logger task didn't start
|
159
|
+
session.state_logger.write(session.as_state_log())
|
160
|
+
|
161
|
+
ctx.log_debug(f"Final session state: {session}") # Log final session details
|
162
|
+
|
163
|
+
|
164
|
+
async def log_session_state(task: AnyTask, session: AnySession):
|
165
|
+
"""
|
166
|
+
Periodically logs the session state until the session is terminated.
|
167
|
+
"""
|
168
|
+
try:
|
169
|
+
while not session.is_terminated:
|
170
|
+
session.state_logger.write(session.as_state_log())
|
171
|
+
await asyncio.sleep(0.1) # Log interval
|
172
|
+
# Log one final time after termination signal
|
173
|
+
session.state_logger.write(session.as_state_log())
|
174
|
+
except asyncio.CancelledError:
|
175
|
+
# Log final state on cancellation too
|
176
|
+
session.state_logger.write(session.as_state_log())
|
177
|
+
ctx = task.get_ctx(session)
|
178
|
+
ctx.log_debug("Session state logger cancelled.")
|
179
|
+
except Exception as e:
|
180
|
+
# Log any unexpected errors in the logger itself
|
181
|
+
ctx = task.get_ctx(session)
|
182
|
+
ctx.log_error(f"Error in session state logger: {e}")
|
@@ -0,0 +1,134 @@
|
|
1
|
+
import asyncio
|
2
|
+
|
3
|
+
from zrb.session.any_session import AnySession
|
4
|
+
from zrb.task.base.execution import execute_action_with_retry
|
5
|
+
from zrb.task.base_task import BaseTask
|
6
|
+
from zrb.util.run import run_async
|
7
|
+
from zrb.xcom.xcom import Xcom
|
8
|
+
|
9
|
+
|
10
|
+
async def monitor_task_readiness(
|
11
|
+
task: BaseTask, session: AnySession, action_coro: asyncio.Task
|
12
|
+
):
|
13
|
+
"""
|
14
|
+
Monitors the readiness of a task after its initial execution.
|
15
|
+
If readiness checks fail beyond a threshold, it cancels the original action,
|
16
|
+
resets the task status, and re-executes the action.
|
17
|
+
"""
|
18
|
+
ctx = task.get_ctx(session)
|
19
|
+
readiness_checks = task.readiness_checks
|
20
|
+
readiness_check_period = getattr(task, "_readiness_check_period", 5.0)
|
21
|
+
readiness_failure_threshold = getattr(task, "_readiness_failure_threshold", 1)
|
22
|
+
readiness_timeout = getattr(task, "_readiness_timeout", 60)
|
23
|
+
|
24
|
+
if not readiness_checks:
|
25
|
+
ctx.log_debug("No readiness checks defined, monitoring is not applicable.")
|
26
|
+
return
|
27
|
+
|
28
|
+
failure_count = 0
|
29
|
+
ctx.log_info("Starting readiness monitoring...")
|
30
|
+
|
31
|
+
while not session.is_terminated:
|
32
|
+
await asyncio.sleep(readiness_check_period)
|
33
|
+
|
34
|
+
if session.is_terminated:
|
35
|
+
break # Exit loop if session terminated during sleep
|
36
|
+
|
37
|
+
if failure_count < readiness_failure_threshold:
|
38
|
+
ctx.log_info("Performing periodic readiness check...")
|
39
|
+
# Reset status and XCom for readiness check tasks before re-running
|
40
|
+
for check in readiness_checks:
|
41
|
+
session.get_task_status(check).reset_history()
|
42
|
+
session.get_task_status(check).reset()
|
43
|
+
# Clear previous XCom data for the check task if needed
|
44
|
+
check_xcom: Xcom = ctx.xcom.get(check.name)
|
45
|
+
check_xcom.clear()
|
46
|
+
|
47
|
+
readiness_check_coros = [
|
48
|
+
run_async(check.exec_chain(session)) for check in readiness_checks
|
49
|
+
]
|
50
|
+
|
51
|
+
try:
|
52
|
+
# Wait for checks with a timeout
|
53
|
+
await asyncio.wait_for(
|
54
|
+
asyncio.gather(*readiness_check_coros),
|
55
|
+
timeout=readiness_timeout,
|
56
|
+
)
|
57
|
+
# Check if all checks actually completed successfully
|
58
|
+
all_checks_completed = all(
|
59
|
+
session.get_task_status(check).is_completed
|
60
|
+
for check in readiness_checks
|
61
|
+
)
|
62
|
+
if all_checks_completed:
|
63
|
+
ctx.log_info("Readiness check OK.")
|
64
|
+
failure_count = 0 # Reset failure count on success
|
65
|
+
continue # Continue monitoring
|
66
|
+
else:
|
67
|
+
ctx.log_warning(
|
68
|
+
"Periodic readiness check failed (tasks did not complete)."
|
69
|
+
)
|
70
|
+
failure_count += 1
|
71
|
+
|
72
|
+
except asyncio.TimeoutError:
|
73
|
+
failure_count += 1
|
74
|
+
ctx.log_warning(
|
75
|
+
f"Readiness check timed out ({readiness_timeout}s). "
|
76
|
+
f"Failure count: {failure_count}/{readiness_failure_threshold}"
|
77
|
+
)
|
78
|
+
# Ensure check tasks are marked as failed on timeout
|
79
|
+
for check in readiness_checks:
|
80
|
+
if not session.get_task_status(check).is_finished:
|
81
|
+
session.get_task_status(check).mark_as_failed()
|
82
|
+
|
83
|
+
except (asyncio.CancelledError, KeyboardInterrupt):
|
84
|
+
ctx.log_info("Monitoring cancelled or interrupted.")
|
85
|
+
break # Exit monitoring loop
|
86
|
+
|
87
|
+
except Exception as e:
|
88
|
+
failure_count += 1
|
89
|
+
ctx.log_error(
|
90
|
+
f"Readiness check failed with exception: {e}. "
|
91
|
+
f"Failure count: {failure_count}"
|
92
|
+
)
|
93
|
+
# Mark checks as failed
|
94
|
+
for check in readiness_checks:
|
95
|
+
if not session.get_task_status(check).is_finished:
|
96
|
+
session.get_task_status(check).mark_as_failed()
|
97
|
+
|
98
|
+
# If failure threshold is reached
|
99
|
+
if failure_count >= readiness_failure_threshold:
|
100
|
+
ctx.log_warning(
|
101
|
+
f"Readiness failure threshold ({readiness_failure_threshold}) reached."
|
102
|
+
)
|
103
|
+
|
104
|
+
# Cancel the original running action if it's still running
|
105
|
+
if action_coro and not action_coro.done():
|
106
|
+
ctx.log_info("Cancelling original task action...")
|
107
|
+
action_coro.cancel()
|
108
|
+
try:
|
109
|
+
await action_coro # Allow cancellation to process
|
110
|
+
except asyncio.CancelledError:
|
111
|
+
ctx.log_info("Original task action cancelled.")
|
112
|
+
except Exception as e:
|
113
|
+
ctx.log_warning(f"Error during original action cancellation: {e}")
|
114
|
+
|
115
|
+
# Reset the main task status
|
116
|
+
ctx.log_info("Resetting task status.")
|
117
|
+
session.get_task_status(task).reset()
|
118
|
+
|
119
|
+
# Re-execute the action (with retries)
|
120
|
+
ctx.log_info("Re-executing task action...")
|
121
|
+
# Import dynamically to avoid circular dependency
|
122
|
+
new_action_coro = asyncio.create_task(
|
123
|
+
run_async(execute_action_with_retry(task, session))
|
124
|
+
)
|
125
|
+
# Defer the new action coroutine
|
126
|
+
session.defer_action(task, new_action_coro)
|
127
|
+
# Update the reference for the next monitoring cycle
|
128
|
+
action_coro = new_action_coro
|
129
|
+
|
130
|
+
# Reset failure count after attempting restart
|
131
|
+
failure_count = 0
|
132
|
+
ctx.log_info("Continuing monitoring...")
|
133
|
+
|
134
|
+
ctx.log_info("Stopping readiness monitoring.")
|
@@ -0,0 +1,41 @@
|
|
1
|
+
# No specific imports needed from typing for these changes
|
2
|
+
from zrb.task.any_task import AnyTask
|
3
|
+
|
4
|
+
|
5
|
+
def handle_rshift(
|
6
|
+
left_task: AnyTask, right_operand: AnyTask | list[AnyTask]
|
7
|
+
) -> AnyTask | list[AnyTask]:
|
8
|
+
"""
|
9
|
+
Implements the >> operator logic: left_task becomes an upstream for right_operand.
|
10
|
+
Modifies the right_operand(s) by calling append_upstream.
|
11
|
+
Returns the right_operand.
|
12
|
+
"""
|
13
|
+
try:
|
14
|
+
if isinstance(right_operand, list):
|
15
|
+
for task in right_operand:
|
16
|
+
# Assuming append_upstream exists and handles duplicates
|
17
|
+
task.append_upstream(left_task)
|
18
|
+
else:
|
19
|
+
# Assuming right_operand is a single AnyTask
|
20
|
+
right_operand.append_upstream(left_task)
|
21
|
+
return right_operand
|
22
|
+
except Exception as e:
|
23
|
+
# Catch potential errors during append_upstream or type issues
|
24
|
+
raise ValueError(f"Invalid operation {left_task} >> {right_operand}: {e}")
|
25
|
+
|
26
|
+
|
27
|
+
def handle_lshift(
|
28
|
+
left_task: AnyTask, right_operand: AnyTask | list[AnyTask]
|
29
|
+
) -> AnyTask:
|
30
|
+
"""
|
31
|
+
Implements the << operator logic: right_operand becomes an upstream for left_task.
|
32
|
+
Modifies the left_task by calling append_upstream.
|
33
|
+
Returns the left_task.
|
34
|
+
"""
|
35
|
+
try:
|
36
|
+
# Assuming append_upstream exists and handles single or list input
|
37
|
+
left_task.append_upstream(right_operand)
|
38
|
+
return left_task
|
39
|
+
except Exception as e:
|
40
|
+
# Catch potential errors during append_upstream or type issues
|
41
|
+
raise ValueError(f"Invalid operation {left_task} << {right_operand}: {e}")
|