devpilot-agentic-cli 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agent/__init__.py +1 -0
- agent/a2a_client.py +94 -0
- agent/a2a_server.py +148 -0
- agent/cli.py +233 -0
- agent/config.py +232 -0
- agent/context.py +182 -0
- agent/history.py +172 -0
- agent/loop.py +102 -0
- agent/mcp_client.py +104 -0
- agent/providers/__init__.py +4 -0
- agent/providers/anthropic_provider.py +169 -0
- agent/providers/base.py +148 -0
- agent/providers/factory.py +35 -0
- agent/providers/openai_provider.py +194 -0
- agent/providers/system_prompt.py +132 -0
- agent/setup_wizard.py +309 -0
- agent/tools/__init__.py +15 -0
- agent/tools/a2a.py +56 -0
- agent/tools/base.py +52 -0
- agent/tools/diagram.py +131 -0
- agent/tools/doc_gen.py +163 -0
- agent/tools/fs.py +411 -0
- agent/tools/git_ops.py +145 -0
- agent/tools/registry.py +219 -0
- agent/tools/search_code.py +120 -0
- agent/tools/shell.py +118 -0
- agent/tools/web_search.py +105 -0
- agent/tui/__init__.py +3 -0
- agent/tui/app.py +557 -0
- agent/ui.py +263 -0
- devpilot_agentic_cli-1.0.0.dist-info/METADATA +288 -0
- devpilot_agentic_cli-1.0.0.dist-info/RECORD +35 -0
- devpilot_agentic_cli-1.0.0.dist-info/WHEEL +5 -0
- devpilot_agentic_cli-1.0.0.dist-info/entry_points.txt +2 -0
- devpilot_agentic_cli-1.0.0.dist-info/top_level.txt +1 -0
agent/mcp_client.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
"""
|
|
2
|
+
agent/mcp_client.py
|
|
3
|
+
───────────────────
|
|
4
|
+
MCP Client Integration (Sprint 3).
|
|
5
|
+
Connects to servers defined in mcp_servers.json, discovers tools,
|
|
6
|
+
and registers them into the ToolRegistry.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
from contextlib import AsyncExitStack
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
|
|
13
|
+
from mcp.client.session import ClientSession
|
|
14
|
+
from mcp.client.stdio import StdioServerParameters, stdio_client
|
|
15
|
+
from mcp.types import TextContent
|
|
16
|
+
|
|
17
|
+
from agent.tools import ToolRegistry, ToolResult
|
|
18
|
+
from agent.ui import UI
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class MCPManager:
|
|
22
|
+
"""Manages connections to multiple MCP servers."""
|
|
23
|
+
|
|
24
|
+
def __init__(self, config_path: Path):
|
|
25
|
+
self.config_path = config_path
|
|
26
|
+
self.exit_stack = AsyncExitStack()
|
|
27
|
+
self.sessions: dict[str, ClientSession] = {}
|
|
28
|
+
|
|
29
|
+
async def connect_all(self, registry: ToolRegistry) -> None:
|
|
30
|
+
"""Connect to all servers in mcp_servers.json and register tools."""
|
|
31
|
+
if not self.config_path.exists():
|
|
32
|
+
return
|
|
33
|
+
|
|
34
|
+
try:
|
|
35
|
+
with open(self.config_path, "r", encoding="utf-8") as f:
|
|
36
|
+
data = json.load(f)
|
|
37
|
+
servers = data.get("mcpServers", data.get("servers", {}))
|
|
38
|
+
except (json.JSONDecodeError, OSError) as e:
|
|
39
|
+
UI.print_error(f"Failed to read mcp_servers.json: {e}")
|
|
40
|
+
return
|
|
41
|
+
|
|
42
|
+
# Handle both list of dicts and dict of dicts formats for mcp_servers.json
|
|
43
|
+
if isinstance(servers, dict):
|
|
44
|
+
# In official MCP config format, it's a dict mapping name to config
|
|
45
|
+
server_items = servers.items()
|
|
46
|
+
else:
|
|
47
|
+
# Fallback if it's a list
|
|
48
|
+
server_items = [(s.get("name", f"server_{i}"), s) for i, s in enumerate(servers)]
|
|
49
|
+
|
|
50
|
+
for name, server_config in server_items:
|
|
51
|
+
if server_config.get("enabled", True) is False:
|
|
52
|
+
continue
|
|
53
|
+
|
|
54
|
+
command = server_config.get("command")
|
|
55
|
+
args = server_config.get("args", [])
|
|
56
|
+
|
|
57
|
+
if not command:
|
|
58
|
+
UI.print_error(f"MCP server '{name}' missing 'command'. Skipping.")
|
|
59
|
+
continue
|
|
60
|
+
|
|
61
|
+
try:
|
|
62
|
+
server_params = StdioServerParameters(command=command, args=args, env=server_config.get("env"))
|
|
63
|
+
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
|
|
64
|
+
read, write = stdio_transport
|
|
65
|
+
session = await self.exit_stack.enter_async_context(ClientSession(read, write))
|
|
66
|
+
await session.initialize()
|
|
67
|
+
|
|
68
|
+
self.sessions[name] = session
|
|
69
|
+
|
|
70
|
+
# Fetch and register tools
|
|
71
|
+
tools_response = await session.list_tools()
|
|
72
|
+
for mcp_tool in tools_response.tools:
|
|
73
|
+
# Convert to canonical schema format
|
|
74
|
+
canonical_schema = {
|
|
75
|
+
"name": mcp_tool.name,
|
|
76
|
+
"description": mcp_tool.description or "",
|
|
77
|
+
"input_schema": mcp_tool.inputSchema,
|
|
78
|
+
"_mcp_server_id": name,
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
# Create closure for execution
|
|
82
|
+
def make_executor(session_ref: ClientSession, tool_name: str):
|
|
83
|
+
async def _executor(tool_input: dict) -> ToolResult:
|
|
84
|
+
try:
|
|
85
|
+
result = await session_ref.call_tool(tool_name, tool_input)
|
|
86
|
+
# Flatten result text
|
|
87
|
+
text_contents = [c.text for c in result.content if isinstance(c, TextContent)]
|
|
88
|
+
output = "\n".join(text_contents)
|
|
89
|
+
return ToolResult(output, is_error=result.isError)
|
|
90
|
+
except Exception as e:
|
|
91
|
+
return ToolResult(f"MCP execution error: {e}", is_error=True)
|
|
92
|
+
return _executor
|
|
93
|
+
|
|
94
|
+
registry.register_mcp_tool(canonical_schema, make_executor(session, mcp_tool.name))
|
|
95
|
+
|
|
96
|
+
UI.print_info(f"Connected to MCP server: {name} ({len(tools_response.tools)} tools)")
|
|
97
|
+
except Exception as e:
|
|
98
|
+
UI.print_error(f"Failed to connect to MCP server '{name}': {e}")
|
|
99
|
+
registry.deregister_mcp_tools(name)
|
|
100
|
+
|
|
101
|
+
async def close(self) -> None:
|
|
102
|
+
"""Close all connections."""
|
|
103
|
+
await self.exit_stack.aclose()
|
|
104
|
+
self.sessions.clear()
|
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
"""
|
|
2
|
+
agent/providers/anthropic_provider.py
|
|
3
|
+
───────────────────────────────────────
|
|
4
|
+
Anthropic (Claude) model provider.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import asyncio
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
import anthropic
|
|
13
|
+
|
|
14
|
+
from agent.config import Config
|
|
15
|
+
from agent.providers.base import BaseProvider, ProviderResponse, ToolUseBlock
|
|
16
|
+
from agent.providers.system_prompt import build_system_prompt
|
|
17
|
+
|
|
18
|
+
_MAX_TOKENS = 16000
|
|
19
|
+
_MAX_TOKENS_NO_THINKING = 8096
|
|
20
|
+
_MAX_RETRIES = 3
|
|
21
|
+
_RETRY_BASE_DELAY = 1.0
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class AnthropicProvider(BaseProvider):
|
|
25
|
+
|
|
26
|
+
def __init__(self, config: Config) -> None:
|
|
27
|
+
config.validate_api_key()
|
|
28
|
+
self._config = config
|
|
29
|
+
self._client = anthropic.AsyncAnthropic(api_key=config.active_api_key)
|
|
30
|
+
|
|
31
|
+
async def _create_with_retry(self, **kwargs: Any) -> anthropic.types.Message:
|
|
32
|
+
for attempt in range(_MAX_RETRIES + 1):
|
|
33
|
+
try:
|
|
34
|
+
return await self._client.messages.create(**kwargs) # type: ignore[call-overload]
|
|
35
|
+
except (anthropic.RateLimitError, anthropic.APIStatusError) as exc:
|
|
36
|
+
if attempt == _MAX_RETRIES:
|
|
37
|
+
raise
|
|
38
|
+
delay = _RETRY_BASE_DELAY * (2 ** attempt)
|
|
39
|
+
print(
|
|
40
|
+
f" [DevPilot] API error ({exc.__class__.__name__}), "
|
|
41
|
+
f"retrying in {delay:.0f}s … (attempt {attempt + 1}/{_MAX_RETRIES})"
|
|
42
|
+
)
|
|
43
|
+
await asyncio.sleep(delay)
|
|
44
|
+
raise RuntimeError("Retry loop exited unexpectedly")
|
|
45
|
+
|
|
46
|
+
async def chat(
|
|
47
|
+
self,
|
|
48
|
+
messages: list[dict],
|
|
49
|
+
tools: list[dict],
|
|
50
|
+
system: str | None = None,
|
|
51
|
+
) -> ProviderResponse:
|
|
52
|
+
kwargs: dict[str, Any] = dict(
|
|
53
|
+
model=self._config.model,
|
|
54
|
+
system=system or build_system_prompt(),
|
|
55
|
+
messages=messages,
|
|
56
|
+
tools=tools,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
if self._config.extended_thinking:
|
|
60
|
+
kwargs["betas"] = ["interleaved-thinking-2025-05-14"]
|
|
61
|
+
kwargs["thinking"] = {
|
|
62
|
+
"type": "enabled",
|
|
63
|
+
"budget_tokens": self._config.thinking_budget,
|
|
64
|
+
}
|
|
65
|
+
kwargs["max_tokens"] = max(_MAX_TOKENS, self._config.thinking_budget + 4096)
|
|
66
|
+
else:
|
|
67
|
+
kwargs["max_tokens"] = _MAX_TOKENS_NO_THINKING
|
|
68
|
+
|
|
69
|
+
response = await self._create_with_retry(**kwargs)
|
|
70
|
+
|
|
71
|
+
text: str | None = None
|
|
72
|
+
tool_uses: list[ToolUseBlock] = []
|
|
73
|
+
raw_content: list[dict] = []
|
|
74
|
+
thinking_text: str | None = None
|
|
75
|
+
|
|
76
|
+
for block in response.content:
|
|
77
|
+
if block.type == "thinking":
|
|
78
|
+
thinking_text = block.thinking
|
|
79
|
+
raw_content.append({"type": "thinking", "thinking": block.thinking})
|
|
80
|
+
elif block.type == "text":
|
|
81
|
+
text = block.text
|
|
82
|
+
raw_content.append({"type": "text", "text": block.text})
|
|
83
|
+
elif block.type == "tool_use":
|
|
84
|
+
tool_uses.append(
|
|
85
|
+
ToolUseBlock(id=block.id, name=block.name, input=block.input)
|
|
86
|
+
)
|
|
87
|
+
raw_content.append(
|
|
88
|
+
{"type": "tool_use", "id": block.id, "name": block.name, "input": block.input}
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
stop_reason = "tool_use" if tool_uses else (response.stop_reason or "end_turn")
|
|
92
|
+
return ProviderResponse(
|
|
93
|
+
text=text,
|
|
94
|
+
tool_uses=tool_uses,
|
|
95
|
+
stop_reason=stop_reason,
|
|
96
|
+
assistant_message={"role": "assistant", "content": raw_content},
|
|
97
|
+
thinking=thinking_text,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
async def chat_stream(
|
|
101
|
+
self,
|
|
102
|
+
messages: list[dict],
|
|
103
|
+
tools: list[dict],
|
|
104
|
+
system: str | None = None,
|
|
105
|
+
) -> ProviderResponse:
|
|
106
|
+
if self._config.extended_thinking:
|
|
107
|
+
return await self.chat(messages, tools, system=system)
|
|
108
|
+
|
|
109
|
+
from agent.ui import console
|
|
110
|
+
|
|
111
|
+
kwargs: dict[str, Any] = dict(
|
|
112
|
+
model=self._config.model,
|
|
113
|
+
max_tokens=_MAX_TOKENS_NO_THINKING,
|
|
114
|
+
system=system or build_system_prompt(),
|
|
115
|
+
messages=messages,
|
|
116
|
+
tools=tools,
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
try:
|
|
120
|
+
async with self._client.messages.stream(**kwargs) as stream:
|
|
121
|
+
from agent.ui import UI
|
|
122
|
+
if not getattr(UI, "_tui_app", None):
|
|
123
|
+
console.print()
|
|
124
|
+
|
|
125
|
+
async for text_delta in stream.text_stream:
|
|
126
|
+
UI.print_stream_token(text_delta)
|
|
127
|
+
|
|
128
|
+
if not getattr(UI, "_tui_app", None):
|
|
129
|
+
console.print()
|
|
130
|
+
console.print()
|
|
131
|
+
|
|
132
|
+
final = await stream.get_final_message()
|
|
133
|
+
|
|
134
|
+
except (anthropic.RateLimitError, anthropic.APIStatusError):
|
|
135
|
+
return await self.chat(messages, tools, system=system)
|
|
136
|
+
|
|
137
|
+
text = None
|
|
138
|
+
tool_uses = []
|
|
139
|
+
raw_content = []
|
|
140
|
+
|
|
141
|
+
for block in final.content:
|
|
142
|
+
if block.type == "text":
|
|
143
|
+
text = block.text
|
|
144
|
+
raw_content.append({"type": "text", "text": block.text})
|
|
145
|
+
elif block.type == "tool_use":
|
|
146
|
+
tool_uses.append(
|
|
147
|
+
ToolUseBlock(id=block.id, name=block.name, input=block.input)
|
|
148
|
+
)
|
|
149
|
+
raw_content.append(
|
|
150
|
+
{"type": "tool_use", "id": block.id, "name": block.name, "input": block.input}
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
stop_reason = "tool_use" if tool_uses else (final.stop_reason or "end_turn")
|
|
154
|
+
return ProviderResponse(
|
|
155
|
+
text=text,
|
|
156
|
+
tool_uses=tool_uses,
|
|
157
|
+
stop_reason=stop_reason,
|
|
158
|
+
assistant_message={"role": "assistant", "content": raw_content},
|
|
159
|
+
streamed_text=True,
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
def make_tool_result_message(self, tool_use_id: str, content: str, is_error: bool = False) -> dict:
|
|
163
|
+
return {
|
|
164
|
+
"role": "user",
|
|
165
|
+
"content": [{"type": "tool_result", "tool_use_id": tool_use_id, "content": content, "is_error": is_error}],
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
def make_user_message(self, text: str) -> dict:
|
|
169
|
+
return {"role": "user", "content": text}
|
agent/providers/base.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
"""
|
|
2
|
+
agent/providers/base.py
|
|
3
|
+
────────────────────────
|
|
4
|
+
Abstract interface and shared data types for all model providers.
|
|
5
|
+
|
|
6
|
+
The agentic loop works exclusively with these types — it has zero knowledge
|
|
7
|
+
of whether Anthropic, OpenAI, or any other provider is underneath.
|
|
8
|
+
|
|
9
|
+
─── Canonical message format (stored in conversation history) ───────────────
|
|
10
|
+
|
|
11
|
+
User text message:
|
|
12
|
+
{"role": "user", "content": "some text"}
|
|
13
|
+
|
|
14
|
+
User message carrying tool results:
|
|
15
|
+
{"role": "user", "content": [
|
|
16
|
+
{
|
|
17
|
+
"type": "tool_result",
|
|
18
|
+
"tool_use_id": "<id that matches the tool_use block>",
|
|
19
|
+
"content": "<tool output string>",
|
|
20
|
+
"is_error": False
|
|
21
|
+
}
|
|
22
|
+
]}
|
|
23
|
+
|
|
24
|
+
Assistant message (text and/or tool calls):
|
|
25
|
+
{"role": "assistant", "content": [
|
|
26
|
+
{"type": "text", "text": "..."},
|
|
27
|
+
{"type": "tool_use", "id": "...", "name": "...", "input": {...}}
|
|
28
|
+
]}
|
|
29
|
+
|
|
30
|
+
─── Canonical tool schema format (fed to the provider) ──────────────────────
|
|
31
|
+
|
|
32
|
+
{
|
|
33
|
+
"name": "tool_name",
|
|
34
|
+
"description": "What the tool does.",
|
|
35
|
+
"input_schema": { ← Anthropic-style JSON Schema
|
|
36
|
+
"type": "object",
|
|
37
|
+
"properties": { ... },
|
|
38
|
+
"required": [ ... ]
|
|
39
|
+
}
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
Providers that use a different schema format (e.g., OpenAI uses
|
|
43
|
+
"function.parameters") are responsible for converting internally.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
from __future__ import annotations
|
|
47
|
+
|
|
48
|
+
from abc import ABC, abstractmethod
|
|
49
|
+
from dataclasses import dataclass, field
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
# ── Shared data types ─────────────────────────────────────────────────────────
|
|
53
|
+
|
|
54
|
+
@dataclass
|
|
55
|
+
class ToolUseBlock:
|
|
56
|
+
"""A single tool invocation requested by the model."""
|
|
57
|
+
id: str # Unique ID — used to pair with the tool_result
|
|
58
|
+
name: str # Must match a tool name in the registry
|
|
59
|
+
input: dict # Arguments the model wants to pass to the tool
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@dataclass
|
|
63
|
+
class ProviderResponse:
|
|
64
|
+
"""
|
|
65
|
+
Normalised response returned by every provider's chat() call.
|
|
66
|
+
The agentic loop reads only this object — never raw SDK types.
|
|
67
|
+
"""
|
|
68
|
+
text: str | None # Prose text, if any
|
|
69
|
+
tool_uses: list[ToolUseBlock] # Zero or more tool invocations
|
|
70
|
+
stop_reason: str # "end_turn" | "tool_use" | "max_tokens"
|
|
71
|
+
assistant_message: dict # Ready-to-append canonical history entry
|
|
72
|
+
thinking: str | None = None # Extended thinking inner monologue (Anthropic only)
|
|
73
|
+
streamed_text: bool = False # True if the text was already printed live to stdout
|
|
74
|
+
|
|
75
|
+
@property
|
|
76
|
+
def has_tool_uses(self) -> bool:
|
|
77
|
+
"""True when the model wants to call one or more tools."""
|
|
78
|
+
return len(self.tool_uses) > 0
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
# ── Abstract provider interface ───────────────────────────────────────────────
|
|
82
|
+
|
|
83
|
+
class BaseProvider(ABC):
|
|
84
|
+
"""
|
|
85
|
+
Every model provider must implement this interface.
|
|
86
|
+
No other module should import SDK-specific types.
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
@abstractmethod
|
|
90
|
+
async def chat(
|
|
91
|
+
self,
|
|
92
|
+
messages: list[dict],
|
|
93
|
+
tools: list[dict],
|
|
94
|
+
system: str | None = None,
|
|
95
|
+
) -> ProviderResponse:
|
|
96
|
+
"""
|
|
97
|
+
Send the conversation to the model and return a normalised response.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
messages : Full conversation history in canonical format.
|
|
101
|
+
tools : Tool schemas in canonical (Anthropic-style) format.
|
|
102
|
+
system : Optional system prompt override.
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
ProviderResponse — text, tool_uses, stop_reason, assistant_message.
|
|
106
|
+
"""
|
|
107
|
+
...
|
|
108
|
+
|
|
109
|
+
async def chat_stream(
|
|
110
|
+
self,
|
|
111
|
+
messages: list[dict],
|
|
112
|
+
tools: list[dict],
|
|
113
|
+
system: str | None = None,
|
|
114
|
+
) -> ProviderResponse:
|
|
115
|
+
"""
|
|
116
|
+
Stream the model response, printing tokens to the console in real time,
|
|
117
|
+
then return the same ProviderResponse as chat().
|
|
118
|
+
|
|
119
|
+
Providers that don't support streaming should fall back to chat().
|
|
120
|
+
The default implementation does exactly that.
|
|
121
|
+
"""
|
|
122
|
+
return await self.chat(messages, tools, system=system)
|
|
123
|
+
|
|
124
|
+
@abstractmethod
|
|
125
|
+
def make_tool_result_message(
|
|
126
|
+
self,
|
|
127
|
+
tool_use_id: str,
|
|
128
|
+
content: str,
|
|
129
|
+
is_error: bool = False,
|
|
130
|
+
) -> dict:
|
|
131
|
+
"""
|
|
132
|
+
Build the canonical history message that carries a tool's output back
|
|
133
|
+
to the model. This is appended to history immediately after a tool runs.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
tool_use_id : The ID from the matching ToolUseBlock.
|
|
137
|
+
content : The tool's output (stringified).
|
|
138
|
+
is_error : True if the tool raised an exception.
|
|
139
|
+
"""
|
|
140
|
+
...
|
|
141
|
+
|
|
142
|
+
@abstractmethod
|
|
143
|
+
def make_user_message(self, text: str) -> dict:
|
|
144
|
+
"""
|
|
145
|
+
Wrap a plain text string in a canonical user message dict.
|
|
146
|
+
Used to inject the initial task into the conversation.
|
|
147
|
+
"""
|
|
148
|
+
...
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
"""
|
|
2
|
+
agent/providers/factory.py
|
|
3
|
+
───────────────────────────
|
|
4
|
+
Single factory function that reads config and instantiates
|
|
5
|
+
the correct provider. All other modules use this instead of
|
|
6
|
+
importing individual providers directly.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
from agent.config import Config
|
|
12
|
+
from agent.providers.base import BaseProvider
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def create_provider(config: Config) -> BaseProvider:
|
|
16
|
+
"""
|
|
17
|
+
Instantiate and return the correct BaseProvider for the given config.
|
|
18
|
+
|
|
19
|
+
Raises:
|
|
20
|
+
ValueError : If config.provider is not a recognised name.
|
|
21
|
+
ConfigError : If the provider's API key is missing (raised inside __init__).
|
|
22
|
+
"""
|
|
23
|
+
# Import here (not at module level) to avoid loading unused SDKs
|
|
24
|
+
if config.provider == "anthropic":
|
|
25
|
+
from agent.providers.anthropic_provider import AnthropicProvider
|
|
26
|
+
return AnthropicProvider(config)
|
|
27
|
+
|
|
28
|
+
if config.provider == "openai":
|
|
29
|
+
from agent.providers.openai_provider import OpenAIProvider
|
|
30
|
+
return OpenAIProvider(config)
|
|
31
|
+
|
|
32
|
+
raise ValueError(
|
|
33
|
+
f"Unknown provider '{config.provider}'. "
|
|
34
|
+
"Valid options are: 'anthropic', 'openai'."
|
|
35
|
+
)
|
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
"""
|
|
2
|
+
agent/providers/openai_provider.py
|
|
3
|
+
────────────────────────────────────
|
|
4
|
+
OpenAI-compatible model provider.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import asyncio
|
|
10
|
+
import json
|
|
11
|
+
from typing import Any, cast
|
|
12
|
+
|
|
13
|
+
import openai
|
|
14
|
+
from openai import AsyncOpenAI
|
|
15
|
+
from openai.types.chat import ChatCompletion
|
|
16
|
+
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall
|
|
17
|
+
|
|
18
|
+
from agent.config import Config
|
|
19
|
+
from agent.providers.base import BaseProvider, ProviderResponse, ToolUseBlock
|
|
20
|
+
from agent.providers.system_prompt import build_system_prompt
|
|
21
|
+
|
|
22
|
+
_MAX_TOKENS = 4096
|
|
23
|
+
_MAX_RETRIES = 3
|
|
24
|
+
_RETRY_BASE_DELAY = 1.0
|
|
25
|
+
|
|
26
|
+
_STOP_REASON_MAP: dict[str, str] = {
|
|
27
|
+
"stop": "end_turn",
|
|
28
|
+
"tool_calls": "tool_use",
|
|
29
|
+
"length": "max_tokens",
|
|
30
|
+
"content_filter": "end_turn",
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class OpenAIProvider(BaseProvider):
|
|
35
|
+
|
|
36
|
+
def __init__(self, config: Config) -> None:
|
|
37
|
+
config.validate_api_key()
|
|
38
|
+
self._config = config
|
|
39
|
+
self._client = AsyncOpenAI(api_key=config.active_api_key, base_url=config.base_url)
|
|
40
|
+
|
|
41
|
+
@staticmethod
|
|
42
|
+
def _to_openai_messages(messages: list[dict]) -> list[dict]:
|
|
43
|
+
result: list[dict] = []
|
|
44
|
+
for msg in messages:
|
|
45
|
+
role: str = msg["role"]
|
|
46
|
+
content: Any = msg["content"]
|
|
47
|
+
|
|
48
|
+
if isinstance(content, str):
|
|
49
|
+
result.append({"role": role, "content": content})
|
|
50
|
+
continue
|
|
51
|
+
|
|
52
|
+
if role == "assistant":
|
|
53
|
+
text_parts: list[str] = []
|
|
54
|
+
tool_calls: list[dict] = []
|
|
55
|
+
for block in content:
|
|
56
|
+
if block["type"] == "text":
|
|
57
|
+
text_parts.append(block["text"])
|
|
58
|
+
elif block["type"] == "tool_use":
|
|
59
|
+
tool_calls.append({
|
|
60
|
+
"id": block["id"],
|
|
61
|
+
"type": "function",
|
|
62
|
+
"function": {"name": block["name"], "arguments": json.dumps(block["input"])},
|
|
63
|
+
})
|
|
64
|
+
oai_msg: dict[str, Any] = {
|
|
65
|
+
"role": "assistant",
|
|
66
|
+
"content": " ".join(text_parts) if text_parts else None,
|
|
67
|
+
}
|
|
68
|
+
if tool_calls:
|
|
69
|
+
oai_msg["tool_calls"] = tool_calls
|
|
70
|
+
result.append(oai_msg)
|
|
71
|
+
|
|
72
|
+
elif role == "user":
|
|
73
|
+
for block in content:
|
|
74
|
+
if block["type"] == "tool_result":
|
|
75
|
+
result.append({
|
|
76
|
+
"role": "tool",
|
|
77
|
+
"tool_call_id": block["tool_use_id"],
|
|
78
|
+
"content": block["content"],
|
|
79
|
+
})
|
|
80
|
+
else:
|
|
81
|
+
result.append({"role": "user", "content": block.get("text", "")})
|
|
82
|
+
|
|
83
|
+
return result
|
|
84
|
+
|
|
85
|
+
@staticmethod
|
|
86
|
+
def _to_openai_tools(tools: list[dict]) -> list[dict]:
|
|
87
|
+
return [
|
|
88
|
+
{
|
|
89
|
+
"type": "function",
|
|
90
|
+
"function": {
|
|
91
|
+
"name": t["name"],
|
|
92
|
+
"description": t.get("description", ""),
|
|
93
|
+
"parameters": t.get("input_schema", {}),
|
|
94
|
+
},
|
|
95
|
+
}
|
|
96
|
+
for t in tools
|
|
97
|
+
]
|
|
98
|
+
|
|
99
|
+
async def _create_with_retry(self, **kwargs: Any) -> ChatCompletion:
|
|
100
|
+
for attempt in range(_MAX_RETRIES + 1):
|
|
101
|
+
try:
|
|
102
|
+
return cast(ChatCompletion, await self._client.chat.completions.create(**kwargs))
|
|
103
|
+
except (openai.RateLimitError, openai.APIStatusError) as exc:
|
|
104
|
+
if attempt == _MAX_RETRIES:
|
|
105
|
+
raise
|
|
106
|
+
delay = _RETRY_BASE_DELAY * (2 ** attempt)
|
|
107
|
+
print(
|
|
108
|
+
f" [DevPilot] API error ({exc.__class__.__name__}), "
|
|
109
|
+
f"retrying in {delay:.0f}s … (attempt {attempt + 1}/{_MAX_RETRIES})"
|
|
110
|
+
)
|
|
111
|
+
await asyncio.sleep(delay)
|
|
112
|
+
raise RuntimeError("Unreachable")
|
|
113
|
+
|
|
114
|
+
async def chat(
|
|
115
|
+
self,
|
|
116
|
+
messages: list[dict],
|
|
117
|
+
tools: list[dict],
|
|
118
|
+
system: str | None = None,
|
|
119
|
+
) -> ProviderResponse:
|
|
120
|
+
oai_messages = self._to_openai_messages(messages)
|
|
121
|
+
oai_messages.insert(0, {"role": "system", "content": system or build_system_prompt()})
|
|
122
|
+
|
|
123
|
+
call_kwargs: dict[str, Any] = {
|
|
124
|
+
"model": self._config.model,
|
|
125
|
+
"max_completion_tokens": _MAX_TOKENS,
|
|
126
|
+
"messages": oai_messages,
|
|
127
|
+
}
|
|
128
|
+
oai_tools = self._to_openai_tools(tools)
|
|
129
|
+
if oai_tools:
|
|
130
|
+
call_kwargs["tools"] = oai_tools
|
|
131
|
+
|
|
132
|
+
response = await self._create_with_retry(**call_kwargs)
|
|
133
|
+
choice = response.choices[0]
|
|
134
|
+
oai_msg = choice.message
|
|
135
|
+
|
|
136
|
+
text: str | None = oai_msg.content or None
|
|
137
|
+
tool_uses: list[ToolUseBlock] = []
|
|
138
|
+
raw_content: list[dict] = []
|
|
139
|
+
|
|
140
|
+
if text:
|
|
141
|
+
raw_content.append({"type": "text", "text": text})
|
|
142
|
+
|
|
143
|
+
tool_calls = oai_msg.tool_calls
|
|
144
|
+
if tool_calls:
|
|
145
|
+
for tc in tool_calls:
|
|
146
|
+
tc_typed = cast(ChatCompletionMessageToolCall, tc)
|
|
147
|
+
fn_name: str = tc_typed.function.name
|
|
148
|
+
fn_args: str = tc_typed.function.arguments or "{}"
|
|
149
|
+
try:
|
|
150
|
+
tool_input: dict = json.loads(fn_args)
|
|
151
|
+
except json.JSONDecodeError:
|
|
152
|
+
tool_input = {}
|
|
153
|
+
tool_uses.append(ToolUseBlock(id=tc_typed.id, name=fn_name, input=tool_input))
|
|
154
|
+
raw_content.append(
|
|
155
|
+
{"type": "tool_use", "id": tc_typed.id, "name": fn_name, "input": tool_input}
|
|
156
|
+
)
|
|
157
|
+
elif text:
|
|
158
|
+
# Fallback: detect JSON tool calls hallucinated inline by local models (e.g. Ollama)
|
|
159
|
+
try:
|
|
160
|
+
import re, uuid
|
|
161
|
+
match = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", text, re.DOTALL)
|
|
162
|
+
if match:
|
|
163
|
+
potential_json, start_idx, end_idx = match.group(1), match.start(), match.end()
|
|
164
|
+
else:
|
|
165
|
+
start_idx, end_idx = text.find("{"), text.rfind("}")
|
|
166
|
+
potential_json = text[start_idx:end_idx + 1] if start_idx != -1 and end_idx > start_idx else ""
|
|
167
|
+
|
|
168
|
+
if potential_json:
|
|
169
|
+
parsed = json.loads(potential_json)
|
|
170
|
+
if isinstance(parsed, dict) and "name" in parsed and "arguments" in parsed:
|
|
171
|
+
fake_id = f"call_{uuid.uuid4().hex[:8]}"
|
|
172
|
+
tool_uses.append(ToolUseBlock(id=fake_id, name=parsed["name"], input=parsed.get("arguments", {})))
|
|
173
|
+
raw_content.append(
|
|
174
|
+
{"type": "tool_use", "id": fake_id, "name": parsed["name"], "input": parsed.get("arguments", {})}
|
|
175
|
+
)
|
|
176
|
+
except Exception:
|
|
177
|
+
pass
|
|
178
|
+
|
|
179
|
+
stop_reason = _STOP_REASON_MAP.get(choice.finish_reason or "stop", "end_turn")
|
|
180
|
+
return ProviderResponse(
|
|
181
|
+
text=text,
|
|
182
|
+
tool_uses=tool_uses,
|
|
183
|
+
stop_reason=stop_reason,
|
|
184
|
+
assistant_message={"role": "assistant", "content": raw_content},
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
def make_tool_result_message(self, tool_use_id: str, content: str, is_error: bool = False) -> dict:
|
|
188
|
+
return {
|
|
189
|
+
"role": "user",
|
|
190
|
+
"content": [{"type": "tool_result", "tool_use_id": tool_use_id, "content": content, "is_error": is_error}],
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
def make_user_message(self, text: str) -> dict:
|
|
194
|
+
return {"role": "user", "content": text}
|