jarvis-ai-assistant 0.1.130__py3-none-any.whl → 0.1.132__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jarvis/__init__.py +1 -1
- jarvis/jarvis_agent/__init__.py +71 -38
- jarvis/jarvis_agent/builtin_input_handler.py +73 -0
- jarvis/{jarvis_code_agent → jarvis_agent}/file_input_handler.py +1 -1
- jarvis/jarvis_agent/main.py +1 -1
- jarvis/{jarvis_code_agent → jarvis_agent}/patch.py +77 -55
- jarvis/{jarvis_code_agent → jarvis_agent}/shell_input_handler.py +1 -2
- jarvis/jarvis_code_agent/code_agent.py +93 -88
- jarvis/jarvis_dev/main.py +335 -626
- jarvis/jarvis_git_squash/main.py +11 -32
- jarvis/jarvis_lsp/base.py +2 -26
- jarvis/jarvis_lsp/cpp.py +2 -14
- jarvis/jarvis_lsp/go.py +0 -13
- jarvis/jarvis_lsp/python.py +1 -30
- jarvis/jarvis_lsp/registry.py +10 -14
- jarvis/jarvis_lsp/rust.py +0 -12
- jarvis/jarvis_multi_agent/__init__.py +20 -29
- jarvis/jarvis_platform/ai8.py +7 -32
- jarvis/jarvis_platform/base.py +2 -7
- jarvis/jarvis_platform/kimi.py +3 -144
- jarvis/jarvis_platform/ollama.py +54 -68
- jarvis/jarvis_platform/openai.py +0 -4
- jarvis/jarvis_platform/oyi.py +0 -75
- jarvis/jarvis_platform/registry.py +1 -1
- jarvis/jarvis_platform/yuanbao.py +264 -0
- jarvis/jarvis_platform_manager/main.py +3 -3
- jarvis/jarvis_rag/file_processors.py +138 -0
- jarvis/jarvis_rag/main.py +1305 -425
- jarvis/jarvis_tools/ask_codebase.py +227 -41
- jarvis/jarvis_tools/code_review.py +229 -166
- jarvis/jarvis_tools/create_code_agent.py +76 -72
- jarvis/jarvis_tools/create_sub_agent.py +32 -15
- jarvis/jarvis_tools/execute_python_script.py +58 -0
- jarvis/jarvis_tools/execute_shell.py +15 -28
- jarvis/jarvis_tools/execute_shell_script.py +2 -2
- jarvis/jarvis_tools/file_analyzer.py +271 -0
- jarvis/jarvis_tools/file_operation.py +3 -3
- jarvis/jarvis_tools/find_caller.py +213 -0
- jarvis/jarvis_tools/find_symbol.py +211 -0
- jarvis/jarvis_tools/function_analyzer.py +248 -0
- jarvis/jarvis_tools/git_commiter.py +89 -70
- jarvis/jarvis_tools/lsp_find_definition.py +83 -67
- jarvis/jarvis_tools/lsp_find_references.py +62 -46
- jarvis/jarvis_tools/lsp_get_diagnostics.py +90 -74
- jarvis/jarvis_tools/methodology.py +89 -48
- jarvis/jarvis_tools/project_analyzer.py +220 -0
- jarvis/jarvis_tools/read_code.py +24 -3
- jarvis/jarvis_tools/read_webpage.py +195 -81
- jarvis/jarvis_tools/registry.py +132 -11
- jarvis/jarvis_tools/search_web.py +73 -30
- jarvis/jarvis_tools/tool_generator.py +7 -9
- jarvis/jarvis_utils/__init__.py +1 -0
- jarvis/jarvis_utils/config.py +67 -3
- jarvis/jarvis_utils/embedding.py +344 -45
- jarvis/jarvis_utils/git_utils.py +18 -2
- jarvis/jarvis_utils/input.py +7 -4
- jarvis/jarvis_utils/methodology.py +379 -7
- jarvis/jarvis_utils/output.py +5 -3
- jarvis/jarvis_utils/utils.py +62 -10
- {jarvis_ai_assistant-0.1.130.dist-info → jarvis_ai_assistant-0.1.132.dist-info}/METADATA +3 -4
- jarvis_ai_assistant-0.1.132.dist-info/RECORD +82 -0
- {jarvis_ai_assistant-0.1.130.dist-info → jarvis_ai_assistant-0.1.132.dist-info}/entry_points.txt +2 -0
- jarvis/jarvis_c2rust/c2rust.yaml +0 -734
- jarvis/jarvis_code_agent/builtin_input_handler.py +0 -43
- jarvis/jarvis_codebase/__init__.py +0 -0
- jarvis/jarvis_codebase/main.py +0 -1011
- jarvis/jarvis_tools/lsp_get_document_symbols.py +0 -87
- jarvis/jarvis_tools/lsp_prepare_rename.py +0 -130
- jarvis_ai_assistant-0.1.130.dist-info/RECORD +0 -79
- {jarvis_ai_assistant-0.1.130.dist-info → jarvis_ai_assistant-0.1.132.dist-info}/LICENSE +0 -0
- {jarvis_ai_assistant-0.1.130.dist-info → jarvis_ai_assistant-0.1.132.dist-info}/WHEEL +0 -0
- {jarvis_ai_assistant-0.1.130.dist-info → jarvis_ai_assistant-0.1.132.dist-info}/top_level.txt +0 -0
jarvis/jarvis_utils/embedding.py
CHANGED
|
@@ -3,9 +3,16 @@ import numpy as np
|
|
|
3
3
|
import torch
|
|
4
4
|
from sentence_transformers import SentenceTransformer
|
|
5
5
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
|
6
|
-
from typing import List, Any, Tuple
|
|
6
|
+
from typing import List, Any, Optional, Tuple
|
|
7
|
+
import functools
|
|
8
|
+
|
|
9
|
+
from yaspin.api import Yaspin
|
|
7
10
|
from jarvis.jarvis_utils.output import PrettyOutput, OutputType
|
|
8
11
|
|
|
12
|
+
# 全局缓存,避免重复加载模型
|
|
13
|
+
_global_models = {}
|
|
14
|
+
_global_tokenizers = {}
|
|
15
|
+
|
|
9
16
|
def get_context_token_count(text: str) -> int:
|
|
10
17
|
"""使用分词器获取文本的token数量。
|
|
11
18
|
|
|
@@ -26,9 +33,10 @@ def get_context_token_count(text: str) -> int:
|
|
|
26
33
|
# 回退到基于字符的粗略估计
|
|
27
34
|
return len(text) // 4 # 每个token大约4个字符的粗略估计
|
|
28
35
|
|
|
36
|
+
@functools.lru_cache(maxsize=1)
|
|
29
37
|
def load_embedding_model() -> SentenceTransformer:
|
|
30
38
|
"""
|
|
31
|
-
|
|
39
|
+
加载句子嵌入模型,使用缓存避免重复加载。
|
|
32
40
|
|
|
33
41
|
返回:
|
|
34
42
|
SentenceTransformer: 加载的嵌入模型
|
|
@@ -36,6 +44,10 @@ def load_embedding_model() -> SentenceTransformer:
|
|
|
36
44
|
model_name = "BAAI/bge-m3"
|
|
37
45
|
cache_dir = os.path.expanduser("~/.cache/huggingface/hub")
|
|
38
46
|
|
|
47
|
+
# 检查全局缓存中是否已有模型
|
|
48
|
+
if model_name in _global_models:
|
|
49
|
+
return _global_models[model_name]
|
|
50
|
+
|
|
39
51
|
try:
|
|
40
52
|
embedding_model = SentenceTransformer(
|
|
41
53
|
model_name,
|
|
@@ -49,6 +61,13 @@ def load_embedding_model() -> SentenceTransformer:
|
|
|
49
61
|
local_files_only=False
|
|
50
62
|
)
|
|
51
63
|
|
|
64
|
+
# 如果可用,将模型移到GPU上
|
|
65
|
+
if torch.cuda.is_available():
|
|
66
|
+
embedding_model.to(torch.device("cuda"))
|
|
67
|
+
|
|
68
|
+
# 保存到全局缓存
|
|
69
|
+
_global_models[model_name] = embedding_model
|
|
70
|
+
|
|
52
71
|
return embedding_model
|
|
53
72
|
|
|
54
73
|
def get_embedding(embedding_model: Any, text: str) -> np.ndarray:
|
|
@@ -67,77 +86,249 @@ def get_embedding(embedding_model: Any, text: str) -> np.ndarray:
|
|
|
67
86
|
show_progress_bar=False)
|
|
68
87
|
return np.array(embedding, dtype=np.float32)
|
|
69
88
|
|
|
70
|
-
def get_embedding_batch(embedding_model: Any, texts: List[str]) -> np.ndarray:
|
|
89
|
+
def get_embedding_batch(embedding_model: Any, prefix: str, texts: List[str], spinner: Optional[Yaspin] = None, batch_size: int = 8) -> np.ndarray:
|
|
71
90
|
"""
|
|
72
|
-
|
|
91
|
+
为一批文本生成嵌入向量,使用高效的批处理,针对RAG优化。
|
|
73
92
|
|
|
74
93
|
参数:
|
|
75
94
|
embedding_model: 使用的嵌入模型
|
|
95
|
+
prefix: 进度条前缀
|
|
76
96
|
texts: 要嵌入的文本列表
|
|
97
|
+
spinner: 可选的进度指示器
|
|
98
|
+
batch_size: 批处理大小,更大的值可能更快但需要更多内存
|
|
77
99
|
|
|
78
100
|
返回:
|
|
79
101
|
np.ndarray: 堆叠的嵌入向量
|
|
80
102
|
"""
|
|
103
|
+
# 简单嵌入缓存,避免重复计算相同文本块
|
|
104
|
+
embedding_cache = {}
|
|
105
|
+
cache_hits = 0
|
|
106
|
+
|
|
81
107
|
try:
|
|
108
|
+
# 预处理:将所有文本分块
|
|
109
|
+
all_chunks = []
|
|
110
|
+
chunk_indices = [] # 跟踪每个原始文本对应的块索引
|
|
111
|
+
|
|
112
|
+
for i, text in enumerate(texts):
|
|
113
|
+
if spinner:
|
|
114
|
+
spinner.text = f"{prefix} 预处理中 ({i+1}/{len(texts)}) ..."
|
|
115
|
+
|
|
116
|
+
# 预处理文本:移除多余空白,规范化
|
|
117
|
+
text = ' '.join(text.split()) if text else ""
|
|
118
|
+
|
|
119
|
+
# 使用更优化的分块函数
|
|
120
|
+
chunks = split_text_into_chunks(text, 512)
|
|
121
|
+
start_idx = len(all_chunks)
|
|
122
|
+
all_chunks.extend(chunks)
|
|
123
|
+
end_idx = len(all_chunks)
|
|
124
|
+
chunk_indices.append((start_idx, end_idx))
|
|
125
|
+
|
|
126
|
+
if not all_chunks:
|
|
127
|
+
return np.zeros((0, embedding_model.get_sentence_embedding_dimension()), dtype=np.float32)
|
|
128
|
+
|
|
129
|
+
# 批量处理所有块
|
|
82
130
|
all_vectors = []
|
|
83
|
-
for
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
131
|
+
for i in range(0, len(all_chunks), batch_size):
|
|
132
|
+
if spinner:
|
|
133
|
+
spinner.text = f"{prefix} 批量处理嵌入 ({i+1}/{len(all_chunks)}) ..."
|
|
134
|
+
|
|
135
|
+
batch = all_chunks[i:i+batch_size]
|
|
136
|
+
batch_to_process = []
|
|
137
|
+
batch_indices = []
|
|
138
|
+
|
|
139
|
+
# 检查缓存,避免重复计算
|
|
140
|
+
for j, chunk in enumerate(batch):
|
|
141
|
+
chunk_hash = hash(chunk)
|
|
142
|
+
if chunk_hash in embedding_cache:
|
|
143
|
+
all_vectors.append(embedding_cache[chunk_hash])
|
|
144
|
+
cache_hits += 1
|
|
145
|
+
else:
|
|
146
|
+
batch_to_process.append(chunk)
|
|
147
|
+
batch_indices.append(j)
|
|
148
|
+
|
|
149
|
+
if batch_to_process:
|
|
150
|
+
# 对未缓存的块处理
|
|
151
|
+
batch_vectors = embedding_model.encode(
|
|
152
|
+
batch_to_process,
|
|
153
|
+
normalize_embeddings=True,
|
|
154
|
+
show_progress_bar=False,
|
|
155
|
+
convert_to_numpy=True,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
# 处理结果并更新缓存
|
|
159
|
+
if len(batch_to_process) == 1:
|
|
160
|
+
vec = batch_vectors
|
|
161
|
+
chunk_hash = hash(batch_to_process[0])
|
|
162
|
+
embedding_cache[chunk_hash] = vec
|
|
163
|
+
all_vectors.append(vec)
|
|
164
|
+
else:
|
|
165
|
+
for j, vec in enumerate(batch_vectors):
|
|
166
|
+
chunk_hash = hash(batch_to_process[j])
|
|
167
|
+
embedding_cache[chunk_hash] = vec
|
|
168
|
+
all_vectors.append(vec)
|
|
169
|
+
|
|
170
|
+
# 组织结果到原始文本顺序
|
|
171
|
+
result_vectors = []
|
|
172
|
+
for start_idx, end_idx in chunk_indices:
|
|
173
|
+
text_vectors = []
|
|
174
|
+
for j in range(start_idx, end_idx):
|
|
175
|
+
if j < len(all_vectors):
|
|
176
|
+
text_vectors.append(all_vectors[j])
|
|
177
|
+
|
|
178
|
+
if text_vectors:
|
|
179
|
+
# 当一个文本被分成多个块时,采用加权平均
|
|
180
|
+
if len(text_vectors) > 1:
|
|
181
|
+
# 针对RAG优化:对多个块进行加权平均,前面的块权重略高
|
|
182
|
+
weights = np.linspace(1.0, 0.8, len(text_vectors))
|
|
183
|
+
weights = weights / weights.sum() # 归一化权重
|
|
184
|
+
|
|
185
|
+
# 应用权重并求和
|
|
186
|
+
weighted_sum = np.zeros_like(text_vectors[0])
|
|
187
|
+
for i, vec in enumerate(text_vectors):
|
|
188
|
+
# 确保向量形状一致,处理可能的维度不匹配问题
|
|
189
|
+
vec_array = np.asarray(vec).reshape(weighted_sum.shape)
|
|
190
|
+
weighted_sum += vec_array * weights[i]
|
|
191
|
+
|
|
192
|
+
# 归一化结果向量
|
|
193
|
+
norm = np.linalg.norm(weighted_sum)
|
|
194
|
+
if norm > 0:
|
|
195
|
+
weighted_sum = weighted_sum / norm
|
|
196
|
+
|
|
197
|
+
result_vectors.append(weighted_sum)
|
|
198
|
+
else:
|
|
199
|
+
# 单块直接使用
|
|
200
|
+
result_vectors.append(text_vectors[0])
|
|
201
|
+
|
|
202
|
+
if spinner and cache_hits > 0:
|
|
203
|
+
spinner.text = f"{prefix} 缓存命中: {cache_hits}/{len(all_chunks)} 块"
|
|
204
|
+
|
|
205
|
+
return np.vstack(result_vectors)
|
|
206
|
+
|
|
87
207
|
except Exception as e:
|
|
88
208
|
PrettyOutput.print(f"批量嵌入失败: {str(e)}", OutputType.ERROR)
|
|
89
209
|
return np.zeros((0, embedding_model.get_sentence_embedding_dimension()), dtype=np.float32)
|
|
90
210
|
|
|
91
|
-
def split_text_into_chunks(text: str, max_length: int = 512) -> List[str]:
|
|
92
|
-
"""
|
|
211
|
+
def split_text_into_chunks(text: str, max_length: int = 512, min_length: int = 50) -> List[str]:
|
|
212
|
+
"""将文本分割成带重叠窗口的块,优化RAG检索效果。
|
|
93
213
|
|
|
94
214
|
参数:
|
|
95
215
|
text: 要分割的输入文本
|
|
96
216
|
max_length: 每个块的最大长度
|
|
217
|
+
min_length: 每个块的最小长度(除了最后一块可能较短)
|
|
97
218
|
|
|
98
219
|
返回:
|
|
99
|
-
List[str]:
|
|
220
|
+
List[str]: 文本块列表,每个块的长度尽可能接近但不超过max_length
|
|
100
221
|
"""
|
|
222
|
+
if not text:
|
|
223
|
+
return []
|
|
224
|
+
|
|
225
|
+
# 如果文本长度小于最大长度,直接返回整个文本
|
|
226
|
+
if len(text) <= max_length:
|
|
227
|
+
return [text]
|
|
228
|
+
|
|
229
|
+
# 预处理:规范化文本,移除多余空白字符
|
|
230
|
+
text = ' '.join(text.split())
|
|
231
|
+
|
|
232
|
+
# 中英文标点符号集合,优化RAG召回的句子边界
|
|
233
|
+
primary_punctuation = {'.', '!', '?', '\n', '。', '!', '?'} # 主要句末标点
|
|
234
|
+
secondary_punctuation = {';', ':', '…', ';', ':'} # 次级分隔符
|
|
235
|
+
tertiary_punctuation = {',', ',', '、', ')', ')', ']', '】', '}', '》', '"', "'"} # 最低优先级
|
|
236
|
+
|
|
101
237
|
chunks = []
|
|
102
238
|
start = 0
|
|
103
|
-
while start < len(text):
|
|
104
|
-
end = start + max_length
|
|
105
|
-
# 找到最近的句子边界
|
|
106
|
-
if end < len(text):
|
|
107
|
-
while end > start and text[end] not in {'.', '!', '?', '\n'}:
|
|
108
|
-
end -= 1
|
|
109
|
-
if end == start: # 未找到标点,强制分割
|
|
110
|
-
end = start + max_length
|
|
111
|
-
chunk = text[start:end]
|
|
112
|
-
chunks.append(chunk)
|
|
113
|
-
# 重叠20%的窗口
|
|
114
|
-
start = end - int(max_length * 0.2)
|
|
115
|
-
return chunks
|
|
116
|
-
|
|
117
|
-
def get_embedding_with_chunks(embedding_model: Any, text: str) -> List[np.ndarray]:
|
|
118
|
-
"""
|
|
119
|
-
为文本块生成嵌入向量。
|
|
120
239
|
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
text
|
|
240
|
+
while start < len(text):
|
|
241
|
+
# 初始化结束位置为最大可能长度
|
|
242
|
+
end = min(start + max_length, len(text))
|
|
124
243
|
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
244
|
+
# 只有当不是最后一块且结束位置等于最大长度时,才尝试寻找句子边界
|
|
245
|
+
if end < len(text) and end == start + max_length:
|
|
246
|
+
# 优先查找段落边界,这对RAG特别重要
|
|
247
|
+
paragraph_boundary = text.rfind('\n\n', start, end)
|
|
248
|
+
if paragraph_boundary > start and paragraph_boundary < end - min_length: # 确保不会切得太短
|
|
249
|
+
end = paragraph_boundary + 2
|
|
250
|
+
else:
|
|
251
|
+
# 寻找句子边界,从end-1位置开始
|
|
252
|
+
found_boundary = False
|
|
253
|
+
best_boundary = -1
|
|
254
|
+
|
|
255
|
+
# 扩大搜索范围以找到更好的语义边界
|
|
256
|
+
search_range = min(120, end - start - min_length) # 扩大搜索范围,但确保新块不小于min_length
|
|
257
|
+
|
|
258
|
+
# 先尝试找主要标点(句号等)
|
|
259
|
+
for i in range(end-1, max(start, end-search_range), -1):
|
|
260
|
+
if text[i] in primary_punctuation:
|
|
261
|
+
best_boundary = i
|
|
262
|
+
found_boundary = True
|
|
263
|
+
break
|
|
264
|
+
|
|
265
|
+
# 如果没找到主要标点,再找次要标点(分号、冒号等)
|
|
266
|
+
if not found_boundary:
|
|
267
|
+
for i in range(end-1, max(start, end-search_range), -1):
|
|
268
|
+
if text[i] in secondary_punctuation:
|
|
269
|
+
best_boundary = i
|
|
270
|
+
found_boundary = True
|
|
271
|
+
break
|
|
272
|
+
|
|
273
|
+
# 最后考虑逗号和其他可能的边界
|
|
274
|
+
if not found_boundary:
|
|
275
|
+
for i in range(end-1, max(start, end-search_range), -1):
|
|
276
|
+
if text[i] in tertiary_punctuation:
|
|
277
|
+
best_boundary = i
|
|
278
|
+
found_boundary = True
|
|
279
|
+
break
|
|
280
|
+
|
|
281
|
+
# 如果找到了合适的边界且不会导致太短的块,使用它
|
|
282
|
+
if found_boundary and (best_boundary - start) >= min_length:
|
|
283
|
+
end = best_boundary + 1
|
|
284
|
+
|
|
285
|
+
# 添加当前块,并确保删除开头和结尾的空白字符
|
|
286
|
+
chunk = text[start:end].strip()
|
|
287
|
+
if chunk and len(chunk) >= min_length: # 只添加符合最小长度的非空块
|
|
288
|
+
chunks.append(chunk)
|
|
289
|
+
elif chunk and not chunks: # 如果是第一个块且小于最小长度,也添加它
|
|
290
|
+
chunks.append(chunk)
|
|
291
|
+
elif chunk: # 如果块太小,尝试与前一个块合并
|
|
292
|
+
if chunks:
|
|
293
|
+
if len(chunks[-1]) + len(chunk) <= max_length * 1.1: # 允许略微超过最大长度
|
|
294
|
+
chunks[-1] = chunks[-1] + " " + chunk
|
|
295
|
+
else:
|
|
296
|
+
# 如果合并会导致太长,添加这个小块(特殊情况)
|
|
297
|
+
chunks.append(chunk)
|
|
298
|
+
|
|
299
|
+
# 计算下一块的开始位置,调整重叠窗口大小以提高RAG检索质量
|
|
300
|
+
next_start = end - int(max_length * 0.2) # 20%的重叠窗口大小
|
|
301
|
+
|
|
302
|
+
# 确保总是有前进,避免无限循环
|
|
303
|
+
if next_start <= start:
|
|
304
|
+
next_start = start + max(1, min_length // 2)
|
|
305
|
+
|
|
306
|
+
start = next_start
|
|
307
|
+
|
|
308
|
+
# 最后检查是否有太短的块,尝试合并相邻的短块
|
|
309
|
+
if len(chunks) > 1:
|
|
310
|
+
merged_chunks = []
|
|
311
|
+
i = 0
|
|
312
|
+
while i < len(chunks):
|
|
313
|
+
current = chunks[i]
|
|
314
|
+
# 如果当前块太短且不是最后一个块,尝试与下一个合并
|
|
315
|
+
if len(current) < min_length and i < len(chunks) - 1:
|
|
316
|
+
next_chunk = chunks[i + 1]
|
|
317
|
+
if len(current) + len(next_chunk) <= max_length * 1.1:
|
|
318
|
+
merged_chunks.append(current + " " + next_chunk)
|
|
319
|
+
i += 2 # 跳过下一个块
|
|
320
|
+
continue
|
|
321
|
+
merged_chunks.append(current)
|
|
322
|
+
i += 1
|
|
323
|
+
chunks = merged_chunks
|
|
131
324
|
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
vector = get_embedding(embedding_model, chunk)
|
|
135
|
-
vectors.append(vector)
|
|
136
|
-
return vectors
|
|
325
|
+
return chunks
|
|
326
|
+
|
|
137
327
|
|
|
328
|
+
@functools.lru_cache(maxsize=1)
|
|
138
329
|
def load_tokenizer() -> AutoTokenizer:
|
|
139
330
|
"""
|
|
140
|
-
|
|
331
|
+
加载用于文本处理的分词器,使用缓存避免重复加载。
|
|
141
332
|
|
|
142
333
|
返回:
|
|
143
334
|
AutoTokenizer: 加载的分词器
|
|
@@ -145,6 +336,10 @@ def load_tokenizer() -> AutoTokenizer:
|
|
|
145
336
|
model_name = "gpt2"
|
|
146
337
|
cache_dir = os.path.expanduser("~/.cache/huggingface/hub")
|
|
147
338
|
|
|
339
|
+
# 检查全局缓存
|
|
340
|
+
if model_name in _global_tokenizers:
|
|
341
|
+
return _global_tokenizers[model_name]
|
|
342
|
+
|
|
148
343
|
try:
|
|
149
344
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
150
345
|
model_name,
|
|
@@ -158,11 +353,15 @@ def load_tokenizer() -> AutoTokenizer:
|
|
|
158
353
|
local_files_only=False
|
|
159
354
|
)
|
|
160
355
|
|
|
356
|
+
# 保存到全局缓存
|
|
357
|
+
_global_tokenizers[model_name] = tokenizer
|
|
358
|
+
|
|
161
359
|
return tokenizer # type: ignore
|
|
162
360
|
|
|
361
|
+
@functools.lru_cache(maxsize=1)
|
|
163
362
|
def load_rerank_model() -> Tuple[AutoModelForSequenceClassification, AutoTokenizer]:
|
|
164
363
|
"""
|
|
165
|
-
|
|
364
|
+
加载重排序模型和分词器,使用缓存避免重复加载。
|
|
166
365
|
|
|
167
366
|
返回:
|
|
168
367
|
Tuple[AutoModelForSequenceClassification, AutoTokenizer]: 加载的模型和分词器
|
|
@@ -170,7 +369,10 @@ def load_rerank_model() -> Tuple[AutoModelForSequenceClassification, AutoTokeniz
|
|
|
170
369
|
model_name = "BAAI/bge-reranker-v2-m3"
|
|
171
370
|
cache_dir = os.path.expanduser("~/.cache/huggingface/hub")
|
|
172
371
|
|
|
173
|
-
|
|
372
|
+
# 检查全局缓存
|
|
373
|
+
key = f"rerank_{model_name}"
|
|
374
|
+
if key in _global_models and f"{key}_tokenizer" in _global_tokenizers:
|
|
375
|
+
return _global_models[key], _global_tokenizers[f"{key}_tokenizer"]
|
|
174
376
|
|
|
175
377
|
try:
|
|
176
378
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
@@ -199,4 +401,101 @@ def load_rerank_model() -> Tuple[AutoModelForSequenceClassification, AutoTokeniz
|
|
|
199
401
|
model = model.cuda()
|
|
200
402
|
model.eval()
|
|
201
403
|
|
|
404
|
+
# 保存到全局缓存
|
|
405
|
+
_global_models[key] = model
|
|
406
|
+
_global_tokenizers[f"{key}_tokenizer"] = tokenizer
|
|
407
|
+
|
|
202
408
|
return model, tokenizer # type: ignore
|
|
409
|
+
|
|
410
|
+
def rerank_results(query: str, documents: List[str], initial_scores: Optional[List[float]] = None,
|
|
411
|
+
batch_size: int = 8, spinner: Optional[Yaspin] = None) -> List[float]:
|
|
412
|
+
"""
|
|
413
|
+
使用交叉编码器重排序检索结果,提高RAG精度。
|
|
414
|
+
|
|
415
|
+
参数:
|
|
416
|
+
query: 查询文本
|
|
417
|
+
documents: 要重排序的文档内容列表
|
|
418
|
+
initial_scores: 初始检索分数,可选。如果提供,将与重排序分数融合
|
|
419
|
+
batch_size: 批处理大小
|
|
420
|
+
spinner: 可选的进度指示器
|
|
421
|
+
|
|
422
|
+
返回:
|
|
423
|
+
List[float]: 重排序后的分数列表,与输入文档对应
|
|
424
|
+
"""
|
|
425
|
+
try:
|
|
426
|
+
if not documents:
|
|
427
|
+
return []
|
|
428
|
+
|
|
429
|
+
# 加载重排序模型
|
|
430
|
+
if spinner:
|
|
431
|
+
spinner.text = "加载重排序模型..."
|
|
432
|
+
model, tokenizer = load_rerank_model()
|
|
433
|
+
|
|
434
|
+
# 准备评分
|
|
435
|
+
all_scores = []
|
|
436
|
+
|
|
437
|
+
# 批量处理
|
|
438
|
+
for i in range(0, len(documents), batch_size):
|
|
439
|
+
if spinner:
|
|
440
|
+
spinner.text = f"重排序进度: {i}/{len(documents)}..."
|
|
441
|
+
|
|
442
|
+
# 准备当前批次
|
|
443
|
+
batch_docs = documents[i:i+batch_size]
|
|
444
|
+
pairs = [(query, doc) for doc in batch_docs]
|
|
445
|
+
|
|
446
|
+
# 编码输入
|
|
447
|
+
with torch.no_grad():
|
|
448
|
+
# 使用类型忽略以避免mypy错误
|
|
449
|
+
inputs = tokenizer( # type: ignore
|
|
450
|
+
pairs,
|
|
451
|
+
padding=True,
|
|
452
|
+
truncation=True,
|
|
453
|
+
return_tensors="pt",
|
|
454
|
+
max_length=512
|
|
455
|
+
)
|
|
456
|
+
|
|
457
|
+
# 使用GPU加速(如果可用)
|
|
458
|
+
if torch.cuda.is_available():
|
|
459
|
+
inputs = {k: v.cuda() for k, v in inputs.items()}
|
|
460
|
+
|
|
461
|
+
# 获取分数
|
|
462
|
+
outputs = model(**inputs) # type: ignore
|
|
463
|
+
scores = outputs.logits.squeeze(-1).cpu().tolist()
|
|
464
|
+
|
|
465
|
+
# 如果只有一个文档,确保返回列表
|
|
466
|
+
if len(batch_docs) == 1:
|
|
467
|
+
all_scores.append(float(scores))
|
|
468
|
+
else:
|
|
469
|
+
all_scores.extend(scores)
|
|
470
|
+
|
|
471
|
+
# 归一化分数到0-1范围
|
|
472
|
+
if all_scores:
|
|
473
|
+
min_score = min(all_scores)
|
|
474
|
+
max_score = max(all_scores)
|
|
475
|
+
if max_score > min_score:
|
|
476
|
+
normalized_scores = [(score - min_score) / (max_score - min_score) for score in all_scores]
|
|
477
|
+
else:
|
|
478
|
+
normalized_scores = [0.5] * len(all_scores)
|
|
479
|
+
|
|
480
|
+
# 融合初始分数(如果提供)
|
|
481
|
+
if initial_scores and len(initial_scores) == len(normalized_scores):
|
|
482
|
+
# 使用加权平均融合分数:初始分数权重0.3,重排序分数权重0.7
|
|
483
|
+
final_scores = [0.3 * init_score + 0.7 * rerank_score
|
|
484
|
+
for init_score, rerank_score in zip(initial_scores, normalized_scores)]
|
|
485
|
+
return final_scores
|
|
486
|
+
|
|
487
|
+
return normalized_scores
|
|
488
|
+
|
|
489
|
+
if spinner:
|
|
490
|
+
spinner.text = "重排序完成"
|
|
491
|
+
|
|
492
|
+
# 如果重排序失败,返回初始分数或默认分数
|
|
493
|
+
return initial_scores if initial_scores else [0.5] * len(documents)
|
|
494
|
+
|
|
495
|
+
except Exception as e:
|
|
496
|
+
PrettyOutput.print(f"重排序失败: {str(e)}", OutputType.ERROR)
|
|
497
|
+
if spinner:
|
|
498
|
+
spinner.text = f"重排序失败: {str(e)}"
|
|
499
|
+
|
|
500
|
+
# 发生错误时回退到初始分数
|
|
501
|
+
return initial_scores if initial_scores else [0.5] * len(documents)
|
jarvis/jarvis_utils/git_utils.py
CHANGED
|
@@ -14,9 +14,25 @@ import subprocess
|
|
|
14
14
|
from typing import List, Tuple, Dict
|
|
15
15
|
from jarvis.jarvis_utils.output import PrettyOutput, OutputType
|
|
16
16
|
def find_git_root(start_dir="."):
|
|
17
|
-
"""
|
|
17
|
+
"""
|
|
18
|
+
切换到给定路径的Git根目录,如果不是Git仓库则初始化。
|
|
19
|
+
|
|
20
|
+
参数:
|
|
21
|
+
start_dir (str): 起始查找目录,默认为当前目录。
|
|
22
|
+
|
|
23
|
+
返回:
|
|
24
|
+
str: Git仓库根目录路径。如果目录不是Git仓库,则会初始化一个新的Git仓库。
|
|
25
|
+
"""
|
|
18
26
|
os.chdir(start_dir)
|
|
19
|
-
|
|
27
|
+
try:
|
|
28
|
+
git_root = os.popen("git rev-parse --show-toplevel").read().strip()
|
|
29
|
+
if not git_root:
|
|
30
|
+
subprocess.run(["git", "init"], check=True)
|
|
31
|
+
git_root = os.path.abspath(".")
|
|
32
|
+
except subprocess.CalledProcessError:
|
|
33
|
+
# 如果不是Git仓库,初始化一个新的
|
|
34
|
+
subprocess.run(["git", "init"], check=True)
|
|
35
|
+
git_root = os.path.abspath(".")
|
|
20
36
|
os.chdir(git_root)
|
|
21
37
|
return git_root
|
|
22
38
|
def has_uncommitted_changes():
|
jarvis/jarvis_utils/input.py
CHANGED
|
@@ -15,7 +15,7 @@ from prompt_toolkit.document import Document
|
|
|
15
15
|
from prompt_toolkit.key_binding import KeyBindings
|
|
16
16
|
from fuzzywuzzy import process
|
|
17
17
|
from colorama import Fore, Style as ColoramaStyle
|
|
18
|
-
from
|
|
18
|
+
from jarvis.jarvis_utils.output import PrettyOutput, OutputType
|
|
19
19
|
def get_single_line_input(tip: str) -> str:
|
|
20
20
|
"""
|
|
21
21
|
获取支持历史记录的单行输入。
|
|
@@ -74,10 +74,13 @@ class FileCompleter(Completer):
|
|
|
74
74
|
# 添加默认建议
|
|
75
75
|
if not text_after_at.strip():
|
|
76
76
|
# 默认建议列表
|
|
77
|
+
from jarvis.jarvis_utils.utils import ot
|
|
77
78
|
default_suggestions = [
|
|
78
|
-
(
|
|
79
|
-
(
|
|
80
|
-
(
|
|
79
|
+
(ot("CodeBase"), '查询代码库'),
|
|
80
|
+
(ot("Web"), '网页搜索'),
|
|
81
|
+
(ot("RAG"), '知识库检索'),
|
|
82
|
+
(ot("Summary"), '总结'),
|
|
83
|
+
(ot("Clear"), '清除历史'),
|
|
81
84
|
]
|
|
82
85
|
for name, desc in default_suggestions:
|
|
83
86
|
yield Completion(
|