jarvis-ai-assistant 0.1.134__py3-none-any.whl → 0.1.138__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of jarvis-ai-assistant might be problematic. Click here for more details.

Files changed (78) hide show
  1. jarvis/__init__.py +1 -1
  2. jarvis/jarvis_agent/__init__.py +201 -79
  3. jarvis/jarvis_agent/builtin_input_handler.py +16 -6
  4. jarvis/jarvis_agent/file_input_handler.py +9 -9
  5. jarvis/jarvis_agent/jarvis.py +10 -10
  6. jarvis/jarvis_agent/main.py +12 -11
  7. jarvis/jarvis_agent/output_handler.py +3 -3
  8. jarvis/jarvis_agent/patch.py +86 -62
  9. jarvis/jarvis_agent/shell_input_handler.py +5 -3
  10. jarvis/jarvis_code_agent/code_agent.py +134 -99
  11. jarvis/jarvis_code_agent/file_select.py +24 -24
  12. jarvis/jarvis_dev/main.py +45 -51
  13. jarvis/jarvis_git_details/__init__.py +0 -0
  14. jarvis/jarvis_git_details/main.py +179 -0
  15. jarvis/jarvis_git_squash/main.py +7 -7
  16. jarvis/jarvis_lsp/base.py +11 -11
  17. jarvis/jarvis_lsp/cpp.py +14 -14
  18. jarvis/jarvis_lsp/go.py +13 -13
  19. jarvis/jarvis_lsp/python.py +8 -8
  20. jarvis/jarvis_lsp/registry.py +21 -21
  21. jarvis/jarvis_lsp/rust.py +15 -15
  22. jarvis/jarvis_methodology/main.py +101 -0
  23. jarvis/jarvis_multi_agent/__init__.py +11 -11
  24. jarvis/jarvis_multi_agent/main.py +6 -6
  25. jarvis/jarvis_platform/__init__.py +1 -1
  26. jarvis/jarvis_platform/ai8.py +67 -89
  27. jarvis/jarvis_platform/base.py +14 -13
  28. jarvis/jarvis_platform/kimi.py +25 -28
  29. jarvis/jarvis_platform/ollama.py +24 -26
  30. jarvis/jarvis_platform/openai.py +15 -19
  31. jarvis/jarvis_platform/oyi.py +48 -50
  32. jarvis/jarvis_platform/registry.py +27 -28
  33. jarvis/jarvis_platform/yuanbao.py +38 -42
  34. jarvis/jarvis_platform_manager/main.py +81 -81
  35. jarvis/jarvis_platform_manager/openai_test.py +21 -21
  36. jarvis/jarvis_rag/file_processors.py +18 -18
  37. jarvis/jarvis_rag/main.py +261 -277
  38. jarvis/jarvis_smart_shell/main.py +12 -12
  39. jarvis/jarvis_tools/ask_codebase.py +28 -28
  40. jarvis/jarvis_tools/ask_user.py +8 -8
  41. jarvis/jarvis_tools/base.py +4 -4
  42. jarvis/jarvis_tools/chdir.py +9 -9
  43. jarvis/jarvis_tools/code_review.py +19 -19
  44. jarvis/jarvis_tools/create_code_agent.py +15 -15
  45. jarvis/jarvis_tools/execute_python_script.py +3 -3
  46. jarvis/jarvis_tools/execute_shell.py +11 -11
  47. jarvis/jarvis_tools/execute_shell_script.py +3 -3
  48. jarvis/jarvis_tools/file_analyzer.py +29 -29
  49. jarvis/jarvis_tools/file_operation.py +22 -20
  50. jarvis/jarvis_tools/find_caller.py +25 -25
  51. jarvis/jarvis_tools/find_methodolopy.py +65 -0
  52. jarvis/jarvis_tools/find_symbol.py +24 -24
  53. jarvis/jarvis_tools/function_analyzer.py +27 -27
  54. jarvis/jarvis_tools/git_commiter.py +9 -9
  55. jarvis/jarvis_tools/lsp_get_diagnostics.py +19 -19
  56. jarvis/jarvis_tools/methodology.py +23 -62
  57. jarvis/jarvis_tools/project_analyzer.py +29 -33
  58. jarvis/jarvis_tools/rag.py +15 -15
  59. jarvis/jarvis_tools/read_code.py +24 -22
  60. jarvis/jarvis_tools/read_webpage.py +31 -31
  61. jarvis/jarvis_tools/registry.py +72 -52
  62. jarvis/jarvis_tools/tool_generator.py +18 -18
  63. jarvis/jarvis_utils/config.py +23 -23
  64. jarvis/jarvis_utils/embedding.py +83 -83
  65. jarvis/jarvis_utils/git_utils.py +20 -20
  66. jarvis/jarvis_utils/globals.py +18 -6
  67. jarvis/jarvis_utils/input.py +10 -9
  68. jarvis/jarvis_utils/methodology.py +140 -136
  69. jarvis/jarvis_utils/output.py +11 -11
  70. jarvis/jarvis_utils/utils.py +22 -70
  71. {jarvis_ai_assistant-0.1.134.dist-info → jarvis_ai_assistant-0.1.138.dist-info}/METADATA +1 -1
  72. jarvis_ai_assistant-0.1.138.dist-info/RECORD +85 -0
  73. {jarvis_ai_assistant-0.1.134.dist-info → jarvis_ai_assistant-0.1.138.dist-info}/entry_points.txt +2 -0
  74. jarvis/jarvis_tools/select_code_files.py +0 -62
  75. jarvis_ai_assistant-0.1.134.dist-info/RECORD +0 -82
  76. {jarvis_ai_assistant-0.1.134.dist-info → jarvis_ai_assistant-0.1.138.dist-info}/LICENSE +0 -0
  77. {jarvis_ai_assistant-0.1.134.dist-info → jarvis_ai_assistant-0.1.138.dist-info}/WHEEL +0 -0
  78. {jarvis_ai_assistant-0.1.134.dist-info → jarvis_ai_assistant-0.1.138.dist-info}/top_level.txt +0 -0
@@ -12,35 +12,35 @@ import os
12
12
  def get_max_token_count() -> int:
13
13
  """
14
14
  获取模型允许的最大token数量。
15
-
15
+
16
16
  返回:
17
17
  int: 模型能处理的最大token数量。
18
18
  """
19
19
  return int(os.getenv('JARVIS_MAX_TOKEN_COUNT', '64000')) # 默认64k
20
-
20
+
21
21
  def get_thread_count() -> int:
22
22
  """
23
23
  获取用于并行处理的线程数。
24
-
24
+
25
25
  返回:
26
26
  int: 线程数,默认为1
27
27
  """
28
- return int(os.getenv('JARVIS_THREAD_COUNT', '1'))
29
-
28
+ return int(os.getenv('JARVIS_THREAD_COUNT', '1'))
29
+
30
30
  def is_auto_complete() -> bool:
31
31
  """
32
32
  检查是否启用了自动补全功能。
33
-
33
+
34
34
  返回:
35
35
  bool: 如果启用了自动补全则返回True,默认为False
36
36
  """
37
37
  return os.getenv('JARVIS_AUTO_COMPLETE', 'false') == 'true'
38
-
38
+
39
39
 
40
40
  def get_min_paragraph_length() -> int:
41
41
  """
42
42
  获取文本处理的最小段落长度。
43
-
43
+
44
44
  返回:
45
45
  int: 最小字符长度,默认为50
46
46
  """
@@ -48,7 +48,7 @@ def get_min_paragraph_length() -> int:
48
48
  def get_max_paragraph_length() -> int:
49
49
  """
50
50
  获取文本处理的最大段落长度。
51
-
51
+
52
52
  返回:
53
53
  int: 最大字符长度,默认为12800
54
54
  """
@@ -56,7 +56,7 @@ def get_max_paragraph_length() -> int:
56
56
  def get_shell_name() -> str:
57
57
  """
58
58
  获取系统shell名称。
59
-
59
+
60
60
  返回:
61
61
  str: Shell名称(例如bash, zsh),默认为bash
62
62
  """
@@ -64,7 +64,7 @@ def get_shell_name() -> str:
64
64
  def get_normal_platform_name() -> str:
65
65
  """
66
66
  获取正常操作的平台名称。
67
-
67
+
68
68
  返回:
69
69
  str: 平台名称,默认为'yuanbao'
70
70
  """
@@ -72,7 +72,7 @@ def get_normal_platform_name() -> str:
72
72
  def get_normal_model_name() -> str:
73
73
  """
74
74
  获取正常操作的模型名称。
75
-
75
+
76
76
  返回:
77
77
  str: 模型名称,默认为'deep_seek'
78
78
  """
@@ -82,7 +82,7 @@ def get_normal_model_name() -> str:
82
82
  def get_thinking_platform_name() -> str:
83
83
  """
84
84
  获取思考操作的平台名称。
85
-
85
+
86
86
  返回:
87
87
  str: 平台名称,默认为'yuanbao'
88
88
  """
@@ -90,7 +90,7 @@ def get_thinking_platform_name() -> str:
90
90
  def get_thinking_model_name() -> str:
91
91
  """
92
92
  获取思考操作的模型名称。
93
-
93
+
94
94
  返回:
95
95
  str: 模型名称,默认为'deep_seek'
96
96
  """
@@ -99,7 +99,7 @@ def get_thinking_model_name() -> str:
99
99
  def is_execute_tool_confirm() -> bool:
100
100
  """
101
101
  检查工具执行是否需要确认。
102
-
102
+
103
103
  返回:
104
104
  bool: 如果需要确认则返回True,默认为False
105
105
  """
@@ -107,7 +107,7 @@ def is_execute_tool_confirm() -> bool:
107
107
  def is_confirm_before_apply_patch() -> bool:
108
108
  """
109
109
  检查应用补丁前是否需要确认。
110
-
110
+
111
111
  返回:
112
112
  bool: 如果需要确认则返回True,默认为False
113
113
  """
@@ -116,18 +116,18 @@ def is_confirm_before_apply_patch() -> bool:
116
116
  def get_rag_ignored_paths() -> list:
117
117
  """
118
118
  获取RAG索引时需要忽略的路径列表。
119
-
119
+
120
120
  首先尝试从.jarvis/rag_ignore.txt文件中读取,
121
121
  如果该文件不存在,则返回默认忽略列表。
122
-
122
+
123
123
  返回:
124
124
  list: 忽略路径的列表,默认包含常见忽略路径
125
125
  """
126
126
  # 默认忽略路径
127
127
  default_ignored = [
128
- '.git',
129
- '__pycache__',
130
- 'node_modules',
128
+ '.git',
129
+ '__pycache__',
130
+ 'node_modules',
131
131
  '.jarvis',
132
132
  '.jarvis-*',
133
133
  'target',
@@ -164,7 +164,7 @@ def get_rag_ignored_paths() -> list:
164
164
  '*.xz',
165
165
  '*.rar'
166
166
  ]
167
-
167
+
168
168
  # 尝试从配置文件中读取
169
169
  try:
170
170
  config_path = os.path.join('.jarvis', 'rag_ignore.txt')
@@ -174,5 +174,5 @@ def get_rag_ignored_paths() -> list:
174
174
  return custom_ignored
175
175
  except Exception:
176
176
  pass
177
-
177
+
178
178
  return default_ignored
@@ -15,10 +15,10 @@ _global_tokenizers = {}
15
15
 
16
16
  def get_context_token_count(text: str) -> int:
17
17
  """使用分词器获取文本的token数量。
18
-
18
+
19
19
  参数:
20
20
  text: 要计算token的输入文本
21
-
21
+
22
22
  返回:
23
23
  int: 文本中的token数量
24
24
  """
@@ -27,7 +27,7 @@ def get_context_token_count(text: str) -> int:
27
27
  tokenizer = load_tokenizer()
28
28
  chunks = split_text_into_chunks(text, 512)
29
29
  return sum([len(tokenizer.encode(chunk)) for chunk in chunks]) # type: ignore
30
-
30
+
31
31
  except Exception as e:
32
32
  PrettyOutput.print(f"计算token失败: {str(e)}", OutputType.WARNING)
33
33
  # 回退到基于字符的粗略估计
@@ -37,17 +37,17 @@ def get_context_token_count(text: str) -> int:
37
37
  def load_embedding_model() -> SentenceTransformer:
38
38
  """
39
39
  加载句子嵌入模型,使用缓存避免重复加载。
40
-
40
+
41
41
  返回:
42
42
  SentenceTransformer: 加载的嵌入模型
43
43
  """
44
44
  model_name = "BAAI/bge-m3"
45
45
  cache_dir = os.path.expanduser("~/.cache/huggingface/hub")
46
-
46
+
47
47
  # 检查全局缓存中是否已有模型
48
48
  if model_name in _global_models:
49
49
  return _global_models[model_name]
50
-
50
+
51
51
  try:
52
52
  embedding_model = SentenceTransformer(
53
53
  model_name,
@@ -60,28 +60,28 @@ def load_embedding_model() -> SentenceTransformer:
60
60
  cache_folder=cache_dir,
61
61
  local_files_only=False
62
62
  )
63
-
63
+
64
64
  # 如果可用,将模型移到GPU上
65
65
  if torch.cuda.is_available():
66
66
  embedding_model.to(torch.device("cuda"))
67
-
67
+
68
68
  # 保存到全局缓存
69
69
  _global_models[model_name] = embedding_model
70
-
70
+
71
71
  return embedding_model
72
72
 
73
73
  def get_embedding(embedding_model: Any, text: str) -> np.ndarray:
74
74
  """
75
75
  为给定文本生成嵌入向量。
76
-
76
+
77
77
  参数:
78
78
  embedding_model: 使用的嵌入模型
79
79
  text: 要嵌入的输入文本
80
-
80
+
81
81
  返回:
82
82
  np.ndarray: 嵌入向量
83
83
  """
84
- embedding = embedding_model.encode(text,
84
+ embedding = embedding_model.encode(text,
85
85
  normalize_embeddings=True,
86
86
  show_progress_bar=False)
87
87
  return np.array(embedding, dtype=np.float32)
@@ -89,53 +89,53 @@ def get_embedding(embedding_model: Any, text: str) -> np.ndarray:
89
89
  def get_embedding_batch(embedding_model: Any, prefix: str, texts: List[str], spinner: Optional[Yaspin] = None, batch_size: int = 8) -> np.ndarray:
90
90
  """
91
91
  为一批文本生成嵌入向量,使用高效的批处理,针对RAG优化。
92
-
92
+
93
93
  参数:
94
94
  embedding_model: 使用的嵌入模型
95
95
  prefix: 进度条前缀
96
96
  texts: 要嵌入的文本列表
97
97
  spinner: 可选的进度指示器
98
98
  batch_size: 批处理大小,更大的值可能更快但需要更多内存
99
-
99
+
100
100
  返回:
101
101
  np.ndarray: 堆叠的嵌入向量
102
102
  """
103
103
  # 简单嵌入缓存,避免重复计算相同文本块
104
104
  embedding_cache = {}
105
105
  cache_hits = 0
106
-
106
+
107
107
  try:
108
108
  # 预处理:将所有文本分块
109
109
  all_chunks = []
110
110
  chunk_indices = [] # 跟踪每个原始文本对应的块索引
111
-
111
+
112
112
  for i, text in enumerate(texts):
113
113
  if spinner:
114
114
  spinner.text = f"{prefix} 预处理中 ({i+1}/{len(texts)}) ..."
115
-
115
+
116
116
  # 预处理文本:移除多余空白,规范化
117
117
  text = ' '.join(text.split()) if text else ""
118
-
118
+
119
119
  # 使用更优化的分块函数
120
120
  chunks = split_text_into_chunks(text, 512)
121
121
  start_idx = len(all_chunks)
122
122
  all_chunks.extend(chunks)
123
123
  end_idx = len(all_chunks)
124
124
  chunk_indices.append((start_idx, end_idx))
125
-
125
+
126
126
  if not all_chunks:
127
127
  return np.zeros((0, embedding_model.get_sentence_embedding_dimension()), dtype=np.float32)
128
-
128
+
129
129
  # 批量处理所有块
130
130
  all_vectors = []
131
131
  for i in range(0, len(all_chunks), batch_size):
132
132
  if spinner:
133
133
  spinner.text = f"{prefix} 批量处理嵌入 ({i+1}/{len(all_chunks)}) ..."
134
-
134
+
135
135
  batch = all_chunks[i:i+batch_size]
136
136
  batch_to_process = []
137
137
  batch_indices = []
138
-
138
+
139
139
  # 检查缓存,避免重复计算
140
140
  for j, chunk in enumerate(batch):
141
141
  chunk_hash = hash(chunk)
@@ -145,16 +145,16 @@ def get_embedding_batch(embedding_model: Any, prefix: str, texts: List[str], spi
145
145
  else:
146
146
  batch_to_process.append(chunk)
147
147
  batch_indices.append(j)
148
-
148
+
149
149
  if batch_to_process:
150
150
  # 对未缓存的块处理
151
151
  batch_vectors = embedding_model.encode(
152
- batch_to_process,
152
+ batch_to_process,
153
153
  normalize_embeddings=True,
154
154
  show_progress_bar=False,
155
155
  convert_to_numpy=True,
156
156
  )
157
-
157
+
158
158
  # 处理结果并更新缓存
159
159
  if len(batch_to_process) == 1:
160
160
  vec = batch_vectors
@@ -166,7 +166,7 @@ def get_embedding_batch(embedding_model: Any, prefix: str, texts: List[str], spi
166
166
  chunk_hash = hash(batch_to_process[j])
167
167
  embedding_cache[chunk_hash] = vec
168
168
  all_vectors.append(vec)
169
-
169
+
170
170
  # 组织结果到原始文本顺序
171
171
  result_vectors = []
172
172
  for start_idx, end_idx in chunk_indices:
@@ -174,73 +174,73 @@ def get_embedding_batch(embedding_model: Any, prefix: str, texts: List[str], spi
174
174
  for j in range(start_idx, end_idx):
175
175
  if j < len(all_vectors):
176
176
  text_vectors.append(all_vectors[j])
177
-
177
+
178
178
  if text_vectors:
179
179
  # 当一个文本被分成多个块时,采用加权平均
180
180
  if len(text_vectors) > 1:
181
181
  # 针对RAG优化:对多个块进行加权平均,前面的块权重略高
182
182
  weights = np.linspace(1.0, 0.8, len(text_vectors))
183
183
  weights = weights / weights.sum() # 归一化权重
184
-
184
+
185
185
  # 应用权重并求和
186
186
  weighted_sum = np.zeros_like(text_vectors[0])
187
187
  for i, vec in enumerate(text_vectors):
188
188
  # 确保向量形状一致,处理可能的维度不匹配问题
189
189
  vec_array = np.asarray(vec).reshape(weighted_sum.shape)
190
190
  weighted_sum += vec_array * weights[i]
191
-
191
+
192
192
  # 归一化结果向量
193
193
  norm = np.linalg.norm(weighted_sum)
194
194
  if norm > 0:
195
195
  weighted_sum = weighted_sum / norm
196
-
196
+
197
197
  result_vectors.append(weighted_sum)
198
198
  else:
199
199
  # 单块直接使用
200
200
  result_vectors.append(text_vectors[0])
201
-
201
+
202
202
  if spinner and cache_hits > 0:
203
203
  spinner.text = f"{prefix} 缓存命中: {cache_hits}/{len(all_chunks)} 块"
204
-
204
+
205
205
  return np.vstack(result_vectors)
206
-
206
+
207
207
  except Exception as e:
208
208
  PrettyOutput.print(f"批量嵌入失败: {str(e)}", OutputType.ERROR)
209
209
  return np.zeros((0, embedding_model.get_sentence_embedding_dimension()), dtype=np.float32)
210
-
210
+
211
211
  def split_text_into_chunks(text: str, max_length: int = 512, min_length: int = 50) -> List[str]:
212
212
  """将文本分割成带重叠窗口的块,优化RAG检索效果。
213
-
213
+
214
214
  参数:
215
215
  text: 要分割的输入文本
216
216
  max_length: 每个块的最大长度
217
217
  min_length: 每个块的最小长度(除了最后一块可能较短)
218
-
218
+
219
219
  返回:
220
220
  List[str]: 文本块列表,每个块的长度尽可能接近但不超过max_length
221
221
  """
222
222
  if not text:
223
223
  return []
224
-
224
+
225
225
  # 如果文本长度小于最大长度,直接返回整个文本
226
226
  if len(text) <= max_length:
227
227
  return [text]
228
-
228
+
229
229
  # 预处理:规范化文本,移除多余空白字符
230
230
  text = ' '.join(text.split())
231
-
231
+
232
232
  # 中英文标点符号集合,优化RAG召回的句子边界
233
233
  primary_punctuation = {'.', '!', '?', '\n', '。', '!', '?'} # 主要句末标点
234
234
  secondary_punctuation = {';', ':', '…', ';', ':'} # 次级分隔符
235
235
  tertiary_punctuation = {',', ',', '、', ')', ')', ']', '】', '}', '》', '"', "'"} # 最低优先级
236
-
236
+
237
237
  chunks = []
238
238
  start = 0
239
-
239
+
240
240
  while start < len(text):
241
241
  # 初始化结束位置为最大可能长度
242
242
  end = min(start + max_length, len(text))
243
-
243
+
244
244
  # 只有当不是最后一块且结束位置等于最大长度时,才尝试寻找句子边界
245
245
  if end < len(text) and end == start + max_length:
246
246
  # 优先查找段落边界,这对RAG特别重要
@@ -251,17 +251,17 @@ def split_text_into_chunks(text: str, max_length: int = 512, min_length: int = 5
251
251
  # 寻找句子边界,从end-1位置开始
252
252
  found_boundary = False
253
253
  best_boundary = -1
254
-
254
+
255
255
  # 扩大搜索范围以找到更好的语义边界
256
256
  search_range = min(120, end - start - min_length) # 扩大搜索范围,但确保新块不小于min_length
257
-
257
+
258
258
  # 先尝试找主要标点(句号等)
259
259
  for i in range(end-1, max(start, end-search_range), -1):
260
260
  if text[i] in primary_punctuation:
261
261
  best_boundary = i
262
262
  found_boundary = True
263
263
  break
264
-
264
+
265
265
  # 如果没找到主要标点,再找次要标点(分号、冒号等)
266
266
  if not found_boundary:
267
267
  for i in range(end-1, max(start, end-search_range), -1):
@@ -269,7 +269,7 @@ def split_text_into_chunks(text: str, max_length: int = 512, min_length: int = 5
269
269
  best_boundary = i
270
270
  found_boundary = True
271
271
  break
272
-
272
+
273
273
  # 最后考虑逗号和其他可能的边界
274
274
  if not found_boundary:
275
275
  for i in range(end-1, max(start, end-search_range), -1):
@@ -277,11 +277,11 @@ def split_text_into_chunks(text: str, max_length: int = 512, min_length: int = 5
277
277
  best_boundary = i
278
278
  found_boundary = True
279
279
  break
280
-
280
+
281
281
  # 如果找到了合适的边界且不会导致太短的块,使用它
282
282
  if found_boundary and (best_boundary - start) >= min_length:
283
283
  end = best_boundary + 1
284
-
284
+
285
285
  # 添加当前块,并确保删除开头和结尾的空白字符
286
286
  chunk = text[start:end].strip()
287
287
  if chunk and len(chunk) >= min_length: # 只添加符合最小长度的非空块
@@ -295,16 +295,16 @@ def split_text_into_chunks(text: str, max_length: int = 512, min_length: int = 5
295
295
  else:
296
296
  # 如果合并会导致太长,添加这个小块(特殊情况)
297
297
  chunks.append(chunk)
298
-
298
+
299
299
  # 计算下一块的开始位置,调整重叠窗口大小以提高RAG检索质量
300
300
  next_start = end - int(max_length * 0.2) # 20%的重叠窗口大小
301
-
301
+
302
302
  # 确保总是有前进,避免无限循环
303
303
  if next_start <= start:
304
304
  next_start = start + max(1, min_length // 2)
305
-
305
+
306
306
  start = next_start
307
-
307
+
308
308
  # 最后检查是否有太短的块,尝试合并相邻的短块
309
309
  if len(chunks) > 1:
310
310
  merged_chunks = []
@@ -321,7 +321,7 @@ def split_text_into_chunks(text: str, max_length: int = 512, min_length: int = 5
321
321
  merged_chunks.append(current)
322
322
  i += 1
323
323
  chunks = merged_chunks
324
-
324
+
325
325
  return chunks
326
326
 
327
327
 
@@ -329,17 +329,17 @@ def split_text_into_chunks(text: str, max_length: int = 512, min_length: int = 5
329
329
  def load_tokenizer() -> AutoTokenizer:
330
330
  """
331
331
  加载用于文本处理的分词器,使用缓存避免重复加载。
332
-
332
+
333
333
  返回:
334
334
  AutoTokenizer: 加载的分词器
335
335
  """
336
336
  model_name = "gpt2"
337
337
  cache_dir = os.path.expanduser("~/.cache/huggingface/hub")
338
-
338
+
339
339
  # 检查全局缓存
340
340
  if model_name in _global_tokenizers:
341
341
  return _global_tokenizers[model_name]
342
-
342
+
343
343
  try:
344
344
  tokenizer = AutoTokenizer.from_pretrained(
345
345
  model_name,
@@ -352,28 +352,28 @@ def load_tokenizer() -> AutoTokenizer:
352
352
  cache_dir=cache_dir,
353
353
  local_files_only=False
354
354
  )
355
-
355
+
356
356
  # 保存到全局缓存
357
357
  _global_tokenizers[model_name] = tokenizer
358
-
358
+
359
359
  return tokenizer # type: ignore
360
360
 
361
361
  @functools.lru_cache(maxsize=1)
362
362
  def load_rerank_model() -> Tuple[AutoModelForSequenceClassification, AutoTokenizer]:
363
363
  """
364
364
  加载重排序模型和分词器,使用缓存避免重复加载。
365
-
365
+
366
366
  返回:
367
367
  Tuple[AutoModelForSequenceClassification, AutoTokenizer]: 加载的模型和分词器
368
368
  """
369
369
  model_name = "BAAI/bge-reranker-v2-m3"
370
370
  cache_dir = os.path.expanduser("~/.cache/huggingface/hub")
371
-
371
+
372
372
  # 检查全局缓存
373
373
  key = f"rerank_{model_name}"
374
374
  if key in _global_models and f"{key}_tokenizer" in _global_tokenizers:
375
375
  return _global_models[key], _global_tokenizers[f"{key}_tokenizer"]
376
-
376
+
377
377
  try:
378
378
  tokenizer = AutoTokenizer.from_pretrained(
379
379
  model_name,
@@ -396,53 +396,53 @@ def load_rerank_model() -> Tuple[AutoModelForSequenceClassification, AutoTokeniz
396
396
  cache_dir=cache_dir,
397
397
  local_files_only=False
398
398
  )
399
-
399
+
400
400
  if torch.cuda.is_available():
401
401
  model = model.cuda()
402
402
  model.eval()
403
-
403
+
404
404
  # 保存到全局缓存
405
405
  _global_models[key] = model
406
406
  _global_tokenizers[f"{key}_tokenizer"] = tokenizer
407
-
407
+
408
408
  return model, tokenizer # type: ignore
409
409
 
410
- def rerank_results(query: str, documents: List[str], initial_scores: Optional[List[float]] = None,
410
+ def rerank_results(query: str, documents: List[str], initial_scores: Optional[List[float]] = None,
411
411
  batch_size: int = 8, spinner: Optional[Yaspin] = None) -> List[float]:
412
412
  """
413
413
  使用交叉编码器重排序检索结果,提高RAG精度。
414
-
414
+
415
415
  参数:
416
416
  query: 查询文本
417
417
  documents: 要重排序的文档内容列表
418
418
  initial_scores: 初始检索分数,可选。如果提供,将与重排序分数融合
419
419
  batch_size: 批处理大小
420
420
  spinner: 可选的进度指示器
421
-
421
+
422
422
  返回:
423
423
  List[float]: 重排序后的分数列表,与输入文档对应
424
424
  """
425
425
  try:
426
426
  if not documents:
427
427
  return []
428
-
428
+
429
429
  # 加载重排序模型
430
430
  if spinner:
431
431
  spinner.text = "加载重排序模型..."
432
432
  model, tokenizer = load_rerank_model()
433
-
433
+
434
434
  # 准备评分
435
435
  all_scores = []
436
-
436
+
437
437
  # 批量处理
438
438
  for i in range(0, len(documents), batch_size):
439
439
  if spinner:
440
440
  spinner.text = f"重排序进度: {i}/{len(documents)}..."
441
-
441
+
442
442
  # 准备当前批次
443
443
  batch_docs = documents[i:i+batch_size]
444
444
  pairs = [(query, doc) for doc in batch_docs]
445
-
445
+
446
446
  # 编码输入
447
447
  with torch.no_grad():
448
448
  # 使用类型忽略以避免mypy错误
@@ -453,21 +453,21 @@ def rerank_results(query: str, documents: List[str], initial_scores: Optional[Li
453
453
  return_tensors="pt",
454
454
  max_length=512
455
455
  )
456
-
456
+
457
457
  # 使用GPU加速(如果可用)
458
458
  if torch.cuda.is_available():
459
459
  inputs = {k: v.cuda() for k, v in inputs.items()}
460
-
460
+
461
461
  # 获取分数
462
462
  outputs = model(**inputs) # type: ignore
463
463
  scores = outputs.logits.squeeze(-1).cpu().tolist()
464
-
464
+
465
465
  # 如果只有一个文档,确保返回列表
466
466
  if len(batch_docs) == 1:
467
467
  all_scores.append(float(scores))
468
468
  else:
469
469
  all_scores.extend(scores)
470
-
470
+
471
471
  # 归一化分数到0-1范围
472
472
  if all_scores:
473
473
  min_score = min(all_scores)
@@ -476,26 +476,26 @@ def rerank_results(query: str, documents: List[str], initial_scores: Optional[Li
476
476
  normalized_scores = [(score - min_score) / (max_score - min_score) for score in all_scores]
477
477
  else:
478
478
  normalized_scores = [0.5] * len(all_scores)
479
-
479
+
480
480
  # 融合初始分数(如果提供)
481
481
  if initial_scores and len(initial_scores) == len(normalized_scores):
482
482
  # 使用加权平均融合分数:初始分数权重0.3,重排序分数权重0.7
483
- final_scores = [0.3 * init_score + 0.7 * rerank_score
483
+ final_scores = [0.3 * init_score + 0.7 * rerank_score
484
484
  for init_score, rerank_score in zip(initial_scores, normalized_scores)]
485
485
  return final_scores
486
-
486
+
487
487
  return normalized_scores
488
-
488
+
489
489
  if spinner:
490
490
  spinner.text = "重排序完成"
491
-
491
+
492
492
  # 如果重排序失败,返回初始分数或默认分数
493
493
  return initial_scores if initial_scores else [0.5] * len(documents)
494
-
494
+
495
495
  except Exception as e:
496
496
  PrettyOutput.print(f"重排序失败: {str(e)}", OutputType.ERROR)
497
497
  if spinner:
498
498
  spinner.text = f"重排序失败: {str(e)}"
499
-
499
+
500
500
  # 发生错误时回退到初始分数
501
501
  return initial_scores if initial_scores else [0.5] * len(documents)