axio 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- axio/__init__.py +1 -0
- axio/agent.py +239 -0
- axio/blocks.py +98 -0
- axio/context.py +197 -0
- axio/events.py +66 -0
- axio/exceptions.py +21 -0
- axio/messages.py +21 -0
- axio/models.py +102 -0
- axio/permission.py +50 -0
- axio/selector.py +121 -0
- axio/stream.py +57 -0
- axio/testing.py +87 -0
- axio/tool.py +74 -0
- axio/transport.py +35 -0
- axio/types.py +28 -0
- axio-0.1.0.dist-info/METADATA +8 -0
- axio-0.1.0.dist-info/RECORD +19 -0
- axio-0.1.0.dist-info/WHEEL +4 -0
- axio-0.1.0.dist-info/licenses/LICENSE +21 -0
axio/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
axio/agent.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
1
|
+
"""Agent: the core agentic loop orchestrating transport, tools, and context."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import json
|
|
7
|
+
import logging
|
|
8
|
+
from collections.abc import AsyncGenerator
|
|
9
|
+
from dataclasses import dataclass, field
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
from axio.blocks import TextBlock, ToolResultBlock, ToolUseBlock
|
|
13
|
+
from axio.context import ContextStore
|
|
14
|
+
from axio.events import (
|
|
15
|
+
Error,
|
|
16
|
+
IterationEnd,
|
|
17
|
+
SessionEndEvent,
|
|
18
|
+
StreamEvent,
|
|
19
|
+
TextDelta,
|
|
20
|
+
ToolInputDelta,
|
|
21
|
+
ToolResult,
|
|
22
|
+
ToolUseStart,
|
|
23
|
+
)
|
|
24
|
+
from axio.messages import Message
|
|
25
|
+
from axio.stream import AgentStream
|
|
26
|
+
from axio.tool import Tool
|
|
27
|
+
from axio.transport import CompletionTransport
|
|
28
|
+
from axio.types import StopReason, Usage
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass(slots=True)
|
|
34
|
+
class Agent:
|
|
35
|
+
system: str
|
|
36
|
+
tools: list[Tool]
|
|
37
|
+
transport: CompletionTransport
|
|
38
|
+
max_iterations: int = field(default=50)
|
|
39
|
+
|
|
40
|
+
def run_stream(self, user_message: str, context: ContextStore) -> AgentStream:
|
|
41
|
+
return AgentStream(self._run_loop(user_message, context))
|
|
42
|
+
|
|
43
|
+
async def run(self, user_message: str, context: ContextStore) -> str:
|
|
44
|
+
return await self.run_stream(user_message, context).get_final_text()
|
|
45
|
+
|
|
46
|
+
async def dispatch_tools(self, blocks: list[ToolUseBlock], iteration: int) -> list[ToolResultBlock]:
|
|
47
|
+
tool_names = [b.name for b in blocks]
|
|
48
|
+
logger.info("Dispatching %d tool(s): %s", len(blocks), tool_names)
|
|
49
|
+
|
|
50
|
+
async def _run_one(block: ToolUseBlock) -> ToolResultBlock:
|
|
51
|
+
tool = self._find_tool(block.name)
|
|
52
|
+
if tool is None:
|
|
53
|
+
logger.warning("Unknown tool requested: %s", block.name)
|
|
54
|
+
return ToolResultBlock(tool_use_id=block.id, content=f"Unknown tool: {block.name}", is_error=True)
|
|
55
|
+
logger.debug("Tool %s (id=%s) args=%s", block.name, block.id, json.dumps(block.input)[:200])
|
|
56
|
+
try:
|
|
57
|
+
result = await tool(**block.input)
|
|
58
|
+
content = result if isinstance(result, str) else str(result)
|
|
59
|
+
except Exception as exc:
|
|
60
|
+
logger.error("Tool %s raised %s: %s", block.name, type(exc).__name__, exc, exc_info=True)
|
|
61
|
+
return ToolResultBlock(tool_use_id=block.id, content=str(exc), is_error=True)
|
|
62
|
+
return ToolResultBlock(tool_use_id=block.id, content=content)
|
|
63
|
+
|
|
64
|
+
results = list(await asyncio.gather(*[_run_one(b) for b in blocks]))
|
|
65
|
+
error_count = sum(1 for r in results if r.is_error)
|
|
66
|
+
logger.info("Tools complete: %d total, %d errors", len(results), error_count)
|
|
67
|
+
return results
|
|
68
|
+
|
|
69
|
+
def _find_tool(self, name: str) -> Tool | None:
|
|
70
|
+
for tool in self.tools:
|
|
71
|
+
if tool.name == name:
|
|
72
|
+
return tool
|
|
73
|
+
return None
|
|
74
|
+
|
|
75
|
+
async def _append(self, context: ContextStore, message: Message) -> None:
|
|
76
|
+
await context.append(message)
|
|
77
|
+
|
|
78
|
+
@staticmethod
|
|
79
|
+
def _accumulate_text(content: list[TextBlock | ToolUseBlock], delta: str) -> None:
|
|
80
|
+
"""Append text delta — merge into last TextBlock or start a new one."""
|
|
81
|
+
if content and isinstance(content[-1], TextBlock):
|
|
82
|
+
content[-1] = TextBlock(text=content[-1].text + delta)
|
|
83
|
+
else:
|
|
84
|
+
content.append(TextBlock(text=delta))
|
|
85
|
+
|
|
86
|
+
@staticmethod
|
|
87
|
+
def _finalize_pending_tools(
|
|
88
|
+
pending: dict[str, dict[str, Any]],
|
|
89
|
+
usage: Usage,
|
|
90
|
+
) -> tuple[list[ToolUseBlock], set[str]]:
|
|
91
|
+
"""Convert streamed tool-call fragments into ToolUseBlocks.
|
|
92
|
+
|
|
93
|
+
Returns (blocks, malformed_ids).
|
|
94
|
+
"""
|
|
95
|
+
blocks: list[ToolUseBlock] = []
|
|
96
|
+
malformed: set[str] = set()
|
|
97
|
+
for tid, info in pending.items():
|
|
98
|
+
raw = "".join(info["json_parts"])
|
|
99
|
+
if not raw:
|
|
100
|
+
logger.warning(
|
|
101
|
+
"Tool %s (id=%s) received empty arguments (output may be truncated, output_tokens=%d)",
|
|
102
|
+
info["name"],
|
|
103
|
+
tid,
|
|
104
|
+
usage.output_tokens,
|
|
105
|
+
)
|
|
106
|
+
inp: dict[str, Any] = {}
|
|
107
|
+
else:
|
|
108
|
+
try:
|
|
109
|
+
inp = json.loads(raw)
|
|
110
|
+
except json.JSONDecodeError as exc:
|
|
111
|
+
logger.warning(
|
|
112
|
+
"Tool %s (id=%s) has malformed JSON arguments: %s\nRaw: %s",
|
|
113
|
+
info["name"],
|
|
114
|
+
tid,
|
|
115
|
+
exc,
|
|
116
|
+
raw,
|
|
117
|
+
)
|
|
118
|
+
malformed.add(tid)
|
|
119
|
+
inp = {}
|
|
120
|
+
blocks.append(ToolUseBlock(id=tid, name=info["name"], input=inp))
|
|
121
|
+
return blocks, malformed
|
|
122
|
+
|
|
123
|
+
async def _run_loop(self, user_message: str, context: ContextStore) -> AsyncGenerator[StreamEvent, None]:
|
|
124
|
+
total_usage = Usage(0, 0)
|
|
125
|
+
session_end_emitted = False
|
|
126
|
+
await self._append(context, Message(role="user", content=[TextBlock(text=user_message)]))
|
|
127
|
+
|
|
128
|
+
try:
|
|
129
|
+
for iteration in range(1, self.max_iterations + 1):
|
|
130
|
+
history = await context.get_history()
|
|
131
|
+
logger.info("Iteration %d, history length=%d", iteration, len(history))
|
|
132
|
+
active_tools = self.tools
|
|
133
|
+
|
|
134
|
+
content: list[TextBlock | ToolUseBlock] = []
|
|
135
|
+
pending: dict[str, dict[str, Any]] = {}
|
|
136
|
+
stop_reason = StopReason.end_turn
|
|
137
|
+
malformed: set[str] = set()
|
|
138
|
+
|
|
139
|
+
try:
|
|
140
|
+
async for event in self.transport.stream(history, active_tools, self.system):
|
|
141
|
+
yield event
|
|
142
|
+
match event:
|
|
143
|
+
case TextDelta(delta=delta):
|
|
144
|
+
self._accumulate_text(content, delta)
|
|
145
|
+
case ToolUseStart(tool_use_id=tid, name=name):
|
|
146
|
+
pending[tid] = {"name": name, "json_parts": []}
|
|
147
|
+
case ToolInputDelta(tool_use_id=tid, partial_json=pj):
|
|
148
|
+
if tid in pending:
|
|
149
|
+
pending[tid]["json_parts"].append(pj)
|
|
150
|
+
case IterationEnd(usage=usage, stop_reason=sr):
|
|
151
|
+
blocks, malformed = self._finalize_pending_tools(pending, usage)
|
|
152
|
+
content.extend(blocks)
|
|
153
|
+
pending.clear()
|
|
154
|
+
total_usage = total_usage + usage
|
|
155
|
+
await context.add_context_tokens(usage.input_tokens, usage.output_tokens)
|
|
156
|
+
stop_reason = sr
|
|
157
|
+
except Exception as exc:
|
|
158
|
+
logger.error("Transport error: %s", exc, exc_info=True)
|
|
159
|
+
yield Error(exception=exc)
|
|
160
|
+
yield SessionEndEvent(stop_reason=StopReason.error, total_usage=total_usage)
|
|
161
|
+
session_end_emitted = True
|
|
162
|
+
return
|
|
163
|
+
|
|
164
|
+
tool_blocks = [b for b in content if isinstance(b, ToolUseBlock)]
|
|
165
|
+
|
|
166
|
+
if tool_blocks:
|
|
167
|
+
if stop_reason != StopReason.tool_use:
|
|
168
|
+
logger.warning(
|
|
169
|
+
"Dispatching %d tool(s) despite stop_reason=%s",
|
|
170
|
+
len(tool_blocks),
|
|
171
|
+
stop_reason,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
# Dispatch tools BEFORE appending to context — cancellation
|
|
175
|
+
# between here and the two appends below cannot leave orphan
|
|
176
|
+
# ToolUseBlocks in the persistent context store.
|
|
177
|
+
valid = [b for b in tool_blocks if b.id not in malformed]
|
|
178
|
+
error_results = [
|
|
179
|
+
ToolResultBlock(
|
|
180
|
+
tool_use_id=b.id,
|
|
181
|
+
content=(
|
|
182
|
+
f"Malformed JSON arguments for tool {b.name}."
|
|
183
|
+
f" Raw input could not be parsed. Please retry the tool call"
|
|
184
|
+
f" with valid JSON arguments."
|
|
185
|
+
),
|
|
186
|
+
is_error=True,
|
|
187
|
+
)
|
|
188
|
+
for b in tool_blocks
|
|
189
|
+
if b.id in malformed
|
|
190
|
+
]
|
|
191
|
+
dispatched = await self.dispatch_tools(valid, iteration) if valid else []
|
|
192
|
+
results = dispatched + error_results
|
|
193
|
+
|
|
194
|
+
# Append both messages atomically (assistant + tool results)
|
|
195
|
+
await self._append(context, Message(role="assistant", content=list(content)))
|
|
196
|
+
await self._append(context, Message(role="user", content=list(results)))
|
|
197
|
+
|
|
198
|
+
# Yield ToolResult events
|
|
199
|
+
by_id = {b.id: b for b in tool_blocks}
|
|
200
|
+
for r in results:
|
|
201
|
+
block = by_id.get(r.tool_use_id)
|
|
202
|
+
result_content = (
|
|
203
|
+
r.content
|
|
204
|
+
if isinstance(r.content, str)
|
|
205
|
+
else "\n".join(b.text for b in r.content if isinstance(b, TextBlock))
|
|
206
|
+
)
|
|
207
|
+
yield ToolResult(
|
|
208
|
+
tool_use_id=r.tool_use_id,
|
|
209
|
+
name=block.name if block else "",
|
|
210
|
+
is_error=r.is_error,
|
|
211
|
+
content=result_content,
|
|
212
|
+
input=block.input if block else {},
|
|
213
|
+
)
|
|
214
|
+
continue
|
|
215
|
+
|
|
216
|
+
await self._append(context, Message(role="assistant", content=list(content)))
|
|
217
|
+
|
|
218
|
+
match stop_reason:
|
|
219
|
+
case StopReason.end_turn:
|
|
220
|
+
logger.debug("End turn: total_usage=%s", total_usage)
|
|
221
|
+
yield SessionEndEvent(stop_reason=StopReason.end_turn, total_usage=total_usage)
|
|
222
|
+
session_end_emitted = True
|
|
223
|
+
return
|
|
224
|
+
case StopReason.max_tokens | StopReason.error:
|
|
225
|
+
yield Error(exception=RuntimeError(f"Transport stopped with: {stop_reason}"))
|
|
226
|
+
yield SessionEndEvent(stop_reason=StopReason.error, total_usage=total_usage)
|
|
227
|
+
session_end_emitted = True
|
|
228
|
+
return
|
|
229
|
+
|
|
230
|
+
logger.warning("Max iterations (%d) reached", self.max_iterations)
|
|
231
|
+
yield SessionEndEvent(stop_reason=StopReason.error, total_usage=total_usage)
|
|
232
|
+
session_end_emitted = True
|
|
233
|
+
|
|
234
|
+
except GeneratorExit:
|
|
235
|
+
return
|
|
236
|
+
except BaseException:
|
|
237
|
+
if not session_end_emitted:
|
|
238
|
+
yield SessionEndEvent(stop_reason=StopReason.error, total_usage=total_usage)
|
|
239
|
+
raise
|
axio/blocks.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
"""Content blocks: TextBlock, ImageBlock, ToolUseBlock, ToolResultBlock."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import base64
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from functools import singledispatch
|
|
8
|
+
from typing import Any, Literal
|
|
9
|
+
|
|
10
|
+
from axio.types import ToolCallID, ToolName
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ContentBlock:
|
|
14
|
+
"""Base class for all content blocks."""
|
|
15
|
+
|
|
16
|
+
__slots__ = ()
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass(frozen=True, slots=True)
|
|
20
|
+
class TextBlock(ContentBlock):
|
|
21
|
+
text: str
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass(frozen=True, slots=True)
|
|
25
|
+
class ImageBlock(ContentBlock):
|
|
26
|
+
media_type: Literal["image/jpeg", "image/png", "image/gif", "image/webp"]
|
|
27
|
+
data: bytes
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass(frozen=True, slots=True)
|
|
31
|
+
class ToolUseBlock(ContentBlock):
|
|
32
|
+
id: ToolCallID
|
|
33
|
+
name: ToolName
|
|
34
|
+
input: dict[str, Any]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass(frozen=True, slots=True)
|
|
38
|
+
class ToolResultBlock(ContentBlock):
|
|
39
|
+
tool_use_id: ToolCallID
|
|
40
|
+
content: str | list[TextBlock | ImageBlock]
|
|
41
|
+
is_error: bool = False
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@singledispatch
|
|
45
|
+
def to_dict(block: ContentBlock) -> dict[str, Any]:
|
|
46
|
+
"""Serialize a ContentBlock to a plain dict."""
|
|
47
|
+
msg = f"Unknown block type: {type(block).__name__}"
|
|
48
|
+
raise TypeError(msg)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@to_dict.register(TextBlock)
|
|
52
|
+
def _text_to_dict(block: TextBlock) -> dict[str, Any]:
|
|
53
|
+
return {"type": "text", "text": block.text}
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@to_dict.register(ImageBlock)
|
|
57
|
+
def _image_to_dict(block: ImageBlock) -> dict[str, Any]:
|
|
58
|
+
return {"type": "image", "media_type": block.media_type, "data": base64.b64encode(block.data).decode()}
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@to_dict.register(ToolUseBlock)
|
|
62
|
+
def _tool_use_to_dict(block: ToolUseBlock) -> dict[str, Any]:
|
|
63
|
+
return {"type": "tool_use", "id": block.id, "name": block.name, "input": block.input}
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@to_dict.register(ToolResultBlock)
|
|
67
|
+
def _tool_result_to_dict(block: ToolResultBlock) -> dict[str, Any]:
|
|
68
|
+
if isinstance(block.content, str):
|
|
69
|
+
serialized_content: str | list[dict[str, Any]] = block.content
|
|
70
|
+
else:
|
|
71
|
+
serialized_content = [to_dict(b) for b in block.content]
|
|
72
|
+
return {
|
|
73
|
+
"type": "tool_result",
|
|
74
|
+
"tool_use_id": block.tool_use_id,
|
|
75
|
+
"content": serialized_content,
|
|
76
|
+
"is_error": block.is_error,
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def from_dict(data: dict[str, Any]) -> ContentBlock:
|
|
81
|
+
"""Deserialize a plain dict to a ContentBlock."""
|
|
82
|
+
match data["type"]:
|
|
83
|
+
case "text":
|
|
84
|
+
return TextBlock(text=data["text"])
|
|
85
|
+
case "image":
|
|
86
|
+
return ImageBlock(media_type=data["media_type"], data=base64.b64decode(data["data"]))
|
|
87
|
+
case "tool_use":
|
|
88
|
+
return ToolUseBlock(id=data["id"], name=data["name"], input=data["input"])
|
|
89
|
+
case "tool_result":
|
|
90
|
+
raw = data["content"]
|
|
91
|
+
if isinstance(raw, str):
|
|
92
|
+
content: str | list[TextBlock | ImageBlock] = raw
|
|
93
|
+
else:
|
|
94
|
+
content = [from_dict(b) for b in raw] # type: ignore[misc]
|
|
95
|
+
return ToolResultBlock(tool_use_id=data["tool_use_id"], content=content, is_error=data["is_error"])
|
|
96
|
+
case _:
|
|
97
|
+
msg = f"Unknown block type: {data['type']}"
|
|
98
|
+
raise ValueError(msg)
|
axio/context.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
"""ContextStore: protocol for conversation history storage."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import copy
|
|
6
|
+
import logging
|
|
7
|
+
from abc import ABC, abstractmethod
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from typing import Self
|
|
10
|
+
from uuid import uuid4
|
|
11
|
+
|
|
12
|
+
from axio.blocks import TextBlock, ToolResultBlock
|
|
13
|
+
from axio.messages import Message
|
|
14
|
+
from axio.transport import CompletionTransport
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass(frozen=True, slots=True)
|
|
20
|
+
class SessionInfo:
|
|
21
|
+
session_id: str
|
|
22
|
+
message_count: int
|
|
23
|
+
preview: str
|
|
24
|
+
created_at: str
|
|
25
|
+
input_tokens: int = 0
|
|
26
|
+
output_tokens: int = 0
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class ContextStore(ABC):
|
|
30
|
+
@property
|
|
31
|
+
@abstractmethod
|
|
32
|
+
def session_id(self) -> str: ...
|
|
33
|
+
|
|
34
|
+
@abstractmethod
|
|
35
|
+
async def append(self, message: Message) -> None: ...
|
|
36
|
+
|
|
37
|
+
@abstractmethod
|
|
38
|
+
async def get_history(self) -> list[Message]: ...
|
|
39
|
+
|
|
40
|
+
@abstractmethod
|
|
41
|
+
async def clear(self) -> None: ...
|
|
42
|
+
|
|
43
|
+
@abstractmethod
|
|
44
|
+
async def fork(self) -> ContextStore: ...
|
|
45
|
+
|
|
46
|
+
@abstractmethod
|
|
47
|
+
async def set_context_tokens(self, input_tokens: int, output_tokens: int) -> None: ...
|
|
48
|
+
|
|
49
|
+
@abstractmethod
|
|
50
|
+
async def get_context_tokens(self) -> tuple[int, int]: ...
|
|
51
|
+
|
|
52
|
+
@abstractmethod
|
|
53
|
+
async def close(self) -> None: ...
|
|
54
|
+
|
|
55
|
+
async def list_sessions(self) -> list[SessionInfo]:
|
|
56
|
+
"""List available sessions. Default: returns a single entry for the current session."""
|
|
57
|
+
history = await self.get_history()
|
|
58
|
+
in_tok, out_tok = await self.get_context_tokens()
|
|
59
|
+
preview = "(empty)"
|
|
60
|
+
for msg in history:
|
|
61
|
+
if msg.role == "user":
|
|
62
|
+
for block in msg.content:
|
|
63
|
+
if isinstance(block, TextBlock):
|
|
64
|
+
text = block.text
|
|
65
|
+
preview = text[:80] + ("..." if len(text) > 80 else "")
|
|
66
|
+
break
|
|
67
|
+
break
|
|
68
|
+
return [
|
|
69
|
+
SessionInfo(
|
|
70
|
+
session_id=self.session_id,
|
|
71
|
+
message_count=len(history),
|
|
72
|
+
preview=preview,
|
|
73
|
+
created_at="",
|
|
74
|
+
input_tokens=in_tok,
|
|
75
|
+
output_tokens=out_tok,
|
|
76
|
+
),
|
|
77
|
+
]
|
|
78
|
+
|
|
79
|
+
async def add_context_tokens(self, input_tokens: int, output_tokens: int) -> None:
|
|
80
|
+
cur_in, cur_out = await self.get_context_tokens()
|
|
81
|
+
await self.set_context_tokens(cur_in + input_tokens, cur_out + output_tokens)
|
|
82
|
+
|
|
83
|
+
@classmethod
|
|
84
|
+
async def from_history(cls, history: list[Message]) -> Self:
|
|
85
|
+
"""Create a new ContextStore pre-populated with *history*."""
|
|
86
|
+
store = cls()
|
|
87
|
+
for message in history:
|
|
88
|
+
await store.append(message)
|
|
89
|
+
return store
|
|
90
|
+
|
|
91
|
+
@classmethod
|
|
92
|
+
async def from_context(cls, context: ContextStore) -> Self:
|
|
93
|
+
return await cls.from_history(await context.get_history())
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class MemoryContextStore(ContextStore):
|
|
97
|
+
"""Simple in-memory context store. fork() returns a deep copy."""
|
|
98
|
+
|
|
99
|
+
def __init__(self, history: list[Message] | None = None) -> None:
|
|
100
|
+
self._session_id = uuid4().hex
|
|
101
|
+
self._history: list[Message] = list(history or [])
|
|
102
|
+
self._input_tokens: int = 0
|
|
103
|
+
self._output_tokens: int = 0
|
|
104
|
+
|
|
105
|
+
@property
|
|
106
|
+
def session_id(self) -> str:
|
|
107
|
+
return self._session_id
|
|
108
|
+
|
|
109
|
+
async def append(self, message: Message) -> None:
|
|
110
|
+
self._history.append(message)
|
|
111
|
+
|
|
112
|
+
async def get_history(self) -> list[Message]:
|
|
113
|
+
return list(self._history)
|
|
114
|
+
|
|
115
|
+
async def clear(self) -> None:
|
|
116
|
+
self._history.clear()
|
|
117
|
+
self._input_tokens = 0
|
|
118
|
+
self._output_tokens = 0
|
|
119
|
+
|
|
120
|
+
async def fork(self) -> MemoryContextStore:
|
|
121
|
+
store = MemoryContextStore(copy.deepcopy(self._history))
|
|
122
|
+
store._input_tokens = self._input_tokens
|
|
123
|
+
store._output_tokens = self._output_tokens
|
|
124
|
+
return store
|
|
125
|
+
|
|
126
|
+
async def set_context_tokens(self, input_tokens: int, output_tokens: int) -> None:
|
|
127
|
+
self._input_tokens = input_tokens
|
|
128
|
+
self._output_tokens = output_tokens
|
|
129
|
+
|
|
130
|
+
async def get_context_tokens(self) -> tuple[int, int]:
|
|
131
|
+
return self._input_tokens, self._output_tokens
|
|
132
|
+
|
|
133
|
+
async def close(self) -> None:
|
|
134
|
+
pass
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
_DEFAULT_COMPACTION_PROMPT = (
|
|
138
|
+
"You are a conversation summarizer. You will see a conversation between"
|
|
139
|
+
" a user and an AI assistant, including tool calls and their results."
|
|
140
|
+
" Produce a concise summary preserving: user goals, decisions made,"
|
|
141
|
+
" key facts, tool outcomes, and state changes. Write as narrative prose,"
|
|
142
|
+
" not as a transcript."
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
async def compact_context(
|
|
147
|
+
context: ContextStore,
|
|
148
|
+
transport: CompletionTransport,
|
|
149
|
+
*,
|
|
150
|
+
max_messages: int = 20,
|
|
151
|
+
keep_recent: int = 6,
|
|
152
|
+
system_prompt: str | None = None,
|
|
153
|
+
) -> list[Message] | None:
|
|
154
|
+
"""Summarize old messages from *context*, keeping recent ones verbatim.
|
|
155
|
+
|
|
156
|
+
Returns a compacted message list ready to populate a fresh store,
|
|
157
|
+
or ``None`` when no compaction is needed.
|
|
158
|
+
"""
|
|
159
|
+
history = await context.get_history()
|
|
160
|
+
if len(history) <= max_messages:
|
|
161
|
+
return None
|
|
162
|
+
|
|
163
|
+
split = _find_safe_boundary(history, keep_recent)
|
|
164
|
+
if split <= 0:
|
|
165
|
+
return None
|
|
166
|
+
|
|
167
|
+
old, recent = history[:split], history[split:]
|
|
168
|
+
|
|
169
|
+
# Deferred import to avoid circular dependency (context ↔ agent)
|
|
170
|
+
from axio.agent import Agent
|
|
171
|
+
|
|
172
|
+
summary_ctx = MemoryContextStore(old)
|
|
173
|
+
agent = Agent(
|
|
174
|
+
system=system_prompt or _DEFAULT_COMPACTION_PROMPT,
|
|
175
|
+
tools=[],
|
|
176
|
+
transport=transport,
|
|
177
|
+
max_iterations=1,
|
|
178
|
+
)
|
|
179
|
+
try:
|
|
180
|
+
summary = await agent.run("Summarize the conversation above.", summary_ctx)
|
|
181
|
+
except Exception:
|
|
182
|
+
logger.warning("Context compaction failed, keeping original history", exc_info=True)
|
|
183
|
+
return None
|
|
184
|
+
|
|
185
|
+
return [
|
|
186
|
+
Message(role="user", content=[TextBlock(text=summary)]),
|
|
187
|
+
Message(role="assistant", content=[TextBlock(text="Understood, context restored.")]),
|
|
188
|
+
*recent,
|
|
189
|
+
]
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def _find_safe_boundary(history: list[Message], keep_recent: int) -> int:
|
|
193
|
+
"""Return a split index that never separates a tool_use from its tool_result."""
|
|
194
|
+
split = len(history) - keep_recent
|
|
195
|
+
while split > 0 and any(isinstance(b, ToolResultBlock) for b in history[split].content):
|
|
196
|
+
split -= 1
|
|
197
|
+
return split
|
axio/events.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
"""Stream events: all variants emitted by AgentStream."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from axio.types import StopReason, ToolCallID, ToolName, Usage
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass(frozen=True, slots=True)
|
|
12
|
+
class ReasoningDelta:
|
|
13
|
+
index: int
|
|
14
|
+
delta: str
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass(frozen=True, slots=True)
|
|
18
|
+
class TextDelta:
|
|
19
|
+
index: int
|
|
20
|
+
delta: str
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass(frozen=True, slots=True)
|
|
24
|
+
class ToolUseStart:
|
|
25
|
+
index: int
|
|
26
|
+
tool_use_id: ToolCallID
|
|
27
|
+
name: ToolName
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass(frozen=True, slots=True)
|
|
31
|
+
class ToolInputDelta:
|
|
32
|
+
index: int
|
|
33
|
+
tool_use_id: ToolCallID
|
|
34
|
+
partial_json: str
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass(frozen=True, slots=True)
|
|
38
|
+
class ToolResult:
|
|
39
|
+
tool_use_id: ToolCallID
|
|
40
|
+
name: ToolName
|
|
41
|
+
is_error: bool
|
|
42
|
+
content: str = ""
|
|
43
|
+
input: dict[str, Any] = field(default_factory=dict)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@dataclass(frozen=True, slots=True)
|
|
47
|
+
class IterationEnd:
|
|
48
|
+
iteration: int
|
|
49
|
+
stop_reason: StopReason
|
|
50
|
+
usage: Usage
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@dataclass(frozen=True, slots=True)
|
|
54
|
+
class Error:
|
|
55
|
+
exception: BaseException
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@dataclass(frozen=True, slots=True)
|
|
59
|
+
class SessionEndEvent:
|
|
60
|
+
stop_reason: StopReason
|
|
61
|
+
total_usage: Usage
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
type StreamEvent = (
|
|
65
|
+
ReasoningDelta | TextDelta | ToolUseStart | ToolInputDelta | ToolResult | IterationEnd | Error | SessionEndEvent
|
|
66
|
+
)
|
axio/exceptions.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"""Exception hierarchy for axio."""
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class AxioError(Exception):
|
|
5
|
+
"""Base exception for all axio errors."""
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ToolError(AxioError):
|
|
9
|
+
"""Base for tool-related errors."""
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class GuardError(ToolError):
|
|
13
|
+
"""Guard denied or crashed during permission check."""
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class HandlerError(ToolError):
|
|
17
|
+
"""Handler raised during execution."""
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class StreamError(AxioError):
|
|
21
|
+
"""Error during stream collection."""
|
axio/messages.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"""Message: the fundamental unit of conversation history."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from typing import Any, Literal
|
|
7
|
+
|
|
8
|
+
from axio.blocks import ContentBlock, from_dict, to_dict
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass(slots=True)
|
|
12
|
+
class Message:
|
|
13
|
+
role: Literal["user", "assistant"]
|
|
14
|
+
content: list[ContentBlock] = field(default_factory=list)
|
|
15
|
+
|
|
16
|
+
def to_dict(self) -> dict[str, Any]:
|
|
17
|
+
return {"role": self.role, "content": [to_dict(b) for b in self.content]}
|
|
18
|
+
|
|
19
|
+
@classmethod
|
|
20
|
+
def from_dict(cls, data: dict[str, Any]) -> Message:
|
|
21
|
+
return cls(role=data["role"], content=[from_dict(b) for b in data["content"]])
|
axio/models.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
"""Transport-agnostic model types: Capability, ModelSpec, ModelRegistry."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import ItemsView, Iterable, Iterator, KeysView, MutableMapping, ValuesView
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from enum import StrEnum
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass(frozen=True, slots=True)
|
|
11
|
+
class TransportMeta:
|
|
12
|
+
"""Metadata a transport plugin declares about itself."""
|
|
13
|
+
|
|
14
|
+
label: str
|
|
15
|
+
api_key_env: str
|
|
16
|
+
role_defaults: dict[str, str]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class Capability(StrEnum):
|
|
20
|
+
text = "text"
|
|
21
|
+
vision = "vision"
|
|
22
|
+
reasoning = "reasoning"
|
|
23
|
+
tool_use = "tool_use"
|
|
24
|
+
json_mode = "json_mode"
|
|
25
|
+
structured_outputs = "structured_outputs"
|
|
26
|
+
embedding = "embedding"
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass(frozen=True, slots=True)
|
|
30
|
+
class ModelSpec:
|
|
31
|
+
id: str
|
|
32
|
+
capabilities: frozenset[Capability] = frozenset()
|
|
33
|
+
max_output_tokens: int = 8192
|
|
34
|
+
context_window: int = 128000
|
|
35
|
+
input_cost: float = 0.0
|
|
36
|
+
output_cost: float = 0.0
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class ModelRegistry(MutableMapping[str, ModelSpec]):
|
|
40
|
+
__slots__ = ("_models",)
|
|
41
|
+
|
|
42
|
+
def __init__(self, models: Iterable[ModelSpec] | None = None) -> None:
|
|
43
|
+
self._models: dict[str, ModelSpec] = {m.id: m for m in (models or [])}
|
|
44
|
+
|
|
45
|
+
def __setitem__(self, key: str, value: ModelSpec, /) -> None:
|
|
46
|
+
if not isinstance(value, ModelSpec):
|
|
47
|
+
raise ValueError("ModelRegistry values must be ModelSpec instances")
|
|
48
|
+
self._models[key] = value
|
|
49
|
+
|
|
50
|
+
def __delitem__(self, key: str, /) -> None:
|
|
51
|
+
del self._models[key]
|
|
52
|
+
|
|
53
|
+
def __getitem__(self, key: str, /) -> ModelSpec:
|
|
54
|
+
return self._models[key]
|
|
55
|
+
|
|
56
|
+
def __len__(self) -> int:
|
|
57
|
+
return len(self._models)
|
|
58
|
+
|
|
59
|
+
def __iter__(self) -> Iterator[ModelSpec]: # type: ignore[override]
|
|
60
|
+
return iter(self._models.values())
|
|
61
|
+
|
|
62
|
+
def __eq__(self, other: object) -> bool:
|
|
63
|
+
if isinstance(other, ModelRegistry):
|
|
64
|
+
return self._models == other._models
|
|
65
|
+
if isinstance(other, dict):
|
|
66
|
+
return self._models == other
|
|
67
|
+
return NotImplemented
|
|
68
|
+
|
|
69
|
+
def __repr__(self) -> str:
|
|
70
|
+
return f"ModelRegistry({self._models!r})"
|
|
71
|
+
|
|
72
|
+
def clear(self) -> None:
|
|
73
|
+
self._models.clear()
|
|
74
|
+
|
|
75
|
+
def keys(self) -> KeysView[str]:
|
|
76
|
+
return self._models.keys()
|
|
77
|
+
|
|
78
|
+
def values(self) -> ValuesView[ModelSpec]:
|
|
79
|
+
return self._models.values()
|
|
80
|
+
|
|
81
|
+
def items(self) -> ItemsView[str, ModelSpec]:
|
|
82
|
+
return self._models.items()
|
|
83
|
+
|
|
84
|
+
def by_prefix(self, prefix: str) -> ModelRegistry:
|
|
85
|
+
return ModelRegistry(v for k, v in self._models.items() if k.startswith(prefix))
|
|
86
|
+
|
|
87
|
+
def by_capability(self, *caps: Capability) -> ModelRegistry:
|
|
88
|
+
required = frozenset(caps)
|
|
89
|
+
return ModelRegistry(v for v in self._models.values() if required <= v.capabilities)
|
|
90
|
+
|
|
91
|
+
def search(self, *q: str) -> ModelRegistry:
|
|
92
|
+
"""search by parts of id"""
|
|
93
|
+
return ModelRegistry(v for k, v in self._models.items() if all(part in k for part in q))
|
|
94
|
+
|
|
95
|
+
def by_cost(self, *, output: bool = False, desc: bool = False) -> ModelRegistry:
|
|
96
|
+
"""Return registry ordered by cost (input by default, output if *output=True*)."""
|
|
97
|
+
attr = "output_cost" if output else "input_cost"
|
|
98
|
+
items = sorted(self._models.values(), key=lambda v: getattr(v, attr), reverse=desc)
|
|
99
|
+
return ModelRegistry(items)
|
|
100
|
+
|
|
101
|
+
def ids(self) -> list[str]:
|
|
102
|
+
return list(self._models)
|
axio/permission.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
"""Permission system: guards that gate tool execution."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
from axio.exceptions import GuardError
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class PermissionGuard(ABC):
|
|
13
|
+
"""Gate for tool calls. Return handler to allow, raise to deny.
|
|
14
|
+
|
|
15
|
+
Tool calls guards via ``await guard(instance)``.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
async def __call__(self, handler: Any) -> Any:
|
|
19
|
+
return await self.check(handler)
|
|
20
|
+
|
|
21
|
+
@abstractmethod
|
|
22
|
+
async def check(self, handler: Any) -> Any: ...
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ConcurrentGuard(PermissionGuard, ABC):
|
|
26
|
+
"""Guard with concurrency control.
|
|
27
|
+
|
|
28
|
+
Subclass and override ``check()``. ``__call__`` acquires the semaphore
|
|
29
|
+
then delegates to ``check()``. Set ``concurrency`` to control parallelism
|
|
30
|
+
(default 1 — one check at a time).
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
concurrency: int = 1
|
|
34
|
+
|
|
35
|
+
def __init__(self) -> None:
|
|
36
|
+
self._semaphore = asyncio.Semaphore(self.concurrency)
|
|
37
|
+
|
|
38
|
+
async def __call__(self, handler: Any) -> Any:
|
|
39
|
+
async with self._semaphore:
|
|
40
|
+
return await self.check(handler)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class AllowAllGuard(PermissionGuard):
|
|
44
|
+
async def check(self, handler: Any) -> Any:
|
|
45
|
+
return handler
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class DenyAllGuard(PermissionGuard):
|
|
49
|
+
async def check(self, handler: Any) -> Any:
|
|
50
|
+
raise GuardError("denied")
|
axio/selector.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
"""Tool selector: choose relevant tools per query via embedding similarity."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import math
|
|
6
|
+
from collections.abc import AsyncIterator
|
|
7
|
+
from typing import Protocol, runtime_checkable
|
|
8
|
+
|
|
9
|
+
from axio.blocks import TextBlock, ToolResultBlock
|
|
10
|
+
from axio.events import StreamEvent
|
|
11
|
+
from axio.messages import Message
|
|
12
|
+
from axio.tool import Tool
|
|
13
|
+
from axio.transport import CompletionTransport, EmbeddingTransport
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@runtime_checkable
|
|
17
|
+
class ToolSelector(Protocol):
|
|
18
|
+
async def select(self, messages: list[Message], tools: list[Tool]) -> list[Tool]: ...
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _cosine_similarity(a: list[float], b: list[float]) -> float:
|
|
22
|
+
dot = sum(x * y for x, y in zip(a, b))
|
|
23
|
+
norm_a = math.sqrt(sum(x * x for x in a))
|
|
24
|
+
norm_b = math.sqrt(sum(x * x for x in b))
|
|
25
|
+
if norm_a == 0.0 or norm_b == 0.0:
|
|
26
|
+
return 0.0
|
|
27
|
+
return dot / (norm_a * norm_b)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _extract_last_user_text(messages: list[Message]) -> str | None:
|
|
31
|
+
"""Return joined text from the last user message, or None if it's a tool-result iteration."""
|
|
32
|
+
for msg in reversed(messages):
|
|
33
|
+
if msg.role != "user":
|
|
34
|
+
continue
|
|
35
|
+
texts = [b.text for b in msg.content if isinstance(b, TextBlock)]
|
|
36
|
+
if texts:
|
|
37
|
+
return " ".join(texts)
|
|
38
|
+
# User message with only ToolResultBlocks → tool-result iteration
|
|
39
|
+
if any(isinstance(b, ToolResultBlock) for b in msg.content):
|
|
40
|
+
return None
|
|
41
|
+
return None
|
|
42
|
+
return None
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class EmbeddingToolSelector:
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
transport: EmbeddingTransport,
|
|
49
|
+
*,
|
|
50
|
+
top_k: int = 5,
|
|
51
|
+
pinned: frozenset[str] = frozenset(),
|
|
52
|
+
) -> None:
|
|
53
|
+
self._transport = transport
|
|
54
|
+
self._top_k = top_k
|
|
55
|
+
self._pinned = pinned
|
|
56
|
+
self._tool_embeddings: dict[str, list[float]] = {}
|
|
57
|
+
self._tool_descriptions: dict[str, str] = {}
|
|
58
|
+
|
|
59
|
+
async def _ensure_embeddings(self, tools: list[Tool]) -> None:
|
|
60
|
+
to_embed: list[tuple[str, str]] = []
|
|
61
|
+
for tool in tools:
|
|
62
|
+
cached_desc = self._tool_descriptions.get(tool.name)
|
|
63
|
+
if cached_desc is None or cached_desc != tool.description:
|
|
64
|
+
to_embed.append((tool.name, tool.description))
|
|
65
|
+
|
|
66
|
+
if not to_embed:
|
|
67
|
+
return
|
|
68
|
+
|
|
69
|
+
texts = [desc for _, desc in to_embed]
|
|
70
|
+
vectors = await self._transport.embed(texts)
|
|
71
|
+
for (name, desc), vec in zip(to_embed, vectors):
|
|
72
|
+
self._tool_embeddings[name] = vec
|
|
73
|
+
self._tool_descriptions[name] = desc
|
|
74
|
+
|
|
75
|
+
async def select(self, messages: list[Message], tools: list[Tool]) -> list[Tool]:
|
|
76
|
+
if len(tools) <= self._top_k:
|
|
77
|
+
return tools
|
|
78
|
+
|
|
79
|
+
query = _extract_last_user_text(messages)
|
|
80
|
+
if query is None:
|
|
81
|
+
return tools
|
|
82
|
+
|
|
83
|
+
await self._ensure_embeddings(tools)
|
|
84
|
+
|
|
85
|
+
query_vec = (await self._transport.embed([query]))[0]
|
|
86
|
+
|
|
87
|
+
pinned_tools: list[Tool] = []
|
|
88
|
+
scorable: list[tuple[float, Tool]] = []
|
|
89
|
+
|
|
90
|
+
for tool in tools:
|
|
91
|
+
if tool.name in self._pinned:
|
|
92
|
+
pinned_tools.append(tool)
|
|
93
|
+
elif tool.name in self._tool_embeddings:
|
|
94
|
+
score = _cosine_similarity(query_vec, self._tool_embeddings[tool.name])
|
|
95
|
+
scorable.append((score, tool))
|
|
96
|
+
|
|
97
|
+
scorable.sort(key=lambda x: x[0], reverse=True)
|
|
98
|
+
remaining_slots = max(0, self._top_k - len(pinned_tools))
|
|
99
|
+
selected = pinned_tools + [t for _, t in scorable[:remaining_slots]]
|
|
100
|
+
return selected
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class ToolFilteringTransport:
|
|
104
|
+
"""CompletionTransport decorator that filters tools via a ToolSelector."""
|
|
105
|
+
|
|
106
|
+
def __init__(self, transport: CompletionTransport, selector: ToolSelector) -> None:
|
|
107
|
+
self._transport = transport
|
|
108
|
+
self._selector = selector
|
|
109
|
+
|
|
110
|
+
def stream(self, messages: list[Message], tools: list[Tool], system: str) -> AsyncIterator[StreamEvent]:
|
|
111
|
+
return self._filtered_stream(messages, tools, system)
|
|
112
|
+
|
|
113
|
+
async def _filtered_stream(
|
|
114
|
+
self,
|
|
115
|
+
messages: list[Message],
|
|
116
|
+
tools: list[Tool],
|
|
117
|
+
system: str,
|
|
118
|
+
) -> AsyncIterator[StreamEvent]:
|
|
119
|
+
selected = await self._selector.select(messages, tools)
|
|
120
|
+
async for event in self._transport.stream(messages, selected, system):
|
|
121
|
+
yield event
|
axio/stream.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
"""AgentStream: async iterator wrapper over the agent event generator."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import AsyncGenerator
|
|
6
|
+
|
|
7
|
+
from axio.events import Error, SessionEndEvent, StreamEvent, TextDelta
|
|
8
|
+
from axio.exceptions import StreamError
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class AgentStream:
|
|
12
|
+
def __init__(self, generator: AsyncGenerator[StreamEvent, None]) -> None:
|
|
13
|
+
self._generator = generator
|
|
14
|
+
self._closed = False
|
|
15
|
+
|
|
16
|
+
def __aiter__(self) -> AgentStream:
|
|
17
|
+
return self
|
|
18
|
+
|
|
19
|
+
async def __anext__(self) -> StreamEvent:
|
|
20
|
+
if self._closed:
|
|
21
|
+
raise StopAsyncIteration
|
|
22
|
+
try:
|
|
23
|
+
return await self._generator.__anext__()
|
|
24
|
+
except StopAsyncIteration:
|
|
25
|
+
self._closed = True
|
|
26
|
+
raise
|
|
27
|
+
|
|
28
|
+
async def aclose(self) -> None:
|
|
29
|
+
if not self._closed:
|
|
30
|
+
self._closed = True
|
|
31
|
+
await self._generator.aclose()
|
|
32
|
+
|
|
33
|
+
async def get_final_text(self) -> str:
|
|
34
|
+
parts: list[str] = []
|
|
35
|
+
try:
|
|
36
|
+
async for event in self:
|
|
37
|
+
if isinstance(event, Error):
|
|
38
|
+
raise StreamError(str(event.exception)) from event.exception
|
|
39
|
+
if isinstance(event, TextDelta):
|
|
40
|
+
parts.append(event.delta)
|
|
41
|
+
finally:
|
|
42
|
+
await self.aclose()
|
|
43
|
+
return "".join(parts)
|
|
44
|
+
|
|
45
|
+
async def get_session_end(self) -> SessionEndEvent:
|
|
46
|
+
result: SessionEndEvent | None = None
|
|
47
|
+
try:
|
|
48
|
+
async for event in self:
|
|
49
|
+
if isinstance(event, Error):
|
|
50
|
+
raise StreamError(str(event.exception)) from event.exception
|
|
51
|
+
if isinstance(event, SessionEndEvent):
|
|
52
|
+
result = event
|
|
53
|
+
finally:
|
|
54
|
+
await self.aclose()
|
|
55
|
+
if result is None:
|
|
56
|
+
raise StreamError("Stream ended without SessionEndEvent")
|
|
57
|
+
return result
|
axio/testing.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
"""Shared test helpers: StubTransport, fixtures, response builders."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from collections.abc import AsyncIterator
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
from axio.context import MemoryContextStore
|
|
10
|
+
from axio.events import IterationEnd, StreamEvent, TextDelta, ToolInputDelta, ToolUseStart
|
|
11
|
+
from axio.messages import Message
|
|
12
|
+
from axio.tool import Tool, ToolHandler
|
|
13
|
+
from axio.types import StopReason, Usage
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class MsgInput(ToolHandler):
|
|
17
|
+
msg: str
|
|
18
|
+
|
|
19
|
+
async def __call__(self) -> str:
|
|
20
|
+
return self.model_dump_json()
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class StubTransport:
|
|
24
|
+
"""A CompletionTransport that yields pre-configured event sequences.
|
|
25
|
+
|
|
26
|
+
Each call to stream() pops the next sequence from the list.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(self, responses: list[list[StreamEvent]] | None = None) -> None:
|
|
30
|
+
self._responses: list[list[StreamEvent]] = list(responses or [])
|
|
31
|
+
self._call_count = 0
|
|
32
|
+
|
|
33
|
+
async def _generate(self, events: list[StreamEvent]) -> AsyncIterator[StreamEvent]:
|
|
34
|
+
for event in events:
|
|
35
|
+
yield event
|
|
36
|
+
|
|
37
|
+
def stream(self, messages: list[Message], tools: list[Tool], system: str) -> AsyncIterator[StreamEvent]:
|
|
38
|
+
idx = min(self._call_count, len(self._responses) - 1)
|
|
39
|
+
events = self._responses[idx]
|
|
40
|
+
self._call_count += 1
|
|
41
|
+
return self._generate(events)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def make_tool_use_response(
|
|
45
|
+
tool_name: str = "echo",
|
|
46
|
+
tool_id: str = "call_1",
|
|
47
|
+
tool_input: dict[str, Any] | None = None,
|
|
48
|
+
iteration: int = 1,
|
|
49
|
+
usage: Usage | None = None,
|
|
50
|
+
) -> list[StreamEvent]:
|
|
51
|
+
"""Build a standard tool_use response event sequence."""
|
|
52
|
+
inp = tool_input or {"msg": "hi"}
|
|
53
|
+
u = usage or Usage(10, 5)
|
|
54
|
+
return [
|
|
55
|
+
ToolUseStart(0, tool_id, tool_name),
|
|
56
|
+
ToolInputDelta(0, tool_id, json.dumps(inp)),
|
|
57
|
+
IterationEnd(iteration, StopReason.tool_use, u),
|
|
58
|
+
]
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def make_text_response(text: str = "Done", iteration: int = 2, usage: Usage | None = None) -> list[StreamEvent]:
|
|
62
|
+
"""Build a standard end_turn text response event sequence."""
|
|
63
|
+
u = usage or Usage(10, 5)
|
|
64
|
+
return [
|
|
65
|
+
TextDelta(0, text),
|
|
66
|
+
IterationEnd(iteration, StopReason.end_turn, u),
|
|
67
|
+
]
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def make_stub_transport() -> StubTransport:
|
|
71
|
+
return StubTransport(
|
|
72
|
+
[
|
|
73
|
+
[
|
|
74
|
+
TextDelta(0, "Hello"),
|
|
75
|
+
TextDelta(0, " world"),
|
|
76
|
+
IterationEnd(1, StopReason.end_turn, Usage(10, 5)),
|
|
77
|
+
]
|
|
78
|
+
]
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def make_ephemeral_context() -> MemoryContextStore:
|
|
83
|
+
return MemoryContextStore()
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def make_echo_tool() -> Tool:
|
|
87
|
+
return Tool(name="echo", description="Returns input as JSON", handler=MsgInput)
|
axio/tool.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
"""Tool: frozen dataclass binding a ToolHandler to a name, guard, and concurrency."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
from collections.abc import AsyncGenerator
|
|
7
|
+
from contextlib import asynccontextmanager
|
|
8
|
+
from dataclasses import dataclass, field
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
from pydantic import BaseModel
|
|
12
|
+
|
|
13
|
+
from axio.exceptions import GuardError, HandlerError
|
|
14
|
+
from axio.permission import PermissionGuard
|
|
15
|
+
from axio.types import ToolName
|
|
16
|
+
|
|
17
|
+
type JSONSchema = dict[str, Any]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ToolHandler(BaseModel):
|
|
21
|
+
"""Base for tool handlers.
|
|
22
|
+
|
|
23
|
+
Subclass fields define the input JSON-schema.
|
|
24
|
+
Override ``async def __call__`` to implement execution logic.
|
|
25
|
+
Pydantic provides ``__repr__`` automatically — override for custom display.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
async def __call__(self) -> str:
|
|
29
|
+
raise NotImplementedError
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass(frozen=True, slots=True)
|
|
33
|
+
class Tool:
|
|
34
|
+
name: ToolName
|
|
35
|
+
description: str
|
|
36
|
+
handler: type[ToolHandler]
|
|
37
|
+
guards: tuple[PermissionGuard, ...] = ()
|
|
38
|
+
concurrency: int | None = None
|
|
39
|
+
|
|
40
|
+
_semaphore: asyncio.Semaphore | None = field(init=False, default=None, repr=False, compare=False)
|
|
41
|
+
|
|
42
|
+
def __post_init__(self) -> None:
|
|
43
|
+
if self.concurrency is not None:
|
|
44
|
+
object.__setattr__(self, "_semaphore", asyncio.Semaphore(self.concurrency))
|
|
45
|
+
|
|
46
|
+
@asynccontextmanager
|
|
47
|
+
async def _acquire(self) -> AsyncGenerator[None, None]:
|
|
48
|
+
if self._semaphore is None:
|
|
49
|
+
yield
|
|
50
|
+
return
|
|
51
|
+
|
|
52
|
+
async with self._semaphore:
|
|
53
|
+
yield
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
def input_schema(self) -> JSONSchema:
|
|
57
|
+
return self.handler.model_json_schema()
|
|
58
|
+
|
|
59
|
+
async def __call__(self, **kwargs: Any) -> Any:
|
|
60
|
+
async with self._acquire():
|
|
61
|
+
instance = self.handler.model_validate(kwargs)
|
|
62
|
+
for guard in self.guards:
|
|
63
|
+
try:
|
|
64
|
+
instance = await guard(instance)
|
|
65
|
+
except GuardError:
|
|
66
|
+
raise
|
|
67
|
+
except Exception as exc:
|
|
68
|
+
raise GuardError(str(exc)) from exc
|
|
69
|
+
try:
|
|
70
|
+
return await instance()
|
|
71
|
+
except HandlerError:
|
|
72
|
+
raise
|
|
73
|
+
except Exception as exc:
|
|
74
|
+
raise HandlerError(str(exc)) from exc
|
axio/transport.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
"""Transport protocols: completion, image gen, TTS, STT."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import AsyncIterator
|
|
6
|
+
from typing import Protocol, runtime_checkable
|
|
7
|
+
|
|
8
|
+
from axio.events import StreamEvent
|
|
9
|
+
from axio.messages import Message
|
|
10
|
+
from axio.tool import Tool
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@runtime_checkable
|
|
14
|
+
class CompletionTransport(Protocol):
|
|
15
|
+
def stream(self, messages: list[Message], tools: list[Tool], system: str) -> AsyncIterator[StreamEvent]: ...
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@runtime_checkable
|
|
19
|
+
class ImageGenTransport(Protocol):
|
|
20
|
+
async def generate(self, prompt: str, *, size: tuple[int, int] | None = None, n: int = 1) -> list[bytes]: ...
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@runtime_checkable
|
|
24
|
+
class TTSTransport(Protocol):
|
|
25
|
+
def synthesize(self, text: str, *, voice: str | None = None) -> AsyncIterator[bytes]: ...
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@runtime_checkable
|
|
29
|
+
class STTTransport(Protocol):
|
|
30
|
+
async def transcribe(self, audio: bytes, media_type: str = "audio/wav") -> str: ...
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@runtime_checkable
|
|
34
|
+
class EmbeddingTransport(Protocol):
|
|
35
|
+
async def embed(self, texts: list[str]) -> list[list[float]]: ...
|
axio/types.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
"""Primitive types: ToolName, ToolCallID, StopReason, Usage."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from enum import StrEnum
|
|
7
|
+
|
|
8
|
+
type ToolName = str
|
|
9
|
+
type ToolCallID = str
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class StopReason(StrEnum):
|
|
13
|
+
end_turn = "end_turn"
|
|
14
|
+
tool_use = "tool_use"
|
|
15
|
+
max_tokens = "max_tokens"
|
|
16
|
+
error = "error"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass(frozen=True, slots=True)
|
|
20
|
+
class Usage:
|
|
21
|
+
input_tokens: int
|
|
22
|
+
output_tokens: int
|
|
23
|
+
|
|
24
|
+
def __add__(self, other: Usage) -> Usage:
|
|
25
|
+
return Usage(
|
|
26
|
+
input_tokens=self.input_tokens + other.input_tokens,
|
|
27
|
+
output_tokens=self.output_tokens + other.output_tokens,
|
|
28
|
+
)
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
axio/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
|
2
|
+
axio/agent.py,sha256=HqWptDuRdz1Hns-rcm43iiN3nEq35zmUBihWCH0EcOA,10691
|
|
3
|
+
axio/blocks.py,sha256=MUOTqZcTWTi2d1b7WEoomueaCulkzoJ2QxJEmNiM5qA,3011
|
|
4
|
+
axio/context.py,sha256=CTAaqJVp9aT6BkSkZbOO56b6-Kwaw4VcprpX6Rivzzc,6202
|
|
5
|
+
axio/events.py,sha256=HimirQY5g9mSa1TWpSq1cyxl5SnUMoas3hrWPfT4GXs,1299
|
|
6
|
+
axio/exceptions.py,sha256=pvgKVcVfJ-HsMr2JatDIE7lzdw-mMQZZLszzSCLj3Lk,422
|
|
7
|
+
axio/messages.py,sha256=0-GFC7BNPTyrJaVA584EjNN3c6LjSMaftH9CQR_Tr04,656
|
|
8
|
+
axio/models.py,sha256=FA8h3MQqQS0-oz50QST7ZiYyjiPWxJ73rivEQ1MJbeo,3325
|
|
9
|
+
axio/permission.py,sha256=_BPXZv8QRljwV5wzpPEptvsbDxDtvoLEn97FFyg7GMw,1313
|
|
10
|
+
axio/selector.py,sha256=ehXIYrZ5rUM-UDO6gctbQG8T2vp7nx6ez043IRPk20g,4243
|
|
11
|
+
axio/stream.py,sha256=xLu_D2NefG6NBg4q5xwVjG7jhS-wZzWlyFcpbqKb-zI,1853
|
|
12
|
+
axio/testing.py,sha256=jjr1dmOu4FYKfFOttFbkJDsCGDJV8dbdDpurehTIkvE,2637
|
|
13
|
+
axio/tool.py,sha256=NTXQ3kz1x8mmKAAOntoF-TPgDZ7icsOJ29yR_pIsP1U,2237
|
|
14
|
+
axio/transport.py,sha256=fWK8bUofM_GCZtfO8Cw5AJHjYsGs1GitYoRCCkGYb-4,1036
|
|
15
|
+
axio/types.py,sha256=n18_fDTbZUI1hp6rjvithcJ_5tumKT6lr5Cm3ua0p-I,642
|
|
16
|
+
axio-0.1.0.dist-info/METADATA,sha256=EVl7pr2sFcPvWlzAqB-OQtNDIeBF0RMuhAbwPAUsPHQ,219
|
|
17
|
+
axio-0.1.0.dist-info/WHEEL,sha256=QccIxa26bgl1E6uMy58deGWi-0aeIkkangHcxk2kWfw,87
|
|
18
|
+
axio-0.1.0.dist-info/licenses/LICENSE,sha256=ddOkAXgIM2QzvUDPydeis3tJXfhHxt1wKnmGjbwPPxQ,1074
|
|
19
|
+
axio-0.1.0.dist-info/RECORD,,
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 Axio contributors
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|