yycode 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. agent/__init__.py +33 -0
  2. agent/acp/__init__.py +2 -0
  3. agent/acp/approval_adapter.py +134 -0
  4. agent/acp/content_adapter.py +45 -0
  5. agent/acp/jsonrpc.py +92 -0
  6. agent/acp/server.py +197 -0
  7. agent/acp/session_manager.py +193 -0
  8. agent/acp/update_adapter.py +192 -0
  9. agent/app_paths.py +25 -0
  10. agent/approval.py +169 -0
  11. agent/cancellation.py +52 -0
  12. agent/change_snapshot.py +186 -0
  13. agent/context_compressor.py +116 -0
  14. agent/graph.py +137 -0
  15. agent/llm_retry.py +434 -0
  16. agent/logger.py +97 -0
  17. agent/lsp/__init__.py +13 -0
  18. agent/lsp/client.py +151 -0
  19. agent/lsp/manager.py +234 -0
  20. agent/lsp/types.py +119 -0
  21. agent/message_context_manager.py +322 -0
  22. agent/message_format.py +105 -0
  23. agent/nodes/llm_node.py +58 -0
  24. agent/nodes/state.py +12 -0
  25. agent/nodes/task_guard_node.py +50 -0
  26. agent/nodes/tools_node.py +70 -0
  27. agent/plan_snapshot.py +70 -0
  28. agent/providers/__init__.py +13 -0
  29. agent/providers/anthropic_provider.py +268 -0
  30. agent/providers/base.py +52 -0
  31. agent/providers/openai_provider.py +279 -0
  32. agent/providers/text_tool_calls.py +118 -0
  33. agent/runtime/approval_service.py +184 -0
  34. agent/runtime/context.py +43 -0
  35. agent/runtime/tool_events.py +368 -0
  36. agent/runtime/tool_executor.py +208 -0
  37. agent/runtime/tool_output.py +261 -0
  38. agent/runtime/tool_registry.py +91 -0
  39. agent/runtime/tool_scheduler.py +35 -0
  40. agent/runtime/workflow_guard.py +217 -0
  41. agent/runtime/workspace.py +5 -0
  42. agent/runtime/workspace_tools.py +22 -0
  43. agent/session.py +787 -0
  44. agent/session_replay.py +95 -0
  45. agent/session_store.py +186 -0
  46. agent/skills.py +254 -0
  47. agent/streaming.py +248 -0
  48. agent/subagent.py +634 -0
  49. agent/task_memory.py +340 -0
  50. agent/todo_manager.py +304 -0
  51. agent/tool_retry.py +106 -0
  52. agent/tui/__init__.py +14 -0
  53. agent/tui/app.py +1325 -0
  54. agent/tui/approval.py +53 -0
  55. agent/tui/commands/__init__.py +6 -0
  56. agent/tui/commands/base.py +48 -0
  57. agent/tui/commands/clear.py +37 -0
  58. agent/tui/commands/help.py +27 -0
  59. agent/tui/commands/registry.py +94 -0
  60. agent/tui/help_content.py +108 -0
  61. agent/tui/renderers.py +1961 -0
  62. agent/tui/runner.py +439 -0
  63. agent/tui/state.py +653 -0
  64. main.py +465 -0
  65. tools/__init__.py +50 -0
  66. tools/apply_patch.py +305 -0
  67. tools/bash.py +76 -0
  68. tools/diff_utils.py +139 -0
  69. tools/edit_file.py +40 -0
  70. tools/git_diff.py +72 -0
  71. tools/git_show.py +65 -0
  72. tools/grep.py +149 -0
  73. tools/list_files.py +90 -0
  74. tools/list_skills.py +24 -0
  75. tools/load_skill.py +30 -0
  76. tools/lsp_definition.py +27 -0
  77. tools/lsp_diagnostics.py +32 -0
  78. tools/lsp_document_symbols.py +23 -0
  79. tools/lsp_hover.py +29 -0
  80. tools/lsp_references.py +37 -0
  81. tools/lsp_utils.py +38 -0
  82. tools/lsp_workspace_symbols.py +23 -0
  83. tools/read_file.py +61 -0
  84. tools/read_many_files.py +50 -0
  85. tools/safety.py +50 -0
  86. tools/subagent.py +57 -0
  87. tools/todo.py +89 -0
  88. tools/verify.py +107 -0
  89. tools/web_search.py +250 -0
  90. tools/workspace.py +36 -0
  91. tools/workspace_state.py +60 -0
  92. tools/write_file.py +88 -0
  93. utils/__init__.py +5 -0
  94. utils/retry.py +13 -0
  95. yycode-0.3.2.data/data/skills/code_review.md +61 -0
  96. yycode-0.3.2.data/data/skills/code_workflow.md +404 -0
  97. yycode-0.3.2.data/data/skills/drawio/SKILL.md +636 -0
  98. yycode-0.3.2.data/data/skills/drawio/agents/openai.yaml +19 -0
  99. yycode-0.3.2.data/data/skills/drawio/assets/demo-erd.drawio +84 -0
  100. yycode-0.3.2.data/data/skills/drawio/assets/demo-layered-cn.drawio +91 -0
  101. yycode-0.3.2.data/data/skills/drawio/assets/demo-layered-cn.png +0 -0
  102. yycode-0.3.2.data/data/skills/drawio/assets/demo-layered.drawio +112 -0
  103. yycode-0.3.2.data/data/skills/drawio/assets/demo-layered.png +0 -0
  104. yycode-0.3.2.data/data/skills/drawio/assets/demo-ml.drawio +90 -0
  105. yycode-0.3.2.data/data/skills/drawio/assets/demo-ring-cn.drawio +68 -0
  106. yycode-0.3.2.data/data/skills/drawio/assets/demo-ring-cn.png +0 -0
  107. yycode-0.3.2.data/data/skills/drawio/assets/demo-ring.drawio +86 -0
  108. yycode-0.3.2.data/data/skills/drawio/assets/demo-ring.png +0 -0
  109. yycode-0.3.2.data/data/skills/drawio/assets/demo-sequence.drawio +116 -0
  110. yycode-0.3.2.data/data/skills/drawio/assets/demo-star-cn.drawio +66 -0
  111. yycode-0.3.2.data/data/skills/drawio/assets/demo-star-cn.png +0 -0
  112. yycode-0.3.2.data/data/skills/drawio/assets/demo-star.drawio +79 -0
  113. yycode-0.3.2.data/data/skills/drawio/assets/demo-star.png +0 -0
  114. yycode-0.3.2.data/data/skills/drawio/assets/demo-uml-class.drawio +64 -0
  115. yycode-0.3.2.data/data/skills/drawio/assets/microservices-example.drawio +173 -0
  116. yycode-0.3.2.data/data/skills/drawio/assets/microservices-example.png +0 -0
  117. yycode-0.3.2.data/data/skills/drawio/assets/workflow-cn.drawio +120 -0
  118. yycode-0.3.2.data/data/skills/drawio/assets/workflow-cn.png +0 -0
  119. yycode-0.3.2.data/data/skills/drawio/assets/workflow.drawio +120 -0
  120. yycode-0.3.2.data/data/skills/drawio/assets/workflow.png +0 -0
  121. yycode-0.3.2.data/data/skills/drawio/docs/index.html +469 -0
  122. yycode-0.3.2.data/data/skills/drawio/docs/zh.html +456 -0
  123. yycode-0.3.2.data/data/skills/drawio/references/style-extraction.md +254 -0
  124. yycode-0.3.2.data/data/skills/drawio/styles/schema.json +112 -0
  125. yycode-0.3.2.data/data/skills/plan.md +115 -0
  126. yycode-0.3.2.data/data/skills/ppt/SKILL.md +254 -0
  127. yycode-0.3.2.dist-info/METADATA +12 -0
  128. yycode-0.3.2.dist-info/RECORD +131 -0
  129. yycode-0.3.2.dist-info/WHEEL +5 -0
  130. yycode-0.3.2.dist-info/entry_points.txt +2 -0
  131. yycode-0.3.2.dist-info/top_level.txt +4 -0
@@ -0,0 +1,193 @@
1
+ """ACP session lifecycle management."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import os
7
+ from dataclasses import dataclass, field
8
+ from pathlib import Path
9
+ from typing import Any, Awaitable, Callable
10
+
11
+ from agent.acp.approval_adapter import AcpApprovalAdapter
12
+ from agent.acp.content_adapter import content_blocks_to_text
13
+ from agent.acp.update_adapter import (
14
+ plan_snapshot_to_update,
15
+ replay_event_to_updates,
16
+ stream_event_to_updates,
17
+ )
18
+ from agent.cancellation import CancellationController
19
+ from agent.plan_snapshot import build_plan_snapshot
20
+ from agent.session import Session
21
+ from agent.streaming import StreamEvent
22
+
23
+
24
+ Notifier = Callable[[str, dict[str, Any]], Awaitable[None]]
25
+ Requester = Callable[[str, dict[str, Any]], Awaitable[dict[str, Any]]]
26
+
27
+
28
+ async def auto_approval_callback(_request: Any) -> bool:
29
+ """Approve runtime approval requests without asking the ACP client."""
30
+ return True
31
+
32
+
33
+ @dataclass
34
+ class AcpManagedSession:
35
+ """Runtime state for one ACP session."""
36
+
37
+ session: Session
38
+ approval_adapter: AcpApprovalAdapter
39
+ cancel_controller: CancellationController = field(default_factory=CancellationController)
40
+
41
+
42
+ class AcpSessionManager:
43
+ """Create, load, run, and cancel ACP-backed yoyoagent sessions."""
44
+
45
+ def __init__(self, notify: Notifier, request: Requester, *, auto_approve: bool = False):
46
+ self.notify = notify
47
+ self.request = request
48
+ self.auto_approve = auto_approve
49
+ self.sessions: dict[str, AcpManagedSession] = {}
50
+
51
+ async def new_session(self, params: dict[str, Any]) -> dict[str, Any]:
52
+ """Create a new yoyoagent session for an ACP client."""
53
+ cwd = _resolve_cwd(params)
54
+ session = self._create_session(cwd)
55
+ self.sessions[session.id] = self._managed(session)
56
+ await self._send_available_commands(session)
57
+ return {"sessionId": session.id}
58
+
59
+ async def load_session(self, params: dict[str, Any]) -> dict[str, Any] | None:
60
+ """Load a persisted yoyoagent session and replay display events."""
61
+ cwd = _resolve_cwd(params)
62
+ session_id = _session_id_from_params(params)
63
+ session = self._create_session(cwd, session_id=session_id, resume=True)
64
+ self.sessions[session.id] = self._managed(session)
65
+ await self._send_available_commands(session)
66
+ for replay_event in session.replay_view():
67
+ for update in replay_event_to_updates(replay_event):
68
+ await self._send_update(session.id, update)
69
+ return {"sessionId": session.id}
70
+
71
+ async def prompt(self, params: dict[str, Any]) -> dict[str, Any]:
72
+ """Run one prompt turn."""
73
+ session_id = _session_id_from_params(params)
74
+ managed = self._require_session(session_id)
75
+ prompt_text = content_blocks_to_text(
76
+ params.get("content")
77
+ or params.get("prompt")
78
+ or params.get("message")
79
+ or params.get("input")
80
+ or ""
81
+ )
82
+ task = asyncio.create_task(managed.session.send(prompt_text))
83
+ managed.cancel_controller.set_task(task)
84
+ try:
85
+ await task
86
+ except asyncio.CancelledError:
87
+ managed.approval_adapter.cancel_pending()
88
+ return {"stopReason": "cancelled"}
89
+ finally:
90
+ managed.cancel_controller.clear_task(task)
91
+ await self._send_update(session_id, plan_snapshot_to_update(build_plan_snapshot(managed.session.todo_manager)))
92
+ return {"stopReason": "end_turn"}
93
+
94
+ async def cancel(self, params: dict[str, Any]) -> dict[str, Any]:
95
+ """Cancel an active prompt turn."""
96
+ session_id = _session_id_from_params(params)
97
+ managed = self._require_session(session_id)
98
+ managed.approval_adapter.cancel_pending()
99
+ result = await managed.cancel_controller.cancel()
100
+ return {"status": result.status}
101
+
102
+ async def close(self) -> None:
103
+ """Close all managed sessions."""
104
+ for managed in list(self.sessions.values()):
105
+ managed.approval_adapter.cancel_pending()
106
+ await managed.cancel_controller.cancel()
107
+ await managed.session.close()
108
+ self.sessions.clear()
109
+
110
+ def _managed(self, session: Session) -> AcpManagedSession:
111
+ approval = AcpApprovalAdapter(
112
+ session.id,
113
+ self.request,
114
+ workdir=session.workdir,
115
+ )
116
+ session.approval_callback = auto_approval_callback if self.auto_approve else approval.callback
117
+ session.stream_callback = self._stream_callback(session)
118
+ session._graph = None
119
+ return AcpManagedSession(session=session, approval_adapter=approval)
120
+
121
+ def _create_session(
122
+ self,
123
+ cwd: Path,
124
+ *,
125
+ session_id: str | None = None,
126
+ resume: bool = False,
127
+ ) -> Session:
128
+ return Session.from_config(
129
+ workdir=cwd,
130
+ session_id=session_id,
131
+ persist_messages=True,
132
+ resume=resume,
133
+ )
134
+
135
+ def _stream_callback(self, session: Session):
136
+ async def callback(event: StreamEvent) -> None:
137
+ for update in stream_event_to_updates(event, workdir=session.workdir):
138
+ await self._send_update(session.id, update)
139
+ if event.event_type == "tool_result" and event.tool_name == "todo":
140
+ await self._send_update(
141
+ session.id,
142
+ plan_snapshot_to_update(build_plan_snapshot(session.todo_manager)),
143
+ )
144
+
145
+ return callback
146
+
147
+ async def _send_update(self, session_id: str, update: dict[str, Any]) -> None:
148
+ await self.notify("session/update", {"sessionId": session_id, "update": update})
149
+
150
+ async def _send_available_commands(self, session: Session) -> None:
151
+ commands = [
152
+ {
153
+ "name": "/plan",
154
+ "description": "Discuss requirements and produce an implementation plan without executing changes.",
155
+ }
156
+ ]
157
+ for skill in session.skill_registry.list_skills():
158
+ commands.append(
159
+ {
160
+ "name": f"/{skill.name}",
161
+ "description": skill.description,
162
+ }
163
+ )
164
+ await self._send_update(
165
+ session.id,
166
+ {
167
+ "sessionUpdate": "available_commands_update",
168
+ "commands": commands,
169
+ },
170
+ )
171
+
172
+ def _require_session(self, session_id: str) -> AcpManagedSession:
173
+ if session_id not in self.sessions:
174
+ raise ValueError(f"Unknown ACP session: {session_id}")
175
+ return self.sessions[session_id]
176
+
177
+
178
+ def _resolve_cwd(params: dict[str, Any]) -> Path:
179
+ raw = params.get("cwd") or params.get("workdir") or os.getcwd()
180
+ cwd = Path(str(raw)).expanduser()
181
+ if not cwd.is_absolute():
182
+ cwd = Path.cwd() / cwd
183
+ cwd = cwd.resolve()
184
+ if not cwd.exists() or not cwd.is_dir():
185
+ raise ValueError(f"cwd must be an existing directory: {cwd}")
186
+ return cwd
187
+
188
+
189
+ def _session_id_from_params(params: dict[str, Any]) -> str:
190
+ session_id = params.get("sessionId") or params.get("session_id") or params.get("id")
191
+ if not session_id:
192
+ raise ValueError("sessionId is required")
193
+ return str(session_id)
@@ -0,0 +1,192 @@
1
+ """Map yoyoagent stream and replay events to ACP session updates."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ from agent.plan_snapshot import PlanSnapshot
9
+ from agent.session_replay import ReplayEvent
10
+ from agent.streaming import StreamEvent
11
+
12
+
13
+ READ_TOOLS = {"read_file", "read_many_files", "list_files", "git_show", "workspace_state", "git_diff"}
14
+ SEARCH_TOOLS = {"grep", "web_search"}
15
+ EDIT_TOOLS = {"apply_patch", "write_file", "edit_file"}
16
+ EXECUTE_TOOLS = {"bash", "verify"}
17
+
18
+
19
+ def stream_event_to_updates(event: StreamEvent, *, workdir: Path | None = None) -> list[dict[str, Any]]:
20
+ """Return ACP session/update payloads for one yoyoagent stream event."""
21
+ if event.event_type == "text_delta":
22
+ return [_update("agent_message_chunk", {"content": event.content})]
23
+ if event.event_type == "tool_start":
24
+ return [
25
+ _update(
26
+ "tool_call",
27
+ {
28
+ "toolCallId": _tool_call_id(event),
29
+ "title": event.title or event.content or event.tool_name or "Tool call",
30
+ "kind": _tool_kind(event.tool_name),
31
+ "status": _tool_status(event.status, default="in_progress"),
32
+ "content": _tool_content(event),
33
+ "locations": _locations(event.file_paths, workdir),
34
+ "rawInput": (event.metadata or {}).get("args") or event.metadata or {},
35
+ "_meta": {"yoyo": event.to_dict()},
36
+ },
37
+ )
38
+ ]
39
+ if event.event_type == "tool_end":
40
+ return [
41
+ _update(
42
+ "tool_call_update",
43
+ {
44
+ "toolCallId": _tool_call_id(event),
45
+ "status": _tool_status(event.status, default="completed"),
46
+ "elapsedMs": event.elapsed_ms,
47
+ "_meta": {"yoyo": event.to_dict()},
48
+ },
49
+ )
50
+ ]
51
+ if event.event_type == "tool_result":
52
+ return [
53
+ _update(
54
+ "tool_call_update",
55
+ {
56
+ "toolCallId": _tool_call_id(event),
57
+ "title": event.title,
58
+ "kind": _tool_kind(event.tool_name),
59
+ "status": _tool_status(event.status, default="completed"),
60
+ "content": [{"type": "text", "text": event.content}],
61
+ "locations": _locations(event.file_paths, workdir),
62
+ "rawOutput": event.content,
63
+ "_meta": {"yoyo": event.to_dict()},
64
+ },
65
+ )
66
+ ]
67
+ if event.event_type in {"context_compressed", "context_summarized"}:
68
+ return []
69
+ if event.event_type == "session_warning":
70
+ return [_update("agent_message_chunk", {"content": f"\n[{event.title or 'context'}] {event.content}\n"})]
71
+ if event.event_type == "usage":
72
+ return [_update("usage", {"usage": event.usage or {}, "_meta": {"yoyo": event.to_dict()}})]
73
+ if event.event_type in {"llm_waiting", "llm_timeout", "llm_retry", "llm_error"}:
74
+ return [
75
+ _update(
76
+ "status",
77
+ {
78
+ "title": event.title or "Model status",
79
+ "content": event.content,
80
+ "status": event.status or "running",
81
+ "_meta": {"yoyo": event.to_dict()},
82
+ },
83
+ )
84
+ ]
85
+ return []
86
+
87
+
88
+ def plan_snapshot_to_update(snapshot: PlanSnapshot) -> dict[str, Any]:
89
+ """Return an ACP plan update payload from a public plan snapshot."""
90
+ return _update(
91
+ "plan",
92
+ {
93
+ "entries": [
94
+ {
95
+ "id": entry.id,
96
+ "title": entry.title,
97
+ "status": entry.status,
98
+ "priority": entry.priority,
99
+ }
100
+ for entry in snapshot.entries
101
+ ],
102
+ "_meta": {
103
+ "yoyo": {
104
+ "memory": snapshot.memory,
105
+ "updatedAt": snapshot.updated_at,
106
+ "taskStarted": snapshot.task_started,
107
+ "taskCompleted": snapshot.task_completed,
108
+ }
109
+ },
110
+ },
111
+ )
112
+
113
+
114
+ def replay_event_to_updates(event: ReplayEvent) -> list[dict[str, Any]]:
115
+ """Return ACP replay updates for one session replay event."""
116
+ if event.kind == "summary":
117
+ return [_update("agent_message_chunk", {"content": f"\n[Session summary]\n{event.content}\n"})]
118
+ if event.role == "user":
119
+ return [_update("user_message_chunk", {"content": event.content})]
120
+ if event.role == "assistant":
121
+ return [_update("agent_message_chunk", {"content": event.content})]
122
+ if event.role == "tool":
123
+ tool_name = str(event.metadata.get("tool_name") or "tool")
124
+ tool_call_id = str(event.metadata.get("tool_call_id") or f"replay-{tool_name}")
125
+ return [
126
+ _update(
127
+ "tool_call_update",
128
+ {
129
+ "toolCallId": tool_call_id,
130
+ "title": tool_name,
131
+ "kind": _tool_kind(tool_name),
132
+ "status": "completed",
133
+ "content": [{"type": "text", "text": event.content}],
134
+ "rawOutput": event.content,
135
+ "_meta": {"yoyo": {"replay": True, **event.metadata}},
136
+ },
137
+ )
138
+ ]
139
+ return []
140
+
141
+
142
+ def _update(update_type: str, payload: dict[str, Any]) -> dict[str, Any]:
143
+ return {"sessionUpdate": update_type, **payload}
144
+
145
+
146
+ def _tool_content(event: StreamEvent) -> list[dict[str, str]]:
147
+ detail = event.detail or event.content or ""
148
+ return [{"type": "text", "text": detail}] if detail else []
149
+
150
+
151
+ def _tool_call_id(event: StreamEvent) -> str:
152
+ metadata = event.metadata or {}
153
+ args = metadata.get("args") if isinstance(metadata, dict) else None
154
+ explicit = metadata.get("tool_call_id") or metadata.get("id") if isinstance(metadata, dict) else None
155
+ if explicit:
156
+ return str(explicit)
157
+ if isinstance(args, dict) and args.get("tool_call_id"):
158
+ return str(args["tool_call_id"])
159
+ return f"{event.session_id}:{event.tool_name or event.event_type}"
160
+
161
+
162
+ def _tool_status(status: str | None, *, default: str) -> str:
163
+ if status in {"failed", "denied", "cancelled"}:
164
+ return "failed" if status == "denied" else status
165
+ if status in {"completed", "running", "in_progress", "waiting_for_user"}:
166
+ return "in_progress" if status == "running" else status
167
+ return default
168
+
169
+
170
+ def _tool_kind(tool_name: str | None) -> str:
171
+ name = tool_name or ""
172
+ if name in READ_TOOLS:
173
+ return "read"
174
+ if name in SEARCH_TOOLS or name.startswith("lsp_"):
175
+ return "search"
176
+ if name in EDIT_TOOLS:
177
+ return "edit"
178
+ if name in EXECUTE_TOOLS:
179
+ return "execute"
180
+ if name in {"todo", "subagent"}:
181
+ return "think"
182
+ return "other"
183
+
184
+
185
+ def _locations(paths: list[str] | None, workdir: Path | None) -> list[dict[str, Any]]:
186
+ locations = []
187
+ for path in paths or []:
188
+ location: dict[str, Any] = {"path": str(path)}
189
+ if workdir is not None and path and not str(path).startswith("/"):
190
+ location["path"] = str((workdir / path).resolve())
191
+ locations.append(location)
192
+ return locations
agent/app_paths.py ADDED
@@ -0,0 +1,25 @@
1
+ """Application-level path helpers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from pathlib import Path
7
+
8
+
9
+ def resolve_app_root(raw_app_root: str | Path | None = None) -> Path:
10
+ """Resolve the yoyoagent application root directory."""
11
+ raw = raw_app_root or os.environ.get("YOYO_APP_ROOT")
12
+ if raw:
13
+ return Path(raw).expanduser().resolve()
14
+ return Path(__file__).resolve().parents[1]
15
+
16
+
17
+ def resolve_runtime_data_dir(
18
+ app_root: Path,
19
+ raw_runtime_data_dir: str | Path | None = None,
20
+ ) -> Path:
21
+ """Resolve the runtime data directory for yoyoagent-owned data."""
22
+ raw = raw_runtime_data_dir or os.environ.get("YOYO_RUNTIME_DATA_DIR")
23
+ if raw:
24
+ return Path(raw).expanduser().resolve()
25
+ return app_root
agent/approval.py ADDED
@@ -0,0 +1,169 @@
1
+ """Runtime approval helpers for high-risk tool calls."""
2
+
3
+ from dataclasses import dataclass
4
+ import re
5
+ from pathlib import Path
6
+ from typing import Awaitable, Callable, Literal, Optional
7
+
8
+ from tools.safety import approval_required
9
+ from tools.safety import unsafe_command_response
10
+ from tools.apply_patch import preview_apply_patch_diff
11
+ from tools.write_file import preview_write_file_diff
12
+
13
+
14
+ ApprovalDecisionStatus = Literal["approved", "denied", "cancelled"]
15
+ ApprovalCallback = Callable[["ApprovalRequest"], Awaitable[bool]]
16
+
17
+
18
+ class ApprovalDenied(Exception):
19
+ """Raised when the user denies a runtime approval request."""
20
+
21
+ def __init__(self, request: "ApprovalRequest"):
22
+ self.request = request
23
+ super().__init__(request.format())
24
+
25
+
26
+ class ApprovalTargetMissing(Exception):
27
+ """Raised when a file approval cannot identify the target path."""
28
+
29
+ def __init__(self, tool_name: str):
30
+ self.tool_name = tool_name
31
+ super().__init__(missing_file_target_message(tool_name))
32
+
33
+
34
+ @dataclass(frozen=True)
35
+ class ApprovalRequest:
36
+ """A request to approve a risky tool execution."""
37
+
38
+ action: str
39
+ tool_name: str
40
+ reason: str
41
+ risk: str
42
+ path: str = ""
43
+ command: str = ""
44
+ diff_preview: str = ""
45
+
46
+ def format(self, include_diff: bool = True) -> str:
47
+ """Format the request using the stable tool response shape."""
48
+ formatted = approval_required(
49
+ action=self.action,
50
+ command=self.command,
51
+ path=self.path,
52
+ reason=self.reason,
53
+ risk=self.risk,
54
+ )
55
+ if include_diff and self.diff_preview:
56
+ formatted = f"{formatted}\n\ndiff_preview:\n{self.diff_preview}"
57
+ return formatted
58
+
59
+
60
+ @dataclass(frozen=True)
61
+ class ApprovalDecision:
62
+ """UI-independent approval decision."""
63
+
64
+ status: ApprovalDecisionStatus
65
+
66
+ @property
67
+ def approved(self) -> bool:
68
+ """Return whether the request was approved."""
69
+ return self.status == "approved"
70
+
71
+
72
+ async def callback_from_decider(
73
+ request: ApprovalRequest,
74
+ decider: Callable[[ApprovalRequest], Awaitable[ApprovalDecision]],
75
+ ) -> bool:
76
+ """Adapt a decision-oriented approval adapter to the legacy bool callback."""
77
+ return (await decider(request)).approved
78
+
79
+
80
+ def approval_request_for_tool(
81
+ tool_name: str,
82
+ args: dict,
83
+ workdir: Path | str | None = None,
84
+ ) -> Optional[ApprovalRequest]:
85
+ """Return an approval request for tool calls that require runtime confirmation."""
86
+ if tool_name == "bash":
87
+ command = args.get("command", "")
88
+ if unsafe_command_response(command):
89
+ return ApprovalRequest(
90
+ action="run_command",
91
+ tool_name=tool_name,
92
+ command=command,
93
+ reason="bash command matches a high-risk command pattern.",
94
+ risk="This operation may be destructive or affect files outside the intended task.",
95
+ )
96
+ return None
97
+ if tool_name == "apply_patch":
98
+ path = args.get("path") or _paths_from_unified_diff(args.get("patch", ""))
99
+ if not path:
100
+ raise ApprovalTargetMissing(tool_name)
101
+ return ApprovalRequest(
102
+ action="edit_file",
103
+ tool_name=tool_name,
104
+ path=path,
105
+ reason="apply_patch edits workspace files.",
106
+ risk="File edits can overwrite user work or introduce unintended code changes.",
107
+ diff_preview=preview_apply_patch_diff(
108
+ patch=args.get("patch", ""),
109
+ path=args.get("path", ""),
110
+ old_text=args.get("old_text", ""),
111
+ new_text=args.get("new_text", ""),
112
+ workdir=workdir,
113
+ ),
114
+ )
115
+ if tool_name == "write_file":
116
+ if not args.get("path"):
117
+ raise ApprovalTargetMissing(tool_name)
118
+ return ApprovalRequest(
119
+ action="create_file",
120
+ tool_name=tool_name,
121
+ path=args.get("path", ""),
122
+ reason="write_file creates a new workspace file.",
123
+ risk="Creating files changes the workspace and may add unwanted artifacts.",
124
+ diff_preview=preview_write_file_diff(
125
+ args.get("path", ""),
126
+ args.get("content", ""),
127
+ workdir=workdir,
128
+ ),
129
+ )
130
+ return None
131
+
132
+
133
+ def missing_file_target_message(tool_name: str) -> str:
134
+ """Return a model-facing correction when a write tool has no target path."""
135
+ return (
136
+ f"File edit blocked for {tool_name}: no target file was detected.\n\n"
137
+ "Retry with an explicit target file using one of these formats:\n"
138
+ "- apply_patch with path + old_text + new_text\n"
139
+ "- apply_patch with a unified diff that includes ---/+++ file headers\n"
140
+ "- apply_patch with Begin Patch lines such as *** Update File: path\n"
141
+ "- write_file with a path for a brand-new file"
142
+ )
143
+
144
+
145
+ def approval_cache_key(request: ApprovalRequest) -> tuple[str, str, str]:
146
+ """Return the cache key for approvals within one agent run."""
147
+ return (request.action, request.tool_name, request.path)
148
+
149
+
150
+ def _paths_from_unified_diff(patch: str) -> str:
151
+ paths = []
152
+ for line in patch.splitlines():
153
+ path = None
154
+ begin_patch_match = re.match(r"\*\*\* (?:Add|Update|Delete) File: (.+)$", line)
155
+ if begin_patch_match:
156
+ path = begin_patch_match.group(1).strip()
157
+ elif line.startswith("*** Move to: "):
158
+ path = line[len("*** Move to: "):].strip()
159
+ if line.startswith("diff --git "):
160
+ match = re.match(r"diff --git a/(.+?) b/(.+)$", line)
161
+ if match:
162
+ path = match.group(2)
163
+ elif line.startswith("+++ "):
164
+ raw = line[4:].split("\t", 1)[0].strip()
165
+ if raw != "/dev/null":
166
+ path = raw[2:] if raw.startswith("b/") else raw
167
+ if path and path not in paths:
168
+ paths.append(path)
169
+ return ", ".join(paths)
agent/cancellation.py ADDED
@@ -0,0 +1,52 @@
1
+ """Shared cancellation controller for interactive runners."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import contextlib
7
+ from dataclasses import dataclass
8
+ from typing import Literal
9
+
10
+
11
+ CancelStatus = Literal["cancelled", "not_running", "already_finished"]
12
+
13
+
14
+ @dataclass(frozen=True)
15
+ class CancelResult:
16
+ """Result of a cancellation attempt."""
17
+
18
+ status: CancelStatus
19
+
20
+
21
+ class CancellationController:
22
+ """Track and cancel one active asyncio task."""
23
+
24
+ def __init__(self) -> None:
25
+ self.current_task: asyncio.Task | None = None
26
+
27
+ def set_task(self, task: asyncio.Task) -> None:
28
+ """Set the currently active task."""
29
+ self.current_task = task
30
+
31
+ def clear_task(self, task: asyncio.Task | None = None) -> None:
32
+ """Clear the active task if it matches."""
33
+ if task is None or self.current_task is task:
34
+ self.current_task = None
35
+
36
+ def is_running(self) -> bool:
37
+ """Return whether a task is currently running."""
38
+ return self.current_task is not None and not self.current_task.done()
39
+
40
+ async def cancel(self) -> CancelResult:
41
+ """Cancel the active task and return a stable status."""
42
+ task = self.current_task
43
+ if task is None:
44
+ return CancelResult("not_running")
45
+ if task.done():
46
+ self.current_task = None
47
+ return CancelResult("already_finished")
48
+ task.cancel()
49
+ with contextlib.suppress(asyncio.CancelledError):
50
+ await task
51
+ self.current_task = None
52
+ return CancelResult("cancelled")