amrita_core 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
amrita_core/config.py ADDED
@@ -0,0 +1,159 @@
1
+ # TODO:分实例的配置
2
+ from __future__ import annotations
3
+
4
+ import random
5
+ import string
6
+ from typing import Literal
7
+
8
+ from pydantic import Field
9
+ from typing_extensions import LiteralString
10
+
11
+ from amrita_core.types import BaseModel
12
+
13
+
14
+ def random_alnum_string(length: int) -> str:
15
+ if length < 0:
16
+ raise ValueError("Length can't be smaller than zero!")
17
+
18
+ chars: LiteralString = string.ascii_letters + string.digits
19
+
20
+ return "".join(random.choice(chars) for _ in range(length))
21
+
22
+
23
+ class CookieConfig(BaseModel):
24
+ """Amrita Core's cookie config"""
25
+
26
+ enable_cookie: bool = Field(
27
+ default=True, description="Whether to enable Cookie leak detection mechanism"
28
+ )
29
+ cookie: str = Field(
30
+ default_factory=lambda: random_alnum_string(16),
31
+ description="Cookie string for security detection",
32
+ )
33
+
34
+
35
+ class FunctionConfig(BaseModel):
36
+ use_minimal_context: bool = Field(
37
+ default=True,
38
+ description="Whether to use minimal context, i.e. system prompt + user's last message (disabling this option will use all context from the message list, which may consume a large amount of Tokens during Agent workflow execution; enabling this option may effectively reduce token usage)",
39
+ )
40
+
41
+ tool_calling_mode: Literal["agent", "rag", "none"] = Field(
42
+ default="agent",
43
+ description="Tool calling mode, i.e. whether to use Agent or RAG to call tools",
44
+ )
45
+ agent_tool_call_limit: int = Field(
46
+ default=10, description="Tool call limit in agent mode"
47
+ )
48
+ agent_tool_call_notice: Literal["hide", "notify"] = Field(
49
+ default="hide",
50
+ description="Method of showing tool call status in agent mode, hide to conceal, notify to inform",
51
+ )
52
+ agent_thought_mode: Literal[
53
+ "reasoning", "chat", "reasoning-required", "reasoning-optional"
54
+ ] = Field(
55
+ default="chat",
56
+ description="Thinking mode in agent mode, reasoning mode will first perform reasoning process, then execute tasks; "
57
+ "reasoning-required requires task analysis for each Tool Calling; "
58
+ "reasoning-optional does not require reasoning but allows it; "
59
+ "chat mode executes tasks directly",
60
+ )
61
+ agent_reasoning_hide: bool = Field(
62
+ default=False, description="Whether to hide the thought process in agent mode"
63
+ )
64
+ agent_middle_message: bool = Field(
65
+ default=True,
66
+ description="Whether to allow Agent to send intermediate messages to users in agent mode",
67
+ )
68
+ agent_mcp_client_enable: bool = Field(
69
+ default=False, description="Whether to enable MCP client"
70
+ )
71
+ agent_mcp_server_scripts: list[str] = Field(
72
+ default=[], description="List of MCP server scripts"
73
+ )
74
+
75
+
76
+ class LLMConfig(BaseModel):
77
+ require_tools: bool = Field(
78
+ default=False,
79
+ description="Whether to force at least one tool to be used per call",
80
+ )
81
+ memory_lenth_limit: int = Field(
82
+ default=50, description="Maximum number of messages in memory context"
83
+ )
84
+ max_tokens: int = Field(
85
+ default=100,
86
+ description="Maximum number of tokens generated in a single response",
87
+ )
88
+ tokens_count_mode: Literal["word", "bpe", "char"] = Field(
89
+ default="bpe",
90
+ description="Token counting mode: bpe(subwords)/word(words)/char(characters)",
91
+ )
92
+ enable_tokens_limit: bool = Field(
93
+ default=True, description="Whether to enable context length limits"
94
+ )
95
+ session_tokens_windows: int = Field(
96
+ default=5000, description="Session tokens window size"
97
+ )
98
+ llm_timeout: int = Field(
99
+ default=60, description="API request timeout duration (seconds)"
100
+ )
101
+ auto_retry: bool = Field(
102
+ default=True, description="Automatically retry on request failure"
103
+ )
104
+ max_retries: int = Field(default=3, description="Maximum number of retries")
105
+ enable_memory_abstract: bool = Field(
106
+ default=True,
107
+ description="Whether to enable context memory summarization (will delete context and insert a summary into system instruction)",
108
+ )
109
+ memory_abstract_proportion: float = Field(
110
+ default=15e-2, description="Context summarization proportion (0.15=15%)"
111
+ )
112
+ enable_multi_modal: bool = Field(
113
+ default=True,
114
+ description="Whether to enable multi-modal support (currently only supports image)",
115
+ )
116
+
117
+
118
+ class AmritaConfig(BaseModel):
119
+ function_config: FunctionConfig = Field(
120
+ default_factory=FunctionConfig,
121
+ description="Function configuration",
122
+ )
123
+ llm: LLMConfig = Field(
124
+ default_factory=LLMConfig,
125
+ description="LLM configuration",
126
+ )
127
+ cookie: CookieConfig = Field(
128
+ default_factory=CookieConfig, description="Cookie configuration"
129
+ )
130
+
131
+
132
+ __config = AmritaConfig()
133
+ __inited: bool = False
134
+
135
+
136
+ def get_config() -> AmritaConfig:
137
+ """Get amrita config
138
+
139
+ Raises:
140
+ RuntimeError: Raise it if amrita core is not initialized.
141
+
142
+ Returns:
143
+ AmritaConfig: Amrita core config
144
+ """
145
+ if not __inited:
146
+ raise RuntimeError("Config is not initialized. Please set config first.")
147
+ return __config
148
+
149
+
150
+ def set_config(config: AmritaConfig):
151
+ """Override the config.
152
+
153
+ Args:
154
+ config (AmritaConfig): Configuration object to set
155
+ """
156
+ global __config, __inited
157
+ if not __inited:
158
+ __inited = True
159
+ __config = config
@@ -0,0 +1,90 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from enum import Enum
5
+ from typing import TYPE_CHECKING
6
+
7
+ from amrita_core.types import USER_INPUT, SendMessageWrap
8
+
9
+ if TYPE_CHECKING:
10
+ from amrita_core.chatmanager import ChatObject
11
+
12
+
13
+ class EventTypeEnum(str, Enum):
14
+ """
15
+ EventTypeEnum class is used to define and manage different event types.
16
+ It encapsulates the string identifiers of event types, providing a structured way
17
+ to handle and retrieve event types.
18
+
19
+ """
20
+
21
+ COMPLETION = "COMPLETION"
22
+ Nil = "Nil"
23
+ BEFORE_COMPLETION = "BEFORE_COMPLETION"
24
+
25
+ def validate(self, name: str) -> bool:
26
+ return name in self
27
+
28
+
29
+ @dataclass
30
+ class Event:
31
+ user_input: USER_INPUT
32
+ original_context: SendMessageWrap
33
+ chat_object: "ChatObject"
34
+
35
+ def __post_init__(self):
36
+ # Initialize event type as none
37
+ self._event_type = EventTypeEnum.Nil
38
+ # Validate and store messages using SendMessageWrap
39
+ self._context_messages = self.original_context
40
+
41
+ @property
42
+ def event_type(self) -> EventTypeEnum:
43
+ return self._event_type
44
+
45
+ @property
46
+ def message(self) -> SendMessageWrap:
47
+ return self._context_messages
48
+
49
+ def get_context_messages(self) -> SendMessageWrap:
50
+ return self._context_messages
51
+
52
+ def get_event_type(self) -> str:
53
+ raise NotImplementedError
54
+
55
+ def get_user_input(self) -> USER_INPUT:
56
+ return self.user_input
57
+
58
+
59
+ @dataclass
60
+ class CompletionEvent(Event):
61
+ model_response: str
62
+
63
+ def __post_init__(self):
64
+ super().__post_init__()
65
+ # Initialize event type as completion event
66
+ self._event_type = EventTypeEnum.COMPLETION
67
+
68
+ @property
69
+ def event_type(self):
70
+ return EventTypeEnum.COMPLETION
71
+
72
+ def get_event_type(self) -> str:
73
+ return EventTypeEnum.COMPLETION
74
+
75
+ def get_model_response(self) -> str:
76
+ return self.model_response
77
+
78
+
79
+ @dataclass
80
+ class PreCompletionEvent(Event):
81
+ def __post_init__(self):
82
+ super().__post_init__()
83
+ self._event_type = EventTypeEnum.BEFORE_COMPLETION
84
+
85
+ @property
86
+ def event_type(self) -> EventTypeEnum:
87
+ return self._event_type
88
+
89
+ def get_event_type(self) -> str:
90
+ return self._event_type
@@ -0,0 +1,14 @@
1
+ class MatcherException(Exception):
2
+ """Base exception for Matcher."""
3
+
4
+
5
+ class BlockException(MatcherException):
6
+ pass
7
+
8
+
9
+ class CancelException(MatcherException):
10
+ pass
11
+
12
+
13
+ class PassException(MatcherException):
14
+ pass
@@ -0,0 +1,213 @@
1
+ import inspect
2
+ from collections.abc import Awaitable, Callable
3
+ from copy import deepcopy
4
+ from types import FrameType
5
+ from typing import (
6
+ Any,
7
+ ClassVar,
8
+ TypeAlias,
9
+ )
10
+
11
+ from pydantic import BaseModel, Field
12
+ from typing_extensions import Self
13
+
14
+ from amrita_core.logging import debug_log, logger
15
+
16
+ from .event import Event
17
+ from .exception import BlockException, CancelException, PassException
18
+
19
+ ChatException: TypeAlias = BlockException | CancelException | PassException
20
+
21
+
22
+ class FunctionData(BaseModel, arbitrary_types_allowed=True):
23
+ function: Callable[..., Awaitable[Any]] = Field(...)
24
+ signature: inspect.Signature = Field(...)
25
+ frame: FrameType = Field(...)
26
+ priority: int = Field(...)
27
+ block: bool = Field(...)
28
+ matcher: Any = Field(...)
29
+
30
+
31
+ class EventRegistry:
32
+ _instance = None
33
+ __event_handlers: ClassVar[dict[str, list[FunctionData]]] = {}
34
+
35
+ def __new__(cls) -> Self:
36
+ if cls._instance is None:
37
+ cls._instance = super().__new__(cls)
38
+ return cls._instance
39
+
40
+ def register_handler(self, event_type: str, data: FunctionData):
41
+ self.__event_handlers.setdefault(event_type, []).append(data)
42
+
43
+ def get_handlers(self, event_type: str) -> list[FunctionData]:
44
+ self.__event_handlers.setdefault(event_type, [])
45
+ self.__event_handlers[event_type].sort(key=lambda x: x.priority, reverse=False)
46
+ return self.__event_handlers[event_type]
47
+
48
+ def _all(self) -> dict[str, list[FunctionData]]:
49
+ return self.__event_handlers
50
+
51
+
52
+ class Matcher:
53
+ def __init__(self, event_type: str, priority: int = 10, block: bool = True):
54
+ """Constructor, initialize Matcher object.
55
+ Args:
56
+ event_type (str): Event type
57
+ priority (int, optional): Priority. Defaults to 10.
58
+ block (bool, optional): Whether to block subsequent events. Defaults to True.
59
+ """
60
+ if priority <= 0:
61
+ raise ValueError("Event priority cannot be zero or negative!")
62
+
63
+ self.event_type = event_type
64
+ self.priority = priority
65
+ self.block = block
66
+
67
+ def append_handler(self, func: Callable[..., Awaitable[Any]]):
68
+ frame = inspect.currentframe()
69
+ assert frame is not None, "Frame is None!!!"
70
+ func_data = FunctionData(
71
+ function=func,
72
+ signature=inspect.signature(func),
73
+ frame=frame,
74
+ priority=self.priority,
75
+ block=self.block,
76
+ matcher=self,
77
+ )
78
+ EventRegistry().register_handler(self.event_type, func_data)
79
+
80
+ def handle(self):
81
+ """
82
+ Event handler registration function
83
+ """
84
+
85
+ def wrapper(
86
+ func: Callable[..., Awaitable[Any]],
87
+ ):
88
+ self.append_handler(func)
89
+ return func
90
+
91
+ return wrapper
92
+
93
+ def stop_process(self):
94
+ """
95
+ Stop the event flow within the current chat plugin and immediately stop the current handler.
96
+ """
97
+ raise BlockException()
98
+
99
+ def cancel_matcher(self):
100
+ """
101
+ Stop event processing within the current chat plugin and cancel.
102
+ """
103
+ raise CancelException()
104
+
105
+ def pass_event(self):
106
+ """
107
+ Ignore the current handler and continue processing the next one.
108
+ """
109
+ raise PassException()
110
+
111
+
112
+ class MatcherManager:
113
+ """
114
+ Event handling manager.
115
+ """
116
+
117
+ @staticmethod
118
+ async def trigger_event(*args, **kwargs) -> None:
119
+ """
120
+ Trigger a specific type of event and call all registered event handlers for that type.
121
+
122
+ Parameters:
123
+ - event: Event object containing event-related data.
124
+ - **kwargs: Keyword arguments passed to the dependency injection system.
125
+ - *args: Variable arguments passed to the dependency injection system.
126
+ """
127
+ event: Event | None = None
128
+ for i in args:
129
+ if isinstance(i, Event):
130
+ event = i
131
+ break
132
+ if not event:
133
+ raise RuntimeError("No event found in args")
134
+ event_type = event.get_event_type() # Get event type
135
+ priority_tmp = 0
136
+ debug_log(f"Running matchers for event: {event_type}!")
137
+ # Check if there are handlers for this event type
138
+ if matcher_list := EventRegistry().get_handlers(event_type):
139
+ for matcher in matcher_list:
140
+ if matcher.priority != priority_tmp:
141
+ priority_tmp = matcher.priority
142
+ debug_log(f"Running matchers for priority {priority_tmp}......")
143
+
144
+ signature = matcher.signature
145
+ frame = matcher.frame
146
+ line_number = frame.f_lineno
147
+ file_name = frame.f_code.co_filename
148
+ handler = matcher.function
149
+ session_args = [matcher.matcher, *args]
150
+ session_kwargs = {**deepcopy(kwargs)}
151
+
152
+ args_types = {k: v.annotation for k, v in signature.parameters.items()}
153
+ filtered_args_types = {
154
+ k: v for k, v in args_types.items() if v is not inspect._empty
155
+ }
156
+ if args_types != filtered_args_types:
157
+ failed_args = list(args_types.keys() - filtered_args_types.keys())
158
+ logger.warning(
159
+ f"Matcher {matcher.function.__name__} (File: {file_name}: Line {frame.f_lineno!s}) has untyped parameters!"
160
+ + f"(Args:{''.join(i + ',' for i in failed_args)}).Skipping......"
161
+ )
162
+ continue
163
+ new_args = []
164
+ used_indices = set()
165
+ for param_type in filtered_args_types.values():
166
+ for i, arg in enumerate(session_args):
167
+ if i in used_indices:
168
+ continue
169
+ if isinstance(arg, param_type):
170
+ new_args.append(arg)
171
+ used_indices.add(i)
172
+ break
173
+
174
+ # Get keyword argument type annotations
175
+ kwparams = signature.parameters
176
+ f_kwargs = { # TODO: kwparams dependent support
177
+ param_name: session_kwargs[param.annotation]
178
+ for param_name, param in kwparams.items()
179
+ if param.annotation in session_kwargs
180
+ }
181
+ if len(new_args) != len(list(filtered_args_types)):
182
+ continue
183
+
184
+ # Call the handler
185
+
186
+ try:
187
+ logger.info(f"Starting to run Matcher: '{handler.__name__}'")
188
+
189
+ await handler(*new_args, **f_kwargs)
190
+ except PassException:
191
+ logger.info(
192
+ f"Matcher '{handler.__name__}'(~{file_name}:{line_number}) was skipped"
193
+ )
194
+ continue
195
+ except CancelException:
196
+ logger.info("Cancelled Matcher processing")
197
+ return
198
+ except BlockException:
199
+ break
200
+ except Exception as e:
201
+ logger.opt(exception=e, colors=True).error(
202
+ f"An error occurred while running '{handler.__name__}'({file_name}:{line_number}) "
203
+ )
204
+
205
+ continue
206
+ finally:
207
+ logger.info(f"Handler {handler.__name__} finished")
208
+ if matcher.block:
209
+ break
210
+ else:
211
+ logger.warning(
212
+ f"No registered Matcher for {event_type} event, skipping processing."
213
+ )
amrita_core/hook/on.py ADDED
@@ -0,0 +1,14 @@
1
+ from .event import EventTypeEnum
2
+ from .matcher import Matcher
3
+
4
+
5
+ def on_completion(priority: int = 10, block: bool = True):
6
+ return on_event(EventTypeEnum.COMPLETION, priority, block)
7
+
8
+
9
+ def on_precompletion(priority: int = 10, block: bool = True):
10
+ return on_event(EventTypeEnum.BEFORE_COMPLETION, priority, block)
11
+
12
+
13
+ def on_event(event_type: EventTypeEnum, priority: int = 10, block: bool = True):
14
+ return Matcher(event_type, priority, block)
amrita_core/libchat.py ADDED
@@ -0,0 +1,189 @@
1
+ from __future__ import annotations
2
+
3
+ import typing
4
+ from collections.abc import AsyncGenerator, Generator
5
+
6
+ from pydantic import ValidationError
7
+
8
+ from amrita_core.preset import PresetManager
9
+
10
+ from .config import get_config
11
+ from .logging import debug_log
12
+ from .protocol import (
13
+ AdapterManager,
14
+ ModelAdapter,
15
+ )
16
+ from .tokenizer import hybrid_token_count
17
+ from .tools.models import ToolChoice
18
+ from .types import (
19
+ CONTENT_LIST_TYPE,
20
+ Message,
21
+ ModelPreset,
22
+ ToolCall,
23
+ ToolResult,
24
+ UniResponse,
25
+ UniResponseUsage,
26
+ )
27
+
28
+ T = typing.TypeVar("T")
29
+
30
+
31
+ def text_generator(
32
+ memory: CONTENT_LIST_TYPE, split_role: bool = False
33
+ ) -> Generator[str, None, str]:
34
+ memory_l = [(i.model_dump() if hasattr(i, "model_dump") else i) for i in memory]
35
+ role_map = {
36
+ "assistant": "<BOT's response>",
37
+ "user": "<User's query>",
38
+ "tool": "<Tool call>",
39
+ }
40
+ for st in memory_l:
41
+ if st["content"] is None:
42
+ continue
43
+ if isinstance(st["content"], str):
44
+ yield (
45
+ st["content"]
46
+ if not split_role
47
+ else role_map.get(st["role"], "") + st["content"]
48
+ )
49
+ else:
50
+ for s in st["content"]:
51
+ if s["type"] == "text" and s.get("text") is not None:
52
+ yield (
53
+ s["text"]
54
+ if not split_role
55
+ else role_map.get(st["role"], "") + s["text"]
56
+ )
57
+ return ""
58
+
59
+
60
+ async def get_tokens(
61
+ memory: CONTENT_LIST_TYPE, response: UniResponse[str, None]
62
+ ) -> UniResponseUsage[int]:
63
+ """Calculate token counts for messages and response
64
+
65
+ Args:
66
+ memory: Message history list
67
+ response: Model response
68
+
69
+ Returns:
70
+ Object containing token usage information
71
+ """
72
+ if (
73
+ response.usage is not None
74
+ and response.usage.total_tokens is not None
75
+ and response.usage.completion_tokens is not None
76
+ and response.usage.prompt_tokens is not None
77
+ ):
78
+ return response.usage
79
+
80
+ it = hybrid_token_count(
81
+ "".join(list(text_generator(memory))),
82
+ get_config().llm.tokens_count_mode,
83
+ )
84
+
85
+ ot = hybrid_token_count(response.content)
86
+ return UniResponseUsage(
87
+ prompt_tokens=it, total_tokens=it + ot, completion_tokens=ot
88
+ )
89
+
90
+
91
+ def _validate_msg_list(
92
+ messages: CONTENT_LIST_TYPE,
93
+ ) -> CONTENT_LIST_TYPE:
94
+ validated_messages = []
95
+ for msg in messages:
96
+ if isinstance(msg, dict):
97
+ # Ensure message has role field
98
+ if "role" not in msg:
99
+ raise ValueError("Message dictionary is missing 'role' field")
100
+ try:
101
+ validated_msg = (
102
+ Message.model_validate(msg)
103
+ if msg["role"] != "tool"
104
+ else ToolResult.model_validate(msg)
105
+ )
106
+ except ValidationError as e:
107
+ raise ValueError(f"Invalid message format: {e}")
108
+ validated_messages.append(validated_msg)
109
+ else:
110
+ validated_messages.append(msg)
111
+ return validated_messages
112
+
113
+
114
+ async def _call_with_reflection(
115
+ preset: ModelPreset,
116
+ call_func: typing.Callable[..., typing.Awaitable[T]],
117
+ *args,
118
+ **kwargs,
119
+ ) -> T:
120
+ """Call the specified function with the list of presets"""
121
+ adapter_class = AdapterManager().safe_get_adapter(preset.protocol)
122
+ if adapter_class:
123
+ debug_log(
124
+ f"Using adapter {adapter_class.__name__} to handle protocol {preset.protocol}"
125
+ )
126
+ else:
127
+ raise ValueError(f"Undefined protocol adapter: {preset.protocol}")
128
+
129
+ debug_log(f"Getting chat for {preset.model}")
130
+ debug_log(f"Preset: {preset.name}")
131
+ debug_log(f"Key: {preset.api_key[:7]}...")
132
+ debug_log(f"Protocol: {preset.protocol}")
133
+ debug_log(f"API URL: {preset.base_url}")
134
+ debug_log(f"Model: {preset.model}")
135
+ adapter = adapter_class(preset)
136
+ return await call_func(adapter, *args, **kwargs)
137
+
138
+
139
+ async def tools_caller(
140
+ messages: CONTENT_LIST_TYPE,
141
+ tools: list,
142
+ preset: ModelPreset | None = None,
143
+ tool_choice: ToolChoice | None = None,
144
+ ) -> UniResponse[None, list[ToolCall] | None]:
145
+ async def _call_tools(
146
+ adapter: ModelAdapter,
147
+ messages: CONTENT_LIST_TYPE,
148
+ tools,
149
+ tool_choice,
150
+ ):
151
+ return await adapter.call_tools(messages, tools, tool_choice)
152
+
153
+ preset = preset or PresetManager().get_default_preset()
154
+ return await _call_with_reflection(
155
+ preset, _call_tools, messages, tools, tool_choice
156
+ )
157
+
158
+
159
+ async def call_completion(
160
+ messages: CONTENT_LIST_TYPE,
161
+ preset: ModelPreset | None = None,
162
+ ) -> AsyncGenerator[str | UniResponse[str, None], None]:
163
+ """Get chat response"""
164
+ messages = _validate_msg_list(messages)
165
+ preset = preset or PresetManager().get_default_preset()
166
+
167
+ async def _call_api(adapter: ModelAdapter, messages: CONTENT_LIST_TYPE):
168
+ async def inner():
169
+ async for i in adapter.call_api([(i.model_dump()) for i in messages]):
170
+ yield i
171
+
172
+ return inner
173
+
174
+ # Call adapter to get chat response
175
+ response = await _call_with_reflection(preset, _call_api, messages)
176
+
177
+ async for i in response():
178
+ yield i
179
+
180
+
181
+ async def get_last_response(
182
+ generator: AsyncGenerator[str | UniResponse[str, None], None],
183
+ ) -> UniResponse[str, None]:
184
+ ls: list[UniResponse[str, None]] = [
185
+ i async for i in generator if isinstance(i, UniResponse)
186
+ ]
187
+ if len(ls) == 0:
188
+ raise RuntimeError("No response found in generator.")
189
+ return ls[-1]