zrb 1.5.7__py3-none-any.whl → 1.5.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zrb/builtin/llm/tool/rag.py +4 -3
- zrb/task/cmd_task.py +2 -1
- zrb/task/llm/agent.py +141 -0
- zrb/task/llm/config.py +83 -0
- zrb/task/llm/context.py +95 -0
- zrb/task/llm/{context_enricher.py → context_enrichment.py} +52 -6
- zrb/task/llm/history.py +153 -3
- zrb/task/llm/history_summarization.py +170 -0
- zrb/task/llm/prompt.py +87 -0
- zrb/task/llm/typing.py +3 -0
- zrb/task/llm_task.py +135 -320
- {zrb-1.5.7.dist-info → zrb-1.5.8.dist-info}/METADATA +1 -1
- {zrb-1.5.7.dist-info → zrb-1.5.8.dist-info}/RECORD +15 -12
- zrb/task/llm/agent_runner.py +0 -53
- zrb/task/llm/default_context.py +0 -45
- zrb/task/llm/history_summarizer.py +0 -71
- {zrb-1.5.7.dist-info → zrb-1.5.8.dist-info}/WHEEL +0 -0
- {zrb-1.5.7.dist-info → zrb-1.5.8.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,170 @@
|
|
1
|
+
import json
|
2
|
+
from typing import Any
|
3
|
+
|
4
|
+
from pydantic import BaseModel
|
5
|
+
from pydantic_ai import Agent
|
6
|
+
from pydantic_ai.models import Model
|
7
|
+
from pydantic_ai.settings import ModelSettings
|
8
|
+
|
9
|
+
from zrb.attr.type import BoolAttr, IntAttr
|
10
|
+
from zrb.context.any_context import AnyContext
|
11
|
+
from zrb.task.llm.agent import run_agent_iteration
|
12
|
+
from zrb.task.llm.typing import ListOfDict
|
13
|
+
from zrb.util.attr import get_bool_attr, get_int_attr
|
14
|
+
|
15
|
+
|
16
|
+
def get_history_part_len(history_list: ListOfDict) -> int:
|
17
|
+
"""Calculates the total number of 'parts' in a history list."""
|
18
|
+
history_part_len = 0
|
19
|
+
for history in history_list:
|
20
|
+
if "parts" in history:
|
21
|
+
history_part_len += len(history["parts"])
|
22
|
+
else:
|
23
|
+
history_part_len += 1
|
24
|
+
return history_part_len
|
25
|
+
|
26
|
+
|
27
|
+
def get_history_summarization_threshold(
|
28
|
+
ctx: AnyContext,
|
29
|
+
history_summarization_threshold_attr: IntAttr,
|
30
|
+
render_history_summarization_threshold: bool,
|
31
|
+
) -> int:
|
32
|
+
"""Gets the history summarization threshold, handling defaults and errors."""
|
33
|
+
try:
|
34
|
+
return get_int_attr(
|
35
|
+
ctx,
|
36
|
+
history_summarization_threshold_attr,
|
37
|
+
-1, # Default to -1 (no threshold)
|
38
|
+
auto_render=render_history_summarization_threshold,
|
39
|
+
)
|
40
|
+
except ValueError as e:
|
41
|
+
ctx.log_warning(
|
42
|
+
f"Could not convert history_summarization_threshold to int: {e}. "
|
43
|
+
"Defaulting to -1 (no threshold)."
|
44
|
+
)
|
45
|
+
return -1
|
46
|
+
|
47
|
+
|
48
|
+
def should_summarize_history(
|
49
|
+
ctx: AnyContext,
|
50
|
+
history_list: ListOfDict,
|
51
|
+
should_summarize_history_attr: BoolAttr,
|
52
|
+
render_summarize_history: bool,
|
53
|
+
history_summarization_threshold_attr: IntAttr,
|
54
|
+
render_history_summarization_threshold: bool,
|
55
|
+
) -> bool:
|
56
|
+
"""Determines if history summarization should occur based on length and config."""
|
57
|
+
history_part_len = get_history_part_len(history_list)
|
58
|
+
if history_part_len == 0:
|
59
|
+
return False
|
60
|
+
summarization_threshold = get_history_summarization_threshold(
|
61
|
+
ctx,
|
62
|
+
history_summarization_threshold_attr,
|
63
|
+
render_history_summarization_threshold,
|
64
|
+
)
|
65
|
+
if summarization_threshold == -1: # -1 means no summarization trigger
|
66
|
+
return False
|
67
|
+
if summarization_threshold > history_part_len:
|
68
|
+
return False
|
69
|
+
return get_bool_attr(
|
70
|
+
ctx,
|
71
|
+
should_summarize_history_attr,
|
72
|
+
False, # Default to False if not specified
|
73
|
+
auto_render=render_summarize_history,
|
74
|
+
)
|
75
|
+
|
76
|
+
|
77
|
+
class SummarizationConfig(BaseModel):
|
78
|
+
model_config = {"arbitrary_types_allowed": True}
|
79
|
+
model: Model | str | None = None
|
80
|
+
settings: ModelSettings | None = None
|
81
|
+
prompt: str
|
82
|
+
retries: int = 1
|
83
|
+
|
84
|
+
|
85
|
+
async def summarize_history(
|
86
|
+
ctx: AnyContext,
|
87
|
+
config: SummarizationConfig,
|
88
|
+
conversation_context: dict[str, Any],
|
89
|
+
history_list: ListOfDict,
|
90
|
+
) -> dict[str, Any]:
|
91
|
+
"""Runs an LLM call to summarize history and update the context."""
|
92
|
+
ctx.log_info("Attempting to summarize conversation history...")
|
93
|
+
|
94
|
+
summarization_agent = Agent(
|
95
|
+
model=config.model,
|
96
|
+
system_prompt=config.prompt,
|
97
|
+
tools=[], # No tools needed for summarization
|
98
|
+
mcp_servers=[],
|
99
|
+
model_settings=config.settings,
|
100
|
+
retries=config.retries,
|
101
|
+
)
|
102
|
+
|
103
|
+
# Prepare context and history for summarization prompt
|
104
|
+
try:
|
105
|
+
context_json = json.dumps(conversation_context)
|
106
|
+
history_to_summarize_json = json.dumps(history_list)
|
107
|
+
summarization_user_prompt = (
|
108
|
+
f"# Current Context\n{context_json}\n\n"
|
109
|
+
f"# Conversation History to Summarize\n{history_to_summarize_json}"
|
110
|
+
)
|
111
|
+
except Exception as e:
|
112
|
+
ctx.log_warning(f"Error formatting context/history for summarization: {e}")
|
113
|
+
return conversation_context # Return original context if formatting fails
|
114
|
+
|
115
|
+
try:
|
116
|
+
summary_run = await run_agent_iteration(
|
117
|
+
ctx=ctx,
|
118
|
+
agent=summarization_agent,
|
119
|
+
user_prompt=summarization_user_prompt,
|
120
|
+
history_list=[], # Summarization agent doesn't need prior history
|
121
|
+
)
|
122
|
+
if summary_run and summary_run.result.data:
|
123
|
+
summary_text = str(summary_run.result.data)
|
124
|
+
# Update context with the new summary
|
125
|
+
conversation_context["history_summary"] = summary_text
|
126
|
+
ctx.log_info("History summarized and added/updated in context.")
|
127
|
+
ctx.log_info(f"Conversation summary: {summary_text}")
|
128
|
+
else:
|
129
|
+
ctx.log_warning("History summarization failed or returned no data.")
|
130
|
+
except Exception as e:
|
131
|
+
ctx.log_warning(f"Error during history summarization: {e}")
|
132
|
+
|
133
|
+
return conversation_context
|
134
|
+
|
135
|
+
|
136
|
+
async def maybe_summarize_history(
|
137
|
+
ctx: AnyContext,
|
138
|
+
history_list: ListOfDict,
|
139
|
+
conversation_context: dict[str, Any],
|
140
|
+
should_summarize_history_attr: BoolAttr,
|
141
|
+
render_summarize_history: bool,
|
142
|
+
history_summarization_threshold_attr: IntAttr,
|
143
|
+
render_history_summarization_threshold: bool,
|
144
|
+
model: str | Model | None,
|
145
|
+
model_settings: ModelSettings | None,
|
146
|
+
summarization_prompt: str,
|
147
|
+
) -> tuple[ListOfDict, dict[str, Any]]:
|
148
|
+
"""Summarizes history and updates context if enabled and threshold met."""
|
149
|
+
if should_summarize_history(
|
150
|
+
ctx,
|
151
|
+
history_list,
|
152
|
+
should_summarize_history_attr,
|
153
|
+
render_summarize_history,
|
154
|
+
history_summarization_threshold_attr,
|
155
|
+
render_history_summarization_threshold,
|
156
|
+
):
|
157
|
+
# Use summarize_history defined above
|
158
|
+
updated_context = await summarize_history(
|
159
|
+
ctx=ctx,
|
160
|
+
config=SummarizationConfig(
|
161
|
+
model=model,
|
162
|
+
settings=model_settings,
|
163
|
+
prompt=summarization_prompt,
|
164
|
+
),
|
165
|
+
conversation_context=conversation_context,
|
166
|
+
history_list=history_list, # Pass the full list for context
|
167
|
+
)
|
168
|
+
# Truncate the history list after summarization
|
169
|
+
return [], updated_context
|
170
|
+
return history_list, conversation_context
|
zrb/task/llm/prompt.py
ADDED
@@ -0,0 +1,87 @@
|
|
1
|
+
import json
|
2
|
+
from textwrap import dedent
|
3
|
+
from typing import Any
|
4
|
+
|
5
|
+
from zrb.attr.type import StrAttr
|
6
|
+
from zrb.context.any_context import AnyContext
|
7
|
+
from zrb.llm_config import llm_config as default_llm_config
|
8
|
+
from zrb.task.llm.context import get_default_context # Updated import
|
9
|
+
from zrb.util.attr import get_attr, get_str_attr
|
10
|
+
|
11
|
+
|
12
|
+
def get_system_prompt(
|
13
|
+
ctx: AnyContext,
|
14
|
+
system_prompt_attr: StrAttr | None,
|
15
|
+
render_system_prompt: bool,
|
16
|
+
) -> str:
|
17
|
+
"""Gets the system prompt, rendering if configured and handling defaults."""
|
18
|
+
system_prompt = get_attr(
|
19
|
+
ctx,
|
20
|
+
system_prompt_attr,
|
21
|
+
None,
|
22
|
+
auto_render=render_system_prompt,
|
23
|
+
)
|
24
|
+
if system_prompt is not None:
|
25
|
+
return system_prompt
|
26
|
+
return default_llm_config.get_default_system_prompt()
|
27
|
+
|
28
|
+
|
29
|
+
def get_user_message(
|
30
|
+
ctx: AnyContext,
|
31
|
+
message_attr: StrAttr | None,
|
32
|
+
) -> str:
|
33
|
+
"""Gets the user message, rendering and providing a default."""
|
34
|
+
return get_str_attr(ctx, message_attr, "How are you?", auto_render=True)
|
35
|
+
|
36
|
+
|
37
|
+
def get_summarization_prompt(
|
38
|
+
ctx: AnyContext,
|
39
|
+
summarization_prompt_attr: StrAttr | None,
|
40
|
+
render_summarization_prompt: bool,
|
41
|
+
) -> str:
|
42
|
+
"""Gets the summarization prompt, rendering if configured and handling defaults."""
|
43
|
+
summarization_prompt = get_attr(
|
44
|
+
ctx,
|
45
|
+
summarization_prompt_attr,
|
46
|
+
None,
|
47
|
+
auto_render=render_summarization_prompt,
|
48
|
+
)
|
49
|
+
if summarization_prompt is not None:
|
50
|
+
return summarization_prompt
|
51
|
+
return default_llm_config.get_default_summarization_prompt()
|
52
|
+
|
53
|
+
|
54
|
+
def get_context_enrichment_prompt(
|
55
|
+
ctx: AnyContext,
|
56
|
+
context_enrichment_prompt_attr: StrAttr | None,
|
57
|
+
render_context_enrichment_prompt: bool,
|
58
|
+
) -> str:
|
59
|
+
"""Gets the context enrichment prompt, rendering if configured and handling defaults."""
|
60
|
+
context_enrichment_prompt = get_attr(
|
61
|
+
ctx,
|
62
|
+
context_enrichment_prompt_attr,
|
63
|
+
None,
|
64
|
+
auto_render=render_context_enrichment_prompt,
|
65
|
+
)
|
66
|
+
if context_enrichment_prompt is not None:
|
67
|
+
return context_enrichment_prompt
|
68
|
+
return default_llm_config.get_default_context_enrichment_prompt()
|
69
|
+
|
70
|
+
|
71
|
+
def build_user_prompt(
|
72
|
+
ctx: AnyContext,
|
73
|
+
message_attr: StrAttr | None,
|
74
|
+
conversation_context: dict[str, Any],
|
75
|
+
) -> str:
|
76
|
+
"""Constructs the final user prompt including context."""
|
77
|
+
user_message = get_user_message(ctx, message_attr)
|
78
|
+
# Combine default context, conversation context (potentially enriched/summarized)
|
79
|
+
enriched_context = {**get_default_context(user_message), **conversation_context}
|
80
|
+
return dedent(
|
81
|
+
f"""
|
82
|
+
# Context
|
83
|
+
{json.dumps(enriched_context)}
|
84
|
+
# User Message
|
85
|
+
{user_message}
|
86
|
+
"""
|
87
|
+
).strip()
|
zrb/task/llm/typing.py
ADDED