jarvis-ai-assistant 0.1.193__py3-none-any.whl → 0.1.194__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of jarvis-ai-assistant might be problematic. Click here for more details.
- jarvis/__init__.py +1 -1
- jarvis/jarvis_agent/__init__.py +34 -41
- jarvis/jarvis_agent/builtin_input_handler.py +26 -4
- jarvis/jarvis_agent/jarvis.py +38 -22
- jarvis/jarvis_agent/main.py +20 -12
- jarvis/jarvis_agent/output_handler.py +7 -7
- jarvis/jarvis_agent/shell_input_handler.py +14 -11
- jarvis/jarvis_code_agent/code_agent.py +91 -85
- jarvis/jarvis_code_agent/lint.py +92 -105
- jarvis/jarvis_code_analysis/checklists/__init__.py +1 -1
- jarvis/jarvis_code_analysis/checklists/c_cpp.py +1 -1
- jarvis/jarvis_code_analysis/checklists/csharp.py +1 -1
- jarvis/jarvis_code_analysis/checklists/data_format.py +1 -1
- jarvis/jarvis_code_analysis/checklists/devops.py +1 -1
- jarvis/jarvis_code_analysis/checklists/docs.py +1 -1
- jarvis/jarvis_code_analysis/checklists/go.py +1 -1
- jarvis/jarvis_code_analysis/checklists/infrastructure.py +1 -1
- jarvis/jarvis_code_analysis/checklists/java.py +1 -1
- jarvis/jarvis_code_analysis/checklists/javascript.py +1 -1
- jarvis/jarvis_code_analysis/checklists/kotlin.py +1 -1
- jarvis/jarvis_code_analysis/checklists/loader.py +51 -35
- jarvis/jarvis_code_analysis/checklists/php.py +1 -1
- jarvis/jarvis_code_analysis/checklists/python.py +1 -1
- jarvis/jarvis_code_analysis/checklists/ruby.py +1 -1
- jarvis/jarvis_code_analysis/checklists/rust.py +1 -1
- jarvis/jarvis_code_analysis/checklists/shell.py +1 -1
- jarvis/jarvis_code_analysis/checklists/sql.py +1 -1
- jarvis/jarvis_code_analysis/checklists/swift.py +1 -1
- jarvis/jarvis_code_analysis/checklists/web.py +1 -1
- jarvis/jarvis_code_analysis/code_review.py +293 -192
- jarvis/jarvis_dev/main.py +73 -56
- jarvis/jarvis_git_details/main.py +29 -33
- jarvis/jarvis_git_squash/main.py +13 -11
- jarvis/jarvis_git_utils/git_commiter.py +12 -3
- jarvis/jarvis_mcp/__init__.py +8 -10
- jarvis/jarvis_mcp/sse_mcp_client.py +182 -205
- jarvis/jarvis_mcp/stdio_mcp_client.py +93 -120
- jarvis/jarvis_mcp/streamable_mcp_client.py +117 -142
- jarvis/jarvis_methodology/main.py +75 -41
- jarvis/jarvis_multi_agent/__init__.py +24 -16
- jarvis/jarvis_multi_agent/main.py +10 -4
- jarvis/jarvis_platform/__init__.py +1 -1
- jarvis/jarvis_platform/base.py +49 -21
- jarvis/jarvis_platform/human.py +5 -3
- jarvis/jarvis_platform/kimi.py +96 -72
- jarvis/jarvis_platform/openai.py +23 -28
- jarvis/jarvis_platform/registry.py +50 -33
- jarvis/jarvis_platform/tongyi.py +16 -10
- jarvis/jarvis_platform/yuanbao.py +197 -144
- jarvis/jarvis_platform_manager/main.py +4 -2
- jarvis/jarvis_smart_shell/main.py +35 -29
- jarvis/jarvis_tools/ask_user.py +8 -16
- jarvis/jarvis_tools/base.py +3 -2
- jarvis/jarvis_tools/chdir.py +7 -19
- jarvis/jarvis_tools/cli/main.py +14 -10
- jarvis/jarvis_tools/code_plan.py +10 -31
- jarvis/jarvis_tools/create_code_agent.py +10 -13
- jarvis/jarvis_tools/create_sub_agent.py +10 -22
- jarvis/jarvis_tools/edit_file.py +98 -76
- jarvis/jarvis_tools/execute_script.py +46 -46
- jarvis/jarvis_tools/file_analyzer.py +22 -34
- jarvis/jarvis_tools/file_operation.py +69 -62
- jarvis/jarvis_tools/generate_new_tool.py +0 -2
- jarvis/jarvis_tools/methodology.py +19 -23
- jarvis/jarvis_tools/read_code.py +35 -35
- jarvis/jarvis_tools/read_webpage.py +7 -16
- jarvis/jarvis_tools/registry.py +63 -30
- jarvis/jarvis_tools/rewrite_file.py +26 -29
- jarvis/jarvis_tools/search_web.py +5 -8
- jarvis/jarvis_tools/virtual_tty.py +133 -122
- jarvis/jarvis_utils/__init__.py +0 -1
- jarvis/jarvis_utils/builtin_replace_map.py +9 -9
- jarvis/jarvis_utils/config.py +59 -32
- jarvis/jarvis_utils/embedding.py +17 -14
- jarvis/jarvis_utils/file_processors.py +16 -9
- jarvis/jarvis_utils/git_utils.py +140 -99
- jarvis/jarvis_utils/globals.py +1 -1
- jarvis/jarvis_utils/input.py +84 -52
- jarvis/jarvis_utils/methodology.py +28 -21
- jarvis/jarvis_utils/output.py +159 -78
- jarvis/jarvis_utils/tag.py +2 -1
- jarvis/jarvis_utils/utils.py +85 -51
- {jarvis_ai_assistant-0.1.193.dist-info → jarvis_ai_assistant-0.1.194.dist-info}/METADATA +314 -230
- jarvis_ai_assistant-0.1.194.dist-info/RECORD +97 -0
- jarvis/jarvis_agent/file_input_handler.py +0 -112
- jarvis/jarvis_event/__init__.py +0 -0
- jarvis_ai_assistant-0.1.193.dist-info/RECORD +0 -99
- {jarvis_ai_assistant-0.1.193.dist-info → jarvis_ai_assistant-0.1.194.dist-info}/WHEEL +0 -0
- {jarvis_ai_assistant-0.1.193.dist-info → jarvis_ai_assistant-0.1.194.dist-info}/entry_points.txt +0 -0
- {jarvis_ai_assistant-0.1.193.dist-info → jarvis_ai_assistant-0.1.194.dist-info}/licenses/LICENSE +0 -0
- {jarvis_ai_assistant-0.1.193.dist-info → jarvis_ai_assistant-0.1.194.dist-info}/top_level.txt +0 -0
jarvis/jarvis_utils/config.py
CHANGED
|
@@ -21,6 +21,7 @@ def set_global_env_data(env_data: Dict[str, Any]) -> None:
|
|
|
21
21
|
global GLOBAL_CONFIG_DATA
|
|
22
22
|
GLOBAL_CONFIG_DATA = env_data
|
|
23
23
|
|
|
24
|
+
|
|
24
25
|
def set_config(key: str, value: Any) -> None:
|
|
25
26
|
"""设置配置"""
|
|
26
27
|
GLOBAL_CONFIG_DATA[key] = value
|
|
@@ -32,48 +33,53 @@ def set_config(key: str, value: Any) -> None:
|
|
|
32
33
|
所有配置都从环境变量中读取,带有回退默认值。
|
|
33
34
|
"""
|
|
34
35
|
|
|
36
|
+
|
|
35
37
|
def get_git_commit_prompt() -> str:
|
|
36
38
|
"""
|
|
37
39
|
获取Git提交提示模板
|
|
38
|
-
|
|
40
|
+
|
|
39
41
|
返回:
|
|
40
42
|
str: Git提交信息生成提示模板,如果未配置则返回空字符串
|
|
41
43
|
"""
|
|
42
44
|
return GLOBAL_CONFIG_DATA.get("JARVIS_GIT_COMMIT_PROMPT", "")
|
|
43
45
|
|
|
46
|
+
|
|
44
47
|
# 输出窗口预留大小
|
|
45
48
|
INPUT_WINDOW_REVERSE_SIZE = 2048
|
|
46
49
|
|
|
50
|
+
|
|
47
51
|
@lru_cache(maxsize=None)
|
|
48
52
|
def get_replace_map() -> dict:
|
|
49
53
|
"""
|
|
50
54
|
获取替换映射表。
|
|
51
|
-
|
|
55
|
+
|
|
52
56
|
优先使用GLOBAL_CONFIG_DATA['JARVIS_REPLACE_MAP']的配置,
|
|
53
57
|
如果没有则从数据目录下的replace_map.yaml文件中读取替换映射表,
|
|
54
58
|
如果文件不存在则返回内置替换映射表。
|
|
55
|
-
|
|
59
|
+
|
|
56
60
|
返回:
|
|
57
61
|
dict: 合并后的替换映射表字典(内置+文件中的映射表)
|
|
58
62
|
"""
|
|
59
|
-
if
|
|
60
|
-
return {**BUILTIN_REPLACE_MAP, **GLOBAL_CONFIG_DATA[
|
|
61
|
-
|
|
62
|
-
replace_map_path = os.path.join(get_data_dir(),
|
|
63
|
+
if "JARVIS_REPLACE_MAP" in GLOBAL_CONFIG_DATA:
|
|
64
|
+
return {**BUILTIN_REPLACE_MAP, **GLOBAL_CONFIG_DATA["JARVIS_REPLACE_MAP"]}
|
|
65
|
+
|
|
66
|
+
replace_map_path = os.path.join(get_data_dir(), "replace_map.yaml")
|
|
63
67
|
if not os.path.exists(replace_map_path):
|
|
64
68
|
return BUILTIN_REPLACE_MAP.copy()
|
|
65
|
-
|
|
69
|
+
|
|
66
70
|
from jarvis.jarvis_utils.output import PrettyOutput, OutputType
|
|
71
|
+
|
|
67
72
|
PrettyOutput.print(
|
|
68
73
|
"警告:使用replace_map.yaml进行配置的方式已被弃用,将在未来版本中移除。"
|
|
69
74
|
"请迁移到使用GLOBAL_CONFIG_DATA中的JARVIS_REPLACE_MAP配置。",
|
|
70
|
-
output_type=OutputType.WARNING
|
|
75
|
+
output_type=OutputType.WARNING,
|
|
71
76
|
)
|
|
72
|
-
|
|
73
|
-
with open(replace_map_path,
|
|
77
|
+
|
|
78
|
+
with open(replace_map_path, "r", encoding="utf-8", errors="ignore") as file:
|
|
74
79
|
file_map = yaml.safe_load(file) or {}
|
|
75
80
|
return {**BUILTIN_REPLACE_MAP, **file_map}
|
|
76
81
|
|
|
82
|
+
|
|
77
83
|
def get_max_token_count() -> int:
|
|
78
84
|
"""
|
|
79
85
|
获取模型允许的最大token数量。
|
|
@@ -81,7 +87,8 @@ def get_max_token_count() -> int:
|
|
|
81
87
|
返回:
|
|
82
88
|
int: 模型能处理的最大token数量。
|
|
83
89
|
"""
|
|
84
|
-
return int(GLOBAL_CONFIG_DATA.get(
|
|
90
|
+
return int(GLOBAL_CONFIG_DATA.get("JARVIS_MAX_TOKEN_COUNT", "960000"))
|
|
91
|
+
|
|
85
92
|
|
|
86
93
|
def get_max_input_token_count() -> int:
|
|
87
94
|
"""
|
|
@@ -90,7 +97,7 @@ def get_max_input_token_count() -> int:
|
|
|
90
97
|
返回:
|
|
91
98
|
int: 模型能处理的最大输入token数量。
|
|
92
99
|
"""
|
|
93
|
-
return int(GLOBAL_CONFIG_DATA.get(
|
|
100
|
+
return int(GLOBAL_CONFIG_DATA.get("JARVIS_MAX_INPUT_TOKEN_COUNT", "32000"))
|
|
94
101
|
|
|
95
102
|
|
|
96
103
|
def is_auto_complete() -> bool:
|
|
@@ -100,7 +107,7 @@ def is_auto_complete() -> bool:
|
|
|
100
107
|
返回:
|
|
101
108
|
bool: 如果启用了自动补全则返回True,默认为False
|
|
102
109
|
"""
|
|
103
|
-
return GLOBAL_CONFIG_DATA.get(
|
|
110
|
+
return GLOBAL_CONFIG_DATA.get("JARVIS_AUTO_COMPLETE", False) == True
|
|
104
111
|
|
|
105
112
|
|
|
106
113
|
def get_shell_name() -> str:
|
|
@@ -110,8 +117,10 @@ def get_shell_name() -> str:
|
|
|
110
117
|
返回:
|
|
111
118
|
str: Shell名称(例如bash, zsh),默认为bash
|
|
112
119
|
"""
|
|
113
|
-
shell_path = GLOBAL_CONFIG_DATA.get(
|
|
120
|
+
shell_path = GLOBAL_CONFIG_DATA.get("SHELL", "/bin/bash")
|
|
114
121
|
return os.path.basename(shell_path)
|
|
122
|
+
|
|
123
|
+
|
|
115
124
|
def get_normal_platform_name() -> str:
|
|
116
125
|
"""
|
|
117
126
|
获取正常操作的平台名称。
|
|
@@ -119,7 +128,9 @@ def get_normal_platform_name() -> str:
|
|
|
119
128
|
返回:
|
|
120
129
|
str: 平台名称,默认为'yuanbao'
|
|
121
130
|
"""
|
|
122
|
-
return GLOBAL_CONFIG_DATA.get(
|
|
131
|
+
return GLOBAL_CONFIG_DATA.get("JARVIS_PLATFORM", "yuanbao")
|
|
132
|
+
|
|
133
|
+
|
|
123
134
|
def get_normal_model_name() -> str:
|
|
124
135
|
"""
|
|
125
136
|
获取正常操作的模型名称。
|
|
@@ -127,7 +138,7 @@ def get_normal_model_name() -> str:
|
|
|
127
138
|
返回:
|
|
128
139
|
str: 模型名称,默认为'deep_seek'
|
|
129
140
|
"""
|
|
130
|
-
return GLOBAL_CONFIG_DATA.get(
|
|
141
|
+
return GLOBAL_CONFIG_DATA.get("JARVIS_MODEL", "deep_seek_v3")
|
|
131
142
|
|
|
132
143
|
|
|
133
144
|
def get_thinking_platform_name() -> str:
|
|
@@ -137,7 +148,11 @@ def get_thinking_platform_name() -> str:
|
|
|
137
148
|
返回:
|
|
138
149
|
str: 平台名称,默认为'yuanbao'
|
|
139
150
|
"""
|
|
140
|
-
return GLOBAL_CONFIG_DATA.get(
|
|
151
|
+
return GLOBAL_CONFIG_DATA.get(
|
|
152
|
+
"JARVIS_THINKING_PLATFORM", GLOBAL_CONFIG_DATA.get("JARVIS_PLATFORM", "yuanbao")
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
|
|
141
156
|
def get_thinking_model_name() -> str:
|
|
142
157
|
"""
|
|
143
158
|
获取思考操作的模型名称。
|
|
@@ -145,7 +160,10 @@ def get_thinking_model_name() -> str:
|
|
|
145
160
|
返回:
|
|
146
161
|
str: 模型名称,默认为'deep_seek'
|
|
147
162
|
"""
|
|
148
|
-
return GLOBAL_CONFIG_DATA.get(
|
|
163
|
+
return GLOBAL_CONFIG_DATA.get(
|
|
164
|
+
"JARVIS_THINKING_MODEL", GLOBAL_CONFIG_DATA.get("JARVIS_MODEL", "deep_seek")
|
|
165
|
+
)
|
|
166
|
+
|
|
149
167
|
|
|
150
168
|
def is_execute_tool_confirm() -> bool:
|
|
151
169
|
"""
|
|
@@ -154,7 +172,9 @@ def is_execute_tool_confirm() -> bool:
|
|
|
154
172
|
返回:
|
|
155
173
|
bool: 如果需要确认则返回True,默认为False
|
|
156
174
|
"""
|
|
157
|
-
return GLOBAL_CONFIG_DATA.get(
|
|
175
|
+
return GLOBAL_CONFIG_DATA.get("JARVIS_EXECUTE_TOOL_CONFIRM", False) == True
|
|
176
|
+
|
|
177
|
+
|
|
158
178
|
def is_confirm_before_apply_patch() -> bool:
|
|
159
179
|
"""
|
|
160
180
|
检查应用补丁前是否需要确认。
|
|
@@ -162,7 +182,8 @@ def is_confirm_before_apply_patch() -> bool:
|
|
|
162
182
|
返回:
|
|
163
183
|
bool: 如果需要确认则返回True,默认为False
|
|
164
184
|
"""
|
|
165
|
-
return GLOBAL_CONFIG_DATA.get(
|
|
185
|
+
return GLOBAL_CONFIG_DATA.get("JARVIS_CONFIRM_BEFORE_APPLY_PATCH", True) == True
|
|
186
|
+
|
|
166
187
|
|
|
167
188
|
def get_max_tool_call_count() -> int:
|
|
168
189
|
"""
|
|
@@ -171,22 +192,23 @@ def get_max_tool_call_count() -> int:
|
|
|
171
192
|
返回:
|
|
172
193
|
int: 最大连续工具调用次数,默认为20
|
|
173
194
|
"""
|
|
174
|
-
return int(GLOBAL_CONFIG_DATA.get(
|
|
195
|
+
return int(GLOBAL_CONFIG_DATA.get("JARVIS_MAX_TOOL_CALL_COUNT", "20"))
|
|
175
196
|
|
|
176
197
|
|
|
177
198
|
def get_data_dir() -> str:
|
|
178
199
|
"""
|
|
179
200
|
获取Jarvis数据存储目录路径。
|
|
180
|
-
|
|
201
|
+
|
|
181
202
|
返回:
|
|
182
203
|
str: 数据目录路径,优先从JARVIS_DATA_PATH环境变量获取,
|
|
183
204
|
如果未设置或为空,则使用~/.jarvis作为默认值
|
|
184
205
|
"""
|
|
185
|
-
data_path = GLOBAL_CONFIG_DATA.get(
|
|
206
|
+
data_path = GLOBAL_CONFIG_DATA.get("JARVIS_DATA_PATH", "").strip()
|
|
186
207
|
if not data_path:
|
|
187
|
-
return os.path.expanduser(
|
|
208
|
+
return os.path.expanduser("~/.jarvis")
|
|
188
209
|
return data_path
|
|
189
210
|
|
|
211
|
+
|
|
190
212
|
def get_auto_update() -> bool:
|
|
191
213
|
"""
|
|
192
214
|
获取是否自动更新git仓库。
|
|
@@ -194,7 +216,8 @@ def get_auto_update() -> bool:
|
|
|
194
216
|
返回:
|
|
195
217
|
bool: 如果需要自动更新则返回True,默认为True
|
|
196
218
|
"""
|
|
197
|
-
return GLOBAL_CONFIG_DATA.get(
|
|
219
|
+
return GLOBAL_CONFIG_DATA.get("JARVIS_AUTO_UPDATE", True) == True
|
|
220
|
+
|
|
198
221
|
|
|
199
222
|
def get_max_big_content_size() -> int:
|
|
200
223
|
"""
|
|
@@ -203,7 +226,8 @@ def get_max_big_content_size() -> int:
|
|
|
203
226
|
返回:
|
|
204
227
|
int: 最大大内容大小
|
|
205
228
|
"""
|
|
206
|
-
return int(GLOBAL_CONFIG_DATA.get(
|
|
229
|
+
return int(GLOBAL_CONFIG_DATA.get("JARVIS_MAX_BIG_CONTENT_SIZE", "160000"))
|
|
230
|
+
|
|
207
231
|
|
|
208
232
|
def get_pretty_output() -> bool:
|
|
209
233
|
"""
|
|
@@ -212,7 +236,8 @@ def get_pretty_output() -> bool:
|
|
|
212
236
|
返回:
|
|
213
237
|
bool: 如果启用PrettyOutput则返回True,默认为True
|
|
214
238
|
"""
|
|
215
|
-
return GLOBAL_CONFIG_DATA.get(
|
|
239
|
+
return GLOBAL_CONFIG_DATA.get("JARVIS_PRETTY_OUTPUT", False) == True
|
|
240
|
+
|
|
216
241
|
|
|
217
242
|
def is_use_methodology() -> bool:
|
|
218
243
|
"""
|
|
@@ -221,7 +246,8 @@ def is_use_methodology() -> bool:
|
|
|
221
246
|
返回:
|
|
222
247
|
bool: 如果启用方法论则返回True,默认为True
|
|
223
248
|
"""
|
|
224
|
-
return GLOBAL_CONFIG_DATA.get(
|
|
249
|
+
return GLOBAL_CONFIG_DATA.get("JARVIS_USE_METHODOLOGY", True) == True
|
|
250
|
+
|
|
225
251
|
|
|
226
252
|
def is_use_analysis() -> bool:
|
|
227
253
|
"""
|
|
@@ -230,7 +256,8 @@ def is_use_analysis() -> bool:
|
|
|
230
256
|
返回:
|
|
231
257
|
bool: 如果启用任务分析则返回True,默认为True
|
|
232
258
|
"""
|
|
233
|
-
return GLOBAL_CONFIG_DATA.get(
|
|
259
|
+
return GLOBAL_CONFIG_DATA.get("JARVIS_USE_ANALYSIS", True) == True
|
|
260
|
+
|
|
234
261
|
|
|
235
262
|
def is_print_prompt() -> bool:
|
|
236
263
|
"""
|
|
@@ -239,7 +266,7 @@ def is_print_prompt() -> bool:
|
|
|
239
266
|
返回:
|
|
240
267
|
bool: 如果打印提示则返回True,默认为True
|
|
241
268
|
"""
|
|
242
|
-
return GLOBAL_CONFIG_DATA.get(
|
|
269
|
+
return GLOBAL_CONFIG_DATA.get("JARVIS_PRINT_PROMPT", False) == True
|
|
243
270
|
|
|
244
271
|
|
|
245
272
|
def get_mcp_config() -> List[Dict[str, Any]]:
|
|
@@ -249,4 +276,4 @@ def get_mcp_config() -> List[Dict[str, Any]]:
|
|
|
249
276
|
返回:
|
|
250
277
|
List[Dict[str, Any]]: MCP配置项列表,如果未配置则返回空列表
|
|
251
278
|
"""
|
|
252
|
-
return GLOBAL_CONFIG_DATA.get("JARVIS_MCP", [])
|
|
279
|
+
return GLOBAL_CONFIG_DATA.get("JARVIS_MCP", [])
|
jarvis/jarvis_utils/embedding.py
CHANGED
|
@@ -11,6 +11,7 @@ from jarvis.jarvis_utils.output import OutputType, PrettyOutput
|
|
|
11
11
|
# 全局缓存,避免重复加载模型
|
|
12
12
|
_global_tokenizers = {}
|
|
13
13
|
|
|
14
|
+
|
|
14
15
|
def get_context_token_count(text: str) -> int:
|
|
15
16
|
"""使用分词器获取文本的token数量。
|
|
16
17
|
|
|
@@ -26,7 +27,7 @@ def get_context_token_count(text: str) -> int:
|
|
|
26
27
|
total_tokens = 0
|
|
27
28
|
chunk_size = 100 # 每次处理100个字符,避免超过模型最大长度(考虑到中文字符可能被编码成多个token)
|
|
28
29
|
for i in range(0, len(text), chunk_size):
|
|
29
|
-
chunk = text[i:i + chunk_size]
|
|
30
|
+
chunk = text[i : i + chunk_size]
|
|
30
31
|
tokens = tokenizer.encode(chunk) # type: ignore
|
|
31
32
|
total_tokens += len(tokens)
|
|
32
33
|
return total_tokens
|
|
@@ -34,7 +35,10 @@ def get_context_token_count(text: str) -> int:
|
|
|
34
35
|
PrettyOutput.print(f"计算token失败: {str(e)}", OutputType.WARNING)
|
|
35
36
|
return len(text) // 4 # 每个token大约4个字符的粗略估计
|
|
36
37
|
|
|
37
|
-
|
|
38
|
+
|
|
39
|
+
def split_text_into_chunks(
|
|
40
|
+
text: str, max_length: int = 512, min_length: int = 50
|
|
41
|
+
) -> List[str]:
|
|
38
42
|
"""将文本分割成块,基于token数量进行切割。
|
|
39
43
|
|
|
40
44
|
参数:
|
|
@@ -52,15 +56,18 @@ def split_text_into_chunks(text: str, max_length: int = 512, min_length: int = 5
|
|
|
52
56
|
chunks = []
|
|
53
57
|
current_chunk = ""
|
|
54
58
|
current_tokens = 0
|
|
55
|
-
|
|
59
|
+
|
|
56
60
|
# 按较大的块处理文本,避免破坏token边界
|
|
57
61
|
chunk_size = 50 # 每次处理50个字符
|
|
58
62
|
for i in range(0, len(text), chunk_size):
|
|
59
|
-
chunk = text[i:i + chunk_size]
|
|
63
|
+
chunk = text[i : i + chunk_size]
|
|
60
64
|
chunk_tokens = get_context_token_count(chunk)
|
|
61
|
-
|
|
65
|
+
|
|
62
66
|
# 如果当前块加上新块会超过最大长度,且当前块已经达到最小长度,则保存当前块
|
|
63
|
-
if
|
|
67
|
+
if (
|
|
68
|
+
current_tokens + chunk_tokens > max_length
|
|
69
|
+
and current_tokens >= min_length
|
|
70
|
+
):
|
|
64
71
|
chunks.append(current_chunk)
|
|
65
72
|
current_chunk = chunk
|
|
66
73
|
current_tokens = chunk_tokens
|
|
@@ -77,7 +84,7 @@ def split_text_into_chunks(text: str, max_length: int = 512, min_length: int = 5
|
|
|
77
84
|
except Exception as e:
|
|
78
85
|
PrettyOutput.print(f"文本分割失败: {str(e)}", OutputType.WARNING)
|
|
79
86
|
# 发生错误时回退到简单的字符分割
|
|
80
|
-
return [text[i:i + max_length] for i in range(0, len(text), max_length)]
|
|
87
|
+
return [text[i : i + max_length] for i in range(0, len(text), max_length)]
|
|
81
88
|
|
|
82
89
|
|
|
83
90
|
@functools.lru_cache(maxsize=1)
|
|
@@ -97,18 +104,14 @@ def load_tokenizer() -> AutoTokenizer:
|
|
|
97
104
|
|
|
98
105
|
try:
|
|
99
106
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
100
|
-
model_name,
|
|
101
|
-
cache_dir=cache_dir,
|
|
102
|
-
local_files_only=True
|
|
107
|
+
model_name, cache_dir=cache_dir, local_files_only=True
|
|
103
108
|
)
|
|
104
109
|
except Exception:
|
|
105
110
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
106
|
-
model_name,
|
|
107
|
-
cache_dir=cache_dir,
|
|
108
|
-
local_files_only=False
|
|
111
|
+
model_name, cache_dir=cache_dir, local_files_only=False
|
|
109
112
|
)
|
|
110
113
|
|
|
111
114
|
# 保存到全局缓存
|
|
112
115
|
_global_tokenizers[model_name] = tokenizer
|
|
113
116
|
|
|
114
|
-
return tokenizer
|
|
117
|
+
return tokenizer # type: ignore
|
|
@@ -4,6 +4,7 @@ import unicodedata
|
|
|
4
4
|
|
|
5
5
|
class FileProcessor:
|
|
6
6
|
"""Base class for file processor"""
|
|
7
|
+
|
|
7
8
|
@staticmethod
|
|
8
9
|
def can_handle(file_path: str) -> bool:
|
|
9
10
|
"""Determine if the file can be processed"""
|
|
@@ -14,9 +15,11 @@ class FileProcessor:
|
|
|
14
15
|
"""Extract file text content"""
|
|
15
16
|
raise NotImplementedError
|
|
16
17
|
|
|
18
|
+
|
|
17
19
|
class TextFileProcessor(FileProcessor):
|
|
18
20
|
"""Text file processor"""
|
|
19
|
-
|
|
21
|
+
|
|
22
|
+
ENCODINGS = ["utf-8", "gbk", "gb2312", "latin1"]
|
|
20
23
|
SAMPLE_SIZE = 8192 # Read the first 8KB to detect encoding
|
|
21
24
|
|
|
22
25
|
@staticmethod
|
|
@@ -24,16 +27,20 @@ class TextFileProcessor(FileProcessor):
|
|
|
24
27
|
"""Determine if the file is a text file by trying to decode it"""
|
|
25
28
|
try:
|
|
26
29
|
# Read the first part of the file to detect encoding
|
|
27
|
-
with open(file_path,
|
|
30
|
+
with open(file_path, "rb") as f:
|
|
28
31
|
sample = f.read(TextFileProcessor.SAMPLE_SIZE)
|
|
29
32
|
|
|
30
33
|
# Check if it contains null bytes (usually represents a binary file)
|
|
31
|
-
if b
|
|
34
|
+
if b"\x00" in sample:
|
|
32
35
|
return False
|
|
33
36
|
|
|
34
37
|
# Check if it contains too many non-printable characters (usually represents a binary file)
|
|
35
|
-
non_printable = sum(
|
|
36
|
-
|
|
38
|
+
non_printable = sum(
|
|
39
|
+
1 for byte in sample if byte < 32 and byte not in (9, 10, 13)
|
|
40
|
+
) # tab, newline, carriage return
|
|
41
|
+
if (
|
|
42
|
+
non_printable / len(sample) > 0.3
|
|
43
|
+
): # If non-printable characters exceed 30%, it is considered a binary file
|
|
37
44
|
return False
|
|
38
45
|
|
|
39
46
|
# Try to decode with different encodings
|
|
@@ -55,7 +62,7 @@ class TextFileProcessor(FileProcessor):
|
|
|
55
62
|
detected_encoding = None
|
|
56
63
|
try:
|
|
57
64
|
# First try to detect encoding
|
|
58
|
-
with open(file_path,
|
|
65
|
+
with open(file_path, "rb") as f:
|
|
59
66
|
raw_data = f.read()
|
|
60
67
|
|
|
61
68
|
# Try different encodings
|
|
@@ -68,14 +75,14 @@ class TextFileProcessor(FileProcessor):
|
|
|
68
75
|
continue
|
|
69
76
|
|
|
70
77
|
if not detected_encoding:
|
|
71
|
-
raise UnicodeDecodeError(f"Failed to decode file with supported encodings: {file_path}")
|
|
78
|
+
raise UnicodeDecodeError(f"Failed to decode file with supported encodings: {file_path}") # type: ignore
|
|
72
79
|
|
|
73
80
|
# Use the detected encoding to read the file
|
|
74
|
-
with open(file_path,
|
|
81
|
+
with open(file_path, "r", encoding=detected_encoding, errors="ignore") as f:
|
|
75
82
|
content = f.read()
|
|
76
83
|
|
|
77
84
|
# Normalize Unicode characters
|
|
78
|
-
content = unicodedata.normalize(
|
|
85
|
+
content = unicodedata.normalize("NFKC", content)
|
|
79
86
|
|
|
80
87
|
return content
|
|
81
88
|
|