amrita_core 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- amrita_core/__init__.py +101 -0
- amrita_core/builtins/__init__.py +7 -0
- amrita_core/builtins/adapter.py +148 -0
- amrita_core/builtins/agent.py +415 -0
- amrita_core/builtins/tools.py +64 -0
- amrita_core/chatmanager.py +896 -0
- amrita_core/config.py +159 -0
- amrita_core/hook/event.py +90 -0
- amrita_core/hook/exception.py +14 -0
- amrita_core/hook/matcher.py +213 -0
- amrita_core/hook/on.py +14 -0
- amrita_core/libchat.py +189 -0
- amrita_core/logging.py +71 -0
- amrita_core/preset.py +166 -0
- amrita_core/protocol.py +101 -0
- amrita_core/tokenizer.py +115 -0
- amrita_core/tools/manager.py +163 -0
- amrita_core/tools/mcp.py +338 -0
- amrita_core/tools/models.py +353 -0
- amrita_core/types.py +274 -0
- amrita_core/utils.py +66 -0
- amrita_core-0.1.0.dist-info/METADATA +73 -0
- amrita_core-0.1.0.dist-info/RECORD +26 -0
- amrita_core-0.1.0.dist-info/WHEEL +5 -0
- amrita_core-0.1.0.dist-info/licenses/LICENSE +661 -0
- amrita_core-0.1.0.dist-info/top_level.txt +1 -0
amrita_core/logging.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
# ref: https://github.com/NoneBot/NoneBot2/blob/main/nonebot/log.py
|
|
2
|
+
import inspect
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
import sys
|
|
6
|
+
from typing import TYPE_CHECKING, Protocol
|
|
7
|
+
|
|
8
|
+
import loguru
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from loguru import Logger, Record
|
|
12
|
+
|
|
13
|
+
logger: "Logger" = loguru.logger
|
|
14
|
+
|
|
15
|
+
debug: bool = False
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ToStringAble(Protocol):
|
|
19
|
+
def __str__(self) -> str: ...
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def debug_log(message: ToStringAble) -> None:
|
|
23
|
+
global debug
|
|
24
|
+
if debug:
|
|
25
|
+
logger.debug(message)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class LoguruHandler(logging.Handler):
|
|
29
|
+
def emit(self, record: logging.LogRecord):
|
|
30
|
+
try:
|
|
31
|
+
level = logger.level(record.levelname).name
|
|
32
|
+
except ValueError:
|
|
33
|
+
level = record.levelno
|
|
34
|
+
|
|
35
|
+
frame, depth = inspect.currentframe(), 0
|
|
36
|
+
while frame and (depth == 0 or frame.f_code.co_filename == logging.__file__):
|
|
37
|
+
frame = frame.f_back
|
|
38
|
+
depth += 1
|
|
39
|
+
|
|
40
|
+
logger.opt(depth=depth, exception=record.exc_info, colors=True).log(
|
|
41
|
+
level, record.getMessage()
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def default_filter(record: "Record"):
|
|
46
|
+
"""Default filter for logging, change level from Environment"""
|
|
47
|
+
log_level = os.environ.get("LOG_LEVEL", "INFO")
|
|
48
|
+
levelno = logger.level(log_level).no if isinstance(log_level, str) else log_level
|
|
49
|
+
return record["level"].no >= levelno
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
default_format: str = (
|
|
53
|
+
"<green>{time:YYYY-MM-DD HH:mm:ss}</green> | "
|
|
54
|
+
"<level>{level: <7}</level> | "
|
|
55
|
+
"<magenta>{name}:{function}:{line}</magenta> | "
|
|
56
|
+
"<level>{message}</level>"
|
|
57
|
+
)
|
|
58
|
+
"""Default log format"""
|
|
59
|
+
|
|
60
|
+
logger.remove()
|
|
61
|
+
logger_id = logger.add(
|
|
62
|
+
sys.stdout,
|
|
63
|
+
level=0,
|
|
64
|
+
diagnose=False,
|
|
65
|
+
filter=default_filter,
|
|
66
|
+
format=default_format,
|
|
67
|
+
)
|
|
68
|
+
"""Default log handler id"""
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
__autodoc__ = {"logger_id": False}
|
amrita_core/preset.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
import random
|
|
2
|
+
import time
|
|
3
|
+
import typing
|
|
4
|
+
|
|
5
|
+
from typing_extensions import Self
|
|
6
|
+
|
|
7
|
+
from .logging import debug_log
|
|
8
|
+
from .protocol import AdapterManager
|
|
9
|
+
from .tokenizer import hybrid_token_count
|
|
10
|
+
from .types import BaseModel, Message, ModelPreset, TextContent, UniResponse
|
|
11
|
+
|
|
12
|
+
TEST_MSG_PROMPT: Message[list[TextContent]] = Message(
|
|
13
|
+
role="system",
|
|
14
|
+
content=[TextContent(text="You are a helpful assistant.", type="text")],
|
|
15
|
+
)
|
|
16
|
+
TEST_MSG_USER: Message[list[TextContent]] = Message(
|
|
17
|
+
role="user",
|
|
18
|
+
content=[
|
|
19
|
+
TextContent(text="Hello, please briefly introduce yourself.", type="text")
|
|
20
|
+
],
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
TEST_MSG_LIST: list[Message[list[TextContent]]] = [
|
|
24
|
+
TEST_MSG_PROMPT,
|
|
25
|
+
TEST_MSG_USER,
|
|
26
|
+
]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class PresetReport(BaseModel):
|
|
30
|
+
preset_name: str # Name of the preset
|
|
31
|
+
preset_data: ModelPreset # Preset data
|
|
32
|
+
test_input: tuple[Message, Message] # Test input
|
|
33
|
+
test_output: Message | None # Test output
|
|
34
|
+
token_prompt: int # Token count of the prompt
|
|
35
|
+
token_completion: int # Token count of the completion
|
|
36
|
+
status: bool # Test result
|
|
37
|
+
message: str # Test result message
|
|
38
|
+
time_used: float
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class PresetManager:
|
|
42
|
+
"""
|
|
43
|
+
PresetManager is a singleton class that manages presets.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
_default_preset: ModelPreset | None = None
|
|
47
|
+
_presets: dict[str, ModelPreset]
|
|
48
|
+
_instance = None
|
|
49
|
+
|
|
50
|
+
def __new__(cls) -> Self:
|
|
51
|
+
if cls._instance is None:
|
|
52
|
+
cls._presets = {}
|
|
53
|
+
cls._instance = super().__new__(cls)
|
|
54
|
+
return cls._instance
|
|
55
|
+
|
|
56
|
+
def set_default_preset(self, preset: ModelPreset | str) -> None:
|
|
57
|
+
"""
|
|
58
|
+
Set the default preset.
|
|
59
|
+
"""
|
|
60
|
+
if isinstance(preset, str):
|
|
61
|
+
preset = self.get_preset(preset)
|
|
62
|
+
if preset.name not in self._presets:
|
|
63
|
+
self.add_preset(preset)
|
|
64
|
+
self._default_preset = preset
|
|
65
|
+
|
|
66
|
+
def get_default_preset(self) -> ModelPreset:
|
|
67
|
+
"""
|
|
68
|
+
Get the default preset.
|
|
69
|
+
"""
|
|
70
|
+
if self._default_preset is None:
|
|
71
|
+
self._default_preset = random.choice(list(self._presets.values()))
|
|
72
|
+
return self._default_preset
|
|
73
|
+
|
|
74
|
+
def get_preset(self, name: str) -> ModelPreset:
|
|
75
|
+
"""
|
|
76
|
+
Get a preset by name.
|
|
77
|
+
"""
|
|
78
|
+
if name not in self._presets:
|
|
79
|
+
raise ValueError(f"Preset {name} not found")
|
|
80
|
+
return self._presets[name]
|
|
81
|
+
|
|
82
|
+
def add_preset(self, preset: ModelPreset) -> None:
|
|
83
|
+
"""
|
|
84
|
+
Add a preset.
|
|
85
|
+
"""
|
|
86
|
+
if preset.name in self._presets:
|
|
87
|
+
raise ValueError(f"Preset {preset.name} already exists")
|
|
88
|
+
self._presets[preset.name] = preset
|
|
89
|
+
|
|
90
|
+
def get_all_presets(self) -> list[ModelPreset]:
|
|
91
|
+
"""
|
|
92
|
+
Get all presets.
|
|
93
|
+
"""
|
|
94
|
+
return list(self._presets.values())
|
|
95
|
+
|
|
96
|
+
async def test_single_preset(self, preset: ModelPreset | str) -> PresetReport:
|
|
97
|
+
"""Test a single preset for parallel execution"""
|
|
98
|
+
if isinstance(preset, str):
|
|
99
|
+
preset = self.get_preset(preset)
|
|
100
|
+
debug_log(f"Testing preset: {preset.name}...")
|
|
101
|
+
prompt_tokens = hybrid_token_count(
|
|
102
|
+
"".join(
|
|
103
|
+
[typing.cast(TextContent, msg.content[0]).text for msg in TEST_MSG_LIST]
|
|
104
|
+
)
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
adapter = AdapterManager().safe_get_adapter(preset.protocol)
|
|
108
|
+
if adapter is None:
|
|
109
|
+
return PresetReport(
|
|
110
|
+
preset_name=preset.name,
|
|
111
|
+
preset_data=preset,
|
|
112
|
+
test_input=(TEST_MSG_PROMPT, TEST_MSG_USER),
|
|
113
|
+
test_output=None,
|
|
114
|
+
token_prompt=prompt_tokens,
|
|
115
|
+
token_completion=0,
|
|
116
|
+
status=False,
|
|
117
|
+
message=f"Undefined protocol adapter: {preset.protocol}",
|
|
118
|
+
time_used=0,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
try:
|
|
122
|
+
time_start = time.time()
|
|
123
|
+
debug_log(f"Calling preset: {preset.name}...")
|
|
124
|
+
data = [ # noqa: RUF015
|
|
125
|
+
i
|
|
126
|
+
async for i in adapter(preset).call_api(TEST_MSG_LIST)
|
|
127
|
+
if isinstance(i, UniResponse)
|
|
128
|
+
][0]
|
|
129
|
+
time_end = time.time()
|
|
130
|
+
time_delta = time_end - time_start
|
|
131
|
+
debug_log(
|
|
132
|
+
f"Successfully called preset {preset.name}, took {time_delta:.2f} seconds"
|
|
133
|
+
)
|
|
134
|
+
return PresetReport(
|
|
135
|
+
preset_name=preset.name,
|
|
136
|
+
preset_data=preset,
|
|
137
|
+
test_input=(TEST_MSG_PROMPT, TEST_MSG_USER),
|
|
138
|
+
test_output=Message[list[TextContent]](
|
|
139
|
+
role="assistant",
|
|
140
|
+
content=[TextContent(type="text", text=data.content)],
|
|
141
|
+
),
|
|
142
|
+
token_prompt=prompt_tokens,
|
|
143
|
+
token_completion=hybrid_token_count(data.content),
|
|
144
|
+
status=True,
|
|
145
|
+
message="",
|
|
146
|
+
time_used=time_delta,
|
|
147
|
+
)
|
|
148
|
+
except Exception as e:
|
|
149
|
+
debug_log(f"Error occurred while testing preset {preset.name}: {e}")
|
|
150
|
+
return PresetReport(
|
|
151
|
+
preset_name=preset.name,
|
|
152
|
+
preset_data=preset,
|
|
153
|
+
test_input=(TEST_MSG_PROMPT, TEST_MSG_USER),
|
|
154
|
+
test_output=None,
|
|
155
|
+
token_prompt=prompt_tokens,
|
|
156
|
+
token_completion=0,
|
|
157
|
+
status=False,
|
|
158
|
+
message=str(e),
|
|
159
|
+
time_used=0,
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
async def test_presets(self) -> typing.AsyncGenerator[PresetReport, None]:
|
|
163
|
+
presets: list[ModelPreset] = self.get_all_presets()
|
|
164
|
+
debug_log(f"Starting to test all presets ({len(presets)} total)...")
|
|
165
|
+
for preset in presets:
|
|
166
|
+
yield await self.test_single_preset(preset)
|
amrita_core/protocol.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import abstractmethod
|
|
4
|
+
from collections.abc import AsyncGenerator, Iterable
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
|
|
7
|
+
from .logging import logger
|
|
8
|
+
from .tools.models import ToolChoice, ToolFunctionSchema
|
|
9
|
+
from .types import ModelPreset, ToolCall, UniResponse
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class ModelAdapter:
|
|
14
|
+
"""Base class for model adapter"""
|
|
15
|
+
|
|
16
|
+
preset: ModelPreset
|
|
17
|
+
__override__: bool = False # Whether to allow overriding existing adapters
|
|
18
|
+
|
|
19
|
+
def __init_subclass__(cls) -> None:
|
|
20
|
+
super().__init_subclass__()
|
|
21
|
+
if not getattr(cls, "__abstract__", False):
|
|
22
|
+
AdapterManager().register_adapter(cls)
|
|
23
|
+
|
|
24
|
+
@abstractmethod
|
|
25
|
+
async def call_api(
|
|
26
|
+
self, messages: Iterable
|
|
27
|
+
) -> AsyncGenerator[str | UniResponse[str, None], None]:
|
|
28
|
+
yield ""
|
|
29
|
+
|
|
30
|
+
async def call_tools(
|
|
31
|
+
self,
|
|
32
|
+
messages: Iterable,
|
|
33
|
+
tools: list[ToolFunctionSchema],
|
|
34
|
+
tool_choice: ToolChoice | None = None,
|
|
35
|
+
) -> UniResponse[None, list[ToolCall] | None]:
|
|
36
|
+
raise NotImplementedError
|
|
37
|
+
|
|
38
|
+
@staticmethod
|
|
39
|
+
@abstractmethod
|
|
40
|
+
def get_adapter_protocol() -> str | tuple[str, ...]: ...
|
|
41
|
+
|
|
42
|
+
@property
|
|
43
|
+
def protocol(self):
|
|
44
|
+
"""Get model protocol adapter"""
|
|
45
|
+
return self.get_adapter_protocol()
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class AdapterManager:
|
|
49
|
+
__instance = None
|
|
50
|
+
_adapter_class: dict[str, type[ModelAdapter]]
|
|
51
|
+
|
|
52
|
+
def __new__(cls):
|
|
53
|
+
if cls.__instance is None:
|
|
54
|
+
cls.__instance = super().__new__(cls)
|
|
55
|
+
cls.__instance._adapter_class = {}
|
|
56
|
+
return cls.__instance
|
|
57
|
+
|
|
58
|
+
def get_adapters(self) -> dict[str, type[ModelAdapter]]:
|
|
59
|
+
"""Get all registered adapters"""
|
|
60
|
+
return self._adapter_class
|
|
61
|
+
|
|
62
|
+
def safe_get_adapter(self, protocol: str) -> type[ModelAdapter] | None:
|
|
63
|
+
"""Get adapter"""
|
|
64
|
+
return self._adapter_class.get(protocol)
|
|
65
|
+
|
|
66
|
+
def get_adapter(self, protocol: str) -> type[ModelAdapter]:
|
|
67
|
+
"""Get adapter"""
|
|
68
|
+
if protocol not in self._adapter_class:
|
|
69
|
+
raise ValueError(f"No adapter found for protocol {protocol}")
|
|
70
|
+
return self._adapter_class[protocol]
|
|
71
|
+
|
|
72
|
+
def register_adapter(self, adapter: type[ModelAdapter]):
|
|
73
|
+
"""Register adapter"""
|
|
74
|
+
protocol = adapter.get_adapter_protocol()
|
|
75
|
+
override = adapter.__override__ if hasattr(adapter, "__override__") else False
|
|
76
|
+
if isinstance(protocol, str):
|
|
77
|
+
if protocol in self._adapter_class:
|
|
78
|
+
if not override:
|
|
79
|
+
raise ValueError(
|
|
80
|
+
f"Model protocol adapter {protocol} is already registered"
|
|
81
|
+
)
|
|
82
|
+
logger.warning(
|
|
83
|
+
f"Model protocol adapter {protocol} has been registered by {self._adapter_class[protocol].__name__}, overriding existing adapter"
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
self._adapter_class[protocol] = adapter
|
|
87
|
+
elif isinstance(protocol, tuple):
|
|
88
|
+
for p in protocol:
|
|
89
|
+
if not isinstance(p, str):
|
|
90
|
+
raise TypeError(
|
|
91
|
+
"Model protocol adapter must be a string or tuple of strings"
|
|
92
|
+
)
|
|
93
|
+
if p in self._adapter_class:
|
|
94
|
+
if not override:
|
|
95
|
+
raise ValueError(
|
|
96
|
+
f"Model protocol adapter {p} is already registered"
|
|
97
|
+
)
|
|
98
|
+
logger.warning(
|
|
99
|
+
f"Model protocol adapter {p} has been registered by {self._adapter_class[p].__name__}, overriding existing adapter"
|
|
100
|
+
)
|
|
101
|
+
self._adapter_class[p] = adapter
|
amrita_core/tokenizer.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from functools import lru_cache
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
import jieba
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@lru_cache(maxsize=2048)
|
|
9
|
+
def hybrid_token_count(
|
|
10
|
+
text: str,
|
|
11
|
+
mode: Literal["word", "bpe", "char"] = "word",
|
|
12
|
+
truncate_mode: Literal["head", "tail", "middle"] = "head",
|
|
13
|
+
) -> int:
|
|
14
|
+
"""
|
|
15
|
+
Calculate token count for mixed Chinese-English text, supporting word, subword, and character modes
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
text: Input text
|
|
19
|
+
mode: Tokenization mode ['char'(character-level), 'word'(word-level), 'bpe'(mixed mode)], default bpe
|
|
20
|
+
truncate_mode: Truncation mode ['head'(head truncation), 'tail'(tail truncation), 'middle'(middle truncation)], default head
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
int: Number of tokens
|
|
24
|
+
"""
|
|
25
|
+
return Tokenizer(mode=mode, truncate_mode=truncate_mode).count_tokens(text=text)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class Tokenizer:
|
|
29
|
+
"""General purpose text tokenizer"""
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
max_tokens: int = 2048,
|
|
34
|
+
mode: Literal["word", "bpe", "char"] = "bpe",
|
|
35
|
+
truncate_mode: Literal["head", "tail", "middle"] = "head",
|
|
36
|
+
):
|
|
37
|
+
"""
|
|
38
|
+
Initialize the tokenizer
|
|
39
|
+
|
|
40
|
+
:param max_tokens: Maximum token limit, default 2048 (only effective in Word mode)
|
|
41
|
+
:param mode: Tokenization mode ['char'(character-level), 'word'(word-level), 'bpe'(mixed mode)], default bpe
|
|
42
|
+
:param truncate_mode: Truncation mode ['head'(head truncation), 'tail'(tail truncation), 'middle'(middle truncation)], default head
|
|
43
|
+
"""
|
|
44
|
+
self.max_tokens = max_tokens
|
|
45
|
+
self.mode = mode
|
|
46
|
+
self.truncate_mode = truncate_mode
|
|
47
|
+
self._word_pattern = re.compile(r"\w+|[^\w\s]") # Match words or punctuation
|
|
48
|
+
|
|
49
|
+
def tokenize(self, text: str) -> list[str]:
|
|
50
|
+
"""Perform tokenization operation, returning a list of tokens
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
text: Input text
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
List[str]: List of tokens
|
|
57
|
+
"""
|
|
58
|
+
if self.mode == "char":
|
|
59
|
+
return list(text)
|
|
60
|
+
|
|
61
|
+
# Mixed Chinese-English tokenization strategy
|
|
62
|
+
tokens = []
|
|
63
|
+
for chunk in re.findall(self._word_pattern, text):
|
|
64
|
+
if chunk.strip() == "":
|
|
65
|
+
continue
|
|
66
|
+
|
|
67
|
+
if self._is_english(chunk):
|
|
68
|
+
tokens.extend(chunk.split())
|
|
69
|
+
else:
|
|
70
|
+
tokens.extend(jieba.lcut(chunk))
|
|
71
|
+
|
|
72
|
+
return tokens[: self.max_tokens] if self.mode == "word" else tokens
|
|
73
|
+
|
|
74
|
+
def truncate(self, tokens: list[str]) -> list[str]:
|
|
75
|
+
"""Perform token truncation operation
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
tokens: List of tokens
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
List[str]: Truncated list of tokens
|
|
82
|
+
"""
|
|
83
|
+
if len(tokens) <= self.max_tokens:
|
|
84
|
+
return tokens
|
|
85
|
+
|
|
86
|
+
if self.truncate_mode == "head":
|
|
87
|
+
return tokens[-self.max_tokens :]
|
|
88
|
+
elif self.truncate_mode == "tail":
|
|
89
|
+
return tokens[: self.max_tokens]
|
|
90
|
+
else: # middle mode preserves head and tail
|
|
91
|
+
head_len = self.max_tokens // 2
|
|
92
|
+
tail_len = self.max_tokens - head_len
|
|
93
|
+
return tokens[:head_len] + tokens[-tail_len:]
|
|
94
|
+
|
|
95
|
+
def count_tokens(self, text: str) -> int:
|
|
96
|
+
"""Count the number of tokens in text
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
text: Input text
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
int: Number of tokens
|
|
103
|
+
"""
|
|
104
|
+
return len(self.tokenize(text))
|
|
105
|
+
|
|
106
|
+
def _is_english(self, text: str) -> bool:
|
|
107
|
+
"""Check if the text is English
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
text: Input text
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
bool: Whether the text is English
|
|
114
|
+
"""
|
|
115
|
+
return all(ord(c) < 128 for c in text)
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
import typing
|
|
2
|
+
from collections.abc import Awaitable, Callable
|
|
3
|
+
from typing import Any, ClassVar, overload
|
|
4
|
+
|
|
5
|
+
from typing_extensions import Self
|
|
6
|
+
|
|
7
|
+
from .models import FunctionDefinitionSchema, ToolContext, ToolData, ToolFunctionSchema
|
|
8
|
+
|
|
9
|
+
T = typing.TypeVar("T")
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ToolsManager:
|
|
13
|
+
_instance = None
|
|
14
|
+
_models: ClassVar[dict[str, ToolData]] = {}
|
|
15
|
+
_disabled_tools: ClassVar[set[str]] = (
|
|
16
|
+
set()
|
|
17
|
+
) # Disabled tools, has_tool and get_tool will not return disabled tools
|
|
18
|
+
|
|
19
|
+
def __new__(cls) -> Self:
|
|
20
|
+
if cls._instance is None:
|
|
21
|
+
cls._instance = super().__new__(cls)
|
|
22
|
+
return cls._instance
|
|
23
|
+
|
|
24
|
+
def has_tool(self, name: str) -> bool:
|
|
25
|
+
return False if name in self._disabled_tools else name in self._models
|
|
26
|
+
|
|
27
|
+
@overload
|
|
28
|
+
def get_tool(self, name: str) -> ToolData | None: ...
|
|
29
|
+
@overload
|
|
30
|
+
def get_tool(self, name: str, default: T) -> ToolData | T: ...
|
|
31
|
+
def get_tool(self, name: str, default: T = None) -> ToolData | T | None:
|
|
32
|
+
if not self.has_tool(name):
|
|
33
|
+
return default
|
|
34
|
+
tool: ToolData = self._models[name]
|
|
35
|
+
return tool if tool.enable_if() else default
|
|
36
|
+
|
|
37
|
+
@overload
|
|
38
|
+
def get_tool_meta(self, name: str) -> ToolFunctionSchema | None: ...
|
|
39
|
+
@overload
|
|
40
|
+
def get_tool_meta(self, name: str, default: T) -> ToolFunctionSchema | T: ...
|
|
41
|
+
def get_tool_meta(
|
|
42
|
+
self, name: str, default: T | None = None
|
|
43
|
+
) -> ToolFunctionSchema | None | T:
|
|
44
|
+
func_data = self.get_tool(name)
|
|
45
|
+
if func_data is None:
|
|
46
|
+
return default
|
|
47
|
+
if isinstance(func_data, ToolData):
|
|
48
|
+
return func_data.data
|
|
49
|
+
return default
|
|
50
|
+
|
|
51
|
+
@overload
|
|
52
|
+
def get_tool_func(
|
|
53
|
+
self, name: str, default: T
|
|
54
|
+
) -> (
|
|
55
|
+
Callable[[dict[str, Any]], Awaitable[str]]
|
|
56
|
+
| Callable[[ToolContext], Awaitable[str | None]]
|
|
57
|
+
| T
|
|
58
|
+
): ...
|
|
59
|
+
@overload
|
|
60
|
+
def get_tool_func(
|
|
61
|
+
self,
|
|
62
|
+
name: str,
|
|
63
|
+
) -> (
|
|
64
|
+
Callable[[dict[str, Any]], Awaitable[str]]
|
|
65
|
+
| Callable[[ToolContext], Awaitable[str | None]]
|
|
66
|
+
| None
|
|
67
|
+
): ...
|
|
68
|
+
def get_tool_func(
|
|
69
|
+
self, name: str, default: T | None = None
|
|
70
|
+
) -> (
|
|
71
|
+
Callable[[dict[str, Any]], Awaitable[str]]
|
|
72
|
+
| Callable[[ToolContext], Awaitable[str | None]]
|
|
73
|
+
| None
|
|
74
|
+
| T
|
|
75
|
+
):
|
|
76
|
+
func_data = self.get_tool(name)
|
|
77
|
+
if func_data is None:
|
|
78
|
+
return default
|
|
79
|
+
if isinstance(func_data, ToolData):
|
|
80
|
+
return func_data.func
|
|
81
|
+
return default
|
|
82
|
+
|
|
83
|
+
def get_tools(self) -> dict[str, ToolData]:
|
|
84
|
+
return {
|
|
85
|
+
name: data
|
|
86
|
+
for name, data in self._models.items()
|
|
87
|
+
if (name not in self._disabled_tools and data.enable_if())
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
def tools_meta(self) -> dict[str, ToolFunctionSchema]:
|
|
91
|
+
return {
|
|
92
|
+
k: v.data
|
|
93
|
+
for k, v in self._models.items()
|
|
94
|
+
if (k not in self._disabled_tools and v.enable_if())
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
def tools_meta_dict(self, **kwargs) -> dict[str, dict[str, Any]]:
|
|
98
|
+
return {
|
|
99
|
+
k: v.data.model_dump(**kwargs)
|
|
100
|
+
for k, v in self._models.items()
|
|
101
|
+
if (k not in self._disabled_tools and v.enable_if())
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
def register_tool(self, tool: ToolData) -> None:
|
|
105
|
+
if tool.data.function.name not in self._models:
|
|
106
|
+
self._models[tool.data.function.name] = tool
|
|
107
|
+
else:
|
|
108
|
+
raise ValueError(f"Tool {tool.data.function.name} already exists")
|
|
109
|
+
|
|
110
|
+
def remove_tool(self, name: str) -> None:
|
|
111
|
+
self._models.pop(name, None)
|
|
112
|
+
if name in self._disabled_tools:
|
|
113
|
+
self._disabled_tools.remove(name)
|
|
114
|
+
|
|
115
|
+
def enable_tool(self, name: str) -> None:
|
|
116
|
+
if name in self._disabled_tools:
|
|
117
|
+
self._disabled_tools.remove(name)
|
|
118
|
+
else:
|
|
119
|
+
raise ValueError(f"Tool {name} is not disabled")
|
|
120
|
+
|
|
121
|
+
def disable_tool(self, name: str) -> None:
|
|
122
|
+
if self.has_tool(name):
|
|
123
|
+
self._disabled_tools.add(name)
|
|
124
|
+
else:
|
|
125
|
+
raise ValueError(f"Tool {name} does not exist or has been disabled")
|
|
126
|
+
|
|
127
|
+
def get_disabled_tools(self) -> list[str]:
|
|
128
|
+
return list(self._disabled_tools)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def on_tools(
|
|
132
|
+
data: FunctionDefinitionSchema,
|
|
133
|
+
custom_run: bool = False,
|
|
134
|
+
strict: bool = False,
|
|
135
|
+
enable_if: Callable[[], bool] = lambda: True,
|
|
136
|
+
) -> Callable[
|
|
137
|
+
...,
|
|
138
|
+
Callable[[dict[str, Any]], Awaitable[str]]
|
|
139
|
+
| Callable[[ToolContext], Awaitable[str | None]],
|
|
140
|
+
]:
|
|
141
|
+
"""Tool registration decorator
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
data (FunctionDefinitionSchema): Function metadata
|
|
145
|
+
custom_run (bool, optional): Whether to enable custom run mode. Defaults to False.
|
|
146
|
+
strict (bool, optional): Whether to enable strict mode. Defaults to False.
|
|
147
|
+
show_call (bool, optional): Whether to show tool call. Defaults to True.
|
|
148
|
+
"""
|
|
149
|
+
|
|
150
|
+
def decorator(
|
|
151
|
+
func: Callable[[dict[str, Any]], Awaitable[str]]
|
|
152
|
+
| Callable[[ToolContext], Awaitable[str | None]],
|
|
153
|
+
):
|
|
154
|
+
tool_data = ToolData(
|
|
155
|
+
func=func,
|
|
156
|
+
data=ToolFunctionSchema(function=data, type="function", strict=strict),
|
|
157
|
+
custom_run=custom_run,
|
|
158
|
+
enable_if=enable_if,
|
|
159
|
+
)
|
|
160
|
+
ToolsManager().register_tool(tool_data)
|
|
161
|
+
return func
|
|
162
|
+
|
|
163
|
+
return decorator
|