aury-agent 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (149) hide show
  1. aury/__init__.py +2 -0
  2. aury/agents/__init__.py +55 -0
  3. aury/agents/a2a/__init__.py +168 -0
  4. aury/agents/backends/__init__.py +196 -0
  5. aury/agents/backends/artifact/__init__.py +9 -0
  6. aury/agents/backends/artifact/memory.py +130 -0
  7. aury/agents/backends/artifact/types.py +133 -0
  8. aury/agents/backends/code/__init__.py +65 -0
  9. aury/agents/backends/file/__init__.py +11 -0
  10. aury/agents/backends/file/local.py +66 -0
  11. aury/agents/backends/file/types.py +40 -0
  12. aury/agents/backends/invocation/__init__.py +8 -0
  13. aury/agents/backends/invocation/memory.py +81 -0
  14. aury/agents/backends/invocation/types.py +110 -0
  15. aury/agents/backends/memory/__init__.py +8 -0
  16. aury/agents/backends/memory/memory.py +179 -0
  17. aury/agents/backends/memory/types.py +136 -0
  18. aury/agents/backends/message/__init__.py +9 -0
  19. aury/agents/backends/message/memory.py +122 -0
  20. aury/agents/backends/message/types.py +124 -0
  21. aury/agents/backends/sandbox.py +275 -0
  22. aury/agents/backends/session/__init__.py +8 -0
  23. aury/agents/backends/session/memory.py +93 -0
  24. aury/agents/backends/session/types.py +124 -0
  25. aury/agents/backends/shell/__init__.py +11 -0
  26. aury/agents/backends/shell/local.py +110 -0
  27. aury/agents/backends/shell/types.py +55 -0
  28. aury/agents/backends/shell.py +209 -0
  29. aury/agents/backends/snapshot/__init__.py +19 -0
  30. aury/agents/backends/snapshot/git.py +95 -0
  31. aury/agents/backends/snapshot/hybrid.py +125 -0
  32. aury/agents/backends/snapshot/memory.py +86 -0
  33. aury/agents/backends/snapshot/types.py +59 -0
  34. aury/agents/backends/state/__init__.py +29 -0
  35. aury/agents/backends/state/composite.py +49 -0
  36. aury/agents/backends/state/file.py +57 -0
  37. aury/agents/backends/state/memory.py +52 -0
  38. aury/agents/backends/state/sqlite.py +262 -0
  39. aury/agents/backends/state/types.py +178 -0
  40. aury/agents/backends/subagent/__init__.py +165 -0
  41. aury/agents/cli/__init__.py +41 -0
  42. aury/agents/cli/chat.py +239 -0
  43. aury/agents/cli/config.py +236 -0
  44. aury/agents/cli/extensions.py +460 -0
  45. aury/agents/cli/main.py +189 -0
  46. aury/agents/cli/session.py +337 -0
  47. aury/agents/cli/workflow.py +276 -0
  48. aury/agents/context_providers/__init__.py +66 -0
  49. aury/agents/context_providers/artifact.py +299 -0
  50. aury/agents/context_providers/base.py +177 -0
  51. aury/agents/context_providers/memory.py +70 -0
  52. aury/agents/context_providers/message.py +130 -0
  53. aury/agents/context_providers/skill.py +50 -0
  54. aury/agents/context_providers/subagent.py +46 -0
  55. aury/agents/context_providers/tool.py +68 -0
  56. aury/agents/core/__init__.py +83 -0
  57. aury/agents/core/base.py +573 -0
  58. aury/agents/core/context.py +797 -0
  59. aury/agents/core/context_builder.py +303 -0
  60. aury/agents/core/event_bus/__init__.py +15 -0
  61. aury/agents/core/event_bus/bus.py +203 -0
  62. aury/agents/core/factory.py +169 -0
  63. aury/agents/core/isolator.py +97 -0
  64. aury/agents/core/logging.py +95 -0
  65. aury/agents/core/parallel.py +194 -0
  66. aury/agents/core/runner.py +139 -0
  67. aury/agents/core/services/__init__.py +5 -0
  68. aury/agents/core/services/file_session.py +144 -0
  69. aury/agents/core/services/message.py +53 -0
  70. aury/agents/core/services/session.py +53 -0
  71. aury/agents/core/signals.py +109 -0
  72. aury/agents/core/state.py +363 -0
  73. aury/agents/core/types/__init__.py +107 -0
  74. aury/agents/core/types/action.py +176 -0
  75. aury/agents/core/types/artifact.py +135 -0
  76. aury/agents/core/types/block.py +736 -0
  77. aury/agents/core/types/message.py +350 -0
  78. aury/agents/core/types/recall.py +144 -0
  79. aury/agents/core/types/session.py +257 -0
  80. aury/agents/core/types/subagent.py +154 -0
  81. aury/agents/core/types/tool.py +205 -0
  82. aury/agents/eval/__init__.py +331 -0
  83. aury/agents/hitl/__init__.py +57 -0
  84. aury/agents/hitl/ask_user.py +242 -0
  85. aury/agents/hitl/compaction.py +230 -0
  86. aury/agents/hitl/exceptions.py +87 -0
  87. aury/agents/hitl/permission.py +617 -0
  88. aury/agents/hitl/revert.py +216 -0
  89. aury/agents/llm/__init__.py +31 -0
  90. aury/agents/llm/adapter.py +367 -0
  91. aury/agents/llm/openai.py +294 -0
  92. aury/agents/llm/provider.py +476 -0
  93. aury/agents/mcp/__init__.py +153 -0
  94. aury/agents/memory/__init__.py +46 -0
  95. aury/agents/memory/compaction.py +394 -0
  96. aury/agents/memory/manager.py +465 -0
  97. aury/agents/memory/processor.py +177 -0
  98. aury/agents/memory/store.py +187 -0
  99. aury/agents/memory/types.py +137 -0
  100. aury/agents/messages/__init__.py +40 -0
  101. aury/agents/messages/config.py +47 -0
  102. aury/agents/messages/raw_store.py +224 -0
  103. aury/agents/messages/store.py +118 -0
  104. aury/agents/messages/types.py +88 -0
  105. aury/agents/middleware/__init__.py +31 -0
  106. aury/agents/middleware/base.py +341 -0
  107. aury/agents/middleware/chain.py +342 -0
  108. aury/agents/middleware/message.py +129 -0
  109. aury/agents/middleware/message_container.py +126 -0
  110. aury/agents/middleware/raw_message.py +153 -0
  111. aury/agents/middleware/truncation.py +139 -0
  112. aury/agents/middleware/types.py +81 -0
  113. aury/agents/plugin.py +162 -0
  114. aury/agents/react/__init__.py +4 -0
  115. aury/agents/react/agent.py +1923 -0
  116. aury/agents/sandbox/__init__.py +23 -0
  117. aury/agents/sandbox/local.py +239 -0
  118. aury/agents/sandbox/remote.py +200 -0
  119. aury/agents/sandbox/types.py +115 -0
  120. aury/agents/skill/__init__.py +16 -0
  121. aury/agents/skill/loader.py +180 -0
  122. aury/agents/skill/types.py +83 -0
  123. aury/agents/tool/__init__.py +39 -0
  124. aury/agents/tool/builtin/__init__.py +23 -0
  125. aury/agents/tool/builtin/ask_user.py +155 -0
  126. aury/agents/tool/builtin/bash.py +107 -0
  127. aury/agents/tool/builtin/delegate.py +726 -0
  128. aury/agents/tool/builtin/edit.py +121 -0
  129. aury/agents/tool/builtin/plan.py +277 -0
  130. aury/agents/tool/builtin/read.py +91 -0
  131. aury/agents/tool/builtin/thinking.py +111 -0
  132. aury/agents/tool/builtin/yield_result.py +130 -0
  133. aury/agents/tool/decorator.py +252 -0
  134. aury/agents/tool/set.py +204 -0
  135. aury/agents/usage/__init__.py +12 -0
  136. aury/agents/usage/tracker.py +236 -0
  137. aury/agents/workflow/__init__.py +85 -0
  138. aury/agents/workflow/adapter.py +268 -0
  139. aury/agents/workflow/dag.py +116 -0
  140. aury/agents/workflow/dsl.py +575 -0
  141. aury/agents/workflow/executor.py +659 -0
  142. aury/agents/workflow/expression.py +136 -0
  143. aury/agents/workflow/parser.py +182 -0
  144. aury/agents/workflow/state.py +145 -0
  145. aury/agents/workflow/types.py +86 -0
  146. aury_agent-0.0.4.dist-info/METADATA +90 -0
  147. aury_agent-0.0.4.dist-info/RECORD +149 -0
  148. aury_agent-0.0.4.dist-info/WHEEL +4 -0
  149. aury_agent-0.0.4.dist-info/entry_points.txt +2 -0
@@ -0,0 +1,257 @@
1
+ """Session and Invocation data structures."""
2
+ from __future__ import annotations
3
+
4
+ from dataclasses import dataclass, field
5
+ from datetime import datetime
6
+ from enum import Enum
7
+ from typing import Any
8
+ from uuid import uuid4
9
+
10
+
11
+ def generate_id(prefix: str = "") -> str:
12
+ """Generate a unique ID with optional prefix."""
13
+ uid = uuid4().hex[:12]
14
+ return f"{prefix}_{uid}" if prefix else uid
15
+
16
+
17
+ class InvocationState(Enum):
18
+ """Invocation execution state."""
19
+ PENDING = "pending"
20
+ RUNNING = "running"
21
+ SUSPENDED = "suspended" # HITL waiting
22
+ PAUSED = "paused" # User paused
23
+ COMPLETED = "completed"
24
+ FAILED = "failed"
25
+ CANCELLED = "cancelled"
26
+ ABORTED = "aborted" # User stopped
27
+ SWITCHED = "switched" # User switched agent
28
+
29
+
30
+ class InvocationMode(Enum):
31
+ """Invocation mode."""
32
+ ROOT = "root" # Root invocation
33
+ DELEGATED = "delegated" # Delegated sub-agent
34
+ # Note: EMBEDDED doesn't create new Invocation, uses parent's
35
+
36
+
37
+ @dataclass
38
+ class ControlFrame:
39
+ """Control frame for tracking delegated agent in session."""
40
+ agent_id: str
41
+ invocation_id: str
42
+ entered_at: datetime = field(default_factory=datetime.now)
43
+ parent_invocation_id: str | None = None
44
+
45
+ def to_dict(self) -> dict[str, Any]:
46
+ return {
47
+ "agent_id": self.agent_id,
48
+ "invocation_id": self.invocation_id,
49
+ "entered_at": self.entered_at.isoformat(),
50
+ "parent_invocation_id": self.parent_invocation_id,
51
+ }
52
+
53
+ @classmethod
54
+ def from_dict(cls, data: dict[str, Any]) -> "ControlFrame":
55
+ return cls(
56
+ agent_id=data["agent_id"],
57
+ invocation_id=data["invocation_id"],
58
+ entered_at=datetime.fromisoformat(data["entered_at"]),
59
+ parent_invocation_id=data.get("parent_invocation_id"),
60
+ )
61
+
62
+
63
+ @dataclass
64
+ class Session:
65
+ """Session - container for conversations.
66
+
67
+ A Session represents a conversation thread that can contain multiple
68
+ invocations (turns). Sessions can be nested (parent_id) for sub-agent scenarios.
69
+ """
70
+ id: str = field(default_factory=lambda: generate_id("sess"))
71
+ root_agent_id: str = "" # Root agent for this session
72
+ parent_id: str | None = None # For forked sessions
73
+ created_at: datetime = field(default_factory=datetime.now)
74
+ updated_at: datetime = field(default_factory=datetime.now)
75
+ metadata: dict[str, Any] = field(default_factory=dict) # title, etc.
76
+ is_active: bool = True
77
+
78
+ # Control stack for DELEGATED mode
79
+ control_stack: list[ControlFrame] = field(default_factory=list)
80
+
81
+ # Revert state (if currently in reverted state)
82
+ revert: dict[str, Any] | None = None
83
+
84
+ @property
85
+ def active_agent_id(self) -> str:
86
+ """Get the currently active agent (top of control stack or root)."""
87
+ if self.control_stack:
88
+ return self.control_stack[-1].agent_id
89
+ return self.root_agent_id
90
+
91
+ def push_control(self, frame: ControlFrame) -> None:
92
+ """Push control frame when delegating to sub-agent."""
93
+ self.control_stack.append(frame)
94
+ self.updated_at = datetime.now()
95
+
96
+ def pop_control(self) -> ControlFrame | None:
97
+ """Pop control frame when sub-agent returns control."""
98
+ if self.control_stack:
99
+ self.updated_at = datetime.now()
100
+ return self.control_stack.pop()
101
+ return None
102
+
103
+ def to_dict(self) -> dict[str, Any]:
104
+ """Convert to dictionary for serialization."""
105
+ return {
106
+ "id": self.id,
107
+ "root_agent_id": self.root_agent_id,
108
+ "parent_id": self.parent_id,
109
+ "created_at": self.created_at.isoformat(),
110
+ "updated_at": self.updated_at.isoformat(),
111
+ "metadata": self.metadata,
112
+ "is_active": self.is_active,
113
+ "control_stack": [f.to_dict() for f in self.control_stack],
114
+ "revert": self.revert,
115
+ }
116
+
117
+ @classmethod
118
+ def from_dict(cls, data: dict[str, Any]) -> "Session":
119
+ """Create from dictionary."""
120
+ return cls(
121
+ id=data["id"],
122
+ root_agent_id=data.get("root_agent_id", ""),
123
+ parent_id=data.get("parent_id"),
124
+ created_at=datetime.fromisoformat(data["created_at"]),
125
+ updated_at=datetime.fromisoformat(data["updated_at"]),
126
+ metadata=data.get("metadata", {}),
127
+ is_active=data.get("is_active", True),
128
+ control_stack=[ControlFrame.from_dict(f) for f in data.get("control_stack", [])],
129
+ revert=data.get("revert"),
130
+ )
131
+
132
+
133
+ @dataclass
134
+ class Invocation:
135
+ """A single invocation (turn) within a session.
136
+
137
+ An Invocation represents one user input and the agent's response,
138
+ including all tool calls and intermediate steps.
139
+ """
140
+ id: str = field(default_factory=lambda: generate_id("inv"))
141
+ session_id: str = ""
142
+ agent_id: str = "" # Which agent executes this invocation
143
+ mode: InvocationMode = InvocationMode.ROOT # ROOT or DELEGATED
144
+ state: InvocationState = InvocationState.PENDING
145
+
146
+ # Relationship (tree structure)
147
+ parent_invocation_id: str | None = None # Parent invocation (for DELEGATED)
148
+
149
+ # Timestamps
150
+ created_at: datetime = field(default_factory=datetime.now)
151
+ started_at: datetime | None = None
152
+ finished_at: datetime | None = None
153
+
154
+ # Agent state for resumption
155
+ agent_state: dict[str, Any] | None = None
156
+ pending_tool_ids: list[str] = field(default_factory=list)
157
+
158
+ # Execution info
159
+ step_count: int = 0
160
+ snapshot_id: str | None = None # Pre-execution snapshot for revert
161
+
162
+ # Error info
163
+ error: str | None = None
164
+
165
+ # Branch (for state isolation)
166
+ branch: str | None = None
167
+
168
+ metadata: dict[str, Any] = field(default_factory=dict)
169
+
170
+ def mark_started(self) -> None:
171
+ """Mark invocation as started."""
172
+ self.state = InvocationState.RUNNING
173
+ self.started_at = datetime.now()
174
+
175
+ def mark_completed(self) -> None:
176
+ """Mark invocation as completed."""
177
+ self.state = InvocationState.COMPLETED
178
+ self.finished_at = datetime.now()
179
+
180
+ def mark_failed(self, error: str) -> None:
181
+ """Mark invocation as failed."""
182
+ self.state = InvocationState.FAILED
183
+ self.error = error
184
+ self.finished_at = datetime.now()
185
+
186
+ def mark_cancelled(self) -> None:
187
+ """Mark invocation as cancelled."""
188
+ self.state = InvocationState.CANCELLED
189
+ self.finished_at = datetime.now()
190
+
191
+ def mark_aborted(self) -> None:
192
+ """Mark invocation as aborted by user."""
193
+ self.state = InvocationState.ABORTED
194
+ self.finished_at = datetime.now()
195
+
196
+ def mark_switched(self) -> None:
197
+ """Mark invocation as ended due to agent switch."""
198
+ self.state = InvocationState.SWITCHED
199
+ self.finished_at = datetime.now()
200
+
201
+ def mark_paused(self) -> None:
202
+ """Mark invocation as paused."""
203
+ self.state = InvocationState.PAUSED
204
+
205
+ def mark_suspended(self) -> None:
206
+ """Mark invocation as suspended (HITL)."""
207
+ self.state = InvocationState.SUSPENDED
208
+
209
+ @property
210
+ def duration_ms(self) -> int | None:
211
+ """Get execution duration in milliseconds."""
212
+ if self.started_at and self.finished_at:
213
+ return int((self.finished_at - self.started_at).total_seconds() * 1000)
214
+ return None
215
+
216
+ def to_dict(self) -> dict[str, Any]:
217
+ """Convert to dictionary for serialization."""
218
+ return {
219
+ "id": self.id,
220
+ "session_id": self.session_id,
221
+ "agent_id": self.agent_id,
222
+ "mode": self.mode.value,
223
+ "state": self.state.value,
224
+ "parent_invocation_id": self.parent_invocation_id,
225
+ "created_at": self.created_at.isoformat(),
226
+ "started_at": self.started_at.isoformat() if self.started_at else None,
227
+ "finished_at": self.finished_at.isoformat() if self.finished_at else None,
228
+ "agent_state": self.agent_state,
229
+ "pending_tool_ids": self.pending_tool_ids,
230
+ "step_count": self.step_count,
231
+ "snapshot_id": self.snapshot_id,
232
+ "error": self.error,
233
+ "branch": self.branch,
234
+ "metadata": self.metadata,
235
+ }
236
+
237
+ @classmethod
238
+ def from_dict(cls, data: dict[str, Any]) -> "Invocation":
239
+ """Create from dictionary."""
240
+ return cls(
241
+ id=data["id"],
242
+ session_id=data["session_id"],
243
+ agent_id=data.get("agent_id", ""),
244
+ mode=InvocationMode(data.get("mode", "root")),
245
+ state=InvocationState(data["state"]),
246
+ parent_invocation_id=data.get("parent_invocation_id"),
247
+ created_at=datetime.fromisoformat(data["created_at"]),
248
+ started_at=datetime.fromisoformat(data["started_at"]) if data.get("started_at") else None,
249
+ finished_at=datetime.fromisoformat(data["finished_at"]) if data.get("finished_at") else None,
250
+ agent_state=data.get("agent_state"),
251
+ pending_tool_ids=data.get("pending_tool_ids", []),
252
+ step_count=data.get("step_count", 0),
253
+ snapshot_id=data.get("snapshot_id"),
254
+ error=data.get("error"),
255
+ branch=data.get("branch"),
256
+ metadata=data.get("metadata", {}),
257
+ )
@@ -0,0 +1,154 @@
1
+ """SubAgent input/output types for agent delegation."""
2
+ from __future__ import annotations
3
+
4
+ from dataclasses import dataclass, field
5
+ from datetime import datetime
6
+ from enum import Enum
7
+ from typing import Any, Literal, TypedDict, NotRequired, TYPE_CHECKING
8
+
9
+ if TYPE_CHECKING:
10
+ from .message import Message
11
+
12
+
13
+ class SubAgentMode(Enum):
14
+ """SubAgent execution mode."""
15
+ EMBEDDED = "embedded" # Embedded execution, shares parent's invocation
16
+ DELEGATED = "delegated" # Delegated execution, creates new invocation
17
+
18
+
19
+ class SubAgentInput(TypedDict):
20
+ """Input for SubAgent invocation.
21
+
22
+ LLM 传的输入:
23
+ - agent: 调用哪个 agent
24
+ - task_context: 任务上下文(用户意图、背景、要求等)
25
+ - artifact_refs: 相关资料引用
26
+
27
+ 其他配置(mode, inherit_messages, summary_mode 等)
28
+ 在 AgentConfig 中定义,不由 LLM 传入。
29
+ """
30
+ # Required
31
+ agent: str # Agent key
32
+
33
+ # Task context - 尽可能描述用户意图、背景信息、具体要求
34
+ task_context: NotRequired[str]
35
+
36
+ # Artifact references - 相关资料 [{id, summary}, ...]
37
+ artifact_refs: NotRequired[list[dict[str, str]]]
38
+
39
+
40
+ @dataclass
41
+ class SubAgentMetadata:
42
+ """Metadata about sub-agent execution."""
43
+ child_invocation_id: str
44
+ agent_name: str
45
+ agent_type: str # "react" | "workflow"
46
+ steps: int = 0
47
+ duration_ms: int = 0
48
+ token_usage: dict[str, int] = field(default_factory=dict)
49
+
50
+ def to_dict(self) -> dict[str, Any]:
51
+ return {
52
+ "child_invocation_id": self.child_invocation_id,
53
+ "agent_name": self.agent_name,
54
+ "agent_type": self.agent_type,
55
+ "steps": self.steps,
56
+ "duration_ms": self.duration_ms,
57
+ "token_usage": self.token_usage,
58
+ }
59
+
60
+ @classmethod
61
+ def from_dict(cls, data: dict[str, Any]) -> "SubAgentMetadata":
62
+ return cls(
63
+ child_invocation_id=data["child_invocation_id"],
64
+ agent_name=data["agent_name"],
65
+ agent_type=data.get("agent_type", "react"),
66
+ steps=data.get("steps", 0),
67
+ duration_ms=data.get("duration_ms", 0),
68
+ token_usage=data.get("token_usage", {}),
69
+ )
70
+
71
+
72
+ @dataclass
73
+ class SubAgentResult:
74
+ """Result returned by sub-agent to parent agent.
75
+
76
+ Contains both text output (for LLM context) and structured data.
77
+ """
78
+ # Text output (summary for LLM)
79
+ output: str
80
+
81
+ # Execution status
82
+ status: Literal["completed", "aborted", "failed", "switched"]
83
+
84
+ # Structured data (optional)
85
+ data: dict[str, Any] | None = None
86
+
87
+ # State changes to merge back to parent
88
+ state_updates: dict[str, Any] | None = None
89
+
90
+ # Error info (when failed)
91
+ error: str | None = None
92
+
93
+ # Execution metadata
94
+ metadata: SubAgentMetadata | None = None
95
+
96
+ def to_dict(self) -> dict[str, Any]:
97
+ return {
98
+ "output": self.output,
99
+ "status": self.status,
100
+ "data": self.data,
101
+ "state_updates": self.state_updates,
102
+ "error": self.error,
103
+ "metadata": self.metadata.to_dict() if self.metadata else None,
104
+ }
105
+
106
+ @classmethod
107
+ def from_dict(cls, data: dict[str, Any]) -> "SubAgentResult":
108
+ return cls(
109
+ output=data["output"],
110
+ status=data["status"],
111
+ data=data.get("data"),
112
+ state_updates=data.get("state_updates"),
113
+ error=data.get("error"),
114
+ metadata=SubAgentMetadata.from_dict(data["metadata"]) if data.get("metadata") else None,
115
+ )
116
+
117
+ @classmethod
118
+ def completed(
119
+ cls,
120
+ output: str,
121
+ data: dict[str, Any] | None = None,
122
+ state_updates: dict[str, Any] | None = None,
123
+ metadata: SubAgentMetadata | None = None,
124
+ ) -> "SubAgentResult":
125
+ """Create a completed result."""
126
+ return cls(
127
+ output=output,
128
+ status="completed",
129
+ data=data,
130
+ state_updates=state_updates,
131
+ metadata=metadata,
132
+ )
133
+
134
+ @classmethod
135
+ def failed(cls, error: str, output: str = "") -> "SubAgentResult":
136
+ """Create a failed result."""
137
+ return cls(
138
+ output=output or f"SubAgent failed: {error}",
139
+ status="failed",
140
+ error=error,
141
+ )
142
+
143
+ @classmethod
144
+ def aborted(cls, output: str = "SubAgent was aborted") -> "SubAgentResult":
145
+ """Create an aborted result."""
146
+ return cls(output=output, status="aborted")
147
+
148
+
149
+ __all__ = [
150
+ "SubAgentMode",
151
+ "SubAgentInput",
152
+ "SubAgentMetadata",
153
+ "SubAgentResult",
154
+ ]
@@ -0,0 +1,205 @@
1
+ """Tool-related type definitions."""
2
+ from __future__ import annotations
3
+
4
+ import asyncio
5
+ from dataclasses import dataclass, field
6
+ from datetime import datetime
7
+ from enum import Enum
8
+ from typing import Any, Awaitable, Callable
9
+
10
+ from .session import generate_id
11
+
12
+
13
+ @dataclass
14
+ class ToolInfo:
15
+ """Tool metadata for LLM."""
16
+ name: str
17
+ description: str
18
+ parameters: dict[str, Any] # JSON Schema
19
+
20
+ def to_llm_schema(self) -> dict[str, Any]:
21
+ """Convert to LLM API format."""
22
+ return {
23
+ "name": self.name,
24
+ "description": self.description,
25
+ "input_schema": self.parameters,
26
+ }
27
+
28
+
29
+ @dataclass
30
+ class ToolContext:
31
+ """Context passed to tool execution."""
32
+ session_id: str
33
+ invocation_id: str
34
+ block_id: str
35
+ call_id: str
36
+ agent: str
37
+ abort_signal: asyncio.Event
38
+ update_metadata: Callable[[dict[str, Any]], Awaitable[None]]
39
+
40
+ # Optional usage tracker
41
+ usage: Any | None = None # UsageTracker
42
+
43
+ # Branch for sub-agent isolation
44
+ branch: str | None = None
45
+
46
+ # Caller's middleware chain (for sub-agent inheritance)
47
+ middleware: Any | None = None # MiddlewareChain
48
+
49
+ async def emit(self, block: Any) -> None:
50
+ """Emit a block event.
51
+
52
+ Uses the global emit function via ContextVar.
53
+ Works automatically when called within agent.run() context.
54
+
55
+ Args:
56
+ block: BlockEvent to emit
57
+ """
58
+ from ..context import emit as global_emit
59
+
60
+ # Fill in IDs if not set
61
+ if hasattr(block, 'session_id') and not block.session_id:
62
+ block.session_id = self.session_id
63
+ if hasattr(block, 'invocation_id') and not block.invocation_id:
64
+ block.invocation_id = self.invocation_id
65
+
66
+ await global_emit(block)
67
+
68
+ async def emit_hitl(self, request_id: str, data: dict[str, Any]) -> None:
69
+ """Emit a HITL request block.
70
+
71
+ Convenience method for tools that need user interaction.
72
+ The data format is flexible - can be anything the frontend understands:
73
+ - Choice selection (choices, radio, checkbox)
74
+ - Text input (text, textarea, number)
75
+ - Confirmation (yes/no, approve/reject)
76
+ - Rich content (product cards, file selection, etc.)
77
+
78
+ Args:
79
+ request_id: Unique ID for this HITL request
80
+ data: Arbitrary data dict for frontend to render.
81
+ Common fields: type, question, choices, default, context
82
+ """
83
+ from .block import BlockEvent, BlockKind
84
+
85
+ await self.emit(BlockEvent(
86
+ kind=BlockKind.HITL_REQUEST,
87
+ data={"request_id": request_id, **data},
88
+ ))
89
+
90
+
91
+ @dataclass
92
+ class ToolResult:
93
+ """Tool execution result for LLM."""
94
+ output: str
95
+ is_error: bool = False
96
+
97
+ @classmethod
98
+ def success(cls, output: str) -> ToolResult:
99
+ """Create a successful result."""
100
+ return cls(output=output, is_error=False)
101
+
102
+ @classmethod
103
+ def error(cls, message: str) -> ToolResult:
104
+ """Create an error result."""
105
+ return cls(output=message, is_error=True)
106
+
107
+
108
+ class ToolInvocationState(Enum):
109
+ """Tool invocation state machine."""
110
+ PARTIAL_CALL = "partial-call" # Arguments streaming
111
+ CALL = "call" # Arguments complete, ready to execute
112
+ RESULT = "result" # Execution complete
113
+
114
+
115
+ @dataclass
116
+ class ToolInvocation:
117
+ """Tool invocation tracking (state machine)."""
118
+ tool_call_id: str
119
+ tool_name: str
120
+ state: ToolInvocationState = ToolInvocationState.PARTIAL_CALL
121
+ args: dict[str, Any] = field(default_factory=dict)
122
+ args_raw: str = "" # Raw JSON string for streaming
123
+ result: str | None = None
124
+ is_error: bool = False
125
+
126
+ # Timing
127
+ time: dict[str, datetime | None] = field(
128
+ default_factory=lambda: {"start": None, "end": None}
129
+ )
130
+ metadata: dict[str, Any] = field(default_factory=dict)
131
+
132
+ def mark_call_complete(self) -> None:
133
+ """Mark arguments as complete."""
134
+ self.state = ToolInvocationState.CALL
135
+ self.time["start"] = datetime.now()
136
+
137
+ def mark_result(self, result: str, is_error: bool = False) -> None:
138
+ """Mark execution complete."""
139
+ self.state = ToolInvocationState.RESULT
140
+ self.result = result
141
+ self.is_error = is_error
142
+ self.time["end"] = datetime.now()
143
+
144
+ @property
145
+ def duration_ms(self) -> int | None:
146
+ """Get execution duration."""
147
+ if self.time["start"] and self.time["end"]:
148
+ return int((self.time["end"] - self.time["start"]).total_seconds() * 1000)
149
+ return None
150
+
151
+
152
+ class BaseTool:
153
+ """Base class for tools with common functionality."""
154
+
155
+ _name: str = "base_tool"
156
+ _description: str = "Base tool"
157
+ _parameters: dict[str, Any] = {
158
+ "type": "object",
159
+ "properties": {},
160
+ "required": [],
161
+ }
162
+ _config: ToolConfig | None = None
163
+
164
+ @property
165
+ def name(self) -> str:
166
+ return self._name
167
+
168
+ @property
169
+ def description(self) -> str:
170
+ return self._description
171
+
172
+ @property
173
+ def parameters(self) -> dict[str, Any]:
174
+ return self._parameters
175
+
176
+ @property
177
+ def config(self) -> ToolConfig:
178
+ """Get tool config. Returns default config if not set."""
179
+ return self._config or ToolConfig()
180
+
181
+ async def execute(self, params: dict[str, Any], ctx: ToolContext) -> ToolResult:
182
+ """Override this method."""
183
+ raise NotImplementedError("Subclass must implement execute()")
184
+
185
+ def get_info(self) -> ToolInfo:
186
+ """Get tool info."""
187
+ return ToolInfo(
188
+ name=self.name,
189
+ description=self.description,
190
+ parameters=self.parameters,
191
+ )
192
+
193
+
194
+ @dataclass
195
+ class ToolConfig:
196
+ """Tool configuration."""
197
+ is_resumable: bool = False # Supports pause/resume
198
+ timeout: float | None = None # Execution timeout in seconds
199
+ requires_permission: bool = False # Needs HITL approval
200
+ permission_message: str | None = None
201
+
202
+ # Retry configuration
203
+ max_retries: int = 0 # 0 = no retry
204
+ retry_delay: float = 1.0 # Base delay between retries (seconds)
205
+ retry_backoff: float = 2.0 # Exponential backoff multiplier