zrb 1.5.4__py3-none-any.whl → 1.5.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
zrb/task/llm_task.py CHANGED
@@ -1,28 +1,15 @@
1
- import functools
2
1
  import inspect
3
2
  import json
4
- import os
5
- import traceback
6
3
  from collections.abc import Callable
4
+ from textwrap import dedent
7
5
  from typing import Any
8
6
 
9
- from openai import APIError
10
7
  from pydantic_ai import Agent, Tool
11
8
  from pydantic_ai.mcp import MCPServer
12
- from pydantic_ai.messages import (
13
- FinalResultEvent,
14
- FunctionToolCallEvent,
15
- FunctionToolResultEvent,
16
- ModelMessagesTypeAdapter,
17
- PartDeltaEvent,
18
- PartStartEvent,
19
- TextPartDelta,
20
- ToolCallPartDelta,
21
- )
22
9
  from pydantic_ai.models import Model
23
10
  from pydantic_ai.settings import ModelSettings
24
11
 
25
- from zrb.attr.type import StrAttr, fstring
12
+ from zrb.attr.type import BoolAttr, IntAttr, StrAttr, fstring
26
13
  from zrb.context.any_context import AnyContext
27
14
  from zrb.context.any_shared_context import AnySharedContext
28
15
  from zrb.env.any_env import AnyEnv
@@ -31,12 +18,18 @@ from zrb.llm_config import LLMConfig
31
18
  from zrb.llm_config import llm_config as default_llm_config
32
19
  from zrb.task.any_task import AnyTask
33
20
  from zrb.task.base_task import BaseTask
34
- from zrb.util.attr import get_attr, get_str_attr
35
- from zrb.util.cli.style import stylize_faint
36
- from zrb.util.file import read_file, write_file
21
+ from zrb.task.llm.agent_runner import run_agent_iteration
22
+ from zrb.task.llm.context_enricher import EnrichmentConfig, enrich_context
23
+ from zrb.task.llm.default_context import get_default_context
24
+ from zrb.task.llm.history import ConversationHistoryData, ListOfDict
25
+ from zrb.task.llm.history_summarizer import SummarizationConfig, summarize_history
26
+ from zrb.task.llm.tool_wrapper import wrap_tool
27
+ from zrb.util.attr import get_attr, get_bool_attr, get_int_attr, get_str_attr
28
+ from zrb.util.file import write_file
37
29
  from zrb.util.run import run_async
38
30
 
39
- ListOfDict = list[dict[str, Any]]
31
+ # ListOfDict moved to history.py
32
+ # Removed old ConversationHistoryData type alias
40
33
  ToolOrCallable = Tool | Callable
41
34
 
42
35
 
@@ -65,6 +58,12 @@ class LLMTask(BaseTask):
65
58
  system_prompt: StrAttr | None = None,
66
59
  render_system_prompt: bool = True,
67
60
  message: StrAttr | None = None,
61
+ summarization_prompt: StrAttr | None = None,
62
+ render_summarization_prompt: bool = True,
63
+ enrich_context: BoolAttr = False,
64
+ render_enrich_context: bool = True,
65
+ context_enrichment_prompt: StrAttr | None = None,
66
+ render_context_enrichment_prompt: bool = True,
68
67
  tools: (
69
68
  list[ToolOrCallable] | Callable[[AnySharedContext], list[ToolOrCallable]]
70
69
  ) = [],
@@ -72,16 +71,29 @@ class LLMTask(BaseTask):
72
71
  list[MCPServer] | Callable[[AnySharedContext], list[MCPServer]]
73
72
  ) = [],
74
73
  conversation_history: (
75
- ListOfDict | Callable[[AnySharedContext], ListOfDict]
76
- ) = [],
74
+ ConversationHistoryData # Use the new BaseModel
75
+ | Callable[
76
+ [AnySharedContext], ConversationHistoryData | dict | list
77
+ ] # Allow returning raw dict/list too
78
+ | dict # Allow raw dict
79
+ | list # Allow raw list (old format)
80
+ ) = ConversationHistoryData(), # Default to an empty model instance
77
81
  conversation_history_reader: (
78
- Callable[[AnySharedContext], ListOfDict] | None
82
+ Callable[[AnySharedContext], ConversationHistoryData | dict | list | None]
83
+ | None
84
+ # Allow returning raw dict/list or None
79
85
  ) = None,
80
86
  conversation_history_writer: (
81
- Callable[[AnySharedContext, ListOfDict], None] | None
87
+ Callable[[AnySharedContext, ConversationHistoryData], None]
88
+ | None
89
+ # Writer expects the model instance
82
90
  ) = None,
83
91
  conversation_history_file: StrAttr | None = None,
84
92
  render_history_file: bool = True,
93
+ summarize_history: BoolAttr = True,
94
+ render_summarize_history: bool = True,
95
+ history_summarization_threshold: IntAttr = 5, # -1 means no summarization trigger
96
+ render_history_summarization_threshold: bool = True,
85
97
  execute_condition: bool | str | Callable[[AnySharedContext], bool] = True,
86
98
  retries: int = 2,
87
99
  retry_period: float = 0,
@@ -95,6 +107,9 @@ class LLMTask(BaseTask):
95
107
  upstream: list[AnyTask] | AnyTask | None = None,
96
108
  fallback: list[AnyTask] | AnyTask | None = None,
97
109
  successor: list[AnyTask] | AnyTask | None = None,
110
+ conversation_context: (
111
+ dict[str, Any] | Callable[[AnySharedContext], dict[str, Any]] | None
112
+ ) = None,
98
113
  ):
99
114
  super().__init__(
100
115
  name=name,
@@ -128,6 +143,12 @@ class LLMTask(BaseTask):
128
143
  self._system_prompt = system_prompt
129
144
  self._render_system_prompt = render_system_prompt
130
145
  self._message = message
146
+ self._summarization_prompt = summarization_prompt
147
+ self._render_summarization_prompt = render_summarization_prompt
148
+ self._should_enrich_context = enrich_context
149
+ self._render_enrich_context = render_enrich_context
150
+ self._context_enrichment_prompt = context_enrichment_prompt
151
+ self._render_context_enrichment_prompt = render_context_enrichment_prompt
131
152
  self._tools = tools
132
153
  self._additional_tools: list[ToolOrCallable] = []
133
154
  self._mcp_servers = mcp_servers
@@ -137,7 +158,14 @@ class LLMTask(BaseTask):
137
158
  self._conversation_history_writer = conversation_history_writer
138
159
  self._conversation_history_file = conversation_history_file
139
160
  self._render_history_file = render_history_file
161
+ self._should_summarize_history = summarize_history
162
+ self._render_summarize_history = render_summarize_history
163
+ self._history_summarization_threshold = history_summarization_threshold
164
+ self._render_history_summarization_threshold = (
165
+ render_history_summarization_threshold
166
+ )
140
167
  self._max_call_iteration = max_call_iteration
168
+ self._conversation_context = conversation_context
141
169
 
142
170
  def add_tool(self, tool: ToolOrCallable):
143
171
  self._additional_tools.append(tool)
@@ -145,128 +173,90 @@ class LLMTask(BaseTask):
145
173
  def add_mcp_server(self, mcp_server: MCPServer):
146
174
  self._additional_mcp_servers.append(mcp_server)
147
175
 
176
+ def set_should_enrich_context(self, enrich_context: bool):
177
+ self._should_enrich_context = enrich_context
178
+
179
+ def set_should_summarize_history(self, summarize_history: bool):
180
+ self._should_summarize_history = summarize_history
181
+
182
+ def set_history_summarization_threshold(self, summarization_threshold: int):
183
+ self._history_summarization_threshold = summarization_threshold
184
+
148
185
  async def _exec_action(self, ctx: AnyContext) -> Any:
149
- history = await self._read_conversation_history(ctx)
150
- user_prompt = self._get_message(ctx)
186
+ history_data: ConversationHistoryData = await self._read_conversation_history(
187
+ ctx
188
+ )
189
+ # Extract history list and conversation context
190
+ history_list = history_data.history
191
+ conversation_context = self._get_conversation_context(ctx)
192
+ # Merge history context without overwriting existing keys
193
+ for key, value in history_data.context.items():
194
+ if key not in conversation_context:
195
+ conversation_context[key] = value
196
+ # Enrich context based on history (if enabled)
197
+ if self._get_should_enrich_context(ctx, history_list):
198
+ conversation_context = await enrich_context(
199
+ ctx=ctx,
200
+ config=EnrichmentConfig(
201
+ model=self._get_model(ctx),
202
+ settings=self._get_model_settings(ctx),
203
+ prompt=self._get_context_enrichment_prompt(ctx),
204
+ ),
205
+ conversation_context=conversation_context,
206
+ history_list=history_list,
207
+ )
208
+ # Get history handling parameters
209
+ if self._get_should_summarize_history(ctx, history_list):
210
+ ctx.log_info("Summarize previous conversation")
211
+ # Summarize the part to be removed and update context
212
+ conversation_context = await summarize_history(
213
+ ctx=ctx,
214
+ config=SummarizationConfig(
215
+ model=self._get_model(ctx),
216
+ settings=self._get_model_settings(ctx),
217
+ prompt=self._get_summarization_prompt(ctx),
218
+ ),
219
+ conversation_context=conversation_context,
220
+ history_list=history_list, # Pass the full list for context
221
+ )
222
+ # Truncate the history list after summarization
223
+ history_list = []
224
+ # Construct user prompt
225
+ user_prompt = self._get_user_prompt(ctx, conversation_context)
226
+ # Create and run agent
151
227
  agent = self._get_agent(ctx)
152
228
  try:
153
- async with agent.run_mcp_servers():
154
- async with agent.iter(
155
- user_prompt=user_prompt,
156
- message_history=ModelMessagesTypeAdapter.validate_python(history),
157
- ) as agent_run:
158
- async for node in agent_run:
159
- # Each node represents a step in the agent's execution
160
- # Reference: https://ai.pydantic.dev/agents/#streaming
161
- try:
162
- await self._print_node(ctx, agent_run, node)
163
- except APIError as e:
164
- # Extract detailed error information from the response
165
- error_details = _extract_api_error_details(e)
166
- ctx.log_error(f"API Error: {error_details}")
167
- raise
168
- except Exception as e:
169
- ctx.log_error(f"Error processing node: {str(e)}")
170
- ctx.log_error(f"Error type: {type(e).__name__}")
171
- raise
172
- new_history = json.loads(agent_run.result.all_messages_json())
173
- await self._write_conversation_history(ctx, new_history)
174
- return agent_run.result.data
229
+ agent_run = await run_agent_iteration(
230
+ ctx=ctx,
231
+ agent=agent,
232
+ user_prompt=user_prompt,
233
+ history_list=history_list,
234
+ )
235
+ if agent_run:
236
+ new_history_list = json.loads(agent_run.result.all_messages_json())
237
+ data_to_write = ConversationHistoryData(
238
+ context=conversation_context,
239
+ history=new_history_list,
240
+ )
241
+ await self._write_conversation_history(
242
+ ctx, data_to_write
243
+ ) # Pass the model instance
244
+ return agent_run.result.data
175
245
  except Exception as e:
176
246
  ctx.log_error(f"Error in agent execution: {str(e)}")
177
247
  raise
178
248
 
179
- async def _print_node(self, ctx: AnyContext, agent_run: Any, node: Any):
180
- if Agent.is_user_prompt_node(node):
181
- # A user prompt node => The user has provided input
182
- ctx.print(stylize_faint(f">> UserPromptNode: {node.user_prompt}"))
183
- elif Agent.is_model_request_node(node):
184
- # A model request node => We can stream tokens from the model"s request
185
- ctx.print(
186
- stylize_faint(">> ModelRequestNode: streaming partial request tokens")
187
- )
188
- async with node.stream(agent_run.ctx) as request_stream:
189
- is_streaming = False
190
- async for event in request_stream:
191
- if isinstance(event, PartStartEvent):
192
- if is_streaming:
193
- ctx.print("", plain=True)
194
- ctx.print(
195
- stylize_faint(
196
- f"[Request] Starting part {event.index}: {event.part!r}"
197
- ),
198
- )
199
- is_streaming = False
200
- elif isinstance(event, PartDeltaEvent):
201
- if isinstance(event.delta, TextPartDelta):
202
- ctx.print(
203
- stylize_faint(f"{event.delta.content_delta}"),
204
- end="",
205
- plain=is_streaming,
206
- )
207
- elif isinstance(event.delta, ToolCallPartDelta):
208
- ctx.print(
209
- stylize_faint(f"{event.delta.args_delta}"),
210
- end="",
211
- plain=is_streaming,
212
- )
213
- is_streaming = True
214
- elif isinstance(event, FinalResultEvent):
215
- if is_streaming:
216
- ctx.print("", plain=True)
217
- ctx.print(
218
- stylize_faint(f"[Result] tool_name={event.tool_name}"),
219
- )
220
- is_streaming = False
221
- if is_streaming:
222
- ctx.print("", plain=True)
223
- elif Agent.is_call_tools_node(node):
224
- # A handle-response node => The model returned some data, potentially calls a tool
225
- ctx.print(
226
- stylize_faint(
227
- ">> CallToolsNode: streaming partial response & tool usage"
228
- )
229
- )
230
- async with node.stream(agent_run.ctx) as handle_stream:
231
- async for event in handle_stream:
232
- if isinstance(event, FunctionToolCallEvent):
233
- # Handle empty arguments across different providers
234
- if event.part.args == "" or event.part.args is None:
235
- event.part.args = {}
236
- elif isinstance(
237
- event.part.args, str
238
- ) and event.part.args.strip() in ["null", "{}"]:
239
- # Some providers might send "null" or "{}" as a string
240
- event.part.args = {}
241
- # Handle dummy property if present (from our schema sanitization)
242
- if (
243
- isinstance(event.part.args, dict)
244
- and "_dummy" in event.part.args
245
- ):
246
- del event.part.args["_dummy"]
247
- ctx.print(
248
- stylize_faint(
249
- f"[Tools] The LLM calls tool={event.part.tool_name!r} with args={event.part.args} (tool_call_id={event.part.tool_call_id!r})" # noqa
250
- )
251
- )
252
- elif isinstance(event, FunctionToolResultEvent):
253
- ctx.print(
254
- stylize_faint(
255
- f"[Tools] Tool call {event.tool_call_id!r} returned => {event.result.content}" # noqa
256
- )
257
- )
258
- elif Agent.is_end_node(node):
259
- # Once an End node is reached, the agent run is complete
260
- ctx.print(stylize_faint(f"{agent_run.result.data}"))
261
-
262
249
  async def _write_conversation_history(
263
- self, ctx: AnyContext, conversations: list[Any]
250
+ self, ctx: AnyContext, history_data: ConversationHistoryData
264
251
  ):
252
+ # Expects the model instance
265
253
  if self._conversation_history_writer is not None:
266
- await run_async(self._conversation_history_writer(ctx, conversations))
254
+ # Pass the model instance directly to the writer
255
+ await run_async(self._conversation_history_writer(ctx, history_data))
267
256
  history_file = self._get_history_file(ctx)
268
257
  if history_file != "":
269
- write_file(history_file, json.dumps(conversations, indent=2))
258
+ # Use model_dump_json for serialization
259
+ write_file(history_file, history_data.model_dump_json(indent=2))
270
260
 
271
261
  def _get_model_settings(self, ctx: AnyContext) -> ModelSettings | None:
272
262
  if callable(self._model_settings):
@@ -282,10 +272,17 @@ class LLMTask(BaseTask):
282
272
  self._tools(ctx) if callable(self._tools) else self._tools
283
273
  )
284
274
  tools_or_callables.extend(self._additional_tools)
285
- tools = [
286
- tool if isinstance(tool, Tool) else Tool(_wrap_tool(tool), takes_ctx=False)
287
- for tool in tools_or_callables
288
- ]
275
+ tools = []
276
+ for tool_or_callable in tools_or_callables:
277
+ if isinstance(tool_or_callable, Tool):
278
+ tools.append(tool_or_callable)
279
+ else:
280
+ # Inspect original callable for 'ctx' parameter
281
+ # This ctx refer to pydantic AI's ctx, not task ctx.
282
+ original_sig = inspect.signature(tool_or_callable)
283
+ takes_ctx = "ctx" in original_sig.parameters
284
+ wrapped_tool = wrap_tool(tool_or_callable)
285
+ tools.append(Tool(wrapped_tool, takes_ctx=takes_ctx))
289
286
  mcp_servers = list(
290
287
  self._mcp_servers(ctx) if callable(self._mcp_servers) else self._mcp_servers
291
288
  )
@@ -345,22 +342,90 @@ class LLMTask(BaseTask):
345
342
  return system_prompt
346
343
  return default_llm_config.get_default_system_prompt()
347
344
 
348
- def _get_message(self, ctx: AnyContext) -> str:
345
+ def _get_user_prompt(
346
+ self, ctx: AnyContext, conversation_context: dict[str, Any]
347
+ ) -> str:
348
+ user_message = self._get_user_message(ctx)
349
+ enriched_context = {**get_default_context(user_message), **conversation_context}
350
+ return dedent(
351
+ f"""
352
+ # Context
353
+ {json.dumps(enriched_context)}
354
+ # User Message
355
+ {user_message}
356
+ """.strip()
357
+ )
358
+
359
+ def _get_user_message(self, ctx: AnyContext) -> str:
349
360
  return get_str_attr(ctx, self._message, "How are you?", auto_render=True)
350
361
 
351
- async def _read_conversation_history(self, ctx: AnyContext) -> ListOfDict:
352
- if self._conversation_history_reader is not None:
353
- return await run_async(self._conversation_history_reader(ctx))
354
- if callable(self._conversation_history):
355
- return self._conversation_history(ctx)
362
+ def _get_summarization_prompt(self, ctx: AnyContext) -> str:
363
+ summarization_prompt = get_attr(
364
+ ctx,
365
+ self._summarization_prompt,
366
+ None,
367
+ auto_render=self._render_summarization_prompt,
368
+ )
369
+ if summarization_prompt is not None:
370
+ return summarization_prompt
371
+ return default_llm_config.get_default_summarization_prompt()
372
+
373
+ def _get_should_enrich_context(
374
+ self, ctx: AnyContext, history_list: ListOfDict
375
+ ) -> bool:
376
+ if len(history_list) == 0:
377
+ return False
378
+ return get_bool_attr(
379
+ ctx,
380
+ self._should_enrich_context,
381
+ True, # Default to True if not specified
382
+ auto_render=self._render_enrich_context,
383
+ )
384
+
385
+ def _get_context_enrichment_prompt(self, ctx: AnyContext) -> str:
386
+ context_enrichment_prompt = get_attr(
387
+ ctx,
388
+ self._context_enrichment_prompt,
389
+ None,
390
+ auto_render=self._render_context_enrichment_prompt,
391
+ )
392
+ if context_enrichment_prompt is not None:
393
+ return context_enrichment_prompt
394
+ return default_llm_config.get_default_context_enrichment_prompt()
395
+
396
+ async def _read_conversation_history(
397
+ self, ctx: AnyContext
398
+ ) -> ConversationHistoryData: # Returns the model instance
399
+ """Reads conversation history from reader, file, or attribute, with validation."""
356
400
  history_file = self._get_history_file(ctx)
357
- if (
358
- len(self._conversation_history) == 0
359
- and history_file != ""
360
- and os.path.isfile(history_file)
361
- ):
362
- return json.loads(read_file(history_file))
363
- return self._conversation_history
401
+ # Priority 1 & 2: Reader and File (handled by ConversationHistoryData)
402
+ history_data = await ConversationHistoryData.read_from_sources(
403
+ ctx=ctx,
404
+ reader=self._conversation_history_reader,
405
+ file_path=history_file,
406
+ )
407
+ if history_data:
408
+ return history_data
409
+ # Priority 3: Callable or direct conversation_history attribute
410
+ raw_data_attr: Any = None
411
+ if callable(self._conversation_history):
412
+ try:
413
+ raw_data_attr = await run_async(self._conversation_history(ctx))
414
+ except Exception as e:
415
+ ctx.log_warning(
416
+ f"Error executing callable conversation_history attribute: {e}. "
417
+ "Ignoring."
418
+ )
419
+ if raw_data_attr is None:
420
+ raw_data_attr = self._conversation_history
421
+ if raw_data_attr:
422
+ history_data = ConversationHistoryData.parse_and_validate(
423
+ ctx, raw_data_attr, "attribute"
424
+ )
425
+ if history_data:
426
+ return history_data
427
+ # Fallback: Return default value
428
+ return ConversationHistoryData()
364
429
 
365
430
  def _get_history_file(self, ctx: AnyContext) -> str:
366
431
  return get_str_attr(
@@ -370,103 +435,71 @@ class LLMTask(BaseTask):
370
435
  auto_render=self._render_history_file,
371
436
  )
372
437
 
373
-
374
- def _wrap_tool(func):
375
- sig = inspect.signature(func)
376
- if len(sig.parameters) == 0:
377
-
378
- @functools.wraps(func)
379
- async def wrapper(_dummy=None):
380
- try:
381
- return await run_async(func())
382
- except Exception as e:
383
- # Optionally, you can include more details from traceback if needed.
384
- error_details = traceback.format_exc()
385
- return f"Error: {e}\nDetails: {error_details}"
386
-
387
- new_sig = inspect.Signature(
388
- parameters=[
389
- inspect.Parameter(
390
- "_dummy", inspect.Parameter.POSITIONAL_OR_KEYWORD, default=None
391
- )
392
- ]
438
+ def _get_should_summarize_history(
439
+ self, ctx: AnyContext, history_list: ListOfDict
440
+ ) -> bool:
441
+ history_len = len(history_list)
442
+ if history_len == 0:
443
+ return False
444
+ summarization_threshold = self._get_history_summarization_threshold(ctx)
445
+ if summarization_threshold == -1:
446
+ return False
447
+ if summarization_threshold > history_len:
448
+ return False
449
+ return get_bool_attr(
450
+ ctx,
451
+ self._should_summarize_history,
452
+ False,
453
+ auto_render=self._render_summarize_history,
393
454
  )
394
- # Override the wrapper's signature so introspection yields a non-empty schema.
395
- wrapper.__signature__ = new_sig
396
- return wrapper
397
- else:
398
-
399
- @functools.wraps(func)
400
- async def wrapper(*args, **kwargs):
401
- try:
402
- return await run_async(func(*args, **kwargs))
403
- except Exception as e:
404
- # Optionally, you can include more details from traceback if needed.
405
- error_details = traceback.format_exc()
406
- return f"Error: {e}\nDetails: {error_details}"
407
455
 
408
- return wrapper
409
-
410
-
411
- def _extract_api_error_details(error: APIError) -> str:
412
- """Extract detailed error information from an APIError."""
413
- details = f"{error.message}"
414
- # Try to parse the error body as JSON
415
- if error.body:
456
+ def _get_history_summarization_threshold(self, ctx: AnyContext) -> int:
457
+ # Use get_int_attr with -1 as default (no limit)
416
458
  try:
417
- if isinstance(error.body, str):
418
- body_json = json.loads(error.body)
419
- elif isinstance(error.body, bytes):
420
- body_json = json.loads(error.body.decode("utf-8"))
421
- else:
422
- body_json = error.body
423
- # Extract error details from the JSON structure
424
- if isinstance(body_json, dict):
425
- if "error" in body_json:
426
- error_obj = body_json["error"]
427
- if isinstance(error_obj, dict):
428
- if "message" in error_obj:
429
- details += f"\nProvider message: {error_obj['message']}"
430
- if "code" in error_obj:
431
- details += f"\nError code: {error_obj['code']}"
432
- if "status" in error_obj:
433
- details += f"\nStatus: {error_obj['status']}"
434
- # Check for metadata that might contain provider-specific information
435
- if "metadata" in body_json and isinstance(body_json["metadata"], dict):
436
- metadata = body_json["metadata"]
437
- if "provider_name" in metadata:
438
- details += f"\nProvider: {metadata['provider_name']}"
439
- if "raw" in metadata:
440
- try:
441
- raw_json = json.loads(metadata["raw"])
442
- if "error" in raw_json and isinstance(
443
- raw_json["error"], dict
444
- ):
445
- raw_error = raw_json["error"]
446
- if "message" in raw_error:
447
- details += (
448
- f"\nRaw error message: {raw_error['message']}"
449
- )
450
- except (KeyError, TypeError, ValueError):
451
- # If we can't parse the raw JSON, just include it as is
452
- details += f"\nRaw error data: {metadata['raw']}"
453
- except json.JSONDecodeError:
454
- # If we can't parse the JSON, include the raw body
455
- details += f"\nRaw error body: {error.body}"
456
- except Exception as e:
457
- # Catch any other exceptions during parsing
458
- details += f"\nError parsing error body: {str(e)}"
459
- # Include request information if available
460
- if hasattr(error, "request") and error.request:
461
- if hasattr(error.request, "method") and hasattr(error.request, "url"):
462
- details += f"\nRequest: {error.request.method} {error.request.url}"
463
- # Include a truncated version of the request content if available
464
- if hasattr(error.request, "content") and error.request.content:
465
- content = error.request.content
466
- if isinstance(content, bytes):
459
+ return get_int_attr(
460
+ ctx,
461
+ self._history_summarization_threshold,
462
+ -1,
463
+ auto_render=self._render_history_summarization_threshold,
464
+ )
465
+ except ValueError as e:
466
+ ctx.log_warning(
467
+ f"Could not convert history_summarization_threshold to int: {e}. "
468
+ "Defaulting to -1 (no threshold)."
469
+ )
470
+ return -1
471
+
472
+ def _get_conversation_context(self, ctx: AnyContext) -> dict[str, Any]:
473
+ """
474
+ Retrieves the conversation context.
475
+ If a value in the context dict is callable, it executes it with ctx.
476
+ """
477
+ raw_context = get_attr(
478
+ ctx, self._conversation_context, {}, auto_render=False
479
+ ) # Context usually shouldn't be rendered
480
+ if not isinstance(raw_context, dict):
481
+ ctx.log_warning(
482
+ f"Conversation context resolved to type {type(raw_context)}, "
483
+ "expected dict. Returning empty context."
484
+ )
485
+ return {}
486
+ # If conversation_context contains callable value, execute them.
487
+ processed_context: dict[str, Any] = {}
488
+ for key, value in raw_context.items():
489
+ if callable(value):
467
490
  try:
468
- content = content.decode("utf-8")
469
- except UnicodeDecodeError:
470
- content = str(content)
471
- details += f"\nRequest content: {content}"
472
- return details
491
+ # Check if the callable expects 'ctx'
492
+ sig = inspect.signature(value)
493
+ if "ctx" in sig.parameters:
494
+ processed_context[key] = value(ctx)
495
+ else:
496
+ processed_context[key] = value()
497
+ except Exception as e:
498
+ ctx.log_warning(
499
+ f"Error executing callable for context key '{key}': {e}. "
500
+ "Skipping."
501
+ )
502
+ processed_context[key] = None
503
+ else:
504
+ processed_context[key] = value
505
+ return processed_context
zrb/util/file.py CHANGED
@@ -3,13 +3,19 @@ import re
3
3
 
4
4
 
5
5
  def read_file(file_path: str, replace_map: dict[str, str] = {}) -> str:
6
- with open(file_path, "r", encoding="utf-8") as f:
6
+ with open(
7
+ os.path.abspath(os.path.expanduser(file_path)), "r", encoding="utf-8"
8
+ ) as f:
7
9
  content = f.read()
8
10
  for key, val in replace_map.items():
9
11
  content = content.replace(key, val)
10
12
  return content
11
13
 
12
14
 
15
+ def read_dir(dir_path: str) -> list[str]:
16
+ return [f for f in os.listdir(os.path.abspath(os.path.expanduser(dir_path)))]
17
+
18
+
13
19
  def write_file(file_path: str, content: str | list[str]):
14
20
  if isinstance(content, list):
15
21
  content = "\n".join([line for line in content if line is not None])
@@ -21,5 +27,5 @@ def write_file(file_path: str, content: str | list[str]):
21
27
  content = content.rstrip("\n")
22
28
  if should_add_eol:
23
29
  content += "\n"
24
- with open(file_path, "w") as f:
30
+ with open(os.path.abspath(os.path.expanduser(file_path)), "w") as f:
25
31
  f.write(content)
zrb/util/load.py CHANGED
@@ -23,6 +23,8 @@ def load_file(script_path: str, sys_path_index: int = 0) -> Any | None:
23
23
  # Add script dir path to Python path
24
24
  os.environ["PYTHONPATH"] = _get_new_python_path(script_dir_path)
25
25
  spec = importlib.util.spec_from_file_location(module_name, script_path)
26
+ if spec is None:
27
+ return None
26
28
  module = importlib.util.module_from_spec(spec)
27
29
  spec.loader.exec_module(module)
28
30
  return module