jarvis-ai-assistant 0.1.208__py3-none-any.whl → 0.1.210__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jarvis/__init__.py +1 -1
- jarvis/jarvis_agent/__init__.py +9 -59
- jarvis/jarvis_agent/edit_file_handler.py +1 -1
- jarvis/jarvis_code_agent/code_agent.py +55 -8
- jarvis/jarvis_code_agent/lint.py +1 -1
- jarvis/jarvis_data/config_schema.json +0 -25
- jarvis/jarvis_git_utils/git_commiter.py +2 -2
- jarvis/jarvis_platform/kimi.py +20 -11
- jarvis/jarvis_platform/tongyi.py +84 -74
- jarvis/jarvis_platform/yuanbao.py +60 -54
- jarvis/jarvis_tools/ask_user.py +0 -1
- jarvis/jarvis_tools/file_analyzer.py +0 -3
- jarvis/jarvis_utils/config.py +4 -49
- jarvis/jarvis_utils/embedding.py +6 -51
- jarvis/jarvis_utils/git_utils.py +74 -11
- jarvis/jarvis_utils/http.py +169 -0
- jarvis/jarvis_utils/utils.py +186 -63
- {jarvis_ai_assistant-0.1.208.dist-info → jarvis_ai_assistant-0.1.210.dist-info}/METADATA +5 -10
- {jarvis_ai_assistant-0.1.208.dist-info → jarvis_ai_assistant-0.1.210.dist-info}/RECORD +23 -24
- {jarvis_ai_assistant-0.1.208.dist-info → jarvis_ai_assistant-0.1.210.dist-info}/entry_points.txt +1 -0
- jarvis/jarvis_data/huggingface.tar.gz +0 -0
- jarvis/jarvis_utils/jarvis_history.py +0 -98
- {jarvis_ai_assistant-0.1.208.dist-info → jarvis_ai_assistant-0.1.210.dist-info}/WHEEL +0 -0
- {jarvis_ai_assistant-0.1.208.dist-info → jarvis_ai_assistant-0.1.210.dist-info}/licenses/LICENSE +0 -0
- {jarvis_ai_assistant-0.1.208.dist-info → jarvis_ai_assistant-0.1.210.dist-info}/top_level.txt +0 -0
jarvis/jarvis_platform/tongyi.py
CHANGED
@@ -5,9 +5,8 @@ import time
|
|
5
5
|
import uuid
|
6
6
|
from typing import Any, Dict, Generator, List, Tuple
|
7
7
|
|
8
|
-
import requests # type: ignore
|
9
|
-
|
10
8
|
from jarvis.jarvis_platform.base import BasePlatform
|
9
|
+
from jarvis.jarvis_utils import http
|
11
10
|
from jarvis.jarvis_utils.output import OutputType, PrettyOutput
|
12
11
|
from jarvis.jarvis_utils.utils import while_success
|
13
12
|
|
@@ -160,82 +159,93 @@ class TongyiPlatform(BasePlatform):
|
|
160
159
|
}
|
161
160
|
|
162
161
|
try:
|
163
|
-
|
164
|
-
|
162
|
+
# 使用新的stream_post接口发送消息请求,获取流式响应
|
163
|
+
response_stream = while_success(
|
164
|
+
lambda: http.stream_post(url, headers=headers, json=payload),
|
165
165
|
sleep_time=5,
|
166
166
|
)
|
167
|
-
|
168
|
-
raise Exception(f"HTTP {response.status_code}: {response.text}")
|
167
|
+
|
169
168
|
msg_id = ""
|
170
169
|
session_id = ""
|
171
170
|
thinking_content = ""
|
172
171
|
text_content = ""
|
173
172
|
in_thinking = False
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
continue
|
173
|
+
response_data = b""
|
174
|
+
|
175
|
+
# 处理流式响应
|
176
|
+
for chunk in response_stream:
|
177
|
+
response_data += chunk
|
180
178
|
|
179
|
+
# 尝试解析SSE格式的数据
|
181
180
|
try:
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
if
|
196
|
-
|
197
|
-
|
198
|
-
]
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
181
|
+
# 查找完整的数据行
|
182
|
+
lines = response_data.decode("utf-8").split("\n")
|
183
|
+
response_data = b"" # 重置缓冲区
|
184
|
+
|
185
|
+
for line in lines:
|
186
|
+
if not line.strip():
|
187
|
+
continue
|
188
|
+
|
189
|
+
# SSE格式的行通常以"data: "开头
|
190
|
+
if line.startswith("data: "):
|
191
|
+
try:
|
192
|
+
data = json.loads(line[6:])
|
193
|
+
# 记录消息ID和会话ID
|
194
|
+
if "msgId" in data:
|
195
|
+
msg_id = data["msgId"]
|
196
|
+
if "sessionId" in data:
|
197
|
+
session_id = data["sessionId"]
|
198
|
+
|
199
|
+
if "contents" in data and len(data["contents"]) > 0:
|
200
|
+
for content in data["contents"]:
|
201
|
+
if content.get("contentType") == "think":
|
202
|
+
if not in_thinking:
|
203
|
+
yield "<think>\n\n"
|
204
|
+
in_thinking = True
|
205
|
+
if content.get("incremental"):
|
206
|
+
tmp_content = json.loads(
|
207
|
+
content.get("content")
|
208
|
+
)["content"]
|
209
|
+
thinking_content += tmp_content
|
210
|
+
yield tmp_content
|
211
|
+
else:
|
212
|
+
tmp_content = json.loads(
|
213
|
+
content.get("content")
|
214
|
+
)["content"]
|
215
|
+
if len(thinking_content) < len(
|
216
|
+
tmp_content
|
217
|
+
):
|
218
|
+
yield tmp_content[
|
219
|
+
len(thinking_content) :
|
220
|
+
]
|
221
|
+
thinking_content = tmp_content
|
222
|
+
else:
|
223
|
+
yield "\r\n</think>\n"[
|
224
|
+
len(thinking_content)
|
225
|
+
- len(tmp_content) :
|
226
|
+
]
|
227
|
+
thinking_content = tmp_content
|
228
|
+
in_thinking = False
|
229
|
+
elif content.get("contentType") == "text":
|
230
|
+
if in_thinking:
|
231
|
+
continue
|
232
|
+
if content.get("incremental"):
|
233
|
+
tmp_content = content.get("content")
|
234
|
+
text_content += tmp_content
|
235
|
+
yield tmp_content
|
236
|
+
else:
|
237
|
+
tmp_content = content.get("content")
|
238
|
+
if len(text_content) < len(tmp_content):
|
239
|
+
yield tmp_content[
|
240
|
+
len(text_content) :
|
241
|
+
]
|
242
|
+
text_content = tmp_content
|
243
|
+
|
244
|
+
except json.JSONDecodeError:
|
245
|
+
continue
|
246
|
+
|
247
|
+
except UnicodeDecodeError:
|
248
|
+
# 如果解码失败,继续累积数据
|
239
249
|
continue
|
240
250
|
|
241
251
|
self.msg_id = msg_id
|
@@ -258,7 +268,7 @@ class TongyiPlatform(BasePlatform):
|
|
258
268
|
|
259
269
|
try:
|
260
270
|
response = while_success(
|
261
|
-
lambda:
|
271
|
+
lambda: http.post(url, headers=headers, json=payload), sleep_time=5
|
262
272
|
)
|
263
273
|
if response.status_code != 200:
|
264
274
|
raise Exception(f"HTTP {response.status_code}: {response.text}")
|
@@ -314,7 +324,7 @@ class TongyiPlatform(BasePlatform):
|
|
314
324
|
print(f"📤 正在上传文件: {file_name}")
|
315
325
|
|
316
326
|
# Upload file
|
317
|
-
response =
|
327
|
+
response = http.post(
|
318
328
|
upload_token["host"], data=form_data, files=files
|
319
329
|
)
|
320
330
|
|
@@ -349,7 +359,7 @@ class TongyiPlatform(BasePlatform):
|
|
349
359
|
"dir": upload_token["dir"],
|
350
360
|
}
|
351
361
|
|
352
|
-
response =
|
362
|
+
response = http.post(url, headers=headers, json=payload)
|
353
363
|
if response.status_code != 200:
|
354
364
|
print(f"❌ 获取下载链接失败: HTTP {response.status_code}")
|
355
365
|
return False
|
@@ -381,7 +391,7 @@ class TongyiPlatform(BasePlatform):
|
|
381
391
|
"fileSize": os.path.getsize(file_path),
|
382
392
|
}
|
383
393
|
|
384
|
-
add_response =
|
394
|
+
add_response = http.post(
|
385
395
|
add_url, headers=headers, json=add_payload
|
386
396
|
)
|
387
397
|
if add_response.status_code != 200:
|
@@ -464,7 +474,7 @@ class TongyiPlatform(BasePlatform):
|
|
464
474
|
|
465
475
|
try:
|
466
476
|
response = while_success(
|
467
|
-
lambda:
|
477
|
+
lambda: http.post(url, headers=headers, json=payload), sleep_time=5
|
468
478
|
)
|
469
479
|
if response.status_code != 200:
|
470
480
|
PrettyOutput.print(
|
@@ -7,10 +7,10 @@ import time
|
|
7
7
|
import urllib.parse
|
8
8
|
from typing import Dict, Generator, List, Tuple
|
9
9
|
|
10
|
-
import requests # type: ignore
|
11
10
|
from PIL import Image # type: ignore
|
12
11
|
|
13
12
|
from jarvis.jarvis_platform.base import BasePlatform
|
13
|
+
from jarvis.jarvis_utils import http
|
14
14
|
from jarvis.jarvis_utils.output import OutputType, PrettyOutput
|
15
15
|
from jarvis.jarvis_utils.utils import while_success
|
16
16
|
|
@@ -37,7 +37,7 @@ class YuanbaoPlatform(BasePlatform):
|
|
37
37
|
self.conversation_id = "" # 会话ID,用于标识当前对话
|
38
38
|
# 从环境变量中获取必要参数
|
39
39
|
self.cookies = os.getenv("YUANBAO_COOKIES") # 认证cookies
|
40
|
-
self.agent_id =
|
40
|
+
self.agent_id = "naQivTmsDa"
|
41
41
|
|
42
42
|
if not self.cookies:
|
43
43
|
PrettyOutput.print("YUANBAO_COOKIES 未设置", OutputType.WARNING)
|
@@ -95,7 +95,7 @@ class YuanbaoPlatform(BasePlatform):
|
|
95
95
|
|
96
96
|
try:
|
97
97
|
response = while_success(
|
98
|
-
lambda:
|
98
|
+
lambda: http.post(url, headers=headers, data=payload),
|
99
99
|
sleep_time=5,
|
100
100
|
)
|
101
101
|
response_json = response.json()
|
@@ -254,7 +254,7 @@ class YuanbaoPlatform(BasePlatform):
|
|
254
254
|
|
255
255
|
try:
|
256
256
|
response = while_success(
|
257
|
-
lambda:
|
257
|
+
lambda: http.post(url, headers=headers, json=payload),
|
258
258
|
sleep_time=5,
|
259
259
|
)
|
260
260
|
|
@@ -331,7 +331,7 @@ class YuanbaoPlatform(BasePlatform):
|
|
331
331
|
)
|
332
332
|
|
333
333
|
# Upload the file
|
334
|
-
response =
|
334
|
+
response = http.put(url, headers=headers, data=file_content)
|
335
335
|
|
336
336
|
if response.status_code not in [200, 204]:
|
337
337
|
PrettyOutput.print(
|
@@ -468,60 +468,66 @@ class YuanbaoPlatform(BasePlatform):
|
|
468
468
|
payload["displayPrompt"] = payload["prompt"]
|
469
469
|
|
470
470
|
try:
|
471
|
-
#
|
472
|
-
|
473
|
-
lambda:
|
474
|
-
url, headers=headers, json=payload, stream=True, timeout=600
|
475
|
-
),
|
471
|
+
# 使用新的stream_post接口发送消息请求,获取流式响应
|
472
|
+
response_stream = while_success(
|
473
|
+
lambda: http.stream_post(url, headers=headers, json=payload),
|
476
474
|
sleep_time=5,
|
477
475
|
)
|
478
476
|
|
479
|
-
# 检查响应状态
|
480
|
-
if response.status_code != 200:
|
481
|
-
error_msg = f"发送消息失败,状态码: {response.status_code}"
|
482
|
-
if hasattr(response, "text"):
|
483
|
-
error_msg += f", 响应: {response.text}"
|
484
|
-
raise Exception(error_msg)
|
485
|
-
|
486
477
|
in_thinking = False
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
478
|
+
response_data = b""
|
479
|
+
|
480
|
+
# 处理流式响应
|
481
|
+
for chunk in response_stream:
|
482
|
+
response_data += chunk
|
483
|
+
|
484
|
+
# 尝试解析SSE格式的数据
|
485
|
+
try:
|
486
|
+
# 查找完整的数据行
|
487
|
+
lines = response_data.decode("utf-8").split("\n")
|
488
|
+
response_data = b"" # 重置缓冲区
|
489
|
+
|
490
|
+
for line in lines:
|
491
|
+
if not line.strip():
|
492
|
+
continue
|
493
|
+
|
494
|
+
# SSE格式的行通常以"data: "开头
|
495
|
+
if line.startswith("data: "):
|
496
|
+
try:
|
497
|
+
data_str = line[6:] # 移除"data: "前缀
|
498
|
+
|
499
|
+
# 检查结束标志
|
500
|
+
if data_str == "[DONE]":
|
501
|
+
self.first_chat = False
|
502
|
+
return None
|
503
|
+
|
504
|
+
data = json.loads(data_str)
|
505
|
+
|
506
|
+
# 处理文本类型的消息
|
507
|
+
if data.get("type") == "text":
|
508
|
+
if in_thinking:
|
509
|
+
yield "</think>\n"
|
510
|
+
in_thinking = False
|
511
|
+
msg = data.get("msg", "")
|
512
|
+
if msg:
|
513
|
+
yield msg
|
514
|
+
|
515
|
+
# 处理思考中的消息
|
516
|
+
elif data.get("type") == "think":
|
517
|
+
if not in_thinking:
|
518
|
+
yield "<think>\n"
|
519
|
+
in_thinking = True
|
520
|
+
think_content = data.get("content", "")
|
521
|
+
if think_content:
|
522
|
+
yield think_content
|
523
|
+
|
524
|
+
except json.JSONDecodeError:
|
525
|
+
pass
|
526
|
+
|
527
|
+
except UnicodeDecodeError:
|
528
|
+
# 如果解码失败,继续累积数据
|
491
529
|
continue
|
492
530
|
|
493
|
-
line_str = line.decode("utf-8")
|
494
|
-
|
495
|
-
# SSE格式的行通常以"data: "开头
|
496
|
-
if line_str.startswith("data: "):
|
497
|
-
try:
|
498
|
-
data_str = line_str[6:] # 移除"data: "前缀
|
499
|
-
data = json.loads(data_str)
|
500
|
-
|
501
|
-
# 处理文本类型的消息
|
502
|
-
if data.get("type") == "text":
|
503
|
-
if in_thinking:
|
504
|
-
yield "</think>\n"
|
505
|
-
in_thinking = False
|
506
|
-
msg = data.get("msg", "")
|
507
|
-
if msg:
|
508
|
-
yield msg
|
509
|
-
|
510
|
-
# 处理思考中的消息
|
511
|
-
elif data.get("type") == "think":
|
512
|
-
if not in_thinking:
|
513
|
-
yield "<think>\n"
|
514
|
-
in_thinking = True
|
515
|
-
think_content = data.get("content", "")
|
516
|
-
if think_content:
|
517
|
-
yield think_content
|
518
|
-
|
519
|
-
except json.JSONDecodeError:
|
520
|
-
pass
|
521
|
-
|
522
|
-
# 检测结束标志
|
523
|
-
elif line_str == "data: [DONE]":
|
524
|
-
return None
|
525
531
|
self.first_chat = False
|
526
532
|
return None
|
527
533
|
|
@@ -547,7 +553,7 @@ class YuanbaoPlatform(BasePlatform):
|
|
547
553
|
|
548
554
|
try:
|
549
555
|
response = while_success(
|
550
|
-
lambda:
|
556
|
+
lambda: http.post(url, headers=headers, json=payload),
|
551
557
|
sleep_time=5,
|
552
558
|
)
|
553
559
|
|
jarvis/jarvis_tools/ask_user.py
CHANGED
jarvis/jarvis_utils/config.py
CHANGED
@@ -3,7 +3,7 @@ import os
|
|
3
3
|
from functools import lru_cache
|
4
4
|
from typing import Any, Dict, List
|
5
5
|
|
6
|
-
import yaml
|
6
|
+
import yaml # type: ignore
|
7
7
|
|
8
8
|
from jarvis.jarvis_utils.builtin_replace_map import BUILTIN_REPLACE_MAP
|
9
9
|
|
@@ -96,16 +96,6 @@ def get_max_input_token_count() -> int:
|
|
96
96
|
return int(GLOBAL_CONFIG_DATA.get("JARVIS_MAX_INPUT_TOKEN_COUNT", "32000"))
|
97
97
|
|
98
98
|
|
99
|
-
def is_auto_complete() -> bool:
|
100
|
-
"""
|
101
|
-
检查是否启用了自动补全功能。
|
102
|
-
|
103
|
-
返回:
|
104
|
-
bool: 如果启用了自动补全则返回True,默认为False
|
105
|
-
"""
|
106
|
-
return GLOBAL_CONFIG_DATA.get("JARVIS_AUTO_COMPLETE", False) == True
|
107
|
-
|
108
|
-
|
109
99
|
def get_shell_name() -> str:
|
110
100
|
"""
|
111
101
|
获取系统shell名称。
|
@@ -119,10 +109,6 @@ def get_shell_name() -> str:
|
|
119
109
|
3. 最后从环境变量SHELL获取
|
120
110
|
4. 如果都未配置,则默认返回bash
|
121
111
|
"""
|
122
|
-
shell_name = GLOBAL_CONFIG_DATA.get("JARVIS_SHELL")
|
123
|
-
if shell_name:
|
124
|
-
return shell_name.lower()
|
125
|
-
|
126
112
|
shell_path = GLOBAL_CONFIG_DATA.get("SHELL", os.getenv("SHELL", "/bin/bash"))
|
127
113
|
return os.path.basename(shell_path).lower()
|
128
114
|
|
@@ -191,16 +177,6 @@ def is_confirm_before_apply_patch() -> bool:
|
|
191
177
|
return GLOBAL_CONFIG_DATA.get("JARVIS_CONFIRM_BEFORE_APPLY_PATCH", False) == True
|
192
178
|
|
193
179
|
|
194
|
-
def get_max_tool_call_count() -> int:
|
195
|
-
"""
|
196
|
-
获取最大工具调用次数。
|
197
|
-
|
198
|
-
返回:
|
199
|
-
int: 最大连续工具调用次数,默认为20
|
200
|
-
"""
|
201
|
-
return int(GLOBAL_CONFIG_DATA.get("JARVIS_MAX_TOOL_CALL_COUNT", "20"))
|
202
|
-
|
203
|
-
|
204
180
|
def get_data_dir() -> str:
|
205
181
|
"""
|
206
182
|
获取Jarvis数据存储目录路径。
|
@@ -209,20 +185,9 @@ def get_data_dir() -> str:
|
|
209
185
|
str: 数据目录路径,优先从JARVIS_DATA_PATH环境变量获取,
|
210
186
|
如果未设置或为空,则使用~/.jarvis作为默认值
|
211
187
|
"""
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
return data_path
|
216
|
-
|
217
|
-
|
218
|
-
def get_auto_update() -> bool:
|
219
|
-
"""
|
220
|
-
获取是否自动更新git仓库。
|
221
|
-
|
222
|
-
返回:
|
223
|
-
bool: 如果需要自动更新则返回True,默认为True
|
224
|
-
"""
|
225
|
-
return GLOBAL_CONFIG_DATA.get("JARVIS_AUTO_UPDATE", True) == True
|
188
|
+
return os.path.expanduser(
|
189
|
+
GLOBAL_CONFIG_DATA.get("JARVIS_DATA_PATH", "~/.jarvis").strip()
|
190
|
+
)
|
226
191
|
|
227
192
|
|
228
193
|
def get_max_big_content_size() -> int:
|
@@ -275,16 +240,6 @@ def is_print_prompt() -> bool:
|
|
275
240
|
return GLOBAL_CONFIG_DATA.get("JARVIS_PRINT_PROMPT", False) == True
|
276
241
|
|
277
242
|
|
278
|
-
def get_history_count() -> int:
|
279
|
-
"""
|
280
|
-
获取是否启用历史记录功能。
|
281
|
-
|
282
|
-
返回:
|
283
|
-
bool: 如果启用历史记录则返回True,默认为False
|
284
|
-
"""
|
285
|
-
return GLOBAL_CONFIG_DATA.get("JARVIS_USE_HISTORY_COUNT", 0)
|
286
|
-
|
287
|
-
|
288
243
|
def get_mcp_config() -> List[Dict[str, Any]]:
|
289
244
|
"""
|
290
245
|
获取MCP配置列表。
|
jarvis/jarvis_utils/embedding.py
CHANGED
@@ -1,17 +1,11 @@
|
|
1
1
|
# -*- coding: utf-8 -*-
|
2
|
-
import
|
3
|
-
import os
|
4
|
-
from typing import Any, List
|
2
|
+
from typing import List
|
5
3
|
|
6
|
-
from jarvis.jarvis_utils.config import get_data_dir
|
7
4
|
from jarvis.jarvis_utils.output import OutputType, PrettyOutput
|
8
5
|
|
9
|
-
# 全局缓存,避免重复加载模型
|
10
|
-
_global_tokenizers = {}
|
11
|
-
|
12
6
|
|
13
7
|
def get_context_token_count(text: str) -> int:
|
14
|
-
"""
|
8
|
+
"""使用tiktoken获取文本的token数量。
|
15
9
|
|
16
10
|
参数:
|
17
11
|
text: 要计算token的输入文本
|
@@ -20,16 +14,10 @@ def get_context_token_count(text: str) -> int:
|
|
20
14
|
int: 文本中的token数量
|
21
15
|
"""
|
22
16
|
try:
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
chunk_size = 100 # 每次处理100个字符,避免超过模型最大长度(考虑到中文字符可能被编码成多个token)
|
28
|
-
for i in range(0, len(text), chunk_size):
|
29
|
-
chunk = text[i : i + chunk_size]
|
30
|
-
tokens = tokenizer.encode(chunk) # type: ignore
|
31
|
-
total_tokens += len(tokens)
|
32
|
-
return total_tokens
|
17
|
+
import tiktoken
|
18
|
+
|
19
|
+
encoding = tiktoken.get_encoding("cl100k_base")
|
20
|
+
return len(encoding.encode(text))
|
33
21
|
except Exception as e:
|
34
22
|
PrettyOutput.print(f"计算token失败: {str(e)}", OutputType.WARNING)
|
35
23
|
return len(text) // 4 # 每个token大约4个字符的粗略估计
|
@@ -84,36 +72,3 @@ def split_text_into_chunks(
|
|
84
72
|
PrettyOutput.print(f"文本分割失败: {str(e)}", OutputType.WARNING)
|
85
73
|
# 发生错误时回退到简单的字符分割
|
86
74
|
return [text[i : i + max_length] for i in range(0, len(text), max_length)]
|
87
|
-
|
88
|
-
|
89
|
-
@functools.lru_cache(maxsize=1)
|
90
|
-
def load_tokenizer() -> Any:
|
91
|
-
"""
|
92
|
-
加载用于文本处理的分词器,使用缓存避免重复加载。
|
93
|
-
|
94
|
-
返回:
|
95
|
-
AutoTokenizer: 加载的分词器
|
96
|
-
"""
|
97
|
-
|
98
|
-
from transformers import AutoTokenizer # type: ignore
|
99
|
-
|
100
|
-
model_name = "gpt2"
|
101
|
-
cache_dir = os.path.join(get_data_dir(), "huggingface", "hub")
|
102
|
-
|
103
|
-
# 检查全局缓存
|
104
|
-
if model_name in _global_tokenizers:
|
105
|
-
return _global_tokenizers[model_name]
|
106
|
-
|
107
|
-
try:
|
108
|
-
tokenizer = AutoTokenizer.from_pretrained(
|
109
|
-
model_name, cache_dir=cache_dir, local_files_only=True
|
110
|
-
)
|
111
|
-
except Exception:
|
112
|
-
tokenizer = AutoTokenizer.from_pretrained(
|
113
|
-
model_name, cache_dir=cache_dir, local_files_only=False
|
114
|
-
)
|
115
|
-
|
116
|
-
# 保存到全局缓存
|
117
|
-
_global_tokenizers[model_name] = tokenizer
|
118
|
-
|
119
|
-
return tokenizer # type: ignore
|