agentinc-sdk 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentinc/__init__.py +0 -0
- agentinc/sdk/__init__.py +39 -0
- agentinc/sdk/agent.py +285 -0
- agentinc/sdk/memory/__init__.py +14 -0
- agentinc/sdk/memory/base.py +12 -0
- agentinc/sdk/memory/redis.py +48 -0
- agentinc/sdk/protocol.py +19 -0
- agentinc/sdk/providers/__init__.py +29 -0
- agentinc/sdk/providers/anthropic.py +107 -0
- agentinc/sdk/providers/base.py +15 -0
- agentinc/sdk/providers/gemini.py +117 -0
- agentinc/sdk/providers/openai.py +127 -0
- agentinc/sdk/raw.py +134 -0
- agentinc/sdk/schemas.py +76 -0
- agentinc/sdk/serve.py +162 -0
- agentinc/sdk/tool.py +119 -0
- agentinc_sdk-0.2.0.dist-info/METADATA +240 -0
- agentinc_sdk-0.2.0.dist-info/RECORD +20 -0
- agentinc_sdk-0.2.0.dist-info/WHEEL +4 -0
- agentinc_sdk-0.2.0.dist-info/licenses/LICENSE +190 -0
agentinc/__init__.py
ADDED
|
File without changes
|
agentinc/sdk/__init__.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
from .agent import Agent
|
|
2
|
+
from .protocol import AgentFactory, AgentProtocol, ToolProtocol
|
|
3
|
+
from .raw import RawAdapter
|
|
4
|
+
from .schemas import (
|
|
5
|
+
AgentInput,
|
|
6
|
+
AgentOutput,
|
|
7
|
+
DataConfig,
|
|
8
|
+
MCPConfig,
|
|
9
|
+
MemoryConfig,
|
|
10
|
+
Message,
|
|
11
|
+
ModelConfig,
|
|
12
|
+
ToolCall,
|
|
13
|
+
ToolSchema,
|
|
14
|
+
)
|
|
15
|
+
from .tool import ToolWrapper, tool
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
# Core
|
|
19
|
+
"Agent",
|
|
20
|
+
"AgentProtocol",
|
|
21
|
+
"ToolProtocol",
|
|
22
|
+
"AgentFactory",
|
|
23
|
+
# Schemas
|
|
24
|
+
"AgentInput",
|
|
25
|
+
"AgentOutput",
|
|
26
|
+
"Message",
|
|
27
|
+
"ToolCall",
|
|
28
|
+
"ToolSchema",
|
|
29
|
+
# Config
|
|
30
|
+
"ModelConfig",
|
|
31
|
+
"MemoryConfig",
|
|
32
|
+
"MCPConfig",
|
|
33
|
+
"DataConfig",
|
|
34
|
+
# Tools
|
|
35
|
+
"ToolWrapper",
|
|
36
|
+
"tool",
|
|
37
|
+
# Deprecated
|
|
38
|
+
"RawAdapter",
|
|
39
|
+
]
|
agentinc/sdk/agent.py
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import uuid
|
|
5
|
+
from collections.abc import Callable
|
|
6
|
+
from typing import Any, AsyncIterator
|
|
7
|
+
|
|
8
|
+
from .memory import memory_for
|
|
9
|
+
from .memory.base import Memory
|
|
10
|
+
from .providers import provider_for
|
|
11
|
+
from .providers.base import Provider
|
|
12
|
+
from .schemas import (
|
|
13
|
+
AgentInput,
|
|
14
|
+
AgentOutput,
|
|
15
|
+
DataConfig,
|
|
16
|
+
MCPConfig,
|
|
17
|
+
MemoryConfig,
|
|
18
|
+
Message,
|
|
19
|
+
ModelConfig,
|
|
20
|
+
ToolCall,
|
|
21
|
+
ToolSchema,
|
|
22
|
+
)
|
|
23
|
+
from .tool import ToolWrapper, _build_schema
|
|
24
|
+
|
|
25
|
+
log = logging.getLogger("agentinc.sdk.agent")
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _wrap_tool(fn: Callable) -> ToolWrapper:
|
|
29
|
+
if isinstance(fn, ToolWrapper):
|
|
30
|
+
return fn
|
|
31
|
+
schema = _build_schema(fn, name=fn.__name__, description=fn.__doc__ or "")
|
|
32
|
+
return ToolWrapper(fn=fn, schema=schema)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class Agent:
|
|
36
|
+
"""
|
|
37
|
+
High-level agent that wires together a provider, tools, MCP servers,
|
|
38
|
+
Redis memory, and optional context into a single AgentProtocol-compatible object.
|
|
39
|
+
|
|
40
|
+
Usage::
|
|
41
|
+
|
|
42
|
+
agent = Agent(
|
|
43
|
+
role="You are a helpful assistant.",
|
|
44
|
+
model={"model": "gpt-4o-mini", "api_key": "sk-…"},
|
|
45
|
+
tools=[my_function],
|
|
46
|
+
memory={"type": "redis", "connection": "redis://localhost:6379"},
|
|
47
|
+
)
|
|
48
|
+
serve(agent, name="my-agent", port=8000)
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
def __init__(
|
|
52
|
+
self,
|
|
53
|
+
role: str,
|
|
54
|
+
model: ModelConfig,
|
|
55
|
+
tools: list[Callable] | None = None,
|
|
56
|
+
mcps: list[MCPConfig] | None = None,
|
|
57
|
+
memory: MemoryConfig | None = None,
|
|
58
|
+
context: str | None = None,
|
|
59
|
+
data: DataConfig | None = None,
|
|
60
|
+
) -> None:
|
|
61
|
+
self._role = role
|
|
62
|
+
self._context = context
|
|
63
|
+
self._model_config = model
|
|
64
|
+
|
|
65
|
+
self._provider: Provider = provider_for(model)
|
|
66
|
+
self._memory: Memory | None = memory_for(memory) if memory else None
|
|
67
|
+
self._mcp_configs: list[MCPConfig] = mcps or []
|
|
68
|
+
|
|
69
|
+
# Wrap plain functions as ToolWrapper instances
|
|
70
|
+
self._local_tools: dict[str, ToolWrapper] = {
|
|
71
|
+
w.schema().name: w for w in (_wrap_tool(t) for t in (tools or []))
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
# MCP tool registry — populated on first run()
|
|
75
|
+
self._mcp_tools: dict[str, Any] = {}
|
|
76
|
+
self._mcp_ready = False
|
|
77
|
+
|
|
78
|
+
if data is not None:
|
|
79
|
+
log.debug("data= parameter accepted but not yet implemented (reserved for RAG)")
|
|
80
|
+
|
|
81
|
+
# ------------------------------------------------------------------
|
|
82
|
+
# AgentProtocol
|
|
83
|
+
# ------------------------------------------------------------------
|
|
84
|
+
|
|
85
|
+
async def run(self, input: AgentInput) -> AsyncIterator[AgentOutput]:
|
|
86
|
+
if self._mcp_configs and not self._mcp_ready:
|
|
87
|
+
await self._init_mcp()
|
|
88
|
+
|
|
89
|
+
session_id = input.metadata.get("session_id") or str(uuid.uuid4())
|
|
90
|
+
|
|
91
|
+
# Load history from memory backend (empty list if no memory configured)
|
|
92
|
+
persisted: list[Message] = []
|
|
93
|
+
if self._memory:
|
|
94
|
+
persisted = await self._memory.load(session_id)
|
|
95
|
+
|
|
96
|
+
# Merge caller-supplied history (takes precedence) with persisted history
|
|
97
|
+
history = persisted if not input.history else list(input.history)
|
|
98
|
+
|
|
99
|
+
# Build system prompt
|
|
100
|
+
system_parts = [self._role]
|
|
101
|
+
if self._context:
|
|
102
|
+
system_parts.append(self._context)
|
|
103
|
+
|
|
104
|
+
messages: list[dict] = [{"role": "system", "content": "\n\n".join(system_parts)}]
|
|
105
|
+
for msg in history:
|
|
106
|
+
messages.append(self._message_to_dict(msg))
|
|
107
|
+
messages.append({"role": "user", "content": input.message})
|
|
108
|
+
|
|
109
|
+
# Collect all tool schemas
|
|
110
|
+
tool_schemas = [w.schema() for w in self._local_tools.values()]
|
|
111
|
+
tool_schemas += [
|
|
112
|
+
ToolSchema(name=name, description=spec["description"], parameters=spec["parameters"])
|
|
113
|
+
for name, spec in self._mcp_tools.items()
|
|
114
|
+
]
|
|
115
|
+
|
|
116
|
+
# Agentic loop
|
|
117
|
+
new_messages: list[dict] = []
|
|
118
|
+
while True:
|
|
119
|
+
tool_calls_batch: list[ToolCall] = []
|
|
120
|
+
|
|
121
|
+
async for chunk in self._provider.complete(messages + new_messages, tool_schemas):
|
|
122
|
+
if chunk.tool_calls:
|
|
123
|
+
tool_calls_batch.extend(chunk.tool_calls)
|
|
124
|
+
elif chunk.content:
|
|
125
|
+
yield chunk
|
|
126
|
+
if chunk.done:
|
|
127
|
+
break
|
|
128
|
+
|
|
129
|
+
if not tool_calls_batch:
|
|
130
|
+
yield AgentOutput(content="", done=True)
|
|
131
|
+
break
|
|
132
|
+
|
|
133
|
+
# Append assistant tool-call turn
|
|
134
|
+
new_messages.append({
|
|
135
|
+
"role": "assistant",
|
|
136
|
+
"tool_calls": [
|
|
137
|
+
{
|
|
138
|
+
"id": tc.id,
|
|
139
|
+
"type": "function",
|
|
140
|
+
"function": {"name": tc.name, "arguments": str(tc.arguments)},
|
|
141
|
+
}
|
|
142
|
+
for tc in tool_calls_batch
|
|
143
|
+
],
|
|
144
|
+
})
|
|
145
|
+
|
|
146
|
+
# Dispatch all tool calls and append results
|
|
147
|
+
for tc in tool_calls_batch:
|
|
148
|
+
result = await self._dispatch_tool(tc)
|
|
149
|
+
new_messages.append({
|
|
150
|
+
"role": "tool",
|
|
151
|
+
"tool_call_id": tc.id,
|
|
152
|
+
"content": result,
|
|
153
|
+
})
|
|
154
|
+
|
|
155
|
+
# Persist updated history
|
|
156
|
+
if self._memory:
|
|
157
|
+
updated = list(history) + [
|
|
158
|
+
Message(role="user", content=input.message),
|
|
159
|
+
*[
|
|
160
|
+
Message(role=m["role"], content=m.get("content", ""))
|
|
161
|
+
for m in new_messages
|
|
162
|
+
if m["role"] in ("assistant", "tool") and m.get("content")
|
|
163
|
+
],
|
|
164
|
+
]
|
|
165
|
+
await self._memory.save(session_id, updated)
|
|
166
|
+
|
|
167
|
+
# ------------------------------------------------------------------
|
|
168
|
+
# Internals
|
|
169
|
+
# ------------------------------------------------------------------
|
|
170
|
+
|
|
171
|
+
async def _dispatch_tool(self, tc: ToolCall) -> str:
|
|
172
|
+
if tc.name in self._local_tools:
|
|
173
|
+
return await self._local_tools[tc.name].call(tc)
|
|
174
|
+
if tc.name in self._mcp_tools:
|
|
175
|
+
return await self._call_mcp_tool(tc)
|
|
176
|
+
return f"Error: tool '{tc.name}' not found"
|
|
177
|
+
|
|
178
|
+
async def _init_mcp(self) -> None:
|
|
179
|
+
for config in self._mcp_configs:
|
|
180
|
+
try:
|
|
181
|
+
tools = await _fetch_mcp_tools(config)
|
|
182
|
+
self._mcp_tools.update(tools)
|
|
183
|
+
except Exception:
|
|
184
|
+
log.exception("Failed to initialise MCP server: %s", config)
|
|
185
|
+
self._mcp_ready = True
|
|
186
|
+
|
|
187
|
+
async def _call_mcp_tool(self, tc: ToolCall) -> str:
|
|
188
|
+
spec = self._mcp_tools.get(tc.name)
|
|
189
|
+
if spec is None:
|
|
190
|
+
return f"Error: MCP tool '{tc.name}' not found"
|
|
191
|
+
try:
|
|
192
|
+
result = await _invoke_mcp_tool(spec["config"], tc.name, tc.arguments)
|
|
193
|
+
return result
|
|
194
|
+
except Exception as exc:
|
|
195
|
+
return f"Error calling MCP tool '{tc.name}': {exc}"
|
|
196
|
+
|
|
197
|
+
@staticmethod
|
|
198
|
+
def _message_to_dict(msg: Message) -> dict:
|
|
199
|
+
d: dict = {"role": msg.role, "content": msg.content or ""}
|
|
200
|
+
if msg.tool_calls:
|
|
201
|
+
d["tool_calls"] = [
|
|
202
|
+
{
|
|
203
|
+
"id": tc.id,
|
|
204
|
+
"type": "function",
|
|
205
|
+
"function": {"name": tc.name, "arguments": str(tc.arguments)},
|
|
206
|
+
}
|
|
207
|
+
for tc in msg.tool_calls
|
|
208
|
+
]
|
|
209
|
+
if msg.tool_call_id:
|
|
210
|
+
d["tool_call_id"] = msg.tool_call_id
|
|
211
|
+
return d
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
# ------------------------------------------------------------------
|
|
215
|
+
# MCP helpers (lazy-import mcp package)
|
|
216
|
+
# ------------------------------------------------------------------
|
|
217
|
+
|
|
218
|
+
async def _fetch_mcp_tools(config: MCPConfig) -> dict[str, dict]:
|
|
219
|
+
"""Connect to an MCP server and return its tool schemas."""
|
|
220
|
+
try:
|
|
221
|
+
from mcp import ClientSession, StdioServerParameters
|
|
222
|
+
from mcp.client.stdio import stdio_client
|
|
223
|
+
from mcp.client.sse import sse_client
|
|
224
|
+
except ImportError:
|
|
225
|
+
raise ImportError(
|
|
226
|
+
"MCP support requires the mcp extra: pip install 'agentinc-sdk[mcp]'"
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
tools: dict[str, dict] = {}
|
|
230
|
+
|
|
231
|
+
if config["type"] == "stdio":
|
|
232
|
+
server_params = StdioServerParameters(
|
|
233
|
+
command=config["command"],
|
|
234
|
+
args=config.get("args", []),
|
|
235
|
+
)
|
|
236
|
+
async with stdio_client(server_params) as (read, write):
|
|
237
|
+
async with ClientSession(read, write) as session:
|
|
238
|
+
await session.initialize()
|
|
239
|
+
result = await session.list_tools()
|
|
240
|
+
for tool in result.tools:
|
|
241
|
+
tools[tool.name] = {
|
|
242
|
+
"description": tool.description or "",
|
|
243
|
+
"parameters": tool.inputSchema or {"type": "object", "properties": {}},
|
|
244
|
+
"config": config,
|
|
245
|
+
}
|
|
246
|
+
|
|
247
|
+
elif config["type"] == "sse":
|
|
248
|
+
async with sse_client(config["url"]) as (read, write):
|
|
249
|
+
async with ClientSession(read, write) as session:
|
|
250
|
+
await session.initialize()
|
|
251
|
+
result = await session.list_tools()
|
|
252
|
+
for tool in result.tools:
|
|
253
|
+
tools[tool.name] = {
|
|
254
|
+
"description": tool.description or "",
|
|
255
|
+
"parameters": tool.inputSchema or {"type": "object", "properties": {}},
|
|
256
|
+
"config": config,
|
|
257
|
+
}
|
|
258
|
+
|
|
259
|
+
return tools
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
async def _invoke_mcp_tool(config: MCPConfig, name: str, arguments: dict) -> str:
|
|
263
|
+
from mcp import ClientSession, StdioServerParameters
|
|
264
|
+
from mcp.client.stdio import stdio_client
|
|
265
|
+
from mcp.client.sse import sse_client
|
|
266
|
+
|
|
267
|
+
if config["type"] == "stdio":
|
|
268
|
+
server_params = StdioServerParameters(
|
|
269
|
+
command=config["command"],
|
|
270
|
+
args=config.get("args", []),
|
|
271
|
+
)
|
|
272
|
+
async with stdio_client(server_params) as (read, write):
|
|
273
|
+
async with ClientSession(read, write) as session:
|
|
274
|
+
await session.initialize()
|
|
275
|
+
result = await session.call_tool(name, arguments)
|
|
276
|
+
return str(result.content)
|
|
277
|
+
|
|
278
|
+
elif config["type"] == "sse":
|
|
279
|
+
async with sse_client(config["url"]) as (read, write):
|
|
280
|
+
async with ClientSession(read, write) as session:
|
|
281
|
+
await session.initialize()
|
|
282
|
+
result = await session.call_tool(name, arguments)
|
|
283
|
+
return str(result.content)
|
|
284
|
+
|
|
285
|
+
return "Error: unsupported MCP transport"
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from ..schemas import MemoryConfig
|
|
4
|
+
from .base import Memory
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def memory_for(config: MemoryConfig) -> Memory:
|
|
8
|
+
if config["type"] == "redis":
|
|
9
|
+
from .redis import RedisMemory
|
|
10
|
+
return RedisMemory(config)
|
|
11
|
+
raise ValueError(f"Unsupported memory type: {config['type']!r}")
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
__all__ = ["Memory", "memory_for"]
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Protocol, runtime_checkable
|
|
4
|
+
|
|
5
|
+
from ..schemas import Message
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@runtime_checkable
|
|
9
|
+
class Memory(Protocol):
|
|
10
|
+
async def load(self, session_id: str) -> list[Message]: ...
|
|
11
|
+
async def save(self, session_id: str, history: list[Message]) -> None: ...
|
|
12
|
+
async def clear(self, session_id: str) -> None: ...
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
|
|
6
|
+
from ..schemas import MemoryConfig, Message
|
|
7
|
+
|
|
8
|
+
log = logging.getLogger("agentinc.sdk.memory.redis")
|
|
9
|
+
|
|
10
|
+
_TTL = 86400 # 24 hours
|
|
11
|
+
_KEY_PREFIX = "agentinc:session"
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class RedisMemory:
|
|
15
|
+
def __init__(self, config: MemoryConfig) -> None:
|
|
16
|
+
try:
|
|
17
|
+
import redis.asyncio as aioredis
|
|
18
|
+
except ImportError:
|
|
19
|
+
raise ImportError(
|
|
20
|
+
"Redis memory requires the memory extra: "
|
|
21
|
+
"pip install 'agentinc-sdk[memory]'"
|
|
22
|
+
)
|
|
23
|
+
self._redis = aioredis.from_url(
|
|
24
|
+
config["connection"],
|
|
25
|
+
username=config.get("user"),
|
|
26
|
+
password=config.get("password"),
|
|
27
|
+
decode_responses=True,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
def _key(self, session_id: str) -> str:
|
|
31
|
+
return f"{_KEY_PREFIX}:{session_id}:history"
|
|
32
|
+
|
|
33
|
+
async def load(self, session_id: str) -> list[Message]:
|
|
34
|
+
raw = await self._redis.get(self._key(session_id))
|
|
35
|
+
if not raw:
|
|
36
|
+
return []
|
|
37
|
+
try:
|
|
38
|
+
return [Message(**m) for m in json.loads(raw)]
|
|
39
|
+
except Exception:
|
|
40
|
+
log.warning("Failed to deserialise history for session %s", session_id)
|
|
41
|
+
return []
|
|
42
|
+
|
|
43
|
+
async def save(self, session_id: str, history: list[Message]) -> None:
|
|
44
|
+
payload = json.dumps([m.model_dump() for m in history])
|
|
45
|
+
await self._redis.set(self._key(session_id), payload, ex=_TTL)
|
|
46
|
+
|
|
47
|
+
async def clear(self, session_id: str) -> None:
|
|
48
|
+
await self._redis.delete(self._key(session_id))
|
agentinc/sdk/protocol.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, AsyncIterator, Protocol, runtime_checkable
|
|
4
|
+
|
|
5
|
+
from .schemas import AgentInput, AgentOutput, ToolCall, ToolSchema
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@runtime_checkable
|
|
9
|
+
class ToolProtocol(Protocol):
|
|
10
|
+
def schema(self) -> ToolSchema: ...
|
|
11
|
+
async def call(self, tool_call: ToolCall) -> str: ...
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@runtime_checkable
|
|
15
|
+
class AgentProtocol(Protocol):
|
|
16
|
+
def run(self, input: AgentInput) -> AsyncIterator[AgentOutput]: ...
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
AgentFactory = Any
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from ..schemas import ModelConfig
|
|
4
|
+
from .base import Provider
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def provider_for(config: ModelConfig) -> Provider:
|
|
8
|
+
"""Return the correct Provider instance based on model name or base_url."""
|
|
9
|
+
model = config["model"]
|
|
10
|
+
base_url = config.get("base_url")
|
|
11
|
+
|
|
12
|
+
if base_url:
|
|
13
|
+
from .openai import OpenAIProvider
|
|
14
|
+
return OpenAIProvider(config)
|
|
15
|
+
|
|
16
|
+
if model.startswith("claude"):
|
|
17
|
+
from .anthropic import AnthropicProvider
|
|
18
|
+
return AnthropicProvider(config)
|
|
19
|
+
|
|
20
|
+
if model.startswith("gemini"):
|
|
21
|
+
from .gemini import GeminiProvider
|
|
22
|
+
return GeminiProvider(config)
|
|
23
|
+
|
|
24
|
+
# Default: OpenAI (covers gpt-*, o1-*, o3-*, etc.)
|
|
25
|
+
from .openai import OpenAIProvider
|
|
26
|
+
return OpenAIProvider(config)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
__all__ = ["Provider", "provider_for"]
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import AsyncIterator
|
|
5
|
+
|
|
6
|
+
from ..schemas import AgentOutput, ModelConfig, ToolCall, ToolSchema
|
|
7
|
+
|
|
8
|
+
log = logging.getLogger("agentinc.sdk.providers.anthropic")
|
|
9
|
+
|
|
10
|
+
_SYSTEM_KEY = "__system__"
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _strip_system(messages: list[dict]) -> tuple[str, list[dict]]:
|
|
14
|
+
"""Pull the system message out; Anthropic takes it as a separate param."""
|
|
15
|
+
system = ""
|
|
16
|
+
rest = []
|
|
17
|
+
for m in messages:
|
|
18
|
+
if m["role"] == "system":
|
|
19
|
+
system = m.get("content", "")
|
|
20
|
+
else:
|
|
21
|
+
rest.append(m)
|
|
22
|
+
return system, rest
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class AnthropicProvider:
|
|
26
|
+
def __init__(self, config: ModelConfig) -> None:
|
|
27
|
+
try:
|
|
28
|
+
import anthropic as _anthropic
|
|
29
|
+
self._anthropic = _anthropic
|
|
30
|
+
except ImportError:
|
|
31
|
+
raise ImportError(
|
|
32
|
+
"Anthropic provider requires the anthropic extra: "
|
|
33
|
+
"pip install 'agentinc-sdk[anthropic]'"
|
|
34
|
+
)
|
|
35
|
+
self._client = self._anthropic.AsyncAnthropic(api_key=config["api_key"])
|
|
36
|
+
self._model = config["model"]
|
|
37
|
+
|
|
38
|
+
async def complete(
|
|
39
|
+
self,
|
|
40
|
+
messages: list[dict],
|
|
41
|
+
tools: list[ToolSchema],
|
|
42
|
+
stream: bool = True,
|
|
43
|
+
) -> AsyncIterator[AgentOutput]:
|
|
44
|
+
system, msgs = _strip_system(messages)
|
|
45
|
+
|
|
46
|
+
anthropic_tools = (
|
|
47
|
+
[
|
|
48
|
+
{
|
|
49
|
+
"name": t.name,
|
|
50
|
+
"description": t.description,
|
|
51
|
+
"input_schema": t.parameters,
|
|
52
|
+
}
|
|
53
|
+
for t in tools
|
|
54
|
+
]
|
|
55
|
+
if tools
|
|
56
|
+
else []
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
kwargs: dict = dict(
|
|
60
|
+
model=self._model,
|
|
61
|
+
max_tokens=4096,
|
|
62
|
+
messages=msgs,
|
|
63
|
+
)
|
|
64
|
+
if system:
|
|
65
|
+
kwargs["system"] = system
|
|
66
|
+
if anthropic_tools:
|
|
67
|
+
kwargs["tools"] = anthropic_tools
|
|
68
|
+
|
|
69
|
+
if stream:
|
|
70
|
+
async for chunk in self._stream(kwargs):
|
|
71
|
+
yield chunk
|
|
72
|
+
else:
|
|
73
|
+
async for chunk in self._blocking(kwargs):
|
|
74
|
+
yield chunk
|
|
75
|
+
|
|
76
|
+
async def _stream(self, kwargs: dict) -> AsyncIterator[AgentOutput]:
|
|
77
|
+
async with self._client.messages.stream(**kwargs) as stream:
|
|
78
|
+
async for text in stream.text_stream:
|
|
79
|
+
yield AgentOutput(content=text, done=False)
|
|
80
|
+
|
|
81
|
+
msg = await stream.get_final_message()
|
|
82
|
+
|
|
83
|
+
if msg.stop_reason == "tool_use":
|
|
84
|
+
tool_calls = []
|
|
85
|
+
for block in msg.content:
|
|
86
|
+
if block.type == "tool_use":
|
|
87
|
+
tool_calls.append(
|
|
88
|
+
ToolCall(id=block.id, name=block.name, arguments=block.input or {})
|
|
89
|
+
)
|
|
90
|
+
yield AgentOutput(tool_calls=tool_calls, done=False)
|
|
91
|
+
else:
|
|
92
|
+
yield AgentOutput(content="", done=True)
|
|
93
|
+
|
|
94
|
+
async def _blocking(self, kwargs: dict) -> AsyncIterator[AgentOutput]:
|
|
95
|
+
msg = await self._client.messages.create(**kwargs)
|
|
96
|
+
|
|
97
|
+
if msg.stop_reason == "tool_use":
|
|
98
|
+
tool_calls = []
|
|
99
|
+
for block in msg.content:
|
|
100
|
+
if block.type == "tool_use":
|
|
101
|
+
tool_calls.append(
|
|
102
|
+
ToolCall(id=block.id, name=block.name, arguments=block.input or {})
|
|
103
|
+
)
|
|
104
|
+
yield AgentOutput(tool_calls=tool_calls, done=False)
|
|
105
|
+
else:
|
|
106
|
+
text = "".join(b.text for b in msg.content if hasattr(b, "text"))
|
|
107
|
+
yield AgentOutput(content=text, done=True)
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import AsyncIterator, Protocol, runtime_checkable
|
|
4
|
+
|
|
5
|
+
from ..schemas import AgentOutput, ToolSchema
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@runtime_checkable
|
|
9
|
+
class Provider(Protocol):
|
|
10
|
+
async def complete(
|
|
11
|
+
self,
|
|
12
|
+
messages: list[dict],
|
|
13
|
+
tools: list[ToolSchema],
|
|
14
|
+
stream: bool = True,
|
|
15
|
+
) -> AsyncIterator[AgentOutput]: ...
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import AsyncIterator
|
|
5
|
+
|
|
6
|
+
from ..schemas import AgentOutput, ModelConfig, ToolCall, ToolSchema
|
|
7
|
+
|
|
8
|
+
log = logging.getLogger("agentinc.sdk.providers.gemini")
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _to_gemini_messages(messages: list[dict]) -> tuple[str, list[dict]]:
|
|
12
|
+
"""Convert OpenAI-style messages to Gemini contents + system instruction."""
|
|
13
|
+
system = ""
|
|
14
|
+
contents = []
|
|
15
|
+
for m in messages:
|
|
16
|
+
role = m["role"]
|
|
17
|
+
content = m.get("content", "")
|
|
18
|
+
if role == "system":
|
|
19
|
+
system = content
|
|
20
|
+
elif role == "assistant":
|
|
21
|
+
contents.append({"role": "model", "parts": [{"text": content}]})
|
|
22
|
+
elif role == "tool":
|
|
23
|
+
contents.append({
|
|
24
|
+
"role": "user",
|
|
25
|
+
"parts": [{"function_response": {
|
|
26
|
+
"name": m.get("name", ""),
|
|
27
|
+
"response": {"result": content},
|
|
28
|
+
}}],
|
|
29
|
+
})
|
|
30
|
+
else:
|
|
31
|
+
contents.append({"role": "user", "parts": [{"text": content}]})
|
|
32
|
+
return system, contents
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class GeminiProvider:
|
|
36
|
+
def __init__(self, config: ModelConfig) -> None:
|
|
37
|
+
try:
|
|
38
|
+
from google import genai
|
|
39
|
+
from google.genai import types as genai_types
|
|
40
|
+
self._genai = genai
|
|
41
|
+
self._types = genai_types
|
|
42
|
+
except ImportError:
|
|
43
|
+
raise ImportError(
|
|
44
|
+
"Gemini provider requires the gemini extra: "
|
|
45
|
+
"pip install 'agentinc-sdk[gemini]'"
|
|
46
|
+
)
|
|
47
|
+
self._client = genai.Client(api_key=config["api_key"])
|
|
48
|
+
self._model = config["model"]
|
|
49
|
+
|
|
50
|
+
async def complete(
|
|
51
|
+
self,
|
|
52
|
+
messages: list[dict],
|
|
53
|
+
tools: list[ToolSchema],
|
|
54
|
+
stream: bool = True,
|
|
55
|
+
) -> AsyncIterator[AgentOutput]:
|
|
56
|
+
system, contents = _to_gemini_messages(messages)
|
|
57
|
+
|
|
58
|
+
gemini_tools = None
|
|
59
|
+
if tools:
|
|
60
|
+
function_declarations = [
|
|
61
|
+
self._types.FunctionDeclaration(
|
|
62
|
+
name=t.name,
|
|
63
|
+
description=t.description,
|
|
64
|
+
parameters=t.parameters,
|
|
65
|
+
)
|
|
66
|
+
for t in tools
|
|
67
|
+
]
|
|
68
|
+
gemini_tools = [self._types.Tool(function_declarations=function_declarations)]
|
|
69
|
+
|
|
70
|
+
config_kwargs: dict = {}
|
|
71
|
+
if system:
|
|
72
|
+
config_kwargs["system_instruction"] = system
|
|
73
|
+
if gemini_tools:
|
|
74
|
+
config_kwargs["tools"] = gemini_tools
|
|
75
|
+
|
|
76
|
+
gen_config = self._types.GenerateContentConfig(**config_kwargs) if config_kwargs else None
|
|
77
|
+
|
|
78
|
+
if stream:
|
|
79
|
+
async for chunk in self._stream(contents, gen_config):
|
|
80
|
+
yield chunk
|
|
81
|
+
else:
|
|
82
|
+
async for chunk in self._blocking(contents, gen_config):
|
|
83
|
+
yield chunk
|
|
84
|
+
|
|
85
|
+
async def _stream(self, contents, gen_config) -> AsyncIterator[AgentOutput]:
|
|
86
|
+
kwargs = {"model": self._model, "contents": contents}
|
|
87
|
+
if gen_config:
|
|
88
|
+
kwargs["config"] = gen_config
|
|
89
|
+
|
|
90
|
+
async for chunk in await self._client.aio.models.generate_content_stream(**kwargs):
|
|
91
|
+
if chunk.function_calls:
|
|
92
|
+
tool_calls = [
|
|
93
|
+
ToolCall(id=fc.id or fc.name, name=fc.name, arguments=dict(fc.args or {}))
|
|
94
|
+
for fc in chunk.function_calls
|
|
95
|
+
]
|
|
96
|
+
yield AgentOutput(tool_calls=tool_calls, done=False)
|
|
97
|
+
return
|
|
98
|
+
if chunk.text:
|
|
99
|
+
yield AgentOutput(content=chunk.text, done=False)
|
|
100
|
+
|
|
101
|
+
yield AgentOutput(content="", done=True)
|
|
102
|
+
|
|
103
|
+
async def _blocking(self, contents, gen_config) -> AsyncIterator[AgentOutput]:
|
|
104
|
+
kwargs = {"model": self._model, "contents": contents}
|
|
105
|
+
if gen_config:
|
|
106
|
+
kwargs["config"] = gen_config
|
|
107
|
+
|
|
108
|
+
response = await self._client.aio.models.generate_content(**kwargs)
|
|
109
|
+
|
|
110
|
+
if response.function_calls:
|
|
111
|
+
tool_calls = [
|
|
112
|
+
ToolCall(id=fc.id or fc.name, name=fc.name, arguments=dict(fc.args or {}))
|
|
113
|
+
for fc in response.function_calls
|
|
114
|
+
]
|
|
115
|
+
yield AgentOutput(tool_calls=tool_calls, done=False)
|
|
116
|
+
else:
|
|
117
|
+
yield AgentOutput(content=response.text or "", done=True)
|