zrb 1.5.5__py3-none-any.whl → 1.5.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zrb/__init__.py +2 -0
- zrb/__main__.py +28 -2
- zrb/builtin/llm/history.py +73 -0
- zrb/builtin/llm/input.py +27 -0
- zrb/builtin/llm/llm_chat.py +4 -61
- zrb/builtin/llm/tool/api.py +39 -17
- zrb/builtin/llm/tool/cli.py +19 -5
- zrb/builtin/llm/tool/file.py +277 -137
- zrb/builtin/llm/tool/rag.py +18 -1
- zrb/builtin/llm/tool/web.py +31 -14
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/common/error.py +6 -8
- zrb/config.py +1 -0
- zrb/llm_config.py +81 -15
- zrb/task/llm/__init__.py +0 -0
- zrb/task/llm/agent_runner.py +53 -0
- zrb/task/llm/context_enricher.py +86 -0
- zrb/task/llm/default_context.py +45 -0
- zrb/task/llm/error.py +77 -0
- zrb/task/llm/history.py +92 -0
- zrb/task/llm/history_summarizer.py +71 -0
- zrb/task/llm/print_node.py +98 -0
- zrb/task/llm/tool_wrapper.py +88 -0
- zrb/task/llm_task.py +279 -246
- zrb/util/file.py +17 -2
- zrb/util/load.py +2 -0
- {zrb-1.5.5.dist-info → zrb-1.5.7.dist-info}/METADATA +1 -1
- {zrb-1.5.5.dist-info → zrb-1.5.7.dist-info}/RECORD +29 -18
- {zrb-1.5.5.dist-info → zrb-1.5.7.dist-info}/WHEEL +0 -0
- {zrb-1.5.5.dist-info → zrb-1.5.7.dist-info}/entry_points.txt +0 -0
zrb/task/llm_task.py
CHANGED
@@ -1,28 +1,15 @@
|
|
1
|
-
import functools
|
2
1
|
import inspect
|
3
2
|
import json
|
4
|
-
import os
|
5
|
-
import traceback
|
6
3
|
from collections.abc import Callable
|
4
|
+
from textwrap import dedent
|
7
5
|
from typing import Any
|
8
6
|
|
9
|
-
from openai import APIError
|
10
7
|
from pydantic_ai import Agent, Tool
|
11
8
|
from pydantic_ai.mcp import MCPServer
|
12
|
-
from pydantic_ai.messages import (
|
13
|
-
FinalResultEvent,
|
14
|
-
FunctionToolCallEvent,
|
15
|
-
FunctionToolResultEvent,
|
16
|
-
ModelMessagesTypeAdapter,
|
17
|
-
PartDeltaEvent,
|
18
|
-
PartStartEvent,
|
19
|
-
TextPartDelta,
|
20
|
-
ToolCallPartDelta,
|
21
|
-
)
|
22
9
|
from pydantic_ai.models import Model
|
23
10
|
from pydantic_ai.settings import ModelSettings
|
24
11
|
|
25
|
-
from zrb.attr.type import StrAttr, fstring
|
12
|
+
from zrb.attr.type import BoolAttr, IntAttr, StrAttr, fstring
|
26
13
|
from zrb.context.any_context import AnyContext
|
27
14
|
from zrb.context.any_shared_context import AnySharedContext
|
28
15
|
from zrb.env.any_env import AnyEnv
|
@@ -31,12 +18,18 @@ from zrb.llm_config import LLMConfig
|
|
31
18
|
from zrb.llm_config import llm_config as default_llm_config
|
32
19
|
from zrb.task.any_task import AnyTask
|
33
20
|
from zrb.task.base_task import BaseTask
|
34
|
-
from zrb.
|
35
|
-
from zrb.
|
36
|
-
from zrb.
|
21
|
+
from zrb.task.llm.agent_runner import run_agent_iteration
|
22
|
+
from zrb.task.llm.context_enricher import EnrichmentConfig, enrich_context
|
23
|
+
from zrb.task.llm.default_context import get_default_context
|
24
|
+
from zrb.task.llm.history import ConversationHistoryData, ListOfDict
|
25
|
+
from zrb.task.llm.history_summarizer import SummarizationConfig, summarize_history
|
26
|
+
from zrb.task.llm.tool_wrapper import wrap_tool
|
27
|
+
from zrb.util.attr import get_attr, get_bool_attr, get_int_attr, get_str_attr
|
28
|
+
from zrb.util.file import write_file
|
37
29
|
from zrb.util.run import run_async
|
38
30
|
|
39
|
-
ListOfDict
|
31
|
+
# ListOfDict moved to history.py
|
32
|
+
# Removed old ConversationHistoryData type alias
|
40
33
|
ToolOrCallable = Tool | Callable
|
41
34
|
|
42
35
|
|
@@ -65,6 +58,12 @@ class LLMTask(BaseTask):
|
|
65
58
|
system_prompt: StrAttr | None = None,
|
66
59
|
render_system_prompt: bool = True,
|
67
60
|
message: StrAttr | None = None,
|
61
|
+
summarization_prompt: StrAttr | None = None,
|
62
|
+
render_summarization_prompt: bool = True,
|
63
|
+
enrich_context: BoolAttr = False,
|
64
|
+
render_enrich_context: bool = True,
|
65
|
+
context_enrichment_prompt: StrAttr | None = None,
|
66
|
+
render_context_enrichment_prompt: bool = True,
|
68
67
|
tools: (
|
69
68
|
list[ToolOrCallable] | Callable[[AnySharedContext], list[ToolOrCallable]]
|
70
69
|
) = [],
|
@@ -72,16 +71,29 @@ class LLMTask(BaseTask):
|
|
72
71
|
list[MCPServer] | Callable[[AnySharedContext], list[MCPServer]]
|
73
72
|
) = [],
|
74
73
|
conversation_history: (
|
75
|
-
|
76
|
-
|
74
|
+
ConversationHistoryData # Use the new BaseModel
|
75
|
+
| Callable[
|
76
|
+
[AnySharedContext], ConversationHistoryData | dict | list
|
77
|
+
] # Allow returning raw dict/list too
|
78
|
+
| dict # Allow raw dict
|
79
|
+
| list # Allow raw list (old format)
|
80
|
+
) = ConversationHistoryData(), # Default to an empty model instance
|
77
81
|
conversation_history_reader: (
|
78
|
-
Callable[[AnySharedContext],
|
82
|
+
Callable[[AnySharedContext], ConversationHistoryData | dict | list | None]
|
83
|
+
| None
|
84
|
+
# Allow returning raw dict/list or None
|
79
85
|
) = None,
|
80
86
|
conversation_history_writer: (
|
81
|
-
Callable[[AnySharedContext,
|
87
|
+
Callable[[AnySharedContext, ConversationHistoryData], None]
|
88
|
+
| None
|
89
|
+
# Writer expects the model instance
|
82
90
|
) = None,
|
83
91
|
conversation_history_file: StrAttr | None = None,
|
84
92
|
render_history_file: bool = True,
|
93
|
+
summarize_history: BoolAttr = True,
|
94
|
+
render_summarize_history: bool = True,
|
95
|
+
history_summarization_threshold: IntAttr = 5, # -1 means no summarization trigger
|
96
|
+
render_history_summarization_threshold: bool = True,
|
85
97
|
execute_condition: bool | str | Callable[[AnySharedContext], bool] = True,
|
86
98
|
retries: int = 2,
|
87
99
|
retry_period: float = 0,
|
@@ -95,6 +107,9 @@ class LLMTask(BaseTask):
|
|
95
107
|
upstream: list[AnyTask] | AnyTask | None = None,
|
96
108
|
fallback: list[AnyTask] | AnyTask | None = None,
|
97
109
|
successor: list[AnyTask] | AnyTask | None = None,
|
110
|
+
conversation_context: (
|
111
|
+
dict[str, Any] | Callable[[AnySharedContext], dict[str, Any]] | None
|
112
|
+
) = None,
|
98
113
|
):
|
99
114
|
super().__init__(
|
100
115
|
name=name,
|
@@ -128,6 +143,12 @@ class LLMTask(BaseTask):
|
|
128
143
|
self._system_prompt = system_prompt
|
129
144
|
self._render_system_prompt = render_system_prompt
|
130
145
|
self._message = message
|
146
|
+
self._summarization_prompt = summarization_prompt
|
147
|
+
self._render_summarization_prompt = render_summarization_prompt
|
148
|
+
self._should_enrich_context = enrich_context
|
149
|
+
self._render_enrich_context = render_enrich_context
|
150
|
+
self._context_enrichment_prompt = context_enrichment_prompt
|
151
|
+
self._render_context_enrichment_prompt = render_context_enrichment_prompt
|
131
152
|
self._tools = tools
|
132
153
|
self._additional_tools: list[ToolOrCallable] = []
|
133
154
|
self._mcp_servers = mcp_servers
|
@@ -137,7 +158,14 @@ class LLMTask(BaseTask):
|
|
137
158
|
self._conversation_history_writer = conversation_history_writer
|
138
159
|
self._conversation_history_file = conversation_history_file
|
139
160
|
self._render_history_file = render_history_file
|
161
|
+
self._should_summarize_history = summarize_history
|
162
|
+
self._render_summarize_history = render_summarize_history
|
163
|
+
self._history_summarization_threshold = history_summarization_threshold
|
164
|
+
self._render_history_summarization_threshold = (
|
165
|
+
render_history_summarization_threshold
|
166
|
+
)
|
140
167
|
self._max_call_iteration = max_call_iteration
|
168
|
+
self._conversation_context = conversation_context
|
141
169
|
|
142
170
|
def add_tool(self, tool: ToolOrCallable):
|
143
171
|
self._additional_tools.append(tool)
|
@@ -145,128 +173,90 @@ class LLMTask(BaseTask):
|
|
145
173
|
def add_mcp_server(self, mcp_server: MCPServer):
|
146
174
|
self._additional_mcp_servers.append(mcp_server)
|
147
175
|
|
176
|
+
def set_should_enrich_context(self, enrich_context: bool):
|
177
|
+
self._should_enrich_context = enrich_context
|
178
|
+
|
179
|
+
def set_should_summarize_history(self, summarize_history: bool):
|
180
|
+
self._should_summarize_history = summarize_history
|
181
|
+
|
182
|
+
def set_history_summarization_threshold(self, summarization_threshold: int):
|
183
|
+
self._history_summarization_threshold = summarization_threshold
|
184
|
+
|
148
185
|
async def _exec_action(self, ctx: AnyContext) -> Any:
|
149
|
-
|
150
|
-
|
186
|
+
history_data: ConversationHistoryData = await self._read_conversation_history(
|
187
|
+
ctx
|
188
|
+
)
|
189
|
+
# Extract history list and conversation context
|
190
|
+
history_list = history_data.history
|
191
|
+
conversation_context = self._get_conversation_context(ctx)
|
192
|
+
# Merge history context without overwriting existing keys
|
193
|
+
for key, value in history_data.context.items():
|
194
|
+
if key not in conversation_context:
|
195
|
+
conversation_context[key] = value
|
196
|
+
# Enrich context based on history (if enabled)
|
197
|
+
if self._get_should_enrich_context(ctx, history_list):
|
198
|
+
conversation_context = await enrich_context(
|
199
|
+
ctx=ctx,
|
200
|
+
config=EnrichmentConfig(
|
201
|
+
model=self._get_model(ctx),
|
202
|
+
settings=self._get_model_settings(ctx),
|
203
|
+
prompt=self._get_context_enrichment_prompt(ctx),
|
204
|
+
),
|
205
|
+
conversation_context=conversation_context,
|
206
|
+
history_list=history_list,
|
207
|
+
)
|
208
|
+
# Get history handling parameters
|
209
|
+
if self._get_should_summarize_history(ctx, history_list):
|
210
|
+
ctx.log_info("Summarize previous conversation")
|
211
|
+
# Summarize the part to be removed and update context
|
212
|
+
conversation_context = await summarize_history(
|
213
|
+
ctx=ctx,
|
214
|
+
config=SummarizationConfig(
|
215
|
+
model=self._get_model(ctx),
|
216
|
+
settings=self._get_model_settings(ctx),
|
217
|
+
prompt=self._get_summarization_prompt(ctx),
|
218
|
+
),
|
219
|
+
conversation_context=conversation_context,
|
220
|
+
history_list=history_list, # Pass the full list for context
|
221
|
+
)
|
222
|
+
# Truncate the history list after summarization
|
223
|
+
history_list = []
|
224
|
+
# Construct user prompt
|
225
|
+
user_prompt = self._get_user_prompt(ctx, conversation_context)
|
226
|
+
# Create and run agent
|
151
227
|
agent = self._get_agent(ctx)
|
152
228
|
try:
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
ctx.log_error(f"Error processing node: {str(e)}")
|
170
|
-
ctx.log_error(f"Error type: {type(e).__name__}")
|
171
|
-
raise
|
172
|
-
new_history = json.loads(agent_run.result.all_messages_json())
|
173
|
-
await self._write_conversation_history(ctx, new_history)
|
174
|
-
return agent_run.result.data
|
229
|
+
agent_run = await run_agent_iteration(
|
230
|
+
ctx=ctx,
|
231
|
+
agent=agent,
|
232
|
+
user_prompt=user_prompt,
|
233
|
+
history_list=history_list,
|
234
|
+
)
|
235
|
+
if agent_run:
|
236
|
+
new_history_list = json.loads(agent_run.result.all_messages_json())
|
237
|
+
data_to_write = ConversationHistoryData(
|
238
|
+
context=conversation_context,
|
239
|
+
history=new_history_list,
|
240
|
+
)
|
241
|
+
await self._write_conversation_history(
|
242
|
+
ctx, data_to_write
|
243
|
+
) # Pass the model instance
|
244
|
+
return agent_run.result.data
|
175
245
|
except Exception as e:
|
176
246
|
ctx.log_error(f"Error in agent execution: {str(e)}")
|
177
247
|
raise
|
178
248
|
|
179
|
-
async def _print_node(self, ctx: AnyContext, agent_run: Any, node: Any):
|
180
|
-
if Agent.is_user_prompt_node(node):
|
181
|
-
# A user prompt node => The user has provided input
|
182
|
-
ctx.print(stylize_faint(f">> UserPromptNode: {node.user_prompt}"))
|
183
|
-
elif Agent.is_model_request_node(node):
|
184
|
-
# A model request node => We can stream tokens from the model"s request
|
185
|
-
ctx.print(
|
186
|
-
stylize_faint(">> ModelRequestNode: streaming partial request tokens")
|
187
|
-
)
|
188
|
-
async with node.stream(agent_run.ctx) as request_stream:
|
189
|
-
is_streaming = False
|
190
|
-
async for event in request_stream:
|
191
|
-
if isinstance(event, PartStartEvent):
|
192
|
-
if is_streaming:
|
193
|
-
ctx.print("", plain=True)
|
194
|
-
ctx.print(
|
195
|
-
stylize_faint(
|
196
|
-
f"[Request] Starting part {event.index}: {event.part!r}"
|
197
|
-
),
|
198
|
-
)
|
199
|
-
is_streaming = False
|
200
|
-
elif isinstance(event, PartDeltaEvent):
|
201
|
-
if isinstance(event.delta, TextPartDelta):
|
202
|
-
ctx.print(
|
203
|
-
stylize_faint(f"{event.delta.content_delta}"),
|
204
|
-
end="",
|
205
|
-
plain=is_streaming,
|
206
|
-
)
|
207
|
-
elif isinstance(event.delta, ToolCallPartDelta):
|
208
|
-
ctx.print(
|
209
|
-
stylize_faint(f"{event.delta.args_delta}"),
|
210
|
-
end="",
|
211
|
-
plain=is_streaming,
|
212
|
-
)
|
213
|
-
is_streaming = True
|
214
|
-
elif isinstance(event, FinalResultEvent):
|
215
|
-
if is_streaming:
|
216
|
-
ctx.print("", plain=True)
|
217
|
-
ctx.print(
|
218
|
-
stylize_faint(f"[Result] tool_name={event.tool_name}"),
|
219
|
-
)
|
220
|
-
is_streaming = False
|
221
|
-
if is_streaming:
|
222
|
-
ctx.print("", plain=True)
|
223
|
-
elif Agent.is_call_tools_node(node):
|
224
|
-
# A handle-response node => The model returned some data, potentially calls a tool
|
225
|
-
ctx.print(
|
226
|
-
stylize_faint(
|
227
|
-
">> CallToolsNode: streaming partial response & tool usage"
|
228
|
-
)
|
229
|
-
)
|
230
|
-
async with node.stream(agent_run.ctx) as handle_stream:
|
231
|
-
async for event in handle_stream:
|
232
|
-
if isinstance(event, FunctionToolCallEvent):
|
233
|
-
# Handle empty arguments across different providers
|
234
|
-
if event.part.args == "" or event.part.args is None:
|
235
|
-
event.part.args = {}
|
236
|
-
elif isinstance(
|
237
|
-
event.part.args, str
|
238
|
-
) and event.part.args.strip() in ["null", "{}"]:
|
239
|
-
# Some providers might send "null" or "{}" as a string
|
240
|
-
event.part.args = {}
|
241
|
-
# Handle dummy property if present (from our schema sanitization)
|
242
|
-
if (
|
243
|
-
isinstance(event.part.args, dict)
|
244
|
-
and "_dummy" in event.part.args
|
245
|
-
):
|
246
|
-
del event.part.args["_dummy"]
|
247
|
-
ctx.print(
|
248
|
-
stylize_faint(
|
249
|
-
f"[Tools] The LLM calls tool={event.part.tool_name!r} with args={event.part.args} (tool_call_id={event.part.tool_call_id!r})" # noqa
|
250
|
-
)
|
251
|
-
)
|
252
|
-
elif isinstance(event, FunctionToolResultEvent):
|
253
|
-
ctx.print(
|
254
|
-
stylize_faint(
|
255
|
-
f"[Tools] Tool call {event.tool_call_id!r} returned => {event.result.content}" # noqa
|
256
|
-
)
|
257
|
-
)
|
258
|
-
elif Agent.is_end_node(node):
|
259
|
-
# Once an End node is reached, the agent run is complete
|
260
|
-
ctx.print(stylize_faint(f"{agent_run.result.data}"))
|
261
|
-
|
262
249
|
async def _write_conversation_history(
|
263
|
-
self, ctx: AnyContext,
|
250
|
+
self, ctx: AnyContext, history_data: ConversationHistoryData
|
264
251
|
):
|
252
|
+
# Expects the model instance
|
265
253
|
if self._conversation_history_writer is not None:
|
266
|
-
|
254
|
+
# Pass the model instance directly to the writer
|
255
|
+
await run_async(self._conversation_history_writer(ctx, history_data))
|
267
256
|
history_file = self._get_history_file(ctx)
|
268
257
|
if history_file != "":
|
269
|
-
|
258
|
+
# Use model_dump_json for serialization
|
259
|
+
write_file(history_file, history_data.model_dump_json(indent=2))
|
270
260
|
|
271
261
|
def _get_model_settings(self, ctx: AnyContext) -> ModelSettings | None:
|
272
262
|
if callable(self._model_settings):
|
@@ -282,10 +272,17 @@ class LLMTask(BaseTask):
|
|
282
272
|
self._tools(ctx) if callable(self._tools) else self._tools
|
283
273
|
)
|
284
274
|
tools_or_callables.extend(self._additional_tools)
|
285
|
-
tools = [
|
286
|
-
|
287
|
-
|
288
|
-
|
275
|
+
tools = []
|
276
|
+
for tool_or_callable in tools_or_callables:
|
277
|
+
if isinstance(tool_or_callable, Tool):
|
278
|
+
tools.append(tool_or_callable)
|
279
|
+
else:
|
280
|
+
# Inspect original callable for 'ctx' parameter
|
281
|
+
# This ctx refer to pydantic AI's ctx, not task ctx.
|
282
|
+
original_sig = inspect.signature(tool_or_callable)
|
283
|
+
takes_ctx = "ctx" in original_sig.parameters
|
284
|
+
wrapped_tool = wrap_tool(tool_or_callable)
|
285
|
+
tools.append(Tool(wrapped_tool, takes_ctx=takes_ctx))
|
289
286
|
mcp_servers = list(
|
290
287
|
self._mcp_servers(ctx) if callable(self._mcp_servers) else self._mcp_servers
|
291
288
|
)
|
@@ -345,22 +342,90 @@ class LLMTask(BaseTask):
|
|
345
342
|
return system_prompt
|
346
343
|
return default_llm_config.get_default_system_prompt()
|
347
344
|
|
348
|
-
def
|
345
|
+
def _get_user_prompt(
|
346
|
+
self, ctx: AnyContext, conversation_context: dict[str, Any]
|
347
|
+
) -> str:
|
348
|
+
user_message = self._get_user_message(ctx)
|
349
|
+
enriched_context = {**get_default_context(user_message), **conversation_context}
|
350
|
+
return dedent(
|
351
|
+
f"""
|
352
|
+
# Context
|
353
|
+
{json.dumps(enriched_context)}
|
354
|
+
# User Message
|
355
|
+
{user_message}
|
356
|
+
""".strip()
|
357
|
+
)
|
358
|
+
|
359
|
+
def _get_user_message(self, ctx: AnyContext) -> str:
|
349
360
|
return get_str_attr(ctx, self._message, "How are you?", auto_render=True)
|
350
361
|
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
362
|
+
def _get_summarization_prompt(self, ctx: AnyContext) -> str:
|
363
|
+
summarization_prompt = get_attr(
|
364
|
+
ctx,
|
365
|
+
self._summarization_prompt,
|
366
|
+
None,
|
367
|
+
auto_render=self._render_summarization_prompt,
|
368
|
+
)
|
369
|
+
if summarization_prompt is not None:
|
370
|
+
return summarization_prompt
|
371
|
+
return default_llm_config.get_default_summarization_prompt()
|
372
|
+
|
373
|
+
def _get_should_enrich_context(
|
374
|
+
self, ctx: AnyContext, history_list: ListOfDict
|
375
|
+
) -> bool:
|
376
|
+
if len(history_list) == 0:
|
377
|
+
return False
|
378
|
+
return get_bool_attr(
|
379
|
+
ctx,
|
380
|
+
self._should_enrich_context,
|
381
|
+
True, # Default to True if not specified
|
382
|
+
auto_render=self._render_enrich_context,
|
383
|
+
)
|
384
|
+
|
385
|
+
def _get_context_enrichment_prompt(self, ctx: AnyContext) -> str:
|
386
|
+
context_enrichment_prompt = get_attr(
|
387
|
+
ctx,
|
388
|
+
self._context_enrichment_prompt,
|
389
|
+
None,
|
390
|
+
auto_render=self._render_context_enrichment_prompt,
|
391
|
+
)
|
392
|
+
if context_enrichment_prompt is not None:
|
393
|
+
return context_enrichment_prompt
|
394
|
+
return default_llm_config.get_default_context_enrichment_prompt()
|
395
|
+
|
396
|
+
async def _read_conversation_history(
|
397
|
+
self, ctx: AnyContext
|
398
|
+
) -> ConversationHistoryData: # Returns the model instance
|
399
|
+
"""Reads conversation history from reader, file, or attribute, with validation."""
|
356
400
|
history_file = self._get_history_file(ctx)
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
401
|
+
# Priority 1 & 2: Reader and File (handled by ConversationHistoryData)
|
402
|
+
history_data = await ConversationHistoryData.read_from_sources(
|
403
|
+
ctx=ctx,
|
404
|
+
reader=self._conversation_history_reader,
|
405
|
+
file_path=history_file,
|
406
|
+
)
|
407
|
+
if history_data:
|
408
|
+
return history_data
|
409
|
+
# Priority 3: Callable or direct conversation_history attribute
|
410
|
+
raw_data_attr: Any = None
|
411
|
+
if callable(self._conversation_history):
|
412
|
+
try:
|
413
|
+
raw_data_attr = await run_async(self._conversation_history(ctx))
|
414
|
+
except Exception as e:
|
415
|
+
ctx.log_warning(
|
416
|
+
f"Error executing callable conversation_history attribute: {e}. "
|
417
|
+
"Ignoring."
|
418
|
+
)
|
419
|
+
if raw_data_attr is None:
|
420
|
+
raw_data_attr = self._conversation_history
|
421
|
+
if raw_data_attr:
|
422
|
+
history_data = ConversationHistoryData.parse_and_validate(
|
423
|
+
ctx, raw_data_attr, "attribute"
|
424
|
+
)
|
425
|
+
if history_data:
|
426
|
+
return history_data
|
427
|
+
# Fallback: Return default value
|
428
|
+
return ConversationHistoryData()
|
364
429
|
|
365
430
|
def _get_history_file(self, ctx: AnyContext) -> str:
|
366
431
|
return get_str_attr(
|
@@ -370,103 +435,71 @@ class LLMTask(BaseTask):
|
|
370
435
|
auto_render=self._render_history_file,
|
371
436
|
)
|
372
437
|
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
inspect.Parameter(
|
390
|
-
"_dummy", inspect.Parameter.POSITIONAL_OR_KEYWORD, default=None
|
391
|
-
)
|
392
|
-
]
|
438
|
+
def _get_should_summarize_history(
|
439
|
+
self, ctx: AnyContext, history_list: ListOfDict
|
440
|
+
) -> bool:
|
441
|
+
history_len = len(history_list)
|
442
|
+
if history_len == 0:
|
443
|
+
return False
|
444
|
+
summarization_threshold = self._get_history_summarization_threshold(ctx)
|
445
|
+
if summarization_threshold == -1:
|
446
|
+
return False
|
447
|
+
if summarization_threshold > history_len:
|
448
|
+
return False
|
449
|
+
return get_bool_attr(
|
450
|
+
ctx,
|
451
|
+
self._should_summarize_history,
|
452
|
+
False,
|
453
|
+
auto_render=self._render_summarize_history,
|
393
454
|
)
|
394
|
-
# Override the wrapper's signature so introspection yields a non-empty schema.
|
395
|
-
wrapper.__signature__ = new_sig
|
396
|
-
return wrapper
|
397
|
-
else:
|
398
|
-
|
399
|
-
@functools.wraps(func)
|
400
|
-
async def wrapper(*args, **kwargs):
|
401
|
-
try:
|
402
|
-
return await run_async(func(*args, **kwargs))
|
403
|
-
except Exception as e:
|
404
|
-
# Optionally, you can include more details from traceback if needed.
|
405
|
-
error_details = traceback.format_exc()
|
406
|
-
return json.dumps({"error": f"{e}", "details": f"{error_details}"})
|
407
455
|
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
def _extract_api_error_details(error: APIError) -> str:
|
412
|
-
"""Extract detailed error information from an APIError."""
|
413
|
-
details = f"{error.message}"
|
414
|
-
# Try to parse the error body as JSON
|
415
|
-
if error.body:
|
456
|
+
def _get_history_summarization_threshold(self, ctx: AnyContext) -> int:
|
457
|
+
# Use get_int_attr with -1 as default (no limit)
|
416
458
|
try:
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
f"\nRaw error message: {raw_error['message']}"
|
449
|
-
)
|
450
|
-
except (KeyError, TypeError, ValueError):
|
451
|
-
# If we can't parse the raw JSON, just include it as is
|
452
|
-
details += f"\nRaw error data: {metadata['raw']}"
|
453
|
-
except json.JSONDecodeError:
|
454
|
-
# If we can't parse the JSON, include the raw body
|
455
|
-
details += f"\nRaw error body: {error.body}"
|
456
|
-
except Exception as e:
|
457
|
-
# Catch any other exceptions during parsing
|
458
|
-
details += f"\nError parsing error body: {str(e)}"
|
459
|
-
# Include request information if available
|
460
|
-
if hasattr(error, "request") and error.request:
|
461
|
-
if hasattr(error.request, "method") and hasattr(error.request, "url"):
|
462
|
-
details += f"\nRequest: {error.request.method} {error.request.url}"
|
463
|
-
# Include a truncated version of the request content if available
|
464
|
-
if hasattr(error.request, "content") and error.request.content:
|
465
|
-
content = error.request.content
|
466
|
-
if isinstance(content, bytes):
|
459
|
+
return get_int_attr(
|
460
|
+
ctx,
|
461
|
+
self._history_summarization_threshold,
|
462
|
+
-1,
|
463
|
+
auto_render=self._render_history_summarization_threshold,
|
464
|
+
)
|
465
|
+
except ValueError as e:
|
466
|
+
ctx.log_warning(
|
467
|
+
f"Could not convert history_summarization_threshold to int: {e}. "
|
468
|
+
"Defaulting to -1 (no threshold)."
|
469
|
+
)
|
470
|
+
return -1
|
471
|
+
|
472
|
+
def _get_conversation_context(self, ctx: AnyContext) -> dict[str, Any]:
|
473
|
+
"""
|
474
|
+
Retrieves the conversation context.
|
475
|
+
If a value in the context dict is callable, it executes it with ctx.
|
476
|
+
"""
|
477
|
+
raw_context = get_attr(
|
478
|
+
ctx, self._conversation_context, {}, auto_render=False
|
479
|
+
) # Context usually shouldn't be rendered
|
480
|
+
if not isinstance(raw_context, dict):
|
481
|
+
ctx.log_warning(
|
482
|
+
f"Conversation context resolved to type {type(raw_context)}, "
|
483
|
+
"expected dict. Returning empty context."
|
484
|
+
)
|
485
|
+
return {}
|
486
|
+
# If conversation_context contains callable value, execute them.
|
487
|
+
processed_context: dict[str, Any] = {}
|
488
|
+
for key, value in raw_context.items():
|
489
|
+
if callable(value):
|
467
490
|
try:
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
491
|
+
# Check if the callable expects 'ctx'
|
492
|
+
sig = inspect.signature(value)
|
493
|
+
if "ctx" in sig.parameters:
|
494
|
+
processed_context[key] = value(ctx)
|
495
|
+
else:
|
496
|
+
processed_context[key] = value()
|
497
|
+
except Exception as e:
|
498
|
+
ctx.log_warning(
|
499
|
+
f"Error executing callable for context key '{key}': {e}. "
|
500
|
+
"Skipping."
|
501
|
+
)
|
502
|
+
processed_context[key] = None
|
503
|
+
else:
|
504
|
+
processed_context[key] = value
|
505
|
+
return processed_context
|
zrb/util/file.py
CHANGED
@@ -3,13 +3,28 @@ import re
|
|
3
3
|
|
4
4
|
|
5
5
|
def read_file(file_path: str, replace_map: dict[str, str] = {}) -> str:
|
6
|
-
with open(
|
6
|
+
with open(
|
7
|
+
os.path.abspath(os.path.expanduser(file_path)), "r", encoding="utf-8"
|
8
|
+
) as f:
|
7
9
|
content = f.read()
|
8
10
|
for key, val in replace_map.items():
|
9
11
|
content = content.replace(key, val)
|
10
12
|
return content
|
11
13
|
|
12
14
|
|
15
|
+
def read_file_with_line_numbers(
|
16
|
+
file_path: str, replace_map: dict[str, str] = {}
|
17
|
+
) -> str:
|
18
|
+
content = read_file(file_path, replace_map)
|
19
|
+
lines = content.splitlines()
|
20
|
+
numbered_lines = [f"{i + 1} | {line}" for i, line in enumerate(lines)]
|
21
|
+
return "\n".join(numbered_lines)
|
22
|
+
|
23
|
+
|
24
|
+
def read_dir(dir_path: str) -> list[str]:
|
25
|
+
return [f for f in os.listdir(os.path.abspath(os.path.expanduser(dir_path)))]
|
26
|
+
|
27
|
+
|
13
28
|
def write_file(file_path: str, content: str | list[str]):
|
14
29
|
if isinstance(content, list):
|
15
30
|
content = "\n".join([line for line in content if line is not None])
|
@@ -21,5 +36,5 @@ def write_file(file_path: str, content: str | list[str]):
|
|
21
36
|
content = content.rstrip("\n")
|
22
37
|
if should_add_eol:
|
23
38
|
content += "\n"
|
24
|
-
with open(file_path, "w") as f:
|
39
|
+
with open(os.path.abspath(os.path.expanduser(file_path)), "w") as f:
|
25
40
|
f.write(content)
|