aury-agent 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aury/__init__.py +2 -0
- aury/agents/__init__.py +55 -0
- aury/agents/a2a/__init__.py +168 -0
- aury/agents/backends/__init__.py +196 -0
- aury/agents/backends/artifact/__init__.py +9 -0
- aury/agents/backends/artifact/memory.py +130 -0
- aury/agents/backends/artifact/types.py +133 -0
- aury/agents/backends/code/__init__.py +65 -0
- aury/agents/backends/file/__init__.py +11 -0
- aury/agents/backends/file/local.py +66 -0
- aury/agents/backends/file/types.py +40 -0
- aury/agents/backends/invocation/__init__.py +8 -0
- aury/agents/backends/invocation/memory.py +81 -0
- aury/agents/backends/invocation/types.py +110 -0
- aury/agents/backends/memory/__init__.py +8 -0
- aury/agents/backends/memory/memory.py +179 -0
- aury/agents/backends/memory/types.py +136 -0
- aury/agents/backends/message/__init__.py +9 -0
- aury/agents/backends/message/memory.py +122 -0
- aury/agents/backends/message/types.py +124 -0
- aury/agents/backends/sandbox.py +275 -0
- aury/agents/backends/session/__init__.py +8 -0
- aury/agents/backends/session/memory.py +93 -0
- aury/agents/backends/session/types.py +124 -0
- aury/agents/backends/shell/__init__.py +11 -0
- aury/agents/backends/shell/local.py +110 -0
- aury/agents/backends/shell/types.py +55 -0
- aury/agents/backends/shell.py +209 -0
- aury/agents/backends/snapshot/__init__.py +19 -0
- aury/agents/backends/snapshot/git.py +95 -0
- aury/agents/backends/snapshot/hybrid.py +125 -0
- aury/agents/backends/snapshot/memory.py +86 -0
- aury/agents/backends/snapshot/types.py +59 -0
- aury/agents/backends/state/__init__.py +29 -0
- aury/agents/backends/state/composite.py +49 -0
- aury/agents/backends/state/file.py +57 -0
- aury/agents/backends/state/memory.py +52 -0
- aury/agents/backends/state/sqlite.py +262 -0
- aury/agents/backends/state/types.py +178 -0
- aury/agents/backends/subagent/__init__.py +165 -0
- aury/agents/cli/__init__.py +41 -0
- aury/agents/cli/chat.py +239 -0
- aury/agents/cli/config.py +236 -0
- aury/agents/cli/extensions.py +460 -0
- aury/agents/cli/main.py +189 -0
- aury/agents/cli/session.py +337 -0
- aury/agents/cli/workflow.py +276 -0
- aury/agents/context_providers/__init__.py +66 -0
- aury/agents/context_providers/artifact.py +299 -0
- aury/agents/context_providers/base.py +177 -0
- aury/agents/context_providers/memory.py +70 -0
- aury/agents/context_providers/message.py +130 -0
- aury/agents/context_providers/skill.py +50 -0
- aury/agents/context_providers/subagent.py +46 -0
- aury/agents/context_providers/tool.py +68 -0
- aury/agents/core/__init__.py +83 -0
- aury/agents/core/base.py +573 -0
- aury/agents/core/context.py +797 -0
- aury/agents/core/context_builder.py +303 -0
- aury/agents/core/event_bus/__init__.py +15 -0
- aury/agents/core/event_bus/bus.py +203 -0
- aury/agents/core/factory.py +169 -0
- aury/agents/core/isolator.py +97 -0
- aury/agents/core/logging.py +95 -0
- aury/agents/core/parallel.py +194 -0
- aury/agents/core/runner.py +139 -0
- aury/agents/core/services/__init__.py +5 -0
- aury/agents/core/services/file_session.py +144 -0
- aury/agents/core/services/message.py +53 -0
- aury/agents/core/services/session.py +53 -0
- aury/agents/core/signals.py +109 -0
- aury/agents/core/state.py +363 -0
- aury/agents/core/types/__init__.py +107 -0
- aury/agents/core/types/action.py +176 -0
- aury/agents/core/types/artifact.py +135 -0
- aury/agents/core/types/block.py +736 -0
- aury/agents/core/types/message.py +350 -0
- aury/agents/core/types/recall.py +144 -0
- aury/agents/core/types/session.py +257 -0
- aury/agents/core/types/subagent.py +154 -0
- aury/agents/core/types/tool.py +205 -0
- aury/agents/eval/__init__.py +331 -0
- aury/agents/hitl/__init__.py +57 -0
- aury/agents/hitl/ask_user.py +242 -0
- aury/agents/hitl/compaction.py +230 -0
- aury/agents/hitl/exceptions.py +87 -0
- aury/agents/hitl/permission.py +617 -0
- aury/agents/hitl/revert.py +216 -0
- aury/agents/llm/__init__.py +31 -0
- aury/agents/llm/adapter.py +367 -0
- aury/agents/llm/openai.py +294 -0
- aury/agents/llm/provider.py +476 -0
- aury/agents/mcp/__init__.py +153 -0
- aury/agents/memory/__init__.py +46 -0
- aury/agents/memory/compaction.py +394 -0
- aury/agents/memory/manager.py +465 -0
- aury/agents/memory/processor.py +177 -0
- aury/agents/memory/store.py +187 -0
- aury/agents/memory/types.py +137 -0
- aury/agents/messages/__init__.py +40 -0
- aury/agents/messages/config.py +47 -0
- aury/agents/messages/raw_store.py +224 -0
- aury/agents/messages/store.py +118 -0
- aury/agents/messages/types.py +88 -0
- aury/agents/middleware/__init__.py +31 -0
- aury/agents/middleware/base.py +341 -0
- aury/agents/middleware/chain.py +342 -0
- aury/agents/middleware/message.py +129 -0
- aury/agents/middleware/message_container.py +126 -0
- aury/agents/middleware/raw_message.py +153 -0
- aury/agents/middleware/truncation.py +139 -0
- aury/agents/middleware/types.py +81 -0
- aury/agents/plugin.py +162 -0
- aury/agents/react/__init__.py +4 -0
- aury/agents/react/agent.py +1923 -0
- aury/agents/sandbox/__init__.py +23 -0
- aury/agents/sandbox/local.py +239 -0
- aury/agents/sandbox/remote.py +200 -0
- aury/agents/sandbox/types.py +115 -0
- aury/agents/skill/__init__.py +16 -0
- aury/agents/skill/loader.py +180 -0
- aury/agents/skill/types.py +83 -0
- aury/agents/tool/__init__.py +39 -0
- aury/agents/tool/builtin/__init__.py +23 -0
- aury/agents/tool/builtin/ask_user.py +155 -0
- aury/agents/tool/builtin/bash.py +107 -0
- aury/agents/tool/builtin/delegate.py +726 -0
- aury/agents/tool/builtin/edit.py +121 -0
- aury/agents/tool/builtin/plan.py +277 -0
- aury/agents/tool/builtin/read.py +91 -0
- aury/agents/tool/builtin/thinking.py +111 -0
- aury/agents/tool/builtin/yield_result.py +130 -0
- aury/agents/tool/decorator.py +252 -0
- aury/agents/tool/set.py +204 -0
- aury/agents/usage/__init__.py +12 -0
- aury/agents/usage/tracker.py +236 -0
- aury/agents/workflow/__init__.py +85 -0
- aury/agents/workflow/adapter.py +268 -0
- aury/agents/workflow/dag.py +116 -0
- aury/agents/workflow/dsl.py +575 -0
- aury/agents/workflow/executor.py +659 -0
- aury/agents/workflow/expression.py +136 -0
- aury/agents/workflow/parser.py +182 -0
- aury/agents/workflow/state.py +145 -0
- aury/agents/workflow/types.py +86 -0
- aury_agent-0.0.4.dist-info/METADATA +90 -0
- aury_agent-0.0.4.dist-info/RECORD +149 -0
- aury_agent-0.0.4.dist-info/WHEEL +4 -0
- aury_agent-0.0.4.dist-info/entry_points.txt +2 -0
|
@@ -0,0 +1,331 @@
|
|
|
1
|
+
"""Evaluation framework for agent testing.
|
|
2
|
+
|
|
3
|
+
TODO: Implement evaluation suite.
|
|
4
|
+
|
|
5
|
+
This module will provide:
|
|
6
|
+
- EvalSuite: Define and run evaluation test suites
|
|
7
|
+
- TestCase: Individual test case definition
|
|
8
|
+
- EvalResult: Test result with metrics
|
|
9
|
+
- Evaluators: Built-in evaluators (exact match, semantic, etc.)
|
|
10
|
+
|
|
11
|
+
Reference: Agent evaluation best practices
|
|
12
|
+
"""
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
from dataclasses import dataclass, field
|
|
16
|
+
from enum import Enum
|
|
17
|
+
from typing import Any, Callable, Awaitable, TYPE_CHECKING
|
|
18
|
+
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
from ..core.base import BaseAgent
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
# =============================================================================
|
|
24
|
+
# Test Case
|
|
25
|
+
# =============================================================================
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class TestCase:
|
|
29
|
+
"""A single evaluation test case.
|
|
30
|
+
|
|
31
|
+
TODO: Implement test case execution.
|
|
32
|
+
|
|
33
|
+
Usage:
|
|
34
|
+
case = TestCase(
|
|
35
|
+
name="simple_math",
|
|
36
|
+
input="What is 2 + 2?",
|
|
37
|
+
expected="4",
|
|
38
|
+
evaluator="contains",
|
|
39
|
+
tags=["math", "basic"],
|
|
40
|
+
)
|
|
41
|
+
"""
|
|
42
|
+
name: str
|
|
43
|
+
input: str | dict[str, Any]
|
|
44
|
+
expected: str | dict[str, Any] | None = None
|
|
45
|
+
evaluator: str = "exact" # "exact", "contains", "semantic", "custom"
|
|
46
|
+
custom_evaluator: Callable[[str, str], Awaitable[float]] | None = None
|
|
47
|
+
tags: list[str] = field(default_factory=list)
|
|
48
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
49
|
+
timeout: float = 30.0
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
# =============================================================================
|
|
53
|
+
# Eval Result
|
|
54
|
+
# =============================================================================
|
|
55
|
+
|
|
56
|
+
class EvalStatus(Enum):
|
|
57
|
+
"""Evaluation status."""
|
|
58
|
+
PASS = "pass"
|
|
59
|
+
FAIL = "fail"
|
|
60
|
+
ERROR = "error"
|
|
61
|
+
SKIP = "skip"
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@dataclass
|
|
65
|
+
class EvalResult:
|
|
66
|
+
"""Result of a single test case.
|
|
67
|
+
|
|
68
|
+
TODO: Implement result tracking.
|
|
69
|
+
"""
|
|
70
|
+
test_name: str
|
|
71
|
+
status: EvalStatus
|
|
72
|
+
score: float # 0.0 - 1.0
|
|
73
|
+
actual_output: str | None = None
|
|
74
|
+
expected_output: str | None = None
|
|
75
|
+
error_message: str | None = None
|
|
76
|
+
duration_ms: float = 0.0
|
|
77
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@dataclass
|
|
81
|
+
class SuiteResult:
|
|
82
|
+
"""Result of an entire test suite.
|
|
83
|
+
|
|
84
|
+
TODO: Implement suite-level metrics.
|
|
85
|
+
"""
|
|
86
|
+
suite_name: str
|
|
87
|
+
results: list[EvalResult] = field(default_factory=list)
|
|
88
|
+
total_duration_ms: float = 0.0
|
|
89
|
+
|
|
90
|
+
@property
|
|
91
|
+
def total(self) -> int:
|
|
92
|
+
return len(self.results)
|
|
93
|
+
|
|
94
|
+
@property
|
|
95
|
+
def passed(self) -> int:
|
|
96
|
+
return sum(1 for r in self.results if r.status == EvalStatus.PASS)
|
|
97
|
+
|
|
98
|
+
@property
|
|
99
|
+
def failed(self) -> int:
|
|
100
|
+
return sum(1 for r in self.results if r.status == EvalStatus.FAIL)
|
|
101
|
+
|
|
102
|
+
@property
|
|
103
|
+
def errors(self) -> int:
|
|
104
|
+
return sum(1 for r in self.results if r.status == EvalStatus.ERROR)
|
|
105
|
+
|
|
106
|
+
@property
|
|
107
|
+
def pass_rate(self) -> float:
|
|
108
|
+
return self.passed / self.total if self.total > 0 else 0.0
|
|
109
|
+
|
|
110
|
+
@property
|
|
111
|
+
def avg_score(self) -> float:
|
|
112
|
+
scores = [r.score for r in self.results]
|
|
113
|
+
return sum(scores) / len(scores) if scores else 0.0
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
# =============================================================================
|
|
117
|
+
# TODO: Evaluators
|
|
118
|
+
# =============================================================================
|
|
119
|
+
|
|
120
|
+
class Evaluator:
|
|
121
|
+
"""Base evaluator interface.
|
|
122
|
+
|
|
123
|
+
TODO: Implement evaluators.
|
|
124
|
+
"""
|
|
125
|
+
|
|
126
|
+
async def evaluate(
|
|
127
|
+
self,
|
|
128
|
+
actual: str,
|
|
129
|
+
expected: str | None,
|
|
130
|
+
context: dict[str, Any] | None = None,
|
|
131
|
+
) -> float:
|
|
132
|
+
"""Evaluate output against expected.
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
Score from 0.0 to 1.0
|
|
136
|
+
"""
|
|
137
|
+
raise NotImplementedError("TODO: Evaluator not yet implemented")
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class ExactMatchEvaluator(Evaluator):
|
|
141
|
+
"""Exact string match evaluator.
|
|
142
|
+
|
|
143
|
+
TODO: Implement exact match.
|
|
144
|
+
"""
|
|
145
|
+
|
|
146
|
+
async def evaluate(
|
|
147
|
+
self,
|
|
148
|
+
actual: str,
|
|
149
|
+
expected: str | None,
|
|
150
|
+
context: dict[str, Any] | None = None,
|
|
151
|
+
) -> float:
|
|
152
|
+
raise NotImplementedError("TODO: ExactMatchEvaluator not yet implemented")
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
class ContainsEvaluator(Evaluator):
|
|
156
|
+
"""Check if expected is contained in actual.
|
|
157
|
+
|
|
158
|
+
TODO: Implement contains check.
|
|
159
|
+
"""
|
|
160
|
+
|
|
161
|
+
async def evaluate(
|
|
162
|
+
self,
|
|
163
|
+
actual: str,
|
|
164
|
+
expected: str | None,
|
|
165
|
+
context: dict[str, Any] | None = None,
|
|
166
|
+
) -> float:
|
|
167
|
+
raise NotImplementedError("TODO: ContainsEvaluator not yet implemented")
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
class SemanticEvaluator(Evaluator):
|
|
171
|
+
"""Semantic similarity evaluator using embeddings.
|
|
172
|
+
|
|
173
|
+
TODO: Implement semantic evaluation.
|
|
174
|
+
"""
|
|
175
|
+
|
|
176
|
+
def __init__(self, threshold: float = 0.8):
|
|
177
|
+
self.threshold = threshold
|
|
178
|
+
|
|
179
|
+
async def evaluate(
|
|
180
|
+
self,
|
|
181
|
+
actual: str,
|
|
182
|
+
expected: str | None,
|
|
183
|
+
context: dict[str, Any] | None = None,
|
|
184
|
+
) -> float:
|
|
185
|
+
raise NotImplementedError("TODO: SemanticEvaluator not yet implemented")
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
class LLMJudgeEvaluator(Evaluator):
|
|
189
|
+
"""Use LLM as judge for evaluation.
|
|
190
|
+
|
|
191
|
+
TODO: Implement LLM judge.
|
|
192
|
+
"""
|
|
193
|
+
|
|
194
|
+
def __init__(self, criteria: str | None = None):
|
|
195
|
+
self.criteria = criteria
|
|
196
|
+
|
|
197
|
+
async def evaluate(
|
|
198
|
+
self,
|
|
199
|
+
actual: str,
|
|
200
|
+
expected: str | None,
|
|
201
|
+
context: dict[str, Any] | None = None,
|
|
202
|
+
) -> float:
|
|
203
|
+
raise NotImplementedError("TODO: LLMJudgeEvaluator not yet implemented")
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
# =============================================================================
|
|
207
|
+
# TODO: EvalSuite
|
|
208
|
+
# =============================================================================
|
|
209
|
+
|
|
210
|
+
class EvalSuite:
|
|
211
|
+
"""Evaluation test suite.
|
|
212
|
+
|
|
213
|
+
TODO: Implement evaluation suite.
|
|
214
|
+
|
|
215
|
+
Usage:
|
|
216
|
+
suite = EvalSuite(
|
|
217
|
+
name="math_tests",
|
|
218
|
+
agent=my_agent,
|
|
219
|
+
cases=[
|
|
220
|
+
TestCase(name="add", input="2+2?", expected="4"),
|
|
221
|
+
TestCase(name="mul", input="3*4?", expected="12"),
|
|
222
|
+
],
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
result = await suite.run()
|
|
226
|
+
print(f"Pass rate: {result.pass_rate:.1%}")
|
|
227
|
+
"""
|
|
228
|
+
|
|
229
|
+
def __init__(
|
|
230
|
+
self,
|
|
231
|
+
name: str,
|
|
232
|
+
agent: "BaseAgent",
|
|
233
|
+
cases: list[TestCase],
|
|
234
|
+
evaluators: dict[str, Evaluator] | None = None,
|
|
235
|
+
parallel: bool = False,
|
|
236
|
+
max_workers: int = 4,
|
|
237
|
+
):
|
|
238
|
+
self.name = name
|
|
239
|
+
self.agent = agent
|
|
240
|
+
self.cases = cases
|
|
241
|
+
self.evaluators = evaluators or {
|
|
242
|
+
"exact": ExactMatchEvaluator(),
|
|
243
|
+
"contains": ContainsEvaluator(),
|
|
244
|
+
"semantic": SemanticEvaluator(),
|
|
245
|
+
"llm_judge": LLMJudgeEvaluator(),
|
|
246
|
+
}
|
|
247
|
+
self.parallel = parallel
|
|
248
|
+
self.max_workers = max_workers
|
|
249
|
+
raise NotImplementedError("TODO: EvalSuite not yet implemented")
|
|
250
|
+
|
|
251
|
+
async def run(
|
|
252
|
+
self,
|
|
253
|
+
tags: list[str] | None = None,
|
|
254
|
+
verbose: bool = False,
|
|
255
|
+
) -> SuiteResult:
|
|
256
|
+
"""Run all test cases.
|
|
257
|
+
|
|
258
|
+
Args:
|
|
259
|
+
tags: Only run cases with these tags (None = all)
|
|
260
|
+
verbose: Print progress
|
|
261
|
+
|
|
262
|
+
Returns:
|
|
263
|
+
Suite result with all test results
|
|
264
|
+
"""
|
|
265
|
+
raise NotImplementedError("TODO: EvalSuite.run not yet implemented")
|
|
266
|
+
|
|
267
|
+
async def run_case(self, case: TestCase) -> EvalResult:
|
|
268
|
+
"""Run a single test case.
|
|
269
|
+
|
|
270
|
+
Args:
|
|
271
|
+
case: Test case to run
|
|
272
|
+
|
|
273
|
+
Returns:
|
|
274
|
+
Evaluation result
|
|
275
|
+
"""
|
|
276
|
+
raise NotImplementedError("TODO: EvalSuite.run_case not yet implemented")
|
|
277
|
+
|
|
278
|
+
def add_case(self, case: TestCase) -> None:
|
|
279
|
+
"""Add a test case to the suite."""
|
|
280
|
+
self.cases.append(case)
|
|
281
|
+
|
|
282
|
+
def filter_by_tags(self, tags: list[str]) -> list[TestCase]:
|
|
283
|
+
"""Filter cases by tags."""
|
|
284
|
+
return [c for c in self.cases if any(t in c.tags for t in tags)]
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
# =============================================================================
|
|
288
|
+
# Helper functions
|
|
289
|
+
# =============================================================================
|
|
290
|
+
|
|
291
|
+
def test_case(
|
|
292
|
+
name: str,
|
|
293
|
+
input: str | dict[str, Any],
|
|
294
|
+
expected: str | dict[str, Any] | None = None,
|
|
295
|
+
evaluator: str = "exact",
|
|
296
|
+
**kwargs: Any,
|
|
297
|
+
) -> TestCase:
|
|
298
|
+
"""Convenience function to create test cases.
|
|
299
|
+
|
|
300
|
+
Usage:
|
|
301
|
+
cases = [
|
|
302
|
+
test_case("simple", "2+2?", "4"),
|
|
303
|
+
test_case("semantic", "hello", "hi there", evaluator="semantic"),
|
|
304
|
+
]
|
|
305
|
+
"""
|
|
306
|
+
return TestCase(
|
|
307
|
+
name=name,
|
|
308
|
+
input=input,
|
|
309
|
+
expected=expected,
|
|
310
|
+
evaluator=evaluator,
|
|
311
|
+
**kwargs,
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
__all__ = [
|
|
316
|
+
# Test case
|
|
317
|
+
"TestCase",
|
|
318
|
+
"test_case",
|
|
319
|
+
# Results
|
|
320
|
+
"EvalStatus",
|
|
321
|
+
"EvalResult",
|
|
322
|
+
"SuiteResult",
|
|
323
|
+
# Evaluators
|
|
324
|
+
"Evaluator",
|
|
325
|
+
"ExactMatchEvaluator",
|
|
326
|
+
"ContainsEvaluator",
|
|
327
|
+
"SemanticEvaluator",
|
|
328
|
+
"LLMJudgeEvaluator",
|
|
329
|
+
# Suite
|
|
330
|
+
"EvalSuite",
|
|
331
|
+
]
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
"""HITL (Human-in-the-Loop) components."""
|
|
2
|
+
from .compaction import (
|
|
3
|
+
SessionCompaction,
|
|
4
|
+
CompactionConfig,
|
|
5
|
+
)
|
|
6
|
+
from .revert import (
|
|
7
|
+
SessionRevert,
|
|
8
|
+
RevertState,
|
|
9
|
+
BlockBackend,
|
|
10
|
+
)
|
|
11
|
+
from .exceptions import (
|
|
12
|
+
SuspendSignal,
|
|
13
|
+
HITLSuspend,
|
|
14
|
+
HITLTimeoutError,
|
|
15
|
+
HITLCancelledError,
|
|
16
|
+
HITLRequest,
|
|
17
|
+
)
|
|
18
|
+
from .ask_user import (
|
|
19
|
+
AskUserTool,
|
|
20
|
+
ConfirmTool,
|
|
21
|
+
)
|
|
22
|
+
from .permission import (
|
|
23
|
+
Permission,
|
|
24
|
+
PermissionRules,
|
|
25
|
+
PermissionSpec,
|
|
26
|
+
RejectedError,
|
|
27
|
+
SkippedError,
|
|
28
|
+
HumanResponse,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
__all__ = [
|
|
32
|
+
# Compaction
|
|
33
|
+
"SessionCompaction",
|
|
34
|
+
"CompactionConfig",
|
|
35
|
+
# Revert
|
|
36
|
+
"SessionRevert",
|
|
37
|
+
"RevertState",
|
|
38
|
+
"BlockBackend",
|
|
39
|
+
# Signals
|
|
40
|
+
"SuspendSignal",
|
|
41
|
+
"HITLSuspend",
|
|
42
|
+
# Exceptions
|
|
43
|
+
"HITLTimeoutError",
|
|
44
|
+
"HITLCancelledError",
|
|
45
|
+
# Types
|
|
46
|
+
"HITLRequest",
|
|
47
|
+
# Tools
|
|
48
|
+
"AskUserTool",
|
|
49
|
+
"ConfirmTool",
|
|
50
|
+
# Permission
|
|
51
|
+
"Permission",
|
|
52
|
+
"PermissionRules",
|
|
53
|
+
"PermissionSpec",
|
|
54
|
+
"RejectedError",
|
|
55
|
+
"SkippedError",
|
|
56
|
+
"HumanResponse",
|
|
57
|
+
]
|
|
@@ -0,0 +1,242 @@
|
|
|
1
|
+
"""AskUser tool for LLM to request human input.
|
|
2
|
+
|
|
3
|
+
This tool allows the LLM to pause execution and ask the user
|
|
4
|
+
for clarification or additional information.
|
|
5
|
+
"""
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from typing import Any, TYPE_CHECKING
|
|
9
|
+
|
|
10
|
+
from ..tool import BaseTool, ToolResult
|
|
11
|
+
from ..core.types.session import generate_id
|
|
12
|
+
from ..core.types.block import BlockEvent
|
|
13
|
+
from ..core.signals import HITLSuspend
|
|
14
|
+
from .exceptions import HITLRequest
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from ..core.context import InvocationContext
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class AskUserTool(BaseTool):
|
|
21
|
+
"""Tool for LLM to ask user questions.
|
|
22
|
+
|
|
23
|
+
When executed, this tool:
|
|
24
|
+
1. Checkpoints current state
|
|
25
|
+
2. Updates invocation status to SUSPENDED
|
|
26
|
+
3. Emits a hitl_request block to frontend
|
|
27
|
+
4. Raises HITLSuspend signal to pause execution
|
|
28
|
+
|
|
29
|
+
The user's response comes via:
|
|
30
|
+
- agent.respond(request_id, response)
|
|
31
|
+
- agent.run(response) (auto-detects suspended state)
|
|
32
|
+
|
|
33
|
+
Note: Uses HITLSuspend (inherits BaseException) so it won't be caught
|
|
34
|
+
by generic `except Exception` handlers.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
_name = "ask_user"
|
|
38
|
+
_description = "Ask the user for clarification or additional information. Use this when you need more details to complete a task."
|
|
39
|
+
_parameters = {
|
|
40
|
+
"type": "object",
|
|
41
|
+
"properties": {
|
|
42
|
+
"question": {
|
|
43
|
+
"type": "string",
|
|
44
|
+
"description": "The question to ask the user",
|
|
45
|
+
},
|
|
46
|
+
"options": {
|
|
47
|
+
"type": "array",
|
|
48
|
+
"items": {"type": "string"},
|
|
49
|
+
"description": "Optional list of suggested answers",
|
|
50
|
+
},
|
|
51
|
+
"context": {
|
|
52
|
+
"type": "string",
|
|
53
|
+
"description": "Additional context about why you're asking",
|
|
54
|
+
},
|
|
55
|
+
},
|
|
56
|
+
"required": ["question"],
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
async def execute(
|
|
60
|
+
self,
|
|
61
|
+
question: str,
|
|
62
|
+
options: list[str] | None = None,
|
|
63
|
+
context: str | None = None,
|
|
64
|
+
*,
|
|
65
|
+
ctx: "InvocationContext | None" = None,
|
|
66
|
+
) -> ToolResult:
|
|
67
|
+
"""Execute ask_user tool.
|
|
68
|
+
|
|
69
|
+
This method does NOT return normally - it raises HITLSuspend
|
|
70
|
+
to pause the agent execution.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
question: Question to ask the user
|
|
74
|
+
options: Optional predefined answer options
|
|
75
|
+
context: Additional context
|
|
76
|
+
ctx: Invocation context (injected by agent)
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
Never returns normally
|
|
80
|
+
|
|
81
|
+
Raises:
|
|
82
|
+
HITLSuspend: Always raised to suspend execution
|
|
83
|
+
"""
|
|
84
|
+
if ctx is None:
|
|
85
|
+
# If no context, return error (should not happen in normal use)
|
|
86
|
+
return ToolResult.error("Cannot ask user without execution context")
|
|
87
|
+
|
|
88
|
+
# Generate request ID
|
|
89
|
+
request_id = generate_id("req")
|
|
90
|
+
|
|
91
|
+
# Create HITL request data
|
|
92
|
+
request = HITLRequest(
|
|
93
|
+
request_id=request_id,
|
|
94
|
+
request_type="ask_user",
|
|
95
|
+
message=question,
|
|
96
|
+
options=options,
|
|
97
|
+
tool_name=self._name,
|
|
98
|
+
metadata={"context": context} if context else {},
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
# Checkpoint current state
|
|
102
|
+
if hasattr(ctx, "state") and ctx.state is not None:
|
|
103
|
+
await ctx.state.checkpoint()
|
|
104
|
+
|
|
105
|
+
# Update invocation status to SUSPENDED
|
|
106
|
+
if ctx.backends and ctx.backends.invocation:
|
|
107
|
+
await ctx.backends.invocation.update(ctx.invocation_id, {
|
|
108
|
+
"status": "suspended",
|
|
109
|
+
"pending_request_id": request_id,
|
|
110
|
+
"pending_request_type": "ask_user",
|
|
111
|
+
"pending_request_data": request.to_dict(),
|
|
112
|
+
})
|
|
113
|
+
|
|
114
|
+
# Emit HITL request block to frontend
|
|
115
|
+
await ctx.emit(BlockEvent(
|
|
116
|
+
kind="hitl_request",
|
|
117
|
+
data={
|
|
118
|
+
"request_id": request_id,
|
|
119
|
+
"type": "ask_user",
|
|
120
|
+
"question": question,
|
|
121
|
+
"options": options,
|
|
122
|
+
"context": context,
|
|
123
|
+
},
|
|
124
|
+
))
|
|
125
|
+
|
|
126
|
+
# Raise signal to suspend execution
|
|
127
|
+
raise HITLSuspend(
|
|
128
|
+
request_id=request_id,
|
|
129
|
+
request_type="ask_user",
|
|
130
|
+
message=question,
|
|
131
|
+
options=options,
|
|
132
|
+
tool_name=self._name,
|
|
133
|
+
metadata={"context": context} if context else {},
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
class ConfirmTool(BaseTool):
|
|
138
|
+
"""Tool for LLM to request confirmation before proceeding.
|
|
139
|
+
|
|
140
|
+
Similar to ask_user but specifically for yes/no confirmations.
|
|
141
|
+
"""
|
|
142
|
+
|
|
143
|
+
_name = "confirm"
|
|
144
|
+
_description = "Ask the user to confirm an action before proceeding."
|
|
145
|
+
_parameters = {
|
|
146
|
+
"type": "object",
|
|
147
|
+
"properties": {
|
|
148
|
+
"action": {
|
|
149
|
+
"type": "string",
|
|
150
|
+
"description": "The action you want to perform",
|
|
151
|
+
},
|
|
152
|
+
"details": {
|
|
153
|
+
"type": "string",
|
|
154
|
+
"description": "Details about the action",
|
|
155
|
+
},
|
|
156
|
+
"risk_level": {
|
|
157
|
+
"type": "string",
|
|
158
|
+
"enum": ["low", "medium", "high"],
|
|
159
|
+
"description": "Risk level of the action",
|
|
160
|
+
},
|
|
161
|
+
},
|
|
162
|
+
"required": ["action"],
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
async def execute(
|
|
166
|
+
self,
|
|
167
|
+
action: str,
|
|
168
|
+
details: str | None = None,
|
|
169
|
+
risk_level: str = "medium",
|
|
170
|
+
*,
|
|
171
|
+
ctx: "InvocationContext | None" = None,
|
|
172
|
+
) -> ToolResult:
|
|
173
|
+
"""Execute confirm tool.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
action: Action to confirm
|
|
177
|
+
details: Additional details
|
|
178
|
+
risk_level: Risk level (low, medium, high)
|
|
179
|
+
ctx: Invocation context
|
|
180
|
+
|
|
181
|
+
Raises:
|
|
182
|
+
HITLSuspend: Always raised to suspend execution
|
|
183
|
+
"""
|
|
184
|
+
if ctx is None:
|
|
185
|
+
return ToolResult.error("Cannot confirm without execution context")
|
|
186
|
+
|
|
187
|
+
request_id = generate_id("req")
|
|
188
|
+
|
|
189
|
+
message = f"Confirm: {action}"
|
|
190
|
+
if details:
|
|
191
|
+
message += f"\n\nDetails: {details}"
|
|
192
|
+
|
|
193
|
+
request = HITLRequest(
|
|
194
|
+
request_id=request_id,
|
|
195
|
+
request_type="confirm",
|
|
196
|
+
message=message,
|
|
197
|
+
options=["Yes, proceed", "No, cancel"],
|
|
198
|
+
tool_name=self._name,
|
|
199
|
+
metadata={
|
|
200
|
+
"action": action,
|
|
201
|
+
"details": details,
|
|
202
|
+
"risk_level": risk_level,
|
|
203
|
+
},
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
# Checkpoint
|
|
207
|
+
if hasattr(ctx, "state") and ctx.state is not None:
|
|
208
|
+
await ctx.state.checkpoint()
|
|
209
|
+
|
|
210
|
+
# Update invocation
|
|
211
|
+
if ctx.backends and ctx.backends.invocation:
|
|
212
|
+
await ctx.backends.invocation.update(ctx.invocation_id, {
|
|
213
|
+
"status": "suspended",
|
|
214
|
+
"pending_request_id": request_id,
|
|
215
|
+
"pending_request_type": "confirm",
|
|
216
|
+
"pending_request_data": request.to_dict(),
|
|
217
|
+
})
|
|
218
|
+
|
|
219
|
+
# Emit block
|
|
220
|
+
await ctx.emit(BlockEvent(
|
|
221
|
+
kind="hitl_request",
|
|
222
|
+
data={
|
|
223
|
+
"request_id": request_id,
|
|
224
|
+
"type": "confirm",
|
|
225
|
+
"action": action,
|
|
226
|
+
"details": details,
|
|
227
|
+
"risk_level": risk_level,
|
|
228
|
+
"options": ["Yes, proceed", "No, cancel"],
|
|
229
|
+
},
|
|
230
|
+
))
|
|
231
|
+
|
|
232
|
+
raise HITLSuspend(
|
|
233
|
+
request_id=request_id,
|
|
234
|
+
request_type="confirm",
|
|
235
|
+
message=message,
|
|
236
|
+
options=["Yes, proceed", "No, cancel"],
|
|
237
|
+
tool_name=self._name,
|
|
238
|
+
metadata=request.metadata,
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
__all__ = ["AskUserTool", "ConfirmTool"]
|