jarvis-ai-assistant 0.1.130__py3-none-any.whl → 0.1.132__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. jarvis/__init__.py +1 -1
  2. jarvis/jarvis_agent/__init__.py +71 -38
  3. jarvis/jarvis_agent/builtin_input_handler.py +73 -0
  4. jarvis/{jarvis_code_agent → jarvis_agent}/file_input_handler.py +1 -1
  5. jarvis/jarvis_agent/main.py +1 -1
  6. jarvis/{jarvis_code_agent → jarvis_agent}/patch.py +77 -55
  7. jarvis/{jarvis_code_agent → jarvis_agent}/shell_input_handler.py +1 -2
  8. jarvis/jarvis_code_agent/code_agent.py +93 -88
  9. jarvis/jarvis_dev/main.py +335 -626
  10. jarvis/jarvis_git_squash/main.py +11 -32
  11. jarvis/jarvis_lsp/base.py +2 -26
  12. jarvis/jarvis_lsp/cpp.py +2 -14
  13. jarvis/jarvis_lsp/go.py +0 -13
  14. jarvis/jarvis_lsp/python.py +1 -30
  15. jarvis/jarvis_lsp/registry.py +10 -14
  16. jarvis/jarvis_lsp/rust.py +0 -12
  17. jarvis/jarvis_multi_agent/__init__.py +20 -29
  18. jarvis/jarvis_platform/ai8.py +7 -32
  19. jarvis/jarvis_platform/base.py +2 -7
  20. jarvis/jarvis_platform/kimi.py +3 -144
  21. jarvis/jarvis_platform/ollama.py +54 -68
  22. jarvis/jarvis_platform/openai.py +0 -4
  23. jarvis/jarvis_platform/oyi.py +0 -75
  24. jarvis/jarvis_platform/registry.py +1 -1
  25. jarvis/jarvis_platform/yuanbao.py +264 -0
  26. jarvis/jarvis_platform_manager/main.py +3 -3
  27. jarvis/jarvis_rag/file_processors.py +138 -0
  28. jarvis/jarvis_rag/main.py +1305 -425
  29. jarvis/jarvis_tools/ask_codebase.py +227 -41
  30. jarvis/jarvis_tools/code_review.py +229 -166
  31. jarvis/jarvis_tools/create_code_agent.py +76 -72
  32. jarvis/jarvis_tools/create_sub_agent.py +32 -15
  33. jarvis/jarvis_tools/execute_python_script.py +58 -0
  34. jarvis/jarvis_tools/execute_shell.py +15 -28
  35. jarvis/jarvis_tools/execute_shell_script.py +2 -2
  36. jarvis/jarvis_tools/file_analyzer.py +271 -0
  37. jarvis/jarvis_tools/file_operation.py +3 -3
  38. jarvis/jarvis_tools/find_caller.py +213 -0
  39. jarvis/jarvis_tools/find_symbol.py +211 -0
  40. jarvis/jarvis_tools/function_analyzer.py +248 -0
  41. jarvis/jarvis_tools/git_commiter.py +89 -70
  42. jarvis/jarvis_tools/lsp_find_definition.py +83 -67
  43. jarvis/jarvis_tools/lsp_find_references.py +62 -46
  44. jarvis/jarvis_tools/lsp_get_diagnostics.py +90 -74
  45. jarvis/jarvis_tools/methodology.py +89 -48
  46. jarvis/jarvis_tools/project_analyzer.py +220 -0
  47. jarvis/jarvis_tools/read_code.py +24 -3
  48. jarvis/jarvis_tools/read_webpage.py +195 -81
  49. jarvis/jarvis_tools/registry.py +132 -11
  50. jarvis/jarvis_tools/search_web.py +73 -30
  51. jarvis/jarvis_tools/tool_generator.py +7 -9
  52. jarvis/jarvis_utils/__init__.py +1 -0
  53. jarvis/jarvis_utils/config.py +67 -3
  54. jarvis/jarvis_utils/embedding.py +344 -45
  55. jarvis/jarvis_utils/git_utils.py +18 -2
  56. jarvis/jarvis_utils/input.py +7 -4
  57. jarvis/jarvis_utils/methodology.py +379 -7
  58. jarvis/jarvis_utils/output.py +5 -3
  59. jarvis/jarvis_utils/utils.py +62 -10
  60. {jarvis_ai_assistant-0.1.130.dist-info → jarvis_ai_assistant-0.1.132.dist-info}/METADATA +3 -4
  61. jarvis_ai_assistant-0.1.132.dist-info/RECORD +82 -0
  62. {jarvis_ai_assistant-0.1.130.dist-info → jarvis_ai_assistant-0.1.132.dist-info}/entry_points.txt +2 -0
  63. jarvis/jarvis_c2rust/c2rust.yaml +0 -734
  64. jarvis/jarvis_code_agent/builtin_input_handler.py +0 -43
  65. jarvis/jarvis_codebase/__init__.py +0 -0
  66. jarvis/jarvis_codebase/main.py +0 -1011
  67. jarvis/jarvis_tools/lsp_get_document_symbols.py +0 -87
  68. jarvis/jarvis_tools/lsp_prepare_rename.py +0 -130
  69. jarvis_ai_assistant-0.1.130.dist-info/RECORD +0 -79
  70. {jarvis_ai_assistant-0.1.130.dist-info → jarvis_ai_assistant-0.1.132.dist-info}/LICENSE +0 -0
  71. {jarvis_ai_assistant-0.1.130.dist-info → jarvis_ai_assistant-0.1.132.dist-info}/WHEEL +0 -0
  72. {jarvis_ai_assistant-0.1.130.dist-info → jarvis_ai_assistant-0.1.132.dist-info}/top_level.txt +0 -0
@@ -3,9 +3,16 @@ import numpy as np
3
3
  import torch
4
4
  from sentence_transformers import SentenceTransformer
5
5
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
- from typing import List, Any, Tuple
6
+ from typing import List, Any, Optional, Tuple
7
+ import functools
8
+
9
+ from yaspin.api import Yaspin
7
10
  from jarvis.jarvis_utils.output import PrettyOutput, OutputType
8
11
 
12
+ # 全局缓存,避免重复加载模型
13
+ _global_models = {}
14
+ _global_tokenizers = {}
15
+
9
16
  def get_context_token_count(text: str) -> int:
10
17
  """使用分词器获取文本的token数量。
11
18
 
@@ -26,9 +33,10 @@ def get_context_token_count(text: str) -> int:
26
33
  # 回退到基于字符的粗略估计
27
34
  return len(text) // 4 # 每个token大约4个字符的粗略估计
28
35
 
36
+ @functools.lru_cache(maxsize=1)
29
37
  def load_embedding_model() -> SentenceTransformer:
30
38
  """
31
- 加载句子嵌入模型。
39
+ 加载句子嵌入模型,使用缓存避免重复加载。
32
40
 
33
41
  返回:
34
42
  SentenceTransformer: 加载的嵌入模型
@@ -36,6 +44,10 @@ def load_embedding_model() -> SentenceTransformer:
36
44
  model_name = "BAAI/bge-m3"
37
45
  cache_dir = os.path.expanduser("~/.cache/huggingface/hub")
38
46
 
47
+ # 检查全局缓存中是否已有模型
48
+ if model_name in _global_models:
49
+ return _global_models[model_name]
50
+
39
51
  try:
40
52
  embedding_model = SentenceTransformer(
41
53
  model_name,
@@ -49,6 +61,13 @@ def load_embedding_model() -> SentenceTransformer:
49
61
  local_files_only=False
50
62
  )
51
63
 
64
+ # 如果可用,将模型移到GPU上
65
+ if torch.cuda.is_available():
66
+ embedding_model.to(torch.device("cuda"))
67
+
68
+ # 保存到全局缓存
69
+ _global_models[model_name] = embedding_model
70
+
52
71
  return embedding_model
53
72
 
54
73
  def get_embedding(embedding_model: Any, text: str) -> np.ndarray:
@@ -67,77 +86,249 @@ def get_embedding(embedding_model: Any, text: str) -> np.ndarray:
67
86
  show_progress_bar=False)
68
87
  return np.array(embedding, dtype=np.float32)
69
88
 
70
- def get_embedding_batch(embedding_model: Any, texts: List[str]) -> np.ndarray:
89
+ def get_embedding_batch(embedding_model: Any, prefix: str, texts: List[str], spinner: Optional[Yaspin] = None, batch_size: int = 8) -> np.ndarray:
71
90
  """
72
- 为一批文本生成嵌入向量。
91
+ 为一批文本生成嵌入向量,使用高效的批处理,针对RAG优化。
73
92
 
74
93
  参数:
75
94
  embedding_model: 使用的嵌入模型
95
+ prefix: 进度条前缀
76
96
  texts: 要嵌入的文本列表
97
+ spinner: 可选的进度指示器
98
+ batch_size: 批处理大小,更大的值可能更快但需要更多内存
77
99
 
78
100
  返回:
79
101
  np.ndarray: 堆叠的嵌入向量
80
102
  """
103
+ # 简单嵌入缓存,避免重复计算相同文本块
104
+ embedding_cache = {}
105
+ cache_hits = 0
106
+
81
107
  try:
108
+ # 预处理:将所有文本分块
109
+ all_chunks = []
110
+ chunk_indices = [] # 跟踪每个原始文本对应的块索引
111
+
112
+ for i, text in enumerate(texts):
113
+ if spinner:
114
+ spinner.text = f"{prefix} 预处理中 ({i+1}/{len(texts)}) ..."
115
+
116
+ # 预处理文本:移除多余空白,规范化
117
+ text = ' '.join(text.split()) if text else ""
118
+
119
+ # 使用更优化的分块函数
120
+ chunks = split_text_into_chunks(text, 512)
121
+ start_idx = len(all_chunks)
122
+ all_chunks.extend(chunks)
123
+ end_idx = len(all_chunks)
124
+ chunk_indices.append((start_idx, end_idx))
125
+
126
+ if not all_chunks:
127
+ return np.zeros((0, embedding_model.get_sentence_embedding_dimension()), dtype=np.float32)
128
+
129
+ # 批量处理所有块
82
130
  all_vectors = []
83
- for text in texts:
84
- vectors = get_embedding_with_chunks(embedding_model, text)
85
- all_vectors.extend(vectors)
86
- return np.vstack(all_vectors)
131
+ for i in range(0, len(all_chunks), batch_size):
132
+ if spinner:
133
+ spinner.text = f"{prefix} 批量处理嵌入 ({i+1}/{len(all_chunks)}) ..."
134
+
135
+ batch = all_chunks[i:i+batch_size]
136
+ batch_to_process = []
137
+ batch_indices = []
138
+
139
+ # 检查缓存,避免重复计算
140
+ for j, chunk in enumerate(batch):
141
+ chunk_hash = hash(chunk)
142
+ if chunk_hash in embedding_cache:
143
+ all_vectors.append(embedding_cache[chunk_hash])
144
+ cache_hits += 1
145
+ else:
146
+ batch_to_process.append(chunk)
147
+ batch_indices.append(j)
148
+
149
+ if batch_to_process:
150
+ # 对未缓存的块处理
151
+ batch_vectors = embedding_model.encode(
152
+ batch_to_process,
153
+ normalize_embeddings=True,
154
+ show_progress_bar=False,
155
+ convert_to_numpy=True,
156
+ )
157
+
158
+ # 处理结果并更新缓存
159
+ if len(batch_to_process) == 1:
160
+ vec = batch_vectors
161
+ chunk_hash = hash(batch_to_process[0])
162
+ embedding_cache[chunk_hash] = vec
163
+ all_vectors.append(vec)
164
+ else:
165
+ for j, vec in enumerate(batch_vectors):
166
+ chunk_hash = hash(batch_to_process[j])
167
+ embedding_cache[chunk_hash] = vec
168
+ all_vectors.append(vec)
169
+
170
+ # 组织结果到原始文本顺序
171
+ result_vectors = []
172
+ for start_idx, end_idx in chunk_indices:
173
+ text_vectors = []
174
+ for j in range(start_idx, end_idx):
175
+ if j < len(all_vectors):
176
+ text_vectors.append(all_vectors[j])
177
+
178
+ if text_vectors:
179
+ # 当一个文本被分成多个块时,采用加权平均
180
+ if len(text_vectors) > 1:
181
+ # 针对RAG优化:对多个块进行加权平均,前面的块权重略高
182
+ weights = np.linspace(1.0, 0.8, len(text_vectors))
183
+ weights = weights / weights.sum() # 归一化权重
184
+
185
+ # 应用权重并求和
186
+ weighted_sum = np.zeros_like(text_vectors[0])
187
+ for i, vec in enumerate(text_vectors):
188
+ # 确保向量形状一致,处理可能的维度不匹配问题
189
+ vec_array = np.asarray(vec).reshape(weighted_sum.shape)
190
+ weighted_sum += vec_array * weights[i]
191
+
192
+ # 归一化结果向量
193
+ norm = np.linalg.norm(weighted_sum)
194
+ if norm > 0:
195
+ weighted_sum = weighted_sum / norm
196
+
197
+ result_vectors.append(weighted_sum)
198
+ else:
199
+ # 单块直接使用
200
+ result_vectors.append(text_vectors[0])
201
+
202
+ if spinner and cache_hits > 0:
203
+ spinner.text = f"{prefix} 缓存命中: {cache_hits}/{len(all_chunks)} 块"
204
+
205
+ return np.vstack(result_vectors)
206
+
87
207
  except Exception as e:
88
208
  PrettyOutput.print(f"批量嵌入失败: {str(e)}", OutputType.ERROR)
89
209
  return np.zeros((0, embedding_model.get_sentence_embedding_dimension()), dtype=np.float32)
90
210
 
91
- def split_text_into_chunks(text: str, max_length: int = 512) -> List[str]:
92
- """将文本分割成带重叠窗口的块。
211
+ def split_text_into_chunks(text: str, max_length: int = 512, min_length: int = 50) -> List[str]:
212
+ """将文本分割成带重叠窗口的块,优化RAG检索效果。
93
213
 
94
214
  参数:
95
215
  text: 要分割的输入文本
96
216
  max_length: 每个块的最大长度
217
+ min_length: 每个块的最小长度(除了最后一块可能较短)
97
218
 
98
219
  返回:
99
- List[str]: 文本块列表
220
+ List[str]: 文本块列表,每个块的长度尽可能接近但不超过max_length
100
221
  """
222
+ if not text:
223
+ return []
224
+
225
+ # 如果文本长度小于最大长度,直接返回整个文本
226
+ if len(text) <= max_length:
227
+ return [text]
228
+
229
+ # 预处理:规范化文本,移除多余空白字符
230
+ text = ' '.join(text.split())
231
+
232
+ # 中英文标点符号集合,优化RAG召回的句子边界
233
+ primary_punctuation = {'.', '!', '?', '\n', '。', '!', '?'} # 主要句末标点
234
+ secondary_punctuation = {';', ':', '…', ';', ':'} # 次级分隔符
235
+ tertiary_punctuation = {',', ',', '、', ')', ')', ']', '】', '}', '》', '"', "'"} # 最低优先级
236
+
101
237
  chunks = []
102
238
  start = 0
103
- while start < len(text):
104
- end = start + max_length
105
- # 找到最近的句子边界
106
- if end < len(text):
107
- while end > start and text[end] not in {'.', '!', '?', '\n'}:
108
- end -= 1
109
- if end == start: # 未找到标点,强制分割
110
- end = start + max_length
111
- chunk = text[start:end]
112
- chunks.append(chunk)
113
- # 重叠20%的窗口
114
- start = end - int(max_length * 0.2)
115
- return chunks
116
-
117
- def get_embedding_with_chunks(embedding_model: Any, text: str) -> List[np.ndarray]:
118
- """
119
- 为文本块生成嵌入向量。
120
239
 
121
- 参数:
122
- embedding_model: 使用的嵌入模型
123
- text: 要处理的输入文本
240
+ while start < len(text):
241
+ # 初始化结束位置为最大可能长度
242
+ end = min(start + max_length, len(text))
124
243
 
125
- 返回:
126
- List[np.ndarray]: 每个块的嵌入向量列表
127
- """
128
- chunks = split_text_into_chunks(text, 512)
129
- if not chunks:
130
- return []
244
+ # 只有当不是最后一块且结束位置等于最大长度时,才尝试寻找句子边界
245
+ if end < len(text) and end == start + max_length:
246
+ # 优先查找段落边界,这对RAG特别重要
247
+ paragraph_boundary = text.rfind('\n\n', start, end)
248
+ if paragraph_boundary > start and paragraph_boundary < end - min_length: # 确保不会切得太短
249
+ end = paragraph_boundary + 2
250
+ else:
251
+ # 寻找句子边界,从end-1位置开始
252
+ found_boundary = False
253
+ best_boundary = -1
254
+
255
+ # 扩大搜索范围以找到更好的语义边界
256
+ search_range = min(120, end - start - min_length) # 扩大搜索范围,但确保新块不小于min_length
257
+
258
+ # 先尝试找主要标点(句号等)
259
+ for i in range(end-1, max(start, end-search_range), -1):
260
+ if text[i] in primary_punctuation:
261
+ best_boundary = i
262
+ found_boundary = True
263
+ break
264
+
265
+ # 如果没找到主要标点,再找次要标点(分号、冒号等)
266
+ if not found_boundary:
267
+ for i in range(end-1, max(start, end-search_range), -1):
268
+ if text[i] in secondary_punctuation:
269
+ best_boundary = i
270
+ found_boundary = True
271
+ break
272
+
273
+ # 最后考虑逗号和其他可能的边界
274
+ if not found_boundary:
275
+ for i in range(end-1, max(start, end-search_range), -1):
276
+ if text[i] in tertiary_punctuation:
277
+ best_boundary = i
278
+ found_boundary = True
279
+ break
280
+
281
+ # 如果找到了合适的边界且不会导致太短的块,使用它
282
+ if found_boundary and (best_boundary - start) >= min_length:
283
+ end = best_boundary + 1
284
+
285
+ # 添加当前块,并确保删除开头和结尾的空白字符
286
+ chunk = text[start:end].strip()
287
+ if chunk and len(chunk) >= min_length: # 只添加符合最小长度的非空块
288
+ chunks.append(chunk)
289
+ elif chunk and not chunks: # 如果是第一个块且小于最小长度,也添加它
290
+ chunks.append(chunk)
291
+ elif chunk: # 如果块太小,尝试与前一个块合并
292
+ if chunks:
293
+ if len(chunks[-1]) + len(chunk) <= max_length * 1.1: # 允许略微超过最大长度
294
+ chunks[-1] = chunks[-1] + " " + chunk
295
+ else:
296
+ # 如果合并会导致太长,添加这个小块(特殊情况)
297
+ chunks.append(chunk)
298
+
299
+ # 计算下一块的开始位置,调整重叠窗口大小以提高RAG检索质量
300
+ next_start = end - int(max_length * 0.2) # 20%的重叠窗口大小
301
+
302
+ # 确保总是有前进,避免无限循环
303
+ if next_start <= start:
304
+ next_start = start + max(1, min_length // 2)
305
+
306
+ start = next_start
307
+
308
+ # 最后检查是否有太短的块,尝试合并相邻的短块
309
+ if len(chunks) > 1:
310
+ merged_chunks = []
311
+ i = 0
312
+ while i < len(chunks):
313
+ current = chunks[i]
314
+ # 如果当前块太短且不是最后一个块,尝试与下一个合并
315
+ if len(current) < min_length and i < len(chunks) - 1:
316
+ next_chunk = chunks[i + 1]
317
+ if len(current) + len(next_chunk) <= max_length * 1.1:
318
+ merged_chunks.append(current + " " + next_chunk)
319
+ i += 2 # 跳过下一个块
320
+ continue
321
+ merged_chunks.append(current)
322
+ i += 1
323
+ chunks = merged_chunks
131
324
 
132
- vectors = []
133
- for chunk in chunks:
134
- vector = get_embedding(embedding_model, chunk)
135
- vectors.append(vector)
136
- return vectors
325
+ return chunks
326
+
137
327
 
328
+ @functools.lru_cache(maxsize=1)
138
329
  def load_tokenizer() -> AutoTokenizer:
139
330
  """
140
- 加载用于文本处理的分词器。
331
+ 加载用于文本处理的分词器,使用缓存避免重复加载。
141
332
 
142
333
  返回:
143
334
  AutoTokenizer: 加载的分词器
@@ -145,6 +336,10 @@ def load_tokenizer() -> AutoTokenizer:
145
336
  model_name = "gpt2"
146
337
  cache_dir = os.path.expanduser("~/.cache/huggingface/hub")
147
338
 
339
+ # 检查全局缓存
340
+ if model_name in _global_tokenizers:
341
+ return _global_tokenizers[model_name]
342
+
148
343
  try:
149
344
  tokenizer = AutoTokenizer.from_pretrained(
150
345
  model_name,
@@ -158,11 +353,15 @@ def load_tokenizer() -> AutoTokenizer:
158
353
  local_files_only=False
159
354
  )
160
355
 
356
+ # 保存到全局缓存
357
+ _global_tokenizers[model_name] = tokenizer
358
+
161
359
  return tokenizer # type: ignore
162
360
 
361
+ @functools.lru_cache(maxsize=1)
163
362
  def load_rerank_model() -> Tuple[AutoModelForSequenceClassification, AutoTokenizer]:
164
363
  """
165
- 加载重排序模型和分词器。
364
+ 加载重排序模型和分词器,使用缓存避免重复加载。
166
365
 
167
366
  返回:
168
367
  Tuple[AutoModelForSequenceClassification, AutoTokenizer]: 加载的模型和分词器
@@ -170,7 +369,10 @@ def load_rerank_model() -> Tuple[AutoModelForSequenceClassification, AutoTokeniz
170
369
  model_name = "BAAI/bge-reranker-v2-m3"
171
370
  cache_dir = os.path.expanduser("~/.cache/huggingface/hub")
172
371
 
173
- PrettyOutput.print(f"加载重排序模型: {model_name}...", OutputType.INFO)
372
+ # 检查全局缓存
373
+ key = f"rerank_{model_name}"
374
+ if key in _global_models and f"{key}_tokenizer" in _global_tokenizers:
375
+ return _global_models[key], _global_tokenizers[f"{key}_tokenizer"]
174
376
 
175
377
  try:
176
378
  tokenizer = AutoTokenizer.from_pretrained(
@@ -199,4 +401,101 @@ def load_rerank_model() -> Tuple[AutoModelForSequenceClassification, AutoTokeniz
199
401
  model = model.cuda()
200
402
  model.eval()
201
403
 
404
+ # 保存到全局缓存
405
+ _global_models[key] = model
406
+ _global_tokenizers[f"{key}_tokenizer"] = tokenizer
407
+
202
408
  return model, tokenizer # type: ignore
409
+
410
+ def rerank_results(query: str, documents: List[str], initial_scores: Optional[List[float]] = None,
411
+ batch_size: int = 8, spinner: Optional[Yaspin] = None) -> List[float]:
412
+ """
413
+ 使用交叉编码器重排序检索结果,提高RAG精度。
414
+
415
+ 参数:
416
+ query: 查询文本
417
+ documents: 要重排序的文档内容列表
418
+ initial_scores: 初始检索分数,可选。如果提供,将与重排序分数融合
419
+ batch_size: 批处理大小
420
+ spinner: 可选的进度指示器
421
+
422
+ 返回:
423
+ List[float]: 重排序后的分数列表,与输入文档对应
424
+ """
425
+ try:
426
+ if not documents:
427
+ return []
428
+
429
+ # 加载重排序模型
430
+ if spinner:
431
+ spinner.text = "加载重排序模型..."
432
+ model, tokenizer = load_rerank_model()
433
+
434
+ # 准备评分
435
+ all_scores = []
436
+
437
+ # 批量处理
438
+ for i in range(0, len(documents), batch_size):
439
+ if spinner:
440
+ spinner.text = f"重排序进度: {i}/{len(documents)}..."
441
+
442
+ # 准备当前批次
443
+ batch_docs = documents[i:i+batch_size]
444
+ pairs = [(query, doc) for doc in batch_docs]
445
+
446
+ # 编码输入
447
+ with torch.no_grad():
448
+ # 使用类型忽略以避免mypy错误
449
+ inputs = tokenizer( # type: ignore
450
+ pairs,
451
+ padding=True,
452
+ truncation=True,
453
+ return_tensors="pt",
454
+ max_length=512
455
+ )
456
+
457
+ # 使用GPU加速(如果可用)
458
+ if torch.cuda.is_available():
459
+ inputs = {k: v.cuda() for k, v in inputs.items()}
460
+
461
+ # 获取分数
462
+ outputs = model(**inputs) # type: ignore
463
+ scores = outputs.logits.squeeze(-1).cpu().tolist()
464
+
465
+ # 如果只有一个文档,确保返回列表
466
+ if len(batch_docs) == 1:
467
+ all_scores.append(float(scores))
468
+ else:
469
+ all_scores.extend(scores)
470
+
471
+ # 归一化分数到0-1范围
472
+ if all_scores:
473
+ min_score = min(all_scores)
474
+ max_score = max(all_scores)
475
+ if max_score > min_score:
476
+ normalized_scores = [(score - min_score) / (max_score - min_score) for score in all_scores]
477
+ else:
478
+ normalized_scores = [0.5] * len(all_scores)
479
+
480
+ # 融合初始分数(如果提供)
481
+ if initial_scores and len(initial_scores) == len(normalized_scores):
482
+ # 使用加权平均融合分数:初始分数权重0.3,重排序分数权重0.7
483
+ final_scores = [0.3 * init_score + 0.7 * rerank_score
484
+ for init_score, rerank_score in zip(initial_scores, normalized_scores)]
485
+ return final_scores
486
+
487
+ return normalized_scores
488
+
489
+ if spinner:
490
+ spinner.text = "重排序完成"
491
+
492
+ # 如果重排序失败,返回初始分数或默认分数
493
+ return initial_scores if initial_scores else [0.5] * len(documents)
494
+
495
+ except Exception as e:
496
+ PrettyOutput.print(f"重排序失败: {str(e)}", OutputType.ERROR)
497
+ if spinner:
498
+ spinner.text = f"重排序失败: {str(e)}"
499
+
500
+ # 发生错误时回退到初始分数
501
+ return initial_scores if initial_scores else [0.5] * len(documents)
@@ -14,9 +14,25 @@ import subprocess
14
14
  from typing import List, Tuple, Dict
15
15
  from jarvis.jarvis_utils.output import PrettyOutput, OutputType
16
16
  def find_git_root(start_dir="."):
17
- """切换到给定路径的Git根目录"""
17
+ """
18
+ 切换到给定路径的Git根目录,如果不是Git仓库则初始化。
19
+
20
+ 参数:
21
+ start_dir (str): 起始查找目录,默认为当前目录。
22
+
23
+ 返回:
24
+ str: Git仓库根目录路径。如果目录不是Git仓库,则会初始化一个新的Git仓库。
25
+ """
18
26
  os.chdir(start_dir)
19
- git_root = os.popen("git rev-parse --show-toplevel").read().strip()
27
+ try:
28
+ git_root = os.popen("git rev-parse --show-toplevel").read().strip()
29
+ if not git_root:
30
+ subprocess.run(["git", "init"], check=True)
31
+ git_root = os.path.abspath(".")
32
+ except subprocess.CalledProcessError:
33
+ # 如果不是Git仓库,初始化一个新的
34
+ subprocess.run(["git", "init"], check=True)
35
+ git_root = os.path.abspath(".")
20
36
  os.chdir(git_root)
21
37
  return git_root
22
38
  def has_uncommitted_changes():
@@ -15,7 +15,7 @@ from prompt_toolkit.document import Document
15
15
  from prompt_toolkit.key_binding import KeyBindings
16
16
  from fuzzywuzzy import process
17
17
  from colorama import Fore, Style as ColoramaStyle
18
- from ..jarvis_utils.output import PrettyOutput, OutputType
18
+ from jarvis.jarvis_utils.output import PrettyOutput, OutputType
19
19
  def get_single_line_input(tip: str) -> str:
20
20
  """
21
21
  获取支持历史记录的单行输入。
@@ -74,10 +74,13 @@ class FileCompleter(Completer):
74
74
  # 添加默认建议
75
75
  if not text_after_at.strip():
76
76
  # 默认建议列表
77
+ from jarvis.jarvis_utils.utils import ot
77
78
  default_suggestions = [
78
- ('<CodeBase>', '查询代码库'),
79
- ('<Web>', '网页搜索'),
80
- ('<RAG>', '知识库检索')
79
+ (ot("CodeBase"), '查询代码库'),
80
+ (ot("Web"), '网页搜索'),
81
+ (ot("RAG"), '知识库检索'),
82
+ (ot("Summary"), '总结'),
83
+ (ot("Clear"), '清除历史'),
81
84
  ]
82
85
  for name, desc in default_suggestions:
83
86
  yield Completion(