jarvis-ai-assistant 0.1.193__py3-none-any.whl → 0.1.195__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. jarvis/__init__.py +1 -1
  2. jarvis/jarvis_agent/__init__.py +45 -41
  3. jarvis/jarvis_agent/builtin_input_handler.py +26 -4
  4. jarvis/jarvis_agent/jarvis.py +30 -19
  5. jarvis/jarvis_agent/main.py +20 -12
  6. jarvis/jarvis_agent/output_handler.py +7 -7
  7. jarvis/jarvis_agent/shell_input_handler.py +14 -11
  8. jarvis/jarvis_code_agent/code_agent.py +81 -79
  9. jarvis/jarvis_code_agent/lint.py +92 -105
  10. jarvis/jarvis_code_analysis/checklists/__init__.py +1 -1
  11. jarvis/jarvis_code_analysis/checklists/c_cpp.py +1 -1
  12. jarvis/jarvis_code_analysis/checklists/csharp.py +1 -1
  13. jarvis/jarvis_code_analysis/checklists/data_format.py +1 -1
  14. jarvis/jarvis_code_analysis/checklists/devops.py +1 -1
  15. jarvis/jarvis_code_analysis/checklists/docs.py +1 -1
  16. jarvis/jarvis_code_analysis/checklists/go.py +1 -1
  17. jarvis/jarvis_code_analysis/checklists/infrastructure.py +1 -1
  18. jarvis/jarvis_code_analysis/checklists/java.py +1 -1
  19. jarvis/jarvis_code_analysis/checklists/javascript.py +1 -1
  20. jarvis/jarvis_code_analysis/checklists/kotlin.py +1 -1
  21. jarvis/jarvis_code_analysis/checklists/loader.py +31 -29
  22. jarvis/jarvis_code_analysis/checklists/php.py +1 -1
  23. jarvis/jarvis_code_analysis/checklists/python.py +1 -1
  24. jarvis/jarvis_code_analysis/checklists/ruby.py +1 -1
  25. jarvis/jarvis_code_analysis/checklists/rust.py +1 -1
  26. jarvis/jarvis_code_analysis/checklists/shell.py +1 -1
  27. jarvis/jarvis_code_analysis/checklists/sql.py +1 -1
  28. jarvis/jarvis_code_analysis/checklists/swift.py +1 -1
  29. jarvis/jarvis_code_analysis/checklists/web.py +1 -1
  30. jarvis/jarvis_code_analysis/code_review.py +292 -190
  31. jarvis/jarvis_dev/main.py +73 -56
  32. jarvis/jarvis_git_details/main.py +29 -33
  33. jarvis/jarvis_git_squash/main.py +13 -11
  34. jarvis/jarvis_git_utils/git_commiter.py +15 -5
  35. jarvis/jarvis_mcp/__init__.py +8 -10
  36. jarvis/jarvis_mcp/sse_mcp_client.py +182 -205
  37. jarvis/jarvis_mcp/stdio_mcp_client.py +93 -120
  38. jarvis/jarvis_mcp/streamable_mcp_client.py +117 -142
  39. jarvis/jarvis_methodology/main.py +71 -39
  40. jarvis/jarvis_multi_agent/__init__.py +24 -16
  41. jarvis/jarvis_multi_agent/main.py +10 -4
  42. jarvis/jarvis_platform/__init__.py +1 -1
  43. jarvis/jarvis_platform/base.py +44 -18
  44. jarvis/jarvis_platform/human.py +15 -3
  45. jarvis/jarvis_platform/kimi.py +117 -81
  46. jarvis/jarvis_platform/openai.py +23 -28
  47. jarvis/jarvis_platform/registry.py +43 -29
  48. jarvis/jarvis_platform/tongyi.py +16 -10
  49. jarvis/jarvis_platform/yuanbao.py +197 -144
  50. jarvis/jarvis_platform_manager/main.py +4 -2
  51. jarvis/jarvis_smart_shell/main.py +35 -30
  52. jarvis/jarvis_tools/ask_user.py +8 -16
  53. jarvis/jarvis_tools/base.py +3 -2
  54. jarvis/jarvis_tools/chdir.py +7 -19
  55. jarvis/jarvis_tools/cli/main.py +14 -10
  56. jarvis/jarvis_tools/code_plan.py +10 -31
  57. jarvis/jarvis_tools/create_code_agent.py +6 -11
  58. jarvis/jarvis_tools/create_sub_agent.py +10 -22
  59. jarvis/jarvis_tools/edit_file.py +98 -76
  60. jarvis/jarvis_tools/execute_script.py +46 -46
  61. jarvis/jarvis_tools/file_analyzer.py +22 -34
  62. jarvis/jarvis_tools/file_operation.py +69 -62
  63. jarvis/jarvis_tools/generate_new_tool.py +0 -2
  64. jarvis/jarvis_tools/methodology.py +19 -23
  65. jarvis/jarvis_tools/read_code.py +35 -35
  66. jarvis/jarvis_tools/read_webpage.py +7 -16
  67. jarvis/jarvis_tools/registry.py +63 -30
  68. jarvis/jarvis_tools/rewrite_file.py +26 -29
  69. jarvis/jarvis_tools/search_web.py +5 -8
  70. jarvis/jarvis_tools/virtual_tty.py +133 -122
  71. jarvis/jarvis_utils/__init__.py +0 -1
  72. jarvis/jarvis_utils/builtin_replace_map.py +9 -9
  73. jarvis/jarvis_utils/config.py +60 -37
  74. jarvis/jarvis_utils/embedding.py +24 -19
  75. jarvis/jarvis_utils/file_processors.py +16 -9
  76. jarvis/jarvis_utils/git_utils.py +157 -107
  77. jarvis/jarvis_utils/globals.py +1 -1
  78. jarvis/jarvis_utils/input.py +85 -52
  79. jarvis/jarvis_utils/jarvis_history.py +43 -0
  80. jarvis/jarvis_utils/methodology.py +31 -24
  81. jarvis/jarvis_utils/output.py +164 -80
  82. jarvis/jarvis_utils/tag.py +2 -1
  83. jarvis/jarvis_utils/utils.py +84 -52
  84. {jarvis_ai_assistant-0.1.193.dist-info → jarvis_ai_assistant-0.1.195.dist-info}/METADATA +362 -230
  85. jarvis_ai_assistant-0.1.195.dist-info/RECORD +98 -0
  86. jarvis/jarvis_agent/file_input_handler.py +0 -112
  87. jarvis/jarvis_event/__init__.py +0 -0
  88. jarvis_ai_assistant-0.1.193.dist-info/RECORD +0 -99
  89. {jarvis_ai_assistant-0.1.193.dist-info → jarvis_ai_assistant-0.1.195.dist-info}/WHEEL +0 -0
  90. {jarvis_ai_assistant-0.1.193.dist-info → jarvis_ai_assistant-0.1.195.dist-info}/entry_points.txt +0 -0
  91. {jarvis_ai_assistant-0.1.193.dist-info → jarvis_ai_assistant-0.1.195.dist-info}/licenses/LICENSE +0 -0
  92. {jarvis_ai_assistant-0.1.193.dist-info → jarvis_ai_assistant-0.1.195.dist-info}/top_level.txt +0 -0
@@ -1,16 +1,12 @@
1
1
  # -*- coding: utf-8 -*-
2
2
  import os
3
3
  from functools import lru_cache
4
-
5
4
  from typing import Any, Dict, List
6
5
 
7
-
8
6
  import yaml
9
7
 
10
-
11
8
  from jarvis.jarvis_utils.builtin_replace_map import BUILTIN_REPLACE_MAP
12
9
 
13
-
14
10
  # 全局环境变量存储
15
11
 
16
12
  GLOBAL_CONFIG_DATA: Dict[str, Any] = {}
@@ -21,6 +17,7 @@ def set_global_env_data(env_data: Dict[str, Any]) -> None:
21
17
  global GLOBAL_CONFIG_DATA
22
18
  GLOBAL_CONFIG_DATA = env_data
23
19
 
20
+
24
21
  def set_config(key: str, value: Any) -> None:
25
22
  """设置配置"""
26
23
  GLOBAL_CONFIG_DATA[key] = value
@@ -32,48 +29,53 @@ def set_config(key: str, value: Any) -> None:
32
29
  所有配置都从环境变量中读取,带有回退默认值。
33
30
  """
34
31
 
32
+
35
33
  def get_git_commit_prompt() -> str:
36
34
  """
37
35
  获取Git提交提示模板
38
-
36
+
39
37
  返回:
40
38
  str: Git提交信息生成提示模板,如果未配置则返回空字符串
41
39
  """
42
40
  return GLOBAL_CONFIG_DATA.get("JARVIS_GIT_COMMIT_PROMPT", "")
43
41
 
42
+
44
43
  # 输出窗口预留大小
45
44
  INPUT_WINDOW_REVERSE_SIZE = 2048
46
45
 
46
+
47
47
  @lru_cache(maxsize=None)
48
48
  def get_replace_map() -> dict:
49
49
  """
50
50
  获取替换映射表。
51
-
51
+
52
52
  优先使用GLOBAL_CONFIG_DATA['JARVIS_REPLACE_MAP']的配置,
53
53
  如果没有则从数据目录下的replace_map.yaml文件中读取替换映射表,
54
54
  如果文件不存在则返回内置替换映射表。
55
-
55
+
56
56
  返回:
57
57
  dict: 合并后的替换映射表字典(内置+文件中的映射表)
58
58
  """
59
- if 'JARVIS_REPLACE_MAP' in GLOBAL_CONFIG_DATA:
60
- return {**BUILTIN_REPLACE_MAP, **GLOBAL_CONFIG_DATA['JARVIS_REPLACE_MAP']}
61
-
62
- replace_map_path = os.path.join(get_data_dir(), 'replace_map.yaml')
59
+ if "JARVIS_REPLACE_MAP" in GLOBAL_CONFIG_DATA:
60
+ return {**BUILTIN_REPLACE_MAP, **GLOBAL_CONFIG_DATA["JARVIS_REPLACE_MAP"]}
61
+
62
+ replace_map_path = os.path.join(get_data_dir(), "replace_map.yaml")
63
63
  if not os.path.exists(replace_map_path):
64
64
  return BUILTIN_REPLACE_MAP.copy()
65
-
66
- from jarvis.jarvis_utils.output import PrettyOutput, OutputType
65
+
66
+ from jarvis.jarvis_utils.output import OutputType, PrettyOutput
67
+
67
68
  PrettyOutput.print(
68
69
  "警告:使用replace_map.yaml进行配置的方式已被弃用,将在未来版本中移除。"
69
70
  "请迁移到使用GLOBAL_CONFIG_DATA中的JARVIS_REPLACE_MAP配置。",
70
- output_type=OutputType.WARNING
71
+ output_type=OutputType.WARNING,
71
72
  )
72
-
73
- with open(replace_map_path, 'r', encoding='utf-8', errors='ignore') as file:
73
+
74
+ with open(replace_map_path, "r", encoding="utf-8", errors="ignore") as file:
74
75
  file_map = yaml.safe_load(file) or {}
75
76
  return {**BUILTIN_REPLACE_MAP, **file_map}
76
77
 
78
+
77
79
  def get_max_token_count() -> int:
78
80
  """
79
81
  获取模型允许的最大token数量。
@@ -81,7 +83,8 @@ def get_max_token_count() -> int:
81
83
  返回:
82
84
  int: 模型能处理的最大token数量。
83
85
  """
84
- return int(GLOBAL_CONFIG_DATA.get('JARVIS_MAX_TOKEN_COUNT', '960000'))
86
+ return int(GLOBAL_CONFIG_DATA.get("JARVIS_MAX_TOKEN_COUNT", "960000"))
87
+
85
88
 
86
89
  def get_max_input_token_count() -> int:
87
90
  """
@@ -90,7 +93,7 @@ def get_max_input_token_count() -> int:
90
93
  返回:
91
94
  int: 模型能处理的最大输入token数量。
92
95
  """
93
- return int(GLOBAL_CONFIG_DATA.get('JARVIS_MAX_INPUT_TOKEN_COUNT', '32000'))
96
+ return int(GLOBAL_CONFIG_DATA.get("JARVIS_MAX_INPUT_TOKEN_COUNT", "32000"))
94
97
 
95
98
 
96
99
  def is_auto_complete() -> bool:
@@ -100,7 +103,7 @@ def is_auto_complete() -> bool:
100
103
  返回:
101
104
  bool: 如果启用了自动补全则返回True,默认为False
102
105
  """
103
- return GLOBAL_CONFIG_DATA.get('JARVIS_AUTO_COMPLETE', False) == True
106
+ return GLOBAL_CONFIG_DATA.get("JARVIS_AUTO_COMPLETE", False) == True
104
107
 
105
108
 
106
109
  def get_shell_name() -> str:
@@ -110,8 +113,10 @@ def get_shell_name() -> str:
110
113
  返回:
111
114
  str: Shell名称(例如bash, zsh),默认为bash
112
115
  """
113
- shell_path = GLOBAL_CONFIG_DATA.get('SHELL', '/bin/bash')
116
+ shell_path = GLOBAL_CONFIG_DATA.get("SHELL", "/bin/bash")
114
117
  return os.path.basename(shell_path)
118
+
119
+
115
120
  def get_normal_platform_name() -> str:
116
121
  """
117
122
  获取正常操作的平台名称。
@@ -119,7 +124,9 @@ def get_normal_platform_name() -> str:
119
124
  返回:
120
125
  str: 平台名称,默认为'yuanbao'
121
126
  """
122
- return GLOBAL_CONFIG_DATA.get('JARVIS_PLATFORM', 'yuanbao')
127
+ return GLOBAL_CONFIG_DATA.get("JARVIS_PLATFORM", "yuanbao")
128
+
129
+
123
130
  def get_normal_model_name() -> str:
124
131
  """
125
132
  获取正常操作的模型名称。
@@ -127,7 +134,7 @@ def get_normal_model_name() -> str:
127
134
  返回:
128
135
  str: 模型名称,默认为'deep_seek'
129
136
  """
130
- return GLOBAL_CONFIG_DATA.get('JARVIS_MODEL', 'deep_seek_v3')
137
+ return GLOBAL_CONFIG_DATA.get("JARVIS_MODEL", "deep_seek_v3")
131
138
 
132
139
 
133
140
  def get_thinking_platform_name() -> str:
@@ -137,7 +144,11 @@ def get_thinking_platform_name() -> str:
137
144
  返回:
138
145
  str: 平台名称,默认为'yuanbao'
139
146
  """
140
- return GLOBAL_CONFIG_DATA.get('JARVIS_THINKING_PLATFORM', GLOBAL_CONFIG_DATA.get('JARVIS_PLATFORM', 'yuanbao'))
147
+ return GLOBAL_CONFIG_DATA.get(
148
+ "JARVIS_THINKING_PLATFORM", GLOBAL_CONFIG_DATA.get("JARVIS_PLATFORM", "yuanbao")
149
+ )
150
+
151
+
141
152
  def get_thinking_model_name() -> str:
142
153
  """
143
154
  获取思考操作的模型名称。
@@ -145,7 +156,10 @@ def get_thinking_model_name() -> str:
145
156
  返回:
146
157
  str: 模型名称,默认为'deep_seek'
147
158
  """
148
- return GLOBAL_CONFIG_DATA.get('JARVIS_THINKING_MODEL', GLOBAL_CONFIG_DATA.get('JARVIS_MODEL', 'deep_seek'))
159
+ return GLOBAL_CONFIG_DATA.get(
160
+ "JARVIS_THINKING_MODEL", GLOBAL_CONFIG_DATA.get("JARVIS_MODEL", "deep_seek")
161
+ )
162
+
149
163
 
150
164
  def is_execute_tool_confirm() -> bool:
151
165
  """
@@ -154,7 +168,9 @@ def is_execute_tool_confirm() -> bool:
154
168
  返回:
155
169
  bool: 如果需要确认则返回True,默认为False
156
170
  """
157
- return GLOBAL_CONFIG_DATA.get('JARVIS_EXECUTE_TOOL_CONFIRM', False) == True
171
+ return GLOBAL_CONFIG_DATA.get("JARVIS_EXECUTE_TOOL_CONFIRM", False) == True
172
+
173
+
158
174
  def is_confirm_before_apply_patch() -> bool:
159
175
  """
160
176
  检查应用补丁前是否需要确认。
@@ -162,7 +178,8 @@ def is_confirm_before_apply_patch() -> bool:
162
178
  返回:
163
179
  bool: 如果需要确认则返回True,默认为False
164
180
  """
165
- return GLOBAL_CONFIG_DATA.get('JARVIS_CONFIRM_BEFORE_APPLY_PATCH', True) == True
181
+ return GLOBAL_CONFIG_DATA.get("JARVIS_CONFIRM_BEFORE_APPLY_PATCH", True) == True
182
+
166
183
 
167
184
  def get_max_tool_call_count() -> int:
168
185
  """
@@ -171,22 +188,23 @@ def get_max_tool_call_count() -> int:
171
188
  返回:
172
189
  int: 最大连续工具调用次数,默认为20
173
190
  """
174
- return int(GLOBAL_CONFIG_DATA.get('JARVIS_MAX_TOOL_CALL_COUNT', '20'))
191
+ return int(GLOBAL_CONFIG_DATA.get("JARVIS_MAX_TOOL_CALL_COUNT", "20"))
175
192
 
176
193
 
177
194
  def get_data_dir() -> str:
178
195
  """
179
196
  获取Jarvis数据存储目录路径。
180
-
197
+
181
198
  返回:
182
199
  str: 数据目录路径,优先从JARVIS_DATA_PATH环境变量获取,
183
200
  如果未设置或为空,则使用~/.jarvis作为默认值
184
201
  """
185
- data_path = GLOBAL_CONFIG_DATA.get('JARVIS_DATA_PATH', '').strip()
202
+ data_path = GLOBAL_CONFIG_DATA.get("JARVIS_DATA_PATH", "").strip()
186
203
  if not data_path:
187
- return os.path.expanduser('~/.jarvis')
204
+ return os.path.expanduser("~/.jarvis")
188
205
  return data_path
189
206
 
207
+
190
208
  def get_auto_update() -> bool:
191
209
  """
192
210
  获取是否自动更新git仓库。
@@ -194,7 +212,8 @@ def get_auto_update() -> bool:
194
212
  返回:
195
213
  bool: 如果需要自动更新则返回True,默认为True
196
214
  """
197
- return GLOBAL_CONFIG_DATA.get('JARVIS_AUTO_UPDATE', True) == True
215
+ return GLOBAL_CONFIG_DATA.get("JARVIS_AUTO_UPDATE", True) == True
216
+
198
217
 
199
218
  def get_max_big_content_size() -> int:
200
219
  """
@@ -203,7 +222,8 @@ def get_max_big_content_size() -> int:
203
222
  返回:
204
223
  int: 最大大内容大小
205
224
  """
206
- return int(GLOBAL_CONFIG_DATA.get('JARVIS_MAX_BIG_CONTENT_SIZE', '160000'))
225
+ return int(GLOBAL_CONFIG_DATA.get("JARVIS_MAX_BIG_CONTENT_SIZE", "160000"))
226
+
207
227
 
208
228
  def get_pretty_output() -> bool:
209
229
  """
@@ -212,7 +232,8 @@ def get_pretty_output() -> bool:
212
232
  返回:
213
233
  bool: 如果启用PrettyOutput则返回True,默认为True
214
234
  """
215
- return GLOBAL_CONFIG_DATA.get('JARVIS_PRETTY_OUTPUT', False) == True
235
+ return GLOBAL_CONFIG_DATA.get("JARVIS_PRETTY_OUTPUT", False) == True
236
+
216
237
 
217
238
  def is_use_methodology() -> bool:
218
239
  """
@@ -221,7 +242,8 @@ def is_use_methodology() -> bool:
221
242
  返回:
222
243
  bool: 如果启用方法论则返回True,默认为True
223
244
  """
224
- return GLOBAL_CONFIG_DATA.get('JARVIS_USE_METHODOLOGY', True) == True
245
+ return GLOBAL_CONFIG_DATA.get("JARVIS_USE_METHODOLOGY", True) == True
246
+
225
247
 
226
248
  def is_use_analysis() -> bool:
227
249
  """
@@ -230,7 +252,8 @@ def is_use_analysis() -> bool:
230
252
  返回:
231
253
  bool: 如果启用任务分析则返回True,默认为True
232
254
  """
233
- return GLOBAL_CONFIG_DATA.get('JARVIS_USE_ANALYSIS', True) == True
255
+ return GLOBAL_CONFIG_DATA.get("JARVIS_USE_ANALYSIS", True) == True
256
+
234
257
 
235
258
  def is_print_prompt() -> bool:
236
259
  """
@@ -239,7 +262,7 @@ def is_print_prompt() -> bool:
239
262
  返回:
240
263
  bool: 如果打印提示则返回True,默认为True
241
264
  """
242
- return GLOBAL_CONFIG_DATA.get('JARVIS_PRINT_PROMPT', False) == True
265
+ return GLOBAL_CONFIG_DATA.get("JARVIS_PRINT_PROMPT", False) == True
243
266
 
244
267
 
245
268
  def get_mcp_config() -> List[Dict[str, Any]]:
@@ -249,4 +272,4 @@ def get_mcp_config() -> List[Dict[str, Any]]:
249
272
  返回:
250
273
  List[Dict[str, Any]]: MCP配置项列表,如果未配置则返回空列表
251
274
  """
252
- return GLOBAL_CONFIG_DATA.get("JARVIS_MCP", [])
275
+ return GLOBAL_CONFIG_DATA.get("JARVIS_MCP", [])
@@ -1,9 +1,7 @@
1
1
  # -*- coding: utf-8 -*-
2
2
  import functools
3
3
  import os
4
- from typing import List
5
-
6
- from transformers import AutoTokenizer
4
+ from typing import Any, List
7
5
 
8
6
  from jarvis.jarvis_utils.config import get_data_dir
9
7
  from jarvis.jarvis_utils.output import OutputType, PrettyOutput
@@ -11,6 +9,7 @@ from jarvis.jarvis_utils.output import OutputType, PrettyOutput
11
9
  # 全局缓存,避免重复加载模型
12
10
  _global_tokenizers = {}
13
11
 
12
+
14
13
  def get_context_token_count(text: str) -> int:
15
14
  """使用分词器获取文本的token数量。
16
15
 
@@ -21,12 +20,13 @@ def get_context_token_count(text: str) -> int:
21
20
  int: 文本中的token数量
22
21
  """
23
22
  try:
24
- tokenizer = load_tokenizer()
23
+ from transformers import AutoTokenizer # type: ignore
24
+ tokenizer : AutoTokenizer = load_tokenizer()
25
25
  # 分批处理长文本,确保不超过模型最大长度
26
26
  total_tokens = 0
27
27
  chunk_size = 100 # 每次处理100个字符,避免超过模型最大长度(考虑到中文字符可能被编码成多个token)
28
28
  for i in range(0, len(text), chunk_size):
29
- chunk = text[i:i + chunk_size]
29
+ chunk = text[i : i + chunk_size]
30
30
  tokens = tokenizer.encode(chunk) # type: ignore
31
31
  total_tokens += len(tokens)
32
32
  return total_tokens
@@ -34,7 +34,10 @@ def get_context_token_count(text: str) -> int:
34
34
  PrettyOutput.print(f"计算token失败: {str(e)}", OutputType.WARNING)
35
35
  return len(text) // 4 # 每个token大约4个字符的粗略估计
36
36
 
37
- def split_text_into_chunks(text: str, max_length: int = 512, min_length: int = 50) -> List[str]:
37
+
38
+ def split_text_into_chunks(
39
+ text: str, max_length: int = 512, min_length: int = 50
40
+ ) -> List[str]:
38
41
  """将文本分割成块,基于token数量进行切割。
39
42
 
40
43
  参数:
@@ -52,15 +55,18 @@ def split_text_into_chunks(text: str, max_length: int = 512, min_length: int = 5
52
55
  chunks = []
53
56
  current_chunk = ""
54
57
  current_tokens = 0
55
-
58
+
56
59
  # 按较大的块处理文本,避免破坏token边界
57
60
  chunk_size = 50 # 每次处理50个字符
58
61
  for i in range(0, len(text), chunk_size):
59
- chunk = text[i:i + chunk_size]
62
+ chunk = text[i : i + chunk_size]
60
63
  chunk_tokens = get_context_token_count(chunk)
61
-
64
+
62
65
  # 如果当前块加上新块会超过最大长度,且当前块已经达到最小长度,则保存当前块
63
- if current_tokens + chunk_tokens > max_length and current_tokens >= min_length:
66
+ if (
67
+ current_tokens + chunk_tokens > max_length
68
+ and current_tokens >= min_length
69
+ ):
64
70
  chunks.append(current_chunk)
65
71
  current_chunk = chunk
66
72
  current_tokens = chunk_tokens
@@ -77,17 +83,20 @@ def split_text_into_chunks(text: str, max_length: int = 512, min_length: int = 5
77
83
  except Exception as e:
78
84
  PrettyOutput.print(f"文本分割失败: {str(e)}", OutputType.WARNING)
79
85
  # 发生错误时回退到简单的字符分割
80
- return [text[i:i + max_length] for i in range(0, len(text), max_length)]
86
+ return [text[i : i + max_length] for i in range(0, len(text), max_length)]
81
87
 
82
88
 
83
89
  @functools.lru_cache(maxsize=1)
84
- def load_tokenizer() -> AutoTokenizer:
90
+ def load_tokenizer() -> Any:
85
91
  """
86
92
  加载用于文本处理的分词器,使用缓存避免重复加载。
87
93
 
88
94
  返回:
89
95
  AutoTokenizer: 加载的分词器
90
96
  """
97
+
98
+ from transformers import AutoTokenizer # type: ignore
99
+
91
100
  model_name = "gpt2"
92
101
  cache_dir = os.path.join(get_data_dir(), "huggingface", "hub")
93
102
 
@@ -97,18 +106,14 @@ def load_tokenizer() -> AutoTokenizer:
97
106
 
98
107
  try:
99
108
  tokenizer = AutoTokenizer.from_pretrained(
100
- model_name,
101
- cache_dir=cache_dir,
102
- local_files_only=True
109
+ model_name, cache_dir=cache_dir, local_files_only=True
103
110
  )
104
111
  except Exception:
105
112
  tokenizer = AutoTokenizer.from_pretrained(
106
- model_name,
107
- cache_dir=cache_dir,
108
- local_files_only=False
113
+ model_name, cache_dir=cache_dir, local_files_only=False
109
114
  )
110
115
 
111
116
  # 保存到全局缓存
112
117
  _global_tokenizers[model_name] = tokenizer
113
118
 
114
- return tokenizer # type: ignore
119
+ return tokenizer # type: ignore
@@ -4,6 +4,7 @@ import unicodedata
4
4
 
5
5
  class FileProcessor:
6
6
  """Base class for file processor"""
7
+
7
8
  @staticmethod
8
9
  def can_handle(file_path: str) -> bool:
9
10
  """Determine if the file can be processed"""
@@ -14,9 +15,11 @@ class FileProcessor:
14
15
  """Extract file text content"""
15
16
  raise NotImplementedError
16
17
 
18
+
17
19
  class TextFileProcessor(FileProcessor):
18
20
  """Text file processor"""
19
- ENCODINGS = ['utf-8', 'gbk', 'gb2312', 'latin1']
21
+
22
+ ENCODINGS = ["utf-8", "gbk", "gb2312", "latin1"]
20
23
  SAMPLE_SIZE = 8192 # Read the first 8KB to detect encoding
21
24
 
22
25
  @staticmethod
@@ -24,16 +27,20 @@ class TextFileProcessor(FileProcessor):
24
27
  """Determine if the file is a text file by trying to decode it"""
25
28
  try:
26
29
  # Read the first part of the file to detect encoding
27
- with open(file_path, 'rb') as f:
30
+ with open(file_path, "rb") as f:
28
31
  sample = f.read(TextFileProcessor.SAMPLE_SIZE)
29
32
 
30
33
  # Check if it contains null bytes (usually represents a binary file)
31
- if b'\x00' in sample:
34
+ if b"\x00" in sample:
32
35
  return False
33
36
 
34
37
  # Check if it contains too many non-printable characters (usually represents a binary file)
35
- non_printable = sum(1 for byte in sample if byte < 32 and byte not in (9, 10, 13)) # tab, newline, carriage return
36
- if non_printable / len(sample) > 0.3: # If non-printable characters exceed 30%, it is considered a binary file
38
+ non_printable = sum(
39
+ 1 for byte in sample if byte < 32 and byte not in (9, 10, 13)
40
+ ) # tab, newline, carriage return
41
+ if (
42
+ non_printable / len(sample) > 0.3
43
+ ): # If non-printable characters exceed 30%, it is considered a binary file
37
44
  return False
38
45
 
39
46
  # Try to decode with different encodings
@@ -55,7 +62,7 @@ class TextFileProcessor(FileProcessor):
55
62
  detected_encoding = None
56
63
  try:
57
64
  # First try to detect encoding
58
- with open(file_path, 'rb') as f:
65
+ with open(file_path, "rb") as f:
59
66
  raw_data = f.read()
60
67
 
61
68
  # Try different encodings
@@ -68,14 +75,14 @@ class TextFileProcessor(FileProcessor):
68
75
  continue
69
76
 
70
77
  if not detected_encoding:
71
- raise UnicodeDecodeError(f"Failed to decode file with supported encodings: {file_path}") # type: ignore
78
+ raise UnicodeDecodeError(f"Failed to decode file with supported encodings: {file_path}") # type: ignore
72
79
 
73
80
  # Use the detected encoding to read the file
74
- with open(file_path, 'r', encoding=detected_encoding, errors='ignore') as f:
81
+ with open(file_path, "r", encoding=detected_encoding, errors="ignore") as f:
75
82
  content = f.read()
76
83
 
77
84
  # Normalize Unicode characters
78
- content = unicodedata.normalize('NFKC', content)
85
+ content = unicodedata.normalize("NFKC", content)
79
86
 
80
87
  return content
81
88