squidbot 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,416 @@
1
+ """
2
+ Plugin Hooks System
3
+
4
+ Provides lifecycle hooks for plugins to intercept and modify behavior.
5
+ Inspired by OpenClaw's hook architecture.
6
+ """
7
+
8
+ import asyncio
9
+ import logging
10
+ from dataclasses import dataclass, field
11
+ from enum import Enum
12
+ from typing import Any, Awaitable, Callable
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class HookName(str, Enum):
18
+ """Available hook types."""
19
+
20
+ # Agent lifecycle hooks
21
+ BEFORE_AGENT_START = "before_agent_start"
22
+ AGENT_END = "agent_end"
23
+
24
+ # Message hooks
25
+ MESSAGE_RECEIVED = "message_received"
26
+ MESSAGE_SENDING = "message_sending"
27
+ MESSAGE_SENT = "message_sent"
28
+
29
+ # Tool hooks
30
+ BEFORE_TOOL_CALL = "before_tool_call"
31
+ AFTER_TOOL_CALL = "after_tool_call"
32
+
33
+ # Session hooks
34
+ SESSION_START = "session_start"
35
+ SESSION_END = "session_end"
36
+
37
+
38
+ # ============================================================
39
+ # Hook Event/Context/Result Types
40
+ # ============================================================
41
+
42
+
43
+ @dataclass
44
+ class BeforeAgentStartEvent:
45
+ """Event for before_agent_start hook."""
46
+
47
+ prompt: str
48
+ messages: list[dict] = field(default_factory=list)
49
+ session_id: str | None = None
50
+
51
+
52
+ @dataclass
53
+ class BeforeAgentStartResult:
54
+ """Result from before_agent_start hook."""
55
+
56
+ system_prompt: str | None = None # Override system prompt
57
+ prepend_context: str | None = None # Prepend to context
58
+
59
+
60
+ @dataclass
61
+ class AgentEndEvent:
62
+ """Event for agent_end hook."""
63
+
64
+ messages: list[dict]
65
+ success: bool
66
+ response: str = ""
67
+ error: str | None = None
68
+ duration_ms: float | None = None
69
+
70
+
71
+ @dataclass
72
+ class MessageReceivedEvent:
73
+ """Event for message_received hook."""
74
+
75
+ sender: str
76
+ content: str
77
+ channel: str
78
+ metadata: dict = field(default_factory=dict)
79
+
80
+
81
+ @dataclass
82
+ class MessageSendingEvent:
83
+ """Event for message_sending hook."""
84
+
85
+ recipient: str
86
+ content: str
87
+ channel: str
88
+ metadata: dict = field(default_factory=dict)
89
+
90
+
91
+ @dataclass
92
+ class MessageSendingResult:
93
+ """Result from message_sending hook."""
94
+
95
+ content: str | None = None # Modified content
96
+ cancel: bool = False # Prevent sending
97
+
98
+
99
+ @dataclass
100
+ class MessageSentEvent:
101
+ """Event for message_sent hook."""
102
+
103
+ recipient: str
104
+ content: str
105
+ channel: str
106
+ success: bool
107
+ error: str | None = None
108
+
109
+
110
+ @dataclass
111
+ class BeforeToolCallEvent:
112
+ """Event for before_tool_call hook."""
113
+
114
+ tool_name: str
115
+ params: dict = field(default_factory=dict)
116
+
117
+
118
+ @dataclass
119
+ class BeforeToolCallResult:
120
+ """Result from before_tool_call hook."""
121
+
122
+ params: dict | None = None # Modified params
123
+ block: bool = False # Prevent tool call
124
+ block_reason: str | None = None
125
+
126
+
127
+ @dataclass
128
+ class AfterToolCallEvent:
129
+ """Event for after_tool_call hook."""
130
+
131
+ tool_name: str
132
+ params: dict
133
+ result: Any = None
134
+ error: str | None = None
135
+ duration_ms: float | None = None
136
+
137
+
138
+ @dataclass
139
+ class SessionStartEvent:
140
+ """Event for session_start hook."""
141
+
142
+ session_id: str
143
+ channel: str
144
+ user_id: str | None = None
145
+
146
+
147
+ @dataclass
148
+ class SessionEndEvent:
149
+ """Event for session_end hook."""
150
+
151
+ session_id: str
152
+ message_count: int
153
+ duration_ms: float | None = None
154
+
155
+
156
+ @dataclass
157
+ class HookContext:
158
+ """Context passed to all hooks."""
159
+
160
+ plugin_id: str
161
+ session_id: str | None = None
162
+ channel: str | None = None
163
+ metadata: dict = field(default_factory=dict)
164
+
165
+
166
+ # ============================================================
167
+ # Hook Registration
168
+ # ============================================================
169
+
170
+
171
+ @dataclass
172
+ class HookRegistration:
173
+ """A registered hook handler."""
174
+
175
+ plugin_id: str
176
+ hook_name: HookName
177
+ handler: Callable[..., Awaitable[Any] | Any]
178
+ priority: int = 0 # Higher = runs first
179
+
180
+
181
+ class HookRegistry:
182
+ """Registry for all hook handlers."""
183
+
184
+ def __init__(self):
185
+ self._hooks: list[HookRegistration] = []
186
+
187
+ def register(
188
+ self,
189
+ plugin_id: str,
190
+ hook_name: HookName,
191
+ handler: Callable,
192
+ priority: int = 0,
193
+ ) -> None:
194
+ """Register a hook handler."""
195
+ self._hooks.append(
196
+ HookRegistration(
197
+ plugin_id=plugin_id,
198
+ hook_name=hook_name,
199
+ handler=handler,
200
+ priority=priority,
201
+ )
202
+ )
203
+ logger.debug(
204
+ f"Registered hook {hook_name.value} from plugin {plugin_id} (priority={priority})"
205
+ )
206
+
207
+ def unregister(self, plugin_id: str) -> int:
208
+ """Unregister all hooks for a plugin. Returns count removed."""
209
+ before = len(self._hooks)
210
+ self._hooks = [h for h in self._hooks if h.plugin_id != plugin_id]
211
+ return before - len(self._hooks)
212
+
213
+ def get_hooks(self, hook_name: HookName) -> list[HookRegistration]:
214
+ """Get all hooks for a given hook name, sorted by priority (highest first)."""
215
+ return sorted(
216
+ [h for h in self._hooks if h.hook_name == hook_name],
217
+ key=lambda h: h.priority,
218
+ reverse=True,
219
+ )
220
+
221
+ def has_hooks(self, hook_name: HookName) -> bool:
222
+ """Check if any hooks are registered for a hook name."""
223
+ return any(h.hook_name == hook_name for h in self._hooks)
224
+
225
+ def get_hook_count(self, hook_name: HookName) -> int:
226
+ """Get count of hooks for a hook name."""
227
+ return sum(1 for h in self._hooks if h.hook_name == hook_name)
228
+
229
+ def list_all(self) -> list[dict]:
230
+ """List all registered hooks."""
231
+ return [
232
+ {
233
+ "plugin_id": h.plugin_id,
234
+ "hook_name": h.hook_name.value,
235
+ "priority": h.priority,
236
+ }
237
+ for h in self._hooks
238
+ ]
239
+
240
+
241
+ # ============================================================
242
+ # Hook Runner
243
+ # ============================================================
244
+
245
+
246
+ class HookRunner:
247
+ """Executes hooks with proper ordering and error handling."""
248
+
249
+ def __init__(self, registry: HookRegistry, catch_errors: bool = True):
250
+ self._registry = registry
251
+ self._catch_errors = catch_errors
252
+
253
+ async def _run_handler(
254
+ self, hook: HookRegistration, event: Any, ctx: HookContext
255
+ ) -> Any:
256
+ """Run a single hook handler."""
257
+ try:
258
+ result = hook.handler(event, ctx)
259
+ if asyncio.iscoroutine(result):
260
+ result = await result
261
+ return result
262
+ except Exception as e:
263
+ if self._catch_errors:
264
+ logger.error(
265
+ f"Hook {hook.hook_name.value} from {hook.plugin_id} failed: {e}"
266
+ )
267
+ return None
268
+ raise
269
+
270
+ async def _run_void_hook(
271
+ self, hook_name: HookName, event: Any, ctx: HookContext
272
+ ) -> None:
273
+ """Run a void hook (fire-and-forget, parallel execution)."""
274
+ hooks = self._registry.get_hooks(hook_name)
275
+ if not hooks:
276
+ return
277
+
278
+ tasks = [self._run_handler(hook, event, ctx) for hook in hooks]
279
+ await asyncio.gather(*tasks, return_exceptions=self._catch_errors)
280
+
281
+ async def _run_modifying_hook(
282
+ self,
283
+ hook_name: HookName,
284
+ event: Any,
285
+ ctx: HookContext,
286
+ merge_fn: Callable[[Any, Any], Any] | None = None,
287
+ ) -> Any:
288
+ """Run a modifying hook (sequential, ordered by priority)."""
289
+ hooks = self._registry.get_hooks(hook_name)
290
+ if not hooks:
291
+ return None
292
+
293
+ result = None
294
+ for hook in hooks:
295
+ handler_result = await self._run_handler(hook, event, ctx)
296
+ if handler_result is not None:
297
+ if result is not None and merge_fn:
298
+ result = merge_fn(result, handler_result)
299
+ else:
300
+ result = handler_result
301
+
302
+ return result
303
+
304
+ # ============================================================
305
+ # Agent Hooks
306
+ # ============================================================
307
+
308
+ async def run_before_agent_start(
309
+ self, event: BeforeAgentStartEvent, ctx: HookContext
310
+ ) -> BeforeAgentStartResult | None:
311
+ """Run before_agent_start hooks."""
312
+
313
+ def merge(acc: BeforeAgentStartResult, next_: BeforeAgentStartResult):
314
+ return BeforeAgentStartResult(
315
+ system_prompt=next_.system_prompt or acc.system_prompt,
316
+ prepend_context=(
317
+ f"{acc.prepend_context}\n\n{next_.prepend_context}"
318
+ if acc.prepend_context and next_.prepend_context
319
+ else (next_.prepend_context or acc.prepend_context)
320
+ ),
321
+ )
322
+
323
+ return await self._run_modifying_hook(
324
+ HookName.BEFORE_AGENT_START, event, ctx, merge
325
+ )
326
+
327
+ async def run_agent_end(self, event: AgentEndEvent, ctx: HookContext) -> None:
328
+ """Run agent_end hooks."""
329
+ await self._run_void_hook(HookName.AGENT_END, event, ctx)
330
+
331
+ # ============================================================
332
+ # Message Hooks
333
+ # ============================================================
334
+
335
+ async def run_message_received(
336
+ self, event: MessageReceivedEvent, ctx: HookContext
337
+ ) -> None:
338
+ """Run message_received hooks."""
339
+ await self._run_void_hook(HookName.MESSAGE_RECEIVED, event, ctx)
340
+
341
+ async def run_message_sending(
342
+ self, event: MessageSendingEvent, ctx: HookContext
343
+ ) -> MessageSendingResult | None:
344
+ """Run message_sending hooks."""
345
+
346
+ def merge(acc: MessageSendingResult, next_: MessageSendingResult):
347
+ return MessageSendingResult(
348
+ content=next_.content or acc.content,
349
+ cancel=next_.cancel or acc.cancel,
350
+ )
351
+
352
+ return await self._run_modifying_hook(
353
+ HookName.MESSAGE_SENDING, event, ctx, merge
354
+ )
355
+
356
+ async def run_message_sent(self, event: MessageSentEvent, ctx: HookContext) -> None:
357
+ """Run message_sent hooks."""
358
+ await self._run_void_hook(HookName.MESSAGE_SENT, event, ctx)
359
+
360
+ # ============================================================
361
+ # Tool Hooks
362
+ # ============================================================
363
+
364
+ async def run_before_tool_call(
365
+ self, event: BeforeToolCallEvent, ctx: HookContext
366
+ ) -> BeforeToolCallResult | None:
367
+ """Run before_tool_call hooks."""
368
+
369
+ def merge(acc: BeforeToolCallResult, next_: BeforeToolCallResult):
370
+ return BeforeToolCallResult(
371
+ params=next_.params or acc.params,
372
+ block=next_.block or acc.block,
373
+ block_reason=next_.block_reason or acc.block_reason,
374
+ )
375
+
376
+ return await self._run_modifying_hook(
377
+ HookName.BEFORE_TOOL_CALL, event, ctx, merge
378
+ )
379
+
380
+ async def run_after_tool_call(
381
+ self, event: AfterToolCallEvent, ctx: HookContext
382
+ ) -> None:
383
+ """Run after_tool_call hooks."""
384
+ await self._run_void_hook(HookName.AFTER_TOOL_CALL, event, ctx)
385
+
386
+ # ============================================================
387
+ # Session Hooks
388
+ # ============================================================
389
+
390
+ async def run_session_start(
391
+ self, event: SessionStartEvent, ctx: HookContext
392
+ ) -> None:
393
+ """Run session_start hooks."""
394
+ await self._run_void_hook(HookName.SESSION_START, event, ctx)
395
+
396
+ async def run_session_end(self, event: SessionEndEvent, ctx: HookContext) -> None:
397
+ """Run session_end hooks."""
398
+ await self._run_void_hook(HookName.SESSION_END, event, ctx)
399
+
400
+
401
+ # ============================================================
402
+ # Global Instances
403
+ # ============================================================
404
+
405
+ _hook_registry = HookRegistry()
406
+ _hook_runner = HookRunner(_hook_registry)
407
+
408
+
409
+ def get_hook_registry() -> HookRegistry:
410
+ """Get the global hook registry."""
411
+ return _hook_registry
412
+
413
+
414
+ def get_hook_runner() -> HookRunner:
415
+ """Get the global hook runner."""
416
+ return _hook_runner
@@ -0,0 +1,248 @@
1
+ """
2
+ Plugin Loader and Registry
3
+
4
+ Handles plugin discovery, loading, and lifecycle management.
5
+ """
6
+
7
+ import importlib
8
+ import importlib.util
9
+ import logging
10
+ import pkgutil
11
+ from dataclasses import dataclass, field
12
+ from pathlib import Path
13
+ from typing import TYPE_CHECKING
14
+
15
+ if TYPE_CHECKING:
16
+ from .base import Plugin
17
+
18
+ from ..tools.base import Tool
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ @dataclass
24
+ class PluginInfo:
25
+ """Information about a loaded plugin."""
26
+
27
+ plugin: "Plugin"
28
+ enabled: bool = True
29
+ load_error: str | None = None
30
+ hook_count: int = 0
31
+
32
+
33
+ class PluginRegistry:
34
+ """Central registry for all loaded plugins."""
35
+
36
+ def __init__(self):
37
+ self._plugins: dict[str, PluginInfo] = {}
38
+ self._tool_cache: list[Tool] | None = None
39
+
40
+ def register(self, plugin: "Plugin") -> bool:
41
+ """Register a plugin. Returns True if successful."""
42
+ from .base import PluginApi
43
+ from .hooks import get_hook_registry
44
+
45
+ try:
46
+ manifest = plugin.manifest
47
+ plugin_id = manifest.id
48
+
49
+ if plugin_id in self._plugins:
50
+ logger.warning(f"Plugin '{plugin_id}' already registered, skipping")
51
+ return False
52
+
53
+ # Create plugin API for hook registration
54
+ hook_registry = get_hook_registry()
55
+ api = PluginApi(plugin_id, hook_registry)
56
+
57
+ # Register hooks first
58
+ plugin.register_hooks(api)
59
+ hook_count = (
60
+ hook_registry.get_hook_count_for_plugin(plugin_id)
61
+ if hasattr(hook_registry, "get_hook_count_for_plugin")
62
+ else 0
63
+ )
64
+
65
+ # Then activate
66
+ plugin.activate()
67
+
68
+ self._plugins[plugin_id] = PluginInfo(
69
+ plugin=plugin,
70
+ enabled=True,
71
+ hook_count=hook_count,
72
+ )
73
+ self._tool_cache = None # Invalidate cache
74
+
75
+ logger.info(f"Registered plugin: {manifest.name} v{manifest.version}")
76
+ return True
77
+
78
+ except Exception as e:
79
+ logger.error(f"Failed to register plugin: {e}")
80
+ return False
81
+
82
+ def unregister(self, plugin_id: str) -> bool:
83
+ """Unregister a plugin by ID."""
84
+ from .hooks import get_hook_registry
85
+
86
+ if plugin_id not in self._plugins:
87
+ return False
88
+
89
+ info = self._plugins[plugin_id]
90
+
91
+ # Unregister hooks
92
+ hook_registry = get_hook_registry()
93
+ hook_registry.unregister(plugin_id)
94
+
95
+ # Deactivate plugin
96
+ try:
97
+ info.plugin.deactivate()
98
+ except Exception as e:
99
+ logger.warning(f"Error deactivating plugin '{plugin_id}': {e}")
100
+
101
+ del self._plugins[plugin_id]
102
+ self._tool_cache = None
103
+ return True
104
+
105
+ def get_plugin(self, plugin_id: str) -> "Plugin | None":
106
+ """Get a plugin by ID."""
107
+ info = self._plugins.get(plugin_id)
108
+ return info.plugin if info else None
109
+
110
+ def get_all_plugins(self) -> list["Plugin"]:
111
+ """Get all registered plugins."""
112
+ return [info.plugin for info in self._plugins.values() if info.enabled]
113
+
114
+ def get_all_tools(self) -> list[Tool]:
115
+ """Get all tools from all enabled plugins."""
116
+ if self._tool_cache is not None:
117
+ return self._tool_cache
118
+
119
+ tools = []
120
+ for info in self._plugins.values():
121
+ if info.enabled:
122
+ try:
123
+ tools.extend(info.plugin.get_tools())
124
+ except Exception as e:
125
+ logger.error(
126
+ f"Error getting tools from plugin '{info.plugin.manifest.id}': {e}"
127
+ )
128
+
129
+ self._tool_cache = tools
130
+ return tools
131
+
132
+ def enable_plugin(self, plugin_id: str) -> bool:
133
+ """Enable a plugin."""
134
+ if plugin_id not in self._plugins:
135
+ return False
136
+ self._plugins[plugin_id].enabled = True
137
+ self._tool_cache = None
138
+ return True
139
+
140
+ def disable_plugin(self, plugin_id: str) -> bool:
141
+ """Disable a plugin."""
142
+ if plugin_id not in self._plugins:
143
+ return False
144
+ self._plugins[plugin_id].enabled = False
145
+ self._tool_cache = None
146
+ return True
147
+
148
+ def list_plugins(self) -> list[dict]:
149
+ """List all plugins with their status."""
150
+ from .hooks import get_hook_registry
151
+
152
+ hook_registry = get_hook_registry()
153
+ result = []
154
+
155
+ for plugin_id, info in self._plugins.items():
156
+ manifest = info.plugin.manifest
157
+
158
+ # Count hooks for this plugin
159
+ hooks = [h for h in hook_registry.list_all() if h["plugin_id"] == plugin_id]
160
+
161
+ result.append(
162
+ {
163
+ "id": plugin_id,
164
+ "name": manifest.name,
165
+ "version": manifest.version,
166
+ "description": manifest.description,
167
+ "enabled": info.enabled,
168
+ "tools": [t.name for t in info.plugin.get_tools()],
169
+ "hooks": [h["hook_name"] for h in hooks],
170
+ }
171
+ )
172
+ return result
173
+
174
+
175
+ # Global registry instance
176
+ _registry = PluginRegistry()
177
+
178
+
179
+ def get_registry() -> PluginRegistry:
180
+ """Get the global plugin registry."""
181
+ return _registry
182
+
183
+
184
+ def load_builtin_plugins() -> None:
185
+ """Load all built-in plugins from the plugins directory."""
186
+ from .base import Plugin
187
+
188
+ plugins_dir = Path(__file__).parent
189
+
190
+ # Find all Python files in plugins directory (excluding base.py, loader.py, hooks.py, __init__.py)
191
+ excluded = {"base", "loader", "hooks", "__init__"}
192
+
193
+ for _, module_name, _ in pkgutil.iter_modules([str(plugins_dir)]):
194
+ if module_name in excluded:
195
+ continue
196
+
197
+ try:
198
+ module = importlib.import_module(
199
+ f".{module_name}", package="squidbot.plugins"
200
+ )
201
+
202
+ # Look for get_plugin() function or Plugin subclass
203
+ if hasattr(module, "get_plugin"):
204
+ plugin = module.get_plugin()
205
+ if isinstance(plugin, Plugin):
206
+ _registry.register(plugin)
207
+ else:
208
+ # Search for Plugin subclasses
209
+ for attr_name in dir(module):
210
+ attr = getattr(module, attr_name)
211
+ if (
212
+ isinstance(attr, type)
213
+ and issubclass(attr, Plugin)
214
+ and attr is not Plugin
215
+ ):
216
+ try:
217
+ plugin = attr()
218
+ _registry.register(plugin)
219
+ except Exception as e:
220
+ logger.error(
221
+ f"Failed to instantiate plugin {attr_name}: {e}"
222
+ )
223
+
224
+ except Exception as e:
225
+ logger.error(f"Failed to load plugin module '{module_name}': {e}")
226
+
227
+
228
+ def load_external_plugins(plugins_dir: Path) -> None:
229
+ """Load plugins from an external directory."""
230
+ if not plugins_dir.exists():
231
+ return
232
+
233
+ for plugin_path in plugins_dir.glob("*.py"):
234
+ if plugin_path.name.startswith("_"):
235
+ continue
236
+
237
+ try:
238
+ spec = importlib.util.spec_from_file_location(plugin_path.stem, plugin_path)
239
+ if spec and spec.loader:
240
+ module = importlib.util.module_from_spec(spec)
241
+ spec.loader.exec_module(module)
242
+
243
+ if hasattr(module, "get_plugin"):
244
+ plugin = module.get_plugin()
245
+ _registry.register(plugin)
246
+
247
+ except Exception as e:
248
+ logger.error(f"Failed to load external plugin '{plugin_path.name}': {e}")