aury-agent 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aury/__init__.py +2 -0
- aury/agents/__init__.py +55 -0
- aury/agents/a2a/__init__.py +168 -0
- aury/agents/backends/__init__.py +196 -0
- aury/agents/backends/artifact/__init__.py +9 -0
- aury/agents/backends/artifact/memory.py +130 -0
- aury/agents/backends/artifact/types.py +133 -0
- aury/agents/backends/code/__init__.py +65 -0
- aury/agents/backends/file/__init__.py +11 -0
- aury/agents/backends/file/local.py +66 -0
- aury/agents/backends/file/types.py +40 -0
- aury/agents/backends/invocation/__init__.py +8 -0
- aury/agents/backends/invocation/memory.py +81 -0
- aury/agents/backends/invocation/types.py +110 -0
- aury/agents/backends/memory/__init__.py +8 -0
- aury/agents/backends/memory/memory.py +179 -0
- aury/agents/backends/memory/types.py +136 -0
- aury/agents/backends/message/__init__.py +9 -0
- aury/agents/backends/message/memory.py +122 -0
- aury/agents/backends/message/types.py +124 -0
- aury/agents/backends/sandbox.py +275 -0
- aury/agents/backends/session/__init__.py +8 -0
- aury/agents/backends/session/memory.py +93 -0
- aury/agents/backends/session/types.py +124 -0
- aury/agents/backends/shell/__init__.py +11 -0
- aury/agents/backends/shell/local.py +110 -0
- aury/agents/backends/shell/types.py +55 -0
- aury/agents/backends/shell.py +209 -0
- aury/agents/backends/snapshot/__init__.py +19 -0
- aury/agents/backends/snapshot/git.py +95 -0
- aury/agents/backends/snapshot/hybrid.py +125 -0
- aury/agents/backends/snapshot/memory.py +86 -0
- aury/agents/backends/snapshot/types.py +59 -0
- aury/agents/backends/state/__init__.py +29 -0
- aury/agents/backends/state/composite.py +49 -0
- aury/agents/backends/state/file.py +57 -0
- aury/agents/backends/state/memory.py +52 -0
- aury/agents/backends/state/sqlite.py +262 -0
- aury/agents/backends/state/types.py +178 -0
- aury/agents/backends/subagent/__init__.py +165 -0
- aury/agents/cli/__init__.py +41 -0
- aury/agents/cli/chat.py +239 -0
- aury/agents/cli/config.py +236 -0
- aury/agents/cli/extensions.py +460 -0
- aury/agents/cli/main.py +189 -0
- aury/agents/cli/session.py +337 -0
- aury/agents/cli/workflow.py +276 -0
- aury/agents/context_providers/__init__.py +66 -0
- aury/agents/context_providers/artifact.py +299 -0
- aury/agents/context_providers/base.py +177 -0
- aury/agents/context_providers/memory.py +70 -0
- aury/agents/context_providers/message.py +130 -0
- aury/agents/context_providers/skill.py +50 -0
- aury/agents/context_providers/subagent.py +46 -0
- aury/agents/context_providers/tool.py +68 -0
- aury/agents/core/__init__.py +83 -0
- aury/agents/core/base.py +573 -0
- aury/agents/core/context.py +797 -0
- aury/agents/core/context_builder.py +303 -0
- aury/agents/core/event_bus/__init__.py +15 -0
- aury/agents/core/event_bus/bus.py +203 -0
- aury/agents/core/factory.py +169 -0
- aury/agents/core/isolator.py +97 -0
- aury/agents/core/logging.py +95 -0
- aury/agents/core/parallel.py +194 -0
- aury/agents/core/runner.py +139 -0
- aury/agents/core/services/__init__.py +5 -0
- aury/agents/core/services/file_session.py +144 -0
- aury/agents/core/services/message.py +53 -0
- aury/agents/core/services/session.py +53 -0
- aury/agents/core/signals.py +109 -0
- aury/agents/core/state.py +363 -0
- aury/agents/core/types/__init__.py +107 -0
- aury/agents/core/types/action.py +176 -0
- aury/agents/core/types/artifact.py +135 -0
- aury/agents/core/types/block.py +736 -0
- aury/agents/core/types/message.py +350 -0
- aury/agents/core/types/recall.py +144 -0
- aury/agents/core/types/session.py +257 -0
- aury/agents/core/types/subagent.py +154 -0
- aury/agents/core/types/tool.py +205 -0
- aury/agents/eval/__init__.py +331 -0
- aury/agents/hitl/__init__.py +57 -0
- aury/agents/hitl/ask_user.py +242 -0
- aury/agents/hitl/compaction.py +230 -0
- aury/agents/hitl/exceptions.py +87 -0
- aury/agents/hitl/permission.py +617 -0
- aury/agents/hitl/revert.py +216 -0
- aury/agents/llm/__init__.py +31 -0
- aury/agents/llm/adapter.py +367 -0
- aury/agents/llm/openai.py +294 -0
- aury/agents/llm/provider.py +476 -0
- aury/agents/mcp/__init__.py +153 -0
- aury/agents/memory/__init__.py +46 -0
- aury/agents/memory/compaction.py +394 -0
- aury/agents/memory/manager.py +465 -0
- aury/agents/memory/processor.py +177 -0
- aury/agents/memory/store.py +187 -0
- aury/agents/memory/types.py +137 -0
- aury/agents/messages/__init__.py +40 -0
- aury/agents/messages/config.py +47 -0
- aury/agents/messages/raw_store.py +224 -0
- aury/agents/messages/store.py +118 -0
- aury/agents/messages/types.py +88 -0
- aury/agents/middleware/__init__.py +31 -0
- aury/agents/middleware/base.py +341 -0
- aury/agents/middleware/chain.py +342 -0
- aury/agents/middleware/message.py +129 -0
- aury/agents/middleware/message_container.py +126 -0
- aury/agents/middleware/raw_message.py +153 -0
- aury/agents/middleware/truncation.py +139 -0
- aury/agents/middleware/types.py +81 -0
- aury/agents/plugin.py +162 -0
- aury/agents/react/__init__.py +4 -0
- aury/agents/react/agent.py +1923 -0
- aury/agents/sandbox/__init__.py +23 -0
- aury/agents/sandbox/local.py +239 -0
- aury/agents/sandbox/remote.py +200 -0
- aury/agents/sandbox/types.py +115 -0
- aury/agents/skill/__init__.py +16 -0
- aury/agents/skill/loader.py +180 -0
- aury/agents/skill/types.py +83 -0
- aury/agents/tool/__init__.py +39 -0
- aury/agents/tool/builtin/__init__.py +23 -0
- aury/agents/tool/builtin/ask_user.py +155 -0
- aury/agents/tool/builtin/bash.py +107 -0
- aury/agents/tool/builtin/delegate.py +726 -0
- aury/agents/tool/builtin/edit.py +121 -0
- aury/agents/tool/builtin/plan.py +277 -0
- aury/agents/tool/builtin/read.py +91 -0
- aury/agents/tool/builtin/thinking.py +111 -0
- aury/agents/tool/builtin/yield_result.py +130 -0
- aury/agents/tool/decorator.py +252 -0
- aury/agents/tool/set.py +204 -0
- aury/agents/usage/__init__.py +12 -0
- aury/agents/usage/tracker.py +236 -0
- aury/agents/workflow/__init__.py +85 -0
- aury/agents/workflow/adapter.py +268 -0
- aury/agents/workflow/dag.py +116 -0
- aury/agents/workflow/dsl.py +575 -0
- aury/agents/workflow/executor.py +659 -0
- aury/agents/workflow/expression.py +136 -0
- aury/agents/workflow/parser.py +182 -0
- aury/agents/workflow/state.py +145 -0
- aury/agents/workflow/types.py +86 -0
- aury_agent-0.0.4.dist-info/METADATA +90 -0
- aury_agent-0.0.4.dist-info/RECORD +149 -0
- aury_agent-0.0.4.dist-info/WHEEL +4 -0
- aury_agent-0.0.4.dist-info/entry_points.txt +2 -0
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
"""Message store protocol and implementations.
|
|
2
|
+
|
|
3
|
+
Note: For production use, prefer MessageBackend (backends/message/).
|
|
4
|
+
This module provides a simple protocol and in-memory implementation for testing.
|
|
5
|
+
"""
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from typing import Protocol, runtime_checkable
|
|
9
|
+
|
|
10
|
+
from .types import Message
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@runtime_checkable
|
|
14
|
+
class MessageStore(Protocol):
|
|
15
|
+
"""Protocol for message storage.
|
|
16
|
+
|
|
17
|
+
Note: For production, use MessageBackend instead.
|
|
18
|
+
This protocol is kept for backward compatibility.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
async def add(
|
|
22
|
+
self,
|
|
23
|
+
session_id: str,
|
|
24
|
+
message: Message,
|
|
25
|
+
namespace: str | None = None,
|
|
26
|
+
) -> None:
|
|
27
|
+
"""Add a message to session history."""
|
|
28
|
+
...
|
|
29
|
+
|
|
30
|
+
async def get_all(
|
|
31
|
+
self,
|
|
32
|
+
session_id: str,
|
|
33
|
+
namespace: str | None = None,
|
|
34
|
+
) -> list[Message]:
|
|
35
|
+
"""Get all messages for a session."""
|
|
36
|
+
...
|
|
37
|
+
|
|
38
|
+
async def get_recent(
|
|
39
|
+
self,
|
|
40
|
+
session_id: str,
|
|
41
|
+
limit: int,
|
|
42
|
+
namespace: str | None = None,
|
|
43
|
+
) -> list[Message]:
|
|
44
|
+
"""Get recent messages for a session."""
|
|
45
|
+
...
|
|
46
|
+
|
|
47
|
+
async def delete_by_invocation(
|
|
48
|
+
self,
|
|
49
|
+
session_id: str,
|
|
50
|
+
invocation_id: str,
|
|
51
|
+
namespace: str | None = None,
|
|
52
|
+
) -> int:
|
|
53
|
+
"""Delete messages by invocation ID."""
|
|
54
|
+
...
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class InMemoryMessageStore:
|
|
58
|
+
"""In-memory message store for testing."""
|
|
59
|
+
|
|
60
|
+
def __init__(self) -> None:
|
|
61
|
+
# Key format: "session_id" or "session_id:namespace"
|
|
62
|
+
self._messages: dict[str, list[Message]] = {}
|
|
63
|
+
|
|
64
|
+
def _make_key(self, session_id: str, namespace: str | None) -> str:
|
|
65
|
+
if namespace:
|
|
66
|
+
return f"{session_id}:{namespace}"
|
|
67
|
+
return session_id
|
|
68
|
+
|
|
69
|
+
async def add(
|
|
70
|
+
self,
|
|
71
|
+
session_id: str,
|
|
72
|
+
message: Message,
|
|
73
|
+
namespace: str | None = None,
|
|
74
|
+
) -> None:
|
|
75
|
+
key = self._make_key(session_id, namespace)
|
|
76
|
+
if key not in self._messages:
|
|
77
|
+
self._messages[key] = []
|
|
78
|
+
self._messages[key].append(message)
|
|
79
|
+
|
|
80
|
+
async def get_all(
|
|
81
|
+
self,
|
|
82
|
+
session_id: str,
|
|
83
|
+
namespace: str | None = None,
|
|
84
|
+
) -> list[Message]:
|
|
85
|
+
key = self._make_key(session_id, namespace)
|
|
86
|
+
return self._messages.get(key, []).copy()
|
|
87
|
+
|
|
88
|
+
async def get_recent(
|
|
89
|
+
self,
|
|
90
|
+
session_id: str,
|
|
91
|
+
limit: int,
|
|
92
|
+
namespace: str | None = None,
|
|
93
|
+
) -> list[Message]:
|
|
94
|
+
key = self._make_key(session_id, namespace)
|
|
95
|
+
messages = self._messages.get(key, [])
|
|
96
|
+
return messages[-limit:] if limit else messages.copy()
|
|
97
|
+
|
|
98
|
+
async def delete_by_invocation(
|
|
99
|
+
self,
|
|
100
|
+
session_id: str,
|
|
101
|
+
invocation_id: str,
|
|
102
|
+
namespace: str | None = None,
|
|
103
|
+
) -> int:
|
|
104
|
+
key = self._make_key(session_id, namespace)
|
|
105
|
+
if key not in self._messages:
|
|
106
|
+
return 0
|
|
107
|
+
|
|
108
|
+
original = self._messages[key]
|
|
109
|
+
self._messages[key] = [
|
|
110
|
+
m for m in original if m.invocation_id != invocation_id
|
|
111
|
+
]
|
|
112
|
+
return len(original) - len(self._messages[key])
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
__all__ = [
|
|
116
|
+
"MessageStore",
|
|
117
|
+
"InMemoryMessageStore",
|
|
118
|
+
]
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
"""Message type definitions.
|
|
2
|
+
|
|
3
|
+
Core types for the message system:
|
|
4
|
+
- Message: A single message in conversation history
|
|
5
|
+
- MessageRole: User/Assistant/Tool/System
|
|
6
|
+
"""
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from dataclasses import dataclass, field
|
|
10
|
+
from datetime import datetime
|
|
11
|
+
from enum import Enum
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class MessageRole(Enum):
|
|
16
|
+
"""Message role types."""
|
|
17
|
+
USER = "user"
|
|
18
|
+
ASSISTANT = "assistant"
|
|
19
|
+
TOOL = "tool"
|
|
20
|
+
SYSTEM = "system"
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class Message:
|
|
25
|
+
"""A message in conversation history.
|
|
26
|
+
|
|
27
|
+
Attributes:
|
|
28
|
+
role: Message role (user/assistant/tool/system)
|
|
29
|
+
content: Message content (string or content parts)
|
|
30
|
+
invocation_id: Which invocation this message belongs to
|
|
31
|
+
tool_call_id: Tool call ID (for tool messages)
|
|
32
|
+
created_at: When the message was created
|
|
33
|
+
metadata: Additional metadata
|
|
34
|
+
"""
|
|
35
|
+
role: str
|
|
36
|
+
content: str | list[dict[str, Any]]
|
|
37
|
+
invocation_id: str = ""
|
|
38
|
+
tool_call_id: str | None = None
|
|
39
|
+
created_at: datetime = field(default_factory=datetime.now)
|
|
40
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
41
|
+
|
|
42
|
+
def to_dict(self) -> dict[str, Any]:
|
|
43
|
+
"""Convert to dict for storage."""
|
|
44
|
+
result = {
|
|
45
|
+
"role": self.role,
|
|
46
|
+
"content": self.content,
|
|
47
|
+
"invocation_id": self.invocation_id,
|
|
48
|
+
"created_at": self.created_at.isoformat(),
|
|
49
|
+
}
|
|
50
|
+
if self.tool_call_id:
|
|
51
|
+
result["tool_call_id"] = self.tool_call_id
|
|
52
|
+
if self.metadata:
|
|
53
|
+
result["metadata"] = self.metadata
|
|
54
|
+
return result
|
|
55
|
+
|
|
56
|
+
@classmethod
|
|
57
|
+
def from_dict(cls, data: dict[str, Any]) -> "Message":
|
|
58
|
+
"""Create from dict."""
|
|
59
|
+
created_at = data.get("created_at")
|
|
60
|
+
if isinstance(created_at, str):
|
|
61
|
+
created_at = datetime.fromisoformat(created_at)
|
|
62
|
+
elif created_at is None:
|
|
63
|
+
created_at = datetime.now()
|
|
64
|
+
|
|
65
|
+
return cls(
|
|
66
|
+
role=data["role"],
|
|
67
|
+
content=data["content"],
|
|
68
|
+
invocation_id=data.get("invocation_id", ""),
|
|
69
|
+
tool_call_id=data.get("tool_call_id"),
|
|
70
|
+
created_at=created_at,
|
|
71
|
+
metadata=data.get("metadata", {}),
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
def to_llm_format(self) -> dict[str, Any]:
|
|
75
|
+
"""Convert to LLM message format."""
|
|
76
|
+
result: dict[str, Any] = {
|
|
77
|
+
"role": self.role,
|
|
78
|
+
"content": self.content,
|
|
79
|
+
}
|
|
80
|
+
if self.tool_call_id:
|
|
81
|
+
result["tool_call_id"] = self.tool_call_id
|
|
82
|
+
return result
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
__all__ = [
|
|
86
|
+
"MessageRole",
|
|
87
|
+
"Message",
|
|
88
|
+
]
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
"""Middleware system for request/response processing.
|
|
2
|
+
|
|
3
|
+
Middleware provides hooks for intercepting and modifying:
|
|
4
|
+
- LLM requests/responses
|
|
5
|
+
- Agent lifecycle events
|
|
6
|
+
- Tool execution
|
|
7
|
+
- Sub-agent delegation
|
|
8
|
+
- Message persistence
|
|
9
|
+
"""
|
|
10
|
+
from .types import TriggerMode, HookAction, HookResult, MiddlewareConfig
|
|
11
|
+
from .base import Middleware, BaseMiddleware
|
|
12
|
+
from .chain import MiddlewareChain
|
|
13
|
+
from .message_container import MessageContainerMiddleware
|
|
14
|
+
from .message import MessageBackendMiddleware
|
|
15
|
+
from .truncation import MessageTruncationMiddleware
|
|
16
|
+
from .raw_message import RawMessageMiddleware
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
"TriggerMode",
|
|
20
|
+
"HookAction",
|
|
21
|
+
"HookResult",
|
|
22
|
+
"MiddlewareConfig",
|
|
23
|
+
"Middleware",
|
|
24
|
+
"BaseMiddleware",
|
|
25
|
+
"MiddlewareChain",
|
|
26
|
+
# Default middlewares
|
|
27
|
+
"MessageContainerMiddleware",
|
|
28
|
+
"MessageBackendMiddleware",
|
|
29
|
+
"MessageTruncationMiddleware",
|
|
30
|
+
"RawMessageMiddleware",
|
|
31
|
+
]
|
|
@@ -0,0 +1,341 @@
|
|
|
1
|
+
"""Middleware protocol and base implementation."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from typing import Any, Protocol, runtime_checkable, TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
from .types import HookResult, MiddlewareConfig
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from ..core.types.tool import BaseTool, ToolResult
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@runtime_checkable
|
|
13
|
+
class Middleware(Protocol):
|
|
14
|
+
"""Middleware protocol for request/response processing.
|
|
15
|
+
|
|
16
|
+
Includes both LLM request/response hooks and agent lifecycle hooks.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
@property
|
|
20
|
+
def config(self) -> MiddlewareConfig:
|
|
21
|
+
"""Get middleware configuration."""
|
|
22
|
+
...
|
|
23
|
+
|
|
24
|
+
# ========== LLM Request/Response Hooks ==========
|
|
25
|
+
|
|
26
|
+
async def on_request(
|
|
27
|
+
self,
|
|
28
|
+
request: dict[str, Any],
|
|
29
|
+
context: dict[str, Any],
|
|
30
|
+
) -> dict[str, Any] | None:
|
|
31
|
+
"""Process request before LLM call.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
request: The request to process
|
|
35
|
+
context: Execution context
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
Modified request, or None to skip further processing
|
|
39
|
+
"""
|
|
40
|
+
...
|
|
41
|
+
|
|
42
|
+
async def on_response(
|
|
43
|
+
self,
|
|
44
|
+
response: dict[str, Any],
|
|
45
|
+
context: dict[str, Any],
|
|
46
|
+
) -> dict[str, Any] | None:
|
|
47
|
+
"""Process response after LLM call.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
response: The response to process
|
|
51
|
+
context: Execution context
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
Modified response, or None to skip further processing
|
|
55
|
+
"""
|
|
56
|
+
...
|
|
57
|
+
|
|
58
|
+
async def on_error(
|
|
59
|
+
self,
|
|
60
|
+
error: Exception,
|
|
61
|
+
context: dict[str, Any],
|
|
62
|
+
) -> Exception | None:
|
|
63
|
+
"""Handle errors.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
error: The exception that occurred
|
|
67
|
+
context: Execution context
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
Modified exception, or None to suppress
|
|
71
|
+
"""
|
|
72
|
+
...
|
|
73
|
+
|
|
74
|
+
async def on_model_stream(
|
|
75
|
+
self,
|
|
76
|
+
chunk: dict[str, Any],
|
|
77
|
+
context: dict[str, Any],
|
|
78
|
+
) -> dict[str, Any] | None:
|
|
79
|
+
"""Process streaming chunk (triggered by trigger_mode).
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
chunk: The streaming chunk
|
|
83
|
+
context: Execution context
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
Modified chunk, or None to skip further processing
|
|
87
|
+
"""
|
|
88
|
+
...
|
|
89
|
+
|
|
90
|
+
# ========== Agent Lifecycle Hooks ==========
|
|
91
|
+
|
|
92
|
+
async def on_agent_start(
|
|
93
|
+
self,
|
|
94
|
+
agent_id: str,
|
|
95
|
+
input_data: Any,
|
|
96
|
+
context: dict[str, Any],
|
|
97
|
+
) -> HookResult:
|
|
98
|
+
"""Called when agent starts processing.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
agent_id: The agent identifier
|
|
102
|
+
input_data: Input to the agent
|
|
103
|
+
context: Execution context
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
HookResult controlling execution flow
|
|
107
|
+
"""
|
|
108
|
+
...
|
|
109
|
+
|
|
110
|
+
async def on_agent_end(
|
|
111
|
+
self,
|
|
112
|
+
agent_id: str,
|
|
113
|
+
result: Any,
|
|
114
|
+
context: dict[str, Any],
|
|
115
|
+
) -> HookResult:
|
|
116
|
+
"""Called when agent completes processing.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
agent_id: The agent identifier
|
|
120
|
+
result: Agent's result
|
|
121
|
+
context: Execution context
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
HookResult (only CONTINUE/STOP meaningful here)
|
|
125
|
+
"""
|
|
126
|
+
...
|
|
127
|
+
|
|
128
|
+
async def on_tool_call(
|
|
129
|
+
self,
|
|
130
|
+
tool: "BaseTool",
|
|
131
|
+
params: dict[str, Any],
|
|
132
|
+
context: dict[str, Any],
|
|
133
|
+
) -> HookResult:
|
|
134
|
+
"""Called before tool execution.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
tool: The tool to be called
|
|
138
|
+
params: Tool parameters
|
|
139
|
+
context: Execution context
|
|
140
|
+
|
|
141
|
+
Returns:
|
|
142
|
+
HookResult - SKIP to skip tool, RETRY to modify params
|
|
143
|
+
"""
|
|
144
|
+
...
|
|
145
|
+
|
|
146
|
+
async def on_tool_end(
|
|
147
|
+
self,
|
|
148
|
+
tool: "BaseTool",
|
|
149
|
+
result: "ToolResult",
|
|
150
|
+
context: dict[str, Any],
|
|
151
|
+
) -> HookResult:
|
|
152
|
+
"""Called after tool execution.
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
tool: The tool that was called
|
|
156
|
+
result: Tool execution result
|
|
157
|
+
context: Execution context
|
|
158
|
+
|
|
159
|
+
Returns:
|
|
160
|
+
HookResult - RETRY to re-execute tool
|
|
161
|
+
"""
|
|
162
|
+
...
|
|
163
|
+
|
|
164
|
+
async def on_subagent_start(
|
|
165
|
+
self,
|
|
166
|
+
parent_agent_id: str,
|
|
167
|
+
child_agent_id: str,
|
|
168
|
+
mode: str, # "embedded" or "delegated"
|
|
169
|
+
context: dict[str, Any],
|
|
170
|
+
) -> HookResult:
|
|
171
|
+
"""Called when delegating to a sub-agent.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
parent_agent_id: Parent agent identifier
|
|
175
|
+
child_agent_id: Child agent identifier
|
|
176
|
+
mode: Delegation mode
|
|
177
|
+
context: Execution context
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
HookResult - SKIP to skip delegation
|
|
181
|
+
"""
|
|
182
|
+
...
|
|
183
|
+
|
|
184
|
+
async def on_subagent_end(
|
|
185
|
+
self,
|
|
186
|
+
parent_agent_id: str,
|
|
187
|
+
child_agent_id: str,
|
|
188
|
+
result: Any,
|
|
189
|
+
context: dict[str, Any],
|
|
190
|
+
) -> HookResult:
|
|
191
|
+
"""Called when sub-agent completes.
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
parent_agent_id: Parent agent identifier
|
|
195
|
+
child_agent_id: Child agent identifier
|
|
196
|
+
result: Sub-agent's result
|
|
197
|
+
context: Execution context
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
HookResult (for post-processing)
|
|
201
|
+
"""
|
|
202
|
+
...
|
|
203
|
+
|
|
204
|
+
async def on_message_save(
|
|
205
|
+
self,
|
|
206
|
+
message: dict[str, Any],
|
|
207
|
+
context: dict[str, Any],
|
|
208
|
+
) -> dict[str, Any] | None:
|
|
209
|
+
"""Called before saving a message to history.
|
|
210
|
+
|
|
211
|
+
Allows middlewares to transform, filter, or block messages
|
|
212
|
+
before they are persisted.
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
message: Message dict with 'role', 'content', etc.
|
|
216
|
+
context: Execution context
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
219
|
+
Modified message, or None to skip saving
|
|
220
|
+
"""
|
|
221
|
+
...
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
class BaseMiddleware:
|
|
225
|
+
"""Base middleware implementation with sensible defaults.
|
|
226
|
+
|
|
227
|
+
Subclass and override specific hooks as needed.
|
|
228
|
+
All hooks have sensible pass-through defaults.
|
|
229
|
+
"""
|
|
230
|
+
|
|
231
|
+
_config: MiddlewareConfig = MiddlewareConfig()
|
|
232
|
+
|
|
233
|
+
@property
|
|
234
|
+
def config(self) -> MiddlewareConfig:
|
|
235
|
+
return self._config
|
|
236
|
+
|
|
237
|
+
# ========== LLM Request/Response Hooks ==========
|
|
238
|
+
|
|
239
|
+
async def on_request(
|
|
240
|
+
self,
|
|
241
|
+
request: dict[str, Any],
|
|
242
|
+
context: dict[str, Any],
|
|
243
|
+
) -> dict[str, Any] | None:
|
|
244
|
+
"""Default: pass through."""
|
|
245
|
+
return request
|
|
246
|
+
|
|
247
|
+
async def on_response(
|
|
248
|
+
self,
|
|
249
|
+
response: dict[str, Any],
|
|
250
|
+
context: dict[str, Any],
|
|
251
|
+
) -> dict[str, Any] | None:
|
|
252
|
+
"""Default: pass through."""
|
|
253
|
+
return response
|
|
254
|
+
|
|
255
|
+
async def on_error(
|
|
256
|
+
self,
|
|
257
|
+
error: Exception,
|
|
258
|
+
context: dict[str, Any],
|
|
259
|
+
) -> Exception | None:
|
|
260
|
+
"""Default: re-raise error."""
|
|
261
|
+
return error
|
|
262
|
+
|
|
263
|
+
async def on_model_stream(
|
|
264
|
+
self,
|
|
265
|
+
chunk: dict[str, Any],
|
|
266
|
+
context: dict[str, Any],
|
|
267
|
+
) -> dict[str, Any] | None:
|
|
268
|
+
"""Default: pass through."""
|
|
269
|
+
return chunk
|
|
270
|
+
|
|
271
|
+
# ========== Agent Lifecycle Hooks ==========
|
|
272
|
+
|
|
273
|
+
async def on_agent_start(
|
|
274
|
+
self,
|
|
275
|
+
agent_id: str,
|
|
276
|
+
input_data: Any,
|
|
277
|
+
context: dict[str, Any],
|
|
278
|
+
) -> HookResult:
|
|
279
|
+
"""Default: continue."""
|
|
280
|
+
return HookResult.proceed()
|
|
281
|
+
|
|
282
|
+
async def on_agent_end(
|
|
283
|
+
self,
|
|
284
|
+
agent_id: str,
|
|
285
|
+
result: Any,
|
|
286
|
+
context: dict[str, Any],
|
|
287
|
+
) -> HookResult:
|
|
288
|
+
"""Default: continue."""
|
|
289
|
+
return HookResult.proceed()
|
|
290
|
+
|
|
291
|
+
async def on_tool_call(
|
|
292
|
+
self,
|
|
293
|
+
tool: "BaseTool",
|
|
294
|
+
params: dict[str, Any],
|
|
295
|
+
context: dict[str, Any],
|
|
296
|
+
) -> HookResult:
|
|
297
|
+
"""Default: continue."""
|
|
298
|
+
return HookResult.proceed()
|
|
299
|
+
|
|
300
|
+
async def on_tool_end(
|
|
301
|
+
self,
|
|
302
|
+
tool: "BaseTool",
|
|
303
|
+
result: "ToolResult",
|
|
304
|
+
context: dict[str, Any],
|
|
305
|
+
) -> HookResult:
|
|
306
|
+
"""Default: continue."""
|
|
307
|
+
return HookResult.proceed()
|
|
308
|
+
|
|
309
|
+
async def on_subagent_start(
|
|
310
|
+
self,
|
|
311
|
+
parent_agent_id: str,
|
|
312
|
+
child_agent_id: str,
|
|
313
|
+
mode: str,
|
|
314
|
+
context: dict[str, Any],
|
|
315
|
+
) -> HookResult:
|
|
316
|
+
"""Default: continue."""
|
|
317
|
+
return HookResult.proceed()
|
|
318
|
+
|
|
319
|
+
async def on_subagent_end(
|
|
320
|
+
self,
|
|
321
|
+
parent_agent_id: str,
|
|
322
|
+
child_agent_id: str,
|
|
323
|
+
result: Any,
|
|
324
|
+
context: dict[str, Any],
|
|
325
|
+
) -> HookResult:
|
|
326
|
+
"""Default: continue."""
|
|
327
|
+
return HookResult.proceed()
|
|
328
|
+
|
|
329
|
+
async def on_message_save(
|
|
330
|
+
self,
|
|
331
|
+
message: dict[str, Any],
|
|
332
|
+
context: dict[str, Any],
|
|
333
|
+
) -> dict[str, Any] | None:
|
|
334
|
+
"""Default: pass through."""
|
|
335
|
+
return message
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
__all__ = [
|
|
339
|
+
"Middleware",
|
|
340
|
+
"BaseMiddleware",
|
|
341
|
+
]
|