jarvis-ai-assistant 0.1.193__py3-none-any.whl → 0.1.195__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jarvis/__init__.py +1 -1
- jarvis/jarvis_agent/__init__.py +45 -41
- jarvis/jarvis_agent/builtin_input_handler.py +26 -4
- jarvis/jarvis_agent/jarvis.py +30 -19
- jarvis/jarvis_agent/main.py +20 -12
- jarvis/jarvis_agent/output_handler.py +7 -7
- jarvis/jarvis_agent/shell_input_handler.py +14 -11
- jarvis/jarvis_code_agent/code_agent.py +81 -79
- jarvis/jarvis_code_agent/lint.py +92 -105
- jarvis/jarvis_code_analysis/checklists/__init__.py +1 -1
- jarvis/jarvis_code_analysis/checklists/c_cpp.py +1 -1
- jarvis/jarvis_code_analysis/checklists/csharp.py +1 -1
- jarvis/jarvis_code_analysis/checklists/data_format.py +1 -1
- jarvis/jarvis_code_analysis/checklists/devops.py +1 -1
- jarvis/jarvis_code_analysis/checklists/docs.py +1 -1
- jarvis/jarvis_code_analysis/checklists/go.py +1 -1
- jarvis/jarvis_code_analysis/checklists/infrastructure.py +1 -1
- jarvis/jarvis_code_analysis/checklists/java.py +1 -1
- jarvis/jarvis_code_analysis/checklists/javascript.py +1 -1
- jarvis/jarvis_code_analysis/checklists/kotlin.py +1 -1
- jarvis/jarvis_code_analysis/checklists/loader.py +31 -29
- jarvis/jarvis_code_analysis/checklists/php.py +1 -1
- jarvis/jarvis_code_analysis/checklists/python.py +1 -1
- jarvis/jarvis_code_analysis/checklists/ruby.py +1 -1
- jarvis/jarvis_code_analysis/checklists/rust.py +1 -1
- jarvis/jarvis_code_analysis/checklists/shell.py +1 -1
- jarvis/jarvis_code_analysis/checklists/sql.py +1 -1
- jarvis/jarvis_code_analysis/checklists/swift.py +1 -1
- jarvis/jarvis_code_analysis/checklists/web.py +1 -1
- jarvis/jarvis_code_analysis/code_review.py +292 -190
- jarvis/jarvis_dev/main.py +73 -56
- jarvis/jarvis_git_details/main.py +29 -33
- jarvis/jarvis_git_squash/main.py +13 -11
- jarvis/jarvis_git_utils/git_commiter.py +15 -5
- jarvis/jarvis_mcp/__init__.py +8 -10
- jarvis/jarvis_mcp/sse_mcp_client.py +182 -205
- jarvis/jarvis_mcp/stdio_mcp_client.py +93 -120
- jarvis/jarvis_mcp/streamable_mcp_client.py +117 -142
- jarvis/jarvis_methodology/main.py +71 -39
- jarvis/jarvis_multi_agent/__init__.py +24 -16
- jarvis/jarvis_multi_agent/main.py +10 -4
- jarvis/jarvis_platform/__init__.py +1 -1
- jarvis/jarvis_platform/base.py +44 -18
- jarvis/jarvis_platform/human.py +15 -3
- jarvis/jarvis_platform/kimi.py +117 -81
- jarvis/jarvis_platform/openai.py +23 -28
- jarvis/jarvis_platform/registry.py +43 -29
- jarvis/jarvis_platform/tongyi.py +16 -10
- jarvis/jarvis_platform/yuanbao.py +197 -144
- jarvis/jarvis_platform_manager/main.py +4 -2
- jarvis/jarvis_smart_shell/main.py +35 -30
- jarvis/jarvis_tools/ask_user.py +8 -16
- jarvis/jarvis_tools/base.py +3 -2
- jarvis/jarvis_tools/chdir.py +7 -19
- jarvis/jarvis_tools/cli/main.py +14 -10
- jarvis/jarvis_tools/code_plan.py +10 -31
- jarvis/jarvis_tools/create_code_agent.py +6 -11
- jarvis/jarvis_tools/create_sub_agent.py +10 -22
- jarvis/jarvis_tools/edit_file.py +98 -76
- jarvis/jarvis_tools/execute_script.py +46 -46
- jarvis/jarvis_tools/file_analyzer.py +22 -34
- jarvis/jarvis_tools/file_operation.py +69 -62
- jarvis/jarvis_tools/generate_new_tool.py +0 -2
- jarvis/jarvis_tools/methodology.py +19 -23
- jarvis/jarvis_tools/read_code.py +35 -35
- jarvis/jarvis_tools/read_webpage.py +7 -16
- jarvis/jarvis_tools/registry.py +63 -30
- jarvis/jarvis_tools/rewrite_file.py +26 -29
- jarvis/jarvis_tools/search_web.py +5 -8
- jarvis/jarvis_tools/virtual_tty.py +133 -122
- jarvis/jarvis_utils/__init__.py +0 -1
- jarvis/jarvis_utils/builtin_replace_map.py +9 -9
- jarvis/jarvis_utils/config.py +60 -37
- jarvis/jarvis_utils/embedding.py +24 -19
- jarvis/jarvis_utils/file_processors.py +16 -9
- jarvis/jarvis_utils/git_utils.py +157 -107
- jarvis/jarvis_utils/globals.py +1 -1
- jarvis/jarvis_utils/input.py +85 -52
- jarvis/jarvis_utils/jarvis_history.py +43 -0
- jarvis/jarvis_utils/methodology.py +31 -24
- jarvis/jarvis_utils/output.py +164 -80
- jarvis/jarvis_utils/tag.py +2 -1
- jarvis/jarvis_utils/utils.py +84 -52
- {jarvis_ai_assistant-0.1.193.dist-info → jarvis_ai_assistant-0.1.195.dist-info}/METADATA +362 -230
- jarvis_ai_assistant-0.1.195.dist-info/RECORD +98 -0
- jarvis/jarvis_agent/file_input_handler.py +0 -112
- jarvis/jarvis_event/__init__.py +0 -0
- jarvis_ai_assistant-0.1.193.dist-info/RECORD +0 -99
- {jarvis_ai_assistant-0.1.193.dist-info → jarvis_ai_assistant-0.1.195.dist-info}/WHEEL +0 -0
- {jarvis_ai_assistant-0.1.193.dist-info → jarvis_ai_assistant-0.1.195.dist-info}/entry_points.txt +0 -0
- {jarvis_ai_assistant-0.1.193.dist-info → jarvis_ai_assistant-0.1.195.dist-info}/licenses/LICENSE +0 -0
- {jarvis_ai_assistant-0.1.193.dist-info → jarvis_ai_assistant-0.1.195.dist-info}/top_level.txt +0 -0
jarvis/jarvis_utils/config.py
CHANGED
@@ -1,16 +1,12 @@
|
|
1
1
|
# -*- coding: utf-8 -*-
|
2
2
|
import os
|
3
3
|
from functools import lru_cache
|
4
|
-
|
5
4
|
from typing import Any, Dict, List
|
6
5
|
|
7
|
-
|
8
6
|
import yaml
|
9
7
|
|
10
|
-
|
11
8
|
from jarvis.jarvis_utils.builtin_replace_map import BUILTIN_REPLACE_MAP
|
12
9
|
|
13
|
-
|
14
10
|
# 全局环境变量存储
|
15
11
|
|
16
12
|
GLOBAL_CONFIG_DATA: Dict[str, Any] = {}
|
@@ -21,6 +17,7 @@ def set_global_env_data(env_data: Dict[str, Any]) -> None:
|
|
21
17
|
global GLOBAL_CONFIG_DATA
|
22
18
|
GLOBAL_CONFIG_DATA = env_data
|
23
19
|
|
20
|
+
|
24
21
|
def set_config(key: str, value: Any) -> None:
|
25
22
|
"""设置配置"""
|
26
23
|
GLOBAL_CONFIG_DATA[key] = value
|
@@ -32,48 +29,53 @@ def set_config(key: str, value: Any) -> None:
|
|
32
29
|
所有配置都从环境变量中读取,带有回退默认值。
|
33
30
|
"""
|
34
31
|
|
32
|
+
|
35
33
|
def get_git_commit_prompt() -> str:
|
36
34
|
"""
|
37
35
|
获取Git提交提示模板
|
38
|
-
|
36
|
+
|
39
37
|
返回:
|
40
38
|
str: Git提交信息生成提示模板,如果未配置则返回空字符串
|
41
39
|
"""
|
42
40
|
return GLOBAL_CONFIG_DATA.get("JARVIS_GIT_COMMIT_PROMPT", "")
|
43
41
|
|
42
|
+
|
44
43
|
# 输出窗口预留大小
|
45
44
|
INPUT_WINDOW_REVERSE_SIZE = 2048
|
46
45
|
|
46
|
+
|
47
47
|
@lru_cache(maxsize=None)
|
48
48
|
def get_replace_map() -> dict:
|
49
49
|
"""
|
50
50
|
获取替换映射表。
|
51
|
-
|
51
|
+
|
52
52
|
优先使用GLOBAL_CONFIG_DATA['JARVIS_REPLACE_MAP']的配置,
|
53
53
|
如果没有则从数据目录下的replace_map.yaml文件中读取替换映射表,
|
54
54
|
如果文件不存在则返回内置替换映射表。
|
55
|
-
|
55
|
+
|
56
56
|
返回:
|
57
57
|
dict: 合并后的替换映射表字典(内置+文件中的映射表)
|
58
58
|
"""
|
59
|
-
if
|
60
|
-
return {**BUILTIN_REPLACE_MAP, **GLOBAL_CONFIG_DATA[
|
61
|
-
|
62
|
-
replace_map_path = os.path.join(get_data_dir(),
|
59
|
+
if "JARVIS_REPLACE_MAP" in GLOBAL_CONFIG_DATA:
|
60
|
+
return {**BUILTIN_REPLACE_MAP, **GLOBAL_CONFIG_DATA["JARVIS_REPLACE_MAP"]}
|
61
|
+
|
62
|
+
replace_map_path = os.path.join(get_data_dir(), "replace_map.yaml")
|
63
63
|
if not os.path.exists(replace_map_path):
|
64
64
|
return BUILTIN_REPLACE_MAP.copy()
|
65
|
-
|
66
|
-
from jarvis.jarvis_utils.output import
|
65
|
+
|
66
|
+
from jarvis.jarvis_utils.output import OutputType, PrettyOutput
|
67
|
+
|
67
68
|
PrettyOutput.print(
|
68
69
|
"警告:使用replace_map.yaml进行配置的方式已被弃用,将在未来版本中移除。"
|
69
70
|
"请迁移到使用GLOBAL_CONFIG_DATA中的JARVIS_REPLACE_MAP配置。",
|
70
|
-
output_type=OutputType.WARNING
|
71
|
+
output_type=OutputType.WARNING,
|
71
72
|
)
|
72
|
-
|
73
|
-
with open(replace_map_path,
|
73
|
+
|
74
|
+
with open(replace_map_path, "r", encoding="utf-8", errors="ignore") as file:
|
74
75
|
file_map = yaml.safe_load(file) or {}
|
75
76
|
return {**BUILTIN_REPLACE_MAP, **file_map}
|
76
77
|
|
78
|
+
|
77
79
|
def get_max_token_count() -> int:
|
78
80
|
"""
|
79
81
|
获取模型允许的最大token数量。
|
@@ -81,7 +83,8 @@ def get_max_token_count() -> int:
|
|
81
83
|
返回:
|
82
84
|
int: 模型能处理的最大token数量。
|
83
85
|
"""
|
84
|
-
return int(GLOBAL_CONFIG_DATA.get(
|
86
|
+
return int(GLOBAL_CONFIG_DATA.get("JARVIS_MAX_TOKEN_COUNT", "960000"))
|
87
|
+
|
85
88
|
|
86
89
|
def get_max_input_token_count() -> int:
|
87
90
|
"""
|
@@ -90,7 +93,7 @@ def get_max_input_token_count() -> int:
|
|
90
93
|
返回:
|
91
94
|
int: 模型能处理的最大输入token数量。
|
92
95
|
"""
|
93
|
-
return int(GLOBAL_CONFIG_DATA.get(
|
96
|
+
return int(GLOBAL_CONFIG_DATA.get("JARVIS_MAX_INPUT_TOKEN_COUNT", "32000"))
|
94
97
|
|
95
98
|
|
96
99
|
def is_auto_complete() -> bool:
|
@@ -100,7 +103,7 @@ def is_auto_complete() -> bool:
|
|
100
103
|
返回:
|
101
104
|
bool: 如果启用了自动补全则返回True,默认为False
|
102
105
|
"""
|
103
|
-
return GLOBAL_CONFIG_DATA.get(
|
106
|
+
return GLOBAL_CONFIG_DATA.get("JARVIS_AUTO_COMPLETE", False) == True
|
104
107
|
|
105
108
|
|
106
109
|
def get_shell_name() -> str:
|
@@ -110,8 +113,10 @@ def get_shell_name() -> str:
|
|
110
113
|
返回:
|
111
114
|
str: Shell名称(例如bash, zsh),默认为bash
|
112
115
|
"""
|
113
|
-
shell_path = GLOBAL_CONFIG_DATA.get(
|
116
|
+
shell_path = GLOBAL_CONFIG_DATA.get("SHELL", "/bin/bash")
|
114
117
|
return os.path.basename(shell_path)
|
118
|
+
|
119
|
+
|
115
120
|
def get_normal_platform_name() -> str:
|
116
121
|
"""
|
117
122
|
获取正常操作的平台名称。
|
@@ -119,7 +124,9 @@ def get_normal_platform_name() -> str:
|
|
119
124
|
返回:
|
120
125
|
str: 平台名称,默认为'yuanbao'
|
121
126
|
"""
|
122
|
-
return GLOBAL_CONFIG_DATA.get(
|
127
|
+
return GLOBAL_CONFIG_DATA.get("JARVIS_PLATFORM", "yuanbao")
|
128
|
+
|
129
|
+
|
123
130
|
def get_normal_model_name() -> str:
|
124
131
|
"""
|
125
132
|
获取正常操作的模型名称。
|
@@ -127,7 +134,7 @@ def get_normal_model_name() -> str:
|
|
127
134
|
返回:
|
128
135
|
str: 模型名称,默认为'deep_seek'
|
129
136
|
"""
|
130
|
-
return GLOBAL_CONFIG_DATA.get(
|
137
|
+
return GLOBAL_CONFIG_DATA.get("JARVIS_MODEL", "deep_seek_v3")
|
131
138
|
|
132
139
|
|
133
140
|
def get_thinking_platform_name() -> str:
|
@@ -137,7 +144,11 @@ def get_thinking_platform_name() -> str:
|
|
137
144
|
返回:
|
138
145
|
str: 平台名称,默认为'yuanbao'
|
139
146
|
"""
|
140
|
-
return GLOBAL_CONFIG_DATA.get(
|
147
|
+
return GLOBAL_CONFIG_DATA.get(
|
148
|
+
"JARVIS_THINKING_PLATFORM", GLOBAL_CONFIG_DATA.get("JARVIS_PLATFORM", "yuanbao")
|
149
|
+
)
|
150
|
+
|
151
|
+
|
141
152
|
def get_thinking_model_name() -> str:
|
142
153
|
"""
|
143
154
|
获取思考操作的模型名称。
|
@@ -145,7 +156,10 @@ def get_thinking_model_name() -> str:
|
|
145
156
|
返回:
|
146
157
|
str: 模型名称,默认为'deep_seek'
|
147
158
|
"""
|
148
|
-
return GLOBAL_CONFIG_DATA.get(
|
159
|
+
return GLOBAL_CONFIG_DATA.get(
|
160
|
+
"JARVIS_THINKING_MODEL", GLOBAL_CONFIG_DATA.get("JARVIS_MODEL", "deep_seek")
|
161
|
+
)
|
162
|
+
|
149
163
|
|
150
164
|
def is_execute_tool_confirm() -> bool:
|
151
165
|
"""
|
@@ -154,7 +168,9 @@ def is_execute_tool_confirm() -> bool:
|
|
154
168
|
返回:
|
155
169
|
bool: 如果需要确认则返回True,默认为False
|
156
170
|
"""
|
157
|
-
return GLOBAL_CONFIG_DATA.get(
|
171
|
+
return GLOBAL_CONFIG_DATA.get("JARVIS_EXECUTE_TOOL_CONFIRM", False) == True
|
172
|
+
|
173
|
+
|
158
174
|
def is_confirm_before_apply_patch() -> bool:
|
159
175
|
"""
|
160
176
|
检查应用补丁前是否需要确认。
|
@@ -162,7 +178,8 @@ def is_confirm_before_apply_patch() -> bool:
|
|
162
178
|
返回:
|
163
179
|
bool: 如果需要确认则返回True,默认为False
|
164
180
|
"""
|
165
|
-
return GLOBAL_CONFIG_DATA.get(
|
181
|
+
return GLOBAL_CONFIG_DATA.get("JARVIS_CONFIRM_BEFORE_APPLY_PATCH", True) == True
|
182
|
+
|
166
183
|
|
167
184
|
def get_max_tool_call_count() -> int:
|
168
185
|
"""
|
@@ -171,22 +188,23 @@ def get_max_tool_call_count() -> int:
|
|
171
188
|
返回:
|
172
189
|
int: 最大连续工具调用次数,默认为20
|
173
190
|
"""
|
174
|
-
return int(GLOBAL_CONFIG_DATA.get(
|
191
|
+
return int(GLOBAL_CONFIG_DATA.get("JARVIS_MAX_TOOL_CALL_COUNT", "20"))
|
175
192
|
|
176
193
|
|
177
194
|
def get_data_dir() -> str:
|
178
195
|
"""
|
179
196
|
获取Jarvis数据存储目录路径。
|
180
|
-
|
197
|
+
|
181
198
|
返回:
|
182
199
|
str: 数据目录路径,优先从JARVIS_DATA_PATH环境变量获取,
|
183
200
|
如果未设置或为空,则使用~/.jarvis作为默认值
|
184
201
|
"""
|
185
|
-
data_path = GLOBAL_CONFIG_DATA.get(
|
202
|
+
data_path = GLOBAL_CONFIG_DATA.get("JARVIS_DATA_PATH", "").strip()
|
186
203
|
if not data_path:
|
187
|
-
return os.path.expanduser(
|
204
|
+
return os.path.expanduser("~/.jarvis")
|
188
205
|
return data_path
|
189
206
|
|
207
|
+
|
190
208
|
def get_auto_update() -> bool:
|
191
209
|
"""
|
192
210
|
获取是否自动更新git仓库。
|
@@ -194,7 +212,8 @@ def get_auto_update() -> bool:
|
|
194
212
|
返回:
|
195
213
|
bool: 如果需要自动更新则返回True,默认为True
|
196
214
|
"""
|
197
|
-
return GLOBAL_CONFIG_DATA.get(
|
215
|
+
return GLOBAL_CONFIG_DATA.get("JARVIS_AUTO_UPDATE", True) == True
|
216
|
+
|
198
217
|
|
199
218
|
def get_max_big_content_size() -> int:
|
200
219
|
"""
|
@@ -203,7 +222,8 @@ def get_max_big_content_size() -> int:
|
|
203
222
|
返回:
|
204
223
|
int: 最大大内容大小
|
205
224
|
"""
|
206
|
-
return int(GLOBAL_CONFIG_DATA.get(
|
225
|
+
return int(GLOBAL_CONFIG_DATA.get("JARVIS_MAX_BIG_CONTENT_SIZE", "160000"))
|
226
|
+
|
207
227
|
|
208
228
|
def get_pretty_output() -> bool:
|
209
229
|
"""
|
@@ -212,7 +232,8 @@ def get_pretty_output() -> bool:
|
|
212
232
|
返回:
|
213
233
|
bool: 如果启用PrettyOutput则返回True,默认为True
|
214
234
|
"""
|
215
|
-
return GLOBAL_CONFIG_DATA.get(
|
235
|
+
return GLOBAL_CONFIG_DATA.get("JARVIS_PRETTY_OUTPUT", False) == True
|
236
|
+
|
216
237
|
|
217
238
|
def is_use_methodology() -> bool:
|
218
239
|
"""
|
@@ -221,7 +242,8 @@ def is_use_methodology() -> bool:
|
|
221
242
|
返回:
|
222
243
|
bool: 如果启用方法论则返回True,默认为True
|
223
244
|
"""
|
224
|
-
return GLOBAL_CONFIG_DATA.get(
|
245
|
+
return GLOBAL_CONFIG_DATA.get("JARVIS_USE_METHODOLOGY", True) == True
|
246
|
+
|
225
247
|
|
226
248
|
def is_use_analysis() -> bool:
|
227
249
|
"""
|
@@ -230,7 +252,8 @@ def is_use_analysis() -> bool:
|
|
230
252
|
返回:
|
231
253
|
bool: 如果启用任务分析则返回True,默认为True
|
232
254
|
"""
|
233
|
-
return GLOBAL_CONFIG_DATA.get(
|
255
|
+
return GLOBAL_CONFIG_DATA.get("JARVIS_USE_ANALYSIS", True) == True
|
256
|
+
|
234
257
|
|
235
258
|
def is_print_prompt() -> bool:
|
236
259
|
"""
|
@@ -239,7 +262,7 @@ def is_print_prompt() -> bool:
|
|
239
262
|
返回:
|
240
263
|
bool: 如果打印提示则返回True,默认为True
|
241
264
|
"""
|
242
|
-
return GLOBAL_CONFIG_DATA.get(
|
265
|
+
return GLOBAL_CONFIG_DATA.get("JARVIS_PRINT_PROMPT", False) == True
|
243
266
|
|
244
267
|
|
245
268
|
def get_mcp_config() -> List[Dict[str, Any]]:
|
@@ -249,4 +272,4 @@ def get_mcp_config() -> List[Dict[str, Any]]:
|
|
249
272
|
返回:
|
250
273
|
List[Dict[str, Any]]: MCP配置项列表,如果未配置则返回空列表
|
251
274
|
"""
|
252
|
-
return GLOBAL_CONFIG_DATA.get("JARVIS_MCP", [])
|
275
|
+
return GLOBAL_CONFIG_DATA.get("JARVIS_MCP", [])
|
jarvis/jarvis_utils/embedding.py
CHANGED
@@ -1,9 +1,7 @@
|
|
1
1
|
# -*- coding: utf-8 -*-
|
2
2
|
import functools
|
3
3
|
import os
|
4
|
-
from typing import List
|
5
|
-
|
6
|
-
from transformers import AutoTokenizer
|
4
|
+
from typing import Any, List
|
7
5
|
|
8
6
|
from jarvis.jarvis_utils.config import get_data_dir
|
9
7
|
from jarvis.jarvis_utils.output import OutputType, PrettyOutput
|
@@ -11,6 +9,7 @@ from jarvis.jarvis_utils.output import OutputType, PrettyOutput
|
|
11
9
|
# 全局缓存,避免重复加载模型
|
12
10
|
_global_tokenizers = {}
|
13
11
|
|
12
|
+
|
14
13
|
def get_context_token_count(text: str) -> int:
|
15
14
|
"""使用分词器获取文本的token数量。
|
16
15
|
|
@@ -21,12 +20,13 @@ def get_context_token_count(text: str) -> int:
|
|
21
20
|
int: 文本中的token数量
|
22
21
|
"""
|
23
22
|
try:
|
24
|
-
|
23
|
+
from transformers import AutoTokenizer # type: ignore
|
24
|
+
tokenizer : AutoTokenizer = load_tokenizer()
|
25
25
|
# 分批处理长文本,确保不超过模型最大长度
|
26
26
|
total_tokens = 0
|
27
27
|
chunk_size = 100 # 每次处理100个字符,避免超过模型最大长度(考虑到中文字符可能被编码成多个token)
|
28
28
|
for i in range(0, len(text), chunk_size):
|
29
|
-
chunk = text[i:i + chunk_size]
|
29
|
+
chunk = text[i : i + chunk_size]
|
30
30
|
tokens = tokenizer.encode(chunk) # type: ignore
|
31
31
|
total_tokens += len(tokens)
|
32
32
|
return total_tokens
|
@@ -34,7 +34,10 @@ def get_context_token_count(text: str) -> int:
|
|
34
34
|
PrettyOutput.print(f"计算token失败: {str(e)}", OutputType.WARNING)
|
35
35
|
return len(text) // 4 # 每个token大约4个字符的粗略估计
|
36
36
|
|
37
|
-
|
37
|
+
|
38
|
+
def split_text_into_chunks(
|
39
|
+
text: str, max_length: int = 512, min_length: int = 50
|
40
|
+
) -> List[str]:
|
38
41
|
"""将文本分割成块,基于token数量进行切割。
|
39
42
|
|
40
43
|
参数:
|
@@ -52,15 +55,18 @@ def split_text_into_chunks(text: str, max_length: int = 512, min_length: int = 5
|
|
52
55
|
chunks = []
|
53
56
|
current_chunk = ""
|
54
57
|
current_tokens = 0
|
55
|
-
|
58
|
+
|
56
59
|
# 按较大的块处理文本,避免破坏token边界
|
57
60
|
chunk_size = 50 # 每次处理50个字符
|
58
61
|
for i in range(0, len(text), chunk_size):
|
59
|
-
chunk = text[i:i + chunk_size]
|
62
|
+
chunk = text[i : i + chunk_size]
|
60
63
|
chunk_tokens = get_context_token_count(chunk)
|
61
|
-
|
64
|
+
|
62
65
|
# 如果当前块加上新块会超过最大长度,且当前块已经达到最小长度,则保存当前块
|
63
|
-
if
|
66
|
+
if (
|
67
|
+
current_tokens + chunk_tokens > max_length
|
68
|
+
and current_tokens >= min_length
|
69
|
+
):
|
64
70
|
chunks.append(current_chunk)
|
65
71
|
current_chunk = chunk
|
66
72
|
current_tokens = chunk_tokens
|
@@ -77,17 +83,20 @@ def split_text_into_chunks(text: str, max_length: int = 512, min_length: int = 5
|
|
77
83
|
except Exception as e:
|
78
84
|
PrettyOutput.print(f"文本分割失败: {str(e)}", OutputType.WARNING)
|
79
85
|
# 发生错误时回退到简单的字符分割
|
80
|
-
return [text[i:i + max_length] for i in range(0, len(text), max_length)]
|
86
|
+
return [text[i : i + max_length] for i in range(0, len(text), max_length)]
|
81
87
|
|
82
88
|
|
83
89
|
@functools.lru_cache(maxsize=1)
|
84
|
-
def load_tokenizer() ->
|
90
|
+
def load_tokenizer() -> Any:
|
85
91
|
"""
|
86
92
|
加载用于文本处理的分词器,使用缓存避免重复加载。
|
87
93
|
|
88
94
|
返回:
|
89
95
|
AutoTokenizer: 加载的分词器
|
90
96
|
"""
|
97
|
+
|
98
|
+
from transformers import AutoTokenizer # type: ignore
|
99
|
+
|
91
100
|
model_name = "gpt2"
|
92
101
|
cache_dir = os.path.join(get_data_dir(), "huggingface", "hub")
|
93
102
|
|
@@ -97,18 +106,14 @@ def load_tokenizer() -> AutoTokenizer:
|
|
97
106
|
|
98
107
|
try:
|
99
108
|
tokenizer = AutoTokenizer.from_pretrained(
|
100
|
-
model_name,
|
101
|
-
cache_dir=cache_dir,
|
102
|
-
local_files_only=True
|
109
|
+
model_name, cache_dir=cache_dir, local_files_only=True
|
103
110
|
)
|
104
111
|
except Exception:
|
105
112
|
tokenizer = AutoTokenizer.from_pretrained(
|
106
|
-
model_name,
|
107
|
-
cache_dir=cache_dir,
|
108
|
-
local_files_only=False
|
113
|
+
model_name, cache_dir=cache_dir, local_files_only=False
|
109
114
|
)
|
110
115
|
|
111
116
|
# 保存到全局缓存
|
112
117
|
_global_tokenizers[model_name] = tokenizer
|
113
118
|
|
114
|
-
return tokenizer
|
119
|
+
return tokenizer # type: ignore
|
@@ -4,6 +4,7 @@ import unicodedata
|
|
4
4
|
|
5
5
|
class FileProcessor:
|
6
6
|
"""Base class for file processor"""
|
7
|
+
|
7
8
|
@staticmethod
|
8
9
|
def can_handle(file_path: str) -> bool:
|
9
10
|
"""Determine if the file can be processed"""
|
@@ -14,9 +15,11 @@ class FileProcessor:
|
|
14
15
|
"""Extract file text content"""
|
15
16
|
raise NotImplementedError
|
16
17
|
|
18
|
+
|
17
19
|
class TextFileProcessor(FileProcessor):
|
18
20
|
"""Text file processor"""
|
19
|
-
|
21
|
+
|
22
|
+
ENCODINGS = ["utf-8", "gbk", "gb2312", "latin1"]
|
20
23
|
SAMPLE_SIZE = 8192 # Read the first 8KB to detect encoding
|
21
24
|
|
22
25
|
@staticmethod
|
@@ -24,16 +27,20 @@ class TextFileProcessor(FileProcessor):
|
|
24
27
|
"""Determine if the file is a text file by trying to decode it"""
|
25
28
|
try:
|
26
29
|
# Read the first part of the file to detect encoding
|
27
|
-
with open(file_path,
|
30
|
+
with open(file_path, "rb") as f:
|
28
31
|
sample = f.read(TextFileProcessor.SAMPLE_SIZE)
|
29
32
|
|
30
33
|
# Check if it contains null bytes (usually represents a binary file)
|
31
|
-
if b
|
34
|
+
if b"\x00" in sample:
|
32
35
|
return False
|
33
36
|
|
34
37
|
# Check if it contains too many non-printable characters (usually represents a binary file)
|
35
|
-
non_printable = sum(
|
36
|
-
|
38
|
+
non_printable = sum(
|
39
|
+
1 for byte in sample if byte < 32 and byte not in (9, 10, 13)
|
40
|
+
) # tab, newline, carriage return
|
41
|
+
if (
|
42
|
+
non_printable / len(sample) > 0.3
|
43
|
+
): # If non-printable characters exceed 30%, it is considered a binary file
|
37
44
|
return False
|
38
45
|
|
39
46
|
# Try to decode with different encodings
|
@@ -55,7 +62,7 @@ class TextFileProcessor(FileProcessor):
|
|
55
62
|
detected_encoding = None
|
56
63
|
try:
|
57
64
|
# First try to detect encoding
|
58
|
-
with open(file_path,
|
65
|
+
with open(file_path, "rb") as f:
|
59
66
|
raw_data = f.read()
|
60
67
|
|
61
68
|
# Try different encodings
|
@@ -68,14 +75,14 @@ class TextFileProcessor(FileProcessor):
|
|
68
75
|
continue
|
69
76
|
|
70
77
|
if not detected_encoding:
|
71
|
-
raise UnicodeDecodeError(f"Failed to decode file with supported encodings: {file_path}")
|
78
|
+
raise UnicodeDecodeError(f"Failed to decode file with supported encodings: {file_path}") # type: ignore
|
72
79
|
|
73
80
|
# Use the detected encoding to read the file
|
74
|
-
with open(file_path,
|
81
|
+
with open(file_path, "r", encoding=detected_encoding, errors="ignore") as f:
|
75
82
|
content = f.read()
|
76
83
|
|
77
84
|
# Normalize Unicode characters
|
78
|
-
content = unicodedata.normalize(
|
85
|
+
content = unicodedata.normalize("NFKC", content)
|
79
86
|
|
80
87
|
return content
|
81
88
|
|