amrita_core 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
amrita_core/logging.py ADDED
@@ -0,0 +1,71 @@
1
+ # ref: https://github.com/NoneBot/NoneBot2/blob/main/nonebot/log.py
2
+ import inspect
3
+ import logging
4
+ import os
5
+ import sys
6
+ from typing import TYPE_CHECKING, Protocol
7
+
8
+ import loguru
9
+
10
+ if TYPE_CHECKING:
11
+ from loguru import Logger, Record
12
+
13
+ logger: "Logger" = loguru.logger
14
+
15
+ debug: bool = False
16
+
17
+
18
+ class ToStringAble(Protocol):
19
+ def __str__(self) -> str: ...
20
+
21
+
22
+ def debug_log(message: ToStringAble) -> None:
23
+ global debug
24
+ if debug:
25
+ logger.debug(message)
26
+
27
+
28
+ class LoguruHandler(logging.Handler):
29
+ def emit(self, record: logging.LogRecord):
30
+ try:
31
+ level = logger.level(record.levelname).name
32
+ except ValueError:
33
+ level = record.levelno
34
+
35
+ frame, depth = inspect.currentframe(), 0
36
+ while frame and (depth == 0 or frame.f_code.co_filename == logging.__file__):
37
+ frame = frame.f_back
38
+ depth += 1
39
+
40
+ logger.opt(depth=depth, exception=record.exc_info, colors=True).log(
41
+ level, record.getMessage()
42
+ )
43
+
44
+
45
+ def default_filter(record: "Record"):
46
+ """Default filter for logging, change level from Environment"""
47
+ log_level = os.environ.get("LOG_LEVEL", "INFO")
48
+ levelno = logger.level(log_level).no if isinstance(log_level, str) else log_level
49
+ return record["level"].no >= levelno
50
+
51
+
52
+ default_format: str = (
53
+ "<green>{time:YYYY-MM-DD HH:mm:ss}</green> | "
54
+ "<level>{level: <7}</level> | "
55
+ "<magenta>{name}:{function}:{line}</magenta> | "
56
+ "<level>{message}</level>"
57
+ )
58
+ """Default log format"""
59
+
60
+ logger.remove()
61
+ logger_id = logger.add(
62
+ sys.stdout,
63
+ level=0,
64
+ diagnose=False,
65
+ filter=default_filter,
66
+ format=default_format,
67
+ )
68
+ """Default log handler id"""
69
+
70
+
71
+ __autodoc__ = {"logger_id": False}
amrita_core/preset.py ADDED
@@ -0,0 +1,166 @@
1
+ import random
2
+ import time
3
+ import typing
4
+
5
+ from typing_extensions import Self
6
+
7
+ from .logging import debug_log
8
+ from .protocol import AdapterManager
9
+ from .tokenizer import hybrid_token_count
10
+ from .types import BaseModel, Message, ModelPreset, TextContent, UniResponse
11
+
12
+ TEST_MSG_PROMPT: Message[list[TextContent]] = Message(
13
+ role="system",
14
+ content=[TextContent(text="You are a helpful assistant.", type="text")],
15
+ )
16
+ TEST_MSG_USER: Message[list[TextContent]] = Message(
17
+ role="user",
18
+ content=[
19
+ TextContent(text="Hello, please briefly introduce yourself.", type="text")
20
+ ],
21
+ )
22
+
23
+ TEST_MSG_LIST: list[Message[list[TextContent]]] = [
24
+ TEST_MSG_PROMPT,
25
+ TEST_MSG_USER,
26
+ ]
27
+
28
+
29
+ class PresetReport(BaseModel):
30
+ preset_name: str # Name of the preset
31
+ preset_data: ModelPreset # Preset data
32
+ test_input: tuple[Message, Message] # Test input
33
+ test_output: Message | None # Test output
34
+ token_prompt: int # Token count of the prompt
35
+ token_completion: int # Token count of the completion
36
+ status: bool # Test result
37
+ message: str # Test result message
38
+ time_used: float
39
+
40
+
41
+ class PresetManager:
42
+ """
43
+ PresetManager is a singleton class that manages presets.
44
+ """
45
+
46
+ _default_preset: ModelPreset | None = None
47
+ _presets: dict[str, ModelPreset]
48
+ _instance = None
49
+
50
+ def __new__(cls) -> Self:
51
+ if cls._instance is None:
52
+ cls._presets = {}
53
+ cls._instance = super().__new__(cls)
54
+ return cls._instance
55
+
56
+ def set_default_preset(self, preset: ModelPreset | str) -> None:
57
+ """
58
+ Set the default preset.
59
+ """
60
+ if isinstance(preset, str):
61
+ preset = self.get_preset(preset)
62
+ if preset.name not in self._presets:
63
+ self.add_preset(preset)
64
+ self._default_preset = preset
65
+
66
+ def get_default_preset(self) -> ModelPreset:
67
+ """
68
+ Get the default preset.
69
+ """
70
+ if self._default_preset is None:
71
+ self._default_preset = random.choice(list(self._presets.values()))
72
+ return self._default_preset
73
+
74
+ def get_preset(self, name: str) -> ModelPreset:
75
+ """
76
+ Get a preset by name.
77
+ """
78
+ if name not in self._presets:
79
+ raise ValueError(f"Preset {name} not found")
80
+ return self._presets[name]
81
+
82
+ def add_preset(self, preset: ModelPreset) -> None:
83
+ """
84
+ Add a preset.
85
+ """
86
+ if preset.name in self._presets:
87
+ raise ValueError(f"Preset {preset.name} already exists")
88
+ self._presets[preset.name] = preset
89
+
90
+ def get_all_presets(self) -> list[ModelPreset]:
91
+ """
92
+ Get all presets.
93
+ """
94
+ return list(self._presets.values())
95
+
96
+ async def test_single_preset(self, preset: ModelPreset | str) -> PresetReport:
97
+ """Test a single preset for parallel execution"""
98
+ if isinstance(preset, str):
99
+ preset = self.get_preset(preset)
100
+ debug_log(f"Testing preset: {preset.name}...")
101
+ prompt_tokens = hybrid_token_count(
102
+ "".join(
103
+ [typing.cast(TextContent, msg.content[0]).text for msg in TEST_MSG_LIST]
104
+ )
105
+ )
106
+
107
+ adapter = AdapterManager().safe_get_adapter(preset.protocol)
108
+ if adapter is None:
109
+ return PresetReport(
110
+ preset_name=preset.name,
111
+ preset_data=preset,
112
+ test_input=(TEST_MSG_PROMPT, TEST_MSG_USER),
113
+ test_output=None,
114
+ token_prompt=prompt_tokens,
115
+ token_completion=0,
116
+ status=False,
117
+ message=f"Undefined protocol adapter: {preset.protocol}",
118
+ time_used=0,
119
+ )
120
+
121
+ try:
122
+ time_start = time.time()
123
+ debug_log(f"Calling preset: {preset.name}...")
124
+ data = [ # noqa: RUF015
125
+ i
126
+ async for i in adapter(preset).call_api(TEST_MSG_LIST)
127
+ if isinstance(i, UniResponse)
128
+ ][0]
129
+ time_end = time.time()
130
+ time_delta = time_end - time_start
131
+ debug_log(
132
+ f"Successfully called preset {preset.name}, took {time_delta:.2f} seconds"
133
+ )
134
+ return PresetReport(
135
+ preset_name=preset.name,
136
+ preset_data=preset,
137
+ test_input=(TEST_MSG_PROMPT, TEST_MSG_USER),
138
+ test_output=Message[list[TextContent]](
139
+ role="assistant",
140
+ content=[TextContent(type="text", text=data.content)],
141
+ ),
142
+ token_prompt=prompt_tokens,
143
+ token_completion=hybrid_token_count(data.content),
144
+ status=True,
145
+ message="",
146
+ time_used=time_delta,
147
+ )
148
+ except Exception as e:
149
+ debug_log(f"Error occurred while testing preset {preset.name}: {e}")
150
+ return PresetReport(
151
+ preset_name=preset.name,
152
+ preset_data=preset,
153
+ test_input=(TEST_MSG_PROMPT, TEST_MSG_USER),
154
+ test_output=None,
155
+ token_prompt=prompt_tokens,
156
+ token_completion=0,
157
+ status=False,
158
+ message=str(e),
159
+ time_used=0,
160
+ )
161
+
162
+ async def test_presets(self) -> typing.AsyncGenerator[PresetReport, None]:
163
+ presets: list[ModelPreset] = self.get_all_presets()
164
+ debug_log(f"Starting to test all presets ({len(presets)} total)...")
165
+ for preset in presets:
166
+ yield await self.test_single_preset(preset)
@@ -0,0 +1,101 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import abstractmethod
4
+ from collections.abc import AsyncGenerator, Iterable
5
+ from dataclasses import dataclass
6
+
7
+ from .logging import logger
8
+ from .tools.models import ToolChoice, ToolFunctionSchema
9
+ from .types import ModelPreset, ToolCall, UniResponse
10
+
11
+
12
+ @dataclass
13
+ class ModelAdapter:
14
+ """Base class for model adapter"""
15
+
16
+ preset: ModelPreset
17
+ __override__: bool = False # Whether to allow overriding existing adapters
18
+
19
+ def __init_subclass__(cls) -> None:
20
+ super().__init_subclass__()
21
+ if not getattr(cls, "__abstract__", False):
22
+ AdapterManager().register_adapter(cls)
23
+
24
+ @abstractmethod
25
+ async def call_api(
26
+ self, messages: Iterable
27
+ ) -> AsyncGenerator[str | UniResponse[str, None], None]:
28
+ yield ""
29
+
30
+ async def call_tools(
31
+ self,
32
+ messages: Iterable,
33
+ tools: list[ToolFunctionSchema],
34
+ tool_choice: ToolChoice | None = None,
35
+ ) -> UniResponse[None, list[ToolCall] | None]:
36
+ raise NotImplementedError
37
+
38
+ @staticmethod
39
+ @abstractmethod
40
+ def get_adapter_protocol() -> str | tuple[str, ...]: ...
41
+
42
+ @property
43
+ def protocol(self):
44
+ """Get model protocol adapter"""
45
+ return self.get_adapter_protocol()
46
+
47
+
48
+ class AdapterManager:
49
+ __instance = None
50
+ _adapter_class: dict[str, type[ModelAdapter]]
51
+
52
+ def __new__(cls):
53
+ if cls.__instance is None:
54
+ cls.__instance = super().__new__(cls)
55
+ cls.__instance._adapter_class = {}
56
+ return cls.__instance
57
+
58
+ def get_adapters(self) -> dict[str, type[ModelAdapter]]:
59
+ """Get all registered adapters"""
60
+ return self._adapter_class
61
+
62
+ def safe_get_adapter(self, protocol: str) -> type[ModelAdapter] | None:
63
+ """Get adapter"""
64
+ return self._adapter_class.get(protocol)
65
+
66
+ def get_adapter(self, protocol: str) -> type[ModelAdapter]:
67
+ """Get adapter"""
68
+ if protocol not in self._adapter_class:
69
+ raise ValueError(f"No adapter found for protocol {protocol}")
70
+ return self._adapter_class[protocol]
71
+
72
+ def register_adapter(self, adapter: type[ModelAdapter]):
73
+ """Register adapter"""
74
+ protocol = adapter.get_adapter_protocol()
75
+ override = adapter.__override__ if hasattr(adapter, "__override__") else False
76
+ if isinstance(protocol, str):
77
+ if protocol in self._adapter_class:
78
+ if not override:
79
+ raise ValueError(
80
+ f"Model protocol adapter {protocol} is already registered"
81
+ )
82
+ logger.warning(
83
+ f"Model protocol adapter {protocol} has been registered by {self._adapter_class[protocol].__name__}, overriding existing adapter"
84
+ )
85
+
86
+ self._adapter_class[protocol] = adapter
87
+ elif isinstance(protocol, tuple):
88
+ for p in protocol:
89
+ if not isinstance(p, str):
90
+ raise TypeError(
91
+ "Model protocol adapter must be a string or tuple of strings"
92
+ )
93
+ if p in self._adapter_class:
94
+ if not override:
95
+ raise ValueError(
96
+ f"Model protocol adapter {p} is already registered"
97
+ )
98
+ logger.warning(
99
+ f"Model protocol adapter {p} has been registered by {self._adapter_class[p].__name__}, overriding existing adapter"
100
+ )
101
+ self._adapter_class[p] = adapter
@@ -0,0 +1,115 @@
1
+ import re
2
+ from functools import lru_cache
3
+ from typing import Literal
4
+
5
+ import jieba
6
+
7
+
8
+ @lru_cache(maxsize=2048)
9
+ def hybrid_token_count(
10
+ text: str,
11
+ mode: Literal["word", "bpe", "char"] = "word",
12
+ truncate_mode: Literal["head", "tail", "middle"] = "head",
13
+ ) -> int:
14
+ """
15
+ Calculate token count for mixed Chinese-English text, supporting word, subword, and character modes
16
+
17
+ Args:
18
+ text: Input text
19
+ mode: Tokenization mode ['char'(character-level), 'word'(word-level), 'bpe'(mixed mode)], default bpe
20
+ truncate_mode: Truncation mode ['head'(head truncation), 'tail'(tail truncation), 'middle'(middle truncation)], default head
21
+
22
+ Returns:
23
+ int: Number of tokens
24
+ """
25
+ return Tokenizer(mode=mode, truncate_mode=truncate_mode).count_tokens(text=text)
26
+
27
+
28
+ class Tokenizer:
29
+ """General purpose text tokenizer"""
30
+
31
+ def __init__(
32
+ self,
33
+ max_tokens: int = 2048,
34
+ mode: Literal["word", "bpe", "char"] = "bpe",
35
+ truncate_mode: Literal["head", "tail", "middle"] = "head",
36
+ ):
37
+ """
38
+ Initialize the tokenizer
39
+
40
+ :param max_tokens: Maximum token limit, default 2048 (only effective in Word mode)
41
+ :param mode: Tokenization mode ['char'(character-level), 'word'(word-level), 'bpe'(mixed mode)], default bpe
42
+ :param truncate_mode: Truncation mode ['head'(head truncation), 'tail'(tail truncation), 'middle'(middle truncation)], default head
43
+ """
44
+ self.max_tokens = max_tokens
45
+ self.mode = mode
46
+ self.truncate_mode = truncate_mode
47
+ self._word_pattern = re.compile(r"\w+|[^\w\s]") # Match words or punctuation
48
+
49
+ def tokenize(self, text: str) -> list[str]:
50
+ """Perform tokenization operation, returning a list of tokens
51
+
52
+ Args:
53
+ text: Input text
54
+
55
+ Returns:
56
+ List[str]: List of tokens
57
+ """
58
+ if self.mode == "char":
59
+ return list(text)
60
+
61
+ # Mixed Chinese-English tokenization strategy
62
+ tokens = []
63
+ for chunk in re.findall(self._word_pattern, text):
64
+ if chunk.strip() == "":
65
+ continue
66
+
67
+ if self._is_english(chunk):
68
+ tokens.extend(chunk.split())
69
+ else:
70
+ tokens.extend(jieba.lcut(chunk))
71
+
72
+ return tokens[: self.max_tokens] if self.mode == "word" else tokens
73
+
74
+ def truncate(self, tokens: list[str]) -> list[str]:
75
+ """Perform token truncation operation
76
+
77
+ Args:
78
+ tokens: List of tokens
79
+
80
+ Returns:
81
+ List[str]: Truncated list of tokens
82
+ """
83
+ if len(tokens) <= self.max_tokens:
84
+ return tokens
85
+
86
+ if self.truncate_mode == "head":
87
+ return tokens[-self.max_tokens :]
88
+ elif self.truncate_mode == "tail":
89
+ return tokens[: self.max_tokens]
90
+ else: # middle mode preserves head and tail
91
+ head_len = self.max_tokens // 2
92
+ tail_len = self.max_tokens - head_len
93
+ return tokens[:head_len] + tokens[-tail_len:]
94
+
95
+ def count_tokens(self, text: str) -> int:
96
+ """Count the number of tokens in text
97
+
98
+ Args:
99
+ text: Input text
100
+
101
+ Returns:
102
+ int: Number of tokens
103
+ """
104
+ return len(self.tokenize(text))
105
+
106
+ def _is_english(self, text: str) -> bool:
107
+ """Check if the text is English
108
+
109
+ Args:
110
+ text: Input text
111
+
112
+ Returns:
113
+ bool: Whether the text is English
114
+ """
115
+ return all(ord(c) < 128 for c in text)
@@ -0,0 +1,163 @@
1
+ import typing
2
+ from collections.abc import Awaitable, Callable
3
+ from typing import Any, ClassVar, overload
4
+
5
+ from typing_extensions import Self
6
+
7
+ from .models import FunctionDefinitionSchema, ToolContext, ToolData, ToolFunctionSchema
8
+
9
+ T = typing.TypeVar("T")
10
+
11
+
12
+ class ToolsManager:
13
+ _instance = None
14
+ _models: ClassVar[dict[str, ToolData]] = {}
15
+ _disabled_tools: ClassVar[set[str]] = (
16
+ set()
17
+ ) # Disabled tools, has_tool and get_tool will not return disabled tools
18
+
19
+ def __new__(cls) -> Self:
20
+ if cls._instance is None:
21
+ cls._instance = super().__new__(cls)
22
+ return cls._instance
23
+
24
+ def has_tool(self, name: str) -> bool:
25
+ return False if name in self._disabled_tools else name in self._models
26
+
27
+ @overload
28
+ def get_tool(self, name: str) -> ToolData | None: ...
29
+ @overload
30
+ def get_tool(self, name: str, default: T) -> ToolData | T: ...
31
+ def get_tool(self, name: str, default: T = None) -> ToolData | T | None:
32
+ if not self.has_tool(name):
33
+ return default
34
+ tool: ToolData = self._models[name]
35
+ return tool if tool.enable_if() else default
36
+
37
+ @overload
38
+ def get_tool_meta(self, name: str) -> ToolFunctionSchema | None: ...
39
+ @overload
40
+ def get_tool_meta(self, name: str, default: T) -> ToolFunctionSchema | T: ...
41
+ def get_tool_meta(
42
+ self, name: str, default: T | None = None
43
+ ) -> ToolFunctionSchema | None | T:
44
+ func_data = self.get_tool(name)
45
+ if func_data is None:
46
+ return default
47
+ if isinstance(func_data, ToolData):
48
+ return func_data.data
49
+ return default
50
+
51
+ @overload
52
+ def get_tool_func(
53
+ self, name: str, default: T
54
+ ) -> (
55
+ Callable[[dict[str, Any]], Awaitable[str]]
56
+ | Callable[[ToolContext], Awaitable[str | None]]
57
+ | T
58
+ ): ...
59
+ @overload
60
+ def get_tool_func(
61
+ self,
62
+ name: str,
63
+ ) -> (
64
+ Callable[[dict[str, Any]], Awaitable[str]]
65
+ | Callable[[ToolContext], Awaitable[str | None]]
66
+ | None
67
+ ): ...
68
+ def get_tool_func(
69
+ self, name: str, default: T | None = None
70
+ ) -> (
71
+ Callable[[dict[str, Any]], Awaitable[str]]
72
+ | Callable[[ToolContext], Awaitable[str | None]]
73
+ | None
74
+ | T
75
+ ):
76
+ func_data = self.get_tool(name)
77
+ if func_data is None:
78
+ return default
79
+ if isinstance(func_data, ToolData):
80
+ return func_data.func
81
+ return default
82
+
83
+ def get_tools(self) -> dict[str, ToolData]:
84
+ return {
85
+ name: data
86
+ for name, data in self._models.items()
87
+ if (name not in self._disabled_tools and data.enable_if())
88
+ }
89
+
90
+ def tools_meta(self) -> dict[str, ToolFunctionSchema]:
91
+ return {
92
+ k: v.data
93
+ for k, v in self._models.items()
94
+ if (k not in self._disabled_tools and v.enable_if())
95
+ }
96
+
97
+ def tools_meta_dict(self, **kwargs) -> dict[str, dict[str, Any]]:
98
+ return {
99
+ k: v.data.model_dump(**kwargs)
100
+ for k, v in self._models.items()
101
+ if (k not in self._disabled_tools and v.enable_if())
102
+ }
103
+
104
+ def register_tool(self, tool: ToolData) -> None:
105
+ if tool.data.function.name not in self._models:
106
+ self._models[tool.data.function.name] = tool
107
+ else:
108
+ raise ValueError(f"Tool {tool.data.function.name} already exists")
109
+
110
+ def remove_tool(self, name: str) -> None:
111
+ self._models.pop(name, None)
112
+ if name in self._disabled_tools:
113
+ self._disabled_tools.remove(name)
114
+
115
+ def enable_tool(self, name: str) -> None:
116
+ if name in self._disabled_tools:
117
+ self._disabled_tools.remove(name)
118
+ else:
119
+ raise ValueError(f"Tool {name} is not disabled")
120
+
121
+ def disable_tool(self, name: str) -> None:
122
+ if self.has_tool(name):
123
+ self._disabled_tools.add(name)
124
+ else:
125
+ raise ValueError(f"Tool {name} does not exist or has been disabled")
126
+
127
+ def get_disabled_tools(self) -> list[str]:
128
+ return list(self._disabled_tools)
129
+
130
+
131
+ def on_tools(
132
+ data: FunctionDefinitionSchema,
133
+ custom_run: bool = False,
134
+ strict: bool = False,
135
+ enable_if: Callable[[], bool] = lambda: True,
136
+ ) -> Callable[
137
+ ...,
138
+ Callable[[dict[str, Any]], Awaitable[str]]
139
+ | Callable[[ToolContext], Awaitable[str | None]],
140
+ ]:
141
+ """Tool registration decorator
142
+
143
+ Args:
144
+ data (FunctionDefinitionSchema): Function metadata
145
+ custom_run (bool, optional): Whether to enable custom run mode. Defaults to False.
146
+ strict (bool, optional): Whether to enable strict mode. Defaults to False.
147
+ show_call (bool, optional): Whether to show tool call. Defaults to True.
148
+ """
149
+
150
+ def decorator(
151
+ func: Callable[[dict[str, Any]], Awaitable[str]]
152
+ | Callable[[ToolContext], Awaitable[str | None]],
153
+ ):
154
+ tool_data = ToolData(
155
+ func=func,
156
+ data=ToolFunctionSchema(function=data, type="function", strict=strict),
157
+ custom_run=custom_run,
158
+ enable_if=enable_if,
159
+ )
160
+ ToolsManager().register_tool(tool_data)
161
+ return func
162
+
163
+ return decorator