nonebot-plugin-amrita 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,111 @@
1
+ import asyncio
2
+ import contextlib
3
+ import sys
4
+
5
+ import amrita_core
6
+ from amrita_core import ChatManager, ChatObject, minimal_init
7
+ from amrita_core.config import (
8
+ AmritaConfig,
9
+ BuiltinAgentConfig,
10
+ CookieConfig,
11
+ FunctionConfig,
12
+ LLMConfig,
13
+ )
14
+ from nonebot import get_driver, logger
15
+ from nonebot import log as nb_log
16
+ from nonebot.plugin import PluginMetadata
17
+
18
+ from . import agent, database, dirty, memory
19
+ from . import config as conf_module
20
+ from .config import Config
21
+ from .database import InsightsModel, UserDataExecutor
22
+ from .memory import CachedUserDataRepository, MemorySchema
23
+
24
+ __plugin_meta__ = PluginMetadata(
25
+ name="LibAmritaCore",
26
+ description="Add AmritaCore (a high performance agent core) support to nonebot2",
27
+ usage="View `https://amrita-core.suggar.top/zh` for details.",
28
+ type="library",
29
+ config=Config,
30
+ supported_adapters=None,
31
+ )
32
+
33
+
34
+ def _patch_logger():
35
+ logger.remove()
36
+ nb_log.logger_id = logger.add(
37
+ sys.stdout,
38
+ level=0,
39
+ diagnose=False,
40
+ filter=nb_log.default_filter,
41
+ format=nb_log.default_format,
42
+ )
43
+ amrita_core.logging.logger_id = nb_log.logger_id
44
+
45
+
46
+ _patch_logger() # AmritaCore会修改Loguru的配置,这里重置为NoneBot2的默认配置
47
+
48
+
49
+ def replace_config(config: Config):
50
+
51
+ if not isinstance(config, Config):
52
+ raise TypeError("config must be Config")
53
+ conf_module._config = config
54
+
55
+
56
+ @get_driver().on_startup
57
+ async def init():
58
+ _config = conf_module._config
59
+ am_cookie_conf = CookieConfig(
60
+ enable_cookie=_config.amrita_cookie_enable, cookie=_config.amrita_cookie
61
+ )
62
+ am_function_conf = FunctionConfig(
63
+ agent_mcp_client_enable=_config.amrita_mcp_enable,
64
+ agent_mcp_server_scripts=_config.amrita_mcp_clients,
65
+ agent_tool_call_limit=_config.amrita_tool_call_limit,
66
+ )
67
+ am_builtin_conf = BuiltinAgentConfig(
68
+ tool_calling_mode=_config.amrita_agent_mode,
69
+ agent_thought_mode=_config.amrita_agent_thought_mode,
70
+ )
71
+ am_llm_conf = LLMConfig(
72
+ memory_length_limit=_config.amrita_memory_length,
73
+ session_tokens_windows=_config.amrita_memory_token_limit,
74
+ max_tokens=_config.amrita_prompt_token_limit,
75
+ )
76
+ am_conf = AmritaConfig(
77
+ cookie=am_cookie_conf,
78
+ function_config=am_function_conf,
79
+ builtin=am_builtin_conf,
80
+ llm=am_llm_conf,
81
+ )
82
+ await minimal_init(am_conf)
83
+
84
+
85
+ @get_driver().on_shutdown
86
+ async def shutdown():
87
+ logger.info("Shutting down AmritaCore...")
88
+
89
+ async def kill_all(objs: list[ChatObject]):
90
+ for obj in objs:
91
+ with contextlib.suppress(Exception):
92
+ obj.terminate()
93
+
94
+ await asyncio.gather(
95
+ *[kill_all(objs) for objs in ChatManager().running_chat_object.values()],
96
+ return_exceptions=True,
97
+ )
98
+
99
+
100
+ __all__ = [
101
+ "CachedUserDataRepository",
102
+ "ChatManager",
103
+ "ChatObject",
104
+ "InsightsModel",
105
+ "MemorySchema",
106
+ "UserDataExecutor",
107
+ "agent",
108
+ "database",
109
+ "dirty",
110
+ "memory",
111
+ ]
@@ -0,0 +1,115 @@
1
+ from __future__ import annotations
2
+
3
+ import typing
4
+
5
+ import typing_extensions
6
+ from amrita_core import (
7
+ AgentRuntime as AmRuntime,
8
+ )
9
+ from amrita_core import (
10
+ AgentStrategy,
11
+ AmritaConfig,
12
+ ChatObject,
13
+ ModelPreset,
14
+ PresetManager,
15
+ SessionsManager,
16
+ get_config,
17
+ )
18
+ from amrita_core.builtins.agent import AmritaAgentStrategy
19
+ from amrita_core.consts import DEFAULT_TEMPLATE
20
+ from amrita_core.sessions import SessionData
21
+ from amrita_core.types import Content, Message
22
+ from jinja2 import Template
23
+ from nonebot.adapters import Event
24
+ from nonebot.params import Depends
25
+
26
+ from nonebot_plugin_amrita.database import InsightsModel, make_id
27
+ from nonebot_plugin_amrita.lock import lock_by_session
28
+ from nonebot_plugin_amrita.memory import (
29
+ AwaredMemory,
30
+ CachedUserDataRepository,
31
+ add_usage,
32
+ )
33
+
34
+
35
+ class AgentSession(AmRuntime):
36
+ chat_objs: list[ChatObject]
37
+
38
+ def __init__(
39
+ self,
40
+ config: AmritaConfig,
41
+ preset: ModelPreset,
42
+ train: dict[str, str] | Message[str],
43
+ strategy: type[AgentStrategy] = AmritaAgentStrategy,
44
+ template: Template | str = DEFAULT_TEMPLATE,
45
+ session: SessionData | str | None = None,
46
+ no_session: bool = False,
47
+ ):
48
+ super().__init__(config, preset, train, strategy, template, session, no_session)
49
+ self.chat_objs = []
50
+
51
+ async def __aenter__(self) -> typing_extensions.Self:
52
+ return self
53
+
54
+ async def __aexit__(self, exc_type, exc, tb):
55
+ del exc_type, exc, tb
56
+ uni_id = self.session_id
57
+ if not self.chat_objs:
58
+ return
59
+ async with lock_by_session(uni_id): # Thread safe
60
+ dm = CachedUserDataRepository()
61
+ metadata = await dm.get_metadata(uni_id)
62
+ insight = await InsightsModel.get()
63
+ for chat_object in self.chat_objs:
64
+ if chat_object.response.usage:
65
+ add_usage(metadata, chat_object.response.usage)
66
+ add_usage(insight, chat_object.response.usage)
67
+ await insight.save()
68
+ await dm.update_metadata(metadata)
69
+
70
+ @classmethod
71
+ async def load_from(
72
+ cls,
73
+ id_or_event: Event | str,
74
+ train: Message[str] | dict[str, str],
75
+ config: AmritaConfig | None = None,
76
+ preset: ModelPreset | None = None,
77
+ **kwargs,
78
+ ) -> AgentSession:
79
+ uni_id = make_id(id_or_event)
80
+ dm = CachedUserDataRepository()
81
+ memory = await dm.get_memory(uni_id)
82
+ config = config or get_config()
83
+ preset = preset or PresetManager().get_default_preset()
84
+ SessionsManager().init_session(uni_id)
85
+ session = SessionsManager().get_session_data(uni_id)
86
+ session.memory = memory.memory_json
87
+ return cls(config, preset, train, session=session, **kwargs)
88
+
89
+ async def save_context(self):
90
+ session_id = self.session_id
91
+ dm = CachedUserDataRepository()
92
+ context: AwaredMemory = typing.cast(AwaredMemory, self.context)
93
+ mem = await dm.get_memory(session_id)
94
+ mem.memory_json = context
95
+ await dm.update_memory_data(mem)
96
+
97
+ @typing_extensions.override
98
+ def get_chatobject(
99
+ self, input: typing.Sequence[Content] | str | None, **kwargs
100
+ ) -> ChatObject:
101
+ obj = super().get_chatobject(input, **kwargs)
102
+ self.chat_objs.append(obj)
103
+ return obj
104
+
105
+
106
+ def SessionDepends(
107
+ train: Message[str] | dict[str, str],
108
+ config: AmritaConfig | None = None,
109
+ preset: ModelPreset | None = None,
110
+ **kwargs,
111
+ ):
112
+ async def constructor(event: Event) -> AgentSession:
113
+ return await AgentSession.load_from(event, train, config, preset, **kwargs)
114
+
115
+ return Depends(constructor)