axio 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- axio-0.1.0/.github/workflows/publish.yml +40 -0
- axio-0.1.0/.github/workflows/tests.yml +44 -0
- axio-0.1.0/LICENSE +21 -0
- axio-0.1.0/Makefile +16 -0
- axio-0.1.0/PKG-INFO +8 -0
- axio-0.1.0/README.md +15 -0
- axio-0.1.0/pyproject.toml +31 -0
- axio-0.1.0/src/axio/__init__.py +1 -0
- axio-0.1.0/src/axio/agent.py +239 -0
- axio-0.1.0/src/axio/blocks.py +98 -0
- axio-0.1.0/src/axio/context.py +197 -0
- axio-0.1.0/src/axio/events.py +66 -0
- axio-0.1.0/src/axio/exceptions.py +21 -0
- axio-0.1.0/src/axio/messages.py +21 -0
- axio-0.1.0/src/axio/models.py +102 -0
- axio-0.1.0/src/axio/permission.py +50 -0
- axio-0.1.0/src/axio/selector.py +121 -0
- axio-0.1.0/src/axio/stream.py +57 -0
- axio-0.1.0/src/axio/testing.py +87 -0
- axio-0.1.0/src/axio/tool.py +74 -0
- axio-0.1.0/src/axio/transport.py +35 -0
- axio-0.1.0/src/axio/types.py +28 -0
- axio-0.1.0/tests/conftest.py +23 -0
- axio-0.1.0/tests/test_agent_branch.py +67 -0
- axio-0.1.0/tests/test_agent_permission.py +154 -0
- axio-0.1.0/tests/test_agent_run.py +192 -0
- axio-0.1.0/tests/test_agent_stream.py +87 -0
- axio-0.1.0/tests/test_agent_tools.py +270 -0
- axio-0.1.0/tests/test_blocks.py +167 -0
- axio-0.1.0/tests/test_context.py +245 -0
- axio-0.1.0/tests/test_events.py +112 -0
- axio-0.1.0/tests/test_exceptions.py +48 -0
- axio-0.1.0/tests/test_permission.py +28 -0
- axio-0.1.0/tests/test_selector.py +235 -0
- axio-0.1.0/tests/test_stream.py +86 -0
- axio-0.1.0/tests/test_tool.py +160 -0
- axio-0.1.0/tests/test_transport.py +46 -0
- axio-0.1.0/tests/test_types.py +61 -0
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
name: Publish to PyPI
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
release:
|
|
5
|
+
types: [published]
|
|
6
|
+
|
|
7
|
+
env:
|
|
8
|
+
FORCE_COLOR: 1
|
|
9
|
+
|
|
10
|
+
jobs:
|
|
11
|
+
build:
|
|
12
|
+
runs-on: ubuntu-latest
|
|
13
|
+
steps:
|
|
14
|
+
- uses: actions/checkout@v4
|
|
15
|
+
- uses: astral-sh/setup-uv@v6
|
|
16
|
+
|
|
17
|
+
- name: Set version from release tag
|
|
18
|
+
run: uv version "${GITHUB_REF_NAME#v}"
|
|
19
|
+
|
|
20
|
+
- name: Build package
|
|
21
|
+
run: uv build
|
|
22
|
+
|
|
23
|
+
- uses: actions/upload-artifact@v4
|
|
24
|
+
with:
|
|
25
|
+
name: dist
|
|
26
|
+
path: dist/
|
|
27
|
+
|
|
28
|
+
publish:
|
|
29
|
+
runs-on: ubuntu-latest
|
|
30
|
+
needs: build
|
|
31
|
+
environment: pypi
|
|
32
|
+
permissions:
|
|
33
|
+
id-token: write
|
|
34
|
+
steps:
|
|
35
|
+
- uses: actions/download-artifact@v4
|
|
36
|
+
with:
|
|
37
|
+
name: dist
|
|
38
|
+
path: dist/
|
|
39
|
+
|
|
40
|
+
- uses: pypa/gh-action-pypi-publish@release/v1
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
name: tests
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
push:
|
|
5
|
+
branches: [ master, main ]
|
|
6
|
+
pull_request:
|
|
7
|
+
branches: [ master, main ]
|
|
8
|
+
|
|
9
|
+
env:
|
|
10
|
+
FORCE_COLOR: 1
|
|
11
|
+
|
|
12
|
+
jobs:
|
|
13
|
+
ruff:
|
|
14
|
+
runs-on: ubuntu-latest
|
|
15
|
+
steps:
|
|
16
|
+
- uses: actions/checkout@v4
|
|
17
|
+
- uses: astral-sh/setup-uv@v6
|
|
18
|
+
- run: uv sync --frozen
|
|
19
|
+
- run: uv run ruff check
|
|
20
|
+
- run: uv run ruff format --check
|
|
21
|
+
|
|
22
|
+
mypy:
|
|
23
|
+
runs-on: ubuntu-latest
|
|
24
|
+
steps:
|
|
25
|
+
- uses: actions/checkout@v4
|
|
26
|
+
- uses: astral-sh/setup-uv@v6
|
|
27
|
+
- run: uv sync --frozen
|
|
28
|
+
- run: uv run mypy .
|
|
29
|
+
|
|
30
|
+
tests:
|
|
31
|
+
runs-on: ubuntu-latest
|
|
32
|
+
strategy:
|
|
33
|
+
fail-fast: false
|
|
34
|
+
matrix:
|
|
35
|
+
python:
|
|
36
|
+
- "3.12"
|
|
37
|
+
- "3.13"
|
|
38
|
+
steps:
|
|
39
|
+
- uses: actions/checkout@v4
|
|
40
|
+
- uses: astral-sh/setup-uv@v6
|
|
41
|
+
with:
|
|
42
|
+
python-version: ${{ matrix.python }}
|
|
43
|
+
- run: uv sync --frozen
|
|
44
|
+
- run: uv run pytest -vv --cov=axio --cov-report=term-missing
|
axio-0.1.0/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 Axio contributors
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
axio-0.1.0/Makefile
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
.PHONY: fmt lint typecheck test all
|
|
2
|
+
|
|
3
|
+
fmt:
|
|
4
|
+
uv run ruff format src/ tests/
|
|
5
|
+
uv run ruff check --fix src/ tests/
|
|
6
|
+
|
|
7
|
+
lint:
|
|
8
|
+
uv run ruff check src/ tests/
|
|
9
|
+
|
|
10
|
+
typecheck:
|
|
11
|
+
uv run mypy src/
|
|
12
|
+
|
|
13
|
+
test:
|
|
14
|
+
uv run pytest tests/ -v
|
|
15
|
+
|
|
16
|
+
all: fmt lint typecheck test
|
axio-0.1.0/PKG-INFO
ADDED
axio-0.1.0/README.md
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "axio"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = "Minimal, streaming-first, protocol-driven foundation for LLM-powered agents"
|
|
5
|
+
requires-python = ">=3.12"
|
|
6
|
+
license = {text = "MIT"}
|
|
7
|
+
dependencies = ["pydantic>=2"]
|
|
8
|
+
|
|
9
|
+
[build-system]
|
|
10
|
+
requires = ["hatchling"]
|
|
11
|
+
build-backend = "hatchling.build"
|
|
12
|
+
|
|
13
|
+
[tool.hatch.build.targets.wheel]
|
|
14
|
+
packages = ["src/axio"]
|
|
15
|
+
|
|
16
|
+
[tool.pytest.ini_options]
|
|
17
|
+
asyncio_mode = "auto"
|
|
18
|
+
|
|
19
|
+
[tool.ruff]
|
|
20
|
+
line-length = 119
|
|
21
|
+
target-version = "py312"
|
|
22
|
+
|
|
23
|
+
[tool.ruff.lint]
|
|
24
|
+
select = ["E", "F", "I", "UP"]
|
|
25
|
+
|
|
26
|
+
[tool.mypy]
|
|
27
|
+
strict = true
|
|
28
|
+
python_version = "3.12"
|
|
29
|
+
|
|
30
|
+
[dependency-groups]
|
|
31
|
+
dev = ["pytest>=8", "pytest-asyncio>=0.24", "mypy>=1.14", "ruff>=0.9"]
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1,239 @@
|
|
|
1
|
+
"""Agent: the core agentic loop orchestrating transport, tools, and context."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import json
|
|
7
|
+
import logging
|
|
8
|
+
from collections.abc import AsyncGenerator
|
|
9
|
+
from dataclasses import dataclass, field
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
from axio.blocks import TextBlock, ToolResultBlock, ToolUseBlock
|
|
13
|
+
from axio.context import ContextStore
|
|
14
|
+
from axio.events import (
|
|
15
|
+
Error,
|
|
16
|
+
IterationEnd,
|
|
17
|
+
SessionEndEvent,
|
|
18
|
+
StreamEvent,
|
|
19
|
+
TextDelta,
|
|
20
|
+
ToolInputDelta,
|
|
21
|
+
ToolResult,
|
|
22
|
+
ToolUseStart,
|
|
23
|
+
)
|
|
24
|
+
from axio.messages import Message
|
|
25
|
+
from axio.stream import AgentStream
|
|
26
|
+
from axio.tool import Tool
|
|
27
|
+
from axio.transport import CompletionTransport
|
|
28
|
+
from axio.types import StopReason, Usage
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass(slots=True)
|
|
34
|
+
class Agent:
|
|
35
|
+
system: str
|
|
36
|
+
tools: list[Tool]
|
|
37
|
+
transport: CompletionTransport
|
|
38
|
+
max_iterations: int = field(default=50)
|
|
39
|
+
|
|
40
|
+
def run_stream(self, user_message: str, context: ContextStore) -> AgentStream:
|
|
41
|
+
return AgentStream(self._run_loop(user_message, context))
|
|
42
|
+
|
|
43
|
+
async def run(self, user_message: str, context: ContextStore) -> str:
|
|
44
|
+
return await self.run_stream(user_message, context).get_final_text()
|
|
45
|
+
|
|
46
|
+
async def dispatch_tools(self, blocks: list[ToolUseBlock], iteration: int) -> list[ToolResultBlock]:
|
|
47
|
+
tool_names = [b.name for b in blocks]
|
|
48
|
+
logger.info("Dispatching %d tool(s): %s", len(blocks), tool_names)
|
|
49
|
+
|
|
50
|
+
async def _run_one(block: ToolUseBlock) -> ToolResultBlock:
|
|
51
|
+
tool = self._find_tool(block.name)
|
|
52
|
+
if tool is None:
|
|
53
|
+
logger.warning("Unknown tool requested: %s", block.name)
|
|
54
|
+
return ToolResultBlock(tool_use_id=block.id, content=f"Unknown tool: {block.name}", is_error=True)
|
|
55
|
+
logger.debug("Tool %s (id=%s) args=%s", block.name, block.id, json.dumps(block.input)[:200])
|
|
56
|
+
try:
|
|
57
|
+
result = await tool(**block.input)
|
|
58
|
+
content = result if isinstance(result, str) else str(result)
|
|
59
|
+
except Exception as exc:
|
|
60
|
+
logger.error("Tool %s raised %s: %s", block.name, type(exc).__name__, exc, exc_info=True)
|
|
61
|
+
return ToolResultBlock(tool_use_id=block.id, content=str(exc), is_error=True)
|
|
62
|
+
return ToolResultBlock(tool_use_id=block.id, content=content)
|
|
63
|
+
|
|
64
|
+
results = list(await asyncio.gather(*[_run_one(b) for b in blocks]))
|
|
65
|
+
error_count = sum(1 for r in results if r.is_error)
|
|
66
|
+
logger.info("Tools complete: %d total, %d errors", len(results), error_count)
|
|
67
|
+
return results
|
|
68
|
+
|
|
69
|
+
def _find_tool(self, name: str) -> Tool | None:
|
|
70
|
+
for tool in self.tools:
|
|
71
|
+
if tool.name == name:
|
|
72
|
+
return tool
|
|
73
|
+
return None
|
|
74
|
+
|
|
75
|
+
async def _append(self, context: ContextStore, message: Message) -> None:
|
|
76
|
+
await context.append(message)
|
|
77
|
+
|
|
78
|
+
@staticmethod
|
|
79
|
+
def _accumulate_text(content: list[TextBlock | ToolUseBlock], delta: str) -> None:
|
|
80
|
+
"""Append text delta — merge into last TextBlock or start a new one."""
|
|
81
|
+
if content and isinstance(content[-1], TextBlock):
|
|
82
|
+
content[-1] = TextBlock(text=content[-1].text + delta)
|
|
83
|
+
else:
|
|
84
|
+
content.append(TextBlock(text=delta))
|
|
85
|
+
|
|
86
|
+
@staticmethod
|
|
87
|
+
def _finalize_pending_tools(
|
|
88
|
+
pending: dict[str, dict[str, Any]],
|
|
89
|
+
usage: Usage,
|
|
90
|
+
) -> tuple[list[ToolUseBlock], set[str]]:
|
|
91
|
+
"""Convert streamed tool-call fragments into ToolUseBlocks.
|
|
92
|
+
|
|
93
|
+
Returns (blocks, malformed_ids).
|
|
94
|
+
"""
|
|
95
|
+
blocks: list[ToolUseBlock] = []
|
|
96
|
+
malformed: set[str] = set()
|
|
97
|
+
for tid, info in pending.items():
|
|
98
|
+
raw = "".join(info["json_parts"])
|
|
99
|
+
if not raw:
|
|
100
|
+
logger.warning(
|
|
101
|
+
"Tool %s (id=%s) received empty arguments (output may be truncated, output_tokens=%d)",
|
|
102
|
+
info["name"],
|
|
103
|
+
tid,
|
|
104
|
+
usage.output_tokens,
|
|
105
|
+
)
|
|
106
|
+
inp: dict[str, Any] = {}
|
|
107
|
+
else:
|
|
108
|
+
try:
|
|
109
|
+
inp = json.loads(raw)
|
|
110
|
+
except json.JSONDecodeError as exc:
|
|
111
|
+
logger.warning(
|
|
112
|
+
"Tool %s (id=%s) has malformed JSON arguments: %s\nRaw: %s",
|
|
113
|
+
info["name"],
|
|
114
|
+
tid,
|
|
115
|
+
exc,
|
|
116
|
+
raw,
|
|
117
|
+
)
|
|
118
|
+
malformed.add(tid)
|
|
119
|
+
inp = {}
|
|
120
|
+
blocks.append(ToolUseBlock(id=tid, name=info["name"], input=inp))
|
|
121
|
+
return blocks, malformed
|
|
122
|
+
|
|
123
|
+
async def _run_loop(self, user_message: str, context: ContextStore) -> AsyncGenerator[StreamEvent, None]:
|
|
124
|
+
total_usage = Usage(0, 0)
|
|
125
|
+
session_end_emitted = False
|
|
126
|
+
await self._append(context, Message(role="user", content=[TextBlock(text=user_message)]))
|
|
127
|
+
|
|
128
|
+
try:
|
|
129
|
+
for iteration in range(1, self.max_iterations + 1):
|
|
130
|
+
history = await context.get_history()
|
|
131
|
+
logger.info("Iteration %d, history length=%d", iteration, len(history))
|
|
132
|
+
active_tools = self.tools
|
|
133
|
+
|
|
134
|
+
content: list[TextBlock | ToolUseBlock] = []
|
|
135
|
+
pending: dict[str, dict[str, Any]] = {}
|
|
136
|
+
stop_reason = StopReason.end_turn
|
|
137
|
+
malformed: set[str] = set()
|
|
138
|
+
|
|
139
|
+
try:
|
|
140
|
+
async for event in self.transport.stream(history, active_tools, self.system):
|
|
141
|
+
yield event
|
|
142
|
+
match event:
|
|
143
|
+
case TextDelta(delta=delta):
|
|
144
|
+
self._accumulate_text(content, delta)
|
|
145
|
+
case ToolUseStart(tool_use_id=tid, name=name):
|
|
146
|
+
pending[tid] = {"name": name, "json_parts": []}
|
|
147
|
+
case ToolInputDelta(tool_use_id=tid, partial_json=pj):
|
|
148
|
+
if tid in pending:
|
|
149
|
+
pending[tid]["json_parts"].append(pj)
|
|
150
|
+
case IterationEnd(usage=usage, stop_reason=sr):
|
|
151
|
+
blocks, malformed = self._finalize_pending_tools(pending, usage)
|
|
152
|
+
content.extend(blocks)
|
|
153
|
+
pending.clear()
|
|
154
|
+
total_usage = total_usage + usage
|
|
155
|
+
await context.add_context_tokens(usage.input_tokens, usage.output_tokens)
|
|
156
|
+
stop_reason = sr
|
|
157
|
+
except Exception as exc:
|
|
158
|
+
logger.error("Transport error: %s", exc, exc_info=True)
|
|
159
|
+
yield Error(exception=exc)
|
|
160
|
+
yield SessionEndEvent(stop_reason=StopReason.error, total_usage=total_usage)
|
|
161
|
+
session_end_emitted = True
|
|
162
|
+
return
|
|
163
|
+
|
|
164
|
+
tool_blocks = [b for b in content if isinstance(b, ToolUseBlock)]
|
|
165
|
+
|
|
166
|
+
if tool_blocks:
|
|
167
|
+
if stop_reason != StopReason.tool_use:
|
|
168
|
+
logger.warning(
|
|
169
|
+
"Dispatching %d tool(s) despite stop_reason=%s",
|
|
170
|
+
len(tool_blocks),
|
|
171
|
+
stop_reason,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
# Dispatch tools BEFORE appending to context — cancellation
|
|
175
|
+
# between here and the two appends below cannot leave orphan
|
|
176
|
+
# ToolUseBlocks in the persistent context store.
|
|
177
|
+
valid = [b for b in tool_blocks if b.id not in malformed]
|
|
178
|
+
error_results = [
|
|
179
|
+
ToolResultBlock(
|
|
180
|
+
tool_use_id=b.id,
|
|
181
|
+
content=(
|
|
182
|
+
f"Malformed JSON arguments for tool {b.name}."
|
|
183
|
+
f" Raw input could not be parsed. Please retry the tool call"
|
|
184
|
+
f" with valid JSON arguments."
|
|
185
|
+
),
|
|
186
|
+
is_error=True,
|
|
187
|
+
)
|
|
188
|
+
for b in tool_blocks
|
|
189
|
+
if b.id in malformed
|
|
190
|
+
]
|
|
191
|
+
dispatched = await self.dispatch_tools(valid, iteration) if valid else []
|
|
192
|
+
results = dispatched + error_results
|
|
193
|
+
|
|
194
|
+
# Append both messages atomically (assistant + tool results)
|
|
195
|
+
await self._append(context, Message(role="assistant", content=list(content)))
|
|
196
|
+
await self._append(context, Message(role="user", content=list(results)))
|
|
197
|
+
|
|
198
|
+
# Yield ToolResult events
|
|
199
|
+
by_id = {b.id: b for b in tool_blocks}
|
|
200
|
+
for r in results:
|
|
201
|
+
block = by_id.get(r.tool_use_id)
|
|
202
|
+
result_content = (
|
|
203
|
+
r.content
|
|
204
|
+
if isinstance(r.content, str)
|
|
205
|
+
else "\n".join(b.text for b in r.content if isinstance(b, TextBlock))
|
|
206
|
+
)
|
|
207
|
+
yield ToolResult(
|
|
208
|
+
tool_use_id=r.tool_use_id,
|
|
209
|
+
name=block.name if block else "",
|
|
210
|
+
is_error=r.is_error,
|
|
211
|
+
content=result_content,
|
|
212
|
+
input=block.input if block else {},
|
|
213
|
+
)
|
|
214
|
+
continue
|
|
215
|
+
|
|
216
|
+
await self._append(context, Message(role="assistant", content=list(content)))
|
|
217
|
+
|
|
218
|
+
match stop_reason:
|
|
219
|
+
case StopReason.end_turn:
|
|
220
|
+
logger.debug("End turn: total_usage=%s", total_usage)
|
|
221
|
+
yield SessionEndEvent(stop_reason=StopReason.end_turn, total_usage=total_usage)
|
|
222
|
+
session_end_emitted = True
|
|
223
|
+
return
|
|
224
|
+
case StopReason.max_tokens | StopReason.error:
|
|
225
|
+
yield Error(exception=RuntimeError(f"Transport stopped with: {stop_reason}"))
|
|
226
|
+
yield SessionEndEvent(stop_reason=StopReason.error, total_usage=total_usage)
|
|
227
|
+
session_end_emitted = True
|
|
228
|
+
return
|
|
229
|
+
|
|
230
|
+
logger.warning("Max iterations (%d) reached", self.max_iterations)
|
|
231
|
+
yield SessionEndEvent(stop_reason=StopReason.error, total_usage=total_usage)
|
|
232
|
+
session_end_emitted = True
|
|
233
|
+
|
|
234
|
+
except GeneratorExit:
|
|
235
|
+
return
|
|
236
|
+
except BaseException:
|
|
237
|
+
if not session_end_emitted:
|
|
238
|
+
yield SessionEndEvent(stop_reason=StopReason.error, total_usage=total_usage)
|
|
239
|
+
raise
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
"""Content blocks: TextBlock, ImageBlock, ToolUseBlock, ToolResultBlock."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import base64
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from functools import singledispatch
|
|
8
|
+
from typing import Any, Literal
|
|
9
|
+
|
|
10
|
+
from axio.types import ToolCallID, ToolName
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ContentBlock:
|
|
14
|
+
"""Base class for all content blocks."""
|
|
15
|
+
|
|
16
|
+
__slots__ = ()
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass(frozen=True, slots=True)
|
|
20
|
+
class TextBlock(ContentBlock):
|
|
21
|
+
text: str
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass(frozen=True, slots=True)
|
|
25
|
+
class ImageBlock(ContentBlock):
|
|
26
|
+
media_type: Literal["image/jpeg", "image/png", "image/gif", "image/webp"]
|
|
27
|
+
data: bytes
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass(frozen=True, slots=True)
|
|
31
|
+
class ToolUseBlock(ContentBlock):
|
|
32
|
+
id: ToolCallID
|
|
33
|
+
name: ToolName
|
|
34
|
+
input: dict[str, Any]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass(frozen=True, slots=True)
|
|
38
|
+
class ToolResultBlock(ContentBlock):
|
|
39
|
+
tool_use_id: ToolCallID
|
|
40
|
+
content: str | list[TextBlock | ImageBlock]
|
|
41
|
+
is_error: bool = False
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@singledispatch
|
|
45
|
+
def to_dict(block: ContentBlock) -> dict[str, Any]:
|
|
46
|
+
"""Serialize a ContentBlock to a plain dict."""
|
|
47
|
+
msg = f"Unknown block type: {type(block).__name__}"
|
|
48
|
+
raise TypeError(msg)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@to_dict.register(TextBlock)
|
|
52
|
+
def _text_to_dict(block: TextBlock) -> dict[str, Any]:
|
|
53
|
+
return {"type": "text", "text": block.text}
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@to_dict.register(ImageBlock)
|
|
57
|
+
def _image_to_dict(block: ImageBlock) -> dict[str, Any]:
|
|
58
|
+
return {"type": "image", "media_type": block.media_type, "data": base64.b64encode(block.data).decode()}
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@to_dict.register(ToolUseBlock)
|
|
62
|
+
def _tool_use_to_dict(block: ToolUseBlock) -> dict[str, Any]:
|
|
63
|
+
return {"type": "tool_use", "id": block.id, "name": block.name, "input": block.input}
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@to_dict.register(ToolResultBlock)
|
|
67
|
+
def _tool_result_to_dict(block: ToolResultBlock) -> dict[str, Any]:
|
|
68
|
+
if isinstance(block.content, str):
|
|
69
|
+
serialized_content: str | list[dict[str, Any]] = block.content
|
|
70
|
+
else:
|
|
71
|
+
serialized_content = [to_dict(b) for b in block.content]
|
|
72
|
+
return {
|
|
73
|
+
"type": "tool_result",
|
|
74
|
+
"tool_use_id": block.tool_use_id,
|
|
75
|
+
"content": serialized_content,
|
|
76
|
+
"is_error": block.is_error,
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def from_dict(data: dict[str, Any]) -> ContentBlock:
|
|
81
|
+
"""Deserialize a plain dict to a ContentBlock."""
|
|
82
|
+
match data["type"]:
|
|
83
|
+
case "text":
|
|
84
|
+
return TextBlock(text=data["text"])
|
|
85
|
+
case "image":
|
|
86
|
+
return ImageBlock(media_type=data["media_type"], data=base64.b64decode(data["data"]))
|
|
87
|
+
case "tool_use":
|
|
88
|
+
return ToolUseBlock(id=data["id"], name=data["name"], input=data["input"])
|
|
89
|
+
case "tool_result":
|
|
90
|
+
raw = data["content"]
|
|
91
|
+
if isinstance(raw, str):
|
|
92
|
+
content: str | list[TextBlock | ImageBlock] = raw
|
|
93
|
+
else:
|
|
94
|
+
content = [from_dict(b) for b in raw] # type: ignore[misc]
|
|
95
|
+
return ToolResultBlock(tool_use_id=data["tool_use_id"], content=content, is_error=data["is_error"])
|
|
96
|
+
case _:
|
|
97
|
+
msg = f"Unknown block type: {data['type']}"
|
|
98
|
+
raise ValueError(msg)
|