jarvis-ai-assistant 0.1.134__py3-none-any.whl → 0.1.138__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of jarvis-ai-assistant might be problematic. Click here for more details.
- jarvis/__init__.py +1 -1
- jarvis/jarvis_agent/__init__.py +201 -79
- jarvis/jarvis_agent/builtin_input_handler.py +16 -6
- jarvis/jarvis_agent/file_input_handler.py +9 -9
- jarvis/jarvis_agent/jarvis.py +10 -10
- jarvis/jarvis_agent/main.py +12 -11
- jarvis/jarvis_agent/output_handler.py +3 -3
- jarvis/jarvis_agent/patch.py +86 -62
- jarvis/jarvis_agent/shell_input_handler.py +5 -3
- jarvis/jarvis_code_agent/code_agent.py +134 -99
- jarvis/jarvis_code_agent/file_select.py +24 -24
- jarvis/jarvis_dev/main.py +45 -51
- jarvis/jarvis_git_details/__init__.py +0 -0
- jarvis/jarvis_git_details/main.py +179 -0
- jarvis/jarvis_git_squash/main.py +7 -7
- jarvis/jarvis_lsp/base.py +11 -11
- jarvis/jarvis_lsp/cpp.py +14 -14
- jarvis/jarvis_lsp/go.py +13 -13
- jarvis/jarvis_lsp/python.py +8 -8
- jarvis/jarvis_lsp/registry.py +21 -21
- jarvis/jarvis_lsp/rust.py +15 -15
- jarvis/jarvis_methodology/main.py +101 -0
- jarvis/jarvis_multi_agent/__init__.py +11 -11
- jarvis/jarvis_multi_agent/main.py +6 -6
- jarvis/jarvis_platform/__init__.py +1 -1
- jarvis/jarvis_platform/ai8.py +67 -89
- jarvis/jarvis_platform/base.py +14 -13
- jarvis/jarvis_platform/kimi.py +25 -28
- jarvis/jarvis_platform/ollama.py +24 -26
- jarvis/jarvis_platform/openai.py +15 -19
- jarvis/jarvis_platform/oyi.py +48 -50
- jarvis/jarvis_platform/registry.py +27 -28
- jarvis/jarvis_platform/yuanbao.py +38 -42
- jarvis/jarvis_platform_manager/main.py +81 -81
- jarvis/jarvis_platform_manager/openai_test.py +21 -21
- jarvis/jarvis_rag/file_processors.py +18 -18
- jarvis/jarvis_rag/main.py +261 -277
- jarvis/jarvis_smart_shell/main.py +12 -12
- jarvis/jarvis_tools/ask_codebase.py +28 -28
- jarvis/jarvis_tools/ask_user.py +8 -8
- jarvis/jarvis_tools/base.py +4 -4
- jarvis/jarvis_tools/chdir.py +9 -9
- jarvis/jarvis_tools/code_review.py +19 -19
- jarvis/jarvis_tools/create_code_agent.py +15 -15
- jarvis/jarvis_tools/execute_python_script.py +3 -3
- jarvis/jarvis_tools/execute_shell.py +11 -11
- jarvis/jarvis_tools/execute_shell_script.py +3 -3
- jarvis/jarvis_tools/file_analyzer.py +29 -29
- jarvis/jarvis_tools/file_operation.py +22 -20
- jarvis/jarvis_tools/find_caller.py +25 -25
- jarvis/jarvis_tools/find_methodolopy.py +65 -0
- jarvis/jarvis_tools/find_symbol.py +24 -24
- jarvis/jarvis_tools/function_analyzer.py +27 -27
- jarvis/jarvis_tools/git_commiter.py +9 -9
- jarvis/jarvis_tools/lsp_get_diagnostics.py +19 -19
- jarvis/jarvis_tools/methodology.py +23 -62
- jarvis/jarvis_tools/project_analyzer.py +29 -33
- jarvis/jarvis_tools/rag.py +15 -15
- jarvis/jarvis_tools/read_code.py +24 -22
- jarvis/jarvis_tools/read_webpage.py +31 -31
- jarvis/jarvis_tools/registry.py +72 -52
- jarvis/jarvis_tools/tool_generator.py +18 -18
- jarvis/jarvis_utils/config.py +23 -23
- jarvis/jarvis_utils/embedding.py +83 -83
- jarvis/jarvis_utils/git_utils.py +20 -20
- jarvis/jarvis_utils/globals.py +18 -6
- jarvis/jarvis_utils/input.py +10 -9
- jarvis/jarvis_utils/methodology.py +140 -136
- jarvis/jarvis_utils/output.py +11 -11
- jarvis/jarvis_utils/utils.py +22 -70
- {jarvis_ai_assistant-0.1.134.dist-info → jarvis_ai_assistant-0.1.138.dist-info}/METADATA +1 -1
- jarvis_ai_assistant-0.1.138.dist-info/RECORD +85 -0
- {jarvis_ai_assistant-0.1.134.dist-info → jarvis_ai_assistant-0.1.138.dist-info}/entry_points.txt +2 -0
- jarvis/jarvis_tools/select_code_files.py +0 -62
- jarvis_ai_assistant-0.1.134.dist-info/RECORD +0 -82
- {jarvis_ai_assistant-0.1.134.dist-info → jarvis_ai_assistant-0.1.138.dist-info}/LICENSE +0 -0
- {jarvis_ai_assistant-0.1.134.dist-info → jarvis_ai_assistant-0.1.138.dist-info}/WHEEL +0 -0
- {jarvis_ai_assistant-0.1.134.dist-info → jarvis_ai_assistant-0.1.138.dist-info}/top_level.txt +0 -0
jarvis/jarvis_utils/config.py
CHANGED
|
@@ -12,35 +12,35 @@ import os
|
|
|
12
12
|
def get_max_token_count() -> int:
|
|
13
13
|
"""
|
|
14
14
|
获取模型允许的最大token数量。
|
|
15
|
-
|
|
15
|
+
|
|
16
16
|
返回:
|
|
17
17
|
int: 模型能处理的最大token数量。
|
|
18
18
|
"""
|
|
19
19
|
return int(os.getenv('JARVIS_MAX_TOKEN_COUNT', '64000')) # 默认64k
|
|
20
|
-
|
|
20
|
+
|
|
21
21
|
def get_thread_count() -> int:
|
|
22
22
|
"""
|
|
23
23
|
获取用于并行处理的线程数。
|
|
24
|
-
|
|
24
|
+
|
|
25
25
|
返回:
|
|
26
26
|
int: 线程数,默认为1
|
|
27
27
|
"""
|
|
28
|
-
return int(os.getenv('JARVIS_THREAD_COUNT', '1'))
|
|
29
|
-
|
|
28
|
+
return int(os.getenv('JARVIS_THREAD_COUNT', '1'))
|
|
29
|
+
|
|
30
30
|
def is_auto_complete() -> bool:
|
|
31
31
|
"""
|
|
32
32
|
检查是否启用了自动补全功能。
|
|
33
|
-
|
|
33
|
+
|
|
34
34
|
返回:
|
|
35
35
|
bool: 如果启用了自动补全则返回True,默认为False
|
|
36
36
|
"""
|
|
37
37
|
return os.getenv('JARVIS_AUTO_COMPLETE', 'false') == 'true'
|
|
38
|
-
|
|
38
|
+
|
|
39
39
|
|
|
40
40
|
def get_min_paragraph_length() -> int:
|
|
41
41
|
"""
|
|
42
42
|
获取文本处理的最小段落长度。
|
|
43
|
-
|
|
43
|
+
|
|
44
44
|
返回:
|
|
45
45
|
int: 最小字符长度,默认为50
|
|
46
46
|
"""
|
|
@@ -48,7 +48,7 @@ def get_min_paragraph_length() -> int:
|
|
|
48
48
|
def get_max_paragraph_length() -> int:
|
|
49
49
|
"""
|
|
50
50
|
获取文本处理的最大段落长度。
|
|
51
|
-
|
|
51
|
+
|
|
52
52
|
返回:
|
|
53
53
|
int: 最大字符长度,默认为12800
|
|
54
54
|
"""
|
|
@@ -56,7 +56,7 @@ def get_max_paragraph_length() -> int:
|
|
|
56
56
|
def get_shell_name() -> str:
|
|
57
57
|
"""
|
|
58
58
|
获取系统shell名称。
|
|
59
|
-
|
|
59
|
+
|
|
60
60
|
返回:
|
|
61
61
|
str: Shell名称(例如bash, zsh),默认为bash
|
|
62
62
|
"""
|
|
@@ -64,7 +64,7 @@ def get_shell_name() -> str:
|
|
|
64
64
|
def get_normal_platform_name() -> str:
|
|
65
65
|
"""
|
|
66
66
|
获取正常操作的平台名称。
|
|
67
|
-
|
|
67
|
+
|
|
68
68
|
返回:
|
|
69
69
|
str: 平台名称,默认为'yuanbao'
|
|
70
70
|
"""
|
|
@@ -72,7 +72,7 @@ def get_normal_platform_name() -> str:
|
|
|
72
72
|
def get_normal_model_name() -> str:
|
|
73
73
|
"""
|
|
74
74
|
获取正常操作的模型名称。
|
|
75
|
-
|
|
75
|
+
|
|
76
76
|
返回:
|
|
77
77
|
str: 模型名称,默认为'deep_seek'
|
|
78
78
|
"""
|
|
@@ -82,7 +82,7 @@ def get_normal_model_name() -> str:
|
|
|
82
82
|
def get_thinking_platform_name() -> str:
|
|
83
83
|
"""
|
|
84
84
|
获取思考操作的平台名称。
|
|
85
|
-
|
|
85
|
+
|
|
86
86
|
返回:
|
|
87
87
|
str: 平台名称,默认为'yuanbao'
|
|
88
88
|
"""
|
|
@@ -90,7 +90,7 @@ def get_thinking_platform_name() -> str:
|
|
|
90
90
|
def get_thinking_model_name() -> str:
|
|
91
91
|
"""
|
|
92
92
|
获取思考操作的模型名称。
|
|
93
|
-
|
|
93
|
+
|
|
94
94
|
返回:
|
|
95
95
|
str: 模型名称,默认为'deep_seek'
|
|
96
96
|
"""
|
|
@@ -99,7 +99,7 @@ def get_thinking_model_name() -> str:
|
|
|
99
99
|
def is_execute_tool_confirm() -> bool:
|
|
100
100
|
"""
|
|
101
101
|
检查工具执行是否需要确认。
|
|
102
|
-
|
|
102
|
+
|
|
103
103
|
返回:
|
|
104
104
|
bool: 如果需要确认则返回True,默认为False
|
|
105
105
|
"""
|
|
@@ -107,7 +107,7 @@ def is_execute_tool_confirm() -> bool:
|
|
|
107
107
|
def is_confirm_before_apply_patch() -> bool:
|
|
108
108
|
"""
|
|
109
109
|
检查应用补丁前是否需要确认。
|
|
110
|
-
|
|
110
|
+
|
|
111
111
|
返回:
|
|
112
112
|
bool: 如果需要确认则返回True,默认为False
|
|
113
113
|
"""
|
|
@@ -116,18 +116,18 @@ def is_confirm_before_apply_patch() -> bool:
|
|
|
116
116
|
def get_rag_ignored_paths() -> list:
|
|
117
117
|
"""
|
|
118
118
|
获取RAG索引时需要忽略的路径列表。
|
|
119
|
-
|
|
119
|
+
|
|
120
120
|
首先尝试从.jarvis/rag_ignore.txt文件中读取,
|
|
121
121
|
如果该文件不存在,则返回默认忽略列表。
|
|
122
|
-
|
|
122
|
+
|
|
123
123
|
返回:
|
|
124
124
|
list: 忽略路径的列表,默认包含常见忽略路径
|
|
125
125
|
"""
|
|
126
126
|
# 默认忽略路径
|
|
127
127
|
default_ignored = [
|
|
128
|
-
'.git',
|
|
129
|
-
'__pycache__',
|
|
130
|
-
'node_modules',
|
|
128
|
+
'.git',
|
|
129
|
+
'__pycache__',
|
|
130
|
+
'node_modules',
|
|
131
131
|
'.jarvis',
|
|
132
132
|
'.jarvis-*',
|
|
133
133
|
'target',
|
|
@@ -164,7 +164,7 @@ def get_rag_ignored_paths() -> list:
|
|
|
164
164
|
'*.xz',
|
|
165
165
|
'*.rar'
|
|
166
166
|
]
|
|
167
|
-
|
|
167
|
+
|
|
168
168
|
# 尝试从配置文件中读取
|
|
169
169
|
try:
|
|
170
170
|
config_path = os.path.join('.jarvis', 'rag_ignore.txt')
|
|
@@ -174,5 +174,5 @@ def get_rag_ignored_paths() -> list:
|
|
|
174
174
|
return custom_ignored
|
|
175
175
|
except Exception:
|
|
176
176
|
pass
|
|
177
|
-
|
|
177
|
+
|
|
178
178
|
return default_ignored
|
jarvis/jarvis_utils/embedding.py
CHANGED
|
@@ -15,10 +15,10 @@ _global_tokenizers = {}
|
|
|
15
15
|
|
|
16
16
|
def get_context_token_count(text: str) -> int:
|
|
17
17
|
"""使用分词器获取文本的token数量。
|
|
18
|
-
|
|
18
|
+
|
|
19
19
|
参数:
|
|
20
20
|
text: 要计算token的输入文本
|
|
21
|
-
|
|
21
|
+
|
|
22
22
|
返回:
|
|
23
23
|
int: 文本中的token数量
|
|
24
24
|
"""
|
|
@@ -27,7 +27,7 @@ def get_context_token_count(text: str) -> int:
|
|
|
27
27
|
tokenizer = load_tokenizer()
|
|
28
28
|
chunks = split_text_into_chunks(text, 512)
|
|
29
29
|
return sum([len(tokenizer.encode(chunk)) for chunk in chunks]) # type: ignore
|
|
30
|
-
|
|
30
|
+
|
|
31
31
|
except Exception as e:
|
|
32
32
|
PrettyOutput.print(f"计算token失败: {str(e)}", OutputType.WARNING)
|
|
33
33
|
# 回退到基于字符的粗略估计
|
|
@@ -37,17 +37,17 @@ def get_context_token_count(text: str) -> int:
|
|
|
37
37
|
def load_embedding_model() -> SentenceTransformer:
|
|
38
38
|
"""
|
|
39
39
|
加载句子嵌入模型,使用缓存避免重复加载。
|
|
40
|
-
|
|
40
|
+
|
|
41
41
|
返回:
|
|
42
42
|
SentenceTransformer: 加载的嵌入模型
|
|
43
43
|
"""
|
|
44
44
|
model_name = "BAAI/bge-m3"
|
|
45
45
|
cache_dir = os.path.expanduser("~/.cache/huggingface/hub")
|
|
46
|
-
|
|
46
|
+
|
|
47
47
|
# 检查全局缓存中是否已有模型
|
|
48
48
|
if model_name in _global_models:
|
|
49
49
|
return _global_models[model_name]
|
|
50
|
-
|
|
50
|
+
|
|
51
51
|
try:
|
|
52
52
|
embedding_model = SentenceTransformer(
|
|
53
53
|
model_name,
|
|
@@ -60,28 +60,28 @@ def load_embedding_model() -> SentenceTransformer:
|
|
|
60
60
|
cache_folder=cache_dir,
|
|
61
61
|
local_files_only=False
|
|
62
62
|
)
|
|
63
|
-
|
|
63
|
+
|
|
64
64
|
# 如果可用,将模型移到GPU上
|
|
65
65
|
if torch.cuda.is_available():
|
|
66
66
|
embedding_model.to(torch.device("cuda"))
|
|
67
|
-
|
|
67
|
+
|
|
68
68
|
# 保存到全局缓存
|
|
69
69
|
_global_models[model_name] = embedding_model
|
|
70
|
-
|
|
70
|
+
|
|
71
71
|
return embedding_model
|
|
72
72
|
|
|
73
73
|
def get_embedding(embedding_model: Any, text: str) -> np.ndarray:
|
|
74
74
|
"""
|
|
75
75
|
为给定文本生成嵌入向量。
|
|
76
|
-
|
|
76
|
+
|
|
77
77
|
参数:
|
|
78
78
|
embedding_model: 使用的嵌入模型
|
|
79
79
|
text: 要嵌入的输入文本
|
|
80
|
-
|
|
80
|
+
|
|
81
81
|
返回:
|
|
82
82
|
np.ndarray: 嵌入向量
|
|
83
83
|
"""
|
|
84
|
-
embedding = embedding_model.encode(text,
|
|
84
|
+
embedding = embedding_model.encode(text,
|
|
85
85
|
normalize_embeddings=True,
|
|
86
86
|
show_progress_bar=False)
|
|
87
87
|
return np.array(embedding, dtype=np.float32)
|
|
@@ -89,53 +89,53 @@ def get_embedding(embedding_model: Any, text: str) -> np.ndarray:
|
|
|
89
89
|
def get_embedding_batch(embedding_model: Any, prefix: str, texts: List[str], spinner: Optional[Yaspin] = None, batch_size: int = 8) -> np.ndarray:
|
|
90
90
|
"""
|
|
91
91
|
为一批文本生成嵌入向量,使用高效的批处理,针对RAG优化。
|
|
92
|
-
|
|
92
|
+
|
|
93
93
|
参数:
|
|
94
94
|
embedding_model: 使用的嵌入模型
|
|
95
95
|
prefix: 进度条前缀
|
|
96
96
|
texts: 要嵌入的文本列表
|
|
97
97
|
spinner: 可选的进度指示器
|
|
98
98
|
batch_size: 批处理大小,更大的值可能更快但需要更多内存
|
|
99
|
-
|
|
99
|
+
|
|
100
100
|
返回:
|
|
101
101
|
np.ndarray: 堆叠的嵌入向量
|
|
102
102
|
"""
|
|
103
103
|
# 简单嵌入缓存,避免重复计算相同文本块
|
|
104
104
|
embedding_cache = {}
|
|
105
105
|
cache_hits = 0
|
|
106
|
-
|
|
106
|
+
|
|
107
107
|
try:
|
|
108
108
|
# 预处理:将所有文本分块
|
|
109
109
|
all_chunks = []
|
|
110
110
|
chunk_indices = [] # 跟踪每个原始文本对应的块索引
|
|
111
|
-
|
|
111
|
+
|
|
112
112
|
for i, text in enumerate(texts):
|
|
113
113
|
if spinner:
|
|
114
114
|
spinner.text = f"{prefix} 预处理中 ({i+1}/{len(texts)}) ..."
|
|
115
|
-
|
|
115
|
+
|
|
116
116
|
# 预处理文本:移除多余空白,规范化
|
|
117
117
|
text = ' '.join(text.split()) if text else ""
|
|
118
|
-
|
|
118
|
+
|
|
119
119
|
# 使用更优化的分块函数
|
|
120
120
|
chunks = split_text_into_chunks(text, 512)
|
|
121
121
|
start_idx = len(all_chunks)
|
|
122
122
|
all_chunks.extend(chunks)
|
|
123
123
|
end_idx = len(all_chunks)
|
|
124
124
|
chunk_indices.append((start_idx, end_idx))
|
|
125
|
-
|
|
125
|
+
|
|
126
126
|
if not all_chunks:
|
|
127
127
|
return np.zeros((0, embedding_model.get_sentence_embedding_dimension()), dtype=np.float32)
|
|
128
|
-
|
|
128
|
+
|
|
129
129
|
# 批量处理所有块
|
|
130
130
|
all_vectors = []
|
|
131
131
|
for i in range(0, len(all_chunks), batch_size):
|
|
132
132
|
if spinner:
|
|
133
133
|
spinner.text = f"{prefix} 批量处理嵌入 ({i+1}/{len(all_chunks)}) ..."
|
|
134
|
-
|
|
134
|
+
|
|
135
135
|
batch = all_chunks[i:i+batch_size]
|
|
136
136
|
batch_to_process = []
|
|
137
137
|
batch_indices = []
|
|
138
|
-
|
|
138
|
+
|
|
139
139
|
# 检查缓存,避免重复计算
|
|
140
140
|
for j, chunk in enumerate(batch):
|
|
141
141
|
chunk_hash = hash(chunk)
|
|
@@ -145,16 +145,16 @@ def get_embedding_batch(embedding_model: Any, prefix: str, texts: List[str], spi
|
|
|
145
145
|
else:
|
|
146
146
|
batch_to_process.append(chunk)
|
|
147
147
|
batch_indices.append(j)
|
|
148
|
-
|
|
148
|
+
|
|
149
149
|
if batch_to_process:
|
|
150
150
|
# 对未缓存的块处理
|
|
151
151
|
batch_vectors = embedding_model.encode(
|
|
152
|
-
batch_to_process,
|
|
152
|
+
batch_to_process,
|
|
153
153
|
normalize_embeddings=True,
|
|
154
154
|
show_progress_bar=False,
|
|
155
155
|
convert_to_numpy=True,
|
|
156
156
|
)
|
|
157
|
-
|
|
157
|
+
|
|
158
158
|
# 处理结果并更新缓存
|
|
159
159
|
if len(batch_to_process) == 1:
|
|
160
160
|
vec = batch_vectors
|
|
@@ -166,7 +166,7 @@ def get_embedding_batch(embedding_model: Any, prefix: str, texts: List[str], spi
|
|
|
166
166
|
chunk_hash = hash(batch_to_process[j])
|
|
167
167
|
embedding_cache[chunk_hash] = vec
|
|
168
168
|
all_vectors.append(vec)
|
|
169
|
-
|
|
169
|
+
|
|
170
170
|
# 组织结果到原始文本顺序
|
|
171
171
|
result_vectors = []
|
|
172
172
|
for start_idx, end_idx in chunk_indices:
|
|
@@ -174,73 +174,73 @@ def get_embedding_batch(embedding_model: Any, prefix: str, texts: List[str], spi
|
|
|
174
174
|
for j in range(start_idx, end_idx):
|
|
175
175
|
if j < len(all_vectors):
|
|
176
176
|
text_vectors.append(all_vectors[j])
|
|
177
|
-
|
|
177
|
+
|
|
178
178
|
if text_vectors:
|
|
179
179
|
# 当一个文本被分成多个块时,采用加权平均
|
|
180
180
|
if len(text_vectors) > 1:
|
|
181
181
|
# 针对RAG优化:对多个块进行加权平均,前面的块权重略高
|
|
182
182
|
weights = np.linspace(1.0, 0.8, len(text_vectors))
|
|
183
183
|
weights = weights / weights.sum() # 归一化权重
|
|
184
|
-
|
|
184
|
+
|
|
185
185
|
# 应用权重并求和
|
|
186
186
|
weighted_sum = np.zeros_like(text_vectors[0])
|
|
187
187
|
for i, vec in enumerate(text_vectors):
|
|
188
188
|
# 确保向量形状一致,处理可能的维度不匹配问题
|
|
189
189
|
vec_array = np.asarray(vec).reshape(weighted_sum.shape)
|
|
190
190
|
weighted_sum += vec_array * weights[i]
|
|
191
|
-
|
|
191
|
+
|
|
192
192
|
# 归一化结果向量
|
|
193
193
|
norm = np.linalg.norm(weighted_sum)
|
|
194
194
|
if norm > 0:
|
|
195
195
|
weighted_sum = weighted_sum / norm
|
|
196
|
-
|
|
196
|
+
|
|
197
197
|
result_vectors.append(weighted_sum)
|
|
198
198
|
else:
|
|
199
199
|
# 单块直接使用
|
|
200
200
|
result_vectors.append(text_vectors[0])
|
|
201
|
-
|
|
201
|
+
|
|
202
202
|
if spinner and cache_hits > 0:
|
|
203
203
|
spinner.text = f"{prefix} 缓存命中: {cache_hits}/{len(all_chunks)} 块"
|
|
204
|
-
|
|
204
|
+
|
|
205
205
|
return np.vstack(result_vectors)
|
|
206
|
-
|
|
206
|
+
|
|
207
207
|
except Exception as e:
|
|
208
208
|
PrettyOutput.print(f"批量嵌入失败: {str(e)}", OutputType.ERROR)
|
|
209
209
|
return np.zeros((0, embedding_model.get_sentence_embedding_dimension()), dtype=np.float32)
|
|
210
|
-
|
|
210
|
+
|
|
211
211
|
def split_text_into_chunks(text: str, max_length: int = 512, min_length: int = 50) -> List[str]:
|
|
212
212
|
"""将文本分割成带重叠窗口的块,优化RAG检索效果。
|
|
213
|
-
|
|
213
|
+
|
|
214
214
|
参数:
|
|
215
215
|
text: 要分割的输入文本
|
|
216
216
|
max_length: 每个块的最大长度
|
|
217
217
|
min_length: 每个块的最小长度(除了最后一块可能较短)
|
|
218
|
-
|
|
218
|
+
|
|
219
219
|
返回:
|
|
220
220
|
List[str]: 文本块列表,每个块的长度尽可能接近但不超过max_length
|
|
221
221
|
"""
|
|
222
222
|
if not text:
|
|
223
223
|
return []
|
|
224
|
-
|
|
224
|
+
|
|
225
225
|
# 如果文本长度小于最大长度,直接返回整个文本
|
|
226
226
|
if len(text) <= max_length:
|
|
227
227
|
return [text]
|
|
228
|
-
|
|
228
|
+
|
|
229
229
|
# 预处理:规范化文本,移除多余空白字符
|
|
230
230
|
text = ' '.join(text.split())
|
|
231
|
-
|
|
231
|
+
|
|
232
232
|
# 中英文标点符号集合,优化RAG召回的句子边界
|
|
233
233
|
primary_punctuation = {'.', '!', '?', '\n', '。', '!', '?'} # 主要句末标点
|
|
234
234
|
secondary_punctuation = {';', ':', '…', ';', ':'} # 次级分隔符
|
|
235
235
|
tertiary_punctuation = {',', ',', '、', ')', ')', ']', '】', '}', '》', '"', "'"} # 最低优先级
|
|
236
|
-
|
|
236
|
+
|
|
237
237
|
chunks = []
|
|
238
238
|
start = 0
|
|
239
|
-
|
|
239
|
+
|
|
240
240
|
while start < len(text):
|
|
241
241
|
# 初始化结束位置为最大可能长度
|
|
242
242
|
end = min(start + max_length, len(text))
|
|
243
|
-
|
|
243
|
+
|
|
244
244
|
# 只有当不是最后一块且结束位置等于最大长度时,才尝试寻找句子边界
|
|
245
245
|
if end < len(text) and end == start + max_length:
|
|
246
246
|
# 优先查找段落边界,这对RAG特别重要
|
|
@@ -251,17 +251,17 @@ def split_text_into_chunks(text: str, max_length: int = 512, min_length: int = 5
|
|
|
251
251
|
# 寻找句子边界,从end-1位置开始
|
|
252
252
|
found_boundary = False
|
|
253
253
|
best_boundary = -1
|
|
254
|
-
|
|
254
|
+
|
|
255
255
|
# 扩大搜索范围以找到更好的语义边界
|
|
256
256
|
search_range = min(120, end - start - min_length) # 扩大搜索范围,但确保新块不小于min_length
|
|
257
|
-
|
|
257
|
+
|
|
258
258
|
# 先尝试找主要标点(句号等)
|
|
259
259
|
for i in range(end-1, max(start, end-search_range), -1):
|
|
260
260
|
if text[i] in primary_punctuation:
|
|
261
261
|
best_boundary = i
|
|
262
262
|
found_boundary = True
|
|
263
263
|
break
|
|
264
|
-
|
|
264
|
+
|
|
265
265
|
# 如果没找到主要标点,再找次要标点(分号、冒号等)
|
|
266
266
|
if not found_boundary:
|
|
267
267
|
for i in range(end-1, max(start, end-search_range), -1):
|
|
@@ -269,7 +269,7 @@ def split_text_into_chunks(text: str, max_length: int = 512, min_length: int = 5
|
|
|
269
269
|
best_boundary = i
|
|
270
270
|
found_boundary = True
|
|
271
271
|
break
|
|
272
|
-
|
|
272
|
+
|
|
273
273
|
# 最后考虑逗号和其他可能的边界
|
|
274
274
|
if not found_boundary:
|
|
275
275
|
for i in range(end-1, max(start, end-search_range), -1):
|
|
@@ -277,11 +277,11 @@ def split_text_into_chunks(text: str, max_length: int = 512, min_length: int = 5
|
|
|
277
277
|
best_boundary = i
|
|
278
278
|
found_boundary = True
|
|
279
279
|
break
|
|
280
|
-
|
|
280
|
+
|
|
281
281
|
# 如果找到了合适的边界且不会导致太短的块,使用它
|
|
282
282
|
if found_boundary and (best_boundary - start) >= min_length:
|
|
283
283
|
end = best_boundary + 1
|
|
284
|
-
|
|
284
|
+
|
|
285
285
|
# 添加当前块,并确保删除开头和结尾的空白字符
|
|
286
286
|
chunk = text[start:end].strip()
|
|
287
287
|
if chunk and len(chunk) >= min_length: # 只添加符合最小长度的非空块
|
|
@@ -295,16 +295,16 @@ def split_text_into_chunks(text: str, max_length: int = 512, min_length: int = 5
|
|
|
295
295
|
else:
|
|
296
296
|
# 如果合并会导致太长,添加这个小块(特殊情况)
|
|
297
297
|
chunks.append(chunk)
|
|
298
|
-
|
|
298
|
+
|
|
299
299
|
# 计算下一块的开始位置,调整重叠窗口大小以提高RAG检索质量
|
|
300
300
|
next_start = end - int(max_length * 0.2) # 20%的重叠窗口大小
|
|
301
|
-
|
|
301
|
+
|
|
302
302
|
# 确保总是有前进,避免无限循环
|
|
303
303
|
if next_start <= start:
|
|
304
304
|
next_start = start + max(1, min_length // 2)
|
|
305
|
-
|
|
305
|
+
|
|
306
306
|
start = next_start
|
|
307
|
-
|
|
307
|
+
|
|
308
308
|
# 最后检查是否有太短的块,尝试合并相邻的短块
|
|
309
309
|
if len(chunks) > 1:
|
|
310
310
|
merged_chunks = []
|
|
@@ -321,7 +321,7 @@ def split_text_into_chunks(text: str, max_length: int = 512, min_length: int = 5
|
|
|
321
321
|
merged_chunks.append(current)
|
|
322
322
|
i += 1
|
|
323
323
|
chunks = merged_chunks
|
|
324
|
-
|
|
324
|
+
|
|
325
325
|
return chunks
|
|
326
326
|
|
|
327
327
|
|
|
@@ -329,17 +329,17 @@ def split_text_into_chunks(text: str, max_length: int = 512, min_length: int = 5
|
|
|
329
329
|
def load_tokenizer() -> AutoTokenizer:
|
|
330
330
|
"""
|
|
331
331
|
加载用于文本处理的分词器,使用缓存避免重复加载。
|
|
332
|
-
|
|
332
|
+
|
|
333
333
|
返回:
|
|
334
334
|
AutoTokenizer: 加载的分词器
|
|
335
335
|
"""
|
|
336
336
|
model_name = "gpt2"
|
|
337
337
|
cache_dir = os.path.expanduser("~/.cache/huggingface/hub")
|
|
338
|
-
|
|
338
|
+
|
|
339
339
|
# 检查全局缓存
|
|
340
340
|
if model_name in _global_tokenizers:
|
|
341
341
|
return _global_tokenizers[model_name]
|
|
342
|
-
|
|
342
|
+
|
|
343
343
|
try:
|
|
344
344
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
345
345
|
model_name,
|
|
@@ -352,28 +352,28 @@ def load_tokenizer() -> AutoTokenizer:
|
|
|
352
352
|
cache_dir=cache_dir,
|
|
353
353
|
local_files_only=False
|
|
354
354
|
)
|
|
355
|
-
|
|
355
|
+
|
|
356
356
|
# 保存到全局缓存
|
|
357
357
|
_global_tokenizers[model_name] = tokenizer
|
|
358
|
-
|
|
358
|
+
|
|
359
359
|
return tokenizer # type: ignore
|
|
360
360
|
|
|
361
361
|
@functools.lru_cache(maxsize=1)
|
|
362
362
|
def load_rerank_model() -> Tuple[AutoModelForSequenceClassification, AutoTokenizer]:
|
|
363
363
|
"""
|
|
364
364
|
加载重排序模型和分词器,使用缓存避免重复加载。
|
|
365
|
-
|
|
365
|
+
|
|
366
366
|
返回:
|
|
367
367
|
Tuple[AutoModelForSequenceClassification, AutoTokenizer]: 加载的模型和分词器
|
|
368
368
|
"""
|
|
369
369
|
model_name = "BAAI/bge-reranker-v2-m3"
|
|
370
370
|
cache_dir = os.path.expanduser("~/.cache/huggingface/hub")
|
|
371
|
-
|
|
371
|
+
|
|
372
372
|
# 检查全局缓存
|
|
373
373
|
key = f"rerank_{model_name}"
|
|
374
374
|
if key in _global_models and f"{key}_tokenizer" in _global_tokenizers:
|
|
375
375
|
return _global_models[key], _global_tokenizers[f"{key}_tokenizer"]
|
|
376
|
-
|
|
376
|
+
|
|
377
377
|
try:
|
|
378
378
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
379
379
|
model_name,
|
|
@@ -396,53 +396,53 @@ def load_rerank_model() -> Tuple[AutoModelForSequenceClassification, AutoTokeniz
|
|
|
396
396
|
cache_dir=cache_dir,
|
|
397
397
|
local_files_only=False
|
|
398
398
|
)
|
|
399
|
-
|
|
399
|
+
|
|
400
400
|
if torch.cuda.is_available():
|
|
401
401
|
model = model.cuda()
|
|
402
402
|
model.eval()
|
|
403
|
-
|
|
403
|
+
|
|
404
404
|
# 保存到全局缓存
|
|
405
405
|
_global_models[key] = model
|
|
406
406
|
_global_tokenizers[f"{key}_tokenizer"] = tokenizer
|
|
407
|
-
|
|
407
|
+
|
|
408
408
|
return model, tokenizer # type: ignore
|
|
409
409
|
|
|
410
|
-
def rerank_results(query: str, documents: List[str], initial_scores: Optional[List[float]] = None,
|
|
410
|
+
def rerank_results(query: str, documents: List[str], initial_scores: Optional[List[float]] = None,
|
|
411
411
|
batch_size: int = 8, spinner: Optional[Yaspin] = None) -> List[float]:
|
|
412
412
|
"""
|
|
413
413
|
使用交叉编码器重排序检索结果,提高RAG精度。
|
|
414
|
-
|
|
414
|
+
|
|
415
415
|
参数:
|
|
416
416
|
query: 查询文本
|
|
417
417
|
documents: 要重排序的文档内容列表
|
|
418
418
|
initial_scores: 初始检索分数,可选。如果提供,将与重排序分数融合
|
|
419
419
|
batch_size: 批处理大小
|
|
420
420
|
spinner: 可选的进度指示器
|
|
421
|
-
|
|
421
|
+
|
|
422
422
|
返回:
|
|
423
423
|
List[float]: 重排序后的分数列表,与输入文档对应
|
|
424
424
|
"""
|
|
425
425
|
try:
|
|
426
426
|
if not documents:
|
|
427
427
|
return []
|
|
428
|
-
|
|
428
|
+
|
|
429
429
|
# 加载重排序模型
|
|
430
430
|
if spinner:
|
|
431
431
|
spinner.text = "加载重排序模型..."
|
|
432
432
|
model, tokenizer = load_rerank_model()
|
|
433
|
-
|
|
433
|
+
|
|
434
434
|
# 准备评分
|
|
435
435
|
all_scores = []
|
|
436
|
-
|
|
436
|
+
|
|
437
437
|
# 批量处理
|
|
438
438
|
for i in range(0, len(documents), batch_size):
|
|
439
439
|
if spinner:
|
|
440
440
|
spinner.text = f"重排序进度: {i}/{len(documents)}..."
|
|
441
|
-
|
|
441
|
+
|
|
442
442
|
# 准备当前批次
|
|
443
443
|
batch_docs = documents[i:i+batch_size]
|
|
444
444
|
pairs = [(query, doc) for doc in batch_docs]
|
|
445
|
-
|
|
445
|
+
|
|
446
446
|
# 编码输入
|
|
447
447
|
with torch.no_grad():
|
|
448
448
|
# 使用类型忽略以避免mypy错误
|
|
@@ -453,21 +453,21 @@ def rerank_results(query: str, documents: List[str], initial_scores: Optional[Li
|
|
|
453
453
|
return_tensors="pt",
|
|
454
454
|
max_length=512
|
|
455
455
|
)
|
|
456
|
-
|
|
456
|
+
|
|
457
457
|
# 使用GPU加速(如果可用)
|
|
458
458
|
if torch.cuda.is_available():
|
|
459
459
|
inputs = {k: v.cuda() for k, v in inputs.items()}
|
|
460
|
-
|
|
460
|
+
|
|
461
461
|
# 获取分数
|
|
462
462
|
outputs = model(**inputs) # type: ignore
|
|
463
463
|
scores = outputs.logits.squeeze(-1).cpu().tolist()
|
|
464
|
-
|
|
464
|
+
|
|
465
465
|
# 如果只有一个文档,确保返回列表
|
|
466
466
|
if len(batch_docs) == 1:
|
|
467
467
|
all_scores.append(float(scores))
|
|
468
468
|
else:
|
|
469
469
|
all_scores.extend(scores)
|
|
470
|
-
|
|
470
|
+
|
|
471
471
|
# 归一化分数到0-1范围
|
|
472
472
|
if all_scores:
|
|
473
473
|
min_score = min(all_scores)
|
|
@@ -476,26 +476,26 @@ def rerank_results(query: str, documents: List[str], initial_scores: Optional[Li
|
|
|
476
476
|
normalized_scores = [(score - min_score) / (max_score - min_score) for score in all_scores]
|
|
477
477
|
else:
|
|
478
478
|
normalized_scores = [0.5] * len(all_scores)
|
|
479
|
-
|
|
479
|
+
|
|
480
480
|
# 融合初始分数(如果提供)
|
|
481
481
|
if initial_scores and len(initial_scores) == len(normalized_scores):
|
|
482
482
|
# 使用加权平均融合分数:初始分数权重0.3,重排序分数权重0.7
|
|
483
|
-
final_scores = [0.3 * init_score + 0.7 * rerank_score
|
|
483
|
+
final_scores = [0.3 * init_score + 0.7 * rerank_score
|
|
484
484
|
for init_score, rerank_score in zip(initial_scores, normalized_scores)]
|
|
485
485
|
return final_scores
|
|
486
|
-
|
|
486
|
+
|
|
487
487
|
return normalized_scores
|
|
488
|
-
|
|
488
|
+
|
|
489
489
|
if spinner:
|
|
490
490
|
spinner.text = "重排序完成"
|
|
491
|
-
|
|
491
|
+
|
|
492
492
|
# 如果重排序失败,返回初始分数或默认分数
|
|
493
493
|
return initial_scores if initial_scores else [0.5] * len(documents)
|
|
494
|
-
|
|
494
|
+
|
|
495
495
|
except Exception as e:
|
|
496
496
|
PrettyOutput.print(f"重排序失败: {str(e)}", OutputType.ERROR)
|
|
497
497
|
if spinner:
|
|
498
498
|
spinner.text = f"重排序失败: {str(e)}"
|
|
499
|
-
|
|
499
|
+
|
|
500
500
|
# 发生错误时回退到初始分数
|
|
501
501
|
return initial_scores if initial_scores else [0.5] * len(documents)
|