zrb 1.15.3__py3-none-any.whl → 2.0.0a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of zrb might be problematic. Click here for more details.
- zrb/__init__.py +118 -133
- zrb/attr/type.py +10 -7
- zrb/builtin/__init__.py +55 -1
- zrb/builtin/git.py +12 -1
- zrb/builtin/group.py +31 -15
- zrb/builtin/llm/chat.py +147 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/add_entity_util.py +7 -7
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/module/add_module_util.py +5 -5
- zrb/builtin/project/add/fastapp/fastapp_util.py +1 -1
- zrb/builtin/searxng/config/settings.yml +5671 -0
- zrb/builtin/searxng/start.py +21 -0
- zrb/builtin/shell/autocomplete/bash.py +4 -3
- zrb/builtin/shell/autocomplete/zsh.py +4 -3
- zrb/callback/callback.py +8 -1
- zrb/cmd/cmd_result.py +2 -1
- zrb/config/config.py +555 -169
- zrb/config/helper.py +84 -0
- zrb/config/web_auth_config.py +50 -35
- zrb/context/any_shared_context.py +20 -3
- zrb/context/context.py +39 -5
- zrb/context/print_fn.py +13 -0
- zrb/context/shared_context.py +17 -8
- zrb/group/any_group.py +3 -3
- zrb/group/group.py +3 -3
- zrb/input/any_input.py +5 -1
- zrb/input/base_input.py +18 -6
- zrb/input/option_input.py +41 -1
- zrb/input/text_input.py +7 -24
- zrb/llm/agent/__init__.py +9 -0
- zrb/llm/agent/agent.py +215 -0
- zrb/llm/agent/summarizer.py +20 -0
- zrb/llm/app/__init__.py +10 -0
- zrb/llm/app/completion.py +281 -0
- zrb/llm/app/confirmation/allow_tool.py +66 -0
- zrb/llm/app/confirmation/handler.py +178 -0
- zrb/llm/app/confirmation/replace_confirmation.py +77 -0
- zrb/llm/app/keybinding.py +34 -0
- zrb/llm/app/layout.py +117 -0
- zrb/llm/app/lexer.py +155 -0
- zrb/llm/app/redirection.py +28 -0
- zrb/llm/app/style.py +16 -0
- zrb/llm/app/ui.py +733 -0
- zrb/llm/config/__init__.py +4 -0
- zrb/llm/config/config.py +122 -0
- zrb/llm/config/limiter.py +247 -0
- zrb/llm/history_manager/__init__.py +4 -0
- zrb/llm/history_manager/any_history_manager.py +23 -0
- zrb/llm/history_manager/file_history_manager.py +91 -0
- zrb/llm/history_processor/summarizer.py +108 -0
- zrb/llm/note/__init__.py +3 -0
- zrb/llm/note/manager.py +122 -0
- zrb/llm/prompt/__init__.py +29 -0
- zrb/llm/prompt/claude_compatibility.py +92 -0
- zrb/llm/prompt/compose.py +55 -0
- zrb/llm/prompt/default.py +51 -0
- zrb/llm/prompt/markdown/file_extractor.md +112 -0
- zrb/llm/prompt/markdown/mandate.md +23 -0
- zrb/llm/prompt/markdown/persona.md +3 -0
- zrb/llm/prompt/markdown/repo_extractor.md +112 -0
- zrb/llm/prompt/markdown/repo_summarizer.md +29 -0
- zrb/llm/prompt/markdown/summarizer.md +21 -0
- zrb/llm/prompt/note.py +41 -0
- zrb/llm/prompt/system_context.py +46 -0
- zrb/llm/prompt/zrb.py +41 -0
- zrb/llm/skill/__init__.py +3 -0
- zrb/llm/skill/manager.py +86 -0
- zrb/llm/task/__init__.py +4 -0
- zrb/llm/task/llm_chat_task.py +316 -0
- zrb/llm/task/llm_task.py +245 -0
- zrb/llm/tool/__init__.py +39 -0
- zrb/llm/tool/bash.py +75 -0
- zrb/llm/tool/code.py +266 -0
- zrb/llm/tool/file.py +419 -0
- zrb/llm/tool/note.py +70 -0
- zrb/{builtin/llm → llm}/tool/rag.py +33 -37
- zrb/llm/tool/search/brave.py +53 -0
- zrb/llm/tool/search/searxng.py +47 -0
- zrb/llm/tool/search/serpapi.py +47 -0
- zrb/llm/tool/skill.py +19 -0
- zrb/llm/tool/sub_agent.py +70 -0
- zrb/llm/tool/web.py +97 -0
- zrb/llm/tool/zrb_task.py +66 -0
- zrb/llm/util/attachment.py +101 -0
- zrb/llm/util/prompt.py +104 -0
- zrb/llm/util/stream_response.py +178 -0
- zrb/runner/cli.py +21 -20
- zrb/runner/common_util.py +24 -19
- zrb/runner/web_route/task_input_api_route.py +5 -5
- zrb/runner/web_util/user.py +7 -3
- zrb/session/any_session.py +12 -9
- zrb/session/session.py +38 -17
- zrb/task/any_task.py +24 -3
- zrb/task/base/context.py +42 -22
- zrb/task/base/execution.py +67 -55
- zrb/task/base/lifecycle.py +14 -7
- zrb/task/base/monitoring.py +12 -7
- zrb/task/base_task.py +113 -50
- zrb/task/base_trigger.py +16 -6
- zrb/task/cmd_task.py +6 -0
- zrb/task/http_check.py +11 -5
- zrb/task/make_task.py +5 -3
- zrb/task/rsync_task.py +30 -10
- zrb/task/scaffolder.py +7 -4
- zrb/task/scheduler.py +7 -4
- zrb/task/tcp_check.py +6 -4
- zrb/util/ascii_art/art/bee.txt +17 -0
- zrb/util/ascii_art/art/cat.txt +9 -0
- zrb/util/ascii_art/art/ghost.txt +16 -0
- zrb/util/ascii_art/art/panda.txt +17 -0
- zrb/util/ascii_art/art/rose.txt +14 -0
- zrb/util/ascii_art/art/unicorn.txt +15 -0
- zrb/util/ascii_art/banner.py +92 -0
- zrb/util/attr.py +54 -39
- zrb/util/cli/markdown.py +32 -0
- zrb/util/cli/text.py +30 -0
- zrb/util/cmd/command.py +33 -10
- zrb/util/file.py +61 -33
- zrb/util/git.py +2 -2
- zrb/util/{llm/prompt.py → markdown.py} +2 -3
- zrb/util/match.py +78 -0
- zrb/util/run.py +3 -3
- zrb/util/string/conversion.py +1 -1
- zrb/util/truncate.py +23 -0
- zrb/util/yaml.py +204 -0
- zrb/xcom/xcom.py +10 -0
- {zrb-1.15.3.dist-info → zrb-2.0.0a4.dist-info}/METADATA +41 -27
- {zrb-1.15.3.dist-info → zrb-2.0.0a4.dist-info}/RECORD +129 -131
- {zrb-1.15.3.dist-info → zrb-2.0.0a4.dist-info}/WHEEL +1 -1
- zrb/attr/__init__.py +0 -0
- zrb/builtin/llm/chat_session.py +0 -311
- zrb/builtin/llm/history.py +0 -71
- zrb/builtin/llm/input.py +0 -27
- zrb/builtin/llm/llm_ask.py +0 -187
- zrb/builtin/llm/previous-session.js +0 -21
- zrb/builtin/llm/tool/__init__.py +0 -0
- zrb/builtin/llm/tool/api.py +0 -71
- zrb/builtin/llm/tool/cli.py +0 -38
- zrb/builtin/llm/tool/code.py +0 -254
- zrb/builtin/llm/tool/file.py +0 -626
- zrb/builtin/llm/tool/sub_agent.py +0 -137
- zrb/builtin/llm/tool/web.py +0 -195
- zrb/builtin/project/__init__.py +0 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/__init__.py +0 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/module/template/app_template/module/my_module/service/__init__.py +0 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/common/__init__.py +0 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/__init__.py +0 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/__init__.py +0 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/permission/__init__.py +0 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/role/__init__.py +0 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/user/__init__.py +0 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/schema/__init__.py +0 -0
- zrb/builtin/project/create/__init__.py +0 -0
- zrb/builtin/shell/__init__.py +0 -0
- zrb/builtin/shell/autocomplete/__init__.py +0 -0
- zrb/callback/__init__.py +0 -0
- zrb/cmd/__init__.py +0 -0
- zrb/config/default_prompt/file_extractor_system_prompt.md +0 -12
- zrb/config/default_prompt/interactive_system_prompt.md +0 -35
- zrb/config/default_prompt/persona.md +0 -1
- zrb/config/default_prompt/repo_extractor_system_prompt.md +0 -112
- zrb/config/default_prompt/repo_summarizer_system_prompt.md +0 -10
- zrb/config/default_prompt/summarization_prompt.md +0 -16
- zrb/config/default_prompt/system_prompt.md +0 -32
- zrb/config/llm_config.py +0 -243
- zrb/config/llm_context/config.py +0 -129
- zrb/config/llm_context/config_parser.py +0 -46
- zrb/config/llm_rate_limitter.py +0 -137
- zrb/content_transformer/__init__.py +0 -0
- zrb/context/__init__.py +0 -0
- zrb/dot_dict/__init__.py +0 -0
- zrb/env/__init__.py +0 -0
- zrb/group/__init__.py +0 -0
- zrb/input/__init__.py +0 -0
- zrb/runner/__init__.py +0 -0
- zrb/runner/web_route/__init__.py +0 -0
- zrb/runner/web_route/home_page/__init__.py +0 -0
- zrb/session/__init__.py +0 -0
- zrb/session_state_log/__init__.py +0 -0
- zrb/session_state_logger/__init__.py +0 -0
- zrb/task/__init__.py +0 -0
- zrb/task/base/__init__.py +0 -0
- zrb/task/llm/__init__.py +0 -0
- zrb/task/llm/agent.py +0 -243
- zrb/task/llm/config.py +0 -103
- zrb/task/llm/conversation_history.py +0 -128
- zrb/task/llm/conversation_history_model.py +0 -242
- zrb/task/llm/default_workflow/coding.md +0 -24
- zrb/task/llm/default_workflow/copywriting.md +0 -17
- zrb/task/llm/default_workflow/researching.md +0 -18
- zrb/task/llm/error.py +0 -95
- zrb/task/llm/history_summarization.py +0 -216
- zrb/task/llm/print_node.py +0 -101
- zrb/task/llm/prompt.py +0 -325
- zrb/task/llm/tool_wrapper.py +0 -220
- zrb/task/llm/typing.py +0 -3
- zrb/task/llm_task.py +0 -341
- zrb/task_status/__init__.py +0 -0
- zrb/util/__init__.py +0 -0
- zrb/util/cli/__init__.py +0 -0
- zrb/util/cmd/__init__.py +0 -0
- zrb/util/codemod/__init__.py +0 -0
- zrb/util/string/__init__.py +0 -0
- zrb/xcom/__init__.py +0 -0
- {zrb-1.15.3.dist-info → zrb-2.0.0a4.dist-info}/entry_points.txt +0 -0
zrb/llm/config/config.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING
|
|
2
|
+
|
|
3
|
+
from zrb.config.config import CFG
|
|
4
|
+
|
|
5
|
+
if TYPE_CHECKING:
|
|
6
|
+
from pydantic_ai.models import Model
|
|
7
|
+
from pydantic_ai.providers import Provider
|
|
8
|
+
from pydantic_ai.settings import ModelSettings
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class LLMConfig:
|
|
12
|
+
"""
|
|
13
|
+
Configuration provider for Pollux.
|
|
14
|
+
Allows runtime configuration while falling back to ZRB global settings.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self):
|
|
18
|
+
self._model: "str | Model | None" = None
|
|
19
|
+
self._model_settings: "ModelSettings | None" = None
|
|
20
|
+
self._system_prompt: str | None = None
|
|
21
|
+
self._summarization_prompt: str | None = None
|
|
22
|
+
|
|
23
|
+
# Optional overrides for provider resolution
|
|
24
|
+
self._api_key: str | None = None
|
|
25
|
+
self._base_url: str | None = None
|
|
26
|
+
self._provider: "str | Provider | None" = None
|
|
27
|
+
|
|
28
|
+
# --- Model ---
|
|
29
|
+
|
|
30
|
+
@property
|
|
31
|
+
def model(self) -> "str | Model":
|
|
32
|
+
"""
|
|
33
|
+
The LLM model to use. Returns a model string (e.g. 'openai:gpt-4o')
|
|
34
|
+
or a pydantic_ai Model object.
|
|
35
|
+
"""
|
|
36
|
+
if self._model is not None:
|
|
37
|
+
return self._model
|
|
38
|
+
|
|
39
|
+
model_name = CFG.LLM_MODEL or "openai:gpt-4o"
|
|
40
|
+
provider = self.provider
|
|
41
|
+
|
|
42
|
+
return self._resolve_model(model_name, provider)
|
|
43
|
+
|
|
44
|
+
@model.setter
|
|
45
|
+
def model(self, value: "str | Model"):
|
|
46
|
+
self._model = value
|
|
47
|
+
|
|
48
|
+
# --- Model Settings ---
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
def model_settings(self) -> "ModelSettings | None":
|
|
52
|
+
"""Runtime settings for the model (temperature, etc.)."""
|
|
53
|
+
return self._model_settings
|
|
54
|
+
|
|
55
|
+
@model_settings.setter
|
|
56
|
+
def model_settings(self, value: "ModelSettings"):
|
|
57
|
+
self._model_settings = value
|
|
58
|
+
|
|
59
|
+
# --- Provider Helpers (Advanced) ---
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
def api_key(self) -> str | None:
|
|
63
|
+
return self._api_key or getattr(CFG, "LLM_API_KEY", None)
|
|
64
|
+
|
|
65
|
+
@api_key.setter
|
|
66
|
+
def api_key(self, value: str):
|
|
67
|
+
self._api_key = value
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def base_url(self) -> str | None:
|
|
71
|
+
return self._base_url or getattr(CFG, "LLM_BASE_URL", None)
|
|
72
|
+
|
|
73
|
+
@base_url.setter
|
|
74
|
+
def base_url(self, value: str):
|
|
75
|
+
self._base_url = value
|
|
76
|
+
|
|
77
|
+
@property
|
|
78
|
+
def provider(self) -> "str | Provider":
|
|
79
|
+
"""Resolves the model provider based on config."""
|
|
80
|
+
if self._provider is not None:
|
|
81
|
+
return self._provider
|
|
82
|
+
|
|
83
|
+
# If API Key or Base URL is set, we assume OpenAI-compatible provider
|
|
84
|
+
if self.api_key or self.base_url:
|
|
85
|
+
from pydantic_ai.providers.openai import OpenAIProvider
|
|
86
|
+
|
|
87
|
+
return OpenAIProvider(api_key=self.api_key, base_url=self.base_url)
|
|
88
|
+
|
|
89
|
+
return "openai"
|
|
90
|
+
|
|
91
|
+
@provider.setter
|
|
92
|
+
def provider(self, value: "str | Provider"):
|
|
93
|
+
self._provider = value
|
|
94
|
+
|
|
95
|
+
# --- Internal Logic ---
|
|
96
|
+
|
|
97
|
+
def _resolve_model(
|
|
98
|
+
self, model_name: str, provider: "str | Provider"
|
|
99
|
+
) -> "str | Model":
|
|
100
|
+
# Strip existing provider prefix if present
|
|
101
|
+
clean_model_name = model_name.split(":", 1)[-1]
|
|
102
|
+
|
|
103
|
+
# 1. Provider is an Object (e.g. OpenAIProvider created from custom config)
|
|
104
|
+
# We check specific types we know how to wrap
|
|
105
|
+
try:
|
|
106
|
+
from pydantic_ai.models.openai import OpenAIChatModel
|
|
107
|
+
from pydantic_ai.providers.openai import OpenAIProvider
|
|
108
|
+
|
|
109
|
+
if isinstance(provider, OpenAIProvider):
|
|
110
|
+
return OpenAIChatModel(model_name=clean_model_name, provider=provider)
|
|
111
|
+
except ImportError:
|
|
112
|
+
pass
|
|
113
|
+
|
|
114
|
+
# 2. Provider is a String
|
|
115
|
+
if isinstance(provider, str):
|
|
116
|
+
return f"{provider}:{clean_model_name}"
|
|
117
|
+
|
|
118
|
+
# 3. Fallback (Provider is None or unknown object)
|
|
119
|
+
return model_name
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
llm_config = LLMConfig()
|
|
@@ -0,0 +1,247 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import json
|
|
3
|
+
import time
|
|
4
|
+
from collections import deque
|
|
5
|
+
from typing import Any, Callable
|
|
6
|
+
|
|
7
|
+
from zrb.config.config import CFG
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class LLMLimiter:
|
|
11
|
+
"""
|
|
12
|
+
Manages LLM constraints: Context Window (Pruning) and Rate Limits (Throttling).
|
|
13
|
+
Designed as a singleton to share limits across tasks.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(self):
|
|
17
|
+
# Sliding window logs
|
|
18
|
+
self._request_log: deque[float] = deque()
|
|
19
|
+
self._token_log: deque[tuple[float, int]] = deque()
|
|
20
|
+
|
|
21
|
+
# Internal overrides
|
|
22
|
+
self._max_requests_per_minute: int | None = None
|
|
23
|
+
self._max_tokens_per_minute: int | None = None
|
|
24
|
+
self._max_tokens_per_request: int | None = None
|
|
25
|
+
self._throttle_check_interval: float | None = None
|
|
26
|
+
|
|
27
|
+
# --- Configuration Properties ---
|
|
28
|
+
|
|
29
|
+
@property
|
|
30
|
+
def max_requests_per_minute(self) -> int:
|
|
31
|
+
if self._max_requests_per_minute is not None:
|
|
32
|
+
return self._max_requests_per_minute
|
|
33
|
+
return getattr(CFG, "LLM_MAX_REQUESTS_PER_MINUTE", None) or 60
|
|
34
|
+
|
|
35
|
+
@max_requests_per_minute.setter
|
|
36
|
+
def max_requests_per_minute(self, value: int):
|
|
37
|
+
self._max_requests_per_minute = value
|
|
38
|
+
|
|
39
|
+
@property
|
|
40
|
+
def max_tokens_per_minute(self) -> int:
|
|
41
|
+
if self._max_tokens_per_minute is not None:
|
|
42
|
+
return self._max_tokens_per_minute
|
|
43
|
+
return getattr(CFG, "LLM_MAX_TOKENS_PER_MINUTE", None) or 100_000
|
|
44
|
+
|
|
45
|
+
@max_tokens_per_minute.setter
|
|
46
|
+
def max_tokens_per_minute(self, value: int):
|
|
47
|
+
self._max_tokens_per_minute = value
|
|
48
|
+
|
|
49
|
+
@property
|
|
50
|
+
def max_tokens_per_request(self) -> int:
|
|
51
|
+
if self._max_tokens_per_request is not None:
|
|
52
|
+
return self._max_tokens_per_request
|
|
53
|
+
return getattr(CFG, "LLM_MAX_TOKENS_PER_REQUEST", None) or 16_000
|
|
54
|
+
|
|
55
|
+
@max_tokens_per_request.setter
|
|
56
|
+
def max_tokens_per_request(self, value: int):
|
|
57
|
+
self._max_tokens_per_request = value
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
def throttle_check_interval(self) -> float:
|
|
61
|
+
if self._throttle_check_interval is not None:
|
|
62
|
+
return self._throttle_check_interval
|
|
63
|
+
return getattr(CFG, "LLM_THROTTLE_SLEEP", None) or 0.1
|
|
64
|
+
|
|
65
|
+
@throttle_check_interval.setter
|
|
66
|
+
def throttle_check_interval(self, value: float):
|
|
67
|
+
self._throttle_check_interval = value
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def use_tiktoken(self) -> bool:
|
|
71
|
+
return getattr(CFG, "USE_TIKTOKEN", False)
|
|
72
|
+
|
|
73
|
+
@property
|
|
74
|
+
def tiktoken_encoding(self) -> str:
|
|
75
|
+
return getattr(CFG, "TIKTOKEN_ENCODING_NAME", "cl100k_base")
|
|
76
|
+
|
|
77
|
+
# --- Public API ---
|
|
78
|
+
|
|
79
|
+
def fit_context_window(self, history: list[Any], new_message: Any) -> list[Any]:
|
|
80
|
+
"""
|
|
81
|
+
Prunes the history (removing oldest turns) so that 'history + new_message'
|
|
82
|
+
fits within 'max_tokens_per_request'.
|
|
83
|
+
Ensures strict tool call pairing by removing full conversation turns.
|
|
84
|
+
"""
|
|
85
|
+
if not history:
|
|
86
|
+
return history
|
|
87
|
+
|
|
88
|
+
# Import message types locally to avoid circular deps or startup cost
|
|
89
|
+
try:
|
|
90
|
+
from pydantic_ai.messages import (
|
|
91
|
+
ModelRequest,
|
|
92
|
+
ToolReturnPart,
|
|
93
|
+
UserPromptPart,
|
|
94
|
+
)
|
|
95
|
+
except ImportError:
|
|
96
|
+
# Fallback if pydantic_ai is not installed (unlikely in context)
|
|
97
|
+
return []
|
|
98
|
+
|
|
99
|
+
def is_turn_start(msg: Any) -> bool:
|
|
100
|
+
"""Identify start of a new user interaction (User Prompt without Tool Return)."""
|
|
101
|
+
if not isinstance(msg, ModelRequest):
|
|
102
|
+
return False
|
|
103
|
+
has_user = any(isinstance(p, UserPromptPart) for p in msg.parts)
|
|
104
|
+
has_return = any(isinstance(p, ToolReturnPart) for p in msg.parts)
|
|
105
|
+
return has_user and not has_return
|
|
106
|
+
|
|
107
|
+
new_msg_tokens = self._count_tokens(new_message)
|
|
108
|
+
if new_msg_tokens > self.max_tokens_per_request:
|
|
109
|
+
return []
|
|
110
|
+
|
|
111
|
+
pruned_history = list(history)
|
|
112
|
+
|
|
113
|
+
while pruned_history:
|
|
114
|
+
history_tokens = self._count_tokens(pruned_history)
|
|
115
|
+
total_tokens = history_tokens + new_msg_tokens
|
|
116
|
+
|
|
117
|
+
if total_tokens <= self.max_tokens_per_request:
|
|
118
|
+
break
|
|
119
|
+
|
|
120
|
+
# Pruning Strategy: Find the start of the *next* turn and cut everything before it.
|
|
121
|
+
# We start searching from index 1 because removing index 0 (current start) is the goal.
|
|
122
|
+
next_turn_index = -1
|
|
123
|
+
for i in range(1, len(pruned_history)):
|
|
124
|
+
if is_turn_start(pruned_history[i]):
|
|
125
|
+
next_turn_index = i
|
|
126
|
+
break
|
|
127
|
+
|
|
128
|
+
if next_turn_index != -1:
|
|
129
|
+
# Remove everything up to the next turn
|
|
130
|
+
pruned_history = pruned_history[next_turn_index:]
|
|
131
|
+
else:
|
|
132
|
+
# No subsequent turns found.
|
|
133
|
+
# This implies the history contains only one (potentially long) turn or partial fragments.
|
|
134
|
+
# To satisfy the limit, we must clear the history entirely.
|
|
135
|
+
pruned_history = []
|
|
136
|
+
|
|
137
|
+
return pruned_history
|
|
138
|
+
|
|
139
|
+
async def acquire(self, content: Any, notifier: Callable[[str], Any] | None = None):
|
|
140
|
+
"""
|
|
141
|
+
Acquires permission to proceed with the given content.
|
|
142
|
+
Calculates token count internally and waits if rate limits are exceeded.
|
|
143
|
+
"""
|
|
144
|
+
# Calculate tokens once
|
|
145
|
+
estimated_tokens = self._count_tokens(content)
|
|
146
|
+
|
|
147
|
+
# 1. Prune logs older than 60 seconds
|
|
148
|
+
self._prune_logs()
|
|
149
|
+
|
|
150
|
+
# 2. Check limits loop
|
|
151
|
+
notified = False
|
|
152
|
+
while not self._can_proceed(estimated_tokens):
|
|
153
|
+
wait_time = self._calculate_wait_time(estimated_tokens)
|
|
154
|
+
reason = self._get_limit_reason(estimated_tokens)
|
|
155
|
+
|
|
156
|
+
if notifier:
|
|
157
|
+
msg = f"Rate Limit Reached: {reason}. Waiting {wait_time:.1f}s..."
|
|
158
|
+
# Only notify once or if status changes? Simple is better.
|
|
159
|
+
notifier(msg)
|
|
160
|
+
notified = True
|
|
161
|
+
|
|
162
|
+
await asyncio.sleep(self.throttle_check_interval)
|
|
163
|
+
self._prune_logs()
|
|
164
|
+
|
|
165
|
+
if notified and notifier:
|
|
166
|
+
notifier("") # Clear status
|
|
167
|
+
|
|
168
|
+
# 3. Record usage
|
|
169
|
+
now = time.time()
|
|
170
|
+
self._request_log.append(now)
|
|
171
|
+
self._token_log.append((now, estimated_tokens))
|
|
172
|
+
|
|
173
|
+
def count_tokens(self, content: Any) -> int:
|
|
174
|
+
"""Public alias for internal counter."""
|
|
175
|
+
return self._count_tokens(content)
|
|
176
|
+
|
|
177
|
+
# --- Internal Helpers ---
|
|
178
|
+
|
|
179
|
+
def _count_tokens(self, content: Any) -> int:
|
|
180
|
+
text = self._to_str(content)
|
|
181
|
+
if self.use_tiktoken:
|
|
182
|
+
try:
|
|
183
|
+
import tiktoken
|
|
184
|
+
|
|
185
|
+
enc = tiktoken.get_encoding(self.tiktoken_encoding)
|
|
186
|
+
return len(enc.encode(text))
|
|
187
|
+
except ImportError:
|
|
188
|
+
pass
|
|
189
|
+
# Fallback approximation (char/4)
|
|
190
|
+
return len(text) // 4
|
|
191
|
+
|
|
192
|
+
def _to_str(self, content: Any) -> str:
|
|
193
|
+
if isinstance(content, str):
|
|
194
|
+
return content
|
|
195
|
+
try:
|
|
196
|
+
return json.dumps(content, default=str)
|
|
197
|
+
except Exception:
|
|
198
|
+
return str(content)
|
|
199
|
+
|
|
200
|
+
def _prune_logs(self):
|
|
201
|
+
now = time.time()
|
|
202
|
+
window_start = now - 60
|
|
203
|
+
|
|
204
|
+
while self._request_log and self._request_log[0] < window_start:
|
|
205
|
+
self._request_log.popleft()
|
|
206
|
+
|
|
207
|
+
while self._token_log and self._token_log[0][0] < window_start:
|
|
208
|
+
self._token_log.popleft()
|
|
209
|
+
|
|
210
|
+
def _can_proceed(self, tokens: int) -> bool:
|
|
211
|
+
requests_ok = len(self._request_log) < self.max_requests_per_minute
|
|
212
|
+
|
|
213
|
+
current_tokens = sum(t for _, t in self._token_log)
|
|
214
|
+
tokens_ok = (current_tokens + tokens) <= self.max_tokens_per_minute
|
|
215
|
+
|
|
216
|
+
return requests_ok and tokens_ok
|
|
217
|
+
|
|
218
|
+
def _get_limit_reason(self, tokens: int) -> str:
|
|
219
|
+
if len(self._request_log) >= self.max_requests_per_minute:
|
|
220
|
+
return f"Max Requests ({self.max_requests_per_minute}/min)"
|
|
221
|
+
return f"Max Tokens ({self.max_tokens_per_minute}/min)"
|
|
222
|
+
|
|
223
|
+
def _calculate_wait_time(self, tokens: int) -> float:
|
|
224
|
+
now = time.time()
|
|
225
|
+
# Default wait
|
|
226
|
+
wait = 1.0
|
|
227
|
+
|
|
228
|
+
# If request limit hit, wait until oldest request expires
|
|
229
|
+
if len(self._request_log) >= self.max_requests_per_minute:
|
|
230
|
+
oldest = self._request_log[0]
|
|
231
|
+
wait = max(0.1, 60 - (now - oldest))
|
|
232
|
+
|
|
233
|
+
# If token limit hit, wait until enough tokens expire
|
|
234
|
+
current_tokens = sum(t for _, t in self._token_log)
|
|
235
|
+
if current_tokens + tokens > self.max_tokens_per_minute:
|
|
236
|
+
needed = (current_tokens + tokens) - self.max_tokens_per_minute
|
|
237
|
+
freed = 0
|
|
238
|
+
for ts, count in self._token_log:
|
|
239
|
+
freed += count
|
|
240
|
+
if freed >= needed:
|
|
241
|
+
wait = max(wait, 60 - (now - ts))
|
|
242
|
+
break
|
|
243
|
+
|
|
244
|
+
return wait
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
llm_limiter = LLMLimiter()
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import TYPE_CHECKING
|
|
3
|
+
|
|
4
|
+
if TYPE_CHECKING:
|
|
5
|
+
from pydantic_ai import ModelMessage
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class AnyHistoryManager(ABC):
|
|
9
|
+
@abstractmethod
|
|
10
|
+
def load(self, conversation_name: str) -> "list[ModelMessage]":
|
|
11
|
+
pass
|
|
12
|
+
|
|
13
|
+
@abstractmethod
|
|
14
|
+
def save(self, conversation_name: str):
|
|
15
|
+
pass
|
|
16
|
+
|
|
17
|
+
@abstractmethod
|
|
18
|
+
def update(self, conversation_name: str, messages: "list[ModelMessage]"):
|
|
19
|
+
pass
|
|
20
|
+
|
|
21
|
+
@abstractmethod
|
|
22
|
+
def search(self, keyword: str) -> list[str]:
|
|
23
|
+
pass
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
if TYPE_CHECKING:
|
|
6
|
+
from pydantic_ai import ModelMessage
|
|
7
|
+
|
|
8
|
+
from zrb.llm.history_manager.any_history_manager import AnyHistoryManager
|
|
9
|
+
from zrb.util.match import fuzzy_match
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class FileHistoryManager(AnyHistoryManager):
|
|
13
|
+
def __init__(self, history_dir: str):
|
|
14
|
+
self._history_dir = os.path.expanduser(history_dir)
|
|
15
|
+
self._cache: "dict[str, list[ModelMessage]]" = {}
|
|
16
|
+
if not os.path.exists(self._history_dir):
|
|
17
|
+
os.makedirs(self._history_dir, exist_ok=True)
|
|
18
|
+
|
|
19
|
+
def _get_file_path(self, conversation_name: str) -> str:
|
|
20
|
+
# Sanitize conversation name to be safe for filename
|
|
21
|
+
safe_name = "".join(
|
|
22
|
+
c for c in conversation_name if c.isalnum() or c in (" ", ".", "_", "-")
|
|
23
|
+
).strip()
|
|
24
|
+
if not safe_name:
|
|
25
|
+
safe_name = "default"
|
|
26
|
+
return os.path.join(self._history_dir, f"{safe_name}.json")
|
|
27
|
+
|
|
28
|
+
def load(self, conversation_name: str) -> "list[ModelMessage]":
|
|
29
|
+
from pydantic_ai.messages import ModelMessagesTypeAdapter
|
|
30
|
+
|
|
31
|
+
if conversation_name in self._cache:
|
|
32
|
+
return self._cache[conversation_name]
|
|
33
|
+
|
|
34
|
+
file_path = self._get_file_path(conversation_name)
|
|
35
|
+
if not os.path.exists(file_path):
|
|
36
|
+
return []
|
|
37
|
+
|
|
38
|
+
try:
|
|
39
|
+
with open(file_path, "r", encoding="utf-8") as f:
|
|
40
|
+
content = f.read()
|
|
41
|
+
if not content.strip():
|
|
42
|
+
return []
|
|
43
|
+
data = json.loads(content)
|
|
44
|
+
messages = ModelMessagesTypeAdapter.validate_python(data)
|
|
45
|
+
self._cache[conversation_name] = messages
|
|
46
|
+
return messages
|
|
47
|
+
except (json.JSONDecodeError, OSError) as e:
|
|
48
|
+
# Log error or warn? For now, return empty list or re-raise.
|
|
49
|
+
# Returning empty list is safer for UI not to crash.
|
|
50
|
+
print(f"Warning: Failed to load history for {conversation_name}: {e}")
|
|
51
|
+
return []
|
|
52
|
+
|
|
53
|
+
def update(self, conversation_name: str, messages: "list[ModelMessage]"):
|
|
54
|
+
self._cache[conversation_name] = messages
|
|
55
|
+
|
|
56
|
+
def save(self, conversation_name: str):
|
|
57
|
+
from pydantic_ai.messages import ModelMessagesTypeAdapter
|
|
58
|
+
|
|
59
|
+
if conversation_name not in self._cache:
|
|
60
|
+
return
|
|
61
|
+
|
|
62
|
+
messages = self._cache[conversation_name]
|
|
63
|
+
file_path = self._get_file_path(conversation_name)
|
|
64
|
+
|
|
65
|
+
try:
|
|
66
|
+
data = ModelMessagesTypeAdapter.dump_python(messages, mode="json")
|
|
67
|
+
with open(file_path, "w", encoding="utf-8") as f:
|
|
68
|
+
json.dump(data, f, indent=2)
|
|
69
|
+
except OSError as e:
|
|
70
|
+
print(f"Error: Failed to save history for {conversation_name}: {e}")
|
|
71
|
+
|
|
72
|
+
def search(self, keyword: str) -> list[str]:
|
|
73
|
+
if not os.path.exists(self._history_dir):
|
|
74
|
+
return []
|
|
75
|
+
|
|
76
|
+
matches = []
|
|
77
|
+
for filename in os.listdir(self._history_dir):
|
|
78
|
+
if not filename.endswith(".json"):
|
|
79
|
+
continue
|
|
80
|
+
|
|
81
|
+
# Remove extension to get session name
|
|
82
|
+
session_name = filename[:-5]
|
|
83
|
+
|
|
84
|
+
is_match, score = fuzzy_match(session_name, keyword)
|
|
85
|
+
if is_match:
|
|
86
|
+
matches.append((session_name, score))
|
|
87
|
+
|
|
88
|
+
# Sort by score (lower is better)
|
|
89
|
+
matches.sort(key=lambda x: x[1])
|
|
90
|
+
|
|
91
|
+
return [m[0] for m in matches]
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING, Any, Awaitable, Callable
|
|
2
|
+
|
|
3
|
+
from zrb.config.config import CFG
|
|
4
|
+
from zrb.llm.agent.summarizer import create_summarizer_agent
|
|
5
|
+
from zrb.llm.config.limiter import LLMLimiter
|
|
6
|
+
from zrb.llm.config.limiter import llm_limiter as default_llm_limiter
|
|
7
|
+
from zrb.util.markdown import make_markdown_section
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from pydantic_ai.messages import ModelMessage
|
|
11
|
+
else:
|
|
12
|
+
ModelMessage = Any
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def is_turn_start(msg: Any) -> bool:
|
|
16
|
+
"""Identify start of a new user interaction (User Prompt without Tool Return)."""
|
|
17
|
+
from pydantic_ai.messages import ModelRequest, ToolReturnPart, UserPromptPart
|
|
18
|
+
|
|
19
|
+
if not isinstance(msg, ModelRequest):
|
|
20
|
+
return False
|
|
21
|
+
# In pydantic_ai, ModelRequest parts can be list of various parts
|
|
22
|
+
has_user = any(isinstance(p, UserPromptPart) for p in msg.parts)
|
|
23
|
+
has_return = any(isinstance(p, ToolReturnPart) for p in msg.parts)
|
|
24
|
+
return has_user and not has_return
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
async def summarize_history(
|
|
28
|
+
messages: "list[ModelMessage]",
|
|
29
|
+
agent: Any = None,
|
|
30
|
+
summary_window: int | None = None,
|
|
31
|
+
) -> "list[ModelMessage]":
|
|
32
|
+
"""
|
|
33
|
+
Summarizes the history, keeping the last `summary_window` messages intact.
|
|
34
|
+
Returns a new list of messages where older messages are replaced by a summary.
|
|
35
|
+
"""
|
|
36
|
+
from pydantic_ai.messages import ModelRequest, UserPromptPart
|
|
37
|
+
|
|
38
|
+
if summary_window is None:
|
|
39
|
+
summary_window = CFG.LLM_HISTORY_SUMMARIZATION_WINDOW
|
|
40
|
+
if len(messages) <= summary_window:
|
|
41
|
+
return messages
|
|
42
|
+
|
|
43
|
+
# Determine split index
|
|
44
|
+
# We want to keep at least summary_window messages.
|
|
45
|
+
# So split_idx <= len(messages) - summary_window.
|
|
46
|
+
# We search backwards from there for a clean turn start.
|
|
47
|
+
start_search_idx = max(0, len(messages) - summary_window)
|
|
48
|
+
split_idx = -1
|
|
49
|
+
|
|
50
|
+
# Iterate backwards from start_search_idx to 0
|
|
51
|
+
for i in range(start_search_idx, -1, -1):
|
|
52
|
+
if is_turn_start(messages[i]):
|
|
53
|
+
split_idx = i
|
|
54
|
+
break
|
|
55
|
+
|
|
56
|
+
if split_idx <= 0:
|
|
57
|
+
return messages
|
|
58
|
+
|
|
59
|
+
to_summarize = messages[:split_idx]
|
|
60
|
+
to_keep = messages[split_idx:]
|
|
61
|
+
|
|
62
|
+
# Simple text representation for now
|
|
63
|
+
history_text = "\n".join([str(m) for m in to_summarize])
|
|
64
|
+
|
|
65
|
+
summarizer_agent = agent or create_summarizer_agent()
|
|
66
|
+
|
|
67
|
+
# Run the summarizer agent
|
|
68
|
+
result = await summarizer_agent.run(
|
|
69
|
+
f"Summarize this conversation history:\n{history_text}"
|
|
70
|
+
)
|
|
71
|
+
summary_text = result.output
|
|
72
|
+
|
|
73
|
+
# Create a summary message injected as user context
|
|
74
|
+
summary_message = ModelRequest(
|
|
75
|
+
parts=[
|
|
76
|
+
UserPromptPart(
|
|
77
|
+
content=make_markdown_section(
|
|
78
|
+
"Previous conversation summary", summary_text
|
|
79
|
+
)
|
|
80
|
+
)
|
|
81
|
+
]
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
return [summary_message] + to_keep
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def create_summarizer_history_processor(
|
|
88
|
+
agent: Any = None,
|
|
89
|
+
limiter: LLMLimiter | None = None,
|
|
90
|
+
token_threshold: int | None = None,
|
|
91
|
+
summary_window: int | None = None,
|
|
92
|
+
) -> "Callable[[list[ModelMessage]], Awaitable[list[ModelMessage]]]":
|
|
93
|
+
"""
|
|
94
|
+
Creates a history processor that auto-summarizes history when it exceeds `token_threshold`.
|
|
95
|
+
"""
|
|
96
|
+
llm_limiter = limiter or default_llm_limiter
|
|
97
|
+
if token_threshold is None:
|
|
98
|
+
token_threshold = CFG.LLM_HISTORY_SUMMARIZATION_TOKEN_THRESHOLD
|
|
99
|
+
|
|
100
|
+
async def process_history(messages: "list[ModelMessage]") -> "list[ModelMessage]":
|
|
101
|
+
current_tokens = llm_limiter.count_tokens(messages)
|
|
102
|
+
|
|
103
|
+
if current_tokens <= token_threshold:
|
|
104
|
+
return messages
|
|
105
|
+
|
|
106
|
+
return await summarize_history(messages, agent, summary_window)
|
|
107
|
+
|
|
108
|
+
return process_history
|
zrb/llm/note/__init__.py
ADDED