yycode 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agent/__init__.py +33 -0
- agent/acp/__init__.py +2 -0
- agent/acp/approval_adapter.py +134 -0
- agent/acp/content_adapter.py +45 -0
- agent/acp/jsonrpc.py +92 -0
- agent/acp/server.py +197 -0
- agent/acp/session_manager.py +193 -0
- agent/acp/update_adapter.py +192 -0
- agent/app_paths.py +25 -0
- agent/approval.py +169 -0
- agent/cancellation.py +52 -0
- agent/change_snapshot.py +186 -0
- agent/context_compressor.py +116 -0
- agent/graph.py +137 -0
- agent/llm_retry.py +434 -0
- agent/logger.py +97 -0
- agent/lsp/__init__.py +13 -0
- agent/lsp/client.py +151 -0
- agent/lsp/manager.py +234 -0
- agent/lsp/types.py +119 -0
- agent/message_context_manager.py +322 -0
- agent/message_format.py +105 -0
- agent/nodes/llm_node.py +58 -0
- agent/nodes/state.py +12 -0
- agent/nodes/task_guard_node.py +50 -0
- agent/nodes/tools_node.py +70 -0
- agent/plan_snapshot.py +70 -0
- agent/providers/__init__.py +13 -0
- agent/providers/anthropic_provider.py +268 -0
- agent/providers/base.py +52 -0
- agent/providers/openai_provider.py +279 -0
- agent/providers/text_tool_calls.py +118 -0
- agent/runtime/approval_service.py +184 -0
- agent/runtime/context.py +43 -0
- agent/runtime/tool_events.py +368 -0
- agent/runtime/tool_executor.py +208 -0
- agent/runtime/tool_output.py +261 -0
- agent/runtime/tool_registry.py +91 -0
- agent/runtime/tool_scheduler.py +35 -0
- agent/runtime/workflow_guard.py +217 -0
- agent/runtime/workspace.py +5 -0
- agent/runtime/workspace_tools.py +22 -0
- agent/session.py +787 -0
- agent/session_replay.py +95 -0
- agent/session_store.py +186 -0
- agent/skills.py +254 -0
- agent/streaming.py +248 -0
- agent/subagent.py +634 -0
- agent/task_memory.py +340 -0
- agent/todo_manager.py +304 -0
- agent/tool_retry.py +106 -0
- agent/tui/__init__.py +14 -0
- agent/tui/app.py +1325 -0
- agent/tui/approval.py +53 -0
- agent/tui/commands/__init__.py +6 -0
- agent/tui/commands/base.py +48 -0
- agent/tui/commands/clear.py +37 -0
- agent/tui/commands/help.py +27 -0
- agent/tui/commands/registry.py +94 -0
- agent/tui/help_content.py +108 -0
- agent/tui/renderers.py +1961 -0
- agent/tui/runner.py +439 -0
- agent/tui/state.py +653 -0
- main.py +465 -0
- tools/__init__.py +50 -0
- tools/apply_patch.py +305 -0
- tools/bash.py +76 -0
- tools/diff_utils.py +139 -0
- tools/edit_file.py +40 -0
- tools/git_diff.py +72 -0
- tools/git_show.py +65 -0
- tools/grep.py +149 -0
- tools/list_files.py +90 -0
- tools/list_skills.py +24 -0
- tools/load_skill.py +30 -0
- tools/lsp_definition.py +27 -0
- tools/lsp_diagnostics.py +32 -0
- tools/lsp_document_symbols.py +23 -0
- tools/lsp_hover.py +29 -0
- tools/lsp_references.py +37 -0
- tools/lsp_utils.py +38 -0
- tools/lsp_workspace_symbols.py +23 -0
- tools/read_file.py +61 -0
- tools/read_many_files.py +50 -0
- tools/safety.py +50 -0
- tools/subagent.py +57 -0
- tools/todo.py +89 -0
- tools/verify.py +107 -0
- tools/web_search.py +250 -0
- tools/workspace.py +36 -0
- tools/workspace_state.py +60 -0
- tools/write_file.py +88 -0
- utils/__init__.py +5 -0
- utils/retry.py +13 -0
- yycode-0.3.2.data/data/skills/code_review.md +61 -0
- yycode-0.3.2.data/data/skills/code_workflow.md +404 -0
- yycode-0.3.2.data/data/skills/drawio/SKILL.md +636 -0
- yycode-0.3.2.data/data/skills/drawio/agents/openai.yaml +19 -0
- yycode-0.3.2.data/data/skills/drawio/assets/demo-erd.drawio +84 -0
- yycode-0.3.2.data/data/skills/drawio/assets/demo-layered-cn.drawio +91 -0
- yycode-0.3.2.data/data/skills/drawio/assets/demo-layered-cn.png +0 -0
- yycode-0.3.2.data/data/skills/drawio/assets/demo-layered.drawio +112 -0
- yycode-0.3.2.data/data/skills/drawio/assets/demo-layered.png +0 -0
- yycode-0.3.2.data/data/skills/drawio/assets/demo-ml.drawio +90 -0
- yycode-0.3.2.data/data/skills/drawio/assets/demo-ring-cn.drawio +68 -0
- yycode-0.3.2.data/data/skills/drawio/assets/demo-ring-cn.png +0 -0
- yycode-0.3.2.data/data/skills/drawio/assets/demo-ring.drawio +86 -0
- yycode-0.3.2.data/data/skills/drawio/assets/demo-ring.png +0 -0
- yycode-0.3.2.data/data/skills/drawio/assets/demo-sequence.drawio +116 -0
- yycode-0.3.2.data/data/skills/drawio/assets/demo-star-cn.drawio +66 -0
- yycode-0.3.2.data/data/skills/drawio/assets/demo-star-cn.png +0 -0
- yycode-0.3.2.data/data/skills/drawio/assets/demo-star.drawio +79 -0
- yycode-0.3.2.data/data/skills/drawio/assets/demo-star.png +0 -0
- yycode-0.3.2.data/data/skills/drawio/assets/demo-uml-class.drawio +64 -0
- yycode-0.3.2.data/data/skills/drawio/assets/microservices-example.drawio +173 -0
- yycode-0.3.2.data/data/skills/drawio/assets/microservices-example.png +0 -0
- yycode-0.3.2.data/data/skills/drawio/assets/workflow-cn.drawio +120 -0
- yycode-0.3.2.data/data/skills/drawio/assets/workflow-cn.png +0 -0
- yycode-0.3.2.data/data/skills/drawio/assets/workflow.drawio +120 -0
- yycode-0.3.2.data/data/skills/drawio/assets/workflow.png +0 -0
- yycode-0.3.2.data/data/skills/drawio/docs/index.html +469 -0
- yycode-0.3.2.data/data/skills/drawio/docs/zh.html +456 -0
- yycode-0.3.2.data/data/skills/drawio/references/style-extraction.md +254 -0
- yycode-0.3.2.data/data/skills/drawio/styles/schema.json +112 -0
- yycode-0.3.2.data/data/skills/plan.md +115 -0
- yycode-0.3.2.data/data/skills/ppt/SKILL.md +254 -0
- yycode-0.3.2.dist-info/METADATA +12 -0
- yycode-0.3.2.dist-info/RECORD +131 -0
- yycode-0.3.2.dist-info/WHEEL +5 -0
- yycode-0.3.2.dist-info/entry_points.txt +2 -0
- yycode-0.3.2.dist-info/top_level.txt +4 -0
agent/message_format.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
"""Helpers for converting LangChain messages into provider-neutral payloads."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def messages_to_provider_format(messages: list[BaseMessage]) -> list[dict]:
|
|
9
|
+
"""Convert LangChain messages to the provider-neutral format used by providers."""
|
|
10
|
+
provider_messages: list[dict] = []
|
|
11
|
+
index = 0
|
|
12
|
+
while index < len(messages):
|
|
13
|
+
msg = messages[index]
|
|
14
|
+
if isinstance(msg, HumanMessage):
|
|
15
|
+
provider_messages.append({"role": "user", "content": msg.content})
|
|
16
|
+
elif isinstance(msg, AIMessage):
|
|
17
|
+
assistant_message = {
|
|
18
|
+
"role": "assistant",
|
|
19
|
+
"content": _assistant_content(msg),
|
|
20
|
+
}
|
|
21
|
+
reasoning_content = _assistant_reasoning_content(msg)
|
|
22
|
+
if reasoning_content:
|
|
23
|
+
assistant_message["reasoning_content"] = reasoning_content
|
|
24
|
+
provider_messages.append(assistant_message)
|
|
25
|
+
elif isinstance(msg, ToolMessage):
|
|
26
|
+
tool_results: list[dict[str, Any]] = []
|
|
27
|
+
while index < len(messages) and isinstance(messages[index], ToolMessage):
|
|
28
|
+
tool_msg = messages[index]
|
|
29
|
+
tool_results.append(
|
|
30
|
+
{
|
|
31
|
+
"type": "tool_result",
|
|
32
|
+
"tool_use_id": tool_msg.tool_call_id,
|
|
33
|
+
"content": tool_msg.content,
|
|
34
|
+
}
|
|
35
|
+
)
|
|
36
|
+
index += 1
|
|
37
|
+
provider_messages.append(
|
|
38
|
+
{
|
|
39
|
+
"role": "user",
|
|
40
|
+
"content": tool_results,
|
|
41
|
+
}
|
|
42
|
+
)
|
|
43
|
+
continue
|
|
44
|
+
index += 1
|
|
45
|
+
return provider_messages
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _assistant_content(message: AIMessage) -> Any:
|
|
49
|
+
provider_blocks = message.additional_kwargs.get("provider_blocks")
|
|
50
|
+
tool_calls = message.additional_kwargs.get("tool_calls_data") or message.tool_calls or []
|
|
51
|
+
if provider_blocks:
|
|
52
|
+
content_blocks = [
|
|
53
|
+
block
|
|
54
|
+
for block in provider_blocks
|
|
55
|
+
if not (
|
|
56
|
+
isinstance(block, dict)
|
|
57
|
+
and block.get("type") in {"reasoning_content", "tool_use"}
|
|
58
|
+
)
|
|
59
|
+
]
|
|
60
|
+
content_blocks.extend(_tool_use_blocks(tool_calls))
|
|
61
|
+
return content_blocks or message.content
|
|
62
|
+
|
|
63
|
+
if not tool_calls:
|
|
64
|
+
return message.content
|
|
65
|
+
|
|
66
|
+
content: list[dict[str, Any]] = []
|
|
67
|
+
if message.content:
|
|
68
|
+
content.append({"type": "text", "text": str(message.content)})
|
|
69
|
+
content.extend(_tool_use_blocks(tool_calls))
|
|
70
|
+
return content
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _assistant_reasoning_content(message: AIMessage) -> str | None:
|
|
74
|
+
reasoning_content = message.additional_kwargs.get("reasoning_content")
|
|
75
|
+
if reasoning_content:
|
|
76
|
+
return str(reasoning_content)
|
|
77
|
+
|
|
78
|
+
provider_blocks = message.additional_kwargs.get("provider_blocks") or []
|
|
79
|
+
for block in provider_blocks:
|
|
80
|
+
if not isinstance(block, dict) or block.get("type") != "reasoning_content":
|
|
81
|
+
continue
|
|
82
|
+
value = block.get("reasoning_content") or block.get("text")
|
|
83
|
+
if value:
|
|
84
|
+
return str(value)
|
|
85
|
+
return None
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _tool_use_blocks(tool_calls: list[Any]) -> list[dict[str, Any]]:
|
|
89
|
+
blocks: list[dict[str, Any]] = []
|
|
90
|
+
for tool_call in tool_calls:
|
|
91
|
+
blocks.append(
|
|
92
|
+
{
|
|
93
|
+
"type": "tool_use",
|
|
94
|
+
"id": _tool_call_field(tool_call, "id"),
|
|
95
|
+
"name": _tool_call_field(tool_call, "name"),
|
|
96
|
+
"input": _tool_call_field(tool_call, "args") or {},
|
|
97
|
+
}
|
|
98
|
+
)
|
|
99
|
+
return blocks
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _tool_call_field(tool_call: Any, field: str) -> Any:
|
|
103
|
+
if isinstance(tool_call, dict):
|
|
104
|
+
return tool_call.get(field)
|
|
105
|
+
return getattr(tool_call, field, None)
|
agent/nodes/llm_node.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
"""LLM graph node."""
|
|
2
|
+
|
|
3
|
+
from langchain_core.messages import AIMessage
|
|
4
|
+
|
|
5
|
+
from agent.llm_retry import chat_with_retry
|
|
6
|
+
from agent.message_format import messages_to_provider_format
|
|
7
|
+
from agent.nodes.state import AgentState
|
|
8
|
+
from agent.runtime.context import AgentRuntimeContext
|
|
9
|
+
from agent.streaming import StreamEvent, make_provider_stream_callback
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def create_llm_node(runtime: AgentRuntimeContext):
|
|
13
|
+
"""Create LLM node with given runtime."""
|
|
14
|
+
provider_stream_callback = make_provider_stream_callback(
|
|
15
|
+
runtime.stream_callback,
|
|
16
|
+
source="main",
|
|
17
|
+
session_id=runtime.session_id,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
async def llm_node(state: AgentState) -> AgentState:
|
|
21
|
+
response = await chat_with_retry(
|
|
22
|
+
runtime.provider,
|
|
23
|
+
messages=messages_to_provider_format(state["messages"]),
|
|
24
|
+
tools=runtime.tools,
|
|
25
|
+
system_prompt=runtime.system_prompt,
|
|
26
|
+
stream_callback=provider_stream_callback,
|
|
27
|
+
event_callback=runtime.stream_callback,
|
|
28
|
+
source="main",
|
|
29
|
+
session_id=runtime.session_id,
|
|
30
|
+
)
|
|
31
|
+
if runtime.stream_callback and response.usage:
|
|
32
|
+
await runtime.stream_callback(
|
|
33
|
+
StreamEvent(
|
|
34
|
+
source="main",
|
|
35
|
+
session_id=runtime.session_id,
|
|
36
|
+
event_type="usage",
|
|
37
|
+
usage=response.usage,
|
|
38
|
+
)
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
tool_calls = [
|
|
42
|
+
{
|
|
43
|
+
"name": tc.name,
|
|
44
|
+
"args": dict(tc.args or {}),
|
|
45
|
+
"id": tc.id,
|
|
46
|
+
}
|
|
47
|
+
for tc in response.tool_calls
|
|
48
|
+
]
|
|
49
|
+
|
|
50
|
+
ai_msg = AIMessage(content=response.content, tool_calls=tool_calls)
|
|
51
|
+
ai_msg.additional_kwargs["tool_calls_data"] = response.tool_calls
|
|
52
|
+
if response.content_blocks:
|
|
53
|
+
ai_msg.additional_kwargs["provider_blocks"] = response.content_blocks
|
|
54
|
+
ai_msg.additional_kwargs["raw_response"] = response.raw_response
|
|
55
|
+
ai_msg.additional_kwargs["usage"] = response.usage
|
|
56
|
+
return {"messages": [ai_msg]}
|
|
57
|
+
|
|
58
|
+
return llm_node
|
agent/nodes/state.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""Shared LangGraph state types."""
|
|
2
|
+
|
|
3
|
+
from typing import Annotated, TypedDict
|
|
4
|
+
|
|
5
|
+
from langchain_core.messages import BaseMessage
|
|
6
|
+
from langgraph.graph.message import add_messages
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class AgentState(TypedDict):
|
|
10
|
+
"""Agent graph state."""
|
|
11
|
+
|
|
12
|
+
messages: Annotated[list[BaseMessage], add_messages]
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
"""Task State guard graph node."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from langchain_core.messages import HumanMessage
|
|
6
|
+
from langgraph.graph import END
|
|
7
|
+
|
|
8
|
+
from agent.nodes.state import AgentState
|
|
9
|
+
from agent.todo_manager import TodoManager
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def create_task_guard_node(todo_manager: TodoManager):
|
|
13
|
+
"""Create a guard node that prevents finishing before Task State is complete."""
|
|
14
|
+
|
|
15
|
+
async def task_guard_node(state: AgentState) -> AgentState:
|
|
16
|
+
if not todo_manager.has_incomplete_task_state():
|
|
17
|
+
return {"messages": []}
|
|
18
|
+
return {
|
|
19
|
+
"messages": [
|
|
20
|
+
HumanMessage(
|
|
21
|
+
content=todo_manager.get_finish_blocker_message(),
|
|
22
|
+
additional_kwargs={
|
|
23
|
+
"context_ephemeral": True,
|
|
24
|
+
"ephemeral_kind": "task_guard",
|
|
25
|
+
},
|
|
26
|
+
)
|
|
27
|
+
]
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
return task_guard_node
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def route_after_llm(state: AgentState) -> Literal["tools", "task_guard"]:
|
|
34
|
+
"""Route to tools when the model requested tools, otherwise to task guard."""
|
|
35
|
+
last_msg = state["messages"][-1]
|
|
36
|
+
tool_calls_data = last_msg.additional_kwargs.get("tool_calls_data", [])
|
|
37
|
+
return "tools" if tool_calls_data else "task_guard"
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def route_after_tools(state: AgentState) -> Literal["llm", END]:
|
|
41
|
+
"""End when tools preserved a final answer after completing Task State."""
|
|
42
|
+
last_msg = state["messages"][-1]
|
|
43
|
+
if last_msg.additional_kwargs.get("task_completed_final") is True:
|
|
44
|
+
return END
|
|
45
|
+
return "llm"
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def route_after_task_guard(state: AgentState, todo_manager: TodoManager) -> Literal["llm", END]:
|
|
49
|
+
"""Route after task guard."""
|
|
50
|
+
return END if not todo_manager.has_incomplete_task_state() else "llm"
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
"""Tools graph node."""
|
|
2
|
+
|
|
3
|
+
from langchain_core.messages import AIMessage, HumanMessage
|
|
4
|
+
|
|
5
|
+
from agent.nodes.state import AgentState
|
|
6
|
+
from agent.runtime.approval_service import ApprovalService
|
|
7
|
+
from agent.runtime.context import AgentRuntimeContext
|
|
8
|
+
from agent.runtime.tool_executor import ToolExecutor
|
|
9
|
+
from agent.runtime.tool_registry import RuntimeToolRegistry
|
|
10
|
+
from agent.runtime.tool_scheduler import execute_tool_calls
|
|
11
|
+
from agent.runtime.workflow_guard import WorkflowGuard
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def create_tools_node(runtime: AgentRuntimeContext):
|
|
15
|
+
"""Create tools node with runtime-bound handlers."""
|
|
16
|
+
registry = RuntimeToolRegistry(runtime)
|
|
17
|
+
workflow_guard = WorkflowGuard(runtime, registry)
|
|
18
|
+
approval_service = ApprovalService(
|
|
19
|
+
runtime.approval_callback,
|
|
20
|
+
runtime.workflow_state,
|
|
21
|
+
runtime.stream_callback,
|
|
22
|
+
runtime.session_id,
|
|
23
|
+
source=runtime.source,
|
|
24
|
+
role=runtime.role,
|
|
25
|
+
parent_session_id=runtime.parent_session_id,
|
|
26
|
+
workdir=runtime.workdir,
|
|
27
|
+
)
|
|
28
|
+
executor = ToolExecutor(runtime, registry, workflow_guard, approval_service)
|
|
29
|
+
|
|
30
|
+
async def tools_node(state: AgentState) -> AgentState:
|
|
31
|
+
last_msg = state["messages"][-1]
|
|
32
|
+
tool_calls_data = last_msg.additional_kwargs.get("tool_calls_data", [])
|
|
33
|
+
tool_messages = await execute_tool_calls(
|
|
34
|
+
tool_calls_data,
|
|
35
|
+
executor.execute,
|
|
36
|
+
registry.can_run_concurrently,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
if tool_calls_data:
|
|
40
|
+
if any(tc.name == "todo" for tc in tool_calls_data):
|
|
41
|
+
runtime.todo_manager.record_tool_call("todo")
|
|
42
|
+
else:
|
|
43
|
+
runtime.todo_manager.record_tool_call(tool_calls_data[0].name)
|
|
44
|
+
|
|
45
|
+
additional_messages = workflow_guard.after_batch_messages(tool_calls_data)
|
|
46
|
+
repeated_todo_message = runtime.todo_manager.consume_repeated_incomplete_message()
|
|
47
|
+
if repeated_todo_message:
|
|
48
|
+
additional_messages.append(
|
|
49
|
+
HumanMessage(
|
|
50
|
+
content=repeated_todo_message,
|
|
51
|
+
additional_kwargs={
|
|
52
|
+
"context_ephemeral": True,
|
|
53
|
+
"ephemeral_kind": "task_repeated_reminder",
|
|
54
|
+
},
|
|
55
|
+
)
|
|
56
|
+
)
|
|
57
|
+
if (
|
|
58
|
+
any(tc.name == "todo" for tc in tool_calls_data)
|
|
59
|
+
and runtime.todo_manager.can_finish_task()
|
|
60
|
+
and str(last_msg.content or "").strip()
|
|
61
|
+
):
|
|
62
|
+
additional_messages.append(
|
|
63
|
+
AIMessage(
|
|
64
|
+
content=last_msg.content,
|
|
65
|
+
additional_kwargs={"task_completed_final": True},
|
|
66
|
+
)
|
|
67
|
+
)
|
|
68
|
+
return {"messages": tool_messages + additional_messages}
|
|
69
|
+
|
|
70
|
+
return tools_node
|
agent/plan_snapshot.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
"""Shared task plan snapshot models."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from datetime import datetime, timezone
|
|
7
|
+
from typing import Any, Literal
|
|
8
|
+
|
|
9
|
+
from agent.todo_manager import TodoManager
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
PlanStatus = Literal["pending", "in_progress", "completed"]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass(frozen=True)
|
|
16
|
+
class PlanEntry:
|
|
17
|
+
"""One stable task plan entry for UI/protocol adapters."""
|
|
18
|
+
|
|
19
|
+
id: str
|
|
20
|
+
title: str
|
|
21
|
+
status: PlanStatus
|
|
22
|
+
priority: str = "medium"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass(frozen=True)
|
|
26
|
+
class PlanSnapshot:
|
|
27
|
+
"""Public snapshot of the current task plan and compact memory."""
|
|
28
|
+
|
|
29
|
+
entries: list[PlanEntry] = field(default_factory=list)
|
|
30
|
+
memory: dict[str, Any] = field(default_factory=dict)
|
|
31
|
+
updated_at: str = ""
|
|
32
|
+
task_started: bool = False
|
|
33
|
+
task_completed: bool = False
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def build_plan_snapshot(todo_manager: TodoManager | None) -> PlanSnapshot:
|
|
37
|
+
"""Return a stable task plan snapshot independent of any UI renderer."""
|
|
38
|
+
if todo_manager is None:
|
|
39
|
+
return PlanSnapshot(updated_at=_utc_now())
|
|
40
|
+
state = todo_manager.get_task_state()
|
|
41
|
+
raw_items = state.get("items") or []
|
|
42
|
+
entries = [
|
|
43
|
+
PlanEntry(
|
|
44
|
+
id=str(item.get("id") or index + 1),
|
|
45
|
+
title=str(item.get("text") or ""),
|
|
46
|
+
status=_normalize_status(item.get("status")),
|
|
47
|
+
priority="high" if item.get("status") == "in_progress" else "medium",
|
|
48
|
+
)
|
|
49
|
+
for index, item in enumerate(raw_items)
|
|
50
|
+
if isinstance(item, dict)
|
|
51
|
+
]
|
|
52
|
+
return PlanSnapshot(
|
|
53
|
+
entries=entries,
|
|
54
|
+
memory=dict(state.get("memory") or {}),
|
|
55
|
+
updated_at=_utc_now(),
|
|
56
|
+
task_started=bool(todo_manager.task_state_started),
|
|
57
|
+
task_completed=bool(todo_manager.task_completed),
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _normalize_status(value: object) -> PlanStatus:
|
|
62
|
+
if value == "completed":
|
|
63
|
+
return "completed"
|
|
64
|
+
if value == "in_progress":
|
|
65
|
+
return "in_progress"
|
|
66
|
+
return "pending"
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _utc_now() -> str:
|
|
70
|
+
return datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
"""LLM providers package."""
|
|
2
|
+
|
|
3
|
+
from .base import LLMProvider, ChatResponse, ToolCall
|
|
4
|
+
from .anthropic_provider import AnthropicProvider
|
|
5
|
+
from .openai_provider import OpenAIProvider
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"LLMProvider",
|
|
9
|
+
"ChatResponse",
|
|
10
|
+
"ToolCall",
|
|
11
|
+
"AnthropicProvider",
|
|
12
|
+
"OpenAIProvider",
|
|
13
|
+
]
|
|
@@ -0,0 +1,268 @@
|
|
|
1
|
+
"""Anthropic LLM provider implementation."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from typing import Any, Optional, Callable
|
|
5
|
+
|
|
6
|
+
from anthropic import AsyncAnthropic
|
|
7
|
+
|
|
8
|
+
from agent.logger import get_logger
|
|
9
|
+
|
|
10
|
+
from .base import LLMProvider, ChatResponse, ToolCall
|
|
11
|
+
from .text_tool_calls import TextToolCallStreamFilter, parse_text_tool_calls
|
|
12
|
+
|
|
13
|
+
logger = get_logger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class AnthropicProvider(LLMProvider):
|
|
17
|
+
"""Anthropic Claude API provider."""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
api_key: str,
|
|
22
|
+
model: str,
|
|
23
|
+
base_url: Optional[str] = None,
|
|
24
|
+
):
|
|
25
|
+
self.model = model
|
|
26
|
+
self.client = AsyncAnthropic(
|
|
27
|
+
api_key=api_key,
|
|
28
|
+
base_url=base_url
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
async def chat(
|
|
32
|
+
self,
|
|
33
|
+
messages: list[dict],
|
|
34
|
+
tools: list[dict],
|
|
35
|
+
system_prompt: Optional[str] = None,
|
|
36
|
+
stream_callback: Optional[Callable[[str, str], None]] = None,
|
|
37
|
+
) -> ChatResponse:
|
|
38
|
+
"""Send chat request to Anthropic API."""
|
|
39
|
+
import traceback
|
|
40
|
+
|
|
41
|
+
kwargs = {
|
|
42
|
+
"model": self.model,
|
|
43
|
+
"messages": messages,
|
|
44
|
+
"max_tokens": 8000,
|
|
45
|
+
}
|
|
46
|
+
if system_prompt:
|
|
47
|
+
kwargs["system"] = system_prompt
|
|
48
|
+
if tools:
|
|
49
|
+
kwargs["tools"] = tools
|
|
50
|
+
|
|
51
|
+
try:
|
|
52
|
+
# First try non-streaming mode (more reliable for compatible APIs)
|
|
53
|
+
return await self._chat_non_streaming(kwargs, stream_callback)
|
|
54
|
+
except Exception as e:
|
|
55
|
+
logger.warning(f"Non-streaming failed, trying streaming: {e}")
|
|
56
|
+
try:
|
|
57
|
+
return await self._chat_streaming(kwargs, stream_callback)
|
|
58
|
+
except Exception as e2:
|
|
59
|
+
logger.error(f"Both modes failed. Last error: {type(e2).__name__}: {e2}")
|
|
60
|
+
logger.error(f"Traceback:\n{traceback.format_exc()}")
|
|
61
|
+
raise
|
|
62
|
+
|
|
63
|
+
async def _chat_non_streaming(
|
|
64
|
+
self,
|
|
65
|
+
kwargs: dict,
|
|
66
|
+
stream_callback: Optional[Callable[[str, str], None]] = None,
|
|
67
|
+
) -> ChatResponse:
|
|
68
|
+
"""Non-streaming chat mode."""
|
|
69
|
+
message = await self.client.messages.create(**kwargs)
|
|
70
|
+
|
|
71
|
+
current_text = ""
|
|
72
|
+
tool_calls_data = []
|
|
73
|
+
text_filter = TextToolCallStreamFilter()
|
|
74
|
+
|
|
75
|
+
for block in message.content:
|
|
76
|
+
if block.type == "text":
|
|
77
|
+
current_text += block.text
|
|
78
|
+
if stream_callback:
|
|
79
|
+
for safe_text in text_filter.feed(block.text):
|
|
80
|
+
await stream_callback("text_delta", safe_text)
|
|
81
|
+
elif block.type == "tool_use":
|
|
82
|
+
tool_calls_data.append({
|
|
83
|
+
"name": block.name,
|
|
84
|
+
"args": block.input,
|
|
85
|
+
"id": block.id,
|
|
86
|
+
})
|
|
87
|
+
|
|
88
|
+
usage = self._extract_usage(getattr(message, "usage", None))
|
|
89
|
+
content_blocks = self._normalize_content_blocks(message.content)
|
|
90
|
+
|
|
91
|
+
tool_calls = [
|
|
92
|
+
ToolCall(id=tc["id"], name=tc["name"], args=tc["args"])
|
|
93
|
+
for tc in tool_calls_data
|
|
94
|
+
]
|
|
95
|
+
cleaned_text, text_tool_calls = parse_text_tool_calls(current_text)
|
|
96
|
+
if text_tool_calls:
|
|
97
|
+
current_text = cleaned_text
|
|
98
|
+
tool_calls.extend(text_tool_calls)
|
|
99
|
+
elif stream_callback:
|
|
100
|
+
for safe_text in text_filter.flush():
|
|
101
|
+
await stream_callback("text_delta", safe_text)
|
|
102
|
+
|
|
103
|
+
return ChatResponse(
|
|
104
|
+
content=current_text,
|
|
105
|
+
tool_calls=tool_calls,
|
|
106
|
+
content_blocks=content_blocks,
|
|
107
|
+
raw_response=message,
|
|
108
|
+
usage=usage,
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
async def _chat_streaming(
|
|
112
|
+
self,
|
|
113
|
+
kwargs: dict,
|
|
114
|
+
stream_callback: Optional[Callable[[str, str], None]] = None,
|
|
115
|
+
) -> ChatResponse:
|
|
116
|
+
"""Streaming chat mode (fallback)."""
|
|
117
|
+
current_text = ""
|
|
118
|
+
tool_calls_data = []
|
|
119
|
+
current_tool_use = None
|
|
120
|
+
in_thinking = False
|
|
121
|
+
text_filter = TextToolCallStreamFilter()
|
|
122
|
+
|
|
123
|
+
async with self.client.messages.stream(**kwargs) as stream:
|
|
124
|
+
async for event in stream:
|
|
125
|
+
if event.type == "content_block_start":
|
|
126
|
+
block = event.content_block
|
|
127
|
+
if block.type == "thinking":
|
|
128
|
+
in_thinking = True
|
|
129
|
+
if stream_callback:
|
|
130
|
+
await stream_callback("thinking_start", "")
|
|
131
|
+
elif block.type == "text":
|
|
132
|
+
in_thinking = False
|
|
133
|
+
elif block.type == "tool_use":
|
|
134
|
+
in_thinking = False
|
|
135
|
+
current_tool_use = {
|
|
136
|
+
"name": block.name,
|
|
137
|
+
"id": block.id,
|
|
138
|
+
"args": "",
|
|
139
|
+
}
|
|
140
|
+
elif event.type == "content_block_delta":
|
|
141
|
+
delta = event.delta
|
|
142
|
+
if delta.type == "thinking_delta":
|
|
143
|
+
if stream_callback:
|
|
144
|
+
await stream_callback("thinking_delta", delta.thinking)
|
|
145
|
+
elif delta.type == "text_delta":
|
|
146
|
+
current_text += delta.text
|
|
147
|
+
if stream_callback:
|
|
148
|
+
for safe_text in text_filter.feed(delta.text):
|
|
149
|
+
await stream_callback("text_delta", safe_text)
|
|
150
|
+
elif delta.type == "input_json_delta":
|
|
151
|
+
if current_tool_use:
|
|
152
|
+
current_tool_use["args"] += delta.partial_json
|
|
153
|
+
elif event.type == "content_block_stop":
|
|
154
|
+
if in_thinking:
|
|
155
|
+
if stream_callback:
|
|
156
|
+
await stream_callback("thinking_end", "")
|
|
157
|
+
in_thinking = False
|
|
158
|
+
elif current_tool_use:
|
|
159
|
+
try:
|
|
160
|
+
args = json.loads(current_tool_use["args"])
|
|
161
|
+
except json.JSONDecodeError:
|
|
162
|
+
args = {}
|
|
163
|
+
tool_calls_data.append({
|
|
164
|
+
"name": current_tool_use["name"],
|
|
165
|
+
"args": args,
|
|
166
|
+
"id": current_tool_use["id"],
|
|
167
|
+
})
|
|
168
|
+
current_tool_use = None
|
|
169
|
+
|
|
170
|
+
final_message = await stream.get_final_message()
|
|
171
|
+
usage = self._extract_usage(getattr(final_message, "usage", None))
|
|
172
|
+
content_blocks = self._normalize_content_blocks(final_message.content)
|
|
173
|
+
|
|
174
|
+
tool_calls = [
|
|
175
|
+
ToolCall(id=tc["id"], name=tc["name"], args=tc["args"])
|
|
176
|
+
for tc in tool_calls_data
|
|
177
|
+
]
|
|
178
|
+
cleaned_text, text_tool_calls = parse_text_tool_calls(current_text)
|
|
179
|
+
if text_tool_calls:
|
|
180
|
+
current_text = cleaned_text
|
|
181
|
+
tool_calls.extend(text_tool_calls)
|
|
182
|
+
elif stream_callback:
|
|
183
|
+
for safe_text in text_filter.flush():
|
|
184
|
+
await stream_callback("text_delta", safe_text)
|
|
185
|
+
|
|
186
|
+
return ChatResponse(
|
|
187
|
+
content=current_text,
|
|
188
|
+
tool_calls=tool_calls,
|
|
189
|
+
content_blocks=content_blocks,
|
|
190
|
+
raw_response=final_message,
|
|
191
|
+
usage=usage,
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
async def close(self) -> None:
|
|
195
|
+
"""Close the client."""
|
|
196
|
+
await self.client.close()
|
|
197
|
+
|
|
198
|
+
async def count_tokens(
|
|
199
|
+
self,
|
|
200
|
+
messages: list[dict],
|
|
201
|
+
system_prompt: Optional[str] = None,
|
|
202
|
+
tools: Optional[list[dict]] = None,
|
|
203
|
+
) -> Optional[int]:
|
|
204
|
+
"""Count input tokens using the Anthropic-compatible count endpoint."""
|
|
205
|
+
try:
|
|
206
|
+
kwargs = {
|
|
207
|
+
"model": self.model,
|
|
208
|
+
"messages": messages,
|
|
209
|
+
}
|
|
210
|
+
if system_prompt:
|
|
211
|
+
kwargs["system"] = system_prompt
|
|
212
|
+
if tools:
|
|
213
|
+
kwargs["tools"] = tools
|
|
214
|
+
|
|
215
|
+
response = await self.client.messages.count_tokens(**kwargs)
|
|
216
|
+
input_tokens = getattr(response, "input_tokens", None)
|
|
217
|
+
return int(input_tokens) if input_tokens is not None else None
|
|
218
|
+
except Exception:
|
|
219
|
+
logger.warning("Count tokens not supported, falling back to estimation")
|
|
220
|
+
return None
|
|
221
|
+
|
|
222
|
+
def _extract_usage(self, usage: Any) -> Optional[dict[str, int]]:
|
|
223
|
+
"""Normalize Anthropic usage data."""
|
|
224
|
+
if usage is None:
|
|
225
|
+
return None
|
|
226
|
+
input_tokens = getattr(usage, "input_tokens", None)
|
|
227
|
+
output_tokens = getattr(usage, "output_tokens", None)
|
|
228
|
+
if input_tokens is None and output_tokens is None:
|
|
229
|
+
return None
|
|
230
|
+
return {
|
|
231
|
+
"input_tokens": input_tokens or 0,
|
|
232
|
+
"output_tokens": output_tokens or 0,
|
|
233
|
+
"total_tokens": (input_tokens or 0) + (output_tokens or 0),
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
def _normalize_content_blocks(self, blocks: Any) -> list[dict[str, Any]]:
|
|
237
|
+
"""Convert Anthropic content blocks into provider-neutral serializable dicts."""
|
|
238
|
+
normalized: list[dict[str, Any]] = []
|
|
239
|
+
for block in blocks or []:
|
|
240
|
+
block_type = getattr(block, "type", None)
|
|
241
|
+
if block_type == "text":
|
|
242
|
+
normalized.append({"type": "text", "text": getattr(block, "text", "")})
|
|
243
|
+
elif block_type == "thinking":
|
|
244
|
+
thinking_block = {
|
|
245
|
+
"type": "thinking",
|
|
246
|
+
"thinking": getattr(block, "thinking", ""),
|
|
247
|
+
}
|
|
248
|
+
signature = getattr(block, "signature", None)
|
|
249
|
+
if signature:
|
|
250
|
+
thinking_block["signature"] = signature
|
|
251
|
+
normalized.append(thinking_block)
|
|
252
|
+
elif block_type == "redacted_thinking":
|
|
253
|
+
data = {"type": "redacted_thinking"}
|
|
254
|
+
for field in ("data", "signature"):
|
|
255
|
+
value = getattr(block, field, None)
|
|
256
|
+
if value:
|
|
257
|
+
data[field] = value
|
|
258
|
+
normalized.append(data)
|
|
259
|
+
elif block_type == "tool_use":
|
|
260
|
+
normalized.append(
|
|
261
|
+
{
|
|
262
|
+
"type": "tool_use",
|
|
263
|
+
"id": getattr(block, "id", None),
|
|
264
|
+
"name": getattr(block, "name", None),
|
|
265
|
+
"input": getattr(block, "input", None) or {},
|
|
266
|
+
}
|
|
267
|
+
)
|
|
268
|
+
return normalized
|