neobot-modloader 1.0.0a7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- neobot_modloader-1.0.0a7/PKG-INFO +11 -0
- neobot_modloader-1.0.0a7/README.md +0 -0
- neobot_modloader-1.0.0a7/pyproject.toml +22 -0
- neobot_modloader-1.0.0a7/src/neobot_modloader/__init__.py +19 -0
- neobot_modloader-1.0.0a7/src/neobot_modloader/context.py +230 -0
- neobot_modloader-1.0.0a7/src/neobot_modloader/events.py +288 -0
- neobot_modloader-1.0.0a7/src/neobot_modloader/hooks.py +200 -0
- neobot_modloader-1.0.0a7/src/neobot_modloader/loader.py +160 -0
- neobot_modloader-1.0.0a7/src/neobot_modloader/manager.py +151 -0
- neobot_modloader-1.0.0a7/src/neobot_modloader/plugin.py +39 -0
- neobot_modloader-1.0.0a7/src/neobot_modloader/py.typed +0 -0
- neobot_modloader-1.0.0a7/src/neobot_modloader/runtime.py +101 -0
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: neobot-modloader
|
|
3
|
+
Version: 1.0.0a7
|
|
4
|
+
Summary: Add your description here
|
|
5
|
+
Author: wsrsq, tangtian
|
|
6
|
+
Author-email: wsrsq <wsrsq001@163.com>, tangtian <a14b@126.com>
|
|
7
|
+
Requires-Dist: neobot-adapter
|
|
8
|
+
Requires-Dist: neobot-contracts
|
|
9
|
+
Requires-Python: >=3.13
|
|
10
|
+
Description-Content-Type: text/markdown
|
|
11
|
+
|
|
File without changes
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "neobot-modloader"
|
|
3
|
+
version = "1.0.0-alpha.7"
|
|
4
|
+
description = "Add your description here"
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
authors = [
|
|
7
|
+
{ name = "wsrsq", email = "wsrsq001@163.com" },
|
|
8
|
+
{ name = "tangtian", email = "a14b@126.com" },
|
|
9
|
+
]
|
|
10
|
+
requires-python = ">=3.13"
|
|
11
|
+
dependencies = [
|
|
12
|
+
"neobot-adapter",
|
|
13
|
+
"neobot-contracts",
|
|
14
|
+
]
|
|
15
|
+
|
|
16
|
+
[tool.uv.sources]
|
|
17
|
+
neobot-adapter = { workspace = true }
|
|
18
|
+
neobot-contracts = { workspace = true }
|
|
19
|
+
|
|
20
|
+
[build-system]
|
|
21
|
+
requires = ["uv_build>=0.9.27,<0.10.0"]
|
|
22
|
+
build-backend = "uv_build"
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from neobot_modloader.context import PluginContext
|
|
4
|
+
from neobot_modloader.events import PluginEventBus
|
|
5
|
+
from neobot_modloader.hooks import PluginHookBus
|
|
6
|
+
from neobot_modloader.loader import FilesystemPluginLoader
|
|
7
|
+
from neobot_modloader.manager import DefaultPluginManager
|
|
8
|
+
from neobot_modloader.plugin import BasePlugin
|
|
9
|
+
from neobot_modloader.runtime import PluginRuntime
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"BasePlugin",
|
|
13
|
+
"DefaultPluginManager",
|
|
14
|
+
"FilesystemPluginLoader",
|
|
15
|
+
"PluginContext",
|
|
16
|
+
"PluginEventBus",
|
|
17
|
+
"PluginHookBus",
|
|
18
|
+
"PluginRuntime",
|
|
19
|
+
]
|
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any, Mapping
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel
|
|
7
|
+
|
|
8
|
+
from neobot_adapter.model.response import SendMsgResponse
|
|
9
|
+
from neobot_contracts.models import ConversationRef
|
|
10
|
+
from neobot_contracts.ports.logging import Logger, NullLogger
|
|
11
|
+
|
|
12
|
+
from neobot_modloader.events import PluginEventBus
|
|
13
|
+
|
|
14
|
+
MessagePayload = str | list[dict[str, Any]]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class PluginAgentRegistrar:
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
*,
|
|
21
|
+
plugin_name: str,
|
|
22
|
+
registry: Any | None,
|
|
23
|
+
record_registration: Any | None,
|
|
24
|
+
) -> None:
|
|
25
|
+
self._plugin_name = plugin_name
|
|
26
|
+
self._registry = registry
|
|
27
|
+
self._record_registration = record_registration
|
|
28
|
+
self._registered: dict[str, Any] = {}
|
|
29
|
+
|
|
30
|
+
@property
|
|
31
|
+
def names(self) -> list[str]:
|
|
32
|
+
return list(self._registered)
|
|
33
|
+
|
|
34
|
+
def register(self, name: str, agent: Any) -> str:
|
|
35
|
+
if self._registry is None:
|
|
36
|
+
raise RuntimeError("Agent registry is not available")
|
|
37
|
+
local_name = self._validate_name(name)
|
|
38
|
+
self._validate_agent(agent)
|
|
39
|
+
registered_name = self._registered_name(local_name)
|
|
40
|
+
if registered_name in self._registered:
|
|
41
|
+
raise ValueError(f"插件 Agent 已注册: {registered_name}")
|
|
42
|
+
registry_names = getattr(self._registry, "names", [])
|
|
43
|
+
if registered_name in registry_names:
|
|
44
|
+
raise ValueError(f"Agent 已注册: {registered_name}")
|
|
45
|
+
self._registry.register(registered_name, agent)
|
|
46
|
+
self._registered[registered_name] = agent
|
|
47
|
+
if self._record_registration is not None:
|
|
48
|
+
self._record_registration(registered_name, agent)
|
|
49
|
+
return registered_name
|
|
50
|
+
|
|
51
|
+
def unregister(self, registered_name: str) -> Any | None:
|
|
52
|
+
agent = self._registered.pop(registered_name, None)
|
|
53
|
+
if self._registry is None:
|
|
54
|
+
return agent
|
|
55
|
+
unregister = getattr(self._registry, "unregister", None)
|
|
56
|
+
if callable(unregister):
|
|
57
|
+
removed = unregister(registered_name)
|
|
58
|
+
return removed if removed is not None else agent
|
|
59
|
+
return agent
|
|
60
|
+
|
|
61
|
+
def snapshot(self) -> list[dict[str, str]]:
|
|
62
|
+
return [
|
|
63
|
+
{"name": name, "description": str(getattr(agent, "description", ""))}
|
|
64
|
+
for name, agent in self._registered.items()
|
|
65
|
+
]
|
|
66
|
+
|
|
67
|
+
def list_agents(self, name: str | None = None) -> str:
|
|
68
|
+
if name is not None:
|
|
69
|
+
local_name = self._validate_name(name)
|
|
70
|
+
registered_name = self._registered_name(local_name)
|
|
71
|
+
agent = self._registered.get(registered_name)
|
|
72
|
+
if agent is None:
|
|
73
|
+
return f"Agent '{registered_name}' not found"
|
|
74
|
+
return f"Agent {registered_name}: {getattr(agent, 'description', '')}"
|
|
75
|
+
if not self._registered:
|
|
76
|
+
return "No agents available"
|
|
77
|
+
lines = [
|
|
78
|
+
f"- {registered_name}: {getattr(agent, 'description', '')}"
|
|
79
|
+
for registered_name, agent in self._registered.items()
|
|
80
|
+
]
|
|
81
|
+
return "Available agents:\n" + "\n".join(lines)
|
|
82
|
+
|
|
83
|
+
def _registered_name(self, local_name: str) -> str:
|
|
84
|
+
return f"plugin:{self._plugin_name}:{local_name}"
|
|
85
|
+
|
|
86
|
+
@staticmethod
|
|
87
|
+
def _validate_name(name: str) -> str:
|
|
88
|
+
if not isinstance(name, str):
|
|
89
|
+
raise TypeError("Agent name must be a string")
|
|
90
|
+
if not name:
|
|
91
|
+
raise ValueError("Agent name cannot be empty")
|
|
92
|
+
if name != name.strip():
|
|
93
|
+
raise ValueError("Agent name cannot contain leading or trailing whitespace")
|
|
94
|
+
if ":" in name:
|
|
95
|
+
raise ValueError("Agent name cannot contain ':'")
|
|
96
|
+
return name
|
|
97
|
+
|
|
98
|
+
@staticmethod
|
|
99
|
+
def _validate_agent(agent: Any) -> None:
|
|
100
|
+
missing: list[str] = []
|
|
101
|
+
for attr in ("description", "tool_definitions"):
|
|
102
|
+
if not hasattr(agent, attr):
|
|
103
|
+
missing.append(attr)
|
|
104
|
+
for method in ("invoke", "stream_invoke", "close"):
|
|
105
|
+
if not callable(getattr(agent, method, None)):
|
|
106
|
+
missing.append(method)
|
|
107
|
+
if missing:
|
|
108
|
+
raise TypeError(f"Plugin agent is missing required attributes: {', '.join(missing)}")
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class PluginContext:
|
|
112
|
+
def __init__(
|
|
113
|
+
self,
|
|
114
|
+
*,
|
|
115
|
+
plugin_name: str,
|
|
116
|
+
plugin_dir: Path,
|
|
117
|
+
data_dir: Path,
|
|
118
|
+
config: Mapping[str, Any] | None,
|
|
119
|
+
logger: Logger | None,
|
|
120
|
+
adapter: Any,
|
|
121
|
+
hook_bus: Any | None = None,
|
|
122
|
+
record_subscription: Any | None = None,
|
|
123
|
+
agent_registry: Any | None = None,
|
|
124
|
+
record_agent_registration: Any | None = None,
|
|
125
|
+
) -> None:
|
|
126
|
+
self._plugin_name = plugin_name
|
|
127
|
+
self._plugin_dir = plugin_dir
|
|
128
|
+
self._data_dir = data_dir
|
|
129
|
+
self._config = dict(config or {})
|
|
130
|
+
self._logger = logger or NullLogger()
|
|
131
|
+
self._adapter = adapter
|
|
132
|
+
self._data_dir.mkdir(parents=True, exist_ok=True)
|
|
133
|
+
self.agents = PluginAgentRegistrar(
|
|
134
|
+
plugin_name=plugin_name,
|
|
135
|
+
registry=agent_registry,
|
|
136
|
+
record_registration=record_agent_registration,
|
|
137
|
+
)
|
|
138
|
+
from neobot_modloader.hooks import PluginHookBus
|
|
139
|
+
|
|
140
|
+
self.on = PluginEventBus(
|
|
141
|
+
hook_bus=hook_bus or PluginHookBus(logger=self._logger),
|
|
142
|
+
logger=self._logger,
|
|
143
|
+
record_subscription=record_subscription or (lambda _subscription: None),
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
@property
|
|
147
|
+
def plugin_name(self) -> str:
|
|
148
|
+
return self._plugin_name
|
|
149
|
+
|
|
150
|
+
@property
|
|
151
|
+
def plugin_dir(self) -> Path:
|
|
152
|
+
return self._plugin_dir
|
|
153
|
+
|
|
154
|
+
@property
|
|
155
|
+
def data_dir(self) -> Path:
|
|
156
|
+
return self._data_dir
|
|
157
|
+
|
|
158
|
+
@property
|
|
159
|
+
def config(self) -> Mapping[str, Any]:
|
|
160
|
+
return self._config
|
|
161
|
+
|
|
162
|
+
@property
|
|
163
|
+
def logger(self) -> Logger:
|
|
164
|
+
return self._logger
|
|
165
|
+
|
|
166
|
+
async def send_private(self, user_id: int, message: MessagePayload) -> SendMsgResponse:
|
|
167
|
+
return await self._adapter.send_private_msg(user_id, message)
|
|
168
|
+
|
|
169
|
+
async def send_group(self, group_id: int, message: MessagePayload) -> SendMsgResponse:
|
|
170
|
+
return await self._adapter.send_group_msg(group_id, message)
|
|
171
|
+
|
|
172
|
+
async def send(
|
|
173
|
+
self,
|
|
174
|
+
conversation: ConversationRef,
|
|
175
|
+
message: MessagePayload,
|
|
176
|
+
) -> SendMsgResponse:
|
|
177
|
+
return await self._adapter.send(conversation, message)
|
|
178
|
+
|
|
179
|
+
async def reply(self, event: dict[str, Any] | BaseModel, message: MessagePayload) -> SendMsgResponse:
|
|
180
|
+
return await self.send(self.conversation_from_event(event), message)
|
|
181
|
+
|
|
182
|
+
def message_text(self, event: dict[str, Any] | BaseModel) -> str:
|
|
183
|
+
data = self._event_to_dict(event)
|
|
184
|
+
raw_message = data.get("raw_message")
|
|
185
|
+
if raw_message is not None:
|
|
186
|
+
return str(raw_message)
|
|
187
|
+
|
|
188
|
+
message = data.get("message")
|
|
189
|
+
if isinstance(message, str):
|
|
190
|
+
return message
|
|
191
|
+
if isinstance(message, list):
|
|
192
|
+
return "".join(self._segment_text(segment) for segment in message)
|
|
193
|
+
if message is None:
|
|
194
|
+
return ""
|
|
195
|
+
return str(message)
|
|
196
|
+
|
|
197
|
+
def conversation_from_event(self, event: dict[str, Any] | BaseModel) -> ConversationRef:
|
|
198
|
+
data = self._event_to_dict(event)
|
|
199
|
+
message_type = data.get("message_type")
|
|
200
|
+
if message_type == "private" and data.get("user_id") is not None:
|
|
201
|
+
return ConversationRef(kind="private", id=str(data["user_id"]))
|
|
202
|
+
if message_type == "group" and data.get("group_id") is not None:
|
|
203
|
+
return ConversationRef(kind="group", id=str(data["group_id"]))
|
|
204
|
+
if data.get("group_id") is not None:
|
|
205
|
+
return ConversationRef(kind="group", id=str(data["group_id"]))
|
|
206
|
+
if data.get("user_id") is not None:
|
|
207
|
+
return ConversationRef(kind="private", id=str(data["user_id"]))
|
|
208
|
+
raise ValueError(f"无法从事件推断会话: plugin={self.plugin_name}")
|
|
209
|
+
|
|
210
|
+
def require_config(self, key: str) -> Any:
|
|
211
|
+
if key not in self._config:
|
|
212
|
+
raise KeyError(f"插件 {self.plugin_name!r} 缺少配置项 {key!r}")
|
|
213
|
+
return self._config[key]
|
|
214
|
+
|
|
215
|
+
def _event_to_dict(self, event: dict[str, Any] | BaseModel) -> dict[str, Any]:
|
|
216
|
+
if isinstance(event, BaseModel):
|
|
217
|
+
return event.model_dump(mode="python")
|
|
218
|
+
return dict(event)
|
|
219
|
+
|
|
220
|
+
def _segment_text(self, segment: Any) -> str:
|
|
221
|
+
if isinstance(segment, BaseModel):
|
|
222
|
+
segment = segment.model_dump(mode="python")
|
|
223
|
+
if not isinstance(segment, Mapping):
|
|
224
|
+
return str(segment)
|
|
225
|
+
if segment.get("type") != "text":
|
|
226
|
+
return ""
|
|
227
|
+
data = segment.get("data")
|
|
228
|
+
if isinstance(data, Mapping):
|
|
229
|
+
return str(data.get("text", ""))
|
|
230
|
+
return ""
|
|
@@ -0,0 +1,288 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import inspect
|
|
4
|
+
import re
|
|
5
|
+
from collections.abc import Awaitable, Callable, Sequence
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from neobot_contracts.ports.logging import Logger, NullLogger
|
|
9
|
+
|
|
10
|
+
from neobot_modloader.hooks import PluginHookBus, Rule
|
|
11
|
+
|
|
12
|
+
EventHandler = Callable[..., Any]
|
|
13
|
+
SubscriptionRecorder = Callable[[Any], None]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class PluginEventBus:
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
*,
|
|
20
|
+
hook_bus: PluginHookBus,
|
|
21
|
+
logger: Logger | None = None,
|
|
22
|
+
record_subscription: SubscriptionRecorder | None = None,
|
|
23
|
+
) -> None:
|
|
24
|
+
self._hook_bus = hook_bus
|
|
25
|
+
self._logger = logger or NullLogger()
|
|
26
|
+
self._record_subscription = record_subscription or (lambda _subscription: None)
|
|
27
|
+
|
|
28
|
+
def message(
|
|
29
|
+
self,
|
|
30
|
+
*,
|
|
31
|
+
group: bool = False,
|
|
32
|
+
private: bool = False,
|
|
33
|
+
sub_type: str | None = None,
|
|
34
|
+
rule: Rule | None = None,
|
|
35
|
+
priority: int = 0,
|
|
36
|
+
timeout: float | None = None,
|
|
37
|
+
block: bool = False,
|
|
38
|
+
block_ai_reply: bool = False,
|
|
39
|
+
regex: str | re.Pattern[str] | None = None,
|
|
40
|
+
keywords: str | Sequence[str] | None = None,
|
|
41
|
+
contains: str | Sequence[str] | None = None,
|
|
42
|
+
not_contains: str | Sequence[str] | None = None,
|
|
43
|
+
) -> Callable[[EventHandler], EventHandler]:
|
|
44
|
+
if group and private:
|
|
45
|
+
raise ValueError("group 和 private 不能同时为 True")
|
|
46
|
+
message_type = "group" if group else "private" if private else None
|
|
47
|
+
rule = _build_message_rule(
|
|
48
|
+
rule=rule,
|
|
49
|
+
regex=regex,
|
|
50
|
+
keywords=keywords,
|
|
51
|
+
contains=contains,
|
|
52
|
+
not_contains=not_contains,
|
|
53
|
+
)
|
|
54
|
+
return self._decorator(
|
|
55
|
+
post_type="message",
|
|
56
|
+
timeout=timeout,
|
|
57
|
+
block=block,
|
|
58
|
+
block_ai_reply=block_ai_reply,
|
|
59
|
+
message_type=message_type,
|
|
60
|
+
notice_type=None,
|
|
61
|
+
request_type=None,
|
|
62
|
+
meta_event_type=None,
|
|
63
|
+
sub_type=sub_type,
|
|
64
|
+
rule=rule,
|
|
65
|
+
priority=priority,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
def notice(
|
|
69
|
+
self,
|
|
70
|
+
notice_type: str | None = None,
|
|
71
|
+
*,
|
|
72
|
+
sub_type: str | None = None,
|
|
73
|
+
rule: Rule | None = None,
|
|
74
|
+
priority: int = 0,
|
|
75
|
+
timeout: float | None = None,
|
|
76
|
+
block: bool = False,
|
|
77
|
+
block_ai_reply: bool = False,
|
|
78
|
+
) -> Callable[[EventHandler], EventHandler]:
|
|
79
|
+
return self._decorator(
|
|
80
|
+
post_type="notice",
|
|
81
|
+
timeout=timeout,
|
|
82
|
+
block=block,
|
|
83
|
+
block_ai_reply=block_ai_reply,
|
|
84
|
+
message_type=None,
|
|
85
|
+
notice_type=notice_type,
|
|
86
|
+
request_type=None,
|
|
87
|
+
meta_event_type=None,
|
|
88
|
+
sub_type=sub_type,
|
|
89
|
+
rule=rule,
|
|
90
|
+
priority=priority,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
def request(
|
|
94
|
+
self,
|
|
95
|
+
request_type: str | None = None,
|
|
96
|
+
*,
|
|
97
|
+
sub_type: str | None = None,
|
|
98
|
+
rule: Rule | None = None,
|
|
99
|
+
priority: int = 0,
|
|
100
|
+
timeout: float | None = None,
|
|
101
|
+
block: bool = False,
|
|
102
|
+
block_ai_reply: bool = False,
|
|
103
|
+
) -> Callable[[EventHandler], EventHandler]:
|
|
104
|
+
return self._decorator(
|
|
105
|
+
post_type="request",
|
|
106
|
+
timeout=timeout,
|
|
107
|
+
block=block,
|
|
108
|
+
block_ai_reply=block_ai_reply,
|
|
109
|
+
message_type=None,
|
|
110
|
+
notice_type=None,
|
|
111
|
+
request_type=request_type,
|
|
112
|
+
meta_event_type=None,
|
|
113
|
+
sub_type=sub_type,
|
|
114
|
+
rule=rule,
|
|
115
|
+
priority=priority,
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
def meta_event(
|
|
119
|
+
self,
|
|
120
|
+
meta_event_type: str | None = None,
|
|
121
|
+
*,
|
|
122
|
+
sub_type: str | None = None,
|
|
123
|
+
rule: Rule | None = None,
|
|
124
|
+
priority: int = 0,
|
|
125
|
+
timeout: float | None = None,
|
|
126
|
+
block: bool = False,
|
|
127
|
+
block_ai_reply: bool = False,
|
|
128
|
+
) -> Callable[[EventHandler], EventHandler]:
|
|
129
|
+
return self._decorator(
|
|
130
|
+
post_type="meta_event",
|
|
131
|
+
timeout=timeout,
|
|
132
|
+
block=block,
|
|
133
|
+
block_ai_reply=block_ai_reply,
|
|
134
|
+
message_type=None,
|
|
135
|
+
notice_type=None,
|
|
136
|
+
request_type=None,
|
|
137
|
+
meta_event_type=meta_event_type,
|
|
138
|
+
sub_type=sub_type,
|
|
139
|
+
rule=rule,
|
|
140
|
+
priority=priority,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
def event(
|
|
144
|
+
self,
|
|
145
|
+
*,
|
|
146
|
+
post_type: str | None = None,
|
|
147
|
+
message_type: str | None = None,
|
|
148
|
+
notice_type: str | None = None,
|
|
149
|
+
request_type: str | None = None,
|
|
150
|
+
meta_event_type: str | None = None,
|
|
151
|
+
sub_type: str | None = None,
|
|
152
|
+
rule: Rule | None = None,
|
|
153
|
+
priority: int = 0,
|
|
154
|
+
timeout: float | None = None,
|
|
155
|
+
block: bool = False,
|
|
156
|
+
block_ai_reply: bool = False,
|
|
157
|
+
) -> Callable[[EventHandler], EventHandler]:
|
|
158
|
+
return self._decorator(
|
|
159
|
+
post_type=post_type,
|
|
160
|
+
timeout=timeout,
|
|
161
|
+
block=block,
|
|
162
|
+
block_ai_reply=block_ai_reply,
|
|
163
|
+
message_type=message_type,
|
|
164
|
+
notice_type=notice_type,
|
|
165
|
+
request_type=request_type,
|
|
166
|
+
meta_event_type=meta_event_type,
|
|
167
|
+
sub_type=sub_type,
|
|
168
|
+
rule=rule,
|
|
169
|
+
priority=priority,
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
def _decorator(
|
|
173
|
+
self,
|
|
174
|
+
*,
|
|
175
|
+
post_type: str | None,
|
|
176
|
+
timeout: float | None,
|
|
177
|
+
block: bool,
|
|
178
|
+
block_ai_reply: bool,
|
|
179
|
+
message_type: str | None,
|
|
180
|
+
notice_type: str | None,
|
|
181
|
+
request_type: str | None,
|
|
182
|
+
meta_event_type: str | None,
|
|
183
|
+
sub_type: str | None,
|
|
184
|
+
rule: Rule | None,
|
|
185
|
+
priority: int,
|
|
186
|
+
) -> Callable[[EventHandler], EventHandler]:
|
|
187
|
+
def register(handler: EventHandler) -> EventHandler:
|
|
188
|
+
subscription = self._hook_bus.subscribe(
|
|
189
|
+
handler,
|
|
190
|
+
post_type=post_type,
|
|
191
|
+
message_type=message_type,
|
|
192
|
+
notice_type=notice_type,
|
|
193
|
+
request_type=request_type,
|
|
194
|
+
meta_event_type=meta_event_type,
|
|
195
|
+
sub_type=sub_type,
|
|
196
|
+
rule=rule,
|
|
197
|
+
priority=priority,
|
|
198
|
+
timeout=timeout,
|
|
199
|
+
block=block,
|
|
200
|
+
block_ai_reply=block_ai_reply,
|
|
201
|
+
logger=self._logger,
|
|
202
|
+
)
|
|
203
|
+
self._record_subscription(subscription)
|
|
204
|
+
return handler
|
|
205
|
+
|
|
206
|
+
return register
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def _build_message_rule(
|
|
210
|
+
*,
|
|
211
|
+
rule: Rule | None,
|
|
212
|
+
regex: str | re.Pattern[str] | None,
|
|
213
|
+
keywords: str | Sequence[str] | None,
|
|
214
|
+
contains: str | Sequence[str] | None,
|
|
215
|
+
not_contains: str | Sequence[str] | None,
|
|
216
|
+
) -> Rule | None:
|
|
217
|
+
if regex is None and keywords is None and contains is None and not_contains is None:
|
|
218
|
+
return rule
|
|
219
|
+
|
|
220
|
+
keyword_values = _to_text_list(keywords)
|
|
221
|
+
contains_values = _to_text_list(contains)
|
|
222
|
+
not_contains_values = _to_text_list(not_contains)
|
|
223
|
+
compiled_regex = re.compile(regex) if isinstance(regex, str) else regex
|
|
224
|
+
|
|
225
|
+
async def combined(event: dict[str, Any]) -> bool:
|
|
226
|
+
text = _message_text(event)
|
|
227
|
+
if compiled_regex is not None and compiled_regex.search(text) is None:
|
|
228
|
+
return False
|
|
229
|
+
if keyword_values and not any(keyword in text for keyword in keyword_values):
|
|
230
|
+
return False
|
|
231
|
+
if contains_values and not all(value in text for value in contains_values):
|
|
232
|
+
return False
|
|
233
|
+
if not_contains_values and any(value in text for value in not_contains_values):
|
|
234
|
+
return False
|
|
235
|
+
if rule is None:
|
|
236
|
+
return True
|
|
237
|
+
result = rule(event)
|
|
238
|
+
if inspect.isawaitable(result):
|
|
239
|
+
result = await result
|
|
240
|
+
return bool(result)
|
|
241
|
+
|
|
242
|
+
return combined
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def _to_text_list(value: str | Sequence[str] | None) -> list[str]:
|
|
246
|
+
if value is None:
|
|
247
|
+
return []
|
|
248
|
+
if isinstance(value, str):
|
|
249
|
+
return [value]
|
|
250
|
+
return [str(item) for item in value]
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def _message_text(event: Any) -> str:
|
|
254
|
+
data = _event_to_dict(event)
|
|
255
|
+
raw_message = data.get("raw_message")
|
|
256
|
+
if raw_message is not None:
|
|
257
|
+
return str(raw_message)
|
|
258
|
+
|
|
259
|
+
message = data.get("message")
|
|
260
|
+
if isinstance(message, str):
|
|
261
|
+
return message
|
|
262
|
+
if isinstance(message, list):
|
|
263
|
+
return "".join(_segment_text(segment) for segment in message)
|
|
264
|
+
if message is None:
|
|
265
|
+
return ""
|
|
266
|
+
return str(message)
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
def _event_to_dict(event: Any) -> dict[str, Any]:
|
|
270
|
+
if isinstance(event, dict):
|
|
271
|
+
return event
|
|
272
|
+
if hasattr(event, "model_dump"):
|
|
273
|
+
dumped = event.model_dump(mode="python")
|
|
274
|
+
return dumped if isinstance(dumped, dict) else {}
|
|
275
|
+
return {}
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def _segment_text(segment: Any) -> str:
|
|
279
|
+
if hasattr(segment, "model_dump"):
|
|
280
|
+
segment = segment.model_dump(mode="python")
|
|
281
|
+
if not isinstance(segment, dict):
|
|
282
|
+
return str(segment)
|
|
283
|
+
if segment.get("type") != "text":
|
|
284
|
+
return ""
|
|
285
|
+
data = segment.get("data")
|
|
286
|
+
if isinstance(data, dict):
|
|
287
|
+
return str(data.get("text", ""))
|
|
288
|
+
return ""
|
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import inspect
|
|
5
|
+
import threading
|
|
6
|
+
from collections.abc import Awaitable, Callable
|
|
7
|
+
from dataclasses import dataclass, field
|
|
8
|
+
from typing import Any, get_type_hints
|
|
9
|
+
|
|
10
|
+
from pydantic import BaseModel
|
|
11
|
+
|
|
12
|
+
from neobot_contracts.ports.logging import Logger, NullLogger
|
|
13
|
+
|
|
14
|
+
Rule = Callable[[dict[str, Any]], bool | Awaitable[bool]]
|
|
15
|
+
EventHandler = Callable[..., Any]
|
|
16
|
+
ReplyBlockRecorder = Callable[[Any], None]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class HookSubscription:
|
|
21
|
+
_unsubscribe: Callable[[], None]
|
|
22
|
+
_active: bool = True
|
|
23
|
+
|
|
24
|
+
def unsubscribe(self) -> None:
|
|
25
|
+
if not self._active:
|
|
26
|
+
return
|
|
27
|
+
self._unsubscribe()
|
|
28
|
+
self._active = False
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass(slots=True)
|
|
32
|
+
class HookRegistration:
|
|
33
|
+
handler: EventHandler
|
|
34
|
+
post_type: str | None = None
|
|
35
|
+
message_type: str | None = None
|
|
36
|
+
notice_type: str | None = None
|
|
37
|
+
request_type: str | None = None
|
|
38
|
+
meta_event_type: str | None = None
|
|
39
|
+
sub_type: str | None = None
|
|
40
|
+
rule: Rule | None = None
|
|
41
|
+
priority: int = 0
|
|
42
|
+
timeout: float | None = None
|
|
43
|
+
block: bool = False
|
|
44
|
+
block_ai_reply: bool = False
|
|
45
|
+
logger: Logger = field(default_factory=NullLogger)
|
|
46
|
+
event_model: type[BaseModel] | None = None
|
|
47
|
+
|
|
48
|
+
def matches(self, event: dict[str, Any]) -> bool:
|
|
49
|
+
if self.post_type and event.get("post_type") != self.post_type:
|
|
50
|
+
return False
|
|
51
|
+
if self.message_type and event.get("message_type") != self.message_type:
|
|
52
|
+
return False
|
|
53
|
+
if self.notice_type and event.get("notice_type") != self.notice_type:
|
|
54
|
+
return False
|
|
55
|
+
if self.request_type and event.get("request_type") != self.request_type:
|
|
56
|
+
return False
|
|
57
|
+
if self.meta_event_type and event.get("meta_event_type") != self.meta_event_type:
|
|
58
|
+
return False
|
|
59
|
+
if self.sub_type and event.get("sub_type") != self.sub_type:
|
|
60
|
+
return False
|
|
61
|
+
return True
|
|
62
|
+
|
|
63
|
+
def coerce(self, event: dict[str, Any]) -> Any:
|
|
64
|
+
if self.event_model is not None:
|
|
65
|
+
return self.event_model.model_validate(event)
|
|
66
|
+
return event
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class PluginHookBus:
|
|
70
|
+
def __init__(
|
|
71
|
+
self,
|
|
72
|
+
*,
|
|
73
|
+
logger: Logger | None = None,
|
|
74
|
+
record_ai_reply_block: ReplyBlockRecorder | None = None,
|
|
75
|
+
) -> None:
|
|
76
|
+
self._logger = logger or NullLogger()
|
|
77
|
+
self._record_ai_reply_block = record_ai_reply_block
|
|
78
|
+
self._hooks: list[HookRegistration] = []
|
|
79
|
+
self._lock = threading.RLock()
|
|
80
|
+
|
|
81
|
+
def subscribe(
|
|
82
|
+
self,
|
|
83
|
+
handler: EventHandler,
|
|
84
|
+
*,
|
|
85
|
+
post_type: str | None = None,
|
|
86
|
+
message_type: str | None = None,
|
|
87
|
+
notice_type: str | None = None,
|
|
88
|
+
request_type: str | None = None,
|
|
89
|
+
meta_event_type: str | None = None,
|
|
90
|
+
sub_type: str | None = None,
|
|
91
|
+
rule: Rule | None = None,
|
|
92
|
+
priority: int = 0,
|
|
93
|
+
timeout: float | None = None,
|
|
94
|
+
block: bool = False,
|
|
95
|
+
block_ai_reply: bool = False,
|
|
96
|
+
logger: Logger | None = None,
|
|
97
|
+
) -> HookSubscription:
|
|
98
|
+
registration = HookRegistration(
|
|
99
|
+
handler=handler,
|
|
100
|
+
post_type=post_type,
|
|
101
|
+
message_type=message_type,
|
|
102
|
+
notice_type=notice_type,
|
|
103
|
+
request_type=request_type,
|
|
104
|
+
meta_event_type=meta_event_type,
|
|
105
|
+
sub_type=sub_type,
|
|
106
|
+
rule=rule,
|
|
107
|
+
priority=priority,
|
|
108
|
+
timeout=timeout,
|
|
109
|
+
block=block,
|
|
110
|
+
block_ai_reply=block_ai_reply,
|
|
111
|
+
logger=logger or self._logger,
|
|
112
|
+
event_model=_extract_event_model(handler),
|
|
113
|
+
)
|
|
114
|
+
with self._lock:
|
|
115
|
+
self._hooks.append(registration)
|
|
116
|
+
self._hooks.sort(key=lambda item: item.priority, reverse=True)
|
|
117
|
+
|
|
118
|
+
def _unsubscribe() -> None:
|
|
119
|
+
with self._lock:
|
|
120
|
+
self._hooks = [item for item in self._hooks if item is not registration]
|
|
121
|
+
|
|
122
|
+
return HookSubscription(_unsubscribe)
|
|
123
|
+
|
|
124
|
+
async def dispatch(self, ctx: Any) -> None:
|
|
125
|
+
event = getattr(ctx, "raw_event", None)
|
|
126
|
+
if not isinstance(event, dict):
|
|
127
|
+
return
|
|
128
|
+
with self._lock:
|
|
129
|
+
hooks = [hook for hook in self._hooks if hook.matches(event)]
|
|
130
|
+
|
|
131
|
+
for hook in hooks:
|
|
132
|
+
if getattr(ctx, "consumed", False):
|
|
133
|
+
break
|
|
134
|
+
if hook.rule is not None:
|
|
135
|
+
try:
|
|
136
|
+
rule_result = hook.rule(event)
|
|
137
|
+
if inspect.isawaitable(rule_result):
|
|
138
|
+
rule_result = await rule_result
|
|
139
|
+
if not rule_result:
|
|
140
|
+
continue
|
|
141
|
+
except Exception as exc:
|
|
142
|
+
hook.logger.exception(f"插件事件规则执行失败 ({hook.handler.__module__}.{hook.handler.__qualname__}): {exc}")
|
|
143
|
+
continue
|
|
144
|
+
|
|
145
|
+
try:
|
|
146
|
+
payload = hook.coerce(event)
|
|
147
|
+
except Exception as exc:
|
|
148
|
+
hook.logger.exception(f"插件事件模型转换失败 ({hook.handler.__module__}.{hook.handler.__qualname__}): {exc}")
|
|
149
|
+
payload = event
|
|
150
|
+
|
|
151
|
+
handled = await self._call_hook(hook, payload)
|
|
152
|
+
if not handled:
|
|
153
|
+
continue
|
|
154
|
+
|
|
155
|
+
if hook.block_ai_reply:
|
|
156
|
+
block_ai_reply = getattr(ctx, "block_ai_reply", None)
|
|
157
|
+
if callable(block_ai_reply):
|
|
158
|
+
block_ai_reply()
|
|
159
|
+
if self._record_ai_reply_block is not None:
|
|
160
|
+
self._record_ai_reply_block(event)
|
|
161
|
+
if hook.block:
|
|
162
|
+
consume = getattr(ctx, "consume", None)
|
|
163
|
+
if callable(consume):
|
|
164
|
+
consume()
|
|
165
|
+
break
|
|
166
|
+
|
|
167
|
+
async def _call_hook(self, hook: HookRegistration, event: Any) -> bool:
|
|
168
|
+
try:
|
|
169
|
+
call = _call_handler(hook.handler, event)
|
|
170
|
+
if hook.timeout is None:
|
|
171
|
+
await call
|
|
172
|
+
else:
|
|
173
|
+
await asyncio.wait_for(call, timeout=hook.timeout)
|
|
174
|
+
return True
|
|
175
|
+
except TimeoutError:
|
|
176
|
+
hook.logger.warning(f"插件事件处理超时: {hook.handler.__module__}.{hook.handler.__qualname__}")
|
|
177
|
+
return False
|
|
178
|
+
except Exception as exc:
|
|
179
|
+
hook.logger.exception(f"插件事件处理失败 ({hook.handler.__module__}.{hook.handler.__qualname__}): {exc}")
|
|
180
|
+
return False
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def _extract_event_model(handler: EventHandler) -> type[BaseModel] | None:
|
|
184
|
+
try:
|
|
185
|
+
hints = get_type_hints(handler)
|
|
186
|
+
except Exception:
|
|
187
|
+
return None
|
|
188
|
+
params = list(inspect.signature(handler).parameters.values())
|
|
189
|
+
if not params:
|
|
190
|
+
return None
|
|
191
|
+
annotation = hints.get(params[0].name)
|
|
192
|
+
if isinstance(annotation, type) and issubclass(annotation, BaseModel):
|
|
193
|
+
return annotation
|
|
194
|
+
return None
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
async def _call_handler(handler: EventHandler, event: Any) -> Any:
|
|
198
|
+
if inspect.iscoroutinefunction(handler):
|
|
199
|
+
return await handler(event)
|
|
200
|
+
return await asyncio.to_thread(handler, event)
|
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import hashlib
|
|
4
|
+
import importlib.util
|
|
5
|
+
import re
|
|
6
|
+
import sys
|
|
7
|
+
import tomllib
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from types import ModuleType
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
from neobot_contracts.ports.logging import Logger, NullLogger
|
|
14
|
+
|
|
15
|
+
from neobot_modloader.plugin import FunctionPlugin
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass(frozen=True, slots=True)
|
|
19
|
+
class LoadedPlugin:
|
|
20
|
+
name: str
|
|
21
|
+
version: str
|
|
22
|
+
plugin: Any
|
|
23
|
+
plugin_dir: Path
|
|
24
|
+
config: dict[str, Any]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass(frozen=True, slots=True)
|
|
28
|
+
class PluginLoadError:
|
|
29
|
+
name: str
|
|
30
|
+
plugin_dir: Path
|
|
31
|
+
error: Exception
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
PluginLoadResult = LoadedPlugin | PluginLoadError
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class FilesystemPluginLoader:
|
|
38
|
+
def __init__(self, logger: Logger | None = None) -> None:
|
|
39
|
+
self._logger = logger or NullLogger()
|
|
40
|
+
|
|
41
|
+
def load_all(self, plugin_dir: Path) -> list[PluginLoadResult]:
|
|
42
|
+
plugin_dir = plugin_dir.resolve()
|
|
43
|
+
if not plugin_dir.exists():
|
|
44
|
+
self._logger.info(f"插件目录不存在,已按空目录处理: {plugin_dir}")
|
|
45
|
+
return []
|
|
46
|
+
if not plugin_dir.is_dir():
|
|
47
|
+
return [
|
|
48
|
+
PluginLoadError(
|
|
49
|
+
name=plugin_dir.name,
|
|
50
|
+
plugin_dir=plugin_dir,
|
|
51
|
+
error=NotADirectoryError(str(plugin_dir)),
|
|
52
|
+
)
|
|
53
|
+
]
|
|
54
|
+
|
|
55
|
+
results: list[PluginLoadResult] = []
|
|
56
|
+
for entry in sorted(plugin_dir.iterdir(), key=lambda item: item.name):
|
|
57
|
+
if entry.name.startswith("_"):
|
|
58
|
+
continue
|
|
59
|
+
if entry.is_file() and entry.suffix == ".py":
|
|
60
|
+
results.append(self._load_file(entry))
|
|
61
|
+
elif entry.is_dir() and (entry / "__init__.py").is_file():
|
|
62
|
+
results.append(self._load_package(entry))
|
|
63
|
+
return results
|
|
64
|
+
|
|
65
|
+
def _load_file(self, path: Path) -> PluginLoadResult:
|
|
66
|
+
name = path.stem
|
|
67
|
+
try:
|
|
68
|
+
self._validate_plugin_name(name)
|
|
69
|
+
module = self._import_module(path, name)
|
|
70
|
+
plugin = self._create_plugin(module, name=name, version="0.1.0")
|
|
71
|
+
plugin_name = str(getattr(plugin, "name", name) or name)
|
|
72
|
+
self._validate_plugin_name(plugin_name)
|
|
73
|
+
version = str(getattr(plugin, "version", "0.1.0") or "0.1.0")
|
|
74
|
+
return LoadedPlugin(
|
|
75
|
+
name=plugin_name,
|
|
76
|
+
version=version,
|
|
77
|
+
plugin=plugin,
|
|
78
|
+
plugin_dir=path.parent,
|
|
79
|
+
config={},
|
|
80
|
+
)
|
|
81
|
+
except Exception as exc:
|
|
82
|
+
self._logger.exception(f"插件加载失败 ({path}): {exc}")
|
|
83
|
+
return PluginLoadError(name=name, plugin_dir=path.parent, error=exc)
|
|
84
|
+
|
|
85
|
+
def _load_package(self, path: Path) -> PluginLoadResult:
|
|
86
|
+
manifest_name = path.name
|
|
87
|
+
try:
|
|
88
|
+
metadata = self._read_manifest(path / "plugin.toml")
|
|
89
|
+
name = str(metadata.get("name") or path.name)
|
|
90
|
+
self._validate_plugin_name(name)
|
|
91
|
+
version = str(metadata.get("version") or "0.1.0")
|
|
92
|
+
config = metadata.get("config") or {}
|
|
93
|
+
if not isinstance(config, dict):
|
|
94
|
+
raise TypeError("plugin.toml 的 [config] 必须是 table")
|
|
95
|
+
|
|
96
|
+
module = self._import_module(path / "__init__.py", name)
|
|
97
|
+
plugin = self._create_plugin(module, name=name, version=version)
|
|
98
|
+
plugin_name = str(getattr(plugin, "name", name) or name)
|
|
99
|
+
self._validate_plugin_name(plugin_name)
|
|
100
|
+
plugin_version = str(getattr(plugin, "version", version) or version)
|
|
101
|
+
return LoadedPlugin(
|
|
102
|
+
name=plugin_name,
|
|
103
|
+
version=plugin_version,
|
|
104
|
+
plugin=plugin,
|
|
105
|
+
plugin_dir=path,
|
|
106
|
+
config=dict(config),
|
|
107
|
+
)
|
|
108
|
+
except Exception as exc:
|
|
109
|
+
self._logger.exception(f"插件加载失败 ({path}): {exc}")
|
|
110
|
+
return PluginLoadError(name=manifest_name, plugin_dir=path, error=exc)
|
|
111
|
+
|
|
112
|
+
def _read_manifest(self, path: Path) -> dict[str, Any]:
|
|
113
|
+
if not path.is_file():
|
|
114
|
+
return {}
|
|
115
|
+
with path.open("rb") as file:
|
|
116
|
+
return tomllib.load(file)
|
|
117
|
+
|
|
118
|
+
def _import_module(self, path: Path, plugin_name: str) -> ModuleType:
|
|
119
|
+
module_name = self._module_name(path, plugin_name)
|
|
120
|
+
sys.modules.setdefault("neobot_user_plugins", ModuleType("neobot_user_plugins"))
|
|
121
|
+
spec = importlib.util.spec_from_file_location(module_name, path)
|
|
122
|
+
if spec is None or spec.loader is None:
|
|
123
|
+
raise ImportError(f"无法创建插件模块 spec: {path}")
|
|
124
|
+
module = importlib.util.module_from_spec(spec)
|
|
125
|
+
if path.name == "__init__.py":
|
|
126
|
+
module.__package__ = module_name
|
|
127
|
+
module.__path__ = [str(path.parent)]
|
|
128
|
+
sys.modules[module_name] = module
|
|
129
|
+
try:
|
|
130
|
+
spec.loader.exec_module(module)
|
|
131
|
+
except Exception:
|
|
132
|
+
sys.modules.pop(module_name, None)
|
|
133
|
+
raise
|
|
134
|
+
return module
|
|
135
|
+
|
|
136
|
+
def _create_plugin(self, module: ModuleType, *, name: str, version: str) -> Any:
|
|
137
|
+
setup = getattr(module, "setup", None)
|
|
138
|
+
if callable(setup):
|
|
139
|
+
return FunctionPlugin(name=name, version=version, setup=setup)
|
|
140
|
+
|
|
141
|
+
plugin = getattr(module, "plugin", None)
|
|
142
|
+
if plugin is not None:
|
|
143
|
+
return plugin
|
|
144
|
+
|
|
145
|
+
create_plugin = getattr(module, "create_plugin", None)
|
|
146
|
+
if callable(create_plugin):
|
|
147
|
+
return create_plugin()
|
|
148
|
+
|
|
149
|
+
raise ValueError("插件模块未导出 setup(ctx)、plugin 或 create_plugin()")
|
|
150
|
+
|
|
151
|
+
def _module_name(self, path: Path, plugin_name: str) -> str:
|
|
152
|
+
digest = hashlib.sha1(str(path.resolve()).encode("utf-8")).hexdigest()[:12]
|
|
153
|
+
safe_name = re.sub(r"\W", "_", plugin_name)
|
|
154
|
+
return f"neobot_user_plugins.{safe_name}_{digest}"
|
|
155
|
+
|
|
156
|
+
def _validate_plugin_name(self, name: str) -> None:
|
|
157
|
+
if not name:
|
|
158
|
+
raise ValueError("插件名不能为空")
|
|
159
|
+
if "/" in name or "\\" in name or name in {".", ".."}:
|
|
160
|
+
raise ValueError(f"非法插件名: {name!r}")
|
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import inspect
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from neobot_contracts.ports.logging import Logger, NullLogger
|
|
8
|
+
from neobot_contracts.ports.plugin import PluginState
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass(slots=True)
|
|
12
|
+
class PluginRecord:
|
|
13
|
+
name: str
|
|
14
|
+
plugin: Any
|
|
15
|
+
context: Any
|
|
16
|
+
state: PluginState = PluginState.UNLOADED
|
|
17
|
+
subscriptions: list[Any] = field(default_factory=list)
|
|
18
|
+
agent_registrations: list[tuple[str, Any]] = field(default_factory=list)
|
|
19
|
+
error: Exception | None = None
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class DefaultPluginManager:
|
|
23
|
+
def __init__(self, logger: Logger | None = None) -> None:
|
|
24
|
+
self._logger = logger or NullLogger()
|
|
25
|
+
self._records: dict[str, PluginRecord] = {}
|
|
26
|
+
|
|
27
|
+
def register(self, plugin: Any, context: Any) -> None:
|
|
28
|
+
name = context.plugin_name
|
|
29
|
+
if name in self._records:
|
|
30
|
+
raise ValueError(f"插件已注册: {name}")
|
|
31
|
+
self._records[name] = PluginRecord(name=name, plugin=plugin, context=context)
|
|
32
|
+
|
|
33
|
+
def get_plugin(self, name: str) -> Any | None:
|
|
34
|
+
record = self._records.get(name)
|
|
35
|
+
return record.plugin if record is not None else None
|
|
36
|
+
|
|
37
|
+
def get_state(self, name: str) -> PluginState:
|
|
38
|
+
record = self._records.get(name)
|
|
39
|
+
return record.state if record is not None else PluginState.UNLOADED
|
|
40
|
+
|
|
41
|
+
def get_subscriptions(self, name: str) -> list[Any]:
|
|
42
|
+
record = self._records.get(name)
|
|
43
|
+
if record is None:
|
|
44
|
+
return []
|
|
45
|
+
return list(record.subscriptions)
|
|
46
|
+
|
|
47
|
+
def record_subscription(self, name: str, subscription: Any) -> None:
|
|
48
|
+
record = self._records.get(name)
|
|
49
|
+
if record is None:
|
|
50
|
+
raise KeyError(f"插件未注册: {name}")
|
|
51
|
+
record.subscriptions.append(subscription)
|
|
52
|
+
|
|
53
|
+
def record_agent_registration(self, name: str, registered_name: str, agent: Any) -> None:
|
|
54
|
+
record = self._records.get(name)
|
|
55
|
+
if record is None:
|
|
56
|
+
raise KeyError(f"插件未注册: {name}")
|
|
57
|
+
record.agent_registrations.append((registered_name, agent))
|
|
58
|
+
|
|
59
|
+
async def load_plugin(self, name: str) -> None:
|
|
60
|
+
record = self._records[name]
|
|
61
|
+
if record.state not in {PluginState.UNLOADED, PluginState.STOPPED}:
|
|
62
|
+
return
|
|
63
|
+
try:
|
|
64
|
+
await self._maybe_await(record.plugin.on_load(record.context))
|
|
65
|
+
except Exception as exc:
|
|
66
|
+
record.error = exc
|
|
67
|
+
record.state = PluginState.ERROR
|
|
68
|
+
self._logger.exception(f"插件加载失败 ({name}): {exc}")
|
|
69
|
+
self._unsubscribe_all(record)
|
|
70
|
+
await self._cleanup_agents(record)
|
|
71
|
+
return
|
|
72
|
+
record.state = PluginState.LOADED
|
|
73
|
+
record.error = None
|
|
74
|
+
|
|
75
|
+
async def start_plugin(self, name: str) -> None:
|
|
76
|
+
record = self._records[name]
|
|
77
|
+
if record.state is PluginState.STOPPED:
|
|
78
|
+
await self.load_plugin(name)
|
|
79
|
+
if record.state is not PluginState.LOADED:
|
|
80
|
+
return
|
|
81
|
+
try:
|
|
82
|
+
await self._maybe_await(record.plugin.on_start())
|
|
83
|
+
except Exception as exc:
|
|
84
|
+
record.error = exc
|
|
85
|
+
record.state = PluginState.ERROR
|
|
86
|
+
self._logger.exception(f"插件启动失败 ({name}): {exc}")
|
|
87
|
+
self._unsubscribe_all(record)
|
|
88
|
+
await self._cleanup_agents(record)
|
|
89
|
+
return
|
|
90
|
+
record.state = PluginState.RUNNING
|
|
91
|
+
record.error = None
|
|
92
|
+
|
|
93
|
+
async def stop_plugin(self, name: str) -> None:
|
|
94
|
+
record = self._records[name]
|
|
95
|
+
if record.state in {PluginState.UNLOADED, PluginState.STOPPED}:
|
|
96
|
+
return
|
|
97
|
+
should_mark_stopped = record.state is not PluginState.ERROR
|
|
98
|
+
if record.state in {PluginState.LOADED, PluginState.RUNNING}:
|
|
99
|
+
try:
|
|
100
|
+
await self._maybe_await(record.plugin.on_stop())
|
|
101
|
+
except Exception as exc:
|
|
102
|
+
record.error = exc
|
|
103
|
+
self._logger.exception(f"插件停止失败 ({name}): {exc}")
|
|
104
|
+
self._unsubscribe_all(record)
|
|
105
|
+
await self._cleanup_agents(record)
|
|
106
|
+
if should_mark_stopped:
|
|
107
|
+
record.state = PluginState.STOPPED
|
|
108
|
+
|
|
109
|
+
async def load_all(self) -> None:
|
|
110
|
+
for name in list(self._records):
|
|
111
|
+
await self.load_plugin(name)
|
|
112
|
+
|
|
113
|
+
async def start_all(self) -> None:
|
|
114
|
+
for name in list(self._records):
|
|
115
|
+
await self.start_plugin(name)
|
|
116
|
+
|
|
117
|
+
async def stop_all(self) -> None:
|
|
118
|
+
for name in reversed(list(self._records)):
|
|
119
|
+
await self.stop_plugin(name)
|
|
120
|
+
|
|
121
|
+
async def _maybe_await(self, value: Any) -> Any:
|
|
122
|
+
if inspect.isawaitable(value):
|
|
123
|
+
return await value
|
|
124
|
+
return value
|
|
125
|
+
|
|
126
|
+
def _unsubscribe_all(self, record: PluginRecord) -> None:
|
|
127
|
+
subscriptions = record.subscriptions
|
|
128
|
+
record.subscriptions = []
|
|
129
|
+
for subscription in subscriptions:
|
|
130
|
+
try:
|
|
131
|
+
subscription.unsubscribe()
|
|
132
|
+
except Exception as exc:
|
|
133
|
+
self._logger.exception(f"插件订阅清理失败 ({record.name}): {exc}")
|
|
134
|
+
|
|
135
|
+
async def _cleanup_agents(self, record: PluginRecord) -> None:
|
|
136
|
+
registrations = record.agent_registrations
|
|
137
|
+
record.agent_registrations = []
|
|
138
|
+
registrar = getattr(record.context, "agents", None)
|
|
139
|
+
for registered_name, agent in registrations:
|
|
140
|
+
try:
|
|
141
|
+
unregister = getattr(registrar, "unregister", None)
|
|
142
|
+
if callable(unregister):
|
|
143
|
+
unregister(registered_name)
|
|
144
|
+
except Exception as exc:
|
|
145
|
+
self._logger.exception(f"插件 Agent 注销失败 ({record.name}/{registered_name}): {exc}")
|
|
146
|
+
try:
|
|
147
|
+
close = getattr(agent, "close", None)
|
|
148
|
+
if callable(close):
|
|
149
|
+
await self._maybe_await(close())
|
|
150
|
+
except Exception as exc:
|
|
151
|
+
self._logger.exception(f"插件 Agent 关闭失败 ({record.name}/{registered_name}): {exc}")
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import inspect
|
|
4
|
+
from collections.abc import Callable
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from neobot_contracts.ports.plugin import PluginContext
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class BasePlugin:
|
|
11
|
+
name: str = ""
|
|
12
|
+
version: str = "0.1.0"
|
|
13
|
+
|
|
14
|
+
async def on_load(self, ctx: PluginContext) -> None:
|
|
15
|
+
pass
|
|
16
|
+
|
|
17
|
+
async def on_start(self) -> None:
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
async def on_stop(self) -> None:
|
|
21
|
+
pass
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class FunctionPlugin(BasePlugin):
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
*,
|
|
28
|
+
name: str,
|
|
29
|
+
setup: Callable[[PluginContext], Any],
|
|
30
|
+
version: str = "0.1.0",
|
|
31
|
+
) -> None:
|
|
32
|
+
self.name = name
|
|
33
|
+
self.version = version
|
|
34
|
+
self._setup = setup
|
|
35
|
+
|
|
36
|
+
async def on_load(self, ctx: PluginContext) -> None:
|
|
37
|
+
result = self._setup(ctx)
|
|
38
|
+
if inspect.isawaitable(result):
|
|
39
|
+
await result
|
|
File without changes
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from neobot_contracts.ports.logging import Logger, NullLogger
|
|
7
|
+
|
|
8
|
+
from neobot_modloader.context import PluginContext
|
|
9
|
+
from neobot_modloader.hooks import PluginHookBus
|
|
10
|
+
from neobot_modloader.loader import FilesystemPluginLoader, LoadedPlugin, PluginLoadError
|
|
11
|
+
from neobot_modloader.manager import DefaultPluginManager
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class PluginRuntime:
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
*,
|
|
18
|
+
plugin_dir: Path,
|
|
19
|
+
data_dir: Path,
|
|
20
|
+
adapter: Any,
|
|
21
|
+
logger_factory: Any,
|
|
22
|
+
loader: FilesystemPluginLoader | None = None,
|
|
23
|
+
manager: DefaultPluginManager | None = None,
|
|
24
|
+
logger: Logger | None = None,
|
|
25
|
+
agent_registry: Any | None = None,
|
|
26
|
+
hook_bus: PluginHookBus | None = None,
|
|
27
|
+
record_ai_reply_block: Any | None = None,
|
|
28
|
+
) -> None:
|
|
29
|
+
self.plugin_dir = plugin_dir.resolve()
|
|
30
|
+
self.data_dir = data_dir.resolve()
|
|
31
|
+
self.adapter = adapter
|
|
32
|
+
self.logger_factory = logger_factory
|
|
33
|
+
self.agent_registry = agent_registry
|
|
34
|
+
self.record_ai_reply_block = record_ai_reply_block
|
|
35
|
+
self.logger = logger or self._get_logger("modloader.runtime")
|
|
36
|
+
self.hook_bus = hook_bus or PluginHookBus(
|
|
37
|
+
logger=self._get_logger("modloader.hooks"),
|
|
38
|
+
record_ai_reply_block=record_ai_reply_block,
|
|
39
|
+
)
|
|
40
|
+
self.loader = loader or FilesystemPluginLoader(
|
|
41
|
+
logger=self._get_logger("modloader.loader")
|
|
42
|
+
)
|
|
43
|
+
self.manager = manager or DefaultPluginManager(
|
|
44
|
+
logger=self._get_logger("modloader.manager")
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
def load_all(self) -> None:
|
|
48
|
+
self.plugin_dir.mkdir(parents=True, exist_ok=True)
|
|
49
|
+
self.data_dir.mkdir(parents=True, exist_ok=True)
|
|
50
|
+
self.logger.info(f"插件目录: {self.plugin_dir}")
|
|
51
|
+
|
|
52
|
+
results = self.loader.load_all(self.plugin_dir)
|
|
53
|
+
loaded_count = 0
|
|
54
|
+
error_count = 0
|
|
55
|
+
for result in results:
|
|
56
|
+
if isinstance(result, PluginLoadError):
|
|
57
|
+
error_count += 1
|
|
58
|
+
self.logger.error(f"插件加载跳过 ({result.name}): {result.error}")
|
|
59
|
+
continue
|
|
60
|
+
self._register(result)
|
|
61
|
+
loaded_count += 1
|
|
62
|
+
self.logger.info(f"插件扫描完成: loaded={loaded_count}, errors={error_count}")
|
|
63
|
+
|
|
64
|
+
async def load_registered(self) -> None:
|
|
65
|
+
await self.manager.load_all()
|
|
66
|
+
|
|
67
|
+
async def start_all(self) -> None:
|
|
68
|
+
await self.manager.load_all()
|
|
69
|
+
await self.manager.start_all()
|
|
70
|
+
|
|
71
|
+
async def stop_all(self) -> None:
|
|
72
|
+
await self.manager.stop_all()
|
|
73
|
+
|
|
74
|
+
def _register(self, loaded: LoadedPlugin) -> None:
|
|
75
|
+
logger = self._get_logger(f"plugin.{loaded.name}")
|
|
76
|
+
context = PluginContext(
|
|
77
|
+
plugin_name=loaded.name,
|
|
78
|
+
plugin_dir=loaded.plugin_dir,
|
|
79
|
+
data_dir=self.data_dir / loaded.name,
|
|
80
|
+
config=loaded.config,
|
|
81
|
+
logger=logger,
|
|
82
|
+
adapter=self.adapter,
|
|
83
|
+
hook_bus=self.hook_bus,
|
|
84
|
+
record_subscription=lambda subscription, name=loaded.name: self.manager.record_subscription(
|
|
85
|
+
name, subscription
|
|
86
|
+
),
|
|
87
|
+
agent_registry=self.agent_registry,
|
|
88
|
+
record_agent_registration=lambda registered_name, agent, name=loaded.name: self.manager.record_agent_registration(
|
|
89
|
+
name, registered_name, agent
|
|
90
|
+
),
|
|
91
|
+
)
|
|
92
|
+
try:
|
|
93
|
+
self.manager.register(loaded.plugin, context)
|
|
94
|
+
except Exception as exc:
|
|
95
|
+
self.logger.exception(f"插件注册失败 ({loaded.name}): {exc}")
|
|
96
|
+
|
|
97
|
+
def _get_logger(self, name: str) -> Logger:
|
|
98
|
+
get_logger = getattr(self.logger_factory, "get_logger", None)
|
|
99
|
+
if callable(get_logger):
|
|
100
|
+
return get_logger(name)
|
|
101
|
+
return NullLogger()
|