zrb 1.5.8__py3-none-any.whl → 1.5.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
zrb/task/base_task.py CHANGED
@@ -1,19 +1,25 @@
1
1
  import asyncio
2
- import os
3
- from collections.abc import Callable, Coroutine
2
+ from collections.abc import Callable
4
3
  from typing import Any
5
4
 
6
5
  from zrb.attr.type import BoolAttr, fstring
7
6
  from zrb.context.any_context import AnyContext
8
- from zrb.context.shared_context import AnySharedContext, SharedContext
9
7
  from zrb.env.any_env import AnyEnv
10
8
  from zrb.input.any_input import AnyInput
11
9
  from zrb.session.any_session import AnySession
12
- from zrb.session.session import Session
13
10
  from zrb.task.any_task import AnyTask
14
- from zrb.util.attr import get_bool_attr
15
- from zrb.util.run import run_async
16
- from zrb.xcom.xcom import Xcom
11
+ from zrb.task.base.context import (
12
+ build_task_context,
13
+ get_combined_envs,
14
+ get_combined_inputs,
15
+ )
16
+ from zrb.task.base.execution import (
17
+ execute_task_action,
18
+ execute_task_chain,
19
+ run_default_action,
20
+ )
21
+ from zrb.task.base.lifecycle import execute_root_tasks, run_and_cleanup, run_task_async
22
+ from zrb.task.base.operators import handle_lshift, handle_rshift
17
23
 
18
24
 
19
25
  class BaseTask(AnyTask):
@@ -62,25 +68,13 @@ class BaseTask(AnyTask):
62
68
  self._action = action
63
69
 
64
70
  def __repr__(self):
65
- return f"<{self.__class__.__name__} name={self._name}>"
71
+ return f"<{self.__class__.__name__} name={self.name}>"
66
72
 
67
73
  def __rshift__(self, other: AnyTask | list[AnyTask]) -> AnyTask | list[AnyTask]:
68
- try:
69
- if isinstance(other, AnyTask):
70
- other.append_upstream(self)
71
- elif isinstance(other, list):
72
- for task in other:
73
- task.append_upstream(self)
74
- return other
75
- except Exception as e:
76
- raise ValueError(f"Invalid operation {self} >> {other}: {e}")
74
+ return handle_rshift(self, other)
77
75
 
78
76
  def __lshift__(self, other: AnyTask | list[AnyTask]) -> AnyTask:
79
- try:
80
- self.append_upstream(other)
81
- return self
82
- except Exception as e:
83
- raise ValueError(f"Invalid operation {self} << {other}: {e}")
77
+ return handle_lshift(self, other)
84
78
 
85
79
  @property
86
80
  def name(self) -> str:
@@ -104,430 +98,146 @@ class BaseTask(AnyTask):
104
98
 
105
99
  @property
106
100
  def envs(self) -> list[AnyEnv]:
107
- envs = []
108
- for upstream in self.upstreams:
109
- envs += upstream.envs
110
- if isinstance(self._envs, AnyEnv):
111
- envs.append(self._envs)
112
- elif self._envs is not None:
113
- envs += self._envs
114
- return [env for env in envs if env is not None]
101
+ return get_combined_envs(self)
115
102
 
116
103
  @property
117
104
  def inputs(self) -> list[AnyInput]:
118
- inputs = []
119
- for upstream in self.upstreams:
120
- self.__combine_inputs(inputs, upstream.inputs)
121
- if self._inputs is not None:
122
- self.__combine_inputs(inputs, self._inputs)
123
- return [task_input for task_input in inputs if task_input is not None]
124
-
125
- def __combine_inputs(
126
- self,
127
- inputs: list[AnyInput],
128
- other_inputs: list[AnyInput | None] | AnyInput | None,
129
- ):
130
- input_names = [task_input.name for task_input in inputs]
131
- if isinstance(other_inputs, AnyInput):
132
- other_inputs = [other_inputs]
133
- elif other_inputs is None:
134
- other_inputs = []
135
- for task_input in other_inputs:
136
- if task_input is None:
137
- continue
138
- if task_input.name not in input_names:
139
- inputs.append(task_input)
105
+ return get_combined_inputs(self)
140
106
 
141
107
  @property
142
108
  def fallbacks(self) -> list[AnyTask]:
109
+ """Returns the list of fallback tasks."""
143
110
  if self._fallbacks is None:
144
111
  return []
145
- elif isinstance(self._fallbacks, AnyTask):
146
- return [self._fallbacks]
147
- return self._fallbacks
112
+ elif isinstance(self._fallbacks, list):
113
+ return self._fallbacks
114
+ return [self._fallbacks] # Assume single task
148
115
 
149
116
  def append_fallback(self, fallbacks: AnyTask | list[AnyTask]):
150
- fallback_list = [fallbacks] if isinstance(fallbacks, AnyTask) else fallbacks
151
- for fallback in fallback_list:
152
- self.__append_fallback(fallback)
153
-
154
- def __append_fallback(self, fallback: AnyTask):
155
- # Make sure self._fallbacks is a list
117
+ """Appends fallback tasks, ensuring no duplicates."""
156
118
  if self._fallbacks is None:
157
119
  self._fallbacks = []
158
- elif isinstance(self._fallbacks, AnyTask):
120
+ elif not isinstance(self._fallbacks, list):
159
121
  self._fallbacks = [self._fallbacks]
160
- # Add fallback if it was not on self._fallbacks
161
- if fallback not in self._fallbacks:
162
- self._fallbacks.append(fallback)
122
+ to_add = fallbacks if isinstance(fallbacks, list) else [fallbacks]
123
+ for fb in to_add:
124
+ if fb not in self._fallbacks:
125
+ self._fallbacks.append(fb)
163
126
 
164
127
  @property
165
128
  def successors(self) -> list[AnyTask]:
129
+ """Returns the list of successor tasks."""
166
130
  if self._successors is None:
167
131
  return []
168
- elif isinstance(self._successors, AnyTask):
169
- return [self._successors]
170
- return self._successors
132
+ elif isinstance(self._successors, list):
133
+ return self._successors
134
+ return [self._successors] # Assume single task
171
135
 
172
136
  def append_successor(self, successors: AnyTask | list[AnyTask]):
173
- successor_list = [successors] if isinstance(successors, AnyTask) else successors
174
- for successor in successor_list:
175
- self.__append_successor(successor)
176
-
177
- def __append_successor(self, successor: AnyTask):
178
- # Make sure self._successors is a list
137
+ """Appends successor tasks, ensuring no duplicates."""
179
138
  if self._successors is None:
180
139
  self._successors = []
181
- elif isinstance(self._successors, AnyTask):
140
+ elif not isinstance(self._successors, list):
182
141
  self._successors = [self._successors]
183
- # Add successor if it was not on self._successors
184
- if successor not in self._successors:
185
- self._successors.append(successor)
142
+ to_add = successors if isinstance(successors, list) else [successors]
143
+ for succ in to_add:
144
+ if succ not in self._successors:
145
+ self._successors.append(succ)
186
146
 
187
147
  @property
188
148
  def readiness_checks(self) -> list[AnyTask]:
149
+ """Returns the list of readiness check tasks."""
189
150
  if self._readiness_checks is None:
190
151
  return []
191
- elif isinstance(self._readiness_checks, AnyTask):
192
- return [self._readiness_checks]
193
- return self._readiness_checks
152
+ elif isinstance(self._readiness_checks, list):
153
+ return self._readiness_checks
154
+ return [self._readiness_checks] # Assume single task
194
155
 
195
156
  def append_readiness_check(self, readiness_checks: AnyTask | list[AnyTask]):
196
- readiness_check_list = (
197
- [readiness_checks]
198
- if isinstance(readiness_checks, AnyTask)
199
- else readiness_checks
200
- )
201
- for readiness_check in readiness_check_list:
202
- self.__append_readiness_check(readiness_check)
203
-
204
- def __append_readiness_check(self, readiness_check: AnyTask):
205
- # Make sure self._readiness_checks is a list
157
+ """Appends readiness check tasks, ensuring no duplicates."""
206
158
  if self._readiness_checks is None:
207
159
  self._readiness_checks = []
208
- elif isinstance(self._readiness_checks, AnyTask):
160
+ elif not isinstance(self._readiness_checks, list):
209
161
  self._readiness_checks = [self._readiness_checks]
210
- # Add readiness_check if it was not on self._readiness_checks
211
- if readiness_check not in self._readiness_checks:
212
- self._readiness_checks.append(readiness_check)
162
+ to_add = (
163
+ readiness_checks
164
+ if isinstance(readiness_checks, list)
165
+ else [readiness_checks]
166
+ )
167
+ for rc in to_add:
168
+ if rc not in self._readiness_checks:
169
+ self._readiness_checks.append(rc)
213
170
 
214
171
  @property
215
172
  def upstreams(self) -> list[AnyTask]:
173
+ """Returns the list of upstream tasks."""
216
174
  if self._upstreams is None:
217
175
  return []
218
- elif isinstance(self._upstreams, AnyTask):
219
- return [self._upstreams]
220
- return self._upstreams
176
+ elif isinstance(self._upstreams, list):
177
+ return self._upstreams
178
+ return [self._upstreams] # Assume single task
221
179
 
222
180
  def append_upstream(self, upstreams: AnyTask | list[AnyTask]):
223
- upstream_list = [upstreams] if isinstance(upstreams, AnyTask) else upstreams
224
- for upstream in upstream_list:
225
- self.__append_upstream(upstream)
226
-
227
- def __append_upstream(self, upstream: AnyTask):
228
- # Make sure self._upstreams is a list
181
+ """Appends upstream tasks, ensuring no duplicates."""
229
182
  if self._upstreams is None:
230
183
  self._upstreams = []
231
- elif isinstance(self._upstreams, AnyTask):
184
+ elif not isinstance(self._upstreams, list):
232
185
  self._upstreams = [self._upstreams]
233
- # Add upstream if it was not on self._upstreams
234
- if upstream not in self._upstreams:
235
- self._upstreams.append(upstream)
186
+ to_add = upstreams if isinstance(upstreams, list) else [upstreams]
187
+ for up in to_add:
188
+ if up not in self._upstreams:
189
+ self._upstreams.append(up)
236
190
 
237
191
  def get_ctx(self, session: AnySession) -> AnyContext:
238
- ctx = session.get_ctx(self)
239
- # Enhance session ctx with current task env
240
- for env in self.envs:
241
- env.update_context(ctx)
242
- return ctx
192
+ return build_task_context(self, session)
243
193
 
244
194
  def run(
245
195
  self, session: AnySession | None = None, str_kwargs: dict[str, str] = {}
246
196
  ) -> Any:
247
- return asyncio.run(self._run_and_cleanup(session, str_kwargs))
197
+ """
198
+ Synchronously runs the task and its dependencies, handling async setup and cleanup.
248
199
 
249
- async def _run_and_cleanup(
250
- self,
251
- session: AnySession | None = None,
252
- str_kwargs: dict[str, str] = {},
253
- ) -> Any:
254
- current_task = asyncio.create_task(self.async_run(session, str_kwargs))
255
- try:
256
- result = await current_task
257
- finally:
258
- if session and not session.is_terminated:
259
- session.terminate()
260
- # Cancel all running tasks except the current one
261
- pending = [task for task in asyncio.all_tasks() if task is not current_task]
262
- for task in pending:
263
- task.cancel()
264
- if pending:
265
- try:
266
- await asyncio.wait(pending, timeout=5)
267
- except asyncio.CancelledError:
268
- pass
269
- return result
200
+ Uses `asyncio.run()` internally, which creates a new event loop.
201
+ WARNING: Do not call this method from within an already running asyncio
202
+ event loop, as it will raise a RuntimeError. Use `async_run` instead
203
+ if you are in an async context.
204
+
205
+ Args:
206
+ session (AnySession | None): The session to use. If None, a new one
207
+ might be created implicitly.
208
+ str_kwargs (dict[str, str]): String-based key-value arguments for inputs.
209
+
210
+ Returns:
211
+ Any: The final result of the main task execution.
212
+ """
213
+ # Use asyncio.run() to execute the async cleanup wrapper
214
+ return asyncio.run(run_and_cleanup(self, session, str_kwargs))
270
215
 
271
216
  async def async_run(
272
217
  self, session: AnySession | None = None, str_kwargs: dict[str, str] = {}
273
218
  ) -> Any:
274
- if session is None:
275
- session = Session(shared_ctx=SharedContext())
276
- # Update session
277
- self.__fill_shared_context_inputs(session.shared_ctx, str_kwargs)
278
- self.__fill_shared_context_envs(session.shared_ctx)
279
- result = await run_async(self.exec_root_tasks(session))
280
- return result
281
-
282
- def __fill_shared_context_inputs(
283
- self, shared_context: AnySharedContext, str_kwargs: dict[str, str] = {}
284
- ):
285
- for task_input in self.inputs:
286
- if task_input.name not in shared_context.input:
287
- str_value = str_kwargs.get(task_input.name, None)
288
- task_input.update_shared_context(shared_context, str_value)
289
-
290
- def __fill_shared_context_envs(self, shared_context: AnySharedContext):
291
- # Inject os environ
292
- os_env_map = {
293
- key: val for key, val in os.environ.items() if key not in shared_context.env
294
- }
295
- shared_context.env.update(os_env_map)
219
+ return await run_task_async(self, session, str_kwargs)
296
220
 
297
221
  async def exec_root_tasks(self, session: AnySession):
298
- session.set_main_task(self)
299
- session.state_logger.write(session.as_state_log())
300
- try:
301
- log_state = asyncio.create_task(self._log_session_state(session))
302
- root_tasks = [
303
- task
304
- for task in session.get_root_tasks(self)
305
- if session.is_allowed_to_run(task)
306
- ]
307
- root_task_coros = [
308
- run_async(root_task.exec_chain(session)) for root_task in root_tasks
309
- ]
310
- await asyncio.gather(*root_task_coros)
311
- await session.wait_deferred()
312
- session.terminate()
313
- await log_state
314
- return session.final_result
315
- except IndexError:
316
- return None
317
- except (asyncio.CancelledError, KeyboardInterrupt):
318
- ctx = self.get_ctx(session)
319
- ctx.log_info("Session terminated")
320
- finally:
321
- session.terminate()
322
- session.state_logger.write(session.as_state_log())
323
- ctx = self.get_ctx(session)
324
- ctx.log_debug(session)
325
-
326
- async def _log_session_state(self, session: AnySession):
327
- session.state_logger.write(session.as_state_log())
328
- while not session.is_terminated:
329
- await asyncio.sleep(0.1)
330
- session.state_logger.write(session.as_state_log())
331
- session.state_logger.write(session.as_state_log())
222
+ return await execute_root_tasks(self, session)
332
223
 
333
224
  async def exec_chain(self, session: AnySession):
334
- if session.is_terminated or not session.is_allowed_to_run(self):
335
- return
336
- result = await self.exec(session)
337
- # Get next tasks
338
- nexts = session.get_next_tasks(self)
339
- if session.is_terminated or len(nexts) == 0:
340
- return result
341
- # Run next tasks asynchronously
342
- next_coros = [run_async(next.exec_chain(session)) for next in nexts]
343
- return await asyncio.gather(*next_coros)
225
+ return await execute_task_chain(self, session)
344
226
 
345
227
  async def exec(self, session: AnySession):
346
- ctx = self.get_ctx(session)
347
- if not session.is_allowed_to_run(self):
348
- # Task is not allowed to run, skip it for now.
349
- # This will be triggered later
350
- ctx.log_info("Not allowed to run")
351
- return
352
- if not self.__get_execute_condition(session):
353
- # Skip the task
354
- ctx.log_info("Marked as skipped")
355
- session.get_task_status(self).mark_as_skipped()
356
- return
357
- # Wait for task to be ready
358
- return await run_async(self.__exec_action_until_ready(session))
359
-
360
- def __get_execute_condition(self, session: Session) -> bool:
361
- ctx = self.get_ctx(session)
362
- return get_bool_attr(ctx, self._execute_condition, True, auto_render=True)
363
-
364
- async def __exec_action_until_ready(self, session: AnySession):
365
- ctx = self.get_ctx(session)
366
- readiness_checks = self.readiness_checks
367
- if len(readiness_checks) == 0:
368
- ctx.log_info("No readiness checks")
369
- # Task has no readiness check
370
- result = await run_async(self.__exec_action_and_retry(session))
371
- ctx.log_info("Marked as ready")
372
- session.get_task_status(self).mark_as_ready()
373
- return result
374
- # Start the task along with the readiness checks
375
- action_coro = asyncio.create_task(
376
- run_async(self.__exec_action_and_retry(session))
377
- )
378
- await asyncio.sleep(self._readiness_check_delay)
379
- readiness_check_coros = [
380
- run_async(check.exec_chain(session)) for check in readiness_checks
381
- ]
382
- # Only wait for readiness checks and mark the task as ready
383
- ctx.log_info("Start readiness checks")
384
- result = await asyncio.gather(*readiness_check_coros)
385
- ctx.log_info("Readiness checks completed")
386
- ctx.log_info("Marked as ready")
387
- session.get_task_status(self).mark_as_ready()
388
- # Defer task's coroutines, will be waited later
389
- session.defer_action(self, action_coro)
390
- if self._monitor_readiness:
391
- monitor_and_rerun_coro = asyncio.create_task(
392
- run_async(self.__exec_monitoring(session, action_coro))
393
- )
394
- session.defer_monitoring(self, monitor_and_rerun_coro)
395
- return result
396
-
397
- async def __exec_monitoring(self, session: AnySession, action_coro: asyncio.Task):
398
- readiness_checks = self.readiness_checks
399
- failure_count = 0
400
- ctx = self.get_ctx(session)
401
- while not session.is_terminated:
402
- await asyncio.sleep(self._readiness_check_period)
403
- if failure_count < self._readiness_failure_threshold:
404
- for readiness_check in readiness_checks:
405
- session.get_task_status(readiness_check).reset_history()
406
- session.get_task_status(readiness_check).reset()
407
- readiness_xcom: Xcom = ctx.xcom[self.name]
408
- readiness_xcom.clear()
409
- readiness_check_coros = [
410
- check.exec_chain(session) for check in readiness_checks
411
- ]
412
- try:
413
- ctx.log_info("Checking")
414
- await asyncio.wait_for(
415
- asyncio.gather(*readiness_check_coros),
416
- timeout=self._readiness_timeout,
417
- )
418
- ctx.log_info("OK")
419
- continue
420
- except (asyncio.CancelledError, KeyboardInterrupt):
421
- for readiness_check in readiness_checks:
422
- ctx.log_info("Marked as failed")
423
- session.get_task_status(readiness_check).mark_as_failed()
424
- except asyncio.TimeoutError:
425
- failure_count += 1
426
- ctx.log_info("Detecting failure")
427
- ctx.log_debug(f"Failure count: {failure_count}")
428
- # Readiness check failed, reset
429
- ctx.log_info("Resetting")
430
- action_coro.cancel()
431
- session.get_task_status(self).reset()
432
- # defer this action
433
- ctx.log_info("Running")
434
- action_coro: Coroutine = asyncio.create_task(
435
- run_async(self.__exec_action_and_retry(session))
436
- )
437
- session.defer_action(self, action_coro)
438
- failure_count = 0
439
- ctx.log_info("Continue monitoring")
440
-
441
- async def __exec_action_and_retry(self, session: AnySession) -> Any:
442
- """
443
- Executes an action with retry logic.
444
-
445
- This method attempts to execute the action defined in `_exec_action` with a specified number of retries.
446
- If the action fails, it will retry after a specified period until the maximum number of attempts is reached.
447
- If the action succeeds, it marks the task as completed and executes any successors.
448
- If the action fails permanently, it marks the task as permanently failed and executes any fallbacks.
449
-
450
- Args:
451
- session (AnySession): The session object containing the task status and context.
452
-
453
- Returns:
454
- Any: The result of the executed action if successful.
455
-
456
- Raises:
457
- Exception: If the action fails permanently after all retry attempts.
458
- """
459
- ctx = self.get_ctx(session)
460
- max_attempt = self._retries + 1
461
- ctx.set_max_attempt(max_attempt)
462
- for attempt in range(max_attempt):
463
- ctx.set_attempt(attempt + 1)
464
- if attempt > 0:
465
- # apply retry period only if this is not the first attempt
466
- await asyncio.sleep(self._retry_period)
467
- try:
468
- ctx.log_info("Marked as started")
469
- session.get_task_status(self).mark_as_started()
470
- result = await run_async(self._exec_action(ctx))
471
- ctx.log_info("Marked as completed")
472
- session.get_task_status(self).mark_as_completed()
473
- # Put result on xcom
474
- task_xcom: Xcom = ctx.xcom.get(self.name)
475
- task_xcom.push(result)
476
- self.__skip_fallbacks(session)
477
- await run_async(self.__exec_successors(session))
478
- return result
479
- except (asyncio.CancelledError, KeyboardInterrupt):
480
- ctx.log_info("Marked as failed")
481
- session.get_task_status(self).mark_as_failed()
482
- return
483
- except BaseException as e:
484
- ctx.log_error(e)
485
- if attempt < max_attempt - 1:
486
- ctx.log_info("Marked as failed")
487
- session.get_task_status(self).mark_as_failed()
488
- continue
489
- ctx.log_info("Marked as permanently failed")
490
- session.get_task_status(self).mark_as_permanently_failed()
491
- self.__skip_successors(session)
492
- await run_async(self.__exec_fallbacks(session))
493
- raise e
494
-
495
- async def __exec_successors(self, session: AnySession) -> Any:
496
- successors: list[AnyTask] = self.successors
497
- successor_coros = [
498
- run_async(successor.exec_chain(session)) for successor in successors
499
- ]
500
- await asyncio.gather(*successor_coros)
501
-
502
- def __skip_successors(self, session: AnySession) -> Any:
503
- for successor in self.successors:
504
- session.get_task_status(successor).mark_as_skipped()
505
-
506
- async def __exec_fallbacks(self, session: AnySession) -> Any:
507
- fallbacks: list[AnyTask] = self.fallbacks
508
- fallback_coros = [
509
- run_async(fallback.exec_chain(session)) for fallback in fallbacks
510
- ]
511
- await asyncio.gather(*fallback_coros)
512
-
513
- def __skip_fallbacks(self, session: AnySession) -> Any:
514
- for fallback in self.fallbacks:
515
- session.get_task_status(fallback).mark_as_skipped()
228
+ return await execute_task_action(self, session)
516
229
 
517
230
  async def _exec_action(self, ctx: AnyContext) -> Any:
518
- """Execute the main action of the task.
519
- By default will render and run the _action attribute.
520
-
521
- Override this method to define custom action.
231
+ """
232
+ Execute the main action of the task.
233
+ This is the primary method to override in subclasses for custom action logic.
234
+ The default implementation handles the '_action' attribute (string or callable).
522
235
 
523
236
  Args:
524
- session (AnySession): The shared session.
237
+ ctx (AnyContext): The execution context for this task.
525
238
 
526
239
  Returns:
527
240
  Any: The result of the action execution.
528
241
  """
529
- if self._action is None:
530
- return
531
- if isinstance(self._action, str):
532
- return ctx.render(self._action)
533
- return await run_async(self._action(ctx))
242
+ # Delegate to the helper function for the default behavior
243
+ return await run_default_action(self, ctx)
@@ -10,6 +10,7 @@ from pydantic_ai.settings import ModelSettings
10
10
 
11
11
  from zrb.attr.type import BoolAttr
12
12
  from zrb.context.any_context import AnyContext
13
+ from zrb.llm_config import llm_config
13
14
  from zrb.task.llm.agent import run_agent_iteration
14
15
  from zrb.task.llm.typing import ListOfDict
15
16
  from zrb.util.attr import get_bool_attr
@@ -90,16 +91,18 @@ async def enrich_context(
90
91
  def should_enrich_context(
91
92
  ctx: AnyContext,
92
93
  history_list: ListOfDict,
93
- should_enrich_context_attr: BoolAttr,
94
+ should_enrich_context_attr: BoolAttr | None, # Allow None
94
95
  render_enrich_context: bool,
95
96
  ) -> bool:
96
97
  """Determines if context enrichment should occur based on history and config."""
97
98
  if len(history_list) == 0:
98
99
  return False
100
+ # Use llm_config default if attribute is None
101
+ default_value = llm_config.get_default_enrich_context()
99
102
  return get_bool_attr(
100
103
  ctx,
101
104
  should_enrich_context_attr,
102
- True, # Default to True if not specified
105
+ default_value, # Pass the default from llm_config
103
106
  auto_render=render_enrich_context,
104
107
  )
105
108
 
@@ -108,7 +111,7 @@ async def maybe_enrich_context(
108
111
  ctx: AnyContext,
109
112
  history_list: ListOfDict,
110
113
  conversation_context: dict[str, Any],
111
- should_enrich_context_attr: BoolAttr,
114
+ should_enrich_context_attr: BoolAttr | None, # Allow None
112
115
  render_enrich_context: bool,
113
116
  model: str | Model | None,
114
117
  model_settings: ModelSettings | None,
@@ -8,6 +8,7 @@ from pydantic_ai.settings import ModelSettings
8
8
 
9
9
  from zrb.attr.type import BoolAttr, IntAttr
10
10
  from zrb.context.any_context import AnyContext
11
+ from zrb.llm_config import llm_config
11
12
  from zrb.task.llm.agent import run_agent_iteration
12
13
  from zrb.task.llm.typing import ListOfDict
13
14
  from zrb.util.attr import get_bool_attr, get_int_attr
@@ -26,7 +27,7 @@ def get_history_part_len(history_list: ListOfDict) -> int:
26
27
 
27
28
  def get_history_summarization_threshold(
28
29
  ctx: AnyContext,
29
- history_summarization_threshold_attr: IntAttr,
30
+ history_summarization_threshold_attr: IntAttr | None,
30
31
  render_history_summarization_threshold: bool,
31
32
  ) -> int:
32
33
  """Gets the history summarization threshold, handling defaults and errors."""
@@ -34,7 +35,8 @@ def get_history_summarization_threshold(
34
35
  return get_int_attr(
35
36
  ctx,
36
37
  history_summarization_threshold_attr,
37
- -1, # Default to -1 (no threshold)
38
+ # Use llm_config default if attribute is None
39
+ llm_config.get_default_history_summarization_threshold(),
38
40
  auto_render=render_history_summarization_threshold,
39
41
  )
40
42
  except ValueError as e:
@@ -48,9 +50,9 @@ def get_history_summarization_threshold(
48
50
  def should_summarize_history(
49
51
  ctx: AnyContext,
50
52
  history_list: ListOfDict,
51
- should_summarize_history_attr: BoolAttr,
53
+ should_summarize_history_attr: BoolAttr | None, # Allow None
52
54
  render_summarize_history: bool,
53
- history_summarization_threshold_attr: IntAttr,
55
+ history_summarization_threshold_attr: IntAttr | None, # Allow None
54
56
  render_history_summarization_threshold: bool,
55
57
  ) -> bool:
56
58
  """Determines if history summarization should occur based on length and config."""
@@ -69,7 +71,8 @@ def should_summarize_history(
69
71
  return get_bool_attr(
70
72
  ctx,
71
73
  should_summarize_history_attr,
72
- False, # Default to False if not specified
74
+ # Use llm_config default if attribute is None
75
+ llm_config.get_default_summarize_history(),
73
76
  auto_render=render_summarize_history,
74
77
  )
75
78
 
@@ -137,9 +140,9 @@ async def maybe_summarize_history(
137
140
  ctx: AnyContext,
138
141
  history_list: ListOfDict,
139
142
  conversation_context: dict[str, Any],
140
- should_summarize_history_attr: BoolAttr,
143
+ should_summarize_history_attr: BoolAttr | None, # Allow None
141
144
  render_summarize_history: bool,
142
- history_summarization_threshold_attr: IntAttr,
145
+ history_summarization_threshold_attr: IntAttr | None, # Allow None
143
146
  render_history_summarization_threshold: bool,
144
147
  model: str | Model | None,
145
148
  model_settings: ModelSettings | None,