zrb 1.5.8__py3-none-any.whl → 1.5.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zrb/builtin/llm/tool/file.py +1 -3
- zrb/llm_config.py +82 -46
- zrb/task/any_task.py +22 -6
- zrb/task/base/__init__.py +0 -0
- zrb/task/base/context.py +108 -0
- zrb/task/base/dependencies.py +57 -0
- zrb/task/base/execution.py +274 -0
- zrb/task/base/lifecycle.py +182 -0
- zrb/task/base/monitoring.py +134 -0
- zrb/task/base/operators.py +41 -0
- zrb/task/base_task.py +91 -381
- zrb/task/llm/context_enrichment.py +6 -3
- zrb/task/llm/history_summarization.py +10 -7
- zrb/task/llm/prompt.py +64 -3
- zrb/task/llm_task.py +26 -16
- zrb/task/scheduler.py +1 -2
- {zrb-1.5.8.dist-info → zrb-1.5.10.dist-info}/METADATA +2 -2
- {zrb-1.5.8.dist-info → zrb-1.5.10.dist-info}/RECORD +20 -13
- {zrb-1.5.8.dist-info → zrb-1.5.10.dist-info}/WHEEL +0 -0
- {zrb-1.5.8.dist-info → zrb-1.5.10.dist-info}/entry_points.txt +0 -0
zrb/task/base_task.py
CHANGED
@@ -1,19 +1,25 @@
|
|
1
1
|
import asyncio
|
2
|
-
import
|
3
|
-
from collections.abc import Callable, Coroutine
|
2
|
+
from collections.abc import Callable
|
4
3
|
from typing import Any
|
5
4
|
|
6
5
|
from zrb.attr.type import BoolAttr, fstring
|
7
6
|
from zrb.context.any_context import AnyContext
|
8
|
-
from zrb.context.shared_context import AnySharedContext, SharedContext
|
9
7
|
from zrb.env.any_env import AnyEnv
|
10
8
|
from zrb.input.any_input import AnyInput
|
11
9
|
from zrb.session.any_session import AnySession
|
12
|
-
from zrb.session.session import Session
|
13
10
|
from zrb.task.any_task import AnyTask
|
14
|
-
from zrb.
|
15
|
-
|
16
|
-
|
11
|
+
from zrb.task.base.context import (
|
12
|
+
build_task_context,
|
13
|
+
get_combined_envs,
|
14
|
+
get_combined_inputs,
|
15
|
+
)
|
16
|
+
from zrb.task.base.execution import (
|
17
|
+
execute_task_action,
|
18
|
+
execute_task_chain,
|
19
|
+
run_default_action,
|
20
|
+
)
|
21
|
+
from zrb.task.base.lifecycle import execute_root_tasks, run_and_cleanup, run_task_async
|
22
|
+
from zrb.task.base.operators import handle_lshift, handle_rshift
|
17
23
|
|
18
24
|
|
19
25
|
class BaseTask(AnyTask):
|
@@ -62,25 +68,13 @@ class BaseTask(AnyTask):
|
|
62
68
|
self._action = action
|
63
69
|
|
64
70
|
def __repr__(self):
|
65
|
-
return f"<{self.__class__.__name__} name={self.
|
71
|
+
return f"<{self.__class__.__name__} name={self.name}>"
|
66
72
|
|
67
73
|
def __rshift__(self, other: AnyTask | list[AnyTask]) -> AnyTask | list[AnyTask]:
|
68
|
-
|
69
|
-
if isinstance(other, AnyTask):
|
70
|
-
other.append_upstream(self)
|
71
|
-
elif isinstance(other, list):
|
72
|
-
for task in other:
|
73
|
-
task.append_upstream(self)
|
74
|
-
return other
|
75
|
-
except Exception as e:
|
76
|
-
raise ValueError(f"Invalid operation {self} >> {other}: {e}")
|
74
|
+
return handle_rshift(self, other)
|
77
75
|
|
78
76
|
def __lshift__(self, other: AnyTask | list[AnyTask]) -> AnyTask:
|
79
|
-
|
80
|
-
self.append_upstream(other)
|
81
|
-
return self
|
82
|
-
except Exception as e:
|
83
|
-
raise ValueError(f"Invalid operation {self} << {other}: {e}")
|
77
|
+
return handle_lshift(self, other)
|
84
78
|
|
85
79
|
@property
|
86
80
|
def name(self) -> str:
|
@@ -104,430 +98,146 @@ class BaseTask(AnyTask):
|
|
104
98
|
|
105
99
|
@property
|
106
100
|
def envs(self) -> list[AnyEnv]:
|
107
|
-
|
108
|
-
for upstream in self.upstreams:
|
109
|
-
envs += upstream.envs
|
110
|
-
if isinstance(self._envs, AnyEnv):
|
111
|
-
envs.append(self._envs)
|
112
|
-
elif self._envs is not None:
|
113
|
-
envs += self._envs
|
114
|
-
return [env for env in envs if env is not None]
|
101
|
+
return get_combined_envs(self)
|
115
102
|
|
116
103
|
@property
|
117
104
|
def inputs(self) -> list[AnyInput]:
|
118
|
-
|
119
|
-
for upstream in self.upstreams:
|
120
|
-
self.__combine_inputs(inputs, upstream.inputs)
|
121
|
-
if self._inputs is not None:
|
122
|
-
self.__combine_inputs(inputs, self._inputs)
|
123
|
-
return [task_input for task_input in inputs if task_input is not None]
|
124
|
-
|
125
|
-
def __combine_inputs(
|
126
|
-
self,
|
127
|
-
inputs: list[AnyInput],
|
128
|
-
other_inputs: list[AnyInput | None] | AnyInput | None,
|
129
|
-
):
|
130
|
-
input_names = [task_input.name for task_input in inputs]
|
131
|
-
if isinstance(other_inputs, AnyInput):
|
132
|
-
other_inputs = [other_inputs]
|
133
|
-
elif other_inputs is None:
|
134
|
-
other_inputs = []
|
135
|
-
for task_input in other_inputs:
|
136
|
-
if task_input is None:
|
137
|
-
continue
|
138
|
-
if task_input.name not in input_names:
|
139
|
-
inputs.append(task_input)
|
105
|
+
return get_combined_inputs(self)
|
140
106
|
|
141
107
|
@property
|
142
108
|
def fallbacks(self) -> list[AnyTask]:
|
109
|
+
"""Returns the list of fallback tasks."""
|
143
110
|
if self._fallbacks is None:
|
144
111
|
return []
|
145
|
-
elif isinstance(self._fallbacks,
|
146
|
-
return
|
147
|
-
return self._fallbacks
|
112
|
+
elif isinstance(self._fallbacks, list):
|
113
|
+
return self._fallbacks
|
114
|
+
return [self._fallbacks] # Assume single task
|
148
115
|
|
149
116
|
def append_fallback(self, fallbacks: AnyTask | list[AnyTask]):
|
150
|
-
|
151
|
-
for fallback in fallback_list:
|
152
|
-
self.__append_fallback(fallback)
|
153
|
-
|
154
|
-
def __append_fallback(self, fallback: AnyTask):
|
155
|
-
# Make sure self._fallbacks is a list
|
117
|
+
"""Appends fallback tasks, ensuring no duplicates."""
|
156
118
|
if self._fallbacks is None:
|
157
119
|
self._fallbacks = []
|
158
|
-
elif isinstance(self._fallbacks,
|
120
|
+
elif not isinstance(self._fallbacks, list):
|
159
121
|
self._fallbacks = [self._fallbacks]
|
160
|
-
|
161
|
-
|
162
|
-
self._fallbacks
|
122
|
+
to_add = fallbacks if isinstance(fallbacks, list) else [fallbacks]
|
123
|
+
for fb in to_add:
|
124
|
+
if fb not in self._fallbacks:
|
125
|
+
self._fallbacks.append(fb)
|
163
126
|
|
164
127
|
@property
|
165
128
|
def successors(self) -> list[AnyTask]:
|
129
|
+
"""Returns the list of successor tasks."""
|
166
130
|
if self._successors is None:
|
167
131
|
return []
|
168
|
-
elif isinstance(self._successors,
|
169
|
-
return
|
170
|
-
return self._successors
|
132
|
+
elif isinstance(self._successors, list):
|
133
|
+
return self._successors
|
134
|
+
return [self._successors] # Assume single task
|
171
135
|
|
172
136
|
def append_successor(self, successors: AnyTask | list[AnyTask]):
|
173
|
-
|
174
|
-
for successor in successor_list:
|
175
|
-
self.__append_successor(successor)
|
176
|
-
|
177
|
-
def __append_successor(self, successor: AnyTask):
|
178
|
-
# Make sure self._successors is a list
|
137
|
+
"""Appends successor tasks, ensuring no duplicates."""
|
179
138
|
if self._successors is None:
|
180
139
|
self._successors = []
|
181
|
-
elif isinstance(self._successors,
|
140
|
+
elif not isinstance(self._successors, list):
|
182
141
|
self._successors = [self._successors]
|
183
|
-
|
184
|
-
|
185
|
-
self._successors
|
142
|
+
to_add = successors if isinstance(successors, list) else [successors]
|
143
|
+
for succ in to_add:
|
144
|
+
if succ not in self._successors:
|
145
|
+
self._successors.append(succ)
|
186
146
|
|
187
147
|
@property
|
188
148
|
def readiness_checks(self) -> list[AnyTask]:
|
149
|
+
"""Returns the list of readiness check tasks."""
|
189
150
|
if self._readiness_checks is None:
|
190
151
|
return []
|
191
|
-
elif isinstance(self._readiness_checks,
|
192
|
-
return
|
193
|
-
return self._readiness_checks
|
152
|
+
elif isinstance(self._readiness_checks, list):
|
153
|
+
return self._readiness_checks
|
154
|
+
return [self._readiness_checks] # Assume single task
|
194
155
|
|
195
156
|
def append_readiness_check(self, readiness_checks: AnyTask | list[AnyTask]):
|
196
|
-
|
197
|
-
[readiness_checks]
|
198
|
-
if isinstance(readiness_checks, AnyTask)
|
199
|
-
else readiness_checks
|
200
|
-
)
|
201
|
-
for readiness_check in readiness_check_list:
|
202
|
-
self.__append_readiness_check(readiness_check)
|
203
|
-
|
204
|
-
def __append_readiness_check(self, readiness_check: AnyTask):
|
205
|
-
# Make sure self._readiness_checks is a list
|
157
|
+
"""Appends readiness check tasks, ensuring no duplicates."""
|
206
158
|
if self._readiness_checks is None:
|
207
159
|
self._readiness_checks = []
|
208
|
-
elif isinstance(self._readiness_checks,
|
160
|
+
elif not isinstance(self._readiness_checks, list):
|
209
161
|
self._readiness_checks = [self._readiness_checks]
|
210
|
-
|
211
|
-
|
212
|
-
|
162
|
+
to_add = (
|
163
|
+
readiness_checks
|
164
|
+
if isinstance(readiness_checks, list)
|
165
|
+
else [readiness_checks]
|
166
|
+
)
|
167
|
+
for rc in to_add:
|
168
|
+
if rc not in self._readiness_checks:
|
169
|
+
self._readiness_checks.append(rc)
|
213
170
|
|
214
171
|
@property
|
215
172
|
def upstreams(self) -> list[AnyTask]:
|
173
|
+
"""Returns the list of upstream tasks."""
|
216
174
|
if self._upstreams is None:
|
217
175
|
return []
|
218
|
-
elif isinstance(self._upstreams,
|
219
|
-
return
|
220
|
-
return self._upstreams
|
176
|
+
elif isinstance(self._upstreams, list):
|
177
|
+
return self._upstreams
|
178
|
+
return [self._upstreams] # Assume single task
|
221
179
|
|
222
180
|
def append_upstream(self, upstreams: AnyTask | list[AnyTask]):
|
223
|
-
|
224
|
-
for upstream in upstream_list:
|
225
|
-
self.__append_upstream(upstream)
|
226
|
-
|
227
|
-
def __append_upstream(self, upstream: AnyTask):
|
228
|
-
# Make sure self._upstreams is a list
|
181
|
+
"""Appends upstream tasks, ensuring no duplicates."""
|
229
182
|
if self._upstreams is None:
|
230
183
|
self._upstreams = []
|
231
|
-
elif isinstance(self._upstreams,
|
184
|
+
elif not isinstance(self._upstreams, list):
|
232
185
|
self._upstreams = [self._upstreams]
|
233
|
-
|
234
|
-
|
235
|
-
self._upstreams
|
186
|
+
to_add = upstreams if isinstance(upstreams, list) else [upstreams]
|
187
|
+
for up in to_add:
|
188
|
+
if up not in self._upstreams:
|
189
|
+
self._upstreams.append(up)
|
236
190
|
|
237
191
|
def get_ctx(self, session: AnySession) -> AnyContext:
|
238
|
-
|
239
|
-
# Enhance session ctx with current task env
|
240
|
-
for env in self.envs:
|
241
|
-
env.update_context(ctx)
|
242
|
-
return ctx
|
192
|
+
return build_task_context(self, session)
|
243
193
|
|
244
194
|
def run(
|
245
195
|
self, session: AnySession | None = None, str_kwargs: dict[str, str] = {}
|
246
196
|
) -> Any:
|
247
|
-
|
197
|
+
"""
|
198
|
+
Synchronously runs the task and its dependencies, handling async setup and cleanup.
|
248
199
|
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
if pending:
|
265
|
-
try:
|
266
|
-
await asyncio.wait(pending, timeout=5)
|
267
|
-
except asyncio.CancelledError:
|
268
|
-
pass
|
269
|
-
return result
|
200
|
+
Uses `asyncio.run()` internally, which creates a new event loop.
|
201
|
+
WARNING: Do not call this method from within an already running asyncio
|
202
|
+
event loop, as it will raise a RuntimeError. Use `async_run` instead
|
203
|
+
if you are in an async context.
|
204
|
+
|
205
|
+
Args:
|
206
|
+
session (AnySession | None): The session to use. If None, a new one
|
207
|
+
might be created implicitly.
|
208
|
+
str_kwargs (dict[str, str]): String-based key-value arguments for inputs.
|
209
|
+
|
210
|
+
Returns:
|
211
|
+
Any: The final result of the main task execution.
|
212
|
+
"""
|
213
|
+
# Use asyncio.run() to execute the async cleanup wrapper
|
214
|
+
return asyncio.run(run_and_cleanup(self, session, str_kwargs))
|
270
215
|
|
271
216
|
async def async_run(
|
272
217
|
self, session: AnySession | None = None, str_kwargs: dict[str, str] = {}
|
273
218
|
) -> Any:
|
274
|
-
|
275
|
-
session = Session(shared_ctx=SharedContext())
|
276
|
-
# Update session
|
277
|
-
self.__fill_shared_context_inputs(session.shared_ctx, str_kwargs)
|
278
|
-
self.__fill_shared_context_envs(session.shared_ctx)
|
279
|
-
result = await run_async(self.exec_root_tasks(session))
|
280
|
-
return result
|
281
|
-
|
282
|
-
def __fill_shared_context_inputs(
|
283
|
-
self, shared_context: AnySharedContext, str_kwargs: dict[str, str] = {}
|
284
|
-
):
|
285
|
-
for task_input in self.inputs:
|
286
|
-
if task_input.name not in shared_context.input:
|
287
|
-
str_value = str_kwargs.get(task_input.name, None)
|
288
|
-
task_input.update_shared_context(shared_context, str_value)
|
289
|
-
|
290
|
-
def __fill_shared_context_envs(self, shared_context: AnySharedContext):
|
291
|
-
# Inject os environ
|
292
|
-
os_env_map = {
|
293
|
-
key: val for key, val in os.environ.items() if key not in shared_context.env
|
294
|
-
}
|
295
|
-
shared_context.env.update(os_env_map)
|
219
|
+
return await run_task_async(self, session, str_kwargs)
|
296
220
|
|
297
221
|
async def exec_root_tasks(self, session: AnySession):
|
298
|
-
|
299
|
-
session.state_logger.write(session.as_state_log())
|
300
|
-
try:
|
301
|
-
log_state = asyncio.create_task(self._log_session_state(session))
|
302
|
-
root_tasks = [
|
303
|
-
task
|
304
|
-
for task in session.get_root_tasks(self)
|
305
|
-
if session.is_allowed_to_run(task)
|
306
|
-
]
|
307
|
-
root_task_coros = [
|
308
|
-
run_async(root_task.exec_chain(session)) for root_task in root_tasks
|
309
|
-
]
|
310
|
-
await asyncio.gather(*root_task_coros)
|
311
|
-
await session.wait_deferred()
|
312
|
-
session.terminate()
|
313
|
-
await log_state
|
314
|
-
return session.final_result
|
315
|
-
except IndexError:
|
316
|
-
return None
|
317
|
-
except (asyncio.CancelledError, KeyboardInterrupt):
|
318
|
-
ctx = self.get_ctx(session)
|
319
|
-
ctx.log_info("Session terminated")
|
320
|
-
finally:
|
321
|
-
session.terminate()
|
322
|
-
session.state_logger.write(session.as_state_log())
|
323
|
-
ctx = self.get_ctx(session)
|
324
|
-
ctx.log_debug(session)
|
325
|
-
|
326
|
-
async def _log_session_state(self, session: AnySession):
|
327
|
-
session.state_logger.write(session.as_state_log())
|
328
|
-
while not session.is_terminated:
|
329
|
-
await asyncio.sleep(0.1)
|
330
|
-
session.state_logger.write(session.as_state_log())
|
331
|
-
session.state_logger.write(session.as_state_log())
|
222
|
+
return await execute_root_tasks(self, session)
|
332
223
|
|
333
224
|
async def exec_chain(self, session: AnySession):
|
334
|
-
|
335
|
-
return
|
336
|
-
result = await self.exec(session)
|
337
|
-
# Get next tasks
|
338
|
-
nexts = session.get_next_tasks(self)
|
339
|
-
if session.is_terminated or len(nexts) == 0:
|
340
|
-
return result
|
341
|
-
# Run next tasks asynchronously
|
342
|
-
next_coros = [run_async(next.exec_chain(session)) for next in nexts]
|
343
|
-
return await asyncio.gather(*next_coros)
|
225
|
+
return await execute_task_chain(self, session)
|
344
226
|
|
345
227
|
async def exec(self, session: AnySession):
|
346
|
-
|
347
|
-
if not session.is_allowed_to_run(self):
|
348
|
-
# Task is not allowed to run, skip it for now.
|
349
|
-
# This will be triggered later
|
350
|
-
ctx.log_info("Not allowed to run")
|
351
|
-
return
|
352
|
-
if not self.__get_execute_condition(session):
|
353
|
-
# Skip the task
|
354
|
-
ctx.log_info("Marked as skipped")
|
355
|
-
session.get_task_status(self).mark_as_skipped()
|
356
|
-
return
|
357
|
-
# Wait for task to be ready
|
358
|
-
return await run_async(self.__exec_action_until_ready(session))
|
359
|
-
|
360
|
-
def __get_execute_condition(self, session: Session) -> bool:
|
361
|
-
ctx = self.get_ctx(session)
|
362
|
-
return get_bool_attr(ctx, self._execute_condition, True, auto_render=True)
|
363
|
-
|
364
|
-
async def __exec_action_until_ready(self, session: AnySession):
|
365
|
-
ctx = self.get_ctx(session)
|
366
|
-
readiness_checks = self.readiness_checks
|
367
|
-
if len(readiness_checks) == 0:
|
368
|
-
ctx.log_info("No readiness checks")
|
369
|
-
# Task has no readiness check
|
370
|
-
result = await run_async(self.__exec_action_and_retry(session))
|
371
|
-
ctx.log_info("Marked as ready")
|
372
|
-
session.get_task_status(self).mark_as_ready()
|
373
|
-
return result
|
374
|
-
# Start the task along with the readiness checks
|
375
|
-
action_coro = asyncio.create_task(
|
376
|
-
run_async(self.__exec_action_and_retry(session))
|
377
|
-
)
|
378
|
-
await asyncio.sleep(self._readiness_check_delay)
|
379
|
-
readiness_check_coros = [
|
380
|
-
run_async(check.exec_chain(session)) for check in readiness_checks
|
381
|
-
]
|
382
|
-
# Only wait for readiness checks and mark the task as ready
|
383
|
-
ctx.log_info("Start readiness checks")
|
384
|
-
result = await asyncio.gather(*readiness_check_coros)
|
385
|
-
ctx.log_info("Readiness checks completed")
|
386
|
-
ctx.log_info("Marked as ready")
|
387
|
-
session.get_task_status(self).mark_as_ready()
|
388
|
-
# Defer task's coroutines, will be waited later
|
389
|
-
session.defer_action(self, action_coro)
|
390
|
-
if self._monitor_readiness:
|
391
|
-
monitor_and_rerun_coro = asyncio.create_task(
|
392
|
-
run_async(self.__exec_monitoring(session, action_coro))
|
393
|
-
)
|
394
|
-
session.defer_monitoring(self, monitor_and_rerun_coro)
|
395
|
-
return result
|
396
|
-
|
397
|
-
async def __exec_monitoring(self, session: AnySession, action_coro: asyncio.Task):
|
398
|
-
readiness_checks = self.readiness_checks
|
399
|
-
failure_count = 0
|
400
|
-
ctx = self.get_ctx(session)
|
401
|
-
while not session.is_terminated:
|
402
|
-
await asyncio.sleep(self._readiness_check_period)
|
403
|
-
if failure_count < self._readiness_failure_threshold:
|
404
|
-
for readiness_check in readiness_checks:
|
405
|
-
session.get_task_status(readiness_check).reset_history()
|
406
|
-
session.get_task_status(readiness_check).reset()
|
407
|
-
readiness_xcom: Xcom = ctx.xcom[self.name]
|
408
|
-
readiness_xcom.clear()
|
409
|
-
readiness_check_coros = [
|
410
|
-
check.exec_chain(session) for check in readiness_checks
|
411
|
-
]
|
412
|
-
try:
|
413
|
-
ctx.log_info("Checking")
|
414
|
-
await asyncio.wait_for(
|
415
|
-
asyncio.gather(*readiness_check_coros),
|
416
|
-
timeout=self._readiness_timeout,
|
417
|
-
)
|
418
|
-
ctx.log_info("OK")
|
419
|
-
continue
|
420
|
-
except (asyncio.CancelledError, KeyboardInterrupt):
|
421
|
-
for readiness_check in readiness_checks:
|
422
|
-
ctx.log_info("Marked as failed")
|
423
|
-
session.get_task_status(readiness_check).mark_as_failed()
|
424
|
-
except asyncio.TimeoutError:
|
425
|
-
failure_count += 1
|
426
|
-
ctx.log_info("Detecting failure")
|
427
|
-
ctx.log_debug(f"Failure count: {failure_count}")
|
428
|
-
# Readiness check failed, reset
|
429
|
-
ctx.log_info("Resetting")
|
430
|
-
action_coro.cancel()
|
431
|
-
session.get_task_status(self).reset()
|
432
|
-
# defer this action
|
433
|
-
ctx.log_info("Running")
|
434
|
-
action_coro: Coroutine = asyncio.create_task(
|
435
|
-
run_async(self.__exec_action_and_retry(session))
|
436
|
-
)
|
437
|
-
session.defer_action(self, action_coro)
|
438
|
-
failure_count = 0
|
439
|
-
ctx.log_info("Continue monitoring")
|
440
|
-
|
441
|
-
async def __exec_action_and_retry(self, session: AnySession) -> Any:
|
442
|
-
"""
|
443
|
-
Executes an action with retry logic.
|
444
|
-
|
445
|
-
This method attempts to execute the action defined in `_exec_action` with a specified number of retries.
|
446
|
-
If the action fails, it will retry after a specified period until the maximum number of attempts is reached.
|
447
|
-
If the action succeeds, it marks the task as completed and executes any successors.
|
448
|
-
If the action fails permanently, it marks the task as permanently failed and executes any fallbacks.
|
449
|
-
|
450
|
-
Args:
|
451
|
-
session (AnySession): The session object containing the task status and context.
|
452
|
-
|
453
|
-
Returns:
|
454
|
-
Any: The result of the executed action if successful.
|
455
|
-
|
456
|
-
Raises:
|
457
|
-
Exception: If the action fails permanently after all retry attempts.
|
458
|
-
"""
|
459
|
-
ctx = self.get_ctx(session)
|
460
|
-
max_attempt = self._retries + 1
|
461
|
-
ctx.set_max_attempt(max_attempt)
|
462
|
-
for attempt in range(max_attempt):
|
463
|
-
ctx.set_attempt(attempt + 1)
|
464
|
-
if attempt > 0:
|
465
|
-
# apply retry period only if this is not the first attempt
|
466
|
-
await asyncio.sleep(self._retry_period)
|
467
|
-
try:
|
468
|
-
ctx.log_info("Marked as started")
|
469
|
-
session.get_task_status(self).mark_as_started()
|
470
|
-
result = await run_async(self._exec_action(ctx))
|
471
|
-
ctx.log_info("Marked as completed")
|
472
|
-
session.get_task_status(self).mark_as_completed()
|
473
|
-
# Put result on xcom
|
474
|
-
task_xcom: Xcom = ctx.xcom.get(self.name)
|
475
|
-
task_xcom.push(result)
|
476
|
-
self.__skip_fallbacks(session)
|
477
|
-
await run_async(self.__exec_successors(session))
|
478
|
-
return result
|
479
|
-
except (asyncio.CancelledError, KeyboardInterrupt):
|
480
|
-
ctx.log_info("Marked as failed")
|
481
|
-
session.get_task_status(self).mark_as_failed()
|
482
|
-
return
|
483
|
-
except BaseException as e:
|
484
|
-
ctx.log_error(e)
|
485
|
-
if attempt < max_attempt - 1:
|
486
|
-
ctx.log_info("Marked as failed")
|
487
|
-
session.get_task_status(self).mark_as_failed()
|
488
|
-
continue
|
489
|
-
ctx.log_info("Marked as permanently failed")
|
490
|
-
session.get_task_status(self).mark_as_permanently_failed()
|
491
|
-
self.__skip_successors(session)
|
492
|
-
await run_async(self.__exec_fallbacks(session))
|
493
|
-
raise e
|
494
|
-
|
495
|
-
async def __exec_successors(self, session: AnySession) -> Any:
|
496
|
-
successors: list[AnyTask] = self.successors
|
497
|
-
successor_coros = [
|
498
|
-
run_async(successor.exec_chain(session)) for successor in successors
|
499
|
-
]
|
500
|
-
await asyncio.gather(*successor_coros)
|
501
|
-
|
502
|
-
def __skip_successors(self, session: AnySession) -> Any:
|
503
|
-
for successor in self.successors:
|
504
|
-
session.get_task_status(successor).mark_as_skipped()
|
505
|
-
|
506
|
-
async def __exec_fallbacks(self, session: AnySession) -> Any:
|
507
|
-
fallbacks: list[AnyTask] = self.fallbacks
|
508
|
-
fallback_coros = [
|
509
|
-
run_async(fallback.exec_chain(session)) for fallback in fallbacks
|
510
|
-
]
|
511
|
-
await asyncio.gather(*fallback_coros)
|
512
|
-
|
513
|
-
def __skip_fallbacks(self, session: AnySession) -> Any:
|
514
|
-
for fallback in self.fallbacks:
|
515
|
-
session.get_task_status(fallback).mark_as_skipped()
|
228
|
+
return await execute_task_action(self, session)
|
516
229
|
|
517
230
|
async def _exec_action(self, ctx: AnyContext) -> Any:
|
518
|
-
"""
|
519
|
-
|
520
|
-
|
521
|
-
|
231
|
+
"""
|
232
|
+
Execute the main action of the task.
|
233
|
+
This is the primary method to override in subclasses for custom action logic.
|
234
|
+
The default implementation handles the '_action' attribute (string or callable).
|
522
235
|
|
523
236
|
Args:
|
524
|
-
|
237
|
+
ctx (AnyContext): The execution context for this task.
|
525
238
|
|
526
239
|
Returns:
|
527
240
|
Any: The result of the action execution.
|
528
241
|
"""
|
529
|
-
|
530
|
-
|
531
|
-
if isinstance(self._action, str):
|
532
|
-
return ctx.render(self._action)
|
533
|
-
return await run_async(self._action(ctx))
|
242
|
+
# Delegate to the helper function for the default behavior
|
243
|
+
return await run_default_action(self, ctx)
|
@@ -10,6 +10,7 @@ from pydantic_ai.settings import ModelSettings
|
|
10
10
|
|
11
11
|
from zrb.attr.type import BoolAttr
|
12
12
|
from zrb.context.any_context import AnyContext
|
13
|
+
from zrb.llm_config import llm_config
|
13
14
|
from zrb.task.llm.agent import run_agent_iteration
|
14
15
|
from zrb.task.llm.typing import ListOfDict
|
15
16
|
from zrb.util.attr import get_bool_attr
|
@@ -90,16 +91,18 @@ async def enrich_context(
|
|
90
91
|
def should_enrich_context(
|
91
92
|
ctx: AnyContext,
|
92
93
|
history_list: ListOfDict,
|
93
|
-
should_enrich_context_attr: BoolAttr,
|
94
|
+
should_enrich_context_attr: BoolAttr | None, # Allow None
|
94
95
|
render_enrich_context: bool,
|
95
96
|
) -> bool:
|
96
97
|
"""Determines if context enrichment should occur based on history and config."""
|
97
98
|
if len(history_list) == 0:
|
98
99
|
return False
|
100
|
+
# Use llm_config default if attribute is None
|
101
|
+
default_value = llm_config.get_default_enrich_context()
|
99
102
|
return get_bool_attr(
|
100
103
|
ctx,
|
101
104
|
should_enrich_context_attr,
|
102
|
-
|
105
|
+
default_value, # Pass the default from llm_config
|
103
106
|
auto_render=render_enrich_context,
|
104
107
|
)
|
105
108
|
|
@@ -108,7 +111,7 @@ async def maybe_enrich_context(
|
|
108
111
|
ctx: AnyContext,
|
109
112
|
history_list: ListOfDict,
|
110
113
|
conversation_context: dict[str, Any],
|
111
|
-
should_enrich_context_attr: BoolAttr,
|
114
|
+
should_enrich_context_attr: BoolAttr | None, # Allow None
|
112
115
|
render_enrich_context: bool,
|
113
116
|
model: str | Model | None,
|
114
117
|
model_settings: ModelSettings | None,
|
@@ -8,6 +8,7 @@ from pydantic_ai.settings import ModelSettings
|
|
8
8
|
|
9
9
|
from zrb.attr.type import BoolAttr, IntAttr
|
10
10
|
from zrb.context.any_context import AnyContext
|
11
|
+
from zrb.llm_config import llm_config
|
11
12
|
from zrb.task.llm.agent import run_agent_iteration
|
12
13
|
from zrb.task.llm.typing import ListOfDict
|
13
14
|
from zrb.util.attr import get_bool_attr, get_int_attr
|
@@ -26,7 +27,7 @@ def get_history_part_len(history_list: ListOfDict) -> int:
|
|
26
27
|
|
27
28
|
def get_history_summarization_threshold(
|
28
29
|
ctx: AnyContext,
|
29
|
-
history_summarization_threshold_attr: IntAttr,
|
30
|
+
history_summarization_threshold_attr: IntAttr | None,
|
30
31
|
render_history_summarization_threshold: bool,
|
31
32
|
) -> int:
|
32
33
|
"""Gets the history summarization threshold, handling defaults and errors."""
|
@@ -34,7 +35,8 @@ def get_history_summarization_threshold(
|
|
34
35
|
return get_int_attr(
|
35
36
|
ctx,
|
36
37
|
history_summarization_threshold_attr,
|
37
|
-
|
38
|
+
# Use llm_config default if attribute is None
|
39
|
+
llm_config.get_default_history_summarization_threshold(),
|
38
40
|
auto_render=render_history_summarization_threshold,
|
39
41
|
)
|
40
42
|
except ValueError as e:
|
@@ -48,9 +50,9 @@ def get_history_summarization_threshold(
|
|
48
50
|
def should_summarize_history(
|
49
51
|
ctx: AnyContext,
|
50
52
|
history_list: ListOfDict,
|
51
|
-
should_summarize_history_attr: BoolAttr,
|
53
|
+
should_summarize_history_attr: BoolAttr | None, # Allow None
|
52
54
|
render_summarize_history: bool,
|
53
|
-
history_summarization_threshold_attr: IntAttr,
|
55
|
+
history_summarization_threshold_attr: IntAttr | None, # Allow None
|
54
56
|
render_history_summarization_threshold: bool,
|
55
57
|
) -> bool:
|
56
58
|
"""Determines if history summarization should occur based on length and config."""
|
@@ -69,7 +71,8 @@ def should_summarize_history(
|
|
69
71
|
return get_bool_attr(
|
70
72
|
ctx,
|
71
73
|
should_summarize_history_attr,
|
72
|
-
|
74
|
+
# Use llm_config default if attribute is None
|
75
|
+
llm_config.get_default_summarize_history(),
|
73
76
|
auto_render=render_summarize_history,
|
74
77
|
)
|
75
78
|
|
@@ -137,9 +140,9 @@ async def maybe_summarize_history(
|
|
137
140
|
ctx: AnyContext,
|
138
141
|
history_list: ListOfDict,
|
139
142
|
conversation_context: dict[str, Any],
|
140
|
-
should_summarize_history_attr: BoolAttr,
|
143
|
+
should_summarize_history_attr: BoolAttr | None, # Allow None
|
141
144
|
render_summarize_history: bool,
|
142
|
-
history_summarization_threshold_attr: IntAttr,
|
145
|
+
history_summarization_threshold_attr: IntAttr | None, # Allow None
|
143
146
|
render_history_summarization_threshold: bool,
|
144
147
|
model: str | Model | None,
|
145
148
|
model_settings: ModelSettings | None,
|