chcode 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chcode/__init__.py +0 -0
- chcode/__main__.py +5 -0
- chcode/agent_setup.py +395 -0
- chcode/agents/__init__.py +0 -0
- chcode/agents/definitions.py +158 -0
- chcode/agents/loader.py +104 -0
- chcode/agents/runner.py +159 -0
- chcode/chat.py +1630 -0
- chcode/cli.py +142 -0
- chcode/config.py +571 -0
- chcode/display.py +325 -0
- chcode/prompts.py +640 -0
- chcode/session.py +149 -0
- chcode/skill_manager.py +165 -0
- chcode/utils/__init__.py +3 -0
- chcode/utils/enhanced_chat_openai.py +368 -0
- chcode/utils/git_checker.py +38 -0
- chcode/utils/git_manager.py +261 -0
- chcode/utils/modelscope_ratelimit.py +65 -0
- chcode/utils/multimodal.py +268 -0
- chcode/utils/shell/__init__.py +17 -0
- chcode/utils/shell/output.py +63 -0
- chcode/utils/shell/provider.py +128 -0
- chcode/utils/shell/result.py +14 -0
- chcode/utils/shell/semantics.py +55 -0
- chcode/utils/shell/session.py +159 -0
- chcode/utils/skill_loader.py +565 -0
- chcode/utils/text_utils.py +14 -0
- chcode/utils/tool_result_pipeline.py +244 -0
- chcode/utils/tools.py +1724 -0
- chcode/vision_config.py +371 -0
- chcode-0.1.0.dist-info/METADATA +275 -0
- chcode-0.1.0.dist-info/RECORD +36 -0
- chcode-0.1.0.dist-info/WHEEL +4 -0
- chcode-0.1.0.dist-info/entry_points.txt +2 -0
- chcode-0.1.0.dist-info/licenses/LICENSE +21 -0
chcode/chat.py
ADDED
|
@@ -0,0 +1,1630 @@
|
|
|
1
|
+
"""
|
|
2
|
+
主聊天 REPL — 类 Claude Code 终端体验
|
|
3
|
+
|
|
4
|
+
prompt_toolkit 多行输入 + rich 流式输出 + 斜杠命令 + HITL 审批
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import asyncio
|
|
10
|
+
import json
|
|
11
|
+
import os
|
|
12
|
+
import re
|
|
13
|
+
import shutil
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
|
|
16
|
+
import openai
|
|
17
|
+
from rich.text import Text
|
|
18
|
+
from prompt_toolkit import PromptSession
|
|
19
|
+
from prompt_toolkit.completion import Completer, Completion
|
|
20
|
+
from prompt_toolkit.formatted_text import HTML
|
|
21
|
+
from prompt_toolkit.key_binding import KeyBindings
|
|
22
|
+
from prompt_toolkit.layout.dimension import Dimension
|
|
23
|
+
from prompt_toolkit.styles import Style
|
|
24
|
+
|
|
25
|
+
from langchain_core.messages import (
|
|
26
|
+
AIMessage,
|
|
27
|
+
AIMessageChunk,
|
|
28
|
+
ToolMessage,
|
|
29
|
+
RemoveMessage,
|
|
30
|
+
HumanMessage,
|
|
31
|
+
BaseMessage,
|
|
32
|
+
)
|
|
33
|
+
from chcode.utils import get_text_content
|
|
34
|
+
from langgraph.types import Command
|
|
35
|
+
|
|
36
|
+
import chcode.display as _display
|
|
37
|
+
from chcode.display import (
|
|
38
|
+
console,
|
|
39
|
+
render_error,
|
|
40
|
+
render_info,
|
|
41
|
+
render_success,
|
|
42
|
+
render_warning,
|
|
43
|
+
render_welcome,
|
|
44
|
+
render_conversation,
|
|
45
|
+
render_ai_start,
|
|
46
|
+
render_ai_chunk,
|
|
47
|
+
render_ai_end,
|
|
48
|
+
get_context_usage_text,
|
|
49
|
+
)
|
|
50
|
+
from chcode.prompts import select, confirm, select_or_custom, text, checkbox
|
|
51
|
+
from chcode.config import (
|
|
52
|
+
get_default_model_config,
|
|
53
|
+
load_workplace,
|
|
54
|
+
save_workplace,
|
|
55
|
+
configure_new_model,
|
|
56
|
+
first_run_configure,
|
|
57
|
+
edit_current_model,
|
|
58
|
+
switch_model,
|
|
59
|
+
ensure_config_dir,
|
|
60
|
+
get_context_window_size,
|
|
61
|
+
)
|
|
62
|
+
from chcode.session import SessionManager
|
|
63
|
+
from chcode.utils.skill_loader import SkillAgentContext, SkillLoader
|
|
64
|
+
from chcode.agent_setup import (
|
|
65
|
+
build_agent,
|
|
66
|
+
create_checkpointer,
|
|
67
|
+
INNER_MODEL_CONFIG,
|
|
68
|
+
reset_budget_state,
|
|
69
|
+
get_fallback_model,
|
|
70
|
+
set_fallback_models, # noqa: F401 # 重新导出供其他模块使用
|
|
71
|
+
advance_fallback,
|
|
72
|
+
ModelSwitchError,
|
|
73
|
+
)
|
|
74
|
+
from chcode.skill_manager import manage_skills
|
|
75
|
+
from chcode.utils.git_checker import check_git_availability
|
|
76
|
+
from chcode.utils.git_manager import GitManager
|
|
77
|
+
from chcode.utils.modelscope_ratelimit import get_ratelimit, is_modelscope_model
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
# ─── 命令自动补全 ──────────────────────────────────────
|
|
81
|
+
|
|
82
|
+
SLASH_COMMANDS = {
|
|
83
|
+
"/new": "新会话",
|
|
84
|
+
"/history": "历史会话",
|
|
85
|
+
"/model": "模型管理(新建/编辑/切换)",
|
|
86
|
+
"/vision": "视觉模型配置",
|
|
87
|
+
"/messages": "管理历史消息(编辑/分叉/删除)",
|
|
88
|
+
"/compress": "压缩会话",
|
|
89
|
+
"/skill": "技能管理",
|
|
90
|
+
"/search": "配置 Tavily 搜索 API Key",
|
|
91
|
+
"/workdir": "切换工作目录",
|
|
92
|
+
"/mode": "切换 Common/Yolo 模式",
|
|
93
|
+
"/git": "Git 状态",
|
|
94
|
+
"/langsmith": "LangSmith 追踪开关",
|
|
95
|
+
"/tools": "显示内置工具",
|
|
96
|
+
"/help": "显示帮助",
|
|
97
|
+
"/quit": "退出",
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class SlashCommandCompleter(Completer):
|
|
102
|
+
"""斜杠命令自动补全器 - 输入 / 时触发下拉列表"""
|
|
103
|
+
|
|
104
|
+
def get_completions(self, document, complete_event):
|
|
105
|
+
|
|
106
|
+
# 获取光标前的完整文本
|
|
107
|
+
text = document.text_before_cursor
|
|
108
|
+
|
|
109
|
+
# 当输入 / 时触发补全
|
|
110
|
+
if text.startswith("/"):
|
|
111
|
+
|
|
112
|
+
# 把输入的文本中的字母转化成小写来处理(大小写不敏感)
|
|
113
|
+
partial = text.lower()
|
|
114
|
+
|
|
115
|
+
# 遍历预先定义的斜杠命令字典
|
|
116
|
+
for cmd, desc in SLASH_COMMANDS.items():
|
|
117
|
+
# 如果转化成小写的输入框中文本 被字典里 命令名 的 前缀匹配 到
|
|
118
|
+
if cmd.startswith(partial):
|
|
119
|
+
# 生成命令
|
|
120
|
+
yield Completion(
|
|
121
|
+
cmd, # 返回完整的命令
|
|
122
|
+
start_position=-len(partial), # 返回前清空输入框已有输入
|
|
123
|
+
display=cmd, # 下拉框显示的命令名
|
|
124
|
+
display_meta=desc, # 下拉框显示的命令名的描述
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
# ─── 辅助函数 ──────────────────────────────────────────
|
|
129
|
+
|
|
130
|
+
# 简易的 BBCode 风格标记语言解析 (论坛或聊天软件)
|
|
131
|
+
_RE_TAG_SPLIT = re.compile(r"(\[/?[^\]]+\])")
|
|
132
|
+
_RE_TAG_OPEN = re.compile(r"^\[([^\]]+)\]$")
|
|
133
|
+
_RE_TAG_CLOSE = re.compile(r"^\[/([^\]]*)\]$")
|
|
134
|
+
|
|
135
|
+
_RICH_TAG_MAP = {
|
|
136
|
+
"bold": "b",
|
|
137
|
+
"italic": "i",
|
|
138
|
+
"red": "fg:red",
|
|
139
|
+
"green": "fg:green",
|
|
140
|
+
"yellow": "fg:yellow",
|
|
141
|
+
"blue": "fg:blue",
|
|
142
|
+
"dim": "fg:#888888",
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def _rich_to_html(text: str) -> str:
|
|
147
|
+
parts = _RE_TAG_SPLIT.split(text)
|
|
148
|
+
opened: list[str] = []
|
|
149
|
+
result: list[str] = []
|
|
150
|
+
|
|
151
|
+
for part in parts:
|
|
152
|
+
close_m = _RE_TAG_CLOSE.match(part)
|
|
153
|
+
open_m = _RE_TAG_OPEN.match(part) if not close_m else None
|
|
154
|
+
if close_m:
|
|
155
|
+
while opened:
|
|
156
|
+
tag = opened.pop()
|
|
157
|
+
result.append(f"</{tag}>")
|
|
158
|
+
elif open_m:
|
|
159
|
+
tags = open_m.group(1).split()
|
|
160
|
+
for t in tags:
|
|
161
|
+
mapped = _RICH_TAG_MAP.get(t)
|
|
162
|
+
if mapped:
|
|
163
|
+
if mapped.startswith("fg:"):
|
|
164
|
+
result.append(f'<style fg="{mapped[3:]}">')
|
|
165
|
+
opened.append("style")
|
|
166
|
+
else:
|
|
167
|
+
result.append(f"<{mapped}>")
|
|
168
|
+
opened.append(mapped)
|
|
169
|
+
else:
|
|
170
|
+
result.append(part)
|
|
171
|
+
|
|
172
|
+
return "".join(result)
|
|
173
|
+
|
|
174
|
+
# 获取最近的几组消息
|
|
175
|
+
def find_and_slice_from_end(lst, x):
|
|
176
|
+
"""从后往前查找第一个 type==x 的元素,返回从该元素到末尾的切片"""
|
|
177
|
+
for i in range(len(lst) - 1, -1, -1):
|
|
178
|
+
if lst[i].type == x:
|
|
179
|
+
return lst[i:]
|
|
180
|
+
return []
|
|
181
|
+
|
|
182
|
+
# 消息分组
|
|
183
|
+
def _group_messages_by_turn(messages: list) -> list[list]:
|
|
184
|
+
"""
|
|
185
|
+
将消息按轮次分组(参考 chagent 逻辑)
|
|
186
|
+
从一个 HumanMessage 开始,到下一个 HumanMessage 之前为一组
|
|
187
|
+
"""
|
|
188
|
+
groups = []
|
|
189
|
+
current_group = []
|
|
190
|
+
|
|
191
|
+
for msg in messages:
|
|
192
|
+
if msg.type == "human": # 下一组消息的第一个消息:HumanMessage
|
|
193
|
+
if current_group: # 当前消息组
|
|
194
|
+
groups.append(current_group)
|
|
195
|
+
current_group = [msg] # 把下一组消息的第一个消息:HumanMessage,放入新的消息组
|
|
196
|
+
else:
|
|
197
|
+
current_group.append(msg) # 把下一组消息的其余消息也放入新的消息组
|
|
198
|
+
|
|
199
|
+
if current_group: # 所有消息都遍历完 还没放入消息组
|
|
200
|
+
groups.append(current_group) # 所以需要放入消息组
|
|
201
|
+
|
|
202
|
+
return groups
|
|
203
|
+
|
|
204
|
+
# 历史会话的会话名显示
|
|
205
|
+
def _get_group_display(group: list) -> str:
|
|
206
|
+
"""获取消息组的显示文本(以 HumanMessage 内容为代表)"""
|
|
207
|
+
for msg in group:
|
|
208
|
+
if msg.type == "human":
|
|
209
|
+
text_content = get_text_content(msg.content)
|
|
210
|
+
content = text_content[:60].replace("\n", " ")
|
|
211
|
+
if len(text_content) > 60:
|
|
212
|
+
content += "..."
|
|
213
|
+
return content
|
|
214
|
+
return "(空消息组)"
|
|
215
|
+
|
|
216
|
+
# 收集即将被压缩的消息的消息id组
|
|
217
|
+
def _collect_ids_from_group(
|
|
218
|
+
group_index: int, groups: list, mode: str = "edit"
|
|
219
|
+
) -> tuple[list[str], list[str]]:
|
|
220
|
+
all_ids = [m.id for group in groups for m in group]
|
|
221
|
+
no_need_ids = []
|
|
222
|
+
for i, group in enumerate(groups):
|
|
223
|
+
if i >= group_index:
|
|
224
|
+
no_need_ids.extend([m.id for m in group])
|
|
225
|
+
return no_need_ids, all_ids
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
# ─── 主聊天类 ──────────────────────────────────────────
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
class ChatREPL:
|
|
232
|
+
def __init__(self):
|
|
233
|
+
self.workplace_path: Path | None = None # 工作目录路径
|
|
234
|
+
self.model_config: dict = {} # 模型参数
|
|
235
|
+
self.yolo = False # Yolo模式
|
|
236
|
+
self.agent = None # agent实例
|
|
237
|
+
self.checkpointer = None # 检查点实例
|
|
238
|
+
self.session_mgr: SessionManager | None = None # 会话管理器
|
|
239
|
+
self.git_manager: GitManager | None = None # git管理器
|
|
240
|
+
self.git = False # git是否激活
|
|
241
|
+
self._git_cp_count = 0 # git提交数
|
|
242
|
+
self._stop_requested = False # 暂停agent的flag
|
|
243
|
+
self._processing = False
|
|
244
|
+
# 初始化 prompt-toolkit 会话(用于命令自动补全)
|
|
245
|
+
self._prompt_session = None
|
|
246
|
+
# 编辑缓冲区(用于 /edit 命令)
|
|
247
|
+
self._edit_buffer: str | None = None
|
|
248
|
+
# 中断恢复缓冲区(中断时将内容填回输入框,不进入编辑模式)
|
|
249
|
+
self._interrupt_buffer: str | None = None
|
|
250
|
+
# SkillLoader 复用,避免每条消息重建
|
|
251
|
+
self._skill_loader: SkillLoader | None = None
|
|
252
|
+
# 上下文用量缓存
|
|
253
|
+
self._context_text: str = ""
|
|
254
|
+
# Windows 保留名(不能作为文件名)
|
|
255
|
+
self.WINDOWS_RESERVED_NAMES = {
|
|
256
|
+
"nul",
|
|
257
|
+
"con",
|
|
258
|
+
"aux",
|
|
259
|
+
"prn",
|
|
260
|
+
"com1",
|
|
261
|
+
"com2",
|
|
262
|
+
"com3",
|
|
263
|
+
"com4",
|
|
264
|
+
"lpt1",
|
|
265
|
+
"lpt2",
|
|
266
|
+
"lpt3",
|
|
267
|
+
}
|
|
268
|
+
|
|
269
|
+
# ─── 清理 ────────────────────────────────────────
|
|
270
|
+
|
|
271
|
+
async def close(self) -> None:
|
|
272
|
+
"""关闭资源(aiosqlite 连接等)"""
|
|
273
|
+
if self.checkpointer is not None:
|
|
274
|
+
try:
|
|
275
|
+
await self.checkpointer.conn.close()
|
|
276
|
+
except Exception:
|
|
277
|
+
pass
|
|
278
|
+
self.checkpointer = None
|
|
279
|
+
|
|
280
|
+
# ─── 初始化 ────────────────────────────────────────
|
|
281
|
+
|
|
282
|
+
async def initialize(self) -> bool:
|
|
283
|
+
"""初始化:加载配置、设置工作目录、构建 agent"""
|
|
284
|
+
ensure_config_dir() # 确保配置目录.chat存在
|
|
285
|
+
|
|
286
|
+
self.workplace_path = Path.cwd() # 获取当前目录路径
|
|
287
|
+
|
|
288
|
+
chat_dir = self.workplace_path / ".chat"
|
|
289
|
+
chat_dir.mkdir(exist_ok=True)
|
|
290
|
+
(chat_dir / "sessions").mkdir(exist_ok=True)
|
|
291
|
+
(chat_dir / "skills").mkdir(exist_ok=True)
|
|
292
|
+
|
|
293
|
+
self.session_mgr = SessionManager(self.workplace_path)
|
|
294
|
+
|
|
295
|
+
self.model_config = get_default_model_config() or {}
|
|
296
|
+
if not self.model_config:
|
|
297
|
+
config = await first_run_configure()
|
|
298
|
+
if config is None:
|
|
299
|
+
return False
|
|
300
|
+
self.model_config = config
|
|
301
|
+
|
|
302
|
+
# 创建 checkpointer
|
|
303
|
+
db_path = self.workplace_path / ".chat" / "sessions" / "checkpointer.db"
|
|
304
|
+
self.checkpointer = await create_checkpointer(db_path)
|
|
305
|
+
|
|
306
|
+
# 构建 agent(可能较慢,放线程)
|
|
307
|
+
console.print(
|
|
308
|
+
"[dim cyan]"
|
|
309
|
+
" ███████╗ ██╗ ██╗ ███████╗ ██████╗ █████╗ ████████╗\n"
|
|
310
|
+
"██╔═════╝ ██║ ██║ ██╔═════╝ ██╔═══██╗ ██╔══██╗ ██╔═════╝\n"
|
|
311
|
+
"██║ ████████║ ██║ ██║ ██║ ██║ ██╗ ████████╗\n"
|
|
312
|
+
"██║ ██╔═══██║ ██║ ██║ ██║ ██║ ██╔╝ ██╔═════╝\n"
|
|
313
|
+
"████████╗ ██║ ██║ ████████╗ ╚██████╔╝ █████╔═╝ ████████╗\n"
|
|
314
|
+
" ╚══════╝ ╚═╝ ╚═╝ ╚══════╝ ╚═════╝ ╚════╝ ╚══════╝[/dim cyan]"
|
|
315
|
+
)
|
|
316
|
+
self.agent = await asyncio.to_thread(
|
|
317
|
+
build_agent,
|
|
318
|
+
self.model_config,
|
|
319
|
+
self.checkpointer,
|
|
320
|
+
None,
|
|
321
|
+
self.yolo,
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
# 初始化 Git(subprocess.run 会阻塞事件循环)
|
|
325
|
+
await self._init_git()
|
|
326
|
+
|
|
327
|
+
# 初始化命令历史
|
|
328
|
+
self._init_readline_history()
|
|
329
|
+
|
|
330
|
+
return True
|
|
331
|
+
|
|
332
|
+
# 初始化命令历史
|
|
333
|
+
def _init_readline_history(self):
|
|
334
|
+
"""初始化 readline 历史(跨会话保存)"""
|
|
335
|
+
try:
|
|
336
|
+
import readline
|
|
337
|
+
|
|
338
|
+
history_path = Path.home() / ".chat" / "history"
|
|
339
|
+
history_path.parent.mkdir(exist_ok=True)
|
|
340
|
+
if history_path.exists():
|
|
341
|
+
readline.read_history_file(str(history_path))
|
|
342
|
+
readline.set_history_length(1000)
|
|
343
|
+
except ImportError:
|
|
344
|
+
pass
|
|
345
|
+
|
|
346
|
+
# 保存命令历史
|
|
347
|
+
def _save_readline_history(self):
|
|
348
|
+
"""保存 readline 历史"""
|
|
349
|
+
try:
|
|
350
|
+
import readline
|
|
351
|
+
|
|
352
|
+
history_path = Path.home() / ".chat" / "history"
|
|
353
|
+
history_path.parent.mkdir(exist_ok=True)
|
|
354
|
+
readline.write_history_file(str(history_path))
|
|
355
|
+
except ImportError:
|
|
356
|
+
pass
|
|
357
|
+
|
|
358
|
+
return True
|
|
359
|
+
|
|
360
|
+
async def _init_git(self) -> None:
|
|
361
|
+
"""初始化 Git"""
|
|
362
|
+
is_available, status, version = await asyncio.to_thread(check_git_availability)
|
|
363
|
+
if is_available:
|
|
364
|
+
self.git_manager = GitManager(str(self.workplace_path))
|
|
365
|
+
if not self.git_manager.is_repo():
|
|
366
|
+
await asyncio.to_thread(self.git_manager.init)
|
|
367
|
+
self.git = True
|
|
368
|
+
self._git_cp_count = self.git_manager.count_checkpoints()
|
|
369
|
+
|
|
370
|
+
# ─── 主循环 ────────────────────────────────────────
|
|
371
|
+
|
|
372
|
+
async def run(self) -> None:
|
|
373
|
+
"""主聊天循环"""
|
|
374
|
+
render_welcome()
|
|
375
|
+
|
|
376
|
+
while True:
|
|
377
|
+
try:
|
|
378
|
+
user_input = await self._get_input()
|
|
379
|
+
if user_input is None:
|
|
380
|
+
break
|
|
381
|
+
|
|
382
|
+
user_input = user_input.strip()
|
|
383
|
+
if not user_input:
|
|
384
|
+
continue
|
|
385
|
+
|
|
386
|
+
# 斜杠命令
|
|
387
|
+
if user_input.startswith("/"):
|
|
388
|
+
await self._handle_command(user_input)
|
|
389
|
+
continue
|
|
390
|
+
|
|
391
|
+
# 正常对话
|
|
392
|
+
prev_tracing = os.environ.get("LANGCHAIN_TRACING_V2", "false").lower()
|
|
393
|
+
await self._process_input(user_input)
|
|
394
|
+
|
|
395
|
+
# 检查 LangSmith 是否因 429 自动关闭
|
|
396
|
+
if prev_tracing == "true" and os.environ.get("LANGCHAIN_TRACING_V2", "false").lower() != prev_tracing:
|
|
397
|
+
render_warning(
|
|
398
|
+
"LangSmith 追踪已因配额耗尽自动关闭 "
|
|
399
|
+
"(/langsmith 可手动管理)"
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
except KeyboardInterrupt:
|
|
403
|
+
if self._processing:
|
|
404
|
+
self._stop_requested = True
|
|
405
|
+
else:
|
|
406
|
+
console.print(Text("\n再见!", style="dim"))
|
|
407
|
+
break
|
|
408
|
+
except EOFError:
|
|
409
|
+
break
|
|
410
|
+
except Exception as e:
|
|
411
|
+
render_error(f"Unexpected error: {e}")
|
|
412
|
+
|
|
413
|
+
async def _get_input(self) -> str | None:
|
|
414
|
+
"""获取用户输入(使用 prompt-toolkit 实现命令自动补全)"""
|
|
415
|
+
|
|
416
|
+
# 检查是否有中断恢复缓冲区
|
|
417
|
+
interrupt_mode = self._interrupt_buffer is not None
|
|
418
|
+
|
|
419
|
+
# 初始化 prompt session(带命令自动补全 + 底部状态栏)
|
|
420
|
+
if self._prompt_session is None:
|
|
421
|
+
completer = SlashCommandCompleter()
|
|
422
|
+
|
|
423
|
+
# 自定义按键:Enter 提交,Ctrl+Enter 换行
|
|
424
|
+
kb = KeyBindings()
|
|
425
|
+
|
|
426
|
+
@kb.add("enter")
|
|
427
|
+
def _submit(event):
|
|
428
|
+
event.current_buffer.validate_and_handle()
|
|
429
|
+
|
|
430
|
+
@kb.add("c-j") # Ctrl+Enter → 换行
|
|
431
|
+
def _newline(event):
|
|
432
|
+
event.current_buffer.insert_text("\n")
|
|
433
|
+
|
|
434
|
+
@kb.add("tab")
|
|
435
|
+
def _tab_toggle_mode(event):
|
|
436
|
+
if event.current_buffer.text:
|
|
437
|
+
return # 有内容时走默认补全
|
|
438
|
+
self.yolo = not self.yolo
|
|
439
|
+
from chcode.agent_setup import update_hitl_config
|
|
440
|
+
|
|
441
|
+
update_hitl_config(self.yolo)
|
|
442
|
+
event.app.renderer._last_rendered_width = 0 # 强制刷新 toolbar
|
|
443
|
+
|
|
444
|
+
_last_width = 0
|
|
445
|
+
_last_width_time = 0.0
|
|
446
|
+
|
|
447
|
+
def _bottom_toolbar():
|
|
448
|
+
nonlocal _last_width, _last_width_time
|
|
449
|
+
import time as _time
|
|
450
|
+
now = _time.monotonic()
|
|
451
|
+
if now - _last_width_time > 1.0:
|
|
452
|
+
_last_width = shutil.get_terminal_size().columns
|
|
453
|
+
_last_width_time = now
|
|
454
|
+
width = _last_width or shutil.get_terminal_size().columns
|
|
455
|
+
sep = "\u2500" * width
|
|
456
|
+
parts = []
|
|
457
|
+
model = self.model_config.get("model", "未设置")
|
|
458
|
+
parts.append(model)
|
|
459
|
+
if hasattr(self, "_context_text") and self._context_text:
|
|
460
|
+
styled = _rich_to_html(self._context_text)
|
|
461
|
+
parts.append(styled)
|
|
462
|
+
parts.append(
|
|
463
|
+
"普通模式" if not self.yolo else "<ansired>YOLO 模式</ansired>"
|
|
464
|
+
)
|
|
465
|
+
if self.git and self.git_manager and self.git_manager.is_repo():
|
|
466
|
+
parts.append(f"Git ({self._git_cp_count} cp)")
|
|
467
|
+
wp = str(self.workplace_path) if self.workplace_path else ""
|
|
468
|
+
if wp:
|
|
469
|
+
parts.append(f"cwd: {wp}")
|
|
470
|
+
status = " │ ".join(parts)
|
|
471
|
+
ratelimit_line = ""
|
|
472
|
+
if is_modelscope_model(self.model_config):
|
|
473
|
+
rl = get_ratelimit()
|
|
474
|
+
if rl:
|
|
475
|
+
total = f"{rl['total_remaining']}/{rl['total_limit']}"
|
|
476
|
+
model_name = self.model_config.get("model", "").split("/")[-1]
|
|
477
|
+
model_rl = f"{rl['model_remaining']}/{rl['model_limit']}"
|
|
478
|
+
ratelimit_line = f"\n<ansicyan>魔搭今日免费额度剩余: 全局 {total} │ 模型({model_name}) {model_rl}</ansicyan>"
|
|
479
|
+
return HTML(f"<ansiblue>{sep}</ansiblue>\n{status}{ratelimit_line}")
|
|
480
|
+
|
|
481
|
+
self._prompt_session = PromptSession(
|
|
482
|
+
multiline=True,
|
|
483
|
+
key_bindings=kb,
|
|
484
|
+
completer=completer,
|
|
485
|
+
complete_while_typing=True,
|
|
486
|
+
reserve_space_for_menu=0,
|
|
487
|
+
bottom_toolbar=_bottom_toolbar,
|
|
488
|
+
refresh_interval=0.1,
|
|
489
|
+
style=Style.from_dict(
|
|
490
|
+
{
|
|
491
|
+
"completion-menu.completion": "bg:#008888 #ffffff",
|
|
492
|
+
"completion-menu.completion.current": "bg:#00aaaa #000000",
|
|
493
|
+
"completion-menu.meta.completion": "bg:#008888 #ffffff",
|
|
494
|
+
"completion-menu.meta.completion.current": "bg:#00aaaa #000000",
|
|
495
|
+
"bottom-toolbar": "noreverse bg:#1a1a2e #aaaaaa",
|
|
496
|
+
}
|
|
497
|
+
),
|
|
498
|
+
)
|
|
499
|
+
|
|
500
|
+
def _dynamic_buffer_height():
|
|
501
|
+
buff = self._prompt_session.default_buffer
|
|
502
|
+
if buff.complete_state is not None:
|
|
503
|
+
n = len(buff.complete_state.completions)
|
|
504
|
+
needed = min(n + 2, 7)
|
|
505
|
+
return Dimension(min=needed, max=needed)
|
|
506
|
+
line_count = buff.text.count("\n") + 1
|
|
507
|
+
return Dimension(min=line_count, max=line_count)
|
|
508
|
+
|
|
509
|
+
def _find_buffer_window(container):
|
|
510
|
+
from prompt_toolkit.layout.containers import Window
|
|
511
|
+
from prompt_toolkit.layout.controls import BufferControl
|
|
512
|
+
|
|
513
|
+
if isinstance(container, Window):
|
|
514
|
+
if isinstance(getattr(container, "content", None), BufferControl):
|
|
515
|
+
return container
|
|
516
|
+
for attr in ("content", "children", "alternative_content"):
|
|
517
|
+
child = getattr(container, attr, None)
|
|
518
|
+
if child is None:
|
|
519
|
+
continue
|
|
520
|
+
children = child if isinstance(child, list) else [child]
|
|
521
|
+
for c in children:
|
|
522
|
+
result = _find_buffer_window(c)
|
|
523
|
+
if result:
|
|
524
|
+
return result
|
|
525
|
+
return None
|
|
526
|
+
|
|
527
|
+
buffer_window = _find_buffer_window(
|
|
528
|
+
self._prompt_session.app.layout.container
|
|
529
|
+
)
|
|
530
|
+
if buffer_window:
|
|
531
|
+
buffer_window.height = _dynamic_buffer_height
|
|
532
|
+
|
|
533
|
+
try:
|
|
534
|
+
# 如果有编辑缓冲区或中断恢复缓冲区,预填充到输入框
|
|
535
|
+
if self._edit_buffer is not None:
|
|
536
|
+
default_text = self._edit_buffer
|
|
537
|
+
self._edit_buffer = None # 清除缓冲区
|
|
538
|
+
elif interrupt_mode:
|
|
539
|
+
default_text = self._interrupt_buffer
|
|
540
|
+
self._interrupt_buffer = None # 清除缓冲区
|
|
541
|
+
else:
|
|
542
|
+
default_text = ""
|
|
543
|
+
|
|
544
|
+
width = shutil.get_terminal_size().columns
|
|
545
|
+
sep = "\u2500" * width
|
|
546
|
+
prompt_text = f"{sep}\n > "
|
|
547
|
+
|
|
548
|
+
# 使用 prompt-toolkit 获取输入(支持命令自动补全)
|
|
549
|
+
result = await asyncio.to_thread(
|
|
550
|
+
self._prompt_session.prompt,
|
|
551
|
+
HTML(f"<ansiblue>{prompt_text}</ansiblue>"),
|
|
552
|
+
default=default_text,
|
|
553
|
+
)
|
|
554
|
+
if result is not None:
|
|
555
|
+
self._save_readline_history()
|
|
556
|
+
return result
|
|
557
|
+
except (EOFError, KeyboardInterrupt):
|
|
558
|
+
return None
|
|
559
|
+
|
|
560
|
+
# ─── 斜杠命令 ──────────────────────────────────────
|
|
561
|
+
|
|
562
|
+
async def _handle_command(self, cmd: str) -> None:
|
|
563
|
+
"""处理斜杠命令"""
|
|
564
|
+
parts = cmd.strip().split(maxsplit=1)
|
|
565
|
+
command = parts[0].lower()
|
|
566
|
+
arg = parts[1] if len(parts) > 1 else ""
|
|
567
|
+
|
|
568
|
+
handlers = {
|
|
569
|
+
"/new": self._cmd_new,
|
|
570
|
+
"/model": self._cmd_model,
|
|
571
|
+
"/vision": self._cmd_vision,
|
|
572
|
+
"/skill": self._cmd_skill,
|
|
573
|
+
"/history": self._cmd_history,
|
|
574
|
+
"/compress": self._cmd_compress,
|
|
575
|
+
"/git": self._cmd_git,
|
|
576
|
+
"/search": self._cmd_search,
|
|
577
|
+
"/mode": self._cmd_mode,
|
|
578
|
+
"/workdir": self._cmd_workdir,
|
|
579
|
+
"/tools": self._cmd_tools,
|
|
580
|
+
"/langsmith": self._cmd_langsmith,
|
|
581
|
+
"/messages": self._cmd_messages,
|
|
582
|
+
"/help": self._cmd_help,
|
|
583
|
+
"/quit": self._cmd_quit,
|
|
584
|
+
}
|
|
585
|
+
|
|
586
|
+
handler = handlers.get(command)
|
|
587
|
+
if handler:
|
|
588
|
+
await handler(arg)
|
|
589
|
+
else:
|
|
590
|
+
render_warning(f"未知命令: {command},输入 /help 查看帮助")
|
|
591
|
+
|
|
592
|
+
async def _cmd_new(self, _arg: str) -> None:
|
|
593
|
+
reset_budget_state()
|
|
594
|
+
self.session_mgr.new_session()
|
|
595
|
+
render_success("新会话已开始")
|
|
596
|
+
self._render_status_bar()
|
|
597
|
+
|
|
598
|
+
async def _cmd_model(self, arg: str) -> None:
|
|
599
|
+
if arg == "new":
|
|
600
|
+
config = await configure_new_model()
|
|
601
|
+
elif arg == "edit":
|
|
602
|
+
config = await edit_current_model()
|
|
603
|
+
elif arg == "switch":
|
|
604
|
+
config = await switch_model()
|
|
605
|
+
else:
|
|
606
|
+
action = await select(
|
|
607
|
+
"模型管理:",
|
|
608
|
+
[
|
|
609
|
+
"新建模型 (/model new)",
|
|
610
|
+
"编辑当前模型 (/model edit)",
|
|
611
|
+
"切换模型 (/model switch)",
|
|
612
|
+
],
|
|
613
|
+
)
|
|
614
|
+
if action is None:
|
|
615
|
+
return
|
|
616
|
+
if "新建" in action:
|
|
617
|
+
config = await configure_new_model()
|
|
618
|
+
elif "编辑" in action:
|
|
619
|
+
config = await edit_current_model()
|
|
620
|
+
elif "切换" in action:
|
|
621
|
+
config = await switch_model()
|
|
622
|
+
else:
|
|
623
|
+
return
|
|
624
|
+
|
|
625
|
+
if config:
|
|
626
|
+
self.model_config = config
|
|
627
|
+
from chcode.agent_setup import update_summarization_model
|
|
628
|
+
|
|
629
|
+
update_summarization_model(config)
|
|
630
|
+
self._render_status_bar()
|
|
631
|
+
|
|
632
|
+
async def _cmd_langsmith(self, _arg: str) -> None:
|
|
633
|
+
current = os.environ.get("LANGCHAIN_TRACING_V2", "false").lower() == "true"
|
|
634
|
+
state = "开启" if current else "关闭"
|
|
635
|
+
action = await select(
|
|
636
|
+
f"LangSmith 追踪: {state}",
|
|
637
|
+
["开启追踪", "关闭追踪"],
|
|
638
|
+
)
|
|
639
|
+
if action is None:
|
|
640
|
+
return
|
|
641
|
+
os.environ["LANGCHAIN_TRACING_V2"] = "true" if "开启" in action else "false"
|
|
642
|
+
render_success(f"LangSmith 追踪已{'开启' if '开启' in action else '关闭'}")
|
|
643
|
+
|
|
644
|
+
async def _cmd_tools(self, _arg: str) -> None:
|
|
645
|
+
from chcode.utils.tools import ALL_TOOLS
|
|
646
|
+
from chcode.utils.multimodal import is_multimodal_model
|
|
647
|
+
|
|
648
|
+
current_model = (self.model_config or {}).get("model", "")
|
|
649
|
+
native_vision = is_multimodal_model(current_model)
|
|
650
|
+
|
|
651
|
+
console.print("[bold]内置工具[/bold]")
|
|
652
|
+
console.print()
|
|
653
|
+
if native_vision:
|
|
654
|
+
console.print("[dim]当前模型支持原生视觉,图片/视频将直接嵌入消息[/dim]")
|
|
655
|
+
console.print()
|
|
656
|
+
for t in ALL_TOOLS:
|
|
657
|
+
name = t.name
|
|
658
|
+
desc = t.description.split("\n")[0] if t.description else ""
|
|
659
|
+
if name == "vision" and native_vision:
|
|
660
|
+
console.print(f" [dim]{name:<16}[/dim] {desc} (已禁用)")
|
|
661
|
+
else:
|
|
662
|
+
console.print(f" [cyan]{name:<16}[/cyan] {desc}")
|
|
663
|
+
console.print()
|
|
664
|
+
|
|
665
|
+
async def _cmd_skill(self, _arg: str) -> None:
|
|
666
|
+
if not self.session_mgr:
|
|
667
|
+
render_error("请先初始化工作目录")
|
|
668
|
+
return
|
|
669
|
+
await manage_skills(self.session_mgr)
|
|
670
|
+
|
|
671
|
+
async def _cmd_history(self, _arg: str) -> None:
|
|
672
|
+
if not self.session_mgr or not self.checkpointer or not self.agent:
|
|
673
|
+
return
|
|
674
|
+
sessions = await self.session_mgr.list_sessions(self.checkpointer)
|
|
675
|
+
if not sessions:
|
|
676
|
+
render_warning("没有历史会话")
|
|
677
|
+
return
|
|
678
|
+
|
|
679
|
+
sessions = sessions[-50:]
|
|
680
|
+
|
|
681
|
+
display_names = await self.session_mgr.get_display_names(sessions, self.agent)
|
|
682
|
+
tid_to_label: dict[str, str] = {}
|
|
683
|
+
labels: list[str] = []
|
|
684
|
+
for tid in sessions:
|
|
685
|
+
name = display_names.get(tid, tid)
|
|
686
|
+
label = name if name == tid else f"{name} ({tid})"
|
|
687
|
+
tid_to_label[label] = tid
|
|
688
|
+
labels.append(label)
|
|
689
|
+
labels.append("返回")
|
|
690
|
+
|
|
691
|
+
action = await select("选择历史会话:", labels)
|
|
692
|
+
if action is None or action == "返回":
|
|
693
|
+
return
|
|
694
|
+
|
|
695
|
+
selected_tid = tid_to_label[action]
|
|
696
|
+
|
|
697
|
+
op = await select("操作:", ["加载此会话", "重命名此会话", "删除此会话", "返回"])
|
|
698
|
+
if op == "加载此会话":
|
|
699
|
+
self.session_mgr.set_thread(selected_tid)
|
|
700
|
+
await self._load_conversation()
|
|
701
|
+
self._render_status_bar()
|
|
702
|
+
elif op == "重命名此会话":
|
|
703
|
+
cur = self.session_mgr._load_names().get(selected_tid, "")
|
|
704
|
+
new_name = await text("输入新名称(留空恢复默认):", default=cur)
|
|
705
|
+
if new_name is not None:
|
|
706
|
+
self.session_mgr.rename_session(selected_tid, new_name)
|
|
707
|
+
render_success("会话已重命名")
|
|
708
|
+
elif op == "删除此会话":
|
|
709
|
+
ok = await confirm(f"确定删除会话 {selected_tid}?", default=False)
|
|
710
|
+
if ok:
|
|
711
|
+
await self.session_mgr.delete_session(selected_tid, self.checkpointer)
|
|
712
|
+
render_success("会话已删除")
|
|
713
|
+
if selected_tid == self.session_mgr.thread_id:
|
|
714
|
+
await self._cmd_new("")
|
|
715
|
+
|
|
716
|
+
async def _cmd_compress(self, _arg: str) -> None:
|
|
717
|
+
if not self.model_config:
|
|
718
|
+
render_warning("请先配置模型")
|
|
719
|
+
return
|
|
720
|
+
|
|
721
|
+
ok = await confirm("确定压缩当前会话?", default=True)
|
|
722
|
+
if not ok:
|
|
723
|
+
return
|
|
724
|
+
|
|
725
|
+
render_info("压缩中...")
|
|
726
|
+
try:
|
|
727
|
+
state = await self.agent.aget_state(self.session_mgr.config)
|
|
728
|
+
messages: list[BaseMessage] = state.values["messages"]
|
|
729
|
+
|
|
730
|
+
# 分离历史消息和最近消息
|
|
731
|
+
recent_messages = []
|
|
732
|
+
recent_message_ids = []
|
|
733
|
+
recent_count = 0
|
|
734
|
+
for msg in reversed(messages):
|
|
735
|
+
recent_messages.append(msg)
|
|
736
|
+
recent_message_ids.append(msg.id)
|
|
737
|
+
if isinstance(msg, HumanMessage):
|
|
738
|
+
recent_count += 1
|
|
739
|
+
if recent_count == 2:
|
|
740
|
+
break
|
|
741
|
+
|
|
742
|
+
pre_messages = []
|
|
743
|
+
for msg in messages:
|
|
744
|
+
if msg.id not in recent_message_ids:
|
|
745
|
+
msg.additional_kwargs["composed"] = True
|
|
746
|
+
# 压缩时去掉 base64 图片/视频,避免 payload 过大导致 API 返回空 choices
|
|
747
|
+
if isinstance(msg.content, list):
|
|
748
|
+
clean_blocks = [
|
|
749
|
+
b for b in msg.content
|
|
750
|
+
if not isinstance(b, dict)
|
|
751
|
+
or b.get("type") not in ("image_url", "video_url")
|
|
752
|
+
]
|
|
753
|
+
if clean_blocks != msg.content:
|
|
754
|
+
msg = msg.model_copy(update={"content": clean_blocks})
|
|
755
|
+
pre_messages.append(msg)
|
|
756
|
+
|
|
757
|
+
from chcode.utils.enhanced_chat_openai import EnhancedChatOpenAI
|
|
758
|
+
|
|
759
|
+
model = EnhancedChatOpenAI(**self.model_config)
|
|
760
|
+
|
|
761
|
+
human_msg = HumanMessage(
|
|
762
|
+
content='以你的角度用第二人称压缩会话,严格按以下JSON格式输出,不要使用markdown代码块:\n{{"summary": "压缩内容"}}',
|
|
763
|
+
additional_kwargs={"hide": True, "composed": True},
|
|
764
|
+
)
|
|
765
|
+
|
|
766
|
+
try:
|
|
767
|
+
raw_resp = await asyncio.to_thread(
|
|
768
|
+
model.invoke, pre_messages + [human_msg]
|
|
769
|
+
)
|
|
770
|
+
import re
|
|
771
|
+
|
|
772
|
+
content = raw_resp.content.strip()
|
|
773
|
+
# 去除 markdown 代码块包裹
|
|
774
|
+
if content.startswith("```"):
|
|
775
|
+
content = re.sub(r"^```(?:json)?\s*\n?", "", content)
|
|
776
|
+
content = re.sub(r"\n?```\s*$", "", content)
|
|
777
|
+
# 提取包含 "summary" 的 JSON 对象(模型可能在 JSON 前输出思考内容)
|
|
778
|
+
json_match = re.search(r'\{[^{}]*"summary"[^{}]*\}', content)
|
|
779
|
+
if json_match:
|
|
780
|
+
content = json_match.group()
|
|
781
|
+
else:
|
|
782
|
+
# 可能 summary 值中包含嵌套对象,用逐字符括号匹配兜底
|
|
783
|
+
# NOTE: 不处理字符串内的 `}`,但模型 summary 含 `}` 的概率极低,暂不改
|
|
784
|
+
depth = 0
|
|
785
|
+
start = -1
|
|
786
|
+
for i, ch in enumerate(content):
|
|
787
|
+
if ch == '{':
|
|
788
|
+
if depth == 0:
|
|
789
|
+
start = i
|
|
790
|
+
depth += 1
|
|
791
|
+
elif ch == '}':
|
|
792
|
+
depth -= 1
|
|
793
|
+
if depth == 0 and start >= 0:
|
|
794
|
+
candidate = content[start:i+1]
|
|
795
|
+
if '"summary"' in candidate:
|
|
796
|
+
content = candidate
|
|
797
|
+
break
|
|
798
|
+
data = json.loads(content)
|
|
799
|
+
ai_content = data.get("summary", "")
|
|
800
|
+
if isinstance(ai_content, dict):
|
|
801
|
+
ai_content = json.dumps(ai_content, ensure_ascii=False)
|
|
802
|
+
if not ai_content:
|
|
803
|
+
ai_content = "会话压缩失败: LLM 返回结果缺少 summary 字段"
|
|
804
|
+
except Exception as e:
|
|
805
|
+
ai_content = f"会话压缩失败: {e}"
|
|
806
|
+
human_msg.additional_kwargs["composed"] = True
|
|
807
|
+
|
|
808
|
+
if ai_content.startswith("会话压缩失败"):
|
|
809
|
+
ai_message = AIMessage(
|
|
810
|
+
ai_content,
|
|
811
|
+
additional_kwargs={"error": True, "composed": True},
|
|
812
|
+
usage_metadata={
|
|
813
|
+
"input_tokens": 0,
|
|
814
|
+
"output_tokens": 0,
|
|
815
|
+
"total_tokens": 0,
|
|
816
|
+
},
|
|
817
|
+
)
|
|
818
|
+
else:
|
|
819
|
+
ai_message = AIMessage(
|
|
820
|
+
f"历史对话已压缩: {ai_content}",
|
|
821
|
+
additional_kwargs={"hide": True},
|
|
822
|
+
)
|
|
823
|
+
|
|
824
|
+
await self.agent.aupdate_state(
|
|
825
|
+
self.session_mgr.config,
|
|
826
|
+
{"messages": pre_messages + [human_msg, ai_message] + recent_messages},
|
|
827
|
+
as_node="model",
|
|
828
|
+
)
|
|
829
|
+
await self._load_conversation()
|
|
830
|
+
render_success("会话压缩完成")
|
|
831
|
+
except Exception as e:
|
|
832
|
+
render_error(f"压缩失败: {e}")
|
|
833
|
+
|
|
834
|
+
async def _cmd_git(self, _arg: str) -> None:
|
|
835
|
+
if not self.git_manager:
|
|
836
|
+
is_available, status, version = await asyncio.to_thread(
|
|
837
|
+
check_git_availability
|
|
838
|
+
)
|
|
839
|
+
if is_available:
|
|
840
|
+
render_success(f"Git {version}")
|
|
841
|
+
await self._init_git()
|
|
842
|
+
else:
|
|
843
|
+
render_error(f"Git 不可用: {status}")
|
|
844
|
+
return
|
|
845
|
+
|
|
846
|
+
if self.git_manager.is_repo():
|
|
847
|
+
count = self.git_manager.count_checkpoints()
|
|
848
|
+
self._git_cp_count = count
|
|
849
|
+
render_success(f"Git 仓库已初始化 ({count} 个检查点)")
|
|
850
|
+
else:
|
|
851
|
+
render_warning("Git 仓库未初始化")
|
|
852
|
+
|
|
853
|
+
async def _cmd_vision(self, _arg: str) -> None: # pragma: no cover
|
|
854
|
+
"""视觉模型配置命令""" # pragma: no cover
|
|
855
|
+
from chcode.vision_config import configure_vision_interactive # pragma: no cover
|
|
856
|
+
await configure_vision_interactive() # pragma: no cover
|
|
857
|
+
|
|
858
|
+
async def _cmd_search(self, _arg: str) -> None:
|
|
859
|
+
from chcode.config import load_tavily_api_key, save_tavily_api_key
|
|
860
|
+
from chcode.utils.tools import update_tavily_api_key
|
|
861
|
+
|
|
862
|
+
current = load_tavily_api_key()
|
|
863
|
+
masked = (
|
|
864
|
+
f"{current[:6]}...{current[-4:]}"
|
|
865
|
+
if current and len(current) > 10
|
|
866
|
+
else (current or "未配置")
|
|
867
|
+
)
|
|
868
|
+
render_info(f"当前 Tavily API Key: {masked}")
|
|
869
|
+
|
|
870
|
+
action = await select("操作:", ["配置 API Key", "清除 API Key", "返回"])
|
|
871
|
+
if action is None or action == "返回":
|
|
872
|
+
return
|
|
873
|
+
|
|
874
|
+
if action == "清除 API Key":
|
|
875
|
+
save_tavily_api_key("")
|
|
876
|
+
update_tavily_api_key("")
|
|
877
|
+
render_success("Tavily API Key 已清除")
|
|
878
|
+
return
|
|
879
|
+
|
|
880
|
+
new_key = await text("请输入 Tavily API Key:")
|
|
881
|
+
if new_key:
|
|
882
|
+
save_tavily_api_key(new_key)
|
|
883
|
+
update_tavily_api_key(new_key)
|
|
884
|
+
render_success("Tavily API Key 已保存")
|
|
885
|
+
else:
|
|
886
|
+
render_warning("未输入,已取消")
|
|
887
|
+
|
|
888
|
+
async def _cmd_mode(self, _arg: str) -> None:
|
|
889
|
+
action = await select(
|
|
890
|
+
"选择模式:",
|
|
891
|
+
["Common (手动批准风险操作)", "Yolo (自动批准所有操作)"],
|
|
892
|
+
)
|
|
893
|
+
if action is None:
|
|
894
|
+
return
|
|
895
|
+
self.yolo = "Yolo" in action
|
|
896
|
+
from chcode.agent_setup import update_hitl_config
|
|
897
|
+
|
|
898
|
+
update_hitl_config(self.yolo)
|
|
899
|
+
mode_str = "Yolo" if self.yolo else "Common"
|
|
900
|
+
render_success(f"已切换到 {mode_str} 模式")
|
|
901
|
+
|
|
902
|
+
async def _cmd_workdir(self, _arg: str) -> None:
|
|
903
|
+
saved = load_workplace()
|
|
904
|
+
choices = [str(saved)] if saved else []
|
|
905
|
+
|
|
906
|
+
result = await select_or_custom(
|
|
907
|
+
"选择工作目录:",
|
|
908
|
+
choices,
|
|
909
|
+
custom_label="自定义路径...",
|
|
910
|
+
custom_prompt="请输入工作目录路径: ",
|
|
911
|
+
)
|
|
912
|
+
if not result:
|
|
913
|
+
return
|
|
914
|
+
|
|
915
|
+
new_path = Path(result)
|
|
916
|
+
if not new_path.exists():
|
|
917
|
+
render_error("路径不存在")
|
|
918
|
+
return
|
|
919
|
+
|
|
920
|
+
self.workplace_path = new_path
|
|
921
|
+
self._skill_loader = None # 工作目录变了,失效缓存
|
|
922
|
+
os.chdir(self.workplace_path)
|
|
923
|
+
save_workplace(self.workplace_path)
|
|
924
|
+
|
|
925
|
+
# 重建子目录
|
|
926
|
+
chat_dir = self.workplace_path / ".chat"
|
|
927
|
+
chat_dir.mkdir(exist_ok=True)
|
|
928
|
+
(chat_dir / "sessions").mkdir(exist_ok=True)
|
|
929
|
+
(chat_dir / "skills").mkdir(exist_ok=True)
|
|
930
|
+
|
|
931
|
+
# 关闭旧 checkpointer 连接
|
|
932
|
+
if self.checkpointer is not None:
|
|
933
|
+
try:
|
|
934
|
+
await self.checkpointer.conn.close()
|
|
935
|
+
except Exception:
|
|
936
|
+
pass
|
|
937
|
+
self.checkpointer = None
|
|
938
|
+
|
|
939
|
+
# 重建会话和 agent
|
|
940
|
+
self.session_mgr = SessionManager(self.workplace_path)
|
|
941
|
+
db_path = self.workplace_path / ".chat" / "sessions" / "checkpointer.db"
|
|
942
|
+
self.checkpointer = await create_checkpointer(db_path)
|
|
943
|
+
self.agent = await asyncio.to_thread(
|
|
944
|
+
build_agent,
|
|
945
|
+
self.model_config,
|
|
946
|
+
self.checkpointer,
|
|
947
|
+
None,
|
|
948
|
+
self.yolo,
|
|
949
|
+
)
|
|
950
|
+
|
|
951
|
+
await self._init_git()
|
|
952
|
+
render_success(f"工作目录: {self.workplace_path}")
|
|
953
|
+
self._render_status_bar()
|
|
954
|
+
|
|
955
|
+
async def _cmd_help(self, _arg: str) -> None:
|
|
956
|
+
from rich.table import Table
|
|
957
|
+
|
|
958
|
+
table = Table(title="命令列表")
|
|
959
|
+
table.add_column("命令", style="cyan")
|
|
960
|
+
table.add_column("说明")
|
|
961
|
+
cmds = [
|
|
962
|
+
("/new", "新会话"),
|
|
963
|
+
("/history", "历史会话"),
|
|
964
|
+
("/model", "模型管理(新建/编辑/切换)"),
|
|
965
|
+
("/vision", "视觉模型配置"),
|
|
966
|
+
("/messages", "管理历史消息(编辑/分叉/删除)"),
|
|
967
|
+
("/compress", "压缩会话"),
|
|
968
|
+
("/skill", "技能管理"),
|
|
969
|
+
("/search", "配置 Tavily 搜索 API Key"),
|
|
970
|
+
("/workdir", "切换工作目录"),
|
|
971
|
+
("/mode", "切换 Common/Yolo 模式"),
|
|
972
|
+
("/git", "Git 状态"),
|
|
973
|
+
("/langsmith", "LangSmith 追踪开关"),
|
|
974
|
+
("/tools", "显示内置工具"),
|
|
975
|
+
("/help", "显示此帮助"),
|
|
976
|
+
("/quit", "退出"),
|
|
977
|
+
]
|
|
978
|
+
for cmd, desc in cmds:
|
|
979
|
+
table.add_row(cmd, desc)
|
|
980
|
+
console.print(table)
|
|
981
|
+
|
|
982
|
+
async def _cmd_quit(self, _arg: str) -> None:
|
|
983
|
+
raise EOFError()
|
|
984
|
+
|
|
985
|
+
# ─── 消息管理命令 ──────────────────────────────────
|
|
986
|
+
|
|
987
|
+
async def _cmd_messages(self, _arg: str) -> None:
|
|
988
|
+
"""管理历史消息:编辑、分叉、删除"""
|
|
989
|
+
if not self.agent or not self.session_mgr:
|
|
990
|
+
render_error("Agent 未初始化")
|
|
991
|
+
return
|
|
992
|
+
|
|
993
|
+
state = await self.agent.aget_state(self.session_mgr.config)
|
|
994
|
+
messages: list[BaseMessage] = state.values.get("messages", [])
|
|
995
|
+
|
|
996
|
+
groups = _group_messages_by_turn(messages)
|
|
997
|
+
if not groups:
|
|
998
|
+
render_warning("没有可管理的消息")
|
|
999
|
+
return
|
|
1000
|
+
|
|
1001
|
+
while True:
|
|
1002
|
+
# 第一步:选择操作类型
|
|
1003
|
+
action = await select("选择操作:", ["编辑消息", "分叉消息", "删除消息"])
|
|
1004
|
+
if not action:
|
|
1005
|
+
return
|
|
1006
|
+
|
|
1007
|
+
# 构建选项列表(带返回选项)
|
|
1008
|
+
options = []
|
|
1009
|
+
for idx, group in enumerate(groups):
|
|
1010
|
+
display = _get_group_display(group)
|
|
1011
|
+
options.append(f"[{idx + 1}] {display}")
|
|
1012
|
+
|
|
1013
|
+
if action == "删除消息":
|
|
1014
|
+
# 多选
|
|
1015
|
+
chosen_list = await checkbox(
|
|
1016
|
+
"选择要删除的消息组(空格选择,回车确认):", options
|
|
1017
|
+
)
|
|
1018
|
+
if not chosen_list:
|
|
1019
|
+
continue # 返回操作选择
|
|
1020
|
+
|
|
1021
|
+
ok = await confirm(
|
|
1022
|
+
f"确定删除 {len(chosen_list)} 个消息组?", default=False
|
|
1023
|
+
)
|
|
1024
|
+
if not ok:
|
|
1025
|
+
continue
|
|
1026
|
+
|
|
1027
|
+
delete_ids = []
|
|
1028
|
+
for chosen in chosen_list:
|
|
1029
|
+
try:
|
|
1030
|
+
sel_idx = int(chosen.split("]")[0].replace("[", "")) - 1
|
|
1031
|
+
if 0 <= sel_idx < len(groups):
|
|
1032
|
+
delete_ids.extend([m.id for m in groups[sel_idx]])
|
|
1033
|
+
except (ValueError, IndexError):
|
|
1034
|
+
continue
|
|
1035
|
+
|
|
1036
|
+
if not delete_ids:
|
|
1037
|
+
render_error("没有有效的选择")
|
|
1038
|
+
continue
|
|
1039
|
+
|
|
1040
|
+
await self._delete_messages(delete_ids)
|
|
1041
|
+
render_success(f"已删除 {len(chosen_list)} 个消息组")
|
|
1042
|
+
return
|
|
1043
|
+
|
|
1044
|
+
# 编辑 / 分叉:单选一条消息组
|
|
1045
|
+
if action == "编辑消息":
|
|
1046
|
+
hint = "选择要编辑的消息组(编辑后将删除此消息组之后的所有内容):"
|
|
1047
|
+
else:
|
|
1048
|
+
hint = "选择 Fork 点(此消息组将保留在分支中):"
|
|
1049
|
+
|
|
1050
|
+
select_options = options + ["返回"]
|
|
1051
|
+
chosen = await select(hint, select_options)
|
|
1052
|
+
if not chosen:
|
|
1053
|
+
return
|
|
1054
|
+
if chosen == "返回":
|
|
1055
|
+
continue
|
|
1056
|
+
|
|
1057
|
+
# 解析选择
|
|
1058
|
+
try:
|
|
1059
|
+
sel_idx = int(chosen.split("]")[0].replace("[", "")) - 1
|
|
1060
|
+
if sel_idx < 0 or sel_idx >= len(groups):
|
|
1061
|
+
render_error("无效的选择")
|
|
1062
|
+
continue
|
|
1063
|
+
except (ValueError, IndexError):
|
|
1064
|
+
render_error("无效的选择")
|
|
1065
|
+
continue
|
|
1066
|
+
|
|
1067
|
+
if action == "编辑消息":
|
|
1068
|
+
target_group = groups[sel_idx]
|
|
1069
|
+
edit_msg = None
|
|
1070
|
+
for msg in target_group:
|
|
1071
|
+
if msg.type == "human":
|
|
1072
|
+
edit_msg = msg
|
|
1073
|
+
break
|
|
1074
|
+
|
|
1075
|
+
if not edit_msg:
|
|
1076
|
+
render_warning("该组没有 HumanMessage")
|
|
1077
|
+
continue
|
|
1078
|
+
|
|
1079
|
+
ok = await confirm(
|
|
1080
|
+
"确定编辑此消息组?编辑后将删除此消息组之后的所有内容。",
|
|
1081
|
+
default=False,
|
|
1082
|
+
)
|
|
1083
|
+
if not ok:
|
|
1084
|
+
continue
|
|
1085
|
+
|
|
1086
|
+
no_need_ids, all_ids = _collect_ids_from_group(
|
|
1087
|
+
sel_idx, groups, mode="edit"
|
|
1088
|
+
)
|
|
1089
|
+
|
|
1090
|
+
if self.git and self.git_manager:
|
|
1091
|
+
try:
|
|
1092
|
+
await asyncio.to_thread(
|
|
1093
|
+
self.git_manager.rollback, no_need_ids, all_ids
|
|
1094
|
+
)
|
|
1095
|
+
except Exception as e:
|
|
1096
|
+
render_warning(f"Git 回滚失败: {e}")
|
|
1097
|
+
|
|
1098
|
+
await self._delete_messages(no_need_ids)
|
|
1099
|
+
|
|
1100
|
+
self._edit_buffer = get_text_content(edit_msg.content)
|
|
1101
|
+
render_success("消息已加载到输入框,修改后发送即可重新生成")
|
|
1102
|
+
return
|
|
1103
|
+
|
|
1104
|
+
elif action == "分叉消息":
|
|
1105
|
+
ok = await confirm(
|
|
1106
|
+
f"确定从第 {sel_idx + 1} 条消息组创建分支?", default=True
|
|
1107
|
+
)
|
|
1108
|
+
if not ok:
|
|
1109
|
+
continue
|
|
1110
|
+
|
|
1111
|
+
no_need_ids, all_ids = _collect_ids_from_group(
|
|
1112
|
+
sel_idx, groups, mode="fork"
|
|
1113
|
+
)
|
|
1114
|
+
|
|
1115
|
+
saved = load_workplace()
|
|
1116
|
+
if saved:
|
|
1117
|
+
choices = [str(saved), "自定义路径..."]
|
|
1118
|
+
else:
|
|
1119
|
+
choices = ["自定义路径..."]
|
|
1120
|
+
|
|
1121
|
+
new_path_str = await select_or_custom("选择新工作目录:", choices)
|
|
1122
|
+
if not new_path_str:
|
|
1123
|
+
continue
|
|
1124
|
+
|
|
1125
|
+
new_path = Path(new_path_str)
|
|
1126
|
+
if not new_path.exists():
|
|
1127
|
+
render_error("路径不存在")
|
|
1128
|
+
continue
|
|
1129
|
+
|
|
1130
|
+
old_path = self.workplace_path
|
|
1131
|
+
|
|
1132
|
+
self.workplace_path = new_path
|
|
1133
|
+
os.chdir(self.workplace_path)
|
|
1134
|
+
save_workplace(self.workplace_path)
|
|
1135
|
+
|
|
1136
|
+
chat_dir = self.workplace_path / ".chat"
|
|
1137
|
+
chat_dir.mkdir(exist_ok=True)
|
|
1138
|
+
(chat_dir / "sessions").mkdir(exist_ok=True)
|
|
1139
|
+
(chat_dir / "skills").mkdir(exist_ok=True)
|
|
1140
|
+
|
|
1141
|
+
if old_path != new_path:
|
|
1142
|
+
render_info("复制工作目录文件...")
|
|
1143
|
+
try:
|
|
1144
|
+
await asyncio.to_thread(self._copy_dir, old_path, new_path)
|
|
1145
|
+
# 复制 .git 目录以保留检查点数据
|
|
1146
|
+
old_git = old_path / ".git"
|
|
1147
|
+
new_git = new_path / ".git"
|
|
1148
|
+
if old_git.exists() and old_git.is_dir():
|
|
1149
|
+
await asyncio.to_thread(
|
|
1150
|
+
shutil.copytree, old_git, new_git, dirs_exist_ok=True
|
|
1151
|
+
)
|
|
1152
|
+
sessions_path = self.workplace_path / ".chat" / "sessions"
|
|
1153
|
+
if sessions_path.exists():
|
|
1154
|
+
await asyncio.to_thread(shutil.rmtree, sessions_path)
|
|
1155
|
+
sessions_path.mkdir(exist_ok=True)
|
|
1156
|
+
except Exception:
|
|
1157
|
+
import traceback
|
|
1158
|
+
|
|
1159
|
+
tb = traceback.format_exc()
|
|
1160
|
+
render_error(f"复制文件失败:\n{tb}")
|
|
1161
|
+
self.workplace_path = old_path
|
|
1162
|
+
os.chdir(self.workplace_path)
|
|
1163
|
+
return
|
|
1164
|
+
|
|
1165
|
+
# 关闭旧 checkpointer 连接
|
|
1166
|
+
if self.checkpointer is not None:
|
|
1167
|
+
try:
|
|
1168
|
+
await self.checkpointer.conn.close()
|
|
1169
|
+
except Exception:
|
|
1170
|
+
pass
|
|
1171
|
+
self.checkpointer = None
|
|
1172
|
+
|
|
1173
|
+
self.session_mgr = SessionManager(self.workplace_path)
|
|
1174
|
+
db_path = self.workplace_path / ".chat" / "sessions" / "checkpointer.db"
|
|
1175
|
+
self.checkpointer = await create_checkpointer(db_path)
|
|
1176
|
+
|
|
1177
|
+
self.agent = await asyncio.to_thread(
|
|
1178
|
+
build_agent,
|
|
1179
|
+
self.model_config,
|
|
1180
|
+
self.checkpointer,
|
|
1181
|
+
None,
|
|
1182
|
+
self.yolo,
|
|
1183
|
+
)
|
|
1184
|
+
|
|
1185
|
+
need_messages = []
|
|
1186
|
+
for i, group in enumerate(groups):
|
|
1187
|
+
need_messages.extend(group)
|
|
1188
|
+
if i == sel_idx:
|
|
1189
|
+
break
|
|
1190
|
+
|
|
1191
|
+
await self.agent.aupdate_state(
|
|
1192
|
+
self.session_mgr.config,
|
|
1193
|
+
{"messages": need_messages},
|
|
1194
|
+
)
|
|
1195
|
+
|
|
1196
|
+
# 先初始化 git
|
|
1197
|
+
await self._init_git()
|
|
1198
|
+
|
|
1199
|
+
# 回滚工作目录
|
|
1200
|
+
if self.git and self.git_manager:
|
|
1201
|
+
try:
|
|
1202
|
+
await asyncio.to_thread(
|
|
1203
|
+
self.git_manager.rollback, no_need_ids, all_ids
|
|
1204
|
+
)
|
|
1205
|
+
except Exception as e:
|
|
1206
|
+
render_warning(f"Git 回滚失败: {e}")
|
|
1207
|
+
|
|
1208
|
+
render_success(f"分支已创建!工作目录: {self.workplace_path}")
|
|
1209
|
+
await self._load_conversation()
|
|
1210
|
+
self._render_status_bar()
|
|
1211
|
+
return
|
|
1212
|
+
|
|
1213
|
+
async def _handle_agent_error(self, error: Exception) -> None:
|
|
1214
|
+
"""Agent 出错时:当前组无 AIMessage 则删除整组,否则保存错误消息"""
|
|
1215
|
+
try:
|
|
1216
|
+
state = await self.agent.aget_state(self.session_mgr.config)
|
|
1217
|
+
messages: list[BaseMessage] = state.values.get("messages", [])
|
|
1218
|
+
|
|
1219
|
+
# 找到最后一组消息(以最后一个 HumanMessage 开头)
|
|
1220
|
+
last_human_idx = -1
|
|
1221
|
+
for i, msg in enumerate(messages):
|
|
1222
|
+
if isinstance(msg, HumanMessage):
|
|
1223
|
+
last_human_idx = i
|
|
1224
|
+
|
|
1225
|
+
if last_human_idx >= 0:
|
|
1226
|
+
current_group = messages[last_human_idx:]
|
|
1227
|
+
has_ai = any(isinstance(m, AIMessage) for m in current_group)
|
|
1228
|
+
|
|
1229
|
+
if not has_ai:
|
|
1230
|
+
# 当前组没有 AIMessage,删除整组
|
|
1231
|
+
await self._delete_messages([m.id for m in current_group])
|
|
1232
|
+
return
|
|
1233
|
+
|
|
1234
|
+
# 有 AIMessage,按原逻辑保存错误消息
|
|
1235
|
+
error_msg = AIMessage(
|
|
1236
|
+
f"Agent 执行错误: {error}",
|
|
1237
|
+
additional_kwargs={"error": True, "composed": True},
|
|
1238
|
+
)
|
|
1239
|
+
await self.agent.aupdate_state(
|
|
1240
|
+
self.session_mgr.config,
|
|
1241
|
+
{"messages": [error_msg]},
|
|
1242
|
+
as_node="model",
|
|
1243
|
+
)
|
|
1244
|
+
except Exception:
|
|
1245
|
+
pass
|
|
1246
|
+
|
|
1247
|
+
async def _handle_cancel(self, user_input: str) -> None:
|
|
1248
|
+
"""取消时:当前组无 AIMessage 则删除整组并回填输入框,否则追加停止消息"""
|
|
1249
|
+
try:
|
|
1250
|
+
state = await self.agent.aget_state(self.session_mgr.config)
|
|
1251
|
+
messages: list[BaseMessage] = state.values.get("messages", [])
|
|
1252
|
+
|
|
1253
|
+
# 找到最后一组消息(以最后一个 HumanMessage 开头)
|
|
1254
|
+
last_human_idx = -1
|
|
1255
|
+
for i, msg in enumerate(messages):
|
|
1256
|
+
if isinstance(msg, HumanMessage):
|
|
1257
|
+
last_human_idx = i
|
|
1258
|
+
|
|
1259
|
+
if last_human_idx >= 0:
|
|
1260
|
+
current_group = messages[last_human_idx:]
|
|
1261
|
+
has_ai = any(isinstance(m, AIMessage) for m in current_group)
|
|
1262
|
+
|
|
1263
|
+
if not has_ai:
|
|
1264
|
+
# 当前组没有 AIMessage,删除整组并回填输入框
|
|
1265
|
+
await self._delete_messages([m.id for m in current_group])
|
|
1266
|
+
self._interrupt_buffer = user_input.strip()
|
|
1267
|
+
return
|
|
1268
|
+
|
|
1269
|
+
# 有 AIMessage,追加一条停止消息
|
|
1270
|
+
error_msg = AIMessage(
|
|
1271
|
+
"该消息意外停止",
|
|
1272
|
+
additional_kwargs={"error": True, "composed": True},
|
|
1273
|
+
)
|
|
1274
|
+
await self.agent.aupdate_state(
|
|
1275
|
+
self.session_mgr.config,
|
|
1276
|
+
{"messages": [error_msg]},
|
|
1277
|
+
as_node="model",
|
|
1278
|
+
)
|
|
1279
|
+
except Exception:
|
|
1280
|
+
pass
|
|
1281
|
+
|
|
1282
|
+
async def _delete_messages(self, message_ids: list[str]) -> None:
|
|
1283
|
+
"""删除指定消息"""
|
|
1284
|
+
if not self.agent or not self.session_mgr:
|
|
1285
|
+
return
|
|
1286
|
+
|
|
1287
|
+
# 使用 RemoveMessage 删除
|
|
1288
|
+
remove_messages = [RemoveMessage(id=mid) for mid in message_ids]
|
|
1289
|
+
await self.agent.aupdate_state(
|
|
1290
|
+
self.session_mgr.config,
|
|
1291
|
+
{"messages": remove_messages},
|
|
1292
|
+
)
|
|
1293
|
+
|
|
1294
|
+
def _copy_dir(self, src: Path, dst: Path):
|
|
1295
|
+
"""复制目录(同步版本)"""
|
|
1296
|
+
for item in src.iterdir():
|
|
1297
|
+
if item.name.startswith("."):
|
|
1298
|
+
continue
|
|
1299
|
+
if item.stem.lower() in self.WINDOWS_RESERVED_NAMES:
|
|
1300
|
+
print(f"跳过 Windows 保留名: {item.name}")
|
|
1301
|
+
continue
|
|
1302
|
+
dest_item = dst / item.name
|
|
1303
|
+
if item.is_dir():
|
|
1304
|
+
try:
|
|
1305
|
+
shutil.copytree(item, dest_item, dirs_exist_ok=True)
|
|
1306
|
+
except Exception as e:
|
|
1307
|
+
print(f"复制目录失败: {item.name}, {e}")
|
|
1308
|
+
else:
|
|
1309
|
+
try:
|
|
1310
|
+
shutil.copy2(item, dest_item)
|
|
1311
|
+
except Exception as e:
|
|
1312
|
+
print(f"复制文件失败: {item.name}, {e}")
|
|
1313
|
+
|
|
1314
|
+
def _render_status_bar(self) -> None:
|
|
1315
|
+
"""状态栏由 bottom_toolbar 自动渲染,此方法仅用于触发刷新"""
|
|
1316
|
+
pass
|
|
1317
|
+
|
|
1318
|
+
# ─── 对话处理 ──────────────────────────────────────
|
|
1319
|
+
|
|
1320
|
+
async def _process_input(self, user_input: str) -> None:
|
|
1321
|
+
"""处理用户输入并调用 agent"""
|
|
1322
|
+
self._processing = True
|
|
1323
|
+
self._stop_requested = False
|
|
1324
|
+
|
|
1325
|
+
accumulated_content = ""
|
|
1326
|
+
ai_started = False
|
|
1327
|
+
|
|
1328
|
+
try:
|
|
1329
|
+
# 多模态模型:检测图片/视频路径并嵌入消息
|
|
1330
|
+
from chcode.utils.multimodal import (
|
|
1331
|
+
is_multimodal_model,
|
|
1332
|
+
extract_media_paths,
|
|
1333
|
+
build_multimodal_message,
|
|
1334
|
+
)
|
|
1335
|
+
|
|
1336
|
+
current_model = (self.model_config or {}).get("model", "")
|
|
1337
|
+
if self.workplace_path and is_multimodal_model(current_model):
|
|
1338
|
+
media_paths = extract_media_paths(user_input, self.workplace_path)
|
|
1339
|
+
if media_paths:
|
|
1340
|
+
message = build_multimodal_message(user_input, media_paths)
|
|
1341
|
+
input_data = {"messages": message}
|
|
1342
|
+
render_info(f"[已嵌入 {len(media_paths)} 个媒体文件]")
|
|
1343
|
+
else:
|
|
1344
|
+
input_data = {"messages": user_input}
|
|
1345
|
+
else:
|
|
1346
|
+
input_data = {"messages": user_input}
|
|
1347
|
+
|
|
1348
|
+
if self._skill_loader is None:
|
|
1349
|
+
from chcode.utils.skill_loader import SkillLoader
|
|
1350
|
+
|
|
1351
|
+
self._skill_loader = SkillLoader(
|
|
1352
|
+
[
|
|
1353
|
+
self.workplace_path / ".chat/skills",
|
|
1354
|
+
Path.home() / ".chat/skills",
|
|
1355
|
+
]
|
|
1356
|
+
)
|
|
1357
|
+
|
|
1358
|
+
skill_agent_context = SkillAgentContext(
|
|
1359
|
+
skill_loader=self._skill_loader,
|
|
1360
|
+
working_directory=self.workplace_path,
|
|
1361
|
+
model_config=self.model_config or INNER_MODEL_CONFIG,
|
|
1362
|
+
thread_id=self.session_mgr.thread_id,
|
|
1363
|
+
)
|
|
1364
|
+
|
|
1365
|
+
while True:
|
|
1366
|
+
interrupt_chunk = None
|
|
1367
|
+
|
|
1368
|
+
try:
|
|
1369
|
+
async for m, i in self.agent.astream(
|
|
1370
|
+
input_data,
|
|
1371
|
+
self.session_mgr.config,
|
|
1372
|
+
stream_mode=["messages", "updates"],
|
|
1373
|
+
context=skill_agent_context,
|
|
1374
|
+
):
|
|
1375
|
+
if self._stop_requested:
|
|
1376
|
+
raise asyncio.CancelledError()
|
|
1377
|
+
|
|
1378
|
+
if m == "messages":
|
|
1379
|
+
content = get_text_content(i[0].content)
|
|
1380
|
+
additional_kwargs = i[0].additional_kwargs
|
|
1381
|
+
|
|
1382
|
+
if additional_kwargs.get("hide", ""):
|
|
1383
|
+
continue
|
|
1384
|
+
|
|
1385
|
+
if isinstance(i[0], AIMessageChunk):
|
|
1386
|
+
reasoning = additional_kwargs.get("reasoning")
|
|
1387
|
+
if reasoning:
|
|
1388
|
+
if (
|
|
1389
|
+
not _display._subagent_parallel
|
|
1390
|
+
and _display._subagent_count == 0
|
|
1391
|
+
):
|
|
1392
|
+
console.print(reasoning, end="", style="dim")
|
|
1393
|
+
if not ai_started:
|
|
1394
|
+
if not content:
|
|
1395
|
+
continue
|
|
1396
|
+
ai_started = True
|
|
1397
|
+
render_ai_start()
|
|
1398
|
+
render_ai_chunk(content or "")
|
|
1399
|
+
accumulated_content += content or ""
|
|
1400
|
+
|
|
1401
|
+
elif isinstance(i[0], ToolMessage):
|
|
1402
|
+
ai_started = False
|
|
1403
|
+
|
|
1404
|
+
elif m == "updates" and "__interrupt__" in i:
|
|
1405
|
+
interrupt_chunk = i
|
|
1406
|
+
|
|
1407
|
+
except asyncio.CancelledError:
|
|
1408
|
+
await self._handle_cancel(user_input)
|
|
1409
|
+
console.print(Text("\n[已中断]", style="dim"), "\n")
|
|
1410
|
+
break
|
|
1411
|
+
except ModelSwitchError:
|
|
1412
|
+
# 需要切换到备用模型
|
|
1413
|
+
fallback = get_fallback_model()
|
|
1414
|
+
if fallback:
|
|
1415
|
+
console.print(f"[yellow]正在切换到备用模型: {fallback.get('model', 'unknown')}[/yellow]")
|
|
1416
|
+
self.model_config = fallback
|
|
1417
|
+
advance_fallback()
|
|
1418
|
+
try:
|
|
1419
|
+
self.agent = await asyncio.to_thread(
|
|
1420
|
+
build_agent,
|
|
1421
|
+
self.model_config,
|
|
1422
|
+
self.checkpointer,
|
|
1423
|
+
None,
|
|
1424
|
+
self.yolo,
|
|
1425
|
+
)
|
|
1426
|
+
# 重建 context 以使用新模型配置
|
|
1427
|
+
skill_agent_context = SkillAgentContext(
|
|
1428
|
+
skill_loader=self._skill_loader,
|
|
1429
|
+
working_directory=self.workplace_path,
|
|
1430
|
+
model_config=self.model_config or INNER_MODEL_CONFIG,
|
|
1431
|
+
thread_id=self.session_mgr.thread_id,
|
|
1432
|
+
)
|
|
1433
|
+
console.print("[green]已切换到备用模型,自动重试中...[/green]")
|
|
1434
|
+
continue # 用备用模型重试当前请求
|
|
1435
|
+
except Exception as e:
|
|
1436
|
+
render_error(f"切换模型失败: {e}")
|
|
1437
|
+
else:
|
|
1438
|
+
render_error("没有更多备用模型可用")
|
|
1439
|
+
await self._handle_agent_error(ModelSwitchError("所有模型均失败"))
|
|
1440
|
+
break
|
|
1441
|
+
except openai.APIError as e:
|
|
1442
|
+
render_error(f"Agent 执行错误: {e}")
|
|
1443
|
+
await self._handle_agent_error(e)
|
|
1444
|
+
break
|
|
1445
|
+
except Exception as e:
|
|
1446
|
+
render_error(f"Agent 执行错误: {e}")
|
|
1447
|
+
await self._handle_agent_error(e)
|
|
1448
|
+
break
|
|
1449
|
+
|
|
1450
|
+
if self._stop_requested:
|
|
1451
|
+
break
|
|
1452
|
+
|
|
1453
|
+
if interrupt_chunk is None:
|
|
1454
|
+
break
|
|
1455
|
+
|
|
1456
|
+
# HITL 审批
|
|
1457
|
+
decisions = await self._collect_decisions_async(interrupt_chunk)
|
|
1458
|
+
input_data = Command(resume={"decisions": decisions})
|
|
1459
|
+
|
|
1460
|
+
if ai_started:
|
|
1461
|
+
render_ai_end()
|
|
1462
|
+
|
|
1463
|
+
# 后处理(上下文更新 + Git 提交)放到后台,不阻塞输入框
|
|
1464
|
+
asyncio.create_task(self._post_process())
|
|
1465
|
+
|
|
1466
|
+
finally:
|
|
1467
|
+
self._processing = False
|
|
1468
|
+
|
|
1469
|
+
async def _post_process(self) -> None:
|
|
1470
|
+
"""流式输出后的后台处理:更新上下文用量、Git 提交"""
|
|
1471
|
+
try:
|
|
1472
|
+
state = await self.agent.aget_state(self.session_mgr.config)
|
|
1473
|
+
messages = state.values.get("messages", [])
|
|
1474
|
+
model_name = self.model_config.get("model", "")
|
|
1475
|
+
max_ctx = get_context_window_size(model_name)
|
|
1476
|
+
self._context_text = get_context_usage_text(messages, max_ctx)
|
|
1477
|
+
|
|
1478
|
+
if self.git and self.git_manager:
|
|
1479
|
+
new_msgs = find_and_slice_from_end(messages, "human")
|
|
1480
|
+
ids = [m.id for m in new_msgs]
|
|
1481
|
+
result = await asyncio.to_thread(
|
|
1482
|
+
self.git_manager.add_commit, "&".join(ids)
|
|
1483
|
+
)
|
|
1484
|
+
if isinstance(result, int) and not isinstance(result, bool):
|
|
1485
|
+
self._git_cp_count = result
|
|
1486
|
+
except Exception:
|
|
1487
|
+
pass
|
|
1488
|
+
|
|
1489
|
+
async def _collect_decisions_async(self, interrupt_chunk) -> list[dict]:
|
|
1490
|
+
"""收集 HITL 决策"""
|
|
1491
|
+
console.print() # 确保 AI 输出和 HITL 之间有换行
|
|
1492
|
+
decisions = []
|
|
1493
|
+
for interrupt in interrupt_chunk["__interrupt__"]:
|
|
1494
|
+
action_requests = interrupt.value["action_requests"]
|
|
1495
|
+
|
|
1496
|
+
for action_request in action_requests:
|
|
1497
|
+
name = action_request["name"]
|
|
1498
|
+
args = action_request["args"]
|
|
1499
|
+
|
|
1500
|
+
content = ""
|
|
1501
|
+
match name:
|
|
1502
|
+
case "bash":
|
|
1503
|
+
content = args.get("command", "")
|
|
1504
|
+
case "write_file":
|
|
1505
|
+
content = f"写入文件: {args.get('file_path')}\n内容: {args.get('content', '')[:200]}"
|
|
1506
|
+
case "edit":
|
|
1507
|
+
file_path = args.get("file_path", "")
|
|
1508
|
+
old_str = args.get("old_string", "")
|
|
1509
|
+
new_str = args.get("new_string", "")
|
|
1510
|
+
render_warning(f"[HITL] edit 修改文件: {file_path}")
|
|
1511
|
+
import difflib
|
|
1512
|
+
from rich.table import Table
|
|
1513
|
+
|
|
1514
|
+
# 查找 old_str 在文件中的起始行号
|
|
1515
|
+
start_line = 1
|
|
1516
|
+
try:
|
|
1517
|
+
content = await asyncio.to_thread(
|
|
1518
|
+
Path(file_path).read_text, encoding="utf-8"
|
|
1519
|
+
)
|
|
1520
|
+
for i, line in enumerate(content.splitlines(), 1):
|
|
1521
|
+
if old_str.splitlines()[0] in line:
|
|
1522
|
+
start_line = i
|
|
1523
|
+
break
|
|
1524
|
+
except Exception:
|
|
1525
|
+
pass
|
|
1526
|
+
old_lines = old_str.splitlines()
|
|
1527
|
+
new_lines = new_str.splitlines()
|
|
1528
|
+
table = Table(
|
|
1529
|
+
show_header=False,
|
|
1530
|
+
show_edge=False,
|
|
1531
|
+
padding=(0, 1),
|
|
1532
|
+
border_style="dim",
|
|
1533
|
+
)
|
|
1534
|
+
table.add_column("old", ratio=1)
|
|
1535
|
+
table.add_column("new", ratio=1)
|
|
1536
|
+
sm = difflib.SequenceMatcher(None, old_lines, new_lines)
|
|
1537
|
+
old_num = start_line
|
|
1538
|
+
new_num = start_line
|
|
1539
|
+
for tag, i1, i2, j1, j2 in sm.get_opcodes():
|
|
1540
|
+
if tag == "equal":
|
|
1541
|
+
for k in range(i2 - i1):
|
|
1542
|
+
table.add_row(
|
|
1543
|
+
Text(
|
|
1544
|
+
f" {old_num:>3} {old_lines[i1 + k]}",
|
|
1545
|
+
style="dim",
|
|
1546
|
+
),
|
|
1547
|
+
Text(
|
|
1548
|
+
f" {new_num:>3} {new_lines[j1 + k]}",
|
|
1549
|
+
style="dim",
|
|
1550
|
+
),
|
|
1551
|
+
)
|
|
1552
|
+
old_num += 1
|
|
1553
|
+
new_num += 1
|
|
1554
|
+
elif tag == "replace":
|
|
1555
|
+
max_len = max(i2 - i1, j2 - j1)
|
|
1556
|
+
for k in range(max_len):
|
|
1557
|
+
old_text = (
|
|
1558
|
+
Text(
|
|
1559
|
+
f"{old_num:>3} - {old_lines[i1 + k]}",
|
|
1560
|
+
style="red",
|
|
1561
|
+
)
|
|
1562
|
+
if k < i2 - i1
|
|
1563
|
+
else None
|
|
1564
|
+
)
|
|
1565
|
+
new_text = (
|
|
1566
|
+
Text(
|
|
1567
|
+
f"{new_num:>3} + {new_lines[j1 + k]}",
|
|
1568
|
+
style="green",
|
|
1569
|
+
)
|
|
1570
|
+
if k < j2 - j1
|
|
1571
|
+
else None
|
|
1572
|
+
)
|
|
1573
|
+
table.add_row(old_text, new_text)
|
|
1574
|
+
if k < i2 - i1:
|
|
1575
|
+
old_num += 1
|
|
1576
|
+
if k < j2 - j1:
|
|
1577
|
+
new_num += 1
|
|
1578
|
+
elif tag == "delete":
|
|
1579
|
+
for k in range(i2 - i1):
|
|
1580
|
+
table.add_row(
|
|
1581
|
+
Text(
|
|
1582
|
+
f"{old_num:>3} - {old_lines[i1 + k]}",
|
|
1583
|
+
style="red",
|
|
1584
|
+
)
|
|
1585
|
+
)
|
|
1586
|
+
old_num += 1
|
|
1587
|
+
elif tag == "insert":
|
|
1588
|
+
for k in range(j2 - j1):
|
|
1589
|
+
table.add_row(
|
|
1590
|
+
None,
|
|
1591
|
+
Text(
|
|
1592
|
+
f"{new_num:>3} + {new_lines[j1 + k]}",
|
|
1593
|
+
style="green",
|
|
1594
|
+
),
|
|
1595
|
+
)
|
|
1596
|
+
new_num += 1
|
|
1597
|
+
console.print(table)
|
|
1598
|
+
content = None # 已直接渲染,跳过通用渲染
|
|
1599
|
+
|
|
1600
|
+
if self.yolo:
|
|
1601
|
+
select_action = True
|
|
1602
|
+
else:
|
|
1603
|
+
if content is not None:
|
|
1604
|
+
render_warning(f"[HITL] {name}")
|
|
1605
|
+
console.print(Text(f" {content[:500]}", style="dim"))
|
|
1606
|
+
result = await select(
|
|
1607
|
+
"操作:",
|
|
1608
|
+
["approve (批准)", "reject (拒绝)"],
|
|
1609
|
+
)
|
|
1610
|
+
select_action = result != "reject (拒绝)" if result else False
|
|
1611
|
+
|
|
1612
|
+
extra = {}
|
|
1613
|
+
if not select_action:
|
|
1614
|
+
extra["message"] = "用户已拒绝"
|
|
1615
|
+
decision = {"type": "approve" if select_action else "reject"}
|
|
1616
|
+
decision.update(extra)
|
|
1617
|
+
decisions.append(decision)
|
|
1618
|
+
|
|
1619
|
+
return decisions
|
|
1620
|
+
|
|
1621
|
+
async def _load_conversation(self) -> None:
|
|
1622
|
+
"""加载当前会话的对话历史并渲染"""
|
|
1623
|
+
if not self.agent:
|
|
1624
|
+
return
|
|
1625
|
+
try:
|
|
1626
|
+
state = await self.agent.aget_state(self.session_mgr.config)
|
|
1627
|
+
messages = state.values.get("messages", [])
|
|
1628
|
+
render_conversation(messages)
|
|
1629
|
+
except Exception as e:
|
|
1630
|
+
render_error(f"加载对话失败: {e}")
|