amrita_core 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- amrita_core/__init__.py +101 -0
- amrita_core/builtins/__init__.py +7 -0
- amrita_core/builtins/adapter.py +148 -0
- amrita_core/builtins/agent.py +415 -0
- amrita_core/builtins/tools.py +64 -0
- amrita_core/chatmanager.py +896 -0
- amrita_core/config.py +159 -0
- amrita_core/hook/event.py +90 -0
- amrita_core/hook/exception.py +14 -0
- amrita_core/hook/matcher.py +213 -0
- amrita_core/hook/on.py +14 -0
- amrita_core/libchat.py +189 -0
- amrita_core/logging.py +71 -0
- amrita_core/preset.py +166 -0
- amrita_core/protocol.py +101 -0
- amrita_core/tokenizer.py +115 -0
- amrita_core/tools/manager.py +163 -0
- amrita_core/tools/mcp.py +338 -0
- amrita_core/tools/models.py +353 -0
- amrita_core/types.py +274 -0
- amrita_core/utils.py +66 -0
- amrita_core-0.1.0.dist-info/METADATA +73 -0
- amrita_core-0.1.0.dist-info/RECORD +26 -0
- amrita_core-0.1.0.dist-info/WHEEL +5 -0
- amrita_core-0.1.0.dist-info/licenses/LICENSE +661 -0
- amrita_core-0.1.0.dist-info/top_level.txt +1 -0
amrita_core/config.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
# TODO:分实例的配置
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import random
|
|
5
|
+
import string
|
|
6
|
+
from typing import Literal
|
|
7
|
+
|
|
8
|
+
from pydantic import Field
|
|
9
|
+
from typing_extensions import LiteralString
|
|
10
|
+
|
|
11
|
+
from amrita_core.types import BaseModel
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def random_alnum_string(length: int) -> str:
|
|
15
|
+
if length < 0:
|
|
16
|
+
raise ValueError("Length can't be smaller than zero!")
|
|
17
|
+
|
|
18
|
+
chars: LiteralString = string.ascii_letters + string.digits
|
|
19
|
+
|
|
20
|
+
return "".join(random.choice(chars) for _ in range(length))
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class CookieConfig(BaseModel):
|
|
24
|
+
"""Amrita Core's cookie config"""
|
|
25
|
+
|
|
26
|
+
enable_cookie: bool = Field(
|
|
27
|
+
default=True, description="Whether to enable Cookie leak detection mechanism"
|
|
28
|
+
)
|
|
29
|
+
cookie: str = Field(
|
|
30
|
+
default_factory=lambda: random_alnum_string(16),
|
|
31
|
+
description="Cookie string for security detection",
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class FunctionConfig(BaseModel):
|
|
36
|
+
use_minimal_context: bool = Field(
|
|
37
|
+
default=True,
|
|
38
|
+
description="Whether to use minimal context, i.e. system prompt + user's last message (disabling this option will use all context from the message list, which may consume a large amount of Tokens during Agent workflow execution; enabling this option may effectively reduce token usage)",
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
tool_calling_mode: Literal["agent", "rag", "none"] = Field(
|
|
42
|
+
default="agent",
|
|
43
|
+
description="Tool calling mode, i.e. whether to use Agent or RAG to call tools",
|
|
44
|
+
)
|
|
45
|
+
agent_tool_call_limit: int = Field(
|
|
46
|
+
default=10, description="Tool call limit in agent mode"
|
|
47
|
+
)
|
|
48
|
+
agent_tool_call_notice: Literal["hide", "notify"] = Field(
|
|
49
|
+
default="hide",
|
|
50
|
+
description="Method of showing tool call status in agent mode, hide to conceal, notify to inform",
|
|
51
|
+
)
|
|
52
|
+
agent_thought_mode: Literal[
|
|
53
|
+
"reasoning", "chat", "reasoning-required", "reasoning-optional"
|
|
54
|
+
] = Field(
|
|
55
|
+
default="chat",
|
|
56
|
+
description="Thinking mode in agent mode, reasoning mode will first perform reasoning process, then execute tasks; "
|
|
57
|
+
"reasoning-required requires task analysis for each Tool Calling; "
|
|
58
|
+
"reasoning-optional does not require reasoning but allows it; "
|
|
59
|
+
"chat mode executes tasks directly",
|
|
60
|
+
)
|
|
61
|
+
agent_reasoning_hide: bool = Field(
|
|
62
|
+
default=False, description="Whether to hide the thought process in agent mode"
|
|
63
|
+
)
|
|
64
|
+
agent_middle_message: bool = Field(
|
|
65
|
+
default=True,
|
|
66
|
+
description="Whether to allow Agent to send intermediate messages to users in agent mode",
|
|
67
|
+
)
|
|
68
|
+
agent_mcp_client_enable: bool = Field(
|
|
69
|
+
default=False, description="Whether to enable MCP client"
|
|
70
|
+
)
|
|
71
|
+
agent_mcp_server_scripts: list[str] = Field(
|
|
72
|
+
default=[], description="List of MCP server scripts"
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class LLMConfig(BaseModel):
|
|
77
|
+
require_tools: bool = Field(
|
|
78
|
+
default=False,
|
|
79
|
+
description="Whether to force at least one tool to be used per call",
|
|
80
|
+
)
|
|
81
|
+
memory_lenth_limit: int = Field(
|
|
82
|
+
default=50, description="Maximum number of messages in memory context"
|
|
83
|
+
)
|
|
84
|
+
max_tokens: int = Field(
|
|
85
|
+
default=100,
|
|
86
|
+
description="Maximum number of tokens generated in a single response",
|
|
87
|
+
)
|
|
88
|
+
tokens_count_mode: Literal["word", "bpe", "char"] = Field(
|
|
89
|
+
default="bpe",
|
|
90
|
+
description="Token counting mode: bpe(subwords)/word(words)/char(characters)",
|
|
91
|
+
)
|
|
92
|
+
enable_tokens_limit: bool = Field(
|
|
93
|
+
default=True, description="Whether to enable context length limits"
|
|
94
|
+
)
|
|
95
|
+
session_tokens_windows: int = Field(
|
|
96
|
+
default=5000, description="Session tokens window size"
|
|
97
|
+
)
|
|
98
|
+
llm_timeout: int = Field(
|
|
99
|
+
default=60, description="API request timeout duration (seconds)"
|
|
100
|
+
)
|
|
101
|
+
auto_retry: bool = Field(
|
|
102
|
+
default=True, description="Automatically retry on request failure"
|
|
103
|
+
)
|
|
104
|
+
max_retries: int = Field(default=3, description="Maximum number of retries")
|
|
105
|
+
enable_memory_abstract: bool = Field(
|
|
106
|
+
default=True,
|
|
107
|
+
description="Whether to enable context memory summarization (will delete context and insert a summary into system instruction)",
|
|
108
|
+
)
|
|
109
|
+
memory_abstract_proportion: float = Field(
|
|
110
|
+
default=15e-2, description="Context summarization proportion (0.15=15%)"
|
|
111
|
+
)
|
|
112
|
+
enable_multi_modal: bool = Field(
|
|
113
|
+
default=True,
|
|
114
|
+
description="Whether to enable multi-modal support (currently only supports image)",
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class AmritaConfig(BaseModel):
|
|
119
|
+
function_config: FunctionConfig = Field(
|
|
120
|
+
default_factory=FunctionConfig,
|
|
121
|
+
description="Function configuration",
|
|
122
|
+
)
|
|
123
|
+
llm: LLMConfig = Field(
|
|
124
|
+
default_factory=LLMConfig,
|
|
125
|
+
description="LLM configuration",
|
|
126
|
+
)
|
|
127
|
+
cookie: CookieConfig = Field(
|
|
128
|
+
default_factory=CookieConfig, description="Cookie configuration"
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
__config = AmritaConfig()
|
|
133
|
+
__inited: bool = False
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def get_config() -> AmritaConfig:
|
|
137
|
+
"""Get amrita config
|
|
138
|
+
|
|
139
|
+
Raises:
|
|
140
|
+
RuntimeError: Raise it if amrita core is not initialized.
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
AmritaConfig: Amrita core config
|
|
144
|
+
"""
|
|
145
|
+
if not __inited:
|
|
146
|
+
raise RuntimeError("Config is not initialized. Please set config first.")
|
|
147
|
+
return __config
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def set_config(config: AmritaConfig):
|
|
151
|
+
"""Override the config.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
config (AmritaConfig): Configuration object to set
|
|
155
|
+
"""
|
|
156
|
+
global __config, __inited
|
|
157
|
+
if not __inited:
|
|
158
|
+
__inited = True
|
|
159
|
+
__config = config
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from enum import Enum
|
|
5
|
+
from typing import TYPE_CHECKING
|
|
6
|
+
|
|
7
|
+
from amrita_core.types import USER_INPUT, SendMessageWrap
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from amrita_core.chatmanager import ChatObject
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class EventTypeEnum(str, Enum):
|
|
14
|
+
"""
|
|
15
|
+
EventTypeEnum class is used to define and manage different event types.
|
|
16
|
+
It encapsulates the string identifiers of event types, providing a structured way
|
|
17
|
+
to handle and retrieve event types.
|
|
18
|
+
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
COMPLETION = "COMPLETION"
|
|
22
|
+
Nil = "Nil"
|
|
23
|
+
BEFORE_COMPLETION = "BEFORE_COMPLETION"
|
|
24
|
+
|
|
25
|
+
def validate(self, name: str) -> bool:
|
|
26
|
+
return name in self
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass
|
|
30
|
+
class Event:
|
|
31
|
+
user_input: USER_INPUT
|
|
32
|
+
original_context: SendMessageWrap
|
|
33
|
+
chat_object: "ChatObject"
|
|
34
|
+
|
|
35
|
+
def __post_init__(self):
|
|
36
|
+
# Initialize event type as none
|
|
37
|
+
self._event_type = EventTypeEnum.Nil
|
|
38
|
+
# Validate and store messages using SendMessageWrap
|
|
39
|
+
self._context_messages = self.original_context
|
|
40
|
+
|
|
41
|
+
@property
|
|
42
|
+
def event_type(self) -> EventTypeEnum:
|
|
43
|
+
return self._event_type
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def message(self) -> SendMessageWrap:
|
|
47
|
+
return self._context_messages
|
|
48
|
+
|
|
49
|
+
def get_context_messages(self) -> SendMessageWrap:
|
|
50
|
+
return self._context_messages
|
|
51
|
+
|
|
52
|
+
def get_event_type(self) -> str:
|
|
53
|
+
raise NotImplementedError
|
|
54
|
+
|
|
55
|
+
def get_user_input(self) -> USER_INPUT:
|
|
56
|
+
return self.user_input
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@dataclass
|
|
60
|
+
class CompletionEvent(Event):
|
|
61
|
+
model_response: str
|
|
62
|
+
|
|
63
|
+
def __post_init__(self):
|
|
64
|
+
super().__post_init__()
|
|
65
|
+
# Initialize event type as completion event
|
|
66
|
+
self._event_type = EventTypeEnum.COMPLETION
|
|
67
|
+
|
|
68
|
+
@property
|
|
69
|
+
def event_type(self):
|
|
70
|
+
return EventTypeEnum.COMPLETION
|
|
71
|
+
|
|
72
|
+
def get_event_type(self) -> str:
|
|
73
|
+
return EventTypeEnum.COMPLETION
|
|
74
|
+
|
|
75
|
+
def get_model_response(self) -> str:
|
|
76
|
+
return self.model_response
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@dataclass
|
|
80
|
+
class PreCompletionEvent(Event):
|
|
81
|
+
def __post_init__(self):
|
|
82
|
+
super().__post_init__()
|
|
83
|
+
self._event_type = EventTypeEnum.BEFORE_COMPLETION
|
|
84
|
+
|
|
85
|
+
@property
|
|
86
|
+
def event_type(self) -> EventTypeEnum:
|
|
87
|
+
return self._event_type
|
|
88
|
+
|
|
89
|
+
def get_event_type(self) -> str:
|
|
90
|
+
return self._event_type
|
|
@@ -0,0 +1,213 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
from collections.abc import Awaitable, Callable
|
|
3
|
+
from copy import deepcopy
|
|
4
|
+
from types import FrameType
|
|
5
|
+
from typing import (
|
|
6
|
+
Any,
|
|
7
|
+
ClassVar,
|
|
8
|
+
TypeAlias,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
from pydantic import BaseModel, Field
|
|
12
|
+
from typing_extensions import Self
|
|
13
|
+
|
|
14
|
+
from amrita_core.logging import debug_log, logger
|
|
15
|
+
|
|
16
|
+
from .event import Event
|
|
17
|
+
from .exception import BlockException, CancelException, PassException
|
|
18
|
+
|
|
19
|
+
ChatException: TypeAlias = BlockException | CancelException | PassException
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class FunctionData(BaseModel, arbitrary_types_allowed=True):
|
|
23
|
+
function: Callable[..., Awaitable[Any]] = Field(...)
|
|
24
|
+
signature: inspect.Signature = Field(...)
|
|
25
|
+
frame: FrameType = Field(...)
|
|
26
|
+
priority: int = Field(...)
|
|
27
|
+
block: bool = Field(...)
|
|
28
|
+
matcher: Any = Field(...)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class EventRegistry:
|
|
32
|
+
_instance = None
|
|
33
|
+
__event_handlers: ClassVar[dict[str, list[FunctionData]]] = {}
|
|
34
|
+
|
|
35
|
+
def __new__(cls) -> Self:
|
|
36
|
+
if cls._instance is None:
|
|
37
|
+
cls._instance = super().__new__(cls)
|
|
38
|
+
return cls._instance
|
|
39
|
+
|
|
40
|
+
def register_handler(self, event_type: str, data: FunctionData):
|
|
41
|
+
self.__event_handlers.setdefault(event_type, []).append(data)
|
|
42
|
+
|
|
43
|
+
def get_handlers(self, event_type: str) -> list[FunctionData]:
|
|
44
|
+
self.__event_handlers.setdefault(event_type, [])
|
|
45
|
+
self.__event_handlers[event_type].sort(key=lambda x: x.priority, reverse=False)
|
|
46
|
+
return self.__event_handlers[event_type]
|
|
47
|
+
|
|
48
|
+
def _all(self) -> dict[str, list[FunctionData]]:
|
|
49
|
+
return self.__event_handlers
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class Matcher:
|
|
53
|
+
def __init__(self, event_type: str, priority: int = 10, block: bool = True):
|
|
54
|
+
"""Constructor, initialize Matcher object.
|
|
55
|
+
Args:
|
|
56
|
+
event_type (str): Event type
|
|
57
|
+
priority (int, optional): Priority. Defaults to 10.
|
|
58
|
+
block (bool, optional): Whether to block subsequent events. Defaults to True.
|
|
59
|
+
"""
|
|
60
|
+
if priority <= 0:
|
|
61
|
+
raise ValueError("Event priority cannot be zero or negative!")
|
|
62
|
+
|
|
63
|
+
self.event_type = event_type
|
|
64
|
+
self.priority = priority
|
|
65
|
+
self.block = block
|
|
66
|
+
|
|
67
|
+
def append_handler(self, func: Callable[..., Awaitable[Any]]):
|
|
68
|
+
frame = inspect.currentframe()
|
|
69
|
+
assert frame is not None, "Frame is None!!!"
|
|
70
|
+
func_data = FunctionData(
|
|
71
|
+
function=func,
|
|
72
|
+
signature=inspect.signature(func),
|
|
73
|
+
frame=frame,
|
|
74
|
+
priority=self.priority,
|
|
75
|
+
block=self.block,
|
|
76
|
+
matcher=self,
|
|
77
|
+
)
|
|
78
|
+
EventRegistry().register_handler(self.event_type, func_data)
|
|
79
|
+
|
|
80
|
+
def handle(self):
|
|
81
|
+
"""
|
|
82
|
+
Event handler registration function
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
def wrapper(
|
|
86
|
+
func: Callable[..., Awaitable[Any]],
|
|
87
|
+
):
|
|
88
|
+
self.append_handler(func)
|
|
89
|
+
return func
|
|
90
|
+
|
|
91
|
+
return wrapper
|
|
92
|
+
|
|
93
|
+
def stop_process(self):
|
|
94
|
+
"""
|
|
95
|
+
Stop the event flow within the current chat plugin and immediately stop the current handler.
|
|
96
|
+
"""
|
|
97
|
+
raise BlockException()
|
|
98
|
+
|
|
99
|
+
def cancel_matcher(self):
|
|
100
|
+
"""
|
|
101
|
+
Stop event processing within the current chat plugin and cancel.
|
|
102
|
+
"""
|
|
103
|
+
raise CancelException()
|
|
104
|
+
|
|
105
|
+
def pass_event(self):
|
|
106
|
+
"""
|
|
107
|
+
Ignore the current handler and continue processing the next one.
|
|
108
|
+
"""
|
|
109
|
+
raise PassException()
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class MatcherManager:
|
|
113
|
+
"""
|
|
114
|
+
Event handling manager.
|
|
115
|
+
"""
|
|
116
|
+
|
|
117
|
+
@staticmethod
|
|
118
|
+
async def trigger_event(*args, **kwargs) -> None:
|
|
119
|
+
"""
|
|
120
|
+
Trigger a specific type of event and call all registered event handlers for that type.
|
|
121
|
+
|
|
122
|
+
Parameters:
|
|
123
|
+
- event: Event object containing event-related data.
|
|
124
|
+
- **kwargs: Keyword arguments passed to the dependency injection system.
|
|
125
|
+
- *args: Variable arguments passed to the dependency injection system.
|
|
126
|
+
"""
|
|
127
|
+
event: Event | None = None
|
|
128
|
+
for i in args:
|
|
129
|
+
if isinstance(i, Event):
|
|
130
|
+
event = i
|
|
131
|
+
break
|
|
132
|
+
if not event:
|
|
133
|
+
raise RuntimeError("No event found in args")
|
|
134
|
+
event_type = event.get_event_type() # Get event type
|
|
135
|
+
priority_tmp = 0
|
|
136
|
+
debug_log(f"Running matchers for event: {event_type}!")
|
|
137
|
+
# Check if there are handlers for this event type
|
|
138
|
+
if matcher_list := EventRegistry().get_handlers(event_type):
|
|
139
|
+
for matcher in matcher_list:
|
|
140
|
+
if matcher.priority != priority_tmp:
|
|
141
|
+
priority_tmp = matcher.priority
|
|
142
|
+
debug_log(f"Running matchers for priority {priority_tmp}......")
|
|
143
|
+
|
|
144
|
+
signature = matcher.signature
|
|
145
|
+
frame = matcher.frame
|
|
146
|
+
line_number = frame.f_lineno
|
|
147
|
+
file_name = frame.f_code.co_filename
|
|
148
|
+
handler = matcher.function
|
|
149
|
+
session_args = [matcher.matcher, *args]
|
|
150
|
+
session_kwargs = {**deepcopy(kwargs)}
|
|
151
|
+
|
|
152
|
+
args_types = {k: v.annotation for k, v in signature.parameters.items()}
|
|
153
|
+
filtered_args_types = {
|
|
154
|
+
k: v for k, v in args_types.items() if v is not inspect._empty
|
|
155
|
+
}
|
|
156
|
+
if args_types != filtered_args_types:
|
|
157
|
+
failed_args = list(args_types.keys() - filtered_args_types.keys())
|
|
158
|
+
logger.warning(
|
|
159
|
+
f"Matcher {matcher.function.__name__} (File: {file_name}: Line {frame.f_lineno!s}) has untyped parameters!"
|
|
160
|
+
+ f"(Args:{''.join(i + ',' for i in failed_args)}).Skipping......"
|
|
161
|
+
)
|
|
162
|
+
continue
|
|
163
|
+
new_args = []
|
|
164
|
+
used_indices = set()
|
|
165
|
+
for param_type in filtered_args_types.values():
|
|
166
|
+
for i, arg in enumerate(session_args):
|
|
167
|
+
if i in used_indices:
|
|
168
|
+
continue
|
|
169
|
+
if isinstance(arg, param_type):
|
|
170
|
+
new_args.append(arg)
|
|
171
|
+
used_indices.add(i)
|
|
172
|
+
break
|
|
173
|
+
|
|
174
|
+
# Get keyword argument type annotations
|
|
175
|
+
kwparams = signature.parameters
|
|
176
|
+
f_kwargs = { # TODO: kwparams dependent support
|
|
177
|
+
param_name: session_kwargs[param.annotation]
|
|
178
|
+
for param_name, param in kwparams.items()
|
|
179
|
+
if param.annotation in session_kwargs
|
|
180
|
+
}
|
|
181
|
+
if len(new_args) != len(list(filtered_args_types)):
|
|
182
|
+
continue
|
|
183
|
+
|
|
184
|
+
# Call the handler
|
|
185
|
+
|
|
186
|
+
try:
|
|
187
|
+
logger.info(f"Starting to run Matcher: '{handler.__name__}'")
|
|
188
|
+
|
|
189
|
+
await handler(*new_args, **f_kwargs)
|
|
190
|
+
except PassException:
|
|
191
|
+
logger.info(
|
|
192
|
+
f"Matcher '{handler.__name__}'(~{file_name}:{line_number}) was skipped"
|
|
193
|
+
)
|
|
194
|
+
continue
|
|
195
|
+
except CancelException:
|
|
196
|
+
logger.info("Cancelled Matcher processing")
|
|
197
|
+
return
|
|
198
|
+
except BlockException:
|
|
199
|
+
break
|
|
200
|
+
except Exception as e:
|
|
201
|
+
logger.opt(exception=e, colors=True).error(
|
|
202
|
+
f"An error occurred while running '{handler.__name__}'({file_name}:{line_number}) "
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
continue
|
|
206
|
+
finally:
|
|
207
|
+
logger.info(f"Handler {handler.__name__} finished")
|
|
208
|
+
if matcher.block:
|
|
209
|
+
break
|
|
210
|
+
else:
|
|
211
|
+
logger.warning(
|
|
212
|
+
f"No registered Matcher for {event_type} event, skipping processing."
|
|
213
|
+
)
|
amrita_core/hook/on.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from .event import EventTypeEnum
|
|
2
|
+
from .matcher import Matcher
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def on_completion(priority: int = 10, block: bool = True):
|
|
6
|
+
return on_event(EventTypeEnum.COMPLETION, priority, block)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def on_precompletion(priority: int = 10, block: bool = True):
|
|
10
|
+
return on_event(EventTypeEnum.BEFORE_COMPLETION, priority, block)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def on_event(event_type: EventTypeEnum, priority: int = 10, block: bool = True):
|
|
14
|
+
return Matcher(event_type, priority, block)
|
amrita_core/libchat.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import typing
|
|
4
|
+
from collections.abc import AsyncGenerator, Generator
|
|
5
|
+
|
|
6
|
+
from pydantic import ValidationError
|
|
7
|
+
|
|
8
|
+
from amrita_core.preset import PresetManager
|
|
9
|
+
|
|
10
|
+
from .config import get_config
|
|
11
|
+
from .logging import debug_log
|
|
12
|
+
from .protocol import (
|
|
13
|
+
AdapterManager,
|
|
14
|
+
ModelAdapter,
|
|
15
|
+
)
|
|
16
|
+
from .tokenizer import hybrid_token_count
|
|
17
|
+
from .tools.models import ToolChoice
|
|
18
|
+
from .types import (
|
|
19
|
+
CONTENT_LIST_TYPE,
|
|
20
|
+
Message,
|
|
21
|
+
ModelPreset,
|
|
22
|
+
ToolCall,
|
|
23
|
+
ToolResult,
|
|
24
|
+
UniResponse,
|
|
25
|
+
UniResponseUsage,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
T = typing.TypeVar("T")
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def text_generator(
|
|
32
|
+
memory: CONTENT_LIST_TYPE, split_role: bool = False
|
|
33
|
+
) -> Generator[str, None, str]:
|
|
34
|
+
memory_l = [(i.model_dump() if hasattr(i, "model_dump") else i) for i in memory]
|
|
35
|
+
role_map = {
|
|
36
|
+
"assistant": "<BOT's response>",
|
|
37
|
+
"user": "<User's query>",
|
|
38
|
+
"tool": "<Tool call>",
|
|
39
|
+
}
|
|
40
|
+
for st in memory_l:
|
|
41
|
+
if st["content"] is None:
|
|
42
|
+
continue
|
|
43
|
+
if isinstance(st["content"], str):
|
|
44
|
+
yield (
|
|
45
|
+
st["content"]
|
|
46
|
+
if not split_role
|
|
47
|
+
else role_map.get(st["role"], "") + st["content"]
|
|
48
|
+
)
|
|
49
|
+
else:
|
|
50
|
+
for s in st["content"]:
|
|
51
|
+
if s["type"] == "text" and s.get("text") is not None:
|
|
52
|
+
yield (
|
|
53
|
+
s["text"]
|
|
54
|
+
if not split_role
|
|
55
|
+
else role_map.get(st["role"], "") + s["text"]
|
|
56
|
+
)
|
|
57
|
+
return ""
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
async def get_tokens(
|
|
61
|
+
memory: CONTENT_LIST_TYPE, response: UniResponse[str, None]
|
|
62
|
+
) -> UniResponseUsage[int]:
|
|
63
|
+
"""Calculate token counts for messages and response
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
memory: Message history list
|
|
67
|
+
response: Model response
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
Object containing token usage information
|
|
71
|
+
"""
|
|
72
|
+
if (
|
|
73
|
+
response.usage is not None
|
|
74
|
+
and response.usage.total_tokens is not None
|
|
75
|
+
and response.usage.completion_tokens is not None
|
|
76
|
+
and response.usage.prompt_tokens is not None
|
|
77
|
+
):
|
|
78
|
+
return response.usage
|
|
79
|
+
|
|
80
|
+
it = hybrid_token_count(
|
|
81
|
+
"".join(list(text_generator(memory))),
|
|
82
|
+
get_config().llm.tokens_count_mode,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
ot = hybrid_token_count(response.content)
|
|
86
|
+
return UniResponseUsage(
|
|
87
|
+
prompt_tokens=it, total_tokens=it + ot, completion_tokens=ot
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def _validate_msg_list(
|
|
92
|
+
messages: CONTENT_LIST_TYPE,
|
|
93
|
+
) -> CONTENT_LIST_TYPE:
|
|
94
|
+
validated_messages = []
|
|
95
|
+
for msg in messages:
|
|
96
|
+
if isinstance(msg, dict):
|
|
97
|
+
# Ensure message has role field
|
|
98
|
+
if "role" not in msg:
|
|
99
|
+
raise ValueError("Message dictionary is missing 'role' field")
|
|
100
|
+
try:
|
|
101
|
+
validated_msg = (
|
|
102
|
+
Message.model_validate(msg)
|
|
103
|
+
if msg["role"] != "tool"
|
|
104
|
+
else ToolResult.model_validate(msg)
|
|
105
|
+
)
|
|
106
|
+
except ValidationError as e:
|
|
107
|
+
raise ValueError(f"Invalid message format: {e}")
|
|
108
|
+
validated_messages.append(validated_msg)
|
|
109
|
+
else:
|
|
110
|
+
validated_messages.append(msg)
|
|
111
|
+
return validated_messages
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
async def _call_with_reflection(
|
|
115
|
+
preset: ModelPreset,
|
|
116
|
+
call_func: typing.Callable[..., typing.Awaitable[T]],
|
|
117
|
+
*args,
|
|
118
|
+
**kwargs,
|
|
119
|
+
) -> T:
|
|
120
|
+
"""Call the specified function with the list of presets"""
|
|
121
|
+
adapter_class = AdapterManager().safe_get_adapter(preset.protocol)
|
|
122
|
+
if adapter_class:
|
|
123
|
+
debug_log(
|
|
124
|
+
f"Using adapter {adapter_class.__name__} to handle protocol {preset.protocol}"
|
|
125
|
+
)
|
|
126
|
+
else:
|
|
127
|
+
raise ValueError(f"Undefined protocol adapter: {preset.protocol}")
|
|
128
|
+
|
|
129
|
+
debug_log(f"Getting chat for {preset.model}")
|
|
130
|
+
debug_log(f"Preset: {preset.name}")
|
|
131
|
+
debug_log(f"Key: {preset.api_key[:7]}...")
|
|
132
|
+
debug_log(f"Protocol: {preset.protocol}")
|
|
133
|
+
debug_log(f"API URL: {preset.base_url}")
|
|
134
|
+
debug_log(f"Model: {preset.model}")
|
|
135
|
+
adapter = adapter_class(preset)
|
|
136
|
+
return await call_func(adapter, *args, **kwargs)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
async def tools_caller(
|
|
140
|
+
messages: CONTENT_LIST_TYPE,
|
|
141
|
+
tools: list,
|
|
142
|
+
preset: ModelPreset | None = None,
|
|
143
|
+
tool_choice: ToolChoice | None = None,
|
|
144
|
+
) -> UniResponse[None, list[ToolCall] | None]:
|
|
145
|
+
async def _call_tools(
|
|
146
|
+
adapter: ModelAdapter,
|
|
147
|
+
messages: CONTENT_LIST_TYPE,
|
|
148
|
+
tools,
|
|
149
|
+
tool_choice,
|
|
150
|
+
):
|
|
151
|
+
return await adapter.call_tools(messages, tools, tool_choice)
|
|
152
|
+
|
|
153
|
+
preset = preset or PresetManager().get_default_preset()
|
|
154
|
+
return await _call_with_reflection(
|
|
155
|
+
preset, _call_tools, messages, tools, tool_choice
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
async def call_completion(
|
|
160
|
+
messages: CONTENT_LIST_TYPE,
|
|
161
|
+
preset: ModelPreset | None = None,
|
|
162
|
+
) -> AsyncGenerator[str | UniResponse[str, None], None]:
|
|
163
|
+
"""Get chat response"""
|
|
164
|
+
messages = _validate_msg_list(messages)
|
|
165
|
+
preset = preset or PresetManager().get_default_preset()
|
|
166
|
+
|
|
167
|
+
async def _call_api(adapter: ModelAdapter, messages: CONTENT_LIST_TYPE):
|
|
168
|
+
async def inner():
|
|
169
|
+
async for i in adapter.call_api([(i.model_dump()) for i in messages]):
|
|
170
|
+
yield i
|
|
171
|
+
|
|
172
|
+
return inner
|
|
173
|
+
|
|
174
|
+
# Call adapter to get chat response
|
|
175
|
+
response = await _call_with_reflection(preset, _call_api, messages)
|
|
176
|
+
|
|
177
|
+
async for i in response():
|
|
178
|
+
yield i
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
async def get_last_response(
|
|
182
|
+
generator: AsyncGenerator[str | UniResponse[str, None], None],
|
|
183
|
+
) -> UniResponse[str, None]:
|
|
184
|
+
ls: list[UniResponse[str, None]] = [
|
|
185
|
+
i async for i in generator if isinstance(i, UniResponse)
|
|
186
|
+
]
|
|
187
|
+
if len(ls) == 0:
|
|
188
|
+
raise RuntimeError("No response found in generator.")
|
|
189
|
+
return ls[-1]
|