aury-agent 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aury/__init__.py +2 -0
- aury/agents/__init__.py +55 -0
- aury/agents/a2a/__init__.py +168 -0
- aury/agents/backends/__init__.py +196 -0
- aury/agents/backends/artifact/__init__.py +9 -0
- aury/agents/backends/artifact/memory.py +130 -0
- aury/agents/backends/artifact/types.py +133 -0
- aury/agents/backends/code/__init__.py +65 -0
- aury/agents/backends/file/__init__.py +11 -0
- aury/agents/backends/file/local.py +66 -0
- aury/agents/backends/file/types.py +40 -0
- aury/agents/backends/invocation/__init__.py +8 -0
- aury/agents/backends/invocation/memory.py +81 -0
- aury/agents/backends/invocation/types.py +110 -0
- aury/agents/backends/memory/__init__.py +8 -0
- aury/agents/backends/memory/memory.py +179 -0
- aury/agents/backends/memory/types.py +136 -0
- aury/agents/backends/message/__init__.py +9 -0
- aury/agents/backends/message/memory.py +122 -0
- aury/agents/backends/message/types.py +124 -0
- aury/agents/backends/sandbox.py +275 -0
- aury/agents/backends/session/__init__.py +8 -0
- aury/agents/backends/session/memory.py +93 -0
- aury/agents/backends/session/types.py +124 -0
- aury/agents/backends/shell/__init__.py +11 -0
- aury/agents/backends/shell/local.py +110 -0
- aury/agents/backends/shell/types.py +55 -0
- aury/agents/backends/shell.py +209 -0
- aury/agents/backends/snapshot/__init__.py +19 -0
- aury/agents/backends/snapshot/git.py +95 -0
- aury/agents/backends/snapshot/hybrid.py +125 -0
- aury/agents/backends/snapshot/memory.py +86 -0
- aury/agents/backends/snapshot/types.py +59 -0
- aury/agents/backends/state/__init__.py +29 -0
- aury/agents/backends/state/composite.py +49 -0
- aury/agents/backends/state/file.py +57 -0
- aury/agents/backends/state/memory.py +52 -0
- aury/agents/backends/state/sqlite.py +262 -0
- aury/agents/backends/state/types.py +178 -0
- aury/agents/backends/subagent/__init__.py +165 -0
- aury/agents/cli/__init__.py +41 -0
- aury/agents/cli/chat.py +239 -0
- aury/agents/cli/config.py +236 -0
- aury/agents/cli/extensions.py +460 -0
- aury/agents/cli/main.py +189 -0
- aury/agents/cli/session.py +337 -0
- aury/agents/cli/workflow.py +276 -0
- aury/agents/context_providers/__init__.py +66 -0
- aury/agents/context_providers/artifact.py +299 -0
- aury/agents/context_providers/base.py +177 -0
- aury/agents/context_providers/memory.py +70 -0
- aury/agents/context_providers/message.py +130 -0
- aury/agents/context_providers/skill.py +50 -0
- aury/agents/context_providers/subagent.py +46 -0
- aury/agents/context_providers/tool.py +68 -0
- aury/agents/core/__init__.py +83 -0
- aury/agents/core/base.py +573 -0
- aury/agents/core/context.py +797 -0
- aury/agents/core/context_builder.py +303 -0
- aury/agents/core/event_bus/__init__.py +15 -0
- aury/agents/core/event_bus/bus.py +203 -0
- aury/agents/core/factory.py +169 -0
- aury/agents/core/isolator.py +97 -0
- aury/agents/core/logging.py +95 -0
- aury/agents/core/parallel.py +194 -0
- aury/agents/core/runner.py +139 -0
- aury/agents/core/services/__init__.py +5 -0
- aury/agents/core/services/file_session.py +144 -0
- aury/agents/core/services/message.py +53 -0
- aury/agents/core/services/session.py +53 -0
- aury/agents/core/signals.py +109 -0
- aury/agents/core/state.py +363 -0
- aury/agents/core/types/__init__.py +107 -0
- aury/agents/core/types/action.py +176 -0
- aury/agents/core/types/artifact.py +135 -0
- aury/agents/core/types/block.py +736 -0
- aury/agents/core/types/message.py +350 -0
- aury/agents/core/types/recall.py +144 -0
- aury/agents/core/types/session.py +257 -0
- aury/agents/core/types/subagent.py +154 -0
- aury/agents/core/types/tool.py +205 -0
- aury/agents/eval/__init__.py +331 -0
- aury/agents/hitl/__init__.py +57 -0
- aury/agents/hitl/ask_user.py +242 -0
- aury/agents/hitl/compaction.py +230 -0
- aury/agents/hitl/exceptions.py +87 -0
- aury/agents/hitl/permission.py +617 -0
- aury/agents/hitl/revert.py +216 -0
- aury/agents/llm/__init__.py +31 -0
- aury/agents/llm/adapter.py +367 -0
- aury/agents/llm/openai.py +294 -0
- aury/agents/llm/provider.py +476 -0
- aury/agents/mcp/__init__.py +153 -0
- aury/agents/memory/__init__.py +46 -0
- aury/agents/memory/compaction.py +394 -0
- aury/agents/memory/manager.py +465 -0
- aury/agents/memory/processor.py +177 -0
- aury/agents/memory/store.py +187 -0
- aury/agents/memory/types.py +137 -0
- aury/agents/messages/__init__.py +40 -0
- aury/agents/messages/config.py +47 -0
- aury/agents/messages/raw_store.py +224 -0
- aury/agents/messages/store.py +118 -0
- aury/agents/messages/types.py +88 -0
- aury/agents/middleware/__init__.py +31 -0
- aury/agents/middleware/base.py +341 -0
- aury/agents/middleware/chain.py +342 -0
- aury/agents/middleware/message.py +129 -0
- aury/agents/middleware/message_container.py +126 -0
- aury/agents/middleware/raw_message.py +153 -0
- aury/agents/middleware/truncation.py +139 -0
- aury/agents/middleware/types.py +81 -0
- aury/agents/plugin.py +162 -0
- aury/agents/react/__init__.py +4 -0
- aury/agents/react/agent.py +1923 -0
- aury/agents/sandbox/__init__.py +23 -0
- aury/agents/sandbox/local.py +239 -0
- aury/agents/sandbox/remote.py +200 -0
- aury/agents/sandbox/types.py +115 -0
- aury/agents/skill/__init__.py +16 -0
- aury/agents/skill/loader.py +180 -0
- aury/agents/skill/types.py +83 -0
- aury/agents/tool/__init__.py +39 -0
- aury/agents/tool/builtin/__init__.py +23 -0
- aury/agents/tool/builtin/ask_user.py +155 -0
- aury/agents/tool/builtin/bash.py +107 -0
- aury/agents/tool/builtin/delegate.py +726 -0
- aury/agents/tool/builtin/edit.py +121 -0
- aury/agents/tool/builtin/plan.py +277 -0
- aury/agents/tool/builtin/read.py +91 -0
- aury/agents/tool/builtin/thinking.py +111 -0
- aury/agents/tool/builtin/yield_result.py +130 -0
- aury/agents/tool/decorator.py +252 -0
- aury/agents/tool/set.py +204 -0
- aury/agents/usage/__init__.py +12 -0
- aury/agents/usage/tracker.py +236 -0
- aury/agents/workflow/__init__.py +85 -0
- aury/agents/workflow/adapter.py +268 -0
- aury/agents/workflow/dag.py +116 -0
- aury/agents/workflow/dsl.py +575 -0
- aury/agents/workflow/executor.py +659 -0
- aury/agents/workflow/expression.py +136 -0
- aury/agents/workflow/parser.py +182 -0
- aury/agents/workflow/state.py +145 -0
- aury/agents/workflow/types.py +86 -0
- aury_agent-0.0.4.dist-info/METADATA +90 -0
- aury_agent-0.0.4.dist-info/RECORD +149 -0
- aury_agent-0.0.4.dist-info/WHEEL +4 -0
- aury_agent-0.0.4.dist-info/entry_points.txt +2 -0
|
@@ -0,0 +1,257 @@
|
|
|
1
|
+
"""Session and Invocation data structures."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from enum import Enum
|
|
7
|
+
from typing import Any
|
|
8
|
+
from uuid import uuid4
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def generate_id(prefix: str = "") -> str:
|
|
12
|
+
"""Generate a unique ID with optional prefix."""
|
|
13
|
+
uid = uuid4().hex[:12]
|
|
14
|
+
return f"{prefix}_{uid}" if prefix else uid
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class InvocationState(Enum):
|
|
18
|
+
"""Invocation execution state."""
|
|
19
|
+
PENDING = "pending"
|
|
20
|
+
RUNNING = "running"
|
|
21
|
+
SUSPENDED = "suspended" # HITL waiting
|
|
22
|
+
PAUSED = "paused" # User paused
|
|
23
|
+
COMPLETED = "completed"
|
|
24
|
+
FAILED = "failed"
|
|
25
|
+
CANCELLED = "cancelled"
|
|
26
|
+
ABORTED = "aborted" # User stopped
|
|
27
|
+
SWITCHED = "switched" # User switched agent
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class InvocationMode(Enum):
|
|
31
|
+
"""Invocation mode."""
|
|
32
|
+
ROOT = "root" # Root invocation
|
|
33
|
+
DELEGATED = "delegated" # Delegated sub-agent
|
|
34
|
+
# Note: EMBEDDED doesn't create new Invocation, uses parent's
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass
|
|
38
|
+
class ControlFrame:
|
|
39
|
+
"""Control frame for tracking delegated agent in session."""
|
|
40
|
+
agent_id: str
|
|
41
|
+
invocation_id: str
|
|
42
|
+
entered_at: datetime = field(default_factory=datetime.now)
|
|
43
|
+
parent_invocation_id: str | None = None
|
|
44
|
+
|
|
45
|
+
def to_dict(self) -> dict[str, Any]:
|
|
46
|
+
return {
|
|
47
|
+
"agent_id": self.agent_id,
|
|
48
|
+
"invocation_id": self.invocation_id,
|
|
49
|
+
"entered_at": self.entered_at.isoformat(),
|
|
50
|
+
"parent_invocation_id": self.parent_invocation_id,
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
@classmethod
|
|
54
|
+
def from_dict(cls, data: dict[str, Any]) -> "ControlFrame":
|
|
55
|
+
return cls(
|
|
56
|
+
agent_id=data["agent_id"],
|
|
57
|
+
invocation_id=data["invocation_id"],
|
|
58
|
+
entered_at=datetime.fromisoformat(data["entered_at"]),
|
|
59
|
+
parent_invocation_id=data.get("parent_invocation_id"),
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@dataclass
|
|
64
|
+
class Session:
|
|
65
|
+
"""Session - container for conversations.
|
|
66
|
+
|
|
67
|
+
A Session represents a conversation thread that can contain multiple
|
|
68
|
+
invocations (turns). Sessions can be nested (parent_id) for sub-agent scenarios.
|
|
69
|
+
"""
|
|
70
|
+
id: str = field(default_factory=lambda: generate_id("sess"))
|
|
71
|
+
root_agent_id: str = "" # Root agent for this session
|
|
72
|
+
parent_id: str | None = None # For forked sessions
|
|
73
|
+
created_at: datetime = field(default_factory=datetime.now)
|
|
74
|
+
updated_at: datetime = field(default_factory=datetime.now)
|
|
75
|
+
metadata: dict[str, Any] = field(default_factory=dict) # title, etc.
|
|
76
|
+
is_active: bool = True
|
|
77
|
+
|
|
78
|
+
# Control stack for DELEGATED mode
|
|
79
|
+
control_stack: list[ControlFrame] = field(default_factory=list)
|
|
80
|
+
|
|
81
|
+
# Revert state (if currently in reverted state)
|
|
82
|
+
revert: dict[str, Any] | None = None
|
|
83
|
+
|
|
84
|
+
@property
|
|
85
|
+
def active_agent_id(self) -> str:
|
|
86
|
+
"""Get the currently active agent (top of control stack or root)."""
|
|
87
|
+
if self.control_stack:
|
|
88
|
+
return self.control_stack[-1].agent_id
|
|
89
|
+
return self.root_agent_id
|
|
90
|
+
|
|
91
|
+
def push_control(self, frame: ControlFrame) -> None:
|
|
92
|
+
"""Push control frame when delegating to sub-agent."""
|
|
93
|
+
self.control_stack.append(frame)
|
|
94
|
+
self.updated_at = datetime.now()
|
|
95
|
+
|
|
96
|
+
def pop_control(self) -> ControlFrame | None:
|
|
97
|
+
"""Pop control frame when sub-agent returns control."""
|
|
98
|
+
if self.control_stack:
|
|
99
|
+
self.updated_at = datetime.now()
|
|
100
|
+
return self.control_stack.pop()
|
|
101
|
+
return None
|
|
102
|
+
|
|
103
|
+
def to_dict(self) -> dict[str, Any]:
|
|
104
|
+
"""Convert to dictionary for serialization."""
|
|
105
|
+
return {
|
|
106
|
+
"id": self.id,
|
|
107
|
+
"root_agent_id": self.root_agent_id,
|
|
108
|
+
"parent_id": self.parent_id,
|
|
109
|
+
"created_at": self.created_at.isoformat(),
|
|
110
|
+
"updated_at": self.updated_at.isoformat(),
|
|
111
|
+
"metadata": self.metadata,
|
|
112
|
+
"is_active": self.is_active,
|
|
113
|
+
"control_stack": [f.to_dict() for f in self.control_stack],
|
|
114
|
+
"revert": self.revert,
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
@classmethod
|
|
118
|
+
def from_dict(cls, data: dict[str, Any]) -> "Session":
|
|
119
|
+
"""Create from dictionary."""
|
|
120
|
+
return cls(
|
|
121
|
+
id=data["id"],
|
|
122
|
+
root_agent_id=data.get("root_agent_id", ""),
|
|
123
|
+
parent_id=data.get("parent_id"),
|
|
124
|
+
created_at=datetime.fromisoformat(data["created_at"]),
|
|
125
|
+
updated_at=datetime.fromisoformat(data["updated_at"]),
|
|
126
|
+
metadata=data.get("metadata", {}),
|
|
127
|
+
is_active=data.get("is_active", True),
|
|
128
|
+
control_stack=[ControlFrame.from_dict(f) for f in data.get("control_stack", [])],
|
|
129
|
+
revert=data.get("revert"),
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
@dataclass
|
|
134
|
+
class Invocation:
|
|
135
|
+
"""A single invocation (turn) within a session.
|
|
136
|
+
|
|
137
|
+
An Invocation represents one user input and the agent's response,
|
|
138
|
+
including all tool calls and intermediate steps.
|
|
139
|
+
"""
|
|
140
|
+
id: str = field(default_factory=lambda: generate_id("inv"))
|
|
141
|
+
session_id: str = ""
|
|
142
|
+
agent_id: str = "" # Which agent executes this invocation
|
|
143
|
+
mode: InvocationMode = InvocationMode.ROOT # ROOT or DELEGATED
|
|
144
|
+
state: InvocationState = InvocationState.PENDING
|
|
145
|
+
|
|
146
|
+
# Relationship (tree structure)
|
|
147
|
+
parent_invocation_id: str | None = None # Parent invocation (for DELEGATED)
|
|
148
|
+
|
|
149
|
+
# Timestamps
|
|
150
|
+
created_at: datetime = field(default_factory=datetime.now)
|
|
151
|
+
started_at: datetime | None = None
|
|
152
|
+
finished_at: datetime | None = None
|
|
153
|
+
|
|
154
|
+
# Agent state for resumption
|
|
155
|
+
agent_state: dict[str, Any] | None = None
|
|
156
|
+
pending_tool_ids: list[str] = field(default_factory=list)
|
|
157
|
+
|
|
158
|
+
# Execution info
|
|
159
|
+
step_count: int = 0
|
|
160
|
+
snapshot_id: str | None = None # Pre-execution snapshot for revert
|
|
161
|
+
|
|
162
|
+
# Error info
|
|
163
|
+
error: str | None = None
|
|
164
|
+
|
|
165
|
+
# Branch (for state isolation)
|
|
166
|
+
branch: str | None = None
|
|
167
|
+
|
|
168
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
169
|
+
|
|
170
|
+
def mark_started(self) -> None:
|
|
171
|
+
"""Mark invocation as started."""
|
|
172
|
+
self.state = InvocationState.RUNNING
|
|
173
|
+
self.started_at = datetime.now()
|
|
174
|
+
|
|
175
|
+
def mark_completed(self) -> None:
|
|
176
|
+
"""Mark invocation as completed."""
|
|
177
|
+
self.state = InvocationState.COMPLETED
|
|
178
|
+
self.finished_at = datetime.now()
|
|
179
|
+
|
|
180
|
+
def mark_failed(self, error: str) -> None:
|
|
181
|
+
"""Mark invocation as failed."""
|
|
182
|
+
self.state = InvocationState.FAILED
|
|
183
|
+
self.error = error
|
|
184
|
+
self.finished_at = datetime.now()
|
|
185
|
+
|
|
186
|
+
def mark_cancelled(self) -> None:
|
|
187
|
+
"""Mark invocation as cancelled."""
|
|
188
|
+
self.state = InvocationState.CANCELLED
|
|
189
|
+
self.finished_at = datetime.now()
|
|
190
|
+
|
|
191
|
+
def mark_aborted(self) -> None:
|
|
192
|
+
"""Mark invocation as aborted by user."""
|
|
193
|
+
self.state = InvocationState.ABORTED
|
|
194
|
+
self.finished_at = datetime.now()
|
|
195
|
+
|
|
196
|
+
def mark_switched(self) -> None:
|
|
197
|
+
"""Mark invocation as ended due to agent switch."""
|
|
198
|
+
self.state = InvocationState.SWITCHED
|
|
199
|
+
self.finished_at = datetime.now()
|
|
200
|
+
|
|
201
|
+
def mark_paused(self) -> None:
|
|
202
|
+
"""Mark invocation as paused."""
|
|
203
|
+
self.state = InvocationState.PAUSED
|
|
204
|
+
|
|
205
|
+
def mark_suspended(self) -> None:
|
|
206
|
+
"""Mark invocation as suspended (HITL)."""
|
|
207
|
+
self.state = InvocationState.SUSPENDED
|
|
208
|
+
|
|
209
|
+
@property
|
|
210
|
+
def duration_ms(self) -> int | None:
|
|
211
|
+
"""Get execution duration in milliseconds."""
|
|
212
|
+
if self.started_at and self.finished_at:
|
|
213
|
+
return int((self.finished_at - self.started_at).total_seconds() * 1000)
|
|
214
|
+
return None
|
|
215
|
+
|
|
216
|
+
def to_dict(self) -> dict[str, Any]:
|
|
217
|
+
"""Convert to dictionary for serialization."""
|
|
218
|
+
return {
|
|
219
|
+
"id": self.id,
|
|
220
|
+
"session_id": self.session_id,
|
|
221
|
+
"agent_id": self.agent_id,
|
|
222
|
+
"mode": self.mode.value,
|
|
223
|
+
"state": self.state.value,
|
|
224
|
+
"parent_invocation_id": self.parent_invocation_id,
|
|
225
|
+
"created_at": self.created_at.isoformat(),
|
|
226
|
+
"started_at": self.started_at.isoformat() if self.started_at else None,
|
|
227
|
+
"finished_at": self.finished_at.isoformat() if self.finished_at else None,
|
|
228
|
+
"agent_state": self.agent_state,
|
|
229
|
+
"pending_tool_ids": self.pending_tool_ids,
|
|
230
|
+
"step_count": self.step_count,
|
|
231
|
+
"snapshot_id": self.snapshot_id,
|
|
232
|
+
"error": self.error,
|
|
233
|
+
"branch": self.branch,
|
|
234
|
+
"metadata": self.metadata,
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
@classmethod
|
|
238
|
+
def from_dict(cls, data: dict[str, Any]) -> "Invocation":
|
|
239
|
+
"""Create from dictionary."""
|
|
240
|
+
return cls(
|
|
241
|
+
id=data["id"],
|
|
242
|
+
session_id=data["session_id"],
|
|
243
|
+
agent_id=data.get("agent_id", ""),
|
|
244
|
+
mode=InvocationMode(data.get("mode", "root")),
|
|
245
|
+
state=InvocationState(data["state"]),
|
|
246
|
+
parent_invocation_id=data.get("parent_invocation_id"),
|
|
247
|
+
created_at=datetime.fromisoformat(data["created_at"]),
|
|
248
|
+
started_at=datetime.fromisoformat(data["started_at"]) if data.get("started_at") else None,
|
|
249
|
+
finished_at=datetime.fromisoformat(data["finished_at"]) if data.get("finished_at") else None,
|
|
250
|
+
agent_state=data.get("agent_state"),
|
|
251
|
+
pending_tool_ids=data.get("pending_tool_ids", []),
|
|
252
|
+
step_count=data.get("step_count", 0),
|
|
253
|
+
snapshot_id=data.get("snapshot_id"),
|
|
254
|
+
error=data.get("error"),
|
|
255
|
+
branch=data.get("branch"),
|
|
256
|
+
metadata=data.get("metadata", {}),
|
|
257
|
+
)
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
"""SubAgent input/output types for agent delegation."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from enum import Enum
|
|
7
|
+
from typing import Any, Literal, TypedDict, NotRequired, TYPE_CHECKING
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from .message import Message
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class SubAgentMode(Enum):
|
|
14
|
+
"""SubAgent execution mode."""
|
|
15
|
+
EMBEDDED = "embedded" # Embedded execution, shares parent's invocation
|
|
16
|
+
DELEGATED = "delegated" # Delegated execution, creates new invocation
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class SubAgentInput(TypedDict):
|
|
20
|
+
"""Input for SubAgent invocation.
|
|
21
|
+
|
|
22
|
+
LLM 传的输入:
|
|
23
|
+
- agent: 调用哪个 agent
|
|
24
|
+
- task_context: 任务上下文(用户意图、背景、要求等)
|
|
25
|
+
- artifact_refs: 相关资料引用
|
|
26
|
+
|
|
27
|
+
其他配置(mode, inherit_messages, summary_mode 等)
|
|
28
|
+
在 AgentConfig 中定义,不由 LLM 传入。
|
|
29
|
+
"""
|
|
30
|
+
# Required
|
|
31
|
+
agent: str # Agent key
|
|
32
|
+
|
|
33
|
+
# Task context - 尽可能描述用户意图、背景信息、具体要求
|
|
34
|
+
task_context: NotRequired[str]
|
|
35
|
+
|
|
36
|
+
# Artifact references - 相关资料 [{id, summary}, ...]
|
|
37
|
+
artifact_refs: NotRequired[list[dict[str, str]]]
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass
|
|
41
|
+
class SubAgentMetadata:
|
|
42
|
+
"""Metadata about sub-agent execution."""
|
|
43
|
+
child_invocation_id: str
|
|
44
|
+
agent_name: str
|
|
45
|
+
agent_type: str # "react" | "workflow"
|
|
46
|
+
steps: int = 0
|
|
47
|
+
duration_ms: int = 0
|
|
48
|
+
token_usage: dict[str, int] = field(default_factory=dict)
|
|
49
|
+
|
|
50
|
+
def to_dict(self) -> dict[str, Any]:
|
|
51
|
+
return {
|
|
52
|
+
"child_invocation_id": self.child_invocation_id,
|
|
53
|
+
"agent_name": self.agent_name,
|
|
54
|
+
"agent_type": self.agent_type,
|
|
55
|
+
"steps": self.steps,
|
|
56
|
+
"duration_ms": self.duration_ms,
|
|
57
|
+
"token_usage": self.token_usage,
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
@classmethod
|
|
61
|
+
def from_dict(cls, data: dict[str, Any]) -> "SubAgentMetadata":
|
|
62
|
+
return cls(
|
|
63
|
+
child_invocation_id=data["child_invocation_id"],
|
|
64
|
+
agent_name=data["agent_name"],
|
|
65
|
+
agent_type=data.get("agent_type", "react"),
|
|
66
|
+
steps=data.get("steps", 0),
|
|
67
|
+
duration_ms=data.get("duration_ms", 0),
|
|
68
|
+
token_usage=data.get("token_usage", {}),
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@dataclass
|
|
73
|
+
class SubAgentResult:
|
|
74
|
+
"""Result returned by sub-agent to parent agent.
|
|
75
|
+
|
|
76
|
+
Contains both text output (for LLM context) and structured data.
|
|
77
|
+
"""
|
|
78
|
+
# Text output (summary for LLM)
|
|
79
|
+
output: str
|
|
80
|
+
|
|
81
|
+
# Execution status
|
|
82
|
+
status: Literal["completed", "aborted", "failed", "switched"]
|
|
83
|
+
|
|
84
|
+
# Structured data (optional)
|
|
85
|
+
data: dict[str, Any] | None = None
|
|
86
|
+
|
|
87
|
+
# State changes to merge back to parent
|
|
88
|
+
state_updates: dict[str, Any] | None = None
|
|
89
|
+
|
|
90
|
+
# Error info (when failed)
|
|
91
|
+
error: str | None = None
|
|
92
|
+
|
|
93
|
+
# Execution metadata
|
|
94
|
+
metadata: SubAgentMetadata | None = None
|
|
95
|
+
|
|
96
|
+
def to_dict(self) -> dict[str, Any]:
|
|
97
|
+
return {
|
|
98
|
+
"output": self.output,
|
|
99
|
+
"status": self.status,
|
|
100
|
+
"data": self.data,
|
|
101
|
+
"state_updates": self.state_updates,
|
|
102
|
+
"error": self.error,
|
|
103
|
+
"metadata": self.metadata.to_dict() if self.metadata else None,
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
@classmethod
|
|
107
|
+
def from_dict(cls, data: dict[str, Any]) -> "SubAgentResult":
|
|
108
|
+
return cls(
|
|
109
|
+
output=data["output"],
|
|
110
|
+
status=data["status"],
|
|
111
|
+
data=data.get("data"),
|
|
112
|
+
state_updates=data.get("state_updates"),
|
|
113
|
+
error=data.get("error"),
|
|
114
|
+
metadata=SubAgentMetadata.from_dict(data["metadata"]) if data.get("metadata") else None,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
@classmethod
|
|
118
|
+
def completed(
|
|
119
|
+
cls,
|
|
120
|
+
output: str,
|
|
121
|
+
data: dict[str, Any] | None = None,
|
|
122
|
+
state_updates: dict[str, Any] | None = None,
|
|
123
|
+
metadata: SubAgentMetadata | None = None,
|
|
124
|
+
) -> "SubAgentResult":
|
|
125
|
+
"""Create a completed result."""
|
|
126
|
+
return cls(
|
|
127
|
+
output=output,
|
|
128
|
+
status="completed",
|
|
129
|
+
data=data,
|
|
130
|
+
state_updates=state_updates,
|
|
131
|
+
metadata=metadata,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
@classmethod
|
|
135
|
+
def failed(cls, error: str, output: str = "") -> "SubAgentResult":
|
|
136
|
+
"""Create a failed result."""
|
|
137
|
+
return cls(
|
|
138
|
+
output=output or f"SubAgent failed: {error}",
|
|
139
|
+
status="failed",
|
|
140
|
+
error=error,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
@classmethod
|
|
144
|
+
def aborted(cls, output: str = "SubAgent was aborted") -> "SubAgentResult":
|
|
145
|
+
"""Create an aborted result."""
|
|
146
|
+
return cls(output=output, status="aborted")
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
__all__ = [
|
|
150
|
+
"SubAgentMode",
|
|
151
|
+
"SubAgentInput",
|
|
152
|
+
"SubAgentMetadata",
|
|
153
|
+
"SubAgentResult",
|
|
154
|
+
]
|
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
"""Tool-related type definitions."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import asyncio
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from datetime import datetime
|
|
7
|
+
from enum import Enum
|
|
8
|
+
from typing import Any, Awaitable, Callable
|
|
9
|
+
|
|
10
|
+
from .session import generate_id
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class ToolInfo:
|
|
15
|
+
"""Tool metadata for LLM."""
|
|
16
|
+
name: str
|
|
17
|
+
description: str
|
|
18
|
+
parameters: dict[str, Any] # JSON Schema
|
|
19
|
+
|
|
20
|
+
def to_llm_schema(self) -> dict[str, Any]:
|
|
21
|
+
"""Convert to LLM API format."""
|
|
22
|
+
return {
|
|
23
|
+
"name": self.name,
|
|
24
|
+
"description": self.description,
|
|
25
|
+
"input_schema": self.parameters,
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass
|
|
30
|
+
class ToolContext:
|
|
31
|
+
"""Context passed to tool execution."""
|
|
32
|
+
session_id: str
|
|
33
|
+
invocation_id: str
|
|
34
|
+
block_id: str
|
|
35
|
+
call_id: str
|
|
36
|
+
agent: str
|
|
37
|
+
abort_signal: asyncio.Event
|
|
38
|
+
update_metadata: Callable[[dict[str, Any]], Awaitable[None]]
|
|
39
|
+
|
|
40
|
+
# Optional usage tracker
|
|
41
|
+
usage: Any | None = None # UsageTracker
|
|
42
|
+
|
|
43
|
+
# Branch for sub-agent isolation
|
|
44
|
+
branch: str | None = None
|
|
45
|
+
|
|
46
|
+
# Caller's middleware chain (for sub-agent inheritance)
|
|
47
|
+
middleware: Any | None = None # MiddlewareChain
|
|
48
|
+
|
|
49
|
+
async def emit(self, block: Any) -> None:
|
|
50
|
+
"""Emit a block event.
|
|
51
|
+
|
|
52
|
+
Uses the global emit function via ContextVar.
|
|
53
|
+
Works automatically when called within agent.run() context.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
block: BlockEvent to emit
|
|
57
|
+
"""
|
|
58
|
+
from ..context import emit as global_emit
|
|
59
|
+
|
|
60
|
+
# Fill in IDs if not set
|
|
61
|
+
if hasattr(block, 'session_id') and not block.session_id:
|
|
62
|
+
block.session_id = self.session_id
|
|
63
|
+
if hasattr(block, 'invocation_id') and not block.invocation_id:
|
|
64
|
+
block.invocation_id = self.invocation_id
|
|
65
|
+
|
|
66
|
+
await global_emit(block)
|
|
67
|
+
|
|
68
|
+
async def emit_hitl(self, request_id: str, data: dict[str, Any]) -> None:
|
|
69
|
+
"""Emit a HITL request block.
|
|
70
|
+
|
|
71
|
+
Convenience method for tools that need user interaction.
|
|
72
|
+
The data format is flexible - can be anything the frontend understands:
|
|
73
|
+
- Choice selection (choices, radio, checkbox)
|
|
74
|
+
- Text input (text, textarea, number)
|
|
75
|
+
- Confirmation (yes/no, approve/reject)
|
|
76
|
+
- Rich content (product cards, file selection, etc.)
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
request_id: Unique ID for this HITL request
|
|
80
|
+
data: Arbitrary data dict for frontend to render.
|
|
81
|
+
Common fields: type, question, choices, default, context
|
|
82
|
+
"""
|
|
83
|
+
from .block import BlockEvent, BlockKind
|
|
84
|
+
|
|
85
|
+
await self.emit(BlockEvent(
|
|
86
|
+
kind=BlockKind.HITL_REQUEST,
|
|
87
|
+
data={"request_id": request_id, **data},
|
|
88
|
+
))
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@dataclass
|
|
92
|
+
class ToolResult:
|
|
93
|
+
"""Tool execution result for LLM."""
|
|
94
|
+
output: str
|
|
95
|
+
is_error: bool = False
|
|
96
|
+
|
|
97
|
+
@classmethod
|
|
98
|
+
def success(cls, output: str) -> ToolResult:
|
|
99
|
+
"""Create a successful result."""
|
|
100
|
+
return cls(output=output, is_error=False)
|
|
101
|
+
|
|
102
|
+
@classmethod
|
|
103
|
+
def error(cls, message: str) -> ToolResult:
|
|
104
|
+
"""Create an error result."""
|
|
105
|
+
return cls(output=message, is_error=True)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class ToolInvocationState(Enum):
|
|
109
|
+
"""Tool invocation state machine."""
|
|
110
|
+
PARTIAL_CALL = "partial-call" # Arguments streaming
|
|
111
|
+
CALL = "call" # Arguments complete, ready to execute
|
|
112
|
+
RESULT = "result" # Execution complete
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
@dataclass
|
|
116
|
+
class ToolInvocation:
|
|
117
|
+
"""Tool invocation tracking (state machine)."""
|
|
118
|
+
tool_call_id: str
|
|
119
|
+
tool_name: str
|
|
120
|
+
state: ToolInvocationState = ToolInvocationState.PARTIAL_CALL
|
|
121
|
+
args: dict[str, Any] = field(default_factory=dict)
|
|
122
|
+
args_raw: str = "" # Raw JSON string for streaming
|
|
123
|
+
result: str | None = None
|
|
124
|
+
is_error: bool = False
|
|
125
|
+
|
|
126
|
+
# Timing
|
|
127
|
+
time: dict[str, datetime | None] = field(
|
|
128
|
+
default_factory=lambda: {"start": None, "end": None}
|
|
129
|
+
)
|
|
130
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
131
|
+
|
|
132
|
+
def mark_call_complete(self) -> None:
|
|
133
|
+
"""Mark arguments as complete."""
|
|
134
|
+
self.state = ToolInvocationState.CALL
|
|
135
|
+
self.time["start"] = datetime.now()
|
|
136
|
+
|
|
137
|
+
def mark_result(self, result: str, is_error: bool = False) -> None:
|
|
138
|
+
"""Mark execution complete."""
|
|
139
|
+
self.state = ToolInvocationState.RESULT
|
|
140
|
+
self.result = result
|
|
141
|
+
self.is_error = is_error
|
|
142
|
+
self.time["end"] = datetime.now()
|
|
143
|
+
|
|
144
|
+
@property
|
|
145
|
+
def duration_ms(self) -> int | None:
|
|
146
|
+
"""Get execution duration."""
|
|
147
|
+
if self.time["start"] and self.time["end"]:
|
|
148
|
+
return int((self.time["end"] - self.time["start"]).total_seconds() * 1000)
|
|
149
|
+
return None
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class BaseTool:
|
|
153
|
+
"""Base class for tools with common functionality."""
|
|
154
|
+
|
|
155
|
+
_name: str = "base_tool"
|
|
156
|
+
_description: str = "Base tool"
|
|
157
|
+
_parameters: dict[str, Any] = {
|
|
158
|
+
"type": "object",
|
|
159
|
+
"properties": {},
|
|
160
|
+
"required": [],
|
|
161
|
+
}
|
|
162
|
+
_config: ToolConfig | None = None
|
|
163
|
+
|
|
164
|
+
@property
|
|
165
|
+
def name(self) -> str:
|
|
166
|
+
return self._name
|
|
167
|
+
|
|
168
|
+
@property
|
|
169
|
+
def description(self) -> str:
|
|
170
|
+
return self._description
|
|
171
|
+
|
|
172
|
+
@property
|
|
173
|
+
def parameters(self) -> dict[str, Any]:
|
|
174
|
+
return self._parameters
|
|
175
|
+
|
|
176
|
+
@property
|
|
177
|
+
def config(self) -> ToolConfig:
|
|
178
|
+
"""Get tool config. Returns default config if not set."""
|
|
179
|
+
return self._config or ToolConfig()
|
|
180
|
+
|
|
181
|
+
async def execute(self, params: dict[str, Any], ctx: ToolContext) -> ToolResult:
|
|
182
|
+
"""Override this method."""
|
|
183
|
+
raise NotImplementedError("Subclass must implement execute()")
|
|
184
|
+
|
|
185
|
+
def get_info(self) -> ToolInfo:
|
|
186
|
+
"""Get tool info."""
|
|
187
|
+
return ToolInfo(
|
|
188
|
+
name=self.name,
|
|
189
|
+
description=self.description,
|
|
190
|
+
parameters=self.parameters,
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
@dataclass
|
|
195
|
+
class ToolConfig:
|
|
196
|
+
"""Tool configuration."""
|
|
197
|
+
is_resumable: bool = False # Supports pause/resume
|
|
198
|
+
timeout: float | None = None # Execution timeout in seconds
|
|
199
|
+
requires_permission: bool = False # Needs HITL approval
|
|
200
|
+
permission_message: str | None = None
|
|
201
|
+
|
|
202
|
+
# Retry configuration
|
|
203
|
+
max_retries: int = 0 # 0 = no retry
|
|
204
|
+
retry_delay: float = 1.0 # Base delay between retries (seconds)
|
|
205
|
+
retry_backoff: float = 2.0 # Exponential backoff multiplier
|