jarvis-ai-assistant 0.2.8__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. jarvis/__init__.py +1 -1
  2. jarvis/jarvis_agent/__init__.py +277 -242
  3. jarvis/jarvis_agent/agent_manager.py +85 -0
  4. jarvis/jarvis_agent/config_editor.py +53 -0
  5. jarvis/jarvis_agent/file_methodology_manager.py +105 -0
  6. jarvis/jarvis_agent/jarvis.py +30 -619
  7. jarvis/jarvis_agent/memory_manager.py +127 -0
  8. jarvis/jarvis_agent/methodology_share_manager.py +174 -0
  9. jarvis/jarvis_agent/prompts.py +18 -3
  10. jarvis/jarvis_agent/share_manager.py +176 -0
  11. jarvis/jarvis_agent/task_analyzer.py +126 -0
  12. jarvis/jarvis_agent/task_manager.py +111 -0
  13. jarvis/jarvis_agent/tool_share_manager.py +139 -0
  14. jarvis/jarvis_code_agent/code_agent.py +26 -20
  15. jarvis/jarvis_data/config_schema.json +37 -4
  16. jarvis/jarvis_platform/ai8.py +13 -1
  17. jarvis/jarvis_platform/base.py +20 -5
  18. jarvis/jarvis_platform/human.py +11 -1
  19. jarvis/jarvis_platform/kimi.py +10 -0
  20. jarvis/jarvis_platform/openai.py +20 -0
  21. jarvis/jarvis_platform/tongyi.py +14 -9
  22. jarvis/jarvis_platform/yuanbao.py +10 -0
  23. jarvis/jarvis_platform_manager/main.py +12 -12
  24. jarvis/jarvis_platform_manager/service.py +9 -4
  25. jarvis/jarvis_tools/registry.py +32 -0
  26. jarvis/jarvis_tools/retrieve_memory.py +36 -8
  27. jarvis/jarvis_tools/search_web.py +1 -1
  28. jarvis/jarvis_utils/clipboard.py +90 -0
  29. jarvis/jarvis_utils/config.py +64 -0
  30. jarvis/jarvis_utils/git_utils.py +17 -7
  31. jarvis/jarvis_utils/globals.py +18 -12
  32. jarvis/jarvis_utils/input.py +118 -16
  33. jarvis/jarvis_utils/methodology.py +48 -5
  34. jarvis/jarvis_utils/utils.py +169 -105
  35. {jarvis_ai_assistant-0.2.8.dist-info → jarvis_ai_assistant-0.3.1.dist-info}/METADATA +1 -1
  36. {jarvis_ai_assistant-0.2.8.dist-info → jarvis_ai_assistant-0.3.1.dist-info}/RECORD +40 -30
  37. {jarvis_ai_assistant-0.2.8.dist-info → jarvis_ai_assistant-0.3.1.dist-info}/WHEEL +0 -0
  38. {jarvis_ai_assistant-0.2.8.dist-info → jarvis_ai_assistant-0.3.1.dist-info}/entry_points.txt +0 -0
  39. {jarvis_ai_assistant-0.2.8.dist-info → jarvis_ai_assistant-0.3.1.dist-info}/licenses/LICENSE +0 -0
  40. {jarvis_ai_assistant-0.2.8.dist-info → jarvis_ai_assistant-0.3.1.dist-info}/top_level.txt +0 -0
@@ -195,11 +195,16 @@ def start_service(
195
195
  if "/" in model:
196
196
  platform_name, model_name = model.split("/", 1)
197
197
  else:
198
- # Use default platform and model if not specified
199
- if default_platform and default_model:
200
- platform_name, model_name = default_platform, default_model
198
+ # Use default platform if not specified in the model name
199
+ if default_platform:
200
+ platform_name = default_platform
201
+ model_name = model
201
202
  else:
202
- platform_name, model_name = "oyi", model # Default to OYI platform
203
+ raise HTTPException(
204
+ status_code=400,
205
+ detail="Model name must be in 'platform/model_name' format "
206
+ "or a default platform must be set.",
207
+ )
203
208
 
204
209
  # Get platform instance
205
210
  platform = get_platform_instance(platform_name, model_name)
@@ -191,6 +191,8 @@ class ToolRegistry(OutputHandlerProtocol):
191
191
  self._load_builtin_tools()
192
192
  self._load_external_tools()
193
193
  self._load_mcp_tools()
194
+ # 应用工具配置组过滤
195
+ self._apply_tool_config_filter()
194
196
 
195
197
  def _get_tool_stats(self) -> Dict[str, int]:
196
198
  """从数据目录获取工具调用统计"""
@@ -258,6 +260,36 @@ class ToolRegistry(OutputHandlerProtocol):
258
260
  name: tool for name, tool in self.tools.items() if name not in names
259
261
  }
260
262
 
263
+ def _apply_tool_config_filter(self) -> None:
264
+ """应用工具配置组的过滤规则"""
265
+ from jarvis.jarvis_utils.config import get_tool_use_list, get_tool_dont_use_list
266
+
267
+ use_list = get_tool_use_list()
268
+ dont_use_list = get_tool_dont_use_list()
269
+
270
+ # 如果配置了 use 列表,只保留列表中的工具
271
+ if use_list:
272
+ filtered_tools = {}
273
+ for tool_name in use_list:
274
+ if tool_name in self.tools:
275
+ filtered_tools[tool_name] = self.tools[tool_name]
276
+ else:
277
+ PrettyOutput.print(
278
+ f"警告: 配置的工具 '{tool_name}' 不存在",
279
+ OutputType.WARNING,
280
+ )
281
+ self.tools = filtered_tools
282
+
283
+ # 如果配置了 dont_use 列表,排除列表中的工具
284
+ if dont_use_list:
285
+ for tool_name in dont_use_list:
286
+ if tool_name in self.tools:
287
+ del self.tools[tool_name]
288
+ PrettyOutput.print(
289
+ f"已排除工具: {tool_name}",
290
+ OutputType.INFO,
291
+ )
292
+
261
293
  def _load_mcp_tools(self) -> None:
262
294
  """加载MCP工具,优先从配置获取,其次从目录扫描"""
263
295
  from jarvis.jarvis_utils.config import get_mcp_config
@@ -4,9 +4,10 @@ import random
4
4
  from pathlib import Path
5
5
  from typing import Any, Dict, List, Optional
6
6
 
7
- from jarvis.jarvis_utils.config import get_data_dir
7
+ from jarvis.jarvis_utils.config import get_data_dir, get_max_input_token_count
8
8
  from jarvis.jarvis_utils.output import OutputType, PrettyOutput
9
9
  from jarvis.jarvis_utils.globals import get_short_term_memories
10
+ from jarvis.jarvis_utils.embedding import get_context_token_count
10
11
 
11
12
 
12
13
  class RetrieveMemoryTool:
@@ -131,14 +132,41 @@ class RetrieveMemoryTool:
131
132
  # 按创建时间排序(最新的在前)
132
133
  all_memories.sort(key=lambda x: x.get("created_at", ""), reverse=True)
133
134
 
134
- # 限制最多返回50条记忆,随机选取
135
- if len(all_memories) > 50:
136
- all_memories = random.sample(all_memories, 50)
137
- # 重新排序,保持时间顺序
138
- all_memories.sort(key=lambda x: x.get("created_at", ""), reverse=True)
135
+ # 获取最大输入token数的2/3作为记忆的token限制
136
+ max_input_tokens = get_max_input_token_count()
137
+ memory_token_limit = int(max_input_tokens * 2 / 3)
139
138
 
140
- # 如果指定了限制,只返回前N个
141
- if limit:
139
+ # 基于token限制和条数限制筛选记忆
140
+ filtered_memories: List[Dict[str, Any]] = []
141
+ total_tokens = 0
142
+
143
+ for memory in all_memories:
144
+ # 计算当前记忆的token数量
145
+ memory_content = json.dumps(memory, ensure_ascii=False)
146
+ memory_tokens = get_context_token_count(memory_content)
147
+
148
+ # 检查是否超过token限制
149
+ if total_tokens + memory_tokens > memory_token_limit:
150
+ PrettyOutput.print(
151
+ f"达到token限制 ({total_tokens}/{memory_token_limit}),停止加载更多记忆",
152
+ OutputType.INFO,
153
+ )
154
+ break
155
+
156
+ # 检查是否超过50条限制
157
+ if len(filtered_memories) >= 50:
158
+ PrettyOutput.print(
159
+ f"达到记忆条数限制 (50条),停止加载更多记忆", OutputType.INFO
160
+ )
161
+ break
162
+
163
+ filtered_memories.append(memory)
164
+ total_tokens += memory_tokens
165
+
166
+ all_memories = filtered_memories
167
+
168
+ # 如果指定了额外的限制,只返回前N个
169
+ if limit and len(all_memories) > limit:
142
170
  all_memories = all_memories[:limit]
143
171
 
144
172
  # 打印结果摘要
@@ -31,7 +31,7 @@ class SearchWebTool:
31
31
  """执行网络搜索、抓取内容并总结结果。"""
32
32
  try:
33
33
  PrettyOutput.print("▶️ 使用 DuckDuckGo 开始网页搜索...", OutputType.INFO)
34
- results = list(DDGS().text(query, max_results=50))
34
+ results = list(DDGS().text(query, max_results=50, page=3))
35
35
 
36
36
  if not results:
37
37
  return {
@@ -0,0 +1,90 @@
1
+ # -*- coding: utf-8 -*-
2
+ import platform
3
+ import subprocess
4
+
5
+ from jarvis.jarvis_utils.output import OutputType, PrettyOutput
6
+
7
+
8
+ def copy_to_clipboard(text: str) -> None:
9
+ """将文本复制到剪贴板,支持Windows、macOS和Linux
10
+
11
+ 参数:
12
+ text: 要复制的文本
13
+ """
14
+ print("--- 剪贴板内容开始 ---")
15
+ print(text)
16
+ print("--- 剪贴板内容结束 ---")
17
+
18
+ system = platform.system()
19
+
20
+ # Windows系统
21
+ if system == "Windows":
22
+ try:
23
+ # 使用Windows的clip命令
24
+ process = subprocess.Popen(
25
+ ["clip"],
26
+ stdin=subprocess.PIPE,
27
+ stdout=subprocess.DEVNULL,
28
+ stderr=subprocess.DEVNULL,
29
+ shell=True,
30
+ )
31
+ if process.stdin:
32
+ process.stdin.write(text.encode("utf-8"))
33
+ process.stdin.close()
34
+ return
35
+ except Exception as e:
36
+ PrettyOutput.print(f"使用Windows clip命令时出错: {e}", OutputType.WARNING)
37
+
38
+ # macOS系统
39
+ elif system == "Darwin":
40
+ try:
41
+ process = subprocess.Popen(
42
+ ["pbcopy"],
43
+ stdin=subprocess.PIPE,
44
+ stdout=subprocess.DEVNULL,
45
+ stderr=subprocess.DEVNULL,
46
+ )
47
+ if process.stdin:
48
+ process.stdin.write(text.encode("utf-8"))
49
+ process.stdin.close()
50
+ return
51
+ except Exception as e:
52
+ PrettyOutput.print(f"使用macOS pbcopy命令时出错: {e}", OutputType.WARNING)
53
+
54
+ # Linux系统
55
+ else:
56
+ # 尝试使用 xsel
57
+ try:
58
+ process = subprocess.Popen(
59
+ ["xsel", "-b", "-i"],
60
+ stdin=subprocess.PIPE,
61
+ stdout=subprocess.DEVNULL,
62
+ stderr=subprocess.DEVNULL,
63
+ )
64
+ if process.stdin:
65
+ process.stdin.write(text.encode("utf-8"))
66
+ process.stdin.close()
67
+ return
68
+ except FileNotFoundError:
69
+ pass # xsel 未安装,继续尝试下一个
70
+ except Exception as e:
71
+ PrettyOutput.print(f"使用xsel时出错: {e}", OutputType.WARNING)
72
+
73
+ # 尝试使用 xclip
74
+ try:
75
+ process = subprocess.Popen(
76
+ ["xclip", "-selection", "clipboard"],
77
+ stdin=subprocess.PIPE,
78
+ stdout=subprocess.DEVNULL,
79
+ stderr=subprocess.DEVNULL,
80
+ )
81
+ if process.stdin:
82
+ process.stdin.write(text.encode("utf-8"))
83
+ process.stdin.close()
84
+ return
85
+ except FileNotFoundError:
86
+ PrettyOutput.print(
87
+ "xsel 和 xclip 均未安装, 无法复制到剪贴板", OutputType.WARNING
88
+ )
89
+ except Exception as e:
90
+ PrettyOutput.print(f"使用xclip时出错: {e}", OutputType.WARNING)
@@ -339,6 +339,16 @@ def is_print_prompt() -> bool:
339
339
  return GLOBAL_CONFIG_DATA.get("JARVIS_PRINT_PROMPT", False) == True
340
340
 
341
341
 
342
+ def is_force_save_memory() -> bool:
343
+ """
344
+ 获取是否强制保存记忆。
345
+
346
+ 返回:
347
+ bool: 如果强制保存记忆则返回True,默认为True
348
+ """
349
+ return GLOBAL_CONFIG_DATA.get("JARVIS_FORCE_SAVE_MEMORY", True) is True
350
+
351
+
342
352
  def is_enable_static_analysis() -> bool:
343
353
  """
344
354
  获取是否启用静态代码分析。
@@ -487,3 +497,57 @@ def get_rag_use_rerank() -> bool:
487
497
  """
488
498
  config = _get_resolved_rag_config()
489
499
  return config.get("use_rerank", True) is True
500
+
501
+
502
+ # ==============================================================================
503
+ # Tool Configuration
504
+ # ==============================================================================
505
+
506
+
507
+ def _get_resolved_tool_config(
508
+ tool_group_override: Optional[str] = None,
509
+ ) -> Dict[str, Any]:
510
+ """
511
+ 解析并合并工具配置,处理工具组。
512
+
513
+ 优先级顺序:
514
+ 1. JARVIS_TOOL_GROUP 中定义的组配置
515
+ 2. 默认配置(所有工具都启用)
516
+
517
+ 返回:
518
+ Dict[str, Any]: 解析后的工具配置字典,包含 'use' 和 'dont_use' 列表
519
+ """
520
+ group_config = {}
521
+ tool_group_name = tool_group_override or GLOBAL_CONFIG_DATA.get("JARVIS_TOOL_GROUP")
522
+ tool_groups = GLOBAL_CONFIG_DATA.get("JARVIS_TOOL_GROUPS", [])
523
+
524
+ if tool_group_name and isinstance(tool_groups, list):
525
+ for group_item in tool_groups:
526
+ if isinstance(group_item, dict) and tool_group_name in group_item:
527
+ group_config = group_item[tool_group_name]
528
+ break
529
+
530
+ # 如果没有找到配置组,返回默认配置(空列表表示使用所有工具)
531
+ return group_config.copy() if group_config else {"use": [], "dont_use": []}
532
+
533
+
534
+ def get_tool_use_list() -> List[str]:
535
+ """
536
+ 获取要使用的工具列表。
537
+
538
+ 返回:
539
+ List[str]: 要使用的工具名称列表,空列表表示使用所有工具
540
+ """
541
+ config = _get_resolved_tool_config()
542
+ return config.get("use", [])
543
+
544
+
545
+ def get_tool_dont_use_list() -> List[str]:
546
+ """
547
+ 获取不使用的工具列表。
548
+
549
+ 返回:
550
+ List[str]: 不使用的工具名称列表
551
+ """
552
+ config = _get_resolved_tool_config()
553
+ return config.get("dont_use", [])
@@ -214,7 +214,9 @@ def handle_commit_workflow() -> bool:
214
214
  Returns:
215
215
  bool: 提交是否成功
216
216
  """
217
- if is_confirm_before_apply_patch() and not user_confirm("是否要提交代码?", default=True):
217
+ if is_confirm_before_apply_patch() and not user_confirm(
218
+ "是否要提交代码?", default=True
219
+ ):
218
220
  revert_change()
219
221
  return False
220
222
 
@@ -280,7 +282,7 @@ def get_latest_commit_hash() -> str:
280
282
  return ""
281
283
 
282
284
 
283
- def get_modified_line_ranges() -> Dict[str, Tuple[int, int]]:
285
+ def get_modified_line_ranges() -> Dict[str, List[Tuple[int, int]]]:
284
286
  """从Git差异中获取所有更改文件的修改行范围
285
287
 
286
288
  返回:
@@ -291,7 +293,7 @@ def get_modified_line_ranges() -> Dict[str, Tuple[int, int]]:
291
293
  diff_output = os.popen("git show").read()
292
294
 
293
295
  # 解析差异以获取修改的文件及其行范围
294
- result = {}
296
+ result: Dict[str, List[Tuple[int, int]]] = {}
295
297
  current_file = None
296
298
 
297
299
  for line in diff_output.splitlines():
@@ -427,7 +429,9 @@ def check_and_update_git_repo(repo_path: str) -> bool:
427
429
  if not in_venv and (
428
430
  "Permission denied" in error_msg or "not writeable" in error_msg
429
431
  ):
430
- if user_confirm("检测到权限问题,是否尝试用户级安装(--user)?", True):
432
+ if user_confirm(
433
+ "检测到权限问题,是否尝试用户级安装(--user)?", True
434
+ ):
431
435
  user_result = subprocess.run(
432
436
  install_cmd + ["--user"],
433
437
  cwd=git_root,
@@ -442,7 +446,9 @@ def check_and_update_git_repo(repo_path: str) -> bool:
442
446
  PrettyOutput.print(f"代码安装失败: {error_msg}", OutputType.ERROR)
443
447
  return False
444
448
  except Exception as e:
445
- PrettyOutput.print(f"安装过程中发生意外错误: {str(e)}", OutputType.ERROR)
449
+ PrettyOutput.print(
450
+ f"安装过程中发生意外错误: {str(e)}", OutputType.ERROR
451
+ )
446
452
  return False
447
453
  # 更新检查日期文件
448
454
  with open(last_check_file, "w") as f:
@@ -476,7 +482,9 @@ def get_diff_file_list() -> List[str]:
476
482
  subprocess.run(["git", "reset"], check=True)
477
483
 
478
484
  if result.returncode != 0:
479
- PrettyOutput.print(f"获取差异文件列表失败: {result.stderr}", OutputType.ERROR)
485
+ PrettyOutput.print(
486
+ f"获取差异文件列表失败: {result.stderr}", OutputType.ERROR
487
+ )
480
488
  return []
481
489
 
482
490
  return [f for f in result.stdout.splitlines() if f]
@@ -626,7 +634,9 @@ def confirm_add_new_files() -> None:
626
634
  need_confirm = True
627
635
 
628
636
  if binary_files:
629
- output_lines.append(f"检测到{len(binary_files)}个二进制文件(选择N将重新检测)")
637
+ output_lines.append(
638
+ f"检测到{len(binary_files)}个二进制文件(选择N将重新检测)"
639
+ )
630
640
  output_lines.append("二进制文件列表:")
631
641
  output_lines.extend(f" - {file}" for file in binary_files)
632
642
  need_confirm = True
@@ -264,7 +264,7 @@ def get_all_memory_tags() -> Dict[str, List[str]]:
264
264
  """
265
265
  获取所有记忆类型中的标签集合。
266
266
  每个类型最多返回200个标签,超过时随机提取。
267
-
267
+
268
268
  返回:
269
269
  Dict[str, List[str]]: 按记忆类型分组的标签列表
270
270
  """
@@ -272,25 +272,27 @@ def get_all_memory_tags() -> Dict[str, List[str]]:
272
272
  import json
273
273
  import random
274
274
  from jarvis.jarvis_utils.config import get_data_dir
275
-
276
- tags_by_type = {
275
+
276
+ tags_by_type: Dict[str, List[str]] = {
277
277
  "short_term": [],
278
278
  "project_long_term": [],
279
- "global_long_term": []
279
+ "global_long_term": [],
280
280
  }
281
-
281
+
282
282
  MAX_TAGS_PER_TYPE = 200
283
-
283
+
284
284
  # 获取短期记忆标签
285
285
  short_term_tags = set()
286
286
  for memory in short_term_memories:
287
287
  short_term_tags.update(memory.get("tags", []))
288
288
  short_term_tags_list = sorted(list(short_term_tags))
289
289
  if len(short_term_tags_list) > MAX_TAGS_PER_TYPE:
290
- tags_by_type["short_term"] = sorted(random.sample(short_term_tags_list, MAX_TAGS_PER_TYPE))
290
+ tags_by_type["short_term"] = sorted(
291
+ random.sample(short_term_tags_list, MAX_TAGS_PER_TYPE)
292
+ )
291
293
  else:
292
294
  tags_by_type["short_term"] = short_term_tags_list
293
-
295
+
294
296
  # 获取项目长期记忆标签
295
297
  project_memory_dir = Path(".jarvis/memory")
296
298
  if project_memory_dir.exists():
@@ -304,10 +306,12 @@ def get_all_memory_tags() -> Dict[str, List[str]]:
304
306
  pass
305
307
  project_tags_list = sorted(list(project_tags))
306
308
  if len(project_tags_list) > MAX_TAGS_PER_TYPE:
307
- tags_by_type["project_long_term"] = sorted(random.sample(project_tags_list, MAX_TAGS_PER_TYPE))
309
+ tags_by_type["project_long_term"] = sorted(
310
+ random.sample(project_tags_list, MAX_TAGS_PER_TYPE)
311
+ )
308
312
  else:
309
313
  tags_by_type["project_long_term"] = project_tags_list
310
-
314
+
311
315
  # 获取全局长期记忆标签
312
316
  global_memory_dir = Path(get_data_dir()) / "memory" / "global_long_term"
313
317
  if global_memory_dir.exists():
@@ -321,8 +325,10 @@ def get_all_memory_tags() -> Dict[str, List[str]]:
321
325
  pass
322
326
  global_tags_list = sorted(list(global_tags))
323
327
  if len(global_tags_list) > MAX_TAGS_PER_TYPE:
324
- tags_by_type["global_long_term"] = sorted(random.sample(global_tags_list, MAX_TAGS_PER_TYPE))
328
+ tags_by_type["global_long_term"] = sorted(
329
+ random.sample(global_tags_list, MAX_TAGS_PER_TYPE)
330
+ )
325
331
  else:
326
332
  tags_by_type["global_long_term"] = global_tags_list
327
-
333
+
328
334
  return tags_by_type
@@ -9,40 +9,141 @@
9
9
  - 用于输入控制的自定义键绑定
10
10
  """
11
11
  import os
12
- from typing import Iterable
12
+ from typing import Iterable, List
13
13
 
14
14
  from colorama import Fore
15
- from colorama import Style as ColoramaStyle # type: ignore
16
- from fuzzywuzzy import process # type: ignore
17
- from prompt_toolkit import PromptSession # type: ignore
18
- from prompt_toolkit.completion import CompleteEvent # type: ignore
15
+ from colorama import Style as ColoramaStyle
16
+ from fuzzywuzzy import process
17
+ from prompt_toolkit import PromptSession
18
+ from prompt_toolkit.application import Application
19
+ from prompt_toolkit.completion import CompleteEvent
19
20
  from prompt_toolkit.completion import (
20
21
  Completer,
21
22
  Completion,
22
23
  PathCompleter,
23
24
  )
24
- from prompt_toolkit.document import Document # type: ignore
25
- from prompt_toolkit.formatted_text import FormattedText # type: ignore
26
- from prompt_toolkit.history import FileHistory # type: ignore
27
- from prompt_toolkit.key_binding import KeyBindings # type: ignore
28
- from prompt_toolkit.styles import Style as PromptStyle # type: ignore
29
-
25
+ from prompt_toolkit.document import Document
26
+ from prompt_toolkit.formatted_text import FormattedText
27
+ from prompt_toolkit.history import FileHistory
28
+ from prompt_toolkit.key_binding import KeyBindings
29
+ from prompt_toolkit.layout.containers import Window
30
+ from prompt_toolkit.layout.controls import FormattedTextControl
31
+ from prompt_toolkit.layout.layout import Layout
32
+ from prompt_toolkit.styles import Style as PromptStyle
33
+
34
+ from jarvis.jarvis_utils.clipboard import copy_to_clipboard
30
35
  from jarvis.jarvis_utils.config import get_data_dir, get_replace_map
36
+ from jarvis.jarvis_utils.globals import get_message_history
31
37
  from jarvis.jarvis_utils.output import OutputType, PrettyOutput
32
38
  from jarvis.jarvis_utils.tag import ot
33
- from jarvis.jarvis_utils.utils import copy_to_clipboard
34
39
 
35
40
  # Sentinel value to indicate that Ctrl+O was pressed
36
41
  CTRL_O_SENTINEL = "__CTRL_O_PRESSED__"
37
42
 
38
43
 
39
- def get_single_line_input(tip: str) -> str:
44
+ def get_single_line_input(tip: str, default: str = "") -> str:
40
45
  """
41
46
  获取支持历史记录的单行输入。
42
47
  """
43
48
  session: PromptSession = PromptSession(history=None)
44
49
  style = PromptStyle.from_dict({"prompt": "ansicyan"})
45
- return session.prompt(f"{tip}", style=style)
50
+ return session.prompt(f"{tip}", default=default, style=style)
51
+
52
+
53
+ def get_choice(tip: str, choices: List[str]) -> str:
54
+ """
55
+ 提供一个可滚动的选择列表供用户选择。
56
+ """
57
+ if not choices:
58
+ raise ValueError("Choices cannot be empty.")
59
+
60
+ try:
61
+ terminal_height = os.get_terminal_size().lines
62
+ except OSError:
63
+ terminal_height = 25 # 如果无法确定终端大小,则使用默认高度
64
+
65
+ # 为提示和缓冲区保留行
66
+ max_visible_choices = max(5, terminal_height - 4)
67
+
68
+ bindings = KeyBindings()
69
+ selected_index = 0
70
+ start_index = 0
71
+
72
+ @bindings.add("up")
73
+ def _(event):
74
+ nonlocal selected_index, start_index
75
+ selected_index = (selected_index - 1 + len(choices)) % len(choices)
76
+ if selected_index < start_index:
77
+ start_index = selected_index
78
+ elif selected_index == len(choices) - 1: # 支持从第一项上翻到最后一项时滚动
79
+ start_index = max(0, len(choices) - max_visible_choices)
80
+ event.app.invalidate()
81
+
82
+ @bindings.add("down")
83
+ def _(event):
84
+ nonlocal selected_index, start_index
85
+ selected_index = (selected_index + 1) % len(choices)
86
+ if selected_index >= start_index + max_visible_choices:
87
+ start_index = selected_index - max_visible_choices + 1
88
+ elif selected_index == 0: # 支持从最后一项下翻到第一项时滚动
89
+ start_index = 0
90
+ event.app.invalidate()
91
+
92
+ @bindings.add("enter")
93
+ def _(event):
94
+ event.app.exit(result=choices[selected_index])
95
+
96
+ def get_prompt_tokens():
97
+ tokens = [("class:question", f"{tip} (使用上下箭头选择, Enter确认)\n")]
98
+
99
+ end_index = min(start_index + max_visible_choices, len(choices))
100
+ visible_choices_slice = choices[start_index:end_index]
101
+
102
+ if start_index > 0:
103
+ tokens.append(("class:indicator", " ... (更多选项在上方) ...\n"))
104
+
105
+ for i, choice in enumerate(visible_choices_slice, start=start_index):
106
+ if i == selected_index:
107
+ tokens.append(("class:selected", f"> {choice}\n"))
108
+ else:
109
+ tokens.append(("", f" {choice}\n"))
110
+
111
+ if end_index < len(choices):
112
+ tokens.append(("class:indicator", " ... (更多选项在下方) ...\n"))
113
+
114
+ return FormattedText(tokens)
115
+
116
+ style = PromptStyle.from_dict(
117
+ {
118
+ "question": "bold",
119
+ "selected": "bg:#696969 #ffffff",
120
+ "indicator": "fg:gray",
121
+ }
122
+ )
123
+
124
+ layout = Layout(
125
+ container=Window(
126
+ content=FormattedTextControl(
127
+ text=get_prompt_tokens,
128
+ focusable=True,
129
+ key_bindings=bindings,
130
+ )
131
+ )
132
+ )
133
+
134
+ app: Application = Application(
135
+ layout=layout,
136
+ key_bindings=bindings,
137
+ style=style,
138
+ mouse_support=True,
139
+ full_screen=True,
140
+ )
141
+
142
+ try:
143
+ result = app.run()
144
+ return result if result is not None else ""
145
+ except (KeyboardInterrupt, EOFError):
146
+ return ""
46
147
 
47
148
 
48
149
  class FileCompleter(Completer):
@@ -160,7 +261,6 @@ def _show_history_and_copy():
160
261
  Displays message history and handles copying to clipboard.
161
262
  This function uses standard I/O and is safe to call outside a prompt session.
162
263
  """
163
- from jarvis.jarvis_utils.globals import get_message_history
164
264
 
165
265
  history = get_message_history()
166
266
  if not history:
@@ -170,7 +270,9 @@ def _show_history_and_copy():
170
270
  print("\n" + "=" * 20 + " 消息历史记录 " + "=" * 20)
171
271
  for i, msg in enumerate(history):
172
272
  cleaned_msg = msg.replace("\n", r"\n")
173
- display_msg = (cleaned_msg[:70] + "...") if len(cleaned_msg) > 70 else cleaned_msg
273
+ display_msg = (
274
+ (cleaned_msg[:70] + "...") if len(cleaned_msg) > 70 else cleaned_msg
275
+ )
174
276
  print(f" {i + 1}: {display_msg.strip()}")
175
277
  print("=" * 58 + "\n")
176
278