aury-agent 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aury/__init__.py +2 -0
- aury/agents/__init__.py +55 -0
- aury/agents/a2a/__init__.py +168 -0
- aury/agents/backends/__init__.py +196 -0
- aury/agents/backends/artifact/__init__.py +9 -0
- aury/agents/backends/artifact/memory.py +130 -0
- aury/agents/backends/artifact/types.py +133 -0
- aury/agents/backends/code/__init__.py +65 -0
- aury/agents/backends/file/__init__.py +11 -0
- aury/agents/backends/file/local.py +66 -0
- aury/agents/backends/file/types.py +40 -0
- aury/agents/backends/invocation/__init__.py +8 -0
- aury/agents/backends/invocation/memory.py +81 -0
- aury/agents/backends/invocation/types.py +110 -0
- aury/agents/backends/memory/__init__.py +8 -0
- aury/agents/backends/memory/memory.py +179 -0
- aury/agents/backends/memory/types.py +136 -0
- aury/agents/backends/message/__init__.py +9 -0
- aury/agents/backends/message/memory.py +122 -0
- aury/agents/backends/message/types.py +124 -0
- aury/agents/backends/sandbox.py +275 -0
- aury/agents/backends/session/__init__.py +8 -0
- aury/agents/backends/session/memory.py +93 -0
- aury/agents/backends/session/types.py +124 -0
- aury/agents/backends/shell/__init__.py +11 -0
- aury/agents/backends/shell/local.py +110 -0
- aury/agents/backends/shell/types.py +55 -0
- aury/agents/backends/shell.py +209 -0
- aury/agents/backends/snapshot/__init__.py +19 -0
- aury/agents/backends/snapshot/git.py +95 -0
- aury/agents/backends/snapshot/hybrid.py +125 -0
- aury/agents/backends/snapshot/memory.py +86 -0
- aury/agents/backends/snapshot/types.py +59 -0
- aury/agents/backends/state/__init__.py +29 -0
- aury/agents/backends/state/composite.py +49 -0
- aury/agents/backends/state/file.py +57 -0
- aury/agents/backends/state/memory.py +52 -0
- aury/agents/backends/state/sqlite.py +262 -0
- aury/agents/backends/state/types.py +178 -0
- aury/agents/backends/subagent/__init__.py +165 -0
- aury/agents/cli/__init__.py +41 -0
- aury/agents/cli/chat.py +239 -0
- aury/agents/cli/config.py +236 -0
- aury/agents/cli/extensions.py +460 -0
- aury/agents/cli/main.py +189 -0
- aury/agents/cli/session.py +337 -0
- aury/agents/cli/workflow.py +276 -0
- aury/agents/context_providers/__init__.py +66 -0
- aury/agents/context_providers/artifact.py +299 -0
- aury/agents/context_providers/base.py +177 -0
- aury/agents/context_providers/memory.py +70 -0
- aury/agents/context_providers/message.py +130 -0
- aury/agents/context_providers/skill.py +50 -0
- aury/agents/context_providers/subagent.py +46 -0
- aury/agents/context_providers/tool.py +68 -0
- aury/agents/core/__init__.py +83 -0
- aury/agents/core/base.py +573 -0
- aury/agents/core/context.py +797 -0
- aury/agents/core/context_builder.py +303 -0
- aury/agents/core/event_bus/__init__.py +15 -0
- aury/agents/core/event_bus/bus.py +203 -0
- aury/agents/core/factory.py +169 -0
- aury/agents/core/isolator.py +97 -0
- aury/agents/core/logging.py +95 -0
- aury/agents/core/parallel.py +194 -0
- aury/agents/core/runner.py +139 -0
- aury/agents/core/services/__init__.py +5 -0
- aury/agents/core/services/file_session.py +144 -0
- aury/agents/core/services/message.py +53 -0
- aury/agents/core/services/session.py +53 -0
- aury/agents/core/signals.py +109 -0
- aury/agents/core/state.py +363 -0
- aury/agents/core/types/__init__.py +107 -0
- aury/agents/core/types/action.py +176 -0
- aury/agents/core/types/artifact.py +135 -0
- aury/agents/core/types/block.py +736 -0
- aury/agents/core/types/message.py +350 -0
- aury/agents/core/types/recall.py +144 -0
- aury/agents/core/types/session.py +257 -0
- aury/agents/core/types/subagent.py +154 -0
- aury/agents/core/types/tool.py +205 -0
- aury/agents/eval/__init__.py +331 -0
- aury/agents/hitl/__init__.py +57 -0
- aury/agents/hitl/ask_user.py +242 -0
- aury/agents/hitl/compaction.py +230 -0
- aury/agents/hitl/exceptions.py +87 -0
- aury/agents/hitl/permission.py +617 -0
- aury/agents/hitl/revert.py +216 -0
- aury/agents/llm/__init__.py +31 -0
- aury/agents/llm/adapter.py +367 -0
- aury/agents/llm/openai.py +294 -0
- aury/agents/llm/provider.py +476 -0
- aury/agents/mcp/__init__.py +153 -0
- aury/agents/memory/__init__.py +46 -0
- aury/agents/memory/compaction.py +394 -0
- aury/agents/memory/manager.py +465 -0
- aury/agents/memory/processor.py +177 -0
- aury/agents/memory/store.py +187 -0
- aury/agents/memory/types.py +137 -0
- aury/agents/messages/__init__.py +40 -0
- aury/agents/messages/config.py +47 -0
- aury/agents/messages/raw_store.py +224 -0
- aury/agents/messages/store.py +118 -0
- aury/agents/messages/types.py +88 -0
- aury/agents/middleware/__init__.py +31 -0
- aury/agents/middleware/base.py +341 -0
- aury/agents/middleware/chain.py +342 -0
- aury/agents/middleware/message.py +129 -0
- aury/agents/middleware/message_container.py +126 -0
- aury/agents/middleware/raw_message.py +153 -0
- aury/agents/middleware/truncation.py +139 -0
- aury/agents/middleware/types.py +81 -0
- aury/agents/plugin.py +162 -0
- aury/agents/react/__init__.py +4 -0
- aury/agents/react/agent.py +1923 -0
- aury/agents/sandbox/__init__.py +23 -0
- aury/agents/sandbox/local.py +239 -0
- aury/agents/sandbox/remote.py +200 -0
- aury/agents/sandbox/types.py +115 -0
- aury/agents/skill/__init__.py +16 -0
- aury/agents/skill/loader.py +180 -0
- aury/agents/skill/types.py +83 -0
- aury/agents/tool/__init__.py +39 -0
- aury/agents/tool/builtin/__init__.py +23 -0
- aury/agents/tool/builtin/ask_user.py +155 -0
- aury/agents/tool/builtin/bash.py +107 -0
- aury/agents/tool/builtin/delegate.py +726 -0
- aury/agents/tool/builtin/edit.py +121 -0
- aury/agents/tool/builtin/plan.py +277 -0
- aury/agents/tool/builtin/read.py +91 -0
- aury/agents/tool/builtin/thinking.py +111 -0
- aury/agents/tool/builtin/yield_result.py +130 -0
- aury/agents/tool/decorator.py +252 -0
- aury/agents/tool/set.py +204 -0
- aury/agents/usage/__init__.py +12 -0
- aury/agents/usage/tracker.py +236 -0
- aury/agents/workflow/__init__.py +85 -0
- aury/agents/workflow/adapter.py +268 -0
- aury/agents/workflow/dag.py +116 -0
- aury/agents/workflow/dsl.py +575 -0
- aury/agents/workflow/executor.py +659 -0
- aury/agents/workflow/expression.py +136 -0
- aury/agents/workflow/parser.py +182 -0
- aury/agents/workflow/state.py +145 -0
- aury/agents/workflow/types.py +86 -0
- aury_agent-0.0.4.dist-info/METADATA +90 -0
- aury_agent-0.0.4.dist-info/RECORD +149 -0
- aury_agent-0.0.4.dist-info/WHEEL +4 -0
- aury_agent-0.0.4.dist-info/entry_points.txt +2 -0
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
"""Session revert for undoing changes."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from typing import Any, Protocol, runtime_checkable
|
|
7
|
+
|
|
8
|
+
from ..backends.snapshot import SnapshotBackend, Patch
|
|
9
|
+
from ..backends.state import StateBackend
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@runtime_checkable
|
|
13
|
+
class BlockBackend(Protocol):
|
|
14
|
+
"""Protocol for block storage operations."""
|
|
15
|
+
|
|
16
|
+
async def get_block(self, session_id: str, block_id: str) -> dict[str, Any] | None:
|
|
17
|
+
"""Get a block by ID."""
|
|
18
|
+
...
|
|
19
|
+
|
|
20
|
+
async def delete_block(self, session_id: str, block_id: str) -> bool:
|
|
21
|
+
"""Delete a block."""
|
|
22
|
+
...
|
|
23
|
+
|
|
24
|
+
async def list_blocks_after(
|
|
25
|
+
self,
|
|
26
|
+
session_id: str,
|
|
27
|
+
block_id: str,
|
|
28
|
+
branch: str | None = None,
|
|
29
|
+
) -> list[dict[str, Any]]:
|
|
30
|
+
"""List all blocks after the specified block."""
|
|
31
|
+
...
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass
|
|
35
|
+
class RevertState:
|
|
36
|
+
"""State of a revert operation."""
|
|
37
|
+
block_id: str
|
|
38
|
+
snapshot_id: str
|
|
39
|
+
diff: str = ""
|
|
40
|
+
reverted_at: datetime = field(default_factory=datetime.now)
|
|
41
|
+
|
|
42
|
+
def to_dict(self) -> dict[str, Any]:
|
|
43
|
+
return {
|
|
44
|
+
"block_id": self.block_id,
|
|
45
|
+
"snapshot_id": self.snapshot_id,
|
|
46
|
+
"diff": self.diff,
|
|
47
|
+
"reverted_at": self.reverted_at.isoformat(),
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
@classmethod
|
|
51
|
+
def from_dict(cls, data: dict[str, Any]) -> RevertState:
|
|
52
|
+
return cls(
|
|
53
|
+
block_id=data["block_id"],
|
|
54
|
+
snapshot_id=data["snapshot_id"],
|
|
55
|
+
diff=data.get("diff", ""),
|
|
56
|
+
reverted_at=datetime.fromisoformat(data["reverted_at"])
|
|
57
|
+
if "reverted_at" in data else datetime.now(),
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class SessionRevert:
|
|
62
|
+
"""Manage session reverts.
|
|
63
|
+
|
|
64
|
+
Handles reverting to previous states, including:
|
|
65
|
+
- File system changes (via snapshot)
|
|
66
|
+
- Session blocks (via block_store)
|
|
67
|
+
- Memory entries (if integrated)
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
NAMESPACE = "revert"
|
|
71
|
+
|
|
72
|
+
def __init__(
|
|
73
|
+
self,
|
|
74
|
+
state: StateBackend,
|
|
75
|
+
block: BlockBackend,
|
|
76
|
+
snapshot: SnapshotBackend | None = None,
|
|
77
|
+
):
|
|
78
|
+
self._state = state
|
|
79
|
+
self._block = block
|
|
80
|
+
self._snapshot = snapshot
|
|
81
|
+
self._revert_states: dict[str, RevertState] = {} # session_id -> RevertState
|
|
82
|
+
|
|
83
|
+
async def revert(
|
|
84
|
+
self,
|
|
85
|
+
session_id: str,
|
|
86
|
+
block_id: str,
|
|
87
|
+
branch: str | None = None,
|
|
88
|
+
) -> RevertState:
|
|
89
|
+
"""Revert session to state before specified block.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
session_id: Session to revert
|
|
93
|
+
block_id: Revert to state before this block
|
|
94
|
+
branch: Optional branch filter (None = all branches)
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
RevertState with revert info
|
|
98
|
+
"""
|
|
99
|
+
# 1. Record current state for unrevert
|
|
100
|
+
current_snapshot = None
|
|
101
|
+
if self._snapshot:
|
|
102
|
+
current_snapshot = await self._snapshot.track()
|
|
103
|
+
|
|
104
|
+
# 2. Get block info
|
|
105
|
+
block_data = await self._block.get_block(session_id, block_id)
|
|
106
|
+
if not block_data:
|
|
107
|
+
raise ValueError(f"Block not found: {block_id}")
|
|
108
|
+
|
|
109
|
+
target_snapshot = block_data.get("snapshot_id")
|
|
110
|
+
|
|
111
|
+
# 3. Collect patches to revert
|
|
112
|
+
blocks_after = await self._block.list_blocks_after(session_id, block_id, branch)
|
|
113
|
+
|
|
114
|
+
if self._snapshot and blocks_after:
|
|
115
|
+
patches = []
|
|
116
|
+
for block in blocks_after:
|
|
117
|
+
if "patch" in block.get("data", {}):
|
|
118
|
+
patches.append(Patch.from_dict(block["data"]["patch"]))
|
|
119
|
+
|
|
120
|
+
if patches:
|
|
121
|
+
await self._snapshot.revert(patches)
|
|
122
|
+
|
|
123
|
+
# 4. Get diff
|
|
124
|
+
diff = ""
|
|
125
|
+
if self._snapshot and target_snapshot:
|
|
126
|
+
diff = await self._snapshot.diff(target_snapshot)
|
|
127
|
+
|
|
128
|
+
# 5. Create revert state
|
|
129
|
+
revert_state = RevertState(
|
|
130
|
+
block_id=block_id,
|
|
131
|
+
snapshot_id=current_snapshot or "",
|
|
132
|
+
diff=diff,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
# 6. Store revert state
|
|
136
|
+
self._revert_states[session_id] = revert_state
|
|
137
|
+
await self._state.set(self.NAMESPACE, session_id, revert_state.to_dict())
|
|
138
|
+
|
|
139
|
+
return revert_state
|
|
140
|
+
|
|
141
|
+
async def unrevert(self, session_id: str) -> bool:
|
|
142
|
+
"""Undo a revert operation.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
session_id: Session to unrevert
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
True if unrevert was performed
|
|
149
|
+
"""
|
|
150
|
+
revert_state = self._revert_states.get(session_id)
|
|
151
|
+
|
|
152
|
+
if not revert_state:
|
|
153
|
+
# Try loading from state
|
|
154
|
+
stored = await self._state.get(self.NAMESPACE, session_id)
|
|
155
|
+
if stored:
|
|
156
|
+
revert_state = RevertState.from_dict(stored)
|
|
157
|
+
|
|
158
|
+
if not revert_state:
|
|
159
|
+
return False
|
|
160
|
+
|
|
161
|
+
# Restore to snapshot before revert
|
|
162
|
+
if self._snapshot and revert_state.snapshot_id:
|
|
163
|
+
await self._snapshot.restore(revert_state.snapshot_id)
|
|
164
|
+
|
|
165
|
+
# Clear revert state
|
|
166
|
+
self._revert_states.pop(session_id, None)
|
|
167
|
+
await self._state.delete(self.NAMESPACE, session_id)
|
|
168
|
+
|
|
169
|
+
return True
|
|
170
|
+
|
|
171
|
+
async def cleanup(self, session_id: str) -> int:
|
|
172
|
+
"""Clean up revert state after new prompt.
|
|
173
|
+
|
|
174
|
+
Deletes blocks after the revert point.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
session_id: Session to cleanup
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
Number of blocks deleted
|
|
181
|
+
"""
|
|
182
|
+
revert_state = self._revert_states.get(session_id)
|
|
183
|
+
|
|
184
|
+
if not revert_state:
|
|
185
|
+
stored = await self._state.get(self.NAMESPACE, session_id)
|
|
186
|
+
if stored:
|
|
187
|
+
revert_state = RevertState.from_dict(stored)
|
|
188
|
+
|
|
189
|
+
if not revert_state:
|
|
190
|
+
return 0
|
|
191
|
+
|
|
192
|
+
# Get blocks to delete
|
|
193
|
+
blocks_to_delete = await self._block.list_blocks_after(
|
|
194
|
+
session_id,
|
|
195
|
+
revert_state.block_id,
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
# Delete blocks
|
|
199
|
+
deleted = 0
|
|
200
|
+
for block in blocks_to_delete:
|
|
201
|
+
if await self._block.delete_block(session_id, block["id"]):
|
|
202
|
+
deleted += 1
|
|
203
|
+
|
|
204
|
+
# Clear revert state
|
|
205
|
+
self._revert_states.pop(session_id, None)
|
|
206
|
+
await self._state.delete(self.NAMESPACE, session_id)
|
|
207
|
+
|
|
208
|
+
return deleted
|
|
209
|
+
|
|
210
|
+
def get_revert_state(self, session_id: str) -> RevertState | None:
|
|
211
|
+
"""Get current revert state for session."""
|
|
212
|
+
return self._revert_states.get(session_id)
|
|
213
|
+
|
|
214
|
+
def is_reverted(self, session_id: str) -> bool:
|
|
215
|
+
"""Check if session is in reverted state."""
|
|
216
|
+
return session_id in self._revert_states
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
"""LLM Provider protocol and implementations."""
|
|
2
|
+
from .provider import (
|
|
3
|
+
Capabilities,
|
|
4
|
+
ToolCall,
|
|
5
|
+
Usage,
|
|
6
|
+
LLMEvent,
|
|
7
|
+
ToolDefinition,
|
|
8
|
+
LLMMessage,
|
|
9
|
+
LLMProvider,
|
|
10
|
+
MockResponse,
|
|
11
|
+
MockLLMProvider,
|
|
12
|
+
ToolCallMockProvider,
|
|
13
|
+
)
|
|
14
|
+
from .adapter import ModelClientProvider, create_provider
|
|
15
|
+
from .openai import OpenAIProvider
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
"Capabilities",
|
|
19
|
+
"ToolCall",
|
|
20
|
+
"Usage",
|
|
21
|
+
"LLMEvent",
|
|
22
|
+
"ToolDefinition",
|
|
23
|
+
"LLMMessage",
|
|
24
|
+
"LLMProvider",
|
|
25
|
+
"MockResponse",
|
|
26
|
+
"MockLLMProvider",
|
|
27
|
+
"ToolCallMockProvider",
|
|
28
|
+
"ModelClientProvider",
|
|
29
|
+
"create_provider",
|
|
30
|
+
"OpenAIProvider",
|
|
31
|
+
]
|
|
@@ -0,0 +1,367 @@
|
|
|
1
|
+
"""LLM Provider adapter using aury-ai-model ModelClient."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import json
|
|
5
|
+
from typing import Any, AsyncIterator
|
|
6
|
+
|
|
7
|
+
from .provider import (
|
|
8
|
+
LLMProvider,
|
|
9
|
+
LLMEvent,
|
|
10
|
+
LLMMessage,
|
|
11
|
+
ToolCall,
|
|
12
|
+
ToolDefinition,
|
|
13
|
+
Usage,
|
|
14
|
+
Capabilities,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
# Import from aury-ai-model
|
|
18
|
+
try:
|
|
19
|
+
from aury.ai.model import (
|
|
20
|
+
ModelClient,
|
|
21
|
+
Message,
|
|
22
|
+
StreamEvent,
|
|
23
|
+
msg,
|
|
24
|
+
Text,
|
|
25
|
+
Evt,
|
|
26
|
+
ToolCall as ModelToolCall,
|
|
27
|
+
ToolSpec,
|
|
28
|
+
FunctionToolSpec,
|
|
29
|
+
ToolKind,
|
|
30
|
+
StreamCollector,
|
|
31
|
+
)
|
|
32
|
+
HAS_MODEL_CLIENT = True
|
|
33
|
+
except ImportError:
|
|
34
|
+
HAS_MODEL_CLIENT = False
|
|
35
|
+
ModelClient = None # type: ignore
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class ModelClientProvider:
|
|
39
|
+
"""LLM Provider using aury-ai-model ModelClient.
|
|
40
|
+
|
|
41
|
+
This adapter bridges the framework's LLMProvider protocol with
|
|
42
|
+
the aury-ai-model ModelClient.
|
|
43
|
+
|
|
44
|
+
Example:
|
|
45
|
+
>>> provider = ModelClientProvider(
|
|
46
|
+
... provider="openai",
|
|
47
|
+
... model="gpt-4o",
|
|
48
|
+
... )
|
|
49
|
+
>>> async for event in provider.complete(messages):
|
|
50
|
+
... print(event)
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
def __init__(
|
|
54
|
+
self,
|
|
55
|
+
provider: str,
|
|
56
|
+
model: str,
|
|
57
|
+
api_key: str | None = None,
|
|
58
|
+
base_url: str | None = None,
|
|
59
|
+
capabilities: Capabilities | None = None,
|
|
60
|
+
**kwargs: Any,
|
|
61
|
+
):
|
|
62
|
+
"""Initialize ModelClient provider.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
provider: Provider name (openai, anthropic, doubao, etc.)
|
|
66
|
+
model: Model name
|
|
67
|
+
api_key: API key (optional, uses env if not provided)
|
|
68
|
+
base_url: Base URL override
|
|
69
|
+
capabilities: Model capabilities
|
|
70
|
+
**kwargs: Additional ModelClient options
|
|
71
|
+
"""
|
|
72
|
+
if not HAS_MODEL_CLIENT:
|
|
73
|
+
raise ImportError(
|
|
74
|
+
"aury-ai-model is not installed. "
|
|
75
|
+
"Please install it: pip install aury-ai-model[all]"
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
self._provider_name = provider
|
|
79
|
+
self._model_name = model
|
|
80
|
+
self._capabilities = capabilities or Capabilities()
|
|
81
|
+
|
|
82
|
+
# Build ModelClient
|
|
83
|
+
client_kwargs = {
|
|
84
|
+
"provider": provider,
|
|
85
|
+
"model": model,
|
|
86
|
+
}
|
|
87
|
+
if api_key:
|
|
88
|
+
client_kwargs["api_key"] = api_key
|
|
89
|
+
if base_url:
|
|
90
|
+
client_kwargs["base_url"] = base_url
|
|
91
|
+
|
|
92
|
+
# Pass through additional options
|
|
93
|
+
for key in ("default_max_tokens", "default_temperature", "default_top_p"):
|
|
94
|
+
if key in kwargs:
|
|
95
|
+
client_kwargs[key] = kwargs[key]
|
|
96
|
+
|
|
97
|
+
self._client = ModelClient(**client_kwargs)
|
|
98
|
+
self._extra_kwargs = {
|
|
99
|
+
k: v for k, v in kwargs.items()
|
|
100
|
+
if k not in client_kwargs
|
|
101
|
+
}
|
|
102
|
+
self._call_count = 0
|
|
103
|
+
|
|
104
|
+
@property
|
|
105
|
+
def provider(self) -> str:
|
|
106
|
+
return self._provider_name
|
|
107
|
+
|
|
108
|
+
@property
|
|
109
|
+
def model(self) -> str:
|
|
110
|
+
return self._model_name
|
|
111
|
+
|
|
112
|
+
@property
|
|
113
|
+
def call_count(self) -> int:
|
|
114
|
+
"""Get number of LLM calls made."""
|
|
115
|
+
return self._call_count
|
|
116
|
+
|
|
117
|
+
@property
|
|
118
|
+
def capabilities(self) -> Capabilities:
|
|
119
|
+
"""Get model capabilities."""
|
|
120
|
+
return self._capabilities
|
|
121
|
+
|
|
122
|
+
def _convert_messages(self, messages: list[LLMMessage]) -> list[Message]:
|
|
123
|
+
"""Convert LLMMessage to aury-ai-model Message.
|
|
124
|
+
|
|
125
|
+
Supports all message types from aury.ai.model:
|
|
126
|
+
- system: msg.system(text)
|
|
127
|
+
- user: msg.user(text, images=[])
|
|
128
|
+
- assistant: msg.assistant(text, tool_calls=[])
|
|
129
|
+
- tool: msg.tool(result, tool_call_id)
|
|
130
|
+
"""
|
|
131
|
+
result = []
|
|
132
|
+
|
|
133
|
+
for m in messages:
|
|
134
|
+
if m.role == "system":
|
|
135
|
+
result.append(msg.system(
|
|
136
|
+
m.content if isinstance(m.content, str) else str(m.content)
|
|
137
|
+
))
|
|
138
|
+
|
|
139
|
+
elif m.role == "user":
|
|
140
|
+
if isinstance(m.content, str):
|
|
141
|
+
result.append(msg.user(m.content))
|
|
142
|
+
else:
|
|
143
|
+
# Handle multipart content (text + images)
|
|
144
|
+
text_parts = []
|
|
145
|
+
images = []
|
|
146
|
+
for part in m.content:
|
|
147
|
+
if isinstance(part, dict):
|
|
148
|
+
if part.get("type") == "text":
|
|
149
|
+
text_parts.append(part.get("text", ""))
|
|
150
|
+
elif part.get("type") == "image_url":
|
|
151
|
+
url = part.get("image_url", {}).get("url", "")
|
|
152
|
+
if url:
|
|
153
|
+
images.append(url)
|
|
154
|
+
result.append(msg.user(
|
|
155
|
+
text=" ".join(text_parts) if text_parts else None,
|
|
156
|
+
images=images if images else None,
|
|
157
|
+
))
|
|
158
|
+
|
|
159
|
+
elif m.role == "assistant":
|
|
160
|
+
if isinstance(m.content, str):
|
|
161
|
+
result.append(msg.assistant(m.content))
|
|
162
|
+
else:
|
|
163
|
+
# Handle tool calls in assistant message
|
|
164
|
+
text_parts = []
|
|
165
|
+
tool_calls = []
|
|
166
|
+
for part in m.content:
|
|
167
|
+
if isinstance(part, dict):
|
|
168
|
+
if part.get("type") == "text":
|
|
169
|
+
text_parts.append(part.get("text", ""))
|
|
170
|
+
elif part.get("type") == "tool_use":
|
|
171
|
+
tool_calls.append(ModelToolCall(
|
|
172
|
+
id=part.get("id", ""),
|
|
173
|
+
name=part.get("name", ""),
|
|
174
|
+
arguments_json=json.dumps(part.get("input", {})),
|
|
175
|
+
))
|
|
176
|
+
result.append(msg.assistant(
|
|
177
|
+
text=" ".join(text_parts) if text_parts else None,
|
|
178
|
+
tool_calls=tool_calls if tool_calls else None,
|
|
179
|
+
))
|
|
180
|
+
|
|
181
|
+
elif m.role == "tool":
|
|
182
|
+
# Tool result message - two formats supported:
|
|
183
|
+
# 1. Simple: LLMMessage(role="tool", content="result", tool_call_id="xxx")
|
|
184
|
+
# 2. List format: content=[{"type": "tool_result", "content": "...", "tool_use_id": "..."}]
|
|
185
|
+
if m.tool_call_id and isinstance(m.content, str):
|
|
186
|
+
# Simple format
|
|
187
|
+
result.append(msg.tool(
|
|
188
|
+
result=m.content,
|
|
189
|
+
tool_call_id=m.tool_call_id,
|
|
190
|
+
))
|
|
191
|
+
elif isinstance(m.content, list):
|
|
192
|
+
# List format (for compatibility)
|
|
193
|
+
for part in m.content:
|
|
194
|
+
if isinstance(part, dict) and part.get("type") == "tool_result":
|
|
195
|
+
result.append(msg.tool(
|
|
196
|
+
result=str(part.get("content", "")),
|
|
197
|
+
tool_call_id=part.get("tool_use_id", ""),
|
|
198
|
+
))
|
|
199
|
+
|
|
200
|
+
return result
|
|
201
|
+
|
|
202
|
+
def _convert_tools(
|
|
203
|
+
self,
|
|
204
|
+
tools: list[ToolDefinition] | None,
|
|
205
|
+
) -> list[ToolSpec] | None:
|
|
206
|
+
"""Convert ToolDefinition to aury-ai-model ToolSpec."""
|
|
207
|
+
if not tools:
|
|
208
|
+
return None
|
|
209
|
+
|
|
210
|
+
return [
|
|
211
|
+
ToolSpec(
|
|
212
|
+
kind=ToolKind.function,
|
|
213
|
+
function=FunctionToolSpec(
|
|
214
|
+
name=tool.name,
|
|
215
|
+
description=tool.description,
|
|
216
|
+
parameters=tool.input_schema,
|
|
217
|
+
),
|
|
218
|
+
)
|
|
219
|
+
for tool in tools
|
|
220
|
+
]
|
|
221
|
+
|
|
222
|
+
def _convert_stream_event(self, event: StreamEvent) -> LLMEvent | None:
|
|
223
|
+
"""Convert aury-ai-model StreamEvent to LLMEvent."""
|
|
224
|
+
match event.type:
|
|
225
|
+
case Evt.content:
|
|
226
|
+
return LLMEvent(type="content", delta=event.delta)
|
|
227
|
+
|
|
228
|
+
case Evt.thinking:
|
|
229
|
+
return LLMEvent(type="thinking", delta=event.delta)
|
|
230
|
+
|
|
231
|
+
case Evt.tool_call_start:
|
|
232
|
+
if event.tool_call:
|
|
233
|
+
return LLMEvent(
|
|
234
|
+
type="tool_call_start",
|
|
235
|
+
tool_call=ToolCall(
|
|
236
|
+
id=event.tool_call.id,
|
|
237
|
+
name=event.tool_call.name,
|
|
238
|
+
arguments="", # Empty at start
|
|
239
|
+
),
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
case Evt.tool_call_delta:
|
|
243
|
+
if event.tool_call_delta:
|
|
244
|
+
return LLMEvent(
|
|
245
|
+
type="tool_call_delta",
|
|
246
|
+
tool_call_delta=event.tool_call_delta,
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
case Evt.tool_call_progress:
|
|
250
|
+
if event.tool_call_progress:
|
|
251
|
+
return LLMEvent(
|
|
252
|
+
type="tool_call_progress",
|
|
253
|
+
tool_call_progress=event.tool_call_progress,
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
case Evt.tool_call:
|
|
257
|
+
if event.tool_call:
|
|
258
|
+
return LLMEvent(
|
|
259
|
+
type="tool_call",
|
|
260
|
+
tool_call=ToolCall(
|
|
261
|
+
id=event.tool_call.id,
|
|
262
|
+
name=event.tool_call.name,
|
|
263
|
+
arguments=event.tool_call.arguments_json,
|
|
264
|
+
),
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
case Evt.usage:
|
|
268
|
+
if event.usage:
|
|
269
|
+
return LLMEvent(
|
|
270
|
+
type="usage",
|
|
271
|
+
usage=Usage(
|
|
272
|
+
input_tokens=event.usage.input_tokens,
|
|
273
|
+
output_tokens=event.usage.output_tokens,
|
|
274
|
+
cache_read_tokens=getattr(event.usage, 'cache_read_tokens', 0),
|
|
275
|
+
cache_write_tokens=getattr(event.usage, 'cache_write_tokens', 0),
|
|
276
|
+
reasoning_tokens=getattr(event.usage, 'reasoning_tokens', 0),
|
|
277
|
+
),
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
case Evt.completed:
|
|
281
|
+
return LLMEvent(type="completed", finish_reason="end_turn")
|
|
282
|
+
|
|
283
|
+
case Evt.error:
|
|
284
|
+
return LLMEvent(type="error", error=event.error)
|
|
285
|
+
|
|
286
|
+
return None
|
|
287
|
+
|
|
288
|
+
async def complete(
|
|
289
|
+
self,
|
|
290
|
+
messages: list[LLMMessage],
|
|
291
|
+
tools: list[ToolDefinition] | None = None,
|
|
292
|
+
enable_thinking: bool = False,
|
|
293
|
+
reasoning_effort: str | None = None,
|
|
294
|
+
**kwargs: Any,
|
|
295
|
+
) -> AsyncIterator[LLMEvent]:
|
|
296
|
+
"""Generate completion with streaming.
|
|
297
|
+
|
|
298
|
+
Streaming is enabled by default - this method uses ModelClient.astream()
|
|
299
|
+
which always streams responses incrementally.
|
|
300
|
+
|
|
301
|
+
Args:
|
|
302
|
+
messages: Conversation messages
|
|
303
|
+
tools: Available tools
|
|
304
|
+
enable_thinking: Whether to request thinking output
|
|
305
|
+
reasoning_effort: Reasoning effort level ("low", "medium", "high", "max", "auto")
|
|
306
|
+
**kwargs: Additional parameters (temperature, max_tokens, etc.)
|
|
307
|
+
|
|
308
|
+
Yields:
|
|
309
|
+
LLMEvent: Streaming events (content, thinking, tool_call, usage, completed, error)
|
|
310
|
+
"""
|
|
311
|
+
# Convert messages and tools
|
|
312
|
+
model_messages = self._convert_messages(messages)
|
|
313
|
+
model_tools = self._convert_tools(tools)
|
|
314
|
+
|
|
315
|
+
# Merge kwargs
|
|
316
|
+
call_kwargs = {**self._extra_kwargs, **kwargs}
|
|
317
|
+
if model_tools:
|
|
318
|
+
call_kwargs["tools"] = model_tools
|
|
319
|
+
|
|
320
|
+
# Add thinking configuration (for models that support it)
|
|
321
|
+
if enable_thinking:
|
|
322
|
+
call_kwargs["return_thinking"] = True
|
|
323
|
+
if reasoning_effort:
|
|
324
|
+
call_kwargs["reasoning_effort"] = reasoning_effort
|
|
325
|
+
|
|
326
|
+
# Increment call count
|
|
327
|
+
self._call_count += 1
|
|
328
|
+
|
|
329
|
+
# Remove stream from kwargs if present (astream always streams, doesn't accept stream param)
|
|
330
|
+
call_kwargs.pop('stream', None)
|
|
331
|
+
|
|
332
|
+
# Ensure usage events are yielded (for statistics tracking)
|
|
333
|
+
# This ensures usage events are included in the stream
|
|
334
|
+
yield_usage_event = call_kwargs.pop('yield_usage_event', True)
|
|
335
|
+
|
|
336
|
+
# Stream from ModelClient with retry support
|
|
337
|
+
# astream() always streams incrementally - events arrive as they're generated
|
|
338
|
+
async for event in self._client.with_retry(
|
|
339
|
+
max_attempts=3,
|
|
340
|
+
base_delay=1.0,
|
|
341
|
+
max_delay=10.0,
|
|
342
|
+
).astream(
|
|
343
|
+
model_messages,
|
|
344
|
+
yield_usage_event=yield_usage_event,
|
|
345
|
+
**call_kwargs
|
|
346
|
+
):
|
|
347
|
+
converted = self._convert_stream_event(event)
|
|
348
|
+
if converted:
|
|
349
|
+
yield converted
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
def create_provider(
|
|
353
|
+
provider: str,
|
|
354
|
+
model: str,
|
|
355
|
+
**kwargs: Any,
|
|
356
|
+
) -> LLMProvider:
|
|
357
|
+
"""Create an LLM provider.
|
|
358
|
+
|
|
359
|
+
Args:
|
|
360
|
+
provider: Provider name (openai, anthropic, doubao, etc.)
|
|
361
|
+
model: Model name
|
|
362
|
+
**kwargs: Additional options
|
|
363
|
+
|
|
364
|
+
Returns:
|
|
365
|
+
LLMProvider instance
|
|
366
|
+
"""
|
|
367
|
+
return ModelClientProvider(provider=provider, model=model, **kwargs)
|