zrb 1.5.6__py3-none-any.whl → 1.5.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zrb/builtin/llm/tool/file.py +8 -7
- zrb/builtin/llm/tool/rag.py +4 -3
- zrb/task/cmd_task.py +2 -1
- zrb/task/llm/agent.py +141 -0
- zrb/task/llm/config.py +83 -0
- zrb/task/llm/context.py +95 -0
- zrb/task/llm/{context_enricher.py → context_enrichment.py} +52 -6
- zrb/task/llm/history.py +153 -3
- zrb/task/llm/history_summarization.py +170 -0
- zrb/task/llm/prompt.py +87 -0
- zrb/task/llm/typing.py +3 -0
- zrb/task/llm_task.py +135 -320
- zrb/util/file.py +9 -0
- {zrb-1.5.6.dist-info → zrb-1.5.8.dist-info}/METADATA +1 -1
- {zrb-1.5.6.dist-info → zrb-1.5.8.dist-info}/RECORD +17 -14
- zrb/task/llm/agent_runner.py +0 -53
- zrb/task/llm/default_context.py +0 -44
- zrb/task/llm/history_summarizer.py +0 -71
- {zrb-1.5.6.dist-info → zrb-1.5.8.dist-info}/WHEEL +0 -0
- {zrb-1.5.6.dist-info → zrb-1.5.8.dist-info}/entry_points.txt +0 -0
zrb/builtin/llm/tool/file.py
CHANGED
@@ -4,8 +4,7 @@ import os
|
|
4
4
|
import re
|
5
5
|
from typing import Any, Optional
|
6
6
|
|
7
|
-
from zrb.util.file import read_file
|
8
|
-
from zrb.util.file import write_file as _write_file
|
7
|
+
from zrb.util.file import read_file, read_file_with_line_numbers, write_file
|
9
8
|
|
10
9
|
DEFAULT_EXCLUDED_PATTERNS = [
|
11
10
|
# Common Python artifacts
|
@@ -182,7 +181,7 @@ def read_from_file(
|
|
182
181
|
start_line: Optional[int] = None,
|
183
182
|
end_line: Optional[int] = None,
|
184
183
|
) -> str:
|
185
|
-
"""Read file content (or specific lines) at a path.
|
184
|
+
"""Read file content (or specific lines) at a path, including line numbers.
|
186
185
|
Args:
|
187
186
|
path (str): Path to read. Pass exactly as provided, including '~'.
|
188
187
|
start_line (Optional[int]): Starting line number (1-based).
|
@@ -191,6 +190,7 @@ def read_from_file(
|
|
191
190
|
Defaults to None (end of file).
|
192
191
|
Returns:
|
193
192
|
str: JSON: {"path": "...", "content": "...", "start_line": N, ...} or {"error": "..."}
|
193
|
+
The content includes line numbers.
|
194
194
|
Raises:
|
195
195
|
Exception: If an error occurs.
|
196
196
|
"""
|
@@ -199,7 +199,7 @@ def read_from_file(
|
|
199
199
|
# Check if file exists
|
200
200
|
if not os.path.exists(abs_path):
|
201
201
|
return json.dumps({"error": f"File {path} does not exist"})
|
202
|
-
content =
|
202
|
+
content = read_file_with_line_numbers(abs_path)
|
203
203
|
lines = content.splitlines()
|
204
204
|
total_lines = len(lines)
|
205
205
|
# Adjust line indices (convert from 1-based to 0-based)
|
@@ -259,7 +259,7 @@ def write_to_file(
|
|
259
259
|
directory = os.path.dirname(abs_path)
|
260
260
|
if directory and not os.path.exists(directory):
|
261
261
|
os.makedirs(directory, exist_ok=True)
|
262
|
-
|
262
|
+
write_file(abs_path, content)
|
263
263
|
result_data = {"success": True, "path": path}
|
264
264
|
if warning:
|
265
265
|
result_data["warning"] = warning
|
@@ -391,6 +391,7 @@ def apply_diff(
|
|
391
391
|
replace_marker (str): Marker for end of replacement block.
|
392
392
|
Defaults to ">>>>>> REPLACE".
|
393
393
|
SEARCH block must exactly match file content including whitespace/indentation.
|
394
|
+
SEARCH block should NOT contains line numbers
|
394
395
|
Format example:
|
395
396
|
[Search Marker, e.g., <<<<<< SEARCH]
|
396
397
|
:start_line:10
|
@@ -414,7 +415,7 @@ def apply_diff(
|
|
414
415
|
return json.dumps(
|
415
416
|
{"success": False, "path": path, "error": f"File not found at {path}"}
|
416
417
|
)
|
417
|
-
content =
|
418
|
+
content = read_file(abs_path)
|
418
419
|
lines = content.splitlines()
|
419
420
|
if start_line < 1 or end_line > len(lines) or start_line > end_line:
|
420
421
|
return json.dumps(
|
@@ -444,7 +445,7 @@ def apply_diff(
|
|
444
445
|
new_content = "\n".join(new_lines)
|
445
446
|
if content.endswith("\n"):
|
446
447
|
new_content += "\n"
|
447
|
-
|
448
|
+
write_file(abs_path, new_content)
|
448
449
|
return json.dumps({"success": True, "path": path})
|
449
450
|
except ValueError as e:
|
450
451
|
raise ValueError(f"Error parsing diff: {e}")
|
zrb/builtin/llm/tool/rag.py
CHANGED
@@ -152,13 +152,14 @@ def create_rag_from_directory(
|
|
152
152
|
|
153
153
|
retrieve.__name__ = tool_name
|
154
154
|
retrieve.__doc__ = dedent(
|
155
|
-
f"""
|
155
|
+
f"""
|
156
|
+
{tool_description}
|
156
157
|
Args:
|
157
158
|
query (str): The user query to search for in documents.
|
158
159
|
Returns:
|
159
160
|
str: JSON string with search results: {{"ids": [...], "documents": [...], ...}}
|
160
|
-
|
161
|
-
)
|
161
|
+
"""
|
162
|
+
).strip()
|
162
163
|
return retrieve
|
163
164
|
|
164
165
|
|
zrb/task/cmd_task.py
CHANGED
@@ -130,7 +130,8 @@ class CmdTask(BaseTask):
|
|
130
130
|
partial(ctx.print, plain=True) if self._should_plain_print else ctx.print
|
131
131
|
)
|
132
132
|
xcom_pid_key = f"{self.name}-pid"
|
133
|
-
ctx.xcom
|
133
|
+
if xcom_pid_key not in ctx.xcom:
|
134
|
+
ctx.xcom[xcom_pid_key] = Xcom([])
|
134
135
|
cmd_result, return_code = await run_command(
|
135
136
|
cmd=[shell, shell_flag, cmd_script],
|
136
137
|
cwd=cwd,
|
zrb/task/llm/agent.py
ADDED
@@ -0,0 +1,141 @@
|
|
1
|
+
import inspect
|
2
|
+
from collections.abc import Callable
|
3
|
+
|
4
|
+
from openai import APIError
|
5
|
+
from pydantic_ai import Agent, Tool
|
6
|
+
from pydantic_ai.agent import AgentRun
|
7
|
+
from pydantic_ai.mcp import MCPServer
|
8
|
+
from pydantic_ai.messages import ModelMessagesTypeAdapter
|
9
|
+
from pydantic_ai.models import Model
|
10
|
+
from pydantic_ai.settings import ModelSettings
|
11
|
+
|
12
|
+
from zrb.context.any_context import AnyContext
|
13
|
+
from zrb.context.any_shared_context import AnySharedContext
|
14
|
+
from zrb.task.llm.error import extract_api_error_details
|
15
|
+
from zrb.task.llm.print_node import print_node
|
16
|
+
from zrb.task.llm.tool_wrapper import wrap_tool
|
17
|
+
from zrb.task.llm.typing import ListOfDict
|
18
|
+
|
19
|
+
ToolOrCallable = Tool | Callable
|
20
|
+
|
21
|
+
|
22
|
+
def create_agent_instance(
|
23
|
+
ctx: AnyContext,
|
24
|
+
model: str | Model | None,
|
25
|
+
system_prompt: str,
|
26
|
+
model_settings: ModelSettings | None,
|
27
|
+
tools_attr: (
|
28
|
+
list[ToolOrCallable] | Callable[[AnySharedContext], list[ToolOrCallable]]
|
29
|
+
),
|
30
|
+
additional_tools: list[ToolOrCallable],
|
31
|
+
mcp_servers_attr: list[MCPServer] | Callable[[AnySharedContext], list[MCPServer]],
|
32
|
+
additional_mcp_servers: list[MCPServer],
|
33
|
+
) -> Agent:
|
34
|
+
"""Creates a new Agent instance with configured tools and servers."""
|
35
|
+
tools_or_callables = list(tools_attr(ctx) if callable(tools_attr) else tools_attr)
|
36
|
+
tools_or_callables.extend(additional_tools)
|
37
|
+
tools = []
|
38
|
+
for tool_or_callable in tools_or_callables:
|
39
|
+
if isinstance(tool_or_callable, Tool):
|
40
|
+
tools.append(tool_or_callable)
|
41
|
+
else:
|
42
|
+
# Inspect original callable for 'ctx' parameter (pydantic-ai context)
|
43
|
+
original_sig = inspect.signature(tool_or_callable)
|
44
|
+
takes_ctx = "ctx" in original_sig.parameters
|
45
|
+
wrapped_tool = wrap_tool(tool_or_callable)
|
46
|
+
tools.append(Tool(wrapped_tool, takes_ctx=takes_ctx))
|
47
|
+
|
48
|
+
mcp_servers = list(
|
49
|
+
mcp_servers_attr(ctx) if callable(mcp_servers_attr) else mcp_servers_attr
|
50
|
+
)
|
51
|
+
mcp_servers.extend(additional_mcp_servers)
|
52
|
+
|
53
|
+
return Agent(
|
54
|
+
model=model,
|
55
|
+
system_prompt=system_prompt,
|
56
|
+
tools=tools,
|
57
|
+
mcp_servers=mcp_servers,
|
58
|
+
model_settings=model_settings,
|
59
|
+
retries=3, # Consider making retries configurable?
|
60
|
+
)
|
61
|
+
|
62
|
+
|
63
|
+
def get_agent(
|
64
|
+
ctx: AnyContext,
|
65
|
+
agent_attr: Agent | Callable[[AnySharedContext], Agent] | None,
|
66
|
+
model: str | Model | None,
|
67
|
+
system_prompt: str,
|
68
|
+
model_settings: ModelSettings | None,
|
69
|
+
tools_attr: (
|
70
|
+
list[ToolOrCallable] | Callable[[AnySharedContext], list[ToolOrCallable]]
|
71
|
+
),
|
72
|
+
additional_tools: list[ToolOrCallable],
|
73
|
+
mcp_servers_attr: list[MCPServer] | Callable[[AnySharedContext], list[MCPServer]],
|
74
|
+
additional_mcp_servers: list[MCPServer],
|
75
|
+
) -> Agent:
|
76
|
+
"""Retrieves the configured Agent instance or creates one if necessary."""
|
77
|
+
if isinstance(agent_attr, Agent):
|
78
|
+
return agent_attr
|
79
|
+
if callable(agent_attr):
|
80
|
+
agent_instance = agent_attr(ctx)
|
81
|
+
if not isinstance(agent_instance, Agent):
|
82
|
+
err_msg = (
|
83
|
+
"Callable agent factory did not return an Agent instance, "
|
84
|
+
f"got: {type(agent_instance)}"
|
85
|
+
)
|
86
|
+
raise TypeError(err_msg)
|
87
|
+
return agent_instance
|
88
|
+
# If no agent provided, create one using the configuration
|
89
|
+
return create_agent_instance(
|
90
|
+
ctx=ctx,
|
91
|
+
model=model,
|
92
|
+
system_prompt=system_prompt,
|
93
|
+
model_settings=model_settings,
|
94
|
+
tools_attr=tools_attr,
|
95
|
+
additional_tools=additional_tools,
|
96
|
+
mcp_servers_attr=mcp_servers_attr,
|
97
|
+
additional_mcp_servers=additional_mcp_servers,
|
98
|
+
)
|
99
|
+
|
100
|
+
|
101
|
+
async def run_agent_iteration(
|
102
|
+
ctx: AnyContext,
|
103
|
+
agent: Agent,
|
104
|
+
user_prompt: str,
|
105
|
+
history_list: ListOfDict,
|
106
|
+
) -> AgentRun:
|
107
|
+
"""
|
108
|
+
Runs a single iteration of the agent execution loop.
|
109
|
+
|
110
|
+
Args:
|
111
|
+
ctx: The task context.
|
112
|
+
agent: The Pydantic AI agent instance.
|
113
|
+
user_prompt: The user's input prompt.
|
114
|
+
history_list: The current conversation history.
|
115
|
+
|
116
|
+
Returns:
|
117
|
+
The agent run result object.
|
118
|
+
|
119
|
+
Raises:
|
120
|
+
Exception: If any error occurs during agent execution.
|
121
|
+
"""
|
122
|
+
async with agent.run_mcp_servers():
|
123
|
+
async with agent.iter(
|
124
|
+
user_prompt=user_prompt,
|
125
|
+
message_history=ModelMessagesTypeAdapter.validate_python(history_list),
|
126
|
+
) as agent_run:
|
127
|
+
async for node in agent_run:
|
128
|
+
# Each node represents a step in the agent's execution
|
129
|
+
# Reference: https://ai.pydantic.dev/agents/#streaming
|
130
|
+
try:
|
131
|
+
await print_node(ctx.print, agent_run, node)
|
132
|
+
except APIError as e:
|
133
|
+
# Extract detailed error information from the response
|
134
|
+
error_details = extract_api_error_details(e)
|
135
|
+
ctx.log_error(f"API Error: {error_details}")
|
136
|
+
raise
|
137
|
+
except Exception as e:
|
138
|
+
ctx.log_error(f"Error processing node: {str(e)}")
|
139
|
+
ctx.log_error(f"Error type: {type(e).__name__}")
|
140
|
+
raise
|
141
|
+
return agent_run
|
zrb/task/llm/config.py
ADDED
@@ -0,0 +1,83 @@
|
|
1
|
+
from typing import Callable
|
2
|
+
|
3
|
+
from pydantic_ai.models import Model
|
4
|
+
from pydantic_ai.settings import ModelSettings
|
5
|
+
|
6
|
+
from zrb.attr.type import StrAttr, fstring
|
7
|
+
from zrb.context.any_context import AnyContext
|
8
|
+
from zrb.context.any_shared_context import AnySharedContext
|
9
|
+
from zrb.llm_config import LLMConfig
|
10
|
+
from zrb.llm_config import llm_config as default_llm_config
|
11
|
+
from zrb.util.attr import get_attr
|
12
|
+
|
13
|
+
|
14
|
+
def get_model_settings(
|
15
|
+
ctx: AnyContext,
|
16
|
+
model_settings_attr: (
|
17
|
+
ModelSettings | Callable[[AnySharedContext], ModelSettings] | None
|
18
|
+
),
|
19
|
+
) -> ModelSettings | None:
|
20
|
+
"""Gets the model settings, resolving callables if necessary."""
|
21
|
+
if callable(model_settings_attr):
|
22
|
+
return model_settings_attr(ctx)
|
23
|
+
return model_settings_attr
|
24
|
+
|
25
|
+
|
26
|
+
def get_model_base_url(
|
27
|
+
ctx: AnyContext,
|
28
|
+
model_base_url_attr: StrAttr | None,
|
29
|
+
render_model_base_url: bool,
|
30
|
+
) -> str | None:
|
31
|
+
"""Gets the model base URL, rendering if configured."""
|
32
|
+
base_url = get_attr(
|
33
|
+
ctx, model_base_url_attr, None, auto_render=render_model_base_url
|
34
|
+
)
|
35
|
+
if isinstance(base_url, str) or base_url is None:
|
36
|
+
return base_url
|
37
|
+
raise ValueError(f"Invalid model base URL: {base_url}")
|
38
|
+
|
39
|
+
|
40
|
+
def get_model_api_key(
|
41
|
+
ctx: AnyContext,
|
42
|
+
model_api_key_attr: StrAttr | None,
|
43
|
+
render_model_api_key: bool,
|
44
|
+
) -> str | None:
|
45
|
+
"""Gets the model API key, rendering if configured."""
|
46
|
+
api_key = get_attr(ctx, model_api_key_attr, None, auto_render=render_model_api_key)
|
47
|
+
if isinstance(api_key, str) or api_key is None:
|
48
|
+
return api_key
|
49
|
+
raise ValueError(f"Invalid model API key: {api_key}")
|
50
|
+
|
51
|
+
|
52
|
+
def get_model(
|
53
|
+
ctx: AnyContext,
|
54
|
+
model_attr: Callable[[AnySharedContext], Model | str | fstring] | Model | None,
|
55
|
+
render_model: bool,
|
56
|
+
model_base_url_attr: StrAttr | None,
|
57
|
+
render_model_base_url: bool,
|
58
|
+
model_api_key_attr: StrAttr | None,
|
59
|
+
render_model_api_key: bool,
|
60
|
+
) -> str | Model | None:
|
61
|
+
"""Gets the model instance or name, handling defaults and configuration."""
|
62
|
+
model = get_attr(ctx, model_attr, None, auto_render=render_model)
|
63
|
+
if model is None:
|
64
|
+
return default_llm_config.get_default_model()
|
65
|
+
if isinstance(model, str):
|
66
|
+
model_base_url = get_model_base_url(
|
67
|
+
ctx, model_base_url_attr, render_model_base_url
|
68
|
+
)
|
69
|
+
model_api_key = get_model_api_key(ctx, model_api_key_attr, render_model_api_key)
|
70
|
+
llm_config = LLMConfig(
|
71
|
+
default_model_name=model,
|
72
|
+
default_base_url=model_base_url,
|
73
|
+
default_api_key=model_api_key,
|
74
|
+
)
|
75
|
+
if model_base_url is None and model_api_key is None:
|
76
|
+
default_model_provider = default_llm_config.get_default_model_provider()
|
77
|
+
if default_model_provider is not None:
|
78
|
+
llm_config.set_default_provider(default_model_provider)
|
79
|
+
return llm_config.get_default_model()
|
80
|
+
# If it's already a Model instance, return it directly
|
81
|
+
if isinstance(model, Model):
|
82
|
+
return model
|
83
|
+
raise ValueError(f"Invalid model type resolved: {type(model)}, value: {model}")
|
zrb/task/llm/context.py
ADDED
@@ -0,0 +1,95 @@
|
|
1
|
+
import datetime
|
2
|
+
import inspect
|
3
|
+
import os
|
4
|
+
import platform
|
5
|
+
import re
|
6
|
+
from collections.abc import Callable
|
7
|
+
from typing import Any
|
8
|
+
|
9
|
+
from zrb.context.any_context import AnyContext
|
10
|
+
from zrb.context.any_shared_context import AnySharedContext
|
11
|
+
from zrb.util.attr import get_attr
|
12
|
+
from zrb.util.file import read_dir, read_file_with_line_numbers
|
13
|
+
|
14
|
+
|
15
|
+
def get_default_context(user_message: str) -> dict[str, Any]:
|
16
|
+
"""Generates default context including time, OS, and file references."""
|
17
|
+
references = re.findall(r"@(\S+)", user_message)
|
18
|
+
current_references = []
|
19
|
+
|
20
|
+
for ref in references:
|
21
|
+
resource_path = os.path.abspath(os.path.expanduser(ref))
|
22
|
+
if os.path.isfile(resource_path):
|
23
|
+
content = read_file_with_line_numbers(resource_path)
|
24
|
+
current_references.append(
|
25
|
+
{
|
26
|
+
"reference": ref,
|
27
|
+
"name": resource_path,
|
28
|
+
"type": "file",
|
29
|
+
"note": "line numbers are included in the content",
|
30
|
+
"content": content,
|
31
|
+
}
|
32
|
+
)
|
33
|
+
elif os.path.isdir(resource_path):
|
34
|
+
content = read_dir(resource_path)
|
35
|
+
current_references.append(
|
36
|
+
{
|
37
|
+
"reference": ref,
|
38
|
+
"name": resource_path,
|
39
|
+
"type": "directory",
|
40
|
+
"content": content,
|
41
|
+
}
|
42
|
+
)
|
43
|
+
|
44
|
+
return {
|
45
|
+
"current_time": datetime.datetime.now().isoformat(),
|
46
|
+
"current_working_directory": os.getcwd(),
|
47
|
+
"current_os": platform.system(),
|
48
|
+
"os_version": platform.version(),
|
49
|
+
"python_version": platform.python_version(),
|
50
|
+
"current_references": current_references,
|
51
|
+
}
|
52
|
+
|
53
|
+
|
54
|
+
def get_conversation_context(
|
55
|
+
ctx: AnyContext,
|
56
|
+
conversation_context_attr: (
|
57
|
+
dict[str, Any] | Callable[[AnySharedContext], dict[str, Any]] | None
|
58
|
+
),
|
59
|
+
) -> dict[str, Any]:
|
60
|
+
"""
|
61
|
+
Retrieves the conversation context.
|
62
|
+
If a value in the context dict is callable, it executes it with ctx.
|
63
|
+
"""
|
64
|
+
raw_context = get_attr(
|
65
|
+
ctx, conversation_context_attr, {}, auto_render=False
|
66
|
+
) # Context usually shouldn't be rendered
|
67
|
+
if not isinstance(raw_context, dict):
|
68
|
+
ctx.log_warning(
|
69
|
+
f"Conversation context resolved to type {type(raw_context)}, "
|
70
|
+
"expected dict. Returning empty context."
|
71
|
+
)
|
72
|
+
return {}
|
73
|
+
# If conversation_context contains callable value, execute them.
|
74
|
+
processed_context: dict[str, Any] = {}
|
75
|
+
for key, value in raw_context.items():
|
76
|
+
if callable(value):
|
77
|
+
try:
|
78
|
+
# Check if the callable expects 'ctx'
|
79
|
+
sig = inspect.signature(value)
|
80
|
+
if "ctx" in sig.parameters:
|
81
|
+
processed_context[key] = value(ctx)
|
82
|
+
else:
|
83
|
+
processed_context[key] = value()
|
84
|
+
except Exception as e:
|
85
|
+
ctx.log_warning(
|
86
|
+
f"Error executing callable for context key '{key}': {e}. "
|
87
|
+
"Skipping."
|
88
|
+
)
|
89
|
+
processed_context[key] = None
|
90
|
+
else:
|
91
|
+
processed_context[key] = value
|
92
|
+
return processed_context
|
93
|
+
|
94
|
+
|
95
|
+
# Context enrichment functions moved to context_enrichment.py
|
@@ -8,12 +8,13 @@ from pydantic_ai import Agent
|
|
8
8
|
from pydantic_ai.models import Model
|
9
9
|
from pydantic_ai.settings import ModelSettings
|
10
10
|
|
11
|
+
from zrb.attr.type import BoolAttr
|
11
12
|
from zrb.context.any_context import AnyContext
|
12
|
-
from zrb.task.llm.
|
13
|
-
from zrb.task.llm.
|
13
|
+
from zrb.task.llm.agent import run_agent_iteration
|
14
|
+
from zrb.task.llm.typing import ListOfDict
|
15
|
+
from zrb.util.attr import get_bool_attr
|
14
16
|
|
15
17
|
|
16
|
-
# Configuration model for context enrichment
|
17
18
|
class EnrichmentConfig(BaseModel):
|
18
19
|
model_config = {"arbitrary_types_allowed": True}
|
19
20
|
model: Model | str | None = None
|
@@ -74,13 +75,58 @@ async def enrich_context(
|
|
74
75
|
response = enrichment_run.result.data.response
|
75
76
|
if response:
|
76
77
|
conversation_context.update(response)
|
77
|
-
ctx.log_info("
|
78
|
+
ctx.log_info("Context enriched based on history.")
|
78
79
|
ctx.log_info(
|
79
|
-
f"
|
80
|
+
f"Updated conversation context: {json.dumps(conversation_context)}"
|
80
81
|
)
|
81
82
|
else:
|
82
|
-
ctx.log_warning("Context enrichment
|
83
|
+
ctx.log_warning("Context enrichment returned no data")
|
83
84
|
except Exception as e:
|
84
85
|
ctx.log_warning(f"Error during context enrichment LLM call: {e}")
|
85
86
|
traceback.print_exc()
|
86
87
|
return conversation_context
|
88
|
+
|
89
|
+
|
90
|
+
def should_enrich_context(
|
91
|
+
ctx: AnyContext,
|
92
|
+
history_list: ListOfDict,
|
93
|
+
should_enrich_context_attr: BoolAttr,
|
94
|
+
render_enrich_context: bool,
|
95
|
+
) -> bool:
|
96
|
+
"""Determines if context enrichment should occur based on history and config."""
|
97
|
+
if len(history_list) == 0:
|
98
|
+
return False
|
99
|
+
return get_bool_attr(
|
100
|
+
ctx,
|
101
|
+
should_enrich_context_attr,
|
102
|
+
True, # Default to True if not specified
|
103
|
+
auto_render=render_enrich_context,
|
104
|
+
)
|
105
|
+
|
106
|
+
|
107
|
+
async def maybe_enrich_context(
|
108
|
+
ctx: AnyContext,
|
109
|
+
history_list: ListOfDict,
|
110
|
+
conversation_context: dict[str, Any],
|
111
|
+
should_enrich_context_attr: BoolAttr,
|
112
|
+
render_enrich_context: bool,
|
113
|
+
model: str | Model | None,
|
114
|
+
model_settings: ModelSettings | None,
|
115
|
+
context_enrichment_prompt: str,
|
116
|
+
) -> dict[str, Any]:
|
117
|
+
"""Enriches context based on history if enabled."""
|
118
|
+
if should_enrich_context(
|
119
|
+
ctx, history_list, should_enrich_context_attr, render_enrich_context
|
120
|
+
):
|
121
|
+
# Use the enrich_context function now defined in this file
|
122
|
+
return await enrich_context(
|
123
|
+
ctx=ctx,
|
124
|
+
config=EnrichmentConfig(
|
125
|
+
model=model,
|
126
|
+
settings=model_settings,
|
127
|
+
prompt=context_enrichment_prompt,
|
128
|
+
),
|
129
|
+
conversation_context=conversation_context,
|
130
|
+
history_list=history_list,
|
131
|
+
)
|
132
|
+
return conversation_context
|
zrb/task/llm/history.py
CHANGED
@@ -5,12 +5,14 @@ from typing import Any, Optional
|
|
5
5
|
|
6
6
|
from pydantic import BaseModel
|
7
7
|
|
8
|
+
from zrb.attr.type import StrAttr
|
8
9
|
from zrb.context.any_context import AnyContext
|
9
|
-
from zrb.
|
10
|
+
from zrb.context.any_shared_context import AnySharedContext
|
11
|
+
from zrb.task.llm.typing import ListOfDict
|
12
|
+
from zrb.util.attr import get_str_attr
|
13
|
+
from zrb.util.file import read_file, write_file
|
10
14
|
from zrb.util.run import run_async
|
11
15
|
|
12
|
-
ListOfDict = list[dict[str, Any]]
|
13
|
-
|
14
16
|
|
15
17
|
# Define the new ConversationHistoryData model
|
16
18
|
class ConversationHistoryData(BaseModel):
|
@@ -90,3 +92,151 @@ class ConversationHistoryData(BaseModel):
|
|
90
92
|
f"Error validating/parsing history data from {source}: {e}. Ignoring."
|
91
93
|
)
|
92
94
|
return None
|
95
|
+
|
96
|
+
|
97
|
+
def get_history_file(
|
98
|
+
ctx: AnyContext,
|
99
|
+
conversation_history_file_attr: StrAttr | None,
|
100
|
+
render_history_file: bool,
|
101
|
+
) -> str:
|
102
|
+
"""Gets the path to the conversation history file, rendering if configured."""
|
103
|
+
return get_str_attr(
|
104
|
+
ctx,
|
105
|
+
conversation_history_file_attr,
|
106
|
+
"",
|
107
|
+
auto_render=render_history_file,
|
108
|
+
)
|
109
|
+
|
110
|
+
|
111
|
+
async def read_conversation_history(
|
112
|
+
ctx: AnyContext,
|
113
|
+
conversation_history_reader: (
|
114
|
+
Callable[[AnySharedContext], ConversationHistoryData | dict | list | None]
|
115
|
+
| None
|
116
|
+
),
|
117
|
+
conversation_history_file_attr: StrAttr | None,
|
118
|
+
render_history_file: bool,
|
119
|
+
conversation_history_attr: (
|
120
|
+
ConversationHistoryData
|
121
|
+
| Callable[[AnySharedContext], ConversationHistoryData | dict | list]
|
122
|
+
| dict
|
123
|
+
| list
|
124
|
+
),
|
125
|
+
) -> ConversationHistoryData:
|
126
|
+
"""Reads conversation history from reader, file, or attribute, with validation."""
|
127
|
+
history_file = get_history_file(
|
128
|
+
ctx, conversation_history_file_attr, render_history_file
|
129
|
+
)
|
130
|
+
# Use the class method defined above
|
131
|
+
history_data = await ConversationHistoryData.read_from_sources(
|
132
|
+
ctx=ctx,
|
133
|
+
reader=conversation_history_reader,
|
134
|
+
file_path=history_file,
|
135
|
+
)
|
136
|
+
if history_data:
|
137
|
+
return history_data
|
138
|
+
# Priority 3: Callable or direct conversation_history attribute
|
139
|
+
raw_data_attr: Any = None
|
140
|
+
if callable(conversation_history_attr):
|
141
|
+
try:
|
142
|
+
raw_data_attr = await run_async(conversation_history_attr(ctx))
|
143
|
+
except Exception as e:
|
144
|
+
ctx.log_warning(
|
145
|
+
f"Error executing callable conversation_history attribute: {e}. "
|
146
|
+
"Ignoring."
|
147
|
+
)
|
148
|
+
if raw_data_attr is None:
|
149
|
+
raw_data_attr = conversation_history_attr
|
150
|
+
if raw_data_attr:
|
151
|
+
# Use the class method defined above
|
152
|
+
history_data = ConversationHistoryData.parse_and_validate(
|
153
|
+
ctx, raw_data_attr, "attribute"
|
154
|
+
)
|
155
|
+
if history_data:
|
156
|
+
return history_data
|
157
|
+
# Fallback: Return default value
|
158
|
+
return ConversationHistoryData()
|
159
|
+
|
160
|
+
|
161
|
+
async def write_conversation_history(
|
162
|
+
ctx: AnyContext,
|
163
|
+
history_data: ConversationHistoryData,
|
164
|
+
conversation_history_writer: (
|
165
|
+
Callable[[AnySharedContext, ConversationHistoryData], None] | None
|
166
|
+
),
|
167
|
+
conversation_history_file_attr: StrAttr | None,
|
168
|
+
render_history_file: bool,
|
169
|
+
):
|
170
|
+
"""Writes conversation history using the writer or to a file."""
|
171
|
+
if conversation_history_writer is not None:
|
172
|
+
await run_async(conversation_history_writer(ctx, history_data))
|
173
|
+
history_file = get_history_file(
|
174
|
+
ctx, conversation_history_file_attr, render_history_file
|
175
|
+
)
|
176
|
+
if history_file != "":
|
177
|
+
write_file(history_file, history_data.model_dump_json(indent=2))
|
178
|
+
|
179
|
+
|
180
|
+
async def prepare_initial_state(
|
181
|
+
ctx: AnyContext,
|
182
|
+
conversation_history_reader: (
|
183
|
+
Callable[[AnySharedContext], ConversationHistoryData | dict | list | None]
|
184
|
+
| None
|
185
|
+
),
|
186
|
+
conversation_history_file_attr: StrAttr | None,
|
187
|
+
render_history_file: bool,
|
188
|
+
conversation_history_attr: (
|
189
|
+
ConversationHistoryData
|
190
|
+
| Callable[[AnySharedContext], ConversationHistoryData | dict | list]
|
191
|
+
| dict
|
192
|
+
| list
|
193
|
+
),
|
194
|
+
conversation_context_getter: Callable[[AnyContext], dict[str, Any]],
|
195
|
+
) -> tuple[ListOfDict, dict[str, Any]]:
|
196
|
+
"""Reads history and prepares the initial conversation context."""
|
197
|
+
history_data: ConversationHistoryData = await read_conversation_history(
|
198
|
+
ctx,
|
199
|
+
conversation_history_reader,
|
200
|
+
conversation_history_file_attr,
|
201
|
+
render_history_file,
|
202
|
+
conversation_history_attr,
|
203
|
+
)
|
204
|
+
# Clean the history list to remove context from historical user prompts
|
205
|
+
cleaned_history_list = []
|
206
|
+
for interaction in history_data.history:
|
207
|
+
cleaned_history_list.append(
|
208
|
+
remove_context_from_interaction_history(interaction)
|
209
|
+
)
|
210
|
+
conversation_context = conversation_context_getter(ctx)
|
211
|
+
# Merge history context from loaded data without overwriting existing keys
|
212
|
+
for key, value in history_data.context.items():
|
213
|
+
if key not in conversation_context:
|
214
|
+
conversation_context[key] = value
|
215
|
+
# Return the CLEANED history list
|
216
|
+
return cleaned_history_list, conversation_context
|
217
|
+
|
218
|
+
|
219
|
+
def remove_context_from_interaction_history(
|
220
|
+
interaction: dict[str, Any],
|
221
|
+
) -> dict[str, Any]:
|
222
|
+
try:
|
223
|
+
cleaned_interaction = json.loads(json.dumps(interaction))
|
224
|
+
except Exception:
|
225
|
+
# Fallback to shallow copy if not JSON serializable (less safe)
|
226
|
+
cleaned_interaction = interaction.copy()
|
227
|
+
if "parts" in cleaned_interaction and isinstance(
|
228
|
+
cleaned_interaction["parts"], list
|
229
|
+
):
|
230
|
+
for part in cleaned_interaction["parts"]:
|
231
|
+
is_user_prompt = part.get("part_kind") == "user-prompt"
|
232
|
+
has_str_content = isinstance(part.get("content"), str)
|
233
|
+
if is_user_prompt and has_str_content:
|
234
|
+
content = part["content"]
|
235
|
+
user_message_marker = "# User Message\n"
|
236
|
+
marker_index = content.find(user_message_marker)
|
237
|
+
if marker_index != -1:
|
238
|
+
# Extract message after the marker and strip whitespace
|
239
|
+
start_index = marker_index + len(user_message_marker)
|
240
|
+
part["content"] = content[start_index:].strip()
|
241
|
+
# else: If marker not found, leave content as is (old format/error)
|
242
|
+
return cleaned_interaction
|