jarvis-ai-assistant 0.1.128__py3-none-any.whl → 0.1.129__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of jarvis-ai-assistant might be problematic. Click here for more details.
- jarvis/__init__.py +1 -1
- jarvis/jarvis_agent/__init__.py +8 -13
- jarvis/jarvis_agent/main.py +77 -0
- jarvis/jarvis_code_agent/builtin_input_handler.py +43 -0
- jarvis/jarvis_code_agent/code_agent.py +3 -78
- jarvis/jarvis_code_agent/file_input_handler.py +88 -0
- jarvis/jarvis_code_agent/patch.py +36 -35
- jarvis/jarvis_code_agent/shell_input_handler.py +8 -2
- jarvis/jarvis_multi_agent/__init__.py +51 -40
- jarvis/jarvis_tools/read_code.py +143 -0
- jarvis/jarvis_tools/registry.py +35 -39
- jarvis/jarvis_tools/tool_generator.py +45 -17
- jarvis/jarvis_utils/__init__.py +17 -17
- jarvis/jarvis_utils/config.py +87 -51
- jarvis/jarvis_utils/embedding.py +49 -48
- jarvis/jarvis_utils/git_utils.py +34 -34
- jarvis/jarvis_utils/globals.py +26 -26
- jarvis/jarvis_utils/input.py +61 -45
- jarvis/jarvis_utils/methodology.py +22 -22
- jarvis/jarvis_utils/output.py +62 -62
- jarvis/jarvis_utils/utils.py +2 -2
- {jarvis_ai_assistant-0.1.128.dist-info → jarvis_ai_assistant-0.1.129.dist-info}/METADATA +1 -1
- {jarvis_ai_assistant-0.1.128.dist-info → jarvis_ai_assistant-0.1.129.dist-info}/RECORD +27 -23
- {jarvis_ai_assistant-0.1.128.dist-info → jarvis_ai_assistant-0.1.129.dist-info}/entry_points.txt +2 -0
- {jarvis_ai_assistant-0.1.128.dist-info → jarvis_ai_assistant-0.1.129.dist-info}/LICENSE +0 -0
- {jarvis_ai_assistant-0.1.128.dist-info → jarvis_ai_assistant-0.1.129.dist-info}/WHEEL +0 -0
- {jarvis_ai_assistant-0.1.128.dist-info → jarvis_ai_assistant-0.1.129.dist-info}/top_level.txt +0 -0
jarvis/jarvis_utils/embedding.py
CHANGED
|
@@ -6,35 +6,32 @@ from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
|
|
6
6
|
from typing import List, Any, Tuple
|
|
7
7
|
from jarvis.jarvis_utils.output import PrettyOutput, OutputType
|
|
8
8
|
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
9
|
def get_context_token_count(text: str) -> int:
|
|
13
|
-
"""
|
|
10
|
+
"""使用分词器获取文本的token数量。
|
|
14
11
|
|
|
15
|
-
|
|
16
|
-
text:
|
|
12
|
+
参数:
|
|
13
|
+
text: 要计算token的输入文本
|
|
17
14
|
|
|
18
|
-
|
|
19
|
-
int:
|
|
15
|
+
返回:
|
|
16
|
+
int: 文本中的token数量
|
|
20
17
|
"""
|
|
21
18
|
try:
|
|
22
|
-
#
|
|
19
|
+
# 使用擅长处理通用文本的快速分词器
|
|
23
20
|
tokenizer = load_tokenizer()
|
|
24
21
|
chunks = split_text_into_chunks(text, 512)
|
|
25
22
|
return sum([len(tokenizer.encode(chunk)) for chunk in chunks]) # type: ignore
|
|
26
23
|
|
|
27
24
|
except Exception as e:
|
|
28
25
|
PrettyOutput.print(f"计算token失败: {str(e)}", OutputType.WARNING)
|
|
29
|
-
#
|
|
30
|
-
return len(text) // 4 #
|
|
26
|
+
# 回退到基于字符的粗略估计
|
|
27
|
+
return len(text) // 4 # 每个token大约4个字符的粗略估计
|
|
31
28
|
|
|
32
29
|
def load_embedding_model() -> SentenceTransformer:
|
|
33
30
|
"""
|
|
34
|
-
|
|
31
|
+
加载句子嵌入模型。
|
|
35
32
|
|
|
36
|
-
|
|
37
|
-
SentenceTransformer:
|
|
33
|
+
返回:
|
|
34
|
+
SentenceTransformer: 加载的嵌入模型
|
|
38
35
|
"""
|
|
39
36
|
model_name = "BAAI/bge-m3"
|
|
40
37
|
cache_dir = os.path.expanduser("~/.cache/huggingface/hub")
|
|
@@ -53,31 +50,33 @@ def load_embedding_model() -> SentenceTransformer:
|
|
|
53
50
|
)
|
|
54
51
|
|
|
55
52
|
return embedding_model
|
|
53
|
+
|
|
56
54
|
def get_embedding(embedding_model: Any, text: str) -> np.ndarray:
|
|
57
55
|
"""
|
|
58
|
-
|
|
56
|
+
为给定文本生成嵌入向量。
|
|
59
57
|
|
|
60
|
-
|
|
61
|
-
embedding_model:
|
|
62
|
-
text:
|
|
58
|
+
参数:
|
|
59
|
+
embedding_model: 使用的嵌入模型
|
|
60
|
+
text: 要嵌入的输入文本
|
|
63
61
|
|
|
64
|
-
|
|
65
|
-
np.ndarray:
|
|
62
|
+
返回:
|
|
63
|
+
np.ndarray: 嵌入向量
|
|
66
64
|
"""
|
|
67
65
|
embedding = embedding_model.encode(text,
|
|
68
66
|
normalize_embeddings=True,
|
|
69
67
|
show_progress_bar=False)
|
|
70
68
|
return np.array(embedding, dtype=np.float32)
|
|
69
|
+
|
|
71
70
|
def get_embedding_batch(embedding_model: Any, texts: List[str]) -> np.ndarray:
|
|
72
71
|
"""
|
|
73
|
-
|
|
72
|
+
为一批文本生成嵌入向量。
|
|
74
73
|
|
|
75
|
-
|
|
76
|
-
embedding_model:
|
|
77
|
-
texts:
|
|
74
|
+
参数:
|
|
75
|
+
embedding_model: 使用的嵌入模型
|
|
76
|
+
texts: 要嵌入的文本列表
|
|
78
77
|
|
|
79
|
-
|
|
80
|
-
np.ndarray:
|
|
78
|
+
返回:
|
|
79
|
+
np.ndarray: 堆叠的嵌入向量
|
|
81
80
|
"""
|
|
82
81
|
try:
|
|
83
82
|
all_vectors = []
|
|
@@ -90,41 +89,41 @@ def get_embedding_batch(embedding_model: Any, texts: List[str]) -> np.ndarray:
|
|
|
90
89
|
return np.zeros((0, embedding_model.get_sentence_embedding_dimension()), dtype=np.float32)
|
|
91
90
|
|
|
92
91
|
def split_text_into_chunks(text: str, max_length: int = 512) -> List[str]:
|
|
93
|
-
"""
|
|
92
|
+
"""将文本分割成带重叠窗口的块。
|
|
94
93
|
|
|
95
|
-
|
|
96
|
-
text:
|
|
97
|
-
max_length:
|
|
94
|
+
参数:
|
|
95
|
+
text: 要分割的输入文本
|
|
96
|
+
max_length: 每个块的最大长度
|
|
98
97
|
|
|
99
|
-
|
|
100
|
-
List[str]:
|
|
98
|
+
返回:
|
|
99
|
+
List[str]: 文本块列表
|
|
101
100
|
"""
|
|
102
101
|
chunks = []
|
|
103
102
|
start = 0
|
|
104
103
|
while start < len(text):
|
|
105
104
|
end = start + max_length
|
|
106
|
-
#
|
|
105
|
+
# 找到最近的句子边界
|
|
107
106
|
if end < len(text):
|
|
108
107
|
while end > start and text[end] not in {'.', '!', '?', '\n'}:
|
|
109
108
|
end -= 1
|
|
110
|
-
if end == start: #
|
|
109
|
+
if end == start: # 未找到标点,强制分割
|
|
111
110
|
end = start + max_length
|
|
112
111
|
chunk = text[start:end]
|
|
113
112
|
chunks.append(chunk)
|
|
114
|
-
#
|
|
113
|
+
# 重叠20%的窗口
|
|
115
114
|
start = end - int(max_length * 0.2)
|
|
116
115
|
return chunks
|
|
117
116
|
|
|
118
117
|
def get_embedding_with_chunks(embedding_model: Any, text: str) -> List[np.ndarray]:
|
|
119
118
|
"""
|
|
120
|
-
|
|
119
|
+
为文本块生成嵌入向量。
|
|
121
120
|
|
|
122
|
-
|
|
123
|
-
embedding_model:
|
|
124
|
-
text:
|
|
121
|
+
参数:
|
|
122
|
+
embedding_model: 使用的嵌入模型
|
|
123
|
+
text: 要处理的输入文本
|
|
125
124
|
|
|
126
|
-
|
|
127
|
-
List[np.ndarray]:
|
|
125
|
+
返回:
|
|
126
|
+
List[np.ndarray]: 每个块的嵌入向量列表
|
|
128
127
|
"""
|
|
129
128
|
chunks = split_text_into_chunks(text, 512)
|
|
130
129
|
if not chunks:
|
|
@@ -135,12 +134,13 @@ def get_embedding_with_chunks(embedding_model: Any, text: str) -> List[np.ndarra
|
|
|
135
134
|
vector = get_embedding(embedding_model, chunk)
|
|
136
135
|
vectors.append(vector)
|
|
137
136
|
return vectors
|
|
137
|
+
|
|
138
138
|
def load_tokenizer() -> AutoTokenizer:
|
|
139
139
|
"""
|
|
140
|
-
|
|
140
|
+
加载用于文本处理的分词器。
|
|
141
141
|
|
|
142
|
-
|
|
143
|
-
AutoTokenizer:
|
|
142
|
+
返回:
|
|
143
|
+
AutoTokenizer: 加载的分词器
|
|
144
144
|
"""
|
|
145
145
|
model_name = "gpt2"
|
|
146
146
|
cache_dir = os.path.expanduser("~/.cache/huggingface/hub")
|
|
@@ -159,12 +159,13 @@ def load_tokenizer() -> AutoTokenizer:
|
|
|
159
159
|
)
|
|
160
160
|
|
|
161
161
|
return tokenizer # type: ignore
|
|
162
|
+
|
|
162
163
|
def load_rerank_model() -> Tuple[AutoModelForSequenceClassification, AutoTokenizer]:
|
|
163
164
|
"""
|
|
164
|
-
|
|
165
|
+
加载重排序模型和分词器。
|
|
165
166
|
|
|
166
|
-
|
|
167
|
-
Tuple[AutoModelForSequenceClassification, AutoTokenizer]:
|
|
167
|
+
返回:
|
|
168
|
+
Tuple[AutoModelForSequenceClassification, AutoTokenizer]: 加载的模型和分词器
|
|
168
169
|
"""
|
|
169
170
|
model_name = "BAAI/bge-reranker-v2-m3"
|
|
170
171
|
cache_dir = os.path.expanduser("~/.cache/huggingface/hub")
|
|
@@ -198,4 +199,4 @@ def load_rerank_model() -> Tuple[AutoModelForSequenceClassification, AutoTokeniz
|
|
|
198
199
|
model = model.cuda()
|
|
199
200
|
model.eval()
|
|
200
201
|
|
|
201
|
-
return model, tokenizer # type: ignore
|
|
202
|
+
return model, tokenizer # type: ignore
|
jarvis/jarvis_utils/git_utils.py
CHANGED
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
"""
|
|
2
|
-
Git
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
-
|
|
6
|
-
-
|
|
7
|
-
-
|
|
8
|
-
-
|
|
9
|
-
-
|
|
2
|
+
Git工具模块
|
|
3
|
+
该模块提供了与Git仓库交互的工具。
|
|
4
|
+
包含以下功能:
|
|
5
|
+
- 查找Git仓库的根目录
|
|
6
|
+
- 检查是否有未提交的更改
|
|
7
|
+
- 获取两个哈希值之间的提交历史
|
|
8
|
+
- 获取最新提交的哈希值
|
|
9
|
+
- 从Git差异中提取修改的行范围
|
|
10
10
|
"""
|
|
11
11
|
import os
|
|
12
12
|
import re
|
|
@@ -14,42 +14,42 @@ import subprocess
|
|
|
14
14
|
from typing import List, Tuple, Dict
|
|
15
15
|
from jarvis.jarvis_utils.output import PrettyOutput, OutputType
|
|
16
16
|
def find_git_root(start_dir="."):
|
|
17
|
-
"""
|
|
17
|
+
"""切换到给定路径的Git根目录"""
|
|
18
18
|
os.chdir(start_dir)
|
|
19
19
|
git_root = os.popen("git rev-parse --show-toplevel").read().strip()
|
|
20
20
|
os.chdir(git_root)
|
|
21
21
|
return git_root
|
|
22
22
|
def has_uncommitted_changes():
|
|
23
|
-
"""
|
|
24
|
-
#
|
|
23
|
+
"""检查Git仓库中是否有未提交的更改"""
|
|
24
|
+
# 静默添加所有更改
|
|
25
25
|
subprocess.run(["git", "add", "."], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
|
26
26
|
|
|
27
|
-
#
|
|
27
|
+
# 检查工作目录更改
|
|
28
28
|
working_changes = subprocess.run(["git", "diff", "--exit-code"],
|
|
29
29
|
stdout=subprocess.DEVNULL,
|
|
30
30
|
stderr=subprocess.DEVNULL).returncode != 0
|
|
31
31
|
|
|
32
|
-
#
|
|
32
|
+
# 检查暂存区更改
|
|
33
33
|
staged_changes = subprocess.run(["git", "diff", "--cached", "--exit-code"],
|
|
34
34
|
stdout=subprocess.DEVNULL,
|
|
35
35
|
stderr=subprocess.DEVNULL).returncode != 0
|
|
36
36
|
|
|
37
|
-
#
|
|
37
|
+
# 静默重置更改
|
|
38
38
|
subprocess.run(["git", "reset"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
|
39
39
|
|
|
40
40
|
return working_changes or staged_changes
|
|
41
41
|
def get_commits_between(start_hash: str, end_hash: str) -> List[Tuple[str, str]]:
|
|
42
|
-
"""
|
|
42
|
+
"""获取两个提交哈希值之间的提交列表
|
|
43
43
|
|
|
44
|
-
|
|
45
|
-
start_hash:
|
|
46
|
-
end_hash:
|
|
44
|
+
参数:
|
|
45
|
+
start_hash: 起始提交哈希值(不包含)
|
|
46
|
+
end_hash: 结束提交哈希值(包含)
|
|
47
47
|
|
|
48
|
-
|
|
49
|
-
List[Tuple[str, str]]:
|
|
48
|
+
返回:
|
|
49
|
+
List[Tuple[str, str]]: (提交哈希值, 提交信息) 元组列表
|
|
50
50
|
"""
|
|
51
51
|
try:
|
|
52
|
-
#
|
|
52
|
+
# 使用git log和pretty格式获取哈希值和信息
|
|
53
53
|
result = subprocess.run(
|
|
54
54
|
['git', 'log', f'{start_hash}..{end_hash}', '--pretty=format:%H|%s'],
|
|
55
55
|
stdout=subprocess.PIPE,
|
|
@@ -71,10 +71,10 @@ def get_commits_between(start_hash: str, end_hash: str) -> List[Tuple[str, str]]
|
|
|
71
71
|
PrettyOutput.print(f"获取commit历史异常: {str(e)}", OutputType.ERROR)
|
|
72
72
|
return []
|
|
73
73
|
def get_latest_commit_hash() -> str:
|
|
74
|
-
"""
|
|
74
|
+
"""获取当前Git仓库的最新提交哈希值
|
|
75
75
|
|
|
76
|
-
|
|
77
|
-
str:
|
|
76
|
+
返回:
|
|
77
|
+
str: 提交哈希值,如果不在Git仓库或发生错误则返回空字符串
|
|
78
78
|
"""
|
|
79
79
|
try:
|
|
80
80
|
result = subprocess.run(
|
|
@@ -89,32 +89,32 @@ def get_latest_commit_hash() -> str:
|
|
|
89
89
|
except Exception:
|
|
90
90
|
return ""
|
|
91
91
|
def get_modified_line_ranges() -> Dict[str, Tuple[int, int]]:
|
|
92
|
-
"""
|
|
92
|
+
"""从Git差异中获取所有更改文件的修改行范围
|
|
93
93
|
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
94
|
+
返回:
|
|
95
|
+
字典,将文件路径映射到包含修改部分的(起始行, 结束行)范围元组。
|
|
96
|
+
行号从1开始。
|
|
97
97
|
"""
|
|
98
|
-
#
|
|
98
|
+
# 获取所有文件的Git差异
|
|
99
99
|
diff_output = os.popen("git show").read()
|
|
100
100
|
|
|
101
|
-
#
|
|
101
|
+
# 解析差异以获取修改的文件及其行范围
|
|
102
102
|
result = {}
|
|
103
103
|
current_file = None
|
|
104
104
|
|
|
105
105
|
for line in diff_output.splitlines():
|
|
106
|
-
#
|
|
106
|
+
# 匹配类似"+++ b/path/to/file"的行
|
|
107
107
|
file_match = re.match(r"^\+\+\+ b/(.*)", line)
|
|
108
108
|
if file_match:
|
|
109
109
|
current_file = file_match.group(1)
|
|
110
110
|
continue
|
|
111
111
|
|
|
112
|
-
#
|
|
112
|
+
# 匹配类似"@@ -100,5 +100,7 @@"的行,其中+部分显示新行
|
|
113
113
|
range_match = re.match(r"^@@ -\d+(?:,\d+)? \+(\d+)(?:,(\d+))? @@", line)
|
|
114
114
|
if range_match and current_file:
|
|
115
|
-
start_line = int(range_match.group(1)) #
|
|
115
|
+
start_line = int(range_match.group(1)) # 保持从1开始
|
|
116
116
|
line_count = int(range_match.group(2)) if range_match.group(2) else 1
|
|
117
117
|
end_line = start_line + line_count - 1
|
|
118
118
|
result[current_file] = (start_line, end_line)
|
|
119
119
|
|
|
120
|
-
return result
|
|
120
|
+
return result
|
jarvis/jarvis_utils/globals.py
CHANGED
|
@@ -1,24 +1,24 @@
|
|
|
1
1
|
"""
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
-
|
|
6
|
-
-
|
|
7
|
-
-
|
|
2
|
+
全局变量和配置模块
|
|
3
|
+
该模块管理Jarvis系统的全局状态和配置。
|
|
4
|
+
包含:
|
|
5
|
+
- 全局代理管理
|
|
6
|
+
- 带有自定义主题的控制台配置
|
|
7
|
+
- 环境初始化
|
|
8
8
|
"""
|
|
9
9
|
from typing import Any, Set
|
|
10
10
|
import colorama
|
|
11
11
|
import os
|
|
12
12
|
from rich.console import Console
|
|
13
13
|
from rich.theme import Theme
|
|
14
|
-
#
|
|
14
|
+
# 初始化colorama以支持跨平台的彩色文本
|
|
15
15
|
colorama.init()
|
|
16
|
-
#
|
|
16
|
+
# 禁用tokenizers并行以避免多进程问题
|
|
17
17
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
18
|
-
#
|
|
18
|
+
# 全局代理管理
|
|
19
19
|
global_agents: Set[str] = set()
|
|
20
20
|
current_agent_name: str = ""
|
|
21
|
-
#
|
|
21
|
+
# 使用自定义主题配置rich控制台
|
|
22
22
|
custom_theme = Theme({
|
|
23
23
|
"INFO": "yellow",
|
|
24
24
|
"WARNING": "yellow",
|
|
@@ -36,13 +36,13 @@ custom_theme = Theme({
|
|
|
36
36
|
console = Console(theme=custom_theme)
|
|
37
37
|
def make_agent_name(agent_name: str) -> str:
|
|
38
38
|
"""
|
|
39
|
-
|
|
39
|
+
通过附加后缀生成唯一的代理名称(如果必要)。
|
|
40
40
|
|
|
41
|
-
|
|
42
|
-
agent_name:
|
|
41
|
+
参数:
|
|
42
|
+
agent_name: 基础代理名称
|
|
43
43
|
|
|
44
|
-
|
|
45
|
-
str:
|
|
44
|
+
返回:
|
|
45
|
+
str: 唯一的代理名称
|
|
46
46
|
"""
|
|
47
47
|
if agent_name in global_agents:
|
|
48
48
|
i = 1
|
|
@@ -52,31 +52,31 @@ def make_agent_name(agent_name: str) -> str:
|
|
|
52
52
|
return agent_name
|
|
53
53
|
def set_agent(agent_name: str, agent: Any) -> None:
|
|
54
54
|
"""
|
|
55
|
-
|
|
55
|
+
设置当前代理并将其添加到全局代理集合中。
|
|
56
56
|
|
|
57
|
-
|
|
58
|
-
agent_name:
|
|
59
|
-
agent:
|
|
57
|
+
参数:
|
|
58
|
+
agent_name: 代理名称
|
|
59
|
+
agent: 代理对象
|
|
60
60
|
"""
|
|
61
61
|
global_agents.add(agent_name)
|
|
62
62
|
global current_agent_name
|
|
63
63
|
current_agent_name = agent_name
|
|
64
64
|
def get_agent_list() -> str:
|
|
65
65
|
"""
|
|
66
|
-
|
|
66
|
+
获取表示当前代理状态的格式化字符串。
|
|
67
67
|
|
|
68
|
-
|
|
69
|
-
str:
|
|
68
|
+
返回:
|
|
69
|
+
str: 包含代理数量和当前代理名称的格式化字符串
|
|
70
70
|
"""
|
|
71
71
|
return "[" + str(len(global_agents)) + "]" + current_agent_name if global_agents else ""
|
|
72
72
|
def delete_agent(agent_name: str) -> None:
|
|
73
73
|
"""
|
|
74
|
-
|
|
74
|
+
从全局代理集合中删除一个代理。
|
|
75
75
|
|
|
76
|
-
|
|
77
|
-
agent_name:
|
|
76
|
+
参数:
|
|
77
|
+
agent_name: 要删除的代理名称
|
|
78
78
|
"""
|
|
79
79
|
if agent_name in global_agents:
|
|
80
80
|
global_agents.remove(agent_name)
|
|
81
81
|
global current_agent_name
|
|
82
|
-
current_agent_name = ""
|
|
82
|
+
current_agent_name = ""
|
jarvis/jarvis_utils/input.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
"""
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
-
|
|
6
|
-
-
|
|
7
|
-
-
|
|
8
|
-
-
|
|
2
|
+
输入处理模块
|
|
3
|
+
该模块提供了处理Jarvis系统中用户输入的实用工具。
|
|
4
|
+
包含:
|
|
5
|
+
- 支持历史记录的单行输入
|
|
6
|
+
- 增强补全功能的多行输入
|
|
7
|
+
- 带有模糊匹配的文件路径补全
|
|
8
|
+
- 用于输入控制的自定义键绑定
|
|
9
9
|
"""
|
|
10
10
|
from prompt_toolkit import PromptSession
|
|
11
11
|
from prompt_toolkit.styles import Style as PromptStyle
|
|
@@ -18,13 +18,13 @@ from colorama import Fore, Style as ColoramaStyle
|
|
|
18
18
|
from ..jarvis_utils.output import PrettyOutput, OutputType
|
|
19
19
|
def get_single_line_input(tip: str) -> str:
|
|
20
20
|
"""
|
|
21
|
-
|
|
21
|
+
获取支持历史记录的单行输入。
|
|
22
22
|
|
|
23
|
-
|
|
24
|
-
tip:
|
|
23
|
+
参数:
|
|
24
|
+
tip: 要显示的提示信息
|
|
25
25
|
|
|
26
|
-
|
|
27
|
-
str:
|
|
26
|
+
返回:
|
|
27
|
+
str: 用户的输入
|
|
28
28
|
"""
|
|
29
29
|
session = PromptSession(history=None)
|
|
30
30
|
style = PromptStyle.from_dict({
|
|
@@ -33,49 +33,65 @@ def get_single_line_input(tip: str) -> str:
|
|
|
33
33
|
return session.prompt(f"{tip}", style=style)
|
|
34
34
|
class FileCompleter(Completer):
|
|
35
35
|
"""
|
|
36
|
-
|
|
36
|
+
带有模糊匹配的文件路径自定义补全器。
|
|
37
37
|
|
|
38
|
-
|
|
39
|
-
path_completer:
|
|
40
|
-
max_suggestions:
|
|
41
|
-
min_score:
|
|
38
|
+
属性:
|
|
39
|
+
path_completer: 基础路径补全器
|
|
40
|
+
max_suggestions: 显示的最大建议数量
|
|
41
|
+
min_score: 建议的最小匹配分数
|
|
42
42
|
"""
|
|
43
43
|
def __init__(self):
|
|
44
|
-
"""
|
|
44
|
+
"""使用默认设置初始化文件补全器。"""
|
|
45
45
|
self.path_completer = PathCompleter()
|
|
46
46
|
self.max_suggestions = 10
|
|
47
47
|
self.min_score = 10
|
|
48
48
|
def get_completions(self, document: Document, complete_event) -> Completion: # type: ignore
|
|
49
49
|
"""
|
|
50
|
-
|
|
50
|
+
生成带有模糊匹配的文件路径补全建议。
|
|
51
51
|
|
|
52
|
-
|
|
53
|
-
document:
|
|
54
|
-
complete_event:
|
|
52
|
+
参数:
|
|
53
|
+
document: 当前正在编辑的文档
|
|
54
|
+
complete_event: 补全事件
|
|
55
55
|
|
|
56
|
-
|
|
57
|
-
Completion:
|
|
56
|
+
生成:
|
|
57
|
+
Completion: 建议的补全项
|
|
58
58
|
"""
|
|
59
59
|
text = document.text_before_cursor
|
|
60
60
|
cursor_pos = document.cursor_position
|
|
61
|
-
#
|
|
61
|
+
# 查找文本中的所有@位置
|
|
62
62
|
at_positions = [i for i, char in enumerate(text) if char == '@']
|
|
63
63
|
if not at_positions:
|
|
64
64
|
return
|
|
65
|
-
#
|
|
65
|
+
# 获取最后一个@位置
|
|
66
66
|
current_at_pos = at_positions[-1]
|
|
67
|
-
#
|
|
67
|
+
# 如果光标不在最后一个@之后,则不补全
|
|
68
68
|
if cursor_pos <= current_at_pos:
|
|
69
69
|
return
|
|
70
|
-
#
|
|
70
|
+
# 检查@之后是否有空格
|
|
71
71
|
text_after_at = text[current_at_pos + 1:cursor_pos]
|
|
72
72
|
if ' ' in text_after_at:
|
|
73
73
|
return
|
|
74
|
-
#
|
|
74
|
+
# 添加默认建议
|
|
75
|
+
if not text_after_at.strip():
|
|
76
|
+
# 默认建议列表
|
|
77
|
+
default_suggestions = [
|
|
78
|
+
('<CodeBase>', '查询代码库'),
|
|
79
|
+
('<Web>', '网页搜索'),
|
|
80
|
+
('<RAG>', '知识库检索')
|
|
81
|
+
]
|
|
82
|
+
for name, desc in default_suggestions:
|
|
83
|
+
yield Completion(
|
|
84
|
+
text=f"'{name}'",
|
|
85
|
+
start_position=-1,
|
|
86
|
+
display=name,
|
|
87
|
+
display_meta=desc
|
|
88
|
+
) # type: ignore
|
|
89
|
+
return
|
|
90
|
+
# 获取当前@之后的文本
|
|
75
91
|
file_path = text_after_at.strip()
|
|
76
|
-
#
|
|
92
|
+
# 计算替换长度
|
|
77
93
|
replace_length = len(text_after_at) + 1
|
|
78
|
-
#
|
|
94
|
+
# 使用git ls-files获取所有可能的文件
|
|
79
95
|
all_files = []
|
|
80
96
|
try:
|
|
81
97
|
import subprocess
|
|
@@ -87,7 +103,7 @@ class FileCompleter(Completer):
|
|
|
87
103
|
all_files = [line.strip() for line in result.stdout.splitlines() if line.strip()]
|
|
88
104
|
except Exception:
|
|
89
105
|
pass
|
|
90
|
-
#
|
|
106
|
+
# 生成补全建议
|
|
91
107
|
if not file_path:
|
|
92
108
|
scored_files = [(path, 100) for path in all_files[:self.max_suggestions]]
|
|
93
109
|
else:
|
|
@@ -95,7 +111,7 @@ class FileCompleter(Completer):
|
|
|
95
111
|
scored_files = [(m[0], m[1]) for m in scored_files_data]
|
|
96
112
|
scored_files.sort(key=lambda x: x[1], reverse=True)
|
|
97
113
|
scored_files = scored_files[:self.max_suggestions]
|
|
98
|
-
#
|
|
114
|
+
# 生成补全项
|
|
99
115
|
for path, score in scored_files:
|
|
100
116
|
if not file_path or score > self.min_score:
|
|
101
117
|
display_text = path
|
|
@@ -109,31 +125,31 @@ class FileCompleter(Completer):
|
|
|
109
125
|
) # type: ignore
|
|
110
126
|
def get_multiline_input(tip: str) -> str:
|
|
111
127
|
"""
|
|
112
|
-
|
|
128
|
+
获取带有增强补全和确认功能的多行输入。
|
|
113
129
|
|
|
114
|
-
|
|
115
|
-
tip:
|
|
130
|
+
参数:
|
|
131
|
+
tip: 要显示的提示信息
|
|
116
132
|
|
|
117
|
-
|
|
118
|
-
str:
|
|
133
|
+
返回:
|
|
134
|
+
str: 用户的输入,如果取消则返回空字符串
|
|
119
135
|
"""
|
|
120
|
-
#
|
|
136
|
+
# 显示输入说明
|
|
121
137
|
PrettyOutput.section("用户输入 - 使用 @ 触发文件补全,Tab 选择补全项,Ctrl+J 提交,按 Ctrl+C 取消输入", OutputType.USER)
|
|
122
138
|
print(f"{Fore.GREEN}{tip}{ColoramaStyle.RESET_ALL}")
|
|
123
|
-
#
|
|
139
|
+
# 配置键绑定
|
|
124
140
|
bindings = KeyBindings()
|
|
125
141
|
@bindings.add('enter')
|
|
126
142
|
def _(event):
|
|
127
|
-
"""
|
|
143
|
+
"""处理回车键以进行补全或换行。"""
|
|
128
144
|
if event.current_buffer.complete_state:
|
|
129
145
|
event.current_buffer.apply_completion(event.current_buffer.complete_state.current_completion)
|
|
130
146
|
else:
|
|
131
147
|
event.current_buffer.insert_text('\n')
|
|
132
148
|
@bindings.add('c-j')
|
|
133
149
|
def _(event):
|
|
134
|
-
"""
|
|
150
|
+
"""处理Ctrl+J以提交输入。"""
|
|
135
151
|
event.current_buffer.validate_and_handle()
|
|
136
|
-
#
|
|
152
|
+
# 配置提示会话
|
|
137
153
|
style = PromptStyle.from_dict({
|
|
138
154
|
'prompt': 'ansicyan',
|
|
139
155
|
})
|
|
@@ -150,7 +166,7 @@ def get_multiline_input(tip: str) -> str:
|
|
|
150
166
|
prompt = FormattedText([
|
|
151
167
|
('class:prompt', '>>> ')
|
|
152
168
|
])
|
|
153
|
-
#
|
|
169
|
+
# 获取输入
|
|
154
170
|
text = session.prompt(
|
|
155
171
|
prompt,
|
|
156
172
|
style=style,
|
|
@@ -158,4 +174,4 @@ def get_multiline_input(tip: str) -> str:
|
|
|
158
174
|
return text
|
|
159
175
|
except KeyboardInterrupt:
|
|
160
176
|
PrettyOutput.print("输入已取消", OutputType.INFO)
|
|
161
|
-
return ""
|
|
177
|
+
return ""
|