zrb 1.9.17__py3-none-any.whl → 1.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,172 +0,0 @@
1
- import json
2
- import traceback
3
- from typing import TYPE_CHECKING
4
-
5
- from zrb.attr.type import BoolAttr, IntAttr
6
- from zrb.config.llm_config import llm_config
7
- from zrb.config.llm_rate_limitter import LLMRateLimiter, llm_rate_limitter
8
- from zrb.context.any_context import AnyContext
9
- from zrb.task.llm.agent import run_agent_iteration
10
- from zrb.task.llm.history import (
11
- count_part_in_history_list,
12
- replace_system_prompt_in_history_list,
13
- )
14
- from zrb.task.llm.typing import ListOfDict
15
- from zrb.util.attr import get_bool_attr, get_int_attr
16
- from zrb.util.cli.style import stylize_faint
17
-
18
- if TYPE_CHECKING:
19
- from pydantic_ai.models import Model
20
- from pydantic_ai.settings import ModelSettings
21
-
22
-
23
- def _count_token_in_history(history_list: ListOfDict) -> int:
24
- """Counts the total number of tokens in a conversation history list."""
25
- text_to_count = json.dumps(history_list)
26
- return llm_rate_limitter.count_token(text_to_count)
27
-
28
-
29
- async def enrich_context(
30
- ctx: AnyContext,
31
- model: "Model | str | None",
32
- settings: "ModelSettings | None",
33
- prompt: str,
34
- previous_long_term_context: str,
35
- history_list: ListOfDict,
36
- rate_limitter: LLMRateLimiter | None = None,
37
- retries: int = 3,
38
- ) -> str:
39
- """Runs an LLM call to update the long-term context and returns the new context string."""
40
- from pydantic_ai import Agent
41
-
42
- ctx.log_info("Attempting to enrich conversation context...")
43
- # Construct the user prompt according to the new prompt format
44
- user_prompt = json.dumps(
45
- {
46
- "previous_long_term_context": previous_long_term_context,
47
- "recent_conversation_history": history_list,
48
- }
49
- )
50
- enrichment_agent = Agent(
51
- model=model,
52
- system_prompt=prompt,
53
- model_settings=settings,
54
- retries=retries,
55
- )
56
-
57
- try:
58
- ctx.print(stylize_faint("💡 Enrich Context"), plain=True)
59
- enrichment_run = await run_agent_iteration(
60
- ctx=ctx,
61
- agent=enrichment_agent,
62
- user_prompt=user_prompt,
63
- history_list=[], # Enrichment agent works off the prompt, not history
64
- rate_limitter=rate_limitter,
65
- )
66
- if enrichment_run and enrichment_run.result.output:
67
- new_long_term_context = str(enrichment_run.result.output)
68
- usage = enrichment_run.result.usage()
69
- ctx.print(
70
- stylize_faint(f"💡 Context Enrichment Token: {usage}"), plain=True
71
- )
72
- ctx.print(plain=True)
73
- ctx.log_info("Context enriched based on history.")
74
- ctx.log_info(f"Updated long-term context:\n{new_long_term_context}")
75
- return new_long_term_context
76
- else:
77
- ctx.log_warning("Context enrichment returned no data.")
78
- except BaseException as e:
79
- ctx.log_warning(f"Error during context enrichment LLM call: {e}")
80
- traceback.print_exc()
81
-
82
- # Return the original context if enrichment fails
83
- return previous_long_term_context
84
-
85
-
86
- def get_context_enrichment_token_threshold(
87
- ctx: AnyContext,
88
- context_enrichment_token_threshold_attr: IntAttr | None,
89
- render_context_enrichment_token_threshold: bool,
90
- ) -> int:
91
- """Gets the context enrichment token threshold, handling defaults and errors."""
92
- try:
93
- return get_int_attr(
94
- ctx,
95
- context_enrichment_token_threshold_attr,
96
- llm_config.default_context_enrichment_token_threshold,
97
- auto_render=render_context_enrichment_token_threshold,
98
- )
99
- except ValueError as e:
100
- ctx.log_warning(
101
- f"Could not convert context_enrichment_token_threshold to int: {e}. "
102
- "Defaulting to -1 (no threshold)."
103
- )
104
- return -1
105
-
106
-
107
- def should_enrich_context(
108
- ctx: AnyContext,
109
- history_list: ListOfDict,
110
- should_enrich_context_attr: BoolAttr | None,
111
- render_enrich_context: bool,
112
- context_enrichment_token_threshold_attr: IntAttr | None,
113
- render_context_enrichment_token_threshold: bool,
114
- ) -> bool:
115
- """
116
- Determines if context enrichment should occur based on history, token threshold, and config.
117
- """
118
- history_part_count = count_part_in_history_list(history_list)
119
- if history_part_count == 0:
120
- return False
121
- enrichment_token_threshold = get_context_enrichment_token_threshold(
122
- ctx,
123
- context_enrichment_token_threshold_attr,
124
- render_context_enrichment_token_threshold,
125
- )
126
- history_token_count = _count_token_in_history(history_list)
127
- if (
128
- enrichment_token_threshold == -1
129
- or enrichment_token_threshold > history_token_count
130
- ):
131
- return False
132
- return get_bool_attr(
133
- ctx,
134
- should_enrich_context_attr,
135
- llm_config.default_enrich_context,
136
- auto_render=render_enrich_context,
137
- )
138
-
139
-
140
- async def maybe_enrich_context(
141
- ctx: AnyContext,
142
- history_list: ListOfDict,
143
- long_term_context: str,
144
- should_enrich_context_attr: BoolAttr | None,
145
- render_enrich_context: bool,
146
- context_enrichment_token_threshold_attr: IntAttr | None,
147
- render_context_enrichment_token_threshold: bool,
148
- model: "str | Model | None",
149
- model_settings: "ModelSettings | None",
150
- context_enrichment_prompt: str,
151
- rate_limitter: LLMRateLimiter | None = None,
152
- ) -> str:
153
- """Enriches context based on history if enabled and token threshold met."""
154
- shorten_history_list = replace_system_prompt_in_history_list(history_list)
155
- if should_enrich_context(
156
- ctx,
157
- shorten_history_list,
158
- should_enrich_context_attr,
159
- render_enrich_context,
160
- context_enrichment_token_threshold_attr,
161
- render_context_enrichment_token_threshold,
162
- ):
163
- return await enrich_context(
164
- ctx=ctx,
165
- model=model,
166
- settings=model_settings,
167
- prompt=context_enrichment_prompt,
168
- previous_long_term_context=long_term_context,
169
- history_list=shorten_history_list,
170
- rate_limitter=rate_limitter,
171
- )
172
- return long_term_context
zrb/task/llm/history.py DELETED
@@ -1,233 +0,0 @@
1
- import json
2
- import os
3
- from collections.abc import Callable
4
- from copy import deepcopy
5
- from typing import Any, Optional
6
-
7
- from zrb.attr.type import StrAttr
8
- from zrb.context.any_context import AnyContext
9
- from zrb.context.any_shared_context import AnySharedContext
10
- from zrb.task.llm.typing import ListOfDict
11
- from zrb.util.attr import get_str_attr
12
- from zrb.util.file import read_file, write_file
13
- from zrb.util.run import run_async
14
-
15
-
16
- # Define the new ConversationHistoryData model
17
- class ConversationHistoryData:
18
- def __init__(
19
- self,
20
- long_term_context: str = "",
21
- conversation_summary: str = "",
22
- history: Optional[ListOfDict] = None,
23
- messages: Optional[ListOfDict] = None,
24
- ):
25
- self.long_term_context = long_term_context
26
- self.conversation_summary = conversation_summary
27
- self.history = (
28
- history
29
- if history is not None
30
- else (messages if messages is not None else [])
31
- )
32
-
33
- def to_dict(self) -> dict[str, Any]:
34
- return {
35
- "long_term_context": self.long_term_context,
36
- "conversation_summary": self.conversation_summary,
37
- "history": self.history,
38
- }
39
-
40
- def model_dump_json(self, indent: int = 2) -> str:
41
- return json.dumps(self.to_dict(), indent=indent)
42
-
43
- @classmethod
44
- async def read_from_sources(
45
- cls,
46
- ctx: AnyContext,
47
- reader: Callable[[AnyContext], dict[str, Any] | list | None] | None,
48
- file_path: str | None,
49
- ) -> Optional["ConversationHistoryData"]:
50
- """Reads conversation history from various sources with priority."""
51
- # Priority 1: Reader function
52
- if reader:
53
- try:
54
- raw_data = await run_async(reader(ctx))
55
- if raw_data:
56
- instance = cls.parse_and_validate(ctx, raw_data, "reader")
57
- if instance:
58
- return instance
59
- except Exception as e:
60
- ctx.log_warning(
61
- f"Error executing conversation history reader: {e}. Ignoring."
62
- )
63
- # Priority 2: History file
64
- if file_path and os.path.isfile(file_path):
65
- try:
66
- content = read_file(file_path)
67
- raw_data = json.loads(content)
68
- instance = cls.parse_and_validate(ctx, raw_data, f"file '{file_path}'")
69
- if instance:
70
- return instance
71
- except json.JSONDecodeError:
72
- ctx.log_warning(
73
- f"Could not decode JSON from history file '{file_path}'. "
74
- "Ignoring file content."
75
- )
76
- except Exception as e:
77
- ctx.log_warning(
78
- f"Error reading history file '{file_path}': {e}. "
79
- "Ignoring file content."
80
- )
81
- # If neither reader nor file provided valid data
82
- return None
83
-
84
- @classmethod
85
- def parse_and_validate(
86
- cls, ctx: AnyContext, data: Any, source: str
87
- ) -> Optional["ConversationHistoryData"]:
88
- """Parses raw data into ConversationHistoryData, handling validation & old formats."""
89
- try:
90
- if isinstance(data, cls):
91
- return data # Already a valid instance
92
- if isinstance(data, dict):
93
- # This handles both the new format and the old {'context': ..., 'history': ...}
94
- return cls(
95
- long_term_context=data.get("long_term_context", ""),
96
- conversation_summary=data.get("conversation_summary", ""),
97
- history=data.get("history", data.get("messages")),
98
- )
99
- elif isinstance(data, list):
100
- # Handle very old format (just a list) - wrap it
101
- ctx.log_warning(
102
- f"History from {source} contains legacy list format. "
103
- "Wrapping it into the new structure. "
104
- "Consider updating the source format."
105
- )
106
- return cls(history=data)
107
- else:
108
- ctx.log_warning(
109
- f"History data from {source} has unexpected format "
110
- f"(type: {type(data)}). Ignoring."
111
- )
112
- return None
113
- except Exception as e: # Catch validation errors too
114
- ctx.log_warning(
115
- f"Error validating/parsing history data from {source}: {e}. Ignoring."
116
- )
117
- return None
118
-
119
-
120
- def get_history_file(
121
- ctx: AnyContext,
122
- conversation_history_file_attr: StrAttr | None,
123
- render_history_file: bool,
124
- ) -> str:
125
- """Gets the path to the conversation history file, rendering if configured."""
126
- return get_str_attr(
127
- ctx,
128
- conversation_history_file_attr,
129
- "",
130
- auto_render=render_history_file,
131
- )
132
-
133
-
134
- async def read_conversation_history(
135
- ctx: AnyContext,
136
- conversation_history_reader: (
137
- Callable[[AnySharedContext], ConversationHistoryData | dict | list | None]
138
- | None
139
- ),
140
- conversation_history_file_attr: StrAttr | None,
141
- render_history_file: bool,
142
- conversation_history_attr: (
143
- ConversationHistoryData
144
- | Callable[[AnySharedContext], ConversationHistoryData | dict | list]
145
- | dict
146
- | list
147
- ),
148
- ) -> ConversationHistoryData:
149
- """Reads conversation history from reader, file, or attribute, with validation."""
150
- history_file = get_history_file(
151
- ctx, conversation_history_file_attr, render_history_file
152
- )
153
- # Use the class method defined above
154
- history_data = await ConversationHistoryData.read_from_sources(
155
- ctx=ctx,
156
- reader=conversation_history_reader,
157
- file_path=history_file,
158
- )
159
- if history_data:
160
- return history_data
161
- # Priority 3: Callable or direct conversation_history attribute
162
- raw_data_attr: Any = None
163
- if callable(conversation_history_attr):
164
- try:
165
- raw_data_attr = await run_async(conversation_history_attr(ctx))
166
- except Exception as e:
167
- ctx.log_warning(
168
- f"Error executing callable conversation_history attribute: {e}. "
169
- "Ignoring."
170
- )
171
- if raw_data_attr is None:
172
- raw_data_attr = conversation_history_attr
173
- if raw_data_attr:
174
- # Use the class method defined above
175
- history_data = ConversationHistoryData.parse_and_validate(
176
- ctx, raw_data_attr, "attribute"
177
- )
178
- if history_data:
179
- return history_data
180
- # Fallback: Return default value
181
- return ConversationHistoryData()
182
-
183
-
184
- async def write_conversation_history(
185
- ctx: AnyContext,
186
- history_data: ConversationHistoryData,
187
- conversation_history_writer: (
188
- Callable[[AnySharedContext, ConversationHistoryData], None] | None
189
- ),
190
- conversation_history_file_attr: StrAttr | None,
191
- render_history_file: bool,
192
- ):
193
- """Writes conversation history using the writer or to a file."""
194
- if conversation_history_writer is not None:
195
- await run_async(conversation_history_writer(ctx, history_data))
196
- history_file = get_history_file(
197
- ctx, conversation_history_file_attr, render_history_file
198
- )
199
- if history_file != "":
200
- write_file(history_file, json.dumps(history_data.to_dict(), indent=2))
201
-
202
-
203
- def replace_system_prompt_in_history_list(
204
- history_list: ListOfDict, replacement: str = "<main LLM system prompt>"
205
- ) -> ListOfDict:
206
- """
207
- Returns a new history list where any part with part_kind 'system-prompt'
208
- has its 'content' replaced with the given replacement string.
209
- Args:
210
- history: List of history items (each item is a dict with a 'parts' list).
211
- replacement: The string to use in place of system-prompt content.
212
-
213
- Returns:
214
- A deep-copied list of history items with system-prompt content replaced.
215
- """
216
- new_history = deepcopy(history_list)
217
- for item in new_history:
218
- parts = item.get("parts", [])
219
- for part in parts:
220
- if part.get("part_kind") == "system-prompt":
221
- part["content"] = replacement
222
- return new_history
223
-
224
-
225
- def count_part_in_history_list(history_list: ListOfDict) -> int:
226
- """Calculates the total number of 'parts' in a history list."""
227
- history_part_len = 0
228
- for history in history_list:
229
- if "parts" in history:
230
- history_part_len += len(history["parts"])
231
- else:
232
- history_part_len += 1
233
- return history_part_len
File without changes